diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index 8dcfa41d2..f4e053e8e 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -649,6 +649,7 @@ class Match : public Clause { protected: Match(int uid) : Clause(uid) {} + Match(int uid, bool optional) : Clause(uid), optional_(optional) {} }; /** @brief Defines the order for sorting values (ascending or descending). */ diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index d522b4b6e..060e583bb 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -363,7 +363,9 @@ bool Expand::ExpandCursor::HandleExistingNode(const VertexAccessor new_node, NodeFilter::NodeFilter(const std::shared_ptr &input, Symbol input_symbol, const NodeAtom *node_atom) - : input_(input), input_symbol_(input_symbol), node_atom_(node_atom) {} + : input_(input ? input : std::make_shared()), + input_symbol_(input_symbol), + node_atom_(node_atom) {} ACCEPT_WITH_INPUT(NodeFilter) @@ -410,7 +412,9 @@ bool NodeFilter::NodeFilterCursor::VertexPasses( EdgeFilter::EdgeFilter(const std::shared_ptr &input, Symbol input_symbol, const EdgeAtom *edge_atom) - : input_(input), input_symbol_(input_symbol), edge_atom_(edge_atom) {} + : input_(input ? input : std::make_shared()), + input_symbol_(input_symbol), + edge_atom_(edge_atom) {} ACCEPT_WITH_INPUT(EdgeFilter) @@ -458,9 +462,10 @@ bool EdgeFilter::EdgeFilterCursor::EdgePasses(const EdgeAccessor &edge, return true; } -Filter::Filter(const std::shared_ptr &input_, - Expression *expression_) - : input_(input_), expression_(expression_) {} +Filter::Filter(const std::shared_ptr &input, + Expression *expression) + : input_(input ? input : std::make_shared()), + expression_(expression) {} ACCEPT_WITH_INPUT(Filter) diff --git a/src/query/plan/operator.hpp b/src/query/plan/operator.hpp index d8fbe36f1..780ac24f7 100644 --- a/src/query/plan/operator.hpp +++ b/src/query/plan/operator.hpp @@ -434,7 +434,7 @@ class NodeFilter : public LogicalOperator { public: /** @brief Construct @c NodeFilter. * - * @param input Required, preceding @c LogicalOperator. + * @param input Optional, preceding @c LogicalOperator. * @param input_symbol @c Symbol where the node to be filtered is stored. * @param node_atom @c NodeAtom with labels and properties to filter by. */ @@ -475,7 +475,7 @@ class EdgeFilter : public LogicalOperator { public: /** @brief Construct @c EdgeFilter. * - * @param input Required, preceding @c LogicalOperator. + * @param input Optional, preceding @c LogicalOperator. * @param input_symbol @c Symbol where the edge to be filtered is stored. * @param edge_atom @c EdgeAtom with edge types and properties to filter by. */ @@ -1231,6 +1231,10 @@ class Optional : public LogicalOperator { void Accept(LogicalOperatorVisitor &visitor) override; std::unique_ptr MakeCursor(GraphDbAccessor &db) override; + auto input() const { return input_; } + auto optional() const { return optional_; } + const auto &optional_symbols() const { return optional_symbols_; } + private: const std::shared_ptr input_; const std::shared_ptr optional_; diff --git a/src/query/plan/planner.cpp b/src/query/plan/planner.cpp index 46796c104..85170dff4 100644 --- a/src/query/plan/planner.cpp +++ b/src/query/plan/planner.cpp @@ -101,18 +101,36 @@ auto GenCreate(Create &create, LogicalOperator *input_op, return last_op; } +// Contextual information used for generating match operators. +struct MatchContext { + const SymbolTable &symbol_table; + // Already bound symbols, which are used to determine whether the operator + // should reference them or establish new. This is both read from and written + // to during generation. + std::unordered_set &bound_symbols; + // Determines whether the match should see the new graph state or not. + GraphView graph_view = GraphView::OLD; + // Symbols for edges established in match, used to ensure Cyphermorphism. + std::unordered_set edge_symbols; + // All the newly established symbols in match. + std::vector new_symbols; +}; + +// Generates operators for matching the given pattern and appends them to +// input_op. Fills the context with all the new symbols and edge symbols. auto GenMatchForPattern(Pattern &pattern, LogicalOperator *input_op, - const SymbolTable &symbol_table, - std::unordered_set &bound_symbols, - std::vector &edge_symbols, - GraphView graph_view = GraphView::OLD) { + MatchContext &context) { + auto &bound_symbols = context.bound_symbols; + const auto &symbol_table = context.symbol_table; auto base = [&](NodeAtom *node) { LogicalOperator *last_op = input_op; // If the first atom binds a symbol, we generate a ScanAll which writes it. // Otherwise, someone else generates it (e.g. a previous ScanAll). - if (BindSymbol(bound_symbols, symbol_table.at(*node->identifier_))) { + const auto &node_symbol = symbol_table.at(*node->identifier_); + if (BindSymbol(bound_symbols, node_symbol)) { last_op = new ScanAll(node, std::shared_ptr(last_op), - graph_view); + context.graph_view); + context.new_symbols.emplace_back(node_symbol); } // Even though we may skip generating ScanAll, we still want to add a filter // in case this atom adds more labels/properties for filtering. @@ -129,28 +147,36 @@ auto GenMatchForPattern(Pattern &pattern, LogicalOperator *input_op, // If the expand symbols were already bound, then we need to indicate // that they exist. The Expand will then check whether the pattern holds // instead of writing the expansion to symbols. + const auto &node_symbol = symbol_table.at(*node->identifier_); auto existing_node = false; - auto existing_edge = false; - if (!BindSymbol(bound_symbols, symbol_table.at(*node->identifier_))) { + if (!BindSymbol(bound_symbols, node_symbol)) { existing_node = true; + } else { + context.new_symbols.emplace_back(node_symbol); } const auto &edge_symbol = symbol_table.at(*edge->identifier_); + auto existing_edge = false; if (!BindSymbol(bound_symbols, edge_symbol)) { existing_edge = true; + } else { + context.new_symbols.emplace_back(edge_symbol); } - last_op = - new Expand(node, edge, std::shared_ptr(last_op), - input_symbol, existing_node, existing_edge, graph_view); + last_op = new Expand(node, edge, std::shared_ptr(last_op), + input_symbol, existing_node, existing_edge, + context.graph_view); if (!existing_edge) { // Ensure Cyphermorphism (different edge symbols always map to different // edges). - if (!edge_symbols.empty()) { + if (!context.edge_symbols.empty()) { last_op = new ExpandUniquenessFilter( std::shared_ptr(last_op), edge_symbol, - edge_symbols); + std::vector(context.edge_symbols.begin(), + context.edge_symbols.end())); } - edge_symbols.emplace_back(edge_symbol); } + // Insert edge_symbol after creating ExpandUniquenessFilter, so that we + // avoid filtering by the same edge we just expanded. + context.edge_symbols.insert(edge_symbol); if (!edge->edge_types_.empty() || !edge->properties_.empty()) { last_op = new EdgeFilter(std::shared_ptr(last_op), symbol_table.at(*edge->identifier_), edge); @@ -167,16 +193,22 @@ auto GenMatchForPattern(Pattern &pattern, LogicalOperator *input_op, auto GenMatch(Match &match, LogicalOperator *input_op, const SymbolTable &symbol_table, std::unordered_set &bound_symbols) { - auto last_op = input_op; - std::vector edge_symbols; + auto last_op = match.optional_ ? nullptr : input_op; + MatchContext context{symbol_table, bound_symbols}; for (auto pattern : match.patterns_) { - last_op = GenMatchForPattern(*pattern, last_op, symbol_table, bound_symbols, - edge_symbols); + last_op = GenMatchForPattern(*pattern, last_op, context); } if (match.where_) { last_op = new Filter(std::shared_ptr(last_op), match.where_->expression_); } + // Plan Optional after Filter. because with `OPTIONAL MATCH ... WHERE`, + // filtering is done while looking for the pattern. + if (match.optional_) { + last_op = new Optional(std::shared_ptr(input_op), + std::shared_ptr(last_op), + context.new_symbols); + } return last_op; } @@ -449,10 +481,8 @@ auto GenMerge(query::Merge &merge, LogicalOperator *input_op, // Copy the bound symbol set, because we don't want to use the updated version // when generating the create part. std::unordered_set bound_symbols_copy(bound_symbols); - std::vector edge_symbols; - auto on_match = - GenMatchForPattern(*merge.pattern_, nullptr, symbol_table, - bound_symbols_copy, edge_symbols, GraphView::NEW); + MatchContext context{symbol_table, bound_symbols_copy, GraphView::NEW}; + auto on_match = GenMatchForPattern(*merge.pattern_, nullptr, context); // Use the original bound_symbols, so we fill it with new symbols. auto on_create = GenCreateForPattern(*merge.pattern_, nullptr, symbol_table, bound_symbols); diff --git a/tests/unit/query_common.hpp b/tests/unit/query_common.hpp index 7af3d2463..d23f5945a 100644 --- a/tests/unit/query_common.hpp +++ b/tests/unit/query_common.hpp @@ -130,14 +130,13 @@ auto GetPattern(AstTreeStorage &storage, std::vector atoms) { } /// -/// This function creates an AST node which can store patterns and fills them -/// with given patterns. +/// This function fills an AST node which with given patterns. /// /// The function is most commonly used to create Match and Create clauses. /// template -auto GetWithPatterns(AstTreeStorage &storage, std::vector patterns) { - auto with_patterns = storage.Create(); +auto GetWithPatterns(TWithPatterns *with_patterns, + std::vector patterns) { with_patterns->patterns_.insert(with_patterns->patterns_.begin(), patterns.begin(), patterns.end()); return with_patterns; @@ -346,11 +345,16 @@ auto GetMerge(AstTreeStorage &storage, Pattern *pattern, OnMatch on_match, #define NODE(...) query::test_common::GetNode(storage, __VA_ARGS__) #define EDGE(...) query::test_common::GetEdge(storage, __VA_ARGS__) #define PATTERN(...) query::test_common::GetPattern(storage, {__VA_ARGS__}) -#define MATCH(...) \ - query::test_common::GetWithPatterns(storage, {__VA_ARGS__}) +#define OPTIONAL_MATCH(...) \ + query::test_common::GetWithPatterns(storage.Create(true), \ + {__VA_ARGS__}) +#define MATCH(...) \ + query::test_common::GetWithPatterns(storage.Create(), \ + {__VA_ARGS__}) #define WHERE(expr) storage.Create((expr)) -#define CREATE(...) \ - query::test_common::GetWithPatterns(storage, {__VA_ARGS__}) +#define CREATE(...) \ + query::test_common::GetWithPatterns(storage.Create(), \ + {__VA_ARGS__}) #define IDENT(name) storage.Create((name)) #define LITERAL(val) storage.Create((val)) #define PROPERTY_LOOKUP(...) \ diff --git a/tests/unit/query_planner.cpp b/tests/unit/query_planner.cpp index 9130c4217..dfd941c6b 100644 --- a/tests/unit/query_planner.cpp +++ b/tests/unit/query_planner.cpp @@ -67,6 +67,11 @@ class PlanChecker : public LogicalOperatorVisitor { op.input()->Accept(*this); return false; } + bool PreVisit(Optional &op) override { + CheckOp(op); + op.input()->Accept(*this); + return false; + } std::list checkers_; @@ -172,6 +177,20 @@ class ExpectMerge : public OpChecker { const std::list &on_create_; }; +class ExpectOptional : public OpChecker { + public: + ExpectOptional(const std::list &optional) + : optional_(optional) {} + + void ExpectOp(Optional &optional, const SymbolTable &symbol_table) override { + PlanChecker check_optional(optional_, symbol_table); + optional.optional()->Accept(check_optional); + } + + private: + const std::list &optional_; +}; + auto MakeSymbolTable(query::Query &query) { SymbolTable symbol_table; SymbolGenerator symbol_generator(symbol_table); @@ -383,7 +402,7 @@ TEST(TestLogicalPlanner, MultiMatchSameStart) { CheckPlan(*query, ExpectScanAll(), ExpectExpand()); } -TEST(TestLogicalPlanner, MatchEdgeCycle) { +TEST(TestLogicalPlanner, MatchExistingEdge) { // Test MATCH (n) -[r]- (m) -[r]- (j) AstTreeStorage storage; auto query = QUERY( @@ -392,6 +411,17 @@ TEST(TestLogicalPlanner, MatchEdgeCycle) { CheckPlan(*query, ExpectScanAll(), ExpectExpand(), ExpectExpand()); } +TEST(TestLogicalPlanner, MultiMatchExistingEdgeOtherEdge) { + // Test MATCH (n) -[r]- (m) MATCH (m) -[r]- (j) -[e]- (l) + AstTreeStorage storage; + auto query = QUERY( + MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))), + MATCH(PATTERN(NODE("m"), EDGE("r"), NODE("j"), EDGE("e"), NODE("l")))); + // We need ExpandUniquenessFilter for edge `e` against `r` in second MATCH. + CheckPlan(*query, ExpectScanAll(), ExpectExpand(), ExpectExpand(), + ExpectExpand(), ExpectExpandUniquenessFilter()); +} + TEST(TestLogicalPlanner, MatchWithReturn) { // Test MATCH (old) WITH old AS new RETURN new AS new AstTreeStorage storage; @@ -624,4 +654,19 @@ TEST(TestLogicalPlanner, MatchMerge) { on_create.clear(); } +TEST(TestLogicalPlanner, MatchOptionalMatchWhereReturn) { + // Test MATCH (n) OPTIONAL MATCH (n) -[r]- (m) WHERE m.prop < 42 RETURN r + Dbms dbms; + auto dba = dbms.active(); + auto prop = dba->property("prop"); + AstTreeStorage storage; + auto query = QUERY(MATCH(PATTERN(NODE("n"))), + OPTIONAL_MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))), + WHERE(LESS(PROPERTY_LOOKUP("m", prop), LITERAL(42))), + RETURN(IDENT("r"), AS("r"))); + std::list optional{new ExpectScanAll(), new ExpectExpand(), + new ExpectFilter()}; + CheckPlan(*query, ExpectScanAll(), ExpectOptional(optional), ExpectProduce()); +} + } // namespace