Add first version of message passing and rpc

Reviewers: mtomic, buda

Reviewed By: mtomic, buda

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1002
This commit is contained in:
Mislav Bradac 2017-12-05 13:41:51 +01:00
parent 8dcd8e1012
commit 60f4db2b9f
18 changed files with 1187 additions and 22 deletions

View File

@ -170,7 +170,6 @@ target_link_libraries(antlr_opencypher_parser_lib antlr4)
include_directories(src)
add_subdirectory(src)
# -----------------------------------------------------------------------------
# Optional subproject configuration -------------------------------------------

View File

@ -4,10 +4,14 @@
set(memgraph_src_files
communication/bolt/v1/decoder/decoded_value.cpp
communication/bolt/v1/session.cpp
communication/reactor/protocol.cpp
communication/messaging/distributed.cpp
communication/messaging/local.cpp
communication/messaging/protocol.cpp
communication/raft/raft.cpp
communication/reactor/reactor_local.cpp
communication/reactor/protocol.cpp
communication/reactor/reactor_distributed.cpp
communication/reactor/reactor_local.cpp
communication/rpc/rpc.cpp
data_structures/concurrent/skiplist_gc.cpp
database/graph_db.cpp
database/graph_db_config.cpp
@ -78,6 +82,7 @@ set_target_properties(${MEMGRAPH_BUILD_NAME} PROPERTIES
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
# Strip the executable in release build.
string(TOLOWER ${CMAKE_BUILD_TYPE} lower_build_type)
if (lower_build_type STREQUAL "release")
add_custom_command(TARGET ${MEMGRAPH_BUILD_NAME} POST_BUILD
COMMAND strip -s $<TARGET_FILE:${MEMGRAPH_BUILD_NAME}>

View File

@ -0,0 +1,69 @@
#include "communication/messaging/distributed.hpp"
namespace communication::messaging {
System::System(const std::string &address, uint16_t port)
: address_(address), port_(port) {
// Numbers of worker are quite arbitrary at the point.
StartClient(4);
StartServer(4);
}
System::~System() {
for (size_t i = 0; i < pool_.size(); ++i) {
pool_[i].join();
}
thread_.join();
}
void System::Shutdown() {
queue_.Shutdown();
server_->Shutdown();
}
void System::StartClient(int worker_count) {
LOG(INFO) << "Starting " << worker_count << " client workers";
for (int i = 0; i < worker_count; ++i) {
pool_.push_back(std::thread([this]() {
while (true) {
auto message = queue_.AwaitPop();
if (message == std::experimental::nullopt) break;
SendMessage(message->address, message->port, message->channel,
std::move(message->message));
}
}));
}
}
void System::StartServer(int worker_count) {
if (server_ != nullptr) {
LOG(FATAL) << "Tried to start a running server!";
}
// Initialize endpoint.
Endpoint endpoint;
try {
endpoint = Endpoint(address_.c_str(), port_);
} catch (io::network::NetworkEndpointException &e) {
LOG(FATAL) << e.what();
}
// Initialize server.
server_ = std::make_unique<ServerT>(endpoint, protocol_data_);
// Start server.
thread_ = std::thread(
[worker_count, this]() { this->server_->Start(worker_count); });
}
std::shared_ptr<EventStream> System::Open(const std::string &name) {
return system_.Open(name);
}
Writer::Writer(System &system, const std::string &address, uint16_t port,
const std::string &name)
: system_(system), address_(address), port_(port), name_(name) {}
void Writer::Send(std::unique_ptr<Message> message) {
system_.queue_.Emplace(address_, port_, name_, std::move(message));
}
}

View File

@ -0,0 +1,124 @@
#pragma once
#include <cassert>
#include <exception>
#include <functional>
#include <iostream>
#include <memory>
#include <mutex>
#include <queue>
#include <stdexcept>
#include <tuple>
#include <typeindex>
#include <utility>
#include <gflags/gflags.h>
#include "communication/messaging/local.hpp"
#include "data_structures/queue.hpp"
#include "protocol.hpp"
#include "cereal/archives/binary.hpp"
#include "cereal/types/base_class.hpp"
#include "cereal/types/memory.hpp"
#include "cereal/types/polymorphic.hpp"
#include "cereal/types/string.hpp"
#include "cereal/types/utility.hpp"
#include "cereal/types/vector.hpp"
#include "communication/server.hpp"
#include "threading/sync/spinlock.hpp"
namespace communication::messaging {
class System;
// Writes message to remote event stream.
class Writer {
public:
Writer(System &system, const std::string &address, uint16_t port,
const std::string &name);
Writer(const Writer &) = delete;
void operator=(const Writer &) = delete;
Writer(Writer &&) = delete;
void operator=(Writer &&) = delete;
template <typename TMessage, typename... Args>
void Send(Args &&... args) {
Send(std::unique_ptr<Message>(
std::make_unique<TMessage>(std::forward<Args>(args)...)));
}
void Send(std::unique_ptr<Message> message);
private:
System &system_;
std::string address_;
uint16_t port_;
std::string name_;
};
class System {
public:
friend class Writer;
System(const std::string &address, uint16_t port);
System(const System &) = delete;
System(System &&) = delete;
System &operator=(const System &) = delete;
System &operator=(System &&) = delete;
~System();
std::shared_ptr<EventStream> Open(const std::string &name);
void Shutdown();
const std::string &address() const { return address_; }
uint16_t port() const { return port_; }
private:
using Endpoint = io::network::NetworkEndpoint;
using Socket = Socket;
using ServerT = communication::Server<Session, SessionData>;
struct NetworkMessage {
NetworkMessage() {}
NetworkMessage(const std::string &address, uint16_t port,
const std::string &channel,
std::unique_ptr<Message> &&message)
: address(address),
port(port),
channel(channel),
message(std::move(message)) {}
NetworkMessage(NetworkMessage &&nm) = default;
NetworkMessage &operator=(NetworkMessage &&nm) = default;
std::string address;
uint16_t port = 0;
std::string channel;
std::unique_ptr<Message> message;
};
/** Start a threadpool that dispatches the messages from the (outgoing) queue
* to the sockets */
void StartClient(int worker_count);
/** Start a threadpool that relays the messages from the sockets to the
* LocalEventStreams */
void StartServer(int workers_count);
// Client variables.
std::vector<std::thread> pool_;
Queue<NetworkMessage> queue_;
// Server variables.
std::thread thread_;
SessionData protocol_data_;
std::unique_ptr<ServerT> server_{nullptr};
std::string address_;
uint16_t port_;
LocalSystem &system_ = protocol_data_.system;
};
} // namespace communication::messaging

View File

@ -0,0 +1,61 @@
#include "communication/messaging/local.hpp"
#include "fmt/format.h"
#include "glog/logging.h"
namespace communication::messaging {
std::shared_ptr<EventStream> LocalSystem::Open(const std::string &name) {
std::unique_lock<std::mutex> guard(mutex_);
// TODO: It would be better to use std::make_shared here, but we can't since
// constructor is private.
std::shared_ptr<EventStream> stream(new EventStream(*this, name));
auto got = channels_.emplace(name, stream);
CHECK(got.second) << fmt::format("Stream with name {} already exists", name);
return stream;
}
std::shared_ptr<EventStream> LocalSystem::Resolve(const std::string &name) {
std::unique_lock<std::mutex> guard(mutex_);
auto it = channels_.find(name);
if (it == channels_.end()) return nullptr;
return it->second.lock();
}
void LocalSystem::Remove(const std::string &name) {
std::unique_lock<std::mutex> guard(mutex_);
auto it = channels_.find(name);
CHECK(it != channels_.end()) << "Trying to delete nonexisting stream";
channels_.erase(it);
}
void LocalWriter::Send(std::unique_ptr<Message> message) {
// TODO: We could add caching to LocalWriter so that we don't need to acquire
// lock on system every time we want to send a message. This can be
// accomplished by storing weak_ptr to EventStream.
auto stream = system_.Resolve(name_);
if (!stream) return;
stream->Push(std::move(message));
}
EventStream::~EventStream() { system_.Remove(name_); }
std::unique_ptr<Message> EventStream::Poll() {
auto opt_message = queue_.MaybePop();
if (opt_message == std::experimental::nullopt) return nullptr;
return std::move(*opt_message);
}
void EventStream::Push(std::unique_ptr<Message> message) {
queue_.Push(std::move(message));
}
std::unique_ptr<Message> EventStream::Await(
std::chrono::system_clock::duration timeout) {
auto opt_message = queue_.AwaitPop(timeout);
if (opt_message == std::experimental::nullopt) return nullptr;
return std::move(*opt_message);
};
void EventStream::Shutdown() { queue_.Shutdown(); }
}

View File

@ -0,0 +1,109 @@
#pragma once
#include <mutex>
#include <string>
#include <type_traits>
#include <typeindex>
#include "cereal/types/memory.hpp"
#include "data_structures/queue.hpp"
namespace communication::messaging {
/**
* Base class for messages.
*/
class Message {
public:
virtual ~Message() {}
template <class Archive>
void serialize(Archive &) {}
/**
* Run-time type identification that is used for callbacks.
*
* Warning: this works because of the virtual destructor, don't remove it from
* this class
*/
std::type_index type_index() const { return typeid(*this); }
};
class EventStream;
class LocalWriter;
class LocalSystem {
public:
friend class EventStream;
friend class LocalWriter;
LocalSystem() = default;
LocalSystem(const LocalSystem &) = delete;
LocalSystem(LocalSystem &&) = delete;
LocalSystem &operator=(const LocalSystem &) = delete;
LocalSystem &operator=(LocalSystem &&) = delete;
std::shared_ptr<EventStream> Open(const std::string &name);
private:
std::shared_ptr<EventStream> Resolve(const std::string &name);
void Remove(const std::string &name);
std::mutex mutex_;
std::unordered_map<std::string, std::weak_ptr<EventStream>> channels_;
};
class LocalWriter {
public:
LocalWriter(LocalSystem &system, const std::string &name)
: system_(system), name_(name) {}
LocalWriter(const LocalWriter &) = delete;
void operator=(const LocalWriter &) = delete;
LocalWriter(LocalWriter &&) = delete;
void operator=(LocalWriter &&) = delete;
template <typename TMessage, typename... Args>
void Send(Args &&... args) {
Send(std::unique_ptr<Message>(
std::make_unique<TMessage>(std::forward<Args>(args)...)));
}
void Send(std::unique_ptr<Message> message);
private:
LocalSystem &system_;
std::string name_;
};
class EventStream {
public:
friend class LocalWriter;
friend class LocalSystem;
EventStream(const EventStream &) = delete;
void operator=(const EventStream &) = delete;
EventStream(EventStream &&) = delete;
void operator=(EventStream &&) = delete;
~EventStream();
std::unique_ptr<Message> Poll();
std::unique_ptr<Message> Await(
std::chrono::system_clock::duration timeout =
std::chrono::system_clock::duration::max());
void Shutdown();
const std::string &name() const { return name_; }
private:
EventStream(LocalSystem &system, const std::string &name)
: system_(system), name_(name) {}
void Push(std::unique_ptr<Message> message);
LocalSystem &system_;
std::string name_;
Queue<std::unique_ptr<Message>> queue_;
};
} // namespace communication::messaging

View File

@ -0,0 +1,115 @@
#include <sstream>
#include "communication/messaging/distributed.hpp"
#include "communication/messaging/local.hpp"
#include "communication/messaging/protocol.hpp"
#include "fmt/format.h"
#include "glog/logging.h"
namespace communication::messaging {
Session::Session(Socket &&socket, SessionData &data)
: socket_(std::move(socket)), system_(data.system) {}
bool Session::Alive() const { return alive_; }
std::string Session::GetStringAndShift(SizeT len) {
std::string ret(reinterpret_cast<char *>(buffer_.data()), len);
buffer_.Shift(len);
return ret;
}
void Session::Execute() {
if (buffer_.size() < sizeof(SizeT)) return;
SizeT len_channel = GetLength();
if (buffer_.size() < 2 * sizeof(SizeT) + len_channel) return;
SizeT len_data = GetLength(sizeof(SizeT) + len_channel);
if (buffer_.size() < 2 * sizeof(SizeT) + len_data + len_channel) return;
// Remove the length bytes from the buffer.
buffer_.Shift(sizeof(SizeT));
auto channel = GetStringAndShift(len_channel);
buffer_.Shift(sizeof(SizeT));
// TODO: check for exceptions
std::istringstream stream;
stream.str(std::string(reinterpret_cast<char *>(buffer_.data()), len_data));
::cereal::BinaryInputArchive iarchive{stream};
std::unique_ptr<Message> message{nullptr};
iarchive(message);
buffer_.Shift(len_data);
LocalWriter writer(system_, channel);
writer.Send(std::move(message));
}
StreamBuffer Session::Allocate() { return buffer_.Allocate(); }
void Session::Written(size_t len) { buffer_.Written(len); }
void Session::Close() {
DLOG(INFO) << "Closing session";
this->socket_.Close();
}
SizeT Session::GetLength(int offset) {
SizeT ret = *reinterpret_cast<SizeT *>(buffer_.data() + offset);
return ret;
}
bool SendLength(Socket &socket, SizeT length) {
return socket.Write(reinterpret_cast<uint8_t *>(&length), sizeof(SizeT));
}
void SendMessage(const std::string &address, uint16_t port,
const std::string &channel, std::unique_ptr<Message> message) {
CHECK(message) << "Trying to send nullptr instead of message";
// Initialize endpoint.
Endpoint endpoint;
try {
endpoint = Endpoint(address.c_str(), port);
} catch (io::network::NetworkEndpointException &e) {
LOG(ERROR) << "Address {} is invalid!";
return;
}
// Initialize socket.
Socket socket;
if (!socket.Connect(endpoint)) {
LOG(INFO) << "Couldn't connect to remote address: " << address << ":"
<< port;
return;
}
if (!SendLength(socket, channel.size())) {
LOG(INFO) << "Couldn't send channel size!";
return;
}
if (!socket.Write(channel)) {
LOG(INFO) << "Couldn't send channel data!";
return;
}
// Serialize and send message
std::ostringstream stream;
::cereal::BinaryOutputArchive oarchive(stream);
oarchive(message);
const std::string &buffer = stream.str();
int64_t message_size = 2 * sizeof(SizeT) + buffer.size() + channel.size();
CHECK(message_size <= kMaxMessageSize) << fmt::format(
"Trying to send message of size {}, max message size is {}", message_size,
kMaxMessageSize);
if (!SendLength(socket, buffer.size())) {
LOG(INFO) << "Couldn't send message size!";
return;
}
if (!socket.Write(buffer)) {
LOG(INFO) << "Couldn't send message data!";
return;
}
}
}

View File

@ -0,0 +1,106 @@
#pragma once
#include <chrono>
#include <cstdint>
#include "communication/bolt/v1/decoder/buffer.hpp"
#include "communication/messaging/local.hpp"
#include "io/network/epoll.hpp"
#include "io/network/network_endpoint.hpp"
#include "io/network/socket.hpp"
#include "io/network/stream_buffer.hpp"
/**
* @brief Protocol
*
* Has classes and functions that implement server and client sides of our
* messaging protocol.
*
* Message layout: SizeT channel_size, channel_size characters channel,
* SizeT message_size, message_size bytes serialized_message
*/
namespace communication::messaging {
class Message;
using Endpoint = io::network::NetworkEndpoint;
using Socket = io::network::Socket;
using StreamBuffer = io::network::StreamBuffer;
// This buffer should be larger than the largest serialized message.
const int64_t kMaxMessageSize = 262144;
using Buffer = bolt::Buffer<kMaxMessageSize>;
using SizeT = uint16_t;
/**
* Distributed Protocol Data
*/
struct SessionData {
LocalSystem system;
};
/**
* Distributed Protocol Session
*
* This class is responsible for handling a single client connection.
*/
class Session {
public:
Session(Socket &&socket, SessionData &data);
int Id() const { return socket_.fd(); }
/**
* Returns the protocol alive state
*/
bool Alive() const;
/**
* Executes the protocol after data has been read into the buffer.
* Goes through the protocol states in order to execute commands from the
* client.
*/
void Execute();
/**
* Allocates data from the internal buffer.
* Used in the underlying network stack to asynchronously read data
* from the client.
* @returns a StreamBuffer to the allocated internal data buffer
*/
StreamBuffer Allocate();
/**
* Notifies the internal buffer of written data.
* Used in the underlying network stack to notify the internal buffer
* how many bytes of data have been written.
* @param len how many data was written to the buffer
*/
void Written(size_t len);
bool TimedOut() { return false; }
/**
* Closes the session (client socket).
*/
void Close();
Socket socket_;
LocalSystem &system_;
std::chrono::time_point<std::chrono::steady_clock> last_event_time_;
private:
SizeT GetLength(int offset = 0);
std::string GetStringAndShift(SizeT len);
bool alive_{true};
Buffer buffer_;
};
/**
* Distributed Protocol Send Message
*/
void SendMessage(const std::string &address, uint16_t port,
const std::string &channel, std::unique_ptr<Message> message);
}

View File

@ -96,7 +96,7 @@ class Network {
break;
}
}
queue_.Signal();
queue_.Shutdown();
for (size_t i = 0; i < pool_.size(); ++i) {
pool_[i].join();
}

View File

@ -0,0 +1,146 @@
#include <iterator>
#include <random>
#include "communication/rpc/rpc.hpp"
namespace communication::rpc {
const char kProtocolStreamPrefix[] = "rpc-";
std::string UniqueId() {
static thread_local std::mt19937 pseudo_rand_gen{std::random_device{}()};
const char kCharset[] =
"0123456789"
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz";
const auto kMaxIndex = (sizeof(kCharset) - 1);
static thread_local std::uniform_int_distribution<> rand_dist{0, kMaxIndex};
std::string id;
std::generate_n(std::back_inserter(id), 20,
[&] { return kCharset[rand_dist(pseudo_rand_gen)]; });
return id;
}
class Request : public messaging::Message {
public:
Request(const std::string &address, uint16_t port, const std::string &stream,
std::unique_ptr<Message> message)
: address_(address),
port_(port),
stream_(stream),
message_id_(UniqueId()),
message_(std::move(message)) {}
const std::string &address() const { return address_; }
uint16_t port() const { return port_; }
const std::string &stream() const { return stream_; }
const std::string &message_id() const { return message_id_; }
const messaging::Message &message() const { return *message_; }
template <class Archive>
void serialize(Archive &ar) {
ar(cereal::virtual_base_class<messaging::Message>(this), address_, port_,
stream_, message_id_, message_);
}
protected:
friend class cereal::access;
Request() {} // Cereal needs access to a default constructor.
std::string address_;
uint16_t port_;
std::string stream_;
std::string message_id_;
std::unique_ptr<messaging::Message> message_;
};
class Response : public messaging::Message {
public:
explicit Response(const std::string &message_id,
std::unique_ptr<messaging::Message> message)
: message_id_(message_id), message_(std::move(message)) {}
template <class Archive>
void serialize(Archive &ar) {
ar(cereal::virtual_base_class<messaging::Message>(this), message_id_,
message_);
}
const auto &message_id() const { return message_id_; }
auto &message() { return message_; }
protected:
Response() {} // Cereal needs access to a default constructor.
friend class cereal::access;
std::string message_id_;
std::unique_ptr<messaging::Message> message_;
};
Client::Client(messaging::System &system, const std::string &address,
uint16_t port, const std::string &name)
: system_(system),
writer_(system, address, port, kProtocolStreamPrefix + name),
stream_(system.Open(UniqueId())) {}
// Because of the way Call is implemented it can fail without reporting (it will
// just block indefinately). This is why you always need to provide reasonable
// timeout when calling it.
// TODO: Make Call use same connection for request and respone and as soon as
// connection drop return nullptr.
std::unique_ptr<messaging::Message> Client::Call(
std::chrono::system_clock::duration timeout,
std::unique_ptr<messaging::Message> message) {
auto request = std::make_unique<Request>(system_.address(), system_.port(),
stream_->name(), std::move(message));
auto message_id = request->message_id();
writer_.Send(std::move(request));
auto now = std::chrono::system_clock::now();
auto until = now + timeout;
while (true) {
auto message = stream_->Await(until - std::chrono::system_clock::now());
if (!message) break; // Client was either signaled or timeout was reached.
auto *response = dynamic_cast<Response *>(message.get());
if (!response) {
LOG(ERROR) << "Message received by rpc client is not a response";
continue;
}
if (response->message_id() != message_id) {
// This can happen if some stale response arrives after we issued a new
// request.
continue;
}
return std::move(response->message());
}
return nullptr;
}
Server::Server(messaging::System &system, const std::string &name)
: system_(system), stream_(system.Open(kProtocolStreamPrefix + name)) {}
void Server::Start() {
// TODO: Add logging.
while (alive_) {
auto message = stream_->Await();
if (!message) continue;
auto *request = dynamic_cast<Request *>(message.get());
if (!request) continue;
auto &real_request = request->message();
auto it = callbacks_.find(real_request.type_index());
if (it == callbacks_.end()) continue;
auto response = it->second(real_request);
messaging::Writer writer(system_, request->address(), request->port(),
request->stream());
writer.Send<Response>(request->message_id(), std::move(response));
}
}
void Server::Shutdown() {
alive_ = false;
stream_->Shutdown();
}
}
CEREAL_REGISTER_TYPE(communication::rpc::Request);
CEREAL_REGISTER_TYPE(communication::rpc::Response);

View File

@ -0,0 +1,91 @@
#include <type_traits>
#include "communication/messaging/distributed.hpp"
namespace communication::rpc {
template <typename TRequest, typename TResponse>
struct RequestResponse {
using Request = TRequest;
using Response = TResponse;
};
// Client is not thread safe.
class Client {
public:
Client(messaging::System &system, const std::string &address, uint16_t port,
const std::string &name);
// Call function can initiate only one request at the time. Function blocks
// until there is a response or timeout was reached. If timeout was reached
// nullptr is returned.
template <typename TRequestResponse, typename... Args>
std::unique_ptr<typename TRequestResponse::Response> Call(
std::chrono::system_clock::duration timeout, Args &&... args) {
using Req = typename TRequestResponse::Request;
using Res = typename TRequestResponse::Response;
static_assert(std::is_base_of<messaging::Message, Req>::value,
"TRequestResponse::Request must be derived from Message");
static_assert(std::is_base_of<messaging::Message, Res>::value,
"TRequestResponse::Response must be derived from Message");
auto response =
Call(timeout, std::unique_ptr<messaging::Message>(
std::make_unique<Req>(std::forward<Args>(args)...)));
auto *real_response = dynamic_cast<Res *>(response.get());
if (!real_response && response) {
LOG(ERROR) << "Message response was of unexpected type";
return nullptr;
}
response.release();
return std::unique_ptr<Res>(real_response);
}
private:
std::unique_ptr<messaging::Message> Call(
std::chrono::system_clock::duration timeout,
std::unique_ptr<messaging::Message> message);
messaging::System &system_;
messaging::Writer writer_;
std::shared_ptr<messaging::EventStream> stream_;
};
class Server {
public:
Server(messaging::System &system, const std::string &name);
template <typename TRequestResponse>
void Register(
std::function<std::unique_ptr<typename TRequestResponse::Response>(
const typename TRequestResponse::Request &)>
callback) {
static_assert(std::is_base_of<messaging::Message,
typename TRequestResponse::Request>::value,
"TRequestResponse::Request must be derived from Message");
static_assert(std::is_base_of<messaging::Message,
typename TRequestResponse::Response>::value,
"TRequestResponse::Response must be derived from Message");
auto got = callbacks_.emplace(
typeid(typename TRequestResponse::Request),
[callback = callback](const messaging::Message &base_message) {
const auto &message =
dynamic_cast<const typename TRequestResponse::Request &>(
base_message);
return callback(message);
});
CHECK(got.second) << "Callback for that message type already registered";
}
void Start();
void Shutdown();
private:
messaging::System &system_;
std::shared_ptr<messaging::EventStream> stream_;
std::unordered_map<std::type_index,
std::function<std::unique_ptr<messaging::Message>(
const messaging::Message &)>>
callbacks_;
std::atomic<bool> alive_{true};
};
}

View File

@ -7,6 +7,11 @@
#include <iostream>
#include <mutex>
#include <queue>
#include <thread>
#include "glog/logging.h"
using namespace std::literals::chrono_literals;
// Thread safe queue. Probably doesn't perform very well, but it works.
template <typename T>
@ -18,8 +23,6 @@ class Queue {
Queue(Queue &&) = delete;
Queue &operator=(Queue &&) = delete;
~Queue() { Signal(); }
void Push(T x) {
std::unique_lock<std::mutex> guard(mutex_);
queue_.emplace(std::move(x));
@ -46,12 +49,19 @@ class Queue {
}
// Block until there is an element in the queue and then pop it from the queue
// and return it. Function can return nullopt only if Queue is signaled via
// Signal function.
std::experimental::optional<T> AwaitPop() {
// and return it. Function can return nullopt if Queue is signaled via
// Shutdown function or if there is no element to pop after timeout elapses.
std::experimental::optional<T> AwaitPop(
std::chrono::system_clock::duration timeout =
std::chrono::system_clock::duration::max()) {
std::unique_lock<std::mutex> guard(mutex_);
cvar_.wait(guard, [this] { return !queue_.empty() || signaled_; });
if (queue_.empty()) return std::experimental::nullopt;
auto now = std::chrono::system_clock::now();
auto until = std::chrono::system_clock::time_point::max() - timeout > now
? now + timeout
: std::chrono::system_clock::time_point::max();
cvar_.wait_until(guard, until,
[this] { return !queue_.empty() || !alive_; });
if (queue_.empty() || !alive_) return std::experimental::nullopt;
std::experimental::optional<T> x(std::move(queue_.front()));
queue_.pop();
return x;
@ -66,14 +76,17 @@ class Queue {
return x;
}
// Notify all threads waiting on conditional variable to stop waiting.
void Signal() {
signaled_ = true;
// Notify all threads waiting on conditional variable to stop waiting. New
// threads that try to Await will not block.
void Shutdown() {
std::unique_lock<std::mutex> guard(mutex_);
alive_ = false;
guard.unlock();
cvar_.notify_all();
}
private:
std::atomic<bool> signaled_{false};
bool alive_ = true;
std::queue<T> queue_;
std::condition_variable cvar_;
mutable std::mutex mutex_;

View File

@ -36,13 +36,23 @@ class SocketEventDispatcher {
// Go through all events and process them in order.
for (int i = 0; i < n; ++i) {
auto &event = events_[i];
Listener &listener = *reinterpret_cast<Listener *>(event.data.ptr);
// TODO: revise this. Reported events will be combined so continue is not
// probably what we want to do.
// Even though it is possible for multiple events to be reported we handle
// only one of them. Since we use epoll in level triggered mode
// unprocessed events will be reported next time we call epoll_wait. This
// kind of processing events is safer since callbacks can destroy listener
// and calling next callback on listener object will segfault. More subtle
// bugs are also possible: one callback can handle multiple events
// (maybe even in a subtle implicit way) and then we don't want to call
// multiple callbacks since we are not sure if those are valid anymore.
try {
// Hangup event.
if (event.events & EPOLLIN) {
// We have some data waiting to be read.
listener.OnData();
continue;
}
if (event.events & EPOLLRDHUP) {
listener.OnClose();
continue;
@ -54,8 +64,6 @@ class SocketEventDispatcher {
continue;
}
// We have some data waiting to be read.
listener.OnData();
} catch (const std::exception &e) {
listener.OnException(e);
}

View File

@ -0,0 +1,97 @@
#include <atomic>
#include <chrono>
#include <cstdlib>
#include <future>
#include <iostream>
#include <string>
#include <thread>
#include <vector>
#include "communication/messaging/distributed.hpp"
#include "gtest/gtest.h"
using namespace communication::messaging;
using namespace std::literals::chrono_literals;
struct MessageInt : public Message {
MessageInt() {} // cereal needs this
MessageInt(int x) : x(x) {}
int x;
template <class Archive>
void serialize(Archive &ar) {
ar(cereal::virtual_base_class<Message>(this), x);
}
};
CEREAL_REGISTER_TYPE(MessageInt);
#define GET_X(p) dynamic_cast<MessageInt *>((p).get())->x
/**
* Test do the services start up without crashes.
*/
TEST(SimpleTests, StartAndShutdown) {
System system("127.0.0.1", 10000);
// do nothing
std::this_thread::sleep_for(500ms);
system.Shutdown();
}
TEST(Messaging, Pop) {
System master_system("127.0.0.1", 10000);
System slave_system("127.0.0.1", 10001);
auto stream = master_system.Open("main");
Writer writer(slave_system, "127.0.0.1", 10000, "main");
std::this_thread::sleep_for(100ms);
EXPECT_EQ(stream->Poll(), nullptr);
writer.Send<MessageInt>(10);
EXPECT_EQ(GET_X(stream->Await()), 10);
master_system.Shutdown();
slave_system.Shutdown();
}
TEST(Messaging, Await) {
System master_system("127.0.0.1", 10000);
System slave_system("127.0.0.1", 10001);
auto stream = master_system.Open("main");
Writer writer(slave_system, "127.0.0.1", 10000, "main");
std::this_thread::sleep_for(100ms);
std::thread t([&] {
std::this_thread::sleep_for(100ms);
stream->Shutdown();
std::this_thread::sleep_for(100ms);
writer.Send<MessageInt>(20);
});
EXPECT_EQ(stream->Poll(), nullptr);
EXPECT_EQ(stream->Await(), nullptr);
t.join();
master_system.Shutdown();
slave_system.Shutdown();
}
TEST(Messaging, RecreateChannelAfterClosing) {
System master_system("127.0.0.1", 10000);
System slave_system("127.0.0.1", 10001);
auto stream = master_system.Open("main");
Writer writer(slave_system, "127.0.0.1", 10000, "main");
std::this_thread::sleep_for(100ms);
writer.Send<MessageInt>(10);
EXPECT_EQ(GET_X(stream->Await()), 10);
stream = nullptr;
writer.Send<MessageInt>(20);
std::this_thread::sleep_for(100ms);
stream = master_system.Open("main");
std::this_thread::sleep_for(100ms);
EXPECT_EQ(stream->Poll(), nullptr);
writer.Send<MessageInt>(30);
EXPECT_EQ(GET_X(stream->Await()), 30);
master_system.Shutdown();
slave_system.Shutdown();
}

View File

@ -0,0 +1,72 @@
#include <chrono>
#include <iostream>
#include <thread>
#include "communication/messaging/local.hpp"
#include "gtest/gtest.h"
using namespace std::literals::chrono_literals;
using namespace communication::messaging;
struct MessageInt : public Message {
MessageInt(int xx) : x(xx) {}
int x;
};
#define GET_X(p) dynamic_cast<MessageInt *>((p).get())->x
TEST(LocalMessaging, Pop) {
LocalSystem system;
auto stream = system.Open("main");
LocalWriter writer(system, "main");
EXPECT_EQ(stream->Poll(), nullptr);
writer.Send<MessageInt>(10);
EXPECT_EQ(GET_X(stream->Poll()), 10);
}
TEST(LocalMessaging, Await) {
LocalSystem system;
auto stream = system.Open("main");
LocalWriter writer(system, "main");
std::thread t([&] {
std::this_thread::sleep_for(100ms);
stream->Shutdown();
std::this_thread::sleep_for(100ms);
writer.Send<MessageInt>(20);
});
EXPECT_EQ(stream->Poll(), nullptr);
EXPECT_EQ(stream->Await(), nullptr);
t.join();
}
TEST(LocalMessaging, AwaitTimeout) {
LocalSystem system;
auto stream = system.Open("main");
EXPECT_EQ(stream->Poll(), nullptr);
EXPECT_EQ(stream->Await(100ms), nullptr);
}
TEST(LocalMessaging, RecreateChannelAfterClosing) {
LocalSystem system;
auto stream = system.Open("main");
LocalWriter writer(system, "main");
writer.Send<MessageInt>(10);
EXPECT_EQ(GET_X(stream->Poll()), 10);
stream = nullptr;
writer.Send<MessageInt>(20);
stream = system.Open("main");
EXPECT_EQ(stream->Poll(), nullptr);
writer.Send<MessageInt>(30);
EXPECT_EQ(GET_X(stream->Poll()), 30);
}
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -79,6 +79,7 @@ TEST(Queue, AwaitPop) {
});
EXPECT_EQ(*q.AwaitPop(), 1);
std::this_thread::sleep_for(1000ms);
EXPECT_EQ(*q.AwaitPop(), 2);
EXPECT_EQ(*q.AwaitPop(), 3);
EXPECT_EQ(*q.AwaitPop(), 4);
@ -86,12 +87,19 @@ TEST(Queue, AwaitPop) {
std::thread t2([&] {
std::this_thread::sleep_for(100ms);
q.Signal();
q.Shutdown();
});
std::this_thread::sleep_for(200ms);
EXPECT_EQ(q.AwaitPop(), std::experimental::nullopt);
t2.join();
}
TEST(Queue, AwaitPopTimeout) {
std::this_thread::sleep_for(1000ms);
Queue<int> q;
EXPECT_EQ(q.AwaitPop(100ms), std::experimental::nullopt);
}
TEST(Queue, Concurrent) {
Queue<int> q;

80
tests/unit/rpc.cpp Normal file
View File

@ -0,0 +1,80 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include <thread>
#include "communication/messaging/distributed.hpp"
#include "communication/rpc/rpc.hpp"
#include "gtest/gtest.h"
using communication::messaging::System;
using communication::messaging::Message;
using namespace communication::rpc;
using namespace std::literals::chrono_literals;
struct SumReq : public Message {
SumReq() {} // cereal needs this
SumReq(int x, int y) : x(x), y(y) {}
int x;
int y;
template <class Archive>
void serialize(Archive &ar) {
ar(cereal::virtual_base_class<Message>(this), x, y);
}
};
CEREAL_REGISTER_TYPE(SumReq);
struct SumRes : public Message {
SumRes() {} // cereal needs this
SumRes(int sum) : sum(sum) {}
int sum;
template <class Archive>
void serialize(Archive &ar) {
ar(cereal::virtual_base_class<Message>(this), sum);
}
};
CEREAL_REGISTER_TYPE(SumRes);
using Sum = RequestResponse<SumReq, SumRes>;
TEST(Rpc, Call) {
System server_system("127.0.0.1", 10000);
Server server(server_system, "main");
server.Register<Sum>([](const SumReq &request) {
return std::make_unique<SumRes>(request.x + request.y);
});
std::thread server_thread([&] { server.Start(); });
std::this_thread::sleep_for(100ms);
System client_system("127.0.0.1", 10001);
Client client(client_system, "127.0.0.1", 10000, "main");
auto sum = client.Call<Sum>(300ms, 10, 20);
EXPECT_EQ(sum->sum, 30);
server.Shutdown();
server_thread.join();
server_system.Shutdown();
client_system.Shutdown();
}
TEST(Rpc, Timeout) {
System server_system("127.0.0.1", 10000);
Server server(server_system, "main");
server.Register<Sum>([](const SumReq &request) {
std::this_thread::sleep_for(300ms);
return std::make_unique<SumRes>(request.x + request.y);
});
std::thread server_thread([&] { server.Start(); });
std::this_thread::sleep_for(100ms);
System client_system("127.0.0.1", 10001);
Client client(client_system, "127.0.0.1", 10000, "main");
auto sum = client.Call<Sum>(100ms, 10, 20);
EXPECT_FALSE(sum);
server.Shutdown();
server_thread.join();
server_system.Shutdown();
client_system.Shutdown();
}

62
tools/rpcgen Executable file
View File

@ -0,0 +1,62 @@
#!/usr/bin/env python3
# To run from vim add to your .vimrc something similar to:
# command -nargs=* Rpcgen :r !/home/mislav/code/memgraph/tools/rpcgen <args>
import sys
USAGE = "\n\nUsage:\n" \
"./rpcgen request_response_name request_args -- response_args\n" \
"Arguments should be seperated with minus sign (-).\n" \
"Example: ./rpcgen Sum int x - int y -- int sum\n\n"
assert len(sys.argv) >= 3, "Too few arguments.\n" + USAGE
request_response_name = sys.argv[1]
args_string = " ".join(sys.argv[2:])
split = args_string.split("--")
assert len(split) == 2, "Arguments should contain one -- separator.\n" + USAGE
request_args, response_args = split
def generate(message_name, args):
def process_arg(arg):
arg = arg.strip()
assert arg, "Each arg should be non empty string.\n" + USAGE
for i in range(len(arg) - 1, -1, -1):
if not arg[i].isalpha() and arg[i] != "_": break
assert i != -1 and i != len(arg), "Each string separated with - " \
"should contain type and variable name.\n" + USAGE
typ = arg[:i+1].strip()
name = arg[i+1:].strip()
return typ, name
types, names = zip(*map(process_arg, args.split("-")))
return \
"""
struct {message_name} : public Message {{
{message_name}() {{}} // cereal needs this
{message_name}({constructor_args}): {init_list} {{}}
{members}
template <class Archive>
void serialize(Archive &ar) {{
ar(cereal::virtual_base_class<Message>(this), {serialize_args});
}}
}};
CEREAL_REGISTER_TYPE({message_name});""" \
.format(message_name=message_name,
constructor_args=",".join(
map(lambda x: x[0] + " " + x[1], zip(types, names))),
init_list=", ".join(map(lambda x: "{x}({x})".format(x=x), names)),
members="\n".join(map(lambda x:
"{} {};".format(x[0], x[1]), zip(types, names))),
serialize_args=", ".join(names))
request_name = request_response_name + "Req"
response_name = request_response_name + "Res"
req_class = generate(request_name, request_args)
res_class = generate(response_name, response_args)
print(req_class)
print(res_class)
print("using {} = RequestResponse<{}, {}>;".format(
request_response_name, request_name, response_name))