From 60f4db2b9f4ebd60069e91b357cf6b1b63bcf393 Mon Sep 17 00:00:00 2001 From: Mislav Bradac Date: Tue, 5 Dec 2017 13:41:51 +0100 Subject: [PATCH] Add first version of message passing and rpc Reviewers: mtomic, buda Reviewed By: mtomic, buda Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D1002 --- CMakeLists.txt | 1 - src/CMakeLists.txt | 9 +- src/communication/messaging/distributed.cpp | 69 +++++++++ src/communication/messaging/distributed.hpp | 124 +++++++++++++++ src/communication/messaging/local.cpp | 61 ++++++++ src/communication/messaging/local.hpp | 109 +++++++++++++ src/communication/messaging/protocol.cpp | 115 ++++++++++++++ src/communication/messaging/protocol.hpp | 106 +++++++++++++ .../reactor/reactor_distributed.hpp | 2 +- src/communication/rpc/rpc.cpp | 146 ++++++++++++++++++ src/communication/rpc/rpc.hpp | 91 +++++++++++ src/data_structures/queue.hpp | 35 +++-- src/io/network/socket_event_dispatcher.hpp | 20 ++- tests/unit/messaging_distributed.cpp | 97 ++++++++++++ tests/unit/messaging_local.cpp | 72 +++++++++ tests/unit/queue.cpp | 10 +- tests/unit/rpc.cpp | 80 ++++++++++ tools/rpcgen | 62 ++++++++ 18 files changed, 1187 insertions(+), 22 deletions(-) create mode 100644 src/communication/messaging/distributed.cpp create mode 100644 src/communication/messaging/distributed.hpp create mode 100644 src/communication/messaging/local.cpp create mode 100644 src/communication/messaging/local.hpp create mode 100644 src/communication/messaging/protocol.cpp create mode 100644 src/communication/messaging/protocol.hpp create mode 100644 src/communication/rpc/rpc.cpp create mode 100644 src/communication/rpc/rpc.hpp create mode 100644 tests/unit/messaging_distributed.cpp create mode 100644 tests/unit/messaging_local.cpp create mode 100644 tests/unit/rpc.cpp create mode 100755 tools/rpcgen diff --git a/CMakeLists.txt b/CMakeLists.txt index 7abfb936d..9243570ee 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -170,7 +170,6 @@ target_link_libraries(antlr_opencypher_parser_lib antlr4) include_directories(src) add_subdirectory(src) - # ----------------------------------------------------------------------------- # Optional subproject configuration ------------------------------------------- diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 19411e39b..6a64e48a6 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -4,10 +4,14 @@ set(memgraph_src_files communication/bolt/v1/decoder/decoded_value.cpp communication/bolt/v1/session.cpp - communication/reactor/protocol.cpp + communication/messaging/distributed.cpp + communication/messaging/local.cpp + communication/messaging/protocol.cpp communication/raft/raft.cpp - communication/reactor/reactor_local.cpp + communication/reactor/protocol.cpp communication/reactor/reactor_distributed.cpp + communication/reactor/reactor_local.cpp + communication/rpc/rpc.cpp data_structures/concurrent/skiplist_gc.cpp database/graph_db.cpp database/graph_db_config.cpp @@ -78,6 +82,7 @@ set_target_properties(${MEMGRAPH_BUILD_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) # Strip the executable in release build. string(TOLOWER ${CMAKE_BUILD_TYPE} lower_build_type) + if (lower_build_type STREQUAL "release") add_custom_command(TARGET ${MEMGRAPH_BUILD_NAME} POST_BUILD COMMAND strip -s $ diff --git a/src/communication/messaging/distributed.cpp b/src/communication/messaging/distributed.cpp new file mode 100644 index 000000000..032b5f081 --- /dev/null +++ b/src/communication/messaging/distributed.cpp @@ -0,0 +1,69 @@ +#include "communication/messaging/distributed.hpp" + +namespace communication::messaging { + +System::System(const std::string &address, uint16_t port) + : address_(address), port_(port) { + // Numbers of worker are quite arbitrary at the point. + StartClient(4); + StartServer(4); +} + +System::~System() { + for (size_t i = 0; i < pool_.size(); ++i) { + pool_[i].join(); + } + thread_.join(); +} + +void System::Shutdown() { + queue_.Shutdown(); + server_->Shutdown(); +} + +void System::StartClient(int worker_count) { + LOG(INFO) << "Starting " << worker_count << " client workers"; + for (int i = 0; i < worker_count; ++i) { + pool_.push_back(std::thread([this]() { + while (true) { + auto message = queue_.AwaitPop(); + if (message == std::experimental::nullopt) break; + SendMessage(message->address, message->port, message->channel, + std::move(message->message)); + } + })); + } +} + +void System::StartServer(int worker_count) { + if (server_ != nullptr) { + LOG(FATAL) << "Tried to start a running server!"; + } + + // Initialize endpoint. + Endpoint endpoint; + try { + endpoint = Endpoint(address_.c_str(), port_); + } catch (io::network::NetworkEndpointException &e) { + LOG(FATAL) << e.what(); + } + // Initialize server. + server_ = std::make_unique(endpoint, protocol_data_); + + // Start server. + thread_ = std::thread( + [worker_count, this]() { this->server_->Start(worker_count); }); +} + +std::shared_ptr System::Open(const std::string &name) { + return system_.Open(name); +} + +Writer::Writer(System &system, const std::string &address, uint16_t port, + const std::string &name) + : system_(system), address_(address), port_(port), name_(name) {} + +void Writer::Send(std::unique_ptr message) { + system_.queue_.Emplace(address_, port_, name_, std::move(message)); +} +} diff --git a/src/communication/messaging/distributed.hpp b/src/communication/messaging/distributed.hpp new file mode 100644 index 000000000..b1d1983ff --- /dev/null +++ b/src/communication/messaging/distributed.hpp @@ -0,0 +1,124 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "communication/messaging/local.hpp" +#include "data_structures/queue.hpp" +#include "protocol.hpp" + +#include "cereal/archives/binary.hpp" +#include "cereal/types/base_class.hpp" +#include "cereal/types/memory.hpp" +#include "cereal/types/polymorphic.hpp" +#include "cereal/types/string.hpp" +#include "cereal/types/utility.hpp" +#include "cereal/types/vector.hpp" + +#include "communication/server.hpp" +#include "threading/sync/spinlock.hpp" + +namespace communication::messaging { + +class System; + +// Writes message to remote event stream. +class Writer { + public: + Writer(System &system, const std::string &address, uint16_t port, + const std::string &name); + Writer(const Writer &) = delete; + void operator=(const Writer &) = delete; + Writer(Writer &&) = delete; + void operator=(Writer &&) = delete; + + template + void Send(Args &&... args) { + Send(std::unique_ptr( + std::make_unique(std::forward(args)...))); + } + + void Send(std::unique_ptr message); + + private: + System &system_; + std::string address_; + uint16_t port_; + std::string name_; +}; + +class System { + public: + friend class Writer; + + System(const std::string &address, uint16_t port); + System(const System &) = delete; + System(System &&) = delete; + System &operator=(const System &) = delete; + System &operator=(System &&) = delete; + ~System(); + + std::shared_ptr Open(const std::string &name); + void Shutdown(); + + const std::string &address() const { return address_; } + uint16_t port() const { return port_; } + + private: + using Endpoint = io::network::NetworkEndpoint; + using Socket = Socket; + using ServerT = communication::Server; + + struct NetworkMessage { + NetworkMessage() {} + + NetworkMessage(const std::string &address, uint16_t port, + const std::string &channel, + std::unique_ptr &&message) + : address(address), + port(port), + channel(channel), + message(std::move(message)) {} + + NetworkMessage(NetworkMessage &&nm) = default; + NetworkMessage &operator=(NetworkMessage &&nm) = default; + + std::string address; + uint16_t port = 0; + std::string channel; + std::unique_ptr message; + }; + + /** Start a threadpool that dispatches the messages from the (outgoing) queue + * to the sockets */ + void StartClient(int worker_count); + + /** Start a threadpool that relays the messages from the sockets to the + * LocalEventStreams */ + void StartServer(int workers_count); + + // Client variables. + std::vector pool_; + Queue queue_; + + // Server variables. + std::thread thread_; + SessionData protocol_data_; + std::unique_ptr server_{nullptr}; + std::string address_; + uint16_t port_; + + LocalSystem &system_ = protocol_data_.system; +}; +} // namespace communication::messaging diff --git a/src/communication/messaging/local.cpp b/src/communication/messaging/local.cpp new file mode 100644 index 000000000..38d3b2e5b --- /dev/null +++ b/src/communication/messaging/local.cpp @@ -0,0 +1,61 @@ +#include "communication/messaging/local.hpp" + +#include "fmt/format.h" +#include "glog/logging.h" + +namespace communication::messaging { + +std::shared_ptr LocalSystem::Open(const std::string &name) { + std::unique_lock guard(mutex_); + // TODO: It would be better to use std::make_shared here, but we can't since + // constructor is private. + std::shared_ptr stream(new EventStream(*this, name)); + auto got = channels_.emplace(name, stream); + CHECK(got.second) << fmt::format("Stream with name {} already exists", name); + return stream; +} + +std::shared_ptr LocalSystem::Resolve(const std::string &name) { + std::unique_lock guard(mutex_); + auto it = channels_.find(name); + if (it == channels_.end()) return nullptr; + return it->second.lock(); +} + +void LocalSystem::Remove(const std::string &name) { + std::unique_lock guard(mutex_); + auto it = channels_.find(name); + CHECK(it != channels_.end()) << "Trying to delete nonexisting stream"; + channels_.erase(it); +} + +void LocalWriter::Send(std::unique_ptr message) { + // TODO: We could add caching to LocalWriter so that we don't need to acquire + // lock on system every time we want to send a message. This can be + // accomplished by storing weak_ptr to EventStream. + auto stream = system_.Resolve(name_); + if (!stream) return; + stream->Push(std::move(message)); +} + +EventStream::~EventStream() { system_.Remove(name_); } + +std::unique_ptr EventStream::Poll() { + auto opt_message = queue_.MaybePop(); + if (opt_message == std::experimental::nullopt) return nullptr; + return std::move(*opt_message); +} + +void EventStream::Push(std::unique_ptr message) { + queue_.Push(std::move(message)); +} + +std::unique_ptr EventStream::Await( + std::chrono::system_clock::duration timeout) { + auto opt_message = queue_.AwaitPop(timeout); + if (opt_message == std::experimental::nullopt) return nullptr; + return std::move(*opt_message); +}; + +void EventStream::Shutdown() { queue_.Shutdown(); } +} diff --git a/src/communication/messaging/local.hpp b/src/communication/messaging/local.hpp new file mode 100644 index 000000000..66cd21e60 --- /dev/null +++ b/src/communication/messaging/local.hpp @@ -0,0 +1,109 @@ +#pragma once + +#include +#include +#include +#include + +#include "cereal/types/memory.hpp" + +#include "data_structures/queue.hpp" + +namespace communication::messaging { + +/** + * Base class for messages. + */ +class Message { + public: + virtual ~Message() {} + + template + void serialize(Archive &) {} + + /** + * Run-time type identification that is used for callbacks. + * + * Warning: this works because of the virtual destructor, don't remove it from + * this class + */ + std::type_index type_index() const { return typeid(*this); } +}; + +class EventStream; +class LocalWriter; + +class LocalSystem { + public: + friend class EventStream; + friend class LocalWriter; + + LocalSystem() = default; + LocalSystem(const LocalSystem &) = delete; + LocalSystem(LocalSystem &&) = delete; + LocalSystem &operator=(const LocalSystem &) = delete; + LocalSystem &operator=(LocalSystem &&) = delete; + + std::shared_ptr Open(const std::string &name); + + private: + std::shared_ptr Resolve(const std::string &name); + + void Remove(const std::string &name); + + std::mutex mutex_; + std::unordered_map> channels_; +}; + +class LocalWriter { + public: + LocalWriter(LocalSystem &system, const std::string &name) + : system_(system), name_(name) {} + LocalWriter(const LocalWriter &) = delete; + void operator=(const LocalWriter &) = delete; + LocalWriter(LocalWriter &&) = delete; + void operator=(LocalWriter &&) = delete; + + template + void Send(Args &&... args) { + Send(std::unique_ptr( + std::make_unique(std::forward(args)...))); + } + + void Send(std::unique_ptr message); + + private: + LocalSystem &system_; + std::string name_; +}; + +class EventStream { + public: + friend class LocalWriter; + friend class LocalSystem; + + EventStream(const EventStream &) = delete; + void operator=(const EventStream &) = delete; + EventStream(EventStream &&) = delete; + void operator=(EventStream &&) = delete; + ~EventStream(); + + std::unique_ptr Poll(); + std::unique_ptr Await( + std::chrono::system_clock::duration timeout = + std::chrono::system_clock::duration::max()); + void Shutdown(); + + const std::string &name() const { return name_; } + + private: + EventStream(LocalSystem &system, const std::string &name) + : system_(system), name_(name) {} + + void Push(std::unique_ptr message); + + LocalSystem &system_; + std::string name_; + Queue> queue_; +}; +} // namespace communication::messaging diff --git a/src/communication/messaging/protocol.cpp b/src/communication/messaging/protocol.cpp new file mode 100644 index 000000000..13e15b3a9 --- /dev/null +++ b/src/communication/messaging/protocol.cpp @@ -0,0 +1,115 @@ +#include + +#include "communication/messaging/distributed.hpp" +#include "communication/messaging/local.hpp" +#include "communication/messaging/protocol.hpp" + +#include "fmt/format.h" +#include "glog/logging.h" + +namespace communication::messaging { + +Session::Session(Socket &&socket, SessionData &data) + : socket_(std::move(socket)), system_(data.system) {} + +bool Session::Alive() const { return alive_; } + +std::string Session::GetStringAndShift(SizeT len) { + std::string ret(reinterpret_cast(buffer_.data()), len); + buffer_.Shift(len); + return ret; +} + +void Session::Execute() { + if (buffer_.size() < sizeof(SizeT)) return; + SizeT len_channel = GetLength(); + if (buffer_.size() < 2 * sizeof(SizeT) + len_channel) return; + SizeT len_data = GetLength(sizeof(SizeT) + len_channel); + if (buffer_.size() < 2 * sizeof(SizeT) + len_data + len_channel) return; + + // Remove the length bytes from the buffer. + buffer_.Shift(sizeof(SizeT)); + auto channel = GetStringAndShift(len_channel); + buffer_.Shift(sizeof(SizeT)); + + // TODO: check for exceptions + std::istringstream stream; + stream.str(std::string(reinterpret_cast(buffer_.data()), len_data)); + ::cereal::BinaryInputArchive iarchive{stream}; + std::unique_ptr message{nullptr}; + iarchive(message); + buffer_.Shift(len_data); + + LocalWriter writer(system_, channel); + writer.Send(std::move(message)); +} + +StreamBuffer Session::Allocate() { return buffer_.Allocate(); } + +void Session::Written(size_t len) { buffer_.Written(len); } + +void Session::Close() { + DLOG(INFO) << "Closing session"; + this->socket_.Close(); +} + +SizeT Session::GetLength(int offset) { + SizeT ret = *reinterpret_cast(buffer_.data() + offset); + return ret; +} + +bool SendLength(Socket &socket, SizeT length) { + return socket.Write(reinterpret_cast(&length), sizeof(SizeT)); +} + +void SendMessage(const std::string &address, uint16_t port, + const std::string &channel, std::unique_ptr message) { + CHECK(message) << "Trying to send nullptr instead of message"; + + // Initialize endpoint. + Endpoint endpoint; + try { + endpoint = Endpoint(address.c_str(), port); + } catch (io::network::NetworkEndpointException &e) { + LOG(ERROR) << "Address {} is invalid!"; + return; + } + + // Initialize socket. + Socket socket; + if (!socket.Connect(endpoint)) { + LOG(INFO) << "Couldn't connect to remote address: " << address << ":" + << port; + return; + } + + if (!SendLength(socket, channel.size())) { + LOG(INFO) << "Couldn't send channel size!"; + return; + } + if (!socket.Write(channel)) { + LOG(INFO) << "Couldn't send channel data!"; + return; + } + + // Serialize and send message + std::ostringstream stream; + ::cereal::BinaryOutputArchive oarchive(stream); + oarchive(message); + + const std::string &buffer = stream.str(); + int64_t message_size = 2 * sizeof(SizeT) + buffer.size() + channel.size(); + CHECK(message_size <= kMaxMessageSize) << fmt::format( + "Trying to send message of size {}, max message size is {}", message_size, + kMaxMessageSize); + + if (!SendLength(socket, buffer.size())) { + LOG(INFO) << "Couldn't send message size!"; + return; + } + if (!socket.Write(buffer)) { + LOG(INFO) << "Couldn't send message data!"; + return; + } +} +} diff --git a/src/communication/messaging/protocol.hpp b/src/communication/messaging/protocol.hpp new file mode 100644 index 000000000..3e8d6da9b --- /dev/null +++ b/src/communication/messaging/protocol.hpp @@ -0,0 +1,106 @@ +#pragma once + +#include +#include + +#include "communication/bolt/v1/decoder/buffer.hpp" +#include "communication/messaging/local.hpp" +#include "io/network/epoll.hpp" +#include "io/network/network_endpoint.hpp" +#include "io/network/socket.hpp" +#include "io/network/stream_buffer.hpp" + +/** + * @brief Protocol + * + * Has classes and functions that implement server and client sides of our + * messaging protocol. + * + * Message layout: SizeT channel_size, channel_size characters channel, + * SizeT message_size, message_size bytes serialized_message + */ +namespace communication::messaging { + +class Message; + +using Endpoint = io::network::NetworkEndpoint; +using Socket = io::network::Socket; +using StreamBuffer = io::network::StreamBuffer; + +// This buffer should be larger than the largest serialized message. +const int64_t kMaxMessageSize = 262144; +using Buffer = bolt::Buffer; +using SizeT = uint16_t; + +/** + * Distributed Protocol Data + */ +struct SessionData { + LocalSystem system; +}; + +/** + * Distributed Protocol Session + * + * This class is responsible for handling a single client connection. + */ +class Session { + public: + Session(Socket &&socket, SessionData &data); + + int Id() const { return socket_.fd(); } + + /** + * Returns the protocol alive state + */ + bool Alive() const; + + /** + * Executes the protocol after data has been read into the buffer. + * Goes through the protocol states in order to execute commands from the + * client. + */ + void Execute(); + + /** + * Allocates data from the internal buffer. + * Used in the underlying network stack to asynchronously read data + * from the client. + * @returns a StreamBuffer to the allocated internal data buffer + */ + StreamBuffer Allocate(); + + /** + * Notifies the internal buffer of written data. + * Used in the underlying network stack to notify the internal buffer + * how many bytes of data have been written. + * @param len how many data was written to the buffer + */ + void Written(size_t len); + + bool TimedOut() { return false; } + + /** + * Closes the session (client socket). + */ + void Close(); + + Socket socket_; + LocalSystem &system_; + + std::chrono::time_point last_event_time_; + + private: + SizeT GetLength(int offset = 0); + std::string GetStringAndShift(SizeT len); + + bool alive_{true}; + Buffer buffer_; +}; + +/** + * Distributed Protocol Send Message + */ +void SendMessage(const std::string &address, uint16_t port, + const std::string &channel, std::unique_ptr message); +} diff --git a/src/communication/reactor/reactor_distributed.hpp b/src/communication/reactor/reactor_distributed.hpp index 79b766bed..41f84eb55 100644 --- a/src/communication/reactor/reactor_distributed.hpp +++ b/src/communication/reactor/reactor_distributed.hpp @@ -96,7 +96,7 @@ class Network { break; } } - queue_.Signal(); + queue_.Shutdown(); for (size_t i = 0; i < pool_.size(); ++i) { pool_[i].join(); } diff --git a/src/communication/rpc/rpc.cpp b/src/communication/rpc/rpc.cpp new file mode 100644 index 000000000..44bc94317 --- /dev/null +++ b/src/communication/rpc/rpc.cpp @@ -0,0 +1,146 @@ +#include +#include + +#include "communication/rpc/rpc.hpp" + +namespace communication::rpc { + +const char kProtocolStreamPrefix[] = "rpc-"; + +std::string UniqueId() { + static thread_local std::mt19937 pseudo_rand_gen{std::random_device{}()}; + const char kCharset[] = + "0123456789" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz"; + const auto kMaxIndex = (sizeof(kCharset) - 1); + static thread_local std::uniform_int_distribution<> rand_dist{0, kMaxIndex}; + + std::string id; + std::generate_n(std::back_inserter(id), 20, + [&] { return kCharset[rand_dist(pseudo_rand_gen)]; }); + return id; +} + +class Request : public messaging::Message { + public: + Request(const std::string &address, uint16_t port, const std::string &stream, + std::unique_ptr message) + : address_(address), + port_(port), + stream_(stream), + message_id_(UniqueId()), + message_(std::move(message)) {} + + const std::string &address() const { return address_; } + uint16_t port() const { return port_; } + const std::string &stream() const { return stream_; } + const std::string &message_id() const { return message_id_; } + const messaging::Message &message() const { return *message_; } + + template + void serialize(Archive &ar) { + ar(cereal::virtual_base_class(this), address_, port_, + stream_, message_id_, message_); + } + + protected: + friend class cereal::access; + Request() {} // Cereal needs access to a default constructor. + + std::string address_; + uint16_t port_; + std::string stream_; + std::string message_id_; + std::unique_ptr message_; +}; + +class Response : public messaging::Message { + public: + explicit Response(const std::string &message_id, + std::unique_ptr message) + : message_id_(message_id), message_(std::move(message)) {} + + template + void serialize(Archive &ar) { + ar(cereal::virtual_base_class(this), message_id_, + message_); + } + + const auto &message_id() const { return message_id_; } + auto &message() { return message_; } + + protected: + Response() {} // Cereal needs access to a default constructor. + friend class cereal::access; + std::string message_id_; + std::unique_ptr message_; +}; + +Client::Client(messaging::System &system, const std::string &address, + uint16_t port, const std::string &name) + : system_(system), + writer_(system, address, port, kProtocolStreamPrefix + name), + stream_(system.Open(UniqueId())) {} + +// Because of the way Call is implemented it can fail without reporting (it will +// just block indefinately). This is why you always need to provide reasonable +// timeout when calling it. +// TODO: Make Call use same connection for request and respone and as soon as +// connection drop return nullptr. +std::unique_ptr Client::Call( + std::chrono::system_clock::duration timeout, + std::unique_ptr message) { + auto request = std::make_unique(system_.address(), system_.port(), + stream_->name(), std::move(message)); + auto message_id = request->message_id(); + writer_.Send(std::move(request)); + + auto now = std::chrono::system_clock::now(); + auto until = now + timeout; + + while (true) { + auto message = stream_->Await(until - std::chrono::system_clock::now()); + if (!message) break; // Client was either signaled or timeout was reached. + auto *response = dynamic_cast(message.get()); + if (!response) { + LOG(ERROR) << "Message received by rpc client is not a response"; + continue; + } + if (response->message_id() != message_id) { + // This can happen if some stale response arrives after we issued a new + // request. + continue; + } + return std::move(response->message()); + } + return nullptr; +} + +Server::Server(messaging::System &system, const std::string &name) + : system_(system), stream_(system.Open(kProtocolStreamPrefix + name)) {} + +void Server::Start() { + // TODO: Add logging. + while (alive_) { + auto message = stream_->Await(); + if (!message) continue; + auto *request = dynamic_cast(message.get()); + if (!request) continue; + auto &real_request = request->message(); + auto it = callbacks_.find(real_request.type_index()); + if (it == callbacks_.end()) continue; + auto response = it->second(real_request); + messaging::Writer writer(system_, request->address(), request->port(), + request->stream()); + writer.Send(request->message_id(), std::move(response)); + } +} + +void Server::Shutdown() { + alive_ = false; + stream_->Shutdown(); +} +} +CEREAL_REGISTER_TYPE(communication::rpc::Request); +CEREAL_REGISTER_TYPE(communication::rpc::Response); diff --git a/src/communication/rpc/rpc.hpp b/src/communication/rpc/rpc.hpp new file mode 100644 index 000000000..52afd0bee --- /dev/null +++ b/src/communication/rpc/rpc.hpp @@ -0,0 +1,91 @@ +#include + +#include "communication/messaging/distributed.hpp" + +namespace communication::rpc { + +template +struct RequestResponse { + using Request = TRequest; + using Response = TResponse; +}; + +// Client is not thread safe. +class Client { + public: + Client(messaging::System &system, const std::string &address, uint16_t port, + const std::string &name); + + // Call function can initiate only one request at the time. Function blocks + // until there is a response or timeout was reached. If timeout was reached + // nullptr is returned. + template + std::unique_ptr Call( + std::chrono::system_clock::duration timeout, Args &&... args) { + using Req = typename TRequestResponse::Request; + using Res = typename TRequestResponse::Response; + static_assert(std::is_base_of::value, + "TRequestResponse::Request must be derived from Message"); + static_assert(std::is_base_of::value, + "TRequestResponse::Response must be derived from Message"); + auto response = + Call(timeout, std::unique_ptr( + std::make_unique(std::forward(args)...))); + auto *real_response = dynamic_cast(response.get()); + if (!real_response && response) { + LOG(ERROR) << "Message response was of unexpected type"; + return nullptr; + } + response.release(); + return std::unique_ptr(real_response); + } + + private: + std::unique_ptr Call( + std::chrono::system_clock::duration timeout, + std::unique_ptr message); + + messaging::System &system_; + messaging::Writer writer_; + std::shared_ptr stream_; +}; + +class Server { + public: + Server(messaging::System &system, const std::string &name); + + template + void Register( + std::function( + const typename TRequestResponse::Request &)> + callback) { + static_assert(std::is_base_of::value, + "TRequestResponse::Request must be derived from Message"); + static_assert(std::is_base_of::value, + "TRequestResponse::Response must be derived from Message"); + auto got = callbacks_.emplace( + typeid(typename TRequestResponse::Request), + [callback = callback](const messaging::Message &base_message) { + const auto &message = + dynamic_cast( + base_message); + return callback(message); + }); + CHECK(got.second) << "Callback for that message type already registered"; + } + + void Start(); + void Shutdown(); + + private: + messaging::System &system_; + std::shared_ptr stream_; + std::unordered_map( + const messaging::Message &)>> + callbacks_; + std::atomic alive_{true}; +}; +} diff --git a/src/data_structures/queue.hpp b/src/data_structures/queue.hpp index a13dc98a9..228e531ac 100644 --- a/src/data_structures/queue.hpp +++ b/src/data_structures/queue.hpp @@ -7,6 +7,11 @@ #include #include #include +#include + +#include "glog/logging.h" + +using namespace std::literals::chrono_literals; // Thread safe queue. Probably doesn't perform very well, but it works. template @@ -18,8 +23,6 @@ class Queue { Queue(Queue &&) = delete; Queue &operator=(Queue &&) = delete; - ~Queue() { Signal(); } - void Push(T x) { std::unique_lock guard(mutex_); queue_.emplace(std::move(x)); @@ -46,12 +49,19 @@ class Queue { } // Block until there is an element in the queue and then pop it from the queue - // and return it. Function can return nullopt only if Queue is signaled via - // Signal function. - std::experimental::optional AwaitPop() { + // and return it. Function can return nullopt if Queue is signaled via + // Shutdown function or if there is no element to pop after timeout elapses. + std::experimental::optional AwaitPop( + std::chrono::system_clock::duration timeout = + std::chrono::system_clock::duration::max()) { std::unique_lock guard(mutex_); - cvar_.wait(guard, [this] { return !queue_.empty() || signaled_; }); - if (queue_.empty()) return std::experimental::nullopt; + auto now = std::chrono::system_clock::now(); + auto until = std::chrono::system_clock::time_point::max() - timeout > now + ? now + timeout + : std::chrono::system_clock::time_point::max(); + cvar_.wait_until(guard, until, + [this] { return !queue_.empty() || !alive_; }); + if (queue_.empty() || !alive_) return std::experimental::nullopt; std::experimental::optional x(std::move(queue_.front())); queue_.pop(); return x; @@ -66,14 +76,17 @@ class Queue { return x; } - // Notify all threads waiting on conditional variable to stop waiting. - void Signal() { - signaled_ = true; + // Notify all threads waiting on conditional variable to stop waiting. New + // threads that try to Await will not block. + void Shutdown() { + std::unique_lock guard(mutex_); + alive_ = false; + guard.unlock(); cvar_.notify_all(); } private: - std::atomic signaled_{false}; + bool alive_ = true; std::queue queue_; std::condition_variable cvar_; mutable std::mutex mutex_; diff --git a/src/io/network/socket_event_dispatcher.hpp b/src/io/network/socket_event_dispatcher.hpp index 9ae019222..ac3ffdc55 100644 --- a/src/io/network/socket_event_dispatcher.hpp +++ b/src/io/network/socket_event_dispatcher.hpp @@ -36,13 +36,23 @@ class SocketEventDispatcher { // Go through all events and process them in order. for (int i = 0; i < n; ++i) { auto &event = events_[i]; - Listener &listener = *reinterpret_cast(event.data.ptr); - // TODO: revise this. Reported events will be combined so continue is not - // probably what we want to do. + // Even though it is possible for multiple events to be reported we handle + // only one of them. Since we use epoll in level triggered mode + // unprocessed events will be reported next time we call epoll_wait. This + // kind of processing events is safer since callbacks can destroy listener + // and calling next callback on listener object will segfault. More subtle + // bugs are also possible: one callback can handle multiple events + // (maybe even in a subtle implicit way) and then we don't want to call + // multiple callbacks since we are not sure if those are valid anymore. try { - // Hangup event. + if (event.events & EPOLLIN) { + // We have some data waiting to be read. + listener.OnData(); + continue; + } + if (event.events & EPOLLRDHUP) { listener.OnClose(); continue; @@ -54,8 +64,6 @@ class SocketEventDispatcher { continue; } - // We have some data waiting to be read. - listener.OnData(); } catch (const std::exception &e) { listener.OnException(e); } diff --git a/tests/unit/messaging_distributed.cpp b/tests/unit/messaging_distributed.cpp new file mode 100644 index 000000000..d8e878f95 --- /dev/null +++ b/tests/unit/messaging_distributed.cpp @@ -0,0 +1,97 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "communication/messaging/distributed.hpp" +#include "gtest/gtest.h" + +using namespace communication::messaging; +using namespace std::literals::chrono_literals; + +struct MessageInt : public Message { + MessageInt() {} // cereal needs this + MessageInt(int x) : x(x) {} + int x; + + template + void serialize(Archive &ar) { + ar(cereal::virtual_base_class(this), x); + } +}; +CEREAL_REGISTER_TYPE(MessageInt); + +#define GET_X(p) dynamic_cast((p).get())->x + +/** + * Test do the services start up without crashes. + */ +TEST(SimpleTests, StartAndShutdown) { + System system("127.0.0.1", 10000); + // do nothing + std::this_thread::sleep_for(500ms); + system.Shutdown(); +} + +TEST(Messaging, Pop) { + System master_system("127.0.0.1", 10000); + System slave_system("127.0.0.1", 10001); + auto stream = master_system.Open("main"); + Writer writer(slave_system, "127.0.0.1", 10000, "main"); + std::this_thread::sleep_for(100ms); + + EXPECT_EQ(stream->Poll(), nullptr); + writer.Send(10); + EXPECT_EQ(GET_X(stream->Await()), 10); + master_system.Shutdown(); + slave_system.Shutdown(); +} + +TEST(Messaging, Await) { + System master_system("127.0.0.1", 10000); + System slave_system("127.0.0.1", 10001); + auto stream = master_system.Open("main"); + Writer writer(slave_system, "127.0.0.1", 10000, "main"); + std::this_thread::sleep_for(100ms); + + std::thread t([&] { + std::this_thread::sleep_for(100ms); + stream->Shutdown(); + std::this_thread::sleep_for(100ms); + writer.Send(20); + }); + + EXPECT_EQ(stream->Poll(), nullptr); + EXPECT_EQ(stream->Await(), nullptr); + t.join(); + master_system.Shutdown(); + slave_system.Shutdown(); +} + +TEST(Messaging, RecreateChannelAfterClosing) { + System master_system("127.0.0.1", 10000); + System slave_system("127.0.0.1", 10001); + auto stream = master_system.Open("main"); + Writer writer(slave_system, "127.0.0.1", 10000, "main"); + std::this_thread::sleep_for(100ms); + + writer.Send(10); + EXPECT_EQ(GET_X(stream->Await()), 10); + + stream = nullptr; + writer.Send(20); + std::this_thread::sleep_for(100ms); + + stream = master_system.Open("main"); + std::this_thread::sleep_for(100ms); + EXPECT_EQ(stream->Poll(), nullptr); + writer.Send(30); + EXPECT_EQ(GET_X(stream->Await()), 30); + + master_system.Shutdown(); + slave_system.Shutdown(); +} diff --git a/tests/unit/messaging_local.cpp b/tests/unit/messaging_local.cpp new file mode 100644 index 000000000..6719567cb --- /dev/null +++ b/tests/unit/messaging_local.cpp @@ -0,0 +1,72 @@ +#include +#include +#include + +#include "communication/messaging/local.hpp" +#include "gtest/gtest.h" + +using namespace std::literals::chrono_literals; +using namespace communication::messaging; + +struct MessageInt : public Message { + MessageInt(int xx) : x(xx) {} + int x; +}; + +#define GET_X(p) dynamic_cast((p).get())->x + +TEST(LocalMessaging, Pop) { + LocalSystem system; + auto stream = system.Open("main"); + LocalWriter writer(system, "main"); + + EXPECT_EQ(stream->Poll(), nullptr); + writer.Send(10); + EXPECT_EQ(GET_X(stream->Poll()), 10); +} + +TEST(LocalMessaging, Await) { + LocalSystem system; + auto stream = system.Open("main"); + LocalWriter writer(system, "main"); + + std::thread t([&] { + std::this_thread::sleep_for(100ms); + stream->Shutdown(); + std::this_thread::sleep_for(100ms); + writer.Send(20); + }); + + EXPECT_EQ(stream->Poll(), nullptr); + EXPECT_EQ(stream->Await(), nullptr); + t.join(); +} + +TEST(LocalMessaging, AwaitTimeout) { + LocalSystem system; + auto stream = system.Open("main"); + + EXPECT_EQ(stream->Poll(), nullptr); + EXPECT_EQ(stream->Await(100ms), nullptr); +} + +TEST(LocalMessaging, RecreateChannelAfterClosing) { + LocalSystem system; + auto stream = system.Open("main"); + LocalWriter writer(system, "main"); + + writer.Send(10); + EXPECT_EQ(GET_X(stream->Poll()), 10); + + stream = nullptr; + writer.Send(20); + stream = system.Open("main"); + EXPECT_EQ(stream->Poll(), nullptr); + writer.Send(30); + EXPECT_EQ(GET_X(stream->Poll()), 30); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tests/unit/queue.cpp b/tests/unit/queue.cpp index 310781821..17091a044 100644 --- a/tests/unit/queue.cpp +++ b/tests/unit/queue.cpp @@ -79,6 +79,7 @@ TEST(Queue, AwaitPop) { }); EXPECT_EQ(*q.AwaitPop(), 1); + std::this_thread::sleep_for(1000ms); EXPECT_EQ(*q.AwaitPop(), 2); EXPECT_EQ(*q.AwaitPop(), 3); EXPECT_EQ(*q.AwaitPop(), 4); @@ -86,12 +87,19 @@ TEST(Queue, AwaitPop) { std::thread t2([&] { std::this_thread::sleep_for(100ms); - q.Signal(); + q.Shutdown(); }); + std::this_thread::sleep_for(200ms); EXPECT_EQ(q.AwaitPop(), std::experimental::nullopt); t2.join(); } +TEST(Queue, AwaitPopTimeout) { + std::this_thread::sleep_for(1000ms); + Queue q; + EXPECT_EQ(q.AwaitPop(100ms), std::experimental::nullopt); +} + TEST(Queue, Concurrent) { Queue q; diff --git a/tests/unit/rpc.cpp b/tests/unit/rpc.cpp new file mode 100644 index 000000000..385ad4e86 --- /dev/null +++ b/tests/unit/rpc.cpp @@ -0,0 +1,80 @@ +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include + +#include "communication/messaging/distributed.hpp" +#include "communication/rpc/rpc.hpp" +#include "gtest/gtest.h" + +using communication::messaging::System; +using communication::messaging::Message; +using namespace communication::rpc; +using namespace std::literals::chrono_literals; + +struct SumReq : public Message { + SumReq() {} // cereal needs this + SumReq(int x, int y) : x(x), y(y) {} + int x; + int y; + + template + void serialize(Archive &ar) { + ar(cereal::virtual_base_class(this), x, y); + } +}; +CEREAL_REGISTER_TYPE(SumReq); + +struct SumRes : public Message { + SumRes() {} // cereal needs this + SumRes(int sum) : sum(sum) {} + int sum; + + template + void serialize(Archive &ar) { + ar(cereal::virtual_base_class(this), sum); + } +}; +CEREAL_REGISTER_TYPE(SumRes); +using Sum = RequestResponse; + +TEST(Rpc, Call) { + System server_system("127.0.0.1", 10000); + Server server(server_system, "main"); + server.Register([](const SumReq &request) { + return std::make_unique(request.x + request.y); + }); + std::thread server_thread([&] { server.Start(); }); + std::this_thread::sleep_for(100ms); + + System client_system("127.0.0.1", 10001); + Client client(client_system, "127.0.0.1", 10000, "main"); + auto sum = client.Call(300ms, 10, 20); + EXPECT_EQ(sum->sum, 30); + + server.Shutdown(); + server_thread.join(); + server_system.Shutdown(); + client_system.Shutdown(); +} + +TEST(Rpc, Timeout) { + System server_system("127.0.0.1", 10000); + Server server(server_system, "main"); + server.Register([](const SumReq &request) { + std::this_thread::sleep_for(300ms); + return std::make_unique(request.x + request.y); + }); + std::thread server_thread([&] { server.Start(); }); + std::this_thread::sleep_for(100ms); + + System client_system("127.0.0.1", 10001); + Client client(client_system, "127.0.0.1", 10000, "main"); + auto sum = client.Call(100ms, 10, 20); + EXPECT_FALSE(sum); + + server.Shutdown(); + server_thread.join(); + server_system.Shutdown(); + client_system.Shutdown(); +} diff --git a/tools/rpcgen b/tools/rpcgen new file mode 100755 index 000000000..cc8ab2638 --- /dev/null +++ b/tools/rpcgen @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# To run from vim add to your .vimrc something similar to: +# command -nargs=* Rpcgen :r !/home/mislav/code/memgraph/tools/rpcgen + + +import sys + +USAGE = "\n\nUsage:\n" \ + "./rpcgen request_response_name request_args -- response_args\n" \ + "Arguments should be seperated with minus sign (-).\n" \ + "Example: ./rpcgen Sum int x - int y -- int sum\n\n" + +assert len(sys.argv) >= 3, "Too few arguments.\n" + USAGE + +request_response_name = sys.argv[1] +args_string = " ".join(sys.argv[2:]) +split = args_string.split("--") +assert len(split) == 2, "Arguments should contain one -- separator.\n" + USAGE +request_args, response_args = split + +def generate(message_name, args): + def process_arg(arg): + arg = arg.strip() + assert arg, "Each arg should be non empty string.\n" + USAGE + for i in range(len(arg) - 1, -1, -1): + if not arg[i].isalpha() and arg[i] != "_": break + assert i != -1 and i != len(arg), "Each string separated with - " \ + "should contain type and variable name.\n" + USAGE + typ = arg[:i+1].strip() + name = arg[i+1:].strip() + return typ, name + types, names = zip(*map(process_arg, args.split("-"))) + + return \ +""" +struct {message_name} : public Message {{ + {message_name}() {{}} // cereal needs this + {message_name}({constructor_args}): {init_list} {{}} + {members} + + template + void serialize(Archive &ar) {{ + ar(cereal::virtual_base_class(this), {serialize_args}); + }} +}}; +CEREAL_REGISTER_TYPE({message_name});""" \ + .format(message_name=message_name, + constructor_args=",".join( + map(lambda x: x[0] + " " + x[1], zip(types, names))), + init_list=", ".join(map(lambda x: "{x}({x})".format(x=x), names)), + members="\n".join(map(lambda x: + "{} {};".format(x[0], x[1]), zip(types, names))), + serialize_args=", ".join(names)) + +request_name = request_response_name + "Req" +response_name = request_response_name + "Res" +req_class = generate(request_name, request_args) +res_class = generate(response_name, response_args) +print(req_class) +print(res_class) +print("using {} = RequestResponse<{}, {}>;".format( + request_response_name, request_name, response_name))