Fix distinct, now doesn't impacts other aggregates (#1235)
Before a distinct on one aggregate would impact distinct on another aggregate. Fixed the logical error and at the same time did some memory optimisations.
This commit is contained in:
parent
1553fcb958
commit
eb4e2b019d
@ -3405,7 +3405,10 @@ TypedValue DefaultAggregationOpValue(const Aggregate::Element &element, utils::M
|
|||||||
class AggregateCursor : public Cursor {
|
class AggregateCursor : public Cursor {
|
||||||
public:
|
public:
|
||||||
AggregateCursor(const Aggregate &self, utils::MemoryResource *mem)
|
AggregateCursor(const Aggregate &self, utils::MemoryResource *mem)
|
||||||
: self_(self), input_cursor_(self_.input_->MakeCursor(mem)), aggregation_(mem) {}
|
: self_(self),
|
||||||
|
input_cursor_(self_.input_->MakeCursor(mem)),
|
||||||
|
aggregation_(mem),
|
||||||
|
reused_group_by_(self.group_by_.size(), mem) {}
|
||||||
|
|
||||||
bool Pull(Frame &frame, ExecutionContext &context) override {
|
bool Pull(Frame &frame, ExecutionContext &context) override {
|
||||||
SCOPED_PROFILE_OP_BY_REF(self_);
|
SCOPED_PROFILE_OP_BY_REF(self_);
|
||||||
@ -3498,6 +3501,8 @@ class AggregateCursor : public Cursor {
|
|||||||
// custom equality
|
// custom equality
|
||||||
TypedValueVectorEqual>
|
TypedValueVectorEqual>
|
||||||
aggregation_;
|
aggregation_;
|
||||||
|
// this is a for object reuse, to avoid re-allocating this buffer
|
||||||
|
utils::pmr::vector<TypedValue> reused_group_by_;
|
||||||
// iterator over the accumulated cache
|
// iterator over the accumulated cache
|
||||||
decltype(aggregation_.begin()) aggregation_it_ = aggregation_.begin();
|
decltype(aggregation_.begin()) aggregation_it_ = aggregation_.begin();
|
||||||
// this LogicalOp pulls all from the input on it's first pull
|
// this LogicalOp pulls all from the input on it's first pull
|
||||||
@ -3520,16 +3525,36 @@ class AggregateCursor : public Cursor {
|
|||||||
ProcessOne(*frame, &evaluator);
|
ProcessOne(*frame, &evaluator);
|
||||||
}
|
}
|
||||||
|
|
||||||
// calculate AVG aggregations (so far they have only been summed)
|
// post processing
|
||||||
for (size_t pos = 0; pos < self_.aggregations_.size(); ++pos) {
|
for (size_t pos = 0; pos < self_.aggregations_.size(); ++pos) {
|
||||||
if (self_.aggregations_[pos].op != Aggregation::Op::AVG) continue;
|
switch (self_.aggregations_[pos].op) {
|
||||||
for (auto &kv : aggregation_) {
|
case Aggregation::Op::AVG: {
|
||||||
AggregationValue &agg_value = kv.second;
|
// calculate AVG aggregations (so far they have only been summed)
|
||||||
auto count = agg_value.counts_[pos];
|
for (auto &kv : aggregation_) {
|
||||||
auto *pull_memory = context->evaluation_context.memory;
|
AggregationValue &agg_value = kv.second;
|
||||||
if (count > 0) {
|
auto count = agg_value.counts_[pos];
|
||||||
agg_value.values_[pos] = agg_value.values_[pos] / TypedValue(static_cast<double>(count), pull_memory);
|
auto *pull_memory = context->evaluation_context.memory;
|
||||||
|
if (count > 0) {
|
||||||
|
agg_value.values_[pos] = agg_value.values_[pos] / TypedValue(static_cast<double>(count), pull_memory);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
case Aggregation::Op::COUNT: {
|
||||||
|
// Copy counts to be the value
|
||||||
|
for (auto &kv : aggregation_) {
|
||||||
|
AggregationValue &agg_value = kv.second;
|
||||||
|
agg_value.values_[pos] = agg_value.counts_[pos];
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case Aggregation::Op::MIN:
|
||||||
|
case Aggregation::Op::MAX:
|
||||||
|
case Aggregation::Op::SUM:
|
||||||
|
case Aggregation::Op::COLLECT_LIST:
|
||||||
|
case Aggregation::Op::COLLECT_MAP:
|
||||||
|
case Aggregation::Op::PROJECT:
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3538,14 +3563,16 @@ class AggregateCursor : public Cursor {
|
|||||||
* Performs a single accumulation.
|
* Performs a single accumulation.
|
||||||
*/
|
*/
|
||||||
void ProcessOne(const Frame &frame, ExpressionEvaluator *evaluator) {
|
void ProcessOne(const Frame &frame, ExpressionEvaluator *evaluator) {
|
||||||
auto *mem = aggregation_.get_allocator().GetMemoryResource();
|
// Preallocated group_by, since most of the time the aggregation key won't be unique
|
||||||
utils::pmr::vector<TypedValue> group_by(mem);
|
reused_group_by_.clear();
|
||||||
group_by.reserve(self_.group_by_.size());
|
|
||||||
for (Expression *expression : self_.group_by_) {
|
for (Expression *expression : self_.group_by_) {
|
||||||
group_by.emplace_back(expression->Accept(*evaluator));
|
reused_group_by_.emplace_back(expression->Accept(*evaluator));
|
||||||
}
|
}
|
||||||
auto &agg_value = aggregation_.try_emplace(std::move(group_by), mem).first->second;
|
auto *mem = aggregation_.get_allocator().GetMemoryResource();
|
||||||
EnsureInitialized(frame, &agg_value);
|
auto res = aggregation_.try_emplace(reused_group_by_, mem);
|
||||||
|
auto &agg_value = res.first->second;
|
||||||
|
if (res.second /*was newly inserted*/) EnsureInitialized(frame, &agg_value);
|
||||||
Update(evaluator, &agg_value);
|
Update(evaluator, &agg_value);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3556,14 +3583,21 @@ class AggregateCursor : public Cursor {
|
|||||||
void EnsureInitialized(const Frame &frame, AggregateCursor::AggregationValue *agg_value) const {
|
void EnsureInitialized(const Frame &frame, AggregateCursor::AggregationValue *agg_value) const {
|
||||||
if (!agg_value->values_.empty()) return;
|
if (!agg_value->values_.empty()) return;
|
||||||
|
|
||||||
|
const auto num_of_aggregations = self_.aggregations_.size();
|
||||||
|
agg_value->values_.reserve(num_of_aggregations);
|
||||||
|
agg_value->unique_values_.reserve(num_of_aggregations);
|
||||||
|
|
||||||
|
auto *mem = agg_value->values_.get_allocator().GetMemoryResource();
|
||||||
for (const auto &agg_elem : self_.aggregations_) {
|
for (const auto &agg_elem : self_.aggregations_) {
|
||||||
auto *mem = agg_value->values_.get_allocator().GetMemoryResource();
|
|
||||||
agg_value->values_.emplace_back(DefaultAggregationOpValue(agg_elem, mem));
|
agg_value->values_.emplace_back(DefaultAggregationOpValue(agg_elem, mem));
|
||||||
agg_value->unique_values_.emplace_back(AggregationValue::TSet(mem));
|
agg_value->unique_values_.emplace_back(AggregationValue::TSet(mem));
|
||||||
}
|
}
|
||||||
agg_value->counts_.resize(self_.aggregations_.size(), 0);
|
agg_value->counts_.resize(num_of_aggregations, 0);
|
||||||
|
|
||||||
for (const Symbol &remember_sym : self_.remember_) agg_value->remember_.push_back(frame[remember_sym]);
|
agg_value->remember_.reserve(self_.remember_.size());
|
||||||
|
for (const Symbol &remember_sym : self_.remember_) {
|
||||||
|
agg_value->remember_.push_back(frame[remember_sym]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the given AggregationValue with new data. Assumes that
|
/** Updates the given AggregationValue with new data. Assumes that
|
||||||
@ -3580,13 +3614,14 @@ class AggregateCursor : public Cursor {
|
|||||||
auto value_it = agg_value->values_.begin();
|
auto value_it = agg_value->values_.begin();
|
||||||
auto unique_values_it = agg_value->unique_values_.begin();
|
auto unique_values_it = agg_value->unique_values_.begin();
|
||||||
auto agg_elem_it = self_.aggregations_.begin();
|
auto agg_elem_it = self_.aggregations_.begin();
|
||||||
for (; count_it < agg_value->counts_.end(); count_it++, value_it++, unique_values_it++, agg_elem_it++) {
|
const auto counts_end = agg_value->counts_.end();
|
||||||
|
for (; count_it != counts_end; ++count_it, ++value_it, ++unique_values_it, ++agg_elem_it) {
|
||||||
// COUNT(*) is the only case where input expression is optional
|
// COUNT(*) is the only case where input expression is optional
|
||||||
// handle it here
|
// handle it here
|
||||||
auto input_expr_ptr = agg_elem_it->value;
|
auto input_expr_ptr = agg_elem_it->value;
|
||||||
if (!input_expr_ptr) {
|
if (!input_expr_ptr) {
|
||||||
*count_it += 1;
|
*count_it += 1;
|
||||||
*value_it = *count_it;
|
// value is deferred to post-processing
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3598,7 +3633,7 @@ class AggregateCursor : public Cursor {
|
|||||||
if (agg_elem_it->distinct) {
|
if (agg_elem_it->distinct) {
|
||||||
auto insert_result = unique_values_it->insert(input_value);
|
auto insert_result = unique_values_it->insert(input_value);
|
||||||
if (!insert_result.second) {
|
if (!insert_result.second) {
|
||||||
break;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
*count_it += 1;
|
*count_it += 1;
|
||||||
@ -3607,19 +3642,19 @@ class AggregateCursor : public Cursor {
|
|||||||
switch (agg_op) {
|
switch (agg_op) {
|
||||||
case Aggregation::Op::MIN:
|
case Aggregation::Op::MIN:
|
||||||
case Aggregation::Op::MAX:
|
case Aggregation::Op::MAX:
|
||||||
*value_it = input_value;
|
|
||||||
EnsureOkForMinMax(input_value);
|
EnsureOkForMinMax(input_value);
|
||||||
|
*value_it = std::move(input_value);
|
||||||
break;
|
break;
|
||||||
case Aggregation::Op::SUM:
|
case Aggregation::Op::SUM:
|
||||||
case Aggregation::Op::AVG:
|
case Aggregation::Op::AVG:
|
||||||
*value_it = input_value;
|
|
||||||
EnsureOkForAvgSum(input_value);
|
EnsureOkForAvgSum(input_value);
|
||||||
|
*value_it = std::move(input_value);
|
||||||
break;
|
break;
|
||||||
case Aggregation::Op::COUNT:
|
case Aggregation::Op::COUNT:
|
||||||
*value_it = 1;
|
// value is deferred to post-processing
|
||||||
break;
|
break;
|
||||||
case Aggregation::Op::COLLECT_LIST:
|
case Aggregation::Op::COLLECT_LIST:
|
||||||
value_it->ValueList().push_back(input_value);
|
value_it->ValueList().push_back(std::move(input_value));
|
||||||
break;
|
break;
|
||||||
case Aggregation::Op::PROJECT: {
|
case Aggregation::Op::PROJECT: {
|
||||||
EnsureOkForProject(input_value);
|
EnsureOkForProject(input_value);
|
||||||
@ -3629,7 +3664,7 @@ class AggregateCursor : public Cursor {
|
|||||||
case Aggregation::Op::COLLECT_MAP:
|
case Aggregation::Op::COLLECT_MAP:
|
||||||
auto key = agg_elem_it->key->Accept(*evaluator);
|
auto key = agg_elem_it->key->Accept(*evaluator);
|
||||||
if (key.type() != TypedValue::Type::String) throw QueryRuntimeException("Map key must be a string.");
|
if (key.type() != TypedValue::Type::String) throw QueryRuntimeException("Map key must be a string.");
|
||||||
value_it->ValueMap().emplace(key.ValueString(), input_value);
|
value_it->ValueMap().emplace(key.ValueString(), std::move(input_value));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
@ -3638,7 +3673,7 @@ class AggregateCursor : public Cursor {
|
|||||||
// aggregation of existing values
|
// aggregation of existing values
|
||||||
switch (agg_op) {
|
switch (agg_op) {
|
||||||
case Aggregation::Op::COUNT:
|
case Aggregation::Op::COUNT:
|
||||||
*value_it = *count_it;
|
// value is deferred to post-processing
|
||||||
break;
|
break;
|
||||||
case Aggregation::Op::MIN: {
|
case Aggregation::Op::MIN: {
|
||||||
EnsureOkForMinMax(input_value);
|
EnsureOkForMinMax(input_value);
|
||||||
@ -3647,7 +3682,7 @@ class AggregateCursor : public Cursor {
|
|||||||
// since we skip nulls we either have a valid comparison, or
|
// since we skip nulls we either have a valid comparison, or
|
||||||
// an exception was just thrown above
|
// an exception was just thrown above
|
||||||
// safe to assume a bool TypedValue
|
// safe to assume a bool TypedValue
|
||||||
if (comparison_result.ValueBool()) *value_it = input_value;
|
if (comparison_result.ValueBool()) *value_it = std::move(input_value);
|
||||||
} catch (const TypedValueException &) {
|
} catch (const TypedValueException &) {
|
||||||
throw QueryRuntimeException("Unable to get MIN of '{}' and '{}'.", input_value.type(), value_it->type());
|
throw QueryRuntimeException("Unable to get MIN of '{}' and '{}'.", input_value.type(), value_it->type());
|
||||||
}
|
}
|
||||||
@ -3658,7 +3693,7 @@ class AggregateCursor : public Cursor {
|
|||||||
EnsureOkForMinMax(input_value);
|
EnsureOkForMinMax(input_value);
|
||||||
try {
|
try {
|
||||||
TypedValue comparison_result = input_value > *value_it;
|
TypedValue comparison_result = input_value > *value_it;
|
||||||
if (comparison_result.ValueBool()) *value_it = input_value;
|
if (comparison_result.ValueBool()) *value_it = std::move(input_value);
|
||||||
} catch (const TypedValueException &) {
|
} catch (const TypedValueException &) {
|
||||||
throw QueryRuntimeException("Unable to get MAX of '{}' and '{}'.", input_value.type(), value_it->type());
|
throw QueryRuntimeException("Unable to get MAX of '{}' and '{}'.", input_value.type(), value_it->type());
|
||||||
}
|
}
|
||||||
@ -3672,7 +3707,7 @@ class AggregateCursor : public Cursor {
|
|||||||
*value_it = *value_it + input_value;
|
*value_it = *value_it + input_value;
|
||||||
break;
|
break;
|
||||||
case Aggregation::Op::COLLECT_LIST:
|
case Aggregation::Op::COLLECT_LIST:
|
||||||
value_it->ValueList().push_back(input_value);
|
value_it->ValueList().push_back(std::move(input_value));
|
||||||
break;
|
break;
|
||||||
case Aggregation::Op::PROJECT: {
|
case Aggregation::Op::PROJECT: {
|
||||||
EnsureOkForProject(input_value);
|
EnsureOkForProject(input_value);
|
||||||
@ -3682,7 +3717,7 @@ class AggregateCursor : public Cursor {
|
|||||||
case Aggregation::Op::COLLECT_MAP:
|
case Aggregation::Op::COLLECT_MAP:
|
||||||
auto key = agg_elem_it->key->Accept(*evaluator);
|
auto key = agg_elem_it->key->Accept(*evaluator);
|
||||||
if (key.type() != TypedValue::Type::String) throw QueryRuntimeException("Map key must be a string.");
|
if (key.type() != TypedValue::Type::String) throw QueryRuntimeException("Map key must be a string.");
|
||||||
value_it->ValueMap().emplace(key.ValueString(), input_value);
|
value_it->ValueMap().emplace(key.ValueString(), std::move(input_value));
|
||||||
break;
|
break;
|
||||||
} // end switch over Aggregation::Op enum
|
} // end switch over Aggregation::Op enum
|
||||||
} // end loop over all aggregations
|
} // end loop over all aggregations
|
||||||
|
@ -1135,3 +1135,26 @@ Feature: Functions
|
|||||||
| 0 |
|
| 0 |
|
||||||
| 1 |
|
| 1 |
|
||||||
| 2 |
|
| 2 |
|
||||||
|
|
||||||
|
Scenario: Aggregate distinct does not impact other aggregates:
|
||||||
|
Given an empty graph
|
||||||
|
And having executed:
|
||||||
|
"""
|
||||||
|
CREATE (:Node_A {id:1})
|
||||||
|
CREATE (:Node_A {id:2})
|
||||||
|
CREATE (:Node_A {id:3})
|
||||||
|
CREATE (:Node_B {id:1})
|
||||||
|
CREATE (:Node_B {id:2})
|
||||||
|
CREATE (:Node_B {id:3})
|
||||||
|
CREATE (:Node_B {id:4})
|
||||||
|
CREATE (:Node_B {id:4})
|
||||||
|
"""
|
||||||
|
When executing query:
|
||||||
|
"""
|
||||||
|
MATCH (a:Node_A), (b:Node_B)
|
||||||
|
RETURN COUNT(DISTINCT a.id) AS A_COUNT,
|
||||||
|
COUNT(b.id) AS B_COUNT;
|
||||||
|
"""
|
||||||
|
Then the result should be:
|
||||||
|
| A_COUNT | B_COUNT |
|
||||||
|
| 3 | 15 |
|
||||||
|
Loading…
Reference in New Issue
Block a user