diff --git a/src/communication/rpc/client.cpp b/src/communication/rpc/client.cpp index 8133b70f0..b000f459d 100644 --- a/src/communication/rpc/client.cpp +++ b/src/communication/rpc/client.cpp @@ -127,4 +127,12 @@ std::unique_ptr<Message> Client::Call(std::unique_ptr<Message> request) { } } +void Client::Abort() { + if (!socket_) return; + // We need to call Shutdown on the socket to abort any pending read or + // write operations. + socket_->Shutdown(); + socket_ = std::experimental::nullopt; +} + } // namespace communication::rpc diff --git a/src/communication/rpc/client.hpp b/src/communication/rpc/client.hpp index 7e5448109..598b5745b 100644 --- a/src/communication/rpc/client.hpp +++ b/src/communication/rpc/client.hpp @@ -41,6 +41,9 @@ class Client { return std::unique_ptr<Res>(real_response); } + // Call this function from another thread to abort a pending RPC call. + void Abort(); + private: std::unique_ptr<Message> Call(std::unique_ptr<Message> request); diff --git a/src/io/network/socket.cpp b/src/io/network/socket.cpp index f6697fa5e..bafa7b80c 100644 --- a/src/io/network/socket.cpp +++ b/src/io/network/socket.cpp @@ -50,6 +50,11 @@ void Socket::Close() { socket_ = -1; } +void Socket::Shutdown() { + if (socket_ == -1) return; + shutdown(socket_, SHUT_RDWR); +} + bool Socket::IsOpen() const { return socket_ != -1; } bool Socket::Connect(const Endpoint &endpoint) { diff --git a/src/io/network/socket.hpp b/src/io/network/socket.hpp index 78490e206..8704bfe2d 100644 --- a/src/io/network/socket.hpp +++ b/src/io/network/socket.hpp @@ -28,6 +28,11 @@ class Socket { */ void Close(); + /** + * Shutdown the socket if it is open. + */ + void Shutdown(); + /** * Checks whether the socket is open. * diff --git a/tests/unit/rpc.cpp b/tests/unit/rpc.cpp index dc3b1e40a..6dc973d66 100644 --- a/tests/unit/rpc.cpp +++ b/tests/unit/rpc.cpp @@ -14,6 +14,7 @@ #include "communication/rpc/messages.hpp" #include "communication/rpc/server.hpp" #include "gtest/gtest.h" +#include "utils/timer.hpp" using namespace communication::rpc; using namespace std::literals::chrono_literals; @@ -66,18 +67,27 @@ TEST(Rpc, Call) { EXPECT_EQ(sum->sum, 30); } -/* TODO (mferencevic): enable when async calls are implemented! -TEST(Rpc, Timeout) { +TEST(Rpc, Abort) { System server_system({"127.0.0.1", 0}); Server server(server_system, "main"); server.Register<Sum>([](const SumReq &request) { - std::this_thread::sleep_for(300ms); + std::this_thread::sleep_for(500ms); return std::make_unique<SumRes>(request.x + request.y); }); std::this_thread::sleep_for(100ms); Client client(server_system.endpoint(), "main"); - auto sum = client.Call<Sum>(100ms, 10, 20); - EXPECT_FALSE(sum); + + std::thread thread([&client]() { + std::this_thread::sleep_for(100ms); + LOG(INFO) << "Shutting down the connection!"; + client.Abort(); + }); + + utils::Timer timer; + auto sum = client.Call<Sum>(10, 20); + EXPECT_EQ(sum, nullptr); + EXPECT_LT(timer.Elapsed(), 200ms); + + thread.join(); } -*/