diff --git a/src/communication/rpc/rpc.cpp b/src/communication/rpc/rpc.cpp index a6ce67660..2bcbbfcd5 100644 --- a/src/communication/rpc/rpc.cpp +++ b/src/communication/rpc/rpc.cpp @@ -109,24 +109,29 @@ Server::Server(messaging::System &system, const std::string &name) void Server::Start() { // TODO: Add logging. - while (alive_) { - auto message = stream_->Await(); - if (!message) continue; - auto *request = dynamic_cast<Request *>(message.get()); - if (!request) continue; - auto &real_request = request->message(); - auto it = callbacks_.find(real_request.type_index()); - if (it == callbacks_.end()) continue; - auto response = it->second(real_request); - messaging::Writer writer(system_, request->address(), request->port(), - request->stream()); - writer.Send<Response>(request->message_id(), std::move(response)); - } + CHECK(started_ == false) << "Server can't be started multiple times"; + started_ = true; + running_thread_ = std::thread([this]() { + while (alive_) { + auto message = stream_->Await(); + if (!message) continue; + auto *request = dynamic_cast<Request *>(message.get()); + if (!request) continue; + auto &real_request = request->message(); + auto it = callbacks_.find(real_request.type_index()); + if (it == callbacks_.end()) continue; + auto response = it->second(real_request); + messaging::Writer writer(system_, request->address(), request->port(), + request->stream()); + writer.Send<Response>(request->message_id(), std::move(response)); + } + }); } void Server::Shutdown() { alive_ = false; stream_->Shutdown(); + if (running_thread_.joinable()) running_thread_.join(); } } // namespace communication::rpc CEREAL_REGISTER_TYPE(communication::rpc::Request); diff --git a/src/communication/rpc/rpc.hpp b/src/communication/rpc/rpc.hpp index 5c6c2ebec..98b186762 100644 --- a/src/communication/rpc/rpc.hpp +++ b/src/communication/rpc/rpc.hpp @@ -92,5 +92,8 @@ class Server { const messaging::Message &)>> callbacks_; std::atomic<bool> alive_{true}; + + std::thread running_thread_; + bool started_{false}; }; } // namespace communication::rpc diff --git a/src/transactions/engine_master.cpp b/src/transactions/engine_master.cpp index c5af8afaa..5d362cee4 100644 --- a/src/transactions/engine_master.cpp +++ b/src/transactions/engine_master.cpp @@ -125,15 +125,12 @@ void MasterEngine::StartServer(communication::messaging::System &system) { return std::make_unique<IsActiveRes>(GlobalIsActive(req.member)); }); - rpc_server_thread_ = std::thread([this] { rpc_server_->Start(); }); + rpc_server_->Start(); } void MasterEngine::StopServer() { CHECK(rpc_server_) << "Can't stop a server that's not running"; rpc_server_->Shutdown(); - if (rpc_server_thread_.joinable()) { - rpc_server_thread_.join(); - } rpc_server_ = std::experimental::nullopt; } } // namespace tx diff --git a/src/transactions/engine_master.hpp b/src/transactions/engine_master.hpp index 3ff9a5c1e..9c26a031b 100644 --- a/src/transactions/engine_master.hpp +++ b/src/transactions/engine_master.hpp @@ -81,6 +81,5 @@ class MasterEngine : public Engine { // Optional RPC server, only used in distributed, not in single_node. std::experimental::optional<communication::rpc::Server> rpc_server_; - std::thread rpc_server_thread_; }; } // namespace tx diff --git a/tests/unit/rpc.cpp b/tests/unit/rpc.cpp index 2f73a5d23..ff6871ba5 100644 --- a/tests/unit/rpc.cpp +++ b/tests/unit/rpc.cpp @@ -44,7 +44,7 @@ TEST(Rpc, Call) { server.Register<Sum>([](const SumReq &request) { return std::make_unique<SumRes>(request.x + request.y); }); - std::thread server_thread([&] { server.Start(); }); + server.Start(); std::this_thread::sleep_for(100ms); System client_system("127.0.0.1", 0); @@ -54,7 +54,6 @@ TEST(Rpc, Call) { EXPECT_EQ(sum->sum, 30); server.Shutdown(); - server_thread.join(); server_system.Shutdown(); client_system.Shutdown(); } @@ -66,7 +65,7 @@ TEST(Rpc, Timeout) { std::this_thread::sleep_for(300ms); return std::make_unique<SumRes>(request.x + request.y); }); - std::thread server_thread([&] { server.Start(); }); + server.Start(); std::this_thread::sleep_for(100ms); System client_system("127.0.0.1", 0); @@ -76,7 +75,6 @@ TEST(Rpc, Timeout) { EXPECT_FALSE(sum); server.Shutdown(); - server_thread.join(); server_system.Shutdown(); client_system.Shutdown(); }