Add first version of message passing and rpc
Reviewers: mtomic, buda Reviewed By: mtomic, buda Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D1002
This commit is contained in:
parent
8dcd8e1012
commit
60f4db2b9f
@ -170,7 +170,6 @@ target_link_libraries(antlr_opencypher_parser_lib antlr4)
|
||||
|
||||
include_directories(src)
|
||||
add_subdirectory(src)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Optional subproject configuration -------------------------------------------
|
||||
|
@ -4,10 +4,14 @@
|
||||
set(memgraph_src_files
|
||||
communication/bolt/v1/decoder/decoded_value.cpp
|
||||
communication/bolt/v1/session.cpp
|
||||
communication/reactor/protocol.cpp
|
||||
communication/messaging/distributed.cpp
|
||||
communication/messaging/local.cpp
|
||||
communication/messaging/protocol.cpp
|
||||
communication/raft/raft.cpp
|
||||
communication/reactor/reactor_local.cpp
|
||||
communication/reactor/protocol.cpp
|
||||
communication/reactor/reactor_distributed.cpp
|
||||
communication/reactor/reactor_local.cpp
|
||||
communication/rpc/rpc.cpp
|
||||
data_structures/concurrent/skiplist_gc.cpp
|
||||
database/graph_db.cpp
|
||||
database/graph_db_config.cpp
|
||||
@ -78,6 +82,7 @@ set_target_properties(${MEMGRAPH_BUILD_NAME} PROPERTIES
|
||||
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
|
||||
# Strip the executable in release build.
|
||||
string(TOLOWER ${CMAKE_BUILD_TYPE} lower_build_type)
|
||||
|
||||
if (lower_build_type STREQUAL "release")
|
||||
add_custom_command(TARGET ${MEMGRAPH_BUILD_NAME} POST_BUILD
|
||||
COMMAND strip -s $<TARGET_FILE:${MEMGRAPH_BUILD_NAME}>
|
||||
|
69
src/communication/messaging/distributed.cpp
Normal file
69
src/communication/messaging/distributed.cpp
Normal file
@ -0,0 +1,69 @@
|
||||
#include "communication/messaging/distributed.hpp"
|
||||
|
||||
namespace communication::messaging {
|
||||
|
||||
System::System(const std::string &address, uint16_t port)
|
||||
: address_(address), port_(port) {
|
||||
// Numbers of worker are quite arbitrary at the point.
|
||||
StartClient(4);
|
||||
StartServer(4);
|
||||
}
|
||||
|
||||
System::~System() {
|
||||
for (size_t i = 0; i < pool_.size(); ++i) {
|
||||
pool_[i].join();
|
||||
}
|
||||
thread_.join();
|
||||
}
|
||||
|
||||
void System::Shutdown() {
|
||||
queue_.Shutdown();
|
||||
server_->Shutdown();
|
||||
}
|
||||
|
||||
void System::StartClient(int worker_count) {
|
||||
LOG(INFO) << "Starting " << worker_count << " client workers";
|
||||
for (int i = 0; i < worker_count; ++i) {
|
||||
pool_.push_back(std::thread([this]() {
|
||||
while (true) {
|
||||
auto message = queue_.AwaitPop();
|
||||
if (message == std::experimental::nullopt) break;
|
||||
SendMessage(message->address, message->port, message->channel,
|
||||
std::move(message->message));
|
||||
}
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
void System::StartServer(int worker_count) {
|
||||
if (server_ != nullptr) {
|
||||
LOG(FATAL) << "Tried to start a running server!";
|
||||
}
|
||||
|
||||
// Initialize endpoint.
|
||||
Endpoint endpoint;
|
||||
try {
|
||||
endpoint = Endpoint(address_.c_str(), port_);
|
||||
} catch (io::network::NetworkEndpointException &e) {
|
||||
LOG(FATAL) << e.what();
|
||||
}
|
||||
// Initialize server.
|
||||
server_ = std::make_unique<ServerT>(endpoint, protocol_data_);
|
||||
|
||||
// Start server.
|
||||
thread_ = std::thread(
|
||||
[worker_count, this]() { this->server_->Start(worker_count); });
|
||||
}
|
||||
|
||||
std::shared_ptr<EventStream> System::Open(const std::string &name) {
|
||||
return system_.Open(name);
|
||||
}
|
||||
|
||||
Writer::Writer(System &system, const std::string &address, uint16_t port,
|
||||
const std::string &name)
|
||||
: system_(system), address_(address), port_(port), name_(name) {}
|
||||
|
||||
void Writer::Send(std::unique_ptr<Message> message) {
|
||||
system_.queue_.Emplace(address_, port_, name_, std::move(message));
|
||||
}
|
||||
}
|
124
src/communication/messaging/distributed.hpp
Normal file
124
src/communication/messaging/distributed.hpp
Normal file
@ -0,0 +1,124 @@
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <exception>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <stdexcept>
|
||||
#include <tuple>
|
||||
#include <typeindex>
|
||||
#include <utility>
|
||||
|
||||
#include <gflags/gflags.h>
|
||||
|
||||
#include "communication/messaging/local.hpp"
|
||||
#include "data_structures/queue.hpp"
|
||||
#include "protocol.hpp"
|
||||
|
||||
#include "cereal/archives/binary.hpp"
|
||||
#include "cereal/types/base_class.hpp"
|
||||
#include "cereal/types/memory.hpp"
|
||||
#include "cereal/types/polymorphic.hpp"
|
||||
#include "cereal/types/string.hpp"
|
||||
#include "cereal/types/utility.hpp"
|
||||
#include "cereal/types/vector.hpp"
|
||||
|
||||
#include "communication/server.hpp"
|
||||
#include "threading/sync/spinlock.hpp"
|
||||
|
||||
namespace communication::messaging {
|
||||
|
||||
class System;
|
||||
|
||||
// Writes message to remote event stream.
|
||||
class Writer {
|
||||
public:
|
||||
Writer(System &system, const std::string &address, uint16_t port,
|
||||
const std::string &name);
|
||||
Writer(const Writer &) = delete;
|
||||
void operator=(const Writer &) = delete;
|
||||
Writer(Writer &&) = delete;
|
||||
void operator=(Writer &&) = delete;
|
||||
|
||||
template <typename TMessage, typename... Args>
|
||||
void Send(Args &&... args) {
|
||||
Send(std::unique_ptr<Message>(
|
||||
std::make_unique<TMessage>(std::forward<Args>(args)...)));
|
||||
}
|
||||
|
||||
void Send(std::unique_ptr<Message> message);
|
||||
|
||||
private:
|
||||
System &system_;
|
||||
std::string address_;
|
||||
uint16_t port_;
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
class System {
|
||||
public:
|
||||
friend class Writer;
|
||||
|
||||
System(const std::string &address, uint16_t port);
|
||||
System(const System &) = delete;
|
||||
System(System &&) = delete;
|
||||
System &operator=(const System &) = delete;
|
||||
System &operator=(System &&) = delete;
|
||||
~System();
|
||||
|
||||
std::shared_ptr<EventStream> Open(const std::string &name);
|
||||
void Shutdown();
|
||||
|
||||
const std::string &address() const { return address_; }
|
||||
uint16_t port() const { return port_; }
|
||||
|
||||
private:
|
||||
using Endpoint = io::network::NetworkEndpoint;
|
||||
using Socket = Socket;
|
||||
using ServerT = communication::Server<Session, SessionData>;
|
||||
|
||||
struct NetworkMessage {
|
||||
NetworkMessage() {}
|
||||
|
||||
NetworkMessage(const std::string &address, uint16_t port,
|
||||
const std::string &channel,
|
||||
std::unique_ptr<Message> &&message)
|
||||
: address(address),
|
||||
port(port),
|
||||
channel(channel),
|
||||
message(std::move(message)) {}
|
||||
|
||||
NetworkMessage(NetworkMessage &&nm) = default;
|
||||
NetworkMessage &operator=(NetworkMessage &&nm) = default;
|
||||
|
||||
std::string address;
|
||||
uint16_t port = 0;
|
||||
std::string channel;
|
||||
std::unique_ptr<Message> message;
|
||||
};
|
||||
|
||||
/** Start a threadpool that dispatches the messages from the (outgoing) queue
|
||||
* to the sockets */
|
||||
void StartClient(int worker_count);
|
||||
|
||||
/** Start a threadpool that relays the messages from the sockets to the
|
||||
* LocalEventStreams */
|
||||
void StartServer(int workers_count);
|
||||
|
||||
// Client variables.
|
||||
std::vector<std::thread> pool_;
|
||||
Queue<NetworkMessage> queue_;
|
||||
|
||||
// Server variables.
|
||||
std::thread thread_;
|
||||
SessionData protocol_data_;
|
||||
std::unique_ptr<ServerT> server_{nullptr};
|
||||
std::string address_;
|
||||
uint16_t port_;
|
||||
|
||||
LocalSystem &system_ = protocol_data_.system;
|
||||
};
|
||||
} // namespace communication::messaging
|
61
src/communication/messaging/local.cpp
Normal file
61
src/communication/messaging/local.cpp
Normal file
@ -0,0 +1,61 @@
|
||||
#include "communication/messaging/local.hpp"
|
||||
|
||||
#include "fmt/format.h"
|
||||
#include "glog/logging.h"
|
||||
|
||||
namespace communication::messaging {
|
||||
|
||||
std::shared_ptr<EventStream> LocalSystem::Open(const std::string &name) {
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
// TODO: It would be better to use std::make_shared here, but we can't since
|
||||
// constructor is private.
|
||||
std::shared_ptr<EventStream> stream(new EventStream(*this, name));
|
||||
auto got = channels_.emplace(name, stream);
|
||||
CHECK(got.second) << fmt::format("Stream with name {} already exists", name);
|
||||
return stream;
|
||||
}
|
||||
|
||||
std::shared_ptr<EventStream> LocalSystem::Resolve(const std::string &name) {
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
auto it = channels_.find(name);
|
||||
if (it == channels_.end()) return nullptr;
|
||||
return it->second.lock();
|
||||
}
|
||||
|
||||
void LocalSystem::Remove(const std::string &name) {
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
auto it = channels_.find(name);
|
||||
CHECK(it != channels_.end()) << "Trying to delete nonexisting stream";
|
||||
channels_.erase(it);
|
||||
}
|
||||
|
||||
void LocalWriter::Send(std::unique_ptr<Message> message) {
|
||||
// TODO: We could add caching to LocalWriter so that we don't need to acquire
|
||||
// lock on system every time we want to send a message. This can be
|
||||
// accomplished by storing weak_ptr to EventStream.
|
||||
auto stream = system_.Resolve(name_);
|
||||
if (!stream) return;
|
||||
stream->Push(std::move(message));
|
||||
}
|
||||
|
||||
EventStream::~EventStream() { system_.Remove(name_); }
|
||||
|
||||
std::unique_ptr<Message> EventStream::Poll() {
|
||||
auto opt_message = queue_.MaybePop();
|
||||
if (opt_message == std::experimental::nullopt) return nullptr;
|
||||
return std::move(*opt_message);
|
||||
}
|
||||
|
||||
void EventStream::Push(std::unique_ptr<Message> message) {
|
||||
queue_.Push(std::move(message));
|
||||
}
|
||||
|
||||
std::unique_ptr<Message> EventStream::Await(
|
||||
std::chrono::system_clock::duration timeout) {
|
||||
auto opt_message = queue_.AwaitPop(timeout);
|
||||
if (opt_message == std::experimental::nullopt) return nullptr;
|
||||
return std::move(*opt_message);
|
||||
};
|
||||
|
||||
void EventStream::Shutdown() { queue_.Shutdown(); }
|
||||
}
|
109
src/communication/messaging/local.hpp
Normal file
109
src/communication/messaging/local.hpp
Normal file
@ -0,0 +1,109 @@
|
||||
#pragma once
|
||||
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <typeindex>
|
||||
|
||||
#include "cereal/types/memory.hpp"
|
||||
|
||||
#include "data_structures/queue.hpp"
|
||||
|
||||
namespace communication::messaging {
|
||||
|
||||
/**
|
||||
* Base class for messages.
|
||||
*/
|
||||
class Message {
|
||||
public:
|
||||
virtual ~Message() {}
|
||||
|
||||
template <class Archive>
|
||||
void serialize(Archive &) {}
|
||||
|
||||
/**
|
||||
* Run-time type identification that is used for callbacks.
|
||||
*
|
||||
* Warning: this works because of the virtual destructor, don't remove it from
|
||||
* this class
|
||||
*/
|
||||
std::type_index type_index() const { return typeid(*this); }
|
||||
};
|
||||
|
||||
class EventStream;
|
||||
class LocalWriter;
|
||||
|
||||
class LocalSystem {
|
||||
public:
|
||||
friend class EventStream;
|
||||
friend class LocalWriter;
|
||||
|
||||
LocalSystem() = default;
|
||||
LocalSystem(const LocalSystem &) = delete;
|
||||
LocalSystem(LocalSystem &&) = delete;
|
||||
LocalSystem &operator=(const LocalSystem &) = delete;
|
||||
LocalSystem &operator=(LocalSystem &&) = delete;
|
||||
|
||||
std::shared_ptr<EventStream> Open(const std::string &name);
|
||||
|
||||
private:
|
||||
std::shared_ptr<EventStream> Resolve(const std::string &name);
|
||||
|
||||
void Remove(const std::string &name);
|
||||
|
||||
std::mutex mutex_;
|
||||
std::unordered_map<std::string, std::weak_ptr<EventStream>> channels_;
|
||||
};
|
||||
|
||||
class LocalWriter {
|
||||
public:
|
||||
LocalWriter(LocalSystem &system, const std::string &name)
|
||||
: system_(system), name_(name) {}
|
||||
LocalWriter(const LocalWriter &) = delete;
|
||||
void operator=(const LocalWriter &) = delete;
|
||||
LocalWriter(LocalWriter &&) = delete;
|
||||
void operator=(LocalWriter &&) = delete;
|
||||
|
||||
template <typename TMessage, typename... Args>
|
||||
void Send(Args &&... args) {
|
||||
Send(std::unique_ptr<Message>(
|
||||
std::make_unique<TMessage>(std::forward<Args>(args)...)));
|
||||
}
|
||||
|
||||
void Send(std::unique_ptr<Message> message);
|
||||
|
||||
private:
|
||||
LocalSystem &system_;
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
class EventStream {
|
||||
public:
|
||||
friend class LocalWriter;
|
||||
friend class LocalSystem;
|
||||
|
||||
EventStream(const EventStream &) = delete;
|
||||
void operator=(const EventStream &) = delete;
|
||||
EventStream(EventStream &&) = delete;
|
||||
void operator=(EventStream &&) = delete;
|
||||
~EventStream();
|
||||
|
||||
std::unique_ptr<Message> Poll();
|
||||
std::unique_ptr<Message> Await(
|
||||
std::chrono::system_clock::duration timeout =
|
||||
std::chrono::system_clock::duration::max());
|
||||
void Shutdown();
|
||||
|
||||
const std::string &name() const { return name_; }
|
||||
|
||||
private:
|
||||
EventStream(LocalSystem &system, const std::string &name)
|
||||
: system_(system), name_(name) {}
|
||||
|
||||
void Push(std::unique_ptr<Message> message);
|
||||
|
||||
LocalSystem &system_;
|
||||
std::string name_;
|
||||
Queue<std::unique_ptr<Message>> queue_;
|
||||
};
|
||||
} // namespace communication::messaging
|
115
src/communication/messaging/protocol.cpp
Normal file
115
src/communication/messaging/protocol.cpp
Normal file
@ -0,0 +1,115 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "communication/messaging/distributed.hpp"
|
||||
#include "communication/messaging/local.hpp"
|
||||
#include "communication/messaging/protocol.hpp"
|
||||
|
||||
#include "fmt/format.h"
|
||||
#include "glog/logging.h"
|
||||
|
||||
namespace communication::messaging {
|
||||
|
||||
Session::Session(Socket &&socket, SessionData &data)
|
||||
: socket_(std::move(socket)), system_(data.system) {}
|
||||
|
||||
bool Session::Alive() const { return alive_; }
|
||||
|
||||
std::string Session::GetStringAndShift(SizeT len) {
|
||||
std::string ret(reinterpret_cast<char *>(buffer_.data()), len);
|
||||
buffer_.Shift(len);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void Session::Execute() {
|
||||
if (buffer_.size() < sizeof(SizeT)) return;
|
||||
SizeT len_channel = GetLength();
|
||||
if (buffer_.size() < 2 * sizeof(SizeT) + len_channel) return;
|
||||
SizeT len_data = GetLength(sizeof(SizeT) + len_channel);
|
||||
if (buffer_.size() < 2 * sizeof(SizeT) + len_data + len_channel) return;
|
||||
|
||||
// Remove the length bytes from the buffer.
|
||||
buffer_.Shift(sizeof(SizeT));
|
||||
auto channel = GetStringAndShift(len_channel);
|
||||
buffer_.Shift(sizeof(SizeT));
|
||||
|
||||
// TODO: check for exceptions
|
||||
std::istringstream stream;
|
||||
stream.str(std::string(reinterpret_cast<char *>(buffer_.data()), len_data));
|
||||
::cereal::BinaryInputArchive iarchive{stream};
|
||||
std::unique_ptr<Message> message{nullptr};
|
||||
iarchive(message);
|
||||
buffer_.Shift(len_data);
|
||||
|
||||
LocalWriter writer(system_, channel);
|
||||
writer.Send(std::move(message));
|
||||
}
|
||||
|
||||
StreamBuffer Session::Allocate() { return buffer_.Allocate(); }
|
||||
|
||||
void Session::Written(size_t len) { buffer_.Written(len); }
|
||||
|
||||
void Session::Close() {
|
||||
DLOG(INFO) << "Closing session";
|
||||
this->socket_.Close();
|
||||
}
|
||||
|
||||
SizeT Session::GetLength(int offset) {
|
||||
SizeT ret = *reinterpret_cast<SizeT *>(buffer_.data() + offset);
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool SendLength(Socket &socket, SizeT length) {
|
||||
return socket.Write(reinterpret_cast<uint8_t *>(&length), sizeof(SizeT));
|
||||
}
|
||||
|
||||
void SendMessage(const std::string &address, uint16_t port,
|
||||
const std::string &channel, std::unique_ptr<Message> message) {
|
||||
CHECK(message) << "Trying to send nullptr instead of message";
|
||||
|
||||
// Initialize endpoint.
|
||||
Endpoint endpoint;
|
||||
try {
|
||||
endpoint = Endpoint(address.c_str(), port);
|
||||
} catch (io::network::NetworkEndpointException &e) {
|
||||
LOG(ERROR) << "Address {} is invalid!";
|
||||
return;
|
||||
}
|
||||
|
||||
// Initialize socket.
|
||||
Socket socket;
|
||||
if (!socket.Connect(endpoint)) {
|
||||
LOG(INFO) << "Couldn't connect to remote address: " << address << ":"
|
||||
<< port;
|
||||
return;
|
||||
}
|
||||
|
||||
if (!SendLength(socket, channel.size())) {
|
||||
LOG(INFO) << "Couldn't send channel size!";
|
||||
return;
|
||||
}
|
||||
if (!socket.Write(channel)) {
|
||||
LOG(INFO) << "Couldn't send channel data!";
|
||||
return;
|
||||
}
|
||||
|
||||
// Serialize and send message
|
||||
std::ostringstream stream;
|
||||
::cereal::BinaryOutputArchive oarchive(stream);
|
||||
oarchive(message);
|
||||
|
||||
const std::string &buffer = stream.str();
|
||||
int64_t message_size = 2 * sizeof(SizeT) + buffer.size() + channel.size();
|
||||
CHECK(message_size <= kMaxMessageSize) << fmt::format(
|
||||
"Trying to send message of size {}, max message size is {}", message_size,
|
||||
kMaxMessageSize);
|
||||
|
||||
if (!SendLength(socket, buffer.size())) {
|
||||
LOG(INFO) << "Couldn't send message size!";
|
||||
return;
|
||||
}
|
||||
if (!socket.Write(buffer)) {
|
||||
LOG(INFO) << "Couldn't send message data!";
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
106
src/communication/messaging/protocol.hpp
Normal file
106
src/communication/messaging/protocol.hpp
Normal file
@ -0,0 +1,106 @@
|
||||
#pragma once
|
||||
|
||||
#include <chrono>
|
||||
#include <cstdint>
|
||||
|
||||
#include "communication/bolt/v1/decoder/buffer.hpp"
|
||||
#include "communication/messaging/local.hpp"
|
||||
#include "io/network/epoll.hpp"
|
||||
#include "io/network/network_endpoint.hpp"
|
||||
#include "io/network/socket.hpp"
|
||||
#include "io/network/stream_buffer.hpp"
|
||||
|
||||
/**
|
||||
* @brief Protocol
|
||||
*
|
||||
* Has classes and functions that implement server and client sides of our
|
||||
* messaging protocol.
|
||||
*
|
||||
* Message layout: SizeT channel_size, channel_size characters channel,
|
||||
* SizeT message_size, message_size bytes serialized_message
|
||||
*/
|
||||
namespace communication::messaging {
|
||||
|
||||
class Message;
|
||||
|
||||
using Endpoint = io::network::NetworkEndpoint;
|
||||
using Socket = io::network::Socket;
|
||||
using StreamBuffer = io::network::StreamBuffer;
|
||||
|
||||
// This buffer should be larger than the largest serialized message.
|
||||
const int64_t kMaxMessageSize = 262144;
|
||||
using Buffer = bolt::Buffer<kMaxMessageSize>;
|
||||
using SizeT = uint16_t;
|
||||
|
||||
/**
|
||||
* Distributed Protocol Data
|
||||
*/
|
||||
struct SessionData {
|
||||
LocalSystem system;
|
||||
};
|
||||
|
||||
/**
|
||||
* Distributed Protocol Session
|
||||
*
|
||||
* This class is responsible for handling a single client connection.
|
||||
*/
|
||||
class Session {
|
||||
public:
|
||||
Session(Socket &&socket, SessionData &data);
|
||||
|
||||
int Id() const { return socket_.fd(); }
|
||||
|
||||
/**
|
||||
* Returns the protocol alive state
|
||||
*/
|
||||
bool Alive() const;
|
||||
|
||||
/**
|
||||
* Executes the protocol after data has been read into the buffer.
|
||||
* Goes through the protocol states in order to execute commands from the
|
||||
* client.
|
||||
*/
|
||||
void Execute();
|
||||
|
||||
/**
|
||||
* Allocates data from the internal buffer.
|
||||
* Used in the underlying network stack to asynchronously read data
|
||||
* from the client.
|
||||
* @returns a StreamBuffer to the allocated internal data buffer
|
||||
*/
|
||||
StreamBuffer Allocate();
|
||||
|
||||
/**
|
||||
* Notifies the internal buffer of written data.
|
||||
* Used in the underlying network stack to notify the internal buffer
|
||||
* how many bytes of data have been written.
|
||||
* @param len how many data was written to the buffer
|
||||
*/
|
||||
void Written(size_t len);
|
||||
|
||||
bool TimedOut() { return false; }
|
||||
|
||||
/**
|
||||
* Closes the session (client socket).
|
||||
*/
|
||||
void Close();
|
||||
|
||||
Socket socket_;
|
||||
LocalSystem &system_;
|
||||
|
||||
std::chrono::time_point<std::chrono::steady_clock> last_event_time_;
|
||||
|
||||
private:
|
||||
SizeT GetLength(int offset = 0);
|
||||
std::string GetStringAndShift(SizeT len);
|
||||
|
||||
bool alive_{true};
|
||||
Buffer buffer_;
|
||||
};
|
||||
|
||||
/**
|
||||
* Distributed Protocol Send Message
|
||||
*/
|
||||
void SendMessage(const std::string &address, uint16_t port,
|
||||
const std::string &channel, std::unique_ptr<Message> message);
|
||||
}
|
@ -96,7 +96,7 @@ class Network {
|
||||
break;
|
||||
}
|
||||
}
|
||||
queue_.Signal();
|
||||
queue_.Shutdown();
|
||||
for (size_t i = 0; i < pool_.size(); ++i) {
|
||||
pool_[i].join();
|
||||
}
|
||||
|
146
src/communication/rpc/rpc.cpp
Normal file
146
src/communication/rpc/rpc.cpp
Normal file
@ -0,0 +1,146 @@
|
||||
#include <iterator>
|
||||
#include <random>
|
||||
|
||||
#include "communication/rpc/rpc.hpp"
|
||||
|
||||
namespace communication::rpc {
|
||||
|
||||
const char kProtocolStreamPrefix[] = "rpc-";
|
||||
|
||||
std::string UniqueId() {
|
||||
static thread_local std::mt19937 pseudo_rand_gen{std::random_device{}()};
|
||||
const char kCharset[] =
|
||||
"0123456789"
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
"abcdefghijklmnopqrstuvwxyz";
|
||||
const auto kMaxIndex = (sizeof(kCharset) - 1);
|
||||
static thread_local std::uniform_int_distribution<> rand_dist{0, kMaxIndex};
|
||||
|
||||
std::string id;
|
||||
std::generate_n(std::back_inserter(id), 20,
|
||||
[&] { return kCharset[rand_dist(pseudo_rand_gen)]; });
|
||||
return id;
|
||||
}
|
||||
|
||||
class Request : public messaging::Message {
|
||||
public:
|
||||
Request(const std::string &address, uint16_t port, const std::string &stream,
|
||||
std::unique_ptr<Message> message)
|
||||
: address_(address),
|
||||
port_(port),
|
||||
stream_(stream),
|
||||
message_id_(UniqueId()),
|
||||
message_(std::move(message)) {}
|
||||
|
||||
const std::string &address() const { return address_; }
|
||||
uint16_t port() const { return port_; }
|
||||
const std::string &stream() const { return stream_; }
|
||||
const std::string &message_id() const { return message_id_; }
|
||||
const messaging::Message &message() const { return *message_; }
|
||||
|
||||
template <class Archive>
|
||||
void serialize(Archive &ar) {
|
||||
ar(cereal::virtual_base_class<messaging::Message>(this), address_, port_,
|
||||
stream_, message_id_, message_);
|
||||
}
|
||||
|
||||
protected:
|
||||
friend class cereal::access;
|
||||
Request() {} // Cereal needs access to a default constructor.
|
||||
|
||||
std::string address_;
|
||||
uint16_t port_;
|
||||
std::string stream_;
|
||||
std::string message_id_;
|
||||
std::unique_ptr<messaging::Message> message_;
|
||||
};
|
||||
|
||||
class Response : public messaging::Message {
|
||||
public:
|
||||
explicit Response(const std::string &message_id,
|
||||
std::unique_ptr<messaging::Message> message)
|
||||
: message_id_(message_id), message_(std::move(message)) {}
|
||||
|
||||
template <class Archive>
|
||||
void serialize(Archive &ar) {
|
||||
ar(cereal::virtual_base_class<messaging::Message>(this), message_id_,
|
||||
message_);
|
||||
}
|
||||
|
||||
const auto &message_id() const { return message_id_; }
|
||||
auto &message() { return message_; }
|
||||
|
||||
protected:
|
||||
Response() {} // Cereal needs access to a default constructor.
|
||||
friend class cereal::access;
|
||||
std::string message_id_;
|
||||
std::unique_ptr<messaging::Message> message_;
|
||||
};
|
||||
|
||||
Client::Client(messaging::System &system, const std::string &address,
|
||||
uint16_t port, const std::string &name)
|
||||
: system_(system),
|
||||
writer_(system, address, port, kProtocolStreamPrefix + name),
|
||||
stream_(system.Open(UniqueId())) {}
|
||||
|
||||
// Because of the way Call is implemented it can fail without reporting (it will
|
||||
// just block indefinately). This is why you always need to provide reasonable
|
||||
// timeout when calling it.
|
||||
// TODO: Make Call use same connection for request and respone and as soon as
|
||||
// connection drop return nullptr.
|
||||
std::unique_ptr<messaging::Message> Client::Call(
|
||||
std::chrono::system_clock::duration timeout,
|
||||
std::unique_ptr<messaging::Message> message) {
|
||||
auto request = std::make_unique<Request>(system_.address(), system_.port(),
|
||||
stream_->name(), std::move(message));
|
||||
auto message_id = request->message_id();
|
||||
writer_.Send(std::move(request));
|
||||
|
||||
auto now = std::chrono::system_clock::now();
|
||||
auto until = now + timeout;
|
||||
|
||||
while (true) {
|
||||
auto message = stream_->Await(until - std::chrono::system_clock::now());
|
||||
if (!message) break; // Client was either signaled or timeout was reached.
|
||||
auto *response = dynamic_cast<Response *>(message.get());
|
||||
if (!response) {
|
||||
LOG(ERROR) << "Message received by rpc client is not a response";
|
||||
continue;
|
||||
}
|
||||
if (response->message_id() != message_id) {
|
||||
// This can happen if some stale response arrives after we issued a new
|
||||
// request.
|
||||
continue;
|
||||
}
|
||||
return std::move(response->message());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Server::Server(messaging::System &system, const std::string &name)
|
||||
: system_(system), stream_(system.Open(kProtocolStreamPrefix + name)) {}
|
||||
|
||||
void Server::Start() {
|
||||
// TODO: Add logging.
|
||||
while (alive_) {
|
||||
auto message = stream_->Await();
|
||||
if (!message) continue;
|
||||
auto *request = dynamic_cast<Request *>(message.get());
|
||||
if (!request) continue;
|
||||
auto &real_request = request->message();
|
||||
auto it = callbacks_.find(real_request.type_index());
|
||||
if (it == callbacks_.end()) continue;
|
||||
auto response = it->second(real_request);
|
||||
messaging::Writer writer(system_, request->address(), request->port(),
|
||||
request->stream());
|
||||
writer.Send<Response>(request->message_id(), std::move(response));
|
||||
}
|
||||
}
|
||||
|
||||
void Server::Shutdown() {
|
||||
alive_ = false;
|
||||
stream_->Shutdown();
|
||||
}
|
||||
}
|
||||
CEREAL_REGISTER_TYPE(communication::rpc::Request);
|
||||
CEREAL_REGISTER_TYPE(communication::rpc::Response);
|
91
src/communication/rpc/rpc.hpp
Normal file
91
src/communication/rpc/rpc.hpp
Normal file
@ -0,0 +1,91 @@
|
||||
#include <type_traits>
|
||||
|
||||
#include "communication/messaging/distributed.hpp"
|
||||
|
||||
namespace communication::rpc {
|
||||
|
||||
template <typename TRequest, typename TResponse>
|
||||
struct RequestResponse {
|
||||
using Request = TRequest;
|
||||
using Response = TResponse;
|
||||
};
|
||||
|
||||
// Client is not thread safe.
|
||||
class Client {
|
||||
public:
|
||||
Client(messaging::System &system, const std::string &address, uint16_t port,
|
||||
const std::string &name);
|
||||
|
||||
// Call function can initiate only one request at the time. Function blocks
|
||||
// until there is a response or timeout was reached. If timeout was reached
|
||||
// nullptr is returned.
|
||||
template <typename TRequestResponse, typename... Args>
|
||||
std::unique_ptr<typename TRequestResponse::Response> Call(
|
||||
std::chrono::system_clock::duration timeout, Args &&... args) {
|
||||
using Req = typename TRequestResponse::Request;
|
||||
using Res = typename TRequestResponse::Response;
|
||||
static_assert(std::is_base_of<messaging::Message, Req>::value,
|
||||
"TRequestResponse::Request must be derived from Message");
|
||||
static_assert(std::is_base_of<messaging::Message, Res>::value,
|
||||
"TRequestResponse::Response must be derived from Message");
|
||||
auto response =
|
||||
Call(timeout, std::unique_ptr<messaging::Message>(
|
||||
std::make_unique<Req>(std::forward<Args>(args)...)));
|
||||
auto *real_response = dynamic_cast<Res *>(response.get());
|
||||
if (!real_response && response) {
|
||||
LOG(ERROR) << "Message response was of unexpected type";
|
||||
return nullptr;
|
||||
}
|
||||
response.release();
|
||||
return std::unique_ptr<Res>(real_response);
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<messaging::Message> Call(
|
||||
std::chrono::system_clock::duration timeout,
|
||||
std::unique_ptr<messaging::Message> message);
|
||||
|
||||
messaging::System &system_;
|
||||
messaging::Writer writer_;
|
||||
std::shared_ptr<messaging::EventStream> stream_;
|
||||
};
|
||||
|
||||
class Server {
|
||||
public:
|
||||
Server(messaging::System &system, const std::string &name);
|
||||
|
||||
template <typename TRequestResponse>
|
||||
void Register(
|
||||
std::function<std::unique_ptr<typename TRequestResponse::Response>(
|
||||
const typename TRequestResponse::Request &)>
|
||||
callback) {
|
||||
static_assert(std::is_base_of<messaging::Message,
|
||||
typename TRequestResponse::Request>::value,
|
||||
"TRequestResponse::Request must be derived from Message");
|
||||
static_assert(std::is_base_of<messaging::Message,
|
||||
typename TRequestResponse::Response>::value,
|
||||
"TRequestResponse::Response must be derived from Message");
|
||||
auto got = callbacks_.emplace(
|
||||
typeid(typename TRequestResponse::Request),
|
||||
[callback = callback](const messaging::Message &base_message) {
|
||||
const auto &message =
|
||||
dynamic_cast<const typename TRequestResponse::Request &>(
|
||||
base_message);
|
||||
return callback(message);
|
||||
});
|
||||
CHECK(got.second) << "Callback for that message type already registered";
|
||||
}
|
||||
|
||||
void Start();
|
||||
void Shutdown();
|
||||
|
||||
private:
|
||||
messaging::System &system_;
|
||||
std::shared_ptr<messaging::EventStream> stream_;
|
||||
std::unordered_map<std::type_index,
|
||||
std::function<std::unique_ptr<messaging::Message>(
|
||||
const messaging::Message &)>>
|
||||
callbacks_;
|
||||
std::atomic<bool> alive_{true};
|
||||
};
|
||||
}
|
@ -7,6 +7,11 @@
|
||||
#include <iostream>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <thread>
|
||||
|
||||
#include "glog/logging.h"
|
||||
|
||||
using namespace std::literals::chrono_literals;
|
||||
|
||||
// Thread safe queue. Probably doesn't perform very well, but it works.
|
||||
template <typename T>
|
||||
@ -18,8 +23,6 @@ class Queue {
|
||||
Queue(Queue &&) = delete;
|
||||
Queue &operator=(Queue &&) = delete;
|
||||
|
||||
~Queue() { Signal(); }
|
||||
|
||||
void Push(T x) {
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
queue_.emplace(std::move(x));
|
||||
@ -46,12 +49,19 @@ class Queue {
|
||||
}
|
||||
|
||||
// Block until there is an element in the queue and then pop it from the queue
|
||||
// and return it. Function can return nullopt only if Queue is signaled via
|
||||
// Signal function.
|
||||
std::experimental::optional<T> AwaitPop() {
|
||||
// and return it. Function can return nullopt if Queue is signaled via
|
||||
// Shutdown function or if there is no element to pop after timeout elapses.
|
||||
std::experimental::optional<T> AwaitPop(
|
||||
std::chrono::system_clock::duration timeout =
|
||||
std::chrono::system_clock::duration::max()) {
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
cvar_.wait(guard, [this] { return !queue_.empty() || signaled_; });
|
||||
if (queue_.empty()) return std::experimental::nullopt;
|
||||
auto now = std::chrono::system_clock::now();
|
||||
auto until = std::chrono::system_clock::time_point::max() - timeout > now
|
||||
? now + timeout
|
||||
: std::chrono::system_clock::time_point::max();
|
||||
cvar_.wait_until(guard, until,
|
||||
[this] { return !queue_.empty() || !alive_; });
|
||||
if (queue_.empty() || !alive_) return std::experimental::nullopt;
|
||||
std::experimental::optional<T> x(std::move(queue_.front()));
|
||||
queue_.pop();
|
||||
return x;
|
||||
@ -66,14 +76,17 @@ class Queue {
|
||||
return x;
|
||||
}
|
||||
|
||||
// Notify all threads waiting on conditional variable to stop waiting.
|
||||
void Signal() {
|
||||
signaled_ = true;
|
||||
// Notify all threads waiting on conditional variable to stop waiting. New
|
||||
// threads that try to Await will not block.
|
||||
void Shutdown() {
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
alive_ = false;
|
||||
guard.unlock();
|
||||
cvar_.notify_all();
|
||||
}
|
||||
|
||||
private:
|
||||
std::atomic<bool> signaled_{false};
|
||||
bool alive_ = true;
|
||||
std::queue<T> queue_;
|
||||
std::condition_variable cvar_;
|
||||
mutable std::mutex mutex_;
|
||||
|
@ -36,13 +36,23 @@ class SocketEventDispatcher {
|
||||
// Go through all events and process them in order.
|
||||
for (int i = 0; i < n; ++i) {
|
||||
auto &event = events_[i];
|
||||
|
||||
Listener &listener = *reinterpret_cast<Listener *>(event.data.ptr);
|
||||
|
||||
// TODO: revise this. Reported events will be combined so continue is not
|
||||
// probably what we want to do.
|
||||
// Even though it is possible for multiple events to be reported we handle
|
||||
// only one of them. Since we use epoll in level triggered mode
|
||||
// unprocessed events will be reported next time we call epoll_wait. This
|
||||
// kind of processing events is safer since callbacks can destroy listener
|
||||
// and calling next callback on listener object will segfault. More subtle
|
||||
// bugs are also possible: one callback can handle multiple events
|
||||
// (maybe even in a subtle implicit way) and then we don't want to call
|
||||
// multiple callbacks since we are not sure if those are valid anymore.
|
||||
try {
|
||||
// Hangup event.
|
||||
if (event.events & EPOLLIN) {
|
||||
// We have some data waiting to be read.
|
||||
listener.OnData();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (event.events & EPOLLRDHUP) {
|
||||
listener.OnClose();
|
||||
continue;
|
||||
@ -54,8 +64,6 @@ class SocketEventDispatcher {
|
||||
continue;
|
||||
}
|
||||
|
||||
// We have some data waiting to be read.
|
||||
listener.OnData();
|
||||
} catch (const std::exception &e) {
|
||||
listener.OnException(e);
|
||||
}
|
||||
|
97
tests/unit/messaging_distributed.cpp
Normal file
97
tests/unit/messaging_distributed.cpp
Normal file
@ -0,0 +1,97 @@
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <cstdlib>
|
||||
#include <future>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#include "communication/messaging/distributed.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using namespace communication::messaging;
|
||||
using namespace std::literals::chrono_literals;
|
||||
|
||||
struct MessageInt : public Message {
|
||||
MessageInt() {} // cereal needs this
|
||||
MessageInt(int x) : x(x) {}
|
||||
int x;
|
||||
|
||||
template <class Archive>
|
||||
void serialize(Archive &ar) {
|
||||
ar(cereal::virtual_base_class<Message>(this), x);
|
||||
}
|
||||
};
|
||||
CEREAL_REGISTER_TYPE(MessageInt);
|
||||
|
||||
#define GET_X(p) dynamic_cast<MessageInt *>((p).get())->x
|
||||
|
||||
/**
|
||||
* Test do the services start up without crashes.
|
||||
*/
|
||||
TEST(SimpleTests, StartAndShutdown) {
|
||||
System system("127.0.0.1", 10000);
|
||||
// do nothing
|
||||
std::this_thread::sleep_for(500ms);
|
||||
system.Shutdown();
|
||||
}
|
||||
|
||||
TEST(Messaging, Pop) {
|
||||
System master_system("127.0.0.1", 10000);
|
||||
System slave_system("127.0.0.1", 10001);
|
||||
auto stream = master_system.Open("main");
|
||||
Writer writer(slave_system, "127.0.0.1", 10000, "main");
|
||||
std::this_thread::sleep_for(100ms);
|
||||
|
||||
EXPECT_EQ(stream->Poll(), nullptr);
|
||||
writer.Send<MessageInt>(10);
|
||||
EXPECT_EQ(GET_X(stream->Await()), 10);
|
||||
master_system.Shutdown();
|
||||
slave_system.Shutdown();
|
||||
}
|
||||
|
||||
TEST(Messaging, Await) {
|
||||
System master_system("127.0.0.1", 10000);
|
||||
System slave_system("127.0.0.1", 10001);
|
||||
auto stream = master_system.Open("main");
|
||||
Writer writer(slave_system, "127.0.0.1", 10000, "main");
|
||||
std::this_thread::sleep_for(100ms);
|
||||
|
||||
std::thread t([&] {
|
||||
std::this_thread::sleep_for(100ms);
|
||||
stream->Shutdown();
|
||||
std::this_thread::sleep_for(100ms);
|
||||
writer.Send<MessageInt>(20);
|
||||
});
|
||||
|
||||
EXPECT_EQ(stream->Poll(), nullptr);
|
||||
EXPECT_EQ(stream->Await(), nullptr);
|
||||
t.join();
|
||||
master_system.Shutdown();
|
||||
slave_system.Shutdown();
|
||||
}
|
||||
|
||||
TEST(Messaging, RecreateChannelAfterClosing) {
|
||||
System master_system("127.0.0.1", 10000);
|
||||
System slave_system("127.0.0.1", 10001);
|
||||
auto stream = master_system.Open("main");
|
||||
Writer writer(slave_system, "127.0.0.1", 10000, "main");
|
||||
std::this_thread::sleep_for(100ms);
|
||||
|
||||
writer.Send<MessageInt>(10);
|
||||
EXPECT_EQ(GET_X(stream->Await()), 10);
|
||||
|
||||
stream = nullptr;
|
||||
writer.Send<MessageInt>(20);
|
||||
std::this_thread::sleep_for(100ms);
|
||||
|
||||
stream = master_system.Open("main");
|
||||
std::this_thread::sleep_for(100ms);
|
||||
EXPECT_EQ(stream->Poll(), nullptr);
|
||||
writer.Send<MessageInt>(30);
|
||||
EXPECT_EQ(GET_X(stream->Await()), 30);
|
||||
|
||||
master_system.Shutdown();
|
||||
slave_system.Shutdown();
|
||||
}
|
72
tests/unit/messaging_local.cpp
Normal file
72
tests/unit/messaging_local.cpp
Normal file
@ -0,0 +1,72 @@
|
||||
#include <chrono>
|
||||
#include <iostream>
|
||||
#include <thread>
|
||||
|
||||
#include "communication/messaging/local.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using namespace std::literals::chrono_literals;
|
||||
using namespace communication::messaging;
|
||||
|
||||
struct MessageInt : public Message {
|
||||
MessageInt(int xx) : x(xx) {}
|
||||
int x;
|
||||
};
|
||||
|
||||
#define GET_X(p) dynamic_cast<MessageInt *>((p).get())->x
|
||||
|
||||
TEST(LocalMessaging, Pop) {
|
||||
LocalSystem system;
|
||||
auto stream = system.Open("main");
|
||||
LocalWriter writer(system, "main");
|
||||
|
||||
EXPECT_EQ(stream->Poll(), nullptr);
|
||||
writer.Send<MessageInt>(10);
|
||||
EXPECT_EQ(GET_X(stream->Poll()), 10);
|
||||
}
|
||||
|
||||
TEST(LocalMessaging, Await) {
|
||||
LocalSystem system;
|
||||
auto stream = system.Open("main");
|
||||
LocalWriter writer(system, "main");
|
||||
|
||||
std::thread t([&] {
|
||||
std::this_thread::sleep_for(100ms);
|
||||
stream->Shutdown();
|
||||
std::this_thread::sleep_for(100ms);
|
||||
writer.Send<MessageInt>(20);
|
||||
});
|
||||
|
||||
EXPECT_EQ(stream->Poll(), nullptr);
|
||||
EXPECT_EQ(stream->Await(), nullptr);
|
||||
t.join();
|
||||
}
|
||||
|
||||
TEST(LocalMessaging, AwaitTimeout) {
|
||||
LocalSystem system;
|
||||
auto stream = system.Open("main");
|
||||
|
||||
EXPECT_EQ(stream->Poll(), nullptr);
|
||||
EXPECT_EQ(stream->Await(100ms), nullptr);
|
||||
}
|
||||
|
||||
TEST(LocalMessaging, RecreateChannelAfterClosing) {
|
||||
LocalSystem system;
|
||||
auto stream = system.Open("main");
|
||||
LocalWriter writer(system, "main");
|
||||
|
||||
writer.Send<MessageInt>(10);
|
||||
EXPECT_EQ(GET_X(stream->Poll()), 10);
|
||||
|
||||
stream = nullptr;
|
||||
writer.Send<MessageInt>(20);
|
||||
stream = system.Open("main");
|
||||
EXPECT_EQ(stream->Poll(), nullptr);
|
||||
writer.Send<MessageInt>(30);
|
||||
EXPECT_EQ(GET_X(stream->Poll()), 30);
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -79,6 +79,7 @@ TEST(Queue, AwaitPop) {
|
||||
});
|
||||
|
||||
EXPECT_EQ(*q.AwaitPop(), 1);
|
||||
std::this_thread::sleep_for(1000ms);
|
||||
EXPECT_EQ(*q.AwaitPop(), 2);
|
||||
EXPECT_EQ(*q.AwaitPop(), 3);
|
||||
EXPECT_EQ(*q.AwaitPop(), 4);
|
||||
@ -86,12 +87,19 @@ TEST(Queue, AwaitPop) {
|
||||
|
||||
std::thread t2([&] {
|
||||
std::this_thread::sleep_for(100ms);
|
||||
q.Signal();
|
||||
q.Shutdown();
|
||||
});
|
||||
std::this_thread::sleep_for(200ms);
|
||||
EXPECT_EQ(q.AwaitPop(), std::experimental::nullopt);
|
||||
t2.join();
|
||||
}
|
||||
|
||||
TEST(Queue, AwaitPopTimeout) {
|
||||
std::this_thread::sleep_for(1000ms);
|
||||
Queue<int> q;
|
||||
EXPECT_EQ(q.AwaitPop(100ms), std::experimental::nullopt);
|
||||
}
|
||||
|
||||
TEST(Queue, Concurrent) {
|
||||
Queue<int> q;
|
||||
|
||||
|
80
tests/unit/rpc.cpp
Normal file
80
tests/unit/rpc.cpp
Normal file
@ -0,0 +1,80 @@
|
||||
#include "gmock/gmock.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include <thread>
|
||||
|
||||
#include "communication/messaging/distributed.hpp"
|
||||
#include "communication/rpc/rpc.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using communication::messaging::System;
|
||||
using communication::messaging::Message;
|
||||
using namespace communication::rpc;
|
||||
using namespace std::literals::chrono_literals;
|
||||
|
||||
struct SumReq : public Message {
|
||||
SumReq() {} // cereal needs this
|
||||
SumReq(int x, int y) : x(x), y(y) {}
|
||||
int x;
|
||||
int y;
|
||||
|
||||
template <class Archive>
|
||||
void serialize(Archive &ar) {
|
||||
ar(cereal::virtual_base_class<Message>(this), x, y);
|
||||
}
|
||||
};
|
||||
CEREAL_REGISTER_TYPE(SumReq);
|
||||
|
||||
struct SumRes : public Message {
|
||||
SumRes() {} // cereal needs this
|
||||
SumRes(int sum) : sum(sum) {}
|
||||
int sum;
|
||||
|
||||
template <class Archive>
|
||||
void serialize(Archive &ar) {
|
||||
ar(cereal::virtual_base_class<Message>(this), sum);
|
||||
}
|
||||
};
|
||||
CEREAL_REGISTER_TYPE(SumRes);
|
||||
using Sum = RequestResponse<SumReq, SumRes>;
|
||||
|
||||
TEST(Rpc, Call) {
|
||||
System server_system("127.0.0.1", 10000);
|
||||
Server server(server_system, "main");
|
||||
server.Register<Sum>([](const SumReq &request) {
|
||||
return std::make_unique<SumRes>(request.x + request.y);
|
||||
});
|
||||
std::thread server_thread([&] { server.Start(); });
|
||||
std::this_thread::sleep_for(100ms);
|
||||
|
||||
System client_system("127.0.0.1", 10001);
|
||||
Client client(client_system, "127.0.0.1", 10000, "main");
|
||||
auto sum = client.Call<Sum>(300ms, 10, 20);
|
||||
EXPECT_EQ(sum->sum, 30);
|
||||
|
||||
server.Shutdown();
|
||||
server_thread.join();
|
||||
server_system.Shutdown();
|
||||
client_system.Shutdown();
|
||||
}
|
||||
|
||||
TEST(Rpc, Timeout) {
|
||||
System server_system("127.0.0.1", 10000);
|
||||
Server server(server_system, "main");
|
||||
server.Register<Sum>([](const SumReq &request) {
|
||||
std::this_thread::sleep_for(300ms);
|
||||
return std::make_unique<SumRes>(request.x + request.y);
|
||||
});
|
||||
std::thread server_thread([&] { server.Start(); });
|
||||
std::this_thread::sleep_for(100ms);
|
||||
|
||||
System client_system("127.0.0.1", 10001);
|
||||
Client client(client_system, "127.0.0.1", 10000, "main");
|
||||
auto sum = client.Call<Sum>(100ms, 10, 20);
|
||||
EXPECT_FALSE(sum);
|
||||
|
||||
server.Shutdown();
|
||||
server_thread.join();
|
||||
server_system.Shutdown();
|
||||
client_system.Shutdown();
|
||||
}
|
62
tools/rpcgen
Executable file
62
tools/rpcgen
Executable file
@ -0,0 +1,62 @@
|
||||
#!/usr/bin/env python3
|
||||
# To run from vim add to your .vimrc something similar to:
|
||||
# command -nargs=* Rpcgen :r !/home/mislav/code/memgraph/tools/rpcgen <args>
|
||||
|
||||
|
||||
import sys
|
||||
|
||||
USAGE = "\n\nUsage:\n" \
|
||||
"./rpcgen request_response_name request_args -- response_args\n" \
|
||||
"Arguments should be seperated with minus sign (-).\n" \
|
||||
"Example: ./rpcgen Sum int x - int y -- int sum\n\n"
|
||||
|
||||
assert len(sys.argv) >= 3, "Too few arguments.\n" + USAGE
|
||||
|
||||
request_response_name = sys.argv[1]
|
||||
args_string = " ".join(sys.argv[2:])
|
||||
split = args_string.split("--")
|
||||
assert len(split) == 2, "Arguments should contain one -- separator.\n" + USAGE
|
||||
request_args, response_args = split
|
||||
|
||||
def generate(message_name, args):
|
||||
def process_arg(arg):
|
||||
arg = arg.strip()
|
||||
assert arg, "Each arg should be non empty string.\n" + USAGE
|
||||
for i in range(len(arg) - 1, -1, -1):
|
||||
if not arg[i].isalpha() and arg[i] != "_": break
|
||||
assert i != -1 and i != len(arg), "Each string separated with - " \
|
||||
"should contain type and variable name.\n" + USAGE
|
||||
typ = arg[:i+1].strip()
|
||||
name = arg[i+1:].strip()
|
||||
return typ, name
|
||||
types, names = zip(*map(process_arg, args.split("-")))
|
||||
|
||||
return \
|
||||
"""
|
||||
struct {message_name} : public Message {{
|
||||
{message_name}() {{}} // cereal needs this
|
||||
{message_name}({constructor_args}): {init_list} {{}}
|
||||
{members}
|
||||
|
||||
template <class Archive>
|
||||
void serialize(Archive &ar) {{
|
||||
ar(cereal::virtual_base_class<Message>(this), {serialize_args});
|
||||
}}
|
||||
}};
|
||||
CEREAL_REGISTER_TYPE({message_name});""" \
|
||||
.format(message_name=message_name,
|
||||
constructor_args=",".join(
|
||||
map(lambda x: x[0] + " " + x[1], zip(types, names))),
|
||||
init_list=", ".join(map(lambda x: "{x}({x})".format(x=x), names)),
|
||||
members="\n".join(map(lambda x:
|
||||
"{} {};".format(x[0], x[1]), zip(types, names))),
|
||||
serialize_args=", ".join(names))
|
||||
|
||||
request_name = request_response_name + "Req"
|
||||
response_name = request_response_name + "Res"
|
||||
req_class = generate(request_name, request_args)
|
||||
res_class = generate(response_name, response_args)
|
||||
print(req_class)
|
||||
print(res_class)
|
||||
print("using {} = RequestResponse<{}, {}>;".format(
|
||||
request_response_name, request_name, response_name))
|
Loading…
Reference in New Issue
Block a user