Plan Skip and Limit after OrderBy

Summary:
Revive the OutputSymbols method.
Use OutputSymbols to stream results in Interpret.

Reviewers: mislav.bradac, buda, florijan

Reviewed By: florijan

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D318
This commit is contained in:
Teon Banek 2017-04-26 16:12:39 +02:00
parent 8fa574026e
commit f0aaca4a1a
5 changed files with 74 additions and 62 deletions

View File

@ -42,21 +42,10 @@ void Interpret(const std::string &query, GraphDbAccessor &db_accessor,
Frame frame(symbol_table.max_position()); Frame frame(symbol_table.max_position());
std::vector<std::string> header; std::vector<std::string> header;
bool is_return = false; std::vector<Symbol> output_symbols(logical_plan->OutputSymbols(symbol_table));
std::vector<Symbol> output_symbols; if (!output_symbols.empty()) {
if (auto produce = dynamic_cast<plan::Produce *>(logical_plan.get())) { // Since we have output symbols, this means that the query contains RETURN
is_return = true; // clause, so stream out the results.
// collect the symbols from the return clause
for (auto named_expression : produce->named_expressions())
output_symbols.emplace_back(symbol_table[*named_expression]);
} else if (auto order_by =
dynamic_cast<plan::OrderBy *>(logical_plan.get())) {
is_return = true;
output_symbols = order_by->output_symbols();
}
if (is_return) {
// top level node in the operator tree is a produce/order_by (return)
// so stream out results
// generate header // generate header
for (const auto &symbol : output_symbols) header.push_back(symbol.name_); for (const auto &symbol : output_symbols) header.push_back(symbol.name_);
@ -97,4 +86,5 @@ void Interpret(const std::string &query, GraphDbAccessor &db_accessor,
summary["type"] = "rw"; summary["type"] = "rw";
stream.Summary(summary); stream.Summary(summary);
} }
}
} // namespace query

View File

