Fix distinct, now doesn't impacts other aggregates (#1235)
Before a distinct on one aggregate would impact distinct on another aggregate. Fixed the logical error and at the same time did some memory optimisations.
This commit is contained in:
parent
1553fcb958
commit
eb4e2b019d
@ -3405,7 +3405,10 @@ TypedValue DefaultAggregationOpValue(const Aggregate::Element &element, utils::M
|
||||
class AggregateCursor : public Cursor {
|
||||
public:
|
||||
AggregateCursor(const Aggregate &self, utils::MemoryResource *mem)
|
||||
: self_(self), input_cursor_(self_.input_->MakeCursor(mem)), aggregation_(mem) {}
|
||||
: self_(self),
|
||||
input_cursor_(self_.input_->MakeCursor(mem)),
|
||||
aggregation_(mem),
|
||||
reused_group_by_(self.group_by_.size(), mem) {}
|
||||
|
||||
bool Pull(Frame &frame, ExecutionContext &context) override {
|
||||
SCOPED_PROFILE_OP_BY_REF(self_);
|
||||
@ -3498,6 +3501,8 @@ class AggregateCursor : public Cursor {
|
||||
// custom equality
|
||||
TypedValueVectorEqual>
|
||||
aggregation_;
|
||||
// this is a for object reuse, to avoid re-allocating this buffer
|
||||
utils::pmr::vector<TypedValue> reused_group_by_;
|
||||
// iterator over the accumulated cache
|
||||
decltype(aggregation_.begin()) aggregation_it_ = aggregation_.begin();
|
||||
// this LogicalOp pulls all from the input on it's first pull
|
||||
@ -3520,16 +3525,36 @@ class AggregateCursor : public Cursor {
|
||||
ProcessOne(*frame, &evaluator);
|
||||
}
|
||||
|
||||
// calculate AVG aggregations (so far they have only been summed)
|
||||
// post processing
|
||||
for (size_t pos = 0; pos < self_.aggregations_.size(); ++pos) {
|
||||
if (self_.aggregations_[pos].op != Aggregation::Op::AVG) continue;
|
||||
for (auto &kv : aggregation_) {
|
||||
AggregationValue &agg_value = kv.second;
|
||||
auto count = agg_value.counts_[pos];
|
||||
auto *pull_memory = context->evaluation_context.memory;
|
||||
if (count > 0) {
|
||||
agg_value.values_[pos] = agg_value.values_[pos] / TypedValue(static_cast<double>(count), pull_memory);
|
||||
switch (self_.aggregations_[pos].op) {
|
||||
case Aggregation::Op::AVG: {
|
||||
// calculate AVG aggregations (so far they have only been summed)
|
||||
for (auto &kv : aggregation_) {
|
||||
AggregationValue &agg_value = kv.second;
|
||||
auto count = agg_value.counts_[pos];
|
||||
auto *pull_memory = context->evaluation_context.memory;
|
||||
if (count > 0) {
|
||||
agg_value.values_[pos] = agg_value.values_[pos] / TypedValue(static_cast<double>(count), pull_memory);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Aggregation::Op::COUNT: {
|
||||
// Copy counts to be the value
|
||||
for (auto &kv : aggregation_) {
|
||||
AggregationValue &agg_value = kv.second;
|
||||
agg_value.values_[pos] = agg_value.counts_[pos];
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Aggregation::Op::MIN:
|
||||
case Aggregation::Op::MAX:
|
||||
case Aggregation::Op::SUM:
|
||||
case Aggregation::Op::COLLECT_LIST:
|
||||
case Aggregation::Op::COLLECT_MAP:
|
||||
case Aggregation::Op::PROJECT:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3538,14 +3563,16 @@ class AggregateCursor : public Cursor {
|
||||
* Performs a single accumulation.
|
||||
*/
|
||||
void ProcessOne(const Frame &frame, ExpressionEvaluator *evaluator) {
|
||||
auto *mem = aggregation_.get_allocator().GetMemoryResource();
|
||||
utils::pmr::vector<TypedValue> group_by(mem);
|
||||
group_by.reserve(self_.group_by_.size());
|
||||
// Preallocated group_by, since most of the time the aggregation key won't be unique
|
||||
reused_group_by_.clear();
|
||||
|
||||
for (Expression *expression : self_.group_by_) {
|
||||
group_by.emplace_back(expression->Accept(*evaluator));
|
||||
reused_group_by_.emplace_back(expression->Accept(*evaluator));
|
||||
}
|
||||
auto &agg_value = aggregation_.try_emplace(std::move(group_by), mem).first->second;
|
||||
EnsureInitialized(frame, &agg_value);
|
||||
auto *mem = aggregation_.get_allocator().GetMemoryResource();
|
||||
auto res = aggregation_.try_emplace(reused_group_by_, mem);
|
||||
auto &agg_value = res.first->second;
|
||||
if (res.second /*was newly inserted*/) EnsureInitialized(frame, &agg_value);
|
||||
Update(evaluator, &agg_value);
|
||||
}
|
||||
|
||||
@ -3556,14 +3583,21 @@ class AggregateCursor : public Cursor {
|
||||
void EnsureInitialized(const Frame &frame, AggregateCursor::AggregationValue *agg_value) const {
|
||||
if (!agg_value->values_.empty()) return;
|
||||
|
||||
const auto num_of_aggregations = self_.aggregations_.size();
|
||||
agg_value->values_.reserve(num_of_aggregations);
|
||||
agg_value->unique_values_.reserve(num_of_aggregations);
|
||||
|
||||
auto *mem = agg_value->values_.get_allocator().GetMemoryResource();
|
||||
for (const auto &agg_elem : self_.aggregations_) {
|
||||
auto *mem = agg_value->values_.get_allocator().GetMemoryResource();
|
||||
agg_value->values_.emplace_back(DefaultAggregationOpValue(agg_elem, mem));
|
||||
agg_value->unique_values_.emplace_back(AggregationValue::TSet(mem));
|
||||
}
|
||||
agg_value->counts_.resize(self_.aggregations_.size(), 0);
|
||||
agg_value->counts_.resize(num_of_aggregations, 0);
|
||||
|
||||
for (const Symbol &remember_sym : self_.remember_) agg_value->remember_.push_back(frame[remember_sym]);
|
||||
agg_value->remember_.reserve(self_.remember_.size());
|
||||
for (const Symbol &remember_sym : self_.remember_) {
|
||||
agg_value->remember_.push_back(frame[remember_sym]);
|
||||
}
|
||||
}
|
||||
|
||||
/** Updates the given AggregationValue with new data. Assumes that
|
||||
@ -3580,13 +3614,14 @@ class AggregateCursor : public Cursor {
|
||||
auto value_it = agg_value->values_.begin();
|
||||
auto unique_values_it = agg_value->unique_values_.begin();
|
||||
auto agg_elem_it = self_.aggregations_.begin();
|
||||
for (; count_it < agg_value->counts_.end(); count_it++, value_it++, unique_values_it++, agg_elem_it++) {
|
||||
const auto counts_end = agg_value->counts_.end();
|
||||
for (; count_it != counts_end; ++count_it, ++value_it, ++unique_values_it, ++agg_elem_it) {
|
||||
// COUNT(*) is the only case where input expression is optional
|
||||
// handle it here
|
||||
auto input_expr_ptr = agg_elem_it->value;
|
||||
if (!input_expr_ptr) {
|
||||
*count_it += 1;
|
||||
*value_it = *count_it;
|
||||
// value is deferred to post-processing
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -3598,7 +3633,7 @@ class AggregateCursor : public Cursor {
|
||||
if (agg_elem_it->distinct) {
|
||||
auto insert_result = unique_values_it->insert(input_value);
|
||||
if (!insert_result.second) {
|
||||
break;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
*count_it += 1;
|
||||
@ -3607,19 +3642,19 @@ class AggregateCursor : public Cursor {
|
||||
switch (agg_op) {
|
||||
case Aggregation::Op::MIN:
|
||||
case Aggregation::Op::MAX:
|
||||
*value_it = input_value;
|
||||
EnsureOkForMinMax(input_value);
|
||||
*value_it = std::move(input_value);
|
||||
break;
|
||||
case Aggregation::Op::SUM:
|
||||
case Aggregation::Op::AVG:
|
||||
*value_it = input_value;
|
||||
EnsureOkForAvgSum(input_value);
|
||||
*value_it = std::move(input_value);
|
||||
break;
|
||||
case Aggregation::Op::COUNT:
|
||||
*value_it = 1;
|
||||
// value is deferred to post-processing
|
||||
break;
|
||||
case Aggregation::Op::COLLECT_LIST:
|
||||
value_it->ValueList().push_back(input_value);
|
||||
value_it->ValueList().push_back(std::move(input_value));
|
||||
break;
|
||||
case Aggregation::Op::PROJECT: {
|
||||
EnsureOkForProject(input_value);
|
||||
@ -3629,7 +3664,7 @@ class AggregateCursor : public Cursor {
|
||||
case Aggregation::Op::COLLECT_MAP:
|
||||
auto key = agg_elem_it->key->Accept(*evaluator);
|
||||
if (key.type() != TypedValue::Type::String) throw QueryRuntimeException("Map key must be a string.");
|
||||
value_it->ValueMap().emplace(key.ValueString(), input_value);
|
||||
value_it->ValueMap().emplace(key.ValueString(), std::move(input_value));
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
@ -3638,7 +3673,7 @@ class AggregateCursor : public Cursor {
|
||||
// aggregation of existing values
|
||||
switch (agg_op) {
|
||||
case Aggregation::Op::COUNT:
|
||||
*value_it = *count_it;
|
||||
// value is deferred to post-processing
|
||||
break;
|
||||
case Aggregation::Op::MIN: {
|
||||
EnsureOkForMinMax(input_value);
|
||||
@ -3647,7 +3682,7 @@ class AggregateCursor : public Cursor {
|
||||
// since we skip nulls we either have a valid comparison, or
|
||||
// an exception was just thrown above
|
||||
// safe to assume a bool TypedValue
|
||||
if (comparison_result.ValueBool()) *value_it = input_value;
|
||||
if (comparison_result.ValueBool()) *value_it = std::move(input_value);
|
||||
} catch (const TypedValueException &) {
|
||||
throw QueryRuntimeException("Unable to get MIN of '{}' and '{}'.", input_value.type(), value_it->type());
|
||||
}
|
||||
@ -3658,7 +3693,7 @@ class AggregateCursor : public Cursor {
|
||||
EnsureOkForMinMax(input_value);
|
||||
try {
|
||||
TypedValue comparison_result = input_value > *value_it;
|
||||
if (comparison_result.ValueBool()) *value_it = input_value;
|
||||
if (comparison_result.ValueBool()) *value_it = std::move(input_value);
|
||||
} catch (const TypedValueException &) {
|
||||
throw QueryRuntimeException("Unable to get MAX of '{}' and '{}'.", input_value.type(), value_it->type());
|
||||
}
|
||||
@ -3672,7 +3707,7 @@ class AggregateCursor : public Cursor {
|
||||
*value_it = *value_it + input_value;
|
||||
break;
|
||||
case Aggregation::Op::COLLECT_LIST:
|
||||
value_it->ValueList().push_back(input_value);
|
||||
value_it->ValueList().push_back(std::move(input_value));
|
||||
break;
|
||||
case Aggregation::Op::PROJECT: {
|
||||
EnsureOkForProject(input_value);
|
||||
@ -3682,7 +3717,7 @@ class AggregateCursor : public Cursor {
|
||||
case Aggregation::Op::COLLECT_MAP:
|
||||
auto key = agg_elem_it->key->Accept(*evaluator);
|
||||
if (key.type() != TypedValue::Type::String) throw QueryRuntimeException("Map key must be a string.");
|
||||
value_it->ValueMap().emplace(key.ValueString(), input_value);
|
||||
value_it->ValueMap().emplace(key.ValueString(), std::move(input_value));
|
||||
break;
|
||||
} // end switch over Aggregation::Op enum
|
||||
} // end loop over all aggregations
|
||||
|
@ -1135,3 +1135,26 @@ Feature: Functions
|
||||
| 0 |
|
||||
| 1 |
|
||||
| 2 |
|
||||
|
||||
Scenario: Aggregate distinct does not impact other aggregates:
|
||||
Given an empty graph
|
||||
And having executed:
|
||||
"""
|
||||
CREATE (:Node_A {id:1})
|
||||
CREATE (:Node_A {id:2})
|
||||
CREATE (:Node_A {id:3})
|
||||
CREATE (:Node_B {id:1})
|
||||
CREATE (:Node_B {id:2})
|
||||
CREATE (:Node_B {id:3})
|
||||
CREATE (:Node_B {id:4})
|
||||
CREATE (:Node_B {id:4})
|
||||
"""
|
||||
When executing query:
|
||||
"""
|
||||
MATCH (a:Node_A), (b:Node_B)
|
||||
RETURN COUNT(DISTINCT a.id) AS A_COUNT,
|
||||
COUNT(b.id) AS B_COUNT;
|
||||
"""
|
||||
Then the result should be:
|
||||
| A_COUNT | B_COUNT |
|
||||
| 3 | 15 |
|
||||
|
Loading…
Reference in New Issue
Block a user