From 2320f95dd1456c678da7de134097557a821e5940 Mon Sep 17 00:00:00 2001
From: Tyler Neely <t@jujit.su>
Date: Thu, 18 Aug 2022 14:01:47 +0000
Subject: [PATCH] Update the coordinator to include request for initializing a
 new shard map

---
 src/coordinator/coordinator.hpp  | 39 ++++++++++++++++++++++++++++----
 src/coordinator/shard_map.hpp    | 33 ++++++++++++++++++++-------
 tests/simulation/sharded_map.cpp | 15 ++++++------
 3 files changed, 67 insertions(+), 20 deletions(-)

diff --git a/src/coordinator/coordinator.hpp b/src/coordinator/coordinator.hpp
index fcd073d15..8eb070c72 100644
--- a/src/coordinator/coordinator.hpp
+++ b/src/coordinator/coordinator.hpp
@@ -87,10 +87,22 @@ struct DeregisterStorageEngineResponse {
   bool success;
 };
 
-using WriteRequests = std::variant<AllocateHlcBatchRequest, AllocateEdgeIdBatchRequest, SplitShardRequest,
-                                   RegisterStorageEngineRequest, DeregisterStorageEngineRequest>;
-using WriteResponses = std::variant<AllocateHlcBatchResponse, AllocateEdgeIdBatchResponse, SplitShardResponse,
-                                    RegisterStorageEngineResponse, DeregisterStorageEngineResponse>;
+struct InitializeLabelRequest {
+  std::string label_name;
+  Hlc last_shard_map_version;
+};
+
+struct InitializeLabelResponse {
+  bool success;
+  std::optional<ShardMap> fresher_shard_map;
+};
+
+using WriteRequests =
+    std::variant<AllocateHlcBatchRequest, AllocateEdgeIdBatchRequest, SplitShardRequest, RegisterStorageEngineRequest,
+                 DeregisterStorageEngineRequest, InitializeLabelRequest>;
+using WriteResponses =
+    std::variant<AllocateHlcBatchResponse, AllocateEdgeIdBatchResponse, SplitShardResponse,
+                 RegisterStorageEngineResponse, DeregisterStorageEngineResponse, InitializeLabelResponse>;
 
 using ReadRequests = std::variant<HlcRequest, GetShardMapRequest>;
 using ReadResponses = std::variant<HlcResponse, GetShardMapResponse>;
@@ -123,7 +135,7 @@ class Coordinator {
 
     MG_ASSERT(!(hlc_request.last_shard_map_version.logical_id > hlc_shard_map.logical_id));
 
-    res.new_hlc = shard_map_.UpdateShardMapVersion();
+    res.new_hlc = shard_map_.IncrementShardMapVersion();
 
     // res.fresher_shard_map = hlc_request.last_shard_map_version.logical_id < hlc_shard_map.logical_id
     //                             ? std::make_optional(shard_map_)
@@ -199,6 +211,23 @@ class Coordinator {
     return res;
   }
 
+  WriteResponses ApplyWrite(InitializeLabelRequest &&initialize_label_request) {
+    InitializeLabelResponse res{};
+
+    bool success = shard_map_.InitializeNewLabel(initialize_label_request.label_name,
+                                                 initialize_label_request.last_shard_map_version);
+
+    if (success) {
+      res.fresher_shard_map = shard_map_;
+      res.success = false;
+    } else {
+      res.fresher_shard_map = std::nullopt;
+      res.success = true;
+    }
+
+    return res;
+  }
+
  public:
   explicit Coordinator(ShardMap sm) : shard_map_{(sm)} {}
 
diff --git a/src/coordinator/shard_map.hpp b/src/coordinator/shard_map.hpp
index bc891db5e..7624ff904 100644
--- a/src/coordinator/shard_map.hpp
+++ b/src/coordinator/shard_map.hpp
@@ -46,16 +46,9 @@ struct ShardMap {
   Hlc shard_map_version;
   std::map<Label, Shards> shards;
 
-  // TODO(gabor) later we will want to update the wallclock time with
-  // the given Io<impl>'s time as well. This function should just be
-  // replaced with operator== since it is already overloaded for Hlc
-  // objects.
-  bool CompareShardMapVersions(Hlc one, Hlc two) { return one.logical_id == two.logical_id; }
-
- public:
   // TODO(gabor) later we will want to update the wallclock time with
   // the given Io<impl>'s time as well
-  Hlc UpdateShardMapVersion() noexcept {
+  Hlc IncrementShardMapVersion() noexcept {
     ++shard_map_version.logical_id;
     return shard_map_version;
   }
@@ -83,12 +76,29 @@ struct ShardMap {
 
       // Apply the split
       shards_in_map[key] = shard_to_map_to;
+
       return true;
     }
 
     return false;
   }
 
+  bool InitializeNewLabel(std::string label_name, Hlc last_shard_map_version) {
+    if (shard_map_version != last_shard_map_version) {
+      return false;
+    }
+
+    if (shards.contains(label_name)) {
+      return false;
+    }
+
+    shards.emplace(label_name, Shards{});
+
+    IncrementShardMapVersion();
+
+    return true;
+  }
+
   void AddServer(Address server_address) {
     // Find a random place for the server to plug in
   }
@@ -106,6 +116,13 @@ struct ShardMap {
 
     return asd2;
   }
+
+ private:
+  // TODO(gabor) later we will want to update the wallclock time with
+  // the given Io<impl>'s time as well. This function should just be
+  // replaced with operator== since it is already overloaded for Hlc
+  // objects.
+  bool CompareShardMapVersions(Hlc one, Hlc two) { return one.logical_id == two.logical_id; }
 };
 
 }  // namespace memgraph::coordinator
