diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f4c303daf..6a510aee6 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -18,6 +18,7 @@ add_subdirectory(query/v2) add_subdirectory(slk) add_subdirectory(rpc) add_subdirectory(auth) +add_subdirectory(coordinator) if (MG_ENTERPRISE) add_subdirectory(audit) diff --git a/src/coordinator/CMakeLists.txt b/src/coordinator/CMakeLists.txt new file mode 100644 index 000000000..198223f5a --- /dev/null +++ b/src/coordinator/CMakeLists.txt @@ -0,0 +1,10 @@ +set(coordinator_src_files + coordinator.hpp + shard_map.hpp + hybrid_logical_clock.hpp) + +find_package(fmt REQUIRED) +find_package(Threads REQUIRED) + +add_library(mg-coordinator STATIC ${coordinator_src_files}) +target_link_libraries(mg-coordinator stdc++fs Threads::Threads fmt::fmt mg-utils) diff --git a/src/coordinator/coordinator.hpp b/src/coordinator/coordinator.hpp new file mode 100644 index 000000000..3382a4006 --- /dev/null +++ b/src/coordinator/coordinator.hpp @@ -0,0 +1,259 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <optional> +#include <string> +#include <unordered_set> +#include <variant> +#include <vector> + +#include "coordinator/hybrid_logical_clock.hpp" +#include "coordinator/shard_map.hpp" +#include "io/simulator/simulator.hpp" +#include "io/time.hpp" +#include "io/transport.hpp" +#include "storage/v3/id_types.hpp" +#include "storage/v3/schemas.hpp" + +namespace memgraph::coordinator { + +using memgraph::storage::v3::LabelId; +using memgraph::storage::v3::PropertyId; +using Address = memgraph::io::Address; +using SimT = memgraph::io::simulator::SimulatorTransport; +using memgraph::storage::v3::SchemaProperty; + +struct HlcRequest { + Hlc last_shard_map_version; +}; + +struct HlcResponse { + Hlc new_hlc; + std::optional<ShardMap> fresher_shard_map; +}; + +struct GetShardMapRequest { + // No state +}; + +struct GetShardMapResponse { + ShardMap shard_map; +}; + +struct AllocateHlcBatchResponse { + bool success; + Hlc low; + Hlc high; +}; + +struct AllocateEdgeIdBatchRequest { + size_t batch_size; +}; + +struct AllocateEdgeIdBatchResponse { + uint64_t low; + uint64_t high; +}; + +struct AllocatePropertyIdsRequest { + std::vector<std::string> property_names; +}; + +struct AllocatePropertyIdsResponse { + std::map<std::string, PropertyId> property_ids; +}; + +struct SplitShardRequest { + Hlc previous_shard_map_version; + LabelId label_id; + CompoundKey split_key; +}; + +struct SplitShardResponse { + bool success; +}; + +struct RegisterStorageEngineRequest { + Address address; +}; + +struct RegisterStorageEngineResponse { + bool success; +}; + +struct DeregisterStorageEngineRequest { + Address address; +}; + +struct DeregisterStorageEngineResponse { + bool success; +}; + +struct InitializeLabelRequest { + std::string label_name; + std::vector<SchemaProperty> schema; + Hlc last_shard_map_version; +}; + +struct InitializeLabelResponse { + bool success; + std::optional<ShardMap> fresher_shard_map; +}; + +struct HeartbeatRequest {}; +struct HeartbeatResponse {}; + +using CoordinatorWriteRequests = + std::variant<HlcRequest, AllocateEdgeIdBatchRequest, SplitShardRequest, RegisterStorageEngineRequest, + DeregisterStorageEngineRequest, InitializeLabelRequest, AllocatePropertyIdsRequest>; +using CoordinatorWriteResponses = + std::variant<HlcResponse, AllocateEdgeIdBatchResponse, SplitShardResponse, RegisterStorageEngineResponse, + DeregisterStorageEngineResponse, InitializeLabelResponse, AllocatePropertyIdsResponse>; + +using CoordinatorReadRequests = std::variant<GetShardMapRequest, HeartbeatRequest>; +using CoordinatorReadResponses = std::variant<GetShardMapResponse, HeartbeatResponse>; + +class Coordinator { + public: + explicit Coordinator(ShardMap sm) : shard_map_{std::move(sm)} {} + + // NOLINTNEXTLINE(readability-convert-member-functions-to-static + CoordinatorReadResponses Read(CoordinatorReadRequests requests) { + return std::visit([&](auto &&request) { return HandleRead(std::forward<decltype(request)>(request)); }, + std::move(requests)); // NOLINT(hicpp-move-const-arg,performance-move-const-arg) + } + + // NOLINTNEXTLINE(readability-convert-member-functions-to-static + CoordinatorWriteResponses Apply(CoordinatorWriteRequests requests) { + return std::visit([&](auto &&request) mutable { return ApplyWrite(std::forward<decltype(request)>(request)); }, + std::move(requests)); + } + + private: + ShardMap shard_map_; + uint64_t highest_allocated_timestamp_; + + /// Query engines need to periodically request batches of unique edge IDs. + uint64_t highest_allocated_edge_id_; + + static CoordinatorReadResponses HandleRead(HeartbeatRequest && /* heartbeat_request */) { + return HeartbeatResponse{}; + } + + CoordinatorReadResponses HandleRead(GetShardMapRequest && /* get_shard_map_request */) { + GetShardMapResponse res; + res.shard_map = shard_map_; + return res; + } + + CoordinatorWriteResponses ApplyWrite(HlcRequest &&hlc_request) { + HlcResponse res{}; + + auto hlc_shard_map = shard_map_.GetHlc(); + + MG_ASSERT(!(hlc_request.last_shard_map_version.logical_id > hlc_shard_map.logical_id)); + + res.new_hlc = Hlc{ + .logical_id = ++highest_allocated_timestamp_, + // TODO(tyler) probably pass some more context to the Coordinator here + // so that we can use our wall clock and enforce monotonicity. + // .coordinator_wall_clock = io_.Now(), + }; + + // Allways return fresher shard_map for now. + res.fresher_shard_map = std::make_optional(shard_map_); + + return res; + } + + CoordinatorWriteResponses ApplyWrite(AllocateEdgeIdBatchRequest &&ahr) { + AllocateEdgeIdBatchResponse res{}; + + uint64_t low = highest_allocated_edge_id_; + + highest_allocated_edge_id_ += ahr.batch_size; + + uint64_t high = highest_allocated_edge_id_; + + res.low = low; + res.high = high; + + return res; + } + + /// This splits the shard immediately beneath the provided + /// split key, keeping the assigned peers identical for now, + /// but letting them be gradually migrated over time. + CoordinatorWriteResponses ApplyWrite(SplitShardRequest &&split_shard_request) { + SplitShardResponse res{}; + + if (split_shard_request.previous_shard_map_version != shard_map_.shard_map_version) { + res.success = false; + } else { + res.success = shard_map_.SplitShard(split_shard_request.previous_shard_map_version, split_shard_request.label_id, + split_shard_request.split_key); + } + + return res; + } + + /// This adds the provided storage engine to the standby storage engine pool, + /// which can be used to rebalance storage over time. + static CoordinatorWriteResponses ApplyWrite(RegisterStorageEngineRequest && /* register_storage_engine_request */) { + RegisterStorageEngineResponse res{}; + // TODO + + return res; + } + + /// This begins the process of draining the provided storage engine from all raft + /// clusters that it might be participating in. + static CoordinatorWriteResponses ApplyWrite(DeregisterStorageEngineRequest && /* register_storage_engine_request */) { + DeregisterStorageEngineResponse res{}; + // TODO + // const Address &address = register_storage_engine_request.address; + // storage_engine_pool_.erase(address); + // res.success = true; + + return res; + } + + CoordinatorWriteResponses ApplyWrite(InitializeLabelRequest &&initialize_label_request) { + InitializeLabelResponse res{}; + + bool success = shard_map_.InitializeNewLabel(initialize_label_request.label_name, initialize_label_request.schema, + initialize_label_request.last_shard_map_version); + + if (success) { + res.fresher_shard_map = shard_map_; + res.success = false; + } else { + res.fresher_shard_map = std::nullopt; + res.success = true; + } + + return res; + } + + CoordinatorWriteResponses ApplyWrite(AllocatePropertyIdsRequest &&allocate_property_ids_request) { + AllocatePropertyIdsResponse res{}; + + auto property_ids = shard_map_.AllocatePropertyIds(allocate_property_ids_request.property_names); + + res.property_ids = property_ids; + + return res; + } +}; + +} // namespace memgraph::coordinator diff --git a/src/coordinator/coordinator_client.hpp b/src/coordinator/coordinator_client.hpp new file mode 100644 index 000000000..f73ecfcd2 --- /dev/null +++ b/src/coordinator/coordinator_client.hpp @@ -0,0 +1,25 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include "coordinator/coordinator.hpp" +#include "io/rsm/rsm_client.hpp" + +namespace memgraph::coordinator { + +using memgraph::io::rsm::RsmClient; + +template <typename IoImpl> +using CoordinatorClient = RsmClient<IoImpl, CoordinatorWriteRequests, CoordinatorWriteResponses, + CoordinatorReadRequests, CoordinatorReadResponses>; + +} // namespace memgraph::coordinator diff --git a/src/coordinator/coordinator_rsm.hpp b/src/coordinator/coordinator_rsm.hpp new file mode 100644 index 000000000..94bbe8351 --- /dev/null +++ b/src/coordinator/coordinator_rsm.hpp @@ -0,0 +1,23 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include "coordinator/coordinator.hpp" +#include "io/rsm/raft.hpp" + +namespace memgraph::coordinator { + +template <typename IoImpl> +using CoordinatorRsm = memgraph::io::rsm::Raft<IoImpl, Coordinator, CoordinatorWriteRequests, CoordinatorWriteResponses, + CoordinatorReadRequests, CoordinatorReadResponses>; + +} // namespace memgraph::coordinator diff --git a/src/coordinator/hybrid_logical_clock.hpp b/src/coordinator/hybrid_logical_clock.hpp new file mode 100644 index 000000000..75d3a2fbb --- /dev/null +++ b/src/coordinator/hybrid_logical_clock.hpp @@ -0,0 +1,28 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include "io/time.hpp" + +namespace memgraph::coordinator { + +using Time = memgraph::io::Time; + +/// Hybrid-logical clock +struct Hlc { + uint64_t logical_id; + Time coordinator_wall_clock; + + bool operator==(const Hlc &other) const = default; +}; + +} // namespace memgraph::coordinator diff --git a/src/coordinator/shard_map.hpp b/src/coordinator/shard_map.hpp new file mode 100644 index 000000000..ce9f779dd --- /dev/null +++ b/src/coordinator/shard_map.hpp @@ -0,0 +1,189 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <map> +#include <vector> + +#include "coordinator/hybrid_logical_clock.hpp" +#include "io/address.hpp" +#include "storage/v3/id_types.hpp" +#include "storage/v3/property_value.hpp" +#include "storage/v3/schemas.hpp" + +namespace memgraph::coordinator { + +using memgraph::io::Address; +using memgraph::storage::v3::LabelId; +using memgraph::storage::v3::PropertyId; +using memgraph::storage::v3::SchemaProperty; + +enum class Status : uint8_t { + CONSENSUS_PARTICIPANT, + INITIALIZING, + // TODO(tyler) this will possibly have more states, + // depending on the reconfiguration protocol that we + // implement. +}; + +struct AddressAndStatus { + memgraph::io::Address address; + Status status; +}; + +using CompoundKey = std::vector<memgraph::storage::v3::PropertyValue>; +using Shard = std::vector<AddressAndStatus>; +using Shards = std::map<CompoundKey, Shard>; +using LabelName = std::string; +using PropertyName = std::string; +using PropertyMap = std::map<PropertyName, PropertyId>; + +struct LabelSpace { + std::vector<SchemaProperty> schema; + std::map<CompoundKey, Shard> shards; +}; + +struct ShardMap { + Hlc shard_map_version; + uint64_t max_property_id; + std::map<PropertyName, PropertyId> properties; + uint64_t max_label_id; + std::map<LabelName, LabelId> labels; + std::map<LabelId, LabelSpace> label_spaces; + std::map<LabelId, std::vector<SchemaProperty>> schemas; + + // TODO(gabor) later we will want to update the wallclock time with + // the given Io<impl>'s time as well + Hlc IncrementShardMapVersion() noexcept { + ++shard_map_version.logical_id; + return shard_map_version; + } + + Hlc GetHlc() const noexcept { return shard_map_version; } + + bool SplitShard(Hlc previous_shard_map_version, LabelId label_id, const CompoundKey &key) { + if (previous_shard_map_version != shard_map_version) { + return false; + } + + auto &label_space = label_spaces.at(label_id); + auto &shards_in_map = label_space.shards; + + MG_ASSERT(!shards_in_map.contains(key)); + MG_ASSERT(label_spaces.contains(label_id)); + + // Finding the Shard that the new CompoundKey should map to. + auto prev = std::prev(shards_in_map.upper_bound(key)); + Shard duplicated_shard = prev->second; + + // Apply the split + shards_in_map[key] = duplicated_shard; + + return true; + } + + bool InitializeNewLabel(std::string label_name, std::vector<SchemaProperty> schema, Hlc last_shard_map_version) { + if (shard_map_version != last_shard_map_version || labels.contains(label_name)) { + return false; + } + + const LabelId label_id = LabelId::FromUint(++max_label_id); + + labels.emplace(std::move(label_name), label_id); + + LabelSpace label_space{ + .schema = std::move(schema), + .shards = Shards{}, + }; + + label_spaces.emplace(label_id, label_space); + + IncrementShardMapVersion(); + + return true; + } + + void AddServer(Address server_address) { + // Find a random place for the server to plug in + } + + Shards GetShardsForRange(const LabelName &label_name, const CompoundKey &start_key, const CompoundKey &end_key) const { + MG_ASSERT(start_key <= end_key); + MG_ASSERT(labels.contains(label_name)); + + LabelId label_id = labels.at(label_name); + + const auto &label_space = label_spaces.at(label_id); + + const auto &shards_for_label = label_space.shards; + + MG_ASSERT(shards_for_label.begin()->first <= start_key, + "the ShardMap must always contain a minimal key that is less than or equal to any requested key"); + + auto it = std::prev(shards_for_label.upper_bound(start_key)); + const auto end_it = shards_for_label.upper_bound(end_key); + + Shards shards{}; + + std::copy(it, end_it, std::inserter(shards, shards.end())); + + return shards; + } + + Shard GetShardForKey(const LabelName &label_name, const CompoundKey &key) const { + MG_ASSERT(labels.contains(label_name)); + + LabelId label_id = labels.at(label_name); + + const auto &label_space = label_spaces.at(label_id); + + MG_ASSERT(label_space.shards.begin()->first <= key, + "the ShardMap must always contain a minimal key that is less than or equal to any requested key"); + + return std::prev(label_space.shards.upper_bound(key))->second; + } + + PropertyMap AllocatePropertyIds(const std::vector<PropertyName> &new_properties) { + PropertyMap ret{}; + + bool mutated = false; + + for (const auto &property_name : new_properties) { + if (properties.contains(property_name)) { + auto property_id = properties.at(property_name); + ret.emplace(property_name, property_id); + } else { + mutated = true; + + const PropertyId property_id = PropertyId::FromUint(++max_property_id); + ret.emplace(property_name, property_id); + properties.emplace(property_name, property_id); + } + } + + if (mutated) { + IncrementShardMapVersion(); + } + + return ret; + } + + std::optional<PropertyId> GetPropertyId(const std::string &property_name) const { + if (properties.contains(property_name)) { + return properties.at(property_name); + } + + return std::nullopt; + } +}; + +} // namespace memgraph::coordinator diff --git a/src/io/address.hpp b/src/io/address.hpp index 914c8cb86..286bab577 100644 --- a/src/io/address.hpp +++ b/src/io/address.hpp @@ -45,6 +45,15 @@ struct Address { }; } + /// Returns a new ID with the same IP and port but a unique UUID. + Address ForkUniqueAddress() { + return Address{ + .unique_id = boost::uuids::uuid{boost::uuids::random_generator()()}, + .last_known_ip = last_known_ip, + .last_known_port = last_known_port, + }; + } + friend bool operator==(const Address &lhs, const Address &rhs) = default; /// unique_id is most dominant for ordering, then last_known_ip, then last_known_port diff --git a/src/io/local_transport/local_transport.hpp b/src/io/local_transport/local_transport.hpp index f08392a87..bdf964bed 100644 --- a/src/io/local_transport/local_transport.hpp +++ b/src/io/local_transport/local_transport.hpp @@ -31,9 +31,9 @@ class LocalTransport { LocalTransport(std::shared_ptr<LocalTransportHandle> local_transport_handle, Address address) : local_transport_handle_(std::move(local_transport_handle)), address_(address) {} - template <Message Request, Message Response> - ResponseFuture<Response> Request(Address to_address, RequestId request_id, Request request, Duration timeout) { - auto [future, promise] = memgraph::io::FuturePromisePair<ResponseResult<Response>>(); + template <Message RequestT, Message ResponseT> + ResponseFuture<ResponseT> Request(Address to_address, RequestId request_id, RequestT request, Duration timeout) { + auto [future, promise] = memgraph::io::FuturePromisePair<ResponseResult<ResponseT>>(); Address from_address = address_; diff --git a/src/io/local_transport/local_transport_handle.hpp b/src/io/local_transport/local_transport_handle.hpp index 8536ff716..1afade9e6 100644 --- a/src/io/local_transport/local_transport_handle.hpp +++ b/src/io/local_transport/local_transport_handle.hpp @@ -113,9 +113,9 @@ class LocalTransportHandle { cv_.notify_all(); } - template <Message Request, Message Response> - void SubmitRequest(Address to_address, Address from_address, RequestId request_id, Request &&request, - Duration timeout, ResponsePromise<Response> promise) { + template <Message RequestT, Message ResponseT> + void SubmitRequest(Address to_address, Address from_address, RequestId request_id, RequestT &&request, + Duration timeout, ResponsePromise<ResponseT> promise) { const bool port_matches = to_address.last_known_port == from_address.last_known_port; const bool ip_matches = to_address.last_known_ip == from_address.last_known_ip; @@ -132,7 +132,7 @@ class LocalTransportHandle { promises_.emplace(std::move(promise_key), std::move(dop)); } // lock dropped - Send(to_address, from_address, request_id, std::forward<Request>(request)); + Send(to_address, from_address, request_id, std::forward<RequestT>(request)); } }; diff --git a/src/io/rsm/raft.hpp b/src/io/rsm/raft.hpp index c38d3da74..ea7286cbd 100644 --- a/src/io/rsm/raft.hpp +++ b/src/io/rsm/raft.hpp @@ -184,13 +184,15 @@ concept FollowerOrCandidate = memgraph::utils::SameAsAnyOf<Role, Follower, Candi /* all ReplicatedState classes should have an Apply method -that returns our WriteResponseValue: +that returns our WriteResponseValue after consensus, and +a Read method that returns our ReadResponseValue without +requiring consensus. ReadResponse Read(ReadOperation); WriteResponseValue ReplicatedState::Apply(WriteRequest); -for examples: -if the state is uint64_t, and WriteRequest is `struct PlusOne {};`, +For example: +If the state is uint64_t, and WriteRequest is `struct PlusOne {};`, and WriteResponseValue is also uint64_t (the new value), then each call to state.Apply(PlusOne{}) will return the new value after incrementing it. 0, 1, 2, 3... and this will be sent back diff --git a/src/io/rsm/rsm_client.hpp b/src/io/rsm/rsm_client.hpp new file mode 100644 index 000000000..f837ee0ec --- /dev/null +++ b/src/io/rsm/rsm_client.hpp @@ -0,0 +1,136 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <iostream> +#include <optional> +#include <vector> + +#include "io/address.hpp" +#include "io/rsm/raft.hpp" +#include "utils/result.hpp" + +namespace memgraph::io::rsm { + +using memgraph::io::Address; +using memgraph::io::Duration; +using memgraph::io::ResponseEnvelope; +using memgraph::io::ResponseFuture; +using memgraph::io::ResponseResult; +using memgraph::io::Time; +using memgraph::io::TimedOut; +using memgraph::io::rsm::ReadRequest; +using memgraph::io::rsm::ReadResponse; +using memgraph::io::rsm::WriteRequest; +using memgraph::io::rsm::WriteResponse; +using memgraph::utils::BasicResult; + +template <typename IoImpl, typename WriteRequestT, typename WriteResponseT, typename ReadRequestT, + typename ReadResponseT> +class RsmClient { + using ServerPool = std::vector<Address>; + + Io<IoImpl> io_; + Address leader_; + ServerPool server_addrs_; + + template <typename ResponseT> + void PossiblyRedirectLeader(const ResponseT &response) { + if (response.retry_leader) { + MG_ASSERT(!response.success, "retry_leader should never be set for successful responses"); + leader_ = response.retry_leader.value(); + spdlog::debug("client redirected to leader server {}", leader_.ToString()); + } else if (!response.success) { + std::uniform_int_distribution<size_t> addr_distrib(0, (server_addrs_.size() - 1)); + size_t addr_index = io_.Rand(addr_distrib); + leader_ = server_addrs_[addr_index]; + + spdlog::debug( + "client NOT redirected to leader server despite our success failing to be processed (it probably was sent to " + "a RSM Candidate) trying a random one at index {} with address {}", + addr_index, leader_.ToString()); + } + } + + public: + RsmClient(Io<IoImpl> io, Address leader, ServerPool server_addrs) + : io_{io}, leader_{leader}, server_addrs_{server_addrs} {} + + RsmClient() = delete; + + BasicResult<TimedOut, WriteResponseT> SendWriteRequest(WriteRequestT req) { + WriteRequest<WriteRequestT> client_req; + client_req.operation = req; + + const Duration overall_timeout = io_.GetDefaultTimeout(); + const Time before = io_.Now(); + + do { + spdlog::debug("client sending WriteRequest to Leader {}", leader_.ToString()); + ResponseFuture<WriteResponse<WriteResponseT>> response_future = + io_.template Request<WriteRequest<WriteRequestT>, WriteResponse<WriteResponseT>>(leader_, client_req); + ResponseResult<WriteResponse<WriteResponseT>> response_result = std::move(response_future).Wait(); + + if (response_result.HasError()) { + spdlog::debug("client timed out while trying to communicate with leader server {}", leader_.ToString()); + return response_result.GetError(); + } + + ResponseEnvelope<WriteResponse<WriteResponseT>> &&response_envelope = std::move(response_result.GetValue()); + WriteResponse<WriteResponseT> &&write_response = std::move(response_envelope.message); + + if (write_response.success) { + return std::move(write_response.write_return); + } + + PossiblyRedirectLeader(write_response); + } while (io_.Now() < before + overall_timeout); + + return TimedOut{}; + } + + BasicResult<TimedOut, ReadResponseT> SendReadRequest(ReadRequestT req) { + ReadRequest<ReadRequestT> read_req; + read_req.operation = req; + + const Duration overall_timeout = io_.GetDefaultTimeout(); + const Time before = io_.Now(); + + do { + spdlog::debug("client sending ReadRequest to Leader {}", leader_.ToString()); + + ResponseFuture<ReadResponse<ReadResponseT>> get_response_future = + io_.template Request<ReadRequest<ReadRequestT>, ReadResponse<ReadResponseT>>(leader_, read_req); + + // receive response + ResponseResult<ReadResponse<ReadResponseT>> get_response_result = std::move(get_response_future).Wait(); + + if (get_response_result.HasError()) { + spdlog::debug("client timed out while trying to communicate with leader server {}", leader_.ToString()); + return get_response_result.GetError(); + } + + ResponseEnvelope<ReadResponse<ReadResponseT>> &&get_response_envelope = std::move(get_response_result.GetValue()); + ReadResponse<ReadResponseT> &&read_get_response = std::move(get_response_envelope.message); + + if (read_get_response.success) { + return std::move(read_get_response.read_return); + } + + PossiblyRedirectLeader(read_get_response); + } while (io_.Now() < before + overall_timeout); + + return TimedOut{}; + } +}; + +} // namespace memgraph::io::rsm diff --git a/src/io/rsm/shard_rsm.hpp b/src/io/rsm/shard_rsm.hpp new file mode 100644 index 000000000..480f40238 --- /dev/null +++ b/src/io/rsm/shard_rsm.hpp @@ -0,0 +1,155 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +/// The ShardRsm is a simple in-memory raft-backed kv store that can be used for simple testing +/// and implementation of some query engine logic before storage engines are fully implemented. +/// +/// To implement multiple read and write commands, change the StorageRead* and StorageWrite* requests +/// and responses to a std::variant of the different options, and route them to specific handlers in +/// the ShardRsm's Read and Apply methods. Remember that Read is called immediately when the Raft +/// leader receives the request, and does not replicate anything over Raft. Apply is called only +/// AFTER the StorageWriteRequest is replicated to a majority of Raft peers, and the result of calling +/// ShardRsm::Apply(StorageWriteRequest) is returned to the client that submitted the request. + +#include <deque> +#include <iostream> +#include <map> +#include <optional> +#include <set> +#include <thread> +#include <vector> + +#include "coordinator/hybrid_logical_clock.hpp" +#include "io/address.hpp" +#include "io/rsm/raft.hpp" +#include "storage/v3/id_types.hpp" +#include "storage/v3/property_value.hpp" +#include "utils/logging.hpp" + +namespace memgraph::io::rsm { + +using memgraph::coordinator::Hlc; +using memgraph::storage::v3::LabelId; +using memgraph::storage::v3::PropertyValue; + +using ShardRsmKey = std::vector<PropertyValue>; + +struct StorageWriteRequest { + LabelId label_id; + Hlc transaction_id; + ShardRsmKey key; + std::optional<int> value; +}; + +struct StorageWriteResponse { + bool shard_rsm_success; + std::optional<int> last_value; + // Only has a value if the given shard does not contain the requested key + std::optional<Hlc> latest_known_shard_map_version{std::nullopt}; +}; + +struct StorageReadRequest { + LabelId label_id; + Hlc transaction_id; + ShardRsmKey key; +}; + +struct StorageReadResponse { + bool shard_rsm_success; + std::optional<int> value; + // Only has a value if the given shard does not contain the requested key + std::optional<Hlc> latest_known_shard_map_version{std::nullopt}; +}; + +class ShardRsm { + std::map<ShardRsmKey, int> state_; + ShardRsmKey minimum_key_; + std::optional<ShardRsmKey> maximum_key_{std::nullopt}; + Hlc shard_map_version_; + + // The key is not located in this shard + bool IsKeyInRange(const ShardRsmKey &key) const { + if (maximum_key_) [[likely]] { + return (key >= minimum_key_ && key <= maximum_key_); + } + return key >= minimum_key_; + } + + public: + StorageReadResponse Read(StorageReadRequest request) const { + StorageReadResponse ret; + + if (!IsKeyInRange(request.key)) { + ret.latest_known_shard_map_version = shard_map_version_; + ret.shard_rsm_success = false; + } else if (state_.contains(request.key)) { + ret.value = state_.at(request.key); + ret.shard_rsm_success = true; + } else { + ret.shard_rsm_success = false; + ret.value = std::nullopt; + } + return ret; + } + + StorageWriteResponse Apply(StorageWriteRequest request) { + StorageWriteResponse ret; + + // Key is outside the prohibited range + if (!IsKeyInRange(request.key)) { + ret.latest_known_shard_map_version = shard_map_version_; + ret.shard_rsm_success = false; + } + // Key exist + else if (state_.contains(request.key)) { + auto &val = state_[request.key]; + + /* + * Delete + */ + if (!request.value) { + ret.shard_rsm_success = true; + ret.last_value = val; + state_.erase(state_.find(request.key)); + } + + /* + * Update + */ + // Does old_value match? + if (request.value == val) { + ret.last_value = val; + ret.shard_rsm_success = true; + + val = request.value.value(); + + } else { + ret.last_value = val; + ret.shard_rsm_success = false; + } + } + /* + * Create + */ + else { + ret.last_value = std::nullopt; + ret.shard_rsm_success = true; + + state_.emplace(request.key, request.value.value()); + } + + return ret; + } +}; + +} // namespace memgraph::io::rsm diff --git a/src/io/simulator/simulator.hpp b/src/io/simulator/simulator.hpp index 354aae6ac..a28ae16df 100644 --- a/src/io/simulator/simulator.hpp +++ b/src/io/simulator/simulator.hpp @@ -23,6 +23,7 @@ namespace memgraph::io::simulator { class Simulator { std::mt19937 rng_; std::shared_ptr<SimulatorHandle> simulator_handle_; + uint16_t auto_port_ = 0; public: explicit Simulator(SimulatorConfig config) @@ -30,6 +31,11 @@ class Simulator { void ShutDown() { simulator_handle_->ShutDown(); } + Io<SimulatorTransport> RegisterNew() { + Address address = Address::TestAddress(auto_port_++); + return Register(address); + } + Io<SimulatorTransport> Register(Address address) { std::uniform_int_distribution<uint64_t> seed_distrib; uint64_t seed = seed_distrib(rng_); diff --git a/src/io/simulator/simulator_transport.hpp b/src/io/simulator/simulator_transport.hpp index 2706f798c..4b5a2e890 100644 --- a/src/io/simulator/simulator_transport.hpp +++ b/src/io/simulator/simulator_transport.hpp @@ -32,11 +32,11 @@ class SimulatorTransport { SimulatorTransport(std::shared_ptr<SimulatorHandle> simulator_handle, Address address, uint64_t seed) : simulator_handle_(simulator_handle), address_(address), rng_(std::mt19937{seed}) {} - template <Message Request, Message Response> - ResponseFuture<Response> Request(Address address, uint64_t request_id, Request request, Duration timeout) { + template <Message RequestT, Message ResponseT> + ResponseFuture<ResponseT> Request(Address address, uint64_t request_id, RequestT request, Duration timeout) { std::function<bool()> maybe_tick_simulator = [this] { return simulator_handle_->MaybeTickSimulator(); }; auto [future, promise] = - memgraph::io::FuturePromisePairWithNotifier<ResponseResult<Response>>(maybe_tick_simulator); + memgraph::io::FuturePromisePairWithNotifier<ResponseResult<ResponseT>>(maybe_tick_simulator); simulator_handle_->SubmitRequest(address, address_, request_id, std::move(request), timeout, std::move(promise)); diff --git a/src/io/transport.hpp b/src/io/transport.hpp index 31592c9c3..04b1d2d62 100644 --- a/src/io/transport.hpp +++ b/src/io/transport.hpp @@ -76,20 +76,23 @@ class Io { /// without an explicit timeout set. void SetDefaultTimeout(Duration timeout) { default_timeout_ = timeout; } + /// Returns the current default timeout for this Io instance. + Duration GetDefaultTimeout() { return default_timeout_; } + /// Issue a request with an explicit timeout in microseconds provided. This tends to be used by clients. - template <Message Request, Message Response> - ResponseFuture<Response> RequestWithTimeout(Address address, Request request, Duration timeout) { + template <Message RequestT, Message ResponseT> + ResponseFuture<ResponseT> RequestWithTimeout(Address address, RequestT request, Duration timeout) { const RequestId request_id = ++request_id_counter_; - return implementation_.template Request<Request, Response>(address, request_id, request, timeout); + return implementation_.template Request<RequestT, ResponseT>(address, request_id, request, timeout); } /// Issue a request that times out after the default timeout. This tends /// to be used by clients. - template <Message Request, Message Response> - ResponseFuture<Response> Request(Address address, Request request) { + template <Message RequestT, Message ResponseT> + ResponseFuture<ResponseT> Request(Address address, RequestT request) { const RequestId request_id = ++request_id_counter_; const Duration timeout = default_timeout_; - return implementation_.template Request<Request, Response>(address, request_id, std::move(request), timeout); + return implementation_.template Request<RequestT, ResponseT>(address, request_id, std::move(request), timeout); } /// Wait for an explicit number of microseconds for a request of one of the diff --git a/tests/simulation/CMakeLists.txt b/tests/simulation/CMakeLists.txt index e868ca0ef..b44dacb36 100644 --- a/tests/simulation/CMakeLists.txt +++ b/tests/simulation/CMakeLists.txt @@ -30,3 +30,5 @@ add_simulation_test(basic_request.cpp address) add_simulation_test(raft.cpp address) add_simulation_test(trial_query_storage/query_storage_test.cpp address) + +add_simulation_test(sharded_map.cpp address) diff --git a/tests/simulation/raft.cpp b/tests/simulation/raft.cpp index 7033d04e6..92750d744 100644 --- a/tests/simulation/raft.cpp +++ b/tests/simulation/raft.cpp @@ -20,6 +20,7 @@ #include "io/address.hpp" #include "io/rsm/raft.hpp" +#include "io/rsm/rsm_client.hpp" #include "io/simulator/simulator.hpp" #include "io/simulator/simulator_transport.hpp" @@ -33,6 +34,7 @@ using memgraph::io::Time; using memgraph::io::rsm::Raft; using memgraph::io::rsm::ReadRequest; using memgraph::io::rsm::ReadResponse; +using memgraph::io::rsm::RsmClient; using memgraph::io::rsm::WriteRequest; using memgraph::io::rsm::WriteResponse; using memgraph::io::simulator::Simulator; @@ -147,7 +149,6 @@ void RunSimulation() { std::vector<Address> srv_2_peers = {srv_addr_1, srv_addr_3}; std::vector<Address> srv_3_peers = {srv_addr_1, srv_addr_2}; - // TODO(tyler / gabor) supply default TestState to Raft constructor using RaftClass = Raft<SimulatorTransport, TestState, CasRequest, CasResponse, GetRequest, GetResponse>; RaftClass srv_1{std::move(srv_io_1), srv_1_peers, TestState{}}; RaftClass srv_2{std::move(srv_io_2), srv_2_peers, TestState{}}; @@ -165,16 +166,20 @@ void RunSimulation() { spdlog::info("beginning test after servers have become quiescent"); std::mt19937 cli_rng_{0}; - Address server_addrs[]{srv_addr_1, srv_addr_2, srv_addr_3}; + std::vector<Address> server_addrs{srv_addr_1, srv_addr_2, srv_addr_3}; Address leader = server_addrs[0]; + RsmClient<SimulatorTransport, CasRequest, CasResponse, GetRequest, GetResponse> client(cli_io, leader, server_addrs); + const int key = 0; std::optional<int> last_known_value = 0; bool success = false; for (int i = 0; !success; i++) { - // send request + /* + * Write Request + */ CasRequest cas_req; cas_req.key = key; @@ -182,40 +187,12 @@ void RunSimulation() { cas_req.new_value = i; - WriteRequest<CasRequest> cli_req; - cli_req.operation = cas_req; - - spdlog::info("client sending CasRequest to Leader {} ", leader.last_known_port); - ResponseFuture<WriteResponse<CasResponse>> cas_response_future = - cli_io.Request<WriteRequest<CasRequest>, WriteResponse<CasResponse>>(leader, cli_req); - - // receive cas_response - ResponseResult<WriteResponse<CasResponse>> cas_response_result = std::move(cas_response_future).Wait(); - - if (cas_response_result.HasError()) { - spdlog::info("client timed out while trying to communicate with assumed Leader server {}", - leader.last_known_port); + auto write_cas_response_result = client.SendWriteRequest(cas_req); + if (write_cas_response_result.HasError()) { + // timed out continue; } - - ResponseEnvelope<WriteResponse<CasResponse>> cas_response_envelope = cas_response_result.GetValue(); - WriteResponse<CasResponse> write_cas_response = cas_response_envelope.message; - - if (write_cas_response.retry_leader) { - MG_ASSERT(!write_cas_response.success, "retry_leader should never be set for successful responses"); - leader = write_cas_response.retry_leader.value(); - spdlog::info("client redirected to leader server {}", leader.last_known_port); - } else if (!write_cas_response.success) { - std::uniform_int_distribution<size_t> addr_distrib(0, 2); - size_t addr_index = addr_distrib(cli_rng_); - leader = server_addrs[addr_index]; - - spdlog::info("client NOT redirected to leader server, trying a random one at index {} with port {}", addr_index, - leader.last_known_port); - continue; - } - - CasResponse cas_response = write_cas_response.write_return; + CasResponse cas_response = write_cas_response_result.GetValue(); bool cas_succeeded = cas_response.cas_success; @@ -228,47 +205,18 @@ void RunSimulation() { continue; } + /* + * Get Request + */ GetRequest get_req; get_req.key = key; - ReadRequest<GetRequest> read_req; - read_req.operation = get_req; - - spdlog::info("client sending GetRequest to Leader {}", leader.last_known_port); - - ResponseFuture<ReadResponse<GetResponse>> get_response_future = - cli_io.Request<ReadRequest<GetRequest>, ReadResponse<GetResponse>>(leader, read_req); - - // receive response - ResponseResult<ReadResponse<GetResponse>> get_response_result = std::move(get_response_future).Wait(); - - if (get_response_result.HasError()) { - spdlog::info("client timed out while trying to communicate with Leader server {}", leader.last_known_port); + auto read_get_response_result = client.SendReadRequest(get_req); + if (read_get_response_result.HasError()) { + // timed out continue; } - - ResponseEnvelope<ReadResponse<GetResponse>> get_response_envelope = get_response_result.GetValue(); - ReadResponse<GetResponse> read_get_response = get_response_envelope.message; - - if (!read_get_response.success) { - // sent to a non-leader - continue; - } - - if (read_get_response.retry_leader) { - MG_ASSERT(!read_get_response.success, "retry_leader should never be set for successful responses"); - leader = read_get_response.retry_leader.value(); - spdlog::info("client redirected to Leader server {}", leader.last_known_port); - } else if (!read_get_response.success) { - std::uniform_int_distribution<size_t> addr_distrib(0, 2); - size_t addr_index = addr_distrib(cli_rng_); - leader = server_addrs[addr_index]; - - spdlog::info("client NOT redirected to leader server, trying a random one at index {} with port {}", addr_index, - leader.last_known_port); - } - - GetResponse get_response = read_get_response.read_return; + GetResponse get_response = read_get_response_result.GetValue(); MG_ASSERT(get_response.value == i); diff --git a/tests/simulation/sharded_map.cpp b/tests/simulation/sharded_map.cpp new file mode 100644 index 000000000..cd9d7db1a --- /dev/null +++ b/tests/simulation/sharded_map.cpp @@ -0,0 +1,356 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include <chrono> +#include <deque> +#include <iostream> +#include <map> +#include <optional> +#include <set> +#include <thread> +#include <vector> + +#include "common/types.hpp" +#include "coordinator/coordinator_client.hpp" +#include "coordinator/coordinator_rsm.hpp" +#include "io/address.hpp" +#include "io/errors.hpp" +#include "io/rsm/raft.hpp" +#include "io/rsm/rsm_client.hpp" +#include "io/rsm/shard_rsm.hpp" +#include "io/simulator/simulator.hpp" +#include "io/simulator/simulator_transport.hpp" +#include "storage/v3/id_types.hpp" +#include "storage/v3/schemas.hpp" +#include "utils/result.hpp" + +using memgraph::coordinator::AddressAndStatus; +using memgraph::coordinator::CompoundKey; +using memgraph::coordinator::Coordinator; +using memgraph::coordinator::CoordinatorClient; +using memgraph::coordinator::CoordinatorRsm; +using memgraph::coordinator::HlcRequest; +using memgraph::coordinator::HlcResponse; +using memgraph::coordinator::Shard; +using memgraph::coordinator::ShardMap; +using memgraph::coordinator::Shards; +using memgraph::coordinator::Status; +using memgraph::io::Address; +using memgraph::io::Io; +using memgraph::io::ResponseEnvelope; +using memgraph::io::ResponseFuture; +using memgraph::io::Time; +using memgraph::io::TimedOut; +using memgraph::io::rsm::Raft; +using memgraph::io::rsm::ReadRequest; +using memgraph::io::rsm::ReadResponse; +using memgraph::io::rsm::RsmClient; +using memgraph::io::rsm::ShardRsm; +using memgraph::io::rsm::StorageReadRequest; +using memgraph::io::rsm::StorageReadResponse; +using memgraph::io::rsm::StorageWriteRequest; +using memgraph::io::rsm::StorageWriteResponse; +using memgraph::io::rsm::WriteRequest; +using memgraph::io::rsm::WriteResponse; +using memgraph::io::simulator::Simulator; +using memgraph::io::simulator::SimulatorConfig; +using memgraph::io::simulator::SimulatorStats; +using memgraph::io::simulator::SimulatorTransport; +using memgraph::storage::v3::LabelId; +using memgraph::storage::v3::SchemaProperty; +using memgraph::utils::BasicResult; + +using ShardClient = + RsmClient<SimulatorTransport, StorageWriteRequest, StorageWriteResponse, StorageReadRequest, StorageReadResponse>; +namespace { + +const std::string label_name = std::string("test_label"); + +ShardMap CreateDummyShardmap(Address a_io_1, Address a_io_2, Address a_io_3, Address b_io_1, Address b_io_2, + Address b_io_3) { + ShardMap sm; + + // register new properties + const std::vector<std::string> property_names = {"property_1", "property_2"}; + const auto properties = sm.AllocatePropertyIds(property_names); + const auto property_id_1 = properties.at("property_1"); + const auto property_id_2 = properties.at("property_2"); + const auto type_1 = memgraph::common::SchemaType::INT; + const auto type_2 = memgraph::common::SchemaType::INT; + + // register new label space + std::vector<SchemaProperty> schema = { + SchemaProperty{.property_id = property_id_1, .type = type_1}, + SchemaProperty{.property_id = property_id_2, .type = type_2}, + }; + bool label_success = sm.InitializeNewLabel(label_name, schema, sm.shard_map_version); + MG_ASSERT(label_success); + + const LabelId label_id = sm.labels.at(label_name); + auto &label_space = sm.label_spaces.at(label_id); + Shards &shards_for_label = label_space.shards; + + // add first shard at [0, 0] + AddressAndStatus aas1_1{.address = a_io_1, .status = Status::CONSENSUS_PARTICIPANT}; + AddressAndStatus aas1_2{.address = a_io_2, .status = Status::CONSENSUS_PARTICIPANT}; + AddressAndStatus aas1_3{.address = a_io_3, .status = Status::CONSENSUS_PARTICIPANT}; + + Shard shard1 = {aas1_1, aas1_2, aas1_3}; + + const auto key1 = memgraph::storage::v3::PropertyValue(0); + const auto key2 = memgraph::storage::v3::PropertyValue(0); + const CompoundKey compound_key_1 = {key1, key2}; + shards_for_label.emplace(compound_key_1, shard1); + + // add second shard at [12, 13] + AddressAndStatus aas2_1{.address = b_io_1, .status = Status::CONSENSUS_PARTICIPANT}; + AddressAndStatus aas2_2{.address = b_io_2, .status = Status::CONSENSUS_PARTICIPANT}; + AddressAndStatus aas2_3{.address = b_io_3, .status = Status::CONSENSUS_PARTICIPANT}; + + Shard shard2 = {aas2_1, aas2_2, aas2_3}; + + auto key3 = memgraph::storage::v3::PropertyValue(12); + auto key4 = memgraph::storage::v3::PropertyValue(13); + CompoundKey compound_key_2 = {key3, key4}; + shards_for_label[compound_key_2] = shard2; + + return sm; +} + +std::optional<ShardClient> DetermineShardLocation(const Shard &target_shard, const std::vector<Address> &a_addrs, + ShardClient a_client, const std::vector<Address> &b_addrs, + ShardClient b_client) { + for (const auto &addr : target_shard) { + if (addr.address == b_addrs[0]) { + return b_client; + } + if (addr.address == a_addrs[0]) { + return a_client; + } + } + return {}; +} + +} // namespace + +using ConcreteCoordinatorRsm = CoordinatorRsm<SimulatorTransport>; +using ConcreteShardRsm = Raft<SimulatorTransport, ShardRsm, StorageWriteRequest, StorageWriteResponse, + StorageReadRequest, StorageReadResponse>; + +template <typename IoImpl> +void RunStorageRaft( + Raft<IoImpl, ShardRsm, StorageWriteRequest, StorageWriteResponse, StorageReadRequest, StorageReadResponse> server) { + server.Run(); +} + +int main() { + SimulatorConfig config{ + .drop_percent = 5, + .perform_timeouts = true, + .scramble_messages = true, + .rng_seed = 0, + .start_time = Time::min() + std::chrono::microseconds{256 * 1024}, + .abort_time = Time::min() + std::chrono::microseconds{2 * 8 * 1024 * 1024}, + }; + + auto simulator = Simulator(config); + + Io<SimulatorTransport> cli_io = simulator.RegisterNew(); + + // Register + Io<SimulatorTransport> a_io_1 = simulator.RegisterNew(); + Io<SimulatorTransport> a_io_2 = simulator.RegisterNew(); + Io<SimulatorTransport> a_io_3 = simulator.RegisterNew(); + + Io<SimulatorTransport> b_io_1 = simulator.RegisterNew(); + Io<SimulatorTransport> b_io_2 = simulator.RegisterNew(); + Io<SimulatorTransport> b_io_3 = simulator.RegisterNew(); + + // Preconfigure coordinator with kv shard 'A' and 'B' + auto sm1 = CreateDummyShardmap(a_io_1.GetAddress(), a_io_2.GetAddress(), a_io_3.GetAddress(), b_io_1.GetAddress(), + b_io_2.GetAddress(), b_io_3.GetAddress()); + auto sm2 = CreateDummyShardmap(a_io_1.GetAddress(), a_io_2.GetAddress(), a_io_3.GetAddress(), b_io_1.GetAddress(), + b_io_2.GetAddress(), b_io_3.GetAddress()); + auto sm3 = CreateDummyShardmap(a_io_1.GetAddress(), a_io_2.GetAddress(), a_io_3.GetAddress(), b_io_1.GetAddress(), + b_io_2.GetAddress(), b_io_3.GetAddress()); + + // Spin up shard A + std::vector<Address> a_addrs = {a_io_1.GetAddress(), a_io_2.GetAddress(), a_io_3.GetAddress()}; + + std::vector<Address> a_1_peers = {a_addrs[1], a_addrs[2]}; + std::vector<Address> a_2_peers = {a_addrs[0], a_addrs[2]}; + std::vector<Address> a_3_peers = {a_addrs[0], a_addrs[1]}; + + ConcreteShardRsm a_1{std::move(a_io_1), a_1_peers, ShardRsm{}}; + ConcreteShardRsm a_2{std::move(a_io_2), a_2_peers, ShardRsm{}}; + ConcreteShardRsm a_3{std::move(a_io_3), a_3_peers, ShardRsm{}}; + + auto a_thread_1 = std::jthread(RunStorageRaft<SimulatorTransport>, std::move(a_1)); + simulator.IncrementServerCountAndWaitForQuiescentState(a_addrs[0]); + + auto a_thread_2 = std::jthread(RunStorageRaft<SimulatorTransport>, std::move(a_2)); + simulator.IncrementServerCountAndWaitForQuiescentState(a_addrs[1]); + + auto a_thread_3 = std::jthread(RunStorageRaft<SimulatorTransport>, std::move(a_3)); + simulator.IncrementServerCountAndWaitForQuiescentState(a_addrs[2]); + + // Spin up shard B + std::vector<Address> b_addrs = {b_io_1.GetAddress(), b_io_2.GetAddress(), b_io_3.GetAddress()}; + + std::vector<Address> b_1_peers = {b_addrs[1], b_addrs[2]}; + std::vector<Address> b_2_peers = {b_addrs[0], b_addrs[2]}; + std::vector<Address> b_3_peers = {b_addrs[0], b_addrs[1]}; + + ConcreteShardRsm b_1{std::move(b_io_1), b_1_peers, ShardRsm{}}; + ConcreteShardRsm b_2{std::move(b_io_2), b_2_peers, ShardRsm{}}; + ConcreteShardRsm b_3{std::move(b_io_3), b_3_peers, ShardRsm{}}; + + auto b_thread_1 = std::jthread(RunStorageRaft<SimulatorTransport>, std::move(b_1)); + simulator.IncrementServerCountAndWaitForQuiescentState(b_addrs[0]); + + auto b_thread_2 = std::jthread(RunStorageRaft<SimulatorTransport>, std::move(b_2)); + simulator.IncrementServerCountAndWaitForQuiescentState(b_addrs[1]); + + auto b_thread_3 = std::jthread(RunStorageRaft<SimulatorTransport>, std::move(b_3)); + simulator.IncrementServerCountAndWaitForQuiescentState(b_addrs[2]); + + // Spin up coordinators + + Io<SimulatorTransport> c_io_1 = simulator.RegisterNew(); + Io<SimulatorTransport> c_io_2 = simulator.RegisterNew(); + Io<SimulatorTransport> c_io_3 = simulator.RegisterNew(); + + std::vector<Address> c_addrs = {c_io_1.GetAddress(), c_io_2.GetAddress(), c_io_3.GetAddress()}; + + std::vector<Address> c_1_peers = {c_addrs[1], c_addrs[2]}; + std::vector<Address> c_2_peers = {c_addrs[0], c_addrs[2]}; + std::vector<Address> c_3_peers = {c_addrs[0], c_addrs[1]}; + + ConcreteCoordinatorRsm c_1{std::move(c_io_1), c_1_peers, Coordinator{(sm1)}}; + ConcreteCoordinatorRsm c_2{std::move(c_io_2), c_2_peers, Coordinator{(sm2)}}; + ConcreteCoordinatorRsm c_3{std::move(c_io_3), c_3_peers, Coordinator{(sm3)}}; + + auto c_thread_1 = std::jthread([c_1]() mutable { c_1.Run(); }); + simulator.IncrementServerCountAndWaitForQuiescentState(c_addrs[0]); + + auto c_thread_2 = std::jthread([c_2]() mutable { c_2.Run(); }); + simulator.IncrementServerCountAndWaitForQuiescentState(c_addrs[1]); + + auto c_thread_3 = std::jthread([c_3]() mutable { c_3.Run(); }); + simulator.IncrementServerCountAndWaitForQuiescentState(c_addrs[2]); + + std::cout << "beginning test after servers have become quiescent" << std::endl; + + // Have client contact coordinator RSM for a new transaction ID and + // also get the current shard map + CoordinatorClient<SimulatorTransport> coordinator_client(cli_io, c_addrs[0], c_addrs); + + ShardClient shard_a_client(cli_io, a_addrs[0], a_addrs); + ShardClient shard_b_client(cli_io, b_addrs[0], b_addrs); + + memgraph::coordinator::HlcRequest req; + + // Last ShardMap Version The query engine knows about. + ShardMap client_shard_map; + req.last_shard_map_version = client_shard_map.GetHlc(); + + while (true) { + // Create CompoundKey + const auto cm_key_1 = memgraph::storage::v3::PropertyValue(3); + const auto cm_key_2 = memgraph::storage::v3::PropertyValue(4); + + const CompoundKey compound_key = {cm_key_1, cm_key_2}; + + // Look for Shard + BasicResult<TimedOut, memgraph::coordinator::CoordinatorWriteResponses> read_res = + coordinator_client.SendWriteRequest(req); + + if (read_res.HasError()) { + // timeout + continue; + } + + auto coordinator_read_response = read_res.GetValue(); + HlcResponse hlc_response = std::get<HlcResponse>(coordinator_read_response); + + // Transaction ID to be used later... + auto transaction_id = hlc_response.new_hlc; + + if (hlc_response.fresher_shard_map) { + client_shard_map = hlc_response.fresher_shard_map.value(); + } + + auto target_shard = client_shard_map.GetShardForKey(label_name, compound_key); + + // Determine which shard to send the requests to. This should be a more proper client cache in the "real" version. + auto storage_client_opt = DetermineShardLocation(target_shard, a_addrs, shard_a_client, b_addrs, shard_b_client); + MG_ASSERT(storage_client_opt); + + auto storage_client = storage_client_opt.value(); + + LabelId label_id = client_shard_map.labels.at(label_name); + + // Have client use shard map to decide which shard to communicate + // with in order to write a new value + // client_shard_map. + StorageWriteRequest storage_req; + storage_req.label_id = label_id; + storage_req.key = compound_key; + storage_req.value = 1000; + storage_req.transaction_id = transaction_id; + + auto write_response_result = storage_client.SendWriteRequest(storage_req); + if (write_response_result.HasError()) { + // timed out + continue; + } + auto write_response = write_response_result.GetValue(); + + bool cas_succeeded = write_response.shard_rsm_success; + + if (!cas_succeeded) { + continue; + } + // Have client use shard map to decide which shard to communicate + // with to read that same value back + + StorageReadRequest storage_get_req; + storage_get_req.label_id = label_id; + storage_get_req.key = compound_key; + storage_get_req.transaction_id = transaction_id; + + auto get_response_result = storage_client.SendReadRequest(storage_get_req); + if (get_response_result.HasError()) { + // timed out + continue; + } + auto get_response = get_response_result.GetValue(); + auto val = get_response.value.value(); + + MG_ASSERT(val == 1000); + break; + } + + simulator.ShutDown(); + + SimulatorStats stats = simulator.Stats(); + + std::cout << "total messages: " << stats.total_messages << std::endl; + std::cout << "dropped messages: " << stats.dropped_messages << std::endl; + std::cout << "timed out requests: " << stats.timed_out_requests << std::endl; + std::cout << "total requests: " << stats.total_requests << std::endl; + std::cout << "total responses: " << stats.total_responses << std::endl; + std::cout << "simulator ticks: " << stats.simulator_ticks << std::endl; + + std::cout << "========================== SUCCESS :) ==========================" << std::endl; + + return 0; +} diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 1c08943c1..c112fe9b2 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -402,12 +402,6 @@ add_custom_target(test_lcp ALL DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/test_lcp) add_test(test_lcp ${CMAKE_CURRENT_BINARY_DIR}/test_lcp) add_dependencies(memgraph__unit test_lcp) -# Test websocket -find_package(Boost REQUIRED) - -add_unit_test(websocket.cpp) -target_link_libraries(${test_prefix}websocket mg-communication Boost::headers) - # Test future add_unit_test(future.cpp) target_link_libraries(${test_prefix}future mg-io)