diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index b7105b9e3..3431dc896 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -851,6 +851,31 @@ class RemoveLabels : public Clause { : Clause(uid), identifier_(identifier), labels_(labels) {} }; +class Merge : public Clause { + friend class AstTreeStorage; + + public: + void Accept(TreeVisitorBase &visitor) override { + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + pattern_->Accept(visitor); + for (auto &set : on_match_) { + set->Accept(visitor); + } + for (auto &set : on_create_) { + set->Accept(visitor); + } + } + } + + Pattern *pattern_ = nullptr; + std::vector on_match_; + std::vector on_create_; + + protected: + Merge(int uid) : Clause(uid) {} +}; + // It would be better to call this AstTree, but we already have a class Tree, // which could be renamed to Node or AstTreeNode, but we also have a class // called NodeAtom... diff --git a/src/query/frontend/ast/ast_visitor.hpp b/src/query/frontend/ast/ast_visitor.hpp index e69d9ba47..5eb2ab325 100644 --- a/src/query/frontend/ast/ast_visitor.hpp +++ b/src/query/frontend/ast/ast_visitor.hpp @@ -44,6 +44,7 @@ class SetProperties; class SetLabels; class RemoveProperty; class RemoveLabels; +class Merge; using TreeVisitorBase = ::utils::Visitor< Query, NamedExpression, OrOperator, XorOperator, AndOperator, NotOperator, @@ -53,5 +54,5 @@ using TreeVisitorBase = ::utils::Visitor< UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, Identifier, Literal, PropertyLookup, Aggregation, Function, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom, Delete, Where, SetProperty, SetProperties, SetLabels, - RemoveProperty, RemoveLabels>; + RemoveProperty, RemoveLabels, Merge>; } diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index a28fba1a1..4d2b345e5 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -51,7 +51,8 @@ antlrcpp::Any CypherMainVisitor::visitSingleQuery( dynamic_cast(clause) || dynamic_cast(clause) || dynamic_cast(clause) || - dynamic_cast(clause)) { + dynamic_cast(clause) || + dynamic_cast(clause)) { if (has_return) { throw SemanticException("Update clauses can't be after return"); } @@ -116,6 +117,9 @@ antlrcpp::Any CypherMainVisitor::visitClause(CypherParser::ClauseContext *ctx) { if (ctx->with()) { return static_cast(ctx->with()->accept(this).as()); } + if (ctx->merge()) { + return static_cast(ctx->merge()->accept(this).as()); + } // TODO: implement other clauses. throw utils::NotYetImplemented(); return 0; @@ -943,4 +947,20 @@ antlrcpp::Any CypherMainVisitor::visitWith(CypherParser::WithContext *ctx) { } return with; } + +antlrcpp::Any CypherMainVisitor::visitMerge(CypherParser::MergeContext *ctx) { + auto *merge = storage_.Create(); + merge->pattern_ = ctx->patternPart()->accept(this); + for (auto &merge_action : ctx->mergeAction()) { + auto set = merge_action->set()->accept(this).as>(); + if (merge_action->MATCH()) { + merge->on_match_.insert(merge->on_match_.end(), set.begin(), set.end()); + } else { + debug_assert(merge_action->CREATE(), "Expected ON MATCH or ON CREATE"); + merge->on_create_.insert(merge->on_create_.end(), set.begin(), set.end()); + } + } + return merge; } + +} // namespace query::frontend diff --git a/src/query/frontend/ast/cypher_main_visitor.hpp b/src/query/frontend/ast/cypher_main_visitor.hpp index 2c532f789..865c31494 100644 --- a/src/query/frontend/ast/cypher_main_visitor.hpp +++ b/src/query/frontend/ast/cypher_main_visitor.hpp @@ -470,6 +470,11 @@ class CypherMainVisitor : public antlropencypher::CypherBaseVisitor { */ antlrcpp::Any visitWith(CypherParser::WithContext *ctx) override; + /** + * @return Merge* + */ + antlrcpp::Any visitMerge(CypherParser::MergeContext *ctx) override; + public: Query *query() { return query_; } const static std::string kAnonPrefix; diff --git a/tests/unit/cypher_main_visitor.cpp b/tests/unit/cypher_main_visitor.cpp index a32eb5a29..950dce860 100644 --- a/tests/unit/cypher_main_visitor.cpp +++ b/tests/unit/cypher_main_visitor.cpp @@ -924,6 +924,7 @@ TEST(CypherMainVisitorTest, ClausesOrdering) { ASSERT_THROW(AstGenerator("RETURN 1 RETURN 1"), SemanticException); ASSERT_THROW(AstGenerator("RETURN 1 MATCH (n) RETURN n"), SemanticException); ASSERT_THROW(AstGenerator("RETURN 1 DELETE n"), SemanticException); + ASSERT_THROW(AstGenerator("RETURN 1 MERGE (n)"), SemanticException); ASSERT_THROW(AstGenerator("RETURN 1 WITH n AS m RETURN 1"), SemanticException); @@ -945,4 +946,21 @@ TEST(CypherMainVisitorTest, ClausesOrdering) { AstGenerator("WITH 1 AS n SET n += m"); AstGenerator("WITH 1 AS n MATCH (n) RETURN n"); } + +TEST(CypherMainVisitorTest, Merge) { + AstGenerator ast_generator( + "MERGE (a) -[:r]- (b) ON MATCH SET a.x = b.x " + "ON CREATE SET b :label ON MATCH SET b = a"); + auto *query = ast_generator.query_; + ASSERT_EQ(query->clauses_.size(), 1U); + auto *merge = dynamic_cast(query->clauses_[0]); + ASSERT_TRUE(merge); + EXPECT_TRUE(dynamic_cast(merge->pattern_)); + ASSERT_EQ(merge->on_match_.size(), 2U); + EXPECT_TRUE(dynamic_cast(merge->on_match_[0])); + EXPECT_TRUE(dynamic_cast(merge->on_match_[1])); + ASSERT_EQ(merge->on_create_.size(), 1U); + EXPECT_TRUE(dynamic_cast(merge->on_create_[0])); +} + }