diff --git a/src/query/frontend/ast/ast.lcp b/src/query/frontend/ast/ast.lcp index e734912da..6e959fc95 100644 --- a/src/query/frontend/ast/ast.lcp +++ b/src/query/frontend/ast/ast.lcp @@ -2629,5 +2629,39 @@ cpp<# (:serialize (:slk)) (:clone)) +(lcp:define-class foreach (clause) + ((named_expression "NamedExpression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (clauses "std::vector<Clause *>" + :scope :public + :slk-save #'slk-save-ast-vector + :slk-load (slk-load-ast-vector "Clause"))) + (:public + #>cpp + Foreach() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + named_expression_->Accept(visitor); + for (auto &clause : clauses_) { + clause->Accept(visitor); + } + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + Foreach(NamedExpression *expression, std::vector<Clause *> clauses) + : named_expression_(expression), clauses_(clauses) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + (lcp:pop-namespace) ;; namespace query (lcp:pop-namespace) ;; namespace memgraph diff --git a/src/query/frontend/ast/ast_visitor.hpp b/src/query/frontend/ast/ast_visitor.hpp index c5182479c..0e4a6012c 100644 --- a/src/query/frontend/ast/ast_visitor.hpp +++ b/src/query/frontend/ast/ast_visitor.hpp @@ -93,6 +93,7 @@ class CreateSnapshotQuery; class StreamQuery; class SettingQuery; class VersionQuery; +class Foreach; using TreeCompositeVisitor = utils::CompositeVisitor< SingleQuery, CypherUnion, NamedExpression, OrOperator, XorOperator, AndOperator, NotOperator, AdditionOperator, @@ -101,7 +102,7 @@ using TreeCompositeVisitor = utils::CompositeVisitor< ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, ListLiteral, MapLiteral, PropertyLookup, LabelsTest, Aggregation, Function, Reduce, Coalesce, Extract, All, Single, Any, None, CallProcedure, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom, Delete, Where, SetProperty, SetProperties, SetLabels, - RemoveProperty, RemoveLabels, Merge, Unwind, RegexMatch, LoadCsv>; + RemoveProperty, RemoveLabels, Merge, Unwind, RegexMatch, LoadCsv, Foreach>; using TreeLeafVisitor = utils::LeafVisitor<Identifier, PrimitiveLiteral, ParameterLookup>; diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index 9e7f3c5d4..2544ea663 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -37,6 +37,7 @@ #include "utils/exceptions.hpp" #include "utils/logging.hpp" #include "utils/string.hpp" +#include "utils/typeinfo.hpp" namespace memgraph::query::frontend { @@ -956,7 +957,8 @@ antlrcpp::Any CypherMainVisitor::visitSingleQuery(MemgraphCypher::SingleQueryCon utils::IsSubtype(clause_type, SetProperty::kType) || utils::IsSubtype(clause_type, SetProperties::kType) || utils::IsSubtype(clause_type, SetLabels::kType) || utils::IsSubtype(clause_type, RemoveProperty::kType) || - utils::IsSubtype(clause_type, RemoveLabels::kType) || utils::IsSubtype(clause_type, Merge::kType)) { + utils::IsSubtype(clause_type, RemoveLabels::kType) || utils::IsSubtype(clause_type, Merge::kType) || + utils::IsSubtype(clause_type, Foreach::kType)) { if (has_return) { throw SemanticException("Update clause can't be used after RETURN."); } @@ -1036,6 +1038,9 @@ antlrcpp::Any CypherMainVisitor::visitClause(MemgraphCypher::ClauseContext *ctx) if (ctx->loadCsv()) { return static_cast<Clause *>(ctx->loadCsv()->accept(this).as<LoadCsv *>()); } + if (ctx->foreach ()) { + return static_cast<Clause *>(ctx->foreach ()->accept(this).as<Foreach *>()); + } // TODO: implement other clauses. throw utils::NotYetImplemented("clause '{}'", ctx->getText()); return 0; @@ -2283,6 +2288,37 @@ antlrcpp::Any CypherMainVisitor::visitFilterExpression(MemgraphCypher::FilterExp return 0; } +antlrcpp::Any CypherMainVisitor::visitForeach(MemgraphCypher::ForeachContext *ctx) { + auto *for_each = storage_->Create<Foreach>(); + + auto *named_expr = storage_->Create<NamedExpression>(); + named_expr->expression_ = ctx->expression()->accept(this); + named_expr->name_ = std::string(ctx->variable()->accept(this).as<std::string>()); + for_each->named_expression_ = named_expr; + + for (auto *update_clause_ctx : ctx->updateClause()) { + if (auto *set = update_clause_ctx->set(); set) { + auto set_items = visitSet(set).as<std::vector<Clause *>>(); + std::copy(set_items.begin(), set_items.end(), std::back_inserter(for_each->clauses_)); + } else if (auto *remove = update_clause_ctx->remove(); remove) { + auto remove_items = visitRemove(remove).as<std::vector<Clause *>>(); + std::copy(remove_items.begin(), remove_items.end(), std::back_inserter(for_each->clauses_)); + } else if (auto *merge = update_clause_ctx->merge(); merge) { + for_each->clauses_.push_back(visitMerge(merge).as<Merge *>()); + } else if (auto *create = update_clause_ctx->create(); create) { + for_each->clauses_.push_back(visitCreate(create).as<Create *>()); + } else if (auto *cypher_delete = update_clause_ctx->cypherDelete(); cypher_delete) { + for_each->clauses_.push_back(visitCypherDelete(cypher_delete).as<Delete *>()); + } else { + auto *nested_for_each = update_clause_ctx->foreach (); + MG_ASSERT(nested_for_each != nullptr, "Unexpected clause in FOREACH"); + for_each->clauses_.push_back(visitForeach(nested_for_each).as<Foreach *>()); + } + } + + return for_each; +} + LabelIx CypherMainVisitor::AddLabel(const std::string &name) { return storage_->GetLabelIx(name); } PropertyIx CypherMainVisitor::AddProperty(const std::string &name) { return storage_->GetPropertyIx(name); } diff --git a/src/query/frontend/ast/cypher_main_visitor.hpp b/src/query/frontend/ast/cypher_main_visitor.hpp index 4d01e7773..2a6b8ff5e 100644 --- a/src/query/frontend/ast/cypher_main_visitor.hpp +++ b/src/query/frontend/ast/cypher_main_visitor.hpp @@ -844,6 +844,11 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor { */ antlrcpp::Any visitFilterExpression(MemgraphCypher::FilterExpressionContext *) override; + /** + * @return Foreach* + */ + antlrcpp::Any visitForeach(MemgraphCypher::ForeachContext *ctx) override; + public: Query *query() { return query_; } const static std::string kAnonPrefix; diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 index e4c13c342..76fd3038a 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 @@ -47,6 +47,7 @@ memgraphCypherKeyword : cypherKeyword | DUMP | EXECUTE | FOR + | FOREACH | FREE | FROM | GLOBAL @@ -163,8 +164,19 @@ clause : cypherMatch | cypherReturn | callProcedure | loadCsv + | foreach ; +updateClause : set + | remove + | create + | merge + | cypherDelete + | foreach + ; + +foreach : FOREACH '(' variable IN expression '|' updateClause+ ')' ; + streamQuery : checkStream | createStream | dropStream diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 index 7e95c295f..55e5d53a2 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 @@ -54,6 +54,7 @@ DUMP : D U M P ; DURABILITY : D U R A B I L I T Y ; EXECUTE : E X E C U T E ; FOR : F O R ; +FOREACH : F O R E A C H; FREE : F R E E ; FREE_MEMORY : F R E E UNDERSCORE M E M O R Y ; FROM : F R O M ; diff --git a/src/query/frontend/semantic/symbol_generator.cpp b/src/query/frontend/semantic/symbol_generator.cpp index 7c9fa97ca..0df1f3771 100644 --- a/src/query/frontend/semantic/symbol_generator.cpp +++ b/src/query/frontend/semantic/symbol_generator.cpp @@ -15,7 +15,9 @@ #include "query/frontend/semantic/symbol_generator.hpp" +#include <algorithm> #include <optional> +#include <ranges> #include <unordered_set> #include <variant> @@ -39,35 +41,56 @@ std::unordered_map<std::string, Identifier *> GeneratePredefinedIdentifierMap( } // namespace SymbolGenerator::SymbolGenerator(SymbolTable *symbol_table, const std::vector<Identifier *> &predefined_identifiers) - : symbol_table_(symbol_table), predefined_identifiers_{GeneratePredefinedIdentifierMap(predefined_identifiers)} {} + : symbol_table_(symbol_table), + predefined_identifiers_{GeneratePredefinedIdentifierMap(predefined_identifiers)}, + scopes_(1, Scope()) {} -auto SymbolGenerator::CreateSymbol(const std::string &name, bool user_declared, Symbol::Type type, int token_position) { - auto symbol = symbol_table_->CreateSymbol(name, user_declared, type, token_position); - scope_.symbols[name] = symbol; - return symbol; -} - -auto SymbolGenerator::GetOrCreateSymbol(const std::string &name, bool user_declared, Symbol::Type type) { - auto search = scope_.symbols.find(name); - if (search != scope_.symbols.end()) { - auto symbol = search->second; +std::optional<Symbol> SymbolGenerator::FindSymbolInScope(const std::string &name, const Scope &scope, + Symbol::Type type) { + if (auto it = scope.symbols.find(name); it != scope.symbols.end()) { + const auto &symbol = it->second; // Unless we have `ANY` type, check that types match. if (type != Symbol::Type::ANY && symbol.type() != Symbol::Type::ANY && type != symbol.type()) { throw TypeMismatchError(name, Symbol::TypeToString(symbol.type()), Symbol::TypeToString(type)); } - return search->second; + return symbol; + } + return std::nullopt; +} + +auto SymbolGenerator::CreateSymbol(const std::string &name, bool user_declared, Symbol::Type type, int token_position) { + auto symbol = symbol_table_->CreateSymbol(name, user_declared, type, token_position); + scopes_.back().symbols[name] = symbol; + return symbol; +} + +auto SymbolGenerator::GetOrCreateSymbolLocalScope(const std::string &name, bool user_declared, Symbol::Type type) { + auto &scope = scopes_.back(); + if (auto maybe_symbol = FindSymbolInScope(name, scope, type); maybe_symbol) { + return *maybe_symbol; + } + return CreateSymbol(name, user_declared, type); +} + +auto SymbolGenerator::GetOrCreateSymbol(const std::string &name, bool user_declared, Symbol::Type type) { + // NOLINTNEXTLINE + for (auto scope = scopes_.rbegin(); scope != scopes_.rend(); ++scope) { + if (auto maybe_symbol = FindSymbolInScope(name, *scope, type); maybe_symbol) { + return *maybe_symbol; + } } return CreateSymbol(name, user_declared, type); } void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) { + auto &scope = scopes_.back(); for (auto &expr : body.named_expressions) { expr->Accept(*this); } std::vector<Symbol> user_symbols; if (body.all_identifiers) { // Carry over user symbols because '*' appeared. - for (auto sym_pair : scope_.symbols) { + for (const auto &sym_pair : scope.symbols) { if (!sym_pair.second.user_declared()) { continue; } @@ -81,18 +104,18 @@ void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) { // declares only those established through named expressions. New declarations // must not be visible inside named expressions themselves. bool removed_old_names = false; - if ((!where && body.order_by.empty()) || scope_.has_aggregation) { + if ((!where && body.order_by.empty()) || scope.has_aggregation) { // WHERE and ORDER BY need to see both the old and new symbols, unless we // have an aggregation. Therefore, we can clear the symbols immediately if // there is neither ORDER BY nor WHERE, or we have an aggregation. - scope_.symbols.clear(); + scope.symbols.clear(); removed_old_names = true; } // Create symbols for named expressions. std::unordered_set<std::string> new_names; for (const auto &user_sym : user_symbols) { new_names.insert(user_sym.name()); - scope_.symbols[user_sym.name()] = user_sym; + scope.symbols[user_sym.name()] = user_sym; } for (auto &named_expr : body.named_expressions) { const auto &name = named_expr->name_; @@ -103,35 +126,35 @@ void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) { // new symbol would have a more specific type. named_expr->MapTo(CreateSymbol(name, true, Symbol::Type::ANY, named_expr->token_position_)); } - scope_.in_order_by = true; + scope.in_order_by = true; for (const auto &order_pair : body.order_by) { order_pair.expression->Accept(*this); } - scope_.in_order_by = false; + scope.in_order_by = false; if (body.skip) { - scope_.in_skip = true; + scope.in_skip = true; body.skip->Accept(*this); - scope_.in_skip = false; + scope.in_skip = false; } if (body.limit) { - scope_.in_limit = true; + scope.in_limit = true; body.limit->Accept(*this); - scope_.in_limit = false; + scope.in_limit = false; } if (where) where->Accept(*this); if (!removed_old_names) { // We have an ORDER BY or WHERE, but no aggregation, which means we didn't // clear the old symbols, so do it now. We cannot just call clear, because // we've added new symbols. - for (auto sym_it = scope_.symbols.begin(); sym_it != scope_.symbols.end();) { + for (auto sym_it = scope.symbols.begin(); sym_it != scope.symbols.end();) { if (new_names.find(sym_it->first) == new_names.end()) { - sym_it = scope_.symbols.erase(sym_it); + sym_it = scope.symbols.erase(sym_it); } else { sym_it++; } } } - scope_.has_aggregation = false; + scopes_.back().has_aggregation = false; } // Query @@ -145,7 +168,7 @@ bool SymbolGenerator::PreVisit(SingleQuery &) { // Union bool SymbolGenerator::PreVisit(CypherUnion &) { - scope_ = Scope(); + scopes_.back() = Scope(); return true; } @@ -166,11 +189,11 @@ bool SymbolGenerator::PostVisit(CypherUnion &cypher_union) { // Clauses bool SymbolGenerator::PreVisit(Create &) { - scope_.in_create = true; + scopes_.back().in_create = true; return true; } bool SymbolGenerator::PostVisit(Create &) { - scope_.in_create = false; + scopes_.back().in_create = false; return true; } @@ -183,7 +206,7 @@ bool SymbolGenerator::PreVisit(CallProcedure &call_proc) { bool SymbolGenerator::PostVisit(CallProcedure &call_proc) { for (auto *ident : call_proc.result_identifiers_) { - if (HasSymbol(ident->name_)) { + if (HasSymbolLocalScope(ident->name_)) { throw RedeclareVariableError(ident->name_); } ident->MapTo(CreateSymbol(ident->name_, true)); @@ -194,7 +217,7 @@ bool SymbolGenerator::PostVisit(CallProcedure &call_proc) { bool SymbolGenerator::PreVisit(LoadCsv &load_csv) { return false; } bool SymbolGenerator::PostVisit(LoadCsv &load_csv) { - if (HasSymbol(load_csv.row_var_->name_)) { + if (HasSymbolLocalScope(load_csv.row_var_->name_)) { throw RedeclareVariableError(load_csv.row_var_->name_); } load_csv.row_var_->MapTo(CreateSymbol(load_csv.row_var_->name_, true)); @@ -202,45 +225,47 @@ bool SymbolGenerator::PostVisit(LoadCsv &load_csv) { } bool SymbolGenerator::PreVisit(Return &ret) { - scope_.in_return = true; + auto &scope = scopes_.back(); + scope.in_return = true; VisitReturnBody(ret.body_); - scope_.in_return = false; + scope.in_return = false; return false; // We handled the traversal ourselves. } bool SymbolGenerator::PostVisit(Return &) { - for (const auto &name_symbol : scope_.symbols) curr_return_names_.insert(name_symbol.first); + for (const auto &name_symbol : scopes_.back().symbols) curr_return_names_.insert(name_symbol.first); return true; } bool SymbolGenerator::PreVisit(With &with) { - scope_.in_with = true; + auto &scope = scopes_.back(); + scope.in_with = true; VisitReturnBody(with.body_, with.where_); - scope_.in_with = false; + scope.in_with = false; return false; // We handled the traversal ourselves. } bool SymbolGenerator::PreVisit(Where &) { - scope_.in_where = true; + scopes_.back().in_where = true; return true; } bool SymbolGenerator::PostVisit(Where &) { - scope_.in_where = false; + scopes_.back().in_where = false; return true; } bool SymbolGenerator::PreVisit(Merge &) { - scope_.in_merge = true; + scopes_.back().in_merge = true; return true; } bool SymbolGenerator::PostVisit(Merge &) { - scope_.in_merge = false; + scopes_.back().in_merge = false; return true; } bool SymbolGenerator::PostVisit(Unwind &unwind) { const auto &name = unwind.named_expression_->name_; - if (HasSymbol(name)) { + if (HasSymbolLocalScope(name)) { throw RedeclareVariableError(name); } unwind.named_expression_->MapTo(CreateSymbol(name, true)); @@ -248,55 +273,70 @@ bool SymbolGenerator::PostVisit(Unwind &unwind) { } bool SymbolGenerator::PreVisit(Match &) { - scope_.in_match = true; + scopes_.back().in_match = true; return true; } bool SymbolGenerator::PostVisit(Match &) { - scope_.in_match = false; + auto &scope = scopes_.back(); + scope.in_match = false; // Check variables in property maps after visiting Match, so that they can // reference symbols out of bind order. - for (auto &ident : scope_.identifiers_in_match) { - if (!HasSymbol(ident->name_) && !ConsumePredefinedIdentifier(ident->name_)) + for (auto &ident : scope.identifiers_in_match) { + if (!HasSymbolLocalScope(ident->name_) && !ConsumePredefinedIdentifier(ident->name_)) throw UnboundVariableError(ident->name_); - ident->MapTo(scope_.symbols[ident->name_]); + ident->MapTo(scope.symbols[ident->name_]); } - scope_.identifiers_in_match.clear(); + scope.identifiers_in_match.clear(); + return true; +} + +bool SymbolGenerator::PreVisit(Foreach &for_each) { + const auto &name = for_each.named_expression_->name_; + scopes_.emplace_back(Scope()); + scopes_.back().in_foreach = true; + for_each.named_expression_->MapTo( + CreateSymbol(name, true, Symbol::Type::ANY, for_each.named_expression_->token_position_)); + return true; +} +bool SymbolGenerator::PostVisit([[maybe_unused]] Foreach &for_each) { + scopes_.pop_back(); return true; } // Expressions SymbolGenerator::ReturnType SymbolGenerator::Visit(Identifier &ident) { - if (scope_.in_skip || scope_.in_limit) { - throw SemanticException("Variables are not allowed in {}.", scope_.in_skip ? "SKIP" : "LIMIT"); + auto &scope = scopes_.back(); + if (scope.in_skip || scope.in_limit) { + throw SemanticException("Variables are not allowed in {}.", scope.in_skip ? "SKIP" : "LIMIT"); } Symbol symbol; - if (scope_.in_pattern && !(scope_.in_node_atom || scope_.visiting_edge)) { + if (scope.in_pattern && !(scope.in_node_atom || scope.visiting_edge)) { // If we are in the pattern, and outside of a node or an edge, the // identifier is the pattern name. - symbol = GetOrCreateSymbol(ident.name_, ident.user_declared_, Symbol::Type::PATH); - } else if (scope_.in_pattern && scope_.in_pattern_atom_identifier) { + symbol = GetOrCreateSymbolLocalScope(ident.name_, ident.user_declared_, Symbol::Type::PATH); + } else if (scope.in_pattern && scope.in_pattern_atom_identifier) { // Patterns used to create nodes and edges cannot redeclare already // established bindings. Declaration only happens in single node // patterns and in edge patterns. OpenCypher example, // `MATCH (n) CREATE (n)` should throw an error that `n` is already // declared. While `MATCH (n) CREATE (n) -[:R]-> (n)` is allowed, // since `n` now references the bound node instead of declaring it. - if ((scope_.in_create_node || scope_.in_create_edge) && HasSymbol(ident.name_)) { + if ((scope.in_create_node || scope.in_create_edge) && HasSymbolLocalScope(ident.name_)) { throw RedeclareVariableError(ident.name_); } auto type = Symbol::Type::VERTEX; - if (scope_.visiting_edge) { + if (scope.visiting_edge) { // Edge referencing is not allowed (like in Neo4j): // `MATCH (n) - [r] -> (n) - [r] -> (n) RETURN r` is not allowed. - if (HasSymbol(ident.name_)) { + if (HasSymbolLocalScope(ident.name_)) { throw RedeclareVariableError(ident.name_); } - type = scope_.visiting_edge->IsVariable() ? Symbol::Type::EDGE_LIST : Symbol::Type::EDGE; + type = scope.visiting_edge->IsVariable() ? Symbol::Type::EDGE_LIST : Symbol::Type::EDGE; } - symbol = GetOrCreateSymbol(ident.name_, ident.user_declared_, type); - } else if (scope_.in_pattern && !scope_.in_pattern_atom_identifier && scope_.in_match) { - if (scope_.in_edge_range && scope_.visiting_edge->identifier_->name_ == ident.name_) { + symbol = GetOrCreateSymbolLocalScope(ident.name_, ident.user_declared_, type); + } else if (scope.in_pattern && !scope.in_pattern_atom_identifier && scope.in_match) { + if (scope.in_edge_range && scope.visiting_edge->identifier_->name_ == ident.name_) { // Prevent variable path bounds to reference the identifier which is bound // by the variable path itself. throw UnboundVariableError(ident.name_); @@ -304,30 +344,30 @@ SymbolGenerator::ReturnType SymbolGenerator::Visit(Identifier &ident) { // Variables in property maps or bounds of variable length path during MATCH // can reference symbols bound later in the same MATCH. We collect them // here, so that they can be checked after visiting Match. - scope_.identifiers_in_match.emplace_back(&ident); + scope.identifiers_in_match.emplace_back(&ident); } else { // Everything else references a bound symbol. if (!HasSymbol(ident.name_) && !ConsumePredefinedIdentifier(ident.name_)) throw UnboundVariableError(ident.name_); - symbol = scope_.symbols[ident.name_]; + symbol = GetOrCreateSymbol(ident.name_, ident.user_declared_, Symbol::Type::ANY); } ident.MapTo(symbol); return true; } bool SymbolGenerator::PreVisit(Aggregation &aggr) { + auto &scope = scopes_.back(); // Check if the aggregation can be used in this context. This check should // probably move to a separate phase, which checks if the query is well // formed. - if ((!scope_.in_return && !scope_.in_with) || scope_.in_order_by || scope_.in_skip || scope_.in_limit || - scope_.in_where) { + if ((!scope.in_return && !scope.in_with) || scope.in_order_by || scope.in_skip || scope.in_limit || scope.in_where) { throw SemanticException("Aggregation functions are only allowed in WITH and RETURN."); } - if (scope_.in_aggregation) { + if (scope.in_aggregation) { throw SemanticException( "Using aggregation functions inside aggregation functions is not " "allowed."); } - if (scope_.num_if_operators) { + if (scope.num_if_operators) { // Neo allows aggregations here and produces very interesting behaviors. // To simplify implementation at this moment we decided to completely // disallow aggregations inside of the CASE. @@ -341,23 +381,23 @@ bool SymbolGenerator::PreVisit(Aggregation &aggr) { // Currently, we only have aggregation operators which return numbers. auto aggr_name = Aggregation::OpToString(aggr.op_) + std::to_string(aggr.symbol_pos_); aggr.MapTo(CreateSymbol(aggr_name, false, Symbol::Type::NUMBER)); - scope_.in_aggregation = true; - scope_.has_aggregation = true; + scope.in_aggregation = true; + scope.has_aggregation = true; return true; } bool SymbolGenerator::PostVisit(Aggregation &) { - scope_.in_aggregation = false; + scopes_.back().in_aggregation = false; return true; } bool SymbolGenerator::PreVisit(IfOperator &) { - ++scope_.num_if_operators; + ++scopes_.back().num_if_operators; return true; } bool SymbolGenerator::PostVisit(IfOperator &) { - --scope_.num_if_operators; + --scopes_.back().num_if_operators; return true; } @@ -401,33 +441,36 @@ bool SymbolGenerator::PreVisit(Extract &extract) { // Pattern and its subparts. bool SymbolGenerator::PreVisit(Pattern &pattern) { - scope_.in_pattern = true; - if ((scope_.in_create || scope_.in_merge) && pattern.atoms_.size() == 1U) { + auto &scope = scopes_.back(); + scope.in_pattern = true; + if ((scope.in_create || scope.in_merge) && pattern.atoms_.size() == 1U) { MG_ASSERT(utils::IsSubtype(*pattern.atoms_[0], NodeAtom::kType), "Expected a single NodeAtom in Pattern"); - scope_.in_create_node = true; + scope.in_create_node = true; } return true; } bool SymbolGenerator::PostVisit(Pattern &) { - scope_.in_pattern = false; - scope_.in_create_node = false; + auto &scope = scopes_.back(); + scope.in_pattern = false; + scope.in_create_node = false; return true; } bool SymbolGenerator::PreVisit(NodeAtom &node_atom) { - auto check_node_semantic = [&node_atom, this](const bool props_or_labels) { + auto &scope = scopes_.back(); + auto check_node_semantic = [&node_atom, &scope, this](const bool props_or_labels) { const auto &node_name = node_atom.identifier_->name_; - if ((scope_.in_create || scope_.in_merge) && props_or_labels && HasSymbol(node_name)) { + if ((scope.in_create || scope.in_merge) && props_or_labels && HasSymbolLocalScope(node_name)) { throw SemanticException("Cannot create node '" + node_name + "' with labels or properties, because it is already declared."); } - scope_.in_pattern_atom_identifier = true; + scope.in_pattern_atom_identifier = true; node_atom.identifier_->Accept(*this); - scope_.in_pattern_atom_identifier = false; + scope.in_pattern_atom_identifier = false; }; - scope_.in_node_atom = true; + scope.in_node_atom = true; if (auto *properties = std::get_if<std::unordered_map<PropertyIx, Expression *>>(&node_atom.properties_)) { bool props_or_labels = !properties->empty() || !node_atom.labels_.empty(); @@ -447,20 +490,21 @@ bool SymbolGenerator::PreVisit(NodeAtom &node_atom) { } bool SymbolGenerator::PostVisit(NodeAtom &) { - scope_.in_node_atom = false; + scopes_.back().in_node_atom = false; return true; } bool SymbolGenerator::PreVisit(EdgeAtom &edge_atom) { - scope_.visiting_edge = &edge_atom; - if (scope_.in_create || scope_.in_merge) { - scope_.in_create_edge = true; + auto &scope = scopes_.back(); + scope.visiting_edge = &edge_atom; + if (scope.in_create || scope.in_merge) { + scope.in_create_edge = true; if (edge_atom.edge_types_.size() != 1U) { throw SemanticException( "A single relationship type must be specified " "when creating an edge."); } - if (scope_.in_create && // Merge allows bidirectionality + if (scope.in_create && // Merge allows bidirectionality edge_atom.direction_ == EdgeAtom::Direction::BOTH) { throw SemanticException( "Bidirectional relationship are not supported " @@ -480,15 +524,15 @@ bool SymbolGenerator::PreVisit(EdgeAtom &edge_atom) { std::get<ParameterLookup *>(edge_atom.properties_)->Accept(*this); } if (edge_atom.IsVariable()) { - scope_.in_edge_range = true; + scope.in_edge_range = true; if (edge_atom.lower_bound_) { edge_atom.lower_bound_->Accept(*this); } if (edge_atom.upper_bound_) { edge_atom.upper_bound_->Accept(*this); } - scope_.in_edge_range = false; - scope_.in_pattern = false; + scope.in_edge_range = false; + scope.in_pattern = false; if (edge_atom.filter_lambda_.expression) { VisitWithIdentifiers(edge_atom.filter_lambda_.expression, {edge_atom.filter_lambda_.inner_edge, edge_atom.filter_lambda_.inner_node}); @@ -505,34 +549,36 @@ bool SymbolGenerator::PreVisit(EdgeAtom &edge_atom) { VisitWithIdentifiers(edge_atom.weight_lambda_.expression, {edge_atom.weight_lambda_.inner_edge, edge_atom.weight_lambda_.inner_node}); } - scope_.in_pattern = true; + scope.in_pattern = true; } - scope_.in_pattern_atom_identifier = true; + scope.in_pattern_atom_identifier = true; edge_atom.identifier_->Accept(*this); - scope_.in_pattern_atom_identifier = false; + scope.in_pattern_atom_identifier = false; if (edge_atom.total_weight_) { - if (HasSymbol(edge_atom.total_weight_->name_)) { + if (HasSymbolLocalScope(edge_atom.total_weight_->name_)) { throw RedeclareVariableError(edge_atom.total_weight_->name_); } - edge_atom.total_weight_->MapTo(GetOrCreateSymbol(edge_atom.total_weight_->name_, - edge_atom.total_weight_->user_declared_, Symbol::Type::NUMBER)); + edge_atom.total_weight_->MapTo(GetOrCreateSymbolLocalScope( + edge_atom.total_weight_->name_, edge_atom.total_weight_->user_declared_, Symbol::Type::NUMBER)); } return false; } bool SymbolGenerator::PostVisit(EdgeAtom &) { - scope_.visiting_edge = nullptr; - scope_.in_create_edge = false; + auto &scope = scopes_.back(); + scope.visiting_edge = nullptr; + scope.in_create_edge = false; return true; } void SymbolGenerator::VisitWithIdentifiers(Expression *expr, const std::vector<Identifier *> &identifiers) { + auto &scope = scopes_.back(); std::vector<std::pair<std::optional<Symbol>, Identifier *>> prev_symbols; // Collect previous symbols if they exist. for (const auto &identifier : identifiers) { std::optional<Symbol> prev_symbol; - auto prev_symbol_it = scope_.symbols.find(identifier->name_); - if (prev_symbol_it != scope_.symbols.end()) { + auto prev_symbol_it = scope.symbols.find(identifier->name_); + if (prev_symbol_it != scope.symbols.end()) { prev_symbol = prev_symbol_it->second; } identifier->MapTo(CreateSymbol(identifier->name_, identifier->user_declared_)); @@ -545,14 +591,20 @@ void SymbolGenerator::VisitWithIdentifiers(Expression *expr, const std::vector<I const auto &prev_symbol = prev.first; const auto &identifier = prev.second; if (prev_symbol) { - scope_.symbols[identifier->name_] = *prev_symbol; + scope.symbols[identifier->name_] = *prev_symbol; } else { - scope_.symbols.erase(identifier->name_); + scope.symbols.erase(identifier->name_); } } } -bool SymbolGenerator::HasSymbol(const std::string &name) { return scope_.symbols.find(name) != scope_.symbols.end(); } +bool SymbolGenerator::HasSymbol(const std::string &name) const { + return std::ranges::any_of(scopes_, [&name](const auto &scope) { return scope.symbols.contains(name); }); +} + +bool SymbolGenerator::HasSymbolLocalScope(const std::string &name) const { + return scopes_.back().symbols.contains(name); +} bool SymbolGenerator::ConsumePredefinedIdentifier(const std::string &name) { auto it = predefined_identifiers_.find(name); diff --git a/src/query/frontend/semantic/symbol_generator.hpp b/src/query/frontend/semantic/symbol_generator.hpp index fd57175bc..8703fd4e6 100644 --- a/src/query/frontend/semantic/symbol_generator.hpp +++ b/src/query/frontend/semantic/symbol_generator.hpp @@ -15,6 +15,9 @@ #pragma once +#include <optional> +#include <vector> + #include "query/exceptions.hpp" #include "query/frontend/ast/ast.hpp" #include "query/frontend/semantic/symbol_table.hpp" @@ -59,6 +62,8 @@ class SymbolGenerator : public HierarchicalTreeVisitor { bool PostVisit(Unwind &) override; bool PreVisit(Match &) override; bool PostVisit(Match &) override; + bool PreVisit(Foreach &) override; + bool PostVisit(Foreach &) override; // Expressions ReturnType Visit(Identifier &) override; @@ -107,6 +112,7 @@ class SymbolGenerator : public HierarchicalTreeVisitor { bool in_order_by{false}; bool in_where{false}; bool in_match{false}; + bool in_foreach{false}; // True when visiting a pattern atom (node or edge) identifier, which can be // reused or created in the pattern itself. bool in_pattern_atom_identifier{false}; @@ -125,7 +131,10 @@ class SymbolGenerator : public HierarchicalTreeVisitor { int num_if_operators{0}; }; - bool HasSymbol(const std::string &name); + static std::optional<Symbol> FindSymbolInScope(const std::string &name, const Scope &scope, Symbol::Type type); + + bool HasSymbol(const std::string &name) const; + bool HasSymbolLocalScope(const std::string &name) const; // @return true if it added a predefined identifier with that name bool ConsumePredefinedIdentifier(const std::string &name); @@ -135,9 +144,10 @@ class SymbolGenerator : public HierarchicalTreeVisitor { auto CreateSymbol(const std::string &name, bool user_declared, Symbol::Type type = Symbol::Type::ANY, int token_position = -1); + auto GetOrCreateSymbol(const std::string &name, bool user_declared, Symbol::Type type = Symbol::Type::ANY); // Returns the symbol by name. If the mapping already exists, checks if the // types match. Otherwise, returns a new symbol. - auto GetOrCreateSymbol(const std::string &name, bool user_declared, Symbol::Type type = Symbol::Type::ANY); + auto GetOrCreateSymbolLocalScope(const std::string &name, bool user_declared, Symbol::Type type = Symbol::Type::ANY); void VisitReturnBody(ReturnBody &body, Where *where = nullptr); @@ -148,7 +158,7 @@ class SymbolGenerator : public HierarchicalTreeVisitor { // Identifiers which are injected from outside the query. Each identifier // is mapped by its name. std::unordered_map<std::string, Identifier *> predefined_identifiers_; - Scope scope_; + std::vector<Scope> scopes_; std::unordered_set<std::string> prev_return_names_; std::unordered_set<std::string> curr_return_names_; }; diff --git a/src/query/frontend/stripped_lexer_constants.hpp b/src/query/frontend/stripped_lexer_constants.hpp index 0dd697c06..42b7b4aeb 100644 --- a/src/query/frontend/stripped_lexer_constants.hpp +++ b/src/query/frontend/stripped_lexer_constants.hpp @@ -204,7 +204,8 @@ const trie::Trie kKeywords = {"union", "pulsar", "service_url", "version", - "websocket"}; + "websocket" + "foreach"}; // Unicode codepoints that are allowed at the start of the unescaped name. const std::bitset<kBitsetSize> kUnescapedNameAllowedStarts( diff --git a/src/query/plan/cost_estimator.hpp b/src/query/plan/cost_estimator.hpp index b9f26db00..0b3e7f867 100644 --- a/src/query/plan/cost_estimator.hpp +++ b/src/query/plan/cost_estimator.hpp @@ -61,6 +61,7 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor { static constexpr double kFilter{1.5}; static constexpr double kEdgeUniquenessFilter{1.5}; static constexpr double kUnwind{1.3}; + static constexpr double kForeach{1.0}; }; struct CardParam { @@ -72,6 +73,7 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor { struct MiscParam { static constexpr double kUnwindNoLiteral{10.0}; + static constexpr double kForeachNoLiteral{10.0}; }; using HierarchicalLogicalOperatorVisitor::PostVisit; @@ -193,6 +195,23 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor { return true; } + bool PostVisit(Foreach &foreach) override { + // Foreach cost depends both on the number elements in the list that get unwound + // as well as the total clauses that get called for each unwounded element. + // First estimate cardinality and then increment the cost. + + double foreach_elements{0}; + if (auto *literal = utils::Downcast<query::ListLiteral>(foreach.expression_)) { + foreach_elements = literal->elements_.size(); + } else { + foreach_elements = MiscParam::kForeachNoLiteral; + } + + cardinality_ *= foreach_elements; + IncrementCost(CostParam::kForeach); + return true; + } + bool Visit(Once &) override { return true; } auto cost() const { return cost_; } diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index c3f0afbf4..5349812ec 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -45,6 +45,7 @@ #include "utils/fnv.hpp" #include "utils/likely.hpp" #include "utils/logging.hpp" +#include "utils/memory.hpp" #include "utils/pmr/unordered_map.hpp" #include "utils/pmr/unordered_set.hpp" #include "utils/pmr/vector.hpp" @@ -105,6 +106,7 @@ extern const Event DistinctOperator; extern const Event UnionOperator; extern const Event CartesianOperator; extern const Event CallProcedureOperator; +extern const Event ForeachOperator; } // namespace EventCounter namespace memgraph::query::plan { @@ -4024,4 +4026,85 @@ UniqueCursorPtr LoadCsv::MakeCursor(utils::MemoryResource *mem) const { return MakeUniqueCursorPtr<LoadCsvCursor>(mem, this, mem); }; +class ForeachCursor : public Cursor { + public: + explicit ForeachCursor(const Foreach &foreach, utils::MemoryResource *mem) + : loop_variable_symbol_(foreach.loop_variable_symbol_), + input_(foreach.input_->MakeCursor(mem)), + updates_(foreach.update_clauses_->MakeCursor(mem)), + expression(foreach.expression_) {} + + bool Pull(Frame &frame, ExecutionContext &context) override { + SCOPED_PROFILE_OP(op_name_); + + if (!input_->Pull(frame, context)) { + return false; + } + + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, + storage::View::NEW); + TypedValue expr_result = expression->Accept(evaluator); + + if (expr_result.IsNull()) { + return true; + } + + if (!expr_result.IsList()) { + throw QueryRuntimeException("FOREACH expression must resolve to a list, but got '{}'.", expr_result.type()); + } + + const auto &cache_ = expr_result.ValueList(); + for (const auto &index : cache_) { + frame[loop_variable_symbol_] = index; + while (updates_->Pull(frame, context)) { + } + ResetUpdates(); + } + + return true; + } + + void Shutdown() override { input_->Shutdown(); } + + void ResetUpdates() { updates_->Reset(); } + + void Reset() override { + input_->Reset(); + ResetUpdates(); + } + + private: + const Symbol loop_variable_symbol_; + const UniqueCursorPtr input_; + const UniqueCursorPtr updates_; + Expression *expression; + const char *op_name_{"Foreach"}; +}; + +Foreach::Foreach(std::shared_ptr<LogicalOperator> input, std::shared_ptr<LogicalOperator> updates, Expression *expr, + Symbol loop_variable_symbol) + : input_(input ? std::move(input) : std::make_shared<Once>()), + update_clauses_(std::move(updates)), + expression_(expr), + loop_variable_symbol_(loop_variable_symbol) {} + +UniqueCursorPtr Foreach::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::ForeachOperator); + return MakeUniqueCursorPtr<ForeachCursor>(mem, *this, mem); +} + +std::vector<Symbol> Foreach::ModifiedSymbols(const SymbolTable &table) const { + auto symbols = input_->ModifiedSymbols(table); + symbols.emplace_back(loop_variable_symbol_); + return symbols; +} + +bool Foreach::Accept(HierarchicalLogicalOperatorVisitor &visitor) { + if (visitor.PreVisit(*this)) { + input_->Accept(visitor); + update_clauses_->Accept(visitor); + } + return visitor.PostVisit(*this); +} + } // namespace memgraph::query::plan diff --git a/src/query/plan/operator.lcp b/src/query/plan/operator.lcp index 346036df3..1a3b4b43e 100644 --- a/src/query/plan/operator.lcp +++ b/src/query/plan/operator.lcp @@ -131,6 +131,7 @@ class Union; class Cartesian; class CallProcedure; class LoadCsv; +class Foreach; using LogicalOperatorCompositeVisitor = utils::CompositeVisitor< Once, CreateNode, CreateExpand, ScanAll, ScanAllByLabel, @@ -139,7 +140,7 @@ using LogicalOperatorCompositeVisitor = utils::CompositeVisitor< Expand, ExpandVariable, ConstructNamedPath, Filter, Produce, Delete, SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, EdgeUniquenessFilter, Accumulate, Aggregate, Skip, Limit, OrderBy, Merge, - Optional, Unwind, Distinct, Union, Cartesian, CallProcedure, LoadCsv>; + Optional, Unwind, Distinct, Union, Cartesian, CallProcedure, LoadCsv, Foreach>; using LogicalOperatorLeafVisitor = utils::LeafVisitor<Once>; @@ -2261,6 +2262,42 @@ at once. Instead, each call of the callback should return a single row of the ta (:serialize (:slk)) (:clone)) +(lcp:define-class foreach (logical-operator) + ((input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (update-clauses "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (expression "Expression *" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (loop-variable-symbol "Symbol" :scope :public)) + + (:documentation + "Iterates over a collection of elements and applies one or more update +clauses. +") + (:public + #>cpp + Foreach() = default; + Foreach(std::shared_ptr<LogicalOperator> input, + std::shared_ptr<LogicalOperator> updates, + Expression *named_expr, + Symbol loop_variable_symbol); + + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = std::move(input); + } + cpp<#) + (:serialize (:slk)) + (:clone)) + (lcp:pop-namespace) ;; plan (lcp:pop-namespace) ;; query (lcp:pop-namespace) ;; memgraph diff --git a/src/query/plan/preprocess.cpp b/src/query/plan/preprocess.cpp index ca5300b53..cd0ebd965 100644 --- a/src/query/plan/preprocess.cpp +++ b/src/query/plan/preprocess.cpp @@ -17,8 +17,10 @@ #include <variant> #include "query/exceptions.hpp" +#include "query/frontend/ast/ast.hpp" #include "query/frontend/ast/ast_visitor.hpp" #include "query/plan/preprocess.hpp" +#include "utils/typeinfo.hpp" namespace memgraph::query::plan { @@ -526,6 +528,18 @@ void Filters::AnalyzeAndStoreFilter(Expression *expr, const SymbolTable &symbol_ // as `expr1 < n.prop AND n.prop < expr2`. } +static void ParseForeach(query::Foreach &foreach, SingleQueryPart &query_part, AstStorage &storage, + SymbolTable &symbol_table) { + for (auto *clause : foreach.clauses_) { + if (auto *merge = utils::Downcast<query::Merge>(clause)) { + query_part.merge_matching.emplace_back(Matching{}); + AddMatching({merge->pattern_}, nullptr, symbol_table, storage, query_part.merge_matching.back()); + } else if (auto *nested = utils::Downcast<query::Foreach>(clause)) { + ParseForeach(*nested, query_part, storage, symbol_table); + } + } +} + // Converts a Query to multiple QueryParts. In the process new Ast nodes may be // created, e.g. filter expressions. std::vector<SingleQueryPart> CollectSingleQueryParts(SymbolTable &symbol_table, AstStorage &storage, @@ -546,6 +560,8 @@ std::vector<SingleQueryPart> CollectSingleQueryParts(SymbolTable &symbol_table, if (auto *merge = utils::Downcast<query::Merge>(clause)) { query_part->merge_matching.emplace_back(Matching{}); AddMatching({merge->pattern_}, nullptr, symbol_table, storage, query_part->merge_matching.back()); + } else if (auto *foreach = utils::Downcast<query::Foreach>(clause)) { + ParseForeach(*foreach, *query_part, storage, symbol_table); } else if (utils::IsSubtype(*clause, With::kType) || utils::IsSubtype(*clause, query::Unwind::kType) || utils::IsSubtype(*clause, query::CallProcedure::kType) || utils::IsSubtype(*clause, query::LoadCsv::kType)) { diff --git a/src/query/plan/preprocess.hpp b/src/query/plan/preprocess.hpp index d567ff070..35a3002ec 100644 --- a/src/query/plan/preprocess.hpp +++ b/src/query/plan/preprocess.hpp @@ -306,6 +306,10 @@ struct Matching { /// will produce the second `merge_matching` element. This way, if someone /// traverses `remaining_clauses`, the order of appearance of `Merge` clauses is /// in the same order as their respective `merge_matching` elements. +/// An exception to the above rule is Foreach. Its update clauses will not be contained in +/// the `remaining_clauses`, but rather inside the foreach itself. The order guarantee is not +/// violated because the update clauses of the foreach are immediately processed in +/// the `RuleBasedPlanner` as if as they were pushed into the `remaining_clauses`. struct SingleQueryPart { /// @brief All `MATCH` clauses merged into one @c Matching. Matching matching; @@ -320,6 +324,10 @@ struct SingleQueryPart { /// /// Since @c Merge is contained in `remaining_clauses`, this vector contains /// matching in the same order as @c Merge appears. + // + /// Foreach @c does not violate this gurantee. However, update clauses are not stored + /// in the `remaining_clauses` but rather in the `Foreach` itself and are guranteed + /// to be processed in the same order by the semantics of the `RuleBasedPlanner`. std::vector<Matching> merge_matching{}; /// @brief All the remaining clauses (without @c Match). std::vector<Clause *> remaining_clauses{}; diff --git a/src/query/plan/pretty_print.cpp b/src/query/plan/pretty_print.cpp index fea2b4dd9..b44b8986c 100644 --- a/src/query/plan/pretty_print.cpp +++ b/src/query/plan/pretty_print.cpp @@ -241,6 +241,12 @@ bool PlanPrinter::PreVisit(query::plan::Cartesian &op) { return false; } +bool PlanPrinter::PreVisit(query::plan::Foreach &op) { + WithPrintLn([](auto &out) { out << "* Foreach"; }); + Branch(*op.update_clauses_); + op.input_->Accept(*this); + return false; +} #undef PRE_VISIT bool PlanPrinter::DefaultPreVisit() { @@ -883,6 +889,21 @@ bool PlanToJsonVisitor::PreVisit(Cartesian &op) { output_ = std::move(self); return false; } +bool PlanToJsonVisitor::PreVisit(Foreach &op) { + json self; + self["name"] = "Foreach"; + self["loop_variable_symbol"] = ToJson(op.loop_variable_symbol_); + self["expression"] = ToJson(op.expression_); + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + op.update_clauses_->Accept(*this); + self["update_clauses"] = PopOutput(); + + output_ = std::move(self); + return false; +} } // namespace impl diff --git a/src/query/plan/pretty_print.hpp b/src/query/plan/pretty_print.hpp index fe2954954..44ae102a5 100644 --- a/src/query/plan/pretty_print.hpp +++ b/src/query/plan/pretty_print.hpp @@ -92,6 +92,7 @@ class PlanPrinter : public virtual HierarchicalLogicalOperatorVisitor { bool PreVisit(Unwind &) override; bool PreVisit(CallProcedure &) override; bool PreVisit(LoadCsv &) override; + bool PreVisit(Foreach &) override; bool Visit(Once &) override; @@ -204,6 +205,7 @@ class PlanToJsonVisitor : public virtual HierarchicalLogicalOperatorVisitor { bool PreVisit(Union &) override; bool PreVisit(Unwind &) override; + bool PreVisit(Foreach &) override; bool PreVisit(CallProcedure &) override; bool PreVisit(LoadCsv &) override; diff --git a/src/query/plan/read_write_type_checker.cpp b/src/query/plan/read_write_type_checker.cpp index a2e65f295..1d2e752d4 100644 --- a/src/query/plan/read_write_type_checker.cpp +++ b/src/query/plan/read_write_type_checker.cpp @@ -79,6 +79,11 @@ bool ReadWriteTypeChecker::PreVisit(CallProcedure &op) { return true; } +bool ReadWriteTypeChecker::PreVisit([[maybe_unused]] Foreach &op) { + UpdateType(RWType::RW); + return false; +} + #undef PRE_VISIT bool ReadWriteTypeChecker::Visit(Once &op) { return false; } diff --git a/src/query/plan/read_write_type_checker.hpp b/src/query/plan/read_write_type_checker.hpp index 95b014caa..8b7f53987 100644 --- a/src/query/plan/read_write_type_checker.hpp +++ b/src/query/plan/read_write_type_checker.hpp @@ -84,6 +84,7 @@ class ReadWriteTypeChecker : public virtual HierarchicalLogicalOperatorVisitor { bool PreVisit(Unwind &) override; bool PreVisit(CallProcedure &) override; + bool PreVisit(Foreach &) override; bool Visit(Once &) override; diff --git a/src/query/plan/rewrite/index_lookup.hpp b/src/query/plan/rewrite/index_lookup.hpp index 3330011e0..791379018 100644 --- a/src/query/plan/rewrite/index_lookup.hpp +++ b/src/query/plan/rewrite/index_lookup.hpp @@ -433,6 +433,16 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor { return true; } + bool PreVisit(Foreach &op) override { + prev_ops_.push_back(&op); + return false; + } + + bool PostVisit(Foreach &) override { + prev_ops_.pop_back(); + return true; + } + std::shared_ptr<LogicalOperator> new_root_; private: diff --git a/src/query/plan/rule_based_planner.hpp b/src/query/plan/rule_based_planner.hpp index d769b33e1..8854392e6 100644 --- a/src/query/plan/rule_based_planner.hpp +++ b/src/query/plan/rule_based_planner.hpp @@ -22,6 +22,7 @@ #include "query/plan/operator.hpp" #include "query/plan/preprocess.hpp" #include "utils/logging.hpp" +#include "utils/typeinfo.hpp" namespace memgraph::query::plan { @@ -223,6 +224,10 @@ class RuleBasedPlanner { input_op = std::make_unique<plan::LoadCsv>(std::move(input_op), load_csv->file_, load_csv->with_header_, load_csv->ignore_bad_, load_csv->delimiter_, load_csv->quote_, row_sym); + } else if (auto *foreach = utils::Downcast<query::Foreach>(clause)) { + is_write = true; + input_op = HandleForeachClause(foreach, std::move(input_op), *context.symbol_table, context.bound_symbols, + query_part, merge_id); } else { throw utils::NotYetImplemented("clause '{}' conversion to operator(s)", clause->GetTypeInfo().name); } @@ -530,6 +535,27 @@ class RuleBasedPlanner { } return std::make_unique<plan::Merge>(std::move(input_op), std::move(on_match), std::move(on_create)); } + + std::unique_ptr<LogicalOperator> HandleForeachClause(query::Foreach *foreach, + std::unique_ptr<LogicalOperator> input_op, + const SymbolTable &symbol_table, + std::unordered_set<Symbol> &bound_symbols, + const SingleQueryPart &query_part, uint64_t &merge_id) { + const auto &symbol = symbol_table.at(*foreach->named_expression_); + bound_symbols.insert(symbol); + std::unique_ptr<LogicalOperator> op = std::make_unique<plan::Once>(); + for (auto *clause : foreach->clauses_) { + if (auto *nested_for_each = utils::Downcast<query::Foreach>(clause)) { + op = HandleForeachClause(nested_for_each, std::move(op), symbol_table, bound_symbols, query_part, merge_id); + } else if (auto *merge = utils::Downcast<query::Merge>(clause)) { + op = GenMerge(*merge, std::move(op), query_part.merge_matching[merge_id++]); + } else { + op = HandleWriteClause(clause, op, symbol_table, bound_symbols); + } + } + return std::make_unique<plan::Foreach>(std::move(input_op), std::move(op), foreach->named_expression_->expression_, + symbol); + } }; } // namespace memgraph::query::plan diff --git a/src/query/plan/scoped_profile.hpp b/src/query/plan/scoped_profile.hpp index 969c777fd..e825384d1 100644 --- a/src/query/plan/scoped_profile.hpp +++ b/src/query/plan/scoped_profile.hpp @@ -71,9 +71,9 @@ class ScopedProfile { private: query::ExecutionContext *context_; - ProfilingStats *root_; - ProfilingStats *stats_; - unsigned long long start_time_; + ProfilingStats *root_{nullptr}; + ProfilingStats *stats_{nullptr}; + unsigned long long start_time_{0}; }; } // namespace memgraph::query::plan diff --git a/src/utils/event_counter.cpp b/src/utils/event_counter.cpp index bcbb1a166..634b6eae2 100644 --- a/src/utils/event_counter.cpp +++ b/src/utils/event_counter.cpp @@ -49,6 +49,7 @@ M(UnionOperator, "Number of times Union operator was used.") \ M(CartesianOperator, "Number of times Cartesian operator was used.") \ M(CallProcedureOperator, "Number of times CallProcedure operator was used.") \ + M(ForeachOperator, "Number of times Foreach operator was used.") \ \ M(FailedQuery, "Number of times executing a query failed.") \ M(LabelIndexCreated, "Number of times a label index was created.") \ diff --git a/tests/benchmark/query/execution.cpp b/tests/benchmark/query/execution.cpp index 6bc4b9b01..90d12e78f 100644 --- a/tests/benchmark/query/execution.cpp +++ b/tests/benchmark/query/execution.cpp @@ -499,4 +499,39 @@ BENCHMARK_TEMPLATE(Unwind, MonotonicBufferResource) BENCHMARK_TEMPLATE(Unwind, PoolResource)->Ranges({{4, 1U << 7U}, {512, 1U << 13U}})->Unit(benchmark::kMicrosecond); +template <class TMemory> +// NOLINTNEXTLINE(google-runtime-references) +static void Foreach(benchmark::State &state) { + memgraph::query::AstStorage ast; + memgraph::storage::Storage db; + memgraph::query::SymbolTable symbol_table; + auto list_sym = symbol_table.CreateSymbol("list", false); + auto *list_expr = ast.Create<memgraph::query::Identifier>("list")->MapTo(list_sym); + auto out_sym = symbol_table.CreateSymbol("out", false); + auto create_node = + std::make_shared<memgraph::query::plan::CreateNode>(nullptr, memgraph::query::plan::NodeCreationInfo{}); + auto foreach = std::make_shared<memgraph::query::plan::Foreach>(nullptr, std::move(create_node), list_expr, out_sym); + + auto storage_dba = db.Access(); + memgraph::query::DbAccessor dba(&storage_dba); + TMemory per_pull_memory; + memgraph::query::EvaluationContext evaluation_context{per_pull_memory.get()}; + while (state.KeepRunning()) { + memgraph::query::ExecutionContext execution_context{&dba, symbol_table, evaluation_context}; + TMemory memory; + memgraph::query::Frame frame(symbol_table.max_position(), memory.get()); + frame[list_sym] = memgraph::query::TypedValue(std::vector<memgraph::query::TypedValue>(state.range(1))); + auto cursor = foreach->MakeCursor(memory.get()); + while (cursor->Pull(frame, execution_context)) per_pull_memory.Reset(); + } + state.SetItemsProcessed(state.iterations()); +} + +BENCHMARK_TEMPLATE(Foreach, PoolResource)->Ranges({{4, 1U << 7U}, {512, 1U << 13U}})->Unit(benchmark::kMicrosecond); +BENCHMARK_TEMPLATE(Foreach, MonotonicBufferResource) + ->Ranges({{4, 1U << 7U}, {512, 1U << 13U}}) + ->Unit(benchmark::kMicrosecond); + +BENCHMARK_TEMPLATE(Foreach, PoolResource)->Ranges({{4, 1U << 7U}, {512, 1U << 13U}})->Unit(benchmark::kMicrosecond); + BENCHMARK_MAIN(); diff --git a/tests/gql_behave/tests/memgraph_V1/features/foreach.feature b/tests/gql_behave/tests/memgraph_V1/features/foreach.feature new file mode 100644 index 000000000..ce1d9c932 --- /dev/null +++ b/tests/gql_behave/tests/memgraph_V1/features/foreach.feature @@ -0,0 +1,273 @@ +# Copyright 2022 Memgraph Ltd. +# +# Use of this software is governed by the Business Source License +# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +# License, and you may not use this file except in compliance with the Business Source License. +# +# As of the Change Date specified in that file, in accordance with +# the Business Source License, use of this software will be governed +# by the Apache License, Version 2.0, included in the file +# licenses/APL.txt. + +Feature: Foreach + Behaviour tests for memgraph FOREACH clause + + Scenario: Foreach create + Given an empty graph + And having executed + """ + FOREACH( i IN [1, 2, 3] | CREATE (n {age : i})) + """ + When executing query: + """ + MATCH (n) RETURN n.age + """ + Then the result should be: + | n.age | + | 1 | + | 2 | + | 3 | + And no side effects + + Scenario: Foreach Foreach create + Given an empty graph + And having executed + """ + FOREACH( i IN [1, 2, 3] | CREATE (n {age : i})) FOREACH( i in [4, 5, 6] | CREATE (n {age : i})) + """ + When executing query: + """ + MATCH (n) RETURN n.age + """ + Then the result should be: + | n.age | + | 1 | + | 2 | + | 3 | + | 4 | + | 5 | + | 6 | + And no side effects + + Scenario: Foreach shadowing + Given an empty graph + And having executed + """ + FOREACH( i IN [1] | FOREACH( i in [2, 3, 4] | CREATE (n {age : i}))) + """ + When executing query: + """ + MATCH (n) RETURN n.age + """ + Then the result should be: + | n.age | + | 2 | + | 3 | + | 4 | + And no side effects + + Scenario: Foreach shadowing in create + Given an empty graph + And having executed + """ + FOREACH (i IN [1] | FOREACH (j IN [3,4] | CREATE (i {prop: j}))); + """ + When executing query: + """ + MATCH (n) RETURN n.prop + """ + Then the result should be: + | n.prop | + | 3 | + | 4 | + And no side effects + + Scenario: Foreach set + Given an empty graph + And having executed + """ + CREATE (n1 { marked: false })-[:RELATES]->(n2 { marked: false }) + """ + And having executed + """ + MATCH p=(n1)-[*]->(n2) + FOREACH (n IN nodes(p) | SET n.marked = true) + """ + When executing query: + """ + MATCH (n) + RETURN n.marked + """ + Then the result should be: + | n.marked | + | true | + | true | + And no side effects + + Scenario: Foreach remove + Given an empty graph + And having executed + """ + CREATE (n1 { marked: false })-[:RELATES]->(n2 { marked: false }) + """ + And having executed + """ + MATCH p=(n1)-[*]->(n2) + FOREACH (n IN nodes(p) | REMOVE n.marked) + """ + When executing query: + """ + MATCH (n) + RETURN n; + """ + Then the result should be: + | n | + | () | + | () | + And no side effects + + Scenario: Foreach delete + Given an empty graph + And having executed + """ + CREATE (n1 { marked: false })-[:RELATES]->(n2 { marked: false }) + """ + And having executed + """ + MATCH p=(n1)-[*]->(n2) + FOREACH (n IN nodes(p) | DETACH delete n) + """ + When executing query: + """ + MATCH (n) + RETURN n; + """ + Then the result should be: + | | + And no side effects + + Scenario: Foreach merge + Given an empty graph + And having executed + """ + FOREACH (i IN [1, 2, 3] | MERGE (n { age : i })) + """ + When executing query: + """ + MATCH (n) + RETURN n.age; + """ + Then the result should be: + | n.age | + | 1 | + | 2 | + | 3 | + And no side effects + + Scenario: Foreach nested + Given an empty graph + And having executed + """ + FOREACH (i IN [1, 2, 3] | FOREACH( j IN [1] | CREATE (k { prop : j }))) + """ + When executing query: + """ + MATCH (n) + RETURN n.prop; + """ + Then the result should be: + | n.prop | + | 1 | + | 1 | + | 1 | + + Scenario: Foreach multiple update clauses + Given an empty graph + And having executed + """ + CREATE (n1 { marked1: false, marked2: false })-[:RELATES]->(n2 { marked1: false, marked2: false }) + """ + And having executed + """ + MATCH p=(n1)-[*]->(n2) + FOREACH (n IN nodes(p) | SET n.marked1 = true SET n.marked2 = true) + """ + When executing query: + """ + MATCH (n) + RETURN n + """ + Then the result should be: + | n | + | ({marked1: true, marked2: true}) | + | ({marked1: true, marked2: true}) | + And no side effects + + Scenario: Foreach multiple nested update clauses + Given an empty graph + And having executed + """ + CREATE (n1 { marked1: false, marked2: false })-[:RELATES]->(n2 { marked1: false, marked2: false }) + """ + And having executed + """ + MATCH p=(n1)-[*]->(n2) + FOREACH (n IN nodes(p) | FOREACH (j IN [1] | SET n.marked1 = true SET n.marked2 = true)) + """ + When executing query: + """ + MATCH (n) + RETURN n + """ + Then the result should be: + | n | + | ({marked1: true, marked2: true}) | + | ({marked1: true, marked2: true}) | + And no side effects + + Scenario: Foreach match foreach return + Given an empty graph + And having executed + """ + CREATE (n {prop: [[], [1,2]]}); + """ + When executing query: + """ + MATCH (n) FOREACH (i IN n.prop | CREATE (:V { i: i})) RETURN n; + """ + Then the result should be: + | n | + | ({prop: [[], [1, 2]]}) | + And no side effects + + Scenario: Foreach on null value + Given an empty graph + And having executed + """ + CREATE (n); + """ + When executing query: + """ + MATCH (n) FOREACH (i IN n.prop | CREATE (:V { i: i})); + """ + Then the result should be: + | | + And no side effects + + Scenario: Foreach nested merge + Given an empty graph + And having executed + """ + FOREACH(i in [1, 2, 3] | foreach(j in [1] | MERGE (n { age : i }))); + """ + When executing query: + """ + MATCH (n) + RETURN n + """ + Then the result should be: + | n | + | ({age: 1}) | + | ({age: 2}) | + | ({age: 3}) | + And no side effects diff --git a/tests/unit/cypher_main_visitor.cpp b/tests/unit/cypher_main_visitor.cpp index 061ee90e1..3800a9e01 100644 --- a/tests/unit/cypher_main_visitor.cpp +++ b/tests/unit/cypher_main_visitor.cpp @@ -4090,3 +4090,89 @@ TEST_P(CypherMainVisitorTest, VersionQuery) { TestInvalidQuery("SHOW VERSIONS", ast_generator); ASSERT_NO_THROW(ast_generator.ParseQuery("SHOW VERSION")); } + +TEST_P(CypherMainVisitorTest, ForeachThrow) { + auto &ast_generator = *GetParam(); + EXPECT_THROW(ast_generator.ParseQuery("FOREACH(i IN [1, 2] | UNWIND [1,2,3] AS j CREATE (n))"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("FOREACH(i IN [1, 2] CREATE (:Foo {prop : i}))"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("FOREACH(i IN [1, 2] | MATCH (n)"), SyntaxException); + EXPECT_THROW(ast_generator.ParseQuery("FOREACH(i IN x | MATCH (n)"), SyntaxException); +} + +TEST_P(CypherMainVisitorTest, Foreach) { + auto &ast_generator = *GetParam(); + // CREATE + { + auto *query = dynamic_cast<CypherQuery *>( + ast_generator.ParseQuery("FOREACH (age IN [1, 2, 3] | CREATE (m:Age {amount: age}))")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *foreach = dynamic_cast<Foreach *>(single_query->clauses_[0]); + ASSERT_TRUE(foreach); + ASSERT_TRUE(foreach->named_expression_); + EXPECT_EQ(foreach->named_expression_->name_, "age"); + auto *expr = foreach->named_expression_->expression_; + ASSERT_TRUE(expr); + ASSERT_TRUE(dynamic_cast<ListLiteral *>(expr)); + const auto &clauses = foreach->clauses_; + ASSERT_TRUE(clauses.size() == 1); + ASSERT_TRUE(dynamic_cast<Create *>(clauses.front())); + } + // SET + { + auto *query = + dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("FOREACH (i IN nodes(path) | SET i.checkpoint = true)")); + auto *foreach = dynamic_cast<Foreach *>(query->single_query_->clauses_[0]); + const auto &clauses = foreach->clauses_; + ASSERT_TRUE(clauses.size() == 1); + ASSERT_TRUE(dynamic_cast<SetProperty *>(clauses.front())); + } + // REMOVE + { + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("FOREACH (i IN nodes(path) | REMOVE i.prop)")); + auto *foreach = dynamic_cast<Foreach *>(query->single_query_->clauses_[0]); + const auto &clauses = foreach->clauses_; + ASSERT_TRUE(clauses.size() == 1); + ASSERT_TRUE(dynamic_cast<RemoveProperty *>(clauses.front())); + } + // MERGE + { + // merge works as create here + auto *query = + dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("FOREACH (i IN [1, 2, 3] | MERGE (n {no : i}))")); + auto *foreach = dynamic_cast<Foreach *>(query->single_query_->clauses_[0]); + const auto &clauses = foreach->clauses_; + ASSERT_TRUE(clauses.size() == 1); + ASSERT_TRUE(dynamic_cast<Merge *>(clauses.front())); + } + // CYPHER DELETE + { + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("FOREACH (i IN nodes(path) | DETACH DELETE i)")); + auto *foreach = dynamic_cast<Foreach *>(query->single_query_->clauses_[0]); + const auto &clauses = foreach->clauses_; + ASSERT_TRUE(clauses.size() == 1); + ASSERT_TRUE(dynamic_cast<Delete *>(clauses.front())); + } + // nested FOREACH + { + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery( + "FOREACH (i IN nodes(path) | FOREACH (age IN i.list | CREATE (m:Age {amount: age})))")); + + auto *foreach = dynamic_cast<Foreach *>(query->single_query_->clauses_[0]); + const auto &clauses = foreach->clauses_; + ASSERT_TRUE(clauses.size() == 1); + ASSERT_TRUE(dynamic_cast<Foreach *>(clauses.front())); + } + // Multiple update clauses + { + auto *query = dynamic_cast<CypherQuery *>( + ast_generator.ParseQuery("FOREACH (i IN nodes(path) | SET i.checkpoint = true REMOVE i.prop)")); + auto *foreach = dynamic_cast<Foreach *>(query->single_query_->clauses_[0]); + const auto &clauses = foreach->clauses_; + ASSERT_TRUE(clauses.size() == 2); + ASSERT_TRUE(dynamic_cast<SetProperty *>(clauses.front())); + ASSERT_TRUE(dynamic_cast<RemoveProperty *>(*++clauses.begin())); + } +} diff --git a/tests/unit/plan_pretty_print.cpp b/tests/unit/plan_pretty_print.cpp index 4444bd653..0148a3082 100644 --- a/tests/unit/plan_pretty_print.cpp +++ b/tests/unit/plan_pretty_print.cpp @@ -911,3 +911,34 @@ TEST_F(PrintToJsonTest, CallProcedure) { "result_symbols" : ["name_alias", "signature_alias"] })sep"); } + +TEST_F(PrintToJsonTest, Foreach) { + Symbol x = GetSymbol("x"); + std::shared_ptr<LogicalOperator> create = + std::make_shared<CreateNode>(nullptr, NodeCreationInfo{GetSymbol("node"), {dba.NameToLabel("Label1")}, {}}); + std::shared_ptr<LogicalOperator> foreach = + std::make_shared<plan::Foreach>(nullptr, std::move(create), LIST(LITERAL(1)), x); + + Check(foreach.get(), R"sep( + { + "expression": "(ListLiteral [1])", + "input": { + "name": "Once" + }, + "name": "Foreach", + "loop_variable_symbol": "x", + "update_clauses": { + "input": { + "name": "Once" + }, + "name": "CreateNode", + "node_info": { + "labels": [ + "Label1" + ], + "properties": null, + "symbol": "node" + } + } + })sep"); +} diff --git a/tests/unit/query_common.hpp b/tests/unit/query_common.hpp index 60c1685d8..b0dd41a72 100644 --- a/tests/unit/query_common.hpp +++ b/tests/unit/query_common.hpp @@ -463,6 +463,11 @@ auto GetCallProcedure(AstStorage &storage, std::string procedure_name, return call_procedure; } +/// Create the FOREACH clause with given named expression. +auto GetForeach(AstStorage &storage, NamedExpression *named_expr, const std::vector<query::Clause *> &clauses) { + return storage.Create<query::Foreach>(named_expr, clauses); +} + } // namespace test_common } // namespace memgraph::query @@ -526,13 +531,14 @@ auto GetCallProcedure(AstStorage &storage, std::string procedure_name, memgraph::query::test_common::OnCreate { \ std::vector<memgraph::query::Clause *> { __VA_ARGS__ } \ } -#define CREATE_INDEX_ON(label, property) \ +#define CREATE_INDEX_ON(label, property) \ storage.Create<memgraph::query::IndexQuery>(memgraph::query::IndexQuery::Action::CREATE, (label), \ - std::vector<memgraph::query::PropertyIx>{(property)}) + std::vector<memgraph::query::PropertyIx>{(property)}) #define QUERY(...) memgraph::query::test_common::GetQuery(storage, __VA_ARGS__) #define SINGLE_QUERY(...) memgraph::query::test_common::GetSingleQuery(storage.Create<SingleQuery>(), __VA_ARGS__) #define UNION(...) memgraph::query::test_common::GetCypherUnion(storage.Create<CypherUnion>(true), __VA_ARGS__) #define UNION_ALL(...) memgraph::query::test_common::GetCypherUnion(storage.Create<CypherUnion>(false), __VA_ARGS__) +#define FOREACH(...) memgraph::query::test_common::GetForeach(storage, __VA_ARGS__) // Various operators #define NOT(expr) storage.Create<memgraph::query::NotOperator>((expr)) #define UPLUS(expr) storage.Create<memgraph::query::UnaryPlusOperator>((expr)) diff --git a/tests/unit/query_cost_estimator.cpp b/tests/unit/query_cost_estimator.cpp index b73b24e6f..5adc6f88d 100644 --- a/tests/unit/query_cost_estimator.cpp +++ b/tests/unit/query_cost_estimator.cpp @@ -181,6 +181,19 @@ TEST_F(QueryCostEstimator, ExpandVariable) { EXPECT_COST(CardParam::kExpandVariable * CostParam::kExpandVariable); } +TEST_F(QueryCostEstimator, ForeachListLiteral) { + constexpr size_t list_expr_sz = 10; + std::shared_ptr<LogicalOperator> create = std::make_shared<CreateNode>(std::make_shared<Once>(), NodeCreationInfo{}); + MakeOp<memgraph::query::plan::Foreach>( + last_op_, create, storage_.Create<ListLiteral>(std::vector<Expression *>(list_expr_sz, nullptr)), NextSymbol()); + EXPECT_COST(CostParam::kForeach * list_expr_sz); +} + +TEST_F(QueryCostEstimator, Foreach) { + std::shared_ptr<LogicalOperator> create = std::make_shared<CreateNode>(std::make_shared<Once>(), NodeCreationInfo{}); + MakeOp<memgraph::query::plan::Foreach>(last_op_, create, storage_.Create<Identifier>(), NextSymbol()); + EXPECT_COST(CostParam::kForeach * MiscParam::kForeachNoLiteral); +} // Helper for testing an operations cost and cardinality. // Only for operations that first increment cost, then modify cardinality. // Intentially a macro (instead of function) for better test feedback. diff --git a/tests/unit/query_plan.cpp b/tests/unit/query_plan.cpp index e5cc29296..93d2f33c7 100644 --- a/tests/unit/query_plan.cpp +++ b/tests/unit/query_plan.cpp @@ -90,7 +90,6 @@ void DeleteListContent(std::list<BaseOpChecker *> *list) { delete ptr; } } - TYPED_TEST_CASE(TestPlanner, PlannerTypes); TYPED_TEST(TestPlanner, MatchNodeReturn) { @@ -1555,4 +1554,57 @@ TYPED_TEST(TestPlanner, LabelPropertyInListWhereLabelPropertyOnRight) { } } +TYPED_TEST(TestPlanner, Foreach) { + AstStorage storage; + FakeDbAccessor dba; + { + auto *i = NEXPR("i", IDENT("i")); + auto *query = QUERY(SINGLE_QUERY(FOREACH(i, {CREATE(PATTERN(NODE("n")))}))); + auto create = ExpectCreateNode(); + std::list<BaseOpChecker *> updates{&create}; + std::list<BaseOpChecker *> input; + CheckPlan<TypeParam>(query, storage, ExpectForeach(input, updates)); + } + { + auto *i = NEXPR("i", IDENT("i")); + auto *query = QUERY(SINGLE_QUERY(FOREACH(i, {DELETE(IDENT("i"))}))); + auto del = ExpectDelete(); + std::list<BaseOpChecker *> updates{&del}; + std::list<BaseOpChecker *> input; + CheckPlan<TypeParam>(query, storage, ExpectForeach({input}, updates)); + } + { + auto prop = dba.Property("prop"); + auto *i = NEXPR("i", IDENT("i")); + auto *query = QUERY(SINGLE_QUERY(FOREACH(i, {SET(PROPERTY_LOOKUP("i", prop), LITERAL(10))}))); + auto set_prop = ExpectSetProperty(); + std::list<BaseOpChecker *> updates{&set_prop}; + std::list<BaseOpChecker *> input; + CheckPlan<TypeParam>(query, storage, ExpectForeach({input}, updates)); + } + { + auto *i = NEXPR("i", IDENT("i")); + auto *j = NEXPR("j", IDENT("j")); + auto *query = QUERY(SINGLE_QUERY(FOREACH(i, {FOREACH(j, {CREATE(PATTERN(NODE("n"))), DELETE(IDENT("i"))})}))); + auto create = ExpectCreateNode(); + auto del = ExpectDelete(); + std::list<BaseOpChecker *> input; + std::list<BaseOpChecker *> nested_updates{{&create, &del}}; + auto nested_foreach = ExpectForeach(input, nested_updates); + std::list<BaseOpChecker *> updates{&nested_foreach}; + CheckPlan<TypeParam>(query, storage, ExpectForeach(input, updates)); + } + { + auto *i = NEXPR("i", IDENT("i")); + auto *j = NEXPR("j", IDENT("j")); + auto create = ExpectCreateNode(); + std::list<BaseOpChecker *> empty; + std::list<BaseOpChecker *> updates{&create}; + auto input_op = ExpectForeach(empty, updates); + std::list<BaseOpChecker *> input{&input_op}; + auto *query = + QUERY(SINGLE_QUERY(FOREACH(i, {CREATE(PATTERN(NODE("n")))}), FOREACH(j, {CREATE(PATTERN(NODE("n")))}))); + CheckPlan<TypeParam>(query, storage, ExpectForeach(input, updates)); + } +} } // namespace diff --git a/tests/unit/query_plan_checker.hpp b/tests/unit/query_plan_checker.hpp index a454370a0..335b6ab2b 100644 --- a/tests/unit/query_plan_checker.hpp +++ b/tests/unit/query_plan_checker.hpp @@ -90,6 +90,11 @@ class PlanChecker : public virtual HierarchicalLogicalOperatorVisitor { } PRE_VISIT(Unwind); PRE_VISIT(Distinct); + + bool PreVisit(Foreach &op) override { + CheckOp(op); + return false; + } bool Visit(Once &) override { // Ignore checking Once, it is implicitly at the end. @@ -150,6 +155,25 @@ using ExpectOrderBy = OpChecker<OrderBy>; using ExpectUnwind = OpChecker<Unwind>; using ExpectDistinct = OpChecker<Distinct>; +class ExpectForeach : public OpChecker<Foreach> { + public: + ExpectForeach(const std::list<BaseOpChecker *> &input, const std::list<BaseOpChecker *> &updates) + : input_(input), updates_(updates) {} + + void ExpectOp(Foreach &foreach, const SymbolTable &symbol_table) override { + PlanChecker check_input(input_, symbol_table); + foreach + .input_->Accept(check_input); + PlanChecker check_updates(updates_, symbol_table); + foreach + .update_clauses_->Accept(check_updates); + } + + private: + std::list<BaseOpChecker *> input_; + std::list<BaseOpChecker *> updates_; +}; + class ExpectExpandVariable : public OpChecker<ExpandVariable> { public: void ExpectOp(ExpandVariable &op, const SymbolTable &) override { diff --git a/tests/unit/query_plan_read_write_typecheck.cpp b/tests/unit/query_plan_read_write_typecheck.cpp index b68202a86..673620ff9 100644 --- a/tests/unit/query_plan_read_write_typecheck.cpp +++ b/tests/unit/query_plan_read_write_typecheck.cpp @@ -247,3 +247,9 @@ TEST_F(ReadWriteTypeCheckTest, ConstructNamedPath) { CheckPlanType(last_op.get(), RWType::R); } + +TEST_F(ReadWriteTypeCheckTest, Foreach) { + Symbol x = GetSymbol("x"); + std::shared_ptr<LogicalOperator> foreach = std::make_shared<plan::Foreach>(nullptr, nullptr, nullptr, x); + CheckPlanType(foreach.get(), RWType::RW); +} diff --git a/tests/unit/query_semantic.cpp b/tests/unit/query_semantic.cpp index 643d8175d..45e60811a 100644 --- a/tests/unit/query_semantic.cpp +++ b/tests/unit/query_semantic.cpp @@ -1157,3 +1157,23 @@ TEST_F(TestSymbolGenerator, PredefinedIdentifiers) { query = QUERY(SINGLE_QUERY(unwind, CREATE(PATTERN(node)))); ASSERT_THROW(memgraph::query::MakeSymbolTable(query, {first_op}), SemanticException); } + +TEST_F(TestSymbolGenerator, Foreach) { + auto *i = NEXPR("i", IDENT("i")); + auto query = QUERY(SINGLE_QUERY(FOREACH(i, {CREATE(PATTERN(NODE("n")))}), RETURN("n"))); + EXPECT_THROW(memgraph::query::MakeSymbolTable(query), UnboundVariableError); + + query = QUERY(SINGLE_QUERY(FOREACH(i, {CREATE(PATTERN(NODE("n")))}), FOREACH(i, {CREATE(PATTERN(NODE("v")))}))); + auto symbol_table = memgraph::query::MakeSymbolTable(query); + ASSERT_EQ(symbol_table.max_position(), 6); + + query = QUERY(SINGLE_QUERY(FOREACH(i, {FOREACH(i, {CREATE(PATTERN(NODE("i")))})}))); + EXPECT_THROW(memgraph::query::MakeSymbolTable(query), RedeclareVariableError); + + query = QUERY(SINGLE_QUERY(FOREACH(i, {FOREACH(i, {CREATE(PATTERN(NODE("v")))})}))); + symbol_table = memgraph::query::MakeSymbolTable(query); + ASSERT_EQ(symbol_table.max_position(), 4); + + query = QUERY(SINGLE_QUERY(FOREACH(i, {CREATE(PATTERN(NODE("n")))}), RETURN("i"))); + EXPECT_THROW(memgraph::query::MakeSymbolTable(query), UnboundVariableError); +}