Make RemotePull
operator async
Reviewers: teon.banek, florijan Reviewed By: teon.banek Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D1168
This commit is contained in:
parent
f84b0b0787
commit
a66351c3f4
@ -32,10 +32,10 @@ class RemotePullRpcClients {
|
||||
int batch_size = kDefaultBatchSize) {
|
||||
return clients_.ExecuteOnWorker<RemotePullData>(
|
||||
worker_id,
|
||||
[&dba, plan_id, params, symbols, batch_size](ClientPool &client) {
|
||||
auto result =
|
||||
client.Call<RemotePullRpc>(dba.transaction_id(), plan_id, params,
|
||||
symbols, batch_size, true, true);
|
||||
[&dba, plan_id, params, symbols, batch_size](ClientPool &client_pool) {
|
||||
auto result = client_pool.Call<RemotePullRpc>(
|
||||
dba.transaction_id(), plan_id, params, symbols, batch_size, true,
|
||||
true);
|
||||
|
||||
auto handle_vertex = [&dba](auto &v) {
|
||||
dba.remote_vertices().emplace(v.global_address.gid(),
|
||||
|
@ -43,9 +43,10 @@ class RpcWorkerClients {
|
||||
auto ExecuteOnWorker(
|
||||
int worker_id,
|
||||
std::function<TResult(communication::rpc::ClientPool &)> execute) {
|
||||
auto &client = GetClientPool(worker_id);
|
||||
return std::async(std::launch::async,
|
||||
[execute, &client]() { return execute(client); });
|
||||
auto &client_pool = GetClientPool(worker_id);
|
||||
return std::async(std::launch::async, [execute, &client_pool]() {
|
||||
return execute(client_pool);
|
||||
});
|
||||
}
|
||||
|
||||
/** Asynchroniously executes the `execute` function on all worker rpc clients
|
||||
|
@ -2566,10 +2566,12 @@ std::unique_ptr<Cursor> ProduceRemote::MakeCursor(
|
||||
}
|
||||
|
||||
PullRemote::PullRemote(const std::shared_ptr<LogicalOperator> &input,
|
||||
int64_t plan_id, const std::vector<Symbol> &symbols)
|
||||
int64_t plan_id, const std::vector<Symbol> &symbols,
|
||||
bool pull_local)
|
||||
: input_(input ? input : std::make_shared<Once>()),
|
||||
plan_id_(plan_id),
|
||||
symbols_(symbols) {}
|
||||
symbols_(symbols),
|
||||
pull_local_(pull_local) {}
|
||||
|
||||
ACCEPT_WITH_INPUT(PullRemote);
|
||||
|
||||
@ -2577,64 +2579,114 @@ PullRemote::PullRemoteCursor::PullRemoteCursor(const PullRemote &self,
|
||||
database::GraphDbAccessor &db)
|
||||
: self_(self), db_(db), input_cursor_(self.input_->MakeCursor(db)) {
|
||||
worker_ids_ = db_.db().remote_pull_clients().GetWorkerIds();
|
||||
// remove master from the worker_ids list
|
||||
// Remove master from the worker ids list.
|
||||
worker_ids_.erase(std::find(worker_ids_.begin(), worker_ids_.end(), 0));
|
||||
}
|
||||
|
||||
void PullRemote::PullRemoteCursor::EndRemotePull() {
|
||||
if (remote_pull_ended_) return;
|
||||
db_.db().remote_pull_clients().EndAllRemotePulls(db_.transaction().id_,
|
||||
self_.plan_id());
|
||||
remote_pull_ended_ = true;
|
||||
std::vector<std::future<void>> futures;
|
||||
for (auto &worker_id : worker_ids_) {
|
||||
futures.emplace_back(db_.db().remote_pull_clients().EndRemotePull(
|
||||
worker_id, db_.transaction().id_, self_.plan_id()));
|
||||
}
|
||||
for (auto &future : futures) future.wait();
|
||||
worker_ids_.clear();
|
||||
}
|
||||
|
||||
bool PullRemote::PullRemoteCursor::Pull(Frame &frame, Context &context) {
|
||||
if (input_cursor_->Pull(frame, context)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
while (worker_ids_.size() > 0 && results_.empty()) {
|
||||
last_pulled_worker_ = (last_pulled_worker_ + 1) % worker_ids_.size();
|
||||
auto remote_results =
|
||||
db_.db()
|
||||
.remote_pull_clients()
|
||||
.RemotePull(db_, worker_ids_[last_pulled_worker_], self_.plan_id(),
|
||||
context.parameters_, self_.symbols())
|
||||
.get();
|
||||
|
||||
auto get_results = [&]() {
|
||||
for (auto &frame : remote_results.frames) {
|
||||
results_.emplace(std::move(frame));
|
||||
}
|
||||
auto insert_future_for_worker = [&](int worker_id) {
|
||||
remote_pulls_[worker_id] = db_.db().remote_pull_clients().RemotePull(
|
||||
db_, worker_id, self_.plan_id(), context.parameters_, self_.symbols());
|
||||
};
|
||||
|
||||
if (!remote_pulls_initialized_) {
|
||||
for (auto &worker_id : worker_ids_) {
|
||||
insert_future_for_worker(worker_id);
|
||||
}
|
||||
remote_pulls_initialized_ = true;
|
||||
}
|
||||
|
||||
bool have_remote_results = false;
|
||||
while (!have_remote_results && !worker_ids_.empty()) {
|
||||
// If we don't have results for a worker, check if his remote pull
|
||||
// finished and save results locally.
|
||||
for (auto &worker_id : worker_ids_) {
|
||||
if (!remote_results_[worker_id].empty()) continue;
|
||||
|
||||
auto found_it = remote_pulls_.find(worker_id);
|
||||
if (found_it == remote_pulls_.end()) continue;
|
||||
|
||||
auto &remote_pull = found_it->second;
|
||||
if (!remote_pull.valid()) continue;
|
||||
|
||||
auto remote_results = remote_pull.get();
|
||||
switch (remote_results.pull_state) {
|
||||
case distributed::RemotePullState::CURSOR_EXHAUSTED:
|
||||
get_results();
|
||||
worker_ids_.erase(worker_ids_.begin() + last_pulled_worker_);
|
||||
remote_results_[worker_id] = std::move(remote_results.frames);
|
||||
remote_pulls_.erase(found_it);
|
||||
break;
|
||||
case distributed::RemotePullState::CURSOR_IN_PROGRESS:
|
||||
get_results();
|
||||
remote_results_[worker_id] = std::move(remote_results.frames);
|
||||
insert_future_for_worker(worker_id);
|
||||
break;
|
||||
case distributed::RemotePullState::SERIALIZATION_ERROR:
|
||||
EndRemotePull();
|
||||
throw mvcc::SerializationError(
|
||||
"Serialization error occured during PullRemote!");
|
||||
"Serialization error occured during PullRemote !");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// if the results_ are still empty, we've exhausted all worker results
|
||||
if (results_.empty()) {
|
||||
EndRemotePull();
|
||||
return false;
|
||||
// Get locally stored results from workers in a round-robin fasion.
|
||||
int num_workers = worker_ids_.size();
|
||||
for (int i = 0; i < num_workers; ++i) {
|
||||
int worker_id_index = (last_pulled_worker_id_index_ + i) % num_workers;
|
||||
int worker_id = worker_ids_[worker_id_index];
|
||||
|
||||
if (!remote_results_[worker_id].empty()) {
|
||||
last_pulled_worker_id_index_ = worker_id_index;
|
||||
have_remote_results = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
auto &result = results_.front();
|
||||
// If there are no remote results available, pull and return local results.
|
||||
if (!have_remote_results && self_.pull_local() &&
|
||||
input_cursor_->Pull(frame, context)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// No more remote results, make sure local results get exhausted.
|
||||
if (!have_remote_results) {
|
||||
if (self_.pull_local() && input_cursor_->Pull(frame, context)) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
int pull_from_worker_id = worker_ids_[last_pulled_worker_id_index_];
|
||||
{
|
||||
auto &result = remote_results_[pull_from_worker_id].back();
|
||||
for (size_t i = 0; i < self_.symbols().size(); ++i) {
|
||||
frame[self_.symbols()[i]] = std::move(result[i]);
|
||||
}
|
||||
results_.pop();
|
||||
}
|
||||
remote_results_[pull_from_worker_id].resize(
|
||||
remote_results_[pull_from_worker_id].size() - 1);
|
||||
|
||||
// Remove the worker if we exhausted all locally stored results and there are
|
||||
// no more pending remote pulls for that worker.
|
||||
if (remote_results_[pull_from_worker_id].empty() &&
|
||||
remote_pulls_.find(pull_from_worker_id) == remote_pulls_.end()) {
|
||||
worker_ids_.erase(worker_ids_.begin() + last_pulled_worker_id_index_);
|
||||
db_.db()
|
||||
.remote_pull_clients()
|
||||
.EndRemotePull(pull_from_worker_id, db_.transaction().id_,
|
||||
self_.plan_id())
|
||||
.wait();
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -4,7 +4,6 @@
|
||||
|
||||
#include <experimental/optional>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
@ -17,6 +16,7 @@
|
||||
#include "boost/serialization/shared_ptr.hpp"
|
||||
#include "boost/serialization/unique_ptr.hpp"
|
||||
|
||||
#include "distributed/remote_pull_produce_rpc_messages.hpp"
|
||||
#include "query/common.hpp"
|
||||
#include "query/frontend/ast/ast.hpp"
|
||||
#include "query/frontend/semantic/symbol.hpp"
|
||||
@ -2274,7 +2274,7 @@ class ProduceRemote : public LogicalOperator {
|
||||
class PullRemote : public LogicalOperator {
|
||||
public:
|
||||
PullRemote(const std::shared_ptr<LogicalOperator> &input, int64_t plan_id,
|
||||
const std::vector<Symbol> &symbols);
|
||||
const std::vector<Symbol> &symbols, bool pull_local = true);
|
||||
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
|
||||
std::unique_ptr<Cursor> MakeCursor(
|
||||
database::GraphDbAccessor &db) const override;
|
||||
@ -2284,11 +2284,13 @@ class PullRemote : public LogicalOperator {
|
||||
}
|
||||
const auto &symbols() const { return symbols_; }
|
||||
auto plan_id() const { return plan_id_; }
|
||||
auto pull_local() const { return pull_local_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<LogicalOperator> input_;
|
||||
int64_t plan_id_ = 0;
|
||||
std::vector<Symbol> symbols_;
|
||||
bool pull_local_ = true;
|
||||
|
||||
PullRemote() {}
|
||||
|
||||
@ -2304,10 +2306,13 @@ class PullRemote : public LogicalOperator {
|
||||
const PullRemote &self_;
|
||||
database::GraphDbAccessor &db_;
|
||||
const std::unique_ptr<Cursor> input_cursor_;
|
||||
std::queue<std::vector<query::TypedValue>> results_;
|
||||
std::unordered_map<int, std::future<distributed::RemotePullData>>
|
||||
remote_pulls_;
|
||||
std::unordered_map<int, std::vector<std::vector<query::TypedValue>>>
|
||||
remote_results_;
|
||||
std::vector<int> worker_ids_;
|
||||
int last_pulled_worker_ = -1;
|
||||
bool remote_pull_ended_ = false;
|
||||
int last_pulled_worker_id_index_ = 0;
|
||||
bool remote_pulls_initialized_ = false;
|
||||
};
|
||||
|
||||
friend class boost::serialization::access;
|
||||
@ -2317,6 +2322,7 @@ class PullRemote : public LogicalOperator {
|
||||
ar &input_;
|
||||
ar &plan_id_;
|
||||
ar &symbols_;
|
||||
ar &pull_local_;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -23,6 +23,8 @@
|
||||
#include "transactions/engine_master.hpp"
|
||||
|
||||
#include "distributed_common.hpp"
|
||||
#include "query/interpreter.hpp"
|
||||
#include "query/typed_value.hpp"
|
||||
#include "query_common.hpp"
|
||||
#include "query_plan_common.hpp"
|
||||
|
||||
@ -339,3 +341,27 @@ TEST_F(DistributedGraphDbTest, WorkerOwnedDbAccessors) {
|
||||
VertexAccessor v_in_w2{v_ga, dba_w2};
|
||||
EXPECT_EQ(v_in_w2.PropsAt(prop).Value<int64_t>(), 42);
|
||||
}
|
||||
|
||||
TEST_F(DistributedGraphDbTest, RemotePullTest) {
|
||||
using Interpreter = query::Interpreter;
|
||||
std::map<std::string, query::TypedValue> params = {};
|
||||
|
||||
GraphDbAccessor dba(master());
|
||||
|
||||
ResultStreamFaker result;
|
||||
Interpreter interpreter_;
|
||||
interpreter_("OPTIONAL MATCH(n) UNWIND(RANGE(0, 20)) AS X RETURN 1", dba,
|
||||
params, false)
|
||||
.PullAll(result);
|
||||
|
||||
// Three instances (master + 2 workers) with 21 result each.
|
||||
uint expected_results = 3U * 21;
|
||||
ASSERT_EQ(result.GetHeader().size(), 1U);
|
||||
EXPECT_EQ(result.GetHeader()[0], "1");
|
||||
ASSERT_EQ(result.GetResults().size(), expected_results);
|
||||
|
||||
for (uint i = 0; i < expected_results; ++i) {
|
||||
ASSERT_EQ(result.GetResults()[i].size(), 1U);
|
||||
ASSERT_EQ(result.GetResults()[i][0].Value<int64_t>(), 1);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user