diff --git a/tests/simulation/sharded_map.cpp b/tests/simulation/sharded_map.cpp
index 677932067..6a1f82220 100644
--- a/tests/simulation/sharded_map.cpp
+++ b/tests/simulation/sharded_map.cpp
@@ -49,8 +49,8 @@ using memgraph::io::rsm::Raft;
 using memgraph::io::rsm::ReadRequest;
 using memgraph::io::rsm::ReadResponse;
 using memgraph::io::rsm::RsmClient;
-using memgraph::io::rsm::StorageGetRequest;
-using memgraph::io::rsm::StorageGetResponse;
+using memgraph::io::rsm::StorageReadRequest;
+using memgraph::io::rsm::StorageReadResponse;
 using memgraph::io::rsm::StorageRsm;
 using memgraph::io::rsm::StorageWriteRequest;
 using memgraph::io::rsm::StorageWriteResponse;
@@ -62,8 +62,8 @@ using memgraph::io::simulator::SimulatorStats;
 using memgraph::io::simulator::SimulatorTransport;
 using memgraph::utils::BasicResult;
 
-using StorageClient =
-    RsmClient<Io<SimulatorTransport>, StorageWriteRequest, StorageWriteResponse, StorageGetRequest, StorageGetResponse>;
+using StorageClient = RsmClient<Io<SimulatorTransport>, StorageWriteRequest, StorageWriteResponse, StorageReadRequest,
+                                StorageReadResponse>;
 namespace {
 
 ShardMap CreateDummyShardmap(memgraph::coordinator::Address a_io_1, memgraph::coordinator::Address a_io_2,
@@ -122,11 +122,12 @@ std::optional<StorageClient> DetermineShardLocation(Shard target_shard, const st
 
 using ConcreteCoordinatorRsm = CoordinatorRsm<SimulatorTransport>;
 using ConcreteStorageRsm = Raft<SimulatorTransport, StorageRsm, StorageWriteRequest, StorageWriteResponse,
-                                StorageGetRequest, StorageGetResponse>;
+                                StorageReadRequest, StorageReadResponse>;
 
 template <typename IoImpl>
 void RunStorageRaft(
-    Raft<IoImpl, StorageRsm, StorageWriteRequest, StorageWriteResponse, StorageGetRequest, StorageGetResponse> server) {
+    Raft<IoImpl, StorageRsm, StorageWriteRequest, StorageWriteResponse, StorageReadRequest, StorageReadResponse>
+        server) {
   server.Run();
 }
 
@@ -307,7 +308,7 @@ int main() {
     // Have client use shard map to decide which shard to communicate
     // with to read that same value back
 
-    StorageGetRequest storage_get_req;
+    StorageReadRequest storage_get_req;
     storage_get_req.key = {write_key_1, write_key_2};
 
     auto get_response_result = storage_client.SendReadRequest(storage_get_req);