diff --git a/src/database/graph_db_accessor.cpp b/src/database/graph_db_accessor.cpp index 969ed35cf..be5473f4d 100644 --- a/src/database/graph_db_accessor.cpp +++ b/src/database/graph_db_accessor.cpp @@ -177,7 +177,7 @@ void GraphDbAccessor::BuildIndex(storage::Label label, // CreateIndex. GraphDbAccessor dba(db_); - std::experimental::optional<std::vector<std::future<bool>>> + std::experimental::optional<std::vector<utils::Future<bool>>> index_rpc_completions; // Notify all workers to start building an index if we are the master since diff --git a/src/distributed/remote_pull_rpc_clients.hpp b/src/distributed/remote_pull_rpc_clients.hpp index 88422a0a6..691625e74 100644 --- a/src/distributed/remote_pull_rpc_clients.hpp +++ b/src/distributed/remote_pull_rpc_clients.hpp @@ -10,6 +10,7 @@ #include "query/frontend/semantic/symbol.hpp" #include "query/parameters.hpp" #include "transactions/type.hpp" +#include "utils/future.hpp" namespace distributed { @@ -30,7 +31,7 @@ class RemotePullRpcClients { /// @todo: it might be cleaner to split RemotePull into {InitRemoteCursor, /// RemotePull, RemoteAccumulate}, but that's a lot of refactoring and more /// RPC calls. - std::future<RemotePullData> RemotePull( + utils::Future<RemotePullData> RemotePull( database::GraphDbAccessor &dba, int worker_id, int64_t plan_id, const Parameters ¶ms, const std::vector<query::Symbol> &symbols, bool accumulate, int batch_size = kDefaultBatchSize) { @@ -86,7 +87,7 @@ class RemotePullRpcClients { auto GetWorkerIds() { return clients_.GetWorkerIds(); } - std::vector<std::future<void>> NotifyAllTransactionCommandAdvanced( + std::vector<utils::Future<void>> NotifyAllTransactionCommandAdvanced( tx::transaction_id_t tx_id) { return clients_.ExecuteOnWorkers<void>(0, [tx_id](auto &client) { client.template Call<TransactionCommandAdvancedRpc>(tx_id); diff --git a/src/distributed/remote_updates_rpc_clients.hpp b/src/distributed/remote_updates_rpc_clients.hpp index cb720ae84..f7b803698 100644 --- a/src/distributed/remote_updates_rpc_clients.hpp +++ b/src/distributed/remote_updates_rpc_clients.hpp @@ -12,6 +12,7 @@ #include "storage/gid.hpp" #include "storage/types.hpp" #include "transactions/type.hpp" +#include "utils/future.hpp" namespace distributed { @@ -133,7 +134,7 @@ class RemoteUpdatesRpcClients { /// Calls for all the workers (except the given one) to apply their updates /// and returns the future results. - std::vector<std::future<RemoteUpdateResult>> RemoteUpdateApplyAll( + std::vector<utils::Future<RemoteUpdateResult>> RemoteUpdateApplyAll( int skip_worker_id, tx::transaction_id_t tx_id) { return worker_clients_.ExecuteOnWorkers<RemoteUpdateResult>( skip_worker_id, [tx_id](auto &client) { diff --git a/src/distributed/rpc_worker_clients.hpp b/src/distributed/rpc_worker_clients.hpp index 2c1c9bb7e..215d94a14 100644 --- a/src/distributed/rpc_worker_clients.hpp +++ b/src/distributed/rpc_worker_clients.hpp @@ -1,7 +1,6 @@ #pragma once #include <functional> -#include <future> #include <type_traits> #include <unordered_map> @@ -10,6 +9,7 @@ #include "distributed/index_rpc_messages.hpp" #include "storage/types.hpp" #include "transactions/transaction.hpp" +#include "utils/future.hpp" namespace distributed { @@ -37,16 +37,17 @@ class RpcWorkerClients { auto GetWorkerIds() { return coordination_.GetWorkerIds(); } /** Asynchroniously executes the given function on the rpc client for the - * given worker id. Returns an `std::future` of the given `execute` function's + * given worker id. Returns an `utils::Future` of the given `execute` + * function's * return type. */ template <typename TResult> auto ExecuteOnWorker( int worker_id, std::function<TResult(communication::rpc::ClientPool &)> execute) { auto &client_pool = GetClientPool(worker_id); - return std::async(std::launch::async, [execute, &client_pool]() { - return execute(client_pool); - }); + return utils::make_future( + std::async(std::launch::async, + [execute, &client_pool]() { return execute(client_pool); })); } /** Asynchroniously executes the `execute` function on all worker rpc clients @@ -56,7 +57,7 @@ class RpcWorkerClients { auto ExecuteOnWorkers( int skip_worker_id, std::function<TResult(communication::rpc::ClientPool &)> execute) { - std::vector<std::future<TResult>> futures; + std::vector<utils::Future<TResult>> futures; for (auto &worker_id : coordination_.GetWorkerIds()) { if (worker_id == skip_worker_id) continue; futures.emplace_back(std::move(ExecuteOnWorker(worker_id, execute))); @@ -93,5 +94,4 @@ class IndexRpcClients { private: RpcWorkerClients &clients_; }; - } // namespace distributed diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 6e5e53116..204e2321f 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -26,6 +26,7 @@ #include "query/path.hpp" #include "utils/algorithm.hpp" #include "utils/exceptions.hpp" +#include "utils/future.hpp" DEFINE_HIDDEN_int32(remote_pull_sleep_micros, 10, "Sleep between remote result pulling in microseconds"); @@ -408,10 +409,10 @@ std::unique_ptr<Cursor> ScanAllByLabelPropertyRange::MakeCursor( context.symbol_table_, db, graph_view_); auto convert = [&evaluator](const auto &bound) -> std::experimental::optional<utils::Bound<PropertyValue>> { - if (!bound) return std::experimental::nullopt; - return std::experimental::make_optional(utils::Bound<PropertyValue>( - bound.value().value()->Accept(evaluator), bound.value().type())); - }; + if (!bound) return std::experimental::nullopt; + return std::experimental::make_optional(utils::Bound<PropertyValue>( + bound.value().value()->Accept(evaluator), bound.value().type())); + }; return db.Vertices(label_, property_, convert(lower_bound()), convert(upper_bound()), graph_view_ == GraphView::NEW); }; @@ -3058,7 +3059,7 @@ class RemotePuller { if (found_it == remote_pulls_.end()) continue; auto &remote_pull = found_it->second; - if (!utils::IsFutureReady(remote_pull)) continue; + if (!remote_pull.IsReady()) continue; auto remote_results = remote_pull.get(); switch (remote_results.pull_state) { @@ -3129,7 +3130,7 @@ class RemotePuller { database::GraphDbAccessor &db_; std::vector<Symbol> symbols_; int64_t plan_id_; - std::unordered_map<int, std::future<distributed::RemotePullData>> + std::unordered_map<int, utils::Future<distributed::RemotePullData>> remote_pulls_; std::unordered_map<int, std::vector<std::vector<query::TypedValue>>> remote_results_; @@ -3268,7 +3269,8 @@ class SynchronizeCursor : public Cursor { auto &db = context.db_accessor_.db(); // Tell all workers to accumulate, only if there is a remote pull. - std::vector<std::future<distributed::RemotePullData>> worker_accumulations; + std::vector<utils::Future<distributed::RemotePullData>> + worker_accumulations; if (pull_remote_cursor_) { for (auto worker_id : db.remote_pull_clients().GetWorkerIds()) { if (worker_id == db.WorkerId()) continue; diff --git a/src/utils/future.hpp b/src/utils/future.hpp new file mode 100644 index 000000000..03105ba8b --- /dev/null +++ b/src/utils/future.hpp @@ -0,0 +1,45 @@ +#pragma once +/// @file + +#include <future> + +namespace utils { + +/// Wraps an `std::future` object to ensure that upon destruction the +/// `std::future` is waited on. +template <typename TResult> +class Future { + public: + Future() {} + Future(std::future<TResult> future) : future_(std::move(future)) {} + + Future(const Future &) = delete; + Future(Future &&) = default; + Future &operator=(const Future &) = delete; + Future &operator=(Future &&) = default; + + ~Future() { + if (future_.valid()) future_.wait(); + } + + /// Returns true if the future has the result available. NOTE: The behaviour + /// is undefined if future isn't valid, i.e. `future.valid() == false`. + bool IsReady() const { + auto status = future_.wait_for(std::chrono::seconds(0)); + return status == std::future_status::ready; + } + + auto get() { return future_.get(); } + auto wait() { return future_.wait(); } + auto valid() { return future_.valid(); } + + private: + std::future<TResult> future_; +}; + +/// Creates a `Future` from the given `std::future`. +template <typename TResult> +Future<TResult> make_future(std::future<TResult> future) { + return Future<TResult>(std::move(future)); +} +} diff --git a/tests/unit/distributed_common.hpp b/tests/unit/distributed_common.hpp index 546ecafc2..1d6fa0a2a 100644 --- a/tests/unit/distributed_common.hpp +++ b/tests/unit/distributed_common.hpp @@ -28,11 +28,14 @@ class DistributedGraphDbTest : public ::testing::Test { }; protected: + virtual int QueryExecutionTimeSec(int) { return 180; } + void SetUp() override { const auto kInitTime = 200ms; database::Config master_config; master_config.master_endpoint = {kLocal, 0}; + master_config.query_execution_time_sec = QueryExecutionTimeSec(0); master_ = std::make_unique<database::Master>(master_config); std::this_thread::sleep_for(kInitTime); @@ -41,6 +44,7 @@ class DistributedGraphDbTest : public ::testing::Test { config.worker_id = worker_id; config.master_endpoint = master_->endpoint(); config.worker_endpoint = {kLocal, 0}; + config.query_execution_time_sec = QueryExecutionTimeSec(worker_id); return config; }; diff --git a/tests/unit/distributed_interpretation.cpp b/tests/unit/distributed_interpretation.cpp index f09287732..03a108ea7 100644 --- a/tests/unit/distributed_interpretation.cpp +++ b/tests/unit/distributed_interpretation.cpp @@ -6,6 +6,7 @@ #include "query/interpreter.hpp" #include "query_common.hpp" #include "query_plan_common.hpp" +#include "utils/timer.hpp" using namespace distributed; using namespace database; @@ -168,3 +169,53 @@ TEST_F(DistributedInterpretationTest, Cartesian) { ASSERT_THAT(got, testing::UnorderedElementsAreArray(expected)); } + +class TestQueryWaitsOnFutures : public DistributedInterpretationTest { + protected: + int QueryExecutionTimeSec(int worker_id) override { + return worker_id == 2 ? 3 : 1; + } +}; + +TEST_F(TestQueryWaitsOnFutures, Test) { + const int kVertexCount = 10; + auto make_fully_connected = [this](database::GraphDb &db) { + database::GraphDbAccessor dba(db); + std::vector<VertexAccessor> vertices; + for (int i = 0; i < kVertexCount; ++i) + vertices.emplace_back(dba.InsertVertex()); + auto et = dba.EdgeType("et"); + for (auto &from : vertices) + for (auto &to : vertices) dba.InsertEdge(from, to, et); + dba.Commit(); + }; + + make_fully_connected(worker(1)); + ASSERT_EQ(VertexCount(worker(1)), kVertexCount); + ASSERT_EQ(EdgeCount(worker(1)), kVertexCount * kVertexCount); + + { + utils::Timer timer; + try { + Run("MATCH ()--()--()--()--()--()--() RETURN count(1)"); + } catch (...) { + } + double seconds = timer.Elapsed().count(); + EXPECT_GT(seconds, 1); + EXPECT_LT(seconds, 2); + } + + make_fully_connected(worker(2)); + ASSERT_EQ(VertexCount(worker(2)), kVertexCount); + ASSERT_EQ(EdgeCount(worker(2)), kVertexCount * kVertexCount); + + { + utils::Timer timer; + try { + Run("MATCH ()--()--()--()--()--()--() RETURN count(1)"); + } catch (...) { + } + double seconds = timer.Elapsed().count(); + EXPECT_GT(seconds, 3); + } +} diff --git a/tests/unit/distributed_query_plan.cpp b/tests/unit/distributed_query_plan.cpp index 1b1cd7d65..0f248042f 100644 --- a/tests/unit/distributed_query_plan.cpp +++ b/tests/unit/distributed_query_plan.cpp @@ -325,10 +325,7 @@ TEST_F(DistributedGraphDbTest, PullRemoteOrderBy) { class DistributedTransactionTimeout : public DistributedGraphDbTest { protected: - void SetUp() override { - FLAGS_query_execution_time_sec = 1; - DistributedGraphDbTest::SetUp(); - } + int QueryExecutionTimeSec(int) override { return 1; } }; TEST_F(DistributedTransactionTimeout, Timeout) {