From 7a3caa320cb1044d155b192634cd53b56ca76aa6 Mon Sep 17 00:00:00 2001
From: Kostas Kyrimis <kostaskyrim@gmail.com>
Date: Thu, 24 Nov 2022 14:25:20 +0200
Subject: [PATCH] WiP

---
 src/query/v2/shard_request_manager.hpp     | 92 ++++++++++++++++++++--
 src/storage/v3/shard_rsm.cpp               |  3 +
 tests/simulation/CMakeLists.txt            |  1 +
 tests/simulation/shard_request_manager.cpp | 19 ++++-
 4 files changed, 108 insertions(+), 7 deletions(-)

diff --git a/src/query/v2/shard_request_manager.hpp b/src/query/v2/shard_request_manager.hpp
index 20bae7b97..4db77e645 100644
--- a/src/query/v2/shard_request_manager.hpp
+++ b/src/query/v2/shard_request_manager.hpp
@@ -105,6 +105,7 @@ struct ExecutionState {
 class ShardRequestManagerInterface {
  public:
   using VertexAccessor = memgraph::query::v2::accessors::VertexAccessor;
+  using EdgeAccessor = memgraph::query::v2::accessors::EdgeAccessor;
   ShardRequestManagerInterface() = default;
   ShardRequestManagerInterface(const ShardRequestManagerInterface &) = delete;
   ShardRequestManagerInterface(ShardRequestManagerInterface &&) = delete;
@@ -122,7 +123,8 @@ class ShardRequestManagerInterface {
                                                   ExpandOneRequest request) = 0;
   virtual std::vector<CreateExpandResponse> Request(ExecutionState<CreateExpandRequest> &state,
                                                     std::vector<NewExpand> new_edges) = 0;
-
+  virtual std::vector<GetPropertiesResponse> Request(ExecutionState<GetPropertiesRequest> &state,
+                                                     GetPropertiesRequest request) = 0;
   virtual storage::v3::EdgeTypeId NameToEdgeType(const std::string &name) const = 0;
   virtual storage::v3::PropertyId NameToProperty(const std::string &name) const = 0;
   virtual storage::v3::LabelId NameToLabel(const std::string &name) const = 0;
@@ -146,6 +148,7 @@ class ShardRequestManager : public ShardRequestManagerInterface {
   using ShardMap = memgraph::coordinator::ShardMap;
   using CompoundKey = memgraph::coordinator::PrimaryKey;
   using VertexAccessor = memgraph::query::v2::accessors::VertexAccessor;
+  using EdgeAccessor = memgraph::query::v2::accessors::EdgeAccessor;
   ShardRequestManager(CoordinatorClient coord, memgraph::io::Io<TTransport> &&io)
       : coord_cli_(std::move(coord)), io_(std::move(io)) {}
 
@@ -353,6 +356,21 @@ class ShardRequestManager : public ShardRequestManagerInterface {
     return result_rows;
   }
 
+  std::vector<GetPropertiesResponse> Request(ExecutionState<GetPropertiesRequest> &state,
+                                             GetPropertiesRequest requests) override {
+    MaybeInitializeExecutionState(state, std::move(requests));
+    SendAllRequests(state);
+
+    std::vector<GetPropertiesResponse> responses;
+    // 2. Block untill all the futures are exhausted
+    do {
+      AwaitOnResponses(state, responses);
+    } while (!state.shard_cache.empty());
+
+    MaybeCompleteState(state);
+    return responses;
+  }
+
  private:
   enum class PaginatedResponseState { Pending, PartiallyFinished };
 
@@ -373,6 +391,13 @@ class ShardRequestManager : public ShardRequestManagerInterface {
     }
   }
 
+  template <typename ExecutionState>
+  void ThrowIfStateExecuting(ExecutionState &state) const {
+    if (state.state == ExecutionState::EXECUTING) [[unlikely]] {
+      throw std::runtime_error("State is completed and must be reset");
+    }
+  }
+
   template <typename ExecutionState>
   void MaybeCompleteState(ExecutionState &state) const {
     if (state.requests.empty()) {
@@ -508,6 +533,33 @@ class ShardRequestManager : public ShardRequestManagerInterface {
     state.state = ExecutionState<ExpandOneRequest>::EXECUTING;
   }
 
+  void MaybeInitializeExecutionState(ExecutionState<GetPropertiesRequest> &state, GetPropertiesRequest request) {
+    ThrowIfStateCompleted(state);
+    ThrowIfStateExecuting(state);
+
+    std::map<Shard, GetPropertiesRequest> per_shard_request_table;
+    auto top_level_rqst_template = request;
+    top_level_rqst_template.transaction_id = transaction_id_;
+    top_level_rqst_template.vertices_and_edges.clear();
+
+    state.transaction_id = transaction_id_;
+
+    for (auto &[vertex, maybe_edge] : request.vertices_and_edges) {
+      auto shard =
+          shards_map_.GetShardForKey(vertex.first.id, storage::conversions::ConvertPropertyVector(vertex.second));
+      if (!per_shard_request_table.contains(shard)) {
+        per_shard_request_table.insert(std::pair(shard, top_level_rqst_template));
+        state.shard_cache.push_back(shard);
+      }
+      per_shard_request_table[shard].vertices_and_edges.push_back({std::move(vertex), maybe_edge});
+    }
+
+    for (auto &[shard, rqst] : per_shard_request_table) {
+      state.requests.push_back(std::move(rqst));
+    }
+    state.state = ExecutionState<GetPropertiesRequest>::EXECUTING;
+  }
+
   StorageClient &GetStorageClientForShard(Shard shard) {
     if (!storage_cli_manager_.Exists(shard)) {
       AddStorageClientToManager(shard);
@@ -532,7 +584,8 @@ class ShardRequestManager : public ShardRequestManagerInterface {
     storage_cli_manager_.AddClient(target_shard, std::move(cli));
   }
 
-  void SendAllRequests(ExecutionState<ScanVerticesRequest> &state) {
+  template <typename TRequest>
+  void SendAllRequests(ExecutionState<TRequest> &state) {
     int64_t shard_idx = 0;
     for (const auto &request : state.requests) {
       const auto &current_shard = state.shard_cache[shard_idx];
@@ -581,9 +634,6 @@ class ShardRequestManager : public ShardRequestManagerInterface {
     int64_t request_idx = 0;
 
     for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end();) {
-      // This is fine because all new_vertices of each request end up on the same shard
-      const auto labels = state.requests[request_idx].new_vertices[0].label_ids;
-
       auto &storage_client = GetStorageClientForShard(*shard_it);
 
       auto poll_result = storage_client.AwaitAsyncWriteRequest();
@@ -649,6 +699,38 @@ class ShardRequestManager : public ShardRequestManagerInterface {
     }
   }
 
+  void AwaitOnResponses(ExecutionState<GetPropertiesRequest> &state, std::vector<GetPropertiesResponse> &responses) {
+    auto &shard_cache_ref = state.shard_cache;
+    int64_t request_idx = 0;
+
+    for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end();) {
+      auto &storage_client = GetStorageClientForShard(*shard_it);
+
+      auto poll_result = storage_client.PollAsyncReadRequest();
+      if (!poll_result) {
+        ++shard_it;
+        ++request_idx;
+        continue;
+      }
+
+      if (poll_result->HasError()) {
+        throw std::runtime_error("GetProperties request timed out");
+      }
+
+      ReadResponses response_variant = poll_result->GetValue();
+      auto response = std::get<GetPropertiesResponse>(response_variant);
+      if (response.result != GetPropertiesResponse::SUCCESS) {
+        throw std::runtime_error("GetProperties request did not succeed");
+      }
+
+      responses.push_back(std::move(response));
+      shard_it = shard_cache_ref.erase(shard_it);
+      // Needed to maintain the 1-1 mapping between the ShardCache and the requests.
+      auto it = state.requests.begin() + request_idx;
+      state.requests.erase(it);
+    }
+  }
+
   void AwaitOnPaginatedRequests(ExecutionState<ScanVerticesRequest> &state,
                                 std::vector<ScanVerticesResponse> &responses,
                                 std::map<Shard, PaginatedResponseState> &paginated_response_tracker) {
diff --git a/src/storage/v3/shard_rsm.cpp b/src/storage/v3/shard_rsm.cpp
index 3188458ce..9a2d6dcc8 100644
--- a/src/storage/v3/shard_rsm.cpp
+++ b/src/storage/v3/shard_rsm.cpp
@@ -518,6 +518,9 @@ msgs::ReadResponses ShardRsm::HandleRead(msgs::GetPropertiesRequest &&req) {
   if (req.vertices_and_edges.empty()) {
     return msgs::GetPropertiesResponse{.result = msgs::GetPropertiesResponse::FAILURE};
   }
+  if (req.property_ids.empty()) {
+    return msgs::GetPropertiesResponse{.result = msgs::GetPropertiesResponse::SUCCESS};
+  }
 
   auto shard_acc = shard_->Access(req.transaction_id);
   auto dba = DbAccessor{&shard_acc};
diff --git a/tests/simulation/CMakeLists.txt b/tests/simulation/CMakeLists.txt
index 9e1a4c71e..2d7e6c49c 100644
--- a/tests/simulation/CMakeLists.txt
+++ b/tests/simulation/CMakeLists.txt
@@ -32,3 +32,4 @@ add_simulation_test(trial_query_storage/query_storage_test.cpp)
 add_simulation_test(sharded_map.cpp)
 add_simulation_test(shard_rsm.cpp)
 add_simulation_test(cluster_property_test.cpp)
+add_simulation_test(shard_request_manager.cpp)
diff --git a/tests/simulation/shard_request_manager.cpp b/tests/simulation/shard_request_manager.cpp
index 746ab385f..e62e15318 100644
--- a/tests/simulation/shard_request_manager.cpp
+++ b/tests/simulation/shard_request_manager.cpp
@@ -221,8 +221,22 @@ void TestExpandOne(msgs::ShardRequestManagerInterface &shard_request_manager) {
   MG_ASSERT(result_rows.size() == 2);
 }
 
-template <typename ShardRequestManager>
-void TestAggregate(ShardRequestManager &io) {}
+void TestGetProperties(msgs::ShardRequestManagerInterface &shard_request_manager) {
+  using PropVal = msgs::Value;
+
+  auto label_id = shard_request_manager.NameToLabel("test_label");
+  msgs::VertexId v1{{label_id}, {PropVal(int64_t(1)), PropVal(int64_t(0))}};
+  msgs::VertexId v2{{label_id}, {PropVal(int64_t(13)), PropVal(int64_t(13))}};
+
+  msgs::ExecutionState<msgs::GetPropertiesRequest> state;
+  msgs::GetPropertiesRequest request;
+
+  request.vertices_and_edges.push_back({v1});
+  request.vertices_and_edges.push_back({v2});
+
+  auto result = shard_request_manager.Request(state, std::move(request));
+  MG_ASSERT(result.size() == 2);
+}
 
 void DoTest() {
   SimulatorConfig config{
@@ -343,6 +357,7 @@ void DoTest() {
   TestScanVertices(io);
   TestCreateVertices(io);
   TestCreateExpand(io);
+  TestGetProperties(io);
 
   simulator.ShutDown();