Merge pull request #634 from memgraph/T1122-MG-ShardManager-ThreadPool

Add shard manager thread pool
This commit is contained in:
Tyler Neely 2022-11-04 16:41:53 +01:00 committed by GitHub
commit d85fb94bc7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 772 additions and 207 deletions

View File

@ -0,0 +1,156 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <chrono>
#include <deque>
#include <memory>
#include <queue>
#include <variant>
#include "coordinator/coordinator.hpp"
#include "coordinator/coordinator_rsm.hpp"
#include "coordinator/shard_map.hpp"
#include "io/address.hpp"
#include "io/future.hpp"
#include "io/messages.hpp"
#include "io/rsm/raft.hpp"
#include "io/time.hpp"
#include "io/transport.hpp"
#include "query/v2/requests.hpp"
namespace memgraph::coordinator::coordinator_worker {
/// Obligations:
/// * ShutDown
/// * Cron
/// * RouteMessage
using coordinator::Coordinator;
using coordinator::CoordinatorRsm;
using io::Address;
using io::RequestId;
using io::Time;
using io::messages::CoordinatorMessages;
using msgs::ReadRequests;
using msgs::ReadResponses;
using msgs::WriteRequests;
using msgs::WriteResponses;
struct ShutDown {};
struct Cron {};
struct RouteMessage {
CoordinatorMessages message;
RequestId request_id;
Address to;
Address from;
};
using Message = std::variant<RouteMessage, Cron, ShutDown>;
struct QueueInner {
std::mutex mu{};
std::condition_variable cv;
// TODO(tyler) handle simulator communication std::shared_ptr<std::atomic<int>> blocked;
// TODO(tyler) investigate using a priority queue that prioritizes messages in a way that
// improves overall QoS. For example, maybe we want to schedule raft Append messages
// ahead of Read messages or generally writes before reads for lowering the load on the
// overall system faster etc... When we do this, we need to make sure to avoid
// starvation by sometimes randomizing priorities, rather than following a strict
// prioritization.
std::deque<Message> queue;
};
/// There are two reasons to implement our own Queue instead of using
/// one off-the-shelf:
/// 1. we will need to know in the simulator when all threads are waiting
/// 2. we will want to implement our own priority queue within this for QoS
class Queue {
std::shared_ptr<QueueInner> inner_ = std::make_shared<QueueInner>();
public:
void Push(Message &&message) {
{
MG_ASSERT(inner_.use_count() > 0);
std::unique_lock<std::mutex> lock(inner_->mu);
inner_->queue.emplace_back(std::move(message));
} // lock dropped before notifying condition variable
inner_->cv.notify_all();
}
Message Pop() {
MG_ASSERT(inner_.use_count() > 0);
std::unique_lock<std::mutex> lock(inner_->mu);
while (inner_->queue.empty()) {
inner_->cv.wait(lock);
}
Message message = std::move(inner_->queue.front());
inner_->queue.pop_front();
return message;
}
};
/// A CoordinatorWorker owns Raft<CoordinatorRsm> instances. receives messages from the MachineManager.
template <typename IoImpl>
class CoordinatorWorker {
io::Io<IoImpl> io_;
Queue queue_;
CoordinatorRsm<IoImpl> coordinator_;
bool Process(ShutDown && /*shut_down*/) { return false; }
bool Process(Cron && /* cron */) {
coordinator_.Cron();
return true;
}
bool Process(RouteMessage &&route_message) {
coordinator_.Handle(std::move(route_message.message), route_message.request_id, route_message.from);
return true;
}
public:
CoordinatorWorker(io::Io<IoImpl> io, Queue queue, Coordinator coordinator)
: io_(std::move(io)),
queue_(std::move(queue)),
coordinator_{std::move(io_.ForkLocal()), {}, std::move(coordinator)} {}
CoordinatorWorker(CoordinatorWorker &&) noexcept = default;
CoordinatorWorker &operator=(CoordinatorWorker &&) noexcept = default;
CoordinatorWorker(const CoordinatorWorker &) = delete;
CoordinatorWorker &operator=(const CoordinatorWorker &) = delete;
~CoordinatorWorker() = default;
void Run() {
while (true) {
Message message = queue_.Pop();
const bool should_continue = std::visit(
[this](auto &&msg) { return this->Process(std::forward<decltype(msg)>(msg)); }, std::move(message));
if (!should_continue) {
return;
}
}
}
};
} // namespace memgraph::coordinator::coordinator_worker

View File

@ -31,14 +31,9 @@ class LocalTransport {
: local_transport_handle_(std::move(local_transport_handle)) {} : local_transport_handle_(std::move(local_transport_handle)) {}
template <Message RequestT, Message ResponseT> template <Message RequestT, Message ResponseT>
ResponseFuture<ResponseT> Request(Address to_address, Address from_address, RequestId request_id, RequestT request, ResponseFuture<ResponseT> Request(Address to_address, Address from_address, RequestT request, Duration timeout) {
Duration timeout) { return local_transport_handle_->template SubmitRequest<RequestT, ResponseT>(to_address, from_address,
auto [future, promise] = memgraph::io::FuturePromisePair<ResponseResult<ResponseT>>(); std::move(request), timeout);
local_transport_handle_->SubmitRequest(to_address, from_address, request_id, std::move(request), timeout,
std::move(promise));
return std::move(future);
} }
template <Message... Ms> template <Message... Ms>
@ -61,8 +56,6 @@ class LocalTransport {
return distrib(rng); return distrib(rng);
} }
std::unordered_map<std::string, LatencyHistogramSummary> ResponseLatencies() { LatencyHistogramSummaries ResponseLatencies() { return local_transport_handle_->ResponseLatencies(); }
return local_transport_handle_->ResponseLatencies();
}
}; };
}; // namespace memgraph::io::local_transport }; // namespace memgraph::io::local_transport

View File

