Make ShardRequestManager work with futures (#588)
The communication between the ShardRequestManager and the RsmClient used to be direct. In this PR this changes into a future-based communication type. The RsmClient stores state about the currently processed future (either read or write request) and exposes blocking and non-blocking functionality to obtain the filled future. The ShardRequestManager -for now- will send of the set of requests present in the ExecutionState and block on each of them until the requests are completed or the set of paginated responses(caused by, for example the batch-limit in ScanAll) are ready for the next round.
This commit is contained in:
parent
5347c06d76
commit
d06132cb33
@ -148,6 +148,14 @@ class Future {
|
||||
old.consumed_or_moved_ = true;
|
||||
}
|
||||
|
||||
Future &operator=(Future &&old) noexcept {
|
||||
MG_ASSERT(!old.consumed_or_moved_, "Future moved from after already being moved from or consumed.");
|
||||
shared_ = std::move(old.shared_);
|
||||
consumed_or_moved_ = old.consumed_or_moved_;
|
||||
old.consumed_or_moved_ = true;
|
||||
return *this;
|
||||
}
|
||||
|
||||
Future(const Future &) = delete;
|
||||
Future &operator=(const Future &) = delete;
|
||||
~Future() = default;
|
||||
|
@ -13,9 +13,11 @@
|
||||
|
||||
#include <iostream>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "io/address.hpp"
|
||||
#include "io/errors.hpp"
|
||||
#include "io/rsm/raft.hpp"
|
||||
#include "utils/result.hpp"
|
||||
|
||||
@ -43,21 +45,36 @@ class RsmClient {
|
||||
Address leader_;
|
||||
ServerPool server_addrs_;
|
||||
|
||||
/// State for single async read/write operations. In the future this could become a map
|
||||
/// of async operations that can be accessed via an ID etc...
|
||||
std::optional<Time> async_read_before_;
|
||||
std::optional<ResponseFuture<ReadResponse<ReadResponseT>>> async_read_;
|
||||
ReadRequestT current_read_request_;
|
||||
|
||||
std::optional<Time> async_write_before_;
|
||||
std::optional<ResponseFuture<WriteResponse<WriteResponseT>>> async_write_;
|
||||
WriteRequestT current_write_request_;
|
||||
|
||||
void SelectRandomLeader() {
|
||||
std::uniform_int_distribution<size_t> addr_distrib(0, (server_addrs_.size() - 1));
|
||||
size_t addr_index = io_.Rand(addr_distrib);
|
||||
leader_ = server_addrs_[addr_index];
|
||||
|
||||
spdlog::debug(
|
||||
"client NOT redirected to leader server despite our success failing to be processed (it probably was sent to "
|
||||
"a RSM Candidate) trying a random one at index {} with address {}",
|
||||
addr_index, leader_.ToString());
|
||||
}
|
||||
|
||||
template <typename ResponseT>
|
||||
void PossiblyRedirectLeader(const ResponseT &response) {
|
||||
if (response.retry_leader) {
|
||||
MG_ASSERT(!response.success, "retry_leader should never be set for successful responses");
|
||||
leader_ = response.retry_leader.value();
|
||||
spdlog::debug("client redirected to leader server {}", leader_.ToString());
|
||||
} else if (!response.success) {
|
||||
std::uniform_int_distribution<size_t> addr_distrib(0, (server_addrs_.size() - 1));
|
||||
size_t addr_index = io_.Rand(addr_distrib);
|
||||
leader_ = server_addrs_[addr_index];
|
||||
|
||||
spdlog::debug(
|
||||
"client NOT redirected to leader server despite our success failing to be processed (it probably was sent to "
|
||||
"a RSM Candidate) trying a random one at index {} with address {}",
|
||||
addr_index, leader_.ToString());
|
||||
}
|
||||
if (!response.success) {
|
||||
SelectRandomLeader();
|
||||
}
|
||||
}
|
||||
|
||||
@ -65,7 +82,13 @@ class RsmClient {
|
||||
RsmClient(Io<IoImpl> io, Address leader, ServerPool server_addrs)
|
||||
: io_{io}, leader_{leader}, server_addrs_{server_addrs} {}
|
||||
|
||||
RsmClient(const RsmClient &) = delete;
|
||||
RsmClient &operator=(const RsmClient &) = delete;
|
||||
RsmClient(RsmClient &&) noexcept = default;
|
||||
RsmClient &operator=(RsmClient &&) noexcept = default;
|
||||
|
||||
RsmClient() = delete;
|
||||
~RsmClient() = default;
|
||||
|
||||
BasicResult<TimedOut, WriteResponseT> SendWriteRequest(WriteRequestT req) {
|
||||
WriteRequest<WriteRequestT> client_req;
|
||||
@ -131,6 +154,117 @@ class RsmClient {
|
||||
|
||||
return TimedOut{};
|
||||
}
|
||||
|
||||
/// AsyncRead methods
|
||||
void SendAsyncReadRequest(const ReadRequestT &req) {
|
||||
MG_ASSERT(!async_read_);
|
||||
|
||||
ReadRequest<ReadRequestT> read_req = {.operation = req};
|
||||
|
||||
if (!async_read_before_) {
|
||||
async_read_before_ = io_.Now();
|
||||
}
|
||||
current_read_request_ = std::move(req);
|
||||
async_read_ = io_.template Request<ReadRequest<ReadRequestT>, ReadResponse<ReadResponseT>>(leader_, read_req);
|
||||
}
|
||||
|
||||
std::optional<BasicResult<TimedOut, ReadResponseT>> PollAsyncReadRequest() {
|
||||
MG_ASSERT(async_read_);
|
||||
|
||||
if (!async_read_->IsReady()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
return AwaitAsyncReadRequest();
|
||||
}
|
||||
|
||||
std::optional<BasicResult<TimedOut, ReadResponseT>> AwaitAsyncReadRequest() {
|
||||
ResponseResult<ReadResponse<ReadResponseT>> get_response_result = std::move(*async_read_).Wait();
|
||||
async_read_.reset();
|
||||
|
||||
const Duration overall_timeout = io_.GetDefaultTimeout();
|
||||
const bool past_time_out = io_.Now() < *async_read_before_ + overall_timeout;
|
||||
const bool result_has_error = get_response_result.HasError();
|
||||
|
||||
if (result_has_error && past_time_out) {
|
||||
// TODO static assert the exact type of error.
|
||||
spdlog::debug("client timed out while trying to communicate with leader server {}", leader_.ToString());
|
||||
async_read_before_ = std::nullopt;
|
||||
return TimedOut{};
|
||||
}
|
||||
if (!result_has_error) {
|
||||
ResponseEnvelope<ReadResponse<ReadResponseT>> &&get_response_envelope = std::move(get_response_result.GetValue());
|
||||
ReadResponse<ReadResponseT> &&read_get_response = std::move(get_response_envelope.message);
|
||||
|
||||
PossiblyRedirectLeader(read_get_response);
|
||||
|
||||
if (read_get_response.success) {
|
||||
async_read_before_ = std::nullopt;
|
||||
return std::move(read_get_response.read_return);
|
||||
}
|
||||
SendAsyncReadRequest(current_read_request_);
|
||||
} else if (result_has_error) {
|
||||
SelectRandomLeader();
|
||||
SendAsyncReadRequest(current_read_request_);
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
/// AsyncWrite methods
|
||||
void SendAsyncWriteRequest(const WriteRequestT &req) {
|
||||
MG_ASSERT(!async_write_);
|
||||
|
||||
WriteRequest<WriteRequestT> write_req = {.operation = req};
|
||||
|
||||
if (!async_write_before_) {
|
||||
async_write_before_ = io_.Now();
|
||||
}
|
||||
current_write_request_ = std::move(req);
|
||||
async_write_ = io_.template Request<WriteRequest<WriteRequestT>, WriteResponse<WriteResponseT>>(leader_, write_req);
|
||||
}
|
||||
|
||||
std::optional<BasicResult<TimedOut, WriteResponseT>> PollAsyncWriteRequest() {
|
||||
MG_ASSERT(async_write_);
|
||||
|
||||
if (!async_write_->IsReady()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
return AwaitAsyncWriteRequest();
|
||||
}
|
||||
|
||||
std::optional<BasicResult<TimedOut, WriteResponseT>> AwaitAsyncWriteRequest() {
|
||||
ResponseResult<WriteResponse<WriteResponseT>> get_response_result = std::move(*async_write_).Wait();
|
||||
async_write_.reset();
|
||||
|
||||
const Duration overall_timeout = io_.GetDefaultTimeout();
|
||||
const bool past_time_out = io_.Now() < *async_write_before_ + overall_timeout;
|
||||
const bool result_has_error = get_response_result.HasError();
|
||||
|
||||
if (result_has_error && past_time_out) {
|
||||
// TODO static assert the exact type of error.
|
||||
spdlog::debug("client timed out while trying to communicate with leader server {}", leader_.ToString());
|
||||
async_write_before_ = std::nullopt;
|
||||
return TimedOut{};
|
||||
}
|
||||
if (!result_has_error) {
|
||||
ResponseEnvelope<WriteResponse<WriteResponseT>> &&get_response_envelope =
|
||||
std::move(get_response_result.GetValue());
|
||||
WriteResponse<WriteResponseT> &&write_get_response = std::move(get_response_envelope.message);
|
||||
|
||||
PossiblyRedirectLeader(write_get_response);
|
||||
|
||||
if (write_get_response.success) {
|
||||
async_write_before_ = std::nullopt;
|
||||
return std::move(write_get_response.write_return);
|
||||
}
|
||||
SendAsyncWriteRequest(current_write_request_);
|
||||
} else if (result_has_error) {
|
||||
SelectRandomLeader();
|
||||
SendAsyncWriteRequest(current_write_request_);
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace memgraph::io::rsm
|
||||
|
@ -389,25 +389,7 @@ class DistributedScanAllAndFilterCursor : public Cursor {
|
||||
current_batch.clear();
|
||||
current_vertex_it = current_batch.end();
|
||||
request_state_ = msgs::ExecutionState<msgs::ScanVerticesRequest>{};
|
||||
|
||||
auto request = msgs::ScanVerticesRequest{};
|
||||
if (label_.has_value()) {
|
||||
request.label = msgs::Label{.id = label_.value()};
|
||||
}
|
||||
if (property_expression_pair_.has_value()) {
|
||||
request.property_expression_pair = std::make_pair(
|
||||
property_expression_pair_.value().first,
|
||||
expr::ExpressiontoStringWhileReplacingNodeAndEdgeSymbols(property_expression_pair_.value().second));
|
||||
}
|
||||
if (filter_expressions_.has_value()) {
|
||||
auto res = std::vector<std::string>{};
|
||||
res.reserve(filter_expressions_->size());
|
||||
std::transform(filter_expressions_->begin(), filter_expressions_->end(), std::back_inserter(res),
|
||||
[](auto &filter) { return expr::ExpressiontoStringWhileReplacingNodeAndEdgeSymbols(filter); });
|
||||
|
||||
request.filter_expressions = res;
|
||||
}
|
||||
request_state_.requests.emplace_back(request);
|
||||
request_state_.label = "label";
|
||||
}
|
||||
|
||||
void Reset() override {
|
||||
|
@ -473,6 +473,7 @@ struct ExpandOneResultRow {
|
||||
};
|
||||
|
||||
struct ExpandOneResponse {
|
||||
bool success;
|
||||
std::vector<ExpandOneResultRow> result;
|
||||
};
|
||||
|
||||
|
@ -86,6 +86,11 @@ struct ExecutionState {
|
||||
// the ShardRequestManager impl will send requests to. When a request to a shard exhausts it, meaning that
|
||||
// it pulled all the requested data from the given Shard, it will be removed from the Vector. When the Vector becomes
|
||||
// empty, it means that all of the requests have completed succefully.
|
||||
// TODO(gvolfing)
|
||||
// Maybe make this into a more complex object to be able to keep track of paginated resutls. E.g. instead of a vector
|
||||
// of Shards make it into a std::vector<std::pair<Shard, PaginatedResultType>> (probably a struct instead of a pair)
|
||||
// where PaginatedResultType is an enum signaling the progress on the given request. This way we can easily check if
|
||||
// a partial response on a shard(if there is one) is finished and we can send off the request for the next batch.
|
||||
std::vector<Shard> shard_cache;
|
||||
// 1-1 mapping with `shard_cache`.
|
||||
// A vector that tracks request metatdata for each shard (For example, next_id for a ScanAll on Shard A)
|
||||
@ -233,34 +238,22 @@ class ShardRequestManager : public ShardRequestManagerInterface {
|
||||
std::vector<VertexAccessor> Request(ExecutionState<ScanVerticesRequest> &state) override {
|
||||
MaybeInitializeExecutionState(state);
|
||||
std::vector<ScanVerticesResponse> responses;
|
||||
auto &shard_cache_ref = state.shard_cache;
|
||||
size_t id = 0;
|
||||
for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end(); ++id) {
|
||||
auto &storage_client = GetStorageClientForShard(
|
||||
*state.label, storage::conversions::ConvertPropertyVector(state.requests[id].start_id.second));
|
||||
// TODO(kostasrim) Currently requests return the result directly. Adjust this when the API works MgFuture
|
||||
// instead.
|
||||
ReadRequests req = state.requests[id];
|
||||
auto read_response_result = storage_client.SendReadRequest(req);
|
||||
// RETRY on timeouts?
|
||||
// Sometimes this produces a timeout. Temporary solution is to use a while(true) as was done in shard_map test
|
||||
if (read_response_result.HasError()) {
|
||||
throw std::runtime_error("ScanAll request timedout");
|
||||
}
|
||||
ReadResponses read_response_variant = read_response_result.GetValue();
|
||||
auto &response = std::get<ScanVerticesResponse>(read_response_variant);
|
||||
if (!response.success) {
|
||||
throw std::runtime_error("ScanAll request did not succeed");
|
||||
}
|
||||
if (!response.next_start_id) {
|
||||
shard_it = shard_cache_ref.erase(shard_it);
|
||||
} else {
|
||||
state.requests[id].start_id.second = response.next_start_id->second;
|
||||
++shard_it;
|
||||
}
|
||||
responses.push_back(std::move(response));
|
||||
|
||||
SendAllRequests(state);
|
||||
auto all_requests_gathered = [](auto &paginated_rsp_tracker) {
|
||||
return std::ranges::all_of(paginated_rsp_tracker, [](const auto &state) {
|
||||
return state.second == PaginatedResponseState::PartiallyFinished;
|
||||
});
|
||||
};
|
||||
|
||||
std::map<Shard, PaginatedResponseState> paginated_response_tracker;
|
||||
for (const auto &shard : state.shard_cache) {
|
||||
paginated_response_tracker.insert(std::make_pair(shard, PaginatedResponseState::Pending));
|
||||
}
|
||||
// We are done with this state
|
||||
do {
|
||||
AwaitOnPaginatedRequests(state, responses, paginated_response_tracker);
|
||||
} while (!all_requests_gathered(paginated_response_tracker));
|
||||
|
||||
MaybeCompleteState(state);
|
||||
// TODO(kostasrim) Before returning start prefetching the batch (this shall be done once we get MgFuture as return
|
||||
// result of storage_client.SendReadRequest()).
|
||||
@ -273,32 +266,15 @@ class ShardRequestManager : public ShardRequestManagerInterface {
|
||||
MaybeInitializeExecutionState(state, new_vertices);
|
||||
std::vector<CreateVerticesResponse> responses;
|
||||
auto &shard_cache_ref = state.shard_cache;
|
||||
size_t id = 0;
|
||||
for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end(); ++id) {
|
||||
// This is fine because all new_vertices of each request end up on the same shard
|
||||
const auto labels = state.requests[id].new_vertices[0].label_ids;
|
||||
for (auto &new_vertex : state.requests[id].new_vertices) {
|
||||
new_vertex.label_ids.erase(new_vertex.label_ids.begin());
|
||||
}
|
||||
auto primary_key = state.requests[id].new_vertices[0].primary_key;
|
||||
auto &storage_client = GetStorageClientForShard(*shard_it, labels[0].id);
|
||||
WriteRequests req = state.requests[id];
|
||||
auto write_response_result = storage_client.SendWriteRequest(req);
|
||||
// RETRY on timeouts?
|
||||
// Sometimes this produces a timeout. Temporary solution is to use a while(true) as was done in shard_map test
|
||||
if (write_response_result.HasError()) {
|
||||
throw std::runtime_error("CreateVertices request timedout");
|
||||
}
|
||||
WriteResponses response_variant = write_response_result.GetValue();
|
||||
CreateVerticesResponse mapped_response = std::get<CreateVerticesResponse>(response_variant);
|
||||
|
||||
if (!mapped_response.success) {
|
||||
throw std::runtime_error("CreateVertices request did not succeed");
|
||||
}
|
||||
responses.push_back(mapped_response);
|
||||
shard_it = shard_cache_ref.erase(shard_it);
|
||||
}
|
||||
// We are done with this state
|
||||
// 1. Send the requests.
|
||||
SendAllRequests(state, shard_cache_ref);
|
||||
|
||||
// 2. Block untill all the futures are exhausted
|
||||
do {
|
||||
AwaitOnResponses(state, responses);
|
||||
} while (!state.shard_cache.empty());
|
||||
|
||||
MaybeCompleteState(state);
|
||||
// TODO(kostasrim) Before returning start prefetching the batch (this shall be done once we get MgFuture as return
|
||||
// result of storage_client.SendReadRequest()).
|
||||
@ -314,25 +290,22 @@ class ShardRequestManager : public ShardRequestManagerInterface {
|
||||
MaybeInitializeExecutionState(state);
|
||||
std::vector<ExpandOneResponse> responses;
|
||||
auto &shard_cache_ref = state.shard_cache;
|
||||
size_t id = 0;
|
||||
// pending_requests on shards
|
||||
for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end(); ++id) {
|
||||
const Label primary_label = state.requests[id].src_vertices[0].first;
|
||||
auto &storage_client = GetStorageClientForShard(*shard_it, primary_label.id);
|
||||
ReadRequests req = state.requests[id];
|
||||
auto read_response_result = storage_client.SendReadRequest(req);
|
||||
// RETRY on timeouts?
|
||||
// Sometimes this produces a timeout. Temporary solution is to use a while(true) as was done in shard_map
|
||||
if (read_response_result.HasError()) {
|
||||
throw std::runtime_error("ExpandOne request timedout");
|
||||
}
|
||||
auto &response = std::get<ExpandOneResponse>(read_response_result.GetValue());
|
||||
responses.push_back(std::move(response));
|
||||
}
|
||||
|
||||
// 1. Send the requests.
|
||||
SendAllRequests(state, shard_cache_ref);
|
||||
|
||||
// 2. Block untill all the futures are exhausted
|
||||
do {
|
||||
AwaitOnResponses(state, responses);
|
||||
} while (!state.shard_cache.empty());
|
||||
|
||||
MaybeCompleteState(state);
|
||||
return responses;
|
||||
}
|
||||
|
||||
private:
|
||||
enum class PaginatedResponseState { Pending, PartiallyFinished };
|
||||
|
||||
std::vector<VertexAccessor> PostProcess(std::vector<ScanVerticesResponse> &&responses) const {
|
||||
std::vector<VertexAccessor> accessors;
|
||||
for (auto &response : responses) {
|
||||
@ -467,6 +440,170 @@ class ShardRequestManager : public ShardRequestManagerInterface {
|
||||
storage_cli_manager_.AddClient(label_id, target_shard, std::move(cli));
|
||||
}
|
||||
|
||||
void SendAllRequests(ExecutionState<ScanVerticesRequest> &state) {
|
||||
for (const auto &request : state.requests) {
|
||||
auto &storage_client =
|
||||
GetStorageClientForShard(*state.label, storage::conversions::ConvertPropertyVector(request.start_id.second));
|
||||
ReadRequests req = request;
|
||||
storage_client.SendAsyncReadRequest(request);
|
||||
}
|
||||
}
|
||||
|
||||
void SendAllRequests(ExecutionState<CreateVerticesRequest> &state,
|
||||
std::vector<memgraph::coordinator::Shard> &shard_cache_ref) {
|
||||
size_t id = 0;
|
||||
for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end(); ++shard_it) {
|
||||
// This is fine because all new_vertices of each request end up on the same shard
|
||||
const auto labels = state.requests[id].new_vertices[0].label_ids;
|
||||
auto req_deep_copy = state.requests[id];
|
||||
|
||||
for (auto &new_vertex : req_deep_copy.new_vertices) {
|
||||
new_vertex.label_ids.erase(new_vertex.label_ids.begin());
|
||||
}
|
||||
|
||||
auto &storage_client = GetStorageClientForShard(*shard_it, labels[0].id);
|
||||
|
||||
WriteRequests req = req_deep_copy;
|
||||
storage_client.SendAsyncWriteRequest(req);
|
||||
++id;
|
||||
}
|
||||
}
|
||||
|
||||
void SendAllRequests(ExecutionState<ExpandOneRequest> &state,
|
||||
std::vector<memgraph::coordinator::Shard> &shard_cache_ref) {
|
||||
size_t id = 0;
|
||||
for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end(); ++id) {
|
||||
const Label primary_label = state.requests[id].src_vertices[0].first;
|
||||
auto &storage_client = GetStorageClientForShard(*shard_it, primary_label.id);
|
||||
ReadRequests req = state.requests[id];
|
||||
storage_client.SendAsyncReadRequest(req);
|
||||
}
|
||||
}
|
||||
|
||||
void AwaitOnResponses(ExecutionState<CreateVerticesRequest> &state, std::vector<CreateVerticesResponse> &responses) {
|
||||
auto &shard_cache_ref = state.shard_cache;
|
||||
int64_t request_idx = 0;
|
||||
|
||||
for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end();) {
|
||||
// This is fine because all new_vertices of each request end up on the same shard
|
||||
const auto labels = state.requests[request_idx].new_vertices[0].label_ids;
|
||||
|
||||
auto &storage_client = GetStorageClientForShard(*shard_it, labels[0].id);
|
||||
|
||||
auto poll_result = storage_client.AwaitAsyncWriteRequest();
|
||||
if (!poll_result) {
|
||||
++shard_it;
|
||||
++request_idx;
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
if (poll_result->HasError()) {
|
||||
throw std::runtime_error("CreateVertices request timed out");
|
||||
}
|
||||
|
||||
WriteResponses response_variant = poll_result->GetValue();
|
||||
auto response = std::get<CreateVerticesResponse>(response_variant);
|
||||
|
||||
if (!response.success) {
|
||||
throw std::runtime_error("CreateVertices request did not succeed");
|
||||
}
|
||||
responses.push_back(response);
|
||||
|
||||
shard_it = shard_cache_ref.erase(shard_it);
|
||||
// Needed to maintain the 1-1 mapping between the ShardCache and the requests.
|
||||
auto it = state.requests.begin() + request_idx;
|
||||
state.requests.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
void AwaitOnResponses(ExecutionState<ExpandOneRequest> &state, std::vector<ExpandOneResponse> &responses) {
|
||||
auto &shard_cache_ref = state.shard_cache;
|
||||
int64_t request_idx = 0;
|
||||
|
||||
for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end(); ++request_idx) {
|
||||
auto &storage_client = GetStorageClientForShard(
|
||||
*state.label,
|
||||
storage::conversions::ConvertPropertyVector(state.requests[request_idx].src_vertices[0].second));
|
||||
|
||||
auto poll_result = storage_client.PollAsyncReadRequest();
|
||||
if (!poll_result) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (poll_result->HasError()) {
|
||||
throw std::runtime_error("ExpandOne request timed out");
|
||||
}
|
||||
|
||||
ReadResponses response_variant = poll_result->GetValue();
|
||||
auto response = std::get<ExpandOneResponse>(response_variant);
|
||||
// -NOTE-
|
||||
// Currently a boolean flag for signaling the overall success of the
|
||||
// ExpandOne request does not exist. But it should, so here we assume
|
||||
// that it is already in place.
|
||||
if (!response.success) {
|
||||
throw std::runtime_error("ExpandOne request did not succeed");
|
||||
}
|
||||
|
||||
responses.push_back(std::move(response));
|
||||
shard_it = shard_cache_ref.erase(shard_it);
|
||||
// Needed to maintain the 1-1 mapping between the ShardCache and the requests.
|
||||
auto it = state.requests.begin() + request_idx;
|
||||
state.requests.erase(it);
|
||||
--request_idx;
|
||||
}
|
||||
}
|
||||
|
||||
void AwaitOnPaginatedRequests(ExecutionState<ScanVerticesRequest> &state,
|
||||
std::vector<ScanVerticesResponse> &responses,
|
||||
std::map<Shard, PaginatedResponseState> &paginated_response_tracker) {
|
||||
auto &shard_cache_ref = state.shard_cache;
|
||||
|
||||
// Find the first request that is not holding a paginated response.
|
||||
int64_t request_idx = 0;
|
||||
for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end();) {
|
||||
if (paginated_response_tracker.at(*shard_it) != PaginatedResponseState::Pending) {
|
||||
++shard_it;
|
||||
++request_idx;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto &storage_client = GetStorageClientForShard(
|
||||
*state.label, storage::conversions::ConvertPropertyVector(state.requests[request_idx].start_id.second));
|
||||
auto await_result = storage_client.AwaitAsyncReadRequest();
|
||||
|
||||
if (!await_result) {
|
||||
// Redirection has occured.
|
||||
++shard_it;
|
||||
++request_idx;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (await_result->HasError()) {
|
||||
throw std::runtime_error("ScanAll request timed out");
|
||||
}
|
||||
|
||||
ReadResponses read_response_variant = await_result->GetValue();
|
||||
auto response = std::get<ScanVerticesResponse>(read_response_variant);
|
||||
if (!response.success) {
|
||||
throw std::runtime_error("ScanAll request did not succeed");
|
||||
}
|
||||
|
||||
if (!response.next_start_id) {
|
||||
paginated_response_tracker.erase((*shard_it));
|
||||
shard_cache_ref.erase(shard_it);
|
||||
// Needed to maintain the 1-1 mapping between the ShardCache and the requests.
|
||||
auto it = state.requests.begin() + request_idx;
|
||||
state.requests.erase(it);
|
||||
|
||||
} else {
|
||||
state.requests[request_idx].start_id.second = response.next_start_id->second;
|
||||
paginated_response_tracker[*shard_it] = PaginatedResponseState::PartiallyFinished;
|
||||
}
|
||||
responses.push_back(std::move(response));
|
||||
}
|
||||
}
|
||||
|
||||
ShardMap shards_map_;
|
||||
CoordinatorClient coord_cli_;
|
||||
RsmStorageClientManager<StorageClient> storage_cli_manager_;
|
||||
|
@ -131,15 +131,15 @@ ShardMap CreateDummyShardmap(Address a_io_1, Address a_io_2, Address a_io_3, Add
|
||||
return sm;
|
||||
}
|
||||
|
||||
std::optional<ShardClient> DetermineShardLocation(const Shard &target_shard, const std::vector<Address> &a_addrs,
|
||||
ShardClient a_client, const std::vector<Address> &b_addrs,
|
||||
ShardClient b_client) {
|
||||
std::optional<ShardClient *> DetermineShardLocation(const Shard &target_shard, const std::vector<Address> &a_addrs,
|
||||
ShardClient &a_client, const std::vector<Address> &b_addrs,
|
||||
ShardClient &b_client) {
|
||||
for (const auto &addr : target_shard) {
|
||||
if (addr.address == b_addrs[0]) {
|
||||
return b_client;
|
||||
return &b_client;
|
||||
}
|
||||
if (addr.address == a_addrs[0]) {
|
||||
return a_client;
|
||||
return &a_client;
|
||||
}
|
||||
}
|
||||
return {};
|
||||
@ -313,7 +313,7 @@ int main() {
|
||||
storage_req.value = 1000;
|
||||
storage_req.transaction_id = transaction_id;
|
||||
|
||||
auto write_response_result = storage_client.SendWriteRequest(storage_req);
|
||||
auto write_response_result = storage_client->SendWriteRequest(storage_req);
|
||||
if (write_response_result.HasError()) {
|
||||
// timed out
|
||||
continue;
|
||||
@ -333,7 +333,7 @@ int main() {
|
||||
storage_get_req.key = compound_key;
|
||||
storage_get_req.transaction_id = transaction_id;
|
||||
|
||||
auto get_response_result = storage_client.SendReadRequest(storage_get_req);
|
||||
auto get_response_result = storage_client->SendReadRequest(storage_get_req);
|
||||
if (get_response_result.HasError()) {
|
||||
// timed out
|
||||
continue;
|
||||
|
@ -33,6 +33,10 @@
|
||||
|
||||
namespace memgraph::io::tests {
|
||||
|
||||
static const std::string kLabelName{"test_label"};
|
||||
static const std::string kProperty1{"property_1"};
|
||||
static const std::string kProperty2{"property_2"};
|
||||
|
||||
using memgraph::coordinator::Coordinator;
|
||||
using memgraph::coordinator::CoordinatorClient;
|
||||
using memgraph::coordinator::CoordinatorReadRequests;
|
||||
@ -63,13 +67,11 @@ using ShardClient = RsmClient<LocalTransport, WriteRequests, WriteResponses, Rea
|
||||
ShardMap TestShardMap() {
|
||||
ShardMap sm{};
|
||||
|
||||
const std::string label_name = std::string("test_label");
|
||||
|
||||
// register new properties
|
||||
const std::vector<std::string> property_names = {"property_1", "property_2"};
|
||||
const std::vector<std::string> property_names = {kProperty1, kProperty2};
|
||||
const auto properties = sm.AllocatePropertyIds(property_names);
|
||||
const auto property_id_1 = properties.at("property_1");
|
||||
const auto property_id_2 = properties.at("property_2");
|
||||
const auto property_id_1 = properties.at(kProperty1);
|
||||
const auto property_id_2 = properties.at(kProperty2);
|
||||
const auto type_1 = memgraph::common::SchemaType::INT;
|
||||
const auto type_2 = memgraph::common::SchemaType::INT;
|
||||
|
||||
@ -81,8 +83,8 @@ ShardMap TestShardMap() {
|
||||
|
||||
const size_t replication_factor = 1;
|
||||
|
||||
std::optional<LabelId> label_id = sm.InitializeNewLabel(label_name, schema, replication_factor, sm.shard_map_version);
|
||||
MG_ASSERT(label_id);
|
||||
const auto label_id = sm.InitializeNewLabel(kLabelName, schema, replication_factor, sm.shard_map_version);
|
||||
EXPECT_TRUE(label_id.has_value());
|
||||
|
||||
// split the shard at N split points
|
||||
// NB: this is the logic that should be provided by the "split file"
|
||||
@ -98,9 +100,9 @@ ShardMap TestShardMap() {
|
||||
|
||||
const CompoundKey split_point = {key1, key2};
|
||||
|
||||
const bool split_success = sm.SplitShard(sm.shard_map_version, label_id.value(), split_point);
|
||||
const auto split_success = sm.SplitShard(sm.shard_map_version, label_id.value(), split_point);
|
||||
|
||||
MG_ASSERT(split_success);
|
||||
EXPECT_TRUE(split_success);
|
||||
}
|
||||
|
||||
return sm;
|
||||
@ -108,10 +110,10 @@ ShardMap TestShardMap() {
|
||||
|
||||
template <typename ShardRequestManager>
|
||||
void TestScanAll(ShardRequestManager &shard_request_manager) {
|
||||
msgs::ExecutionState<msgs::ScanVerticesRequest> state{.label = "test_label"};
|
||||
msgs::ExecutionState<msgs::ScanVerticesRequest> state{.label = kLabelName};
|
||||
|
||||
auto result = shard_request_manager.Request(state);
|
||||
MG_ASSERT(result.size() == 2, "{}", result.size());
|
||||
EXPECT_EQ(result.size(), 2);
|
||||
}
|
||||
|
||||
template <typename ShardRequestManager>
|
||||
@ -119,7 +121,7 @@ void TestCreateVertices(ShardRequestManager &shard_request_manager) {
|
||||
using PropVal = msgs::Value;
|
||||
msgs::ExecutionState<msgs::CreateVerticesRequest> state;
|
||||
std::vector<msgs::NewVertex> new_vertices;
|
||||
auto label_id = shard_request_manager.LabelNameToLabelId("test_label");
|
||||
auto label_id = shard_request_manager.LabelNameToLabelId(kLabelName);
|
||||
msgs::NewVertex a1{.primary_key = {PropVal(int64_t(0)), PropVal(int64_t(0))}};
|
||||
a1.label_ids.push_back({label_id});
|
||||
msgs::NewVertex a2{.primary_key = {PropVal(int64_t(13)), PropVal(int64_t(13))}};
|
||||
@ -128,7 +130,7 @@ void TestCreateVertices(ShardRequestManager &shard_request_manager) {
|
||||
new_vertices.push_back(std::move(a2));
|
||||
|
||||
auto result = shard_request_manager.Request(state, std::move(new_vertices));
|
||||
MG_ASSERT(result.size() == 1);
|
||||
EXPECT_EQ(result.size(), 1);
|
||||
}
|
||||
|
||||
template <typename ShardRequestManager>
|
||||
|
Loading…
Reference in New Issue
Block a user