Merge remote-tracking branch 'origin/project-pineapples' into E118-MG-lexicographically-ordered-storage

This commit is contained in:
János Benjamin Antal 2022-09-02 11:41:03 +02:00
commit 947baedbe6
31 changed files with 3334 additions and 1 deletions

View File

@ -171,7 +171,7 @@ jobs:
# Run leftover CTest tests (all except unit and benchmark tests).
cd build
ctest -E "(memgraph__unit|memgraph__benchmark)" --output-on-failure
ctest -E "(memgraph__unit|memgraph__benchmark|memgraph__simulation)" --output-on-failure
- name: Run drivers tests
run: |
@ -262,6 +262,15 @@ jobs:
cd build
ctest -R memgraph__unit --output-on-failure -j$THREADS
- name: Run simulation tests
run: |
# Activate toolchain.
source /opt/toolchain-v4/activate
# Run unit tests.
cd build
ctest -R memgraph__simulation --output-on-failure -j$THREADS
- name: Run e2e tests
run: |
# TODO(gitbuda): Setup mgclient and pymgclient properly.

View File

@ -5,6 +5,7 @@ add_subdirectory(lisp)
add_subdirectory(utils)
add_subdirectory(requests)
add_subdirectory(io)
add_subdirectory(io/simulator)
add_subdirectory(kvstore)
add_subdirectory(telemetry)
add_subdirectory(communication)

68
src/io/address.hpp Normal file
View File

@ -0,0 +1,68 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <compare>
#include <fmt/format.h>
#include <boost/asio/ip/tcp.hpp>
#include <boost/uuid/uuid.hpp>
#include <boost/uuid/uuid_generators.hpp>
#include <boost/uuid/uuid_io.hpp>
namespace memgraph::io {
struct Address {
// It's important for all participants to have a
// unique identifier - IP and port alone are not
// enough, and may change over the lifecycle of
// the nodes. Particularly storage nodes may change
// their IP addresses over time, and the system
// should gracefully update its information
// about them.
boost::uuids::uuid unique_id;
boost::asio::ip::address last_known_ip;
uint16_t last_known_port;
static Address TestAddress(uint16_t port) {
return Address{
.unique_id = boost::uuids::uuid{boost::uuids::random_generator()()},
.last_known_port = port,
};
}
static Address UniqueLocalAddress() {
return Address{
.unique_id = boost::uuids::uuid{boost::uuids::random_generator()()},
};
}
friend bool operator==(const Address &lhs, const Address &rhs) = default;
/// unique_id is most dominant for ordering, then last_known_ip, then last_known_port
friend bool operator<(const Address &lhs, const Address &rhs) {
if (lhs.unique_id != rhs.unique_id) {
return lhs.unique_id < rhs.unique_id;
}
if (lhs.last_known_ip != rhs.last_known_ip) {
return lhs.last_known_ip < rhs.last_known_ip;
}
return lhs.last_known_port < rhs.last_known_port;
}
std::string ToString() const {
return fmt::format("Address {{ unique_id: {}, last_known_ip: {}, last_known_port: {} }}",
boost::uuids::to_string(unique_id), last_known_ip.to_string(), last_known_port);
}
};
}; // namespace memgraph::io

26
src/io/errors.hpp Normal file
View File

@ -0,0 +1,26 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
namespace memgraph::io {
// Signifies that a retriable operation was unable to
// complete after a configured number of retries.
struct RetriesExhausted {};
// Signifies that a request was unable to receive a response
// within some configured timeout duration. It is important
// to remember that in distributed systems, a timeout does
// not signify that a request was not received or processed.
// It may be the case that the request was fully processed
// but that the response was not received.
struct TimedOut {};
}; // namespace memgraph::io

262
src/io/future.hpp Normal file
View File

@ -0,0 +1,262 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <condition_variable>
#include <memory>
#include <mutex>
#include <optional>
#include <thread>
#include <utility>
#include "io/errors.hpp"
#include "utils/logging.hpp"
namespace memgraph::io {
// Shared is in an anonymous namespace, and the only way to
// construct a Promise or Future is to pass a Shared in. This
// ensures that Promises and Futures can only be constructed
// in this translation unit.
namespace details {
template <typename T>
class Shared {
mutable std::condition_variable cv_;
mutable std::mutex mu_;
std::optional<T> item_;
bool consumed_ = false;
bool waiting_ = false;
std::function<bool()> simulator_notifier_ = nullptr;
public:
explicit Shared(std::function<bool()> simulator_notifier) : simulator_notifier_(simulator_notifier) {}
Shared() = default;
Shared(Shared &&) = delete;
Shared &operator=(Shared &&) = delete;
Shared(const Shared &) = delete;
Shared &operator=(const Shared &) = delete;
~Shared() = default;
/// Takes the item out of our optional item_ and returns it.
T Take() {
MG_ASSERT(item_, "Take called without item_ being present");
MG_ASSERT(!consumed_, "Take called on already-consumed Future");
T ret = std::move(item_).value();
item_.reset();
consumed_ = true;
return ret;
}
T Wait() {
std::unique_lock<std::mutex> lock(mu_);
waiting_ = true;
while (!item_) {
bool simulator_progressed = false;
if (simulator_notifier_) [[unlikely]] {
// We can't hold our own lock while notifying
// the simulator because notifying the simulator
// involves acquiring the simulator's mutex
// to guarantee that our notification linearizes
// with the simulator's condition variable.
// However, the simulator may acquire our
// mutex to check if we are being awaited,
// while determining system quiescence,
// so we have to get out of its way to avoid
// a cyclical deadlock.
lock.unlock();
simulator_progressed = std::invoke(simulator_notifier_);
lock.lock();
if (item_) {
// item may have been filled while we
// had dropped our mutex while notifying
// the simulator of our waiting_ status.
break;
}
}
if (!simulator_progressed) [[likely]] {
cv_.wait(lock);
}
MG_ASSERT(!consumed_, "Future consumed twice!");
}
waiting_ = false;
return Take();
}
bool IsReady() const {
std::unique_lock<std::mutex> lock(mu_);
return item_;
}
std::optional<T> TryGet() {
std::unique_lock<std::mutex> lock(mu_);
if (item_) {
return Take();
}
return std::nullopt;
}
void Fill(T item) {
{
std::unique_lock<std::mutex> lock(mu_);
MG_ASSERT(!consumed_, "Promise filled after it was already consumed!");
MG_ASSERT(!item_, "Promise filled twice!");
item_ = item;
} // lock released before condition variable notification
cv_.notify_all();
}
bool IsAwaited() const {
std::unique_lock<std::mutex> lock(mu_);
return waiting_;
}
};
} // namespace details
template <typename T>
class Future {
bool consumed_or_moved_ = false;
std::shared_ptr<details::Shared<T>> shared_;
public:
explicit Future(std::shared_ptr<details::Shared<T>> shared) : shared_(shared) {}
Future() = delete;
Future(Future &&old) noexcept {
MG_ASSERT(!old.consumed_or_moved_, "Future moved from after already being moved from or consumed.");
shared_ = std::move(old.shared_);
consumed_or_moved_ = old.consumed_or_moved_;
old.consumed_or_moved_ = true;
}
Future &operator=(Future &&old) noexcept {
MG_ASSERT(!old.consumed_or_moved_, "Future moved from after already being moved from or consumed.");
shared_ = std::move(old.shared_);
old.consumed_or_moved_ = true;
}
Future(const Future &) = delete;
Future &operator=(const Future &) = delete;
~Future() = default;
/// Returns true if the Future is ready to
/// be consumed using TryGet or Wait (prefer Wait
/// if you know it's ready, because it doesn't
/// return an optional.
bool IsReady() {
MG_ASSERT(!consumed_or_moved_, "Called IsReady after Future already consumed!");
return shared_->IsReady();
}
/// Non-blocking method that returns the inner
/// item if it's already ready, or std::nullopt
/// if it is not ready yet.
std::optional<T> TryGet() {
MG_ASSERT(!consumed_or_moved_, "Called TryGet after Future already consumed!");
std::optional<T> ret = shared_->TryGet();
if (ret) {
consumed_or_moved_ = true;
}
return ret;
}
/// Block on the corresponding promise to be filled,
/// returning the inner item when ready.
T Wait() && {
MG_ASSERT(!consumed_or_moved_, "Future should only be consumed with Wait once!");
T ret = shared_->Wait();
consumed_or_moved_ = true;
return ret;
}
/// Marks this Future as canceled.
void Cancel() {
MG_ASSERT(!consumed_or_moved_, "Future::Cancel called on a future that was already moved or consumed!");
consumed_or_moved_ = true;
}
};
template <typename T>
class Promise {
std::shared_ptr<details::Shared<T>> shared_;
bool filled_or_moved_ = false;
public:
explicit Promise(std::shared_ptr<details::Shared<T>> shared) : shared_(shared) {}
Promise() = delete;
Promise(Promise &&old) noexcept {
MG_ASSERT(!old.filled_or_moved_, "Promise moved from after already being moved from or filled.");
shared_ = std::move(old.shared_);
old.filled_or_moved_ = true;
}
Promise &operator=(Promise &&old) noexcept {
MG_ASSERT(!old.filled_or_moved_, "Promise moved from after already being moved from or filled.");
shared_ = std::move(old.shared_);
old.filled_or_moved_ = true;
}
Promise(const Promise &) = delete;
Promise &operator=(const Promise &) = delete;
~Promise() { MG_ASSERT(filled_or_moved_, "Promise destroyed before its associated Future was filled!"); }
// Fill the expected item into the Future.
void Fill(T item) {
MG_ASSERT(!filled_or_moved_, "Promise::Fill called on a promise that is already filled or moved!");
shared_->Fill(item);
filled_or_moved_ = true;
}
bool IsAwaited() { return shared_->IsAwaited(); }
/// Moves this Promise into a unique_ptr.
std::unique_ptr<Promise<T>> ToUnique() && {
std::unique_ptr<Promise<T>> up = std::make_unique<Promise<T>>(std::move(shared_));
filled_or_moved_ = true;
return up;
}
};
template <typename T>
std::pair<Future<T>, Promise<T>> FuturePromisePair() {
std::shared_ptr<details::Shared<T>> shared = std::make_shared<details::Shared<T>>();
Future<T> future = Future<T>(shared);
Promise<T> promise = Promise<T>(shared);
return std::make_pair(std::move(future), std::move(promise));
}
template <typename T>
std::pair<Future<T>, Promise<T>> FuturePromisePairWithNotifier(std::function<bool()> simulator_notifier) {
std::shared_ptr<details::Shared<T>> shared = std::make_shared<details::Shared<T>>(simulator_notifier);
Future<T> future = Future<T>(shared);
Promise<T> promise = Promise<T>(shared);
return std::make_pair(std::move(future), std::move(promise));
}
}; // namespace memgraph::io

View File

@ -0,0 +1,35 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <random>
#include "io/address.hpp"
#include "io/local_transport/local_transport.hpp"
#include "io/local_transport/local_transport_handle.hpp"
#include "io/transport.hpp"
namespace memgraph::io::local_transport {
class LocalSystem {
std::shared_ptr<LocalTransportHandle> local_transport_handle_ = std::make_shared<LocalTransportHandle>();
public:
Io<LocalTransport> Register(Address address) {
LocalTransport local_transport(local_transport_handle_, address);
return Io{local_transport, address};
}
void ShutDown() { local_transport_handle_->ShutDown(); }
};
} // namespace memgraph::io::local_transport

View File

