Implement collect aggregation

Reviewers: teon.banek, florijan

Reviewed By: teon.banek, florijan

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D386
This commit is contained in:
Mislav Bradac 2017-05-19 15:49:25 +02:00
parent 28fbeb8e9d
commit f9cd87bb46
7 changed files with 83 additions and 32 deletions

View File

@ -593,12 +593,13 @@ class Aggregation : public UnaryOperator {
friend class AstTreeStorage;
public:
enum class Op { COUNT, MIN, MAX, SUM, AVG };
enum class Op { COUNT, MIN, MAX, SUM, AVG, COLLECT };
static const constexpr char *const kCount = "COUNT";
static const constexpr char *const kMin = "MIN";
static const constexpr char *const kMax = "MAX";
static const constexpr char *const kSum = "SUM";
static const constexpr char *const kAvg = "AVG";
static const constexpr char *const kCollect = "COLLECT";
DEFVISITABLE(TreeVisitor<TypedValue>);
bool Accept(HierarchicalTreeVisitor &visitor) override {

View File

@ -788,6 +788,10 @@ antlrcpp::Any CypherMainVisitor::visitFunctionInvocation(
return static_cast<Expression *>(
storage_.Create<Aggregation>(expressions[0], Aggregation::Op::AVG));
}
if (function_name == Aggregation::kCollect) {
return static_cast<Expression *>(storage_.Create<Aggregation>(
expressions[0], Aggregation::Op::COLLECT));
}
}
auto function = NameToFunction(function_name);
if (!function) throw SemanticException();

View File

@ -938,10 +938,13 @@ void Aggregate::AggregateCursor::EnsureInitialized(
if (agg_value.values_.size() > 0) return;
for (const auto &agg_elem : self_.aggregations_) {
if (std::get<1>(agg_elem) == Aggregation::Op::COUNT)
if (std::get<1>(agg_elem) == Aggregation::Op::COUNT) {
agg_value.values_.emplace_back(TypedValue(0));
else
} else if (std::get<1>(agg_elem) == Aggregation::Op::COLLECT) {
agg_value.values_.emplace_back(std::vector<TypedValue>());
} else {
agg_value.values_.emplace_back(TypedValue::Null);
}
}
agg_value.counts_.resize(self_.aggregations_.size(), 0);
@ -964,7 +967,6 @@ void Aggregate::AggregateCursor::Update(
auto agg_elem_it = self_.aggregations_.begin();
for (; count_it < agg_value.counts_.end();
count_it++, value_it++, agg_elem_it++) {
// COUNT(*) is the only case where input expression is optional
// handle it here
auto input_expr_ptr = std::get<0>(*agg_elem_it);
@ -986,16 +988,21 @@ void Aggregate::AggregateCursor::Update(
switch (agg_op) {
case Aggregation::Op::MIN:
case Aggregation::Op::MAX:
*value_it = input_value;
EnsureOkForMinMax(input_value);
break;
case Aggregation::Op::SUM:
case Aggregation::Op::AVG:
*value_it = input_value;
EnsureOkForAvgSum(input_value);
break;
case Aggregation::Op::COUNT:
*value_it = 1;
break;
case Aggregation::Op::COLLECT:
value_it->Value<std::vector<TypedValue>>().push_back(input_value);
break;
}
*value_it = agg_op == Aggregation::Op::COUNT ? 1 : input_value;
continue;
}
@ -1029,6 +1036,9 @@ void Aggregate::AggregateCursor::Update(
EnsureOkForAvgSum(input_value);
*value_it = *value_it + input_value;
break;
case Aggregation::Op::COLLECT:
value_it->Value<std::vector<TypedValue>>().push_back(input_value);
break;
} // end switch over Aggregation::Op enum
} // end loop over all aggregations
}

View File

@ -506,15 +506,15 @@ TEST(CypherMainVisitorTest, UnaryMinusPlusOperators) {
TEST(CypherMainVisitorTest, Aggregation) {
AstGenerator ast_generator(
"RETURN COUNT(a), MIN(b), MAX(c), SUM(d), AVG(e), COUNT(*)");
"RETURN COUNT(a), MIN(b), MAX(c), SUM(d), AVG(e), COLLECT(f), COUNT(*)");
auto *query = ast_generator.query_;
auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]);
ASSERT_EQ(return_clause->body_.named_expressions.size(), 6U);
ASSERT_EQ(return_clause->body_.named_expressions.size(), 7U);
Aggregation::Op ops[] = {Aggregation::Op::COUNT, Aggregation::Op::MIN,
Aggregation::Op::MAX, Aggregation::Op::SUM,
Aggregation::Op::AVG};
std::string ids[] = {"a", "b", "c", "d", "e"};
for (int i = 0; i < 5; ++i) {
Aggregation::Op::MAX, Aggregation::Op::SUM,
Aggregation::Op::AVG, Aggregation::Op::COLLECT};
std::string ids[] = {"a", "b", "c", "d", "e", "f"};
for (int i = 0; i < 6; ++i) {
auto *aggregation = dynamic_cast<Aggregation *>(
return_clause->body_.named_expressions[i]->expression_);
ASSERT_TRUE(aggregation);
@ -524,7 +524,7 @@ TEST(CypherMainVisitorTest, Aggregation) {
ASSERT_EQ(identifier->name_, ids[i]);
}
auto *aggregation = dynamic_cast<Aggregation *>(
return_clause->body_.named_expressions[5]->expression_);
return_clause->body_.named_expressions[6]->expression_);
ASSERT_TRUE(aggregation);
ASSERT_EQ(aggregation->op_, Aggregation::Op::COUNT);
ASSERT_FALSE(aggregation->expression_);

View File

@ -28,6 +28,14 @@ namespace query {
namespace test_common {
auto ToInt64List(const TypedValue &t) {
std::vector<int64_t> list;
for (auto x : t.Value<std::vector<TypedValue>>()) {
list.push_back(x.Value<int64_t>());
}
return list;
};
// Custom types for ORDER BY, SKIP, LIMIT, ON MATCH and ON CREATE expressions,
// so that they can be used to resolve function calls.
struct OrderBy {

View File

@ -14,14 +14,15 @@
#include "query/interpret/awesome_memgraph_functions.hpp"
#include "query/interpret/eval.hpp"
#include "query/interpret/frame.hpp"
#include "query_common.hpp"
using namespace query;
using testing::Pair;
using testing::UnorderedElementsAre;
using testing::ElementsAre;
using query::test_common::ToInt64List;
namespace {
struct NoContextExpressionEvaluator {
NoContextExpressionEvaluator() {}
Frame frame{128};
@ -823,30 +824,23 @@ TEST(ExpressionEvaluator, FunctionRange) {
EXPECT_TRUE(EvaluateFunction("RANGE", {1, 2, TypedValue::Null}).IsNull());
EXPECT_THROW(EvaluateFunction("RANGE", {1, TypedValue::Null, 1.3}),
QueryRuntimeException);
auto to_int_list = [](const TypedValue &t) {
std::vector<int64_t> list;
for (auto x : t.Value<std::vector<TypedValue>>()) {
list.push_back(x.Value<int64_t>());
}
return list;
};
EXPECT_THROW(EvaluateFunction("RANGE", {1, 2, 0}), QueryRuntimeException);
EXPECT_THAT(to_int_list(EvaluateFunction("RANGE", {1, 3})),
EXPECT_THAT(ToInt64List(EvaluateFunction("RANGE", {1, 3})),
ElementsAre(1, 2, 3));
EXPECT_THAT(to_int_list(EvaluateFunction("RANGE", {-1, 5, 2})),
EXPECT_THAT(ToInt64List(EvaluateFunction("RANGE", {-1, 5, 2})),
ElementsAre(-1, 1, 3, 5));
EXPECT_THAT(to_int_list(EvaluateFunction("RANGE", {2, 10, 3})),
EXPECT_THAT(ToInt64List(EvaluateFunction("RANGE", {2, 10, 3})),
ElementsAre(2, 5, 8));
EXPECT_THAT(to_int_list(EvaluateFunction("RANGE", {2, 2, 2})),
EXPECT_THAT(ToInt64List(EvaluateFunction("RANGE", {2, 2, 2})),
ElementsAre(2));
EXPECT_THAT(to_int_list(EvaluateFunction("RANGE", {3, 0, 5})), ElementsAre());
EXPECT_THAT(to_int_list(EvaluateFunction("RANGE", {5, 1, -2})),
EXPECT_THAT(ToInt64List(EvaluateFunction("RANGE", {3, 0, 5})), ElementsAre());
EXPECT_THAT(ToInt64List(EvaluateFunction("RANGE", {5, 1, -2})),
ElementsAre(5, 3, 1));
EXPECT_THAT(to_int_list(EvaluateFunction("RANGE", {6, 1, -2})),
EXPECT_THAT(ToInt64List(EvaluateFunction("RANGE", {6, 1, -2})),
ElementsAre(6, 4, 2));
EXPECT_THAT(to_int_list(EvaluateFunction("RANGE", {2, 2, -3})),
EXPECT_THAT(ToInt64List(EvaluateFunction("RANGE", {2, 2, -3})),
ElementsAre(2));
EXPECT_THAT(to_int_list(EvaluateFunction("RANGE", {-2, 4, -1})),
EXPECT_THAT(ToInt64List(EvaluateFunction("RANGE", {-2, 4, -1})),
ElementsAre());
}

View File

@ -21,6 +21,8 @@
using namespace query;
using namespace query::plan;
using testing::UnorderedElementsAre;
using query::test_common::ToInt64List;
TEST(QueryPlan, Accumulate) {
// simulate the following two query execution on an empty db
@ -169,18 +171,19 @@ TEST(QueryPlan, AggregateOps) {
auto n_p = PROPERTY_LOOKUP("n", prop);
symbol_table[*n_p->expression_] = n.sym_;
std::vector<Expression *> aggregation_expressions(6, n_p);
std::vector<Expression *> aggregation_expressions(7, n_p);
aggregation_expressions[0] = nullptr;
auto produce = MakeAggregationProduce(
n.op_, symbol_table, storage, aggregation_expressions,
{Aggregation::Op::COUNT, Aggregation::Op::COUNT, Aggregation::Op::MIN,
Aggregation::Op::MAX, Aggregation::Op::SUM, Aggregation::Op::AVG},
Aggregation::Op::MAX, Aggregation::Op::SUM, Aggregation::Op::AVG,
Aggregation::Op::COLLECT},
{}, {});
// checks
auto results = CollectProduce(produce, symbol_table, *dba).GetResults();
ASSERT_EQ(results.size(), 1);
ASSERT_EQ(results[0].size(), 6);
ASSERT_EQ(results[0].size(), 7);
// count(*)
ASSERT_EQ(results[0][0].type(), TypedValue::Type::Int);
EXPECT_EQ(results[0][0].Value<int64_t>(), 4);
@ -199,6 +202,9 @@ TEST(QueryPlan, AggregateOps) {
// avg
ASSERT_EQ(results[0][5].type(), TypedValue::Type::Double);
EXPECT_FLOAT_EQ(results[0][5].Value<double>(), 24 / 3.0);
// collect
ASSERT_EQ(results[0][6].type(), TypedValue::Type::List);
EXPECT_THAT(ToInt64List(results[0][6]), UnorderedElementsAre(5, 7, 12));
}
TEST(QueryPlan, AggregateGroupByValues) {
@ -320,6 +326,34 @@ TEST(QueryPlan, AggregateNoInput) {
EXPECT_EQ(1, results[0][0].Value<int64_t>());
}
// TODO: This test is valid but it fails. We don't handle aggregations correctly
// in the case when there is no input. Also add similar tests for other
// aggregation ops.
// TEST(QueryPlan, AggregateCollectNoResults) {
// Dbms dbms;
// auto dba = dbms.active();
// auto prop = dba->property("prop");
//
// AstTreeStorage storage;
// SymbolTable symbol_table;
//
// // match all nodes and perform aggregations
// auto n = MakeScanAll(storage, symbol_table, "n");
// auto n_p = PROPERTY_LOOKUP("n", prop);
// symbol_table[*n_p->expression_] = n.sym_;
//
// auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p},
// {Aggregation::Op::COLLECT}, {}, {});
//
// // checks
// auto results = CollectProduce(produce, symbol_table, *dba).GetResults();
// ASSERT_EQ(results.size(), 1);
// ASSERT_EQ(results[0].size(), 1);
// ASSERT_EQ(results[0][0].type(), TypedValue::Type::List);
// // Collect should return empty list if there are no results.
// EXPECT_THAT(ToInt64List(results[0][0]), UnorderedElementsAre());
//}
TEST(QueryPlan, AggregateCountEdgeCases) {
// tests for detected bugs in the COUNT aggregation behavior
// ensure that COUNT returns correctly for