diff --git a/src/query/interpret/eval.hpp b/src/query/interpret/eval.hpp index 2a9fb289f..19fb34d3a 100644 --- a/src/query/interpret/eval.hpp +++ b/src/query/interpret/eval.hpp @@ -1105,8 +1105,15 @@ class ExpressionEvaluator : public ExpressionVisitor { } } - TypedValue Visit(PatternComprehension & /*pattern_comprehension*/) override { - throw utils::NotYetImplemented("Expression evaluator can not handle pattern comprehension."); + TypedValue Visit(PatternComprehension &pattern_comprehension) override { + TypedValue &frame_pattern_comprehension_value = frame_->at(symbol_table_->at(pattern_comprehension)); + if (!frame_pattern_comprehension_value.IsList()) [[unlikely]] { + throw QueryRuntimeException( + "Unexpected behavior: Pattern Comprehension expected a list, got {}. Please report the problem on GitHub " + "issues", + frame_pattern_comprehension_value.type()); + } + return frame_pattern_comprehension_value; } private: diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 29f64f950..eb9953e2f 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -138,6 +138,7 @@ extern const Event EvaluatePatternFilterOperator; extern const Event ApplyOperator; extern const Event IndexedJoinOperator; extern const Event HashJoinOperator; +extern const Event RollUpApplyOperator; } // namespace memgraph::metrics namespace memgraph::query::plan { @@ -5741,16 +5742,15 @@ UniqueCursorPtr HashJoin::MakeCursor(utils::MemoryResource *mem) const { return MakeUniqueCursorPtr(mem, *this, mem); } -RollUpApply::RollUpApply(const std::shared_ptr &input, - std::shared_ptr &&second_branch) - : input_(input), list_collection_branch_(second_branch) {} +RollUpApply::RollUpApply(const std::shared_ptr &&input, + std::shared_ptr &&list_collection_branch, + const std::vector &list_collection_symbols, Symbol result_symbol) + : input_(input), + list_collection_branch_(list_collection_branch), + list_collection_symbols_(list_collection_symbols), + result_symbol_(result_symbol) {} -std::vector RollUpApply::OutputSymbols(const SymbolTable & /*symbol_table*/) const { - std::vector symbols; - return symbols; -} - -std::vector RollUpApply::ModifiedSymbols(const SymbolTable &table) const { return OutputSymbols(table); } +std::vector RollUpApply::ModifiedSymbols(const SymbolTable &table) const { return {result_symbol_}; } bool RollUpApply::Accept(HierarchicalLogicalOperatorVisitor &visitor) { if (visitor.PreVisit(*this)) { @@ -5762,4 +5762,68 @@ bool RollUpApply::Accept(HierarchicalLogicalOperatorVisitor &visitor) { return visitor.PostVisit(*this); } +namespace { + +class RollUpApplyCursor : public Cursor { + public: + RollUpApplyCursor(const RollUpApply &self, utils::MemoryResource *mem) + : self_(self), + input_cursor_(self.input_->MakeCursor(mem)), + list_collection_cursor_(self_.list_collection_branch_->MakeCursor(mem)) { + MG_ASSERT(input_cursor_ != nullptr, "RollUpApplyCursor: Missing left operator cursor."); + MG_ASSERT(list_collection_cursor_ != nullptr, "RollUpApplyCursor: Missing right operator cursor."); + MG_ASSERT(self_.list_collection_symbols_.size() == 1U, "Expected a single list collection symbol."); + } + + bool Pull(Frame &frame, ExecutionContext &context) override { + OOMExceptionEnabler oom_exception; + SCOPED_PROFILE_OP_BY_REF(self_); + + auto clear_frame_change_collector = [&context](const auto &symbols) { + for (const auto &symbol : symbols) { + if (context.frame_change_collector && context.frame_change_collector->IsKeyTracked(symbol.name())) { + context.frame_change_collector->ResetTrackingValue(symbol.name()); + } + } + }; + + TypedValue result(std::vector(), context.evaluation_context.memory); + if (input_cursor_->Pull(frame, context)) { + while (list_collection_cursor_->Pull(frame, context)) { + // collect values from the list collection branch + for (const auto &output_symbol : self_.list_collection_symbols_) { + result.ValueList().emplace_back(frame[output_symbol]); + } + } + clear_frame_change_collector(self_.list_collection_symbols_); + frame[self_.result_symbol_] = result; + } else { + return false; + } + + return true; + } + + void Shutdown() override { + input_cursor_->Shutdown(); + list_collection_cursor_->Shutdown(); + } + + void Reset() override { + input_cursor_->Reset(); + list_collection_cursor_->Reset(); + } + + private: + const RollUpApply &self_; + const UniqueCursorPtr input_cursor_; + const UniqueCursorPtr list_collection_cursor_; +}; +} // namespace + +UniqueCursorPtr RollUpApply::MakeCursor(utils::MemoryResource *mem) const { + memgraph::metrics::IncrementCounter(memgraph::metrics::RollUpApplyOperator); + return MakeUniqueCursorPtr(mem, *this, mem); +} + } // namespace memgraph::query::plan diff --git a/src/query/plan/operator.hpp b/src/query/plan/operator.hpp index 5a8ef0625..fa736cd86 100644 --- a/src/query/plan/operator.hpp +++ b/src/query/plan/operator.hpp @@ -2682,28 +2682,30 @@ class RollUpApply : public memgraph::query::plan::LogicalOperator { const utils::TypeInfo &GetTypeInfo() const override { return kType; } RollUpApply() = default; - RollUpApply(const std::shared_ptr &input, std::shared_ptr &&second_branch); + RollUpApply(const std::shared_ptr &&input, std::shared_ptr &&list_collection_branch, + const std::vector &list_collection_symbols, Symbol result_symbol); bool HasSingleInput() const override { return false; } std::shared_ptr input() const override { return input_; } void set_input(std::shared_ptr input) override { input_ = input; } bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; - UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override { - throw utils::NotYetImplemented("Execution of Pattern comprehension is currently unsupported."); - } - std::vector OutputSymbols(const SymbolTable &) const override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; std::vector ModifiedSymbols(const SymbolTable &) const override; std::unique_ptr Clone(AstStorage *storage) const override { auto object = std::make_unique(); object->input_ = input_ ? input_->Clone(storage) : nullptr; object->list_collection_branch_ = list_collection_branch_ ? list_collection_branch_->Clone(storage) : nullptr; + object->list_collection_symbols_ = list_collection_symbols_; + object->result_symbol_ = result_symbol_; return object; } std::shared_ptr input_; std::shared_ptr list_collection_branch_; + std::vector list_collection_symbols_; + Symbol result_symbol_; }; } // namespace plan diff --git a/src/query/plan/preprocess.cpp b/src/query/plan/preprocess.cpp index 2c783fa15..823850b91 100644 --- a/src/query/plan/preprocess.cpp +++ b/src/query/plan/preprocess.cpp @@ -702,6 +702,7 @@ void PatternVisitor::Visit(PatternComprehension &op) { AddMatching({op.pattern_}, op.filter_, symbol_table_, storage_, matching); matching.result_expr = storage_.Create(symbol_table_.at(op).name(), op.resultExpr_); matching.result_expr->MapTo(symbol_table_.at(op)); + matching.result_symbol = symbol_table_.at(op); pattern_comprehension_matchings_.push_back(std::move(matching)); } diff --git a/src/query/plan/preprocess.hpp b/src/query/plan/preprocess.hpp index 5d4e2e8d2..4616064b3 100644 --- a/src/query/plan/preprocess.hpp +++ b/src/query/plan/preprocess.hpp @@ -503,6 +503,7 @@ inline auto Filters::IdFilters(const Symbol &symbol) const -> std::vectorAccept(*this); + RewriteBranch(&op.list_collection_branch_); + return false; + } + + bool PostVisit(RollUpApply &) override { + prev_ops_.pop_back(); + return true; + } + std::shared_ptr new_root_; private: diff --git a/src/query/plan/rule_based_planner.cpp b/src/query/plan/rule_based_planner.cpp index 54b5c3834..c6209ede1 100644 --- a/src/query/plan/rule_based_planner.cpp +++ b/src/query/plan/rule_based_planner.cpp @@ -19,6 +19,7 @@ #include #include "query/frontend/ast/ast.hpp" +#include "query/frontend/semantic/symbol_table.hpp" #include "query/plan/operator.hpp" #include "query/plan/preprocess.hpp" #include "utils/algorithm.hpp" @@ -43,8 +44,7 @@ namespace { class ReturnBodyContext : public HierarchicalTreeVisitor { public: ReturnBodyContext(const ReturnBody &body, SymbolTable &symbol_table, const std::unordered_set &bound_symbols, - AstStorage &storage, std::unordered_map> pc_ops, - Where *where = nullptr) + AstStorage &storage, PatternComprehensionDataMap &pc_ops, Where *where = nullptr) : body_(body), symbol_table_(symbol_table), bound_symbols_(bound_symbols), storage_(storage), where_(where) { // Collect symbols from named expressions. output_symbols_.reserve(body_.named_expressions.size()); @@ -57,9 +57,11 @@ class ReturnBodyContext : public HierarchicalTreeVisitor { output_symbols_.emplace_back(symbol_table_.at(*named_expr)); named_expr->Accept(*this); named_expressions_.emplace_back(named_expr); - if (pattern_comprehension_) { + + // Pattern comprehension can be filled during named expression traversion + if (has_pattern_comprehension()) { if (auto it = pc_ops.find(named_expr->name_); it != pc_ops.end()) { - pattern_comprehension_op_ = std::move(it->second); + pattern_comprehension_data_.op = std::move(it->second.op); pc_ops.erase(it); } else { throw utils::NotYetImplemented("Operation on top of pattern comprehension"); @@ -399,18 +401,19 @@ class ReturnBodyContext : public HierarchicalTreeVisitor { } bool PreVisit(PatternComprehension & /*unused*/) override { - pattern_compression_aggregations_start_index_ = has_aggregation_.size(); + aggregations_start_index_ = has_aggregation_.size(); return true; } bool PostVisit(PatternComprehension &pattern_comprehension) override { bool has_aggr = false; - for (auto i = has_aggregation_.size(); i > pattern_compression_aggregations_start_index_; --i) { + for (auto i = has_aggregation_.size(); i > aggregations_start_index_; --i) { has_aggr |= has_aggregation_.back(); has_aggregation_.pop_back(); } has_aggregation_.emplace_back(has_aggr); - pattern_comprehension_ = &pattern_comprehension; + pattern_comprehension_data_.pattern_comprehension = &pattern_comprehension; + pattern_comprehension_data_.result_symbol = symbol_table_.at(pattern_comprehension); return true; } @@ -468,9 +471,11 @@ class ReturnBodyContext : public HierarchicalTreeVisitor { // named_expressions. const auto &output_symbols() const { return output_symbols_; } - const auto *pattern_comprehension() const { return pattern_comprehension_; } + const bool has_pattern_comprehension() const { return pattern_comprehension_data_.pattern_comprehension != nullptr; } - std::shared_ptr pattern_comprehension_op() const { return pattern_comprehension_op_; } + const PatternComprehensionData pattern_comprehension_data() const { return pattern_comprehension_data_; } + + const SymbolTable &symbol_table() const { return symbol_table_; } private: const ReturnBody &body_; @@ -493,9 +498,8 @@ class ReturnBodyContext : public HierarchicalTreeVisitor { // group by it. std::list has_aggregation_; std::vector named_expressions_; - PatternComprehension *pattern_comprehension_ = nullptr; - std::shared_ptr pattern_comprehension_op_; - size_t pattern_compression_aggregations_start_index_ = 0; + PatternComprehensionData pattern_comprehension_data_; + size_t aggregations_start_index_ = 0; }; std::unique_ptr GenReturnBody(std::unique_ptr input_op, bool advance_command, @@ -514,8 +518,11 @@ std::unique_ptr GenReturnBody(std::unique_ptr last_op = std::make_unique(std::move(last_op), body.aggregations(), body.group_by(), remember); } - if (body.pattern_comprehension()) { - last_op = std::make_unique(std::move(last_op), body.pattern_comprehension_op()); + if (body.has_pattern_comprehension()) { + auto list_collection_data = body.pattern_comprehension_data(); + auto list_collection_symbols = list_collection_data.op->ModifiedSymbols(body.symbol_table()); + last_op = std::make_unique(std::move(last_op), std::move(list_collection_data.op), + list_collection_symbols, list_collection_data.result_symbol); } last_op = std::make_unique(std::move(last_op), body.named_expressions()); @@ -580,9 +587,9 @@ Expression *ExtractFilters(const std::unordered_set &bound_symbols, Filt return filter_expr; } -std::unordered_set GetSubqueryBoundSymbols( - const std::vector &single_query_parts, SymbolTable &symbol_table, AstStorage &storage, - std::unordered_map> pc_ops) { +std::unordered_set GetSubqueryBoundSymbols(const std::vector &single_query_parts, + SymbolTable &symbol_table, AstStorage &storage, + PatternComprehensionDataMap &pc_ops) { const auto &query = single_query_parts[0]; if (!query.matching.expansions.empty() || query.remaining_clauses.empty()) { @@ -622,7 +629,7 @@ std::unique_ptr GenNamedPaths(std::unique_ptr std::unique_ptr GenReturn(Return &ret, std::unique_ptr input_op, SymbolTable &symbol_table, bool is_write, const std::unordered_set &bound_symbols, AstStorage &storage, - std::unordered_map> pc_ops) { + PatternComprehensionDataMap &pc_ops) { // Similar to WITH clause, but we want to accumulate when the query writes to // the database. This way we handle the case when we want to return // expressions with the latest updated results. For example, `MATCH (n) -- () @@ -638,7 +645,7 @@ std::unique_ptr GenReturn(Return &ret, std::unique_ptr GenWith(With &with, std::unique_ptr input_op, SymbolTable &symbol_table, bool is_write, std::unordered_set &bound_symbols, AstStorage &storage, - std::unordered_map> pc_ops) { + PatternComprehensionDataMap &pc_ops) { // WITH clause is Accumulate/Aggregate (advance_command) + Produce and // optional Filter. In case of update and aggregation, we want to accumulate // first, so that when aggregating, we get the latest results. Similar to diff --git a/src/query/plan/rule_based_planner.hpp b/src/query/plan/rule_based_planner.hpp index 52281de60..1930f81f1 100644 --- a/src/query/plan/rule_based_planner.hpp +++ b/src/query/plan/rule_based_planner.hpp @@ -27,6 +27,18 @@ namespace memgraph::query::plan { +struct PatternComprehensionData { + PatternComprehensionData() = default; + + PatternComprehensionData(std::shared_ptr lop, Symbol res_symbol) + : op(std::move(lop)), result_symbol(res_symbol) {} + + PatternComprehension *pattern_comprehension = nullptr; + std::shared_ptr op; + Symbol result_symbol; +}; +using PatternComprehensionDataMap = std::unordered_map; + /// @brief Context which contains variables commonly used during planning. template struct PlanningContext { @@ -88,9 +100,9 @@ bool HasBoundFilterSymbols(const std::unordered_set &bound_symbols, cons // Returns the set of symbols for the subquery that are actually referenced from the outer scope and // used in the subquery. -std::unordered_set GetSubqueryBoundSymbols( - const std::vector &single_query_parts, SymbolTable &symbol_table, AstStorage &storage, - std::unordered_map> pc_ops); +std::unordered_set GetSubqueryBoundSymbols(const std::vector &single_query_parts, + SymbolTable &symbol_table, AstStorage &storage, + PatternComprehensionDataMap &pc_ops); Symbol GetSymbol(NodeAtom *atom, const SymbolTable &symbol_table); Symbol GetSymbol(EdgeAtom *atom, const SymbolTable &symbol_table); @@ -145,12 +157,12 @@ std::unique_ptr GenNamedPaths(std::unique_ptr std::unique_ptr GenReturn(Return &ret, std::unique_ptr input_op, SymbolTable &symbol_table, bool is_write, const std::unordered_set &bound_symbols, AstStorage &storage, - std::unordered_map> pc_ops); + PatternComprehensionDataMap &pc_ops); std::unique_ptr GenWith(With &with, std::unique_ptr input_op, SymbolTable &symbol_table, bool is_write, std::unordered_set &bound_symbols, AstStorage &storage, - std::unordered_map> pc_ops); + PatternComprehensionDataMap &pc_ops); std::unique_ptr GenUnion(const CypherUnion &cypher_union, std::shared_ptr left_op, std::shared_ptr right_op, SymbolTable &symbol_table); @@ -194,7 +206,7 @@ class RuleBasedPlanner { uint64_t merge_id = 0; uint64_t subquery_id = 0; - std::unordered_map> pattern_comprehension_ops; + PatternComprehensionDataMap pattern_comprehension_ops; if (single_query_part.pattern_comprehension_matchings.size() > 1) { throw utils::NotYetImplemented("Multiple pattern comprehensions."); @@ -204,7 +216,8 @@ class RuleBasedPlanner { MatchContext match_ctx{matching.second, *context.symbol_table, context.bound_symbols}; new_input = PlanMatching(match_ctx, std::move(new_input)); new_input = std::make_unique(std::move(new_input), std::vector{matching.second.result_expr}); - pattern_comprehension_ops.emplace(matching.first, std::move(new_input)); + pattern_comprehension_ops.emplace( + matching.first, PatternComprehensionData(std::move(new_input), matching.second.result_symbol)); } for (const auto &clause : single_query_part.remaining_clauses) { @@ -875,9 +888,9 @@ class RuleBasedPlanner { symbol); } - std::unique_ptr HandleSubquery( - std::unique_ptr last_op, std::shared_ptr subquery, SymbolTable &symbol_table, - AstStorage &storage, std::unordered_map> pc_ops) { + std::unique_ptr HandleSubquery(std::unique_ptr last_op, + std::shared_ptr subquery, SymbolTable &symbol_table, + AstStorage &storage, PatternComprehensionDataMap &pc_ops) { std::unordered_set outer_scope_bound_symbols; outer_scope_bound_symbols.insert(std::make_move_iterator(context_->bound_symbols.begin()), std::make_move_iterator(context_->bound_symbols.end())); diff --git a/src/utils/event_counter.cpp b/src/utils/event_counter.cpp index 54ff4ed5c..da396ff56 100644 --- a/src/utils/event_counter.cpp +++ b/src/utils/event_counter.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -57,6 +57,7 @@ M(ApplyOperator, Operator, "Number of times ApplyOperator operator was used.") \ M(IndexedJoinOperator, Operator, "Number of times IndexedJoin operator was used.") \ M(HashJoinOperator, Operator, "Number of times HashJoin operator was used.") \ + M(RollUpApplyOperator, Operator, "Number of times RollUpApply operator was used.") \ \ M(ActiveLabelIndices, Index, "Number of active label indices in the system.") \ M(ActiveLabelPropertyIndices, Index, "Number of active label property indices in the system.") \ diff --git a/tests/gql_behave/tests/memgraph_V1/features/list_operations.feature b/tests/gql_behave/tests/memgraph_V1/features/list_operations.feature index a6a4b15d2..9e3f06b34 100644 --- a/tests/gql_behave/tests/memgraph_V1/features/list_operations.feature +++ b/tests/gql_behave/tests/memgraph_V1/features/list_operations.feature @@ -287,10 +287,20 @@ Feature: List operators MATCH (keanu:Person {name: 'Keanu Reeves'}) RETURN [(keanu)-->(b:Movie) WHERE b.title CONTAINS 'Matrix' | b.released] AS years """ - Then an error should be raised -# Then the result should be: -# | years | -# | [2021,2003,2003,1999] | + Then the result should be: + | years | + | [2003, 2003, 1999, 2021] | + + Scenario: List pattern comprehension and property + Given graph "graph_keanu" + When executing query: + """ + MATCH (keanu:Person {name: 'Keanu Reeves'}) + RETURN [(keanu)-->(b:Movie) WHERE b.title CONTAINS 'Matrix' | b.released] AS years, keanu.name + """ + Then the result should be: + | years | keanu.name | + | [2003, 2003, 1999, 2021] | "Keanu Reeves" | Scenario: Multiple entries with list pattern comprehension Given graph "graph_keanu" @@ -299,7 +309,11 @@ Feature: List operators MATCH (n:Person) RETURN n.name, [(n)-->(b:Movie) WHERE b.title CONTAINS 'Matrix' | b.released] AS years """ - Then an error should be raised + Then the result should be: + | n.name | years | + | "Keanu Reeves" | [2003, 2003, 1999, 2021] | + | "Carrie-Anne Moss" | [2003,1999] | + | "Laurence Fishburne" | [1999] | Scenario: Multiple list pattern comprehensions in Return Given graph "graph_keanu"