@ -17,8 +17,7 @@
} \ } \
} }
namespace query { namespace query::plan {
namespace plan {
void Once::Accept(LogicalOperatorVisitor &visitor) { void Once::Accept(LogicalOperatorVisitor &visitor) {
if (visitor.PreVisit(*this)) { if (visitor.PreVisit(*this)) {
@ -193,9 +192,8 @@ ScanAll::ScanAllCursor::ScanAllCursor(const ScanAll &self, GraphDbAccessor &db)
// once this GraphDbAccessor API is available // once this GraphDbAccessor API is available
vertices_(db.vertices()), vertices_(db.vertices()),
vertices_it_(vertices_.end()) { vertices_it_(vertices_.end()) {
if (self.graph_view_ == GraphView::NEW) if (self.graph_view_ == GraphView::NEW) throw utils::NotYetImplemented();
throw utils::NotYetImplemented(); }
}
bool ScanAll::ScanAllCursor::Pull(Frame &frame, bool ScanAll::ScanAllCursor::Pull(Frame &frame,
const SymbolTable &symbol_table) { const SymbolTable &symbol_table) {
@ -499,6 +497,14 @@ std::unique_ptr<Cursor> Produce::MakeCursor(GraphDbAccessor &db) {
return std::make_unique<ProduceCursor>(*this, db); return std::make_unique<ProduceCursor>(*this, db);
} }
std::vector<Symbol> Produce::OutputSymbols(const SymbolTable &symbol_table) {
std::vector<Symbol> symbols;
for (const auto &named_expr : named_expressions_) {
symbols.emplace_back(symbol_table.at(*named_expr));
}
return symbols;
}
const std::vector<NamedExpression *> &Produce::named_expressions() { const std::vector<NamedExpression *> &Produce::named_expressions() {
return named_expressions_; return named_expressions_;
} }
@ -1161,6 +1167,11 @@ std::unique_ptr<Cursor> Skip::MakeCursor(GraphDbAccessor &db) {
return std::make_unique<SkipCursor>(*this, db); return std::make_unique<SkipCursor>(*this, db);
} }
std::vector<Symbol> Skip::OutputSymbols(const SymbolTable &symbol_table) {
// Propagate this to potential Produce.
return input_->OutputSymbols(symbol_table);
}
Skip::SkipCursor::SkipCursor(Skip &self, GraphDbAccessor &db) Skip::SkipCursor::SkipCursor(Skip &self, GraphDbAccessor &db)
: self_(self), db_(db), input_cursor_(self_.input_->MakeCursor(db)) {} : self_(self), db_(db), input_cursor_(self_.input_->MakeCursor(db)) {}
@ -1203,6 +1214,11 @@ std::unique_ptr<Cursor> Limit::MakeCursor(GraphDbAccessor &db) {
return std::make_unique<LimitCursor>(*this, db); return std::make_unique<LimitCursor>(*this, db);
} }
std::vector<Symbol> Limit::OutputSymbols(const SymbolTable &symbol_table) {
// Propagate this to potential Produce.
return input_->OutputSymbols(symbol_table);
}
Limit::LimitCursor::LimitCursor(Limit &self, GraphDbAccessor &db) Limit::LimitCursor::LimitCursor(Limit &self, GraphDbAccessor &db)
: self_(self), db_(db), input_cursor_(self_.input_->MakeCursor(db)) {} : self_(self), db_(db), input_cursor_(self_.input_->MakeCursor(db)) {}
@ -1257,6 +1273,11 @@ std::unique_ptr<Cursor> OrderBy::MakeCursor(GraphDbAccessor &db) {
return std::make_unique<OrderByCursor>(*this, db); return std::make_unique<OrderByCursor>(*this, db);
} }
std::vector<Symbol> OrderBy::OutputSymbols(const SymbolTable &symbol_table) {
// Propagate this to potential Produce.
return input_->OutputSymbols(symbol_table);
}
OrderBy::OrderByCursor::OrderByCursor(OrderBy &self, GraphDbAccessor &db) OrderBy::OrderByCursor::OrderByCursor(OrderBy &self, GraphDbAccessor &db)
: self_(self), db_(db), input_cursor_(self_.input_->MakeCursor(db)) {} : self_(self), db_(db), input_cursor_(self_.input_->MakeCursor(db)) {}
@ -1446,5 +1467,4 @@ void Merge::MergeCursor::Reset() {
merge_create_cursor_->Reset(); merge_create_cursor_->Reset();
} }
} // namespace plan } // namespace query::plan
} // namespace query

View File

@ -97,6 +97,21 @@ class LogicalOperator : public ::utils::Visitable<LogicalOperatorVisitor> {
* @param GraphDbAccessor Used to perform operations on the database. * @param GraphDbAccessor Used to perform operations on the database.
*/ */
virtual std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) = 0; virtual std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) = 0;
/** @brief Return @c Symbol vector where the results will be stored.
*
* Currently, outputs symbols are only generated in @c Produce operator.
* @c Skip, @c Limit and @c OrderBy propagate the symbols from @c Produce (if
* it exists as input operator). In the future, we may want this method to
* return the symbols that will be set in this operator.
*
* @param SymbolTable used to find symbols for expressions.
* @return std::vector<Symbol> used for results.
*/
virtual std::vector<Symbol> OutputSymbols(const SymbolTable &) {
return std::vector<Symbol>();
}
virtual ~LogicalOperator() {} virtual ~LogicalOperator() {}
}; };
@ -538,6 +553,7 @@ class Produce : public LogicalOperator {
const std::vector<NamedExpression *> named_expressions); const std::vector<NamedExpression *> named_expressions);
void Accept(LogicalOperatorVisitor &visitor) override; void Accept(LogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override; std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
std::vector<Symbol> OutputSymbols(const SymbolTable &) override;
const std::vector<NamedExpression *> &named_expressions(); const std::vector<NamedExpression *> &named_expressions();
private: private:
@ -1009,6 +1025,7 @@ class Skip : public LogicalOperator {
Skip(const std::shared_ptr<LogicalOperator> &input, Expression *expression); Skip(const std::shared_ptr<LogicalOperator> &input, Expression *expression);
void Accept(LogicalOperatorVisitor &visitor) override; void Accept(LogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override; std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
std::vector<Symbol> OutputSymbols(const SymbolTable &) override;
private: private:
const std::shared_ptr<LogicalOperator> input_; const std::shared_ptr<LogicalOperator> input_;
@ -1051,6 +1068,7 @@ class Limit : public LogicalOperator {
Limit(const std::shared_ptr<LogicalOperator> &input, Expression *expression); Limit(const std::shared_ptr<LogicalOperator> &input, Expression *expression);
void Accept(LogicalOperatorVisitor &visitor) override; void Accept(LogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override; std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
std::vector<Symbol> OutputSymbols(const SymbolTable &) override;
private: private:
const std::shared_ptr<LogicalOperator> input_; const std::shared_ptr<LogicalOperator> input_;
@ -1091,6 +1109,7 @@ class OrderBy : public LogicalOperator {
const std::vector<Symbol> &output_symbols); const std::vector<Symbol> &output_symbols);
void Accept(LogicalOperatorVisitor &visitor) override; void Accept(LogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override; std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
std::vector<Symbol> OutputSymbols(const SymbolTable &) override;
const auto &output_symbols() const { return output_symbols_; } const auto &output_symbols() const { return output_symbols_; }

View File

@ -328,19 +328,6 @@ class ReturnBodyContext : public TreeVisitorBase {
std::list<bool> has_aggregation_; std::list<bool> has_aggregation_;
}; };
auto GenSkipLimit(LogicalOperator *input_op, const ReturnBodyContext &body) {
auto last_op = input_op;
// SKIP is always before LIMIT clause.
if (body.skip()) {
last_op = new Skip(std::shared_ptr<LogicalOperator>(last_op), body.skip());
}
if (body.limit()) {
last_op =
new Limit(std::shared_ptr<LogicalOperator>(last_op), body.limit());
}
return last_op;
}
auto GenReturnBody(LogicalOperator *input_op, bool advance_command, auto GenReturnBody(LogicalOperator *input_op, bool advance_command,
const ReturnBodyContext &body, bool accumulate = false) { const ReturnBodyContext &body, bool accumulate = false) {
if (body.distinct()) { if (body.distinct()) {
@ -350,14 +337,6 @@ auto GenReturnBody(LogicalOperator *input_op, bool advance_command,
std::vector<Symbol> used_symbols(body.used_symbols().begin(), std::vector<Symbol> used_symbols(body.used_symbols().begin(),
body.used_symbols().end()); body.used_symbols().end());
auto last_op = input_op; auto last_op = input_op;
if (body.aggregations().empty()) {
// In case when we have SKIP/LIMIT and we don't perform aggregations, we
// want to put them before (optional) accumulation. This way we ensure that
// write part of the query will be limited.
// For example, `MATCH (n) SET n.x = n.x + 1 RETURN n LIMIT 1` should
// increment `n.x` only once.
last_op = GenSkipLimit(last_op, body);
}
if (accumulate) { if (accumulate) {
// We only advance the command in Accumulate. This is done for WITH clause, // We only advance the command in Accumulate. This is done for WITH clause,
// when the first part updated the database. RETURN clause may only need an // when the first part updated the database. RETURN clause may only need an
@ -367,10 +346,8 @@ auto GenReturnBody(LogicalOperator *input_op, bool advance_command,
} }
if (!body.aggregations().empty()) { if (!body.aggregations().empty()) {
// When we have aggregation, SKIP/LIMIT should always come after it. // When we have aggregation, SKIP/LIMIT should always come after it.
last_op = GenSkipLimit( last_op = new Aggregate(std::shared_ptr<LogicalOperator>(last_op),
new Aggregate(std::shared_ptr<LogicalOperator>(last_op), body.aggregations(), body.group_by(), used_symbols);
body.aggregations(), body.group_by(), used_symbols),
body);
} }
last_op = new Produce(std::shared_ptr<LogicalOperator>(last_op), last_op = new Produce(std::shared_ptr<LogicalOperator>(last_op),
body.named_expressions()); body.named_expressions());
@ -385,6 +362,15 @@ auto GenReturnBody(LogicalOperator *input_op, bool advance_command,
last_op = new OrderBy(std::shared_ptr<LogicalOperator>(last_op), last_op = new OrderBy(std::shared_ptr<LogicalOperator>(last_op),
body.order_by(), body.output_symbols()); body.order_by(), body.output_symbols());
} }
// Finally, Skip and Limit must come after OrderBy.
if (body.skip()) {
last_op = new Skip(std::shared_ptr<LogicalOperator>(last_op), body.skip());
}
// Limit is always after Skip.
if (body.limit()) {
last_op =
new Limit(std::shared_ptr<LogicalOperator>(last_op), body.limit());
}
return last_op; return last_op;
} }

View File

@ -500,9 +500,8 @@ TEST(TestLogicalPlanner, MatchReturnSkipLimit) {
auto query = auto query =
QUERY(MATCH(PATTERN(NODE("n"))), QUERY(MATCH(PATTERN(NODE("n"))),
RETURN(IDENT("n"), AS("n"), SKIP(LITERAL(2)), LIMIT(LITERAL(1)))); RETURN(IDENT("n"), AS("n"), SKIP(LITERAL(2)), LIMIT(LITERAL(1))));
// A simple Skip and Limit combo which should come before Produce. CheckPlan(*query, ExpectScanAll(), ExpectProduce(), ExpectSkip(),
CheckPlan(*query, ExpectScanAll(), ExpectSkip(), ExpectLimit(), ExpectLimit());
ExpectProduce());
} }
TEST(TestLogicalPlanner, CreateWithSkipReturnLimit) { TEST(TestLogicalPlanner, CreateWithSkipReturnLimit) {
@ -515,13 +514,13 @@ TEST(TestLogicalPlanner, CreateWithSkipReturnLimit) {
auto symbol_table = MakeSymbolTable(*query); auto symbol_table = MakeSymbolTable(*query);
auto acc = ExpectAccumulate({symbol_table.at(*ident_n)}); auto acc = ExpectAccumulate({symbol_table.at(*ident_n)});
auto plan = MakeLogicalPlan(*query, symbol_table); auto plan = MakeLogicalPlan(*query, symbol_table);
// Since we have a write query, we need to have Accumulate, so Skip and Limit // Since we have a write query, we need to have Accumulate. This is a bit
// need to come before it. This is a bit different than Neo4j, which optimizes // different than Neo4j 3.0, which optimizes WITH followed by RETURN as a
// WITH followed by RETURN as a single RETURN clause. This would cause the // single RETURN clause and then moves Skip and Limit before Accumulate. This
// Limit operator to also appear before Accumulate, thus changing the // causes different behaviour. A newer version of Neo4j does the same thing as
// behaviour. We've decided to diverge from Neo4j here, for consistency sake. // us here (but who knows if they change it again).
CheckPlan(*plan, symbol_table, ExpectCreateNode(), ExpectSkip(), acc, CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, ExpectProduce(),
ExpectProduce(), ExpectLimit(), ExpectProduce()); ExpectSkip(), ExpectProduce(), ExpectLimit());
} }
TEST(TestLogicalPlanner, CreateReturnSumSkipLimit) { TEST(TestLogicalPlanner, CreateReturnSumSkipLimit) {
@ -538,10 +537,8 @@ TEST(TestLogicalPlanner, CreateReturnSumSkipLimit) {
auto acc = ExpectAccumulate({symbol_table.at(*n_prop->expression_)}); auto acc = ExpectAccumulate({symbol_table.at(*n_prop->expression_)});
auto aggr = ExpectAggregate({sum}, {}); auto aggr = ExpectAggregate({sum}, {});
auto plan = MakeLogicalPlan(*query, symbol_table); auto plan = MakeLogicalPlan(*query, symbol_table);
// We have a write query and aggregation, therefore Skip and Limit should come CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, aggr, ExpectProduce(),
// after Accumulate and Aggregate. ExpectSkip(), ExpectLimit());
CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, aggr, ExpectSkip(),
ExpectLimit(), ExpectProduce());
} }
TEST(TestLogicalPlanner, MatchReturnOrderBy) { TEST(TestLogicalPlanner, MatchReturnOrderBy) {