Implement distributed aware OrderBy operator
Summary: Extracted `TypedValueVectorCompare` and `RemotePuller` from operators so it can be reused. The new `PullRemoteOrerBy` operator pulls one result from each worker and one from master, relies on the fact that workers/master returned sorted results, returns the next one in order, and pulls the source of that result to get the next one. Depends on D1215 that (at the moment) is still in review. Reviewers: florijan, teon.banek Reviewed By: teon.banek Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D1221
This commit is contained in:
parent
cbc9420c17
commit
b9d61a0127
@ -213,4 +213,70 @@ void ReconstructTypedValue(TypedValue &value) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
bool TypedValueVectorCompare::operator()(
|
||||
const std::vector<TypedValue> &c1,
|
||||
const std::vector<TypedValue> &c2) const {
|
||||
// ordering is invalid if there are more elements in the collections
|
||||
// then there are in the ordering_ vector
|
||||
DCHECK(c1.size() <= ordering_.size() && c2.size() <= ordering_.size())
|
||||
<< "Collections contain more elements then there are orderings";
|
||||
|
||||
auto c1_it = c1.begin();
|
||||
auto c2_it = c2.begin();
|
||||
auto ordering_it = ordering_.begin();
|
||||
for (; c1_it != c1.end() && c2_it != c2.end();
|
||||
c1_it++, c2_it++, ordering_it++) {
|
||||
if (TypedValueCompare(*c1_it, *c2_it)) return *ordering_it == Ordering::ASC;
|
||||
if (TypedValueCompare(*c2_it, *c1_it))
|
||||
return *ordering_it == Ordering::DESC;
|
||||
}
|
||||
|
||||
// at least one collection is exhausted
|
||||
// c1 is less then c2 iff c1 reached the end but c2 didn't
|
||||
return (c1_it == c1.end()) && (c2_it != c2.end());
|
||||
}
|
||||
|
||||
bool TypedValueVectorCompare::TypedValueCompare(const TypedValue &a,
|
||||
const TypedValue &b) const {
|
||||
// in ordering null comes after everything else
|
||||
// at the same time Null is not less that null
|
||||
// first deal with Null < Whatever case
|
||||
if (a.IsNull()) return false;
|
||||
// now deal with NotNull < Null case
|
||||
if (b.IsNull()) return true;
|
||||
|
||||
// comparisons are from this point legal only between values of
|
||||
// the same type, or int+float combinations
|
||||
if ((a.type() != b.type() && !(a.IsNumeric() && b.IsNumeric())))
|
||||
throw QueryRuntimeException(
|
||||
"Can't compare value of type {} to value of type {}", a.type(),
|
||||
b.type());
|
||||
|
||||
switch (a.type()) {
|
||||
case TypedValue::Type::Bool:
|
||||
return !a.Value<bool>() && b.Value<bool>();
|
||||
case TypedValue::Type::Int:
|
||||
if (b.type() == TypedValue::Type::Double)
|
||||
return a.Value<int64_t>() < b.Value<double>();
|
||||
else
|
||||
return a.Value<int64_t>() < b.Value<int64_t>();
|
||||
case TypedValue::Type::Double:
|
||||
if (b.type() == TypedValue::Type::Int)
|
||||
return a.Value<double>() < b.Value<int64_t>();
|
||||
else
|
||||
return a.Value<double>() < b.Value<double>();
|
||||
case TypedValue::Type::String:
|
||||
return a.Value<std::string>() < b.Value<std::string>();
|
||||
case TypedValue::Type::List:
|
||||
case TypedValue::Type::Map:
|
||||
case TypedValue::Type::Vertex:
|
||||
case TypedValue::Type::Edge:
|
||||
case TypedValue::Type::Path:
|
||||
throw QueryRuntimeException(
|
||||
"Comparison is not defined for values of type {}", a.type());
|
||||
default:
|
||||
LOG(FATAL) << "Unhandled comparison for types";
|
||||
}
|
||||
}
|
||||
} // namespace query
|
||||
|
@ -3,6 +3,8 @@
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
#include "boost/serialization/serialization.hpp"
|
||||
#include "query/frontend/ast/ast.hpp"
|
||||
#include "query/typed_value.hpp"
|
||||
|
||||
namespace query {
|
||||
@ -33,4 +35,35 @@ enum class GraphView { AS_IS, OLD, NEW };
|
||||
* @returns - If the reconstruction succeeded.
|
||||
*/
|
||||
void ReconstructTypedValue(TypedValue &value);
|
||||
|
||||
// Custom Comparator type for comparing vectors of TypedValues.
|
||||
//
|
||||
// Does lexicographical ordering of elements based on the above
|
||||
// defined TypedValueCompare, and also accepts a vector of Orderings
|
||||
// the define how respective elements compare.
|
||||
class TypedValueVectorCompare {
|
||||
public:
|
||||
TypedValueVectorCompare() {}
|
||||
explicit TypedValueVectorCompare(const std::vector<Ordering> &ordering)
|
||||
: ordering_(ordering) {}
|
||||
bool operator()(const std::vector<TypedValue> &c1,
|
||||
const std::vector<TypedValue> &c2) const;
|
||||
|
||||
private:
|
||||
std::vector<Ordering> ordering_;
|
||||
|
||||
friend class boost::serialization::access;
|
||||
|
||||
template <class TArchive>
|
||||
void serialize(TArchive &ar, const unsigned int) {
|
||||
ar &ordering_;
|
||||
}
|
||||
// Custom comparison for TypedValue objects.
|
||||
//
|
||||
// Behaves generally like Neo's ORDER BY comparison operator:
|
||||
// - null is greater than anything else
|
||||
// - primitives compare naturally, only implicit cast is int->double
|
||||
// - (list, map, path, vertex, edge) can't compare to anything
|
||||
bool TypedValueCompare(const TypedValue &a, const TypedValue &b) const;
|
||||
};
|
||||
}
|
||||
|
@ -397,10 +397,10 @@ std::unique_ptr<Cursor> ScanAllByLabelPropertyRange::MakeCursor(
|
||||
context.symbol_table_, db, graph_view_);
|
||||
auto convert = [&evaluator](const auto &bound)
|
||||
-> std::experimental::optional<utils::Bound<PropertyValue>> {
|
||||
if (!bound) return std::experimental::nullopt;
|
||||
return std::experimental::make_optional(utils::Bound<PropertyValue>(
|
||||
bound.value().value()->Accept(evaluator), bound.value().type()));
|
||||
};
|
||||
if (!bound) return std::experimental::nullopt;
|
||||
return std::experimental::make_optional(utils::Bound<PropertyValue>(
|
||||
bound.value().value()->Accept(evaluator), bound.value().type()));
|
||||
};
|
||||
return db.Vertices(label_, property_, convert(lower_bound()),
|
||||
convert(upper_bound()), graph_view_ == GraphView::NEW);
|
||||
};
|
||||
@ -1216,9 +1216,8 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor {
|
||||
self_.graph_view_);
|
||||
// For the given (vertex, edge, vertex) tuple checks if they satisfy the
|
||||
// "where" condition. if so, places them in the priority queue.
|
||||
auto expand_pair = [this, &evaluator, &frame](VertexAccessor from,
|
||||
EdgeAccessor edge,
|
||||
VertexAccessor vertex) {
|
||||
auto expand_pair = [this, &evaluator, &frame](
|
||||
VertexAccessor from, EdgeAccessor edge, VertexAccessor vertex) {
|
||||
SwitchAccessor(edge, self_.graph_view_);
|
||||
SwitchAccessor(vertex, self_.graph_view_);
|
||||
|
||||
@ -2567,72 +2566,6 @@ void OrderBy::OrderByCursor::Reset() {
|
||||
cache_it_ = cache_.begin();
|
||||
}
|
||||
|
||||
bool OrderBy::TypedValueCompare(const TypedValue &a, const TypedValue &b) {
|
||||
// in ordering null comes after everything else
|
||||
// at the same time Null is not less that null
|
||||
// first deal with Null < Whatever case
|
||||
if (a.IsNull()) return false;
|
||||
// now deal with NotNull < Null case
|
||||
if (b.IsNull()) return true;
|
||||
|
||||
// comparisons are from this point legal only between values of
|
||||
// the same type, or int+float combinations
|
||||
if ((a.type() != b.type() && !(a.IsNumeric() && b.IsNumeric())))
|
||||
throw QueryRuntimeException(
|
||||
"Can't compare value of type {} to value of type {}", a.type(),
|
||||
b.type());
|
||||
|
||||
switch (a.type()) {
|
||||
case TypedValue::Type::Bool:
|
||||
return !a.Value<bool>() && b.Value<bool>();
|
||||
case TypedValue::Type::Int:
|
||||
if (b.type() == TypedValue::Type::Double)
|
||||
return a.Value<int64_t>() < b.Value<double>();
|
||||
else
|
||||
return a.Value<int64_t>() < b.Value<int64_t>();
|
||||
case TypedValue::Type::Double:
|
||||
if (b.type() == TypedValue::Type::Int)
|
||||
return a.Value<double>() < b.Value<int64_t>();
|
||||
else
|
||||
return a.Value<double>() < b.Value<double>();
|
||||
case TypedValue::Type::String:
|
||||
return a.Value<std::string>() < b.Value<std::string>();
|
||||
case TypedValue::Type::List:
|
||||
case TypedValue::Type::Map:
|
||||
case TypedValue::Type::Vertex:
|
||||
case TypedValue::Type::Edge:
|
||||
case TypedValue::Type::Path:
|
||||
throw QueryRuntimeException(
|
||||
"Comparison is not defined for values of type {}", a.type());
|
||||
default:
|
||||
LOG(FATAL) << "Unhandled comparison for types";
|
||||
}
|
||||
}
|
||||
|
||||
bool OrderBy::TypedValueVectorCompare::operator()(
|
||||
const std::vector<TypedValue> &c1,
|
||||
const std::vector<TypedValue> &c2) const {
|
||||
// ordering is invalid if there are more elements in the collections
|
||||
// then there are in the ordering_ vector
|
||||
DCHECK(c1.size() <= ordering_.size() && c2.size() <= ordering_.size())
|
||||
<< "Collections contain more elements then there are orderings";
|
||||
|
||||
auto c1_it = c1.begin();
|
||||
auto c2_it = c2.begin();
|
||||
auto ordering_it = ordering_.begin();
|
||||
for (; c1_it != c1.end() && c2_it != c2.end();
|
||||
c1_it++, c2_it++, ordering_it++) {
|
||||
if (OrderBy::TypedValueCompare(*c1_it, *c2_it))
|
||||
return *ordering_it == Ordering::ASC;
|
||||
if (OrderBy::TypedValueCompare(*c2_it, *c1_it))
|
||||
return *ordering_it == Ordering::DESC;
|
||||
}
|
||||
|
||||
// at least one collection is exhausted
|
||||
// c1 is less then c2 iff c1 reached the end but c2 didn't
|
||||
return (c1_it == c1.end()) && (c2_it != c2.end());
|
||||
}
|
||||
|
||||
Merge::Merge(const std::shared_ptr<LogicalOperator> &input,
|
||||
const std::shared_ptr<LogicalOperator> &merge_match,
|
||||
const std::shared_ptr<LogicalOperator> &merge_create)
|
||||
@ -2985,10 +2918,6 @@ void Union::UnionCursor::Reset() {
|
||||
right_cursor_->Reset();
|
||||
}
|
||||
|
||||
PullRemote::PullRemote(const std::shared_ptr<LogicalOperator> &input,
|
||||
int64_t plan_id, const std::vector<Symbol> &symbols)
|
||||
: input_(input), plan_id_(plan_id), symbols_(symbols) {}
|
||||
|
||||
ACCEPT_WITH_INPUT(PullRemote);
|
||||
|
||||
std::vector<Symbol> PullRemote::OutputSymbols(const SymbolTable &table) const {
|
||||
@ -3005,34 +2934,103 @@ std::vector<Symbol> PullRemote::ModifiedSymbols(
|
||||
return symbols;
|
||||
}
|
||||
|
||||
PullRemote::PullRemoteCursor::PullRemoteCursor(const PullRemote &self,
|
||||
database::GraphDbAccessor &db)
|
||||
: self_(self),
|
||||
db_(db),
|
||||
input_cursor_(self.input_ ? self.input_->MakeCursor(db) : nullptr) {
|
||||
worker_ids_ = db_.db().remote_pull_clients().GetWorkerIds();
|
||||
// Remove master from the worker ids list.
|
||||
worker_ids_.erase(std::find(worker_ids_.begin(), worker_ids_.end(), 0));
|
||||
std::vector<Symbol> Synchronize::ModifiedSymbols(
|
||||
const SymbolTable &table) const {
|
||||
auto symbols = input_->ModifiedSymbols(table);
|
||||
if (pull_remote_) {
|
||||
auto pull_symbols = pull_remote_->ModifiedSymbols(table);
|
||||
symbols.insert(symbols.end(), pull_symbols.begin(), pull_symbols.end());
|
||||
}
|
||||
return symbols;
|
||||
}
|
||||
|
||||
bool PullRemote::PullRemoteCursor::Pull(Frame &frame, Context &context) {
|
||||
auto insert_future_for_worker = [&](int worker_id) {
|
||||
remote_pulls_[worker_id] = db_.db().remote_pull_clients().RemotePull(
|
||||
db_, worker_id, self_.plan_id(), context.parameters_, self_.symbols(),
|
||||
false);
|
||||
};
|
||||
bool Synchronize::Accept(HierarchicalLogicalOperatorVisitor &visitor) {
|
||||
if (visitor.PreVisit(*this)) {
|
||||
input_->Accept(visitor) && pull_remote_->Accept(visitor);
|
||||
}
|
||||
return visitor.PostVisit(*this);
|
||||
}
|
||||
|
||||
if (!remote_pulls_initialized_) {
|
||||
for (auto &worker_id : worker_ids_) {
|
||||
insert_future_for_worker(worker_id);
|
||||
}
|
||||
remote_pulls_initialized_ = true;
|
||||
std::vector<Symbol> Cartesian::ModifiedSymbols(const SymbolTable &table) const {
|
||||
auto symbols = left_op_->ModifiedSymbols(table);
|
||||
auto right = right_op_->ModifiedSymbols(table);
|
||||
symbols.insert(symbols.end(), right.begin(), right.end());
|
||||
return symbols;
|
||||
}
|
||||
|
||||
bool Cartesian::Accept(HierarchicalLogicalOperatorVisitor &visitor) {
|
||||
if (visitor.PreVisit(*this)) {
|
||||
left_op_->Accept(visitor) && right_op_->Accept(visitor);
|
||||
}
|
||||
return visitor.PostVisit(*this);
|
||||
}
|
||||
|
||||
PullRemoteOrderBy::PullRemoteOrderBy(
|
||||
const std::shared_ptr<LogicalOperator> &input, int64_t plan_id,
|
||||
const std::vector<std::pair<Ordering, Expression *>> &order_by,
|
||||
const std::vector<Symbol> &symbols)
|
||||
: input_(input), plan_id_(plan_id), symbols_(symbols) {
|
||||
CHECK(input_ != nullptr)
|
||||
<< "PullRemoteOrderBy should always be constructed with input!";
|
||||
std::vector<Ordering> ordering;
|
||||
ordering.reserve(order_by.size());
|
||||
order_by_.reserve(order_by.size());
|
||||
for (const auto &ordering_expression_pair : order_by) {
|
||||
ordering.emplace_back(ordering_expression_pair.first);
|
||||
order_by_.emplace_back(ordering_expression_pair.second);
|
||||
}
|
||||
compare_ = TypedValueVectorCompare(ordering);
|
||||
}
|
||||
|
||||
ACCEPT_WITH_INPUT(PullRemoteOrderBy);
|
||||
|
||||
std::vector<Symbol> PullRemoteOrderBy::OutputSymbols(
|
||||
const SymbolTable &table) const {
|
||||
return input_->OutputSymbols(table);
|
||||
}
|
||||
|
||||
std::vector<Symbol> PullRemoteOrderBy::ModifiedSymbols(
|
||||
const SymbolTable &table) const {
|
||||
return input_->ModifiedSymbols(table);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/** Helper class that wraps remote pulling for cursors that handle results from
|
||||
* distributed workers.
|
||||
*/
|
||||
class RemotePuller {
|
||||
public:
|
||||
RemotePuller(database::GraphDbAccessor &db,
|
||||
const std::vector<Symbol> &symbols, int64_t plan_id)
|
||||
: db_(db), symbols_(symbols), plan_id_(plan_id) {
|
||||
worker_ids_ = db_.db().remote_pull_clients().GetWorkerIds();
|
||||
// Remove master from the worker ids list.
|
||||
worker_ids_.erase(std::find(worker_ids_.begin(), worker_ids_.end(), 0));
|
||||
}
|
||||
|
||||
bool have_remote_results = false;
|
||||
while (!have_remote_results && !worker_ids_.empty()) {
|
||||
void Initialize(Context &context) {
|
||||
if (!remote_pulls_initialized_) {
|
||||
for (auto &worker_id : worker_ids_) {
|
||||
UpdateRemotePullForWorker(worker_id, context);
|
||||
}
|
||||
remote_pulls_initialized_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
void Update(Context &context) {
|
||||
// If we don't have results for a worker, check if his remote pull
|
||||
// finished and save results locally.
|
||||
|
||||
auto move_frames = [this](int worker_id, auto remote_results) {
|
||||
remote_results_[worker_id] = std::move(remote_results.frames);
|
||||
// Since we return and remove results from the back of the vector,
|
||||
// reverse the results so the first to return is on the end of the
|
||||
// vector.
|
||||
std::reverse(remote_results_[worker_id].begin(),
|
||||
remote_results_[worker_id].end());
|
||||
};
|
||||
|
||||
for (auto &worker_id : worker_ids_) {
|
||||
if (!remote_results_[worker_id].empty()) continue;
|
||||
|
||||
@ -3045,12 +3043,12 @@ bool PullRemote::PullRemoteCursor::Pull(Frame &frame, Context &context) {
|
||||
auto remote_results = remote_pull.get();
|
||||
switch (remote_results.pull_state) {
|
||||
case distributed::RemotePullState::CURSOR_EXHAUSTED:
|
||||
remote_results_[worker_id] = std::move(remote_results.frames);
|
||||
move_frames(worker_id, remote_results);
|
||||
remote_pulls_.erase(found_it);
|
||||
break;
|
||||
case distributed::RemotePullState::CURSOR_IN_PROGRESS:
|
||||
remote_results_[worker_id] = std::move(remote_results.frames);
|
||||
insert_future_for_worker(worker_id);
|
||||
move_frames(worker_id, remote_results);
|
||||
UpdateRemotePullForWorker(worker_id, context);
|
||||
break;
|
||||
case distributed::RemotePullState::SERIALIZATION_ERROR:
|
||||
throw mvcc::SerializationError(
|
||||
@ -3068,94 +3066,133 @@ bool PullRemote::PullRemoteCursor::Pull(Frame &frame, Context &context) {
|
||||
"Query runtime error occurred duing PullRemote !");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get locally stored results from workers in a round-robin fasion.
|
||||
int num_workers = worker_ids_.size();
|
||||
for (int i = 0; i < num_workers; ++i) {
|
||||
int worker_id_index =
|
||||
(last_pulled_worker_id_index_ + i + 1) % num_workers;
|
||||
int worker_id = worker_ids_[worker_id_index];
|
||||
auto Workers() { return worker_ids_; }
|
||||
|
||||
if (!remote_results_[worker_id].empty()) {
|
||||
last_pulled_worker_id_index_ = worker_id_index;
|
||||
have_remote_results = true;
|
||||
break;
|
||||
int GetWorkerId(int worker_id_index) { return worker_ids_[worker_id_index]; }
|
||||
|
||||
size_t WorkerCount() { return worker_ids_.size(); }
|
||||
|
||||
void ClearWorkers() { worker_ids_.clear(); }
|
||||
|
||||
bool HasPendingPulls() { return !remote_pulls_.empty(); }
|
||||
|
||||
bool HasPendingPullFromWorker(int worker_id) {
|
||||
return remote_pulls_.find(worker_id) != remote_pulls_.end();
|
||||
}
|
||||
|
||||
bool HasResultsFromWorker(int worker_id) {
|
||||
return !remote_results_[worker_id].empty();
|
||||
}
|
||||
|
||||
std::vector<query::TypedValue> PopResultFromWorker(int worker_id) {
|
||||
auto result = remote_results_[worker_id].back();
|
||||
remote_results_[worker_id].pop_back();
|
||||
|
||||
// Remove the worker if we exhausted all locally stored results and there
|
||||
// are no more pending remote pulls for that worker.
|
||||
if (remote_results_[worker_id].empty() &&
|
||||
remote_pulls_.find(worker_id) == remote_pulls_.end()) {
|
||||
worker_ids_.erase(
|
||||
std::find(worker_ids_.begin(), worker_ids_.end(), worker_id));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
database::GraphDbAccessor &db_;
|
||||
std::vector<Symbol> symbols_;
|
||||
int64_t plan_id_;
|
||||
std::unordered_map<int, std::future<distributed::RemotePullData>>
|
||||
remote_pulls_;
|
||||
std::unordered_map<int, std::vector<std::vector<query::TypedValue>>>
|
||||
remote_results_;
|
||||
std::vector<int> worker_ids_;
|
||||
bool remote_pulls_initialized_ = false;
|
||||
|
||||
void UpdateRemotePullForWorker(int worker_id, Context &context) {
|
||||
remote_pulls_[worker_id] = db_.db().remote_pull_clients().RemotePull(
|
||||
db_, worker_id, plan_id_, context.parameters_, symbols_, false);
|
||||
}
|
||||
};
|
||||
|
||||
class PullRemoteCursor : public Cursor {
|
||||
public:
|
||||
PullRemoteCursor(const PullRemote &self, database::GraphDbAccessor &db)
|
||||
: self_(self),
|
||||
input_cursor_(self.input() ? self.input()->MakeCursor(db) : nullptr),
|
||||
remote_puller_(RemotePuller(db, self.symbols(), self.plan_id())) {}
|
||||
|
||||
bool Pull(Frame &frame, Context &context) override {
|
||||
remote_puller_.Initialize(context);
|
||||
|
||||
bool have_remote_results = false;
|
||||
while (!have_remote_results && remote_puller_.WorkerCount() > 0) {
|
||||
remote_puller_.Update(context);
|
||||
|
||||
// Get locally stored results from workers in a round-robin fasion.
|
||||
int num_workers = remote_puller_.WorkerCount();
|
||||
for (int i = 0; i < num_workers; ++i) {
|
||||
int worker_id_index =
|
||||
(last_pulled_worker_id_index_ + i + 1) % num_workers;
|
||||
int worker_id = remote_puller_.GetWorkerId(worker_id_index);
|
||||
|
||||
if (remote_puller_.HasResultsFromWorker(worker_id)) {
|
||||
last_pulled_worker_id_index_ = worker_id_index;
|
||||
have_remote_results = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!have_remote_results) {
|
||||
if (!remote_puller_.HasPendingPulls()) {
|
||||
remote_puller_.ClearWorkers();
|
||||
break;
|
||||
}
|
||||
|
||||
// If there are no remote results available, try to pull and return
|
||||
// local results.
|
||||
if (input_cursor_ && input_cursor_->Pull(frame, context)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// If there aren't any local/remote results available, sleep.
|
||||
std::this_thread::sleep_for(
|
||||
std::chrono::milliseconds(FLAGS_remote_pull_sleep));
|
||||
}
|
||||
}
|
||||
|
||||
// No more remote results, make sure local results get exhausted.
|
||||
if (!have_remote_results) {
|
||||
if (remote_pulls_.empty()) {
|
||||
worker_ids_.clear();
|
||||
break;
|
||||
}
|
||||
|
||||
// If there are no remote results available, try to pull and return local
|
||||
// results.
|
||||
if (input_cursor_ && input_cursor_->Pull(frame, context)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// If there aren't any local/remote results available, sleep.
|
||||
std::this_thread::sleep_for(
|
||||
std::chrono::milliseconds(FLAGS_remote_pull_sleep));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// No more remote results, make sure local results get exhausted.
|
||||
if (!have_remote_results) {
|
||||
if (input_cursor_ && input_cursor_->Pull(frame, context)) {
|
||||
return true;
|
||||
{
|
||||
int worker_id = remote_puller_.GetWorkerId(last_pulled_worker_id_index_);
|
||||
auto result = remote_puller_.PopResultFromWorker(worker_id);
|
||||
for (size_t i = 0; i < self_.symbols().size(); ++i) {
|
||||
frame[self_.symbols()[i]] = std::move(result[i]);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
int pull_from_worker_id = worker_ids_[last_pulled_worker_id_index_];
|
||||
{
|
||||
auto &result = remote_results_[pull_from_worker_id].back();
|
||||
for (size_t i = 0; i < self_.symbols().size(); ++i) {
|
||||
frame[self_.symbols()[i]] = std::move(result[i]);
|
||||
}
|
||||
}
|
||||
remote_results_[pull_from_worker_id].resize(
|
||||
remote_results_[pull_from_worker_id].size() - 1);
|
||||
|
||||
// Remove the worker if we exhausted all locally stored results and there
|
||||
// are no more pending remote pulls for that worker.
|
||||
if (remote_results_[pull_from_worker_id].empty() &&
|
||||
remote_pulls_.find(pull_from_worker_id) == remote_pulls_.end()) {
|
||||
worker_ids_.erase(worker_ids_.begin() + last_pulled_worker_id_index_);
|
||||
void Reset() override {
|
||||
throw QueryRuntimeException("Unsupported: Reset during PullRemote!");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
private:
|
||||
const PullRemote &self_;
|
||||
const std::unique_ptr<Cursor> input_cursor_;
|
||||
RemotePuller remote_puller_;
|
||||
int last_pulled_worker_id_index_ = 0;
|
||||
};
|
||||
|
||||
void PullRemote::PullRemoteCursor::Reset() {
|
||||
throw QueryRuntimeException("Unsupported: Reset during PullRemote!");
|
||||
}
|
||||
|
||||
std::unique_ptr<Cursor> PullRemote::MakeCursor(
|
||||
database::GraphDbAccessor &db) const {
|
||||
return std::make_unique<PullRemote::PullRemoteCursor>(*this, db);
|
||||
}
|
||||
|
||||
bool Synchronize::Accept(HierarchicalLogicalOperatorVisitor &visitor) {
|
||||
if (visitor.PreVisit(*this)) {
|
||||
input_->Accept(visitor) && pull_remote_->Accept(visitor);
|
||||
}
|
||||
return visitor.PostVisit(*this);
|
||||
}
|
||||
|
||||
std::vector<Symbol> Synchronize::ModifiedSymbols(
|
||||
const SymbolTable &table) const {
|
||||
auto symbols = input_->ModifiedSymbols(table);
|
||||
if (pull_remote_) {
|
||||
auto pull_symbols = pull_remote_->ModifiedSymbols(table);
|
||||
symbols.insert(symbols.end(), pull_symbols.begin(), pull_symbols.end());
|
||||
}
|
||||
return symbols;
|
||||
}
|
||||
|
||||
namespace {
|
||||
class SynchronizeCursor : public Cursor {
|
||||
public:
|
||||
SynchronizeCursor(const Synchronize &self, database::GraphDbAccessor &db)
|
||||
@ -3365,30 +3402,148 @@ class CartesianCursor : public Cursor {
|
||||
bool cartesian_pull_initialized_{false};
|
||||
};
|
||||
|
||||
class PullRemoteOrderByCursor : public Cursor {
|
||||
public:
|
||||
PullRemoteOrderByCursor(const PullRemoteOrderBy &self,
|
||||
database::GraphDbAccessor &db)
|
||||
: self_(self),
|
||||
db_(db),
|
||||
input_(self.input()->MakeCursor(db)),
|
||||
remote_puller_(RemotePuller(db, self.symbols(), self.plan_id())) {}
|
||||
|
||||
bool Pull(Frame &frame, Context &context) {
|
||||
ExpressionEvaluator evaluator(frame, context.parameters_,
|
||||
context.symbol_table_, db_);
|
||||
|
||||
auto evaluate_result = [this, &evaluator]() {
|
||||
std::vector<TypedValue> order_by;
|
||||
order_by.reserve(self_.order_by().size());
|
||||
for (auto expression_ptr : self_.order_by()) {
|
||||
order_by.emplace_back(expression_ptr->Accept(evaluator));
|
||||
}
|
||||
return order_by;
|
||||
};
|
||||
|
||||
auto restore_frame = [&frame,
|
||||
this](const std::vector<TypedValue> &restore_from) {
|
||||
for (size_t i = 0; i < restore_from.size(); ++i) {
|
||||
frame[self_.symbols()[i]] = restore_from[i];
|
||||
}
|
||||
};
|
||||
|
||||
if (!merge_initialized_) {
|
||||
remote_puller_.Initialize(context);
|
||||
missing_results_from_ = remote_puller_.Workers();
|
||||
missing_master_result_ = true;
|
||||
merge_initialized_ = true;
|
||||
}
|
||||
|
||||
if (missing_master_result_) {
|
||||
if (input_->Pull(frame, context)) {
|
||||
std::vector<TypedValue> output;
|
||||
output.reserve(self_.symbols().size());
|
||||
for (const Symbol &symbol : self_.symbols()) {
|
||||
output.emplace_back(frame[symbol]);
|
||||
}
|
||||
|
||||
merge_.push_back(MergeResultItem{std::experimental::nullopt, output,
|
||||
evaluate_result()});
|
||||
}
|
||||
missing_master_result_ = false;
|
||||
}
|
||||
|
||||
while (!missing_results_from_.empty()) {
|
||||
remote_puller_.Update(context);
|
||||
|
||||
bool has_all_result = true;
|
||||
for (auto &worker_id : missing_results_from_) {
|
||||
if (!remote_puller_.HasResultsFromWorker(worker_id) &&
|
||||
remote_puller_.HasPendingPullFromWorker(worker_id)) {
|
||||
has_all_result = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!has_all_result) {
|
||||
// If we don't have results from all workers, sleep before continuing.
|
||||
std::this_thread::sleep_for(
|
||||
std::chrono::milliseconds(FLAGS_remote_pull_sleep));
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto &worker_id : missing_results_from_) {
|
||||
// It is possible that the workers remote pull finished but it didn't
|
||||
// return any results. In that case, just skip it.
|
||||
if (!remote_puller_.HasResultsFromWorker(worker_id)) continue;
|
||||
auto remote_result = remote_puller_.PopResultFromWorker(worker_id);
|
||||
restore_frame(remote_result);
|
||||
merge_.push_back(
|
||||
MergeResultItem{worker_id, remote_result, evaluate_result()});
|
||||
}
|
||||
|
||||
missing_results_from_.clear();
|
||||
}
|
||||
|
||||
if (merge_.empty()) return false;
|
||||
|
||||
auto result_it = std::min_element(
|
||||
merge_.begin(), merge_.end(), [this](const auto &lhs, const auto &rhs) {
|
||||
return self_.compare()(lhs.order_by, rhs.order_by);
|
||||
});
|
||||
|
||||
restore_frame(result_it->remote_result);
|
||||
|
||||
if (result_it->worker_id) {
|
||||
missing_results_from_.push_back(result_it->worker_id.value());
|
||||
} else {
|
||||
missing_master_result_ = true;
|
||||
}
|
||||
|
||||
merge_.erase(result_it);
|
||||
return true;
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
throw QueryRuntimeException("Unsupported: Reset during PullRemoteOrderBy!");
|
||||
}
|
||||
|
||||
private:
|
||||
struct MergeResultItem {
|
||||
std::experimental::optional<int> worker_id;
|
||||
std::vector<TypedValue> remote_result;
|
||||
std::vector<TypedValue> order_by;
|
||||
};
|
||||
|
||||
const PullRemoteOrderBy &self_;
|
||||
database::GraphDbAccessor &db_;
|
||||
std::unique_ptr<Cursor> input_;
|
||||
RemotePuller remote_puller_;
|
||||
std::vector<MergeResultItem> merge_;
|
||||
std::vector<int> missing_results_from_;
|
||||
bool missing_master_result_ = false;
|
||||
bool merge_initialized_ = false;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Cursor> PullRemote::MakeCursor(
|
||||
database::GraphDbAccessor &db) const {
|
||||
return std::make_unique<PullRemoteCursor>(*this, db);
|
||||
}
|
||||
|
||||
std::unique_ptr<Cursor> Synchronize::MakeCursor(
|
||||
database::GraphDbAccessor &db) const {
|
||||
return std::make_unique<SynchronizeCursor>(*this, db);
|
||||
}
|
||||
|
||||
bool Cartesian::Accept(HierarchicalLogicalOperatorVisitor &visitor) {
|
||||
if (visitor.PreVisit(*this)) {
|
||||
left_op_->Accept(visitor) && right_op_->Accept(visitor);
|
||||
}
|
||||
return visitor.PostVisit(*this);
|
||||
}
|
||||
|
||||
std::unique_ptr<Cursor> Cartesian::MakeCursor(
|
||||
database::GraphDbAccessor &db) const {
|
||||
return std::make_unique<CartesianCursor>(*this, db);
|
||||
}
|
||||
|
||||
std::vector<Symbol> Cartesian::ModifiedSymbols(const SymbolTable &table) const {
|
||||
auto symbols = left_op_->ModifiedSymbols(table);
|
||||
auto right = right_op_->ModifiedSymbols(table);
|
||||
symbols.insert(symbols.end(), right.begin(), right.end());
|
||||
return symbols;
|
||||
std::unique_ptr<Cursor> PullRemoteOrderBy::MakeCursor(
|
||||
database::GraphDbAccessor &db) const {
|
||||
return std::make_unique<PullRemoteOrderByCursor>(*this, db);
|
||||
}
|
||||
|
||||
} // namespace query::plan
|
||||
@ -3428,3 +3583,4 @@ BOOST_CLASS_EXPORT_IMPLEMENT(query::plan::Union);
|
||||
BOOST_CLASS_EXPORT_IMPLEMENT(query::plan::PullRemote);
|
||||
BOOST_CLASS_EXPORT_IMPLEMENT(query::plan::Synchronize);
|
||||
BOOST_CLASS_EXPORT_IMPLEMENT(query::plan::Cartesian);
|
||||
BOOST_CLASS_EXPORT_IMPLEMENT(query::plan::PullRemoteOrderBy);
|
||||
|
@ -101,6 +101,7 @@ class Union;
|
||||
class PullRemote;
|
||||
class Synchronize;
|
||||
class Cartesian;
|
||||
class PullRemoteOrderBy;
|
||||
|
||||
using LogicalOperatorCompositeVisitor = ::utils::CompositeVisitor<
|
||||
Once, CreateNode, CreateExpand, ScanAll, ScanAllByLabel,
|
||||
@ -110,7 +111,7 @@ using LogicalOperatorCompositeVisitor = ::utils::CompositeVisitor<
|
||||
ExpandUniquenessFilter<VertexAccessor>,
|
||||
ExpandUniquenessFilter<EdgeAccessor>, Accumulate, Aggregate, Skip, Limit,
|
||||
OrderBy, Merge, Optional, Unwind, Distinct, Union, PullRemote, Synchronize,
|
||||
Cartesian>;
|
||||
Cartesian, PullRemoteOrderBy>;
|
||||
|
||||
using LogicalOperatorLeafVisitor = ::utils::LeafVisitor<Once, CreateIndex>;
|
||||
|
||||
@ -1946,29 +1947,6 @@ class OrderBy : public LogicalOperator {
|
||||
void set_input(std::shared_ptr<LogicalOperator> input) { input_ = input; }
|
||||
|
||||
private:
|
||||
// custom Comparator type for comparing vectors of TypedValues
|
||||
// does lexicographical ordering of elements based on the above
|
||||
// defined TypedValueCompare, and also accepts a vector of Orderings
|
||||
// the define how respective elements compare
|
||||
class TypedValueVectorCompare {
|
||||
public:
|
||||
TypedValueVectorCompare() {}
|
||||
explicit TypedValueVectorCompare(const std::vector<Ordering> &ordering)
|
||||
: ordering_(ordering) {}
|
||||
bool operator()(const std::vector<TypedValue> &c1,
|
||||
const std::vector<TypedValue> &c2) const;
|
||||
|
||||
private:
|
||||
std::vector<Ordering> ordering_;
|
||||
|
||||
friend class boost::serialization::access;
|
||||
|
||||
template <class TArchive>
|
||||
void serialize(TArchive &ar, const unsigned int) {
|
||||
ar &ordering_;
|
||||
}
|
||||
};
|
||||
|
||||
std::shared_ptr<LogicalOperator> input_;
|
||||
TypedValueVectorCompare compare_;
|
||||
std::vector<Expression *> order_by_;
|
||||
@ -1976,13 +1954,6 @@ class OrderBy : public LogicalOperator {
|
||||
|
||||
OrderBy() {}
|
||||
|
||||
// custom comparison for TypedValue objects
|
||||
// behaves generally like Neo's ORDER BY comparison operator:
|
||||
// - null is greater than anything else
|
||||
// - primitives compare naturally, only implicit cast is int->double
|
||||
// - (list, map, path, vertex, edge) can't compare to anything
|
||||
static bool TypedValueCompare(const TypedValue &a, const TypedValue &b);
|
||||
|
||||
class OrderByCursor : public Cursor {
|
||||
public:
|
||||
OrderByCursor(const OrderBy &self, database::GraphDbAccessor &db);
|
||||
@ -2364,13 +2335,15 @@ class Union : public LogicalOperator {
|
||||
class PullRemote : public LogicalOperator {
|
||||
public:
|
||||
PullRemote(const std::shared_ptr<LogicalOperator> &input, int64_t plan_id,
|
||||
const std::vector<Symbol> &symbols);
|
||||
const std::vector<Symbol> &symbols)
|
||||
: input_(input), plan_id_(plan_id), symbols_(symbols) {}
|
||||
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
|
||||
std::unique_ptr<Cursor> MakeCursor(
|
||||
database::GraphDbAccessor &db) const override;
|
||||
std::vector<Symbol> OutputSymbols(const SymbolTable &) const override;
|
||||
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
|
||||
|
||||
auto input() const { return input_; }
|
||||
const auto &symbols() const { return symbols_; }
|
||||
auto plan_id() const { return plan_id_; }
|
||||
|
||||
@ -2381,25 +2354,6 @@ class PullRemote : public LogicalOperator {
|
||||
|
||||
PullRemote() {}
|
||||
|
||||
class PullRemoteCursor : public Cursor {
|
||||
public:
|
||||
PullRemoteCursor(const PullRemote &self, database::GraphDbAccessor &db);
|
||||
bool Pull(Frame &, Context &) override;
|
||||
void Reset() override;
|
||||
|
||||
private:
|
||||
const PullRemote &self_;
|
||||
database::GraphDbAccessor &db_;
|
||||
const std::unique_ptr<Cursor> input_cursor_;
|
||||
std::unordered_map<int, std::future<distributed::RemotePullData>>
|
||||
remote_pulls_;
|
||||
std::unordered_map<int, std::vector<std::vector<query::TypedValue>>>
|
||||
remote_results_;
|
||||
std::vector<int> worker_ids_;
|
||||
int last_pulled_worker_id_index_ = 0;
|
||||
bool remote_pulls_initialized_ = false;
|
||||
};
|
||||
|
||||
friend class boost::serialization::access;
|
||||
template <class TArchive>
|
||||
void serialize(TArchive &ar, const unsigned int) {
|
||||
@ -2510,6 +2464,67 @@ class Cartesian : public LogicalOperator {
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Operator that merges distributed OrderBy operators.
|
||||
*
|
||||
* Instead of using a regular OrderBy on master (which would collect all remote
|
||||
* results and order them), we can have each worker do an OrderBy locally and
|
||||
* have the master rely on the fact that the results are ordered and merge them
|
||||
* by having only one result from each worker.
|
||||
*/
|
||||
class PullRemoteOrderBy : public LogicalOperator {
|
||||
public:
|
||||
PullRemoteOrderBy(
|
||||
const std::shared_ptr<LogicalOperator> &input, int64_t plan_id,
|
||||
const std::vector<std::pair<Ordering, Expression *>> &order_by,
|
||||
const std::vector<Symbol> &symbols);
|
||||
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
|
||||
std::unique_ptr<Cursor> MakeCursor(
|
||||
database::GraphDbAccessor &db) const override;
|
||||
|
||||
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
|
||||
std::vector<Symbol> OutputSymbols(const SymbolTable &) const override;
|
||||
|
||||
auto input() const { return input_; }
|
||||
auto plan_id() const { return plan_id_; }
|
||||
const auto &symbols() const { return symbols_; }
|
||||
auto order_by() const { return order_by_; }
|
||||
auto compare() const { return compare_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<LogicalOperator> input_;
|
||||
int64_t plan_id_ = 0;
|
||||
std::vector<Symbol> symbols_;
|
||||
std::vector<Expression *> order_by_;
|
||||
TypedValueVectorCompare compare_;
|
||||
|
||||
PullRemoteOrderBy() {}
|
||||
|
||||
friend class boost::serialization::access;
|
||||
|
||||
BOOST_SERIALIZATION_SPLIT_MEMBER();
|
||||
|
||||
template <class TArchive>
|
||||
void save(TArchive &ar, const unsigned int) const {
|
||||
ar &boost::serialization::base_object<LogicalOperator>(*this);
|
||||
ar &input_;
|
||||
ar &plan_id_;
|
||||
ar &symbols_;
|
||||
SavePointers(ar, order_by_);
|
||||
ar &compare_;
|
||||
}
|
||||
|
||||
template <class TArchive>
|
||||
void load(TArchive &ar, const unsigned int) {
|
||||
ar &boost::serialization::base_object<LogicalOperator>(*this);
|
||||
ar &input_;
|
||||
ar &plan_id_;
|
||||
ar &symbols_;
|
||||
LoadPointers(ar, order_by_);
|
||||
ar &compare_;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace plan
|
||||
} // namespace query
|
||||
|
||||
@ -2547,3 +2562,4 @@ BOOST_CLASS_EXPORT_KEY(query::plan::Union);
|
||||
BOOST_CLASS_EXPORT_KEY(query::plan::PullRemote);
|
||||
BOOST_CLASS_EXPORT_KEY(query::plan::Synchronize);
|
||||
BOOST_CLASS_EXPORT_KEY(query::plan::Cartesian);
|
||||
BOOST_CLASS_EXPORT_KEY(query::plan::PullRemoteOrderBy);
|
||||
|
@ -262,3 +262,61 @@ TEST_F(DistributedGraphDbTest, Create) {
|
||||
EXPECT_GT(VertexCount(worker(1)), 200);
|
||||
EXPECT_GT(VertexCount(worker(2)), 200);
|
||||
}
|
||||
|
||||
TEST_F(DistributedGraphDbTest, PullRemoteOrderBy) {
|
||||
// Create some data on the master and both workers.
|
||||
storage::Property prop;
|
||||
{
|
||||
GraphDbAccessor dba{master()};
|
||||
auto tx_id = dba.transaction_id();
|
||||
GraphDbAccessor dba1{worker(1), tx_id};
|
||||
GraphDbAccessor dba2{worker(2), tx_id};
|
||||
prop = dba.Property("prop");
|
||||
auto add_data = [prop](GraphDbAccessor &dba, int value) {
|
||||
dba.InsertVertex().PropsSet(prop, value);
|
||||
};
|
||||
|
||||
std::vector<int> data;
|
||||
for (int i = 0; i < 300; ++i) data.push_back(i);
|
||||
std::random_shuffle(data.begin(), data.end());
|
||||
|
||||
for (int i = 0; i < 100; ++i) add_data(dba, data[i]);
|
||||
for (int i = 100; i < 200; ++i) add_data(dba1, data[i]);
|
||||
for (int i = 200; i < 300; ++i) add_data(dba2, data[i]);
|
||||
|
||||
dba.Commit();
|
||||
}
|
||||
|
||||
auto &db = master();
|
||||
GraphDbAccessor dba{db};
|
||||
Context ctx{dba};
|
||||
SymbolGenerator symbol_generator{ctx.symbol_table_};
|
||||
AstTreeStorage storage;
|
||||
|
||||
// Query plan for: MATCH (n) RETURN n.prop ORDER BY n.prop;
|
||||
auto n = MakeScanAll(storage, ctx.symbol_table_, "n");
|
||||
auto n_p = PROPERTY_LOOKUP("n", prop);
|
||||
ctx.symbol_table_[*n_p->expression_] = n.sym_;
|
||||
auto order_by = std::make_shared<plan::OrderBy>(
|
||||
n.op_,
|
||||
std::vector<std::pair<Ordering, Expression *>>{{Ordering::ASC, n_p}},
|
||||
std::vector<Symbol>{n.sym_});
|
||||
|
||||
const int plan_id = 42;
|
||||
master().plan_dispatcher().DispatchPlan(plan_id, order_by, ctx.symbol_table_);
|
||||
|
||||
auto pull_remote_order_by = std::make_shared<plan::PullRemoteOrderBy>(
|
||||
order_by, plan_id,
|
||||
std::vector<std::pair<Ordering, Expression *>>{{Ordering::ASC, n_p}},
|
||||
std::vector<Symbol>{n.sym_});
|
||||
|
||||
auto n_p_ne = NEXPR("n.prop", n_p);
|
||||
ctx.symbol_table_[*n_p_ne] = ctx.symbol_table_.CreateSymbol("n.prop", true);
|
||||
auto produce = MakeProduce(pull_remote_order_by, n_p_ne);
|
||||
auto results = CollectProduce(produce.get(), ctx.symbol_table_, dba);
|
||||
|
||||
ASSERT_EQ(results.size(), 300);
|
||||
for (int j = 0; j < 300; ++j) {
|
||||
EXPECT_TRUE(TypedValue::BoolEqual{}(results[j][0], j));
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user