From 9589dd97b676896a062e3b2dc9f566de5524db8e Mon Sep 17 00:00:00 2001
From: jeremy <jeremy.bailleux@memgraph.io>
Date: Fri, 30 Dec 2022 16:21:41 +0100
Subject: [PATCH] Impl and correct aggregate

---
 src/query/v2/multiframe.hpp    |   2 +-
 src/query/v2/plan/operator.cpp | 116 +++++++++++++++++++++++++++++++++
 2 files changed, 117 insertions(+), 1 deletion(-)

diff --git a/src/query/v2/multiframe.hpp b/src/query/v2/multiframe.hpp
index aeacda7d3..0365b449f 100644
--- a/src/query/v2/multiframe.hpp
+++ b/src/query/v2/multiframe.hpp
@@ -168,7 +168,7 @@ class ValidFramesModifier {
     Iterator &operator++() {
       do {
         ptr_++;
-      } while (*this != iterator_wrapper_->end() && ptr_->IsValid());
+      } while (*this != iterator_wrapper_->end() && !ptr_->IsValid());
 
       return *this;
     }
diff --git a/src/query/v2/plan/operator.cpp b/src/query/v2/plan/operator.cpp
index b6bbf9ce1..43376aaed 100644
--- a/src/query/v2/plan/operator.cpp
+++ b/src/query/v2/plan/operator.cpp
@@ -1234,6 +1234,55 @@ class AggregateCursor : public Cursor {
     return true;
   }
 
+  void PullMultiple(MultiFrame &multi_frame, ExecutionContext &context) override {
+    SCOPED_PROFILE_OP("AggregateMF");
+
+    if (!pulled_all_input_) {
+      while (!pulled_all_input_) {
+        ProcessAll(multi_frame, &context);
+      }
+      multi_frame.MakeAllFramesInvalid();
+      aggregation_it_ = aggregation_.begin();
+
+      // in case there is no input and no group_bys we need to return true
+      // just this once
+      if (aggregation_.empty() && self_.group_by_.empty()) {
+        auto frame = multi_frame.GetFirstFrame();
+        frame.MakeValid();
+        auto *pull_memory = context.evaluation_context.memory;
+        // place default aggregation values on the frame
+        for (const auto &elem : self_.aggregations_) {
+          frame[elem.output_sym] = DefaultAggregationOpValue(elem, pull_memory);
+        }
+        // place null as remember values on the frame
+        for (const Symbol &remember_sym : self_.remember_) {
+          frame[remember_sym] = TypedValue(pull_memory);
+        }
+        return;
+      }
+    }
+
+    if (aggregation_it_ == aggregation_.end()) {
+      return;
+    }
+
+    // place aggregation values on the frame
+    auto &frame = multi_frame.GetFirstFrame();
+    frame.MakeValid();
+    auto aggregation_values_it = aggregation_it_->second.values_.begin();
+    for (const auto &aggregation_elem : self_.aggregations_) {
+      frame[aggregation_elem.output_sym] = *aggregation_values_it++;
+    }
+
+    // place remember values on the frame
+    auto remember_values_it = aggregation_it_->second.remember_.begin();
+    for (const Symbol &remember_sym : self_.remember_) {
+      frame[remember_sym] = *remember_values_it++;
+    }
+
+    aggregation_it_++;
+  }
+
   void Shutdown() override { input_cursor_->Shutdown(); }
 
   void Reset() override {
@@ -1312,6 +1361,36 @@ class AggregateCursor : public Cursor {
     }
   }
 
+  void ProcessAll(MultiFrame &multi_frame, ExecutionContext *context) {
+    input_cursor_->PullMultiple(multi_frame, *context);
+    auto valid_frames_modifier =
+        multi_frame.GetValidFramesConsumer();  // consumer is needed i.o. reader because of the evaluator
+    if (valid_frames_modifier.begin() == valid_frames_modifier.end()) {
+      // There are no valid frames, we stop
+      pulled_all_input_ = true;
+      return;
+    }
+
+    for (auto &frame : valid_frames_modifier) {
+      ExpressionEvaluator evaluator(&frame, context->symbol_table, context->evaluation_context, context->request_router,
+                                    storage::v3::View::NEW);
+      ProcessOne(frame, &evaluator);
+    }
+
+    // calculate AVG aggregations (so far they have only been summed)
+    for (size_t pos = 0; pos < self_.aggregations_.size(); ++pos) {
+      if (self_.aggregations_[pos].op != Aggregation::Op::AVG) continue;
+      for (auto &kv : aggregation_) {
+        AggregationValue &agg_value = kv.second;
+        auto count = agg_value.counts_[pos];
+        auto *pull_memory = context->evaluation_context.memory;
+        if (count > 0) {
+          agg_value.values_[pos] = agg_value.values_[pos] / TypedValue(static_cast<double>(count), pull_memory);
+        }
+      }
+    }
+  }
+
   /**
    * Performs a single accumulation.
    */
@@ -1327,6 +1406,21 @@ class AggregateCursor : public Cursor {
     Update(evaluator, &agg_value);
   }
 
+  /**
+   * Performs a single accumulation.
+   */
+  void ProcessOne(FrameWithValidity &frame, ExpressionEvaluator *evaluator) {
+    auto *mem = aggregation_.get_allocator().GetMemoryResource();
+    utils::pmr::vector<TypedValue> group_by(mem);
+    group_by.reserve(self_.group_by_.size());
+    for (Expression *expression : self_.group_by_) {
+      group_by.emplace_back(expression->Accept(*evaluator));
+    }
+    auto &agg_value = aggregation_.try_emplace(std::move(group_by), mem).first->second;
+    EnsureInitialized(frame, &agg_value);
+    Update(evaluator, &agg_value);
+  }
+
   /** Ensures the new AggregationValue has been initialized. This means
    * that the value vectors are filled with an appropriate number of Nulls,
    * counts are set to 0 and remember values are remembered.
@@ -1343,6 +1437,28 @@ class AggregateCursor : public Cursor {
     for (const Symbol &remember_sym : self_.remember_) agg_value->remember_.push_back(frame[remember_sym]);
   }
 
+  /** Ensures the new AggregationValue has been initialized. This means
+   * that the value vectors are filled with an appropriate number of Nulls,
+   * counts are set to 0 and remember values are remembered.
+   */
+  void EnsureInitialized(FrameWithValidity &frame, AggregateCursor::AggregationValue *agg_value) const {
+    if (!agg_value->values_.empty()) {
+      frame.MakeInvalid();
+      return;
+    }
+
+    for (const auto &agg_elem : self_.aggregations_) {
+      auto *mem = agg_value->values_.get_allocator().GetMemoryResource();
+      agg_value->values_.emplace_back(DefaultAggregationOpValue(agg_elem, mem));
+    }
+    agg_value->counts_.resize(self_.aggregations_.size(), 0);
+
+    for (const Symbol &remember_sym : self_.remember_) {
+      agg_value->remember_.push_back(frame[remember_sym]);
+    }
+    frame.MakeInvalid();
+  }
+
   /** Updates the given AggregationValue with new data. Assumes that
    * the AggregationValue has been initialized */
   void Update(ExpressionEvaluator *evaluator, AggregateCursor::AggregationValue *agg_value) {