From ac5c6bf0e846b4f420e919f2fd11655020cdb9ff Mon Sep 17 00:00:00 2001
From: Matija Santl <matija.santl@memgraph.com>
Date: Tue, 15 Jan 2019 16:07:16 +0100
Subject: [PATCH] Add multi-threaded benchmark client for HA

Summary:
There are some serious speedups when doing parallel writes.

Results on my machine (4 cores):
```
duration 6.73173
executed_writes 15003
write_per_second 2228.7
```

Reviewers: ipaljak

Reviewed By: ipaljak

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1807
---
 tests/feature_benchmark/ha/benchmark.cpp | 116 ++++++++++++++---------
 1 file changed, 72 insertions(+), 44 deletions(-)

diff --git a/tests/feature_benchmark/ha/benchmark.cpp b/tests/feature_benchmark/ha/benchmark.cpp
index 18d12879e..2e32f0168 100644
--- a/tests/feature_benchmark/ha/benchmark.cpp
+++ b/tests/feature_benchmark/ha/benchmark.cpp
@@ -1,5 +1,6 @@
 #include <atomic>
 #include <chrono>
+#include <experimental/optional>
 #include <fstream>
 #include <thread>
 
@@ -24,15 +25,54 @@ DEFINE_int64(query_count, 0, "How many queries should we execute.");
 DEFINE_int64(timeout, 60, "How many seconds should the benchmark wait.");
 DEFINE_string(output_file, "", "Output file where the results should be.");
 
+std::experimental::optional<io::network::Endpoint> GetLeaderEndpoint() {
+  for (int retry = 0; retry < 10; ++retry) {
+    for (int i = 0; i < FLAGS_cluster_size; ++i) {
+      try {
+        communication::ClientContext context(FLAGS_use_ssl);
+        communication::bolt::Client client(&context);
+
+        uint16_t port = FLAGS_port + i;
+        io::network::Endpoint endpoint{FLAGS_address, port};
+
+        client.Connect(endpoint, FLAGS_username, FLAGS_password);
+        client.Execute("MATCH (n) RETURN n", {});
+        client.Close();
+
+        // If we succeeded with the above query, we found the current leader.
+        return std::experimental::make_optional(endpoint);
+
+      } catch (const communication::bolt::ClientQueryException &) {
+        // This one is not the leader, continue.
+        continue;
+      } catch (const communication::bolt::ClientFatalException &) {
+        // This one seems to be down, continue.
+        continue;
+      }
+    }
+
+    LOG(INFO) << "Couldn't find Raft cluster leader, retrying...";
+    std::this_thread::sleep_for(1s);
+  }
+
+  return std::experimental::nullopt;
+}
+
 int main(int argc, char **argv) {
   gflags::ParseCommandLineFlags(&argc, &argv, true);
   google::SetUsageMessage("Memgraph HA benchmark client");
   google::InitGoogleLogging(argv[0]);
 
-  int64_t query_counter = 0;
+  std::atomic<int64_t> query_counter{0};
   std::atomic<bool> timeout_reached{false};
   std::atomic<bool> benchmark_finished{false};
 
+  auto leader_endpoint = GetLeaderEndpoint();
+  if (!leader_endpoint) {
+    LOG(ERROR) << "Couldn't find Raft cluster leader!";
+    return 1;
+  }
+
   // Kickoff a thread that will timeout after FLAGS_timeout seconds
   std::thread timeout_thread_ =
       std::thread([&timeout_reached, &benchmark_finished]() {
@@ -45,62 +85,50 @@ int main(int argc, char **argv) {
         timeout_reached.store(true);
       });
 
-  double duration = 0;
-  double write_per_second = 0;
+  std::vector<std::thread> threads;
 
-  bool successful = false;
-  for (int retry = 0; !successful && retry < 10; ++retry) {
-    for (int i = 0; !successful && i < FLAGS_cluster_size; ++i) {
-      try {
-        communication::ClientContext context(FLAGS_use_ssl);
-        communication::bolt::Client client(&context);
+  for (int i = 0; i < std::thread::hardware_concurrency(); ++i) {
+    threads.emplace_back(
+        [endpoint = *leader_endpoint, &timeout_reached, &query_counter]() {
+          communication::ClientContext context(FLAGS_use_ssl);
+          communication::bolt::Client client(&context);
+          client.Connect(endpoint, FLAGS_username, FLAGS_password);
 
-        uint16_t port = FLAGS_port + i;
-        io::network::Endpoint endpoint{FLAGS_address, port};
-        client.Connect(endpoint, FLAGS_username, FLAGS_password);
+          while (query_counter.load() < FLAGS_query_count) {
+            if (timeout_reached.load()) break;
 
-        utils::Timer timer;
-        for (int k = 0; k < FLAGS_query_count; ++k) {
-          client.Execute("CREATE (:Node)", {});
-          query_counter++;
-
-          if (timeout_reached.load()) break;
-        }
-
-        duration = timer.Elapsed().count();
-        successful = true;
-
-      } catch (const communication::bolt::ClientQueryException &) {
-        // This one is not the leader, continue.
-        continue;
-      } catch (const communication::bolt::ClientFatalException &) {
-        // This one seems to be down, continue.
-        continue;
-      }
-
-      if (timeout_reached.load()) break;
-    }
-
-    if (timeout_reached.load()) break;
-    if (!successful) {
-      LOG(INFO) << "Couldn't find Raft cluster leader, retrying...";
-      std::this_thread::sleep_for(1s);
-    }
+            try {
+              client.Execute("CREATE (:Node)", {});
+              query_counter.fetch_add(1);
+            } catch (const communication::bolt::ClientQueryException &e) {
+              LOG(WARNING) << e.what();
+              break;
+            } catch (const communication::bolt::ClientFatalException &e) {
+              LOG(WARNING) << e.what();
+              break;
+            }
+          }
+        });
   }
 
+  utils::Timer timer;
+  int64_t query_offset = query_counter.load();
+
+  for (auto &t : threads) {
+    if (t.joinable()) t.join();
+  }
+
+  double duration = timer.Elapsed().count();
+  double write_per_second = (query_counter - query_offset) / duration;
+
   benchmark_finished.store(true);
   if (timeout_thread_.joinable()) timeout_thread_.join();
 
-  if (successful) {
-    write_per_second = query_counter / duration;
-  }
-
   std::ofstream output(FLAGS_output_file);
   output << "duration " << duration << std::endl;
   output << "executed_writes " << query_counter << std::endl;
   output << "write_per_second " << write_per_second << std::endl;
   output.close();
 
-  if (!successful) return 1;
   return 0;
 }