Add request streaming support to the RPC client

Summary:
This change only adds streaming support to the client request. The client
response, server request and server response are still handled only when all of
the data is received.

Reviewers: buda

Reviewed By: buda

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D2807
This commit is contained in:
Matej Ferencevic 2020-05-26 12:47:25 +02:00
parent aaf0c1ca08
commit dd9180da32
3 changed files with 212 additions and 53 deletions

View File

@ -22,18 +22,104 @@ class Client {
Client(const io::network::Endpoint &endpoint,
communication::ClientContext *context);
/// Call a previously defined and registered RPC call. This function can
/// initiate only one request at a time. The call blocks until a response is
/// received.
/// Object used to handle streaming of request data to the RPC server.
template <class TRequestResponse>
class StreamHandler {
private:
friend class Client;
StreamHandler(
Client *self, std::unique_lock<std::mutex> &&guard,
std::function<typename TRequestResponse::Response(slk::Reader *)>
res_load)
: self_(self),
guard_(std::move(guard)),
req_builder_([&](const uint8_t *data, size_t size, bool have_more) {
if (!self_->client_->Write(data, size, have_more))
throw RpcFailedException(self_->endpoint_);
}),
res_load_(res_load) {}
public:
StreamHandler(StreamHandler &&) noexcept = default;
StreamHandler &operator=(StreamHandler &&) noexcept = default;
StreamHandler(const StreamHandler &) = delete;
StreamHandler &operator=(const StreamHandler &) = delete;
~StreamHandler() {}
slk::Builder *GetBuilder() { return &req_builder_; }
typename TRequestResponse::Response AwaitResponse() {
auto res_type = TRequestResponse::Response::kType;
// Finalize the request.
req_builder_.Finalize();
// Receive the response.
uint64_t response_data_size = 0;
while (true) {
auto ret = slk::CheckStreamComplete(self_->client_->GetData(),
self_->client_->GetDataSize());
if (ret.status == slk::StreamStatus::INVALID) {
throw RpcFailedException(self_->endpoint_);
} else if (ret.status == slk::StreamStatus::PARTIAL) {
if (!self_->client_->Read(
ret.stream_size - self_->client_->GetDataSize(),
/* exactly_len = */ false)) {
throw RpcFailedException(self_->endpoint_);
}
} else {
response_data_size = ret.stream_size;
break;
}
}
// Load the response.
slk::Reader res_reader(self_->client_->GetData(), response_data_size);
utils::OnScopeExit res_cleanup([&, response_data_size] {
self_->client_->ShiftData(response_data_size);
});
uint64_t res_id = 0;
slk::Load(&res_id, &res_reader);
// Check the response ID.
if (res_id != res_type.id) {
LOG(ERROR) << "Message response was of unexpected type";
self_->client_ = std::nullopt;
throw RpcFailedException(self_->endpoint_);
}
VLOG(12) << "[RpcClient] received " << res_type.name;
return res_load_(&res_reader);
}
private:
Client *self_;
std::unique_lock<std::mutex> guard_;
slk::Builder req_builder_;
std::function<typename TRequestResponse::Response(slk::Reader *)> res_load_;
};
/// Stream a previously defined and registered RPC call. This function can
/// initiate only one request at a time. The call returns a `StreamHandler`
/// object that can be used to send additional data to the request (with the
/// automatically sent `TRequestResponse::Request` object) and await until the
/// response is received from the server.
///
/// @returns TRequestResponse::Response object that was specified to be
/// returned by the RPC call
/// @returns StreamHandler<TRequestResponse> object that is used to handle
/// streaming of additional data to
/// the client and to await the
/// response from the server
/// @throws RpcFailedException if an error was occurred while executing the
/// RPC call (eg. connection failed, remote end
/// died, etc.)
template <class TRequestResponse, class... Args>
typename TRequestResponse::Response Call(Args &&... args) {
return CallWithLoad<TRequestResponse>(
StreamHandler<TRequestResponse> Stream(Args &&... args) {
return StreamWithLoad<TRequestResponse>(
[](auto *reader) {
typename TRequestResponse::Response response;
TRequestResponse::Response::Load(&response, reader);
@ -42,17 +128,16 @@ class Client {
std::forward<Args>(args)...);
}
/// Same as `Call` but the first argument is a response loading function.
/// Same as `Stream` but the first argument is a response loading function.
template <class TRequestResponse, class... Args>
typename TRequestResponse::Response CallWithLoad(
StreamHandler<TRequestResponse> StreamWithLoad(
std::function<typename TRequestResponse::Response(slk::Reader *)> load,
Args &&... args) {
typename TRequestResponse::Request request(std::forward<Args>(args)...);
auto req_type = TRequestResponse::Request::kType;
auto res_type = TRequestResponse::Response::kType;
VLOG(12) << "[RpcClient] sent " << req_type.name;
std::lock_guard<std::mutex> guard(mutex_);
std::unique_lock<std::mutex> guard(mutex_);
// Check if the connection is broken (if we haven't used the client for a
// long time the server could have died).
@ -70,51 +155,39 @@ class Client {
}
}
// Create the stream handler.
StreamHandler<TRequestResponse> handler(this, std::move(guard), load);
// Build and send the request.
slk::Builder req_builder(
[&](const uint8_t *data, size_t size, bool have_more) {
client_->Write(data, size, have_more);
});
slk::Save(req_type.id, &req_builder);
TRequestResponse::Request::Save(request, &req_builder);
req_builder.Finalize();
slk::Save(req_type.id, handler.GetBuilder());
TRequestResponse::Request::Save(request, handler.GetBuilder());
// Receive response.
uint64_t response_data_size = 0;
while (true) {
auto ret =
slk::CheckStreamComplete(client_->GetData(), client_->GetDataSize());
if (ret.status == slk::StreamStatus::INVALID) {
throw RpcFailedException(endpoint_);
} else if (ret.status == slk::StreamStatus::PARTIAL) {
if (!client_->Read(ret.stream_size - client_->GetDataSize(),
/* exactly_len = */ false)) {
throw RpcFailedException(endpoint_);
}
} else {
response_data_size = ret.stream_size;
break;
}
}
// Return the handler to the user.
return std::move(handler);
}
// Load the response.
slk::Reader res_reader(client_->GetData(), response_data_size);
utils::OnScopeExit res_cleanup(
[&, response_data_size] { client_->ShiftData(response_data_size); });
/// Call a previously defined and registered RPC call. This function can
/// initiate only one request at a time. The call blocks until a response is
/// received.
///
/// @returns TRequestResponse::Response object that was specified to be
/// returned by the RPC call
/// @throws RpcFailedException if an error was occurred while executing the
/// RPC call (eg. connection failed, remote end
/// died, etc.)
template <class TRequestResponse, class... Args>
typename TRequestResponse::Response Call(Args &&... args) {
auto stream = Stream<TRequestResponse>(std::forward<Args>(args)...);
return stream.AwaitResponse();
}
uint64_t res_id = 0;
slk::Load(&res_id, &res_reader);
// Check response ID.
if (res_id != res_type.id) {
LOG(ERROR) << "Message response was of unexpected type";
client_ = std::nullopt;
throw RpcFailedException(endpoint_);
}
VLOG(12) << "[RpcClient] received " << res_type.name;
return load(&res_reader);
/// Same as `Call` but the first argument is a response loading function.
template <class TRequestResponse, class... Args>
typename TRequestResponse::Response CallWithLoad(
std::function<typename TRequestResponse::Response(slk::Reader *)> load,
Args &&... args) {
auto stream = StreamWithLoad(load, std::forward<Args>(args)...);
return stream.AwaitResponse();
}
/// Call this function from another thread to abort a pending RPC call.

View File

@ -85,7 +85,6 @@ void Reader::GetSegment(bool should_be_final) {
throw SlkReaderException("Size data missing in SLK stream!");
}
memcpy(&len, data_ + pos_, sizeof(SegmentSize));
pos_ += sizeof(SegmentSize);
if (should_be_final && len != 0) {
throw SlkReaderException(
@ -96,6 +95,10 @@ void Reader::GetSegment(bool should_be_final) {
"Got an empty SLK segment when expecting a non-empty segment!");
}
// The position is incremented after the checks above so that the new
// segment can be reread if some of the above checks fail.
pos_ += sizeof(SegmentSize);
if (pos_ + len > size_) {
throw SlkReaderException("There isn't enough data in the SLK stream!");
}

View File

@ -208,3 +208,86 @@ TEST(Rpc, JumboMessage) {
server.Shutdown();
server.AwaitShutdown();
}
TEST(Rpc, Stream) {
communication::ServerContext server_context;
Server server({"127.0.0.1", 0}, &server_context);
server.Register<Echo>([](auto *req_reader, auto *res_builder) {
EchoMessage req;
slk::Load(&req, req_reader);
std::string payload;
slk::Load(&payload, req_reader);
EchoMessage res(req.data + payload);
slk::Save(res, res_builder);
});
ASSERT_TRUE(server.Start());
std::this_thread::sleep_for(100ms);
communication::ClientContext client_context;
Client client(server.endpoint(), &client_context);
auto stream = client.Stream<Echo>("hello");
slk::Save("world", stream.GetBuilder());
auto echo = stream.AwaitResponse();
EXPECT_EQ(echo.data, "helloworld");
server.Shutdown();
server.AwaitShutdown();
}
TEST(Rpc, StreamLarge) {
communication::ServerContext server_context;
Server server({"127.0.0.1", 0}, &server_context);
server.Register<Echo>([](auto *req_reader, auto *res_builder) {
EchoMessage req;
slk::Load(&req, req_reader);
std::string payload;
slk::Load(&payload, req_reader);
EchoMessage res(req.data + payload);
slk::Save(res, res_builder);
});
ASSERT_TRUE(server.Start());
std::this_thread::sleep_for(100ms);
std::string testdata1(50000, 'a');
std::string testdata2(50000, 'b');
communication::ClientContext client_context;
Client client(server.endpoint(), &client_context);
auto stream = client.Stream<Echo>(testdata1);
slk::Save(testdata2, stream.GetBuilder());
auto echo = stream.AwaitResponse();
EXPECT_EQ(echo.data, testdata1 + testdata2);
server.Shutdown();
server.AwaitShutdown();
}
TEST(Rpc, StreamJumbo) {
communication::ServerContext server_context;
Server server({"127.0.0.1", 0}, &server_context);
server.Register<Echo>([](auto *req_reader, auto *res_builder) {
EchoMessage req;
slk::Load(&req, req_reader);
std::string payload;
slk::Load(&payload, req_reader);
EchoMessage res(req.data + payload);
slk::Save(res, res_builder);
});
ASSERT_TRUE(server.Start());
std::this_thread::sleep_for(100ms);
// NOLINTNEXTLINE (bugprone-string-constructor)
std::string testdata1(5000000, 'a');
// NOLINTNEXTLINE (bugprone-string-constructor)
std::string testdata2(5000000, 'b');
communication::ClientContext client_context;
Client client(server.endpoint(), &client_context);
auto stream = client.Stream<Echo>(testdata1);
slk::Save(testdata2, stream.GetBuilder());
auto echo = stream.AwaitResponse();
EXPECT_EQ(echo.data, testdata1 + testdata2);
server.Shutdown();
server.AwaitShutdown();
}