diff --git a/src/query/v2/shard_request_manager.hpp b/src/query/v2/shard_request_manager.hpp index 20bae7b97..4db77e645 100644 --- a/src/query/v2/shard_request_manager.hpp +++ b/src/query/v2/shard_request_manager.hpp @@ -105,6 +105,7 @@ struct ExecutionState { class ShardRequestManagerInterface { public: using VertexAccessor = memgraph::query::v2::accessors::VertexAccessor; + using EdgeAccessor = memgraph::query::v2::accessors::EdgeAccessor; ShardRequestManagerInterface() = default; ShardRequestManagerInterface(const ShardRequestManagerInterface &) = delete; ShardRequestManagerInterface(ShardRequestManagerInterface &&) = delete; @@ -122,7 +123,8 @@ class ShardRequestManagerInterface { ExpandOneRequest request) = 0; virtual std::vector Request(ExecutionState &state, std::vector new_edges) = 0; - + virtual std::vector Request(ExecutionState &state, + GetPropertiesRequest request) = 0; virtual storage::v3::EdgeTypeId NameToEdgeType(const std::string &name) const = 0; virtual storage::v3::PropertyId NameToProperty(const std::string &name) const = 0; virtual storage::v3::LabelId NameToLabel(const std::string &name) const = 0; @@ -146,6 +148,7 @@ class ShardRequestManager : public ShardRequestManagerInterface { using ShardMap = memgraph::coordinator::ShardMap; using CompoundKey = memgraph::coordinator::PrimaryKey; using VertexAccessor = memgraph::query::v2::accessors::VertexAccessor; + using EdgeAccessor = memgraph::query::v2::accessors::EdgeAccessor; ShardRequestManager(CoordinatorClient coord, memgraph::io::Io &&io) : coord_cli_(std::move(coord)), io_(std::move(io)) {} @@ -353,6 +356,21 @@ class ShardRequestManager : public ShardRequestManagerInterface { return result_rows; } + std::vector Request(ExecutionState &state, + GetPropertiesRequest requests) override { + MaybeInitializeExecutionState(state, std::move(requests)); + SendAllRequests(state); + + std::vector responses; + // 2. Block untill all the futures are exhausted + do { + AwaitOnResponses(state, responses); + } while (!state.shard_cache.empty()); + + MaybeCompleteState(state); + return responses; + } + private: enum class PaginatedResponseState { Pending, PartiallyFinished }; @@ -373,6 +391,13 @@ class ShardRequestManager : public ShardRequestManagerInterface { } } + template + void ThrowIfStateExecuting(ExecutionState &state) const { + if (state.state == ExecutionState::EXECUTING) [[unlikely]] { + throw std::runtime_error("State is completed and must be reset"); + } + } + template void MaybeCompleteState(ExecutionState &state) const { if (state.requests.empty()) { @@ -508,6 +533,33 @@ class ShardRequestManager : public ShardRequestManagerInterface { state.state = ExecutionState::EXECUTING; } + void MaybeInitializeExecutionState(ExecutionState &state, GetPropertiesRequest request) { + ThrowIfStateCompleted(state); + ThrowIfStateExecuting(state); + + std::map per_shard_request_table; + auto top_level_rqst_template = request; + top_level_rqst_template.transaction_id = transaction_id_; + top_level_rqst_template.vertices_and_edges.clear(); + + state.transaction_id = transaction_id_; + + for (auto &[vertex, maybe_edge] : request.vertices_and_edges) { + auto shard = + shards_map_.GetShardForKey(vertex.first.id, storage::conversions::ConvertPropertyVector(vertex.second)); + if (!per_shard_request_table.contains(shard)) { + per_shard_request_table.insert(std::pair(shard, top_level_rqst_template)); + state.shard_cache.push_back(shard); + } + per_shard_request_table[shard].vertices_and_edges.push_back({std::move(vertex), maybe_edge}); + } + + for (auto &[shard, rqst] : per_shard_request_table) { + state.requests.push_back(std::move(rqst)); + } + state.state = ExecutionState::EXECUTING; + } + StorageClient &GetStorageClientForShard(Shard shard) { if (!storage_cli_manager_.Exists(shard)) { AddStorageClientToManager(shard); @@ -532,7 +584,8 @@ class ShardRequestManager : public ShardRequestManagerInterface { storage_cli_manager_.AddClient(target_shard, std::move(cli)); } - void SendAllRequests(ExecutionState &state) { + template + void SendAllRequests(ExecutionState &state) { int64_t shard_idx = 0; for (const auto &request : state.requests) { const auto ¤t_shard = state.shard_cache[shard_idx]; @@ -581,9 +634,6 @@ class ShardRequestManager : public ShardRequestManagerInterface { int64_t request_idx = 0; for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end();) { - // This is fine because all new_vertices of each request end up on the same shard - const auto labels = state.requests[request_idx].new_vertices[0].label_ids; - auto &storage_client = GetStorageClientForShard(*shard_it); auto poll_result = storage_client.AwaitAsyncWriteRequest(); @@ -649,6 +699,38 @@ class ShardRequestManager : public ShardRequestManagerInterface { } } + void AwaitOnResponses(ExecutionState &state, std::vector &responses) { + auto &shard_cache_ref = state.shard_cache; + int64_t request_idx = 0; + + for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end();) { + auto &storage_client = GetStorageClientForShard(*shard_it); + + auto poll_result = storage_client.PollAsyncReadRequest(); + if (!poll_result) { + ++shard_it; + ++request_idx; + continue; + } + + if (poll_result->HasError()) { + throw std::runtime_error("GetProperties request timed out"); + } + + ReadResponses response_variant = poll_result->GetValue(); + auto response = std::get(response_variant); + if (response.result != GetPropertiesResponse::SUCCESS) { + throw std::runtime_error("GetProperties request did not succeed"); + } + + responses.push_back(std::move(response)); + shard_it = shard_cache_ref.erase(shard_it); + // Needed to maintain the 1-1 mapping between the ShardCache and the requests. + auto it = state.requests.begin() + request_idx; + state.requests.erase(it); + } + } + void AwaitOnPaginatedRequests(ExecutionState &state, std::vector &responses, std::map &paginated_response_tracker) { diff --git a/src/storage/v3/shard_rsm.cpp b/src/storage/v3/shard_rsm.cpp index 3188458ce..9a2d6dcc8 100644 --- a/src/storage/v3/shard_rsm.cpp +++ b/src/storage/v3/shard_rsm.cpp @@ -518,6 +518,9 @@ msgs::ReadResponses ShardRsm::HandleRead(msgs::GetPropertiesRequest &&req) { if (req.vertices_and_edges.empty()) { return msgs::GetPropertiesResponse{.result = msgs::GetPropertiesResponse::FAILURE}; } + if (req.property_ids.empty()) { + return msgs::GetPropertiesResponse{.result = msgs::GetPropertiesResponse::SUCCESS}; + } auto shard_acc = shard_->Access(req.transaction_id); auto dba = DbAccessor{&shard_acc}; diff --git a/tests/simulation/CMakeLists.txt b/tests/simulation/CMakeLists.txt index 9e1a4c71e..2d7e6c49c 100644 --- a/tests/simulation/CMakeLists.txt +++ b/tests/simulation/CMakeLists.txt @@ -32,3 +32,4 @@ add_simulation_test(trial_query_storage/query_storage_test.cpp) add_simulation_test(sharded_map.cpp) add_simulation_test(shard_rsm.cpp) add_simulation_test(cluster_property_test.cpp) +add_simulation_test(shard_request_manager.cpp) diff --git a/tests/simulation/shard_request_manager.cpp b/tests/simulation/shard_request_manager.cpp index 746ab385f..e62e15318 100644 --- a/tests/simulation/shard_request_manager.cpp +++ b/tests/simulation/shard_request_manager.cpp @@ -221,8 +221,22 @@ void TestExpandOne(msgs::ShardRequestManagerInterface &shard_request_manager) { MG_ASSERT(result_rows.size() == 2); } -template -void TestAggregate(ShardRequestManager &io) {} +void TestGetProperties(msgs::ShardRequestManagerInterface &shard_request_manager) { + using PropVal = msgs::Value; + + auto label_id = shard_request_manager.NameToLabel("test_label"); + msgs::VertexId v1{{label_id}, {PropVal(int64_t(1)), PropVal(int64_t(0))}}; + msgs::VertexId v2{{label_id}, {PropVal(int64_t(13)), PropVal(int64_t(13))}}; + + msgs::ExecutionState state; + msgs::GetPropertiesRequest request; + + request.vertices_and_edges.push_back({v1}); + request.vertices_and_edges.push_back({v2}); + + auto result = shard_request_manager.Request(state, std::move(request)); + MG_ASSERT(result.size() == 2); +} void DoTest() { SimulatorConfig config{ @@ -343,6 +357,7 @@ void DoTest() { TestScanVertices(io); TestCreateVertices(io); TestCreateExpand(io); + TestGetProperties(io); simulator.ShutDown();