@ -0,0 +1,67 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <chrono>
#include <memory>
#include <random>
#include <utility>
#include "io/address.hpp"
#include "io/local_transport/local_transport_handle.hpp"
#include "io/time.hpp"
#include "io/transport.hpp"
namespace memgraph::io::local_transport {
class LocalTransport {
std::shared_ptr<LocalTransportHandle> local_transport_handle_;
const Address address_;
public:
LocalTransport(std::shared_ptr<LocalTransportHandle> local_transport_handle, Address address)
: local_transport_handle_(std::move(local_transport_handle)), address_(address) {}
template <Message Request, Message Response>
ResponseFuture<Response> Request(Address to_address, RequestId request_id, Request request, Duration timeout) {
auto [future, promise] = memgraph::io::FuturePromisePair<ResponseResult<Response>>();
Address from_address = address_;
local_transport_handle_->SubmitRequest(to_address, from_address, request_id, std::move(request), timeout,
std::move(promise));
return std::move(future);
}
template <Message... Ms>
requires(sizeof...(Ms) > 0) RequestResult<Ms...> Receive(Duration timeout) {
Address from_address = address_;
return local_transport_handle_->template Receive<Ms...>(timeout);
}
template <Message M>
void Send(Address to_address, Address from_address, RequestId request_id, M &&message) {
return local_transport_handle_->template Send<M>(to_address, from_address, request_id, std::forward<M>(message));
}
Time Now() const { return local_transport_handle_->Now(); }
bool ShouldShutDown() const { return local_transport_handle_->ShouldShutDown(); }
template <class D = std::poisson_distribution<>, class Return = uint64_t>
Return Rand(D distrib) {
std::random_device rng;
return distrib(rng);
}
};
}; // namespace memgraph::io::local_transport

View File

@ -0,0 +1,139 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <chrono>
#include <condition_variable>
#include <iostream>
#include <map>
#include <mutex>
#include "io/errors.hpp"
#include "io/message_conversion.hpp"
#include "io/time.hpp"
#include "io/transport.hpp"
namespace memgraph::io::local_transport {
class LocalTransportHandle {
mutable std::mutex mu_{};
mutable std::condition_variable cv_;
bool should_shut_down_ = false;
// the responses to requests that are being waited on
std::map<PromiseKey, DeadlineAndOpaquePromise> promises_;
// messages that are sent to servers that may later receive them
std::vector<OpaqueMessage> can_receive_;
public:
void ShutDown() {
std::unique_lock<std::mutex> lock(mu_);
should_shut_down_ = true;
cv_.notify_all();
}
bool ShouldShutDown() const {
std::unique_lock<std::mutex> lock(mu_);
return should_shut_down_;
}
static Time Now() {
auto nano_time = std::chrono::system_clock::now();
return std::chrono::time_point_cast<std::chrono::microseconds>(nano_time);
}
template <Message... Ms>
requires(sizeof...(Ms) > 0) RequestResult<Ms...> Receive(Duration timeout) {
std::unique_lock lock(mu_);
Time before = Now();
while (can_receive_.empty()) {
Time now = Now();
// protection against non-monotonic timesources
auto maxed_now = std::max(now, before);
auto elapsed = maxed_now - before;
if (timeout < elapsed) {
return TimedOut{};
}
Duration relative_timeout = timeout - elapsed;
std::cv_status cv_status_value = cv_.wait_for(lock, relative_timeout);
if (cv_status_value == std::cv_status::timeout) {
return TimedOut{};
}
}
auto current_message = std::move(can_receive_.back());
can_receive_.pop_back();
auto m_opt = std::move(current_message).Take<Ms...>();
return std::move(m_opt).value();
}
template <Message M>
void Send(Address to_address, Address from_address, RequestId request_id, M &&message) {
std::any message_any(std::forward<M>(message));
OpaqueMessage opaque_message{
.from_address = from_address, .request_id = request_id, .message = std::move(message_any)};
PromiseKey promise_key{.requester_address = to_address,
.request_id = opaque_message.request_id,
.replier_address = opaque_message.from_address};
{
std::unique_lock<std::mutex> lock(mu_);
if (promises_.contains(promise_key)) {
// complete waiting promise if it's there
DeadlineAndOpaquePromise dop = std::move(promises_.at(promise_key));
promises_.erase(promise_key);
dop.promise.Fill(std::move(opaque_message));
} else {
can_receive_.emplace_back(std::move(opaque_message));
}
} // lock dropped
cv_.notify_all();
}
template <Message Request, Message Response>
void SubmitRequest(Address to_address, Address from_address, RequestId request_id, Request &&request,
Duration timeout, ResponsePromise<Response> promise) {
const bool port_matches = to_address.last_known_port == from_address.last_known_port;
const bool ip_matches = to_address.last_known_ip == from_address.last_known_ip;
MG_ASSERT(port_matches && ip_matches);
const Time deadline = Now() + timeout;
{
std::unique_lock<std::mutex> lock(mu_);
PromiseKey promise_key{
.requester_address = from_address, .request_id = request_id, .replier_address = to_address};
OpaquePromise opaque_promise(std::move(promise).ToUnique());
DeadlineAndOpaquePromise dop{.deadline = deadline, .promise = std::move(opaque_promise)};
promises_.emplace(std::move(promise_key), std::move(dop));
} // lock dropped
Send(to_address, from_address, request_id, std::forward<Request>(request));
}
};
} // namespace memgraph::io::local_transport

View File

@ -0,0 +1,206 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "io/transport.hpp"
namespace memgraph::io {
using memgraph::io::Duration;
using memgraph::io::Message;
using memgraph::io::Time;
struct PromiseKey {
Address requester_address;
uint64_t request_id;
// TODO(tyler) possibly remove replier_address from promise key
// once we want to support DSR.
Address replier_address;
public:
friend bool operator<(const PromiseKey &lhs, const PromiseKey &rhs) {
if (lhs.requester_address != rhs.requester_address) {
return lhs.requester_address < rhs.requester_address;
}
if (lhs.request_id != rhs.request_id) {
return lhs.request_id < rhs.request_id;
}
return lhs.replier_address < rhs.replier_address;
}
};
struct OpaqueMessage {
Address to_address;
Address from_address;
uint64_t request_id;
std::any message;
/// Recursively tries to match a specific type from the outer
/// variant's parameter pack against the type of the std::any,
/// and if it matches, make it concrete and return it. Otherwise,
/// move on and compare the any with the next type from the
/// parameter pack.
///
/// Return is the full std::variant<Ts...> type that holds the
/// full parameter pack without interfering with recursive
/// narrowing expansion.
template <typename Return, Message Head, Message... Rest>
std::optional<Return> Unpack(std::any &&a) {
if (typeid(Head) == a.type()) {
Head concrete = std::any_cast<Head>(std::move(a));
return concrete;
}
if constexpr (sizeof...(Rest) > 0) {
return Unpack<Return, Rest...>(std::move(a));
} else {
return std::nullopt;
}
}
/// High level "user-facing" conversion function that lets
/// people interested in conversion only supply a single
/// parameter pack for the types that they want to compare
/// with the any and potentially include in the returned
/// variant.
template <Message... Ms>
requires(sizeof...(Ms) > 0) std::optional<std::variant<Ms...>> VariantFromAny(std::any &&a) {
return Unpack<std::variant<Ms...>, Ms...>(std::move(a));
}
template <Message... Ms>
requires(sizeof...(Ms) > 0) std::optional<RequestEnvelope<Ms...>> Take() && {
std::optional<std::variant<Ms...>> m_opt = VariantFromAny<Ms...>(std::move(message));
if (m_opt) {
return RequestEnvelope<Ms...>{
.message = std::move(*m_opt),
.request_id = request_id,
.to_address = to_address,
.from_address = from_address,
};
}
return std::nullopt;
}
};
class OpaquePromiseTraitBase {
public:
virtual const std::type_info *TypeInfo() const = 0;
virtual bool IsAwaited(void *ptr) const = 0;
virtual void Fill(void *ptr, OpaqueMessage &&) const = 0;
virtual void TimeOut(void *ptr) const = 0;
virtual ~OpaquePromiseTraitBase() = default;
OpaquePromiseTraitBase() = default;
OpaquePromiseTraitBase(const OpaquePromiseTraitBase &) = delete;
OpaquePromiseTraitBase &operator=(const OpaquePromiseTraitBase &) = delete;
OpaquePromiseTraitBase(OpaquePromiseTraitBase &&old) = delete;
OpaquePromiseTraitBase &operator=(OpaquePromiseTraitBase &&) = delete;
};
template <typename T>
class OpaquePromiseTrait : public OpaquePromiseTraitBase {
public:
const std::type_info *TypeInfo() const override { return &typeid(T); };
bool IsAwaited(void *ptr) const override { return static_cast<ResponsePromise<T> *>(ptr)->IsAwaited(); };
void Fill(void *ptr, OpaqueMessage &&opaque_message) const override {
T message = std::any_cast<T>(std::move(opaque_message.message));
auto response_envelope = ResponseEnvelope<T>{.message = std::move(message),
.request_id = opaque_message.request_id,
.to_address = opaque_message.to_address,
.from_address = opaque_message.from_address};
auto promise = static_cast<ResponsePromise<T> *>(ptr);
auto unique_promise = std::unique_ptr<ResponsePromise<T>>(promise);
unique_promise->Fill(std::move(response_envelope));
};
void TimeOut(void *ptr) const override {
auto promise = static_cast<ResponsePromise<T> *>(ptr);
auto unique_promise = std::unique_ptr<ResponsePromise<T>>(promise);
ResponseResult<T> result = TimedOut{};
unique_promise->Fill(std::move(result));
}
};
class OpaquePromise {
void *ptr_;
std::unique_ptr<OpaquePromiseTraitBase> trait_;
public:
OpaquePromise(OpaquePromise &&old) noexcept : ptr_(old.ptr_), trait_(std::move(old.trait_)) {
MG_ASSERT(old.ptr_ != nullptr);
old.ptr_ = nullptr;
}
OpaquePromise &operator=(OpaquePromise &&old) noexcept {
MG_ASSERT(ptr_ == nullptr);
MG_ASSERT(old.ptr_ != nullptr);
MG_ASSERT(this != &old);
ptr_ = old.ptr_;
trait_ = std::move(old.trait_);
old.ptr_ = nullptr;
return *this;
}
OpaquePromise(const OpaquePromise &) = delete;
OpaquePromise &operator=(const OpaquePromise &) = delete;
template <typename T>
std::unique_ptr<ResponsePromise<T>> Take() && {
MG_ASSERT(typeid(T) == *trait_->TypeInfo());
MG_ASSERT(ptr_ != nullptr);
auto ptr = static_cast<ResponsePromise<T> *>(ptr_);
ptr_ = nullptr;
return std::unique_ptr<T>(ptr);
}
template <typename T>
explicit OpaquePromise(std::unique_ptr<ResponsePromise<T>> promise)
: ptr_(static_cast<void *>(promise.release())), trait_(std::make_unique<OpaquePromiseTrait<T>>()) {}
bool IsAwaited() {
MG_ASSERT(ptr_ != nullptr);
return trait_->IsAwaited(ptr_);
}
void TimeOut() {
MG_ASSERT(ptr_ != nullptr);
trait_->TimeOut(ptr_);
ptr_ = nullptr;
}
void Fill(OpaqueMessage &&opaque_message) {
MG_ASSERT(ptr_ != nullptr);
trait_->Fill(ptr_, std::move(opaque_message));
ptr_ = nullptr;
}
~OpaquePromise() {
MG_ASSERT(ptr_ == nullptr, "OpaquePromise destroyed without being explicitly timed out or filled");
}
};
struct DeadlineAndOpaquePromise {
Time deadline;
OpaquePromise promise;
};
} // namespace memgraph::io

913
src/io/rsm/raft.hpp Normal file
View File

