diff --git a/src/query/frontend/logical/operator.cpp b/src/query/frontend/logical/operator.cpp index b1e3666b9..2d7f2ce5d 100644 --- a/src/query/frontend/logical/operator.cpp +++ b/src/query/frontend/logical/operator.cpp @@ -1049,7 +1049,12 @@ void Aggregate::AggregateCursor::EnsureInitialized( Frame &frame, Aggregate::AggregateCursor::AggregationValue &agg_value) { if (agg_value.values_.size() > 0) return; - agg_value.values_.resize(self_.aggregations_.size(), TypedValue::Null); + for (const auto &agg_elem : self_.aggregations_) { + if (std::get<1>(agg_elem) == Aggregation::Op::COUNT) + agg_value.values_.emplace_back(TypedValue(0)); + else + agg_value.values_.emplace_back(TypedValue::Null); + } agg_value.counts_.resize(self_.aggregations_.size(), 0); for (const Symbol &remember_sym : self_.remember_) @@ -1074,17 +1079,19 @@ void Aggregate::AggregateCursor::Update( std::get<0>(*agg_elem_it)->Accept(evaluator); TypedValue input_value = evaluator.PopBack(); + // Aggregations skip Null input values. if (input_value.type() == TypedValue::Type::Null) continue; + const auto &agg_op = std::get<1>(*agg_elem_it); *count_it += 1; if (*count_it == 1) { // first value, nothing to aggregate. set and continue. - *value_it = input_value; + *value_it = agg_op == Aggregation::Op::COUNT ? 1 : input_value; continue; } // aggregation of existing values - switch (std::get<1>(*agg_elem_it)) { + switch (agg_op) { case Aggregation::Op::COUNT: *value_it = *count_it; break; diff --git a/tests/unit/query_plan_accumulate_aggregate.cpp b/tests/unit/query_plan_accumulate_aggregate.cpp index 1137c7536..70b9bec48 100644 --- a/tests/unit/query_plan_accumulate_aggregate.cpp +++ b/tests/unit/query_plan_accumulate_aggregate.cpp @@ -252,9 +252,9 @@ TEST(QueryPlan, AggregateGroupByValues) { result_group_bys.insert(row[1]); } ASSERT_EQ(result_group_bys.size(), group_by_vals.size() - 2); - EXPECT_TRUE(std::is_permutation( - group_by_vals.begin(), group_by_vals.end() - 2, result_group_bys.begin(), - TypedValue::BoolEqual{})); + EXPECT_TRUE( + std::is_permutation(group_by_vals.begin(), group_by_vals.end() - 2, + result_group_bys.begin(), TypedValue::BoolEqual{})); } TEST(QueryPlan, AggregateMultipleGroupBy) { @@ -323,10 +323,67 @@ TEST(QueryPlan, AggregateAdvance) { auto match = MakeScanAll(storage, symbol_table, "m", aggregate); EXPECT_EQ(advance ? 1 : 0, PullAll(match.op_, *dba, symbol_table)); }; -// check(false); + // check(false); check(true); } +TEST(QueryPlan, AggregateCountEdgeCases) { + // tests for detected bugs in the COUNT aggregation behavior + // ensure that COUNT returns correctly for + // - 0 vertices in database + // - 1 vertex in database, property not set + // - 1 vertex in database, property set + // - 2 vertices in database, property set on one + // - 2 vertices in database, property set on both + + Dbms dbms; + auto dba = dbms.active(); + auto prop = dba->property("prop"); + + AstTreeStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_p = PROPERTY_LOOKUP("n", prop); + symbol_table[*n_p->expression_] = n.sym_; + + // returns -1 when there are no results + // otherwise returns MATCH (n) RETURN count(n.prop) + auto count = [&]() { + auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p}, + {Aggregation::Op::COUNT}, {}, {}); + auto results = CollectProduce(produce, symbol_table, *dba).GetResults(); + if (results.size() == 0) return -1L; + EXPECT_EQ(1, results.size()); + EXPECT_EQ(1, results[0].size()); + EXPECT_EQ(TypedValue::Type::Int, results[0][0].type()); + return results[0][0].Value(); + }; + + // no vertices yet in database + EXPECT_EQ(-1, count()); + + // one vertex, no property set + dba->insert_vertex(); + dba->advance_command(); + EXPECT_EQ(0, count()); + + // one vertex, property set + for (VertexAccessor va : dba->vertices()) va.PropsSet(prop, 42); + dba->advance_command(); + EXPECT_EQ(1, count()); + + // two vertices, one with property set + dba->insert_vertex(); + dba->advance_command(); + EXPECT_EQ(1, count()); + + // two vertices, both with property set + for (VertexAccessor va : dba->vertices()) va.PropsSet(prop, 42); + dba->advance_command(); + EXPECT_EQ(2, count()); +} + TEST(QueryPlan, AggregateTypes) { // testing exceptions that can get emitted by an aggregation // does not check all combinations that can result in an exception @@ -358,7 +415,7 @@ TEST(QueryPlan, AggregateTypes) { CollectProduce(produce, symbol_table, *dba).GetResults(); }; - // everythin except for COUNT fails on a Vertex + // everything except for COUNT fails on a Vertex auto n_id = n_p1->expression_; aggregate(n_id, Aggregation::Op::COUNT); EXPECT_THROW(aggregate(n_id, Aggregation::Op::MIN), TypedValueException);