Add Cypher subqueries (#794) (#851)

Co-authored-by: Bruno Sačarić <bruno.sacaric@gmail.com>
This commit is contained in:
Josipmrden 2023-03-31 15:24:02 +02:00 committed by Marko Budiselic
parent 398503da7a
commit f5a49ed29f
38 changed files with 1771 additions and 176 deletions

View File

@ -267,4 +267,7 @@ constexpr utils::TypeInfo query::TransactionQueueQuery::kType{utils::TypeId::AST
"TransactionQueueQuery", &query::Query::kType};
constexpr utils::TypeInfo query::Exists::kType{utils::TypeId::AST_EXISTS, "Exists", &query::Expression::kType};
constexpr utils::TypeInfo query::CallSubquery::kType{utils::TypeId::AST_CALL_SUBQUERY, "CallSubquery",
&query::Clause::kType};
} // namespace memgraph

View File

@ -1616,6 +1616,8 @@ class NamedExpression : public memgraph::query::Tree,
int32_t token_position_{-1};
/// Symbol table position of the symbol this NamedExpression is mapped to.
int32_t symbol_pos_{-1};
/// True if the variable is aliased
bool is_aliased_{false};
NamedExpression *Clone(AstStorage *storage) const override {
NamedExpression *object = storage->Create<NamedExpression>();
@ -1623,6 +1625,7 @@ class NamedExpression : public memgraph::query::Tree,
object->expression_ = expression_ ? expression_->Clone(storage) : nullptr;
object->token_position_ = token_position_;
object->symbol_pos_ = symbol_pos_;
object->is_aliased_ = is_aliased_;
return object;
}
@ -1973,7 +1976,7 @@ class Query : public memgraph::query::Tree, public utils::Visitable<QueryVisitor
friend class AstStorage;
};
class CypherQuery : public memgraph::query::Query {
class CypherQuery : public memgraph::query::Query, public utils::Visitable<HierarchicalTreeVisitor> {
public:
static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; }
@ -1982,6 +1985,17 @@ class CypherQuery : public memgraph::query::Query {
DEFVISITABLE(QueryVisitor<void>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
single_query_->Accept(visitor);
for (auto *cypher_union : cypher_unions_) {
cypher_union->Accept(visitor);
}
}
return visitor.PostVisit(*this);
}
/// First and potentially only query.
memgraph::query::SingleQuery *single_query_{nullptr};
/// Contains remaining queries that should form and union with `single_query_`.
@ -3289,5 +3303,31 @@ class Exists : public memgraph::query::Expression {
friend class AstStorage;
};
class CallSubquery : public memgraph::query::Clause {
public:
static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; }
CallSubquery() = default;
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
cypher_query_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
memgraph::query::CypherQuery *cypher_query_;
CallSubquery *Clone(AstStorage *storage) const override {
CallSubquery *object = storage->Create<CallSubquery>();
object->cypher_query_ = cypher_query_ ? cypher_query_->Clone(storage) : nullptr;
return object;
}
private:
friend class AstStorage;
};
} // namespace query
} // namespace memgraph

View File

