Merge pull request #729 from memgraph/T1216-MG-implement-aggregate

Implement Aggregate with `MultiFrame`
This commit is contained in:
János Benjamin Antal 2023-01-30 15:56:30 +01:00 committed by GitHub
commit 7fa7586940
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 82 additions and 15 deletions

View File

@ -202,7 +202,7 @@ class DistributedCreateNodeCursor : public Cursor {
request_router->CreateVertices(NodeCreationInfoToRequests(context, multi_frame));
}
PlaceNodesOnTheMultiFrame(multi_frame, context);
return false;
return true;
}
void Shutdown() override { input_cursor_->Shutdown(); }
@ -1320,6 +1320,55 @@ class AggregateCursor : public Cursor {
auto remember_values_it = aggregation_it_->second.remember_.begin();
for (const Symbol &remember_sym : self_.remember_) frame[remember_sym] = *remember_values_it++;
++aggregation_it_;
return true;
}
bool PullMultiple(MultiFrame &multi_frame, ExecutionContext &context) override {
SCOPED_PROFILE_OP("AggregateMF");
if (!pulled_all_input_) {
ProcessAll(multi_frame, &context);
pulled_all_input_ = true;
MG_ASSERT(!multi_frame.HasValidFrame(), "ProcessAll didn't consumed all input frames!");
aggregation_it_ = aggregation_.begin();
// in case there is no input and no group_bys we need to return true
// just this once
if (aggregation_.empty() && self_.group_by_.empty()) {
auto frame = multi_frame.GetFirstFrame();
frame.MakeValid();
auto *pull_memory = context.evaluation_context.memory;
// place default aggregation values on the frame
for (const auto &elem : self_.aggregations_) {
frame[elem.output_sym] = DefaultAggregationOpValue(elem, pull_memory);
}
// place null as remember values on the frame
for (const Symbol &remember_sym : self_.remember_) {
frame[remember_sym] = TypedValue(pull_memory);
}
return true;
}
}
if (aggregation_it_ == aggregation_.end()) {
return false;
}
// place aggregation values on the frame
auto &frame = multi_frame.GetFirstFrame();
frame.MakeValid();
auto aggregation_values_it = aggregation_it_->second.values_.begin();
for (const auto &aggregation_elem : self_.aggregations_) {
frame[aggregation_elem.output_sym] = *aggregation_values_it++;
}
// place remember values on the frame
auto remember_values_it = aggregation_it_->second.remember_.begin();
for (const Symbol &remember_sym : self_.remember_) {
frame[remember_sym] = *remember_values_it++;
}
aggregation_it_++;
return true;
}
@ -1388,18 +1437,22 @@ class AggregateCursor : public Cursor {
ProcessOne(*frame, &evaluator);
}
// calculate AVG aggregations (so far they have only been summed)
for (size_t pos = 0; pos < self_.aggregations_.size(); ++pos) {
if (self_.aggregations_[pos].op != Aggregation::Op::AVG) continue;
for (auto &kv : aggregation_) {
AggregationValue &agg_value = kv.second;
auto count = agg_value.counts_[pos];
auto *pull_memory = context->evaluation_context.memory;
if (count > 0) {
agg_value.values_[pos] = agg_value.values_[pos] / TypedValue(static_cast<double>(count), pull_memory);
}
CalculateAverages(*context);
}
void ProcessAll(MultiFrame &multi_frame, ExecutionContext *context) {
while (input_cursor_->PullMultiple(multi_frame, *context)) {
auto valid_frames_modifier =
multi_frame.GetValidFramesConsumer(); // consumer is needed i.o. reader because of the evaluator
for (auto &frame : valid_frames_modifier) {
ExpressionEvaluator evaluator(&frame, context->symbol_table, context->evaluation_context,
context->request_router, storage::v3::View::NEW);
ProcessOne(frame, &evaluator);
}
}
CalculateAverages(*context);
}
/**
@ -1417,6 +1470,20 @@ class AggregateCursor : public Cursor {
Update(evaluator, &agg_value);
}
void CalculateAverages(ExecutionContext &context) {
for (size_t pos = 0; pos < self_.aggregations_.size(); ++pos) {
if (self_.aggregations_[pos].op != Aggregation::Op::AVG) continue;
for (auto &kv : aggregation_) {
AggregationValue &agg_value = kv.second;
auto count = agg_value.counts_[pos];
auto *pull_memory = context.evaluation_context.memory;
if (count > 0) {
agg_value.values_[pos] = agg_value.values_[pos] / TypedValue(static_cast<double>(count), pull_memory);
}
}
}
}
/** Ensures the new AggregationValue has been initialized. This means
* that the value vectors are filled with an appropriate number of Nulls,
* counts are set to 0 and remember values are remembered.
@ -1450,7 +1517,7 @@ class AggregateCursor : public Cursor {
for (; count_it < agg_value->counts_.end(); count_it++, value_it++, agg_elem_it++) {
// COUNT(*) is the only case where input expression is optional
// handle it here
auto input_expr_ptr = agg_elem_it->value;
auto *input_expr_ptr = agg_elem_it->value;
if (!input_expr_ptr) {
*count_it += 1;
*value_it = *count_it;
@ -1541,7 +1608,7 @@ class AggregateCursor : public Cursor {
/** Checks if the given TypedValue is legal in MIN and MAX. If not
* an appropriate exception is thrown. */
void EnsureOkForMinMax(const TypedValue &value) const {
static void EnsureOkForMinMax(const TypedValue &value) {
switch (value.type()) {
case TypedValue::Type::Bool:
case TypedValue::Type::Int:
@ -1557,7 +1624,7 @@ class AggregateCursor : public Cursor {
/** Checks if the given TypedValue is legal in AVG and SUM. If not
* an appropriate exception is thrown. */
void EnsureOkForAvgSum(const TypedValue &value) const {
static void EnsureOkForAvgSum(const TypedValue &value) {
switch (value.type()) {
case TypedValue::Type::Int:
case TypedValue::Type::Double:

View File

@ -86,7 +86,7 @@ class TestPlanner : public ::testing::Test {};
using PlannerTypes = ::testing::Types<Planner>;
TYPED_TEST_CASE(TestPlanner, PlannerTypes);
TYPED_TEST_SUITE(TestPlanner, PlannerTypes);
TYPED_TEST(TestPlanner, MatchFilterPropIsNotNull) {
const char *prim_label_name = "prim_label_one";