From 820c943e8a734a3f392c1628297071a5dd453392 Mon Sep 17 00:00:00 2001
From: gvolfing <gabor.volfinger@memgraph.io>
Date: Sun, 21 Aug 2022 20:20:52 +0200
Subject: [PATCH] Add UberService and skeleton for ThriftHandle::Send()

---
 src/interface/echo.thrift        | 14 +++---
 src/interface/ubermessage.thrift |  9 ++++
 src/io/thrift/thrift_handle.hpp  | 82 ++++++++++++++++++++++++++++++--
 tests/unit/thrift_handle.cpp     |  2 +
 4 files changed, 96 insertions(+), 11 deletions(-)

diff --git a/src/interface/echo.thrift b/src/interface/echo.thrift
index 2eb3fad24..ea625bde7 100644
--- a/src/interface/echo.thrift
+++ b/src/interface/echo.thrift
@@ -2,15 +2,15 @@ struct EchoMessage {
     1: binary message;
 }
 
-struct Address{
-    1: string unique_id;
-    2: string last_known_ip;
-    3: i32 last_known_port;
-}
+//struct Address{
+//    1: string unique_id;
+//    2: string last_known_ip;
+//    3: i32 last_known_port;
+//}
 
 struct CompoundMessage{
-    1: Address to_address
-    2: Address from_address
+//    1: Address to_address
+//    2: Address from_address
     3: binary message
 }
 
diff --git a/src/interface/ubermessage.thrift b/src/interface/ubermessage.thrift
index 82ef1ed53..f781a63fe 100644
--- a/src/interface/ubermessage.thrift
+++ b/src/interface/ubermessage.thrift
@@ -1,5 +1,12 @@
 //include "address.thrift"
 
+// TODO(gvolfing) remove this once the include problem is resolved
+struct Address{
+    1: string unique_id;
+    2: string last_known_ip;
+    3: i32 last_known_port;
+}
+
 struct HeartbeatRequest {
   1: bool test;
 }
