List Pattern Comprehension planner (#1686)

This commit is contained in:
Aidar Samerkhanov 2024-03-07 18:41:02 +04:00 committed by GitHub
parent 02325f8673
commit a099417c56
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 460 additions and 57 deletions

View File

@ -21,6 +21,7 @@
#include "query/interpret/awesome_memgraph_functions.hpp"
#include "query/typed_value.hpp"
#include "storage/v2/property_value.hpp"
#include "utils/exceptions.hpp"
#include "utils/typeinfo.hpp"
namespace memgraph::query {
@ -3586,7 +3587,7 @@ class PatternComprehension : public memgraph::query::Expression {
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
if (variable_) {
variable_->Accept(visitor);
throw utils::NotYetImplemented("Variable in pattern comprehension.");
}
pattern_->Accept(visitor);
if (filter_) {
@ -3615,7 +3616,8 @@ class PatternComprehension : public memgraph::query::Expression {
int32_t symbol_pos_{-1};
PatternComprehension *Clone(AstStorage *storage) const override {
PatternComprehension *object = storage->Create<PatternComprehension>();
auto *object = storage->Create<PatternComprehension>();
object->variable_ = variable_ ? variable_->Clone(storage) : nullptr;
object->pattern_ = pattern_ ? pattern_->Clone(storage) : nullptr;
object->filter_ = filter_ ? filter_->Clone(storage) : nullptr;
object->resultExpr_ = resultExpr_ ? resultExpr_->Clone(storage) : nullptr;
@ -3625,7 +3627,8 @@ class PatternComprehension : public memgraph::query::Expression {
}
protected:
PatternComprehension(Identifier *variable, Pattern *pattern) : variable_(variable), pattern_(pattern) {}
PatternComprehension(Identifier *variable, Pattern *pattern, Where *filter, Expression *resultExpr)
: variable_(variable), pattern_(pattern), filter_(filter), resultExpr_(resultExpr) {}
private:
friend class AstStorage;

View File

@ -721,6 +721,32 @@ bool SymbolGenerator::PostVisit(EdgeAtom &) {
return true;
}
bool SymbolGenerator::PreVisit(PatternComprehension &pc) {
auto &scope = scopes_.back();
if (scope.in_set_property) {
throw utils::NotYetImplemented("Pattern Comprehension cannot be used within SET clause.!");
}
if (scope.in_with) {
throw utils::NotYetImplemented("Pattern Comprehension cannot be used within WITH!");
}
if (scope.in_reduce) {
throw utils::NotYetImplemented("Pattern Comprehension cannot be used within REDUCE!");
}
if (scope.num_if_operators) {
throw utils::NotYetImplemented("IF operator cannot be used with Pattern Comprehension!");
}
const auto &symbol = CreateAnonymousSymbol();
pc.MapTo(symbol);
return true;
}
bool SymbolGenerator::PostVisit(PatternComprehension & /*pc*/) { return true; }
void SymbolGenerator::VisitWithIdentifiers(Expression *expr, const std::vector<Identifier *> &identifiers) {
auto &scope = scopes_.back();
std::vector<std::pair<std::optional<Symbol>, Identifier *>> prev_symbols;

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -97,6 +97,8 @@ class SymbolGenerator : public HierarchicalTreeVisitor {
bool PostVisit(NodeAtom &) override;
bool PreVisit(EdgeAtom &) override;
bool PostVisit(EdgeAtom &) override;
bool PreVisit(PatternComprehension &) override;
bool PostVisit(PatternComprehension &) override;
private:
// Scope stores the state of where we are when visiting the AST and a map of

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -206,6 +206,14 @@ class PlanHintsProvider final : public HierarchicalLogicalOperatorVisitor {
bool PostVisit(IndexedJoin & /*unused*/) override { return true; }
bool PreVisit(RollUpApply &op) override {
op.input()->Accept(*this);
op.list_collection_branch_->Accept(*this);
return false;
}
bool PostVisit(RollUpApply & /*unused*/) override { return true; }
private:
const SymbolTable &symbol_table_;
std::vector<std::string> hints_;

View File

@ -5624,4 +5624,25 @@ UniqueCursorPtr HashJoin::MakeCursor(utils::MemoryResource *mem) const {
return MakeUniqueCursorPtr<HashJoinCursor>(mem, *this, mem);
}
RollUpApply::RollUpApply(const std::shared_ptr<LogicalOperator> &input,
std::shared_ptr<LogicalOperator> &&second_branch)
: input_(input), list_collection_branch_(second_branch) {}
std::vector<Symbol> RollUpApply::OutputSymbols(const SymbolTable & /*symbol_table*/) const {
std::vector<Symbol> symbols;
return symbols;
}
std::vector<Symbol> RollUpApply::ModifiedSymbols(const SymbolTable &table) const { return OutputSymbols(table); }
bool RollUpApply::Accept(HierarchicalLogicalOperatorVisitor &visitor) {
if (visitor.PreVisit(*this)) {
if (!input_ || !list_collection_branch_) {
throw utils::NotYetImplemented("One of the branches in pattern comprehension is null! Please contact support.");
}
input_->Accept(visitor) && list_collection_branch_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
} // namespace memgraph::query::plan

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -130,6 +130,7 @@ class EvaluatePatternFilter;
class Apply;
class IndexedJoin;
class HashJoin;
class RollUpApply;
using LogicalOperatorCompositeVisitor =
utils::CompositeVisitor<Once, CreateNode, CreateExpand, ScanAll, ScanAllByLabel, ScanAllByLabelPropertyRange,
@ -137,7 +138,7 @@ using LogicalOperatorCompositeVisitor =
ConstructNamedPath, Filter, Produce, Delete, SetProperty, SetProperties, SetLabels,
RemoveProperty, RemoveLabels, EdgeUniquenessFilter, Accumulate, Aggregate, Skip, Limit,
OrderBy, Merge, Optional, Unwind, Distinct, Union, Cartesian, CallProcedure, LoadCsv,
Foreach, EmptyResult, EvaluatePatternFilter, Apply, IndexedJoin, HashJoin>;
Foreach, EmptyResult, EvaluatePatternFilter, Apply, IndexedJoin, HashJoin, RollUpApply>;
using LogicalOperatorLeafVisitor = utils::LeafVisitor<Once>;
@ -2634,5 +2635,38 @@ class HashJoin : public memgraph::query::plan::LogicalOperator {
}
};
/// RollUpApply operator is used to execute an expression which takes as input a pattern,
/// and returns a list with content from the matched pattern
/// It's used for a pattern expression or pattern comprehension in a query.
class RollUpApply : public memgraph::query::plan::LogicalOperator {
public:
static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; }
RollUpApply() = default;
RollUpApply(const std::shared_ptr<LogicalOperator> &input, std::shared_ptr<LogicalOperator> &&second_branch);
bool HasSingleInput() const override { return false; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; }
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override {
throw utils::NotYetImplemented("Execution of Pattern comprehension is currently unsupported.");
}
std::vector<Symbol> OutputSymbols(const SymbolTable &) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override {
auto object = std::make_unique<RollUpApply>();
object->input_ = input_ ? input_->Clone(storage) : nullptr;
object->list_collection_branch_ = list_collection_branch_ ? list_collection_branch_->Clone(storage) : nullptr;
return object;
}
std::shared_ptr<memgraph::query::plan::LogicalOperator> input_;
std::shared_ptr<memgraph::query::plan::LogicalOperator> list_collection_branch_;
};
} // namespace plan
} // namespace memgraph::query

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -154,4 +154,7 @@ constexpr utils::TypeInfo query::plan::IndexedJoin::kType{utils::TypeId::INDEXED
constexpr utils::TypeInfo query::plan::HashJoin::kType{utils::TypeId::HASH_JOIN, "HashJoin",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::RollUpApply::kType{utils::TypeId::ROLLUP_APPLY, "RollUpApply",
&query::plan::LogicalOperator::kType};
} // namespace memgraph

View File

@ -632,20 +632,20 @@ void AddMatching(const Match &match, SymbolTable &symbol_table, AstStorage &stor
// If there are any pattern filters, we add those as well
for (auto &filter : matching.filters) {
PatternFilterVisitor visitor(symbol_table, storage);
PatternVisitor visitor(symbol_table, storage);
filter.expression->Accept(visitor);
filter.matchings = visitor.getMatchings();
filter.matchings = visitor.getFilterMatchings();
}
}
PatternFilterVisitor::PatternFilterVisitor(SymbolTable &symbol_table, AstStorage &storage)
PatternVisitor::PatternVisitor(SymbolTable &symbol_table, AstStorage &storage)
: symbol_table_(symbol_table), storage_(storage) {}
PatternFilterVisitor::PatternFilterVisitor(const PatternFilterVisitor &) = default;
PatternFilterVisitor::PatternFilterVisitor(PatternFilterVisitor &&) noexcept = default;
PatternFilterVisitor::~PatternFilterVisitor() = default;
PatternVisitor::PatternVisitor(const PatternVisitor &) = default;
PatternVisitor::PatternVisitor(PatternVisitor &&) noexcept = default;
PatternVisitor::~PatternVisitor() = default;
void PatternFilterVisitor::Visit(Exists &op) {
void PatternVisitor::Visit(Exists &op) {
std::vector<Pattern *> patterns;
patterns.push_back(op.pattern_);
@ -655,10 +655,10 @@ void PatternFilterVisitor::Visit(Exists &op) {
filter_matching.type = PatternFilterType::EXISTS;
filter_matching.symbol = std::make_optional<Symbol>(symbol_table_.at(op));
matchings_.push_back(std::move(filter_matching));
filter_matchings_.push_back(std::move(filter_matching));
}
std::vector<FilterMatching> PatternFilterVisitor::getMatchings() { return matchings_; }
std::vector<FilterMatching> PatternVisitor::getFilterMatchings() { return filter_matchings_; }
static void ParseForeach(query::Foreach &foreach, SingleQueryPart &query_part, AstStorage &storage,
SymbolTable &symbol_table) {
@ -672,6 +672,30 @@ static void ParseForeach(query::Foreach &foreach, SingleQueryPart &query_part, A
}
}
static void ParseReturn(query::Return &ret, AstStorage &storage, SymbolTable &symbol_table,
std::unordered_map<std::string, PatternComprehensionMatching> &matchings) {
PatternVisitor visitor(symbol_table, storage);
for (auto *expr : ret.body_.named_expressions) {
expr->Accept(visitor);
auto pattern_comprehension_matchings = visitor.getPatternComprehensionMatchings();
for (auto &matching : pattern_comprehension_matchings) {
matchings.emplace(expr->name_, matching);
}
}
}
void PatternVisitor::Visit(NamedExpression &op) { op.expression_->Accept(*this); }
void PatternVisitor::Visit(PatternComprehension &op) {
PatternComprehensionMatching matching;
AddMatching({op.pattern_}, op.filter_, symbol_table_, storage_, matching);
matching.result_expr = storage_.Create<NamedExpression>(symbol_table_.at(op).name(), op.resultExpr_);
matching.result_expr->MapTo(symbol_table_.at(op));
pattern_comprehension_matchings_.push_back(std::move(matching));
}
// Converts a Query to multiple QueryParts. In the process new Ast nodes may be
// created, e.g. filter expressions.
std::vector<SingleQueryPart> CollectSingleQueryParts(SymbolTable &symbol_table, AstStorage &storage,
@ -703,7 +727,8 @@ std::vector<SingleQueryPart> CollectSingleQueryParts(SymbolTable &symbol_table,
// This query part is done, continue with a new one.
query_parts.emplace_back(SingleQueryPart{});
query_part = &query_parts.back();
} else if (utils::IsSubtype(*clause, Return::kType)) {
} else if (auto *ret = utils::Downcast<Return>(clause)) {
ParseReturn(*ret, storage, symbol_table, query_part->pattern_comprehension_matchings);
return query_parts;
}
}

View File

@ -153,19 +153,20 @@ struct Expansion {
ExpansionGroupId expansion_group_id = ExpansionGroupId();
};
struct PatternComprehensionMatching;
struct FilterMatching;
enum class PatternFilterType { EXISTS };
/// Collects matchings from filters that include patterns
class PatternFilterVisitor : public ExpressionVisitor<void> {
/// Collects matchings that include patterns
class PatternVisitor : public ExpressionVisitor<void> {
public:
explicit PatternFilterVisitor(SymbolTable &symbol_table, AstStorage &storage);
PatternFilterVisitor(const PatternFilterVisitor &);
PatternFilterVisitor &operator=(const PatternFilterVisitor &) = delete;
PatternFilterVisitor(PatternFilterVisitor &&) noexcept;
PatternFilterVisitor &operator=(PatternFilterVisitor &&) noexcept = delete;
~PatternFilterVisitor() override;
explicit PatternVisitor(SymbolTable &symbol_table, AstStorage &storage);
PatternVisitor(const PatternVisitor &);
PatternVisitor &operator=(const PatternVisitor &) = delete;
PatternVisitor(PatternVisitor &&) noexcept;
PatternVisitor &operator=(PatternVisitor &&) noexcept = delete;
~PatternVisitor() override;
using ExpressionVisitor<void>::Visit;
@ -233,18 +234,24 @@ class PatternFilterVisitor : public ExpressionVisitor<void> {
void Visit(PropertyLookup &op) override{};
void Visit(AllPropertiesLookup &op) override{};
void Visit(ParameterLookup &op) override{};
void Visit(NamedExpression &op) override{};
void Visit(RegexMatch &op) override{};
void Visit(PatternComprehension &op) override{};
void Visit(NamedExpression &op) override;
void Visit(PatternComprehension &op) override;
std::vector<FilterMatching> getMatchings();
std::vector<FilterMatching> getFilterMatchings();
std::vector<PatternComprehensionMatching> getPatternComprehensionMatchings() {
return pattern_comprehension_matchings_;
}
SymbolTable &symbol_table_;
AstStorage &storage_;
private:
/// Collection of matchings in the filter expression being analyzed.
std::vector<FilterMatching> matchings_;
std::vector<FilterMatching> filter_matchings_;
/// Collection of matchings in the pattern comprehension being analyzed.
std::vector<PatternComprehensionMatching> pattern_comprehension_matchings_;
};
/// Stores the symbols and expression used to filter a property.
@ -495,6 +502,11 @@ inline auto Filters::IdFilters(const Symbol &symbol) const -> std::vector<Filter
return filters;
}
struct PatternComprehensionMatching : Matching {
/// Pattern comprehension result named expression
NamedExpression *result_expr = nullptr;
};
/// @brief Represents a read (+ write) part of a query. Parts are split on
/// `WITH` clauses.
///
@ -537,6 +549,14 @@ struct SingleQueryPart {
/// in the `remaining_clauses` but rather in the `Foreach` itself and are guranteed
/// to be processed in the same order by the semantics of the `RuleBasedPlanner`.
std::vector<Matching> merge_matching{};
/// @brief @c NamedExpression name to @c PatternComprehensionMatching for each pattern comprehension.
///
/// Storing the normalized pattern of a @c PatternComprehension does not preclude storing the
/// @c PatternComprehension clause itself inside `remaining_clauses`. The reason is that we
/// need to have access to other parts of the clause, such as pattern, filter clauses.
std::unordered_map<std::string, PatternComprehensionMatching> pattern_comprehension_matchings{};
/// @brief All the remaining clauses (without @c Match).
std::vector<Clause *> remaining_clauses{};
/// The subqueries vector are all the subqueries in this query part ordered in a list by

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -143,6 +143,13 @@ bool PlanPrinter::PreVisit(query::plan::Union &op) {
return false;
}
bool PlanPrinter::PreVisit(query::plan::RollUpApply &op) {
WithPrintLn([&op](auto &out) { out << "* " << op.ToString(); });
Branch(*op.list_collection_branch_);
op.input_->Accept(*this);
return false;
}
bool PlanPrinter::PreVisit(query::plan::CallProcedure &op) {
WithPrintLn([&op](auto &out) { out << "* " << op.ToString(); });
return true;

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -91,6 +91,7 @@ class PlanPrinter : public virtual HierarchicalLogicalOperatorVisitor {
bool PreVisit(OrderBy &) override;
bool PreVisit(Distinct &) override;
bool PreVisit(Union &) override;
bool PreVisit(RollUpApply &) override;
bool PreVisit(Unwind &) override;
bool PreVisit(CallProcedure &) override;

View File

@ -595,6 +595,18 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
return true;
}
bool PreVisit(RollUpApply &op) override {
prev_ops_.push_back(&op);
op.input()->Accept(*this);
RewriteBranch(&op.list_collection_branch_);
return false;
}
bool PostVisit(RollUpApply &) override {
prev_ops_.pop_back();
return true;
}
std::shared_ptr<LogicalOperator> new_root_;
private:

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -455,6 +455,18 @@ class JoinRewriter final : public HierarchicalLogicalOperatorVisitor {
return true;
}
bool PreVisit(RollUpApply &op) override {
prev_ops_.push_back(&op);
op.input()->Accept(*this);
RewriteBranch(&op.list_collection_branch_);
return false;
}
bool PostVisit(RollUpApply &) override {
prev_ops_.pop_back();
return true;
}
std::shared_ptr<LogicalOperator> new_root_;
private:

View File

@ -14,9 +14,12 @@
#include <algorithm>
#include <functional>
#include <limits>
#include <memory>
#include <stack>
#include <unordered_set>
#include "query/frontend/ast/ast.hpp"
#include "query/plan/operator.hpp"
#include "query/plan/preprocess.hpp"
#include "utils/algorithm.hpp"
#include "utils/exceptions.hpp"
@ -40,7 +43,8 @@ namespace {
class ReturnBodyContext : public HierarchicalTreeVisitor {
public:
ReturnBodyContext(const ReturnBody &body, SymbolTable &symbol_table, const std::unordered_set<Symbol> &bound_symbols,
AstStorage &storage, Where *where = nullptr)
AstStorage &storage, std::unordered_map<std::string, std::shared_ptr<LogicalOperator>> pc_ops,
Where *where = nullptr)
: body_(body), symbol_table_(symbol_table), bound_symbols_(bound_symbols), storage_(storage), where_(where) {
// Collect symbols from named expressions.
output_symbols_.reserve(body_.named_expressions.size());
@ -53,6 +57,14 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
output_symbols_.emplace_back(symbol_table_.at(*named_expr));
named_expr->Accept(*this);
named_expressions_.emplace_back(named_expr);
if (pattern_comprehension_) {
if (auto it = pc_ops.find(named_expr->name_); it != pc_ops.end()) {
pattern_comprehension_op_ = std::move(it->second);
pc_ops.erase(it);
} else {
throw utils::NotYetImplemented("Operation on top of pattern comprehension");
}
}
}
// Collect symbols used in group by expressions.
if (!aggregations_.empty()) {
@ -386,8 +398,20 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
return true;
}
bool PostVisit(PatternComprehension & /*unused*/) override {
throw utils::NotYetImplemented("Planner can not handle pattern comprehension.");
bool PreVisit(PatternComprehension & /*unused*/) override {
pattern_compression_aggregations_start_index_ = has_aggregation_.size();
return true;
}
bool PostVisit(PatternComprehension &pattern_comprehension) override {
bool has_aggr = false;
for (auto i = has_aggregation_.size(); i > pattern_compression_aggregations_start_index_; --i) {
has_aggr |= has_aggregation_.back();
has_aggregation_.pop_back();
}
has_aggregation_.emplace_back(has_aggr);
pattern_comprehension_ = &pattern_comprehension;
return true;
}
// Creates NamedExpression with an Identifier for each user declared symbol.
@ -444,6 +468,10 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
// named_expressions.
const auto &output_symbols() const { return output_symbols_; }
const auto *pattern_comprehension() const { return pattern_comprehension_; }
std::shared_ptr<LogicalOperator> pattern_comprehension_op() const { return pattern_comprehension_op_; }
private:
const ReturnBody &body_;
SymbolTable &symbol_table_;
@ -465,10 +493,13 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
// group by it.
std::list<bool> has_aggregation_;
std::vector<NamedExpression *> named_expressions_;
PatternComprehension *pattern_comprehension_ = nullptr;
std::shared_ptr<LogicalOperator> pattern_comprehension_op_;
size_t pattern_compression_aggregations_start_index_ = 0;
};
std::unique_ptr<LogicalOperator> GenReturnBody(std::unique_ptr<LogicalOperator> input_op, bool advance_command,
const ReturnBodyContext &body, bool accumulate = false) {
const ReturnBodyContext &body, bool accumulate) {
std::vector<Symbol> used_symbols(body.used_symbols().begin(), body.used_symbols().end());
auto last_op = std::move(input_op);
if (accumulate) {
@ -482,6 +513,11 @@ std::unique_ptr<LogicalOperator> GenReturnBody(std::unique_ptr<LogicalOperator>
std::vector<Symbol> remember(body.group_by_used_symbols().begin(), body.group_by_used_symbols().end());
last_op = std::make_unique<Aggregate>(std::move(last_op), body.aggregations(), body.group_by(), remember);
}
if (body.pattern_comprehension()) {
last_op = std::make_unique<RollUpApply>(std::move(last_op), body.pattern_comprehension_op());
}
last_op = std::make_unique<Produce>(std::move(last_op), body.named_expressions());
// Distinct in ReturnBody only makes Produce values unique, so plan after it.
if (body.distinct()) {
@ -506,6 +542,7 @@ std::unique_ptr<LogicalOperator> GenReturnBody(std::unique_ptr<LogicalOperator>
last_op = std::make_unique<Filter>(std::move(last_op), std::vector<std::shared_ptr<LogicalOperator>>{},
body.where()->expression_);
}
return last_op;
}
@ -543,8 +580,9 @@ Expression *ExtractFilters(const std::unordered_set<Symbol> &bound_symbols, Filt
return filter_expr;
}
std::unordered_set<Symbol> GetSubqueryBoundSymbols(const std::vector<SingleQueryPart> &single_query_parts,
SymbolTable &symbol_table, AstStorage &storage) {
std::unordered_set<Symbol> GetSubqueryBoundSymbols(
const std::vector<SingleQueryPart> &single_query_parts, SymbolTable &symbol_table, AstStorage &storage,
std::unordered_map<std::string, std::shared_ptr<LogicalOperator>> pc_ops) {
const auto &query = single_query_parts[0];
if (!query.matching.expansions.empty() || query.remaining_clauses.empty()) {
@ -552,7 +590,7 @@ std::unordered_set<Symbol> GetSubqueryBoundSymbols(const std::vector<SingleQuery
}
if (std::unordered_set<Symbol> bound_symbols; auto *with = utils::Downcast<query::With>(query.remaining_clauses[0])) {
auto input_op = impl::GenWith(*with, nullptr, symbol_table, false, bound_symbols, storage);
auto input_op = impl::GenWith(*with, nullptr, symbol_table, false, bound_symbols, storage, pc_ops);
return bound_symbols;
}
@ -583,7 +621,8 @@ std::unique_ptr<LogicalOperator> GenNamedPaths(std::unique_ptr<LogicalOperator>
std::unique_ptr<LogicalOperator> GenReturn(Return &ret, std::unique_ptr<LogicalOperator> input_op,
SymbolTable &symbol_table, bool is_write,
const std::unordered_set<Symbol> &bound_symbols, AstStorage &storage) {
const std::unordered_set<Symbol> &bound_symbols, AstStorage &storage,
std::unordered_map<std::string, std::shared_ptr<LogicalOperator>> pc_ops) {
// Similar to WITH clause, but we want to accumulate when the query writes to
// the database. This way we handle the case when we want to return
// expressions with the latest updated results. For example, `MATCH (n) -- ()
@ -592,13 +631,14 @@ std::unique_ptr<LogicalOperator> GenReturn(Return &ret, std::unique_ptr<LogicalO
// final result of 'k' increments.
bool accumulate = is_write;
bool advance_command = false;
ReturnBodyContext body(ret.body_, symbol_table, bound_symbols, storage);
ReturnBodyContext body(ret.body_, symbol_table, bound_symbols, storage, pc_ops);
return GenReturnBody(std::move(input_op), advance_command, body, accumulate);
}
std::unique_ptr<LogicalOperator> GenWith(With &with, std::unique_ptr<LogicalOperator> input_op,
SymbolTable &symbol_table, bool is_write,
std::unordered_set<Symbol> &bound_symbols, AstStorage &storage) {
std::unordered_set<Symbol> &bound_symbols, AstStorage &storage,
std::unordered_map<std::string, std::shared_ptr<LogicalOperator>> pc_ops) {
// WITH clause is Accumulate/Aggregate (advance_command) + Produce and
// optional Filter. In case of update and aggregation, we want to accumulate
// first, so that when aggregating, we get the latest results. Similar to
@ -606,7 +646,7 @@ std::unique_ptr<LogicalOperator> GenWith(With &with, std::unique_ptr<LogicalOper
bool accumulate = is_write;
// No need to advance the command if we only performed reads.
bool advance_command = is_write;
ReturnBodyContext body(with.body_, symbol_table, bound_symbols, storage, with.where_);
ReturnBodyContext body(with.body_, symbol_table, bound_symbols, storage, pc_ops, with.where_);
auto last_op = GenReturnBody(std::move(input_op), advance_command, body, accumulate);
// Reset bound symbols, so that only those in WITH are exposed.
bound_symbols.clear();

View File

@ -21,6 +21,7 @@
#include "query/frontend/ast/ast_visitor.hpp"
#include "query/plan/operator.hpp"
#include "query/plan/preprocess.hpp"
#include "utils/exceptions.hpp"
#include "utils/logging.hpp"
#include "utils/typeinfo.hpp"
@ -87,8 +88,9 @@ bool HasBoundFilterSymbols(const std::unordered_set<Symbol> &bound_symbols, cons
// Returns the set of symbols for the subquery that are actually referenced from the outer scope and
// used in the subquery.
std::unordered_set<Symbol> GetSubqueryBoundSymbols(const std::vector<SingleQueryPart> &single_query_parts,
SymbolTable &symbol_table, AstStorage &storage);
std::unordered_set<Symbol> GetSubqueryBoundSymbols(
const std::vector<SingleQueryPart> &single_query_parts, SymbolTable &symbol_table, AstStorage &storage,
std::unordered_map<std::string, std::shared_ptr<LogicalOperator>> pc_ops);
Symbol GetSymbol(NodeAtom *atom, const SymbolTable &symbol_table);
Symbol GetSymbol(EdgeAtom *atom, const SymbolTable &symbol_table);
@ -142,11 +144,13 @@ std::unique_ptr<LogicalOperator> GenNamedPaths(std::unique_ptr<LogicalOperator>
std::unique_ptr<LogicalOperator> GenReturn(Return &ret, std::unique_ptr<LogicalOperator> input_op,
SymbolTable &symbol_table, bool is_write,
const std::unordered_set<Symbol> &bound_symbols, AstStorage &storage);
const std::unordered_set<Symbol> &bound_symbols, AstStorage &storage,
std::unordered_map<std::string, std::shared_ptr<LogicalOperator>> pc_ops);
std::unique_ptr<LogicalOperator> GenWith(With &with, std::unique_ptr<LogicalOperator> input_op,
SymbolTable &symbol_table, bool is_write,
std::unordered_set<Symbol> &bound_symbols, AstStorage &storage);
std::unordered_set<Symbol> &bound_symbols, AstStorage &storage,
std::unordered_map<std::string, std::shared_ptr<LogicalOperator>> pc_ops);
std::unique_ptr<LogicalOperator> GenUnion(const CypherUnion &cypher_union, std::shared_ptr<LogicalOperator> left_op,
std::shared_ptr<LogicalOperator> right_op, SymbolTable &symbol_table);
@ -190,11 +194,24 @@ class RuleBasedPlanner {
uint64_t merge_id = 0;
uint64_t subquery_id = 0;
std::unordered_map<std::string, std::shared_ptr<LogicalOperator>> pattern_comprehension_ops;
if (single_query_part.pattern_comprehension_matchings.size() > 1) {
throw utils::NotYetImplemented("Multiple pattern comprehensions.");
}
for (const auto &matching : single_query_part.pattern_comprehension_matchings) {
std::unique_ptr<LogicalOperator> new_input;
MatchContext match_ctx{matching.second, *context.symbol_table, context.bound_symbols};
new_input = PlanMatching(match_ctx, std::move(new_input));
new_input = std::make_unique<Produce>(std::move(new_input), std::vector{matching.second.result_expr});
pattern_comprehension_ops.emplace(matching.first, std::move(new_input));
}
for (const auto &clause : single_query_part.remaining_clauses) {
MG_ASSERT(!utils::IsSubtype(*clause, Match::kType), "Unexpected Match in remaining clauses");
if (auto *ret = utils::Downcast<Return>(clause)) {
input_op = impl::GenReturn(*ret, std::move(input_op), *context.symbol_table, context.is_write_query,
context.bound_symbols, *context.ast_storage);
context.bound_symbols, *context.ast_storage, pattern_comprehension_ops);
} else if (auto *merge = utils::Downcast<query::Merge>(clause)) {
input_op = GenMerge(*merge, std::move(input_op), single_query_part.merge_matching[merge_id++]);
// Treat MERGE clause as write, because we do not know if it will
@ -202,7 +219,7 @@ class RuleBasedPlanner {
context.is_write_query = true;
} else if (auto *with = utils::Downcast<query::With>(clause)) {
input_op = impl::GenWith(*with, std::move(input_op), *context.symbol_table, context.is_write_query,
context.bound_symbols, *context.ast_storage);
context.bound_symbols, *context.ast_storage, pattern_comprehension_ops);
// WITH clause advances the command, so reset the flag.
context.is_write_query = false;
} else if (auto op = HandleWriteClause(clause, input_op, *context.symbol_table, context.bound_symbols)) {
@ -241,7 +258,7 @@ class RuleBasedPlanner {
single_query_part, merge_id);
} else if (auto *call_sub = utils::Downcast<query::CallSubquery>(clause)) {
input_op = HandleSubquery(std::move(input_op), single_query_part.subqueries[subquery_id++],
*context.symbol_table, *context_->ast_storage);
*context.symbol_table, *context_->ast_storage, pattern_comprehension_ops);
} else {
throw utils::NotYetImplemented("clause '{}' conversion to operator(s)", clause->GetTypeInfo().name);
}
@ -860,15 +877,15 @@ class RuleBasedPlanner {
symbol);
}
std::unique_ptr<LogicalOperator> HandleSubquery(std::unique_ptr<LogicalOperator> last_op,
std::shared_ptr<QueryParts> subquery, SymbolTable &symbol_table,
AstStorage &storage) {
std::unique_ptr<LogicalOperator> HandleSubquery(
std::unique_ptr<LogicalOperator> last_op, std::shared_ptr<QueryParts> subquery, SymbolTable &symbol_table,
AstStorage &storage, std::unordered_map<std::string, std::shared_ptr<LogicalOperator>> pc_ops) {
std::unordered_set<Symbol> outer_scope_bound_symbols;
outer_scope_bound_symbols.insert(std::make_move_iterator(context_->bound_symbols.begin()),
std::make_move_iterator(context_->bound_symbols.end()));
context_->bound_symbols =
impl::GetSubqueryBoundSymbols(subquery->query_parts[0].single_query_parts, symbol_table, storage);
impl::GetSubqueryBoundSymbols(subquery->query_parts[0].single_query_parts, symbol_table, storage, pc_ops);
auto subquery_op = Plan(*subquery);

View File

@ -68,6 +68,7 @@ enum class TypeId : uint64_t {
APPLY,
INDEXED_JOIN,
HASH_JOIN,
ROLLUP_APPLY,
// Replication
// NOTE: these NEED to be stable in the 2000+ range (see rpc version)

View File

@ -291,3 +291,45 @@ Feature: List operators
# Then the result should be:
# | years |
# | [2021,2003,2003,1999] |
Scenario: Multiple entries with list pattern comprehension
Given graph "graph_keanu"
When executing query:
"""
MATCH (n:Person)
RETURN n.name, [(n)-->(b:Movie) WHERE b.title CONTAINS 'Matrix' | b.released] AS years
"""
Then an error should be raised
Scenario: Multiple list pattern comprehensions in Return
Given graph "graph_keanu"
When executing query:
"""
MATCH (n:Person)
RETURN n.name,
[(n)-->(b:Movie) WHERE b.title CONTAINS 'Matrix' | b.released] AS years,
[(n)-->(b:Movie) WHERE b.title CONTAINS 'Matrix' | b.title] AS titles
"""
Then an error should be raised
Scenario: Function inside pattern comprehension's expression
Given graph "graph_keanu"
When executing query:
"""
MATCH (keanu:Person {name: 'Keanu Reeves'})
RETURN [p = (keanu)-->(b:Movie) WHERE b.title CONTAINS 'Matrix' | size(nodes(p))] AS nodes
"""
Then an error should be raised
Scenario: Multiple list pattern comprehensions in With
Given graph "graph_keanu"
When executing query:
"""
MATCH (n) WHERE size(n.name) > 5
WITH
n AS actor,
[(n)-->(m) WHERE m.released > 2000 | m.title] AS titles,
[(n)-->(m) WHERE m.released > 2000 | m.released] AS years
RETURN actor.name, years, titles;
"""
Then an error should be raised

View File

@ -1,5 +1,7 @@
CREATE
(keanu:Person {name: 'Keanu Reeves'}),
(trinity:Person {name: 'Carrie-Anne Moss'}),
(morpheus:Person {name: 'Laurence Fishburne'}),
(johnnyMnemonic:Movie {title: 'Johnny Mnemonic', released: 1995}),
(theMatrixRevolutions:Movie {title: 'The Matrix Revolutions', released: 2003}),
(theMatrixReloaded:Movie {title: 'The Matrix Reloaded', released: 2003}),
@ -13,4 +15,7 @@ CREATE
(keanu)-[:ACTED_IN]->(theReplacements),
(keanu)-[:ACTED_IN]->(theMatrix),
(keanu)-[:ACTED_IN]->(theDevilsAdvocate),
(keanu)-[:ACTED_IN]->(theMatrixResurrections);
(keanu)-[:ACTED_IN]->(theMatrixResurrections),
(trinity)-[:ACTED_IN]->(theMatrix),
(trinity)-[:ACTED_IN]->(theMatrixReloaded),
(morpheus)-[:ACTED_IN]->(theMatrix);

View File

@ -4624,3 +4624,101 @@ TEST_P(CypherMainVisitorTest, CallSubquery) {
ASSERT_TRUE(nested_match);
}
}
TEST_P(CypherMainVisitorTest, PatternComprehension) {
auto &ast_generator = *GetParam();
{
const auto *query =
dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("MATCH (n) RETURN [(n)-->(b) | b.val] AS res;"));
const auto *ret = dynamic_cast<Return *>(query->single_query_->clauses_[1]);
const auto *pc = dynamic_cast<PatternComprehension *>(ret->body_.named_expressions[0]->expression_);
ASSERT_TRUE(pc);
// Check for variable_
EXPECT_EQ(pc->variable_, nullptr);
// Check for pattern_
const auto pattern = pc->pattern_;
ASSERT_TRUE(pattern->atoms_.size() == 3);
const auto *node1 = dynamic_cast<NodeAtom *>(pattern->atoms_[0]);
const auto *edge = dynamic_cast<EdgeAtom *>(pattern->atoms_[1]);
const auto *node2 = dynamic_cast<NodeAtom *>(pattern->atoms_[2]);
ASSERT_TRUE(node1);
ASSERT_TRUE(edge);
ASSERT_TRUE(node2);
// Check for filter_
EXPECT_EQ(pc->filter_, nullptr);
// Check for resultExpr_
const auto *result_expr = pc->resultExpr_;
ASSERT_TRUE(result_expr);
}
{
const auto *query = dynamic_cast<CypherQuery *>(
ast_generator.ParseQuery("MATCH (n) RETURN [(n)-->(b) WHERE b.id=1 | b.val] AS res;"));
const auto *ret = dynamic_cast<Return *>(query->single_query_->clauses_[1]);
const auto *pc = dynamic_cast<PatternComprehension *>(ret->body_.named_expressions[0]->expression_);
ASSERT_TRUE(pc);
// Check for variable_
EXPECT_EQ(pc->variable_, nullptr);
// Check for pattern_
const auto pattern = pc->pattern_;
ASSERT_TRUE(pattern->atoms_.size() == 3);
const auto *node1 = dynamic_cast<NodeAtom *>(pattern->atoms_[0]);
const auto *edge = dynamic_cast<EdgeAtom *>(pattern->atoms_[1]);
const auto *node2 = dynamic_cast<NodeAtom *>(pattern->atoms_[2]);
ASSERT_TRUE(node1);
ASSERT_TRUE(edge);
ASSERT_TRUE(node2);
// Check for filter_
const auto *filter = pc->filter_;
ASSERT_TRUE(filter);
ASSERT_TRUE(filter->expression_);
// Check for resultExpr_
const auto *result_expr = pc->resultExpr_;
ASSERT_TRUE(result_expr);
}
{
const auto *query = dynamic_cast<CypherQuery *>(
ast_generator.ParseQuery("MATCH (n) RETURN [p = (n)-->(b) WHERE b.id=1 | b.val] AS res;"));
const auto *ret = dynamic_cast<Return *>(query->single_query_->clauses_[1]);
const auto *pc = dynamic_cast<PatternComprehension *>(ret->body_.named_expressions[0]->expression_);
ASSERT_TRUE(pc);
// Check for variable_
ASSERT_TRUE(pc->variable_);
// Check for pattern_
const auto pattern = pc->pattern_;
ASSERT_TRUE(pattern->atoms_.size() == 3);
const auto *node1 = dynamic_cast<NodeAtom *>(pattern->atoms_[0]);
const auto *edge = dynamic_cast<EdgeAtom *>(pattern->atoms_[1]);
const auto *node2 = dynamic_cast<NodeAtom *>(pattern->atoms_[2]);
ASSERT_TRUE(node1);
ASSERT_TRUE(edge);
ASSERT_TRUE(node2);
// Check for filter_
const auto *filter = pc->filter_;
ASSERT_TRUE(filter);
ASSERT_TRUE(filter->expression_);
// Check for resultExpr_
const auto *result_expr = pc->resultExpr_;
ASSERT_TRUE(result_expr);
}
}

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -635,3 +635,5 @@ auto GetForeach(AstStorage &storage, NamedExpression *named_expr, const std::vec
#define DROP_USER(usernames) storage.Create<memgraph::query::DropUser>((usernames))
#define CALL_PROCEDURE(...) memgraph::query::test_common::GetCallProcedure(storage, __VA_ARGS__)
#define CALL_SUBQUERY(...) memgraph::query::test_common::GetCallSubquery(this->storage, __VA_ARGS__)
#define PATTERN_COMPREHENSION(variable, pattern, filter, resultExpr) \
this->storage.template Create<memgraph::query::PatternComprehension>(variable, pattern, filter, resultExpr)

View File

@ -1442,3 +1442,27 @@ TYPED_TEST(TestSymbolGenerator, PropertyCachingMixedLookups2) {
ASSERT_TRUE(prop3_eval_mode == PropertyLookup::EvaluationMode::GET_ALL_PROPERTIES);
ASSERT_TRUE(prop4_eval_mode == PropertyLookup::EvaluationMode::GET_ALL_PROPERTIES);
}
TYPED_TEST(TestSymbolGenerator, PatternComprehension) {
auto prop = this->dba.NameToProperty("prop");
// MATCH (n) RETURN [(n)-[edge]->(m) | m.prop] AS alias
auto query = QUERY(SINGLE_QUERY(
MATCH(PATTERN(NODE("n"))),
RETURN(NEXPR("alias", PATTERN_COMPREHENSION(nullptr,
PATTERN(NODE("n"), EDGE("edge", EdgeAtom::Direction::BOTH, {}, false),
NODE("m", std::nullopt, false)),
nullptr, PROPERTY_LOOKUP(this->dba, "m", prop))))));
auto symbol_table = MakeSymbolTable(query);
ASSERT_EQ(symbol_table.max_position(), 7);
memgraph::query::plan::UsedSymbolsCollector collector(symbol_table);
auto *ret = dynamic_cast<Return *>(query->single_query_->clauses_[1]);
auto *pc = dynamic_cast<PatternComprehension *>(ret->body_.named_expressions[0]->expression_);
pc->Accept(collector);
// n, edge, m, Path
ASSERT_EQ(collector.symbols_.size(), 4);
}