diff --git a/src/io/simulator/simulator_handle.hpp b/src/io/simulator/simulator_handle.hpp index 4434ee207..3fbdf2056 100644 --- a/src/io/simulator/simulator_handle.hpp +++ b/src/io/simulator/simulator_handle.hpp @@ -36,6 +36,7 @@ using memgraph::io::Duration; using memgraph::io::OpaqueMessage; using memgraph::io::OpaquePromise; using memgraph::io::Time; +using memgraph::io::TimedOut; class SimulatorHandle { mutable std::mutex mu_{}; diff --git a/src/io/thrift/thrift_handle.hpp b/src/io/thrift/thrift_handle.hpp index 2694e9cce..cb4e44661 100644 --- a/src/io/thrift/thrift_handle.hpp +++ b/src/io/thrift/thrift_handle.hpp @@ -15,6 +15,7 @@ #include <map> #include <mutex> +#include "io/errors.hpp" #include "io/message_conversion.hpp" #include "io/transport.hpp" @@ -23,34 +24,61 @@ namespace memgraph::io::thrift { using memgraph::io::Address; using memgraph::io::OpaqueMessage; using memgraph::io::OpaquePromise; +using memgraph::io::TimedOut; using RequestId = uint64_t; class ThriftHandle { mutable std::mutex mu_{}; mutable std::condition_variable cv_; + const Address address_ = Address::TestAddress(0); // the responses to requests that are being waited on - std::map<RequestId, DeadlineAndOpaquePromise> promises_; + std::map<PromiseKey, DeadlineAndOpaquePromise> promises_; // messages that are sent to servers that may later receive them std::vector<OpaqueMessage> can_receive_; // TODO(tyler) thrift clients for each outbound address combination - std::map<Address, void *> clients_; + // std::map<Address, void *> clients_; // TODO(gabor) make this to a threadpool // uuid of the address -> port number where the given rsm is residing. // TODO(gabor) The RSM map should not be a part of this class. - std::map<boost::uuids::uuid, uint16_t /*this should be the actual RSM*/> rsm_map_; + // std::map<boost::uuids::uuid, uint16_t /*this should be the actual RSM*/> rsm_map_; + + // this is duplicated between the ThriftTransport and here + // because it's relatively simple and there's no need to + // avoid the duplication as of the time of implementation. + Time Now() const { + auto nano_time = std::chrono::system_clock::now(); + return std::chrono::time_point_cast<std::chrono::microseconds>(nano_time); + } public: + ThriftHandle(Address our_address) : address_(our_address) {} + template <Message M> - void DeliverMessage(Address from_address, RequestId request_id, M &&message) { + void DeliverMessage(Address to_address, Address from_address, RequestId request_id, M &&message) { + std::any message_any(std::move(message)); + OpaqueMessage opaque_message{ + .from_address = from_address, .request_id = request_id, .message = std::move(message_any)}; + + PromiseKey promise_key{.requester_address = to_address, + .request_id = opaque_message.request_id, + .replier_address = opaque_message.from_address}; + { std::unique_lock<std::mutex> lock(mu_); - std::any message_any(std::move(message)); - OpaqueMessage om{.from_address = from_address, .request_id = request_id, .message = std::move(message_any)}; - can_receive_.emplace_back(std::move(om)); + + if (promises_.contains(promise_key)) { + // complete waiting promise if it's there + DeadlineAndOpaquePromise dop = std::move(promises_.at(promise_key)); + promises_.erase(promise_key); + + dop.promise.Fill(std::move(opaque_message)); + } else { + can_receive_.emplace_back(std::move(opaque_message)); + } } // lock dropped cv_.notify_all(); @@ -59,33 +87,46 @@ class ThriftHandle { template <Message Request, Message Response> void SubmitRequest(Address to_address, Address from_address, RequestId request_id, Request &&request, Duration timeout, ResponsePromise<Response> &&promise) { - // TODO(tyler) simular to simulator transport, add the promise to the promises_ map + const Time deadline = Now() + timeout; + Address our_address = address_; - Send(to_address, from_address, request_id, request); + PromiseKey promise_key{.requester_address = from_address, .request_id = request_id, .replier_address = to_address}; + OpaquePromise opaque_promise(std::move(promise).ToUnique()); + DeadlineAndOpaquePromise dop{.deadline = deadline, .promise = std::move(opaque_promise)}; + promises_.emplace(std::move(promise_key), std::move(dop)); + + cv_.notify_all(); + + bool port_matches = to_address.last_known_port == our_address.last_known_port; + bool ip_matches = to_address.last_known_ip == our_address.last_known_ip; + + if (port_matches && ip_matches) { + // hairpin routing optimization + DeliverMessage(to_address, from_address, request_id, std::move(request)); + } else { + // send using a thrift client to remove service + Send(to_address, from_address, request_id, request); + } } template <Message... Ms> - requires(sizeof...(Ms) > 0) RequestResult<Ms...> Receive(const Address &receiver, Duration timeout) { + requires(sizeof...(Ms) > 0) RequestResult<Ms...> Receive(Duration timeout) { // TODO(tyler) block for the specified duration on the Inbox's receipt of a message of this type. std::unique_lock lock(mu_); - cv_.wait(lock, [this] { return !can_receive_.empty(); }); - while (!can_receive_.empty()) { - auto current_message = can_receive_.back(); - can_receive_.pop_back(); - - // Logic to determine who to send the message. - // - auto destination_id = current_message.to_address.unique_id; - auto destination_port = rsm_map_.at(destination_id); - - // Send it to the port of the destination -how? - // TODO(tyler) search for item in can_receive_ that matches the desired types, rather - // than asserting that the last item in can_rx matches. - auto m_opt = std::move(current_message).Take<Ms...>(); - - return (std::move(m_opt)); + while (can_receive_.empty()) { + std::cv_status cv_status_value = cv_.wait_for(lock, timeout); + if (cv_status_value == std::cv_status::timeout) { + return TimedOut{}; + } } + + auto current_message = std::move(can_receive_.back()); + can_receive_.pop_back(); + + auto m_opt = std::move(current_message).Take<Ms...>(); + + return std::move(m_opt).value(); } template <Message M> diff --git a/src/io/thrift/thrift_transport.hpp b/src/io/thrift/thrift_transport.hpp index 19aa248b8..5d491e0de 100644 --- a/src/io/thrift/thrift_transport.hpp +++ b/src/io/thrift/thrift_transport.hpp @@ -37,7 +37,7 @@ class ThriftTransport { template <Message Request, Message Response> ResponseFuture<Response> Request(Address address, uint64_t request_id, Request request, Duration timeout) { - auto [future, promise] = memgraph::io::FuturePromisePairWithNotifier<ResponseResult<Response>>(); + auto [future, promise] = memgraph::io::FuturePromisePair<ResponseResult<Response>>(); thrift_handle_->SubmitRequest(address, address_, request_id, std::move(request), timeout, std::move(promise)); diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index f4d71fa73..36c97abbc 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -400,3 +400,7 @@ target_link_libraries(${test_prefix}future mg-io) # Test Thrift transport echo add_unit_test(thrift_transport_echo.cpp) target_link_libraries(${test_prefix}thrift_transport_echo mg-io fmt Threads::Threads FBThrift::thriftcpp2 mg-interface-echo-cpp2) + +# Test Thrift transport echo +add_unit_test(thrift_handle.cpp) +target_link_libraries(${test_prefix}thrift_handle mg-io fmt Threads::Threads) diff --git a/tests/unit/thrift_handle.cpp b/tests/unit/thrift_handle.cpp new file mode 100644 index 000000000..da7758fb3 --- /dev/null +++ b/tests/unit/thrift_handle.cpp @@ -0,0 +1,93 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include <string> +#include <thread> + +#include "gtest/gtest.h" + +#include "io/address.hpp" +#include "io/future.hpp" +#include "io/thrift/thrift_handle.hpp" +#include "io/transport.hpp" +#include "utils/logging.hpp" + +using memgraph::io::Address; +using memgraph::io::Duration; +using memgraph::io::FuturePromisePair; +using memgraph::io::RequestEnvelope; +using memgraph::io::ResponseEnvelope; +using memgraph::io::ResponseResult; +using memgraph::io::thrift::ThriftHandle; + +struct TestMessage { + int value; +}; + +TEST(Thrift, ThriftHandleTimeout) { + auto our_address = Address::TestAddress(0); + auto handle = ThriftHandle{our_address}; + + // assert timeouts fire + auto should_timeout = handle.Receive<TestMessage>(Duration{}); + MG_ASSERT(should_timeout.HasError()); +} + +TEST(Thrift, ThriftHandleReceive) { + auto our_address = Address::TestAddress(0); + auto handle = ThriftHandle{our_address}; + + // assert we can send and receive + auto to_address = Address::TestAddress(0); + auto from_address = Address::TestAddress(1); + auto request_id = 0; + auto message = TestMessage{ + .value = 777, + }; + + handle.DeliverMessage(to_address, from_address, request_id, std::move(message)); + + auto should_have_message = handle.Receive<TestMessage>(Duration{}); + MG_ASSERT(should_have_message.HasValue()); + + RequestEnvelope<TestMessage> re = should_have_message.GetValue(); + TestMessage request = std::get<TestMessage>(std::move(re.message)); + MG_ASSERT(request.value == 777); +} + +/// this test "sends" a TestMessage to a server and expects to receive +/// a TestMessage back with the same value. +TEST(Thrift, ThriftHandleRequestReceive) { + // use the same address for now, to rely on loopback optimization + auto our_address = Address::TestAddress(0); + auto cli_address = our_address; + auto srv_address = cli_address; + + auto handle = ThriftHandle{our_address}; + + auto timeout = Duration{}; + auto request_id = 1; + auto expected_value = 323; + auto request = TestMessage{}; + request.value = expected_value; + + auto [future, promise] = FuturePromisePair<ResponseResult<TestMessage>>(); + + handle.SubmitRequest(srv_address, cli_address, request_id, std::move(request), timeout, std::move(promise)); + + // TODO(tyler) do actual socket stuff in the future maybe + + ResponseResult<TestMessage> response_result = std::move(future).Wait(); + MG_ASSERT(response_result.HasValue()); + ResponseEnvelope<TestMessage> response_envelope = response_result.GetValue(); + TestMessage response = response_envelope.message; + MG_ASSERT(response.value == expected_value); +}