diff --git a/src/query/frontend/semantic/symbol_generator.cpp b/src/query/frontend/semantic/symbol_generator.cpp index 79c8b0db0..79561507f 100644 --- a/src/query/frontend/semantic/symbol_generator.cpp +++ b/src/query/frontend/semantic/symbol_generator.cpp @@ -114,6 +114,14 @@ void SymbolGenerator::PostVisit(Where &) { scope_.in_where = false; } void SymbolGenerator::Visit(Merge &) { scope_.in_merge = true; } void SymbolGenerator::PostVisit(Merge &) { scope_.in_merge = false; } +void SymbolGenerator::PostVisit(Unwind &unwind) { + const auto &name = unwind.named_expression_->name_; + if (HasSymbol(name)) { + throw RedeclareVariableError(name); + } + symbol_table_[*unwind.named_expression_] = CreateSymbol(name); +} + // Expressions void SymbolGenerator::Visit(Identifier &ident) { diff --git a/src/query/frontend/semantic/symbol_generator.hpp b/src/query/frontend/semantic/symbol_generator.hpp index 977e6c955..4adfd9ffb 100644 --- a/src/query/frontend/semantic/symbol_generator.hpp +++ b/src/query/frontend/semantic/symbol_generator.hpp @@ -33,6 +33,7 @@ class SymbolGenerator : public TreeVisitorBase { void PostVisit(Where &) override; void Visit(Merge &) override; void PostVisit(Merge &) override; + void PostVisit(Unwind &) override; // Expressions void Visit(Identifier &) override; diff --git a/src/query/plan/planner.cpp b/src/query/plan/planner.cpp index 515fa9122..820a57f9d 100644 --- a/src/query/plan/planner.cpp +++ b/src/query/plan/planner.cpp @@ -533,6 +533,10 @@ std::unique_ptr MakeLogicalPlan( bound_symbols)) { is_write = true; input_op = op; + } else if (auto *unwind = dynamic_cast(clause)) { + input_op = new plan::Unwind(std::shared_ptr(input_op), + unwind->named_expression_->expression_, + symbol_table.at(*unwind->named_expression_)); } else { throw utils::NotYetImplemented(); } diff --git a/tests/unit/query_common.hpp b/tests/unit/query_common.hpp index 4ad14b4ac..730a2b72c 100644 --- a/tests/unit/query_common.hpp +++ b/tests/unit/query_common.hpp @@ -252,6 +252,17 @@ auto GetWith(AstTreeStorage &storage, T... exprs) { return with; } +/// +/// Create the UNWIND clause with given named expression. +/// +auto GetUnwind(AstTreeStorage &storage, NamedExpression *named_expr) { + return storage.Create(named_expr); +} +auto GetUnwind(AstTreeStorage &storage, Expression *expr, NamedExpression *as) { + as->expression_ = expr; + return GetUnwind(storage, as); +} + /// /// Create the delete clause with given named expressions. /// @@ -357,6 +368,9 @@ auto GetMerge(AstTreeStorage &storage, Pattern *pattern, OnMatch on_match, {__VA_ARGS__}) #define IDENT(name) storage.Create((name)) #define LITERAL(val) storage.Create((val)) +#define LIST(...) \ + storage.Create( \ + std::vector{__VA_ARGS__}) #define PROPERTY_LOOKUP(...) \ query::test_common::GetPropertyLookup(storage, __VA_ARGS__) #define NEXPR(name, expr) storage.Create((name), (expr)) @@ -366,6 +380,7 @@ auto GetMerge(AstTreeStorage &storage, Pattern *pattern, OnMatch on_match, #define AS(name) storage.Create((name)) #define RETURN(...) query::test_common::GetReturn(storage, __VA_ARGS__) #define WITH(...) query::test_common::GetWith(storage, __VA_ARGS__) +#define UNWIND(...) query::test_common::GetUnwind(storage, __VA_ARGS__) #define ORDER_BY(...) query::test_common::GetOrderBy(__VA_ARGS__) #define SKIP(expr) \ query::test_common::Skip { (expr) } diff --git a/tests/unit/query_planner.cpp b/tests/unit/query_planner.cpp index dfd941c6b..2a7f9a089 100644 --- a/tests/unit/query_planner.cpp +++ b/tests/unit/query_planner.cpp @@ -72,6 +72,7 @@ class PlanChecker : public LogicalOperatorVisitor { op.input()->Accept(*this); return false; } + void Visit(Unwind &op) override { CheckOp(op); } std::list checkers_; @@ -117,6 +118,7 @@ using ExpectExpandUniquenessFilter = using ExpectSkip = OpChecker; using ExpectLimit = OpChecker; using ExpectOrderBy = OpChecker; +using ExpectUnwind = OpChecker; class ExpectAccumulate : public OpChecker { public: @@ -669,4 +671,13 @@ TEST(TestLogicalPlanner, MatchOptionalMatchWhereReturn) { CheckPlan(*query, ExpectScanAll(), ExpectOptional(optional), ExpectProduce()); } +TEST(TestLogicalPlanner, MatchUnwindReturn) { + // Test MATCH (n) UNWIND [1,2,3] AS x RETURN n AS n, x AS x + AstTreeStorage storage; + auto query = QUERY(MATCH(PATTERN(NODE("n"))), + UNWIND(LIST(LITERAL(1), LITERAL(2), LITERAL(3)), AS("x")), + RETURN(IDENT("n"), AS("n"), IDENT("x"), AS("x"))); + CheckPlan(*query, ExpectScanAll(), ExpectUnwind(), ExpectProduce()); +} + } // namespace diff --git a/tests/unit/query_semantic.cpp b/tests/unit/query_semantic.cpp index 2d37f45f9..14c4857b4 100644 --- a/tests/unit/query_semantic.cpp +++ b/tests/unit/query_semantic.cpp @@ -724,4 +724,41 @@ TEST(TestSymbolGenerator, MergeOnMatchOnCreate) { EXPECT_EQ(m, symbol_table.at(*m_prop->expression_)); } +TEST(TestSymbolGenerator, WithUnwindRedeclareReturn) { + // Test WITH [1, 2] AS list UNWIND list AS list RETURN list + AstTreeStorage storage; + auto query = QUERY(WITH(LIST(LITERAL(1), LITERAL(2)), AS("list")), + UNWIND(IDENT("list"), AS("list")), + RETURN(IDENT("list"), AS("list"))); + SymbolTable symbol_table; + SymbolGenerator symbol_generator(symbol_table); + EXPECT_THROW(query->Accept(symbol_generator), RedeclareVariableError); +} + +TEST(TestSymbolGenerator, WithUnwindReturn) { + // WITH [1, 2] AS list UNWIND list AS elem RETURN list AS list, elem AS elem + AstTreeStorage storage; + auto with_as_list = AS("list"); + auto unwind = UNWIND(IDENT("list"), AS("elem")); + auto ret_list = IDENT("list"); + auto ret_as_list = AS("list"); + auto ret_elem = IDENT("elem"); + auto ret_as_elem = AS("elem"); + auto query = QUERY(WITH(LIST(LITERAL(1), LITERAL(2)), with_as_list), unwind, + RETURN(ret_list, ret_as_list, ret_elem, ret_as_elem)); + SymbolTable symbol_table; + SymbolGenerator symbol_generator(symbol_table); + query->Accept(symbol_generator); + // Symbols for: `list`, `elem`, `AS list`, `AS elem` + EXPECT_EQ(symbol_table.max_position(), 4); + const auto &list = symbol_table.at(*with_as_list); + EXPECT_EQ(list, symbol_table.at(*unwind->named_expression_->expression_)); + const auto &elem = symbol_table.at(*unwind->named_expression_); + EXPECT_NE(list, elem); + EXPECT_EQ(list, symbol_table.at(*ret_list)); + EXPECT_NE(list, symbol_table.at(*ret_as_list)); + EXPECT_EQ(elem, symbol_table.at(*ret_elem)); + EXPECT_NE(elem, symbol_table.at(*ret_as_elem)); +} + } // namespace