diff --git a/tests/feature_benchmark/ha/benchmark.cpp b/tests/feature_benchmark/ha/benchmark.cpp index 18d12879e..2e32f0168 100644 --- a/tests/feature_benchmark/ha/benchmark.cpp +++ b/tests/feature_benchmark/ha/benchmark.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -24,15 +25,54 @@ DEFINE_int64(query_count, 0, "How many queries should we execute."); DEFINE_int64(timeout, 60, "How many seconds should the benchmark wait."); DEFINE_string(output_file, "", "Output file where the results should be."); +std::experimental::optional GetLeaderEndpoint() { + for (int retry = 0; retry < 10; ++retry) { + for (int i = 0; i < FLAGS_cluster_size; ++i) { + try { + communication::ClientContext context(FLAGS_use_ssl); + communication::bolt::Client client(&context); + + uint16_t port = FLAGS_port + i; + io::network::Endpoint endpoint{FLAGS_address, port}; + + client.Connect(endpoint, FLAGS_username, FLAGS_password); + client.Execute("MATCH (n) RETURN n", {}); + client.Close(); + + // If we succeeded with the above query, we found the current leader. + return std::experimental::make_optional(endpoint); + + } catch (const communication::bolt::ClientQueryException &) { + // This one is not the leader, continue. + continue; + } catch (const communication::bolt::ClientFatalException &) { + // This one seems to be down, continue. + continue; + } + } + + LOG(INFO) << "Couldn't find Raft cluster leader, retrying..."; + std::this_thread::sleep_for(1s); + } + + return std::experimental::nullopt; +} + int main(int argc, char **argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); google::SetUsageMessage("Memgraph HA benchmark client"); google::InitGoogleLogging(argv[0]); - int64_t query_counter = 0; + std::atomic query_counter{0}; std::atomic timeout_reached{false}; std::atomic benchmark_finished{false}; + auto leader_endpoint = GetLeaderEndpoint(); + if (!leader_endpoint) { + LOG(ERROR) << "Couldn't find Raft cluster leader!"; + return 1; + } + // Kickoff a thread that will timeout after FLAGS_timeout seconds std::thread timeout_thread_ = std::thread([&timeout_reached, &benchmark_finished]() { @@ -45,62 +85,50 @@ int main(int argc, char **argv) { timeout_reached.store(true); }); - double duration = 0; - double write_per_second = 0; + std::vector threads; - bool successful = false; - for (int retry = 0; !successful && retry < 10; ++retry) { - for (int i = 0; !successful && i < FLAGS_cluster_size; ++i) { - try { - communication::ClientContext context(FLAGS_use_ssl); - communication::bolt::Client client(&context); + for (int i = 0; i < std::thread::hardware_concurrency(); ++i) { + threads.emplace_back( + [endpoint = *leader_endpoint, &timeout_reached, &query_counter]() { + communication::ClientContext context(FLAGS_use_ssl); + communication::bolt::Client client(&context); + client.Connect(endpoint, FLAGS_username, FLAGS_password); - uint16_t port = FLAGS_port + i; - io::network::Endpoint endpoint{FLAGS_address, port}; - client.Connect(endpoint, FLAGS_username, FLAGS_password); + while (query_counter.load() < FLAGS_query_count) { + if (timeout_reached.load()) break; - utils::Timer timer; - for (int k = 0; k < FLAGS_query_count; ++k) { - client.Execute("CREATE (:Node)", {}); - query_counter++; - - if (timeout_reached.load()) break; - } - - duration = timer.Elapsed().count(); - successful = true; - - } catch (const communication::bolt::ClientQueryException &) { - // This one is not the leader, continue. - continue; - } catch (const communication::bolt::ClientFatalException &) { - // This one seems to be down, continue. - continue; - } - - if (timeout_reached.load()) break; - } - - if (timeout_reached.load()) break; - if (!successful) { - LOG(INFO) << "Couldn't find Raft cluster leader, retrying..."; - std::this_thread::sleep_for(1s); - } + try { + client.Execute("CREATE (:Node)", {}); + query_counter.fetch_add(1); + } catch (const communication::bolt::ClientQueryException &e) { + LOG(WARNING) << e.what(); + break; + } catch (const communication::bolt::ClientFatalException &e) { + LOG(WARNING) << e.what(); + break; + } + } + }); } + utils::Timer timer; + int64_t query_offset = query_counter.load(); + + for (auto &t : threads) { + if (t.joinable()) t.join(); + } + + double duration = timer.Elapsed().count(); + double write_per_second = (query_counter - query_offset) / duration; + benchmark_finished.store(true); if (timeout_thread_.joinable()) timeout_thread_.join(); - if (successful) { - write_per_second = query_counter / duration; - } - std::ofstream output(FLAGS_output_file); output << "duration " << duration << std::endl; output << "executed_writes " << query_counter << std::endl; output << "write_per_second " << write_per_second << std::endl; output.close(); - if (!successful) return 1; return 0; }