@ -0,0 +1,913 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
// TODO(tyler) buffer out-of-order Append buffers on the Followers to reassemble more quickly
// TODO(tyler) handle granular batch sizes based on simple flow control
#pragma once
#include <deque>
#include <iostream>
#include <map>
#include <set>
#include <thread>
#include <unordered_map>
#include <vector>
#include "io/simulator/simulator.hpp"
#include "io/transport.hpp"
#include "utils/concepts.hpp"
namespace memgraph::io::rsm {
/// Timeout and replication tunables
using namespace std::chrono_literals;
static constexpr auto kMinimumElectionTimeout = 100ms;
static constexpr auto kMaximumElectionTimeout = 200ms;
static constexpr auto kMinimumBroadcastTimeout = 40ms;
static constexpr auto kMaximumBroadcastTimeout = 60ms;
static constexpr auto kMinimumCronInterval = 1ms;
static constexpr auto kMaximumCronInterval = 2ms;
static constexpr auto kMinimumReceiveTimeout = 40ms;
static constexpr auto kMaximumReceiveTimeout = 60ms;
static_assert(kMinimumElectionTimeout > kMaximumBroadcastTimeout,
"The broadcast timeout has to be smaller than the election timeout!");
static_assert(kMinimumElectionTimeout < kMaximumElectionTimeout,
"The minimum election timeout has to be smaller than the maximum election timeout!");
static_assert(kMinimumBroadcastTimeout < kMaximumBroadcastTimeout,
"The minimum broadcast timeout has to be smaller than the maximum broadcast timeout!");
static_assert(kMinimumCronInterval < kMaximumCronInterval,
"The minimum cron interval has to be smaller than the maximum cron interval!");
static_assert(kMinimumReceiveTimeout < kMaximumReceiveTimeout,
"The minimum receive timeout has to be smaller than the maximum receive timeout!");
static constexpr size_t kMaximumAppendBatchSize = 1024;
using Term = uint64_t;
using LogIndex = uint64_t;
using LogSize = uint64_t;
using RequestId = uint64_t;
template <typename WriteOperation>
struct WriteRequest {
WriteOperation operation;
};
/// WriteResponse is returned to a client after
/// their WriteRequest was entered in to the raft
/// log and it reached consensus.
///
/// WriteReturn is the result of applying the WriteRequest to
/// ReplicatedState, and if the ReplicatedState::write
/// method is deterministic, all replicas will
/// have the same ReplicatedState after applying
/// the submitted WriteRequest.
template <typename WriteReturn>
struct WriteResponse {
bool success;
WriteReturn write_return;
std::optional<Address> retry_leader;
LogIndex raft_index;
};
template <typename ReadOperation>
struct ReadRequest {
ReadOperation operation;
};
template <typename ReadReturn>
struct ReadResponse {
bool success;
ReadReturn read_return;
std::optional<Address> retry_leader;
};
/// AppendRequest is a raft-level message that the Leader
/// periodically broadcasts to all Follower peers. This
/// serves three main roles:
/// 1. acts as a heartbeat from the Leader to the Follower
/// 2. replicates new data that the Leader has received to the Follower
/// 3. informs Follower peers when the commit index has increased,
/// signalling that it is now safe to apply log items to the
/// replicated state machine
template <typename WriteRequest>
struct AppendRequest {
Term term = 0;
LogIndex batch_start_log_index;
Term last_log_term;
std::vector<std::pair<Term, WriteRequest>> entries;
LogSize leader_commit;
};
struct AppendResponse {
bool success;
Term term;
Term last_log_term;
// a small optimization over the raft paper, tells
// the leader the offset that we are interested in
// to send log offsets from for us. This will only
// be useful at the beginning of a leader's term.
LogSize log_size;
};
struct VoteRequest {
Term term = 0;
LogSize log_size;
Term last_log_term;
};
struct VoteResponse {
Term term = 0;
LogSize committed_log_size;
bool vote_granted = false;
};
template <typename WriteRequest>
struct CommonState {
Term term = 0;
std::vector<std::pair<Term, WriteRequest>> log;
LogSize committed_log_size = 0;
LogSize applied_size = 0;
};
struct FollowerTracker {
LogIndex next_index = 0;
LogSize confirmed_log_size = 0;
};
struct PendingClientRequest {
RequestId request_id;
Address address;
Time received_at;
};
struct Leader {
std::map<Address, FollowerTracker> followers;
std::unordered_map<LogIndex, PendingClientRequest> pending_client_requests;
Time last_broadcast = Time::min();
std::string static ToString() { return "\tLeader \t"; }
};
struct Candidate {
std::map<Address, LogSize> successful_votes;
Time election_began = Time::min();
std::set<Address> outstanding_votes;
std::string static ToString() { return "\tCandidate\t"; }
};
struct Follower {
Time last_received_append_entries_timestamp;
Address leader_address;
std::string static ToString() { return "\tFollower \t"; }
};
using Role = std::variant<Candidate, Leader, Follower>;
template <typename Role>
concept AllRoles = memgraph::utils::SameAsAnyOf<Role, Leader, Follower, Candidate>;
template <typename Role>
concept LeaderOrFollower = memgraph::utils::SameAsAnyOf<Role, Leader, Follower>;
template <typename Role>
concept FollowerOrCandidate = memgraph::utils::SameAsAnyOf<Role, Follower, Candidate>;
/*
all ReplicatedState classes should have an Apply method
that returns our WriteResponseValue:
ReadResponse Read(ReadOperation);
WriteResponseValue ReplicatedState::Apply(WriteRequest);
for examples:
if the state is uint64_t, and WriteRequest is `struct PlusOne {};`,
and WriteResponseValue is also uint64_t (the new value), then
each call to state.Apply(PlusOne{}) will return the new value
after incrementing it. 0, 1, 2, 3... and this will be sent back
to the client that requested the mutation.
In practice, these mutations will usually be predicated on some
previous value, so that they are idempotent, functioning similarly
to a CAS operation.
*/
template <typename WriteOperation, typename ReadOperation, typename ReplicatedState, typename WriteResponseValue,
typename ReadResponseValue>
concept Rsm = requires(ReplicatedState state, WriteOperation w, ReadOperation r) {
{ state.Read(r) } -> std::same_as<ReadResponseValue>;
{ state.Apply(w) } -> std::same_as<WriteResponseValue>;
};
/// Parameter Purpose
/// --------------------------
/// IoImpl the concrete Io provider - SimulatorTransport, ThriftTransport, etc...
/// ReplicatedState the high-level data structure that is managed by the raft-backed replicated state machine
/// WriteOperation the individual operation type that is applied to the ReplicatedState in identical order
/// across each replica
/// WriteResponseValue the return value of calling ReplicatedState::Apply(WriteOperation), which is executed in
/// identical order across all replicas after an WriteOperation reaches consensus.
/// ReadOperation the type of operations that do not require consensus before executing directly
/// on a const ReplicatedState &
/// ReadResponseValue the return value of calling ReplicatedState::Read(ReadOperation), which is executed directly
/// without going through consensus first
template <typename IoImpl, typename ReplicatedState, typename WriteOperation, typename WriteResponseValue,
typename ReadOperation, typename ReadResponseValue>
requires Rsm<WriteOperation, ReadOperation, ReplicatedState, WriteResponseValue, ReadResponseValue>
class Raft {
CommonState<WriteOperation> state_;
Role role_ = Candidate{};
Io<IoImpl> io_;
std::vector<Address> peers_;
ReplicatedState replicated_state_;
public:
Raft(Io<IoImpl> &&io, std::vector<Address> peers, ReplicatedState &&replicated_state)
: io_(std::forward<Io<IoImpl>>(io)),
peers_(peers),
replicated_state_(std::forward<ReplicatedState>(replicated_state)) {}
void Run() {
Time last_cron = io_.Now();
while (!io_.ShouldShutDown()) {
const auto now = io_.Now();
const Duration random_cron_interval = RandomTimeout(kMinimumCronInterval, kMaximumCronInterval);
if (now - last_cron > random_cron_interval) {
Cron();
last_cron = now;
}
const Duration receive_timeout = RandomTimeout(kMinimumReceiveTimeout, kMaximumReceiveTimeout);
auto request_result =
io_.template ReceiveWithTimeout<ReadRequest<ReadOperation>, AppendRequest<WriteOperation>, AppendResponse,
WriteRequest<WriteOperation>, VoteRequest, VoteResponse>(receive_timeout);
if (request_result.HasError()) {
continue;
}
auto request = std::move(request_result.GetValue());
Handle(std::move(request.message), request.request_id, request.from_address);
}
}
private:
// Raft paper - 5.3
// When the entry has been safely replicated, the leader applies the
// entry to its state machine and returns the result of that
// execution to the client.
//
// "Safely replicated" is defined as being known to be present
// on at least a majority of all peers (inclusive of the Leader).
void BumpCommitIndexAndReplyToClients(Leader &leader) {
auto confirmed_log_sizes = std::vector<LogSize>{};
// We include our own log size in the calculation of the log
// confirmed log size that is present on at least a majority of all peers.
confirmed_log_sizes.push_back(state_.log.size());
for (const auto &[addr, f] : leader.followers) {
confirmed_log_sizes.push_back(f.confirmed_log_size);
Log("Follower at port ", addr.last_known_port, " has confirmed log size of: ", f.confirmed_log_size);
}
// reverse sort from highest to lowest (using std::ranges::greater)
std::ranges::sort(confirmed_log_sizes, std::ranges::greater());
// This is a particularly correctness-critical calculation because it
// determines the committed log size that will be broadcast in
// the next AppendRequest.
//
// If the following sizes are recorded for clusters of different numbers of peers,
// these are the expected sizes that are considered to have reached consensus:
//
// state | expected value | (confirmed_log_sizes.size() / 2)
// [1] 1 (1 / 2) => 0
// [2, 1] 1 (2 / 2) => 1
// [3, 2, 1] 2 (3 / 2) => 1
// [4, 3, 2, 1] 2 (4 / 2) => 2
// [5, 4, 3, 2, 1] 3 (5 / 2) => 2
const size_t majority_index = confirmed_log_sizes.size() / 2;
const LogSize new_committed_log_size = confirmed_log_sizes[majority_index];
// We never go backwards in history.
MG_ASSERT(state_.committed_log_size <= new_committed_log_size,
"as a Leader, we have previously set our committed_log_size to {}, but our Followers have a majority "
"committed_log_size of {}",
state_.committed_log_size, new_committed_log_size);
state_.committed_log_size = new_committed_log_size;
// For each size between the old size and the new one (inclusive),
// Apply that log's WriteOperation to our replicated_state_,
// and use the specific return value of the ReplicatedState::Apply
// method (WriteResponseValue) to respond to the requester.
for (; state_.applied_size < state_.committed_log_size; state_.applied_size++) {
const LogIndex apply_index = state_.applied_size;
const auto &write_request = state_.log[apply_index].second;
const WriteResponseValue write_return = replicated_state_.Apply(write_request);
if (leader.pending_client_requests.contains(apply_index)) {
const PendingClientRequest client_request = std::move(leader.pending_client_requests.at(apply_index));
leader.pending_client_requests.erase(apply_index);
const WriteResponse<WriteResponseValue> resp{
.success = true,
.write_return = std::move(write_return),
.raft_index = apply_index,
};
io_.Send(client_request.address, client_request.request_id, std::move(resp));
}
}
Log("committed_log_size is now ", state_.committed_log_size);
}
// Raft paper - 5.1
// AppendEntries RPCs are initiated by leaders to replicate log entries and to provide a form of heartbeat
void BroadcastAppendEntries(std::map<Address, FollowerTracker> &followers) {
for (auto &[address, follower] : followers) {
const LogIndex next_index = follower.next_index;
const auto missing = state_.log.size() - next_index;
const auto batch_size = std::min(missing, kMaximumAppendBatchSize);
const auto start_index = next_index;
const auto end_index = start_index + batch_size;
// advance follower's next index
follower.next_index += batch_size;
std::vector<std::pair<Term, WriteOperation>> entries;
entries.insert(entries.begin(), state_.log.begin() + start_index, state_.log.begin() + end_index);
const Term previous_term_from_index = PreviousTermFromIndex(start_index);
Log("sending ", entries.size(), " entries to Follower ", address.last_known_port,
" which are above its next_index of ", next_index);
AppendRequest<WriteOperation> ar{
.term = state_.term,
.batch_start_log_index = start_index,
.last_log_term = previous_term_from_index,
.entries = std::move(entries),
.leader_commit = state_.committed_log_size,
};
// request_id not necessary to set because it's not a Future-backed Request.
static constexpr RequestId request_id = 0;
io_.Send(address, request_id, std::move(ar));
}
}
// Raft paper - 5.2
// Raft uses randomized election timeouts to ensure that split votes are rare and that they are resolved quickly
Duration RandomTimeout(Duration min, Duration max) {
std::uniform_int_distribution time_distrib(min.count(), max.count());
const auto rand_micros = io_.Rand(time_distrib);
return Duration{rand_micros};
}
Duration RandomTimeout(int min_micros, int max_micros) {
std::uniform_int_distribution time_distrib(min_micros, max_micros);
const int rand_micros = io_.Rand(time_distrib);
return std::chrono::microseconds{rand_micros};
}
Term PreviousTermFromIndex(LogIndex index) const {
if (index == 0 || state_.log.size() + 1 <= index) {
return 0;
}
const auto &[term, data] = state_.log.at(index - 1);
return term;
}
Term CommittedLogTerm() {
MG_ASSERT(state_.log.size() >= state_.committed_log_size);
if (state_.log.empty() || state_.committed_log_size == 0) {
return 0;
}
const auto &[term, data] = state_.log.at(state_.committed_log_size - 1);
return term;
}
Term LastLogTerm() const {
if (state_.log.empty()) {
return 0;
}
const auto &[term, data] = state_.log.back();
return term;
}
template <typename... Ts>
void Log(Ts &&...args) {
const Time now = io_.Now();
const auto micros = std::chrono::duration_cast<std::chrono::milliseconds>(now.time_since_epoch()).count();
const Term term = state_.term;
const std::string role_string = std::visit([&](const auto &role) { return role.ToString(); }, role_);
std::ostringstream out;
out << '\t' << static_cast<int>(micros) << "\t" << term << "\t" << io_.GetAddress().last_known_port;
out << role_string;
(out << ... << args);
spdlog::info(out.str());
}
/////////////////////////////////////////////////////////////
/// Raft-related Cron methods
///
/// Cron + std::visit is how events are dispatched
/// to certain code based on Raft role.
///
/// Cron(role) takes as the first argument a reference to its
/// role, and as the second argument, the message that has
/// been received.
/////////////////////////////////////////////////////////////
/// Periodic protocol maintenance.
void Cron() {
// dispatch periodic logic based on our role to a specific Cron method.
std::optional<Role> new_role = std::visit([&](auto &role) { return Cron(role); }, role_);
if (new_role) {
role_ = std::move(new_role).value();
}
}
// Raft paper - 5.2
// Candidates keep sending Vote to peers until:
// 1. receiving Append with a higher term (become Follower)
// 2. receiving Vote with a higher term (become a Follower)
// 3. receiving a quorum of responses to our last batch of Vote (become a Leader)
std::optional<Role> Cron(Candidate &candidate) {
const auto now = io_.Now();
const Duration election_timeout = RandomTimeout(kMinimumElectionTimeout, kMaximumElectionTimeout);
const auto election_timeout_us = std::chrono::duration_cast<std::chrono::milliseconds>(election_timeout).count();
if (now - candidate.election_began > election_timeout) {
state_.term++;
Log("becoming Candidate for term ", state_.term, " after leader timeout of ", election_timeout_us,
"ms elapsed since last election attempt");
const VoteRequest request{
.term = state_.term,
.log_size = state_.log.size(),
.last_log_term = LastLogTerm(),
};
auto outstanding_votes = std::set<Address>();
for (const auto &peer : peers_) {
// request_id not necessary to set because it's not a Future-backed Request.
static constexpr auto request_id = 0;
io_.template Send<VoteRequest>(peer, request_id, request);
outstanding_votes.insert(peer);
}
return Candidate{
.successful_votes = std::map<Address, LogIndex>(),
.election_began = now,
.outstanding_votes = outstanding_votes,
};
}
return std::nullopt;
}
// Raft paper - 5.2
// Followers become candidates if we haven't heard from the leader
// after a randomized timeout.
std::optional<Role> Cron(Follower &follower) {
const auto now = io_.Now();
const auto time_since_last_append_entries = now - follower.last_received_append_entries_timestamp;
// randomized follower timeout
const Duration election_timeout = RandomTimeout(kMinimumElectionTimeout, kMaximumElectionTimeout);
if (time_since_last_append_entries > election_timeout) {
// become a Candidate if we haven't heard from the Leader after this timeout
return Candidate{};
}
return std::nullopt;
}
// Leaders (re)send AppendRequest to followers.
std::optional<Role> Cron(Leader &leader) {
const Time now = io_.Now();
const Duration broadcast_timeout = RandomTimeout(kMinimumBroadcastTimeout, kMaximumBroadcastTimeout);
if (now - leader.last_broadcast > broadcast_timeout) {
BroadcastAppendEntries(leader.followers);
leader.last_broadcast = now;
}
return std::nullopt;
}
/////////////////////////////////////////////////////////////
/// Raft-related Handle methods
///
/// Handle + std::visit is how events are dispatched
/// to certain code based on Raft role.
///
/// Handle(role, message, ...)
/// takes as the first argument a reference
/// to its role, and as the second argument, the
/// message that has been received.
/////////////////////////////////////////////////////////////
using ReceiveVariant = std::variant<ReadRequest<ReadOperation>, AppendRequest<WriteOperation>, AppendResponse,
WriteRequest<WriteOperation>, VoteRequest, VoteResponse>;
void Handle(ReceiveVariant &&message_variant, RequestId request_id, Address from_address) {
// dispatch the message to a handler based on our role,
// which can be specified in the Handle first argument,
// or it can be `auto` if it's a handler for several roles
// or messages.
std::optional<Role> new_role = std::visit(
[&](auto &&msg, auto &role) mutable {
return Handle(role, std::forward<decltype(msg)>(msg), request_id, from_address);
},
std::forward<ReceiveVariant>(message_variant), role_);
// TODO(tyler) (M3) maybe replace std::visit with get_if for explicit prioritized matching, [[likely]] etc...
if (new_role) {
role_ = std::move(new_role).value();
}
}
// all roles can receive Vote and possibly become a follower
template <AllRoles ALL>
std::optional<Role> Handle(ALL & /* variable */, VoteRequest &&req, RequestId request_id, Address from_address) {
Log("received VoteRequest from ", from_address.last_known_port, " with term ", req.term);
const bool last_log_term_dominates = req.last_log_term >= LastLogTerm();
const bool term_dominates = req.term > state_.term;
const bool log_size_dominates = req.log_size >= state_.log.size();
const bool new_leader = last_log_term_dominates && term_dominates && log_size_dominates;
if (new_leader) {
MG_ASSERT(req.term > state_.term);
MG_ASSERT(std::max(req.term, state_.term) == req.term);
}
const VoteResponse res{
.term = std::max(req.term, state_.term),
.committed_log_size = state_.committed_log_size,
.vote_granted = new_leader,
};
io_.Send(from_address, request_id, res);
if (new_leader) {
// become a follower
state_.term = req.term;
return Follower{
.last_received_append_entries_timestamp = io_.Now(),
.leader_address = from_address,
};
}
if (term_dominates) {
Log("received a vote from an inferior candidate. Becoming Candidate");
state_.term = std::max(state_.term, req.term) + 1;
return Candidate{};
}
return std::nullopt;
}
std::optional<Role> Handle(Candidate &candidate, VoteResponse &&res, RequestId /* variable */, Address from_address) {
Log("received VoteResponse");
if (!res.vote_granted || res.term != state_.term) {
Log("received unsuccessful VoteResponse from term ", res.term, " when our candidacy term is ", state_.term);
// we received a delayed VoteResponse from the past, which has to do with an election that is
// no longer valid. We can simply drop this.
return std::nullopt;
}
MG_ASSERT(candidate.outstanding_votes.contains(from_address),
"Received unexpected VoteResponse from server not present in Candidate's outstanding_votes!");
candidate.outstanding_votes.erase(from_address);
MG_ASSERT(!candidate.successful_votes.contains(from_address),
"Received unexpected VoteResponse from server already in Candidate's successful_votes!");
candidate.successful_votes.insert({from_address, res.committed_log_size});
if (candidate.successful_votes.size() >= candidate.outstanding_votes.size()) {
std::map<Address, FollowerTracker> followers{};
for (const auto &[address, committed_log_size] : candidate.successful_votes) {
FollowerTracker follower{
.next_index = committed_log_size,
.confirmed_log_size = committed_log_size,
};
followers.insert({address, follower});
}
for (const auto &address : candidate.outstanding_votes) {
FollowerTracker follower{
.next_index = state_.log.size(),
.confirmed_log_size = 0,
};
followers.insert({address, follower});
}
Log("becoming Leader at term ", state_.term);
BroadcastAppendEntries(followers);
return Leader{
.followers = std::move(followers),
.pending_client_requests = std::unordered_map<LogIndex, PendingClientRequest>(),
};
}
return std::nullopt;
}
template <LeaderOrFollower LOF>
std::optional<Role> Handle(LOF & /* variable */, VoteResponse && /* variable */, RequestId /* variable */,
Address /* variable */) {
Log("non-Candidate received VoteResponse");
return std::nullopt;
}
template <AllRoles ALL>
std::optional<Role> Handle(ALL &role, AppendRequest<WriteOperation> &&req, RequestId request_id,
Address from_address) {
// log size starts out as state_.committed_log_size and only if everything is successful do we
// switch it to the log length.
AppendResponse res{
.success = false,
.term = state_.term,
.last_log_term = CommittedLogTerm(),
.log_size = state_.log.size(),
};
if constexpr (std::is_same<ALL, Leader>()) {
MG_ASSERT(req.term != state_.term, "Multiple leaders are acting under the term ", req.term);
}
const bool is_candidate = std::is_same<ALL, Candidate>();
const bool is_failed_competitor = is_candidate && req.term == state_.term;
const Time now = io_.Now();
// Raft paper - 5.2
// While waiting for votes, a candidate may receive an
// AppendEntries RPC from another server claiming to be leader. If
// the leaders term (included in its RPC) is at least as large as
// the candidates current term, then the candidate recognizes the
// leader as legitimate and returns to follower state.
if (req.term > state_.term || is_failed_competitor) {
// become follower of this leader, reply with our log status
state_.term = req.term;
io_.Send(from_address, request_id, res);
Log("becoming Follower of Leader ", from_address.last_known_port, " at term ", req.term);
return Follower{
.last_received_append_entries_timestamp = now,
.leader_address = from_address,
};
}
if (req.term < state_.term) {
// nack this request from an old leader
io_.Send(from_address, request_id, res);
return std::nullopt;
}
// at this point, we're dealing with our own leader
if constexpr (std::is_same<ALL, Follower>()) {
// small specialization for when we're already a Follower
MG_ASSERT(role.leader_address == from_address, "Multiple Leaders are acting under the same term number!");
role.last_received_append_entries_timestamp = now;
} else {
Log("Somehow entered Follower-specific logic as a non-Follower");
MG_ASSERT(false, "Somehow entered Follower-specific logic as a non-Follower");
}
// Handle steady-state conditions.
if (req.batch_start_log_index != state_.log.size()) {
Log("req.batch_start_log_index of ", req.batch_start_log_index, " does not match our log size of ",
state_.log.size());
} else if (req.last_log_term != LastLogTerm()) {
Log("req.last_log_term differs from our leader term at that slot, expected: ", LastLogTerm(), " but got ",
req.last_log_term);
} else {
// happy path - Apply log
Log("applying batch of ", req.entries.size(), " entries to our log starting at index ",
req.batch_start_log_index);
const auto resize_length = req.batch_start_log_index;
MG_ASSERT(resize_length >= state_.committed_log_size,
"Applied history from Leader which goes back in time from our commit_index");
// possibly chop-off stuff that was replaced by
// things with different terms (we got data that
// hasn't reached consensus yet, which is normal)
state_.log.resize(resize_length);
if (req.entries.size() > 0) {
auto &[first_term, op] = req.entries.at(0);
MG_ASSERT(LastLogTerm() <= first_term);
}
state_.log.insert(state_.log.end(), std::make_move_iterator(req.entries.begin()),
std::make_move_iterator(req.entries.end()));
MG_ASSERT(req.leader_commit >= state_.committed_log_size);
state_.committed_log_size = std::min(req.leader_commit, state_.log.size());
for (; state_.applied_size < state_.committed_log_size; state_.applied_size++) {
const auto &write_request = state_.log[state_.applied_size].second;
replicated_state_.Apply(write_request);
}
res.success = true;
}
res.last_log_term = LastLogTerm();
res.log_size = state_.log.size();
Log("returning log_size of ", res.log_size);
io_.Send(from_address, request_id, res);
return std::nullopt;
}
std::optional<Role> Handle(Leader &leader, AppendResponse &&res, RequestId /* variable */, Address from_address) {
if (res.term != state_.term) {
Log("received AppendResponse related to a previous term when we (presumably) were the leader");
return std::nullopt;
}
// TODO(tyler) when we have dynamic membership, this assert will become incorrect, but we should
// keep it in-place until then because it has bug finding value.
MG_ASSERT(leader.followers.contains(from_address), "received AppendResponse from unknown Follower");
// at this point, we know the term matches and we know this Follower
FollowerTracker &follower = leader.followers.at(from_address);
if (res.success) {
Log("got successful AppendResponse from ", from_address.last_known_port, " with log_size of ", res.log_size);
follower.next_index = std::max(follower.next_index, res.log_size);
} else {
Log("got unsuccessful AppendResponse from ", from_address.last_known_port, " with log_size of ", res.log_size);
follower.next_index = res.log_size;
}
follower.confirmed_log_size = std::max(follower.confirmed_log_size, res.log_size);
BumpCommitIndexAndReplyToClients(leader);
return std::nullopt;
}
template <FollowerOrCandidate FOC>
std::optional<Role> Handle(FOC & /* variable */, AppendResponse && /* variable */, RequestId /* variable */,
Address /* variable */) {
// we used to be the leader, and are getting old delayed responses
return std::nullopt;
}
/////////////////////////////////////////////////////////////
/// RSM-related handle methods
/////////////////////////////////////////////////////////////
// Leaders are able to immediately respond to the requester (with a ReadResponseValue) applied to the ReplicatedState
std::optional<Role> Handle(Leader & /* variable */, ReadRequest<ReadOperation> &&req, RequestId request_id,
Address from_address) {
Log("handling ReadOperation");
ReadOperation read_operation = req.operation;
ReadResponseValue read_return = replicated_state_.Read(read_operation);
const ReadResponse<ReadResponseValue> resp{
.success = true,
.read_return = std::move(read_return),
.retry_leader = std::nullopt,
};
io_.Send(from_address, request_id, resp);
return std::nullopt;
}
// Candidates should respond with a failure, similar to the Candidate + WriteRequest failure below
std::optional<Role> Handle(Candidate & /* variable */, ReadRequest<ReadOperation> && /* variable */,
RequestId request_id, Address from_address) {
Log("received ReadOperation - not redirecting because no Leader is known");
const ReadResponse<ReadResponseValue> res{
.success = false,
};
io_.Send(from_address, request_id, res);
Cron();
return std::nullopt;
}
// Followers should respond with a redirection, similar to the Follower + WriteRequest response below
std::optional<Role> Handle(Follower &follower, ReadRequest<ReadOperation> && /* variable */, RequestId request_id,
Address from_address) {
Log("redirecting client to known Leader with port ", follower.leader_address.last_known_port);
const ReadResponse<ReadResponseValue> res{
.success = false,
.retry_leader = follower.leader_address,
};
io_.Send(from_address, request_id, res);
return std::nullopt;
}
// Raft paper - 8
// When a client first starts up, it connects to a randomly chosen
// server. If the clients first choice is not the leader, that
// server will reject the clients request and supply information
// about the most recent leader it has heard from.
std::optional<Role> Handle(Follower &follower, WriteRequest<WriteOperation> && /* variable */, RequestId request_id,
Address from_address) {
Log("redirecting client to known Leader with port ", follower.leader_address.last_known_port);
const WriteResponse<WriteResponseValue> res{
.success = false,
.retry_leader = follower.leader_address,
};
io_.Send(from_address, request_id, res);
return std::nullopt;
}
std::optional<Role> Handle(Candidate & /* variable */, WriteRequest<WriteOperation> && /* variable */,
RequestId request_id, Address from_address) {
Log("received WriteRequest - not redirecting because no Leader is known");
const WriteResponse<WriteResponseValue> res{
.success = false,
};
io_.Send(from_address, request_id, res);
Cron();
return std::nullopt;
}
// only leaders actually handle replication requests from clients
std::optional<Role> Handle(Leader &leader, WriteRequest<WriteOperation> &&req, RequestId request_id,
Address from_address) {
Log("handling WriteRequest");
// we are the leader. add item to log and send Append to peers
MG_ASSERT(state_.term >= LastLogTerm());
state_.log.emplace_back(std::pair(state_.term, std::move(req.operation)));
LogIndex log_index = state_.log.size() - 1;
PendingClientRequest pcr{
.request_id = request_id,
.address = from_address,
.received_at = io_.Now(),
};
leader.pending_client_requests.emplace(log_index, pcr);
BroadcastAppendEntries(leader.followers);
return std::nullopt;
}
};
}; // namespace memgraph::io::rsm

