From c4f51d87f815e47eb0db1fd4caa33ab9c7473614 Mon Sep 17 00:00:00 2001
From: Marin Tomic <marin.tomic@memgraph.io>
Date: Fri, 6 Jul 2018 15:12:45 +0200
Subject: [PATCH] Implement Reset for distributed operators

Reviewers: teon.banek, msantl, buda

Reviewed By: teon.banek

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1467
---
 src/distributed/bfs_rpc_clients.cpp           | 11 +++++
 src/distributed/bfs_rpc_clients.hpp           |  3 ++
 src/distributed/bfs_rpc_messages.lcp          |  4 ++
 src/distributed/bfs_rpc_server.hpp            |  9 ++++
 src/distributed/bfs_subcursor.hpp             |  4 +-
 src/distributed/produce_rpc_server.cpp        | 26 ++++++++++++
 src/distributed/produce_rpc_server.hpp        |  5 +++
 src/distributed/pull_produce_rpc_messages.lcp |  7 ++++
 src/distributed/pull_rpc_clients.cpp          | 21 +++++++---
 src/distributed/pull_rpc_clients.hpp          |  5 ++-
 src/query/plan/operator.cpp                   | 42 +++++++++++++++----
 tests/unit/CMakeLists.txt                     |  3 ++
 tests/unit/distributed_interpretation.cpp     |  1 +
 tests/unit/distributed_query_plan.cpp         |  6 +--
 tests/unit/distributed_reset.cpp              | 35 ++++++++++++++++
 15 files changed, 164 insertions(+), 18 deletions(-)
 create mode 100644 tests/unit/distributed_reset.cpp

diff --git a/src/distributed/bfs_rpc_clients.cpp b/src/distributed/bfs_rpc_clients.cpp
index c0a29d9eb..2751af940 100644
--- a/src/distributed/bfs_rpc_clients.cpp
+++ b/src/distributed/bfs_rpc_clients.cpp
@@ -44,6 +44,17 @@ void BfsRpcClients::RegisterSubcursors(
       ->RegisterSubcursors(subcursor_ids);
 }
 
