Add command id to remote produce

Summary:
Command id is necessary in remote produce to identify an ongoing pull
because a transaction can have multiple commands that all belong under
the same plan and tx id.

Reviewers: teon.banek, mtomic, buda

Reviewed By: teon.banek

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1386
This commit is contained in:
Matija Santl 2018-04-30 09:33:09 +02:00
parent 91e38f6413
commit f872c93ad1
13 changed files with 185 additions and 107 deletions

View File

@ -9,6 +9,8 @@
#include "communication/bolt/v1/codes.hpp"
#include "communication/bolt/v1/decoder/decoded_value.hpp"
#include "communication/bolt/v1/state.hpp"
#include "database/graph_db.hpp"
#include "distributed/pull_rpc_clients.hpp"
#include "query/exceptions.hpp"
#include "query/typed_value.hpp"
#include "utils/exceptions.hpp"
@ -120,6 +122,13 @@ State HandleRun(TSession &session, State state, Marker marker) {
return State::Result;
}
session.db_accessor_->AdvanceCommand();
if (session.db_.type() == database::GraphDb::Type::DISTRIBUTED_MASTER) {
auto tx_id = session.db_accessor_->transaction_id();
auto futures =
session.db_.pull_clients().NotifyAllTransactionCommandAdvanced(
tx_id);
for (auto &future : futures) future.wait();
}
}
auto &params_map = params.ValueMap();

View File

