List Pattern Comprehension planner (#1686)
This commit is contained in:
parent
02325f8673
commit
a099417c56
@ -21,6 +21,7 @@
|
||||
#include "query/interpret/awesome_memgraph_functions.hpp"
|
||||
#include "query/typed_value.hpp"
|
||||
#include "storage/v2/property_value.hpp"
|
||||
#include "utils/exceptions.hpp"
|
||||
#include "utils/typeinfo.hpp"
|
||||
|
||||
namespace memgraph::query {
|
||||
@ -3586,7 +3587,7 @@ class PatternComprehension : public memgraph::query::Expression {
|
||||
bool Accept(HierarchicalTreeVisitor &visitor) override {
|
||||
if (visitor.PreVisit(*this)) {
|
||||
if (variable_) {
|
||||
variable_->Accept(visitor);
|
||||
throw utils::NotYetImplemented("Variable in pattern comprehension.");
|
||||
}
|
||||
pattern_->Accept(visitor);
|
||||
if (filter_) {
|
||||
@ -3615,7 +3616,8 @@ class PatternComprehension : public memgraph::query::Expression {
|
||||
int32_t symbol_pos_{-1};
|
||||
|
||||
PatternComprehension *Clone(AstStorage *storage) const override {
|
||||
PatternComprehension *object = storage->Create<PatternComprehension>();
|
||||
auto *object = storage->Create<PatternComprehension>();
|
||||
object->variable_ = variable_ ? variable_->Clone(storage) : nullptr;
|
||||
object->pattern_ = pattern_ ? pattern_->Clone(storage) : nullptr;
|
||||
object->filter_ = filter_ ? filter_->Clone(storage) : nullptr;
|
||||
object->resultExpr_ = resultExpr_ ? resultExpr_->Clone(storage) : nullptr;
|
||||
@ -3625,7 +3627,8 @@ class PatternComprehension : public memgraph::query::Expression {
|
||||
}
|
||||
|
||||
protected:
|
||||
PatternComprehension(Identifier *variable, Pattern *pattern) : variable_(variable), pattern_(pattern) {}
|
||||
PatternComprehension(Identifier *variable, Pattern *pattern, Where *filter, Expression *resultExpr)
|
||||
: variable_(variable), pattern_(pattern), filter_(filter), resultExpr_(resultExpr) {}
|
||||
|
||||
private:
|
||||
friend class AstStorage;
|
||||
|
@ -721,6 +721,32 @@ bool SymbolGenerator::PostVisit(EdgeAtom &) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SymbolGenerator::PreVisit(PatternComprehension &pc) {
|
||||
auto &scope = scopes_.back();
|
||||
|
||||
if (scope.in_set_property) {
|
||||
throw utils::NotYetImplemented("Pattern Comprehension cannot be used within SET clause.!");
|
||||
}
|
||||
|
||||
if (scope.in_with) {
|
||||
throw utils::NotYetImplemented("Pattern Comprehension cannot be used within WITH!");
|
||||
}
|
||||
|
||||
if (scope.in_reduce) {
|
||||
throw utils::NotYetImplemented("Pattern Comprehension cannot be used within REDUCE!");
|
||||
}
|
||||
|
||||
if (scope.num_if_operators) {
|
||||
throw utils::NotYetImplemented("IF operator cannot be used with Pattern Comprehension!");
|
||||
}
|
||||
|
||||
const auto &symbol = CreateAnonymousSymbol();
|
||||
pc.MapTo(symbol);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SymbolGenerator::PostVisit(PatternComprehension & /*pc*/) { return true; }
|
||||
|
||||
void SymbolGenerator::VisitWithIdentifiers(Expression *expr, const std::vector<Identifier *> &identifiers) {
|
||||
auto &scope = scopes_.back();
|
||||
std::vector<std::pair<std::optional<Symbol>, Identifier *>> prev_symbols;
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
// Copyright 2024 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -97,6 +97,8 @@ class SymbolGenerator : public HierarchicalTreeVisitor {
|
||||
bool PostVisit(NodeAtom &) override;
|
||||
bool PreVisit(EdgeAtom &) override;
|
||||
bool PostVisit(EdgeAtom &) override;
|
||||
bool PreVisit(PatternComprehension &) override;
|
||||
bool PostVisit(PatternComprehension &) override;
|
||||
|
||||
private:
|
||||
// Scope stores the state of where we are when visiting the AST and a map of
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
// Copyright 2024 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -206,6 +206,14 @@ class PlanHintsProvider final : public HierarchicalLogicalOperatorVisitor {
|
||||
|
||||
bool PostVisit(IndexedJoin & /*unused*/) override { return true; }
|
||||
|
||||
bool PreVisit(RollUpApply &op) override {
|
||||
op.input()->Accept(*this);
|
||||
op.list_collection_branch_->Accept(*this);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool PostVisit(RollUpApply & /*unused*/) override { return true; }
|
||||
|
||||
private:
|
||||
const SymbolTable &symbol_table_;
|
||||
std::vector<std::string> hints_;
|
||||
|
@ -5624,4 +5624,25 @@ UniqueCursorPtr HashJoin::MakeCursor(utils::MemoryResource *mem) const {
|
||||
return MakeUniqueCursorPtr<HashJoinCursor>(mem, *this, mem);
|
||||
}
|
||||
|
||||
RollUpApply::RollUpApply(const std::shared_ptr<LogicalOperator> &input,
|
||||
std::shared_ptr<LogicalOperator> &&second_branch)
|
||||
: input_(input), list_collection_branch_(second_branch) {}
|
||||
|
||||
std::vector<Symbol> RollUpApply::OutputSymbols(const SymbolTable & /*symbol_table*/) const {
|
||||
std::vector<Symbol> symbols;
|
||||
return symbols;
|
||||
}
|
||||
|
||||
std::vector<Symbol> RollUpApply::ModifiedSymbols(const SymbolTable &table) const { return OutputSymbols(table); }
|
||||
|
||||
bool RollUpApply::Accept(HierarchicalLogicalOperatorVisitor &visitor) {
|
||||
if (visitor.PreVisit(*this)) {
|
||||
if (!input_ || !list_collection_branch_) {
|
||||
throw utils::NotYetImplemented("One of the branches in pattern comprehension is null! Please contact support.");
|
||||
}
|
||||
input_->Accept(visitor) && list_collection_branch_->Accept(visitor);
|
||||
}
|
||||
return visitor.PostVisit(*this);
|
||||
}
|
||||
|
||||
} // namespace memgraph::query::plan
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
// Copyright 2024 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -130,6 +130,7 @@ class EvaluatePatternFilter;
|
||||
class Apply;
|
||||
class IndexedJoin;
|
||||
class HashJoin;
|
||||
class RollUpApply;
|
||||
|
||||
using LogicalOperatorCompositeVisitor =
|
||||
utils::CompositeVisitor<Once, CreateNode, CreateExpand, ScanAll, ScanAllByLabel, ScanAllByLabelPropertyRange,
|
||||
@ -137,7 +138,7 @@ using LogicalOperatorCompositeVisitor =
|
||||
ConstructNamedPath, Filter, Produce, Delete, SetProperty, SetProperties, SetLabels,
|
||||
RemoveProperty, RemoveLabels, EdgeUniquenessFilter, Accumulate, Aggregate, Skip, Limit,
|
||||
OrderBy, Merge, Optional, Unwind, Distinct, Union, Cartesian, CallProcedure, LoadCsv,
|
||||
Foreach, EmptyResult, EvaluatePatternFilter, Apply, IndexedJoin, HashJoin>;
|
||||
Foreach, EmptyResult, EvaluatePatternFilter, Apply, IndexedJoin, HashJoin, RollUpApply>;
|
||||
|
||||
using LogicalOperatorLeafVisitor = utils::LeafVisitor<Once>;
|
||||
|
||||
@ -2634,5 +2635,38 @@ class HashJoin : public memgraph::query::plan::LogicalOperator {
|
||||
}
|
||||
};
|
||||
|
||||
/// RollUpApply operator is used to execute an expression which takes as input a pattern,
|
||||
/// and returns a list with content from the matched pattern
|
||||
/// It's used for a pattern expression or pattern comprehension in a query.
|
||||
class RollUpApply : public memgraph::query::plan::LogicalOperator {
|
||||
public:
|
||||
static const utils::TypeInfo kType;
|
||||
const utils::TypeInfo &GetTypeInfo() const override { return kType; }
|
||||
|
||||
RollUpApply() = default;
|
||||
RollUpApply(const std::shared_ptr<LogicalOperator> &input, std::shared_ptr<LogicalOperator> &&second_branch);
|
||||
|
||||
bool HasSingleInput() const override { return false; }
|
||||
std::shared_ptr<LogicalOperator> input() const override { return input_; }
|
||||
void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; }
|
||||
|
||||
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
|
||||
UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override {
|
||||
throw utils::NotYetImplemented("Execution of Pattern comprehension is currently unsupported.");
|
||||
}
|
||||
std::vector<Symbol> OutputSymbols(const SymbolTable &) const override;
|
||||
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
|
||||
|
||||
std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override {
|
||||
auto object = std::make_unique<RollUpApply>();
|
||||
object->input_ = input_ ? input_->Clone(storage) : nullptr;
|
||||
object->list_collection_branch_ = list_collection_branch_ ? list_collection_branch_->Clone(storage) : nullptr;
|
||||
return object;
|
||||
}
|
||||
|
||||
std::shared_ptr<memgraph::query::plan::LogicalOperator> input_;
|
||||
std::shared_ptr<memgraph::query::plan::LogicalOperator> list_collection_branch_;
|
||||
};
|
||||
|
||||
} // namespace plan
|
||||
} // namespace memgraph::query
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
// Copyright 2024 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -154,4 +154,7 @@ constexpr utils::TypeInfo query::plan::IndexedJoin::kType{utils::TypeId::INDEXED
|
||||
|
||||
constexpr utils::TypeInfo query::plan::HashJoin::kType{utils::TypeId::HASH_JOIN, "HashJoin",
|
||||
&query::plan::LogicalOperator::kType};
|
||||
|
||||
constexpr utils::TypeInfo query::plan::RollUpApply::kType{utils::TypeId::ROLLUP_APPLY, "RollUpApply",
|
||||
&query::plan::LogicalOperator::kType};
|
||||
} // namespace memgraph
|
||||
|
@ -632,20 +632,20 @@ void AddMatching(const Match &match, SymbolTable &symbol_table, AstStorage &stor
|
||||
|
||||
// If there are any pattern filters, we add those as well
|
||||
for (auto &filter : matching.filters) {
|
||||
PatternFilterVisitor visitor(symbol_table, storage);
|
||||
PatternVisitor visitor(symbol_table, storage);
|
||||
|
||||
filter.expression->Accept(visitor);
|
||||
filter.matchings = visitor.getMatchings();
|
||||
filter.matchings = visitor.getFilterMatchings();
|
||||
}
|
||||
}
|
||||
|
||||
PatternFilterVisitor::PatternFilterVisitor(SymbolTable &symbol_table, AstStorage &storage)
|
||||
PatternVisitor::PatternVisitor(SymbolTable &symbol_table, AstStorage &storage)
|
||||
: symbol_table_(symbol_table), storage_(storage) {}
|
||||
PatternFilterVisitor::PatternFilterVisitor(const PatternFilterVisitor &) = default;
|
||||
PatternFilterVisitor::PatternFilterVisitor(PatternFilterVisitor &&) noexcept = default;
|
||||
PatternFilterVisitor::~PatternFilterVisitor() = default;
|
||||
PatternVisitor::PatternVisitor(const PatternVisitor &) = default;
|
||||
PatternVisitor::PatternVisitor(PatternVisitor &&) noexcept = default;
|
||||
PatternVisitor::~PatternVisitor() = default;
|
||||
|
||||
void PatternFilterVisitor::Visit(Exists &op) {
|
||||
void PatternVisitor::Visit(Exists &op) {
|
||||
std::vector<Pattern *> patterns;
|
||||
patterns.push_back(op.pattern_);
|
||||
|
||||
@ -655,10 +655,10 @@ void PatternFilterVisitor::Visit(Exists &op) {
|
||||
filter_matching.type = PatternFilterType::EXISTS;
|
||||
filter_matching.symbol = std::make_optional<Symbol>(symbol_table_.at(op));
|
||||
|
||||
matchings_.push_back(std::move(filter_matching));
|
||||
filter_matchings_.push_back(std::move(filter_matching));
|
||||
}
|
||||
|
||||
std::vector<FilterMatching> PatternFilterVisitor::getMatchings() { return matchings_; }
|
||||
std::vector<FilterMatching> PatternVisitor::getFilterMatchings() { return filter_matchings_; }
|
||||
|
||||
static void ParseForeach(query::Foreach &foreach, SingleQueryPart &query_part, AstStorage &storage,
|
||||
SymbolTable &symbol_table) {
|
||||
@ -672,6 +672,30 @@ static void ParseForeach(query::Foreach &foreach, SingleQueryPart &query_part, A
|
||||
}
|
||||
}
|
||||
|
||||
static void ParseReturn(query::Return &ret, AstStorage &storage, SymbolTable &symbol_table,
|
||||
std::unordered_map<std::string, PatternComprehensionMatching> &matchings) {
|
||||
PatternVisitor visitor(symbol_table, storage);
|
||||
|
||||
for (auto *expr : ret.body_.named_expressions) {
|
||||
expr->Accept(visitor);
|
||||
auto pattern_comprehension_matchings = visitor.getPatternComprehensionMatchings();
|
||||
for (auto &matching : pattern_comprehension_matchings) {
|
||||
matchings.emplace(expr->name_, matching);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PatternVisitor::Visit(NamedExpression &op) { op.expression_->Accept(*this); }
|
||||
|
||||
void PatternVisitor::Visit(PatternComprehension &op) {
|
||||
PatternComprehensionMatching matching;
|
||||
AddMatching({op.pattern_}, op.filter_, symbol_table_, storage_, matching);
|
||||
matching.result_expr = storage_.Create<NamedExpression>(symbol_table_.at(op).name(), op.resultExpr_);
|
||||
matching.result_expr->MapTo(symbol_table_.at(op));
|
||||
|
||||
pattern_comprehension_matchings_.push_back(std::move(matching));
|
||||
}
|
||||
|
||||
// Converts a Query to multiple QueryParts. In the process new Ast nodes may be
|
||||
// created, e.g. filter expressions.
|
||||
std::vector<SingleQueryPart> CollectSingleQueryParts(SymbolTable &symbol_table, AstStorage &storage,
|
||||
@ -703,7 +727,8 @@ std::vector<SingleQueryPart> CollectSingleQueryParts(SymbolTable &symbol_table,
|
||||
// This query part is done, continue with a new one.
|
||||
query_parts.emplace_back(SingleQueryPart{});
|
||||
query_part = &query_parts.back();
|
||||
} else if (utils::IsSubtype(*clause, Return::kType)) {
|
||||
} else if (auto *ret = utils::Downcast<Return>(clause)) {
|
||||
ParseReturn(*ret, storage, symbol_table, query_part->pattern_comprehension_matchings);
|
||||
return query_parts;
|
||||
}
|
||||
}
|
||||
|
@ -153,19 +153,20 @@ struct Expansion {
|
||||
ExpansionGroupId expansion_group_id = ExpansionGroupId();
|
||||
};
|
||||
|
||||
struct PatternComprehensionMatching;
|
||||
struct FilterMatching;
|
||||
|
||||
enum class PatternFilterType { EXISTS };
|
||||
|
||||
/// Collects matchings from filters that include patterns
|
||||
class PatternFilterVisitor : public ExpressionVisitor<void> {
|
||||
/// Collects matchings that include patterns
|
||||
class PatternVisitor : public ExpressionVisitor<void> {
|
||||
public:
|
||||
explicit PatternFilterVisitor(SymbolTable &symbol_table, AstStorage &storage);
|
||||
PatternFilterVisitor(const PatternFilterVisitor &);
|
||||
PatternFilterVisitor &operator=(const PatternFilterVisitor &) = delete;
|
||||
PatternFilterVisitor(PatternFilterVisitor &&) noexcept;
|
||||
PatternFilterVisitor &operator=(PatternFilterVisitor &&) noexcept = delete;
|
||||
~PatternFilterVisitor() override;
|
||||
explicit PatternVisitor(SymbolTable &symbol_table, AstStorage &storage);
|
||||
PatternVisitor(const PatternVisitor &);
|
||||
PatternVisitor &operator=(const PatternVisitor &) = delete;
|
||||
PatternVisitor(PatternVisitor &&) noexcept;
|
||||
PatternVisitor &operator=(PatternVisitor &&) noexcept = delete;
|
||||
~PatternVisitor() override;
|
||||
|
||||
using ExpressionVisitor<void>::Visit;
|
||||
|
||||
@ -233,18 +234,24 @@ class PatternFilterVisitor : public ExpressionVisitor<void> {
|
||||
void Visit(PropertyLookup &op) override{};
|
||||
void Visit(AllPropertiesLookup &op) override{};
|
||||
void Visit(ParameterLookup &op) override{};
|
||||
void Visit(NamedExpression &op) override{};
|
||||
void Visit(RegexMatch &op) override{};
|
||||
void Visit(PatternComprehension &op) override{};
|
||||
void Visit(NamedExpression &op) override;
|
||||
void Visit(PatternComprehension &op) override;
|
||||
|
||||
std::vector<FilterMatching> getMatchings();
|
||||
std::vector<FilterMatching> getFilterMatchings();
|
||||
std::vector<PatternComprehensionMatching> getPatternComprehensionMatchings() {
|
||||
return pattern_comprehension_matchings_;
|
||||
}
|
||||
|
||||
SymbolTable &symbol_table_;
|
||||
AstStorage &storage_;
|
||||
|
||||
private:
|
||||
/// Collection of matchings in the filter expression being analyzed.
|
||||
std::vector<FilterMatching> matchings_;
|
||||
std::vector<FilterMatching> filter_matchings_;
|
||||
|
||||
/// Collection of matchings in the pattern comprehension being analyzed.
|
||||
std::vector<PatternComprehensionMatching> pattern_comprehension_matchings_;
|
||||
};
|
||||
|
||||
/// Stores the symbols and expression used to filter a property.
|
||||
@ -495,6 +502,11 @@ inline auto Filters::IdFilters(const Symbol &symbol) const -> std::vector<Filter
|
||||
return filters;
|
||||
}
|
||||
|
||||
struct PatternComprehensionMatching : Matching {
|
||||
/// Pattern comprehension result named expression
|
||||
NamedExpression *result_expr = nullptr;
|
||||
};
|
||||
|
||||
/// @brief Represents a read (+ write) part of a query. Parts are split on
|
||||
/// `WITH` clauses.
|
||||
///
|
||||
@ -537,6 +549,14 @@ struct SingleQueryPart {
|
||||
/// in the `remaining_clauses` but rather in the `Foreach` itself and are guranteed
|
||||
/// to be processed in the same order by the semantics of the `RuleBasedPlanner`.
|
||||
std::vector<Matching> merge_matching{};
|
||||
|
||||
/// @brief @c NamedExpression name to @c PatternComprehensionMatching for each pattern comprehension.
|
||||
///
|
||||
/// Storing the normalized pattern of a @c PatternComprehension does not preclude storing the
|
||||
/// @c PatternComprehension clause itself inside `remaining_clauses`. The reason is that we
|
||||
/// need to have access to other parts of the clause, such as pattern, filter clauses.
|
||||
std::unordered_map<std::string, PatternComprehensionMatching> pattern_comprehension_matchings{};
|
||||
|
||||
/// @brief All the remaining clauses (without @c Match).
|
||||
std::vector<Clause *> remaining_clauses{};
|
||||
/// The subqueries vector are all the subqueries in this query part ordered in a list by
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
// Copyright 2024 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -143,6 +143,13 @@ bool PlanPrinter::PreVisit(query::plan::Union &op) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool PlanPrinter::PreVisit(query::plan::RollUpApply &op) {
|
||||
WithPrintLn([&op](auto &out) { out << "* " << op.ToString(); });
|
||||
Branch(*op.list_collection_branch_);
|
||||
op.input_->Accept(*this);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool PlanPrinter::PreVisit(query::plan::CallProcedure &op) {
|
||||
WithPrintLn([&op](auto &out) { out << "* " << op.ToString(); });
|
||||
return true;
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
// Copyright 2024 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -91,6 +91,7 @@ class PlanPrinter : public virtual HierarchicalLogicalOperatorVisitor {
|
||||
bool PreVisit(OrderBy &) override;
|
||||
bool PreVisit(Distinct &) override;
|
||||
bool PreVisit(Union &) override;
|
||||
bool PreVisit(RollUpApply &) override;
|
||||
|
||||
bool PreVisit(Unwind &) override;
|
||||
bool PreVisit(CallProcedure &) override;
|
||||
|
@ -595,6 +595,18 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PreVisit(RollUpApply &op) override {
|
||||
prev_ops_.push_back(&op);
|
||||
op.input()->Accept(*this);
|
||||
RewriteBranch(&op.list_collection_branch_);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool PostVisit(RollUpApply &) override {
|
||||
prev_ops_.pop_back();
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<LogicalOperator> new_root_;
|
||||
|
||||
private:
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
// Copyright 2024 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -455,6 +455,18 @@ class JoinRewriter final : public HierarchicalLogicalOperatorVisitor {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PreVisit(RollUpApply &op) override {
|
||||
prev_ops_.push_back(&op);
|
||||
op.input()->Accept(*this);
|
||||
RewriteBranch(&op.list_collection_branch_);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool PostVisit(RollUpApply &) override {
|
||||
prev_ops_.pop_back();
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<LogicalOperator> new_root_;
|
||||
|
||||
private:
|
||||
|
@ -14,9 +14,12 @@
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <stack>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "query/frontend/ast/ast.hpp"
|
||||
#include "query/plan/operator.hpp"
|
||||
#include "query/plan/preprocess.hpp"
|
||||
#include "utils/algorithm.hpp"
|
||||
#include "utils/exceptions.hpp"
|
||||
@ -40,7 +43,8 @@ namespace {
|
||||
class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
public:
|
||||
ReturnBodyContext(const ReturnBody &body, SymbolTable &symbol_table, const std::unordered_set<Symbol> &bound_symbols,
|
||||
AstStorage &storage, Where *where = nullptr)
|
||||
AstStorage &storage, std::unordered_map<std::string, std::shared_ptr<LogicalOperator>> pc_ops,
|
||||
Where *where = nullptr)
|
||||
: body_(body), symbol_table_(symbol_table), bound_symbols_(bound_symbols), storage_(storage), where_(where) {
|
||||
// Collect symbols from named expressions.
|
||||
output_symbols_.reserve(body_.named_expressions.size());
|
||||
@ -53,6 +57,14 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
output_symbols_.emplace_back(symbol_table_.at(*named_expr));
|
||||
named_expr->Accept(*this);
|
||||
named_expressions_.emplace_back(named_expr);
|
||||
if (pattern_comprehension_) {
|
||||
if (auto it = pc_ops.find(named_expr->name_); it != pc_ops.end()) {
|
||||
pattern_comprehension_op_ = std::move(it->second);
|
||||
pc_ops.erase(it);
|
||||
} else {
|
||||
throw utils::NotYetImplemented("Operation on top of pattern comprehension");
|
||||
}
|
||||
}
|
||||
}
|
||||
// Collect symbols used in group by expressions.
|
||||
if (!aggregations_.empty()) {
|
||||
@ -386,8 +398,20 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PostVisit(PatternComprehension & /*unused*/) override {
|
||||
throw utils::NotYetImplemented("Planner can not handle pattern comprehension.");
|
||||
bool PreVisit(PatternComprehension & /*unused*/) override {
|
||||
pattern_compression_aggregations_start_index_ = has_aggregation_.size();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PostVisit(PatternComprehension &pattern_comprehension) override {
|
||||
bool has_aggr = false;
|
||||
for (auto i = has_aggregation_.size(); i > pattern_compression_aggregations_start_index_; --i) {
|
||||
has_aggr |= has_aggregation_.back();
|
||||
has_aggregation_.pop_back();
|
||||
}
|
||||
has_aggregation_.emplace_back(has_aggr);
|
||||
pattern_comprehension_ = &pattern_comprehension;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Creates NamedExpression with an Identifier for each user declared symbol.
|
||||
@ -444,6 +468,10 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
// named_expressions.
|
||||
const auto &output_symbols() const { return output_symbols_; }
|
||||
|
||||
const auto *pattern_comprehension() const { return pattern_comprehension_; }
|
||||
|
||||
std::shared_ptr<LogicalOperator> pattern_comprehension_op() const { return pattern_comprehension_op_; }
|
||||
|
||||
private:
|
||||
const ReturnBody &body_;
|
||||
SymbolTable &symbol_table_;
|
||||
@ -465,10 +493,13 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
// group by it.
|
||||
std::list<bool> has_aggregation_;
|
||||
std::vector<NamedExpression *> named_expressions_;
|
||||
PatternComprehension *pattern_comprehension_ = nullptr;
|
||||
std::shared_ptr<LogicalOperator> pattern_comprehension_op_;
|
||||
size_t pattern_compression_aggregations_start_index_ = 0;
|
||||
};
|
||||
|
||||
std::unique_ptr<LogicalOperator> GenReturnBody(std::unique_ptr<LogicalOperator> input_op, bool advance_command,
|
||||
const ReturnBodyContext &body, bool accumulate = false) {
|
||||
const ReturnBodyContext &body, bool accumulate) {
|
||||
std::vector<Symbol> used_symbols(body.used_symbols().begin(), body.used_symbols().end());
|
||||
auto last_op = std::move(input_op);
|
||||
if (accumulate) {
|
||||
@ -482,6 +513,11 @@ std::unique_ptr<LogicalOperator> GenReturnBody(std::unique_ptr<LogicalOperator>
|
||||
std::vector<Symbol> remember(body.group_by_used_symbols().begin(), body.group_by_used_symbols().end());
|
||||
last_op = std::make_unique<Aggregate>(std::move(last_op), body.aggregations(), body.group_by(), remember);
|
||||
}
|
||||
|
||||
if (body.pattern_comprehension()) {
|
||||
last_op = std::make_unique<RollUpApply>(std::move(last_op), body.pattern_comprehension_op());
|
||||
}
|
||||
|
||||
last_op = std::make_unique<Produce>(std::move(last_op), body.named_expressions());
|
||||
// Distinct in ReturnBody only makes Produce values unique, so plan after it.
|
||||
if (body.distinct()) {
|
||||
@ -506,6 +542,7 @@ std::unique_ptr<LogicalOperator> GenReturnBody(std::unique_ptr<LogicalOperator>
|
||||
last_op = std::make_unique<Filter>(std::move(last_op), std::vector<std::shared_ptr<LogicalOperator>>{},
|
||||
body.where()->expression_);
|
||||
}
|
||||
|
||||
return last_op;
|
||||
}
|
||||
|
||||
@ -543,8 +580,9 @@ Expression *ExtractFilters(const std::unordered_set<Symbol> &bound_symbols, Filt
|
||||
return filter_expr;
|
||||
}
|
||||
|
||||
std::unordered_set<Symbol> GetSubqueryBoundSymbols(const std::vector<SingleQueryPart> &single_query_parts,
|
||||
SymbolTable &symbol_table, AstStorage &storage) {
|
||||
std::unordered_set<Symbol> GetSubqueryBoundSymbols(
|
||||
const std::vector<SingleQueryPart> &single_query_parts, SymbolTable &symbol_table, AstStorage &storage,
|
||||
std::unordered_map<std::string, std::shared_ptr<LogicalOperator>> pc_ops) {
|
||||
const auto &query = single_query_parts[0];
|
||||
|
||||
if (!query.matching.expansions.empty() || query.remaining_clauses.empty()) {
|
||||
@ -552,7 +590,7 @@ std::unordered_set<Symbol> GetSubqueryBoundSymbols(const std::vector<SingleQuery
|
||||
}
|
||||
|
||||
if (std::unordered_set<Symbol> bound_symbols; auto *with = utils::Downcast<query::With>(query.remaining_clauses[0])) {
|
||||
auto input_op = impl::GenWith(*with, nullptr, symbol_table, false, bound_symbols, storage);
|
||||
auto input_op = impl::GenWith(*with, nullptr, symbol_table, false, bound_symbols, storage, pc_ops);
|
||||
return bound_symbols;
|
||||
}
|
||||
|
||||
@ -583,7 +621,8 @@ std::unique_ptr<LogicalOperator> GenNamedPaths(std::unique_ptr<LogicalOperator>
|
||||
|
||||
std::unique_ptr<LogicalOperator> GenReturn(Return &ret, std::unique_ptr<LogicalOperator> input_op,
|
||||
SymbolTable &symbol_table, bool is_write,
|
||||
const std::unordered_set<Symbol> &bound_symbols, AstStorage &storage) {
|
||||
const std::unordered_set<Symbol> &bound_symbols, AstStorage &storage,
|
||||
std::unordered_map<std::string, std::shared_ptr<LogicalOperator>> pc_ops) {
|
||||
// Similar to WITH clause, but we want to accumulate when the query writes to
|
||||
// the database. This way we handle the case when we want to return
|
||||
// expressions with the latest updated results. For example, `MATCH (n) -- ()
|
||||
@ -592,13 +631,14 @@ std::unique_ptr<LogicalOperator> GenReturn(Return &ret, std::unique_ptr<LogicalO
|
||||
// final result of 'k' increments.
|
||||
bool accumulate = is_write;
|
||||
bool advance_command = false;
|
||||
ReturnBodyContext body(ret.body_, symbol_table, bound_symbols, storage);
|
||||
ReturnBodyContext body(ret.body_, symbol_table, bound_symbols, storage, pc_ops);
|
||||
return GenReturnBody(std::move(input_op), advance_command, body, accumulate);
|
||||
}
|
||||
|
||||
std::unique_ptr<LogicalOperator> GenWith(With &with, std::unique_ptr<LogicalOperator> input_op,
|
||||
SymbolTable &symbol_table, bool is_write,
|
||||
std::unordered_set<Symbol> &bound_symbols, AstStorage &storage) {
|
||||
std::unordered_set<Symbol> &bound_symbols, AstStorage &storage,
|
||||
std::unordered_map<std::string, std::shared_ptr<LogicalOperator>> pc_ops) {
|
||||
// WITH clause is Accumulate/Aggregate (advance_command) + Produce and
|
||||
// optional Filter. In case of update and aggregation, we want to accumulate
|
||||
// first, so that when aggregating, we get the latest results. Similar to
|
||||
@ -606,7 +646,7 @@ std::unique_ptr<LogicalOperator> GenWith(With &with, std::unique_ptr<LogicalOper
|
||||
bool accumulate = is_write;
|
||||
// No need to advance the command if we only performed reads.
|
||||
bool advance_command = is_write;
|
||||
ReturnBodyContext body(with.body_, symbol_table, bound_symbols, storage, with.where_);
|
||||
ReturnBodyContext body(with.body_, symbol_table, bound_symbols, storage, pc_ops, with.where_);
|
||||
auto last_op = GenReturnBody(std::move(input_op), advance_command, body, accumulate);
|
||||
// Reset bound symbols, so that only those in WITH are exposed.
|
||||
bound_symbols.clear();
|
||||
|
@ -21,6 +21,7 @@
|
||||
#include "query/frontend/ast/ast_visitor.hpp"
|
||||
#include "query/plan/operator.hpp"
|
||||
#include "query/plan/preprocess.hpp"
|
||||
#include "utils/exceptions.hpp"
|
||||
#include "utils/logging.hpp"
|
||||
#include "utils/typeinfo.hpp"
|
||||
|
||||
@ -87,8 +88,9 @@ bool HasBoundFilterSymbols(const std::unordered_set<Symbol> &bound_symbols, cons
|
||||
|
||||
// Returns the set of symbols for the subquery that are actually referenced from the outer scope and
|
||||
// used in the subquery.
|
||||
std::unordered_set<Symbol> GetSubqueryBoundSymbols(const std::vector<SingleQueryPart> &single_query_parts,
|
||||
SymbolTable &symbol_table, AstStorage &storage);
|
||||
std::unordered_set<Symbol> GetSubqueryBoundSymbols(
|
||||
const std::vector<SingleQueryPart> &single_query_parts, SymbolTable &symbol_table, AstStorage &storage,
|
||||
std::unordered_map<std::string, std::shared_ptr<LogicalOperator>> pc_ops);
|
||||
|
||||
Symbol GetSymbol(NodeAtom *atom, const SymbolTable &symbol_table);
|
||||
Symbol GetSymbol(EdgeAtom *atom, const SymbolTable &symbol_table);
|
||||
@ -142,11 +144,13 @@ std::unique_ptr<LogicalOperator> GenNamedPaths(std::unique_ptr<LogicalOperator>
|
||||
|
||||
std::unique_ptr<LogicalOperator> GenReturn(Return &ret, std::unique_ptr<LogicalOperator> input_op,
|
||||
SymbolTable &symbol_table, bool is_write,
|
||||
const std::unordered_set<Symbol> &bound_symbols, AstStorage &storage);
|
||||
const std::unordered_set<Symbol> &bound_symbols, AstStorage &storage,
|
||||
std::unordered_map<std::string, std::shared_ptr<LogicalOperator>> pc_ops);
|
||||
|
||||
std::unique_ptr<LogicalOperator> GenWith(With &with, std::unique_ptr<LogicalOperator> input_op,
|
||||
SymbolTable &symbol_table, bool is_write,
|
||||
std::unordered_set<Symbol> &bound_symbols, AstStorage &storage);
|
||||
std::unordered_set<Symbol> &bound_symbols, AstStorage &storage,
|
||||
std::unordered_map<std::string, std::shared_ptr<LogicalOperator>> pc_ops);
|
||||
|
||||
std::unique_ptr<LogicalOperator> GenUnion(const CypherUnion &cypher_union, std::shared_ptr<LogicalOperator> left_op,
|
||||
std::shared_ptr<LogicalOperator> right_op, SymbolTable &symbol_table);
|
||||
@ -190,11 +194,24 @@ class RuleBasedPlanner {
|
||||
uint64_t merge_id = 0;
|
||||
uint64_t subquery_id = 0;
|
||||
|
||||
std::unordered_map<std::string, std::shared_ptr<LogicalOperator>> pattern_comprehension_ops;
|
||||
|
||||
if (single_query_part.pattern_comprehension_matchings.size() > 1) {
|
||||
throw utils::NotYetImplemented("Multiple pattern comprehensions.");
|
||||
}
|
||||
for (const auto &matching : single_query_part.pattern_comprehension_matchings) {
|
||||
std::unique_ptr<LogicalOperator> new_input;
|
||||
MatchContext match_ctx{matching.second, *context.symbol_table, context.bound_symbols};
|
||||
new_input = PlanMatching(match_ctx, std::move(new_input));
|
||||
new_input = std::make_unique<Produce>(std::move(new_input), std::vector{matching.second.result_expr});
|
||||
pattern_comprehension_ops.emplace(matching.first, std::move(new_input));
|
||||
}
|
||||
|
||||
for (const auto &clause : single_query_part.remaining_clauses) {
|
||||
MG_ASSERT(!utils::IsSubtype(*clause, Match::kType), "Unexpected Match in remaining clauses");
|
||||
if (auto *ret = utils::Downcast<Return>(clause)) {
|
||||
input_op = impl::GenReturn(*ret, std::move(input_op), *context.symbol_table, context.is_write_query,
|
||||
context.bound_symbols, *context.ast_storage);
|
||||
context.bound_symbols, *context.ast_storage, pattern_comprehension_ops);
|
||||
} else if (auto *merge = utils::Downcast<query::Merge>(clause)) {
|
||||
input_op = GenMerge(*merge, std::move(input_op), single_query_part.merge_matching[merge_id++]);
|
||||
// Treat MERGE clause as write, because we do not know if it will
|
||||
@ -202,7 +219,7 @@ class RuleBasedPlanner {
|
||||
context.is_write_query = true;
|
||||
} else if (auto *with = utils::Downcast<query::With>(clause)) {
|
||||
input_op = impl::GenWith(*with, std::move(input_op), *context.symbol_table, context.is_write_query,
|
||||
context.bound_symbols, *context.ast_storage);
|
||||
context.bound_symbols, *context.ast_storage, pattern_comprehension_ops);
|
||||
// WITH clause advances the command, so reset the flag.
|
||||
context.is_write_query = false;
|
||||
} else if (auto op = HandleWriteClause(clause, input_op, *context.symbol_table, context.bound_symbols)) {
|
||||
@ -241,7 +258,7 @@ class RuleBasedPlanner {
|
||||
single_query_part, merge_id);
|
||||
} else if (auto *call_sub = utils::Downcast<query::CallSubquery>(clause)) {
|
||||
input_op = HandleSubquery(std::move(input_op), single_query_part.subqueries[subquery_id++],
|
||||
*context.symbol_table, *context_->ast_storage);
|
||||
*context.symbol_table, *context_->ast_storage, pattern_comprehension_ops);
|
||||
} else {
|
||||
throw utils::NotYetImplemented("clause '{}' conversion to operator(s)", clause->GetTypeInfo().name);
|
||||
}
|
||||
@ -860,15 +877,15 @@ class RuleBasedPlanner {
|
||||
symbol);
|
||||
}
|
||||
|
||||
std::unique_ptr<LogicalOperator> HandleSubquery(std::unique_ptr<LogicalOperator> last_op,
|
||||
std::shared_ptr<QueryParts> subquery, SymbolTable &symbol_table,
|
||||
AstStorage &storage) {
|
||||
std::unique_ptr<LogicalOperator> HandleSubquery(
|
||||
std::unique_ptr<LogicalOperator> last_op, std::shared_ptr<QueryParts> subquery, SymbolTable &symbol_table,
|
||||
AstStorage &storage, std::unordered_map<std::string, std::shared_ptr<LogicalOperator>> pc_ops) {
|
||||
std::unordered_set<Symbol> outer_scope_bound_symbols;
|
||||
outer_scope_bound_symbols.insert(std::make_move_iterator(context_->bound_symbols.begin()),
|
||||
std::make_move_iterator(context_->bound_symbols.end()));
|
||||
|
||||
context_->bound_symbols =
|
||||
impl::GetSubqueryBoundSymbols(subquery->query_parts[0].single_query_parts, symbol_table, storage);
|
||||
impl::GetSubqueryBoundSymbols(subquery->query_parts[0].single_query_parts, symbol_table, storage, pc_ops);
|
||||
|
||||
auto subquery_op = Plan(*subquery);
|
||||
|
||||
|
@ -68,6 +68,7 @@ enum class TypeId : uint64_t {
|
||||
APPLY,
|
||||
INDEXED_JOIN,
|
||||
HASH_JOIN,
|
||||
ROLLUP_APPLY,
|
||||
|
||||
// Replication
|
||||
// NOTE: these NEED to be stable in the 2000+ range (see rpc version)
|
||||
|
@ -291,3 +291,45 @@ Feature: List operators
|
||||
# Then the result should be:
|
||||
# | years |
|
||||
# | [2021,2003,2003,1999] |
|
||||
|
||||
Scenario: Multiple entries with list pattern comprehension
|
||||
Given graph "graph_keanu"
|
||||
When executing query:
|
||||
"""
|
||||
MATCH (n:Person)
|
||||
RETURN n.name, [(n)-->(b:Movie) WHERE b.title CONTAINS 'Matrix' | b.released] AS years
|
||||
"""
|
||||
Then an error should be raised
|
||||
|
||||
Scenario: Multiple list pattern comprehensions in Return
|
||||
Given graph "graph_keanu"
|
||||
When executing query:
|
||||
"""
|
||||
MATCH (n:Person)
|
||||
RETURN n.name,
|
||||
[(n)-->(b:Movie) WHERE b.title CONTAINS 'Matrix' | b.released] AS years,
|
||||
[(n)-->(b:Movie) WHERE b.title CONTAINS 'Matrix' | b.title] AS titles
|
||||
"""
|
||||
Then an error should be raised
|
||||
|
||||
Scenario: Function inside pattern comprehension's expression
|
||||
Given graph "graph_keanu"
|
||||
When executing query:
|
||||
"""
|
||||
MATCH (keanu:Person {name: 'Keanu Reeves'})
|
||||
RETURN [p = (keanu)-->(b:Movie) WHERE b.title CONTAINS 'Matrix' | size(nodes(p))] AS nodes
|
||||
"""
|
||||
Then an error should be raised
|
||||
|
||||
Scenario: Multiple list pattern comprehensions in With
|
||||
Given graph "graph_keanu"
|
||||
When executing query:
|
||||
"""
|
||||
MATCH (n) WHERE size(n.name) > 5
|
||||
WITH
|
||||
n AS actor,
|
||||
[(n)-->(m) WHERE m.released > 2000 | m.title] AS titles,
|
||||
[(n)-->(m) WHERE m.released > 2000 | m.released] AS years
|
||||
RETURN actor.name, years, titles;
|
||||
"""
|
||||
Then an error should be raised
|
||||
|
@ -1,5 +1,7 @@
|
||||
CREATE
|
||||
(keanu:Person {name: 'Keanu Reeves'}),
|
||||
(trinity:Person {name: 'Carrie-Anne Moss'}),
|
||||
(morpheus:Person {name: 'Laurence Fishburne'}),
|
||||
(johnnyMnemonic:Movie {title: 'Johnny Mnemonic', released: 1995}),
|
||||
(theMatrixRevolutions:Movie {title: 'The Matrix Revolutions', released: 2003}),
|
||||
(theMatrixReloaded:Movie {title: 'The Matrix Reloaded', released: 2003}),
|
||||
@ -13,4 +15,7 @@ CREATE
|
||||
(keanu)-[:ACTED_IN]->(theReplacements),
|
||||
(keanu)-[:ACTED_IN]->(theMatrix),
|
||||
(keanu)-[:ACTED_IN]->(theDevilsAdvocate),
|
||||
(keanu)-[:ACTED_IN]->(theMatrixResurrections);
|
||||
(keanu)-[:ACTED_IN]->(theMatrixResurrections),
|
||||
(trinity)-[:ACTED_IN]->(theMatrix),
|
||||
(trinity)-[:ACTED_IN]->(theMatrixReloaded),
|
||||
(morpheus)-[:ACTED_IN]->(theMatrix);
|
||||
|
@ -4624,3 +4624,101 @@ TEST_P(CypherMainVisitorTest, CallSubquery) {
|
||||
ASSERT_TRUE(nested_match);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(CypherMainVisitorTest, PatternComprehension) {
|
||||
auto &ast_generator = *GetParam();
|
||||
{
|
||||
const auto *query =
|
||||
dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("MATCH (n) RETURN [(n)-->(b) | b.val] AS res;"));
|
||||
const auto *ret = dynamic_cast<Return *>(query->single_query_->clauses_[1]);
|
||||
|
||||
const auto *pc = dynamic_cast<PatternComprehension *>(ret->body_.named_expressions[0]->expression_);
|
||||
ASSERT_TRUE(pc);
|
||||
|
||||
// Check for variable_
|
||||
EXPECT_EQ(pc->variable_, nullptr);
|
||||
|
||||
// Check for pattern_
|
||||
const auto pattern = pc->pattern_;
|
||||
ASSERT_TRUE(pattern->atoms_.size() == 3);
|
||||
|
||||
const auto *node1 = dynamic_cast<NodeAtom *>(pattern->atoms_[0]);
|
||||
const auto *edge = dynamic_cast<EdgeAtom *>(pattern->atoms_[1]);
|
||||
const auto *node2 = dynamic_cast<NodeAtom *>(pattern->atoms_[2]);
|
||||
|
||||
ASSERT_TRUE(node1);
|
||||
ASSERT_TRUE(edge);
|
||||
ASSERT_TRUE(node2);
|
||||
|
||||
// Check for filter_
|
||||
EXPECT_EQ(pc->filter_, nullptr);
|
||||
|
||||
// Check for resultExpr_
|
||||
const auto *result_expr = pc->resultExpr_;
|
||||
ASSERT_TRUE(result_expr);
|
||||
}
|
||||
{
|
||||
const auto *query = dynamic_cast<CypherQuery *>(
|
||||
ast_generator.ParseQuery("MATCH (n) RETURN [(n)-->(b) WHERE b.id=1 | b.val] AS res;"));
|
||||
const auto *ret = dynamic_cast<Return *>(query->single_query_->clauses_[1]);
|
||||
|
||||
const auto *pc = dynamic_cast<PatternComprehension *>(ret->body_.named_expressions[0]->expression_);
|
||||
ASSERT_TRUE(pc);
|
||||
|
||||
// Check for variable_
|
||||
EXPECT_EQ(pc->variable_, nullptr);
|
||||
|
||||
// Check for pattern_
|
||||
const auto pattern = pc->pattern_;
|
||||
ASSERT_TRUE(pattern->atoms_.size() == 3);
|
||||
|
||||
const auto *node1 = dynamic_cast<NodeAtom *>(pattern->atoms_[0]);
|
||||
const auto *edge = dynamic_cast<EdgeAtom *>(pattern->atoms_[1]);
|
||||
const auto *node2 = dynamic_cast<NodeAtom *>(pattern->atoms_[2]);
|
||||
|
||||
ASSERT_TRUE(node1);
|
||||
ASSERT_TRUE(edge);
|
||||
ASSERT_TRUE(node2);
|
||||
|
||||
// Check for filter_
|
||||
const auto *filter = pc->filter_;
|
||||
ASSERT_TRUE(filter);
|
||||
ASSERT_TRUE(filter->expression_);
|
||||
|
||||
// Check for resultExpr_
|
||||
const auto *result_expr = pc->resultExpr_;
|
||||
ASSERT_TRUE(result_expr);
|
||||
}
|
||||
{
|
||||
const auto *query = dynamic_cast<CypherQuery *>(
|
||||
ast_generator.ParseQuery("MATCH (n) RETURN [p = (n)-->(b) WHERE b.id=1 | b.val] AS res;"));
|
||||
const auto *ret = dynamic_cast<Return *>(query->single_query_->clauses_[1]);
|
||||
|
||||
const auto *pc = dynamic_cast<PatternComprehension *>(ret->body_.named_expressions[0]->expression_);
|
||||
ASSERT_TRUE(pc);
|
||||
|
||||
// Check for variable_
|
||||
ASSERT_TRUE(pc->variable_);
|
||||
|
||||
// Check for pattern_
|
||||
const auto pattern = pc->pattern_;
|
||||
ASSERT_TRUE(pattern->atoms_.size() == 3);
|
||||
|
||||
const auto *node1 = dynamic_cast<NodeAtom *>(pattern->atoms_[0]);
|
||||
const auto *edge = dynamic_cast<EdgeAtom *>(pattern->atoms_[1]);
|
||||
const auto *node2 = dynamic_cast<NodeAtom *>(pattern->atoms_[2]);
|
||||
|
||||
ASSERT_TRUE(node1);
|
||||
ASSERT_TRUE(edge);
|
||||
ASSERT_TRUE(node2);
|
||||
|
||||
// Check for filter_
|
||||
const auto *filter = pc->filter_;
|
||||
ASSERT_TRUE(filter);
|
||||
ASSERT_TRUE(filter->expression_);
|
||||
|
||||
// Check for resultExpr_
|
||||
const auto *result_expr = pc->resultExpr_;
|
||||
ASSERT_TRUE(result_expr);
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
// Copyright 2024 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -635,3 +635,5 @@ auto GetForeach(AstStorage &storage, NamedExpression *named_expr, const std::vec
|
||||
#define DROP_USER(usernames) storage.Create<memgraph::query::DropUser>((usernames))
|
||||
#define CALL_PROCEDURE(...) memgraph::query::test_common::GetCallProcedure(storage, __VA_ARGS__)
|
||||
#define CALL_SUBQUERY(...) memgraph::query::test_common::GetCallSubquery(this->storage, __VA_ARGS__)
|
||||
#define PATTERN_COMPREHENSION(variable, pattern, filter, resultExpr) \
|
||||
this->storage.template Create<memgraph::query::PatternComprehension>(variable, pattern, filter, resultExpr)
|
||||
|
@ -1442,3 +1442,27 @@ TYPED_TEST(TestSymbolGenerator, PropertyCachingMixedLookups2) {
|
||||
ASSERT_TRUE(prop3_eval_mode == PropertyLookup::EvaluationMode::GET_ALL_PROPERTIES);
|
||||
ASSERT_TRUE(prop4_eval_mode == PropertyLookup::EvaluationMode::GET_ALL_PROPERTIES);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestSymbolGenerator, PatternComprehension) {
|
||||
auto prop = this->dba.NameToProperty("prop");
|
||||
|
||||
// MATCH (n) RETURN [(n)-[edge]->(m) | m.prop] AS alias
|
||||
auto query = QUERY(SINGLE_QUERY(
|
||||
MATCH(PATTERN(NODE("n"))),
|
||||
RETURN(NEXPR("alias", PATTERN_COMPREHENSION(nullptr,
|
||||
PATTERN(NODE("n"), EDGE("edge", EdgeAtom::Direction::BOTH, {}, false),
|
||||
NODE("m", std::nullopt, false)),
|
||||
nullptr, PROPERTY_LOOKUP(this->dba, "m", prop))))));
|
||||
|
||||
auto symbol_table = MakeSymbolTable(query);
|
||||
ASSERT_EQ(symbol_table.max_position(), 7);
|
||||
|
||||
memgraph::query::plan::UsedSymbolsCollector collector(symbol_table);
|
||||
auto *ret = dynamic_cast<Return *>(query->single_query_->clauses_[1]);
|
||||
auto *pc = dynamic_cast<PatternComprehension *>(ret->body_.named_expressions[0]->expression_);
|
||||
|
||||
pc->Accept(collector);
|
||||
|
||||
// n, edge, m, Path
|
||||
ASSERT_EQ(collector.symbols_.size(), 4);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user