Fix distinct, now doesn't impacts other aggregates (#1235)

Before a distinct on one aggregate would impact distinct on another
aggregate. Fixed the logical error and at the same time did some memory
optimisations.
This commit is contained in:
Gareth Andrew Lloyd 2023-09-20 16:45:55 +01:00 committed by GitHub
parent 1553fcb958
commit eb4e2b019d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 89 additions and 31 deletions

View File

@ -3405,7 +3405,10 @@ TypedValue DefaultAggregationOpValue(const Aggregate::Element &element, utils::M
class AggregateCursor : public Cursor {
public:
AggregateCursor(const Aggregate &self, utils::MemoryResource *mem)
: self_(self), input_cursor_(self_.input_->MakeCursor(mem)), aggregation_(mem) {}
: self_(self),
input_cursor_(self_.input_->MakeCursor(mem)),
aggregation_(mem),
reused_group_by_(self.group_by_.size(), mem) {}
bool Pull(Frame &frame, ExecutionContext &context) override {
SCOPED_PROFILE_OP_BY_REF(self_);
@ -3498,6 +3501,8 @@ class AggregateCursor : public Cursor {
// custom equality
TypedValueVectorEqual>
aggregation_;
// this is a for object reuse, to avoid re-allocating this buffer
utils::pmr::vector<TypedValue> reused_group_by_;
// iterator over the accumulated cache
decltype(aggregation_.begin()) aggregation_it_ = aggregation_.begin();
// this LogicalOp pulls all from the input on it's first pull
@ -3520,16 +3525,36 @@ class AggregateCursor : public Cursor {
ProcessOne(*frame, &evaluator);
}
// calculate AVG aggregations (so far they have only been summed)
// post processing
for (size_t pos = 0; pos < self_.aggregations_.size(); ++pos) {
if (self_.aggregations_[pos].op != Aggregation::Op::AVG) continue;
for (auto &kv : aggregation_) {
AggregationValue &agg_value = kv.second;
auto count = agg_value.counts_[pos];
auto *pull_memory = context->evaluation_context.memory;
if (count > 0) {
agg_value.values_[pos] = agg_value.values_[pos] / TypedValue(static_cast<double>(count), pull_memory);
switch (self_.aggregations_[pos].op) {
case Aggregation::Op::AVG: {
// calculate AVG aggregations (so far they have only been summed)
for (auto &kv : aggregation_) {
AggregationValue &agg_value = kv.second;
auto count = agg_value.counts_[pos];
auto *pull_memory = context->evaluation_context.memory;
if (count > 0) {
agg_value.values_[pos] = agg_value.values_[pos] / TypedValue(static_cast<double>(count), pull_memory);
}
}
break;
}
case Aggregation::Op::COUNT: {
// Copy counts to be the value
for (auto &kv : aggregation_) {
AggregationValue &agg_value = kv.second;
agg_value.values_[pos] = agg_value.counts_[pos];
}
break;
}
case Aggregation::Op::MIN:
case Aggregation::Op::MAX:
case Aggregation::Op::SUM:
case Aggregation::Op::COLLECT_LIST:
case Aggregation::Op::COLLECT_MAP:
case Aggregation::Op::PROJECT:
break;
}
}
}
@ -3538,14 +3563,16 @@ class AggregateCursor : public Cursor {
* Performs a single accumulation.
*/
void ProcessOne(const Frame &frame, ExpressionEvaluator *evaluator) {
auto *mem = aggregation_.get_allocator().GetMemoryResource();
utils::pmr::vector<TypedValue> group_by(mem);
group_by.reserve(self_.group_by_.size());
// Preallocated group_by, since most of the time the aggregation key won't be unique
reused_group_by_.clear();
for (Expression *expression : self_.group_by_) {
group_by.emplace_back(expression->Accept(*evaluator));
reused_group_by_.emplace_back(expression->Accept(*evaluator));
}
auto &agg_value = aggregation_.try_emplace(std::move(group_by), mem).first->second;
EnsureInitialized(frame, &agg_value);
auto *mem = aggregation_.get_allocator().GetMemoryResource();
auto res = aggregation_.try_emplace(reused_group_by_, mem);
auto &agg_value = res.first->second;
if (res.second /*was newly inserted*/) EnsureInitialized(frame, &agg_value);
Update(evaluator, &agg_value);
}
@ -3556,14 +3583,21 @@ class AggregateCursor : public Cursor {
void EnsureInitialized(const Frame &frame, AggregateCursor::AggregationValue *agg_value) const {
if (!agg_value->values_.empty()) return;
const auto num_of_aggregations = self_.aggregations_.size();
agg_value->values_.reserve(num_of_aggregations);
agg_value->unique_values_.reserve(num_of_aggregations);
auto *mem = agg_value->values_.get_allocator().GetMemoryResource();
for (const auto &agg_elem : self_.aggregations_) {
auto *mem = agg_value->values_.get_allocator().GetMemoryResource();
agg_value->values_.emplace_back(DefaultAggregationOpValue(agg_elem, mem));
agg_value->unique_values_.emplace_back(AggregationValue::TSet(mem));
}
agg_value->counts_.resize(self_.aggregations_.size(), 0);
agg_value->counts_.resize(num_of_aggregations, 0);
for (const Symbol &remember_sym : self_.remember_) agg_value->remember_.push_back(frame[remember_sym]);
agg_value->remember_.reserve(self_.remember_.size());
for (const Symbol &remember_sym : self_.remember_) {
agg_value->remember_.push_back(frame[remember_sym]);
}
}
/** Updates the given AggregationValue with new data. Assumes that
@ -3580,13 +3614,14 @@ class AggregateCursor : public Cursor {
auto value_it = agg_value->values_.begin();
auto unique_values_it = agg_value->unique_values_.begin();
auto agg_elem_it = self_.aggregations_.begin();
for (; count_it < agg_value->counts_.end(); count_it++, value_it++, unique_values_it++, agg_elem_it++) {
const auto counts_end = agg_value->counts_.end();
for (; count_it != counts_end; ++count_it, ++value_it, ++unique_values_it, ++agg_elem_it) {
// COUNT(*) is the only case where input expression is optional
// handle it here
auto input_expr_ptr = agg_elem_it->value;
if (!input_expr_ptr) {
*count_it += 1;
*value_it = *count_it;
// value is deferred to post-processing
continue;
}
@ -3598,7 +3633,7 @@ class AggregateCursor : public Cursor {
if (agg_elem_it->distinct) {
auto insert_result = unique_values_it->insert(input_value);
if (!insert_result.second) {
break;
continue;
}
}
*count_it += 1;
@ -3607,19 +3642,19 @@ class AggregateCursor : public Cursor {
switch (agg_op) {
case Aggregation::Op::MIN:
case Aggregation::Op::MAX:
*value_it = input_value;
EnsureOkForMinMax(input_value);
*value_it = std::move(input_value);
break;
case Aggregation::Op::SUM:
case Aggregation::Op::AVG:
*value_it = input_value;
EnsureOkForAvgSum(input_value);
*value_it = std::move(input_value);
break;
case Aggregation::Op::COUNT:
*value_it = 1;
// value is deferred to post-processing
break;
case Aggregation::Op::COLLECT_LIST:
value_it->ValueList().push_back(input_value);
value_it->ValueList().push_back(std::move(input_value));
break;
case Aggregation::Op::PROJECT: {
EnsureOkForProject(input_value);
@ -3629,7 +3664,7 @@ class AggregateCursor : public Cursor {
case Aggregation::Op::COLLECT_MAP:
auto key = agg_elem_it->key->Accept(*evaluator);
if (key.type() != TypedValue::Type::String) throw QueryRuntimeException("Map key must be a string.");
value_it->ValueMap().emplace(key.ValueString(), input_value);
value_it->ValueMap().emplace(key.ValueString(), std::move(input_value));
break;
}
continue;
@ -3638,7 +3673,7 @@ class AggregateCursor : public Cursor {
// aggregation of existing values
switch (agg_op) {
case Aggregation::Op::COUNT:
*value_it = *count_it;
// value is deferred to post-processing
break;
case Aggregation::Op::MIN: {
EnsureOkForMinMax(input_value);
@ -3647,7 +3682,7 @@ class AggregateCursor : public Cursor {
// since we skip nulls we either have a valid comparison, or
// an exception was just thrown above
// safe to assume a bool TypedValue
if (comparison_result.ValueBool()) *value_it = input_value;
if (comparison_result.ValueBool()) *value_it = std::move(input_value);
} catch (const TypedValueException &) {
throw QueryRuntimeException("Unable to get MIN of '{}' and '{}'.", input_value.type(), value_it->type());
}
@ -3658,7 +3693,7 @@ class AggregateCursor : public Cursor {
EnsureOkForMinMax(input_value);
try {
TypedValue comparison_result = input_value > *value_it;
if (comparison_result.ValueBool()) *value_it = input_value;
if (comparison_result.ValueBool()) *value_it = std::move(input_value);
} catch (const TypedValueException &) {
throw QueryRuntimeException("Unable to get MAX of '{}' and '{}'.", input_value.type(), value_it->type());
}
@ -3672,7 +3707,7 @@ class AggregateCursor : public Cursor {
*value_it = *value_it + input_value;
break;
case Aggregation::Op::COLLECT_LIST:
value_it->ValueList().push_back(input_value);
value_it->ValueList().push_back(std::move(input_value));
break;
case Aggregation::Op::PROJECT: {
EnsureOkForProject(input_value);
@ -3682,7 +3717,7 @@ class AggregateCursor : public Cursor {
case Aggregation::Op::COLLECT_MAP:
auto key = agg_elem_it->key->Accept(*evaluator);
if (key.type() != TypedValue::Type::String) throw QueryRuntimeException("Map key must be a string.");
value_it->ValueMap().emplace(key.ValueString(), input_value);
value_it->ValueMap().emplace(key.ValueString(), std::move(input_value));
break;
} // end switch over Aggregation::Op enum
} // end loop over all aggregations

View File

@ -1135,3 +1135,26 @@ Feature: Functions
| 0 |
| 1 |
| 2 |
Scenario: Aggregate distinct does not impact other aggregates:
Given an empty graph
And having executed:
"""
CREATE (:Node_A {id:1})
CREATE (:Node_A {id:2})
CREATE (:Node_A {id:3})
CREATE (:Node_B {id:1})
CREATE (:Node_B {id:2})
CREATE (:Node_B {id:3})
CREATE (:Node_B {id:4})
CREATE (:Node_B {id:4})
"""
When executing query:
"""
MATCH (a:Node_A), (b:Node_B)
RETURN COUNT(DISTINCT a.id) AS A_COUNT,
COUNT(b.id) AS B_COUNT;
"""
Then the result should be:
| A_COUNT | B_COUNT |
| 3 | 15 |