Plan OrderBy

Summary:
Support OrderBy in test macros.
Test planning OrderBy.
Handle symbol visibility for ORDER BY and WHERE.
Add Hash struct to Symbol.
Collect used symbols in ORDER BY and WHERE.

Reviewers: mislav.bradac, florijan

Reviewed By: florijan

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D307
This commit is contained in:
Teon Banek 2017-04-24 13:51:16 +02:00
parent fe36835519
commit e53e232e49
10 changed files with 471 additions and 177 deletions

View File

@ -30,31 +30,63 @@ auto SymbolGenerator::GetOrCreateSymbol(const std::string &name,
return CreateSymbol(name, type);
}
void SymbolGenerator::BindNamedExpressionSymbols(
const std::vector<NamedExpression *> &named_expressions) {
std::unordered_set<std::string> seen_names;
for (auto &named_expr : named_expressions) {
// Improvement would be to infer the type of the expression.
void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) {
for (auto &expr : body.named_expressions) {
expr->Accept(*this);
}
// WITH/RETURN clause removes declarations of all the previous variables and
// declares only those established through named expressions. New declarations
// must not be visible inside named expressions themselves.
bool removed_old_names = false;
if ((!where && body.order_by.empty()) || scope_.has_aggregation) {
// WHERE and ORDER BY need to see both the old and new symbols, unless we
// have an aggregation. Therefore, we can clear the symbols immediately if
// there is neither ORDER BY nor WHERE, or we have an aggregation.
scope_.symbols.clear();
removed_old_names = true;
}
// Create symbols for named expressions.
std::unordered_set<std::string> new_names;
for (auto &named_expr : body.named_expressions) {
const auto &name = named_expr->name_;
if (!seen_names.insert(name).second) {
if (!new_names.insert(name).second) {
throw SemanticException(
"Multiple results with the same name '{}' are not allowed.", name);
}
// An improvement would be to infer the type of the expression, so that the
// new symbol would have a more specific type.
symbol_table_[*named_expr] = CreateSymbol(name);
}
}
void SymbolGenerator::VisitSkipAndLimit(Expression *skip, Expression *limit) {
if (skip) {
scope_.in_order_by = true;
for (const auto &order_pair : body.order_by) {
order_pair.second->Accept(*this);
}
scope_.in_order_by = false;
if (body.skip) {
scope_.in_skip = true;
skip->Accept(*this);
body.skip->Accept(*this);
scope_.in_skip = false;
}
if (limit) {
if (body.limit) {
scope_.in_limit = true;
limit->Accept(*this);
body.limit->Accept(*this);
scope_.in_limit = false;
}
if (where) where->Accept(*this);
if (!removed_old_names) {
// We have an ORDER BY or WHERE, but no aggregation, which means we didn't
// clear the old symbols, so do it now. We cannot just call clear, because
// we've added new symbols.
for (auto sym_it = scope_.symbols.begin();
sym_it != scope_.symbols.end();) {
if (new_names.find(sym_it->first) == new_names.end()) {
sym_it = scope_.symbols.erase(sym_it);
} else {
sym_it++;
}
}
}
scope_.has_aggregation = false;
}
// Clauses
@ -64,33 +96,21 @@ void SymbolGenerator::PostVisit(Create &create) { scope_.in_create = false; }
bool SymbolGenerator::PreVisit(Return &ret) {
scope_.in_return = true;
for (auto &expr : ret.body_.named_expressions) {
expr->Accept(*this);
}
// Named expressions establish bindings for expressions which come after
// return, but not for the expressions contained inside.
BindNamedExpressionSymbols(ret.body_.named_expressions);
VisitSkipAndLimit(ret.body_.skip, ret.body_.limit);
VisitReturnBody(ret.body_);
scope_.in_return = false;
return false; // We handled the traversal ourselves.
}
bool SymbolGenerator::PreVisit(With &with) {
scope_.in_with = true;
for (auto &expr : with.body_.named_expressions) {
expr->Accept(*this);
}
VisitReturnBody(with.body_, with.where_);
scope_.in_with = false;
// WITH clause removes declarations of all the previous variables and declares
// only those established through named expressions. New declarations must not
// be visible inside named expressions themselves.
scope_.symbols.clear();
BindNamedExpressionSymbols(with.body_.named_expressions);
VisitSkipAndLimit(with.body_.skip, with.body_.limit);
if (with.where_) with.where_->Accept(*this);
return false; // We handled the traversal ourselves.
}
void SymbolGenerator::Visit(Where &) { scope_.in_where = true; }
void SymbolGenerator::PostVisit(Where &) { scope_.in_where = false; }
// Expressions
void SymbolGenerator::Visit(Identifier &ident) {
@ -133,7 +153,8 @@ void SymbolGenerator::Visit(Aggregation &aggr) {
// Check if the aggregation can be used in this context. This check should
// probably move to a separate phase, which checks if the query is well
// formed.
if (!scope_.in_return && !scope_.in_with) {
if ((!scope_.in_return && !scope_.in_with) || scope_.in_order_by ||
scope_.in_skip || scope_.in_limit || scope_.in_where) {
throw SemanticException(
"Aggregation functions are only allowed in WITH and RETURN");
}
@ -146,6 +167,7 @@ void SymbolGenerator::Visit(Aggregation &aggr) {
// Currently, we only have aggregation operators which return numbers.
symbol_table_[aggr] = symbol_table_.CreateSymbol("", Symbol::Type::Number);
scope_.in_aggregation = true;
scope_.has_aggregation = true;
}
void SymbolGenerator::PostVisit(Aggregation &aggr) {

View File

@ -29,6 +29,8 @@ class SymbolGenerator : public TreeVisitorBase {
void PostVisit(Create &) override;
bool PreVisit(Return &) override;
bool PreVisit(With &) override;
void Visit(Where &) override;
void PostVisit(Where &) override;
// Expressions
void Visit(Identifier &) override;
@ -62,6 +64,11 @@ class SymbolGenerator : public TreeVisitorBase {
bool in_with{false};
bool in_skip{false};
bool in_limit{false};
bool in_order_by{false};
bool in_where{false};
// True if the return/with contains an aggregation in any named expression.
bool has_aggregation{false};
// Map from variable names to symbols.
std::map<std::string, Symbol> symbols;
};
@ -77,10 +84,7 @@ class SymbolGenerator : public TreeVisitorBase {
auto GetOrCreateSymbol(const std::string &name,
Symbol::Type type = Symbol::Type::Any);
void BindNamedExpressionSymbols(
const std::vector<NamedExpression *> &named_expressions);
void VisitSkipAndLimit(Expression *skip, Expression *limit);
void VisitReturnBody(ReturnBody &body, Where *where = nullptr);
SymbolTable &symbol_table_;
Scope scope_;

View File

@ -13,10 +13,17 @@ class Symbol {
enum class Type { Any, Vertex, Edge, Path, Number };
static std::string TypeToString(Type type) {
const char *enum_string[] = {"Any", "Vertex", "Edge", "Path"};
const char *enum_string[] = {"Any", "Vertex", "Edge", "Path", "Number"};
return enum_string[static_cast<int>(type)];
}
// Calculates the Symbol hash based on its position.
struct Hash {
size_t operator()(const Symbol &symbol) const {
return std::hash<int>{}(symbol.position_);
}
};
Symbol() {}
Symbol(const std::string &name, int position, Type type = Type::Any)
: name_(name), position_(position), type_(type) {}

View File

@ -42,25 +42,32 @@ void Interpret(const std::string &query, GraphDbAccessor &db_accessor,
Frame frame(symbol_table.max_position());
std::vector<std::string> header;
bool is_return = false;
std::vector<Symbol> output_symbols;
if (auto produce = dynamic_cast<plan::Produce *>(logical_plan.get())) {
// top level node in the operator tree is a produce (return)
is_return = true;
// collect the symbols from the return clause
for (auto named_expression : produce->named_expressions())
output_symbols.emplace_back(symbol_table[*named_expression]);
} else if (auto order_by =
dynamic_cast<plan::OrderBy *>(logical_plan.get())) {
is_return = true;
output_symbols = order_by->output_symbols();
}
if (is_return) {
// top level node in the operator tree is a produce/order_by (return)
// so stream out results
// generate header
for (auto named_expression : produce->named_expressions())
header.push_back(named_expression->name_);
for (const auto &symbol : output_symbols) header.push_back(symbol.name_);
stream.Header(header);
// collect the symbols from the return clause
std::vector<Symbol> symbols;
for (auto named_expression : produce->named_expressions())
symbols.emplace_back(symbol_table[*named_expression]);
// stream out results
auto cursor = produce->MakeCursor(db_accessor);
auto cursor = logical_plan->MakeCursor(db_accessor);
while (cursor->Pull(frame, symbol_table)) {
std::vector<TypedValue> values;
for (auto &symbol : symbols) values.emplace_back(frame[symbol]);
for (const auto &symbol : output_symbols)
values.emplace_back(frame[symbol]);
stream.Result(values);
}
} else if (dynamic_cast<plan::CreateNode *>(logical_plan.get()) ||
@ -72,7 +79,7 @@ void Interpret(const std::string &query, GraphDbAccessor &db_accessor,
dynamic_cast<plan::RemoveLabels *>(logical_plan.get()) ||
dynamic_cast<plan::Delete *>(logical_plan.get())) {
stream.Header(header);
auto cursor = logical_plan.get()->MakeCursor(db_accessor);
auto cursor = logical_plan->MakeCursor(db_accessor);
while (cursor->Pull(frame, symbol_table)) continue;
} else {
throw QueryRuntimeException("Unknown top level LogicalOp");

View File

@ -1224,9 +1224,9 @@ void Limit::LimitCursor::Reset() {
}
OrderBy::OrderBy(const std::shared_ptr<LogicalOperator> &input,
const std::vector<std::pair<Ordering, Expression *>> order_by,
const std::vector<Symbol> remember)
: input_(input), remember_(remember) {
const std::vector<std::pair<Ordering, Expression *>> &order_by,
const std::vector<Symbol> &output_symbols)
: input_(input), output_symbols_(output_symbols) {
// split the order_by vector into two vectors of orderings and expressions
std::vector<Ordering> ordering;
ordering.reserve(order_by.size());
@ -1259,12 +1259,12 @@ bool OrderBy::OrderByCursor::Pull(Frame &frame,
order_by.emplace_back(evaluator.PopBack());
}
// collect the remember elements
std::list<TypedValue> remember;
for (const Symbol &remember_sym : self_.remember_)
remember.emplace_back(frame[remember_sym]);
// collect the output elements
std::list<TypedValue> output;
for (const Symbol &output_sym : self_.output_symbols_)
output.emplace_back(frame[output_sym]);
cache_.emplace_back(order_by, remember);
cache_.emplace_back(order_by, output);
}
std::sort(cache_.begin(), cache_.end(),
@ -1278,13 +1278,13 @@ bool OrderBy::OrderByCursor::Pull(Frame &frame,
if (cache_it_ == cache_.end()) return false;
// place the remembered values on the frame
debug_assert(self_.remember_.size() == cache_it_->second.size(),
"Number of values does not match the number of remember symbols "
// place the output values on the frame
debug_assert(self_.output_symbols_.size() == cache_it_->second.size(),
"Number of values does not match the number of output symbols "
"in OrderBy");
auto remember_sym_it = self_.remember_.begin();
for (const TypedValue &remember : cache_it_->second)
frame[*remember_sym_it++] = remember;
auto output_sym_it = self_.output_symbols_.begin();
for (const TypedValue &output : cache_it_->second)
frame[*output_sym_it++] = output;
cache_it_++;
return true;

View File

@ -827,6 +827,8 @@ class Accumulate : public LogicalOperator {
void Accept(LogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
const auto &symbols() const { return symbols_; };
private:
const std::shared_ptr<LogicalOperator> input_;
const std::vector<Symbol> symbols_;
@ -1066,11 +1068,13 @@ class Limit : public LogicalOperator {
class OrderBy : public LogicalOperator {
public:
OrderBy(const std::shared_ptr<LogicalOperator> &input,
const std::vector<std::pair<Ordering, Expression *>> order_by,
const std::vector<Symbol> remember);
const std::vector<std::pair<Ordering, Expression *>> &order_by,
const std::vector<Symbol> &output_symbols);
void Accept(LogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
const auto &output_symbols() const { return output_symbols_; }
private:
// custom Comparator type for comparing lists of TypedValues
// does lexicographical ordering of elements based on the above
@ -1091,7 +1095,7 @@ class OrderBy : public LogicalOperator {
const std::shared_ptr<LogicalOperator> input_;
TypedValueListCompare compare_;
std::vector<Expression *> order_by_;
const std::vector<Symbol> remember_;
const std::vector<Symbol> output_symbols_;
// custom comparison for TypedValue objects
// behaves generally like Neo's ORDER BY comparison operator:

View File

@ -1,5 +1,6 @@
#include "query/plan/planner.hpp"
#include <algorithm>
#include <functional>
#include <unordered_set>
@ -184,17 +185,37 @@ auto GenMatch(Match &match, LogicalOperator *input_op,
// * flag whether the results need to be DISTINCT;
// * optional SKIP expression;
// * optional LIMIT expression and
// * optional ORDER BY expression.
// * optional ORDER BY expressions.
//
// In addition to the above, we collect information on used symbols,
// aggregations and expressions used for group by.
class ReturnBodyContext : public TreeVisitorBase {
public:
ReturnBodyContext(const ReturnBody &body, const SymbolTable &symbol_table)
: body_(body), symbol_table_(symbol_table) {
ReturnBodyContext(const ReturnBody &body, const SymbolTable &symbol_table,
Where *where = nullptr)
: body_(body), symbol_table_(symbol_table), where_(where) {
// Collect symbols from named expressions.
output_symbols_.reserve(body_.named_expressions.size());
for (auto &named_expr : body_.named_expressions) {
output_symbols_.emplace_back(symbol_table_.at(*named_expr));
named_expr->Accept(*this);
}
if (aggregations_.empty()) {
// Visit order_by and where if we do not have aggregations. This way we
// prevent collecting group_by expressions from order_by and where, which
// would be very wrong. When we have aggregation, order_by and where can
// only use new symbols (ensured in semantic analysis), so we don't care
// about collecting used_symbols. Also, semantic analysis should
// have prevented any aggregations from appearing here.
for (const auto &order_pair : body.order_by) {
order_pair.second->Accept(*this);
}
if (where) {
where->Accept(*this);
}
debug_assert(aggregations_.empty(),
"Unexpected aggregations in ORDER BY or WHERE");
}
}
using TreeVisitorBase::PreVisit;
@ -204,7 +225,13 @@ class ReturnBodyContext : public TreeVisitorBase {
void Visit(Literal &) override { has_aggregation_.emplace_back(false); }
void Visit(Identifier &ident) override {
symbols_.insert(symbol_table_.at(ident));
const auto &symbol = symbol_table_.at(ident);
if (std::find(output_symbols_.begin(), output_symbols_.end(), symbol) ==
output_symbols_.end()) {
// Don't pick up new symbols, even though they may be used in ORDER BY or
// WHERE.
used_symbols_.insert(symbol);
}
has_aggregation_.emplace_back(false);
}
@ -217,7 +244,7 @@ class ReturnBodyContext : public TreeVisitorBase {
bool aggr1 = has_aggregation_.back(); \
has_aggregation_.pop_back(); \
bool has_aggr = aggr1 || aggr2; \
if (has_aggr) { \
if (has_aggr && !(aggr1 && aggr2)) { \
/* Group by the expression which does not contain aggregation. */ \
/* Possible optimization is to ignore constant value expressions */ \
group_by_.emplace_back(aggr1 ? op.expression2_ : op.expression1_); \
@ -264,13 +291,18 @@ class ReturnBodyContext : public TreeVisitorBase {
bool distinct() const { return body_.distinct; }
// Named expressions which are used to produce results.
const auto &named_expressions() const { return body_.named_expressions; }
// Pairs of (Ordering, Expression *) for sorting results.
const auto &order_by() const { return body_.order_by; }
// Optional expression which determines how many results to skip.
auto *skip() const { return body_.skip; }
// Optional expression which determines how many results to produce.
auto *limit() const { return body_.limit; }
// Optional Where clause for filtering.
const auto *where() const { return where_; }
// Set of symbols used inside the visited expressions outside of aggregation
// expression.
const auto &symbols() const { return symbols_; }
// expression. These only includes old symbols, even though new ones may have
// been used in ORDER BY or WHERE.
const auto &used_symbols() const { return used_symbols_; }
// List of aggregation elements found in expressions.
const auto &aggregations() const { return aggregations_; }
// When there is at least one aggregation element, all the non-aggregate (sub)
@ -278,17 +310,16 @@ class ReturnBodyContext : public TreeVisitorBase {
// AS sum, n.c AS nc`, we will group by `2 * n.b` and `n.c`.
const auto &group_by() const { return group_by_; }
private:
// Calculates the Symbol hash based on its position.
struct SymbolHash {
size_t operator()(const Symbol &symbol) const {
return std::hash<int>{}(symbol.position_);
}
};
// All symbols generated by named expressions. They are collected in order of
// named_expressions.
const auto &output_symbols() const { return output_symbols_; }
private:
const ReturnBody &body_;
const SymbolTable &symbol_table_;
std::unordered_set<Symbol, SymbolHash> symbols_;
const Where *const where_ = nullptr;
std::unordered_set<Symbol, Symbol::Hash> used_symbols_;
std::vector<Symbol> output_symbols_;
std::vector<Aggregate::Element> aggregations_;
std::vector<Expression *> group_by_;
// Flag indicating whether an expression contains an aggregation.
@ -314,8 +345,8 @@ auto GenReturnBody(LogicalOperator *input_op, bool advance_command,
// TODO: Plan with distinct, when operator available.
throw utils::NotYetImplemented();
}
auto symbols =
std::vector<Symbol>(body.symbols().begin(), body.symbols().end());
std::vector<Symbol> used_symbols(body.used_symbols().begin(),
body.used_symbols().end());
auto last_op = input_op;
if (body.aggregations().empty()) {
// In case when we have SKIP/LIMIT and we don't perform aggregations, we
@ -329,18 +360,30 @@ auto GenReturnBody(LogicalOperator *input_op, bool advance_command,
// We only advance the command in Accumulate. This is done for WITH clause,
// when the first part updated the database. RETURN clause may only need an
// accumulation after updates, without advancing the command.
last_op = new Accumulate(std::shared_ptr<LogicalOperator>(last_op), symbols,
advance_command);
last_op = new Accumulate(std::shared_ptr<LogicalOperator>(last_op),
used_symbols, advance_command);
}
if (!body.aggregations().empty()) {
// When we have aggregation, SKIP/LIMIT should always come after it.
last_op = GenSkipLimit(
new Aggregate(std::shared_ptr<LogicalOperator>(last_op),
body.aggregations(), body.group_by(), symbols),
body.aggregations(), body.group_by(), used_symbols),
body);
}
return new Produce(std::shared_ptr<LogicalOperator>(last_op),
body.named_expressions());
last_op = new Produce(std::shared_ptr<LogicalOperator>(last_op),
body.named_expressions());
// Where may see new symbols so it comes after we generate Produce.
if (body.where()) {
last_op = new Filter(std::shared_ptr<LogicalOperator>(last_op),
body.where()->expression_);
}
// Like Where, OrderBy can read from symbols established by named expressions
// in Produce, so it must come after it.
if (!body.order_by().empty()) {
last_op = new OrderBy(std::shared_ptr<LogicalOperator>(last_op),
body.order_by(), body.output_symbols());
}
return last_op;
}
auto GenWith(With &with, LogicalOperator *input_op,
@ -353,17 +396,13 @@ auto GenWith(With &with, LogicalOperator *input_op,
bool accumulate = is_write;
// No need to advance the command if we only performed reads.
bool advance_command = is_write;
ReturnBodyContext body(with.body_, symbol_table);
ReturnBodyContext body(with.body_, symbol_table, with.where_);
LogicalOperator *last_op =
GenReturnBody(input_op, advance_command, body, accumulate);
// Reset bound symbols, so that only those in WITH are exposed.
bound_symbols.clear();
for (auto &named_expr : with.body_.named_expressions) {
BindSymbol(bound_symbols, symbol_table.at(*named_expr));
}
if (with.where_) {
last_op = new Filter(std::shared_ptr<LogicalOperator>(last_op),
with.where_->expression_);
for (const auto &symbol : body.output_symbols()) {
BindSymbol(bound_symbols, symbol);
}
return last_op;
}

View File

@ -1,18 +1,74 @@
///
/// @file
/// This file provides macros for easier construction of openCypher query AST.
/// The usage of macros is very similar to how one would write openCypher. For
/// example:
///
/// AstTreeStorage storage; // Macros rely on storage being in scope.
///
/// QUERY(MATCH(PATTERN(NODE("n"), EDGE("e"), NODE("m"))),
/// WHERE(LESS(PROPERTY_LOOKUP("e", edge_prop), LITERAL(3))),
/// RETURN(SUM(PROPERTY_LOOKUP("m", prop)), AS("sum"),
/// ORDER_BY(IDENT("sum")),
/// SKIP(ADD(LITERAL(1), LITERAL(2)))));
///
/// Each of the macros is accompanied by a function. The functions use overload
/// resolution and template magic to provide a type safe way of constructing
/// queries. Although the functions can be used by themselves, it is more
/// convenient to use the macros.
///
#include <utility>
#include <vector>
#include "database/graph_db_datatypes.hpp"
namespace query {
namespace test_common {
// Custom types for SKIP and LIMIT and expressions, so that they can be used to
// resolve function calls.
// Custom types for ORDER BY, SKIP and LIMIT and expressions, so that they can
// be used to resolve function calls.
struct OrderBy {
std::vector<std::pair<Ordering, Expression *>> expressions;
};
struct Skip {
query::Expression *expression = nullptr;
Expression *expression = nullptr;
};
struct Limit {
query::Expression *expression = nullptr;
Expression *expression = nullptr;
};
// Helper functions for filling the OrderBy with expressions.
auto FillOrderBy(OrderBy &order_by, Expression *expression,
Ordering ordering = Ordering::ASC) {
order_by.expressions.emplace_back(ordering, expression);
}
template <class... T>
auto FillOrderBy(OrderBy &order_by, Expression *expression, Ordering ordering,
T... rest) {
FillOrderBy(order_by, expression, ordering);
FillOrderBy(order_by, rest...);
}
template <class... T>
auto FillOrderBy(OrderBy &order_by, Expression *expression, T... rest) {
FillOrderBy(order_by, expression);
FillOrderBy(order_by, rest...);
}
///
/// Create OrderBy expressions.
///
/// The supported combination of arguments is: (Expression, [Ordering])+
/// Since the Ordering is optional, by default it is ascending.
///
template <class... T>
auto GetOrderBy(T... exprs) {
OrderBy order_by;
FillOrderBy(order_by, exprs...);
return order_by;
}
///
/// Create PropertyLookup with given name and property.
///
@ -117,86 +173,77 @@ auto GetQuery(AstTreeStorage &storage, Clause *clause, T *... clauses) {
return GetQuery(storage, clauses...);
}
///
/// Create the return clause with given named expressions.
///
auto GetReturn(Return *ret, NamedExpression *named_expr) {
ret->body_.named_expressions.emplace_back(named_expr);
return ret;
// Helper functions for constructing RETURN and WITH clauses.
void FillReturnBody(ReturnBody &body, NamedExpression *named_expr) {
body.named_expressions.emplace_back(named_expr);
}
auto GetReturn(Return *ret, Skip skip, Limit limit = Limit{}) {
ret->body_.skip = skip.expression;
ret->body_.limit = limit.expression;
return ret;
void FillReturnBody(ReturnBody &body, Limit limit) {
body.limit = limit.expression;
}
auto GetReturn(Return *ret, Limit limit) {
ret->body_.limit = limit.expression;
return ret;
void FillReturnBody(ReturnBody &body, Skip skip, Limit limit = Limit{}) {
body.skip = skip.expression;
body.limit = limit.expression;
}
auto GetReturn(Return *ret, Expression *expr, NamedExpression *named_expr) {
void FillReturnBody(ReturnBody &body, OrderBy order_by, Limit limit = Limit{}) {
body.order_by = order_by.expressions;
body.limit = limit.expression;
}
void FillReturnBody(ReturnBody &body, OrderBy order_by, Skip skip,
Limit limit = Limit{}) {
body.order_by = order_by.expressions;
body.skip = skip.expression;
body.limit = limit.expression;
}
void FillReturnBody(ReturnBody &body, Expression *expr,
NamedExpression *named_expr) {
// This overload supports `RETURN(expr, AS(name))` construct, since
// NamedExpression does not inherit Expression.
named_expr->expression_ = expr;
ret->body_.named_expressions.emplace_back(named_expr);
return ret;
body.named_expressions.emplace_back(named_expr);
}
template <class... T>
auto GetReturn(Return *ret, Expression *expr, NamedExpression *named_expr,
T... rest) {
void FillReturnBody(ReturnBody &body, Expression *expr,
NamedExpression *named_expr, T... rest) {
named_expr->expression_ = expr;
ret->body_.named_expressions.emplace_back(named_expr);
return GetReturn(ret, rest...);
body.named_expressions.emplace_back(named_expr);
FillReturnBody(body, rest...);
}
template <class... T>
auto GetReturn(Return *ret, NamedExpression *named_expr, T... rest) {
ret->body_.named_expressions.emplace_back(named_expr);
return GetReturn(ret, rest...);
}
template <class... T>
auto GetReturn(AstTreeStorage &storage, T... exprs) {
auto ret = storage.Create<Return>();
return GetReturn(ret, exprs...);
void FillReturnBody(ReturnBody &body, NamedExpression *named_expr, T... rest) {
body.named_expressions.emplace_back(named_expr);
FillReturnBody(body, rest...);
}
///
/// Create the with clause with given named expressions.
/// Create the return clause with given expressions.
///
auto GetWith(With *with, NamedExpression *named_expr) {
with->body_.named_expressions.emplace_back(named_expr);
return with;
}
auto GetWith(With *with, Skip skip, Limit limit = {}) {
with->body_.skip = skip.expression;
with->body_.limit = limit.expression;
return with;
}
auto GetWith(With *with, Limit limit) {
with->body_.limit = limit.expression;
return with;
}
auto GetWith(With *with, Expression *expr, NamedExpression *named_expr) {
// This overload supports `RETURN(expr, AS(name))` construct, since
// NamedExpression does not inherit Expression.
named_expr->expression_ = expr;
with->body_.named_expressions.emplace_back(named_expr);
return with;
}
/// The supported expression combination of arguments is:
///
/// (NamedExpression | (Expression NamedExpression))+ [OrderBy] [Skip] [Limit]
///
/// When the pair (Expression NamedExpression) is given, the Expression will be
/// moved inside the NamedExpression. This is done, so that the constructs like
/// RETURN(expr, AS("name"), ...) are supported.
///
/// @sa GetWith
template <class... T>
auto GetWith(With *with, Expression *expr, NamedExpression *named_expr,
T... rest) {
named_expr->expression_ = expr;
with->body_.named_expressions.emplace_back(named_expr);
return GetWith(with, rest...);
}
template <class... T>
auto GetWith(With *with, NamedExpression *named_expr, T... rest) {
with->body_.named_expressions.emplace_back(named_expr);
return GetWith(with, rest...);
auto GetReturn(AstTreeStorage &storage, T... exprs) {
auto ret = storage.Create<Return>();
FillReturnBody(ret->body_, exprs...);
return ret;
}
///
/// Create the with clause with given expressions.
///
/// The supported expression combination is the same as for @c GetReturn.
///
/// @sa GetReturn
template <class... T>
auto GetWith(AstTreeStorage &storage, T... exprs) {
auto with = storage.Create<With>();
return GetWith(with, exprs...);
FillReturnBody(with->body_, exprs...);
return with;
}
///
@ -288,8 +335,11 @@ auto GetRemove(AstTreeStorage &storage, const std::string &name,
#define AS(name) storage.Create<query::NamedExpression>((name))
#define RETURN(...) query::test_common::GetReturn(storage, __VA_ARGS__)
#define WITH(...) query::test_common::GetWith(storage, __VA_ARGS__)
#define SKIP(expr) query::test_common::Skip{(expr)}
#define LIMIT(expr) query::test_common::Limit{(expr)}
#define ORDER_BY(...) query::test_common::GetOrderBy(__VA_ARGS__)
#define SKIP(expr) \
query::test_common::Skip { (expr) }
#define LIMIT(expr) \
query::test_common::Limit { (expr) }
#define DELETE(...) query::test_common::GetDelete(storage, {__VA_ARGS__})
#define DETACH_DELETE(...) \
query::test_common::GetDelete(storage, {__VA_ARGS__}, true)
@ -297,7 +347,10 @@ auto GetRemove(AstTreeStorage &storage, const std::string &name,
#define REMOVE(...) query::test_common::GetRemove(storage, __VA_ARGS__)
#define QUERY(...) query::test_common::GetQuery(storage, __VA_ARGS__)
// Various operators
#define ADD(expr1, expr2) storage.Create<query::AdditionOperator>((expr1), (expr2))
#define ADD(expr1, expr2) \
storage.Create<query::AdditionOperator>((expr1), (expr2))
#define LESS(expr1, expr2) storage.Create<query::LessOperator>((expr1), (expr2))
#define SUM(expr) \
storage.Create<query::Aggregation>((expr), query::Aggregation::Op::SUM)
#define COUNT(expr) \
storage.Create<query::Aggregation>((expr), query::Aggregation::Op::COUNT)

View File

@ -15,6 +15,7 @@
using namespace query::plan;
using query::AstTreeStorage;
using query::Symbol;
using query::SymbolTable;
using query::SymbolGenerator;
using Direction = query::EdgeAtom::Direction;
@ -57,13 +58,27 @@ using ExpectRemoveLabels = OpChecker<RemoveLabels>;
template <class TAccessor>
using ExpectExpandUniquenessFilter =
OpChecker<ExpandUniquenessFilter<TAccessor>>;
using ExpectAccumulate = OpChecker<Accumulate>;
using ExpectSkip = OpChecker<Skip>;
using ExpectLimit = OpChecker<Limit>;
using ExpectOrderBy = OpChecker<OrderBy>;
class ExpectAccumulate : public OpChecker<Accumulate> {
public:
ExpectAccumulate(const std::unordered_set<Symbol, Symbol::Hash> &symbols)
: symbols_(symbols) {}
void ExpectOp(Accumulate &op, const SymbolTable &symbol_table) override {
std::unordered_set<Symbol, Symbol::Hash> got_symbols(op.symbols().begin(),
op.symbols().end());
EXPECT_EQ(symbols_, got_symbols);
}
private:
const std::unordered_set<Symbol, Symbol::Hash> symbols_;
};
class ExpectAggregate : public OpChecker<Aggregate> {
public:
ExpectAggregate() = default;
ExpectAggregate(const std::vector<query::Aggregation *> &aggregations,
const std::unordered_set<query::Expression *> &group_by)
: aggregations_(aggregations), group_by_(group_by) {}
@ -119,6 +134,7 @@ class PlanChecker : public LogicalOperatorVisitor {
void Visit(Aggregate &op) override { CheckOp(op); }
void Visit(Skip &op) override { CheckOp(op); }
void Visit(Limit &op) override { CheckOp(op); }
void Visit(OrderBy &op) override { CheckOp(op); }
std::list<BaseOpChecker *> checkers_;
@ -132,18 +148,29 @@ class PlanChecker : public LogicalOperatorVisitor {
const SymbolTable &symbol_table_;
};
template <class... TChecker>
auto CheckPlan(query::Query &query, TChecker... checker) {
auto MakeSymbolTable(query::Query &query) {
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
query.Accept(symbol_generator);
auto plan = MakeLogicalPlan(query, symbol_table);
return symbol_table;
}
template <class... TChecker>
auto CheckPlan(LogicalOperator &plan, const SymbolTable &symbol_table,
TChecker... checker) {
std::list<BaseOpChecker *> checkers{&checker...};
PlanChecker plan_checker(checkers, symbol_table);
plan->Accept(plan_checker);
plan.Accept(plan_checker);
EXPECT_TRUE(plan_checker.checkers_.empty());
}
template <class... TChecker>
auto CheckPlan(query::Query &query, TChecker... checker) {
auto symbol_table = MakeSymbolTable(query);
auto plan = MakeLogicalPlan(query, symbol_table);
CheckPlan(*plan, symbol_table, checker...);
}
TEST(TestLogicalPlanner, MatchNodeReturn) {
// Test MATCH (n) RETURN n AS n
AstTreeStorage storage;
@ -154,8 +181,12 @@ TEST(TestLogicalPlanner, MatchNodeReturn) {
TEST(TestLogicalPlanner, CreateNodeReturn) {
// Test CREATE (n) RETURN n AS n
AstTreeStorage storage;
auto query = QUERY(CREATE(PATTERN(NODE("n"))), RETURN(IDENT("n"), AS("n")));
CheckPlan(*query, ExpectCreateNode(), ExpectAccumulate(), ExpectProduce());
auto ident_n = IDENT("n");
auto query = QUERY(CREATE(PATTERN(NODE("n"))), RETURN(ident_n, AS("n")));
auto symbol_table = MakeSymbolTable(*query);
auto acc = ExpectAccumulate({symbol_table.at(*ident_n)});
auto plan = MakeLogicalPlan(*query, symbol_table);
CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, ExpectProduce());
}
TEST(TestLogicalPlanner, CreateExpand) {
@ -413,12 +444,16 @@ TEST(TestLogicalPlanner, CreateWithSum) {
auto dba = dbms.active();
auto prop = dba->property("prop");
AstTreeStorage storage;
auto sum = SUM(PROPERTY_LOOKUP("n", prop));
auto n_prop = PROPERTY_LOOKUP("n", prop);
auto sum = SUM(n_prop);
auto query = QUERY(CREATE(PATTERN(NODE("n"))), WITH(sum, AS("sum")));
auto symbol_table = MakeSymbolTable(*query);
auto acc = ExpectAccumulate({symbol_table.at(*n_prop->expression_)});
auto aggr = ExpectAggregate({sum}, {});
auto plan = MakeLogicalPlan(*query, symbol_table);
// We expect both the accumulation and aggregation because the part before
// WITH updates the database.
CheckPlan(*query, ExpectCreateNode(), ExpectAccumulate(), aggr,
CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, aggr,
ExpectProduce());
}
@ -449,15 +484,19 @@ TEST(TestLogicalPlanner, MatchReturnSkipLimit) {
TEST(TestLogicalPlanner, CreateWithSkipReturnLimit) {
// Test CREATE (n) WITH n AS m SKIP 2 RETURN m LIMIT 1
AstTreeStorage storage;
auto ident_n = IDENT("n");
auto query = QUERY(CREATE(PATTERN(NODE("n"))),
WITH(IDENT("n"), AS("m"), SKIP(LITERAL(2))),
WITH(ident_n, AS("m"), SKIP(LITERAL(2))),
RETURN(IDENT("m"), AS("m"), LIMIT(LITERAL(1))));
auto symbol_table = MakeSymbolTable(*query);
auto acc = ExpectAccumulate({symbol_table.at(*ident_n)});
auto plan = MakeLogicalPlan(*query, symbol_table);
// Since we have a write query, we need to have Accumulate, so Skip and Limit
// need to come before it. This is a bit different than Neo4j, which optimizes
// WITH followed by RETURN as a single RETURN clause. This would cause the
// Limit operator to also appear before Accumulate, thus changing the
// behaviour. We've decided to diverge from Neo4j here, for consistency sake.
CheckPlan(*query, ExpectCreateNode(), ExpectSkip(), ExpectAccumulate(),
CheckPlan(*plan, symbol_table, ExpectCreateNode(), ExpectSkip(), acc,
ExpectProduce(), ExpectLimit(), ExpectProduce());
}
@ -467,14 +506,69 @@ TEST(TestLogicalPlanner, CreateReturnSumSkipLimit) {
auto dba = dbms.active();
auto prop = dba->property("prop");
AstTreeStorage storage;
auto sum = SUM(PROPERTY_LOOKUP("n", prop));
auto n_prop = PROPERTY_LOOKUP("n", prop);
auto sum = SUM(n_prop);
auto query = QUERY(CREATE(PATTERN(NODE("n"))),
RETURN(sum, AS("s"), SKIP(LITERAL(2)), LIMIT(LITERAL(1))));
auto symbol_table = MakeSymbolTable(*query);
auto acc = ExpectAccumulate({symbol_table.at(*n_prop->expression_)});
auto aggr = ExpectAggregate({sum}, {});
auto plan = MakeLogicalPlan(*query, symbol_table);
// We have a write query and aggregation, therefore Skip and Limit should come
// after Accumulate and Aggregate.
CheckPlan(*query, ExpectCreateNode(), ExpectAccumulate(), aggr, ExpectSkip(),
CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, aggr, ExpectSkip(),
ExpectLimit(), ExpectProduce());
}
TEST(TestLogicalPlanner, MatchReturnOrderBy) {
// Test MATCH (n) RETURN n ORDER BY n.prop
Dbms dbms;
auto dba = dbms.active();
auto prop = dba->property("prop");
AstTreeStorage storage;
auto ret = RETURN(IDENT("n"), AS("n"), ORDER_BY(PROPERTY_LOOKUP("n", prop)));
auto query = QUERY(MATCH(PATTERN(NODE("n"))), ret);
CheckPlan(*query, ExpectScanAll(), ExpectProduce(), ExpectOrderBy());
}
TEST(TestLogicalPlanner, CreateWithOrderByWhere) {
// Test CREATE (n) -[r :r]-> (m)
// WITH n AS new ORDER BY new.prop, r.prop WHERE m.prop < 42
Dbms dbms;
auto dba = dbms.active();
auto prop = dba->property("prop");
auto r_type = dba->edge_type("r");
AstTreeStorage storage;
auto ident_n = IDENT("n");
auto new_prop = PROPERTY_LOOKUP("new", prop);
auto r_prop = PROPERTY_LOOKUP("r", prop);
auto m_prop = PROPERTY_LOOKUP("m", prop);
auto query =
QUERY(CREATE(PATTERN(NODE("n"), EDGE("r", r_type, Direction::RIGHT),
NODE("m"))),
WITH(ident_n, AS("new"), ORDER_BY(new_prop, r_prop)),
WHERE(LESS(m_prop, LITERAL(42))));
auto symbol_table = MakeSymbolTable(*query);
// Since this is a write query, we expect to accumulate to old used symbols.
auto acc = ExpectAccumulate({
symbol_table.at(*ident_n), // `n` in WITH
symbol_table.at(*r_prop->expression_), // `r` in ORDER BY
symbol_table.at(*m_prop->expression_), // `m` in WHERE
});
auto plan = MakeLogicalPlan(*query, symbol_table);
CheckPlan(*plan, symbol_table, ExpectCreateNode(), ExpectCreateExpand(), acc,
ExpectProduce(), ExpectFilter(), ExpectOrderBy());
}
TEST(TestLogicalPlanner, ReturnAddSumCountOrderBy) {
// Test RETURN SUM(1) + COUNT(2) AS result ORDER BY result
AstTreeStorage storage;
auto sum = SUM(LITERAL(1));
auto count = COUNT(LITERAL(2));
auto query =
QUERY(RETURN(ADD(sum, count), AS("result"), ORDER_BY(IDENT("result"))));
auto aggr = ExpectAggregate({sum, count}, {});
CheckPlan(*query, aggr, ExpectProduce(), ExpectOrderBy());
}
} // namespace

View File

@ -334,13 +334,14 @@ TEST(TestSymbolGenerator, MatchWithWhere) {
}
TEST(TestSymbolGenerator, MatchWithWhereUnbound) {
// Test MATCH (old) WITH old AS n WHERE old.prop < 42
// Test MATCH (old) WITH COUNT(old) AS c WHERE old.prop < 42
Dbms dbms;
auto dba = dbms.active();
auto prop = dba->property("prop");
AstTreeStorage storage;
auto query = QUERY(MATCH(PATTERN(NODE("old"))), WITH(IDENT("old"), AS("n")),
WHERE(LESS(PROPERTY_LOOKUP("old", prop), LITERAL(42))));
auto query =
QUERY(MATCH(PATTERN(NODE("old"))), WITH(COUNT(IDENT("old")), AS("c")),
WHERE(LESS(PROPERTY_LOOKUP("old", prop), LITERAL(42))));
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
EXPECT_THROW(query->Accept(symbol_generator), UnboundVariableError);
@ -585,4 +586,67 @@ TEST(TestSymbolGenerator, SkipLimitIdentifier) {
}
}
TEST(TestSymbolGenerator, OrderBy) {
// Test MATCH (old) RETURN old AS new ORDER BY COUNT(1)
{
AstTreeStorage storage;
auto query =
QUERY(MATCH(PATTERN(NODE("old"))),
RETURN(IDENT("old"), AS("new"), ORDER_BY(COUNT(LITERAL(1)))));
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
EXPECT_THROW(query->Accept(symbol_generator), SemanticException);
}
// Test MATCH (old) RETURN COUNT(old) AS new ORDER BY old
{
AstTreeStorage storage;
auto query =
QUERY(MATCH(PATTERN(NODE("old"))),
RETURN(COUNT(IDENT("old")), AS("new"), ORDER_BY(IDENT("old"))));
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
EXPECT_THROW(query->Accept(symbol_generator), UnboundVariableError);
}
// Test MATCH (old) RETURN COUNT(old) AS new ORDER BY new
{
AstTreeStorage storage;
auto node = NODE("old");
auto ident_old = IDENT("old");
auto as_new = AS("new");
auto ident_new = IDENT("new");
auto query = QUERY(MATCH(PATTERN(node)),
RETURN(COUNT(ident_old), as_new, ORDER_BY(ident_new)));
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
query->Accept(symbol_generator);
// Symbols for `old`, `count(old)` and `new`
EXPECT_EQ(symbol_table.max_position(), 3);
auto old = symbol_table.at(*node->identifier_);
EXPECT_EQ(old, symbol_table.at(*ident_old));
auto new_sym = symbol_table.at(*as_new);
EXPECT_NE(old, new_sym);
EXPECT_EQ(new_sym, symbol_table.at(*ident_new));
}
// Test MATCH (old) RETURN old AS new ORDER BY old
{
AstTreeStorage storage;
auto node = NODE("old");
auto ident_old = IDENT("old");
auto as_new = AS("new");
auto by_old = IDENT("old");
auto query = QUERY(MATCH(PATTERN(node)),
RETURN(ident_old, as_new, ORDER_BY(by_old)));
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
query->Accept(symbol_generator);
// Symbols for `old` and `new`
EXPECT_EQ(symbol_table.max_position(), 2);
auto old = symbol_table.at(*node->identifier_);
EXPECT_EQ(old, symbol_table.at(*ident_old));
EXPECT_EQ(old, symbol_table.at(*by_old));
auto new_sym = symbol_table.at(*as_new);
EXPECT_NE(old, new_sym);
}
}
}