Fix distinct, now doesn't impacts other aggregates (#1235)

Before a distinct on one aggregate would impact distinct on another
aggregate. Fixed the logical error and at the same time did some memory
optimisations.
This commit is contained in:
Gareth Andrew Lloyd 2023-09-20 16:45:55 +01:00 committed by GitHub
parent 1553fcb958
commit eb4e2b019d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 89 additions and 31 deletions

View File

@ -3405,7 +3405,10 @@ TypedValue DefaultAggregationOpValue(const Aggregate::Element &element, utils::M
class AggregateCursor : public Cursor { class AggregateCursor : public Cursor {
public: public:
AggregateCursor(const Aggregate &self, utils::MemoryResource *mem) AggregateCursor(const Aggregate &self, utils::MemoryResource *mem)
: self_(self), input_cursor_(self_.input_->MakeCursor(mem)), aggregation_(mem) {} : self_(self),
input_cursor_(self_.input_->MakeCursor(mem)),
aggregation_(mem),
reused_group_by_(self.group_by_.size(), mem) {}
bool Pull(Frame &frame, ExecutionContext &context) override { bool Pull(Frame &frame, ExecutionContext &context) override {
SCOPED_PROFILE_OP_BY_REF(self_); SCOPED_PROFILE_OP_BY_REF(self_);
@ -3498,6 +3501,8 @@ class AggregateCursor : public Cursor {
// custom equality // custom equality
TypedValueVectorEqual> TypedValueVectorEqual>
aggregation_; aggregation_;
// this is a for object reuse, to avoid re-allocating this buffer
utils::pmr::vector<TypedValue> reused_group_by_;
// iterator over the accumulated cache // iterator over the accumulated cache
decltype(aggregation_.begin()) aggregation_it_ = aggregation_.begin(); decltype(aggregation_.begin()) aggregation_it_ = aggregation_.begin();
// this LogicalOp pulls all from the input on it's first pull // this LogicalOp pulls all from the input on it's first pull
@ -3520,16 +3525,36 @@ class AggregateCursor : public Cursor {
ProcessOne(*frame, &evaluator); ProcessOne(*frame, &evaluator);
} }
// calculate AVG aggregations (so far they have only been summed) // post processing
for (size_t pos = 0; pos < self_.aggregations_.size(); ++pos) { for (size_t pos = 0; pos < self_.aggregations_.size(); ++pos) {
if (self_.aggregations_[pos].op != Aggregation::Op::AVG) continue; switch (self_.aggregations_[pos].op) {
for (auto &kv : aggregation_) { case Aggregation::Op::AVG: {
AggregationValue &agg_value = kv.second; // calculate AVG aggregations (so far they have only been summed)
auto count = agg_value.counts_[pos]; for (auto &kv : aggregation_) {
auto *pull_memory = context->evaluation_context.memory; AggregationValue &agg_value = kv.second;
if (count > 0) { auto count = agg_value.counts_[pos];
agg_value.values_[pos] = agg_value.values_[pos] / TypedValue(static_cast<double>(count), pull_memory); auto *pull_memory = context->evaluation_context.memory;
if (count > 0) {
agg_value.values_[pos] = agg_value.values_[pos] / TypedValue(static_cast<double>(count), pull_memory);
}
}
break;
} }
case Aggregation::Op::COUNT: {
// Copy counts to be the value
for (auto &kv : aggregation_) {
AggregationValue &agg_value = kv.second;
agg_value.values_[pos] = agg_value.counts_[pos];
}
break;
}
case Aggregation::Op::MIN:
case Aggregation::Op::MAX:
case Aggregation::Op::SUM:
case Aggregation::Op::COLLECT_LIST:
case Aggregation::Op::COLLECT_MAP:
case Aggregation::Op::PROJECT:
break;
} }
} }
} }
@ -3538,14 +3563,16 @@ class AggregateCursor : public Cursor {
* Performs a single accumulation. * Performs a single accumulation.
*/ */
void ProcessOne(const Frame &frame, ExpressionEvaluator *evaluator) { void ProcessOne(const Frame &frame, ExpressionEvaluator *evaluator) {
auto *mem = aggregation_.get_allocator().GetMemoryResource(); // Preallocated group_by, since most of the time the aggregation key won't be unique
utils::pmr::vector<TypedValue> group_by(mem); reused_group_by_.clear();
group_by.reserve(self_.group_by_.size());
for (Expression *expression : self_.group_by_) { for (Expression *expression : self_.group_by_) {
group_by.emplace_back(expression->Accept(*evaluator)); reused_group_by_.emplace_back(expression->Accept(*evaluator));
} }
auto &agg_value = aggregation_.try_emplace(std::move(group_by), mem).first->second; auto *mem = aggregation_.get_allocator().GetMemoryResource();
EnsureInitialized(frame, &agg_value); auto res = aggregation_.try_emplace(reused_group_by_, mem);
auto &agg_value = res.first->second;
if (res.second /*was newly inserted*/) EnsureInitialized(frame, &agg_value);
Update(evaluator, &agg_value); Update(evaluator, &agg_value);
} }
@ -3556,14 +3583,21 @@ class AggregateCursor : public Cursor {
void EnsureInitialized(const Frame &frame, AggregateCursor::AggregationValue *agg_value) const { void EnsureInitialized(const Frame &frame, AggregateCursor::AggregationValue *agg_value) const {
if (!agg_value->values_.empty()) return; if (!agg_value->values_.empty()) return;
const auto num_of_aggregations = self_.aggregations_.size();
agg_value->values_.reserve(num_of_aggregations);
agg_value->unique_values_.reserve(num_of_aggregations);
auto *mem = agg_value->values_.get_allocator().GetMemoryResource();
for (const auto &agg_elem : self_.aggregations_) { for (const auto &agg_elem : self_.aggregations_) {
auto *mem = agg_value->values_.get_allocator().GetMemoryResource();
agg_value->values_.emplace_back(DefaultAggregationOpValue(agg_elem, mem)); agg_value->values_.emplace_back(DefaultAggregationOpValue(agg_elem, mem));
agg_value->unique_values_.emplace_back(AggregationValue::TSet(mem)); agg_value->unique_values_.emplace_back(AggregationValue::TSet(mem));
} }
agg_value->counts_.resize(self_.aggregations_.size(), 0); agg_value->counts_.resize(num_of_aggregations, 0);
for (const Symbol &remember_sym : self_.remember_) agg_value->remember_.push_back(frame[remember_sym]); agg_value->remember_.reserve(self_.remember_.size());
for (const Symbol &remember_sym : self_.remember_) {
agg_value->remember_.push_back(frame[remember_sym]);
}
} }
/** Updates the given AggregationValue with new data. Assumes that /** Updates the given AggregationValue with new data. Assumes that
@ -3580,13 +3614,14 @@ class AggregateCursor : public Cursor {
auto value_it = agg_value->values_.begin(); auto value_it = agg_value->values_.begin();
auto unique_values_it = agg_value->unique_values_.begin(); auto unique_values_it = agg_value->unique_values_.begin();
auto agg_elem_it = self_.aggregations_.begin(); auto agg_elem_it = self_.aggregations_.begin();
for (; count_it < agg_value->counts_.end(); count_it++, value_it++, unique_values_it++, agg_elem_it++) { const auto counts_end = agg_value->counts_.end();
for (; count_it != counts_end; ++count_it, ++value_it, ++unique_values_it, ++agg_elem_it) {
// COUNT(*) is the only case where input expression is optional // COUNT(*) is the only case where input expression is optional
// handle it here // handle it here
auto input_expr_ptr = agg_elem_it->value; auto input_expr_ptr = agg_elem_it->value;
if (!input_expr_ptr) { if (!input_expr_ptr) {
*count_it += 1; *count_it += 1;
*value_it = *count_it; // value is deferred to post-processing
continue; continue;
} }
@ -3598,7 +3633,7 @@ class AggregateCursor : public Cursor {
if (agg_elem_it->distinct) { if (agg_elem_it->distinct) {
auto insert_result = unique_values_it->insert(input_value); auto insert_result = unique_values_it->insert(input_value);
if (!insert_result.second) { if (!insert_result.second) {
break; continue;
} }
} }
*count_it += 1; *count_it += 1;
@ -3607,19 +3642,19 @@ class AggregateCursor : public Cursor {
switch (agg_op) { switch (agg_op) {
case Aggregation::Op::MIN: case Aggregation::Op::MIN:
case Aggregation::Op::MAX: case Aggregation::Op::MAX:
*value_it = input_value;
EnsureOkForMinMax(input_value); EnsureOkForMinMax(input_value);
*value_it = std::move(input_value);
break; break;
case Aggregation::Op::SUM: case Aggregation::Op::SUM:
case Aggregation::Op::AVG: case Aggregation::Op::AVG:
*value_it = input_value;
EnsureOkForAvgSum(input_value); EnsureOkForAvgSum(input_value);
*value_it = std::move(input_value);
break; break;
case Aggregation::Op::COUNT: case Aggregation::Op::COUNT:
*value_it = 1; // value is deferred to post-processing
break; break;
case Aggregation::Op::COLLECT_LIST: case Aggregation::Op::COLLECT_LIST:
value_it->ValueList().push_back(input_value); value_it->ValueList().push_back(std::move(input_value));
break; break;
case Aggregation::Op::PROJECT: { case Aggregation::Op::PROJECT: {
EnsureOkForProject(input_value); EnsureOkForProject(input_value);
@ -3629,7 +3664,7 @@ class AggregateCursor : public Cursor {
case Aggregation::Op::COLLECT_MAP: case Aggregation::Op::COLLECT_MAP:
auto key = agg_elem_it->key->Accept(*evaluator); auto key = agg_elem_it->key->Accept(*evaluator);
if (key.type() != TypedValue::Type::String) throw QueryRuntimeException("Map key must be a string."); if (key.type() != TypedValue::Type::String) throw QueryRuntimeException("Map key must be a string.");
value_it->ValueMap().emplace(key.ValueString(), input_value); value_it->ValueMap().emplace(key.ValueString(), std::move(input_value));
break; break;
} }
continue; continue;
@ -3638,7 +3673,7 @@ class AggregateCursor : public Cursor {
// aggregation of existing values // aggregation of existing values
switch (agg_op) { switch (agg_op) {
case Aggregation::Op::COUNT: case Aggregation::Op::COUNT:
*value_it = *count_it; // value is deferred to post-processing
break; break;
case Aggregation::Op::MIN: { case Aggregation::Op::MIN: {
EnsureOkForMinMax(input_value); EnsureOkForMinMax(input_value);
@ -3647,7 +3682,7 @@ class AggregateCursor : public Cursor {
// since we skip nulls we either have a valid comparison, or // since we skip nulls we either have a valid comparison, or
// an exception was just thrown above // an exception was just thrown above
// safe to assume a bool TypedValue // safe to assume a bool TypedValue
if (comparison_result.ValueBool()) *value_it = input_value; if (comparison_result.ValueBool()) *value_it = std::move(input_value);
} catch (const TypedValueException &) { } catch (const TypedValueException &) {
throw QueryRuntimeException("Unable to get MIN of '{}' and '{}'.", input_value.type(), value_it->type()); throw QueryRuntimeException("Unable to get MIN of '{}' and '{}'.", input_value.type(), value_it->type());
} }
@ -3658,7 +3693,7 @@ class AggregateCursor : public Cursor {
EnsureOkForMinMax(input_value); EnsureOkForMinMax(input_value);
try { try {
TypedValue comparison_result = input_value > *value_it; TypedValue comparison_result = input_value > *value_it;
if (comparison_result.ValueBool()) *value_it = input_value; if (comparison_result.ValueBool()) *value_it = std::move(input_value);
} catch (const TypedValueException &) { } catch (const TypedValueException &) {
throw QueryRuntimeException("Unable to get MAX of '{}' and '{}'.", input_value.type(), value_it->type()); throw QueryRuntimeException("Unable to get MAX of '{}' and '{}'.", input_value.type(), value_it->type());
} }
@ -3672,7 +3707,7 @@ class AggregateCursor : public Cursor {
*value_it = *value_it + input_value; *value_it = *value_it + input_value;
break; break;
case Aggregation::Op::COLLECT_LIST: case Aggregation::Op::COLLECT_LIST:
value_it->ValueList().push_back(input_value); value_it->ValueList().push_back(std::move(input_value));
break; break;
case Aggregation::Op::PROJECT: { case Aggregation::Op::PROJECT: {
EnsureOkForProject(input_value); EnsureOkForProject(input_value);
@ -3682,7 +3717,7 @@ class AggregateCursor : public Cursor {
case Aggregation::Op::COLLECT_MAP: case Aggregation::Op::COLLECT_MAP:
auto key = agg_elem_it->key->Accept(*evaluator); auto key = agg_elem_it->key->Accept(*evaluator);
if (key.type() != TypedValue::Type::String) throw QueryRuntimeException("Map key must be a string."); if (key.type() != TypedValue::Type::String) throw QueryRuntimeException("Map key must be a string.");
value_it->ValueMap().emplace(key.ValueString(), input_value); value_it->ValueMap().emplace(key.ValueString(), std::move(input_value));
break; break;
} // end switch over Aggregation::Op enum } // end switch over Aggregation::Op enum
} // end loop over all aggregations } // end loop over all aggregations

View File

@ -1135,3 +1135,26 @@ Feature: Functions
| 0 | | 0 |
| 1 | | 1 |
| 2 | | 2 |
Scenario: Aggregate distinct does not impact other aggregates:
Given an empty graph
And having executed:
"""
CREATE (:Node_A {id:1})
CREATE (:Node_A {id:2})
CREATE (:Node_A {id:3})
CREATE (:Node_B {id:1})
CREATE (:Node_B {id:2})
CREATE (:Node_B {id:3})
CREATE (:Node_B {id:4})
CREATE (:Node_B {id:4})
"""
When executing query:
"""
MATCH (a:Node_A), (b:Node_B)
RETURN COUNT(DISTINCT a.id) AS A_COUNT,
COUNT(b.id) AS B_COUNT;
"""
Then the result should be:
| A_COUNT | B_COUNT |
| 3 | 15 |