Add abstract interface to Middleware and rename it. Cleaned up member function implementation and tests

This commit is contained in:
Kostas Kyrimis 2022-09-05 15:46:03 +03:00
parent ba06d29a35
commit e442bf435a
3 changed files with 91 additions and 93 deletions
src/query/v2
tests/simulation

View File

@ -45,15 +45,16 @@ class RsmStorageClientManager {
RsmStorageClientManager(const RsmStorageClientManager &) = delete;
RsmStorageClientManager(RsmStorageClientManager &&) = delete;
void AddClient(const std::string &label, CompoundKey cm_k, TStorageClient client) {
cli_cache_[label].insert({std::move(cm_k), std::move(client)});
void AddClient(const std::string &label, CompoundKey key, TStorageClient client) {
cli_cache_[label].insert({std::move(key), std::move(client)});
}
bool Exists(const std::string &label, const CompoundKey &cm_k) { return cli_cache_[label].contains(cm_k); }
bool Exists(const std::string &label, const CompoundKey &key) { return cli_cache_[label].contains(key); }
void PurgeCache() { cli_cache_.clear(); }
// void EvictFromCache(std::vector<TStorageClient>);
TStorageClient &GetClient(const std::string &label, CompoundKey key) { return cli_cache_[label].find(key)->second; }
TStorageClient &GetClient(const std::string &label, const CompoundKey &key) {
return cli_cache_[label].find(key)->second;
}
private:
std::unordered_map<std::string, std::map<CompoundKey, TStorageClient>> cli_cache_;
@ -63,31 +64,51 @@ template <typename TRequest>
struct ExecutionState {
using CompoundKey = memgraph::io::rsm::ShardRsmKey;
using Shard = memgraph::coordinator::Shard;
std::optional<std::vector<Shard>> state_;
std::string label;
// using CompoundKey = memgraph::coordinator::CompoundKey;
const std::string label;
// CompoundKey is optional because some operators require to iterate over all the available keys
// of a shard. One example is ScanAll, where we only require the field label.
std::optional<CompoundKey> key;
// Transaction id to be filled by the ShardRequestManager implementation
memgraph::coordinator::Hlc transaction_id;
// Initialized by ShardRequestManager implementation. This vector is filled with the shards that
// the ShardRequestManager impl will send requests to. When a request to a shard exhausts it, meaning that
// it pulled all the requested data from the given Shard, it will be removed from the Vector. When the Vector becomes
// empty, it means that all of the requests have completed succefully.
std::optional<std::vector<Shard>> state_;
// 1-1 mapping with `state_`.
// A vector that tracks request metatdata for each shard (For example, next_id for a ScanAll on Shard A)
std::vector<TRequest> requests;
};
namespace rsm = memgraph::io::rsm;
class ShardRequestManagerInterface {
public:
ShardRequestManagerInterface() = default;
virtual void StartTransaction() = 0;
virtual std::vector<ScanVerticesResponse> Request(ExecutionState<ScanVerticesRequest> &state) = 0;
virtual ~ShardRequestManagerInterface() {}
ShardRequestManagerInterface(const ShardRequestManagerInterface &) = delete;
ShardRequestManagerInterface(ShardRequestManagerInterface &&) = delete;
};
// TODO(kostasrim)rename this class template
template <typename TTransport, typename... Rest>
class QueryEngineMiddleware {
class ShardRequestManager : public ShardRequestManagerInterface {
public:
using StorageWriteRequest = memgraph::io::rsm::StorageWriteRequest;
using StorageWriteResponse = memgraph::io::rsm::StorageWriteResponse;
using StorageClient =
memgraph::coordinator::RsmClient<TTransport, rsm::StorageWriteRequest, rsm::StorageWriteResponse, Rest...>;
memgraph::coordinator::RsmClient<TTransport, StorageWriteRequest, StorageWriteResponse, Rest...>;
using CoordinatorClient = memgraph::coordinator::CoordinatorClient<TTransport>;
using Address = memgraph::io::Address;
using Shard = memgraph::coordinator::Shard;
using ShardMap = memgraph::coordinator::ShardMap;
using CompoundKey = memgraph::coordinator::CompoundKey;
QueryEngineMiddleware(CoordinatorClient coord, memgraph::io::Io<TTransport> &&io)
ShardRequestManager(CoordinatorClient coord, memgraph::io::Io<TTransport> &&io)
: coord_cli_(std::move(coord)), io_(std::move(io)) {}
void StartTransaction() {
~ShardRequestManager() override {}
void StartTransaction() override {
memgraph::coordinator::HlcRequest req{.last_shard_map_version = shards_map_.GetHlc()};
auto read_res = coord_cli_.SendReadRequest(req);
if (read_res.HasError()) {
@ -99,20 +120,20 @@ class QueryEngineMiddleware {
// Transaction ID to be used later...
transaction_id_ = hlc_response.new_hlc;
if (hlc_response.fresher_shard_map) {
shards_map_ = hlc_response.fresher_shard_map.value();
} else {
if (!hlc_response.fresher_shard_map) {
throw std::runtime_error("Should handle gracefully!");
}
shards_map_ = hlc_response.fresher_shard_map.value();
}
std::vector<ScanVerticesResponse> Request(ExecutionState<ScanVerticesRequest> &state) {
MaybeUpdateExecutionState(state);
std::vector<ScanVerticesResponse> Request(ExecutionState<ScanVerticesRequest> &state) override {
MaybeInitializeExecutionState(state);
std::vector<ScanVerticesResponse> responses;
auto &state_ref = *state.state_;
size_t id = 0;
for (auto shard_it = state_ref.begin(); shard_it != state_ref.end(); ++id) {
auto &storage_client = GetStorageClientForShard(state.label, state.requests[id].start_id.second);
// TODO(kostasrim) Currently requests return the result directly. Adjust this when the API works MgFuture instead.
auto read_response_result = storage_client.SendReadRequest(state.requests[id]);
// RETRY on timeouts?
// Sometimes this produces a timeout. Temporary solution is to use a while(true) as was done in shard_map test
@ -120,7 +141,7 @@ class QueryEngineMiddleware {
throw std::runtime_error("Read request error");
}
if (read_response_result.GetValue().success == false) {
throw std::runtime_error("ReadRequest failed");
throw std::runtime_error("Request did not succeed");
}
responses.push_back(read_response_result.GetValue());
if (!read_response_result.GetValue().next_start_id) {
@ -130,30 +151,12 @@ class QueryEngineMiddleware {
++shard_it;
}
}
// TODO(kostasrim) Update state accordingly
// TODO(kostasrim) Before returning start prefetching the batch (this shall be done once we get MgFuture as return
// result of storage_client.SendReadRequest()).
return responses;
// For a future based API. Also maybe introduce a `Retry` function that accepts a lambda which is the request
// and a number denoting the number of times the request is retried until an exception or an error is returned.
// std::vector<memgraph::io::future<ScanAllVerticesRequest>> requests;
// for (const auto &shard : state.state_) {
// auto &storage_client = GetStorageClientForShard(state.Label, rqst.label);
// requests.push_back(client->Request(rqst));
// }
//
// std::vector<ScanAllVerticesResponse> responses;
// for (auto &f : requests) {
// f.wait();
// if (f.HasError()) {
// // handle error
// }
// responses.push_back(std::move(f).Value());
// }
}
// CreateVerticesResponse Request(CreateVerticesRequest rqst, ExecutionState &state) {
// // MaybeUpdateShardMap();
// // MaybeUpdateExecutionState();
// }
std::vector<CreateVerticesResponse> Request(ExecutionState<CreateVerticesRequest> &state) {}
// size_t TestRequest(ExecutionState &state) {
// MaybeUpdateShardMap(state);
@ -239,7 +242,7 @@ class QueryEngineMiddleware {
}
}
void MaybeUpdateExecutionState(ExecutionState<ScanVerticesRequest> &state) {
void MaybeInitializeExecutionState(ExecutionState<ScanVerticesRequest> &state) {
if (state.state_) {
return;
}
@ -257,19 +260,20 @@ class QueryEngineMiddleware {
// std::vector<storageclient> GetStorageClientFromShardforRange(const std::string &label, const CompoundKey &start,
// const CompoundKey &end);
StorageClient &GetStorageClientForShard(const std::string &label, const CompoundKey &cm_k) {
if (storage_cli_manager_.Exists(label, cm_k)) {
return storage_cli_manager_.GetClient(label, cm_k);
StorageClient &GetStorageClientForShard(const std::string &label, const CompoundKey &key) {
if (storage_cli_manager_.Exists(label, key)) {
return storage_cli_manager_.GetClient(label, key);
}
auto target_shard = shards_map_.GetShardForKey(label, cm_k);
AddStorageClientToManager(std::move(target_shard), label, cm_k);
return storage_cli_manager_.GetClient(label, cm_k);
auto target_shard = shards_map_.GetShardForKey(label, key);
AddStorageClientToManager(std::move(target_shard), label, key);
return storage_cli_manager_.GetClient(label, key);
}
void AddStorageClientToManager(Shard target_shard, const std::string &label, const CompoundKey &cm_k) {
MG_ASSERT(!target_shard.empty());
auto leader_addr = target_shard.front();
std::vector<Address> addresses;
addresses.reserve(target_shard.size());
for (auto &address : target_shard) {
addresses.push_back(std::move(address.address));
}

View File

@ -54,7 +54,7 @@ using memgraph::storage::v3::PropertyValue;
using ShardRsmKey = std::vector<memgraph::storage::v3::PropertyValue>;
class ShardRsmV2 {
class MockedShardRsm {
std::map<ShardRsmKey, int> state_;
ShardRsmKey minimum_key_;
std::optional<ShardRsmKey> maximum_key_{std::nullopt};
@ -101,22 +101,6 @@ class ShardRsmV2 {
return ret;
}
// StorageReadResponse Read(StorageReadRequest request) {
// StorageReadResponse ret;
//
// if (!IsKeyInRange(request.key)) {
// ret.latest_known_shard_map_version = shard_map_version_;
// ret.shard_rsm_success = false;
// } else if (state_.contains(request.key)) {
// ret.value = state_[request.key];
// ret.shard_rsm_success = true;
// } else {
// ret.shard_rsm_success = false;
// ret.value = std::nullopt;
// }
// return ret;
// }
//
StorageWriteResponse Apply(StorageWriteRequest request) {
StorageWriteResponse ret;

View File

@ -55,7 +55,6 @@ using memgraph::io::rsm::Raft;
using memgraph::io::rsm::ReadRequest;
using memgraph::io::rsm::ReadResponse;
using memgraph::io::rsm::RsmClient;
using memgraph::io::rsm::ShardRsm;
using memgraph::io::rsm::StorageReadRequest;
using memgraph::io::rsm::StorageReadResponse;
using memgraph::io::rsm::StorageWriteRequest;
@ -117,16 +116,43 @@ ShardMap CreateDummyShardmap(memgraph::coordinator::Address a_io_1, memgraph::co
} // namespace
using ConcreteCoordinatorRsm = CoordinatorRsm<SimulatorTransport>;
using ConcreteStorageRsm = Raft<SimulatorTransport, ShardRsmV2, StorageWriteRequest, StorageWriteResponse,
using ConcreteStorageRsm = Raft<SimulatorTransport, MockedShardRsm, StorageWriteRequest, StorageWriteResponse,
ScanVerticesRequest, ScanVerticesResponse>;
template <typename IoImpl>
void RunStorageRaft(
Raft<IoImpl, ShardRsmV2, StorageWriteRequest, StorageWriteResponse, ScanVerticesRequest, ScanVerticesResponse>
Raft<IoImpl, MockedShardRsm, StorageWriteRequest, StorageWriteResponse, ScanVerticesRequest, ScanVerticesResponse>
server) {
server.Run();
}
template <typename ShardRequestManager>
void TestScanAll(ShardRequestManager &io) {
ExecutionState<ScanVerticesRequest> state{.label = "test_label"};
state.key = std::make_optional<CompoundKey>(
std::vector{memgraph::storage::v3::PropertyValue(0), memgraph::storage::v3::PropertyValue(0)});
auto result = io.Request(state);
MG_ASSERT(result.size() == 2);
{
auto &list_of_values_1 = std::get<ListedValues>(result[0].values);
MG_ASSERT(list_of_values_1.properties[0][0].int_v == 0);
auto &list_of_values_2 = std::get<ListedValues>(result[1].values);
MG_ASSERT(list_of_values_2.properties[0][0].int_v == 444);
}
result = io.Request(state);
{
MG_ASSERT(result.size() == 1);
auto &list_of_values_1 = std::get<ListedValues>(result[0].values);
MG_ASSERT(list_of_values_1.properties[0][0].int_v == 1);
}
// Exhaust it, request should be empty
result = io.Request(state);
MG_ASSERT(result.size() == 0);
}
int main() {
SimulatorConfig config{
.drop_percent = 0,
@ -165,9 +191,9 @@ int main() {
std::vector<Address> a_2_peers = {a_addrs[0], a_addrs[2]};
std::vector<Address> a_3_peers = {a_addrs[0], a_addrs[1]};
ConcreteStorageRsm a_1{std::move(a_io_1), a_1_peers, ShardRsmV2{}};
ConcreteStorageRsm a_2{std::move(a_io_2), a_2_peers, ShardRsmV2{}};
ConcreteStorageRsm a_3{std::move(a_io_3), a_3_peers, ShardRsmV2{}};
ConcreteStorageRsm a_1{std::move(a_io_1), a_1_peers, MockedShardRsm{}};
ConcreteStorageRsm a_2{std::move(a_io_2), a_2_peers, MockedShardRsm{}};
ConcreteStorageRsm a_3{std::move(a_io_3), a_3_peers, MockedShardRsm{}};
auto a_thread_1 = std::jthread(RunStorageRaft<SimulatorTransport>, std::move(a_1));
simulator.IncrementServerCountAndWaitForQuiescentState(a_addrs[0]);
@ -185,9 +211,9 @@ int main() {
std::vector<Address> b_2_peers = {b_addrs[0], b_addrs[2]};
std::vector<Address> b_3_peers = {b_addrs[0], b_addrs[1]};
ConcreteStorageRsm b_1{std::move(b_io_1), b_1_peers, ShardRsmV2{}};
ConcreteStorageRsm b_2{std::move(b_io_2), b_2_peers, ShardRsmV2{}};
ConcreteStorageRsm b_3{std::move(b_io_3), b_3_peers, ShardRsmV2{}};
ConcreteStorageRsm b_1{std::move(b_io_1), b_1_peers, MockedShardRsm{}};
ConcreteStorageRsm b_2{std::move(b_io_2), b_2_peers, MockedShardRsm{}};
ConcreteStorageRsm b_3{std::move(b_io_3), b_3_peers, MockedShardRsm{}};
auto b_thread_1 = std::jthread(RunStorageRaft<SimulatorTransport>, std::move(b_1));
simulator.IncrementServerCountAndWaitForQuiescentState(b_addrs[0]);
@ -229,31 +255,15 @@ int main() {
// also get the current shard map
CoordinatorClient<SimulatorTransport> coordinator_client(cli_io, c_addrs[0], c_addrs);
QueryEngineMiddleware<SimulatorTransport, ScanVerticesRequest, ScanVerticesResponse> io(std::move(coordinator_client),
std::move(cli_io));
ShardRequestManager<SimulatorTransport, ScanVerticesRequest, ScanVerticesResponse> io(std::move(coordinator_client),
std::move(cli_io));
ExecutionState<ScanVerticesRequest> state;
ExecutionState<ScanVerticesRequest> state{.label = "test_label"};
state.key = std::make_optional<CompoundKey>(
std::vector{memgraph::storage::v3::PropertyValue(0), memgraph::storage::v3::PropertyValue(0)});
state.label = "test_label";
// auto result = io.TestRequest(state);
io.StartTransaction();
auto result = io.Request(state);
auto &list_of_values = std::get<ListedValues>(result[0].values);
std::cout << "Result is: " << list_of_values.properties[0][0].int_v << std::endl;
auto &list_of_values_v2 = std::get<ListedValues>(result[1].values);
std::cout << "Result is: " << list_of_values_v2.properties[0][0].int_v << std::endl;
result = io.Request(state);
std::cout << "Result is: " << result.size() << std::endl;
// auto &list_of_values2 = std::get<ListedValues>(result[0].values);
// std::cout << "Result is: " << list_of_values2.properties[0][0].int_v << std::endl;
// exhaust it
result = io.Request(state);
std::cout << "Result is: " << result.size() << std::endl;
TestScanAll(io);
simulator.ShutDown();
return 0;