Query::Plan::Aggregate - count edge-case fix

Reviewers: teon.banek

Reviewed By: teon.banek

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D270
This commit is contained in:
florijan 2017-04-12 15:47:55 +02:00
parent afbec0a26e
commit c89c65a748
2 changed files with 72 additions and 8 deletions

View File

@ -1049,7 +1049,12 @@ void Aggregate::AggregateCursor::EnsureInitialized(
Frame &frame, Aggregate::AggregateCursor::AggregationValue &agg_value) {
if (agg_value.values_.size() > 0) return;
agg_value.values_.resize(self_.aggregations_.size(), TypedValue::Null);
for (const auto &agg_elem : self_.aggregations_) {
if (std::get<1>(agg_elem) == Aggregation::Op::COUNT)
agg_value.values_.emplace_back(TypedValue(0));
else
agg_value.values_.emplace_back(TypedValue::Null);
}
agg_value.counts_.resize(self_.aggregations_.size(), 0);
for (const Symbol &remember_sym : self_.remember_)
@ -1074,17 +1079,19 @@ void Aggregate::AggregateCursor::Update(
std::get<0>(*agg_elem_it)->Accept(evaluator);
TypedValue input_value = evaluator.PopBack();
// Aggregations skip Null input values.
if (input_value.type() == TypedValue::Type::Null) continue;
const auto &agg_op = std::get<1>(*agg_elem_it);
*count_it += 1;
if (*count_it == 1) {
// first value, nothing to aggregate. set and continue.
*value_it = input_value;
*value_it = agg_op == Aggregation::Op::COUNT ? 1 : input_value;
continue;
}
// aggregation of existing values
switch (std::get<1>(*agg_elem_it)) {
switch (agg_op) {
case Aggregation::Op::COUNT:
*value_it = *count_it;
break;

View File

@ -252,9 +252,9 @@ TEST(QueryPlan, AggregateGroupByValues) {
result_group_bys.insert(row[1]);
}
ASSERT_EQ(result_group_bys.size(), group_by_vals.size() - 2);
EXPECT_TRUE(std::is_permutation(
group_by_vals.begin(), group_by_vals.end() - 2, result_group_bys.begin(),
TypedValue::BoolEqual{}));
EXPECT_TRUE(
std::is_permutation(group_by_vals.begin(), group_by_vals.end() - 2,
result_group_bys.begin(), TypedValue::BoolEqual{}));
}
TEST(QueryPlan, AggregateMultipleGroupBy) {
@ -323,10 +323,67 @@ TEST(QueryPlan, AggregateAdvance) {
auto match = MakeScanAll(storage, symbol_table, "m", aggregate);
EXPECT_EQ(advance ? 1 : 0, PullAll(match.op_, *dba, symbol_table));
};
// check(false);
// check(false);
check(true);
}
TEST(QueryPlan, AggregateCountEdgeCases) {
// tests for detected bugs in the COUNT aggregation behavior
// ensure that COUNT returns correctly for
// - 0 vertices in database
// - 1 vertex in database, property not set
// - 1 vertex in database, property set
// - 2 vertices in database, property set on one
// - 2 vertices in database, property set on both
Dbms dbms;
auto dba = dbms.active();
auto prop = dba->property("prop");
AstTreeStorage storage;
SymbolTable symbol_table;
auto n = MakeScanAll(storage, symbol_table, "n");
auto n_p = PROPERTY_LOOKUP("n", prop);
symbol_table[*n_p->expression_] = n.sym_;
// returns -1 when there are no results
// otherwise returns MATCH (n) RETURN count(n.prop)
auto count = [&]() {
auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p},
{Aggregation::Op::COUNT}, {}, {});
auto results = CollectProduce(produce, symbol_table, *dba).GetResults();
if (results.size() == 0) return -1L;
EXPECT_EQ(1, results.size());
EXPECT_EQ(1, results[0].size());
EXPECT_EQ(TypedValue::Type::Int, results[0][0].type());
return results[0][0].Value<int64_t>();
};
// no vertices yet in database
EXPECT_EQ(-1, count());
// one vertex, no property set
dba->insert_vertex();
dba->advance_command();
EXPECT_EQ(0, count());
// one vertex, property set
for (VertexAccessor va : dba->vertices()) va.PropsSet(prop, 42);
dba->advance_command();
EXPECT_EQ(1, count());
// two vertices, one with property set
dba->insert_vertex();
dba->advance_command();
EXPECT_EQ(1, count());
// two vertices, both with property set
for (VertexAccessor va : dba->vertices()) va.PropsSet(prop, 42);
dba->advance_command();
EXPECT_EQ(2, count());
}
TEST(QueryPlan, AggregateTypes) {
// testing exceptions that can get emitted by an aggregation
// does not check all combinations that can result in an exception
@ -358,7 +415,7 @@ TEST(QueryPlan, AggregateTypes) {
CollectProduce(produce, symbol_table, *dba).GetResults();
};
// everythin except for COUNT fails on a Vertex
// everything except for COUNT fails on a Vertex
auto n_id = n_p1->expression_;
aggregate(n_id, Aggregation::Op::COUNT);
EXPECT_THROW(aggregate(n_id, Aggregation::Op::MIN), TypedValueException);