From 55dc08fc303e5979fe924712770c99e43b1d63e2 Mon Sep 17 00:00:00 2001 From: Teon Banek Date: Thu, 20 Apr 2017 11:20:20 +0200 Subject: [PATCH] Plan Skip and Limit operators Summary: Support SKIP and LIMIT macros in tests. Test planning Skip and Limit. Prevent variables in SKIP and LIMIT. Reviewers: mislav.bradac, florijan Reviewed By: florijan Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D296 --- .../frontend/semantic/symbol_generator.cpp | 29 +++++- .../frontend/semantic/symbol_generator.hpp | 7 +- src/query/plan/planner.cpp | 93 +++++++++++++------ tests/unit/query_common.hpp | 41 ++++++-- tests/unit/query_planner.cpp | 46 +++++++++ tests/unit/query_semantic.cpp | 30 ++++++ 6 files changed, 207 insertions(+), 39 deletions(-) diff --git a/src/query/frontend/semantic/symbol_generator.cpp b/src/query/frontend/semantic/symbol_generator.cpp index 9b5d2c82e..4f28efc3c 100644 --- a/src/query/frontend/semantic/symbol_generator.cpp +++ b/src/query/frontend/semantic/symbol_generator.cpp @@ -44,17 +44,35 @@ void SymbolGenerator::BindNamedExpressionSymbols( } } +void SymbolGenerator::VisitSkipAndLimit(Expression *skip, Expression *limit) { + if (skip) { + scope_.in_skip = true; + skip->Accept(*this); + scope_.in_skip = false; + } + if (limit) { + scope_.in_limit = true; + limit->Accept(*this); + scope_.in_limit = false; + } +} + // Clauses void SymbolGenerator::Visit(Create &create) { scope_.in_create = true; } void SymbolGenerator::PostVisit(Create &create) { scope_.in_create = false; } -void SymbolGenerator::Visit(Return &ret) { scope_.in_return = true; } -void SymbolGenerator::PostVisit(Return &ret) { +bool SymbolGenerator::PreVisit(Return &ret) { + scope_.in_return = true; + for (auto &expr : ret.body_.named_expressions) { + expr->Accept(*this); + } // Named expressions establish bindings for expressions which come after // return, but not for the expressions contained inside. BindNamedExpressionSymbols(ret.body_.named_expressions); + VisitSkipAndLimit(ret.body_.skip, ret.body_.limit); scope_.in_return = false; + return false; // We handled the traversal ourselves. } bool SymbolGenerator::PreVisit(With &with) { @@ -68,6 +86,7 @@ bool SymbolGenerator::PreVisit(With &with) { // be visible inside named expressions themselves. scope_.symbols.clear(); BindNamedExpressionSymbols(with.body_.named_expressions); + VisitSkipAndLimit(with.body_.skip, with.body_.limit); if (with.where_) with.where_->Accept(*this); return false; // We handled the traversal ourselves. } @@ -75,6 +94,10 @@ bool SymbolGenerator::PreVisit(With &with) { // Expressions void SymbolGenerator::Visit(Identifier &ident) { + if (scope_.in_skip || scope_.in_limit) { + throw SemanticException("Variables are not allowed in {}", + scope_.in_skip ? "SKIP" : "LIMIT"); + } Symbol symbol; if (scope_.in_pattern && !scope_.in_property_map) { // Patterns can bind new symbols or reference already bound. But there @@ -89,7 +112,7 @@ void SymbolGenerator::Visit(Identifier &ident) { // `MATCH (n) - [r] -> (n) - [r] -> (n) RETURN r`, which would // usually raise redeclaration of `r`. if ((scope_.in_create_node || scope_.in_create_edge) && - HasSymbol(ident.name_)) { + HasSymbol(ident.name_)) { // Case 1) throw RedeclareVariableError(ident.name_); } diff --git a/src/query/frontend/semantic/symbol_generator.hpp b/src/query/frontend/semantic/symbol_generator.hpp index 8e3df829a..7f410c7c1 100644 --- a/src/query/frontend/semantic/symbol_generator.hpp +++ b/src/query/frontend/semantic/symbol_generator.hpp @@ -27,8 +27,7 @@ class SymbolGenerator : public TreeVisitorBase { // Clauses void Visit(Create &) override; void PostVisit(Create &) override; - void Visit(Return &) override; - void PostVisit(Return &) override; + bool PreVisit(Return &) override; bool PreVisit(With &) override; // Expressions @@ -61,6 +60,8 @@ class SymbolGenerator : public TreeVisitorBase { bool in_aggregation{false}; bool in_return{false}; bool in_with{false}; + bool in_skip{false}; + bool in_limit{false}; std::map symbols; }; @@ -79,6 +80,8 @@ class SymbolGenerator : public TreeVisitorBase { void BindNamedExpressionSymbols( const std::vector &named_expressions); + void VisitSkipAndLimit(Expression *skip, Expression *limit); + SymbolTable &symbol_table_; Scope scope_; }; diff --git a/src/query/plan/planner.cpp b/src/query/plan/planner.cpp index b9facec61..2104a1671 100644 --- a/src/query/plan/planner.cpp +++ b/src/query/plan/planner.cpp @@ -177,14 +177,25 @@ auto GenMatch(Match &match, LogicalOperator *input_op, return last_op; } -// Ast tree visitor which collects the context for a return body. The return -// body are the named expressions found in WITH and RETURN clauses. The -// collected context consists of used symbols, aggregations and group by named -// expressions. +// Ast tree visitor which collects the context for a return body. +// The return body of WITH and RETURN clauses consists of: +// +// * named expressions (used to produce results); +// * flag whether the results need to be DISTINCT; +// * optional SKIP expression; +// * optional LIMIT expression and +// * optional ORDER BY expression. +// +// In addition to the above, we collect information on used symbols, +// aggregations and expressions used for group by. class ReturnBodyContext : public TreeVisitorBase { public: - ReturnBodyContext(const SymbolTable &symbol_table) - : symbol_table_(symbol_table) {} + ReturnBodyContext(const ReturnBody &body, const SymbolTable &symbol_table) + : body_(body), symbol_table_(symbol_table) { + for (auto &named_expr : body_.named_expressions) { + named_expr->Accept(*this); + } + } using TreeVisitorBase::PreVisit; using TreeVisitorBase::Visit; @@ -249,6 +260,14 @@ class ReturnBodyContext : public TreeVisitorBase { has_aggregation_.pop_back(); } + // If true, results need to be distinct. + bool distinct() const { return body_.distinct; } + // Named expressions which are used to produce results. + const auto &named_expressions() const { return body_.named_expressions; } + // Optional expression which determines how many results to skip. + auto *skip() const { return body_.skip; } + // Optional expression which determines how many results to produce. + auto *limit() const { return body_.limit; } // Set of symbols used inside the visited expressions outside of aggregation // expression. const auto &symbols() const { return symbols_; } @@ -267,6 +286,7 @@ class ReturnBodyContext : public TreeVisitorBase { } }; + const ReturnBody &body_; const SymbolTable &symbol_table_; std::unordered_set symbols_; std::vector aggregations_; @@ -275,17 +295,36 @@ class ReturnBodyContext : public TreeVisitorBase { std::list has_aggregation_; }; +auto GenSkipLimit(LogicalOperator *input_op, const ReturnBodyContext &body) { + auto last_op = input_op; + // SKIP is always before LIMIT clause. + if (body.skip()) { + last_op = new Skip(std::shared_ptr(last_op), body.skip()); + } + if (body.limit()) { + last_op = + new Limit(std::shared_ptr(last_op), body.limit()); + } + return last_op; +} + auto GenReturnBody(LogicalOperator *input_op, bool advance_command, - const std::vector &named_expressions, - const SymbolTable &symbol_table, bool accumulate = false) { - ReturnBodyContext context(symbol_table); - // Generate context for all named expressions. - for (auto &named_expr : named_expressions) { - named_expr->Accept(context); + const ReturnBodyContext &body, bool accumulate = false) { + if (body.distinct()) { + // TODO: Plan with distinct, when operator available. + throw utils::NotYetImplemented(); } auto symbols = - std::vector(context.symbols().begin(), context.symbols().end()); + std::vector(body.symbols().begin(), body.symbols().end()); auto last_op = input_op; + if (body.aggregations().empty()) { + // In case when we have SKIP/LIMIT and we don't perform aggregations, we + // want to put them before (optional) accumulation. This way we ensure that + // write part of the query will be limited. + // For example, `MATCH (n) SET n.x = n.x + 1 RETURN n LIMIT 1` should + // increment `n.x` only once. + last_op = GenSkipLimit(last_op, body); + } if (accumulate) { // We only advance the command in Accumulate. This is done for WITH clause, // when the first part updated the database. RETURN clause may only need an @@ -293,32 +332,30 @@ auto GenReturnBody(LogicalOperator *input_op, bool advance_command, last_op = new Accumulate(std::shared_ptr(last_op), symbols, advance_command); } - if (!context.aggregations().empty()) { - last_op = + if (!body.aggregations().empty()) { + // When we have aggregation, SKIP/LIMIT should always come after it. + last_op = GenSkipLimit( new Aggregate(std::shared_ptr(last_op), - context.aggregations(), context.group_by(), symbols); + body.aggregations(), body.group_by(), symbols), + body); } return new Produce(std::shared_ptr(last_op), - named_expressions); + body.named_expressions()); } auto GenWith(With &with, LogicalOperator *input_op, const SymbolTable &symbol_table, bool is_write, std::unordered_set &bound_symbols) { // WITH clause is Accumulate/Aggregate (advance_command) + Produce and - // optional Filter. - if (with.body_.distinct) { - // TODO: Plan distinct with, when operator available. - throw utils::NotYetImplemented(); - } - // In case of update and aggregation, we want to accumulate first, so that - // when aggregating, we get the latest results. Similar to RETURN clause. + // optional Filter. In case of update and aggregation, we want to accumulate + // first, so that when aggregating, we get the latest results. Similar to + // RETURN clause. bool accumulate = is_write; // No need to advance the command if we only performed reads. bool advance_command = is_write; + ReturnBodyContext body(with.body_, symbol_table); LogicalOperator *last_op = - GenReturnBody(input_op, advance_command, with.body_.named_expressions, - symbol_table, accumulate); + GenReturnBody(input_op, advance_command, body, accumulate); // Reset bound symbols, so that only those in WITH are exposed. bound_symbols.clear(); for (auto &named_expr : with.body_.named_expressions) { @@ -341,8 +378,8 @@ auto GenReturn(Return &ret, LogicalOperator *input_op, // value is the same, final result of 'k' increments. bool accumulate = is_write; bool advance_command = false; - return GenReturnBody(input_op, advance_command, ret.body_.named_expressions, - symbol_table, accumulate); + ReturnBodyContext body(ret.body_, symbol_table); + return GenReturnBody(input_op, advance_command, body, accumulate); } // Generate an operator for a clause which writes to the database. If the clause diff --git a/tests/unit/query_common.hpp b/tests/unit/query_common.hpp index f3a2a211c..b27b092ad 100644 --- a/tests/unit/query_common.hpp +++ b/tests/unit/query_common.hpp @@ -4,6 +4,15 @@ namespace query { namespace test_common { +// Custom types for SKIP and LIMIT and expressions, so that they can be used to +// resolve function calls. +struct Skip { + query::Expression *expression = nullptr; +}; +struct Limit { + query::Expression *expression = nullptr; +}; + /// /// Create PropertyLookup with given name and property. /// @@ -115,6 +124,15 @@ auto GetReturn(Return *ret, NamedExpression *named_expr) { ret->body_.named_expressions.emplace_back(named_expr); return ret; } +auto GetReturn(Return *ret, Skip skip, Limit limit = Limit{}) { + ret->body_.skip = skip.expression; + ret->body_.limit = limit.expression; + return ret; +} +auto GetReturn(Return *ret, Limit limit) { + ret->body_.limit = limit.expression; + return ret; +} auto GetReturn(Return *ret, Expression *expr, NamedExpression *named_expr) { // This overload supports `RETURN(expr, AS(name))` construct, since // NamedExpression does not inherit Expression. @@ -124,18 +142,18 @@ auto GetReturn(Return *ret, Expression *expr, NamedExpression *named_expr) { } template auto GetReturn(Return *ret, Expression *expr, NamedExpression *named_expr, - T *... rest) { + T... rest) { named_expr->expression_ = expr; ret->body_.named_expressions.emplace_back(named_expr); return GetReturn(ret, rest...); } template -auto GetReturn(Return *ret, NamedExpression *named_expr, T *... rest) { +auto GetReturn(Return *ret, NamedExpression *named_expr, T... rest) { ret->body_.named_expressions.emplace_back(named_expr); return GetReturn(ret, rest...); } template -auto GetReturn(AstTreeStorage &storage, T *... exprs) { +auto GetReturn(AstTreeStorage &storage, T... exprs) { auto ret = storage.Create(); return GetReturn(ret, exprs...); } @@ -147,6 +165,15 @@ auto GetWith(With *with, NamedExpression *named_expr) { with->body_.named_expressions.emplace_back(named_expr); return with; } +auto GetWith(With *with, Skip skip, Limit limit = {}) { + with->body_.skip = skip.expression; + with->body_.limit = limit.expression; + return with; +} +auto GetWith(With *with, Limit limit) { + with->body_.limit = limit.expression; + return with; +} auto GetWith(With *with, Expression *expr, NamedExpression *named_expr) { // This overload supports `RETURN(expr, AS(name))` construct, since // NamedExpression does not inherit Expression. @@ -156,18 +183,18 @@ auto GetWith(With *with, Expression *expr, NamedExpression *named_expr) { } template auto GetWith(With *with, Expression *expr, NamedExpression *named_expr, - T *... rest) { + T... rest) { named_expr->expression_ = expr; with->body_.named_expressions.emplace_back(named_expr); return GetWith(with, rest...); } template -auto GetWith(With *with, NamedExpression *named_expr, T *... rest) { +auto GetWith(With *with, NamedExpression *named_expr, T... rest) { with->body_.named_expressions.emplace_back(named_expr); return GetWith(with, rest...); } template -auto GetWith(AstTreeStorage &storage, T *... exprs) { +auto GetWith(AstTreeStorage &storage, T... exprs) { auto with = storage.Create(); return GetWith(with, exprs...); } @@ -261,6 +288,8 @@ auto GetRemove(AstTreeStorage &storage, const std::string &name, #define AS(name) storage.Create((name)) #define RETURN(...) query::test_common::GetReturn(storage, __VA_ARGS__) #define WITH(...) query::test_common::GetWith(storage, __VA_ARGS__) +#define SKIP(expr) query::test_common::Skip{(expr)} +#define LIMIT(expr) query::test_common::Limit{(expr)} #define DELETE(...) query::test_common::GetDelete(storage, {__VA_ARGS__}) #define DETACH_DELETE(...) \ query::test_common::GetDelete(storage, {__VA_ARGS__}, true) diff --git a/tests/unit/query_planner.cpp b/tests/unit/query_planner.cpp index 077c11a48..d7c165cc0 100644 --- a/tests/unit/query_planner.cpp +++ b/tests/unit/query_planner.cpp @@ -58,6 +58,8 @@ template using ExpectExpandUniquenessFilter = OpChecker>; using ExpectAccumulate = OpChecker; +using ExpectSkip = OpChecker; +using ExpectLimit = OpChecker; class ExpectAggregate : public OpChecker { public: @@ -115,6 +117,8 @@ class PlanChecker : public LogicalOperatorVisitor { void Visit(ExpandUniquenessFilter &op) override { CheckOp(op); } void Visit(Accumulate &op) override { CheckOp(op); } void Visit(Aggregate &op) override { CheckOp(op); } + void Visit(Skip &op) override { CheckOp(op); } + void Visit(Limit &op) override { CheckOp(op); } std::list checkers_; @@ -431,4 +435,46 @@ TEST(TestLogicalPlanner, MatchWithCreate) { CheckPlan(*query, ExpectScanAll(), ExpectProduce(), ExpectCreateExpand()); } +TEST(TestLogicalPlanner, MatchReturnSkipLimit) { + // Test MATCH (n) RETURN n SKIP 2 LIMIT 1 + AstTreeStorage storage; + auto query = + QUERY(MATCH(PATTERN(NODE("n"))), + RETURN(IDENT("n"), AS("n"), SKIP(LITERAL(2)), LIMIT(LITERAL(1)))); + // A simple Skip and Limit combo which should come before Produce. + CheckPlan(*query, ExpectScanAll(), ExpectSkip(), ExpectLimit(), + ExpectProduce()); +} + +TEST(TestLogicalPlanner, CreateWithSkipReturnLimit) { + // Test CREATE (n) WITH n AS m SKIP 2 RETURN m LIMIT 1 + AstTreeStorage storage; + auto query = QUERY(CREATE(PATTERN(NODE("n"))), + WITH(IDENT("n"), AS("m"), SKIP(LITERAL(2))), + RETURN(IDENT("m"), AS("m"), LIMIT(LITERAL(1)))); + // Since we have a write query, we need to have Accumulate, so Skip and Limit + // need to come before it. This is a bit different than Neo4j, which optimizes + // WITH followed by RETURN as a single RETURN clause. This would cause the + // Limit operator to also appear before Accumulate, thus changing the + // behaviour. We've decided to diverge from Neo4j here, for consistency sake. + CheckPlan(*query, ExpectCreateNode(), ExpectSkip(), ExpectAccumulate(), + ExpectProduce(), ExpectLimit(), ExpectProduce()); +} + +TEST(TestLogicalPlanner, CreateReturnSumSkipLimit) { + // Test CREATE (n) RETURN SUM(n.prop) AS s SKIP 2 LIMIT 1 + Dbms dbms; + auto dba = dbms.active(); + auto prop = dba->property("prop"); + AstTreeStorage storage; + auto sum = SUM(PROPERTY_LOOKUP("n", prop)); + auto query = QUERY(CREATE(PATTERN(NODE("n"))), + RETURN(sum, AS("s"), SKIP(LITERAL(2)), LIMIT(LITERAL(1)))); + auto aggr = ExpectAggregate({sum}, {}); + // We have a write query and aggregation, therefore Skip and Limit should come + // after Accumulate and Aggregate. + CheckPlan(*query, ExpectCreateNode(), ExpectAccumulate(), aggr, ExpectSkip(), + ExpectLimit(), ExpectProduce()); +} + } // namespace diff --git a/tests/unit/query_semantic.cpp b/tests/unit/query_semantic.cpp index 67a2c3b3a..9abccd5c5 100644 --- a/tests/unit/query_semantic.cpp +++ b/tests/unit/query_semantic.cpp @@ -555,4 +555,34 @@ TEST(TestSymbolGenerator, SameResults) { } } +TEST(TestSymbolGenerator, SkipLimitIdentifier) { + // Test MATCH (old) WITH old AS new SKIP old + { + AstTreeStorage storage; + auto query = QUERY(MATCH(PATTERN(NODE("old"))), + WITH(IDENT("old"), AS("new"), SKIP(IDENT("old")))); + SymbolTable symbol_table; + SymbolGenerator symbol_generator(symbol_table); + EXPECT_THROW(query->Accept(symbol_generator), SemanticException); + } + // Test MATCH (old) WITH old AS new SKIP new + { + AstTreeStorage storage; + auto query = QUERY(MATCH(PATTERN(NODE("old"))), + WITH(IDENT("old"), AS("new"), SKIP(IDENT("new")))); + SymbolTable symbol_table; + SymbolGenerator symbol_generator(symbol_table); + EXPECT_THROW(query->Accept(symbol_generator), SemanticException); + } + // Test MATCH (n) RETURN n AS n LIMIT n + { + AstTreeStorage storage; + auto query = QUERY(MATCH(PATTERN(NODE("n"))), + RETURN(IDENT("n"), AS("n"), SKIP(IDENT("n")))); + SymbolTable symbol_table; + SymbolGenerator symbol_generator(symbol_table); + EXPECT_THROW(query->Accept(symbol_generator), SemanticException); + } +} + }