@ -112,7 +112,7 @@ void ProduceRpcServer::FinishAndClearOngoingProducePlans(
tx::TransactionId tx_id) {
std::lock_guard<std::mutex> guard{ongoing_produces_lock_};
for (auto it = ongoing_produces_.begin(); it != ongoing_produces_.end();) {
if (it->first.first == tx_id) {
if (std::get<0>(it->first) == tx_id) {
it = ongoing_produces_.erase(it);
} else {
++it;
@ -122,9 +122,9 @@ void ProduceRpcServer::FinishAndClearOngoingProducePlans(
ProduceRpcServer::OngoingProduce &ProduceRpcServer::GetOngoingProduce(
const PullReq &req) {
auto key_pair = std::make_pair(req.tx_id, req.plan_id);
auto key_tuple = std::make_tuple(req.tx_id, req.command_id, req.plan_id);
std::lock_guard<std::mutex> guard{ongoing_produces_lock_};
auto found = ongoing_produces_.find(key_pair);
auto found = ongoing_produces_.find(key_tuple);
if (found != ongoing_produces_.end()) {
return found->second;
}
@ -135,7 +135,7 @@ ProduceRpcServer::OngoingProduce &ProduceRpcServer::GetOngoingProduce(
}
auto &plan_pack = plan_consumer_.PlanForId(req.plan_id);
return ongoing_produces_
.emplace(std::piecewise_construct, std::forward_as_tuple(key_pair),
.emplace(std::piecewise_construct, std::forward_as_tuple(key_tuple),
std::forward_as_tuple(db_, req.tx_id, plan_pack.plan,
plan_pack.symbol_table, req.params,
req.symbols))

View File

@ -70,8 +70,11 @@ class ProduceRpcServer {
private:
std::mutex ongoing_produces_lock_;
/// Mapping of (tx id, plan id) to OngoingProduce.
std::map<std::pair<tx::TransactionId, int64_t>, OngoingProduce>
/// Mapping of (tx id, command id, plan id) to OngoingProduce.
/// The command_id should be the command_id at the initialization of a cursor
/// that can call ProduceRpcServer.
std::map<std::tuple<tx::TransactionId, tx::CommandId, int64_t>,
OngoingProduce>
ongoing_produces_;
database::GraphDb &db_;
communication::rpc::Server &produce_rpc_server_;

View File

@ -38,11 +38,13 @@ enum class PullState {
struct PullReq : public communication::rpc::Message {
PullReq() {}
PullReq(tx::TransactionId tx_id, tx::Snapshot tx_snapshot, int64_t plan_id,
const Parameters &params, std::vector<query::Symbol> symbols,
bool accumulate, int batch_size, bool send_old, bool send_new)
tx::CommandId command_id, const Parameters &params,
std::vector<query::Symbol> symbols, bool accumulate, int batch_size,
bool send_old, bool send_new)
: tx_id(tx_id),
tx_snapshot(tx_snapshot),
plan_id(plan_id),
command_id(command_id),
params(params),
symbols(symbols),
accumulate(accumulate),
@ -53,6 +55,7 @@ struct PullReq : public communication::rpc::Message {
tx::TransactionId tx_id;
tx::Snapshot tx_snapshot;
int64_t plan_id;
tx::CommandId command_id;
Parameters params;
std::vector<query::Symbol> symbols;
bool accumulate;
@ -70,6 +73,7 @@ struct PullReq : public communication::rpc::Message {
ar << tx_id;
ar << tx_snapshot;
ar << plan_id;
ar << command_id;
ar << params.size();
for (auto &kv : params) {
ar << kv.first;
@ -89,6 +93,7 @@ struct PullReq : public communication::rpc::Message {
ar >> tx_id;
ar >> tx_snapshot;
ar >> plan_id;
ar >> command_id;
size_t params_size;
ar >> params_size;
for (size_t i = 0; i < params_size; ++i) {

View File

@ -9,14 +9,15 @@ namespace distributed {
utils::Future<PullData> PullRpcClients::Pull(
database::GraphDbAccessor &dba, int worker_id, int64_t plan_id,
const Parameters &params, const std::vector<query::Symbol> &symbols,
bool accumulate, int batch_size) {
tx::CommandId command_id, const Parameters &params,
const std::vector<query::Symbol> &symbols, bool accumulate,
int batch_size) {
return clients_.ExecuteOnWorker<PullData>(
worker_id, [&dba, plan_id, params, symbols, accumulate,
worker_id, [&dba, plan_id, command_id, params, symbols, accumulate,
batch_size](ClientPool &client_pool) {
auto result = client_pool.Call<PullRpc>(
dba.transaction_id(), dba.transaction().snapshot(), plan_id, params,
symbols, accumulate, batch_size, true, true);
dba.transaction_id(), dba.transaction().snapshot(), plan_id,
command_id, params, symbols, accumulate, batch_size, true, true);
auto handle_vertex = [&dba](auto &v) {
dba.db()
@ -61,8 +62,7 @@ utils::Future<PullData> PullRpcClients::Pull(
}
std::vector<utils::Future<void>>
PullRpcClients::NotifyAllTransactionCommandAdvanced(
tx::TransactionId tx_id) {
PullRpcClients::NotifyAllTransactionCommandAdvanced(tx::TransactionId tx_id) {
return clients_.ExecuteOnWorkers<void>(0, [tx_id](auto &client) {
auto res = client.template Call<TransactionCommandAdvancedRpc>(tx_id);
CHECK(res) << "TransactionCommandAdvanceRpc failed";

View File

@ -23,14 +23,15 @@ class PullRpcClients {
PullRpcClients(RpcWorkerClients &clients) : clients_(clients) {}
/// Calls a remote pull asynchroniously. IMPORTANT: take care not to call this
/// function for the same (tx_id, worker_id, plan_id) before the previous call
/// has ended.
/// function for the same (tx_id, worker_id, plan_id, command_id) before the
/// previous call has ended.
///
/// @todo: it might be cleaner to split Pull into {InitRemoteCursor,
/// Pull, RemoteAccumulate}, but that's a lot of refactoring and more
/// RPC calls.
utils::Future<PullData> Pull(database::GraphDbAccessor &dba, int worker_id,
int64_t plan_id, const Parameters &params,
int64_t plan_id, tx::CommandId command_id,
const Parameters &params,
const std::vector<query::Symbol> &symbols,
bool accumulate,
int batch_size = kDefaultBatchSize);

View File

@ -412,32 +412,32 @@ std::unique_ptr<Cursor> ScanAllByLabelPropertyRange::MakeCursor(
-> std::experimental::optional<decltype(
db.Vertices(label_, property_, std::experimental::nullopt,
std::experimental::nullopt, false))> {
ExpressionEvaluator evaluator(frame, context.parameters_,
context.symbol_table_, db, graph_view_);
auto convert = [&evaluator](const auto &bound)
-> std::experimental::optional<utils::Bound<PropertyValue>> {
if (!bound) return std::experimental::nullopt;
auto value = bound->value()->Accept(evaluator);
try {
ExpressionEvaluator evaluator(frame, context.parameters_,
context.symbol_table_, db, graph_view_);
auto convert = [&evaluator](const auto &bound)
-> std::experimental::optional<utils::Bound<PropertyValue>> {
if (!bound) return std::experimental::nullopt;
auto value = bound->value()->Accept(evaluator);
try {
return std::experimental::make_optional(
utils::Bound<PropertyValue>(value, bound->type()));
} catch (const TypedValueException &) {
throw QueryRuntimeException(
"'{}' cannot be used as a property value.", value.type());
}
};
auto maybe_lower = convert(lower_bound());
auto maybe_upper = convert(upper_bound());
// If any bound is null, then the comparison would result in nulls. This
// is treated as not satisfying the filter, so return no vertices.
if (maybe_lower && maybe_lower->value().IsNull())
return std::experimental::nullopt;
if (maybe_upper && maybe_upper->value().IsNull())
return std::experimental::nullopt;
return std::experimental::make_optional(
utils::Bound<PropertyValue>(value, bound->type()));
} catch (const TypedValueException &) {
throw QueryRuntimeException("'{}' cannot be used as a property value.",
value.type());
}
};
auto maybe_lower = convert(lower_bound());
auto maybe_upper = convert(upper_bound());
// If any bound is null, then the comparison would result in nulls. This
// is treated as not satisfying the filter, so return no vertices.
if (maybe_lower && maybe_lower->value().IsNull())
return std::experimental::nullopt;
if (maybe_upper && maybe_upper->value().IsNull())
return std::experimental::nullopt;
return std::experimental::make_optional(
db.Vertices(label_, property_, maybe_lower, maybe_upper,
graph_view_ == GraphView::NEW));
};
db.Vertices(label_, property_, maybe_lower, maybe_upper,
graph_view_ == GraphView::NEW));
};
return std::make_unique<ScanAllCursor<decltype(vertices)>>(
output_symbol_, input_->MakeCursor(db), std::move(vertices), db);
}
@ -460,18 +460,18 @@ std::unique_ptr<Cursor> ScanAllByLabelPropertyValue::MakeCursor(
auto vertices = [this, &db](Frame &frame, Context &context)
-> std::experimental::optional<decltype(
db.Vertices(label_, property_, TypedValue::Null, false))> {
ExpressionEvaluator evaluator(frame, context.parameters_,
context.symbol_table_, db, graph_view_);
auto value = expression_->Accept(evaluator);
if (value.IsNull()) return std::experimental::nullopt;
try {
return std::experimental::make_optional(
db.Vertices(label_, property_, value, graph_view_ == GraphView::NEW));
} catch (const TypedValueException &) {
throw QueryRuntimeException("'{}' cannot be used as a property value.",
value.type());
}
};
ExpressionEvaluator evaluator(frame, context.parameters_,
context.symbol_table_, db, graph_view_);
auto value = expression_->Accept(evaluator);
if (value.IsNull()) return std::experimental::nullopt;
try {
return std::experimental::make_optional(db.Vertices(
label_, property_, value, graph_view_ == GraphView::NEW));
} catch (const TypedValueException &) {
throw QueryRuntimeException(
"'{}' cannot be used as a property value.", value.type());
}
};
return std::make_unique<ScanAllCursor<decltype(vertices)>>(
output_symbol_, input_->MakeCursor(db), std::move(vertices), db);
}
@ -1367,8 +1367,7 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor {
// For the given (edge, vertex, weight, depth) tuple checks if they
// satisfy the "where" condition. if so, places them in the priority queue.
auto expand_pair = [this, &evaluator, &frame, &create_state](
EdgeAccessor edge, VertexAccessor vertex,
double weight, int depth) {
EdgeAccessor edge, VertexAccessor vertex, double weight, int depth) {
SwitchAccessor(edge, self_.graph_view_);
SwitchAccessor(vertex, self_.graph_view_);
@ -3193,14 +3192,17 @@ std::vector<Symbol> PullRemoteOrderBy::ModifiedSymbols(
namespace {
/** Helper class that wraps remote pulling for cursors that handle results
* from distributed workers.
/** Helper class that wraps remote pulling for cursors that handle results from
* distributed workers.
*
* The command_id should be the command_id at the initialization of a cursor.
*/
class RemotePuller {
public:
RemotePuller(database::GraphDbAccessor &db,
const std::vector<Symbol> &symbols, int64_t plan_id)
: db_(db), symbols_(symbols), plan_id_(plan_id) {
const std::vector<Symbol> &symbols, int64_t plan_id,
tx::CommandId command_id)
: db_(db), symbols_(symbols), plan_id_(plan_id), command_id_(command_id) {
worker_ids_ = db_.db().pull_clients().GetWorkerIds();
// Remove master from the worker ids list.
worker_ids_.erase(std::find(worker_ids_.begin(), worker_ids_.end(), 0));
@ -3209,7 +3211,7 @@ class RemotePuller {
void Initialize(Context &context) {
if (!remote_pulls_initialized_) {
VLOG(10) << "[RemotePuller] [" << context.db_accessor_.transaction_id()
<< "] [" << plan_id_ << "] initialized";
<< "] [" << plan_id_ << "] [" << command_id_ << "] initialized";
for (auto &worker_id : worker_ids_) {
UpdatePullForWorker(worker_id, context);
}
@ -3223,7 +3225,8 @@ class RemotePuller {
auto move_frames = [this, &context](int worker_id, auto remote_results) {
VLOG(10) << "[RemotePuller] [" << context.db_accessor_.transaction_id()
<< "] [" << plan_id_ << "] received results from " << worker_id;
<< "] [" << plan_id_ << "] [" << command_id_
<< "] received results from " << worker_id;
remote_results_[worker_id] = std::move(remote_results.frames);
// Since we return and remove results from the back of the vector,
// reverse the results so the first to return is on the end of the
@ -3246,14 +3249,16 @@ class RemotePuller {
case distributed::PullState::CURSOR_EXHAUSTED:
VLOG(10) << "[RemotePuller] ["
<< context.db_accessor_.transaction_id() << "] [" << plan_id_
<< "] cursor exhausted from " << worker_id;
<< "] [" << command_id_ << "] cursor exhausted from "
<< worker_id;
move_frames(worker_id, remote_results);
remote_pulls_.erase(found_it);
break;
case distributed::PullState::CURSOR_IN_PROGRESS:
VLOG(10) << "[RemotePuller] ["
<< context.db_accessor_.transaction_id() << "] [" << plan_id_
<< "] cursor in progress from " << worker_id;
<< "] [" << command_id_ << "] cursor in progress from "
<< worker_id;
move_frames(worker_id, remote_results);
UpdatePullForWorker(worker_id, context);
break;
@ -3316,6 +3321,7 @@ class RemotePuller {
database::GraphDbAccessor &db_;
std::vector<Symbol> symbols_;
int64_t plan_id_;
tx::CommandId command_id_;
std::unordered_map<int, utils::Future<distributed::PullData>> remote_pulls_;
std::unordered_map<int, std::vector<std::vector<query::TypedValue>>>
remote_results_;
@ -3323,8 +3329,9 @@ class RemotePuller {
bool remote_pulls_initialized_ = false;
void UpdatePullForWorker(int worker_id, Context &context) {
remote_pulls_[worker_id] = db_.db().pull_clients().Pull(
db_, worker_id, plan_id_, context.parameters_, symbols_, false);
remote_pulls_[worker_id] =
db_.db().pull_clients().Pull(db_, worker_id, plan_id_, command_id_,
context.parameters_, symbols_, false);
}
};
@ -3333,7 +3340,9 @@ class PullRemoteCursor : public Cursor {
PullRemoteCursor(const PullRemote &self, database::GraphDbAccessor &db)
: self_(self),
input_cursor_(self.input() ? self.input()->MakeCursor(db) : nullptr),
remote_puller_(RemotePuller(db, self.symbols(), self.plan_id())) {}
command_id_(db.transaction().cid()),
remote_puller_(
RemotePuller(db, self.symbols(), self.plan_id(), command_id_)) {}
bool Pull(Frame &frame, Context &context) override {
if (context.db_accessor_.should_abort()) throw HintedAbortError();
@ -3369,13 +3378,15 @@ class PullRemoteCursor : public Cursor {
if (input_cursor_ && input_cursor_->Pull(frame, context)) {
VLOG(10) << "[PullRemoteCursor] ["
<< context.db_accessor_.transaction_id() << "] ["
<< self_.plan_id() << "] producing local results ";
<< self_.plan_id() << "] [" << command_id_
<< "] producing local results ";
return true;
}
VLOG(10) << "[PullRemoteCursor] ["
<< context.db_accessor_.transaction_id() << "] ["
<< self_.plan_id() << "] no results available, sleeping ";
<< self_.plan_id() << "] [" << command_id_
<< "] no results available, sleeping ";
// If there aren't any local/remote results available, sleep.
std::this_thread::sleep_for(
std::chrono::microseconds(FLAGS_remote_pull_sleep_micros));
@ -3387,7 +3398,8 @@ class PullRemoteCursor : public Cursor {
if (input_cursor_ && input_cursor_->Pull(frame, context)) {
VLOG(10) << "[PullRemoteCursor] ["
<< context.db_accessor_.transaction_id() << "] ["
<< self_.plan_id() << "] producing local results ";
<< self_.plan_id() << "] [" << command_id_
<< "] producing local results ";
return true;
}
return false;
@ -3397,8 +3409,8 @@ class PullRemoteCursor : public Cursor {
int worker_id = remote_puller_.GetWorkerId(last_pulled_worker_id_index_);
VLOG(10) << "[PullRemoteCursor] ["
<< context.db_accessor_.transaction_id() << "] ["
<< self_.plan_id() << "] producing results from worker "
<< worker_id;
<< self_.plan_id() << "] [" << command_id_
<< "] producing results from worker " << worker_id;
auto result = remote_puller_.PopResultFromWorker(worker_id);
for (size_t i = 0; i < self_.symbols().size(); ++i) {
frame[self_.symbols()[i]] = std::move(result[i]);
@ -3414,6 +3426,7 @@ class PullRemoteCursor : public Cursor {
private:
const PullRemote &self_;
const std::unique_ptr<Cursor> input_cursor_;
tx::CommandId command_id_;
RemotePuller remote_puller_;
int last_pulled_worker_id_index_ = 0;
};
@ -3424,8 +3437,8 @@ class SynchronizeCursor : public Cursor {
: self_(self),
input_cursor_(self.input()->MakeCursor(db)),
pull_remote_cursor_(
self.pull_remote() ? self.pull_remote()->MakeCursor(db) : nullptr) {
}
self.pull_remote() ? self.pull_remote()->MakeCursor(db) : nullptr),
command_id_(db.transaction().cid()) {}
bool Pull(Frame &frame, Context &context) override {
if (!initial_pull_done_) {
@ -3469,6 +3482,7 @@ class SynchronizeCursor : public Cursor {
const std::unique_ptr<Cursor> pull_remote_cursor_;
bool initial_pull_done_{false};
std::vector<std::vector<TypedValue>> local_frames_;
tx::CommandId command_id_;
void InitialPull(Frame &frame, Context &context) {
VLOG(10) << "[SynchronizeCursor] [" << context.db_accessor_.transaction_id()
@ -3482,7 +3496,8 @@ class SynchronizeCursor : public Cursor {
if (worker_id == db.WorkerId()) continue;
worker_accumulations.emplace_back(db.pull_clients().Pull(
context.db_accessor_, worker_id, self_.pull_remote()->plan_id(),
context.parameters_, self_.pull_remote()->symbols(), true, 0));
command_id_, context.parameters_, self_.pull_remote()->symbols(),
true, 0));
}
}
@ -3653,7 +3668,9 @@ class PullRemoteOrderByCursor : public Cursor {
database::GraphDbAccessor &db)
: self_(self),
input_(self.input()->MakeCursor(db)),
remote_puller_(RemotePuller(db, self.symbols(), self.plan_id())) {}
command_id_(db.transaction().cid()),
remote_puller_(
RemotePuller(db, self.symbols(), self.plan_id(), command_id_)) {}
bool Pull(Frame &frame, Context &context) {
if (context.db_accessor_.should_abort()) throw HintedAbortError();
@ -3680,7 +3697,7 @@ class PullRemoteOrderByCursor : public Cursor {
if (!merge_initialized_) {
VLOG(10) << "[PullRemoteOrderBy] ["
<< context.db_accessor_.transaction_id() << "] ["
<< self_.plan_id() << "] initialize";
<< self_.plan_id() << "] [" << command_id_ << "] initialize";
remote_puller_.Initialize(context);
missing_results_from_ = remote_puller_.Workers();
missing_master_result_ = true;
@ -3717,7 +3734,8 @@ class PullRemoteOrderByCursor : public Cursor {
if (!has_all_result) {
VLOG(10) << "[PullRemoteOrderByCursor] ["
<< context.db_accessor_.transaction_id() << "] ["
<< self_.plan_id() << "] missing results, sleep";
<< self_.plan_id() << "] [" << command_id_
<< "] missing results, sleep";
// If we don't have results from all workers, sleep before continuing.
std::this_thread::sleep_for(
std::chrono::microseconds(FLAGS_remote_pull_sleep_micros));
@ -3749,13 +3767,15 @@ class PullRemoteOrderByCursor : public Cursor {
if (result_it->worker_id) {
VLOG(10) << "[PullRemoteOrderByCursor] ["
<< context.db_accessor_.transaction_id() << "] ["
<< self_.plan_id() << "] producing results from worker "
<< self_.plan_id() << "] [" << command_id_
<< "] producing results from worker "
<< result_it->worker_id.value();
missing_results_from_.push_back(result_it->worker_id.value());
} else {
VLOG(10) << "[PullRemoteOrderByCursor] ["
<< context.db_accessor_.transaction_id() << "] ["
<< self_.plan_id() << "] producing local results";
<< self_.plan_id() << "] [" << command_id_
<< "] producing local results";
missing_master_result_ = true;
}
@ -3776,6 +3796,7 @@ class PullRemoteOrderByCursor : public Cursor {
const PullRemoteOrderBy &self_;
std::unique_ptr<Cursor> input_;
tx::CommandId command_id_;
RemotePuller remote_puller_;
std::vector<MergeResultItem> merge_;
std::vector<int> missing_results_from_;

View File

@ -577,12 +577,12 @@ std::unique_ptr<LogicalOperator> GenReturn(
Return &ret, std::unique_ptr<LogicalOperator> input_op,
SymbolTable &symbol_table, bool is_write,
const std::unordered_set<Symbol> &bound_symbols, AstTreeStorage &storage) {
// Similar to WITH clause, but we want to accumulate and advance command when
// the query writes to the database. This way we handle the case when we want
// to return expressions with the latest updated results. For example,
// `MATCH (n) -- () SET n.prop = n.prop + 1 RETURN n.prop`. If we match same
// `n` multiple 'k' times, we want to return 'k' results where the property
// value is the same, final result of 'k' increments.
// Similar to WITH clause, but we want to accumulate when the query writes to
// the database. This way we handle the case when we want to return
// expressions with the latest updated results. For example, `MATCH (n) -- ()
// SET n.prop = n.prop + 1 RETURN n.prop`. If we match same `n` multiple 'k'
// times, we want to return 'k' results where the property value is the same,
// final result of 'k' increments.
bool accumulate = is_write;
bool advance_command = false;
ReturnBodyContext body(ret.body_, symbol_table, bound_symbols, storage);

View File

@ -122,9 +122,7 @@ void WorkerEngine::LocalForEachActiveTransaction(
for (auto pair : active_.access()) f(*pair.second);
}
TransactionId WorkerEngine::LocalOldestActive() const {
return oldest_active_;
}
TransactionId WorkerEngine::LocalOldestActive() const { return oldest_active_; }
Transaction *WorkerEngine::RunningTransaction(TransactionId tx_id) {
auto accessor = active_.access();
@ -151,8 +149,7 @@ Transaction *WorkerEngine::RunningTransaction(TransactionId tx_id,
return insertion.first->second;
}
void WorkerEngine::ClearTransactionalCache(
TransactionId oldest_active) const {
void WorkerEngine::ClearTransactionalCache(TransactionId oldest_active) const {
auto access = active_.access();
for (auto kv : access) {
if (kv.first < oldest_active) {

View File

@ -68,7 +68,6 @@ class WorkerEngine : public Engine {
// Updates the oldest active transaction to the one from the snapshot. If the
// snapshot is empty, it's set to the given alternative.
void UpdateOldestActive(const Snapshot &snapshot,
TransactionId alternative);
void UpdateOldestActive(const Snapshot &snapshot, TransactionId alternative);
};
} // namespace tx

View File

@ -6,6 +6,7 @@
#include "database/graph_db.hpp"
#include "distributed/plan_consumer.hpp"
#include "distributed/pull_rpc_clients.hpp"
#include "distributed_common.hpp"
#include "query/interpreter.hpp"
#include "query_common.hpp"
@ -31,15 +32,20 @@ class DistributedInterpretationTest : public DistributedGraphDbTest {
DistributedGraphDbTest::TearDown();
}
auto Run(const std::string &query) {
auto RunWithDba(const std::string &query, GraphDbAccessor &dba) {
std::map<std::string, query::TypedValue> params = {};
GraphDbAccessor dba(master());
ResultStreamFaker result;
interpreter_.value()(query, dba, params, false).PullAll(result);
dba.Commit();
return result.GetResults();
}
auto Run(const std::string &query) {
GraphDbAccessor dba(master());
auto results = RunWithDba(query, dba);
dba.Commit();
return results;
}
private:
std::experimental::optional<query::Interpreter> interpreter_;
};
@ -270,6 +276,36 @@ TEST_F(DistributedInterpretationTest, ConcurrentPlanExpiration) {
for (auto &t : counters) t.join();
}
TEST_F(DistributedInterpretationTest, OngoingProduceKeyTest) {
int worker_count = 10;
for (int i = 0; i < worker_count; ++i) {
InsertVertex(master());
InsertVertex(worker(1));
InsertVertex(worker(2));
}
GraphDbAccessor dba(master());
auto count1 = RunWithDba("MATCH (n) RETURN count(n)", dba);
dba.AdvanceCommand();
auto count2 = RunWithDba("MATCH (n) RETURN count(n)", dba);
ASSERT_EQ(count1[0][0].ValueInt(), 3 * worker_count);
ASSERT_EQ(count2[0][0].ValueInt(), 3 * worker_count);
}
TEST_F(DistributedInterpretationTest, AdvanceCommandOnWorkers) {
GraphDbAccessor dba(master());
RunWithDba("UNWIND RANGE(1, 10) as x CREATE (:A {id: x})", dba);
dba.AdvanceCommand();
// Advance commands on workers also.
auto futures = dba.db().pull_clients().NotifyAllTransactionCommandAdvanced(
dba.transaction_id());
for (auto &future : futures) future.wait();
auto count = RunWithDba("MATCH (n) RETURN count(n)", dba);
ASSERT_EQ(count[0][0].ValueInt(), 10);
}
int main(int argc, char **argv) {
google::InitGoogleLogging(argv[0]);
::testing::InitGoogleTest(&argc, argv);

View File

@ -56,12 +56,13 @@ TEST_F(DistributedGraphDbTest, PullProduceRpc) {
const int plan_id = 42;
master().plan_dispatcher().DispatchPlan(plan_id, produce, ctx.symbol_table_);
tx::CommandId command_id = dba.transaction().cid();
Parameters params;
std::vector<query::Symbol> symbols{ctx.symbol_table_[*x_ne]};
auto remote_pull = [this, &params, &symbols](GraphDbAccessor &dba,
int worker_id) {
return master().pull_clients().Pull(dba, worker_id, plan_id, params,
symbols, false, 3);
auto remote_pull = [this, &command_id, &params, &symbols](
GraphDbAccessor &dba, int worker_id) {
return master().pull_clients().Pull(dba, worker_id, plan_id, command_id,
params, symbols, false, 3);
};
auto expect_first_batch = [](auto &batch) {
EXPECT_EQ(batch.pull_state, distributed::PullState::CURSOR_IN_PROGRESS);
@ -174,13 +175,14 @@ TEST_F(DistributedGraphDbTest, PullProduceRpcWithGraphElements) {
const int plan_id = 42;
master().plan_dispatcher().DispatchPlan(plan_id, produce, ctx.symbol_table_);
tx::CommandId command_id = dba.transaction().cid();
Parameters params;
std::vector<query::Symbol> symbols{ctx.symbol_table_[*return_n_r],
ctx.symbol_table_[*return_m], p_sym};
auto remote_pull = [this, &params, &symbols](GraphDbAccessor &dba,
int worker_id) {
return master().pull_clients().Pull(dba, worker_id, plan_id, params,
symbols, false, 3);
auto remote_pull = [this, &command_id, &params, &symbols](
GraphDbAccessor &dba, int worker_id) {
return master().pull_clients().Pull(dba, worker_id, plan_id, command_id,
params, symbols, false, 3);
};
auto future_w1_results = remote_pull(dba, 1);
auto future_w2_results = remote_pull(dba, 2);
@ -346,13 +348,14 @@ TEST_F(DistributedTransactionTimeout, Timeout) {
const int plan_id = 42;
master().plan_dispatcher().DispatchPlan(plan_id, produce, ctx.symbol_table_);
tx::CommandId command_id = dba.transaction().cid();
Parameters params;
std::vector<query::Symbol> symbols{ctx.symbol_table_[*output]};
auto remote_pull = [this, &params, &symbols, &dba]() {
auto remote_pull = [this, &command_id, &params, &symbols, &dba]() {
return master()
.pull_clients()
.Pull(dba, 1, plan_id, params, symbols, false, 1)
.Pull(dba, 1, plan_id, command_id, params, symbols, false, 1)
.get()
.pull_state;
};

View File

@ -1,3 +1,5 @@
#include <mutex>
#include "boost/archive/binary_iarchive.hpp"
#include "boost/archive/binary_oarchive.hpp"
#include "boost/serialization/export.hpp"
@ -50,6 +52,7 @@ class RpcWorkerClientsTest : public ::testing::Test {
workers_server_.back()->Register<distributed::IncrementCounterRpc>(
[this, i](const distributed::IncrementCounterReq &) {
std::unique_lock<std::mutex> lock(mutex_);
workers_cnt_[i]++;
return std::make_unique<distributed::IncrementCounterRes>();
});
@ -76,6 +79,7 @@ class RpcWorkerClientsTest : public ::testing::Test {
std::vector<std::unique_ptr<distributed::WorkerCoordination>> workers_coord_;
std::vector<std::unique_ptr<distributed::ClusterDiscoveryWorker>>
cluster_discovery_;
std::mutex mutex_;
std::unordered_map<int, int> workers_cnt_;
communication::rpc::Server master_server_{kLocalHost};