+void BfsRpcClients::ResetSubcursors(
+    const std::unordered_map<int16_t, int64_t> &subcursor_ids) {
+  auto futures = clients_->ExecuteOnWorkers<void>(
+      db_->WorkerId(), [&subcursor_ids](int worker_id, auto &client) {
+        auto res = client.template Call<ResetSubcursorRpc>(
+            subcursor_ids.at(worker_id));
+        CHECK(res) << "ResetSubcursor RPC failed!";
+      });
+  subcursor_storage_->Get(subcursor_ids.at(db_->WorkerId()))->Reset();
+}
+
 void BfsRpcClients::RemoveBfsSubcursors(
     const std::unordered_map<int16_t, int64_t> &subcursor_ids) {
   auto futures = clients_->ExecuteOnWorkers<void>(
diff --git a/src/distributed/bfs_rpc_clients.hpp b/src/distributed/bfs_rpc_clients.hpp
index a60acdf29..38ed19b60 100644
--- a/src/distributed/bfs_rpc_clients.hpp
+++ b/src/distributed/bfs_rpc_clients.hpp
@@ -27,6 +27,9 @@ class BfsRpcClients {
   void RegisterSubcursors(
       const std::unordered_map<int16_t, int64_t> &subcursor_ids);
 
+  void ResetSubcursors(
+      const std::unordered_map<int16_t, int64_t> &subcursor_ids);
+
   void RemoveBfsSubcursors(
       const std::unordered_map<int16_t, int64_t> &subcursor_ids);
 
diff --git a/src/distributed/bfs_rpc_messages.lcp b/src/distributed/bfs_rpc_messages.lcp
index 4cb7c42b7..2971772d3 100644
--- a/src/distributed/bfs_rpc_messages.lcp
+++ b/src/distributed/bfs_rpc_messages.lcp
@@ -186,6 +186,10 @@ cpp<#
                                 cpp<#))))
   (:response ()))
 
+(lcp:define-rpc reset-subcursor
+    (:request ((subcursor-id :int64_t)))
+    (:response ()))
+
 (lcp:define-rpc remove-bfs-subcursor
     (:request ((member :int64_t)))
   (:response ()))
diff --git a/src/distributed/bfs_rpc_server.hpp b/src/distributed/bfs_rpc_server.hpp
index 2c6832030..8ce85bdd1 100644
--- a/src/distributed/bfs_rpc_server.hpp
+++ b/src/distributed/bfs_rpc_server.hpp
@@ -38,6 +38,15 @@ class BfsRpcServer {
           res.Save(res_builder);
         });
 
+    server_->Register<ResetSubcursorRpc>([this](const auto &req_reader,
+                                                auto *res_builder) {
+      ResetSubcursorReq req;
+      req.Load(req_reader);
+      subcursor_storage_->Get(req.subcursor_id)->Reset();
+      ResetSubcursorRes res;
+      res.Save(res_builder);
+    });
+
     server_->Register<RemoveBfsSubcursorRpc>(
         [this](const auto &req_reader, auto *res_builder) {
           RemoveBfsSubcursorReq req;
diff --git a/src/distributed/bfs_subcursor.hpp b/src/distributed/bfs_subcursor.hpp
index b71724b9d..1b3eab669 100644
--- a/src/distributed/bfs_subcursor.hpp
+++ b/src/distributed/bfs_subcursor.hpp
@@ -73,11 +73,11 @@ class ExpandBfsSubcursor {
 
   /// Reconstruct the part of path to given vertex stored on this worker.
   PathSegment ReconstructPath(storage::VertexAddress vertex_addr);
-
- private:
+ 
   /// Used to reset subcursor state before starting expansion from new source.
   void Reset();
 
+ private:
   /// Expands to a local or remote vertex, returns true if expansion was
   /// successful.
   bool ExpandToVertex(EdgeAccessor edge, VertexAccessor vertex);
diff --git a/src/distributed/produce_rpc_server.cpp b/src/distributed/produce_rpc_server.cpp
index b2a925e79..6040fa375 100644
--- a/src/distributed/produce_rpc_server.cpp
+++ b/src/distributed/produce_rpc_server.cpp
@@ -52,6 +52,12 @@ PullState ProduceRpcServer::OngoingProduce::Accumulate() {
   }
 }
 
+void ProduceRpcServer::OngoingProduce::Reset() {
+  cursor_->Reset();
+  accumulation_.clear();
+  cursor_state_ = PullState::CURSOR_IN_PROGRESS;
+}
+
 std::pair<std::vector<query::TypedValue>, PullState>
 ProduceRpcServer::OngoingProduce::PullOneFromCursor() {
   std::vector<query::TypedValue> results;
@@ -105,6 +111,15 @@ ProduceRpcServer::ProduceRpcServer(
         res.Save(res_builder);
       });
 
+  produce_rpc_server_.Register<ResetCursorRpc>(
+      [this](const auto &req_reader, auto *res_builder) {
+        ResetCursorReq req;
+        req.Load(req_reader);
+        Reset(req);
+        ResetCursorRes res;
+        res.Save(res_builder);
+      });
+
   produce_rpc_server_.Register<TransactionCommandAdvancedRpc>(
       [this](const auto &req_reader, auto *res_builder) {
         TransactionCommandAdvancedReq req;
@@ -174,4 +189,15 @@ PullResData ProduceRpcServer::Pull(const PullReq &req) {
   return result;
 }
 
+void ProduceRpcServer::Reset(const ResetCursorReq &req) {
+  auto key_tuple = std::make_tuple(req.tx_id, req.command_id, req.plan_id);
+  std::lock_guard<std::mutex> guard{ongoing_produces_lock_};
+  auto found = ongoing_produces_.find(key_tuple);
+  // It is fine if the cursor doesn't exist yet. Creating a new cursor is the
+  // same thing as reseting an existing one.
+  if (found != ongoing_produces_.end()) {
+    found->second.Reset();
+  }
+}
+
 }  // namespace distributed
diff --git a/src/distributed/produce_rpc_server.hpp b/src/distributed/produce_rpc_server.hpp
index 492e061dd..d496b5ab4 100644
--- a/src/distributed/produce_rpc_server.hpp
+++ b/src/distributed/produce_rpc_server.hpp
@@ -46,6 +46,8 @@ class ProduceRpcServer {
     /// CURSOR_EXHAUSTED. If an error occurs, an appropriate value is returned.
     PullState Accumulate();
 
+    void Reset();
+
    private:
     database::GraphDbAccessor dba_;
     query::Context context_;
@@ -87,6 +89,9 @@ class ProduceRpcServer {
 
   /// Performs a single remote pull for the given request.
   PullResData Pull(const PullReq &req);
+
+  /// Resets the cursor for an ongoing produce.
+  void Reset(const ResetCursorReq &req);
 };
 
 }  // namespace distributed
diff --git a/src/distributed/pull_produce_rpc_messages.lcp b/src/distributed/pull_produce_rpc_messages.lcp
index 2e38849ef..842ad2ece 100644
--- a/src/distributed/pull_produce_rpc_messages.lcp
+++ b/src/distributed/pull_produce_rpc_messages.lcp
@@ -541,6 +541,13 @@ void PullResData::LoadGraphElement(
 ;; TODO make a separate RPC for the continuation of an existing pull, as an
 ;; optimization not to have to send the full PullReqData pack every time.
 
+(lcp:define-rpc reset-cursor
+    (:request
+     ((tx-id "tx::TransactionId")
+      (plan-id :int64_t)
+      (command-id "tx::CommandId")))
+    (:response ()))
+
 (lcp:define-rpc transaction-command-advanced
     (:request ((member "tx::TransactionId")))
   (:response ()))
diff --git a/src/distributed/pull_rpc_clients.cpp b/src/distributed/pull_rpc_clients.cpp
index 85bf8b07a..d03f82e72 100644
--- a/src/distributed/pull_rpc_clients.cpp
+++ b/src/distributed/pull_rpc_clients.cpp
@@ -8,27 +8,38 @@
 namespace distributed {
 
 utils::Future<PullData> PullRpcClients::Pull(
-    database::GraphDbAccessor &dba, int worker_id, int64_t plan_id,
+    database::GraphDbAccessor *dba, int worker_id, int64_t plan_id,
     tx::CommandId command_id, const Parameters &params,
     const std::vector<query::Symbol> &symbols, int64_t timestamp,
     bool accumulate, int batch_size) {
   return clients_.ExecuteOnWorker<
-      PullData>(worker_id, [&dba, plan_id, command_id, params, symbols,
+      PullData>(worker_id, [dba, plan_id, command_id, params, symbols,
                             timestamp, accumulate, batch_size](
                                int worker_id, ClientPool &client_pool) {
-    auto load_pull_res = [&dba](const auto &res_reader) {
+    auto load_pull_res = [dba](const auto &res_reader) {
       PullRes res;
-      res.Load(res_reader, &dba);
+      res.Load(res_reader, dba);
       return res;
     };
     auto result = client_pool.CallWithLoad<PullRpc>(
-        load_pull_res, dba.transaction_id(), dba.transaction().snapshot(),
+        load_pull_res, dba->transaction_id(), dba->transaction().snapshot(),
         plan_id, command_id, params, symbols, timestamp, accumulate, batch_size,
         true, true);
     return PullData{result->data.pull_state, std::move(result->data.frames)};
   });
 }
 
+utils::Future<void> PullRpcClients::ResetCursor(database::GraphDbAccessor *dba,
+                                                int worker_id, int64_t plan_id,
+                                                tx::CommandId command_id) {
+  return clients_.ExecuteOnWorker<void>(
+      worker_id, [dba, plan_id, command_id](int worker_id, auto &client) {
+        auto res = client.template Call<ResetCursorRpc>(dba->transaction_id(),
+                                                        plan_id, command_id);
+        CHECK(res) << "ResetCursorRpc failed!";
+      });
+}
+
 std::vector<utils::Future<void>>
 PullRpcClients::NotifyAllTransactionCommandAdvanced(tx::TransactionId tx_id) {
   return clients_.ExecuteOnWorkers<void>(
diff --git a/src/distributed/pull_rpc_clients.hpp b/src/distributed/pull_rpc_clients.hpp
index 23367b649..030bcecb7 100644
--- a/src/distributed/pull_rpc_clients.hpp
+++ b/src/distributed/pull_rpc_clients.hpp
@@ -29,13 +29,16 @@ class PullRpcClients {
   /// @todo: it might be cleaner to split Pull into {InitRemoteCursor,
   /// Pull, RemoteAccumulate}, but that's a lot of refactoring and more
   /// RPC calls.
-  utils::Future<PullData> Pull(database::GraphDbAccessor &dba, int worker_id,
+  utils::Future<PullData> Pull(database::GraphDbAccessor *dba, int worker_id,
                                int64_t plan_id, tx::CommandId command_id,
                                const Parameters &params,
                                const std::vector<query::Symbol> &symbols,
                                int64_t timestamp, bool accumulate,
                                int batch_size = kDefaultBatchSize);
 
+  utils::Future<void> ResetCursor(database::GraphDbAccessor *dba, int worker_id,
+                                  int64_t plan_id, tx::CommandId command_id);
+
   auto GetWorkerIds() { return clients_.GetWorkerIds(); }
 
   std::vector<utils::Future<void>> NotifyAllTransactionCommandAdvanced(
diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp
index e227a8bfc..f965ea8c9 100644
--- a/src/query/plan/operator.cpp
+++ b/src/query/plan/operator.cpp
@@ -1327,7 +1327,8 @@ class DistributedExpandBfsCursor : public query::plan::Cursor {
   }
 
   void Reset() override {
-    LOG(FATAL) << "`Reset` not supported in distributed";
+    db_.db().bfs_subcursor_clients().ResetSubcursors(subcursor_ids_);
+    pull_pos_ = subcursor_ids_.end();
   }
 
  private:
@@ -3271,11 +3272,28 @@ class RemotePuller {
           throw HintedAbortError();
         case distributed::PullState::QUERY_ERROR:
           throw QueryRuntimeException(
-              "Query runtime error occurred duing PullRemote !");
+              "Query runtime error occurred during PullRemote !");
       }
     }
   }
 
+  void Reset() {
+    worker_ids_ = db_.db().pull_clients().GetWorkerIds();
+    // Remove master from the worker ids list.
+    worker_ids_.erase(std::find(worker_ids_.begin(), worker_ids_.end(), 0));
+
+    // We must clear remote_pulls before reseting cursors to make sure that all
+    // outstanding remote pulls are done. Otherwise we might try to reset cursor
+    // during its pull.
+    remote_pulls_.clear();
+    for (auto &worker_id : worker_ids_) {
+      db_.db().pull_clients().ResetCursor(&db_, worker_id, plan_id_,
+                                          command_id_);
+    }
+    remote_results_.clear();
+    remote_pulls_initialized_ = false;
+  }
+
   auto Workers() { return worker_ids_; }
 
   int GetWorkerId(int worker_id_index) { return worker_ids_[worker_id_index]; }
@@ -3322,7 +3340,7 @@ class RemotePuller {
 
   void UpdatePullForWorker(int worker_id, Context &context) {
     remote_pulls_[worker_id] = db_.db().pull_clients().Pull(
-        db_, worker_id, plan_id_, command_id_, context.parameters_, symbols_,
+        &db_, worker_id, plan_id_, command_id_, context.parameters_, symbols_,
         context.timestamp_, false);
   }
 };
@@ -3412,7 +3430,9 @@ class PullRemoteCursor : public Cursor {
   }
 
   void Reset() override {
-    throw QueryRuntimeException("Unsupported: Reset during PullRemote!");
+    if (input_cursor_) input_cursor_->Reset();
+    remote_puller_.Reset();
+    last_pulled_worker_id_index_ = 0;
   }
 
  private:
@@ -3465,7 +3485,10 @@ class SynchronizeCursor : public Cursor {
   }
 
   void Reset() override {
-    throw QueryRuntimeException("Unsupported: Reset during Synchronize!");
+    input_cursor_->Reset();
+    pull_remote_cursor_->Reset();
+    initial_pull_done_ = false;
+    local_frames_.clear();
   }
 
  private:
@@ -3487,7 +3510,7 @@ class SynchronizeCursor : public Cursor {
       for (auto worker_id : db.pull_clients().GetWorkerIds()) {
         if (worker_id == db.WorkerId()) continue;
         worker_accumulations.emplace_back(db.pull_clients().Pull(
-            context.db_accessor_, worker_id, self_.pull_remote()->plan_id(),
+            &context.db_accessor_, worker_id, self_.pull_remote()->plan_id(),
             command_id_, context.parameters_, self_.pull_remote()->symbols(),
             context.timestamp_, true, 0));
       }
@@ -3774,7 +3797,12 @@ class PullRemoteOrderByCursor : public Cursor {
   }
 
   void Reset() {
-    throw QueryRuntimeException("Unsupported: Reset during PullRemoteOrderBy!");
+    input_->Reset();
+    remote_puller_.Reset();
+    merge_.clear();
+    missing_results_from_.clear();
+    missing_master_result_ = false;
+    merge_initialized_ = false;
   }
 
  private:
diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt
index e5fe56cdd..466de9102 100644
--- a/tests/unit/CMakeLists.txt
+++ b/tests/unit/CMakeLists.txt
@@ -103,6 +103,9 @@ target_link_libraries(${test_prefix}distributed_interpretation memgraph_lib kvst
 add_unit_test(distributed_query_plan.cpp)
 target_link_libraries(${test_prefix}distributed_query_plan memgraph_lib kvstore_dummy_lib)
 
+add_unit_test(distributed_reset.cpp)
+target_link_libraries(${test_prefix}distributed_reset memgraph_lib kvstore_dummy_lib)
+
 add_unit_test(distributed_serialization.cpp)
 target_link_libraries(${test_prefix}distributed_serialization memgraph_lib kvstore_dummy_lib)
 
diff --git a/tests/unit/distributed_interpretation.cpp b/tests/unit/distributed_interpretation.cpp
index a18b38374..8c7f02e7d 100644
--- a/tests/unit/distributed_interpretation.cpp
+++ b/tests/unit/distributed_interpretation.cpp
@@ -6,6 +6,7 @@
 
 #include "database/graph_db.hpp"
 #include "distributed/plan_consumer.hpp"
+#include "distributed/plan_dispatcher.hpp"
 #include "distributed/pull_rpc_clients.hpp"
 #include "distributed_common.hpp"
 #include "query/interpreter.hpp"
diff --git a/tests/unit/distributed_query_plan.cpp b/tests/unit/distributed_query_plan.cpp
index 7f3fea5f3..0ca1da1d2 100644
--- a/tests/unit/distributed_query_plan.cpp
+++ b/tests/unit/distributed_query_plan.cpp
@@ -67,7 +67,7 @@ TEST_F(DistributedQueryPlan, PullProduceRpc) {
   std::vector<query::Symbol> symbols{ctx.symbol_table_[*x_ne]};
   auto remote_pull = [this, &command_id, &params, &symbols](
                          GraphDbAccessor &dba, int worker_id) {
-    return master().pull_clients().Pull(dba, worker_id, plan_id, command_id,
+    return master().pull_clients().Pull(&dba, worker_id, plan_id, command_id,
                                         params, symbols, 0, false, 3);
   };
   auto expect_first_batch = [](auto &batch) {
@@ -187,7 +187,7 @@ TEST_F(DistributedQueryPlan, PullProduceRpcWithGraphElements) {
                                      ctx.symbol_table_[*return_m], p_sym};
   auto remote_pull = [this, &command_id, &params, &symbols](
                          GraphDbAccessor &dba, int worker_id) {
-    return master().pull_clients().Pull(dba, worker_id, plan_id, command_id,
+    return master().pull_clients().Pull(&dba, worker_id, plan_id, command_id,
                                         params, symbols, 0, false, 3);
   };
   auto future_w1_results = remote_pull(dba, 1);
@@ -363,7 +363,7 @@ TEST_F(DistributedTransactionTimeout, Timeout) {
   auto remote_pull = [this, &command_id, &params, &symbols, &dba]() {
     return master()
         .pull_clients()
-        .Pull(dba, 1, plan_id, command_id, params, symbols, 0, false, 1)
+        .Pull(&dba, 1, plan_id, command_id, params, symbols, 0, false, 1)
         .get()
         .pull_state;
   };
diff --git a/tests/unit/distributed_reset.cpp b/tests/unit/distributed_reset.cpp
new file mode 100644
index 000000000..ba93b4e77
--- /dev/null
+++ b/tests/unit/distributed_reset.cpp
@@ -0,0 +1,35 @@
+#include "gtest/gtest.h"
+
+#include "distributed/plan_dispatcher.hpp"
+#include "distributed_common.hpp"
+#include "query/context.hpp"
+#include "query/interpret/frame.hpp"
+
+class DistributedReset : public DistributedGraphDbTest {
+ protected:
+  DistributedReset() : DistributedGraphDbTest("reset") {}
+};
+
+TEST_F(DistributedReset, ResetTest) {
+  query::SymbolTable symbol_table;
+  auto once = std::make_shared<query::plan::Once>();
+  auto pull_remote = std::make_shared<query::plan::PullRemote>(
+      once, 42, std::vector<query::Symbol>());
+  master().plan_dispatcher().DispatchPlan(42, once, symbol_table);
+  database::GraphDbAccessor dba{master()};
+  query::Frame frame(0);
+  query::Context context(dba);
+  auto pull_remote_cursor = pull_remote->MakeCursor(dba);
+
+  for (int i = 0; i < 3; ++i) {
+    EXPECT_TRUE(pull_remote_cursor->Pull(frame, context));
+  }
+  EXPECT_FALSE(pull_remote_cursor->Pull(frame, context));
+
+  pull_remote_cursor->Reset();
+
+  for (int i = 0; i < 3; ++i) {
+    EXPECT_TRUE(pull_remote_cursor->Pull(frame, context));
+  }
+  EXPECT_FALSE(pull_remote_cursor->Pull(frame, context));
+}