diff --git a/src/io/rsm/rsm_client.hpp b/src/io/rsm/rsm_client.hpp index b60380b08..920866c7a 100644 --- a/src/io/rsm/rsm_client.hpp +++ b/src/io/rsm/rsm_client.hpp @@ -14,6 +14,7 @@ #include <iostream> #include <optional> #include <type_traits> +#include <unordered_map> #include <vector> #include "io/address.hpp" @@ -36,6 +37,21 @@ using memgraph::io::rsm::WriteRequest; using memgraph::io::rsm::WriteResponse; using memgraph::utils::BasicResult; +class AsyncRequestToken { + size_t id_; + + public: + explicit AsyncRequestToken(size_t id) : id_(id) {} + size_t GetId() const { return id_; } +}; + +template <typename RequestT, typename ResponseT> +struct AsyncRequest { + Time start_time; + RequestT request; + ResponseFuture<ResponseT> future; +}; + template <typename IoImpl, typename WriteRequestT, typename WriteResponseT, typename ReadRequestT, typename ReadResponseT> class RsmClient { @@ -47,23 +63,17 @@ class RsmClient { /// State for single async read/write operations. In the future this could become a map /// of async operations that can be accessed via an ID etc... - std::optional<Time> async_read_before_; - std::optional<ResponseFuture<ReadResponse<ReadResponseT>>> async_read_; - ReadRequestT current_read_request_; + std::unordered_map<size_t, AsyncRequest<ReadRequestT, ReadResponse<ReadResponseT>>> async_reads_; + std::unordered_map<size_t, AsyncRequest<WriteRequestT, WriteResponse<WriteResponseT>>> async_writes_; - std::optional<Time> async_write_before_; - std::optional<ResponseFuture<WriteResponse<WriteResponseT>>> async_write_; - WriteRequestT current_write_request_; + size_t async_token_generator_ = 0; void SelectRandomLeader() { std::uniform_int_distribution<size_t> addr_distrib(0, (server_addrs_.size() - 1)); size_t addr_index = io_.Rand(addr_distrib); leader_ = server_addrs_[addr_index]; - spdlog::debug( - "client NOT redirected to leader server despite our success failing to be processed (it probably was sent to " - "a RSM Candidate) trying a random one at index {} with address {}", - addr_index, leader_.ToString()); + spdlog::debug("selecting a random leader at index {} with address {}", addr_index, leader_.ToString()); } template <typename ResponseT> @@ -91,107 +101,74 @@ class RsmClient { ~RsmClient() = default; BasicResult<TimedOut, WriteResponseT> SendWriteRequest(WriteRequestT req) { - WriteRequest<WriteRequestT> client_req; - client_req.operation = req; - - const Duration overall_timeout = io_.GetDefaultTimeout(); - const Time before = io_.Now(); - - do { - spdlog::debug("client sending WriteRequest to Leader {}", leader_.ToString()); - ResponseFuture<WriteResponse<WriteResponseT>> response_future = - io_.template Request<WriteRequest<WriteRequestT>, WriteResponse<WriteResponseT>>(leader_, client_req); - ResponseResult<WriteResponse<WriteResponseT>> response_result = std::move(response_future).Wait(); - - if (response_result.HasError()) { - spdlog::debug("client timed out while trying to communicate with leader server {}", leader_.ToString()); - return response_result.GetError(); - } - - ResponseEnvelope<WriteResponse<WriteResponseT>> &&response_envelope = std::move(response_result.GetValue()); - WriteResponse<WriteResponseT> &&write_response = std::move(response_envelope.message); - - if (write_response.success) { - return std::move(write_response.write_return); - } - - PossiblyRedirectLeader(write_response); - } while (io_.Now() < before + overall_timeout); - - return TimedOut{}; + auto token = SendAsyncWriteRequest(req); + auto poll_result = AwaitAsyncWriteRequest(token); + while (!poll_result) { + poll_result = AwaitAsyncWriteRequest(token); + } + return poll_result.value(); } BasicResult<TimedOut, ReadResponseT> SendReadRequest(ReadRequestT req) { - ReadRequest<ReadRequestT> read_req; - read_req.operation = req; - - const Duration overall_timeout = io_.GetDefaultTimeout(); - const Time before = io_.Now(); - - do { - spdlog::debug("client sending ReadRequest to Leader {}", leader_.ToString()); - - ResponseFuture<ReadResponse<ReadResponseT>> get_response_future = - io_.template Request<ReadRequest<ReadRequestT>, ReadResponse<ReadResponseT>>(leader_, read_req); - - // receive response - ResponseResult<ReadResponse<ReadResponseT>> get_response_result = std::move(get_response_future).Wait(); - - if (get_response_result.HasError()) { - spdlog::debug("client timed out while trying to communicate with leader server {}", leader_.ToString()); - return get_response_result.GetError(); - } - - ResponseEnvelope<ReadResponse<ReadResponseT>> &&get_response_envelope = std::move(get_response_result.GetValue()); - ReadResponse<ReadResponseT> &&read_get_response = std::move(get_response_envelope.message); - - if (read_get_response.success) { - return std::move(read_get_response.read_return); - } - - PossiblyRedirectLeader(read_get_response); - } while (io_.Now() < before + overall_timeout); - - return TimedOut{}; + auto token = SendAsyncReadRequest(req); + auto poll_result = AwaitAsyncReadRequest(token); + while (!poll_result) { + poll_result = AwaitAsyncReadRequest(token); + } + return poll_result.value(); } /// AsyncRead methods - void SendAsyncReadRequest(const ReadRequestT &req) { - MG_ASSERT(!async_read_); + AsyncRequestToken SendAsyncReadRequest(const ReadRequestT &req) { + size_t token = async_token_generator_++; ReadRequest<ReadRequestT> read_req = {.operation = req}; - if (!async_read_before_) { - async_read_before_ = io_.Now(); - } - current_read_request_ = std::move(req); - async_read_ = io_.template Request<ReadRequest<ReadRequestT>, ReadResponse<ReadResponseT>>(leader_, read_req); + AsyncRequest<ReadRequestT, ReadResponse<ReadResponseT>> async_request{ + .start_time = io_.Now(), + .request = std::move(req), + .future = io_.template Request<ReadRequest<ReadRequestT>, ReadResponse<ReadResponseT>>(leader_, read_req), + }; + + async_reads_.emplace(token, std::move(async_request)); + + return AsyncRequestToken{token}; } - std::optional<BasicResult<TimedOut, ReadResponseT>> PollAsyncReadRequest() { - MG_ASSERT(async_read_); + void ResendAsyncReadRequest(const AsyncRequestToken &token) { + auto &async_request = async_reads_.at(token.GetId()); - if (!async_read_->IsReady()) { + ReadRequest<ReadRequestT> read_req = {.operation = async_request.request}; + + async_request.future = + io_.template Request<ReadRequest<ReadRequestT>, ReadResponse<ReadResponseT>>(leader_, read_req); + } + + std::optional<BasicResult<TimedOut, ReadResponseT>> PollAsyncReadRequest(const AsyncRequestToken &token) { + auto &async_request = async_reads_.at(token.GetId()); + + if (!async_request.future.IsReady()) { return std::nullopt; } return AwaitAsyncReadRequest(); } - std::optional<BasicResult<TimedOut, ReadResponseT>> AwaitAsyncReadRequest() { - ResponseResult<ReadResponse<ReadResponseT>> get_response_result = std::move(*async_read_).Wait(); - async_read_.reset(); + std::optional<BasicResult<TimedOut, ReadResponseT>> AwaitAsyncReadRequest(const AsyncRequestToken &token) { + auto &async_request = async_reads_.at(token.GetId()); + ResponseResult<ReadResponse<ReadResponseT>> get_response_result = std::move(async_request.future).Wait(); const Duration overall_timeout = io_.GetDefaultTimeout(); - const bool past_time_out = io_.Now() < *async_read_before_ + overall_timeout; + const bool past_time_out = io_.Now() > async_request.start_time + overall_timeout; const bool result_has_error = get_response_result.HasError(); if (result_has_error && past_time_out) { // TODO static assert the exact type of error. spdlog::debug("client timed out while trying to communicate with leader server {}", leader_.ToString()); - async_read_before_ = std::nullopt; + async_reads_.erase(token.GetId()); return TimedOut{}; } + if (!result_has_error) { ResponseEnvelope<ReadResponse<ReadResponseT>> &&get_response_envelope = std::move(get_response_result.GetValue()); ReadResponse<ReadResponseT> &&read_get_response = std::move(get_response_envelope.message); @@ -199,54 +176,70 @@ class RsmClient { PossiblyRedirectLeader(read_get_response); if (read_get_response.success) { - async_read_before_ = std::nullopt; + async_reads_.erase(token.GetId()); + spdlog::debug("returning read_return for RSM request"); return std::move(read_get_response.read_return); } - SendAsyncReadRequest(current_read_request_); - } else if (result_has_error) { + } else { SelectRandomLeader(); - SendAsyncReadRequest(current_read_request_); } + + ResendAsyncReadRequest(token); + return std::nullopt; } /// AsyncWrite methods - void SendAsyncWriteRequest(const WriteRequestT &req) { - MG_ASSERT(!async_write_); + AsyncRequestToken SendAsyncWriteRequest(const WriteRequestT &req) { + size_t token = async_token_generator_++; WriteRequest<WriteRequestT> write_req = {.operation = req}; - if (!async_write_before_) { - async_write_before_ = io_.Now(); - } - current_write_request_ = std::move(req); - async_write_ = io_.template Request<WriteRequest<WriteRequestT>, WriteResponse<WriteResponseT>>(leader_, write_req); + AsyncRequest<WriteRequestT, WriteResponse<WriteResponseT>> async_request{ + .start_time = io_.Now(), + .request = std::move(req), + .future = io_.template Request<WriteRequest<WriteRequestT>, WriteResponse<WriteResponseT>>(leader_, write_req), + }; + + async_writes_.emplace(token, std::move(async_request)); + + return AsyncRequestToken{token}; } - std::optional<BasicResult<TimedOut, WriteResponseT>> PollAsyncWriteRequest() { - MG_ASSERT(async_write_); + void ResendAsyncWriteRequest(const AsyncRequestToken &token) { + auto &async_request = async_writes_.at(token.GetId()); - if (!async_write_->IsReady()) { + WriteRequest<WriteRequestT> write_req = {.operation = async_request.request}; + + async_request.future = + io_.template Request<WriteRequest<WriteRequestT>, WriteResponse<WriteResponseT>>(leader_, write_req); + } + + std::optional<BasicResult<TimedOut, WriteResponseT>> PollAsyncWriteRequest(const AsyncRequestToken &token) { + auto &async_request = async_writes_.at(token.GetId()); + + if (!async_request.future.IsReady()) { return std::nullopt; } return AwaitAsyncWriteRequest(); } - std::optional<BasicResult<TimedOut, WriteResponseT>> AwaitAsyncWriteRequest() { - ResponseResult<WriteResponse<WriteResponseT>> get_response_result = std::move(*async_write_).Wait(); - async_write_.reset(); + std::optional<BasicResult<TimedOut, WriteResponseT>> AwaitAsyncWriteRequest(const AsyncRequestToken &token) { + auto &async_request = async_writes_.at(token.GetId()); + ResponseResult<WriteResponse<WriteResponseT>> get_response_result = std::move(async_request.future).Wait(); const Duration overall_timeout = io_.GetDefaultTimeout(); - const bool past_time_out = io_.Now() < *async_write_before_ + overall_timeout; + const bool past_time_out = io_.Now() > async_request.start_time + overall_timeout; const bool result_has_error = get_response_result.HasError(); if (result_has_error && past_time_out) { // TODO static assert the exact type of error. spdlog::debug("client timed out while trying to communicate with leader server {}", leader_.ToString()); - async_write_before_ = std::nullopt; + async_writes_.erase(token.GetId()); return TimedOut{}; } + if (!result_has_error) { ResponseEnvelope<WriteResponse<WriteResponseT>> &&get_response_envelope = std::move(get_response_result.GetValue()); @@ -255,14 +248,15 @@ class RsmClient { PossiblyRedirectLeader(write_get_response); if (write_get_response.success) { - async_write_before_ = std::nullopt; + async_writes_.erase(token.GetId()); return std::move(write_get_response.write_return); } - SendAsyncWriteRequest(current_write_request_); - } else if (result_has_error) { + } else { SelectRandomLeader(); - SendAsyncWriteRequest(current_write_request_); } + + ResendAsyncWriteRequest(token); + return std::nullopt; } }; diff --git a/src/query/v2/plan/operator.cpp b/src/query/v2/plan/operator.cpp index 0f1415694..db2a1372d 100644 --- a/src/query/v2/plan/operator.cpp +++ b/src/query/v2/plan/operator.cpp @@ -180,7 +180,7 @@ class DistributedCreateNodeCursor : public Cursor { auto &request_router = context.request_router; { SCOPED_REQUEST_WAIT_PROFILE; - request_router->Request(state_, NodeCreationInfoToRequest(context, frame)); + request_router->CreateVertices(NodeCreationInfoToRequest(context, frame)); } PlaceNodeOnTheFrame(frame, context); return true; @@ -191,7 +191,7 @@ class DistributedCreateNodeCursor : public Cursor { void Shutdown() override { input_cursor_->Shutdown(); } - void Reset() override { state_ = {}; } + void Reset() override {} void PlaceNodeOnTheFrame(Frame &frame, ExecutionContext &context) { // TODO(kostasrim) Make this work with batching @@ -252,7 +252,6 @@ class DistributedCreateNodeCursor : public Cursor { std::vector<const NodeCreationInfo *> nodes_info_; std::vector<std::vector<std::pair<storage::v3::PropertyId, msgs::Value>>> src_vertex_props_; std::vector<msgs::PrimaryKey> primary_keys_; - ExecutionState<msgs::CreateVerticesRequest> state_; }; bool Once::OnceCursor::Pull(Frame &, ExecutionContext &context) { @@ -382,7 +381,6 @@ class ScanAllCursor : public Cursor { std::optional<decltype(vertices_.value().begin())> vertices_it_; const char *op_name_; std::vector<msgs::ScanVerticesResponse> current_batch; - ExecutionState<msgs::ScanVerticesRequest> request_state; }; class DistributedScanAllAndFilterCursor : public Cursor { @@ -401,14 +399,21 @@ class DistributedScanAllAndFilterCursor : public Cursor { ResetExecutionState(); } + enum class State : int8_t { INITIALIZING, COMPLETED }; + using VertexAccessor = accessors::VertexAccessor; bool MakeRequest(RequestRouterInterface &request_router, ExecutionContext &context) { { SCOPED_REQUEST_WAIT_PROFILE; - current_batch = request_router.Request(request_state_); + std::optional<std::string> request_label = std::nullopt; + if (label_.has_value()) { + request_label = request_router.LabelToName(*label_); + } + current_batch = request_router.ScanVertices(request_label); } current_vertex_it = current_batch.begin(); + request_state_ = State::COMPLETED; return !current_batch.empty(); } @@ -420,19 +425,15 @@ class DistributedScanAllAndFilterCursor : public Cursor { if (MustAbort(context)) { throw HintedAbortError(); } - using State = ExecutionState<msgs::ScanVerticesRequest>; - if (request_state_.state == State::INITIALIZING) { + if (request_state_ == State::INITIALIZING) { if (!input_cursor_->Pull(frame, context)) { return false; } } - request_state_.label = - label_.has_value() ? std::make_optional(request_router.LabelToName(*label_)) : std::nullopt; - if (current_vertex_it == current_batch.end() && - (request_state_.state == State::COMPLETED || !MakeRequest(request_router, context))) { + (request_state_ == State::COMPLETED || !MakeRequest(request_router, context))) { ResetExecutionState(); continue; } @@ -448,7 +449,7 @@ class DistributedScanAllAndFilterCursor : public Cursor { void ResetExecutionState() { current_batch.clear(); current_vertex_it = current_batch.end(); - request_state_ = ExecutionState<msgs::ScanVerticesRequest>{}; + request_state_ = State::INITIALIZING; } void Reset() override { @@ -462,7 +463,7 @@ class DistributedScanAllAndFilterCursor : public Cursor { const char *op_name_; std::vector<VertexAccessor> current_batch; std::vector<VertexAccessor>::iterator current_vertex_it; - ExecutionState<msgs::ScanVerticesRequest> request_state_; + State request_state_ = State::INITIALIZING; std::optional<storage::v3::LabelId> label_; std::optional<std::pair<storage::v3::PropertyId, Expression *>> property_expression_pair_; std::optional<std::vector<Expression *>> filter_expressions_; @@ -2377,7 +2378,7 @@ class DistributedCreateExpandCursor : public Cursor { ResetExecutionState(); { SCOPED_REQUEST_WAIT_PROFILE; - request_router->Request(state_, ExpandCreationInfoToRequest(context, frame)); + request_router->CreateExpand(ExpandCreationInfoToRequest(context, frame)); } return true; } @@ -2457,11 +2458,10 @@ class DistributedCreateExpandCursor : public Cursor { } private: - void ResetExecutionState() { state_ = {}; } + void ResetExecutionState() {} const UniqueCursorPtr input_cursor_; const CreateExpand &self_; - ExecutionState<msgs::CreateExpandRequest> state_; }; class DistributedExpandCursor : public Cursor { @@ -2508,8 +2508,7 @@ class DistributedExpandCursor : public Cursor { request.edge_properties.emplace(); request.src_vertices.push_back(get_dst_vertex(edge, direction)); request.direction = (direction == EdgeAtom::Direction::IN) ? msgs::EdgeDirection::OUT : msgs::EdgeDirection::IN; - ExecutionState<msgs::ExpandOneRequest> request_state; - auto result_rows = context.request_router->Request(request_state, std::move(request)); + auto result_rows = context.request_router->ExpandOne(std::move(request)); MG_ASSERT(result_rows.size() == 1); auto &result_row = result_rows.front(); frame[self_.common_.node_symbol] = accessors::VertexAccessor( @@ -2534,10 +2533,9 @@ class DistributedExpandCursor : public Cursor { // to not fetch any properties of the edges request.edge_properties.emplace(); request.src_vertices.push_back(vertex.Id()); - ExecutionState<msgs::ExpandOneRequest> request_state; - auto result_rows = std::invoke([&context, &request_state, &request]() mutable { + auto result_rows = std::invoke([&context, &request]() mutable { SCOPED_REQUEST_WAIT_PROFILE; - return context.request_router->Request(request_state, std::move(request)); + return context.request_router->ExpandOne(std::move(request)); }); MG_ASSERT(result_rows.size() == 1); auto &result_row = result_rows.front(); diff --git a/src/query/v2/request_router.hpp b/src/query/v2/request_router.hpp index fbf8c514c..996272fdc 100644 --- a/src/query/v2/request_router.hpp +++ b/src/query/v2/request_router.hpp @@ -71,34 +71,28 @@ class RsmStorageClientManager { std::map<Shard, TStorageClient> cli_cache_; }; +template <typename TRequest> +struct ShardRequestState { + memgraph::coordinator::Shard shard; + TRequest request; + std::optional<io::rsm::AsyncRequestToken> async_request_token; +}; + template <typename TRequest> struct ExecutionState { using CompoundKey = io::rsm::ShardRsmKey; using Shard = coordinator::Shard; - enum State : int8_t { INITIALIZING, EXECUTING, COMPLETED }; // label is optional because some operators can create/remove etc, vertices. These kind of requests contain the label // on the request itself. std::optional<std::string> label; - // CompoundKey is optional because some operators require to iterate over all the available keys - // of a shard. One example is ScanAll, where we only require the field label. - std::optional<CompoundKey> key; // Transaction id to be filled by the RequestRouter implementation coordinator::Hlc transaction_id; // Initialized by RequestRouter implementation. This vector is filled with the shards that // the RequestRouter impl will send requests to. When a request to a shard exhausts it, meaning that // it pulled all the requested data from the given Shard, it will be removed from the Vector. When the Vector becomes // empty, it means that all of the requests have completed succefully. - // TODO(gvolfing) - // Maybe make this into a more complex object to be able to keep track of paginated resutls. E.g. instead of a vector - // of Shards make it into a std::vector<std::pair<Shard, PaginatedResultType>> (probably a struct instead of a pair) - // where PaginatedResultType is an enum signaling the progress on the given request. This way we can easily check if - // a partial response on a shard(if there is one) is finished and we can send off the request for the next batch. - std::vector<Shard> shard_cache; - // 1-1 mapping with `shard_cache`. - // A vector that tracks request metadata for each shard (For example, next_id for a ScanAll on Shard A) - std::vector<TRequest> requests; - State state = INITIALIZING; + std::vector<ShardRequestState<TRequest>> requests; }; class RequestRouterInterface { @@ -114,13 +108,10 @@ class RequestRouterInterface { virtual void StartTransaction() = 0; virtual void Commit() = 0; - virtual std::vector<VertexAccessor> Request(ExecutionState<msgs::ScanVerticesRequest> &state) = 0; - virtual std::vector<msgs::CreateVerticesResponse> Request(ExecutionState<msgs::CreateVerticesRequest> &state, - std::vector<msgs::NewVertex> new_vertices) = 0; - virtual std::vector<msgs::ExpandOneResultRow> Request(ExecutionState<msgs::ExpandOneRequest> &state, - msgs::ExpandOneRequest request) = 0; - virtual std::vector<msgs::CreateExpandResponse> Request(ExecutionState<msgs::CreateExpandRequest> &state, - std::vector<msgs::NewExpand> new_edges) = 0; + virtual std::vector<VertexAccessor> ScanVertices(std::optional<std::string> label) = 0; + virtual std::vector<msgs::CreateVerticesResponse> CreateVertices(std::vector<msgs::NewVertex> new_vertices) = 0; + virtual std::vector<msgs::ExpandOneResultRow> ExpandOne(msgs::ExpandOneRequest request) = 0; + virtual std::vector<msgs::CreateExpandResponse> CreateExpand(std::vector<msgs::NewExpand> new_edges) = 0; virtual storage::v3::EdgeTypeId NameToEdgeType(const std::string &name) const = 0; virtual storage::v3::PropertyId NameToProperty(const std::string &name) const = 0; @@ -246,99 +237,121 @@ class RequestRouter : public RequestRouterInterface { bool IsPrimaryLabel(storage::v3::LabelId label) const override { return shards_map_.label_spaces.contains(label); } // TODO(kostasrim) Simplify return result - std::vector<VertexAccessor> Request(ExecutionState<msgs::ScanVerticesRequest> &state) override { - MaybeInitializeExecutionState(state); + std::vector<VertexAccessor> ScanVertices(std::optional<std::string> label) override { + ExecutionState<msgs::ScanVerticesRequest> state = {}; + state.label = label; + + // create requests + InitializeExecutionState(state); + + // begin all requests in parallel + for (auto &request : state.requests) { + auto &storage_client = GetStorageClientForShard(request.shard); + msgs::ReadRequests req = request.request; + + request.async_request_token = storage_client.SendAsyncReadRequest(request.request); + } + + // drive requests to completion std::vector<msgs::ScanVerticesResponse> responses; + responses.reserve(state.requests.size()); + do { + DriveReadResponses(state, responses); + } while (!state.requests.empty()); - SendAllRequests(state); - auto all_requests_gathered = [](auto &paginated_rsp_tracker) { - return std::ranges::all_of(paginated_rsp_tracker, [](const auto &state) { - return state.second == PaginatedResponseState::PartiallyFinished; - }); - }; - - std::map<Shard, PaginatedResponseState> paginated_response_tracker; - for (const auto &shard : state.shard_cache) { - paginated_response_tracker.insert(std::make_pair(shard, PaginatedResponseState::Pending)); + // convert responses into VertexAccessor objects to return + std::vector<VertexAccessor> accessors; + accessors.reserve(responses.size()); + for (auto &response : responses) { + for (auto &result_row : response.results) { + accessors.emplace_back(VertexAccessor(std::move(result_row.vertex), std::move(result_row.props), this)); + } } - do { - AwaitOnPaginatedRequests(state, responses, paginated_response_tracker); - } while (!all_requests_gathered(paginated_response_tracker)); - - MaybeCompleteState(state); - // TODO(kostasrim) Before returning start prefetching the batch (this shall be done once we get MgFuture as return - // result of storage_client.SendReadRequest()). - return PostProcess(std::move(responses)); + return accessors; } - std::vector<msgs::CreateVerticesResponse> Request(ExecutionState<msgs::CreateVerticesRequest> &state, - std::vector<msgs::NewVertex> new_vertices) override { + std::vector<msgs::CreateVerticesResponse> CreateVertices(std::vector<msgs::NewVertex> new_vertices) override { + ExecutionState<msgs::CreateVerticesRequest> state = {}; MG_ASSERT(!new_vertices.empty()); - MaybeInitializeExecutionState(state, new_vertices); - std::vector<msgs::CreateVerticesResponse> responses; - auto &shard_cache_ref = state.shard_cache; - // 1. Send the requests. - SendAllRequests(state, shard_cache_ref); + // create requests + InitializeExecutionState(state, new_vertices); - // 2. Block untill all the futures are exhausted - do { - AwaitOnResponses(state, responses); - } while (!state.shard_cache.empty()); + // begin all requests in parallel + for (auto &request : state.requests) { + auto req_deep_copy = request.request; - MaybeCompleteState(state); - // TODO(kostasrim) Before returning start prefetching the batch (this shall be done once we get MgFuture as return - // result of storage_client.SendReadRequest()). - return responses; - } - - std::vector<msgs::CreateExpandResponse> Request(ExecutionState<msgs::CreateExpandRequest> &state, - std::vector<msgs::NewExpand> new_edges) override { - MG_ASSERT(!new_edges.empty()); - MaybeInitializeExecutionState(state, new_edges); - std::vector<msgs::CreateExpandResponse> responses; - auto &shard_cache_ref = state.shard_cache; - size_t id{0}; - for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end(); ++id) { - auto &storage_client = GetStorageClientForShard(*shard_it); - msgs::WriteRequests req = state.requests[id]; - auto write_response_result = storage_client.SendWriteRequest(std::move(req)); - if (write_response_result.HasError()) { - throw std::runtime_error("CreateVertices request timedout"); + for (auto &new_vertex : req_deep_copy.new_vertices) { + new_vertex.label_ids.erase(new_vertex.label_ids.begin()); } - msgs::WriteResponses response_variant = write_response_result.GetValue(); - msgs::CreateExpandResponse mapped_response = std::get<msgs::CreateExpandResponse>(response_variant); - if (mapped_response.error) { - throw std::runtime_error("CreateExpand request did not succeed"); - } - responses.push_back(mapped_response); - shard_it = shard_cache_ref.erase(shard_it); + auto &storage_client = GetStorageClientForShard(request.shard); + + msgs::WriteRequests req = req_deep_copy; + request.async_request_token = storage_client.SendAsyncWriteRequest(req); } - // We are done with this state - MaybeCompleteState(state); + + // drive requests to completion + std::vector<msgs::CreateVerticesResponse> responses; + responses.reserve(state.requests.size()); + do { + DriveWriteResponses(state, responses); + } while (!state.requests.empty()); + return responses; } - std::vector<msgs::ExpandOneResultRow> Request(ExecutionState<msgs::ExpandOneRequest> &state, - msgs::ExpandOneRequest request) override { + std::vector<msgs::CreateExpandResponse> CreateExpand(std::vector<msgs::NewExpand> new_edges) override { + ExecutionState<msgs::CreateExpandRequest> state = {}; + MG_ASSERT(!new_edges.empty()); + + // create requests + InitializeExecutionState(state, new_edges); + + // begin all requests in parallel + for (auto &request : state.requests) { + auto &storage_client = GetStorageClientForShard(request.shard); + msgs::WriteRequests req = request.request; + request.async_request_token = storage_client.SendAsyncWriteRequest(req); + } + + // drive requests to completion + std::vector<msgs::CreateExpandResponse> responses; + responses.reserve(state.requests.size()); + do { + DriveWriteResponses(state, responses); + } while (!state.requests.empty()); + + return responses; + } + + std::vector<msgs::ExpandOneResultRow> ExpandOne(msgs::ExpandOneRequest request) override { + ExecutionState<msgs::ExpandOneRequest> state = {}; // TODO(kostasrim)Update to limit the batch size here // Expansions of the destination must be handled by the caller. For example // match (u:L1 { prop : 1 })-[:Friend]-(v:L1) // For each vertex U, the ExpandOne will result in <U, Edges>. The destination vertex and its properties // must be fetched again with an ExpandOne(Edges.dst) - MaybeInitializeExecutionState(state, std::move(request)); + + // create requests + InitializeExecutionState(state, std::move(request)); + + // begin all requests in parallel + for (auto &request : state.requests) { + auto &storage_client = GetStorageClientForShard(request.shard); + msgs::ReadRequests req = request.request; + request.async_request_token = storage_client.SendAsyncReadRequest(req); + } + + // drive requests to completion std::vector<msgs::ExpandOneResponse> responses; - auto &shard_cache_ref = state.shard_cache; - - // 1. Send the requests. - SendAllRequests(state, shard_cache_ref); - - // 2. Block untill all the futures are exhausted + responses.reserve(state.requests.size()); do { - AwaitOnResponses(state, responses); - } while (!state.shard_cache.empty()); + DriveReadResponses(state, responses); + } while (!state.requests.empty()); + + // post-process responses std::vector<msgs::ExpandOneResultRow> result_rows; const auto total_row_count = std::accumulate(responses.begin(), responses.end(), 0, [](const int64_t partial_count, const msgs::ExpandOneResponse &resp) { @@ -350,7 +363,7 @@ class RequestRouter : public RequestRouterInterface { result_rows.insert(result_rows.end(), std::make_move_iterator(response.result.begin()), std::make_move_iterator(response.result.end())); } - MaybeCompleteState(state); + return result_rows; } @@ -367,71 +380,35 @@ class RequestRouter : public RequestRouterInterface { } private: - enum class PaginatedResponseState { Pending, PartiallyFinished }; - - std::vector<VertexAccessor> PostProcess(std::vector<msgs::ScanVerticesResponse> &&responses) const { - std::vector<VertexAccessor> accessors; - for (auto &response : responses) { - for (auto &result_row : response.results) { - accessors.emplace_back(VertexAccessor(std::move(result_row.vertex), std::move(result_row.props), this)); - } - } - return accessors; - } - - template <typename ExecutionState> - void ThrowIfStateCompleted(ExecutionState &state) const { - if (state.state == ExecutionState::COMPLETED) [[unlikely]] { - throw std::runtime_error("State is completed and must be reset"); - } - } - - template <typename ExecutionState> - void MaybeCompleteState(ExecutionState &state) const { - if (state.requests.empty()) { - state.state = ExecutionState::COMPLETED; - } - } - - template <typename ExecutionState> - bool ShallNotInitializeState(ExecutionState &state) const { - return state.state != ExecutionState::INITIALIZING; - } - - void MaybeInitializeExecutionState(ExecutionState<msgs::CreateVerticesRequest> &state, - std::vector<msgs::NewVertex> new_vertices) { - ThrowIfStateCompleted(state); - if (ShallNotInitializeState(state)) { - return; - } + void InitializeExecutionState(ExecutionState<msgs::CreateVerticesRequest> &state, + std::vector<msgs::NewVertex> new_vertices) { state.transaction_id = transaction_id_; std::map<Shard, msgs::CreateVerticesRequest> per_shard_request_table; for (auto &new_vertex : new_vertices) { - MG_ASSERT(!new_vertex.label_ids.empty(), "This is error!"); + MG_ASSERT(!new_vertex.label_ids.empty(), "No label_ids provided for new vertex in RequestRouter::CreateVertices"); auto shard = shards_map_.GetShardForKey(new_vertex.label_ids[0].id, storage::conversions::ConvertPropertyVector(new_vertex.primary_key)); if (!per_shard_request_table.contains(shard)) { msgs::CreateVerticesRequest create_v_rqst{.transaction_id = transaction_id_}; per_shard_request_table.insert(std::pair(shard, std::move(create_v_rqst))); - state.shard_cache.push_back(shard); } per_shard_request_table[shard].new_vertices.push_back(std::move(new_vertex)); } - for (auto &[shard, rqst] : per_shard_request_table) { - state.requests.push_back(std::move(rqst)); + for (auto &[shard, request] : per_shard_request_table) { + ShardRequestState<msgs::CreateVerticesRequest> shard_request_state{ + .shard = shard, + .request = request, + .async_request_token = std::nullopt, + }; + state.requests.emplace_back(std::move(shard_request_state)); } - state.state = ExecutionState<msgs::CreateVerticesRequest>::EXECUTING; } - void MaybeInitializeExecutionState(ExecutionState<msgs::CreateExpandRequest> &state, - std::vector<msgs::NewExpand> new_expands) { - ThrowIfStateCompleted(state); - if (ShallNotInitializeState(state)) { - return; - } + void InitializeExecutionState(ExecutionState<msgs::CreateExpandRequest> &state, + std::vector<msgs::NewExpand> new_expands) { state.transaction_id = transaction_id_; std::map<Shard, msgs::CreateExpandRequest> per_shard_request_table; @@ -459,18 +436,16 @@ class RequestRouter : public RequestRouterInterface { } for (auto &[shard, request] : per_shard_request_table) { - state.shard_cache.push_back(shard); - state.requests.push_back(std::move(request)); + ShardRequestState<msgs::CreateExpandRequest> shard_request_state{ + .shard = shard, + .request = request, + .async_request_token = std::nullopt, + }; + state.requests.emplace_back(std::move(shard_request_state)); } - state.state = ExecutionState<msgs::CreateExpandRequest>::EXECUTING; } - void MaybeInitializeExecutionState(ExecutionState<msgs::ScanVerticesRequest> &state) { - ThrowIfStateCompleted(state); - if (ShallNotInitializeState(state)) { - return; - } - + void InitializeExecutionState(ExecutionState<msgs::ScanVerticesRequest> &state) { std::vector<coordinator::Shards> multi_shards; state.transaction_id = transaction_id_; if (!state.label) { @@ -484,21 +459,23 @@ class RequestRouter : public RequestRouterInterface { for (auto &shards : multi_shards) { for (auto &[key, shard] : shards) { MG_ASSERT(!shard.empty()); - state.shard_cache.push_back(std::move(shard)); - msgs::ScanVerticesRequest rqst; - rqst.transaction_id = transaction_id_; - rqst.start_id.second = storage::conversions::ConvertValueVector(key); - state.requests.push_back(std::move(rqst)); + + msgs::ScanVerticesRequest request; + request.transaction_id = transaction_id_; + request.start_id.second = storage::conversions::ConvertValueVector(key); + + ShardRequestState<msgs::ScanVerticesRequest> shard_request_state{ + .shard = shard, + .request = std::move(request), + .async_request_token = std::nullopt, + }; + + state.requests.emplace_back(std::move(shard_request_state)); } } - state.state = ExecutionState<msgs::ScanVerticesRequest>::EXECUTING; } - void MaybeInitializeExecutionState(ExecutionState<msgs::ExpandOneRequest> &state, msgs::ExpandOneRequest request) { - ThrowIfStateCompleted(state); - if (ShallNotInitializeState(state)) { - return; - } + void InitializeExecutionState(ExecutionState<msgs::ExpandOneRequest> &state, msgs::ExpandOneRequest request) { state.transaction_id = transaction_id_; std::map<Shard, msgs::ExpandOneRequest> per_shard_request_table; @@ -511,15 +488,19 @@ class RequestRouter : public RequestRouterInterface { shards_map_.GetShardForKey(vertex.first.id, storage::conversions::ConvertPropertyVector(vertex.second)); if (!per_shard_request_table.contains(shard)) { per_shard_request_table.insert(std::pair(shard, top_level_rqst_template)); - state.shard_cache.push_back(shard); } per_shard_request_table[shard].src_vertices.push_back(vertex); } - for (auto &[shard, rqst] : per_shard_request_table) { - state.requests.push_back(std::move(rqst)); + for (auto &[shard, request] : per_shard_request_table) { + ShardRequestState<msgs::ExpandOneRequest> shard_request_state{ + .shard = shard, + .request = request, + .async_request_token = std::nullopt, + }; + + state.requests.emplace_back(std::move(shard_request_state)); } - state.state = ExecutionState<msgs::ExpandOneRequest>::EXECUTING; } StorageClient &GetStorageClientForShard(Shard shard) { @@ -546,173 +527,54 @@ class RequestRouter : public RequestRouterInterface { storage_cli_manager_.AddClient(target_shard, std::move(cli)); } - void SendAllRequests(ExecutionState<msgs::ScanVerticesRequest> &state) { - int64_t shard_idx = 0; - for (const auto &request : state.requests) { - const auto ¤t_shard = state.shard_cache[shard_idx]; + template <typename RequestT, typename ResponseT> + void DriveReadResponses(ExecutionState<RequestT> &state, std::vector<ResponseT> &responses) { + for (auto &request : state.requests) { + auto &storage_client = GetStorageClientForShard(request.shard); - auto &storage_client = GetStorageClientForShard(current_shard); - msgs::ReadRequests req = request; - storage_client.SendAsyncReadRequest(request); - - ++shard_idx; - } - } - - void SendAllRequests(ExecutionState<msgs::CreateVerticesRequest> &state, - std::vector<memgraph::coordinator::Shard> &shard_cache_ref) { - size_t id = 0; - for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end(); ++shard_it) { - // This is fine because all new_vertices of each request end up on the same shard - const auto labels = state.requests[id].new_vertices[0].label_ids; - auto req_deep_copy = state.requests[id]; - - for (auto &new_vertex : req_deep_copy.new_vertices) { - new_vertex.label_ids.erase(new_vertex.label_ids.begin()); - } - - auto &storage_client = GetStorageClientForShard(*shard_it); - - msgs::WriteRequests req = req_deep_copy; - storage_client.SendAsyncWriteRequest(req); - ++id; - } - } - - void SendAllRequests(ExecutionState<msgs::ExpandOneRequest> &state, - std::vector<memgraph::coordinator::Shard> &shard_cache_ref) { - size_t id = 0; - for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end(); ++shard_it) { - auto &storage_client = GetStorageClientForShard(*shard_it); - msgs::ReadRequests req = state.requests[id]; - storage_client.SendAsyncReadRequest(req); - ++id; - } - } - - void AwaitOnResponses(ExecutionState<msgs::CreateVerticesRequest> &state, - std::vector<msgs::CreateVerticesResponse> &responses) { - auto &shard_cache_ref = state.shard_cache; - int64_t request_idx = 0; - - for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end();) { - // This is fine because all new_vertices of each request end up on the same shard - const auto labels = state.requests[request_idx].new_vertices[0].label_ids; - - auto &storage_client = GetStorageClientForShard(*shard_it); - - auto poll_result = storage_client.AwaitAsyncWriteRequest(); - if (!poll_result) { - ++shard_it; - ++request_idx; - - continue; + auto poll_result = storage_client.AwaitAsyncReadRequest(request.async_request_token.value()); + while (!poll_result) { + poll_result = storage_client.AwaitAsyncReadRequest(request.async_request_token.value()); } if (poll_result->HasError()) { - throw std::runtime_error("CreateVertices request timed out"); - } - - msgs::WriteResponses response_variant = poll_result->GetValue(); - auto response = std::get<msgs::CreateVerticesResponse>(response_variant); - - if (response.error) { - throw std::runtime_error("CreateVertices request did not succeed"); - } - responses.push_back(response); - - shard_it = shard_cache_ref.erase(shard_it); - // Needed to maintain the 1-1 mapping between the ShardCache and the requests. - auto it = state.requests.begin() + request_idx; - state.requests.erase(it); - } - } - - void AwaitOnResponses(ExecutionState<msgs::ExpandOneRequest> &state, - std::vector<msgs::ExpandOneResponse> &responses) { - auto &shard_cache_ref = state.shard_cache; - int64_t request_idx = 0; - - for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end();) { - auto &storage_client = GetStorageClientForShard(*shard_it); - - auto poll_result = storage_client.PollAsyncReadRequest(); - if (!poll_result) { - ++shard_it; - ++request_idx; - continue; - } - - if (poll_result->HasError()) { - throw std::runtime_error("ExpandOne request timed out"); + throw std::runtime_error("RequestRouter Read request timed out"); } msgs::ReadResponses response_variant = poll_result->GetValue(); - auto response = std::get<msgs::ExpandOneResponse>(response_variant); - // -NOTE- - // Currently a boolean flag for signaling the overall success of the - // ExpandOne request does not exist. But it should, so here we assume - // that it is already in place. + auto response = std::get<ResponseT>(response_variant); if (response.error) { - throw std::runtime_error("ExpandOne request did not succeed"); + throw std::runtime_error("RequestRouter Read request did not succeed"); } responses.push_back(std::move(response)); - shard_it = shard_cache_ref.erase(shard_it); - // Needed to maintain the 1-1 mapping between the ShardCache and the requests. - auto it = state.requests.begin() + request_idx; - state.requests.erase(it); } + state.requests.clear(); } - void AwaitOnPaginatedRequests(ExecutionState<msgs::ScanVerticesRequest> &state, - std::vector<msgs::ScanVerticesResponse> &responses, - std::map<Shard, PaginatedResponseState> &paginated_response_tracker) { - auto &shard_cache_ref = state.shard_cache; + template <typename RequestT, typename ResponseT> + void DriveWriteResponses(ExecutionState<RequestT> &state, std::vector<ResponseT> &responses) { + for (auto &request : state.requests) { + auto &storage_client = GetStorageClientForShard(request.shard); - // Find the first request that is not holding a paginated response. - int64_t request_idx = 0; - for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end();) { - if (paginated_response_tracker.at(*shard_it) != PaginatedResponseState::Pending) { - ++shard_it; - ++request_idx; - continue; + auto poll_result = storage_client.AwaitAsyncWriteRequest(request.async_request_token.value()); + while (!poll_result) { + poll_result = storage_client.AwaitAsyncWriteRequest(request.async_request_token.value()); } - auto &storage_client = GetStorageClientForShard(*shard_it); - - auto await_result = storage_client.AwaitAsyncReadRequest(); - - if (!await_result) { - // Redirection has occured. - ++shard_it; - ++request_idx; - continue; + if (poll_result->HasError()) { + throw std::runtime_error("RequestRouter Write request timed out"); } - if (await_result->HasError()) { - throw std::runtime_error("ScanAll request timed out"); - } - - msgs::ReadResponses read_response_variant = await_result->GetValue(); - auto response = std::get<msgs::ScanVerticesResponse>(read_response_variant); + msgs::WriteResponses response_variant = poll_result->GetValue(); + auto response = std::get<ResponseT>(response_variant); if (response.error) { - throw std::runtime_error("ScanAll request did not succeed"); + throw std::runtime_error("RequestRouter Write request did not succeed"); } - if (!response.next_start_id) { - paginated_response_tracker.erase((*shard_it)); - shard_cache_ref.erase(shard_it); - // Needed to maintain the 1-1 mapping between the ShardCache and the requests. - auto it = state.requests.begin() + request_idx; - state.requests.erase(it); - - } else { - state.requests[request_idx].start_id.second = response.next_start_id->second; - paginated_response_tracker[*shard_it] = PaginatedResponseState::PartiallyFinished; - } responses.push_back(std::move(response)); } + state.requests.clear(); } void SetUpNameIdMappers() { diff --git a/tests/simulation/test_cluster.hpp b/tests/simulation/test_cluster.hpp index 99f617cda..1392a0632 100644 --- a/tests/simulation/test_cluster.hpp +++ b/tests/simulation/test_cluster.hpp @@ -164,8 +164,6 @@ void ExecuteOp(query::v2::RequestRouter<SimulatorTransport> &request_router, std return; } - query::v2::ExecutionState<msgs::CreateVerticesRequest> state; - auto label_id = request_router.NameToLabel("test_label"); msgs::NewVertex nv{.primary_key = primary_key}; @@ -174,7 +172,7 @@ void ExecuteOp(query::v2::RequestRouter<SimulatorTransport> &request_router, std std::vector<msgs::NewVertex> new_vertices; new_vertices.push_back(std::move(nv)); - auto result = request_router.Request(state, std::move(new_vertices)); + auto result = request_router.CreateVertices(std::move(new_vertices)); RC_ASSERT(result.size() == 1); RC_ASSERT(!result[0].error.has_value()); @@ -184,9 +182,7 @@ void ExecuteOp(query::v2::RequestRouter<SimulatorTransport> &request_router, std void ExecuteOp(query::v2::RequestRouter<SimulatorTransport> &request_router, std::set<CompoundKey> &correctness_model, ScanAll scan_all) { - query::v2::ExecutionState<msgs::ScanVerticesRequest> request{.label = "test_label"}; - - auto results = request_router.Request(request); + auto results = request_router.ScanVertices("test_label"); RC_ASSERT(results.size() == correctness_model.size()); diff --git a/tests/unit/high_density_shard_create_scan.cpp b/tests/unit/high_density_shard_create_scan.cpp index 22af9c702..2be48fc77 100644 --- a/tests/unit/high_density_shard_create_scan.cpp +++ b/tests/unit/high_density_shard_create_scan.cpp @@ -174,8 +174,6 @@ void ExecuteOp(query::v2::RequestRouter<LocalTransport> &request_router, std::se return; } - query::v2::ExecutionState<msgs::CreateVerticesRequest> state; - auto label_id = request_router.NameToLabel("test_label"); msgs::NewVertex nv{.primary_key = primary_key}; @@ -184,7 +182,7 @@ void ExecuteOp(query::v2::RequestRouter<LocalTransport> &request_router, std::se std::vector<msgs::NewVertex> new_vertices; new_vertices.push_back(std::move(nv)); - auto result = request_router.Request(state, std::move(new_vertices)); + auto result = request_router.CreateVertices(std::move(new_vertices)); MG_ASSERT(result.size() == 1); MG_ASSERT(!result[0].error.has_value()); @@ -194,9 +192,7 @@ void ExecuteOp(query::v2::RequestRouter<LocalTransport> &request_router, std::se void ExecuteOp(query::v2::RequestRouter<LocalTransport> &request_router, std::set<CompoundKey> &correctness_model, ScanAll scan_all) { - query::v2::ExecutionState<msgs::ScanVerticesRequest> request{.label = "test_label"}; - - auto results = request_router.Request(request); + auto results = request_router.ScanVertices("test_label"); MG_ASSERT(results.size() == correctness_model.size()); diff --git a/tests/unit/machine_manager.cpp b/tests/unit/machine_manager.cpp index 0b081e5a1..748233737 100644 --- a/tests/unit/machine_manager.cpp +++ b/tests/unit/machine_manager.cpp @@ -111,15 +111,12 @@ ShardMap TestShardMap() { template <typename RequestRouter> void TestScanAll(RequestRouter &request_router) { - query::v2::ExecutionState<msgs::ScanVerticesRequest> state{.label = kLabelName}; - - auto result = request_router.Request(state); + auto result = request_router.ScanVertices(kLabelName); EXPECT_EQ(result.size(), 2); } void TestCreateVertices(query::v2::RequestRouterInterface &request_router) { using PropVal = msgs::Value; - query::v2::ExecutionState<msgs::CreateVerticesRequest> state; std::vector<msgs::NewVertex> new_vertices; auto label_id = request_router.NameToLabel(kLabelName); msgs::NewVertex a1{.primary_key = {PropVal(int64_t(0)), PropVal(int64_t(0))}}; @@ -129,14 +126,13 @@ void TestCreateVertices(query::v2::RequestRouterInterface &request_router) { new_vertices.push_back(std::move(a1)); new_vertices.push_back(std::move(a2)); - auto result = request_router.Request(state, std::move(new_vertices)); + auto result = request_router.CreateVertices(std::move(new_vertices)); EXPECT_EQ(result.size(), 1); EXPECT_FALSE(result[0].error.has_value()) << result[0].error->message; } void TestCreateExpand(query::v2::RequestRouterInterface &request_router) { using PropVal = msgs::Value; - query::v2::ExecutionState<msgs::CreateExpandRequest> state; std::vector<msgs::NewExpand> new_expands; const auto edge_type_id = request_router.NameToEdgeType("edge_type"); @@ -150,20 +146,19 @@ void TestCreateExpand(query::v2::RequestRouterInterface &request_router) { new_expands.push_back(std::move(expand_1)); new_expands.push_back(std::move(expand_2)); - auto responses = request_router.Request(state, std::move(new_expands)); + auto responses = request_router.CreateExpand(std::move(new_expands)); MG_ASSERT(responses.size() == 1); MG_ASSERT(!responses[0].error.has_value()); } void TestExpandOne(query::v2::RequestRouterInterface &request_router) { - query::v2::ExecutionState<msgs::ExpandOneRequest> state{}; msgs::ExpandOneRequest request; const auto edge_type_id = request_router.NameToEdgeType("edge_type"); const auto label = msgs::Label{request_router.NameToLabel("test_label")}; request.src_vertices.push_back(msgs::VertexId{label, {msgs::Value(int64_t(0)), msgs::Value(int64_t(0))}}); request.edge_types.push_back(msgs::EdgeType{edge_type_id}); request.direction = msgs::EdgeDirection::BOTH; - auto result_rows = request_router.Request(state, std::move(request)); + auto result_rows = request_router.ExpandOne(std::move(request)); MG_ASSERT(result_rows.size() == 1); MG_ASSERT(result_rows[0].in_edges_with_all_properties.size() == 1); MG_ASSERT(result_rows[0].out_edges_with_all_properties.size() == 1); diff --git a/tests/unit/query_v2_expression_evaluator.cpp b/tests/unit/query_v2_expression_evaluator.cpp index 112ecd29e..5f77ed4e7 100644 --- a/tests/unit/query_v2_expression_evaluator.cpp +++ b/tests/unit/query_v2_expression_evaluator.cpp @@ -82,23 +82,15 @@ class MockedRequestRouter : public RequestRouterInterface { } void StartTransaction() override {} void Commit() override {} - std::vector<VertexAccessor> Request(ExecutionState<memgraph::msgs::ScanVerticesRequest> &state) override { + std::vector<VertexAccessor> ScanVertices(std::optional<std::string> /* label */) override { return {}; } + + std::vector<CreateVerticesResponse> CreateVertices(std::vector<memgraph::msgs::NewVertex> new_vertices) override { return {}; } - std::vector<CreateVerticesResponse> Request(ExecutionState<CreateVerticesRequest> &state, - std::vector<memgraph::msgs::NewVertex> new_vertices) override { - return {}; - } + std::vector<ExpandOneResultRow> ExpandOne(ExpandOneRequest request) override { return {}; } - std::vector<ExpandOneResultRow> Request(ExecutionState<ExpandOneRequest> &state, ExpandOneRequest request) override { - return {}; - } - - std::vector<CreateExpandResponse> Request(ExecutionState<CreateExpandRequest> &state, - std::vector<NewExpand> new_edges) override { - return {}; - } + std::vector<CreateExpandResponse> CreateExpand(std::vector<NewExpand> new_edges) override { return {}; } const std::string &PropertyToName(memgraph::storage::v3::PropertyId id) const override { return properties_.IdToName(id.AsUint());