// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // License, and you may not use this file except in compliance with the Business Source License. // // As of the Change Date specified in that file, in accordance with // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. #pragma once #include #include #include #include "communication/client.hpp" #include "io/network/endpoint.hpp" #include "rpc/exceptions.hpp" #include "rpc/messages.hpp" #include "slk/serialization.hpp" #include "slk/streams.hpp" #include "utils/logging.hpp" #include "utils/on_scope_exit.hpp" #include "utils/typeinfo.hpp" namespace memgraph::rpc { /// Client is thread safe, but it is recommended to use thread_local clients. class Client { public: Client(io::network::Endpoint endpoint, communication::ClientContext *context); /// Object used to handle streaming of request data to the RPC server. template class StreamHandler { private: friend class Client; StreamHandler(Client *self, std::unique_lock &&guard, std::function res_load) : self_(self), guard_(std::move(guard)), req_builder_([self](const uint8_t *data, size_t size, bool have_more) { if (!self->client_->Write(data, size, have_more)) throw RpcFailedException(self->endpoint_); }), res_load_(res_load) {} public: StreamHandler(StreamHandler &&) noexcept = default; StreamHandler &operator=(StreamHandler &&) noexcept = default; StreamHandler(const StreamHandler &) = delete; StreamHandler &operator=(const StreamHandler &) = delete; ~StreamHandler() {} slk::Builder *GetBuilder() { return &req_builder_; } typename TRequestResponse::Response AwaitResponse() { auto res_type = TRequestResponse::Response::kType; // Finalize the request. req_builder_.Finalize(); // Receive the response. uint64_t response_data_size = 0; while (true) { auto ret = slk::CheckStreamComplete(self_->client_->GetData(), self_->client_->GetDataSize()); if (ret.status == slk::StreamStatus::INVALID) { throw RpcFailedException(self_->endpoint_); } else if (ret.status == slk::StreamStatus::PARTIAL) { if (!self_->client_->Read(ret.stream_size - self_->client_->GetDataSize(), /* exactly_len = */ false)) { throw RpcFailedException(self_->endpoint_); } } else { response_data_size = ret.stream_size; break; } } // Load the response. slk::Reader res_reader(self_->client_->GetData(), response_data_size); utils::OnScopeExit res_cleanup([&, response_data_size] { self_->client_->ShiftData(response_data_size); }); utils::TypeId res_id{utils::TypeId::UNKNOWN}; slk::Load(&res_id, &res_reader); // Check the response ID. if (res_id != res_type.id && res_id != utils::TypeId::UNKNOWN) { spdlog::error("Message response was of unexpected type"); self_->client_ = std::nullopt; throw RpcFailedException(self_->endpoint_); } SPDLOG_TRACE("[RpcClient] received {}", res_type.name); return res_load_(&res_reader); } private: Client *self_; std::unique_lock guard_; slk::Builder req_builder_; std::function res_load_; }; /// Stream a previously defined and registered RPC call. This function can /// initiate only one request at a time. The call returns a `StreamHandler` /// object that can be used to send additional data to the request (with the /// automatically sent `TRequestResponse::Request` object) and await until the /// response is received from the server. /// /// @returns StreamHandler object that is used to handle /// streaming of additional data to /// the client and to await the /// response from the server /// @throws RpcFailedException if an error was occurred while executing the /// RPC call (eg. connection failed, remote end /// died, etc.) template StreamHandler Stream(Args &&...args) { return StreamWithLoad( [](auto *reader) { typename TRequestResponse::Response response; TRequestResponse::Response::Load(&response, reader); return response; }, std::forward(args)...); } /// Same as `Stream` but the first argument is a response loading function. template StreamHandler StreamWithLoad(std::function load, Args &&...args) { typename TRequestResponse::Request request(std::forward(args)...); auto req_type = TRequestResponse::Request::kType; SPDLOG_TRACE("[RpcClient] sent {}", req_type.name); std::unique_lock guard(mutex_); // Check if the connection is broken (if we haven't used the client for a // long time the server could have died). if (client_ && client_->ErrorStatus()) { client_ = std::nullopt; } // Connect to the remote server. if (!client_) { client_.emplace(context_); if (!client_->Connect(endpoint_)) { SPDLOG_ERROR("Couldn't connect to remote address {}", endpoint_); client_ = std::nullopt; throw RpcFailedException(endpoint_); } } // Create the stream handler. StreamHandler handler(this, std::move(guard), load); // Build and send the request. slk::Save(req_type.id, handler.GetBuilder()); TRequestResponse::Request::Save(request, handler.GetBuilder()); // Return the handler to the user. return std::move(handler); } /// Call a previously defined and registered RPC call. This function can /// initiate only one request at a time. The call blocks until a response is /// received. /// /// @returns TRequestResponse::Response object that was specified to be /// returned by the RPC call /// @throws RpcFailedException if an error was occurred while executing the /// RPC call (eg. connection failed, remote end /// died, etc.) template typename TRequestResponse::Response Call(Args &&...args) { auto stream = Stream(std::forward(args)...); return stream.AwaitResponse(); } /// Same as `Call` but the first argument is a response loading function. template typename TRequestResponse::Response CallWithLoad( std::function load, Args &&...args) { auto stream = StreamWithLoad(load, std::forward(args)...); return stream.AwaitResponse(); } /// Call this function from another thread to abort a pending RPC call. void Abort(); auto Endpoint() const -> io::network::Endpoint const & { return endpoint_; } private: io::network::Endpoint endpoint_; communication::ClientContext *context_; std::optional client_; std::mutex mutex_; }; } // namespace memgraph::rpc