Implement foreach clause (#351)
This commit is contained in:
parent
c8dbaf5979
commit
ea2806bd57
@ -2629,5 +2629,39 @@ cpp<#
|
||||
(:serialize (:slk))
|
||||
(:clone))
|
||||
|
||||
(lcp:define-class foreach (clause)
|
||||
((named_expression "NamedExpression *" :initval "nullptr" :scope :public
|
||||
:slk-save #'slk-save-ast-pointer
|
||||
:slk-load (slk-load-ast-pointer "Expression"))
|
||||
(clauses "std::vector<Clause *>"
|
||||
:scope :public
|
||||
:slk-save #'slk-save-ast-vector
|
||||
:slk-load (slk-load-ast-vector "Clause")))
|
||||
(:public
|
||||
#>cpp
|
||||
Foreach() = default;
|
||||
|
||||
bool Accept(HierarchicalTreeVisitor &visitor) override {
|
||||
if (visitor.PreVisit(*this)) {
|
||||
named_expression_->Accept(visitor);
|
||||
for (auto &clause : clauses_) {
|
||||
clause->Accept(visitor);
|
||||
}
|
||||
}
|
||||
return visitor.PostVisit(*this);
|
||||
}
|
||||
cpp<#)
|
||||
(:protected
|
||||
#>cpp
|
||||
Foreach(NamedExpression *expression, std::vector<Clause *> clauses)
|
||||
: named_expression_(expression), clauses_(clauses) {}
|
||||
cpp<#)
|
||||
(:private
|
||||
#>cpp
|
||||
friend class AstStorage;
|
||||
cpp<#)
|
||||
(:serialize (:slk))
|
||||
(:clone))
|
||||
|
||||
(lcp:pop-namespace) ;; namespace query
|
||||
(lcp:pop-namespace) ;; namespace memgraph
|
||||
|
@ -93,6 +93,7 @@ class CreateSnapshotQuery;
|
||||
class StreamQuery;
|
||||
class SettingQuery;
|
||||
class VersionQuery;
|
||||
class Foreach;
|
||||
|
||||
using TreeCompositeVisitor = utils::CompositeVisitor<
|
||||
SingleQuery, CypherUnion, NamedExpression, OrOperator, XorOperator, AndOperator, NotOperator, AdditionOperator,
|
||||
@ -101,7 +102,7 @@ using TreeCompositeVisitor = utils::CompositeVisitor<
|
||||
ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, ListLiteral, MapLiteral,
|
||||
PropertyLookup, LabelsTest, Aggregation, Function, Reduce, Coalesce, Extract, All, Single, Any, None, CallProcedure,
|
||||
Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom, Delete, Where, SetProperty, SetProperties, SetLabels,
|
||||
RemoveProperty, RemoveLabels, Merge, Unwind, RegexMatch, LoadCsv>;
|
||||
RemoveProperty, RemoveLabels, Merge, Unwind, RegexMatch, LoadCsv, Foreach>;
|
||||
|
||||
using TreeLeafVisitor = utils::LeafVisitor<Identifier, PrimitiveLiteral, ParameterLookup>;
|
||||
|
||||
|
@ -37,6 +37,7 @@
|
||||
#include "utils/exceptions.hpp"
|
||||
#include "utils/logging.hpp"
|
||||
#include "utils/string.hpp"
|
||||
#include "utils/typeinfo.hpp"
|
||||
|
||||
namespace memgraph::query::frontend {
|
||||
|
||||
@ -956,7 +957,8 @@ antlrcpp::Any CypherMainVisitor::visitSingleQuery(MemgraphCypher::SingleQueryCon
|
||||
utils::IsSubtype(clause_type, SetProperty::kType) ||
|
||||
utils::IsSubtype(clause_type, SetProperties::kType) || utils::IsSubtype(clause_type, SetLabels::kType) ||
|
||||
utils::IsSubtype(clause_type, RemoveProperty::kType) ||
|
||||
utils::IsSubtype(clause_type, RemoveLabels::kType) || utils::IsSubtype(clause_type, Merge::kType)) {
|
||||
utils::IsSubtype(clause_type, RemoveLabels::kType) || utils::IsSubtype(clause_type, Merge::kType) ||
|
||||
utils::IsSubtype(clause_type, Foreach::kType)) {
|
||||
if (has_return) {
|
||||
throw SemanticException("Update clause can't be used after RETURN.");
|
||||
}
|
||||
@ -1036,6 +1038,9 @@ antlrcpp::Any CypherMainVisitor::visitClause(MemgraphCypher::ClauseContext *ctx)
|
||||
if (ctx->loadCsv()) {
|
||||
return static_cast<Clause *>(ctx->loadCsv()->accept(this).as<LoadCsv *>());
|
||||
}
|
||||
if (ctx->foreach ()) {
|
||||
return static_cast<Clause *>(ctx->foreach ()->accept(this).as<Foreach *>());
|
||||
}
|
||||
// TODO: implement other clauses.
|
||||
throw utils::NotYetImplemented("clause '{}'", ctx->getText());
|
||||
return 0;
|
||||
@ -2283,6 +2288,37 @@ antlrcpp::Any CypherMainVisitor::visitFilterExpression(MemgraphCypher::FilterExp
|
||||
return 0;
|
||||
}
|
||||
|
||||
antlrcpp::Any CypherMainVisitor::visitForeach(MemgraphCypher::ForeachContext *ctx) {
|
||||
auto *for_each = storage_->Create<Foreach>();
|
||||
|
||||
auto *named_expr = storage_->Create<NamedExpression>();
|
||||
named_expr->expression_ = ctx->expression()->accept(this);
|
||||
named_expr->name_ = std::string(ctx->variable()->accept(this).as<std::string>());
|
||||
for_each->named_expression_ = named_expr;
|
||||
|
||||
for (auto *update_clause_ctx : ctx->updateClause()) {
|
||||
if (auto *set = update_clause_ctx->set(); set) {
|
||||
auto set_items = visitSet(set).as<std::vector<Clause *>>();
|
||||
std::copy(set_items.begin(), set_items.end(), std::back_inserter(for_each->clauses_));
|
||||
} else if (auto *remove = update_clause_ctx->remove(); remove) {
|
||||
auto remove_items = visitRemove(remove).as<std::vector<Clause *>>();
|
||||
std::copy(remove_items.begin(), remove_items.end(), std::back_inserter(for_each->clauses_));
|
||||
} else if (auto *merge = update_clause_ctx->merge(); merge) {
|
||||
for_each->clauses_.push_back(visitMerge(merge).as<Merge *>());
|
||||
} else if (auto *create = update_clause_ctx->create(); create) {
|
||||
for_each->clauses_.push_back(visitCreate(create).as<Create *>());
|
||||
} else if (auto *cypher_delete = update_clause_ctx->cypherDelete(); cypher_delete) {
|
||||
for_each->clauses_.push_back(visitCypherDelete(cypher_delete).as<Delete *>());
|
||||
} else {
|
||||
auto *nested_for_each = update_clause_ctx->foreach ();
|
||||
MG_ASSERT(nested_for_each != nullptr, "Unexpected clause in FOREACH");
|
||||
for_each->clauses_.push_back(visitForeach(nested_for_each).as<Foreach *>());
|
||||
}
|
||||
}
|
||||
|
||||
return for_each;
|
||||
}
|
||||
|
||||
LabelIx CypherMainVisitor::AddLabel(const std::string &name) { return storage_->GetLabelIx(name); }
|
||||
|
||||
PropertyIx CypherMainVisitor::AddProperty(const std::string &name) { return storage_->GetPropertyIx(name); }
|
||||
|
@ -844,6 +844,11 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
|
||||
*/
|
||||
antlrcpp::Any visitFilterExpression(MemgraphCypher::FilterExpressionContext *) override;
|
||||
|
||||
/**
|
||||
* @return Foreach*
|
||||
*/
|
||||
antlrcpp::Any visitForeach(MemgraphCypher::ForeachContext *ctx) override;
|
||||
|
||||
public:
|
||||
Query *query() { return query_; }
|
||||
const static std::string kAnonPrefix;
|
||||
|
@ -47,6 +47,7 @@ memgraphCypherKeyword : cypherKeyword
|
||||
| DUMP
|
||||
| EXECUTE
|
||||
| FOR
|
||||
| FOREACH
|
||||
| FREE
|
||||
| FROM
|
||||
| GLOBAL
|
||||
@ -163,8 +164,19 @@ clause : cypherMatch
|
||||
| cypherReturn
|
||||
| callProcedure
|
||||
| loadCsv
|
||||
| foreach
|
||||
;
|
||||
|
||||
updateClause : set
|
||||
| remove
|
||||
| create
|
||||
| merge
|
||||
| cypherDelete
|
||||
| foreach
|
||||
;
|
||||
|
||||
foreach : FOREACH '(' variable IN expression '|' updateClause+ ')' ;
|
||||
|
||||
streamQuery : checkStream
|
||||
| createStream
|
||||
| dropStream
|
||||
|
@ -54,6 +54,7 @@ DUMP : D U M P ;
|
||||
DURABILITY : D U R A B I L I T Y ;
|
||||
EXECUTE : E X E C U T E ;
|
||||
FOR : F O R ;
|
||||
FOREACH : F O R E A C H;
|
||||
FREE : F R E E ;
|
||||
FREE_MEMORY : F R E E UNDERSCORE M E M O R Y ;
|
||||
FROM : F R O M ;
|
||||
|
@ -15,7 +15,9 @@
|
||||
|
||||
#include "query/frontend/semantic/symbol_generator.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <optional>
|
||||
#include <ranges>
|
||||
#include <unordered_set>
|
||||
#include <variant>
|
||||
|
||||
@ -39,35 +41,56 @@ std::unordered_map<std::string, Identifier *> GeneratePredefinedIdentifierMap(
|
||||
} // namespace
|
||||
|
||||
SymbolGenerator::SymbolGenerator(SymbolTable *symbol_table, const std::vector<Identifier *> &predefined_identifiers)
|
||||
: symbol_table_(symbol_table), predefined_identifiers_{GeneratePredefinedIdentifierMap(predefined_identifiers)} {}
|
||||
: symbol_table_(symbol_table),
|
||||
predefined_identifiers_{GeneratePredefinedIdentifierMap(predefined_identifiers)},
|
||||
scopes_(1, Scope()) {}
|
||||
|
||||
auto SymbolGenerator::CreateSymbol(const std::string &name, bool user_declared, Symbol::Type type, int token_position) {
|
||||
auto symbol = symbol_table_->CreateSymbol(name, user_declared, type, token_position);
|
||||
scope_.symbols[name] = symbol;
|
||||
return symbol;
|
||||
}
|
||||
|
||||
auto SymbolGenerator::GetOrCreateSymbol(const std::string &name, bool user_declared, Symbol::Type type) {
|
||||
auto search = scope_.symbols.find(name);
|
||||
if (search != scope_.symbols.end()) {
|
||||
auto symbol = search->second;
|
||||
std::optional<Symbol> SymbolGenerator::FindSymbolInScope(const std::string &name, const Scope &scope,
|
||||
Symbol::Type type) {
|
||||
if (auto it = scope.symbols.find(name); it != scope.symbols.end()) {
|
||||
const auto &symbol = it->second;
|
||||
// Unless we have `ANY` type, check that types match.
|
||||
if (type != Symbol::Type::ANY && symbol.type() != Symbol::Type::ANY && type != symbol.type()) {
|
||||
throw TypeMismatchError(name, Symbol::TypeToString(symbol.type()), Symbol::TypeToString(type));
|
||||
}
|
||||
return search->second;
|
||||
return symbol;
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto SymbolGenerator::CreateSymbol(const std::string &name, bool user_declared, Symbol::Type type, int token_position) {
|
||||
auto symbol = symbol_table_->CreateSymbol(name, user_declared, type, token_position);
|
||||
scopes_.back().symbols[name] = symbol;
|
||||
return symbol;
|
||||
}
|
||||
|
||||
auto SymbolGenerator::GetOrCreateSymbolLocalScope(const std::string &name, bool user_declared, Symbol::Type type) {
|
||||
auto &scope = scopes_.back();
|
||||
if (auto maybe_symbol = FindSymbolInScope(name, scope, type); maybe_symbol) {
|
||||
return *maybe_symbol;
|
||||
}
|
||||
return CreateSymbol(name, user_declared, type);
|
||||
}
|
||||
|
||||
auto SymbolGenerator::GetOrCreateSymbol(const std::string &name, bool user_declared, Symbol::Type type) {
|
||||
// NOLINTNEXTLINE
|
||||
for (auto scope = scopes_.rbegin(); scope != scopes_.rend(); ++scope) {
|
||||
if (auto maybe_symbol = FindSymbolInScope(name, *scope, type); maybe_symbol) {
|
||||
return *maybe_symbol;
|
||||
}
|
||||
}
|
||||
return CreateSymbol(name, user_declared, type);
|
||||
}
|
||||
|
||||
void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) {
|
||||
auto &scope = scopes_.back();
|
||||
for (auto &expr : body.named_expressions) {
|
||||
expr->Accept(*this);
|
||||
}
|
||||
std::vector<Symbol> user_symbols;
|
||||
if (body.all_identifiers) {
|
||||
// Carry over user symbols because '*' appeared.
|
||||
for (auto sym_pair : scope_.symbols) {
|
||||
for (const auto &sym_pair : scope.symbols) {
|
||||
if (!sym_pair.second.user_declared()) {
|
||||
continue;
|
||||
}
|
||||
@ -81,18 +104,18 @@ void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) {
|
||||
// declares only those established through named expressions. New declarations
|
||||
// must not be visible inside named expressions themselves.
|
||||
bool removed_old_names = false;
|
||||
if ((!where && body.order_by.empty()) || scope_.has_aggregation) {
|
||||
if ((!where && body.order_by.empty()) || scope.has_aggregation) {
|
||||
// WHERE and ORDER BY need to see both the old and new symbols, unless we
|
||||
// have an aggregation. Therefore, we can clear the symbols immediately if
|
||||
// there is neither ORDER BY nor WHERE, or we have an aggregation.
|
||||
scope_.symbols.clear();
|
||||
scope.symbols.clear();
|
||||
removed_old_names = true;
|
||||
}
|
||||
// Create symbols for named expressions.
|
||||
std::unordered_set<std::string> new_names;
|
||||
for (const auto &user_sym : user_symbols) {
|
||||
new_names.insert(user_sym.name());
|
||||
scope_.symbols[user_sym.name()] = user_sym;
|
||||
scope.symbols[user_sym.name()] = user_sym;
|
||||
}
|
||||
for (auto &named_expr : body.named_expressions) {
|
||||
const auto &name = named_expr->name_;
|
||||
@ -103,35 +126,35 @@ void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) {
|
||||
// new symbol would have a more specific type.
|
||||
named_expr->MapTo(CreateSymbol(name, true, Symbol::Type::ANY, named_expr->token_position_));
|
||||
}
|
||||
scope_.in_order_by = true;
|
||||
scope.in_order_by = true;
|
||||
for (const auto &order_pair : body.order_by) {
|
||||
order_pair.expression->Accept(*this);
|
||||
}
|
||||
scope_.in_order_by = false;
|
||||
scope.in_order_by = false;
|
||||
if (body.skip) {
|
||||
scope_.in_skip = true;
|
||||
scope.in_skip = true;
|
||||
body.skip->Accept(*this);
|
||||
scope_.in_skip = false;
|
||||
scope.in_skip = false;
|
||||
}
|
||||
if (body.limit) {
|
||||
scope_.in_limit = true;
|
||||
scope.in_limit = true;
|
||||
body.limit->Accept(*this);
|
||||
scope_.in_limit = false;
|
||||
scope.in_limit = false;
|
||||
}
|
||||
if (where) where->Accept(*this);
|
||||
if (!removed_old_names) {
|
||||
// We have an ORDER BY or WHERE, but no aggregation, which means we didn't
|
||||
// clear the old symbols, so do it now. We cannot just call clear, because
|
||||
// we've added new symbols.
|
||||
for (auto sym_it = scope_.symbols.begin(); sym_it != scope_.symbols.end();) {
|
||||
for (auto sym_it = scope.symbols.begin(); sym_it != scope.symbols.end();) {
|
||||
if (new_names.find(sym_it->first) == new_names.end()) {
|
||||
sym_it = scope_.symbols.erase(sym_it);
|
||||
sym_it = scope.symbols.erase(sym_it);
|
||||
} else {
|
||||
sym_it++;
|
||||
}
|
||||
}
|
||||
}
|
||||
scope_.has_aggregation = false;
|
||||
scopes_.back().has_aggregation = false;
|
||||
}
|
||||
|
||||
// Query
|
||||
@ -145,7 +168,7 @@ bool SymbolGenerator::PreVisit(SingleQuery &) {
|
||||
// Union
|
||||
|
||||
bool SymbolGenerator::PreVisit(CypherUnion &) {
|
||||
scope_ = Scope();
|
||||
scopes_.back() = Scope();
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -166,11 +189,11 @@ bool SymbolGenerator::PostVisit(CypherUnion &cypher_union) {
|
||||
// Clauses
|
||||
|
||||
bool SymbolGenerator::PreVisit(Create &) {
|
||||
scope_.in_create = true;
|
||||
scopes_.back().in_create = true;
|
||||
return true;
|
||||
}
|
||||
bool SymbolGenerator::PostVisit(Create &) {
|
||||
scope_.in_create = false;
|
||||
scopes_.back().in_create = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -183,7 +206,7 @@ bool SymbolGenerator::PreVisit(CallProcedure &call_proc) {
|
||||
|
||||
bool SymbolGenerator::PostVisit(CallProcedure &call_proc) {
|
||||
for (auto *ident : call_proc.result_identifiers_) {
|
||||
if (HasSymbol(ident->name_)) {
|
||||
if (HasSymbolLocalScope(ident->name_)) {
|
||||
throw RedeclareVariableError(ident->name_);
|
||||
}
|
||||
ident->MapTo(CreateSymbol(ident->name_, true));
|
||||
@ -194,7 +217,7 @@ bool SymbolGenerator::PostVisit(CallProcedure &call_proc) {
|
||||
bool SymbolGenerator::PreVisit(LoadCsv &load_csv) { return false; }
|
||||
|
||||
bool SymbolGenerator::PostVisit(LoadCsv &load_csv) {
|
||||
if (HasSymbol(load_csv.row_var_->name_)) {
|
||||
if (HasSymbolLocalScope(load_csv.row_var_->name_)) {
|
||||
throw RedeclareVariableError(load_csv.row_var_->name_);
|
||||
}
|
||||
load_csv.row_var_->MapTo(CreateSymbol(load_csv.row_var_->name_, true));
|
||||
@ -202,45 +225,47 @@ bool SymbolGenerator::PostVisit(LoadCsv &load_csv) {
|
||||
}
|
||||
|
||||
bool SymbolGenerator::PreVisit(Return &ret) {
|
||||
scope_.in_return = true;
|
||||
auto &scope = scopes_.back();
|
||||
scope.in_return = true;
|
||||
VisitReturnBody(ret.body_);
|
||||
scope_.in_return = false;
|
||||
scope.in_return = false;
|
||||
return false; // We handled the traversal ourselves.
|
||||
}
|
||||
|
||||
bool SymbolGenerator::PostVisit(Return &) {
|
||||
for (const auto &name_symbol : scope_.symbols) curr_return_names_.insert(name_symbol.first);
|
||||
for (const auto &name_symbol : scopes_.back().symbols) curr_return_names_.insert(name_symbol.first);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SymbolGenerator::PreVisit(With &with) {
|
||||
scope_.in_with = true;
|
||||
auto &scope = scopes_.back();
|
||||
scope.in_with = true;
|
||||
VisitReturnBody(with.body_, with.where_);
|
||||
scope_.in_with = false;
|
||||
scope.in_with = false;
|
||||
return false; // We handled the traversal ourselves.
|
||||
}
|
||||
|
||||
bool SymbolGenerator::PreVisit(Where &) {
|
||||
scope_.in_where = true;
|
||||
scopes_.back().in_where = true;
|
||||
return true;
|
||||
}
|
||||
bool SymbolGenerator::PostVisit(Where &) {
|
||||
scope_.in_where = false;
|
||||
scopes_.back().in_where = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SymbolGenerator::PreVisit(Merge &) {
|
||||
scope_.in_merge = true;
|
||||
scopes_.back().in_merge = true;
|
||||
return true;
|
||||
}
|
||||
bool SymbolGenerator::PostVisit(Merge &) {
|
||||
scope_.in_merge = false;
|
||||
scopes_.back().in_merge = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SymbolGenerator::PostVisit(Unwind &unwind) {
|
||||
const auto &name = unwind.named_expression_->name_;
|
||||
if (HasSymbol(name)) {
|
||||
if (HasSymbolLocalScope(name)) {
|
||||
throw RedeclareVariableError(name);
|
||||
}
|
||||
unwind.named_expression_->MapTo(CreateSymbol(name, true));
|
||||
@ -248,55 +273,70 @@ bool SymbolGenerator::PostVisit(Unwind &unwind) {
|
||||
}
|
||||
|
||||
bool SymbolGenerator::PreVisit(Match &) {
|
||||
scope_.in_match = true;
|
||||
scopes_.back().in_match = true;
|
||||
return true;
|
||||
}
|
||||
bool SymbolGenerator::PostVisit(Match &) {
|
||||
scope_.in_match = false;
|
||||
auto &scope = scopes_.back();
|
||||
scope.in_match = false;
|
||||
// Check variables in property maps after visiting Match, so that they can
|
||||
// reference symbols out of bind order.
|
||||
for (auto &ident : scope_.identifiers_in_match) {
|
||||
if (!HasSymbol(ident->name_) && !ConsumePredefinedIdentifier(ident->name_))
|
||||
for (auto &ident : scope.identifiers_in_match) {
|
||||
if (!HasSymbolLocalScope(ident->name_) && !ConsumePredefinedIdentifier(ident->name_))
|
||||
throw UnboundVariableError(ident->name_);
|
||||
ident->MapTo(scope_.symbols[ident->name_]);
|
||||
ident->MapTo(scope.symbols[ident->name_]);
|
||||
}
|
||||
scope_.identifiers_in_match.clear();
|
||||
scope.identifiers_in_match.clear();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SymbolGenerator::PreVisit(Foreach &for_each) {
|
||||
const auto &name = for_each.named_expression_->name_;
|
||||
scopes_.emplace_back(Scope());
|
||||
scopes_.back().in_foreach = true;
|
||||
for_each.named_expression_->MapTo(
|
||||
CreateSymbol(name, true, Symbol::Type::ANY, for_each.named_expression_->token_position_));
|
||||
return true;
|
||||
}
|
||||
bool SymbolGenerator::PostVisit([[maybe_unused]] Foreach &for_each) {
|
||||
scopes_.pop_back();
|
||||
return true;
|
||||
}
|
||||
|
||||
// Expressions
|
||||
|
||||
SymbolGenerator::ReturnType SymbolGenerator::Visit(Identifier &ident) {
|
||||
if (scope_.in_skip || scope_.in_limit) {
|
||||
throw SemanticException("Variables are not allowed in {}.", scope_.in_skip ? "SKIP" : "LIMIT");
|
||||
auto &scope = scopes_.back();
|
||||
if (scope.in_skip || scope.in_limit) {
|
||||
throw SemanticException("Variables are not allowed in {}.", scope.in_skip ? "SKIP" : "LIMIT");
|
||||
}
|
||||
Symbol symbol;
|
||||
if (scope_.in_pattern && !(scope_.in_node_atom || scope_.visiting_edge)) {
|
||||
if (scope.in_pattern && !(scope.in_node_atom || scope.visiting_edge)) {
|
||||
// If we are in the pattern, and outside of a node or an edge, the
|
||||
// identifier is the pattern name.
|
||||
symbol = GetOrCreateSymbol(ident.name_, ident.user_declared_, Symbol::Type::PATH);
|
||||
} else if (scope_.in_pattern && scope_.in_pattern_atom_identifier) {
|
||||
symbol = GetOrCreateSymbolLocalScope(ident.name_, ident.user_declared_, Symbol::Type::PATH);
|
||||
} else if (scope.in_pattern && scope.in_pattern_atom_identifier) {
|
||||
// Patterns used to create nodes and edges cannot redeclare already
|
||||
// established bindings. Declaration only happens in single node
|
||||
// patterns and in edge patterns. OpenCypher example,
|
||||
// `MATCH (n) CREATE (n)` should throw an error that `n` is already
|
||||
// declared. While `MATCH (n) CREATE (n) -[:R]-> (n)` is allowed,
|
||||
// since `n` now references the bound node instead of declaring it.
|
||||
if ((scope_.in_create_node || scope_.in_create_edge) && HasSymbol(ident.name_)) {
|
||||
if ((scope.in_create_node || scope.in_create_edge) && HasSymbolLocalScope(ident.name_)) {
|
||||
throw RedeclareVariableError(ident.name_);
|
||||
}
|
||||
auto type = Symbol::Type::VERTEX;
|
||||
if (scope_.visiting_edge) {
|
||||
if (scope.visiting_edge) {
|
||||
// Edge referencing is not allowed (like in Neo4j):
|
||||
// `MATCH (n) - [r] -> (n) - [r] -> (n) RETURN r` is not allowed.
|
||||
if (HasSymbol(ident.name_)) {
|
||||
if (HasSymbolLocalScope(ident.name_)) {
|
||||
throw RedeclareVariableError(ident.name_);
|
||||
}
|
||||
type = scope_.visiting_edge->IsVariable() ? Symbol::Type::EDGE_LIST : Symbol::Type::EDGE;
|
||||
type = scope.visiting_edge->IsVariable() ? Symbol::Type::EDGE_LIST : Symbol::Type::EDGE;
|
||||
}
|
||||
symbol = GetOrCreateSymbol(ident.name_, ident.user_declared_, type);
|
||||
} else if (scope_.in_pattern && !scope_.in_pattern_atom_identifier && scope_.in_match) {
|
||||
if (scope_.in_edge_range && scope_.visiting_edge->identifier_->name_ == ident.name_) {
|
||||
symbol = GetOrCreateSymbolLocalScope(ident.name_, ident.user_declared_, type);
|
||||
} else if (scope.in_pattern && !scope.in_pattern_atom_identifier && scope.in_match) {
|
||||
if (scope.in_edge_range && scope.visiting_edge->identifier_->name_ == ident.name_) {
|
||||
// Prevent variable path bounds to reference the identifier which is bound
|
||||
// by the variable path itself.
|
||||
throw UnboundVariableError(ident.name_);
|
||||
@ -304,30 +344,30 @@ SymbolGenerator::ReturnType SymbolGenerator::Visit(Identifier &ident) {
|
||||
// Variables in property maps or bounds of variable length path during MATCH
|
||||
// can reference symbols bound later in the same MATCH. We collect them
|
||||
// here, so that they can be checked after visiting Match.
|
||||
scope_.identifiers_in_match.emplace_back(&ident);
|
||||
scope.identifiers_in_match.emplace_back(&ident);
|
||||
} else {
|
||||
// Everything else references a bound symbol.
|
||||
if (!HasSymbol(ident.name_) && !ConsumePredefinedIdentifier(ident.name_)) throw UnboundVariableError(ident.name_);
|
||||
symbol = scope_.symbols[ident.name_];
|
||||
symbol = GetOrCreateSymbol(ident.name_, ident.user_declared_, Symbol::Type::ANY);
|
||||
}
|
||||
ident.MapTo(symbol);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SymbolGenerator::PreVisit(Aggregation &aggr) {
|
||||
auto &scope = scopes_.back();
|
||||
// Check if the aggregation can be used in this context. This check should
|
||||
// probably move to a separate phase, which checks if the query is well
|
||||
// formed.
|
||||
if ((!scope_.in_return && !scope_.in_with) || scope_.in_order_by || scope_.in_skip || scope_.in_limit ||
|
||||
scope_.in_where) {
|
||||
if ((!scope.in_return && !scope.in_with) || scope.in_order_by || scope.in_skip || scope.in_limit || scope.in_where) {
|
||||
throw SemanticException("Aggregation functions are only allowed in WITH and RETURN.");
|
||||
}
|
||||
if (scope_.in_aggregation) {
|
||||
if (scope.in_aggregation) {
|
||||
throw SemanticException(
|
||||
"Using aggregation functions inside aggregation functions is not "
|
||||
"allowed.");
|
||||
}
|
||||
if (scope_.num_if_operators) {
|
||||
if (scope.num_if_operators) {
|
||||
// Neo allows aggregations here and produces very interesting behaviors.
|
||||
// To simplify implementation at this moment we decided to completely
|
||||
// disallow aggregations inside of the CASE.
|
||||
@ -341,23 +381,23 @@ bool SymbolGenerator::PreVisit(Aggregation &aggr) {
|
||||
// Currently, we only have aggregation operators which return numbers.
|
||||
auto aggr_name = Aggregation::OpToString(aggr.op_) + std::to_string(aggr.symbol_pos_);
|
||||
aggr.MapTo(CreateSymbol(aggr_name, false, Symbol::Type::NUMBER));
|
||||
scope_.in_aggregation = true;
|
||||
scope_.has_aggregation = true;
|
||||
scope.in_aggregation = true;
|
||||
scope.has_aggregation = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SymbolGenerator::PostVisit(Aggregation &) {
|
||||
scope_.in_aggregation = false;
|
||||
scopes_.back().in_aggregation = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SymbolGenerator::PreVisit(IfOperator &) {
|
||||
++scope_.num_if_operators;
|
||||
++scopes_.back().num_if_operators;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SymbolGenerator::PostVisit(IfOperator &) {
|
||||
--scope_.num_if_operators;
|
||||
--scopes_.back().num_if_operators;
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -401,33 +441,36 @@ bool SymbolGenerator::PreVisit(Extract &extract) {
|
||||
// Pattern and its subparts.
|
||||
|
||||
bool SymbolGenerator::PreVisit(Pattern &pattern) {
|
||||
scope_.in_pattern = true;
|
||||
if ((scope_.in_create || scope_.in_merge) && pattern.atoms_.size() == 1U) {
|
||||
auto &scope = scopes_.back();
|
||||
scope.in_pattern = true;
|
||||
if ((scope.in_create || scope.in_merge) && pattern.atoms_.size() == 1U) {
|
||||
MG_ASSERT(utils::IsSubtype(*pattern.atoms_[0], NodeAtom::kType), "Expected a single NodeAtom in Pattern");
|
||||
scope_.in_create_node = true;
|
||||
scope.in_create_node = true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SymbolGenerator::PostVisit(Pattern &) {
|
||||
scope_.in_pattern = false;
|
||||
scope_.in_create_node = false;
|
||||
auto &scope = scopes_.back();
|
||||
scope.in_pattern = false;
|
||||
scope.in_create_node = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SymbolGenerator::PreVisit(NodeAtom &node_atom) {
|
||||
auto check_node_semantic = [&node_atom, this](const bool props_or_labels) {
|
||||
auto &scope = scopes_.back();
|
||||
auto check_node_semantic = [&node_atom, &scope, this](const bool props_or_labels) {
|
||||
const auto &node_name = node_atom.identifier_->name_;
|
||||
if ((scope_.in_create || scope_.in_merge) && props_or_labels && HasSymbol(node_name)) {
|
||||
if ((scope.in_create || scope.in_merge) && props_or_labels && HasSymbolLocalScope(node_name)) {
|
||||
throw SemanticException("Cannot create node '" + node_name +
|
||||
"' with labels or properties, because it is already declared.");
|
||||
}
|
||||
scope_.in_pattern_atom_identifier = true;
|
||||
scope.in_pattern_atom_identifier = true;
|
||||
node_atom.identifier_->Accept(*this);
|
||||
scope_.in_pattern_atom_identifier = false;
|
||||
scope.in_pattern_atom_identifier = false;
|
||||
};
|
||||
|
||||
scope_.in_node_atom = true;
|
||||
scope.in_node_atom = true;
|
||||
if (auto *properties = std::get_if<std::unordered_map<PropertyIx, Expression *>>(&node_atom.properties_)) {
|
||||
bool props_or_labels = !properties->empty() || !node_atom.labels_.empty();
|
||||
|
||||
@ -447,20 +490,21 @@ bool SymbolGenerator::PreVisit(NodeAtom &node_atom) {
|
||||
}
|
||||
|
||||
bool SymbolGenerator::PostVisit(NodeAtom &) {
|
||||
scope_.in_node_atom = false;
|
||||
scopes_.back().in_node_atom = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SymbolGenerator::PreVisit(EdgeAtom &edge_atom) {
|
||||
scope_.visiting_edge = &edge_atom;
|
||||
if (scope_.in_create || scope_.in_merge) {
|
||||
scope_.in_create_edge = true;
|
||||
auto &scope = scopes_.back();
|
||||
scope.visiting_edge = &edge_atom;
|
||||
if (scope.in_create || scope.in_merge) {
|
||||
scope.in_create_edge = true;
|
||||
if (edge_atom.edge_types_.size() != 1U) {
|
||||
throw SemanticException(
|
||||
"A single relationship type must be specified "
|
||||
"when creating an edge.");
|
||||
}
|
||||
if (scope_.in_create && // Merge allows bidirectionality
|
||||
if (scope.in_create && // Merge allows bidirectionality
|
||||
edge_atom.direction_ == EdgeAtom::Direction::BOTH) {
|
||||
throw SemanticException(
|
||||
"Bidirectional relationship are not supported "
|
||||
@ -480,15 +524,15 @@ bool SymbolGenerator::PreVisit(EdgeAtom &edge_atom) {
|
||||
std::get<ParameterLookup *>(edge_atom.properties_)->Accept(*this);
|
||||
}
|
||||
if (edge_atom.IsVariable()) {
|
||||
scope_.in_edge_range = true;
|
||||
scope.in_edge_range = true;
|
||||
if (edge_atom.lower_bound_) {
|
||||
edge_atom.lower_bound_->Accept(*this);
|
||||
}
|
||||
if (edge_atom.upper_bound_) {
|
||||
edge_atom.upper_bound_->Accept(*this);
|
||||
}
|
||||
scope_.in_edge_range = false;
|
||||
scope_.in_pattern = false;
|
||||
scope.in_edge_range = false;
|
||||
scope.in_pattern = false;
|
||||
if (edge_atom.filter_lambda_.expression) {
|
||||
VisitWithIdentifiers(edge_atom.filter_lambda_.expression,
|
||||
{edge_atom.filter_lambda_.inner_edge, edge_atom.filter_lambda_.inner_node});
|
||||
@ -505,34 +549,36 @@ bool SymbolGenerator::PreVisit(EdgeAtom &edge_atom) {
|
||||
VisitWithIdentifiers(edge_atom.weight_lambda_.expression,
|
||||
{edge_atom.weight_lambda_.inner_edge, edge_atom.weight_lambda_.inner_node});
|
||||
}
|
||||
scope_.in_pattern = true;
|
||||
scope.in_pattern = true;
|
||||
}
|
||||
scope_.in_pattern_atom_identifier = true;
|
||||
scope.in_pattern_atom_identifier = true;
|
||||
edge_atom.identifier_->Accept(*this);
|
||||
scope_.in_pattern_atom_identifier = false;
|
||||
scope.in_pattern_atom_identifier = false;
|
||||
if (edge_atom.total_weight_) {
|
||||
if (HasSymbol(edge_atom.total_weight_->name_)) {
|
||||
if (HasSymbolLocalScope(edge_atom.total_weight_->name_)) {
|
||||
throw RedeclareVariableError(edge_atom.total_weight_->name_);
|
||||
}
|
||||
edge_atom.total_weight_->MapTo(GetOrCreateSymbol(edge_atom.total_weight_->name_,
|
||||
edge_atom.total_weight_->user_declared_, Symbol::Type::NUMBER));
|
||||
edge_atom.total_weight_->MapTo(GetOrCreateSymbolLocalScope(
|
||||
edge_atom.total_weight_->name_, edge_atom.total_weight_->user_declared_, Symbol::Type::NUMBER));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool SymbolGenerator::PostVisit(EdgeAtom &) {
|
||||
scope_.visiting_edge = nullptr;
|
||||
scope_.in_create_edge = false;
|
||||
auto &scope = scopes_.back();
|
||||
scope.visiting_edge = nullptr;
|
||||
scope.in_create_edge = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
void SymbolGenerator::VisitWithIdentifiers(Expression *expr, const std::vector<Identifier *> &identifiers) {
|
||||
auto &scope = scopes_.back();
|
||||
std::vector<std::pair<std::optional<Symbol>, Identifier *>> prev_symbols;
|
||||
// Collect previous symbols if they exist.
|
||||
for (const auto &identifier : identifiers) {
|
||||
std::optional<Symbol> prev_symbol;
|
||||
auto prev_symbol_it = scope_.symbols.find(identifier->name_);
|
||||
if (prev_symbol_it != scope_.symbols.end()) {
|
||||
auto prev_symbol_it = scope.symbols.find(identifier->name_);
|
||||
if (prev_symbol_it != scope.symbols.end()) {
|
||||
prev_symbol = prev_symbol_it->second;
|
||||
}
|
||||
identifier->MapTo(CreateSymbol(identifier->name_, identifier->user_declared_));
|
||||
@ -545,14 +591,20 @@ void SymbolGenerator::VisitWithIdentifiers(Expression *expr, const std::vector<I
|
||||
const auto &prev_symbol = prev.first;
|
||||
const auto &identifier = prev.second;
|
||||
if (prev_symbol) {
|
||||
scope_.symbols[identifier->name_] = *prev_symbol;
|
||||
scope.symbols[identifier->name_] = *prev_symbol;
|
||||
} else {
|
||||
scope_.symbols.erase(identifier->name_);
|
||||
scope.symbols.erase(identifier->name_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool SymbolGenerator::HasSymbol(const std::string &name) { return scope_.symbols.find(name) != scope_.symbols.end(); }
|
||||
bool SymbolGenerator::HasSymbol(const std::string &name) const {
|
||||
return std::ranges::any_of(scopes_, [&name](const auto &scope) { return scope.symbols.contains(name); });
|
||||
}
|
||||
|
||||
bool SymbolGenerator::HasSymbolLocalScope(const std::string &name) const {
|
||||
return scopes_.back().symbols.contains(name);
|
||||
}
|
||||
|
||||
bool SymbolGenerator::ConsumePredefinedIdentifier(const std::string &name) {
|
||||
auto it = predefined_identifiers_.find(name);
|
||||
|
@ -15,6 +15,9 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
#include "query/exceptions.hpp"
|
||||
#include "query/frontend/ast/ast.hpp"
|
||||
#include "query/frontend/semantic/symbol_table.hpp"
|
||||
@ -59,6 +62,8 @@ class SymbolGenerator : public HierarchicalTreeVisitor {
|
||||
bool PostVisit(Unwind &) override;
|
||||
bool PreVisit(Match &) override;
|
||||
bool PostVisit(Match &) override;
|
||||
bool PreVisit(Foreach &) override;
|
||||
bool PostVisit(Foreach &) override;
|
||||
|
||||
// Expressions
|
||||
ReturnType Visit(Identifier &) override;
|
||||
@ -107,6 +112,7 @@ class SymbolGenerator : public HierarchicalTreeVisitor {
|
||||
bool in_order_by{false};
|
||||
bool in_where{false};
|
||||
bool in_match{false};
|
||||
bool in_foreach{false};
|
||||
// True when visiting a pattern atom (node or edge) identifier, which can be
|
||||
// reused or created in the pattern itself.
|
||||
bool in_pattern_atom_identifier{false};
|
||||
@ -125,7 +131,10 @@ class SymbolGenerator : public HierarchicalTreeVisitor {
|
||||
int num_if_operators{0};
|
||||
};
|
||||
|
||||
bool HasSymbol(const std::string &name);
|
||||
static std::optional<Symbol> FindSymbolInScope(const std::string &name, const Scope &scope, Symbol::Type type);
|
||||
|
||||
bool HasSymbol(const std::string &name) const;
|
||||
bool HasSymbolLocalScope(const std::string &name) const;
|
||||
|
||||
// @return true if it added a predefined identifier with that name
|
||||
bool ConsumePredefinedIdentifier(const std::string &name);
|
||||
@ -135,9 +144,10 @@ class SymbolGenerator : public HierarchicalTreeVisitor {
|
||||
auto CreateSymbol(const std::string &name, bool user_declared, Symbol::Type type = Symbol::Type::ANY,
|
||||
int token_position = -1);
|
||||
|
||||
auto GetOrCreateSymbol(const std::string &name, bool user_declared, Symbol::Type type = Symbol::Type::ANY);
|
||||
// Returns the symbol by name. If the mapping already exists, checks if the
|
||||
// types match. Otherwise, returns a new symbol.
|
||||
auto GetOrCreateSymbol(const std::string &name, bool user_declared, Symbol::Type type = Symbol::Type::ANY);
|
||||
auto GetOrCreateSymbolLocalScope(const std::string &name, bool user_declared, Symbol::Type type = Symbol::Type::ANY);
|
||||
|
||||
void VisitReturnBody(ReturnBody &body, Where *where = nullptr);
|
||||
|
||||
@ -148,7 +158,7 @@ class SymbolGenerator : public HierarchicalTreeVisitor {
|
||||
// Identifiers which are injected from outside the query. Each identifier
|
||||
// is mapped by its name.
|
||||
std::unordered_map<std::string, Identifier *> predefined_identifiers_;
|
||||
Scope scope_;
|
||||
std::vector<Scope> scopes_;
|
||||
std::unordered_set<std::string> prev_return_names_;
|
||||
std::unordered_set<std::string> curr_return_names_;
|
||||
};
|
||||
|
@ -204,7 +204,8 @@ const trie::Trie kKeywords = {"union",
|
||||
"pulsar",
|
||||
"service_url",
|
||||
"version",
|
||||
"websocket"};
|
||||
"websocket"
|
||||
"foreach"};
|
||||
|
||||
// Unicode codepoints that are allowed at the start of the unescaped name.
|
||||
const std::bitset<kBitsetSize> kUnescapedNameAllowedStarts(
|
||||
|
@ -61,6 +61,7 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor {
|
||||
static constexpr double kFilter{1.5};
|
||||
static constexpr double kEdgeUniquenessFilter{1.5};
|
||||
static constexpr double kUnwind{1.3};
|
||||
static constexpr double kForeach{1.0};
|
||||
};
|
||||
|
||||
struct CardParam {
|
||||
@ -72,6 +73,7 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor {
|
||||
|
||||
struct MiscParam {
|
||||
static constexpr double kUnwindNoLiteral{10.0};
|
||||
static constexpr double kForeachNoLiteral{10.0};
|
||||
};
|
||||
|
||||
using HierarchicalLogicalOperatorVisitor::PostVisit;
|
||||
@ -193,6 +195,23 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PostVisit(Foreach &foreach) override {
|
||||
// Foreach cost depends both on the number elements in the list that get unwound
|
||||
// as well as the total clauses that get called for each unwounded element.
|
||||
// First estimate cardinality and then increment the cost.
|
||||
|
||||
double foreach_elements{0};
|
||||
if (auto *literal = utils::Downcast<query::ListLiteral>(foreach.expression_)) {
|
||||
foreach_elements = literal->elements_.size();
|
||||
} else {
|
||||
foreach_elements = MiscParam::kForeachNoLiteral;
|
||||
}
|
||||
|
||||
cardinality_ *= foreach_elements;
|
||||
IncrementCost(CostParam::kForeach);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Visit(Once &) override { return true; }
|
||||
|
||||
auto cost() const { return cost_; }
|
||||
|
@ -45,6 +45,7 @@
|
||||
#include "utils/fnv.hpp"
|
||||
#include "utils/likely.hpp"
|
||||
#include "utils/logging.hpp"
|
||||
#include "utils/memory.hpp"
|
||||
#include "utils/pmr/unordered_map.hpp"
|
||||
#include "utils/pmr/unordered_set.hpp"
|
||||
#include "utils/pmr/vector.hpp"
|
||||
@ -105,6 +106,7 @@ extern const Event DistinctOperator;
|
||||
extern const Event UnionOperator;
|
||||
extern const Event CartesianOperator;
|
||||
extern const Event CallProcedureOperator;
|
||||
extern const Event ForeachOperator;
|
||||
} // namespace EventCounter
|
||||
|
||||
namespace memgraph::query::plan {
|
||||
@ -4024,4 +4026,85 @@ UniqueCursorPtr LoadCsv::MakeCursor(utils::MemoryResource *mem) const {
|
||||
return MakeUniqueCursorPtr<LoadCsvCursor>(mem, this, mem);
|
||||
};
|
||||
|
||||
class ForeachCursor : public Cursor {
|
||||
public:
|
||||
explicit ForeachCursor(const Foreach &foreach, utils::MemoryResource *mem)
|
||||
: loop_variable_symbol_(foreach.loop_variable_symbol_),
|
||||
input_(foreach.input_->MakeCursor(mem)),
|
||||
updates_(foreach.update_clauses_->MakeCursor(mem)),
|
||||
expression(foreach.expression_) {}
|
||||
|
||||
bool Pull(Frame &frame, ExecutionContext &context) override {
|
||||
SCOPED_PROFILE_OP(op_name_);
|
||||
|
||||
if (!input_->Pull(frame, context)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
|
||||
storage::View::NEW);
|
||||
TypedValue expr_result = expression->Accept(evaluator);
|
||||
|
||||
if (expr_result.IsNull()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!expr_result.IsList()) {
|
||||
throw QueryRuntimeException("FOREACH expression must resolve to a list, but got '{}'.", expr_result.type());
|
||||
}
|
||||
|
||||
const auto &cache_ = expr_result.ValueList();
|
||||
for (const auto &index : cache_) {
|
||||
frame[loop_variable_symbol_] = index;
|
||||
while (updates_->Pull(frame, context)) {
|
||||
}
|
||||
ResetUpdates();
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void Shutdown() override { input_->Shutdown(); }
|
||||
|
||||
void ResetUpdates() { updates_->Reset(); }
|
||||
|
||||
void Reset() override {
|
||||
input_->Reset();
|
||||
ResetUpdates();
|
||||
}
|
||||
|
||||
private:
|
||||
const Symbol loop_variable_symbol_;
|
||||
const UniqueCursorPtr input_;
|
||||
const UniqueCursorPtr updates_;
|
||||
Expression *expression;
|
||||
const char *op_name_{"Foreach"};
|
||||
};
|
||||
|
||||
Foreach::Foreach(std::shared_ptr<LogicalOperator> input, std::shared_ptr<LogicalOperator> updates, Expression *expr,
|
||||
Symbol loop_variable_symbol)
|
||||
: input_(input ? std::move(input) : std::make_shared<Once>()),
|
||||
update_clauses_(std::move(updates)),
|
||||
expression_(expr),
|
||||
loop_variable_symbol_(loop_variable_symbol) {}
|
||||
|
||||
UniqueCursorPtr Foreach::MakeCursor(utils::MemoryResource *mem) const {
|
||||
EventCounter::IncrementCounter(EventCounter::ForeachOperator);
|
||||
return MakeUniqueCursorPtr<ForeachCursor>(mem, *this, mem);
|
||||
}
|
||||
|
||||
std::vector<Symbol> Foreach::ModifiedSymbols(const SymbolTable &table) const {
|
||||
auto symbols = input_->ModifiedSymbols(table);
|
||||
symbols.emplace_back(loop_variable_symbol_);
|
||||
return symbols;
|
||||
}
|
||||
|
||||
bool Foreach::Accept(HierarchicalLogicalOperatorVisitor &visitor) {
|
||||
if (visitor.PreVisit(*this)) {
|
||||
input_->Accept(visitor);
|
||||
update_clauses_->Accept(visitor);
|
||||
}
|
||||
return visitor.PostVisit(*this);
|
||||
}
|
||||
|
||||
} // namespace memgraph::query::plan
|
||||
|
@ -131,6 +131,7 @@ class Union;
|
||||
class Cartesian;
|
||||
class CallProcedure;
|
||||
class LoadCsv;
|
||||
class Foreach;
|
||||
|
||||
using LogicalOperatorCompositeVisitor = utils::CompositeVisitor<
|
||||
Once, CreateNode, CreateExpand, ScanAll, ScanAllByLabel,
|
||||
@ -139,7 +140,7 @@ using LogicalOperatorCompositeVisitor = utils::CompositeVisitor<
|
||||
Expand, ExpandVariable, ConstructNamedPath, Filter, Produce, Delete,
|
||||
SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels,
|
||||
EdgeUniquenessFilter, Accumulate, Aggregate, Skip, Limit, OrderBy, Merge,
|
||||
Optional, Unwind, Distinct, Union, Cartesian, CallProcedure, LoadCsv>;
|
||||
Optional, Unwind, Distinct, Union, Cartesian, CallProcedure, LoadCsv, Foreach>;
|
||||
|
||||
using LogicalOperatorLeafVisitor = utils::LeafVisitor<Once>;
|
||||
|
||||
@ -2261,6 +2262,42 @@ at once. Instead, each call of the callback should return a single row of the ta
|
||||
(:serialize (:slk))
|
||||
(:clone))
|
||||
|
||||
(lcp:define-class foreach (logical-operator)
|
||||
((input "std::shared_ptr<LogicalOperator>" :scope :public
|
||||
:slk-save #'slk-save-operator-pointer
|
||||
:slk-load #'slk-load-operator-pointer)
|
||||
(update-clauses "std::shared_ptr<LogicalOperator>" :scope :public
|
||||
:slk-save #'slk-save-operator-pointer
|
||||
:slk-load #'slk-load-operator-pointer)
|
||||
(expression "Expression *" :scope :public
|
||||
:slk-save #'slk-save-ast-pointer
|
||||
:slk-load (slk-load-ast-pointer "Expression"))
|
||||
(loop-variable-symbol "Symbol" :scope :public))
|
||||
|
||||
(:documentation
|
||||
"Iterates over a collection of elements and applies one or more update
|
||||
clauses.
|
||||
")
|
||||
(:public
|
||||
#>cpp
|
||||
Foreach() = default;
|
||||
Foreach(std::shared_ptr<LogicalOperator> input,
|
||||
std::shared_ptr<LogicalOperator> updates,
|
||||
Expression *named_expr,
|
||||
Symbol loop_variable_symbol);
|
||||
|
||||
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
|
||||
UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override;
|
||||
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
|
||||
bool HasSingleInput() const override { return true; }
|
||||
std::shared_ptr<LogicalOperator> input() const override { return input_; }
|
||||
void set_input(std::shared_ptr<LogicalOperator> input) override {
|
||||
input_ = std::move(input);
|
||||
}
|
||||
cpp<#)
|
||||
(:serialize (:slk))
|
||||
(:clone))
|
||||
|
||||
(lcp:pop-namespace) ;; plan
|
||||
(lcp:pop-namespace) ;; query
|
||||
(lcp:pop-namespace) ;; memgraph
|
||||
|
@ -17,8 +17,10 @@
|
||||
#include <variant>
|
||||
|
||||
#include "query/exceptions.hpp"
|
||||
#include "query/frontend/ast/ast.hpp"
|
||||
#include "query/frontend/ast/ast_visitor.hpp"
|
||||
#include "query/plan/preprocess.hpp"
|
||||
#include "utils/typeinfo.hpp"
|
||||
|
||||
namespace memgraph::query::plan {
|
||||
|
||||
@ -526,6 +528,18 @@ void Filters::AnalyzeAndStoreFilter(Expression *expr, const SymbolTable &symbol_
|
||||
// as `expr1 < n.prop AND n.prop < expr2`.
|
||||
}
|
||||
|
||||
static void ParseForeach(query::Foreach &foreach, SingleQueryPart &query_part, AstStorage &storage,
|
||||
SymbolTable &symbol_table) {
|
||||
for (auto *clause : foreach.clauses_) {
|
||||
if (auto *merge = utils::Downcast<query::Merge>(clause)) {
|
||||
query_part.merge_matching.emplace_back(Matching{});
|
||||
AddMatching({merge->pattern_}, nullptr, symbol_table, storage, query_part.merge_matching.back());
|
||||
} else if (auto *nested = utils::Downcast<query::Foreach>(clause)) {
|
||||
ParseForeach(*nested, query_part, storage, symbol_table);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Converts a Query to multiple QueryParts. In the process new Ast nodes may be
|
||||
// created, e.g. filter expressions.
|
||||
std::vector<SingleQueryPart> CollectSingleQueryParts(SymbolTable &symbol_table, AstStorage &storage,
|
||||
@ -546,6 +560,8 @@ std::vector<SingleQueryPart> CollectSingleQueryParts(SymbolTable &symbol_table,
|
||||
if (auto *merge = utils::Downcast<query::Merge>(clause)) {
|
||||
query_part->merge_matching.emplace_back(Matching{});
|
||||
AddMatching({merge->pattern_}, nullptr, symbol_table, storage, query_part->merge_matching.back());
|
||||
} else if (auto *foreach = utils::Downcast<query::Foreach>(clause)) {
|
||||
ParseForeach(*foreach, *query_part, storage, symbol_table);
|
||||
} else if (utils::IsSubtype(*clause, With::kType) || utils::IsSubtype(*clause, query::Unwind::kType) ||
|
||||
utils::IsSubtype(*clause, query::CallProcedure::kType) ||
|
||||
utils::IsSubtype(*clause, query::LoadCsv::kType)) {
|
||||
|
@ -306,6 +306,10 @@ struct Matching {
|
||||
/// will produce the second `merge_matching` element. This way, if someone
|
||||
/// traverses `remaining_clauses`, the order of appearance of `Merge` clauses is
|
||||
/// in the same order as their respective `merge_matching` elements.
|
||||
/// An exception to the above rule is Foreach. Its update clauses will not be contained in
|
||||
/// the `remaining_clauses`, but rather inside the foreach itself. The order guarantee is not
|
||||
/// violated because the update clauses of the foreach are immediately processed in
|
||||
/// the `RuleBasedPlanner` as if as they were pushed into the `remaining_clauses`.
|
||||
struct SingleQueryPart {
|
||||
/// @brief All `MATCH` clauses merged into one @c Matching.
|
||||
Matching matching;
|
||||
@ -320,6 +324,10 @@ struct SingleQueryPart {
|
||||
///
|
||||
/// Since @c Merge is contained in `remaining_clauses`, this vector contains
|
||||
/// matching in the same order as @c Merge appears.
|
||||
//
|
||||
/// Foreach @c does not violate this gurantee. However, update clauses are not stored
|
||||
/// in the `remaining_clauses` but rather in the `Foreach` itself and are guranteed
|
||||
/// to be processed in the same order by the semantics of the `RuleBasedPlanner`.
|
||||
std::vector<Matching> merge_matching{};
|
||||
/// @brief All the remaining clauses (without @c Match).
|
||||
std::vector<Clause *> remaining_clauses{};
|
||||
|
@ -241,6 +241,12 @@ bool PlanPrinter::PreVisit(query::plan::Cartesian &op) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool PlanPrinter::PreVisit(query::plan::Foreach &op) {
|
||||
WithPrintLn([](auto &out) { out << "* Foreach"; });
|
||||
Branch(*op.update_clauses_);
|
||||
op.input_->Accept(*this);
|
||||
return false;
|
||||
}
|
||||
#undef PRE_VISIT
|
||||
|
||||
bool PlanPrinter::DefaultPreVisit() {
|
||||
@ -883,6 +889,21 @@ bool PlanToJsonVisitor::PreVisit(Cartesian &op) {
|
||||
output_ = std::move(self);
|
||||
return false;
|
||||
}
|
||||
bool PlanToJsonVisitor::PreVisit(Foreach &op) {
|
||||
json self;
|
||||
self["name"] = "Foreach";
|
||||
self["loop_variable_symbol"] = ToJson(op.loop_variable_symbol_);
|
||||
self["expression"] = ToJson(op.expression_);
|
||||
|
||||
op.input_->Accept(*this);
|
||||
self["input"] = PopOutput();
|
||||
|
||||
op.update_clauses_->Accept(*this);
|
||||
self["update_clauses"] = PopOutput();
|
||||
|
||||
output_ = std::move(self);
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
|
||||
|
@ -92,6 +92,7 @@ class PlanPrinter : public virtual HierarchicalLogicalOperatorVisitor {
|
||||
bool PreVisit(Unwind &) override;
|
||||
bool PreVisit(CallProcedure &) override;
|
||||
bool PreVisit(LoadCsv &) override;
|
||||
bool PreVisit(Foreach &) override;
|
||||
|
||||
bool Visit(Once &) override;
|
||||
|
||||
@ -204,6 +205,7 @@ class PlanToJsonVisitor : public virtual HierarchicalLogicalOperatorVisitor {
|
||||
bool PreVisit(Union &) override;
|
||||
|
||||
bool PreVisit(Unwind &) override;
|
||||
bool PreVisit(Foreach &) override;
|
||||
bool PreVisit(CallProcedure &) override;
|
||||
bool PreVisit(LoadCsv &) override;
|
||||
|
||||
|
@ -79,6 +79,11 @@ bool ReadWriteTypeChecker::PreVisit(CallProcedure &op) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ReadWriteTypeChecker::PreVisit([[maybe_unused]] Foreach &op) {
|
||||
UpdateType(RWType::RW);
|
||||
return false;
|
||||
}
|
||||
|
||||
#undef PRE_VISIT
|
||||
|
||||
bool ReadWriteTypeChecker::Visit(Once &op) { return false; }
|
||||
|
@ -84,6 +84,7 @@ class ReadWriteTypeChecker : public virtual HierarchicalLogicalOperatorVisitor {
|
||||
|
||||
bool PreVisit(Unwind &) override;
|
||||
bool PreVisit(CallProcedure &) override;
|
||||
bool PreVisit(Foreach &) override;
|
||||
|
||||
bool Visit(Once &) override;
|
||||
|
||||
|
@ -433,6 +433,16 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PreVisit(Foreach &op) override {
|
||||
prev_ops_.push_back(&op);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool PostVisit(Foreach &) override {
|
||||
prev_ops_.pop_back();
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<LogicalOperator> new_root_;
|
||||
|
||||
private:
|
||||
|
@ -22,6 +22,7 @@
|
||||
#include "query/plan/operator.hpp"
|
||||
#include "query/plan/preprocess.hpp"
|
||||
#include "utils/logging.hpp"
|
||||
#include "utils/typeinfo.hpp"
|
||||
|
||||
namespace memgraph::query::plan {
|
||||
|
||||
@ -223,6 +224,10 @@ class RuleBasedPlanner {
|
||||
input_op =
|
||||
std::make_unique<plan::LoadCsv>(std::move(input_op), load_csv->file_, load_csv->with_header_,
|
||||
load_csv->ignore_bad_, load_csv->delimiter_, load_csv->quote_, row_sym);
|
||||
} else if (auto *foreach = utils::Downcast<query::Foreach>(clause)) {
|
||||
is_write = true;
|
||||
input_op = HandleForeachClause(foreach, std::move(input_op), *context.symbol_table, context.bound_symbols,
|
||||
query_part, merge_id);
|
||||
} else {
|
||||
throw utils::NotYetImplemented("clause '{}' conversion to operator(s)", clause->GetTypeInfo().name);
|
||||
}
|
||||
@ -530,6 +535,27 @@ class RuleBasedPlanner {
|
||||
}
|
||||
return std::make_unique<plan::Merge>(std::move(input_op), std::move(on_match), std::move(on_create));
|
||||
}
|
||||
|
||||
std::unique_ptr<LogicalOperator> HandleForeachClause(query::Foreach *foreach,
|
||||
std::unique_ptr<LogicalOperator> input_op,
|
||||
const SymbolTable &symbol_table,
|
||||
std::unordered_set<Symbol> &bound_symbols,
|
||||
const SingleQueryPart &query_part, uint64_t &merge_id) {
|
||||
const auto &symbol = symbol_table.at(*foreach->named_expression_);
|
||||
bound_symbols.insert(symbol);
|
||||
std::unique_ptr<LogicalOperator> op = std::make_unique<plan::Once>();
|
||||
for (auto *clause : foreach->clauses_) {
|
||||
if (auto *nested_for_each = utils::Downcast<query::Foreach>(clause)) {
|
||||
op = HandleForeachClause(nested_for_each, std::move(op), symbol_table, bound_symbols, query_part, merge_id);
|
||||
} else if (auto *merge = utils::Downcast<query::Merge>(clause)) {
|
||||
op = GenMerge(*merge, std::move(op), query_part.merge_matching[merge_id++]);
|
||||
} else {
|
||||
op = HandleWriteClause(clause, op, symbol_table, bound_symbols);
|
||||
}
|
||||
}
|
||||
return std::make_unique<plan::Foreach>(std::move(input_op), std::move(op), foreach->named_expression_->expression_,
|
||||
symbol);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace memgraph::query::plan
|
||||
|
@ -71,9 +71,9 @@ class ScopedProfile {
|
||||
|
||||
private:
|
||||
query::ExecutionContext *context_;
|
||||
ProfilingStats *root_;
|
||||
ProfilingStats *stats_;
|
||||
unsigned long long start_time_;
|
||||
ProfilingStats *root_{nullptr};
|
||||
ProfilingStats *stats_{nullptr};
|
||||
unsigned long long start_time_{0};
|
||||
};
|
||||
|
||||
} // namespace memgraph::query::plan
|
||||
|
@ -49,6 +49,7 @@
|
||||
M(UnionOperator, "Number of times Union operator was used.") \
|
||||
M(CartesianOperator, "Number of times Cartesian operator was used.") \
|
||||
M(CallProcedureOperator, "Number of times CallProcedure operator was used.") \
|
||||
M(ForeachOperator, "Number of times Foreach operator was used.") \
|
||||
\
|
||||
M(FailedQuery, "Number of times executing a query failed.") \
|
||||
M(LabelIndexCreated, "Number of times a label index was created.") \
|
||||
|
@ -499,4 +499,39 @@ BENCHMARK_TEMPLATE(Unwind, MonotonicBufferResource)
|
||||
|
||||
BENCHMARK_TEMPLATE(Unwind, PoolResource)->Ranges({{4, 1U << 7U}, {512, 1U << 13U}})->Unit(benchmark::kMicrosecond);
|
||||
|
||||
template <class TMemory>
|
||||
// NOLINTNEXTLINE(google-runtime-references)
|
||||
static void Foreach(benchmark::State &state) {
|
||||
memgraph::query::AstStorage ast;
|
||||
memgraph::storage::Storage db;
|
||||
memgraph::query::SymbolTable symbol_table;
|
||||
auto list_sym = symbol_table.CreateSymbol("list", false);
|
||||
auto *list_expr = ast.Create<memgraph::query::Identifier>("list")->MapTo(list_sym);
|
||||
auto out_sym = symbol_table.CreateSymbol("out", false);
|
||||
auto create_node =
|
||||
std::make_shared<memgraph::query::plan::CreateNode>(nullptr, memgraph::query::plan::NodeCreationInfo{});
|
||||
auto foreach = std::make_shared<memgraph::query::plan::Foreach>(nullptr, std::move(create_node), list_expr, out_sym);
|
||||
|
||||
auto storage_dba = db.Access();
|
||||
memgraph::query::DbAccessor dba(&storage_dba);
|
||||
TMemory per_pull_memory;
|
||||
memgraph::query::EvaluationContext evaluation_context{per_pull_memory.get()};
|
||||
while (state.KeepRunning()) {
|
||||
memgraph::query::ExecutionContext execution_context{&dba, symbol_table, evaluation_context};
|
||||
TMemory memory;
|
||||
memgraph::query::Frame frame(symbol_table.max_position(), memory.get());
|
||||
frame[list_sym] = memgraph::query::TypedValue(std::vector<memgraph::query::TypedValue>(state.range(1)));
|
||||
auto cursor = foreach->MakeCursor(memory.get());
|
||||
while (cursor->Pull(frame, execution_context)) per_pull_memory.Reset();
|
||||
}
|
||||
state.SetItemsProcessed(state.iterations());
|
||||
}
|
||||
|
||||
BENCHMARK_TEMPLATE(Foreach, PoolResource)->Ranges({{4, 1U << 7U}, {512, 1U << 13U}})->Unit(benchmark::kMicrosecond);
|
||||
BENCHMARK_TEMPLATE(Foreach, MonotonicBufferResource)
|
||||
->Ranges({{4, 1U << 7U}, {512, 1U << 13U}})
|
||||
->Unit(benchmark::kMicrosecond);
|
||||
|
||||
BENCHMARK_TEMPLATE(Foreach, PoolResource)->Ranges({{4, 1U << 7U}, {512, 1U << 13U}})->Unit(benchmark::kMicrosecond);
|
||||
|
||||
BENCHMARK_MAIN();
|
||||
|
273
tests/gql_behave/tests/memgraph_V1/features/foreach.feature
Normal file
273
tests/gql_behave/tests/memgraph_V1/features/foreach.feature
Normal file
@ -0,0 +1,273 @@
|
||||
# Copyright 2022 Memgraph Ltd.
|
||||
#
|
||||
# Use of this software is governed by the Business Source License
|
||||
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
# License, and you may not use this file except in compliance with the Business Source License.
|
||||
#
|
||||
# As of the Change Date specified in that file, in accordance with
|
||||
# the Business Source License, use of this software will be governed
|
||||
# by the Apache License, Version 2.0, included in the file
|
||||
# licenses/APL.txt.
|
||||
|
||||
Feature: Foreach
|
||||
Behaviour tests for memgraph FOREACH clause
|
||||
|
||||
Scenario: Foreach create
|
||||
Given an empty graph
|
||||
And having executed
|
||||
"""
|
||||
FOREACH( i IN [1, 2, 3] | CREATE (n {age : i}))
|
||||
"""
|
||||
When executing query:
|
||||
"""
|
||||
MATCH (n) RETURN n.age
|
||||
"""
|
||||
Then the result should be:
|
||||
| n.age |
|
||||
| 1 |
|
||||
| 2 |
|
||||
| 3 |
|
||||
And no side effects
|
||||
|
||||
Scenario: Foreach Foreach create
|
||||
Given an empty graph
|
||||
And having executed
|
||||
"""
|
||||
FOREACH( i IN [1, 2, 3] | CREATE (n {age : i})) FOREACH( i in [4, 5, 6] | CREATE (n {age : i}))
|
||||
"""
|
||||
When executing query:
|
||||
"""
|
||||
MATCH (n) RETURN n.age
|
||||
"""
|
||||
Then the result should be:
|
||||
| n.age |
|
||||
| 1 |
|
||||
| 2 |
|
||||
| 3 |
|
||||
| 4 |
|
||||
| 5 |
|
||||
| 6 |
|
||||
And no side effects
|
||||
|
||||
Scenario: Foreach shadowing
|
||||
Given an empty graph
|
||||
And having executed
|
||||
"""
|
||||
FOREACH( i IN [1] | FOREACH( i in [2, 3, 4] | CREATE (n {age : i})))
|
||||
"""
|
||||
When executing query:
|
||||
"""
|
||||
MATCH (n) RETURN n.age
|
||||
"""
|
||||
Then the result should be:
|
||||
| n.age |
|
||||
| 2 |
|
||||
| 3 |
|
||||
| 4 |
|
||||
And no side effects
|
||||
|
||||
Scenario: Foreach shadowing in create
|
||||
Given an empty graph
|
||||
And having executed
|
||||
"""
|
||||
FOREACH (i IN [1] | FOREACH (j IN [3,4] | CREATE (i {prop: j})));
|
||||
"""
|
||||
When executing query:
|
||||
"""
|
||||
MATCH (n) RETURN n.prop
|
||||
"""
|
||||
Then the result should be:
|
||||
| n.prop |
|
||||
| 3 |
|
||||
| 4 |
|
||||
And no side effects
|
||||
|
||||
Scenario: Foreach set
|
||||
Given an empty graph
|
||||
And having executed
|
||||
"""
|
||||
CREATE (n1 { marked: false })-[:RELATES]->(n2 { marked: false })
|
||||
"""
|
||||
And having executed
|
||||
"""
|
||||
MATCH p=(n1)-[*]->(n2)
|
||||
FOREACH (n IN nodes(p) | SET n.marked = true)
|
||||
"""
|
||||
When executing query:
|
||||
"""
|
||||
MATCH (n)
|
||||
RETURN n.marked
|
||||
"""
|
||||
Then the result should be:
|
||||
| n.marked |
|
||||
| true |
|
||||
| true |
|
||||
And no side effects
|
||||
|
||||
Scenario: Foreach remove
|
||||
Given an empty graph
|
||||
And having executed
|
||||
"""
|
||||
CREATE (n1 { marked: false })-[:RELATES]->(n2 { marked: false })
|
||||
"""
|
||||
And having executed
|
||||
"""
|
||||
MATCH p=(n1)-[*]->(n2)
|
||||
FOREACH (n IN nodes(p) | REMOVE n.marked)
|
||||
"""
|
||||
When executing query:
|
||||
"""
|
||||
MATCH (n)
|
||||
RETURN n;
|
||||
"""
|
||||
Then the result should be:
|
||||
| n |
|
||||
| () |
|
||||
| () |
|
||||
And no side effects
|
||||
|
||||
Scenario: Foreach delete
|
||||
Given an empty graph
|
||||
And having executed
|
||||
"""
|
||||
CREATE (n1 { marked: false })-[:RELATES]->(n2 { marked: false })
|
||||
"""
|
||||
And having executed
|
||||
"""
|
||||
MATCH p=(n1)-[*]->(n2)
|
||||
FOREACH (n IN nodes(p) | DETACH delete n)
|
||||
"""
|
||||
When executing query:
|
||||
"""
|
||||
MATCH (n)
|
||||
RETURN n;
|
||||
"""
|
||||
Then the result should be:
|
||||
| |
|
||||
And no side effects
|
||||
|
||||
Scenario: Foreach merge
|
||||
Given an empty graph
|
||||
And having executed
|
||||
"""
|
||||
FOREACH (i IN [1, 2, 3] | MERGE (n { age : i }))
|
||||
"""
|
||||
When executing query:
|
||||
"""
|
||||
MATCH (n)
|
||||
RETURN n.age;
|
||||
"""
|
||||
Then the result should be:
|
||||
| n.age |
|
||||
| 1 |
|
||||
| 2 |
|
||||
| 3 |
|
||||
And no side effects
|
||||
|
||||
Scenario: Foreach nested
|
||||
Given an empty graph
|
||||
And having executed
|
||||
"""
|
||||
FOREACH (i IN [1, 2, 3] | FOREACH( j IN [1] | CREATE (k { prop : j })))
|
||||
"""
|
||||
When executing query:
|
||||
"""
|
||||
MATCH (n)
|
||||
RETURN n.prop;
|
||||
"""
|
||||
Then the result should be:
|
||||
| n.prop |
|
||||
| 1 |
|
||||
| 1 |
|
||||
| 1 |
|
||||
|
||||
Scenario: Foreach multiple update clauses
|
||||
Given an empty graph
|
||||
And having executed
|
||||
"""
|
||||
CREATE (n1 { marked1: false, marked2: false })-[:RELATES]->(n2 { marked1: false, marked2: false })
|
||||
"""
|
||||
And having executed
|
||||
"""
|
||||
MATCH p=(n1)-[*]->(n2)
|
||||
FOREACH (n IN nodes(p) | SET n.marked1 = true SET n.marked2 = true)
|
||||
"""
|
||||
When executing query:
|
||||
"""
|
||||
MATCH (n)
|
||||
RETURN n
|
||||
"""
|
||||
Then the result should be:
|
||||
| n |
|
||||
| ({marked1: true, marked2: true}) |
|
||||
| ({marked1: true, marked2: true}) |
|
||||
And no side effects
|
||||
|
||||
Scenario: Foreach multiple nested update clauses
|
||||
Given an empty graph
|
||||
And having executed
|
||||
"""
|
||||
CREATE (n1 { marked1: false, marked2: false })-[:RELATES]->(n2 { marked1: false, marked2: false })
|
||||
"""
|
||||
And having executed
|
||||
"""
|
||||
MATCH p=(n1)-[*]->(n2)
|
||||
FOREACH (n IN nodes(p) | FOREACH (j IN [1] | SET n.marked1 = true SET n.marked2 = true))
|
||||
"""
|
||||
When executing query:
|
||||
"""
|
||||
MATCH (n)
|
||||
RETURN n
|
||||
"""
|
||||
Then the result should be:
|
||||
| n |
|
||||
| ({marked1: true, marked2: true}) |
|
||||
| ({marked1: true, marked2: true}) |
|
||||
And no side effects
|
||||
|
||||
Scenario: Foreach match foreach return
|
||||
Given an empty graph
|
||||
And having executed
|
||||
"""
|
||||
CREATE (n {prop: [[], [1,2]]});
|
||||
"""
|
||||
When executing query:
|
||||
"""
|
||||
MATCH (n) FOREACH (i IN n.prop | CREATE (:V { i: i})) RETURN n;
|
||||
"""
|
||||
Then the result should be:
|
||||
| n |
|
||||
| ({prop: [[], [1, 2]]}) |
|
||||
And no side effects
|
||||
|
||||
Scenario: Foreach on null value
|
||||
Given an empty graph
|
||||
And having executed
|
||||
"""
|
||||
CREATE (n);
|
||||
"""
|
||||
When executing query:
|
||||
"""
|
||||
MATCH (n) FOREACH (i IN n.prop | CREATE (:V { i: i}));
|
||||
"""
|
||||
Then the result should be:
|
||||
| |
|
||||
And no side effects
|
||||
|
||||
Scenario: Foreach nested merge
|
||||
Given an empty graph
|
||||
And having executed
|
||||
"""
|
||||
FOREACH(i in [1, 2, 3] | foreach(j in [1] | MERGE (n { age : i })));
|
||||
"""
|
||||
When executing query:
|
||||
"""
|
||||
MATCH (n)
|
||||
RETURN n
|
||||
"""
|
||||
Then the result should be:
|
||||
| n |
|
||||
| ({age: 1}) |
|
||||
| ({age: 2}) |
|
||||
| ({age: 3}) |
|
||||
And no side effects
|
@ -4090,3 +4090,89 @@ TEST_P(CypherMainVisitorTest, VersionQuery) {
|
||||
TestInvalidQuery("SHOW VERSIONS", ast_generator);
|
||||
ASSERT_NO_THROW(ast_generator.ParseQuery("SHOW VERSION"));
|
||||
}
|
||||
|
||||
TEST_P(CypherMainVisitorTest, ForeachThrow) {
|
||||
auto &ast_generator = *GetParam();
|
||||
EXPECT_THROW(ast_generator.ParseQuery("FOREACH(i IN [1, 2] | UNWIND [1,2,3] AS j CREATE (n))"), SyntaxException);
|
||||
EXPECT_THROW(ast_generator.ParseQuery("FOREACH(i IN [1, 2] CREATE (:Foo {prop : i}))"), SyntaxException);
|
||||
EXPECT_THROW(ast_generator.ParseQuery("FOREACH(i IN [1, 2] | MATCH (n)"), SyntaxException);
|
||||
EXPECT_THROW(ast_generator.ParseQuery("FOREACH(i IN x | MATCH (n)"), SyntaxException);
|
||||
}
|
||||
|
||||
TEST_P(CypherMainVisitorTest, Foreach) {
|
||||
auto &ast_generator = *GetParam();
|
||||
// CREATE
|
||||
{
|
||||
auto *query = dynamic_cast<CypherQuery *>(
|
||||
ast_generator.ParseQuery("FOREACH (age IN [1, 2, 3] | CREATE (m:Age {amount: age}))"));
|
||||
ASSERT_TRUE(query);
|
||||
ASSERT_TRUE(query->single_query_);
|
||||
auto *single_query = query->single_query_;
|
||||
ASSERT_EQ(single_query->clauses_.size(), 1U);
|
||||
auto *foreach = dynamic_cast<Foreach *>(single_query->clauses_[0]);
|
||||
ASSERT_TRUE(foreach);
|
||||
ASSERT_TRUE(foreach->named_expression_);
|
||||
EXPECT_EQ(foreach->named_expression_->name_, "age");
|
||||
auto *expr = foreach->named_expression_->expression_;
|
||||
ASSERT_TRUE(expr);
|
||||
ASSERT_TRUE(dynamic_cast<ListLiteral *>(expr));
|
||||
const auto &clauses = foreach->clauses_;
|
||||
ASSERT_TRUE(clauses.size() == 1);
|
||||
ASSERT_TRUE(dynamic_cast<Create *>(clauses.front()));
|
||||
}
|
||||
// SET
|
||||
{
|
||||
auto *query =
|
||||
dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("FOREACH (i IN nodes(path) | SET i.checkpoint = true)"));
|
||||
auto *foreach = dynamic_cast<Foreach *>(query->single_query_->clauses_[0]);
|
||||
const auto &clauses = foreach->clauses_;
|
||||
ASSERT_TRUE(clauses.size() == 1);
|
||||
ASSERT_TRUE(dynamic_cast<SetProperty *>(clauses.front()));
|
||||
}
|
||||
// REMOVE
|
||||
{
|
||||
auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("FOREACH (i IN nodes(path) | REMOVE i.prop)"));
|
||||
auto *foreach = dynamic_cast<Foreach *>(query->single_query_->clauses_[0]);
|
||||
const auto &clauses = foreach->clauses_;
|
||||
ASSERT_TRUE(clauses.size() == 1);
|
||||
ASSERT_TRUE(dynamic_cast<RemoveProperty *>(clauses.front()));
|
||||
}
|
||||
// MERGE
|
||||
{
|
||||
// merge works as create here
|
||||
auto *query =
|
||||
dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("FOREACH (i IN [1, 2, 3] | MERGE (n {no : i}))"));
|
||||
auto *foreach = dynamic_cast<Foreach *>(query->single_query_->clauses_[0]);
|
||||
const auto &clauses = foreach->clauses_;
|
||||
ASSERT_TRUE(clauses.size() == 1);
|
||||
ASSERT_TRUE(dynamic_cast<Merge *>(clauses.front()));
|
||||
}
|
||||
// CYPHER DELETE
|
||||
{
|
||||
auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("FOREACH (i IN nodes(path) | DETACH DELETE i)"));
|
||||
auto *foreach = dynamic_cast<Foreach *>(query->single_query_->clauses_[0]);
|
||||
const auto &clauses = foreach->clauses_;
|
||||
ASSERT_TRUE(clauses.size() == 1);
|
||||
ASSERT_TRUE(dynamic_cast<Delete *>(clauses.front()));
|
||||
}
|
||||
// nested FOREACH
|
||||
{
|
||||
auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery(
|
||||
"FOREACH (i IN nodes(path) | FOREACH (age IN i.list | CREATE (m:Age {amount: age})))"));
|
||||
|
||||
auto *foreach = dynamic_cast<Foreach *>(query->single_query_->clauses_[0]);
|
||||
const auto &clauses = foreach->clauses_;
|
||||
ASSERT_TRUE(clauses.size() == 1);
|
||||
ASSERT_TRUE(dynamic_cast<Foreach *>(clauses.front()));
|
||||
}
|
||||
// Multiple update clauses
|
||||
{
|
||||
auto *query = dynamic_cast<CypherQuery *>(
|
||||
ast_generator.ParseQuery("FOREACH (i IN nodes(path) | SET i.checkpoint = true REMOVE i.prop)"));
|
||||
auto *foreach = dynamic_cast<Foreach *>(query->single_query_->clauses_[0]);
|
||||
const auto &clauses = foreach->clauses_;
|
||||
ASSERT_TRUE(clauses.size() == 2);
|
||||
ASSERT_TRUE(dynamic_cast<SetProperty *>(clauses.front()));
|
||||
ASSERT_TRUE(dynamic_cast<RemoveProperty *>(*++clauses.begin()));
|
||||
}
|
||||
}
|
||||
|
@ -911,3 +911,34 @@ TEST_F(PrintToJsonTest, CallProcedure) {
|
||||
"result_symbols" : ["name_alias", "signature_alias"]
|
||||
})sep");
|
||||
}
|
||||
|
||||
TEST_F(PrintToJsonTest, Foreach) {
|
||||
Symbol x = GetSymbol("x");
|
||||
std::shared_ptr<LogicalOperator> create =
|
||||
std::make_shared<CreateNode>(nullptr, NodeCreationInfo{GetSymbol("node"), {dba.NameToLabel("Label1")}, {}});
|
||||
std::shared_ptr<LogicalOperator> foreach =
|
||||
std::make_shared<plan::Foreach>(nullptr, std::move(create), LIST(LITERAL(1)), x);
|
||||
|
||||
Check(foreach.get(), R"sep(
|
||||
{
|
||||
"expression": "(ListLiteral [1])",
|
||||
"input": {
|
||||
"name": "Once"
|
||||
},
|
||||
"name": "Foreach",
|
||||
"loop_variable_symbol": "x",
|
||||
"update_clauses": {
|
||||
"input": {
|
||||
"name": "Once"
|
||||
},
|
||||
"name": "CreateNode",
|
||||
"node_info": {
|
||||
"labels": [
|
||||
"Label1"
|
||||
],
|
||||
"properties": null,
|
||||
"symbol": "node"
|
||||
}
|
||||
}
|
||||
})sep");
|
||||
}
|
||||
|
@ -463,6 +463,11 @@ auto GetCallProcedure(AstStorage &storage, std::string procedure_name,
|
||||
return call_procedure;
|
||||
}
|
||||
|
||||
/// Create the FOREACH clause with given named expression.
|
||||
auto GetForeach(AstStorage &storage, NamedExpression *named_expr, const std::vector<query::Clause *> &clauses) {
|
||||
return storage.Create<query::Foreach>(named_expr, clauses);
|
||||
}
|
||||
|
||||
} // namespace test_common
|
||||
|
||||
} // namespace memgraph::query
|
||||
@ -526,13 +531,14 @@ auto GetCallProcedure(AstStorage &storage, std::string procedure_name,
|
||||
memgraph::query::test_common::OnCreate { \
|
||||
std::vector<memgraph::query::Clause *> { __VA_ARGS__ } \
|
||||
}
|
||||
#define CREATE_INDEX_ON(label, property) \
|
||||
#define CREATE_INDEX_ON(label, property) \
|
||||
storage.Create<memgraph::query::IndexQuery>(memgraph::query::IndexQuery::Action::CREATE, (label), \
|
||||
std::vector<memgraph::query::PropertyIx>{(property)})
|
||||
std::vector<memgraph::query::PropertyIx>{(property)})
|
||||
#define QUERY(...) memgraph::query::test_common::GetQuery(storage, __VA_ARGS__)
|
||||
#define SINGLE_QUERY(...) memgraph::query::test_common::GetSingleQuery(storage.Create<SingleQuery>(), __VA_ARGS__)
|
||||
#define UNION(...) memgraph::query::test_common::GetCypherUnion(storage.Create<CypherUnion>(true), __VA_ARGS__)
|
||||
#define UNION_ALL(...) memgraph::query::test_common::GetCypherUnion(storage.Create<CypherUnion>(false), __VA_ARGS__)
|
||||
#define FOREACH(...) memgraph::query::test_common::GetForeach(storage, __VA_ARGS__)
|
||||
// Various operators
|
||||
#define NOT(expr) storage.Create<memgraph::query::NotOperator>((expr))
|
||||
#define UPLUS(expr) storage.Create<memgraph::query::UnaryPlusOperator>((expr))
|
||||
|
@ -181,6 +181,19 @@ TEST_F(QueryCostEstimator, ExpandVariable) {
|
||||
EXPECT_COST(CardParam::kExpandVariable * CostParam::kExpandVariable);
|
||||
}
|
||||
|
||||
TEST_F(QueryCostEstimator, ForeachListLiteral) {
|
||||
constexpr size_t list_expr_sz = 10;
|
||||
std::shared_ptr<LogicalOperator> create = std::make_shared<CreateNode>(std::make_shared<Once>(), NodeCreationInfo{});
|
||||
MakeOp<memgraph::query::plan::Foreach>(
|
||||
last_op_, create, storage_.Create<ListLiteral>(std::vector<Expression *>(list_expr_sz, nullptr)), NextSymbol());
|
||||
EXPECT_COST(CostParam::kForeach * list_expr_sz);
|
||||
}
|
||||
|
||||
TEST_F(QueryCostEstimator, Foreach) {
|
||||
std::shared_ptr<LogicalOperator> create = std::make_shared<CreateNode>(std::make_shared<Once>(), NodeCreationInfo{});
|
||||
MakeOp<memgraph::query::plan::Foreach>(last_op_, create, storage_.Create<Identifier>(), NextSymbol());
|
||||
EXPECT_COST(CostParam::kForeach * MiscParam::kForeachNoLiteral);
|
||||
}
|
||||
// Helper for testing an operations cost and cardinality.
|
||||
// Only for operations that first increment cost, then modify cardinality.
|
||||
// Intentially a macro (instead of function) for better test feedback.
|
||||
|
@ -90,7 +90,6 @@ void DeleteListContent(std::list<BaseOpChecker *> *list) {
|
||||
delete ptr;
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST_CASE(TestPlanner, PlannerTypes);
|
||||
|
||||
TYPED_TEST(TestPlanner, MatchNodeReturn) {
|
||||
@ -1555,4 +1554,57 @@ TYPED_TEST(TestPlanner, LabelPropertyInListWhereLabelPropertyOnRight) {
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestPlanner, Foreach) {
|
||||
AstStorage storage;
|
||||
FakeDbAccessor dba;
|
||||
{
|
||||
auto *i = NEXPR("i", IDENT("i"));
|
||||
auto *query = QUERY(SINGLE_QUERY(FOREACH(i, {CREATE(PATTERN(NODE("n")))})));
|
||||
auto create = ExpectCreateNode();
|
||||
std::list<BaseOpChecker *> updates{&create};
|
||||
std::list<BaseOpChecker *> input;
|
||||
CheckPlan<TypeParam>(query, storage, ExpectForeach(input, updates));
|
||||
}
|
||||
{
|
||||
auto *i = NEXPR("i", IDENT("i"));
|
||||
auto *query = QUERY(SINGLE_QUERY(FOREACH(i, {DELETE(IDENT("i"))})));
|
||||
auto del = ExpectDelete();
|
||||
std::list<BaseOpChecker *> updates{&del};
|
||||
std::list<BaseOpChecker *> input;
|
||||
CheckPlan<TypeParam>(query, storage, ExpectForeach({input}, updates));
|
||||
}
|
||||
{
|
||||
auto prop = dba.Property("prop");
|
||||
auto *i = NEXPR("i", IDENT("i"));
|
||||
auto *query = QUERY(SINGLE_QUERY(FOREACH(i, {SET(PROPERTY_LOOKUP("i", prop), LITERAL(10))})));
|
||||
auto set_prop = ExpectSetProperty();
|
||||
std::list<BaseOpChecker *> updates{&set_prop};
|
||||
std::list<BaseOpChecker *> input;
|
||||
CheckPlan<TypeParam>(query, storage, ExpectForeach({input}, updates));
|
||||
}
|
||||
{
|
||||
auto *i = NEXPR("i", IDENT("i"));
|
||||
auto *j = NEXPR("j", IDENT("j"));
|
||||
auto *query = QUERY(SINGLE_QUERY(FOREACH(i, {FOREACH(j, {CREATE(PATTERN(NODE("n"))), DELETE(IDENT("i"))})})));
|
||||
auto create = ExpectCreateNode();
|
||||
auto del = ExpectDelete();
|
||||
std::list<BaseOpChecker *> input;
|
||||
std::list<BaseOpChecker *> nested_updates{{&create, &del}};
|
||||
auto nested_foreach = ExpectForeach(input, nested_updates);
|
||||
std::list<BaseOpChecker *> updates{&nested_foreach};
|
||||
CheckPlan<TypeParam>(query, storage, ExpectForeach(input, updates));
|
||||
}
|
||||
{
|
||||
auto *i = NEXPR("i", IDENT("i"));
|
||||
auto *j = NEXPR("j", IDENT("j"));
|
||||
auto create = ExpectCreateNode();
|
||||
std::list<BaseOpChecker *> empty;
|
||||
std::list<BaseOpChecker *> updates{&create};
|
||||
auto input_op = ExpectForeach(empty, updates);
|
||||
std::list<BaseOpChecker *> input{&input_op};
|
||||
auto *query =
|
||||
QUERY(SINGLE_QUERY(FOREACH(i, {CREATE(PATTERN(NODE("n")))}), FOREACH(j, {CREATE(PATTERN(NODE("n")))})));
|
||||
CheckPlan<TypeParam>(query, storage, ExpectForeach(input, updates));
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
@ -90,6 +90,11 @@ class PlanChecker : public virtual HierarchicalLogicalOperatorVisitor {
|
||||
}
|
||||
PRE_VISIT(Unwind);
|
||||
PRE_VISIT(Distinct);
|
||||
|
||||
bool PreVisit(Foreach &op) override {
|
||||
CheckOp(op);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool Visit(Once &) override {
|
||||
// Ignore checking Once, it is implicitly at the end.
|
||||
@ -150,6 +155,25 @@ using ExpectOrderBy = OpChecker<OrderBy>;
|
||||
using ExpectUnwind = OpChecker<Unwind>;
|
||||
using ExpectDistinct = OpChecker<Distinct>;
|
||||
|
||||
class ExpectForeach : public OpChecker<Foreach> {
|
||||
public:
|
||||
ExpectForeach(const std::list<BaseOpChecker *> &input, const std::list<BaseOpChecker *> &updates)
|
||||
: input_(input), updates_(updates) {}
|
||||
|
||||
void ExpectOp(Foreach &foreach, const SymbolTable &symbol_table) override {
|
||||
PlanChecker check_input(input_, symbol_table);
|
||||
foreach
|
||||
.input_->Accept(check_input);
|
||||
PlanChecker check_updates(updates_, symbol_table);
|
||||
foreach
|
||||
.update_clauses_->Accept(check_updates);
|
||||
}
|
||||
|
||||
private:
|
||||
std::list<BaseOpChecker *> input_;
|
||||
std::list<BaseOpChecker *> updates_;
|
||||
};
|
||||
|
||||
class ExpectExpandVariable : public OpChecker<ExpandVariable> {
|
||||
public:
|
||||
void ExpectOp(ExpandVariable &op, const SymbolTable &) override {
|
||||
|
@ -247,3 +247,9 @@ TEST_F(ReadWriteTypeCheckTest, ConstructNamedPath) {
|
||||
|
||||
CheckPlanType(last_op.get(), RWType::R);
|
||||
}
|
||||
|
||||
TEST_F(ReadWriteTypeCheckTest, Foreach) {
|
||||
Symbol x = GetSymbol("x");
|
||||
std::shared_ptr<LogicalOperator> foreach = std::make_shared<plan::Foreach>(nullptr, nullptr, nullptr, x);
|
||||
CheckPlanType(foreach.get(), RWType::RW);
|
||||
}
|
||||
|
@ -1157,3 +1157,23 @@ TEST_F(TestSymbolGenerator, PredefinedIdentifiers) {
|
||||
query = QUERY(SINGLE_QUERY(unwind, CREATE(PATTERN(node))));
|
||||
ASSERT_THROW(memgraph::query::MakeSymbolTable(query, {first_op}), SemanticException);
|
||||
}
|
||||
|
||||
TEST_F(TestSymbolGenerator, Foreach) {
|
||||
auto *i = NEXPR("i", IDENT("i"));
|
||||
auto query = QUERY(SINGLE_QUERY(FOREACH(i, {CREATE(PATTERN(NODE("n")))}), RETURN("n")));
|
||||
EXPECT_THROW(memgraph::query::MakeSymbolTable(query), UnboundVariableError);
|
||||
|
||||
query = QUERY(SINGLE_QUERY(FOREACH(i, {CREATE(PATTERN(NODE("n")))}), FOREACH(i, {CREATE(PATTERN(NODE("v")))})));
|
||||
auto symbol_table = memgraph::query::MakeSymbolTable(query);
|
||||
ASSERT_EQ(symbol_table.max_position(), 6);
|
||||
|
||||
query = QUERY(SINGLE_QUERY(FOREACH(i, {FOREACH(i, {CREATE(PATTERN(NODE("i")))})})));
|
||||
EXPECT_THROW(memgraph::query::MakeSymbolTable(query), RedeclareVariableError);
|
||||
|
||||
query = QUERY(SINGLE_QUERY(FOREACH(i, {FOREACH(i, {CREATE(PATTERN(NODE("v")))})})));
|
||||
symbol_table = memgraph::query::MakeSymbolTable(query);
|
||||
ASSERT_EQ(symbol_table.max_position(), 4);
|
||||
|
||||
query = QUERY(SINGLE_QUERY(FOREACH(i, {CREATE(PATTERN(NODE("n")))}), RETURN("i")));
|
||||
EXPECT_THROW(memgraph::query::MakeSymbolTable(query), UnboundVariableError);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user