diff --git a/src/query/plan/rule_based_planner.cpp b/src/query/plan/rule_based_planner.cpp index c66b541ad..7945913ca 100644 --- a/src/query/plan/rule_based_planner.cpp +++ b/src/query/plan/rule_based_planner.cpp @@ -530,7 +530,7 @@ std::vector<Expansion> NormalizePatterns( collector.symbols_.erase( symbol_table.at(*bf_atom->next_node_identifier_)); } - expansions.emplace_back(Expansion{prev_node, edge, edge->direction_, + expansions.emplace_back(Expansion{prev_node, edge, edge->direction_, false, collector.symbols_, current_node}); }; for (const auto &pattern : patterns) { diff --git a/src/query/plan/rule_based_planner.hpp b/src/query/plan/rule_based_planner.hpp index 672f028e6..82fb3a5b7 100644 --- a/src/query/plan/rule_based_planner.hpp +++ b/src/query/plan/rule_based_planner.hpp @@ -15,6 +15,8 @@ struct Expansion { /// Direction of the edge, it may be flipped compared to original /// @c EdgeAtom during plan generation. EdgeAtom::Direction direction = EdgeAtom::Direction::BOTH; + /// True if the direction and nodes were flipped. + bool is_flipped = false; /// Set of symbols found inside the range expressions of a variable path edge. std::unordered_set<Symbol> symbols_in_range{}; /// Optional node at the other end of an edge. If the expansion @@ -489,8 +491,8 @@ class RuleBasedPlanner { bound_symbols, node_symbol, all_filters, storage); last_op = new ExpandVariable( node_symbol, edge_symbol, expansion.direction, - expansion.direction != expansion.edge->direction_, - expansion.edge->lower_bound_, expansion.edge->upper_bound_, + expansion.is_flipped, expansion.edge->lower_bound_, + expansion.edge->upper_bound_, std::shared_ptr<LogicalOperator>(last_op), node1_symbol, existing_node, existing_edge, match_context.graph_view, filter_expr); diff --git a/src/query/plan/variable_start_planner.cpp b/src/query/plan/variable_start_planner.cpp index 0c7e638bf..53d4a9b28 100644 --- a/src/query/plan/variable_start_planner.cpp +++ b/src/query/plan/variable_start_planner.cpp @@ -96,6 +96,7 @@ void AddNextExpansions( if (!dynamic_cast<BreadthFirstAtom *>(expansion.edge)) { // BFS must *not* be flipped. Doing that changes the BFS results. std::swap(expansion.node1, expansion.node2); + expansion.is_flipped = true; if (expansion.direction != EdgeAtom::Direction::BOTH) { expansion.direction = expansion.direction == EdgeAtom::Direction::IN ? EdgeAtom::Direction::OUT diff --git a/tests/unit/query_variable_start_planner.cpp b/tests/unit/query_variable_start_planner.cpp index db8d70dc2..894be87ca 100644 --- a/tests/unit/query_variable_start_planner.cpp +++ b/tests/unit/query_variable_start_planner.cpp @@ -263,6 +263,33 @@ TEST(TestVariableStartPlanner, MatchVariableExpandReferenceNode) { }); } +TEST(TestVariableStartPlanner, MatchVariableExpandBoth) { + Dbms dbms; + auto dba = dbms.active(); + auto id = dba->Property("id"); + // Graph (v1 {id:1}) -[:r1]-> (v2) -[:r2]-> (v3) + auto v1 = dba->InsertVertex(); + v1.PropsSet(id, 1); + auto v2 = dba->InsertVertex(); + auto v3 = dba->InsertVertex(); + auto r1 = dba->InsertEdge(v1, v2, dba->EdgeType("r1")); + auto r2 = dba->InsertEdge(v2, v3, dba->EdgeType("r2")); + dba->AdvanceCommand(); + // Test MATCH (n {id:1}) -[r*]- (m) RETURN r + AstTreeStorage storage; + auto edge = EDGE("r", Direction::BOTH); + edge->has_range_ = true; + auto node_n = NODE("n"); + node_n->properties_[std::make_pair("id", id)] = LITERAL(1); + QUERY(MATCH(PATTERN(node_n, edge, NODE("m"))), RETURN("r")); + // We expect to get a single column with the following rows: + TypedValue r1_list(std::vector<TypedValue>{r1}); // [r1] + TypedValue r1_r2_list(std::vector<TypedValue>{r1, r2}); // [r1, r2] + CheckPlansProduce(2, storage, *dba, [&](const auto &results) { + AssertRows(results, {{r1_list}, {r1_r2_list}}); + }); +} + TEST(TestVariableStartPlanner, MatchBfs) { Dbms dbms; auto dba = dbms.active();