diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index cc9c19351..353431151 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -894,6 +894,7 @@ class Merge : public Clause { for (auto &set : on_create_) { set->Accept(visitor); } + visitor.PostVisit(*this); } } @@ -905,6 +906,28 @@ class Merge : public Clause { Merge(int uid) : Clause(uid) {} }; +class Unwind : public Clause { + friend class AstTreeStorage; + + public: + void Accept(TreeVisitorBase &visitor) override { + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + named_expression_->Accept(visitor); + visitor.PostVisit(*this); + } + } + + NamedExpression *const named_expression_ = nullptr; + + protected: + Unwind(int uid, NamedExpression *named_expression) + : Clause(uid), named_expression_(named_expression) { + debug_assert(named_expression, + "Unwind cannot take nullptr for named_expression") + } +}; + // It would be better to call this AstTree, but we already have a class Tree, // which could be renamed to Node or AstTreeNode, but we also have a class // called NodeAtom... diff --git a/src/query/frontend/ast/ast_visitor.hpp b/src/query/frontend/ast/ast_visitor.hpp index bbcd278ff..f6f0c64a0 100644 --- a/src/query/frontend/ast/ast_visitor.hpp +++ b/src/query/frontend/ast/ast_visitor.hpp @@ -46,6 +46,7 @@ class SetLabels; class RemoveProperty; class RemoveLabels; class Merge; +class Unwind; using TreeVisitorBase = ::utils::Visitor< Query, NamedExpression, OrOperator, XorOperator, AndOperator, NotOperator, @@ -55,5 +56,7 @@ using TreeVisitorBase = ::utils::Visitor< UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, Identifier, PrimitiveLiteral, ListLiteral, PropertyLookup, Aggregation, Function, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom, Delete, Where, - SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, Merge>; -} + SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, Merge, + Unwind>; + +} // namespace query diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index 291a58f55..a966ef7a7 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -42,7 +42,7 @@ antlrcpp::Any CypherMainVisitor::visitSingleQuery( bool has_update = false; bool has_return = false; for (Clause *clause : query_->clauses_) { - if (dynamic_cast(clause)) { + if (dynamic_cast(clause) || dynamic_cast(clause)) { if (has_update || has_return) { throw SemanticException("Match can't be after return or update clause"); } @@ -121,6 +121,9 @@ antlrcpp::Any CypherMainVisitor::visitClause(CypherParser::ClauseContext *ctx) { if (ctx->merge()) { return static_cast(ctx->merge()->accept(this).as()); } + if (ctx->unwind()) { + return static_cast(ctx->unwind()->accept(this).as()); + } // TODO: implement other clauses. throw utils::NotYetImplemented(); return 0; @@ -965,4 +968,12 @@ antlrcpp::Any CypherMainVisitor::visitMerge(CypherParser::MergeContext *ctx) { return merge; } +antlrcpp::Any CypherMainVisitor::visitUnwind(CypherParser::UnwindContext *ctx) { + auto *named_expr = storage_.Create(); + named_expr->expression_ = ctx->expression()->accept(this); + named_expr->name_ = + std::string(ctx->variable()->accept(this).as()); + return storage_.Create(named_expr); +} + } // namespace query::frontend diff --git a/src/query/frontend/ast/cypher_main_visitor.hpp b/src/query/frontend/ast/cypher_main_visitor.hpp index e62b4f9b6..f82ed9aca 100644 --- a/src/query/frontend/ast/cypher_main_visitor.hpp +++ b/src/query/frontend/ast/cypher_main_visitor.hpp @@ -480,6 +480,11 @@ class CypherMainVisitor : public antlropencypher::CypherBaseVisitor { */ antlrcpp::Any visitMerge(CypherParser::MergeContext *ctx) override; + /** + * @return Unwind* + */ + antlrcpp::Any visitUnwind(CypherParser::UnwindContext *ctx) override; + public: Query *query() { return query_; } const static std::string kAnonPrefix; diff --git a/tests/unit/cypher_main_visitor.cpp b/tests/unit/cypher_main_visitor.cpp index e0fc5612e..923818566 100644 --- a/tests/unit/cypher_main_visitor.cpp +++ b/tests/unit/cypher_main_visitor.cpp @@ -970,6 +970,8 @@ TEST(CypherMainVisitorTest, ClausesOrdering) { ASSERT_THROW(AstGenerator("RETURN 1 MERGE (n)"), SemanticException); ASSERT_THROW(AstGenerator("RETURN 1 WITH n AS m RETURN 1"), SemanticException); + ASSERT_THROW(AstGenerator("RETURN 1 AS n UNWIND n AS x RETURN x"), + SemanticException); AstGenerator("CREATE (n)"); ASSERT_THROW(AstGenerator("SET n:x MATCH (n) RETURN n"), SemanticException); @@ -988,6 +990,12 @@ TEST(CypherMainVisitorTest, ClausesOrdering) { AstGenerator("WITH 1 AS n RETURN n"); AstGenerator("WITH 1 AS n SET n += m"); AstGenerator("WITH 1 AS n MATCH (n) RETURN n"); + + ASSERT_THROW(AstGenerator("UNWIND [1,2,3] AS x"), SemanticException); + ASSERT_THROW(AstGenerator("CREATE (n) UNWIND [1,2,3] AS x RETURN x"), + SemanticException); + AstGenerator("UNWIND [1,2,3] AS x CREATE (n) RETURN x"); + AstGenerator("CREATE (n) WITH n UNWIND [1,2,3] AS x RETURN x"); } TEST(CypherMainVisitorTest, Merge) { @@ -1005,4 +1013,24 @@ TEST(CypherMainVisitorTest, Merge) { ASSERT_EQ(merge->on_create_.size(), 1U); EXPECT_TRUE(dynamic_cast(merge->on_create_[0])); } + +TEST(CypherMainVisitorTest, Unwind) { + AstGenerator ast_generator("UNWIND [1,2,3] AS elem RETURN elem"); + auto *query = ast_generator.query_; + ASSERT_EQ(query->clauses_.size(), 2U); + auto *unwind = dynamic_cast(query->clauses_[0]); + ASSERT_TRUE(unwind); + auto *ret = dynamic_cast(query->clauses_[1]); + EXPECT_TRUE(ret); + ASSERT_TRUE(unwind->named_expression_); + EXPECT_EQ(unwind->named_expression_->name_, "elem"); + auto *expr = unwind->named_expression_->expression_; + ASSERT_TRUE(expr); + ASSERT_TRUE(dynamic_cast(expr)); +} + +TEST(CypherMainVisitorTest, UnwindWithoutAsError) { + EXPECT_THROW(AstGenerator("UNWIND [1,2,3] RETURN 42"), SyntaxException); +} + } diff --git a/tests/unit/query_plan_accumulate_aggregate.cpp b/tests/unit/query_plan_accumulate_aggregate.cpp index 84f0ebbae..5e9bacab2 100644 --- a/tests/unit/query_plan_accumulate_aggregate.cpp +++ b/tests/unit/query_plan_accumulate_aggregate.cpp @@ -490,11 +490,11 @@ TEST(QueryPlan, Unwind) { std::vector{"bla"}}); auto x = symbol_table.CreateSymbol("x"); - auto unwind_0 = std::make_shared(nullptr, input_expr, x); + auto unwind_0 = std::make_shared(nullptr, input_expr, x); auto x_expr = IDENT("x"); symbol_table[*x_expr] = x; auto y = symbol_table.CreateSymbol("y"); - auto unwind_1 = std::make_shared(unwind_0, x_expr, y); + auto unwind_1 = std::make_shared(unwind_0, x_expr, y); auto x_ne = NEXPR("x", x_expr); symbol_table[*x_ne] = symbol_table.CreateSymbol("x_ne");