Plan PullRemoteOrderBy

Summary:
During distributed execution, OrderBy is split across workers and the
master gets to merge those results via PullRemoteOrderBy. Since this
operator may be an input to almost any other operator, virtual accessors
to `input` have been added in LogicalOperator.

Depends on D1221

Reviewers: florijan, msantl, buda

Reviewed By: msantl

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1232
This commit is contained in:
Teon Banek 2018-02-23 15:06:31 +01:00
parent 2a68543f3e
commit 30d2cfb9db
6 changed files with 329 additions and 75 deletions

View File

@ -49,6 +49,8 @@ class TypedValueVectorCompare {
bool operator()(const std::vector<TypedValue> &c1,
const std::vector<TypedValue> &c2) const;
const auto &ordering() const { return ordering_; }
private:
std::vector<Ordering> ordering_;

View File

@ -280,8 +280,7 @@ class DistributedPlanner : public HierarchicalLogicalOperatorVisitor {
}
// OrderBy is an associative operator, this means we can do ordering
// on workers and then merge the results on master. This requires a more
// involved solution, so for now treat OrderBy just like Skip.
// on workers and then merge the results on master.
bool PreVisit(OrderBy &order_by) override {
prev_ops_.push_back(&order_by);
return true;
@ -290,12 +289,37 @@ class DistributedPlanner : public HierarchicalLogicalOperatorVisitor {
prev_ops_.pop_back();
// TODO: Associative combination of OrderBy
if (ShouldSplit()) {
auto input = order_by.input();
auto pull_id = AddWorkerPlan(input);
Split(order_by,
std::make_shared<PullRemote>(
input, pull_id,
input->OutputSymbols(distributed_plan_.symbol_table)));
std::unordered_set<Symbol> pull_symbols(order_by.output_symbols().begin(),
order_by.output_symbols().end());
// Pull symbols need to also include those used in order by expressions.
// For example, `RETURN n AS m ORDER BY n.prop`, output symbols will
// contain `m`, while we also need to pull `n`.
// TODO: Consider creating a virtual symbol for expressions like `n.prop`
// and sending them instead. It's possible that the evaluated expression
// requires less network traffic than sending the value of the used symbol
// `n` itself.
for (const auto &expr : order_by.order_by()) {
UsedSymbolsCollector collector(distributed_plan_.symbol_table);
expr->Accept(collector);
pull_symbols.insert(collector.symbols_.begin(),
collector.symbols_.end());
}
// Create a copy of OrderBy but with added symbols used in expressions, so
// that they can be pulled.
std::vector<std::pair<Ordering, Expression *>> ordering;
ordering.reserve(order_by.order_by().size());
for (int i = 0; i < order_by.order_by().size(); ++i) {
ordering.emplace_back(order_by.compare().ordering()[i],
order_by.order_by()[i]);
}
auto worker_plan = std::make_shared<OrderBy>(
order_by.input(), ordering,
std::vector<Symbol>(pull_symbols.begin(), pull_symbols.end()));
auto pull_id = AddWorkerPlan(worker_plan);
auto merge_op = std::make_unique<PullRemoteOrderBy>(
worker_plan, pull_id, ordering,
std::vector<Symbol>(pull_symbols.begin(), pull_symbols.end()));
SplitOnPrevious(std::move(merge_op));
}
return true;
}
@ -487,9 +511,7 @@ class DistributedPlanner : public HierarchicalLogicalOperatorVisitor {
pull_op, master_aggrs, aggr_op.group_by(), aggr_op.remember());
// Make our master Aggregate into Produce + Aggregate
auto master_plan = std::make_unique<Produce>(master_aggr_op, produce_exprs);
auto produce = dynamic_cast<Produce *>(prev_ops_.back());
DCHECK(produce) << "Expected Aggregate is directly below Produce";
Split(*produce, std::move(master_plan));
SplitOnPrevious(std::move(master_plan));
return true;
}
@ -548,19 +570,9 @@ class DistributedPlanner : public HierarchicalLogicalOperatorVisitor {
pull_remote =
std::make_shared<PullRemote>(nullptr, pull_id, acc.symbols());
}
auto sync = std::make_shared<Synchronize>(acc.input(), pull_remote,
auto sync = std::make_unique<Synchronize>(acc.input(), pull_remote,
acc.advance_command());
auto *prev_op = prev_ops_.back();
// Wire the previous operator (on master) into our synchronization operator.
// TODO: Find a better way to replace the previous operation's input than
// using dynamic casting.
if (auto *produce = dynamic_cast<Produce *>(prev_op)) {
Split(*produce, sync);
} else if (auto *aggr_op = dynamic_cast<Aggregate *>(prev_op)) {
Split(*aggr_op, sync);
} else {
throw utils::NotYetImplemented("distributed planning");
}
SplitOnPrevious(std::move(sync));
needs_synchronize_ = false;
return true;
}
@ -703,8 +715,21 @@ class DistributedPlanner : public HierarchicalLogicalOperatorVisitor {
template <class TOp>
void Split(TOp &master_op, std::shared_ptr<LogicalOperator> merge_op) {
if (on_master_) throw utils::NotYetImplemented("distributed planning");
master_op.set_input(merge_op);
on_master_ = true;
master_op.set_input(merge_op);
}
void SplitOnPrevious(std::unique_ptr<LogicalOperator> merge_op) {
if (on_master_) throw utils::NotYetImplemented("distributed planning");
if (prev_ops_.empty()) {
distributed_plan_.master_plan = std::move(merge_op);
on_master_ = true;
return;
}
auto *master_op = prev_ops_.back();
if (!master_op->HasSingleInput())
throw utils::NotYetImplemented("distributed planning");
Split(*master_op, std::move(merge_op));
}
int64_t AddWorkerPlan(const std::shared_ptr<LogicalOperator> &worker_plan) {

View File

@ -40,6 +40,15 @@ DEFINE_HIDDEN_int32(remote_pull_sleep, 1,
return visitor.PostVisit(*this); \
}
#define WITHOUT_SINGLE_INPUT(class_name) \
bool class_name::HasSingleInput() const { return false; } \
std::shared_ptr<LogicalOperator> class_name::input() const { \
LOG(FATAL) << "Operator " << #class_name << " has no single input!"; \
} \
void class_name::set_input(std::shared_ptr<LogicalOperator>) { \
LOG(FATAL) << "Operator " << #class_name << " has no single input!"; \
}
namespace query::plan {
namespace {
@ -95,6 +104,8 @@ std::unique_ptr<Cursor> Once::MakeCursor(database::GraphDbAccessor &) const {
return std::make_unique<OnceCursor>();
}
WITHOUT_SINGLE_INPUT(Once);
void Once::OnceCursor::Reset() { did_pull_ = false; }
CreateNode::CreateNode(const std::shared_ptr<LogicalOperator> &input,
@ -397,10 +408,10 @@ std::unique_ptr<Cursor> ScanAllByLabelPropertyRange::MakeCursor(
context.symbol_table_, db, graph_view_);
auto convert = [&evaluator](const auto &bound)
-> std::experimental::optional<utils::Bound<PropertyValue>> {
if (!bound) return std::experimental::nullopt;
return std::experimental::make_optional(utils::Bound<PropertyValue>(
bound.value().value()->Accept(evaluator), bound.value().type()));
};
if (!bound) return std::experimental::nullopt;
return std::experimental::make_optional(utils::Bound<PropertyValue>(
bound.value().value()->Accept(evaluator), bound.value().type()));
};
return db.Vertices(label_, property_, convert(lower_bound()),
convert(upper_bound()), graph_view_ == GraphView::NEW);
};
@ -2820,6 +2831,8 @@ bool CreateIndex::Accept(HierarchicalLogicalOperatorVisitor &visitor) {
return visitor.Visit(*this);
}
WITHOUT_SINGLE_INPUT(CreateIndex);
class CreateIndexCursor : public Cursor {
public:
CreateIndexCursor(const CreateIndex &self, database::GraphDbAccessor &db)
@ -2884,6 +2897,8 @@ std::vector<Symbol> Union::ModifiedSymbols(const SymbolTable &) const {
return union_symbols_;
}
WITHOUT_SINGLE_INPUT(Union);
Union::UnionCursor::UnionCursor(const Union &self,
database::GraphDbAccessor &db)
: self_(self),
@ -2965,6 +2980,8 @@ bool Cartesian::Accept(HierarchicalLogicalOperatorVisitor &visitor) {
return visitor.PostVisit(*this);
}
WITHOUT_SINGLE_INPUT(Cartesian);
PullRemoteOrderBy::PullRemoteOrderBy(
const std::shared_ptr<LogicalOperator> &input, int64_t plan_id,
const std::vector<std::pair<Ordering, Expression *>> &order_by,

View File

@ -146,7 +146,7 @@ class LogicalOperator
* database.
*/
virtual std::unique_ptr<Cursor> MakeCursor(
database::GraphDbAccessor &db) const = 0;
database::GraphDbAccessor &) const = 0;
/** Return @c Symbol vector where the query results will be stored.
*
@ -179,6 +179,24 @@ class LogicalOperator
*/
virtual std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const = 0;
/**
* Returns true if the operator takes only one input operator.
* NOTE: When this method returns true, you may use `input` and `set_input`
* methods.
*/
virtual bool HasSingleInput() const = 0;
/**
* Returns the input operator if it has any.
* NOTE: This should only be called if `HasSingleInput() == true`.
*/
virtual std::shared_ptr<LogicalOperator> input() const = 0;
/**
* Set a different input on this operator.
* NOTE: This should only be called if `HasSingleInput() == true`.
*/
virtual void set_input(std::shared_ptr<LogicalOperator>) = 0;
private:
friend class boost::serialization::access;
@ -208,6 +226,10 @@ class Once : public LogicalOperator {
return {};
}
bool HasSingleInput() const override;
std::shared_ptr<LogicalOperator> input() const override;
void set_input(std::shared_ptr<LogicalOperator>) override;
private:
class OnceCursor : public Cursor {
public:
@ -252,8 +274,12 @@ class CreateNode : public LogicalOperator {
database::GraphDbAccessor &db) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
auto input() const { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) { input_ = input; }
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override {
input_ = input;
}
auto on_random_worker() const { return on_random_worker_; }
void set_on_random_worker(bool v) { on_random_worker_ = v; }
@ -332,8 +358,11 @@ class CreateExpand : public LogicalOperator {
database::GraphDbAccessor &db) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
auto input() const { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) { input_ = input; }
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override {
input_ = input;
}
private:
// info on what's getting expanded
@ -429,8 +458,12 @@ class ScanAll : public LogicalOperator {
database::GraphDbAccessor &db) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
auto input() const { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) { input_ = input; }
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override {
input_ = input;
}
auto output_symbol() const { return output_symbol_; }
auto graph_view() const { return graph_view_; }
@ -781,6 +814,12 @@ class Expand : public LogicalOperator, public ExpandCommon {
database::GraphDbAccessor &db) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override {
input_ = input;
}
class ExpandCursor : public Cursor {
public:
ExpandCursor(const Expand &self, database::GraphDbAccessor &db);
@ -913,6 +952,12 @@ class ExpandVariable : public LogicalOperator, public ExpandCommon {
database::GraphDbAccessor &db) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override {
input_ = input;
}
auto type() const { return type_; }
private:
@ -978,7 +1023,12 @@ class ConstructNamedPath : public LogicalOperator {
database::GraphDbAccessor &db) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
const auto &input() const { return input_; }
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override {
input_ = input;
}
const auto &path_symbol() const { return path_symbol_; }
const auto &path_elements() const { return path_elements_; }
@ -1017,6 +1067,12 @@ class Filter : public LogicalOperator {
database::GraphDbAccessor &db) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override {
input_ = input;
}
private:
std::shared_ptr<LogicalOperator> input_;
Expression *expression_;
@ -1075,9 +1131,13 @@ class Produce : public LogicalOperator {
std::vector<Symbol> OutputSymbols(const SymbolTable &) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override {
input_ = input;
}
const std::vector<NamedExpression *> &named_expressions();
auto input() const { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) { input_ = input; }
private:
std::shared_ptr<LogicalOperator> input_;
@ -1130,8 +1190,11 @@ class Delete : public LogicalOperator {
database::GraphDbAccessor &db) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
auto input() const { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) { input_ = input; }
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override {
input_ = input;
}
private:
std::shared_ptr<LogicalOperator> input_;
@ -1191,8 +1254,11 @@ class SetProperty : public LogicalOperator {
database::GraphDbAccessor &db) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
auto input() const { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) { input_ = input; }
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override {
input_ = input;
}
private:
std::shared_ptr<LogicalOperator> input_;
@ -1265,8 +1331,11 @@ class SetProperties : public LogicalOperator {
database::GraphDbAccessor &db) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
auto input() const { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) { input_ = input; }
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override {
input_ = input;
}
private:
std::shared_ptr<LogicalOperator> input_;
@ -1335,8 +1404,11 @@ class SetLabels : public LogicalOperator {
database::GraphDbAccessor &db) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
auto input() const { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) { input_ = input; }
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override {
input_ = input;
}
private:
std::shared_ptr<LogicalOperator> input_;
@ -1380,8 +1452,11 @@ class RemoveProperty : public LogicalOperator {
database::GraphDbAccessor &db) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
auto input() const { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) { input_ = input; }
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override {
input_ = input;
}
private:
std::shared_ptr<LogicalOperator> input_;
@ -1436,8 +1511,11 @@ class RemoveLabels : public LogicalOperator {
database::GraphDbAccessor &db) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
auto input() const { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) { input_ = input; }
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override {
input_ = input;
}
private:
std::shared_ptr<LogicalOperator> input_;
@ -1502,6 +1580,12 @@ class ExpandUniquenessFilter : public LogicalOperator {
database::GraphDbAccessor &db) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override {
input_ = input;
}
private:
std::shared_ptr<LogicalOperator> input_;
Symbol expand_symbol_;
@ -1568,7 +1652,12 @@ class Accumulate : public LogicalOperator {
database::GraphDbAccessor &db) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
auto input() const { return input_; }
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override {
input_ = input;
}
const auto &symbols() const { return symbols_; };
auto advance_command() const { return advance_command_; }
@ -1671,11 +1760,15 @@ class Aggregate : public LogicalOperator {
database::GraphDbAccessor &db) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override {
input_ = input;
}
const auto &aggregations() const { return aggregations_; }
const auto &group_by() const { return group_by_; }
const auto &remember() const { return remember_; }
auto input() const { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) { input_ = input; }
private:
std::shared_ptr<LogicalOperator> input_;
@ -1808,8 +1901,11 @@ class Skip : public LogicalOperator {
std::vector<Symbol> OutputSymbols(const SymbolTable &) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
auto input() const { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) { input_ = input; }
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override {
input_ = input;
}
private:
std::shared_ptr<LogicalOperator> input_;
@ -1876,8 +1972,11 @@ class Limit : public LogicalOperator {
std::vector<Symbol> OutputSymbols(const SymbolTable &) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
auto input() const { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) { input_ = input; }
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override {
input_ = input;
}
private:
std::shared_ptr<LogicalOperator> input_;
@ -1942,9 +2041,15 @@ class OrderBy : public LogicalOperator {
std::vector<Symbol> OutputSymbols(const SymbolTable &) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override {
input_ = input;
}
const auto &order_by() const { return order_by_; }
const auto &compare() const { return compare_; }
const auto &output_symbols() const { return output_symbols_; }
auto input() const { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) { input_ = input; }
private:
std::shared_ptr<LogicalOperator> input_;
@ -2021,7 +2126,15 @@ class Merge : public LogicalOperator {
database::GraphDbAccessor &db) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
auto input() const { return input_; }
// TODO: Consider whether we want to treat Merge as having single input. It
// makes sense that we do, because other branches are executed depending on
// the input.
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override {
input_ = input;
}
auto merge_match() const { return merge_match_; }
auto merge_create() const { return merge_create_; }
@ -2080,7 +2193,13 @@ class Optional : public LogicalOperator {
database::GraphDbAccessor &db) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
auto input() const { return input_; }
// TODO: Consider whether we want to treat Optional as having single input.
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override {
input_ = input;
}
auto optional() const { return optional_; }
const auto &optional_symbols() const { return optional_symbols_; }
@ -2135,6 +2254,12 @@ class Unwind : public LogicalOperator {
database::GraphDbAccessor &db) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override {
input_ = input;
}
Expression *input_expression() const { return input_expression_; }
private:
@ -2199,8 +2324,11 @@ class Distinct : public LogicalOperator {
std::vector<Symbol> OutputSymbols(const SymbolTable &) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
auto input() const { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) { input_ = input; }
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override {
input_ = input;
}
private:
std::shared_ptr<LogicalOperator> input_;
@ -2255,6 +2383,10 @@ class CreateIndex : public LogicalOperator {
return {};
}
bool HasSingleInput() const override;
std::shared_ptr<LogicalOperator> input() const override;
void set_input(std::shared_ptr<LogicalOperator>) override;
auto label() const { return label_; }
auto property() const { return property_; }
@ -2294,6 +2426,10 @@ class Union : public LogicalOperator {
std::vector<Symbol> OutputSymbols(const SymbolTable &) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
bool HasSingleInput() const override;
std::shared_ptr<LogicalOperator> input() const override;
void set_input(std::shared_ptr<LogicalOperator>) override;
private:
std::shared_ptr<LogicalOperator> left_op_, right_op_;
std::vector<Symbol> union_symbols_, left_symbols_, right_symbols_;
@ -2343,7 +2479,12 @@ class PullRemote : public LogicalOperator {
std::vector<Symbol> OutputSymbols(const SymbolTable &) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
auto input() const { return input_; }
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override {
input_ = input;
}
const auto &symbols() const { return symbols_; }
auto plan_id() const { return plan_id_; }
@ -2401,7 +2542,12 @@ class Synchronize : public LogicalOperator {
return input_->OutputSymbols(symbol_table);
}
auto input() const { return input_; }
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override {
input_ = input;
}
auto pull_remote() const { return pull_remote_; }
auto advance_command() const { return advance_command_; }
@ -2440,6 +2586,10 @@ class Cartesian : public LogicalOperator {
database::GraphDbAccessor &db) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
bool HasSingleInput() const override;
std::shared_ptr<LogicalOperator> input() const override;
void set_input(std::shared_ptr<LogicalOperator>) override;
auto left_op() const { return left_op_; }
auto left_symbols() const { return left_symbols_; }
auto right_op() const { return right_op_; }
@ -2485,11 +2635,16 @@ class PullRemoteOrderBy : public LogicalOperator {
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
std::vector<Symbol> OutputSymbols(const SymbolTable &) const override;
auto input() const { return input_; }
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override {
input_ = input;
}
auto plan_id() const { return plan_id_; }
const auto &symbols() const { return symbols_; }
auto order_by() const { return order_by_; }
auto compare() const { return compare_; }
const auto &compare() const { return compare_; }
private:
std::shared_ptr<LogicalOperator> input_;

View File

@ -474,7 +474,17 @@ class PlanPrinter : public query::plan::HierarchicalLogicalOperatorVisitor {
PRE_VISIT(Skip);
PRE_VISIT(Limit);
PRE_VISIT(OrderBy);
bool PreVisit(query::plan::OrderBy &op) override {
WithPrintLn([&op](auto &out) {
out << "* OrderBy {";
utils::PrintIterable(
out, op.output_symbols(), ", ",
[](auto &out, const auto &sym) { out << sym.name(); });
out << "}";
});
return true;
}
bool PreVisit(query::plan::Merge &op) override {
WithPrintLn([](auto &out) { out << "* Merge"; });
@ -535,6 +545,22 @@ class PlanPrinter : public query::plan::HierarchicalLogicalOperatorVisitor {
op.left_op()->Accept(*this);
return false;
}
bool PreVisit(query::plan::PullRemoteOrderBy &op) override {
WithPrintLn([&op](auto &out) {
out << "* PullRemoteOrderBy {";
utils::PrintIterable(
out, op.symbols(), ", ",
[](auto &out, const auto &sym) { out << sym.name(); });
out << "}";
});
WithPrintLn([](auto &out) { out << "|\\"; });
++depth_;
WithPrintLn([](auto &out) { out << "* workers"; });
--depth_;
return true;
}
#undef PRE_VISIT
private:

View File

@ -125,6 +125,8 @@ class PlanChecker : public HierarchicalLogicalOperatorVisitor {
CheckOp(op);
return false;
}
PRE_VISIT(PullRemoteOrderBy);
#undef PRE_VISIT
std::list<BaseOpChecker *> checkers_;
@ -445,6 +447,19 @@ class ExpectCreateNode : public OpChecker<CreateNode> {
bool on_random_worker_ = false;
};
class ExpectPullRemoteOrderBy : public OpChecker<PullRemoteOrderBy> {
public:
ExpectPullRemoteOrderBy(const std::vector<Symbol> symbols)
: symbols_(symbols) {}
void ExpectOp(PullRemoteOrderBy &op, const SymbolTable &) override {
EXPECT_THAT(op.symbols(), testing::UnorderedElementsAreArray(symbols_));
}
private:
std::vector<Symbol> symbols_;
};
auto MakeSymbolTable(query::Query &query) {
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
@ -1253,23 +1268,30 @@ TYPED_TEST(TestPlanner, CreateReturnSumSkipLimit) {
}
TYPED_TEST(TestPlanner, MatchReturnOrderBy) {
// Test MATCH (n) RETURN n ORDER BY n.prop
// Test MATCH (n) RETURN n AS m ORDER BY n.prop
database::SingleNode db;
database::GraphDbAccessor dba(db);
auto prop = dba.Property("prop");
AstTreeStorage storage;
auto *as_n = NEXPR("n", IDENT("n"));
auto ret = RETURN(as_n, ORDER_BY(PROPERTY_LOOKUP("n", prop)));
QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), ret));
auto *as_m = NEXPR("m", IDENT("n"));
auto *node_n = NODE("n");
auto ret = RETURN(as_m, ORDER_BY(PROPERTY_LOOKUP("n", prop)));
QUERY(SINGLE_QUERY(MATCH(PATTERN(node_n)), ret));
auto symbol_table = MakeSymbolTable(*storage.query());
auto planner = MakePlanner<TypeParam>(db, storage, symbol_table);
CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), ExpectProduce(),
ExpectOrderBy());
ExpectPullRemote pull({symbol_table.at(*as_n)});
ExpectPullRemoteOrderBy pull_order_by(
{symbol_table.at(*as_m), symbol_table.at(*node_n->identifier_)});
auto expected = ExpectDistributed(
MakeCheckers(ExpectScanAll(), ExpectProduce(), pull, ExpectOrderBy()),
MakeCheckers(ExpectScanAll(), ExpectProduce()));
MakeCheckers(ExpectScanAll(), ExpectProduce(), ExpectOrderBy(),
pull_order_by),
MakeCheckers(ExpectScanAll(), ExpectProduce(), ExpectOrderBy()));
CheckDistributedPlan(planner.plan(), symbol_table, expected);
// Even though last operator pulls and orders by `m` and `n`, we expect only
// `m` as the output of the query execution.
EXPECT_THAT(planner.plan().OutputSymbols(symbol_table),
testing::UnorderedElementsAre(symbol_table.at(*as_m)));
}
TYPED_TEST(TestPlanner, CreateWithOrderByWhere) {
@ -1300,6 +1322,10 @@ TYPED_TEST(TestPlanner, CreateWithOrderByWhere) {
CheckPlan(planner.plan(), symbol_table, ExpectCreateNode(),
ExpectCreateExpand(), acc, ExpectProduce(), ExpectOrderBy(),
ExpectFilter());
auto expected = ExpectDistributed(MakeCheckers(
ExpectCreateNode(true), ExpectCreateExpand(), ExpectSynchronize(true),
ExpectProduce(), ExpectOrderBy(), ExpectFilter()));
CheckDistributedPlan(planner.plan(), symbol_table, expected);
}
TYPED_TEST(TestPlanner, ReturnAddSumCountOrderBy) {
@ -1311,6 +1337,9 @@ TYPED_TEST(TestPlanner, ReturnAddSumCountOrderBy) {
RETURN(ADD(sum, count), AS("result"), ORDER_BY(IDENT("result")))));
auto aggr = ExpectAggregate({sum, count}, {});
CheckPlan<TypeParam>(storage, aggr, ExpectProduce(), ExpectOrderBy());
auto expected =
ExpectDistributed(MakeCheckers(aggr, ExpectProduce(), ExpectOrderBy()));
CheckDistributedPlan<TypeParam>(storage, expected);
}
TYPED_TEST(TestPlanner, MatchMerge) {