Add SSL support to HA RPC
Reviewers: msantl, teon.banek Reviewed By: msantl Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D2055
This commit is contained in:
parent
4e927a52e1
commit
6c49e6de02
@ -2,7 +2,9 @@
|
||||
|
||||
namespace communication::rpc {
|
||||
|
||||
Client::Client(const io::network::Endpoint &endpoint) : endpoint_(endpoint) {}
|
||||
Client::Client(const io::network::Endpoint &endpoint,
|
||||
communication::ClientContext *context)
|
||||
: endpoint_(endpoint), context_(context) {}
|
||||
|
||||
void Client::Abort() {
|
||||
if (!client_) return;
|
||||
|
@ -19,7 +19,8 @@ namespace communication::rpc {
|
||||
/// Client is thread safe, but it is recommended to use thread_local clients.
|
||||
class Client {
|
||||
public:
|
||||
explicit Client(const io::network::Endpoint &endpoint);
|
||||
Client(const io::network::Endpoint &endpoint,
|
||||
communication::ClientContext *context);
|
||||
|
||||
/// Call a previously defined and registered RPC call. This function can
|
||||
/// initiate only one request at a time. The call blocks until a response is
|
||||
@ -61,7 +62,7 @@ class Client {
|
||||
|
||||
// Connect to the remote server.
|
||||
if (!client_) {
|
||||
client_.emplace(&context_);
|
||||
client_.emplace(context_);
|
||||
if (!client_->Connect(endpoint_)) {
|
||||
DLOG(ERROR) << "Couldn't connect to remote address " << endpoint_;
|
||||
client_ = std::nullopt;
|
||||
@ -121,8 +122,7 @@ class Client {
|
||||
|
||||
private:
|
||||
io::network::Endpoint endpoint_;
|
||||
// TODO (mferencevic): currently the RPC client is hardcoded not to use SSL
|
||||
communication::ClientContext context_;
|
||||
communication::ClientContext *context_;
|
||||
std::optional<communication::Client> client_;
|
||||
|
||||
std::mutex mutex_;
|
||||
|
@ -14,8 +14,9 @@ namespace communication::rpc {
|
||||
*/
|
||||
class ClientPool {
|
||||
public:
|
||||
explicit ClientPool(const io::network::Endpoint &endpoint)
|
||||
: endpoint_(endpoint) {}
|
||||
ClientPool(const io::network::Endpoint &endpoint,
|
||||
communication::ClientContext *context)
|
||||
: endpoint_(endpoint), context_(context) {}
|
||||
|
||||
template <class TRequestResponse, class... Args>
|
||||
typename TRequestResponse::Response Call(Args &&... args) {
|
||||
@ -40,7 +41,7 @@ class ClientPool {
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
if (unused_clients_.empty()) {
|
||||
client = std::make_unique<Client>(endpoint_);
|
||||
client = std::make_unique<Client>(endpoint_, context_);
|
||||
} else {
|
||||
client = std::move(unused_clients_.top());
|
||||
unused_clients_.pop();
|
||||
@ -55,6 +56,7 @@ class ClientPool {
|
||||
}
|
||||
|
||||
io::network::Endpoint endpoint_;
|
||||
communication::ClientContext *context_;
|
||||
|
||||
std::mutex mutex_;
|
||||
std::stack<std::unique_ptr<Client>> unused_clients_;
|
||||
|
@ -3,20 +3,15 @@
|
||||
namespace communication::rpc {
|
||||
|
||||
Server::Server(const io::network::Endpoint &endpoint,
|
||||
size_t workers_count)
|
||||
: server_(endpoint, this, &context_, -1, "RPC", workers_count) {}
|
||||
communication::ServerContext *context, size_t workers_count)
|
||||
: server_(endpoint, this, context, -1, context->use_ssl() ? "RPCS" : "RPC",
|
||||
workers_count) {}
|
||||
|
||||
bool Server::Start() {
|
||||
return server_.Start();
|
||||
}
|
||||
bool Server::Start() { return server_.Start(); }
|
||||
|
||||
void Server::Shutdown() {
|
||||
server_.Shutdown();
|
||||
}
|
||||
void Server::Shutdown() { server_.Shutdown(); }
|
||||
|
||||
void Server::AwaitShutdown() {
|
||||
server_.AwaitShutdown();
|
||||
}
|
||||
void Server::AwaitShutdown() { server_.AwaitShutdown(); }
|
||||
|
||||
const io::network::Endpoint &Server::endpoint() const {
|
||||
return server_.endpoint();
|
||||
|
@ -15,6 +15,7 @@ namespace communication::rpc {
|
||||
class Server {
|
||||
public:
|
||||
Server(const io::network::Endpoint &endpoint,
|
||||
communication::ServerContext *context,
|
||||
size_t workers_count = std::thread::hardware_concurrency());
|
||||
Server(const Server &) = delete;
|
||||
Server(Server &&) = delete;
|
||||
@ -88,8 +89,6 @@ class Server {
|
||||
std::map<uint64_t, RpcCallback> callbacks_;
|
||||
std::map<uint64_t, RpcExtendedCallback> extended_callbacks_;
|
||||
|
||||
// TODO (mferencevic): currently the RPC server is hardcoded not to use SSL
|
||||
communication::ServerContext context_;
|
||||
communication::Server<Session, Server> server_;
|
||||
}; // namespace communication::rpc
|
||||
|
||||
|
@ -10,7 +10,7 @@ Coordination::Coordination(const io::network::Endpoint &worker_endpoint,
|
||||
int worker_id,
|
||||
const io::network::Endpoint &master_endpoint,
|
||||
int server_workers_count, int client_workers_count)
|
||||
: server_(worker_endpoint, server_workers_count),
|
||||
: server_(worker_endpoint, &server_context_, server_workers_count),
|
||||
thread_pool_(client_workers_count, "RPC client") {
|
||||
if (worker_id != 0) {
|
||||
// The master is always worker 0.
|
||||
@ -64,7 +64,7 @@ communication::rpc::ClientPool *Coordination::GetClientPool(int worker_id) {
|
||||
return &client_pools_
|
||||
.emplace(std::piecewise_construct,
|
||||
std::forward_as_tuple(worker_id),
|
||||
std::forward_as_tuple(endpoint))
|
||||
std::forward_as_tuple(endpoint, &client_context_))
|
||||
.first->second;
|
||||
}
|
||||
|
||||
|
@ -95,6 +95,8 @@ class Coordination {
|
||||
/// Gets a worker name for the given endpoint.
|
||||
std::string GetWorkerName(const io::network::Endpoint &endpoint);
|
||||
|
||||
// TODO(mferencevic): distributed is currently hardcoded not to use SSL
|
||||
communication::ServerContext server_context_;
|
||||
communication::rpc::Server server_;
|
||||
|
||||
std::atomic<bool> cluster_alive_{true};
|
||||
@ -103,6 +105,8 @@ class Coordination {
|
||||
std::unordered_map<int, io::network::Endpoint> workers_;
|
||||
mutable std::mutex lock_;
|
||||
|
||||
// TODO(mferencevic): distributed is currently hardcoded not to use SSL
|
||||
communication::ClientContext client_context_;
|
||||
std::unordered_map<int, communication::rpc::ClientPool> client_pools_;
|
||||
utils::ThreadPool thread_pool_;
|
||||
};
|
||||
|
@ -27,8 +27,8 @@ DEFINE_VALIDATED_int32(session_inactivity_timeout, 1800,
|
||||
"Time in seconds after which inactive sessions will be "
|
||||
"closed.",
|
||||
FLAG_IN_RANGE(1, INT32_MAX));
|
||||
DEFINE_string(cert_file, "", "Certificate file to use.");
|
||||
DEFINE_string(key_file, "", "Key file to use.");
|
||||
DEFINE_string(cert_file, "", "Certificate file to use (Bolt).");
|
||||
DEFINE_string(key_file, "", "Key file to use (Bolt).");
|
||||
|
||||
using ServerT = communication::Server<BoltSession, SessionData>;
|
||||
using communication::ServerContext;
|
||||
|
@ -1,10 +1,14 @@
|
||||
#include "raft/coordination.hpp"
|
||||
|
||||
#include <gflags/gflags.h>
|
||||
#include <json/json.hpp>
|
||||
|
||||
#include "utils/file.hpp"
|
||||
#include "utils/string.hpp"
|
||||
|
||||
DEFINE_string(rpc_cert_file, "", "Certificate file to use (RPC).");
|
||||
DEFINE_string(rpc_key_file, "", "Key file to use (RPC).");
|
||||
|
||||
namespace raft {
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
@ -46,15 +50,23 @@ std::unordered_map<uint16_t, io::network::Endpoint> LoadNodesFromFile(
|
||||
Coordination::Coordination(
|
||||
uint16_t node_id,
|
||||
std::unordered_map<uint16_t, io::network::Endpoint> all_nodes)
|
||||
: node_id_(node_id),
|
||||
cluster_size_(all_nodes.size()),
|
||||
server_(all_nodes[node_id], all_nodes.size() * 2) {
|
||||
: node_id_(node_id), cluster_size_(all_nodes.size()) {
|
||||
// Create and initialize all server elements.
|
||||
if (!FLAGS_rpc_cert_file.empty() && !FLAGS_rpc_key_file.empty()) {
|
||||
server_context_.emplace(FLAGS_rpc_key_file, FLAGS_rpc_cert_file);
|
||||
} else {
|
||||
server_context_.emplace();
|
||||
}
|
||||
server_.emplace(all_nodes[node_id_], &server_context_.value(),
|
||||
all_nodes.size() * 2);
|
||||
|
||||
// Create all client elements.
|
||||
endpoints_.resize(cluster_size_);
|
||||
clients_.resize(cluster_size_);
|
||||
client_locks_.resize(cluster_size_);
|
||||
|
||||
// Initialize all client elements.
|
||||
client_context_.emplace(server_context_->use_ssl());
|
||||
for (uint16_t i = 1; i <= cluster_size_; ++i) {
|
||||
auto it = all_nodes.find(i);
|
||||
if (it == all_nodes.end()) {
|
||||
@ -93,7 +105,7 @@ uint16_t Coordination::GetAllNodeCount() { return cluster_size_; }
|
||||
|
||||
uint16_t Coordination::GetOtherNodeCount() { return cluster_size_ - 1; }
|
||||
|
||||
bool Coordination::Start() { return server_.Start(); }
|
||||
bool Coordination::Start() { return server_->Start(); }
|
||||
|
||||
void Coordination::AwaitShutdown(
|
||||
std::function<void(void)> call_before_shutdown) {
|
||||
@ -106,8 +118,8 @@ void Coordination::AwaitShutdown(
|
||||
call_before_shutdown();
|
||||
|
||||
// Shutdown our RPC server.
|
||||
server_.Shutdown();
|
||||
server_.AwaitShutdown();
|
||||
server_->Shutdown();
|
||||
server_->AwaitShutdown();
|
||||
}
|
||||
|
||||
void Coordination::Shutdown() { alive_.store(false); }
|
||||
|
@ -80,7 +80,8 @@ class Coordination final {
|
||||
|
||||
if (!client) {
|
||||
const auto &endpoint = endpoints_[other_id - 1];
|
||||
client = std::make_unique<communication::rpc::Client>(endpoint);
|
||||
client = std::make_unique<communication::rpc::Client>(
|
||||
endpoint, &client_context_.value());
|
||||
}
|
||||
|
||||
try {
|
||||
@ -95,7 +96,7 @@ class Coordination final {
|
||||
/// Registers a RPC call on this node.
|
||||
template <class TRequestResponse>
|
||||
void Register(std::function<void(slk::Reader *, slk::Builder *)> callback) {
|
||||
server_.Register<TRequestResponse>(callback);
|
||||
server_->Register<TRequestResponse>(callback);
|
||||
}
|
||||
|
||||
/// Registers an extended RPC call on this node.
|
||||
@ -103,7 +104,7 @@ class Coordination final {
|
||||
void Register(std::function<void(const io::network::Endpoint &, slk::Reader *,
|
||||
slk::Builder *)>
|
||||
callback) {
|
||||
server_.Register<TRequestResponse>(callback);
|
||||
server_->Register<TRequestResponse>(callback);
|
||||
}
|
||||
|
||||
/// Starts the coordination and its servers.
|
||||
@ -121,8 +122,10 @@ class Coordination final {
|
||||
uint16_t node_id_;
|
||||
uint16_t cluster_size_;
|
||||
|
||||
communication::rpc::Server server_;
|
||||
std::optional<communication::ServerContext> server_context_;
|
||||
std::optional<communication::rpc::Server> server_;
|
||||
|
||||
std::optional<communication::ClientContext> client_context_;
|
||||
std::vector<io::network::Endpoint> endpoints_;
|
||||
std::vector<std::unique_ptr<communication::rpc::Client>> clients_;
|
||||
std::vector<std::unique_ptr<std::mutex>> client_locks_;
|
||||
|
@ -46,7 +46,9 @@ void StatsDispatchMain(const io::network::Endpoint &endpoint) {
|
||||
LOG(INFO) << "Stats dispatcher thread started";
|
||||
utils::ThreadSetName("Stats dispatcher");
|
||||
|
||||
communication::rpc::Client client(endpoint);
|
||||
// TODO(mferencevic): stats are currently hardcoded not to use SSL
|
||||
communication::ClientContext client_context;
|
||||
communication::rpc::Client client(endpoint, &client_context);
|
||||
|
||||
BatchStatsReq batch_request;
|
||||
batch_request.requests.reserve(MAX_BATCH_SIZE);
|
||||
|
@ -44,10 +44,15 @@ const int kThreadsNum = 16;
|
||||
|
||||
DEFINE_string(server_address, "127.0.0.1", "Server address");
|
||||
DEFINE_int32(server_port, 0, "Server port");
|
||||
DEFINE_string(server_cert_file, "", "Server SSL certificate file");
|
||||
DEFINE_string(server_key_file, "", "Server SSL key file");
|
||||
DEFINE_bool(benchmark_use_ssl, false, "Set to true to benchmark using SSL");
|
||||
DEFINE_bool(run_server, true, "Set to false to use external server");
|
||||
DEFINE_bool(run_benchmark, true, "Set to false to only run server");
|
||||
|
||||
std::optional<communication::ServerContext> server_context;
|
||||
std::optional<communication::rpc::Server> server;
|
||||
std::optional<communication::ClientContext> client_context;
|
||||
std::optional<communication::rpc::Client> clients[kThreadsNum];
|
||||
std::optional<communication::rpc::ClientPool> client_pool;
|
||||
std::optional<utils::ThreadPool> thread_pool;
|
||||
@ -105,9 +110,15 @@ int main(int argc, char **argv) {
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
if (FLAGS_run_server) {
|
||||
if (!FLAGS_server_cert_file.empty() && !FLAGS_server_key_file.empty()) {
|
||||
FLAGS_benchmark_use_ssl = true;
|
||||
server_context.emplace(FLAGS_server_key_file, FLAGS_server_cert_file);
|
||||
} else {
|
||||
server_context.emplace();
|
||||
}
|
||||
server.emplace(
|
||||
io::network::Endpoint(FLAGS_server_address, FLAGS_server_port),
|
||||
kThreadsNum);
|
||||
&server_context.value(), kThreadsNum);
|
||||
|
||||
server->Register<Echo>([](const auto &req_reader, auto *res_builder) {
|
||||
EchoMessage res;
|
||||
@ -127,8 +138,10 @@ int main(int argc, char **argv) {
|
||||
endpoint = io::network::Endpoint(FLAGS_server_address, FLAGS_server_port);
|
||||
}
|
||||
|
||||
client_context.emplace(FLAGS_benchmark_use_ssl);
|
||||
|
||||
for (int i = 0; i < kThreadsNum; ++i) {
|
||||
clients[i].emplace(endpoint);
|
||||
clients[i].emplace(endpoint, &client_context.value());
|
||||
clients[i]->Call<Echo>("init");
|
||||
}
|
||||
|
||||
@ -137,7 +150,7 @@ int main(int argc, char **argv) {
|
||||
// of making connections to the server during the benchmark here we
|
||||
// simultaneously call the Echo RPC on the client pool to make the client
|
||||
// pool connect to the server `kThreadsNum` times.
|
||||
client_pool.emplace(endpoint);
|
||||
client_pool.emplace(endpoint, &client_context.value());
|
||||
std::thread threads[kThreadsNum];
|
||||
for (int i = 0; i < kThreadsNum; ++i) {
|
||||
threads[i] =
|
||||
|
@ -15,6 +15,7 @@ class DistributedConcurrentIdMapperTest : public ::testing::Test {
|
||||
|
||||
protected:
|
||||
TestMasterCoordination coordination_;
|
||||
communication::ClientContext client_context_;
|
||||
std::optional<communication::rpc::ClientPool> master_client_pool_;
|
||||
std::optional<storage::MasterConcurrentIdMapper<TId>> master_mapper_;
|
||||
std::optional<storage::WorkerConcurrentIdMapper<TId>> worker_mapper_;
|
||||
@ -22,7 +23,8 @@ class DistributedConcurrentIdMapperTest : public ::testing::Test {
|
||||
void SetUp() override {
|
||||
master_mapper_.emplace(&coordination_);
|
||||
coordination_.Start();
|
||||
master_client_pool_.emplace(coordination_.GetServerEndpoint());
|
||||
master_client_pool_.emplace(coordination_.GetServerEndpoint(),
|
||||
&client_context_);
|
||||
worker_mapper_.emplace(&master_client_pool_.value());
|
||||
}
|
||||
void TearDown() override {
|
||||
|
@ -54,7 +54,8 @@ void EchoMessage::Save(const EchoMessage &obj, slk::Builder *builder) {
|
||||
}
|
||||
|
||||
TEST(Rpc, Call) {
|
||||
Server server({"127.0.0.1", 0});
|
||||
communication::ServerContext server_context;
|
||||
Server server({"127.0.0.1", 0}, &server_context);
|
||||
server.Register<Sum>([](auto *req_reader, auto *res_builder) {
|
||||
SumReq req;
|
||||
slk::Load(&req, req_reader);
|
||||
@ -64,7 +65,8 @@ TEST(Rpc, Call) {
|
||||
ASSERT_TRUE(server.Start());
|
||||
std::this_thread::sleep_for(100ms);
|
||||
|
||||
Client client(server.endpoint());
|
||||
communication::ClientContext client_context;
|
||||
Client client(server.endpoint(), &client_context);
|
||||
auto sum = client.Call<Sum>(10, 20);
|
||||
EXPECT_EQ(sum.sum, 30);
|
||||
|
||||
@ -73,7 +75,8 @@ TEST(Rpc, Call) {
|
||||
}
|
||||
|
||||
TEST(Rpc, Abort) {
|
||||
Server server({"127.0.0.1", 0});
|
||||
communication::ServerContext server_context;
|
||||
Server server({"127.0.0.1", 0}, &server_context);
|
||||
server.Register<Sum>([](auto *req_reader, auto *res_builder) {
|
||||
SumReq req;
|
||||
slk::Load(&req, req_reader);
|
||||
@ -84,7 +87,8 @@ TEST(Rpc, Abort) {
|
||||
ASSERT_TRUE(server.Start());
|
||||
std::this_thread::sleep_for(100ms);
|
||||
|
||||
Client client(server.endpoint());
|
||||
communication::ClientContext client_context;
|
||||
Client client(server.endpoint(), &client_context);
|
||||
|
||||
std::thread thread([&client]() {
|
||||
std::this_thread::sleep_for(100ms);
|
||||
@ -104,7 +108,8 @@ TEST(Rpc, Abort) {
|
||||
}
|
||||
|
||||
TEST(Rpc, ClientPool) {
|
||||
Server server({"127.0.0.1", 0});
|
||||
communication::ServerContext server_context;
|
||||
Server server({"127.0.0.1", 0}, &server_context);
|
||||
server.Register<Sum>([](const auto &req_reader, auto *res_builder) {
|
||||
SumReq req;
|
||||
Load(&req, req_reader);
|
||||
@ -115,7 +120,8 @@ TEST(Rpc, ClientPool) {
|
||||
ASSERT_TRUE(server.Start());
|
||||
std::this_thread::sleep_for(100ms);
|
||||
|
||||
Client client(server.endpoint());
|
||||
communication::ClientContext client_context;
|
||||
Client client(server.endpoint(), &client_context);
|
||||
|
||||
// These calls should take more than 400ms because we're using a regular
|
||||
// client
|
||||
@ -136,7 +142,8 @@ TEST(Rpc, ClientPool) {
|
||||
|
||||
EXPECT_GE(t1.Elapsed(), 400ms);
|
||||
|
||||
ClientPool pool(server.endpoint());
|
||||
communication::ClientContext pool_context;
|
||||
ClientPool pool(server.endpoint(), &pool_context);
|
||||
|
||||
// These calls shouldn't take much more that 100ms because they execute in
|
||||
// parallel
|
||||
@ -159,7 +166,8 @@ TEST(Rpc, ClientPool) {
|
||||
}
|
||||
|
||||
TEST(Rpc, LargeMessage) {
|
||||
Server server({"127.0.0.1", 0});
|
||||
communication::ServerContext server_context;
|
||||
Server server({"127.0.0.1", 0}, &server_context);
|
||||
server.Register<Echo>([](auto *req_reader, auto *res_builder) {
|
||||
EchoMessage res;
|
||||
slk::Load(&res, req_reader);
|
||||
@ -170,7 +178,8 @@ TEST(Rpc, LargeMessage) {
|
||||
|
||||
std::string testdata(100000, 'a');
|
||||
|
||||
Client client(server.endpoint());
|
||||
communication::ClientContext client_context;
|
||||
Client client(server.endpoint(), &client_context);
|
||||
auto echo = client.Call<Echo>(testdata);
|
||||
EXPECT_EQ(echo.data, testdata);
|
||||
|
||||
@ -179,7 +188,8 @@ TEST(Rpc, LargeMessage) {
|
||||
}
|
||||
|
||||
TEST(Rpc, JumboMessage) {
|
||||
Server server({"127.0.0.1", 0});
|
||||
communication::ServerContext server_context;
|
||||
Server server({"127.0.0.1", 0}, &server_context);
|
||||
server.Register<Echo>([](auto *req_reader, auto *res_builder) {
|
||||
EchoMessage res;
|
||||
slk::Load(&res, req_reader);
|
||||
@ -191,7 +201,8 @@ TEST(Rpc, JumboMessage) {
|
||||
// NOLINTNEXTLINE (bugprone-string-constructor)
|
||||
std::string testdata(10000000, 'a');
|
||||
|
||||
Client client(server.endpoint());
|
||||
communication::ClientContext client_context;
|
||||
Client client(server.endpoint(), &client_context);
|
||||
auto echo = client.Call<Echo>(testdata);
|
||||
EXPECT_EQ(echo.data, testdata);
|
||||
|
||||
|
@ -32,7 +32,10 @@ std::string GraphiteFormat(const stats::StatsReq &req) {
|
||||
int main(int argc, char *argv[]) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
|
||||
communication::rpc::Server server({FLAGS_interface, (uint16_t)FLAGS_port});
|
||||
// TODO(mferencevic): stats are currently hardcoded not to use SSL
|
||||
communication::ServerContext server_context;
|
||||
communication::rpc::Server server({FLAGS_interface, (uint16_t)FLAGS_port},
|
||||
&server_context);
|
||||
|
||||
io::network::Socket graphite_socket;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user