memgraph/src/query/plan/rule_based_planner.cpp
2024-03-07 18:41:02 +04:00

671 lines
29 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/plan/rule_based_planner.hpp"
#include <algorithm>
#include <functional>
#include <limits>
#include <memory>
#include <stack>
#include <unordered_set>
#include "query/frontend/ast/ast.hpp"
#include "query/plan/operator.hpp"
#include "query/plan/preprocess.hpp"
#include "utils/algorithm.hpp"
#include "utils/exceptions.hpp"
#include "utils/logging.hpp"
namespace memgraph::query::plan {
namespace {
// Ast tree visitor which collects the context for a return body.
// The return body of WITH and RETURN clauses consists of:
//
// * named expressions (used to produce results);
// * flag whether the results need to be DISTINCT;
// * optional SKIP expression;
// * optional LIMIT expression and
// * optional ORDER BY expressions.
//
// In addition to the above, we collect information on used symbols,
// aggregations and expressions used for group by.
class ReturnBodyContext : public HierarchicalTreeVisitor {
public:
ReturnBodyContext(const ReturnBody &body, SymbolTable &symbol_table, const std::unordered_set<Symbol> &bound_symbols,
AstStorage &storage, std::unordered_map<std::string, std::shared_ptr<LogicalOperator>> pc_ops,
Where *where = nullptr)
: body_(body), symbol_table_(symbol_table), bound_symbols_(bound_symbols), storage_(storage), where_(where) {
// Collect symbols from named expressions.
output_symbols_.reserve(body_.named_expressions.size());
if (body.all_identifiers) {
// Expand '*' to expressions and symbols first, so that their results come
// before regular named expressions.
ExpandUserSymbols();
}
for (auto &named_expr : body_.named_expressions) {
output_symbols_.emplace_back(symbol_table_.at(*named_expr));
named_expr->Accept(*this);
named_expressions_.emplace_back(named_expr);
if (pattern_comprehension_) {
if (auto it = pc_ops.find(named_expr->name_); it != pc_ops.end()) {
pattern_comprehension_op_ = std::move(it->second);
pc_ops.erase(it);
} else {
throw utils::NotYetImplemented("Operation on top of pattern comprehension");
}
}
}
// Collect symbols used in group by expressions.
if (!aggregations_.empty()) {
UsedSymbolsCollector collector(symbol_table_);
for (auto &group_by : group_by_) {
group_by->Accept(collector);
}
group_by_used_symbols_ = collector.symbols_;
}
if (aggregations_.empty()) {
// Visit order_by and where if we do not have aggregations. This way we
// prevent collecting group_by expressions from order_by and where, which
// would be very wrong. When we have aggregation, order_by and where can
// only use new symbols (ensured in semantic analysis), so we don't care
// about collecting used_symbols. Also, semantic analysis should
// have prevented any aggregations from appearing here.
for (const auto &order_pair : body.order_by) {
order_pair.expression->Accept(*this);
}
if (where) {
where->Accept(*this);
}
MG_ASSERT(aggregations_.empty(), "Unexpected aggregations in ORDER BY or WHERE");
}
}
using HierarchicalTreeVisitor::PostVisit;
using HierarchicalTreeVisitor::PreVisit;
using HierarchicalTreeVisitor::Visit;
bool Visit(PrimitiveLiteral &) override {
has_aggregation_.emplace_back(false);
return true;
}
private:
template <typename TLiteral, typename TIteratorToExpression>
void PostVisitCollectionLiteral(TLiteral &literal, TIteratorToExpression iterator_to_expression) {
// If there is an aggregation in the list, and there are group-bys, then we
// need to add the group-bys manually. If there are no aggregations, the
// whole list will be added as a group-by.
std::vector<Expression *> literal_group_by;
bool has_aggr = false;
auto it = has_aggregation_.end();
auto elements_it = literal.elements_.begin();
std::advance(it, -literal.elements_.size());
if (literal.GetTypeInfo() == MapProjectionLiteral::kType) {
// Erase the map variable. Grammar-wise, its a variable and thus never has aggregations.
std::advance(it, -1);
it = has_aggregation_.erase(it);
}
while (it != has_aggregation_.end()) {
if (*it) {
has_aggr = true;
} else {
literal_group_by.emplace_back(iterator_to_expression(elements_it));
}
elements_it++;
it = has_aggregation_.erase(it);
}
has_aggregation_.emplace_back(has_aggr);
if (has_aggr) {
for (auto expression_ptr : literal_group_by) group_by_.emplace_back(expression_ptr);
}
}
public:
bool PostVisit(ListLiteral &list_literal) override {
MG_ASSERT(list_literal.elements_.size() <= has_aggregation_.size(),
"Expected as many has_aggregation_ flags as there are list"
"elements.");
PostVisitCollectionLiteral(list_literal, [](auto it) { return *it; });
return true;
}
bool PostVisit(MapLiteral &map_literal) override {
MG_ASSERT(map_literal.elements_.size() <= has_aggregation_.size(),
"Expected as many has_aggregation_ flags as there are map elements.");
PostVisitCollectionLiteral(map_literal, [](auto it) { return it->second; });
return true;
}
bool PostVisit(MapProjectionLiteral &map_projection_literal) override {
MG_ASSERT(map_projection_literal.elements_.size() <= has_aggregation_.size(),
"Expected as many has_aggregation_ flags as there are map elements.");
PostVisitCollectionLiteral(map_projection_literal, [](auto it) { return it->second; });
return true;
}
bool PostVisit(All &all) override {
// Remove the symbol which is bound by all, because we are only interested
// in free (unbound) symbols.
used_symbols_.erase(symbol_table_.at(*all.identifier_));
MG_ASSERT(has_aggregation_.size() >= 3U, "Expected 3 has_aggregation_ flags for ALL arguments");
bool has_aggr = false;
for (int i = 0; i < 3; ++i) {
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
}
has_aggregation_.emplace_back(has_aggr);
return true;
}
bool PostVisit(Single &single) override {
// Remove the symbol which is bound by single, because we are only
// interested in free (unbound) symbols.
used_symbols_.erase(symbol_table_.at(*single.identifier_));
MG_ASSERT(has_aggregation_.size() >= 3U, "Expected 3 has_aggregation_ flags for SINGLE arguments");
bool has_aggr = false;
for (int i = 0; i < 3; ++i) {
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
}
has_aggregation_.emplace_back(has_aggr);
return true;
}
bool PostVisit(Any &any) override {
// Remove the symbol which is bound by any, because we are only interested
// in free (unbound) symbols.
used_symbols_.erase(symbol_table_.at(*any.identifier_));
MG_ASSERT(has_aggregation_.size() >= 3U, "Expected 3 has_aggregation_ flags for ANY arguments");
bool has_aggr = false;
for (int i = 0; i < 3; ++i) {
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
}
has_aggregation_.emplace_back(has_aggr);
return true;
}
bool PostVisit(None &none) override {
// Remove the symbol which is bound by none, because we are only interested
// in free (unbound) symbols.
used_symbols_.erase(symbol_table_.at(*none.identifier_));
MG_ASSERT(has_aggregation_.size() >= 3U, "Expected 3 has_aggregation_ flags for NONE arguments");
bool has_aggr = false;
for (int i = 0; i < 3; ++i) {
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
}
has_aggregation_.emplace_back(has_aggr);
return true;
}
bool PostVisit(Reduce &reduce) override {
// Remove the symbols bound by reduce, because we are only interested
// in free (unbound) symbols.
used_symbols_.erase(symbol_table_.at(*reduce.accumulator_));
used_symbols_.erase(symbol_table_.at(*reduce.identifier_));
MG_ASSERT(has_aggregation_.size() >= 5U, "Expected 5 has_aggregation_ flags for REDUCE arguments");
bool has_aggr = false;
for (int i = 0; i < 5; ++i) {
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
}
has_aggregation_.emplace_back(has_aggr);
return true;
}
bool PostVisit(Coalesce &coalesce) override {
MG_ASSERT(has_aggregation_.size() >= coalesce.expressions_.size(),
"Expected >= {} has_aggregation_ flags for COALESCE arguments", has_aggregation_.size());
bool has_aggr = false;
for (size_t i = 0; i < coalesce.expressions_.size(); ++i) {
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
}
has_aggregation_.emplace_back(has_aggr);
return true;
}
bool PostVisit(Extract &extract) override {
// Remove the symbol bound by extract, because we are only interested
// in free (unbound) symbols.
used_symbols_.erase(symbol_table_.at(*extract.identifier_));
MG_ASSERT(has_aggregation_.size() >= 3U, "Expected 3 has_aggregation_ flags for EXTRACT arguments");
bool has_aggr = false;
for (int i = 0; i < 3; ++i) {
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
}
has_aggregation_.emplace_back(has_aggr);
return true;
}
bool Visit(Identifier &ident) override {
const auto &symbol = symbol_table_.at(ident);
if (!utils::Contains(output_symbols_, symbol)) {
// Don't pick up new symbols, even though they may be used in ORDER BY or
// WHERE.
used_symbols_.insert(symbol);
}
has_aggregation_.emplace_back(false);
return true;
}
bool PreVisit(ListSlicingOperator &list_slicing) override {
list_slicing.list_->Accept(*this);
bool list_has_aggr = has_aggregation_.back();
has_aggregation_.pop_back();
bool has_aggr = list_has_aggr;
if (list_slicing.lower_bound_) {
list_slicing.lower_bound_->Accept(*this);
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
}
if (list_slicing.upper_bound_) {
list_slicing.upper_bound_->Accept(*this);
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
}
if (has_aggr && !list_has_aggr) {
// We need to group by the list expression, because it didn't have an
// aggregation inside.
group_by_.emplace_back(list_slicing.list_);
}
has_aggregation_.emplace_back(has_aggr);
return false;
}
bool PreVisit(IfOperator &if_operator) override {
if_operator.condition_->Accept(*this);
bool has_aggr = has_aggregation_.back();
has_aggregation_.pop_back();
if_operator.then_expression_->Accept(*this);
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
if_operator.else_expression_->Accept(*this);
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
has_aggregation_.emplace_back(has_aggr);
// TODO: Once we allow aggregations here, insert appropriate stuff in
// group_by.
MG_ASSERT(!has_aggr, "Currently aggregations in CASE are not allowed");
return false;
}
bool PostVisit(Function &function) override {
MG_ASSERT(function.arguments_.size() <= has_aggregation_.size(),
"Expected as many has_aggregation_ flags as there are"
"function arguments.");
bool has_aggr = false;
auto it = has_aggregation_.end();
std::advance(it, -function.arguments_.size());
while (it != has_aggregation_.end()) {
has_aggr = has_aggr || *it;
it = has_aggregation_.erase(it);
}
has_aggregation_.emplace_back(has_aggr);
return true;
}
#define VISIT_BINARY_OPERATOR(BinaryOperator) \
bool PostVisit(BinaryOperator &op) override { \
MG_ASSERT(has_aggregation_.size() >= 2U, "Expected at least 2 has_aggregation_ flags."); \
/* has_aggregation_ stack is reversed, last result is from the 2nd */ \
/* expression. */ \
bool aggr2 = has_aggregation_.back(); \
has_aggregation_.pop_back(); \
bool aggr1 = has_aggregation_.back(); \
has_aggregation_.pop_back(); \
bool has_aggr = aggr1 || aggr2; \
if (has_aggr && !(aggr1 && aggr2)) { \
/* Group by the expression which does not contain aggregation. */ \
/* Possible optimization is to ignore constant value expressions */ \
group_by_.emplace_back(aggr1 ? op.expression2_ : op.expression1_); \
} \
/* Propagate that this whole expression may contain an aggregation. */ \
has_aggregation_.emplace_back(has_aggr); \
return true; \
}
VISIT_BINARY_OPERATOR(OrOperator)
VISIT_BINARY_OPERATOR(XorOperator)
VISIT_BINARY_OPERATOR(AndOperator)
VISIT_BINARY_OPERATOR(AdditionOperator)
VISIT_BINARY_OPERATOR(SubtractionOperator)
VISIT_BINARY_OPERATOR(MultiplicationOperator)
VISIT_BINARY_OPERATOR(DivisionOperator)
VISIT_BINARY_OPERATOR(ModOperator)
VISIT_BINARY_OPERATOR(NotEqualOperator)
VISIT_BINARY_OPERATOR(EqualOperator)
VISIT_BINARY_OPERATOR(LessOperator)
VISIT_BINARY_OPERATOR(GreaterOperator)
VISIT_BINARY_OPERATOR(LessEqualOperator)
VISIT_BINARY_OPERATOR(GreaterEqualOperator)
VISIT_BINARY_OPERATOR(InListOperator)
VISIT_BINARY_OPERATOR(SubscriptOperator)
#undef VISIT_BINARY_OPERATOR
bool PostVisit(Aggregation &aggr) override {
// Aggregation contains a virtual symbol, where the result will be stored.
const auto &symbol = symbol_table_.at(aggr);
aggregations_.emplace_back(
Aggregate::Element{aggr.expression1_, aggr.expression2_, aggr.op_, symbol, aggr.distinct_});
// Aggregation expression1_ is optional in COUNT(*), and COLLECT_MAP uses
// two expressions, so we can have 0, 1 or 2 elements on the
// has_aggregation_stack for this Aggregation expression.
if (aggr.op_ == Aggregation::Op::COLLECT_MAP) has_aggregation_.pop_back();
if (aggr.expression1_)
has_aggregation_.back() = true;
else
has_aggregation_.emplace_back(true);
// Possible optimization is to skip remembering symbols inside aggregation.
// If and when implementing this, don't forget that Accumulate needs *all*
// the symbols, including those inside aggregation.
return true;
}
bool PostVisit(NamedExpression &named_expr) override {
MG_ASSERT(has_aggregation_.size() == 1U, "Expected to reduce has_aggregation_ to single boolean.");
if (!has_aggregation_.back()) {
group_by_.emplace_back(named_expr.expression_);
}
has_aggregation_.pop_back();
return true;
}
bool Visit(ParameterLookup & /*unused*/) override {
has_aggregation_.emplace_back(false);
return true;
}
bool PostVisit(RegexMatch & /*unused*/) override {
MG_ASSERT(has_aggregation_.size() >= 2U, "Expected 2 has_aggregation_ flags for RegexMatch arguments");
bool has_aggr = has_aggregation_.back();
has_aggregation_.pop_back();
has_aggregation_.back() |= has_aggr;
return true;
}
bool PreVisit(PatternComprehension & /*unused*/) override {
pattern_compression_aggregations_start_index_ = has_aggregation_.size();
return true;
}
bool PostVisit(PatternComprehension &pattern_comprehension) override {
bool has_aggr = false;
for (auto i = has_aggregation_.size(); i > pattern_compression_aggregations_start_index_; --i) {
has_aggr |= has_aggregation_.back();
has_aggregation_.pop_back();
}
has_aggregation_.emplace_back(has_aggr);
pattern_comprehension_ = &pattern_comprehension;
return true;
}
// Creates NamedExpression with an Identifier for each user declared symbol.
// This should be used when body.all_identifiers is true, to generate
// expressions for Produce operator.
void ExpandUserSymbols() {
MG_ASSERT(named_expressions_.empty(), "ExpandUserSymbols should be first to fill named_expressions_");
MG_ASSERT(output_symbols_.empty(), "ExpandUserSymbols should be first to fill output_symbols_");
for (const auto &symbol : bound_symbols_) {
if (!symbol.user_declared()) {
continue;
}
auto *ident = storage_.Create<Identifier>(symbol.name())->MapTo(symbol);
auto *named_expr = storage_.Create<NamedExpression>(symbol.name(), ident)->MapTo(symbol);
// Fill output expressions and symbols with expanded identifiers.
named_expressions_.emplace_back(named_expr);
output_symbols_.emplace_back(symbol);
used_symbols_.insert(symbol);
// Don't forget to group by expanded identifiers.
group_by_.emplace_back(ident);
}
// Cypher RETURN/WITH * expects to expand '*' sorted by name.
std::sort(output_symbols_.begin(), output_symbols_.end(),
[](const auto &a, const auto &b) { return a.name() < b.name(); });
std::sort(named_expressions_.begin(), named_expressions_.end(),
[](const auto &a, const auto &b) { return a->name_ < b->name_; });
}
// If true, results need to be distinct.
bool distinct() const { return body_.distinct; }
// Named expressions which are used to produce results.
const auto &named_expressions() const { return named_expressions_; }
// Pairs of (Ordering, Expression *) for sorting results.
const auto &order_by() const { return body_.order_by; }
// Optional expression which determines how many results to skip.
auto *skip() const { return body_.skip; }
// Optional expression which determines how many results to produce.
auto *limit() const { return body_.limit; }
// Optional Where clause for filtering.
const auto *where() const { return where_; }
// Set of symbols used inside the visited expressions, including the inside of
// aggregation expression. These only includes old symbols, even though new
// ones may have been used in ORDER BY or WHERE.
const auto &used_symbols() const { return used_symbols_; }
// List of aggregation elements found in expressions.
const auto &aggregations() const { return aggregations_; }
// When there is at least one aggregation element, all the non-aggregate (sub)
// expressions are used for grouping. For example, in `WITH sum(n.a) + 2 * n.b
// AS sum, n.c AS nc`, we will group by `2 * n.b` and `n.c`.
const auto &group_by() const { return group_by_; }
// Set of symbols used in group by expressions.
const auto &group_by_used_symbols() const { return group_by_used_symbols_; }
// All symbols generated by named expressions. They are collected in order of
// named_expressions.
const auto &output_symbols() const { return output_symbols_; }
const auto *pattern_comprehension() const { return pattern_comprehension_; }
std::shared_ptr<LogicalOperator> pattern_comprehension_op() const { return pattern_comprehension_op_; }
private:
const ReturnBody &body_;
SymbolTable &symbol_table_;
const std::unordered_set<Symbol> &bound_symbols_;
AstStorage &storage_;
const Where *const where_ = nullptr;
std::unordered_set<Symbol> used_symbols_;
std::vector<Symbol> output_symbols_;
std::vector<Aggregate::Element> aggregations_;
std::vector<Expression *> group_by_;
std::unordered_set<Symbol> group_by_used_symbols_;
// Flag stack indicating whether an expression contains an aggregation. A
// stack is needed to address the case where one child sub-expression has
// an aggregation, while the other child does not.
// For example, the AST (+ (sum x) y) is as follows:
// * (sum x) -- Has an aggregation.
// * y -- Doesn't, we need to group by this.
// * (+ (sum x) y) -- The whole expression has an aggregation, so we don't
// group by it.
std::list<bool> has_aggregation_;
std::vector<NamedExpression *> named_expressions_;
PatternComprehension *pattern_comprehension_ = nullptr;
std::shared_ptr<LogicalOperator> pattern_comprehension_op_;
size_t pattern_compression_aggregations_start_index_ = 0;
};
std::unique_ptr<LogicalOperator> GenReturnBody(std::unique_ptr<LogicalOperator> input_op, bool advance_command,
const ReturnBodyContext &body, bool accumulate) {
std::vector<Symbol> used_symbols(body.used_symbols().begin(), body.used_symbols().end());
auto last_op = std::move(input_op);
if (accumulate) {
// We only advance the command in Accumulate. This is done for WITH clause,
// when the first part updated the database. RETURN clause may only need an
// accumulation after updates, without advancing the command.
last_op = std::make_unique<Accumulate>(std::move(last_op), used_symbols, advance_command);
}
if (!body.aggregations().empty()) {
// When we have aggregation, SKIP/LIMIT should always come after it.
std::vector<Symbol> remember(body.group_by_used_symbols().begin(), body.group_by_used_symbols().end());
last_op = std::make_unique<Aggregate>(std::move(last_op), body.aggregations(), body.group_by(), remember);
}
if (body.pattern_comprehension()) {
last_op = std::make_unique<RollUpApply>(std::move(last_op), body.pattern_comprehension_op());
}
last_op = std::make_unique<Produce>(std::move(last_op), body.named_expressions());
// Distinct in ReturnBody only makes Produce values unique, so plan after it.
if (body.distinct()) {
last_op = std::make_unique<Distinct>(std::move(last_op), body.output_symbols());
}
// Like Where, OrderBy can read from symbols established by named expressions
// in Produce, so it must come after it.
if (!body.order_by().empty()) {
last_op = std::make_unique<OrderBy>(std::move(last_op), body.order_by(), body.output_symbols());
}
// Finally, Skip and Limit must come after OrderBy.
if (body.skip()) {
last_op = std::make_unique<Skip>(std::move(last_op), body.skip());
}
// Limit is always after Skip.
if (body.limit()) {
last_op = std::make_unique<Limit>(std::move(last_op), body.limit());
}
// Where may see new symbols so it comes after we generate Produce and in
// general, comes after any OrderBy, Skip or Limit.
if (body.where()) {
last_op = std::make_unique<Filter>(std::move(last_op), std::vector<std::shared_ptr<LogicalOperator>>{},
body.where()->expression_);
}
return last_op;
}
} // namespace
namespace impl {
bool HasBoundFilterSymbols(const std::unordered_set<Symbol> &bound_symbols, const FilterInfo &filter) {
return std::ranges::all_of(
filter.used_symbols.begin(), filter.used_symbols.end(),
[&bound_symbols](const auto &symbol) { return bound_symbols.find(symbol) != bound_symbols.end(); });
}
Expression *ExtractFilters(const std::unordered_set<Symbol> &bound_symbols, Filters &filters, AstStorage &storage) {
Expression *filter_expr = nullptr;
std::vector<FilterInfo> and_joinable_filters{};
for (auto filters_it = filters.begin(); filters_it != filters.end();) {
if (HasBoundFilterSymbols(bound_symbols, *filters_it)) {
and_joinable_filters.emplace_back(*filters_it);
filters_it = filters.erase(filters_it);
} else {
filters_it++;
}
}
// Idea here is to join filters in a way
// that pattern filter ( exists() ) is at the end
// so if any of the AND filters before
// evaluate to false we don't need to
// evaluate pattern ( exists() ) filter
std::partition(and_joinable_filters.begin(), and_joinable_filters.end(),
[](const FilterInfo &filter_info) { return filter_info.type != FilterInfo::Type::Pattern; });
for (auto &and_joinable_filter : and_joinable_filters) {
filter_expr = impl::BoolJoin<AndOperator>(storage, filter_expr, and_joinable_filter.expression);
}
return filter_expr;
}
std::unordered_set<Symbol> GetSubqueryBoundSymbols(
const std::vector<SingleQueryPart> &single_query_parts, SymbolTable &symbol_table, AstStorage &storage,
std::unordered_map<std::string, std::shared_ptr<LogicalOperator>> pc_ops) {
const auto &query = single_query_parts[0];
if (!query.matching.expansions.empty() || query.remaining_clauses.empty()) {
return {};
}
if (std::unordered_set<Symbol> bound_symbols; auto *with = utils::Downcast<query::With>(query.remaining_clauses[0])) {
auto input_op = impl::GenWith(*with, nullptr, symbol_table, false, bound_symbols, storage, pc_ops);
return bound_symbols;
}
return {};
}
std::unique_ptr<LogicalOperator> GenNamedPaths(std::unique_ptr<LogicalOperator> last_op,
std::unordered_set<Symbol> &bound_symbols,
std::unordered_map<Symbol, std::vector<Symbol>> &named_paths) {
auto all_are_bound = [&bound_symbols](const std::vector<Symbol> &syms) {
for (const auto &sym : syms)
if (bound_symbols.find(sym) == bound_symbols.end()) return false;
return true;
};
for (auto named_path_it = named_paths.begin(); named_path_it != named_paths.end();) {
if (all_are_bound(named_path_it->second)) {
last_op = std::make_unique<ConstructNamedPath>(std::move(last_op), named_path_it->first,
std::move(named_path_it->second));
bound_symbols.insert(named_path_it->first);
named_path_it = named_paths.erase(named_path_it);
} else {
++named_path_it;
}
}
return last_op;
}
std::unique_ptr<LogicalOperator> GenReturn(Return &ret, std::unique_ptr<LogicalOperator> input_op,
SymbolTable &symbol_table, bool is_write,
const std::unordered_set<Symbol> &bound_symbols, AstStorage &storage,
std::unordered_map<std::string, std::shared_ptr<LogicalOperator>> pc_ops) {
// Similar to WITH clause, but we want to accumulate when the query writes to
// the database. This way we handle the case when we want to return
// expressions with the latest updated results. For example, `MATCH (n) -- ()
// SET n.prop = n.prop + 1 RETURN n.prop`. If we match same `n` multiple 'k'
// times, we want to return 'k' results where the property value is the same,
// final result of 'k' increments.
bool accumulate = is_write;
bool advance_command = false;
ReturnBodyContext body(ret.body_, symbol_table, bound_symbols, storage, pc_ops);
return GenReturnBody(std::move(input_op), advance_command, body, accumulate);
}
std::unique_ptr<LogicalOperator> GenWith(With &with, std::unique_ptr<LogicalOperator> input_op,
SymbolTable &symbol_table, bool is_write,
std::unordered_set<Symbol> &bound_symbols, AstStorage &storage,
std::unordered_map<std::string, std::shared_ptr<LogicalOperator>> pc_ops) {
// WITH clause is Accumulate/Aggregate (advance_command) + Produce and
// optional Filter. In case of update and aggregation, we want to accumulate
// first, so that when aggregating, we get the latest results. Similar to
// RETURN clause.
bool accumulate = is_write;
// No need to advance the command if we only performed reads.
bool advance_command = is_write;
ReturnBodyContext body(with.body_, symbol_table, bound_symbols, storage, pc_ops, with.where_);
auto last_op = GenReturnBody(std::move(input_op), advance_command, body, accumulate);
// Reset bound symbols, so that only those in WITH are exposed.
bound_symbols.clear();
for (const auto &symbol : body.output_symbols()) {
bound_symbols.insert(symbol);
}
return last_op;
}
std::unique_ptr<LogicalOperator> GenUnion(const CypherUnion &cypher_union, std::shared_ptr<LogicalOperator> left_op,
std::shared_ptr<LogicalOperator> right_op, SymbolTable &symbol_table) {
return std::make_unique<Union>(left_op, right_op, cypher_union.union_symbols_, left_op->OutputSymbols(symbol_table),
right_op->OutputSymbols(symbol_table));
}
Symbol GetSymbol(NodeAtom *atom, const SymbolTable &symbol_table) { return symbol_table.at(*atom->identifier_); }
Symbol GetSymbol(EdgeAtom *atom, const SymbolTable &symbol_table) { return symbol_table.at(*atom->identifier_); }
} // namespace impl
} // namespace memgraph::query::plan