View File

@ -0,0 +1,8 @@
set(io_simulator_sources
simulator_handle.cpp)
find_package(fmt REQUIRED)
find_package(Threads REQUIRED)
add_library(mg-io-simulator STATIC ${io_simulator_sources})
target_link_libraries(mg-io-simulator stdc++fs Threads::Threads fmt::fmt mg-utils)

View File

@ -0,0 +1,45 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <memory>
#include <random>
#include "io/address.hpp"
#include "io/simulator/simulator_config.hpp"
#include "io/simulator/simulator_handle.hpp"
#include "io/simulator/simulator_transport.hpp"
namespace memgraph::io::simulator {
class Simulator {
std::mt19937 rng_;
std::shared_ptr<SimulatorHandle> simulator_handle_;
public:
explicit Simulator(SimulatorConfig config)
: rng_(std::mt19937{config.rng_seed}), simulator_handle_{std::make_shared<SimulatorHandle>(config)} {}
void ShutDown() { simulator_handle_->ShutDown(); }
Io<SimulatorTransport> Register(Address address) {
std::uniform_int_distribution<uint64_t> seed_distrib;
uint64_t seed = seed_distrib(rng_);
return Io{SimulatorTransport{simulator_handle_, address, seed}, address};
}
void IncrementServerCountAndWaitForQuiescentState(Address address) {
simulator_handle_->IncrementServerCountAndWaitForQuiescentState(address);
}
SimulatorStats Stats() { return simulator_handle_->Stats(); }
};
}; // namespace memgraph::io::simulator

