diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 17d2f3041..e360f4fe4 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -920,12 +920,44 @@ Aggregate::AggregateCursor::AggregateCursor(Aggregate &self, GraphDbAccessor &db) : self_(self), db_(db), input_cursor_(self_.input_->MakeCursor(db)) {} +namespace { +/** Returns the default TypedValue for an Aggregation operation. + * This value is valid both for returning when where are no inputs + * to the aggregation op, and for initializing an aggregation result + * when there are */ +TypedValue DefaultAggregationOpValue(Aggregation::Op op) { + switch (op) { + case Aggregation::Op::COUNT: + return TypedValue(0); + case Aggregation::Op::SUM: + case Aggregation::Op::MIN: + case Aggregation::Op::MAX: + case Aggregation::Op::AVG: + return TypedValue::Null; + case Aggregation::Op::COLLECT: + return TypedValue(std::vector()); + } +} +} + bool Aggregate::AggregateCursor::Pull(Frame &frame, const SymbolTable &symbol_table) { if (!pulled_all_input_) { ProcessAll(frame, symbol_table); pulled_all_input_ = true; aggregation_it_ = aggregation_.begin(); + + // in case there is no input and no group_bys we need to return true just + // this once + if (aggregation_.empty() && self_.group_by_.empty()) { + // place default aggregation values on the frame + for (const auto &elem : self_.aggregations_) + frame[std::get<2>(elem)] = DefaultAggregationOpValue(std::get<1>(elem)); + // place null as remember values on the frame + for (const Symbol &remember_sym : self_.remember_) + frame[remember_sym] = TypedValue::Null; + return true; + } } if (aggregation_it_ == aggregation_.end()) return false; @@ -981,15 +1013,9 @@ void Aggregate::AggregateCursor::EnsureInitialized( Aggregate::AggregateCursor::AggregationValue &agg_value) const { if (agg_value.values_.size() > 0) return; - for (const auto &agg_elem : self_.aggregations_) { - if (std::get<1>(agg_elem) == Aggregation::Op::COUNT) { - agg_value.values_.emplace_back(TypedValue(0)); - } else if (std::get<1>(agg_elem) == Aggregation::Op::COLLECT) { - agg_value.values_.emplace_back(std::vector()); - } else { - agg_value.values_.emplace_back(TypedValue::Null); - } - } + for (const auto &agg_elem : self_.aggregations_) + agg_value.values_.emplace_back( + DefaultAggregationOpValue(std::get<1>(agg_elem))); agg_value.counts_.resize(self_.aggregations_.size(), 0); for (const Symbol &remember_sym : self_.remember_) diff --git a/tests/unit/query_plan_accumulate_aggregate.cpp b/tests/unit/query_plan_accumulate_aggregate.cpp index b3480d488..43e73b541 100644 --- a/tests/unit/query_plan_accumulate_aggregate.cpp +++ b/tests/unit/query_plan_accumulate_aggregate.cpp @@ -115,8 +115,6 @@ std::shared_ptr MakeAggregationProduce( const std::vector aggr_ops, const std::vector group_by_exprs, const std::vector remember) { - permanent_assert(aggr_inputs.size() == aggr_ops.size(), - "Provide as many aggr inputs as aggr ops"); // prepare all the aggregations std::vector aggregates; std::vector named_expressions; @@ -148,40 +146,55 @@ std::shared_ptr MakeAggregationProduce( return std::make_shared(aggregation, named_expressions); } -TEST(QueryPlan, AggregateOps) { +/** Test fixture for all the aggregation ops in one return */ +class QueryPlanAggregateOps : public ::testing::Test { + protected: Dbms dbms; - auto dba = dbms.active(); - - // setup is several nodes most of which have an int property set - // we will take the sum, avg, min, max and count - // we won't group by anything - auto prop = dba->property("prop"); - dba->insert_vertex().PropsSet(prop, 5); - dba->insert_vertex().PropsSet(prop, 7); - dba->insert_vertex().PropsSet(prop, 12); - // a missing property (null) gets ignored by all aggregations except COUNT(*) - dba->insert_vertex(); - dba->advance_command(); + std::unique_ptr dba = dbms.active(); + GraphDbTypes::Property prop = dba->property("prop"); AstTreeStorage storage; SymbolTable symbol_table; - // match all nodes and perform aggregations - auto n = MakeScanAll(storage, symbol_table, "n"); - auto n_p = PROPERTY_LOOKUP("n", prop); - symbol_table[*n_p->expression_] = n.sym_; + void AddData() { + // setup is several nodes most of which have an int property set + // we will take the sum, avg, min, max and count + // we won't group by anything + dba->insert_vertex().PropsSet(prop, 5); + dba->insert_vertex().PropsSet(prop, 7); + dba->insert_vertex().PropsSet(prop, 12); + // a missing property (null) gets ignored by all aggregations except + // COUNT(*) + dba->insert_vertex(); + dba->advance_command(); + } - std::vector aggregation_expressions(7, n_p); - aggregation_expressions[0] = nullptr; - auto produce = MakeAggregationProduce( - n.op_, symbol_table, storage, aggregation_expressions, - {Aggregation::Op::COUNT, Aggregation::Op::COUNT, Aggregation::Op::MIN, - Aggregation::Op::MAX, Aggregation::Op::SUM, Aggregation::Op::AVG, - Aggregation::Op::COLLECT}, - {}, {}); + auto AggregationResults(bool with_group_by, + std::vector ops = { + Aggregation::Op::COUNT, Aggregation::Op::COUNT, + Aggregation::Op::MIN, Aggregation::Op::MAX, + Aggregation::Op::SUM, Aggregation::Op::AVG, + Aggregation::Op::COLLECT}) { + // match all nodes and perform aggregations + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_p = PROPERTY_LOOKUP("n", prop); + symbol_table[*n_p->expression_] = n.sym_; + + std::vector aggregation_expressions(7, n_p); + std::vector group_bys; + if (with_group_by) group_bys.push_back(n_p); + aggregation_expressions[0] = nullptr; + auto produce = + MakeAggregationProduce(n.op_, symbol_table, storage, + aggregation_expressions, ops, group_bys, {}); + return CollectProduce(produce, symbol_table, *dba).GetResults(); + } +}; + +TEST_F(QueryPlanAggregateOps, WithData) { + AddData(); + auto results = AggregationResults(false); - // checks - auto results = CollectProduce(produce, symbol_table, *dba).GetResults(); ASSERT_EQ(results.size(), 1); ASSERT_EQ(results[0].size(), 7); // count(*) @@ -207,6 +220,56 @@ TEST(QueryPlan, AggregateOps) { EXPECT_THAT(ToInt64List(results[0][6]), UnorderedElementsAre(5, 7, 12)); } +TEST_F(QueryPlanAggregateOps, WithoutDataWithGroupBy) { + { + auto results = AggregationResults(true, {Aggregation::Op::COUNT}); + EXPECT_EQ(results.size(), 0); + } + { + auto results = AggregationResults(true, {Aggregation::Op::SUM}); + EXPECT_EQ(results.size(), 0); + } + { + auto results = AggregationResults(true, {Aggregation::Op::AVG}); + EXPECT_EQ(results.size(), 0); + } + { + auto results = AggregationResults(true, {Aggregation::Op::MIN}); + EXPECT_EQ(results.size(), 0); + } + { + auto results = AggregationResults(true, {Aggregation::Op::MAX}); + EXPECT_EQ(results.size(), 0); + } + { + auto results = AggregationResults(true, {Aggregation::Op::COLLECT}); + EXPECT_EQ(results.size(), 0); + } +} + +TEST_F(QueryPlanAggregateOps, WithoutDataWithoutGroupBy) { + auto results = AggregationResults(false); + ASSERT_EQ(results.size(), 1); + ASSERT_EQ(results[0].size(), 7); + // count(*) + ASSERT_EQ(results[0][0].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][0].Value(), 0); + // count + ASSERT_EQ(results[0][1].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][1].Value(), 0); + // min + EXPECT_TRUE(results[0][2].IsNull()); + // max + EXPECT_TRUE(results[0][3].IsNull()); + // sum + EXPECT_TRUE(results[0][4].IsNull()); + // avg + EXPECT_TRUE(results[0][5].IsNull()); + // collect + ASSERT_EQ(results[0][6].type(), TypedValue::Type::List); + EXPECT_THAT(ToInt64List(results[0][6]), UnorderedElementsAre()); +} + TEST(QueryPlan, AggregateGroupByValues) { // tests that distinct groups are aggregated properly // for values of all types @@ -326,34 +389,6 @@ TEST(QueryPlan, AggregateNoInput) { EXPECT_EQ(1, results[0][0].Value()); } -// TODO: This test is valid but it fails. We don't handle aggregations correctly -// in the case when there is no input. Also add similar tests for other -// aggregation ops. -// TEST(QueryPlan, AggregateCollectNoResults) { -// Dbms dbms; -// auto dba = dbms.active(); -// auto prop = dba->property("prop"); -// -// AstTreeStorage storage; -// SymbolTable symbol_table; -// -// // match all nodes and perform aggregations -// auto n = MakeScanAll(storage, symbol_table, "n"); -// auto n_p = PROPERTY_LOOKUP("n", prop); -// symbol_table[*n_p->expression_] = n.sym_; -// -// auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p}, -// {Aggregation::Op::COLLECT}, {}, {}); -// -// // checks -// auto results = CollectProduce(produce, symbol_table, *dba).GetResults(); -// ASSERT_EQ(results.size(), 1); -// ASSERT_EQ(results[0].size(), 1); -// ASSERT_EQ(results[0][0].type(), TypedValue::Type::List); -// // Collect should return empty list if there are no results. -// EXPECT_THAT(ToInt64List(results[0][0]), UnorderedElementsAre()); -//} - TEST(QueryPlan, AggregateCountEdgeCases) { // tests for detected bugs in the COUNT aggregation behavior // ensure that COUNT returns correctly for @@ -388,7 +423,7 @@ TEST(QueryPlan, AggregateCountEdgeCases) { }; // no vertices yet in database - EXPECT_EQ(-1, count()); + EXPECT_EQ(0, count()); // one vertex, no property set dba->insert_vertex();