Remove ConcurrentMap from RPC

Reviewers: teon.banek

Reviewed By: teon.banek

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1662
This commit is contained in:
Matej Ferencevic 2018-10-16 12:14:48 +02:00
parent 0829b2bb90
commit f7d1050a9d
7 changed files with 37 additions and 54 deletions

View File

@ -9,10 +9,6 @@ set(communication_src_files
rpc/protocol.cpp
rpc/server.cpp)
# TODO: Extract data_structures to library
set(communication_src_files ${communication_src_files}
${CMAKE_SOURCE_DIR}/src/data_structures/concurrent/skiplist_gc.cpp)
define_add_capnp(communication_src_files communication_capnp_files)
add_capnp(rpc/messages.capnp)

View File

@ -44,14 +44,16 @@ void Session::Execute() {
// callback fills the message data
auto response_builder = response_message.initRoot<capnp::Message>();
auto callbacks_accessor = server_->callbacks_.access();
auto it = callbacks_accessor.find(request.getTypeId());
if (it == callbacks_accessor.end()) {
// Access to `callbacks_` and `extended_callbacks_` is done here without
// acquiring the `mutex_` because we don't allow RPC registration after the
// server was started so those two maps will never be updated when we `find`
// over them.
auto it = server_->callbacks_.find(request.getTypeId());
if (it == server_->callbacks_.end()) {
// We couldn't find a regular callback to call, try to find an extended
// callback to call.
auto extended_callbacks_accessor = server_->extended_callbacks_.access();
auto extended_it = extended_callbacks_accessor.find(request.getTypeId());
if (extended_it == extended_callbacks_accessor.end()) {
auto extended_it = server_->extended_callbacks_.find(request.getTypeId());
if (extended_it == server_->extended_callbacks_.end()) {
// Throw exception to close the socket and cleanup the session.
throw SessionException(
"Session trying to execute an unregistered RPC call!");
@ -74,7 +76,7 @@ void Session::Execute() {
MessageSize input_stream_size = response_bytes.size();
if (!output_stream_->Write(reinterpret_cast<uint8_t *>(&input_stream_size),
sizeof(MessageSize), true)) {
sizeof(MessageSize), true)) {
throw SessionException("Couldn't send response size!");
}
if (!output_stream_->Write(response_bytes.begin(), response_bytes.size())) {

View File

@ -1,6 +1,7 @@
#pragma once
#include <unordered_map>
#include <map>
#include <mutex>
#include <vector>
#include "capnp/any.h"
@ -9,8 +10,6 @@
#include "communication/rpc/messages.hpp"
#include "communication/rpc/protocol.hpp"
#include "communication/server.hpp"
#include "data_structures/concurrent/concurrent_map.hpp"
#include "data_structures/queue.hpp"
#include "io/network/endpoint.hpp"
#include "utils/demangle.hpp"
@ -36,6 +35,8 @@ class Server {
void(const typename TRequestResponse::Request::Capnp::Reader &,
typename TRequestResponse::Response::Capnp::Builder *)>
callback) {
std::lock_guard<std::mutex> guard(lock_);
CHECK(!server_.IsRunning()) << "You can't register RPCs when the server is running!";
RpcCallback rpc;
rpc.req_type = TRequestResponse::Request::TypeInfo;
rpc.res_type = TRequestResponse::Response::TypeInfo;
@ -51,16 +52,12 @@ class Server {
callback(req_data, &res_builder);
};
auto extended_callbacks_accessor = extended_callbacks_.access();
if (extended_callbacks_accessor.find(
TRequestResponse::Request::TypeInfo.id) !=
extended_callbacks_accessor.end()) {
if (extended_callbacks_.find(TRequestResponse::Request::TypeInfo.id) !=
extended_callbacks_.end()) {
LOG(FATAL) << "Callback for that message type already registered!";
}
auto callbacks_accessor = callbacks_.access();
auto got =
callbacks_accessor.insert(TRequestResponse::Request::TypeInfo.id, rpc);
auto got = callbacks_.insert({TRequestResponse::Request::TypeInfo.id, rpc});
CHECK(got.second) << "Callback for that message type already registered";
VLOG(12) << "[RpcServer] register " << rpc.req_type.name << " -> "
<< rpc.res_type.name;
@ -72,6 +69,8 @@ class Server {
const typename TRequestResponse::Request::Capnp::Reader &,
typename TRequestResponse::Response::Capnp::Builder *)>
callback) {
std::lock_guard<std::mutex> guard(lock_);
CHECK(!server_.IsRunning()) << "You can't register RPCs when the server is running!";
RpcExtendedCallback rpc;
rpc.req_type = TRequestResponse::Request::TypeInfo;
rpc.res_type = TRequestResponse::Response::TypeInfo;
@ -88,33 +87,18 @@ class Server {
callback(endpoint, req_data, &res_builder);
};
auto callbacks_accessor = callbacks_.access();
if (callbacks_accessor.find(TRequestResponse::Request::TypeInfo.id) !=
callbacks_accessor.end()) {
if (callbacks_.find(TRequestResponse::Request::TypeInfo.id) !=
callbacks_.end()) {
LOG(FATAL) << "Callback for that message type already registered!";
}
auto extended_callbacks_accessor = extended_callbacks_.access();
auto got = extended_callbacks_accessor.insert(
TRequestResponse::Request::TypeInfo.id, rpc);
auto got =
extended_callbacks_.insert({TRequestResponse::Request::TypeInfo.id, rpc});
CHECK(got.second) << "Callback for that message type already registered";
VLOG(12) << "[RpcServer] register " << rpc.req_type.name << " -> "
<< rpc.res_type.name;
}
template <typename TRequestResponse>
void UnRegister() {
const MessageType &type = TRequestResponse::Request::TypeInfo;
auto callbacks_accessor = callbacks_.access();
auto deleted = callbacks_accessor.remove(type.id);
if (!deleted) {
auto extended_callbacks_accessor = extended_callbacks_.access();
auto extended_deleted = extended_callbacks_accessor.remove(type.id);
CHECK(extended_deleted)
<< "Trying to remove unknown message type callback";
}
}
private:
friend class Session;
@ -135,10 +119,10 @@ class Server {
MessageType res_type;
};
ConcurrentMap<uint64_t, RpcCallback> callbacks_;
ConcurrentMap<uint64_t, RpcExtendedCallback> extended_callbacks_;
std::mutex lock_;
std::map<uint64_t, RpcCallback> callbacks_;
std::map<uint64_t, RpcExtendedCallback> extended_callbacks_;
std::mutex mutex_;
// TODO (mferencevic): currently the RPC server is hardcoded not to use SSL
communication::ServerContext context_;
communication::Server<Session, Server> server_;

View File

@ -127,6 +127,9 @@ class Server final {
listener_.AwaitShutdown();
}
/// Returns `true` if the server was started
bool IsRunning() { return alive_; }
private:
void AcceptConnection() {
// Accept a connection from a socket.

View File

@ -23,16 +23,16 @@ class DistributedConcurrentIdMapperTest : public ::testing::Test {
worker_mapper_;
void SetUp() override {
master_mapper_.emplace(&coordination_);
coordination_.Start();
master_client_pool_.emplace(coordination_.GetServerEndpoint());
master_mapper_.emplace(&coordination_);
worker_mapper_.emplace(&master_client_pool_.value());
}
void TearDown() override {
worker_mapper_ = std::experimental::nullopt;
master_mapper_ = std::experimental::nullopt;
master_client_pool_ = std::experimental::nullopt;
coordination_.Stop();
master_mapper_ = std::experimental::nullopt;
}
};

View File

@ -89,9 +89,9 @@ TEST_F(Distributed, Coordination) {
std::vector<std::unique_ptr<WorkerCoordinationInThread>> workers;
MasterCoordination master_coord({kLocal, 0});
ClusterDiscoveryMaster master_discovery_(&master_coord, tmp_dir("master"));
ASSERT_TRUE(master_coord.Start());
master_coord.SetRecoveredSnapshot(std::experimental::nullopt);
ClusterDiscoveryMaster master_discovery_(&master_coord, tmp_dir("master"));
for (int i = 1; i <= kWorkerCount; ++i)
workers.emplace_back(std::make_unique<WorkerCoordinationInThread>(
@ -120,9 +120,9 @@ TEST_F(Distributed, DesiredAndUniqueId) {
std::vector<std::unique_ptr<WorkerCoordinationInThread>> workers;
MasterCoordination master_coord({kLocal, 0});
ClusterDiscoveryMaster master_discovery_(&master_coord, tmp_dir("master"));
ASSERT_TRUE(master_coord.Start());
master_coord.SetRecoveredSnapshot(std::experimental::nullopt);
ClusterDiscoveryMaster master_discovery_(&master_coord, tmp_dir("master"));
workers.emplace_back(std::make_unique<WorkerCoordinationInThread>(
master_coord.GetServerEndpoint(), tmp_dir("worker42"), 42));
@ -143,9 +143,9 @@ TEST_F(Distributed, CoordinationWorkersId) {
std::vector<std::unique_ptr<WorkerCoordinationInThread>> workers;
MasterCoordination master_coord({kLocal, 0});
ClusterDiscoveryMaster master_discovery_(&master_coord, tmp_dir("master"));
ASSERT_TRUE(master_coord.Start());
master_coord.SetRecoveredSnapshot(std::experimental::nullopt);
ClusterDiscoveryMaster master_discovery_(&master_coord, tmp_dir("master"));
workers.emplace_back(std::make_unique<WorkerCoordinationInThread>(
master_coord.GetServerEndpoint(), tmp_dir("worker42"), 42));
@ -169,9 +169,9 @@ TEST_F(Distributed, ClusterDiscovery) {
std::vector<std::unique_ptr<WorkerCoordinationInThread>> workers;
MasterCoordination master_coord({kLocal, 0});
ClusterDiscoveryMaster master_discovery_(&master_coord, tmp_dir("master"));
ASSERT_TRUE(master_coord.Start());
master_coord.SetRecoveredSnapshot(std::experimental::nullopt);
ClusterDiscoveryMaster master_discovery_(&master_coord, tmp_dir("master"));
std::vector<int> ids;
int worker_count = 10;
@ -200,9 +200,9 @@ TEST_F(Distributed, KeepsTrackOfRecovered) {
std::vector<std::unique_ptr<WorkerCoordinationInThread>> workers;
MasterCoordination master_coord({kLocal, 0});
ClusterDiscoveryMaster master_discovery_(&master_coord, tmp_dir("master"));
ASSERT_TRUE(master_coord.Start());
master_coord.SetRecoveredSnapshot(std::experimental::nullopt);
ClusterDiscoveryMaster master_discovery_(&master_coord, tmp_dir("master"));
int worker_count = 10;
for (int i = 1; i <= worker_count; ++i) {
workers.emplace_back(std::make_unique<WorkerCoordinationInThread>(

View File

@ -19,15 +19,13 @@ class WorkerEngineTest : public testing::Test {
protected:
void SetUp() override {
master_coordination_ = std::make_unique<TestMasterCoordination>();
master_coordination_->Start();
master_ = std::make_unique<EngineMaster>(master_coordination_.get());
master_coordination_->Start();
worker_coordination_ = std::make_unique<TestWorkerCoordination>(
master_coordination_->GetServerEndpoint(), 1);
worker_coordination_->Start();
worker_ = std::make_unique<EngineWorker>(worker_coordination_.get());
worker_coordination_->Start();
}
void TearDown() override {