@@ -38,6 +45,8 @@ union HighLevelUnion {
 struct UberMessage {
     //1: address.Address to_address;
     //2: address.Address from_address;
+    1: Address to_address;
+    2: Address from_address;
     3: i64 request_id;
     4: HighLevelUnion high_level_union;
 }
diff --git a/src/io/thrift/thrift_handle.hpp b/src/io/thrift/thrift_handle.hpp
index 1b433e094..2f435b04f 100644
--- a/src/io/thrift/thrift_handle.hpp
+++ b/src/io/thrift/thrift_handle.hpp
@@ -13,25 +13,45 @@
 
 #include <condition_variable>
 #include <map>
+#include <memory>
 #include <mutex>
 
+#include <boost/asio/ip/tcp.hpp>
+#include <boost/lexical_cast.hpp>
+
+// #include <folly/init/Init.h>
+// #include <folly/io/SocketOptionMap.h>
+// #include <folly/io/async/AsyncServerSocket.h>
+// #include <folly/net/NetworkSocket.h>
+// #include <thrift/lib/cpp2/async/HeaderClientChannel.h>
+// #include <thrift/lib/cpp2/server/ThriftServer.h>
+
+// From generated code
+#include "interface/gen-cpp2/UberServer.h"
+#include "interface/gen-cpp2/UberServerAsyncClient.h"
+
 #include "io/errors.hpp"
 #include "io/message_conversion.hpp"
 #include "io/transport.hpp"
 
 namespace memgraph::io::thrift {
 
+using namespace apache::thrift;
+// using namespace cpp2;
+using namespace folly;
+
 using memgraph::io::Address;
 using memgraph::io::OpaqueMessage;
 using memgraph::io::OpaquePromise;
 using memgraph::io::TimedOut;
 using RequestId = uint64_t;
-
 class ThriftHandle {
   mutable std::mutex mu_{};
   mutable std::condition_variable cv_;
   const Address address_ = Address::TestAddress(0);
 
+  // EventBase base_;
+
   // the responses to requests that are being waited on
   std::map<PromiseKey, DeadlineAndOpaquePromise> promises_;
 
@@ -39,7 +59,7 @@ class ThriftHandle {
   std::vector<OpaqueMessage> can_receive_;
 
   // TODO(tyler) thrift clients for each outbound address combination
-  // std::map<Address, void *> clients_;
+  std::map<Address, cpp2::UberServerAsyncClient> clients_;
 
   // TODO(gabor) make this to a threadpool
   // uuid of the address -> port number where the given rsm is residing.
@@ -47,7 +67,7 @@ class ThriftHandle {
   // std::map<boost::uuids::uuid, uint16_t /*this should be the actual RSM*/> rsm_map_;
 
  public:
-  ThriftHandle(Address our_address) : address_(our_address) {}
+  explicit ThriftHandle(Address our_address) : address_(our_address) {}
 
   Time Now() const {
     auto nano_time = std::chrono::system_clock::now();
@@ -126,7 +146,7 @@ class ThriftHandle {
 
       Duration relative_timeout = timeout - elapsed;
 
-      std::cv_status cv_status_value = cv_.wait_for(lock, relative_timeout);
+      auto cv_status_value = cv_.wait_for(lock, relative_timeout);
 
       if (cv_status_value == std::cv_status::timeout) {
         return TimedOut{};
@@ -144,6 +164,60 @@ class ThriftHandle {
   template <Message M>
   void Send(Address to_address, Address from_address, RequestId request_id, M message) {
     // TODO(tyler) call thrift client for address (or create one if it doesn't exist yet)
+
+    // if(clients_.contains(to_address))
+    // {
+    //   const auto &client = clients_[to_address];
+    //   client.sync_Send(message);
+    // }
+    // else{
+    //   // maybe make this into a member var
+    //   const auto& other_ip = to_address.last_known_ip.to_string();
+    //   const auto& other_port = to_address.last_known_port;
+    //   auto socket(folly::AsyncSocket::newSocket(&base_, other_ip, other_port));
+    //   auto client_channel = HeaderClientChannel::newChannel(std::move(socket));
+    //   // Create a client object
+    //   EchoAsyncClient client(std::move(client_channel));
+
+    // client.sync_Send(message);
+    // }
+  }
+};
+
+class UberMessageService final : cpp2::UberServerSvIf {
+  std::shared_ptr<ThriftHandle> handle_;
+
+  memgraph::io::Address convertToMgAddress(const cpp2::Address address) {
+    memgraph::io::Address ret_address;
+    ret_address = {.unique_id{boost::lexical_cast<boost::uuids::uuid>(address.get_unique_id())},
+                   .last_known_ip{boost::asio::ip::make_address(address.get_last_known_ip())},
+                   .last_known_port = static_cast<uint16_t>(address.get_last_known_port())};
+    return ret_address;
+  }
+
+ public:
+  explicit UberMessageService(std::shared_ptr<ThriftHandle> handle) : handle_{handle} {}
+
+  void ReceiveUberMessage(const cpp2::UberMessage &uber_message) override {
+    const auto &to_address = uber_message.get_to_address();
+    const auto &from_address = uber_message.get_from_address();
+    const auto &request_id = uber_message.get_request_id();
+    auto message = uber_message.get_high_level_union();
+
+    const auto mg_to_address = convertToMgAddress(to_address);
+    const auto mg_from_address = convertToMgAddress(from_address);
+    // Castint int64_t -> uint64_t
+    // FBThrift only provides us with signed integers. If someone
+    // wishes to use signed integers then the go to solution seems to
+    // be to use the one-bigger signed version. Unfortunately FBThrift
+    // does not provide a uint128_t so we have to use the 64 bit one
+    // for now.
+    // TODO(gvolfing) Investigate and try to get around this problem
+    // with Varint or some other Thrift type.
+    const auto mg_request_id = static_cast<uint64_t>(request_id);
+
+    // Transform high_level_union into something usable if needed(?).
+    handle_->DeliverMessage(mg_to_address, mg_from_address, mg_request_id, std::move(message));
   }
 };
 
diff --git a/tests/unit/thrift_handle.cpp b/tests/unit/thrift_handle.cpp
index 3896321bb..1f36b9dcc 100644
--- a/tests/unit/thrift_handle.cpp
+++ b/tests/unit/thrift_handle.cpp
@@ -30,9 +30,11 @@ using memgraph::io::ResponseResult;
 using memgraph::io::Time;
 using memgraph::io::thrift::ThriftHandle;
 
+namespace {
 struct TestMessage {
   int value;
 };
+}  // namespace
 
 TEST(Thrift, ThriftHandleTimeout) {
   auto our_address = Address::TestAddress(0);