Revamp MgFuture, continue threading logic through Simulator

This commit is contained in:
Tyler Neely 2022-07-05 15:45:59 +00:00
parent eb4ca543ea
commit 6debc9e7d8
6 changed files with 128 additions and 62 deletions

View File

@ -3,7 +3,9 @@ set(io_v3_sources
errors.hpp
future.hpp
transport.hpp
simulator.hpp)
notifier.hpp
simulator.hpp
simulator_handle.hpp)
find_package(fmt REQUIRED)
find_package(Threads REQUIRED)

View File

@ -21,7 +21,6 @@
#include "utils/logging.hpp"
#include "errors.hpp"
#include "simulator_handle.hpp"
template <typename T>
class MgPromise;
@ -33,15 +32,17 @@ template <typename T>
std::pair<MgFuture<T>, MgPromise<T>> FuturePromisePair();
template <typename T>
std::pair<MgFuture<T>, MgPromise<T>> FuturePromisePair(SimulatorHandle);
std::pair<MgFuture<T>, MgPromise<T>> FuturePromisePairWithNotifier(std::function<void()>);
template <typename T>
class Shared {
friend std::pair<MgFuture<T>, MgPromise<T>> FuturePromisePair<T>();
friend std::pair<MgFuture<T>, MgPromise<T>> FuturePromisePairWithNotifier<T>(std::function<void()>);
friend MgPromise<T>;
friend MgFuture<T>;
public:
Shared(std::function<void()> simulator_notifier) : simulator_notifier_(simulator_notifier) {}
Shared() = default;
Shared(Shared &&) = delete;
Shared &operator=(Shared &&) = delete;
@ -55,7 +56,7 @@ class Shared {
while (!item_) {
waiting_ = true;
if (simulator_handle_) {
if (simulator_notifier_) {
// We can't hold our own lock while notifying
// the simulator because notifying the simulator
// involves acquiring the simulator's mutex
@ -67,7 +68,7 @@ class Shared {
// so we have to get out of its way to avoid
// a cyclical deadlock.
lock.unlock();
(*simulator_handle_)->NotifySimulator();
(*simulator_notifier_)();
lock.lock();
if (item_) {
// item may have been filled while we
@ -112,13 +113,13 @@ class Shared {
std::optional<T> item_;
bool consumed_;
bool waiting_;
std::optional<std::shared_ptr<SimulatorHandle>> simulator_handle_;
std::optional<std::function<void()>> simulator_notifier_;
};
template <typename T>
class MgFuture {
friend std::pair<MgFuture<T>, MgPromise<T>> FuturePromisePair<T>();
friend std::pair<MgFuture<T>, MgPromise<T>> FuturePromisePair<T>(SimulatorHandle);
friend std::pair<MgFuture<T>, MgPromise<T>> FuturePromisePairWithNotifier<T>(std::function<void()>);
public:
MgFuture(MgFuture &&old) {
@ -154,7 +155,7 @@ class MgFuture {
template <typename T>
class MgPromise {
friend std::pair<MgFuture<T>, MgPromise<T>> FuturePromisePair<T>();
friend std::pair<MgFuture<T>, MgPromise<T>> FuturePromisePair<T>(SimulatorHandle);
friend std::pair<MgFuture<T>, MgPromise<T>> FuturePromisePairWithNotifier<T>(std::function<void()>);
public:
MgPromise(MgPromise &&old) {
@ -191,6 +192,7 @@ class MgPromise {
template <typename T>
std::pair<MgFuture<T>, MgPromise<T>> FuturePromisePair() {
std::shared_ptr<Shared<T>> shared = std::make_shared<Shared<T>>();
MgFuture<T> future = MgFuture<T>(shared);
MgPromise<T> promise = MgPromise<T>(shared);
@ -198,8 +200,11 @@ std::pair<MgFuture<T>, MgPromise<T>> FuturePromisePair() {
}
template <typename T>
std::pair<MgFuture<T>, MgPromise<T>> FuturePromisePair(SimulatorHandle simulator_handle) {
auto [future, promise] = FuturePromisePair<T>();
future.simulator_handle_ = simulator_handle;
std::pair<MgFuture<T>, MgPromise<T>> FuturePromisePairWithNotifier(std::function<void()> simulator_notifier) {
std::shared_ptr<Shared<T>> shared = std::make_shared<Shared<T>>(simulator_notifier);
MgFuture<T> future = MgFuture<T>(shared);
MgPromise<T> promise = MgPromise<T>(shared);
return std::make_pair(std::move(future), std::move(promise));
}

View File

@ -40,22 +40,27 @@ class SimulatorTransport {
template <Message Request, Message Response>
ResponseFuture<Response> RequestTimeout(Address address, uint64_t request_id, Request request,
uint64_t timeout_microseconds) {
std::abort();
std::function<void()> notifier = [=] { simulator_handle_->NotifySimulator(); };
auto [future, promise] = FuturePromisePairWithNotifier<ResponseResult<Response>>(notifier);
return std::move(future);
}
template <Message... Ms>
RequestResult<Ms...> ReceiveTimeout(uint64_t timeout_microseconds) {
std::abort();
}
/*
template <Message... Ms>
RequestResult<Ms...> ReceiveTimeout(uint64_t timeout_microseconds) {
return simulator_handle_->template ReceiveTimeout<Ms...>(timeout_microseconds);
}
template <Message M>
void Send(Address address, uint64_t request_id, M message) {
std::abort();
}
template <Message M>
void Send(Address address, uint64_t request_id, M message) {
return simulator_handle_->template Send<M>(address, request_id, message);
}
*/
std::time_t Now() { std::abort(); }
std::time_t Now() { return std::time(nullptr); }
bool ShouldShutDown() { std::abort(); }
bool ShouldShutDown() { return simulator_handle_->ShouldShutDown(); }
private:
std::shared_ptr<SimulatorHandle> simulator_handle_;

View File

@ -11,7 +11,27 @@
#pragma once
#include <map>
#include "address.hpp"
#include "transport.hpp"
struct OpaqueMessage {
Address address;
uint64_t request_id;
std::unique_ptr<std::any> message;
};
struct PromiseKey {
Address requester;
uint64_t request_id;
Address replier;
};
struct OpaquePromise {
time_t deadline;
std::unique_ptr<std::any> promise;
};
class SimulatorHandle {
public:
@ -20,8 +40,35 @@ class SimulatorHandle {
cv_sim_.notify_all();
}
bool ShouldShutDown() {
std::unique_lock<std::mutex> lock(mu_);
return shut_down_;
}
template <Message Request, Message Response>
void SubmitRequest(Address address, uint64_t request_id, Request request, uint64_t timeout_microseconds,
MgPromise<ResponseResult<Response>> promise) {
std::unique_lock<std::mutex> lock(mu_);
}
/*
template <Message... Ms>
RequestResult<Ms...> ReceiveTimeout(uint64_t timeout_microseconds) {
std::abort();
}
template <Message M>
void Send(Address address, uint64_t request_id, M message) {
std::abort();
}
*/
private:
std::mutex mu_;
std::condition_variable cv_sim_;
std::condition_variable cv_srv_;
std::map<Address, std::vector<OpaqueMessage>> in_flight_;
std::map<PromiseKey, OpaquePromise> promises;
std::map<Address, OpaqueMessage> can_receive;
bool shut_down_;
};

View File

@ -22,6 +22,8 @@
using memgraph::utils::BasicResult;
class SimulatorHandle;
template <typename T>
concept Message = requires(T a, uint8_t *ptr, size_t len) {
// These are placeholders and will be replaced
@ -63,12 +65,14 @@ class Io {
default_timeout_microseconds_ = timeout_microseconds;
}
template <Message Request, Message Response>
ResponseFuture<Response> RequestTimeout(Address address, Request request, uint64_t timeout_microseconds) {
uint64_t request_id = ++request_id_counter_;
return implementation_.template RequestTimeout<Request, Response>(address, request_id, request,
timeout_microseconds);
}
/*
template <Message Request, Message Response>
ResponseFuture<Response> RequestTimeout(Address address, Request request, uint64_t timeout_microseconds) {
uint64_t request_id = ++request_id_counter_;
return implementation_.template RequestTimeout<Request, Response>(address, request_id, request,
timeout_microseconds);
}
*/
template <Message Request, Message Response>
ResponseFuture<Response> RequestTimeout(Address address, Request request) {
@ -78,21 +82,22 @@ class Io {
timeout_microseconds);
}
template <Message... Ms>
RequestResult<Ms...> ReceiveTimeout(uint64_t timeout_microseconds) {
return implementation_.template ReceiveTimeout<Ms...>(timeout_microseconds);
}
/*
template <Message... Ms>
RequestResult<Ms...> ReceiveTimeout(uint64_t timeout_microseconds) {
return implementation_.template ReceiveTimeout<Ms...>(timeout_microseconds);
}
template <Message... Ms>
RequestResult<Ms...> ReceiveTimeout() {
uint64_t timeout_microseconds = default_timeout_microseconds_;
return implementation_.template ReceiveTimeout<Ms...>(timeout_microseconds);
}
template <Message M>
void Send(Address address, uint64_t request_id, M message) {
return implementation_.template Send<M>(address, request_id, message);
}
template <Message... Ms>
RequestResult<Ms...> ReceiveTimeout() {
uint64_t timeout_microseconds = default_timeout_microseconds_;
return implementation_.template ReceiveTimeout<Ms...>(timeout_microseconds);
}
template <Message M>
void Send(Address address, uint64_t request_id, M message) {
return implementation_.template Send<M>(address, request_id, message);
}
*/
std::time_t Now() { return implementation_.Now(); }

View File

@ -12,10 +12,11 @@
//#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "io/v3/simulator.hpp"
#include "io/v3/transport.hpp"
#include "utils/logging.hpp"
//#include "io/v3/transport.hpp"
//#include "utils/logging.hpp"
struct Request {
std::string data;
@ -35,28 +36,29 @@ struct Response {
int main() {
auto simulator = Simulator();
auto addr_1 = Address::TestAddress(1);
auto addr_2 = Address::TestAddress(2);
auto cli_addr = Address::TestAddress(1);
auto srv_addr = Address::TestAddress(2);
auto sim_io_1 = simulator.Register(addr_1, true);
auto sim_io_2 = simulator.Register(addr_2, true);
Io<SimulatorTransport> cli_io = simulator.Register(cli_addr, true);
// Io<SimulatorTransport> srv_io = simulator.Register(srv_addr, true);
// send request
auto response_future = sim_io_1.RequestTimeout<Request, Response>(addr_2, Request{});
Request cli_req;
ResponseFuture<Response> response_future = cli_io.template RequestTimeout<Request, Response>(srv_addr, cli_req);
// receive request
RequestResult<Request> request_result = sim_io_2.ReceiveTimeout<Request>();
auto req_envelope = request_result.GetValue();
Request req = std::get<Request>(req_envelope.message);
auto srv_res = Response{req.data};
// send response
sim_io_2.Send(req_envelope.from, req_envelope.request_id, srv_res);
// receive response
auto response_result = response_future.Wait();
auto res = response_result.GetValue();
// // receive request
// RequestResult<Request> request_result = sim_io_2.template ReceiveTimeout<Request>();
// auto req_envelope = request_result.GetValue();
// Request req = std::get<Request>(req_envelope.message);
//
// auto srv_res = Response{req.data};
//
// // send response
// sim_io_2.Send(req_envelope.from, req_envelope.request_id, srv_res);
//
// // receive response
// auto response_result = response_future.Wait();
// auto res = response_result.GetValue();
return 0;
}