RPC refactor

Summary:
Start removal of old logic
Remove more obsolete classes
Move Message class to RPC
Remove client logic from system
Remove messaging namespace
Move protocol from messaging to rpc
Move System from messaging to rpc
Remove unnecessary namespace
Remove System from RPC Client
Split Client and Server into separate files
Start implementing new client logic
First semi-working state
Changed network protocol layout
Rewrite client
Fix client receive bug
Cleanup code of debug lines
Migrate to accessors
Migrate back to binary boost archives
Remove debug logging from server
Disable timeout test
Reduce message_id from uint64_t to uint32_t
Add multiple workers to server
Fix compiler warnings
Apply clang-format

Reviewers: teon.banek, florijan, dgleich, buda, mtomic

Reviewed By: teon.banek

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1129
This commit is contained in:
Matej Ferencevic 2018-01-24 12:16:14 +01:00
parent 31aded2dae
commit fc20ddcd25
72 changed files with 856 additions and 1310 deletions

View File

@ -4,10 +4,9 @@
set(memgraph_src_files
communication/bolt/v1/decoder/decoded_value.cpp
communication/bolt/v1/session.cpp
communication/messaging/distributed.cpp
communication/messaging/local.cpp
communication/messaging/protocol.cpp
communication/rpc/rpc.cpp
communication/rpc/client.cpp
communication/rpc/protocol.cpp
communication/rpc/server.cpp
data_structures/concurrent/skiplist_gc.cpp
database/config.cpp
database/counters.cpp

View File

@ -194,6 +194,14 @@ class Session {
db_accessor_ = nullptr;
}
TSocket &socket() { return socket_; }
void RefreshLastEventTime(
const std::chrono::time_point<std::chrono::steady_clock>
&last_event_time) {
last_event_time_ = last_event_time;
}
// TODO: Rethink if there is a way to hide some members. At the momement all
// of them are public.
TSocket socket_;

View File

@ -1,53 +0,0 @@
#include "communication/messaging/distributed.hpp"
namespace communication::messaging {
System::System(const io::network::Endpoint &endpoint) : endpoint_(endpoint) {
// Numbers of workers is quite arbitrary at this point.
StartClient(4);
StartServer(4);
}
System::~System() {
queue_.Shutdown();
for (size_t i = 0; i < pool_.size(); ++i) {
pool_[i].join();
}
}
void System::StartClient(int worker_count) {
LOG(INFO) << "Starting " << worker_count << " client workers";
for (int i = 0; i < worker_count; ++i) {
pool_.push_back(std::thread([this]() {
while (true) {
auto message = queue_.AwaitPop();
if (message == std::experimental::nullopt) break;
SendMessage(message->endpoint, message->channel,
std::move(message->message));
}
}));
}
}
void System::StartServer(int worker_count) {
if (server_ != nullptr) {
LOG(FATAL) << "Tried to start a running server!";
}
// Initialize server.
server_ = std::make_unique<ServerT>(endpoint_, protocol_data_, worker_count);
endpoint_ = server_->endpoint();
}
std::shared_ptr<EventStream> System::Open(const std::string &name) {
return system_.Open(name);
}
Writer::Writer(System &system, const Endpoint &endpoint,
const std::string &name)
: system_(system), endpoint_(endpoint), name_(name) {}
void Writer::Send(std::unique_ptr<Message> message) {
system_.queue_.Emplace(endpoint_, name_, std::move(message));
}
} // namespace communication::messaging

View File

@ -1,105 +0,0 @@
#pragma once
#include <cassert>
#include <exception>
#include <functional>
#include <iostream>
#include <memory>
#include <mutex>
#include <queue>
#include <stdexcept>
#include <tuple>
#include <typeindex>
#include <utility>
#include <gflags/gflags.h>
#include "communication/messaging/local.hpp"
#include "data_structures/queue.hpp"
#include "protocol.hpp"
#include "communication/server.hpp"
#include "io/network/endpoint.hpp"
#include "threading/sync/spinlock.hpp"
namespace communication::messaging {
class System;
// Writes message to remote event stream.
class Writer {
public:
Writer(System &system, const Endpoint &endpoint, const std::string &name);
Writer(const Writer &) = delete;
void operator=(const Writer &) = delete;
Writer(Writer &&) = delete;
void operator=(Writer &&) = delete;
template <typename TMessage, typename... Args>
void Send(Args &&... args) {
Send(std::unique_ptr<Message>(
std::make_unique<TMessage>(std::forward<Args>(args)...)));
}
void Send(std::unique_ptr<Message> message);
private:
System &system_;
Endpoint endpoint_;
std::string name_;
};
class System {
public:
friend class Writer;
System(const Endpoint &endpoint);
System(const System &) = delete;
System(System &&) = delete;
System &operator=(const System &) = delete;
System &operator=(System &&) = delete;
~System();
std::shared_ptr<EventStream> Open(const std::string &name);
const Endpoint &endpoint() const { return endpoint_; }
private:
using Socket = Socket;
using ServerT = communication::Server<Session, SessionData>;
struct NetworkMessage {
NetworkMessage() {}
NetworkMessage(const Endpoint &endpoint, const std::string &channel,
std::unique_ptr<Message> &&message)
: endpoint(endpoint), channel(channel), message(std::move(message)) {}
NetworkMessage(NetworkMessage &&nm) = default;
NetworkMessage &operator=(NetworkMessage &&nm) = default;
Endpoint endpoint;
std::string channel;
std::unique_ptr<Message> message;
};
/** Start a threadpool that dispatches the messages from the (outgoing) queue
* to the sockets */
void StartClient(int worker_count);
/** Start a threadpool that relays the messages from the sockets to the
* LocalEventStreams */
void StartServer(int workers_count);
// Client variables.
std::vector<std::thread> pool_;
Queue<NetworkMessage> queue_;
// Server variables.
SessionData protocol_data_;
std::unique_ptr<ServerT> server_{nullptr};
Endpoint endpoint_;
LocalSystem &system_ = protocol_data_.system;
};
} // namespace communication::messaging

View File

@ -1,62 +0,0 @@
#include "communication/messaging/local.hpp"
#include "fmt/format.h"
#include "glog/logging.h"
namespace communication::messaging {
std::shared_ptr<EventStream> LocalSystem::Open(const std::string &name) {
std::unique_lock<std::mutex> guard(mutex_);
// TODO: It would be better to use std::make_shared here, but we can't since
// constructor is private.
std::shared_ptr<EventStream> stream(new EventStream(*this, name));
auto got = channels_.emplace(name, stream);
CHECK(got.second) << fmt::format("Stream with name {} already exists", name);
return stream;
}
std::shared_ptr<EventStream> LocalSystem::Resolve(const std::string &name) {
std::unique_lock<std::mutex> guard(mutex_);
auto it = channels_.find(name);
if (it == channels_.end()) return nullptr;
return it->second.lock();
}
void LocalSystem::Remove(const std::string &name) {
std::unique_lock<std::mutex> guard(mutex_);
auto it = channels_.find(name);
CHECK(it != channels_.end()) << "Trying to delete nonexisting stream";
channels_.erase(it);
}
void LocalWriter::Send(std::unique_ptr<Message> message) {
// TODO: We could add caching to LocalWriter so that we don't need to acquire
// lock on system every time we want to send a message. This can be
// accomplished by storing weak_ptr to EventStream.
auto stream = system_.Resolve(name_);
if (!stream) return;
stream->Push(std::move(message));
}
EventStream::~EventStream() { system_.Remove(name_); }
std::unique_ptr<Message> EventStream::Poll() {
auto opt_message = queue_.MaybePop();
if (opt_message == std::experimental::nullopt) return nullptr;
return std::move(*opt_message);
}
void EventStream::Push(std::unique_ptr<Message> message) {
queue_.Push(std::move(message));
}
std::unique_ptr<Message> EventStream::Await(
std::chrono::system_clock::duration timeout) {
auto opt_message = queue_.AwaitPop(timeout);
if (opt_message == std::experimental::nullopt) return nullptr;
return std::move(*opt_message);
};
void EventStream::Shutdown() { queue_.Shutdown(); }
} // namespace communication::messaging

View File

@ -1,114 +0,0 @@
#pragma once
#include <mutex>
#include <string>
#include <type_traits>
#include <typeindex>
#include <unordered_map>
#include "boost/serialization/access.hpp"
#include "data_structures/queue.hpp"
namespace communication::messaging {
/**
* Base class for messages.
*/
class Message {
public:
virtual ~Message() {}
/**
* Run-time type identification that is used for callbacks.
*
* Warning: this works because of the virtual destructor, don't remove it from
* this class
*/
std::type_index type_index() const { return typeid(*this); }
private:
friend boost::serialization::access;
template <class TArchive>
void serialize(TArchive &, unsigned int) {}
};
class EventStream;
class LocalWriter;
class LocalSystem {
public:
friend class EventStream;
friend class LocalWriter;
LocalSystem() = default;
LocalSystem(const LocalSystem &) = delete;
LocalSystem(LocalSystem &&) = delete;
LocalSystem &operator=(const LocalSystem &) = delete;
LocalSystem &operator=(LocalSystem &&) = delete;
std::shared_ptr<EventStream> Open(const std::string &name);
private:
std::shared_ptr<EventStream> Resolve(const std::string &name);
void Remove(const std::string &name);
std::mutex mutex_;
std::unordered_map<std::string, std::weak_ptr<EventStream>> channels_;
};
class LocalWriter {
public:
LocalWriter(LocalSystem &system, const std::string &name)
: system_(system), name_(name) {}
LocalWriter(const LocalWriter &) = delete;
void operator=(const LocalWriter &) = delete;
LocalWriter(LocalWriter &&) = delete;
void operator=(LocalWriter &&) = delete;
template <typename TMessage, typename... Args>
void Send(Args &&... args) {
Send(std::unique_ptr<Message>(
std::make_unique<TMessage>(std::forward<Args>(args)...)));
}
void Send(std::unique_ptr<Message> message);
private:
LocalSystem &system_;
std::string name_;
};
class EventStream {
public:
friend class LocalWriter;
friend class LocalSystem;
EventStream(const EventStream &) = delete;
void operator=(const EventStream &) = delete;
EventStream(EventStream &&) = delete;
void operator=(EventStream &&) = delete;
~EventStream();
std::unique_ptr<Message> Poll();
std::unique_ptr<Message> Await(
std::chrono::system_clock::duration timeout =
std::chrono::system_clock::duration::max());
void Shutdown();
const std::string &name() const { return name_; }
private:
EventStream(LocalSystem &system, const std::string &name)
: system_(system), name_(name) {}
void Push(std::unique_ptr<Message> message);
LocalSystem &system_;
std::string name_;
Queue<std::unique_ptr<Message>> queue_;
};
} // namespace communication::messaging

View File

@ -1,137 +0,0 @@
#include <sstream>
#include <unordered_map>
#include "boost/archive/binary_iarchive.hpp"
#include "boost/archive/binary_oarchive.hpp"
#include "boost/serialization/unique_ptr.hpp"
#include "fmt/format.h"
#include "glog/logging.h"
#include "communication/messaging/distributed.hpp"
#include "communication/messaging/local.hpp"
#include "communication/messaging/protocol.hpp"
#include "communication/rpc/messages-inl.hpp"
namespace communication::messaging {
Session::Session(Socket &&socket, SessionData &data)
: socket_(std::move(socket)), system_(data.system) {}
bool Session::Alive() const { return alive_; }
std::string Session::GetStringAndShift(SizeT len) {
std::string ret(reinterpret_cast<char *>(buffer_.data()), len);
buffer_.Shift(len);
return ret;
}
void Session::Execute() {
if (buffer_.size() < sizeof(SizeT)) return;
SizeT len_channel = GetLength();
if (buffer_.size() < 2 * sizeof(SizeT) + len_channel) return;
SizeT len_data = GetLength(sizeof(SizeT) + len_channel);
if (buffer_.size() < 2 * sizeof(SizeT) + len_data + len_channel) return;
// Remove the length bytes from the buffer.
buffer_.Shift(sizeof(SizeT));
auto channel = GetStringAndShift(len_channel);
buffer_.Shift(sizeof(SizeT));
// TODO: check for exceptions
std::stringstream stream;
stream.str(std::string(reinterpret_cast<char *>(buffer_.data()), len_data));
boost::archive::binary_iarchive archive(stream);
std::unique_ptr<Message> message{nullptr};
archive >> message;
buffer_.Shift(len_data);
LocalWriter writer(system_, channel);
writer.Send(std::move(message));
}
StreamBuffer Session::Allocate() { return buffer_.Allocate(); }
void Session::Written(size_t len) { buffer_.Written(len); }
void Session::Close() {
DLOG(INFO) << "Closing session";
this->socket_.Close();
}
SizeT Session::GetLength(int offset) {
SizeT ret = *reinterpret_cast<SizeT *>(buffer_.data() + offset);
return ret;
}
bool SendLength(Socket &socket, SizeT length) {
return socket.Write(reinterpret_cast<uint8_t *>(&length), sizeof(SizeT));
}
struct PairHash {
public:
template <typename T, typename U>
std::size_t operator()(const std::pair<T, U> &x) const {
return std::hash<T>()(x.first) ^ std::hash<U>()(x.second);
}
};
void SendMessage(const Endpoint &endpoint, const std::string &channel,
std::unique_ptr<Message> message) {
static thread_local std::unordered_map<std::pair<std::string, uint16_t>,
Socket, PairHash>
cache;
CHECK(message) << "Trying to send nullptr instead of message";
auto it = cache.find({endpoint.address(), endpoint.port()});
if (it == cache.end()) {
Socket socket;
if (!socket.Connect(endpoint)) {
LOG(INFO) << "Couldn't connect to endpoint: " << endpoint;
return;
}
socket.SetKeepAlive();
it =
cache
.emplace(std::piecewise_construct,
std::forward_as_tuple(endpoint.address(), endpoint.port()),
std::forward_as_tuple(std::move(socket)))
.first;
}
auto &socket = it->second;
if (!SendLength(socket, channel.size())) {
LOG(INFO) << "Couldn't send channel size!";
cache.erase(it);
return;
}
if (!socket.Write(channel)) {
LOG(INFO) << "Couldn't send channel data!";
cache.erase(it);
return;
}
// Serialize and send message
std::stringstream stream;
boost::archive::binary_oarchive archive(stream);
archive << message;
const std::string &buffer = stream.str();
int64_t message_size = 2 * sizeof(SizeT) + buffer.size() + channel.size();
CHECK(message_size <= kMaxMessageSize) << fmt::format(
"Trying to send message of size {}, max message size is {}", message_size,
kMaxMessageSize);
if (!SendLength(socket, buffer.size())) {
LOG(INFO) << "Couldn't send message size!";
cache.erase(it);
return;
}
if (!socket.Write(buffer)) {
LOG(INFO) << "Couldn't send message data!";
cache.erase(it);
return;
}
}
} // namespace communication::messaging

View File

@ -3,7 +3,7 @@
#include "boost/serialization/access.hpp"
#include "boost/serialization/base_object.hpp"
#include "communication/messaging/distributed.hpp"
#include "communication/rpc/messages.hpp"
#include "communication/raft/raft.hpp"
namespace communication::raft {
@ -11,7 +11,7 @@ namespace communication::raft {
enum class RpcType { REQUEST_VOTE, APPEND_ENTRIES };
template <class State>
struct PeerRpcRequest : public messaging::Message {
struct PeerRpcRequest : public rpc::Message {
RpcType type;
RequestVoteRequest request_vote;
AppendEntriesRequest<State> append_entries;
@ -21,14 +21,14 @@ struct PeerRpcRequest : public messaging::Message {
template <class TArchive>
void serialize(TArchive &ar, unsigned int) {
ar &boost::serialization::base_object<messaging::Message>(*this);
ar &boost::serialization::base_object<rpc::Message>(*this);
ar &type;
ar &request_vote;
ar &append_entries;
}
};
struct PeerRpcReply : public messaging::Message {
struct PeerRpcReply : public rpc::Message {
RpcType type;
RequestVoteReply request_vote;
AppendEntriesReply append_entries;
@ -38,7 +38,7 @@ struct PeerRpcReply : public messaging::Message {
template <class TArchive>
void serialize(TArchive &ar, unsigned int) {
ar &boost::serialization::base_object<messaging::Message>(*this);
ar &boost::serialization::base_object<rpc::Message>(*this);
ar &type;
ar &request_vote;
ar &append_entries;

View File

@ -273,8 +273,7 @@ void RaftMemberImpl<State>::RequestVote(const std::string &peer_id,
/* Release lock before issuing RPC and waiting for response. */
/* TODO(mtomic): Revise how this will work with RPC cancellation. */
lock.unlock();
bool ok =
network_.SendRequestVote(peer_id, request, reply, config_.rpc_timeout);
bool ok = network_.SendRequestVote(peer_id, request, reply);
lock.lock();
/* TODO(mtomic): Maybe implement exponential backoff. */
@ -377,8 +376,7 @@ void RaftMemberImpl<State>::AppendEntries(const std::string &peer_id,
/* Release lock before issuing RPC and waiting for response. */
/* TODO(mtomic): Revise how this will work with RPC cancellation. */
lock.unlock();
bool ok =
network_.SendAppendEntries(peer_id, request, reply, config_.rpc_timeout);
bool ok = network_.SendAppendEntries(peer_id, request, reply);
lock.lock();
/* TODO(mtomic): Maybe implement exponential backoff. */

View File

@ -115,17 +115,15 @@ class RaftNetworkInterface {
virtual ~RaftNetworkInterface() = default;
/* These function return false if RPC failed for some reason (e.g. cannot
* establish connection, request timeout or request cancelled). Otherwise
* establish connection or request cancelled). Otherwise
* `reply` contains response from peer. */
virtual bool SendRequestVote(const MemberId &recipient,
const RequestVoteRequest &request,
RequestVoteReply &reply,
std::chrono::milliseconds timeout) = 0;
RequestVoteReply &reply) = 0;
virtual bool SendAppendEntries(const MemberId &recipient,
const AppendEntriesRequest<State> &request,
AppendEntriesReply &reply,
std::chrono::milliseconds timeout) = 0;
AppendEntriesReply &reply) = 0;
/* This will be called once the RaftMember is ready to start receiving RPCs.
*/
@ -155,7 +153,6 @@ struct RaftConfig {
std::chrono::milliseconds leader_timeout_min;
std::chrono::milliseconds leader_timeout_max;
std::chrono::milliseconds heartbeat_interval;
std::chrono::milliseconds rpc_timeout;
std::chrono::milliseconds rpc_backoff;
};

View File

@ -4,10 +4,10 @@
#include "glog/logging.h"
#include "communication/messaging/distributed.hpp"
#include "communication/raft/network_common.hpp"
#include "communication/raft/raft.hpp"
#include "communication/rpc/rpc.hpp"
#include "communication/rpc/client.hpp"
#include "communication/rpc/server.hpp"
#include "io/network/endpoint.hpp"
/* Implementation of `RaftNetworkInterface` using RPC. Raft RPC requests and
@ -22,13 +22,12 @@ namespace communication::raft {
const char *kRaftChannelName = "raft-peer-rpc-channel";
template <class State>
using PeerProtocol =
communication::rpc::RequestResponse<PeerRpcRequest<State>, PeerRpcReply>;
using PeerProtocol = rpc::RequestResponse<PeerRpcRequest<State>, PeerRpcReply>;
template <class State>
class RpcNetwork : public RaftNetworkInterface<State> {
public:
RpcNetwork(communication::messaging::System &system,
RpcNetwork(rpc::System &system,
std::unordered_map<std::string, io::network::Endpoint> directory)
: system_(system),
directory_(std::move(directory)),
@ -57,15 +56,14 @@ class RpcNetwork : public RaftNetworkInterface<State> {
virtual bool SendRequestVote(const MemberId &recipient,
const RequestVoteRequest &request,
RequestVoteReply &reply,
std::chrono::milliseconds timeout) override {
RequestVoteReply &reply) override {
PeerRpcRequest<State> req;
PeerRpcReply rep;
req.type = RpcType::REQUEST_VOTE;
req.request_vote = request;
if (!SendRpc(recipient, req, rep, timeout)) {
if (!SendRpc(recipient, req, rep)) {
return false;
}
@ -75,15 +73,14 @@ class RpcNetwork : public RaftNetworkInterface<State> {
virtual bool SendAppendEntries(const MemberId &recipient,
const AppendEntriesRequest<State> &request,
AppendEntriesReply &reply,
std::chrono::milliseconds timeout) override {
AppendEntriesReply &reply) override {
PeerRpcRequest<State> req;
PeerRpcReply rep;
req.type = RpcType::APPEND_ENTRIES;
req.append_entries = request;
if (!SendRpc(recipient, req, rep, timeout)) {
if (!SendRpc(recipient, req, rep)) {
return false;
}
@ -93,9 +90,9 @@ class RpcNetwork : public RaftNetworkInterface<State> {
private:
bool SendRpc(const MemberId &recipient, const PeerRpcRequest<State> &request,
PeerRpcReply &reply, std::chrono::milliseconds timeout) {
PeerRpcReply &reply) {
auto &client = GetClient(recipient);
auto response = client.template Call<PeerProtocol<State>>(timeout, request);
auto response = client.template Call<PeerProtocol<State>>(request);
if (!response) {
return false;
@ -109,17 +106,17 @@ class RpcNetwork : public RaftNetworkInterface<State> {
auto it = clients_.find(id);
if (it == clients_.end()) {
auto ne = directory_[id];
it = clients_.try_emplace(id, system_, ne, kRaftChannelName).first;
it = clients_.try_emplace(id, ne, kRaftChannelName).first;
}
return it->second;
}
communication::messaging::System &system_;
rpc::System &system_;
// TODO(mtomic): how to update and distribute this?
std::unordered_map<MemberId, io::network::Endpoint> directory_;
rpc::Server server_;
std::unordered_map<MemberId, communication::rpc::Client> clients_;
std::unordered_map<MemberId, rpc::Client> clients_;
};
} // namespace communication::raft

View File

@ -54,15 +54,13 @@ class NoOpNetworkInterface : public RaftNetworkInterface<State> {
~NoOpNetworkInterface() {}
virtual bool SendRequestVote(const MemberId &, const RequestVoteRequest &,
RequestVoteReply &,
std::chrono::milliseconds) override {
RequestVoteReply &) override {
return false;
}
virtual bool SendAppendEntries(const MemberId &,
const AppendEntriesRequest<State> &,
AppendEntriesReply &,
std::chrono::milliseconds) override {
AppendEntriesReply &) override {
return false;
}
@ -80,8 +78,7 @@ class NextReplyNetworkInterface : public RaftNetworkInterface<State> {
virtual bool SendRequestVote(const MemberId &,
const RequestVoteRequest &request,
RequestVoteReply &reply,
std::chrono::milliseconds) override {
RequestVoteReply &reply) override {
PeerRpcRequest<State> req;
req.type = RpcType::REQUEST_VOTE;
req.request_vote = request;
@ -97,8 +94,7 @@ class NextReplyNetworkInterface : public RaftNetworkInterface<State> {
virtual bool SendAppendEntries(const MemberId &,
const AppendEntriesRequest<State> &request,
AppendEntriesReply &reply,
std::chrono::milliseconds) override {
AppendEntriesReply &reply) override {
PeerRpcRequest<State> req;
req.type = RpcType::APPEND_ENTRIES;
req.append_entries = request;

View File

@ -0,0 +1,130 @@
#include "boost/archive/binary_iarchive.hpp"
#include "boost/archive/binary_oarchive.hpp"
#include "boost/serialization/access.hpp"
#include "boost/serialization/base_object.hpp"
#include "boost/serialization/export.hpp"
#include "boost/serialization/unique_ptr.hpp"
#include "communication/rpc/client.hpp"
namespace communication::rpc {
Client::Client(const io::network::Endpoint &endpoint,
const std::string &service_name)
: endpoint_(endpoint), service_name_(service_name) {}
std::unique_ptr<Message> Client::Call(std::unique_ptr<Message> request) {
std::lock_guard<std::mutex> guard(mutex_);
uint32_t request_id = ++next_message_id_;
// Check if the connection is broken (if we haven't used the client for a
// long time the server could have died).
if (socket_ && socket_->ErrorStatus()) {
socket_ = std::experimental::nullopt;
}
// Connect to the remote server.
if (!socket_) {
socket_.emplace();
received_bytes_ = 0;
if (!socket_->Connect(endpoint_)) {
LOG(ERROR) << "Couldn't connect to remote address: " << endpoint_;
socket_ = std::experimental::nullopt;
return nullptr;
}
socket_->SetKeepAlive();
// Send service name size.
MessageSize service_len = service_name_.size();
if (!socket_->Write(reinterpret_cast<uint8_t *>(&service_len),
sizeof(MessageSize))) {
LOG(ERROR) << "Couldn't send service name size!";
socket_ = std::experimental::nullopt;
return nullptr;
}
// Send service name.
if (!socket_->Write(service_name_)) {
LOG(ERROR) << "Couldn't send service name!";
socket_ = std::experimental::nullopt;
return nullptr;
}
}
// Send current request ID.
if (!socket_->Write(reinterpret_cast<uint8_t *>(&request_id),
sizeof(uint32_t))) {
LOG(ERROR) << "Couldn't send request ID!";
socket_ = std::experimental::nullopt;
return nullptr;
}
// Serialize and send request.
std::stringstream request_stream;
boost::archive::binary_oarchive request_archive(request_stream);
request_archive << request;
const std::string &request_buffer = request_stream.str();
MessageSize request_data_size = request_buffer.size();
int64_t request_size = sizeof(uint32_t) + request_data_size;
CHECK(request_size <= kMaxMessageSize) << fmt::format(
"Trying to send message of size {}, max message size is {}", request_size,
kMaxMessageSize);
if (!socket_->Write(reinterpret_cast<uint8_t *>(&request_data_size),
sizeof(MessageSize))) {
LOG(ERROR) << "Couldn't send request size!";
socket_ = std::experimental::nullopt;
return nullptr;
}
if (!socket_->Write(request_buffer)) {
LOG(INFO) << "Couldn't send request data!";
socket_ = std::experimental::nullopt;
return nullptr;
}
// Receive response.
while (true) {
auto received = socket_->Read(buffer_.data() + received_bytes_,
buffer_.size() - received_bytes_);
if (received <= 0) {
socket_ = std::experimental::nullopt;
return nullptr;
}
received_bytes_ += received;
if (received_bytes_ < sizeof(uint32_t) + sizeof(MessageSize)) continue;
uint32_t response_id = *reinterpret_cast<uint32_t *>(buffer_.data());
MessageSize response_data_size =
*reinterpret_cast<MessageSize *>(buffer_.data() + sizeof(uint32_t));
size_t response_size =
sizeof(uint32_t) + sizeof(MessageSize) + response_data_size;
if (received_bytes_ < response_size) continue;
std::stringstream response_stream;
response_stream.str(
std::string(reinterpret_cast<char *>(buffer_.data() + sizeof(uint32_t) +
sizeof(MessageSize)),
response_data_size));
boost::archive::binary_iarchive response_archive(response_stream);
std::unique_ptr<Message> response;
response_archive >> response;
std::copy(buffer_.begin() + response_size,
buffer_.begin() + received_bytes_, buffer_.begin());
received_bytes_ -= response_size;
if (response_id != request_id) {
// This can happen if some stale response arrives after we issued a new
// request.
continue;
}
return response;
}
}
} // namespace communication::rpc

View File

@ -0,0 +1,59 @@
#pragma once
#include <experimental/optional>
#include <memory>
#include <mutex>
#include <glog/logging.h>
#include "communication/rpc/messages.hpp"
#include "io/network/endpoint.hpp"
#include "io/network/socket.hpp"
namespace communication::rpc {
// Client is thread safe, but it is recommended to use thread_local clients.
class Client {
public:
Client(const io::network::Endpoint &endpoint, const std::string &name);
// Call function can initiate only one request at the time. Function blocks
// until there is a response. If there was an error nullptr is returned.
template <typename TRequestResponse, typename... Args>
std::unique_ptr<typename TRequestResponse::Response> Call(Args &&... args) {
using Req = typename TRequestResponse::Request;
using Res = typename TRequestResponse::Response;
static_assert(std::is_base_of<Message, Req>::value,
"TRequestResponse::Request must be derived from Message");
static_assert(std::is_base_of<Message, Res>::value,
"TRequestResponse::Response must be derived from Message");
auto response = Call(std::unique_ptr<Message>(
std::make_unique<Req>(std::forward<Args>(args)...)));
auto *real_response = dynamic_cast<Res *>(response.get());
if (!real_response && response) {
// Since message_id was checked in private Call function, this means
// something is very wrong (probably on the server side).
LOG(ERROR) << "Message response was of unexpected type";
socket_ = std::experimental::nullopt;
return nullptr;
}
response.release();
return std::unique_ptr<Res>(real_response);
}
private:
std::unique_ptr<Message> Call(std::unique_ptr<Message> request);
io::network::Endpoint endpoint_;
std::string service_name_;
std::experimental::optional<io::network::Socket> socket_;
uint32_t next_message_id_{0};
std::array<uint8_t, kMaxMessageSize> buffer_;
size_t received_bytes_{0};
std::mutex mutex_;
};
} // namespace communication::rpc

View File

@ -0,0 +1,76 @@
#pragma once
#include <memory>
#include <type_traits>
#include <typeindex>
#include "boost/serialization/access.hpp"
#include "boost/serialization/base_object.hpp"
namespace communication::rpc {
// This buffer should be larger than the largest serialized message.
const uint64_t kMaxMessageSize = 262144;
using MessageSize = uint16_t;
/**
* Base class for messages.
*/
class Message {
public:
virtual ~Message() {}
/**
* Run-time type identification that is used for callbacks.
*
* Warning: this works because of the virtual destructor, don't remove it from
* this class
*/
std::type_index type_index() const { return typeid(*this); }
private:
friend boost::serialization::access;
template <class TArchive>
void serialize(TArchive &, unsigned int) {}
};
template <typename TRequest, typename TResponse>
struct RequestResponse {
using Request = TRequest;
using Response = TResponse;
};
} // namespace communication::rpc
// RPC Pimp
#define RPC_NO_MEMBER_MESSAGE(name) \
struct name : public communication::rpc::Message { \
name() {} \
\
private: \
friend class boost::serialization::access; \
\
template <class TArchive> \
void serialize(TArchive &ar, unsigned int) { \
ar &boost::serialization::base_object<communication::rpc::Message>( \
*this); \
} \
};
#define RPC_SINGLE_MEMBER_MESSAGE(name, type) \
struct name : public communication::rpc::Message { \
name() {} \
name(const type &member) : member(member) {} \
type member; \
\
private: \
friend class boost::serialization::access; \
\
template <class TArchive> \
void serialize(TArchive &ar, unsigned int) { \
ar &boost::serialization::base_object<communication::rpc::Message>( \
*this); \
ar &member; \
} \
};

View File

@ -0,0 +1,95 @@
#include <sstream>
#include "boost/archive/binary_iarchive.hpp"
#include "boost/archive/binary_oarchive.hpp"
#include "boost/serialization/unique_ptr.hpp"
#include "fmt/format.h"
#include "glog/logging.h"
#include "communication/rpc/messages-inl.hpp"
#include "communication/rpc/messages.hpp"
#include "communication/rpc/protocol.hpp"
#include "communication/rpc/server.hpp"
namespace communication::rpc {
Session::Session(Socket &&socket, System &system)
: socket_(std::make_shared<Socket>(std::move(socket))), system_(system) {}
bool Session::Alive() const { return alive_; }
void Session::Execute() {
if (!handshake_done_) {
if (buffer_.size() < sizeof(MessageSize)) return;
MessageSize service_len = *reinterpret_cast<MessageSize *>(buffer_.data());
if (buffer_.size() < sizeof(MessageSize) + service_len) return;
service_name_ = std::string(
reinterpret_cast<char *>(buffer_.data() + sizeof(MessageSize)),
service_len);
buffer_.Shift(sizeof(MessageSize) + service_len);
handshake_done_ = true;
}
if (buffer_.size() < sizeof(uint32_t) + sizeof(MessageSize)) return;
uint32_t message_id = *reinterpret_cast<uint32_t *>(buffer_.data());
MessageSize message_len =
*reinterpret_cast<MessageSize *>(buffer_.data() + sizeof(uint32_t));
if (buffer_.size() < sizeof(uint32_t) + sizeof(MessageSize) + message_len)
return;
// TODO (mferencevic): check for exceptions
std::stringstream stream;
stream.str(
std::string(reinterpret_cast<char *>(buffer_.data() + sizeof(uint32_t) +
sizeof(MessageSize)),
message_len));
boost::archive::binary_iarchive archive(stream);
std::unique_ptr<Message> message;
archive >> message;
buffer_.Shift(sizeof(uint32_t) + sizeof(MessageSize) + message_len);
system_.AddTask(socket_, service_name_, message_id, std::move(message));
}
StreamBuffer Session::Allocate() { return buffer_.Allocate(); }
void Session::Written(size_t len) { buffer_.Written(len); }
void Session::Close() { DLOG(INFO) << "Closing session"; }
bool SendLength(Socket &socket, MessageSize length) {
return socket.Write(reinterpret_cast<uint8_t *>(&length),
sizeof(MessageSize));
}
void SendMessage(Socket &socket, uint32_t message_id,
std::unique_ptr<Message> &message) {
CHECK(message) << "Trying to send nullptr instead of message";
// Serialize and send message
std::stringstream stream;
boost::archive::binary_oarchive archive(stream);
archive << message;
const std::string &buffer = stream.str();
int64_t message_size = sizeof(MessageSize) + buffer.size();
CHECK(message_size <= kMaxMessageSize) << fmt::format(
"Trying to send message of size {}, max message size is {}", message_size,
kMaxMessageSize);
if (!socket.Write(reinterpret_cast<uint8_t *>(&message_id),
sizeof(uint32_t))) {
LOG(WARNING) << "Couldn't send message id!";
return;
}
if (!SendLength(socket, buffer.size())) {
LOG(WARNING) << "Couldn't send message size!";
return;
}
if (!socket.Write(buffer)) {
LOG(WARNING) << "Couldn't send message data!";
return;
}
}
} // namespace communication::rpc

View File

@ -2,42 +2,34 @@
#include <chrono>
#include <cstdint>
#include <memory>
#include "communication/bolt/v1/decoder/buffer.hpp"
#include "communication/messaging/local.hpp"
#include "communication/rpc/messages.hpp"
#include "io/network/endpoint.hpp"
#include "io/network/epoll.hpp"
#include "io/network/socket.hpp"
#include "io/network/stream_buffer.hpp"
/**
* @brief Protocol
*
* Has classes and functions that implement server and client sides of our
* messaging protocol.
* Has classes and functions that implement the server side of our
* RPC protocol.
*
* Message layout: SizeT channel_size, channel_size characters channel,
* SizeT message_size, message_size bytes serialized_message
* Handshake layout: MessageSize service_size, service_size characters service
*
* Message layout: uint32_t message_id, MessageSize message_size,
* message_size bytes serialized_message
*/
namespace communication::messaging {
class Message;
namespace communication::rpc {
using Endpoint = io::network::Endpoint;
using Socket = io::network::Socket;
using StreamBuffer = io::network::StreamBuffer;
// This buffer should be larger than the largest serialized message.
const int64_t kMaxMessageSize = 262144;
using Buffer = bolt::Buffer<kMaxMessageSize>;
using SizeT = uint16_t;
/**
* Distributed Protocol Data
*/
struct SessionData {
LocalSystem system;
};
// Forward declaration of class System
class System;
/**
* Distributed Protocol Session
@ -46,9 +38,9 @@ struct SessionData {
*/
class Session {
public:
Session(Socket &&socket, SessionData &data);
Session(Socket &&socket, System &system);
int Id() const { return socket_.fd(); }
int Id() const { return socket_->fd(); }
/**
* Returns the protocol alive state
@ -85,22 +77,31 @@ class Session {
*/
void Close();
Socket socket_;
LocalSystem &system_;
Socket &socket() { return *socket_; }
std::chrono::time_point<std::chrono::steady_clock> last_event_time_;
void RefreshLastEventTime(
const std::chrono::time_point<std::chrono::steady_clock>
&last_event_time) {
last_event_time_ = last_event_time;
}
private:
SizeT GetLength(int offset = 0);
std::string GetStringAndShift(SizeT len);
std::shared_ptr<Socket> socket_;
std::chrono::time_point<std::chrono::steady_clock> last_event_time_ =
std::chrono::steady_clock::now();
System &system_;
std::string service_name_;
bool handshake_done_{false};
bool alive_{true};
Buffer buffer_;
};
/**
* Distributed Protocol Send Message
* Distributed Protocol Server Send Message
*/
void SendMessage(const Endpoint &endpoint, const std::string &channel,
std::unique_ptr<Message> message);
} // namespace communication::messaging
void SendMessage(Socket &socket, uint32_t message_id,
std::unique_ptr<Message> &message);
} // namespace communication::rpc

View File

@ -1,147 +0,0 @@
#include <iterator>
#include <random>
#include "boost/archive/binary_iarchive.hpp"
#include "boost/archive/binary_oarchive.hpp"
#include "boost/serialization/access.hpp"
#include "boost/serialization/base_object.hpp"
#include "boost/serialization/export.hpp"
#include "boost/serialization/unique_ptr.hpp"
#include "communication/rpc/rpc.hpp"
#include "io/network/endpoint.hpp"
#include "utils/string.hpp"
namespace communication::rpc {
const char kProtocolStreamPrefix[] = "rpc-";
using Endpoint = io::network::Endpoint;
class Request : public messaging::Message {
public:
Request(const Endpoint &endpoint, const std::string &stream,
std::unique_ptr<Message> message)
: endpoint_(endpoint),
stream_(stream),
message_id_(utils::RandomString(20)),
message_(std::move(message)) {}
const Endpoint &endpoint() const { return endpoint_; }
const std::string &stream() const { return stream_; }
const std::string &message_id() const { return message_id_; }
const messaging::Message &message() const { return *message_; }
private:
friend class boost::serialization::access;
Request() {} // Needed for serialization.
template <class TArchive>
void serialize(TArchive &ar, unsigned int) {
ar &boost::serialization::base_object<messaging::Message>(*this);
ar &endpoint_;
ar &stream_;
ar &message_id_;
ar &message_;
}
io::network::Endpoint endpoint_;
std::string stream_;
std::string message_id_;
std::unique_ptr<messaging::Message> message_;
};
class Response : public messaging::Message {
public:
explicit Response(const std::string &message_id,
std::unique_ptr<messaging::Message> message)
: message_id_(message_id), message_(std::move(message)) {}
const auto &message_id() const { return message_id_; }
auto &message() { return message_; }
private:
friend class boost::serialization::access;
Response() {} // Needed for serialization.
template <class TArchive>
void serialize(TArchive &ar, unsigned int) {
ar &boost::serialization::base_object<Message>(*this);
ar &message_id_;
ar &message_;
}
std::string message_id_;
std::unique_ptr<messaging::Message> message_;
};
Client::Client(messaging::System &system, const io::network::Endpoint &endpoint,
const std::string &name)
: system_(system),
writer_(system, endpoint, kProtocolStreamPrefix + name),
stream_(system.Open(utils::RandomString(20))) {}
// Because of the way Call is implemented it can fail without reporting (it will
// just block indefinately). This is why you always need to provide reasonable
// timeout when calling it.
// TODO: Make Call use same connection for request and respone and as soon as
// connection drop return nullptr.
std::unique_ptr<messaging::Message> Client::Call(
std::chrono::system_clock::duration timeout,
std::unique_ptr<messaging::Message> message) {
auto request = std::make_unique<Request>(system_.endpoint(), stream_->name(),
std::move(message));
auto message_id = request->message_id();
writer_.Send(std::move(request));
auto now = std::chrono::system_clock::now();
auto until = now + timeout;
while (true) {
auto message = stream_->Await(until - std::chrono::system_clock::now());
if (!message) break; // Client was either signaled or timeout was reached.
auto *response = dynamic_cast<Response *>(message.get());
if (!response) {
LOG(ERROR) << "Message received by rpc client is not a response";
continue;
}
if (response->message_id() != message_id) {
// This can happen if some stale response arrives after we issued a new
// request.
continue;
}
return std::move(response->message());
}
return nullptr;
}
Server::Server(messaging::System &system, const std::string &name)
: system_(system), stream_(system.Open(kProtocolStreamPrefix + name)) {
// TODO: Add logging.
running_thread_ = std::thread([this]() {
while (alive_) {
auto message = stream_->Await();
if (!message) continue;
auto *request = dynamic_cast<Request *>(message.get());
if (!request) continue;
auto &real_request = request->message();
auto callbacks_accessor = callbacks_.access();
auto it = callbacks_accessor.find(real_request.type_index());
if (it == callbacks_accessor.end()) continue;
auto response = it->second(real_request);
messaging::Writer writer(system_, request->endpoint(), request->stream());
writer.Send<Response>(request->message_id(), std::move(response));
}
});
}
Server::~Server() {
alive_ = false;
stream_->Shutdown();
if (running_thread_.joinable()) running_thread_.join();
}
} // namespace communication::rpc
BOOST_CLASS_EXPORT(communication::rpc::Request);
BOOST_CLASS_EXPORT(communication::rpc::Response);

View File

@ -1,100 +0,0 @@
#pragma once
#include <type_traits>
#include "communication/messaging/distributed.hpp"
#include "data_structures/concurrent/concurrent_map.hpp"
#include "io/network/endpoint.hpp"
namespace communication::rpc {
template <typename TRequest, typename TResponse>
struct RequestResponse {
using Request = TRequest;
using Response = TResponse;
};
// Client is thread safe.
class Client {
public:
Client(messaging::System &system, const io::network::Endpoint &endpoint,
const std::string &name);
// Call function can initiate only one request at the time. Function blocks
// until there is a response or timeout was reached. If timeout was reached
// nullptr is returned.
template <typename TRequestResponse, typename... Args>
std::unique_ptr<typename TRequestResponse::Response> Call(
std::chrono::system_clock::duration timeout, Args &&... args) {
using Req = typename TRequestResponse::Request;
using Res = typename TRequestResponse::Response;
static_assert(std::is_base_of<messaging::Message, Req>::value,
"TRequestResponse::Request must be derived from Message");
static_assert(std::is_base_of<messaging::Message, Res>::value,
"TRequestResponse::Response must be derived from Message");
std::lock_guard<std::mutex> lock(lock_);
auto response =
Call(timeout, std::unique_ptr<messaging::Message>(
std::make_unique<Req>(std::forward<Args>(args)...)));
auto *real_response = dynamic_cast<Res *>(response.get());
if (!real_response && response) {
LOG(ERROR) << "Message response was of unexpected type";
return nullptr;
}
response.release();
return std::unique_ptr<Res>(real_response);
}
private:
std::unique_ptr<messaging::Message> Call(
std::chrono::system_clock::duration timeout,
std::unique_ptr<messaging::Message> message);
messaging::System &system_;
messaging::Writer writer_;
std::shared_ptr<messaging::EventStream> stream_;
std::mutex lock_;
};
class Server {
public:
Server(messaging::System &system, const std::string &name);
~Server();
template <typename TRequestResponse>
void Register(
std::function<std::unique_ptr<typename TRequestResponse::Response>(
const typename TRequestResponse::Request &)>
callback) {
static_assert(std::is_base_of<messaging::Message,
typename TRequestResponse::Request>::value,
"TRequestResponse::Request must be derived from Message");
static_assert(std::is_base_of<messaging::Message,
typename TRequestResponse::Response>::value,
"TRequestResponse::Response must be derived from Message");
auto callbacks_accessor = callbacks_.access();
auto got = callbacks_accessor.insert(
typeid(typename TRequestResponse::Request), [callback = callback](
const messaging::Message
&base_message) {
const auto &message =
dynamic_cast<const typename TRequestResponse::Request &>(
base_message);
return callback(message);
});
CHECK(got.second) << "Callback for that message type already registered";
}
private:
messaging::System &system_;
std::shared_ptr<messaging::EventStream> stream_;
ConcurrentMap<std::type_index,
std::function<std::unique_ptr<messaging::Message>(
const messaging::Message &)>>
callbacks_;
std::atomic<bool> alive_{true};
std::thread running_thread_;
};
} // namespace communication::rpc

View File

@ -0,0 +1,71 @@
#include "boost/archive/binary_iarchive.hpp"
#include "boost/archive/binary_oarchive.hpp"
#include "boost/serialization/access.hpp"
#include "boost/serialization/base_object.hpp"
#include "boost/serialization/export.hpp"
#include "boost/serialization/unique_ptr.hpp"
#include "communication/rpc/server.hpp"
namespace communication::rpc {
System::System(const io::network::Endpoint &endpoint, const size_t worker_count)
: server_(endpoint, *this, worker_count) {}
System::~System() {}
void System::AddTask(std::shared_ptr<Socket> socket, const std::string &service,
uint64_t message_id, std::unique_ptr<Message> message) {
std::unique_lock<std::mutex> guard(mutex_);
auto it = services_.find(service);
if (it == services_.end()) return;
it->second->queue_.Emplace(std::move(socket), message_id, std::move(message));
}
void System::Add(Server &server) {
std::unique_lock<std::mutex> guard(mutex_);
auto got = services_.emplace(server.service_name(), &server);
CHECK(got.second) << fmt::format("Server with name {} already exists",
server.service_name());
}
void System::Remove(const Server &server) {
std::unique_lock<std::mutex> guard(mutex_);
auto it = services_.find(server.service_name());
CHECK(it != services_.end()) << "Trying to delete nonexisting server";
services_.erase(it);
}
Server::Server(System &system, const std::string &service_name,
int workers_count)
: system_(system), service_name_(service_name) {
system_.Add(*this);
for (int i = 0; i < workers_count; ++i) {
threads_.push_back(std::thread([this]() {
// TODO: Add logging.
while (alive_) {
auto task = queue_.AwaitPop();
if (!task) continue;
auto socket = std::move(std::get<0>(*task));
auto message_id = std::get<1>(*task);
auto message = std::move(std::get<2>(*task));
auto callbacks_accessor = callbacks_.access();
auto it = callbacks_accessor.find(message->type_index());
if (it == callbacks_accessor.end()) continue;
auto response = it->second(*(message.get()));
SendMessage(*socket, message_id, response);
}
}));
}
}
Server::~Server() {
alive_.store(false);
queue_.Shutdown();
for (auto &thread : threads_) {
if (thread.joinable()) thread.join();
}
system_.Remove(*this);
}
} // namespace communication::rpc

View File

@ -0,0 +1,94 @@
#pragma once
#include <type_traits>
#include <unordered_map>
#include <vector>
#include "communication/rpc/messages.hpp"
#include "communication/rpc/protocol.hpp"
#include "communication/server.hpp"
#include "data_structures/concurrent/concurrent_map.hpp"
#include "data_structures/queue.hpp"
#include "io/network/endpoint.hpp"
namespace communication::rpc {
// Forward declaration of Server class
class Server;
class System {
public:
System(const io::network::Endpoint &endpoint, const size_t worker_count = 4);
System(const System &) = delete;
System(System &&) = delete;
System &operator=(const System &) = delete;
System &operator=(System &&) = delete;
~System();
const io::network::Endpoint &endpoint() const { return server_.endpoint(); }
private:
using ServerT = communication::Server<Session, System>;
friend class Session;
friend class Server;
/** Start a threadpool that relays the messages from the sockets to the
* LocalEventStreams */
void StartServer(int workers_count);
void AddTask(std::shared_ptr<Socket> socket, const std::string &service,
uint64_t message_id, std::unique_ptr<Message> message);
void Add(Server &server);
void Remove(const Server &server);
std::mutex mutex_;
// Service name to its server mapping.
std::unordered_map<std::string, Server *> services_;
ServerT server_;
};
class Server {
public:
Server(System &system, const std::string &name, int workers_count = 4);
~Server();
template <typename TRequestResponse>
void Register(
std::function<std::unique_ptr<typename TRequestResponse::Response>(
const typename TRequestResponse::Request &)>
callback) {
static_assert(
std::is_base_of<Message, typename TRequestResponse::Request>::value,
"TRequestResponse::Request must be derived from Message");
static_assert(
std::is_base_of<Message, typename TRequestResponse::Response>::value,
"TRequestResponse::Response must be derived from Message");
auto callbacks_accessor = callbacks_.access();
auto got = callbacks_accessor.insert(
typeid(typename TRequestResponse::Request), [callback = callback](
const Message
&base_message) {
const auto &message =
dynamic_cast<const typename TRequestResponse::Request &>(
base_message);
return callback(message);
});
CHECK(got.second) << "Callback for that message type already registered";
}
const std::string &service_name() const { return service_name_; }
private:
friend class System;
System &system_;
Queue<std::tuple<std::shared_ptr<Socket>, uint64_t, std::unique_ptr<Message>>>
queue_;
std::string service_name_;
ConcurrentMap<std::type_index,
std::function<std::unique_ptr<Message>(const Message &)>>
callbacks_;
std::atomic<bool> alive_{true};
std::vector<std::thread> threads_;
};
} // namespace communication::rpc

View File

@ -45,8 +45,8 @@ class Server {
* Constructs and binds server to endpoint, operates on session data and
* invokes n workers
*/
Server(const io::network::Endpoint &endpoint,
TSessionData &session_data, size_t n)
Server(const io::network::Endpoint &endpoint, TSessionData &session_data,
size_t n)
: session_data_(session_data) {
// Without server we can't continue with application so we can just
// terminate here.
@ -107,10 +107,10 @@ class Server {
}
private:
class ConnectionAcceptor : public io::network::BaseListener {
class ConnectionAcceptor {
public:
ConnectionAcceptor(Socket &socket, Server<TSession, TSessionData> &server)
: io::network::BaseListener(socket), server_(server) {}
: socket_(socket), server_(server) {}
void OnData() {
DCHECK(server_.idx_ < server_.workers_.size()) << "Invalid worker id.";
@ -124,6 +124,15 @@ class Server {
server_.idx_ = (server_.idx_ + 1) % server_.workers_.size();
}
void OnClose() { socket_.Close(); }
void OnException(const std::exception &e) {
LOG(FATAL) << "Exception was thrown while processing event on socket "
<< socket_.fd() << " with message: " << e.what();
}
void OnError() { LOG(FATAL) << "Error on server side occured in epoll"; }
private:
// Accepts connection on socket_ and configures new connections. If done
// successfuly new socket (connection) is returner, nullopt otherwise.
@ -145,6 +154,7 @@ class Server {
return s;
}
Socket &socket_;
Server<TSession, TSessionData> &server_;
};

View File

@ -84,25 +84,23 @@ class Worker {
private:
// TODO: Think about ownership. Who should own socket session,
// SessionSocketListener or Worker?
class SessionSocketListener : public io::network::BaseListener {
class SessionSocketListener {
public:
SessionSocketListener(Socket &&socket,
Worker<TSession, TSessionData> &worker)
: BaseListener(session_.socket_),
session_(std::move(socket), worker.session_data_),
worker_(worker) {}
: session_(std::move(socket), worker.session_data_), worker_(worker) {}
auto &session() { return session_; }
const auto &session() const { return session_; }
const auto &TimedOut() const { return session_.TimedOut(); }
void OnData() {
session_.last_event_time_ = std::chrono::steady_clock::now();
session_.RefreshLastEventTime(std::chrono::steady_clock::now());
DLOG(INFO) << "On data";
// allocate the buffer to fill the data
auto buf = session_.Allocate();
// read from the buffer at most buf.len bytes
int len = session_.socket_.Read(buf.data, buf.len);
int len = session_.socket().Read(buf.data, buf.len);
// check for read errors
if (len == -1) {
@ -131,15 +129,15 @@ class Worker {
<< e.what();
OnError();
}
session_.last_event_time_ = std::chrono::steady_clock::now();
session_.RefreshLastEventTime(std::chrono::steady_clock::now());
}
// TODO: Remove duplication in next three functions.
void OnError() {
LOG(ERROR) << fmt::format(
"Error occured in session associated with {}:{}",
session_.socket_.endpoint().address(),
session_.socket_.endpoint().port());
session_.socket().endpoint().address(),
session_.socket().endpoint().port());
CloseSession();
}
@ -147,24 +145,24 @@ class Worker {
LOG(ERROR) << fmt::format(
"Exception was thrown while processing event in session associated "
"with {}:{} with message: {}",
session_.socket_.endpoint().address(),
session_.socket_.endpoint().port(), e.what());
session_.socket().endpoint().address(),
session_.socket().endpoint().port(), e.what());
CloseSession();
}
void OnSessionAndTxTimeout() {
LOG(WARNING) << fmt::format(
"Session or transaction associated with {}:{} timed out.",
session_.socket_.endpoint().address(),
session_.socket_.endpoint().port());
session_.socket().endpoint().address(),
session_.socket().endpoint().port());
// TODO: report to client what happend.
CloseSession();
}
void OnClose() {
LOG(INFO) << fmt::format("Client {}:{} closed the connection.",
session_.socket_.endpoint().address(),
session_.socket_.endpoint().port());
session_.socket().endpoint().address(),
session_.socket().endpoint().port());
CloseSession();
}

View File

@ -8,7 +8,6 @@
namespace database {
const std::string kCountersRpc = "CountersRpc";
const auto kCountersRpcTimeout = 300ms;
RPC_SINGLE_MEMBER_MESSAGE(CountersGetReq, std::string);
RPC_SINGLE_MEMBER_MESSAGE(CountersGetRes, int64_t);
@ -33,7 +32,7 @@ void SingleNodeCounters::Set(const std::string &name, int64_t value) {
if (!name_counter_pair.second) name_counter_pair.first->second.store(value);
}
MasterCounters::MasterCounters(communication::messaging::System &system)
MasterCounters::MasterCounters(communication::rpc::System &system)
: rpc_server_(system, kCountersRpc) {
rpc_server_.Register<CountersGetRpc>([this](const CountersGetReq &req) {
return std::make_unique<CountersGetRes>(Get(req.member));
@ -44,20 +43,18 @@ MasterCounters::MasterCounters(communication::messaging::System &system)
});
}
WorkerCounters::WorkerCounters(
communication::messaging::System &system,
const io::network::Endpoint &master_endpoint)
: rpc_client_(system, master_endpoint, kCountersRpc) {}
WorkerCounters::WorkerCounters(const io::network::Endpoint &master_endpoint)
: rpc_client_(master_endpoint, kCountersRpc) {}
int64_t WorkerCounters::Get(const std::string &name) {
auto response = rpc_client_.Call<CountersGetRpc>(kCountersRpcTimeout, name);
auto response = rpc_client_.Call<CountersGetRpc>(name);
CHECK(response) << "CountersGetRpc - failed to get response from master";
return response->member;
}
void WorkerCounters::Set(const std::string &name, int64_t value) {
auto response = rpc_client_.Call<CountersSetRpc>(
kCountersRpcTimeout, CountersSetReqData{name, value});
auto response =
rpc_client_.Call<CountersSetRpc>(CountersSetReqData{name, value});
CHECK(response) << "CountersSetRpc - failed to get response from master";
}

View File

@ -4,10 +4,10 @@
#include <cstdint>
#include <string>
#include "communication/messaging/distributed.hpp"
#include "communication/rpc/rpc.hpp"
#include "communication/rpc/messages.hpp"
#include "communication/rpc/client.hpp"
#include "communication/rpc/server.hpp"
#include "data_structures/concurrent/concurrent_map.hpp"
#include "utils/rpc_pimp.hpp"
namespace database {
@ -45,7 +45,7 @@ class SingleNodeCounters : public Counters {
/** Implementation for distributed master. */
class MasterCounters : public SingleNodeCounters {
public:
MasterCounters(communication::messaging::System &system);
MasterCounters(communication::rpc::System &system);
private:
communication::rpc::Server rpc_server_;
@ -54,8 +54,7 @@ class MasterCounters : public SingleNodeCounters {
/** Implementation for distributed worker. */
class WorkerCounters : public Counters {
public:
WorkerCounters(communication::messaging::System &system,
const io::network::Endpoint &master_endpoint);
WorkerCounters(const io::network::Endpoint &master_endpoint);
int64_t Get(const std::string &name) override;
void Set(const std::string &name, int64_t value) override;

View File

@ -1,6 +1,6 @@
#include "glog/logging.h"
#include "communication/messaging/distributed.hpp"
#include "communication/rpc/server.hpp"
#include "database/graph_db.hpp"
#include "distributed/coordination_master.hpp"
#include "distributed/coordination_worker.hpp"
@ -114,16 +114,15 @@ class Master : public PrivateBase {
LOG(FATAL) << "Plan Consumer not available in single-node.";
}
communication::messaging::System system_{config_.master_endpoint};
communication::rpc::System system_{config_.master_endpoint};
tx::MasterEngine tx_engine_{system_, &wal_};
StorageGc storage_gc_{storage_, tx_engine_, config_.gc_cycle_sec};
distributed::MasterCoordination coordination_{system_};
TypemapPack<MasterConcurrentIdMapper> typemap_pack_{system_};
database::MasterCounters counters_{system_};
distributed::RemoteDataRpcServer remote_data_server_{*this, system_};
distributed::RemoteDataRpcClients remote_data_clients_{system_,
coordination_};
distributed::PlanDispatcher plan_dispatcher_{system_, coordination_};
distributed::RemoteDataRpcClients remote_data_clients_{coordination_};
distributed::PlanDispatcher plan_dispatcher_{coordination_};
};
class Worker : public PrivateBase {
@ -142,17 +141,15 @@ class Worker : public PrivateBase {
LOG(FATAL) << "Plan Dispatcher not available in single-node.";
}
communication::messaging::System system_{config_.worker_endpoint};
communication::rpc::System system_{config_.worker_endpoint};
distributed::WorkerCoordination coordination_{system_,
config_.master_endpoint};
tx::WorkerEngine tx_engine_{system_, config_.master_endpoint};
tx::WorkerEngine tx_engine_{config_.master_endpoint};
StorageGc storage_gc_{storage_, tx_engine_, config_.gc_cycle_sec};
TypemapPack<WorkerConcurrentIdMapper> typemap_pack_{system_,
config_.master_endpoint};
database::WorkerCounters counters_{system_, config_.master_endpoint};
TypemapPack<WorkerConcurrentIdMapper> typemap_pack_{config_.master_endpoint};
database::WorkerCounters counters_{config_.master_endpoint};
distributed::RemoteDataRpcServer remote_data_server_{*this, system_};
distributed::RemoteDataRpcClients remote_data_clients_{system_,
coordination_};
distributed::RemoteDataRpcClients remote_data_clients_{coordination_};
distributed::PlanConsumer plan_consumer_{system_};
};

View File

@ -3,8 +3,8 @@
namespace distributed {
MasterCoordination::MasterCoordination(communication::messaging::System &system)
: system_(system), server_(system, kCoordinationServerName) {
MasterCoordination::MasterCoordination(communication::rpc::System &system)
: server_(system, kCoordinationServerName) {
// The master is always worker 0.
workers_.emplace(0, system.endpoint());
@ -43,9 +43,8 @@ MasterCoordination::~MasterCoordination() {
for (const auto &kv : workers_) {
// Skip master (self).
if (kv.first == 0) continue;
communication::rpc::Client client(system_, kv.second,
kCoordinationServerName);
auto result = client.Call<StopWorkerRpc>(100ms);
communication::rpc::Client client(kv.second, kCoordinationServerName);
auto result = client.Call<StopWorkerRpc>();
CHECK(result) << "Failed to shut down worker: " << kv.first;
}
}

View File

@ -3,8 +3,8 @@
#include <mutex>
#include <unordered_map>
#include "communication/messaging/distributed.hpp"
#include "communication/rpc/rpc.hpp"
#include "communication/rpc/client.hpp"
#include "communication/rpc/server.hpp"
#include "distributed/coordination.hpp"
#include "io/network/endpoint.hpp"
@ -27,7 +27,7 @@ class MasterCoordination : public Coordination {
int RegisterWorker(int desired_worker_id, Endpoint endpoint);
public:
explicit MasterCoordination(communication::messaging::System &system);
explicit MasterCoordination(communication::rpc::System &system);
/** Shuts down all the workers and this master server. */
~MasterCoordination();
@ -39,7 +39,6 @@ class MasterCoordination : public Coordination {
std::vector<int> GetWorkerIds() override;
private:
communication::messaging::System &system_;
communication::rpc::Server server_;
// Most master functions aren't thread-safe.
mutable std::mutex lock_;

View File

@ -3,16 +3,14 @@
#include "boost/serialization/access.hpp"
#include "boost/serialization/base_object.hpp"
#include "communication/messaging/local.hpp"
#include "communication/rpc/rpc.hpp"
#include "communication/rpc/messages.hpp"
#include "io/network/endpoint.hpp"
#include "utils/rpc_pimp.hpp"
namespace distributed {
const std::string kCoordinationServerName = "CoordinationRpc";
using communication::messaging::Message;
using communication::rpc::Message;
using Endpoint = io::network::Endpoint;
struct RegisterWorkerReq : public Message {

View File

@ -7,15 +7,17 @@
namespace distributed {
WorkerCoordination::WorkerCoordination(communication::messaging::System &system,
using namespace std::literals::chrono_literals;
WorkerCoordination::WorkerCoordination(communication::rpc::System &system,
const Endpoint &master_endpoint)
: system_(system),
client_(system_, master_endpoint, kCoordinationServerName),
client_(master_endpoint, kCoordinationServerName),
server_(system_, kCoordinationServerName) {}
int WorkerCoordination::RegisterWorker(int desired_worker_id) {
auto result = client_.Call<RegisterWorkerRpc>(300ms, desired_worker_id,
system_.endpoint());
auto result =
client_.Call<RegisterWorkerRpc>(desired_worker_id, system_.endpoint());
CHECK(result) << "Failed to RegisterWorker with the master";
return result->member;
}
@ -24,7 +26,7 @@ Endpoint WorkerCoordination::GetEndpoint(int worker_id) {
auto accessor = endpoint_cache_.access();
auto found = accessor.find(worker_id);
if (found != accessor.end()) return found->second;
auto result = client_.Call<GetEndpointRpc>(300ms, worker_id);
auto result = client_.Call<GetEndpointRpc>(worker_id);
CHECK(result) << "Failed to GetEndpoint from the master";
accessor.insert(worker_id, result->member);
return result->member;

View File

@ -1,5 +1,7 @@
#pragma once
#include "communication/rpc/client.hpp"
#include "communication/rpc/server.hpp"
#include "data_structures/concurrent/concurrent_map.hpp"
#include "distributed/coordination.hpp"
#include "distributed/coordination_rpc_messages.hpp"
@ -12,7 +14,7 @@ class WorkerCoordination : public Coordination {
using Endpoint = io::network::Endpoint;
public:
WorkerCoordination(communication::messaging::System &system,
WorkerCoordination(communication::rpc::System &system,
const Endpoint &master_endpoint);
/**
@ -35,7 +37,7 @@ class WorkerCoordination : public Coordination {
void WaitForShutdown();
private:
communication::messaging::System &system_;
communication::rpc::System &system_;
communication::rpc::Client client_;
communication::rpc::Server server_;
mutable ConcurrentMap<int, Endpoint> endpoint_cache_;

View File

@ -2,7 +2,7 @@
namespace distributed {
PlanConsumer::PlanConsumer(communication::messaging::System &system)
PlanConsumer::PlanConsumer(communication::rpc::System &system)
: server_(system, kDistributedPlanServerName) {
server_.Register<DistributedPlanRpc>([this](const DispatchPlanReq &req) {
plan_cache_.access().insert(req.plan_id_,

View File

@ -1,5 +1,6 @@
#pragma once
#include "communication/rpc/server.hpp"
#include "data_structures/concurrent/concurrent_map.hpp"
#include "distributed/plan_rpc_messages.hpp"
#include "query/frontend/semantic/symbol_table.hpp"
@ -12,7 +13,7 @@ namespace distributed {
*/
class PlanConsumer {
public:
explicit PlanConsumer(communication::messaging::System &system);
explicit PlanConsumer(communication::rpc::System &system);
/**
* Return cached plan and symbol table for a given plan id.

View File

@ -2,9 +2,8 @@
namespace distributed {
PlanDispatcher::PlanDispatcher(communication::messaging::System &system,
Coordination &coordination)
: clients_(system, coordination, kDistributedPlanServerName) {}
PlanDispatcher::PlanDispatcher(Coordination &coordination)
: clients_(coordination, kDistributedPlanServerName) {}
void PlanDispatcher::DispatchPlan(
int64_t plan_id, std::shared_ptr<query::plan::LogicalOperator> plan,
@ -12,7 +11,7 @@ void PlanDispatcher::DispatchPlan(
auto futures = clients_.ExecuteOnWorkers<void>(
0, [plan_id, &plan, &symbol_table](communication::rpc::Client &client) {
auto result =
client.Call<DistributedPlanRpc>(300ms, plan_id, plan, symbol_table);
client.Call<DistributedPlanRpc>(plan_id, plan, symbol_table);
CHECK(result) << "Failed to dispatch plan to worker";
});

View File

@ -1,6 +1,5 @@
#pragma once
#include "communication/rpc/rpc.hpp"
#include "distributed/coordination.hpp"
#include "distributed/plan_rpc_messages.hpp"
#include "distributed/rpc_worker_clients.hpp"
@ -14,8 +13,7 @@ namespace distributed {
*/
class PlanDispatcher {
public:
explicit PlanDispatcher(communication::messaging::System &system,
Coordination &coordination);
explicit PlanDispatcher(Coordination &coordination);
/**
* Synchronously dispatch a plan to all workers and wait for their

View File

@ -3,18 +3,16 @@
#include "boost/serialization/access.hpp"
#include "boost/serialization/base_object.hpp"
#include "communication/messaging/local.hpp"
#include "communication/rpc/rpc.hpp"
#include "communication/rpc/messages.hpp"
#include "query/frontend/ast/ast.hpp"
#include "query/frontend/semantic/symbol_table.hpp"
#include "query/plan/operator.hpp"
#include "utils/rpc_pimp.hpp"
namespace distributed {
const std::string kDistributedPlanServerName = "DistributedPlanRpc";
using communication::messaging::Message;
using communication::rpc::Message;
using SymbolTable = query::SymbolTable;
using AstTreeStorage = query::AstTreeStorage;

View File

@ -3,7 +3,7 @@
#include <mutex>
#include <utility>
#include "communication/messaging/distributed.hpp"
#include "communication/rpc/client.hpp"
#include "distributed/coordination.hpp"
#include "distributed/remote_data_rpc_messages.hpp"
#include "distributed/rpc_worker_clients.hpp"
@ -17,9 +17,8 @@ class RemoteDataRpcClients {
using Client = communication::rpc::Client;
public:
RemoteDataRpcClients(communication::messaging::System &system,
Coordination &coordination)
: clients_(system, coordination, kRemoteDataRpcName) {}
RemoteDataRpcClients(Coordination &coordination)
: clients_(coordination, kRemoteDataRpcName) {}
/// Returns a remote worker's data for the given params. That worker must own
/// the vertex for the given id, and that vertex must be visible in given
@ -28,7 +27,7 @@ class RemoteDataRpcClients {
tx::transaction_id_t tx_id,
gid::Gid gid) {
auto response = clients_.GetClient(worker_id).Call<RemoteVertexRpc>(
kRemoteDataRpcTimeout, TxGidPair{tx_id, gid});
TxGidPair{tx_id, gid});
return std::move(response->name_output_);
}
@ -38,7 +37,7 @@ class RemoteDataRpcClients {
std::unique_ptr<Edge> RemoteEdge(int worker_id, tx::transaction_id_t tx_id,
gid::Gid gid) {
auto response = clients_.GetClient(worker_id).Call<RemoteEdgeRpc>(
kRemoteDataRpcTimeout, TxGidPair{tx_id, gid});
TxGidPair{tx_id, gid});
return std::move(response->name_output_);
}

View File

@ -3,18 +3,15 @@
#include <memory>
#include <string>
#include "communication/messaging/local.hpp"
#include "communication/rpc/rpc.hpp"
#include "communication/rpc/messages.hpp"
#include "distributed/serialization.hpp"
#include "storage/edge.hpp"
#include "storage/gid.hpp"
#include "storage/vertex.hpp"
#include "transactions/type.hpp"
#include "utils/rpc_pimp.hpp"
namespace distributed {
const std::string kRemoteDataRpcName = "RemoteDataRpc";
const auto kRemoteDataRpcTimeout = 100ms;
struct TxGidPair {
tx::transaction_id_t tx_id;
@ -31,7 +28,7 @@ struct TxGidPair {
};
#define MAKE_RESPONSE(type, name) \
class Remote##type##Res : public communication::messaging::Message { \
class Remote##type##Res : public communication::rpc::Message { \
public: \
Remote##type##Res() {} \
Remote##type##Res(const type *name, int worker_id) \
@ -40,14 +37,14 @@ struct TxGidPair {
template <class TArchive> \
void save(TArchive &ar, unsigned int) const { \
ar << boost::serialization::base_object< \
const communication::messaging::Message>(*this); \
const communication::rpc::Message>(*this); \
Save##type(ar, *name_input_, worker_id_); \
} \
\
template <class TArchive> \
void load(TArchive &ar, unsigned int) { \
ar >> boost::serialization::base_object< \
communication::messaging::Message>(*this); \
ar >> boost::serialization::base_object<communication::rpc::Message>( \
*this); \
auto v = Load##type(ar); \
v.swap(name_output_); \
} \

View File

@ -2,8 +2,7 @@
#include <memory>
#include "communication/messaging/distributed.hpp"
#include "communication/rpc/rpc.hpp"
#include "communication/rpc/server.hpp"
#include "database/graph_db.hpp"
#include "database/graph_db_accessor.hpp"
#include "distributed/remote_data_rpc_messages.hpp"
@ -18,7 +17,7 @@ class RemoteDataRpcServer {
// invalidation.
public:
RemoteDataRpcServer(database::GraphDb &db,
communication::messaging::System &system)
communication::rpc::System &system)
: db_(db), system_(system) {
rpc_server_.Register<RemoteVertexRpc>([this](const RemoteVertexReq &req) {
database::GraphDbAccessor dba(db_, req.member.tx_id);
@ -38,7 +37,7 @@ class RemoteDataRpcServer {
private:
database::GraphDb &db_;
communication::messaging::System &system_;
communication::rpc::System &system_;
communication::rpc::Server rpc_server_{system_, kRemoteDataRpcName};
};
} // namespace distributed

View File

@ -4,8 +4,7 @@
#include <future>
#include <type_traits>
#include "communication/messaging/distributed.hpp"
#include "communication/rpc/rpc.hpp"
#include "communication/rpc/client.hpp"
#include "distributed/coordination.hpp"
namespace distributed {
@ -14,12 +13,9 @@ namespace distributed {
* Thread safe. */
class RpcWorkerClients {
public:
RpcWorkerClients(communication::messaging::System &system,
Coordination &coordination,
RpcWorkerClients(Coordination &coordination,
const std::string &rpc_client_name)
: system_(system),
coordination_(coordination),
rpc_client_name_(rpc_client_name) {}
: coordination_(coordination), rpc_client_name_(rpc_client_name) {}
RpcWorkerClients(const RpcWorkerClients &) = delete;
RpcWorkerClients(RpcWorkerClients &&) = delete;
@ -31,9 +27,8 @@ class RpcWorkerClients {
auto found = clients_.find(worker_id);
if (found != clients_.end()) return found->second;
return clients_
.emplace(
std::piecewise_construct, std::forward_as_tuple(worker_id),
std::forward_as_tuple(system_, coordination_.GetEndpoint(worker_id),
.emplace(std::piecewise_construct, std::forward_as_tuple(worker_id),
std::forward_as_tuple(coordination_.GetEndpoint(worker_id),
rpc_client_name_))
.first->second;
}
@ -66,7 +61,6 @@ class RpcWorkerClients {
}
private:
communication::messaging::System &system_;
// TODO make Coordination const, it's member GetEndpoint must be const too.
Coordination &coordination_;
const std::string rpc_client_name_;

View File

@ -153,6 +153,14 @@ void Socket::SetTimeout(long sec, long usec) {
<< "Can't set socket timeout";
}
int Socket::ErrorStatus() const {
int optval;
socklen_t optlen = sizeof(optval);
auto status = getsockopt(socket_, SOL_SOCKET, SO_ERROR, &optval, &optlen);
CHECK(!status) << "getsockopt failed";
return optval;
}
bool Socket::Listen(int backlog) { return listen(socket_, backlog) == 0; }
std::experimental::optional<Socket> Socket::Accept() {

View File

@ -104,6 +104,11 @@ class Socket {
*/
void SetTimeout(long sec, long usec);
/**
* Checks if there are any errors on a socket. Returns 0 if there are none.
*/
int ErrorStatus() const;
/**
* Returns the socket file descriptor.
*/
@ -157,8 +162,7 @@ class Socket {
int Read(void *buffer, size_t len);
private:
Socket(int fd, const Endpoint &endpoint)
: socket_(fd), endpoint_(endpoint) {}
Socket(int fd, const Endpoint &endpoint) : socket_(fd), endpoint_(endpoint) {}
int socket_ = -1;
Endpoint endpoint_;

View File

@ -81,28 +81,4 @@ class SocketEventDispatcher {
Epoll::Event events_[kMaxEvents];
};
/**
* Implements Listener concept, suitable for inheritance.
*/
class BaseListener {
public:
explicit BaseListener(Socket &socket) : socket_(socket) {}
void OnClose() { socket_.Close(); }
// If server is listening on socket and there is incoming connection OnData
// event will be triggered.
void OnData() {}
void OnException(const std::exception &e) {
LOG(FATAL) << "Exception was thrown while processing event on socket "
<< socket_.fd() << " with message: " << e.what();
}
void OnError() { LOG(FATAL) << "Error on server side occured in epoll"; }
protected:
Socket &socket_;
};
} // namespace io::network

View File

@ -6,7 +6,6 @@
#include <glog/logging.h>
#include "communication/bolt/v1/session.hpp"
#include "communication/messaging/distributed.hpp"
#include "communication/server.hpp"
#include "config.hpp"
#include "database/graph_db.hpp"

View File

@ -31,7 +31,7 @@ ID_VALUE_RPC_CALLS(Property)
template <typename TId>
MasterConcurrentIdMapper<TId>::MasterConcurrentIdMapper(
communication::messaging::System &system)
communication::rpc::System &system)
// We have to make sure our rpc server name is unique with regards to type.
// Otherwise we will try to reuse the same rpc server name for different
// types (Label/EdgeType/Property)

View File

@ -2,8 +2,7 @@
#include <experimental/optional>
#include "communication/messaging/distributed.hpp"
#include "communication/rpc/rpc.hpp"
#include "communication/rpc/server.hpp"
#include "data_structures/concurrent/concurrent_map.hpp"
#include "storage/concurrent_id_mapper_single_node.hpp"
@ -13,7 +12,7 @@ namespace storage {
template <typename TId>
class MasterConcurrentIdMapper : public SingleNodeConcurrentIdMapper<TId> {
public:
explicit MasterConcurrentIdMapper(communication::messaging::System &system);
explicit MasterConcurrentIdMapper(communication::rpc::System &system);
private:
communication::rpc::Server rpc_server_;

View File

@ -2,17 +2,15 @@
#include <chrono>
#include "communication/rpc/rpc.hpp"
#include "communication/rpc/messages.hpp"
#include "storage/types.hpp"
#include "transactions/commit_log.hpp"
#include "transactions/snapshot.hpp"
#include "transactions/type.hpp"
#include "utils/rpc_pimp.hpp"
namespace storage {
const std::string kConcurrentIdMapperRpc = "ConcurrentIdMapper";
const auto kConcurrentIdMapperRpcTimeout = 300ms;
namespace impl {

View File

@ -10,16 +10,14 @@ namespace storage {
template <> \
type WorkerConcurrentIdMapper<type>::RpcValueToId( \
const std::string &value) { \
auto response = \
rpc_client_.Call<type##IdRpc>(kConcurrentIdMapperRpcTimeout, value); \
auto response = rpc_client_.Call<type##IdRpc>(value); \
CHECK(response) << ("Failed to obtain " #type " from master"); \
return response->member; \
} \
\
template <> \
std::string WorkerConcurrentIdMapper<type>::RpcIdToValue(type id) { \
auto response = \
rpc_client_.Call<Id##type##Rpc>(kConcurrentIdMapperRpcTimeout, id); \
auto response = rpc_client_.Call<Id##type##Rpc>(id); \
CHECK(response) << ("Failed to obtain " #type " value from master"); \
return response->member; \
}
@ -33,10 +31,8 @@ ID_VALUE_RPC_CALLS(Property)
template <typename TId>
WorkerConcurrentIdMapper<TId>::WorkerConcurrentIdMapper(
communication::messaging::System &system,
const io::network::Endpoint &master_endpoint)
: rpc_client_(system, master_endpoint, impl::RpcServerNameFromType<TId>()) {
}
: rpc_client_(master_endpoint, impl::RpcServerNameFromType<TId>()) {}
template <typename TId>
TId WorkerConcurrentIdMapper<TId>::value_to_id(const std::string &value) {

View File

@ -1,7 +1,6 @@
#pragma once
#include "communication/messaging/distributed.hpp"
#include "communication/rpc/rpc.hpp"
#include "communication/rpc/client.hpp"
#include "data_structures/concurrent/concurrent_map.hpp"
#include "io/network/endpoint.hpp"
#include "storage/concurrent_id_mapper.hpp"
@ -18,8 +17,7 @@ class WorkerConcurrentIdMapper : public ConcurrentIdMapper<TId> {
std::string RpcIdToValue(TId id);
public:
WorkerConcurrentIdMapper(communication::messaging::System &system,
const io::network::Endpoint &master_endpoint);
WorkerConcurrentIdMapper(const io::network::Endpoint &master_endpoint);
TId value_to_id(const std::string &value) override;
const std::string &id_to_value(const TId &id) override;

View File

@ -9,7 +9,7 @@
namespace tx {
MasterEngine::MasterEngine(communication::messaging::System &system,
MasterEngine::MasterEngine(communication::rpc::System &system,
durability::WriteAheadLog *wal)
: SingleNodeEngine(wal), rpc_server_(system, kTransactionEngineRpc) {
rpc_server_.Register<SnapshotRpc>([this](const SnapshotReq &req) {
@ -20,7 +20,7 @@ MasterEngine::MasterEngine(communication::messaging::System &system,
});
rpc_server_.Register<GcSnapshotRpc>(
[this](const communication::messaging::Message &) {
[this](const communication::rpc::Message &) {
return std::make_unique<SnapshotRes>(GlobalGcSnapshot());
});
@ -29,7 +29,7 @@ MasterEngine::MasterEngine(communication::messaging::System &system,
});
rpc_server_.Register<ActiveTransactionsRpc>(
[this](const communication::messaging::Message &) {
[this](const communication::rpc::Message &) {
return std::make_unique<SnapshotRes>(GlobalActiveTransactions());
});

View File

@ -1,7 +1,6 @@
#pragma once
#include "communication/messaging/distributed.hpp"
#include "communication/rpc/rpc.hpp"
#include "communication/rpc/server.hpp"
#include "transactions/engine_single_node.hpp"
namespace tx {
@ -14,7 +13,7 @@ class MasterEngine : public SingleNodeEngine {
* @param wal - Optional. If present, the Engine will write tx
* Begin/Commit/Abort atomically (while under lock).
*/
MasterEngine(communication::messaging::System &system,
MasterEngine(communication::rpc::System &system,
durability::WriteAheadLog *wal = nullptr);
private:

View File

@ -1,10 +1,9 @@
#pragma once
#include "communication/rpc/rpc.hpp"
#include "communication/rpc/messages.hpp"
#include "transactions/commit_log.hpp"
#include "transactions/snapshot.hpp"
#include "transactions/type.hpp"
#include "utils/rpc_pimp.hpp"
namespace tx {

View File

@ -4,8 +4,6 @@
#include <experimental/optional>
#include <unordered_map>
#include "communication/messaging/distributed.hpp"
#include "communication/rpc/rpc.hpp"
#include "durability/wal.hpp"
#include "threading/sync/spinlock.hpp"
#include "transactions/commit_log.hpp"

View File

@ -5,13 +5,9 @@
#include "utils/atomic.hpp"
namespace tx {
namespace {
static const auto kRpcTimeout = 100ms;
} // namespace
WorkerEngine::WorkerEngine(communication::messaging::System &system,
const io::network::Endpoint &endpoint)
: rpc_client_(system, endpoint, kTransactionEngineRpc) {}
WorkerEngine::WorkerEngine(const io::network::Endpoint &endpoint)
: rpc_client_(endpoint, kTransactionEngineRpc) {}
CommitLog::Info WorkerEngine::Info(transaction_id_t tid) const {
auto info = clog_.fetch_info(tid);
@ -20,7 +16,7 @@ CommitLog::Info WorkerEngine::Info(transaction_id_t tid) const {
if (!(info.is_aborted() || info.is_committed())) {
// @review: this version of Call is just used because Info has no
// default constructor.
info = rpc_client_.Call<ClogInfoRpc>(kRpcTimeout, tid)->member;
info = rpc_client_.Call<ClogInfoRpc>(tid)->member;
DCHECK(info.is_committed() || info.is_aborted())
<< "It is expected that the transaction is not running anymore. This "
"function should be used only after the snapshot of the current "
@ -33,16 +29,15 @@ CommitLog::Info WorkerEngine::Info(transaction_id_t tid) const {
}
Snapshot WorkerEngine::GlobalGcSnapshot() {
return std::move(rpc_client_.Call<GcSnapshotRpc>(kRpcTimeout)->member);
return std::move(rpc_client_.Call<GcSnapshotRpc>()->member);
}
Snapshot WorkerEngine::GlobalActiveTransactions() {
return std::move(
rpc_client_.Call<ActiveTransactionsRpc>(kRpcTimeout)->member);
return std::move(rpc_client_.Call<ActiveTransactionsRpc>()->member);
}
bool WorkerEngine::GlobalIsActive(transaction_id_t tid) const {
return rpc_client_.Call<IsActiveRpc>(kRpcTimeout, tid)->member;
return rpc_client_.Call<IsActiveRpc>(tid)->member;
}
tx::transaction_id_t WorkerEngine::LocalLast() const { return local_last_; }
@ -57,8 +52,7 @@ tx::Transaction *WorkerEngine::RunningTransaction(tx::transaction_id_t tx_id) {
auto found = accessor.find(tx_id);
if (found != accessor.end()) return found->second;
Snapshot snapshot(
std::move(rpc_client_.Call<SnapshotRpc>(kRpcTimeout, tx_id)->member));
Snapshot snapshot(std::move(rpc_client_.Call<SnapshotRpc>(tx_id)->member));
auto insertion =
accessor.insert(tx_id, new Transaction(tx_id, snapshot, *this));
utils::EnsureAtomicGe(local_last_, tx_id);

View File

@ -3,8 +3,7 @@
#include <atomic>
#include <mutex>
#include "communication/messaging/distributed.hpp"
#include "communication/rpc/rpc.hpp"
#include "communication/rpc/client.hpp"
#include "data_structures/concurrent/concurrent_map.hpp"
#include "io/network/endpoint.hpp"
#include "transactions/commit_log.hpp"
@ -17,8 +16,7 @@ namespace tx {
* source of truth) to obtain transactional info. Caches most info locally. */
class WorkerEngine : public Engine {
public:
WorkerEngine(communication::messaging::System &system,
const io::network::Endpoint &endpoint);
WorkerEngine(const io::network::Endpoint &endpoint);
CommitLog::Info Info(transaction_id_t tid) const override;
Snapshot GlobalGcSnapshot() override;

View File

@ -1,36 +0,0 @@
#pragma once
#include "boost/serialization/base_object.hpp"
#include "communication/messaging/local.hpp"
#define RPC_NO_MEMBER_MESSAGE(name) \
struct name : public communication::messaging::Message { \
name() {} \
\
private: \
friend class boost::serialization::access; \
\
template <class TArchive> \
void serialize(TArchive &ar, unsigned int) { \
ar &boost::serialization::base_object< \
communication::messaging::Message>(*this); \
} \
};
#define RPC_SINGLE_MEMBER_MESSAGE(name, type) \
struct name : public communication::messaging::Message { \
name() {} \
name(const type &member) : member(member) {} \
type member; \
\
private: \
friend class boost::serialization::access; \
\
template <class TArchive> \
void serialize(TArchive &ar, unsigned int) { \
ar &boost::serialization::base_object< \
communication::messaging::Message>(*this); \
ar &member; \
} \
};

View File

@ -58,6 +58,14 @@ class TestSession {
this->socket_.Close();
}
Socket &socket() { return socket_; }
void RefreshLastEventTime(
const std::chrono::time_point<std::chrono::steady_clock>
&last_event_time) {
last_event_time_ = last_event_time;
}
communication::bolt::Buffer<SIZE * 2> buffer_;
Socket socket_;
io::network::Epoll::Event event_;

View File

@ -45,6 +45,14 @@ class TestSession {
void Close() { this->socket_.Close(); }
Socket &socket() { return socket_; }
void RefreshLastEventTime(
const std::chrono::time_point<std::chrono::steady_clock>
&last_event_time) {
last_event_time_ = last_event_time;
}
Socket socket_;
communication::bolt::Buffer<> buffer_;
io::network::Epoll::Event event_;

View File

@ -6,21 +6,16 @@
#include <gflags/gflags.h>
#include <glog/logging.h>
#include "communication/messaging/distributed.hpp"
#include "communication/rpc/rpc.hpp"
#include "communication/rpc/client.hpp"
#include "io/network/endpoint.hpp"
#include "messages.hpp"
#include "utils/network.hpp"
#include "utils/signals/handler.hpp"
#include "utils/terminate_handler.hpp"
using communication::messaging::Message;
using communication::messaging::System;
using namespace communication::rpc;
using namespace std::literals::chrono_literals;
DEFINE_string(interface, "127.0.0.1", "Client system interface.");
DEFINE_int32(port, 8020, "Client system port.");
DEFINE_string(server_interface, "127.0.0.1",
"Server interface on which to communicate.");
DEFINE_int32(server_port, 8010, "Server port on which to communicate.");
@ -34,9 +29,7 @@ int main(int argc, char **argv) {
google::InitGoogleLogging(argv[0]);
// Initialize client.
System client_system(io::network::Endpoint(FLAGS_interface, FLAGS_port));
Client client(
client_system,
io::network::Endpoint(utils::ResolveHostname(FLAGS_server_interface),
FLAGS_server_port),
"main");
@ -47,7 +40,7 @@ int main(int argc, char **argv) {
// in correct order.
for (int i = 1; i <= 100; ++i) {
LOG(INFO) << fmt::format("Apennding value: {}", i);
auto result_tuple = client.Call<AppendEntry>(300ms, i);
auto result_tuple = client.Call<AppendEntry>(i);
if (!result_tuple) {
LOG(INFO) << "Request unsuccessful";
// Try to resend value

View File

@ -5,14 +5,11 @@
#include <gflags/gflags.h>
#include <glog/logging.h>
#include "communication/messaging/distributed.hpp"
#include "communication/rpc/rpc.hpp"
#include "communication/rpc/server.hpp"
#include "messages.hpp"
#include "utils/signals/handler.hpp"
#include "utils/terminate_handler.hpp"
using communication::messaging::Message;
using communication::messaging::System;
using namespace communication::rpc;
using namespace std::literals::chrono_literals;

View File

@ -3,11 +3,10 @@
#include "boost/serialization/base_object.hpp"
#include "boost/serialization/export.hpp"
#include "communication/messaging/distributed.hpp"
#include "communication/rpc/rpc.hpp"
#include "communication/rpc/messages.hpp"
using boost::serialization::base_object;
using communication::messaging::Message;
using communication::rpc::Message;
using namespace communication::rpc;
struct AppendEntryReq : public Message {

View File

@ -6,11 +6,12 @@
#include "boost/archive/text_oarchive.hpp"
#include "boost/serialization/export.hpp"
#include "communication/messaging/distributed.hpp"
#include "communication/raft/rpc.hpp"
#include "communication/raft/storage/file.hpp"
#include "communication/raft/test_utils.hpp"
using namespace std::literals::chrono_literals;
namespace raft = communication::raft;
using io::network::Endpoint;
@ -41,11 +42,11 @@ int main(int argc, char *argv[]) {
{"b", Endpoint("127.0.0.1", 12346)},
{"c", Endpoint("127.0.0.1", 12347)}};
communication::messaging::System my_system(directory[FLAGS_member_id]);
communication::rpc::System my_system(directory[FLAGS_member_id]);
RpcNetwork<DummyState> network(my_system, directory);
raft::SimpleFileStorage<DummyState> storage(FLAGS_log_dir);
raft::RaftConfig config{{"a", "b", "c"}, 150ms, 300ms, 70ms, 60ms, 30ms};
raft::RaftConfig config{{"a", "b", "c"}, 150ms, 300ms, 70ms, 30ms};
{
raft::RaftMember<DummyState> raft_member(network, storage, FLAGS_member_id,

View File

@ -13,10 +13,10 @@
database::SingleNode db; \
SessionData session_data{db}; \
SessionT session(std::move(socket), session_data); \
std::vector<uint8_t> &output = session.socket_.output;
std::vector<uint8_t> &output = session.socket().output;
using communication::bolt::State;
using communication::bolt::SessionData;
using communication::bolt::State;
using SessionT = communication::bolt::Session<TestSocket>;
using ResultStreamT = SessionT::ResultStreamT;
@ -89,7 +89,7 @@ void ExecuteHandshake(SessionT &session, std::vector<uint8_t> &output) {
session.Execute();
ASSERT_EQ(session.state_, State::Init);
ASSERT_TRUE(session.socket_.IsOpen());
ASSERT_TRUE(session.socket().IsOpen());
PrintOutput(output);
CheckOutput(output, handshake_resp, 4);
}
@ -109,7 +109,7 @@ void ExecuteCommand(SessionT &session, const uint8_t *data, size_t len,
void ExecuteInit(SessionT &session, std::vector<uint8_t> &output) {
ExecuteCommand(session, init_req, sizeof(init_req));
ASSERT_EQ(session.state_, State::Idle);
ASSERT_TRUE(session.socket_.IsOpen());
ASSERT_TRUE(session.socket().IsOpen());
PrintOutput(output);
CheckOutput(output, init_resp, 7);
}
@ -152,7 +152,7 @@ TEST(BoltSession, HandshakeWrongPreamble) {
session.Execute();
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_FALSE(session.socket().IsOpen());
PrintOutput(output);
CheckFailureMessage(output);
}
@ -166,25 +166,25 @@ TEST(BoltSession, HandshakeInTwoPackets) {
session.Execute();
ASSERT_EQ(session.state_, State::Handshake);
ASSERT_TRUE(session.socket_.IsOpen());
ASSERT_TRUE(session.socket().IsOpen());
memcpy(buff.data + 10, handshake_req + 10, 10);
session.Written(10);
session.Execute();
ASSERT_EQ(session.state_, State::Init);
ASSERT_TRUE(session.socket_.IsOpen());
ASSERT_TRUE(session.socket().IsOpen());
PrintOutput(output);
CheckOutput(output, handshake_resp, 4);
}
TEST(BoltSession, HandshakeWriteFail) {
INIT_VARS;
session.socket_.SetWriteSuccess(false);
session.socket().SetWriteSuccess(false);
ExecuteCommand(session, handshake_req, sizeof(handshake_req), false);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_FALSE(session.socket().IsOpen());
ASSERT_EQ(output.size(), 0);
}
@ -199,7 +199,7 @@ TEST(BoltSession, InitWrongSignature) {
ExecuteCommand(session, run_req_header, sizeof(run_req_header));
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_FALSE(session.socket().IsOpen());
CheckFailureMessage(output);
}
@ -212,7 +212,7 @@ TEST(BoltSession, InitWrongMarker) {
ExecuteCommand(session, data, 2);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_FALSE(session.socket().IsOpen());
CheckFailureMessage(output);
}
@ -227,7 +227,7 @@ TEST(BoltSession, InitMissingData) {
ExecuteCommand(session, init_req, len[i]);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_FALSE(session.socket().IsOpen());
CheckFailureMessage(output);
}
}
@ -235,11 +235,11 @@ TEST(BoltSession, InitMissingData) {
TEST(BoltSession, InitWriteFail) {
INIT_VARS;
ExecuteHandshake(session, output);
session.socket_.SetWriteSuccess(false);
session.socket().SetWriteSuccess(false);
ExecuteCommand(session, init_req, sizeof(init_req));
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_FALSE(session.socket().IsOpen());
ASSERT_EQ(output.size(), 0);
}
@ -260,7 +260,7 @@ TEST(BoltSession, ExecuteRunWrongMarker) {
ExecuteCommand(session, data, sizeof(data));
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_FALSE(session.socket().IsOpen());
CheckFailureMessage(output);
}
@ -276,7 +276,7 @@ TEST(BoltSession, ExecuteRunMissingData) {
ExecuteCommand(session, run_req_header, len[i]);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_FALSE(session.socket().IsOpen());
CheckFailureMessage(output);
}
}
@ -289,17 +289,17 @@ TEST(BoltSession, ExecuteRunBasicException) {
ExecuteHandshake(session, output);
ExecuteInit(session, output);
session.socket_.SetWriteSuccess(i == 0);
session.socket().SetWriteSuccess(i == 0);
WriteRunRequest(session, "MATCH (omnom");
session.Execute();
if (i == 0) {
ASSERT_EQ(session.state_, State::ErrorIdle);
ASSERT_TRUE(session.socket_.IsOpen());
ASSERT_TRUE(session.socket().IsOpen());
CheckFailureMessage(output);
} else {
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_FALSE(session.socket().IsOpen());
ASSERT_EQ(output.size(), 0);
}
}
@ -333,7 +333,7 @@ TEST(BoltSession, ExecutePullAllDiscardAllResetWrongMarker) {
ExecuteCommand(session, data, sizeof(data));
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_FALSE(session.socket().IsOpen());
CheckFailureMessage(output);
}
}
@ -346,11 +346,11 @@ TEST(BoltSession, ExecutePullAllBufferEmpty) {
ExecuteHandshake(session, output);
ExecuteInit(session, output);
session.socket_.SetWriteSuccess(i == 0);
session.socket().SetWriteSuccess(i == 0);
ExecuteCommand(session, pullall_req, sizeof(pullall_req));
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_FALSE(session.socket().IsOpen());
if (i == 0) {
CheckFailureMessage(output);
} else {
@ -376,17 +376,17 @@ TEST(BoltSession, ExecutePullAllDiscardAllReset) {
if (j == 1) output.clear();
session.socket_.SetWriteSuccess(j == 0);
session.socket().SetWriteSuccess(j == 0);
ExecuteCommand(session, dataset[i], 2);
if (j == 0) {
ASSERT_EQ(session.state_, State::Idle);
ASSERT_TRUE(session.socket_.IsOpen());
ASSERT_TRUE(session.socket().IsOpen());
ASSERT_FALSE(session.encoder_buffer_.HasData());
PrintOutput(output);
} else {
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_FALSE(session.socket().IsOpen());
ASSERT_EQ(output.size(), 0);
}
}
@ -401,7 +401,7 @@ TEST(BoltSession, ExecuteInvalidMessage) {
ExecuteCommand(session, init_req, sizeof(init_req));
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_FALSE(session.socket().IsOpen());
CheckFailureMessage(output);
}
@ -418,7 +418,7 @@ TEST(BoltSession, ErrorIgnoreMessage) {
output.clear();
session.socket_.SetWriteSuccess(i == 0);
session.socket().SetWriteSuccess(i == 0);
ExecuteCommand(session, init_req, sizeof(init_req));
// assert that all data from the init message was cleaned up
@ -426,11 +426,11 @@ TEST(BoltSession, ErrorIgnoreMessage) {
if (i == 0) {
ASSERT_EQ(session.state_, State::ErrorIdle);
ASSERT_TRUE(session.socket_.IsOpen());
ASSERT_TRUE(session.socket().IsOpen());
CheckOutput(output, ignored_resp, sizeof(ignored_resp));
} else {
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_FALSE(session.socket().IsOpen());
ASSERT_EQ(output.size(), 0);
}
}
@ -448,7 +448,7 @@ TEST(BoltSession, ErrorRunAfterRun) {
output.clear();
session.socket_.SetWriteSuccess(true);
session.socket().SetWriteSuccess(true);
// Session holds results of last run.
ASSERT_EQ(session.state_, State::Result);
@ -458,7 +458,7 @@ TEST(BoltSession, ErrorRunAfterRun) {
session.Execute();
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_FALSE(session.socket().IsOpen());
}
TEST(BoltSession, ErrorCantCleanup) {
@ -476,7 +476,7 @@ TEST(BoltSession, ErrorCantCleanup) {
ExecuteCommand(session, init_req, sizeof(init_req) - 10);
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_FALSE(session.socket().IsOpen());
CheckFailureMessage(output);
}
@ -496,7 +496,7 @@ TEST(BoltSession, ErrorWrongMarker) {
ExecuteCommand(session, data, sizeof(data));
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_FALSE(session.socket().IsOpen());
CheckFailureMessage(output);
}
@ -517,7 +517,7 @@ TEST(BoltSession, ErrorOK) {
output.clear();
session.socket_.SetWriteSuccess(j == 0);
session.socket().SetWriteSuccess(j == 0);
ExecuteCommand(session, dataset[i], 2);
// assert that all data from the init message was cleaned up
@ -525,11 +525,11 @@ TEST(BoltSession, ErrorOK) {
if (j == 0) {
ASSERT_EQ(session.state_, State::Idle);
ASSERT_TRUE(session.socket_.IsOpen());
ASSERT_TRUE(session.socket().IsOpen());
CheckOutput(output, success_resp, sizeof(success_resp));
} else {
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_FALSE(session.socket().IsOpen());
ASSERT_EQ(output.size(), 0);
}
}
@ -552,7 +552,7 @@ TEST(BoltSession, ErrorMissingData) {
ExecuteCommand(session, data, sizeof(data));
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_FALSE(session.socket().IsOpen());
CheckFailureMessage(output);
}
@ -566,7 +566,7 @@ TEST(BoltSession, MultipleChunksInOneExecute) {
ExecuteCommand(session, pullall_req, sizeof(pullall_req));
ASSERT_EQ(session.state_, State::Idle);
ASSERT_TRUE(session.socket_.IsOpen());
ASSERT_TRUE(session.socket().IsOpen());
PrintOutput(output);
// Count chunks in output
@ -598,7 +598,7 @@ TEST(BoltSession, PartialChunk) {
session.Execute();
ASSERT_EQ(session.state_, State::Idle);
ASSERT_TRUE(session.socket_.IsOpen());
ASSERT_TRUE(session.socket().IsOpen());
ASSERT_EQ(output.size(), 0);
WriteChunkTail(session);
@ -606,7 +606,7 @@ TEST(BoltSession, PartialChunk) {
session.Execute();
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_FALSE(session.socket().IsOpen());
ASSERT_GT(output.size(), 0);
PrintOutput(output);
}
@ -658,7 +658,7 @@ TEST(BoltSession, ExplicitTransactionValidQueries) {
ASSERT_FALSE(session.db_accessor_);
CheckSuccessMessage(output);
ASSERT_TRUE(session.socket_.IsOpen());
ASSERT_TRUE(session.socket().IsOpen());
}
}
@ -707,20 +707,20 @@ TEST(BoltSession, ExplicitTransactionInvalidQuery) {
if (transaction_end == "ROLLBACK") {
ASSERT_EQ(session.state_, State::Result);
ASSERT_FALSE(session.db_accessor_);
ASSERT_TRUE(session.socket_.IsOpen());
ASSERT_TRUE(session.socket().IsOpen());
CheckSuccessMessage(output);
ExecuteCommand(session, pullall_req, sizeof(pullall_req));
session.Execute();
ASSERT_EQ(session.state_, State::Idle);
ASSERT_FALSE(session.db_accessor_);
ASSERT_TRUE(session.socket_.IsOpen());
ASSERT_TRUE(session.socket().IsOpen());
CheckSuccessMessage(output);
} else {
ASSERT_EQ(session.state_, State::Close);
ASSERT_FALSE(session.db_accessor_);
ASSERT_FALSE(session.socket_.IsOpen());
ASSERT_FALSE(session.socket().IsOpen());
CheckFailureMessage(output);
}
}

View File

@ -2,7 +2,7 @@
#include "gtest/gtest.h"
#include "communication/messaging/distributed.hpp"
#include "communication/rpc/server.hpp"
#include "storage/concurrent_id_mapper_master.hpp"
#include "storage/concurrent_id_mapper_worker.hpp"
#include "storage/types.hpp"
@ -12,16 +12,15 @@ class DistributedConcurrentIdMapperTest : public ::testing::Test {
const std::string kLocal{"127.0.0.1"};
protected:
communication::messaging::System master_system_{{kLocal, 0}};
communication::rpc::System master_system_{{kLocal, 0}};
std::experimental::optional<storage::MasterConcurrentIdMapper<TId>>
master_mapper_;
communication::messaging::System worker_system_{{kLocal, 0}};
std::experimental::optional<storage::WorkerConcurrentIdMapper<TId>>
worker_mapper_;
void SetUp() override {
master_mapper_.emplace(master_system_);
worker_mapper_.emplace(worker_system_, master_system_.endpoint());
worker_mapper_.emplace(master_system_.endpoint());
}
void TearDown() override {
worker_mapper_ = std::experimental::nullopt;

View File

@ -1,19 +1,16 @@
#include "gtest/gtest.h"
#include "communication/messaging/distributed.hpp"
#include "communication/rpc/server.hpp"
#include "database/counters.hpp"
const std::string kLocal = "127.0.0.1";
TEST(CountersDistributed, All) {
communication::messaging::System master_sys({kLocal, 0});
communication::rpc::System master_sys({kLocal, 0});
database::MasterCounters master(master_sys);
communication::messaging::System w1_sys({kLocal, 0});
database::WorkerCounters w1(w1_sys, master_sys.endpoint());
communication::messaging::System w2_sys({kLocal, 0});
database::WorkerCounters w2(w2_sys, master_sys.endpoint());
database::WorkerCounters w1(master_sys.endpoint());
database::WorkerCounters w2(master_sys.endpoint());
EXPECT_EQ(w1.Get("a"), 0);
EXPECT_EQ(w1.Get("a"), 1);

View File

@ -7,13 +7,14 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "communication/messaging/distributed.hpp"
#include "communication/rpc/server.hpp"
#include "distributed/coordination_master.hpp"
#include "distributed/coordination_worker.hpp"
#include "io/network/endpoint.hpp"
using communication::messaging::System;
using communication::rpc::System;
using namespace distributed;
using namespace std::literals::chrono_literals;
const int kWorkerCount = 5;
const std::string kLocal = "127.0.0.1";

View File

@ -3,7 +3,6 @@
#include "gtest/gtest.h"
#include "communication/messaging/distributed.hpp"
#include "database/graph_db.hpp"
#include "distributed/coordination.hpp"
#include "distributed/coordination_master.hpp"
@ -104,7 +103,7 @@ TEST_F(DistributedGraphDbTest, TxEngine) {
EXPECT_EQ(worker2().tx_engine().RunningTransaction(tx2->id_)->snapshot(),
tx2->snapshot());
::testing::FLAGS_gtest_death_test_style = "fast";
::testing::FLAGS_gtest_death_test_style = "threadsafe";
EXPECT_DEATH(worker2().tx_engine().RunningTransaction(123), "");
}

View File

@ -1,101 +0,0 @@
#include <atomic>
#include <chrono>
#include <cstdlib>
#include <future>
#include <iostream>
#include <string>
#include <thread>
#include <vector>
#include "boost/archive/binary_iarchive.hpp"
#include "boost/archive/binary_oarchive.hpp"
#include "boost/archive/text_iarchive.hpp"
#include "boost/archive/text_oarchive.hpp"
#include "boost/serialization/access.hpp"
#include "boost/serialization/base_object.hpp"
#include "boost/serialization/export.hpp"
#include "gtest/gtest.h"
#include "communication/messaging/distributed.hpp"
using namespace communication::messaging;
using namespace std::literals::chrono_literals;
struct MessageInt : public Message {
MessageInt(int x) : x(x) {}
int x;
private:
friend class boost::serialization::access;
MessageInt() {} // Needed for serialization
template <class TArchive>
void serialize(TArchive &ar, unsigned int) {
ar &boost::serialization::base_object<Message>(*this);
ar &x;
}
};
BOOST_CLASS_EXPORT(MessageInt);
#define GET_X(p) dynamic_cast<MessageInt *>((p).get())->x
/**
* Test do the services start up without crashes.
*/
TEST(SimpleTests, StartAndShutdown) {
System system({"127.0.0.1", 0});
// do nothing
std::this_thread::sleep_for(500ms);
}
TEST(Messaging, Pop) {
System master_system({"127.0.0.1", 0});
System slave_system({"127.0.0.1", 0});
auto stream = master_system.Open("main");
Writer writer(slave_system, master_system.endpoint(), "main");
std::this_thread::sleep_for(100ms);
EXPECT_EQ(stream->Poll(), nullptr);
writer.Send<MessageInt>(10);
EXPECT_EQ(GET_X(stream->Await()), 10);
}
TEST(Messaging, Await) {
System master_system({"127.0.0.1", 0});
System slave_system({"127.0.0.1", 0});
auto stream = master_system.Open("main");
Writer writer(slave_system, master_system.endpoint(), "main");
std::this_thread::sleep_for(100ms);
std::thread t([&] {
std::this_thread::sleep_for(100ms);
stream->Shutdown();
std::this_thread::sleep_for(100ms);
writer.Send<MessageInt>(20);
});
EXPECT_EQ(stream->Poll(), nullptr);
EXPECT_EQ(stream->Await(), nullptr);
t.join();
}
TEST(Messaging, RecreateChannelAfterClosing) {
System master_system({"127.0.0.1", 0});
System slave_system({"127.0.0.1", 0});
auto stream = master_system.Open("main");
Writer writer(slave_system, master_system.endpoint(), "main");
std::this_thread::sleep_for(100ms);
writer.Send<MessageInt>(10);
EXPECT_EQ(GET_X(stream->Await()), 10);
stream = nullptr;
writer.Send<MessageInt>(20);
std::this_thread::sleep_for(100ms);
stream = master_system.Open("main");
std::this_thread::sleep_for(100ms);
EXPECT_EQ(stream->Poll(), nullptr);
writer.Send<MessageInt>(30);
EXPECT_EQ(GET_X(stream->Await()), 30);
}

View File

@ -1,72 +0,0 @@
#include <chrono>
#include <iostream>
#include <thread>
#include "communication/messaging/local.hpp"
#include "gtest/gtest.h"
using namespace std::literals::chrono_literals;
using namespace communication::messaging;
struct MessageInt : public Message {
MessageInt(int xx) : x(xx) {}
int x;
};
#define GET_X(p) dynamic_cast<MessageInt *>((p).get())->x
TEST(LocalMessaging, Pop) {
LocalSystem system;
auto stream = system.Open("main");
LocalWriter writer(system, "main");
EXPECT_EQ(stream->Poll(), nullptr);
writer.Send<MessageInt>(10);
EXPECT_EQ(GET_X(stream->Poll()), 10);
}
TEST(LocalMessaging, Await) {
LocalSystem system;
auto stream = system.Open("main");
LocalWriter writer(system, "main");
std::thread t([&] {
std::this_thread::sleep_for(100ms);
stream->Shutdown();
std::this_thread::sleep_for(100ms);
writer.Send<MessageInt>(20);
});
EXPECT_EQ(stream->Poll(), nullptr);
EXPECT_EQ(stream->Await(), nullptr);
t.join();
}
TEST(LocalMessaging, AwaitTimeout) {
LocalSystem system;
auto stream = system.Open("main");
EXPECT_EQ(stream->Poll(), nullptr);
EXPECT_EQ(stream->Await(100ms), nullptr);
}
TEST(LocalMessaging, RecreateChannelAfterClosing) {
LocalSystem system;
auto stream = system.Open("main");
LocalWriter writer(system, "main");
writer.Send<MessageInt>(10);
EXPECT_EQ(GET_X(stream->Poll()), 10);
stream = nullptr;
writer.Send<MessageInt>(20);
stream = system.Open("main");
EXPECT_EQ(stream->Poll(), nullptr);
writer.Send<MessageInt>(30);
EXPECT_EQ(GET_X(stream->Poll()), 30);
}
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -18,11 +18,11 @@ using namespace communication::raft::test_utils;
using communication::raft::impl::RaftMemberImpl;
using communication::raft::impl::RaftMode;
const RaftConfig test_config1{{"a"}, 150ms, 300ms, 70ms, 60ms, 30ms};
const RaftConfig test_config2{{"a", "b"}, 150ms, 300ms, 70ms, 60ms, 30ms};
const RaftConfig test_config3{{"a", "b", "c"}, 150ms, 300ms, 70ms, 60ms, 30ms};
const RaftConfig test_config1{{"a"}, 150ms, 300ms, 70ms, 30ms};
const RaftConfig test_config2{{"a", "b"}, 150ms, 300ms, 70ms, 30ms};
const RaftConfig test_config3{{"a", "b", "c"}, 150ms, 300ms, 70ms, 30ms};
const RaftConfig test_config5{
{"a", "b", "c", "d", "e"}, 150ms, 300ms, 70ms, 30ms, 30ms};
{"a", "b", "c", "d", "e"}, 150ms, 300ms, 70ms, 30ms};
class RaftMemberImplTest : public ::testing::Test {
public:

View File

@ -10,12 +10,11 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "communication/messaging/distributed.hpp"
#include "communication/rpc/rpc.hpp"
#include "communication/rpc/client.hpp"
#include "communication/rpc/messages.hpp"
#include "communication/rpc/server.hpp"
#include "gtest/gtest.h"
using communication::messaging::Message;
using communication::messaging::System;
using namespace communication::rpc;
using namespace std::literals::chrono_literals;
@ -62,12 +61,12 @@ TEST(Rpc, Call) {
});
std::this_thread::sleep_for(100ms);
System client_system({"127.0.0.1", 0});
Client client(client_system, server_system.endpoint(), "main");
auto sum = client.Call<Sum>(300ms, 10, 20);
Client client(server_system.endpoint(), "main");
auto sum = client.Call<Sum>(10, 20);
EXPECT_EQ(sum->sum, 30);
}
/* TODO (mferencevic): enable when async calls are implemented!
TEST(Rpc, Timeout) {
System server_system({"127.0.0.1", 0});
Server server(server_system, "main");
@ -77,8 +76,8 @@ TEST(Rpc, Timeout) {
});
std::this_thread::sleep_for(100ms);
System client_system({"127.0.0.1", 0});
Client client(client_system, server_system.endpoint(), "main");
Client client(server_system.endpoint(), "main");
auto sum = client.Call<Sum>(100ms, 10, 20);
EXPECT_FALSE(sum);
}
*/

View File

@ -2,14 +2,14 @@
#include "gtest/gtest.h"
#include "communication/messaging/distributed.hpp"
#include "communication/rpc/server.hpp"
#include "io/network/endpoint.hpp"
#include "transactions/engine_master.hpp"
#include "transactions/engine_rpc_messages.hpp"
#include "transactions/engine_worker.hpp"
using namespace tx;
using namespace communication::messaging;
using namespace communication::rpc;
class WorkerEngineTest : public testing::Test {
protected:
@ -18,8 +18,7 @@ class WorkerEngineTest : public testing::Test {
System master_system_{{local, 0}};
MasterEngine master_{master_system_};
System worker_system_{{local, 0}};
WorkerEngine worker_{worker_system_, master_system_.endpoint()};
WorkerEngine worker_{master_system_.endpoint()};
};
TEST_F(WorkerEngineTest, RunningTransaction) {