Plan aggregation in WITH and RETURN clauses

Summary:
Generate symbols for aggregation results.
Plan aggregation in WITH clause.
Plan aggregation in RETURN clause.
Extract handling write clauses to a function.

Reviewers: mislav.bradac, florijan

Reviewed By: mislav.bradac, florijan

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D236
This commit is contained in:
Teon Banek 2017-04-12 12:58:10 +02:00
parent 6cb1cdc607
commit 355b9a9b9a
9 changed files with 452 additions and 128 deletions

View File

@ -61,6 +61,7 @@ class ExpressionEvaluator : public TreeVisitorBase {
return last;
}
using TreeVisitorBase::PreVisit;
using TreeVisitorBase::Visit;
using TreeVisitorBase::PostVisit;
@ -147,12 +148,14 @@ class ExpressionEvaluator : public TreeVisitorBase {
result_stack_.push_back(literal.value_);
}
void Visit(Aggregation &aggregation) override {
bool PreVisit(Aggregation &aggregation) override {
auto value = frame_[symbol_table_.at(aggregation)];
// Aggregation is probably always simple type, but let's switch accessor
// just to be sure.
SwitchAccessors(value);
result_stack_.emplace_back(std::move(value));
// Prevent evaluation of expressions inside the aggregation.
return false;
}
void PostVisit(Function &function) override {

View File

@ -833,6 +833,9 @@ class Aggregate : public LogicalOperator {
void Accept(LogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
const auto &aggregations() const { return aggregations_; }
const auto &group_by() const { return group_by_; }
private:
const std::shared_ptr<LogicalOperator> input_;
const std::vector<Element> aggregations_;

View File

@ -60,7 +60,7 @@ auto ReducePattern(
}
auto GenCreateForPattern(Pattern &pattern, LogicalOperator *input_op,
const query::SymbolTable &symbol_table,
const SymbolTable &symbol_table,
std::unordered_set<int> &bound_symbols) {
auto base = [&](NodeAtom *node) -> LogicalOperator * {
if (BindSymbol(bound_symbols, symbol_table.at(*node->identifier_)))
@ -91,7 +91,7 @@ auto GenCreateForPattern(Pattern &pattern, LogicalOperator *input_op,
}
auto GenCreate(Create &create, LogicalOperator *input_op,
const query::SymbolTable &symbol_table,
const SymbolTable &symbol_table,
std::unordered_set<int> &bound_symbols) {
auto last_op = input_op;
for (auto pattern : create.patterns_) {
@ -102,7 +102,7 @@ auto GenCreate(Create &create, LogicalOperator *input_op,
}
auto GenMatchForPattern(Pattern &pattern, LogicalOperator *input_op,
const query::SymbolTable &symbol_table,
const SymbolTable &symbol_table,
std::unordered_set<int> &bound_symbols,
std::vector<Symbol> &edge_symbols) {
auto base = [&](NodeAtom *node) {
@ -162,7 +162,7 @@ auto GenMatchForPattern(Pattern &pattern, LogicalOperator *input_op,
}
auto GenMatch(Match &match, LogicalOperator *input_op,
const query::SymbolTable &symbol_table,
const SymbolTable &symbol_table,
std::unordered_set<int> &bound_symbols) {
auto last_op = input_op;
std::vector<Symbol> edge_symbols;
@ -177,20 +177,87 @@ auto GenMatch(Match &match, LogicalOperator *input_op,
return last_op;
}
// Ast tree visitor which collects all the symbols referenced by identifiers.
class SymbolCollector : public TreeVisitorBase {
// Ast tree visitor which collects the context for a return body. The return
// body are the named expressions found in WITH and RETURN clauses. The
// collected context consists of used symbols, aggregations and group by named
// expressions.
class ReturnBodyContext : public TreeVisitorBase {
public:
SymbolCollector(const SymbolTable &symbol_table)
ReturnBodyContext(const SymbolTable &symbol_table)
: symbol_table_(symbol_table) {}
using TreeVisitorBase::PreVisit;
using TreeVisitorBase::Visit;
using TreeVisitorBase::PostVisit;
void Visit(Literal &) override { has_aggregation_.emplace_back(false); }
void Visit(Identifier &ident) override {
symbols_.insert(symbol_table_.at(ident));
has_aggregation_.emplace_back(false);
}
#define VISIT_BINARY_OPERATOR(BinaryOperator) \
void PostVisit(BinaryOperator &op) override { \
/* has_aggregation_ stack is reversed, last result is from the 2nd \
* expression. */ \
bool aggr2 = has_aggregation_.back(); \
has_aggregation_.pop_back(); \
bool aggr1 = has_aggregation_.back(); \
has_aggregation_.pop_back(); \
bool has_aggr = aggr1 || aggr2; \
if (has_aggr) { \
/* Group by the expression which does not contain aggregation. */ \
/* Possible optimization is to ignore constant value expressions */ \
group_by_.emplace_back(aggr1 ? op.expression2_ : op.expression1_); \
} \
/* Propagate that this whole expression may contain an aggregation. */ \
has_aggregation_.emplace_back(has_aggr); \
}
VISIT_BINARY_OPERATOR(OrOperator)
VISIT_BINARY_OPERATOR(XorOperator)
VISIT_BINARY_OPERATOR(AndOperator)
VISIT_BINARY_OPERATOR(AdditionOperator)
VISIT_BINARY_OPERATOR(SubtractionOperator)
VISIT_BINARY_OPERATOR(MultiplicationOperator)
VISIT_BINARY_OPERATOR(DivisionOperator)
VISIT_BINARY_OPERATOR(ModOperator)
VISIT_BINARY_OPERATOR(NotEqualOperator)
VISIT_BINARY_OPERATOR(EqualOperator)
VISIT_BINARY_OPERATOR(LessOperator)
VISIT_BINARY_OPERATOR(GreaterOperator)
VISIT_BINARY_OPERATOR(LessEqualOperator)
VISIT_BINARY_OPERATOR(GreaterEqualOperator)
#undef VISIT_BINARY_OPERATOR
void PostVisit(Aggregation &aggr) override {
// Aggregation contains a virtual symbol, where the result will be stored.
const auto &symbol = symbol_table_.at(aggr);
aggregations_.emplace_back(aggr.expression_, aggr.op_, symbol);
has_aggregation_.back() = true;
// Possible optimization is to skip remembering symbols inside aggregation.
// If and when implementing this, don't forget that Accumulate needs *all*
// the symbols, including those inside aggregation.
}
void PostVisit(NamedExpression &named_expr) override {
if (!has_aggregation_.back()) {
group_by_.emplace_back(named_expr.expression_);
}
has_aggregation_.pop_back();
}
// Set of symbols used inside the visited expressions outside of aggregation
// expression.
const auto &symbols() const { return symbols_; }
// List of aggregation elements found in expressions.
const auto &aggregations() const { return aggregations_; }
// When there is at least one aggregation element, all the non-aggregate (sub)
// expressions are used for grouping. For example, in `WITH sum(n.a) + 2 * n.b
// AS sum, n.c AS nc`, we will group by `2 * n.b` and `n.c`.
const auto &group_by() const { return group_by_; }
private:
// Calculates the Symbol hash based on its position.
@ -202,27 +269,55 @@ class SymbolCollector : public TreeVisitorBase {
const SymbolTable &symbol_table_;
std::unordered_set<Symbol, SymbolHash> symbols_;
std::vector<Aggregate::Element> aggregations_;
std::vector<Expression *> group_by_;
// Flag indicating whether an expression contains an aggregation.
std::list<bool> has_aggregation_;
};
auto GenReturnBody(LogicalOperator *input_op, bool advance_command,
const std::vector<NamedExpression *> &named_expressions,
const SymbolTable &symbol_table, bool accumulate = false) {
ReturnBodyContext context(symbol_table);
// Generate context for all named expressions.
for (auto &named_expr : named_expressions) {
named_expr->Accept(context);
}
auto symbols =
std::vector<Symbol>(context.symbols().begin(), context.symbols().end());
auto last_op = input_op;
if (accumulate) {
// We only advance the command in Accumulate. This is done for WITH clause,
// when the first part updated the database. RETURN clause may only need an
// accumulation after updates, without advancing the command.
last_op = new Accumulate(std::shared_ptr<LogicalOperator>(last_op), symbols,
advance_command);
}
if (!context.aggregations().empty()) {
last_op =
new Aggregate(std::shared_ptr<LogicalOperator>(last_op),
context.aggregations(), context.group_by(), symbols);
}
return new Produce(std::shared_ptr<LogicalOperator>(last_op),
named_expressions);
}
auto GenWith(With &with, LogicalOperator *input_op,
const query::SymbolTable &symbol_table) {
const SymbolTable &symbol_table, bool is_write) {
// WITH clause is Accumulate/Aggregate (advance_command) + Produce and
// optional Filter.
if (with.distinct_) {
// TODO: Plan disctint with, when operator available.
// TODO: Plan distinct with, when operator available.
throw NotYetImplemented();
}
// WITH clause is Accumulate/Aggregate (advance_command) + Produce.
SymbolCollector symbol_collector(symbol_table);
// Collect used symbols so that accumulate doesn't copy the whole frame.
for (auto &named_expr : with.named_expressions_) {
named_expr->expression_->Accept(symbol_collector);
}
auto symbols = symbol_collector.symbols();
// TODO: Check whether we need aggregate instead of accumulate.
// In case of update and aggregation, we want to accumulate first, so that
// when aggregating, we get the latest results. Similar to RETURN clause.
bool accumulate = is_write;
// No need to advance the command if we only performed reads.
bool advance_command = is_write;
LogicalOperator *last_op =
new Accumulate(std::shared_ptr<LogicalOperator>(input_op),
std::vector<Symbol>(symbols.begin(), symbols.end()), true);
last_op = new Produce(std::shared_ptr<LogicalOperator>(last_op),
with.named_expressions_);
GenReturnBody(input_op, advance_command, with.named_expressions_,
symbol_table, accumulate);
if (with.where_) {
last_op = new Filter(std::shared_ptr<LogicalOperator>(last_op),
with.where_->expression_);
@ -230,55 +325,80 @@ auto GenWith(With &with, LogicalOperator *input_op,
return last_op;
}
auto GenReturn(Return &ret, LogicalOperator *input_op,
const SymbolTable &symbol_table, bool is_write) {
// Similar to WITH clause, but we want to accumulate and advance command when
// the query writes to the database. This way we handle the case when we want
// to return expressions with the latest updated results. For example,
// `MATCH (n) -- () SET n.prop = n.prop + 1 RETURN n.prop`. If we match same
// `n` multiple 'k' times, we want to return 'k' results where the property
// value is the same, final result of 'k' increments.
bool accumulate = is_write;
bool advance_command = false;
return GenReturnBody(input_op, advance_command, ret.named_expressions_,
symbol_table, accumulate);
}
// Generate an operator for a clause which writes to the database. If the clause
// isn't handled, returns nullptr.
LogicalOperator *HandleWriteClause(Clause *clause, LogicalOperator *input_op,
const SymbolTable &symbol_table,
std::unordered_set<int> &bound_symbols) {
if (auto *create = dynamic_cast<Create *>(clause)) {
return GenCreate(*create, input_op, symbol_table, bound_symbols);
} else if (auto *del = dynamic_cast<query::Delete *>(clause)) {
return new plan::Delete(std::shared_ptr<LogicalOperator>(input_op),
del->expressions_, del->detach_);
} else if (auto *set = dynamic_cast<query::SetProperty *>(clause)) {
return new plan::SetProperty(std::shared_ptr<LogicalOperator>(input_op),
set->property_lookup_, set->expression_);
} else if (auto *set = dynamic_cast<query::SetProperties *>(clause)) {
auto op = set->update_ ? plan::SetProperties::Op::UPDATE
: plan::SetProperties::Op::REPLACE;
const auto &input_symbol = symbol_table.at(*set->identifier_);
return new plan::SetProperties(std::shared_ptr<LogicalOperator>(input_op),
input_symbol, set->expression_, op);
} else if (auto *set = dynamic_cast<query::SetLabels *>(clause)) {
const auto &input_symbol = symbol_table.at(*set->identifier_);
return new plan::SetLabels(std::shared_ptr<LogicalOperator>(input_op),
input_symbol, set->labels_);
} else if (auto *rem = dynamic_cast<query::RemoveProperty *>(clause)) {
return new plan::RemoveProperty(std::shared_ptr<LogicalOperator>(input_op),
rem->property_lookup_);
} else if (auto *rem = dynamic_cast<query::RemoveLabels *>(clause)) {
const auto &input_symbol = symbol_table.at(*rem->identifier_);
return new plan::RemoveLabels(std::shared_ptr<LogicalOperator>(input_op),
input_symbol, rem->labels_);
}
return nullptr;
}
} // namespace
std::unique_ptr<LogicalOperator> MakeLogicalPlan(
query::Query &query, const query::SymbolTable &symbol_table) {
// TODO: Extract functions and state into a class with methods. Possibly a
// visitor or similar to avoid all those dynamic casts.
LogicalOperator *input_op = nullptr;
// bound_symbols set is used to differentiate cycles in pattern matching, so
// that the operator can be correctly initialized whether to read the symbol
// or write it. E.g. `MATCH (n) -[r]- (n)` would bind (and write) the first
// `n`, but the latter `n` would only read the already written information.
std::unordered_set<int> bound_symbols;
// Set to true if a query command performs a writes to the database.
bool is_write = false;
LogicalOperator *input_op = nullptr;
for (auto &clause : query.clauses_) {
auto *clause_ptr = clause;
if (auto *match = dynamic_cast<Match *>(clause_ptr)) {
// Clauses which read from the database.
if (auto *match = dynamic_cast<Match *>(clause)) {
input_op = GenMatch(*match, input_op, symbol_table, bound_symbols);
} else if (auto *ret = dynamic_cast<Return *>(clause_ptr)) {
input_op = new Produce(std::shared_ptr<LogicalOperator>(input_op),
ret->named_expressions_);
} else if (auto *create = dynamic_cast<Create *>(clause_ptr)) {
input_op = GenCreate(*create, input_op, symbol_table, bound_symbols);
} else if (auto *del = dynamic_cast<query::Delete *>(clause_ptr)) {
input_op = new plan::Delete(std::shared_ptr<LogicalOperator>(input_op),
del->expressions_, del->detach_);
} else if (auto *set = dynamic_cast<query::SetProperty *>(clause_ptr)) {
input_op =
new plan::SetProperty(std::shared_ptr<LogicalOperator>(input_op),
set->property_lookup_, set->expression_);
} else if (auto *set = dynamic_cast<query::SetProperties *>(clause_ptr)) {
auto op = set->update_ ? plan::SetProperties::Op::UPDATE
: plan::SetProperties::Op::REPLACE;
const auto &input_symbol = symbol_table.at(*set->identifier_);
input_op =
new plan::SetProperties(std::shared_ptr<LogicalOperator>(input_op),
input_symbol, set->expression_, op);
} else if (auto *set = dynamic_cast<query::SetLabels *>(clause_ptr)) {
const auto &input_symbol = symbol_table.at(*set->identifier_);
input_op = new plan::SetLabels(std::shared_ptr<LogicalOperator>(input_op),
input_symbol, set->labels_);
} else if (auto *rem = dynamic_cast<query::RemoveProperty *>(clause_ptr)) {
input_op = new plan::RemoveProperty(
std::shared_ptr<LogicalOperator>(input_op), rem->property_lookup_);
} else if (auto *rem = dynamic_cast<query::RemoveLabels *>(clause_ptr)) {
const auto &input_symbol = symbol_table.at(*rem->identifier_);
input_op =
new plan::RemoveLabels(std::shared_ptr<LogicalOperator>(input_op),
input_symbol, rem->labels_);
} else if (auto *with = dynamic_cast<query::With *>(clause_ptr)) {
input_op = GenWith(*with, input_op, symbol_table);
} else if (auto *ret = dynamic_cast<Return *>(clause)) {
input_op = GenReturn(*ret, input_op, symbol_table, is_write);
} else if (auto *with = dynamic_cast<query::With *>(clause)) {
input_op = GenWith(*with, input_op, symbol_table, is_write);
// WITH clause advances the command, so reset the flag.
is_write = false;
} else if (auto *op = HandleWriteClause(clause, input_op, symbol_table,
bound_symbols)) {
is_write = true;
input_op = op;
} else {
throw NotYetImplemented();
}

View File

@ -108,6 +108,20 @@ void SymbolGenerator::Visit(Identifier &ident) {
symbol_table_[ident] = symbol;
}
void SymbolGenerator::Visit(Aggregation &aggr) {
// Create a virtual symbol for aggregation result.
symbol_table_[aggr] = symbol_table_.CreateSymbol("");
if (scope_.in_aggregation) {
throw SemanticException(
"Using aggregate functions inside aggregate functions is not allowed");
}
scope_.in_aggregation = true;
}
void SymbolGenerator::PostVisit(Aggregation &aggr) {
scope_.in_aggregation = false;
}
// Pattern and its subparts.
void SymbolGenerator::Visit(Pattern &pattern) {

View File

@ -24,23 +24,25 @@ class SymbolGenerator : public TreeVisitorBase {
using TreeVisitorBase::PostVisit;
// Clauses
void Visit(Create &create) override;
void PostVisit(Create &create) override;
void PostVisit(Return &ret) override;
void Visit(With &with) override;
void PostVisit(With &with) override;
void Visit(Where &where) override;
void Visit(Create &) override;
void PostVisit(Create &) override;
void PostVisit(Return &) override;
void Visit(With &) override;
void PostVisit(With &) override;
void Visit(Where &) override;
// Expressions
void Visit(Identifier &ident) override;
void Visit(Identifier &) override;
void Visit(Aggregation &) override;
void PostVisit(Aggregation &) override;
// Pattern and its subparts.
void Visit(Pattern &pattern) override;
void PostVisit(Pattern &pattern) override;
void Visit(NodeAtom &node_atom) override;
void PostVisit(NodeAtom &node_atom) override;
void Visit(EdgeAtom &edge_atom) override;
void PostVisit(EdgeAtom &edge_atom) override;
void Visit(Pattern &) override;
void PostVisit(Pattern &) override;
void Visit(NodeAtom &) override;
void PostVisit(NodeAtom &) override;
void Visit(EdgeAtom &) override;
void PostVisit(EdgeAtom &) override;
private:
// Scope stores the state of where we are when visiting the AST and a map of
@ -56,6 +58,7 @@ class SymbolGenerator : public TreeVisitorBase {
bool in_node_atom{false};
bool in_edge_atom{false};
bool in_property_map{false};
bool in_aggregation{false};
// Pointer to With clause if we are inside it, otherwise nullptr.
With *with{nullptr};
std::map<std::string, Symbol> symbols;

View File

@ -270,3 +270,5 @@ auto GetRemove(AstTreeStorage &storage, const std::string &name,
// Various operators
#define ADD(expr1, expr2) storage.Create<query::AdditionOperator>((expr1), (expr2))
#define LESS(expr1, expr2) storage.Create<query::LessOperator>((expr1), (expr2))
#define SUM(expr) \
storage.Create<query::Aggregation>((expr), query::Aggregation::Op::SUM)

View File

@ -281,3 +281,17 @@ TEST(ExpressionEvaluator, Function) {
ASSERT_THROW(op->Accept(eval.eval), QueryRuntimeException);
}
}
TEST(ExpressionEvaluator, Aggregation) {
AstTreeStorage storage;
auto aggr = storage.Create<Aggregation>(storage.Create<Literal>(42),
Aggregation::Op::COUNT);
SymbolTable symbol_table;
auto aggr_sym = symbol_table.CreateSymbol("aggr");
symbol_table[*aggr] = aggr_sym;
Frame frame{symbol_table.max_position()};
frame[aggr_sym] = TypedValue(1);
ExpressionEvaluator eval{frame, symbol_table};
aggr->Accept(eval);
EXPECT_EQ(eval.PopBack().Value<int64_t>(), 1);
}

View File

@ -1,5 +1,6 @@
#include <list>
#include <typeinfo>
#include <tuple>
#include <unordered_set>
#include "gtest/gtest.h"
@ -20,69 +21,137 @@ using Direction = query::EdgeAtom::Direction;
namespace {
class BaseOpChecker {
public:
virtual ~BaseOpChecker() {}
virtual void CheckOp(LogicalOperator &, const SymbolTable &) = 0;
};
template <class TOp>
class OpChecker : public BaseOpChecker {
public:
void CheckOp(LogicalOperator &op, const SymbolTable &symbol_table) override {
auto *expected_op = dynamic_cast<TOp *>(&op);
ASSERT_TRUE(expected_op);
ExpectOp(*expected_op, symbol_table);
}
virtual void ExpectOp(TOp &op, const SymbolTable &) {}
};
using ExpectCreateNode = OpChecker<CreateNode>;
using ExpectCreateExpand = OpChecker<CreateExpand>;
using ExpectDelete = OpChecker<Delete>;
using ExpectScanAll = OpChecker<ScanAll>;
using ExpectExpand = OpChecker<Expand>;
using ExpectNodeFilter = OpChecker<NodeFilter>;
using ExpectEdgeFilter = OpChecker<EdgeFilter>;
using ExpectFilter = OpChecker<Filter>;
using ExpectProduce = OpChecker<Produce>;
using ExpectSetProperty = OpChecker<SetProperty>;
using ExpectSetProperties = OpChecker<SetProperties>;
using ExpectSetLabels = OpChecker<SetLabels>;
using ExpectRemoveProperty = OpChecker<RemoveProperty>;
using ExpectRemoveLabels = OpChecker<RemoveLabels>;
template <class TAccessor>
using ExpectExpandUniquenessFilter =
OpChecker<ExpandUniquenessFilter<TAccessor>>;
using ExpectAccumulate = OpChecker<Accumulate>;
class ExpectAggregate : public OpChecker<Aggregate> {
public:
ExpectAggregate() = default;
ExpectAggregate(const std::vector<query::Aggregation *> &aggregations,
const std::unordered_set<query::Expression *> &group_by)
: aggregations_(aggregations), group_by_(group_by) {}
void ExpectOp(Aggregate &op, const SymbolTable &symbol_table) override {
auto aggr_it = aggregations_.begin();
for (const auto &aggr_elem : op.aggregations()) {
ASSERT_NE(aggr_it, aggregations_.end());
auto aggr = *aggr_it++;
auto expected =
std::make_tuple(aggr->expression_, aggr->op_, symbol_table.at(*aggr));
EXPECT_EQ(expected, aggr_elem);
}
EXPECT_EQ(aggr_it, aggregations_.end());
auto got_group_by = std::unordered_set<query::Expression *>(
op.group_by().begin(), op.group_by().end());
EXPECT_EQ(group_by_, got_group_by);
}
private:
const std::vector<query::Aggregation *> aggregations_;
const std::unordered_set<query::Expression *> group_by_;
};
class PlanChecker : public LogicalOperatorVisitor {
public:
using LogicalOperatorVisitor::Visit;
using LogicalOperatorVisitor::PostVisit;
PlanChecker(const std::list<size_t> &types) : types_(types) {}
PlanChecker(const std::list<BaseOpChecker *> &checkers,
const SymbolTable &symbol_table)
: checkers_(checkers), symbol_table_(symbol_table) {}
void Visit(CreateNode &op) override { AssertType(op); }
void Visit(CreateExpand &op) override { AssertType(op); }
void Visit(Delete &op) override { AssertType(op); }
void Visit(ScanAll &op) override { AssertType(op); }
void Visit(Expand &op) override { AssertType(op); }
void Visit(NodeFilter &op) override { AssertType(op); }
void Visit(EdgeFilter &op) override { AssertType(op); }
void Visit(Filter &op) override { AssertType(op); }
void Visit(Produce &op) override { AssertType(op); }
void Visit(SetProperty &op) override { AssertType(op); }
void Visit(SetProperties &op) override { AssertType(op); }
void Visit(SetLabels &op) override { AssertType(op); }
void Visit(RemoveProperty &op) override { AssertType(op); }
void Visit(RemoveLabels &op) override { AssertType(op); }
void Visit(CreateNode &op) override { CheckOp(op); }
void Visit(CreateExpand &op) override { CheckOp(op); }
void Visit(Delete &op) override { CheckOp(op); }
void Visit(ScanAll &op) override { CheckOp(op); }
void Visit(Expand &op) override { CheckOp(op); }
void Visit(NodeFilter &op) override { CheckOp(op); }
void Visit(EdgeFilter &op) override { CheckOp(op); }
void Visit(Filter &op) override { CheckOp(op); }
void Visit(Produce &op) override { CheckOp(op); }
void Visit(SetProperty &op) override { CheckOp(op); }
void Visit(SetProperties &op) override { CheckOp(op); }
void Visit(SetLabels &op) override { CheckOp(op); }
void Visit(RemoveProperty &op) override { CheckOp(op); }
void Visit(RemoveLabels &op) override { CheckOp(op); }
void Visit(ExpandUniquenessFilter<VertexAccessor> &op) override {
AssertType(op);
CheckOp(op);
}
void Visit(ExpandUniquenessFilter<EdgeAccessor> &op) override {
AssertType(op);
}
void Visit(Accumulate &op) override { AssertType(op); }
void Visit(ExpandUniquenessFilter<EdgeAccessor> &op) override { CheckOp(op); }
void Visit(Accumulate &op) override { CheckOp(op); }
void Visit(Aggregate &op) override { CheckOp(op); }
std::list<size_t> types_;
std::list<BaseOpChecker *> checkers_;
private:
void AssertType(const LogicalOperator &op) {
ASSERT_FALSE(types_.empty());
ASSERT_EQ(types_.back(), typeid(op).hash_code());
types_.pop_back();
void CheckOp(LogicalOperator &op) {
ASSERT_FALSE(checkers_.empty());
checkers_.back()->CheckOp(op, symbol_table_);
checkers_.pop_back();
}
const SymbolTable &symbol_table_;
};
template <class... TOps>
auto CheckPlan(query::Query &query) {
template <class... TChecker>
auto CheckPlan(query::Query &query, TChecker... checker) {
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
query.Accept(symbol_generator);
auto plan = MakeLogicalPlan(query, symbol_table);
std::list<size_t> type_hashes{typeid(TOps).hash_code()...};
PlanChecker plan_checker(type_hashes);
std::list<BaseOpChecker *> checkers{&checker...};
PlanChecker plan_checker(checkers, symbol_table);
plan->Accept(plan_checker);
EXPECT_TRUE(plan_checker.types_.empty());
EXPECT_TRUE(plan_checker.checkers_.empty());
}
TEST(TestLogicalPlanner, MatchNodeReturn) {
// Test MATCH (n) RETURN n AS n
AstTreeStorage storage;
auto query = QUERY(MATCH(PATTERN(NODE("n"))), RETURN(IDENT("n"), AS("n")));
CheckPlan<ScanAll, Produce>(*query);
CheckPlan(*query, ExpectScanAll(), ExpectProduce());
}
TEST(TestLogicalPlanner, CreateNodeReturn) {
// Test CREATE (n) RETURN n AS n
AstTreeStorage storage;
auto query = QUERY(CREATE(PATTERN(NODE("n"))), RETURN(IDENT("n"), AS("n")));
CheckPlan<CreateNode, Produce>(*query);
CheckPlan(*query, ExpectCreateNode(), ExpectAccumulate(), ExpectProduce());
}
TEST(TestLogicalPlanner, CreateExpand) {
@ -93,14 +162,14 @@ TEST(TestLogicalPlanner, CreateExpand) {
auto relationship = dba->edge_type("relationship");
auto query = QUERY(CREATE(PATTERN(
NODE("n"), EDGE("r", relationship, Direction::RIGHT), NODE("m"))));
CheckPlan<CreateNode, CreateExpand>(*query);
CheckPlan(*query, ExpectCreateNode(), ExpectCreateExpand());
}
TEST(TestLogicalPlanner, CreateMultipleNode) {
// Test CREATE (n), (m)
AstTreeStorage storage;
auto query = QUERY(CREATE(PATTERN(NODE("n")), PATTERN(NODE("m"))));
CheckPlan<CreateNode, CreateNode>(*query);
CheckPlan(*query, ExpectCreateNode(), ExpectCreateNode());
}
TEST(TestLogicalPlanner, CreateNodeExpandNode) {
@ -112,7 +181,8 @@ TEST(TestLogicalPlanner, CreateNodeExpandNode) {
auto query = QUERY(CREATE(
PATTERN(NODE("n"), EDGE("r", relationship, Direction::RIGHT), NODE("m")),
PATTERN(NODE("l"))));
CheckPlan<CreateNode, CreateExpand, CreateNode>(*query);
CheckPlan(*query, ExpectCreateNode(), ExpectCreateExpand(),
ExpectCreateNode());
}
TEST(TestLogicalPlanner, MatchCreateExpand) {
@ -125,7 +195,7 @@ TEST(TestLogicalPlanner, MatchCreateExpand) {
QUERY(MATCH(PATTERN(NODE("n"))),
CREATE(PATTERN(NODE("n"), EDGE("r", relationship, Direction::RIGHT),
NODE("m"))));
CheckPlan<ScanAll, CreateExpand>(*query);
CheckPlan(*query, ExpectScanAll(), ExpectCreateExpand());
}
TEST(TestLogicalPlanner, MatchLabeledNodes) {
@ -136,7 +206,7 @@ TEST(TestLogicalPlanner, MatchLabeledNodes) {
auto label = dba->label("label");
auto query =
QUERY(MATCH(PATTERN(NODE("n", label))), RETURN(IDENT("n"), AS("n")));
CheckPlan<ScanAll, NodeFilter, Produce>(*query);
CheckPlan(*query, ExpectScanAll(), ExpectNodeFilter(), ExpectProduce());
}
TEST(TestLogicalPlanner, MatchPathReturn) {
@ -148,7 +218,8 @@ TEST(TestLogicalPlanner, MatchPathReturn) {
auto query =
QUERY(MATCH(PATTERN(NODE("n"), EDGE("r", relationship), NODE("m"))),
RETURN(IDENT("n"), AS("n")));
CheckPlan<ScanAll, Expand, EdgeFilter, Produce>(*query);
CheckPlan(*query, ExpectScanAll(), ExpectExpand(), ExpectEdgeFilter(),
ExpectProduce());
}
TEST(TestLogicalPlanner, MatchWhereReturn) {
@ -160,14 +231,14 @@ TEST(TestLogicalPlanner, MatchWhereReturn) {
auto query = QUERY(MATCH(PATTERN(NODE("n"))),
WHERE(LESS(PROPERTY_LOOKUP("n", property), LITERAL(42))),
RETURN(IDENT("n"), AS("n")));
CheckPlan<ScanAll, Filter, Produce>(*query);
CheckPlan(*query, ExpectScanAll(), ExpectFilter(), ExpectProduce());
}
TEST(TestLogicalPlanner, MatchDelete) {
// Test MATCH (n) DELETE n
AstTreeStorage storage;
auto query = QUERY(MATCH(PATTERN(NODE("n"))), DELETE(IDENT("n")));
CheckPlan<ScanAll, Delete>(*query);
CheckPlan(*query, ExpectScanAll(), ExpectDelete());
}
TEST(TestLogicalPlanner, MatchNodeSet) {
@ -180,7 +251,8 @@ TEST(TestLogicalPlanner, MatchNodeSet) {
auto query = QUERY(MATCH(PATTERN(NODE("n"))),
SET(PROPERTY_LOOKUP("n", prop), LITERAL(42)),
SET("n", IDENT("n")), SET("n", {label}));
CheckPlan<ScanAll, SetProperty, SetProperties, SetLabels>(*query);
CheckPlan(*query, ExpectScanAll(), ExpectSetProperty(), ExpectSetProperties(),
ExpectSetLabels());
}
TEST(TestLogicalPlanner, MatchRemove) {
@ -192,7 +264,8 @@ TEST(TestLogicalPlanner, MatchRemove) {
auto label = dba->label("label");
auto query = QUERY(MATCH(PATTERN(NODE("n"))),
REMOVE(PROPERTY_LOOKUP("n", prop)), REMOVE("n", {label}));
CheckPlan<ScanAll, RemoveProperty, RemoveLabels>(*query);
CheckPlan(*query, ExpectScanAll(), ExpectRemoveProperty(),
ExpectRemoveLabels());
}
TEST(TestLogicalPlanner, MatchMultiPattern) {
@ -202,8 +275,8 @@ TEST(TestLogicalPlanner, MatchMultiPattern) {
PATTERN(NODE("j"), EDGE("e"), NODE("i"))));
// We expect the expansions after the first to have a uniqueness filter in a
// single MATCH clause.
CheckPlan<ScanAll, Expand, ScanAll, Expand,
ExpandUniquenessFilter<EdgeAccessor>>(*query);
CheckPlan(*query, ExpectScanAll(), ExpectExpand(), ExpectScanAll(),
ExpectExpand(), ExpectExpandUniquenessFilter<EdgeAccessor>());
}
TEST(TestLogicalPlanner, MatchMultiPatternSameStart) {
@ -213,7 +286,7 @@ TEST(TestLogicalPlanner, MatchMultiPatternSameStart) {
MATCH(PATTERN(NODE("n")), PATTERN(NODE("n"), EDGE("e"), NODE("m"))));
// We expect the second pattern to generate only an Expand, since another
// ScanAll would be redundant.
CheckPlan<ScanAll, Expand>(*query);
CheckPlan(*query, ExpectScanAll(), ExpectExpand());
}
TEST(TestLogicalPlanner, MatchMultiPatternSameExpandStart) {
@ -224,8 +297,8 @@ TEST(TestLogicalPlanner, MatchMultiPatternSameExpandStart) {
// We expect the second pattern to generate only an Expand. Another
// ScanAll would be redundant, as it would generate the nodes obtained from
// expansion. Additionally, a uniqueness filter is expected.
CheckPlan<ScanAll, Expand, Expand, ExpandUniquenessFilter<EdgeAccessor>>(
*query);
CheckPlan(*query, ExpectScanAll(), ExpectExpand(), ExpectExpand(),
ExpectExpandUniquenessFilter<EdgeAccessor>());
}
TEST(TestLogicalPlanner, MultiMatch) {
@ -236,8 +309,9 @@ TEST(TestLogicalPlanner, MultiMatch) {
MATCH(PATTERN(NODE("j"), EDGE("e"), NODE("i"), EDGE("f"), NODE("h"))));
// Multiple MATCH clauses form a Cartesian product, so the uniqueness should
// not cross MATCH boundaries.
CheckPlan<ScanAll, Expand, ScanAll, Expand, Expand,
ExpandUniquenessFilter<EdgeAccessor>>(*query);
CheckPlan(*query, ExpectScanAll(), ExpectExpand(), ExpectScanAll(),
ExpectExpand(), ExpectExpand(),
ExpectExpandUniquenessFilter<EdgeAccessor>());
}
TEST(TestLogicalPlanner, MultiMatchSameStart) {
@ -247,7 +321,7 @@ TEST(TestLogicalPlanner, MultiMatchSameStart) {
MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))));
// Similar to MatchMultiPatternSameStart, we expect only Expand from second
// MATCH clause.
CheckPlan<ScanAll, Expand>(*query);
CheckPlan(*query, ExpectScanAll(), ExpectExpand());
}
TEST(TestLogicalPlanner, MatchEdgeCycle) {
@ -256,7 +330,7 @@ TEST(TestLogicalPlanner, MatchEdgeCycle) {
auto query = QUERY(
MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"), EDGE("r"), NODE("j"))));
// There is no ExpandUniquenessFilter for referencing the same edge.
CheckPlan<ScanAll, Expand, Expand>(*query);
CheckPlan(*query, ExpectScanAll(), ExpectExpand(), ExpectExpand());
}
TEST(TestLogicalPlanner, MatchWithReturn) {
@ -264,7 +338,8 @@ TEST(TestLogicalPlanner, MatchWithReturn) {
AstTreeStorage storage;
auto query = QUERY(MATCH(PATTERN(NODE("old"))), WITH(IDENT("old"), AS("new")),
RETURN(IDENT("new"), AS("new")));
CheckPlan<ScanAll, Accumulate, Produce, Produce>(*query);
// No accumulation since we only do reads.
CheckPlan(*query, ExpectScanAll(), ExpectProduce(), ExpectProduce());
}
TEST(TestLogicalPlanner, MatchWithWhereReturn) {
@ -276,7 +351,9 @@ TEST(TestLogicalPlanner, MatchWithWhereReturn) {
auto query = QUERY(MATCH(PATTERN(NODE("old"))), WITH(IDENT("old"), AS("new")),
WHERE(LESS(PROPERTY_LOOKUP("new", prop), LITERAL(42))),
RETURN(IDENT("new"), AS("new")));
CheckPlan<ScanAll, Accumulate, Produce, Filter, Produce>(*query);
// No accumulation since we only do reads.
CheckPlan(*query, ExpectScanAll(), ExpectProduce(), ExpectFilter(),
ExpectProduce());
}
TEST(TestLogicalPlanner, CreateMultiExpand) {
@ -289,7 +366,56 @@ TEST(TestLogicalPlanner, CreateMultiExpand) {
auto query = QUERY(
CREATE(PATTERN(NODE("n"), EDGE("r", r, Direction::RIGHT), NODE("m")),
PATTERN(NODE("n"), EDGE("p", p, Direction::RIGHT), NODE("l"))));
CheckPlan<CreateNode, CreateExpand, CreateExpand>(*query);
CheckPlan(*query, ExpectCreateNode(), ExpectCreateExpand(),
ExpectCreateExpand());
}
TEST(TestLogicalPlanner, MatchWithSumWhereReturn) {
// Test MATCH (n) WITH SUM(n.prop) + 42 AS sum WHERE sum < 42
// RETURN sum AS result
Dbms dbms;
auto dba = dbms.active();
auto prop = dba->property("prop");
AstTreeStorage storage;
auto sum = SUM(PROPERTY_LOOKUP("n", prop));
auto literal = LITERAL(42);
auto query =
QUERY(MATCH(PATTERN(NODE("n"))), WITH(ADD(sum, literal), AS("sum")),
WHERE(LESS(IDENT("sum"), LITERAL(42))),
RETURN(IDENT("sum"), AS("result")));
auto aggr = ExpectAggregate({sum}, {literal});
CheckPlan(*query, ExpectScanAll(), aggr, ExpectProduce(), ExpectFilter(),
ExpectProduce());
}
TEST(TestLogicalPlanner, MatchReturnSum) {
// Test MATCH (n) RETURN SUM(n.prop1) AS sum, n.prop2 AS group
Dbms dbms;
auto dba = dbms.active();
auto prop1 = dba->property("prop1");
auto prop2 = dba->property("prop2");
AstTreeStorage storage;
auto sum = SUM(PROPERTY_LOOKUP("n", prop1));
auto n_prop2 = PROPERTY_LOOKUP("n", prop2);
auto query = QUERY(MATCH(PATTERN(NODE("n"))),
RETURN(sum, AS("sum"), n_prop2, AS("group")));
auto aggr = ExpectAggregate({sum}, {n_prop2});
CheckPlan(*query, ExpectScanAll(), aggr, ExpectProduce());
}
TEST(TestLogicalPlanner, CreateWithSum) {
// Test CREATE (n) WITH SUM(n.prop) AS sum
Dbms dbms;
auto dba = dbms.active();
auto prop = dba->property("prop");
AstTreeStorage storage;
auto sum = SUM(PROPERTY_LOOKUP("n", prop));
auto query = QUERY(CREATE(PATTERN(NODE("n"))), WITH(sum, AS("sum")));
auto aggr = ExpectAggregate({sum}, {});
// We expect both the accumulation and aggregation because the part before
// WITH updates the database.
CheckPlan(*query, ExpectCreateNode(), ExpectAccumulate(), aggr,
ExpectProduce());
}
} // namespace

View File

@ -358,6 +358,7 @@ TEST(TestSymbolGenerator, CreateMultiExpand) {
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
query->Accept(symbol_generator);
EXPECT_EQ(symbol_table.max_position(), 5);
auto n1 = symbol_table.at(*node_n1->identifier_);
auto n2 = symbol_table.at(*node_n2->identifier_);
EXPECT_EQ(n1, n2);
@ -408,4 +409,42 @@ TEST(TestSymbolGenerator, CreateExpandProperty) {
EXPECT_THROW(query->Accept(symbol_generator), SemanticException);
}
TEST(TestSymbolGenerator, MatchReturnSum) {
// Test MATCH (n) RETURN SUM(n.prop) + 42 AS result
Dbms dbms;
auto dba = dbms.active();
auto prop = dba->property("prop");
AstTreeStorage storage;
auto node = NODE("n");
auto sum = SUM(PROPERTY_LOOKUP("n", prop));
auto as_result = AS("result");
auto query =
QUERY(MATCH(PATTERN(node)), RETURN(ADD(sum, LITERAL(42)), as_result));
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
query->Accept(symbol_generator);
// 3 symbols for: 'n', 'sum' and 'result'.
EXPECT_EQ(symbol_table.max_position(), 3);
auto node_symbol = symbol_table.at(*node->identifier_);
auto sum_symbol = symbol_table.at(*sum);
EXPECT_NE(node_symbol, sum_symbol);
auto result_symbol = symbol_table.at(*as_result);
EXPECT_NE(result_symbol, node_symbol);
EXPECT_NE(result_symbol, sum_symbol);
}
TEST(TestSymbolGenerator, NestedAggregation) {
// Test MATCH (n) RETURN SUM(42 + SUM(n.prop)) AS s
Dbms dbms;
auto dba = dbms.active();
auto prop = dba->property("prop");
AstTreeStorage storage;
auto query = QUERY(
MATCH(PATTERN(NODE("n"))),
RETURN(SUM(ADD(LITERAL(42), SUM(PROPERTY_LOOKUP("n", prop)))), AS("s")));
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
EXPECT_THROW(query->Accept(symbol_generator), SemanticException);
}
}