Plan OrderBy
Summary: Support OrderBy in test macros. Test planning OrderBy. Handle symbol visibility for ORDER BY and WHERE. Add Hash struct to Symbol. Collect used symbols in ORDER BY and WHERE. Reviewers: mislav.bradac, florijan Reviewed By: florijan Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D307
This commit is contained in:
parent
fe36835519
commit
e53e232e49
@ -30,31 +30,63 @@ auto SymbolGenerator::GetOrCreateSymbol(const std::string &name,
|
||||
return CreateSymbol(name, type);
|
||||
}
|
||||
|
||||
void SymbolGenerator::BindNamedExpressionSymbols(
|
||||
const std::vector<NamedExpression *> &named_expressions) {
|
||||
std::unordered_set<std::string> seen_names;
|
||||
for (auto &named_expr : named_expressions) {
|
||||
// Improvement would be to infer the type of the expression.
|
||||
void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) {
|
||||
for (auto &expr : body.named_expressions) {
|
||||
expr->Accept(*this);
|
||||
}
|
||||
// WITH/RETURN clause removes declarations of all the previous variables and
|
||||
// declares only those established through named expressions. New declarations
|
||||
// must not be visible inside named expressions themselves.
|
||||
bool removed_old_names = false;
|
||||
if ((!where && body.order_by.empty()) || scope_.has_aggregation) {
|
||||
// WHERE and ORDER BY need to see both the old and new symbols, unless we
|
||||
// have an aggregation. Therefore, we can clear the symbols immediately if
|
||||
// there is neither ORDER BY nor WHERE, or we have an aggregation.
|
||||
scope_.symbols.clear();
|
||||
removed_old_names = true;
|
||||
}
|
||||
// Create symbols for named expressions.
|
||||
std::unordered_set<std::string> new_names;
|
||||
for (auto &named_expr : body.named_expressions) {
|
||||
const auto &name = named_expr->name_;
|
||||
if (!seen_names.insert(name).second) {
|
||||
if (!new_names.insert(name).second) {
|
||||
throw SemanticException(
|
||||
"Multiple results with the same name '{}' are not allowed.", name);
|
||||
}
|
||||
// An improvement would be to infer the type of the expression, so that the
|
||||
// new symbol would have a more specific type.
|
||||
symbol_table_[*named_expr] = CreateSymbol(name);
|
||||
}
|
||||
}
|
||||
|
||||
void SymbolGenerator::VisitSkipAndLimit(Expression *skip, Expression *limit) {
|
||||
if (skip) {
|
||||
scope_.in_order_by = true;
|
||||
for (const auto &order_pair : body.order_by) {
|
||||
order_pair.second->Accept(*this);
|
||||
}
|
||||
scope_.in_order_by = false;
|
||||
if (body.skip) {
|
||||
scope_.in_skip = true;
|
||||
skip->Accept(*this);
|
||||
body.skip->Accept(*this);
|
||||
scope_.in_skip = false;
|
||||
}
|
||||
if (limit) {
|
||||
if (body.limit) {
|
||||
scope_.in_limit = true;
|
||||
limit->Accept(*this);
|
||||
body.limit->Accept(*this);
|
||||
scope_.in_limit = false;
|
||||
}
|
||||
if (where) where->Accept(*this);
|
||||
if (!removed_old_names) {
|
||||
// We have an ORDER BY or WHERE, but no aggregation, which means we didn't
|
||||
// clear the old symbols, so do it now. We cannot just call clear, because
|
||||
// we've added new symbols.
|
||||
for (auto sym_it = scope_.symbols.begin();
|
||||
sym_it != scope_.symbols.end();) {
|
||||
if (new_names.find(sym_it->first) == new_names.end()) {
|
||||
sym_it = scope_.symbols.erase(sym_it);
|
||||
} else {
|
||||
sym_it++;
|
||||
}
|
||||
}
|
||||
}
|
||||
scope_.has_aggregation = false;
|
||||
}
|
||||
|
||||
// Clauses
|
||||
@ -64,33 +96,21 @@ void SymbolGenerator::PostVisit(Create &create) { scope_.in_create = false; }
|
||||
|
||||
bool SymbolGenerator::PreVisit(Return &ret) {
|
||||
scope_.in_return = true;
|
||||
for (auto &expr : ret.body_.named_expressions) {
|
||||
expr->Accept(*this);
|
||||
}
|
||||
// Named expressions establish bindings for expressions which come after
|
||||
// return, but not for the expressions contained inside.
|
||||
BindNamedExpressionSymbols(ret.body_.named_expressions);
|
||||
VisitSkipAndLimit(ret.body_.skip, ret.body_.limit);
|
||||
VisitReturnBody(ret.body_);
|
||||
scope_.in_return = false;
|
||||
return false; // We handled the traversal ourselves.
|
||||
}
|
||||
|
||||
bool SymbolGenerator::PreVisit(With &with) {
|
||||
scope_.in_with = true;
|
||||
for (auto &expr : with.body_.named_expressions) {
|
||||
expr->Accept(*this);
|
||||
}
|
||||
VisitReturnBody(with.body_, with.where_);
|
||||
scope_.in_with = false;
|
||||
// WITH clause removes declarations of all the previous variables and declares
|
||||
// only those established through named expressions. New declarations must not
|
||||
// be visible inside named expressions themselves.
|
||||
scope_.symbols.clear();
|
||||
BindNamedExpressionSymbols(with.body_.named_expressions);
|
||||
VisitSkipAndLimit(with.body_.skip, with.body_.limit);
|
||||
if (with.where_) with.where_->Accept(*this);
|
||||
return false; // We handled the traversal ourselves.
|
||||
}
|
||||
|
||||
void SymbolGenerator::Visit(Where &) { scope_.in_where = true; }
|
||||
void SymbolGenerator::PostVisit(Where &) { scope_.in_where = false; }
|
||||
|
||||
// Expressions
|
||||
|
||||
void SymbolGenerator::Visit(Identifier &ident) {
|
||||
@ -133,7 +153,8 @@ void SymbolGenerator::Visit(Aggregation &aggr) {
|
||||
// Check if the aggregation can be used in this context. This check should
|
||||
// probably move to a separate phase, which checks if the query is well
|
||||
// formed.
|
||||
if (!scope_.in_return && !scope_.in_with) {
|
||||
if ((!scope_.in_return && !scope_.in_with) || scope_.in_order_by ||
|
||||
scope_.in_skip || scope_.in_limit || scope_.in_where) {
|
||||
throw SemanticException(
|
||||
"Aggregation functions are only allowed in WITH and RETURN");
|
||||
}
|
||||
@ -146,6 +167,7 @@ void SymbolGenerator::Visit(Aggregation &aggr) {
|
||||
// Currently, we only have aggregation operators which return numbers.
|
||||
symbol_table_[aggr] = symbol_table_.CreateSymbol("", Symbol::Type::Number);
|
||||
scope_.in_aggregation = true;
|
||||
scope_.has_aggregation = true;
|
||||
}
|
||||
|
||||
void SymbolGenerator::PostVisit(Aggregation &aggr) {
|
||||
|
@ -29,6 +29,8 @@ class SymbolGenerator : public TreeVisitorBase {
|
||||
void PostVisit(Create &) override;
|
||||
bool PreVisit(Return &) override;
|
||||
bool PreVisit(With &) override;
|
||||
void Visit(Where &) override;
|
||||
void PostVisit(Where &) override;
|
||||
|
||||
// Expressions
|
||||
void Visit(Identifier &) override;
|
||||
@ -62,6 +64,11 @@ class SymbolGenerator : public TreeVisitorBase {
|
||||
bool in_with{false};
|
||||
bool in_skip{false};
|
||||
bool in_limit{false};
|
||||
bool in_order_by{false};
|
||||
bool in_where{false};
|
||||
// True if the return/with contains an aggregation in any named expression.
|
||||
bool has_aggregation{false};
|
||||
// Map from variable names to symbols.
|
||||
std::map<std::string, Symbol> symbols;
|
||||
};
|
||||
|
||||
@ -77,10 +84,7 @@ class SymbolGenerator : public TreeVisitorBase {
|
||||
auto GetOrCreateSymbol(const std::string &name,
|
||||
Symbol::Type type = Symbol::Type::Any);
|
||||
|
||||
void BindNamedExpressionSymbols(
|
||||
const std::vector<NamedExpression *> &named_expressions);
|
||||
|
||||
void VisitSkipAndLimit(Expression *skip, Expression *limit);
|
||||
void VisitReturnBody(ReturnBody &body, Where *where = nullptr);
|
||||
|
||||
SymbolTable &symbol_table_;
|
||||
Scope scope_;
|
||||
|
@ -13,10 +13,17 @@ class Symbol {
|
||||
enum class Type { Any, Vertex, Edge, Path, Number };
|
||||
|
||||
static std::string TypeToString(Type type) {
|
||||
const char *enum_string[] = {"Any", "Vertex", "Edge", "Path"};
|
||||
const char *enum_string[] = {"Any", "Vertex", "Edge", "Path", "Number"};
|
||||
return enum_string[static_cast<int>(type)];
|
||||
}
|
||||
|
||||
// Calculates the Symbol hash based on its position.
|
||||
struct Hash {
|
||||
size_t operator()(const Symbol &symbol) const {
|
||||
return std::hash<int>{}(symbol.position_);
|
||||
}
|
||||
};
|
||||
|
||||
Symbol() {}
|
||||
Symbol(const std::string &name, int position, Type type = Type::Any)
|
||||
: name_(name), position_(position), type_(type) {}
|
||||
|
@ -42,25 +42,32 @@ void Interpret(const std::string &query, GraphDbAccessor &db_accessor,
|
||||
Frame frame(symbol_table.max_position());
|
||||
|
||||
std::vector<std::string> header;
|
||||
bool is_return = false;
|
||||
std::vector<Symbol> output_symbols;
|
||||
if (auto produce = dynamic_cast<plan::Produce *>(logical_plan.get())) {
|
||||
// top level node in the operator tree is a produce (return)
|
||||
is_return = true;
|
||||
// collect the symbols from the return clause
|
||||
for (auto named_expression : produce->named_expressions())
|
||||
output_symbols.emplace_back(symbol_table[*named_expression]);
|
||||
} else if (auto order_by =
|
||||
dynamic_cast<plan::OrderBy *>(logical_plan.get())) {
|
||||
is_return = true;
|
||||
output_symbols = order_by->output_symbols();
|
||||
}
|
||||
if (is_return) {
|
||||
// top level node in the operator tree is a produce/order_by (return)
|
||||
// so stream out results
|
||||
|
||||
// generate header
|
||||
for (auto named_expression : produce->named_expressions())
|
||||
header.push_back(named_expression->name_);
|
||||
for (const auto &symbol : output_symbols) header.push_back(symbol.name_);
|
||||
stream.Header(header);
|
||||
|
||||
// collect the symbols from the return clause
|
||||
std::vector<Symbol> symbols;
|
||||
for (auto named_expression : produce->named_expressions())
|
||||
symbols.emplace_back(symbol_table[*named_expression]);
|
||||
|
||||
// stream out results
|
||||
auto cursor = produce->MakeCursor(db_accessor);
|
||||
auto cursor = logical_plan->MakeCursor(db_accessor);
|
||||
while (cursor->Pull(frame, symbol_table)) {
|
||||
std::vector<TypedValue> values;
|
||||
for (auto &symbol : symbols) values.emplace_back(frame[symbol]);
|
||||
for (const auto &symbol : output_symbols)
|
||||
values.emplace_back(frame[symbol]);
|
||||
stream.Result(values);
|
||||
}
|
||||
} else if (dynamic_cast<plan::CreateNode *>(logical_plan.get()) ||
|
||||
@ -72,7 +79,7 @@ void Interpret(const std::string &query, GraphDbAccessor &db_accessor,
|
||||
dynamic_cast<plan::RemoveLabels *>(logical_plan.get()) ||
|
||||
dynamic_cast<plan::Delete *>(logical_plan.get())) {
|
||||
stream.Header(header);
|
||||
auto cursor = logical_plan.get()->MakeCursor(db_accessor);
|
||||
auto cursor = logical_plan->MakeCursor(db_accessor);
|
||||
while (cursor->Pull(frame, symbol_table)) continue;
|
||||
} else {
|
||||
throw QueryRuntimeException("Unknown top level LogicalOp");
|
||||
|
@ -1224,9 +1224,9 @@ void Limit::LimitCursor::Reset() {
|
||||
}
|
||||
|
||||
OrderBy::OrderBy(const std::shared_ptr<LogicalOperator> &input,
|
||||
const std::vector<std::pair<Ordering, Expression *>> order_by,
|
||||
const std::vector<Symbol> remember)
|
||||
: input_(input), remember_(remember) {
|
||||
const std::vector<std::pair<Ordering, Expression *>> &order_by,
|
||||
const std::vector<Symbol> &output_symbols)
|
||||
: input_(input), output_symbols_(output_symbols) {
|
||||
// split the order_by vector into two vectors of orderings and expressions
|
||||
std::vector<Ordering> ordering;
|
||||
ordering.reserve(order_by.size());
|
||||
@ -1259,12 +1259,12 @@ bool OrderBy::OrderByCursor::Pull(Frame &frame,
|
||||
order_by.emplace_back(evaluator.PopBack());
|
||||
}
|
||||
|
||||
// collect the remember elements
|
||||
std::list<TypedValue> remember;
|
||||
for (const Symbol &remember_sym : self_.remember_)
|
||||
remember.emplace_back(frame[remember_sym]);
|
||||
// collect the output elements
|
||||
std::list<TypedValue> output;
|
||||
for (const Symbol &output_sym : self_.output_symbols_)
|
||||
output.emplace_back(frame[output_sym]);
|
||||
|
||||
cache_.emplace_back(order_by, remember);
|
||||
cache_.emplace_back(order_by, output);
|
||||
}
|
||||
|
||||
std::sort(cache_.begin(), cache_.end(),
|
||||
@ -1278,13 +1278,13 @@ bool OrderBy::OrderByCursor::Pull(Frame &frame,
|
||||
|
||||
if (cache_it_ == cache_.end()) return false;
|
||||
|
||||
// place the remembered values on the frame
|
||||
debug_assert(self_.remember_.size() == cache_it_->second.size(),
|
||||
"Number of values does not match the number of remember symbols "
|
||||
// place the output values on the frame
|
||||
debug_assert(self_.output_symbols_.size() == cache_it_->second.size(),
|
||||
"Number of values does not match the number of output symbols "
|
||||
"in OrderBy");
|
||||
auto remember_sym_it = self_.remember_.begin();
|
||||
for (const TypedValue &remember : cache_it_->second)
|
||||
frame[*remember_sym_it++] = remember;
|
||||
auto output_sym_it = self_.output_symbols_.begin();
|
||||
for (const TypedValue &output : cache_it_->second)
|
||||
frame[*output_sym_it++] = output;
|
||||
|
||||
cache_it_++;
|
||||
return true;
|
||||
|
@ -827,6 +827,8 @@ class Accumulate : public LogicalOperator {
|
||||
void Accept(LogicalOperatorVisitor &visitor) override;
|
||||
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
|
||||
|
||||
const auto &symbols() const { return symbols_; };
|
||||
|
||||
private:
|
||||
const std::shared_ptr<LogicalOperator> input_;
|
||||
const std::vector<Symbol> symbols_;
|
||||
@ -1066,11 +1068,13 @@ class Limit : public LogicalOperator {
|
||||
class OrderBy : public LogicalOperator {
|
||||
public:
|
||||
OrderBy(const std::shared_ptr<LogicalOperator> &input,
|
||||
const std::vector<std::pair<Ordering, Expression *>> order_by,
|
||||
const std::vector<Symbol> remember);
|
||||
const std::vector<std::pair<Ordering, Expression *>> &order_by,
|
||||
const std::vector<Symbol> &output_symbols);
|
||||
void Accept(LogicalOperatorVisitor &visitor) override;
|
||||
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
|
||||
|
||||
const auto &output_symbols() const { return output_symbols_; }
|
||||
|
||||
private:
|
||||
// custom Comparator type for comparing lists of TypedValues
|
||||
// does lexicographical ordering of elements based on the above
|
||||
@ -1091,7 +1095,7 @@ class OrderBy : public LogicalOperator {
|
||||
const std::shared_ptr<LogicalOperator> input_;
|
||||
TypedValueListCompare compare_;
|
||||
std::vector<Expression *> order_by_;
|
||||
const std::vector<Symbol> remember_;
|
||||
const std::vector<Symbol> output_symbols_;
|
||||
|
||||
// custom comparison for TypedValue objects
|
||||
// behaves generally like Neo's ORDER BY comparison operator:
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include "query/plan/planner.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <unordered_set>
|
||||
|
||||
@ -184,17 +185,37 @@ auto GenMatch(Match &match, LogicalOperator *input_op,
|
||||
// * flag whether the results need to be DISTINCT;
|
||||
// * optional SKIP expression;
|
||||
// * optional LIMIT expression and
|
||||
// * optional ORDER BY expression.
|
||||
// * optional ORDER BY expressions.
|
||||
//
|
||||
// In addition to the above, we collect information on used symbols,
|
||||
// aggregations and expressions used for group by.
|
||||
class ReturnBodyContext : public TreeVisitorBase {
|
||||
public:
|
||||
ReturnBodyContext(const ReturnBody &body, const SymbolTable &symbol_table)
|
||||
: body_(body), symbol_table_(symbol_table) {
|
||||
ReturnBodyContext(const ReturnBody &body, const SymbolTable &symbol_table,
|
||||
Where *where = nullptr)
|
||||
: body_(body), symbol_table_(symbol_table), where_(where) {
|
||||
// Collect symbols from named expressions.
|
||||
output_symbols_.reserve(body_.named_expressions.size());
|
||||
for (auto &named_expr : body_.named_expressions) {
|
||||
output_symbols_.emplace_back(symbol_table_.at(*named_expr));
|
||||
named_expr->Accept(*this);
|
||||
}
|
||||
if (aggregations_.empty()) {
|
||||
// Visit order_by and where if we do not have aggregations. This way we
|
||||
// prevent collecting group_by expressions from order_by and where, which
|
||||
// would be very wrong. When we have aggregation, order_by and where can
|
||||
// only use new symbols (ensured in semantic analysis), so we don't care
|
||||
// about collecting used_symbols. Also, semantic analysis should
|
||||
// have prevented any aggregations from appearing here.
|
||||
for (const auto &order_pair : body.order_by) {
|
||||
order_pair.second->Accept(*this);
|
||||
}
|
||||
if (where) {
|
||||
where->Accept(*this);
|
||||
}
|
||||
debug_assert(aggregations_.empty(),
|
||||
"Unexpected aggregations in ORDER BY or WHERE");
|
||||
}
|
||||
}
|
||||
|
||||
using TreeVisitorBase::PreVisit;
|
||||
@ -204,7 +225,13 @@ class ReturnBodyContext : public TreeVisitorBase {
|
||||
void Visit(Literal &) override { has_aggregation_.emplace_back(false); }
|
||||
|
||||
void Visit(Identifier &ident) override {
|
||||
symbols_.insert(symbol_table_.at(ident));
|
||||
const auto &symbol = symbol_table_.at(ident);
|
||||
if (std::find(output_symbols_.begin(), output_symbols_.end(), symbol) ==
|
||||
output_symbols_.end()) {
|
||||
// Don't pick up new symbols, even though they may be used in ORDER BY or
|
||||
// WHERE.
|
||||
used_symbols_.insert(symbol);
|
||||
}
|
||||
has_aggregation_.emplace_back(false);
|
||||
}
|
||||
|
||||
@ -217,7 +244,7 @@ class ReturnBodyContext : public TreeVisitorBase {
|
||||
bool aggr1 = has_aggregation_.back(); \
|
||||
has_aggregation_.pop_back(); \
|
||||
bool has_aggr = aggr1 || aggr2; \
|
||||
if (has_aggr) { \
|
||||
if (has_aggr && !(aggr1 && aggr2)) { \
|
||||
/* Group by the expression which does not contain aggregation. */ \
|
||||
/* Possible optimization is to ignore constant value expressions */ \
|
||||
group_by_.emplace_back(aggr1 ? op.expression2_ : op.expression1_); \
|
||||
@ -264,13 +291,18 @@ class ReturnBodyContext : public TreeVisitorBase {
|
||||
bool distinct() const { return body_.distinct; }
|
||||
// Named expressions which are used to produce results.
|
||||
const auto &named_expressions() const { return body_.named_expressions; }
|
||||
// Pairs of (Ordering, Expression *) for sorting results.
|
||||
const auto &order_by() const { return body_.order_by; }
|
||||
// Optional expression which determines how many results to skip.
|
||||
auto *skip() const { return body_.skip; }
|
||||
// Optional expression which determines how many results to produce.
|
||||
auto *limit() const { return body_.limit; }
|
||||
// Optional Where clause for filtering.
|
||||
const auto *where() const { return where_; }
|
||||
// Set of symbols used inside the visited expressions outside of aggregation
|
||||
// expression.
|
||||
const auto &symbols() const { return symbols_; }
|
||||
// expression. These only includes old symbols, even though new ones may have
|
||||
// been used in ORDER BY or WHERE.
|
||||
const auto &used_symbols() const { return used_symbols_; }
|
||||
// List of aggregation elements found in expressions.
|
||||
const auto &aggregations() const { return aggregations_; }
|
||||
// When there is at least one aggregation element, all the non-aggregate (sub)
|
||||
@ -278,17 +310,16 @@ class ReturnBodyContext : public TreeVisitorBase {
|
||||
// AS sum, n.c AS nc`, we will group by `2 * n.b` and `n.c`.
|
||||
const auto &group_by() const { return group_by_; }
|
||||
|
||||
private:
|
||||
// Calculates the Symbol hash based on its position.
|
||||
struct SymbolHash {
|
||||
size_t operator()(const Symbol &symbol) const {
|
||||
return std::hash<int>{}(symbol.position_);
|
||||
}
|
||||
};
|
||||
// All symbols generated by named expressions. They are collected in order of
|
||||
// named_expressions.
|
||||
const auto &output_symbols() const { return output_symbols_; }
|
||||
|
||||
private:
|
||||
const ReturnBody &body_;
|
||||
const SymbolTable &symbol_table_;
|
||||
std::unordered_set<Symbol, SymbolHash> symbols_;
|
||||
const Where *const where_ = nullptr;
|
||||
std::unordered_set<Symbol, Symbol::Hash> used_symbols_;
|
||||
std::vector<Symbol> output_symbols_;
|
||||
std::vector<Aggregate::Element> aggregations_;
|
||||
std::vector<Expression *> group_by_;
|
||||
// Flag indicating whether an expression contains an aggregation.
|
||||
@ -314,8 +345,8 @@ auto GenReturnBody(LogicalOperator *input_op, bool advance_command,
|
||||
// TODO: Plan with distinct, when operator available.
|
||||
throw utils::NotYetImplemented();
|
||||
}
|
||||
auto symbols =
|
||||
std::vector<Symbol>(body.symbols().begin(), body.symbols().end());
|
||||
std::vector<Symbol> used_symbols(body.used_symbols().begin(),
|
||||
body.used_symbols().end());
|
||||
auto last_op = input_op;
|
||||
if (body.aggregations().empty()) {
|
||||
// In case when we have SKIP/LIMIT and we don't perform aggregations, we
|
||||
@ -329,18 +360,30 @@ auto GenReturnBody(LogicalOperator *input_op, bool advance_command,
|
||||
// We only advance the command in Accumulate. This is done for WITH clause,
|
||||
// when the first part updated the database. RETURN clause may only need an
|
||||
// accumulation after updates, without advancing the command.
|
||||
last_op = new Accumulate(std::shared_ptr<LogicalOperator>(last_op), symbols,
|
||||
advance_command);
|
||||
last_op = new Accumulate(std::shared_ptr<LogicalOperator>(last_op),
|
||||
used_symbols, advance_command);
|
||||
}
|
||||
if (!body.aggregations().empty()) {
|
||||
// When we have aggregation, SKIP/LIMIT should always come after it.
|
||||
last_op = GenSkipLimit(
|
||||
new Aggregate(std::shared_ptr<LogicalOperator>(last_op),
|
||||
body.aggregations(), body.group_by(), symbols),
|
||||
body.aggregations(), body.group_by(), used_symbols),
|
||||
body);
|
||||
}
|
||||
return new Produce(std::shared_ptr<LogicalOperator>(last_op),
|
||||
body.named_expressions());
|
||||
last_op = new Produce(std::shared_ptr<LogicalOperator>(last_op),
|
||||
body.named_expressions());
|
||||
// Where may see new symbols so it comes after we generate Produce.
|
||||
if (body.where()) {
|
||||
last_op = new Filter(std::shared_ptr<LogicalOperator>(last_op),
|
||||
body.where()->expression_);
|
||||
}
|
||||
// Like Where, OrderBy can read from symbols established by named expressions
|
||||
// in Produce, so it must come after it.
|
||||
if (!body.order_by().empty()) {
|
||||
last_op = new OrderBy(std::shared_ptr<LogicalOperator>(last_op),
|
||||
body.order_by(), body.output_symbols());
|
||||
}
|
||||
return last_op;
|
||||
}
|
||||
|
||||
auto GenWith(With &with, LogicalOperator *input_op,
|
||||
@ -353,17 +396,13 @@ auto GenWith(With &with, LogicalOperator *input_op,
|
||||
bool accumulate = is_write;
|
||||
// No need to advance the command if we only performed reads.
|
||||
bool advance_command = is_write;
|
||||
ReturnBodyContext body(with.body_, symbol_table);
|
||||
ReturnBodyContext body(with.body_, symbol_table, with.where_);
|
||||
LogicalOperator *last_op =
|
||||
GenReturnBody(input_op, advance_command, body, accumulate);
|
||||
// Reset bound symbols, so that only those in WITH are exposed.
|
||||
bound_symbols.clear();
|
||||
for (auto &named_expr : with.body_.named_expressions) {
|
||||
BindSymbol(bound_symbols, symbol_table.at(*named_expr));
|
||||
}
|
||||
if (with.where_) {
|
||||
last_op = new Filter(std::shared_ptr<LogicalOperator>(last_op),
|
||||
with.where_->expression_);
|
||||
for (const auto &symbol : body.output_symbols()) {
|
||||
BindSymbol(bound_symbols, symbol);
|
||||
}
|
||||
return last_op;
|
||||
}
|
||||
|
@ -1,18 +1,74 @@
|
||||
///
|
||||
/// @file
|
||||
/// This file provides macros for easier construction of openCypher query AST.
|
||||
/// The usage of macros is very similar to how one would write openCypher. For
|
||||
/// example:
|
||||
///
|
||||
/// AstTreeStorage storage; // Macros rely on storage being in scope.
|
||||
///
|
||||
/// QUERY(MATCH(PATTERN(NODE("n"), EDGE("e"), NODE("m"))),
|
||||
/// WHERE(LESS(PROPERTY_LOOKUP("e", edge_prop), LITERAL(3))),
|
||||
/// RETURN(SUM(PROPERTY_LOOKUP("m", prop)), AS("sum"),
|
||||
/// ORDER_BY(IDENT("sum")),
|
||||
/// SKIP(ADD(LITERAL(1), LITERAL(2)))));
|
||||
///
|
||||
/// Each of the macros is accompanied by a function. The functions use overload
|
||||
/// resolution and template magic to provide a type safe way of constructing
|
||||
/// queries. Although the functions can be used by themselves, it is more
|
||||
/// convenient to use the macros.
|
||||
///
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "database/graph_db_datatypes.hpp"
|
||||
|
||||
namespace query {
|
||||
|
||||
namespace test_common {
|
||||
|
||||
// Custom types for SKIP and LIMIT and expressions, so that they can be used to
|
||||
// resolve function calls.
|
||||
// Custom types for ORDER BY, SKIP and LIMIT and expressions, so that they can
|
||||
// be used to resolve function calls.
|
||||
struct OrderBy {
|
||||
std::vector<std::pair<Ordering, Expression *>> expressions;
|
||||
};
|
||||
struct Skip {
|
||||
query::Expression *expression = nullptr;
|
||||
Expression *expression = nullptr;
|
||||
};
|
||||
struct Limit {
|
||||
query::Expression *expression = nullptr;
|
||||
Expression *expression = nullptr;
|
||||
};
|
||||
|
||||
// Helper functions for filling the OrderBy with expressions.
|
||||
auto FillOrderBy(OrderBy &order_by, Expression *expression,
|
||||
Ordering ordering = Ordering::ASC) {
|
||||
order_by.expressions.emplace_back(ordering, expression);
|
||||
}
|
||||
template <class... T>
|
||||
auto FillOrderBy(OrderBy &order_by, Expression *expression, Ordering ordering,
|
||||
T... rest) {
|
||||
FillOrderBy(order_by, expression, ordering);
|
||||
FillOrderBy(order_by, rest...);
|
||||
}
|
||||
template <class... T>
|
||||
auto FillOrderBy(OrderBy &order_by, Expression *expression, T... rest) {
|
||||
FillOrderBy(order_by, expression);
|
||||
FillOrderBy(order_by, rest...);
|
||||
}
|
||||
|
||||
///
|
||||
/// Create OrderBy expressions.
|
||||
///
|
||||
/// The supported combination of arguments is: (Expression, [Ordering])+
|
||||
/// Since the Ordering is optional, by default it is ascending.
|
||||
///
|
||||
template <class... T>
|
||||
auto GetOrderBy(T... exprs) {
|
||||
OrderBy order_by;
|
||||
FillOrderBy(order_by, exprs...);
|
||||
return order_by;
|
||||
}
|
||||
|
||||
///
|
||||
/// Create PropertyLookup with given name and property.
|
||||
///
|
||||
@ -117,86 +173,77 @@ auto GetQuery(AstTreeStorage &storage, Clause *clause, T *... clauses) {
|
||||
return GetQuery(storage, clauses...);
|
||||
}
|
||||
|
||||
///
|
||||
/// Create the return clause with given named expressions.
|
||||
///
|
||||
auto GetReturn(Return *ret, NamedExpression *named_expr) {
|
||||
ret->body_.named_expressions.emplace_back(named_expr);
|
||||
return ret;
|
||||
// Helper functions for constructing RETURN and WITH clauses.
|
||||
void FillReturnBody(ReturnBody &body, NamedExpression *named_expr) {
|
||||
body.named_expressions.emplace_back(named_expr);
|
||||
}
|
||||
auto GetReturn(Return *ret, Skip skip, Limit limit = Limit{}) {
|
||||
ret->body_.skip = skip.expression;
|
||||
ret->body_.limit = limit.expression;
|
||||
return ret;
|
||||
void FillReturnBody(ReturnBody &body, Limit limit) {
|
||||
body.limit = limit.expression;
|
||||
}
|
||||
auto GetReturn(Return *ret, Limit limit) {
|
||||
ret->body_.limit = limit.expression;
|
||||
return ret;
|
||||
void FillReturnBody(ReturnBody &body, Skip skip, Limit limit = Limit{}) {
|
||||
body.skip = skip.expression;
|
||||
body.limit = limit.expression;
|
||||
}
|
||||
auto GetReturn(Return *ret, Expression *expr, NamedExpression *named_expr) {
|
||||
void FillReturnBody(ReturnBody &body, OrderBy order_by, Limit limit = Limit{}) {
|
||||
body.order_by = order_by.expressions;
|
||||
body.limit = limit.expression;
|
||||
}
|
||||
void FillReturnBody(ReturnBody &body, OrderBy order_by, Skip skip,
|
||||
Limit limit = Limit{}) {
|
||||
body.order_by = order_by.expressions;
|
||||
body.skip = skip.expression;
|
||||
body.limit = limit.expression;
|
||||
}
|
||||
void FillReturnBody(ReturnBody &body, Expression *expr,
|
||||
NamedExpression *named_expr) {
|
||||
// This overload supports `RETURN(expr, AS(name))` construct, since
|
||||
// NamedExpression does not inherit Expression.
|
||||
named_expr->expression_ = expr;
|
||||
ret->body_.named_expressions.emplace_back(named_expr);
|
||||
return ret;
|
||||
body.named_expressions.emplace_back(named_expr);
|
||||
}
|
||||
template <class... T>
|
||||
auto GetReturn(Return *ret, Expression *expr, NamedExpression *named_expr,
|
||||
T... rest) {
|
||||
void FillReturnBody(ReturnBody &body, Expression *expr,
|
||||
NamedExpression *named_expr, T... rest) {
|
||||
named_expr->expression_ = expr;
|
||||
ret->body_.named_expressions.emplace_back(named_expr);
|
||||
return GetReturn(ret, rest...);
|
||||
body.named_expressions.emplace_back(named_expr);
|
||||
FillReturnBody(body, rest...);
|
||||
}
|
||||
template <class... T>
|
||||
auto GetReturn(Return *ret, NamedExpression *named_expr, T... rest) {
|
||||
ret->body_.named_expressions.emplace_back(named_expr);
|
||||
return GetReturn(ret, rest...);
|
||||
}
|
||||
template <class... T>
|
||||
auto GetReturn(AstTreeStorage &storage, T... exprs) {
|
||||
auto ret = storage.Create<Return>();
|
||||
return GetReturn(ret, exprs...);
|
||||
void FillReturnBody(ReturnBody &body, NamedExpression *named_expr, T... rest) {
|
||||
body.named_expressions.emplace_back(named_expr);
|
||||
FillReturnBody(body, rest...);
|
||||
}
|
||||
|
||||
///
|
||||
/// Create the with clause with given named expressions.
|
||||
/// Create the return clause with given expressions.
|
||||
///
|
||||
auto GetWith(With *with, NamedExpression *named_expr) {
|
||||
with->body_.named_expressions.emplace_back(named_expr);
|
||||
return with;
|
||||
}
|
||||
auto GetWith(With *with, Skip skip, Limit limit = {}) {
|
||||
with->body_.skip = skip.expression;
|
||||
with->body_.limit = limit.expression;
|
||||
return with;
|
||||
}
|
||||
auto GetWith(With *with, Limit limit) {
|
||||
with->body_.limit = limit.expression;
|
||||
return with;
|
||||
}
|
||||
auto GetWith(With *with, Expression *expr, NamedExpression *named_expr) {
|
||||
// This overload supports `RETURN(expr, AS(name))` construct, since
|
||||
// NamedExpression does not inherit Expression.
|
||||
named_expr->expression_ = expr;
|
||||
with->body_.named_expressions.emplace_back(named_expr);
|
||||
return with;
|
||||
}
|
||||
/// The supported expression combination of arguments is:
|
||||
///
|
||||
/// (NamedExpression | (Expression NamedExpression))+ [OrderBy] [Skip] [Limit]
|
||||
///
|
||||
/// When the pair (Expression NamedExpression) is given, the Expression will be
|
||||
/// moved inside the NamedExpression. This is done, so that the constructs like
|
||||
/// RETURN(expr, AS("name"), ...) are supported.
|
||||
///
|
||||
/// @sa GetWith
|
||||
template <class... T>
|
||||
auto GetWith(With *with, Expression *expr, NamedExpression *named_expr,
|
||||
T... rest) {
|
||||
named_expr->expression_ = expr;
|
||||
with->body_.named_expressions.emplace_back(named_expr);
|
||||
return GetWith(with, rest...);
|
||||
}
|
||||
template <class... T>
|
||||
auto GetWith(With *with, NamedExpression *named_expr, T... rest) {
|
||||
with->body_.named_expressions.emplace_back(named_expr);
|
||||
return GetWith(with, rest...);
|
||||
auto GetReturn(AstTreeStorage &storage, T... exprs) {
|
||||
auto ret = storage.Create<Return>();
|
||||
FillReturnBody(ret->body_, exprs...);
|
||||
return ret;
|
||||
}
|
||||
|
||||
///
|
||||
/// Create the with clause with given expressions.
|
||||
///
|
||||
/// The supported expression combination is the same as for @c GetReturn.
|
||||
///
|
||||
/// @sa GetReturn
|
||||
template <class... T>
|
||||
auto GetWith(AstTreeStorage &storage, T... exprs) {
|
||||
auto with = storage.Create<With>();
|
||||
return GetWith(with, exprs...);
|
||||
FillReturnBody(with->body_, exprs...);
|
||||
return with;
|
||||
}
|
||||
|
||||
///
|
||||
@ -288,8 +335,11 @@ auto GetRemove(AstTreeStorage &storage, const std::string &name,
|
||||
#define AS(name) storage.Create<query::NamedExpression>((name))
|
||||
#define RETURN(...) query::test_common::GetReturn(storage, __VA_ARGS__)
|
||||
#define WITH(...) query::test_common::GetWith(storage, __VA_ARGS__)
|
||||
#define SKIP(expr) query::test_common::Skip{(expr)}
|
||||
#define LIMIT(expr) query::test_common::Limit{(expr)}
|
||||
#define ORDER_BY(...) query::test_common::GetOrderBy(__VA_ARGS__)
|
||||
#define SKIP(expr) \
|
||||
query::test_common::Skip { (expr) }
|
||||
#define LIMIT(expr) \
|
||||
query::test_common::Limit { (expr) }
|
||||
#define DELETE(...) query::test_common::GetDelete(storage, {__VA_ARGS__})
|
||||
#define DETACH_DELETE(...) \
|
||||
query::test_common::GetDelete(storage, {__VA_ARGS__}, true)
|
||||
@ -297,7 +347,10 @@ auto GetRemove(AstTreeStorage &storage, const std::string &name,
|
||||
#define REMOVE(...) query::test_common::GetRemove(storage, __VA_ARGS__)
|
||||
#define QUERY(...) query::test_common::GetQuery(storage, __VA_ARGS__)
|
||||
// Various operators
|
||||
#define ADD(expr1, expr2) storage.Create<query::AdditionOperator>((expr1), (expr2))
|
||||
#define ADD(expr1, expr2) \
|
||||
storage.Create<query::AdditionOperator>((expr1), (expr2))
|
||||
#define LESS(expr1, expr2) storage.Create<query::LessOperator>((expr1), (expr2))
|
||||
#define SUM(expr) \
|
||||
storage.Create<query::Aggregation>((expr), query::Aggregation::Op::SUM)
|
||||
#define COUNT(expr) \
|
||||
storage.Create<query::Aggregation>((expr), query::Aggregation::Op::COUNT)
|
||||
|
@ -15,6 +15,7 @@
|
||||
|
||||
using namespace query::plan;
|
||||
using query::AstTreeStorage;
|
||||
using query::Symbol;
|
||||
using query::SymbolTable;
|
||||
using query::SymbolGenerator;
|
||||
using Direction = query::EdgeAtom::Direction;
|
||||
@ -57,13 +58,27 @@ using ExpectRemoveLabels = OpChecker<RemoveLabels>;
|
||||
template <class TAccessor>
|
||||
using ExpectExpandUniquenessFilter =
|
||||
OpChecker<ExpandUniquenessFilter<TAccessor>>;
|
||||
using ExpectAccumulate = OpChecker<Accumulate>;
|
||||
using ExpectSkip = OpChecker<Skip>;
|
||||
using ExpectLimit = OpChecker<Limit>;
|
||||
using ExpectOrderBy = OpChecker<OrderBy>;
|
||||
|
||||
class ExpectAccumulate : public OpChecker<Accumulate> {
|
||||
public:
|
||||
ExpectAccumulate(const std::unordered_set<Symbol, Symbol::Hash> &symbols)
|
||||
: symbols_(symbols) {}
|
||||
|
||||
void ExpectOp(Accumulate &op, const SymbolTable &symbol_table) override {
|
||||
std::unordered_set<Symbol, Symbol::Hash> got_symbols(op.symbols().begin(),
|
||||
op.symbols().end());
|
||||
EXPECT_EQ(symbols_, got_symbols);
|
||||
}
|
||||
|
||||
private:
|
||||
const std::unordered_set<Symbol, Symbol::Hash> symbols_;
|
||||
};
|
||||
|
||||
class ExpectAggregate : public OpChecker<Aggregate> {
|
||||
public:
|
||||
ExpectAggregate() = default;
|
||||
ExpectAggregate(const std::vector<query::Aggregation *> &aggregations,
|
||||
const std::unordered_set<query::Expression *> &group_by)
|
||||
: aggregations_(aggregations), group_by_(group_by) {}
|
||||
@ -119,6 +134,7 @@ class PlanChecker : public LogicalOperatorVisitor {
|
||||
void Visit(Aggregate &op) override { CheckOp(op); }
|
||||
void Visit(Skip &op) override { CheckOp(op); }
|
||||
void Visit(Limit &op) override { CheckOp(op); }
|
||||
void Visit(OrderBy &op) override { CheckOp(op); }
|
||||
|
||||
std::list<BaseOpChecker *> checkers_;
|
||||
|
||||
@ -132,18 +148,29 @@ class PlanChecker : public LogicalOperatorVisitor {
|
||||
const SymbolTable &symbol_table_;
|
||||
};
|
||||
|
||||
template <class... TChecker>
|
||||
auto CheckPlan(query::Query &query, TChecker... checker) {
|
||||
auto MakeSymbolTable(query::Query &query) {
|
||||
SymbolTable symbol_table;
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
query.Accept(symbol_generator);
|
||||
auto plan = MakeLogicalPlan(query, symbol_table);
|
||||
return symbol_table;
|
||||
}
|
||||
|
||||
template <class... TChecker>
|
||||
auto CheckPlan(LogicalOperator &plan, const SymbolTable &symbol_table,
|
||||
TChecker... checker) {
|
||||
std::list<BaseOpChecker *> checkers{&checker...};
|
||||
PlanChecker plan_checker(checkers, symbol_table);
|
||||
plan->Accept(plan_checker);
|
||||
plan.Accept(plan_checker);
|
||||
EXPECT_TRUE(plan_checker.checkers_.empty());
|
||||
}
|
||||
|
||||
template <class... TChecker>
|
||||
auto CheckPlan(query::Query &query, TChecker... checker) {
|
||||
auto symbol_table = MakeSymbolTable(query);
|
||||
auto plan = MakeLogicalPlan(query, symbol_table);
|
||||
CheckPlan(*plan, symbol_table, checker...);
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchNodeReturn) {
|
||||
// Test MATCH (n) RETURN n AS n
|
||||
AstTreeStorage storage;
|
||||
@ -154,8 +181,12 @@ TEST(TestLogicalPlanner, MatchNodeReturn) {
|
||||
TEST(TestLogicalPlanner, CreateNodeReturn) {
|
||||
// Test CREATE (n) RETURN n AS n
|
||||
AstTreeStorage storage;
|
||||
auto query = QUERY(CREATE(PATTERN(NODE("n"))), RETURN(IDENT("n"), AS("n")));
|
||||
CheckPlan(*query, ExpectCreateNode(), ExpectAccumulate(), ExpectProduce());
|
||||
auto ident_n = IDENT("n");
|
||||
auto query = QUERY(CREATE(PATTERN(NODE("n"))), RETURN(ident_n, AS("n")));
|
||||
auto symbol_table = MakeSymbolTable(*query);
|
||||
auto acc = ExpectAccumulate({symbol_table.at(*ident_n)});
|
||||
auto plan = MakeLogicalPlan(*query, symbol_table);
|
||||
CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, ExpectProduce());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, CreateExpand) {
|
||||
@ -413,12 +444,16 @@ TEST(TestLogicalPlanner, CreateWithSum) {
|
||||
auto dba = dbms.active();
|
||||
auto prop = dba->property("prop");
|
||||
AstTreeStorage storage;
|
||||
auto sum = SUM(PROPERTY_LOOKUP("n", prop));
|
||||
auto n_prop = PROPERTY_LOOKUP("n", prop);
|
||||
auto sum = SUM(n_prop);
|
||||
auto query = QUERY(CREATE(PATTERN(NODE("n"))), WITH(sum, AS("sum")));
|
||||
auto symbol_table = MakeSymbolTable(*query);
|
||||
auto acc = ExpectAccumulate({symbol_table.at(*n_prop->expression_)});
|
||||
auto aggr = ExpectAggregate({sum}, {});
|
||||
auto plan = MakeLogicalPlan(*query, symbol_table);
|
||||
// We expect both the accumulation and aggregation because the part before
|
||||
// WITH updates the database.
|
||||
CheckPlan(*query, ExpectCreateNode(), ExpectAccumulate(), aggr,
|
||||
CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, aggr,
|
||||
ExpectProduce());
|
||||
}
|
||||
|
||||
@ -449,15 +484,19 @@ TEST(TestLogicalPlanner, MatchReturnSkipLimit) {
|
||||
TEST(TestLogicalPlanner, CreateWithSkipReturnLimit) {
|
||||
// Test CREATE (n) WITH n AS m SKIP 2 RETURN m LIMIT 1
|
||||
AstTreeStorage storage;
|
||||
auto ident_n = IDENT("n");
|
||||
auto query = QUERY(CREATE(PATTERN(NODE("n"))),
|
||||
WITH(IDENT("n"), AS("m"), SKIP(LITERAL(2))),
|
||||
WITH(ident_n, AS("m"), SKIP(LITERAL(2))),
|
||||
RETURN(IDENT("m"), AS("m"), LIMIT(LITERAL(1))));
|
||||
auto symbol_table = MakeSymbolTable(*query);
|
||||
auto acc = ExpectAccumulate({symbol_table.at(*ident_n)});
|
||||
auto plan = MakeLogicalPlan(*query, symbol_table);
|
||||
// Since we have a write query, we need to have Accumulate, so Skip and Limit
|
||||
// need to come before it. This is a bit different than Neo4j, which optimizes
|
||||
// WITH followed by RETURN as a single RETURN clause. This would cause the
|
||||
// Limit operator to also appear before Accumulate, thus changing the
|
||||
// behaviour. We've decided to diverge from Neo4j here, for consistency sake.
|
||||
CheckPlan(*query, ExpectCreateNode(), ExpectSkip(), ExpectAccumulate(),
|
||||
CheckPlan(*plan, symbol_table, ExpectCreateNode(), ExpectSkip(), acc,
|
||||
ExpectProduce(), ExpectLimit(), ExpectProduce());
|
||||
}
|
||||
|
||||
@ -467,14 +506,69 @@ TEST(TestLogicalPlanner, CreateReturnSumSkipLimit) {
|
||||
auto dba = dbms.active();
|
||||
auto prop = dba->property("prop");
|
||||
AstTreeStorage storage;
|
||||
auto sum = SUM(PROPERTY_LOOKUP("n", prop));
|
||||
auto n_prop = PROPERTY_LOOKUP("n", prop);
|
||||
auto sum = SUM(n_prop);
|
||||
auto query = QUERY(CREATE(PATTERN(NODE("n"))),
|
||||
RETURN(sum, AS("s"), SKIP(LITERAL(2)), LIMIT(LITERAL(1))));
|
||||
auto symbol_table = MakeSymbolTable(*query);
|
||||
auto acc = ExpectAccumulate({symbol_table.at(*n_prop->expression_)});
|
||||
auto aggr = ExpectAggregate({sum}, {});
|
||||
auto plan = MakeLogicalPlan(*query, symbol_table);
|
||||
// We have a write query and aggregation, therefore Skip and Limit should come
|
||||
// after Accumulate and Aggregate.
|
||||
CheckPlan(*query, ExpectCreateNode(), ExpectAccumulate(), aggr, ExpectSkip(),
|
||||
CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, aggr, ExpectSkip(),
|
||||
ExpectLimit(), ExpectProduce());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchReturnOrderBy) {
|
||||
// Test MATCH (n) RETURN n ORDER BY n.prop
|
||||
Dbms dbms;
|
||||
auto dba = dbms.active();
|
||||
auto prop = dba->property("prop");
|
||||
AstTreeStorage storage;
|
||||
auto ret = RETURN(IDENT("n"), AS("n"), ORDER_BY(PROPERTY_LOOKUP("n", prop)));
|
||||
auto query = QUERY(MATCH(PATTERN(NODE("n"))), ret);
|
||||
CheckPlan(*query, ExpectScanAll(), ExpectProduce(), ExpectOrderBy());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, CreateWithOrderByWhere) {
|
||||
// Test CREATE (n) -[r :r]-> (m)
|
||||
// WITH n AS new ORDER BY new.prop, r.prop WHERE m.prop < 42
|
||||
Dbms dbms;
|
||||
auto dba = dbms.active();
|
||||
auto prop = dba->property("prop");
|
||||
auto r_type = dba->edge_type("r");
|
||||
AstTreeStorage storage;
|
||||
auto ident_n = IDENT("n");
|
||||
auto new_prop = PROPERTY_LOOKUP("new", prop);
|
||||
auto r_prop = PROPERTY_LOOKUP("r", prop);
|
||||
auto m_prop = PROPERTY_LOOKUP("m", prop);
|
||||
auto query =
|
||||
QUERY(CREATE(PATTERN(NODE("n"), EDGE("r", r_type, Direction::RIGHT),
|
||||
NODE("m"))),
|
||||
WITH(ident_n, AS("new"), ORDER_BY(new_prop, r_prop)),
|
||||
WHERE(LESS(m_prop, LITERAL(42))));
|
||||
auto symbol_table = MakeSymbolTable(*query);
|
||||
// Since this is a write query, we expect to accumulate to old used symbols.
|
||||
auto acc = ExpectAccumulate({
|
||||
symbol_table.at(*ident_n), // `n` in WITH
|
||||
symbol_table.at(*r_prop->expression_), // `r` in ORDER BY
|
||||
symbol_table.at(*m_prop->expression_), // `m` in WHERE
|
||||
});
|
||||
auto plan = MakeLogicalPlan(*query, symbol_table);
|
||||
CheckPlan(*plan, symbol_table, ExpectCreateNode(), ExpectCreateExpand(), acc,
|
||||
ExpectProduce(), ExpectFilter(), ExpectOrderBy());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, ReturnAddSumCountOrderBy) {
|
||||
// Test RETURN SUM(1) + COUNT(2) AS result ORDER BY result
|
||||
AstTreeStorage storage;
|
||||
auto sum = SUM(LITERAL(1));
|
||||
auto count = COUNT(LITERAL(2));
|
||||
auto query =
|
||||
QUERY(RETURN(ADD(sum, count), AS("result"), ORDER_BY(IDENT("result"))));
|
||||
auto aggr = ExpectAggregate({sum, count}, {});
|
||||
CheckPlan(*query, aggr, ExpectProduce(), ExpectOrderBy());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -334,13 +334,14 @@ TEST(TestSymbolGenerator, MatchWithWhere) {
|
||||
}
|
||||
|
||||
TEST(TestSymbolGenerator, MatchWithWhereUnbound) {
|
||||
// Test MATCH (old) WITH old AS n WHERE old.prop < 42
|
||||
// Test MATCH (old) WITH COUNT(old) AS c WHERE old.prop < 42
|
||||
Dbms dbms;
|
||||
auto dba = dbms.active();
|
||||
auto prop = dba->property("prop");
|
||||
AstTreeStorage storage;
|
||||
auto query = QUERY(MATCH(PATTERN(NODE("old"))), WITH(IDENT("old"), AS("n")),
|
||||
WHERE(LESS(PROPERTY_LOOKUP("old", prop), LITERAL(42))));
|
||||
auto query =
|
||||
QUERY(MATCH(PATTERN(NODE("old"))), WITH(COUNT(IDENT("old")), AS("c")),
|
||||
WHERE(LESS(PROPERTY_LOOKUP("old", prop), LITERAL(42))));
|
||||
SymbolTable symbol_table;
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
EXPECT_THROW(query->Accept(symbol_generator), UnboundVariableError);
|
||||
@ -585,4 +586,67 @@ TEST(TestSymbolGenerator, SkipLimitIdentifier) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestSymbolGenerator, OrderBy) {
|
||||
// Test MATCH (old) RETURN old AS new ORDER BY COUNT(1)
|
||||
{
|
||||
AstTreeStorage storage;
|
||||
auto query =
|
||||
QUERY(MATCH(PATTERN(NODE("old"))),
|
||||
RETURN(IDENT("old"), AS("new"), ORDER_BY(COUNT(LITERAL(1)))));
|
||||
SymbolTable symbol_table;
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
EXPECT_THROW(query->Accept(symbol_generator), SemanticException);
|
||||
}
|
||||
// Test MATCH (old) RETURN COUNT(old) AS new ORDER BY old
|
||||
{
|
||||
AstTreeStorage storage;
|
||||
auto query =
|
||||
QUERY(MATCH(PATTERN(NODE("old"))),
|
||||
RETURN(COUNT(IDENT("old")), AS("new"), ORDER_BY(IDENT("old"))));
|
||||
SymbolTable symbol_table;
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
EXPECT_THROW(query->Accept(symbol_generator), UnboundVariableError);
|
||||
}
|
||||
// Test MATCH (old) RETURN COUNT(old) AS new ORDER BY new
|
||||
{
|
||||
AstTreeStorage storage;
|
||||
auto node = NODE("old");
|
||||
auto ident_old = IDENT("old");
|
||||
auto as_new = AS("new");
|
||||
auto ident_new = IDENT("new");
|
||||
auto query = QUERY(MATCH(PATTERN(node)),
|
||||
RETURN(COUNT(ident_old), as_new, ORDER_BY(ident_new)));
|
||||
SymbolTable symbol_table;
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
query->Accept(symbol_generator);
|
||||
// Symbols for `old`, `count(old)` and `new`
|
||||
EXPECT_EQ(symbol_table.max_position(), 3);
|
||||
auto old = symbol_table.at(*node->identifier_);
|
||||
EXPECT_EQ(old, symbol_table.at(*ident_old));
|
||||
auto new_sym = symbol_table.at(*as_new);
|
||||
EXPECT_NE(old, new_sym);
|
||||
EXPECT_EQ(new_sym, symbol_table.at(*ident_new));
|
||||
}
|
||||
// Test MATCH (old) RETURN old AS new ORDER BY old
|
||||
{
|
||||
AstTreeStorage storage;
|
||||
auto node = NODE("old");
|
||||
auto ident_old = IDENT("old");
|
||||
auto as_new = AS("new");
|
||||
auto by_old = IDENT("old");
|
||||
auto query = QUERY(MATCH(PATTERN(node)),
|
||||
RETURN(ident_old, as_new, ORDER_BY(by_old)));
|
||||
SymbolTable symbol_table;
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
query->Accept(symbol_generator);
|
||||
// Symbols for `old` and `new`
|
||||
EXPECT_EQ(symbol_table.max_position(), 2);
|
||||
auto old = symbol_table.at(*node->identifier_);
|
||||
EXPECT_EQ(old, symbol_table.at(*ident_old));
|
||||
EXPECT_EQ(old, symbol_table.at(*by_old));
|
||||
auto new_sym = symbol_table.at(*as_new);
|
||||
EXPECT_NE(old, new_sym);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user