@ -30,6 +30,7 @@ class LocalTransportHandle {
mutable std::condition_variable cv_; mutable std::condition_variable cv_;
bool should_shut_down_ = false; bool should_shut_down_ = false;
MessageHistogramCollector histograms_; MessageHistogramCollector histograms_;
RequestId request_id_counter_ = 0;
// the responses to requests that are being waited on // the responses to requests that are being waited on
std::map<PromiseKey, DeadlineAndOpaquePromise> promises_; std::map<PromiseKey, DeadlineAndOpaquePromise> promises_;
@ -56,7 +57,7 @@ class LocalTransportHandle {
return should_shut_down_; return should_shut_down_;
} }
std::unordered_map<std::string, LatencyHistogramSummary> ResponseLatencies() { LatencyHistogramSummaries ResponseLatencies() {
std::unique_lock<std::mutex> lock(mu_); std::unique_lock<std::mutex> lock(mu_);
return histograms_.ResponseLatencies(); return histograms_.ResponseLatencies();
} }
@ -113,8 +114,7 @@ class LocalTransportHandle {
.message = std::move(message_any), .message = std::move(message_any),
.type_info = type_info}; .type_info = type_info};
PromiseKey promise_key{ PromiseKey promise_key{.requester_address = to_address, .request_id = opaque_message.request_id};
.requester_address = to_address, .request_id = opaque_message.request_id, .replier_address = from_address};
{ {
std::unique_lock<std::mutex> lock(mu_); std::unique_lock<std::mutex> lock(mu_);
@ -139,8 +139,10 @@ class LocalTransportHandle {
} }
template <Message RequestT, Message ResponseT> template <Message RequestT, Message ResponseT>
void SubmitRequest(Address to_address, Address from_address, RequestId request_id, RequestT &&request, ResponseFuture<ResponseT> SubmitRequest(Address to_address, Address from_address, RequestT &&request,
Duration timeout, ResponsePromise<ResponseT> promise) { Duration timeout) {
auto [future, promise] = memgraph::io::FuturePromisePair<ResponseResult<ResponseT>>();
const bool port_matches = to_address.last_known_port == from_address.last_known_port; const bool port_matches = to_address.last_known_port == from_address.last_known_port;
const bool ip_matches = to_address.last_known_ip == from_address.last_known_ip; const bool ip_matches = to_address.last_known_ip == from_address.last_known_ip;
@ -149,17 +151,22 @@ class LocalTransportHandle {
const auto now = Now(); const auto now = Now();
const Time deadline = now + timeout; const Time deadline = now + timeout;
RequestId request_id = 0;
{ {
std::unique_lock<std::mutex> lock(mu_); std::unique_lock<std::mutex> lock(mu_);
PromiseKey promise_key{ request_id = ++request_id_counter_;
.requester_address = from_address, .request_id = request_id, .replier_address = to_address}; PromiseKey promise_key{.requester_address = from_address, .request_id = request_id};
OpaquePromise opaque_promise(std::move(promise).ToUnique()); OpaquePromise opaque_promise(std::move(promise).ToUnique());
DeadlineAndOpaquePromise dop{.requested_at = now, .deadline = deadline, .promise = std::move(opaque_promise)}; DeadlineAndOpaquePromise dop{.requested_at = now, .deadline = deadline, .promise = std::move(opaque_promise)};
MG_ASSERT(!promises_.contains(promise_key));
promises_.emplace(std::move(promise_key), std::move(dop)); promises_.emplace(std::move(promise_key), std::move(dop));
} // lock dropped } // lock dropped
Send(to_address, from_address, request_id, std::forward<RequestT>(request)); Send(to_address, from_address, request_id, std::forward<RequestT>(request));
return std::move(future);
} }
}; };

View File

@ -11,6 +11,8 @@
#pragma once #pragma once
#include <boost/core/demangle.hpp>
#include "io/transport.hpp" #include "io/transport.hpp"
#include "utils/type_info_ref.hpp" #include "utils/type_info_ref.hpp"
@ -19,9 +21,6 @@ namespace memgraph::io {
struct PromiseKey { struct PromiseKey {
Address requester_address; Address requester_address;
uint64_t request_id; uint64_t request_id;
// TODO(tyler) possibly remove replier_address from promise key
// once we want to support DSR.
Address replier_address;
public: public:
friend bool operator<(const PromiseKey &lhs, const PromiseKey &rhs) { friend bool operator<(const PromiseKey &lhs, const PromiseKey &rhs) {
@ -29,12 +28,8 @@ struct PromiseKey {
return lhs.requester_address < rhs.requester_address; return lhs.requester_address < rhs.requester_address;
} }
if (lhs.request_id != rhs.request_id) {
return lhs.request_id < rhs.request_id; return lhs.request_id < rhs.request_id;
} }
return lhs.replier_address < rhs.replier_address;
}
}; };
struct OpaqueMessage { struct OpaqueMessage {
@ -90,6 +85,10 @@ struct OpaqueMessage {
}; };
} }
std::string demangled_name = "\"" + boost::core::demangle(message.type().name()) + "\"";
spdlog::error("failed to cast message of type {} to expected request type (probably in Receive argument types)",
demangled_name);
return std::nullopt; return std::nullopt;
} }
}; };

View File

@ -20,6 +20,7 @@
#include "io/time.hpp" #include "io/time.hpp"
#include "utils/histogram.hpp" #include "utils/histogram.hpp"
#include "utils/logging.hpp" #include "utils/logging.hpp"
#include "utils/print_helpers.hpp"
#include "utils/type_info_ref.hpp" #include "utils/type_info_ref.hpp"
namespace memgraph::io { namespace memgraph::io {
@ -57,6 +58,35 @@ struct LatencyHistogramSummary {
} }
}; };
struct LatencyHistogramSummaries {
std::unordered_map<std::string, LatencyHistogramSummary> latencies;
std::string SummaryTable() {
std::string output;
const auto row = [&output](const auto &c1, const auto &c2, const auto &c3, const auto &c4, const auto &c5,
const auto &c6, const auto &c7) {
output +=
fmt::format("{: >50} | {: >8} | {: >8} | {: >8} | {: >8} | {: >8} | {: >8}\n", c1, c2, c3, c4, c5, c6, c7);
};
row("name", "count", "min (μs)", "med (μs)", "p99 (μs)", "max (μs)", "sum (ms)");
for (const auto &[name, histo] : latencies) {
row(name, histo.count, histo.p0.count(), histo.p50.count(), histo.p99.count(), histo.p100.count(),
histo.sum.count() / 1000);
}
output += "\n";
return output;
}
friend std::ostream &operator<<(std::ostream &in, const LatencyHistogramSummaries &histo) {
using memgraph::utils::print_helpers::operator<<;
in << histo.latencies;
return in;
}
};
class MessageHistogramCollector { class MessageHistogramCollector {
std::unordered_map<utils::TypeInfoRef, utils::Histogram, utils::TypeInfoHasher, utils::TypeInfoEqualTo> histograms_; std::unordered_map<utils::TypeInfoRef, utils::Histogram, utils::TypeInfoHasher, utils::TypeInfoEqualTo> histograms_;
@ -66,7 +96,7 @@ class MessageHistogramCollector {
histo.Measure(duration.count()); histo.Measure(duration.count());
} }
std::unordered_map<std::string, LatencyHistogramSummary> ResponseLatencies() { LatencyHistogramSummaries ResponseLatencies() {
std::unordered_map<std::string, LatencyHistogramSummary> ret{}; std::unordered_map<std::string, LatencyHistogramSummary> ret{};
for (const auto &[type_id, histo] : histograms_) { for (const auto &[type_id, histo] : histograms_) {
@ -90,7 +120,7 @@ class MessageHistogramCollector {
ret.emplace(demangled_name, latency_histogram_summary); ret.emplace(demangled_name, latency_histogram_summary);
} }
return ret; return LatencyHistogramSummaries{.latencies = ret};
} }
}; };

View File