@ -1237,7 +1237,9 @@ cpp<#
(token-position :int32_t :initval -1 :scope :public
:documentation "This field contains token position of first token in named expression used to create name_. If NamedExpression object is not created from query or it is aliased leave this value at -1.")
(symbol-pos :int32_t :initval -1 :scope :public
:documentation "Symbol table position of the symbol this NamedExpression is mapped to."))
:documentation "Symbol table position of the symbol this NamedExpression is mapped to.")
(is-aliased :bool :initval "false" :scope :public
:documentation "True if the variable is aliased"))
(:public
#>cpp
using utils::Visitable<ExpressionVisitor<TypedValue>>::Accept;
@ -1620,7 +1622,7 @@ cpp<#
(:clone :ignore-other-base-classes t)
(:type-info :ignore-other-base-classes t))
(lcp:define-class cypher-query (query)
(lcp:define-class cypher-query (query "::utils::Visitable<HierarchicalTreeVisitor>")
((single-query "SingleQuery *" :initval "nullptr" :scope :public
:slk-save #'slk-save-ast-pointer
:slk-load (slk-load-ast-pointer "SingleQuery")
@ -1639,13 +1641,25 @@ cpp<#
CypherQuery() = default;
DEFVISITABLE(QueryVisitor<void>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
single_query_->Accept(visitor);
for (auto *cypher_union : cypher_unions_) {
cypher_union->Accept(visitor);
}
}
return visitor.PostVisit(*this);
}
cpp<#)
(:private
#>cpp
friend class AstStorage;
cpp<#)
(:serialize (:slk))
(:clone))
(:serialize (:slk :ignore-other-base-classes t))
(:clone :ignore-other-base-classes t)
(:type-info :ignore-other-base-classes t))
(lcp:define-class explain-query (query)
((cypher-query "CypherQuery *" :initval "nullptr" :scope :public
@ -2787,5 +2801,28 @@ cpp<#
(:serialize (:slk))
(:clone))
(lcp:define-class call-subquery (clause)
((cypher-query "CypherQuery *" :scope :public
:slk-save #'slk-save-ast-pointer
:slk-load (slk-load-ast-pointer "CypherQuery")))
(:public
#>cpp
CallSubquery() = default;
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
cypher_query_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
cpp<#)
(:private
#>cpp
friend class AstStorage;
cpp<#)
(:serialize (:slk))
(:clone))
(lcp:pop-namespace) ;; namespace query
(lcp:pop-namespace) ;; namespace memgraph

View File

@ -95,6 +95,7 @@ class SettingQuery;
class VersionQuery;
class Foreach;
class ShowConfigQuery;
class CallSubquery;
class AnalyzeGraphQuery;
class TransactionQueueQuery;
class Exists;
@ -106,7 +107,7 @@ using TreeCompositeVisitor = utils::CompositeVisitor<
ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, ListLiteral, MapLiteral,
PropertyLookup, LabelsTest, Aggregation, Function, Reduce, Coalesce, Extract, All, Single, Any, None, CallProcedure,
Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom, Delete, Where, SetProperty, SetProperties, SetLabels,
RemoveProperty, RemoveLabels, Merge, Unwind, RegexMatch, LoadCsv, Foreach, Exists>;
RemoveProperty, RemoveLabels, Merge, Unwind, RegexMatch, LoadCsv, Foreach, Exists, CallSubquery, CypherQuery>;
using TreeLeafVisitor = utils::LeafVisitor<Identifier, PrimitiveLiteral, ParameterLookup>;

View File

@ -989,6 +989,10 @@ antlrcpp::Any CypherMainVisitor::visitSingleQuery(MemgraphCypher::SingleQueryCon
calls_write_procedure = true;
has_update = true;
}
} else if (const auto *call_subquery = utils::Downcast<CallSubquery>(clause); call_subquery != nullptr) {
if (has_return) {
throw SemanticException("CALL can't be put after RETURN clause.");
}
} else if (utils::IsSubtype(clause_type, Unwind::kType)) {
check_write_procedure("UNWIND");
if (has_update || has_return) {
@ -1101,6 +1105,9 @@ antlrcpp::Any CypherMainVisitor::visitClause(MemgraphCypher::ClauseContext *ctx)
if (ctx->foreach ()) {
return static_cast<Clause *>(std::any_cast<Foreach *>(ctx->foreach ()->accept(this)));
}
if (ctx->callSubquery()) {
return static_cast<Clause *>(std::any_cast<CallSubquery *>(ctx->callSubquery()->accept(this)));
}
// TODO: implement other clauses.
throw utils::NotYetImplemented("clause '{}'", ctx->getText());
return 0;
@ -1596,6 +1603,7 @@ antlrcpp::Any CypherMainVisitor::visitReturnItem(MemgraphCypher::ReturnItemConte
named_expr->expression_ = std::any_cast<Expression *>(ctx->expression()->accept(this));
MG_ASSERT(named_expr->expression_);
if (ctx->variable()) {
named_expr->is_aliased_ = true;
named_expr->name_ = std::string(std::any_cast<std::string>(ctx->variable()->accept(this)));
users_identifiers.insert(named_expr->name_);
} else {
@ -2565,6 +2573,20 @@ antlrcpp::Any CypherMainVisitor::visitShowConfigQuery(MemgraphCypher::ShowConfig
return query_;
}
antlrcpp::Any CypherMainVisitor::visitCallSubquery(MemgraphCypher::CallSubqueryContext *ctx) {
auto *call_subquery = storage_->Create<CallSubquery>();
MG_ASSERT(ctx->cypherQuery(), "Expected query inside subquery clause");
if (ctx->cypherQuery()->queryMemoryLimit()) {
throw SyntaxException("Memory limit cannot be set on subqueries!");
}
call_subquery->cypher_query_ = std::any_cast<CypherQuery *>(ctx->cypherQuery()->accept(this));
return call_subquery;
}
LabelIx CypherMainVisitor::AddLabel(const std::string &name) { return storage_->GetLabelIx(name); }
PropertyIx CypherMainVisitor::AddProperty(const std::string &name) { return storage_->GetPropertyIx(name); }

View File

@ -922,6 +922,11 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
*/
antlrcpp::Any visitShowConfigQuery(MemgraphCypher::ShowConfigQueryContext *ctx) override;
/**
* @return CallSubquery*
*/
antlrcpp::Any visitCallSubquery(MemgraphCypher::CallSubqueryContext *ctx) override;
public:
Query *query() { return query_; }
const static std::string kAnonPrefix;

View File

@ -176,6 +176,7 @@ clause : cypherMatch
| callProcedure
| loadCsv
| foreach
| callSubquery
;
updateClause : set
@ -188,6 +189,8 @@ updateClause : set
foreach : FOREACH '(' variable IN expression '|' updateClause+ ')' ;
callSubquery : CALL '{' cypherQuery '}' ;
streamQuery : checkStream
| createStream
| dropStream

View File

@ -36,6 +36,7 @@ BATCH_LIMIT : B A T C H UNDERSCORE L I M I T ;
BATCH_SIZE : B A T C H UNDERSCORE S I Z E ;
BEFORE : B E F O R E ;
BOOTSTRAP_SERVERS : B O O T S T R A P UNDERSCORE S E R V E R S ;
CALL : C A L L ;
CHECK : C H E C K ;
CLEAR : C L E A R ;
COMMIT : C O M M I T ;

View File

@ -31,9 +31,9 @@ class PrivilegeExtractor : public QueryVisitor<void>, public HierarchicalTreeVis
void Visit(AuthQuery & /*unused*/) override { AddPrivilege(AuthQuery::Privilege::AUTH); }
void Visit(ExplainQuery &query) override { query.cypher_query_->Accept(*this); }
void Visit(ExplainQuery &query) override { query.cypher_query_->Accept(dynamic_cast<QueryVisitor &>(*this)); }
void Visit(ProfileQuery &query) override { query.cypher_query_->Accept(*this); }
void Visit(ProfileQuery &query) override { query.cypher_query_->Accept(dynamic_cast<QueryVisitor &>(*this)); }
void Visit(InfoQuery &info_query) override {
switch (info_query.info_type_) {

View File

@ -211,6 +211,33 @@ bool SymbolGenerator::PostVisit(CallProcedure &call_proc) {
return true;
}
bool SymbolGenerator::PreVisit(CallSubquery & /*call_sub*/) {
scopes_.emplace_back(Scope{.in_call_subquery = true});
return true;
}
bool SymbolGenerator::PostVisit(CallSubquery & /*call_sub*/) {
// no need to set the flag to true as we are popping the scope
auto subquery_scope = scopes_.back();
scopes_.pop_back();
auto &main_query_scope = scopes_.back();
if (!subquery_scope.has_return) {
return true;
}
// append symbols returned in from subquery to outer scope
for (const auto &[symbol_name, symbol] : subquery_scope.symbols) {
if (main_query_scope.symbols.find(symbol_name) != main_query_scope.symbols.end()) {
throw SemanticException("Variable in subquery already declared in outer scope!");
}
main_query_scope.symbols[symbol_name] = symbol;
}
return true;
}
bool SymbolGenerator::PreVisit(LoadCsv &load_csv) { return false; }
bool SymbolGenerator::PostVisit(LoadCsv &load_csv) {
@ -224,6 +251,8 @@ bool SymbolGenerator::PostVisit(LoadCsv &load_csv) {
bool SymbolGenerator::PreVisit(Return &ret) {
auto &scope = scopes_.back();
scope.in_return = true;
scope.has_return = true;
VisitReturnBody(ret.body_);
scope.in_return = false;
return false; // We handled the traversal ourselves.
@ -470,6 +499,15 @@ bool SymbolGenerator::PostVisit(Exists & /*exists*/) {
return true;
}
bool SymbolGenerator::PreVisit(NamedExpression &named_expression) {
if (auto &scope = scopes_.back(); scope.in_call_subquery && scope.in_return &&
!utils::Downcast<Identifier>(named_expression.expression_) &&
!named_expression.is_aliased_) {
throw SemanticException("Expression returned from subquery must be aliased (use AS)!");
}
return true;
}
bool SymbolGenerator::PreVisit(SetProperty & /*set_property*/) {
auto &scope = scopes_.back();
scope.in_set_property = true;

View File

@ -50,6 +50,8 @@ class SymbolGenerator : public HierarchicalTreeVisitor {
bool PostVisit(Create &) override;
bool PreVisit(CallProcedure &) override;
bool PostVisit(CallProcedure &) override;
bool PreVisit(CallSubquery & /*unused*/) override;
bool PostVisit(CallSubquery & /*unused*/) override;
bool PreVisit(LoadCsv &) override;
bool PostVisit(LoadCsv &) override;
bool PreVisit(Return &) override;
@ -83,6 +85,7 @@ class SymbolGenerator : public HierarchicalTreeVisitor {
bool PreVisit(Extract &) override;
bool PreVisit(Exists & /*exists*/) override;
bool PostVisit(Exists & /*exists*/) override;
bool PreVisit(NamedExpression & /*unused*/) override;
// Pattern and its subparts.
bool PreVisit(Pattern &) override;
@ -119,6 +122,8 @@ class SymbolGenerator : public HierarchicalTreeVisitor {
bool in_foreach{false};
bool in_exists{false};
bool in_set_property{false};
bool in_call_subquery{false};
bool has_return{false};
// True when visiting a pattern atom (node or edge) identifier, which can be
// reused or created in the pattern itself.
bool in_pattern_atom_identifier{false};

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -62,6 +62,8 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor {
static constexpr double kEdgeUniquenessFilter{1.5};
static constexpr double kUnwind{1.3};
static constexpr double kForeach{1.0};
static constexpr double kUnion{1.0};
static constexpr double kSubquery{1.0};
};
struct CardParam {
@ -212,6 +214,31 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor {
return true;
}
bool PreVisit(Union &op) override {
double left_cost = EstimateCostOnBranch(&op.left_op_);
double right_cost = EstimateCostOnBranch(&op.right_op_);
// the number of hits in the previous operator should be the joined number of results of both parts of the union
cardinality_ *= (left_cost + right_cost);
IncrementCost(CostParam::kUnion);
return false;
}
bool PreVisit(Apply &op) override {
double input_cost = EstimateCostOnBranch(&op.input_);
double subquery_cost = EstimateCostOnBranch(&op.subquery_);
// if the query is a unit subquery, we don't want the cost to be zero but 1xN
input_cost = input_cost == 0 ? 1 : input_cost;
subquery_cost = subquery_cost == 0 ? 1 : subquery_cost;
cardinality_ *= input_cost * subquery_cost;
IncrementCost(CostParam::kSubquery);
return false;
}
bool Visit(Once &) override { return true; }
auto cost() const { return cost_; }
@ -232,6 +259,12 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor {
void IncrementCost(double param) { cost_ += param * cardinality_; }
double EstimateCostOnBranch(std::shared_ptr<LogicalOperator> *branch) {
CostEstimator<TDbAccessor> cost_estimator(db_accessor_, parameters);
(*branch)->Accept(cost_estimator);
return cost_estimator.cost();
}
// converts an optional ScanAll range bound into a property value
// if the bound is present and is a constant expression convertible to
// a property value. otherwise returns nullopt

View File

@ -117,6 +117,7 @@ extern const Event CallProcedureOperator;
extern const Event ForeachOperator;
extern const Event EmptyResultOperator;
extern const Event EvaluatePatternFilterOperator;
extern const Event ApplyOperator;
} // namespace EventCounter
namespace memgraph::query::plan {
@ -4070,8 +4071,14 @@ class DistinctCursor : public Cursor {
utils::pmr::vector<TypedValue> row(seen_rows_.get_allocator().GetMemoryResource());
row.reserve(self_.value_symbols_.size());
for (const auto &symbol : self_.value_symbols_) row.emplace_back(frame[symbol]);
if (seen_rows_.insert(std::move(row)).second) return true;
for (const auto &symbol : self_.value_symbols_) {
row.emplace_back(frame.at(symbol));
}
if (seen_rows_.insert(std::move(row)).second) {
return true;
}
}
}
@ -4796,4 +4803,72 @@ bool Foreach::Accept(HierarchicalLogicalOperatorVisitor &visitor) {
return visitor.PostVisit(*this);
}
Apply::Apply(const std::shared_ptr<LogicalOperator> input, const std::shared_ptr<LogicalOperator> subquery,
bool subquery_has_return)
: input_(input ? input : std::make_shared<Once>()),
subquery_(subquery),
subquery_has_return_(subquery_has_return) {}
bool Apply::Accept(HierarchicalLogicalOperatorVisitor &visitor) {
if (visitor.PreVisit(*this)) {
input_->Accept(visitor) && subquery_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
UniqueCursorPtr Apply::MakeCursor(utils::MemoryResource *mem) const {
EventCounter::IncrementCounter(EventCounter::ApplyOperator);
return MakeUniqueCursorPtr<ApplyCursor>(mem, *this, mem);
}
Apply::ApplyCursor::ApplyCursor(const Apply &self, utils::MemoryResource *mem)
: self_(self),
input_(self.input_->MakeCursor(mem)),
subquery_(self.subquery_->MakeCursor(mem)),
subquery_has_return_(self.subquery_has_return_) {}
std::vector<Symbol> Apply::ModifiedSymbols(const SymbolTable &table) const {
// Since Apply is the Cartesian product, modified symbols are combined from
// both execution branches.
auto symbols = input_->ModifiedSymbols(table);
auto subquery_symbols = subquery_->ModifiedSymbols(table);
symbols.insert(symbols.end(), subquery_symbols.begin(), subquery_symbols.end());
return symbols;
}
bool Apply::ApplyCursor::Pull(Frame &frame, ExecutionContext &context) {
SCOPED_PROFILE_OP("Apply");
while (true) {
if (pull_input_ && !input_->Pull(frame, context)) {
return false;
};
if (subquery_->Pull(frame, context)) {
// if successful, next Pull from this should not pull_input_
pull_input_ = false;
return true;
}
// failed to pull from subquery cursor
// skip that row
pull_input_ = true;
subquery_->Reset();
// don't skip row if no rows are returned from subquery, return input_ rows
if (!subquery_has_return_) return true;
}
}
void Apply::ApplyCursor::Shutdown() {
input_->Shutdown();
subquery_->Shutdown();
}
void Apply::ApplyCursor::Reset() {
input_->Reset();
subquery_->Reset();
pull_input_ = true;
}
} // namespace memgraph::query::plan

View File

@ -127,6 +127,7 @@ class LoadCsv;
class Foreach;
class EmptyResult;
class EvaluatePatternFilter;
class Apply;
using LogicalOperatorCompositeVisitor =
utils::CompositeVisitor<Once, CreateNode, CreateExpand, ScanAll, ScanAllByLabel, ScanAllByLabelPropertyRange,
@ -134,7 +135,7 @@ using LogicalOperatorCompositeVisitor =
ConstructNamedPath, Filter, Produce, Delete, SetProperty, SetProperties, SetLabels,
RemoveProperty, RemoveLabels, EdgeUniquenessFilter, Accumulate, Aggregate, Skip, Limit,
OrderBy, Merge, Optional, Unwind, Distinct, Union, Cartesian, CallProcedure, LoadCsv,
Foreach, EmptyResult, EvaluatePatternFilter>;
Foreach, EmptyResult, EvaluatePatternFilter, Apply>;
using LogicalOperatorLeafVisitor = utils::LeafVisitor<Once>;
@ -1087,8 +1088,8 @@ class Produce : public memgraph::query::plan::LogicalOperator {
auto object = std::make_unique<Produce>();
object->input_ = input_ ? input_->Clone(storage) : nullptr;
object->named_expressions_.resize(named_expressions_.size());
for (auto i1 = 0; i1 < named_expressions_.size(); ++i1) {
object->named_expressions_[i1] = named_expressions_[i1] ? named_expressions_[i1]->Clone(storage) : nullptr;
for (auto i2 = 0; i2 < named_expressions_.size(); ++i2) {
object->named_expressions_[i2] = named_expressions_[i2] ? named_expressions_[i2]->Clone(storage) : nullptr;
}
return object;
}
@ -1136,8 +1137,8 @@ class Delete : public memgraph::query::plan::LogicalOperator {
auto object = std::make_unique<Delete>();
object->input_ = input_ ? input_->Clone(storage) : nullptr;
object->expressions_.resize(expressions_.size());
for (auto i2 = 0; i2 < expressions_.size(); ++i2) {
object->expressions_[i2] = expressions_[i2] ? expressions_[i2]->Clone(storage) : nullptr;
for (auto i3 = 0; i3 < expressions_.size(); ++i3) {
object->expressions_[i3] = expressions_[i3] ? expressions_[i3]->Clone(storage) : nullptr;
}
object->detach_ = detach_;
return object;
@ -1608,12 +1609,12 @@ class Aggregate : public memgraph::query::plan::LogicalOperator {
auto object = std::make_unique<Aggregate>();
object->input_ = input_ ? input_->Clone(storage) : nullptr;
object->aggregations_.resize(aggregations_.size());
for (auto i3 = 0; i3 < aggregations_.size(); ++i3) {
object->aggregations_[i3] = aggregations_[i3].Clone(storage);
for (auto i4 = 0; i4 < aggregations_.size(); ++i4) {
object->aggregations_[i4] = aggregations_[i4].Clone(storage);
}
object->group_by_.resize(group_by_.size());
for (auto i4 = 0; i4 < group_by_.size(); ++i4) {
object->group_by_[i4] = group_by_[i4] ? group_by_[i4]->Clone(storage) : nullptr;
for (auto i5 = 0; i5 < group_by_.size(); ++i5) {
object->group_by_[i5] = group_by_[i5] ? group_by_[i5]->Clone(storage) : nullptr;
}
object->remember_ = remember_;
return object;
@ -1814,8 +1815,8 @@ class OrderBy : public memgraph::query::plan::LogicalOperator {
object->input_ = input_ ? input_->Clone(storage) : nullptr;
object->compare_ = compare_;
object->order_by_.resize(order_by_.size());
for (auto i5 = 0; i5 < order_by_.size(); ++i5) {
object->order_by_[i5] = order_by_[i5] ? order_by_[i5]->Clone(storage) : nullptr;
for (auto i6 = 0; i6 < order_by_.size(); ++i6) {
object->order_by_[i6] = order_by_[i6] ? order_by_[i6]->Clone(storage) : nullptr;
}
object->output_symbols_ = output_symbols_;
return object;
@ -2204,8 +2205,8 @@ class CallProcedure : public memgraph::query::plan::LogicalOperator {
object->input_ = input_ ? input_->Clone(storage) : nullptr;
object->procedure_name_ = procedure_name_;
object->arguments_.resize(arguments_.size());
for (auto i6 = 0; i6 < arguments_.size(); ++i6) {
object->arguments_[i6] = arguments_[i6] ? arguments_[i6]->Clone(storage) : nullptr;
for (auto i7 = 0; i7 < arguments_.size(); ++i7) {
object->arguments_[i7] = arguments_[i7] ? arguments_[i7]->Clone(storage) : nullptr;
}
object->result_fields_ = result_fields_;
object->result_symbols_ = result_symbols_;
@ -2291,6 +2292,53 @@ class Foreach : public memgraph::query::plan::LogicalOperator {
}
};
/// Applies symbols from both output branches.
class Apply : public memgraph::query::plan::LogicalOperator {
public:
static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; }
Apply() {}
Apply(const std::shared_ptr<LogicalOperator> input, const std::shared_ptr<LogicalOperator> subquery,
bool subquery_has_return);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; }
std::shared_ptr<memgraph::query::plan::LogicalOperator> input_;
std::shared_ptr<memgraph::query::plan::LogicalOperator> subquery_;
bool subquery_has_return_;
std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override {
auto object = std::make_unique<Apply>();
object->input_ = input_ ? input_->Clone(storage) : nullptr;
object->subquery_ = subquery_ ? subquery_->Clone(storage) : nullptr;
object->subquery_has_return_ = subquery_has_return_;
return object;
}
private:
class ApplyCursor : public Cursor {
public:
ApplyCursor(const Apply &, utils::MemoryResource *);
bool Pull(Frame &, ExecutionContext &) override;
void Shutdown() override;
void Reset() override;
private:
const Apply &self_;
UniqueCursorPtr input_;
UniqueCursorPtr subquery_;
bool pull_input_{true};
bool subquery_has_return_{true};
};
};
} // namespace plan
} // namespace query
} // namespace memgraph

View File

@ -134,6 +134,7 @@ class LoadCsv;
class Foreach;
class EmptyResult;
class EvaluatePatternFilter;
class Apply;
using LogicalOperatorCompositeVisitor = utils::CompositeVisitor<
Once, CreateNode, CreateExpand, ScanAll, ScanAllByLabel,
@ -142,7 +143,8 @@ using LogicalOperatorCompositeVisitor = utils::CompositeVisitor<
Expand, ExpandVariable, ConstructNamedPath, Filter, Produce, Delete,
SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels,
EdgeUniquenessFilter, Accumulate, Aggregate, Skip, Limit, OrderBy, Merge,
Optional, Unwind, Distinct, Union, Cartesian, CallProcedure, LoadCsv, Foreach, EmptyResult, EvaluatePatternFilter>;
Optional, Unwind, Distinct, Union, Cartesian, CallProcedure, LoadCsv, Foreach, EmptyResult,
EvaluatePatternFilter, Apply>;
using LogicalOperatorLeafVisitor = utils::LeafVisitor<Once>;
@ -2381,6 +2383,52 @@ clauses.
(:serialize (:slk))
(:clone))
(lcp:define-class apply (logical-operator)
((input "std::shared_ptr<LogicalOperator>" :scope :public
:slk-save #'slk-save-operator-pointer
:slk-load #'slk-load-operator-pointer)
(subquery "std::shared_ptr<LogicalOperator>" :scope :public
:slk-save #'slk-save-operator-pointer
:slk-load #'slk-load-operator-pointer)
(subquery-has-return "bool" :scope :public))
(:documentation "Applies symbols from both output branches.")
(:public
#>cpp
Apply() {}
Apply(const std::shared_ptr<LogicalOperator> input, const std::shared_ptr<LogicalOperator> subquery, bool subquery_has_return);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
bool HasSingleInput() const override { return true; }
std::shared_ptr<LogicalOperator> input() const override { return input_; }
void set_input(std::shared_ptr<LogicalOperator> input) override {
input_ = input;
}
cpp<#)
(:private
#>cpp
class ApplyCursor : public Cursor {
public:
ApplyCursor(const Apply &, utils::MemoryResource *);
bool Pull(Frame &, ExecutionContext &) override;
void Shutdown() override;
void Reset() override;
private:
const Apply &self_;
UniqueCursorPtr input_;
UniqueCursorPtr subquery_;
bool pull_input_{true};
bool subquery_has_return_{true};
};
cpp<#)
(:serialize (:slk))
(:clone))
(lcp:pop-namespace) ;; plan
(lcp:pop-namespace) ;; query
(lcp:pop-namespace) ;; memgraph

View File

@ -145,4 +145,7 @@ constexpr utils::TypeInfo query::plan::LoadCsv::kType{utils::TypeId::LOAD_CSV, "
constexpr utils::TypeInfo query::plan::Foreach::kType{utils::TypeId::FOREACH, "Foreach",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::Apply::kType{utils::TypeId::APPLY, "Apply",
&query::plan::LogicalOperator::kType};
} // namespace memgraph

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -50,23 +50,6 @@ class PostProcessor final {
double EstimatePlanCost(const std::unique_ptr<LogicalOperator> &plan, TVertexCounts *vertex_counts) {
return query::plan::EstimatePlanCost(vertex_counts, parameters_, *plan);
}
template <class TPlanningContext>
std::unique_ptr<LogicalOperator> MergeWithCombinator(std::unique_ptr<LogicalOperator> curr_op,
std::unique_ptr<LogicalOperator> last_op, const Tree &combinator,
TPlanningContext *context) {
if (const auto *union_ = utils::Downcast<const CypherUnion>(&combinator)) {
return std::unique_ptr<LogicalOperator>(
impl::GenUnion(*union_, std::move(last_op), std::move(curr_op), *context->symbol_table));
}
throw utils::NotYetImplemented("query combinator");
}
template <class TPlanningContext>
std::unique_ptr<LogicalOperator> MakeDistinct(std::unique_ptr<LogicalOperator> last_op, TPlanningContext *context) {
auto output_symbols = last_op->OutputSymbols(*context->symbol_table);
return std::make_unique<Distinct>(std::move(last_op), output_symbols);
}
};
/// @brief Generates the LogicalOperator tree for a single query and returns the
@ -82,10 +65,9 @@ class PostProcessor final {
/// @sa RuleBasedPlanner
/// @sa VariableStartPlanner
template <template <class> class TPlanner, class TDbAccessor>
auto MakeLogicalPlanForSingleQuery(std::vector<SingleQueryPart> single_query_parts,
PlanningContext<TDbAccessor> *context) {
auto MakeLogicalPlanForSingleQuery(QueryParts query_parts, PlanningContext<TDbAccessor> *context) {
context->bound_symbols.clear();
return TPlanner<PlanningContext<TDbAccessor>>(context).Plan(single_query_parts);
return TPlanner<PlanningContext<TDbAccessor>>(context).Plan(query_parts);
}
/// Generates the LogicalOperator tree and returns the resulting plan.
@ -103,48 +85,34 @@ template <class TPlanningContext, class TPlanPostProcess>
auto MakeLogicalPlan(TPlanningContext *context, TPlanPostProcess *post_process, bool use_variable_planner) {
auto query_parts = CollectQueryParts(*context->symbol_table, *context->ast_storage, context->query);
auto &vertex_counts = *context->db;
double total_cost = 0;
double total_cost = std::numeric_limits<double>::max();
using ProcessedPlan = typename TPlanPostProcess::ProcessedPlan;
ProcessedPlan last_plan;
ProcessedPlan plan_with_least_cost;
for (const auto &query_part : query_parts.query_parts) {
std::optional<ProcessedPlan> curr_plan;
double min_cost = std::numeric_limits<double>::max();
if (use_variable_planner) {
auto plans = MakeLogicalPlanForSingleQuery<VariableStartPlanner>(query_part.single_query_parts, context);
for (auto plan : plans) {
// Plans are generated lazily and the current plan will disappear, so
// it's ok to move it.
auto rewritten_plan = post_process->Rewrite(std::move(plan), context);
double cost = post_process->EstimatePlanCost(rewritten_plan, &vertex_counts);
if (!curr_plan || cost < min_cost) {
curr_plan.emplace(std::move(rewritten_plan));
min_cost = cost;
}
}
} else {
auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>(query_part.single_query_parts, context);
std::optional<ProcessedPlan> curr_plan;
if (use_variable_planner) {
auto plans = MakeLogicalPlanForSingleQuery<VariableStartPlanner>(query_parts, context);
for (auto plan : plans) {
// Plans are generated lazily and the current plan will disappear, so
// it's ok to move it.
auto rewritten_plan = post_process->Rewrite(std::move(plan), context);
min_cost = post_process->EstimatePlanCost(rewritten_plan, &vertex_counts);
curr_plan.emplace(std::move(rewritten_plan));
}
total_cost += min_cost;
if (query_part.query_combinator) {
last_plan = post_process->MergeWithCombinator(std::move(*curr_plan), std::move(last_plan),
*query_part.query_combinator, context);
} else {
last_plan = std::move(*curr_plan);
double cost = post_process->EstimatePlanCost(rewritten_plan, &vertex_counts);
if (!curr_plan || cost < total_cost) {
curr_plan.emplace(std::move(rewritten_plan));
total_cost = cost;
}
}
} else {
auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>(query_parts, context);
auto rewritten_plan = post_process->Rewrite(std::move(plan), context);
total_cost = post_process->EstimatePlanCost(rewritten_plan, &vertex_counts);
curr_plan.emplace(std::move(rewritten_plan));
}
if (query_parts.distinct) {
last_plan = post_process->MakeDistinct(std::move(last_plan), context);
}
plan_with_least_cost = std::move(*curr_plan);
return std::make_pair(std::move(last_plan), total_cost);
return std::make_pair(std::move(plan_with_least_cost), total_cost);
}
template <class TPlanningContext>

View File

@ -584,6 +584,9 @@ std::vector<SingleQueryPart> CollectSingleQueryParts(SymbolTable &symbol_table,
if (auto *merge = utils::Downcast<query::Merge>(clause)) {
query_part->merge_matching.emplace_back(Matching{});
AddMatching({merge->pattern_}, nullptr, symbol_table, storage, query_part->merge_matching.back());
} else if (auto *call_subquery = utils::Downcast<query::CallSubquery>(clause)) {
query_part->subqueries.emplace_back(
std::make_shared<QueryParts>(CollectQueryParts(symbol_table, storage, call_subquery->cypher_query_)));
} else if (auto *foreach = utils::Downcast<query::Foreach>(clause)) {
ParseForeach(*foreach, *query_part, storage, symbol_table);
} else if (utils::IsSubtype(*clause, With::kType) || utils::IsSubtype(*clause, query::Unwind::kType) ||

View File

@ -398,6 +398,9 @@ struct Matching {
std::unordered_set<Symbol> expansion_symbols{};
};
// TODO clumsy to need to declare it before, usually only the struct definition would be in header
struct QueryParts;
struct FilterMatching : Matching {
/// Type of pattern filter
PatternFilterType type;
@ -449,6 +452,9 @@ struct SingleQueryPart {
std::vector<Matching> merge_matching{};
/// @brief All the remaining clauses (without @c Match).
std::vector<Clause *> remaining_clauses{};
/// The subqueries vector are all the subqueries in this query part ordered in a list by
/// the order of calling.
std::vector<std::shared_ptr<QueryParts>> subqueries{};
};
/// Holds query parts of a single query together with the optional information

View File

@ -261,6 +261,13 @@ bool PlanPrinter::PreVisit(query::plan::Filter &op) {
op.input_->Accept(*this);
return false;
}
bool PlanPrinter::PreVisit(query::plan::Apply &op) {
WithPrintLn([](auto &out) { out << "* Apply"; });
Branch(*op.subquery_);
op.input_->Accept(*this);
return false;
}
#undef PRE_VISIT
bool PlanPrinter::DefaultPreVisit() {
@ -954,6 +961,20 @@ bool PlanToJsonVisitor::PreVisit(EvaluatePatternFilter &op) {
return false;
}
bool PlanToJsonVisitor::PreVisit(Apply &op) {
json self;
self["name"] = "Apply";
op.input_->Accept(*this);
self["input"] = PopOutput();
op.subquery_->Accept(*this);
self["subquery"] = PopOutput();
output_ = std::move(self);
return false;
}
} // namespace impl
} // namespace memgraph::query::plan

View File

@ -95,6 +95,7 @@ class PlanPrinter : public virtual HierarchicalLogicalOperatorVisitor {
bool PreVisit(CallProcedure &) override;
bool PreVisit(LoadCsv &) override;
bool PreVisit(Foreach &) override;
bool PreVisit(Apply & /*unused*/) override;
bool Visit(Once &) override;
@ -190,6 +191,7 @@ class PlanToJsonVisitor : public virtual HierarchicalLogicalOperatorVisitor {
bool PreVisit(EvaluatePatternFilter & /*op*/) override;
bool PreVisit(EdgeUniquenessFilter &) override;
bool PreVisit(Cartesian &) override;
bool PreVisit(Apply & /*unused*/) override;
bool PreVisit(ScanAll &) override;
bool PreVisit(ScanAllByLabel &) override;

View File

@ -465,6 +465,18 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
return true;
}
bool PreVisit(Apply &op) override {
prev_ops_.push_back(&op);
op.input()->Accept(*this);
RewriteBranch(&op.subquery_);
return false;
}
bool PostVisit(Apply & /*op*/) override {
prev_ops_.pop_back();
return true;
}
std::shared_ptr<LogicalOperator> new_root_;
private:

View File

@ -515,6 +515,22 @@ Expression *ExtractFilters(const std::unordered_set<Symbol> &bound_symbols, Filt
return filter_expr;
}
std::unordered_set<Symbol> GetSubqueryBoundSymbols(const std::vector<SingleQueryPart> &single_query_parts,
SymbolTable &symbol_table, AstStorage &storage) {
const auto &query = single_query_parts[0];
if (!query.matching.expansions.empty() || query.remaining_clauses.empty()) {
return {};
}
if (std::unordered_set<Symbol> bound_symbols; auto *with = utils::Downcast<query::With>(query.remaining_clauses[0])) {
auto input_op = impl::GenWith(*with, nullptr, symbol_table, false, bound_symbols, storage);
return bound_symbols;
}
return {};
}
std::unique_ptr<LogicalOperator> GenNamedPaths(std::unique_ptr<LogicalOperator> last_op,
std::unordered_set<Symbol> &bound_symbols,
std::unordered_map<Symbol, std::vector<Symbol>> &named_paths) {

View File

@ -49,6 +49,7 @@ struct PlanningContext {
/// write) the first `n`, but the latter `n` would only read the already
/// written information.
std::unordered_set<Symbol> bound_symbols{};
bool is_write_query{false};
};
template <class TDbAccessor>
@ -84,6 +85,11 @@ Expression *ExtractFilters(const std::unordered_set<Symbol> &, Filters &, AstSto
/// Checks if the filters has all the bound symbols to be included in the current part of the query
bool HasBoundFilterSymbols(const std::unordered_set<Symbol> &bound_symbols, const FilterInfo &filter);
// Returns the set of symbols for the subquery that are actually referenced from the outer scope and
// used in the subquery.
std::unordered_set<Symbol> GetSubqueryBoundSymbols(const std::vector<SingleQueryPart> &single_query_parts,
SymbolTable &symbol_table, AstStorage &storage);
/// Utility function for iterating pattern atoms and accumulating a result.
///
/// Each pattern is of the form `NodeAtom (, EdgeAtom, NodeAtom)*`. Therefore,
@ -164,86 +170,96 @@ class RuleBasedPlanner {
/// tree.
using PlanResult = std::unique_ptr<LogicalOperator>;
/// @brief Generates the operator tree based on explicitly set rules.
PlanResult Plan(const std::vector<SingleQueryPart> &query_parts) {
PlanResult Plan(const QueryParts &query_parts) {
auto &context = *context_;
std::unique_ptr<LogicalOperator> input_op;
// Set to true if a query command writes to the database.
bool is_write = false;
for (const auto &query_part : query_parts) {
MatchContext match_ctx{query_part.matching, *context.symbol_table, context.bound_symbols};
input_op = PlanMatching(match_ctx, std::move(input_op));
for (const auto &matching : query_part.optional_matching) {
MatchContext opt_ctx{matching, *context.symbol_table, context.bound_symbols};
std::unique_ptr<LogicalOperator> final_plan;
std::vector<Symbol> bound_symbols(context_->bound_symbols.begin(), context_->bound_symbols.end());
auto once_with_symbols = std::make_unique<Once>(bound_symbols);
for (const auto &query_part : query_parts.query_parts) {
std::unique_ptr<LogicalOperator> input_op;
auto match_op = PlanMatching(opt_ctx, std::move(once_with_symbols));
if (match_op) {
input_op = std::make_unique<Optional>(std::move(input_op), std::move(match_op), opt_ctx.new_symbols);
}
}
uint64_t merge_id = 0;
for (const auto &clause : query_part.remaining_clauses) {
MG_ASSERT(!utils::IsSubtype(*clause, Match::kType), "Unexpected Match in remaining clauses");
if (auto *ret = utils::Downcast<Return>(clause)) {
input_op = impl::GenReturn(*ret, std::move(input_op), *context.symbol_table, is_write, context.bound_symbols,
*context.ast_storage);
} else if (auto *merge = utils::Downcast<query::Merge>(clause)) {
input_op = GenMerge(*merge, std::move(input_op), query_part.merge_matching[merge_id++]);
// Treat MERGE clause as write, because we do not know if it will
// create anything.
is_write = true;
} else if (auto *with = utils::Downcast<query::With>(clause)) {
input_op = impl::GenWith(*with, std::move(input_op), *context.symbol_table, is_write, context.bound_symbols,
*context.ast_storage);
// WITH clause advances the command, so reset the flag.
is_write = false;
} else if (auto op = HandleWriteClause(clause, input_op, *context.symbol_table, context.bound_symbols)) {
is_write = true;
input_op = std::move(op);
} else if (auto *unwind = utils::Downcast<query::Unwind>(clause)) {
const auto &symbol = context.symbol_table->at(*unwind->named_expression_);
context.bound_symbols.insert(symbol);
input_op =
std::make_unique<plan::Unwind>(std::move(input_op), unwind->named_expression_->expression_, symbol);
context.is_write_query = false;
for (const auto &single_query_part : query_part.single_query_parts) {
input_op = HandleMatching(std::move(input_op), single_query_part, *context.symbol_table, context.bound_symbols);
} else if (auto *call_proc = utils::Downcast<query::CallProcedure>(clause)) {
std::vector<Symbol> result_symbols;
result_symbols.reserve(call_proc->result_identifiers_.size());
for (const auto *ident : call_proc->result_identifiers_) {
const auto &sym = context.symbol_table->at(*ident);
context.bound_symbols.insert(sym);
result_symbols.push_back(sym);
uint64_t merge_id = 0;
uint64_t subquery_id = 0;
for (const auto &clause : single_query_part.remaining_clauses) {
MG_ASSERT(!utils::IsSubtype(*clause, Match::kType), "Unexpected Match in remaining clauses");
if (auto *ret = utils::Downcast<Return>(clause)) {
input_op = impl::GenReturn(*ret, std::move(input_op), *context.symbol_table, context.is_write_query,
context.bound_symbols, *context.ast_storage);
} else if (auto *merge = utils::Downcast<query::Merge>(clause)) {
input_op = GenMerge(*merge, std::move(input_op), single_query_part.merge_matching[merge_id++]);
// Treat MERGE clause as write, because we do not know if it will
// create anything.
context.is_write_query = true;
} else if (auto *with = utils::Downcast<query::With>(clause)) {
input_op = impl::GenWith(*with, std::move(input_op), *context.symbol_table, context.is_write_query,
context.bound_symbols, *context.ast_storage);
// WITH clause advances the command, so reset the flag.
context.is_write_query = false;
} else if (auto op = HandleWriteClause(clause, input_op, *context.symbol_table, context.bound_symbols)) {
context.is_write_query = true;
input_op = std::move(op);
} else if (auto *unwind = utils::Downcast<query::Unwind>(clause)) {
const auto &symbol = context.symbol_table->at(*unwind->named_expression_);
context.bound_symbols.insert(symbol);
input_op =
std::make_unique<plan::Unwind>(std::move(input_op), unwind->named_expression_->expression_, symbol);
} else if (auto *call_proc = utils::Downcast<query::CallProcedure>(clause)) {
std::vector<Symbol> result_symbols;
result_symbols.reserve(call_proc->result_identifiers_.size());
for (const auto *ident : call_proc->result_identifiers_) {
const auto &sym = context.symbol_table->at(*ident);
context.bound_symbols.insert(sym);
result_symbols.push_back(sym);
}
// TODO: When we add support for write and eager procedures, we will
// need to plan this operator with Accumulate and pass in
// storage::View::NEW.
input_op = std::make_unique<plan::CallProcedure>(
std::move(input_op), call_proc->procedure_name_, call_proc->arguments_, call_proc->result_fields_,
result_symbols, call_proc->memory_limit_, call_proc->memory_scale_, call_proc->is_write_);
} else if (auto *load_csv = utils::Downcast<query::LoadCsv>(clause)) {
const auto &row_sym = context.symbol_table->at(*load_csv->row_var_);
context.bound_symbols.insert(row_sym);
input_op =
std::make_unique<plan::LoadCsv>(std::move(input_op), load_csv->file_, load_csv->with_header_,
load_csv->ignore_bad_, load_csv->delimiter_, load_csv->quote_, row_sym);
} else if (auto *foreach = utils::Downcast<query::Foreach>(clause)) {
context.is_write_query = true;
input_op = HandleForeachClause(foreach, std::move(input_op), *context.symbol_table, context.bound_symbols,
single_query_part, merge_id);
} else if (auto *call_sub = utils::Downcast<query::CallSubquery>(clause)) {
input_op = HandleSubquery(std::move(input_op), single_query_part.subqueries[subquery_id++],
*context.symbol_table, *context_->ast_storage);
} else {
throw utils::NotYetImplemented("clause '{}' conversion to operator(s)", clause->GetTypeInfo().name);
}
// TODO: When we add support for write and eager procedures, we will
// need to plan this operator with Accumulate and pass in
// storage::View::NEW.
input_op = std::make_unique<plan::CallProcedure>(
std::move(input_op), call_proc->procedure_name_, call_proc->arguments_, call_proc->result_fields_,
result_symbols, call_proc->memory_limit_, call_proc->memory_scale_, call_proc->is_write_);
} else if (auto *load_csv = utils::Downcast<query::LoadCsv>(clause)) {
const auto &row_sym = context.symbol_table->at(*load_csv->row_var_);
context.bound_symbols.insert(row_sym);
input_op =
std::make_unique<plan::LoadCsv>(std::move(input_op), load_csv->file_, load_csv->with_header_,
load_csv->ignore_bad_, load_csv->delimiter_, load_csv->quote_, row_sym);
} else if (auto *foreach = utils::Downcast<query::Foreach>(clause)) {
is_write = true;
input_op = HandleForeachClause(foreach, std::move(input_op), *context.symbol_table, context.bound_symbols,
query_part, merge_id);
} else {
throw utils::NotYetImplemented("clause '{}' conversion to operator(s)", clause->GetTypeInfo().name);
}
}
// Is this the only situation that should be covered
if (input_op->OutputSymbols(*context.symbol_table).empty()) {
input_op = std::make_unique<EmptyResult>(std::move(input_op));
}
if (query_part.query_combinator) {
final_plan = MergeWithCombinator(std::move(input_op), std::move(final_plan), *query_part.query_combinator);
} else {
final_plan = std::move(input_op);
}
}
// Is this the only situation that should be covered
if (input_op->OutputSymbols(*context.symbol_table).empty()) {
input_op = std::make_unique<EmptyResult>(std::move(input_op));
if (query_parts.distinct) {
final_plan = MakeDistinct(std::move(final_plan));
}
return input_op;
return final_plan;
}
private:
@ -255,6 +271,26 @@ class RuleBasedPlanner {
storage::EdgeTypeId GetEdgeType(EdgeTypeIx edge_type) { return context_->db->NameToEdgeType(edge_type.name); }
std::unique_ptr<LogicalOperator> HandleMatching(std::unique_ptr<LogicalOperator> last_op,
const SingleQueryPart &single_query_part, SymbolTable &symbol_table,
std::unordered_set<Symbol> &bound_symbols) {
MatchContext match_ctx{single_query_part.matching, symbol_table, bound_symbols};
last_op = PlanMatching(match_ctx, std::move(last_op));
for (const auto &matching : single_query_part.optional_matching) {
MatchContext opt_ctx{matching, symbol_table, bound_symbols};
std::vector<Symbol> bound_symbols(context_->bound_symbols.begin(), context_->bound_symbols.end());
auto once_with_symbols = std::make_unique<Once>(bound_symbols);
auto match_op = PlanMatching(opt_ctx, std::move(once_with_symbols));
if (match_op) {
last_op = std::make_unique<Optional>(std::move(last_op), std::move(match_op), opt_ctx.new_symbols);
}
}
return last_op;
}
std::unique_ptr<LogicalOperator> GenCreate(Create &create, std::unique_ptr<LogicalOperator> input_op,
const SymbolTable &symbol_table,
std::unordered_set<Symbol> &bound_symbols) {
@ -597,6 +633,36 @@ class RuleBasedPlanner {
symbol);
}
std::unique_ptr<LogicalOperator> HandleSubquery(std::unique_ptr<LogicalOperator> last_op,
std::shared_ptr<QueryParts> subquery, SymbolTable &symbol_table,
AstStorage &storage) {
std::unordered_set<Symbol> outer_scope_bound_symbols;
outer_scope_bound_symbols.insert(std::make_move_iterator(context_->bound_symbols.begin()),
std::make_move_iterator(context_->bound_symbols.end()));
context_->bound_symbols =
impl::GetSubqueryBoundSymbols(subquery->query_parts[0].single_query_parts, symbol_table, storage);
auto subquery_op = Plan(*subquery);
context_->bound_symbols.clear();
context_->bound_symbols.insert(std::make_move_iterator(outer_scope_bound_symbols.begin()),
std::make_move_iterator(outer_scope_bound_symbols.end()));
auto subquery_has_return = true;
if (subquery_op->GetTypeInfo() == EmptyResult::kType) {
subquery_has_return = false;
}
last_op = std::make_unique<Apply>(std::move(last_op), std::move(subquery_op), subquery_has_return);
if (context_->is_write_query) {
last_op = std::make_unique<Accumulate>(std::move(last_op), last_op->ModifiedSymbols(symbol_table), true);
}
return last_op;
}
std::unique_ptr<LogicalOperator> GenFilters(std::unique_ptr<LogicalOperator> last_op,
const std::unordered_set<Symbol> &bound_symbols, Filters &filters,
AstStorage &storage, const SymbolTable &symbol_table) {
@ -654,6 +720,20 @@ class RuleBasedPlanner {
return operators;
}
std::unique_ptr<LogicalOperator> MergeWithCombinator(std::unique_ptr<LogicalOperator> curr_op,
std::unique_ptr<LogicalOperator> last_op,
const Tree &combinator) {
if (const auto *union_ = utils::Downcast<const CypherUnion>(&combinator)) {
return impl::GenUnion(*union_, std::move(last_op), std::move(curr_op), *context_->symbol_table);
}
throw utils::NotYetImplemented("This type of merging queries is not yet implemented!");
}
std::unique_ptr<LogicalOperator> MakeDistinct(std::unique_ptr<LogicalOperator> last_op) {
return std::make_unique<Distinct>(std::move(last_op), last_op->OutputSymbols(*context_->symbol_table));
}
};
} // namespace memgraph::query::plan

View File

@ -313,25 +313,70 @@ class VariableStartPlanner {
// Generates different, equivalent query parts by taking different graph
// matching routes for each query part.
auto VaryQueryMatching(const std::vector<SingleQueryPart> &query_parts, const SymbolTable &symbol_table) {
std::vector<impl::VaryQueryPartMatching> alternative_query_parts;
alternative_query_parts.reserve(query_parts.size());
for (const auto &query_part : query_parts) {
alternative_query_parts.emplace_back(impl::VaryQueryPartMatching(query_part, symbol_table));
auto VaryQueryMatching(const QueryParts &query_parts, const SymbolTable &symbol_table) {
std::vector<impl::VaryQueryPartMatching> varying_query_matchings;
auto single_query_parts = ExtractSingleQueryParts(std::make_unique<QueryParts>(query_parts));
for (const auto &single_query_part : single_query_parts) {
varying_query_matchings.emplace_back(single_query_part, symbol_table);
}
return iter::slice(MakeCartesianProduct(std::move(alternative_query_parts)), 0UL, FLAGS_query_max_plans);
return iter::slice(MakeCartesianProduct(std::move(varying_query_matchings)), 0UL, FLAGS_query_max_plans);
}
std::vector<SingleQueryPart> ExtractSingleQueryParts(const std::shared_ptr<QueryParts> query_parts) {
std::vector<SingleQueryPart> results;
for (const auto &query_part : query_parts->query_parts) {
for (const auto &single_query_part : query_part.single_query_parts) {
results.push_back(single_query_part);
for (const auto &subquery : single_query_part.subqueries) {
const auto subquery_results = ExtractSingleQueryParts(subquery);
results.insert(results.end(), std::make_move_iterator(subquery_results.begin()),
std::make_move_iterator(subquery_results.end()));
}
}
}
return results;
}
QueryParts ReconstructQueryParts(const QueryParts &old_query_parts,
const std::vector<SingleQueryPart> &single_query_parts_variation, uint64_t &index) {
auto reconstructed_query_parts = old_query_parts;
for (auto i = 0; i < old_query_parts.query_parts.size(); i++) {
const auto &old_query_part = old_query_parts.query_parts[i];
for (auto j = 0; j < old_query_part.single_query_parts.size(); j++) {
const auto &old_single_query_part = old_query_part.single_query_parts[j];
reconstructed_query_parts.query_parts[i].single_query_parts[j] = single_query_parts_variation[index++];
for (auto k = 0; k < old_single_query_part.subqueries.size(); k++) {
const auto &subquery = old_single_query_part.subqueries[k];
reconstructed_query_parts.query_parts[i].single_query_parts[j].subqueries[k] =
std::make_shared<QueryParts>(ReconstructQueryParts(*subquery, single_query_parts_variation, index));
}
}
}
return reconstructed_query_parts;
}
public:
explicit VariableStartPlanner(TPlanningContext *context) : context_(context) {}
/// @brief Generate multiple plans by varying the order of graph traversal.
auto Plan(const std::vector<SingleQueryPart> &query_parts) {
auto Plan(const QueryParts &query_parts) {
return iter::imap(
[context = context_](const auto &alternative_query_parts) {
[context = context_, old_query_parts = query_parts, this](const auto &alternative_query_parts) {
uint64_t index = 0;
auto reconstructed_query_parts = ReconstructQueryParts(old_query_parts, alternative_query_parts, index);
RuleBasedPlanner<TPlanningContext> rule_planner(context);
context->bound_symbols.clear();
return rule_planner.Plan(alternative_query_parts);
return rule_planner.Plan(reconstructed_query_parts);
},
VaryQueryMatching(query_parts, *context_->symbol_table));
}
@ -339,7 +384,7 @@ class VariableStartPlanner {
/// @brief The result of plan generation is an iterable of roots to multiple
/// generated operator trees.
using PlanResult = typename std::result_of<decltype (&VariableStartPlanner<TPlanningContext>::Plan)(
VariableStartPlanner<TPlanningContext>, std::vector<SingleQueryPart> &)>::type;
VariableStartPlanner<TPlanningContext>, QueryParts &)>::type;
};
} // namespace memgraph::query::plan

View File

@ -52,6 +52,7 @@
M(CallProcedureOperator, "Number of times CallProcedure operator was used.") \
M(ForeachOperator, "Number of times Foreach operator was used.") \
M(EvaluatePatternFilterOperator, "Number of times EvaluatePatternFilter operator was used.") \
M(ApplyOperator, "Number of times ApplyOperator operator was used.") \
\
M(FailedQuery, "Number of times executing a query failed.") \
M(LabelIndexCreated, "Number of times a label index was created.") \

View File

@ -64,6 +64,7 @@ enum class TypeId : uint64_t {
CALL_PROCEDURE,
LOAD_CSV,
FOREACH,
APPLY,
// Replication
REP_APPEND_DELTAS_REQ,
@ -179,6 +180,8 @@ enum class TypeId : uint64_t {
AST_ANALYZE_GRAPH_QUERY,
AST_TRANSACTION_QUEUE_QUERY,
AST_EXISTS,
AST_CALL_SUBQUERY,
// Symbol
SYMBOL,
};

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -58,9 +58,8 @@ static void BM_PlanChainedMatches(benchmark::State &state) {
if (query_parts.query_parts.size() == 0) {
std::exit(EXIT_FAILURE);
}
auto single_query_parts = query_parts.query_parts.at(0).single_query_parts;
auto plans = memgraph::query::plan::MakeLogicalPlanForSingleQuery<memgraph::query::plan::VariableStartPlanner>(
single_query_parts, &ctx);
query_parts, &ctx);
for (const auto &plan : plans) {
// Exhaust through all generated plans, since they are lazily generated.
benchmark::DoNotOptimize(plan.get());
@ -129,9 +128,8 @@ static void BM_PlanAndEstimateIndexedMatching(benchmark::State &state) {
if (query_parts.query_parts.size() == 0) {
std::exit(EXIT_FAILURE);
}
auto single_query_parts = query_parts.query_parts.at(0).single_query_parts;
auto plans = memgraph::query::plan::MakeLogicalPlanForSingleQuery<memgraph::query::plan::VariableStartPlanner>(
single_query_parts, &ctx);
query_parts, &ctx);
for (auto plan : plans) {
memgraph::query::plan::EstimatePlanCost(&dba, parameters, *plan);
}
@ -160,9 +158,8 @@ static void BM_PlanAndEstimateIndexedMatchingWithCachedCounts(benchmark::State &
if (query_parts.query_parts.size() == 0) {
std::exit(EXIT_FAILURE);
}
auto single_query_parts = query_parts.query_parts.at(0).single_query_parts;
auto plans = memgraph::query::plan::MakeLogicalPlanForSingleQuery<memgraph::query::plan::VariableStartPlanner>(
single_query_parts, &ctx);
query_parts, &ctx);
for (auto plan : plans) {
memgraph::query::plan::EstimatePlanCost(&vertex_counts, parameters, *plan);
}

View File

@ -0,0 +1,408 @@
# Copyright 2023 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
Feature: Subqueries
Behaviour tests for memgraph CALL clause which contains a subquery
Scenario: Subquery without bounded symbols
Given an empty graph
And having executed
"""
CREATE (:Label1 {prop: 1})-[:TYPE]->(:Label2 {prop: 2})
"""
When executing query:
"""
MATCH (n:Label1)
CALL {
MATCH (n:Label1)-[:TYPE]->(m:Label2)
RETURN m
}
RETURN m.prop;
"""
Then the result should be:
| m.prop |
| 2 |
Scenario: Subquery without bounded symbols and without match
Given an empty graph
And having executed
"""
CREATE (:Label1 {prop: 1})-[:TYPE]->(:Label2 {prop: 2})
"""
When executing query:
"""
CALL {
MATCH (n:Label1)-[:TYPE]->(m:Label2)
RETURN m
}
RETURN m.prop;
"""
Then the result should be:
| m.prop |
| 2 |
Scenario: Subquery returning primitive
Given an empty graph
And having executed
"""
CREATE (:Label1 {prop: 1})-[:TYPE]->(:Label2 {prop: 2})
"""
When executing query:
"""
CALL {
MATCH (n:Label1)-[:TYPE]->(m:Label2)
RETURN m.prop AS prop
}
RETURN prop;
"""
Then the result should be:
| prop |
| 2 |
Scenario: Subquery returning 2 values
Given an empty graph
And having executed
"""
CREATE (:Label1 {prop: 1})-[:TYPE]->(:Label2 {prop: 2})
"""
When executing query:
"""
MATCH (n:Label1)
CALL {
MATCH (m:Label1)-[:TYPE]->(o:Label2)
RETURN m, o
}
RETURN m.prop, o.prop;
"""
Then the result should be:
| m.prop | o.prop |
| 1 | 2 |
Scenario: Subquery returning nothing because match did not find any results
Given an empty graph
And having executed
"""
CREATE (:Label1 {prop: 1})-[:TYPE]->(:Label2 {prop: 2})
"""
When executing query:
"""
MATCH (n:Label3)
CALL {
MATCH (m:Label1)-[:TYPE]->(:Label2)
RETURN m
}
RETURN m.prop;
"""
Then the result should be empty
Scenario: Subquery returning a multiple of results since we join elements from basic query and the subquery
Given an empty graph
And having executed
"""
CREATE (:Label1 {prop: 1})-[:TYPE]->(:Label2 {prop: 2})
"""
When executing query:
"""
MATCH (n:Label1)
CALL {
MATCH (m)
RETURN m
}
RETURN n.prop, m.prop
ORDER BY n.prop, m.prop;
"""
Then the result should be:
| n.prop | m.prop |
| 1 | 1 |
| 1 | 2 |
Scenario: Subquery returning a cartesian product
Given an empty graph
And having executed
"""
CREATE (:Label1 {prop: 1})-[:TYPE]->(:Label2 {prop: 2})
"""
When executing query:
"""
MATCH (n)
CALL {
MATCH (m)
RETURN m
}
RETURN n.prop, m.prop
ORDER BY n.prop, m.prop;
"""
Then the result should be:
| n.prop | m.prop |
| 1 | 1 |
| 1 | 2 |
| 2 | 1 |
| 2 | 2 |
Scenario: Subquery with bounded symbols
Given an empty graph
And having executed
"""
CREATE (:Label1 {prop: 1})-[:TYPE]->(:Label2 {prop: 2})
"""
When executing query:
"""
MATCH (n:Label1)
CALL {
WITH n
MATCH (n)-[:TYPE]->(m:Label2)
RETURN m
}
RETURN m.prop;
"""
Then the result should be:
| m.prop |
| 2 |
Scenario: Subquery with invalid bounded symbols
Given an empty graph
And having executed
"""
CREATE (:Label1 {prop: 1})-[:TYPE]->(:Label2 {prop: 2})
"""
When executing query:
"""
MATCH (n:Label1)
CALL {
WITH o
MATCH (o)-[:TYPE]->(m:Label2)
RETURN m
}
RETURN m.prop;
"""
Then an error should be raised
Scenario: Subquery returning primitive but not aliased
Given an empty graph
And having executed
"""
CREATE (:Label1 {prop: 1})-[:TYPE]->(:Label2 {prop: 2})
"""
When executing query:
"""
MATCH (n:Label1)
CALL {
WITH n
MATCH (n)-[:TYPE]->(m:Label2)
RETURN m.prop
}
RETURN n;
"""
Then an error should be raised
Scenario: Subquery returning one primitive and others aliased
Given an empty graph
And having executed
"""
CREATE (:Label1 {prop: 1})-[:TYPE]->(:Label2 {prop: 2})
"""
When executing query:
"""
MATCH (n:Label1)
CALL {
WITH n
MATCH (o)-[:TYPE]->(m:Label2)
RETURN m.prop, o
}
RETURN n;
"""
Then an error should be raised
Scenario: Subquery returning already declared variable in outer scope
Given an empty graph
And having executed
"""
CREATE (:Label1 {prop: 1})-[:TYPE]->(:Label2 {prop: 2})
"""
When executing query:
"""
MATCH (n:Label1), (m:Label2)
CALL {
WITH n
MATCH (n:Label1)-[:TYPE]->(m:Label2)
RETURN m
}
RETURN n;
"""
Then an error should be raised
Scenario: Subquery after subquery
Given an empty graph
And having executed
"""
CREATE (:Label1 {prop: 1})-[:TYPE]->(:Label2 {prop: 2})
"""
When executing query:
"""
MATCH (n)
CALL {
MATCH (m)
RETURN m
}
CALL {
MATCH (o)
RETURN o
}
RETURN n.prop, m.prop, o.prop
ORDER BY n.prop, m.prop, o.prop;
"""
Then the result should be:
| n.prop | m.prop | o.prop |
| 1 | 1 | 1 |
| 1 | 1 | 2 |
| 1 | 2 | 1 |
| 1 | 2 | 2 |
| 2 | 1 | 1 |
| 2 | 1 | 2 |
| 2 | 2 | 1 |
| 2 | 2 | 2 |
Scenario: Subquery with union
Given an empty graph
And having executed
"""
CREATE (:Person {figure: "grandpa"})<-[:CHILD_OF]-(:Person {figure: "dad"})-[:PARENT_OF]->(:Person {figure: "child"})
"""
When executing query:
"""
MATCH (p:Person {figure: "dad"})
CALL {
WITH p
OPTIONAL MATCH (p)-[:CHILD_OF]->(other:Person)
RETURN other
UNION
WITH p
OPTIONAL MATCH (p)-[:PARENT_OF]->(other:Person)
RETURN other
} RETURN DISTINCT p.figure, count(other) as cnt;
"""
Then the result should be:
| p.figure | cnt |
| 'dad' | 2 |
Scenario: Subquery cloning nodes
Given an empty graph
And having executed
"""
CREATE (:Person {name: "Alen"}), (:Person {name: "Bruce"})
"""
When executing query:
"""
MATCH (p:Person)
CALL {
WITH p
UNWIND range (1, 3) AS i
CREATE (n:Person {name: p.name})
RETURN n
}
RETURN n;
"""
Then the result should be:
| n |
| (:Person {name: 'Alen'}) |
| (:Person {name: 'Alen'}) |
| (:Person {name: 'Alen'}) |
| (:Person {name: 'Bruce'}) |
| (:Person {name: 'Bruce'}) |
| (:Person {name: 'Bruce'}) |
Scenario: Subquery in subquery
Given an empty graph
And having executed
"""
CREATE (:Label {id: 1}), (:Label {id: 2})
"""
When executing query:
"""
MATCH (p:Label)
CALL {
MATCH (r:Label)
CALL {
MATCH (s:Label)
RETURN s
}
RETURN r, s
}
RETURN p.id, r.id, s.id;
"""
Then the result should be:
| p.id | r.id | s.id |
| 1 | 1 | 1 |
| 1 | 1 | 2 |
| 1 | 2 | 1 |
| 1 | 2 | 2 |
| 2 | 1 | 1 |
| 2 | 1 | 2 |
| 2 | 2 | 1 |
| 2 | 2 | 2 |
Scenario: Counter inside subquery
Given an empty graph
And having executed
"""
CREATE (:Counter {count: 0})
"""
When executing query:
"""
UNWIND [0, 1, 2] AS x
CALL {
MATCH (n:Counter)
SET n.count = n.count + 1
RETURN n.count AS innerCount
}
WITH innerCount
MATCH (n:Counter)
RETURN innerCount, n.count AS totalCount
"""
Then the result should be:
| innerCount | totalCount |
| 1 | 3 |
| 2 | 3 |
| 3 | 3 |
Scenario: Advance command on multiple subqueries
Given an empty graph
When executing query:
"""
CALL {
CREATE (create_node:Movie {title: "Forrest Gump"})
}
CALL {
MATCH (n) RETURN n
}
RETURN n.title AS title;
"""
Then the result should be:
| title |
| 'Forrest Gump' |
Scenario: Advance command on multiple subqueries with manual accumulate
Given an empty graph
When executing query:
"""
CALL {
CREATE (create_node:Movie {title: "Forrest Gump"})
RETURN create_node
}
WITH create_node
CALL {
MATCH (n) RETURN n
}
RETURN n.title AS title;
"""
Then the result should be:
| title |
| 'Forrest Gump' |

View File

@ -453,7 +453,7 @@ auto MakeLogicalPlans(memgraph::query::CypherQuery *query, memgraph::query::AstS
memgraph::query::Parameters parameters;
memgraph::query::plan::PostProcessor post_process(parameters);
auto plans = memgraph::query::plan::MakeLogicalPlanForSingleQuery<memgraph::query::plan::VariableStartPlanner>(
query_parts.query_parts.at(0).single_query_parts, &ctx);
query_parts, &ctx);
for (auto plan : plans) {
memgraph::query::AstStorage ast_copy;
auto unoptimized_plan = plan->Clone(&ast_copy);

View File

@ -4379,3 +4379,77 @@ TEST_P(CypherMainVisitorTest, Exists) {
ASSERT_TRUE(node);
}
}
TEST_P(CypherMainVisitorTest, CallSubqueryThrow) {
auto &ast_generator = *GetParam();
TestInvalidQueryWithMessage<SyntaxException>("MATCH (n) CALL { MATCH (m) RETURN m QUERY MEMORY UNLIMITED } RETURN n",
ast_generator, "Memory limit cannot be set on subqueries!");
}
TEST_P(CypherMainVisitorTest, CallSubquery) {
auto &ast_generator = *GetParam();
{
const auto *query =
dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("MATCH (n) CALL { MATCH (m) RETURN m } RETURN n, m"));
const auto *call_subquery = dynamic_cast<CallSubquery *>(query->single_query_->clauses_[1]);
const auto *subquery = dynamic_cast<CypherQuery *>(call_subquery->cypher_query_);
ASSERT_TRUE(subquery);
const auto *match = dynamic_cast<Match *>(subquery->single_query_->clauses_[0]);
ASSERT_TRUE(match);
}
{
const auto *query = dynamic_cast<CypherQuery *>(
ast_generator.ParseQuery("MATCH (n) CALL { MATCH (m) RETURN (m) UNION MATCH (m) RETURN m } RETURN n, m"));
const auto *call_subquery = dynamic_cast<CallSubquery *>(query->single_query_->clauses_[1]);
const auto *subquery = dynamic_cast<CypherQuery *>(call_subquery->cypher_query_);
ASSERT_TRUE(subquery);
const auto *match = dynamic_cast<Match *>(subquery->single_query_->clauses_[0]);
ASSERT_TRUE(match);
const auto unions = subquery->cypher_unions_;
ASSERT_TRUE(unions.size() == 1);
}
{
const auto *query = dynamic_cast<CypherQuery *>(
ast_generator.ParseQuery("MATCH (n) CALL { MATCH (m) RETURN (m) UNION ALL MATCH (m) RETURN m } RETURN n, m"));
const auto *call_subquery = dynamic_cast<CallSubquery *>(query->single_query_->clauses_[1]);
const auto *subquery = dynamic_cast<CypherQuery *>(call_subquery->cypher_query_);
ASSERT_TRUE(subquery);
const auto *match = dynamic_cast<Match *>(subquery->single_query_->clauses_[0]);
ASSERT_TRUE(match);
const auto unions = subquery->cypher_unions_;
ASSERT_TRUE(unions.size() == 1);
}
{
const auto *query = dynamic_cast<CypherQuery *>(
ast_generator.ParseQuery("MATCH (n) CALL { MATCH (m) CALL { MATCH (o) RETURN o} RETURN m, o } RETURN n, m, o"));
const auto *call_subquery = dynamic_cast<CallSubquery *>(query->single_query_->clauses_[1]);
const auto *subquery = dynamic_cast<CypherQuery *>(call_subquery->cypher_query_);
ASSERT_TRUE(subquery);
const auto *match = dynamic_cast<Match *>(subquery->single_query_->clauses_[0]);
ASSERT_TRUE(match);
const auto *nested_subquery = dynamic_cast<CallSubquery *>(subquery->single_query_->clauses_[1]);
ASSERT_TRUE(nested_subquery);
const auto *nested_cypher = dynamic_cast<CypherQuery *>(nested_subquery->cypher_query_);
ASSERT_TRUE(nested_cypher);
const auto *nested_match = dynamic_cast<Match *>(nested_cypher->single_query_->clauses_[0]);
ASSERT_TRUE(nested_match);
}
}

View File

@ -464,6 +464,24 @@ auto GetCallProcedure(AstStorage &storage, std::string procedure_name,
return call_procedure;
}
auto GetCallSubquery(AstStorage &storage, SingleQuery *subquery) {
auto *call_subquery = storage.Create<memgraph::query::CallSubquery>();
auto *query = storage.Create<CypherQuery>();
query->single_query_ = std::move(subquery);
call_subquery->cypher_query_ = std::move(query);
return call_subquery;
}
auto GetCallSubquery(AstStorage &storage, CypherQuery *subquery) {
auto *call_subquery = storage.Create<memgraph::query::CallSubquery>();
call_subquery->cypher_query_ = std::move(subquery);
return call_subquery;
}
/// Create the FOREACH clause with given named expression.
auto GetForeach(AstStorage &storage, NamedExpression *named_expr, const std::vector<query::Clause *> &clauses) {
return storage.Create<query::Foreach>(named_expr, clauses);
@ -593,3 +611,4 @@ auto GetForeach(AstStorage &storage, NamedExpression *named_expr, const std::vec
(labels), (edgeTypes))
#define DROP_USER(usernames) storage.Create<memgraph::query::DropUser>((usernames))
#define CALL_PROCEDURE(...) memgraph::query::test_common::GetCallProcedure(storage, __VA_ARGS__)
#define CALL_SUBQUERY(...) memgraph::query::test_common::GetCallSubquery(storage, __VA_ARGS__)

View File

@ -194,6 +194,37 @@ TEST_F(QueryCostEstimator, Foreach) {
MakeOp<memgraph::query::plan::Foreach>(last_op_, create, storage_.Create<Identifier>(), NextSymbol());
EXPECT_COST(CostParam::kForeach * MiscParam::kForeachNoLiteral);
}
TEST_F(QueryCostEstimator, SubqueryCartesian) {
auto no_vertices = 4;
AddVertices(no_vertices, 0, 0);
std::shared_ptr<LogicalOperator> input = std::make_shared<ScanAll>(std::make_shared<Once>(), NextSymbol());
std::shared_ptr<LogicalOperator> subquery = std::make_shared<ScanAll>(std::make_shared<Once>(), NextSymbol());
MakeOp<memgraph::query::plan::Apply>(input, subquery, true);
EXPECT_COST(CostParam::kSubquery * no_vertices * no_vertices);
}
TEST_F(QueryCostEstimator, UnitSubquery) {
auto no_vertices = 4;
AddVertices(no_vertices, 0, 0);
std::shared_ptr<LogicalOperator> input = std::make_shared<Once>();
std::shared_ptr<LogicalOperator> subquery = std::make_shared<ScanAll>(std::make_shared<Once>(), NextSymbol());
MakeOp<memgraph::query::plan::Apply>(input, subquery, true);
EXPECT_COST(CostParam::kSubquery * no_vertices);
}
TEST_F(QueryCostEstimator, Union) {
auto no_vertices = 4;
AddVertices(no_vertices, 0, 0);
std::vector<Symbol> union_symbols{NextSymbol()};
std::shared_ptr<LogicalOperator> left_op = std::make_shared<ScanAll>(std::make_shared<Once>(), NextSymbol());
std::shared_ptr<LogicalOperator> right_op = std::make_shared<ScanAll>(std::make_shared<Once>(), NextSymbol());
MakeOp<memgraph::query::plan::Union>(left_op, right_op, union_symbols, left_op->OutputSymbols(symbol_table_),
right_op->OutputSymbols(symbol_table_));
EXPECT_COST(CostParam::kUnion * (no_vertices + no_vertices));
}
// Helper for testing an operations cost and cardinality.
// Only for operations that first increment cost, then modify cardinality.
// Intentially a macro (instead of function) for better test feedback.

View File

@ -38,6 +38,7 @@ namespace memgraph::query {
using namespace memgraph::query::plan;
using memgraph::query::AstStorage;
using memgraph::query::CypherUnion;
using memgraph::query::SingleQuery;
using memgraph::query::Symbol;
using memgraph::query::SymbolGenerator;
@ -51,10 +52,10 @@ namespace {
class Planner {
public:
template <class TDbAccessor>
Planner(std::vector<SingleQueryPart> single_query_parts, PlanningContext<TDbAccessor> context) {
Planner(QueryParts query_parts, PlanningContext<TDbAccessor> context) {
memgraph::query::Parameters parameters;
PostProcessor post_processor(parameters);
plan_ = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>(single_query_parts, &context);
plan_ = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>(query_parts, &context);
plan_ = post_processor.Rewrite(std::move(plan_), &context);
}
@ -1850,4 +1851,92 @@ TYPED_TEST(TestPlanner, Exists) {
DeleteListContent(&pattern_filter_without_types);
}
}
TYPED_TEST(TestPlanner, Subqueries) {
AstStorage storage;
FakeDbAccessor dba;
// MATCH (n) CALL { MATCH (m) RETURN (m) } RETURN n, m
{
auto *subquery = SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), RETURN("n"));
auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("m"))), CALL_SUBQUERY(subquery), RETURN("m", "n")));
auto symbol_table = memgraph::query::MakeSymbolTable(query);
auto planner = MakePlanner<TypeParam>(&dba, storage, symbol_table, query);
std::list<BaseOpChecker *> subquery_plan{new ExpectScanAll(), new ExpectProduce()};
CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), ExpectApply(subquery_plan), ExpectProduce());
DeleteListContent(&subquery_plan);
}
// MATCH (n) CALL { MATCH (m)-[r]->(n) RETURN (m) } RETURN n, m
{
auto *subquery = SINGLE_QUERY(MATCH(PATTERN(NODE("n"), EDGE("r", Direction::OUT), NODE("m"))), RETURN("n"));
auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("m"))), CALL_SUBQUERY(subquery), RETURN("m", "n")));
auto symbol_table = memgraph::query::MakeSymbolTable(query);
auto planner = MakePlanner<TypeParam>(&dba, storage, symbol_table, query);
std::list<BaseOpChecker *> subquery_plan{new ExpectScanAll(), new ExpectExpand(), new ExpectProduce()};
CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), ExpectApply(subquery_plan), ExpectProduce());
DeleteListContent(&subquery_plan);
}
// MATCH (n) CALL { MATCH (p)-[r]->(s) WHERE s.prop = 2 RETURN (p) } RETURN n, p
{
auto property = dba.Property("prop");
auto *subquery = SINGLE_QUERY(MATCH(PATTERN(NODE("p"), EDGE("r", Direction::OUT), NODE("s"))),
WHERE(EQ(PROPERTY_LOOKUP("s", property), LITERAL(2))), RETURN("p"));
auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), CALL_SUBQUERY(subquery), RETURN("n", "p")));
auto symbol_table = memgraph::query::MakeSymbolTable(query);
auto planner = MakePlanner<TypeParam>(&dba, storage, symbol_table, query);
std::list<BaseOpChecker *> subquery_plan{new ExpectScanAll(), new ExpectExpand(), new ExpectFilter(),
new ExpectProduce()};
CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), ExpectApply(subquery_plan), ExpectProduce());
DeleteListContent(&subquery_plan);
}
// MATCH (m) CALL { MATCH (n) CALL { MATCH (o) RETURN o } RETURN n, o } RETURN m, n, o
{
auto *subquery_inside_subquery = SINGLE_QUERY(MATCH(PATTERN(NODE("o"))), RETURN("o"));
auto *subquery = SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), CALL_SUBQUERY(subquery_inside_subquery), RETURN("n", "o"));
auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("m"))), CALL_SUBQUERY(subquery), RETURN("m", "n", "o")));
auto symbol_table = memgraph::query::MakeSymbolTable(query);
auto planner = MakePlanner<TypeParam>(&dba, storage, symbol_table, query);
std::list<BaseOpChecker *> subquery_inside_subquery_plan{new ExpectScanAll(), new ExpectProduce()};
std::list<BaseOpChecker *> subquery_plan{new ExpectScanAll(), new ExpectApply(subquery_inside_subquery_plan),
new ExpectProduce()};
CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), ExpectApply(subquery_plan), ExpectProduce());
DeleteListContent(&subquery_plan);
DeleteListContent(&subquery_inside_subquery_plan);
}
// MATCH (m) CALL { MATCH (n) RETURN n UNION MATCH (n) RETURN n } RETURN m, n
{
auto *subquery = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), RETURN("n")),
UNION_ALL(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), RETURN("n"))));
auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("m"))), CALL_SUBQUERY(subquery), RETURN("m", "n")));
auto symbol_table = memgraph::query::MakeSymbolTable(query);
auto planner = MakePlanner<TypeParam>(&dba, storage, symbol_table, query);
std::list<BaseOpChecker *> left_subquery_part{new ExpectScanAll(), new ExpectProduce()};
std::list<BaseOpChecker *> right_subquery_part{new ExpectScanAll(), new ExpectProduce()};
std::list<BaseOpChecker *> subquery_plan{new ExpectUnion(left_subquery_part, right_subquery_part)};
CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), ExpectApply(subquery_plan), ExpectProduce());
DeleteListContent(&subquery_plan);
DeleteListContent(&left_subquery_part);
DeleteListContent(&right_subquery_part);
}
}
} // namespace

View File

@ -114,6 +114,17 @@ class PlanChecker : public virtual HierarchicalLogicalOperatorVisitor {
return false;
}
bool PreVisit(Apply &op) override {
CheckOp(op);
op.input()->Accept(*this);
return false;
}
bool PreVisit(Union &op) override {
CheckOp(op);
return false;
}
PRE_VISIT(CallProcedure);
#undef PRE_VISIT
@ -199,6 +210,36 @@ class ExpectForeach : public OpChecker<Foreach> {
std::list<BaseOpChecker *> updates_;
};
class ExpectApply : public OpChecker<Apply> {
public:
ExpectApply(const std::list<BaseOpChecker *> &subquery) : subquery_(subquery) {}
void ExpectOp(Apply &apply, const SymbolTable &symbol_table) override {
PlanChecker check_subquery(subquery_, symbol_table);
apply.subquery_->Accept(check_subquery);
}
private:
std::list<BaseOpChecker *> subquery_;
};
class ExpectUnion : public OpChecker<Union> {
public:
ExpectUnion(const std::list<BaseOpChecker *> &left, const std::list<BaseOpChecker *> &right)
: left_(left), right_(right) {}
void ExpectOp(Union &union_op, const SymbolTable &symbol_table) override {
PlanChecker check_left_op(left_, symbol_table);
union_op.left_op_->Accept(check_left_op);
PlanChecker check_right_op(left_, symbol_table);
union_op.right_op_->Accept(check_right_op);
}
private:
std::list<BaseOpChecker *> left_;
std::list<BaseOpChecker *> right_;
};
class ExpectExpandVariable : public OpChecker<ExpandVariable> {
public:
void ExpectOp(ExpandVariable &op, const SymbolTable &) override {
@ -426,8 +467,7 @@ template <class TPlanner, class TDbAccessor>
TPlanner MakePlanner(TDbAccessor *dba, AstStorage &storage, SymbolTable &symbol_table, CypherQuery *query) {
auto planning_context = MakePlanningContext(&storage, &symbol_table, query, dba);
auto query_parts = CollectQueryParts(symbol_table, storage, query);
auto single_query_parts = query_parts.query_parts.at(0).single_query_parts;
return TPlanner(single_query_parts, planning_context);
return TPlanner(query_parts, planning_context);
}
class FakeDbAccessor {

View File

@ -3676,3 +3676,273 @@ TEST_F(ExistsFixture, DoubleFilters) {
EXPECT_EQ(1, TestDoubleExists("l1", EdgeAtom::Direction::BOTH, {}, {}, true));
EXPECT_EQ(1, TestDoubleExists("l1", EdgeAtom::Direction::BOTH, {}, {}, false));
}
class SubqueriesFeature : public testing::Test {
protected:
memgraph::storage::Storage db;
memgraph::storage::Storage::Accessor storage_dba{db.Access()};
memgraph::query::DbAccessor dba{&storage_dba};
AstStorage storage;
SymbolTable symbol_table;
std::pair<std::string, memgraph::storage::PropertyId> prop = PROPERTY_PAIR("property");
memgraph::query::VertexAccessor v1{dba.InsertVertex()};
memgraph::query::VertexAccessor v2{dba.InsertVertex()};
memgraph::storage::EdgeTypeId edge_type{db.NameToEdgeType("Edge")};
memgraph::query::EdgeAccessor r1{*dba.InsertEdge(&v1, &v2, edge_type)};
void SetUp() override {
// (:l1)-[:Edge]->(:l2)
ASSERT_TRUE(v1.AddLabel(dba.NameToLabel("l1")).HasValue());
ASSERT_TRUE(v2.AddLabel(dba.NameToLabel("l2")).HasValue());
ASSERT_TRUE(v1.SetProperty(prop.second, memgraph::storage::PropertyValue(1)).HasValue());
ASSERT_TRUE(v2.SetProperty(prop.second, memgraph::storage::PropertyValue(2)).HasValue());
ASSERT_TRUE(r1.SetProperty(prop.second, memgraph::storage::PropertyValue(1)).HasValue());
memgraph::license::global_license_checker.EnableTesting();
dba.AdvanceCommand();
}
};
TEST_F(SubqueriesFeature, BasicCartesian) {
// MATCH (n) CALL { MATCH (m) RETURN m } RETURN n, m
auto n = MakeScanAll(storage, symbol_table, "n");
auto return_n = NEXPR("n", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true));
auto m = MakeScanAll(storage, symbol_table, "m");
auto return_m = NEXPR("m", IDENT("m")->MapTo(m.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_2", true));
auto produce_subquery = MakeProduce(m.op_, return_m);
auto apply = std::make_shared<Apply>(n.op_, produce_subquery, true);
auto produce = MakeProduce(apply, return_n, return_m);
auto context = MakeContext(storage, symbol_table, &dba);
auto results = CollectProduce(*produce, &context);
EXPECT_EQ(results.size(), 4);
}
TEST_F(SubqueriesFeature, BasicCartesianWithFilter) {
// MATCH (n) WHERE n.prop = 2 CALL { MATCH (m) RETURN m } RETURN n, m
auto n = MakeScanAll(storage, symbol_table, "n");
auto *filter_expr = AND(storage.Create<LabelsTest>(n.node_->identifier_, n.node_->labels_),
EQ(PROPERTY_LOOKUP(n.node_->identifier_, prop), LITERAL(2)));
auto filter = std::make_shared<Filter>(n.op_, std::vector<std::shared_ptr<LogicalOperator>>{}, filter_expr);
auto return_n = NEXPR("n", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true));
auto m = MakeScanAll(storage, symbol_table, "m");
auto return_m = NEXPR("m", IDENT("m")->MapTo(m.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_2", true));
auto produce_subquery = MakeProduce(m.op_, return_m);
auto apply = std::make_shared<Apply>(filter, produce_subquery, true);
auto produce = MakeProduce(apply, return_n, return_m);
auto context = MakeContext(storage, symbol_table, &dba);
auto results = CollectProduce(*produce, &context);
EXPECT_EQ(results.size(), 2);
}
TEST_F(SubqueriesFeature, BasicCartesianWithFilterInsideSubquery) {
// MATCH (n) CALL { MATCH (m) WHERE m.prop = 2 RETURN m } RETURN n, m
auto n = MakeScanAll(storage, symbol_table, "n");
auto return_n = NEXPR("n", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true));
auto m = MakeScanAll(storage, symbol_table, "m");
auto *filter_expr = AND(storage.Create<LabelsTest>(n.node_->identifier_, n.node_->labels_),
EQ(PROPERTY_LOOKUP(n.node_->identifier_, prop), LITERAL(2)));
auto filter = std::make_shared<Filter>(m.op_, std::vector<std::shared_ptr<LogicalOperator>>{}, filter_expr);
auto return_m = NEXPR("m", IDENT("m")->MapTo(m.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_2", true));
auto produce_subquery = MakeProduce(filter, return_m);
auto apply = std::make_shared<Apply>(n.op_, produce_subquery, true);
auto produce = MakeProduce(apply, return_n, return_m);
auto context = MakeContext(storage, symbol_table, &dba);
auto results = CollectProduce(*produce, &context);
EXPECT_EQ(results.size(), 2);
}
TEST_F(SubqueriesFeature, BasicCartesianWithFilterNoResults) {
// MATCH (n) WHERE n.prop = 3 CALL { MATCH (m) RETURN m } RETURN n, m
auto n = MakeScanAll(storage, symbol_table, "n");
auto *filter_expr = AND(storage.Create<LabelsTest>(n.node_->identifier_, n.node_->labels_),
EQ(PROPERTY_LOOKUP(n.node_->identifier_, prop), LITERAL(3)));
auto filter = std::make_shared<Filter>(n.op_, std::vector<std::shared_ptr<LogicalOperator>>{}, filter_expr);
auto return_n = NEXPR("n", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true));
auto m = MakeScanAll(storage, symbol_table, "m");
auto return_m = NEXPR("m", IDENT("m")->MapTo(m.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_2", true));
auto produce_subquery = MakeProduce(m.op_, return_m);
auto apply = std::make_shared<Apply>(filter, produce_subquery, true);
auto produce = MakeProduce(apply, return_n, return_m);
auto context = MakeContext(storage, symbol_table, &dba);
auto results = CollectProduce(*produce, &context);
EXPECT_EQ(results.size(), 0);
}
TEST_F(SubqueriesFeature, SubqueryInsideSubqueryCartesian) {
// MATCH (n) CALL { MATCH (m) CALL { MATCH (o) RETURN o} RETURN m, o } RETURN n, m, o
auto n = MakeScanAll(storage, symbol_table, "n");
auto return_n = NEXPR("n", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true));
auto m = MakeScanAll(storage, symbol_table, "m");
auto return_m = NEXPR("m", IDENT("m")->MapTo(m.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_2", true));
auto o = MakeScanAll(storage, symbol_table, "o");
auto return_o = NEXPR("o", IDENT("o")->MapTo(o.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_3", true));
auto produce_nested_subquery = MakeProduce(o.op_, return_o);
auto inner_apply = std::make_shared<Apply>(m.op_, produce_nested_subquery, true);
auto produce_subquery = MakeProduce(inner_apply, return_o, return_m);
auto outer_apply = std::make_shared<Apply>(n.op_, produce_subquery, true);
auto produce = MakeProduce(outer_apply, return_n, return_m, return_o);
auto context = MakeContext(storage, symbol_table, &dba);
auto results = CollectProduce(*produce, &context);
EXPECT_EQ(results.size(), 8);
}
TEST_F(SubqueriesFeature, UnitSubquery) {
// CALL { MATCH (m) RETURN m } RETURN m
auto once = std::make_shared<Once>();
auto o = MakeScanAll(storage, symbol_table, "o");
auto return_o = NEXPR("o", IDENT("o")->MapTo(o.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_3", true));
auto produce_subquery = MakeProduce(o.op_, return_o);
auto apply = std::make_shared<Apply>(once, produce_subquery, true);
auto produce = MakeProduce(apply, return_o);
auto context = MakeContext(storage, symbol_table, &dba);
auto results = CollectProduce(*produce, &context);
EXPECT_EQ(results.size(), 2);
}
TEST_F(SubqueriesFeature, SubqueryWithBoundedSymbol) {
// MATCH (n) CALL { WITH n MATCH (n)-[]->(m) RETURN m } RETURN n, m
auto n = MakeScanAll(storage, symbol_table, "n");
auto return_n = NEXPR("n", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true));
auto once = std::make_shared<Once>();
auto produce_with = MakeProduce(once, return_n);
auto expand = MakeExpand(storage, symbol_table, produce_with, n.sym_, "r", EdgeAtom::Direction::OUT, {}, "m", false,
memgraph::storage::View::OLD);
auto return_m =
NEXPR("m", IDENT("m")->MapTo(expand.node_sym_))->MapTo(symbol_table.CreateSymbol("named_expression_3", true));
auto produce_subquery = MakeProduce(expand.op_, return_m);
auto apply = std::make_shared<Apply>(n.op_, produce_subquery, true);
auto produce = MakeProduce(apply, return_n, return_m);
auto context = MakeContext(storage, symbol_table, &dba);
auto results = CollectProduce(*produce, &context);
EXPECT_EQ(results.size(), 1);
}
TEST_F(SubqueriesFeature, SubqueryWithUnionAll) {
// MATCH (n) CALL { MATCH (m) RETURN m UNION ALL MATCH (m) RETURN m } RETURN n, m
auto n = MakeScanAll(storage, symbol_table, "n");
auto return_n = NEXPR("n", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true));
auto m1 = MakeScanAll(storage, symbol_table, "m");
auto return_m = NEXPR("m", IDENT("m")->MapTo(m1.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_2", true));
auto produce_left_union_subquery = MakeProduce(m1.op_, return_m);
auto m2 = MakeScanAll(storage, symbol_table, "m");
auto produce_right_union_subquery = MakeProduce(m2.op_, return_m);
auto union_operator =
std::make_shared<Union>(produce_left_union_subquery, produce_right_union_subquery, std::vector<Symbol>{m1.sym_},
produce_left_union_subquery->OutputSymbols(symbol_table),
produce_right_union_subquery->OutputSymbols(symbol_table));
auto apply = std::make_shared<Apply>(n.op_, union_operator, true);
auto produce = MakeProduce(apply, return_n, return_m);
auto context = MakeContext(storage, symbol_table, &dba);
auto results = CollectProduce(*produce, &context);
EXPECT_EQ(results.size(), 8);
}
TEST_F(SubqueriesFeature, SubqueryWithUnion) {
// MATCH (n) CALL { MATCH (m) RETURN m UNION MATCH (m) RETURN m } RETURN n, m
auto n = MakeScanAll(storage, symbol_table, "n");
auto return_n = NEXPR("n", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true));
auto m1 = MakeScanAll(storage, symbol_table, "m");
auto subquery_return_symbol = symbol_table.CreateSymbol("named_expression_2", true);
auto return_m = NEXPR("m", IDENT("m")->MapTo(m1.sym_))->MapTo(subquery_return_symbol);
auto produce_left_union_subquery = MakeProduce(m1.op_, return_m);
auto m2 = MakeScanAll(storage, symbol_table, "m");
auto produce_right_union_subquery = MakeProduce(m2.op_, return_m);
auto union_operator = std::make_shared<Union>(produce_left_union_subquery, produce_right_union_subquery,
std::vector<Symbol>{subquery_return_symbol},
produce_left_union_subquery->OutputSymbols(symbol_table),
produce_right_union_subquery->OutputSymbols(symbol_table));
auto union_output_symbols = union_operator->OutputSymbols(symbol_table);
auto distinct = std::make_shared<Distinct>(union_operator, std::vector<Symbol>{union_output_symbols});
auto apply = std::make_shared<Apply>(n.op_, distinct, true);
auto produce = MakeProduce(apply, return_n);
auto context = MakeContext(storage, symbol_table, &dba);
auto results = CollectProduce(*produce, &context);
EXPECT_EQ(results.size(), 4);
}
TEST_F(SubqueriesFeature, SubqueriesWithForeach) {
// MATCH (n) CALL { FOREACH (i in range(1, 5) | CREATE (n)) } RETURN n
auto n = MakeScanAll(storage, symbol_table, "n");
auto return_n = NEXPR("n", IDENT("n")->MapTo(n.sym_))->MapTo(symbol_table.CreateSymbol("named_expression_1", true));
auto once_create = std::make_shared<Once>();
NodeCreationInfo node_creation_info;
node_creation_info.symbol = symbol_table.CreateSymbol("n", true);
auto create = std::make_shared<plan::CreateNode>(once_create, node_creation_info);
auto once_foreach = std::make_shared<Once>();
auto iteration_symbol = symbol_table.CreateSymbol("i", true);
auto iterating_list = LIST(LITERAL(1), LITERAL(2), LITERAL(3), LITERAL(4), LITERAL(5));
auto foreach = std::make_shared<plan::Foreach>(once_foreach, create, iterating_list, iteration_symbol);
auto empty_result = std::make_shared<EmptyResult>(foreach);
auto apply = std::make_shared<Apply>(n.op_, empty_result, false);
auto produce = MakeProduce(apply, return_n);
auto context = MakeContext(storage, symbol_table, &dba);
auto results = CollectProduce(*produce, &context);
EXPECT_EQ(results.size(), 2);
}

View File

@ -1214,3 +1214,42 @@ TEST_F(TestSymbolGenerator, Exists) {
auto symbol = *collector.symbols_.begin();
ASSERT_EQ(symbol.name_, "n");
}
TEST_F(TestSymbolGenerator, Subqueries) {
// MATCH (n) CALL { MATCH (n) RETURN n } RETURN n
// Yields exception because n in subquery is referenced in outer scope
auto subquery = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), RETURN("n")));
auto query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), CALL_SUBQUERY(subquery), RETURN("n")));
EXPECT_THROW(MakeSymbolTable(query), SemanticException);
// MATCH (n) CALL { MATCH (m) RETURN m.prop } RETURN n
// Yields exception because m.prop must be aliased before returning
subquery = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("m"))), RETURN("m.prop")));
query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), CALL_SUBQUERY(subquery), RETURN("n")));
EXPECT_THROW(MakeSymbolTable(query), SemanticException);
// MATCH (n) CALL { MATCH (m) RETURN m, m.prop } RETURN n
// Yields exception because m.prop must be aliased before returning
subquery = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("m"))), RETURN("m", "m.prop")));
query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), CALL_SUBQUERY(subquery), RETURN("n")));
EXPECT_THROW(MakeSymbolTable(query), SemanticException);
// MATCH (n) CALL { MATCH (m) RETURN m.prop, m } RETURN n
// Yields exception because m.prop must be aliased before returning
subquery = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("m"))), RETURN("m.prop", "m")));
query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), CALL_SUBQUERY(subquery), RETURN("n")));
EXPECT_THROW(MakeSymbolTable(query), SemanticException);
// MATCH (n) CALL { MATCH (m) RETURN m } RETURN n, m
subquery = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("m"))), RETURN("m")));
query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), CALL_SUBQUERY(subquery), RETURN("n", "m")));
auto symbol_table = MakeSymbolTable(query);
ASSERT_EQ(symbol_table.max_position(), 7);
// MATCH (n) CALL { MATCH (m) RETURN m UNION MATCH (m) RETURN m } RETURN n, m
subquery = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("m"))), RETURN("m")),
UNION(SINGLE_QUERY(MATCH(PATTERN(NODE("m"))), RETURN("m"))));
query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), CALL_SUBQUERY(subquery), RETURN("n", "m")));
symbol_table = MakeSymbolTable(query);
ASSERT_EQ(symbol_table.max_position(), 11);
}

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -75,8 +75,7 @@ void CheckPlansProduce(size_t expected_plan_count, memgraph::query::CypherQuery
auto planning_context = MakePlanningContext(&storage, &symbol_table, query, dba);
auto query_parts = CollectQueryParts(symbol_table, storage, query);
EXPECT_TRUE(query_parts.query_parts.size() > 0);
auto single_query_parts = query_parts.query_parts.at(0).single_query_parts;
auto plans = MakeLogicalPlanForSingleQuery<VariableStartPlanner>(single_query_parts, &planning_context);
auto plans = MakeLogicalPlanForSingleQuery<VariableStartPlanner>(query_parts, &planning_context);
EXPECT_EQ(std::distance(plans.begin(), plans.end()), expected_plan_count);
for (const auto &plan : plans) {
auto *produce = dynamic_cast<Produce *>(plan.get());
@ -330,4 +329,114 @@ TEST(TestVariableStartPlanner, MatchBfs) {
CheckPlansProduce(2, query, storage, &dba, [&](const auto &results) { AssertRows(results, {{r1_list}}, dba); });
}
TEST(TestVariableStartPlanner, TestBasicSubquery) {
memgraph::storage::Storage db;
auto storage_dba = db.Access();
memgraph::query::DbAccessor dba(&storage_dba);
AstStorage storage;
auto v1 = dba.InsertVertex();
auto v2 = dba.InsertVertex();
dba.AdvanceCommand();
auto *subquery = SINGLE_QUERY(MATCH(PATTERN(NODE("m"))), RETURN("m"));
auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), CALL_SUBQUERY(subquery), RETURN("n", "m")));
CheckPlansProduce(1, query, storage, &dba, [&](const auto &results) {
AssertRows(results,
{{TypedValue(v1), TypedValue(v1)},
{TypedValue(v1), TypedValue(v2)},
{TypedValue(v2), TypedValue(v1)},
{TypedValue(v2), TypedValue(v2)}},
dba);
});
}
TEST(TestVariableStartPlanner, TestBasicSubqueryWithMatching) {
memgraph::storage::Storage db;
auto storage_dba = db.Access();
memgraph::query::DbAccessor dba(&storage_dba);
AstStorage storage;
auto v1 = dba.InsertVertex();
auto v2 = dba.InsertVertex();
ASSERT_TRUE(dba.InsertEdge(&v1, &v2, dba.NameToEdgeType("r1")).HasValue());
dba.AdvanceCommand();
auto *subquery =
SINGLE_QUERY(MATCH(PATTERN(NODE("m2"), EDGE("r2", EdgeAtom::Direction::OUT), NODE("n2"))), RETURN("m2"));
auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("m1"), EDGE("r1", EdgeAtom::Direction::OUT), NODE("n1"))),
CALL_SUBQUERY(subquery), RETURN("m1", "m2")));
CheckPlansProduce(4, query, storage, &dba, [&](const auto &results) {
AssertRows(results, {{TypedValue(v1), TypedValue(v1)}}, dba);
});
}
TEST(TestVariableStartPlanner, TestSubqueryWithUnion) {
memgraph::storage::Storage db;
auto storage_dba = db.Access();
memgraph::query::DbAccessor dba(&storage_dba);
AstStorage storage;
auto id = dba.NameToProperty("id");
auto v1 = dba.InsertVertex();
ASSERT_TRUE(v1.SetProperty(id, memgraph::storage::PropertyValue(1)).HasValue());
auto v2 = dba.InsertVertex();
ASSERT_TRUE(v2.SetProperty(id, memgraph::storage::PropertyValue(2)).HasValue());
ASSERT_TRUE(dba.InsertEdge(&v1, &v2, dba.NameToEdgeType("r1")).HasValue());
dba.AdvanceCommand();
auto *subquery =
QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("m2"), EDGE("r2", EdgeAtom::Direction::OUT), NODE("n2"))), RETURN("n2")),
UNION_ALL(SINGLE_QUERY(MATCH(PATTERN(NODE("m2"), EDGE("r2", EdgeAtom::Direction::OUT), NODE("n2"))),
RETURN("n2"))));
auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("m1"), EDGE("r1", EdgeAtom::Direction::OUT), NODE("n1"))),
CALL_SUBQUERY(subquery), RETURN("m1", "n2")));
CheckPlansProduce(8, query, storage, &dba, [&](const auto &results) {
AssertRows(results, {{TypedValue(v1), TypedValue(v2)}, {TypedValue(v1), TypedValue(v2)}}, dba);
});
}
TEST(TestVariableStartPlanner, TestSubqueryWithTripleUnion) {
memgraph::storage::Storage db;
auto storage_dba = db.Access();
memgraph::query::DbAccessor dba(&storage_dba);
AstStorage storage;
auto id = dba.NameToProperty("id");
auto v1 = dba.InsertVertex();
ASSERT_TRUE(v1.SetProperty(id, memgraph::storage::PropertyValue(1)).HasValue());
auto v2 = dba.InsertVertex();
ASSERT_TRUE(v2.SetProperty(id, memgraph::storage::PropertyValue(2)).HasValue());
ASSERT_TRUE(dba.InsertEdge(&v1, &v2, dba.NameToEdgeType("r1")).HasValue());
dba.AdvanceCommand();
auto *subquery =
QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("m2"), EDGE("r2", EdgeAtom::Direction::OUT), NODE("n2"))), RETURN("n2")),
UNION_ALL(SINGLE_QUERY(MATCH(PATTERN(NODE("m2"), EDGE("r2", EdgeAtom::Direction::OUT), NODE("n2"))),
RETURN("n2"))),
UNION_ALL(SINGLE_QUERY(MATCH(PATTERN(NODE("m2"), EDGE("r2", EdgeAtom::Direction::OUT), NODE("n2"))),
RETURN("n2"))));
auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("m1"), EDGE("r1", EdgeAtom::Direction::OUT), NODE("n1"))),
CALL_SUBQUERY(subquery), RETURN("m1", "n2")));
CheckPlansProduce(16, query, storage, &dba, [&](const auto &results) {
AssertRows(results,
{{TypedValue(v1), TypedValue(v2)}, {TypedValue(v1), TypedValue(v2)}, {TypedValue(v1), TypedValue(v2)}},
dba);
});
}
} // namespace