Allow the RsmClient to store multiple in-flight requests. Update the ShardRequestManager to use the new request tokens and refactor some bug-prone aspects of it
This commit is contained in:
parent
5c0e41ed44
commit
631d18465b
@ -14,6 +14,7 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "io/address.hpp"
|
#include "io/address.hpp"
|
||||||
@ -36,6 +37,21 @@ using memgraph::io::rsm::WriteRequest;
|
|||||||
using memgraph::io::rsm::WriteResponse;
|
using memgraph::io::rsm::WriteResponse;
|
||||||
using memgraph::utils::BasicResult;
|
using memgraph::utils::BasicResult;
|
||||||
|
|
||||||
|
class AsyncRequestToken {
|
||||||
|
size_t id_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
AsyncRequestToken(size_t id) : id_(id) {}
|
||||||
|
size_t GetId() const { return id_; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename RequestT, typename ResponseT>
|
||||||
|
struct AsyncRequest {
|
||||||
|
Time start_time;
|
||||||
|
RequestT request;
|
||||||
|
ResponseFuture<ResponseT> future;
|
||||||
|
};
|
||||||
|
|
||||||
template <typename IoImpl, typename WriteRequestT, typename WriteResponseT, typename ReadRequestT,
|
template <typename IoImpl, typename WriteRequestT, typename WriteResponseT, typename ReadRequestT,
|
||||||
typename ReadResponseT>
|
typename ReadResponseT>
|
||||||
class RsmClient {
|
class RsmClient {
|
||||||
@ -47,13 +63,10 @@ class RsmClient {
|
|||||||
|
|
||||||
/// State for single async read/write operations. In the future this could become a map
|
/// State for single async read/write operations. In the future this could become a map
|
||||||
/// of async operations that can be accessed via an ID etc...
|
/// of async operations that can be accessed via an ID etc...
|
||||||
std::optional<Time> async_read_before_;
|
std::unordered_map<size_t, AsyncRequest<ReadRequestT, ReadResponse<ReadResponseT>>> async_reads_;
|
||||||
std::optional<ResponseFuture<ReadResponse<ReadResponseT>>> async_read_;
|
std::unordered_map<size_t, AsyncRequest<WriteRequestT, WriteResponse<WriteResponseT>>> async_writes_;
|
||||||
ReadRequestT current_read_request_;
|
|
||||||
|
|
||||||
std::optional<Time> async_write_before_;
|
size_t async_token_generator_ = 0;
|
||||||
std::optional<ResponseFuture<WriteResponse<WriteResponseT>>> async_write_;
|
|
||||||
WriteRequestT current_write_request_;
|
|
||||||
|
|
||||||
void SelectRandomLeader() {
|
void SelectRandomLeader() {
|
||||||
std::uniform_int_distribution<size_t> addr_distrib(0, (server_addrs_.size() - 1));
|
std::uniform_int_distribution<size_t> addr_distrib(0, (server_addrs_.size() - 1));
|
||||||
@ -156,42 +169,56 @@ class RsmClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// AsyncRead methods
|
/// AsyncRead methods
|
||||||
void SendAsyncReadRequest(const ReadRequestT &req) {
|
AsyncRequestToken SendAsyncReadRequest(const ReadRequestT &req) {
|
||||||
MG_ASSERT(!async_read_);
|
size_t token = async_token_generator_++;
|
||||||
|
|
||||||
ReadRequest<ReadRequestT> read_req = {.operation = req};
|
ReadRequest<ReadRequestT> read_req = {.operation = req};
|
||||||
|
|
||||||
if (!async_read_before_) {
|
AsyncRequest<ReadRequestT, ReadResponse<ReadResponseT>> async_request{
|
||||||
async_read_before_ = io_.Now();
|
.start_time = io_.Now(),
|
||||||
}
|
.request = std::move(req),
|
||||||
current_read_request_ = std::move(req);
|
.future = io_.template Request<ReadRequest<ReadRequestT>, ReadResponse<ReadResponseT>>(leader_, read_req),
|
||||||
async_read_ = io_.template Request<ReadRequest<ReadRequestT>, ReadResponse<ReadResponseT>>(leader_, read_req);
|
};
|
||||||
|
|
||||||
|
async_reads_.emplace(token, std::move(async_request));
|
||||||
|
|
||||||
|
return AsyncRequestToken(token);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<BasicResult<TimedOut, ReadResponseT>> PollAsyncReadRequest() {
|
void ResendAsyncReadRequest(AsyncRequestToken &token) {
|
||||||
MG_ASSERT(async_read_);
|
auto &async_request = async_reads_.at(token.GetId());
|
||||||
|
|
||||||
if (!async_read_->IsReady()) {
|
ReadRequest<ReadRequestT> read_req = {.operation = async_request.request};
|
||||||
|
|
||||||
|
async_request.future =
|
||||||
|
io_.template Request<ReadRequest<ReadRequestT>, ReadResponse<ReadResponseT>>(leader_, read_req);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<BasicResult<TimedOut, ReadResponseT>> PollAsyncReadRequest(AsyncRequestToken &token) {
|
||||||
|
auto &async_request = async_reads_.at(token.GetId());
|
||||||
|
|
||||||
|
if (!async_request.future.IsReady()) {
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
return AwaitAsyncReadRequest();
|
return AwaitAsyncReadRequest();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<BasicResult<TimedOut, ReadResponseT>> AwaitAsyncReadRequest() {
|
std::optional<BasicResult<TimedOut, ReadResponseT>> AwaitAsyncReadRequest(AsyncRequestToken &token) {
|
||||||
ResponseResult<ReadResponse<ReadResponseT>> get_response_result = std::move(*async_read_).Wait();
|
auto &async_request = async_reads_.at(token.GetId());
|
||||||
async_read_.reset();
|
ResponseResult<ReadResponse<ReadResponseT>> get_response_result = std::move(async_request.future).Wait();
|
||||||
|
|
||||||
const Duration overall_timeout = io_.GetDefaultTimeout();
|
const Duration overall_timeout = io_.GetDefaultTimeout();
|
||||||
const bool past_time_out = io_.Now() < *async_read_before_ + overall_timeout;
|
const bool past_time_out = io_.Now() > async_request.start_time + overall_timeout;
|
||||||
const bool result_has_error = get_response_result.HasError();
|
const bool result_has_error = get_response_result.HasError();
|
||||||
|
|
||||||
if (result_has_error && past_time_out) {
|
if (result_has_error && past_time_out) {
|
||||||
// TODO static assert the exact type of error.
|
// TODO static assert the exact type of error.
|
||||||
spdlog::debug("client timed out while trying to communicate with leader server {}", leader_.ToString());
|
spdlog::debug("client timed out while trying to communicate with leader server {}", leader_.ToString());
|
||||||
async_read_before_ = std::nullopt;
|
async_reads_.erase(token.GetId());
|
||||||
return TimedOut{};
|
return TimedOut{};
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!result_has_error) {
|
if (!result_has_error) {
|
||||||
ResponseEnvelope<ReadResponse<ReadResponseT>> &&get_response_envelope = std::move(get_response_result.GetValue());
|
ResponseEnvelope<ReadResponse<ReadResponseT>> &&get_response_envelope = std::move(get_response_result.GetValue());
|
||||||
ReadResponse<ReadResponseT> &&read_get_response = std::move(get_response_envelope.message);
|
ReadResponse<ReadResponseT> &&read_get_response = std::move(get_response_envelope.message);
|
||||||
@ -199,54 +226,69 @@ class RsmClient {
|
|||||||
PossiblyRedirectLeader(read_get_response);
|
PossiblyRedirectLeader(read_get_response);
|
||||||
|
|
||||||
if (read_get_response.success) {
|
if (read_get_response.success) {
|
||||||
async_read_before_ = std::nullopt;
|
async_reads_.erase(token.GetId());
|
||||||
return std::move(read_get_response.read_return);
|
return std::move(read_get_response.read_return);
|
||||||
}
|
}
|
||||||
SendAsyncReadRequest(current_read_request_);
|
} else {
|
||||||
} else if (result_has_error) {
|
|
||||||
SelectRandomLeader();
|
SelectRandomLeader();
|
||||||
SendAsyncReadRequest(current_read_request_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ResendAsyncReadRequest(token);
|
||||||
|
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// AsyncWrite methods
|
/// AsyncWrite methods
|
||||||
void SendAsyncWriteRequest(const WriteRequestT &req) {
|
AsyncRequestToken SendAsyncWriteRequest(const WriteRequestT &req) {
|
||||||
MG_ASSERT(!async_write_);
|
size_t token = async_token_generator_++;
|
||||||
|
|
||||||
WriteRequest<WriteRequestT> write_req = {.operation = req};
|
WriteRequest<WriteRequestT> write_req = {.operation = req};
|
||||||
|
|
||||||
if (!async_write_before_) {
|
AsyncRequest<WriteRequestT, WriteResponse<WriteResponseT>> async_request{
|
||||||
async_write_before_ = io_.Now();
|
.start_time = io_.Now(),
|
||||||
}
|
.request = std::move(req),
|
||||||
current_write_request_ = std::move(req);
|
.future = io_.template Request<WriteRequest<WriteRequestT>, WriteResponse<WriteResponseT>>(leader_, write_req),
|
||||||
async_write_ = io_.template Request<WriteRequest<WriteRequestT>, WriteResponse<WriteResponseT>>(leader_, write_req);
|
};
|
||||||
|
|
||||||
|
async_writes_.emplace(token, std::move(async_request));
|
||||||
|
|
||||||
|
return AsyncRequestToken(token);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<BasicResult<TimedOut, WriteResponseT>> PollAsyncWriteRequest() {
|
void ResendAsyncWriteRequest(AsyncRequestToken &token) {
|
||||||
MG_ASSERT(async_write_);
|
auto &async_request = async_writes_.at(token.GetId());
|
||||||
|
|
||||||
if (!async_write_->IsReady()) {
|
WriteRequest<WriteRequestT> write_req = {.operation = async_request.request};
|
||||||
|
|
||||||
|
async_request.future =
|
||||||
|
io_.template Request<WriteRequest<WriteRequestT>, WriteResponse<WriteResponseT>>(leader_, write_req);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<BasicResult<TimedOut, WriteResponseT>> PollAsyncWriteRequest(AsyncRequestToken &token) {
|
||||||
|
auto &async_request = async_writes_.at(token.GetId());
|
||||||
|
|
||||||
|
if (!async_request.future.IsReady()) {
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
return AwaitAsyncWriteRequest();
|
return AwaitAsyncWriteRequest();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<BasicResult<TimedOut, WriteResponseT>> AwaitAsyncWriteRequest() {
|
std::optional<BasicResult<TimedOut, WriteResponseT>> AwaitAsyncWriteRequest(AsyncRequestToken &token) {
|
||||||
ResponseResult<WriteResponse<WriteResponseT>> get_response_result = std::move(*async_write_).Wait();
|
auto &async_request = async_writes_.at(token.GetId());
|
||||||
async_write_.reset();
|
ResponseResult<WriteResponse<WriteResponseT>> get_response_result = std::move(async_request.future).Wait();
|
||||||
|
|
||||||
const Duration overall_timeout = io_.GetDefaultTimeout();
|
const Duration overall_timeout = io_.GetDefaultTimeout();
|
||||||
const bool past_time_out = io_.Now() < *async_write_before_ + overall_timeout;
|
const bool past_time_out = io_.Now() > async_request.start_time + overall_timeout;
|
||||||
const bool result_has_error = get_response_result.HasError();
|
const bool result_has_error = get_response_result.HasError();
|
||||||
|
|
||||||
if (result_has_error && past_time_out) {
|
if (result_has_error && past_time_out) {
|
||||||
// TODO static assert the exact type of error.
|
// TODO static assert the exact type of error.
|
||||||
spdlog::debug("client timed out while trying to communicate with leader server {}", leader_.ToString());
|
spdlog::debug("client timed out while trying to communicate with leader server {}", leader_.ToString());
|
||||||
async_write_before_ = std::nullopt;
|
async_writes_.erase(token.GetId());
|
||||||
return TimedOut{};
|
return TimedOut{};
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!result_has_error) {
|
if (!result_has_error) {
|
||||||
ResponseEnvelope<WriteResponse<WriteResponseT>> &&get_response_envelope =
|
ResponseEnvelope<WriteResponse<WriteResponseT>> &&get_response_envelope =
|
||||||
std::move(get_response_result.GetValue());
|
std::move(get_response_result.GetValue());
|
||||||
@ -255,14 +297,15 @@ class RsmClient {
|
|||||||
PossiblyRedirectLeader(write_get_response);
|
PossiblyRedirectLeader(write_get_response);
|
||||||
|
|
||||||
if (write_get_response.success) {
|
if (write_get_response.success) {
|
||||||
async_write_before_ = std::nullopt;
|
async_writes_.erase(token.GetId());
|
||||||
return std::move(write_get_response.write_return);
|
return std::move(write_get_response.write_return);
|
||||||
}
|
}
|
||||||
SendAsyncWriteRequest(current_write_request_);
|
} else {
|
||||||
} else if (result_has_error) {
|
|
||||||
SelectRandomLeader();
|
SelectRandomLeader();
|
||||||
SendAsyncWriteRequest(current_write_request_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ResendAsyncWriteRequest(token);
|
||||||
|
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -72,6 +72,13 @@ class RsmStorageClientManager {
|
|||||||
std::map<Shard, TStorageClient> cli_cache_;
|
std::map<Shard, TStorageClient> cli_cache_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename TRequest>
|
||||||
|
struct ShardRequestState {
|
||||||
|
memgraph::coordinator::Shard shard;
|
||||||
|
TRequest request;
|
||||||
|
std::optional<io::rsm::AsyncRequestToken> async_request_token;
|
||||||
|
};
|
||||||
|
|
||||||
template <typename TRequest>
|
template <typename TRequest>
|
||||||
struct ExecutionState {
|
struct ExecutionState {
|
||||||
using CompoundKey = memgraph::io::rsm::ShardRsmKey;
|
using CompoundKey = memgraph::io::rsm::ShardRsmKey;
|
||||||
@ -91,14 +98,13 @@ struct ExecutionState {
|
|||||||
// it pulled all the requested data from the given Shard, it will be removed from the Vector. When the Vector becomes
|
// it pulled all the requested data from the given Shard, it will be removed from the Vector. When the Vector becomes
|
||||||
// empty, it means that all of the requests have completed succefully.
|
// empty, it means that all of the requests have completed succefully.
|
||||||
// TODO(gvolfing)
|
// TODO(gvolfing)
|
||||||
// Maybe make this into a more complex object to be able to keep track of paginated resutls. E.g. instead of a vector
|
// Maybe make this into a more complex object to be able to keep track of paginated results. E.g. instead of a vector
|
||||||
// of Shards make it into a std::vector<std::pair<Shard, PaginatedResultType>> (probably a struct instead of a pair)
|
// of Shards make it into a std::vector<std::pair<Shard, PaginatedResultType>> (probably a struct instead of a pair)
|
||||||
// where PaginatedResultType is an enum signaling the progress on the given request. This way we can easily check if
|
// where PaginatedResultType is an enum signaling the progress on the given request. This way we can easily check if
|
||||||
// a partial response on a shard(if there is one) is finished and we can send off the request for the next batch.
|
// a partial response on a shard(if there is one) is finished and we can send off the request for the next batch.
|
||||||
std::vector<Shard> shard_cache;
|
|
||||||
// 1-1 mapping with `shard_cache`.
|
// 1-1 mapping with `shard_cache`.
|
||||||
// A vector that tracks request metadata for each shard (For example, next_id for a ScanAll on Shard A)
|
// A vector that tracks request metadata for each shard (For example, next_id for a ScanAll on Shard A)
|
||||||
std::vector<TRequest> requests;
|
std::vector<ShardRequestState<TRequest>> requests;
|
||||||
State state = INITIALIZING;
|
State state = INITIALIZING;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -259,8 +265,8 @@ class ShardRequestManager : public ShardRequestManagerInterface {
|
|||||||
};
|
};
|
||||||
|
|
||||||
std::map<Shard, PaginatedResponseState> paginated_response_tracker;
|
std::map<Shard, PaginatedResponseState> paginated_response_tracker;
|
||||||
for (const auto &shard : state.shard_cache) {
|
for (const auto &request : state.requests) {
|
||||||
paginated_response_tracker.insert(std::make_pair(shard, PaginatedResponseState::Pending));
|
paginated_response_tracker.insert(std::make_pair(request.shard, PaginatedResponseState::Pending));
|
||||||
}
|
}
|
||||||
|
|
||||||
do {
|
do {
|
||||||
@ -278,15 +284,14 @@ class ShardRequestManager : public ShardRequestManagerInterface {
|
|||||||
MG_ASSERT(!new_vertices.empty());
|
MG_ASSERT(!new_vertices.empty());
|
||||||
MaybeInitializeExecutionState(state, new_vertices);
|
MaybeInitializeExecutionState(state, new_vertices);
|
||||||
std::vector<CreateVerticesResponse> responses;
|
std::vector<CreateVerticesResponse> responses;
|
||||||
auto &shard_cache_ref = state.shard_cache;
|
|
||||||
|
|
||||||
// 1. Send the requests.
|
// 1. Send the requests.
|
||||||
SendAllRequests(state, shard_cache_ref);
|
SendAllRequests(state);
|
||||||
|
|
||||||
// 2. Block untill all the futures are exhausted
|
// 2. Block untill all the futures are exhausted
|
||||||
do {
|
do {
|
||||||
AwaitOnResponses(state, responses);
|
AwaitOnResponses(state, responses);
|
||||||
} while (!state.shard_cache.empty());
|
} while (!state.requests.empty());
|
||||||
|
|
||||||
MaybeCompleteState(state);
|
MaybeCompleteState(state);
|
||||||
// TODO(kostasrim) Before returning start prefetching the batch (this shall be done once we get MgFuture as return
|
// TODO(kostasrim) Before returning start prefetching the batch (this shall be done once we get MgFuture as return
|
||||||
@ -299,11 +304,9 @@ class ShardRequestManager : public ShardRequestManagerInterface {
|
|||||||
MG_ASSERT(!new_edges.empty());
|
MG_ASSERT(!new_edges.empty());
|
||||||
MaybeInitializeExecutionState(state, new_edges);
|
MaybeInitializeExecutionState(state, new_edges);
|
||||||
std::vector<CreateExpandResponse> responses;
|
std::vector<CreateExpandResponse> responses;
|
||||||
auto &shard_cache_ref = state.shard_cache;
|
for (auto &request : state.requests) {
|
||||||
size_t id{0};
|
auto &storage_client = GetStorageClientForShard(request.shard);
|
||||||
for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end(); ++id) {
|
WriteRequests req = request.request;
|
||||||
auto &storage_client = GetStorageClientForShard(*shard_it);
|
|
||||||
WriteRequests req = state.requests[id];
|
|
||||||
auto write_response_result = storage_client.SendWriteRequest(std::move(req));
|
auto write_response_result = storage_client.SendWriteRequest(std::move(req));
|
||||||
if (write_response_result.HasError()) {
|
if (write_response_result.HasError()) {
|
||||||
throw std::runtime_error("CreateVertices request timedout");
|
throw std::runtime_error("CreateVertices request timedout");
|
||||||
@ -315,9 +318,9 @@ class ShardRequestManager : public ShardRequestManagerInterface {
|
|||||||
throw std::runtime_error("CreateExpand request did not succeed");
|
throw std::runtime_error("CreateExpand request did not succeed");
|
||||||
}
|
}
|
||||||
responses.push_back(mapped_response);
|
responses.push_back(mapped_response);
|
||||||
shard_it = shard_cache_ref.erase(shard_it);
|
|
||||||
}
|
}
|
||||||
// We are done with this state
|
// We are done with this state
|
||||||
|
state.requests.clear();
|
||||||
MaybeCompleteState(state);
|
MaybeCompleteState(state);
|
||||||
return responses;
|
return responses;
|
||||||
}
|
}
|
||||||
@ -330,15 +333,14 @@ class ShardRequestManager : public ShardRequestManagerInterface {
|
|||||||
// must be fetched again with an ExpandOne(Edges.dst)
|
// must be fetched again with an ExpandOne(Edges.dst)
|
||||||
MaybeInitializeExecutionState(state, std::move(request));
|
MaybeInitializeExecutionState(state, std::move(request));
|
||||||
std::vector<ExpandOneResponse> responses;
|
std::vector<ExpandOneResponse> responses;
|
||||||
auto &shard_cache_ref = state.shard_cache;
|
|
||||||
|
|
||||||
// 1. Send the requests.
|
// 1. Send the requests.
|
||||||
SendAllRequests(state, shard_cache_ref);
|
SendAllRequests(state);
|
||||||
|
|
||||||
// 2. Block untill all the futures are exhausted
|
// 2. Block untill all the futures are exhausted
|
||||||
do {
|
do {
|
||||||
AwaitOnResponses(state, responses);
|
AwaitOnResponses(state, responses);
|
||||||
} while (!state.shard_cache.empty());
|
} while (!state.requests.empty());
|
||||||
std::vector<ExpandOneResultRow> result_rows;
|
std::vector<ExpandOneResultRow> result_rows;
|
||||||
const auto total_row_count = std::accumulate(
|
const auto total_row_count = std::accumulate(
|
||||||
responses.begin(), responses.end(), 0,
|
responses.begin(), responses.end(), 0,
|
||||||
@ -402,13 +404,17 @@ class ShardRequestManager : public ShardRequestManagerInterface {
|
|||||||
if (!per_shard_request_table.contains(shard)) {
|
if (!per_shard_request_table.contains(shard)) {
|
||||||
CreateVerticesRequest create_v_rqst{.transaction_id = transaction_id_};
|
CreateVerticesRequest create_v_rqst{.transaction_id = transaction_id_};
|
||||||
per_shard_request_table.insert(std::pair(shard, std::move(create_v_rqst)));
|
per_shard_request_table.insert(std::pair(shard, std::move(create_v_rqst)));
|
||||||
state.shard_cache.push_back(shard);
|
|
||||||
}
|
}
|
||||||
per_shard_request_table[shard].new_vertices.push_back(std::move(new_vertex));
|
per_shard_request_table[shard].new_vertices.push_back(std::move(new_vertex));
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto &[shard, rqst] : per_shard_request_table) {
|
for (auto &[shard, request] : per_shard_request_table) {
|
||||||
state.requests.push_back(std::move(rqst));
|
ShardRequestState<CreateVerticesRequest> shard_request_state{
|
||||||
|
.shard = shard,
|
||||||
|
.request = request,
|
||||||
|
.async_request_token = std::nullopt,
|
||||||
|
};
|
||||||
|
state.requests.emplace_back(std::move(shard_request_state));
|
||||||
}
|
}
|
||||||
state.state = ExecutionState<CreateVerticesRequest>::EXECUTING;
|
state.state = ExecutionState<CreateVerticesRequest>::EXECUTING;
|
||||||
}
|
}
|
||||||
@ -445,8 +451,12 @@ class ShardRequestManager : public ShardRequestManagerInterface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (auto &[shard, request] : per_shard_request_table) {
|
for (auto &[shard, request] : per_shard_request_table) {
|
||||||
state.shard_cache.push_back(shard);
|
ShardRequestState<CreateExpandRequest> shard_request_state{
|
||||||
state.requests.push_back(std::move(request));
|
.shard = shard,
|
||||||
|
.request = request,
|
||||||
|
.async_request_token = std::nullopt,
|
||||||
|
};
|
||||||
|
state.requests.emplace_back(std::move(shard_request_state));
|
||||||
}
|
}
|
||||||
state.state = ExecutionState<CreateExpandRequest>::EXECUTING;
|
state.state = ExecutionState<CreateExpandRequest>::EXECUTING;
|
||||||
}
|
}
|
||||||
@ -470,11 +480,18 @@ class ShardRequestManager : public ShardRequestManagerInterface {
|
|||||||
for (auto &shards : multi_shards) {
|
for (auto &shards : multi_shards) {
|
||||||
for (auto &[key, shard] : shards) {
|
for (auto &[key, shard] : shards) {
|
||||||
MG_ASSERT(!shard.empty());
|
MG_ASSERT(!shard.empty());
|
||||||
state.shard_cache.push_back(std::move(shard));
|
|
||||||
ScanVerticesRequest rqst;
|
ScanVerticesRequest request;
|
||||||
rqst.transaction_id = transaction_id_;
|
request.transaction_id = transaction_id_;
|
||||||
rqst.start_id.second = storage::conversions::ConvertValueVector(key);
|
request.start_id.second = storage::conversions::ConvertValueVector(key);
|
||||||
state.requests.push_back(std::move(rqst));
|
|
||||||
|
ShardRequestState<ScanVerticesRequest> shard_request_state{
|
||||||
|
.shard = shard,
|
||||||
|
.request = std::move(request),
|
||||||
|
.async_request_token = std::nullopt,
|
||||||
|
};
|
||||||
|
|
||||||
|
state.requests.emplace_back(std::move(shard_request_state));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
state.state = ExecutionState<ScanVerticesRequest>::EXECUTING;
|
state.state = ExecutionState<ScanVerticesRequest>::EXECUTING;
|
||||||
@ -497,13 +514,18 @@ class ShardRequestManager : public ShardRequestManagerInterface {
|
|||||||
shards_map_.GetShardForKey(vertex.first.id, storage::conversions::ConvertPropertyVector(vertex.second));
|
shards_map_.GetShardForKey(vertex.first.id, storage::conversions::ConvertPropertyVector(vertex.second));
|
||||||
if (!per_shard_request_table.contains(shard)) {
|
if (!per_shard_request_table.contains(shard)) {
|
||||||
per_shard_request_table.insert(std::pair(shard, top_level_rqst_template));
|
per_shard_request_table.insert(std::pair(shard, top_level_rqst_template));
|
||||||
state.shard_cache.push_back(shard);
|
|
||||||
}
|
}
|
||||||
per_shard_request_table[shard].src_vertices.push_back(vertex);
|
per_shard_request_table[shard].src_vertices.push_back(vertex);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto &[shard, rqst] : per_shard_request_table) {
|
for (auto &[shard, request] : per_shard_request_table) {
|
||||||
state.requests.push_back(std::move(rqst));
|
ShardRequestState<ExpandOneRequest> shard_request_state{
|
||||||
|
.shard = shard,
|
||||||
|
.request = request,
|
||||||
|
.async_request_token = std::nullopt,
|
||||||
|
};
|
||||||
|
|
||||||
|
state.requests.emplace_back(std::move(shard_request_state));
|
||||||
}
|
}
|
||||||
state.state = ExecutionState<ExpandOneRequest>::EXECUTING;
|
state.state = ExecutionState<ExpandOneRequest>::EXECUTING;
|
||||||
}
|
}
|
||||||
@ -533,65 +555,46 @@ class ShardRequestManager : public ShardRequestManagerInterface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void SendAllRequests(ExecutionState<ScanVerticesRequest> &state) {
|
void SendAllRequests(ExecutionState<ScanVerticesRequest> &state) {
|
||||||
int64_t shard_idx = 0;
|
for (auto &request : state.requests) {
|
||||||
for (const auto &request : state.requests) {
|
const auto ¤t_shard = request.shard;
|
||||||
const auto ¤t_shard = state.shard_cache[shard_idx];
|
|
||||||
|
|
||||||
auto &storage_client = GetStorageClientForShard(current_shard);
|
auto &storage_client = GetStorageClientForShard(current_shard);
|
||||||
ReadRequests req = request;
|
ReadRequests req = request.request;
|
||||||
storage_client.SendAsyncReadRequest(request);
|
|
||||||
|
|
||||||
++shard_idx;
|
request.async_request_token = storage_client.SendAsyncReadRequest(request.request);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SendAllRequests(ExecutionState<CreateVerticesRequest> &state,
|
void SendAllRequests(ExecutionState<CreateVerticesRequest> &state) {
|
||||||
std::vector<memgraph::coordinator::Shard> &shard_cache_ref) {
|
for (auto &request : state.requests) {
|
||||||
size_t id = 0;
|
auto req_deep_copy = request.request;
|
||||||
for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end(); ++shard_it) {
|
|
||||||
// This is fine because all new_vertices of each request end up on the same shard
|
|
||||||
const auto labels = state.requests[id].new_vertices[0].label_ids;
|
|
||||||
auto req_deep_copy = state.requests[id];
|
|
||||||
|
|
||||||
for (auto &new_vertex : req_deep_copy.new_vertices) {
|
for (auto &new_vertex : req_deep_copy.new_vertices) {
|
||||||
new_vertex.label_ids.erase(new_vertex.label_ids.begin());
|
new_vertex.label_ids.erase(new_vertex.label_ids.begin());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto &storage_client = GetStorageClientForShard(*shard_it);
|
auto &storage_client = GetStorageClientForShard(request.shard);
|
||||||
|
|
||||||
WriteRequests req = req_deep_copy;
|
WriteRequests req = req_deep_copy;
|
||||||
storage_client.SendAsyncWriteRequest(req);
|
request.async_request_token = storage_client.SendAsyncWriteRequest(req);
|
||||||
++id;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SendAllRequests(ExecutionState<ExpandOneRequest> &state,
|
void SendAllRequests(ExecutionState<ExpandOneRequest> &state) {
|
||||||
std::vector<memgraph::coordinator::Shard> &shard_cache_ref) {
|
for (auto &request : state.requests) {
|
||||||
size_t id = 0;
|
auto &storage_client = GetStorageClientForShard(request.shard);
|
||||||
for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end(); ++shard_it) {
|
ReadRequests req = request.request;
|
||||||
auto &storage_client = GetStorageClientForShard(*shard_it);
|
request.async_request_token = storage_client.SendAsyncReadRequest(req);
|
||||||
ReadRequests req = state.requests[id];
|
|
||||||
storage_client.SendAsyncReadRequest(req);
|
|
||||||
++id;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void AwaitOnResponses(ExecutionState<CreateVerticesRequest> &state, std::vector<CreateVerticesResponse> &responses) {
|
void AwaitOnResponses(ExecutionState<CreateVerticesRequest> &state, std::vector<CreateVerticesResponse> &responses) {
|
||||||
auto &shard_cache_ref = state.shard_cache;
|
for (auto &request : state.requests) {
|
||||||
int64_t request_idx = 0;
|
auto &storage_client = GetStorageClientForShard(request.shard);
|
||||||
|
|
||||||
for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end();) {
|
auto poll_result = storage_client.AwaitAsyncWriteRequest(request.async_request_token.value());
|
||||||
// This is fine because all new_vertices of each request end up on the same shard
|
while (!poll_result) {
|
||||||
const auto labels = state.requests[request_idx].new_vertices[0].label_ids;
|
poll_result = storage_client.AwaitAsyncWriteRequest(request.async_request_token.value());
|
||||||
|
|
||||||
auto &storage_client = GetStorageClientForShard(*shard_it);
|
|
||||||
|
|
||||||
auto poll_result = storage_client.AwaitAsyncWriteRequest();
|
|
||||||
if (!poll_result) {
|
|
||||||
++shard_it;
|
|
||||||
++request_idx;
|
|
||||||
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (poll_result->HasError()) {
|
if (poll_result->HasError()) {
|
||||||
@ -605,26 +608,17 @@ class ShardRequestManager : public ShardRequestManagerInterface {
|
|||||||
throw std::runtime_error("CreateVertices request did not succeed");
|
throw std::runtime_error("CreateVertices request did not succeed");
|
||||||
}
|
}
|
||||||
responses.push_back(response);
|
responses.push_back(response);
|
||||||
|
|
||||||
shard_it = shard_cache_ref.erase(shard_it);
|
|
||||||
// Needed to maintain the 1-1 mapping between the ShardCache and the requests.
|
|
||||||
auto it = state.requests.begin() + request_idx;
|
|
||||||
state.requests.erase(it);
|
|
||||||
}
|
}
|
||||||
|
state.requests.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
void AwaitOnResponses(ExecutionState<ExpandOneRequest> &state, std::vector<ExpandOneResponse> &responses) {
|
void AwaitOnResponses(ExecutionState<ExpandOneRequest> &state, std::vector<ExpandOneResponse> &responses) {
|
||||||
auto &shard_cache_ref = state.shard_cache;
|
for (auto &request : state.requests) {
|
||||||
int64_t request_idx = 0;
|
auto &storage_client = GetStorageClientForShard(request.shard);
|
||||||
|
|
||||||
for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end();) {
|
auto poll_result = storage_client.AwaitAsyncReadRequest(request.async_request_token.value());
|
||||||
auto &storage_client = GetStorageClientForShard(*shard_it);
|
while (!poll_result) {
|
||||||
|
poll_result = storage_client.AwaitAsyncReadRequest(request.async_request_token.value());
|
||||||
auto poll_result = storage_client.PollAsyncReadRequest();
|
|
||||||
if (!poll_result) {
|
|
||||||
++shard_it;
|
|
||||||
++request_idx;
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (poll_result->HasError()) {
|
if (poll_result->HasError()) {
|
||||||
@ -642,36 +636,28 @@ class ShardRequestManager : public ShardRequestManagerInterface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
responses.push_back(std::move(response));
|
responses.push_back(std::move(response));
|
||||||
shard_it = shard_cache_ref.erase(shard_it);
|
|
||||||
// Needed to maintain the 1-1 mapping between the ShardCache and the requests.
|
|
||||||
auto it = state.requests.begin() + request_idx;
|
|
||||||
state.requests.erase(it);
|
|
||||||
}
|
}
|
||||||
|
state.requests.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
void AwaitOnPaginatedRequests(ExecutionState<ScanVerticesRequest> &state,
|
void AwaitOnPaginatedRequests(ExecutionState<ScanVerticesRequest> &state,
|
||||||
std::vector<ScanVerticesResponse> &responses,
|
std::vector<ScanVerticesResponse> &responses,
|
||||||
std::map<Shard, PaginatedResponseState> &paginated_response_tracker) {
|
std::map<Shard, PaginatedResponseState> &paginated_response_tracker) {
|
||||||
auto &shard_cache_ref = state.shard_cache;
|
std::vector<int> to_erase{};
|
||||||
|
|
||||||
// Find the first request that is not holding a paginated response.
|
for (int i = 0; i < state.requests.size(); i++) {
|
||||||
int64_t request_idx = 0;
|
auto &request = state.requests[i];
|
||||||
for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end();) {
|
// only operate on paginated requests
|
||||||
if (paginated_response_tracker.at(*shard_it) != PaginatedResponseState::Pending) {
|
if (paginated_response_tracker.at(request.shard) != PaginatedResponseState::Pending) {
|
||||||
++shard_it;
|
|
||||||
++request_idx;
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto &storage_client = GetStorageClientForShard(*shard_it);
|
auto &storage_client = GetStorageClientForShard(request.shard);
|
||||||
|
|
||||||
auto await_result = storage_client.AwaitAsyncReadRequest();
|
// drive it to completion
|
||||||
|
auto await_result = storage_client.AwaitAsyncReadRequest(request.async_request_token.value());
|
||||||
if (!await_result) {
|
while (!await_result) {
|
||||||
// Redirection has occured.
|
await_result = storage_client.AwaitAsyncReadRequest(request.async_request_token.value());
|
||||||
++shard_it;
|
|
||||||
++request_idx;
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (await_result->HasError()) {
|
if (await_result->HasError()) {
|
||||||
@ -685,17 +671,22 @@ class ShardRequestManager : public ShardRequestManagerInterface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!response.next_start_id) {
|
if (!response.next_start_id) {
|
||||||
paginated_response_tracker.erase((*shard_it));
|
paginated_response_tracker.erase(request.shard);
|
||||||
shard_cache_ref.erase(shard_it);
|
to_erase.push_back(i);
|
||||||
// Needed to maintain the 1-1 mapping between the ShardCache and the requests.
|
|
||||||
auto it = state.requests.begin() + request_idx;
|
|
||||||
state.requests.erase(it);
|
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
state.requests[request_idx].start_id.second = response.next_start_id->second;
|
request.request.start_id.second = response.next_start_id->second;
|
||||||
paginated_response_tracker[*shard_it] = PaginatedResponseState::PartiallyFinished;
|
paginated_response_tracker[request.shard] = PaginatedResponseState::PartiallyFinished;
|
||||||
}
|
}
|
||||||
|
|
||||||
responses.push_back(std::move(response));
|
responses.push_back(std::move(response));
|
||||||
|
|
||||||
|
// reverse sort to_erase to remove requests in reverse order for correctness
|
||||||
|
std::sort(to_erase.begin(), to_erase.end(), std::greater<>());
|
||||||
|
|
||||||
|
auto requests_begin = state.requests.begin();
|
||||||
|
for (int i : to_erase) {
|
||||||
|
state.requests.erase(requests_begin + i);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user