diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index 42d57cfe4..b5296cd2e 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -593,12 +593,13 @@ class Aggregation : public UnaryOperator { friend class AstTreeStorage; public: - enum class Op { COUNT, MIN, MAX, SUM, AVG }; + enum class Op { COUNT, MIN, MAX, SUM, AVG, COLLECT }; static const constexpr char *const kCount = "COUNT"; static const constexpr char *const kMin = "MIN"; static const constexpr char *const kMax = "MAX"; static const constexpr char *const kSum = "SUM"; static const constexpr char *const kAvg = "AVG"; + static const constexpr char *const kCollect = "COLLECT"; DEFVISITABLE(TreeVisitor<TypedValue>); bool Accept(HierarchicalTreeVisitor &visitor) override { diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index 00b801dc7..21aa3bf7b 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -788,6 +788,10 @@ antlrcpp::Any CypherMainVisitor::visitFunctionInvocation( return static_cast<Expression *>( storage_.Create<Aggregation>(expressions[0], Aggregation::Op::AVG)); } + if (function_name == Aggregation::kCollect) { + return static_cast<Expression *>(storage_.Create<Aggregation>( + expressions[0], Aggregation::Op::COLLECT)); + } } auto function = NameToFunction(function_name); if (!function) throw SemanticException(); diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 6f8f60369..b8583ee59 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -938,10 +938,13 @@ void Aggregate::AggregateCursor::EnsureInitialized( if (agg_value.values_.size() > 0) return; for (const auto &agg_elem : self_.aggregations_) { - if (std::get<1>(agg_elem) == Aggregation::Op::COUNT) + if (std::get<1>(agg_elem) == Aggregation::Op::COUNT) { agg_value.values_.emplace_back(TypedValue(0)); - else + } else if (std::get<1>(agg_elem) == Aggregation::Op::COLLECT) { + agg_value.values_.emplace_back(std::vector<TypedValue>()); + } else { agg_value.values_.emplace_back(TypedValue::Null); + } } agg_value.counts_.resize(self_.aggregations_.size(), 0); @@ -964,7 +967,6 @@ void Aggregate::AggregateCursor::Update( auto agg_elem_it = self_.aggregations_.begin(); for (; count_it < agg_value.counts_.end(); count_it++, value_it++, agg_elem_it++) { - // COUNT(*) is the only case where input expression is optional // handle it here auto input_expr_ptr = std::get<0>(*agg_elem_it); @@ -986,16 +988,21 @@ void Aggregate::AggregateCursor::Update( switch (agg_op) { case Aggregation::Op::MIN: case Aggregation::Op::MAX: + *value_it = input_value; EnsureOkForMinMax(input_value); break; case Aggregation::Op::SUM: case Aggregation::Op::AVG: + *value_it = input_value; EnsureOkForAvgSum(input_value); break; case Aggregation::Op::COUNT: + *value_it = 1; + break; + case Aggregation::Op::COLLECT: + value_it->Value<std::vector<TypedValue>>().push_back(input_value); break; } - *value_it = agg_op == Aggregation::Op::COUNT ? 1 : input_value; continue; } @@ -1029,6 +1036,9 @@ void Aggregate::AggregateCursor::Update( EnsureOkForAvgSum(input_value); *value_it = *value_it + input_value; break; + case Aggregation::Op::COLLECT: + value_it->Value<std::vector<TypedValue>>().push_back(input_value); + break; } // end switch over Aggregation::Op enum } // end loop over all aggregations } diff --git a/tests/unit/cypher_main_visitor.cpp b/tests/unit/cypher_main_visitor.cpp index f58332dbe..451995308 100644 --- a/tests/unit/cypher_main_visitor.cpp +++ b/tests/unit/cypher_main_visitor.cpp @@ -506,15 +506,15 @@ TEST(CypherMainVisitorTest, UnaryMinusPlusOperators) { TEST(CypherMainVisitorTest, Aggregation) { AstGenerator ast_generator( - "RETURN COUNT(a), MIN(b), MAX(c), SUM(d), AVG(e), COUNT(*)"); + "RETURN COUNT(a), MIN(b), MAX(c), SUM(d), AVG(e), COLLECT(f), COUNT(*)"); auto *query = ast_generator.query_; auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]); - ASSERT_EQ(return_clause->body_.named_expressions.size(), 6U); + ASSERT_EQ(return_clause->body_.named_expressions.size(), 7U); Aggregation::Op ops[] = {Aggregation::Op::COUNT, Aggregation::Op::MIN, - Aggregation::Op::MAX, Aggregation::Op::SUM, - Aggregation::Op::AVG}; - std::string ids[] = {"a", "b", "c", "d", "e"}; - for (int i = 0; i < 5; ++i) { + Aggregation::Op::MAX, Aggregation::Op::SUM, + Aggregation::Op::AVG, Aggregation::Op::COLLECT}; + std::string ids[] = {"a", "b", "c", "d", "e", "f"}; + for (int i = 0; i < 6; ++i) { auto *aggregation = dynamic_cast<Aggregation *>( return_clause->body_.named_expressions[i]->expression_); ASSERT_TRUE(aggregation); @@ -524,7 +524,7 @@ TEST(CypherMainVisitorTest, Aggregation) { ASSERT_EQ(identifier->name_, ids[i]); } auto *aggregation = dynamic_cast<Aggregation *>( - return_clause->body_.named_expressions[5]->expression_); + return_clause->body_.named_expressions[6]->expression_); ASSERT_TRUE(aggregation); ASSERT_EQ(aggregation->op_, Aggregation::Op::COUNT); ASSERT_FALSE(aggregation->expression_); diff --git a/tests/unit/query_common.hpp b/tests/unit/query_common.hpp index 7ed0a4749..54e3cd8ce 100644 --- a/tests/unit/query_common.hpp +++ b/tests/unit/query_common.hpp @@ -28,6 +28,14 @@ namespace query { namespace test_common { +auto ToInt64List(const TypedValue &t) { + std::vector<int64_t> list; + for (auto x : t.Value<std::vector<TypedValue>>()) { + list.push_back(x.Value<int64_t>()); + } + return list; +}; + // Custom types for ORDER BY, SKIP, LIMIT, ON MATCH and ON CREATE expressions, // so that they can be used to resolve function calls. struct OrderBy { diff --git a/tests/unit/query_expression_evaluator.cpp b/tests/unit/query_expression_evaluator.cpp index 85f291f88..909a56a8d 100644 --- a/tests/unit/query_expression_evaluator.cpp +++ b/tests/unit/query_expression_evaluator.cpp @@ -14,14 +14,15 @@ #include "query/interpret/awesome_memgraph_functions.hpp" #include "query/interpret/eval.hpp" #include "query/interpret/frame.hpp" +#include "query_common.hpp" using namespace query; using testing::Pair; using testing::UnorderedElementsAre; using testing::ElementsAre; +using query::test_common::ToInt64List; namespace { - struct NoContextExpressionEvaluator { NoContextExpressionEvaluator() {} Frame frame{128}; @@ -823,30 +824,23 @@ TEST(ExpressionEvaluator, FunctionRange) { EXPECT_TRUE(EvaluateFunction("RANGE", {1, 2, TypedValue::Null}).IsNull()); EXPECT_THROW(EvaluateFunction("RANGE", {1, TypedValue::Null, 1.3}), QueryRuntimeException); - auto to_int_list = [](const TypedValue &t) { - std::vector<int64_t> list; - for (auto x : t.Value<std::vector<TypedValue>>()) { - list.push_back(x.Value<int64_t>()); - } - return list; - }; EXPECT_THROW(EvaluateFunction("RANGE", {1, 2, 0}), QueryRuntimeException); - EXPECT_THAT(to_int_list(EvaluateFunction("RANGE", {1, 3})), + EXPECT_THAT(ToInt64List(EvaluateFunction("RANGE", {1, 3})), ElementsAre(1, 2, 3)); - EXPECT_THAT(to_int_list(EvaluateFunction("RANGE", {-1, 5, 2})), + EXPECT_THAT(ToInt64List(EvaluateFunction("RANGE", {-1, 5, 2})), ElementsAre(-1, 1, 3, 5)); - EXPECT_THAT(to_int_list(EvaluateFunction("RANGE", {2, 10, 3})), + EXPECT_THAT(ToInt64List(EvaluateFunction("RANGE", {2, 10, 3})), ElementsAre(2, 5, 8)); - EXPECT_THAT(to_int_list(EvaluateFunction("RANGE", {2, 2, 2})), + EXPECT_THAT(ToInt64List(EvaluateFunction("RANGE", {2, 2, 2})), ElementsAre(2)); - EXPECT_THAT(to_int_list(EvaluateFunction("RANGE", {3, 0, 5})), ElementsAre()); - EXPECT_THAT(to_int_list(EvaluateFunction("RANGE", {5, 1, -2})), + EXPECT_THAT(ToInt64List(EvaluateFunction("RANGE", {3, 0, 5})), ElementsAre()); + EXPECT_THAT(ToInt64List(EvaluateFunction("RANGE", {5, 1, -2})), ElementsAre(5, 3, 1)); - EXPECT_THAT(to_int_list(EvaluateFunction("RANGE", {6, 1, -2})), + EXPECT_THAT(ToInt64List(EvaluateFunction("RANGE", {6, 1, -2})), ElementsAre(6, 4, 2)); - EXPECT_THAT(to_int_list(EvaluateFunction("RANGE", {2, 2, -3})), + EXPECT_THAT(ToInt64List(EvaluateFunction("RANGE", {2, 2, -3})), ElementsAre(2)); - EXPECT_THAT(to_int_list(EvaluateFunction("RANGE", {-2, 4, -1})), + EXPECT_THAT(ToInt64List(EvaluateFunction("RANGE", {-2, 4, -1})), ElementsAre()); } diff --git a/tests/unit/query_plan_accumulate_aggregate.cpp b/tests/unit/query_plan_accumulate_aggregate.cpp index bd3590547..b3480d488 100644 --- a/tests/unit/query_plan_accumulate_aggregate.cpp +++ b/tests/unit/query_plan_accumulate_aggregate.cpp @@ -21,6 +21,8 @@ using namespace query; using namespace query::plan; +using testing::UnorderedElementsAre; +using query::test_common::ToInt64List; TEST(QueryPlan, Accumulate) { // simulate the following two query execution on an empty db @@ -169,18 +171,19 @@ TEST(QueryPlan, AggregateOps) { auto n_p = PROPERTY_LOOKUP("n", prop); symbol_table[*n_p->expression_] = n.sym_; - std::vector<Expression *> aggregation_expressions(6, n_p); + std::vector<Expression *> aggregation_expressions(7, n_p); aggregation_expressions[0] = nullptr; auto produce = MakeAggregationProduce( n.op_, symbol_table, storage, aggregation_expressions, {Aggregation::Op::COUNT, Aggregation::Op::COUNT, Aggregation::Op::MIN, - Aggregation::Op::MAX, Aggregation::Op::SUM, Aggregation::Op::AVG}, + Aggregation::Op::MAX, Aggregation::Op::SUM, Aggregation::Op::AVG, + Aggregation::Op::COLLECT}, {}, {}); // checks auto results = CollectProduce(produce, symbol_table, *dba).GetResults(); ASSERT_EQ(results.size(), 1); - ASSERT_EQ(results[0].size(), 6); + ASSERT_EQ(results[0].size(), 7); // count(*) ASSERT_EQ(results[0][0].type(), TypedValue::Type::Int); EXPECT_EQ(results[0][0].Value<int64_t>(), 4); @@ -199,6 +202,9 @@ TEST(QueryPlan, AggregateOps) { // avg ASSERT_EQ(results[0][5].type(), TypedValue::Type::Double); EXPECT_FLOAT_EQ(results[0][5].Value<double>(), 24 / 3.0); + // collect + ASSERT_EQ(results[0][6].type(), TypedValue::Type::List); + EXPECT_THAT(ToInt64List(results[0][6]), UnorderedElementsAre(5, 7, 12)); } TEST(QueryPlan, AggregateGroupByValues) { @@ -320,6 +326,34 @@ TEST(QueryPlan, AggregateNoInput) { EXPECT_EQ(1, results[0][0].Value<int64_t>()); } +// TODO: This test is valid but it fails. We don't handle aggregations correctly +// in the case when there is no input. Also add similar tests for other +// aggregation ops. +// TEST(QueryPlan, AggregateCollectNoResults) { +// Dbms dbms; +// auto dba = dbms.active(); +// auto prop = dba->property("prop"); +// +// AstTreeStorage storage; +// SymbolTable symbol_table; +// +// // match all nodes and perform aggregations +// auto n = MakeScanAll(storage, symbol_table, "n"); +// auto n_p = PROPERTY_LOOKUP("n", prop); +// symbol_table[*n_p->expression_] = n.sym_; +// +// auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p}, +// {Aggregation::Op::COLLECT}, {}, {}); +// +// // checks +// auto results = CollectProduce(produce, symbol_table, *dba).GetResults(); +// ASSERT_EQ(results.size(), 1); +// ASSERT_EQ(results[0].size(), 1); +// ASSERT_EQ(results[0][0].type(), TypedValue::Type::List); +// // Collect should return empty list if there are no results. +// EXPECT_THAT(ToInt64List(results[0][0]), UnorderedElementsAre()); +//} + TEST(QueryPlan, AggregateCountEdgeCases) { // tests for detected bugs in the COUNT aggregation behavior // ensure that COUNT returns correctly for