Added CreateVertices request test and polished

This commit is contained in:
Kostas Kyrimis 2022-09-05 22:28:15 +03:00
parent c8bc4c7dbc
commit 57836f7c2b
3 changed files with 44 additions and 81 deletions
src/query/v2
tests/simulation

View File

@ -42,22 +42,23 @@ class RsmStorageClientManager {
public:
using CompoundKey = memgraph::io::rsm::ShardRsmKey;
using Shard = memgraph::coordinator::Shard;
using LabelId = memgraph::storage::v3::LabelId;
RsmStorageClientManager() = default;
RsmStorageClientManager(const RsmStorageClientManager &) = delete;
RsmStorageClientManager(RsmStorageClientManager &&) = delete;
void AddClient(const std::string &label, Shard key, TStorageClient client) {
cli_cache_[label].insert({std::move(key), std::move(client)});
void AddClient(const LabelId label_id, Shard key, TStorageClient client) {
cli_cache_[label_id].insert({std::move(key), std::move(client)});
}
bool Exists(const std::string &label, const Shard &key) { return cli_cache_[label].contains(key); }
bool Exists(const LabelId label_id, const Shard &key) { return cli_cache_[label_id].contains(key); }
void PurgeCache() { cli_cache_.clear(); }
TStorageClient &GetClient(const std::string &label, const Shard &key) { return cli_cache_[label].find(key)->second; }
TStorageClient &GetClient(const LabelId label_id, const Shard &key) { return cli_cache_[label_id].find(key)->second; }
private:
std::map<std::string, std::map<Shard, TStorageClient>> cli_cache_;
std::map<LabelId, std::map<Shard, TStorageClient>> cli_cache_;
};
template <typename TRequest>
@ -99,10 +100,9 @@ class ShardRequestManagerInterface {
template <typename TTransport, typename... Rest>
class ShardRequestManager : public ShardRequestManagerInterface {
public:
using StorageWriteRequest = memgraph::io::rsm::StorageWriteRequest;
using StorageWriteResponse = memgraph::io::rsm::StorageWriteResponse;
using StorageClient =
memgraph::coordinator::RsmClient<TTransport, StorageWriteRequest, StorageWriteResponse, Rest...>;
using WriteRequests = CreateVerticesRequest;
using WriteResponses = CreateVerticesResponse;
using StorageClient = memgraph::coordinator::RsmClient<TTransport, WriteRequests, WriteResponses, Rest...>;
using CoordinatorClient = memgraph::coordinator::CoordinatorClient<TTransport>;
using Address = memgraph::io::Address;
using Shard = memgraph::coordinator::Shard;
@ -167,16 +167,16 @@ class ShardRequestManager : public ShardRequestManagerInterface {
std::vector<CreateVerticesResponse> Request(ExecutionState<CreateVerticesRequest> &state,
std::vector<NewVertexLabel> new_vertices) {
MG_ASSERT(!new_vertices.empty());
MaybeInitializeExecutionState(state, std::move(new_vertices));
MaybeInitializeExecutionState(state, new_vertices);
std::vector<CreateVerticesResponse> responses;
auto &shard_cache_ref = state.shard_cache;
size_t id = 0;
for (auto shard_it = shard_cache_ref.begin(); shard_it != shard_cache_ref.end(); ++id) {
// This is fine because all new_vertices of each request end up on the same shard
Label label = state.requests[id].new_vertices[0].label_ids;
const Label label = state.requests[id].new_vertices[0].label_ids;
auto primary_key = state.requests[id].new_vertices[0].primary_key;
auto &storage_client = GetStorageClientForShard(label, primary_key);
auto read_response_result = storage_client.SendReadRequest(state.requests[id]);
auto &storage_client = GetStorageClientForShard(*shard_it, label.id);
auto read_response_result = storage_client.SendWriteRequest(state.requests[id]);
// RETRY on timeouts?
// Sometimes this produces a timeout. Temporary solution is to use a while(true) as was done in shard_map test
if (read_response_result.HasError()) {
@ -254,7 +254,7 @@ class ShardRequestManager : public ShardRequestManagerInterface {
state.shard_cache.push_back(shard);
}
per_shard_request_table[shard].new_vertices.push_back(
NewVertex{.label_ids = shards_map_.GetLabelId(new_vertex.label),
NewVertex{.label_ids = {shards_map_.GetLabelId(new_vertex.label)},
.primary_key = std::move(new_vertex.primary_key),
.properties = std::move(new_vertex.properties)});
}
@ -285,16 +285,20 @@ class ShardRequestManager : public ShardRequestManagerInterface {
// std::vector<storageclient> GetStorageClientFromShardforRange(const std::string &label, const CompoundKey &start,
// const CompoundKey &end);
template <typename TLabel>
StorageClient &GetStorageClientForShard(const TLabel &label, const CompoundKey &key) {
auto shard = shards_map_.GetShardForKey(label, key);
if (!storage_cli_manager_.Exists(label, shard)) {
AddStorageClientToManager(shard, label);
StorageClient &GetStorageClientForShard(Shard shard, LabelId label_id) {
if (!storage_cli_manager_.Exists(label_id, shard)) {
AddStorageClientToManager(shard, label_id);
}
return storage_cli_manager_.GetClient(label, shard);
return storage_cli_manager_.GetClient(label_id, shard);
}
void AddStorageClientToManager(Shard target_shard, const std::string &label) {
StorageClient &GetStorageClientForShard(const std::string &label, const CompoundKey &key) {
auto shard = shards_map_.GetShardForKey(label, key);
auto label_id = shards_map_.GetLabelId(label);
return GetStorageClientForShard(std::move(shard), label_id);
}
void AddStorageClientToManager(Shard target_shard, const LabelId &label_id) {
MG_ASSERT(!target_shard.empty());
auto leader_addr = target_shard.front();
std::vector<Address> addresses;
@ -303,7 +307,7 @@ class ShardRequestManager : public ShardRequestManagerInterface {
addresses.push_back(std::move(address.address));
}
auto cli = StorageClient(io_, std::move(leader_addr.address), std::move(addresses));
storage_cli_manager_.AddClient(label, target_shard, std::move(cli));
storage_cli_manager_.AddClient(label_id, target_shard, std::move(cli));
}
ShardMap shards_map_;

View File

@ -101,52 +101,5 @@ class MockedShardRsm {
return ret;
}
StorageWriteResponse Apply(StorageWriteRequest request) {
StorageWriteResponse ret;
// Key is outside the prohibited range
if (!IsKeyInRange(request.key)) {
ret.latest_known_shard_map_version = shard_map_version_;
ret.shard_rsm_success = false;
}
// Key exist
else if (state_.contains(request.key)) {
auto &val = state_[request.key];
/*
* Delete
*/
if (!request.value) {
ret.shard_rsm_success = true;
ret.last_value = val;
state_.erase(state_.find(request.key));
}
/*
* Update
*/
// Does old_value match?
if (request.value == val) {
ret.last_value = val;
ret.shard_rsm_success = true;
val = request.value.value();
} else {
ret.last_value = val;
ret.shard_rsm_success = false;
}
}
/*
* Create
*/
else {
ret.last_value = std::nullopt;
ret.shard_rsm_success = true;
state_.emplace(request.key, std::move(request.value).value());
}
return ret;
}
CreateVerticesResponse Apply(CreateVerticesRequest request) { return CreateVerticesResponse{.success = true}; }
};

View File

@ -116,21 +116,19 @@ ShardMap CreateDummyShardmap(memgraph::coordinator::Address a_io_1, memgraph::co
} // namespace
using ConcreteCoordinatorRsm = CoordinatorRsm<SimulatorTransport>;
using ConcreteStorageRsm = Raft<SimulatorTransport, MockedShardRsm, StorageWriteRequest, StorageWriteResponse,
using ConcreteStorageRsm = Raft<SimulatorTransport, MockedShardRsm, CreateVerticesRequest, CreateVerticesResponse,
ScanVerticesRequest, ScanVerticesResponse>;
template <typename IoImpl>
void RunStorageRaft(
Raft<IoImpl, MockedShardRsm, StorageWriteRequest, StorageWriteResponse, ScanVerticesRequest, ScanVerticesResponse>
server) {
void RunStorageRaft(Raft<IoImpl, MockedShardRsm, CreateVerticesRequest, CreateVerticesResponse, ScanVerticesRequest,
ScanVerticesResponse>
server) {
server.Run();
}
template <typename ShardRequestManager>
void TestScanAll(ShardRequestManager &io) {
ExecutionState<ScanVerticesRequest> state{.label = "test_label"};
state.key = std::make_optional<CompoundKey>(
std::vector{memgraph::storage::v3::PropertyValue(0), memgraph::storage::v3::PropertyValue(0)});
auto result = io.Request(state);
MG_ASSERT(result.size() == 2);
@ -154,7 +152,18 @@ void TestScanAll(ShardRequestManager &io) {
}
template <typename ShardRequestManager>
void TestCreateVertices(ShardRequestManager &io) {}
void TestCreateVertices(ShardRequestManager &io) {
using PropVal = memgraph::storage::v3::PropertyValue;
ExecutionState<CreateVerticesRequest> state;
std::vector<NewVertexLabel> new_vertices;
NewVertexLabel a1{.label = "test_label", .primary_key = {PropVal(1), PropVal(0)}};
NewVertexLabel a2{.label = "test_label", .primary_key = {PropVal(13), PropVal(13)}};
new_vertices.push_back(std::move(a1));
new_vertices.push_back(std::move(a2));
auto result = io.Request(state, std::move(new_vertices));
MG_ASSERT(result.size() == 2);
}
template <typename ShardRequestManager>
void TestExpand(ShardRequestManager &io) {}
@ -267,12 +276,9 @@ int main() {
ShardRequestManager<SimulatorTransport, ScanVerticesRequest, ScanVerticesResponse> io(std::move(coordinator_client),
std::move(cli_io));
ExecutionState<ScanVerticesRequest> state{.label = "test_label"};
state.key = std::make_optional<CompoundKey>(
std::vector{memgraph::storage::v3::PropertyValue(0), memgraph::storage::v3::PropertyValue(0)});
io.StartTransaction();
TestScanAll(io);
TestCreateVertices(io);
simulator.ShutDown();
return 0;