diff --git a/src/query/v2/request_router.hpp b/src/query/v2/request_router.hpp index 1336addae..5f5d103cc 100644 --- a/src/query/v2/request_router.hpp +++ b/src/query/v2/request_router.hpp @@ -11,6 +11,7 @@ #pragma once +#include <algorithm> #include <chrono> #include <deque> #include <iostream> @@ -99,6 +100,7 @@ class RequestRouterInterface { virtual std::vector<msgs::CreateVerticesResponse> CreateVertices(std::vector<msgs::NewVertex> new_vertices) = 0; virtual std::vector<msgs::ExpandOneResultRow> ExpandOne(msgs::ExpandOneRequest request) = 0; virtual std::vector<msgs::CreateExpandResponse> CreateExpand(std::vector<msgs::NewExpand> new_edges) = 0; + virtual std::vector<msgs::GetPropertiesResultRow> GetProperties(msgs::GetPropertiesRequest request) = 0; virtual storage::v3::EdgeTypeId NameToEdgeType(const std::string &name) const = 0; virtual storage::v3::PropertyId NameToProperty(const std::string &name) const = 0; @@ -355,6 +357,28 @@ class RequestRouter : public RequestRouterInterface { return result_rows; } + std::vector<msgs::GetPropertiesResultRow> GetProperties(msgs::GetPropertiesRequest requests) override { + ExecutionState<msgs::GetPropertiesRequest> state = {}; + InitializeExecutionState(state, std::move(requests)); + for (auto &request : state.requests) { + auto &storage_client = GetStorageClientForShard(request.shard); + msgs::ReadRequests req = request.request; + request.async_request_token = storage_client.SendAsyncReadRequest(req); + } + + std::vector<msgs::GetPropertiesResponse> responses; + do { + DriveReadResponses(state, responses); + } while (!state.requests.empty()); + + std::vector<msgs::GetPropertiesResultRow> result; + for (auto &res : responses) { + std::move(res.result_row.begin(), res.result_row.end(), std::back_inserter(result)); + } + + return result; + } + std::optional<storage::v3::PropertyId> MaybeNameToProperty(const std::string &name) const override { return shards_map_.GetPropertyId(name); } @@ -498,6 +522,44 @@ class RequestRouter : public RequestRouterInterface { return requests; } + void InitializeExecutionState(ExecutionState<msgs::GetPropertiesRequest> &state, msgs::GetPropertiesRequest request) { + std::map<Shard, msgs::GetPropertiesRequest> per_shard_request_table; + auto top_level_rqst_template = request; + top_level_rqst_template.transaction_id = transaction_id_; + top_level_rqst_template.vertex_ids.clear(); + top_level_rqst_template.vertices_and_edges.clear(); + + state.transaction_id = transaction_id_; + + for (auto &vertex : request.vertex_ids) { + auto shard = + shards_map_.GetShardForKey(vertex.first.id, storage::conversions::ConvertPropertyVector(vertex.second)); + if (!per_shard_request_table.contains(shard)) { + per_shard_request_table.insert(std::pair(shard, top_level_rqst_template)); + } + per_shard_request_table[shard].vertex_ids.emplace_back(std::move(vertex)); + } + + for (auto &[vertex, maybe_edge] : request.vertices_and_edges) { + auto shard = + shards_map_.GetShardForKey(vertex.first.id, storage::conversions::ConvertPropertyVector(vertex.second)); + if (!per_shard_request_table.contains(shard)) { + per_shard_request_table.insert(std::pair(shard, top_level_rqst_template)); + } + per_shard_request_table[shard].vertices_and_edges.emplace_back(std::move(vertex), maybe_edge); + } + + for (auto &[shard, rqst] : per_shard_request_table) { + ShardRequestState<msgs::GetPropertiesRequest> shard_request_state{ + .shard = shard, + .request = std::move(rqst), + .async_request_token = std::nullopt, + }; + + state.requests.emplace_back(std::move(shard_request_state)); + } + } + StorageClient &GetStorageClientForShard(Shard shard) { if (!storage_cli_manager_.Exists(shard)) { AddStorageClientToManager(shard); diff --git a/tests/simulation/CMakeLists.txt b/tests/simulation/CMakeLists.txt index 9e1a4c71e..cd5fc0a4a 100644 --- a/tests/simulation/CMakeLists.txt +++ b/tests/simulation/CMakeLists.txt @@ -32,3 +32,4 @@ add_simulation_test(trial_query_storage/query_storage_test.cpp) add_simulation_test(sharded_map.cpp) add_simulation_test(shard_rsm.cpp) add_simulation_test(cluster_property_test.cpp) +add_simulation_test(request_router.cpp) diff --git a/tests/simulation/common.hpp b/tests/simulation/common.hpp index a73bf37a5..fcdc1338c 100644 --- a/tests/simulation/common.hpp +++ b/tests/simulation/common.hpp @@ -76,14 +76,10 @@ class MockedShardRsm { using WriteRequests = msgs::WriteRequests; using WriteResponses = msgs::WriteResponses; - // ExpandOneResponse Read(ExpandOneRequest rqst); - // GetPropertiesResponse Read(GetPropertiesRequest rqst); msgs::ScanVerticesResponse ReadImpl(msgs::ScanVerticesRequest rqst) { msgs::ScanVerticesResponse ret; auto as_prop_val = storage::conversions::ConvertPropertyVector(rqst.start_id.second); - if (!IsKeyInRange(as_prop_val)) { - ret.success = false; - } else if (as_prop_val == ShardRsmKey{PropertyValue(0), PropertyValue(0)}) { + if (as_prop_val == ShardRsmKey{PropertyValue(0), PropertyValue(0)}) { msgs::Value val(int64_t(0)); ret.next_start_id = std::make_optional<msgs::VertexId>(); ret.next_start_id->second = @@ -91,37 +87,46 @@ class MockedShardRsm { msgs::ScanResultRow result; result.props.push_back(std::make_pair(msgs::PropertyId::FromUint(0), val)); ret.results.push_back(std::move(result)); - ret.success = true; } else if (as_prop_val == ShardRsmKey{PropertyValue(1), PropertyValue(0)}) { msgs::ScanResultRow result; msgs::Value val(int64_t(1)); result.props.push_back(std::make_pair(msgs::PropertyId::FromUint(0), val)); ret.results.push_back(std::move(result)); - ret.success = true; } else if (as_prop_val == ShardRsmKey{PropertyValue(12), PropertyValue(13)}) { msgs::ScanResultRow result; msgs::Value val(int64_t(444)); result.props.push_back(std::make_pair(msgs::PropertyId::FromUint(0), val)); ret.results.push_back(std::move(result)); - ret.success = true; - } else { - ret.success = false; } return ret; } msgs::ExpandOneResponse ReadImpl(msgs::ExpandOneRequest rqst) { return {}; } - msgs::ExpandOneResponse ReadImpl(msgs::GetPropertiesRequest rqst) { return {}; } + msgs::GetPropertiesResponse ReadImpl(msgs::GetPropertiesRequest rqst) { + msgs::GetPropertiesResponse resp; + auto &vertices = rqst.vertex_ids; + for (auto &vertex : vertices) { + auto as_prop_val = storage::conversions::ConvertPropertyVector(vertex.second); + if (as_prop_val == ShardRsmKey{PropertyValue(0), PropertyValue(0)}) { + resp.result_row.push_back(msgs::GetPropertiesResultRow{.vertex = std::move(vertex)}); + } else if (as_prop_val == ShardRsmKey{PropertyValue(1), PropertyValue(0)}) { + resp.result_row.push_back(msgs::GetPropertiesResultRow{.vertex = std::move(vertex)}); + } else if (as_prop_val == ShardRsmKey{PropertyValue(13), PropertyValue(13)}) { + resp.result_row.push_back(msgs::GetPropertiesResultRow{.vertex = std::move(vertex)}); + } + } + return resp; + } ReadResponses Read(ReadRequests read_requests) { return {std::visit([this]<typename T>(T &&request) { return ReadResponses{ReadImpl(std::forward<T>(request))}; }, std::move(read_requests))}; } - msgs::CreateVerticesResponse ApplyImpl(msgs::CreateVerticesRequest rqst) { return {.success = true}; } + msgs::CreateVerticesResponse ApplyImpl(msgs::CreateVerticesRequest rqst) { return {}; } msgs::DeleteVerticesResponse ApplyImpl(msgs::DeleteVerticesRequest rqst) { return {}; } msgs::UpdateVerticesResponse ApplyImpl(msgs::UpdateVerticesRequest rqst) { return {}; } - msgs::CreateExpandResponse ApplyImpl(msgs::CreateExpandRequest rqst) { return {.success = true}; } + msgs::CreateExpandResponse ApplyImpl(msgs::CreateExpandRequest rqst) { return {}; } msgs::DeleteEdgesResponse ApplyImpl(msgs::DeleteEdgesRequest rqst) { return {}; } msgs::UpdateEdgesResponse ApplyImpl(msgs::UpdateEdgesRequest rqst) { return {}; } msgs::CommitResponse ApplyImpl(msgs::CommitRequest rqst) { return {}; } diff --git a/tests/simulation/request_router.cpp b/tests/simulation/request_router.cpp index af8ec62ef..bc5168483 100644 --- a/tests/simulation/request_router.cpp +++ b/tests/simulation/request_router.cpp @@ -152,9 +152,7 @@ void RunStorageRaft(Raft<IoImpl, MockedShardRsm, WriteRequests, WriteResponses, } void TestScanVertices(query::v2::RequestRouterInterface &request_router) { - msgs::ExecutionState<ScanVerticesRequest> state{.label = "test_label"}; - - auto result = request_router.Request(state); + auto result = request_router.ScanVertices("test_label"); MG_ASSERT(result.size() == 2); { auto prop = result[0].GetProperty(msgs::PropertyId::FromUint(0)); @@ -162,18 +160,10 @@ void TestScanVertices(query::v2::RequestRouterInterface &request_router) { prop = result[1].GetProperty(msgs::PropertyId::FromUint(0)); MG_ASSERT(prop.int_v == 444); } - - result = request_router.Request(state); - { - MG_ASSERT(result.size() == 1); - auto prop = result[0].GetProperty(msgs::PropertyId::FromUint(0)); - MG_ASSERT(prop.int_v == 1); - } } void TestCreateVertices(query::v2::RequestRouterInterface &request_router) { using PropVal = msgs::Value; - msgs::ExecutionState<CreateVerticesRequest> state; std::vector<msgs::NewVertex> new_vertices; auto label_id = request_router.NameToLabel("test_label"); msgs::NewVertex a1{.primary_key = {PropVal(int64_t(1)), PropVal(int64_t(0))}}; @@ -183,13 +173,13 @@ void TestCreateVertices(query::v2::RequestRouterInterface &request_router) { new_vertices.push_back(std::move(a1)); new_vertices.push_back(std::move(a2)); - auto result = request_router.Request(state, std::move(new_vertices)); + auto result = request_router.CreateVertices(std::move(new_vertices)); MG_ASSERT(result.size() == 2); } void TestCreateExpand(query::v2::RequestRouterInterface &request_router) { using PropVal = msgs::Value; - msgs::ExecutionState<msgs::CreateExpandRequest> state; + msgs::CreateExpandRequest state; std::vector<msgs::NewExpand> new_expands; const auto edge_type_id = request_router.NameToEdgeType("edge_type"); @@ -203,24 +193,42 @@ void TestCreateExpand(query::v2::RequestRouterInterface &request_router) { new_expands.push_back(std::move(expand_1)); new_expands.push_back(std::move(expand_2)); - auto responses = request_router.Request(state, std::move(new_expands)); + auto responses = request_router.CreateExpand(std::move(new_expands)); MG_ASSERT(responses.size() == 2); - MG_ASSERT(responses[0].success); - MG_ASSERT(responses[1].success); + MG_ASSERT(!responses[0].error); + MG_ASSERT(!responses[1].error); } void TestExpandOne(query::v2::RequestRouterInterface &request_router) { - msgs::ExecutionState<msgs::ExpandOneRequest> state{}; + msgs::ExpandOneRequest state{}; msgs::ExpandOneRequest request; const auto edge_type_id = request_router.NameToEdgeType("edge_type"); const auto label = msgs::Label{request_router.NameToLabel("test_label")}; request.src_vertices.push_back(msgs::VertexId{label, {msgs::Value(int64_t(0)), msgs::Value(int64_t(0))}}); request.edge_types.push_back(msgs::EdgeType{edge_type_id}); request.direction = msgs::EdgeDirection::BOTH; - auto result_rows = request_router.Request(state, std::move(request)); + auto result_rows = request_router.ExpandOne(std::move(request)); MG_ASSERT(result_rows.size() == 2); } +void TestGetProperties(query::v2::RequestRouterInterface &request_router) { + using PropVal = msgs::Value; + + auto label_id = request_router.NameToLabel("test_label"); + msgs::VertexId v0{{label_id}, {PropVal(int64_t(0)), PropVal(int64_t(0))}}; + msgs::VertexId v1{{label_id}, {PropVal(int64_t(1)), PropVal(int64_t(0))}}; + msgs::VertexId v2{{label_id}, {PropVal(int64_t(13)), PropVal(int64_t(13))}}; + + msgs::GetPropertiesRequest request; + + request.vertex_ids.push_back({v0}); + request.vertex_ids.push_back({v1}); + request.vertex_ids.push_back({v2}); + + auto result = request_router.GetProperties(std::move(request)); + MG_ASSERT(result.size() == 3); +} + template <typename RequestRouter> void TestAggregate(RequestRouter &request_router) {} @@ -345,6 +353,7 @@ void DoTest() { TestScanVertices(request_router); TestCreateVertices(request_router); TestCreateExpand(request_router); + TestGetProperties(request_router); simulator.ShutDown(); diff --git a/tests/unit/query_v2_expression_evaluator.cpp b/tests/unit/query_v2_expression_evaluator.cpp index 50f578bb2..5e91d0d5a 100644 --- a/tests/unit/query_v2_expression_evaluator.cpp +++ b/tests/unit/query_v2_expression_evaluator.cpp @@ -51,6 +51,8 @@ using memgraph::msgs::CreateVerticesResponse; using memgraph::msgs::ExpandOneRequest; using memgraph::msgs::ExpandOneResponse; using memgraph::msgs::ExpandOneResultRow; +using memgraph::msgs::GetPropertiesRequest; +using memgraph::msgs::GetPropertiesResultRow; using memgraph::msgs::NewExpand; using memgraph::msgs::NewVertex; using memgraph::msgs::ScanVerticesRequest; @@ -93,6 +95,8 @@ class MockedRequestRouter : public RequestRouterInterface { std::vector<CreateExpandResponse> CreateExpand(std::vector<NewExpand> /* new_edges */) override { return {}; } + std::vector<GetPropertiesResultRow> GetProperties(GetPropertiesRequest rqst) override { return {}; } + const std::string &PropertyToName(memgraph::storage::v3::PropertyId id) const override { return properties_.IdToName(id.AsUint()); }