View File

@ -0,0 +1,30 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <chrono>
#include "io/time.hpp"
namespace memgraph::io::simulator {
using memgraph::io::Time;
struct SimulatorConfig {
uint8_t drop_percent = 0;
bool perform_timeouts = false;
bool scramble_messages = true;
uint64_t rng_seed = 0;
Time start_time = Time::min();
Time abort_time = Time::max();
};
}; // namespace memgraph::io::simulator

View File

@ -0,0 +1,142 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "io/simulator/simulator_handle.hpp"
#include "io/address.hpp"
#include "io/errors.hpp"
#include "io/simulator/simulator_config.hpp"
#include "io/simulator/simulator_stats.hpp"
#include "io/time.hpp"
#include "io/transport.hpp"
namespace memgraph::io::simulator {
using memgraph::io::Duration;
using memgraph::io::Time;
void SimulatorHandle::ShutDown() {
std::unique_lock<std::mutex> lock(mu_);
should_shut_down_ = true;
cv_.notify_all();
}
bool SimulatorHandle::ShouldShutDown() const {
std::unique_lock<std::mutex> lock(mu_);
return should_shut_down_;
}
void SimulatorHandle::IncrementServerCountAndWaitForQuiescentState(Address address) {
std::unique_lock<std::mutex> lock(mu_);
server_addresses_.insert(address);
while (true) {
const size_t blocked_servers = blocked_on_receive_;
const bool all_servers_blocked = blocked_servers == server_addresses_.size();
if (all_servers_blocked) {
return;
}
cv_.wait(lock);
}
}
bool SimulatorHandle::MaybeTickSimulator() {
std::unique_lock<std::mutex> lock(mu_);
const size_t blocked_servers = blocked_on_receive_;
if (blocked_servers < server_addresses_.size()) {
// we only need to advance the simulator when all
// servers have reached a quiescent state, blocked
// on their own futures or receive methods.
return false;
}
stats_.simulator_ticks++;
cv_.notify_all();
TimeoutPromisesPastDeadline();
if (in_flight_.empty()) {
// return early here because there are no messages to schedule
// We tick the clock forward when all servers are blocked but
// there are no in-flight messages to schedule delivery of.
std::poisson_distribution<> time_distrib(50);
Duration clock_advance = std::chrono::microseconds{time_distrib(rng_)};
cluster_wide_time_microseconds_ += clock_advance;
MG_ASSERT(cluster_wide_time_microseconds_ < config_.abort_time,
"Cluster has executed beyond its configured abort_time, and something may be failing to make progress "
"in an expected amount of time.");
return true;
}
if (config_.scramble_messages) {
// scramble messages
std::uniform_int_distribution<size_t> swap_distrib(0, in_flight_.size() - 1);
const size_t swap_index = swap_distrib(rng_);
std::swap(in_flight_[swap_index], in_flight_.back());
}
auto [to_address, opaque_message] = std::move(in_flight_.back());
in_flight_.pop_back();
std::uniform_int_distribution<int> drop_distrib(0, 99);
const int drop_threshold = drop_distrib(rng_);
const bool should_drop = drop_threshold < config_.drop_percent;
if (should_drop) {
stats_.dropped_messages++;
}
PromiseKey promise_key{.requester_address = to_address,
.request_id = opaque_message.request_id,
.replier_address = opaque_message.from_address};
if (promises_.contains(promise_key)) {
// complete waiting promise if it's there
DeadlineAndOpaquePromise dop = std::move(promises_.at(promise_key));
promises_.erase(promise_key);
const bool normal_timeout = config_.perform_timeouts && (dop.deadline < cluster_wide_time_microseconds_);
if (should_drop || normal_timeout) {
stats_.timed_out_requests++;
dop.promise.TimeOut();
} else {
stats_.total_responses++;
dop.promise.Fill(std::move(opaque_message));
}
} else if (should_drop) {
// don't add it anywhere, let it drop
} else {
// add to can_receive_ if not
const auto &[om_vec, inserted] = can_receive_.try_emplace(to_address, std::vector<OpaqueMessage>());
om_vec->second.emplace_back(std::move(opaque_message));
}
return true;
}
Time SimulatorHandle::Now() const {
std::unique_lock<std::mutex> lock(mu_);
return cluster_wide_time_microseconds_;
}
SimulatorStats SimulatorHandle::Stats() {
std::unique_lock<std::mutex> lock(mu_);
return stats_;
}
} // namespace memgraph::io::simulator

View File

