diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 792d278e8..238638737 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -3463,7 +3463,7 @@ class AggregateCursor : public Cursor { SCOPED_PROFILE_OP_BY_REF(self_); if (!pulled_all_input_) { - ProcessAll(&frame, &context); + if (!ProcessAll(&frame, &context) && self_.AreAllAggregationsForCollecting()) return false; pulled_all_input_ = true; aggregation_it_ = aggregation_.begin(); @@ -3487,7 +3487,6 @@ class AggregateCursor : public Cursor { return true; } } - if (aggregation_it_ == aggregation_.end()) return false; // place aggregation values on the frame @@ -3567,12 +3566,16 @@ class AggregateCursor : public Cursor { * cache cardinality depends on number of * aggregation results, and not on the number of inputs. */ - void ProcessAll(Frame *frame, ExecutionContext *context) { + bool ProcessAll(Frame *frame, ExecutionContext *context) { ExpressionEvaluator evaluator(frame, context->symbol_table, context->evaluation_context, context->db_accessor, storage::View::NEW); + + bool pulled = false; while (input_cursor_->Pull(*frame, *context)) { ProcessOne(*frame, &evaluator); + pulled = true; } + if (!pulled) return false; // post processing for (size_t pos = 0; pos < self_.aggregations_.size(); ++pos) { @@ -3606,6 +3609,7 @@ class AggregateCursor : public Cursor { break; } } + return true; } /** @@ -3819,6 +3823,12 @@ UniqueCursorPtr Aggregate::MakeCursor(utils::MemoryResource *mem) const { return MakeUniqueCursorPtr<AggregateCursor>(mem, *this, mem); } +auto Aggregate::AreAllAggregationsForCollecting() const -> bool { + return std::all_of(aggregations_.begin(), aggregations_.end(), [](const auto &agg) { + return agg.op == Aggregation::Op::COLLECT_LIST || agg.op == Aggregation::Op::COLLECT_MAP; + }); +} + Skip::Skip(const std::shared_ptr<LogicalOperator> &input, Expression *expression) : input_(input), expression_(expression) {} diff --git a/src/query/plan/operator.hpp b/src/query/plan/operator.hpp index 4951b5137..ba844796a 100644 --- a/src/query/plan/operator.hpp +++ b/src/query/plan/operator.hpp @@ -1758,6 +1758,9 @@ class Aggregate : public memgraph::query::plan::LogicalOperator { Aggregate() = default; Aggregate(const std::shared_ptr<LogicalOperator> &input, const std::vector<Element> &aggregations, const std::vector<Expression *> &group_by, const std::vector<Symbol> &remember); + + auto AreAllAggregationsForCollecting() const -> bool; + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; diff --git a/tests/e2e/replication/show_while_creating_invalid_state.py b/tests/e2e/replication/show_while_creating_invalid_state.py index f8fae4cd6..996955dc1 100644 --- a/tests/e2e/replication/show_while_creating_invalid_state.py +++ b/tests/e2e/replication/show_while_creating_invalid_state.py @@ -697,7 +697,7 @@ def test_sync_replication_when_main_is_killed(): ) # 2/ - QUERY_TO_CHECK = "MATCH (n) RETURN COLLECT(n.name);" + QUERY_TO_CHECK = "MATCH (n) RETURN COUNT(n.name);" last_result_from_main = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query(QUERY_TO_CHECK)[0][0] for index in range(50): interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query(f"CREATE (p:Number {{name:{index}}})") diff --git a/tests/gql_behave/tests/memgraph_V1/features/aggregations.feature b/tests/gql_behave/tests/memgraph_V1/features/aggregations.feature index 8fe6a47ad..cff138432 100644 --- a/tests/gql_behave/tests/memgraph_V1/features/aggregations.feature +++ b/tests/gql_behave/tests/memgraph_V1/features/aggregations.feature @@ -401,4 +401,30 @@ Feature: Aggregations MATCH p=()-[:Z]->() WITH project(p) as graph WITH graph.edges as edges UNWIND edges as e RETURN e.prop as y ORDER BY y DESC """ Then the result should be: - | y | + | y | + + Scenario: Empty collect aggregation: + Given an empty graph + And having executed + """ + CREATE (s:Subnet {ip: "192.168.0.1"}) + """ + When executing query: + """ + MATCH (subnet:Subnet) WHERE FALSE WITH subnet, collect(subnet.ip) as ips RETURN id(subnet) as id + """ + Then the result should be empty + + Scenario: Empty count aggregation: + Given an empty graph + And having executed + """ + CREATE (s:Subnet {ip: "192.168.0.1"}) + """ + When executing query: + """ + MATCH (subnet:Subnet) WHERE FALSE WITH subnet, count(subnet.ip) as ips RETURN id(subnet) as id + """ + Then the result should be: + | id | + | null | diff --git a/tests/gql_behave/tests/memgraph_V1_on_disk/features/aggregations.feature b/tests/gql_behave/tests/memgraph_V1_on_disk/features/aggregations.feature index 8fe6a47ad..cff138432 100644 --- a/tests/gql_behave/tests/memgraph_V1_on_disk/features/aggregations.feature +++ b/tests/gql_behave/tests/memgraph_V1_on_disk/features/aggregations.feature @@ -401,4 +401,30 @@ Feature: Aggregations MATCH p=()-[:Z]->() WITH project(p) as graph WITH graph.edges as edges UNWIND edges as e RETURN e.prop as y ORDER BY y DESC """ Then the result should be: - | y | + | y | + + Scenario: Empty collect aggregation: + Given an empty graph + And having executed + """ + CREATE (s:Subnet {ip: "192.168.0.1"}) + """ + When executing query: + """ + MATCH (subnet:Subnet) WHERE FALSE WITH subnet, collect(subnet.ip) as ips RETURN id(subnet) as id + """ + Then the result should be empty + + Scenario: Empty count aggregation: + Given an empty graph + And having executed + """ + CREATE (s:Subnet {ip: "192.168.0.1"}) + """ + When executing query: + """ + MATCH (subnet:Subnet) WHERE FALSE WITH subnet, count(subnet.ip) as ips RETURN id(subnet) as id + """ + Then the result should be: + | id | + | null | diff --git a/tests/unit/query_plan_accumulate_aggregate.cpp b/tests/unit/query_plan_accumulate_aggregate.cpp index e271e0f6a..bbf3e0311 100644 --- a/tests/unit/query_plan_accumulate_aggregate.cpp +++ b/tests/unit/query_plan_accumulate_aggregate.cpp @@ -277,13 +277,11 @@ TYPED_TEST(QueryPlanAggregateOps, WithoutDataWithGroupBy) { } { auto results = this->AggregationResults(true, false, {Aggregation::Op::COLLECT_LIST}); - EXPECT_EQ(results.size(), 1); - EXPECT_EQ(results[0][0].type(), TypedValue::Type::List); + EXPECT_EQ(results.size(), 0); } { auto results = this->AggregationResults(true, false, {Aggregation::Op::COLLECT_MAP}); - EXPECT_EQ(results.size(), 1); - EXPECT_EQ(results[0][0].type(), TypedValue::Type::Map); + EXPECT_EQ(results.size(), 0); } } @@ -695,13 +693,11 @@ TYPED_TEST(QueryPlanAggregateOps, WithoutDataWithDistinctAndWithGroupBy) { } { auto results = this->AggregationResults(true, true, {Aggregation::Op::COLLECT_LIST}); - EXPECT_EQ(results.size(), 1); - EXPECT_EQ(results[0][0].type(), TypedValue::Type::List); + EXPECT_EQ(results.size(), 0); } { auto results = this->AggregationResults(true, true, {Aggregation::Op::COLLECT_MAP}); - EXPECT_EQ(results.size(), 1); - EXPECT_EQ(results[0][0].type(), TypedValue::Type::Map); + EXPECT_EQ(results.size(), 0); } }