Add request streaming support to the RPC client
Summary: This change only adds streaming support to the client request. The client response, server request and server response are still handled only when all of the data is received. Reviewers: buda Reviewed By: buda Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D2807
This commit is contained in:
parent
aaf0c1ca08
commit
dd9180da32
@ -22,18 +22,104 @@ class Client {
|
||||
Client(const io::network::Endpoint &endpoint,
|
||||
communication::ClientContext *context);
|
||||
|
||||
/// Call a previously defined and registered RPC call. This function can
|
||||
/// initiate only one request at a time. The call blocks until a response is
|
||||
/// received.
|
||||
/// Object used to handle streaming of request data to the RPC server.
|
||||
template <class TRequestResponse>
|
||||
class StreamHandler {
|
||||
private:
|
||||
friend class Client;
|
||||
|
||||
StreamHandler(
|
||||
Client *self, std::unique_lock<std::mutex> &&guard,
|
||||
std::function<typename TRequestResponse::Response(slk::Reader *)>
|
||||
res_load)
|
||||
: self_(self),
|
||||
guard_(std::move(guard)),
|
||||
req_builder_([&](const uint8_t *data, size_t size, bool have_more) {
|
||||
if (!self_->client_->Write(data, size, have_more))
|
||||
throw RpcFailedException(self_->endpoint_);
|
||||
}),
|
||||
res_load_(res_load) {}
|
||||
|
||||
public:
|
||||
StreamHandler(StreamHandler &&) noexcept = default;
|
||||
StreamHandler &operator=(StreamHandler &&) noexcept = default;
|
||||
|
||||
StreamHandler(const StreamHandler &) = delete;
|
||||
StreamHandler &operator=(const StreamHandler &) = delete;
|
||||
|
||||
~StreamHandler() {}
|
||||
|
||||
slk::Builder *GetBuilder() { return &req_builder_; }
|
||||
|
||||
typename TRequestResponse::Response AwaitResponse() {
|
||||
auto res_type = TRequestResponse::Response::kType;
|
||||
|
||||
// Finalize the request.
|
||||
req_builder_.Finalize();
|
||||
|
||||
// Receive the response.
|
||||
uint64_t response_data_size = 0;
|
||||
while (true) {
|
||||
auto ret = slk::CheckStreamComplete(self_->client_->GetData(),
|
||||
self_->client_->GetDataSize());
|
||||
if (ret.status == slk::StreamStatus::INVALID) {
|
||||
throw RpcFailedException(self_->endpoint_);
|
||||
} else if (ret.status == slk::StreamStatus::PARTIAL) {
|
||||
if (!self_->client_->Read(
|
||||
ret.stream_size - self_->client_->GetDataSize(),
|
||||
/* exactly_len = */ false)) {
|
||||
throw RpcFailedException(self_->endpoint_);
|
||||
}
|
||||
} else {
|
||||
response_data_size = ret.stream_size;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Load the response.
|
||||
slk::Reader res_reader(self_->client_->GetData(), response_data_size);
|
||||
utils::OnScopeExit res_cleanup([&, response_data_size] {
|
||||
self_->client_->ShiftData(response_data_size);
|
||||
});
|
||||
|
||||
uint64_t res_id = 0;
|
||||
slk::Load(&res_id, &res_reader);
|
||||
|
||||
// Check the response ID.
|
||||
if (res_id != res_type.id) {
|
||||
LOG(ERROR) << "Message response was of unexpected type";
|
||||
self_->client_ = std::nullopt;
|
||||
throw RpcFailedException(self_->endpoint_);
|
||||
}
|
||||
|
||||
VLOG(12) << "[RpcClient] received " << res_type.name;
|
||||
|
||||
return res_load_(&res_reader);
|
||||
}
|
||||
|
||||
private:
|
||||
Client *self_;
|
||||
std::unique_lock<std::mutex> guard_;
|
||||
slk::Builder req_builder_;
|
||||
std::function<typename TRequestResponse::Response(slk::Reader *)> res_load_;
|
||||
};
|
||||
|
||||
/// Stream a previously defined and registered RPC call. This function can
|
||||
/// initiate only one request at a time. The call returns a `StreamHandler`
|
||||
/// object that can be used to send additional data to the request (with the
|
||||
/// automatically sent `TRequestResponse::Request` object) and await until the
|
||||
/// response is received from the server.
|
||||
///
|
||||
/// @returns TRequestResponse::Response object that was specified to be
|
||||
/// returned by the RPC call
|
||||
/// @returns StreamHandler<TRequestResponse> object that is used to handle
|
||||
/// streaming of additional data to
|
||||
/// the client and to await the
|
||||
/// response from the server
|
||||
/// @throws RpcFailedException if an error was occurred while executing the
|
||||
/// RPC call (eg. connection failed, remote end
|
||||
/// died, etc.)
|
||||
template <class TRequestResponse, class... Args>
|
||||
typename TRequestResponse::Response Call(Args &&... args) {
|
||||
return CallWithLoad<TRequestResponse>(
|
||||
StreamHandler<TRequestResponse> Stream(Args &&... args) {
|
||||
return StreamWithLoad<TRequestResponse>(
|
||||
[](auto *reader) {
|
||||
typename TRequestResponse::Response response;
|
||||
TRequestResponse::Response::Load(&response, reader);
|
||||
@ -42,17 +128,16 @@ class Client {
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
/// Same as `Call` but the first argument is a response loading function.
|
||||
/// Same as `Stream` but the first argument is a response loading function.
|
||||
template <class TRequestResponse, class... Args>
|
||||
typename TRequestResponse::Response CallWithLoad(
|
||||
StreamHandler<TRequestResponse> StreamWithLoad(
|
||||
std::function<typename TRequestResponse::Response(slk::Reader *)> load,
|
||||
Args &&... args) {
|
||||
typename TRequestResponse::Request request(std::forward<Args>(args)...);
|
||||
auto req_type = TRequestResponse::Request::kType;
|
||||
auto res_type = TRequestResponse::Response::kType;
|
||||
VLOG(12) << "[RpcClient] sent " << req_type.name;
|
||||
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
|
||||
// Check if the connection is broken (if we haven't used the client for a
|
||||
// long time the server could have died).
|
||||
@ -70,51 +155,39 @@ class Client {
|
||||
}
|
||||
}
|
||||
|
||||
// Create the stream handler.
|
||||
StreamHandler<TRequestResponse> handler(this, std::move(guard), load);
|
||||
|
||||
// Build and send the request.
|
||||
slk::Builder req_builder(
|
||||
[&](const uint8_t *data, size_t size, bool have_more) {
|
||||
client_->Write(data, size, have_more);
|
||||
});
|
||||
slk::Save(req_type.id, &req_builder);
|
||||
TRequestResponse::Request::Save(request, &req_builder);
|
||||
req_builder.Finalize();
|
||||
slk::Save(req_type.id, handler.GetBuilder());
|
||||
TRequestResponse::Request::Save(request, handler.GetBuilder());
|
||||
|
||||
// Receive response.
|
||||
uint64_t response_data_size = 0;
|
||||
while (true) {
|
||||
auto ret =
|
||||
slk::CheckStreamComplete(client_->GetData(), client_->GetDataSize());
|
||||
if (ret.status == slk::StreamStatus::INVALID) {
|
||||
throw RpcFailedException(endpoint_);
|
||||
} else if (ret.status == slk::StreamStatus::PARTIAL) {
|
||||
if (!client_->Read(ret.stream_size - client_->GetDataSize(),
|
||||
/* exactly_len = */ false)) {
|
||||
throw RpcFailedException(endpoint_);
|
||||
}
|
||||
} else {
|
||||
response_data_size = ret.stream_size;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Return the handler to the user.
|
||||
return std::move(handler);
|
||||
}
|
||||
|
||||
// Load the response.
|
||||
slk::Reader res_reader(client_->GetData(), response_data_size);
|
||||
utils::OnScopeExit res_cleanup(
|
||||
[&, response_data_size] { client_->ShiftData(response_data_size); });
|
||||
/// Call a previously defined and registered RPC call. This function can
|
||||
/// initiate only one request at a time. The call blocks until a response is
|
||||
/// received.
|
||||
///
|
||||
/// @returns TRequestResponse::Response object that was specified to be
|
||||
/// returned by the RPC call
|
||||
/// @throws RpcFailedException if an error was occurred while executing the
|
||||
/// RPC call (eg. connection failed, remote end
|
||||
/// died, etc.)
|
||||
template <class TRequestResponse, class... Args>
|
||||
typename TRequestResponse::Response Call(Args &&... args) {
|
||||
auto stream = Stream<TRequestResponse>(std::forward<Args>(args)...);
|
||||
return stream.AwaitResponse();
|
||||
}
|
||||
|
||||
uint64_t res_id = 0;
|
||||
slk::Load(&res_id, &res_reader);
|
||||
|
||||
// Check response ID.
|
||||
if (res_id != res_type.id) {
|
||||
LOG(ERROR) << "Message response was of unexpected type";
|
||||
client_ = std::nullopt;
|
||||
throw RpcFailedException(endpoint_);
|
||||
}
|
||||
|
||||
VLOG(12) << "[RpcClient] received " << res_type.name;
|
||||
|
||||
return load(&res_reader);
|
||||
/// Same as `Call` but the first argument is a response loading function.
|
||||
template <class TRequestResponse, class... Args>
|
||||
typename TRequestResponse::Response CallWithLoad(
|
||||
std::function<typename TRequestResponse::Response(slk::Reader *)> load,
|
||||
Args &&... args) {
|
||||
auto stream = StreamWithLoad(load, std::forward<Args>(args)...);
|
||||
return stream.AwaitResponse();
|
||||
}
|
||||
|
||||
/// Call this function from another thread to abort a pending RPC call.
|
||||
|
@ -85,7 +85,6 @@ void Reader::GetSegment(bool should_be_final) {
|
||||
throw SlkReaderException("Size data missing in SLK stream!");
|
||||
}
|
||||
memcpy(&len, data_ + pos_, sizeof(SegmentSize));
|
||||
pos_ += sizeof(SegmentSize);
|
||||
|
||||
if (should_be_final && len != 0) {
|
||||
throw SlkReaderException(
|
||||
@ -96,6 +95,10 @@ void Reader::GetSegment(bool should_be_final) {
|
||||
"Got an empty SLK segment when expecting a non-empty segment!");
|
||||
}
|
||||
|
||||
// The position is incremented after the checks above so that the new
|
||||
// segment can be reread if some of the above checks fail.
|
||||
pos_ += sizeof(SegmentSize);
|
||||
|
||||
if (pos_ + len > size_) {
|
||||
throw SlkReaderException("There isn't enough data in the SLK stream!");
|
||||
}
|
||||
|
@ -208,3 +208,86 @@ TEST(Rpc, JumboMessage) {
|
||||
server.Shutdown();
|
||||
server.AwaitShutdown();
|
||||
}
|
||||
|
||||
TEST(Rpc, Stream) {
|
||||
communication::ServerContext server_context;
|
||||
Server server({"127.0.0.1", 0}, &server_context);
|
||||
server.Register<Echo>([](auto *req_reader, auto *res_builder) {
|
||||
EchoMessage req;
|
||||
slk::Load(&req, req_reader);
|
||||
std::string payload;
|
||||
slk::Load(&payload, req_reader);
|
||||
EchoMessage res(req.data + payload);
|
||||
slk::Save(res, res_builder);
|
||||
});
|
||||
ASSERT_TRUE(server.Start());
|
||||
std::this_thread::sleep_for(100ms);
|
||||
|
||||
communication::ClientContext client_context;
|
||||
Client client(server.endpoint(), &client_context);
|
||||
auto stream = client.Stream<Echo>("hello");
|
||||
slk::Save("world", stream.GetBuilder());
|
||||
auto echo = stream.AwaitResponse();
|
||||
EXPECT_EQ(echo.data, "helloworld");
|
||||
|
||||
server.Shutdown();
|
||||
server.AwaitShutdown();
|
||||
}
|
||||
|
||||
TEST(Rpc, StreamLarge) {
|
||||
communication::ServerContext server_context;
|
||||
Server server({"127.0.0.1", 0}, &server_context);
|
||||
server.Register<Echo>([](auto *req_reader, auto *res_builder) {
|
||||
EchoMessage req;
|
||||
slk::Load(&req, req_reader);
|
||||
std::string payload;
|
||||
slk::Load(&payload, req_reader);
|
||||
EchoMessage res(req.data + payload);
|
||||
slk::Save(res, res_builder);
|
||||
});
|
||||
ASSERT_TRUE(server.Start());
|
||||
std::this_thread::sleep_for(100ms);
|
||||
|
||||
std::string testdata1(50000, 'a');
|
||||
std::string testdata2(50000, 'b');
|
||||
|
||||
communication::ClientContext client_context;
|
||||
Client client(server.endpoint(), &client_context);
|
||||
auto stream = client.Stream<Echo>(testdata1);
|
||||
slk::Save(testdata2, stream.GetBuilder());
|
||||
auto echo = stream.AwaitResponse();
|
||||
EXPECT_EQ(echo.data, testdata1 + testdata2);
|
||||
|
||||
server.Shutdown();
|
||||
server.AwaitShutdown();
|
||||
}
|
||||
|
||||
TEST(Rpc, StreamJumbo) {
|
||||
communication::ServerContext server_context;
|
||||
Server server({"127.0.0.1", 0}, &server_context);
|
||||
server.Register<Echo>([](auto *req_reader, auto *res_builder) {
|
||||
EchoMessage req;
|
||||
slk::Load(&req, req_reader);
|
||||
std::string payload;
|
||||
slk::Load(&payload, req_reader);
|
||||
EchoMessage res(req.data + payload);
|
||||
slk::Save(res, res_builder);
|
||||
});
|
||||
ASSERT_TRUE(server.Start());
|
||||
std::this_thread::sleep_for(100ms);
|
||||
|
||||
// NOLINTNEXTLINE (bugprone-string-constructor)
|
||||
std::string testdata1(5000000, 'a');
|
||||
// NOLINTNEXTLINE (bugprone-string-constructor)
|
||||
std::string testdata2(5000000, 'b');
|
||||
|
||||
communication::ClientContext client_context;
|
||||
Client client(server.endpoint(), &client_context);
|
||||
auto stream = client.Stream<Echo>(testdata1);
|
||||
slk::Save(testdata2, stream.GetBuilder());
|
||||
auto echo = stream.AwaitResponse();
|
||||
EXPECT_EQ(echo.data, testdata1 + testdata2);
|
||||
|
||||
server.Shutdown();
|
||||
server.AwaitShutdown();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user