@ -0,0 +1,175 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <any>
#include <compare>
#include <iostream>
#include <map>
#include <memory>
#include <optional>
#include <set>
#include <utility>
#include <variant>
#include <vector>
#include "io/address.hpp"
#include "io/errors.hpp"
#include "io/message_conversion.hpp"
#include "io/simulator/simulator_config.hpp"
#include "io/simulator/simulator_stats.hpp"
#include "io/time.hpp"
#include "io/transport.hpp"
namespace memgraph::io::simulator {
class SimulatorHandle {
mutable std::mutex mu_{};
mutable std::condition_variable cv_;
// messages that have not yet been scheduled or dropped
std::vector<std::pair<Address, OpaqueMessage>> in_flight_;
// the responses to requests that are being waited on
std::map<PromiseKey, DeadlineAndOpaquePromise> promises_;
// messages that are sent to servers that may later receive them
std::map<Address, std::vector<OpaqueMessage>> can_receive_;
Time cluster_wide_time_microseconds_;
bool should_shut_down_ = false;
SimulatorStats stats_;
size_t blocked_on_receive_ = 0;
std::set<Address> server_addresses_;
std::mt19937 rng_;
SimulatorConfig config_;
void TimeoutPromisesPastDeadline() {
const Time now = cluster_wide_time_microseconds_;
for (auto &[promise_key, dop] : promises_) {
if (dop.deadline < now) {
spdlog::debug("timing out request from requester {} to replier {}.", promise_key.requester_address.ToString(),
promise_key.replier_address.ToString());
std::move(dop).promise.TimeOut();
promises_.erase(promise_key);
stats_.timed_out_requests++;
}
}
}
public:
explicit SimulatorHandle(SimulatorConfig config)
: cluster_wide_time_microseconds_(config.start_time), rng_(config.rng_seed), config_(config) {}
void IncrementServerCountAndWaitForQuiescentState(Address address);
/// This method causes most of the interesting simulation logic to happen, wrt network behavior.
/// It checks to see if all background "server" threads are blocked on new messages, and if so,
/// it will decide whether to drop, reorder, or deliver in-flight messages based on the SimulatorConfig
/// that was used to create the Simulator.
bool MaybeTickSimulator();
void ShutDown();
bool ShouldShutDown() const;
template <Message Request, Message Response>
void SubmitRequest(Address to_address, Address from_address, RequestId request_id, Request &&request,
Duration timeout, ResponsePromise<Response> &&promise) {
std::unique_lock<std::mutex> lock(mu_);
const Time deadline = cluster_wide_time_microseconds_ + timeout;
std::any message(request);
OpaqueMessage om{.to_address = to_address,
.from_address = from_address,
.request_id = request_id,
.message = std::move(message)};
in_flight_.emplace_back(std::make_pair(to_address, std::move(om)));
PromiseKey promise_key{.requester_address = from_address, .request_id = request_id, .replier_address = to_address};
OpaquePromise opaque_promise(std::move(promise).ToUnique());
DeadlineAndOpaquePromise dop{.deadline = deadline, .promise = std::move(opaque_promise)};
promises_.emplace(std::move(promise_key), std::move(dop));
stats_.total_messages++;
stats_.total_requests++;
cv_.notify_all();
}
template <Message... Ms>
requires(sizeof...(Ms) > 0) RequestResult<Ms...> Receive(const Address &receiver, Duration timeout) {
std::unique_lock<std::mutex> lock(mu_);
blocked_on_receive_ += 1;
const Time deadline = cluster_wide_time_microseconds_ + timeout;
while (!should_shut_down_ && (cluster_wide_time_microseconds_ < deadline)) {
if (can_receive_.contains(receiver)) {
std::vector<OpaqueMessage> &can_rx = can_receive_.at(receiver);
if (!can_rx.empty()) {
OpaqueMessage message = std::move(can_rx.back());
can_rx.pop_back();
// TODO(tyler) search for item in can_receive_ that matches the desired types, rather
// than asserting that the last item in can_rx matches.
auto m_opt = std::move(message).Take<Ms...>();
blocked_on_receive_ -= 1;
return std::move(m_opt).value();
}
}
lock.unlock();
bool made_progress = MaybeTickSimulator();
lock.lock();
if (!should_shut_down_ && !made_progress) {
cv_.wait(lock);
}
}
blocked_on_receive_ -= 1;
return TimedOut{};
}
template <Message M>
void Send(Address to_address, Address from_address, RequestId request_id, M message) {
std::unique_lock<std::mutex> lock(mu_);
std::any message_any(std::move(message));
OpaqueMessage om{.to_address = to_address,
.from_address = from_address,
.request_id = request_id,
.message = std::move(message_any)};
in_flight_.emplace_back(std::make_pair(std::move(to_address), std::move(om)));
stats_.total_messages++;
cv_.notify_all();
}
Time Now() const;
template <class D = std::poisson_distribution<>, class Return = uint64_t>
Return Rand(D distrib) {
std::unique_lock<std::mutex> lock(mu_);
return distrib(rng_);
}
SimulatorStats Stats();
};
}; // namespace memgraph::io::simulator

View File

@ -0,0 +1,25 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <cstdint>
namespace memgraph::io::simulator {
struct SimulatorStats {
uint64_t total_messages = 0;
uint64_t dropped_messages = 0;
uint64_t timed_out_requests = 0;
uint64_t total_requests = 0;
uint64_t total_responses = 0;
uint64_t simulator_ticks = 0;
};
}; // namespace memgraph::io::simulator

View File

@ -0,0 +1,65 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <memory>
#include <utility>
#include "io/address.hpp"
#include "io/simulator/simulator_handle.hpp"
#include "io/time.hpp"
namespace memgraph::io::simulator {
using memgraph::io::Duration;
using memgraph::io::Time;
class SimulatorTransport {
std::shared_ptr<SimulatorHandle> simulator_handle_;
const Address address_;
std::mt19937 rng_;
public:
SimulatorTransport(std::shared_ptr<SimulatorHandle> simulator_handle, Address address, uint64_t seed)
: simulator_handle_(simulator_handle), address_(address), rng_(std::mt19937{seed}) {}
template <Message Request, Message Response>
ResponseFuture<Response> Request(Address address, uint64_t request_id, Request request, Duration timeout) {
std::function<bool()> maybe_tick_simulator = [this] { return simulator_handle_->MaybeTickSimulator(); };
auto [future, promise] =
memgraph::io::FuturePromisePairWithNotifier<ResponseResult<Response>>(maybe_tick_simulator);
simulator_handle_->SubmitRequest(address, address_, request_id, std::move(request), timeout, std::move(promise));
return std::move(future);
}
template <Message... Ms>
requires(sizeof...(Ms) > 0) RequestResult<Ms...> Receive(Duration timeout) {
return simulator_handle_->template Receive<Ms...>(address_, timeout);
}
template <Message M>
void Send(Address to_address, Address from_address, uint64_t request_id, M message) {
return simulator_handle_->template Send<M>(to_address, from_address, request_id, message);
}
Time Now() const { return simulator_handle_->Now(); }
bool ShouldShutDown() const { return simulator_handle_->ShouldShutDown(); }
template <class D = std::poisson_distribution<>, class Return = uint64_t>
Return Rand(D distrib) {
return distrib(rng_);
}
};
}; // namespace memgraph::io::simulator

21
src/io/time.hpp Normal file
View File

@ -0,0 +1,21 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <chrono>
namespace memgraph::io {
using Duration = std::chrono::microseconds;
using Time = std::chrono::time_point<std::chrono::system_clock, Duration>;
} // namespace memgraph::io

135
src/io/transport.hpp Normal file
View File

@ -0,0 +1,135 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <chrono>
#include <concepts>
#include <random>
#include <variant>
#include "io/address.hpp"
#include "io/errors.hpp"
#include "io/future.hpp"
#include "io/time.hpp"
#include "utils/result.hpp"
namespace memgraph::io {
using memgraph::utils::BasicResult;
// TODO(tyler) ensure that Message continues to represent
// reasonable constraints around message types over time,
// as we adapt things to use Thrift-generated message types.
template <typename T>
concept Message = std::same_as<T, std::decay_t<T>>;
using RequestId = uint64_t;
template <Message M>
struct ResponseEnvelope {
M message;
RequestId request_id;
Address to_address;
Address from_address;
};
template <Message M>
using ResponseResult = BasicResult<TimedOut, ResponseEnvelope<M>>;
template <Message M>
using ResponseFuture = memgraph::io::Future<ResponseResult<M>>;
template <Message M>
using ResponsePromise = memgraph::io::Promise<ResponseResult<M>>;
template <Message... Ms>
struct RequestEnvelope {
std::variant<Ms...> message;
RequestId request_id;
Address to_address;
Address from_address;
};
template <Message... Ms>
using RequestResult = BasicResult<TimedOut, RequestEnvelope<Ms...>>;
template <typename I>
class Io {
I implementation_;
Address address_;
RequestId request_id_counter_ = 0;
Duration default_timeout_ = std::chrono::microseconds{50000};
public:
Io(I io, Address address) : implementation_(io), address_(address) {}
/// Set the default timeout for all requests that are issued
/// without an explicit timeout set.
void SetDefaultTimeout(Duration timeout) { default_timeout_ = timeout; }
/// Issue a request with an explicit timeout in microseconds provided. This tends to be used by clients.
template <Message Request, Message Response>
ResponseFuture<Response> RequestWithTimeout(Address address, Request request, Duration timeout) {
const RequestId request_id = ++request_id_counter_;
return implementation_.template Request<Request, Response>(address, request_id, request, timeout);
}
/// Issue a request that times out after the default timeout. This tends
/// to be used by clients.
template <Message Request, Message Response>
ResponseFuture<Response> Request(Address address, Request request) {
const RequestId request_id = ++request_id_counter_;
const Duration timeout = default_timeout_;
return implementation_.template Request<Request, Response>(address, request_id, std::move(request), timeout);
}
/// Wait for an explicit number of microseconds for a request of one of the
/// provided types to arrive. This tends to be used by servers.
template <Message... Ms>
RequestResult<Ms...> ReceiveWithTimeout(Duration timeout) {
return implementation_.template Receive<Ms...>(timeout);
}
/// Wait the default number of microseconds for a request of one of the
/// provided types to arrive. This tends to be used by servers.
template <Message... Ms>
requires(sizeof...(Ms) > 0) RequestResult<Ms...> Receive() {
const Duration timeout = default_timeout_;
return implementation_.template Receive<Ms...>(timeout);
}
/// Send a message in a best-effort fashion. This is used for messaging where
/// responses are not necessarily expected, and for servers to respond to requests.
/// If you need reliable delivery, this must be built on-top. TCP is not enough for most use cases.
template <Message M>
void Send(Address to_address, RequestId request_id, M message) {
Address from_address = address_;
return implementation_.template Send<M>(to_address, from_address, request_id, std::move(message));
}
/// The current system time. This time source should be preferred over any other,
/// because it lets us deterministically control clocks from tests for making
/// things like timeouts deterministic.
Time Now() const { return implementation_.Now(); }
/// Returns true if the system should shut-down.
bool ShouldShutDown() const { return implementation_.ShouldShutDown(); }
/// Returns a random number within the specified distribution.
template <class D = std::poisson_distribution<>, class Return = uint64_t>
Return Rand(D distrib) {
return implementation_.template Rand<D, Return>(distrib);
}
Address GetAddress() { return address_; }
};
}; // namespace memgraph::io

223
src/query/v2/requests.hpp Normal file
View File

