diff --git a/src/query/frontend/semantic/symbol_generator.cpp b/src/query/frontend/semantic/symbol_generator.cpp index 4f28efc3c..5be5130dd 100644 --- a/src/query/frontend/semantic/symbol_generator.cpp +++ b/src/query/frontend/semantic/symbol_generator.cpp @@ -30,31 +30,63 @@ auto SymbolGenerator::GetOrCreateSymbol(const std::string &name, return CreateSymbol(name, type); } -void SymbolGenerator::BindNamedExpressionSymbols( - const std::vector &named_expressions) { - std::unordered_set seen_names; - for (auto &named_expr : named_expressions) { - // Improvement would be to infer the type of the expression. +void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) { + for (auto &expr : body.named_expressions) { + expr->Accept(*this); + } + // WITH/RETURN clause removes declarations of all the previous variables and + // declares only those established through named expressions. New declarations + // must not be visible inside named expressions themselves. + bool removed_old_names = false; + if ((!where && body.order_by.empty()) || scope_.has_aggregation) { + // WHERE and ORDER BY need to see both the old and new symbols, unless we + // have an aggregation. Therefore, we can clear the symbols immediately if + // there is neither ORDER BY nor WHERE, or we have an aggregation. + scope_.symbols.clear(); + removed_old_names = true; + } + // Create symbols for named expressions. + std::unordered_set new_names; + for (auto &named_expr : body.named_expressions) { const auto &name = named_expr->name_; - if (!seen_names.insert(name).second) { + if (!new_names.insert(name).second) { throw SemanticException( "Multiple results with the same name '{}' are not allowed.", name); } + // An improvement would be to infer the type of the expression, so that the + // new symbol would have a more specific type. symbol_table_[*named_expr] = CreateSymbol(name); } -} - -void SymbolGenerator::VisitSkipAndLimit(Expression *skip, Expression *limit) { - if (skip) { + scope_.in_order_by = true; + for (const auto &order_pair : body.order_by) { + order_pair.second->Accept(*this); + } + scope_.in_order_by = false; + if (body.skip) { scope_.in_skip = true; - skip->Accept(*this); + body.skip->Accept(*this); scope_.in_skip = false; } - if (limit) { + if (body.limit) { scope_.in_limit = true; - limit->Accept(*this); + body.limit->Accept(*this); scope_.in_limit = false; } + if (where) where->Accept(*this); + if (!removed_old_names) { + // We have an ORDER BY or WHERE, but no aggregation, which means we didn't + // clear the old symbols, so do it now. We cannot just call clear, because + // we've added new symbols. + for (auto sym_it = scope_.symbols.begin(); + sym_it != scope_.symbols.end();) { + if (new_names.find(sym_it->first) == new_names.end()) { + sym_it = scope_.symbols.erase(sym_it); + } else { + sym_it++; + } + } + } + scope_.has_aggregation = false; } // Clauses @@ -64,33 +96,21 @@ void SymbolGenerator::PostVisit(Create &create) { scope_.in_create = false; } bool SymbolGenerator::PreVisit(Return &ret) { scope_.in_return = true; - for (auto &expr : ret.body_.named_expressions) { - expr->Accept(*this); - } - // Named expressions establish bindings for expressions which come after - // return, but not for the expressions contained inside. - BindNamedExpressionSymbols(ret.body_.named_expressions); - VisitSkipAndLimit(ret.body_.skip, ret.body_.limit); + VisitReturnBody(ret.body_); scope_.in_return = false; return false; // We handled the traversal ourselves. } bool SymbolGenerator::PreVisit(With &with) { scope_.in_with = true; - for (auto &expr : with.body_.named_expressions) { - expr->Accept(*this); - } + VisitReturnBody(with.body_, with.where_); scope_.in_with = false; - // WITH clause removes declarations of all the previous variables and declares - // only those established through named expressions. New declarations must not - // be visible inside named expressions themselves. - scope_.symbols.clear(); - BindNamedExpressionSymbols(with.body_.named_expressions); - VisitSkipAndLimit(with.body_.skip, with.body_.limit); - if (with.where_) with.where_->Accept(*this); return false; // We handled the traversal ourselves. } +void SymbolGenerator::Visit(Where &) { scope_.in_where = true; } +void SymbolGenerator::PostVisit(Where &) { scope_.in_where = false; } + // Expressions void SymbolGenerator::Visit(Identifier &ident) { @@ -133,7 +153,8 @@ void SymbolGenerator::Visit(Aggregation &aggr) { // Check if the aggregation can be used in this context. This check should // probably move to a separate phase, which checks if the query is well // formed. - if (!scope_.in_return && !scope_.in_with) { + if ((!scope_.in_return && !scope_.in_with) || scope_.in_order_by || + scope_.in_skip || scope_.in_limit || scope_.in_where) { throw SemanticException( "Aggregation functions are only allowed in WITH and RETURN"); } @@ -146,6 +167,7 @@ void SymbolGenerator::Visit(Aggregation &aggr) { // Currently, we only have aggregation operators which return numbers. symbol_table_[aggr] = symbol_table_.CreateSymbol("", Symbol::Type::Number); scope_.in_aggregation = true; + scope_.has_aggregation = true; } void SymbolGenerator::PostVisit(Aggregation &aggr) { diff --git a/src/query/frontend/semantic/symbol_generator.hpp b/src/query/frontend/semantic/symbol_generator.hpp index 7f410c7c1..b6ba6ec34 100644 --- a/src/query/frontend/semantic/symbol_generator.hpp +++ b/src/query/frontend/semantic/symbol_generator.hpp @@ -29,6 +29,8 @@ class SymbolGenerator : public TreeVisitorBase { void PostVisit(Create &) override; bool PreVisit(Return &) override; bool PreVisit(With &) override; + void Visit(Where &) override; + void PostVisit(Where &) override; // Expressions void Visit(Identifier &) override; @@ -62,6 +64,11 @@ class SymbolGenerator : public TreeVisitorBase { bool in_with{false}; bool in_skip{false}; bool in_limit{false}; + bool in_order_by{false}; + bool in_where{false}; + // True if the return/with contains an aggregation in any named expression. + bool has_aggregation{false}; + // Map from variable names to symbols. std::map symbols; }; @@ -77,10 +84,7 @@ class SymbolGenerator : public TreeVisitorBase { auto GetOrCreateSymbol(const std::string &name, Symbol::Type type = Symbol::Type::Any); - void BindNamedExpressionSymbols( - const std::vector &named_expressions); - - void VisitSkipAndLimit(Expression *skip, Expression *limit); + void VisitReturnBody(ReturnBody &body, Where *where = nullptr); SymbolTable &symbol_table_; Scope scope_; diff --git a/src/query/frontend/semantic/symbol_table.hpp b/src/query/frontend/semantic/symbol_table.hpp index b40500e55..76856728a 100644 --- a/src/query/frontend/semantic/symbol_table.hpp +++ b/src/query/frontend/semantic/symbol_table.hpp @@ -13,10 +13,17 @@ class Symbol { enum class Type { Any, Vertex, Edge, Path, Number }; static std::string TypeToString(Type type) { - const char *enum_string[] = {"Any", "Vertex", "Edge", "Path"}; + const char *enum_string[] = {"Any", "Vertex", "Edge", "Path", "Number"}; return enum_string[static_cast(type)]; } + // Calculates the Symbol hash based on its position. + struct Hash { + size_t operator()(const Symbol &symbol) const { + return std::hash{}(symbol.position_); + } + }; + Symbol() {} Symbol(const std::string &name, int position, Type type = Type::Any) : name_(name), position_(position), type_(type) {} diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index de072474c..a4f5eeebe 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -42,25 +42,32 @@ void Interpret(const std::string &query, GraphDbAccessor &db_accessor, Frame frame(symbol_table.max_position()); std::vector header; + bool is_return = false; + std::vector output_symbols; if (auto produce = dynamic_cast(logical_plan.get())) { - // top level node in the operator tree is a produce (return) + is_return = true; + // collect the symbols from the return clause + for (auto named_expression : produce->named_expressions()) + output_symbols.emplace_back(symbol_table[*named_expression]); + } else if (auto order_by = + dynamic_cast(logical_plan.get())) { + is_return = true; + output_symbols = order_by->output_symbols(); + } + if (is_return) { + // top level node in the operator tree is a produce/order_by (return) // so stream out results // generate header - for (auto named_expression : produce->named_expressions()) - header.push_back(named_expression->name_); + for (const auto &symbol : output_symbols) header.push_back(symbol.name_); stream.Header(header); - // collect the symbols from the return clause - std::vector symbols; - for (auto named_expression : produce->named_expressions()) - symbols.emplace_back(symbol_table[*named_expression]); - // stream out results - auto cursor = produce->MakeCursor(db_accessor); + auto cursor = logical_plan->MakeCursor(db_accessor); while (cursor->Pull(frame, symbol_table)) { std::vector values; - for (auto &symbol : symbols) values.emplace_back(frame[symbol]); + for (const auto &symbol : output_symbols) + values.emplace_back(frame[symbol]); stream.Result(values); } } else if (dynamic_cast(logical_plan.get()) || @@ -72,7 +79,7 @@ void Interpret(const std::string &query, GraphDbAccessor &db_accessor, dynamic_cast(logical_plan.get()) || dynamic_cast(logical_plan.get())) { stream.Header(header); - auto cursor = logical_plan.get()->MakeCursor(db_accessor); + auto cursor = logical_plan->MakeCursor(db_accessor); while (cursor->Pull(frame, symbol_table)) continue; } else { throw QueryRuntimeException("Unknown top level LogicalOp"); diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 0ee5a75e3..42c1b8ba1 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -1224,9 +1224,9 @@ void Limit::LimitCursor::Reset() { } OrderBy::OrderBy(const std::shared_ptr &input, - const std::vector> order_by, - const std::vector remember) - : input_(input), remember_(remember) { + const std::vector> &order_by, + const std::vector &output_symbols) + : input_(input), output_symbols_(output_symbols) { // split the order_by vector into two vectors of orderings and expressions std::vector ordering; ordering.reserve(order_by.size()); @@ -1259,12 +1259,12 @@ bool OrderBy::OrderByCursor::Pull(Frame &frame, order_by.emplace_back(evaluator.PopBack()); } - // collect the remember elements - std::list remember; - for (const Symbol &remember_sym : self_.remember_) - remember.emplace_back(frame[remember_sym]); + // collect the output elements + std::list output; + for (const Symbol &output_sym : self_.output_symbols_) + output.emplace_back(frame[output_sym]); - cache_.emplace_back(order_by, remember); + cache_.emplace_back(order_by, output); } std::sort(cache_.begin(), cache_.end(), @@ -1278,13 +1278,13 @@ bool OrderBy::OrderByCursor::Pull(Frame &frame, if (cache_it_ == cache_.end()) return false; - // place the remembered values on the frame - debug_assert(self_.remember_.size() == cache_it_->second.size(), - "Number of values does not match the number of remember symbols " + // place the output values on the frame + debug_assert(self_.output_symbols_.size() == cache_it_->second.size(), + "Number of values does not match the number of output symbols " "in OrderBy"); - auto remember_sym_it = self_.remember_.begin(); - for (const TypedValue &remember : cache_it_->second) - frame[*remember_sym_it++] = remember; + auto output_sym_it = self_.output_symbols_.begin(); + for (const TypedValue &output : cache_it_->second) + frame[*output_sym_it++] = output; cache_it_++; return true; diff --git a/src/query/plan/operator.hpp b/src/query/plan/operator.hpp index b7df641d3..7113ee9f5 100644 --- a/src/query/plan/operator.hpp +++ b/src/query/plan/operator.hpp @@ -827,6 +827,8 @@ class Accumulate : public LogicalOperator { void Accept(LogicalOperatorVisitor &visitor) override; std::unique_ptr MakeCursor(GraphDbAccessor &db) override; + const auto &symbols() const { return symbols_; }; + private: const std::shared_ptr input_; const std::vector symbols_; @@ -1066,11 +1068,13 @@ class Limit : public LogicalOperator { class OrderBy : public LogicalOperator { public: OrderBy(const std::shared_ptr &input, - const std::vector> order_by, - const std::vector remember); + const std::vector> &order_by, + const std::vector &output_symbols); void Accept(LogicalOperatorVisitor &visitor) override; std::unique_ptr MakeCursor(GraphDbAccessor &db) override; + const auto &output_symbols() const { return output_symbols_; } + private: // custom Comparator type for comparing lists of TypedValues // does lexicographical ordering of elements based on the above @@ -1091,7 +1095,7 @@ class OrderBy : public LogicalOperator { const std::shared_ptr input_; TypedValueListCompare compare_; std::vector order_by_; - const std::vector remember_; + const std::vector output_symbols_; // custom comparison for TypedValue objects // behaves generally like Neo's ORDER BY comparison operator: diff --git a/src/query/plan/planner.cpp b/src/query/plan/planner.cpp index 2104a1671..5225bc724 100644 --- a/src/query/plan/planner.cpp +++ b/src/query/plan/planner.cpp @@ -1,5 +1,6 @@ #include "query/plan/planner.hpp" +#include #include #include @@ -184,17 +185,37 @@ auto GenMatch(Match &match, LogicalOperator *input_op, // * flag whether the results need to be DISTINCT; // * optional SKIP expression; // * optional LIMIT expression and -// * optional ORDER BY expression. +// * optional ORDER BY expressions. // // In addition to the above, we collect information on used symbols, // aggregations and expressions used for group by. class ReturnBodyContext : public TreeVisitorBase { public: - ReturnBodyContext(const ReturnBody &body, const SymbolTable &symbol_table) - : body_(body), symbol_table_(symbol_table) { + ReturnBodyContext(const ReturnBody &body, const SymbolTable &symbol_table, + Where *where = nullptr) + : body_(body), symbol_table_(symbol_table), where_(where) { + // Collect symbols from named expressions. + output_symbols_.reserve(body_.named_expressions.size()); for (auto &named_expr : body_.named_expressions) { + output_symbols_.emplace_back(symbol_table_.at(*named_expr)); named_expr->Accept(*this); } + if (aggregations_.empty()) { + // Visit order_by and where if we do not have aggregations. This way we + // prevent collecting group_by expressions from order_by and where, which + // would be very wrong. When we have aggregation, order_by and where can + // only use new symbols (ensured in semantic analysis), so we don't care + // about collecting used_symbols. Also, semantic analysis should + // have prevented any aggregations from appearing here. + for (const auto &order_pair : body.order_by) { + order_pair.second->Accept(*this); + } + if (where) { + where->Accept(*this); + } + debug_assert(aggregations_.empty(), + "Unexpected aggregations in ORDER BY or WHERE"); + } } using TreeVisitorBase::PreVisit; @@ -204,7 +225,13 @@ class ReturnBodyContext : public TreeVisitorBase { void Visit(Literal &) override { has_aggregation_.emplace_back(false); } void Visit(Identifier &ident) override { - symbols_.insert(symbol_table_.at(ident)); + const auto &symbol = symbol_table_.at(ident); + if (std::find(output_symbols_.begin(), output_symbols_.end(), symbol) == + output_symbols_.end()) { + // Don't pick up new symbols, even though they may be used in ORDER BY or + // WHERE. + used_symbols_.insert(symbol); + } has_aggregation_.emplace_back(false); } @@ -217,7 +244,7 @@ class ReturnBodyContext : public TreeVisitorBase { bool aggr1 = has_aggregation_.back(); \ has_aggregation_.pop_back(); \ bool has_aggr = aggr1 || aggr2; \ - if (has_aggr) { \ + if (has_aggr && !(aggr1 && aggr2)) { \ /* Group by the expression which does not contain aggregation. */ \ /* Possible optimization is to ignore constant value expressions */ \ group_by_.emplace_back(aggr1 ? op.expression2_ : op.expression1_); \ @@ -264,13 +291,18 @@ class ReturnBodyContext : public TreeVisitorBase { bool distinct() const { return body_.distinct; } // Named expressions which are used to produce results. const auto &named_expressions() const { return body_.named_expressions; } + // Pairs of (Ordering, Expression *) for sorting results. + const auto &order_by() const { return body_.order_by; } // Optional expression which determines how many results to skip. auto *skip() const { return body_.skip; } // Optional expression which determines how many results to produce. auto *limit() const { return body_.limit; } + // Optional Where clause for filtering. + const auto *where() const { return where_; } // Set of symbols used inside the visited expressions outside of aggregation - // expression. - const auto &symbols() const { return symbols_; } + // expression. These only includes old symbols, even though new ones may have + // been used in ORDER BY or WHERE. + const auto &used_symbols() const { return used_symbols_; } // List of aggregation elements found in expressions. const auto &aggregations() const { return aggregations_; } // When there is at least one aggregation element, all the non-aggregate (sub) @@ -278,17 +310,16 @@ class ReturnBodyContext : public TreeVisitorBase { // AS sum, n.c AS nc`, we will group by `2 * n.b` and `n.c`. const auto &group_by() const { return group_by_; } - private: - // Calculates the Symbol hash based on its position. - struct SymbolHash { - size_t operator()(const Symbol &symbol) const { - return std::hash{}(symbol.position_); - } - }; + // All symbols generated by named expressions. They are collected in order of + // named_expressions. + const auto &output_symbols() const { return output_symbols_; } + private: const ReturnBody &body_; const SymbolTable &symbol_table_; - std::unordered_set symbols_; + const Where *const where_ = nullptr; + std::unordered_set used_symbols_; + std::vector output_symbols_; std::vector aggregations_; std::vector group_by_; // Flag indicating whether an expression contains an aggregation. @@ -314,8 +345,8 @@ auto GenReturnBody(LogicalOperator *input_op, bool advance_command, // TODO: Plan with distinct, when operator available. throw utils::NotYetImplemented(); } - auto symbols = - std::vector(body.symbols().begin(), body.symbols().end()); + std::vector used_symbols(body.used_symbols().begin(), + body.used_symbols().end()); auto last_op = input_op; if (body.aggregations().empty()) { // In case when we have SKIP/LIMIT and we don't perform aggregations, we @@ -329,18 +360,30 @@ auto GenReturnBody(LogicalOperator *input_op, bool advance_command, // We only advance the command in Accumulate. This is done for WITH clause, // when the first part updated the database. RETURN clause may only need an // accumulation after updates, without advancing the command. - last_op = new Accumulate(std::shared_ptr(last_op), symbols, - advance_command); + last_op = new Accumulate(std::shared_ptr(last_op), + used_symbols, advance_command); } if (!body.aggregations().empty()) { // When we have aggregation, SKIP/LIMIT should always come after it. last_op = GenSkipLimit( new Aggregate(std::shared_ptr(last_op), - body.aggregations(), body.group_by(), symbols), + body.aggregations(), body.group_by(), used_symbols), body); } - return new Produce(std::shared_ptr(last_op), - body.named_expressions()); + last_op = new Produce(std::shared_ptr(last_op), + body.named_expressions()); + // Where may see new symbols so it comes after we generate Produce. + if (body.where()) { + last_op = new Filter(std::shared_ptr(last_op), + body.where()->expression_); + } + // Like Where, OrderBy can read from symbols established by named expressions + // in Produce, so it must come after it. + if (!body.order_by().empty()) { + last_op = new OrderBy(std::shared_ptr(last_op), + body.order_by(), body.output_symbols()); + } + return last_op; } auto GenWith(With &with, LogicalOperator *input_op, @@ -353,17 +396,13 @@ auto GenWith(With &with, LogicalOperator *input_op, bool accumulate = is_write; // No need to advance the command if we only performed reads. bool advance_command = is_write; - ReturnBodyContext body(with.body_, symbol_table); + ReturnBodyContext body(with.body_, symbol_table, with.where_); LogicalOperator *last_op = GenReturnBody(input_op, advance_command, body, accumulate); // Reset bound symbols, so that only those in WITH are exposed. bound_symbols.clear(); - for (auto &named_expr : with.body_.named_expressions) { - BindSymbol(bound_symbols, symbol_table.at(*named_expr)); - } - if (with.where_) { - last_op = new Filter(std::shared_ptr(last_op), - with.where_->expression_); + for (const auto &symbol : body.output_symbols()) { + BindSymbol(bound_symbols, symbol); } return last_op; } diff --git a/tests/unit/query_common.hpp b/tests/unit/query_common.hpp index b27b092ad..c1c6f3125 100644 --- a/tests/unit/query_common.hpp +++ b/tests/unit/query_common.hpp @@ -1,18 +1,74 @@ +/// +/// @file +/// This file provides macros for easier construction of openCypher query AST. +/// The usage of macros is very similar to how one would write openCypher. For +/// example: +/// +/// AstTreeStorage storage; // Macros rely on storage being in scope. +/// +/// QUERY(MATCH(PATTERN(NODE("n"), EDGE("e"), NODE("m"))), +/// WHERE(LESS(PROPERTY_LOOKUP("e", edge_prop), LITERAL(3))), +/// RETURN(SUM(PROPERTY_LOOKUP("m", prop)), AS("sum"), +/// ORDER_BY(IDENT("sum")), +/// SKIP(ADD(LITERAL(1), LITERAL(2))))); +/// +/// Each of the macros is accompanied by a function. The functions use overload +/// resolution and template magic to provide a type safe way of constructing +/// queries. Although the functions can be used by themselves, it is more +/// convenient to use the macros. +/// + +#include +#include + #include "database/graph_db_datatypes.hpp" namespace query { namespace test_common { -// Custom types for SKIP and LIMIT and expressions, so that they can be used to -// resolve function calls. +// Custom types for ORDER BY, SKIP and LIMIT and expressions, so that they can +// be used to resolve function calls. +struct OrderBy { + std::vector> expressions; +}; struct Skip { - query::Expression *expression = nullptr; + Expression *expression = nullptr; }; struct Limit { - query::Expression *expression = nullptr; + Expression *expression = nullptr; }; +// Helper functions for filling the OrderBy with expressions. +auto FillOrderBy(OrderBy &order_by, Expression *expression, + Ordering ordering = Ordering::ASC) { + order_by.expressions.emplace_back(ordering, expression); +} +template +auto FillOrderBy(OrderBy &order_by, Expression *expression, Ordering ordering, + T... rest) { + FillOrderBy(order_by, expression, ordering); + FillOrderBy(order_by, rest...); +} +template +auto FillOrderBy(OrderBy &order_by, Expression *expression, T... rest) { + FillOrderBy(order_by, expression); + FillOrderBy(order_by, rest...); +} + +/// +/// Create OrderBy expressions. +/// +/// The supported combination of arguments is: (Expression, [Ordering])+ +/// Since the Ordering is optional, by default it is ascending. +/// +template +auto GetOrderBy(T... exprs) { + OrderBy order_by; + FillOrderBy(order_by, exprs...); + return order_by; +} + /// /// Create PropertyLookup with given name and property. /// @@ -117,86 +173,77 @@ auto GetQuery(AstTreeStorage &storage, Clause *clause, T *... clauses) { return GetQuery(storage, clauses...); } -/// -/// Create the return clause with given named expressions. -/// -auto GetReturn(Return *ret, NamedExpression *named_expr) { - ret->body_.named_expressions.emplace_back(named_expr); - return ret; +// Helper functions for constructing RETURN and WITH clauses. +void FillReturnBody(ReturnBody &body, NamedExpression *named_expr) { + body.named_expressions.emplace_back(named_expr); } -auto GetReturn(Return *ret, Skip skip, Limit limit = Limit{}) { - ret->body_.skip = skip.expression; - ret->body_.limit = limit.expression; - return ret; +void FillReturnBody(ReturnBody &body, Limit limit) { + body.limit = limit.expression; } -auto GetReturn(Return *ret, Limit limit) { - ret->body_.limit = limit.expression; - return ret; +void FillReturnBody(ReturnBody &body, Skip skip, Limit limit = Limit{}) { + body.skip = skip.expression; + body.limit = limit.expression; } -auto GetReturn(Return *ret, Expression *expr, NamedExpression *named_expr) { +void FillReturnBody(ReturnBody &body, OrderBy order_by, Limit limit = Limit{}) { + body.order_by = order_by.expressions; + body.limit = limit.expression; +} +void FillReturnBody(ReturnBody &body, OrderBy order_by, Skip skip, + Limit limit = Limit{}) { + body.order_by = order_by.expressions; + body.skip = skip.expression; + body.limit = limit.expression; +} +void FillReturnBody(ReturnBody &body, Expression *expr, + NamedExpression *named_expr) { // This overload supports `RETURN(expr, AS(name))` construct, since // NamedExpression does not inherit Expression. named_expr->expression_ = expr; - ret->body_.named_expressions.emplace_back(named_expr); - return ret; + body.named_expressions.emplace_back(named_expr); } template -auto GetReturn(Return *ret, Expression *expr, NamedExpression *named_expr, - T... rest) { +void FillReturnBody(ReturnBody &body, Expression *expr, + NamedExpression *named_expr, T... rest) { named_expr->expression_ = expr; - ret->body_.named_expressions.emplace_back(named_expr); - return GetReturn(ret, rest...); + body.named_expressions.emplace_back(named_expr); + FillReturnBody(body, rest...); } template -auto GetReturn(Return *ret, NamedExpression *named_expr, T... rest) { - ret->body_.named_expressions.emplace_back(named_expr); - return GetReturn(ret, rest...); -} -template -auto GetReturn(AstTreeStorage &storage, T... exprs) { - auto ret = storage.Create(); - return GetReturn(ret, exprs...); +void FillReturnBody(ReturnBody &body, NamedExpression *named_expr, T... rest) { + body.named_expressions.emplace_back(named_expr); + FillReturnBody(body, rest...); } /// -/// Create the with clause with given named expressions. +/// Create the return clause with given expressions. /// -auto GetWith(With *with, NamedExpression *named_expr) { - with->body_.named_expressions.emplace_back(named_expr); - return with; -} -auto GetWith(With *with, Skip skip, Limit limit = {}) { - with->body_.skip = skip.expression; - with->body_.limit = limit.expression; - return with; -} -auto GetWith(With *with, Limit limit) { - with->body_.limit = limit.expression; - return with; -} -auto GetWith(With *with, Expression *expr, NamedExpression *named_expr) { - // This overload supports `RETURN(expr, AS(name))` construct, since - // NamedExpression does not inherit Expression. - named_expr->expression_ = expr; - with->body_.named_expressions.emplace_back(named_expr); - return with; -} +/// The supported expression combination of arguments is: +/// +/// (NamedExpression | (Expression NamedExpression))+ [OrderBy] [Skip] [Limit] +/// +/// When the pair (Expression NamedExpression) is given, the Expression will be +/// moved inside the NamedExpression. This is done, so that the constructs like +/// RETURN(expr, AS("name"), ...) are supported. +/// +/// @sa GetWith template -auto GetWith(With *with, Expression *expr, NamedExpression *named_expr, - T... rest) { - named_expr->expression_ = expr; - with->body_.named_expressions.emplace_back(named_expr); - return GetWith(with, rest...); -} -template -auto GetWith(With *with, NamedExpression *named_expr, T... rest) { - with->body_.named_expressions.emplace_back(named_expr); - return GetWith(with, rest...); +auto GetReturn(AstTreeStorage &storage, T... exprs) { + auto ret = storage.Create(); + FillReturnBody(ret->body_, exprs...); + return ret; } + +/// +/// Create the with clause with given expressions. +/// +/// The supported expression combination is the same as for @c GetReturn. +/// +/// @sa GetReturn template auto GetWith(AstTreeStorage &storage, T... exprs) { auto with = storage.Create(); - return GetWith(with, exprs...); + FillReturnBody(with->body_, exprs...); + return with; } /// @@ -288,8 +335,11 @@ auto GetRemove(AstTreeStorage &storage, const std::string &name, #define AS(name) storage.Create((name)) #define RETURN(...) query::test_common::GetReturn(storage, __VA_ARGS__) #define WITH(...) query::test_common::GetWith(storage, __VA_ARGS__) -#define SKIP(expr) query::test_common::Skip{(expr)} -#define LIMIT(expr) query::test_common::Limit{(expr)} +#define ORDER_BY(...) query::test_common::GetOrderBy(__VA_ARGS__) +#define SKIP(expr) \ + query::test_common::Skip { (expr) } +#define LIMIT(expr) \ + query::test_common::Limit { (expr) } #define DELETE(...) query::test_common::GetDelete(storage, {__VA_ARGS__}) #define DETACH_DELETE(...) \ query::test_common::GetDelete(storage, {__VA_ARGS__}, true) @@ -297,7 +347,10 @@ auto GetRemove(AstTreeStorage &storage, const std::string &name, #define REMOVE(...) query::test_common::GetRemove(storage, __VA_ARGS__) #define QUERY(...) query::test_common::GetQuery(storage, __VA_ARGS__) // Various operators -#define ADD(expr1, expr2) storage.Create((expr1), (expr2)) +#define ADD(expr1, expr2) \ + storage.Create((expr1), (expr2)) #define LESS(expr1, expr2) storage.Create((expr1), (expr2)) #define SUM(expr) \ storage.Create((expr), query::Aggregation::Op::SUM) +#define COUNT(expr) \ + storage.Create((expr), query::Aggregation::Op::COUNT) diff --git a/tests/unit/query_planner.cpp b/tests/unit/query_planner.cpp index d7c165cc0..9abb0dc19 100644 --- a/tests/unit/query_planner.cpp +++ b/tests/unit/query_planner.cpp @@ -15,6 +15,7 @@ using namespace query::plan; using query::AstTreeStorage; +using query::Symbol; using query::SymbolTable; using query::SymbolGenerator; using Direction = query::EdgeAtom::Direction; @@ -57,13 +58,27 @@ using ExpectRemoveLabels = OpChecker; template using ExpectExpandUniquenessFilter = OpChecker>; -using ExpectAccumulate = OpChecker; using ExpectSkip = OpChecker; using ExpectLimit = OpChecker; +using ExpectOrderBy = OpChecker; + +class ExpectAccumulate : public OpChecker { + public: + ExpectAccumulate(const std::unordered_set &symbols) + : symbols_(symbols) {} + + void ExpectOp(Accumulate &op, const SymbolTable &symbol_table) override { + std::unordered_set got_symbols(op.symbols().begin(), + op.symbols().end()); + EXPECT_EQ(symbols_, got_symbols); + } + + private: + const std::unordered_set symbols_; +}; class ExpectAggregate : public OpChecker { public: - ExpectAggregate() = default; ExpectAggregate(const std::vector &aggregations, const std::unordered_set &group_by) : aggregations_(aggregations), group_by_(group_by) {} @@ -119,6 +134,7 @@ class PlanChecker : public LogicalOperatorVisitor { void Visit(Aggregate &op) override { CheckOp(op); } void Visit(Skip &op) override { CheckOp(op); } void Visit(Limit &op) override { CheckOp(op); } + void Visit(OrderBy &op) override { CheckOp(op); } std::list checkers_; @@ -132,18 +148,29 @@ class PlanChecker : public LogicalOperatorVisitor { const SymbolTable &symbol_table_; }; -template -auto CheckPlan(query::Query &query, TChecker... checker) { +auto MakeSymbolTable(query::Query &query) { SymbolTable symbol_table; SymbolGenerator symbol_generator(symbol_table); query.Accept(symbol_generator); - auto plan = MakeLogicalPlan(query, symbol_table); + return symbol_table; +} + +template +auto CheckPlan(LogicalOperator &plan, const SymbolTable &symbol_table, + TChecker... checker) { std::list checkers{&checker...}; PlanChecker plan_checker(checkers, symbol_table); - plan->Accept(plan_checker); + plan.Accept(plan_checker); EXPECT_TRUE(plan_checker.checkers_.empty()); } +template +auto CheckPlan(query::Query &query, TChecker... checker) { + auto symbol_table = MakeSymbolTable(query); + auto plan = MakeLogicalPlan(query, symbol_table); + CheckPlan(*plan, symbol_table, checker...); +} + TEST(TestLogicalPlanner, MatchNodeReturn) { // Test MATCH (n) RETURN n AS n AstTreeStorage storage; @@ -154,8 +181,12 @@ TEST(TestLogicalPlanner, MatchNodeReturn) { TEST(TestLogicalPlanner, CreateNodeReturn) { // Test CREATE (n) RETURN n AS n AstTreeStorage storage; - auto query = QUERY(CREATE(PATTERN(NODE("n"))), RETURN(IDENT("n"), AS("n"))); - CheckPlan(*query, ExpectCreateNode(), ExpectAccumulate(), ExpectProduce()); + auto ident_n = IDENT("n"); + auto query = QUERY(CREATE(PATTERN(NODE("n"))), RETURN(ident_n, AS("n"))); + auto symbol_table = MakeSymbolTable(*query); + auto acc = ExpectAccumulate({symbol_table.at(*ident_n)}); + auto plan = MakeLogicalPlan(*query, symbol_table); + CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, ExpectProduce()); } TEST(TestLogicalPlanner, CreateExpand) { @@ -413,12 +444,16 @@ TEST(TestLogicalPlanner, CreateWithSum) { auto dba = dbms.active(); auto prop = dba->property("prop"); AstTreeStorage storage; - auto sum = SUM(PROPERTY_LOOKUP("n", prop)); + auto n_prop = PROPERTY_LOOKUP("n", prop); + auto sum = SUM(n_prop); auto query = QUERY(CREATE(PATTERN(NODE("n"))), WITH(sum, AS("sum"))); + auto symbol_table = MakeSymbolTable(*query); + auto acc = ExpectAccumulate({symbol_table.at(*n_prop->expression_)}); auto aggr = ExpectAggregate({sum}, {}); + auto plan = MakeLogicalPlan(*query, symbol_table); // We expect both the accumulation and aggregation because the part before // WITH updates the database. - CheckPlan(*query, ExpectCreateNode(), ExpectAccumulate(), aggr, + CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, aggr, ExpectProduce()); } @@ -449,15 +484,19 @@ TEST(TestLogicalPlanner, MatchReturnSkipLimit) { TEST(TestLogicalPlanner, CreateWithSkipReturnLimit) { // Test CREATE (n) WITH n AS m SKIP 2 RETURN m LIMIT 1 AstTreeStorage storage; + auto ident_n = IDENT("n"); auto query = QUERY(CREATE(PATTERN(NODE("n"))), - WITH(IDENT("n"), AS("m"), SKIP(LITERAL(2))), + WITH(ident_n, AS("m"), SKIP(LITERAL(2))), RETURN(IDENT("m"), AS("m"), LIMIT(LITERAL(1)))); + auto symbol_table = MakeSymbolTable(*query); + auto acc = ExpectAccumulate({symbol_table.at(*ident_n)}); + auto plan = MakeLogicalPlan(*query, symbol_table); // Since we have a write query, we need to have Accumulate, so Skip and Limit // need to come before it. This is a bit different than Neo4j, which optimizes // WITH followed by RETURN as a single RETURN clause. This would cause the // Limit operator to also appear before Accumulate, thus changing the // behaviour. We've decided to diverge from Neo4j here, for consistency sake. - CheckPlan(*query, ExpectCreateNode(), ExpectSkip(), ExpectAccumulate(), + CheckPlan(*plan, symbol_table, ExpectCreateNode(), ExpectSkip(), acc, ExpectProduce(), ExpectLimit(), ExpectProduce()); } @@ -467,14 +506,69 @@ TEST(TestLogicalPlanner, CreateReturnSumSkipLimit) { auto dba = dbms.active(); auto prop = dba->property("prop"); AstTreeStorage storage; - auto sum = SUM(PROPERTY_LOOKUP("n", prop)); + auto n_prop = PROPERTY_LOOKUP("n", prop); + auto sum = SUM(n_prop); auto query = QUERY(CREATE(PATTERN(NODE("n"))), RETURN(sum, AS("s"), SKIP(LITERAL(2)), LIMIT(LITERAL(1)))); + auto symbol_table = MakeSymbolTable(*query); + auto acc = ExpectAccumulate({symbol_table.at(*n_prop->expression_)}); auto aggr = ExpectAggregate({sum}, {}); + auto plan = MakeLogicalPlan(*query, symbol_table); // We have a write query and aggregation, therefore Skip and Limit should come // after Accumulate and Aggregate. - CheckPlan(*query, ExpectCreateNode(), ExpectAccumulate(), aggr, ExpectSkip(), + CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, aggr, ExpectSkip(), ExpectLimit(), ExpectProduce()); } +TEST(TestLogicalPlanner, MatchReturnOrderBy) { + // Test MATCH (n) RETURN n ORDER BY n.prop + Dbms dbms; + auto dba = dbms.active(); + auto prop = dba->property("prop"); + AstTreeStorage storage; + auto ret = RETURN(IDENT("n"), AS("n"), ORDER_BY(PROPERTY_LOOKUP("n", prop))); + auto query = QUERY(MATCH(PATTERN(NODE("n"))), ret); + CheckPlan(*query, ExpectScanAll(), ExpectProduce(), ExpectOrderBy()); +} + +TEST(TestLogicalPlanner, CreateWithOrderByWhere) { + // Test CREATE (n) -[r :r]-> (m) + // WITH n AS new ORDER BY new.prop, r.prop WHERE m.prop < 42 + Dbms dbms; + auto dba = dbms.active(); + auto prop = dba->property("prop"); + auto r_type = dba->edge_type("r"); + AstTreeStorage storage; + auto ident_n = IDENT("n"); + auto new_prop = PROPERTY_LOOKUP("new", prop); + auto r_prop = PROPERTY_LOOKUP("r", prop); + auto m_prop = PROPERTY_LOOKUP("m", prop); + auto query = + QUERY(CREATE(PATTERN(NODE("n"), EDGE("r", r_type, Direction::RIGHT), + NODE("m"))), + WITH(ident_n, AS("new"), ORDER_BY(new_prop, r_prop)), + WHERE(LESS(m_prop, LITERAL(42)))); + auto symbol_table = MakeSymbolTable(*query); + // Since this is a write query, we expect to accumulate to old used symbols. + auto acc = ExpectAccumulate({ + symbol_table.at(*ident_n), // `n` in WITH + symbol_table.at(*r_prop->expression_), // `r` in ORDER BY + symbol_table.at(*m_prop->expression_), // `m` in WHERE + }); + auto plan = MakeLogicalPlan(*query, symbol_table); + CheckPlan(*plan, symbol_table, ExpectCreateNode(), ExpectCreateExpand(), acc, + ExpectProduce(), ExpectFilter(), ExpectOrderBy()); +} + +TEST(TestLogicalPlanner, ReturnAddSumCountOrderBy) { + // Test RETURN SUM(1) + COUNT(2) AS result ORDER BY result + AstTreeStorage storage; + auto sum = SUM(LITERAL(1)); + auto count = COUNT(LITERAL(2)); + auto query = + QUERY(RETURN(ADD(sum, count), AS("result"), ORDER_BY(IDENT("result")))); + auto aggr = ExpectAggregate({sum, count}, {}); + CheckPlan(*query, aggr, ExpectProduce(), ExpectOrderBy()); +} + } // namespace diff --git a/tests/unit/query_semantic.cpp b/tests/unit/query_semantic.cpp index 9abccd5c5..8d01d5f79 100644 --- a/tests/unit/query_semantic.cpp +++ b/tests/unit/query_semantic.cpp @@ -334,13 +334,14 @@ TEST(TestSymbolGenerator, MatchWithWhere) { } TEST(TestSymbolGenerator, MatchWithWhereUnbound) { - // Test MATCH (old) WITH old AS n WHERE old.prop < 42 + // Test MATCH (old) WITH COUNT(old) AS c WHERE old.prop < 42 Dbms dbms; auto dba = dbms.active(); auto prop = dba->property("prop"); AstTreeStorage storage; - auto query = QUERY(MATCH(PATTERN(NODE("old"))), WITH(IDENT("old"), AS("n")), - WHERE(LESS(PROPERTY_LOOKUP("old", prop), LITERAL(42)))); + auto query = + QUERY(MATCH(PATTERN(NODE("old"))), WITH(COUNT(IDENT("old")), AS("c")), + WHERE(LESS(PROPERTY_LOOKUP("old", prop), LITERAL(42)))); SymbolTable symbol_table; SymbolGenerator symbol_generator(symbol_table); EXPECT_THROW(query->Accept(symbol_generator), UnboundVariableError); @@ -585,4 +586,67 @@ TEST(TestSymbolGenerator, SkipLimitIdentifier) { } } +TEST(TestSymbolGenerator, OrderBy) { + // Test MATCH (old) RETURN old AS new ORDER BY COUNT(1) + { + AstTreeStorage storage; + auto query = + QUERY(MATCH(PATTERN(NODE("old"))), + RETURN(IDENT("old"), AS("new"), ORDER_BY(COUNT(LITERAL(1))))); + SymbolTable symbol_table; + SymbolGenerator symbol_generator(symbol_table); + EXPECT_THROW(query->Accept(symbol_generator), SemanticException); + } + // Test MATCH (old) RETURN COUNT(old) AS new ORDER BY old + { + AstTreeStorage storage; + auto query = + QUERY(MATCH(PATTERN(NODE("old"))), + RETURN(COUNT(IDENT("old")), AS("new"), ORDER_BY(IDENT("old")))); + SymbolTable symbol_table; + SymbolGenerator symbol_generator(symbol_table); + EXPECT_THROW(query->Accept(symbol_generator), UnboundVariableError); + } + // Test MATCH (old) RETURN COUNT(old) AS new ORDER BY new + { + AstTreeStorage storage; + auto node = NODE("old"); + auto ident_old = IDENT("old"); + auto as_new = AS("new"); + auto ident_new = IDENT("new"); + auto query = QUERY(MATCH(PATTERN(node)), + RETURN(COUNT(ident_old), as_new, ORDER_BY(ident_new))); + SymbolTable symbol_table; + SymbolGenerator symbol_generator(symbol_table); + query->Accept(symbol_generator); + // Symbols for `old`, `count(old)` and `new` + EXPECT_EQ(symbol_table.max_position(), 3); + auto old = symbol_table.at(*node->identifier_); + EXPECT_EQ(old, symbol_table.at(*ident_old)); + auto new_sym = symbol_table.at(*as_new); + EXPECT_NE(old, new_sym); + EXPECT_EQ(new_sym, symbol_table.at(*ident_new)); + } + // Test MATCH (old) RETURN old AS new ORDER BY old + { + AstTreeStorage storage; + auto node = NODE("old"); + auto ident_old = IDENT("old"); + auto as_new = AS("new"); + auto by_old = IDENT("old"); + auto query = QUERY(MATCH(PATTERN(node)), + RETURN(ident_old, as_new, ORDER_BY(by_old))); + SymbolTable symbol_table; + SymbolGenerator symbol_generator(symbol_table); + query->Accept(symbol_generator); + // Symbols for `old` and `new` + EXPECT_EQ(symbol_table.max_position(), 2); + auto old = symbol_table.at(*node->identifier_); + EXPECT_EQ(old, symbol_table.at(*ident_old)); + EXPECT_EQ(old, symbol_table.at(*by_old)); + auto new_sym = symbol_table.at(*as_new); + EXPECT_NE(old, new_sym); + } +} + }