Add abstract interface to Middleware and rename it. Cleaned up member function implementation and tests
This commit is contained in:
parent
ba06d29a35
commit
e442bf435a
@ -45,15 +45,16 @@ class RsmStorageClientManager {
|
||||
RsmStorageClientManager(const RsmStorageClientManager &) = delete;
|
||||
RsmStorageClientManager(RsmStorageClientManager &&) = delete;
|
||||
|
||||
void AddClient(const std::string &label, CompoundKey cm_k, TStorageClient client) {
|
||||
cli_cache_[label].insert({std::move(cm_k), std::move(client)});
|
||||
void AddClient(const std::string &label, CompoundKey key, TStorageClient client) {
|
||||
cli_cache_[label].insert({std::move(key), std::move(client)});
|
||||
}
|
||||
|
||||
bool Exists(const std::string &label, const CompoundKey &cm_k) { return cli_cache_[label].contains(cm_k); }
|
||||
bool Exists(const std::string &label, const CompoundKey &key) { return cli_cache_[label].contains(key); }
|
||||
|
||||
void PurgeCache() { cli_cache_.clear(); }
|
||||
// void EvictFromCache(std::vector<TStorageClient>);
|
||||
TStorageClient &GetClient(const std::string &label, CompoundKey key) { return cli_cache_[label].find(key)->second; }
|
||||
TStorageClient &GetClient(const std::string &label, const CompoundKey &key) {
|
||||
return cli_cache_[label].find(key)->second;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, std::map<CompoundKey, TStorageClient>> cli_cache_;
|
||||
@ -63,31 +64,51 @@ template <typename TRequest>
|
||||
struct ExecutionState {
|
||||
using CompoundKey = memgraph::io::rsm::ShardRsmKey;
|
||||
using Shard = memgraph::coordinator::Shard;
|
||||
std::optional<std::vector<Shard>> state_;
|
||||
std::string label;
|
||||
// using CompoundKey = memgraph::coordinator::CompoundKey;
|
||||
const std::string label;
|
||||
// CompoundKey is optional because some operators require to iterate over all the available keys
|
||||
// of a shard. One example is ScanAll, where we only require the field label.
|
||||
std::optional<CompoundKey> key;
|
||||
// Transaction id to be filled by the ShardRequestManager implementation
|
||||
memgraph::coordinator::Hlc transaction_id;
|
||||
// Initialized by ShardRequestManager implementation. This vector is filled with the shards that
|
||||
// the ShardRequestManager impl will send requests to. When a request to a shard exhausts it, meaning that
|
||||
// it pulled all the requested data from the given Shard, it will be removed from the Vector. When the Vector becomes
|
||||
// empty, it means that all of the requests have completed succefully.
|
||||
std::optional<std::vector<Shard>> state_;
|
||||
// 1-1 mapping with `state_`.
|
||||
// A vector that tracks request metatdata for each shard (For example, next_id for a ScanAll on Shard A)
|
||||
std::vector<TRequest> requests;
|
||||
};
|
||||
|
||||
namespace rsm = memgraph::io::rsm;
|
||||
class ShardRequestManagerInterface {
|
||||
public:
|
||||
ShardRequestManagerInterface() = default;
|
||||
virtual void StartTransaction() = 0;
|
||||
virtual std::vector<ScanVerticesResponse> Request(ExecutionState<ScanVerticesRequest> &state) = 0;
|
||||
virtual ~ShardRequestManagerInterface() {}
|
||||
ShardRequestManagerInterface(const ShardRequestManagerInterface &) = delete;
|
||||
ShardRequestManagerInterface(ShardRequestManagerInterface &&) = delete;
|
||||
};
|
||||
|
||||
// TODO(kostasrim)rename this class template
|
||||
template <typename TTransport, typename... Rest>
|
||||
class QueryEngineMiddleware {
|
||||
class ShardRequestManager : public ShardRequestManagerInterface {
|
||||
public:
|
||||
using StorageWriteRequest = memgraph::io::rsm::StorageWriteRequest;
|
||||
using StorageWriteResponse = memgraph::io::rsm::StorageWriteResponse;
|
||||
using StorageClient =
|
||||
memgraph::coordinator::RsmClient<TTransport, rsm::StorageWriteRequest, rsm::StorageWriteResponse, Rest...>;
|
||||
memgraph::coordinator::RsmClient<TTransport, StorageWriteRequest, StorageWriteResponse, Rest...>;
|
||||
using CoordinatorClient = memgraph::coordinator::CoordinatorClient<TTransport>;
|
||||
using Address = memgraph::io::Address;
|
||||
using Shard = memgraph::coordinator::Shard;
|
||||
using ShardMap = memgraph::coordinator::ShardMap;
|
||||
using CompoundKey = memgraph::coordinator::CompoundKey;
|
||||
QueryEngineMiddleware(CoordinatorClient coord, memgraph::io::Io<TTransport> &&io)
|
||||
ShardRequestManager(CoordinatorClient coord, memgraph::io::Io<TTransport> &&io)
|
||||
: coord_cli_(std::move(coord)), io_(std::move(io)) {}
|
||||
|
||||
void StartTransaction() {
|
||||
~ShardRequestManager() override {}
|
||||
|
||||
void StartTransaction() override {
|
||||
memgraph::coordinator::HlcRequest req{.last_shard_map_version = shards_map_.GetHlc()};
|
||||
auto read_res = coord_cli_.SendReadRequest(req);
|
||||
if (read_res.HasError()) {
|
||||
@ -99,20 +120,20 @@ class QueryEngineMiddleware {
|
||||
// Transaction ID to be used later...
|
||||
transaction_id_ = hlc_response.new_hlc;
|
||||
|
||||
if (hlc_response.fresher_shard_map) {
|
||||
shards_map_ = hlc_response.fresher_shard_map.value();
|
||||
} else {
|
||||
if (!hlc_response.fresher_shard_map) {
|
||||
throw std::runtime_error("Should handle gracefully!");
|
||||
}
|
||||
shards_map_ = hlc_response.fresher_shard_map.value();
|
||||
}
|
||||
|
||||
std::vector<ScanVerticesResponse> Request(ExecutionState<ScanVerticesRequest> &state) {
|
||||
MaybeUpdateExecutionState(state);
|
||||
std::vector<ScanVerticesResponse> Request(ExecutionState<ScanVerticesRequest> &state) override {
|
||||
MaybeInitializeExecutionState(state);
|
||||
std::vector<ScanVerticesResponse> responses;
|
||||
auto &state_ref = *state.state_;
|
||||
size_t id = 0;
|
||||
for (auto shard_it = state_ref.begin(); shard_it != state_ref.end(); ++id) {
|
||||
auto &storage_client = GetStorageClientForShard(state.label, state.requests[id].start_id.second);
|
||||
// TODO(kostasrim) Currently requests return the result directly. Adjust this when the API works MgFuture instead.
|
||||
auto read_response_result = storage_client.SendReadRequest(state.requests[id]);
|
||||
// RETRY on timeouts?
|
||||
// Sometimes this produces a timeout. Temporary solution is to use a while(true) as was done in shard_map test
|
||||
@ -120,7 +141,7 @@ class QueryEngineMiddleware {
|
||||
throw std::runtime_error("Read request error");
|
||||
}
|
||||
if (read_response_result.GetValue().success == false) {
|
||||
throw std::runtime_error("ReadRequest failed");
|
||||
throw std::runtime_error("Request did not succeed");
|
||||
}
|
||||
responses.push_back(read_response_result.GetValue());
|
||||
if (!read_response_result.GetValue().next_start_id) {
|
||||
@ -130,30 +151,12 @@ class QueryEngineMiddleware {
|
||||
++shard_it;
|
||||
}
|
||||
}
|
||||
// TODO(kostasrim) Update state accordingly
|
||||
// TODO(kostasrim) Before returning start prefetching the batch (this shall be done once we get MgFuture as return
|
||||
// result of storage_client.SendReadRequest()).
|
||||
return responses;
|
||||
// For a future based API. Also maybe introduce a `Retry` function that accepts a lambda which is the request
|
||||
// and a number denoting the number of times the request is retried until an exception or an error is returned.
|
||||
// std::vector<memgraph::io::future<ScanAllVerticesRequest>> requests;
|
||||
// for (const auto &shard : state.state_) {
|
||||
// auto &storage_client = GetStorageClientForShard(state.Label, rqst.label);
|
||||
// requests.push_back(client->Request(rqst));
|
||||
// }
|
||||
//
|
||||
// std::vector<ScanAllVerticesResponse> responses;
|
||||
// for (auto &f : requests) {
|
||||
// f.wait();
|
||||
// if (f.HasError()) {
|
||||
// // handle error
|
||||
// }
|
||||
// responses.push_back(std::move(f).Value());
|
||||
// }
|
||||
}
|
||||
|
||||
// CreateVerticesResponse Request(CreateVerticesRequest rqst, ExecutionState &state) {
|
||||
// // MaybeUpdateShardMap();
|
||||
// // MaybeUpdateExecutionState();
|
||||
// }
|
||||
std::vector<CreateVerticesResponse> Request(ExecutionState<CreateVerticesRequest> &state) {}
|
||||
|
||||
// size_t TestRequest(ExecutionState &state) {
|
||||
// MaybeUpdateShardMap(state);
|
||||
@ -239,7 +242,7 @@ class QueryEngineMiddleware {
|
||||
}
|
||||
}
|
||||
|
||||
void MaybeUpdateExecutionState(ExecutionState<ScanVerticesRequest> &state) {
|
||||
void MaybeInitializeExecutionState(ExecutionState<ScanVerticesRequest> &state) {
|
||||
if (state.state_) {
|
||||
return;
|
||||
}
|
||||
@ -257,19 +260,20 @@ class QueryEngineMiddleware {
|
||||
|
||||
// std::vector<storageclient> GetStorageClientFromShardforRange(const std::string &label, const CompoundKey &start,
|
||||
// const CompoundKey &end);
|
||||
StorageClient &GetStorageClientForShard(const std::string &label, const CompoundKey &cm_k) {
|
||||
if (storage_cli_manager_.Exists(label, cm_k)) {
|
||||
return storage_cli_manager_.GetClient(label, cm_k);
|
||||
StorageClient &GetStorageClientForShard(const std::string &label, const CompoundKey &key) {
|
||||
if (storage_cli_manager_.Exists(label, key)) {
|
||||
return storage_cli_manager_.GetClient(label, key);
|
||||
}
|
||||
auto target_shard = shards_map_.GetShardForKey(label, cm_k);
|
||||
AddStorageClientToManager(std::move(target_shard), label, cm_k);
|
||||
return storage_cli_manager_.GetClient(label, cm_k);
|
||||
auto target_shard = shards_map_.GetShardForKey(label, key);
|
||||
AddStorageClientToManager(std::move(target_shard), label, key);
|
||||
return storage_cli_manager_.GetClient(label, key);
|
||||
}
|
||||
|
||||
void AddStorageClientToManager(Shard target_shard, const std::string &label, const CompoundKey &cm_k) {
|
||||
MG_ASSERT(!target_shard.empty());
|
||||
auto leader_addr = target_shard.front();
|
||||
std::vector<Address> addresses;
|
||||
addresses.reserve(target_shard.size());
|
||||
for (auto &address : target_shard) {
|
||||
addresses.push_back(std::move(address.address));
|
||||
}
|
||||
|
@ -54,7 +54,7 @@ using memgraph::storage::v3::PropertyValue;
|
||||
|
||||
using ShardRsmKey = std::vector<memgraph::storage::v3::PropertyValue>;
|
||||
|
||||
class ShardRsmV2 {
|
||||
class MockedShardRsm {
|
||||
std::map<ShardRsmKey, int> state_;
|
||||
ShardRsmKey minimum_key_;
|
||||
std::optional<ShardRsmKey> maximum_key_{std::nullopt};
|
||||
@ -101,22 +101,6 @@ class ShardRsmV2 {
|
||||
return ret;
|
||||
}
|
||||
|
||||
// StorageReadResponse Read(StorageReadRequest request) {
|
||||
// StorageReadResponse ret;
|
||||
//
|
||||
// if (!IsKeyInRange(request.key)) {
|
||||
// ret.latest_known_shard_map_version = shard_map_version_;
|
||||
// ret.shard_rsm_success = false;
|
||||
// } else if (state_.contains(request.key)) {
|
||||
// ret.value = state_[request.key];
|
||||
// ret.shard_rsm_success = true;
|
||||
// } else {
|
||||
// ret.shard_rsm_success = false;
|
||||
// ret.value = std::nullopt;
|
||||
// }
|
||||
// return ret;
|
||||
// }
|
||||
//
|
||||
StorageWriteResponse Apply(StorageWriteRequest request) {
|
||||
StorageWriteResponse ret;
|
||||
|
||||
|
@ -55,7 +55,6 @@ using memgraph::io::rsm::Raft;
|
||||
using memgraph::io::rsm::ReadRequest;
|
||||
using memgraph::io::rsm::ReadResponse;
|
||||
using memgraph::io::rsm::RsmClient;
|
||||
using memgraph::io::rsm::ShardRsm;
|
||||
using memgraph::io::rsm::StorageReadRequest;
|
||||
using memgraph::io::rsm::StorageReadResponse;
|
||||
using memgraph::io::rsm::StorageWriteRequest;
|
||||
@ -117,16 +116,43 @@ ShardMap CreateDummyShardmap(memgraph::coordinator::Address a_io_1, memgraph::co
|
||||
} // namespace
|
||||
|
||||
using ConcreteCoordinatorRsm = CoordinatorRsm<SimulatorTransport>;
|
||||
using ConcreteStorageRsm = Raft<SimulatorTransport, ShardRsmV2, StorageWriteRequest, StorageWriteResponse,
|
||||
using ConcreteStorageRsm = Raft<SimulatorTransport, MockedShardRsm, StorageWriteRequest, StorageWriteResponse,
|
||||
ScanVerticesRequest, ScanVerticesResponse>;
|
||||
|
||||
template <typename IoImpl>
|
||||
void RunStorageRaft(
|
||||
Raft<IoImpl, ShardRsmV2, StorageWriteRequest, StorageWriteResponse, ScanVerticesRequest, ScanVerticesResponse>
|
||||
Raft<IoImpl, MockedShardRsm, StorageWriteRequest, StorageWriteResponse, ScanVerticesRequest, ScanVerticesResponse>
|
||||
server) {
|
||||
server.Run();
|
||||
}
|
||||
|
||||
template <typename ShardRequestManager>
|
||||
void TestScanAll(ShardRequestManager &io) {
|
||||
ExecutionState<ScanVerticesRequest> state{.label = "test_label"};
|
||||
state.key = std::make_optional<CompoundKey>(
|
||||
std::vector{memgraph::storage::v3::PropertyValue(0), memgraph::storage::v3::PropertyValue(0)});
|
||||
|
||||
auto result = io.Request(state);
|
||||
MG_ASSERT(result.size() == 2);
|
||||
{
|
||||
auto &list_of_values_1 = std::get<ListedValues>(result[0].values);
|
||||
MG_ASSERT(list_of_values_1.properties[0][0].int_v == 0);
|
||||
auto &list_of_values_2 = std::get<ListedValues>(result[1].values);
|
||||
MG_ASSERT(list_of_values_2.properties[0][0].int_v == 444);
|
||||
}
|
||||
|
||||
result = io.Request(state);
|
||||
{
|
||||
MG_ASSERT(result.size() == 1);
|
||||
auto &list_of_values_1 = std::get<ListedValues>(result[0].values);
|
||||
MG_ASSERT(list_of_values_1.properties[0][0].int_v == 1);
|
||||
}
|
||||
|
||||
// Exhaust it, request should be empty
|
||||
result = io.Request(state);
|
||||
MG_ASSERT(result.size() == 0);
|
||||
}
|
||||
|
||||
int main() {
|
||||
SimulatorConfig config{
|
||||
.drop_percent = 0,
|
||||
@ -165,9 +191,9 @@ int main() {
|
||||
std::vector<Address> a_2_peers = {a_addrs[0], a_addrs[2]};
|
||||
std::vector<Address> a_3_peers = {a_addrs[0], a_addrs[1]};
|
||||
|
||||
ConcreteStorageRsm a_1{std::move(a_io_1), a_1_peers, ShardRsmV2{}};
|
||||
ConcreteStorageRsm a_2{std::move(a_io_2), a_2_peers, ShardRsmV2{}};
|
||||
ConcreteStorageRsm a_3{std::move(a_io_3), a_3_peers, ShardRsmV2{}};
|
||||
ConcreteStorageRsm a_1{std::move(a_io_1), a_1_peers, MockedShardRsm{}};
|
||||
ConcreteStorageRsm a_2{std::move(a_io_2), a_2_peers, MockedShardRsm{}};
|
||||
ConcreteStorageRsm a_3{std::move(a_io_3), a_3_peers, MockedShardRsm{}};
|
||||
|
||||
auto a_thread_1 = std::jthread(RunStorageRaft<SimulatorTransport>, std::move(a_1));
|
||||
simulator.IncrementServerCountAndWaitForQuiescentState(a_addrs[0]);
|
||||
@ -185,9 +211,9 @@ int main() {
|
||||
std::vector<Address> b_2_peers = {b_addrs[0], b_addrs[2]};
|
||||
std::vector<Address> b_3_peers = {b_addrs[0], b_addrs[1]};
|
||||
|
||||
ConcreteStorageRsm b_1{std::move(b_io_1), b_1_peers, ShardRsmV2{}};
|
||||
ConcreteStorageRsm b_2{std::move(b_io_2), b_2_peers, ShardRsmV2{}};
|
||||
ConcreteStorageRsm b_3{std::move(b_io_3), b_3_peers, ShardRsmV2{}};
|
||||
ConcreteStorageRsm b_1{std::move(b_io_1), b_1_peers, MockedShardRsm{}};
|
||||
ConcreteStorageRsm b_2{std::move(b_io_2), b_2_peers, MockedShardRsm{}};
|
||||
ConcreteStorageRsm b_3{std::move(b_io_3), b_3_peers, MockedShardRsm{}};
|
||||
|
||||
auto b_thread_1 = std::jthread(RunStorageRaft<SimulatorTransport>, std::move(b_1));
|
||||
simulator.IncrementServerCountAndWaitForQuiescentState(b_addrs[0]);
|
||||
@ -229,31 +255,15 @@ int main() {
|
||||
// also get the current shard map
|
||||
CoordinatorClient<SimulatorTransport> coordinator_client(cli_io, c_addrs[0], c_addrs);
|
||||
|
||||
QueryEngineMiddleware<SimulatorTransport, ScanVerticesRequest, ScanVerticesResponse> io(std::move(coordinator_client),
|
||||
std::move(cli_io));
|
||||
ShardRequestManager<SimulatorTransport, ScanVerticesRequest, ScanVerticesResponse> io(std::move(coordinator_client),
|
||||
std::move(cli_io));
|
||||
|
||||
ExecutionState<ScanVerticesRequest> state;
|
||||
ExecutionState<ScanVerticesRequest> state{.label = "test_label"};
|
||||
state.key = std::make_optional<CompoundKey>(
|
||||
std::vector{memgraph::storage::v3::PropertyValue(0), memgraph::storage::v3::PropertyValue(0)});
|
||||
state.label = "test_label";
|
||||
|
||||
// auto result = io.TestRequest(state);
|
||||
io.StartTransaction();
|
||||
auto result = io.Request(state);
|
||||
auto &list_of_values = std::get<ListedValues>(result[0].values);
|
||||
std::cout << "Result is: " << list_of_values.properties[0][0].int_v << std::endl;
|
||||
auto &list_of_values_v2 = std::get<ListedValues>(result[1].values);
|
||||
std::cout << "Result is: " << list_of_values_v2.properties[0][0].int_v << std::endl;
|
||||
|
||||
result = io.Request(state);
|
||||
std::cout << "Result is: " << result.size() << std::endl;
|
||||
|
||||
// auto &list_of_values2 = std::get<ListedValues>(result[0].values);
|
||||
// std::cout << "Result is: " << list_of_values2.properties[0][0].int_v << std::endl;
|
||||
|
||||
// exhaust it
|
||||
result = io.Request(state);
|
||||
std::cout << "Result is: " << result.size() << std::endl;
|
||||
TestScanAll(io);
|
||||
|
||||
simulator.ShutDown();
|
||||
return 0;
|
||||
|
Loading…
Reference in New Issue
Block a user