diff --git a/src/communication/rpc/rpc.cpp b/src/communication/rpc/rpc.cpp index b799119d4..a6ce67660 100644 --- a/src/communication/rpc/rpc.cpp +++ b/src/communication/rpc/rpc.cpp @@ -128,6 +128,6 @@ void Server::Shutdown() { alive_ = false; stream_->Shutdown(); } -} +} // namespace communication::rpc CEREAL_REGISTER_TYPE(communication::rpc::Request); CEREAL_REGISTER_TYPE(communication::rpc::Response); diff --git a/src/communication/rpc/rpc.hpp b/src/communication/rpc/rpc.hpp index 526c9f1ff..5c6c2ebec 100644 --- a/src/communication/rpc/rpc.hpp +++ b/src/communication/rpc/rpc.hpp @@ -12,7 +12,7 @@ struct RequestResponse { using Response = TResponse; }; -// Client is not thread safe. +// Client is thread safe. class Client { public: Client(messaging::System &system, const std::string &address, uint16_t port, @@ -30,6 +30,7 @@ class Client { "TRequestResponse::Request must be derived from Message"); static_assert(std::is_base_of::value, "TRequestResponse::Response must be derived from Message"); + std::lock_guard lock(lock_); auto response = Call(timeout, std::unique_ptr( std::make_unique(std::forward(args)...))); @@ -50,6 +51,7 @@ class Client { messaging::System &system_; messaging::Writer writer_; std::shared_ptr stream_; + std::mutex lock_; }; class Server { @@ -68,8 +70,9 @@ class Server { typename TRequestResponse::Response>::value, "TRequestResponse::Response must be derived from Message"); auto got = callbacks_.emplace( - typeid(typename TRequestResponse::Request), - [callback = callback](const messaging::Message &base_message) { + typeid(typename TRequestResponse::Request), [callback = callback]( + const messaging::Message + &base_message) { const auto &message = dynamic_cast( base_message); @@ -90,4 +93,4 @@ class Server { callbacks_; std::atomic alive_{true}; }; -} +} // namespace communication::rpc diff --git a/src/transactions/engine_worker.cpp b/src/transactions/engine_worker.cpp index 8c2c1f180..007f7fdd1 100644 --- a/src/transactions/engine_worker.cpp +++ b/src/transactions/engine_worker.cpp @@ -15,7 +15,6 @@ WorkerEngine::WorkerEngine(communication::messaging::System &system, : rpc_client_(system, tx_server_host, tx_server_port, "tx_engine") {} Transaction *WorkerEngine::LocalBegin(transaction_id_t tx_id) { - std::lock_guard guard(rpc_client_lock_); auto accessor = active_.access(); auto found = accessor.find(tx_id); if (found != accessor.end()) return found->second; @@ -30,7 +29,6 @@ Transaction *WorkerEngine::LocalBegin(transaction_id_t tx_id) { } CommitLog::Info WorkerEngine::Info(transaction_id_t tid) const { - std::lock_guard guard(rpc_client_lock_); auto info = clog_.fetch_info(tid); // If we don't know the transaction to be commited nor aborted, ask the // master about it and update the local commit log. @@ -50,18 +48,15 @@ CommitLog::Info WorkerEngine::Info(transaction_id_t tid) const { } Snapshot WorkerEngine::GlobalGcSnapshot() { - std::lock_guard guard(rpc_client_lock_); return std::move(rpc_client_.Call(kRpcTimeout)->member); } Snapshot WorkerEngine::GlobalActiveTransactions() { - std::lock_guard guard(rpc_client_lock_); return std::move( rpc_client_.Call(kRpcTimeout)->member); } bool WorkerEngine::GlobalIsActive(transaction_id_t tid) const { - std::lock_guard guard(rpc_client_lock_); return rpc_client_.Call(kRpcTimeout, tid)->member; } diff --git a/src/transactions/engine_worker.hpp b/src/transactions/engine_worker.hpp index 17c903c12..98dc3ac6f 100644 --- a/src/transactions/engine_worker.hpp +++ b/src/transactions/engine_worker.hpp @@ -1,7 +1,7 @@ #pragma once -#include #include +#include #include "communication/messaging/distributed.hpp" #include "communication/rpc/rpc.hpp" @@ -36,6 +36,5 @@ class WorkerEngine : public Engine { // Communication to the transactional master. mutable communication::rpc::Client rpc_client_; - mutable std::mutex rpc_client_lock_; }; } // namespace tx diff --git a/tests/unit/rpc.cpp b/tests/unit/rpc.cpp index fac41a30a..2f73a5d23 100644 --- a/tests/unit/rpc.cpp +++ b/tests/unit/rpc.cpp @@ -7,8 +7,8 @@ #include "communication/rpc/rpc.hpp" #include "gtest/gtest.h" -using communication::messaging::System; using communication::messaging::Message; +using communication::messaging::System; using namespace communication::rpc; using namespace std::literals::chrono_literals;