diff --git a/src/query/frontend/ast/ast.cpp b/src/query/frontend/ast/ast.cpp index aaea154b9..089a28576 100644 --- a/src/query/frontend/ast/ast.cpp +++ b/src/query/frontend/ast/ast.cpp @@ -267,4 +267,7 @@ constexpr utils::TypeInfo query::TransactionQueueQuery::kType{utils::TypeId::AST "TransactionQueueQuery", &query::Query::kType}; constexpr utils::TypeInfo query::Exists::kType{utils::TypeId::AST_EXISTS, "Exists", &query::Expression::kType}; + +constexpr utils::TypeInfo query::CallSubquery::kType{utils::TypeId::AST_CALL_SUBQUERY, "CallSubquery", + &query::Clause::kType}; } // namespace memgraph diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index 1332d117b..7d7e27278 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -1616,6 +1616,8 @@ class NamedExpression : public memgraph::query::Tree, int32_t token_position_{-1}; /// Symbol table position of the symbol this NamedExpression is mapped to. int32_t symbol_pos_{-1}; + /// True if the variable is aliased + bool is_aliased_{false}; NamedExpression *Clone(AstStorage *storage) const override { NamedExpression *object = storage->Create(); @@ -1623,6 +1625,7 @@ class NamedExpression : public memgraph::query::Tree, object->expression_ = expression_ ? expression_->Clone(storage) : nullptr; object->token_position_ = token_position_; object->symbol_pos_ = symbol_pos_; + object->is_aliased_ = is_aliased_; return object; } @@ -1973,7 +1976,7 @@ class Query : public memgraph::query::Tree, public utils::Visitable { public: static const utils::TypeInfo kType; const utils::TypeInfo &GetTypeInfo() const override { return kType; } @@ -1982,6 +1985,17 @@ class CypherQuery : public memgraph::query::Query { DEFVISITABLE(QueryVisitor); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + single_query_->Accept(visitor); + for (auto *cypher_union : cypher_unions_) { + cypher_union->Accept(visitor); + } + } + + return visitor.PostVisit(*this); + } + /// First and potentially only query. memgraph::query::SingleQuery *single_query_{nullptr}; /// Contains remaining queries that should form and union with `single_query_`. @@ -3289,5 +3303,31 @@ class Exists : public memgraph::query::Expression { friend class AstStorage; }; +class CallSubquery : public memgraph::query::Clause { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + CallSubquery() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + cypher_query_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + memgraph::query::CypherQuery *cypher_query_; + + CallSubquery *Clone(AstStorage *storage) const override { + CallSubquery *object = storage->Create(); + object->cypher_query_ = cypher_query_ ? cypher_query_->Clone(storage) : nullptr; + return object; + } + + private: + friend class AstStorage; +}; + } // namespace query } // namespace memgraph diff --git a/src/query/frontend/ast/ast.lcp b/src/query/frontend/ast/ast.lcp index c6ed71bbc..63924f007 100644 --- a/src/query/frontend/ast/ast.lcp +++ b/src/query/frontend/ast/ast.lcp @@ -1237,7 +1237,9 @@ cpp<# (token-position :int32_t :initval -1 :scope :public :documentation "This field contains token position of first token in named expression used to create name_. If NamedExpression object is not created from query or it is aliased leave this value at -1.") (symbol-pos :int32_t :initval -1 :scope :public - :documentation "Symbol table position of the symbol this NamedExpression is mapped to.")) + :documentation "Symbol table position of the symbol this NamedExpression is mapped to.") + (is-aliased :bool :initval "false" :scope :public + :documentation "True if the variable is aliased")) (:public #>cpp using utils::Visitable>::Accept; @@ -1620,7 +1622,7 @@ cpp<# (:clone :ignore-other-base-classes t) (:type-info :ignore-other-base-classes t)) -(lcp:define-class cypher-query (query) +(lcp:define-class cypher-query (query "::utils::Visitable") ((single-query "SingleQuery *" :initval "nullptr" :scope :public :slk-save #'slk-save-ast-pointer :slk-load (slk-load-ast-pointer "SingleQuery") @@ -1639,13 +1641,25 @@ cpp<# CypherQuery() = default; DEFVISITABLE(QueryVisitor); + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + single_query_->Accept(visitor); + for (auto *cypher_union : cypher_unions_) { + cypher_union->Accept(visitor); + } + } + + return visitor.PostVisit(*this); + } cpp<#) (:private #>cpp friend class AstStorage; cpp<#) - (:serialize (:slk)) - (:clone)) + (:serialize (:slk :ignore-other-base-classes t)) + (:clone :ignore-other-base-classes t) + (:type-info :ignore-other-base-classes t)) (lcp:define-class explain-query (query) ((cypher-query "CypherQuery *" :initval "nullptr" :scope :public @@ -2787,5 +2801,28 @@ cpp<# (:serialize (:slk)) (:clone)) +(lcp:define-class call-subquery (clause) + ((cypher-query "CypherQuery *" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "CypherQuery"))) + (:public + #>cpp + CallSubquery() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + cypher_query_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + (lcp:pop-namespace) ;; namespace query (lcp:pop-namespace) ;; namespace memgraph diff --git a/src/query/frontend/ast/ast_visitor.hpp b/src/query/frontend/ast/ast_visitor.hpp index 2b28008bc..39119f6ff 100644 --- a/src/query/frontend/ast/ast_visitor.hpp +++ b/src/query/frontend/ast/ast_visitor.hpp @@ -95,6 +95,7 @@ class SettingQuery; class VersionQuery; class Foreach; class ShowConfigQuery; +class CallSubquery; class AnalyzeGraphQuery; class TransactionQueueQuery; class Exists; @@ -106,7 +107,7 @@ using TreeCompositeVisitor = utils::CompositeVisitor< ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, ListLiteral, MapLiteral, PropertyLookup, LabelsTest, Aggregation, Function, Reduce, Coalesce, Extract, All, Single, Any, None, CallProcedure, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom, Delete, Where, SetProperty, SetProperties, SetLabels, - RemoveProperty, RemoveLabels, Merge, Unwind, RegexMatch, LoadCsv, Foreach, Exists>; + RemoveProperty, RemoveLabels, Merge, Unwind, RegexMatch, LoadCsv, Foreach, Exists, CallSubquery, CypherQuery>; using TreeLeafVisitor = utils::LeafVisitor; diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index ad92cd1cb..0d417e35d 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -989,6 +989,10 @@ antlrcpp::Any CypherMainVisitor::visitSingleQuery(MemgraphCypher::SingleQueryCon calls_write_procedure = true; has_update = true; } + } else if (const auto *call_subquery = utils::Downcast(clause); call_subquery != nullptr) { + if (has_return) { + throw SemanticException("CALL can't be put after RETURN clause."); + } } else if (utils::IsSubtype(clause_type, Unwind::kType)) { check_write_procedure("UNWIND"); if (has_update || has_return) { @@ -1101,6 +1105,9 @@ antlrcpp::Any CypherMainVisitor::visitClause(MemgraphCypher::ClauseContext *ctx) if (ctx->foreach ()) { return static_cast(std::any_cast(ctx->foreach ()->accept(this))); } + if (ctx->callSubquery()) { + return static_cast(std::any_cast(ctx->callSubquery()->accept(this))); + } // TODO: implement other clauses. throw utils::NotYetImplemented("clause '{}'", ctx->getText()); return 0; @@ -1596,6 +1603,7 @@ antlrcpp::Any CypherMainVisitor::visitReturnItem(MemgraphCypher::ReturnItemConte named_expr->expression_ = std::any_cast(ctx->expression()->accept(this)); MG_ASSERT(named_expr->expression_); if (ctx->variable()) { + named_expr->is_aliased_ = true; named_expr->name_ = std::string(std::any_cast(ctx->variable()->accept(this))); users_identifiers.insert(named_expr->name_); } else { @@ -2565,6 +2573,20 @@ antlrcpp::Any CypherMainVisitor::visitShowConfigQuery(MemgraphCypher::ShowConfig return query_; } +antlrcpp::Any CypherMainVisitor::visitCallSubquery(MemgraphCypher::CallSubqueryContext *ctx) { + auto *call_subquery = storage_->Create(); + + MG_ASSERT(ctx->cypherQuery(), "Expected query inside subquery clause"); + + if (ctx->cypherQuery()->queryMemoryLimit()) { + throw SyntaxException("Memory limit cannot be set on subqueries!"); + } + + call_subquery->cypher_query_ = std::any_cast(ctx->cypherQuery()->accept(this)); + + return call_subquery; +} + LabelIx CypherMainVisitor::AddLabel(const std::string &name) { return storage_->GetLabelIx(name); } PropertyIx CypherMainVisitor::AddProperty(const std::string &name) { return storage_->GetPropertyIx(name); } diff --git a/src/query/frontend/ast/cypher_main_visitor.hpp b/src/query/frontend/ast/cypher_main_visitor.hpp index 6995adeea..7e1c17a1a 100644 --- a/src/query/frontend/ast/cypher_main_visitor.hpp +++ b/src/query/frontend/ast/cypher_main_visitor.hpp @@ -922,6 +922,11 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor { */ antlrcpp::Any visitShowConfigQuery(MemgraphCypher::ShowConfigQueryContext *ctx) override; + /** + * @return CallSubquery* + */ + antlrcpp::Any visitCallSubquery(MemgraphCypher::CallSubqueryContext *ctx) override; + public: Query *query() { return query_; } const static std::string kAnonPrefix; diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 index 5f16a99d3..04e6a7aba 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 @@ -176,6 +176,7 @@ clause : cypherMatch | callProcedure | loadCsv | foreach + | callSubquery ; updateClause : set @@ -188,6 +189,8 @@ updateClause : set foreach : FOREACH '(' variable IN expression '|' updateClause+ ')' ; +callSubquery : CALL '{' cypherQuery '}' ; + streamQuery : checkStream | createStream | dropStream diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 index e80fa0ae7..d081419a7 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 @@ -36,6 +36,7 @@ BATCH_LIMIT : B A T C H UNDERSCORE L I M I T ; BATCH_SIZE : B A T C H UNDERSCORE S I Z E ; BEFORE : B E F O R E ; BOOTSTRAP_SERVERS : B O O T S T R A P UNDERSCORE S E R V E R S ; +CALL : C A L L ; CHECK : C H E C K ; CLEAR : C L E A R ; COMMIT : C O M M I T ; diff --git a/src/query/frontend/semantic/required_privileges.cpp b/src/query/frontend/semantic/required_privileges.cpp index b7dd1dc3c..6e90bd3a3 100644 --- a/src/query/frontend/semantic/required_privileges.cpp +++ b/src/query/frontend/semantic/required_privileges.cpp @@ -31,9 +31,9 @@ class PrivilegeExtractor : public QueryVisitor, public HierarchicalTreeVis void Visit(AuthQuery & /*unused*/) override { AddPrivilege(AuthQuery::Privilege::AUTH); } - void Visit(ExplainQuery &query) override { query.cypher_query_->Accept(*this); } + void Visit(ExplainQuery &query) override { query.cypher_query_->Accept(dynamic_cast(*this)); } - void Visit(ProfileQuery &query) override { query.cypher_query_->Accept(*this); } + void Visit(ProfileQuery &query) override { query.cypher_query_->Accept(dynamic_cast(*this)); } void Visit(InfoQuery &info_query) override { switch (info_query.info_type_) { diff --git a/src/query/frontend/semantic/symbol_generator.cpp b/src/query/frontend/semantic/symbol_generator.cpp index d5450392b..c92db6f2c 100644 --- a/src/query/frontend/semantic/symbol_generator.cpp +++ b/src/query/frontend/semantic/symbol_generator.cpp @@ -211,6 +211,33 @@ bool SymbolGenerator::PostVisit(CallProcedure &call_proc) { return true; } +bool SymbolGenerator::PreVisit(CallSubquery & /*call_sub*/) { + scopes_.emplace_back(Scope{.in_call_subquery = true}); + return true; +} + +bool SymbolGenerator::PostVisit(CallSubquery & /*call_sub*/) { + // no need to set the flag to true as we are popping the scope + auto subquery_scope = scopes_.back(); + scopes_.pop_back(); + auto &main_query_scope = scopes_.back(); + + if (!subquery_scope.has_return) { + return true; + } + + // append symbols returned in from subquery to outer scope + for (const auto &[symbol_name, symbol] : subquery_scope.symbols) { + if (main_query_scope.symbols.find(symbol_name) != main_query_scope.symbols.end()) { + throw SemanticException("Variable in subquery already declared in outer scope!"); + } + + main_query_scope.symbols[symbol_name] = symbol; + } + + return true; +} + bool SymbolGenerator::PreVisit(LoadCsv &load_csv) { return false; } bool SymbolGenerator::PostVisit(LoadCsv &load_csv) { @@ -224,6 +251,8 @@ bool SymbolGenerator::PostVisit(LoadCsv &load_csv) { bool SymbolGenerator::PreVisit(Return &ret) { auto &scope = scopes_.back(); scope.in_return = true; + scope.has_return = true; + VisitReturnBody(ret.body_); scope.in_return = false; return false; // We handled the traversal ourselves. @@ -470,6 +499,15 @@ bool SymbolGenerator::PostVisit(Exists & /*exists*/) { return true; } +bool SymbolGenerator::PreVisit(NamedExpression &named_expression) { + if (auto &scope = scopes_.back(); scope.in_call_subquery && scope.in_return && + !utils::Downcast(named_expression.expression_) && + !named_expression.is_aliased_) { + throw SemanticException("Expression returned from subquery must be aliased (use AS)!"); + } + return true; +} + bool SymbolGenerator::PreVisit(SetProperty & /*set_property*/) { auto &scope = scopes_.back(); scope.in_set_property = true; diff --git a/src/query/frontend/semantic/symbol_generator.hpp b/src/query/frontend/semantic/symbol_generator.hpp index 556c6dab4..25b8dc648 100644 --- a/src/query/frontend/semantic/symbol_generator.hpp +++ b/src/query/frontend/semantic/symbol_generator.hpp @@ -50,6 +50,8 @@ class SymbolGenerator : public HierarchicalTreeVisitor { bool PostVisit(Create &) override; bool PreVisit(CallProcedure &) override; bool PostVisit(CallProcedure &) override; + bool PreVisit(CallSubquery & /*unused*/) override; + bool PostVisit(CallSubquery & /*unused*/) override; bool PreVisit(LoadCsv &) override; bool PostVisit(LoadCsv &) override; bool PreVisit(Return &) override; @@ -83,6 +85,7 @@ class SymbolGenerator : public HierarchicalTreeVisitor { bool PreVisit(Extract &) override; bool PreVisit(Exists & /*exists*/) override; bool PostVisit(Exists & /*exists*/) override; + bool PreVisit(NamedExpression & /*unused*/) override; // Pattern and its subparts. bool PreVisit(Pattern &) override; @@ -119,6 +122,8 @@ class SymbolGenerator : public HierarchicalTreeVisitor { bool in_foreach{false}; bool in_exists{false}; bool in_set_property{false}; + bool in_call_subquery{false}; + bool has_return{false}; // True when visiting a pattern atom (node or edge) identifier, which can be // reused or created in the pattern itself. bool in_pattern_atom_identifier{false}; diff --git a/src/query/plan/cost_estimator.hpp b/src/query/plan/cost_estimator.hpp index 0b3e7f867..f9b71b0d8 100644 --- a/src/query/plan/cost_estimator.hpp +++ b/src/query/plan/cost_estimator.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -62,6 +62,8 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor { static constexpr double kEdgeUniquenessFilter{1.5}; static constexpr double kUnwind{1.3}; static constexpr double kForeach{1.0}; + static constexpr double kUnion{1.0}; + static constexpr double kSubquery{1.0}; }; struct CardParam { @@ -212,6 +214,31 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor { return true; } + bool PreVisit(Union &op) override { + double left_cost = EstimateCostOnBranch(&op.left_op_); + double right_cost = EstimateCostOnBranch(&op.right_op_); + + // the number of hits in the previous operator should be the joined number of results of both parts of the union + cardinality_ *= (left_cost + right_cost); + IncrementCost(CostParam::kUnion); + + return false; + } + + bool PreVisit(Apply &op) override { + double input_cost = EstimateCostOnBranch(&op.input_); + double subquery_cost = EstimateCostOnBranch(&op.subquery_); + + // if the query is a unit subquery, we don't want the cost to be zero but 1xN + input_cost = input_cost == 0 ? 1 : input_cost; + subquery_cost = subquery_cost == 0 ? 1 : subquery_cost; + + cardinality_ *= input_cost * subquery_cost; + IncrementCost(CostParam::kSubquery); + + return false; + } + bool Visit(Once &) override { return true; } auto cost() const { return cost_; } @@ -232,6 +259,12 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor { void IncrementCost(double param) { cost_ += param * cardinality_; } + double EstimateCostOnBranch(std::shared_ptr *branch) { + CostEstimator cost_estimator(db_accessor_, parameters); + (*branch)->Accept(cost_estimator); + return cost_estimator.cost(); + } + // converts an optional ScanAll range bound into a property value // if the bound is present and is a constant expression convertible to // a property value. otherwise returns nullopt diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index f9e8b7b45..2203f36aa 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -117,6 +117,7 @@ extern const Event CallProcedureOperator; extern const Event ForeachOperator; extern const Event EmptyResultOperator; extern const Event EvaluatePatternFilterOperator; +extern const Event ApplyOperator; } // namespace EventCounter namespace memgraph::query::plan { @@ -4070,8 +4071,14 @@ class DistinctCursor : public Cursor { utils::pmr::vector row(seen_rows_.get_allocator().GetMemoryResource()); row.reserve(self_.value_symbols_.size()); - for (const auto &symbol : self_.value_symbols_) row.emplace_back(frame[symbol]); - if (seen_rows_.insert(std::move(row)).second) return true; + + for (const auto &symbol : self_.value_symbols_) { + row.emplace_back(frame.at(symbol)); + } + + if (seen_rows_.insert(std::move(row)).second) { + return true; + } } } @@ -4796,4 +4803,72 @@ bool Foreach::Accept(HierarchicalLogicalOperatorVisitor &visitor) { return visitor.PostVisit(*this); } +Apply::Apply(const std::shared_ptr input, const std::shared_ptr subquery, + bool subquery_has_return) + : input_(input ? input : std::make_shared()), + subquery_(subquery), + subquery_has_return_(subquery_has_return) {} + +bool Apply::Accept(HierarchicalLogicalOperatorVisitor &visitor) { + if (visitor.PreVisit(*this)) { + input_->Accept(visitor) && subquery_->Accept(visitor); + } + return visitor.PostVisit(*this); +} + +UniqueCursorPtr Apply::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::ApplyOperator); + + return MakeUniqueCursorPtr(mem, *this, mem); +} + +Apply::ApplyCursor::ApplyCursor(const Apply &self, utils::MemoryResource *mem) + : self_(self), + input_(self.input_->MakeCursor(mem)), + subquery_(self.subquery_->MakeCursor(mem)), + subquery_has_return_(self.subquery_has_return_) {} + +std::vector Apply::ModifiedSymbols(const SymbolTable &table) const { + // Since Apply is the Cartesian product, modified symbols are combined from + // both execution branches. + auto symbols = input_->ModifiedSymbols(table); + auto subquery_symbols = subquery_->ModifiedSymbols(table); + symbols.insert(symbols.end(), subquery_symbols.begin(), subquery_symbols.end()); + return symbols; +} + +bool Apply::ApplyCursor::Pull(Frame &frame, ExecutionContext &context) { + SCOPED_PROFILE_OP("Apply"); + + while (true) { + if (pull_input_ && !input_->Pull(frame, context)) { + return false; + }; + + if (subquery_->Pull(frame, context)) { + // if successful, next Pull from this should not pull_input_ + pull_input_ = false; + return true; + } + // failed to pull from subquery cursor + // skip that row + pull_input_ = true; + subquery_->Reset(); + + // don't skip row if no rows are returned from subquery, return input_ rows + if (!subquery_has_return_) return true; + } +} + +void Apply::ApplyCursor::Shutdown() { + input_->Shutdown(); + subquery_->Shutdown(); +} + +void Apply::ApplyCursor::Reset() { + input_->Reset(); + subquery_->Reset(); + pull_input_ = true; +} + } // namespace memgraph::query::plan diff --git a/src/query/plan/operator.hpp b/src/query/plan/operator.hpp index 08ced64aa..9c0a0c831 100644 --- a/src/query/plan/operator.hpp +++ b/src/query/plan/operator.hpp @@ -127,6 +127,7 @@ class LoadCsv; class Foreach; class EmptyResult; class EvaluatePatternFilter; +class Apply; using LogicalOperatorCompositeVisitor = utils::CompositeVisitor; + Foreach, EmptyResult, EvaluatePatternFilter, Apply>; using LogicalOperatorLeafVisitor = utils::LeafVisitor; @@ -1087,8 +1088,8 @@ class Produce : public memgraph::query::plan::LogicalOperator { auto object = std::make_unique(); object->input_ = input_ ? input_->Clone(storage) : nullptr; object->named_expressions_.resize(named_expressions_.size()); - for (auto i1 = 0; i1 < named_expressions_.size(); ++i1) { - object->named_expressions_[i1] = named_expressions_[i1] ? named_expressions_[i1]->Clone(storage) : nullptr; + for (auto i2 = 0; i2 < named_expressions_.size(); ++i2) { + object->named_expressions_[i2] = named_expressions_[i2] ? named_expressions_[i2]->Clone(storage) : nullptr; } return object; } @@ -1136,8 +1137,8 @@ class Delete : public memgraph::query::plan::LogicalOperator { auto object = std::make_unique(); object->input_ = input_ ? input_->Clone(storage) : nullptr; object->expressions_.resize(expressions_.size()); - for (auto i2 = 0; i2 < expressions_.size(); ++i2) { - object->expressions_[i2] = expressions_[i2] ? expressions_[i2]->Clone(storage) : nullptr; + for (auto i3 = 0; i3 < expressions_.size(); ++i3) { + object->expressions_[i3] = expressions_[i3] ? expressions_[i3]->Clone(storage) : nullptr; } object->detach_ = detach_; return object; @@ -1608,12 +1609,12 @@ class Aggregate : public memgraph::query::plan::LogicalOperator { auto object = std::make_unique(); object->input_ = input_ ? input_->Clone(storage) : nullptr; object->aggregations_.resize(aggregations_.size()); - for (auto i3 = 0; i3 < aggregations_.size(); ++i3) { - object->aggregations_[i3] = aggregations_[i3].Clone(storage); + for (auto i4 = 0; i4 < aggregations_.size(); ++i4) { + object->aggregations_[i4] = aggregations_[i4].Clone(storage); } object->group_by_.resize(group_by_.size()); - for (auto i4 = 0; i4 < group_by_.size(); ++i4) { - object->group_by_[i4] = group_by_[i4] ? group_by_[i4]->Clone(storage) : nullptr; + for (auto i5 = 0; i5 < group_by_.size(); ++i5) { + object->group_by_[i5] = group_by_[i5] ? group_by_[i5]->Clone(storage) : nullptr; } object->remember_ = remember_; return object; @@ -1814,8 +1815,8 @@ class OrderBy : public memgraph::query::plan::LogicalOperator { object->input_ = input_ ? input_->Clone(storage) : nullptr; object->compare_ = compare_; object->order_by_.resize(order_by_.size()); - for (auto i5 = 0; i5 < order_by_.size(); ++i5) { - object->order_by_[i5] = order_by_[i5] ? order_by_[i5]->Clone(storage) : nullptr; + for (auto i6 = 0; i6 < order_by_.size(); ++i6) { + object->order_by_[i6] = order_by_[i6] ? order_by_[i6]->Clone(storage) : nullptr; } object->output_symbols_ = output_symbols_; return object; @@ -2204,8 +2205,8 @@ class CallProcedure : public memgraph::query::plan::LogicalOperator { object->input_ = input_ ? input_->Clone(storage) : nullptr; object->procedure_name_ = procedure_name_; object->arguments_.resize(arguments_.size()); - for (auto i6 = 0; i6 < arguments_.size(); ++i6) { - object->arguments_[i6] = arguments_[i6] ? arguments_[i6]->Clone(storage) : nullptr; + for (auto i7 = 0; i7 < arguments_.size(); ++i7) { + object->arguments_[i7] = arguments_[i7] ? arguments_[i7]->Clone(storage) : nullptr; } object->result_fields_ = result_fields_; object->result_symbols_ = result_symbols_; @@ -2291,6 +2292,53 @@ class Foreach : public memgraph::query::plan::LogicalOperator { } }; +/// Applies symbols from both output branches. +class Apply : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Apply() {} + + Apply(const std::shared_ptr input, const std::shared_ptr subquery, + bool subquery_has_return); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr input() const override { return input_; } + void set_input(std::shared_ptr input) override { input_ = input; } + + std::shared_ptr input_; + std::shared_ptr subquery_; + bool subquery_has_return_; + + std::unique_ptr Clone(AstStorage *storage) const override { + auto object = std::make_unique(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->subquery_ = subquery_ ? subquery_->Clone(storage) : nullptr; + object->subquery_has_return_ = subquery_has_return_; + return object; + } + + private: + class ApplyCursor : public Cursor { + public: + ApplyCursor(const Apply &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const Apply &self_; + UniqueCursorPtr input_; + UniqueCursorPtr subquery_; + bool pull_input_{true}; + bool subquery_has_return_{true}; + }; +}; + } // namespace plan } // namespace query } // namespace memgraph diff --git a/src/query/plan/operator.lcp b/src/query/plan/operator.lcp index 24e89b82e..e57560347 100644 --- a/src/query/plan/operator.lcp +++ b/src/query/plan/operator.lcp @@ -134,6 +134,7 @@ class LoadCsv; class Foreach; class EmptyResult; class EvaluatePatternFilter; +class Apply; using LogicalOperatorCompositeVisitor = utils::CompositeVisitor< Once, CreateNode, CreateExpand, ScanAll, ScanAllByLabel, @@ -142,7 +143,8 @@ using LogicalOperatorCompositeVisitor = utils::CompositeVisitor< Expand, ExpandVariable, ConstructNamedPath, Filter, Produce, Delete, SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, EdgeUniquenessFilter, Accumulate, Aggregate, Skip, Limit, OrderBy, Merge, - Optional, Unwind, Distinct, Union, Cartesian, CallProcedure, LoadCsv, Foreach, EmptyResult, EvaluatePatternFilter>; + Optional, Unwind, Distinct, Union, Cartesian, CallProcedure, LoadCsv, Foreach, EmptyResult, + EvaluatePatternFilter, Apply>; using LogicalOperatorLeafVisitor = utils::LeafVisitor; @@ -2381,6 +2383,52 @@ clauses. (:serialize (:slk)) (:clone)) +(lcp:define-class apply (logical-operator) + ((input "std::shared_ptr" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (subquery "std::shared_ptr" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (subquery-has-return "bool" :scope :public)) + + (:documentation "Applies symbols from both output branches.") + + (:public + #>cpp + Apply() {} + + Apply(const std::shared_ptr input, const std::shared_ptr subquery, bool subquery_has_return); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr input() const override { return input_; } + void set_input(std::shared_ptr input) override { + input_ = input; + } + cpp<#) + (:private + #>cpp + class ApplyCursor : public Cursor { + public: + ApplyCursor(const Apply &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const Apply &self_; + UniqueCursorPtr input_; + UniqueCursorPtr subquery_; + bool pull_input_{true}; + bool subquery_has_return_{true}; + }; + cpp<#) + (:serialize (:slk)) + (:clone)) + (lcp:pop-namespace) ;; plan (lcp:pop-namespace) ;; query (lcp:pop-namespace) ;; memgraph diff --git a/src/query/plan/operator_type_info.cpp b/src/query/plan/operator_type_info.cpp index c20dd1e77..efedc9b04 100644 --- a/src/query/plan/operator_type_info.cpp +++ b/src/query/plan/operator_type_info.cpp @@ -145,4 +145,7 @@ constexpr utils::TypeInfo query::plan::LoadCsv::kType{utils::TypeId::LOAD_CSV, " constexpr utils::TypeInfo query::plan::Foreach::kType{utils::TypeId::FOREACH, "Foreach", &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::Apply::kType{utils::TypeId::APPLY, "Apply", + &query::plan::LogicalOperator::kType}; } // namespace memgraph diff --git a/src/query/plan/planner.hpp b/src/query/plan/planner.hpp index 85ee16fca..cfc89ec73 100644 --- a/src/query/plan/planner.hpp +++ b/src/query/plan/planner.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -50,23 +50,6 @@ class PostProcessor final { double EstimatePlanCost(const std::unique_ptr &plan, TVertexCounts *vertex_counts) { return query::plan::EstimatePlanCost(vertex_counts, parameters_, *plan); } - - template - std::unique_ptr MergeWithCombinator(std::unique_ptr curr_op, - std::unique_ptr last_op, const Tree &combinator, - TPlanningContext *context) { - if (const auto *union_ = utils::Downcast(&combinator)) { - return std::unique_ptr( - impl::GenUnion(*union_, std::move(last_op), std::move(curr_op), *context->symbol_table)); - } - throw utils::NotYetImplemented("query combinator"); - } - - template - std::unique_ptr MakeDistinct(std::unique_ptr last_op, TPlanningContext *context) { - auto output_symbols = last_op->OutputSymbols(*context->symbol_table); - return std::make_unique(std::move(last_op), output_symbols); - } }; /// @brief Generates the LogicalOperator tree for a single query and returns the @@ -82,10 +65,9 @@ class PostProcessor final { /// @sa RuleBasedPlanner /// @sa VariableStartPlanner template