From b778c54d74c97008a0654d916fb29eabfb42db55 Mon Sep 17 00:00:00 2001 From: Teon Banek Date: Tue, 18 Apr 2017 13:45:52 +0200 Subject: [PATCH] Reset bound symbols after planning WITH Reviewers: florijan, mislav.bradac Reviewed By: florijan Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D290 --- src/query/plan/planner.cpp | 11 +++++++++-- tests/unit/query_planner.cpp | 13 +++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/query/plan/planner.cpp b/src/query/plan/planner.cpp index 477bb5298..d065f8ca7 100644 --- a/src/query/plan/planner.cpp +++ b/src/query/plan/planner.cpp @@ -303,7 +303,8 @@ auto GenReturnBody(LogicalOperator *input_op, bool advance_command, } auto GenWith(With &with, LogicalOperator *input_op, - const SymbolTable &symbol_table, bool is_write) { + const SymbolTable &symbol_table, bool is_write, + std::unordered_set &bound_symbols) { // WITH clause is Accumulate/Aggregate (advance_command) + Produce and // optional Filter. if (with.distinct_) { @@ -318,6 +319,11 @@ auto GenWith(With &with, LogicalOperator *input_op, LogicalOperator *last_op = GenReturnBody(input_op, advance_command, with.named_expressions_, symbol_table, accumulate); + // Reset bound symbols, so that only those in WITH are exposed. + bound_symbols.clear(); + for (auto &named_expr : with.named_expressions_) { + BindSymbol(bound_symbols, symbol_table.at(*named_expr)); + } if (with.where_) { last_op = new Filter(std::shared_ptr(last_op), with.where_->expression_); @@ -392,7 +398,8 @@ std::unique_ptr MakeLogicalPlan( } else if (auto *ret = dynamic_cast(clause)) { input_op = GenReturn(*ret, input_op, symbol_table, is_write); } else if (auto *with = dynamic_cast(clause)) { - input_op = GenWith(*with, input_op, symbol_table, is_write); + input_op = + GenWith(*with, input_op, symbol_table, is_write, bound_symbols); // WITH clause advances the command, so reset the flag. is_write = false; } else if (auto *op = HandleWriteClause(clause, input_op, symbol_table, diff --git a/tests/unit/query_planner.cpp b/tests/unit/query_planner.cpp index c999760e4..077c11a48 100644 --- a/tests/unit/query_planner.cpp +++ b/tests/unit/query_planner.cpp @@ -418,4 +418,17 @@ TEST(TestLogicalPlanner, CreateWithSum) { ExpectProduce()); } +TEST(TestLogicalPlanner, MatchWithCreate) { + // Test MATCH (n) WITH n AS a CREATE (a) -[r :r]-> (b) + Dbms dbms; + auto dba = dbms.active(); + auto r_type = dba->edge_type("r"); + AstTreeStorage storage; + auto query = + QUERY(MATCH(PATTERN(NODE("n"))), WITH(IDENT("n"), AS("a")), + CREATE(PATTERN(NODE("a"), EDGE("r", r_type, Direction::RIGHT), + NODE("b")))); + CheckPlan(*query, ExpectScanAll(), ExpectProduce(), ExpectCreateExpand()); +} + } // namespace