diff --git a/src/query/plan/rule_based_planner.cpp b/src/query/plan/rule_based_planner.cpp index e5a14d6c7..d6a6e9a70 100644 --- a/src/query/plan/rule_based_planner.cpp +++ b/src/query/plan/rule_based_planner.cpp @@ -634,6 +634,16 @@ std::vector<Expansion> NormalizePatterns( if (edge->upper_bound_) { edge->upper_bound_->Accept(collector); } + if (auto *bf_atom = dynamic_cast<BreadthFirstAtom *>(edge)) { + // Get used symbols inside bfs filter expression and max depth. + bf_atom->filter_expression_->Accept(collector); + bf_atom->max_depth_->Accept(collector); + // Remove symbols which are bound by the bfs itself. + collector.symbols_.erase( + symbol_table.at(*bf_atom->traversed_edge_identifier_)); + collector.symbols_.erase( + symbol_table.at(*bf_atom->next_node_identifier_)); + } expansions.emplace_back(Expansion{prev_node, edge, edge->direction_, collector.symbols_, current_node}); }; @@ -837,7 +847,18 @@ LogicalOperator *PlanMatching(const Matching &matching, } else { context.new_symbols.emplace_back(edge_symbol); } - if (expansion.edge->has_range_) { + if (auto *bf_atom = dynamic_cast<BreadthFirstAtom *>(expansion.edge)) { + const auto &traversed_edge_symbol = + symbol_table.at(*bf_atom->traversed_edge_identifier_); + const auto &next_node_symbol = + symbol_table.at(*bf_atom->next_node_identifier_); + last_op = new ExpandBreadthFirst( + node_symbol, edge_symbol, expansion.direction, bf_atom->max_depth_, + next_node_symbol, traversed_edge_symbol, + bf_atom->filter_expression_, + std::shared_ptr<LogicalOperator>(last_op), node1_symbol, + existing_node, context.graph_view); + } else if (expansion.edge->has_range_) { last_op = new ExpandVariable( node_symbol, edge_symbol, expansion.direction, expansion.edge->lower_bound_, expansion.edge->upper_bound_, diff --git a/src/query/plan/variable_start_planner.cpp b/src/query/plan/variable_start_planner.cpp index 04b7540b9..2372c865d 100644 --- a/src/query/plan/variable_start_planner.cpp +++ b/src/query/plan/variable_start_planner.cpp @@ -78,11 +78,13 @@ auto NextExpansion(const SymbolTable &symbol_table, if (expanded_symbols.find(node1_symbol) != expanded_symbols.end()) { return expansion_it; } + // Try expanding from node2 by flipping the expansion. auto *node2 = expansion_it->node2; if (node2 && expanded_symbols.find(symbol_table.at(*node2->identifier_)) != - expanded_symbols.end()) { - // We need to flip the expansion, since we want to expand from node2. + expanded_symbols.end() && + // BFS must *not* be flipped. Doing that changes the BFS results. + !dynamic_cast<BreadthFirstAtom *>(expansion_it->edge)) { std::swap(expansion_it->node2, expansion_it->node1); if (expansion_it->direction != EdgeAtom::Direction::BOTH) { expansion_it->direction = diff --git a/tests/qa/tck_engine/tests/memgraph_V1/features/match.feature b/tests/qa/tck_engine/tests/memgraph_V1/features/match.feature index 502288597..00051c45d 100644 --- a/tests/qa/tck_engine/tests/memgraph_V1/features/match.feature +++ b/tests/qa/tck_engine/tests/memgraph_V1/features/match.feature @@ -484,3 +484,32 @@ Feature: Match Then the result should be: | n.a | m.a | | 1 | 2 | + + Scenario: Test match BFS depth blocked + Given an empty graph + And having executed: + """ + CREATE (n {a:'0'}) -[:r]-> ({a:'1.1'}) -[:r]-> ({a:'2.1'}), (n) -[:r]-> ({a:'1.2'}) + """ + When executing query: + """ + MATCH (n {a:'0'}) -bfs(e, m| true, 1)-> (m) RETURN n.a, m.a + """ + Then the result should be: + | n.a | m.a | + | '0' | '1.1' | + | '0' | '1.2' | + + Scenario: Test match BFS filtered + Given an empty graph + And having executed: + """ + CREATE (n {a:'0'}) -[:r]-> ({a:'1.1'}) -[:r]-> ({a:'2.1'}), (n) -[:r]-> ({a:'1.2'}) + """ + When executing query: + """ + MATCH (n {a:'0'}) -bfs(e, m| m.a = '1.1' OR m.a = '0', 10)-> (m) RETURN n.a, m.a + """ + Then the result should be: + | n.a | m.a | + | '0' | '1.1' | diff --git a/tests/unit/query_planner.cpp b/tests/unit/query_planner.cpp index 09d091bba..4ed729992 100644 --- a/tests/unit/query_planner.cpp +++ b/tests/unit/query_planner.cpp @@ -55,6 +55,7 @@ class PlanChecker : public HierarchicalLogicalOperatorVisitor { PRE_VISIT(ScanAllByLabelPropertyRange); PRE_VISIT(Expand); PRE_VISIT(ExpandVariable); + PRE_VISIT(ExpandBreadthFirst); PRE_VISIT(Filter); PRE_VISIT(Produce); PRE_VISIT(SetProperty); @@ -124,6 +125,7 @@ using ExpectScanAll = OpChecker<ScanAll>; using ExpectScanAllByLabel = OpChecker<ScanAllByLabel>; using ExpectExpand = OpChecker<Expand>; using ExpectExpandVariable = OpChecker<ExpandVariable>; +using ExpectExpandBreadthFirst = OpChecker<ExpandBreadthFirst>; using ExpectFilter = OpChecker<Filter>; using ExpectProduce = OpChecker<Produce>; using ExpectSetProperty = OpChecker<SetProperty>; @@ -1264,4 +1266,15 @@ TEST(TestLogicalPlanner, UnwindMatchVariable) { ExpectProduce()); } +TEST(TestLogicalPlanner, MatchBreadthFirst) { + // Test MATCH (n) -bfs[r](r, n|n, 10)-> (m) RETURN r + AstTreeStorage storage; + auto *bfs = storage.Create<query::BreadthFirstAtom>( + IDENT("r"), Direction::OUT, IDENT("r"), IDENT("n"), IDENT("n"), + LITERAL(10)); + QUERY(MATCH(PATTERN(NODE("n"), bfs, NODE("m"))), RETURN("r")); + CheckPlan(storage, ExpectScanAll(), ExpectExpandBreadthFirst(), + ExpectProduce()); +} + } // namespace diff --git a/tools/gdb-plugins/operator_tree.py b/tools/gdb-plugins/operator_tree.py index 0f89eea3a..650a6b030 100644 --- a/tools/gdb-plugins/operator_tree.py +++ b/tools/gdb-plugins/operator_tree.py @@ -1,4 +1,5 @@ import io +import re import gdb @@ -10,12 +11,26 @@ def _logical_operator_type(): return gdb.lookup_type('query::plan::LogicalOperator') -def _shared_ptr_pointee(shared_ptr): - '''Returns the address of the pointed to object inside shared_ptr.''' - # This function may not be needed when gdb adds dereferencing shared_ptr - # via Python API. +# Pattern for matching std::unique_ptr<T, Deleter> and std::shared_ptr<T> +_SMART_PTR_TYPE_PATTERN = \ + re.compile('^std::(unique|shared)_ptr<(?P<pointee_type>[\w:]*)') + + +def _is_smart_ptr(maybe_smart_ptr, type_name=None): + if maybe_smart_ptr.type.name is None: + return False + match = _SMART_PTR_TYPE_PATTERN.match(maybe_smart_ptr.type.name) + if match is None or type_name is None: + return bool(match) + return type_name == match.group('pointee_type') + + +def _smart_ptr_pointee(smart_ptr): + '''Returns the address of the pointed to object in shared_ptr/unique_ptr.''' + # This function may not be needed when gdb adds dereferencing + # shared_ptr/unique_ptr via Python API. with io.StringIO() as string_io: - print(shared_ptr, file=string_io) + print(smart_ptr, file=string_io) addr = string_io.getvalue().split()[-1] return int(addr, base=16) @@ -24,7 +39,7 @@ def _get_operator_input(operator): '''Returns the input operator of given operator, if it has any.''' if 'input_' not in [f.name for f in operator.type.fields()]: return None - input_addr = _shared_ptr_pointee(operator['input_']) + input_addr = _smart_ptr_pointee(operator['input_']) if input_addr == 0: return None pointer_type = _logical_operator_type().pointer() @@ -36,19 +51,30 @@ class PrintOperatorTree(gdb.Command): '''Print the tree of logical operators from the expression.''' def __init__(self): super(PrintOperatorTree, self).__init__("print-operator-tree", - gdb.COMMAND_USER) + gdb.COMMAND_USER, + gdb.COMPLETE_EXPRESSION) def invoke(self, argument, from_tty): try: operator = gdb.parse_and_eval(argument) except gdb.error as e: raise gdb.GdbError(*e.args) + logical_operator_type = _logical_operator_type() + if _is_smart_ptr(operator, 'query::plan::LogicalOperator'): + pointee = gdb.Value(_smart_ptr_pointee(operator)) + if pointee == 0: + raise gdb.GdbError("Expected a '%s', but got nullptr" % + logical_operator_type) + operator = \ + pointee.cast(logical_operator_type.pointer()).dereference() + elif operator.type == logical_operator_type.pointer(): + operator = operator.dereference() # Currently, gdb doesn't provide API to check if the dynamic_type is # subtype of a base type. So, this check will fail, for example if we # get 'query::plan::ScanAll'. The user can avoid this by up-casting. - if operator.type != _logical_operator_type(): + if operator.type != logical_operator_type: raise gdb.GdbError("Expected a '%s', but got '%s'" % - (_logical_operator_type(), operator.type)) + (logical_operator_type, operator.type)) next_op = operator.cast(operator.dynamic_type) tree = [] while next_op is not None: