From 9589dd97b676896a062e3b2dc9f566de5524db8e Mon Sep 17 00:00:00 2001 From: jeremy <jeremy.bailleux@memgraph.io> Date: Fri, 30 Dec 2022 16:21:41 +0100 Subject: [PATCH] Impl and correct aggregate --- src/query/v2/multiframe.hpp | 2 +- src/query/v2/plan/operator.cpp | 116 +++++++++++++++++++++++++++++++++ 2 files changed, 117 insertions(+), 1 deletion(-) diff --git a/src/query/v2/multiframe.hpp b/src/query/v2/multiframe.hpp index aeacda7d3..0365b449f 100644 --- a/src/query/v2/multiframe.hpp +++ b/src/query/v2/multiframe.hpp @@ -168,7 +168,7 @@ class ValidFramesModifier { Iterator &operator++() { do { ptr_++; - } while (*this != iterator_wrapper_->end() && ptr_->IsValid()); + } while (*this != iterator_wrapper_->end() && !ptr_->IsValid()); return *this; } diff --git a/src/query/v2/plan/operator.cpp b/src/query/v2/plan/operator.cpp index b6bbf9ce1..43376aaed 100644 --- a/src/query/v2/plan/operator.cpp +++ b/src/query/v2/plan/operator.cpp @@ -1234,6 +1234,55 @@ class AggregateCursor : public Cursor { return true; } + void PullMultiple(MultiFrame &multi_frame, ExecutionContext &context) override { + SCOPED_PROFILE_OP("AggregateMF"); + + if (!pulled_all_input_) { + while (!pulled_all_input_) { + ProcessAll(multi_frame, &context); + } + multi_frame.MakeAllFramesInvalid(); + aggregation_it_ = aggregation_.begin(); + + // in case there is no input and no group_bys we need to return true + // just this once + if (aggregation_.empty() && self_.group_by_.empty()) { + auto frame = multi_frame.GetFirstFrame(); + frame.MakeValid(); + auto *pull_memory = context.evaluation_context.memory; + // place default aggregation values on the frame + for (const auto &elem : self_.aggregations_) { + frame[elem.output_sym] = DefaultAggregationOpValue(elem, pull_memory); + } + // place null as remember values on the frame + for (const Symbol &remember_sym : self_.remember_) { + frame[remember_sym] = TypedValue(pull_memory); + } + return; + } + } + + if (aggregation_it_ == aggregation_.end()) { + return; + } + + // place aggregation values on the frame + auto &frame = multi_frame.GetFirstFrame(); + frame.MakeValid(); + auto aggregation_values_it = aggregation_it_->second.values_.begin(); + for (const auto &aggregation_elem : self_.aggregations_) { + frame[aggregation_elem.output_sym] = *aggregation_values_it++; + } + + // place remember values on the frame + auto remember_values_it = aggregation_it_->second.remember_.begin(); + for (const Symbol &remember_sym : self_.remember_) { + frame[remember_sym] = *remember_values_it++; + } + + aggregation_it_++; + } + void Shutdown() override { input_cursor_->Shutdown(); } void Reset() override { @@ -1312,6 +1361,36 @@ class AggregateCursor : public Cursor { } } + void ProcessAll(MultiFrame &multi_frame, ExecutionContext *context) { + input_cursor_->PullMultiple(multi_frame, *context); + auto valid_frames_modifier = + multi_frame.GetValidFramesConsumer(); // consumer is needed i.o. reader because of the evaluator + if (valid_frames_modifier.begin() == valid_frames_modifier.end()) { + // There are no valid frames, we stop + pulled_all_input_ = true; + return; + } + + for (auto &frame : valid_frames_modifier) { + ExpressionEvaluator evaluator(&frame, context->symbol_table, context->evaluation_context, context->request_router, + storage::v3::View::NEW); + ProcessOne(frame, &evaluator); + } + + // calculate AVG aggregations (so far they have only been summed) + for (size_t pos = 0; pos < self_.aggregations_.size(); ++pos) { + if (self_.aggregations_[pos].op != Aggregation::Op::AVG) continue; + for (auto &kv : aggregation_) { + AggregationValue &agg_value = kv.second; + auto count = agg_value.counts_[pos]; + auto *pull_memory = context->evaluation_context.memory; + if (count > 0) { + agg_value.values_[pos] = agg_value.values_[pos] / TypedValue(static_cast<double>(count), pull_memory); + } + } + } + } + /** * Performs a single accumulation. */ @@ -1327,6 +1406,21 @@ class AggregateCursor : public Cursor { Update(evaluator, &agg_value); } + /** + * Performs a single accumulation. + */ + void ProcessOne(FrameWithValidity &frame, ExpressionEvaluator *evaluator) { + auto *mem = aggregation_.get_allocator().GetMemoryResource(); + utils::pmr::vector<TypedValue> group_by(mem); + group_by.reserve(self_.group_by_.size()); + for (Expression *expression : self_.group_by_) { + group_by.emplace_back(expression->Accept(*evaluator)); + } + auto &agg_value = aggregation_.try_emplace(std::move(group_by), mem).first->second; + EnsureInitialized(frame, &agg_value); + Update(evaluator, &agg_value); + } + /** Ensures the new AggregationValue has been initialized. This means * that the value vectors are filled with an appropriate number of Nulls, * counts are set to 0 and remember values are remembered. @@ -1343,6 +1437,28 @@ class AggregateCursor : public Cursor { for (const Symbol &remember_sym : self_.remember_) agg_value->remember_.push_back(frame[remember_sym]); } + /** Ensures the new AggregationValue has been initialized. This means + * that the value vectors are filled with an appropriate number of Nulls, + * counts are set to 0 and remember values are remembered. + */ + void EnsureInitialized(FrameWithValidity &frame, AggregateCursor::AggregationValue *agg_value) const { + if (!agg_value->values_.empty()) { + frame.MakeInvalid(); + return; + } + + for (const auto &agg_elem : self_.aggregations_) { + auto *mem = agg_value->values_.get_allocator().GetMemoryResource(); + agg_value->values_.emplace_back(DefaultAggregationOpValue(agg_elem, mem)); + } + agg_value->counts_.resize(self_.aggregations_.size(), 0); + + for (const Symbol &remember_sym : self_.remember_) { + agg_value->remember_.push_back(frame[remember_sym]); + } + frame.MakeInvalid(); + } + /** Updates the given AggregationValue with new data. Assumes that * the AggregationValue has been initialized */ void Update(ExpressionEvaluator *evaluator, AggregateCursor::AggregationValue *agg_value) {