diff --git a/src/communication/rpc/client.cpp b/src/communication/rpc/client.cpp index cc0005951..60dec2914 100644 --- a/src/communication/rpc/client.cpp +++ b/src/communication/rpc/client.cpp @@ -2,7 +2,9 @@ namespace communication::rpc { -Client::Client(const io::network::Endpoint &endpoint) : endpoint_(endpoint) {} +Client::Client(const io::network::Endpoint &endpoint, + communication::ClientContext *context) + : endpoint_(endpoint), context_(context) {} void Client::Abort() { if (!client_) return; diff --git a/src/communication/rpc/client.hpp b/src/communication/rpc/client.hpp index 9bfd29ab7..2ac3e0845 100644 --- a/src/communication/rpc/client.hpp +++ b/src/communication/rpc/client.hpp @@ -19,7 +19,8 @@ namespace communication::rpc { /// Client is thread safe, but it is recommended to use thread_local clients. class Client { public: - explicit Client(const io::network::Endpoint &endpoint); + Client(const io::network::Endpoint &endpoint, + communication::ClientContext *context); /// Call a previously defined and registered RPC call. This function can /// initiate only one request at a time. The call blocks until a response is @@ -61,7 +62,7 @@ class Client { // Connect to the remote server. if (!client_) { - client_.emplace(&context_); + client_.emplace(context_); if (!client_->Connect(endpoint_)) { DLOG(ERROR) << "Couldn't connect to remote address " << endpoint_; client_ = std::nullopt; @@ -121,8 +122,7 @@ class Client { private: io::network::Endpoint endpoint_; - // TODO (mferencevic): currently the RPC client is hardcoded not to use SSL - communication::ClientContext context_; + communication::ClientContext *context_; std::optional<communication::Client> client_; std::mutex mutex_; diff --git a/src/communication/rpc/client_pool.hpp b/src/communication/rpc/client_pool.hpp index 64ee9eb6f..6d447029d 100644 --- a/src/communication/rpc/client_pool.hpp +++ b/src/communication/rpc/client_pool.hpp @@ -14,8 +14,9 @@ namespace communication::rpc { */ class ClientPool { public: - explicit ClientPool(const io::network::Endpoint &endpoint) - : endpoint_(endpoint) {} + ClientPool(const io::network::Endpoint &endpoint, + communication::ClientContext *context) + : endpoint_(endpoint), context_(context) {} template <class TRequestResponse, class... Args> typename TRequestResponse::Response Call(Args &&... args) { @@ -40,7 +41,7 @@ class ClientPool { std::unique_lock<std::mutex> lock(mutex_); if (unused_clients_.empty()) { - client = std::make_unique<Client>(endpoint_); + client = std::make_unique<Client>(endpoint_, context_); } else { client = std::move(unused_clients_.top()); unused_clients_.pop(); @@ -55,6 +56,7 @@ class ClientPool { } io::network::Endpoint endpoint_; + communication::ClientContext *context_; std::mutex mutex_; std::stack<std::unique_ptr<Client>> unused_clients_; diff --git a/src/communication/rpc/server.cpp b/src/communication/rpc/server.cpp index 4f92b98a6..96049960b 100644 --- a/src/communication/rpc/server.cpp +++ b/src/communication/rpc/server.cpp @@ -3,20 +3,15 @@ namespace communication::rpc { Server::Server(const io::network::Endpoint &endpoint, - size_t workers_count) - : server_(endpoint, this, &context_, -1, "RPC", workers_count) {} + communication::ServerContext *context, size_t workers_count) + : server_(endpoint, this, context, -1, context->use_ssl() ? "RPCS" : "RPC", + workers_count) {} -bool Server::Start() { - return server_.Start(); -} +bool Server::Start() { return server_.Start(); } -void Server::Shutdown() { - server_.Shutdown(); -} +void Server::Shutdown() { server_.Shutdown(); } -void Server::AwaitShutdown() { - server_.AwaitShutdown(); -} +void Server::AwaitShutdown() { server_.AwaitShutdown(); } const io::network::Endpoint &Server::endpoint() const { return server_.endpoint(); diff --git a/src/communication/rpc/server.hpp b/src/communication/rpc/server.hpp index e56ea6319..ede066400 100644 --- a/src/communication/rpc/server.hpp +++ b/src/communication/rpc/server.hpp @@ -15,6 +15,7 @@ namespace communication::rpc { class Server { public: Server(const io::network::Endpoint &endpoint, + communication::ServerContext *context, size_t workers_count = std::thread::hardware_concurrency()); Server(const Server &) = delete; Server(Server &&) = delete; @@ -88,8 +89,6 @@ class Server { std::map<uint64_t, RpcCallback> callbacks_; std::map<uint64_t, RpcExtendedCallback> extended_callbacks_; - // TODO (mferencevic): currently the RPC server is hardcoded not to use SSL - communication::ServerContext context_; communication::Server<Session, Server> server_; }; // namespace communication::rpc diff --git a/src/distributed/coordination.cpp b/src/distributed/coordination.cpp index 96a8bfcc0..a0ef1cb9c 100644 --- a/src/distributed/coordination.cpp +++ b/src/distributed/coordination.cpp @@ -10,7 +10,7 @@ Coordination::Coordination(const io::network::Endpoint &worker_endpoint, int worker_id, const io::network::Endpoint &master_endpoint, int server_workers_count, int client_workers_count) - : server_(worker_endpoint, server_workers_count), + : server_(worker_endpoint, &server_context_, server_workers_count), thread_pool_(client_workers_count, "RPC client") { if (worker_id != 0) { // The master is always worker 0. @@ -64,7 +64,7 @@ communication::rpc::ClientPool *Coordination::GetClientPool(int worker_id) { return &client_pools_ .emplace(std::piecewise_construct, std::forward_as_tuple(worker_id), - std::forward_as_tuple(endpoint)) + std::forward_as_tuple(endpoint, &client_context_)) .first->second; } diff --git a/src/distributed/coordination.hpp b/src/distributed/coordination.hpp index 47f7eaa9f..9cd3d8a86 100644 --- a/src/distributed/coordination.hpp +++ b/src/distributed/coordination.hpp @@ -95,6 +95,8 @@ class Coordination { /// Gets a worker name for the given endpoint. std::string GetWorkerName(const io::network::Endpoint &endpoint); + // TODO(mferencevic): distributed is currently hardcoded not to use SSL + communication::ServerContext server_context_; communication::rpc::Server server_; std::atomic<bool> cluster_alive_{true}; @@ -103,6 +105,8 @@ class Coordination { std::unordered_map<int, io::network::Endpoint> workers_; mutable std::mutex lock_; + // TODO(mferencevic): distributed is currently hardcoded not to use SSL + communication::ClientContext client_context_; std::unordered_map<int, communication::rpc::ClientPool> client_pools_; utils::ThreadPool thread_pool_; }; diff --git a/src/memgraph_ha.cpp b/src/memgraph_ha.cpp index 43fac609c..59790c4c7 100644 --- a/src/memgraph_ha.cpp +++ b/src/memgraph_ha.cpp @@ -27,8 +27,8 @@ DEFINE_VALIDATED_int32(session_inactivity_timeout, 1800, "Time in seconds after which inactive sessions will be " "closed.", FLAG_IN_RANGE(1, INT32_MAX)); -DEFINE_string(cert_file, "", "Certificate file to use."); -DEFINE_string(key_file, "", "Key file to use."); +DEFINE_string(cert_file, "", "Certificate file to use (Bolt)."); +DEFINE_string(key_file, "", "Key file to use (Bolt)."); using ServerT = communication::Server<BoltSession, SessionData>; using communication::ServerContext; diff --git a/src/raft/coordination.cpp b/src/raft/coordination.cpp index 746923809..f0f490c1e 100644 --- a/src/raft/coordination.cpp +++ b/src/raft/coordination.cpp @@ -1,10 +1,14 @@ #include "raft/coordination.hpp" +#include <gflags/gflags.h> #include <json/json.hpp> #include "utils/file.hpp" #include "utils/string.hpp" +DEFINE_string(rpc_cert_file, "", "Certificate file to use (RPC)."); +DEFINE_string(rpc_key_file, "", "Key file to use (RPC)."); + namespace raft { namespace fs = std::filesystem; @@ -46,15 +50,23 @@ std::unordered_map<uint16_t, io::network::Endpoint> LoadNodesFromFile( Coordination::Coordination( uint16_t node_id, std::unordered_map<uint16_t, io::network::Endpoint> all_nodes) - : node_id_(node_id), - cluster_size_(all_nodes.size()), - server_(all_nodes[node_id], all_nodes.size() * 2) { + : node_id_(node_id), cluster_size_(all_nodes.size()) { + // Create and initialize all server elements. + if (!FLAGS_rpc_cert_file.empty() && !FLAGS_rpc_key_file.empty()) { + server_context_.emplace(FLAGS_rpc_key_file, FLAGS_rpc_cert_file); + } else { + server_context_.emplace(); + } + server_.emplace(all_nodes[node_id_], &server_context_.value(), + all_nodes.size() * 2); + // Create all client elements. endpoints_.resize(cluster_size_); clients_.resize(cluster_size_); client_locks_.resize(cluster_size_); // Initialize all client elements. + client_context_.emplace(server_context_->use_ssl()); for (uint16_t i = 1; i <= cluster_size_; ++i) { auto it = all_nodes.find(i); if (it == all_nodes.end()) { @@ -93,7 +105,7 @@ uint16_t Coordination::GetAllNodeCount() { return cluster_size_; } uint16_t Coordination::GetOtherNodeCount() { return cluster_size_ - 1; } -bool Coordination::Start() { return server_.Start(); } +bool Coordination::Start() { return server_->Start(); } void Coordination::AwaitShutdown( std::function<void(void)> call_before_shutdown) { @@ -106,8 +118,8 @@ void Coordination::AwaitShutdown( call_before_shutdown(); // Shutdown our RPC server. - server_.Shutdown(); - server_.AwaitShutdown(); + server_->Shutdown(); + server_->AwaitShutdown(); } void Coordination::Shutdown() { alive_.store(false); } diff --git a/src/raft/coordination.hpp b/src/raft/coordination.hpp index 9495076ed..61b75b117 100644 --- a/src/raft/coordination.hpp +++ b/src/raft/coordination.hpp @@ -80,7 +80,8 @@ class Coordination final { if (!client) { const auto &endpoint = endpoints_[other_id - 1]; - client = std::make_unique<communication::rpc::Client>(endpoint); + client = std::make_unique<communication::rpc::Client>( + endpoint, &client_context_.value()); } try { @@ -95,7 +96,7 @@ class Coordination final { /// Registers a RPC call on this node. template <class TRequestResponse> void Register(std::function<void(slk::Reader *, slk::Builder *)> callback) { - server_.Register<TRequestResponse>(callback); + server_->Register<TRequestResponse>(callback); } /// Registers an extended RPC call on this node. @@ -103,7 +104,7 @@ class Coordination final { void Register(std::function<void(const io::network::Endpoint &, slk::Reader *, slk::Builder *)> callback) { - server_.Register<TRequestResponse>(callback); + server_->Register<TRequestResponse>(callback); } /// Starts the coordination and its servers. @@ -121,8 +122,10 @@ class Coordination final { uint16_t node_id_; uint16_t cluster_size_; - communication::rpc::Server server_; + std::optional<communication::ServerContext> server_context_; + std::optional<communication::rpc::Server> server_; + std::optional<communication::ClientContext> client_context_; std::vector<io::network::Endpoint> endpoints_; std::vector<std::unique_ptr<communication::rpc::Client>> clients_; std::vector<std::unique_ptr<std::mutex>> client_locks_; diff --git a/src/stats/stats.cpp b/src/stats/stats.cpp index f91dd7020..43236f18c 100644 --- a/src/stats/stats.cpp +++ b/src/stats/stats.cpp @@ -46,7 +46,9 @@ void StatsDispatchMain(const io::network::Endpoint &endpoint) { LOG(INFO) << "Stats dispatcher thread started"; utils::ThreadSetName("Stats dispatcher"); - communication::rpc::Client client(endpoint); + // TODO(mferencevic): stats are currently hardcoded not to use SSL + communication::ClientContext client_context; + communication::rpc::Client client(endpoint, &client_context); BatchStatsReq batch_request; batch_request.requests.reserve(MAX_BATCH_SIZE); diff --git a/tests/benchmark/rpc.cpp b/tests/benchmark/rpc.cpp index d2d8cd9a9..fe54ab75b 100644 --- a/tests/benchmark/rpc.cpp +++ b/tests/benchmark/rpc.cpp @@ -44,10 +44,15 @@ const int kThreadsNum = 16; DEFINE_string(server_address, "127.0.0.1", "Server address"); DEFINE_int32(server_port, 0, "Server port"); +DEFINE_string(server_cert_file, "", "Server SSL certificate file"); +DEFINE_string(server_key_file, "", "Server SSL key file"); +DEFINE_bool(benchmark_use_ssl, false, "Set to true to benchmark using SSL"); DEFINE_bool(run_server, true, "Set to false to use external server"); DEFINE_bool(run_benchmark, true, "Set to false to only run server"); +std::optional<communication::ServerContext> server_context; std::optional<communication::rpc::Server> server; +std::optional<communication::ClientContext> client_context; std::optional<communication::rpc::Client> clients[kThreadsNum]; std::optional<communication::rpc::ClientPool> client_pool; std::optional<utils::ThreadPool> thread_pool; @@ -105,9 +110,15 @@ int main(int argc, char **argv) { google::InitGoogleLogging(argv[0]); if (FLAGS_run_server) { + if (!FLAGS_server_cert_file.empty() && !FLAGS_server_key_file.empty()) { + FLAGS_benchmark_use_ssl = true; + server_context.emplace(FLAGS_server_key_file, FLAGS_server_cert_file); + } else { + server_context.emplace(); + } server.emplace( io::network::Endpoint(FLAGS_server_address, FLAGS_server_port), - kThreadsNum); + &server_context.value(), kThreadsNum); server->Register<Echo>([](const auto &req_reader, auto *res_builder) { EchoMessage res; @@ -127,8 +138,10 @@ int main(int argc, char **argv) { endpoint = io::network::Endpoint(FLAGS_server_address, FLAGS_server_port); } + client_context.emplace(FLAGS_benchmark_use_ssl); + for (int i = 0; i < kThreadsNum; ++i) { - clients[i].emplace(endpoint); + clients[i].emplace(endpoint, &client_context.value()); clients[i]->Call<Echo>("init"); } @@ -137,7 +150,7 @@ int main(int argc, char **argv) { // of making connections to the server during the benchmark here we // simultaneously call the Echo RPC on the client pool to make the client // pool connect to the server `kThreadsNum` times. - client_pool.emplace(endpoint); + client_pool.emplace(endpoint, &client_context.value()); std::thread threads[kThreadsNum]; for (int i = 0; i < kThreadsNum; ++i) { threads[i] = diff --git a/tests/unit/concurrent_id_mapper_distributed.cpp b/tests/unit/concurrent_id_mapper_distributed.cpp index 5d0162e42..4e85985b3 100644 --- a/tests/unit/concurrent_id_mapper_distributed.cpp +++ b/tests/unit/concurrent_id_mapper_distributed.cpp @@ -15,6 +15,7 @@ class DistributedConcurrentIdMapperTest : public ::testing::Test { protected: TestMasterCoordination coordination_; + communication::ClientContext client_context_; std::optional<communication::rpc::ClientPool> master_client_pool_; std::optional<storage::MasterConcurrentIdMapper<TId>> master_mapper_; std::optional<storage::WorkerConcurrentIdMapper<TId>> worker_mapper_; @@ -22,7 +23,8 @@ class DistributedConcurrentIdMapperTest : public ::testing::Test { void SetUp() override { master_mapper_.emplace(&coordination_); coordination_.Start(); - master_client_pool_.emplace(coordination_.GetServerEndpoint()); + master_client_pool_.emplace(coordination_.GetServerEndpoint(), + &client_context_); worker_mapper_.emplace(&master_client_pool_.value()); } void TearDown() override { diff --git a/tests/unit/rpc.cpp b/tests/unit/rpc.cpp index fe4e09db4..1fe992c8b 100644 --- a/tests/unit/rpc.cpp +++ b/tests/unit/rpc.cpp @@ -54,7 +54,8 @@ void EchoMessage::Save(const EchoMessage &obj, slk::Builder *builder) { } TEST(Rpc, Call) { - Server server({"127.0.0.1", 0}); + communication::ServerContext server_context; + Server server({"127.0.0.1", 0}, &server_context); server.Register<Sum>([](auto *req_reader, auto *res_builder) { SumReq req; slk::Load(&req, req_reader); @@ -64,7 +65,8 @@ TEST(Rpc, Call) { ASSERT_TRUE(server.Start()); std::this_thread::sleep_for(100ms); - Client client(server.endpoint()); + communication::ClientContext client_context; + Client client(server.endpoint(), &client_context); auto sum = client.Call<Sum>(10, 20); EXPECT_EQ(sum.sum, 30); @@ -73,7 +75,8 @@ TEST(Rpc, Call) { } TEST(Rpc, Abort) { - Server server({"127.0.0.1", 0}); + communication::ServerContext server_context; + Server server({"127.0.0.1", 0}, &server_context); server.Register<Sum>([](auto *req_reader, auto *res_builder) { SumReq req; slk::Load(&req, req_reader); @@ -84,7 +87,8 @@ TEST(Rpc, Abort) { ASSERT_TRUE(server.Start()); std::this_thread::sleep_for(100ms); - Client client(server.endpoint()); + communication::ClientContext client_context; + Client client(server.endpoint(), &client_context); std::thread thread([&client]() { std::this_thread::sleep_for(100ms); @@ -104,7 +108,8 @@ TEST(Rpc, Abort) { } TEST(Rpc, ClientPool) { - Server server({"127.0.0.1", 0}); + communication::ServerContext server_context; + Server server({"127.0.0.1", 0}, &server_context); server.Register<Sum>([](const auto &req_reader, auto *res_builder) { SumReq req; Load(&req, req_reader); @@ -115,7 +120,8 @@ TEST(Rpc, ClientPool) { ASSERT_TRUE(server.Start()); std::this_thread::sleep_for(100ms); - Client client(server.endpoint()); + communication::ClientContext client_context; + Client client(server.endpoint(), &client_context); // These calls should take more than 400ms because we're using a regular // client @@ -136,7 +142,8 @@ TEST(Rpc, ClientPool) { EXPECT_GE(t1.Elapsed(), 400ms); - ClientPool pool(server.endpoint()); + communication::ClientContext pool_context; + ClientPool pool(server.endpoint(), &pool_context); // These calls shouldn't take much more that 100ms because they execute in // parallel @@ -159,7 +166,8 @@ TEST(Rpc, ClientPool) { } TEST(Rpc, LargeMessage) { - Server server({"127.0.0.1", 0}); + communication::ServerContext server_context; + Server server({"127.0.0.1", 0}, &server_context); server.Register<Echo>([](auto *req_reader, auto *res_builder) { EchoMessage res; slk::Load(&res, req_reader); @@ -170,7 +178,8 @@ TEST(Rpc, LargeMessage) { std::string testdata(100000, 'a'); - Client client(server.endpoint()); + communication::ClientContext client_context; + Client client(server.endpoint(), &client_context); auto echo = client.Call<Echo>(testdata); EXPECT_EQ(echo.data, testdata); @@ -179,7 +188,8 @@ TEST(Rpc, LargeMessage) { } TEST(Rpc, JumboMessage) { - Server server({"127.0.0.1", 0}); + communication::ServerContext server_context; + Server server({"127.0.0.1", 0}, &server_context); server.Register<Echo>([](auto *req_reader, auto *res_builder) { EchoMessage res; slk::Load(&res, req_reader); @@ -191,7 +201,8 @@ TEST(Rpc, JumboMessage) { // NOLINTNEXTLINE (bugprone-string-constructor) std::string testdata(10000000, 'a'); - Client client(server.endpoint()); + communication::ClientContext client_context; + Client client(server.endpoint(), &client_context); auto echo = client.Call<Echo>(testdata); EXPECT_EQ(echo.data, testdata); diff --git a/tools/src/mg_statsd/main.cpp b/tools/src/mg_statsd/main.cpp index a4804291f..e9024ed0b 100644 --- a/tools/src/mg_statsd/main.cpp +++ b/tools/src/mg_statsd/main.cpp @@ -32,7 +32,10 @@ std::string GraphiteFormat(const stats::StatsReq &req) { int main(int argc, char *argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, true); - communication::rpc::Server server({FLAGS_interface, (uint16_t)FLAGS_port}); + // TODO(mferencevic): stats are currently hardcoded not to use SSL + communication::ServerContext server_context; + communication::rpc::Server server({FLAGS_interface, (uint16_t)FLAGS_port}, + &server_context); io::network::Socket graphite_socket;