Implement collect aggregation
Reviewers: teon.banek, florijan Reviewed By: teon.banek, florijan Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D386
This commit is contained in:
parent
28fbeb8e9d
commit
f9cd87bb46
@ -593,12 +593,13 @@ class Aggregation : public UnaryOperator {
|
||||
friend class AstTreeStorage;
|
||||
|
||||
public:
|
||||
enum class Op { COUNT, MIN, MAX, SUM, AVG };
|
||||
enum class Op { COUNT, MIN, MAX, SUM, AVG, COLLECT };
|
||||
static const constexpr char *const kCount = "COUNT";
|
||||
static const constexpr char *const kMin = "MIN";
|
||||
static const constexpr char *const kMax = "MAX";
|
||||
static const constexpr char *const kSum = "SUM";
|
||||
static const constexpr char *const kAvg = "AVG";
|
||||
static const constexpr char *const kCollect = "COLLECT";
|
||||
|
||||
DEFVISITABLE(TreeVisitor<TypedValue>);
|
||||
bool Accept(HierarchicalTreeVisitor &visitor) override {
|
||||
|
@ -788,6 +788,10 @@ antlrcpp::Any CypherMainVisitor::visitFunctionInvocation(
|
||||
return static_cast<Expression *>(
|
||||
storage_.Create<Aggregation>(expressions[0], Aggregation::Op::AVG));
|
||||
}
|
||||
if (function_name == Aggregation::kCollect) {
|
||||
return static_cast<Expression *>(storage_.Create<Aggregation>(
|
||||
expressions[0], Aggregation::Op::COLLECT));
|
||||
}
|
||||
}
|
||||
auto function = NameToFunction(function_name);
|
||||
if (!function) throw SemanticException();
|
||||
|
@ -938,10 +938,13 @@ void Aggregate::AggregateCursor::EnsureInitialized(
|
||||
if (agg_value.values_.size() > 0) return;
|
||||
|
||||
for (const auto &agg_elem : self_.aggregations_) {
|
||||
if (std::get<1>(agg_elem) == Aggregation::Op::COUNT)
|
||||
if (std::get<1>(agg_elem) == Aggregation::Op::COUNT) {
|
||||
agg_value.values_.emplace_back(TypedValue(0));
|
||||
else
|
||||
} else if (std::get<1>(agg_elem) == Aggregation::Op::COLLECT) {
|
||||
agg_value.values_.emplace_back(std::vector<TypedValue>());
|
||||
} else {
|
||||
agg_value.values_.emplace_back(TypedValue::Null);
|
||||
}
|
||||
}
|
||||
agg_value.counts_.resize(self_.aggregations_.size(), 0);
|
||||
|
||||
@ -964,7 +967,6 @@ void Aggregate::AggregateCursor::Update(
|
||||
auto agg_elem_it = self_.aggregations_.begin();
|
||||
for (; count_it < agg_value.counts_.end();
|
||||
count_it++, value_it++, agg_elem_it++) {
|
||||
|
||||
// COUNT(*) is the only case where input expression is optional
|
||||
// handle it here
|
||||
auto input_expr_ptr = std::get<0>(*agg_elem_it);
|
||||
@ -986,16 +988,21 @@ void Aggregate::AggregateCursor::Update(
|
||||
switch (agg_op) {
|
||||
case Aggregation::Op::MIN:
|
||||
case Aggregation::Op::MAX:
|
||||
*value_it = input_value;
|
||||
EnsureOkForMinMax(input_value);
|
||||
break;
|
||||
case Aggregation::Op::SUM:
|
||||
case Aggregation::Op::AVG:
|
||||
*value_it = input_value;
|
||||
EnsureOkForAvgSum(input_value);
|
||||
break;
|
||||
case Aggregation::Op::COUNT:
|
||||
*value_it = 1;
|
||||
break;
|
||||
case Aggregation::Op::COLLECT:
|
||||
value_it->Value<std::vector<TypedValue>>().push_back(input_value);
|
||||
break;
|
||||
}
|
||||
*value_it = agg_op == Aggregation::Op::COUNT ? 1 : input_value;
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -1029,6 +1036,9 @@ void Aggregate::AggregateCursor::Update(
|
||||
EnsureOkForAvgSum(input_value);
|
||||
*value_it = *value_it + input_value;
|
||||
break;
|
||||
case Aggregation::Op::COLLECT:
|
||||
value_it->Value<std::vector<TypedValue>>().push_back(input_value);
|
||||
break;
|
||||
} // end switch over Aggregation::Op enum
|
||||
} // end loop over all aggregations
|
||||
}
|
||||
|
@ -506,15 +506,15 @@ TEST(CypherMainVisitorTest, UnaryMinusPlusOperators) {
|
||||
|
||||
TEST(CypherMainVisitorTest, Aggregation) {
|
||||
AstGenerator ast_generator(
|
||||
"RETURN COUNT(a), MIN(b), MAX(c), SUM(d), AVG(e), COUNT(*)");
|
||||
"RETURN COUNT(a), MIN(b), MAX(c), SUM(d), AVG(e), COLLECT(f), COUNT(*)");
|
||||
auto *query = ast_generator.query_;
|
||||
auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]);
|
||||
ASSERT_EQ(return_clause->body_.named_expressions.size(), 6U);
|
||||
ASSERT_EQ(return_clause->body_.named_expressions.size(), 7U);
|
||||
Aggregation::Op ops[] = {Aggregation::Op::COUNT, Aggregation::Op::MIN,
|
||||
Aggregation::Op::MAX, Aggregation::Op::SUM,
|
||||
Aggregation::Op::AVG};
|
||||
std::string ids[] = {"a", "b", "c", "d", "e"};
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
Aggregation::Op::MAX, Aggregation::Op::SUM,
|
||||
Aggregation::Op::AVG, Aggregation::Op::COLLECT};
|
||||
std::string ids[] = {"a", "b", "c", "d", "e", "f"};
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
auto *aggregation = dynamic_cast<Aggregation *>(
|
||||
return_clause->body_.named_expressions[i]->expression_);
|
||||
ASSERT_TRUE(aggregation);
|
||||
@ -524,7 +524,7 @@ TEST(CypherMainVisitorTest, Aggregation) {
|
||||
ASSERT_EQ(identifier->name_, ids[i]);
|
||||
}
|
||||
auto *aggregation = dynamic_cast<Aggregation *>(
|
||||
return_clause->body_.named_expressions[5]->expression_);
|
||||
return_clause->body_.named_expressions[6]->expression_);
|
||||
ASSERT_TRUE(aggregation);
|
||||
ASSERT_EQ(aggregation->op_, Aggregation::Op::COUNT);
|
||||
ASSERT_FALSE(aggregation->expression_);
|
||||
|
@ -28,6 +28,14 @@ namespace query {
|
||||
|
||||
namespace test_common {
|
||||
|
||||
auto ToInt64List(const TypedValue &t) {
|
||||
std::vector<int64_t> list;
|
||||
for (auto x : t.Value<std::vector<TypedValue>>()) {
|
||||
list.push_back(x.Value<int64_t>());
|
||||
}
|
||||
return list;
|
||||
};
|
||||
|
||||
// Custom types for ORDER BY, SKIP, LIMIT, ON MATCH and ON CREATE expressions,
|
||||
// so that they can be used to resolve function calls.
|
||||
struct OrderBy {
|
||||
|
@ -14,14 +14,15 @@
|
||||
#include "query/interpret/awesome_memgraph_functions.hpp"
|
||||
#include "query/interpret/eval.hpp"
|
||||
#include "query/interpret/frame.hpp"
|
||||
#include "query_common.hpp"
|
||||
|
||||
using namespace query;
|
||||
using testing::Pair;
|
||||
using testing::UnorderedElementsAre;
|
||||
using testing::ElementsAre;
|
||||
using query::test_common::ToInt64List;
|
||||
|
||||
namespace {
|
||||
|
||||
struct NoContextExpressionEvaluator {
|
||||
NoContextExpressionEvaluator() {}
|
||||
Frame frame{128};
|
||||
@ -823,30 +824,23 @@ TEST(ExpressionEvaluator, FunctionRange) {
|
||||
EXPECT_TRUE(EvaluateFunction("RANGE", {1, 2, TypedValue::Null}).IsNull());
|
||||
EXPECT_THROW(EvaluateFunction("RANGE", {1, TypedValue::Null, 1.3}),
|
||||
QueryRuntimeException);
|
||||
auto to_int_list = [](const TypedValue &t) {
|
||||
std::vector<int64_t> list;
|
||||
for (auto x : t.Value<std::vector<TypedValue>>()) {
|
||||
list.push_back(x.Value<int64_t>());
|
||||
}
|
||||
return list;
|
||||
};
|
||||
EXPECT_THROW(EvaluateFunction("RANGE", {1, 2, 0}), QueryRuntimeException);
|
||||
EXPECT_THAT(to_int_list(EvaluateFunction("RANGE", {1, 3})),
|
||||
EXPECT_THAT(ToInt64List(EvaluateFunction("RANGE", {1, 3})),
|
||||
ElementsAre(1, 2, 3));
|
||||
EXPECT_THAT(to_int_list(EvaluateFunction("RANGE", {-1, 5, 2})),
|
||||
EXPECT_THAT(ToInt64List(EvaluateFunction("RANGE", {-1, 5, 2})),
|
||||
ElementsAre(-1, 1, 3, 5));
|
||||
EXPECT_THAT(to_int_list(EvaluateFunction("RANGE", {2, 10, 3})),
|
||||
EXPECT_THAT(ToInt64List(EvaluateFunction("RANGE", {2, 10, 3})),
|
||||
ElementsAre(2, 5, 8));
|
||||
EXPECT_THAT(to_int_list(EvaluateFunction("RANGE", {2, 2, 2})),
|
||||
EXPECT_THAT(ToInt64List(EvaluateFunction("RANGE", {2, 2, 2})),
|
||||
ElementsAre(2));
|
||||
EXPECT_THAT(to_int_list(EvaluateFunction("RANGE", {3, 0, 5})), ElementsAre());
|
||||
EXPECT_THAT(to_int_list(EvaluateFunction("RANGE", {5, 1, -2})),
|
||||
EXPECT_THAT(ToInt64List(EvaluateFunction("RANGE", {3, 0, 5})), ElementsAre());
|
||||
EXPECT_THAT(ToInt64List(EvaluateFunction("RANGE", {5, 1, -2})),
|
||||
ElementsAre(5, 3, 1));
|
||||
EXPECT_THAT(to_int_list(EvaluateFunction("RANGE", {6, 1, -2})),
|
||||
EXPECT_THAT(ToInt64List(EvaluateFunction("RANGE", {6, 1, -2})),
|
||||
ElementsAre(6, 4, 2));
|
||||
EXPECT_THAT(to_int_list(EvaluateFunction("RANGE", {2, 2, -3})),
|
||||
EXPECT_THAT(ToInt64List(EvaluateFunction("RANGE", {2, 2, -3})),
|
||||
ElementsAre(2));
|
||||
EXPECT_THAT(to_int_list(EvaluateFunction("RANGE", {-2, 4, -1})),
|
||||
EXPECT_THAT(ToInt64List(EvaluateFunction("RANGE", {-2, 4, -1})),
|
||||
ElementsAre());
|
||||
}
|
||||
|
||||
|
@ -21,6 +21,8 @@
|
||||
|
||||
using namespace query;
|
||||
using namespace query::plan;
|
||||
using testing::UnorderedElementsAre;
|
||||
using query::test_common::ToInt64List;
|
||||
|
||||
TEST(QueryPlan, Accumulate) {
|
||||
// simulate the following two query execution on an empty db
|
||||
@ -169,18 +171,19 @@ TEST(QueryPlan, AggregateOps) {
|
||||
auto n_p = PROPERTY_LOOKUP("n", prop);
|
||||
symbol_table[*n_p->expression_] = n.sym_;
|
||||
|
||||
std::vector<Expression *> aggregation_expressions(6, n_p);
|
||||
std::vector<Expression *> aggregation_expressions(7, n_p);
|
||||
aggregation_expressions[0] = nullptr;
|
||||
auto produce = MakeAggregationProduce(
|
||||
n.op_, symbol_table, storage, aggregation_expressions,
|
||||
{Aggregation::Op::COUNT, Aggregation::Op::COUNT, Aggregation::Op::MIN,
|
||||
Aggregation::Op::MAX, Aggregation::Op::SUM, Aggregation::Op::AVG},
|
||||
Aggregation::Op::MAX, Aggregation::Op::SUM, Aggregation::Op::AVG,
|
||||
Aggregation::Op::COLLECT},
|
||||
{}, {});
|
||||
|
||||
// checks
|
||||
auto results = CollectProduce(produce, symbol_table, *dba).GetResults();
|
||||
ASSERT_EQ(results.size(), 1);
|
||||
ASSERT_EQ(results[0].size(), 6);
|
||||
ASSERT_EQ(results[0].size(), 7);
|
||||
// count(*)
|
||||
ASSERT_EQ(results[0][0].type(), TypedValue::Type::Int);
|
||||
EXPECT_EQ(results[0][0].Value<int64_t>(), 4);
|
||||
@ -199,6 +202,9 @@ TEST(QueryPlan, AggregateOps) {
|
||||
// avg
|
||||
ASSERT_EQ(results[0][5].type(), TypedValue::Type::Double);
|
||||
EXPECT_FLOAT_EQ(results[0][5].Value<double>(), 24 / 3.0);
|
||||
// collect
|
||||
ASSERT_EQ(results[0][6].type(), TypedValue::Type::List);
|
||||
EXPECT_THAT(ToInt64List(results[0][6]), UnorderedElementsAre(5, 7, 12));
|
||||
}
|
||||
|
||||
TEST(QueryPlan, AggregateGroupByValues) {
|
||||
@ -320,6 +326,34 @@ TEST(QueryPlan, AggregateNoInput) {
|
||||
EXPECT_EQ(1, results[0][0].Value<int64_t>());
|
||||
}
|
||||
|
||||
// TODO: This test is valid but it fails. We don't handle aggregations correctly
|
||||
// in the case when there is no input. Also add similar tests for other
|
||||
// aggregation ops.
|
||||
// TEST(QueryPlan, AggregateCollectNoResults) {
|
||||
// Dbms dbms;
|
||||
// auto dba = dbms.active();
|
||||
// auto prop = dba->property("prop");
|
||||
//
|
||||
// AstTreeStorage storage;
|
||||
// SymbolTable symbol_table;
|
||||
//
|
||||
// // match all nodes and perform aggregations
|
||||
// auto n = MakeScanAll(storage, symbol_table, "n");
|
||||
// auto n_p = PROPERTY_LOOKUP("n", prop);
|
||||
// symbol_table[*n_p->expression_] = n.sym_;
|
||||
//
|
||||
// auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p},
|
||||
// {Aggregation::Op::COLLECT}, {}, {});
|
||||
//
|
||||
// // checks
|
||||
// auto results = CollectProduce(produce, symbol_table, *dba).GetResults();
|
||||
// ASSERT_EQ(results.size(), 1);
|
||||
// ASSERT_EQ(results[0].size(), 1);
|
||||
// ASSERT_EQ(results[0][0].type(), TypedValue::Type::List);
|
||||
// // Collect should return empty list if there are no results.
|
||||
// EXPECT_THAT(ToInt64List(results[0][0]), UnorderedElementsAre());
|
||||
//}
|
||||
|
||||
TEST(QueryPlan, AggregateCountEdgeCases) {
|
||||
// tests for detected bugs in the COUNT aggregation behavior
|
||||
// ensure that COUNT returns correctly for
|
||||
|
Loading…
Reference in New Issue
Block a user