Check for aggregation in all Expression types
Summary: This fixes an issue when aggregations and/or group by expressions weren't picked up from certain operators. In addition to that, we would segfault in cases when the `has_aggregation_` is empty. For example, function calls without arguments: `RETURN PI()`. Test aggregations inside some operators Reviewers: florijan, mislav.bradac, buda Reviewed By: mislav.bradac Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D425
This commit is contained in:
parent
b2c523b93a
commit
7278bdff94
@ -313,8 +313,18 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PreVisit(ListLiteral &) override {
|
||||
has_aggregation_.emplace_back(false);
|
||||
bool PostVisit(ListLiteral &list_literal) override {
|
||||
debug_assert(
|
||||
list_literal.elements_.size() <= has_aggregation_.size(),
|
||||
"Expected has_aggregation_ flags as much as there are list elements.");
|
||||
bool has_aggr = false;
|
||||
auto it = has_aggregation_.end();
|
||||
std::advance(it, -list_literal.elements_.size());
|
||||
while (it != has_aggregation_.end()) {
|
||||
has_aggr = has_aggr || *it;
|
||||
it = has_aggregation_.erase(it);
|
||||
}
|
||||
has_aggregation_.emplace_back(has_aggr);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -330,10 +340,51 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PreVisit(ListSlicingOperator &list_slicing) override {
|
||||
list_slicing.list_->Accept(*this);
|
||||
bool list_has_aggr = has_aggregation_.back();
|
||||
has_aggregation_.pop_back();
|
||||
bool has_aggr = list_has_aggr;
|
||||
if (list_slicing.lower_bound_) {
|
||||
list_slicing.lower_bound_->Accept(*this);
|
||||
has_aggr = has_aggr || has_aggregation_.back();
|
||||
has_aggregation_.pop_back();
|
||||
}
|
||||
if (list_slicing.upper_bound_) {
|
||||
list_slicing.upper_bound_->Accept(*this);
|
||||
has_aggr = has_aggr || has_aggregation_.back();
|
||||
has_aggregation_.pop_back();
|
||||
}
|
||||
if (has_aggr && !list_has_aggr) {
|
||||
// We need to group by the list expression, because it didn't have an
|
||||
// aggregation inside.
|
||||
group_by_.emplace_back(list_slicing.list_);
|
||||
}
|
||||
has_aggregation_.emplace_back(has_aggr);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool PostVisit(Function &function) override {
|
||||
debug_assert(function.arguments_.size() <= has_aggregation_.size(),
|
||||
"Expected has_aggregation_ flags as much as there are "
|
||||
"function arguments.");
|
||||
bool has_aggr = false;
|
||||
auto it = has_aggregation_.end();
|
||||
std::advance(it, -function.arguments_.size());
|
||||
while (it != has_aggregation_.end()) {
|
||||
has_aggr = has_aggr || *it;
|
||||
it = has_aggregation_.erase(it);
|
||||
}
|
||||
has_aggregation_.emplace_back(has_aggr);
|
||||
return true;
|
||||
}
|
||||
|
||||
#define VISIT_BINARY_OPERATOR(BinaryOperator) \
|
||||
bool PostVisit(BinaryOperator &op) override { \
|
||||
/* has_aggregation_ stack is reversed, last result is from the 2nd \
|
||||
* expression. */ \
|
||||
debug_assert(has_aggregation_.size() >= 2U, \
|
||||
"Expected at least 2 has_aggregation_ flags."); \
|
||||
/* has_aggregation_ stack is reversed, last result is from the 2nd */ \
|
||||
/* expression. */ \
|
||||
bool aggr2 = has_aggregation_.back(); \
|
||||
has_aggregation_.pop_back(); \
|
||||
bool aggr1 = has_aggregation_.back(); \
|
||||
@ -352,6 +403,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
VISIT_BINARY_OPERATOR(OrOperator)
|
||||
VISIT_BINARY_OPERATOR(XorOperator)
|
||||
VISIT_BINARY_OPERATOR(AndOperator)
|
||||
VISIT_BINARY_OPERATOR(FilterAndOperator)
|
||||
VISIT_BINARY_OPERATOR(AdditionOperator)
|
||||
VISIT_BINARY_OPERATOR(SubtractionOperator)
|
||||
VISIT_BINARY_OPERATOR(MultiplicationOperator)
|
||||
@ -363,6 +415,8 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
VISIT_BINARY_OPERATOR(GreaterOperator)
|
||||
VISIT_BINARY_OPERATOR(LessEqualOperator)
|
||||
VISIT_BINARY_OPERATOR(GreaterEqualOperator)
|
||||
VISIT_BINARY_OPERATOR(InListOperator)
|
||||
VISIT_BINARY_OPERATOR(ListIndexingOperator)
|
||||
|
||||
#undef VISIT_BINARY_OPERATOR
|
||||
|
||||
@ -383,6 +437,8 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
}
|
||||
|
||||
bool PostVisit(NamedExpression &named_expr) override {
|
||||
debug_assert(has_aggregation_.size() == 1U,
|
||||
"Expected to reduce has_aggregation_ to single boolean.");
|
||||
if (!has_aggregation_.back()) {
|
||||
group_by_.emplace_back(named_expr.expression_);
|
||||
}
|
||||
|
@ -25,6 +25,8 @@
|
||||
|
||||
#include "database/graph_db_datatypes.hpp"
|
||||
#include "query/frontend/ast/ast.hpp"
|
||||
#include "query/interpret/awesome_memgraph_functions.hpp"
|
||||
#include "utils/string.hpp"
|
||||
|
||||
namespace query {
|
||||
|
||||
@ -465,3 +467,11 @@ auto GetMerge(AstTreeStorage &storage, Pattern *pattern, OnMatch on_match,
|
||||
#define EQ(expr1, expr2) storage.Create<query::EqualOperator>((expr1), (expr2))
|
||||
#define AND(expr1, expr2) storage.Create<query::AndOperator>((expr1), (expr2))
|
||||
#define OR(expr1, expr2) storage.Create<query::OrOperator>((expr1), (expr2))
|
||||
// Function call
|
||||
#define FN(function_name, ...) \
|
||||
storage.Create<query::Function>( \
|
||||
query::NameToFunction(utils::ToUpperCase(function_name)), \
|
||||
std::vector<query::Expression *>{__VA_ARGS__})
|
||||
// List slicing
|
||||
#define SLICE(list, lower_bound, upper_bound) \
|
||||
storage.Create<query::ListSlicingOperator>(list, lower_bound, upper_bound)
|
||||
|
@ -848,4 +848,61 @@ TEST(TestLogicalPlanner, MultipleOptionalMatchReturn) {
|
||||
ExpectProduce());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, FunctionAggregationReturn) {
|
||||
// Test RETURN sqrt(SUM(2)) AS result, 42 AS group_by
|
||||
AstTreeStorage storage;
|
||||
auto sum = SUM(LITERAL(2));
|
||||
auto group_by_literal = LITERAL(42);
|
||||
QUERY(
|
||||
RETURN(FN("sqrt", sum), AS("result"), group_by_literal, AS("group_by")));
|
||||
auto aggr = ExpectAggregate({sum}, {group_by_literal});
|
||||
CheckPlan(storage, aggr, ExpectProduce());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, FunctionWithoutArguments) {
|
||||
// Test RETURN pi() AS pi
|
||||
AstTreeStorage storage;
|
||||
QUERY(RETURN(FN("pi"), AS("pi")));
|
||||
CheckPlan(storage, ExpectProduce());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, ListLiteralAggregationReturn) {
|
||||
// Test RETURN [SUM(2)] AS result, 42 AS group_by
|
||||
AstTreeStorage storage;
|
||||
auto sum = SUM(LITERAL(2));
|
||||
auto group_by_literal = LITERAL(42);
|
||||
QUERY(RETURN(LIST(sum), AS("result"), group_by_literal, AS("group_by")));
|
||||
auto aggr = ExpectAggregate({sum}, {group_by_literal});
|
||||
CheckPlan(storage, aggr, ExpectProduce());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, EmptyListIndexAggregation) {
|
||||
// Test RETURN [][SUM(2)] AS result, 42 AS group_by
|
||||
AstTreeStorage storage;
|
||||
auto sum = SUM(LITERAL(2));
|
||||
auto empty_list = LIST();
|
||||
auto group_by_literal = LITERAL(42);
|
||||
QUERY(RETURN(storage.Create<query::ListIndexingOperator>(empty_list, sum),
|
||||
AS("result"), group_by_literal, AS("group_by")));
|
||||
// We expect to group by '42' and the empty list, because it is a
|
||||
// sub-expression of a binary operator which contains an aggregation. This is
|
||||
// similar to grouping by '1' in `RETURN 1 + SUM(2)`.
|
||||
auto aggr = ExpectAggregate({sum}, {empty_list, group_by_literal});
|
||||
CheckPlan(storage, aggr, ExpectProduce());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, ListSliceAggregationReturn) {
|
||||
// Test RETURN [1, 2][0..SUM(2)] AS result, 42 AS group_by
|
||||
AstTreeStorage storage;
|
||||
auto sum = SUM(LITERAL(2));
|
||||
auto list = LIST(LITERAL(1), LITERAL(2));
|
||||
auto group_by_literal = LITERAL(42);
|
||||
QUERY(RETURN(SLICE(list, LITERAL(0), sum), AS("result"), group_by_literal,
|
||||
AS("group_by")));
|
||||
// Similarly to EmptyListIndexAggregation test, we expect grouping by list and
|
||||
// '42', because slicing is an operator.
|
||||
auto aggr = ExpectAggregate({sum}, {list, group_by_literal});
|
||||
CheckPlan(storage, aggr, ExpectProduce());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user