diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index 097864d7a..468990547 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -878,8 +878,10 @@ class EdgeAtom : public PatternAtom { edge_atom->properties_[property.first] = property.second->Clone(storage); } edge_atom->has_range_ = has_range_; - edge_atom->lower_bound_ = lower_bound_ ? lower_bound_->Clone(storage) : nullptr; - edge_atom->upper_bound_ = upper_bound_ ? upper_bound_->Clone(storage) : nullptr; + edge_atom->lower_bound_ = + lower_bound_ ? lower_bound_->Clone(storage) : nullptr; + edge_atom->upper_bound_ = + upper_bound_ ? upper_bound_->Clone(storage) : nullptr; return edge_atom; } @@ -897,6 +899,49 @@ class EdgeAtom : public PatternAtom { : PatternAtom(uid, identifier), direction_(direction) {} }; +class BreadthFirstAtom : public EdgeAtom { + // TODO: Reconsider inheriting from EdgeAtom, since only `direction_` is used. + friend class AstTreeStorage; + + public: + DEFVISITABLE(TreeVisitor); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + identifier_->Accept(visitor) && + traversed_edge_identifier_->Accept(visitor) && + next_node_identifier_->Accept(visitor) && + filter_expression_->Accept(visitor) && max_depth_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + BreadthFirstAtom *Clone(AstTreeStorage &storage) const override { + return storage.Create( + identifier_->Clone(storage), direction_, + traversed_edge_identifier_->Clone(storage), + next_node_identifier_->Clone(storage), + filter_expression_->Clone(storage), max_depth_->Clone(storage)); + } + + Identifier *traversed_edge_identifier_ = nullptr; + Identifier *next_node_identifier_ = nullptr; + // Expression which evaluates to true in order to continue the BFS. + Expression *filter_expression_ = nullptr; + Expression *max_depth_ = nullptr; + + protected: + using EdgeAtom::EdgeAtom; + BreadthFirstAtom(int uid, Identifier *identifier, Direction direction, + Identifier *traversed_edge_identifier, + Identifier *next_node_identifier, + Expression *filter_expression, Expression *max_depth) + : EdgeAtom(uid, identifier, direction), + traversed_edge_identifier_(traversed_edge_identifier), + next_node_identifier_(next_node_identifier), + filter_expression_(filter_expression), + max_depth_(max_depth) {} +}; + class Clause : public Tree { friend class AstTreeStorage; diff --git a/src/query/frontend/ast/ast_visitor.hpp b/src/query/frontend/ast/ast_visitor.hpp index a73db0202..68f982a71 100644 --- a/src/query/frontend/ast/ast_visitor.hpp +++ b/src/query/frontend/ast/ast_visitor.hpp @@ -21,6 +21,7 @@ class With; class Pattern; class NodeAtom; class EdgeAtom; +class BreadthFirstAtom; class PrimitiveLiteral; class ListLiteral; class OrOperator; @@ -65,8 +66,8 @@ using TreeCompositeVisitor = ::utils::CompositeVisitor< ListSlicingOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, ListLiteral, PropertyLookup, LabelsTest, EdgeTypeTest, Aggregation, Function, All, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom, - Delete, Where, SetProperty, SetProperties, SetLabels, RemoveProperty, - RemoveLabels, Merge, Unwind>; + BreadthFirstAtom, Delete, Where, SetProperty, SetProperties, SetLabels, + RemoveProperty, RemoveLabels, Merge, Unwind>; using TreeLeafVisitor = ::utils::LeafVisitor; @@ -90,7 +91,8 @@ using TreeVisitor = ::utils::Visitor< ListSlicingOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, ListLiteral, PropertyLookup, LabelsTest, EdgeTypeTest, Aggregation, Function, All, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom, - Delete, Where, SetProperty, SetProperties, SetLabels, RemoveProperty, - RemoveLabels, Merge, Unwind, Identifier, PrimitiveLiteral, CreateIndex>; + BreadthFirstAtom, Delete, Where, SetProperty, SetProperties, SetLabels, + RemoveProperty, RemoveLabels, Merge, Unwind, Identifier, PrimitiveLiteral, + CreateIndex>; } // namespace query diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index 5dff49bbd..52753404c 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -413,8 +413,26 @@ antlrcpp::Any CypherMainVisitor::visitPatternElementChain( antlrcpp::Any CypherMainVisitor::visitRelationshipPattern( CypherParser::RelationshipPatternContext *ctx) { - auto *edge = storage_.Create(); - if (ctx->relationshipDetail()) { + auto *edge = ctx->bfsDetail() ? storage_.Create() + : storage_.Create(); + if (ctx->bfsDetail()) { + if (ctx->bfsDetail()->bfs_variable) { + std::string variable = ctx->bfsDetail()->bfs_variable->accept(this); + edge->identifier_ = storage_.Create(variable); + users_identifiers.insert(variable); + } + auto *bf_atom = dynamic_cast(edge); + std::string traversed_edge_variable = + ctx->bfsDetail()->traversed_edge->accept(this); + bf_atom->traversed_edge_identifier_ = + storage_.Create(traversed_edge_variable); + std::string next_node_variable = ctx->bfsDetail()->next_node->accept(this); + bf_atom->next_node_identifier_ = + storage_.Create(next_node_variable); + bf_atom->filter_expression_ = + ctx->bfsDetail()->expression()[0]->accept(this); + bf_atom->max_depth_ = ctx->bfsDetail()->expression()[1]->accept(this); + } else if (ctx->relationshipDetail()) { if (ctx->relationshipDetail()->variable()) { std::string variable = ctx->relationshipDetail()->variable()->accept(this); @@ -466,6 +484,12 @@ antlrcpp::Any CypherMainVisitor::visitRelationshipDetail( return 0; } +antlrcpp::Any CypherMainVisitor::visitBfsDetail( + CypherParser::BfsDetailContext *) { + debug_assert(false, "Should never be called. See documentation in hpp."); + return 0; +} + antlrcpp::Any CypherMainVisitor::visitRelationshipTypes( CypherParser::RelationshipTypesContext *ctx) { std::vector types; diff --git a/src/query/frontend/ast/cypher_main_visitor.hpp b/src/query/frontend/ast/cypher_main_visitor.hpp index 37944aecf..b148872d7 100644 --- a/src/query/frontend/ast/cypher_main_visitor.hpp +++ b/src/query/frontend/ast/cypher_main_visitor.hpp @@ -268,6 +268,12 @@ class CypherMainVisitor : public antlropencypher::CypherBaseVisitor { antlrcpp::Any visitRelationshipDetail( CypherParser::RelationshipDetailContext *ctx) override; + /** + * This should never be called. Everything is done directly in + * visitRelationshipPattern. + */ + antlrcpp::Any visitBfsDetail(CypherParser::BfsDetailContext *ctx) override; + /** * @return vector */ diff --git a/src/query/frontend/opencypher/grammar/Cypher.g4 b/src/query/frontend/opencypher/grammar/Cypher.g4 index eda75ffa0..e060e8dc2 100644 --- a/src/query/frontend/opencypher/grammar/Cypher.g4 +++ b/src/query/frontend/opencypher/grammar/Cypher.g4 @@ -110,12 +110,14 @@ nodePattern : '(' SP? ( variable SP? )? ( nodeLabels SP? )? ( properties SP? )? patternElementChain : relationshipPattern SP? nodePattern ; -relationshipPattern : ( leftArrowHead SP? dash SP? relationshipDetail? SP? dash SP? rightArrowHead ) - | ( leftArrowHead SP? dash SP? relationshipDetail? SP? dash ) - | ( dash SP? relationshipDetail? SP? dash SP? rightArrowHead ) - | ( dash SP? relationshipDetail? SP? dash ) +relationshipPattern : ( leftArrowHead SP? dash SP? ( bfsDetail | relationshipDetail )? SP? dash SP? rightArrowHead ) + | ( leftArrowHead SP? dash SP? ( bfsDetail | relationshipDetail )? SP? dash ) + | ( dash SP? ( bfsDetail | relationshipDetail )? SP? dash SP? rightArrowHead ) + | ( dash SP? ( bfsDetail | relationshipDetail )? SP? dash ) ; +bfsDetail : BFS SP? ( '[' SP? ( bfs_variable=variable SP? )? ']' )? SP? '(' SP? traversed_edge=variable SP? ',' SP? next_node=variable SP? '|' SP? expression SP? ',' SP? expression SP? ')' ; + relationshipDetail : '[' SP? ( variable SP? )? ( relationshipTypes SP? )? ( rangeLiteral SP? )? properties SP? ']' | '[' SP? ( variable SP? )? ( relationshipTypes SP? )? ( rangeLiteral SP? )? ( properties SP? )? ']' ; @@ -445,6 +447,8 @@ FALSE : ( 'F' | 'f' ) ( 'A' | 'a' ) ( 'L' | 'l' ) ( 'S' | 's' ) ( 'E' | 'e' ) ; INDEX : ( 'I' | 'i') ( 'N' | 'n' ) ( 'D' | 'd' ) ( 'E' | 'e' ) ( 'X' | 'x' ) ; +BFS : ( 'B' | 'b' ) ( 'F' | 'f' ) ( 'S' | 's' ) ; + UnescapedSymbolicName : IdentifierStart ( IdentifierPart )* ; /** diff --git a/src/query/interpret/eval.hpp b/src/query/interpret/eval.hpp index 25866932a..510dbaefd 100644 --- a/src/query/interpret/eval.hpp +++ b/src/query/interpret/eval.hpp @@ -42,6 +42,7 @@ class ExpressionEvaluator : public TreeVisitor { BLOCK_VISIT(Pattern); BLOCK_VISIT(NodeAtom); BLOCK_VISIT(EdgeAtom); + BLOCK_VISIT(BreadthFirstAtom); BLOCK_VISIT(Delete); BLOCK_VISIT(Where); BLOCK_VISIT(SetProperty); diff --git a/tests/unit/cypher_main_visitor.cpp b/tests/unit/cypher_main_visitor.cpp index 61e5d5c78..e398b76e4 100644 --- a/tests/unit/cypher_main_visitor.cpp +++ b/tests/unit/cypher_main_visitor.cpp @@ -1400,4 +1400,26 @@ TYPED_TEST(CypherMainVisitorTest, ReturnAll) { EXPECT_TRUE(eq); } +TYPED_TEST(CypherMainVisitorTest, MatchBfsReturn) { + TypeParam ast_generator( + "MATCH (n) -bfs[r](e, n|e.prop = 42, 10)-> (m) RETURN r"); + auto *query = ast_generator.query_; + ASSERT_EQ(query->clauses_.size(), 2U); + auto *match = dynamic_cast(query->clauses_[0]); + ASSERT_TRUE(match); + ASSERT_EQ(match->patterns_.size(), 1U); + ASSERT_EQ(match->patterns_[0]->atoms_.size(), 3U); + auto *bfs = dynamic_cast(match->patterns_[0]->atoms_[1]); + ASSERT_TRUE(bfs); + EXPECT_EQ(bfs->direction_, EdgeAtom::Direction::OUT); + EXPECT_EQ(bfs->identifier_->name_, "r"); + EXPECT_EQ(bfs->traversed_edge_identifier_->name_, "e"); + EXPECT_EQ(bfs->next_node_identifier_->name_, "n"); + auto *max_depth = dynamic_cast(bfs->max_depth_); + ASSERT_TRUE(max_depth); + EXPECT_EQ(max_depth->value_.Value(), 10U); + auto *eq = dynamic_cast(bfs->filter_expression_); + ASSERT_TRUE(eq); +} + }