Query - Aggregation with no input, SQL-style handling.

Reviewers: mislav.bradac, buda, teon.banek

Reviewed By: mislav.bradac

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D400
This commit is contained in:
florijan 2017-05-30 09:37:24 +02:00
parent e631eb4eb2
commit 7403338f38
2 changed files with 128 additions and 67 deletions

View File

@ -920,12 +920,44 @@ Aggregate::AggregateCursor::AggregateCursor(Aggregate &self,
GraphDbAccessor &db)
: self_(self), db_(db), input_cursor_(self_.input_->MakeCursor(db)) {}
namespace {
/** Returns the default TypedValue for an Aggregation operation.
* This value is valid both for returning when where are no inputs
* to the aggregation op, and for initializing an aggregation result
* when there are */
TypedValue DefaultAggregationOpValue(Aggregation::Op op) {
switch (op) {
case Aggregation::Op::COUNT:
return TypedValue(0);
case Aggregation::Op::SUM:
case Aggregation::Op::MIN:
case Aggregation::Op::MAX:
case Aggregation::Op::AVG:
return TypedValue::Null;
case Aggregation::Op::COLLECT:
return TypedValue(std::vector<TypedValue>());
}
}
}
bool Aggregate::AggregateCursor::Pull(Frame &frame,
const SymbolTable &symbol_table) {
if (!pulled_all_input_) {
ProcessAll(frame, symbol_table);
pulled_all_input_ = true;
aggregation_it_ = aggregation_.begin();
// in case there is no input and no group_bys we need to return true just
// this once
if (aggregation_.empty() && self_.group_by_.empty()) {
// place default aggregation values on the frame
for (const auto &elem : self_.aggregations_)
frame[std::get<2>(elem)] = DefaultAggregationOpValue(std::get<1>(elem));
// place null as remember values on the frame
for (const Symbol &remember_sym : self_.remember_)
frame[remember_sym] = TypedValue::Null;
return true;
}
}
if (aggregation_it_ == aggregation_.end()) return false;
@ -981,15 +1013,9 @@ void Aggregate::AggregateCursor::EnsureInitialized(
Aggregate::AggregateCursor::AggregationValue &agg_value) const {
if (agg_value.values_.size() > 0) return;
for (const auto &agg_elem : self_.aggregations_) {
if (std::get<1>(agg_elem) == Aggregation::Op::COUNT) {
agg_value.values_.emplace_back(TypedValue(0));
} else if (std::get<1>(agg_elem) == Aggregation::Op::COLLECT) {
agg_value.values_.emplace_back(std::vector<TypedValue>());
} else {
agg_value.values_.emplace_back(TypedValue::Null);
}
}
for (const auto &agg_elem : self_.aggregations_)
agg_value.values_.emplace_back(
DefaultAggregationOpValue(std::get<1>(agg_elem)));
agg_value.counts_.resize(self_.aggregations_.size(), 0);
for (const Symbol &remember_sym : self_.remember_)

View File

@ -115,8 +115,6 @@ std::shared_ptr<Produce> MakeAggregationProduce(
const std::vector<Aggregation::Op> aggr_ops,
const std::vector<Expression *> group_by_exprs,
const std::vector<Symbol> remember) {
permanent_assert(aggr_inputs.size() == aggr_ops.size(),
"Provide as many aggr inputs as aggr ops");
// prepare all the aggregations
std::vector<Aggregate::Element> aggregates;
std::vector<NamedExpression *> named_expressions;
@ -148,40 +146,55 @@ std::shared_ptr<Produce> MakeAggregationProduce(
return std::make_shared<Produce>(aggregation, named_expressions);
}
TEST(QueryPlan, AggregateOps) {
/** Test fixture for all the aggregation ops in one return */
class QueryPlanAggregateOps : public ::testing::Test {
protected:
Dbms dbms;
auto dba = dbms.active();
// setup is several nodes most of which have an int property set
// we will take the sum, avg, min, max and count
// we won't group by anything
auto prop = dba->property("prop");
dba->insert_vertex().PropsSet(prop, 5);
dba->insert_vertex().PropsSet(prop, 7);
dba->insert_vertex().PropsSet(prop, 12);
// a missing property (null) gets ignored by all aggregations except COUNT(*)
dba->insert_vertex();
dba->advance_command();
std::unique_ptr<GraphDbAccessor> dba = dbms.active();
GraphDbTypes::Property prop = dba->property("prop");
AstTreeStorage storage;
SymbolTable symbol_table;
// match all nodes and perform aggregations
auto n = MakeScanAll(storage, symbol_table, "n");
auto n_p = PROPERTY_LOOKUP("n", prop);
symbol_table[*n_p->expression_] = n.sym_;
void AddData() {
// setup is several nodes most of which have an int property set
// we will take the sum, avg, min, max and count
// we won't group by anything
dba->insert_vertex().PropsSet(prop, 5);
dba->insert_vertex().PropsSet(prop, 7);
dba->insert_vertex().PropsSet(prop, 12);
// a missing property (null) gets ignored by all aggregations except
// COUNT(*)
dba->insert_vertex();
dba->advance_command();
}
std::vector<Expression *> aggregation_expressions(7, n_p);
aggregation_expressions[0] = nullptr;
auto produce = MakeAggregationProduce(
n.op_, symbol_table, storage, aggregation_expressions,
{Aggregation::Op::COUNT, Aggregation::Op::COUNT, Aggregation::Op::MIN,
Aggregation::Op::MAX, Aggregation::Op::SUM, Aggregation::Op::AVG,
Aggregation::Op::COLLECT},
{}, {});
auto AggregationResults(bool with_group_by,
std::vector<Aggregation::Op> ops = {
Aggregation::Op::COUNT, Aggregation::Op::COUNT,
Aggregation::Op::MIN, Aggregation::Op::MAX,
Aggregation::Op::SUM, Aggregation::Op::AVG,
Aggregation::Op::COLLECT}) {
// match all nodes and perform aggregations
auto n = MakeScanAll(storage, symbol_table, "n");
auto n_p = PROPERTY_LOOKUP("n", prop);
symbol_table[*n_p->expression_] = n.sym_;
std::vector<Expression *> aggregation_expressions(7, n_p);
std::vector<Expression *> group_bys;
if (with_group_by) group_bys.push_back(n_p);
aggregation_expressions[0] = nullptr;
auto produce =
MakeAggregationProduce(n.op_, symbol_table, storage,
aggregation_expressions, ops, group_bys, {});
return CollectProduce(produce, symbol_table, *dba).GetResults();
}
};
TEST_F(QueryPlanAggregateOps, WithData) {
AddData();
auto results = AggregationResults(false);
// checks
auto results = CollectProduce(produce, symbol_table, *dba).GetResults();
ASSERT_EQ(results.size(), 1);
ASSERT_EQ(results[0].size(), 7);
// count(*)
@ -207,6 +220,56 @@ TEST(QueryPlan, AggregateOps) {
EXPECT_THAT(ToInt64List(results[0][6]), UnorderedElementsAre(5, 7, 12));
}
TEST_F(QueryPlanAggregateOps, WithoutDataWithGroupBy) {
{
auto results = AggregationResults(true, {Aggregation::Op::COUNT});
EXPECT_EQ(results.size(), 0);
}
{
auto results = AggregationResults(true, {Aggregation::Op::SUM});
EXPECT_EQ(results.size(), 0);
}
{
auto results = AggregationResults(true, {Aggregation::Op::AVG});
EXPECT_EQ(results.size(), 0);
}
{
auto results = AggregationResults(true, {Aggregation::Op::MIN});
EXPECT_EQ(results.size(), 0);
}
{
auto results = AggregationResults(true, {Aggregation::Op::MAX});
EXPECT_EQ(results.size(), 0);
}
{
auto results = AggregationResults(true, {Aggregation::Op::COLLECT});
EXPECT_EQ(results.size(), 0);
}
}
TEST_F(QueryPlanAggregateOps, WithoutDataWithoutGroupBy) {
auto results = AggregationResults(false);
ASSERT_EQ(results.size(), 1);
ASSERT_EQ(results[0].size(), 7);
// count(*)
ASSERT_EQ(results[0][0].type(), TypedValue::Type::Int);
EXPECT_EQ(results[0][0].Value<int64_t>(), 0);
// count
ASSERT_EQ(results[0][1].type(), TypedValue::Type::Int);
EXPECT_EQ(results[0][1].Value<int64_t>(), 0);
// min
EXPECT_TRUE(results[0][2].IsNull());
// max
EXPECT_TRUE(results[0][3].IsNull());
// sum
EXPECT_TRUE(results[0][4].IsNull());
// avg
EXPECT_TRUE(results[0][5].IsNull());
// collect
ASSERT_EQ(results[0][6].type(), TypedValue::Type::List);
EXPECT_THAT(ToInt64List(results[0][6]), UnorderedElementsAre());
}
TEST(QueryPlan, AggregateGroupByValues) {
// tests that distinct groups are aggregated properly
// for values of all types
@ -326,34 +389,6 @@ TEST(QueryPlan, AggregateNoInput) {
EXPECT_EQ(1, results[0][0].Value<int64_t>());
}
// TODO: This test is valid but it fails. We don't handle aggregations correctly
// in the case when there is no input. Also add similar tests for other
// aggregation ops.
// TEST(QueryPlan, AggregateCollectNoResults) {
// Dbms dbms;
// auto dba = dbms.active();
// auto prop = dba->property("prop");
//
// AstTreeStorage storage;
// SymbolTable symbol_table;
//
// // match all nodes and perform aggregations
// auto n = MakeScanAll(storage, symbol_table, "n");
// auto n_p = PROPERTY_LOOKUP("n", prop);
// symbol_table[*n_p->expression_] = n.sym_;
//
// auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p},
// {Aggregation::Op::COLLECT}, {}, {});
//
// // checks
// auto results = CollectProduce(produce, symbol_table, *dba).GetResults();
// ASSERT_EQ(results.size(), 1);
// ASSERT_EQ(results[0].size(), 1);
// ASSERT_EQ(results[0][0].type(), TypedValue::Type::List);
// // Collect should return empty list if there are no results.
// EXPECT_THAT(ToInt64List(results[0][0]), UnorderedElementsAre());
//}
TEST(QueryPlan, AggregateCountEdgeCases) {
// tests for detected bugs in the COUNT aggregation behavior
// ensure that COUNT returns correctly for
@ -388,7 +423,7 @@ TEST(QueryPlan, AggregateCountEdgeCases) {
};
// no vertices yet in database
EXPECT_EQ(-1, count());
EXPECT_EQ(0, count());
// one vertex, no property set
dba->insert_vertex();