Check for aggregation in all Expression types

Summary:
This fixes an issue when aggregations and/or group by expressions
weren't picked up from certain operators. In addition to that, we would
segfault in cases when the `has_aggregation_` is empty. For example,
function calls without arguments: `RETURN PI()`.

Test aggregations inside some operators

Reviewers: florijan, mislav.bradac, buda

Reviewed By: mislav.bradac

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D425
This commit is contained in:
Teon Banek 2017-06-05 11:50:47 +02:00
parent b2c523b93a
commit 7278bdff94
3 changed files with 127 additions and 4 deletions

View File

@ -313,8 +313,18 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
return true;
}
bool PreVisit(ListLiteral &) override {
has_aggregation_.emplace_back(false);
bool PostVisit(ListLiteral &list_literal) override {
debug_assert(
list_literal.elements_.size() <= has_aggregation_.size(),
"Expected has_aggregation_ flags as much as there are list elements.");
bool has_aggr = false;
auto it = has_aggregation_.end();
std::advance(it, -list_literal.elements_.size());
while (it != has_aggregation_.end()) {
has_aggr = has_aggr || *it;
it = has_aggregation_.erase(it);
}
has_aggregation_.emplace_back(has_aggr);
return true;
}
@ -330,10 +340,51 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
return true;
}
bool PreVisit(ListSlicingOperator &list_slicing) override {
list_slicing.list_->Accept(*this);
bool list_has_aggr = has_aggregation_.back();
has_aggregation_.pop_back();
bool has_aggr = list_has_aggr;
if (list_slicing.lower_bound_) {
list_slicing.lower_bound_->Accept(*this);
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
}
if (list_slicing.upper_bound_) {
list_slicing.upper_bound_->Accept(*this);
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
}
if (has_aggr && !list_has_aggr) {
// We need to group by the list expression, because it didn't have an
// aggregation inside.
group_by_.emplace_back(list_slicing.list_);
}
has_aggregation_.emplace_back(has_aggr);
return false;
}
bool PostVisit(Function &function) override {
debug_assert(function.arguments_.size() <= has_aggregation_.size(),
"Expected has_aggregation_ flags as much as there are "
"function arguments.");
bool has_aggr = false;
auto it = has_aggregation_.end();
std::advance(it, -function.arguments_.size());
while (it != has_aggregation_.end()) {
has_aggr = has_aggr || *it;
it = has_aggregation_.erase(it);
}
has_aggregation_.emplace_back(has_aggr);
return true;
}
#define VISIT_BINARY_OPERATOR(BinaryOperator) \
bool PostVisit(BinaryOperator &op) override { \
/* has_aggregation_ stack is reversed, last result is from the 2nd \
* expression. */ \
debug_assert(has_aggregation_.size() >= 2U, \
"Expected at least 2 has_aggregation_ flags."); \
/* has_aggregation_ stack is reversed, last result is from the 2nd */ \
/* expression. */ \
bool aggr2 = has_aggregation_.back(); \
has_aggregation_.pop_back(); \
bool aggr1 = has_aggregation_.back(); \
@ -352,6 +403,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
VISIT_BINARY_OPERATOR(OrOperator)
VISIT_BINARY_OPERATOR(XorOperator)
VISIT_BINARY_OPERATOR(AndOperator)
VISIT_BINARY_OPERATOR(FilterAndOperator)
VISIT_BINARY_OPERATOR(AdditionOperator)
VISIT_BINARY_OPERATOR(SubtractionOperator)
VISIT_BINARY_OPERATOR(MultiplicationOperator)
@ -363,6 +415,8 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
VISIT_BINARY_OPERATOR(GreaterOperator)
VISIT_BINARY_OPERATOR(LessEqualOperator)
VISIT_BINARY_OPERATOR(GreaterEqualOperator)
VISIT_BINARY_OPERATOR(InListOperator)
VISIT_BINARY_OPERATOR(ListIndexingOperator)
#undef VISIT_BINARY_OPERATOR
@ -383,6 +437,8 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
}
bool PostVisit(NamedExpression &named_expr) override {
debug_assert(has_aggregation_.size() == 1U,
"Expected to reduce has_aggregation_ to single boolean.");
if (!has_aggregation_.back()) {
group_by_.emplace_back(named_expr.expression_);
}

View File

@ -25,6 +25,8 @@
#include "database/graph_db_datatypes.hpp"
#include "query/frontend/ast/ast.hpp"
#include "query/interpret/awesome_memgraph_functions.hpp"
#include "utils/string.hpp"
namespace query {
@ -465,3 +467,11 @@ auto GetMerge(AstTreeStorage &storage, Pattern *pattern, OnMatch on_match,
#define EQ(expr1, expr2) storage.Create<query::EqualOperator>((expr1), (expr2))
#define AND(expr1, expr2) storage.Create<query::AndOperator>((expr1), (expr2))
#define OR(expr1, expr2) storage.Create<query::OrOperator>((expr1), (expr2))
// Function call
#define FN(function_name, ...) \
storage.Create<query::Function>( \
query::NameToFunction(utils::ToUpperCase(function_name)), \
std::vector<query::Expression *>{__VA_ARGS__})
// List slicing
#define SLICE(list, lower_bound, upper_bound) \
storage.Create<query::ListSlicingOperator>(list, lower_bound, upper_bound)

View File

@ -848,4 +848,61 @@ TEST(TestLogicalPlanner, MultipleOptionalMatchReturn) {
ExpectProduce());
}
TEST(TestLogicalPlanner, FunctionAggregationReturn) {
// Test RETURN sqrt(SUM(2)) AS result, 42 AS group_by
AstTreeStorage storage;
auto sum = SUM(LITERAL(2));
auto group_by_literal = LITERAL(42);
QUERY(
RETURN(FN("sqrt", sum), AS("result"), group_by_literal, AS("group_by")));
auto aggr = ExpectAggregate({sum}, {group_by_literal});
CheckPlan(storage, aggr, ExpectProduce());
}
TEST(TestLogicalPlanner, FunctionWithoutArguments) {
// Test RETURN pi() AS pi
AstTreeStorage storage;
QUERY(RETURN(FN("pi"), AS("pi")));
CheckPlan(storage, ExpectProduce());
}
TEST(TestLogicalPlanner, ListLiteralAggregationReturn) {
// Test RETURN [SUM(2)] AS result, 42 AS group_by
AstTreeStorage storage;
auto sum = SUM(LITERAL(2));
auto group_by_literal = LITERAL(42);
QUERY(RETURN(LIST(sum), AS("result"), group_by_literal, AS("group_by")));
auto aggr = ExpectAggregate({sum}, {group_by_literal});
CheckPlan(storage, aggr, ExpectProduce());
}
TEST(TestLogicalPlanner, EmptyListIndexAggregation) {
// Test RETURN [][SUM(2)] AS result, 42 AS group_by
AstTreeStorage storage;
auto sum = SUM(LITERAL(2));
auto empty_list = LIST();
auto group_by_literal = LITERAL(42);
QUERY(RETURN(storage.Create<query::ListIndexingOperator>(empty_list, sum),
AS("result"), group_by_literal, AS("group_by")));
// We expect to group by '42' and the empty list, because it is a
// sub-expression of a binary operator which contains an aggregation. This is
// similar to grouping by '1' in `RETURN 1 + SUM(2)`.
auto aggr = ExpectAggregate({sum}, {empty_list, group_by_literal});
CheckPlan(storage, aggr, ExpectProduce());
}
TEST(TestLogicalPlanner, ListSliceAggregationReturn) {
// Test RETURN [1, 2][0..SUM(2)] AS result, 42 AS group_by
AstTreeStorage storage;
auto sum = SUM(LITERAL(2));
auto list = LIST(LITERAL(1), LITERAL(2));
auto group_by_literal = LITERAL(42);
QUERY(RETURN(SLICE(list, LITERAL(0), sum), AS("result"), group_by_literal,
AS("group_by")));
// Similarly to EmptyListIndexAggregation test, we expect grouping by list and
// '42', because slicing is an operator.
auto aggr = ExpectAggregate({sum}, {list, group_by_literal});
CheckPlan(storage, aggr, ExpectProduce());
}
} // namespace