From dd9180da3255db7974755503d8e160a68cd01ac7 Mon Sep 17 00:00:00 2001 From: Matej Ferencevic Date: Tue, 26 May 2020 12:47:25 +0200 Subject: [PATCH] Add request streaming support to the RPC client Summary: This change only adds streaming support to the client request. The client response, server request and server response are still handled only when all of the data is received. Reviewers: buda Reviewed By: buda Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D2807 --- src/rpc/client.hpp | 177 +++++++++++++++++++++++++++++++------------- src/slk/streams.cpp | 5 +- tests/unit/rpc.cpp | 83 +++++++++++++++++++++ 3 files changed, 212 insertions(+), 53 deletions(-) diff --git a/src/rpc/client.hpp b/src/rpc/client.hpp index 452f8ddb0..60472968e 100644 --- a/src/rpc/client.hpp +++ b/src/rpc/client.hpp @@ -22,18 +22,104 @@ class Client { Client(const io::network::Endpoint &endpoint, communication::ClientContext *context); - /// Call a previously defined and registered RPC call. This function can - /// initiate only one request at a time. The call blocks until a response is - /// received. + /// Object used to handle streaming of request data to the RPC server. + template + class StreamHandler { + private: + friend class Client; + + StreamHandler( + Client *self, std::unique_lock &&guard, + std::function + res_load) + : self_(self), + guard_(std::move(guard)), + req_builder_([&](const uint8_t *data, size_t size, bool have_more) { + if (!self_->client_->Write(data, size, have_more)) + throw RpcFailedException(self_->endpoint_); + }), + res_load_(res_load) {} + + public: + StreamHandler(StreamHandler &&) noexcept = default; + StreamHandler &operator=(StreamHandler &&) noexcept = default; + + StreamHandler(const StreamHandler &) = delete; + StreamHandler &operator=(const StreamHandler &) = delete; + + ~StreamHandler() {} + + slk::Builder *GetBuilder() { return &req_builder_; } + + typename TRequestResponse::Response AwaitResponse() { + auto res_type = TRequestResponse::Response::kType; + + // Finalize the request. + req_builder_.Finalize(); + + // Receive the response. + uint64_t response_data_size = 0; + while (true) { + auto ret = slk::CheckStreamComplete(self_->client_->GetData(), + self_->client_->GetDataSize()); + if (ret.status == slk::StreamStatus::INVALID) { + throw RpcFailedException(self_->endpoint_); + } else if (ret.status == slk::StreamStatus::PARTIAL) { + if (!self_->client_->Read( + ret.stream_size - self_->client_->GetDataSize(), + /* exactly_len = */ false)) { + throw RpcFailedException(self_->endpoint_); + } + } else { + response_data_size = ret.stream_size; + break; + } + } + + // Load the response. + slk::Reader res_reader(self_->client_->GetData(), response_data_size); + utils::OnScopeExit res_cleanup([&, response_data_size] { + self_->client_->ShiftData(response_data_size); + }); + + uint64_t res_id = 0; + slk::Load(&res_id, &res_reader); + + // Check the response ID. + if (res_id != res_type.id) { + LOG(ERROR) << "Message response was of unexpected type"; + self_->client_ = std::nullopt; + throw RpcFailedException(self_->endpoint_); + } + + VLOG(12) << "[RpcClient] received " << res_type.name; + + return res_load_(&res_reader); + } + + private: + Client *self_; + std::unique_lock guard_; + slk::Builder req_builder_; + std::function res_load_; + }; + + /// Stream a previously defined and registered RPC call. This function can + /// initiate only one request at a time. The call returns a `StreamHandler` + /// object that can be used to send additional data to the request (with the + /// automatically sent `TRequestResponse::Request` object) and await until the + /// response is received from the server. /// - /// @returns TRequestResponse::Response object that was specified to be - /// returned by the RPC call + /// @returns StreamHandler object that is used to handle + /// streaming of additional data to + /// the client and to await the + /// response from the server /// @throws RpcFailedException if an error was occurred while executing the /// RPC call (eg. connection failed, remote end /// died, etc.) template - typename TRequestResponse::Response Call(Args &&... args) { - return CallWithLoad( + StreamHandler Stream(Args &&... args) { + return StreamWithLoad( [](auto *reader) { typename TRequestResponse::Response response; TRequestResponse::Response::Load(&response, reader); @@ -42,17 +128,16 @@ class Client { std::forward(args)...); } - /// Same as `Call` but the first argument is a response loading function. + /// Same as `Stream` but the first argument is a response loading function. template - typename TRequestResponse::Response CallWithLoad( + StreamHandler StreamWithLoad( std::function load, Args &&... args) { typename TRequestResponse::Request request(std::forward(args)...); auto req_type = TRequestResponse::Request::kType; - auto res_type = TRequestResponse::Response::kType; VLOG(12) << "[RpcClient] sent " << req_type.name; - std::lock_guard guard(mutex_); + std::unique_lock guard(mutex_); // Check if the connection is broken (if we haven't used the client for a // long time the server could have died). @@ -70,51 +155,39 @@ class Client { } } + // Create the stream handler. + StreamHandler handler(this, std::move(guard), load); + // Build and send the request. - slk::Builder req_builder( - [&](const uint8_t *data, size_t size, bool have_more) { - client_->Write(data, size, have_more); - }); - slk::Save(req_type.id, &req_builder); - TRequestResponse::Request::Save(request, &req_builder); - req_builder.Finalize(); + slk::Save(req_type.id, handler.GetBuilder()); + TRequestResponse::Request::Save(request, handler.GetBuilder()); - // Receive response. - uint64_t response_data_size = 0; - while (true) { - auto ret = - slk::CheckStreamComplete(client_->GetData(), client_->GetDataSize()); - if (ret.status == slk::StreamStatus::INVALID) { - throw RpcFailedException(endpoint_); - } else if (ret.status == slk::StreamStatus::PARTIAL) { - if (!client_->Read(ret.stream_size - client_->GetDataSize(), - /* exactly_len = */ false)) { - throw RpcFailedException(endpoint_); - } - } else { - response_data_size = ret.stream_size; - break; - } - } + // Return the handler to the user. + return std::move(handler); + } - // Load the response. - slk::Reader res_reader(client_->GetData(), response_data_size); - utils::OnScopeExit res_cleanup( - [&, response_data_size] { client_->ShiftData(response_data_size); }); + /// Call a previously defined and registered RPC call. This function can + /// initiate only one request at a time. The call blocks until a response is + /// received. + /// + /// @returns TRequestResponse::Response object that was specified to be + /// returned by the RPC call + /// @throws RpcFailedException if an error was occurred while executing the + /// RPC call (eg. connection failed, remote end + /// died, etc.) + template + typename TRequestResponse::Response Call(Args &&... args) { + auto stream = Stream(std::forward(args)...); + return stream.AwaitResponse(); + } - uint64_t res_id = 0; - slk::Load(&res_id, &res_reader); - - // Check response ID. - if (res_id != res_type.id) { - LOG(ERROR) << "Message response was of unexpected type"; - client_ = std::nullopt; - throw RpcFailedException(endpoint_); - } - - VLOG(12) << "[RpcClient] received " << res_type.name; - - return load(&res_reader); + /// Same as `Call` but the first argument is a response loading function. + template + typename TRequestResponse::Response CallWithLoad( + std::function load, + Args &&... args) { + auto stream = StreamWithLoad(load, std::forward(args)...); + return stream.AwaitResponse(); } /// Call this function from another thread to abort a pending RPC call. diff --git a/src/slk/streams.cpp b/src/slk/streams.cpp index ed7cb822f..ee7d5a549 100644 --- a/src/slk/streams.cpp +++ b/src/slk/streams.cpp @@ -85,7 +85,6 @@ void Reader::GetSegment(bool should_be_final) { throw SlkReaderException("Size data missing in SLK stream!"); } memcpy(&len, data_ + pos_, sizeof(SegmentSize)); - pos_ += sizeof(SegmentSize); if (should_be_final && len != 0) { throw SlkReaderException( @@ -96,6 +95,10 @@ void Reader::GetSegment(bool should_be_final) { "Got an empty SLK segment when expecting a non-empty segment!"); } + // The position is incremented after the checks above so that the new + // segment can be reread if some of the above checks fail. + pos_ += sizeof(SegmentSize); + if (pos_ + len > size_) { throw SlkReaderException("There isn't enough data in the SLK stream!"); } diff --git a/tests/unit/rpc.cpp b/tests/unit/rpc.cpp index 05e3655d7..b8bd1235c 100644 --- a/tests/unit/rpc.cpp +++ b/tests/unit/rpc.cpp @@ -208,3 +208,86 @@ TEST(Rpc, JumboMessage) { server.Shutdown(); server.AwaitShutdown(); } + +TEST(Rpc, Stream) { + communication::ServerContext server_context; + Server server({"127.0.0.1", 0}, &server_context); + server.Register([](auto *req_reader, auto *res_builder) { + EchoMessage req; + slk::Load(&req, req_reader); + std::string payload; + slk::Load(&payload, req_reader); + EchoMessage res(req.data + payload); + slk::Save(res, res_builder); + }); + ASSERT_TRUE(server.Start()); + std::this_thread::sleep_for(100ms); + + communication::ClientContext client_context; + Client client(server.endpoint(), &client_context); + auto stream = client.Stream("hello"); + slk::Save("world", stream.GetBuilder()); + auto echo = stream.AwaitResponse(); + EXPECT_EQ(echo.data, "helloworld"); + + server.Shutdown(); + server.AwaitShutdown(); +} + +TEST(Rpc, StreamLarge) { + communication::ServerContext server_context; + Server server({"127.0.0.1", 0}, &server_context); + server.Register([](auto *req_reader, auto *res_builder) { + EchoMessage req; + slk::Load(&req, req_reader); + std::string payload; + slk::Load(&payload, req_reader); + EchoMessage res(req.data + payload); + slk::Save(res, res_builder); + }); + ASSERT_TRUE(server.Start()); + std::this_thread::sleep_for(100ms); + + std::string testdata1(50000, 'a'); + std::string testdata2(50000, 'b'); + + communication::ClientContext client_context; + Client client(server.endpoint(), &client_context); + auto stream = client.Stream(testdata1); + slk::Save(testdata2, stream.GetBuilder()); + auto echo = stream.AwaitResponse(); + EXPECT_EQ(echo.data, testdata1 + testdata2); + + server.Shutdown(); + server.AwaitShutdown(); +} + +TEST(Rpc, StreamJumbo) { + communication::ServerContext server_context; + Server server({"127.0.0.1", 0}, &server_context); + server.Register([](auto *req_reader, auto *res_builder) { + EchoMessage req; + slk::Load(&req, req_reader); + std::string payload; + slk::Load(&payload, req_reader); + EchoMessage res(req.data + payload); + slk::Save(res, res_builder); + }); + ASSERT_TRUE(server.Start()); + std::this_thread::sleep_for(100ms); + + // NOLINTNEXTLINE (bugprone-string-constructor) + std::string testdata1(5000000, 'a'); + // NOLINTNEXTLINE (bugprone-string-constructor) + std::string testdata2(5000000, 'b'); + + communication::ClientContext client_context; + Client client(server.endpoint(), &client_context); + auto stream = client.Stream(testdata1); + slk::Save(testdata2, stream.GetBuilder()); + auto echo = stream.AwaitResponse(); + EXPECT_EQ(echo.data, testdata1 + testdata2); + + server.Shutdown(); + server.AwaitShutdown(); +}