From 14c9e6845630e71ea7d141b4447e80230ffcc275 Mon Sep 17 00:00:00 2001 From: Tyler Neely <tylerneely@gmail.com> Date: Fri, 12 Aug 2022 08:24:32 +0200 Subject: [PATCH 1/4] Transport prototype (#466) --- .github/workflows/diff.yaml | 11 +- src/CMakeLists.txt | 1 + src/io/address.hpp | 60 ++++ src/io/errors.hpp | 26 ++ src/io/future.hpp | 262 ++++++++++++++++++ src/io/simulator/CMakeLists.txt | 8 + src/io/simulator/message_conversion.hpp | 177 ++++++++++++ src/io/simulator/simulator.hpp | 45 +++ src/io/simulator/simulator_config.hpp | 30 ++ src/io/simulator/simulator_handle.cpp | 154 ++++++++++ src/io/simulator/simulator_handle.hpp | 206 ++++++++++++++ src/io/simulator/simulator_stats.hpp | 25 ++ src/io/simulator/simulator_transport.hpp | 65 +++++ src/io/time.hpp | 21 ++ src/io/transport.hpp | 130 +++++++++ tests/CMakeLists.txt | 3 + tests/benchmark/CMakeLists.txt | 3 + tests/benchmark/future.cpp | 30 ++ tests/simulation/CMakeLists.txt | 30 ++ tests/simulation/basic_request.cpp | 87 ++++++ .../trial_query_storage/messages.hpp | 34 +++ .../query_storage_test.cpp | 85 ++++++ tests/unit/CMakeLists.txt | 4 + tests/unit/future.cpp | 55 ++++ 24 files changed, 1551 insertions(+), 1 deletion(-) create mode 100644 src/io/address.hpp create mode 100644 src/io/errors.hpp create mode 100644 src/io/future.hpp create mode 100644 src/io/simulator/CMakeLists.txt create mode 100644 src/io/simulator/message_conversion.hpp create mode 100644 src/io/simulator/simulator.hpp create mode 100644 src/io/simulator/simulator_config.hpp create mode 100644 src/io/simulator/simulator_handle.cpp create mode 100644 src/io/simulator/simulator_handle.hpp create mode 100644 src/io/simulator/simulator_stats.hpp create mode 100644 src/io/simulator/simulator_transport.hpp create mode 100644 src/io/time.hpp create mode 100644 src/io/transport.hpp create mode 100644 tests/benchmark/future.cpp create mode 100644 tests/simulation/CMakeLists.txt create mode 100644 tests/simulation/basic_request.cpp create mode 100644 tests/simulation/trial_query_storage/messages.hpp create mode 100644 tests/simulation/trial_query_storage/query_storage_test.cpp create mode 100644 tests/unit/future.cpp diff --git a/.github/workflows/diff.yaml b/.github/workflows/diff.yaml index bf6a39147..ce2caea75 100644 --- a/.github/workflows/diff.yaml +++ b/.github/workflows/diff.yaml @@ -171,7 +171,7 @@ jobs: # Run leftover CTest tests (all except unit and benchmark tests). cd build - ctest -E "(memgraph__unit|memgraph__benchmark)" --output-on-failure + ctest -E "(memgraph__unit|memgraph__benchmark|memgraph__simulation)" --output-on-failure - name: Run drivers tests run: | @@ -262,6 +262,15 @@ jobs: cd build ctest -R memgraph__unit --output-on-failure -j$THREADS + - name: Run simulation tests + run: | + # Activate toolchain. + source /opt/toolchain-v4/activate + + # Run unit tests. + cd build + ctest -R memgraph__simulation --output-on-failure -j$THREADS + - name: Run e2e tests run: | # TODO(gitbuda): Setup mgclient and pymgclient properly. diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index efc653b9a..f4c303daf 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -5,6 +5,7 @@ add_subdirectory(lisp) add_subdirectory(utils) add_subdirectory(requests) add_subdirectory(io) +add_subdirectory(io/simulator) add_subdirectory(kvstore) add_subdirectory(telemetry) add_subdirectory(communication) diff --git a/src/io/address.hpp b/src/io/address.hpp new file mode 100644 index 000000000..94a231e07 --- /dev/null +++ b/src/io/address.hpp @@ -0,0 +1,60 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <compare> + +#include <fmt/format.h> +#include <boost/asio/ip/tcp.hpp> +#include <boost/uuid/uuid.hpp> +#include <boost/uuid/uuid_io.hpp> + +namespace memgraph::io { +struct Address { + // It's important for all participants to have a + // unique identifier - IP and port alone are not + // enough, and may change over the lifecycle of + // the nodes. Particularly storage nodes may change + // their IP addresses over time, and the system + // should gracefully update its information + // about them. + boost::uuids::uuid unique_id; + boost::asio::ip::address last_known_ip; + uint16_t last_known_port; + + static Address TestAddress(uint16_t port) { + Address ret; + ret.last_known_port = port; + return ret; + } + + friend bool operator==(const Address &lhs, const Address &rhs) = default; + + /// unique_id is most dominant for ordering, then last_known_ip, then last_known_port + friend bool operator<(const Address &lhs, const Address &rhs) { + if (lhs.unique_id != rhs.unique_id) { + return lhs.unique_id < rhs.unique_id; + } + + if (lhs.last_known_ip != rhs.last_known_ip) { + return lhs.last_known_ip < rhs.last_known_ip; + } + + return lhs.last_known_port < rhs.last_known_port; + } + + std::string ToString() const { + return fmt::format("Address {{ unique_id: {}, last_known_ip: {}, last_known_port: {} }}", + boost::uuids::to_string(unique_id), last_known_ip.to_string(), last_known_port); + } +}; +}; // namespace memgraph::io diff --git a/src/io/errors.hpp b/src/io/errors.hpp new file mode 100644 index 000000000..7df2171d9 --- /dev/null +++ b/src/io/errors.hpp @@ -0,0 +1,26 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +namespace memgraph::io { +// Signifies that a retriable operation was unable to +// complete after a configured number of retries. +struct RetriesExhausted {}; + +// Signifies that a request was unable to receive a response +// within some configured timeout duration. It is important +// to remember that in distributed systems, a timeout does +// not signify that a request was not received or processed. +// It may be the case that the request was fully processed +// but that the response was not received. +struct TimedOut {}; +}; // namespace memgraph::io diff --git a/src/io/future.hpp b/src/io/future.hpp new file mode 100644 index 000000000..7b9a4461c --- /dev/null +++ b/src/io/future.hpp @@ -0,0 +1,262 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <condition_variable> +#include <memory> +#include <mutex> +#include <optional> +#include <thread> +#include <utility> + +#include "io/errors.hpp" +#include "utils/logging.hpp" + +namespace memgraph::io { + +// Shared is in an anonymous namespace, and the only way to +// construct a Promise or Future is to pass a Shared in. This +// ensures that Promises and Futures can only be constructed +// in this translation unit. +namespace details { +template <typename T> +class Shared { + mutable std::condition_variable cv_; + mutable std::mutex mu_; + std::optional<T> item_; + bool consumed_ = false; + bool waiting_ = false; + std::function<bool()> simulator_notifier_ = nullptr; + + public: + explicit Shared(std::function<bool()> simulator_notifier) : simulator_notifier_(simulator_notifier) {} + Shared() = default; + Shared(Shared &&) = delete; + Shared &operator=(Shared &&) = delete; + Shared(const Shared &) = delete; + Shared &operator=(const Shared &) = delete; + ~Shared() = default; + + /// Takes the item out of our optional item_ and returns it. + T Take() { + MG_ASSERT(item_, "Take called without item_ being present"); + MG_ASSERT(!consumed_, "Take called on already-consumed Future"); + + T ret = std::move(item_).value(); + item_.reset(); + + consumed_ = true; + + return ret; + } + + T Wait() { + std::unique_lock<std::mutex> lock(mu_); + waiting_ = true; + + while (!item_) { + bool simulator_progressed = false; + if (simulator_notifier_) [[unlikely]] { + // We can't hold our own lock while notifying + // the simulator because notifying the simulator + // involves acquiring the simulator's mutex + // to guarantee that our notification linearizes + // with the simulator's condition variable. + // However, the simulator may acquire our + // mutex to check if we are being awaited, + // while determining system quiescence, + // so we have to get out of its way to avoid + // a cyclical deadlock. + lock.unlock(); + simulator_progressed = std::invoke(simulator_notifier_); + lock.lock(); + if (item_) { + // item may have been filled while we + // had dropped our mutex while notifying + // the simulator of our waiting_ status. + break; + } + } + if (!simulator_progressed) [[likely]] { + cv_.wait(lock); + } + MG_ASSERT(!consumed_, "Future consumed twice!"); + } + + waiting_ = false; + + return Take(); + } + + bool IsReady() const { + std::unique_lock<std::mutex> lock(mu_); + return item_; + } + + std::optional<T> TryGet() { + std::unique_lock<std::mutex> lock(mu_); + + if (item_) { + return Take(); + } + + return std::nullopt; + } + + void Fill(T item) { + { + std::unique_lock<std::mutex> lock(mu_); + + MG_ASSERT(!consumed_, "Promise filled after it was already consumed!"); + MG_ASSERT(!item_, "Promise filled twice!"); + + item_ = item; + } // lock released before condition variable notification + + cv_.notify_all(); + } + + bool IsAwaited() const { + std::unique_lock<std::mutex> lock(mu_); + return waiting_; + } +}; +} // namespace details + +template <typename T> +class Future { + bool consumed_or_moved_ = false; + std::shared_ptr<details::Shared<T>> shared_; + + public: + explicit Future(std::shared_ptr<details::Shared<T>> shared) : shared_(shared) {} + + Future() = delete; + Future(Future &&old) noexcept { + MG_ASSERT(!old.consumed_or_moved_, "Future moved from after already being moved from or consumed."); + shared_ = std::move(old.shared_); + consumed_or_moved_ = old.consumed_or_moved_; + old.consumed_or_moved_ = true; + } + + Future &operator=(Future &&old) noexcept { + MG_ASSERT(!old.consumed_or_moved_, "Future moved from after already being moved from or consumed."); + shared_ = std::move(old.shared_); + old.consumed_or_moved_ = true; + } + + Future(const Future &) = delete; + Future &operator=(const Future &) = delete; + ~Future() = default; + + /// Returns true if the Future is ready to + /// be consumed using TryGet or Wait (prefer Wait + /// if you know it's ready, because it doesn't + /// return an optional. + bool IsReady() { + MG_ASSERT(!consumed_or_moved_, "Called IsReady after Future already consumed!"); + return shared_->IsReady(); + } + + /// Non-blocking method that returns the inner + /// item if it's already ready, or std::nullopt + /// if it is not ready yet. + std::optional<T> TryGet() { + MG_ASSERT(!consumed_or_moved_, "Called TryGet after Future already consumed!"); + std::optional<T> ret = shared_->TryGet(); + if (ret) { + consumed_or_moved_ = true; + } + return ret; + } + + /// Block on the corresponding promise to be filled, + /// returning the inner item when ready. + T Wait() && { + MG_ASSERT(!consumed_or_moved_, "Future should only be consumed with Wait once!"); + T ret = shared_->Wait(); + consumed_or_moved_ = true; + return ret; + } + + /// Marks this Future as canceled. + void Cancel() { + MG_ASSERT(!consumed_or_moved_, "Future::Cancel called on a future that was already moved or consumed!"); + consumed_or_moved_ = true; + } +}; + +template <typename T> +class Promise { + std::shared_ptr<details::Shared<T>> shared_; + bool filled_or_moved_ = false; + + public: + explicit Promise(std::shared_ptr<details::Shared<T>> shared) : shared_(shared) {} + + Promise() = delete; + Promise(Promise &&old) noexcept { + MG_ASSERT(!old.filled_or_moved_, "Promise moved from after already being moved from or filled."); + shared_ = std::move(old.shared_); + old.filled_or_moved_ = true; + } + + Promise &operator=(Promise &&old) noexcept { + MG_ASSERT(!old.filled_or_moved_, "Promise moved from after already being moved from or filled."); + shared_ = std::move(old.shared_); + old.filled_or_moved_ = true; + } + Promise(const Promise &) = delete; + Promise &operator=(const Promise &) = delete; + + ~Promise() { MG_ASSERT(filled_or_moved_, "Promise destroyed before its associated Future was filled!"); } + + // Fill the expected item into the Future. + void Fill(T item) { + MG_ASSERT(!filled_or_moved_, "Promise::Fill called on a promise that is already filled or moved!"); + shared_->Fill(item); + filled_or_moved_ = true; + } + + bool IsAwaited() { return shared_->IsAwaited(); } + + /// Moves this Promise into a unique_ptr. + std::unique_ptr<Promise<T>> ToUnique() && { + std::unique_ptr<Promise<T>> up = std::make_unique<Promise<T>>(std::move(shared_)); + + filled_or_moved_ = true; + + return up; + } +}; + +template <typename T> +std::pair<Future<T>, Promise<T>> FuturePromisePair() { + std::shared_ptr<details::Shared<T>> shared = std::make_shared<details::Shared<T>>(); + + Future<T> future = Future<T>(shared); + Promise<T> promise = Promise<T>(shared); + + return std::make_pair(std::move(future), std::move(promise)); +} + +template <typename T> +std::pair<Future<T>, Promise<T>> FuturePromisePairWithNotifier(std::function<bool()> simulator_notifier) { + std::shared_ptr<details::Shared<T>> shared = std::make_shared<details::Shared<T>>(simulator_notifier); + + Future<T> future = Future<T>(shared); + Promise<T> promise = Promise<T>(shared); + + return std::make_pair(std::move(future), std::move(promise)); +} + +}; // namespace memgraph::io diff --git a/src/io/simulator/CMakeLists.txt b/src/io/simulator/CMakeLists.txt new file mode 100644 index 000000000..1cb61d8d9 --- /dev/null +++ b/src/io/simulator/CMakeLists.txt @@ -0,0 +1,8 @@ +set(io_simulator_sources + simulator_handle.cpp) + +find_package(fmt REQUIRED) +find_package(Threads REQUIRED) + +add_library(mg-io-simulator STATIC ${io_simulator_sources}) +target_link_libraries(mg-io-simulator stdc++fs Threads::Threads fmt::fmt mg-utils) diff --git a/src/io/simulator/message_conversion.hpp b/src/io/simulator/message_conversion.hpp new file mode 100644 index 000000000..f16c60f65 --- /dev/null +++ b/src/io/simulator/message_conversion.hpp @@ -0,0 +1,177 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include "io/transport.hpp" + +namespace memgraph::io::simulator { + +using memgraph::io::Duration; +using memgraph::io::Message; +using memgraph::io::Time; + +struct OpaqueMessage { + Address from_address; + uint64_t request_id; + std::any message; + + /// Recursively tries to match a specific type from the outer + /// variant's parameter pack against the type of the std::any, + /// and if it matches, make it concrete and return it. Otherwise, + /// move on and compare the any with the next type from the + /// parameter pack. + /// + /// Return is the full std::variant<Ts...> type that holds the + /// full parameter pack without interfering with recursive + /// narrowing expansion. + template <typename Return, Message Head, Message... Rest> + std::optional<Return> Unpack(std::any &&a) { + if (typeid(Head) == a.type()) { + Head concrete = std::any_cast<Head>(std::move(a)); + return concrete; + } + + if constexpr (sizeof...(Rest) > 0) { + return Unpack<Return, Rest...>(std::move(a)); + } else { + return std::nullopt; + } + } + + /// High level "user-facing" conversion function that lets + /// people interested in conversion only supply a single + /// parameter pack for the types that they want to compare + /// with the any and potentially include in the returned + /// variant. + template <Message... Ms> + requires(sizeof...(Ms) > 0) std::optional<std::variant<Ms...>> VariantFromAny(std::any &&a) { + return Unpack<std::variant<Ms...>, Ms...>(std::move(a)); + } + + template <Message... Ms> + requires(sizeof...(Ms) > 0) std::optional<RequestEnvelope<Ms...>> Take() && { + std::optional<std::variant<Ms...>> m_opt = VariantFromAny<Ms...>(std::move(message)); + + if (m_opt) { + return RequestEnvelope<Ms...>{ + .message = std::move(*m_opt), + .request_id = request_id, + .from_address = from_address, + }; + } + + return std::nullopt; + } +}; + +class OpaquePromiseTraitBase { + public: + virtual const std::type_info *TypeInfo() const = 0; + virtual bool IsAwaited(void *ptr) const = 0; + virtual void Fill(void *ptr, OpaqueMessage &&) const = 0; + virtual void TimeOut(void *ptr) const = 0; + + virtual ~OpaquePromiseTraitBase() = default; + OpaquePromiseTraitBase() = default; + OpaquePromiseTraitBase(const OpaquePromiseTraitBase &) = delete; + OpaquePromiseTraitBase &operator=(const OpaquePromiseTraitBase &) = delete; + OpaquePromiseTraitBase(OpaquePromiseTraitBase &&old) = delete; + OpaquePromiseTraitBase &operator=(OpaquePromiseTraitBase &&) = delete; +}; + +template <typename T> +class OpaquePromiseTrait : public OpaquePromiseTraitBase { + public: + const std::type_info *TypeInfo() const override { return &typeid(T); }; + + bool IsAwaited(void *ptr) const override { return static_cast<ResponsePromise<T> *>(ptr)->IsAwaited(); }; + + void Fill(void *ptr, OpaqueMessage &&opaque_message) const override { + T message = std::any_cast<T>(std::move(opaque_message.message)); + auto response_envelope = ResponseEnvelope<T>{.message = std::move(message), + .request_id = opaque_message.request_id, + .from_address = opaque_message.from_address}; + auto promise = static_cast<ResponsePromise<T> *>(ptr); + auto unique_promise = std::unique_ptr<ResponsePromise<T>>(promise); + unique_promise->Fill(std::move(response_envelope)); + }; + + void TimeOut(void *ptr) const override { + auto promise = static_cast<ResponsePromise<T> *>(ptr); + auto unique_promise = std::unique_ptr<ResponsePromise<T>>(promise); + ResponseResult<T> result = TimedOut{}; + unique_promise->Fill(std::move(result)); + } +}; + +class OpaquePromise { + void *ptr_; + std::unique_ptr<OpaquePromiseTraitBase> trait_; + + public: + OpaquePromise(OpaquePromise &&old) noexcept : ptr_(old.ptr_), trait_(std::move(old.trait_)) { + MG_ASSERT(old.ptr_ != nullptr); + old.ptr_ = nullptr; + } + + OpaquePromise &operator=(OpaquePromise &&old) noexcept { + MG_ASSERT(ptr_ == nullptr); + MG_ASSERT(old.ptr_ != nullptr); + MG_ASSERT(this != &old); + ptr_ = old.ptr_; + trait_ = std::move(old.trait_); + old.ptr_ = nullptr; + return *this; + } + + OpaquePromise(const OpaquePromise &) = delete; + OpaquePromise &operator=(const OpaquePromise &) = delete; + + template <typename T> + std::unique_ptr<ResponsePromise<T>> Take() && { + MG_ASSERT(typeid(T) == *trait_->TypeInfo()); + MG_ASSERT(ptr_ != nullptr); + + auto ptr = static_cast<ResponsePromise<T> *>(ptr_); + + ptr_ = nullptr; + + return std::unique_ptr<T>(ptr); + } + + template <typename T> + explicit OpaquePromise(std::unique_ptr<ResponsePromise<T>> promise) + : ptr_(static_cast<void *>(promise.release())), trait_(std::make_unique<OpaquePromiseTrait<T>>()) {} + + bool IsAwaited() { + MG_ASSERT(ptr_ != nullptr); + return trait_->IsAwaited(ptr_); + } + + void TimeOut() { + MG_ASSERT(ptr_ != nullptr); + trait_->TimeOut(ptr_); + ptr_ = nullptr; + } + + void Fill(OpaqueMessage &&opaque_message) { + MG_ASSERT(ptr_ != nullptr); + trait_->Fill(ptr_, std::move(opaque_message)); + ptr_ = nullptr; + } + + ~OpaquePromise() { + MG_ASSERT(ptr_ == nullptr, "OpaquePromise destroyed without being explicitly timed out or filled"); + } +}; + +} // namespace memgraph::io::simulator diff --git a/src/io/simulator/simulator.hpp b/src/io/simulator/simulator.hpp new file mode 100644 index 000000000..354aae6ac --- /dev/null +++ b/src/io/simulator/simulator.hpp @@ -0,0 +1,45 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <memory> +#include <random> + +#include "io/address.hpp" +#include "io/simulator/simulator_config.hpp" +#include "io/simulator/simulator_handle.hpp" +#include "io/simulator/simulator_transport.hpp" + +namespace memgraph::io::simulator { +class Simulator { + std::mt19937 rng_; + std::shared_ptr<SimulatorHandle> simulator_handle_; + + public: + explicit Simulator(SimulatorConfig config) + : rng_(std::mt19937{config.rng_seed}), simulator_handle_{std::make_shared<SimulatorHandle>(config)} {} + + void ShutDown() { simulator_handle_->ShutDown(); } + + Io<SimulatorTransport> Register(Address address) { + std::uniform_int_distribution<uint64_t> seed_distrib; + uint64_t seed = seed_distrib(rng_); + return Io{SimulatorTransport{simulator_handle_, address, seed}, address}; + } + + void IncrementServerCountAndWaitForQuiescentState(Address address) { + simulator_handle_->IncrementServerCountAndWaitForQuiescentState(address); + } + + SimulatorStats Stats() { return simulator_handle_->Stats(); } +}; +}; // namespace memgraph::io::simulator diff --git a/src/io/simulator/simulator_config.hpp b/src/io/simulator/simulator_config.hpp new file mode 100644 index 000000000..4719488d2 --- /dev/null +++ b/src/io/simulator/simulator_config.hpp @@ -0,0 +1,30 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <chrono> + +#include "io/time.hpp" + +namespace memgraph::io::simulator { + +using memgraph::io::Time; + +struct SimulatorConfig { + uint8_t drop_percent = 0; + bool perform_timeouts = false; + bool scramble_messages = true; + uint64_t rng_seed = 0; + Time start_time = Time::min(); + Time abort_time = Time::max(); +}; +}; // namespace memgraph::io::simulator diff --git a/src/io/simulator/simulator_handle.cpp b/src/io/simulator/simulator_handle.cpp new file mode 100644 index 000000000..05585f551 --- /dev/null +++ b/src/io/simulator/simulator_handle.cpp @@ -0,0 +1,154 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "io/simulator/simulator_handle.hpp" +#include "io/address.hpp" +#include "io/errors.hpp" +#include "io/simulator/simulator_config.hpp" +#include "io/simulator/simulator_stats.hpp" +#include "io/time.hpp" +#include "io/transport.hpp" + +namespace memgraph::io::simulator { + +using memgraph::io::Duration; +using memgraph::io::Time; + +void SimulatorHandle::ShutDown() { + std::unique_lock<std::mutex> lock(mu_); + should_shut_down_ = true; + cv_.notify_all(); +} + +bool SimulatorHandle::ShouldShutDown() const { + std::unique_lock<std::mutex> lock(mu_); + return should_shut_down_; +} + +void SimulatorHandle::IncrementServerCountAndWaitForQuiescentState(Address address) { + std::unique_lock<std::mutex> lock(mu_); + server_addresses_.insert(address); + + while (true) { + const size_t blocked_servers = BlockedServers(); + + const bool all_servers_blocked = blocked_servers == server_addresses_.size(); + + if (all_servers_blocked) { + return; + } + + cv_.wait(lock); + } +} + +size_t SimulatorHandle::BlockedServers() { + size_t blocked_servers = blocked_on_receive_; + + for (auto &[promise_key, opaque_promise] : promises_) { + if (opaque_promise.promise.IsAwaited() && server_addresses_.contains(promise_key.requester_address)) { + blocked_servers++; + } + } + + return blocked_servers; +} + +bool SimulatorHandle::MaybeTickSimulator() { + std::unique_lock<std::mutex> lock(mu_); + + const size_t blocked_servers = BlockedServers(); + + if (blocked_servers < server_addresses_.size()) { + // we only need to advance the simulator when all + // servers have reached a quiescent state, blocked + // on their own futures or receive methods. + return false; + } + + stats_.simulator_ticks++; + + cv_.notify_all(); + + TimeoutPromisesPastDeadline(); + + if (in_flight_.empty()) { + // return early here because there are no messages to schedule + + // We tick the clock forward when all servers are blocked but + // there are no in-flight messages to schedule delivery of. + std::poisson_distribution<> time_distrib(50); + Duration clock_advance = std::chrono::microseconds{time_distrib(rng_)}; + cluster_wide_time_microseconds_ += clock_advance; + + MG_ASSERT(cluster_wide_time_microseconds_ < config_.abort_time, + "Cluster has executed beyond its configured abort_time, and something may be failing to make progress " + "in an expected amount of time."); + return true; + } + + if (config_.scramble_messages) { + // scramble messages + std::uniform_int_distribution<size_t> swap_distrib(0, in_flight_.size() - 1); + const size_t swap_index = swap_distrib(rng_); + std::swap(in_flight_[swap_index], in_flight_.back()); + } + + auto [to_address, opaque_message] = std::move(in_flight_.back()); + in_flight_.pop_back(); + + std::uniform_int_distribution<int> drop_distrib(0, 99); + const int drop_threshold = drop_distrib(rng_); + const bool should_drop = drop_threshold < config_.drop_percent; + + if (should_drop) { + stats_.dropped_messages++; + } + + PromiseKey promise_key{.requester_address = to_address, + .request_id = opaque_message.request_id, + .replier_address = opaque_message.from_address}; + + if (promises_.contains(promise_key)) { + // complete waiting promise if it's there + DeadlineAndOpaquePromise dop = std::move(promises_.at(promise_key)); + promises_.erase(promise_key); + + const bool normal_timeout = config_.perform_timeouts && (dop.deadline < cluster_wide_time_microseconds_); + + if (should_drop || normal_timeout) { + stats_.timed_out_requests++; + dop.promise.TimeOut(); + } else { + stats_.total_responses++; + dop.promise.Fill(std::move(opaque_message)); + } + } else if (should_drop) { + // don't add it anywhere, let it drop + } else { + // add to can_receive_ if not + const auto &[om_vec, inserted] = can_receive_.try_emplace(to_address, std::vector<OpaqueMessage>()); + om_vec->second.emplace_back(std::move(opaque_message)); + } + + return true; +} + +Time SimulatorHandle::Now() const { + std::unique_lock<std::mutex> lock(mu_); + return cluster_wide_time_microseconds_; +} + +SimulatorStats SimulatorHandle::Stats() { + std::unique_lock<std::mutex> lock(mu_); + return stats_; +} +} // namespace memgraph::io::simulator diff --git a/src/io/simulator/simulator_handle.hpp b/src/io/simulator/simulator_handle.hpp new file mode 100644 index 000000000..6abaa129d --- /dev/null +++ b/src/io/simulator/simulator_handle.hpp @@ -0,0 +1,206 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <any> +#include <compare> +#include <iostream> +#include <map> +#include <memory> +#include <optional> +#include <set> +#include <utility> +#include <variant> +#include <vector> + +#include "io/address.hpp" +#include "io/errors.hpp" +#include "io/simulator/message_conversion.hpp" +#include "io/simulator/simulator_config.hpp" +#include "io/simulator/simulator_stats.hpp" +#include "io/time.hpp" +#include "io/transport.hpp" + +namespace memgraph::io::simulator { + +using memgraph::io::Duration; +using memgraph::io::Time; + +struct PromiseKey { + Address requester_address; + uint64_t request_id; + // TODO(tyler) possibly remove replier_address from promise key + // once we want to support DSR. + Address replier_address; + + public: + friend bool operator<(const PromiseKey &lhs, const PromiseKey &rhs) { + if (lhs.requester_address != rhs.requester_address) { + return lhs.requester_address < rhs.requester_address; + } + + if (lhs.request_id != rhs.request_id) { + return lhs.request_id < rhs.request_id; + } + + return lhs.replier_address < rhs.replier_address; + } +}; + +struct DeadlineAndOpaquePromise { + Time deadline; + OpaquePromise promise; +}; + +class SimulatorHandle { + mutable std::mutex mu_{}; + mutable std::condition_variable cv_; + + // messages that have not yet been scheduled or dropped + std::vector<std::pair<Address, OpaqueMessage>> in_flight_; + + // the responses to requests that are being waited on + std::map<PromiseKey, DeadlineAndOpaquePromise> promises_; + + // messages that are sent to servers that may later receive them + std::map<Address, std::vector<OpaqueMessage>> can_receive_; + + Time cluster_wide_time_microseconds_; + bool should_shut_down_ = false; + SimulatorStats stats_; + size_t blocked_on_receive_ = 0; + std::set<Address> server_addresses_; + std::mt19937 rng_; + SimulatorConfig config_; + + /// Returns the number of servers currently blocked on Receive, plus + /// the servers that are blocked on Futures that were created through + /// SimulatorTransport::Request. + /// + /// TODO(tyler) investigate whether avoiding consideration of Futures + /// increases determinism. + size_t BlockedServers(); + + void TimeoutPromisesPastDeadline() { + const Time now = cluster_wide_time_microseconds_; + + for (auto &[promise_key, dop] : promises_) { + if (dop.deadline < now) { + spdlog::debug("timing out request from requester {} to replier {}.", promise_key.requester_address.ToString(), + promise_key.replier_address.ToString()); + std::move(dop).promise.TimeOut(); + promises_.erase(promise_key); + + stats_.timed_out_requests++; + } + } + } + + public: + explicit SimulatorHandle(SimulatorConfig config) + : cluster_wide_time_microseconds_(config.start_time), rng_(config.rng_seed), config_(config) {} + + void IncrementServerCountAndWaitForQuiescentState(Address address); + + /// This method causes most of the interesting simulation logic to happen, wrt network behavior. + /// It checks to see if all background "server" threads are blocked on new messages, and if so, + /// it will decide whether to drop, reorder, or deliver in-flight messages based on the SimulatorConfig + /// that was used to create the Simulator. + bool MaybeTickSimulator(); + + void ShutDown(); + + bool ShouldShutDown() const; + + template <Message Request, Message Response> + void SubmitRequest(Address to_address, Address from_address, uint64_t request_id, Request &&request, Duration timeout, + ResponsePromise<Response> &&promise) { + std::unique_lock<std::mutex> lock(mu_); + + const Time deadline = cluster_wide_time_microseconds_ + timeout; + + std::any message(request); + OpaqueMessage om{.from_address = from_address, .request_id = request_id, .message = std::move(message)}; + in_flight_.emplace_back(std::make_pair(to_address, std::move(om))); + + PromiseKey promise_key{.requester_address = from_address, .request_id = request_id, .replier_address = to_address}; + OpaquePromise opaque_promise(std::move(promise).ToUnique()); + DeadlineAndOpaquePromise dop{.deadline = deadline, .promise = std::move(opaque_promise)}; + promises_.emplace(std::move(promise_key), std::move(dop)); + + stats_.total_messages++; + stats_.total_requests++; + + cv_.notify_all(); + } + + template <Message... Ms> + requires(sizeof...(Ms) > 0) RequestResult<Ms...> Receive(const Address &receiver, Duration timeout) { + std::unique_lock<std::mutex> lock(mu_); + + blocked_on_receive_ += 1; + + const Time deadline = cluster_wide_time_microseconds_ + timeout; + + while (!should_shut_down_ && (cluster_wide_time_microseconds_ < deadline)) { + if (can_receive_.contains(receiver)) { + std::vector<OpaqueMessage> &can_rx = can_receive_.at(receiver); + if (!can_rx.empty()) { + OpaqueMessage message = std::move(can_rx.back()); + can_rx.pop_back(); + + // TODO(tyler) search for item in can_receive_ that matches the desired types, rather + // than asserting that the last item in can_rx matches. + auto m_opt = std::move(message).Take<Ms...>(); + + blocked_on_receive_ -= 1; + + return std::move(m_opt).value(); + } + } + + lock.unlock(); + bool made_progress = MaybeTickSimulator(); + lock.lock(); + if (!should_shut_down_ && !made_progress) { + cv_.wait(lock); + } + } + + blocked_on_receive_ -= 1; + + return TimedOut{}; + } + + template <Message M> + void Send(Address to_address, Address from_address, uint64_t request_id, M message) { + std::unique_lock<std::mutex> lock(mu_); + std::any message_any(std::move(message)); + OpaqueMessage om{.from_address = from_address, .request_id = request_id, .message = std::move(message_any)}; + in_flight_.emplace_back(std::make_pair(std::move(to_address), std::move(om))); + + stats_.total_messages++; + + cv_.notify_all(); + } + + Time Now() const; + + template <class D = std::poisson_distribution<>, class Return = uint64_t> + Return Rand(D distrib) { + std::unique_lock<std::mutex> lock(mu_); + return distrib(rng_); + } + + SimulatorStats Stats(); +}; +}; // namespace memgraph::io::simulator diff --git a/src/io/simulator/simulator_stats.hpp b/src/io/simulator/simulator_stats.hpp new file mode 100644 index 000000000..7f529a456 --- /dev/null +++ b/src/io/simulator/simulator_stats.hpp @@ -0,0 +1,25 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <cstdint> + +namespace memgraph::io::simulator { +struct SimulatorStats { + uint64_t total_messages = 0; + uint64_t dropped_messages = 0; + uint64_t timed_out_requests = 0; + uint64_t total_requests = 0; + uint64_t total_responses = 0; + uint64_t simulator_ticks = 0; +}; +}; // namespace memgraph::io::simulator diff --git a/src/io/simulator/simulator_transport.hpp b/src/io/simulator/simulator_transport.hpp new file mode 100644 index 000000000..b67371ff0 --- /dev/null +++ b/src/io/simulator/simulator_transport.hpp @@ -0,0 +1,65 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <memory> +#include <utility> + +#include "io/address.hpp" +#include "io/simulator/simulator_handle.hpp" +#include "io/time.hpp" + +namespace memgraph::io::simulator { + +using memgraph::io::Duration; +using memgraph::io::Time; + +class SimulatorTransport { + std::shared_ptr<SimulatorHandle> simulator_handle_; + const Address address_; + std::mt19937 rng_; + + public: + SimulatorTransport(std::shared_ptr<SimulatorHandle> simulator_handle, Address address, uint64_t seed) + : simulator_handle_(simulator_handle), address_(address), rng_(std::mt19937{seed}) {} + + template <Message Request, Message Response> + ResponseFuture<Response> Request(Address address, uint64_t request_id, Request request, Duration timeout) { + std::function<bool()> maybe_tick_simulator = [this] { return simulator_handle_->MaybeTickSimulator(); }; + auto [future, promise] = + memgraph::io::FuturePromisePairWithNotifier<ResponseResult<Response>>(maybe_tick_simulator); + + simulator_handle_->SubmitRequest(address, address_, request_id, std::move(request), timeout, std::move(promise)); + + return std::move(future); + } + + template <Message... Ms> + requires(sizeof...(Ms) > 0) RequestResult<Ms...> Receive(Duration timeout) { + return simulator_handle_->template Receive<Ms...>(address_, timeout); + } + + template <Message M> + void Send(Address address, uint64_t request_id, M message) { + return simulator_handle_->template Send<M>(address, address_, request_id, message); + } + + Time Now() const { return simulator_handle_->Now(); } + + bool ShouldShutDown() const { return simulator_handle_->ShouldShutDown(); } + + template <class D = std::poisson_distribution<>, class Return = uint64_t> + Return Rand(D distrib) { + return distrib(rng_); + } +}; +}; // namespace memgraph::io::simulator diff --git a/src/io/time.hpp b/src/io/time.hpp new file mode 100644 index 000000000..57f58cab1 --- /dev/null +++ b/src/io/time.hpp @@ -0,0 +1,21 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <chrono> + +namespace memgraph::io { + +using Duration = std::chrono::microseconds; +using Time = std::chrono::time_point<std::chrono::local_t, Duration>; + +} // namespace memgraph::io diff --git a/src/io/transport.hpp b/src/io/transport.hpp new file mode 100644 index 000000000..a9e550434 --- /dev/null +++ b/src/io/transport.hpp @@ -0,0 +1,130 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <chrono> +#include <concepts> +#include <random> +#include <variant> + +#include "io/address.hpp" +#include "io/errors.hpp" +#include "io/future.hpp" +#include "io/time.hpp" +#include "utils/result.hpp" + +namespace memgraph::io { + +using memgraph::utils::BasicResult; + +// TODO(tyler) ensure that Message continues to represent +// reasonable constraints around message types over time, +// as we adapt things to use Thrift-generated message types. +template <typename T> +concept Message = std::same_as<T, std::decay_t<T>>; + +template <Message M> +struct ResponseEnvelope { + M message; + uint64_t request_id; + Address from_address; +}; + +template <Message M> +using ResponseResult = BasicResult<TimedOut, ResponseEnvelope<M>>; + +template <Message M> +using ResponseFuture = memgraph::io::Future<ResponseResult<M>>; + +template <Message M> +using ResponsePromise = memgraph::io::Promise<ResponseResult<M>>; + +template <Message... Ms> +struct RequestEnvelope { + std::variant<Ms...> message; + uint64_t request_id; + Address from_address; +}; + +template <Message... Ms> +using RequestResult = BasicResult<TimedOut, RequestEnvelope<Ms...>>; + +template <typename I> +class Io { + I implementation_; + Address address_; + uint64_t request_id_counter_ = 0; + Duration default_timeout_ = std::chrono::microseconds{50000}; + + public: + Io(I io, Address address) : implementation_(io), address_(address) {} + + /// Set the default timeout for all requests that are issued + /// without an explicit timeout set. + void SetDefaultTimeout(Duration timeout) { default_timeout_ = timeout; } + + /// Issue a request with an explicit timeout in microseconds provided. This tends to be used by clients. + template <Message Request, Message Response> + ResponseFuture<Response> RequestWithTimeout(Address address, Request request, Duration timeout) { + const uint64_t request_id = ++request_id_counter_; + return implementation_.template Request<Request, Response>(address, request_id, request, timeout); + } + + /// Issue a request that times out after the default timeout. This tends + /// to be used by clients. + template <Message Request, Message Response> + ResponseFuture<Response> Request(Address address, Request request) { + const uint64_t request_id = ++request_id_counter_; + const Duration timeout = default_timeout_; + return implementation_.template Request<Request, Response>(address, request_id, std::move(request), timeout); + } + + /// Wait for an explicit number of microseconds for a request of one of the + /// provided types to arrive. This tends to be used by servers. + template <Message... Ms> + RequestResult<Ms...> ReceiveWithTimeout(Duration timeout) { + return implementation_.template Receive<Ms...>(timeout); + } + + /// Wait the default number of microseconds for a request of one of the + /// provided types to arrive. This tends to be used by servers. + template <Message... Ms> + requires(sizeof...(Ms) > 0) RequestResult<Ms...> Receive() { + const Duration timeout = default_timeout_; + return implementation_.template Receive<Ms...>(timeout); + } + + /// Send a message in a best-effort fashion. This is used for messaging where + /// responses are not necessarily expected, and for servers to respond to requests. + /// If you need reliable delivery, this must be built on-top. TCP is not enough for most use cases. + template <Message M> + void Send(Address address, uint64_t request_id, M message) { + return implementation_.template Send<M>(address, request_id, std::move(message)); + } + + /// The current system time. This time source should be preferred over any other, + /// because it lets us deterministically control clocks from tests for making + /// things like timeouts deterministic. + Time Now() const { return implementation_.Now(); } + + /// Returns true if the system should shut-down. + bool ShouldShutDown() const { return implementation_.ShouldShutDown(); } + + /// Returns a random number within the specified distribution. + template <class D = std::poisson_distribution<>, class Return = uint64_t> + Return Rand(D distrib) { + return implementation_.template Rand<D, Return>(distrib); + } + + Address GetAddress() { return address_; } +}; +}; // namespace memgraph::io diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 02535dcc7..664c010c8 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -10,6 +10,9 @@ add_subdirectory(stress) # concurrent test binaries add_subdirectory(concurrent) +# simulation test binaries +add_subdirectory(simulation) + # manual test binaries add_subdirectory(manual) diff --git a/tests/benchmark/CMakeLists.txt b/tests/benchmark/CMakeLists.txt index 4bf8374b0..31f0eebc0 100644 --- a/tests/benchmark/CMakeLists.txt +++ b/tests/benchmark/CMakeLists.txt @@ -62,3 +62,6 @@ target_link_libraries(${test_prefix}storage_v2_gc mg-storage-v2) add_benchmark(storage_v2_property_store.cpp) target_link_libraries(${test_prefix}storage_v2_property_store mg-storage-v2) + +add_benchmark(future.cpp) +target_link_libraries(${test_prefix}future mg-io) diff --git a/tests/benchmark/future.cpp b/tests/benchmark/future.cpp new file mode 100644 index 000000000..abbe3fb98 --- /dev/null +++ b/tests/benchmark/future.cpp @@ -0,0 +1,30 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include <benchmark/benchmark.h> + +#include "io/future.hpp" + +static void FuturePairFillWait(benchmark::State &state) { + uint64_t counter = 0; + while (state.KeepRunning()) { + auto [future, promise] = memgraph::io::FuturePromisePair<int>(); + promise.Fill(1); + std::move(future).Wait(); + + ++counter; + } + state.SetItemsProcessed(counter); +} + +BENCHMARK(FuturePairFillWait)->Unit(benchmark::kNanosecond)->UseRealTime(); + +BENCHMARK_MAIN(); diff --git a/tests/simulation/CMakeLists.txt b/tests/simulation/CMakeLists.txt new file mode 100644 index 000000000..142657401 --- /dev/null +++ b/tests/simulation/CMakeLists.txt @@ -0,0 +1,30 @@ +set(test_prefix memgraph__simulation__) + +find_package(gflags) + +add_custom_target(memgraph__simulation) + +function(add_simulation_test test_cpp san) + # get exec name (remove extension from the abs path) + get_filename_component(exec_name ${test_cpp} NAME_WE) + set(target_name ${test_prefix}${exec_name}) + add_executable(${target_name} ${test_cpp}) + + # OUTPUT_NAME sets the real name of a target when it is built and can be + # used to help create two targets of the same name even though CMake + # requires unique logical target names + set_target_properties(${target_name} PROPERTIES OUTPUT_NAME ${exec_name}) + target_link_libraries(${target_name} gtest gmock mg-utils mg-io mg-io-simulator) + + # sanitize + target_compile_options(${target_name} PRIVATE -fsanitize=${san}) + target_link_options(${target_name} PRIVATE -fsanitize=${san}) + + # register test + add_test(${target_name} ${exec_name}) + add_dependencies(memgraph__simulation ${target_name}) +endfunction(add_simulation_test) + +add_simulation_test(basic_request.cpp address) + +add_simulation_test(trial_query_storage/query_storage_test.cpp address) diff --git a/tests/simulation/basic_request.cpp b/tests/simulation/basic_request.cpp new file mode 100644 index 000000000..ac3190ad7 --- /dev/null +++ b/tests/simulation/basic_request.cpp @@ -0,0 +1,87 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include <thread> + +#include "io/simulator/simulator.hpp" + +using memgraph::io::Address; +using memgraph::io::Io; +using memgraph::io::ResponseFuture; +using memgraph::io::ResponseResult; +using memgraph::io::simulator::Simulator; +using memgraph::io::simulator::SimulatorConfig; +using memgraph::io::simulator::SimulatorTransport; + +struct CounterRequest { + uint64_t proposal; +}; + +struct CounterResponse { + uint64_t highest_seen; +}; + +void run_server(Io<SimulatorTransport> io) { + uint64_t highest_seen = 0; + + while (!io.ShouldShutDown()) { + std::cout << "[SERVER] Is receiving..." << std::endl; + auto request_result = io.Receive<CounterRequest>(); + if (request_result.HasError()) { + std::cout << "[SERVER] Error, continue" << std::endl; + continue; + } + auto request_envelope = request_result.GetValue(); + auto req = std::get<CounterRequest>(request_envelope.message); + + highest_seen = std::max(highest_seen, req.proposal); + auto srv_res = CounterResponse{highest_seen}; + + io.Send(request_envelope.from_address, request_envelope.request_id, srv_res); + } +} + +int main() { + auto config = SimulatorConfig{ + .drop_percent = 0, + .perform_timeouts = true, + .scramble_messages = true, + .rng_seed = 0, + }; + auto simulator = Simulator(config); + + auto cli_addr = Address::TestAddress(1); + auto srv_addr = Address::TestAddress(2); + + Io<SimulatorTransport> cli_io = simulator.Register(cli_addr); + Io<SimulatorTransport> srv_io = simulator.Register(srv_addr); + + auto srv_thread = std::jthread(run_server, std::move(srv_io)); + simulator.IncrementServerCountAndWaitForQuiescentState(srv_addr); + + for (int i = 1; i < 3; ++i) { + // send request + CounterRequest cli_req; + cli_req.proposal = i; + auto res_f = cli_io.Request<CounterRequest, CounterResponse>(srv_addr, cli_req); + auto res_rez = std::move(res_f).Wait(); + if (!res_rez.HasError()) { + std::cout << "[CLIENT] Got a valid response" << std::endl; + auto env = res_rez.GetValue(); + MG_ASSERT(env.message.highest_seen == i); + } else { + std::cout << "[CLIENT] Got an error" << std::endl; + } + } + + simulator.ShutDown(); + return 0; +} diff --git a/tests/simulation/trial_query_storage/messages.hpp b/tests/simulation/trial_query_storage/messages.hpp new file mode 100644 index 000000000..8db78a54c --- /dev/null +++ b/tests/simulation/trial_query_storage/messages.hpp @@ -0,0 +1,34 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <optional> +#include <string> +#include <vector> + +namespace memgraph::tests::simulation { + +struct Vertex { + std::string key; +}; + +struct ScanVerticesRequest { + int64_t count; + std::optional<int64_t> continuation; +}; + +struct VerticesResponse { + std::vector<Vertex> vertices; + std::optional<int64_t> continuation; +}; + +} // namespace memgraph::tests::simulation diff --git a/tests/simulation/trial_query_storage/query_storage_test.cpp b/tests/simulation/trial_query_storage/query_storage_test.cpp new file mode 100644 index 000000000..9cdff4ee6 --- /dev/null +++ b/tests/simulation/trial_query_storage/query_storage_test.cpp @@ -0,0 +1,85 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include <iostream> + +#include "io/address.hpp" +#include "io/simulator/simulator.hpp" +#include "io/simulator/simulator_config.hpp" +#include "io/simulator/simulator_transport.hpp" +#include "io/transport.hpp" + +#include "messages.hpp" + +namespace memgraph::tests::simulation { +using memgraph::io::Io; +using memgraph::io::simulator::SimulatorTransport; + +void run_server(Io<SimulatorTransport> io) { + while (!io.ShouldShutDown()) { + std::cout << "[STORAGE] Is receiving..." << std::endl; + auto request_result = io.Receive<ScanVerticesRequest>(); + if (request_result.HasError()) { + std::cout << "[STORAGE] Error, continue" << std::endl; + continue; + } + auto request_envelope = request_result.GetValue(); + auto req = std::get<ScanVerticesRequest>(request_envelope.message); + + VerticesResponse response{}; + const int64_t start_index = std::invoke([&req] { + if (req.continuation.has_value()) { + return *req.continuation; + } + return 0L; + }); + for (auto index = start_index; index < start_index + req.count; ++index) { + response.vertices.push_back({std::string("Vertex_") + std::to_string(index)}); + } + io.Send(request_envelope.from_address, request_envelope.request_id, response); + } +} + +} // namespace memgraph::tests::simulation + +int main() { + using memgraph::io::Address; + using memgraph::io::Io; + using memgraph::io::simulator::Simulator; + using memgraph::io::simulator::SimulatorConfig; + using memgraph::io::simulator::SimulatorTransport; + using memgraph::tests::simulation::run_server; + using memgraph::tests::simulation::ScanVerticesRequest; + using memgraph::tests::simulation::VerticesResponse; + auto config = SimulatorConfig{ + .drop_percent = 0, + .perform_timeouts = true, + .scramble_messages = true, + .rng_seed = 0, + }; + auto simulator = Simulator(config); + + auto cli_addr = Address::TestAddress(1); + auto srv_addr = Address::TestAddress(2); + + Io<SimulatorTransport> cli_io = simulator.Register(cli_addr); + Io<SimulatorTransport> srv_io = simulator.Register(srv_addr); + + auto srv_thread = std::jthread(run_server, std::move(srv_io)); + simulator.IncrementServerCountAndWaitForQuiescentState(srv_addr); + + auto req = ScanVerticesRequest{2, std::nullopt}; + + auto res_f = cli_io.Request<ScanVerticesRequest, VerticesResponse>(srv_addr, req); + auto res_rez = std::move(res_f).Wait(); + simulator.ShutDown(); + return 0; +} diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 55fb7b01b..9de4860ef 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -396,3 +396,7 @@ find_package(Boost REQUIRED) add_unit_test(websocket.cpp) target_link_libraries(${test_prefix}websocket mg-communication Boost::headers) + +# Test future +add_unit_test(future.cpp) +target_link_libraries(${test_prefix}future mg-io) \ No newline at end of file diff --git a/tests/unit/future.cpp b/tests/unit/future.cpp new file mode 100644 index 000000000..490e19bbc --- /dev/null +++ b/tests/unit/future.cpp @@ -0,0 +1,55 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include <string> +#include <thread> + +#include "gtest/gtest.h" + +#include "io/future.hpp" + +using namespace memgraph::io; + +void Fill(Promise<std::string> promise_1) { promise_1.Fill("success"); } + +void Wait(Future<std::string> future_1, Promise<std::string> promise_2) { + std::string result_1 = std::move(future_1).Wait(); + EXPECT_TRUE(result_1 == "success"); + promise_2.Fill("it worked"); +} + +TEST(Future, BasicLifecycle) { + std::atomic_bool waiting = false; + + std::function<bool()> notifier = [&] { + waiting.store(true, std::memory_order_seq_cst); + return false; + }; + + auto [future_1, promise_1] = FuturePromisePairWithNotifier<std::string>(notifier); + auto [future_2, promise_2] = FuturePromisePair<std::string>(); + + std::jthread t1(Wait, std::move(future_1), std::move(promise_2)); + + // spin in a loop until the promise signals + // that it is waiting + while (!waiting.load(std::memory_order_acquire)) { + std::this_thread::yield(); + } + + std::jthread t2(Fill, std::move(promise_1)); + + t1.join(); + t2.join(); + + std::string result_2 = std::move(future_2).Wait(); + EXPECT_TRUE(result_2 == "it worked"); +} From a40403e3ce8c7c0755ee13373172fd02fe846dc3 Mon Sep 17 00:00:00 2001 From: Tyler Neely <tylerneely@gmail.com> Date: Mon, 29 Aug 2022 13:49:51 +0200 Subject: [PATCH 2/4] Add local transport (#512) * Create LocalTransport Io provider for sending messages to components on the same machine * Move src/io/simulation/message_conversion.hpp to src/io/message_conversion.hpp for use in other Io providers --- src/io/address.hpp | 7 + src/io/local_transport/local_system.hpp | 35 +++++ src/io/local_transport/local_transport.hpp | 67 +++++++++ .../local_transport_handle.hpp | 139 ++++++++++++++++++ src/io/{simulator => }/message_conversion.hpp | 33 ++++- src/io/simulator/simulator_handle.hpp | 47 ++---- src/io/simulator/simulator_transport.hpp | 4 +- src/io/time.hpp | 2 +- src/io/transport.hpp | 19 ++- tests/unit/CMakeLists.txt | 6 +- tests/unit/local_transport.cpp | 86 +++++++++++ 11 files changed, 397 insertions(+), 48 deletions(-) create mode 100644 src/io/local_transport/local_system.hpp create mode 100644 src/io/local_transport/local_transport.hpp create mode 100644 src/io/local_transport/local_transport_handle.hpp rename src/io/{simulator => }/message_conversion.hpp (87%) create mode 100644 tests/unit/local_transport.cpp diff --git a/src/io/address.hpp b/src/io/address.hpp index 94a231e07..19dd55948 100644 --- a/src/io/address.hpp +++ b/src/io/address.hpp @@ -16,6 +16,7 @@ #include <fmt/format.h> #include <boost/asio/ip/tcp.hpp> #include <boost/uuid/uuid.hpp> +#include <boost/uuid/uuid_generators.hpp> #include <boost/uuid/uuid_io.hpp> namespace memgraph::io { @@ -37,6 +38,12 @@ struct Address { return ret; } + static Address UniqueLocalAddress() { + return Address{ + .unique_id = boost::uuids::uuid{boost::uuids::random_generator()()}, + }; + } + friend bool operator==(const Address &lhs, const Address &rhs) = default; /// unique_id is most dominant for ordering, then last_known_ip, then last_known_port diff --git a/src/io/local_transport/local_system.hpp b/src/io/local_transport/local_system.hpp new file mode 100644 index 000000000..fd0628943 --- /dev/null +++ b/src/io/local_transport/local_system.hpp @@ -0,0 +1,35 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <random> + +#include "io/address.hpp" +#include "io/local_transport/local_transport.hpp" +#include "io/local_transport/local_transport_handle.hpp" +#include "io/transport.hpp" + +namespace memgraph::io::local_transport { + +class LocalSystem { + std::shared_ptr<LocalTransportHandle> local_transport_handle_ = std::make_shared<LocalTransportHandle>(); + + public: + Io<LocalTransport> Register(Address address) { + LocalTransport local_transport(local_transport_handle_, address); + return Io{local_transport, address}; + } + + void ShutDown() { local_transport_handle_->ShutDown(); } +}; + +} // namespace memgraph::io::local_transport diff --git a/src/io/local_transport/local_transport.hpp b/src/io/local_transport/local_transport.hpp new file mode 100644 index 000000000..f08392a87 --- /dev/null +++ b/src/io/local_transport/local_transport.hpp @@ -0,0 +1,67 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <chrono> +#include <memory> +#include <random> +#include <utility> + +#include "io/address.hpp" +#include "io/local_transport/local_transport_handle.hpp" +#include "io/time.hpp" +#include "io/transport.hpp" + +namespace memgraph::io::local_transport { + +class LocalTransport { + std::shared_ptr<LocalTransportHandle> local_transport_handle_; + const Address address_; + + public: + LocalTransport(std::shared_ptr<LocalTransportHandle> local_transport_handle, Address address) + : local_transport_handle_(std::move(local_transport_handle)), address_(address) {} + + template <Message Request, Message Response> + ResponseFuture<Response> Request(Address to_address, RequestId request_id, Request request, Duration timeout) { + auto [future, promise] = memgraph::io::FuturePromisePair<ResponseResult<Response>>(); + + Address from_address = address_; + + local_transport_handle_->SubmitRequest(to_address, from_address, request_id, std::move(request), timeout, + std::move(promise)); + + return std::move(future); + } + + template <Message... Ms> + requires(sizeof...(Ms) > 0) RequestResult<Ms...> Receive(Duration timeout) { + Address from_address = address_; + return local_transport_handle_->template Receive<Ms...>(timeout); + } + + template <Message M> + void Send(Address to_address, Address from_address, RequestId request_id, M &&message) { + return local_transport_handle_->template Send<M>(to_address, from_address, request_id, std::forward<M>(message)); + } + + Time Now() const { return local_transport_handle_->Now(); } + + bool ShouldShutDown() const { return local_transport_handle_->ShouldShutDown(); } + + template <class D = std::poisson_distribution<>, class Return = uint64_t> + Return Rand(D distrib) { + std::random_device rng; + return distrib(rng); + } +}; +}; // namespace memgraph::io::local_transport diff --git a/src/io/local_transport/local_transport_handle.hpp b/src/io/local_transport/local_transport_handle.hpp new file mode 100644 index 000000000..8536ff716 --- /dev/null +++ b/src/io/local_transport/local_transport_handle.hpp @@ -0,0 +1,139 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <chrono> +#include <condition_variable> +#include <iostream> +#include <map> +#include <mutex> + +#include "io/errors.hpp" +#include "io/message_conversion.hpp" +#include "io/time.hpp" +#include "io/transport.hpp" + +namespace memgraph::io::local_transport { + +class LocalTransportHandle { + mutable std::mutex mu_{}; + mutable std::condition_variable cv_; + bool should_shut_down_ = false; + + // the responses to requests that are being waited on + std::map<PromiseKey, DeadlineAndOpaquePromise> promises_; + + // messages that are sent to servers that may later receive them + std::vector<OpaqueMessage> can_receive_; + + public: + void ShutDown() { + std::unique_lock<std::mutex> lock(mu_); + should_shut_down_ = true; + cv_.notify_all(); + } + + bool ShouldShutDown() const { + std::unique_lock<std::mutex> lock(mu_); + return should_shut_down_; + } + + static Time Now() { + auto nano_time = std::chrono::system_clock::now(); + return std::chrono::time_point_cast<std::chrono::microseconds>(nano_time); + } + + template <Message... Ms> + requires(sizeof...(Ms) > 0) RequestResult<Ms...> Receive(Duration timeout) { + std::unique_lock lock(mu_); + + Time before = Now(); + + while (can_receive_.empty()) { + Time now = Now(); + + // protection against non-monotonic timesources + auto maxed_now = std::max(now, before); + auto elapsed = maxed_now - before; + + if (timeout < elapsed) { + return TimedOut{}; + } + + Duration relative_timeout = timeout - elapsed; + + std::cv_status cv_status_value = cv_.wait_for(lock, relative_timeout); + + if (cv_status_value == std::cv_status::timeout) { + return TimedOut{}; + } + } + + auto current_message = std::move(can_receive_.back()); + can_receive_.pop_back(); + + auto m_opt = std::move(current_message).Take<Ms...>(); + + return std::move(m_opt).value(); + } + + template <Message M> + void Send(Address to_address, Address from_address, RequestId request_id, M &&message) { + std::any message_any(std::forward<M>(message)); + OpaqueMessage opaque_message{ + .from_address = from_address, .request_id = request_id, .message = std::move(message_any)}; + + PromiseKey promise_key{.requester_address = to_address, + .request_id = opaque_message.request_id, + .replier_address = opaque_message.from_address}; + + { + std::unique_lock<std::mutex> lock(mu_); + + if (promises_.contains(promise_key)) { + // complete waiting promise if it's there + DeadlineAndOpaquePromise dop = std::move(promises_.at(promise_key)); + promises_.erase(promise_key); + + dop.promise.Fill(std::move(opaque_message)); + } else { + can_receive_.emplace_back(std::move(opaque_message)); + } + } // lock dropped + + cv_.notify_all(); + } + + template <Message Request, Message Response> + void SubmitRequest(Address to_address, Address from_address, RequestId request_id, Request &&request, + Duration timeout, ResponsePromise<Response> promise) { + const bool port_matches = to_address.last_known_port == from_address.last_known_port; + const bool ip_matches = to_address.last_known_ip == from_address.last_known_ip; + + MG_ASSERT(port_matches && ip_matches); + const Time deadline = Now() + timeout; + + { + std::unique_lock<std::mutex> lock(mu_); + + PromiseKey promise_key{ + .requester_address = from_address, .request_id = request_id, .replier_address = to_address}; + OpaquePromise opaque_promise(std::move(promise).ToUnique()); + DeadlineAndOpaquePromise dop{.deadline = deadline, .promise = std::move(opaque_promise)}; + promises_.emplace(std::move(promise_key), std::move(dop)); + } // lock dropped + + Send(to_address, from_address, request_id, std::forward<Request>(request)); + } +}; + +} // namespace memgraph::io::local_transport diff --git a/src/io/simulator/message_conversion.hpp b/src/io/message_conversion.hpp similarity index 87% rename from src/io/simulator/message_conversion.hpp rename to src/io/message_conversion.hpp index f16c60f65..53881583b 100644 --- a/src/io/simulator/message_conversion.hpp +++ b/src/io/message_conversion.hpp @@ -13,13 +13,35 @@ #include "io/transport.hpp" -namespace memgraph::io::simulator { +namespace memgraph::io { using memgraph::io::Duration; using memgraph::io::Message; using memgraph::io::Time; +struct PromiseKey { + Address requester_address; + uint64_t request_id; + // TODO(tyler) possibly remove replier_address from promise key + // once we want to support DSR. + Address replier_address; + + public: + friend bool operator<(const PromiseKey &lhs, const PromiseKey &rhs) { + if (lhs.requester_address != rhs.requester_address) { + return lhs.requester_address < rhs.requester_address; + } + + if (lhs.request_id != rhs.request_id) { + return lhs.request_id < rhs.request_id; + } + + return lhs.replier_address < rhs.replier_address; + } +}; + struct OpaqueMessage { + Address to_address; Address from_address; uint64_t request_id; std::any message; @@ -65,6 +87,7 @@ struct OpaqueMessage { return RequestEnvelope<Ms...>{ .message = std::move(*m_opt), .request_id = request_id, + .to_address = to_address, .from_address = from_address, }; } @@ -99,6 +122,7 @@ class OpaquePromiseTrait : public OpaquePromiseTraitBase { T message = std::any_cast<T>(std::move(opaque_message.message)); auto response_envelope = ResponseEnvelope<T>{.message = std::move(message), .request_id = opaque_message.request_id, + .to_address = opaque_message.to_address, .from_address = opaque_message.from_address}; auto promise = static_cast<ResponsePromise<T> *>(ptr); auto unique_promise = std::unique_ptr<ResponsePromise<T>>(promise); @@ -174,4 +198,9 @@ class OpaquePromise { } }; -} // namespace memgraph::io::simulator +struct DeadlineAndOpaquePromise { + Time deadline; + OpaquePromise promise; +}; + +} // namespace memgraph::io diff --git a/src/io/simulator/simulator_handle.hpp b/src/io/simulator/simulator_handle.hpp index 6abaa129d..08a3837ee 100644 --- a/src/io/simulator/simulator_handle.hpp +++ b/src/io/simulator/simulator_handle.hpp @@ -24,7 +24,7 @@ #include "io/address.hpp" #include "io/errors.hpp" -#include "io/simulator/message_conversion.hpp" +#include "io/message_conversion.hpp" #include "io/simulator/simulator_config.hpp" #include "io/simulator/simulator_stats.hpp" #include "io/time.hpp" @@ -32,35 +32,6 @@ namespace memgraph::io::simulator { -using memgraph::io::Duration; -using memgraph::io::Time; - -struct PromiseKey { - Address requester_address; - uint64_t request_id; - // TODO(tyler) possibly remove replier_address from promise key - // once we want to support DSR. - Address replier_address; - - public: - friend bool operator<(const PromiseKey &lhs, const PromiseKey &rhs) { - if (lhs.requester_address != rhs.requester_address) { - return lhs.requester_address < rhs.requester_address; - } - - if (lhs.request_id != rhs.request_id) { - return lhs.request_id < rhs.request_id; - } - - return lhs.replier_address < rhs.replier_address; - } -}; - -struct DeadlineAndOpaquePromise { - Time deadline; - OpaquePromise promise; -}; - class SimulatorHandle { mutable std::mutex mu_{}; mutable std::condition_variable cv_; @@ -122,14 +93,17 @@ class SimulatorHandle { bool ShouldShutDown() const; template <Message Request, Message Response> - void SubmitRequest(Address to_address, Address from_address, uint64_t request_id, Request &&request, Duration timeout, - ResponsePromise<Response> &&promise) { + void SubmitRequest(Address to_address, Address from_address, RequestId request_id, Request &&request, + Duration timeout, ResponsePromise<Response> &&promise) { std::unique_lock<std::mutex> lock(mu_); const Time deadline = cluster_wide_time_microseconds_ + timeout; std::any message(request); - OpaqueMessage om{.from_address = from_address, .request_id = request_id, .message = std::move(message)}; + OpaqueMessage om{.to_address = to_address, + .from_address = from_address, + .request_id = request_id, + .message = std::move(message)}; in_flight_.emplace_back(std::make_pair(to_address, std::move(om))); PromiseKey promise_key{.requester_address = from_address, .request_id = request_id, .replier_address = to_address}; @@ -182,10 +156,13 @@ class SimulatorHandle { } template <Message M> - void Send(Address to_address, Address from_address, uint64_t request_id, M message) { + void Send(Address to_address, Address from_address, RequestId request_id, M message) { std::unique_lock<std::mutex> lock(mu_); std::any message_any(std::move(message)); - OpaqueMessage om{.from_address = from_address, .request_id = request_id, .message = std::move(message_any)}; + OpaqueMessage om{.to_address = to_address, + .from_address = from_address, + .request_id = request_id, + .message = std::move(message_any)}; in_flight_.emplace_back(std::make_pair(std::move(to_address), std::move(om))); stats_.total_messages++; diff --git a/src/io/simulator/simulator_transport.hpp b/src/io/simulator/simulator_transport.hpp index b67371ff0..2706f798c 100644 --- a/src/io/simulator/simulator_transport.hpp +++ b/src/io/simulator/simulator_transport.hpp @@ -49,8 +49,8 @@ class SimulatorTransport { } template <Message M> - void Send(Address address, uint64_t request_id, M message) { - return simulator_handle_->template Send<M>(address, address_, request_id, message); + void Send(Address to_address, Address from_address, uint64_t request_id, M message) { + return simulator_handle_->template Send<M>(to_address, from_address, request_id, message); } Time Now() const { return simulator_handle_->Now(); } diff --git a/src/io/time.hpp b/src/io/time.hpp index 57f58cab1..b07f90154 100644 --- a/src/io/time.hpp +++ b/src/io/time.hpp @@ -16,6 +16,6 @@ namespace memgraph::io { using Duration = std::chrono::microseconds; -using Time = std::chrono::time_point<std::chrono::local_t, Duration>; +using Time = std::chrono::time_point<std::chrono::system_clock, Duration>; } // namespace memgraph::io diff --git a/src/io/transport.hpp b/src/io/transport.hpp index a9e550434..31592c9c3 100644 --- a/src/io/transport.hpp +++ b/src/io/transport.hpp @@ -32,10 +32,13 @@ using memgraph::utils::BasicResult; template <typename T> concept Message = std::same_as<T, std::decay_t<T>>; +using RequestId = uint64_t; + template <Message M> struct ResponseEnvelope { M message; - uint64_t request_id; + RequestId request_id; + Address to_address; Address from_address; }; @@ -51,7 +54,8 @@ using ResponsePromise = memgraph::io::Promise<ResponseResult<M>>; template <Message... Ms> struct RequestEnvelope { std::variant<Ms...> message; - uint64_t request_id; + RequestId request_id; + Address to_address; Address from_address; }; @@ -62,7 +66,7 @@ template <typename I> class Io { I implementation_; Address address_; - uint64_t request_id_counter_ = 0; + RequestId request_id_counter_ = 0; Duration default_timeout_ = std::chrono::microseconds{50000}; public: @@ -75,7 +79,7 @@ class Io { /// Issue a request with an explicit timeout in microseconds provided. This tends to be used by clients. template <Message Request, Message Response> ResponseFuture<Response> RequestWithTimeout(Address address, Request request, Duration timeout) { - const uint64_t request_id = ++request_id_counter_; + const RequestId request_id = ++request_id_counter_; return implementation_.template Request<Request, Response>(address, request_id, request, timeout); } @@ -83,7 +87,7 @@ class Io { /// to be used by clients. template <Message Request, Message Response> ResponseFuture<Response> Request(Address address, Request request) { - const uint64_t request_id = ++request_id_counter_; + const RequestId request_id = ++request_id_counter_; const Duration timeout = default_timeout_; return implementation_.template Request<Request, Response>(address, request_id, std::move(request), timeout); } @@ -107,8 +111,9 @@ class Io { /// responses are not necessarily expected, and for servers to respond to requests. /// If you need reliable delivery, this must be built on-top. TCP is not enough for most use cases. template <Message M> - void Send(Address address, uint64_t request_id, M message) { - return implementation_.template Send<M>(address, request_id, std::move(message)); + void Send(Address to_address, RequestId request_id, M message) { + Address from_address = address_; + return implementation_.template Send<M>(to_address, from_address, request_id, std::move(message)); } /// The current system time. This time source should be preferred over any other, diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 9de4860ef..a5e71347a 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -399,4 +399,8 @@ target_link_libraries(${test_prefix}websocket mg-communication Boost::headers) # Test future add_unit_test(future.cpp) -target_link_libraries(${test_prefix}future mg-io) \ No newline at end of file +target_link_libraries(${test_prefix}future mg-io) + +# Test Local Transport +add_unit_test(local_transport.cpp) +target_link_libraries(${test_prefix}local_transport mg-io) diff --git a/tests/unit/local_transport.cpp b/tests/unit/local_transport.cpp new file mode 100644 index 000000000..aa03325de --- /dev/null +++ b/tests/unit/local_transport.cpp @@ -0,0 +1,86 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include <chrono> +#include <thread> + +#include <gtest/gtest.h> + +#include "io/local_transport/local_system.hpp" +#include "io/local_transport/local_transport.hpp" +#include "io/transport.hpp" + +namespace memgraph::io::tests { + +using memgraph::io::local_transport::LocalSystem; +using memgraph::io::local_transport::LocalTransport; + +struct CounterRequest { + uint64_t proposal; +}; + +struct CounterResponse { + uint64_t highest_seen; +}; + +void RunServer(Io<LocalTransport> io) { + uint64_t highest_seen = 0; + + while (!io.ShouldShutDown()) { + spdlog::info("[SERVER] Is receiving..."); + auto request_result = io.Receive<CounterRequest>(); + if (request_result.HasError()) { + spdlog::info("[SERVER] timed out, continue"); + continue; + } + auto request_envelope = request_result.GetValue(); + ASSERT_TRUE(std::holds_alternative<CounterRequest>(request_envelope.message)); + auto req = std::get<CounterRequest>(request_envelope.message); + + highest_seen = std::max(highest_seen, req.proposal); + auto srv_res = CounterResponse{highest_seen}; + + io.Send(request_envelope.from_address, request_envelope.request_id, srv_res); + } +} + +TEST(LocalTransport, BasicRequest) { + LocalSystem local_system; + + // rely on uuid to be unique on default Address + auto cli_addr = Address::UniqueLocalAddress(); + auto srv_addr = Address::UniqueLocalAddress(); + + Io<LocalTransport> cli_io = local_system.Register(cli_addr); + Io<LocalTransport> srv_io = local_system.Register(srv_addr); + + auto srv_thread = std::jthread(RunServer, std::move(srv_io)); + + for (int i = 1; i < 3; ++i) { + // send request + CounterRequest cli_req; + auto value = 1; // i; + cli_req.proposal = value; + spdlog::info("[CLIENT] sending request"); + auto res_f = cli_io.Request<CounterRequest, CounterResponse>(srv_addr, cli_req); + spdlog::info("[CLIENT] waiting on future"); + + auto res_rez = std::move(res_f).Wait(); + spdlog::info("[CLIENT] future returned"); + MG_ASSERT(!res_rez.HasError()); + spdlog::info("[CLIENT] Got a valid response"); + auto env = res_rez.GetValue(); + MG_ASSERT(env.message.highest_seen == value); + } + + local_system.ShutDown(); +} +} // namespace memgraph::io::tests From c0d03888f46c804f422d37b3085ed3c5148c1e70 Mon Sep 17 00:00:00 2001 From: Tyler Neely <tylerneely@gmail.com> Date: Tue, 30 Aug 2022 15:07:34 +0200 Subject: [PATCH 3/4] Implement basic raft version (#498) --- src/io/address.hpp | 7 +- src/io/rsm/raft.hpp | 913 ++++++++++++++++++++++++++ src/io/simulator/simulator_handle.cpp | 16 +- src/io/simulator/simulator_handle.hpp | 8 - tests/simulation/CMakeLists.txt | 2 + tests/simulation/raft.cpp | 315 +++++++++ 6 files changed, 1236 insertions(+), 25 deletions(-) create mode 100644 src/io/rsm/raft.hpp create mode 100644 tests/simulation/raft.cpp diff --git a/src/io/address.hpp b/src/io/address.hpp index 19dd55948..914c8cb86 100644 --- a/src/io/address.hpp +++ b/src/io/address.hpp @@ -33,9 +33,10 @@ struct Address { uint16_t last_known_port; static Address TestAddress(uint16_t port) { - Address ret; - ret.last_known_port = port; - return ret; + return Address{ + .unique_id = boost::uuids::uuid{boost::uuids::random_generator()()}, + .last_known_port = port, + }; } static Address UniqueLocalAddress() { diff --git a/src/io/rsm/raft.hpp b/src/io/rsm/raft.hpp new file mode 100644 index 000000000..c38d3da74 --- /dev/null +++ b/src/io/rsm/raft.hpp @@ -0,0 +1,913 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +// TODO(tyler) buffer out-of-order Append buffers on the Followers to reassemble more quickly +// TODO(tyler) handle granular batch sizes based on simple flow control + +#pragma once + +#include <deque> +#include <iostream> +#include <map> +#include <set> +#include <thread> +#include <unordered_map> +#include <vector> + +#include "io/simulator/simulator.hpp" +#include "io/transport.hpp" +#include "utils/concepts.hpp" + +namespace memgraph::io::rsm { + +/// Timeout and replication tunables +using namespace std::chrono_literals; +static constexpr auto kMinimumElectionTimeout = 100ms; +static constexpr auto kMaximumElectionTimeout = 200ms; +static constexpr auto kMinimumBroadcastTimeout = 40ms; +static constexpr auto kMaximumBroadcastTimeout = 60ms; +static constexpr auto kMinimumCronInterval = 1ms; +static constexpr auto kMaximumCronInterval = 2ms; +static constexpr auto kMinimumReceiveTimeout = 40ms; +static constexpr auto kMaximumReceiveTimeout = 60ms; +static_assert(kMinimumElectionTimeout > kMaximumBroadcastTimeout, + "The broadcast timeout has to be smaller than the election timeout!"); +static_assert(kMinimumElectionTimeout < kMaximumElectionTimeout, + "The minimum election timeout has to be smaller than the maximum election timeout!"); +static_assert(kMinimumBroadcastTimeout < kMaximumBroadcastTimeout, + "The minimum broadcast timeout has to be smaller than the maximum broadcast timeout!"); +static_assert(kMinimumCronInterval < kMaximumCronInterval, + "The minimum cron interval has to be smaller than the maximum cron interval!"); +static_assert(kMinimumReceiveTimeout < kMaximumReceiveTimeout, + "The minimum receive timeout has to be smaller than the maximum receive timeout!"); +static constexpr size_t kMaximumAppendBatchSize = 1024; + +using Term = uint64_t; +using LogIndex = uint64_t; +using LogSize = uint64_t; +using RequestId = uint64_t; + +template <typename WriteOperation> +struct WriteRequest { + WriteOperation operation; +}; + +/// WriteResponse is returned to a client after +/// their WriteRequest was entered in to the raft +/// log and it reached consensus. +/// +/// WriteReturn is the result of applying the WriteRequest to +/// ReplicatedState, and if the ReplicatedState::write +/// method is deterministic, all replicas will +/// have the same ReplicatedState after applying +/// the submitted WriteRequest. +template <typename WriteReturn> +struct WriteResponse { + bool success; + WriteReturn write_return; + std::optional<Address> retry_leader; + LogIndex raft_index; +}; + +template <typename ReadOperation> +struct ReadRequest { + ReadOperation operation; +}; + +template <typename ReadReturn> +struct ReadResponse { + bool success; + ReadReturn read_return; + std::optional<Address> retry_leader; +}; + +/// AppendRequest is a raft-level message that the Leader +/// periodically broadcasts to all Follower peers. This +/// serves three main roles: +/// 1. acts as a heartbeat from the Leader to the Follower +/// 2. replicates new data that the Leader has received to the Follower +/// 3. informs Follower peers when the commit index has increased, +/// signalling that it is now safe to apply log items to the +/// replicated state machine +template <typename WriteRequest> +struct AppendRequest { + Term term = 0; + LogIndex batch_start_log_index; + Term last_log_term; + std::vector<std::pair<Term, WriteRequest>> entries; + LogSize leader_commit; +}; + +struct AppendResponse { + bool success; + Term term; + Term last_log_term; + // a small optimization over the raft paper, tells + // the leader the offset that we are interested in + // to send log offsets from for us. This will only + // be useful at the beginning of a leader's term. + LogSize log_size; +}; + +struct VoteRequest { + Term term = 0; + LogSize log_size; + Term last_log_term; +}; + +struct VoteResponse { + Term term = 0; + LogSize committed_log_size; + bool vote_granted = false; +}; + +template <typename WriteRequest> +struct CommonState { + Term term = 0; + std::vector<std::pair<Term, WriteRequest>> log; + LogSize committed_log_size = 0; + LogSize applied_size = 0; +}; + +struct FollowerTracker { + LogIndex next_index = 0; + LogSize confirmed_log_size = 0; +}; + +struct PendingClientRequest { + RequestId request_id; + Address address; + Time received_at; +}; + +struct Leader { + std::map<Address, FollowerTracker> followers; + std::unordered_map<LogIndex, PendingClientRequest> pending_client_requests; + Time last_broadcast = Time::min(); + + std::string static ToString() { return "\tLeader \t"; } +}; + +struct Candidate { + std::map<Address, LogSize> successful_votes; + Time election_began = Time::min(); + std::set<Address> outstanding_votes; + + std::string static ToString() { return "\tCandidate\t"; } +}; + +struct Follower { + Time last_received_append_entries_timestamp; + Address leader_address; + + std::string static ToString() { return "\tFollower \t"; } +}; + +using Role = std::variant<Candidate, Leader, Follower>; + +template <typename Role> +concept AllRoles = memgraph::utils::SameAsAnyOf<Role, Leader, Follower, Candidate>; + +template <typename Role> +concept LeaderOrFollower = memgraph::utils::SameAsAnyOf<Role, Leader, Follower>; + +template <typename Role> +concept FollowerOrCandidate = memgraph::utils::SameAsAnyOf<Role, Follower, Candidate>; + +/* +all ReplicatedState classes should have an Apply method +that returns our WriteResponseValue: + +ReadResponse Read(ReadOperation); +WriteResponseValue ReplicatedState::Apply(WriteRequest); + +for examples: +if the state is uint64_t, and WriteRequest is `struct PlusOne {};`, +and WriteResponseValue is also uint64_t (the new value), then +each call to state.Apply(PlusOne{}) will return the new value +after incrementing it. 0, 1, 2, 3... and this will be sent back +to the client that requested the mutation. + +In practice, these mutations will usually be predicated on some +previous value, so that they are idempotent, functioning similarly +to a CAS operation. +*/ +template <typename WriteOperation, typename ReadOperation, typename ReplicatedState, typename WriteResponseValue, + typename ReadResponseValue> +concept Rsm = requires(ReplicatedState state, WriteOperation w, ReadOperation r) { + { state.Read(r) } -> std::same_as<ReadResponseValue>; + { state.Apply(w) } -> std::same_as<WriteResponseValue>; +}; + +/// Parameter Purpose +/// -------------------------- +/// IoImpl the concrete Io provider - SimulatorTransport, ThriftTransport, etc... +/// ReplicatedState the high-level data structure that is managed by the raft-backed replicated state machine +/// WriteOperation the individual operation type that is applied to the ReplicatedState in identical order +/// across each replica +/// WriteResponseValue the return value of calling ReplicatedState::Apply(WriteOperation), which is executed in +/// identical order across all replicas after an WriteOperation reaches consensus. +/// ReadOperation the type of operations that do not require consensus before executing directly +/// on a const ReplicatedState & +/// ReadResponseValue the return value of calling ReplicatedState::Read(ReadOperation), which is executed directly +/// without going through consensus first +template <typename IoImpl, typename ReplicatedState, typename WriteOperation, typename WriteResponseValue, + typename ReadOperation, typename ReadResponseValue> +requires Rsm<WriteOperation, ReadOperation, ReplicatedState, WriteResponseValue, ReadResponseValue> +class Raft { + CommonState<WriteOperation> state_; + Role role_ = Candidate{}; + Io<IoImpl> io_; + std::vector<Address> peers_; + ReplicatedState replicated_state_; + + public: + Raft(Io<IoImpl> &&io, std::vector<Address> peers, ReplicatedState &&replicated_state) + : io_(std::forward<Io<IoImpl>>(io)), + peers_(peers), + replicated_state_(std::forward<ReplicatedState>(replicated_state)) {} + + void Run() { + Time last_cron = io_.Now(); + + while (!io_.ShouldShutDown()) { + const auto now = io_.Now(); + const Duration random_cron_interval = RandomTimeout(kMinimumCronInterval, kMaximumCronInterval); + if (now - last_cron > random_cron_interval) { + Cron(); + last_cron = now; + } + + const Duration receive_timeout = RandomTimeout(kMinimumReceiveTimeout, kMaximumReceiveTimeout); + + auto request_result = + io_.template ReceiveWithTimeout<ReadRequest<ReadOperation>, AppendRequest<WriteOperation>, AppendResponse, + WriteRequest<WriteOperation>, VoteRequest, VoteResponse>(receive_timeout); + if (request_result.HasError()) { + continue; + } + + auto request = std::move(request_result.GetValue()); + + Handle(std::move(request.message), request.request_id, request.from_address); + } + } + + private: + // Raft paper - 5.3 + // When the entry has been safely replicated, the leader applies the + // entry to its state machine and returns the result of that + // execution to the client. + // + // "Safely replicated" is defined as being known to be present + // on at least a majority of all peers (inclusive of the Leader). + void BumpCommitIndexAndReplyToClients(Leader &leader) { + auto confirmed_log_sizes = std::vector<LogSize>{}; + + // We include our own log size in the calculation of the log + // confirmed log size that is present on at least a majority of all peers. + confirmed_log_sizes.push_back(state_.log.size()); + + for (const auto &[addr, f] : leader.followers) { + confirmed_log_sizes.push_back(f.confirmed_log_size); + Log("Follower at port ", addr.last_known_port, " has confirmed log size of: ", f.confirmed_log_size); + } + + // reverse sort from highest to lowest (using std::ranges::greater) + std::ranges::sort(confirmed_log_sizes, std::ranges::greater()); + + // This is a particularly correctness-critical calculation because it + // determines the committed log size that will be broadcast in + // the next AppendRequest. + // + // If the following sizes are recorded for clusters of different numbers of peers, + // these are the expected sizes that are considered to have reached consensus: + // + // state | expected value | (confirmed_log_sizes.size() / 2) + // [1] 1 (1 / 2) => 0 + // [2, 1] 1 (2 / 2) => 1 + // [3, 2, 1] 2 (3 / 2) => 1 + // [4, 3, 2, 1] 2 (4 / 2) => 2 + // [5, 4, 3, 2, 1] 3 (5 / 2) => 2 + const size_t majority_index = confirmed_log_sizes.size() / 2; + const LogSize new_committed_log_size = confirmed_log_sizes[majority_index]; + + // We never go backwards in history. + MG_ASSERT(state_.committed_log_size <= new_committed_log_size, + "as a Leader, we have previously set our committed_log_size to {}, but our Followers have a majority " + "committed_log_size of {}", + state_.committed_log_size, new_committed_log_size); + + state_.committed_log_size = new_committed_log_size; + + // For each size between the old size and the new one (inclusive), + // Apply that log's WriteOperation to our replicated_state_, + // and use the specific return value of the ReplicatedState::Apply + // method (WriteResponseValue) to respond to the requester. + for (; state_.applied_size < state_.committed_log_size; state_.applied_size++) { + const LogIndex apply_index = state_.applied_size; + const auto &write_request = state_.log[apply_index].second; + const WriteResponseValue write_return = replicated_state_.Apply(write_request); + + if (leader.pending_client_requests.contains(apply_index)) { + const PendingClientRequest client_request = std::move(leader.pending_client_requests.at(apply_index)); + leader.pending_client_requests.erase(apply_index); + + const WriteResponse<WriteResponseValue> resp{ + .success = true, + .write_return = std::move(write_return), + .raft_index = apply_index, + }; + + io_.Send(client_request.address, client_request.request_id, std::move(resp)); + } + } + + Log("committed_log_size is now ", state_.committed_log_size); + } + + // Raft paper - 5.1 + // AppendEntries RPCs are initiated by leaders to replicate log entries and to provide a form of heartbeat + void BroadcastAppendEntries(std::map<Address, FollowerTracker> &followers) { + for (auto &[address, follower] : followers) { + const LogIndex next_index = follower.next_index; + + const auto missing = state_.log.size() - next_index; + const auto batch_size = std::min(missing, kMaximumAppendBatchSize); + const auto start_index = next_index; + const auto end_index = start_index + batch_size; + + // advance follower's next index + follower.next_index += batch_size; + + std::vector<std::pair<Term, WriteOperation>> entries; + + entries.insert(entries.begin(), state_.log.begin() + start_index, state_.log.begin() + end_index); + + const Term previous_term_from_index = PreviousTermFromIndex(start_index); + + Log("sending ", entries.size(), " entries to Follower ", address.last_known_port, + " which are above its next_index of ", next_index); + + AppendRequest<WriteOperation> ar{ + .term = state_.term, + .batch_start_log_index = start_index, + .last_log_term = previous_term_from_index, + .entries = std::move(entries), + .leader_commit = state_.committed_log_size, + }; + + // request_id not necessary to set because it's not a Future-backed Request. + static constexpr RequestId request_id = 0; + + io_.Send(address, request_id, std::move(ar)); + } + } + + // Raft paper - 5.2 + // Raft uses randomized election timeouts to ensure that split votes are rare and that they are resolved quickly + Duration RandomTimeout(Duration min, Duration max) { + std::uniform_int_distribution time_distrib(min.count(), max.count()); + + const auto rand_micros = io_.Rand(time_distrib); + + return Duration{rand_micros}; + } + + Duration RandomTimeout(int min_micros, int max_micros) { + std::uniform_int_distribution time_distrib(min_micros, max_micros); + + const int rand_micros = io_.Rand(time_distrib); + + return std::chrono::microseconds{rand_micros}; + } + + Term PreviousTermFromIndex(LogIndex index) const { + if (index == 0 || state_.log.size() + 1 <= index) { + return 0; + } + + const auto &[term, data] = state_.log.at(index - 1); + return term; + } + + Term CommittedLogTerm() { + MG_ASSERT(state_.log.size() >= state_.committed_log_size); + if (state_.log.empty() || state_.committed_log_size == 0) { + return 0; + } + + const auto &[term, data] = state_.log.at(state_.committed_log_size - 1); + return term; + } + + Term LastLogTerm() const { + if (state_.log.empty()) { + return 0; + } + + const auto &[term, data] = state_.log.back(); + return term; + } + + template <typename... Ts> + void Log(Ts &&...args) { + const Time now = io_.Now(); + const auto micros = std::chrono::duration_cast<std::chrono::milliseconds>(now.time_since_epoch()).count(); + const Term term = state_.term; + const std::string role_string = std::visit([&](const auto &role) { return role.ToString(); }, role_); + + std::ostringstream out; + + out << '\t' << static_cast<int>(micros) << "\t" << term << "\t" << io_.GetAddress().last_known_port; + + out << role_string; + + (out << ... << args); + + spdlog::info(out.str()); + } + + ///////////////////////////////////////////////////////////// + /// Raft-related Cron methods + /// + /// Cron + std::visit is how events are dispatched + /// to certain code based on Raft role. + /// + /// Cron(role) takes as the first argument a reference to its + /// role, and as the second argument, the message that has + /// been received. + ///////////////////////////////////////////////////////////// + + /// Periodic protocol maintenance. + void Cron() { + // dispatch periodic logic based on our role to a specific Cron method. + std::optional<Role> new_role = std::visit([&](auto &role) { return Cron(role); }, role_); + + if (new_role) { + role_ = std::move(new_role).value(); + } + } + + // Raft paper - 5.2 + // Candidates keep sending Vote to peers until: + // 1. receiving Append with a higher term (become Follower) + // 2. receiving Vote with a higher term (become a Follower) + // 3. receiving a quorum of responses to our last batch of Vote (become a Leader) + std::optional<Role> Cron(Candidate &candidate) { + const auto now = io_.Now(); + const Duration election_timeout = RandomTimeout(kMinimumElectionTimeout, kMaximumElectionTimeout); + const auto election_timeout_us = std::chrono::duration_cast<std::chrono::milliseconds>(election_timeout).count(); + + if (now - candidate.election_began > election_timeout) { + state_.term++; + Log("becoming Candidate for term ", state_.term, " after leader timeout of ", election_timeout_us, + "ms elapsed since last election attempt"); + + const VoteRequest request{ + .term = state_.term, + .log_size = state_.log.size(), + .last_log_term = LastLogTerm(), + }; + + auto outstanding_votes = std::set<Address>(); + + for (const auto &peer : peers_) { + // request_id not necessary to set because it's not a Future-backed Request. + static constexpr auto request_id = 0; + io_.template Send<VoteRequest>(peer, request_id, request); + outstanding_votes.insert(peer); + } + + return Candidate{ + .successful_votes = std::map<Address, LogIndex>(), + .election_began = now, + .outstanding_votes = outstanding_votes, + }; + } + return std::nullopt; + } + + // Raft paper - 5.2 + // Followers become candidates if we haven't heard from the leader + // after a randomized timeout. + std::optional<Role> Cron(Follower &follower) { + const auto now = io_.Now(); + const auto time_since_last_append_entries = now - follower.last_received_append_entries_timestamp; + + // randomized follower timeout + const Duration election_timeout = RandomTimeout(kMinimumElectionTimeout, kMaximumElectionTimeout); + + if (time_since_last_append_entries > election_timeout) { + // become a Candidate if we haven't heard from the Leader after this timeout + return Candidate{}; + } + + return std::nullopt; + } + + // Leaders (re)send AppendRequest to followers. + std::optional<Role> Cron(Leader &leader) { + const Time now = io_.Now(); + const Duration broadcast_timeout = RandomTimeout(kMinimumBroadcastTimeout, kMaximumBroadcastTimeout); + + if (now - leader.last_broadcast > broadcast_timeout) { + BroadcastAppendEntries(leader.followers); + leader.last_broadcast = now; + } + + return std::nullopt; + } + + ///////////////////////////////////////////////////////////// + /// Raft-related Handle methods + /// + /// Handle + std::visit is how events are dispatched + /// to certain code based on Raft role. + /// + /// Handle(role, message, ...) + /// takes as the first argument a reference + /// to its role, and as the second argument, the + /// message that has been received. + ///////////////////////////////////////////////////////////// + + using ReceiveVariant = std::variant<ReadRequest<ReadOperation>, AppendRequest<WriteOperation>, AppendResponse, + WriteRequest<WriteOperation>, VoteRequest, VoteResponse>; + + void Handle(ReceiveVariant &&message_variant, RequestId request_id, Address from_address) { + // dispatch the message to a handler based on our role, + // which can be specified in the Handle first argument, + // or it can be `auto` if it's a handler for several roles + // or messages. + std::optional<Role> new_role = std::visit( + [&](auto &&msg, auto &role) mutable { + return Handle(role, std::forward<decltype(msg)>(msg), request_id, from_address); + }, + std::forward<ReceiveVariant>(message_variant), role_); + + // TODO(tyler) (M3) maybe replace std::visit with get_if for explicit prioritized matching, [[likely]] etc... + if (new_role) { + role_ = std::move(new_role).value(); + } + } + + // all roles can receive Vote and possibly become a follower + template <AllRoles ALL> + std::optional<Role> Handle(ALL & /* variable */, VoteRequest &&req, RequestId request_id, Address from_address) { + Log("received VoteRequest from ", from_address.last_known_port, " with term ", req.term); + const bool last_log_term_dominates = req.last_log_term >= LastLogTerm(); + const bool term_dominates = req.term > state_.term; + const bool log_size_dominates = req.log_size >= state_.log.size(); + const bool new_leader = last_log_term_dominates && term_dominates && log_size_dominates; + + if (new_leader) { + MG_ASSERT(req.term > state_.term); + MG_ASSERT(std::max(req.term, state_.term) == req.term); + } + + const VoteResponse res{ + .term = std::max(req.term, state_.term), + .committed_log_size = state_.committed_log_size, + .vote_granted = new_leader, + }; + + io_.Send(from_address, request_id, res); + + if (new_leader) { + // become a follower + state_.term = req.term; + return Follower{ + .last_received_append_entries_timestamp = io_.Now(), + .leader_address = from_address, + }; + } + + if (term_dominates) { + Log("received a vote from an inferior candidate. Becoming Candidate"); + state_.term = std::max(state_.term, req.term) + 1; + return Candidate{}; + } + + return std::nullopt; + } + + std::optional<Role> Handle(Candidate &candidate, VoteResponse &&res, RequestId /* variable */, Address from_address) { + Log("received VoteResponse"); + + if (!res.vote_granted || res.term != state_.term) { + Log("received unsuccessful VoteResponse from term ", res.term, " when our candidacy term is ", state_.term); + // we received a delayed VoteResponse from the past, which has to do with an election that is + // no longer valid. We can simply drop this. + return std::nullopt; + } + + MG_ASSERT(candidate.outstanding_votes.contains(from_address), + "Received unexpected VoteResponse from server not present in Candidate's outstanding_votes!"); + candidate.outstanding_votes.erase(from_address); + + MG_ASSERT(!candidate.successful_votes.contains(from_address), + "Received unexpected VoteResponse from server already in Candidate's successful_votes!"); + candidate.successful_votes.insert({from_address, res.committed_log_size}); + + if (candidate.successful_votes.size() >= candidate.outstanding_votes.size()) { + std::map<Address, FollowerTracker> followers{}; + + for (const auto &[address, committed_log_size] : candidate.successful_votes) { + FollowerTracker follower{ + .next_index = committed_log_size, + .confirmed_log_size = committed_log_size, + }; + followers.insert({address, follower}); + } + for (const auto &address : candidate.outstanding_votes) { + FollowerTracker follower{ + .next_index = state_.log.size(), + .confirmed_log_size = 0, + }; + followers.insert({address, follower}); + } + + Log("becoming Leader at term ", state_.term); + + BroadcastAppendEntries(followers); + + return Leader{ + .followers = std::move(followers), + .pending_client_requests = std::unordered_map<LogIndex, PendingClientRequest>(), + }; + } + + return std::nullopt; + } + + template <LeaderOrFollower LOF> + std::optional<Role> Handle(LOF & /* variable */, VoteResponse && /* variable */, RequestId /* variable */, + Address /* variable */) { + Log("non-Candidate received VoteResponse"); + return std::nullopt; + } + + template <AllRoles ALL> + std::optional<Role> Handle(ALL &role, AppendRequest<WriteOperation> &&req, RequestId request_id, + Address from_address) { + // log size starts out as state_.committed_log_size and only if everything is successful do we + // switch it to the log length. + AppendResponse res{ + .success = false, + .term = state_.term, + .last_log_term = CommittedLogTerm(), + .log_size = state_.log.size(), + }; + + if constexpr (std::is_same<ALL, Leader>()) { + MG_ASSERT(req.term != state_.term, "Multiple leaders are acting under the term ", req.term); + } + + const bool is_candidate = std::is_same<ALL, Candidate>(); + const bool is_failed_competitor = is_candidate && req.term == state_.term; + const Time now = io_.Now(); + + // Raft paper - 5.2 + // While waiting for votes, a candidate may receive an + // AppendEntries RPC from another server claiming to be leader. If + // the leader’s term (included in its RPC) is at least as large as + // the candidate’s current term, then the candidate recognizes the + // leader as legitimate and returns to follower state. + if (req.term > state_.term || is_failed_competitor) { + // become follower of this leader, reply with our log status + state_.term = req.term; + + io_.Send(from_address, request_id, res); + + Log("becoming Follower of Leader ", from_address.last_known_port, " at term ", req.term); + return Follower{ + .last_received_append_entries_timestamp = now, + .leader_address = from_address, + }; + } + + if (req.term < state_.term) { + // nack this request from an old leader + io_.Send(from_address, request_id, res); + + return std::nullopt; + } + + // at this point, we're dealing with our own leader + if constexpr (std::is_same<ALL, Follower>()) { + // small specialization for when we're already a Follower + MG_ASSERT(role.leader_address == from_address, "Multiple Leaders are acting under the same term number!"); + role.last_received_append_entries_timestamp = now; + } else { + Log("Somehow entered Follower-specific logic as a non-Follower"); + MG_ASSERT(false, "Somehow entered Follower-specific logic as a non-Follower"); + } + + // Handle steady-state conditions. + if (req.batch_start_log_index != state_.log.size()) { + Log("req.batch_start_log_index of ", req.batch_start_log_index, " does not match our log size of ", + state_.log.size()); + } else if (req.last_log_term != LastLogTerm()) { + Log("req.last_log_term differs from our leader term at that slot, expected: ", LastLogTerm(), " but got ", + req.last_log_term); + } else { + // happy path - Apply log + Log("applying batch of ", req.entries.size(), " entries to our log starting at index ", + req.batch_start_log_index); + + const auto resize_length = req.batch_start_log_index; + + MG_ASSERT(resize_length >= state_.committed_log_size, + "Applied history from Leader which goes back in time from our commit_index"); + + // possibly chop-off stuff that was replaced by + // things with different terms (we got data that + // hasn't reached consensus yet, which is normal) + state_.log.resize(resize_length); + + if (req.entries.size() > 0) { + auto &[first_term, op] = req.entries.at(0); + MG_ASSERT(LastLogTerm() <= first_term); + } + + state_.log.insert(state_.log.end(), std::make_move_iterator(req.entries.begin()), + std::make_move_iterator(req.entries.end())); + + MG_ASSERT(req.leader_commit >= state_.committed_log_size); + state_.committed_log_size = std::min(req.leader_commit, state_.log.size()); + + for (; state_.applied_size < state_.committed_log_size; state_.applied_size++) { + const auto &write_request = state_.log[state_.applied_size].second; + replicated_state_.Apply(write_request); + } + + res.success = true; + } + + res.last_log_term = LastLogTerm(); + res.log_size = state_.log.size(); + + Log("returning log_size of ", res.log_size); + + io_.Send(from_address, request_id, res); + + return std::nullopt; + } + + std::optional<Role> Handle(Leader &leader, AppendResponse &&res, RequestId /* variable */, Address from_address) { + if (res.term != state_.term) { + Log("received AppendResponse related to a previous term when we (presumably) were the leader"); + return std::nullopt; + } + + // TODO(tyler) when we have dynamic membership, this assert will become incorrect, but we should + // keep it in-place until then because it has bug finding value. + MG_ASSERT(leader.followers.contains(from_address), "received AppendResponse from unknown Follower"); + + // at this point, we know the term matches and we know this Follower + + FollowerTracker &follower = leader.followers.at(from_address); + + if (res.success) { + Log("got successful AppendResponse from ", from_address.last_known_port, " with log_size of ", res.log_size); + follower.next_index = std::max(follower.next_index, res.log_size); + } else { + Log("got unsuccessful AppendResponse from ", from_address.last_known_port, " with log_size of ", res.log_size); + follower.next_index = res.log_size; + } + + follower.confirmed_log_size = std::max(follower.confirmed_log_size, res.log_size); + + BumpCommitIndexAndReplyToClients(leader); + + return std::nullopt; + } + + template <FollowerOrCandidate FOC> + std::optional<Role> Handle(FOC & /* variable */, AppendResponse && /* variable */, RequestId /* variable */, + Address /* variable */) { + // we used to be the leader, and are getting old delayed responses + return std::nullopt; + } + + ///////////////////////////////////////////////////////////// + /// RSM-related handle methods + ///////////////////////////////////////////////////////////// + + // Leaders are able to immediately respond to the requester (with a ReadResponseValue) applied to the ReplicatedState + std::optional<Role> Handle(Leader & /* variable */, ReadRequest<ReadOperation> &&req, RequestId request_id, + Address from_address) { + Log("handling ReadOperation"); + ReadOperation read_operation = req.operation; + + ReadResponseValue read_return = replicated_state_.Read(read_operation); + + const ReadResponse<ReadResponseValue> resp{ + .success = true, + .read_return = std::move(read_return), + .retry_leader = std::nullopt, + }; + + io_.Send(from_address, request_id, resp); + + return std::nullopt; + } + + // Candidates should respond with a failure, similar to the Candidate + WriteRequest failure below + std::optional<Role> Handle(Candidate & /* variable */, ReadRequest<ReadOperation> && /* variable */, + RequestId request_id, Address from_address) { + Log("received ReadOperation - not redirecting because no Leader is known"); + const ReadResponse<ReadResponseValue> res{ + .success = false, + }; + + io_.Send(from_address, request_id, res); + + Cron(); + + return std::nullopt; + } + + // Followers should respond with a redirection, similar to the Follower + WriteRequest response below + std::optional<Role> Handle(Follower &follower, ReadRequest<ReadOperation> && /* variable */, RequestId request_id, + Address from_address) { + Log("redirecting client to known Leader with port ", follower.leader_address.last_known_port); + + const ReadResponse<ReadResponseValue> res{ + .success = false, + .retry_leader = follower.leader_address, + }; + + io_.Send(from_address, request_id, res); + + return std::nullopt; + } + + // Raft paper - 8 + // When a client first starts up, it connects to a randomly chosen + // server. If the client’s first choice is not the leader, that + // server will reject the client’s request and supply information + // about the most recent leader it has heard from. + std::optional<Role> Handle(Follower &follower, WriteRequest<WriteOperation> && /* variable */, RequestId request_id, + Address from_address) { + Log("redirecting client to known Leader with port ", follower.leader_address.last_known_port); + + const WriteResponse<WriteResponseValue> res{ + .success = false, + .retry_leader = follower.leader_address, + }; + + io_.Send(from_address, request_id, res); + + return std::nullopt; + } + + std::optional<Role> Handle(Candidate & /* variable */, WriteRequest<WriteOperation> && /* variable */, + RequestId request_id, Address from_address) { + Log("received WriteRequest - not redirecting because no Leader is known"); + + const WriteResponse<WriteResponseValue> res{ + .success = false, + }; + + io_.Send(from_address, request_id, res); + + Cron(); + + return std::nullopt; + } + + // only leaders actually handle replication requests from clients + std::optional<Role> Handle(Leader &leader, WriteRequest<WriteOperation> &&req, RequestId request_id, + Address from_address) { + Log("handling WriteRequest"); + + // we are the leader. add item to log and send Append to peers + MG_ASSERT(state_.term >= LastLogTerm()); + state_.log.emplace_back(std::pair(state_.term, std::move(req.operation))); + + LogIndex log_index = state_.log.size() - 1; + + PendingClientRequest pcr{ + .request_id = request_id, + .address = from_address, + .received_at = io_.Now(), + }; + + leader.pending_client_requests.emplace(log_index, pcr); + + BroadcastAppendEntries(leader.followers); + + return std::nullopt; + } +}; + +}; // namespace memgraph::io::rsm diff --git a/src/io/simulator/simulator_handle.cpp b/src/io/simulator/simulator_handle.cpp index 05585f551..16b2b71a1 100644 --- a/src/io/simulator/simulator_handle.cpp +++ b/src/io/simulator/simulator_handle.cpp @@ -38,7 +38,7 @@ void SimulatorHandle::IncrementServerCountAndWaitForQuiescentState(Address addre server_addresses_.insert(address); while (true) { - const size_t blocked_servers = BlockedServers(); + const size_t blocked_servers = blocked_on_receive_; const bool all_servers_blocked = blocked_servers == server_addresses_.size(); @@ -50,22 +50,10 @@ void SimulatorHandle::IncrementServerCountAndWaitForQuiescentState(Address addre } } -size_t SimulatorHandle::BlockedServers() { - size_t blocked_servers = blocked_on_receive_; - - for (auto &[promise_key, opaque_promise] : promises_) { - if (opaque_promise.promise.IsAwaited() && server_addresses_.contains(promise_key.requester_address)) { - blocked_servers++; - } - } - - return blocked_servers; -} - bool SimulatorHandle::MaybeTickSimulator() { std::unique_lock<std::mutex> lock(mu_); - const size_t blocked_servers = BlockedServers(); + const size_t blocked_servers = blocked_on_receive_; if (blocked_servers < server_addresses_.size()) { // we only need to advance the simulator when all diff --git a/src/io/simulator/simulator_handle.hpp b/src/io/simulator/simulator_handle.hpp index 08a3837ee..8a9e77c58 100644 --- a/src/io/simulator/simulator_handle.hpp +++ b/src/io/simulator/simulator_handle.hpp @@ -53,14 +53,6 @@ class SimulatorHandle { std::mt19937 rng_; SimulatorConfig config_; - /// Returns the number of servers currently blocked on Receive, plus - /// the servers that are blocked on Futures that were created through - /// SimulatorTransport::Request. - /// - /// TODO(tyler) investigate whether avoiding consideration of Futures - /// increases determinism. - size_t BlockedServers(); - void TimeoutPromisesPastDeadline() { const Time now = cluster_wide_time_microseconds_; diff --git a/tests/simulation/CMakeLists.txt b/tests/simulation/CMakeLists.txt index 142657401..e868ca0ef 100644 --- a/tests/simulation/CMakeLists.txt +++ b/tests/simulation/CMakeLists.txt @@ -27,4 +27,6 @@ endfunction(add_simulation_test) add_simulation_test(basic_request.cpp address) +add_simulation_test(raft.cpp address) + add_simulation_test(trial_query_storage/query_storage_test.cpp address) diff --git a/tests/simulation/raft.cpp b/tests/simulation/raft.cpp new file mode 100644 index 000000000..7033d04e6 --- /dev/null +++ b/tests/simulation/raft.cpp @@ -0,0 +1,315 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include <chrono> +#include <deque> +#include <iostream> +#include <map> +#include <optional> +#include <set> +#include <thread> +#include <vector> + +#include "io/address.hpp" +#include "io/rsm/raft.hpp" +#include "io/simulator/simulator.hpp" +#include "io/simulator/simulator_transport.hpp" + +using memgraph::io::Address; +using memgraph::io::Duration; +using memgraph::io::Io; +using memgraph::io::ResponseEnvelope; +using memgraph::io::ResponseFuture; +using memgraph::io::ResponseResult; +using memgraph::io::Time; +using memgraph::io::rsm::Raft; +using memgraph::io::rsm::ReadRequest; +using memgraph::io::rsm::ReadResponse; +using memgraph::io::rsm::WriteRequest; +using memgraph::io::rsm::WriteResponse; +using memgraph::io::simulator::Simulator; +using memgraph::io::simulator::SimulatorConfig; +using memgraph::io::simulator::SimulatorStats; +using memgraph::io::simulator::SimulatorTransport; + +struct CasRequest { + int key; + std::optional<int> old_value; + std::optional<int> new_value; +}; + +struct CasResponse { + bool cas_success; + std::optional<int> last_value; +}; + +struct GetRequest { + int key; +}; + +struct GetResponse { + std::optional<int> value; +}; + +class TestState { + std::map<int, int> state_; + + public: + GetResponse Read(GetRequest request) { + GetResponse ret; + if (state_.contains(request.key)) { + ret.value = state_[request.key]; + } + return ret; + } + + CasResponse Apply(CasRequest request) { + CasResponse ret; + + // Key exist + if (state_.contains(request.key)) { + auto &val = state_[request.key]; + + /* + * Delete + */ + if (!request.new_value) { + ret.last_value = val; + ret.cas_success = true; + + state_.erase(state_.find(request.key)); + } + + /* + * Update + */ + // Does old_value match? + if (request.old_value == val) { + ret.last_value = val; + ret.cas_success = true; + + val = request.new_value.value(); + } else { + ret.last_value = val; + ret.cas_success = false; + } + } + /* + * Create + */ + else { + ret.last_value = std::nullopt; + ret.cas_success = true; + + state_.emplace(request.key, std::move(request.new_value).value()); + } + + return ret; + } +}; + +template <typename IoImpl> +void RunRaft(Raft<IoImpl, TestState, CasRequest, CasResponse, GetRequest, GetResponse> server) { + server.Run(); +} + +void RunSimulation() { + SimulatorConfig config{ + .drop_percent = 5, + .perform_timeouts = true, + .scramble_messages = true, + .rng_seed = 0, + .start_time = Time::min() + std::chrono::microseconds{256 * 1024}, + .abort_time = Time::min() + std::chrono::microseconds{8 * 1024 * 128}, + }; + + auto simulator = Simulator(config); + + auto cli_addr = Address::TestAddress(1); + auto srv_addr_1 = Address::TestAddress(2); + auto srv_addr_2 = Address::TestAddress(3); + auto srv_addr_3 = Address::TestAddress(4); + + Io<SimulatorTransport> cli_io = simulator.Register(cli_addr); + Io<SimulatorTransport> srv_io_1 = simulator.Register(srv_addr_1); + Io<SimulatorTransport> srv_io_2 = simulator.Register(srv_addr_2); + Io<SimulatorTransport> srv_io_3 = simulator.Register(srv_addr_3); + + std::vector<Address> srv_1_peers = {srv_addr_2, srv_addr_3}; + std::vector<Address> srv_2_peers = {srv_addr_1, srv_addr_3}; + std::vector<Address> srv_3_peers = {srv_addr_1, srv_addr_2}; + + // TODO(tyler / gabor) supply default TestState to Raft constructor + using RaftClass = Raft<SimulatorTransport, TestState, CasRequest, CasResponse, GetRequest, GetResponse>; + RaftClass srv_1{std::move(srv_io_1), srv_1_peers, TestState{}}; + RaftClass srv_2{std::move(srv_io_2), srv_2_peers, TestState{}}; + RaftClass srv_3{std::move(srv_io_3), srv_3_peers, TestState{}}; + + auto srv_thread_1 = std::jthread(RunRaft<SimulatorTransport>, std::move(srv_1)); + simulator.IncrementServerCountAndWaitForQuiescentState(srv_addr_1); + + auto srv_thread_2 = std::jthread(RunRaft<SimulatorTransport>, std::move(srv_2)); + simulator.IncrementServerCountAndWaitForQuiescentState(srv_addr_2); + + auto srv_thread_3 = std::jthread(RunRaft<SimulatorTransport>, std::move(srv_3)); + simulator.IncrementServerCountAndWaitForQuiescentState(srv_addr_3); + + spdlog::info("beginning test after servers have become quiescent"); + + std::mt19937 cli_rng_{0}; + Address server_addrs[]{srv_addr_1, srv_addr_2, srv_addr_3}; + Address leader = server_addrs[0]; + + const int key = 0; + std::optional<int> last_known_value = 0; + + bool success = false; + + for (int i = 0; !success; i++) { + // send request + CasRequest cas_req; + cas_req.key = key; + + cas_req.old_value = last_known_value; + + cas_req.new_value = i; + + WriteRequest<CasRequest> cli_req; + cli_req.operation = cas_req; + + spdlog::info("client sending CasRequest to Leader {} ", leader.last_known_port); + ResponseFuture<WriteResponse<CasResponse>> cas_response_future = + cli_io.Request<WriteRequest<CasRequest>, WriteResponse<CasResponse>>(leader, cli_req); + + // receive cas_response + ResponseResult<WriteResponse<CasResponse>> cas_response_result = std::move(cas_response_future).Wait(); + + if (cas_response_result.HasError()) { + spdlog::info("client timed out while trying to communicate with assumed Leader server {}", + leader.last_known_port); + continue; + } + + ResponseEnvelope<WriteResponse<CasResponse>> cas_response_envelope = cas_response_result.GetValue(); + WriteResponse<CasResponse> write_cas_response = cas_response_envelope.message; + + if (write_cas_response.retry_leader) { + MG_ASSERT(!write_cas_response.success, "retry_leader should never be set for successful responses"); + leader = write_cas_response.retry_leader.value(); + spdlog::info("client redirected to leader server {}", leader.last_known_port); + } else if (!write_cas_response.success) { + std::uniform_int_distribution<size_t> addr_distrib(0, 2); + size_t addr_index = addr_distrib(cli_rng_); + leader = server_addrs[addr_index]; + + spdlog::info("client NOT redirected to leader server, trying a random one at index {} with port {}", addr_index, + leader.last_known_port); + continue; + } + + CasResponse cas_response = write_cas_response.write_return; + + bool cas_succeeded = cas_response.cas_success; + + spdlog::info("Client received CasResponse! success: {} last_known_value {}", cas_succeeded, (int)*last_known_value); + + if (cas_succeeded) { + last_known_value = i; + } else { + last_known_value = cas_response.last_value; + continue; + } + + GetRequest get_req; + get_req.key = key; + + ReadRequest<GetRequest> read_req; + read_req.operation = get_req; + + spdlog::info("client sending GetRequest to Leader {}", leader.last_known_port); + + ResponseFuture<ReadResponse<GetResponse>> get_response_future = + cli_io.Request<ReadRequest<GetRequest>, ReadResponse<GetResponse>>(leader, read_req); + + // receive response + ResponseResult<ReadResponse<GetResponse>> get_response_result = std::move(get_response_future).Wait(); + + if (get_response_result.HasError()) { + spdlog::info("client timed out while trying to communicate with Leader server {}", leader.last_known_port); + continue; + } + + ResponseEnvelope<ReadResponse<GetResponse>> get_response_envelope = get_response_result.GetValue(); + ReadResponse<GetResponse> read_get_response = get_response_envelope.message; + + if (!read_get_response.success) { + // sent to a non-leader + continue; + } + + if (read_get_response.retry_leader) { + MG_ASSERT(!read_get_response.success, "retry_leader should never be set for successful responses"); + leader = read_get_response.retry_leader.value(); + spdlog::info("client redirected to Leader server {}", leader.last_known_port); + } else if (!read_get_response.success) { + std::uniform_int_distribution<size_t> addr_distrib(0, 2); + size_t addr_index = addr_distrib(cli_rng_); + leader = server_addrs[addr_index]; + + spdlog::info("client NOT redirected to leader server, trying a random one at index {} with port {}", addr_index, + leader.last_known_port); + } + + GetResponse get_response = read_get_response.read_return; + + MG_ASSERT(get_response.value == i); + + spdlog::info("client successfully cas'd a value and read it back! value: {}", i); + + success = true; + } + + MG_ASSERT(success); + + simulator.ShutDown(); + + SimulatorStats stats = simulator.Stats(); + + spdlog::info("total messages: {}", stats.total_messages); + spdlog::info("dropped messages: {}", stats.dropped_messages); + spdlog::info("timed out requests: {}", stats.timed_out_requests); + spdlog::info("total requests: {}", stats.total_requests); + spdlog::info("total responses: {}", stats.total_responses); + spdlog::info("simulator ticks: {}", stats.simulator_ticks); + + spdlog::info("========================== SUCCESS :) =========================="); + + /* + this is implicit in jthread's dtor + srv_thread_1.join(); + srv_thread_2.join(); + srv_thread_3.join(); + */ +} + +int main() { + int n_tests = 50; + + for (int i = 0; i < n_tests; i++) { + spdlog::info("========================== NEW SIMULATION {} ==========================", i); + spdlog::info("\tTime\t\tTerm\tPort\tRole\t\tMessage\n"); + RunSimulation(); + } + + spdlog::info("passed {} tests!", n_tests); + + return 0; +} From c6447eb48b2f69f0ff469d427fcaba159f62a2ee Mon Sep 17 00:00:00 2001 From: Kostas Kyrimis <kostaskyrim@gmail.com> Date: Thu, 1 Sep 2022 18:54:47 +0300 Subject: [PATCH 4/4] Add shard requests responses (#526) --- src/query/v2/requests.hpp | 223 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 223 insertions(+) create mode 100644 src/query/v2/requests.hpp diff --git a/src/query/v2/requests.hpp b/src/query/v2/requests.hpp new file mode 100644 index 000000000..421a7fd54 --- /dev/null +++ b/src/query/v2/requests.hpp @@ -0,0 +1,223 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <chrono> +#include <iostream> +#include <map> +#include <optional> +#include <unordered_map> +#include <variant> +#include <vector> + +#include "storage/v3/id_types.hpp" +#include "storage/v3/property_value.hpp" + +/// Hybrid-logical clock +struct Hlc { + uint64_t logical_id; + using Duration = std::chrono::microseconds; + using Time = std::chrono::time_point<std::chrono::system_clock, Duration>; + Time coordinator_wall_clock; + + bool operator==(const Hlc &other) const = default; +}; + +struct Label { + size_t id; +}; + +// TODO(kostasrim) update this with CompoundKey, same for the rest of the file. +using PrimaryKey = std::vector<memgraph::storage::v3::PropertyValue>; +using VertexId = std::pair<Label, PrimaryKey>; +using Gid = size_t; +using PropertyId = memgraph::storage::v3::PropertyId; + +struct EdgeType { + std::string name; +}; + +struct EdgeId { + VertexId id; + Gid gid; +}; + +struct Vertex { + VertexId id; + std::vector<Label> labels; +}; + +struct Edge { + VertexId src; + VertexId dst; + EdgeType type; +}; + +struct PathPart { + Vertex dst; + Gid edge; +}; + +struct Path { + Vertex src; + std::vector<PathPart> parts; +}; + +struct Null {}; + +struct Value { + enum Type { NILL, BOOL, INT64, DOUBLE, STRING, LIST, MAP, VERTEX, EDGE, PATH }; + union { + Null null_v; + bool bool_v; + uint64_t int_v; + double double_v; + std::string string_v; + std::vector<Value> list_v; + std::map<std::string, Value> map_v; + Vertex vertex_v; + Edge edge_v; + Path path_v; + }; + + Type type; +}; + +struct ValuesMap { + std::unordered_map<PropertyId, Value> values_map; +}; + +struct MappedValues { + std::vector<ValuesMap> values_map; +}; + +struct ListedValues { + std::vector<std::vector<Value>> properties; +}; + +using Values = std::variant<ListedValues, MappedValues>; + +struct Expression { + std::string expression; +}; + +struct Filter { + std::string filter_expression; +}; + +enum class OrderingDirection { ASCENDING = 1, DESCENDING = 2 }; + +struct OrderBy { + Expression expression; + OrderingDirection direction; +}; + +enum class StorageView { OLD = 0, NEW = 1 }; + +struct ScanVerticesRequest { + Hlc transaction_id; + size_t start_id; + std::optional<std::vector<std::string>> props_to_return; + std::optional<std::vector<std::string>> filter_expressions; + std::optional<size_t> batch_limit; + StorageView storage_view; +}; + +struct ScanVerticesResponse { + bool success; + Values values; + std::optional<VertexId> next_start_id; +}; + +using VertexOrEdgeIds = std::variant<VertexId, EdgeId>; + +struct GetPropertiesRequest { + Hlc transaction_id; + VertexOrEdgeIds vertex_or_edge_ids; + std::vector<PropertyId> property_ids; + std::vector<Expression> expressions; + bool only_unique = false; + std::optional<std::vector<OrderBy>> order_by; + std::optional<size_t> limit; + std::optional<Filter> filter; +}; + +struct GetPropertiesResponse { + bool success; + Values values; +}; + +enum class EdgeDirection : uint8_t { OUT = 1, IN = 2, BOTH = 3 }; + +struct ExpandOneRequest { + Hlc transaction_id; + std::vector<VertexId> src_vertices; + std::vector<EdgeType> edge_types; + EdgeDirection direction; + bool only_unique_neighbor_rows = false; + // The empty optional means return all of the properties, while an empty + // list means do not return any properties + // TODO(antaljanosbenjamin): All of the special values should be communicated through a single vertex object + // after schema is implemented + // Special values are accepted: + // * __mg__labels + std::optional<std::vector<PropertyId>> src_vertex_properties; + // TODO(antaljanosbenjamin): All of the special values should be communicated through a single vertex object + // after schema is implemented + // Special values are accepted: + // * __mg__dst_id (Vertex, but without labels) + // * __mg__type (binary) + std::optional<std::vector<PropertyId>> edge_properties; + // QUESTION(antaljanosbenjamin): Maybe also add possibility to expressions evaluated on the source vertex? + // List of expressions evaluated on edges + std::vector<Expression> expressions; + std::optional<std::vector<OrderBy>> order_by; + std::optional<size_t> limit; + std::optional<Filter> filter; +}; + +struct ExpandOneResultRow { + // NOTE: This struct could be a single Values with columns something like this: + // src_vertex(Vertex), vertex_prop1(Value), vertex_prop2(Value), edges(list<Value>) + // where edges might be a list of: + // 1. list<Value> if only a defined list of edge properties are returned + // 2. map<binary, Value> if all of the edge properties are returned + // The drawback of this is currently the key of the map is always interpreted as a string in Value, not as an + // integer, which should be in case of mapped properties. + Vertex src_vertex; + std::optional<Values> src_vertex_properties; + Values edges; +}; + +struct ExpandOneResponse { + std::vector<ExpandOneResultRow> result; +}; + +struct NewVertex { + std::vector<Label> label_ids; + std::map<PropertyId, Value> properties; +}; + +struct CreateVerticesRequest { + Hlc transaction_id; + std::vector<NewVertex> new_vertices; +}; + +struct CreateVerticesResponse { + bool success; +}; + +using ReadRequests = std::variant<ExpandOneRequest, GetPropertiesRequest, ScanVerticesRequest>; +using ReadResponses = std::variant<ExpandOneResponse, GetPropertiesResponse, ScanVerticesResponse>; + +using WriteRequests = CreateVerticesRequest; +using WriteResponses = CreateVerticesResponse;