@ -22,6 +22,8 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include <boost/core/demangle.hpp>
#include "io/message_conversion.hpp" #include "io/message_conversion.hpp"
#include "io/simulator/simulator.hpp" #include "io/simulator/simulator.hpp"
#include "io/transport.hpp" #include "io/transport.hpp"
@ -109,6 +111,16 @@ utils::TypeInfoRef TypeInfoFor(const WriteResponse<WriteReturn> & /* write_respo
return typeid(WriteReturn); return typeid(WriteReturn);
} }
template <class WriteOperation>
utils::TypeInfoRef TypeInfoFor(const WriteRequest<WriteOperation> & /* write_request */) {
return typeid(WriteOperation);
}
template <class... WriteOperations>
utils::TypeInfoRef TypeInfoFor(const WriteRequest<std::variant<WriteOperations...>> &write_request) {
return TypeInfoForVariant(write_request.operation);
}
/// AppendRequest is a raft-level message that the Leader /// AppendRequest is a raft-level message that the Leader
/// periodically broadcasts to all Follower peers. This /// periodically broadcasts to all Follower peers. This
/// serves three main roles: /// serves three main roles:
@ -569,7 +581,7 @@ class Raft {
const Time now = io_.Now(); const Time now = io_.Now();
const Duration broadcast_timeout = RandomTimeout(kMinimumBroadcastTimeout, kMaximumBroadcastTimeout); const Duration broadcast_timeout = RandomTimeout(kMinimumBroadcastTimeout, kMaximumBroadcastTimeout);
if (now - leader.last_broadcast > broadcast_timeout) { if (now > leader.last_broadcast + broadcast_timeout) {
BroadcastAppendEntries(leader.followers); BroadcastAppendEntries(leader.followers);
leader.last_broadcast = now; leader.last_broadcast = now;
} }
@ -918,7 +930,9 @@ class Raft {
// only leaders actually handle replication requests from clients // only leaders actually handle replication requests from clients
std::optional<Role> Handle(Leader &leader, WriteRequest<WriteOperation> &&req, RequestId request_id, std::optional<Role> Handle(Leader &leader, WriteRequest<WriteOperation> &&req, RequestId request_id,
Address from_address) { Address from_address) {
Log("handling WriteRequest"); auto type_info = TypeInfoFor(req);
std::string demangled_name = boost::core::demangle(type_info.get().name());
Log("handling WriteRequest<" + demangled_name + ">");
// we are the leader. add item to log and send Append to peers // we are the leader. add item to log and send Append to peers
MG_ASSERT(state_.term >= LastLogTerm()); MG_ASSERT(state_.term >= LastLogTerm());

View File

@ -31,7 +31,7 @@ bool SimulatorHandle::ShouldShutDown() const {
return should_shut_down_; return should_shut_down_;
} }
std::unordered_map<std::string, LatencyHistogramSummary> SimulatorHandle::ResponseLatencies() { LatencyHistogramSummaries SimulatorHandle::ResponseLatencies() {
std::unique_lock<std::mutex> lock(mu_); std::unique_lock<std::mutex> lock(mu_);
return histograms_.ResponseLatencies(); return histograms_.ResponseLatencies();
} }
@ -108,9 +108,7 @@ bool SimulatorHandle::MaybeTickSimulator() {
stats_.dropped_messages++; stats_.dropped_messages++;
} }
PromiseKey promise_key{.requester_address = to_address, PromiseKey promise_key{.requester_address = to_address, .request_id = opaque_message.request_id};
.request_id = opaque_message.request_id,
.replier_address = opaque_message.from_address};
if (promises_.contains(promise_key)) { if (promises_.contains(promise_key)) {
// complete waiting promise if it's there // complete waiting promise if it's there

View File

@ -56,14 +56,14 @@ class SimulatorHandle {
std::uniform_int_distribution<int> drop_distrib_{0, 99}; std::uniform_int_distribution<int> drop_distrib_{0, 99};
SimulatorConfig config_; SimulatorConfig config_;
MessageHistogramCollector histograms_; MessageHistogramCollector histograms_;
RequestId request_id_counter_{0};
void TimeoutPromisesPastDeadline() { void TimeoutPromisesPastDeadline() {
const Time now = cluster_wide_time_microseconds_; const Time now = cluster_wide_time_microseconds_;
for (auto it = promises_.begin(); it != promises_.end();) { for (auto it = promises_.begin(); it != promises_.end();) {
auto &[promise_key, dop] = *it; auto &[promise_key, dop] = *it;
if (dop.deadline < now && config_.perform_timeouts) { if (dop.deadline < now && config_.perform_timeouts) {
spdlog::info("timing out request from requester {} to replier {}.", promise_key.requester_address.ToString(), spdlog::info("timing out request from requester {}.", promise_key.requester_address.ToString());
promise_key.replier_address.ToString());
std::move(dop).promise.TimeOut(); std::move(dop).promise.TimeOut();
it = promises_.erase(it); it = promises_.erase(it);
@ -78,7 +78,7 @@ class SimulatorHandle {
explicit SimulatorHandle(SimulatorConfig config) explicit SimulatorHandle(SimulatorConfig config)
: cluster_wide_time_microseconds_(config.start_time), rng_(config.rng_seed), config_(config) {} : cluster_wide_time_microseconds_(config.start_time), rng_(config.rng_seed), config_(config) {}
std::unordered_map<std::string, LatencyHistogramSummary> ResponseLatencies(); LatencyHistogramSummaries ResponseLatencies();
~SimulatorHandle() { ~SimulatorHandle() {
for (auto it = promises_.begin(); it != promises_.end();) { for (auto it = promises_.begin(); it != promises_.end();) {
@ -101,12 +101,17 @@ class SimulatorHandle {
bool ShouldShutDown() const; bool ShouldShutDown() const;
template <Message Request, Message Response> template <Message Request, Message Response>
void SubmitRequest(Address to_address, Address from_address, RequestId request_id, Request &&request, ResponseFuture<Response> SubmitRequest(Address to_address, Address from_address, Request &&request, Duration timeout,
Duration timeout, ResponsePromise<Response> &&promise) { std::function<bool()> &&maybe_tick_simulator) {
auto type_info = TypeInfoFor(request); auto type_info = TypeInfoFor(request);
auto [future, promise] = memgraph::io::FuturePromisePairWithNotifier<ResponseResult<Response>>(
std::forward<std::function<bool()>>(maybe_tick_simulator));
std::unique_lock<std::mutex> lock(mu_); std::unique_lock<std::mutex> lock(mu_);
RequestId request_id = ++request_id_counter_;
const Time deadline = cluster_wide_time_microseconds_ + timeout; const Time deadline = cluster_wide_time_microseconds_ + timeout;
std::any message(request); std::any message(request);
@ -117,19 +122,24 @@ class SimulatorHandle {
.type_info = type_info}; .type_info = type_info};
in_flight_.emplace_back(std::make_pair(to_address, std::move(om))); in_flight_.emplace_back(std::make_pair(to_address, std::move(om)));
PromiseKey promise_key{.requester_address = from_address, .request_id = request_id, .replier_address = to_address}; PromiseKey promise_key{.requester_address = from_address, .request_id = request_id};
OpaquePromise opaque_promise(std::move(promise).ToUnique()); OpaquePromise opaque_promise(std::move(promise).ToUnique());
DeadlineAndOpaquePromise dop{ DeadlineAndOpaquePromise dop{
.requested_at = cluster_wide_time_microseconds_, .requested_at = cluster_wide_time_microseconds_,
.deadline = deadline, .deadline = deadline,
.promise = std::move(opaque_promise), .promise = std::move(opaque_promise),
}; };
MG_ASSERT(!promises_.contains(promise_key));
promises_.emplace(std::move(promise_key), std::move(dop)); promises_.emplace(std::move(promise_key), std::move(dop));
stats_.total_messages++; stats_.total_messages++;
stats_.total_requests++; stats_.total_requests++;
cv_.notify_all(); cv_.notify_all();
return std::move(future);
} }
template <Message... Ms> template <Message... Ms>

View File

@ -33,16 +33,11 @@ class SimulatorTransport {
: simulator_handle_(simulator_handle), address_(address), rng_(std::mt19937{seed}) {} : simulator_handle_(simulator_handle), address_(address), rng_(std::mt19937{seed}) {}
template <Message RequestT, Message ResponseT> template <Message RequestT, Message ResponseT>
ResponseFuture<ResponseT> Request(Address to_address, Address from_address, uint64_t request_id, RequestT request, ResponseFuture<ResponseT> Request(Address to_address, Address from_address, RequestT request, Duration timeout) {
Duration timeout) {
std::function<bool()> maybe_tick_simulator = [this] { return simulator_handle_->MaybeTickSimulator(); }; std::function<bool()> maybe_tick_simulator = [this] { return simulator_handle_->MaybeTickSimulator(); };
auto [future, promise] =
memgraph::io::FuturePromisePairWithNotifier<ResponseResult<ResponseT>>(maybe_tick_simulator);
simulator_handle_->SubmitRequest(to_address, from_address, request_id, std::move(request), timeout, return simulator_handle_->template SubmitRequest<RequestT, ResponseT>(to_address, from_address, std::move(request),
std::move(promise)); timeout, std::move(maybe_tick_simulator));
return std::move(future);
} }
template <Message... Ms> template <Message... Ms>
@ -64,8 +59,6 @@ class SimulatorTransport {
return distrib(rng_); return distrib(rng_);
} }
std::unordered_map<std::string, LatencyHistogramSummary> ResponseLatencies() { LatencyHistogramSummaries ResponseLatencies() { return simulator_handle_->ResponseLatencies(); }
return simulator_handle_->ResponseLatencies();
}
}; };
}; // namespace memgraph::io::simulator }; // namespace memgraph::io::simulator

View File

@ -68,7 +68,6 @@ template <typename I>
class Io { class Io {
I implementation_; I implementation_;
Address address_; Address address_;
RequestId request_id_counter_ = 0;
Duration default_timeout_ = std::chrono::microseconds{100000}; Duration default_timeout_ = std::chrono::microseconds{100000};
public: public:
@ -84,20 +83,17 @@ class Io {
/// Issue a request with an explicit timeout in microseconds provided. This tends to be used by clients. /// Issue a request with an explicit timeout in microseconds provided. This tends to be used by clients.
template <Message RequestT, Message ResponseT> template <Message RequestT, Message ResponseT>
ResponseFuture<ResponseT> RequestWithTimeout(Address address, RequestT request, Duration timeout) { ResponseFuture<ResponseT> RequestWithTimeout(Address address, RequestT request, Duration timeout) {
const RequestId request_id = ++request_id_counter_;
const Address from_address = address_; const Address from_address = address_;
return implementation_.template Request<RequestT, ResponseT>(address, from_address, request_id, request, timeout); return implementation_.template Request<RequestT, ResponseT>(address, from_address, request, timeout);
} }
/// Issue a request that times out after the default timeout. This tends /// Issue a request that times out after the default timeout. This tends
/// to be used by clients. /// to be used by clients.
template <Message RequestT, Message ResponseT> template <Message RequestT, Message ResponseT>
ResponseFuture<ResponseT> Request(Address to_address, RequestT request) { ResponseFuture<ResponseT> Request(Address to_address, RequestT request) {
const RequestId request_id = ++request_id_counter_;
const Duration timeout = default_timeout_; const Duration timeout = default_timeout_;
const Address from_address = address_; const Address from_address = address_;
return implementation_.template Request<RequestT, ResponseT>(to_address, from_address, request_id, return implementation_.template Request<RequestT, ResponseT>(to_address, from_address, std::move(request), timeout);
std::move(request), timeout);
} }
/// Wait for an explicit number of microseconds for a request of one of the /// Wait for an explicit number of microseconds for a request of one of the
@ -143,8 +139,6 @@ class Io {
Io<I> ForkLocal() { return Io(implementation_, address_.ForkUniqueAddress()); } Io<I> ForkLocal() { return Io(implementation_, address_.ForkUniqueAddress()); }
std::unordered_map<std::string, LatencyHistogramSummary> ResponseLatencies() { LatencyHistogramSummaries ResponseLatencies() { return implementation_.ResponseLatencies(); }
return implementation_.ResponseLatencies();
}
}; };
}; // namespace memgraph::io }; // namespace memgraph::io

View File

@ -11,7 +11,11 @@
#pragma once #pragma once
#include <algorithm>
#include <thread>
#include <boost/asio/ip/tcp.hpp> #include <boost/asio/ip/tcp.hpp>
#include "io/address.hpp" #include "io/address.hpp"
#include "storage/v3/property_value.hpp" #include "storage/v3/property_value.hpp"
#include "storage/v3/schemas.hpp" #include "storage/v3/schemas.hpp"
@ -37,6 +41,7 @@ struct MachineConfig {
bool is_query_engine; bool is_query_engine;
boost::asio::ip::address listen_ip; boost::asio::ip::address listen_ip;
uint16_t listen_port; uint16_t listen_port;
size_t shard_worker_threads = std::max(static_cast<unsigned int>(1), std::thread::hardware_concurrency());
}; };
} // namespace memgraph::machine_manager } // namespace memgraph::machine_manager

View File

@ -11,39 +11,43 @@
#pragma once #pragma once
#include <coordinator/coordinator_rsm.hpp> #include "coordinator/coordinator_rsm.hpp"
#include <io/message_conversion.hpp> #include "coordinator/coordinator_worker.hpp"
#include <io/messages.hpp> #include "io/message_conversion.hpp"
#include <io/rsm/rsm_client.hpp> #include "io/messages.hpp"
#include <io/time.hpp> #include "io/rsm/rsm_client.hpp"
#include <machine_manager/machine_config.hpp> #include "io/time.hpp"
#include <storage/v3/shard_manager.hpp> #include "machine_manager/machine_config.hpp"
#include "storage/v3/shard_manager.hpp"
namespace memgraph::machine_manager { namespace memgraph::machine_manager {
using memgraph::coordinator::Coordinator; using coordinator::Coordinator;
using memgraph::coordinator::CoordinatorReadRequests; using coordinator::CoordinatorReadRequests;
using memgraph::coordinator::CoordinatorReadResponses; using coordinator::CoordinatorReadResponses;
using memgraph::coordinator::CoordinatorRsm; using coordinator::CoordinatorRsm;
using memgraph::coordinator::CoordinatorWriteRequests; using coordinator::CoordinatorWriteRequests;
using memgraph::coordinator::CoordinatorWriteResponses; using coordinator::CoordinatorWriteResponses;
using memgraph::io::ConvertVariant; using coordinator::coordinator_worker::CoordinatorWorker;
using memgraph::io::Duration; using CoordinatorRouteMessage = coordinator::coordinator_worker::RouteMessage;
using memgraph::io::RequestId; using CoordinatorQueue = coordinator::coordinator_worker::Queue;
using memgraph::io::Time; using io::ConvertVariant;
using memgraph::io::messages::CoordinatorMessages; using io::Duration;
using memgraph::io::messages::ShardManagerMessages; using io::RequestId;
using memgraph::io::messages::ShardMessages; using io::Time;
using memgraph::io::messages::StorageReadRequest; using io::messages::CoordinatorMessages;
using memgraph::io::messages::StorageWriteRequest; using io::messages::ShardManagerMessages;
using memgraph::io::rsm::AppendRequest; using io::messages::ShardMessages;
using memgraph::io::rsm::AppendResponse; using io::messages::StorageReadRequest;
using memgraph::io::rsm::ReadRequest; using io::messages::StorageWriteRequest;
using memgraph::io::rsm::VoteRequest; using io::rsm::AppendRequest;
using memgraph::io::rsm::VoteResponse; using io::rsm::AppendResponse;
using memgraph::io::rsm::WriteRequest; using io::rsm::ReadRequest;
using memgraph::io::rsm::WriteResponse; using io::rsm::VoteRequest;
using memgraph::storage::v3::ShardManager; using io::rsm::VoteResponse;
using io::rsm::WriteRequest;
using io::rsm::WriteResponse;
using storage::v3::ShardManager;
/// The MachineManager is responsible for: /// The MachineManager is responsible for:
/// * starting the entire system and ensuring that high-level /// * starting the entire system and ensuring that high-level
@ -62,7 +66,9 @@ template <typename IoImpl>
class MachineManager { class MachineManager {
io::Io<IoImpl> io_; io::Io<IoImpl> io_;
MachineConfig config_; MachineConfig config_;
CoordinatorRsm<IoImpl> coordinator_; Address coordinator_address_;
CoordinatorQueue coordinator_queue_;
std::jthread coordinator_handle_;
ShardManager<IoImpl> shard_manager_; ShardManager<IoImpl> shard_manager_;
Time next_cron_ = Time::min(); Time next_cron_ = Time::min();
@ -72,10 +78,27 @@ class MachineManager {
MachineManager(io::Io<IoImpl> io, MachineConfig config, Coordinator coordinator) MachineManager(io::Io<IoImpl> io, MachineConfig config, Coordinator coordinator)
: io_(io), : io_(io),
config_(config), config_(config),
coordinator_{std::move(io.ForkLocal()), {}, std::move(coordinator)}, coordinator_address_(io.GetAddress().ForkUniqueAddress()),
shard_manager_{io.ForkLocal(), coordinator_.GetAddress()} {} shard_manager_{io.ForkLocal(), config.shard_worker_threads, coordinator_address_} {
auto coordinator_io = io.ForkLocal();
coordinator_io.SetAddress(coordinator_address_);
CoordinatorWorker coordinator_worker{coordinator_io, coordinator_queue_, coordinator};
coordinator_handle_ = std::jthread([coordinator = std::move(coordinator_worker)]() mutable { coordinator.Run(); });
}
Address CoordinatorAddress() { return coordinator_.GetAddress(); } MachineManager(MachineManager &&) noexcept = default;
MachineManager &operator=(MachineManager &&) noexcept = default;
MachineManager(const MachineManager &) = delete;
MachineManager &operator=(const MachineManager &) = delete;
~MachineManager() {
if (coordinator_handle_.joinable()) {
coordinator_queue_.Push(coordinator::coordinator_worker::ShutDown{});
coordinator_handle_.join();
}
}
Address CoordinatorAddress() { return coordinator_address_; }
void Run() { void Run() {
while (!io_.ShouldShutDown()) { while (!io_.ShouldShutDown()) {
@ -85,7 +108,7 @@ class MachineManager {
next_cron_ = Cron(); next_cron_ = Cron();
} }
Duration receive_timeout = next_cron_ - now; Duration receive_timeout = std::max(next_cron_, now) - now;
// Note: this parameter pack must be kept in-sync with the ReceiveWithTimeout parameter pack below // Note: this parameter pack must be kept in-sync with the ReceiveWithTimeout parameter pack below
using AllMessages = using AllMessages =
@ -113,7 +136,7 @@ class MachineManager {
spdlog::info("MM got message to {}", request_envelope.to_address.ToString()); spdlog::info("MM got message to {}", request_envelope.to_address.ToString());
// If message is for the coordinator, cast it to subset and pass it to the coordinator // If message is for the coordinator, cast it to subset and pass it to the coordinator
bool to_coordinator = coordinator_.GetAddress() == request_envelope.to_address; bool to_coordinator = coordinator_address_ == request_envelope.to_address;
if (to_coordinator) { if (to_coordinator) {
std::optional<CoordinatorMessages> conversion_attempt = std::optional<CoordinatorMessages> conversion_attempt =
ConvertVariant<AllMessages, ReadRequest<CoordinatorReadRequests>, AppendRequest<CoordinatorWriteRequests>, ConvertVariant<AllMessages, ReadRequest<CoordinatorReadRequests>, AppendRequest<CoordinatorWriteRequests>,
@ -126,8 +149,13 @@ class MachineManager {
CoordinatorMessages &&cm = std::move(conversion_attempt.value()); CoordinatorMessages &&cm = std::move(conversion_attempt.value());
coordinator_.Handle(std::forward<CoordinatorMessages>(cm), request_envelope.request_id, CoordinatorRouteMessage route_message{
request_envelope.from_address); .message = std::move(cm),
.request_id = request_envelope.request_id,
.to = request_envelope.to_address,
.from = request_envelope.from_address,
};
coordinator_queue_.Push(std::move(route_message));
continue; continue;
} }
@ -168,6 +196,7 @@ class MachineManager {
private: private:
Time Cron() { Time Cron() {
spdlog::info("running MachineManager::Cron, address {}", io_.GetAddress().ToString()); spdlog::info("running MachineManager::Cron, address {}", io_.GetAddress().ToString());
coordinator_queue_.Push(coordinator::coordinator_worker::Cron{});
return shard_manager_.Cron(); return shard_manager_.Cron();
} }
}; };

View File

@ -13,47 +13,50 @@
#include <queue> #include <queue>
#include <set> #include <set>
#include <unordered_map>
#include <boost/functional/hash.hpp>
#include <boost/uuid/uuid.hpp> #include <boost/uuid/uuid.hpp>
#include <coordinator/coordinator.hpp> #include "coordinator/coordinator.hpp"
#include <io/address.hpp>
#include <io/message_conversion.hpp>
#include <io/messages.hpp>
#include <io/rsm/raft.hpp>
#include <io/time.hpp>
#include <io/transport.hpp>
#include <query/v2/requests.hpp>
#include <storage/v3/shard.hpp>
#include <storage/v3/shard_rsm.hpp>
#include "coordinator/shard_map.hpp" #include "coordinator/shard_map.hpp"
#include "io/address.hpp"
#include "io/message_conversion.hpp"
#include "io/messages.hpp"
#include "io/rsm/raft.hpp"
#include "io/time.hpp"
#include "io/transport.hpp"
#include "query/v2/requests.hpp"
#include "storage/v3/config.hpp" #include "storage/v3/config.hpp"
#include "storage/v3/shard.hpp"
#include "storage/v3/shard_rsm.hpp"
#include "storage/v3/shard_worker.hpp"
namespace memgraph::storage::v3 { namespace memgraph::storage::v3 {
using boost::uuids::uuid; using boost::uuids::uuid;
using memgraph::coordinator::CoordinatorWriteRequests; using coordinator::CoordinatorWriteRequests;
using memgraph::coordinator::CoordinatorWriteResponses; using coordinator::CoordinatorWriteResponses;
using memgraph::coordinator::HeartbeatRequest; using coordinator::HeartbeatRequest;
using memgraph::coordinator::HeartbeatResponse; using coordinator::HeartbeatResponse;
using memgraph::io::Address; using io::Address;
using memgraph::io::Duration; using io::Duration;
using memgraph::io::Message; using io::Message;
using memgraph::io::RequestId; using io::RequestId;
using memgraph::io::ResponseFuture; using io::ResponseFuture;
using memgraph::io::Time; using io::Time;
using memgraph::io::messages::CoordinatorMessages; using io::messages::CoordinatorMessages;
using memgraph::io::messages::ShardManagerMessages; using io::messages::ShardManagerMessages;
using memgraph::io::messages::ShardMessages; using io::messages::ShardMessages;
using memgraph::io::rsm::Raft; using io::rsm::Raft;
using memgraph::io::rsm::WriteRequest; using io::rsm::WriteRequest;
using memgraph::io::rsm::WriteResponse; using io::rsm::WriteResponse;
using memgraph::msgs::ReadRequests; using msgs::ReadRequests;
using memgraph::msgs::ReadResponses; using msgs::ReadResponses;
using memgraph::msgs::WriteRequests; using msgs::WriteRequests;
using memgraph::msgs::WriteResponses; using msgs::WriteResponses;
using memgraph::storage::v3::ShardRsm; using storage::v3::ShardRsm;
using ShardManagerOrRsmMessage = std::variant<ShardMessages, ShardManagerMessages>; using ShardManagerOrRsmMessage = std::variant<ShardMessages, ShardManagerMessages>;
using TimeUuidPair = std::pair<Time, uuid>; using TimeUuidPair = std::pair<Time, uuid>;
@ -77,7 +80,71 @@ static_assert(kMinimumCronInterval < kMaximumCronInterval,
template <typename IoImpl> template <typename IoImpl>
class ShardManager { class ShardManager {
public: public:
ShardManager(io::Io<IoImpl> io, Address coordinator_leader) : io_(io), coordinator_leader_(coordinator_leader) {} ShardManager(io::Io<IoImpl> io, size_t shard_worker_threads, Address coordinator_leader)
: io_(io), coordinator_leader_(coordinator_leader) {
MG_ASSERT(shard_worker_threads >= 1);
for (int i = 0; i < shard_worker_threads; i++) {
shard_worker::Queue queue;
shard_worker::ShardWorker worker{io, queue};
auto worker_handle = std::jthread([worker = std::move(worker)]() mutable { worker.Run(); });
workers_.emplace_back(queue);
worker_handles_.emplace_back(std::move(worker_handle));
worker_rsm_counts_.emplace_back(0);
}
}
ShardManager(ShardManager &&) noexcept = default;
ShardManager &operator=(ShardManager &&) noexcept = default;
ShardManager(const ShardManager &) = delete;
ShardManager &operator=(const ShardManager &) = delete;
~ShardManager() {
for (auto worker : workers_) {
worker.Push(shard_worker::ShutDown{});
}
workers_.clear();
// The jthread handes for our shard worker threads will be
// blocked on implicitly when worker_handles_ is destroyed.
}
size_t UuidToWorkerIndex(const uuid &to) {
if (rsm_worker_mapping_.contains(to)) {
return rsm_worker_mapping_.at(to);
}
// We will now create a mapping for this (probably new) shard
// by choosing the worker with the lowest number of existing
// mappings.
size_t min_index = 0;
size_t min_count = worker_rsm_counts_.at(min_index);
for (int i = 0; i < worker_rsm_counts_.size(); i++) {
size_t worker_count = worker_rsm_counts_.at(i);
if (worker_count <= min_count) {
min_count = worker_count;
min_index = i;
}
}
worker_rsm_counts_[min_index]++;
rsm_worker_mapping_.emplace(to, min_index);
return min_index;
}
void SendToWorkerByIndex(size_t worker_index, shard_worker::Message &&message) {
workers_[worker_index].Push(std::forward<shard_worker::Message>(message));
}
void SendToWorkerByUuid(const uuid &to, shard_worker::Message &&message) {
size_t worker_index = UuidToWorkerIndex(to);
SendToWorkerByIndex(worker_index, std::forward<shard_worker::Message>(message));
}
/// Periodic protocol maintenance. Returns the time that Cron should be called again /// Periodic protocol maintenance. Returns the time that Cron should be called again
/// in the future. /// in the future.
@ -85,33 +152,23 @@ class ShardManager {
spdlog::info("running ShardManager::Cron, address {}", io_.GetAddress().ToString()); spdlog::info("running ShardManager::Cron, address {}", io_.GetAddress().ToString());
Time now = io_.Now(); Time now = io_.Now();
if (now >= next_cron_) { if (now >= next_reconciliation_) {
Reconciliation(); Reconciliation();
std::uniform_int_distribution time_distrib(kMinimumCronInterval.count(), kMaximumCronInterval.count()); std::uniform_int_distribution time_distrib(kMinimumCronInterval.count(), kMaximumCronInterval.count());
const auto rand = io_.Rand(time_distrib); const auto rand = io_.Rand(time_distrib);
next_cron_ = now + Duration{rand}; next_reconciliation_ = now + Duration{rand};
} }
if (!cron_schedule_.empty()) { for (auto &worker : workers_) {
const auto &[time, uuid] = cron_schedule_.top(); worker.Push(shard_worker::Cron{});
if (time <= now) {
auto &rsm = rsm_map_.at(uuid);
Time next_for_uuid = rsm.Cron();
cron_schedule_.pop();
cron_schedule_.push(std::make_pair(next_for_uuid, uuid));
const auto &[next_time, _uuid] = cron_schedule_.top();
return std::min(next_cron_, next_time);
}
} }
return next_cron_; Time next_worker_cron = now + std::chrono::milliseconds(500);
return std::min(next_worker_cron, next_reconciliation_);
} }
/// Returns the Address for our underlying Io implementation /// Returns the Address for our underlying Io implementation
@ -125,16 +182,21 @@ class ShardManager {
MG_ASSERT(address.last_known_port == to.last_known_port); MG_ASSERT(address.last_known_port == to.last_known_port);
MG_ASSERT(address.last_known_ip == to.last_known_ip); MG_ASSERT(address.last_known_ip == to.last_known_ip);
auto &rsm = rsm_map_.at(to.unique_id); SendToWorkerByUuid(to.unique_id, shard_worker::RouteMessage{
.message = std::move(sm),
rsm.Handle(std::forward<ShardMessages>(sm), request_id, from); .request_id = request_id,
.to = to,
.from = from,
});
} }
private: private:
io::Io<IoImpl> io_; io::Io<IoImpl> io_;
std::map<uuid, ShardRaft<IoImpl>> rsm_map_; std::vector<shard_worker::Queue> workers_;
std::priority_queue<std::pair<Time, uuid>, std::vector<std::pair<Time, uuid>>, std::greater<>> cron_schedule_; std::vector<std::jthread> worker_handles_;
Time next_cron_ = Time::min(); std::vector<size_t> worker_rsm_counts_;
std::unordered_map<uuid, size_t, boost::hash<boost::uuids::uuid>> rsm_worker_mapping_;
Time next_reconciliation_ = Time::min();
Address coordinator_leader_; Address coordinator_leader_;
std::optional<ResponseFuture<WriteResponse<CoordinatorWriteResponses>>> heartbeat_res_; std::optional<ResponseFuture<WriteResponse<CoordinatorWriteResponses>>> heartbeat_res_;
@ -188,39 +250,22 @@ class ShardManager {
} }
void EnsureShardsInitialized(HeartbeatResponse hr) { void EnsureShardsInitialized(HeartbeatResponse hr) {
for (const auto &shard_to_initialize : hr.shards_to_initialize) { for (const auto &to_init : hr.shards_to_initialize) {
InitializeRsm(shard_to_initialize); initialized_but_not_confirmed_rsm_.emplace(to_init.uuid);
initialized_but_not_confirmed_rsm_.emplace(shard_to_initialize.uuid);
}
}
/// Returns true if the RSM was able to be initialized, and false if it was already initialized if (rsm_worker_mapping_.contains(to_init.uuid)) {
void InitializeRsm(coordinator::ShardToInitialize to_init) {
if (rsm_map_.contains(to_init.uuid)) {
// it's not a bug for the coordinator to send us UUIDs that we have // it's not a bug for the coordinator to send us UUIDs that we have
// already created, because there may have been lag that caused // already created, because there may have been lag that caused
// the coordinator not to hear back from us. // the coordinator not to hear back from us.
return; return;
} }
auto rsm_io = io_.ForkLocal(); size_t worker_index = UuidToWorkerIndex(to_init.uuid);
auto io_addr = rsm_io.GetAddress();
io_addr.unique_id = to_init.uuid;
rsm_io.SetAddress(io_addr);
// TODO(tyler) get peers from Coordinator in HeartbeatResponse SendToWorkerByIndex(worker_index, to_init);
std::vector<Address> rsm_peers = {};
std::unique_ptr<Shard> shard = std::make_unique<Shard>(to_init.label_id, to_init.min_key, to_init.max_key, rsm_worker_mapping_.emplace(to_init.uuid, worker_index);
to_init.schema, to_init.config, to_init.id_to_names); }
ShardRsm rsm_state{std::move(shard)};
ShardRaft<IoImpl> rsm{std::move(rsm_io), rsm_peers, std::move(rsm_state)};
spdlog::info("SM created a new shard with UUID {}", to_init.uuid);
rsm_map_.emplace(to_init.uuid, std::move(rsm));
} }
}; };

View File

@ -0,0 +1,224 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <chrono>
#include <deque>
#include <memory>
#include <queue>
#include <variant>
#include <boost/uuid/uuid.hpp>
#include "coordinator/coordinator.hpp"
#include "coordinator/shard_map.hpp"
#include "io/address.hpp"
#include "io/future.hpp"
#include "io/messages.hpp"
#include "io/rsm/raft.hpp"
#include "io/time.hpp"
#include "io/transport.hpp"
#include "query/v2/requests.hpp"
#include "storage/v3/shard_rsm.hpp"
namespace memgraph::storage::v3::shard_worker {
/// Obligations:
/// * ShutDown
/// * Cron
/// * RouteMessage
/// * ShardToInitialize
using boost::uuids::uuid;
using coordinator::ShardToInitialize;
using io::Address;
using io::RequestId;
using io::Time;
using io::messages::ShardMessages;
using io::rsm::Raft;
using msgs::ReadRequests;
using msgs::ReadResponses;
using msgs::WriteRequests;
using msgs::WriteResponses;
using storage::v3::ShardRsm;
template <typename IoImpl>
using ShardRaft = Raft<IoImpl, ShardRsm, WriteRequests, WriteResponses, ReadRequests, ReadResponses>;
struct ShutDown {};
struct Cron {};
struct RouteMessage {
ShardMessages message;
RequestId request_id;
Address to;
Address from;
};
using Message = std::variant<ShutDown, Cron, ShardToInitialize, RouteMessage>;
struct QueueInner {
std::mutex mu{};
std::condition_variable cv;
// TODO(tyler) handle simulator communication std::shared_ptr<std::atomic<int>> blocked;
// TODO(tyler) investigate using a priority queue that prioritizes messages in a way that
// improves overall QoS. For example, maybe we want to schedule raft Append messages
// ahead of Read messages or generally writes before reads for lowering the load on the
// overall system faster etc... When we do this, we need to make sure to avoid
// starvation by sometimes randomizing priorities, rather than following a strict
// prioritization.
std::deque<Message> queue;
};
/// There are two reasons to implement our own Queue instead of using
/// one off-the-shelf:
/// 1. we will need to know in the simulator when all threads are waiting
/// 2. we will want to implement our own priority queue within this for QoS
class Queue {
std::shared_ptr<QueueInner> inner_ = std::make_shared<QueueInner>();
public:
void Push(Message &&message) {
{
MG_ASSERT(inner_.use_count() > 0);
std::unique_lock<std::mutex> lock(inner_->mu);
inner_->queue.emplace_back(std::forward<Message>(message));
} // lock dropped before notifying condition variable
inner_->cv.notify_all();
}
Message Pop() {
MG_ASSERT(inner_.use_count() > 0);
std::unique_lock<std::mutex> lock(inner_->mu);
while (inner_->queue.empty()) {
inner_->cv.wait(lock);
}
Message message = std::move(inner_->queue.front());
inner_->queue.pop_front();
return message;
}
};
/// A ShardWorker owns Raft<ShardRsm> instances. receives messages from the ShardManager.
template <class IoImpl>
class ShardWorker {
io::Io<IoImpl> io_;
Queue queue_;
std::priority_queue<std::pair<Time, uuid>, std::vector<std::pair<Time, uuid>>, std::greater<>> cron_schedule_;
Time next_cron_ = Time::min();
std::map<uuid, ShardRaft<IoImpl>> rsm_map_;
bool Process(ShutDown && /* shut_down */) { return false; }
bool Process(Cron && /* cron */) {
Cron();
return true;
}
bool Process(ShardToInitialize &&shard_to_initialize) {
InitializeRsm(std::forward<ShardToInitialize>(shard_to_initialize));
return true;
}
bool Process(RouteMessage &&route_message) {
auto &rsm = rsm_map_.at(route_message.to.unique_id);
rsm.Handle(std::move(route_message.message), route_message.request_id, route_message.from);
return true;
}
Time Cron() {
spdlog::info("running ShardWorker::Cron, address {}", io_.GetAddress().ToString());
Time now = io_.Now();
while (!cron_schedule_.empty()) {
const auto &[time, uuid] = cron_schedule_.top();
if (time <= now) {
auto &rsm = rsm_map_.at(uuid);
Time next_for_uuid = rsm.Cron();
cron_schedule_.pop();
cron_schedule_.push(std::make_pair(next_for_uuid, uuid));
} else {
return time;
}
}
return now + std::chrono::microseconds(1000);
}
void InitializeRsm(ShardToInitialize to_init) {
if (rsm_map_.contains(to_init.uuid)) {
// it's not a bug for the coordinator to send us UUIDs that we have
// already created, because there may have been lag that caused
// the coordinator not to hear back from us.
return;
}
auto rsm_io = io_.ForkLocal();
auto io_addr = rsm_io.GetAddress();
io_addr.unique_id = to_init.uuid;
rsm_io.SetAddress(io_addr);
// TODO(tyler) get peers from Coordinator in HeartbeatResponse
std::vector<Address> rsm_peers = {};
std::unique_ptr<Shard> shard = std::make_unique<Shard>(to_init.label_id, to_init.min_key, to_init.max_key,
to_init.schema, to_init.config, to_init.id_to_names);
ShardRsm rsm_state{std::move(shard)};
ShardRaft<IoImpl> rsm{std::move(rsm_io), rsm_peers, std::move(rsm_state)};
spdlog::info("SM created a new shard with UUID {}", to_init.uuid);
// perform an initial Cron call for the new RSM
Time next_cron = rsm.Cron();
cron_schedule_.push(std::make_pair(next_cron, to_init.uuid));
rsm_map_.emplace(to_init.uuid, std::move(rsm));
}
public:
ShardWorker(io::Io<IoImpl> io, Queue queue) : io_(io), queue_(queue) {}
ShardWorker(ShardWorker &&) noexcept = default;
ShardWorker &operator=(ShardWorker &&) noexcept = default;
ShardWorker(const ShardWorker &) = delete;
ShardWorker &operator=(const ShardWorker &) = delete;
~ShardWorker() = default;
void Run() {
while (true) {
Message message = queue_.Pop();
const bool should_continue =
std::visit([&](auto &&msg) { return Process(std::forward<decltype(msg)>(msg)); }, std::move(message));
if (!should_continue) {
return;
}
}
}
};
} // namespace memgraph::storage::v3::shard_worker

View File

@ -18,6 +18,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <rapidcheck.h> #include <rapidcheck.h>
#include <rapidcheck/gtest.h> #include <rapidcheck/gtest.h>
#include <spdlog/cfg/env.h>
#include "generated_operations.hpp" #include "generated_operations.hpp"
#include "io/simulator/simulator_config.hpp" #include "io/simulator/simulator_config.hpp"
@ -35,6 +36,8 @@ using storage::v3::kMaximumCronInterval;
RC_GTEST_PROP(RandomClusterConfig, HappyPath, (ClusterConfig cluster_config, NonEmptyOpVec ops)) { RC_GTEST_PROP(RandomClusterConfig, HappyPath, (ClusterConfig cluster_config, NonEmptyOpVec ops)) {
// TODO(tyler) set abort_time to something more restrictive than Time::max() // TODO(tyler) set abort_time to something more restrictive than Time::max()
spdlog::cfg::load_env_levels();
SimulatorConfig sim_config{ SimulatorConfig sim_config{
.drop_percent = 0, .drop_percent = 0,
.perform_timeouts = false, .perform_timeouts = false,

View File

@ -194,6 +194,22 @@ void ExecuteOp(msgs::ShardRequestManager<SimulatorTransport> &shard_request_mana
} }
} }
/// This struct exists as a way of detaching
/// a thread if something causes an uncaught
/// exception - because that thread would not
/// receive a ShutDown message otherwise, and
/// would cause the test to hang forever.
struct DetachIfDropped {
std::jthread &handle;
bool detach = true;
~DetachIfDropped() {
if (detach && handle.joinable()) {
handle.detach();
}
}
};
void RunClusterSimulation(const SimulatorConfig &sim_config, const ClusterConfig &cluster_config, void RunClusterSimulation(const SimulatorConfig &sim_config, const ClusterConfig &cluster_config,
const std::vector<Op> &ops) { const std::vector<Op> &ops) {
spdlog::info("========================== NEW SIMULATION =========================="); spdlog::info("========================== NEW SIMULATION ==========================");
@ -217,9 +233,7 @@ void RunClusterSimulation(const SimulatorConfig &sim_config, const ClusterConfig
auto mm_thread_1 = std::jthread(RunMachine, std::move(mm_1)); auto mm_thread_1 = std::jthread(RunMachine, std::move(mm_1));
// Need to detach this thread so that the destructor does not auto detach_on_error = DetachIfDropped{.handle = mm_thread_1};
// block before we can propagate assertion failures.
mm_thread_1.detach();
// TODO(tyler) clarify addresses of coordinator etc... as it's a mess // TODO(tyler) clarify addresses of coordinator etc... as it's a mess
@ -236,6 +250,11 @@ void RunClusterSimulation(const SimulatorConfig &sim_config, const ClusterConfig
std::visit([&](auto &o) { ExecuteOp(shard_request_manager, correctness_model, o); }, op.inner); std::visit([&](auto &o) { ExecuteOp(shard_request_manager, correctness_model, o); }, op.inner);
} }
// We have now completed our workload without failing any assertions, so we can
// disable detaching the worker thread, which will cause the mm_thread_1 jthread
// to be joined when this function returns.
detach_on_error.detach = false;
simulator.ShutDown(); simulator.ShutDown();
SimulatorStats stats = simulator.Stats(); SimulatorStats stats = simulator.Stats();

View File

@ -400,6 +400,6 @@ target_link_libraries(${test_prefix}pretty_print_ast_to_original_expression_test
add_unit_test(coordinator_shard_map.cpp) add_unit_test(coordinator_shard_map.cpp)
target_link_libraries(${test_prefix}coordinator_shard_map mg-coordinator) target_link_libraries(${test_prefix}coordinator_shard_map mg-coordinator)
# Tests for 1000 shards, 1000 creates, scan # Tests for many shards, many creates, scan
add_unit_test(1k_shards_1k_create_scanall.cpp) add_unit_test(high_density_shard_create_scan.cpp)
target_link_libraries(${test_prefix}1k_shards_1k_create_scanall mg-io mg-coordinator mg-storage-v3 mg-query-v2) target_link_libraries(${test_prefix}high_density_shard_create_scan mg-io mg-coordinator mg-storage-v3 mg-query-v2)

View File

@ -31,7 +31,6 @@
#include "machine_manager/machine_manager.hpp" #include "machine_manager/machine_manager.hpp"
#include "query/v2/requests.hpp" #include "query/v2/requests.hpp"
#include "query/v2/shard_request_manager.hpp" #include "query/v2/shard_request_manager.hpp"
#include "utils/print_helpers.hpp"
#include "utils/variant_helpers.hpp" #include "utils/variant_helpers.hpp"
namespace memgraph::tests::simulation { namespace memgraph::tests::simulation {
@ -82,13 +81,14 @@ struct ScanAll {
}; };
MachineManager<LocalTransport> MkMm(LocalSystem &local_system, std::vector<Address> coordinator_addresses, Address addr, MachineManager<LocalTransport> MkMm(LocalSystem &local_system, std::vector<Address> coordinator_addresses, Address addr,
ShardMap shard_map) { ShardMap shard_map, size_t shard_worker_threads) {
MachineConfig config{ MachineConfig config{
.coordinator_addresses = std::move(coordinator_addresses), .coordinator_addresses = std::move(coordinator_addresses),
.is_storage = true, .is_storage = true,
.is_coordinator = true, .is_coordinator = true,
.listen_ip = addr.last_known_ip, .listen_ip = addr.last_known_ip,
.listen_port = addr.last_known_port, .listen_port = addr.last_known_port,
.shard_worker_threads = shard_worker_threads,
}; };
Io<LocalTransport> io = local_system.Register(addr); Io<LocalTransport> io = local_system.Register(addr);
@ -124,7 +124,7 @@ void WaitForShardsToInitialize(CoordinatorClient<LocalTransport> &coordinator_cl
} }
} }
ShardMap TestShardMap(int n_splits, int replication_factor) { ShardMap TestShardMap(int shards, int replication_factor, int gap_between_shards) {
ShardMap sm{}; ShardMap sm{};
const auto label_name = std::string("test_label"); const auto label_name = std::string("test_label");
@ -147,8 +147,8 @@ ShardMap TestShardMap(int n_splits, int replication_factor) {
MG_ASSERT(label_id.has_value()); MG_ASSERT(label_id.has_value());
// split the shard at N split points // split the shard at N split points
for (int64_t i = 1; i < n_splits; ++i) { for (int64_t i = 1; i < shards; ++i) {
const auto key1 = memgraph::storage::v3::PropertyValue(i); const auto key1 = memgraph::storage::v3::PropertyValue(i * gap_between_shards);
const auto key2 = memgraph::storage::v3::PropertyValue(0); const auto key2 = memgraph::storage::v3::PropertyValue(0);
const auto split_point = {key1, key2}; const auto split_point = {key1, key2};
@ -208,7 +208,16 @@ void ExecuteOp(msgs::ShardRequestManager<LocalTransport> &shard_request_manager,
} }
} }
TEST(MachineManager, ManyShards) { void RunWorkload(int shards, int replication_factor, int create_ops, int scan_ops, int shard_worker_threads,
int gap_between_shards) {
spdlog::info("======================== NEW TEST ========================");
spdlog::info("shards: ", shards);
spdlog::info("replication factor: ", replication_factor);
spdlog::info("create ops: ", create_ops);
spdlog::info("scan all ops: ", scan_ops);
spdlog::info("shard worker threads: ", shard_worker_threads);
spdlog::info("gap between shards: ", gap_between_shards);
LocalSystem local_system; LocalSystem local_system;
auto cli_addr = Address::TestAddress(1); auto cli_addr = Address::TestAddress(1);
@ -221,19 +230,20 @@ TEST(MachineManager, ManyShards) {
machine_1_addr, machine_1_addr,
}; };
auto shard_splits = 1024; auto time_before_shard_map_creation = cli_io_2.Now();
auto replication_factor = 1; ShardMap initialization_sm = TestShardMap(shards, replication_factor, gap_between_shards);
auto create_ops = 1000; auto time_after_shard_map_creation = cli_io_2.Now();
ShardMap initialization_sm = TestShardMap(shard_splits, replication_factor); auto mm_1 = MkMm(local_system, coordinator_addresses, machine_1_addr, initialization_sm, shard_worker_threads);
auto mm_1 = MkMm(local_system, coordinator_addresses, machine_1_addr, initialization_sm);
Address coordinator_address = mm_1.CoordinatorAddress(); Address coordinator_address = mm_1.CoordinatorAddress();
auto mm_thread_1 = std::jthread(RunMachine, std::move(mm_1)); auto mm_thread_1 = std::jthread(RunMachine, std::move(mm_1));
CoordinatorClient<LocalTransport> coordinator_client(cli_io, coordinator_address, {coordinator_address}); CoordinatorClient<LocalTransport> coordinator_client(cli_io, coordinator_address, {coordinator_address});
auto time_before_shard_stabilization = cli_io_2.Now();
WaitForShardsToInitialize(coordinator_client); WaitForShardsToInitialize(coordinator_client);
auto time_after_shard_stabilization = cli_io_2.Now();
msgs::ShardRequestManager<LocalTransport> shard_request_manager(std::move(coordinator_client), std::move(cli_io)); msgs::ShardRequestManager<LocalTransport> shard_request_manager(std::move(coordinator_client), std::move(cli_io));
@ -241,18 +251,54 @@ TEST(MachineManager, ManyShards) {
auto correctness_model = std::set<CompoundKey>{}; auto correctness_model = std::set<CompoundKey>{};
auto time_before_creates = cli_io_2.Now();
for (int i = 0; i < create_ops; i++) { for (int i = 0; i < create_ops; i++) {
ExecuteOp(shard_request_manager, correctness_model, CreateVertex{.first = i, .second = i}); ExecuteOp(shard_request_manager, correctness_model, CreateVertex{.first = i, .second = i});
} }
auto time_after_creates = cli_io_2.Now();
for (int i = 0; i < scan_ops; i++) {
ExecuteOp(shard_request_manager, correctness_model, ScanAll{}); ExecuteOp(shard_request_manager, correctness_model, ScanAll{});
}
auto time_after_scan = cli_io_2.Now();
local_system.ShutDown(); local_system.ShutDown();
auto histo = cli_io_2.ResponseLatencies(); auto latencies = cli_io_2.ResponseLatencies();
using memgraph::utils::print_helpers::operator<<; spdlog::info("response latencies: \n{}", latencies.SummaryTable());
std::cout << "response latencies: " << histo << std::endl;
spdlog::info("serial time break-down: (μs)");
spdlog::info("{: >20}: {: >10}", "split shard map",
(time_after_shard_map_creation - time_before_shard_map_creation).count());
spdlog::info("{: >20}: {: >10}", "shard stabilization",
(time_after_shard_stabilization - time_before_shard_stabilization).count());
spdlog::info("{: >20}: {: >10}", "create nodes", (time_after_creates - time_before_creates).count());
spdlog::info("{: >20}: {: >10}", "scan nodes", (time_after_scan - time_after_creates).count());
std::cout << fmt::format("{} {} {}\n", shards, shard_worker_threads, (time_after_scan - time_after_creates).count());
}
TEST(MachineManager, ManyShards) {
auto shards_attempts = {1, 64};
auto shard_worker_thread_attempts = {1, 32};
auto replication_factor = 1;
auto create_ops = 128;
auto scan_ops = 1;
std::cout << "splits threads scan_all_microseconds\n";
for (const auto shards : shards_attempts) {
auto gap_between_shards = create_ops / shards;
for (const auto shard_worker_threads : shard_worker_thread_attempts) {
RunWorkload(shards, replication_factor, create_ops, scan_ops, shard_worker_threads, gap_between_shards);
}
}
} }
} // namespace memgraph::tests::simulation } // namespace memgraph::tests::simulation