From c4f51d87f815e47eb0db1fd4caa33ab9c7473614 Mon Sep 17 00:00:00 2001 From: Marin Tomic <marin.tomic@memgraph.io> Date: Fri, 6 Jul 2018 15:12:45 +0200 Subject: [PATCH] Implement Reset for distributed operators Reviewers: teon.banek, msantl, buda Reviewed By: teon.banek Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D1467 --- src/distributed/bfs_rpc_clients.cpp | 11 +++++ src/distributed/bfs_rpc_clients.hpp | 3 ++ src/distributed/bfs_rpc_messages.lcp | 4 ++ src/distributed/bfs_rpc_server.hpp | 9 ++++ src/distributed/bfs_subcursor.hpp | 4 +- src/distributed/produce_rpc_server.cpp | 26 ++++++++++++ src/distributed/produce_rpc_server.hpp | 5 +++ src/distributed/pull_produce_rpc_messages.lcp | 7 ++++ src/distributed/pull_rpc_clients.cpp | 21 +++++++--- src/distributed/pull_rpc_clients.hpp | 5 ++- src/query/plan/operator.cpp | 42 +++++++++++++++---- tests/unit/CMakeLists.txt | 3 ++ tests/unit/distributed_interpretation.cpp | 1 + tests/unit/distributed_query_plan.cpp | 6 +-- tests/unit/distributed_reset.cpp | 35 ++++++++++++++++ 15 files changed, 164 insertions(+), 18 deletions(-) create mode 100644 tests/unit/distributed_reset.cpp diff --git a/src/distributed/bfs_rpc_clients.cpp b/src/distributed/bfs_rpc_clients.cpp index c0a29d9eb..2751af940 100644 --- a/src/distributed/bfs_rpc_clients.cpp +++ b/src/distributed/bfs_rpc_clients.cpp @@ -44,6 +44,17 @@ void BfsRpcClients::RegisterSubcursors( ->RegisterSubcursors(subcursor_ids); } +void BfsRpcClients::ResetSubcursors( + const std::unordered_map<int16_t, int64_t> &subcursor_ids) { + auto futures = clients_->ExecuteOnWorkers<void>( + db_->WorkerId(), [&subcursor_ids](int worker_id, auto &client) { + auto res = client.template Call<ResetSubcursorRpc>( + subcursor_ids.at(worker_id)); + CHECK(res) << "ResetSubcursor RPC failed!"; + }); + subcursor_storage_->Get(subcursor_ids.at(db_->WorkerId()))->Reset(); +} + void BfsRpcClients::RemoveBfsSubcursors( const std::unordered_map<int16_t, int64_t> &subcursor_ids) { auto futures = clients_->ExecuteOnWorkers<void>( diff --git a/src/distributed/bfs_rpc_clients.hpp b/src/distributed/bfs_rpc_clients.hpp index a60acdf29..38ed19b60 100644 --- a/src/distributed/bfs_rpc_clients.hpp +++ b/src/distributed/bfs_rpc_clients.hpp @@ -27,6 +27,9 @@ class BfsRpcClients { void RegisterSubcursors( const std::unordered_map<int16_t, int64_t> &subcursor_ids); + void ResetSubcursors( + const std::unordered_map<int16_t, int64_t> &subcursor_ids); + void RemoveBfsSubcursors( const std::unordered_map<int16_t, int64_t> &subcursor_ids); diff --git a/src/distributed/bfs_rpc_messages.lcp b/src/distributed/bfs_rpc_messages.lcp index 4cb7c42b7..2971772d3 100644 --- a/src/distributed/bfs_rpc_messages.lcp +++ b/src/distributed/bfs_rpc_messages.lcp @@ -186,6 +186,10 @@ cpp<# cpp<#)))) (:response ())) +(lcp:define-rpc reset-subcursor + (:request ((subcursor-id :int64_t))) + (:response ())) + (lcp:define-rpc remove-bfs-subcursor (:request ((member :int64_t))) (:response ())) diff --git a/src/distributed/bfs_rpc_server.hpp b/src/distributed/bfs_rpc_server.hpp index 2c6832030..8ce85bdd1 100644 --- a/src/distributed/bfs_rpc_server.hpp +++ b/src/distributed/bfs_rpc_server.hpp @@ -38,6 +38,15 @@ class BfsRpcServer { res.Save(res_builder); }); + server_->Register<ResetSubcursorRpc>([this](const auto &req_reader, + auto *res_builder) { + ResetSubcursorReq req; + req.Load(req_reader); + subcursor_storage_->Get(req.subcursor_id)->Reset(); + ResetSubcursorRes res; + res.Save(res_builder); + }); + server_->Register<RemoveBfsSubcursorRpc>( [this](const auto &req_reader, auto *res_builder) { RemoveBfsSubcursorReq req; diff --git a/src/distributed/bfs_subcursor.hpp b/src/distributed/bfs_subcursor.hpp index b71724b9d..1b3eab669 100644 --- a/src/distributed/bfs_subcursor.hpp +++ b/src/distributed/bfs_subcursor.hpp @@ -73,11 +73,11 @@ class ExpandBfsSubcursor { /// Reconstruct the part of path to given vertex stored on this worker. PathSegment ReconstructPath(storage::VertexAddress vertex_addr); - - private: + /// Used to reset subcursor state before starting expansion from new source. void Reset(); + private: /// Expands to a local or remote vertex, returns true if expansion was /// successful. bool ExpandToVertex(EdgeAccessor edge, VertexAccessor vertex); diff --git a/src/distributed/produce_rpc_server.cpp b/src/distributed/produce_rpc_server.cpp index b2a925e79..6040fa375 100644 --- a/src/distributed/produce_rpc_server.cpp +++ b/src/distributed/produce_rpc_server.cpp @@ -52,6 +52,12 @@ PullState ProduceRpcServer::OngoingProduce::Accumulate() { } } +void ProduceRpcServer::OngoingProduce::Reset() { + cursor_->Reset(); + accumulation_.clear(); + cursor_state_ = PullState::CURSOR_IN_PROGRESS; +} + std::pair<std::vector<query::TypedValue>, PullState> ProduceRpcServer::OngoingProduce::PullOneFromCursor() { std::vector<query::TypedValue> results; @@ -105,6 +111,15 @@ ProduceRpcServer::ProduceRpcServer( res.Save(res_builder); }); + produce_rpc_server_.Register<ResetCursorRpc>( + [this](const auto &req_reader, auto *res_builder) { + ResetCursorReq req; + req.Load(req_reader); + Reset(req); + ResetCursorRes res; + res.Save(res_builder); + }); + produce_rpc_server_.Register<TransactionCommandAdvancedRpc>( [this](const auto &req_reader, auto *res_builder) { TransactionCommandAdvancedReq req; @@ -174,4 +189,15 @@ PullResData ProduceRpcServer::Pull(const PullReq &req) { return result; } +void ProduceRpcServer::Reset(const ResetCursorReq &req) { + auto key_tuple = std::make_tuple(req.tx_id, req.command_id, req.plan_id); + std::lock_guard<std::mutex> guard{ongoing_produces_lock_}; + auto found = ongoing_produces_.find(key_tuple); + // It is fine if the cursor doesn't exist yet. Creating a new cursor is the + // same thing as reseting an existing one. + if (found != ongoing_produces_.end()) { + found->second.Reset(); + } +} + } // namespace distributed diff --git a/src/distributed/produce_rpc_server.hpp b/src/distributed/produce_rpc_server.hpp index 492e061dd..d496b5ab4 100644 --- a/src/distributed/produce_rpc_server.hpp +++ b/src/distributed/produce_rpc_server.hpp @@ -46,6 +46,8 @@ class ProduceRpcServer { /// CURSOR_EXHAUSTED. If an error occurs, an appropriate value is returned. PullState Accumulate(); + void Reset(); + private: database::GraphDbAccessor dba_; query::Context context_; @@ -87,6 +89,9 @@ class ProduceRpcServer { /// Performs a single remote pull for the given request. PullResData Pull(const PullReq &req); + + /// Resets the cursor for an ongoing produce. + void Reset(const ResetCursorReq &req); }; } // namespace distributed diff --git a/src/distributed/pull_produce_rpc_messages.lcp b/src/distributed/pull_produce_rpc_messages.lcp index 2e38849ef..842ad2ece 100644 --- a/src/distributed/pull_produce_rpc_messages.lcp +++ b/src/distributed/pull_produce_rpc_messages.lcp @@ -541,6 +541,13 @@ void PullResData::LoadGraphElement( ;; TODO make a separate RPC for the continuation of an existing pull, as an ;; optimization not to have to send the full PullReqData pack every time. +(lcp:define-rpc reset-cursor + (:request + ((tx-id "tx::TransactionId") + (plan-id :int64_t) + (command-id "tx::CommandId"))) + (:response ())) + (lcp:define-rpc transaction-command-advanced (:request ((member "tx::TransactionId"))) (:response ())) diff --git a/src/distributed/pull_rpc_clients.cpp b/src/distributed/pull_rpc_clients.cpp index 85bf8b07a..d03f82e72 100644 --- a/src/distributed/pull_rpc_clients.cpp +++ b/src/distributed/pull_rpc_clients.cpp @@ -8,27 +8,38 @@ namespace distributed { utils::Future<PullData> PullRpcClients::Pull( - database::GraphDbAccessor &dba, int worker_id, int64_t plan_id, + database::GraphDbAccessor *dba, int worker_id, int64_t plan_id, tx::CommandId command_id, const Parameters ¶ms, const std::vector<query::Symbol> &symbols, int64_t timestamp, bool accumulate, int batch_size) { return clients_.ExecuteOnWorker< - PullData>(worker_id, [&dba, plan_id, command_id, params, symbols, + PullData>(worker_id, [dba, plan_id, command_id, params, symbols, timestamp, accumulate, batch_size]( int worker_id, ClientPool &client_pool) { - auto load_pull_res = [&dba](const auto &res_reader) { + auto load_pull_res = [dba](const auto &res_reader) { PullRes res; - res.Load(res_reader, &dba); + res.Load(res_reader, dba); return res; }; auto result = client_pool.CallWithLoad<PullRpc>( - load_pull_res, dba.transaction_id(), dba.transaction().snapshot(), + load_pull_res, dba->transaction_id(), dba->transaction().snapshot(), plan_id, command_id, params, symbols, timestamp, accumulate, batch_size, true, true); return PullData{result->data.pull_state, std::move(result->data.frames)}; }); } +utils::Future<void> PullRpcClients::ResetCursor(database::GraphDbAccessor *dba, + int worker_id, int64_t plan_id, + tx::CommandId command_id) { + return clients_.ExecuteOnWorker<void>( + worker_id, [dba, plan_id, command_id](int worker_id, auto &client) { + auto res = client.template Call<ResetCursorRpc>(dba->transaction_id(), + plan_id, command_id); + CHECK(res) << "ResetCursorRpc failed!"; + }); +} + std::vector<utils::Future<void>> PullRpcClients::NotifyAllTransactionCommandAdvanced(tx::TransactionId tx_id) { return clients_.ExecuteOnWorkers<void>( diff --git a/src/distributed/pull_rpc_clients.hpp b/src/distributed/pull_rpc_clients.hpp index 23367b649..030bcecb7 100644 --- a/src/distributed/pull_rpc_clients.hpp +++ b/src/distributed/pull_rpc_clients.hpp @@ -29,13 +29,16 @@ class PullRpcClients { /// @todo: it might be cleaner to split Pull into {InitRemoteCursor, /// Pull, RemoteAccumulate}, but that's a lot of refactoring and more /// RPC calls. - utils::Future<PullData> Pull(database::GraphDbAccessor &dba, int worker_id, + utils::Future<PullData> Pull(database::GraphDbAccessor *dba, int worker_id, int64_t plan_id, tx::CommandId command_id, const Parameters ¶ms, const std::vector<query::Symbol> &symbols, int64_t timestamp, bool accumulate, int batch_size = kDefaultBatchSize); + utils::Future<void> ResetCursor(database::GraphDbAccessor *dba, int worker_id, + int64_t plan_id, tx::CommandId command_id); + auto GetWorkerIds() { return clients_.GetWorkerIds(); } std::vector<utils::Future<void>> NotifyAllTransactionCommandAdvanced( diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index e227a8bfc..f965ea8c9 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -1327,7 +1327,8 @@ class DistributedExpandBfsCursor : public query::plan::Cursor { } void Reset() override { - LOG(FATAL) << "`Reset` not supported in distributed"; + db_.db().bfs_subcursor_clients().ResetSubcursors(subcursor_ids_); + pull_pos_ = subcursor_ids_.end(); } private: @@ -3271,11 +3272,28 @@ class RemotePuller { throw HintedAbortError(); case distributed::PullState::QUERY_ERROR: throw QueryRuntimeException( - "Query runtime error occurred duing PullRemote !"); + "Query runtime error occurred during PullRemote !"); } } } + void Reset() { + worker_ids_ = db_.db().pull_clients().GetWorkerIds(); + // Remove master from the worker ids list. + worker_ids_.erase(std::find(worker_ids_.begin(), worker_ids_.end(), 0)); + + // We must clear remote_pulls before reseting cursors to make sure that all + // outstanding remote pulls are done. Otherwise we might try to reset cursor + // during its pull. + remote_pulls_.clear(); + for (auto &worker_id : worker_ids_) { + db_.db().pull_clients().ResetCursor(&db_, worker_id, plan_id_, + command_id_); + } + remote_results_.clear(); + remote_pulls_initialized_ = false; + } + auto Workers() { return worker_ids_; } int GetWorkerId(int worker_id_index) { return worker_ids_[worker_id_index]; } @@ -3322,7 +3340,7 @@ class RemotePuller { void UpdatePullForWorker(int worker_id, Context &context) { remote_pulls_[worker_id] = db_.db().pull_clients().Pull( - db_, worker_id, plan_id_, command_id_, context.parameters_, symbols_, + &db_, worker_id, plan_id_, command_id_, context.parameters_, symbols_, context.timestamp_, false); } }; @@ -3412,7 +3430,9 @@ class PullRemoteCursor : public Cursor { } void Reset() override { - throw QueryRuntimeException("Unsupported: Reset during PullRemote!"); + if (input_cursor_) input_cursor_->Reset(); + remote_puller_.Reset(); + last_pulled_worker_id_index_ = 0; } private: @@ -3465,7 +3485,10 @@ class SynchronizeCursor : public Cursor { } void Reset() override { - throw QueryRuntimeException("Unsupported: Reset during Synchronize!"); + input_cursor_->Reset(); + pull_remote_cursor_->Reset(); + initial_pull_done_ = false; + local_frames_.clear(); } private: @@ -3487,7 +3510,7 @@ class SynchronizeCursor : public Cursor { for (auto worker_id : db.pull_clients().GetWorkerIds()) { if (worker_id == db.WorkerId()) continue; worker_accumulations.emplace_back(db.pull_clients().Pull( - context.db_accessor_, worker_id, self_.pull_remote()->plan_id(), + &context.db_accessor_, worker_id, self_.pull_remote()->plan_id(), command_id_, context.parameters_, self_.pull_remote()->symbols(), context.timestamp_, true, 0)); } @@ -3774,7 +3797,12 @@ class PullRemoteOrderByCursor : public Cursor { } void Reset() { - throw QueryRuntimeException("Unsupported: Reset during PullRemoteOrderBy!"); + input_->Reset(); + remote_puller_.Reset(); + merge_.clear(); + missing_results_from_.clear(); + missing_master_result_ = false; + merge_initialized_ = false; } private: diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index e5fe56cdd..466de9102 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -103,6 +103,9 @@ target_link_libraries(${test_prefix}distributed_interpretation memgraph_lib kvst add_unit_test(distributed_query_plan.cpp) target_link_libraries(${test_prefix}distributed_query_plan memgraph_lib kvstore_dummy_lib) +add_unit_test(distributed_reset.cpp) +target_link_libraries(${test_prefix}distributed_reset memgraph_lib kvstore_dummy_lib) + add_unit_test(distributed_serialization.cpp) target_link_libraries(${test_prefix}distributed_serialization memgraph_lib kvstore_dummy_lib) diff --git a/tests/unit/distributed_interpretation.cpp b/tests/unit/distributed_interpretation.cpp index a18b38374..8c7f02e7d 100644 --- a/tests/unit/distributed_interpretation.cpp +++ b/tests/unit/distributed_interpretation.cpp @@ -6,6 +6,7 @@ #include "database/graph_db.hpp" #include "distributed/plan_consumer.hpp" +#include "distributed/plan_dispatcher.hpp" #include "distributed/pull_rpc_clients.hpp" #include "distributed_common.hpp" #include "query/interpreter.hpp" diff --git a/tests/unit/distributed_query_plan.cpp b/tests/unit/distributed_query_plan.cpp index 7f3fea5f3..0ca1da1d2 100644 --- a/tests/unit/distributed_query_plan.cpp +++ b/tests/unit/distributed_query_plan.cpp @@ -67,7 +67,7 @@ TEST_F(DistributedQueryPlan, PullProduceRpc) { std::vector<query::Symbol> symbols{ctx.symbol_table_[*x_ne]}; auto remote_pull = [this, &command_id, ¶ms, &symbols]( GraphDbAccessor &dba, int worker_id) { - return master().pull_clients().Pull(dba, worker_id, plan_id, command_id, + return master().pull_clients().Pull(&dba, worker_id, plan_id, command_id, params, symbols, 0, false, 3); }; auto expect_first_batch = [](auto &batch) { @@ -187,7 +187,7 @@ TEST_F(DistributedQueryPlan, PullProduceRpcWithGraphElements) { ctx.symbol_table_[*return_m], p_sym}; auto remote_pull = [this, &command_id, ¶ms, &symbols]( GraphDbAccessor &dba, int worker_id) { - return master().pull_clients().Pull(dba, worker_id, plan_id, command_id, + return master().pull_clients().Pull(&dba, worker_id, plan_id, command_id, params, symbols, 0, false, 3); }; auto future_w1_results = remote_pull(dba, 1); @@ -363,7 +363,7 @@ TEST_F(DistributedTransactionTimeout, Timeout) { auto remote_pull = [this, &command_id, ¶ms, &symbols, &dba]() { return master() .pull_clients() - .Pull(dba, 1, plan_id, command_id, params, symbols, 0, false, 1) + .Pull(&dba, 1, plan_id, command_id, params, symbols, 0, false, 1) .get() .pull_state; }; diff --git a/tests/unit/distributed_reset.cpp b/tests/unit/distributed_reset.cpp new file mode 100644 index 000000000..ba93b4e77 --- /dev/null +++ b/tests/unit/distributed_reset.cpp @@ -0,0 +1,35 @@ +#include "gtest/gtest.h" + +#include "distributed/plan_dispatcher.hpp" +#include "distributed_common.hpp" +#include "query/context.hpp" +#include "query/interpret/frame.hpp" + +class DistributedReset : public DistributedGraphDbTest { + protected: + DistributedReset() : DistributedGraphDbTest("reset") {} +}; + +TEST_F(DistributedReset, ResetTest) { + query::SymbolTable symbol_table; + auto once = std::make_shared<query::plan::Once>(); + auto pull_remote = std::make_shared<query::plan::PullRemote>( + once, 42, std::vector<query::Symbol>()); + master().plan_dispatcher().DispatchPlan(42, once, symbol_table); + database::GraphDbAccessor dba{master()}; + query::Frame frame(0); + query::Context context(dba); + auto pull_remote_cursor = pull_remote->MakeCursor(dba); + + for (int i = 0; i < 3; ++i) { + EXPECT_TRUE(pull_remote_cursor->Pull(frame, context)); + } + EXPECT_FALSE(pull_remote_cursor->Pull(frame, context)); + + pull_remote_cursor->Reset(); + + for (int i = 0; i < 3; ++i) { + EXPECT_TRUE(pull_remote_cursor->Pull(frame, context)); + } + EXPECT_FALSE(pull_remote_cursor->Pull(frame, context)); +}