Merge pull request #651 from memgraph/T1157-MG-concurrent-RsmClient-requests
Support concurrent RsmClient requests
This commit is contained in:
commit
53040c6758
@ -14,6 +14,7 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "io/address.hpp"
|
#include "io/address.hpp"
|
||||||
@ -36,6 +37,21 @@ using memgraph::io::rsm::WriteRequest;
|
|||||||
using memgraph::io::rsm::WriteResponse;
|
using memgraph::io::rsm::WriteResponse;
|
||||||
using memgraph::utils::BasicResult;
|
using memgraph::utils::BasicResult;
|
||||||
|
|
||||||
|
class AsyncRequestToken {
|
||||||
|
size_t id_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit AsyncRequestToken(size_t id) : id_(id) {}
|
||||||
|
size_t GetId() const { return id_; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename RequestT, typename ResponseT>
|
||||||
|
struct AsyncRequest {
|
||||||
|
Time start_time;
|
||||||
|
RequestT request;
|
||||||
|
ResponseFuture<ResponseT> future;
|
||||||
|
};
|
||||||
|
|
||||||
template <typename IoImpl, typename WriteRequestT, typename WriteResponseT, typename ReadRequestT,
|
template <typename IoImpl, typename WriteRequestT, typename WriteResponseT, typename ReadRequestT,
|
||||||
typename ReadResponseT>
|
typename ReadResponseT>
|
||||||
class RsmClient {
|
class RsmClient {
|
||||||
@ -47,23 +63,17 @@ class RsmClient {
|
|||||||
|
|
||||||
/// State for single async read/write operations. In the future this could become a map
|
/// State for single async read/write operations. In the future this could become a map
|
||||||
/// of async operations that can be accessed via an ID etc...
|
/// of async operations that can be accessed via an ID etc...
|
||||||
std::optional<Time> async_read_before_;
|
std::unordered_map<size_t, AsyncRequest<ReadRequestT, ReadResponse<ReadResponseT>>> async_reads_;
|
||||||
std::optional<ResponseFuture<ReadResponse<ReadResponseT>>> async_read_;
|
std::unordered_map<size_t, AsyncRequest<WriteRequestT, WriteResponse<WriteResponseT>>> async_writes_;
|
||||||
ReadRequestT current_read_request_;
|
|
||||||
|
|
||||||
std::optional<Time> async_write_before_;
|
size_t async_token_generator_ = 0;
|
||||||
std::optional<ResponseFuture<WriteResponse<WriteResponseT>>> async_write_;
|
|
||||||
WriteRequestT current_write_request_;
|
|
||||||
|
|
||||||
void SelectRandomLeader() {
|
void SelectRandomLeader() {
|
||||||
std::uniform_int_distribution<size_t> addr_distrib(0, (server_addrs_.size() - 1));
|
std::uniform_int_distribution<size_t> addr_distrib(0, (server_addrs_.size() - 1));
|
||||||
size_t addr_index = io_.Rand(addr_distrib);
|
size_t addr_index = io_.Rand(addr_distrib);
|
||||||
leader_ = server_addrs_[addr_index];
|
leader_ = server_addrs_[addr_index];
|
||||||
|
|
||||||
spdlog::debug(
|
spdlog::debug("selecting a random leader at index {} with address {}", addr_index, leader_.ToString());
|
||||||
"client NOT redirected to leader server despite our success failing to be processed (it probably was sent to "
|
|
||||||
"a RSM Candidate) trying a random one at index {} with address {}",
|
|
||||||
addr_index, leader_.ToString());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename ResponseT>
|
template <typename ResponseT>
|
||||||
@ -91,107 +101,74 @@ class RsmClient {
|
|||||||
~RsmClient() = default;
|
~RsmClient() = default;
|
||||||
|
|
||||||
BasicResult<TimedOut, WriteResponseT> SendWriteRequest(WriteRequestT req) {
|
BasicResult<TimedOut, WriteResponseT> SendWriteRequest(WriteRequestT req) {
|
||||||
WriteRequest<WriteRequestT> client_req;
|
auto token = SendAsyncWriteRequest(req);
|
||||||
client_req.operation = req;
|
auto poll_result = AwaitAsyncWriteRequest(token);
|
||||||
|
while (!poll_result) {
|
||||||
const Duration overall_timeout = io_.GetDefaultTimeout();
|
poll_result = AwaitAsyncWriteRequest(token);
|
||||||
const Time before = io_.Now();
|
}
|
||||||
|
return poll_result.value();
|
||||||
do {
|
|
||||||
spdlog::debug("client sending WriteRequest to Leader {}", leader_.ToString());
|
|
||||||
ResponseFuture<WriteResponse<WriteResponseT>> response_future =
|
|
||||||
io_.template Request<WriteRequest<WriteRequestT>, WriteResponse<WriteResponseT>>(leader_, client_req);
|
|
||||||
ResponseResult<WriteResponse<WriteResponseT>> response_result = std::move(response_future).Wait();
|
|
||||||
|
|
||||||
if (response_result.HasError()) {
|
|
||||||
spdlog::debug("client timed out while trying to communicate with leader server {}", leader_.ToString());
|
|
||||||
return response_result.GetError();
|
|
||||||
}
|
|
||||||
|
|
||||||
ResponseEnvelope<WriteResponse<WriteResponseT>> &&response_envelope = std::move(response_result.GetValue());
|
|
||||||
WriteResponse<WriteResponseT> &&write_response = std::move(response_envelope.message);
|
|
||||||
|
|
||||||
if (write_response.success) {
|
|
||||||
return std::move(write_response.write_return);
|
|
||||||
}
|
|
||||||
|
|
||||||
PossiblyRedirectLeader(write_response);
|
|
||||||
} while (io_.Now() < before + overall_timeout);
|
|
||||||
|
|
||||||
return TimedOut{};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BasicResult<TimedOut, ReadResponseT> SendReadRequest(ReadRequestT req) {
|
BasicResult<TimedOut, ReadResponseT> SendReadRequest(ReadRequestT req) {
|
||||||
ReadRequest<ReadRequestT> read_req;
|
auto token = SendAsyncReadRequest(req);
|
||||||
read_req.operation = req;
|
auto poll_result = AwaitAsyncReadRequest(token);
|
||||||
|
while (!poll_result) {
|
||||||
const Duration overall_timeout = io_.GetDefaultTimeout();
|
poll_result = AwaitAsyncReadRequest(token);
|
||||||
const Time before = io_.Now();
|
}
|
||||||
|
return poll_result.value();
|
||||||
do {
|
|
||||||
spdlog::debug("client sending ReadRequest to Leader {}", leader_.ToString());
|
|
||||||
|
|
||||||
ResponseFuture<ReadResponse<ReadResponseT>> get_response_future =
|
|
||||||
io_.template Request<ReadRequest<ReadRequestT>, ReadResponse<ReadResponseT>>(leader_, read_req);
|
|
||||||
|
|
||||||
// receive response
|
|
||||||
ResponseResult<ReadResponse<ReadResponseT>> get_response_result = std::move(get_response_future).Wait();
|
|
||||||
|
|
||||||
if (get_response_result.HasError()) {
|
|
||||||
spdlog::debug("client timed out while trying to communicate with leader server {}", leader_.ToString());
|
|
||||||
return get_response_result.GetError();
|
|
||||||
}
|
|
||||||
|
|
||||||
ResponseEnvelope<ReadResponse<ReadResponseT>> &&get_response_envelope = std::move(get_response_result.GetValue());
|
|
||||||
ReadResponse<ReadResponseT> &&read_get_response = std::move(get_response_envelope.message);
|
|
||||||
|
|
||||||
if (read_get_response.success) {
|
|
||||||
return std::move(read_get_response.read_return);
|
|
||||||
}
|
|
||||||
|
|
||||||
PossiblyRedirectLeader(read_get_response);
|
|
||||||
} while (io_.Now() < before + overall_timeout);
|
|
||||||
|
|
||||||
return TimedOut{};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// AsyncRead methods
|
/// AsyncRead methods
|
||||||
void SendAsyncReadRequest(const ReadRequestT &req) {
|
AsyncRequestToken SendAsyncReadRequest(const ReadRequestT &req) {
|
||||||
MG_ASSERT(!async_read_);
|
size_t token = async_token_generator_++;
|
||||||
|
|
||||||
ReadRequest<ReadRequestT> read_req = {.operation = req};
|
ReadRequest<ReadRequestT> read_req = {.operation = req};
|
||||||
|
|
||||||
if (!async_read_before_) {
|
AsyncRequest<ReadRequestT, ReadResponse<ReadResponseT>> async_request{
|
||||||
async_read_before_ = io_.Now();
|
.start_time = io_.Now(),
|
||||||
}
|
.request = std::move(req),
|
||||||
current_read_request_ = std::move(req);
|
.future = io_.template Request<ReadRequest<ReadRequestT>, ReadResponse<ReadResponseT>>(leader_, read_req),
|
||||||
async_read_ = io_.template Request<ReadRequest<ReadRequestT>, ReadResponse<ReadResponseT>>(leader_, read_req);
|
};
|
||||||
|
|
||||||
|
async_reads_.emplace(token, std::move(async_request));
|
||||||
|
|
||||||
|
return AsyncRequestToken{token};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<BasicResult<TimedOut, ReadResponseT>> PollAsyncReadRequest() {
|
void ResendAsyncReadRequest(const AsyncRequestToken &token) {
|
||||||
MG_ASSERT(async_read_);
|
auto &async_request = async_reads_.at(token.GetId());
|
||||||
|
|
||||||
if (!async_read_->IsReady()) {
|
ReadRequest<ReadRequestT> read_req = {.operation = async_request.request};
|
||||||
|
|
||||||
|
async_request.future =
|
||||||
|
io_.template Request<ReadRequest<ReadRequestT>, ReadResponse<ReadResponseT>>(leader_, read_req);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<BasicResult<TimedOut, ReadResponseT>> PollAsyncReadRequest(const AsyncRequestToken &token) {
|
||||||
|
auto &async_request = async_reads_.at(token.GetId());
|
||||||
|
|
||||||
|
if (!async_request.future.IsReady()) {
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
return AwaitAsyncReadRequest();
|
return AwaitAsyncReadRequest();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<BasicResult<TimedOut, ReadResponseT>> AwaitAsyncReadRequest() {
|
std::optional<BasicResult<TimedOut, ReadResponseT>> AwaitAsyncReadRequest(const AsyncRequestToken &token) {
|
||||||
ResponseResult<ReadResponse<ReadResponseT>> get_response_result = std::move(*async_read_).Wait();
|
auto &async_request = async_reads_.at(token.GetId());
|
||||||
async_read_.reset();
|
ResponseResult<ReadResponse<ReadResponseT>> get_response_result = std::move(async_request.future).Wait();
|
||||||
|
|
||||||
const Duration overall_timeout = io_.GetDefaultTimeout();
|
const Duration overall_timeout = io_.GetDefaultTimeout();
|
||||||
const bool past_time_out = io_.Now() < *async_read_before_ + overall_timeout;
|
const bool past_time_out = io_.Now() > async_request.start_time + overall_timeout;
|
||||||
const bool result_has_error = get_response_result.HasError();
|
const bool result_has_error = get_response_result.HasError();
|
||||||
|
|
||||||
if (result_has_error && past_time_out) {
|
if (result_has_error && past_time_out) {
|
||||||
// TODO static assert the exact type of error.
|
// TODO static assert the exact type of error.
|
||||||
spdlog::debug("client timed out while trying to communicate with leader server {}", leader_.ToString());
|
spdlog::debug("client timed out while trying to communicate with leader server {}", leader_.ToString());
|
||||||
async_read_before_ = std::nullopt;
|
async_reads_.erase(token.GetId());
|
||||||
return TimedOut{};
|
return TimedOut{};
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!result_has_error) {
|
if (!result_has_error) {
|
||||||
ResponseEnvelope<ReadResponse<ReadResponseT>> &&get_response_envelope = std::move(get_response_result.GetValue());
|
ResponseEnvelope<ReadResponse<ReadResponseT>> &&get_response_envelope = std::move(get_response_result.GetValue());
|
||||||
ReadResponse<ReadResponseT> &&read_get_response = std::move(get_response_envelope.message);
|
ReadResponse<ReadResponseT> &&read_get_response = std::move(get_response_envelope.message);
|
||||||
@ -199,54 +176,70 @@ class RsmClient {
|
|||||||
PossiblyRedirectLeader(read_get_response);
|
PossiblyRedirectLeader(read_get_response);
|
||||||
|
|
||||||
if (read_get_response.success) {
|
if (read_get_response.success) {
|
||||||
async_read_before_ = std::nullopt;
|
async_reads_.erase(token.GetId());
|
||||||
|
spdlog::debug("returning read_return for RSM request");
|
||||||
return std::move(read_get_response.read_return);
|
return std::move(read_get_response.read_return);
|
||||||
}
|
}
|
||||||
SendAsyncReadRequest(current_read_request_);
|
} else {
|
||||||
} else if (result_has_error) {
|
|
||||||
SelectRandomLeader();
|
SelectRandomLeader();
|
||||||
SendAsyncReadRequest(current_read_request_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ResendAsyncReadRequest(token);
|
||||||
|
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// AsyncWrite methods
|
/// AsyncWrite methods
|
||||||
void SendAsyncWriteRequest(const WriteRequestT &req) {
|
AsyncRequestToken SendAsyncWriteRequest(const WriteRequestT &req) {
|
||||||
MG_ASSERT(!async_write_);
|
size_t token = async_token_generator_++;
|
||||||
|
|
||||||
WriteRequest<WriteRequestT> write_req = {.operation = req};
|
WriteRequest<WriteRequestT> write_req = {.operation = req};
|
||||||
|
|
||||||
if (!async_write_before_) {
|
AsyncRequest<WriteRequestT, WriteResponse<WriteResponseT>> async_request{
|
||||||
async_write_before_ = io_.Now();
|
.start_time = io_.Now(),
|
||||||
}
|
.request = std::move(req),
|
||||||
current_write_request_ = std::move(req);
|
.future = io_.template Request<WriteRequest<WriteRequestT>, WriteResponse<WriteResponseT>>(leader_, write_req),
|
||||||
async_write_ = io_.template Request<WriteRequest<WriteRequestT>, WriteResponse<WriteResponseT>>(leader_, write_req);
|
};
|
||||||
|
|
||||||
|
async_writes_.emplace(token, std::move(async_request));
|
||||||
|
|
||||||
|
return AsyncRequestToken{token};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<BasicResult<TimedOut, WriteResponseT>> PollAsyncWriteRequest() {
|
void ResendAsyncWriteRequest(const AsyncRequestToken &token) {
|
||||||
MG_ASSERT(async_write_);
|
auto &async_request = async_writes_.at(token.GetId());
|
||||||
|
|
||||||
if (!async_write_->IsReady()) {
|
WriteRequest<WriteRequestT> write_req = {.operation = async_request.request};
|
||||||
|
|
||||||
|
async_request.future =
|
||||||
|
io_.template Request<WriteRequest<WriteRequestT>, WriteResponse<WriteResponseT>>(leader_, write_req);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<BasicResult<TimedOut, WriteResponseT>> PollAsyncWriteRequest(const AsyncRequestToken &token) {
|
||||||
|
auto &async_request = async_writes_.at(token.GetId());
|
||||||
|
|
||||||
|
if (!async_request.future.IsReady()) {
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
return AwaitAsyncWriteRequest();
|
return AwaitAsyncWriteRequest();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<BasicResult<TimedOut, WriteResponseT>> AwaitAsyncWriteRequest() {
|
std::optional<BasicResult<TimedOut, WriteResponseT>> AwaitAsyncWriteRequest(const AsyncRequestToken &token) {
|
||||||
ResponseResult<WriteResponse<WriteResponseT>> get_response_result = std::move(*async_write_).Wait();
|
auto &async_request = async_writes_.at(token.GetId());
|
||||||
async_write_.reset();
|
ResponseResult<WriteResponse<WriteResponseT>> get_response_result = std::move(async_request.future).Wait();
|
||||||
|
|
||||||
const Duration overall_timeout = io_.GetDefaultTimeout();
|
const Duration overall_timeout = io_.GetDefaultTimeout();
|
||||||
const bool past_time_out = io_.Now() < *async_write_before_ + overall_timeout;
|
const bool past_time_out = io_.Now() > async_request.start_time + overall_timeout;
|
||||||
const bool result_has_error = get_response_result.HasError();
|
const bool result_has_error = get_response_result.HasError();
|
||||||
|
|
||||||
if (result_has_error && past_time_out) {
|
if (result_has_error && past_time_out) {
|
||||||
// TODO static assert the exact type of error.
|
// TODO static assert the exact type of error.
|
||||||
spdlog::debug("client timed out while trying to communicate with leader server {}", leader_.ToString());
|
spdlog::debug("client timed out while trying to communicate with leader server {}", leader_.ToString());
|
||||||
async_write_before_ = std::nullopt;
|
async_writes_.erase(token.GetId());
|
||||||
return TimedOut{};
|
return TimedOut{};
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!result_has_error) {
|
if (!result_has_error) {
|
||||||
ResponseEnvelope<WriteResponse<WriteResponseT>> &&get_response_envelope =
|
ResponseEnvelope<WriteResponse<WriteResponseT>> &&get_response_envelope =
|
||||||
std::move(get_response_result.GetValue());
|
std::move(get_response_result.GetValue());
|
||||||
@ -255,14 +248,15 @@ class RsmClient {
|
|||||||
PossiblyRedirectLeader(write_get_response);
|
PossiblyRedirectLeader(write_get_response);
|
||||||
|
|
||||||
if (write_get_response.success) {
|
if (write_get_response.success) {
|
||||||
async_write_before_ = std::nullopt;
|
async_writes_.erase(token.GetId());
|
||||||
return std::move(write_get_response.write_return);
|
return std::move(write_get_response.write_return);
|
||||||
}
|
}
|
||||||
SendAsyncWriteRequest(current_write_request_);
|
} else {
|
||||||
} else if (result_has_error) {
|
|
||||||
SelectRandomLeader();
|
SelectRandomLeader();
|
||||||
SendAsyncWriteRequest(current_write_request_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ResendAsyncWriteRequest(token);
|
||||||
|
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -180,7 +180,7 @@ class DistributedCreateNodeCursor : public Cursor {
|
|||||||
auto &request_router = context.request_router;
|
auto &request_router = context.request_router;
|
||||||
{
|
{
|
||||||
SCOPED_REQUEST_WAIT_PROFILE;
|
SCOPED_REQUEST_WAIT_PROFILE;
|
||||||
request_router->Request(state_, NodeCreationInfoToRequest(context, frame));
|
request_router->CreateVertices(NodeCreationInfoToRequest(context, frame));
|
||||||
}
|
}
|
||||||
PlaceNodeOnTheFrame(frame, context);
|
PlaceNodeOnTheFrame(frame, context);
|
||||||
return true;
|
return true;
|
||||||
@ -191,7 +191,7 @@ class DistributedCreateNodeCursor : public Cursor {
|
|||||||
|
|
||||||
void Shutdown() override { input_cursor_->Shutdown(); }
|
void Shutdown() override { input_cursor_->Shutdown(); }
|
||||||
|
|
||||||
void Reset() override { state_ = {}; }
|
void Reset() override {}
|
||||||
|
|
||||||
void PlaceNodeOnTheFrame(Frame &frame, ExecutionContext &context) {
|
void PlaceNodeOnTheFrame(Frame &frame, ExecutionContext &context) {
|
||||||
// TODO(kostasrim) Make this work with batching
|
// TODO(kostasrim) Make this work with batching
|
||||||
@ -252,7 +252,6 @@ class DistributedCreateNodeCursor : public Cursor {
|
|||||||
std::vector<const NodeCreationInfo *> nodes_info_;
|
std::vector<const NodeCreationInfo *> nodes_info_;
|
||||||
std::vector<std::vector<std::pair<storage::v3::PropertyId, msgs::Value>>> src_vertex_props_;
|
std::vector<std::vector<std::pair<storage::v3::PropertyId, msgs::Value>>> src_vertex_props_;
|
||||||
std::vector<msgs::PrimaryKey> primary_keys_;
|
std::vector<msgs::PrimaryKey> primary_keys_;
|
||||||
ExecutionState<msgs::CreateVerticesRequest> state_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
bool Once::OnceCursor::Pull(Frame &, ExecutionContext &context) {
|
bool Once::OnceCursor::Pull(Frame &, ExecutionContext &context) {
|
||||||
@ -365,7 +364,6 @@ class ScanAllCursor : public Cursor {
|
|||||||
std::optional<decltype(vertices_.value().begin())> vertices_it_;
|
std::optional<decltype(vertices_.value().begin())> vertices_it_;
|
||||||
const char *op_name_;
|
const char *op_name_;
|
||||||
std::vector<msgs::ScanVerticesResponse> current_batch;
|
std::vector<msgs::ScanVerticesResponse> current_batch;
|
||||||
ExecutionState<msgs::ScanVerticesRequest> request_state;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class DistributedScanAllAndFilterCursor : public Cursor {
|
class DistributedScanAllAndFilterCursor : public Cursor {
|
||||||
@ -384,14 +382,21 @@ class DistributedScanAllAndFilterCursor : public Cursor {
|
|||||||
ResetExecutionState();
|
ResetExecutionState();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum class State : int8_t { INITIALIZING, COMPLETED };
|
||||||
|
|
||||||
using VertexAccessor = accessors::VertexAccessor;
|
using VertexAccessor = accessors::VertexAccessor;
|
||||||
|
|
||||||
bool MakeRequest(RequestRouterInterface &request_router, ExecutionContext &context) {
|
bool MakeRequest(RequestRouterInterface &request_router, ExecutionContext &context) {
|
||||||
{
|
{
|
||||||
SCOPED_REQUEST_WAIT_PROFILE;
|
SCOPED_REQUEST_WAIT_PROFILE;
|
||||||
current_batch = request_router.Request(request_state_);
|
std::optional<std::string> request_label = std::nullopt;
|
||||||
|
if (label_.has_value()) {
|
||||||
|
request_label = request_router.LabelToName(*label_);
|
||||||
|
}
|
||||||
|
current_batch = request_router.ScanVertices(request_label);
|
||||||
}
|
}
|
||||||
current_vertex_it = current_batch.begin();
|
current_vertex_it = current_batch.begin();
|
||||||
|
request_state_ = State::COMPLETED;
|
||||||
return !current_batch.empty();
|
return !current_batch.empty();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -403,19 +408,15 @@ class DistributedScanAllAndFilterCursor : public Cursor {
|
|||||||
if (MustAbort(context)) {
|
if (MustAbort(context)) {
|
||||||
throw HintedAbortError();
|
throw HintedAbortError();
|
||||||
}
|
}
|
||||||
using State = ExecutionState<msgs::ScanVerticesRequest>;
|
|
||||||
|
|
||||||
if (request_state_.state == State::INITIALIZING) {
|
if (request_state_ == State::INITIALIZING) {
|
||||||
if (!input_cursor_->Pull(frame, context)) {
|
if (!input_cursor_->Pull(frame, context)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
request_state_.label =
|
|
||||||
label_.has_value() ? std::make_optional(request_router.LabelToName(*label_)) : std::nullopt;
|
|
||||||
|
|
||||||
if (current_vertex_it == current_batch.end() &&
|
if (current_vertex_it == current_batch.end() &&
|
||||||
(request_state_.state == State::COMPLETED || !MakeRequest(request_router, context))) {
|
(request_state_ == State::COMPLETED || !MakeRequest(request_router, context))) {
|
||||||
ResetExecutionState();
|
ResetExecutionState();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -431,7 +432,7 @@ class DistributedScanAllAndFilterCursor : public Cursor {
|
|||||||
void ResetExecutionState() {
|
void ResetExecutionState() {
|
||||||
current_batch.clear();
|
current_batch.clear();
|
||||||
current_vertex_it = current_batch.end();
|
current_vertex_it = current_batch.end();
|
||||||
request_state_ = ExecutionState<msgs::ScanVerticesRequest>{};
|
request_state_ = State::INITIALIZING;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Reset() override {
|
void Reset() override {
|
||||||
@ -445,7 +446,7 @@ class DistributedScanAllAndFilterCursor : public Cursor {
|
|||||||
const char *op_name_;
|
const char *op_name_;
|
||||||
std::vector<VertexAccessor> current_batch;
|
std::vector<VertexAccessor> current_batch;
|
||||||
std::vector<VertexAccessor>::iterator current_vertex_it;
|
std::vector<VertexAccessor>::iterator current_vertex_it;
|
||||||
ExecutionState<msgs::ScanVerticesRequest> request_state_;
|
State request_state_ = State::INITIALIZING;
|
||||||
std::optional<storage::v3::LabelId> label_;
|
std::optional<storage::v3::LabelId> label_;
|
||||||
std::optional<std::pair<storage::v3::PropertyId, Expression *>> property_expression_pair_;
|
std::optional<std::pair<storage::v3::PropertyId, Expression *>> property_expression_pair_;
|
||||||
std::optional<std::vector<Expression *>> filter_expressions_;
|
std::optional<std::vector<Expression *>> filter_expressions_;
|
||||||
@ -2343,7 +2344,7 @@ class DistributedCreateExpandCursor : public Cursor {
|
|||||||
ResetExecutionState();
|
ResetExecutionState();
|
||||||
{
|
{
|
||||||
SCOPED_REQUEST_WAIT_PROFILE;
|
SCOPED_REQUEST_WAIT_PROFILE;
|
||||||
request_router->Request(state_, ExpandCreationInfoToRequest(context, frame));
|
request_router->CreateExpand(ExpandCreationInfoToRequest(context, frame));
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -2423,11 +2424,10 @@ class DistributedCreateExpandCursor : public Cursor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void ResetExecutionState() { state_ = {}; }
|
void ResetExecutionState() {}
|
||||||
|
|
||||||
const UniqueCursorPtr input_cursor_;
|
const UniqueCursorPtr input_cursor_;
|
||||||
const CreateExpand &self_;
|
const CreateExpand &self_;
|
||||||
ExecutionState<msgs::CreateExpandRequest> state_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class DistributedExpandCursor : public Cursor {
|
class DistributedExpandCursor : public Cursor {
|
||||||
@ -2474,8 +2474,7 @@ class DistributedExpandCursor : public Cursor {
|
|||||||
request.edge_properties.emplace();
|
request.edge_properties.emplace();
|
||||||
request.src_vertices.push_back(get_dst_vertex(edge, direction));
|
request.src_vertices.push_back(get_dst_vertex(edge, direction));
|
||||||
request.direction = (direction == EdgeAtom::Direction::IN) ? msgs::EdgeDirection::OUT : msgs::EdgeDirection::IN;
|
request.direction = (direction == EdgeAtom::Direction::IN) ? msgs::EdgeDirection::OUT : msgs::EdgeDirection::IN;
|
||||||
ExecutionState<msgs::ExpandOneRequest> request_state;
|
auto result_rows = context.request_router->ExpandOne(std::move(request));
|
||||||
auto result_rows = context.request_router->Request(request_state, std::move(request));
|
|
||||||
MG_ASSERT(result_rows.size() == 1);
|
MG_ASSERT(result_rows.size() == 1);
|
||||||
auto &result_row = result_rows.front();
|
auto &result_row = result_rows.front();
|
||||||
frame[self_.common_.node_symbol] = accessors::VertexAccessor(
|
frame[self_.common_.node_symbol] = accessors::VertexAccessor(
|
||||||
@ -2500,10 +2499,9 @@ class DistributedExpandCursor : public Cursor {
|
|||||||
// to not fetch any properties of the edges
|
// to not fetch any properties of the edges
|
||||||
request.edge_properties.emplace();
|
request.edge_properties.emplace();
|
||||||
request.src_vertices.push_back(vertex.Id());
|
request.src_vertices.push_back(vertex.Id());
|
||||||
ExecutionState<msgs::ExpandOneRequest> request_state;
|
auto result_rows = std::invoke([&context, &request]() mutable {
|
||||||
auto result_rows = std::invoke([&context, &request_state, &request]() mutable {
|
|
||||||
SCOPED_REQUEST_WAIT_PROFILE;
|
SCOPED_REQUEST_WAIT_PROFILE;
|
||||||
return context.request_router->Request(request_state, std::move(request));
|
return context.request_router->ExpandOne(std::move(request));
|
||||||
});
|
});
|
||||||
MG_ASSERT(result_rows.size() == 1);
|
MG_ASSERT(result_rows.size() == 1);
|
||||||
auto &result_row = result_rows.front();
|
auto &result_row = result_rows.front();
|
||||||
|
@ -71,34 +71,28 @@ class RsmStorageClientManager {
|
|||||||
std::map<Shard, TStorageClient> cli_cache_;
|
std::map<Shard, TStorageClient> cli_cache_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename TRequest>
|
||||||
|
struct ShardRequestState {
|
||||||
|
memgraph::coordinator::Shard shard;
|
||||||
|
TRequest request;
|
||||||
|
std::optional<io::rsm::AsyncRequestToken> async_request_token;
|
||||||
|
};
|
||||||
|
|
||||||
template <typename TRequest>
|
template <typename TRequest>
|
||||||
struct ExecutionState {
|
struct ExecutionState {
|
||||||
using CompoundKey = io::rsm::ShardRsmKey;
|
using CompoundKey = io::rsm::ShardRsmKey;
|
||||||
using Shard = coordinator::Shard;
|
using Shard = coordinator::Shard;
|
||||||
|
|
||||||
enum State : int8_t { INITIALIZING, EXECUTING, COMPLETED };
|
|
||||||
// label is optional because some operators can create/remove etc, vertices. These kind of requests contain the label
|
// label is optional because some operators can create/remove etc, vertices. These kind of requests contain the label
|
||||||
// on the request itself.
|
// on the request itself.
|
||||||
std::optional<std::string> label;
|
std::optional<std::string> label;
|
||||||
// CompoundKey is optional because some operators require to iterate over all the available keys
|
|
||||||
// of a shard. One example is ScanAll, where we only require the field label.
|
|
||||||
std::optional<CompoundKey> key;
|
|
||||||
// Transaction id to be filled by the RequestRouter implementation
|
// Transaction id to be filled by the RequestRouter implementation
|
||||||
coordinator::Hlc transaction_id;
|
coordinator::Hlc transaction_id;
|
||||||
// Initialized by RequestRouter implementation. This vector is filled with the shards that
|
// Initialized by RequestRouter implementation. This vector is filled with the shards that
|
||||||
// the RequestRouter impl will send requests to. When a request to a shard exhausts it, meaning that
|
// the RequestRouter impl will send requests to. When a request to a shard exhausts it, meaning that
|
||||||
// it pulled all the requested data from the given Shard, it will be removed from the Vector. When the Vector becomes
|
// it pulled all the requested data from the given Shard, it will be removed from the Vector. When the Vector becomes
|
||||||
// empty, it means that all of the requests have completed succefully.
|
// empty, it means that all of the requests have completed succefully.
|
||||||
// TODO(gvolfing)
|
std::vector<ShardRequestState<TRequest>> requests;
|
||||||
// Maybe make this into a more complex object to be able to keep track of paginated resutls. E.g. instead of a vector
|
|
||||||
// of Shards make it into a std::vector<std::pair<Shard, PaginatedResultType>> (probably a struct instead of a pair)
|
|
||||||
// where PaginatedResultType is an enum signaling the progress on the given request. This way we can easily check if
|
|
||||||
// a partial response on a shard(if there is one) is finished and we can send off the request for the next batch.
|
|
||||||
std::vector<Shard> shard_cache;
|
|
||||||
// 1-1 mapping with `shard_cache`.
|
|
||||||
// A vector that tracks request metadata for each shard (For example, next_id for a ScanAll on Shard A)
|
|
||||||
std::vector<TRequest> requests;
|
|
||||||
State state = INITIALIZING;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class RequestRouterInterface {
|
class RequestRouterInterface {
|
||||||
@ -114,13 +108,10 @@ class RequestRouterInterface {
|
|||||||
|
|
||||||
virtual void StartTransaction() = 0;
|
virtual void StartTransaction() = 0;
|
||||||
virtual void Commit() = 0;
|
virtual void Commit() = 0;
|
||||||
virtual std::vector<VertexAccessor> Request(ExecutionState<msgs::ScanVerticesRequest> &state) = 0;
|
virtual std::vector<VertexAccessor> ScanVertices(std::optional<std::string> label) = 0;
|
||||||
virtual std::vector<msgs::CreateVerticesResponse> Request(ExecutionState<msgs::CreateVerticesRequest> &state,
|
virtual std::vector<msgs::CreateVerticesResponse> CreateVertices(std::vector<msgs::NewVertex> new_vertices) = 0;
|
||||||
std::vector<msgs::NewVertex> new_vertices) = 0;
|
virtual std::vector<msgs::ExpandOneResultRow> ExpandOne(msgs::ExpandOneRequest request) = 0;
|
||||||
virtual std::vector<msgs::ExpandOneResultRow> Request(ExecutionState<msgs::ExpandOneRequest> &state,
|
virtual std::vector<msgs::CreateExpandResponse> CreateExpand(std::vector<msgs::NewExpand> new_edges) = 0;
|
||||||
msgs::ExpandOneRequest request) = 0;
|
|
||||||
virtual std::vector<msgs::CreateExpandResponse> Request(ExecutionState<msgs::CreateExpandRequest> &state,
|
|
||||||
std::vector<msgs::NewExpand> new_edges) = 0;
|
|
||||||
|
|
||||||
virtual storage::v3::EdgeTypeId NameToEdgeType(const std::string &name) const = 0;
|
virtual storage::v3::EdgeTypeId NameToEdgeType(const std::string &name) const = 0;
|
||||||
virtual storage::v3::PropertyId NameToProperty(const std::string &name) const = 0;
|
virtual storage::v3::PropertyId NameToProperty(const std::string &name) const = 0;
|
||||||
@ -246,99 +237,121 @@ class RequestRouter : public RequestRouterInterface {
|
|||||||
bool IsPrimaryLabel(storage::v3::LabelId label) const override { return shards_map_.label_spaces.contains(label); }
|
bool IsPrimaryLabel(storage::v3::LabelId label) const override { return shards_map_.label_spaces.contains(label); }
|
||||||
|
|
||||||
// TODO(kostasrim) Simplify return result
|
// TODO(kostasrim) Simplify return result
|
||||||
std::vector<VertexAccessor> Request(ExecutionState<msgs::ScanVerticesRequest> &state) override {
|
std::vector<VertexAccessor> ScanVertices(std::optional<std::string> label) override {
|
||||||
MaybeInitializeExecutionState(state);
|
ExecutionState<msgs::ScanVerticesRequest> state = {};
|
||||||
|
state.label = label;
|
||||||
|
|
||||||
|
// create requests
|
||||||
|
InitializeExecutionState(state);
|
||||||
|
|
||||||
|
// begin all requests in parallel
|
||||||
|
for (auto &request : state.requests) {
|
||||||
|
auto &storage_client = GetStorageClientForShard(request.shard);
|
||||||
|
msgs::ReadRequests req = request.request;
|
||||||
|
|
||||||
|
request.async_request_token = storage_client.SendAsyncReadRequest(request.request);
|
||||||
|
}
|
||||||
|
|
||||||
|
// drive requests to completion
|
||||||
std::vector<msgs::ScanVerticesResponse> responses;
|
std::vector<msgs::ScanVerticesResponse> responses;
|
||||||
|
responses.reserve(state.requests.size());
|
||||||
|
do {
|
||||||
|
DriveReadResponses(state, responses);
|
||||||
|
} while (!state.requests.empty());
|
||||||
|
|
||||||
SendAllRequests(state);
|
// convert responses into VertexAccessor objects to return
|
||||||
auto all_requests_gathered = [](auto &paginated_rsp_tracker) {
|
std::vector<VertexAccessor> accessors;
|
||||||
return std::ranges::all_of(paginated_rsp_tracker, [](const auto &state) {
|
accessors.reserve(responses.size());
|
||||||
return state.second == PaginatedResponseState::PartiallyFinished;
|
for (auto &response : responses) {
|
||||||
});
|
for (auto &result_row : response.results) {
|
||||||
};
|
accessors.emplace_back(VertexAccessor(std::move(result_row.vertex), std::move(result_row.props), this));
|
||||||
|
}
|
||||||
std::map<Shard, PaginatedResponseState> paginated_response_tracker;
|
|
||||||
for (const auto &shard : state.shard_cache) {
|
|
||||||
paginated_response_tracker.insert(std::make_pair(shard, PaginatedResponseState::Pending));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
do {
|
return accessors;
|
||||||
AwaitOnPaginatedRequests(state, responses, paginated_response_tracker);
|
|
||||||
} while (!all_requests_gathered(paginated_response_tracker));
|
|
||||||
|
|
||||||
MaybeCompleteState(state);
|
|
||||||
// TODO(kostasrim) Before returning start prefetching the batch (this shall be done once we get MgFuture as return
|
|
||||||
// result of storage_client.SendReadRequest()).
|
|
||||||
return PostProcess(std::move(responses));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<msgs::CreateVerticesResponse> Request(ExecutionState<msgs::CreateVerticesRequest> &state,
|
std::vector<msgs::CreateVerticesResponse> CreateVertices(std::vector<msgs::NewVertex> new_vertices) override {
|
||||||
std::vector<msgs::NewVertex> new_vertices) override {
|
ExecutionState<msgs::CreateVerticesRequest> state = {};
|
||||||
MG_ASSERT(!new_vertices.empty());
|
MG_ASSERT(!new_vertices.empty());
|
||||||
MaybeInitializeExecutionState(state, new_vertices);
|
|
||||||
std::vector<msgs::CreateVerticesResponse> responses;
|
|
||||||
auto &shard_cache_ref = state.shard_cache;
|
|
||||||
|
|
||||||
// 1. Send the requests.
|
// create requests
|
||||||
SendAllRequests(state, shard_cache_ref);
|
InitializeExecutionState(state, new_vertices);
|
||||||
|
|
||||||
// 2. Block untill all the futures are exhausted
|
// begin all requests in parallel
|
||||||
do {
|
for (auto &request : state.requests) {
|
||||||
AwaitOnResponses(state, responses);
|
auto req_deep_copy = request.request;
|
||||||
} while (!state.shard_cache.empty());
|
|
||||||
|
|
||||||
MaybeCompleteState(state);
|
for (auto &new_vertex : req_deep_copy.new_vertices) {
|
||||||
// TODO(kostasrim) Before returning start prefetching the batch (this shall be done once we get MgFuture as return
|
new_vertex.label_ids.erase(new_vertex.label_ids.begin());
|
||||||
// result of storage_client.SendReadRequest()).
|
|
||||||
return responses;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<msgs::CreateExpandResponse> Request(ExecutionState<msgs::CreateExpandRequest> &state,
|
|
||||||
std::vector<msgs::NewExpand> new_edges) override {
|
|
||||||
MG_ASSERT(!new_edges.empty());
|
|
||||||
MaybeInitializeExecutionState(state, new_edges);
|
|
||||||
std::vector<msgs::CreateExpandResponse> responses;
|
|
||||||
auto &shard_cache_ref = state.shard_cache;
|
|
||||||
size_t id{0};
|
|
||||||
for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end(); ++id) {
|
|
||||||
auto &storage_client = GetStorageClientForShard(*shard_it);
|
|
||||||
msgs::WriteRequests req = state.requests[id];
|
|
||||||
auto write_response_result = storage_client.SendWriteRequest(std::move(req));
|
|
||||||
if (write_response_result.HasError()) {
|
|
||||||
throw std::runtime_error("CreateVertices request timedout");
|
|
||||||
}
|
}
|
||||||
msgs::WriteResponses response_variant = write_response_result.GetValue();
|
|
||||||
msgs::CreateExpandResponse mapped_response = std::get<msgs::CreateExpandResponse>(response_variant);
|
|
||||||
|
|
||||||
if (mapped_response.error) {
|
auto &storage_client = GetStorageClientForShard(request.shard);
|
||||||
throw std::runtime_error("CreateExpand request did not succeed");
|
|
||||||
}
|
msgs::WriteRequests req = req_deep_copy;
|
||||||
responses.push_back(mapped_response);
|
request.async_request_token = storage_client.SendAsyncWriteRequest(req);
|
||||||
shard_it = shard_cache_ref.erase(shard_it);
|
|
||||||
}
|
}
|
||||||
// We are done with this state
|
|
||||||
MaybeCompleteState(state);
|
// drive requests to completion
|
||||||
|
std::vector<msgs::CreateVerticesResponse> responses;
|
||||||
|
responses.reserve(state.requests.size());
|
||||||
|
do {
|
||||||
|
DriveWriteResponses(state, responses);
|
||||||
|
} while (!state.requests.empty());
|
||||||
|
|
||||||
return responses;
|
return responses;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<msgs::ExpandOneResultRow> Request(ExecutionState<msgs::ExpandOneRequest> &state,
|
std::vector<msgs::CreateExpandResponse> CreateExpand(std::vector<msgs::NewExpand> new_edges) override {
|
||||||
msgs::ExpandOneRequest request) override {
|
ExecutionState<msgs::CreateExpandRequest> state = {};
|
||||||
|
MG_ASSERT(!new_edges.empty());
|
||||||
|
|
||||||
|
// create requests
|
||||||
|
InitializeExecutionState(state, new_edges);
|
||||||
|
|
||||||
|
// begin all requests in parallel
|
||||||
|
for (auto &request : state.requests) {
|
||||||
|
auto &storage_client = GetStorageClientForShard(request.shard);
|
||||||
|
msgs::WriteRequests req = request.request;
|
||||||
|
request.async_request_token = storage_client.SendAsyncWriteRequest(req);
|
||||||
|
}
|
||||||
|
|
||||||
|
// drive requests to completion
|
||||||
|
std::vector<msgs::CreateExpandResponse> responses;
|
||||||
|
responses.reserve(state.requests.size());
|
||||||
|
do {
|
||||||
|
DriveWriteResponses(state, responses);
|
||||||
|
} while (!state.requests.empty());
|
||||||
|
|
||||||
|
return responses;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<msgs::ExpandOneResultRow> ExpandOne(msgs::ExpandOneRequest request) override {
|
||||||
|
ExecutionState<msgs::ExpandOneRequest> state = {};
|
||||||
// TODO(kostasrim)Update to limit the batch size here
|
// TODO(kostasrim)Update to limit the batch size here
|
||||||
// Expansions of the destination must be handled by the caller. For example
|
// Expansions of the destination must be handled by the caller. For example
|
||||||
// match (u:L1 { prop : 1 })-[:Friend]-(v:L1)
|
// match (u:L1 { prop : 1 })-[:Friend]-(v:L1)
|
||||||
// For each vertex U, the ExpandOne will result in <U, Edges>. The destination vertex and its properties
|
// For each vertex U, the ExpandOne will result in <U, Edges>. The destination vertex and its properties
|
||||||
// must be fetched again with an ExpandOne(Edges.dst)
|
// must be fetched again with an ExpandOne(Edges.dst)
|
||||||
MaybeInitializeExecutionState(state, std::move(request));
|
|
||||||
|
// create requests
|
||||||
|
InitializeExecutionState(state, std::move(request));
|
||||||
|
|
||||||
|
// begin all requests in parallel
|
||||||
|
for (auto &request : state.requests) {
|
||||||
|
auto &storage_client = GetStorageClientForShard(request.shard);
|
||||||
|
msgs::ReadRequests req = request.request;
|
||||||
|
request.async_request_token = storage_client.SendAsyncReadRequest(req);
|
||||||
|
}
|
||||||
|
|
||||||
|
// drive requests to completion
|
||||||
std::vector<msgs::ExpandOneResponse> responses;
|
std::vector<msgs::ExpandOneResponse> responses;
|
||||||
auto &shard_cache_ref = state.shard_cache;
|
responses.reserve(state.requests.size());
|
||||||
|
|
||||||
// 1. Send the requests.
|
|
||||||
SendAllRequests(state, shard_cache_ref);
|
|
||||||
|
|
||||||
// 2. Block untill all the futures are exhausted
|
|
||||||
do {
|
do {
|
||||||
AwaitOnResponses(state, responses);
|
DriveReadResponses(state, responses);
|
||||||
} while (!state.shard_cache.empty());
|
} while (!state.requests.empty());
|
||||||
|
|
||||||
|
// post-process responses
|
||||||
std::vector<msgs::ExpandOneResultRow> result_rows;
|
std::vector<msgs::ExpandOneResultRow> result_rows;
|
||||||
const auto total_row_count = std::accumulate(responses.begin(), responses.end(), 0,
|
const auto total_row_count = std::accumulate(responses.begin(), responses.end(), 0,
|
||||||
[](const int64_t partial_count, const msgs::ExpandOneResponse &resp) {
|
[](const int64_t partial_count, const msgs::ExpandOneResponse &resp) {
|
||||||
@ -350,7 +363,7 @@ class RequestRouter : public RequestRouterInterface {
|
|||||||
result_rows.insert(result_rows.end(), std::make_move_iterator(response.result.begin()),
|
result_rows.insert(result_rows.end(), std::make_move_iterator(response.result.begin()),
|
||||||
std::make_move_iterator(response.result.end()));
|
std::make_move_iterator(response.result.end()));
|
||||||
}
|
}
|
||||||
MaybeCompleteState(state);
|
|
||||||
return result_rows;
|
return result_rows;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -367,71 +380,35 @@ class RequestRouter : public RequestRouterInterface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
enum class PaginatedResponseState { Pending, PartiallyFinished };
|
void InitializeExecutionState(ExecutionState<msgs::CreateVerticesRequest> &state,
|
||||||
|
std::vector<msgs::NewVertex> new_vertices) {
|
||||||
std::vector<VertexAccessor> PostProcess(std::vector<msgs::ScanVerticesResponse> &&responses) const {
|
|
||||||
std::vector<VertexAccessor> accessors;
|
|
||||||
for (auto &response : responses) {
|
|
||||||
for (auto &result_row : response.results) {
|
|
||||||
accessors.emplace_back(VertexAccessor(std::move(result_row.vertex), std::move(result_row.props), this));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return accessors;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename ExecutionState>
|
|
||||||
void ThrowIfStateCompleted(ExecutionState &state) const {
|
|
||||||
if (state.state == ExecutionState::COMPLETED) [[unlikely]] {
|
|
||||||
throw std::runtime_error("State is completed and must be reset");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename ExecutionState>
|
|
||||||
void MaybeCompleteState(ExecutionState &state) const {
|
|
||||||
if (state.requests.empty()) {
|
|
||||||
state.state = ExecutionState::COMPLETED;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename ExecutionState>
|
|
||||||
bool ShallNotInitializeState(ExecutionState &state) const {
|
|
||||||
return state.state != ExecutionState::INITIALIZING;
|
|
||||||
}
|
|
||||||
|
|
||||||
void MaybeInitializeExecutionState(ExecutionState<msgs::CreateVerticesRequest> &state,
|
|
||||||
std::vector<msgs::NewVertex> new_vertices) {
|
|
||||||
ThrowIfStateCompleted(state);
|
|
||||||
if (ShallNotInitializeState(state)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
state.transaction_id = transaction_id_;
|
state.transaction_id = transaction_id_;
|
||||||
|
|
||||||
std::map<Shard, msgs::CreateVerticesRequest> per_shard_request_table;
|
std::map<Shard, msgs::CreateVerticesRequest> per_shard_request_table;
|
||||||
|
|
||||||
for (auto &new_vertex : new_vertices) {
|
for (auto &new_vertex : new_vertices) {
|
||||||
MG_ASSERT(!new_vertex.label_ids.empty(), "This is error!");
|
MG_ASSERT(!new_vertex.label_ids.empty(), "No label_ids provided for new vertex in RequestRouter::CreateVertices");
|
||||||
auto shard = shards_map_.GetShardForKey(new_vertex.label_ids[0].id,
|
auto shard = shards_map_.GetShardForKey(new_vertex.label_ids[0].id,
|
||||||
storage::conversions::ConvertPropertyVector(new_vertex.primary_key));
|
storage::conversions::ConvertPropertyVector(new_vertex.primary_key));
|
||||||
if (!per_shard_request_table.contains(shard)) {
|
if (!per_shard_request_table.contains(shard)) {
|
||||||
msgs::CreateVerticesRequest create_v_rqst{.transaction_id = transaction_id_};
|
msgs::CreateVerticesRequest create_v_rqst{.transaction_id = transaction_id_};
|
||||||
per_shard_request_table.insert(std::pair(shard, std::move(create_v_rqst)));
|
per_shard_request_table.insert(std::pair(shard, std::move(create_v_rqst)));
|
||||||
state.shard_cache.push_back(shard);
|
|
||||||
}
|
}
|
||||||
per_shard_request_table[shard].new_vertices.push_back(std::move(new_vertex));
|
per_shard_request_table[shard].new_vertices.push_back(std::move(new_vertex));
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto &[shard, rqst] : per_shard_request_table) {
|
for (auto &[shard, request] : per_shard_request_table) {
|
||||||
state.requests.push_back(std::move(rqst));
|
ShardRequestState<msgs::CreateVerticesRequest> shard_request_state{
|
||||||
|
.shard = shard,
|
||||||
|
.request = request,
|
||||||
|
.async_request_token = std::nullopt,
|
||||||
|
};
|
||||||
|
state.requests.emplace_back(std::move(shard_request_state));
|
||||||
}
|
}
|
||||||
state.state = ExecutionState<msgs::CreateVerticesRequest>::EXECUTING;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void MaybeInitializeExecutionState(ExecutionState<msgs::CreateExpandRequest> &state,
|
void InitializeExecutionState(ExecutionState<msgs::CreateExpandRequest> &state,
|
||||||
std::vector<msgs::NewExpand> new_expands) {
|
std::vector<msgs::NewExpand> new_expands) {
|
||||||
ThrowIfStateCompleted(state);
|
|
||||||
if (ShallNotInitializeState(state)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
state.transaction_id = transaction_id_;
|
state.transaction_id = transaction_id_;
|
||||||
|
|
||||||
std::map<Shard, msgs::CreateExpandRequest> per_shard_request_table;
|
std::map<Shard, msgs::CreateExpandRequest> per_shard_request_table;
|
||||||
@ -459,18 +436,16 @@ class RequestRouter : public RequestRouterInterface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (auto &[shard, request] : per_shard_request_table) {
|
for (auto &[shard, request] : per_shard_request_table) {
|
||||||
state.shard_cache.push_back(shard);
|
ShardRequestState<msgs::CreateExpandRequest> shard_request_state{
|
||||||
state.requests.push_back(std::move(request));
|
.shard = shard,
|
||||||
|
.request = request,
|
||||||
|
.async_request_token = std::nullopt,
|
||||||
|
};
|
||||||
|
state.requests.emplace_back(std::move(shard_request_state));
|
||||||
}
|
}
|
||||||
state.state = ExecutionState<msgs::CreateExpandRequest>::EXECUTING;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void MaybeInitializeExecutionState(ExecutionState<msgs::ScanVerticesRequest> &state) {
|
void InitializeExecutionState(ExecutionState<msgs::ScanVerticesRequest> &state) {
|
||||||
ThrowIfStateCompleted(state);
|
|
||||||
if (ShallNotInitializeState(state)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<coordinator::Shards> multi_shards;
|
std::vector<coordinator::Shards> multi_shards;
|
||||||
state.transaction_id = transaction_id_;
|
state.transaction_id = transaction_id_;
|
||||||
if (!state.label) {
|
if (!state.label) {
|
||||||
@ -484,21 +459,23 @@ class RequestRouter : public RequestRouterInterface {
|
|||||||
for (auto &shards : multi_shards) {
|
for (auto &shards : multi_shards) {
|
||||||
for (auto &[key, shard] : shards) {
|
for (auto &[key, shard] : shards) {
|
||||||
MG_ASSERT(!shard.empty());
|
MG_ASSERT(!shard.empty());
|
||||||
state.shard_cache.push_back(std::move(shard));
|
|
||||||
msgs::ScanVerticesRequest rqst;
|
msgs::ScanVerticesRequest request;
|
||||||
rqst.transaction_id = transaction_id_;
|
request.transaction_id = transaction_id_;
|
||||||
rqst.start_id.second = storage::conversions::ConvertValueVector(key);
|
request.start_id.second = storage::conversions::ConvertValueVector(key);
|
||||||
state.requests.push_back(std::move(rqst));
|
|
||||||
|
ShardRequestState<msgs::ScanVerticesRequest> shard_request_state{
|
||||||
|
.shard = shard,
|
||||||
|
.request = std::move(request),
|
||||||
|
.async_request_token = std::nullopt,
|
||||||
|
};
|
||||||
|
|
||||||
|
state.requests.emplace_back(std::move(shard_request_state));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
state.state = ExecutionState<msgs::ScanVerticesRequest>::EXECUTING;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void MaybeInitializeExecutionState(ExecutionState<msgs::ExpandOneRequest> &state, msgs::ExpandOneRequest request) {
|
void InitializeExecutionState(ExecutionState<msgs::ExpandOneRequest> &state, msgs::ExpandOneRequest request) {
|
||||||
ThrowIfStateCompleted(state);
|
|
||||||
if (ShallNotInitializeState(state)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
state.transaction_id = transaction_id_;
|
state.transaction_id = transaction_id_;
|
||||||
|
|
||||||
std::map<Shard, msgs::ExpandOneRequest> per_shard_request_table;
|
std::map<Shard, msgs::ExpandOneRequest> per_shard_request_table;
|
||||||
@ -511,15 +488,19 @@ class RequestRouter : public RequestRouterInterface {
|
|||||||
shards_map_.GetShardForKey(vertex.first.id, storage::conversions::ConvertPropertyVector(vertex.second));
|
shards_map_.GetShardForKey(vertex.first.id, storage::conversions::ConvertPropertyVector(vertex.second));
|
||||||
if (!per_shard_request_table.contains(shard)) {
|
if (!per_shard_request_table.contains(shard)) {
|
||||||
per_shard_request_table.insert(std::pair(shard, top_level_rqst_template));
|
per_shard_request_table.insert(std::pair(shard, top_level_rqst_template));
|
||||||
state.shard_cache.push_back(shard);
|
|
||||||
}
|
}
|
||||||
per_shard_request_table[shard].src_vertices.push_back(vertex);
|
per_shard_request_table[shard].src_vertices.push_back(vertex);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto &[shard, rqst] : per_shard_request_table) {
|
for (auto &[shard, request] : per_shard_request_table) {
|
||||||
state.requests.push_back(std::move(rqst));
|
ShardRequestState<msgs::ExpandOneRequest> shard_request_state{
|
||||||
|
.shard = shard,
|
||||||
|
.request = request,
|
||||||
|
.async_request_token = std::nullopt,
|
||||||
|
};
|
||||||
|
|
||||||
|
state.requests.emplace_back(std::move(shard_request_state));
|
||||||
}
|
}
|
||||||
state.state = ExecutionState<msgs::ExpandOneRequest>::EXECUTING;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
StorageClient &GetStorageClientForShard(Shard shard) {
|
StorageClient &GetStorageClientForShard(Shard shard) {
|
||||||
@ -546,173 +527,54 @@ class RequestRouter : public RequestRouterInterface {
|
|||||||
storage_cli_manager_.AddClient(target_shard, std::move(cli));
|
storage_cli_manager_.AddClient(target_shard, std::move(cli));
|
||||||
}
|
}
|
||||||
|
|
||||||
void SendAllRequests(ExecutionState<msgs::ScanVerticesRequest> &state) {
|
template <typename RequestT, typename ResponseT>
|
||||||
int64_t shard_idx = 0;
|
void DriveReadResponses(ExecutionState<RequestT> &state, std::vector<ResponseT> &responses) {
|
||||||
for (const auto &request : state.requests) {
|
for (auto &request : state.requests) {
|
||||||
const auto ¤t_shard = state.shard_cache[shard_idx];
|
auto &storage_client = GetStorageClientForShard(request.shard);
|
||||||
|
|
||||||
auto &storage_client = GetStorageClientForShard(current_shard);
|
auto poll_result = storage_client.AwaitAsyncReadRequest(request.async_request_token.value());
|
||||||
msgs::ReadRequests req = request;
|
while (!poll_result) {
|
||||||
storage_client.SendAsyncReadRequest(request);
|
poll_result = storage_client.AwaitAsyncReadRequest(request.async_request_token.value());
|
||||||
|
|
||||||
++shard_idx;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void SendAllRequests(ExecutionState<msgs::CreateVerticesRequest> &state,
|
|
||||||
std::vector<memgraph::coordinator::Shard> &shard_cache_ref) {
|
|
||||||
size_t id = 0;
|
|
||||||
for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end(); ++shard_it) {
|
|
||||||
// This is fine because all new_vertices of each request end up on the same shard
|
|
||||||
const auto labels = state.requests[id].new_vertices[0].label_ids;
|
|
||||||
auto req_deep_copy = state.requests[id];
|
|
||||||
|
|
||||||
for (auto &new_vertex : req_deep_copy.new_vertices) {
|
|
||||||
new_vertex.label_ids.erase(new_vertex.label_ids.begin());
|
|
||||||
}
|
|
||||||
|
|
||||||
auto &storage_client = GetStorageClientForShard(*shard_it);
|
|
||||||
|
|
||||||
msgs::WriteRequests req = req_deep_copy;
|
|
||||||
storage_client.SendAsyncWriteRequest(req);
|
|
||||||
++id;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void SendAllRequests(ExecutionState<msgs::ExpandOneRequest> &state,
|
|
||||||
std::vector<memgraph::coordinator::Shard> &shard_cache_ref) {
|
|
||||||
size_t id = 0;
|
|
||||||
for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end(); ++shard_it) {
|
|
||||||
auto &storage_client = GetStorageClientForShard(*shard_it);
|
|
||||||
msgs::ReadRequests req = state.requests[id];
|
|
||||||
storage_client.SendAsyncReadRequest(req);
|
|
||||||
++id;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void AwaitOnResponses(ExecutionState<msgs::CreateVerticesRequest> &state,
|
|
||||||
std::vector<msgs::CreateVerticesResponse> &responses) {
|
|
||||||
auto &shard_cache_ref = state.shard_cache;
|
|
||||||
int64_t request_idx = 0;
|
|
||||||
|
|
||||||
for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end();) {
|
|
||||||
// This is fine because all new_vertices of each request end up on the same shard
|
|
||||||
const auto labels = state.requests[request_idx].new_vertices[0].label_ids;
|
|
||||||
|
|
||||||
auto &storage_client = GetStorageClientForShard(*shard_it);
|
|
||||||
|
|
||||||
auto poll_result = storage_client.AwaitAsyncWriteRequest();
|
|
||||||
if (!poll_result) {
|
|
||||||
++shard_it;
|
|
||||||
++request_idx;
|
|
||||||
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (poll_result->HasError()) {
|
if (poll_result->HasError()) {
|
||||||
throw std::runtime_error("CreateVertices request timed out");
|
throw std::runtime_error("RequestRouter Read request timed out");
|
||||||
}
|
|
||||||
|
|
||||||
msgs::WriteResponses response_variant = poll_result->GetValue();
|
|
||||||
auto response = std::get<msgs::CreateVerticesResponse>(response_variant);
|
|
||||||
|
|
||||||
if (response.error) {
|
|
||||||
throw std::runtime_error("CreateVertices request did not succeed");
|
|
||||||
}
|
|
||||||
responses.push_back(response);
|
|
||||||
|
|
||||||
shard_it = shard_cache_ref.erase(shard_it);
|
|
||||||
// Needed to maintain the 1-1 mapping between the ShardCache and the requests.
|
|
||||||
auto it = state.requests.begin() + request_idx;
|
|
||||||
state.requests.erase(it);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void AwaitOnResponses(ExecutionState<msgs::ExpandOneRequest> &state,
|
|
||||||
std::vector<msgs::ExpandOneResponse> &responses) {
|
|
||||||
auto &shard_cache_ref = state.shard_cache;
|
|
||||||
int64_t request_idx = 0;
|
|
||||||
|
|
||||||
for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end();) {
|
|
||||||
auto &storage_client = GetStorageClientForShard(*shard_it);
|
|
||||||
|
|
||||||
auto poll_result = storage_client.PollAsyncReadRequest();
|
|
||||||
if (!poll_result) {
|
|
||||||
++shard_it;
|
|
||||||
++request_idx;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (poll_result->HasError()) {
|
|
||||||
throw std::runtime_error("ExpandOne request timed out");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
msgs::ReadResponses response_variant = poll_result->GetValue();
|
msgs::ReadResponses response_variant = poll_result->GetValue();
|
||||||
auto response = std::get<msgs::ExpandOneResponse>(response_variant);
|
auto response = std::get<ResponseT>(response_variant);
|
||||||
// -NOTE-
|
|
||||||
// Currently a boolean flag for signaling the overall success of the
|
|
||||||
// ExpandOne request does not exist. But it should, so here we assume
|
|
||||||
// that it is already in place.
|
|
||||||
if (response.error) {
|
if (response.error) {
|
||||||
throw std::runtime_error("ExpandOne request did not succeed");
|
throw std::runtime_error("RequestRouter Read request did not succeed");
|
||||||
}
|
}
|
||||||
|
|
||||||
responses.push_back(std::move(response));
|
responses.push_back(std::move(response));
|
||||||
shard_it = shard_cache_ref.erase(shard_it);
|
|
||||||
// Needed to maintain the 1-1 mapping between the ShardCache and the requests.
|
|
||||||
auto it = state.requests.begin() + request_idx;
|
|
||||||
state.requests.erase(it);
|
|
||||||
}
|
}
|
||||||
|
state.requests.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
void AwaitOnPaginatedRequests(ExecutionState<msgs::ScanVerticesRequest> &state,
|
template <typename RequestT, typename ResponseT>
|
||||||
std::vector<msgs::ScanVerticesResponse> &responses,
|
void DriveWriteResponses(ExecutionState<RequestT> &state, std::vector<ResponseT> &responses) {
|
||||||
std::map<Shard, PaginatedResponseState> &paginated_response_tracker) {
|
for (auto &request : state.requests) {
|
||||||
auto &shard_cache_ref = state.shard_cache;
|
auto &storage_client = GetStorageClientForShard(request.shard);
|
||||||
|
|
||||||
// Find the first request that is not holding a paginated response.
|
auto poll_result = storage_client.AwaitAsyncWriteRequest(request.async_request_token.value());
|
||||||
int64_t request_idx = 0;
|
while (!poll_result) {
|
||||||
for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end();) {
|
poll_result = storage_client.AwaitAsyncWriteRequest(request.async_request_token.value());
|
||||||
if (paginated_response_tracker.at(*shard_it) != PaginatedResponseState::Pending) {
|
|
||||||
++shard_it;
|
|
||||||
++request_idx;
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto &storage_client = GetStorageClientForShard(*shard_it);
|
if (poll_result->HasError()) {
|
||||||
|
throw std::runtime_error("RequestRouter Write request timed out");
|
||||||
auto await_result = storage_client.AwaitAsyncReadRequest();
|
|
||||||
|
|
||||||
if (!await_result) {
|
|
||||||
// Redirection has occured.
|
|
||||||
++shard_it;
|
|
||||||
++request_idx;
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (await_result->HasError()) {
|
msgs::WriteResponses response_variant = poll_result->GetValue();
|
||||||
throw std::runtime_error("ScanAll request timed out");
|
auto response = std::get<ResponseT>(response_variant);
|
||||||
}
|
|
||||||
|
|
||||||
msgs::ReadResponses read_response_variant = await_result->GetValue();
|
|
||||||
auto response = std::get<msgs::ScanVerticesResponse>(read_response_variant);
|
|
||||||
if (response.error) {
|
if (response.error) {
|
||||||
throw std::runtime_error("ScanAll request did not succeed");
|
throw std::runtime_error("RequestRouter Write request did not succeed");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!response.next_start_id) {
|
|
||||||
paginated_response_tracker.erase((*shard_it));
|
|
||||||
shard_cache_ref.erase(shard_it);
|
|
||||||
// Needed to maintain the 1-1 mapping between the ShardCache and the requests.
|
|
||||||
auto it = state.requests.begin() + request_idx;
|
|
||||||
state.requests.erase(it);
|
|
||||||
|
|
||||||
} else {
|
|
||||||
state.requests[request_idx].start_id.second = response.next_start_id->second;
|
|
||||||
paginated_response_tracker[*shard_it] = PaginatedResponseState::PartiallyFinished;
|
|
||||||
}
|
|
||||||
responses.push_back(std::move(response));
|
responses.push_back(std::move(response));
|
||||||
}
|
}
|
||||||
|
state.requests.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetUpNameIdMappers() {
|
void SetUpNameIdMappers() {
|
||||||
|
@ -164,8 +164,6 @@ void ExecuteOp(query::v2::RequestRouter<SimulatorTransport> &request_router, std
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
query::v2::ExecutionState<msgs::CreateVerticesRequest> state;
|
|
||||||
|
|
||||||
auto label_id = request_router.NameToLabel("test_label");
|
auto label_id = request_router.NameToLabel("test_label");
|
||||||
|
|
||||||
msgs::NewVertex nv{.primary_key = primary_key};
|
msgs::NewVertex nv{.primary_key = primary_key};
|
||||||
@ -174,7 +172,7 @@ void ExecuteOp(query::v2::RequestRouter<SimulatorTransport> &request_router, std
|
|||||||
std::vector<msgs::NewVertex> new_vertices;
|
std::vector<msgs::NewVertex> new_vertices;
|
||||||
new_vertices.push_back(std::move(nv));
|
new_vertices.push_back(std::move(nv));
|
||||||
|
|
||||||
auto result = request_router.Request(state, std::move(new_vertices));
|
auto result = request_router.CreateVertices(std::move(new_vertices));
|
||||||
|
|
||||||
RC_ASSERT(result.size() == 1);
|
RC_ASSERT(result.size() == 1);
|
||||||
RC_ASSERT(!result[0].error.has_value());
|
RC_ASSERT(!result[0].error.has_value());
|
||||||
@ -184,9 +182,7 @@ void ExecuteOp(query::v2::RequestRouter<SimulatorTransport> &request_router, std
|
|||||||
|
|
||||||
void ExecuteOp(query::v2::RequestRouter<SimulatorTransport> &request_router, std::set<CompoundKey> &correctness_model,
|
void ExecuteOp(query::v2::RequestRouter<SimulatorTransport> &request_router, std::set<CompoundKey> &correctness_model,
|
||||||
ScanAll scan_all) {
|
ScanAll scan_all) {
|
||||||
query::v2::ExecutionState<msgs::ScanVerticesRequest> request{.label = "test_label"};
|
auto results = request_router.ScanVertices("test_label");
|
||||||
|
|
||||||
auto results = request_router.Request(request);
|
|
||||||
|
|
||||||
RC_ASSERT(results.size() == correctness_model.size());
|
RC_ASSERT(results.size() == correctness_model.size());
|
||||||
|
|
||||||
|
@ -174,8 +174,6 @@ void ExecuteOp(query::v2::RequestRouter<LocalTransport> &request_router, std::se
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
query::v2::ExecutionState<msgs::CreateVerticesRequest> state;
|
|
||||||
|
|
||||||
auto label_id = request_router.NameToLabel("test_label");
|
auto label_id = request_router.NameToLabel("test_label");
|
||||||
|
|
||||||
msgs::NewVertex nv{.primary_key = primary_key};
|
msgs::NewVertex nv{.primary_key = primary_key};
|
||||||
@ -184,7 +182,7 @@ void ExecuteOp(query::v2::RequestRouter<LocalTransport> &request_router, std::se
|
|||||||
std::vector<msgs::NewVertex> new_vertices;
|
std::vector<msgs::NewVertex> new_vertices;
|
||||||
new_vertices.push_back(std::move(nv));
|
new_vertices.push_back(std::move(nv));
|
||||||
|
|
||||||
auto result = request_router.Request(state, std::move(new_vertices));
|
auto result = request_router.CreateVertices(std::move(new_vertices));
|
||||||
|
|
||||||
MG_ASSERT(result.size() == 1);
|
MG_ASSERT(result.size() == 1);
|
||||||
MG_ASSERT(!result[0].error.has_value());
|
MG_ASSERT(!result[0].error.has_value());
|
||||||
@ -194,9 +192,7 @@ void ExecuteOp(query::v2::RequestRouter<LocalTransport> &request_router, std::se
|
|||||||
|
|
||||||
void ExecuteOp(query::v2::RequestRouter<LocalTransport> &request_router, std::set<CompoundKey> &correctness_model,
|
void ExecuteOp(query::v2::RequestRouter<LocalTransport> &request_router, std::set<CompoundKey> &correctness_model,
|
||||||
ScanAll scan_all) {
|
ScanAll scan_all) {
|
||||||
query::v2::ExecutionState<msgs::ScanVerticesRequest> request{.label = "test_label"};
|
auto results = request_router.ScanVertices("test_label");
|
||||||
|
|
||||||
auto results = request_router.Request(request);
|
|
||||||
|
|
||||||
MG_ASSERT(results.size() == correctness_model.size());
|
MG_ASSERT(results.size() == correctness_model.size());
|
||||||
|
|
||||||
|
@ -111,15 +111,12 @@ ShardMap TestShardMap() {
|
|||||||
|
|
||||||
template <typename RequestRouter>
|
template <typename RequestRouter>
|
||||||
void TestScanAll(RequestRouter &request_router) {
|
void TestScanAll(RequestRouter &request_router) {
|
||||||
query::v2::ExecutionState<msgs::ScanVerticesRequest> state{.label = kLabelName};
|
auto result = request_router.ScanVertices(kLabelName);
|
||||||
|
|
||||||
auto result = request_router.Request(state);
|
|
||||||
EXPECT_EQ(result.size(), 2);
|
EXPECT_EQ(result.size(), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestCreateVertices(query::v2::RequestRouterInterface &request_router) {
|
void TestCreateVertices(query::v2::RequestRouterInterface &request_router) {
|
||||||
using PropVal = msgs::Value;
|
using PropVal = msgs::Value;
|
||||||
query::v2::ExecutionState<msgs::CreateVerticesRequest> state;
|
|
||||||
std::vector<msgs::NewVertex> new_vertices;
|
std::vector<msgs::NewVertex> new_vertices;
|
||||||
auto label_id = request_router.NameToLabel(kLabelName);
|
auto label_id = request_router.NameToLabel(kLabelName);
|
||||||
msgs::NewVertex a1{.primary_key = {PropVal(int64_t(0)), PropVal(int64_t(0))}};
|
msgs::NewVertex a1{.primary_key = {PropVal(int64_t(0)), PropVal(int64_t(0))}};
|
||||||
@ -129,14 +126,13 @@ void TestCreateVertices(query::v2::RequestRouterInterface &request_router) {
|
|||||||
new_vertices.push_back(std::move(a1));
|
new_vertices.push_back(std::move(a1));
|
||||||
new_vertices.push_back(std::move(a2));
|
new_vertices.push_back(std::move(a2));
|
||||||
|
|
||||||
auto result = request_router.Request(state, std::move(new_vertices));
|
auto result = request_router.CreateVertices(std::move(new_vertices));
|
||||||
EXPECT_EQ(result.size(), 1);
|
EXPECT_EQ(result.size(), 1);
|
||||||
EXPECT_FALSE(result[0].error.has_value()) << result[0].error->message;
|
EXPECT_FALSE(result[0].error.has_value()) << result[0].error->message;
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestCreateExpand(query::v2::RequestRouterInterface &request_router) {
|
void TestCreateExpand(query::v2::RequestRouterInterface &request_router) {
|
||||||
using PropVal = msgs::Value;
|
using PropVal = msgs::Value;
|
||||||
query::v2::ExecutionState<msgs::CreateExpandRequest> state;
|
|
||||||
std::vector<msgs::NewExpand> new_expands;
|
std::vector<msgs::NewExpand> new_expands;
|
||||||
|
|
||||||
const auto edge_type_id = request_router.NameToEdgeType("edge_type");
|
const auto edge_type_id = request_router.NameToEdgeType("edge_type");
|
||||||
@ -150,20 +146,19 @@ void TestCreateExpand(query::v2::RequestRouterInterface &request_router) {
|
|||||||
new_expands.push_back(std::move(expand_1));
|
new_expands.push_back(std::move(expand_1));
|
||||||
new_expands.push_back(std::move(expand_2));
|
new_expands.push_back(std::move(expand_2));
|
||||||
|
|
||||||
auto responses = request_router.Request(state, std::move(new_expands));
|
auto responses = request_router.CreateExpand(std::move(new_expands));
|
||||||
MG_ASSERT(responses.size() == 1);
|
MG_ASSERT(responses.size() == 1);
|
||||||
MG_ASSERT(!responses[0].error.has_value());
|
MG_ASSERT(!responses[0].error.has_value());
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestExpandOne(query::v2::RequestRouterInterface &request_router) {
|
void TestExpandOne(query::v2::RequestRouterInterface &request_router) {
|
||||||
query::v2::ExecutionState<msgs::ExpandOneRequest> state{};
|
|
||||||
msgs::ExpandOneRequest request;
|
msgs::ExpandOneRequest request;
|
||||||
const auto edge_type_id = request_router.NameToEdgeType("edge_type");
|
const auto edge_type_id = request_router.NameToEdgeType("edge_type");
|
||||||
const auto label = msgs::Label{request_router.NameToLabel("test_label")};
|
const auto label = msgs::Label{request_router.NameToLabel("test_label")};
|
||||||
request.src_vertices.push_back(msgs::VertexId{label, {msgs::Value(int64_t(0)), msgs::Value(int64_t(0))}});
|
request.src_vertices.push_back(msgs::VertexId{label, {msgs::Value(int64_t(0)), msgs::Value(int64_t(0))}});
|
||||||
request.edge_types.push_back(msgs::EdgeType{edge_type_id});
|
request.edge_types.push_back(msgs::EdgeType{edge_type_id});
|
||||||
request.direction = msgs::EdgeDirection::BOTH;
|
request.direction = msgs::EdgeDirection::BOTH;
|
||||||
auto result_rows = request_router.Request(state, std::move(request));
|
auto result_rows = request_router.ExpandOne(std::move(request));
|
||||||
MG_ASSERT(result_rows.size() == 1);
|
MG_ASSERT(result_rows.size() == 1);
|
||||||
MG_ASSERT(result_rows[0].in_edges_with_all_properties.size() == 1);
|
MG_ASSERT(result_rows[0].in_edges_with_all_properties.size() == 1);
|
||||||
MG_ASSERT(result_rows[0].out_edges_with_all_properties.size() == 1);
|
MG_ASSERT(result_rows[0].out_edges_with_all_properties.size() == 1);
|
||||||
|
@ -82,23 +82,15 @@ class MockedRequestRouter : public RequestRouterInterface {
|
|||||||
}
|
}
|
||||||
void StartTransaction() override {}
|
void StartTransaction() override {}
|
||||||
void Commit() override {}
|
void Commit() override {}
|
||||||
std::vector<VertexAccessor> Request(ExecutionState<memgraph::msgs::ScanVerticesRequest> &state) override {
|
std::vector<VertexAccessor> ScanVertices(std::optional<std::string> /* label */) override { return {}; }
|
||||||
|
|
||||||
|
std::vector<CreateVerticesResponse> CreateVertices(std::vector<memgraph::msgs::NewVertex> new_vertices) override {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<CreateVerticesResponse> Request(ExecutionState<CreateVerticesRequest> &state,
|
std::vector<ExpandOneResultRow> ExpandOne(ExpandOneRequest request) override { return {}; }
|
||||||
std::vector<memgraph::msgs::NewVertex> new_vertices) override {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<ExpandOneResultRow> Request(ExecutionState<ExpandOneRequest> &state, ExpandOneRequest request) override {
|
std::vector<CreateExpandResponse> CreateExpand(std::vector<NewExpand> new_edges) override { return {}; }
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<CreateExpandResponse> Request(ExecutionState<CreateExpandRequest> &state,
|
|
||||||
std::vector<NewExpand> new_edges) override {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::string &PropertyToName(memgraph::storage::v3::PropertyId id) const override {
|
const std::string &PropertyToName(memgraph::storage::v3::PropertyId id) const override {
|
||||||
return properties_.IdToName(id.AsUint());
|
return properties_.IdToName(id.AsUint());
|
||||||
|
Loading…
Reference in New Issue
Block a user