Remove ConcurrentMap from RPC
Reviewers: teon.banek Reviewed By: teon.banek Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D1662
This commit is contained in:
parent
0829b2bb90
commit
f7d1050a9d
@ -9,10 +9,6 @@ set(communication_src_files
|
||||
rpc/protocol.cpp
|
||||
rpc/server.cpp)
|
||||
|
||||
# TODO: Extract data_structures to library
|
||||
set(communication_src_files ${communication_src_files}
|
||||
${CMAKE_SOURCE_DIR}/src/data_structures/concurrent/skiplist_gc.cpp)
|
||||
|
||||
define_add_capnp(communication_src_files communication_capnp_files)
|
||||
|
||||
add_capnp(rpc/messages.capnp)
|
||||
|
@ -44,14 +44,16 @@ void Session::Execute() {
|
||||
// callback fills the message data
|
||||
auto response_builder = response_message.initRoot<capnp::Message>();
|
||||
|
||||
auto callbacks_accessor = server_->callbacks_.access();
|
||||
auto it = callbacks_accessor.find(request.getTypeId());
|
||||
if (it == callbacks_accessor.end()) {
|
||||
// Access to `callbacks_` and `extended_callbacks_` is done here without
|
||||
// acquiring the `mutex_` because we don't allow RPC registration after the
|
||||
// server was started so those two maps will never be updated when we `find`
|
||||
// over them.
|
||||
auto it = server_->callbacks_.find(request.getTypeId());
|
||||
if (it == server_->callbacks_.end()) {
|
||||
// We couldn't find a regular callback to call, try to find an extended
|
||||
// callback to call.
|
||||
auto extended_callbacks_accessor = server_->extended_callbacks_.access();
|
||||
auto extended_it = extended_callbacks_accessor.find(request.getTypeId());
|
||||
if (extended_it == extended_callbacks_accessor.end()) {
|
||||
auto extended_it = server_->extended_callbacks_.find(request.getTypeId());
|
||||
if (extended_it == server_->extended_callbacks_.end()) {
|
||||
// Throw exception to close the socket and cleanup the session.
|
||||
throw SessionException(
|
||||
"Session trying to execute an unregistered RPC call!");
|
||||
@ -74,7 +76,7 @@ void Session::Execute() {
|
||||
|
||||
MessageSize input_stream_size = response_bytes.size();
|
||||
if (!output_stream_->Write(reinterpret_cast<uint8_t *>(&input_stream_size),
|
||||
sizeof(MessageSize), true)) {
|
||||
sizeof(MessageSize), true)) {
|
||||
throw SessionException("Couldn't send response size!");
|
||||
}
|
||||
if (!output_stream_->Write(response_bytes.begin(), response_bytes.size())) {
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
|
||||
#include "capnp/any.h"
|
||||
@ -9,8 +10,6 @@
|
||||
#include "communication/rpc/messages.hpp"
|
||||
#include "communication/rpc/protocol.hpp"
|
||||
#include "communication/server.hpp"
|
||||
#include "data_structures/concurrent/concurrent_map.hpp"
|
||||
#include "data_structures/queue.hpp"
|
||||
#include "io/network/endpoint.hpp"
|
||||
#include "utils/demangle.hpp"
|
||||
|
||||
@ -36,6 +35,8 @@ class Server {
|
||||
void(const typename TRequestResponse::Request::Capnp::Reader &,
|
||||
typename TRequestResponse::Response::Capnp::Builder *)>
|
||||
callback) {
|
||||
std::lock_guard<std::mutex> guard(lock_);
|
||||
CHECK(!server_.IsRunning()) << "You can't register RPCs when the server is running!";
|
||||
RpcCallback rpc;
|
||||
rpc.req_type = TRequestResponse::Request::TypeInfo;
|
||||
rpc.res_type = TRequestResponse::Response::TypeInfo;
|
||||
@ -51,16 +52,12 @@ class Server {
|
||||
callback(req_data, &res_builder);
|
||||
};
|
||||
|
||||
auto extended_callbacks_accessor = extended_callbacks_.access();
|
||||
if (extended_callbacks_accessor.find(
|
||||
TRequestResponse::Request::TypeInfo.id) !=
|
||||
extended_callbacks_accessor.end()) {
|
||||
if (extended_callbacks_.find(TRequestResponse::Request::TypeInfo.id) !=
|
||||
extended_callbacks_.end()) {
|
||||
LOG(FATAL) << "Callback for that message type already registered!";
|
||||
}
|
||||
|
||||
auto callbacks_accessor = callbacks_.access();
|
||||
auto got =
|
||||
callbacks_accessor.insert(TRequestResponse::Request::TypeInfo.id, rpc);
|
||||
auto got = callbacks_.insert({TRequestResponse::Request::TypeInfo.id, rpc});
|
||||
CHECK(got.second) << "Callback for that message type already registered";
|
||||
VLOG(12) << "[RpcServer] register " << rpc.req_type.name << " -> "
|
||||
<< rpc.res_type.name;
|
||||
@ -72,6 +69,8 @@ class Server {
|
||||
const typename TRequestResponse::Request::Capnp::Reader &,
|
||||
typename TRequestResponse::Response::Capnp::Builder *)>
|
||||
callback) {
|
||||
std::lock_guard<std::mutex> guard(lock_);
|
||||
CHECK(!server_.IsRunning()) << "You can't register RPCs when the server is running!";
|
||||
RpcExtendedCallback rpc;
|
||||
rpc.req_type = TRequestResponse::Request::TypeInfo;
|
||||
rpc.res_type = TRequestResponse::Response::TypeInfo;
|
||||
@ -88,33 +87,18 @@ class Server {
|
||||
callback(endpoint, req_data, &res_builder);
|
||||
};
|
||||
|
||||
auto callbacks_accessor = callbacks_.access();
|
||||
if (callbacks_accessor.find(TRequestResponse::Request::TypeInfo.id) !=
|
||||
callbacks_accessor.end()) {
|
||||
if (callbacks_.find(TRequestResponse::Request::TypeInfo.id) !=
|
||||
callbacks_.end()) {
|
||||
LOG(FATAL) << "Callback for that message type already registered!";
|
||||
}
|
||||
|
||||
auto extended_callbacks_accessor = extended_callbacks_.access();
|
||||
auto got = extended_callbacks_accessor.insert(
|
||||
TRequestResponse::Request::TypeInfo.id, rpc);
|
||||
auto got =
|
||||
extended_callbacks_.insert({TRequestResponse::Request::TypeInfo.id, rpc});
|
||||
CHECK(got.second) << "Callback for that message type already registered";
|
||||
VLOG(12) << "[RpcServer] register " << rpc.req_type.name << " -> "
|
||||
<< rpc.res_type.name;
|
||||
}
|
||||
|
||||
template <typename TRequestResponse>
|
||||
void UnRegister() {
|
||||
const MessageType &type = TRequestResponse::Request::TypeInfo;
|
||||
auto callbacks_accessor = callbacks_.access();
|
||||
auto deleted = callbacks_accessor.remove(type.id);
|
||||
if (!deleted) {
|
||||
auto extended_callbacks_accessor = extended_callbacks_.access();
|
||||
auto extended_deleted = extended_callbacks_accessor.remove(type.id);
|
||||
CHECK(extended_deleted)
|
||||
<< "Trying to remove unknown message type callback";
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
friend class Session;
|
||||
|
||||
@ -135,10 +119,10 @@ class Server {
|
||||
MessageType res_type;
|
||||
};
|
||||
|
||||
ConcurrentMap<uint64_t, RpcCallback> callbacks_;
|
||||
ConcurrentMap<uint64_t, RpcExtendedCallback> extended_callbacks_;
|
||||
std::mutex lock_;
|
||||
std::map<uint64_t, RpcCallback> callbacks_;
|
||||
std::map<uint64_t, RpcExtendedCallback> extended_callbacks_;
|
||||
|
||||
std::mutex mutex_;
|
||||
// TODO (mferencevic): currently the RPC server is hardcoded not to use SSL
|
||||
communication::ServerContext context_;
|
||||
communication::Server<Session, Server> server_;
|
||||
|
@ -127,6 +127,9 @@ class Server final {
|
||||
listener_.AwaitShutdown();
|
||||
}
|
||||
|
||||
/// Returns `true` if the server was started
|
||||
bool IsRunning() { return alive_; }
|
||||
|
||||
private:
|
||||
void AcceptConnection() {
|
||||
// Accept a connection from a socket.
|
||||
|
@ -23,16 +23,16 @@ class DistributedConcurrentIdMapperTest : public ::testing::Test {
|
||||
worker_mapper_;
|
||||
|
||||
void SetUp() override {
|
||||
master_mapper_.emplace(&coordination_);
|
||||
coordination_.Start();
|
||||
master_client_pool_.emplace(coordination_.GetServerEndpoint());
|
||||
master_mapper_.emplace(&coordination_);
|
||||
worker_mapper_.emplace(&master_client_pool_.value());
|
||||
}
|
||||
void TearDown() override {
|
||||
worker_mapper_ = std::experimental::nullopt;
|
||||
master_mapper_ = std::experimental::nullopt;
|
||||
master_client_pool_ = std::experimental::nullopt;
|
||||
coordination_.Stop();
|
||||
master_mapper_ = std::experimental::nullopt;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -89,9 +89,9 @@ TEST_F(Distributed, Coordination) {
|
||||
std::vector<std::unique_ptr<WorkerCoordinationInThread>> workers;
|
||||
|
||||
MasterCoordination master_coord({kLocal, 0});
|
||||
ClusterDiscoveryMaster master_discovery_(&master_coord, tmp_dir("master"));
|
||||
ASSERT_TRUE(master_coord.Start());
|
||||
master_coord.SetRecoveredSnapshot(std::experimental::nullopt);
|
||||
ClusterDiscoveryMaster master_discovery_(&master_coord, tmp_dir("master"));
|
||||
|
||||
for (int i = 1; i <= kWorkerCount; ++i)
|
||||
workers.emplace_back(std::make_unique<WorkerCoordinationInThread>(
|
||||
@ -120,9 +120,9 @@ TEST_F(Distributed, DesiredAndUniqueId) {
|
||||
std::vector<std::unique_ptr<WorkerCoordinationInThread>> workers;
|
||||
|
||||
MasterCoordination master_coord({kLocal, 0});
|
||||
ClusterDiscoveryMaster master_discovery_(&master_coord, tmp_dir("master"));
|
||||
ASSERT_TRUE(master_coord.Start());
|
||||
master_coord.SetRecoveredSnapshot(std::experimental::nullopt);
|
||||
ClusterDiscoveryMaster master_discovery_(&master_coord, tmp_dir("master"));
|
||||
|
||||
workers.emplace_back(std::make_unique<WorkerCoordinationInThread>(
|
||||
master_coord.GetServerEndpoint(), tmp_dir("worker42"), 42));
|
||||
@ -143,9 +143,9 @@ TEST_F(Distributed, CoordinationWorkersId) {
|
||||
std::vector<std::unique_ptr<WorkerCoordinationInThread>> workers;
|
||||
|
||||
MasterCoordination master_coord({kLocal, 0});
|
||||
ClusterDiscoveryMaster master_discovery_(&master_coord, tmp_dir("master"));
|
||||
ASSERT_TRUE(master_coord.Start());
|
||||
master_coord.SetRecoveredSnapshot(std::experimental::nullopt);
|
||||
ClusterDiscoveryMaster master_discovery_(&master_coord, tmp_dir("master"));
|
||||
|
||||
workers.emplace_back(std::make_unique<WorkerCoordinationInThread>(
|
||||
master_coord.GetServerEndpoint(), tmp_dir("worker42"), 42));
|
||||
@ -169,9 +169,9 @@ TEST_F(Distributed, ClusterDiscovery) {
|
||||
std::vector<std::unique_ptr<WorkerCoordinationInThread>> workers;
|
||||
|
||||
MasterCoordination master_coord({kLocal, 0});
|
||||
ClusterDiscoveryMaster master_discovery_(&master_coord, tmp_dir("master"));
|
||||
ASSERT_TRUE(master_coord.Start());
|
||||
master_coord.SetRecoveredSnapshot(std::experimental::nullopt);
|
||||
ClusterDiscoveryMaster master_discovery_(&master_coord, tmp_dir("master"));
|
||||
std::vector<int> ids;
|
||||
int worker_count = 10;
|
||||
|
||||
@ -200,9 +200,9 @@ TEST_F(Distributed, KeepsTrackOfRecovered) {
|
||||
std::vector<std::unique_ptr<WorkerCoordinationInThread>> workers;
|
||||
|
||||
MasterCoordination master_coord({kLocal, 0});
|
||||
ClusterDiscoveryMaster master_discovery_(&master_coord, tmp_dir("master"));
|
||||
ASSERT_TRUE(master_coord.Start());
|
||||
master_coord.SetRecoveredSnapshot(std::experimental::nullopt);
|
||||
ClusterDiscoveryMaster master_discovery_(&master_coord, tmp_dir("master"));
|
||||
int worker_count = 10;
|
||||
for (int i = 1; i <= worker_count; ++i) {
|
||||
workers.emplace_back(std::make_unique<WorkerCoordinationInThread>(
|
||||
|
@ -19,15 +19,13 @@ class WorkerEngineTest : public testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
master_coordination_ = std::make_unique<TestMasterCoordination>();
|
||||
master_coordination_->Start();
|
||||
|
||||
master_ = std::make_unique<EngineMaster>(master_coordination_.get());
|
||||
master_coordination_->Start();
|
||||
|
||||
worker_coordination_ = std::make_unique<TestWorkerCoordination>(
|
||||
master_coordination_->GetServerEndpoint(), 1);
|
||||
worker_coordination_->Start();
|
||||
|
||||
worker_ = std::make_unique<EngineWorker>(worker_coordination_.get());
|
||||
worker_coordination_->Start();
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
|
Loading…
Reference in New Issue
Block a user