From b9d61a0127fc3668d4d9229e976f6d02dbc8c2f9 Mon Sep 17 00:00:00 2001 From: Matija Santl Date: Mon, 19 Feb 2018 15:07:02 +0100 Subject: [PATCH] Implement distributed aware OrderBy operator Summary: Extracted `TypedValueVectorCompare` and `RemotePuller` from operators so it can be reused. The new `PullRemoteOrerBy` operator pulls one result from each worker and one from master, relies on the fact that workers/master returned sorted results, returns the next one in order, and pulls the source of that result to get the next one. Depends on D1215 that (at the moment) is still in review. Reviewers: florijan, teon.banek Reviewed By: teon.banek Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D1221 --- src/query/common.cpp | 66 ++++ src/query/common.hpp | 33 ++ src/query/plan/operator.cpp | 524 +++++++++++++++++--------- src/query/plan/operator.hpp | 118 +++--- tests/unit/distributed_query_plan.cpp | 58 +++ 5 files changed, 564 insertions(+), 235 deletions(-) diff --git a/src/query/common.cpp b/src/query/common.cpp index e58d6a0e4..37fa9f7d7 100644 --- a/src/query/common.cpp +++ b/src/query/common.cpp @@ -213,4 +213,70 @@ void ReconstructTypedValue(TypedValue &value) { break; } } + +bool TypedValueVectorCompare::operator()( + const std::vector &c1, + const std::vector &c2) const { + // ordering is invalid if there are more elements in the collections + // then there are in the ordering_ vector + DCHECK(c1.size() <= ordering_.size() && c2.size() <= ordering_.size()) + << "Collections contain more elements then there are orderings"; + + auto c1_it = c1.begin(); + auto c2_it = c2.begin(); + auto ordering_it = ordering_.begin(); + for (; c1_it != c1.end() && c2_it != c2.end(); + c1_it++, c2_it++, ordering_it++) { + if (TypedValueCompare(*c1_it, *c2_it)) return *ordering_it == Ordering::ASC; + if (TypedValueCompare(*c2_it, *c1_it)) + return *ordering_it == Ordering::DESC; + } + + // at least one collection is exhausted + // c1 is less then c2 iff c1 reached the end but c2 didn't + return (c1_it == c1.end()) && (c2_it != c2.end()); +} + +bool TypedValueVectorCompare::TypedValueCompare(const TypedValue &a, + const TypedValue &b) const { + // in ordering null comes after everything else + // at the same time Null is not less that null + // first deal with Null < Whatever case + if (a.IsNull()) return false; + // now deal with NotNull < Null case + if (b.IsNull()) return true; + + // comparisons are from this point legal only between values of + // the same type, or int+float combinations + if ((a.type() != b.type() && !(a.IsNumeric() && b.IsNumeric()))) + throw QueryRuntimeException( + "Can't compare value of type {} to value of type {}", a.type(), + b.type()); + + switch (a.type()) { + case TypedValue::Type::Bool: + return !a.Value() && b.Value(); + case TypedValue::Type::Int: + if (b.type() == TypedValue::Type::Double) + return a.Value() < b.Value(); + else + return a.Value() < b.Value(); + case TypedValue::Type::Double: + if (b.type() == TypedValue::Type::Int) + return a.Value() < b.Value(); + else + return a.Value() < b.Value(); + case TypedValue::Type::String: + return a.Value() < b.Value(); + case TypedValue::Type::List: + case TypedValue::Type::Map: + case TypedValue::Type::Vertex: + case TypedValue::Type::Edge: + case TypedValue::Type::Path: + throw QueryRuntimeException( + "Comparison is not defined for values of type {}", a.type()); + default: + LOG(FATAL) << "Unhandled comparison for types"; + } +} } // namespace query diff --git a/src/query/common.hpp b/src/query/common.hpp index d98c2ba4e..b818e20da 100644 --- a/src/query/common.hpp +++ b/src/query/common.hpp @@ -3,6 +3,8 @@ #include #include +#include "boost/serialization/serialization.hpp" +#include "query/frontend/ast/ast.hpp" #include "query/typed_value.hpp" namespace query { @@ -33,4 +35,35 @@ enum class GraphView { AS_IS, OLD, NEW }; * @returns - If the reconstruction succeeded. */ void ReconstructTypedValue(TypedValue &value); + +// Custom Comparator type for comparing vectors of TypedValues. +// +// Does lexicographical ordering of elements based on the above +// defined TypedValueCompare, and also accepts a vector of Orderings +// the define how respective elements compare. +class TypedValueVectorCompare { + public: + TypedValueVectorCompare() {} + explicit TypedValueVectorCompare(const std::vector &ordering) + : ordering_(ordering) {} + bool operator()(const std::vector &c1, + const std::vector &c2) const; + + private: + std::vector ordering_; + + friend class boost::serialization::access; + + template + void serialize(TArchive &ar, const unsigned int) { + ar &ordering_; + } + // Custom comparison for TypedValue objects. + // + // Behaves generally like Neo's ORDER BY comparison operator: + // - null is greater than anything else + // - primitives compare naturally, only implicit cast is int->double + // - (list, map, path, vertex, edge) can't compare to anything + bool TypedValueCompare(const TypedValue &a, const TypedValue &b) const; +}; } diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 1dfe21c91..98e61c699 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -397,10 +397,10 @@ std::unique_ptr ScanAllByLabelPropertyRange::MakeCursor( context.symbol_table_, db, graph_view_); auto convert = [&evaluator](const auto &bound) -> std::experimental::optional> { - if (!bound) return std::experimental::nullopt; - return std::experimental::make_optional(utils::Bound( - bound.value().value()->Accept(evaluator), bound.value().type())); - }; + if (!bound) return std::experimental::nullopt; + return std::experimental::make_optional(utils::Bound( + bound.value().value()->Accept(evaluator), bound.value().type())); + }; return db.Vertices(label_, property_, convert(lower_bound()), convert(upper_bound()), graph_view_ == GraphView::NEW); }; @@ -1216,9 +1216,8 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor { self_.graph_view_); // For the given (vertex, edge, vertex) tuple checks if they satisfy the // "where" condition. if so, places them in the priority queue. - auto expand_pair = [this, &evaluator, &frame](VertexAccessor from, - EdgeAccessor edge, - VertexAccessor vertex) { + auto expand_pair = [this, &evaluator, &frame]( + VertexAccessor from, EdgeAccessor edge, VertexAccessor vertex) { SwitchAccessor(edge, self_.graph_view_); SwitchAccessor(vertex, self_.graph_view_); @@ -2567,72 +2566,6 @@ void OrderBy::OrderByCursor::Reset() { cache_it_ = cache_.begin(); } -bool OrderBy::TypedValueCompare(const TypedValue &a, const TypedValue &b) { - // in ordering null comes after everything else - // at the same time Null is not less that null - // first deal with Null < Whatever case - if (a.IsNull()) return false; - // now deal with NotNull < Null case - if (b.IsNull()) return true; - - // comparisons are from this point legal only between values of - // the same type, or int+float combinations - if ((a.type() != b.type() && !(a.IsNumeric() && b.IsNumeric()))) - throw QueryRuntimeException( - "Can't compare value of type {} to value of type {}", a.type(), - b.type()); - - switch (a.type()) { - case TypedValue::Type::Bool: - return !a.Value() && b.Value(); - case TypedValue::Type::Int: - if (b.type() == TypedValue::Type::Double) - return a.Value() < b.Value(); - else - return a.Value() < b.Value(); - case TypedValue::Type::Double: - if (b.type() == TypedValue::Type::Int) - return a.Value() < b.Value(); - else - return a.Value() < b.Value(); - case TypedValue::Type::String: - return a.Value() < b.Value(); - case TypedValue::Type::List: - case TypedValue::Type::Map: - case TypedValue::Type::Vertex: - case TypedValue::Type::Edge: - case TypedValue::Type::Path: - throw QueryRuntimeException( - "Comparison is not defined for values of type {}", a.type()); - default: - LOG(FATAL) << "Unhandled comparison for types"; - } -} - -bool OrderBy::TypedValueVectorCompare::operator()( - const std::vector &c1, - const std::vector &c2) const { - // ordering is invalid if there are more elements in the collections - // then there are in the ordering_ vector - DCHECK(c1.size() <= ordering_.size() && c2.size() <= ordering_.size()) - << "Collections contain more elements then there are orderings"; - - auto c1_it = c1.begin(); - auto c2_it = c2.begin(); - auto ordering_it = ordering_.begin(); - for (; c1_it != c1.end() && c2_it != c2.end(); - c1_it++, c2_it++, ordering_it++) { - if (OrderBy::TypedValueCompare(*c1_it, *c2_it)) - return *ordering_it == Ordering::ASC; - if (OrderBy::TypedValueCompare(*c2_it, *c1_it)) - return *ordering_it == Ordering::DESC; - } - - // at least one collection is exhausted - // c1 is less then c2 iff c1 reached the end but c2 didn't - return (c1_it == c1.end()) && (c2_it != c2.end()); -} - Merge::Merge(const std::shared_ptr &input, const std::shared_ptr &merge_match, const std::shared_ptr &merge_create) @@ -2985,10 +2918,6 @@ void Union::UnionCursor::Reset() { right_cursor_->Reset(); } -PullRemote::PullRemote(const std::shared_ptr &input, - int64_t plan_id, const std::vector &symbols) - : input_(input), plan_id_(plan_id), symbols_(symbols) {} - ACCEPT_WITH_INPUT(PullRemote); std::vector PullRemote::OutputSymbols(const SymbolTable &table) const { @@ -3005,34 +2934,103 @@ std::vector PullRemote::ModifiedSymbols( return symbols; } -PullRemote::PullRemoteCursor::PullRemoteCursor(const PullRemote &self, - database::GraphDbAccessor &db) - : self_(self), - db_(db), - input_cursor_(self.input_ ? self.input_->MakeCursor(db) : nullptr) { - worker_ids_ = db_.db().remote_pull_clients().GetWorkerIds(); - // Remove master from the worker ids list. - worker_ids_.erase(std::find(worker_ids_.begin(), worker_ids_.end(), 0)); +std::vector Synchronize::ModifiedSymbols( + const SymbolTable &table) const { + auto symbols = input_->ModifiedSymbols(table); + if (pull_remote_) { + auto pull_symbols = pull_remote_->ModifiedSymbols(table); + symbols.insert(symbols.end(), pull_symbols.begin(), pull_symbols.end()); + } + return symbols; } -bool PullRemote::PullRemoteCursor::Pull(Frame &frame, Context &context) { - auto insert_future_for_worker = [&](int worker_id) { - remote_pulls_[worker_id] = db_.db().remote_pull_clients().RemotePull( - db_, worker_id, self_.plan_id(), context.parameters_, self_.symbols(), - false); - }; +bool Synchronize::Accept(HierarchicalLogicalOperatorVisitor &visitor) { + if (visitor.PreVisit(*this)) { + input_->Accept(visitor) && pull_remote_->Accept(visitor); + } + return visitor.PostVisit(*this); +} - if (!remote_pulls_initialized_) { - for (auto &worker_id : worker_ids_) { - insert_future_for_worker(worker_id); - } - remote_pulls_initialized_ = true; +std::vector Cartesian::ModifiedSymbols(const SymbolTable &table) const { + auto symbols = left_op_->ModifiedSymbols(table); + auto right = right_op_->ModifiedSymbols(table); + symbols.insert(symbols.end(), right.begin(), right.end()); + return symbols; +} + +bool Cartesian::Accept(HierarchicalLogicalOperatorVisitor &visitor) { + if (visitor.PreVisit(*this)) { + left_op_->Accept(visitor) && right_op_->Accept(visitor); + } + return visitor.PostVisit(*this); +} + +PullRemoteOrderBy::PullRemoteOrderBy( + const std::shared_ptr &input, int64_t plan_id, + const std::vector> &order_by, + const std::vector &symbols) + : input_(input), plan_id_(plan_id), symbols_(symbols) { + CHECK(input_ != nullptr) + << "PullRemoteOrderBy should always be constructed with input!"; + std::vector ordering; + ordering.reserve(order_by.size()); + order_by_.reserve(order_by.size()); + for (const auto &ordering_expression_pair : order_by) { + ordering.emplace_back(ordering_expression_pair.first); + order_by_.emplace_back(ordering_expression_pair.second); + } + compare_ = TypedValueVectorCompare(ordering); +} + +ACCEPT_WITH_INPUT(PullRemoteOrderBy); + +std::vector PullRemoteOrderBy::OutputSymbols( + const SymbolTable &table) const { + return input_->OutputSymbols(table); +} + +std::vector PullRemoteOrderBy::ModifiedSymbols( + const SymbolTable &table) const { + return input_->ModifiedSymbols(table); +} + +namespace { + +/** Helper class that wraps remote pulling for cursors that handle results from + * distributed workers. + */ +class RemotePuller { + public: + RemotePuller(database::GraphDbAccessor &db, + const std::vector &symbols, int64_t plan_id) + : db_(db), symbols_(symbols), plan_id_(plan_id) { + worker_ids_ = db_.db().remote_pull_clients().GetWorkerIds(); + // Remove master from the worker ids list. + worker_ids_.erase(std::find(worker_ids_.begin(), worker_ids_.end(), 0)); } - bool have_remote_results = false; - while (!have_remote_results && !worker_ids_.empty()) { + void Initialize(Context &context) { + if (!remote_pulls_initialized_) { + for (auto &worker_id : worker_ids_) { + UpdateRemotePullForWorker(worker_id, context); + } + remote_pulls_initialized_ = true; + } + } + + void Update(Context &context) { // If we don't have results for a worker, check if his remote pull // finished and save results locally. + + auto move_frames = [this](int worker_id, auto remote_results) { + remote_results_[worker_id] = std::move(remote_results.frames); + // Since we return and remove results from the back of the vector, + // reverse the results so the first to return is on the end of the + // vector. + std::reverse(remote_results_[worker_id].begin(), + remote_results_[worker_id].end()); + }; + for (auto &worker_id : worker_ids_) { if (!remote_results_[worker_id].empty()) continue; @@ -3045,12 +3043,12 @@ bool PullRemote::PullRemoteCursor::Pull(Frame &frame, Context &context) { auto remote_results = remote_pull.get(); switch (remote_results.pull_state) { case distributed::RemotePullState::CURSOR_EXHAUSTED: - remote_results_[worker_id] = std::move(remote_results.frames); + move_frames(worker_id, remote_results); remote_pulls_.erase(found_it); break; case distributed::RemotePullState::CURSOR_IN_PROGRESS: - remote_results_[worker_id] = std::move(remote_results.frames); - insert_future_for_worker(worker_id); + move_frames(worker_id, remote_results); + UpdateRemotePullForWorker(worker_id, context); break; case distributed::RemotePullState::SERIALIZATION_ERROR: throw mvcc::SerializationError( @@ -3068,94 +3066,133 @@ bool PullRemote::PullRemoteCursor::Pull(Frame &frame, Context &context) { "Query runtime error occurred duing PullRemote !"); } } + } - // Get locally stored results from workers in a round-robin fasion. - int num_workers = worker_ids_.size(); - for (int i = 0; i < num_workers; ++i) { - int worker_id_index = - (last_pulled_worker_id_index_ + i + 1) % num_workers; - int worker_id = worker_ids_[worker_id_index]; + auto Workers() { return worker_ids_; } - if (!remote_results_[worker_id].empty()) { - last_pulled_worker_id_index_ = worker_id_index; - have_remote_results = true; - break; + int GetWorkerId(int worker_id_index) { return worker_ids_[worker_id_index]; } + + size_t WorkerCount() { return worker_ids_.size(); } + + void ClearWorkers() { worker_ids_.clear(); } + + bool HasPendingPulls() { return !remote_pulls_.empty(); } + + bool HasPendingPullFromWorker(int worker_id) { + return remote_pulls_.find(worker_id) != remote_pulls_.end(); + } + + bool HasResultsFromWorker(int worker_id) { + return !remote_results_[worker_id].empty(); + } + + std::vector PopResultFromWorker(int worker_id) { + auto result = remote_results_[worker_id].back(); + remote_results_[worker_id].pop_back(); + + // Remove the worker if we exhausted all locally stored results and there + // are no more pending remote pulls for that worker. + if (remote_results_[worker_id].empty() && + remote_pulls_.find(worker_id) == remote_pulls_.end()) { + worker_ids_.erase( + std::find(worker_ids_.begin(), worker_ids_.end(), worker_id)); + } + + return result; + } + + private: + database::GraphDbAccessor &db_; + std::vector symbols_; + int64_t plan_id_; + std::unordered_map> + remote_pulls_; + std::unordered_map>> + remote_results_; + std::vector worker_ids_; + bool remote_pulls_initialized_ = false; + + void UpdateRemotePullForWorker(int worker_id, Context &context) { + remote_pulls_[worker_id] = db_.db().remote_pull_clients().RemotePull( + db_, worker_id, plan_id_, context.parameters_, symbols_, false); + } +}; + +class PullRemoteCursor : public Cursor { + public: + PullRemoteCursor(const PullRemote &self, database::GraphDbAccessor &db) + : self_(self), + input_cursor_(self.input() ? self.input()->MakeCursor(db) : nullptr), + remote_puller_(RemotePuller(db, self.symbols(), self.plan_id())) {} + + bool Pull(Frame &frame, Context &context) override { + remote_puller_.Initialize(context); + + bool have_remote_results = false; + while (!have_remote_results && remote_puller_.WorkerCount() > 0) { + remote_puller_.Update(context); + + // Get locally stored results from workers in a round-robin fasion. + int num_workers = remote_puller_.WorkerCount(); + for (int i = 0; i < num_workers; ++i) { + int worker_id_index = + (last_pulled_worker_id_index_ + i + 1) % num_workers; + int worker_id = remote_puller_.GetWorkerId(worker_id_index); + + if (remote_puller_.HasResultsFromWorker(worker_id)) { + last_pulled_worker_id_index_ = worker_id_index; + have_remote_results = true; + break; + } + } + + if (!have_remote_results) { + if (!remote_puller_.HasPendingPulls()) { + remote_puller_.ClearWorkers(); + break; + } + + // If there are no remote results available, try to pull and return + // local results. + if (input_cursor_ && input_cursor_->Pull(frame, context)) { + return true; + } + + // If there aren't any local/remote results available, sleep. + std::this_thread::sleep_for( + std::chrono::milliseconds(FLAGS_remote_pull_sleep)); } } + // No more remote results, make sure local results get exhausted. if (!have_remote_results) { - if (remote_pulls_.empty()) { - worker_ids_.clear(); - break; - } - - // If there are no remote results available, try to pull and return local - // results. if (input_cursor_ && input_cursor_->Pull(frame, context)) { return true; } - - // If there aren't any local/remote results available, sleep. - std::this_thread::sleep_for( - std::chrono::milliseconds(FLAGS_remote_pull_sleep)); + return false; } - } - // No more remote results, make sure local results get exhausted. - if (!have_remote_results) { - if (input_cursor_ && input_cursor_->Pull(frame, context)) { - return true; + { + int worker_id = remote_puller_.GetWorkerId(last_pulled_worker_id_index_); + auto result = remote_puller_.PopResultFromWorker(worker_id); + for (size_t i = 0; i < self_.symbols().size(); ++i) { + frame[self_.symbols()[i]] = std::move(result[i]); + } } - return false; + return true; } - int pull_from_worker_id = worker_ids_[last_pulled_worker_id_index_]; - { - auto &result = remote_results_[pull_from_worker_id].back(); - for (size_t i = 0; i < self_.symbols().size(); ++i) { - frame[self_.symbols()[i]] = std::move(result[i]); - } - } - remote_results_[pull_from_worker_id].resize( - remote_results_[pull_from_worker_id].size() - 1); - - // Remove the worker if we exhausted all locally stored results and there - // are no more pending remote pulls for that worker. - if (remote_results_[pull_from_worker_id].empty() && - remote_pulls_.find(pull_from_worker_id) == remote_pulls_.end()) { - worker_ids_.erase(worker_ids_.begin() + last_pulled_worker_id_index_); + void Reset() override { + throw QueryRuntimeException("Unsupported: Reset during PullRemote!"); } - return true; -} + private: + const PullRemote &self_; + const std::unique_ptr input_cursor_; + RemotePuller remote_puller_; + int last_pulled_worker_id_index_ = 0; +}; -void PullRemote::PullRemoteCursor::Reset() { - throw QueryRuntimeException("Unsupported: Reset during PullRemote!"); -} - -std::unique_ptr PullRemote::MakeCursor( - database::GraphDbAccessor &db) const { - return std::make_unique(*this, db); -} - -bool Synchronize::Accept(HierarchicalLogicalOperatorVisitor &visitor) { - if (visitor.PreVisit(*this)) { - input_->Accept(visitor) && pull_remote_->Accept(visitor); - } - return visitor.PostVisit(*this); -} - -std::vector Synchronize::ModifiedSymbols( - const SymbolTable &table) const { - auto symbols = input_->ModifiedSymbols(table); - if (pull_remote_) { - auto pull_symbols = pull_remote_->ModifiedSymbols(table); - symbols.insert(symbols.end(), pull_symbols.begin(), pull_symbols.end()); - } - return symbols; -} - -namespace { class SynchronizeCursor : public Cursor { public: SynchronizeCursor(const Synchronize &self, database::GraphDbAccessor &db) @@ -3365,30 +3402,148 @@ class CartesianCursor : public Cursor { bool cartesian_pull_initialized_{false}; }; +class PullRemoteOrderByCursor : public Cursor { + public: + PullRemoteOrderByCursor(const PullRemoteOrderBy &self, + database::GraphDbAccessor &db) + : self_(self), + db_(db), + input_(self.input()->MakeCursor(db)), + remote_puller_(RemotePuller(db, self.symbols(), self.plan_id())) {} + + bool Pull(Frame &frame, Context &context) { + ExpressionEvaluator evaluator(frame, context.parameters_, + context.symbol_table_, db_); + + auto evaluate_result = [this, &evaluator]() { + std::vector order_by; + order_by.reserve(self_.order_by().size()); + for (auto expression_ptr : self_.order_by()) { + order_by.emplace_back(expression_ptr->Accept(evaluator)); + } + return order_by; + }; + + auto restore_frame = [&frame, + this](const std::vector &restore_from) { + for (size_t i = 0; i < restore_from.size(); ++i) { + frame[self_.symbols()[i]] = restore_from[i]; + } + }; + + if (!merge_initialized_) { + remote_puller_.Initialize(context); + missing_results_from_ = remote_puller_.Workers(); + missing_master_result_ = true; + merge_initialized_ = true; + } + + if (missing_master_result_) { + if (input_->Pull(frame, context)) { + std::vector output; + output.reserve(self_.symbols().size()); + for (const Symbol &symbol : self_.symbols()) { + output.emplace_back(frame[symbol]); + } + + merge_.push_back(MergeResultItem{std::experimental::nullopt, output, + evaluate_result()}); + } + missing_master_result_ = false; + } + + while (!missing_results_from_.empty()) { + remote_puller_.Update(context); + + bool has_all_result = true; + for (auto &worker_id : missing_results_from_) { + if (!remote_puller_.HasResultsFromWorker(worker_id) && + remote_puller_.HasPendingPullFromWorker(worker_id)) { + has_all_result = false; + break; + } + } + + if (!has_all_result) { + // If we don't have results from all workers, sleep before continuing. + std::this_thread::sleep_for( + std::chrono::milliseconds(FLAGS_remote_pull_sleep)); + continue; + } + + for (auto &worker_id : missing_results_from_) { + // It is possible that the workers remote pull finished but it didn't + // return any results. In that case, just skip it. + if (!remote_puller_.HasResultsFromWorker(worker_id)) continue; + auto remote_result = remote_puller_.PopResultFromWorker(worker_id); + restore_frame(remote_result); + merge_.push_back( + MergeResultItem{worker_id, remote_result, evaluate_result()}); + } + + missing_results_from_.clear(); + } + + if (merge_.empty()) return false; + + auto result_it = std::min_element( + merge_.begin(), merge_.end(), [this](const auto &lhs, const auto &rhs) { + return self_.compare()(lhs.order_by, rhs.order_by); + }); + + restore_frame(result_it->remote_result); + + if (result_it->worker_id) { + missing_results_from_.push_back(result_it->worker_id.value()); + } else { + missing_master_result_ = true; + } + + merge_.erase(result_it); + return true; + } + + void Reset() { + throw QueryRuntimeException("Unsupported: Reset during PullRemoteOrderBy!"); + } + + private: + struct MergeResultItem { + std::experimental::optional worker_id; + std::vector remote_result; + std::vector order_by; + }; + + const PullRemoteOrderBy &self_; + database::GraphDbAccessor &db_; + std::unique_ptr input_; + RemotePuller remote_puller_; + std::vector merge_; + std::vector missing_results_from_; + bool missing_master_result_ = false; + bool merge_initialized_ = false; +}; + } // namespace +std::unique_ptr PullRemote::MakeCursor( + database::GraphDbAccessor &db) const { + return std::make_unique(*this, db); +} + std::unique_ptr Synchronize::MakeCursor( database::GraphDbAccessor &db) const { return std::make_unique(*this, db); } -bool Cartesian::Accept(HierarchicalLogicalOperatorVisitor &visitor) { - if (visitor.PreVisit(*this)) { - left_op_->Accept(visitor) && right_op_->Accept(visitor); - } - return visitor.PostVisit(*this); -} - std::unique_ptr Cartesian::MakeCursor( database::GraphDbAccessor &db) const { return std::make_unique(*this, db); } -std::vector Cartesian::ModifiedSymbols(const SymbolTable &table) const { - auto symbols = left_op_->ModifiedSymbols(table); - auto right = right_op_->ModifiedSymbols(table); - symbols.insert(symbols.end(), right.begin(), right.end()); - return symbols; +std::unique_ptr PullRemoteOrderBy::MakeCursor( + database::GraphDbAccessor &db) const { + return std::make_unique(*this, db); } } // namespace query::plan @@ -3428,3 +3583,4 @@ BOOST_CLASS_EXPORT_IMPLEMENT(query::plan::Union); BOOST_CLASS_EXPORT_IMPLEMENT(query::plan::PullRemote); BOOST_CLASS_EXPORT_IMPLEMENT(query::plan::Synchronize); BOOST_CLASS_EXPORT_IMPLEMENT(query::plan::Cartesian); +BOOST_CLASS_EXPORT_IMPLEMENT(query::plan::PullRemoteOrderBy); diff --git a/src/query/plan/operator.hpp b/src/query/plan/operator.hpp index 5d5d5f0e2..b88d942ee 100644 --- a/src/query/plan/operator.hpp +++ b/src/query/plan/operator.hpp @@ -101,6 +101,7 @@ class Union; class PullRemote; class Synchronize; class Cartesian; +class PullRemoteOrderBy; using LogicalOperatorCompositeVisitor = ::utils::CompositeVisitor< Once, CreateNode, CreateExpand, ScanAll, ScanAllByLabel, @@ -110,7 +111,7 @@ using LogicalOperatorCompositeVisitor = ::utils::CompositeVisitor< ExpandUniquenessFilter, ExpandUniquenessFilter, Accumulate, Aggregate, Skip, Limit, OrderBy, Merge, Optional, Unwind, Distinct, Union, PullRemote, Synchronize, - Cartesian>; + Cartesian, PullRemoteOrderBy>; using LogicalOperatorLeafVisitor = ::utils::LeafVisitor; @@ -1946,29 +1947,6 @@ class OrderBy : public LogicalOperator { void set_input(std::shared_ptr input) { input_ = input; } private: - // custom Comparator type for comparing vectors of TypedValues - // does lexicographical ordering of elements based on the above - // defined TypedValueCompare, and also accepts a vector of Orderings - // the define how respective elements compare - class TypedValueVectorCompare { - public: - TypedValueVectorCompare() {} - explicit TypedValueVectorCompare(const std::vector &ordering) - : ordering_(ordering) {} - bool operator()(const std::vector &c1, - const std::vector &c2) const; - - private: - std::vector ordering_; - - friend class boost::serialization::access; - - template - void serialize(TArchive &ar, const unsigned int) { - ar &ordering_; - } - }; - std::shared_ptr input_; TypedValueVectorCompare compare_; std::vector order_by_; @@ -1976,13 +1954,6 @@ class OrderBy : public LogicalOperator { OrderBy() {} - // custom comparison for TypedValue objects - // behaves generally like Neo's ORDER BY comparison operator: - // - null is greater than anything else - // - primitives compare naturally, only implicit cast is int->double - // - (list, map, path, vertex, edge) can't compare to anything - static bool TypedValueCompare(const TypedValue &a, const TypedValue &b); - class OrderByCursor : public Cursor { public: OrderByCursor(const OrderBy &self, database::GraphDbAccessor &db); @@ -2364,13 +2335,15 @@ class Union : public LogicalOperator { class PullRemote : public LogicalOperator { public: PullRemote(const std::shared_ptr &input, int64_t plan_id, - const std::vector &symbols); + const std::vector &symbols) + : input_(input), plan_id_(plan_id), symbols_(symbols) {} bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; std::unique_ptr MakeCursor( database::GraphDbAccessor &db) const override; std::vector OutputSymbols(const SymbolTable &) const override; std::vector ModifiedSymbols(const SymbolTable &) const override; + auto input() const { return input_; } const auto &symbols() const { return symbols_; } auto plan_id() const { return plan_id_; } @@ -2381,25 +2354,6 @@ class PullRemote : public LogicalOperator { PullRemote() {} - class PullRemoteCursor : public Cursor { - public: - PullRemoteCursor(const PullRemote &self, database::GraphDbAccessor &db); - bool Pull(Frame &, Context &) override; - void Reset() override; - - private: - const PullRemote &self_; - database::GraphDbAccessor &db_; - const std::unique_ptr input_cursor_; - std::unordered_map> - remote_pulls_; - std::unordered_map>> - remote_results_; - std::vector worker_ids_; - int last_pulled_worker_id_index_ = 0; - bool remote_pulls_initialized_ = false; - }; - friend class boost::serialization::access; template void serialize(TArchive &ar, const unsigned int) { @@ -2510,6 +2464,67 @@ class Cartesian : public LogicalOperator { } }; +/** + * Operator that merges distributed OrderBy operators. + * + * Instead of using a regular OrderBy on master (which would collect all remote + * results and order them), we can have each worker do an OrderBy locally and + * have the master rely on the fact that the results are ordered and merge them + * by having only one result from each worker. + */ +class PullRemoteOrderBy : public LogicalOperator { + public: + PullRemoteOrderBy( + const std::shared_ptr &input, int64_t plan_id, + const std::vector> &order_by, + const std::vector &symbols); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + std::unique_ptr MakeCursor( + database::GraphDbAccessor &db) const override; + + std::vector ModifiedSymbols(const SymbolTable &) const override; + std::vector OutputSymbols(const SymbolTable &) const override; + + auto input() const { return input_; } + auto plan_id() const { return plan_id_; } + const auto &symbols() const { return symbols_; } + auto order_by() const { return order_by_; } + auto compare() const { return compare_; } + + private: + std::shared_ptr input_; + int64_t plan_id_ = 0; + std::vector symbols_; + std::vector order_by_; + TypedValueVectorCompare compare_; + + PullRemoteOrderBy() {} + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object(*this); + ar &input_; + ar &plan_id_; + ar &symbols_; + SavePointers(ar, order_by_); + ar &compare_; + } + + template + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &input_; + ar &plan_id_; + ar &symbols_; + LoadPointers(ar, order_by_); + ar &compare_; + } +}; + } // namespace plan } // namespace query @@ -2547,3 +2562,4 @@ BOOST_CLASS_EXPORT_KEY(query::plan::Union); BOOST_CLASS_EXPORT_KEY(query::plan::PullRemote); BOOST_CLASS_EXPORT_KEY(query::plan::Synchronize); BOOST_CLASS_EXPORT_KEY(query::plan::Cartesian); +BOOST_CLASS_EXPORT_KEY(query::plan::PullRemoteOrderBy); diff --git a/tests/unit/distributed_query_plan.cpp b/tests/unit/distributed_query_plan.cpp index 839a74af6..45f1da614 100644 --- a/tests/unit/distributed_query_plan.cpp +++ b/tests/unit/distributed_query_plan.cpp @@ -262,3 +262,61 @@ TEST_F(DistributedGraphDbTest, Create) { EXPECT_GT(VertexCount(worker(1)), 200); EXPECT_GT(VertexCount(worker(2)), 200); } + +TEST_F(DistributedGraphDbTest, PullRemoteOrderBy) { + // Create some data on the master and both workers. + storage::Property prop; + { + GraphDbAccessor dba{master()}; + auto tx_id = dba.transaction_id(); + GraphDbAccessor dba1{worker(1), tx_id}; + GraphDbAccessor dba2{worker(2), tx_id}; + prop = dba.Property("prop"); + auto add_data = [prop](GraphDbAccessor &dba, int value) { + dba.InsertVertex().PropsSet(prop, value); + }; + + std::vector data; + for (int i = 0; i < 300; ++i) data.push_back(i); + std::random_shuffle(data.begin(), data.end()); + + for (int i = 0; i < 100; ++i) add_data(dba, data[i]); + for (int i = 100; i < 200; ++i) add_data(dba1, data[i]); + for (int i = 200; i < 300; ++i) add_data(dba2, data[i]); + + dba.Commit(); + } + + auto &db = master(); + GraphDbAccessor dba{db}; + Context ctx{dba}; + SymbolGenerator symbol_generator{ctx.symbol_table_}; + AstTreeStorage storage; + + // Query plan for: MATCH (n) RETURN n.prop ORDER BY n.prop; + auto n = MakeScanAll(storage, ctx.symbol_table_, "n"); + auto n_p = PROPERTY_LOOKUP("n", prop); + ctx.symbol_table_[*n_p->expression_] = n.sym_; + auto order_by = std::make_shared( + n.op_, + std::vector>{{Ordering::ASC, n_p}}, + std::vector{n.sym_}); + + const int plan_id = 42; + master().plan_dispatcher().DispatchPlan(plan_id, order_by, ctx.symbol_table_); + + auto pull_remote_order_by = std::make_shared( + order_by, plan_id, + std::vector>{{Ordering::ASC, n_p}}, + std::vector{n.sym_}); + + auto n_p_ne = NEXPR("n.prop", n_p); + ctx.symbol_table_[*n_p_ne] = ctx.symbol_table_.CreateSymbol("n.prop", true); + auto produce = MakeProduce(pull_remote_order_by, n_p_ne); + auto results = CollectProduce(produce.get(), ctx.symbol_table_, dba); + + ASSERT_EQ(results.size(), 300); + for (int j = 0; j < 300; ++j) { + EXPECT_TRUE(TypedValue::BoolEqual{}(results[j][0], j)); + } +}