@ -0,0 +1,223 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <chrono>
#include <iostream>
#include <map>
#include <optional>
#include <unordered_map>
#include <variant>
#include <vector>
#include "storage/v3/id_types.hpp"
#include "storage/v3/property_value.hpp"
/// Hybrid-logical clock
struct Hlc {
uint64_t logical_id;
using Duration = std::chrono::microseconds;
using Time = std::chrono::time_point<std::chrono::system_clock, Duration>;
Time coordinator_wall_clock;
bool operator==(const Hlc &other) const = default;
};
struct Label {
size_t id;
};
// TODO(kostasrim) update this with CompoundKey, same for the rest of the file.
using PrimaryKey = std::vector<memgraph::storage::v3::PropertyValue>;
using VertexId = std::pair<Label, PrimaryKey>;
using Gid = size_t;
using PropertyId = memgraph::storage::v3::PropertyId;
struct EdgeType {
std::string name;
};
struct EdgeId {
VertexId id;
Gid gid;
};
struct Vertex {
VertexId id;
std::vector<Label> labels;
};
struct Edge {
VertexId src;
VertexId dst;
EdgeType type;
};
struct PathPart {
Vertex dst;
Gid edge;
};
struct Path {
Vertex src;
std::vector<PathPart> parts;
};
struct Null {};
struct Value {
enum Type { NILL, BOOL, INT64, DOUBLE, STRING, LIST, MAP, VERTEX, EDGE, PATH };
union {
Null null_v;
bool bool_v;
uint64_t int_v;
double double_v;
std::string string_v;
std::vector<Value> list_v;
std::map<std::string, Value> map_v;
Vertex vertex_v;
Edge edge_v;
Path path_v;
};
Type type;
};
struct ValuesMap {
std::unordered_map<PropertyId, Value> values_map;
};
struct MappedValues {
std::vector<ValuesMap> values_map;
};
struct ListedValues {
std::vector<std::vector<Value>> properties;
};
using Values = std::variant<ListedValues, MappedValues>;
struct Expression {
std::string expression;
};
struct Filter {
std::string filter_expression;
};
enum class OrderingDirection { ASCENDING = 1, DESCENDING = 2 };
struct OrderBy {
Expression expression;
OrderingDirection direction;
};
enum class StorageView { OLD = 0, NEW = 1 };
struct ScanVerticesRequest {
Hlc transaction_id;
size_t start_id;
std::optional<std::vector<std::string>> props_to_return;
std::optional<std::vector<std::string>> filter_expressions;
std::optional<size_t> batch_limit;
StorageView storage_view;
};
struct ScanVerticesResponse {
bool success;
Values values;
std::optional<VertexId> next_start_id;
};
using VertexOrEdgeIds = std::variant<VertexId, EdgeId>;
struct GetPropertiesRequest {
Hlc transaction_id;
VertexOrEdgeIds vertex_or_edge_ids;
std::vector<PropertyId> property_ids;
std::vector<Expression> expressions;
bool only_unique = false;
std::optional<std::vector<OrderBy>> order_by;
std::optional<size_t> limit;
std::optional<Filter> filter;
};
struct GetPropertiesResponse {
bool success;
Values values;
};
enum class EdgeDirection : uint8_t { OUT = 1, IN = 2, BOTH = 3 };
struct ExpandOneRequest {
Hlc transaction_id;
std::vector<VertexId> src_vertices;
std::vector<EdgeType> edge_types;
EdgeDirection direction;
bool only_unique_neighbor_rows = false;
// The empty optional means return all of the properties, while an empty
// list means do not return any properties
// TODO(antaljanosbenjamin): All of the special values should be communicated through a single vertex object
// after schema is implemented
// Special values are accepted:
// * __mg__labels
std::optional<std::vector<PropertyId>> src_vertex_properties;
// TODO(antaljanosbenjamin): All of the special values should be communicated through a single vertex object
// after schema is implemented
// Special values are accepted:
// * __mg__dst_id (Vertex, but without labels)
// * __mg__type (binary)
std::optional<std::vector<PropertyId>> edge_properties;
// QUESTION(antaljanosbenjamin): Maybe also add possibility to expressions evaluated on the source vertex?
// List of expressions evaluated on edges
std::vector<Expression> expressions;
std::optional<std::vector<OrderBy>> order_by;
std::optional<size_t> limit;
std::optional<Filter> filter;
};
struct ExpandOneResultRow {
// NOTE: This struct could be a single Values with columns something like this:
// src_vertex(Vertex), vertex_prop1(Value), vertex_prop2(Value), edges(list<Value>)
// where edges might be a list of:
// 1. list<Value> if only a defined list of edge properties are returned
// 2. map<binary, Value> if all of the edge properties are returned
// The drawback of this is currently the key of the map is always interpreted as a string in Value, not as an
// integer, which should be in case of mapped properties.
Vertex src_vertex;
std::optional<Values> src_vertex_properties;
Values edges;
};
struct ExpandOneResponse {
std::vector<ExpandOneResultRow> result;
};
struct NewVertex {
std::vector<Label> label_ids;
std::map<PropertyId, Value> properties;
};
struct CreateVerticesRequest {
Hlc transaction_id;
std::vector<NewVertex> new_vertices;
};
struct CreateVerticesResponse {
bool success;
};
using ReadRequests = std::variant<ExpandOneRequest, GetPropertiesRequest, ScanVerticesRequest>;
using ReadResponses = std::variant<ExpandOneResponse, GetPropertiesResponse, ScanVerticesResponse>;
using WriteRequests = CreateVerticesRequest;
using WriteResponses = CreateVerticesResponse;

View File

@ -10,6 +10,9 @@ add_subdirectory(stress)
# concurrent test binaries
add_subdirectory(concurrent)
# simulation test binaries
add_subdirectory(simulation)
# manual test binaries
add_subdirectory(manual)

View File

@ -62,3 +62,6 @@ target_link_libraries(${test_prefix}storage_v2_gc mg-storage-v2)
add_benchmark(storage_v2_property_store.cpp)
target_link_libraries(${test_prefix}storage_v2_property_store mg-storage-v2)
add_benchmark(future.cpp)
target_link_libraries(${test_prefix}future mg-io)

View File

@ -0,0 +1,30 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <benchmark/benchmark.h>
#include "io/future.hpp"
static void FuturePairFillWait(benchmark::State &state) {
uint64_t counter = 0;
while (state.KeepRunning()) {
auto [future, promise] = memgraph::io::FuturePromisePair<int>();
promise.Fill(1);
std::move(future).Wait();
++counter;
}
state.SetItemsProcessed(counter);
}
BENCHMARK(FuturePairFillWait)->Unit(benchmark::kNanosecond)->UseRealTime();
BENCHMARK_MAIN();

View File

@ -0,0 +1,32 @@
set(test_prefix memgraph__simulation__)
find_package(gflags)
add_custom_target(memgraph__simulation)
function(add_simulation_test test_cpp san)
# get exec name (remove extension from the abs path)
get_filename_component(exec_name ${test_cpp} NAME_WE)
set(target_name ${test_prefix}${exec_name})
add_executable(${target_name} ${test_cpp})
# OUTPUT_NAME sets the real name of a target when it is built and can be
# used to help create two targets of the same name even though CMake
# requires unique logical target names
set_target_properties(${target_name} PROPERTIES OUTPUT_NAME ${exec_name})
target_link_libraries(${target_name} gtest gmock mg-utils mg-io mg-io-simulator)
# sanitize
target_compile_options(${target_name} PRIVATE -fsanitize=${san})
target_link_options(${target_name} PRIVATE -fsanitize=${san})
# register test
add_test(${target_name} ${exec_name})
add_dependencies(memgraph__simulation ${target_name})
endfunction(add_simulation_test)
add_simulation_test(basic_request.cpp address)
add_simulation_test(raft.cpp address)
add_simulation_test(trial_query_storage/query_storage_test.cpp address)

View File

@ -0,0 +1,87 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <thread>
#include "io/simulator/simulator.hpp"
using memgraph::io::Address;
using memgraph::io::Io;
using memgraph::io::ResponseFuture;
using memgraph::io::ResponseResult;
using memgraph::io::simulator::Simulator;
using memgraph::io::simulator::SimulatorConfig;
using memgraph::io::simulator::SimulatorTransport;
struct CounterRequest {
uint64_t proposal;
};
struct CounterResponse {
uint64_t highest_seen;
};
void run_server(Io<SimulatorTransport> io) {
uint64_t highest_seen = 0;
while (!io.ShouldShutDown()) {
std::cout << "[SERVER] Is receiving..." << std::endl;
auto request_result = io.Receive<CounterRequest>();
if (request_result.HasError()) {
std::cout << "[SERVER] Error, continue" << std::endl;
continue;
}
auto request_envelope = request_result.GetValue();
auto req = std::get<CounterRequest>(request_envelope.message);
highest_seen = std::max(highest_seen, req.proposal);
auto srv_res = CounterResponse{highest_seen};
io.Send(request_envelope.from_address, request_envelope.request_id, srv_res);
}
}
int main() {
auto config = SimulatorConfig{
.drop_percent = 0,
.perform_timeouts = true,
.scramble_messages = true,
.rng_seed = 0,
};
auto simulator = Simulator(config);
auto cli_addr = Address::TestAddress(1);
auto srv_addr = Address::TestAddress(2);
Io<SimulatorTransport> cli_io = simulator.Register(cli_addr);
Io<SimulatorTransport> srv_io = simulator.Register(srv_addr);
auto srv_thread = std::jthread(run_server, std::move(srv_io));
simulator.IncrementServerCountAndWaitForQuiescentState(srv_addr);
for (int i = 1; i < 3; ++i) {
// send request
CounterRequest cli_req;
cli_req.proposal = i;
auto res_f = cli_io.Request<CounterRequest, CounterResponse>(srv_addr, cli_req);
auto res_rez = std::move(res_f).Wait();
if (!res_rez.HasError()) {
std::cout << "[CLIENT] Got a valid response" << std::endl;
auto env = res_rez.GetValue();
MG_ASSERT(env.message.highest_seen == i);
} else {
std::cout << "[CLIENT] Got an error" << std::endl;
}
}
simulator.ShutDown();
return 0;
}

315
tests/simulation/raft.cpp Normal file
View File

