Use custom future that waits on destruct

Reviewers: teon.banek, dgleich

Reviewed By: teon.banek

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1286
This commit is contained in:
florijan 2018-03-08 16:57:15 +01:00
parent 848749cf25
commit 42ca81eb01
9 changed files with 123 additions and 22 deletions

View File

@ -177,7 +177,7 @@ void GraphDbAccessor::BuildIndex(storage::Label label,
// CreateIndex.
GraphDbAccessor dba(db_);
std::experimental::optional<std::vector<std::future<bool>>>
std::experimental::optional<std::vector<utils::Future<bool>>>
index_rpc_completions;
// Notify all workers to start building an index if we are the master since

View File

@ -10,6 +10,7 @@
#include "query/frontend/semantic/symbol.hpp"
#include "query/parameters.hpp"
#include "transactions/type.hpp"
#include "utils/future.hpp"
namespace distributed {
@ -30,7 +31,7 @@ class RemotePullRpcClients {
/// @todo: it might be cleaner to split RemotePull into {InitRemoteCursor,
/// RemotePull, RemoteAccumulate}, but that's a lot of refactoring and more
/// RPC calls.
std::future<RemotePullData> RemotePull(
utils::Future<RemotePullData> RemotePull(
database::GraphDbAccessor &dba, int worker_id, int64_t plan_id,
const Parameters &params, const std::vector<query::Symbol> &symbols,
bool accumulate, int batch_size = kDefaultBatchSize) {
@ -86,7 +87,7 @@ class RemotePullRpcClients {
auto GetWorkerIds() { return clients_.GetWorkerIds(); }
std::vector<std::future<void>> NotifyAllTransactionCommandAdvanced(
std::vector<utils::Future<void>> NotifyAllTransactionCommandAdvanced(
tx::transaction_id_t tx_id) {
return clients_.ExecuteOnWorkers<void>(0, [tx_id](auto &client) {
client.template Call<TransactionCommandAdvancedRpc>(tx_id);

View File

@ -12,6 +12,7 @@
#include "storage/gid.hpp"
#include "storage/types.hpp"
#include "transactions/type.hpp"
#include "utils/future.hpp"
namespace distributed {
@ -133,7 +134,7 @@ class RemoteUpdatesRpcClients {
/// Calls for all the workers (except the given one) to apply their updates
/// and returns the future results.
std::vector<std::future<RemoteUpdateResult>> RemoteUpdateApplyAll(
std::vector<utils::Future<RemoteUpdateResult>> RemoteUpdateApplyAll(
int skip_worker_id, tx::transaction_id_t tx_id) {
return worker_clients_.ExecuteOnWorkers<RemoteUpdateResult>(
skip_worker_id, [tx_id](auto &client) {

View File

@ -1,7 +1,6 @@
#pragma once
#include <functional>
#include <future>
#include <type_traits>
#include <unordered_map>
@ -10,6 +9,7 @@
#include "distributed/index_rpc_messages.hpp"
#include "storage/types.hpp"
#include "transactions/transaction.hpp"
#include "utils/future.hpp"
namespace distributed {
@ -37,16 +37,17 @@ class RpcWorkerClients {
auto GetWorkerIds() { return coordination_.GetWorkerIds(); }
/** Asynchroniously executes the given function on the rpc client for the
* given worker id. Returns an `std::future` of the given `execute` function's
* given worker id. Returns an `utils::Future` of the given `execute`
* function's
* return type. */
template <typename TResult>
auto ExecuteOnWorker(
int worker_id,
std::function<TResult(communication::rpc::ClientPool &)> execute) {
auto &client_pool = GetClientPool(worker_id);
return std::async(std::launch::async, [execute, &client_pool]() {
return execute(client_pool);
});
return utils::make_future(
std::async(std::launch::async,
[execute, &client_pool]() { return execute(client_pool); }));
}
/** Asynchroniously executes the `execute` function on all worker rpc clients
@ -56,7 +57,7 @@ class RpcWorkerClients {
auto ExecuteOnWorkers(
int skip_worker_id,
std::function<TResult(communication::rpc::ClientPool &)> execute) {
std::vector<std::future<TResult>> futures;
std::vector<utils::Future<TResult>> futures;
for (auto &worker_id : coordination_.GetWorkerIds()) {
if (worker_id == skip_worker_id) continue;
futures.emplace_back(std::move(ExecuteOnWorker(worker_id, execute)));
@ -93,5 +94,4 @@ class IndexRpcClients {
private:
RpcWorkerClients &clients_;
};
} // namespace distributed

View File

@ -26,6 +26,7 @@
#include "query/path.hpp"
#include "utils/algorithm.hpp"
#include "utils/exceptions.hpp"
#include "utils/future.hpp"
DEFINE_HIDDEN_int32(remote_pull_sleep_micros, 10,
"Sleep between remote result pulling in microseconds");
@ -408,10 +409,10 @@ std::unique_ptr<Cursor> ScanAllByLabelPropertyRange::MakeCursor(
context.symbol_table_, db, graph_view_);
auto convert = [&evaluator](const auto &bound)
-> std::experimental::optional<utils::Bound<PropertyValue>> {
if (!bound) return std::experimental::nullopt;
return std::experimental::make_optional(utils::Bound<PropertyValue>(
bound.value().value()->Accept(evaluator), bound.value().type()));
};
if (!bound) return std::experimental::nullopt;
return std::experimental::make_optional(utils::Bound<PropertyValue>(
bound.value().value()->Accept(evaluator), bound.value().type()));
};
return db.Vertices(label_, property_, convert(lower_bound()),
convert(upper_bound()), graph_view_ == GraphView::NEW);
};
@ -3058,7 +3059,7 @@ class RemotePuller {
if (found_it == remote_pulls_.end()) continue;
auto &remote_pull = found_it->second;
if (!utils::IsFutureReady(remote_pull)) continue;
if (!remote_pull.IsReady()) continue;
auto remote_results = remote_pull.get();
switch (remote_results.pull_state) {
@ -3129,7 +3130,7 @@ class RemotePuller {
database::GraphDbAccessor &db_;
std::vector<Symbol> symbols_;
int64_t plan_id_;
std::unordered_map<int, std::future<distributed::RemotePullData>>
std::unordered_map<int, utils::Future<distributed::RemotePullData>>
remote_pulls_;
std::unordered_map<int, std::vector<std::vector<query::TypedValue>>>
remote_results_;
@ -3268,7 +3269,8 @@ class SynchronizeCursor : public Cursor {
auto &db = context.db_accessor_.db();
// Tell all workers to accumulate, only if there is a remote pull.
std::vector<std::future<distributed::RemotePullData>> worker_accumulations;
std::vector<utils::Future<distributed::RemotePullData>>
worker_accumulations;
if (pull_remote_cursor_) {
for (auto worker_id : db.remote_pull_clients().GetWorkerIds()) {
if (worker_id == db.WorkerId()) continue;

45
src/utils/future.hpp Normal file
View File

@ -0,0 +1,45 @@
#pragma once
/// @file
#include <future>
namespace utils {
/// Wraps an `std::future` object to ensure that upon destruction the
/// `std::future` is waited on.
template <typename TResult>
class Future {
public:
Future() {}
Future(std::future<TResult> future) : future_(std::move(future)) {}
Future(const Future &) = delete;
Future(Future &&) = default;
Future &operator=(const Future &) = delete;
Future &operator=(Future &&) = default;
~Future() {
if (future_.valid()) future_.wait();
}
/// Returns true if the future has the result available. NOTE: The behaviour
/// is undefined if future isn't valid, i.e. `future.valid() == false`.
bool IsReady() const {
auto status = future_.wait_for(std::chrono::seconds(0));
return status == std::future_status::ready;
}
auto get() { return future_.get(); }
auto wait() { return future_.wait(); }
auto valid() { return future_.valid(); }
private:
std::future<TResult> future_;
};
/// Creates a `Future` from the given `std::future`.
template <typename TResult>
Future<TResult> make_future(std::future<TResult> future) {
return Future<TResult>(std::move(future));
}
}

View File

@ -28,11 +28,14 @@ class DistributedGraphDbTest : public ::testing::Test {
};
protected:
virtual int QueryExecutionTimeSec(int) { return 180; }
void SetUp() override {
const auto kInitTime = 200ms;
database::Config master_config;
master_config.master_endpoint = {kLocal, 0};
master_config.query_execution_time_sec = QueryExecutionTimeSec(0);
master_ = std::make_unique<database::Master>(master_config);
std::this_thread::sleep_for(kInitTime);
@ -41,6 +44,7 @@ class DistributedGraphDbTest : public ::testing::Test {
config.worker_id = worker_id;
config.master_endpoint = master_->endpoint();
config.worker_endpoint = {kLocal, 0};
config.query_execution_time_sec = QueryExecutionTimeSec(worker_id);
return config;
};

View File

@ -6,6 +6,7 @@
#include "query/interpreter.hpp"
#include "query_common.hpp"
#include "query_plan_common.hpp"
#include "utils/timer.hpp"
using namespace distributed;
using namespace database;
@ -168,3 +169,53 @@ TEST_F(DistributedInterpretationTest, Cartesian) {
ASSERT_THAT(got, testing::UnorderedElementsAreArray(expected));
}
class TestQueryWaitsOnFutures : public DistributedInterpretationTest {
protected:
int QueryExecutionTimeSec(int worker_id) override {
return worker_id == 2 ? 3 : 1;
}
};
TEST_F(TestQueryWaitsOnFutures, Test) {
const int kVertexCount = 10;
auto make_fully_connected = [this](database::GraphDb &db) {
database::GraphDbAccessor dba(db);
std::vector<VertexAccessor> vertices;
for (int i = 0; i < kVertexCount; ++i)
vertices.emplace_back(dba.InsertVertex());
auto et = dba.EdgeType("et");
for (auto &from : vertices)
for (auto &to : vertices) dba.InsertEdge(from, to, et);
dba.Commit();
};
make_fully_connected(worker(1));
ASSERT_EQ(VertexCount(worker(1)), kVertexCount);
ASSERT_EQ(EdgeCount(worker(1)), kVertexCount * kVertexCount);
{
utils::Timer timer;
try {
Run("MATCH ()--()--()--()--()--()--() RETURN count(1)");
} catch (...) {
}
double seconds = timer.Elapsed().count();
EXPECT_GT(seconds, 1);
EXPECT_LT(seconds, 2);
}
make_fully_connected(worker(2));
ASSERT_EQ(VertexCount(worker(2)), kVertexCount);
ASSERT_EQ(EdgeCount(worker(2)), kVertexCount * kVertexCount);
{
utils::Timer timer;
try {
Run("MATCH ()--()--()--()--()--()--() RETURN count(1)");
} catch (...) {
}
double seconds = timer.Elapsed().count();
EXPECT_GT(seconds, 3);
}
}

View File

@ -325,10 +325,7 @@ TEST_F(DistributedGraphDbTest, PullRemoteOrderBy) {
class DistributedTransactionTimeout : public DistributedGraphDbTest {
protected:
void SetUp() override {
FLAGS_query_execution_time_sec = 1;
DistributedGraphDbTest::SetUp();
}
int QueryExecutionTimeSec(int) override { return 1; }
};
TEST_F(DistributedTransactionTimeout, Timeout) {