Add SSL support to HA RPC

Reviewers: msantl, teon.banek

Reviewed By: msantl

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D2055
This commit is contained in:
Matej Ferencevic 2019-05-20 10:38:57 +02:00
parent 4e927a52e1
commit 6c49e6de02
15 changed files with 100 additions and 52 deletions

View File

@ -2,7 +2,9 @@
namespace communication::rpc {
Client::Client(const io::network::Endpoint &endpoint) : endpoint_(endpoint) {}
Client::Client(const io::network::Endpoint &endpoint,
communication::ClientContext *context)
: endpoint_(endpoint), context_(context) {}
void Client::Abort() {
if (!client_) return;

View File

@ -19,7 +19,8 @@ namespace communication::rpc {
/// Client is thread safe, but it is recommended to use thread_local clients.
class Client {
public:
explicit Client(const io::network::Endpoint &endpoint);
Client(const io::network::Endpoint &endpoint,
communication::ClientContext *context);
/// Call a previously defined and registered RPC call. This function can
/// initiate only one request at a time. The call blocks until a response is
@ -61,7 +62,7 @@ class Client {
// Connect to the remote server.
if (!client_) {
client_.emplace(&context_);
client_.emplace(context_);
if (!client_->Connect(endpoint_)) {
DLOG(ERROR) << "Couldn't connect to remote address " << endpoint_;
client_ = std::nullopt;
@ -121,8 +122,7 @@ class Client {
private:
io::network::Endpoint endpoint_;
// TODO (mferencevic): currently the RPC client is hardcoded not to use SSL
communication::ClientContext context_;
communication::ClientContext *context_;
std::optional<communication::Client> client_;
std::mutex mutex_;

View File

@ -14,8 +14,9 @@ namespace communication::rpc {
*/
class ClientPool {
public:
explicit ClientPool(const io::network::Endpoint &endpoint)
: endpoint_(endpoint) {}
ClientPool(const io::network::Endpoint &endpoint,
communication::ClientContext *context)
: endpoint_(endpoint), context_(context) {}
template <class TRequestResponse, class... Args>
typename TRequestResponse::Response Call(Args &&... args) {
@ -40,7 +41,7 @@ class ClientPool {
std::unique_lock<std::mutex> lock(mutex_);
if (unused_clients_.empty()) {
client = std::make_unique<Client>(endpoint_);
client = std::make_unique<Client>(endpoint_, context_);
} else {
client = std::move(unused_clients_.top());
unused_clients_.pop();
@ -55,6 +56,7 @@ class ClientPool {
}
io::network::Endpoint endpoint_;
communication::ClientContext *context_;
std::mutex mutex_;
std::stack<std::unique_ptr<Client>> unused_clients_;

View File

@ -3,20 +3,15 @@
namespace communication::rpc {
Server::Server(const io::network::Endpoint &endpoint,
size_t workers_count)
: server_(endpoint, this, &context_, -1, "RPC", workers_count) {}
communication::ServerContext *context, size_t workers_count)
: server_(endpoint, this, context, -1, context->use_ssl() ? "RPCS" : "RPC",
workers_count) {}
bool Server::Start() {
return server_.Start();
}
bool Server::Start() { return server_.Start(); }
void Server::Shutdown() {
server_.Shutdown();
}
void Server::Shutdown() { server_.Shutdown(); }
void Server::AwaitShutdown() {
server_.AwaitShutdown();
}
void Server::AwaitShutdown() { server_.AwaitShutdown(); }
const io::network::Endpoint &Server::endpoint() const {
return server_.endpoint();

View File

@ -15,6 +15,7 @@ namespace communication::rpc {
class Server {
public:
Server(const io::network::Endpoint &endpoint,
communication::ServerContext *context,
size_t workers_count = std::thread::hardware_concurrency());
Server(const Server &) = delete;
Server(Server &&) = delete;
@ -88,8 +89,6 @@ class Server {
std::map<uint64_t, RpcCallback> callbacks_;
std::map<uint64_t, RpcExtendedCallback> extended_callbacks_;
// TODO (mferencevic): currently the RPC server is hardcoded not to use SSL
communication::ServerContext context_;
communication::Server<Session, Server> server_;
}; // namespace communication::rpc

View File

@ -10,7 +10,7 @@ Coordination::Coordination(const io::network::Endpoint &worker_endpoint,
int worker_id,
const io::network::Endpoint &master_endpoint,
int server_workers_count, int client_workers_count)
: server_(worker_endpoint, server_workers_count),
: server_(worker_endpoint, &server_context_, server_workers_count),
thread_pool_(client_workers_count, "RPC client") {
if (worker_id != 0) {
// The master is always worker 0.
@ -64,7 +64,7 @@ communication::rpc::ClientPool *Coordination::GetClientPool(int worker_id) {
return &client_pools_
.emplace(std::piecewise_construct,
std::forward_as_tuple(worker_id),
std::forward_as_tuple(endpoint))
std::forward_as_tuple(endpoint, &client_context_))
.first->second;
}

View File

@ -95,6 +95,8 @@ class Coordination {
/// Gets a worker name for the given endpoint.
std::string GetWorkerName(const io::network::Endpoint &endpoint);
// TODO(mferencevic): distributed is currently hardcoded not to use SSL
communication::ServerContext server_context_;
communication::rpc::Server server_;
std::atomic<bool> cluster_alive_{true};
@ -103,6 +105,8 @@ class Coordination {
std::unordered_map<int, io::network::Endpoint> workers_;
mutable std::mutex lock_;
// TODO(mferencevic): distributed is currently hardcoded not to use SSL
communication::ClientContext client_context_;
std::unordered_map<int, communication::rpc::ClientPool> client_pools_;
utils::ThreadPool thread_pool_;
};

View File

@ -27,8 +27,8 @@ DEFINE_VALIDATED_int32(session_inactivity_timeout, 1800,
"Time in seconds after which inactive sessions will be "
"closed.",
FLAG_IN_RANGE(1, INT32_MAX));
DEFINE_string(cert_file, "", "Certificate file to use.");
DEFINE_string(key_file, "", "Key file to use.");
DEFINE_string(cert_file, "", "Certificate file to use (Bolt).");
DEFINE_string(key_file, "", "Key file to use (Bolt).");
using ServerT = communication::Server<BoltSession, SessionData>;
using communication::ServerContext;

View File

@ -1,10 +1,14 @@
#include "raft/coordination.hpp"
#include <gflags/gflags.h>
#include <json/json.hpp>
#include "utils/file.hpp"
#include "utils/string.hpp"
DEFINE_string(rpc_cert_file, "", "Certificate file to use (RPC).");
DEFINE_string(rpc_key_file, "", "Key file to use (RPC).");
namespace raft {
namespace fs = std::filesystem;
@ -46,15 +50,23 @@ std::unordered_map<uint16_t, io::network::Endpoint> LoadNodesFromFile(
Coordination::Coordination(
uint16_t node_id,
std::unordered_map<uint16_t, io::network::Endpoint> all_nodes)
: node_id_(node_id),
cluster_size_(all_nodes.size()),
server_(all_nodes[node_id], all_nodes.size() * 2) {
: node_id_(node_id), cluster_size_(all_nodes.size()) {
// Create and initialize all server elements.
if (!FLAGS_rpc_cert_file.empty() && !FLAGS_rpc_key_file.empty()) {
server_context_.emplace(FLAGS_rpc_key_file, FLAGS_rpc_cert_file);
} else {
server_context_.emplace();
}
server_.emplace(all_nodes[node_id_], &server_context_.value(),
all_nodes.size() * 2);
// Create all client elements.
endpoints_.resize(cluster_size_);
clients_.resize(cluster_size_);
client_locks_.resize(cluster_size_);
// Initialize all client elements.
client_context_.emplace(server_context_->use_ssl());
for (uint16_t i = 1; i <= cluster_size_; ++i) {
auto it = all_nodes.find(i);
if (it == all_nodes.end()) {
@ -93,7 +105,7 @@ uint16_t Coordination::GetAllNodeCount() { return cluster_size_; }
uint16_t Coordination::GetOtherNodeCount() { return cluster_size_ - 1; }
bool Coordination::Start() { return server_.Start(); }
bool Coordination::Start() { return server_->Start(); }
void Coordination::AwaitShutdown(
std::function<void(void)> call_before_shutdown) {
@ -106,8 +118,8 @@ void Coordination::AwaitShutdown(
call_before_shutdown();
// Shutdown our RPC server.
server_.Shutdown();
server_.AwaitShutdown();
server_->Shutdown();
server_->AwaitShutdown();
}
void Coordination::Shutdown() { alive_.store(false); }

View File

@ -80,7 +80,8 @@ class Coordination final {
if (!client) {
const auto &endpoint = endpoints_[other_id - 1];
client = std::make_unique<communication::rpc::Client>(endpoint);
client = std::make_unique<communication::rpc::Client>(
endpoint, &client_context_.value());
}
try {
@ -95,7 +96,7 @@ class Coordination final {
/// Registers a RPC call on this node.
template <class TRequestResponse>
void Register(std::function<void(slk::Reader *, slk::Builder *)> callback) {
server_.Register<TRequestResponse>(callback);
server_->Register<TRequestResponse>(callback);
}
/// Registers an extended RPC call on this node.
@ -103,7 +104,7 @@ class Coordination final {
void Register(std::function<void(const io::network::Endpoint &, slk::Reader *,
slk::Builder *)>
callback) {
server_.Register<TRequestResponse>(callback);
server_->Register<TRequestResponse>(callback);
}
/// Starts the coordination and its servers.
@ -121,8 +122,10 @@ class Coordination final {
uint16_t node_id_;
uint16_t cluster_size_;
communication::rpc::Server server_;
std::optional<communication::ServerContext> server_context_;
std::optional<communication::rpc::Server> server_;
std::optional<communication::ClientContext> client_context_;
std::vector<io::network::Endpoint> endpoints_;
std::vector<std::unique_ptr<communication::rpc::Client>> clients_;
std::vector<std::unique_ptr<std::mutex>> client_locks_;

View File

@ -46,7 +46,9 @@ void StatsDispatchMain(const io::network::Endpoint &endpoint) {
LOG(INFO) << "Stats dispatcher thread started";
utils::ThreadSetName("Stats dispatcher");
communication::rpc::Client client(endpoint);
// TODO(mferencevic): stats are currently hardcoded not to use SSL
communication::ClientContext client_context;
communication::rpc::Client client(endpoint, &client_context);
BatchStatsReq batch_request;
batch_request.requests.reserve(MAX_BATCH_SIZE);

View File

@ -44,10 +44,15 @@ const int kThreadsNum = 16;
DEFINE_string(server_address, "127.0.0.1", "Server address");
DEFINE_int32(server_port, 0, "Server port");
DEFINE_string(server_cert_file, "", "Server SSL certificate file");
DEFINE_string(server_key_file, "", "Server SSL key file");
DEFINE_bool(benchmark_use_ssl, false, "Set to true to benchmark using SSL");
DEFINE_bool(run_server, true, "Set to false to use external server");
DEFINE_bool(run_benchmark, true, "Set to false to only run server");
std::optional<communication::ServerContext> server_context;
std::optional<communication::rpc::Server> server;
std::optional<communication::ClientContext> client_context;
std::optional<communication::rpc::Client> clients[kThreadsNum];
std::optional<communication::rpc::ClientPool> client_pool;
std::optional<utils::ThreadPool> thread_pool;
@ -105,9 +110,15 @@ int main(int argc, char **argv) {
google::InitGoogleLogging(argv[0]);
if (FLAGS_run_server) {
if (!FLAGS_server_cert_file.empty() && !FLAGS_server_key_file.empty()) {
FLAGS_benchmark_use_ssl = true;
server_context.emplace(FLAGS_server_key_file, FLAGS_server_cert_file);
} else {
server_context.emplace();
}
server.emplace(
io::network::Endpoint(FLAGS_server_address, FLAGS_server_port),
kThreadsNum);
&server_context.value(), kThreadsNum);
server->Register<Echo>([](const auto &req_reader, auto *res_builder) {
EchoMessage res;
@ -127,8 +138,10 @@ int main(int argc, char **argv) {
endpoint = io::network::Endpoint(FLAGS_server_address, FLAGS_server_port);
}
client_context.emplace(FLAGS_benchmark_use_ssl);
for (int i = 0; i < kThreadsNum; ++i) {
clients[i].emplace(endpoint);
clients[i].emplace(endpoint, &client_context.value());
clients[i]->Call<Echo>("init");
}
@ -137,7 +150,7 @@ int main(int argc, char **argv) {
// of making connections to the server during the benchmark here we
// simultaneously call the Echo RPC on the client pool to make the client
// pool connect to the server `kThreadsNum` times.
client_pool.emplace(endpoint);
client_pool.emplace(endpoint, &client_context.value());
std::thread threads[kThreadsNum];
for (int i = 0; i < kThreadsNum; ++i) {
threads[i] =

View File

@ -15,6 +15,7 @@ class DistributedConcurrentIdMapperTest : public ::testing::Test {
protected:
TestMasterCoordination coordination_;
communication::ClientContext client_context_;
std::optional<communication::rpc::ClientPool> master_client_pool_;
std::optional<storage::MasterConcurrentIdMapper<TId>> master_mapper_;
std::optional<storage::WorkerConcurrentIdMapper<TId>> worker_mapper_;
@ -22,7 +23,8 @@ class DistributedConcurrentIdMapperTest : public ::testing::Test {
void SetUp() override {
master_mapper_.emplace(&coordination_);
coordination_.Start();
master_client_pool_.emplace(coordination_.GetServerEndpoint());
master_client_pool_.emplace(coordination_.GetServerEndpoint(),
&client_context_);
worker_mapper_.emplace(&master_client_pool_.value());
}
void TearDown() override {

View File

@ -54,7 +54,8 @@ void EchoMessage::Save(const EchoMessage &obj, slk::Builder *builder) {
}
TEST(Rpc, Call) {
Server server({"127.0.0.1", 0});
communication::ServerContext server_context;
Server server({"127.0.0.1", 0}, &server_context);
server.Register<Sum>([](auto *req_reader, auto *res_builder) {
SumReq req;
slk::Load(&req, req_reader);
@ -64,7 +65,8 @@ TEST(Rpc, Call) {
ASSERT_TRUE(server.Start());
std::this_thread::sleep_for(100ms);
Client client(server.endpoint());
communication::ClientContext client_context;
Client client(server.endpoint(), &client_context);
auto sum = client.Call<Sum>(10, 20);
EXPECT_EQ(sum.sum, 30);
@ -73,7 +75,8 @@ TEST(Rpc, Call) {
}
TEST(Rpc, Abort) {
Server server({"127.0.0.1", 0});
communication::ServerContext server_context;
Server server({"127.0.0.1", 0}, &server_context);
server.Register<Sum>([](auto *req_reader, auto *res_builder) {
SumReq req;
slk::Load(&req, req_reader);
@ -84,7 +87,8 @@ TEST(Rpc, Abort) {
ASSERT_TRUE(server.Start());
std::this_thread::sleep_for(100ms);
Client client(server.endpoint());
communication::ClientContext client_context;
Client client(server.endpoint(), &client_context);
std::thread thread([&client]() {
std::this_thread::sleep_for(100ms);
@ -104,7 +108,8 @@ TEST(Rpc, Abort) {
}
TEST(Rpc, ClientPool) {
Server server({"127.0.0.1", 0});
communication::ServerContext server_context;
Server server({"127.0.0.1", 0}, &server_context);
server.Register<Sum>([](const auto &req_reader, auto *res_builder) {
SumReq req;
Load(&req, req_reader);
@ -115,7 +120,8 @@ TEST(Rpc, ClientPool) {
ASSERT_TRUE(server.Start());
std::this_thread::sleep_for(100ms);
Client client(server.endpoint());
communication::ClientContext client_context;
Client client(server.endpoint(), &client_context);
// These calls should take more than 400ms because we're using a regular
// client
@ -136,7 +142,8 @@ TEST(Rpc, ClientPool) {
EXPECT_GE(t1.Elapsed(), 400ms);
ClientPool pool(server.endpoint());
communication::ClientContext pool_context;
ClientPool pool(server.endpoint(), &pool_context);
// These calls shouldn't take much more that 100ms because they execute in
// parallel
@ -159,7 +166,8 @@ TEST(Rpc, ClientPool) {
}
TEST(Rpc, LargeMessage) {
Server server({"127.0.0.1", 0});
communication::ServerContext server_context;
Server server({"127.0.0.1", 0}, &server_context);
server.Register<Echo>([](auto *req_reader, auto *res_builder) {
EchoMessage res;
slk::Load(&res, req_reader);
@ -170,7 +178,8 @@ TEST(Rpc, LargeMessage) {
std::string testdata(100000, 'a');
Client client(server.endpoint());
communication::ClientContext client_context;
Client client(server.endpoint(), &client_context);
auto echo = client.Call<Echo>(testdata);
EXPECT_EQ(echo.data, testdata);
@ -179,7 +188,8 @@ TEST(Rpc, LargeMessage) {
}
TEST(Rpc, JumboMessage) {
Server server({"127.0.0.1", 0});
communication::ServerContext server_context;
Server server({"127.0.0.1", 0}, &server_context);
server.Register<Echo>([](auto *req_reader, auto *res_builder) {
EchoMessage res;
slk::Load(&res, req_reader);
@ -191,7 +201,8 @@ TEST(Rpc, JumboMessage) {
// NOLINTNEXTLINE (bugprone-string-constructor)
std::string testdata(10000000, 'a');
Client client(server.endpoint());
communication::ClientContext client_context;
Client client(server.endpoint(), &client_context);
auto echo = client.Call<Echo>(testdata);
EXPECT_EQ(echo.data, testdata);

View File

@ -32,7 +32,10 @@ std::string GraphiteFormat(const stats::StatsReq &req) {
int main(int argc, char *argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
communication::rpc::Server server({FLAGS_interface, (uint16_t)FLAGS_port});
// TODO(mferencevic): stats are currently hardcoded not to use SSL
communication::ServerContext server_context;
communication::rpc::Server server({FLAGS_interface, (uint16_t)FLAGS_port},
&server_context);
io::network::Socket graphite_socket;