@ -0,0 +1,315 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <chrono>
#include <deque>
#include <iostream>
#include <map>
#include <optional>
#include <set>
#include <thread>
#include <vector>
#include "io/address.hpp"
#include "io/rsm/raft.hpp"
#include "io/simulator/simulator.hpp"
#include "io/simulator/simulator_transport.hpp"
using memgraph::io::Address;
using memgraph::io::Duration;
using memgraph::io::Io;
using memgraph::io::ResponseEnvelope;
using memgraph::io::ResponseFuture;
using memgraph::io::ResponseResult;
using memgraph::io::Time;
using memgraph::io::rsm::Raft;
using memgraph::io::rsm::ReadRequest;
using memgraph::io::rsm::ReadResponse;
using memgraph::io::rsm::WriteRequest;
using memgraph::io::rsm::WriteResponse;
using memgraph::io::simulator::Simulator;
using memgraph::io::simulator::SimulatorConfig;
using memgraph::io::simulator::SimulatorStats;
using memgraph::io::simulator::SimulatorTransport;
struct CasRequest {
int key;
std::optional<int> old_value;
std::optional<int> new_value;
};
struct CasResponse {
bool cas_success;
std::optional<int> last_value;
};
struct GetRequest {
int key;
};
struct GetResponse {
std::optional<int> value;
};
class TestState {
std::map<int, int> state_;
public:
GetResponse Read(GetRequest request) {
GetResponse ret;
if (state_.contains(request.key)) {
ret.value = state_[request.key];
}
return ret;
}
CasResponse Apply(CasRequest request) {
CasResponse ret;
// Key exist
if (state_.contains(request.key)) {
auto &val = state_[request.key];
/*
* Delete
*/
if (!request.new_value) {
ret.last_value = val;
ret.cas_success = true;
state_.erase(state_.find(request.key));
}
/*
* Update
*/
// Does old_value match?
if (request.old_value == val) {
ret.last_value = val;
ret.cas_success = true;
val = request.new_value.value();
} else {
ret.last_value = val;
ret.cas_success = false;
}
}
/*
* Create
*/
else {
ret.last_value = std::nullopt;
ret.cas_success = true;
state_.emplace(request.key, std::move(request.new_value).value());
}
return ret;
}
};
template <typename IoImpl>
void RunRaft(Raft<IoImpl, TestState, CasRequest, CasResponse, GetRequest, GetResponse> server) {
server.Run();
}
void RunSimulation() {
SimulatorConfig config{
.drop_percent = 5,
.perform_timeouts = true,
.scramble_messages = true,
.rng_seed = 0,
.start_time = Time::min() + std::chrono::microseconds{256 * 1024},
.abort_time = Time::min() + std::chrono::microseconds{8 * 1024 * 128},
};
auto simulator = Simulator(config);
auto cli_addr = Address::TestAddress(1);
auto srv_addr_1 = Address::TestAddress(2);
auto srv_addr_2 = Address::TestAddress(3);
auto srv_addr_3 = Address::TestAddress(4);
Io<SimulatorTransport> cli_io = simulator.Register(cli_addr);
Io<SimulatorTransport> srv_io_1 = simulator.Register(srv_addr_1);
Io<SimulatorTransport> srv_io_2 = simulator.Register(srv_addr_2);
Io<SimulatorTransport> srv_io_3 = simulator.Register(srv_addr_3);
std::vector<Address> srv_1_peers = {srv_addr_2, srv_addr_3};
std::vector<Address> srv_2_peers = {srv_addr_1, srv_addr_3};
std::vector<Address> srv_3_peers = {srv_addr_1, srv_addr_2};
// TODO(tyler / gabor) supply default TestState to Raft constructor
using RaftClass = Raft<SimulatorTransport, TestState, CasRequest, CasResponse, GetRequest, GetResponse>;
RaftClass srv_1{std::move(srv_io_1), srv_1_peers, TestState{}};
RaftClass srv_2{std::move(srv_io_2), srv_2_peers, TestState{}};
RaftClass srv_3{std::move(srv_io_3), srv_3_peers, TestState{}};
auto srv_thread_1 = std::jthread(RunRaft<SimulatorTransport>, std::move(srv_1));
simulator.IncrementServerCountAndWaitForQuiescentState(srv_addr_1);
auto srv_thread_2 = std::jthread(RunRaft<SimulatorTransport>, std::move(srv_2));
simulator.IncrementServerCountAndWaitForQuiescentState(srv_addr_2);
auto srv_thread_3 = std::jthread(RunRaft<SimulatorTransport>, std::move(srv_3));
simulator.IncrementServerCountAndWaitForQuiescentState(srv_addr_3);
spdlog::info("beginning test after servers have become quiescent");
std::mt19937 cli_rng_{0};
Address server_addrs[]{srv_addr_1, srv_addr_2, srv_addr_3};
Address leader = server_addrs[0];
const int key = 0;
std::optional<int> last_known_value = 0;
bool success = false;
for (int i = 0; !success; i++) {
// send request
CasRequest cas_req;
cas_req.key = key;
cas_req.old_value = last_known_value;
cas_req.new_value = i;
WriteRequest<CasRequest> cli_req;
cli_req.operation = cas_req;
spdlog::info("client sending CasRequest to Leader {} ", leader.last_known_port);
ResponseFuture<WriteResponse<CasResponse>> cas_response_future =
cli_io.Request<WriteRequest<CasRequest>, WriteResponse<CasResponse>>(leader, cli_req);
// receive cas_response
ResponseResult<WriteResponse<CasResponse>> cas_response_result = std::move(cas_response_future).Wait();
if (cas_response_result.HasError()) {
spdlog::info("client timed out while trying to communicate with assumed Leader server {}",
leader.last_known_port);
continue;
}
ResponseEnvelope<WriteResponse<CasResponse>> cas_response_envelope = cas_response_result.GetValue();
WriteResponse<CasResponse> write_cas_response = cas_response_envelope.message;
if (write_cas_response.retry_leader) {
MG_ASSERT(!write_cas_response.success, "retry_leader should never be set for successful responses");
leader = write_cas_response.retry_leader.value();
spdlog::info("client redirected to leader server {}", leader.last_known_port);
} else if (!write_cas_response.success) {
std::uniform_int_distribution<size_t> addr_distrib(0, 2);
size_t addr_index = addr_distrib(cli_rng_);
leader = server_addrs[addr_index];
spdlog::info("client NOT redirected to leader server, trying a random one at index {} with port {}", addr_index,
leader.last_known_port);
continue;
}
CasResponse cas_response = write_cas_response.write_return;
bool cas_succeeded = cas_response.cas_success;
spdlog::info("Client received CasResponse! success: {} last_known_value {}", cas_succeeded, (int)*last_known_value);
if (cas_succeeded) {
last_known_value = i;
} else {
last_known_value = cas_response.last_value;
continue;
}
GetRequest get_req;
get_req.key = key;
ReadRequest<GetRequest> read_req;
read_req.operation = get_req;
spdlog::info("client sending GetRequest to Leader {}", leader.last_known_port);
ResponseFuture<ReadResponse<GetResponse>> get_response_future =
cli_io.Request<ReadRequest<GetRequest>, ReadResponse<GetResponse>>(leader, read_req);
// receive response
ResponseResult<ReadResponse<GetResponse>> get_response_result = std::move(get_response_future).Wait();
if (get_response_result.HasError()) {
spdlog::info("client timed out while trying to communicate with Leader server {}", leader.last_known_port);
continue;
}
ResponseEnvelope<ReadResponse<GetResponse>> get_response_envelope = get_response_result.GetValue();
ReadResponse<GetResponse> read_get_response = get_response_envelope.message;
if (!read_get_response.success) {
// sent to a non-leader
continue;
}
if (read_get_response.retry_leader) {
MG_ASSERT(!read_get_response.success, "retry_leader should never be set for successful responses");
leader = read_get_response.retry_leader.value();
spdlog::info("client redirected to Leader server {}", leader.last_known_port);
} else if (!read_get_response.success) {
std::uniform_int_distribution<size_t> addr_distrib(0, 2);
size_t addr_index = addr_distrib(cli_rng_);
leader = server_addrs[addr_index];
spdlog::info("client NOT redirected to leader server, trying a random one at index {} with port {}", addr_index,
leader.last_known_port);
}
GetResponse get_response = read_get_response.read_return;
MG_ASSERT(get_response.value == i);
spdlog::info("client successfully cas'd a value and read it back! value: {}", i);
success = true;
}
MG_ASSERT(success);
simulator.ShutDown();
SimulatorStats stats = simulator.Stats();
spdlog::info("total messages: {}", stats.total_messages);
spdlog::info("dropped messages: {}", stats.dropped_messages);
spdlog::info("timed out requests: {}", stats.timed_out_requests);
spdlog::info("total requests: {}", stats.total_requests);
spdlog::info("total responses: {}", stats.total_responses);
spdlog::info("simulator ticks: {}", stats.simulator_ticks);
spdlog::info("========================== SUCCESS :) ==========================");
/*
this is implicit in jthread's dtor
srv_thread_1.join();
srv_thread_2.join();
srv_thread_3.join();
*/
}
int main() {
int n_tests = 50;
for (int i = 0; i < n_tests; i++) {
spdlog::info("========================== NEW SIMULATION {} ==========================", i);
spdlog::info("\tTime\t\tTerm\tPort\tRole\t\tMessage\n");
RunSimulation();
}
spdlog::info("passed {} tests!", n_tests);
return 0;
}

View File

@ -0,0 +1,34 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <optional>
#include <string>
#include <vector>
namespace memgraph::tests::simulation {
struct Vertex {
std::string key;
};
struct ScanVerticesRequest {
int64_t count;
std::optional<int64_t> continuation;
};
struct VerticesResponse {
std::vector<Vertex> vertices;
std::optional<int64_t> continuation;
};
} // namespace memgraph::tests::simulation

View File

@ -0,0 +1,85 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <iostream>
#include "io/address.hpp"
#include "io/simulator/simulator.hpp"
#include "io/simulator/simulator_config.hpp"
#include "io/simulator/simulator_transport.hpp"
#include "io/transport.hpp"
#include "messages.hpp"
namespace memgraph::tests::simulation {
using memgraph::io::Io;
using memgraph::io::simulator::SimulatorTransport;
void run_server(Io<SimulatorTransport> io) {
while (!io.ShouldShutDown()) {
std::cout << "[STORAGE] Is receiving..." << std::endl;
auto request_result = io.Receive<ScanVerticesRequest>();
if (request_result.HasError()) {
std::cout << "[STORAGE] Error, continue" << std::endl;
continue;
}
auto request_envelope = request_result.GetValue();
auto req = std::get<ScanVerticesRequest>(request_envelope.message);
VerticesResponse response{};
const int64_t start_index = std::invoke([&req] {
if (req.continuation.has_value()) {
return *req.continuation;
}
return 0L;
});
for (auto index = start_index; index < start_index + req.count; ++index) {
response.vertices.push_back({std::string("Vertex_") + std::to_string(index)});
}
io.Send(request_envelope.from_address, request_envelope.request_id, response);
}
}
} // namespace memgraph::tests::simulation
int main() {
using memgraph::io::Address;
using memgraph::io::Io;
using memgraph::io::simulator::Simulator;
using memgraph::io::simulator::SimulatorConfig;
using memgraph::io::simulator::SimulatorTransport;
using memgraph::tests::simulation::run_server;
using memgraph::tests::simulation::ScanVerticesRequest;
using memgraph::tests::simulation::VerticesResponse;
auto config = SimulatorConfig{
.drop_percent = 0,
.perform_timeouts = true,
.scramble_messages = true,
.rng_seed = 0,
};
auto simulator = Simulator(config);
auto cli_addr = Address::TestAddress(1);
auto srv_addr = Address::TestAddress(2);
Io<SimulatorTransport> cli_io = simulator.Register(cli_addr);
Io<SimulatorTransport> srv_io = simulator.Register(srv_addr);
auto srv_thread = std::jthread(run_server, std::move(srv_io));
simulator.IncrementServerCountAndWaitForQuiescentState(srv_addr);
auto req = ScanVerticesRequest{2, std::nullopt};
auto res_f = cli_io.Request<ScanVerticesRequest, VerticesResponse>(srv_addr, req);
auto res_rez = std::move(res_f).Wait();
simulator.ShutDown();
return 0;
}

View File

@ -407,3 +407,11 @@ find_package(Boost REQUIRED)
add_unit_test(websocket.cpp)
target_link_libraries(${test_prefix}websocket mg-communication Boost::headers)
# Test future
add_unit_test(future.cpp)
target_link_libraries(${test_prefix}future mg-io)
# Test Local Transport
add_unit_test(local_transport.cpp)
target_link_libraries(${test_prefix}local_transport mg-io)

55
tests/unit/future.cpp Normal file
View File

@ -0,0 +1,55 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <string>
#include <thread>
#include "gtest/gtest.h"
#include "io/future.hpp"
using namespace memgraph::io;
void Fill(Promise<std::string> promise_1) { promise_1.Fill("success"); }
void Wait(Future<std::string> future_1, Promise<std::string> promise_2) {
std::string result_1 = std::move(future_1).Wait();
EXPECT_TRUE(result_1 == "success");
promise_2.Fill("it worked");
}
TEST(Future, BasicLifecycle) {
std::atomic_bool waiting = false;
std::function<bool()> notifier = [&] {
waiting.store(true, std::memory_order_seq_cst);
return false;
};
auto [future_1, promise_1] = FuturePromisePairWithNotifier<std::string>(notifier);
auto [future_2, promise_2] = FuturePromisePair<std::string>();
std::jthread t1(Wait, std::move(future_1), std::move(promise_2));
// spin in a loop until the promise signals
// that it is waiting
while (!waiting.load(std::memory_order_acquire)) {
std::this_thread::yield();
}
std::jthread t2(Fill, std::move(promise_1));
t1.join();
t2.join();
std::string result_2 = std::move(future_2).Wait();
EXPECT_TRUE(result_2 == "it worked");
}

View File

@ -0,0 +1,86 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <chrono>
#include <thread>
#include <gtest/gtest.h>
#include "io/local_transport/local_system.hpp"
#include "io/local_transport/local_transport.hpp"
#include "io/transport.hpp"
namespace memgraph::io::tests {
using memgraph::io::local_transport::LocalSystem;
using memgraph::io::local_transport::LocalTransport;
struct CounterRequest {
uint64_t proposal;
};
struct CounterResponse {
uint64_t highest_seen;
};
void RunServer(Io<LocalTransport> io) {
uint64_t highest_seen = 0;
while (!io.ShouldShutDown()) {
spdlog::info("[SERVER] Is receiving...");
auto request_result = io.Receive<CounterRequest>();
if (request_result.HasError()) {
spdlog::info("[SERVER] timed out, continue");
continue;
}
auto request_envelope = request_result.GetValue();
ASSERT_TRUE(std::holds_alternative<CounterRequest>(request_envelope.message));
auto req = std::get<CounterRequest>(request_envelope.message);
highest_seen = std::max(highest_seen, req.proposal);
auto srv_res = CounterResponse{highest_seen};
io.Send(request_envelope.from_address, request_envelope.request_id, srv_res);
}
}
TEST(LocalTransport, BasicRequest) {
LocalSystem local_system;
// rely on uuid to be unique on default Address
auto cli_addr = Address::UniqueLocalAddress();
auto srv_addr = Address::UniqueLocalAddress();
Io<LocalTransport> cli_io = local_system.Register(cli_addr);
Io<LocalTransport> srv_io = local_system.Register(srv_addr);
auto srv_thread = std::jthread(RunServer, std::move(srv_io));
for (int i = 1; i < 3; ++i) {
// send request
CounterRequest cli_req;
auto value = 1; // i;
cli_req.proposal = value;
spdlog::info("[CLIENT] sending request");
auto res_f = cli_io.Request<CounterRequest, CounterResponse>(srv_addr, cli_req);
spdlog::info("[CLIENT] waiting on future");
auto res_rez = std::move(res_f).Wait();
spdlog::info("[CLIENT] future returned");
MG_ASSERT(!res_rez.HasError());
spdlog::info("[CLIENT] Got a valid response");
auto env = res_rez.GetValue();
MG_ASSERT(env.message.highest_seen == value);
}
local_system.ShutDown();
}
} // namespace memgraph::io::tests