Add test for sharded_map
This commit is contained in:
parent
1b2c8f6b29
commit
9aeca7a4b3
src
tests/simulation
@ -87,8 +87,6 @@ using ReadRequests = std::variant<HlcRequest, GetShardMapRequest>;
|
||||
using ReadResponses = std::variant<HlcResponse, GetShardMapResponse>;
|
||||
|
||||
class Coordinator {
|
||||
using StandbySotrageEnginePool = std::unordered_set<Address>;
|
||||
|
||||
ShardMap shard_map_;
|
||||
/// The highest reserved timestamp / highest allocated timestamp
|
||||
/// is a way for minimizing communication involved in query engines
|
||||
@ -108,17 +106,16 @@ class Coordinator {
|
||||
/// Increment our
|
||||
ReadResponses Read(HlcRequest &&hlc_request) {
|
||||
HlcResponse res{};
|
||||
shard_map_.UpdateShardMapVersion();
|
||||
|
||||
res.new_hlc = shard_map_.GetHlc();
|
||||
auto hlc_shard_map = shard_map_.GetHlc();
|
||||
|
||||
// TODO(gabor) Once the walclock update is implemented, this
|
||||
// comparison should also be updated
|
||||
if (hlc_request.last_shard_map_version.logical_id == res.new_hlc.logical_id) {
|
||||
res.fresher_shard_map = shard_map_;
|
||||
} else {
|
||||
res.fresher_shard_map = {};
|
||||
}
|
||||
MG_ASSERT(!(hlc_request.last_shard_map_version.logical_id > hlc_shard_map.logical_id));
|
||||
|
||||
res.new_hlc = shard_map_.UpdateShardMapVersion();
|
||||
|
||||
res.fresher_shard_map = hlc_request.last_shard_map_version.logical_id < hlc_shard_map.logical_id
|
||||
? std::make_optional(shard_map_)
|
||||
: std::nullopt;
|
||||
|
||||
return res;
|
||||
}
|
||||
@ -173,6 +170,8 @@ class Coordinator {
|
||||
}
|
||||
|
||||
public:
|
||||
explicit Coordinator(ShardMap sm) : shard_map_{(sm)} {}
|
||||
|
||||
ReadResponses Read(ReadRequests requests) {
|
||||
return std::visit([&](auto &&requests) { return Read(requests); }, std::move(requests));
|
||||
}
|
||||
|
@ -55,7 +55,10 @@ struct ShardMap {
|
||||
public:
|
||||
// TODO(gabor) later we will want to update the wallclock time with
|
||||
// the given Io<impl>'s time as well
|
||||
void UpdateShardMapVersion() noexcept { ++shard_map_version.logical_id; }
|
||||
Hlc UpdateShardMapVersion() noexcept {
|
||||
++shard_map_version.logical_id;
|
||||
return shard_map_version;
|
||||
}
|
||||
|
||||
Hlc GetHlc() const noexcept { return shard_map_version; }
|
||||
|
||||
@ -90,6 +93,8 @@ struct ShardMap {
|
||||
// Find a random place for the server to plug in
|
||||
}
|
||||
|
||||
std::map<Label, Shards> &GetShards() noexcept { return shards; }
|
||||
|
||||
Shards GetShardsForRange(Label label, CompoundKey start, CompoundKey end);
|
||||
|
||||
Shard GetShardForKey(Label label, CompoundKey key);
|
||||
|
@ -89,13 +89,16 @@ class StorageRsm {
|
||||
StorageGetResponse Read(StorageGetRequest request) {
|
||||
StorageGetResponse ret;
|
||||
|
||||
if (IsKeyInRange(request.key)) {
|
||||
if (!IsKeyInRange(request.key)) {
|
||||
std::cout << "ONE" << std::endl;
|
||||
ret.latest_known_shard_map_version = shard_map_version_;
|
||||
ret.shard_rsm_success = false;
|
||||
} else if (state_.contains(request.key)) {
|
||||
std::cout << "TWO" << std::endl;
|
||||
ret.value = state_[request.key];
|
||||
ret.shard_rsm_success = true;
|
||||
} else {
|
||||
std::cout << "THREE" << std::endl;
|
||||
ret.shard_rsm_success = false;
|
||||
ret.value = std::nullopt;
|
||||
}
|
||||
@ -106,13 +109,15 @@ class StorageRsm {
|
||||
StorageWriteResponse ret;
|
||||
|
||||
// Key is outside the prohibited range
|
||||
if (IsKeyInRange(request.key)) {
|
||||
if (!IsKeyInRange(request.key)) {
|
||||
ret.latest_known_shard_map_version = shard_map_version_;
|
||||
ret.shard_rsm_success = false;
|
||||
std::cout << "WRITE 0" << std::endl;
|
||||
}
|
||||
// Key exist
|
||||
else if (state_.contains(request.key)) {
|
||||
auto &val = state_[request.key];
|
||||
std::cout << "WRITE 1" << std::endl;
|
||||
|
||||
/*
|
||||
* Delete
|
||||
@ -121,6 +126,7 @@ class StorageRsm {
|
||||
ret.shard_rsm_success = true;
|
||||
ret.last_value = val;
|
||||
state_.erase(state_.find(request.key));
|
||||
std::cout << "WRITE 2" << std::endl;
|
||||
}
|
||||
|
||||
/*
|
||||
@ -132,9 +138,12 @@ class StorageRsm {
|
||||
ret.shard_rsm_success = true;
|
||||
|
||||
val = request.value.value();
|
||||
std::cout << "WRITE 3" << std::endl;
|
||||
|
||||
} else {
|
||||
ret.last_value = val;
|
||||
ret.shard_rsm_success = false;
|
||||
std::cout << "WRITE 4" << std::endl;
|
||||
}
|
||||
}
|
||||
/*
|
||||
@ -145,8 +154,10 @@ class StorageRsm {
|
||||
ret.shard_rsm_success = true;
|
||||
|
||||
state_.emplace(request.key, std::move(request.value).value());
|
||||
std::cout << "WRITE 5" << std::endl;
|
||||
}
|
||||
|
||||
std::cout << "WRITE ret" << std::endl;
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
@ -31,4 +31,4 @@ add_simulation_test(raft.cpp address)
|
||||
|
||||
add_simulation_test(trial_query_storage/query_storage_test.cpp address)
|
||||
|
||||
#add_simulation_test(sharded_map.cpp address)
|
||||
add_simulation_test(sharded_map.cpp address)
|
||||
|
@ -253,8 +253,8 @@ void RunSimulation() {
|
||||
std::vector<Address> server_addrs{srv_addr_1, srv_addr_2, srv_addr_3};
|
||||
Address leader = server_addrs[0];
|
||||
|
||||
RsmClient<Io<SimulatorTransport>, CasRequest, CasResponse, GetRequest, GetResponse> client(
|
||||
std::move(cli_io), std::move(leader), std::move(server_addrs));
|
||||
RsmClient<Io<SimulatorTransport>, CasRequest, CasResponse, GetRequest, GetResponse> client(cli_io, leader,
|
||||
server_addrs);
|
||||
|
||||
const int key = 0;
|
||||
std::optional<int> last_known_value;
|
||||
|
@ -24,12 +24,21 @@
|
||||
#include "io/rsm/shard_rsm.hpp"
|
||||
#include "io/simulator/simulator.hpp"
|
||||
#include "io/simulator/simulator_transport.hpp"
|
||||
#include "utils/rsm_client.hpp"
|
||||
|
||||
using memgraph::coordinator::Address;
|
||||
using memgraph::coordinator::AddressAndStatus;
|
||||
using memgraph::coordinator::CompoundKey;
|
||||
using memgraph::coordinator::Coordinator;
|
||||
using memgraph::coordinator::Shard;
|
||||
using memgraph::coordinator::ShardMap;
|
||||
using memgraph::coordinator::Shards;
|
||||
using memgraph::coordinator::Status;
|
||||
using memgraph::io::Address;
|
||||
using memgraph::io::Io;
|
||||
using memgraph::io::ResponseEnvelope;
|
||||
using memgraph::io::ResponseFuture;
|
||||
using memgraph::io::Time;
|
||||
using memgraph::io::rsm::CoordinatorRsm;
|
||||
using memgraph::io::rsm::Raft;
|
||||
using memgraph::io::rsm::ReadRequest;
|
||||
@ -46,60 +55,96 @@ using memgraph::io::simulator::SimulatorConfig;
|
||||
using memgraph::io::simulator::SimulatorStats;
|
||||
using memgraph::io::simulator::SimulatorTransport;
|
||||
|
||||
namespace {
|
||||
ShardMap CreateDummyShardmap(Address a_io_1, Address a_io_2, Address a_io_3, Address b_io_1, Address b_io_2,
|
||||
Address b_io_3) {
|
||||
ShardMap sm1;
|
||||
auto &shards = sm1.GetShards();
|
||||
|
||||
// 1
|
||||
std::string label1 = std::string("label1");
|
||||
auto key1 = memgraph::storage::v3::PropertyValue(3);
|
||||
auto key2 = memgraph::storage::v3::PropertyValue(4);
|
||||
CompoundKey cm1 = {key1, key2};
|
||||
AddressAndStatus aas1_1{.address = a_io_1, .status = Status::CONSENSUS_PARTICIPANT};
|
||||
AddressAndStatus aas1_2{.address = a_io_2, .status = Status::CONSENSUS_PARTICIPANT};
|
||||
AddressAndStatus aas1_3{.address = a_io_3, .status = Status::CONSENSUS_PARTICIPANT};
|
||||
|
||||
Shard shard1 = {aas1_1, aas1_2, aas1_3};
|
||||
Shards shards1;
|
||||
shards1[cm1] = shard1;
|
||||
|
||||
// 2
|
||||
std::string label2 = std::string("label2");
|
||||
auto key3 = memgraph::storage::v3::PropertyValue(12);
|
||||
auto key4 = memgraph::storage::v3::PropertyValue(13);
|
||||
CompoundKey cm2 = {key3, key4};
|
||||
AddressAndStatus aas2_1{.address = b_io_1, .status = Status::CONSENSUS_PARTICIPANT};
|
||||
AddressAndStatus aas2_2{.address = b_io_2, .status = Status::CONSENSUS_PARTICIPANT};
|
||||
AddressAndStatus aas2_3{.address = b_io_3, .status = Status::CONSENSUS_PARTICIPANT};
|
||||
|
||||
Shard shard2 = {aas2_1, aas2_2, aas2_3};
|
||||
Shards shards2;
|
||||
shards2[cm2] = shard2;
|
||||
|
||||
shards[label2] = shards2;
|
||||
|
||||
return sm1;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
using ConcreteCoordinatorRsm = CoordinatorRsm<SimulatorTransport>;
|
||||
using ConcreteStorageRsm = Raft<SimulatorTransport, StorageRsm, StorageWriteRequest, StorageWriteResponse,
|
||||
StorageGetRequest, StorageGetResponse>;
|
||||
|
||||
template <typename IoImpl>
|
||||
void RunStorageRaft(
|
||||
Raft<IoImpl, StorageRsm, StorageWriteRequest, StorageWriteResponse, StorageGetRequest, StorageGetResponse> server) {
|
||||
server.Run();
|
||||
}
|
||||
|
||||
int main() {
|
||||
SimulatorConfig config{
|
||||
/*
|
||||
.drop_percent = 5,
|
||||
.perform_timeouts = true,
|
||||
.scramble_messages = true,
|
||||
.rng_seed = 0,
|
||||
.start_time = 256 * 1024,
|
||||
.abort_time = std::chrono::microseconds{8 * 1024 * 1024},
|
||||
*/
|
||||
.drop_percent = 5,
|
||||
.perform_timeouts = true,
|
||||
.scramble_messages = true,
|
||||
.rng_seed = 0,
|
||||
.start_time = Time::min() + std::chrono::microseconds{256 * 1024},
|
||||
.abort_time = Time::min() + std::chrono::microseconds{8 * 1024 * 1024},
|
||||
};
|
||||
|
||||
auto simulator = Simulator(config);
|
||||
|
||||
Io<SimulatorTransport> cli_io = simulator.RegisterNew();
|
||||
|
||||
// spin up coordinators
|
||||
// auto c_thread_1 = std::jthread(RunRaft< Coordinator>, std::move(c_1));
|
||||
// simulator.IncrementServerCountAndWaitForQuiescentState(c_addrs[0]);
|
||||
|
||||
Io<SimulatorTransport> c_io_1 = simulator.RegisterNew();
|
||||
Io<SimulatorTransport> c_io_2 = simulator.RegisterNew();
|
||||
Io<SimulatorTransport> c_io_3 = simulator.RegisterNew();
|
||||
// auto c_thread_2 = std::jthread(RunRaft< Coordinator>, std::move(c_2));
|
||||
// simulator.IncrementServerCountAndWaitForQuiescentState(c_addrs[1]);
|
||||
|
||||
Address c_addrs[] = {c_io_1.GetAddress(), c_io_2.GetAddress(), c_io_3.GetAddress()};
|
||||
|
||||
std::vector<Address> c_1_peers = {c_addrs[1], c_addrs[2]};
|
||||
std::vector<Address> c_2_peers = {c_addrs[0], c_addrs[2]};
|
||||
std::vector<Address> c_3_peers = {c_addrs[0], c_addrs[1]};
|
||||
|
||||
ConcreteCoordinatorRsm c_1{std::move(c_io_1), c_1_peers, Coordinator{}};
|
||||
ConcreteCoordinatorRsm c_2{std::move(c_io_2), c_2_peers, Coordinator{}};
|
||||
ConcreteCoordinatorRsm c_3{std::move(c_io_3), c_3_peers, Coordinator{}};
|
||||
|
||||
auto c_thread_1 = std::jthread([c_1]() mutable { c_1.Run(); });
|
||||
simulator.IncrementServerCountAndWaitForQuiescentState(c_addrs[0]);
|
||||
|
||||
/*
|
||||
auto c_thread_2 = std::jthread(RunRaft< Coordinator>, std::move(c_2));
|
||||
simulator.IncrementServerCountAndWaitForQuiescentState(c_addrs[1]);
|
||||
|
||||
auto c_thread_3 = std::jthread(RunRaft<Coordinator>, std::move(c_3));
|
||||
simulator.IncrementServerCountAndWaitForQuiescentState(c_addrs[2]);
|
||||
*/
|
||||
|
||||
// spin up shard A
|
||||
// auto c_thread_3 = std::jthread(RunRaft<Coordinator>, std::move(c_3));
|
||||
// simulator.IncrementServerCountAndWaitForQuiescentState(c_addrs[2]);
|
||||
|
||||
// Register
|
||||
Io<SimulatorTransport> a_io_1 = simulator.RegisterNew();
|
||||
Io<SimulatorTransport> a_io_2 = simulator.RegisterNew();
|
||||
Io<SimulatorTransport> a_io_3 = simulator.RegisterNew();
|
||||
|
||||
Address a_addrs[] = {a_io_1.GetAddress(), a_io_2.GetAddress(), a_io_3.GetAddress()};
|
||||
Io<SimulatorTransport> b_io_1 = simulator.RegisterNew();
|
||||
Io<SimulatorTransport> b_io_2 = simulator.RegisterNew();
|
||||
Io<SimulatorTransport> b_io_3 = simulator.RegisterNew();
|
||||
|
||||
// Preconfigure coordinator with kv shard 'A' and 'B'
|
||||
auto sm1 = CreateDummyShardmap(a_io_1.GetAddress(), a_io_2.GetAddress(), a_io_3.GetAddress(), b_io_1.GetAddress(),
|
||||
b_io_2.GetAddress(), b_io_3.GetAddress());
|
||||
auto sm2 = CreateDummyShardmap(a_io_1.GetAddress(), a_io_2.GetAddress(), a_io_3.GetAddress(), b_io_1.GetAddress(),
|
||||
b_io_2.GetAddress(), b_io_3.GetAddress());
|
||||
auto sm3 = CreateDummyShardmap(a_io_1.GetAddress(), a_io_2.GetAddress(), a_io_3.GetAddress(), b_io_1.GetAddress(),
|
||||
b_io_2.GetAddress(), b_io_3.GetAddress());
|
||||
|
||||
// Spin up shard A
|
||||
std::vector<Address> a_addrs = {a_io_1.GetAddress(), a_io_2.GetAddress(), a_io_3.GetAddress()};
|
||||
|
||||
std::vector<Address> a_1_peers = {a_addrs[1], a_addrs[2]};
|
||||
std::vector<Address> a_2_peers = {a_addrs[0], a_addrs[2]};
|
||||
@ -109,13 +154,17 @@ int main() {
|
||||
ConcreteStorageRsm a_2{std::move(a_io_2), a_2_peers, StorageRsm{}};
|
||||
ConcreteStorageRsm a_3{std::move(a_io_3), a_3_peers, StorageRsm{}};
|
||||
|
||||
// spin up shard B
|
||||
auto a_thread_1 = std::jthread(RunStorageRaft<SimulatorTransport>, std::move(a_1));
|
||||
simulator.IncrementServerCountAndWaitForQuiescentState(a_addrs[0]);
|
||||
|
||||
Io<SimulatorTransport> b_io_1 = simulator.RegisterNew();
|
||||
Io<SimulatorTransport> b_io_2 = simulator.RegisterNew();
|
||||
Io<SimulatorTransport> b_io_3 = simulator.RegisterNew();
|
||||
auto a_thread_2 = std::jthread(RunStorageRaft<SimulatorTransport>, std::move(a_2));
|
||||
simulator.IncrementServerCountAndWaitForQuiescentState(a_addrs[1]);
|
||||
|
||||
Address b_addrs[] = {b_io_1.GetAddress(), b_io_2.GetAddress(), b_io_3.GetAddress()};
|
||||
auto a_thread_3 = std::jthread(RunStorageRaft<SimulatorTransport>, std::move(a_3));
|
||||
simulator.IncrementServerCountAndWaitForQuiescentState(a_addrs[2]);
|
||||
|
||||
// Spin up shard B
|
||||
std::vector<Address> b_addrs = {b_io_1.GetAddress(), b_io_2.GetAddress(), b_io_3.GetAddress()};
|
||||
|
||||
std::vector<Address> b_1_peers = {b_addrs[1], b_addrs[2]};
|
||||
std::vector<Address> b_2_peers = {b_addrs[0], b_addrs[2]};
|
||||
@ -125,7 +174,122 @@ int main() {
|
||||
ConcreteStorageRsm b_2{std::move(b_io_2), b_2_peers, StorageRsm{}};
|
||||
ConcreteStorageRsm b_3{std::move(b_io_3), b_3_peers, StorageRsm{}};
|
||||
|
||||
auto b_thread_1 = std::jthread(RunStorageRaft<SimulatorTransport>, std::move(b_1));
|
||||
simulator.IncrementServerCountAndWaitForQuiescentState(b_addrs[0]);
|
||||
|
||||
auto b_thread_2 = std::jthread(RunStorageRaft<SimulatorTransport>, std::move(b_2));
|
||||
simulator.IncrementServerCountAndWaitForQuiescentState(b_addrs[1]);
|
||||
|
||||
auto b_thread_3 = std::jthread(RunStorageRaft<SimulatorTransport>, std::move(b_3));
|
||||
simulator.IncrementServerCountAndWaitForQuiescentState(b_addrs[2]);
|
||||
|
||||
std::cout << "beginning test after servers have become quiescent" << std::endl;
|
||||
|
||||
// Spin up coordinators
|
||||
|
||||
Io<SimulatorTransport> c_io_1 = simulator.RegisterNew();
|
||||
Io<SimulatorTransport> c_io_2 = simulator.RegisterNew();
|
||||
Io<SimulatorTransport> c_io_3 = simulator.RegisterNew();
|
||||
|
||||
std::vector<Address> c_addrs = {c_io_1.GetAddress(), c_io_2.GetAddress(), c_io_3.GetAddress()};
|
||||
|
||||
std::vector<Address> c_1_peers = {c_addrs[1], c_addrs[2]};
|
||||
std::vector<Address> c_2_peers = {c_addrs[0], c_addrs[2]};
|
||||
std::vector<Address> c_3_peers = {c_addrs[0], c_addrs[1]};
|
||||
|
||||
ConcreteCoordinatorRsm c_1{std::move(c_io_1), c_1_peers, Coordinator{(sm1)}};
|
||||
ConcreteCoordinatorRsm c_2{std::move(c_io_2), c_2_peers, Coordinator{(sm2)}};
|
||||
ConcreteCoordinatorRsm c_3{std::move(c_io_3), c_3_peers, Coordinator{(sm3)}};
|
||||
|
||||
auto c_thread_1 = std::jthread([c_1]() mutable { c_1.Run(); });
|
||||
simulator.IncrementServerCountAndWaitForQuiescentState(c_addrs[0]);
|
||||
|
||||
auto c_thread_2 = std::jthread([c_2]() mutable { c_2.Run(); });
|
||||
simulator.IncrementServerCountAndWaitForQuiescentState(c_addrs[1]);
|
||||
|
||||
auto c_thread_3 = std::jthread([c_3]() mutable { c_3.Run(); });
|
||||
simulator.IncrementServerCountAndWaitForQuiescentState(c_addrs[2]);
|
||||
|
||||
// Have client contact coordinator RSM for a new transaction ID and
|
||||
// also get the current shard map
|
||||
using CoordinatorClient =
|
||||
RsmClient<Io<SimulatorTransport>, memgraph::coordinator::WriteRequests, memgraph::coordinator::WriteResponses,
|
||||
memgraph::coordinator::ReadRequests, memgraph::coordinator::ReadResponses>;
|
||||
CoordinatorClient coordinator_client(cli_io, c_addrs[2], c_addrs);
|
||||
|
||||
using StorageClient = RsmClient<Io<SimulatorTransport>, StorageWriteRequest, StorageWriteResponse, StorageGetRequest,
|
||||
StorageGetResponse>;
|
||||
StorageClient shard_a_client(cli_io, a_addrs[0], a_addrs);
|
||||
StorageClient shard_b_client(cli_io, b_addrs[0], b_addrs);
|
||||
|
||||
memgraph::coordinator::HlcRequest req;
|
||||
|
||||
// Last ShardMap Version The query engine knows about.
|
||||
ShardMap client_shard_map;
|
||||
req.last_shard_map_version = client_shard_map.GetHlc();
|
||||
|
||||
while (true) {
|
||||
// auto read_res_opt = coordinator_client.SendReadRequest(req);
|
||||
// if(!read_res_opt)
|
||||
// {
|
||||
// std::cout << "ERROR!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!0" << std::endl;
|
||||
// continue;
|
||||
// }
|
||||
// auto read_res = read_res_opt.value();
|
||||
|
||||
// auto res = std::get<memgraph::coordinator::HlcResponse>(read_res.read_return);
|
||||
|
||||
// auto transaction_id = res.new_hlc;
|
||||
|
||||
// client_shard_map = res.fresher_shard_map.value();
|
||||
|
||||
// // Have client use shard map to decide which shard to communicate
|
||||
// // with in order to write a new value
|
||||
|
||||
// //client_shard_map.
|
||||
StorageWriteRequest storage_req;
|
||||
auto write_key_1 = memgraph::storage::PropertyValue(3);
|
||||
auto write_key_2 = memgraph::storage::PropertyValue(4);
|
||||
storage_req.key = {write_key_1, write_key_2};
|
||||
storage_req.value = 1000;
|
||||
auto write_res_opt = shard_a_client.SendWriteRequest(storage_req);
|
||||
if (!write_res_opt) {
|
||||
std::cout << "ERROR!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!1" << std::endl;
|
||||
continue;
|
||||
}
|
||||
auto write_res = write_res_opt.value().write_return;
|
||||
|
||||
bool cas_succeeded = write_res.shard_rsm_success;
|
||||
|
||||
if (cas_succeeded) {
|
||||
last_known_value = i;
|
||||
} else {
|
||||
last_known_value = cas_response.last_value;
|
||||
continue;
|
||||
}
|
||||
|
||||
// ... write_res.
|
||||
|
||||
// Have client use shard map to decide which shard to communicate
|
||||
// with to read that same value back
|
||||
|
||||
StorageGetRequest storage_get_req;
|
||||
storage_get_req.key = {write_key_1, write_key_2};
|
||||
auto get_res_opt = shard_a_client.SendReadRequest(storage_get_req);
|
||||
if (!get_res_opt) {
|
||||
std::cout << "ERROR!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!2" << std::endl;
|
||||
continue;
|
||||
}
|
||||
auto get_res = get_res_opt.value();
|
||||
auto val = get_res.read_return.value.value();
|
||||
|
||||
std::cout << "val -> " << val << std::endl;
|
||||
|
||||
MG_ASSERT(get_res.read_return.value == 1000);
|
||||
break;
|
||||
}
|
||||
|
||||
simulator.ShutDown();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
111
tests/simulation/utils/rsm_client.hpp
Normal file
111
tests/simulation/utils/rsm_client.hpp
Normal file
@ -0,0 +1,111 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#include <iostream>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
#include "io/address.hpp"
|
||||
#include "io/rsm/raft.hpp"
|
||||
|
||||
using memgraph::io::Address;
|
||||
using memgraph::io::ResponseEnvelope;
|
||||
using memgraph::io::ResponseFuture;
|
||||
using memgraph::io::ResponseResult;
|
||||
using memgraph::io::rsm::ReadRequest;
|
||||
using memgraph::io::rsm::ReadResponse;
|
||||
using memgraph::io::rsm::WriteRequest;
|
||||
using memgraph::io::rsm::WriteResponse;
|
||||
|
||||
template <typename IoImpl, typename WriteRequestT, typename WriteResponseT, typename ReadRequestT,
|
||||
typename ReadResponseT>
|
||||
class RsmClient {
|
||||
using ServerPool = std::vector<Address>;
|
||||
|
||||
IoImpl io_;
|
||||
Address leader_;
|
||||
|
||||
std::mt19937 cli_rng_{0};
|
||||
ServerPool server_addrs_;
|
||||
|
||||
template <typename ResponseT>
|
||||
std::optional<ResponseT> CheckForCorrectLeader(ResponseT response) {
|
||||
if (response.retry_leader) {
|
||||
MG_ASSERT(!response.success, "retry_leader should never be set for successful responses");
|
||||
leader_ = response.retry_leader.value();
|
||||
std::cout << "client redirected to leader server " << leader_.last_known_port << std::endl;
|
||||
} else if (!response.success) {
|
||||
std::uniform_int_distribution<size_t> addr_distrib(0, (server_addrs_.size() - 1));
|
||||
size_t addr_index = addr_distrib(cli_rng_);
|
||||
leader_ = server_addrs_[addr_index];
|
||||
|
||||
std::cout << "client NOT redirected to leader server, trying a random one at index " << addr_index
|
||||
<< " with port " << leader_.last_known_port << std::endl;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
public:
|
||||
RsmClient(IoImpl io, Address leader, ServerPool server_addrs)
|
||||
: io_{io}, leader_{leader}, server_addrs_{server_addrs} {}
|
||||
|
||||
RsmClient() = delete;
|
||||
|
||||
std::optional<WriteResponse<WriteResponseT>> SendWriteRequest(WriteRequestT req) {
|
||||
WriteRequest<WriteRequestT> client_req;
|
||||
client_req.operation = req;
|
||||
|
||||
std::cout << "client sending CasRequest to Leader " << leader_.last_known_port << std::endl;
|
||||
ResponseFuture<WriteResponse<WriteResponseT>> response_future =
|
||||
io_.template Request<WriteRequest<WriteRequestT>, WriteResponse<WriteResponseT>>(leader_, client_req);
|
||||
ResponseResult<WriteResponse<WriteResponseT>> response_result = std::move(response_future).Wait();
|
||||
|
||||
if (response_result.HasError()) {
|
||||
std::cout << "client timed out while trying to communicate with leader server " << std::endl;
|
||||
// continue;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
ResponseEnvelope<WriteResponse<WriteResponseT>> response_envelope = response_result.GetValue();
|
||||
WriteResponse<WriteResponseT> write_response = response_envelope.message;
|
||||
|
||||
return CheckForCorrectLeader(write_response);
|
||||
}
|
||||
|
||||
std::optional<ReadResponse<ReadResponseT>> SendReadRequest(ReadRequestT req) {
|
||||
ReadRequest<ReadRequestT> read_req;
|
||||
read_req.operation = req;
|
||||
|
||||
std::cout << "client sending GetRequest to Leader " << leader_.last_known_port << std::endl;
|
||||
ResponseFuture<ReadResponse<ReadResponseT>> get_response_future =
|
||||
io_.template Request<ReadRequest<ReadRequestT>, ReadResponse<ReadResponseT>>(leader_, read_req);
|
||||
|
||||
// receive response
|
||||
ResponseResult<ReadResponse<ReadResponseT>> get_response_result = std::move(get_response_future).Wait();
|
||||
|
||||
if (get_response_result.HasError()) {
|
||||
std::cout << "client timed out while trying to communicate with leader server " << std::endl;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
ResponseEnvelope<ReadResponse<ReadResponseT>> get_response_envelope = get_response_result.GetValue();
|
||||
ReadResponse<ReadResponseT> read_get_response = get_response_envelope.message;
|
||||
|
||||
// if (!read_get_response.success) {
|
||||
// // sent to a non-leader
|
||||
// return {};
|
||||
// }
|
||||
|
||||
return CheckForCorrectLeader(read_get_response);
|
||||
}
|
||||
};
|
Loading…
Reference in New Issue
Block a user