diff --git a/src/io/future.hpp b/src/io/future.hpp index 98437b496..585f18938 100644 --- a/src/io/future.hpp +++ b/src/io/future.hpp @@ -35,10 +35,13 @@ class Shared { std::optional item_; bool consumed_ = false; bool waiting_ = false; - std::function simulator_notifier_ = nullptr; + bool filled_ = false; + std::function wait_notifier_ = nullptr; + std::function fill_notifier_ = nullptr; public: - explicit Shared(std::function simulator_notifier) : simulator_notifier_(simulator_notifier) {} + explicit Shared(std::function wait_notifier, std::function fill_notifier) + : wait_notifier_(wait_notifier), fill_notifier_(fill_notifier) {} Shared() = default; Shared(Shared &&) = delete; Shared &operator=(Shared &&) = delete; @@ -64,7 +67,7 @@ class Shared { waiting_ = true; while (!item_) { - if (simulator_notifier_) [[unlikely]] { + if (wait_notifier_) [[unlikely]] { // We can't hold our own lock while notifying // the simulator because notifying the simulator // involves acquiring the simulator's mutex @@ -76,7 +79,7 @@ class Shared { // so we have to get out of its way to avoid // a cyclical deadlock. lock.unlock(); - std::invoke(simulator_notifier_); + std::invoke(wait_notifier_); lock.lock(); if (item_) { // item may have been filled while we @@ -115,11 +118,19 @@ class Shared { std::unique_lock lock(mu_); MG_ASSERT(!consumed_, "Promise filled after it was already consumed!"); - MG_ASSERT(!item_, "Promise filled twice!"); + MG_ASSERT(!filled_, "Promise filled twice!"); item_ = item; + filled_ = true; } // lock released before condition variable notification + if (fill_notifier_) { + spdlog::trace("calling fill notifier"); + std::invoke(fill_notifier_); + } else { + spdlog::trace("not calling fill notifier"); + } + cv_.notify_all(); } @@ -251,8 +262,9 @@ std::pair, Promise> FuturePromisePair() { } template -std::pair, Promise> FuturePromisePairWithNotifier(std::function simulator_notifier) { - std::shared_ptr> shared = std::make_shared>(simulator_notifier); +std::pair, Promise> FuturePromisePairWithNotifications(std::function wait_notifier, + std::function fill_notifier) { + std::shared_ptr> shared = std::make_shared>(wait_notifier, fill_notifier); Future future = Future(shared); Promise promise = Promise(shared); diff --git a/src/io/local_transport/local_transport.hpp b/src/io/local_transport/local_transport.hpp index 258df6385..b64cabf1d 100644 --- a/src/io/local_transport/local_transport.hpp +++ b/src/io/local_transport/local_transport.hpp @@ -31,9 +31,10 @@ class LocalTransport { : local_transport_handle_(std::move(local_transport_handle)) {} template - ResponseFuture Request(Address to_address, Address from_address, RequestT request, Duration timeout) { - return local_transport_handle_->template SubmitRequest(to_address, from_address, - std::move(request), timeout); + ResponseFuture Request(Address to_address, Address from_address, RequestT request, + std::function fill_notifier, Duration timeout) { + return local_transport_handle_->template SubmitRequest( + to_address, from_address, std::move(request), timeout, fill_notifier); } template diff --git a/src/io/local_transport/local_transport_handle.hpp b/src/io/local_transport/local_transport_handle.hpp index 2303ae735..38538620f 100644 --- a/src/io/local_transport/local_transport_handle.hpp +++ b/src/io/local_transport/local_transport_handle.hpp @@ -140,8 +140,12 @@ class LocalTransportHandle { template ResponseFuture SubmitRequest(Address to_address, Address from_address, RequestT &&request, - Duration timeout) { - auto [future, promise] = memgraph::io::FuturePromisePair>(); + Duration timeout, std::function fill_notifier) { + auto [future, promise] = memgraph::io::FuturePromisePairWithNotifications>( + // set null notifier for when the Future::Wait is called + nullptr, + // set notifier for when Promise::Fill is called + std::forward>(fill_notifier)); const bool port_matches = to_address.last_known_port == from_address.last_known_port; const bool ip_matches = to_address.last_known_ip == from_address.last_known_ip; diff --git a/src/io/notifier.hpp b/src/io/notifier.hpp new file mode 100644 index 000000000..e6b073046 --- /dev/null +++ b/src/io/notifier.hpp @@ -0,0 +1,69 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +namespace memgraph::io { + +class ReadinessToken { + size_t id_; + + public: + explicit ReadinessToken(size_t id) : id_(id) {} + size_t GetId() const { return id_; } +}; + +class Inner { + std::condition_variable cv_; + std::mutex mu_; + std::vector ready_; + + public: + void Notify(ReadinessToken readiness_token) { + { + std::unique_lock lock(mu_); + spdlog::trace("Notifier notifying token {}", readiness_token.GetId()); + ready_.emplace_back(readiness_token); + } // mutex dropped + + cv_.notify_all(); + } + + ReadinessToken Await() { + std::unique_lock lock(mu_); + + while (ready_.empty()) { + cv_.wait(lock); + } + + ReadinessToken ret = ready_.back(); + ready_.pop_back(); + return ret; + } +}; + +class Notifier { + std::shared_ptr inner_; + + public: + Notifier() : inner_(std::make_shared()) {} + Notifier(const Notifier &) = default; + Notifier &operator=(const Notifier &) = default; + Notifier(Notifier &&old) = default; + Notifier &operator=(Notifier &&old) = default; + ~Notifier() = default; + + void Notify(ReadinessToken readiness_token) { inner_->Notify(readiness_token); } + + ReadinessToken Await() { return inner_->Await(); } +}; + +} // namespace memgraph::io diff --git a/src/io/rsm/rsm_client.hpp b/src/io/rsm/rsm_client.hpp index 920866c7a..1283ec3dc 100644 --- a/src/io/rsm/rsm_client.hpp +++ b/src/io/rsm/rsm_client.hpp @@ -19,6 +19,7 @@ #include "io/address.hpp" #include "io/errors.hpp" +#include "io/notifier.hpp" #include "io/rsm/raft.hpp" #include "utils/result.hpp" @@ -37,18 +38,11 @@ using memgraph::io::rsm::WriteRequest; using memgraph::io::rsm::WriteResponse; using memgraph::utils::BasicResult; -class AsyncRequestToken { - size_t id_; - - public: - explicit AsyncRequestToken(size_t id) : id_(id) {} - size_t GetId() const { return id_; } -}; - template struct AsyncRequest { Time start_time; RequestT request; + Notifier notifier; ResponseFuture future; }; @@ -66,8 +60,6 @@ class RsmClient { std::unordered_map>> async_reads_; std::unordered_map>> async_writes_; - size_t async_token_generator_ = 0; - void SelectRandomLeader() { std::uniform_int_distribution addr_distrib(0, (server_addrs_.size() - 1)); size_t addr_index = io_.Rand(addr_distrib); @@ -81,6 +73,7 @@ class RsmClient { if (response.retry_leader) { MG_ASSERT(!response.success, "retry_leader should never be set for successful responses"); leader_ = response.retry_leader.value(); + spdlog::error("client redirected to leader server {}", leader_.ToString()); spdlog::debug("client redirected to leader server {}", leader_.ToString()); } if (!response.success) { @@ -101,61 +94,63 @@ class RsmClient { ~RsmClient() = default; BasicResult SendWriteRequest(WriteRequestT req) { - auto token = SendAsyncWriteRequest(req); - auto poll_result = AwaitAsyncWriteRequest(token); + Notifier notifier; + ReadinessToken readiness_token{0}; + SendAsyncWriteRequest(req, notifier, readiness_token); + auto poll_result = AwaitAsyncWriteRequest(readiness_token); while (!poll_result) { - poll_result = AwaitAsyncWriteRequest(token); + poll_result = AwaitAsyncWriteRequest(readiness_token); } return poll_result.value(); } BasicResult SendReadRequest(ReadRequestT req) { - auto token = SendAsyncReadRequest(req); - auto poll_result = AwaitAsyncReadRequest(token); + Notifier notifier; + ReadinessToken readiness_token{0}; + SendAsyncReadRequest(req, notifier, readiness_token); + auto poll_result = AwaitAsyncReadRequest(readiness_token); while (!poll_result) { - poll_result = AwaitAsyncReadRequest(token); + poll_result = AwaitAsyncReadRequest(readiness_token); } return poll_result.value(); } /// AsyncRead methods - AsyncRequestToken SendAsyncReadRequest(const ReadRequestT &req) { - size_t token = async_token_generator_++; - + void SendAsyncReadRequest(const ReadRequestT &req, Notifier notifier, ReadinessToken readiness_token) { ReadRequest read_req = {.operation = req}; AsyncRequest> async_request{ .start_time = io_.Now(), .request = std::move(req), - .future = io_.template Request, ReadResponse>(leader_, read_req), + .notifier = notifier, + .future = io_.template RequestWithNotification, ReadResponse>( + leader_, read_req, notifier, readiness_token), }; - async_reads_.emplace(token, std::move(async_request)); - - return AsyncRequestToken{token}; + async_reads_.emplace(readiness_token.GetId(), std::move(async_request)); } - void ResendAsyncReadRequest(const AsyncRequestToken &token) { - auto &async_request = async_reads_.at(token.GetId()); + void ResendAsyncReadRequest(const ReadinessToken &readiness_token) { + auto &async_request = async_reads_.at(readiness_token.GetId()); ReadRequest read_req = {.operation = async_request.request}; - async_request.future = - io_.template Request, ReadResponse>(leader_, read_req); + async_request.future = io_.template RequestWithNotification, ReadResponse>( + leader_, read_req, async_request.notifier, readiness_token); } - std::optional> PollAsyncReadRequest(const AsyncRequestToken &token) { - auto &async_request = async_reads_.at(token.GetId()); + std::optional> PollAsyncReadRequest(const ReadinessToken &readiness_token) { + auto &async_request = async_reads_.at(readiness_token.GetId()); if (!async_request.future.IsReady()) { return std::nullopt; } - return AwaitAsyncReadRequest(); + return AwaitAsyncReadRequest(readiness_token); } - std::optional> AwaitAsyncReadRequest(const AsyncRequestToken &token) { - auto &async_request = async_reads_.at(token.GetId()); + std::optional> AwaitAsyncReadRequest(const ReadinessToken &readiness_token) { + auto &async_request = async_reads_.at(readiness_token.GetId()); ResponseResult> get_response_result = std::move(async_request.future).Wait(); const Duration overall_timeout = io_.GetDefaultTimeout(); @@ -165,7 +160,7 @@ class RsmClient { if (result_has_error && past_time_out) { // TODO static assert the exact type of error. spdlog::debug("client timed out while trying to communicate with leader server {}", leader_.ToString()); - async_reads_.erase(token.GetId()); + async_reads_.erase(readiness_token.GetId()); return TimedOut{}; } @@ -176,7 +171,7 @@ class RsmClient { PossiblyRedirectLeader(read_get_response); if (read_get_response.success) { - async_reads_.erase(token.GetId()); + async_reads_.erase(readiness_token.GetId()); spdlog::debug("returning read_return for RSM request"); return std::move(read_get_response.read_return); } @@ -184,49 +179,48 @@ class RsmClient { SelectRandomLeader(); } - ResendAsyncReadRequest(token); + ResendAsyncReadRequest(readiness_token); return std::nullopt; } /// AsyncWrite methods - AsyncRequestToken SendAsyncWriteRequest(const WriteRequestT &req) { - size_t token = async_token_generator_++; - + void SendAsyncWriteRequest(const WriteRequestT &req, Notifier notifier, ReadinessToken readiness_token) { WriteRequest write_req = {.operation = req}; AsyncRequest> async_request{ .start_time = io_.Now(), .request = std::move(req), - .future = io_.template Request, WriteResponse>(leader_, write_req), + .notifier = notifier, + .future = io_.template RequestWithNotification, WriteResponse>( + leader_, write_req, notifier, readiness_token), }; - async_writes_.emplace(token, std::move(async_request)); - - return AsyncRequestToken{token}; + async_writes_.emplace(readiness_token.GetId(), std::move(async_request)); } - void ResendAsyncWriteRequest(const AsyncRequestToken &token) { - auto &async_request = async_writes_.at(token.GetId()); + void ResendAsyncWriteRequest(const ReadinessToken &readiness_token) { + auto &async_request = async_writes_.at(readiness_token.GetId()); WriteRequest write_req = {.operation = async_request.request}; async_request.future = - io_.template Request, WriteResponse>(leader_, write_req); + io_.template RequestWithNotification, WriteResponse>( + leader_, write_req, async_request.notifier, readiness_token); } - std::optional> PollAsyncWriteRequest(const AsyncRequestToken &token) { - auto &async_request = async_writes_.at(token.GetId()); + std::optional> PollAsyncWriteRequest(const ReadinessToken &readiness_token) { + auto &async_request = async_writes_.at(readiness_token.GetId()); if (!async_request.future.IsReady()) { return std::nullopt; } - return AwaitAsyncWriteRequest(); + return AwaitAsyncWriteRequest(readiness_token); } - std::optional> AwaitAsyncWriteRequest(const AsyncRequestToken &token) { - auto &async_request = async_writes_.at(token.GetId()); + std::optional> AwaitAsyncWriteRequest(const ReadinessToken &readiness_token) { + auto &async_request = async_writes_.at(readiness_token.GetId()); ResponseResult> get_response_result = std::move(async_request.future).Wait(); const Duration overall_timeout = io_.GetDefaultTimeout(); @@ -236,7 +230,7 @@ class RsmClient { if (result_has_error && past_time_out) { // TODO static assert the exact type of error. spdlog::debug("client timed out while trying to communicate with leader server {}", leader_.ToString()); - async_writes_.erase(token.GetId()); + async_writes_.erase(readiness_token.GetId()); return TimedOut{}; } @@ -248,14 +242,14 @@ class RsmClient { PossiblyRedirectLeader(write_get_response); if (write_get_response.success) { - async_writes_.erase(token.GetId()); + async_writes_.erase(readiness_token.GetId()); return std::move(write_get_response.write_return); } } else { SelectRandomLeader(); } - ResendAsyncWriteRequest(token); + ResendAsyncWriteRequest(readiness_token); return std::nullopt; } diff --git a/src/io/simulator/simulator_handle.hpp b/src/io/simulator/simulator_handle.hpp index 5a5ad1ec0..0cd6b77b0 100644 --- a/src/io/simulator/simulator_handle.hpp +++ b/src/io/simulator/simulator_handle.hpp @@ -105,12 +105,16 @@ class SimulatorHandle { template ResponseFuture SubmitRequest(Address to_address, Address from_address, Request &&request, Duration timeout, - std::function &&maybe_tick_simulator) { + std::function &&maybe_tick_simulator, + std::function &&fill_notifier) { spdlog::trace("submitting request to {}", to_address.last_known_port); auto type_info = TypeInfoFor(request); - auto [future, promise] = memgraph::io::FuturePromisePairWithNotifier>( - std::forward>(maybe_tick_simulator)); + auto [future, promise] = memgraph::io::FuturePromisePairWithNotifications>( + // set notifier for when the Future::Wait is called + std::forward>(maybe_tick_simulator), + // set notifier for when Promise::Fill is called + std::forward>(fill_notifier)); std::unique_lock lock(mu_); diff --git a/src/io/simulator/simulator_transport.hpp b/src/io/simulator/simulator_transport.hpp index 5e5a24aa9..2107c34ca 100644 --- a/src/io/simulator/simulator_transport.hpp +++ b/src/io/simulator/simulator_transport.hpp @@ -15,6 +15,7 @@ #include #include "io/address.hpp" +#include "io/notifier.hpp" #include "io/simulator/simulator_handle.hpp" #include "io/time.hpp" @@ -33,11 +34,13 @@ class SimulatorTransport { : simulator_handle_(simulator_handle), address_(address), rng_(std::mt19937{seed}) {} template - ResponseFuture Request(Address to_address, Address from_address, RequestT request, Duration timeout) { + ResponseFuture Request(Address to_address, Address from_address, RequestT request, + std::function notification, Duration timeout) { std::function maybe_tick_simulator = [this] { return simulator_handle_->MaybeTickSimulator(); }; return simulator_handle_->template SubmitRequest(to_address, from_address, std::move(request), - timeout, std::move(maybe_tick_simulator)); + timeout, std::move(maybe_tick_simulator), + std::move(notification)); } template diff --git a/src/io/transport.hpp b/src/io/transport.hpp index 4994cb436..2c2e79060 100644 --- a/src/io/transport.hpp +++ b/src/io/transport.hpp @@ -20,6 +20,7 @@ #include "io/errors.hpp" #include "io/future.hpp" #include "io/message_histogram_collector.hpp" +#include "io/notifier.hpp" #include "io/time.hpp" #include "utils/result.hpp" @@ -84,7 +85,9 @@ class Io { template ResponseFuture RequestWithTimeout(Address address, RequestT request, Duration timeout) { const Address from_address = address_; - return implementation_.template Request(address, from_address, request, timeout); + std::function fill_notifier = nullptr; + return implementation_.template Request(address, from_address, request, fill_notifier, + timeout); } /// Issue a request that times out after the default timeout. This tends @@ -93,7 +96,30 @@ class Io { ResponseFuture Request(Address to_address, RequestT request) { const Duration timeout = default_timeout_; const Address from_address = address_; - return implementation_.template Request(to_address, from_address, std::move(request), timeout); + std::function fill_notifier = nullptr; + return implementation_.template Request(to_address, from_address, std::move(request), + fill_notifier, timeout); + } + + /// Issue a request that will notify a Notifier when it is filled or times out. + template + ResponseFuture RequestWithNotification(Address to_address, RequestT request, Notifier notifier, + ReadinessToken readiness_token) { + const Duration timeout = default_timeout_; + const Address from_address = address_; + std::function fill_notifier = std::bind(&Notifier::Notify, notifier, readiness_token); + return implementation_.template Request(to_address, from_address, std::move(request), + fill_notifier, timeout); + } + + /// Issue a request that will notify a Notifier when it is filled or times out. + template + ResponseFuture RequestWithNotificationAndTimeout(Address to_address, RequestT request, Notifier notifier, + ReadinessToken readiness_token, Duration timeout) { + const Address from_address = address_; + std::function fill_notifier = std::bind(&Notifier::Notify, notifier, readiness_token); + return implementation_.template Request(to_address, from_address, std::move(request), + fill_notifier, timeout); } /// Wait for an explicit number of microseconds for a request of one of the diff --git a/src/query/v2/request_router.hpp b/src/query/v2/request_router.hpp index 996272fdc..2d563ade0 100644 --- a/src/query/v2/request_router.hpp +++ b/src/query/v2/request_router.hpp @@ -31,6 +31,7 @@ #include "coordinator/shard_map.hpp" #include "io/address.hpp" #include "io/errors.hpp" +#include "io/notifier.hpp" #include "io/rsm/raft.hpp" #include "io/rsm/rsm_client.hpp" #include "io/rsm/shard_rsm.hpp" @@ -75,25 +76,11 @@ template struct ShardRequestState { memgraph::coordinator::Shard shard; TRequest request; - std::optional async_request_token; }; +// maps from ReadinessToken's internal size_t to the associated state template -struct ExecutionState { - using CompoundKey = io::rsm::ShardRsmKey; - using Shard = coordinator::Shard; - - // label is optional because some operators can create/remove etc, vertices. These kind of requests contain the label - // on the request itself. - std::optional label; - // Transaction id to be filled by the RequestRouter implementation - coordinator::Hlc transaction_id; - // Initialized by RequestRouter implementation. This vector is filled with the shards that - // the RequestRouter impl will send requests to. When a request to a shard exhausts it, meaning that - // it pulled all the requested data from the given Shard, it will be removed from the Vector. When the Vector becomes - // empty, it means that all of the requests have completed succefully. - std::vector> requests; -}; +using RunningRequests = std::unordered_map>; class RequestRouterInterface { public: @@ -238,26 +225,25 @@ class RequestRouter : public RequestRouterInterface { // TODO(kostasrim) Simplify return result std::vector ScanVertices(std::optional label) override { - ExecutionState state = {}; - state.label = label; - // create requests - InitializeExecutionState(state); + std::vector> unsent_requests = RequestsForScanVertices(label); + spdlog::error("created {} ScanVertices requests", unsent_requests.size()); // begin all requests in parallel - for (auto &request : state.requests) { + RunningRequests running_requests = {}; + running_requests.reserve(unsent_requests.size()); + for (size_t i = 0; i < unsent_requests.size(); i++) { + auto &request = unsent_requests[i]; + io::ReadinessToken readiness_token{i}; auto &storage_client = GetStorageClientForShard(request.shard); - msgs::ReadRequests req = request.request; - - request.async_request_token = storage_client.SendAsyncReadRequest(request.request); + storage_client.SendAsyncReadRequest(request.request, notifier_, readiness_token); + running_requests.emplace(readiness_token.GetId(), request); } + spdlog::error("sent {} ScanVertices requests in parallel", running_requests.size()); // drive requests to completion - std::vector responses; - responses.reserve(state.requests.size()); - do { - DriveReadResponses(state, responses); - } while (!state.requests.empty()); + auto responses = DriveReadResponses(running_requests); + spdlog::error("got back {} ScanVertices responses after driving to completion", responses.size()); // convert responses into VertexAccessor objects to return std::vector accessors; @@ -272,62 +258,53 @@ class RequestRouter : public RequestRouterInterface { } std::vector CreateVertices(std::vector new_vertices) override { - ExecutionState state = {}; MG_ASSERT(!new_vertices.empty()); // create requests - InitializeExecutionState(state, new_vertices); + std::vector> unsent_requests = + RequestsForCreateVertices(new_vertices); // begin all requests in parallel - for (auto &request : state.requests) { - auto req_deep_copy = request.request; - - for (auto &new_vertex : req_deep_copy.new_vertices) { + RunningRequests running_requests = {}; + running_requests.reserve(unsent_requests.size()); + for (size_t i = 0; i < unsent_requests.size(); i++) { + auto &request = unsent_requests[i]; + io::ReadinessToken readiness_token{i}; + for (auto &new_vertex : request.request.new_vertices) { new_vertex.label_ids.erase(new_vertex.label_ids.begin()); } - auto &storage_client = GetStorageClientForShard(request.shard); - - msgs::WriteRequests req = req_deep_copy; - request.async_request_token = storage_client.SendAsyncWriteRequest(req); + storage_client.SendAsyncWriteRequest(request.request, notifier_, readiness_token); + running_requests.emplace(readiness_token.GetId(), request); } // drive requests to completion - std::vector responses; - responses.reserve(state.requests.size()); - do { - DriveWriteResponses(state, responses); - } while (!state.requests.empty()); - - return responses; + return DriveWriteResponses(running_requests); } std::vector CreateExpand(std::vector new_edges) override { - ExecutionState state = {}; MG_ASSERT(!new_edges.empty()); // create requests - InitializeExecutionState(state, new_edges); + std::vector> unsent_requests = RequestsForCreateExpand(new_edges); // begin all requests in parallel - for (auto &request : state.requests) { + RunningRequests running_requests = {}; + running_requests.reserve(unsent_requests.size()); + for (size_t i = 0; i < unsent_requests.size(); i++) { + auto &request = unsent_requests[i]; + io::ReadinessToken readiness_token{i}; auto &storage_client = GetStorageClientForShard(request.shard); msgs::WriteRequests req = request.request; - request.async_request_token = storage_client.SendAsyncWriteRequest(req); + storage_client.SendAsyncWriteRequest(req, notifier_, readiness_token); + running_requests.emplace(readiness_token.GetId(), request); } // drive requests to completion - std::vector responses; - responses.reserve(state.requests.size()); - do { - DriveWriteResponses(state, responses); - } while (!state.requests.empty()); - - return responses; + return DriveWriteResponses(running_requests); } std::vector ExpandOne(msgs::ExpandOneRequest request) override { - ExecutionState state = {}; // TODO(kostasrim)Update to limit the batch size here // Expansions of the destination must be handled by the caller. For example // match (u:L1 { prop : 1 })-[:Friend]-(v:L1) @@ -335,21 +312,22 @@ class RequestRouter : public RequestRouterInterface { // must be fetched again with an ExpandOne(Edges.dst) // create requests - InitializeExecutionState(state, std::move(request)); + std::vector> unsent_requests = RequestsForExpandOne(request); // begin all requests in parallel - for (auto &request : state.requests) { + RunningRequests running_requests = {}; + running_requests.reserve(unsent_requests.size()); + for (size_t i = 0; i < unsent_requests.size(); i++) { + auto &request = unsent_requests[i]; + io::ReadinessToken readiness_token{i}; auto &storage_client = GetStorageClientForShard(request.shard); msgs::ReadRequests req = request.request; - request.async_request_token = storage_client.SendAsyncReadRequest(req); + storage_client.SendAsyncReadRequest(req, notifier_, readiness_token); + running_requests.emplace(readiness_token.GetId(), request); } // drive requests to completion - std::vector responses; - responses.reserve(state.requests.size()); - do { - DriveReadResponses(state, responses); - } while (!state.requests.empty()); + auto responses = DriveReadResponses(running_requests); // post-process responses std::vector result_rows; @@ -380,10 +358,8 @@ class RequestRouter : public RequestRouterInterface { } private: - void InitializeExecutionState(ExecutionState &state, - std::vector new_vertices) { - state.transaction_id = transaction_id_; - + std::vector> RequestsForCreateVertices( + const std::vector &new_vertices) { std::map per_shard_request_table; for (auto &new_vertex : new_vertices) { @@ -397,20 +373,21 @@ class RequestRouter : public RequestRouterInterface { per_shard_request_table[shard].new_vertices.push_back(std::move(new_vertex)); } + std::vector> requests = {}; + for (auto &[shard, request] : per_shard_request_table) { ShardRequestState shard_request_state{ .shard = shard, .request = request, - .async_request_token = std::nullopt, }; - state.requests.emplace_back(std::move(shard_request_state)); + requests.emplace_back(std::move(shard_request_state)); } + + return requests; } - void InitializeExecutionState(ExecutionState &state, - std::vector new_expands) { - state.transaction_id = transaction_id_; - + std::vector> RequestsForCreateExpand( + const std::vector &new_expands) { std::map per_shard_request_table; auto ensure_shard_exists_in_table = [&per_shard_request_table, transaction_id = transaction_id_](const Shard &shard) { @@ -435,27 +412,33 @@ class RequestRouter : public RequestRouterInterface { per_shard_request_table[shard_src_vertex].new_expands.push_back(std::move(new_expand)); } + std::vector> requests = {}; + for (auto &[shard, request] : per_shard_request_table) { ShardRequestState shard_request_state{ .shard = shard, .request = request, - .async_request_token = std::nullopt, }; - state.requests.emplace_back(std::move(shard_request_state)); + requests.emplace_back(std::move(shard_request_state)); } + + return requests; } - void InitializeExecutionState(ExecutionState &state) { + std::vector> RequestsForScanVertices( + const std::optional &label) { std::vector multi_shards; - state.transaction_id = transaction_id_; - if (!state.label) { - multi_shards = shards_map_.GetAllShards(); - } else { - const auto label_id = shards_map_.GetLabelId(*state.label); + if (label) { + const auto label_id = shards_map_.GetLabelId(*label); MG_ASSERT(label_id); MG_ASSERT(IsPrimaryLabel(*label_id)); - multi_shards = {shards_map_.GetShardsForLabel(*state.label)}; + multi_shards = {shards_map_.GetShardsForLabel(*label)}; + } else { + multi_shards = shards_map_.GetAllShards(); } + + std::vector> requests = {}; + for (auto &shards : multi_shards) { for (auto &[key, shard] : shards) { MG_ASSERT(!shard.empty()); @@ -467,22 +450,21 @@ class RequestRouter : public RequestRouterInterface { ShardRequestState shard_request_state{ .shard = shard, .request = std::move(request), - .async_request_token = std::nullopt, }; - state.requests.emplace_back(std::move(shard_request_state)); + requests.emplace_back(std::move(shard_request_state)); } } + + return requests; } - void InitializeExecutionState(ExecutionState &state, msgs::ExpandOneRequest request) { - state.transaction_id = transaction_id_; - + std::vector> RequestsForExpandOne(const msgs::ExpandOneRequest &request) { std::map per_shard_request_table; - auto top_level_rqst_template = request; + msgs::ExpandOneRequest top_level_rqst_template = request; top_level_rqst_template.transaction_id = transaction_id_; top_level_rqst_template.src_vertices.clear(); - state.requests.clear(); + for (auto &vertex : request.src_vertices) { auto shard = shards_map_.GetShardForKey(vertex.first.id, storage::conversions::ConvertPropertyVector(vertex.second)); @@ -492,15 +474,18 @@ class RequestRouter : public RequestRouterInterface { per_shard_request_table[shard].src_vertices.push_back(vertex); } + std::vector> requests = {}; + for (auto &[shard, request] : per_shard_request_table) { ShardRequestState shard_request_state{ .shard = shard, .request = request, - .async_request_token = std::nullopt, }; - state.requests.emplace_back(std::move(shard_request_state)); + requests.emplace_back(std::move(shard_request_state)); } + + return requests; } StorageClient &GetStorageClientForShard(Shard shard) { @@ -528,14 +513,18 @@ class RequestRouter : public RequestRouterInterface { } template - void DriveReadResponses(ExecutionState &state, std::vector &responses) { - for (auto &request : state.requests) { + std::vector DriveReadResponses(RunningRequests &running_requests) { + // Store responses in a map based on the corresponding request + // offset, so that they can be reassembled in the correct order + // even if they came back in randomized orders. + std::map response_map; + + while (response_map.size() < running_requests.size()) { + auto ready = notifier_.Await(); + auto &request = running_requests.at(ready.GetId()); auto &storage_client = GetStorageClientForShard(request.shard); - auto poll_result = storage_client.AwaitAsyncReadRequest(request.async_request_token.value()); - while (!poll_result) { - poll_result = storage_client.AwaitAsyncReadRequest(request.async_request_token.value()); - } + auto poll_result = storage_client.PollAsyncReadRequest(ready); if (poll_result->HasError()) { throw std::runtime_error("RequestRouter Read request timed out"); @@ -547,20 +536,36 @@ class RequestRouter : public RequestRouterInterface { throw std::runtime_error("RequestRouter Read request did not succeed"); } - responses.push_back(std::move(response)); + // the readiness token has an ID based on the request vector offset + response_map.emplace(ready.GetId(), std::move(response)); } - state.requests.clear(); + + std::vector responses; + responses.reserve(running_requests.size()); + + int last = -1; + for (auto &&[offset, response] : response_map) { + MG_ASSERT(last + 1 == offset); + responses.emplace_back(std::forward(response)); + last = offset; + } + + return responses; } template - void DriveWriteResponses(ExecutionState &state, std::vector &responses) { - for (auto &request : state.requests) { + std::vector DriveWriteResponses(RunningRequests &running_requests) { + // Store responses in a map based on the corresponding request + // offset, so that they can be reassembled in the correct order + // even if they came back in randomized orders. + std::map response_map; + + while (response_map.size() < running_requests.size()) { + auto ready = notifier_.Await(); + auto &request = running_requests.at(ready.GetId()); auto &storage_client = GetStorageClientForShard(request.shard); - auto poll_result = storage_client.AwaitAsyncWriteRequest(request.async_request_token.value()); - while (!poll_result) { - poll_result = storage_client.AwaitAsyncWriteRequest(request.async_request_token.value()); - } + auto poll_result = storage_client.PollAsyncWriteRequest(ready); if (poll_result->HasError()) { throw std::runtime_error("RequestRouter Write request timed out"); @@ -572,9 +577,21 @@ class RequestRouter : public RequestRouterInterface { throw std::runtime_error("RequestRouter Write request did not succeed"); } - responses.push_back(std::move(response)); + // the readiness token has an ID based on the request vector offset + response_map.emplace(ready.GetId(), std::move(response)); } - state.requests.clear(); + + std::vector responses; + responses.reserve(running_requests.size()); + + int last = -1; + for (auto &&[offset, response] : response_map) { + MG_ASSERT(last + 1 == offset); + responses.emplace_back(std::forward(response)); + last = offset; + } + + return responses; } void SetUpNameIdMappers() { @@ -603,6 +620,7 @@ class RequestRouter : public RequestRouterInterface { RsmStorageClientManager storage_cli_manager_; io::Io io_; coordinator::Hlc transaction_id_; + io::Notifier notifier_ = {}; // TODO(kostasrim) Add batch prefetching }; } // namespace memgraph::query::v2 diff --git a/tests/unit/future.cpp b/tests/unit/future.cpp index 490e19bbc..866a74fce 100644 --- a/tests/unit/future.cpp +++ b/tests/unit/future.cpp @@ -28,13 +28,19 @@ void Wait(Future future_1, Promise promise_2) { TEST(Future, BasicLifecycle) { std::atomic_bool waiting = false; + std::atomic_bool filled = false; - std::function notifier = [&] { + std::function wait_notifier = [&] { waiting.store(true, std::memory_order_seq_cst); return false; }; - auto [future_1, promise_1] = FuturePromisePairWithNotifier(notifier); + std::function fill_notifier = [&] { + filled.store(true, std::memory_order_seq_cst); + return false; + }; + + auto [future_1, promise_1] = FuturePromisePairWithNotifications(wait_notifier, fill_notifier); auto [future_2, promise_2] = FuturePromisePair(); std::jthread t1(Wait, std::move(future_1), std::move(promise_2)); @@ -50,6 +56,8 @@ TEST(Future, BasicLifecycle) { t1.join(); t2.join(); + EXPECT_TRUE(filled.load(std::memory_order_acquire)); + std::string result_2 = std::move(future_2).Wait(); EXPECT_TRUE(result_2 == "it worked"); } diff --git a/tests/unit/high_density_shard_create_scan.cpp b/tests/unit/high_density_shard_create_scan.cpp index 2be48fc77..cefa238ed 100644 --- a/tests/unit/high_density_shard_create_scan.cpp +++ b/tests/unit/high_density_shard_create_scan.cpp @@ -194,7 +194,8 @@ void ExecuteOp(query::v2::RequestRouter &request_router, std::se ScanAll scan_all) { auto results = request_router.ScanVertices("test_label"); - MG_ASSERT(results.size() == correctness_model.size()); + spdlog::error("got {} results, model size is {}", results.size(), correctness_model.size()); + EXPECT_EQ(results.size(), correctness_model.size()); for (const auto &vertex_accessor : results) { const auto properties = vertex_accessor.Properties(); diff --git a/tests/unit/query_v2_expression_evaluator.cpp b/tests/unit/query_v2_expression_evaluator.cpp index 5f77ed4e7..50f578bb2 100644 --- a/tests/unit/query_v2_expression_evaluator.cpp +++ b/tests/unit/query_v2_expression_evaluator.cpp @@ -84,13 +84,14 @@ class MockedRequestRouter : public RequestRouterInterface { void Commit() override {} std::vector ScanVertices(std::optional /* label */) override { return {}; } - std::vector CreateVertices(std::vector new_vertices) override { + std::vector CreateVertices( + std::vector /* new_vertices */) override { return {}; } - std::vector ExpandOne(ExpandOneRequest request) override { return {}; } + std::vector ExpandOne(ExpandOneRequest /* request */) override { return {}; } - std::vector CreateExpand(std::vector new_edges) override { return {}; } + std::vector CreateExpand(std::vector /* new_edges */) override { return {}; } const std::string &PropertyToName(memgraph::storage::v3::PropertyId id) const override { return properties_.IdToName(id.AsUint());