Add new unit test for the thrift handle. Implement hairpin optimization. Fill out basic thrift handle functionality
This commit is contained in:
parent
667275479d
commit
36432ce6d0
src/io
tests/unit
@ -36,6 +36,7 @@ using memgraph::io::Duration;
|
||||
using memgraph::io::OpaqueMessage;
|
||||
using memgraph::io::OpaquePromise;
|
||||
using memgraph::io::Time;
|
||||
using memgraph::io::TimedOut;
|
||||
|
||||
class SimulatorHandle {
|
||||
mutable std::mutex mu_{};
|
||||
|
@ -15,6 +15,7 @@
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
|
||||
#include "io/errors.hpp"
|
||||
#include "io/message_conversion.hpp"
|
||||
#include "io/transport.hpp"
|
||||
|
||||
@ -23,34 +24,61 @@ namespace memgraph::io::thrift {
|
||||
using memgraph::io::Address;
|
||||
using memgraph::io::OpaqueMessage;
|
||||
using memgraph::io::OpaquePromise;
|
||||
using memgraph::io::TimedOut;
|
||||
using RequestId = uint64_t;
|
||||
|
||||
class ThriftHandle {
|
||||
mutable std::mutex mu_{};
|
||||
mutable std::condition_variable cv_;
|
||||
const Address address_ = Address::TestAddress(0);
|
||||
|
||||
// the responses to requests that are being waited on
|
||||
std::map<RequestId, DeadlineAndOpaquePromise> promises_;
|
||||
std::map<PromiseKey, DeadlineAndOpaquePromise> promises_;
|
||||
|
||||
// messages that are sent to servers that may later receive them
|
||||
std::vector<OpaqueMessage> can_receive_;
|
||||
|
||||
// TODO(tyler) thrift clients for each outbound address combination
|
||||
std::map<Address, void *> clients_;
|
||||
// std::map<Address, void *> clients_;
|
||||
|
||||
// TODO(gabor) make this to a threadpool
|
||||
// uuid of the address -> port number where the given rsm is residing.
|
||||
// TODO(gabor) The RSM map should not be a part of this class.
|
||||
std::map<boost::uuids::uuid, uint16_t /*this should be the actual RSM*/> rsm_map_;
|
||||
// std::map<boost::uuids::uuid, uint16_t /*this should be the actual RSM*/> rsm_map_;
|
||||
|
||||
// this is duplicated between the ThriftTransport and here
|
||||
// because it's relatively simple and there's no need to
|
||||
// avoid the duplication as of the time of implementation.
|
||||
Time Now() const {
|
||||
auto nano_time = std::chrono::system_clock::now();
|
||||
return std::chrono::time_point_cast<std::chrono::microseconds>(nano_time);
|
||||
}
|
||||
|
||||
public:
|
||||
ThriftHandle(Address our_address) : address_(our_address) {}
|
||||
|
||||
template <Message M>
|
||||
void DeliverMessage(Address from_address, RequestId request_id, M &&message) {
|
||||
void DeliverMessage(Address to_address, Address from_address, RequestId request_id, M &&message) {
|
||||
std::any message_any(std::move(message));
|
||||
OpaqueMessage opaque_message{
|
||||
.from_address = from_address, .request_id = request_id, .message = std::move(message_any)};
|
||||
|
||||
PromiseKey promise_key{.requester_address = to_address,
|
||||
.request_id = opaque_message.request_id,
|
||||
.replier_address = opaque_message.from_address};
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mu_);
|
||||
std::any message_any(std::move(message));
|
||||
OpaqueMessage om{.from_address = from_address, .request_id = request_id, .message = std::move(message_any)};
|
||||
can_receive_.emplace_back(std::move(om));
|
||||
|
||||
if (promises_.contains(promise_key)) {
|
||||
// complete waiting promise if it's there
|
||||
DeadlineAndOpaquePromise dop = std::move(promises_.at(promise_key));
|
||||
promises_.erase(promise_key);
|
||||
|
||||
dop.promise.Fill(std::move(opaque_message));
|
||||
} else {
|
||||
can_receive_.emplace_back(std::move(opaque_message));
|
||||
}
|
||||
} // lock dropped
|
||||
|
||||
cv_.notify_all();
|
||||
@ -59,33 +87,46 @@ class ThriftHandle {
|
||||
template <Message Request, Message Response>
|
||||
void SubmitRequest(Address to_address, Address from_address, RequestId request_id, Request &&request,
|
||||
Duration timeout, ResponsePromise<Response> &&promise) {
|
||||
// TODO(tyler) simular to simulator transport, add the promise to the promises_ map
|
||||
const Time deadline = Now() + timeout;
|
||||
Address our_address = address_;
|
||||
|
||||
Send(to_address, from_address, request_id, request);
|
||||
PromiseKey promise_key{.requester_address = from_address, .request_id = request_id, .replier_address = to_address};
|
||||
OpaquePromise opaque_promise(std::move(promise).ToUnique());
|
||||
DeadlineAndOpaquePromise dop{.deadline = deadline, .promise = std::move(opaque_promise)};
|
||||
promises_.emplace(std::move(promise_key), std::move(dop));
|
||||
|
||||
cv_.notify_all();
|
||||
|
||||
bool port_matches = to_address.last_known_port == our_address.last_known_port;
|
||||
bool ip_matches = to_address.last_known_ip == our_address.last_known_ip;
|
||||
|
||||
if (port_matches && ip_matches) {
|
||||
// hairpin routing optimization
|
||||
DeliverMessage(to_address, from_address, request_id, std::move(request));
|
||||
} else {
|
||||
// send using a thrift client to remove service
|
||||
Send(to_address, from_address, request_id, request);
|
||||
}
|
||||
}
|
||||
|
||||
template <Message... Ms>
|
||||
requires(sizeof...(Ms) > 0) RequestResult<Ms...> Receive(const Address &receiver, Duration timeout) {
|
||||
requires(sizeof...(Ms) > 0) RequestResult<Ms...> Receive(Duration timeout) {
|
||||
// TODO(tyler) block for the specified duration on the Inbox's receipt of a message of this type.
|
||||
std::unique_lock lock(mu_);
|
||||
cv_.wait(lock, [this] { return !can_receive_.empty(); });
|
||||
|
||||
while (!can_receive_.empty()) {
|
||||
auto current_message = can_receive_.back();
|
||||
can_receive_.pop_back();
|
||||
|
||||
// Logic to determine who to send the message.
|
||||
//
|
||||
auto destination_id = current_message.to_address.unique_id;
|
||||
auto destination_port = rsm_map_.at(destination_id);
|
||||
|
||||
// Send it to the port of the destination -how?
|
||||
// TODO(tyler) search for item in can_receive_ that matches the desired types, rather
|
||||
// than asserting that the last item in can_rx matches.
|
||||
auto m_opt = std::move(current_message).Take<Ms...>();
|
||||
|
||||
return (std::move(m_opt));
|
||||
while (can_receive_.empty()) {
|
||||
std::cv_status cv_status_value = cv_.wait_for(lock, timeout);
|
||||
if (cv_status_value == std::cv_status::timeout) {
|
||||
return TimedOut{};
|
||||
}
|
||||
}
|
||||
|
||||
auto current_message = std::move(can_receive_.back());
|
||||
can_receive_.pop_back();
|
||||
|
||||
auto m_opt = std::move(current_message).Take<Ms...>();
|
||||
|
||||
return std::move(m_opt).value();
|
||||
}
|
||||
|
||||
template <Message M>
|
||||
|
@ -37,7 +37,7 @@ class ThriftTransport {
|
||||
|
||||
template <Message Request, Message Response>
|
||||
ResponseFuture<Response> Request(Address address, uint64_t request_id, Request request, Duration timeout) {
|
||||
auto [future, promise] = memgraph::io::FuturePromisePairWithNotifier<ResponseResult<Response>>();
|
||||
auto [future, promise] = memgraph::io::FuturePromisePair<ResponseResult<Response>>();
|
||||
|
||||
thrift_handle_->SubmitRequest(address, address_, request_id, std::move(request), timeout, std::move(promise));
|
||||
|
||||
|
@ -400,3 +400,7 @@ target_link_libraries(${test_prefix}future mg-io)
|
||||
# Test Thrift transport echo
|
||||
add_unit_test(thrift_transport_echo.cpp)
|
||||
target_link_libraries(${test_prefix}thrift_transport_echo mg-io fmt Threads::Threads FBThrift::thriftcpp2 mg-interface-echo-cpp2)
|
||||
|
||||
# Test Thrift transport echo
|
||||
add_unit_test(thrift_handle.cpp)
|
||||
target_link_libraries(${test_prefix}thrift_handle mg-io fmt Threads::Threads)
|
||||
|
93
tests/unit/thrift_handle.cpp
Normal file
93
tests/unit/thrift_handle.cpp
Normal file
@ -0,0 +1,93 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "io/address.hpp"
|
||||
#include "io/future.hpp"
|
||||
#include "io/thrift/thrift_handle.hpp"
|
||||
#include "io/transport.hpp"
|
||||
#include "utils/logging.hpp"
|
||||
|
||||
using memgraph::io::Address;
|
||||
using memgraph::io::Duration;
|
||||
using memgraph::io::FuturePromisePair;
|
||||
using memgraph::io::RequestEnvelope;
|
||||
using memgraph::io::ResponseEnvelope;
|
||||
using memgraph::io::ResponseResult;
|
||||
using memgraph::io::thrift::ThriftHandle;
|
||||
|
||||
struct TestMessage {
|
||||
int value;
|
||||
};
|
||||
|
||||
TEST(Thrift, ThriftHandleTimeout) {
|
||||
auto our_address = Address::TestAddress(0);
|
||||
auto handle = ThriftHandle{our_address};
|
||||
|
||||
// assert timeouts fire
|
||||
auto should_timeout = handle.Receive<TestMessage>(Duration{});
|
||||
MG_ASSERT(should_timeout.HasError());
|
||||
}
|
||||
|
||||
TEST(Thrift, ThriftHandleReceive) {
|
||||
auto our_address = Address::TestAddress(0);
|
||||
auto handle = ThriftHandle{our_address};
|
||||
|
||||
// assert we can send and receive
|
||||
auto to_address = Address::TestAddress(0);
|
||||
auto from_address = Address::TestAddress(1);
|
||||
auto request_id = 0;
|
||||
auto message = TestMessage{
|
||||
.value = 777,
|
||||
};
|
||||
|
||||
handle.DeliverMessage(to_address, from_address, request_id, std::move(message));
|
||||
|
||||
auto should_have_message = handle.Receive<TestMessage>(Duration{});
|
||||
MG_ASSERT(should_have_message.HasValue());
|
||||
|
||||
RequestEnvelope<TestMessage> re = should_have_message.GetValue();
|
||||
TestMessage request = std::get<TestMessage>(std::move(re.message));
|
||||
MG_ASSERT(request.value == 777);
|
||||
}
|
||||
|
||||
/// this test "sends" a TestMessage to a server and expects to receive
|
||||
/// a TestMessage back with the same value.
|
||||
TEST(Thrift, ThriftHandleRequestReceive) {
|
||||
// use the same address for now, to rely on loopback optimization
|
||||
auto our_address = Address::TestAddress(0);
|
||||
auto cli_address = our_address;
|
||||
auto srv_address = cli_address;
|
||||
|
||||
auto handle = ThriftHandle{our_address};
|
||||
|
||||
auto timeout = Duration{};
|
||||
auto request_id = 1;
|
||||
auto expected_value = 323;
|
||||
auto request = TestMessage{};
|
||||
request.value = expected_value;
|
||||
|
||||
auto [future, promise] = FuturePromisePair<ResponseResult<TestMessage>>();
|
||||
|
||||
handle.SubmitRequest(srv_address, cli_address, request_id, std::move(request), timeout, std::move(promise));
|
||||
|
||||
// TODO(tyler) do actual socket stuff in the future maybe
|
||||
|
||||
ResponseResult<TestMessage> response_result = std::move(future).Wait();
|
||||
MG_ASSERT(response_result.HasValue());
|
||||
ResponseEnvelope<TestMessage> response_envelope = response_result.GetValue();
|
||||
TestMessage response = response_envelope.message;
|
||||
MG_ASSERT(response.value == expected_value);
|
||||
}
|
Loading…
Reference in New Issue
Block a user