diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index dff8140db..f5b7859d8 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -16,6 +16,8 @@ set(memgraph_src_files database/graph_db_config.cpp database/graph_db_accessor.cpp database/state_delta.cpp + distributed/coordination_master.cpp + distributed/coordination_worker.cpp durability/paths.cpp durability/recovery.cpp durability/snapshooter.cpp diff --git a/src/communication/bolt/v1/session.hpp b/src/communication/bolt/v1/session.hpp index 4f1209e13..a5268c5cb 100644 --- a/src/communication/bolt/v1/session.hpp +++ b/src/communication/bolt/v1/session.hpp @@ -2,13 +2,6 @@ #include "glog/logging.h" -#include "io/network/epoll.hpp" -#include "io/network/socket.hpp" -#include "io/network/stream_buffer.hpp" - -#include "query/interpreter.hpp" -#include "transactions/transaction.hpp" - #include "communication/bolt/v1/constants.hpp" #include "communication/bolt/v1/decoder/chunked_decoder_buffer.hpp" #include "communication/bolt/v1/decoder/decoder.hpp" @@ -19,18 +12,25 @@ #include "communication/bolt/v1/states/executing.hpp" #include "communication/bolt/v1/states/handshake.hpp" #include "communication/bolt/v1/states/init.hpp" +#include "database/graph_db.hpp" +#include "io/network/epoll.hpp" +#include "io/network/socket.hpp" +#include "io/network/stream_buffer.hpp" +#include "query/interpreter.hpp" +#include "transactions/transaction.hpp" DECLARE_int32(session_inactivity_timeout); namespace communication::bolt { -/** - * Bolt SessionData - * - * This class is responsible for holding references to Dbms and Interpreter - * that are passed through the network server and worker to the session. - */ +/** Encapsulates Dbms and Interpreter that are passed through the network server + * and worker to the session. */ struct SessionData { + /** Constructs a SessionData object. + * @param args - Arguments forwarded to the GraphDb constructor. */ + template <typename... TArgs> + SessionData(TArgs &&... args) : db(std::forward<TArgs>(args)...) {} + GraphDb db; query::Interpreter interpreter; }; diff --git a/src/communication/messaging/distributed.cpp b/src/communication/messaging/distributed.cpp index 7da42de34..195bf7df8 100644 --- a/src/communication/messaging/distributed.cpp +++ b/src/communication/messaging/distributed.cpp @@ -9,6 +9,9 @@ System::System(const std::string &address, uint16_t port) StartServer(4); } +System::System(const io::network::NetworkEndpoint &endpoint) + : System(endpoint.address(), endpoint.port()) {} + System::~System() { for (size_t i = 0; i < pool_.size(); ++i) { pool_[i].join(); diff --git a/src/communication/messaging/distributed.hpp b/src/communication/messaging/distributed.hpp index 9e0b0c579..c185a5e46 100644 --- a/src/communication/messaging/distributed.hpp +++ b/src/communication/messaging/distributed.hpp @@ -27,6 +27,7 @@ #include "cereal/types/vector.hpp" #include "communication/server.hpp" +#include "io/network/network_endpoint.hpp" #include "threading/sync/spinlock.hpp" namespace communication::messaging { @@ -59,10 +60,13 @@ class Writer { }; class System { + using Endpoint = io::network::NetworkEndpoint; + public: friend class Writer; System(const std::string &address, uint16_t port); + System(const Endpoint &endpoint); System(const System &) = delete; System(System &&) = delete; System &operator=(const System &) = delete; @@ -75,7 +79,6 @@ class System { const io::network::NetworkEndpoint &endpoint() const { return endpoint_; } private: - using Endpoint = io::network::NetworkEndpoint; using Socket = Socket; using ServerT = communication::Server<Session, SessionData>; diff --git a/src/database/creation_exception.hpp b/src/database/creation_exception.hpp deleted file mode 100644 index 5ff397e16..000000000 --- a/src/database/creation_exception.hpp +++ /dev/null @@ -1,17 +0,0 @@ -// -// Copyright 2017 Memgraph -// Created by Florijan Stamenkovic on 03.02.17. -// - -#pragma once - -#include "utils/exceptions.hpp" - -/** - * Thrown when something (Edge or a Vertex) can not - * be created. Typically due to database overload. - */ -class CreationException : public utils::StacktraceException { - public: - using utils::StacktraceException::StacktraceException; -}; diff --git a/src/database/graph_db.cpp b/src/database/graph_db.cpp index 6b4a13fff..340449542 100644 --- a/src/database/graph_db.cpp +++ b/src/database/graph_db.cpp @@ -3,45 +3,83 @@ #include <glog/logging.h> -#include "database/creation_exception.hpp" #include "database/graph_db.hpp" #include "database/graph_db_accessor.hpp" #include "durability/paths.hpp" #include "durability/recovery.hpp" #include "durability/snapshooter.hpp" -#include "storage/edge.hpp" -#include "storage/garbage_collector.hpp" +#include "storage/concurrent_id_mapper_master.hpp" +#include "storage/concurrent_id_mapper_worker.hpp" #include "transactions/engine_master.hpp" +#include "transactions/engine_worker.hpp" #include "utils/timer.hpp" namespace fs = std::experimental::filesystem; -GraphDb::GraphDb(GraphDb::Config config) +#define INIT_MAPPERS(type, ...) \ + labels_ = std::make_unique<type<GraphDbTypes::Label>>(__VA_ARGS__); \ + edge_types_ = std::make_unique<type<GraphDbTypes::EdgeType>>(__VA_ARGS__); \ + properties_ = std::make_unique<type<GraphDbTypes::Property>>(__VA_ARGS__); + +GraphDb::GraphDb(Config config) : GraphDb(config, 0) { + tx_engine_ = std::make_unique<tx::MasterEngine>(); + INIT_MAPPERS(storage::SingleNodeConcurrentIdMapper); + Start(); +} + +GraphDb::GraphDb(communication::messaging::System &system, + distributed::MasterCoordination &master, Config config) + : GraphDb(config, 0) { + tx_engine_ = std::make_unique<tx::MasterEngine>(system); + INIT_MAPPERS(storage::MasterConcurrentIdMapper, system); + get_endpoint_ = [&master](int worker_id) { + return master.GetEndpoint(worker_id); + }; + Start(); +} + +GraphDb::GraphDb(communication::messaging::System &system, int worker_id, + distributed::WorkerCoordination &worker, + Endpoint master_endpoint, Config config) + : GraphDb(config, worker_id) { + tx_engine_ = std::make_unique<tx::WorkerEngine>(system, master_endpoint); + INIT_MAPPERS(storage::WorkerConcurrentIdMapper, system, master_endpoint); + get_endpoint_ = [&worker](int worker_id) { + return worker.GetEndpoint(worker_id); + }; + Start(); +} + +#undef INIT_MAPPERS + +GraphDb::GraphDb(Config config, int worker_id) : config_(config), - tx_engine_(new tx::MasterEngine()), + worker_id_(worker_id), gc_vertices_(vertices_, vertex_record_deleter_, vertex_version_list_deleter_), gc_edges_(edges_, edge_record_deleter_, edge_version_list_deleter_), - wal_{config.durability_directory, config.durability_enabled} { + wal_{config.durability_directory, config.durability_enabled} {} + +void GraphDb::Start() { // Pause of -1 means we shouldn't run the GC. - if (config.gc_cycle_sec != -1) { - gc_scheduler_.Run(std::chrono::seconds(config.gc_cycle_sec), + if (config_.gc_cycle_sec != -1) { + gc_scheduler_.Run(std::chrono::seconds(config_.gc_cycle_sec), [this]() { CollectGarbage(); }); } // If snapshots are enabled we need the durability dir. - if (config.durability_enabled) - durability::CheckDurabilityDir(config.durability_directory); + if (config_.durability_enabled) + durability::CheckDurabilityDir(config_.durability_directory); - if (config.db_recover_on_startup) - durability::Recover(config.durability_directory, *this); - if (config.durability_enabled) wal_.Enable(); + if (config_.db_recover_on_startup) + durability::Recover(config_.durability_directory, *this); + if (config_.durability_enabled) wal_.Enable(); StartSnapshooting(); - if (config.query_execution_time_sec != -1) { + if (config_.query_execution_time_sec != -1) { transaction_killer_.Run( std::chrono::seconds( - std::max(1, std::min(5, config.query_execution_time_sec / 4))), + std::max(1, std::min(5, config_.query_execution_time_sec / 4))), [this]() { tx_engine_->LocalForEachActiveTransaction([this](tx::Transaction &t) { if (t.creation_time() + @@ -103,8 +141,8 @@ void GraphDb::CollectGarbage() { // the ID of the oldest active transaction (or next active, if there // are no currently active). That's legal because that was the // last possible transaction that could have obtained pointers - // to those records. New snapshot can be used, different than one used for - // first two phases of gc. + // to those records. New snapshot can be used, different than one used + // for the first two phases of gc. utils::Timer x; const auto snapshot = tx_engine_->GlobalGcSnapshot(); edge_record_deleter_.FreeExpiredObjects(snapshot.back()); @@ -131,8 +169,8 @@ GraphDb::~GraphDb() { // Stop the gc scheduler to not run into race conditions for deletions. gc_scheduler_.Stop(); - // Stop the snapshot creator to avoid snapshooting while database is beeing - // deleted. + // Stop the snapshot creator to avoid snapshooting while database is + // being deleted. snapshot_creator_.Stop(); // Stop transaction killer. @@ -157,7 +195,8 @@ GraphDb::~GraphDb() { for (auto &id_vlist : vertices_.access()) delete id_vlist.second; for (auto &id_vlist : edges_.access()) delete id_vlist.second; - // Free expired records with the maximal possible id from all the deleters. + // Free expired records with the maximal possible id from all the + // deleters. edge_record_deleter_.FreeExpiredObjects(tx::Transaction::MaxId()); vertex_record_deleter_.FreeExpiredObjects(tx::Transaction::MaxId()); edge_version_list_deleter_.FreeExpiredObjects(tx::Transaction::MaxId()); diff --git a/src/database/graph_db.hpp b/src/database/graph_db.hpp index fd5adc2a8..43c5f78be 100644 --- a/src/database/graph_db.hpp +++ b/src/database/graph_db.hpp @@ -1,19 +1,25 @@ #pragma once +#include <memory> +#include <mutex> + #include "cppitertools/filter.hpp" #include "cppitertools/imap.hpp" #include "data_structures/concurrent/concurrent_map.hpp" #include "data_structures/concurrent/concurrent_set.hpp" -#include "data_structures/concurrent/skiplist.hpp" #include "database/graph_db_datatypes.hpp" #include "database/indexes/key_index.hpp" #include "database/indexes/label_property_index.hpp" +#include "distributed/coordination_master.hpp" +#include "distributed/coordination_worker.hpp" #include "durability/wal.hpp" +#include "io/network/network_endpoint.hpp" #include "mvcc/version_list.hpp" #include "storage/concurrent_id_mapper.hpp" #include "storage/concurrent_id_mapper_master.hpp" #include "storage/concurrent_id_mapper_single_node.hpp" +#include "storage/concurrent_id_mapper_worker.hpp" #include "storage/deferred_deleter.hpp" #include "storage/edge.hpp" #include "storage/garbage_collector.hpp" @@ -48,6 +54,8 @@ * -> CRASH */ class GraphDb { + using Endpoint = io::network::NetworkEndpoint; + public: /// GraphDb configuration. Initialized from flags, but modifiable. struct Config { @@ -65,15 +73,30 @@ class GraphDb { int query_execution_time_sec; }; - explicit GraphDb(Config config = Config{}); + /** Single-node GraphDb ctor. */ + GraphDb(Config config = Config{}); + + /** Distributed master GraphDb ctor. */ + GraphDb(communication::messaging::System &system, + distributed::MasterCoordination &master, Config config = Config()); + + /** Distributed worker GraphDb ctor. */ + GraphDb(communication::messaging::System &system, int worker_id, + distributed::WorkerCoordination &worker, Endpoint master_endpoint, + Config config = Config()); + + private: + // Private ctor used by other ctors. */ + GraphDb(Config config, int worker_id); + + public: /** Delete all vertices and edges and free all deferred deleters. */ ~GraphDb(); - /** Database object can't be copied. */ GraphDb(const GraphDb &db) = delete; - GraphDb(GraphDb &&other) = default; - GraphDb &operator=(const GraphDb &other) = default; - GraphDb &operator=(GraphDb &&other) = default; + GraphDb(GraphDb &&other) = delete; + GraphDb &operator=(const GraphDb &other) = delete; + GraphDb &operator=(GraphDb &&other) = delete; /** Stop all transactions and set is_accepting_transactions_ to false. */ void Shutdown(); @@ -86,8 +109,6 @@ class GraphDb { private: friend class GraphDbAccessor; - void StartSnapshooting(); - Config config_; /** Transaction engine related to this database. Master instance if this @@ -120,19 +141,11 @@ class GraphDb { // Id to value mappers. // TODO this should be also garbage collected - std::unique_ptr<storage::ConcurrentIdMapper<GraphDbTypes::Label, std::string>> - labels_{new storage::SingleNodeConcurrentIdMapper<GraphDbTypes::Label, - std::string>}; - std::unique_ptr< - storage::ConcurrentIdMapper<GraphDbTypes::EdgeType, std::string>> - edge_types_{ - new storage::SingleNodeConcurrentIdMapper<GraphDbTypes::EdgeType, - std::string>}; - std::unique_ptr< - storage::ConcurrentIdMapper<GraphDbTypes::Property, std::string>> - properties_{ - new storage::SingleNodeConcurrentIdMapper<GraphDbTypes::Property, - std::string>}; + std::unique_ptr<storage::ConcurrentIdMapper<GraphDbTypes::Label>> labels_; + std::unique_ptr<storage::ConcurrentIdMapper<GraphDbTypes::EdgeType>> + edge_types_; + std::unique_ptr<storage::ConcurrentIdMapper<GraphDbTypes::Property>> + properties_; // indexes KeyIndex<GraphDbTypes::Label, Vertex> labels_index_; @@ -152,4 +165,13 @@ class GraphDb { // DB level global counters, used in the "counter" function. ConcurrentMap<std::string, std::atomic<int64_t>> counters_; + + // Returns Endpoint info for worker ID. Different implementation in master vs. + // worker. Unused in single-node version. + std::function<io::network::NetworkEndpoint(int)> get_endpoint_; + + // Starts DB operations once all members have been constructed. + void Start(); + // Starts periodically generating database snapshots. + void StartSnapshooting(); }; diff --git a/src/database/graph_db_accessor.cpp b/src/database/graph_db_accessor.cpp index af5ea91df..6ffc9adb4 100644 --- a/src/database/graph_db_accessor.cpp +++ b/src/database/graph_db_accessor.cpp @@ -1,6 +1,5 @@ #include "glog/logging.h" -#include "database/creation_exception.hpp" #include "database/graph_db_accessor.hpp" #include "storage/edge.hpp" #include "storage/edge_accessor.hpp" diff --git a/src/database/graph_db_accessor.hpp b/src/database/graph_db_accessor.hpp index 3f27670f8..a69352e42 100644 --- a/src/database/graph_db_accessor.hpp +++ b/src/database/graph_db_accessor.hpp @@ -612,9 +612,9 @@ class GraphDbAccessor { /** Casts the DB's engine to MasterEngine and returns it. If the DB's engine * is RemoteEngine, this function will crash MG. */ tx::MasterEngine &MasterEngine() { - auto *local_engine = dynamic_cast<tx::MasterEngine *>(db_.tx_engine_.get()); - DCHECK(local_engine) << "Asked for MasterEngine on distributed worker"; - return *local_engine; + auto *master_engine = dynamic_cast<tx::MasterEngine *>(db_.tx_engine_.get()); + DCHECK(master_engine) << "Asked for MasterEngine on distributed worker"; + return *master_engine; } GraphDb &db_; diff --git a/src/database/graph_db_config.cpp b/src/database/graph_db_config.cpp index d22941b13..2ece07751 100644 --- a/src/database/graph_db_config.cpp +++ b/src/database/graph_db_config.cpp @@ -13,6 +13,8 @@ namespace fs = std::experimental::filesystem; // community config. Should we set the default here to true? On some other // points the tech docs are consistent with community config, and not with these // defaults. + +// Durability flags. DEFINE_bool(durability_enabled, false, "If durability (database persistence) should be enabled"); DEFINE_string( @@ -26,6 +28,8 @@ DEFINE_VALIDATED_int32( DEFINE_int32(snapshot_max_retained, -1, "Number of retained snapshots, -1 means without limit."); DEFINE_bool(snapshot_on_exit, false, "Snapshot on exiting the database."); + +// Misc flags. DEFINE_int32(gc_cycle_sec, 30, "Amount of time between starts of two cleaning cycles in seconds. " "-1 to turn off."); @@ -34,11 +38,13 @@ DEFINE_int32(query_execution_time_sec, 180, "limit will be aborted. Value of -1 means no limit."); GraphDb::Config::Config() + // Durability flags. : durability_enabled{FLAGS_durability_enabled}, durability_directory{FLAGS_durability_directory}, db_recover_on_startup{FLAGS_db_recover_on_startup}, snapshot_cycle_sec{FLAGS_snapshot_cycle_sec}, snapshot_max_retained{FLAGS_snapshot_max_retained}, snapshot_on_exit{FLAGS_snapshot_on_exit}, + // Misc flags. gc_cycle_sec{FLAGS_gc_cycle_sec}, query_execution_time_sec{FLAGS_query_execution_time_sec} {} diff --git a/src/distributed/coordination_master.cpp b/src/distributed/coordination_master.cpp new file mode 100644 index 000000000..dd5080388 --- /dev/null +++ b/src/distributed/coordination_master.cpp @@ -0,0 +1,54 @@ +#include "distributed/coordination_master.hpp" +#include "distributed/coordination_rpc_messages.hpp" + +namespace distributed { + +MasterCoordination::MasterCoordination(communication::messaging::System &system) + : system_(system), server_(system, kCoordinationServerName) { + server_.Register<RegisterWorkerRpc>([this](const RegisterWorkerReq &req) { + auto worker_id = RegisterWorker(req.desired_worker_id, req.endpoint); + return std::make_unique<RegisterWorkerRes>(worker_id); + }); + server_.Register<GetEndpointRpc>([this](const GetEndpointReq &req) { + return std::make_unique<GetEndpointRes>(GetEndpoint(req.member)); + }); + server_.Start(); +} + +int MasterCoordination::RegisterWorker(int desired_worker_id, + Endpoint endpoint) { + std::lock_guard<std::mutex> guard(lock_); + int worker_id = desired_worker_id; + // Check if the desired ID is available. + if (workers_.find(worker_id) != workers_.end()) { + if (desired_worker_id >= 0) + LOG(WARNING) << "Unable to assign requested ID (" << worker_id + << ") to worker at \"" << endpoint.address() << ":" + << endpoint.port() << "\""; + worker_id = 1; + } + // Look for the next ID that's not used. + while (workers_.find(worker_id) != workers_.end()) ++worker_id; + workers_.emplace(worker_id, endpoint); + return worker_id; +} + +void MasterCoordination::Shutdown() { + std::lock_guard<std::mutex> guard(lock_); + for (const auto &kv : workers_) { + communication::rpc::Client client(system_, kv.second, + kCoordinationServerName); + auto result = client.Call<StopWorkerRpc>(100ms); + CHECK(result) << "Failed to shut down worker: " << kv.first; + } + server_.Shutdown(); +} + +Endpoint MasterCoordination::GetEndpoint(int worker_id) const { + std::lock_guard<std::mutex> guard(lock_); + auto found = workers_.find(worker_id); + CHECK(found != workers_.end()) << "No endpoint registered for worker id: " + << worker_id; + return found->second; +} +} // namespace distributed diff --git a/src/distributed/coordination_master.hpp b/src/distributed/coordination_master.hpp new file mode 100644 index 000000000..6f1b17688 --- /dev/null +++ b/src/distributed/coordination_master.hpp @@ -0,0 +1,44 @@ +#pragma once + +#include <mutex> +#include <unordered_map> + +#include "communication/messaging/distributed.hpp" +#include "communication/rpc/rpc.hpp" +#include "io/network/network_endpoint.hpp" + +namespace distributed { +using Endpoint = io::network::NetworkEndpoint; + +/** Handles worker registration, getting of other workers' endpoints and + * coordinated shutdown in a distributed memgraph. Master side. */ +class MasterCoordination { + /** + * Registers a new worker with this master server. Notifies all the known + * workers of the new worker. + * + * @param desired_worker_id - The ID the worker would like to have. Set to + * -1 if the worker doesn't care. Does not guarantee that the desired ID will + * be returned, it is possible it's already occupied. If that's an error (for + * example in recovery), the worker should handle it as such. + * @return The assigned ID for the worker asking to become registered. + */ + int RegisterWorker(int desired_worker_id, Endpoint endpoint); + + public: + MasterCoordination(communication::messaging::System &system); + + /** Shuts down all the workers and this master server. */ + void Shutdown(); + + /** Returns the Endpoint for the given worker_id. */ + Endpoint GetEndpoint(int worker_id) const; + + private: + communication::messaging::System &system_; + communication::rpc::Server server_; + // Most master functions aren't thread-safe. + mutable std::mutex lock_; + std::unordered_map<int, Endpoint> workers_; +}; +} // namespace distributed diff --git a/src/distributed/coordination_rpc_messages.hpp b/src/distributed/coordination_rpc_messages.hpp new file mode 100644 index 000000000..0b181c11d --- /dev/null +++ b/src/distributed/coordination_rpc_messages.hpp @@ -0,0 +1,49 @@ +#pragma once + +#include "communication/messaging/local.hpp" +#include "communication/rpc/rpc.hpp" +#include "io/network/network_endpoint.hpp" +#include "utils/rpc_pimp.hpp" + +namespace distributed { + +const std::string kCoordinationServerName = "CoordinationRpc"; + +using communication::messaging::Message; +using Endpoint = io::network::NetworkEndpoint; + +struct RegisterWorkerReq : public Message { + RegisterWorkerReq() {} + // Set desired_worker_id to -1 to get an automatically assigned ID. + RegisterWorkerReq(int desired_worker_id, const Endpoint &endpoint) + : desired_worker_id(desired_worker_id), endpoint(endpoint) {} + int desired_worker_id; + Endpoint endpoint; + + template <class Archive> + void serialize(Archive &ar) { + ar(cereal::virtual_base_class<Message>(this), desired_worker_id, endpoint); + } +}; + +RPC_SINGLE_MEMBER_MESSAGE(RegisterWorkerRes, int); +RPC_SINGLE_MEMBER_MESSAGE(GetEndpointReq, int); +RPC_SINGLE_MEMBER_MESSAGE(GetEndpointRes, Endpoint); +RPC_NO_MEMBER_MESSAGE(StopWorkerReq); +RPC_NO_MEMBER_MESSAGE(StopWorkerRes); + +using RegisterWorkerRpc = + communication::rpc::RequestResponse<RegisterWorkerReq, RegisterWorkerRes>; +using GetEndpointRpc = + communication::rpc::RequestResponse<GetEndpointReq, GetEndpointRes>; +using StopWorkerRpc = + communication::rpc::RequestResponse<StopWorkerReq, StopWorkerRes>; + +} // namespace distributed + +CEREAL_REGISTER_TYPE(distributed::RegisterWorkerReq); +CEREAL_REGISTER_TYPE(distributed::RegisterWorkerRes); +CEREAL_REGISTER_TYPE(distributed::GetEndpointReq); +CEREAL_REGISTER_TYPE(distributed::GetEndpointRes); +CEREAL_REGISTER_TYPE(distributed::StopWorkerReq); +CEREAL_REGISTER_TYPE(distributed::StopWorkerRes); diff --git a/src/distributed/coordination_worker.cpp b/src/distributed/coordination_worker.cpp new file mode 100644 index 000000000..22bc73cb5 --- /dev/null +++ b/src/distributed/coordination_worker.cpp @@ -0,0 +1,55 @@ +#include <condition_variable> +#include <mutex> + +#include "distributed/coordination_worker.hpp" + +namespace distributed { + +WorkerCoordination::WorkerCoordination(communication::messaging::System &system, + const Endpoint &master_endpoint) + : system_(system), + client_(system_, master_endpoint, kCoordinationServerName), + server_(system_, kCoordinationServerName) {} + +int WorkerCoordination::RegisterWorker(int desired_worker_id) { + auto result = client_.Call<RegisterWorkerRpc>(300ms, desired_worker_id, + system_.endpoint()); + CHECK(result) << "Failed to RegisterWorker with the master"; + return result->member; +} + +Endpoint WorkerCoordination::GetEndpoint(int worker_id) { + auto accessor = endpoint_cache_.access(); + auto found = accessor.find(worker_id); + if (found != accessor.end()) return found->second; + auto result = client_.Call<GetEndpointRpc>(300ms, worker_id); + CHECK(result) << "Failed to GetEndpoint from the master"; + accessor.insert(worker_id, result->member); + return result->member; +} + +void WorkerCoordination::WaitForShutdown() { + std::mutex mutex; + std::condition_variable cv; + bool shutdown = false; + + server_.Register<StopWorkerRpc>([&](const StopWorkerReq &) { + std::unique_lock<std::mutex> lk(mutex); + shutdown = true; + lk.unlock(); + cv.notify_one(); + return std::make_unique<StopWorkerRes>(); + }); + server_.Start(); + + std::unique_lock<std::mutex> lk(mutex); + cv.wait(lk, [&shutdown] { return shutdown; }); + // Sleep to allow the server to return the StopWorker response. This is + // necessary because Shutdown will most likely be called after this function. + // TODO (review): Should we call server_.Shutdown() here? Not the usual + // convention, but maybe better... + std::this_thread::sleep_for(100ms); +}; + +void WorkerCoordination::Shutdown() { server_.Shutdown(); } +} // namespace distributed diff --git a/src/distributed/coordination_worker.hpp b/src/distributed/coordination_worker.hpp new file mode 100644 index 000000000..dedaac16d --- /dev/null +++ b/src/distributed/coordination_worker.hpp @@ -0,0 +1,41 @@ +#pragma once + +#include "data_structures/concurrent/concurrent_map.hpp" +#include "distributed/coordination_rpc_messages.hpp" +#include "io/network/network_endpoint.hpp" + +namespace distributed { +using Endpoint = io::network::NetworkEndpoint; + +/** Handles worker registration, getting of other workers' endpoints and + * coordinated shutdown in a distributed memgraph. Worker side. */ +class WorkerCoordination { + public: + WorkerCoordination(communication::messaging::System &system, + const Endpoint &master_endpoint); + + /** + * Registers a worker with the master. + * + * @param worker_id - Desired ID. If -1, or if the desired ID is already + * taken, the worker gets the next available ID. + */ + int RegisterWorker(int desired_worker_id = -1); + + /** Gets the endpoint for the given worker ID from the master. */ + Endpoint GetEndpoint(int worker_id); + + /** Starts listening for a remote shutdown command (issued by the master). + * Blocks the calling thread until that has finished. */ + void WaitForShutdown(); + + /** Shuts the RPC server down. */ + void Shutdown(); + + private: + communication::messaging::System &system_; + communication::rpc::Client client_; + communication::rpc::Server server_; + ConcurrentMap<int, Endpoint> endpoint_cache_; +}; +} // namespace distributed diff --git a/src/io/network/network_endpoint.cpp b/src/io/network/network_endpoint.cpp index 3d1446b44..ce596f0d6 100644 --- a/src/io/network/network_endpoint.cpp +++ b/src/io/network/network_endpoint.cpp @@ -1,9 +1,10 @@ -#include "io/network/network_endpoint.hpp" +#include <arpa/inet.h> +#include <netdb.h> +#include <algorithm> #include "glog/logging.h" -#include <arpa/inet.h> -#include <netdb.h> +#include "io/network/network_endpoint.hpp" namespace io::network { @@ -42,4 +43,10 @@ NetworkEndpoint::NetworkEndpoint(const std::string &addr, NetworkEndpoint::NetworkEndpoint(const std::string &addr, uint16_t port) : NetworkEndpoint(addr.c_str(), std::to_string(port)) {} + +bool NetworkEndpoint::operator==(const NetworkEndpoint &other) const { + return std::equal(std::begin(address_), std::end(address_), + std::begin(other.address_)) && + port_ == other.port_ && family_ == other.family_; } +} // namespace io::network diff --git a/src/io/network/network_endpoint.hpp b/src/io/network/network_endpoint.hpp index 4796a0347..0b739c5c6 100644 --- a/src/io/network/network_endpoint.hpp +++ b/src/io/network/network_endpoint.hpp @@ -25,6 +25,14 @@ class NetworkEndpoint { uint16_t port() const { return port_; } unsigned char family() const { return family_; } + /** Required for cereal serialization. */ + template <class Archive> + void serialize(Archive &archive) { + archive(address_, port_str_, port_, family_); + } + + bool operator==(const NetworkEndpoint &other) const; + private: char address_[INET6_ADDRSTRLEN]; char port_str_[6]; diff --git a/src/memgraph_bolt.cpp b/src/memgraph_bolt.cpp index 7e01a9cb8..1d929bea1 100644 --- a/src/memgraph_bolt.cpp +++ b/src/memgraph_bolt.cpp @@ -6,8 +6,11 @@ #include <glog/logging.h> #include "communication/bolt/v1/session.hpp" +#include "communication/messaging/distributed.hpp" #include "communication/server.hpp" #include "config.hpp" +#include "distributed/coordination_master.hpp" +#include "distributed/coordination_worker.hpp" #include "io/network/network_endpoint.hpp" #include "io/network/network_error.hpp" #include "io/network/socket.hpp" @@ -27,6 +30,7 @@ using SessionT = communication::bolt::Session<Socket>; using ResultStreamT = SessionT::ResultStreamT; using ServerT = communication::Server<SessionT, SessionData>; +// General purpose flags. DEFINE_string(interface, "0.0.0.0", "Communication interface on which to listen."); DEFINE_string(port, "7687", "Communication port on which to listen."); @@ -42,6 +46,31 @@ DEFINE_uint64(memory_warning_threshold, 1024, "less available RAM available it will log a warning. Set to 0 to " "disable."); +// Distributed flags. +DEFINE_HIDDEN_bool( + master, false, + "If this Memgraph server is the master in a distributed deployment."); +DEFINE_HIDDEN_string(master_host, "0.0.0.0", + "For master node indicates the host served on. For worker " + "node indicates the master location."); +DEFINE_VALIDATED_HIDDEN_int32( + master_port, 0, + "For master node the port on which to serve. For " + "worker node indicates the master's port.", + FLAG_IN_RANGE(0, std::numeric_limits<uint16_t>::max())); +DEFINE_HIDDEN_bool( + worker, false, + "If this Memgraph server is a worker in a distributed deployment."); +DEFINE_HIDDEN_string(worker_host, "0.0.0.0", + "For worker node indicates the host served on. For master " + "node this flag is not used."); +DEFINE_VALIDATED_HIDDEN_int32( + worker_port, 0, + "For master node it's unused. For worker node " + "indicates the port on which to serve. If zero (default value), a port is " + "chosen at random. Sent to the master when registring worker node.", + FLAG_IN_RANGE(0, std::numeric_limits<uint16_t>::max())); + // Needed to correctly handle memgraph destruction from a signal handler. // Without having some sort of a flag, it is possible that a signal is handled // when we are exiting main, inside destructors of GraphDb and similar. The @@ -49,41 +78,9 @@ DEFINE_uint64(memory_warning_threshold, 1024, // half destructed state, causing invalid memory access and crash. volatile sig_atomic_t is_shutting_down = 0; -int main(int argc, char **argv) { - google::SetUsageMessage("Memgraph database server"); - gflags::SetVersionString(version_string); - - // Load config before parsing arguments, so that flags from the command line - // overwrite the config. - LoadConfig(); - gflags::ParseCommandLineFlags(&argc, &argv, true); - - google::InitGoogleLogging(argv[0]); - google::SetLogDestination(google::INFO, FLAGS_log_file.c_str()); - google::SetLogSymlink(google::INFO, FLAGS_log_link_basename.c_str()); - - // Unhandled exception handler init. - std::set_terminate(&terminate_handler); - - // Initialize bolt session data (GraphDb and Interpreter). - SessionData session_data; - - // Initialize endpoint. - NetworkEndpoint endpoint(FLAGS_interface, FLAGS_port); - - // Initialize server. - ServerT server(endpoint, session_data); - - // Handler for regular termination signals - auto shutdown = [&server, &session_data]() { - if (is_shutting_down) return; - is_shutting_down = 1; - // Server needs to be shutdown first and then the database. This prevents a - // race condition when a transaction is accepted during server shutdown. - server.Shutdown(); - session_data.db.Shutdown(); - }; - +// Registers the given shutdown function with the appropriate signal handlers. +// See implementation for details. +void InitSignalHandlers(const std::function<void()> &shutdown) { // Prevent handling shutdown inside a shutdown. For example, SIGINT handler // being interrupted by SIGTERM before is_shutting_down is set, thus causing // double shutdown. @@ -104,8 +101,9 @@ int main(int argc, char **argv) { CHECK(SignalHandler::RegisterHandler(Signal::User1, []() { google::CloseLogDestination(google::INFO); })) << "Unable to register SIGUSR1 handler!"; +} - // Start memory warning logger. +void StartMemWarningLogger() { Scheduler mem_log_scheduler; if (FLAGS_memory_warning_threshold > 0) { mem_log_scheduler.Run(std::chrono::seconds(3), [] { @@ -115,9 +113,104 @@ int main(int argc, char **argv) { << " MB left."; }); } +} + +void MasterMain() { + google::SetUsageMessage("Memgraph distributed master"); + // RPC for worker registration, shutdown and endpoint info exchange. + communication::messaging::System system(FLAGS_master_host, FLAGS_master_port); + distributed::MasterCoordination master(system); + + // Bolt server stuff. + SessionData session_data{system, master}; + NetworkEndpoint endpoint(FLAGS_interface, FLAGS_port); + ServerT server(endpoint, session_data); + + // Handler for regular termination signals + auto shutdown = [&server, &session_data, &master, &system] { + if (is_shutting_down) return; + is_shutting_down = 1; + // Server needs to be shutdown first and then the database. This prevents a + // race condition when a transaction is accepted during server shutdown. + server.Shutdown(); + session_data.db.Shutdown(); + master.Shutdown(); + system.Shutdown(); + + }; + InitSignalHandlers(shutdown); + + StartMemWarningLogger(); - // Start worker threads. server.Start(FLAGS_num_workers); +} +void WorkerMain() { + google::SetUsageMessage("Memgraph distributed worker"); + // RPC for worker registration, shutdown and endpoint info exchange. + communication::messaging::System system(FLAGS_worker_host, FLAGS_worker_port); + io::network::NetworkEndpoint master_endpoint{ + FLAGS_master_host, static_cast<uint16_t>(FLAGS_master_port)}; + distributed::WorkerCoordination worker(system, master_endpoint); + auto worker_id = worker.RegisterWorker(); + + // The GraphDb destructor shuts some RPC down. Ensure correct ordering. + { + GraphDb db{system, worker_id, worker, master_endpoint}; + query::Interpreter interpreter; + StartMemWarningLogger(); + // Wait for the shutdown command from the master. + worker.WaitForShutdown(); + } + + worker.Shutdown(); + system.Shutdown(); +} + +void SingleNodeMain() { + google::SetUsageMessage("Memgraph single-node database server"); + SessionData session_data; + NetworkEndpoint endpoint(FLAGS_interface, FLAGS_port); + ServerT server(endpoint, session_data); + + // Handler for regular termination signals + auto shutdown = [&server, &session_data] { + if (is_shutting_down) return; + is_shutting_down = 1; + // Server needs to be shutdown first and then the database. This prevents a + // race condition when a transaction is accepted during server shutdown. + server.Shutdown(); + session_data.db.Shutdown(); + }; + InitSignalHandlers(shutdown); + + StartMemWarningLogger(); + + server.Start(FLAGS_num_workers); +} + +int main(int argc, char **argv) { + gflags::SetVersionString(version_string); + + // Load config before parsing arguments, so that flags from the command line + // overwrite the config. + LoadConfig(); + gflags::ParseCommandLineFlags(&argc, &argv, true); + + google::InitGoogleLogging(argv[0]); + google::SetLogDestination(google::INFO, FLAGS_log_file.c_str()); + google::SetLogSymlink(google::INFO, FLAGS_log_link_basename.c_str()); + + // Unhandled exception handler init. + std::set_terminate(&terminate_handler); + + CHECK(!(FLAGS_master && FLAGS_worker)) + << "Can't run Memgraph as worker and master at the same time"; + if (FLAGS_master) + MasterMain(); + else if (FLAGS_worker) + WorkerMain(); + else + SingleNodeMain(); return 0; } diff --git a/src/storage/concurrent_id_mapper.hpp b/src/storage/concurrent_id_mapper.hpp index 93952ad7d..5d27b0321 100644 --- a/src/storage/concurrent_id_mapper.hpp +++ b/src/storage/concurrent_id_mapper.hpp @@ -1,5 +1,7 @@ #pragma once +#include <string> + namespace storage { /** @@ -8,13 +10,12 @@ namespace storage { * for the master (single source of truth) and worker (must query master). * Both implementations must be concurrent. * - * @TParam TId - ID type. Must expose `::TStorage`. - * @TParam TRecord - Value type. + * @TParam TId - One of GraphDb types (Label, EdgeType, Property). */ -template <typename TId, typename TValue> +template <typename TId> class ConcurrentIdMapper { public: - virtual TId value_to_id(const TValue &value) = 0; - virtual const TValue &id_to_value(const TId &id) = 0; + virtual TId value_to_id(const std::string &value) = 0; + virtual const std::string &id_to_value(const TId &id) = 0; }; -} // namespace storage +} // namespace storage diff --git a/src/storage/concurrent_id_mapper_master.hpp b/src/storage/concurrent_id_mapper_master.hpp index a5ad2f46d..388fa98c7 100644 --- a/src/storage/concurrent_id_mapper_master.hpp +++ b/src/storage/concurrent_id_mapper_master.hpp @@ -12,7 +12,7 @@ namespace storage { /** Master implementation of ConcurrentIdMapper. */ template <typename TId> class MasterConcurrentIdMapper - : public SingleNodeConcurrentIdMapper<TId, std::string> { + : public SingleNodeConcurrentIdMapper<TId> { public: MasterConcurrentIdMapper(communication::messaging::System &system); diff --git a/src/storage/concurrent_id_mapper_single_node.hpp b/src/storage/concurrent_id_mapper_single_node.hpp index 186fd6146..beb43b2d5 100644 --- a/src/storage/concurrent_id_mapper_single_node.hpp +++ b/src/storage/concurrent_id_mapper_single_node.hpp @@ -6,12 +6,12 @@ namespace storage { /** SingleNode implementation of ConcurrentIdMapper. */ -template <typename TId, typename TValue> -class SingleNodeConcurrentIdMapper : public ConcurrentIdMapper<TId, TValue> { +template <typename TId> +class SingleNodeConcurrentIdMapper : public ConcurrentIdMapper<TId> { using StorageT = typename TId::StorageT; public: - TId value_to_id(const TValue &value) override { + TId value_to_id(const std::string &value) override { auto value_to_id_acc = value_to_id_.access(); auto found = value_to_id_acc.find(value); TId inserted_id(0); @@ -34,7 +34,7 @@ class SingleNodeConcurrentIdMapper : public ConcurrentIdMapper<TId, TValue> { return inserted_id; } - const TValue &id_to_value(const TId &id) override { + const std::string &id_to_value(const TId &id) override { const auto id_to_value_acc = id_to_value_.access(); auto result = id_to_value_acc.find(id); DCHECK(result != id_to_value_acc.end()); @@ -42,8 +42,8 @@ class SingleNodeConcurrentIdMapper : public ConcurrentIdMapper<TId, TValue> { } private: - ConcurrentMap<TValue, TId> value_to_id_; - ConcurrentMap<TId, TValue> id_to_value_; + ConcurrentMap<std::string, TId> value_to_id_; + ConcurrentMap<TId, std::string> id_to_value_; std::atomic<StorageT> id_{0}; }; } // namespace storage diff --git a/src/storage/concurrent_id_mapper_worker.hpp b/src/storage/concurrent_id_mapper_worker.hpp index 431ee7871..31561f8c8 100644 --- a/src/storage/concurrent_id_mapper_worker.hpp +++ b/src/storage/concurrent_id_mapper_worker.hpp @@ -10,7 +10,7 @@ namespace storage { /** Worker implementation of ConcurrentIdMapper. */ template <typename TId> -class WorkerConcurrentIdMapper : public ConcurrentIdMapper<TId, std::string> { +class WorkerConcurrentIdMapper : public ConcurrentIdMapper<TId> { // Makes an appropriate RPC call for the current TId type and the given value. TId RpcValueToId(const std::string &value); diff --git a/src/transactions/engine_master.cpp b/src/transactions/engine_master.cpp index 5d362cee4..c354cfac6 100644 --- a/src/transactions/engine_master.cpp +++ b/src/transactions/engine_master.cpp @@ -8,6 +8,10 @@ namespace tx { +MasterEngine::MasterEngine(communication::messaging::System &system) { + StartServer(system); +} + MasterEngine::~MasterEngine() { if (rpc_server_) StopServer(); } diff --git a/src/transactions/engine_master.hpp b/src/transactions/engine_master.hpp index 9c26a031b..defcf7760 100644 --- a/src/transactions/engine_master.hpp +++ b/src/transactions/engine_master.hpp @@ -28,6 +28,10 @@ class TransactionError : public utils::BasicException { */ class MasterEngine : public Engine { public: + MasterEngine() = default; + /** Constructs a master engine and calls StartServer() */ + MasterEngine(communication::messaging::System &system); + /** Stops the tx server if it's running. */ ~MasterEngine(); diff --git a/src/transactions/engine_worker.cpp b/src/transactions/engine_worker.cpp index 007f7fdd1..196cdcda3 100644 --- a/src/transactions/engine_worker.cpp +++ b/src/transactions/engine_worker.cpp @@ -10,9 +10,8 @@ static const auto kRpcTimeout = 100ms; } WorkerEngine::WorkerEngine(communication::messaging::System &system, - const std::string &tx_server_host, - uint16_t tx_server_port) - : rpc_client_(system, tx_server_host, tx_server_port, "tx_engine") {} + const io::network::NetworkEndpoint &endpoint) + : rpc_client_(system, endpoint, "tx_engine") {} Transaction *WorkerEngine::LocalBegin(transaction_id_t tx_id) { auto accessor = active_.access(); @@ -33,8 +32,8 @@ CommitLog::Info WorkerEngine::Info(transaction_id_t tid) const { // If we don't know the transaction to be commited nor aborted, ask the // master about it and update the local commit log. if (!(info.is_aborted() || info.is_committed())) { - // @review: this version of Call is just used because Info has no default - // constructor. + // @review: this version of Call is just used because Info has no + // default constructor. info = rpc_client_.Call<ClogInfoRpc>(kRpcTimeout, tid)->member; DCHECK(info.is_committed() || info.is_aborted()) << "It is expected that the transaction is not running anymore. This " diff --git a/src/transactions/engine_worker.hpp b/src/transactions/engine_worker.hpp index 98dc3ac6f..01b261f0a 100644 --- a/src/transactions/engine_worker.hpp +++ b/src/transactions/engine_worker.hpp @@ -6,6 +6,7 @@ #include "communication/messaging/distributed.hpp" #include "communication/rpc/rpc.hpp" #include "data_structures/concurrent/concurrent_map.hpp" +#include "io/network/network_endpoint.hpp" #include "transactions/commit_log.hpp" #include "transactions/engine.hpp" #include "transactions/transaction.hpp" @@ -15,7 +16,7 @@ namespace tx { class WorkerEngine : public Engine { public: WorkerEngine(communication::messaging::System &system, - const std::string &tx_server_host, uint16_t tx_server_port); + const io::network::NetworkEndpoint &endpoint); Transaction *LocalBegin(transaction_id_t tx_id); diff --git a/src/utils/signals/handler.hpp b/src/utils/signals/handler.hpp index 66df2a5ce..d75881a7a 100644 --- a/src/utils/signals/handler.hpp +++ b/src/utils/signals/handler.hpp @@ -6,8 +6,6 @@ #include <utility> #include <vector> -using Function = std::function<void()>; - // TODO: align bits so signals can be combined // Signal::Terminate | Signal::Interupt enum class Signal : int { @@ -28,7 +26,7 @@ class SignalHandler { public: /// Install a signal handler. - static bool RegisterHandler(Signal signal, Function func) { + static bool RegisterHandler(Signal signal, std::function<void()> func) { sigset_t signal_mask; sigemptyset(&signal_mask); return RegisterHandler(signal, func, signal_mask); @@ -37,7 +35,7 @@ class SignalHandler { /// Like RegisterHandler, but takes a `signal_mask` argument for blocking /// signals during execution of the handler. `signal_mask` should be created /// using `sigemptyset` and `sigaddset` functions from `<signal.h>`. - static bool RegisterHandler(Signal signal, Function func, + static bool RegisterHandler(Signal signal, std::function<void()> func, sigset_t signal_mask) { int signal_number = static_cast<int>(signal); handlers_[signal_number] = func; diff --git a/tests/unit/concurrent_id_mapper_single_node.cpp b/tests/unit/concurrent_id_mapper_single_node.cpp index dfc70a4d5..ae67e7436 100644 --- a/tests/unit/concurrent_id_mapper_single_node.cpp +++ b/tests/unit/concurrent_id_mapper_single_node.cpp @@ -1,3 +1,4 @@ +#include <map> #include <thread> #include <vector> @@ -7,72 +8,47 @@ #include "database/graph_db_datatypes.hpp" #include "storage/concurrent_id_mapper_single_node.hpp" -const int THREAD_NUM = 20; -const int VALUE_MAX = 50; - using Id = GraphDbTypes::Label; -using Mapper = storage::SingleNodeConcurrentIdMapper<Id, int>; +using Mapper = storage::SingleNodeConcurrentIdMapper<Id>; TEST(ConcurrentIdMapper, SameValueGivesSameId) { Mapper mapper; - EXPECT_EQ(mapper.value_to_id(1), mapper.value_to_id(1)); + EXPECT_EQ(mapper.value_to_id("a"), mapper.value_to_id("a")); } TEST(ConcurrentIdMapper, IdToValue) { Mapper mapper; - auto value = 1; + std::string value = "a"; auto id = mapper.value_to_id(value); EXPECT_EQ(value, mapper.id_to_value(id)); } TEST(ConcurrentIdMapper, TwoValuesTwoIds) { Mapper mapper; - EXPECT_NE(mapper.value_to_id(1), mapper.value_to_id(2)); + EXPECT_NE(mapper.value_to_id("a"), mapper.value_to_id("b")); } TEST(ConcurrentIdMapper, SameIdReturnedMultipleThreads) { - std::vector<std::thread> threads; - Mapper mapper; - std::vector<std::vector<Id>> thread_value_ids(THREAD_NUM); + const int thread_count = 20; + std::vector<std::string> values; + for (int i = 0; i < 50; ++i) values.emplace_back("value" + std::to_string(i)); - std::atomic<int> current_value{0}; - std::atomic<int> current_value_insertion_count{0}; - - // Try to insert every value from [0, VALUE_MAX] by multiple threads in the - // same time - for (int i = 0; i < THREAD_NUM; ++i) { - threads.push_back(std::thread([&mapper, &thread_value_ids, ¤t_value, - ¤t_value_insertion_count, i]() { - int last = -1; - while (current_value <= VALUE_MAX) { - while (last == current_value) continue; - auto id = mapper.value_to_id(current_value.load()); - thread_value_ids[i].push_back(id); - // Also check that reverse mapping exists after method exits - EXPECT_EQ(mapper.id_to_value(id), current_value.load()); - last = current_value; - current_value_insertion_count.fetch_add(1); - } - })); - } - // Increment current_value when all threads finish inserting it and getting an - // id for it - threads.push_back( - std::thread([¤t_value, ¤t_value_insertion_count]() { - while (current_value.load() <= VALUE_MAX) { - while (current_value_insertion_count.load() != THREAD_NUM) continue; - current_value_insertion_count.store(0); - current_value.fetch_add(1); + // Perform the whole test a number of times since it's stochastic (we're + // trying to detect bad behavior in parallel execution). + for (int loop_ind = 0; loop_ind < 20; ++loop_ind) { + Mapper mapper; + std::vector<std::map<Id, std::string>> mappings(thread_count); + std::vector<std::thread> threads; + for (int thread_ind = 0; thread_ind < thread_count; ++thread_ind) { + threads.emplace_back([&mapper, &mappings, &values, thread_ind] { + auto &mapping = mappings[thread_ind]; + for (auto &value : values) { + mapping.emplace(mapper.value_to_id(value), value); } - })); - for (auto &thread : threads) thread.join(); - - // For every value inserted, each thread should have the same id - for (int i = 0; i < THREAD_NUM; ++i) - for (int j = 0; j < THREAD_NUM; ++j) - EXPECT_EQ(thread_value_ids[i], thread_value_ids[j]); - - // Each value should have a unique id - std::set<Id> ids(thread_value_ids[0].begin(), thread_value_ids[0].end()); - EXPECT_EQ(ids.size(), thread_value_ids[0].size()); + }); + } + for (auto &thread : threads) thread.join(); + EXPECT_EQ(mappings[0].size(), values.size()); + for (auto &mapping : mappings) EXPECT_EQ(mapping, mappings[0]); + } } diff --git a/tests/unit/distributed_coordination.cpp b/tests/unit/distributed_coordination.cpp new file mode 100644 index 000000000..904a71759 --- /dev/null +++ b/tests/unit/distributed_coordination.cpp @@ -0,0 +1,91 @@ +#include <experimental/optional> +#include <memory> +#include <thread> +#include <vector> + +#include "gtest/gtest.h" + +#include "communication/messaging/distributed.hpp" +#include "distributed/coordination_master.hpp" +#include "distributed/coordination_worker.hpp" +#include "io/network/network_endpoint.hpp" + +using communication::messaging::System; +using namespace distributed; + +const int kWorkerCount = 5; +const std::string kLocal = "127.0.0.1"; + +class WorkerInThread { + public: + WorkerInThread(io::network::NetworkEndpoint master_endpoint, + int desired_id = -1) { + worker_thread_ = std::thread([this, master_endpoint, desired_id] { + system_.emplace(kLocal, 0); + coord_.emplace(*system_, master_endpoint); + worker_id_ = coord_->RegisterWorker(desired_id); + coord_->WaitForShutdown(); + coord_->Shutdown(); + system_->Shutdown(); + }); + } + + int worker_id() const { return worker_id_; } + auto endpoint() const { return system_->endpoint(); } + auto worker_endpoint(int worker_id) { return coord_->GetEndpoint(worker_id); } + void join() { worker_thread_.join(); } + + private: + std::thread worker_thread_; + std::experimental::optional<System> system_; + std::experimental::optional<WorkerCoordination> coord_; + std::atomic<int> worker_id_{0}; +}; + +TEST(Distributed, Coordination) { + System master_system(kLocal, 0); + MasterCoordination master_coord(master_system); + + std::vector<std::unique_ptr<WorkerInThread>> workers; + for (int i = 0; i < kWorkerCount; ++i) + workers.emplace_back( + std::make_unique<WorkerInThread>(master_system.endpoint())); + + // Wait till all the workers are safely initialized. + std::this_thread::sleep_for(300ms); + + // Expect that all workers have a different ID. + std::unordered_set<int> worker_ids; + for (const auto &w : workers) worker_ids.insert(w->worker_id()); + EXPECT_EQ(worker_ids.size(), kWorkerCount); + + // Check endpoints. + for (auto &w1 : workers) { + for (auto &w2 : workers) { + EXPECT_EQ(w1->worker_endpoint(w2->worker_id()), w2->endpoint()); + } + } + + // Coordinated shutdown. + master_coord.Shutdown(); + master_system.Shutdown(); + for (auto &worker : workers) worker->join(); +} + +TEST(Distributed, DesiredAndUniqueId) { + System master_system(kLocal, 0); + MasterCoordination master_coord(master_system); + + WorkerInThread w1(master_system.endpoint(), 42); + std::this_thread::sleep_for(200ms); + WorkerInThread w2(master_system.endpoint(), 42); + std::this_thread::sleep_for(200ms); + + EXPECT_EQ(w1.worker_id(), 42); + EXPECT_NE(w2.worker_id(), 42); + + master_coord.Shutdown(); + w1.join(); + w2.join(); + master_system.Shutdown(); +} diff --git a/tests/unit/transaction_engine_worker.cpp b/tests/unit/transaction_engine_worker.cpp index 155d38d3a..838b17eb3 100644 --- a/tests/unit/transaction_engine_worker.cpp +++ b/tests/unit/transaction_engine_worker.cpp @@ -3,6 +3,7 @@ #include "gtest/gtest.h" #include "communication/messaging/distributed.hpp" +#include "io/network/network_endpoint.hpp" #include "transactions/engine_master.hpp" #include "transactions/engine_worker.hpp" @@ -17,7 +18,7 @@ class WorkerEngineTest : public testing::Test { MasterEngine master_; System worker_system_{local, 0}; - WorkerEngine worker_{worker_system_, local, master_system_.endpoint().port()}; + WorkerEngine worker_{worker_system_, master_system_.endpoint()}; void SetUp() override { master_.StartServer(master_system_); } void TearDown() override { diff --git a/tools/apollo/build_debug b/tools/apollo/build_debug index 631f92c53..a7ebfe516 100644 --- a/tools/apollo/build_debug +++ b/tools/apollo/build_debug @@ -24,4 +24,4 @@ cd ../tools ./setup cd apollo -./generate debug +TIMEOUT=300 ./generate debug diff --git a/tools/apollo/build_diff b/tools/apollo/build_diff index 9c6d8c3df..67376d8b7 100644 --- a/tools/apollo/build_diff +++ b/tools/apollo/build_diff @@ -34,7 +34,7 @@ TIMEOUT=1000 make -j$THREADS memgraph memgraph__macro_benchmark # Install tools, because they may be needed to run some benchmarks and tests. cd ../../memgraph/tools -./setup +TIMEOUT=300 ./setup cd apollo -./generate diff +TIMEOUT=300 ./generate diff diff --git a/tools/apollo/build_release b/tools/apollo/build_release index 7b32f71ba..b6383d0cd 100644 --- a/tools/apollo/build_release +++ b/tools/apollo/build_release @@ -27,4 +27,4 @@ cd ../../tools ./setup cd apollo -./generate release +TIMEOUT=300 ./generate release