Plan aggregation in WITH and RETURN clauses
Summary: Generate symbols for aggregation results. Plan aggregation in WITH clause. Plan aggregation in RETURN clause. Extract handling write clauses to a function. Reviewers: mislav.bradac, florijan Reviewed By: mislav.bradac, florijan Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D236
This commit is contained in:
parent
6cb1cdc607
commit
355b9a9b9a
@ -61,6 +61,7 @@ class ExpressionEvaluator : public TreeVisitorBase {
|
||||
return last;
|
||||
}
|
||||
|
||||
using TreeVisitorBase::PreVisit;
|
||||
using TreeVisitorBase::Visit;
|
||||
using TreeVisitorBase::PostVisit;
|
||||
|
||||
@ -147,12 +148,14 @@ class ExpressionEvaluator : public TreeVisitorBase {
|
||||
result_stack_.push_back(literal.value_);
|
||||
}
|
||||
|
||||
void Visit(Aggregation &aggregation) override {
|
||||
bool PreVisit(Aggregation &aggregation) override {
|
||||
auto value = frame_[symbol_table_.at(aggregation)];
|
||||
// Aggregation is probably always simple type, but let's switch accessor
|
||||
// just to be sure.
|
||||
SwitchAccessors(value);
|
||||
result_stack_.emplace_back(std::move(value));
|
||||
// Prevent evaluation of expressions inside the aggregation.
|
||||
return false;
|
||||
}
|
||||
|
||||
void PostVisit(Function &function) override {
|
||||
|
@ -833,6 +833,9 @@ class Aggregate : public LogicalOperator {
|
||||
void Accept(LogicalOperatorVisitor &visitor) override;
|
||||
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
|
||||
|
||||
const auto &aggregations() const { return aggregations_; }
|
||||
const auto &group_by() const { return group_by_; }
|
||||
|
||||
private:
|
||||
const std::shared_ptr<LogicalOperator> input_;
|
||||
const std::vector<Element> aggregations_;
|
||||
|
@ -60,7 +60,7 @@ auto ReducePattern(
|
||||
}
|
||||
|
||||
auto GenCreateForPattern(Pattern &pattern, LogicalOperator *input_op,
|
||||
const query::SymbolTable &symbol_table,
|
||||
const SymbolTable &symbol_table,
|
||||
std::unordered_set<int> &bound_symbols) {
|
||||
auto base = [&](NodeAtom *node) -> LogicalOperator * {
|
||||
if (BindSymbol(bound_symbols, symbol_table.at(*node->identifier_)))
|
||||
@ -91,7 +91,7 @@ auto GenCreateForPattern(Pattern &pattern, LogicalOperator *input_op,
|
||||
}
|
||||
|
||||
auto GenCreate(Create &create, LogicalOperator *input_op,
|
||||
const query::SymbolTable &symbol_table,
|
||||
const SymbolTable &symbol_table,
|
||||
std::unordered_set<int> &bound_symbols) {
|
||||
auto last_op = input_op;
|
||||
for (auto pattern : create.patterns_) {
|
||||
@ -102,7 +102,7 @@ auto GenCreate(Create &create, LogicalOperator *input_op,
|
||||
}
|
||||
|
||||
auto GenMatchForPattern(Pattern &pattern, LogicalOperator *input_op,
|
||||
const query::SymbolTable &symbol_table,
|
||||
const SymbolTable &symbol_table,
|
||||
std::unordered_set<int> &bound_symbols,
|
||||
std::vector<Symbol> &edge_symbols) {
|
||||
auto base = [&](NodeAtom *node) {
|
||||
@ -162,7 +162,7 @@ auto GenMatchForPattern(Pattern &pattern, LogicalOperator *input_op,
|
||||
}
|
||||
|
||||
auto GenMatch(Match &match, LogicalOperator *input_op,
|
||||
const query::SymbolTable &symbol_table,
|
||||
const SymbolTable &symbol_table,
|
||||
std::unordered_set<int> &bound_symbols) {
|
||||
auto last_op = input_op;
|
||||
std::vector<Symbol> edge_symbols;
|
||||
@ -177,20 +177,87 @@ auto GenMatch(Match &match, LogicalOperator *input_op,
|
||||
return last_op;
|
||||
}
|
||||
|
||||
// Ast tree visitor which collects all the symbols referenced by identifiers.
|
||||
class SymbolCollector : public TreeVisitorBase {
|
||||
// Ast tree visitor which collects the context for a return body. The return
|
||||
// body are the named expressions found in WITH and RETURN clauses. The
|
||||
// collected context consists of used symbols, aggregations and group by named
|
||||
// expressions.
|
||||
class ReturnBodyContext : public TreeVisitorBase {
|
||||
public:
|
||||
SymbolCollector(const SymbolTable &symbol_table)
|
||||
ReturnBodyContext(const SymbolTable &symbol_table)
|
||||
: symbol_table_(symbol_table) {}
|
||||
|
||||
using TreeVisitorBase::PreVisit;
|
||||
using TreeVisitorBase::Visit;
|
||||
using TreeVisitorBase::PostVisit;
|
||||
|
||||
void Visit(Literal &) override { has_aggregation_.emplace_back(false); }
|
||||
|
||||
void Visit(Identifier &ident) override {
|
||||
symbols_.insert(symbol_table_.at(ident));
|
||||
has_aggregation_.emplace_back(false);
|
||||
}
|
||||
|
||||
#define VISIT_BINARY_OPERATOR(BinaryOperator) \
|
||||
void PostVisit(BinaryOperator &op) override { \
|
||||
/* has_aggregation_ stack is reversed, last result is from the 2nd \
|
||||
* expression. */ \
|
||||
bool aggr2 = has_aggregation_.back(); \
|
||||
has_aggregation_.pop_back(); \
|
||||
bool aggr1 = has_aggregation_.back(); \
|
||||
has_aggregation_.pop_back(); \
|
||||
bool has_aggr = aggr1 || aggr2; \
|
||||
if (has_aggr) { \
|
||||
/* Group by the expression which does not contain aggregation. */ \
|
||||
/* Possible optimization is to ignore constant value expressions */ \
|
||||
group_by_.emplace_back(aggr1 ? op.expression2_ : op.expression1_); \
|
||||
} \
|
||||
/* Propagate that this whole expression may contain an aggregation. */ \
|
||||
has_aggregation_.emplace_back(has_aggr); \
|
||||
}
|
||||
|
||||
VISIT_BINARY_OPERATOR(OrOperator)
|
||||
VISIT_BINARY_OPERATOR(XorOperator)
|
||||
VISIT_BINARY_OPERATOR(AndOperator)
|
||||
VISIT_BINARY_OPERATOR(AdditionOperator)
|
||||
VISIT_BINARY_OPERATOR(SubtractionOperator)
|
||||
VISIT_BINARY_OPERATOR(MultiplicationOperator)
|
||||
VISIT_BINARY_OPERATOR(DivisionOperator)
|
||||
VISIT_BINARY_OPERATOR(ModOperator)
|
||||
VISIT_BINARY_OPERATOR(NotEqualOperator)
|
||||
VISIT_BINARY_OPERATOR(EqualOperator)
|
||||
VISIT_BINARY_OPERATOR(LessOperator)
|
||||
VISIT_BINARY_OPERATOR(GreaterOperator)
|
||||
VISIT_BINARY_OPERATOR(LessEqualOperator)
|
||||
VISIT_BINARY_OPERATOR(GreaterEqualOperator)
|
||||
|
||||
#undef VISIT_BINARY_OPERATOR
|
||||
|
||||
void PostVisit(Aggregation &aggr) override {
|
||||
// Aggregation contains a virtual symbol, where the result will be stored.
|
||||
const auto &symbol = symbol_table_.at(aggr);
|
||||
aggregations_.emplace_back(aggr.expression_, aggr.op_, symbol);
|
||||
has_aggregation_.back() = true;
|
||||
// Possible optimization is to skip remembering symbols inside aggregation.
|
||||
// If and when implementing this, don't forget that Accumulate needs *all*
|
||||
// the symbols, including those inside aggregation.
|
||||
}
|
||||
|
||||
void PostVisit(NamedExpression &named_expr) override {
|
||||
if (!has_aggregation_.back()) {
|
||||
group_by_.emplace_back(named_expr.expression_);
|
||||
}
|
||||
has_aggregation_.pop_back();
|
||||
}
|
||||
|
||||
// Set of symbols used inside the visited expressions outside of aggregation
|
||||
// expression.
|
||||
const auto &symbols() const { return symbols_; }
|
||||
// List of aggregation elements found in expressions.
|
||||
const auto &aggregations() const { return aggregations_; }
|
||||
// When there is at least one aggregation element, all the non-aggregate (sub)
|
||||
// expressions are used for grouping. For example, in `WITH sum(n.a) + 2 * n.b
|
||||
// AS sum, n.c AS nc`, we will group by `2 * n.b` and `n.c`.
|
||||
const auto &group_by() const { return group_by_; }
|
||||
|
||||
private:
|
||||
// Calculates the Symbol hash based on its position.
|
||||
@ -202,27 +269,55 @@ class SymbolCollector : public TreeVisitorBase {
|
||||
|
||||
const SymbolTable &symbol_table_;
|
||||
std::unordered_set<Symbol, SymbolHash> symbols_;
|
||||
std::vector<Aggregate::Element> aggregations_;
|
||||
std::vector<Expression *> group_by_;
|
||||
// Flag indicating whether an expression contains an aggregation.
|
||||
std::list<bool> has_aggregation_;
|
||||
};
|
||||
|
||||
auto GenReturnBody(LogicalOperator *input_op, bool advance_command,
|
||||
const std::vector<NamedExpression *> &named_expressions,
|
||||
const SymbolTable &symbol_table, bool accumulate = false) {
|
||||
ReturnBodyContext context(symbol_table);
|
||||
// Generate context for all named expressions.
|
||||
for (auto &named_expr : named_expressions) {
|
||||
named_expr->Accept(context);
|
||||
}
|
||||
auto symbols =
|
||||
std::vector<Symbol>(context.symbols().begin(), context.symbols().end());
|
||||
auto last_op = input_op;
|
||||
if (accumulate) {
|
||||
// We only advance the command in Accumulate. This is done for WITH clause,
|
||||
// when the first part updated the database. RETURN clause may only need an
|
||||
// accumulation after updates, without advancing the command.
|
||||
last_op = new Accumulate(std::shared_ptr<LogicalOperator>(last_op), symbols,
|
||||
advance_command);
|
||||
}
|
||||
if (!context.aggregations().empty()) {
|
||||
last_op =
|
||||
new Aggregate(std::shared_ptr<LogicalOperator>(last_op),
|
||||
context.aggregations(), context.group_by(), symbols);
|
||||
}
|
||||
return new Produce(std::shared_ptr<LogicalOperator>(last_op),
|
||||
named_expressions);
|
||||
}
|
||||
|
||||
auto GenWith(With &with, LogicalOperator *input_op,
|
||||
const query::SymbolTable &symbol_table) {
|
||||
const SymbolTable &symbol_table, bool is_write) {
|
||||
// WITH clause is Accumulate/Aggregate (advance_command) + Produce and
|
||||
// optional Filter.
|
||||
if (with.distinct_) {
|
||||
// TODO: Plan disctint with, when operator available.
|
||||
// TODO: Plan distinct with, when operator available.
|
||||
throw NotYetImplemented();
|
||||
}
|
||||
// WITH clause is Accumulate/Aggregate (advance_command) + Produce.
|
||||
SymbolCollector symbol_collector(symbol_table);
|
||||
// Collect used symbols so that accumulate doesn't copy the whole frame.
|
||||
for (auto &named_expr : with.named_expressions_) {
|
||||
named_expr->expression_->Accept(symbol_collector);
|
||||
}
|
||||
auto symbols = symbol_collector.symbols();
|
||||
// TODO: Check whether we need aggregate instead of accumulate.
|
||||
// In case of update and aggregation, we want to accumulate first, so that
|
||||
// when aggregating, we get the latest results. Similar to RETURN clause.
|
||||
bool accumulate = is_write;
|
||||
// No need to advance the command if we only performed reads.
|
||||
bool advance_command = is_write;
|
||||
LogicalOperator *last_op =
|
||||
new Accumulate(std::shared_ptr<LogicalOperator>(input_op),
|
||||
std::vector<Symbol>(symbols.begin(), symbols.end()), true);
|
||||
last_op = new Produce(std::shared_ptr<LogicalOperator>(last_op),
|
||||
with.named_expressions_);
|
||||
GenReturnBody(input_op, advance_command, with.named_expressions_,
|
||||
symbol_table, accumulate);
|
||||
if (with.where_) {
|
||||
last_op = new Filter(std::shared_ptr<LogicalOperator>(last_op),
|
||||
with.where_->expression_);
|
||||
@ -230,55 +325,80 @@ auto GenWith(With &with, LogicalOperator *input_op,
|
||||
return last_op;
|
||||
}
|
||||
|
||||
auto GenReturn(Return &ret, LogicalOperator *input_op,
|
||||
const SymbolTable &symbol_table, bool is_write) {
|
||||
// Similar to WITH clause, but we want to accumulate and advance command when
|
||||
// the query writes to the database. This way we handle the case when we want
|
||||
// to return expressions with the latest updated results. For example,
|
||||
// `MATCH (n) -- () SET n.prop = n.prop + 1 RETURN n.prop`. If we match same
|
||||
// `n` multiple 'k' times, we want to return 'k' results where the property
|
||||
// value is the same, final result of 'k' increments.
|
||||
bool accumulate = is_write;
|
||||
bool advance_command = false;
|
||||
return GenReturnBody(input_op, advance_command, ret.named_expressions_,
|
||||
symbol_table, accumulate);
|
||||
}
|
||||
|
||||
// Generate an operator for a clause which writes to the database. If the clause
|
||||
// isn't handled, returns nullptr.
|
||||
LogicalOperator *HandleWriteClause(Clause *clause, LogicalOperator *input_op,
|
||||
const SymbolTable &symbol_table,
|
||||
std::unordered_set<int> &bound_symbols) {
|
||||
if (auto *create = dynamic_cast<Create *>(clause)) {
|
||||
return GenCreate(*create, input_op, symbol_table, bound_symbols);
|
||||
} else if (auto *del = dynamic_cast<query::Delete *>(clause)) {
|
||||
return new plan::Delete(std::shared_ptr<LogicalOperator>(input_op),
|
||||
del->expressions_, del->detach_);
|
||||
} else if (auto *set = dynamic_cast<query::SetProperty *>(clause)) {
|
||||
return new plan::SetProperty(std::shared_ptr<LogicalOperator>(input_op),
|
||||
set->property_lookup_, set->expression_);
|
||||
} else if (auto *set = dynamic_cast<query::SetProperties *>(clause)) {
|
||||
auto op = set->update_ ? plan::SetProperties::Op::UPDATE
|
||||
: plan::SetProperties::Op::REPLACE;
|
||||
const auto &input_symbol = symbol_table.at(*set->identifier_);
|
||||
return new plan::SetProperties(std::shared_ptr<LogicalOperator>(input_op),
|
||||
input_symbol, set->expression_, op);
|
||||
} else if (auto *set = dynamic_cast<query::SetLabels *>(clause)) {
|
||||
const auto &input_symbol = symbol_table.at(*set->identifier_);
|
||||
return new plan::SetLabels(std::shared_ptr<LogicalOperator>(input_op),
|
||||
input_symbol, set->labels_);
|
||||
} else if (auto *rem = dynamic_cast<query::RemoveProperty *>(clause)) {
|
||||
return new plan::RemoveProperty(std::shared_ptr<LogicalOperator>(input_op),
|
||||
rem->property_lookup_);
|
||||
} else if (auto *rem = dynamic_cast<query::RemoveLabels *>(clause)) {
|
||||
const auto &input_symbol = symbol_table.at(*rem->identifier_);
|
||||
return new plan::RemoveLabels(std::shared_ptr<LogicalOperator>(input_op),
|
||||
input_symbol, rem->labels_);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<LogicalOperator> MakeLogicalPlan(
|
||||
query::Query &query, const query::SymbolTable &symbol_table) {
|
||||
// TODO: Extract functions and state into a class with methods. Possibly a
|
||||
// visitor or similar to avoid all those dynamic casts.
|
||||
LogicalOperator *input_op = nullptr;
|
||||
// bound_symbols set is used to differentiate cycles in pattern matching, so
|
||||
// that the operator can be correctly initialized whether to read the symbol
|
||||
// or write it. E.g. `MATCH (n) -[r]- (n)` would bind (and write) the first
|
||||
// `n`, but the latter `n` would only read the already written information.
|
||||
std::unordered_set<int> bound_symbols;
|
||||
// Set to true if a query command performs a writes to the database.
|
||||
bool is_write = false;
|
||||
LogicalOperator *input_op = nullptr;
|
||||
for (auto &clause : query.clauses_) {
|
||||
auto *clause_ptr = clause;
|
||||
if (auto *match = dynamic_cast<Match *>(clause_ptr)) {
|
||||
// Clauses which read from the database.
|
||||
if (auto *match = dynamic_cast<Match *>(clause)) {
|
||||
input_op = GenMatch(*match, input_op, symbol_table, bound_symbols);
|
||||
} else if (auto *ret = dynamic_cast<Return *>(clause_ptr)) {
|
||||
input_op = new Produce(std::shared_ptr<LogicalOperator>(input_op),
|
||||
ret->named_expressions_);
|
||||
} else if (auto *create = dynamic_cast<Create *>(clause_ptr)) {
|
||||
input_op = GenCreate(*create, input_op, symbol_table, bound_symbols);
|
||||
} else if (auto *del = dynamic_cast<query::Delete *>(clause_ptr)) {
|
||||
input_op = new plan::Delete(std::shared_ptr<LogicalOperator>(input_op),
|
||||
del->expressions_, del->detach_);
|
||||
} else if (auto *set = dynamic_cast<query::SetProperty *>(clause_ptr)) {
|
||||
input_op =
|
||||
new plan::SetProperty(std::shared_ptr<LogicalOperator>(input_op),
|
||||
set->property_lookup_, set->expression_);
|
||||
} else if (auto *set = dynamic_cast<query::SetProperties *>(clause_ptr)) {
|
||||
auto op = set->update_ ? plan::SetProperties::Op::UPDATE
|
||||
: plan::SetProperties::Op::REPLACE;
|
||||
const auto &input_symbol = symbol_table.at(*set->identifier_);
|
||||
input_op =
|
||||
new plan::SetProperties(std::shared_ptr<LogicalOperator>(input_op),
|
||||
input_symbol, set->expression_, op);
|
||||
} else if (auto *set = dynamic_cast<query::SetLabels *>(clause_ptr)) {
|
||||
const auto &input_symbol = symbol_table.at(*set->identifier_);
|
||||
input_op = new plan::SetLabels(std::shared_ptr<LogicalOperator>(input_op),
|
||||
input_symbol, set->labels_);
|
||||
} else if (auto *rem = dynamic_cast<query::RemoveProperty *>(clause_ptr)) {
|
||||
input_op = new plan::RemoveProperty(
|
||||
std::shared_ptr<LogicalOperator>(input_op), rem->property_lookup_);
|
||||
} else if (auto *rem = dynamic_cast<query::RemoveLabels *>(clause_ptr)) {
|
||||
const auto &input_symbol = symbol_table.at(*rem->identifier_);
|
||||
input_op =
|
||||
new plan::RemoveLabels(std::shared_ptr<LogicalOperator>(input_op),
|
||||
input_symbol, rem->labels_);
|
||||
} else if (auto *with = dynamic_cast<query::With *>(clause_ptr)) {
|
||||
input_op = GenWith(*with, input_op, symbol_table);
|
||||
} else if (auto *ret = dynamic_cast<Return *>(clause)) {
|
||||
input_op = GenReturn(*ret, input_op, symbol_table, is_write);
|
||||
} else if (auto *with = dynamic_cast<query::With *>(clause)) {
|
||||
input_op = GenWith(*with, input_op, symbol_table, is_write);
|
||||
// WITH clause advances the command, so reset the flag.
|
||||
is_write = false;
|
||||
} else if (auto *op = HandleWriteClause(clause, input_op, symbol_table,
|
||||
bound_symbols)) {
|
||||
is_write = true;
|
||||
input_op = op;
|
||||
} else {
|
||||
throw NotYetImplemented();
|
||||
}
|
||||
|
@ -108,6 +108,20 @@ void SymbolGenerator::Visit(Identifier &ident) {
|
||||
symbol_table_[ident] = symbol;
|
||||
}
|
||||
|
||||
void SymbolGenerator::Visit(Aggregation &aggr) {
|
||||
// Create a virtual symbol for aggregation result.
|
||||
symbol_table_[aggr] = symbol_table_.CreateSymbol("");
|
||||
if (scope_.in_aggregation) {
|
||||
throw SemanticException(
|
||||
"Using aggregate functions inside aggregate functions is not allowed");
|
||||
}
|
||||
scope_.in_aggregation = true;
|
||||
}
|
||||
|
||||
void SymbolGenerator::PostVisit(Aggregation &aggr) {
|
||||
scope_.in_aggregation = false;
|
||||
}
|
||||
|
||||
// Pattern and its subparts.
|
||||
|
||||
void SymbolGenerator::Visit(Pattern &pattern) {
|
||||
|
@ -24,23 +24,25 @@ class SymbolGenerator : public TreeVisitorBase {
|
||||
using TreeVisitorBase::PostVisit;
|
||||
|
||||
// Clauses
|
||||
void Visit(Create &create) override;
|
||||
void PostVisit(Create &create) override;
|
||||
void PostVisit(Return &ret) override;
|
||||
void Visit(With &with) override;
|
||||
void PostVisit(With &with) override;
|
||||
void Visit(Where &where) override;
|
||||
void Visit(Create &) override;
|
||||
void PostVisit(Create &) override;
|
||||
void PostVisit(Return &) override;
|
||||
void Visit(With &) override;
|
||||
void PostVisit(With &) override;
|
||||
void Visit(Where &) override;
|
||||
|
||||
// Expressions
|
||||
void Visit(Identifier &ident) override;
|
||||
void Visit(Identifier &) override;
|
||||
void Visit(Aggregation &) override;
|
||||
void PostVisit(Aggregation &) override;
|
||||
|
||||
// Pattern and its subparts.
|
||||
void Visit(Pattern &pattern) override;
|
||||
void PostVisit(Pattern &pattern) override;
|
||||
void Visit(NodeAtom &node_atom) override;
|
||||
void PostVisit(NodeAtom &node_atom) override;
|
||||
void Visit(EdgeAtom &edge_atom) override;
|
||||
void PostVisit(EdgeAtom &edge_atom) override;
|
||||
void Visit(Pattern &) override;
|
||||
void PostVisit(Pattern &) override;
|
||||
void Visit(NodeAtom &) override;
|
||||
void PostVisit(NodeAtom &) override;
|
||||
void Visit(EdgeAtom &) override;
|
||||
void PostVisit(EdgeAtom &) override;
|
||||
|
||||
private:
|
||||
// Scope stores the state of where we are when visiting the AST and a map of
|
||||
@ -56,6 +58,7 @@ class SymbolGenerator : public TreeVisitorBase {
|
||||
bool in_node_atom{false};
|
||||
bool in_edge_atom{false};
|
||||
bool in_property_map{false};
|
||||
bool in_aggregation{false};
|
||||
// Pointer to With clause if we are inside it, otherwise nullptr.
|
||||
With *with{nullptr};
|
||||
std::map<std::string, Symbol> symbols;
|
||||
|
@ -270,3 +270,5 @@ auto GetRemove(AstTreeStorage &storage, const std::string &name,
|
||||
// Various operators
|
||||
#define ADD(expr1, expr2) storage.Create<query::AdditionOperator>((expr1), (expr2))
|
||||
#define LESS(expr1, expr2) storage.Create<query::LessOperator>((expr1), (expr2))
|
||||
#define SUM(expr) \
|
||||
storage.Create<query::Aggregation>((expr), query::Aggregation::Op::SUM)
|
||||
|
@ -281,3 +281,17 @@ TEST(ExpressionEvaluator, Function) {
|
||||
ASSERT_THROW(op->Accept(eval.eval), QueryRuntimeException);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ExpressionEvaluator, Aggregation) {
|
||||
AstTreeStorage storage;
|
||||
auto aggr = storage.Create<Aggregation>(storage.Create<Literal>(42),
|
||||
Aggregation::Op::COUNT);
|
||||
SymbolTable symbol_table;
|
||||
auto aggr_sym = symbol_table.CreateSymbol("aggr");
|
||||
symbol_table[*aggr] = aggr_sym;
|
||||
Frame frame{symbol_table.max_position()};
|
||||
frame[aggr_sym] = TypedValue(1);
|
||||
ExpressionEvaluator eval{frame, symbol_table};
|
||||
aggr->Accept(eval);
|
||||
EXPECT_EQ(eval.PopBack().Value<int64_t>(), 1);
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include <list>
|
||||
#include <typeinfo>
|
||||
#include <tuple>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
@ -20,69 +21,137 @@ using Direction = query::EdgeAtom::Direction;
|
||||
|
||||
namespace {
|
||||
|
||||
class BaseOpChecker {
|
||||
public:
|
||||
virtual ~BaseOpChecker() {}
|
||||
|
||||
virtual void CheckOp(LogicalOperator &, const SymbolTable &) = 0;
|
||||
};
|
||||
|
||||
template <class TOp>
|
||||
class OpChecker : public BaseOpChecker {
|
||||
public:
|
||||
void CheckOp(LogicalOperator &op, const SymbolTable &symbol_table) override {
|
||||
auto *expected_op = dynamic_cast<TOp *>(&op);
|
||||
ASSERT_TRUE(expected_op);
|
||||
ExpectOp(*expected_op, symbol_table);
|
||||
}
|
||||
|
||||
virtual void ExpectOp(TOp &op, const SymbolTable &) {}
|
||||
};
|
||||
|
||||
using ExpectCreateNode = OpChecker<CreateNode>;
|
||||
using ExpectCreateExpand = OpChecker<CreateExpand>;
|
||||
using ExpectDelete = OpChecker<Delete>;
|
||||
using ExpectScanAll = OpChecker<ScanAll>;
|
||||
using ExpectExpand = OpChecker<Expand>;
|
||||
using ExpectNodeFilter = OpChecker<NodeFilter>;
|
||||
using ExpectEdgeFilter = OpChecker<EdgeFilter>;
|
||||
using ExpectFilter = OpChecker<Filter>;
|
||||
using ExpectProduce = OpChecker<Produce>;
|
||||
using ExpectSetProperty = OpChecker<SetProperty>;
|
||||
using ExpectSetProperties = OpChecker<SetProperties>;
|
||||
using ExpectSetLabels = OpChecker<SetLabels>;
|
||||
using ExpectRemoveProperty = OpChecker<RemoveProperty>;
|
||||
using ExpectRemoveLabels = OpChecker<RemoveLabels>;
|
||||
template <class TAccessor>
|
||||
using ExpectExpandUniquenessFilter =
|
||||
OpChecker<ExpandUniquenessFilter<TAccessor>>;
|
||||
using ExpectAccumulate = OpChecker<Accumulate>;
|
||||
|
||||
class ExpectAggregate : public OpChecker<Aggregate> {
|
||||
public:
|
||||
ExpectAggregate() = default;
|
||||
ExpectAggregate(const std::vector<query::Aggregation *> &aggregations,
|
||||
const std::unordered_set<query::Expression *> &group_by)
|
||||
: aggregations_(aggregations), group_by_(group_by) {}
|
||||
|
||||
void ExpectOp(Aggregate &op, const SymbolTable &symbol_table) override {
|
||||
auto aggr_it = aggregations_.begin();
|
||||
for (const auto &aggr_elem : op.aggregations()) {
|
||||
ASSERT_NE(aggr_it, aggregations_.end());
|
||||
auto aggr = *aggr_it++;
|
||||
auto expected =
|
||||
std::make_tuple(aggr->expression_, aggr->op_, symbol_table.at(*aggr));
|
||||
EXPECT_EQ(expected, aggr_elem);
|
||||
}
|
||||
EXPECT_EQ(aggr_it, aggregations_.end());
|
||||
auto got_group_by = std::unordered_set<query::Expression *>(
|
||||
op.group_by().begin(), op.group_by().end());
|
||||
EXPECT_EQ(group_by_, got_group_by);
|
||||
}
|
||||
|
||||
private:
|
||||
const std::vector<query::Aggregation *> aggregations_;
|
||||
const std::unordered_set<query::Expression *> group_by_;
|
||||
};
|
||||
|
||||
class PlanChecker : public LogicalOperatorVisitor {
|
||||
public:
|
||||
using LogicalOperatorVisitor::Visit;
|
||||
using LogicalOperatorVisitor::PostVisit;
|
||||
|
||||
PlanChecker(const std::list<size_t> &types) : types_(types) {}
|
||||
PlanChecker(const std::list<BaseOpChecker *> &checkers,
|
||||
const SymbolTable &symbol_table)
|
||||
: checkers_(checkers), symbol_table_(symbol_table) {}
|
||||
|
||||
void Visit(CreateNode &op) override { AssertType(op); }
|
||||
void Visit(CreateExpand &op) override { AssertType(op); }
|
||||
void Visit(Delete &op) override { AssertType(op); }
|
||||
void Visit(ScanAll &op) override { AssertType(op); }
|
||||
void Visit(Expand &op) override { AssertType(op); }
|
||||
void Visit(NodeFilter &op) override { AssertType(op); }
|
||||
void Visit(EdgeFilter &op) override { AssertType(op); }
|
||||
void Visit(Filter &op) override { AssertType(op); }
|
||||
void Visit(Produce &op) override { AssertType(op); }
|
||||
void Visit(SetProperty &op) override { AssertType(op); }
|
||||
void Visit(SetProperties &op) override { AssertType(op); }
|
||||
void Visit(SetLabels &op) override { AssertType(op); }
|
||||
void Visit(RemoveProperty &op) override { AssertType(op); }
|
||||
void Visit(RemoveLabels &op) override { AssertType(op); }
|
||||
void Visit(CreateNode &op) override { CheckOp(op); }
|
||||
void Visit(CreateExpand &op) override { CheckOp(op); }
|
||||
void Visit(Delete &op) override { CheckOp(op); }
|
||||
void Visit(ScanAll &op) override { CheckOp(op); }
|
||||
void Visit(Expand &op) override { CheckOp(op); }
|
||||
void Visit(NodeFilter &op) override { CheckOp(op); }
|
||||
void Visit(EdgeFilter &op) override { CheckOp(op); }
|
||||
void Visit(Filter &op) override { CheckOp(op); }
|
||||
void Visit(Produce &op) override { CheckOp(op); }
|
||||
void Visit(SetProperty &op) override { CheckOp(op); }
|
||||
void Visit(SetProperties &op) override { CheckOp(op); }
|
||||
void Visit(SetLabels &op) override { CheckOp(op); }
|
||||
void Visit(RemoveProperty &op) override { CheckOp(op); }
|
||||
void Visit(RemoveLabels &op) override { CheckOp(op); }
|
||||
void Visit(ExpandUniquenessFilter<VertexAccessor> &op) override {
|
||||
AssertType(op);
|
||||
CheckOp(op);
|
||||
}
|
||||
void Visit(ExpandUniquenessFilter<EdgeAccessor> &op) override {
|
||||
AssertType(op);
|
||||
}
|
||||
void Visit(Accumulate &op) override { AssertType(op); }
|
||||
void Visit(ExpandUniquenessFilter<EdgeAccessor> &op) override { CheckOp(op); }
|
||||
void Visit(Accumulate &op) override { CheckOp(op); }
|
||||
void Visit(Aggregate &op) override { CheckOp(op); }
|
||||
|
||||
std::list<size_t> types_;
|
||||
std::list<BaseOpChecker *> checkers_;
|
||||
|
||||
private:
|
||||
void AssertType(const LogicalOperator &op) {
|
||||
ASSERT_FALSE(types_.empty());
|
||||
ASSERT_EQ(types_.back(), typeid(op).hash_code());
|
||||
types_.pop_back();
|
||||
void CheckOp(LogicalOperator &op) {
|
||||
ASSERT_FALSE(checkers_.empty());
|
||||
checkers_.back()->CheckOp(op, symbol_table_);
|
||||
checkers_.pop_back();
|
||||
}
|
||||
|
||||
const SymbolTable &symbol_table_;
|
||||
};
|
||||
|
||||
template <class... TOps>
|
||||
auto CheckPlan(query::Query &query) {
|
||||
template <class... TChecker>
|
||||
auto CheckPlan(query::Query &query, TChecker... checker) {
|
||||
SymbolTable symbol_table;
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
query.Accept(symbol_generator);
|
||||
auto plan = MakeLogicalPlan(query, symbol_table);
|
||||
std::list<size_t> type_hashes{typeid(TOps).hash_code()...};
|
||||
PlanChecker plan_checker(type_hashes);
|
||||
std::list<BaseOpChecker *> checkers{&checker...};
|
||||
PlanChecker plan_checker(checkers, symbol_table);
|
||||
plan->Accept(plan_checker);
|
||||
EXPECT_TRUE(plan_checker.types_.empty());
|
||||
EXPECT_TRUE(plan_checker.checkers_.empty());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchNodeReturn) {
|
||||
// Test MATCH (n) RETURN n AS n
|
||||
AstTreeStorage storage;
|
||||
auto query = QUERY(MATCH(PATTERN(NODE("n"))), RETURN(IDENT("n"), AS("n")));
|
||||
CheckPlan<ScanAll, Produce>(*query);
|
||||
CheckPlan(*query, ExpectScanAll(), ExpectProduce());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, CreateNodeReturn) {
|
||||
// Test CREATE (n) RETURN n AS n
|
||||
AstTreeStorage storage;
|
||||
auto query = QUERY(CREATE(PATTERN(NODE("n"))), RETURN(IDENT("n"), AS("n")));
|
||||
CheckPlan<CreateNode, Produce>(*query);
|
||||
CheckPlan(*query, ExpectCreateNode(), ExpectAccumulate(), ExpectProduce());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, CreateExpand) {
|
||||
@ -93,14 +162,14 @@ TEST(TestLogicalPlanner, CreateExpand) {
|
||||
auto relationship = dba->edge_type("relationship");
|
||||
auto query = QUERY(CREATE(PATTERN(
|
||||
NODE("n"), EDGE("r", relationship, Direction::RIGHT), NODE("m"))));
|
||||
CheckPlan<CreateNode, CreateExpand>(*query);
|
||||
CheckPlan(*query, ExpectCreateNode(), ExpectCreateExpand());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, CreateMultipleNode) {
|
||||
// Test CREATE (n), (m)
|
||||
AstTreeStorage storage;
|
||||
auto query = QUERY(CREATE(PATTERN(NODE("n")), PATTERN(NODE("m"))));
|
||||
CheckPlan<CreateNode, CreateNode>(*query);
|
||||
CheckPlan(*query, ExpectCreateNode(), ExpectCreateNode());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, CreateNodeExpandNode) {
|
||||
@ -112,7 +181,8 @@ TEST(TestLogicalPlanner, CreateNodeExpandNode) {
|
||||
auto query = QUERY(CREATE(
|
||||
PATTERN(NODE("n"), EDGE("r", relationship, Direction::RIGHT), NODE("m")),
|
||||
PATTERN(NODE("l"))));
|
||||
CheckPlan<CreateNode, CreateExpand, CreateNode>(*query);
|
||||
CheckPlan(*query, ExpectCreateNode(), ExpectCreateExpand(),
|
||||
ExpectCreateNode());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchCreateExpand) {
|
||||
@ -125,7 +195,7 @@ TEST(TestLogicalPlanner, MatchCreateExpand) {
|
||||
QUERY(MATCH(PATTERN(NODE("n"))),
|
||||
CREATE(PATTERN(NODE("n"), EDGE("r", relationship, Direction::RIGHT),
|
||||
NODE("m"))));
|
||||
CheckPlan<ScanAll, CreateExpand>(*query);
|
||||
CheckPlan(*query, ExpectScanAll(), ExpectCreateExpand());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchLabeledNodes) {
|
||||
@ -136,7 +206,7 @@ TEST(TestLogicalPlanner, MatchLabeledNodes) {
|
||||
auto label = dba->label("label");
|
||||
auto query =
|
||||
QUERY(MATCH(PATTERN(NODE("n", label))), RETURN(IDENT("n"), AS("n")));
|
||||
CheckPlan<ScanAll, NodeFilter, Produce>(*query);
|
||||
CheckPlan(*query, ExpectScanAll(), ExpectNodeFilter(), ExpectProduce());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchPathReturn) {
|
||||
@ -148,7 +218,8 @@ TEST(TestLogicalPlanner, MatchPathReturn) {
|
||||
auto query =
|
||||
QUERY(MATCH(PATTERN(NODE("n"), EDGE("r", relationship), NODE("m"))),
|
||||
RETURN(IDENT("n"), AS("n")));
|
||||
CheckPlan<ScanAll, Expand, EdgeFilter, Produce>(*query);
|
||||
CheckPlan(*query, ExpectScanAll(), ExpectExpand(), ExpectEdgeFilter(),
|
||||
ExpectProduce());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchWhereReturn) {
|
||||
@ -160,14 +231,14 @@ TEST(TestLogicalPlanner, MatchWhereReturn) {
|
||||
auto query = QUERY(MATCH(PATTERN(NODE("n"))),
|
||||
WHERE(LESS(PROPERTY_LOOKUP("n", property), LITERAL(42))),
|
||||
RETURN(IDENT("n"), AS("n")));
|
||||
CheckPlan<ScanAll, Filter, Produce>(*query);
|
||||
CheckPlan(*query, ExpectScanAll(), ExpectFilter(), ExpectProduce());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchDelete) {
|
||||
// Test MATCH (n) DELETE n
|
||||
AstTreeStorage storage;
|
||||
auto query = QUERY(MATCH(PATTERN(NODE("n"))), DELETE(IDENT("n")));
|
||||
CheckPlan<ScanAll, Delete>(*query);
|
||||
CheckPlan(*query, ExpectScanAll(), ExpectDelete());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchNodeSet) {
|
||||
@ -180,7 +251,8 @@ TEST(TestLogicalPlanner, MatchNodeSet) {
|
||||
auto query = QUERY(MATCH(PATTERN(NODE("n"))),
|
||||
SET(PROPERTY_LOOKUP("n", prop), LITERAL(42)),
|
||||
SET("n", IDENT("n")), SET("n", {label}));
|
||||
CheckPlan<ScanAll, SetProperty, SetProperties, SetLabels>(*query);
|
||||
CheckPlan(*query, ExpectScanAll(), ExpectSetProperty(), ExpectSetProperties(),
|
||||
ExpectSetLabels());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchRemove) {
|
||||
@ -192,7 +264,8 @@ TEST(TestLogicalPlanner, MatchRemove) {
|
||||
auto label = dba->label("label");
|
||||
auto query = QUERY(MATCH(PATTERN(NODE("n"))),
|
||||
REMOVE(PROPERTY_LOOKUP("n", prop)), REMOVE("n", {label}));
|
||||
CheckPlan<ScanAll, RemoveProperty, RemoveLabels>(*query);
|
||||
CheckPlan(*query, ExpectScanAll(), ExpectRemoveProperty(),
|
||||
ExpectRemoveLabels());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchMultiPattern) {
|
||||
@ -202,8 +275,8 @@ TEST(TestLogicalPlanner, MatchMultiPattern) {
|
||||
PATTERN(NODE("j"), EDGE("e"), NODE("i"))));
|
||||
// We expect the expansions after the first to have a uniqueness filter in a
|
||||
// single MATCH clause.
|
||||
CheckPlan<ScanAll, Expand, ScanAll, Expand,
|
||||
ExpandUniquenessFilter<EdgeAccessor>>(*query);
|
||||
CheckPlan(*query, ExpectScanAll(), ExpectExpand(), ExpectScanAll(),
|
||||
ExpectExpand(), ExpectExpandUniquenessFilter<EdgeAccessor>());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchMultiPatternSameStart) {
|
||||
@ -213,7 +286,7 @@ TEST(TestLogicalPlanner, MatchMultiPatternSameStart) {
|
||||
MATCH(PATTERN(NODE("n")), PATTERN(NODE("n"), EDGE("e"), NODE("m"))));
|
||||
// We expect the second pattern to generate only an Expand, since another
|
||||
// ScanAll would be redundant.
|
||||
CheckPlan<ScanAll, Expand>(*query);
|
||||
CheckPlan(*query, ExpectScanAll(), ExpectExpand());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchMultiPatternSameExpandStart) {
|
||||
@ -224,8 +297,8 @@ TEST(TestLogicalPlanner, MatchMultiPatternSameExpandStart) {
|
||||
// We expect the second pattern to generate only an Expand. Another
|
||||
// ScanAll would be redundant, as it would generate the nodes obtained from
|
||||
// expansion. Additionally, a uniqueness filter is expected.
|
||||
CheckPlan<ScanAll, Expand, Expand, ExpandUniquenessFilter<EdgeAccessor>>(
|
||||
*query);
|
||||
CheckPlan(*query, ExpectScanAll(), ExpectExpand(), ExpectExpand(),
|
||||
ExpectExpandUniquenessFilter<EdgeAccessor>());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MultiMatch) {
|
||||
@ -236,8 +309,9 @@ TEST(TestLogicalPlanner, MultiMatch) {
|
||||
MATCH(PATTERN(NODE("j"), EDGE("e"), NODE("i"), EDGE("f"), NODE("h"))));
|
||||
// Multiple MATCH clauses form a Cartesian product, so the uniqueness should
|
||||
// not cross MATCH boundaries.
|
||||
CheckPlan<ScanAll, Expand, ScanAll, Expand, Expand,
|
||||
ExpandUniquenessFilter<EdgeAccessor>>(*query);
|
||||
CheckPlan(*query, ExpectScanAll(), ExpectExpand(), ExpectScanAll(),
|
||||
ExpectExpand(), ExpectExpand(),
|
||||
ExpectExpandUniquenessFilter<EdgeAccessor>());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MultiMatchSameStart) {
|
||||
@ -247,7 +321,7 @@ TEST(TestLogicalPlanner, MultiMatchSameStart) {
|
||||
MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))));
|
||||
// Similar to MatchMultiPatternSameStart, we expect only Expand from second
|
||||
// MATCH clause.
|
||||
CheckPlan<ScanAll, Expand>(*query);
|
||||
CheckPlan(*query, ExpectScanAll(), ExpectExpand());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchEdgeCycle) {
|
||||
@ -256,7 +330,7 @@ TEST(TestLogicalPlanner, MatchEdgeCycle) {
|
||||
auto query = QUERY(
|
||||
MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"), EDGE("r"), NODE("j"))));
|
||||
// There is no ExpandUniquenessFilter for referencing the same edge.
|
||||
CheckPlan<ScanAll, Expand, Expand>(*query);
|
||||
CheckPlan(*query, ExpectScanAll(), ExpectExpand(), ExpectExpand());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchWithReturn) {
|
||||
@ -264,7 +338,8 @@ TEST(TestLogicalPlanner, MatchWithReturn) {
|
||||
AstTreeStorage storage;
|
||||
auto query = QUERY(MATCH(PATTERN(NODE("old"))), WITH(IDENT("old"), AS("new")),
|
||||
RETURN(IDENT("new"), AS("new")));
|
||||
CheckPlan<ScanAll, Accumulate, Produce, Produce>(*query);
|
||||
// No accumulation since we only do reads.
|
||||
CheckPlan(*query, ExpectScanAll(), ExpectProduce(), ExpectProduce());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchWithWhereReturn) {
|
||||
@ -276,7 +351,9 @@ TEST(TestLogicalPlanner, MatchWithWhereReturn) {
|
||||
auto query = QUERY(MATCH(PATTERN(NODE("old"))), WITH(IDENT("old"), AS("new")),
|
||||
WHERE(LESS(PROPERTY_LOOKUP("new", prop), LITERAL(42))),
|
||||
RETURN(IDENT("new"), AS("new")));
|
||||
CheckPlan<ScanAll, Accumulate, Produce, Filter, Produce>(*query);
|
||||
// No accumulation since we only do reads.
|
||||
CheckPlan(*query, ExpectScanAll(), ExpectProduce(), ExpectFilter(),
|
||||
ExpectProduce());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, CreateMultiExpand) {
|
||||
@ -289,7 +366,56 @@ TEST(TestLogicalPlanner, CreateMultiExpand) {
|
||||
auto query = QUERY(
|
||||
CREATE(PATTERN(NODE("n"), EDGE("r", r, Direction::RIGHT), NODE("m")),
|
||||
PATTERN(NODE("n"), EDGE("p", p, Direction::RIGHT), NODE("l"))));
|
||||
CheckPlan<CreateNode, CreateExpand, CreateExpand>(*query);
|
||||
CheckPlan(*query, ExpectCreateNode(), ExpectCreateExpand(),
|
||||
ExpectCreateExpand());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchWithSumWhereReturn) {
|
||||
// Test MATCH (n) WITH SUM(n.prop) + 42 AS sum WHERE sum < 42
|
||||
// RETURN sum AS result
|
||||
Dbms dbms;
|
||||
auto dba = dbms.active();
|
||||
auto prop = dba->property("prop");
|
||||
AstTreeStorage storage;
|
||||
auto sum = SUM(PROPERTY_LOOKUP("n", prop));
|
||||
auto literal = LITERAL(42);
|
||||
auto query =
|
||||
QUERY(MATCH(PATTERN(NODE("n"))), WITH(ADD(sum, literal), AS("sum")),
|
||||
WHERE(LESS(IDENT("sum"), LITERAL(42))),
|
||||
RETURN(IDENT("sum"), AS("result")));
|
||||
auto aggr = ExpectAggregate({sum}, {literal});
|
||||
CheckPlan(*query, ExpectScanAll(), aggr, ExpectProduce(), ExpectFilter(),
|
||||
ExpectProduce());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchReturnSum) {
|
||||
// Test MATCH (n) RETURN SUM(n.prop1) AS sum, n.prop2 AS group
|
||||
Dbms dbms;
|
||||
auto dba = dbms.active();
|
||||
auto prop1 = dba->property("prop1");
|
||||
auto prop2 = dba->property("prop2");
|
||||
AstTreeStorage storage;
|
||||
auto sum = SUM(PROPERTY_LOOKUP("n", prop1));
|
||||
auto n_prop2 = PROPERTY_LOOKUP("n", prop2);
|
||||
auto query = QUERY(MATCH(PATTERN(NODE("n"))),
|
||||
RETURN(sum, AS("sum"), n_prop2, AS("group")));
|
||||
auto aggr = ExpectAggregate({sum}, {n_prop2});
|
||||
CheckPlan(*query, ExpectScanAll(), aggr, ExpectProduce());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, CreateWithSum) {
|
||||
// Test CREATE (n) WITH SUM(n.prop) AS sum
|
||||
Dbms dbms;
|
||||
auto dba = dbms.active();
|
||||
auto prop = dba->property("prop");
|
||||
AstTreeStorage storage;
|
||||
auto sum = SUM(PROPERTY_LOOKUP("n", prop));
|
||||
auto query = QUERY(CREATE(PATTERN(NODE("n"))), WITH(sum, AS("sum")));
|
||||
auto aggr = ExpectAggregate({sum}, {});
|
||||
// We expect both the accumulation and aggregation because the part before
|
||||
// WITH updates the database.
|
||||
CheckPlan(*query, ExpectCreateNode(), ExpectAccumulate(), aggr,
|
||||
ExpectProduce());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -358,6 +358,7 @@ TEST(TestSymbolGenerator, CreateMultiExpand) {
|
||||
SymbolTable symbol_table;
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
query->Accept(symbol_generator);
|
||||
EXPECT_EQ(symbol_table.max_position(), 5);
|
||||
auto n1 = symbol_table.at(*node_n1->identifier_);
|
||||
auto n2 = symbol_table.at(*node_n2->identifier_);
|
||||
EXPECT_EQ(n1, n2);
|
||||
@ -408,4 +409,42 @@ TEST(TestSymbolGenerator, CreateExpandProperty) {
|
||||
EXPECT_THROW(query->Accept(symbol_generator), SemanticException);
|
||||
}
|
||||
|
||||
TEST(TestSymbolGenerator, MatchReturnSum) {
|
||||
// Test MATCH (n) RETURN SUM(n.prop) + 42 AS result
|
||||
Dbms dbms;
|
||||
auto dba = dbms.active();
|
||||
auto prop = dba->property("prop");
|
||||
AstTreeStorage storage;
|
||||
auto node = NODE("n");
|
||||
auto sum = SUM(PROPERTY_LOOKUP("n", prop));
|
||||
auto as_result = AS("result");
|
||||
auto query =
|
||||
QUERY(MATCH(PATTERN(node)), RETURN(ADD(sum, LITERAL(42)), as_result));
|
||||
SymbolTable symbol_table;
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
query->Accept(symbol_generator);
|
||||
// 3 symbols for: 'n', 'sum' and 'result'.
|
||||
EXPECT_EQ(symbol_table.max_position(), 3);
|
||||
auto node_symbol = symbol_table.at(*node->identifier_);
|
||||
auto sum_symbol = symbol_table.at(*sum);
|
||||
EXPECT_NE(node_symbol, sum_symbol);
|
||||
auto result_symbol = symbol_table.at(*as_result);
|
||||
EXPECT_NE(result_symbol, node_symbol);
|
||||
EXPECT_NE(result_symbol, sum_symbol);
|
||||
}
|
||||
|
||||
TEST(TestSymbolGenerator, NestedAggregation) {
|
||||
// Test MATCH (n) RETURN SUM(42 + SUM(n.prop)) AS s
|
||||
Dbms dbms;
|
||||
auto dba = dbms.active();
|
||||
auto prop = dba->property("prop");
|
||||
AstTreeStorage storage;
|
||||
auto query = QUERY(
|
||||
MATCH(PATTERN(NODE("n"))),
|
||||
RETURN(SUM(ADD(LITERAL(42), SUM(PROPERTY_LOOKUP("n", prop)))), AS("s")));
|
||||
SymbolTable symbol_table;
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
EXPECT_THROW(query->Accept(symbol_generator), SemanticException);
|
||||
}
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user