diff --git a/src/query/frontend/ast/ast.lcp b/src/query/frontend/ast/ast.lcp index f78f5f5c3..577d15113 100644 --- a/src/query/frontend/ast/ast.lcp +++ b/src/query/frontend/ast/ast.lcp @@ -461,7 +461,8 @@ cpp<# (lcp:define-class aggregation (binary-operator) ((op "Op" :scope :public) (symbol-pos :int32_t :initval -1 :scope :public - :documentation "Symbol table position of the symbol this Aggregation is mapped to.")) + :documentation "Symbol table position of the symbol this Aggregation is mapped to.") + (distinct :bool :initval "false" :scope :public)) (:public (lcp:define-enum op (count min max sum avg collect-list collect-map project) @@ -505,8 +506,8 @@ cpp<# /// Aggregation's first expression is the value being aggregated. The second /// expression is the key used only in COLLECT_MAP. - Aggregation(Expression *expression1, Expression *expression2, Op op) - : BinaryOperator(expression1, expression2), op_(op) { + Aggregation(Expression *expression1, Expression *expression2, Op op, bool distinct) + : BinaryOperator(expression1, expression2), op_(op), distinct_(distinct) { // COUNT without expression denotes COUNT(*) in cypher. DMG_ASSERT(expression1 || op == Aggregation::Op::COUNT, "All aggregations, except COUNT require expression"); diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index d4c66a7a9..4d8663772 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -2106,7 +2106,7 @@ antlrcpp::Any CypherMainVisitor::visitAtom(MemgraphCypher::AtomContext *ctx) { // Here we handle COUNT(*). COUNT(expression) is handled in // visitFunctionInvocation with other aggregations. This is visible in // functionInvocation and atom producions in opencypher grammar. - return static_cast(storage_->Create(nullptr, nullptr, Aggregation::Op::COUNT)); + return static_cast(storage_->Create(nullptr, nullptr, Aggregation::Op::COUNT, false)); } else if (ctx->ALL()) { auto *ident = storage_->Create( std::any_cast(ctx->filterExpression()->idInColl()->variable()->accept(this))); @@ -2222,9 +2222,7 @@ antlrcpp::Any CypherMainVisitor::visitNumberLiteral(MemgraphCypher::NumberLitera } antlrcpp::Any CypherMainVisitor::visitFunctionInvocation(MemgraphCypher::FunctionInvocationContext *ctx) { - if (ctx->DISTINCT()) { - throw utils::NotYetImplemented("DISTINCT function call"); - } + const auto is_distinct = ctx->DISTINCT() != nullptr; auto function_name = std::any_cast(ctx->functionName()->accept(this)); std::vector expressions; for (auto *expression : ctx->expression()) { @@ -2232,33 +2230,38 @@ antlrcpp::Any CypherMainVisitor::visitFunctionInvocation(MemgraphCypher::Functio } if (expressions.size() == 1U) { if (function_name == Aggregation::kCount) { - return static_cast(storage_->Create(expressions[0], nullptr, Aggregation::Op::COUNT)); + return static_cast( + storage_->Create(expressions[0], nullptr, Aggregation::Op::COUNT, is_distinct)); } if (function_name == Aggregation::kMin) { - return static_cast(storage_->Create(expressions[0], nullptr, Aggregation::Op::MIN)); + return static_cast( + storage_->Create(expressions[0], nullptr, Aggregation::Op::MIN, is_distinct)); } if (function_name == Aggregation::kMax) { - return static_cast(storage_->Create(expressions[0], nullptr, Aggregation::Op::MAX)); + return static_cast( + storage_->Create(expressions[0], nullptr, Aggregation::Op::MAX, is_distinct)); } if (function_name == Aggregation::kSum) { - return static_cast(storage_->Create(expressions[0], nullptr, Aggregation::Op::SUM)); + return static_cast( + storage_->Create(expressions[0], nullptr, Aggregation::Op::SUM, is_distinct)); } if (function_name == Aggregation::kAvg) { - return static_cast(storage_->Create(expressions[0], nullptr, Aggregation::Op::AVG)); + return static_cast( + storage_->Create(expressions[0], nullptr, Aggregation::Op::AVG, is_distinct)); } if (function_name == Aggregation::kCollect) { return static_cast( - storage_->Create(expressions[0], nullptr, Aggregation::Op::COLLECT_LIST)); + storage_->Create(expressions[0], nullptr, Aggregation::Op::COLLECT_LIST, is_distinct)); } if (function_name == Aggregation::kProject) { return static_cast( - storage_->Create(expressions[0], nullptr, Aggregation::Op::PROJECT)); + storage_->Create(expressions[0], nullptr, Aggregation::Op::PROJECT, is_distinct)); } } if (expressions.size() == 2U && function_name == Aggregation::kCollect) { return static_cast( - storage_->Create(expressions[1], expressions[0], Aggregation::Op::COLLECT_MAP)); + storage_->Create(expressions[1], expressions[0], Aggregation::Op::COLLECT_MAP, is_distinct)); } auto is_user_defined_function = [](const std::string &function_name) { diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 34d3dcdc0..d37f0b141 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -41,6 +41,7 @@ #include "query/procedure/cypher_types.hpp" #include "query/procedure/mg_procedure_impl.hpp" #include "query/procedure/module.hpp" +#include "query/typed_value.hpp" #include "storage/v2/property_value.hpp" #include "storage/v2/view.hpp" #include "utils/algorithm.hpp" @@ -3211,7 +3212,8 @@ class AggregateCursor : public Cursor { // aggregation map. The vectors in an AggregationValue contain one element for // each aggregation in this LogicalOp. struct AggregationValue { - explicit AggregationValue(utils::MemoryResource *mem) : counts_(mem), values_(mem), remember_(mem) {} + explicit AggregationValue(utils::MemoryResource *mem) + : counts_(mem), values_(mem), remember_(mem), unique_values_(mem) {} // how many input rows have been aggregated in respective values_ element so // far @@ -3224,6 +3226,10 @@ class AggregateCursor : public Cursor { utils::pmr::vector values_; // remember values. utils::pmr::vector remember_; + + using TSet = utils::pmr::unordered_set; + + utils::pmr::vector unique_values_; }; const Aggregate &self_; @@ -3299,6 +3305,7 @@ class AggregateCursor : public Cursor { for (const auto &agg_elem : self_.aggregations_) { auto *mem = agg_value->values_.get_allocator().GetMemoryResource(); agg_value->values_.emplace_back(DefaultAggregationOpValue(agg_elem, mem)); + agg_value->unique_values_.emplace_back(AggregationValue::TSet(mem)); } agg_value->counts_.resize(self_.aggregations_.size(), 0); @@ -3317,8 +3324,9 @@ class AggregateCursor : public Cursor { auto count_it = agg_value->counts_.begin(); auto value_it = agg_value->values_.begin(); + auto unique_values_it = agg_value->unique_values_.begin(); auto agg_elem_it = self_.aggregations_.begin(); - for (; count_it < agg_value->counts_.end(); count_it++, value_it++, agg_elem_it++) { + for (; count_it < agg_value->counts_.end(); count_it++, value_it++, unique_values_it++, agg_elem_it++) { // COUNT(*) is the only case where input expression is optional // handle it here auto input_expr_ptr = agg_elem_it->value; @@ -3333,6 +3341,12 @@ class AggregateCursor : public Cursor { // Aggregations skip Null input values. if (input_value.IsNull()) continue; const auto &agg_op = agg_elem_it->op; + if (agg_elem_it->distinct) { + auto insert_result = unique_values_it->insert(input_value); + if (!insert_result.second) { + break; + } + } *count_it += 1; if (*count_it == 1) { // first value, nothing to aggregate. check type, set and continue. diff --git a/src/query/plan/operator.lcp b/src/query/plan/operator.lcp index d7b531cfc..195a09088 100644 --- a/src/query/plan/operator.lcp +++ b/src/query/plan/operator.lcp @@ -1657,7 +1657,8 @@ elements are in an undefined state after aggregation.") :slk-save #'slk-save-ast-pointer :slk-load (slk-load-ast-pointer "Expression")) (op "::Aggregation::Op") - (output-sym "Symbol")) + (output-sym "Symbol") + (distinct bool :initval "false" )) (:documentation "An aggregation element, contains: (input data expression, key expression - only used in COLLECT_MAP, type of @@ -2282,9 +2283,9 @@ clauses. (:public #>cpp Foreach() = default; - Foreach(std::shared_ptr input, + Foreach(std::shared_ptr input, std::shared_ptr updates, - Expression *named_expr, + Expression *named_expr, Symbol loop_variable_symbol); bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; diff --git a/src/query/plan/pretty_print.cpp b/src/query/plan/pretty_print.cpp index 17644068a..99b4826a6 100644 --- a/src/query/plan/pretty_print.cpp +++ b/src/query/plan/pretty_print.cpp @@ -401,6 +401,8 @@ json ToJson(const Aggregate::Element &elem) { } json["op"] = utils::ToLowerCase(Aggregation::OpToString(elem.op)); json["output_symbol"] = ToJson(elem.output_sym); + json["distinct"] = elem.distinct; + return json; } ////////////////////////// END HELPER FUNCTIONS //////////////////////////////// diff --git a/src/query/plan/rule_based_planner.cpp b/src/query/plan/rule_based_planner.cpp index 6bca9454c..ab4a0752c 100644 --- a/src/query/plan/rule_based_planner.cpp +++ b/src/query/plan/rule_based_planner.cpp @@ -344,7 +344,8 @@ class ReturnBodyContext : public HierarchicalTreeVisitor { bool PostVisit(Aggregation &aggr) override { // Aggregation contains a virtual symbol, where the result will be stored. const auto &symbol = symbol_table_.at(aggr); - aggregations_.emplace_back(Aggregate::Element{aggr.expression1_, aggr.expression2_, aggr.op_, symbol}); + aggregations_.emplace_back( + Aggregate::Element{aggr.expression1_, aggr.expression2_, aggr.op_, symbol, aggr.distinct_}); // Aggregation expression1_ is optional in COUNT(*), and COLLECT_MAP uses // two expressions, so we can have 0, 1 or 2 elements on the // has_aggregation_stack for this Aggregation expression. diff --git a/tests/mgbench/datasets.py b/tests/mgbench/datasets.py index 3a4dd6820..455d86af5 100644 --- a/tests/mgbench/datasets.py +++ b/tests/mgbench/datasets.py @@ -181,6 +181,9 @@ class Pokec(Dataset): def benchmark__arango__aggregate(self): return ("MATCH (n:User) RETURN n.age, COUNT(*)", {}) + def benchmark__arango__aggregate_with_distinct(self): + return ("MATCH (n:User) RETURN COUNT(DISTINCT n.age)", {}) + def benchmark__arango__aggregate_with_filter(self): return ("MATCH (n:User) WHERE n.age >= 18 RETURN n.age, COUNT(*)", {}) diff --git a/tests/unit/plan_pretty_print.cpp b/tests/unit/plan_pretty_print.cpp index 0148a3082..5ad8b8e02 100644 --- a/tests/unit/plan_pretty_print.cpp +++ b/tests/unit/plan_pretty_print.cpp @@ -643,17 +643,70 @@ TEST_F(PrintToJsonTest, Aggregate) { { "value" : "(PropertyLookup (Identifier \"node\") \"value\")", "op" : "sum", - "output_symbol" : "sum" + "output_symbol" : "sum", + "distinct" : false }, { "value" : "(PropertyLookup (Identifier \"node\") \"value\")", "key" : "(PropertyLookup (Identifier \"node\") \"color\")", "op" : "collect", - "output_symbol" : "map" + "output_symbol" : "map", + "distinct" : false }, { "op": "count", - "output_symbol": "count" + "output_symbol": "count", + "distinct" : false + } + ], + "group_by" : [ + "(PropertyLookup (Identifier \"node\") \"type\")" + ], + "remember" : ["node"], + "input" : { + "name" : "ScanAll", + "output_symbol" : "node", + "input" : { "name" : "Once" } + } + })sep"); +} + +TEST_F(PrintToJsonTest, AggregateWithDistinct) { + memgraph::storage::PropertyId value = dba.NameToProperty("value"); + memgraph::storage::PropertyId color = dba.NameToProperty("color"); + memgraph::storage::PropertyId type = dba.NameToProperty("type"); + auto node_sym = GetSymbol("node"); + std::shared_ptr last_op = std::make_shared(nullptr, node_sym); + last_op = std::make_shared( + last_op, + std::vector{ + {PROPERTY_LOOKUP("node", value), nullptr, Aggregation::Op::SUM, GetSymbol("sum"), true}, + {PROPERTY_LOOKUP("node", value), PROPERTY_LOOKUP("node", color), Aggregation::Op::COLLECT_MAP, + GetSymbol("map"), true}, + {nullptr, nullptr, Aggregation::Op::COUNT, GetSymbol("count"), true}}, + std::vector{PROPERTY_LOOKUP("node", type)}, std::vector{node_sym}); + + Check(last_op.get(), R"sep( + { + "name" : "Aggregate", + "aggregations" : [ + { + "value" : "(PropertyLookup (Identifier \"node\") \"value\")", + "op" : "sum", + "output_symbol" : "sum", + "distinct" : true + }, + { + "value" : "(PropertyLookup (Identifier \"node\") \"value\")", + "key" : "(PropertyLookup (Identifier \"node\") \"color\")", + "op" : "collect", + "output_symbol" : "map", + "distinct" : true + }, + { + "op": "count", + "output_symbol": "count", + "distinct" : true } ], "group_by" : [ diff --git a/tests/unit/query_common.hpp b/tests/unit/query_common.hpp index f8da96366..4017a1656 100644 --- a/tests/unit/query_common.hpp +++ b/tests/unit/query_common.hpp @@ -549,12 +549,15 @@ auto GetForeach(AstStorage &storage, NamedExpression *named_expr, const std::vec #define LESS_EQ(expr1, expr2) storage.Create((expr1), (expr2)) #define GREATER(expr1, expr2) storage.Create((expr1), (expr2)) #define GREATER_EQ(expr1, expr2) storage.Create((expr1), (expr2)) -#define SUM(expr) storage.Create((expr), nullptr, memgraph::query::Aggregation::Op::SUM) -#define COUNT(expr) \ - storage.Create((expr), nullptr, memgraph::query::Aggregation::Op::COUNT) -#define AVG(expr) storage.Create((expr), nullptr, memgraph::query::Aggregation::Op::AVG) -#define COLLECT_LIST(expr) \ - storage.Create((expr), nullptr, memgraph::query::Aggregation::Op::COLLECT_LIST) +#define SUM(expr, distinct) \ + storage.Create((expr), nullptr, memgraph::query::Aggregation::Op::SUM, (distinct)) +#define COUNT(expr, distinct) \ + storage.Create((expr), nullptr, memgraph::query::Aggregation::Op::COUNT, (distinct)) +#define AVG(expr, distinct) \ + storage.Create((expr), nullptr, memgraph::query::Aggregation::Op::AVG, (distinct)) +#define COLLECT_LIST(expr, distinct) \ + storage.Create((expr), nullptr, memgraph::query::Aggregation::Op::COLLECT_LIST, \ + (distinct)) #define EQ(expr1, expr2) storage.Create((expr1), (expr2)) #define NEQ(expr1, expr2) storage.Create((expr1), (expr2)) #define AND(expr1, expr2) storage.Create((expr1), (expr2)) diff --git a/tests/unit/query_expression_evaluator.cpp b/tests/unit/query_expression_evaluator.cpp index 13bccea6a..6e61e07d6 100644 --- a/tests/unit/query_expression_evaluator.cpp +++ b/tests/unit/query_expression_evaluator.cpp @@ -596,7 +596,7 @@ TEST_F(ExpressionEvaluatorTest, LabelsTest) { } TEST_F(ExpressionEvaluatorTest, Aggregation) { - auto aggr = storage.Create(storage.Create(42), nullptr, Aggregation::Op::COUNT); + auto aggr = storage.Create(storage.Create(42), nullptr, Aggregation::Op::COUNT, false); auto aggr_sym = symbol_table.CreateSymbol("aggr", true); aggr->MapTo(aggr_sym); frame[aggr_sym] = TypedValue(1); diff --git a/tests/unit/query_plan.cpp b/tests/unit/query_plan.cpp index 93d2f33c7..ce9f74dc9 100644 --- a/tests/unit/query_plan.cpp +++ b/tests/unit/query_plan.cpp @@ -396,7 +396,7 @@ TYPED_TEST(TestPlanner, MatchWithSumWhereReturn) { FakeDbAccessor dba; auto prop = dba.Property("prop"); AstStorage storage; - auto sum = SUM(PROPERTY_LOOKUP("n", prop)); + auto sum = SUM(PROPERTY_LOOKUP("n", prop), false); auto literal = LITERAL(42); auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), WITH(ADD(sum, literal), AS("sum")), WHERE(LESS(IDENT("sum"), LITERAL(42))), RETURN("sum", AS("result")))); @@ -410,7 +410,7 @@ TYPED_TEST(TestPlanner, MatchReturnSum) { auto prop1 = dba.Property("prop1"); auto prop2 = dba.Property("prop2"); AstStorage storage; - auto sum = SUM(PROPERTY_LOOKUP("n", prop1)); + auto sum = SUM(PROPERTY_LOOKUP("n", prop1), false); auto n_prop2 = PROPERTY_LOOKUP("n", prop2); auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), RETURN(sum, AS("sum"), n_prop2, AS("group")))); auto aggr = ExpectAggregate({sum}, {n_prop2}); @@ -426,7 +426,54 @@ TYPED_TEST(TestPlanner, CreateWithSum) { AstStorage storage; auto ident_n = IDENT("n"); auto n_prop = PROPERTY_LOOKUP(ident_n, prop); - auto sum = SUM(n_prop); + auto sum = SUM(n_prop, false); + auto query = QUERY(SINGLE_QUERY(CREATE(PATTERN(NODE("n"))), WITH(sum, AS("sum")))); + auto symbol_table = memgraph::query::MakeSymbolTable(query); + auto acc = ExpectAccumulate({symbol_table.at(*ident_n)}); + auto aggr = ExpectAggregate({sum}, {}); + auto planner = MakePlanner(&dba, storage, symbol_table, query); + // We expect both the accumulation and aggregation because the part before + // WITH updates the database. + CheckPlan(planner.plan(), symbol_table, ExpectCreateNode(), acc, aggr, ExpectProduce()); +} + +TYPED_TEST(TestPlanner, MatchWithSumWithDistinctWhereReturn) { + // Test MATCH (n) WITH SUM(DISTINCT n.prop) + 42 AS sum WHERE sum < 42 + // RETURN sum AS result + FakeDbAccessor dba; + auto prop = dba.Property("prop"); + AstStorage storage; + auto sum = SUM(PROPERTY_LOOKUP("n", prop), true); + auto literal = LITERAL(42); + auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), WITH(ADD(sum, literal), AS("sum")), + WHERE(LESS(IDENT("sum"), LITERAL(42))), RETURN("sum", AS("result")))); + auto aggr = ExpectAggregate({sum}, {literal}); + CheckPlan(query, storage, ExpectScanAll(), aggr, ExpectProduce(), ExpectFilter(), ExpectProduce()); +} + +TYPED_TEST(TestPlanner, MatchReturnSumWithDistinct) { + // Test MATCH (n) RETURN SUM(DISTINCT n.prop1) AS sum, n.prop2 AS group + FakeDbAccessor dba; + auto prop1 = dba.Property("prop1"); + auto prop2 = dba.Property("prop2"); + AstStorage storage; + auto sum = SUM(PROPERTY_LOOKUP("n", prop1), true); + auto n_prop2 = PROPERTY_LOOKUP("n", prop2); + auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), RETURN(sum, AS("sum"), n_prop2, AS("group")))); + auto aggr = ExpectAggregate({sum}, {n_prop2}); + auto symbol_table = memgraph::query::MakeSymbolTable(query); + auto planner = MakePlanner(&dba, storage, symbol_table, query); + CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), aggr, ExpectProduce()); +} + +TYPED_TEST(TestPlanner, CreateWithSumWithDistinct) { + // Test CREATE (n) WITH SUM(DISTINCT n.prop) AS sum + FakeDbAccessor dba; + auto prop = dba.Property("prop"); + AstStorage storage; + auto ident_n = IDENT("n"); + auto n_prop = PROPERTY_LOOKUP(ident_n, prop); + auto sum = SUM(n_prop, true); auto query = QUERY(SINGLE_QUERY(CREATE(PATTERN(NODE("n"))), WITH(sum, AS("sum")))); auto symbol_table = memgraph::query::MakeSymbolTable(query); auto acc = ExpectAccumulate({symbol_table.at(*ident_n)}); @@ -484,7 +531,24 @@ TYPED_TEST(TestPlanner, CreateReturnSumSkipLimit) { AstStorage storage; auto ident_n = IDENT("n"); auto n_prop = PROPERTY_LOOKUP(ident_n, prop); - auto sum = SUM(n_prop); + auto sum = SUM(n_prop, false); + auto query = + QUERY(SINGLE_QUERY(CREATE(PATTERN(NODE("n"))), RETURN(sum, AS("s"), SKIP(LITERAL(2)), LIMIT(LITERAL(1))))); + auto symbol_table = memgraph::query::MakeSymbolTable(query); + auto acc = ExpectAccumulate({symbol_table.at(*ident_n)}); + auto aggr = ExpectAggregate({sum}, {}); + auto planner = MakePlanner(&dba, storage, symbol_table, query); + CheckPlan(planner.plan(), symbol_table, ExpectCreateNode(), acc, aggr, ExpectProduce(), ExpectSkip(), ExpectLimit()); +} + +TYPED_TEST(TestPlanner, CreateReturnSumWithDistinctSkipLimit) { + // Test CREATE (n) RETURN SUM(n.prop) AS s SKIP 2 LIMIT 1 + FakeDbAccessor dba; + auto prop = dba.Property("prop"); + AstStorage storage; + auto ident_n = IDENT("n"); + auto n_prop = PROPERTY_LOOKUP(ident_n, prop); + auto sum = SUM(n_prop, true); auto query = QUERY(SINGLE_QUERY(CREATE(PATTERN(NODE("n"))), RETURN(sum, AS("s"), SKIP(LITERAL(2)), LIMIT(LITERAL(1))))); auto symbol_table = memgraph::query::MakeSymbolTable(query); @@ -539,8 +603,18 @@ TYPED_TEST(TestPlanner, CreateWithOrderByWhere) { TYPED_TEST(TestPlanner, ReturnAddSumCountOrderBy) { // Test RETURN SUM(1) + COUNT(2) AS result ORDER BY result AstStorage storage; - auto sum = SUM(LITERAL(1)); - auto count = COUNT(LITERAL(2)); + auto sum = SUM(LITERAL(1), false); + auto count = COUNT(LITERAL(2), false); + auto *query = QUERY(SINGLE_QUERY(RETURN(ADD(sum, count), AS("result"), ORDER_BY(IDENT("result"))))); + auto aggr = ExpectAggregate({sum, count}, {}); + CheckPlan(query, storage, aggr, ExpectProduce(), ExpectOrderBy()); +} + +TYPED_TEST(TestPlanner, ReturnAddSumCountWithDistinctOrderBy) { + // Test RETURN SUM(1) + COUNT(2) AS result ORDER BY result + AstStorage storage; + auto sum = SUM(LITERAL(1), true); + auto count = COUNT(LITERAL(2), true); auto *query = QUERY(SINGLE_QUERY(RETURN(ADD(sum, count), AS("result"), ORDER_BY(IDENT("result"))))); auto aggr = ExpectAggregate({sum, count}, {}); CheckPlan(query, storage, aggr, ExpectProduce(), ExpectOrderBy()); @@ -610,7 +684,7 @@ TYPED_TEST(TestPlanner, CreateWithDistinctSumWhereReturn) { auto prop = dba.Property("prop"); AstStorage storage; auto node_n = NODE("n"); - auto sum = SUM(PROPERTY_LOOKUP("n", prop)); + auto sum = SUM(PROPERTY_LOOKUP("n", prop), false); auto query = QUERY(SINGLE_QUERY(CREATE(PATTERN(node_n)), WITH_DISTINCT(sum, AS("s")), WHERE(LESS(IDENT("s"), LITERAL(42))), RETURN("s"))); auto symbol_table = memgraph::query::MakeSymbolTable(query); @@ -749,7 +823,7 @@ TYPED_TEST(TestPlanner, MatchReturnAsteriskSum) { FakeDbAccessor dba; auto prop = dba.Property("prop"); AstStorage storage; - auto sum = SUM(PROPERTY_LOOKUP("n", prop)); + auto sum = SUM(PROPERTY_LOOKUP("n", prop), false); auto ret = RETURN(sum, AS("s")); ret->body_.all_identifiers = true; auto query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), ret)); @@ -818,7 +892,7 @@ TYPED_TEST(TestPlanner, MultipleOptionalMatchReturn) { TYPED_TEST(TestPlanner, FunctionAggregationReturn) { // Test RETURN sqrt(SUM(2)) AS result, 42 AS group_by AstStorage storage; - auto sum = SUM(LITERAL(2)); + auto sum = SUM(LITERAL(2), false); auto group_by_literal = LITERAL(42); auto *query = QUERY(SINGLE_QUERY(RETURN(FN("sqrt", sum), AS("result"), group_by_literal, AS("group_by")))); auto aggr = ExpectAggregate({sum}, {group_by_literal}); @@ -835,7 +909,7 @@ TYPED_TEST(TestPlanner, FunctionWithoutArguments) { TYPED_TEST(TestPlanner, ListLiteralAggregationReturn) { // Test RETURN [SUM(2)] AS result, 42 AS group_by AstStorage storage; - auto sum = SUM(LITERAL(2)); + auto sum = SUM(LITERAL(2), false); auto group_by_literal = LITERAL(42); auto *query = QUERY(SINGLE_QUERY(RETURN(LIST(sum), AS("result"), group_by_literal, AS("group_by")))); auto aggr = ExpectAggregate({sum}, {group_by_literal}); @@ -846,7 +920,7 @@ TYPED_TEST(TestPlanner, MapLiteralAggregationReturn) { // Test RETURN {sum: SUM(2)} AS result, 42 AS group_by AstStorage storage; FakeDbAccessor dba; - auto sum = SUM(LITERAL(2)); + auto sum = SUM(LITERAL(2), false); auto group_by_literal = LITERAL(42); auto *query = QUERY( SINGLE_QUERY(RETURN(MAP({storage.GetPropertyIx("sum"), sum}), AS("result"), group_by_literal, AS("group_by")))); @@ -857,7 +931,7 @@ TYPED_TEST(TestPlanner, MapLiteralAggregationReturn) { TYPED_TEST(TestPlanner, EmptyListIndexAggregation) { // Test RETURN [][SUM(2)] AS result, 42 AS group_by AstStorage storage; - auto sum = SUM(LITERAL(2)); + auto sum = SUM(LITERAL(2), false); auto empty_list = LIST(); auto group_by_literal = LITERAL(42); auto *query = QUERY(SINGLE_QUERY(RETURN(storage.Create(empty_list, sum), @@ -872,7 +946,7 @@ TYPED_TEST(TestPlanner, EmptyListIndexAggregation) { TYPED_TEST(TestPlanner, ListSliceAggregationReturn) { // Test RETURN [1, 2][0..SUM(2)] AS result, 42 AS group_by AstStorage storage; - auto sum = SUM(LITERAL(2)); + auto sum = SUM(LITERAL(2), false); auto list = LIST(LITERAL(1), LITERAL(2)); auto group_by_literal = LITERAL(42); auto *query = @@ -886,7 +960,7 @@ TYPED_TEST(TestPlanner, ListSliceAggregationReturn) { TYPED_TEST(TestPlanner, ListWithAggregationAndGroupBy) { // Test RETURN [sum(2), 42] AstStorage storage; - auto sum = SUM(LITERAL(2)); + auto sum = SUM(LITERAL(2), false); auto group_by_literal = LITERAL(42); auto *query = QUERY(SINGLE_QUERY(RETURN(LIST(sum, group_by_literal), AS("result")))); auto aggr = ExpectAggregate({sum}, {group_by_literal}); @@ -896,8 +970,8 @@ TYPED_TEST(TestPlanner, ListWithAggregationAndGroupBy) { TYPED_TEST(TestPlanner, AggregatonWithListWithAggregationAndGroupBy) { // Test RETURN sum(2), [sum(3), 42] AstStorage storage; - auto sum2 = SUM(LITERAL(2)); - auto sum3 = SUM(LITERAL(3)); + auto sum2 = SUM(LITERAL(2), false); + auto sum3 = SUM(LITERAL(3), false); auto group_by_literal = LITERAL(42); auto *query = QUERY(SINGLE_QUERY(RETURN(sum2, AS("sum2"), LIST(sum3, group_by_literal), AS("list")))); auto aggr = ExpectAggregate({sum2, sum3}, {group_by_literal}); @@ -908,7 +982,7 @@ TYPED_TEST(TestPlanner, MapWithAggregationAndGroupBy) { // Test RETURN {lit: 42, sum: sum(2)} AstStorage storage; FakeDbAccessor dba; - auto sum = SUM(LITERAL(2)); + auto sum = SUM(LITERAL(2), false); auto group_by_literal = LITERAL(42); auto *query = QUERY(SINGLE_QUERY(RETURN( MAP({storage.GetPropertyIx("sum"), sum}, {storage.GetPropertyIx("lit"), group_by_literal}), AS("result")))); @@ -1121,7 +1195,7 @@ TYPED_TEST(TestPlanner, SecondPropertyIndex) { TYPED_TEST(TestPlanner, ReturnSumGroupByAll) { // Test RETURN sum([1,2,3]), all(x in [1] where x = 1) AstStorage storage; - auto sum = SUM(LIST(LITERAL(1), LITERAL(2), LITERAL(3))); + auto sum = SUM(LIST(LITERAL(1), LITERAL(2), LITERAL(3)), false); auto *all = ALL("x", LIST(LITERAL(1)), WHERE(EQ(IDENT("x"), LITERAL(1)))); auto *query = QUERY(SINGLE_QUERY(RETURN(sum, AS("sum"), all, AS("all")))); auto aggr = ExpectAggregate({sum}, {all}); diff --git a/tests/unit/query_plan_accumulate_aggregate.cpp b/tests/unit/query_plan_accumulate_aggregate.cpp index 08f1965f0..7fb9e7987 100644 --- a/tests/unit/query_plan_accumulate_aggregate.cpp +++ b/tests/unit/query_plan_accumulate_aggregate.cpp @@ -109,7 +109,7 @@ std::shared_ptr MakeAggregationProduce(std::shared_ptr AstStorage &storage, const std::vector aggr_inputs, const std::vector aggr_ops, const std::vector group_by_exprs, - const std::vector remember) { + const std::vector remember, const bool distinct) { // prepare all the aggregations std::vector aggregates; std::vector named_expressions; @@ -124,7 +124,7 @@ std::shared_ptr MakeAggregationProduce(std::shared_ptr named_expressions.push_back(named_expr); // the key expression is only used in COLLECT_MAP Expression *key_expr_ptr = aggr_op == Aggregation::Op::COLLECT_MAP ? LITERAL("key") : nullptr; - aggregates.emplace_back(Aggregate::Element{*aggr_inputs_it++, key_expr_ptr, aggr_op, aggr_sym}); + aggregates.emplace_back(Aggregate::Element{*aggr_inputs_it++, key_expr_ptr, aggr_op, aggr_sym, distinct}); } // Produce will also evaluate group_by expressions and return them after the @@ -155,16 +155,21 @@ class QueryPlanAggregateOps : public ::testing::Test { ASSERT_TRUE(dba.InsertVertex().SetProperty(prop, memgraph::storage::PropertyValue(5)).HasValue()); ASSERT_TRUE(dba.InsertVertex().SetProperty(prop, memgraph::storage::PropertyValue(7)).HasValue()); ASSERT_TRUE(dba.InsertVertex().SetProperty(prop, memgraph::storage::PropertyValue(12)).HasValue()); + ASSERT_TRUE(dba.InsertVertex().SetProperty(prop, memgraph::storage::PropertyValue(5)).HasValue()); + ASSERT_TRUE(dba.InsertVertex().SetProperty(prop, memgraph::storage::PropertyValue(5)).HasValue()); + ASSERT_TRUE(dba.InsertVertex().SetProperty(prop, memgraph::storage::PropertyValue(12)).HasValue()); + // a missing property (null) gets ignored by all aggregations except // COUNT(*) dba.InsertVertex(); dba.AdvanceCommand(); } - auto AggregationResults(bool with_group_by, std::vector ops = { - Aggregation::Op::COUNT, Aggregation::Op::COUNT, Aggregation::Op::MIN, - Aggregation::Op::MAX, Aggregation::Op::SUM, Aggregation::Op::AVG, - Aggregation::Op::COLLECT_LIST, Aggregation::Op::COLLECT_MAP}) { + auto AggregationResults(bool with_group_by, bool distinct, + std::vector ops = { + Aggregation::Op::COUNT, Aggregation::Op::COUNT, Aggregation::Op::MIN, + Aggregation::Op::MAX, Aggregation::Op::SUM, Aggregation::Op::AVG, + Aggregation::Op::COLLECT_LIST, Aggregation::Op::COLLECT_MAP}) { // match all nodes and perform aggregations auto n = MakeScanAll(storage, symbol_table, "n"); auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); @@ -173,7 +178,8 @@ class QueryPlanAggregateOps : public ::testing::Test { std::vector group_bys; if (with_group_by) group_bys.push_back(n_p); aggregation_expressions[0] = nullptr; - auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, aggregation_expressions, ops, group_bys, {}); + auto produce = + MakeAggregationProduce(n.op_, symbol_table, storage, aggregation_expressions, ops, group_bys, {}, distinct); auto context = MakeContext(storage, symbol_table, &dba); return CollectProduce(*produce, &context); } @@ -181,16 +187,16 @@ class QueryPlanAggregateOps : public ::testing::Test { TEST_F(QueryPlanAggregateOps, WithData) { AddData(); - auto results = AggregationResults(false); + auto results = AggregationResults(false, false); ASSERT_EQ(results.size(), 1); ASSERT_EQ(results[0].size(), 8); // count(*) ASSERT_EQ(results[0][0].type(), TypedValue::Type::Int); - EXPECT_EQ(results[0][0].ValueInt(), 4); + EXPECT_EQ(results[0][0].ValueInt(), 7); // count ASSERT_EQ(results[0][1].type(), TypedValue::Type::Int); - EXPECT_EQ(results[0][1].ValueInt(), 3); + EXPECT_EQ(results[0][1].ValueInt(), 6); // min ASSERT_EQ(results[0][2].type(), TypedValue::Type::Int); EXPECT_EQ(results[0][2].ValueInt(), 5); @@ -199,13 +205,13 @@ TEST_F(QueryPlanAggregateOps, WithData) { EXPECT_EQ(results[0][3].ValueInt(), 12); // sum ASSERT_EQ(results[0][4].type(), TypedValue::Type::Int); - EXPECT_EQ(results[0][4].ValueInt(), 24); + EXPECT_EQ(results[0][4].ValueInt(), 46); // avg ASSERT_EQ(results[0][5].type(), TypedValue::Type::Double); - EXPECT_FLOAT_EQ(results[0][5].ValueDouble(), 24 / 3.0); + EXPECT_FLOAT_EQ(results[0][5].ValueDouble(), 46 / 6.0); // collect list ASSERT_EQ(results[0][6].type(), TypedValue::Type::List); - EXPECT_THAT(ToIntList(results[0][6]), UnorderedElementsAre(5, 7, 12)); + EXPECT_THAT(ToIntList(results[0][6]), UnorderedElementsAre(5, 7, 12, 5, 5, 12)); // collect map ASSERT_EQ(results[0][7].type(), TypedValue::Type::Map); auto map = ToIntMap(results[0][7]); @@ -216,46 +222,46 @@ TEST_F(QueryPlanAggregateOps, WithData) { TEST_F(QueryPlanAggregateOps, WithoutDataWithGroupBy) { { - auto results = AggregationResults(true, {Aggregation::Op::COUNT}); + auto results = AggregationResults(true, false, {Aggregation::Op::COUNT}); EXPECT_EQ(results.size(), 1); EXPECT_EQ(results[0][0].type(), TypedValue::Type::Int); EXPECT_EQ(results[0][0].ValueInt(), 0); } { - auto results = AggregationResults(true, {Aggregation::Op::SUM}); + auto results = AggregationResults(true, false, {Aggregation::Op::SUM}); EXPECT_EQ(results.size(), 1); EXPECT_EQ(results[0][0].type(), TypedValue::Type::Int); EXPECT_EQ(results[0][0].ValueInt(), 0); } { - auto results = AggregationResults(true, {Aggregation::Op::AVG}); + auto results = AggregationResults(true, false, {Aggregation::Op::AVG}); EXPECT_EQ(results.size(), 1); EXPECT_EQ(results[0][0].type(), TypedValue::Type::Null); } { - auto results = AggregationResults(true, {Aggregation::Op::MIN}); + auto results = AggregationResults(true, false, {Aggregation::Op::MIN}); EXPECT_EQ(results.size(), 1); EXPECT_EQ(results[0][0].type(), TypedValue::Type::Null); } { - auto results = AggregationResults(true, {Aggregation::Op::MAX}); + auto results = AggregationResults(true, false, {Aggregation::Op::MAX}); EXPECT_EQ(results.size(), 1); EXPECT_EQ(results[0][0].type(), TypedValue::Type::Null); } { - auto results = AggregationResults(true, {Aggregation::Op::COLLECT_LIST}); + auto results = AggregationResults(true, false, {Aggregation::Op::COLLECT_LIST}); EXPECT_EQ(results.size(), 1); EXPECT_EQ(results[0][0].type(), TypedValue::Type::List); } { - auto results = AggregationResults(true, {Aggregation::Op::COLLECT_MAP}); + auto results = AggregationResults(true, false, {Aggregation::Op::COLLECT_MAP}); EXPECT_EQ(results.size(), 1); EXPECT_EQ(results[0][0].type(), TypedValue::Type::Map); } } TEST_F(QueryPlanAggregateOps, WithoutDataWithoutGroupBy) { - auto results = AggregationResults(false); + auto results = AggregationResults(false, false); ASSERT_EQ(results.size(), 1); ASSERT_EQ(results[0].size(), 8); // count(*) @@ -325,7 +331,8 @@ TEST(QueryPlan, AggregateGroupByValues) { auto n = MakeScanAll(storage, symbol_table, "n"); auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); - auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p}, {Aggregation::Op::COUNT}, {n_p}, {n.sym_}); + auto produce = + MakeAggregationProduce(n.op_, symbol_table, storage, {n_p}, {Aggregation::Op::COUNT}, {n_p}, {n.sym_}, false); auto context = MakeContext(storage, symbol_table, &dba); auto results = CollectProduce(*produce, &context); @@ -372,7 +379,7 @@ TEST(QueryPlan, AggregateMultipleGroupBy) { auto n_p3 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop3); auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p1}, {Aggregation::Op::COUNT}, - {n_p1, n_p2, n_p3}, {n.sym_}); + {n_p1, n_p2, n_p3}, {n.sym_}, false); auto context = MakeContext(storage, symbol_table, &dba); auto results = CollectProduce(*produce, &context); @@ -387,7 +394,7 @@ TEST(QueryPlan, AggregateNoInput) { SymbolTable symbol_table; auto two = LITERAL(2); - auto produce = MakeAggregationProduce(nullptr, symbol_table, storage, {two}, {Aggregation::Op::COUNT}, {}, {}); + auto produce = MakeAggregationProduce(nullptr, symbol_table, storage, {two}, {Aggregation::Op::COUNT}, {}, {}, false); auto context = MakeContext(storage, symbol_table, &dba); auto results = CollectProduce(*produce, &context); EXPECT_EQ(1, results.size()); @@ -419,7 +426,7 @@ TEST(QueryPlan, AggregateCountEdgeCases) { // returns -1 when there are no results // otherwise returns MATCH (n) RETURN count(n.prop) auto count = [&]() { - auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p}, {Aggregation::Op::COUNT}, {}, {}); + auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p}, {Aggregation::Op::COUNT}, {}, {}, false); auto context = MakeContext(storage, symbol_table, &dba); auto results = CollectProduce(*produce, &context); if (results.size() == 0) return -1L; @@ -479,7 +486,7 @@ TEST(QueryPlan, AggregateFirstValueTypes) { auto n_id = n_prop_string->expression_; auto aggregate = [&](Expression *expression, Aggregation::Op aggr_op) { - auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {expression}, {aggr_op}, {}, {}); + auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {expression}, {aggr_op}, {}, {}, false); auto context = MakeContext(storage, symbol_table, &dba); CollectProduce(*produce, &context); }; @@ -533,7 +540,7 @@ TEST(QueryPlan, AggregateTypes) { auto n_p2 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), p2); auto aggregate = [&](Expression *expression, Aggregation::Op aggr_op) { - auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {expression}, {aggr_op}, {}, {}); + auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {expression}, {aggr_op}, {}, {}, false); auto context = MakeContext(storage, symbol_table, &dba); CollectProduce(*produce, &context); }; @@ -609,3 +616,396 @@ TEST(QueryPlan, Unwind) { expected_y_it++; } } + +TEST_F(QueryPlanAggregateOps, WithDataDistinct) { + AddData(); + auto results = AggregationResults(false, true); + + ASSERT_EQ(results.size(), 1); + ASSERT_EQ(results[0].size(), 8); + // count(*) + ASSERT_EQ(results[0][0].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][0].ValueInt(), 7); + // count + ASSERT_EQ(results[0][1].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][1].ValueInt(), 3); + // min + ASSERT_EQ(results[0][2].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][2].ValueInt(), 5); + // max + ASSERT_EQ(results[0][3].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][3].ValueInt(), 12); + // sum + ASSERT_EQ(results[0][4].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][4].ValueInt(), 24); + // avg + ASSERT_EQ(results[0][5].type(), TypedValue::Type::Double); + EXPECT_FLOAT_EQ(results[0][5].ValueDouble(), 24 / 3.0); + // collect list + ASSERT_EQ(results[0][6].type(), TypedValue::Type::List); + EXPECT_THAT(ToIntList(results[0][6]), UnorderedElementsAre(5, 7, 12)); + // collect map + ASSERT_EQ(results[0][7].type(), TypedValue::Type::Map); + auto map = ToIntMap(results[0][7]); + ASSERT_EQ(map.size(), 1); + EXPECT_EQ(map.begin()->first, "key"); + EXPECT_FALSE(std::set({5, 7, 12}).insert(map.begin()->second).second); +} + +TEST_F(QueryPlanAggregateOps, WithoutDataWithDistinctAndWithGroupBy) { + { + auto results = AggregationResults(true, true, {Aggregation::Op::COUNT}); + EXPECT_EQ(results.size(), 1); + EXPECT_EQ(results[0][0].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][0].ValueInt(), 0); + } + { + auto results = AggregationResults(true, true, {Aggregation::Op::SUM}); + EXPECT_EQ(results.size(), 1); + EXPECT_EQ(results[0][0].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][0].ValueInt(), 0); + } + { + auto results = AggregationResults(true, true, {Aggregation::Op::AVG}); + EXPECT_EQ(results.size(), 1); + EXPECT_EQ(results[0][0].type(), TypedValue::Type::Null); + } + { + auto results = AggregationResults(true, true, {Aggregation::Op::MIN}); + EXPECT_EQ(results.size(), 1); + EXPECT_EQ(results[0][0].type(), TypedValue::Type::Null); + } + { + auto results = AggregationResults(true, true, {Aggregation::Op::MAX}); + EXPECT_EQ(results.size(), 1); + EXPECT_EQ(results[0][0].type(), TypedValue::Type::Null); + } + { + auto results = AggregationResults(true, true, {Aggregation::Op::COLLECT_LIST}); + EXPECT_EQ(results.size(), 1); + EXPECT_EQ(results[0][0].type(), TypedValue::Type::List); + } + { + auto results = AggregationResults(true, true, {Aggregation::Op::COLLECT_MAP}); + EXPECT_EQ(results.size(), 1); + EXPECT_EQ(results[0][0].type(), TypedValue::Type::Map); + } +} + +TEST_F(QueryPlanAggregateOps, WithoutDataWithDistinctAndWithoutGroupBy) { + auto results = AggregationResults(false, true); + ASSERT_EQ(results.size(), 1); + ASSERT_EQ(results[0].size(), 8); + // count(*) + ASSERT_EQ(results[0][0].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][0].ValueInt(), 0); + // count + ASSERT_EQ(results[0][1].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][1].ValueInt(), 0); + // min + EXPECT_TRUE(results[0][2].IsNull()); + // max + EXPECT_TRUE(results[0][3].IsNull()); + // sum + EXPECT_EQ(results[0][4].ValueInt(), 0); + // avg + EXPECT_TRUE(results[0][5].IsNull()); + // collect list + ASSERT_EQ(results[0][6].type(), TypedValue::Type::List); + EXPECT_EQ(ToIntList(results[0][6]).size(), 0); + // collect map + ASSERT_EQ(results[0][7].type(), TypedValue::Type::Map); + EXPECT_EQ(ToIntMap(results[0][7]).size(), 0); +} + +TEST(QueryPlan, AggregateGroupByValuesWithDistinct) { + // Tests that distinct groups are aggregated properly for values of all types. + // Also test the "remember" part of the Aggregation API as final results are + // obtained via a property lookup of a remembered node. + memgraph::storage::Storage db; + auto storage_dba = db.Access(); + memgraph::query::DbAccessor dba(&storage_dba); + + // a vector of memgraph::storage::PropertyValue to be set as property values on vertices + // most of them should result in a distinct group (commented where not) + std::vector group_by_vals; + group_by_vals.emplace_back(4); + group_by_vals.emplace_back(7); + group_by_vals.emplace_back(7.3); + group_by_vals.emplace_back(7.2); + group_by_vals.emplace_back("Johhny"); + group_by_vals.emplace_back("Jane"); + group_by_vals.emplace_back("1"); + group_by_vals.emplace_back(true); + group_by_vals.emplace_back(false); + group_by_vals.emplace_back(std::vector{memgraph::storage::PropertyValue(1)}); + group_by_vals.emplace_back(std::vector{memgraph::storage::PropertyValue(1), + memgraph::storage::PropertyValue(2)}); + group_by_vals.emplace_back(std::vector{memgraph::storage::PropertyValue(2), + memgraph::storage::PropertyValue(1)}); + group_by_vals.emplace_back(memgraph::storage::PropertyValue()); + // should NOT result in another group because 7.0 == 7 + group_by_vals.emplace_back(7.0); + // should NOT result in another group + group_by_vals.emplace_back(std::vector{memgraph::storage::PropertyValue(1), + memgraph::storage::PropertyValue(2.0)}); + + // generate a lot of vertices and set props on them + auto prop = dba.NameToProperty("prop"); + for (int i = 0; i < 1000; ++i) + ASSERT_TRUE(dba.InsertVertex().SetProperty(prop, group_by_vals[i % group_by_vals.size()]).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + // match all nodes and perform aggregations + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); + + auto produce = + MakeAggregationProduce(n.op_, symbol_table, storage, {n_p}, {Aggregation::Op::COUNT}, {n_p}, {n.sym_}, true); + + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + ASSERT_EQ(results.size(), group_by_vals.size() - 2); + std::unordered_set result_group_bys; + for (const auto &row : results) { + ASSERT_EQ(2, row.size()); + if (!row[1].IsNull()) { + ASSERT_EQ(1, row[0].ValueInt()); + } + result_group_bys.insert(row[1]); + } + ASSERT_EQ(result_group_bys.size(), group_by_vals.size() - 2); + std::vector group_by_tvals; + group_by_tvals.reserve(group_by_vals.size()); + for (const auto &v : group_by_vals) group_by_tvals.emplace_back(v); + EXPECT_TRUE(std::is_permutation(group_by_tvals.begin(), group_by_tvals.end() - 2, result_group_bys.begin(), + TypedValue::BoolEqual{})); +} + +TEST(QueryPlan, AggregateMultipleGroupByWithDistinct) { + // in this test we have 3 different properties that have different values + // for different records and assert that we get the correct combination + // of values in our groups + memgraph::storage::Storage db; + auto storage_dba = db.Access(); + memgraph::query::DbAccessor dba(&storage_dba); + + auto prop1 = dba.NameToProperty("prop1"); + auto prop2 = dba.NameToProperty("prop2"); + auto prop3 = dba.NameToProperty("prop3"); + for (int i = 0; i < 2 * 3 * 5; ++i) { + auto v = dba.InsertVertex(); + ASSERT_TRUE(v.SetProperty(prop1, memgraph::storage::PropertyValue(static_cast(i % 2))).HasValue()); + ASSERT_TRUE(v.SetProperty(prop2, memgraph::storage::PropertyValue(i % 3)).HasValue()); + ASSERT_TRUE(v.SetProperty(prop3, memgraph::storage::PropertyValue("value" + std::to_string(i % 5))).HasValue()); + } + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + // match all nodes and perform aggregations + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_p1 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop1); + auto n_p2 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop2); + + auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p1}, {Aggregation::Op::COUNT}, {n_p1, n_p2}, + {n.sym_}, true); + + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + for (const auto &row : results) { + ASSERT_EQ(1, row[0].ValueInt()); + } +} + +TEST(QueryPlan, AggregateNoInputWithDistinct) { + memgraph::storage::Storage db; + auto storage_dba = db.Access(); + memgraph::query::DbAccessor dba(&storage_dba); + AstStorage storage; + SymbolTable symbol_table; + + auto two = LITERAL(2); + auto produce = MakeAggregationProduce(nullptr, symbol_table, storage, {two}, {Aggregation::Op::COUNT}, {}, {}, true); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + EXPECT_EQ(1, results.size()); + EXPECT_EQ(1, results[0].size()); + EXPECT_EQ(TypedValue::Type::Int, results[0][0].type()); + EXPECT_EQ(1, results[0][0].ValueInt()); +} + +TEST(QueryPlan, AggregateCountEdgeCasesWithDistinct) { + // tests for detected bugs in the COUNT aggregation behavior + // ensure that COUNT returns correctly for + // - 0 vertices in database + // - 1 vertex in database, property not set + // - 1 vertex in database, property set + // - 2 vertices in database, property set on one + // - 2 vertices in database, property set on both + + memgraph::storage::Storage db; + auto storage_dba = db.Access(); + memgraph::query::DbAccessor dba(&storage_dba); + auto prop = dba.NameToProperty("prop"); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop); + + // returns -1 when there are no results + // otherwise returns MATCH (n) RETURN count(n.prop) + auto count = [&]() { + auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p}, {Aggregation::Op::COUNT}, {}, {}, true); + auto context = MakeContext(storage, symbol_table, &dba); + auto results = CollectProduce(*produce, &context); + if (results.size() == 0) return -1L; + EXPECT_EQ(1, results.size()); + EXPECT_EQ(1, results[0].size()); + EXPECT_EQ(TypedValue::Type::Int, results[0][0].type()); + return results[0][0].ValueInt(); + }; + + // no vertices yet in database + EXPECT_EQ(0, count()); + + // one vertex, no property set + dba.InsertVertex(); + dba.AdvanceCommand(); + EXPECT_EQ(0, count()); + + // one vertex, property set + for (auto va : dba.Vertices(memgraph::storage::View::OLD)) + ASSERT_TRUE(va.SetProperty(prop, memgraph::storage::PropertyValue(42)).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(1, count()); + + // two vertices, one with property set + dba.InsertVertex(); + dba.AdvanceCommand(); + EXPECT_EQ(1, count()); + + // two vertices, both with property set + for (auto va : dba.Vertices(memgraph::storage::View::OLD)) + ASSERT_TRUE(va.SetProperty(prop, memgraph::storage::PropertyValue(42)).HasValue()); + dba.AdvanceCommand(); + EXPECT_EQ(1, count()); +} + +TEST(QueryPlan, AggregateFirstValueTypesWithDistinct) { + // testing exceptions that get emitted by the first-value + // type check + + memgraph::storage::Storage db; + auto storage_dba = db.Access(); + memgraph::query::DbAccessor dba(&storage_dba); + + auto v1 = dba.InsertVertex(); + auto prop_string = dba.NameToProperty("string"); + ASSERT_TRUE(v1.SetProperty(prop_string, memgraph::storage::PropertyValue("johhny")).HasValue()); + auto prop_int = dba.NameToProperty("int"); + ASSERT_TRUE(v1.SetProperty(prop_int, memgraph::storage::PropertyValue(12)).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_prop_string = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop_string); + auto n_prop_int = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop_int); + auto n_id = n_prop_string->expression_; + + auto aggregate = [&](Expression *expression, Aggregation::Op aggr_op) { + auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {expression}, {aggr_op}, {}, {}, true); + auto context = MakeContext(storage, symbol_table, &dba); + CollectProduce(*produce, &context); + }; + + // everything except for COUNT and COLLECT fails on a Vertex + aggregate(n_id, Aggregation::Op::COUNT); + EXPECT_THROW(aggregate(n_id, Aggregation::Op::MIN), QueryRuntimeException); + EXPECT_THROW(aggregate(n_id, Aggregation::Op::MAX), QueryRuntimeException); + EXPECT_THROW(aggregate(n_id, Aggregation::Op::AVG), QueryRuntimeException); + EXPECT_THROW(aggregate(n_id, Aggregation::Op::SUM), QueryRuntimeException); + + // on strings AVG and SUM fail + aggregate(n_prop_string, Aggregation::Op::COUNT); + aggregate(n_prop_string, Aggregation::Op::MIN); + aggregate(n_prop_string, Aggregation::Op::MAX); + EXPECT_THROW(aggregate(n_prop_string, Aggregation::Op::AVG), QueryRuntimeException); + EXPECT_THROW(aggregate(n_prop_string, Aggregation::Op::SUM), QueryRuntimeException); + + // on ints nothing fails + aggregate(n_prop_int, Aggregation::Op::COUNT); + aggregate(n_prop_int, Aggregation::Op::MIN); + aggregate(n_prop_int, Aggregation::Op::MAX); + aggregate(n_prop_int, Aggregation::Op::AVG); + aggregate(n_prop_int, Aggregation::Op::SUM); + aggregate(n_prop_int, Aggregation::Op::COLLECT_LIST); + aggregate(n_prop_int, Aggregation::Op::COLLECT_MAP); +} + +TEST(QueryPlan, AggregateTypesWithDistinct) { + // testing exceptions that can get emitted by an aggregation + // does not check all combinations that can result in an exception + // (that logic is defined and tested by TypedValue) + + memgraph::storage::Storage db; + auto storage_dba = db.Access(); + memgraph::query::DbAccessor dba(&storage_dba); + + auto p1 = dba.NameToProperty("p1"); // has only string props + ASSERT_TRUE(dba.InsertVertex().SetProperty(p1, memgraph::storage::PropertyValue("string")).HasValue()); + ASSERT_TRUE(dba.InsertVertex().SetProperty(p1, memgraph::storage::PropertyValue("str2")).HasValue()); + auto p2 = dba.NameToProperty("p2"); // combines int and bool + ASSERT_TRUE(dba.InsertVertex().SetProperty(p2, memgraph::storage::PropertyValue(42)).HasValue()); + ASSERT_TRUE(dba.InsertVertex().SetProperty(p2, memgraph::storage::PropertyValue(true)).HasValue()); + dba.AdvanceCommand(); + + AstStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_p1 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), p1); + auto n_p2 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), p2); + + auto aggregate = [&](Expression *expression, Aggregation::Op aggr_op) { + auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {expression}, {aggr_op}, {}, {}, true); + auto context = MakeContext(storage, symbol_table, &dba); + CollectProduce(*produce, &context); + }; + + // everything except for COUNT and COLLECT fails on a Vertex + auto n_id = n_p1->expression_; + aggregate(n_id, Aggregation::Op::COUNT); + aggregate(n_id, Aggregation::Op::COLLECT_LIST); + aggregate(n_id, Aggregation::Op::COLLECT_MAP); + EXPECT_THROW(aggregate(n_id, Aggregation::Op::MIN), QueryRuntimeException); + EXPECT_THROW(aggregate(n_id, Aggregation::Op::MAX), QueryRuntimeException); + EXPECT_THROW(aggregate(n_id, Aggregation::Op::AVG), QueryRuntimeException); + EXPECT_THROW(aggregate(n_id, Aggregation::Op::SUM), QueryRuntimeException); + + // on strings AVG and SUM fail + aggregate(n_p1, Aggregation::Op::COUNT); + aggregate(n_p1, Aggregation::Op::COLLECT_LIST); + aggregate(n_p1, Aggregation::Op::COLLECT_MAP); + aggregate(n_p1, Aggregation::Op::MIN); + aggregate(n_p1, Aggregation::Op::MAX); + EXPECT_THROW(aggregate(n_p1, Aggregation::Op::AVG), QueryRuntimeException); + EXPECT_THROW(aggregate(n_p1, Aggregation::Op::SUM), QueryRuntimeException); + + // combination of int and bool, everything except COUNT and COLLECT fails + aggregate(n_p2, Aggregation::Op::COUNT); + aggregate(n_p2, Aggregation::Op::COLLECT_LIST); + aggregate(n_p2, Aggregation::Op::COLLECT_MAP); + EXPECT_THROW(aggregate(n_p2, Aggregation::Op::MIN), QueryRuntimeException); + EXPECT_THROW(aggregate(n_p2, Aggregation::Op::MAX), QueryRuntimeException); + EXPECT_THROW(aggregate(n_p2, Aggregation::Op::AVG), QueryRuntimeException); + EXPECT_THROW(aggregate(n_p2, Aggregation::Op::SUM), QueryRuntimeException); +} diff --git a/tests/unit/query_plan_checker.hpp b/tests/unit/query_plan_checker.hpp index 335b6ab2b..1577187fc 100644 --- a/tests/unit/query_plan_checker.hpp +++ b/tests/unit/query_plan_checker.hpp @@ -90,7 +90,7 @@ class PlanChecker : public virtual HierarchicalLogicalOperatorVisitor { } PRE_VISIT(Unwind); PRE_VISIT(Distinct); - + bool PreVisit(Foreach &op) override { CheckOp(op); return false; @@ -216,6 +216,7 @@ class ExpectAggregate : public OpChecker { EXPECT_EQ(typeid(aggr_elem.value).hash_code(), typeid(aggr->expression1_).hash_code()); EXPECT_EQ(typeid(aggr_elem.key).hash_code(), typeid(aggr->expression2_).hash_code()); EXPECT_EQ(aggr_elem.op, aggr->op_); + EXPECT_EQ(aggr_elem.distinct, aggr->distinct_); EXPECT_EQ(aggr_elem.output_sym, symbol_table.at(*aggr)); } EXPECT_EQ(aggr_it, aggregations_.end()); diff --git a/tests/unit/query_semantic.cpp b/tests/unit/query_semantic.cpp index 45e60811a..89e3de946 100644 --- a/tests/unit/query_semantic.cpp +++ b/tests/unit/query_semantic.cpp @@ -253,7 +253,7 @@ TEST_F(TestSymbolGenerator, MatchWithWhere) { TEST_F(TestSymbolGenerator, MatchWithWhereUnbound) { // Test MATCH (old) WITH COUNT(old) AS c WHERE old.prop < 42 auto prop = dba.NameToProperty("prop"); - auto query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("old"))), WITH(COUNT(IDENT("old")), AS("c")), + auto query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("old"))), WITH(COUNT(IDENT("old"), false), AS("c")), WHERE(LESS(PROPERTY_LOOKUP("old", prop), LITERAL(42))))); EXPECT_THROW(memgraph::query::MakeSymbolTable(query), UnboundVariableError); } @@ -313,7 +313,7 @@ TEST_F(TestSymbolGenerator, MatchReturnSum) { // Test MATCH (n) RETURN SUM(n.prop) + 42 AS result auto prop = dba.NameToProperty("prop"); auto node = NODE("n"); - auto sum = SUM(PROPERTY_LOOKUP("n", prop)); + auto sum = SUM(PROPERTY_LOOKUP("n", prop), false); auto as_result = AS("result"); auto query = QUERY(SINGLE_QUERY(MATCH(PATTERN(node)), RETURN(ADD(sum, LITERAL(42)), as_result))); auto symbol_table = memgraph::query::MakeSymbolTable(query); @@ -330,8 +330,9 @@ TEST_F(TestSymbolGenerator, MatchReturnSum) { TEST_F(TestSymbolGenerator, NestedAggregation) { // Test MATCH (n) RETURN SUM(42 + SUM(n.prop)) AS s auto prop = dba.NameToProperty("prop"); - auto query = QUERY( - SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), RETURN(SUM(ADD(LITERAL(42), SUM(PROPERTY_LOOKUP("n", prop)))), AS("s")))); + auto query = + QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), + RETURN(SUM(ADD(LITERAL(42), SUM(PROPERTY_LOOKUP("n", prop), false)), false), AS("s")))); EXPECT_THROW(memgraph::query::MakeSymbolTable(query), SemanticException); } @@ -339,7 +340,7 @@ TEST_F(TestSymbolGenerator, WrongAggregationContext) { // Test MATCH (n) WITH n.prop AS prop WHERE SUM(prop) < 42 auto prop = dba.NameToProperty("prop"); auto query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), WITH(PROPERTY_LOOKUP("n", prop), AS("prop")), - WHERE(LESS(SUM(IDENT("prop")), LITERAL(42))))); + WHERE(LESS(SUM(IDENT("prop"), false), LITERAL(42))))); EXPECT_THROW(memgraph::query::MakeSymbolTable(query), SemanticException); } @@ -429,14 +430,15 @@ TEST_F(TestSymbolGenerator, LimitUsingIdentifier) { TEST_F(TestSymbolGenerator, OrderByAggregation) { // Test MATCH (old) RETURN old AS new ORDER BY COUNT(1) - auto query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("old"))), RETURN("old", AS("new"), ORDER_BY(COUNT(LITERAL(1)))))); + auto query = + QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("old"))), RETURN("old", AS("new"), ORDER_BY(COUNT(LITERAL(1), false))))); EXPECT_THROW(memgraph::query::MakeSymbolTable(query), SemanticException); } TEST_F(TestSymbolGenerator, OrderByUnboundVariable) { // Test MATCH (old) RETURN COUNT(old) AS new ORDER BY old - auto query = - QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("old"))), RETURN(COUNT(IDENT("old")), AS("new"), ORDER_BY(IDENT("old"))))); + auto query = QUERY( + SINGLE_QUERY(MATCH(PATTERN(NODE("old"))), RETURN(COUNT(IDENT("old"), false), AS("new"), ORDER_BY(IDENT("old"))))); EXPECT_THROW(memgraph::query::MakeSymbolTable(query), UnboundVariableError); } @@ -446,7 +448,7 @@ TEST_F(TestSymbolGenerator, AggregationOrderBy) { auto ident_old = IDENT("old"); auto as_new = AS("new"); auto ident_new = IDENT("new"); - auto query = QUERY(SINGLE_QUERY(MATCH(PATTERN(node)), RETURN(COUNT(ident_old), as_new, ORDER_BY(ident_new)))); + auto query = QUERY(SINGLE_QUERY(MATCH(PATTERN(node)), RETURN(COUNT(ident_old, false), as_new, ORDER_BY(ident_new)))); auto symbol_table = memgraph::query::MakeSymbolTable(query); // Symbols for pattern, `old`, `count(old)` and `new` EXPECT_EQ(symbol_table.max_position(), 4);