Add aggregation distinct (#654) (#665)

This commit is contained in:
niko4299 2022-12-03 12:48:44 +01:00 committed by GitHub
parent 6e4047a847
commit 3e11f38548
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 644 additions and 86 deletions

View File

@ -461,7 +461,8 @@ cpp<#
(lcp:define-class aggregation (binary-operator)
((op "Op" :scope :public)
(symbol-pos :int32_t :initval -1 :scope :public
:documentation "Symbol table position of the symbol this Aggregation is mapped to."))
:documentation "Symbol table position of the symbol this Aggregation is mapped to.")
(distinct :bool :initval "false" :scope :public))
(:public
(lcp:define-enum op
(count min max sum avg collect-list collect-map project)
@ -505,8 +506,8 @@ cpp<#
/// Aggregation's first expression is the value being aggregated. The second
/// expression is the key used only in COLLECT_MAP.
Aggregation(Expression *expression1, Expression *expression2, Op op)
: BinaryOperator(expression1, expression2), op_(op) {
Aggregation(Expression *expression1, Expression *expression2, Op op, bool distinct)
: BinaryOperator(expression1, expression2), op_(op), distinct_(distinct) {
// COUNT without expression denotes COUNT(*) in cypher.
DMG_ASSERT(expression1 || op == Aggregation::Op::COUNT,
"All aggregations, except COUNT require expression");

View File

@ -2106,7 +2106,7 @@ antlrcpp::Any CypherMainVisitor::visitAtom(MemgraphCypher::AtomContext *ctx) {
// Here we handle COUNT(*). COUNT(expression) is handled in
// visitFunctionInvocation with other aggregations. This is visible in
// functionInvocation and atom producions in opencypher grammar.
return static_cast<Expression *>(storage_->Create<Aggregation>(nullptr, nullptr, Aggregation::Op::COUNT));
return static_cast<Expression *>(storage_->Create<Aggregation>(nullptr, nullptr, Aggregation::Op::COUNT, false));
} else if (ctx->ALL()) {
auto *ident = storage_->Create<Identifier>(
std::any_cast<std::string>(ctx->filterExpression()->idInColl()->variable()->accept(this)));
@ -2222,9 +2222,7 @@ antlrcpp::Any CypherMainVisitor::visitNumberLiteral(MemgraphCypher::NumberLitera
}
antlrcpp::Any CypherMainVisitor::visitFunctionInvocation(MemgraphCypher::FunctionInvocationContext *ctx) {
if (ctx->DISTINCT()) {
throw utils::NotYetImplemented("DISTINCT function call");
}
const auto is_distinct = ctx->DISTINCT() != nullptr;
auto function_name = std::any_cast<std::string>(ctx->functionName()->accept(this));
std::vector<Expression *> expressions;
for (auto *expression : ctx->expression()) {
@ -2232,33 +2230,38 @@ antlrcpp::Any CypherMainVisitor::visitFunctionInvocation(MemgraphCypher::Functio
}
if (expressions.size() == 1U) {
if (function_name == Aggregation::kCount) {
return static_cast<Expression *>(storage_->Create<Aggregation>(expressions[0], nullptr, Aggregation::Op::COUNT));
return static_cast<Expression *>(
storage_->Create<Aggregation>(expressions[0], nullptr, Aggregation::Op::COUNT, is_distinct));
}
if (function_name == Aggregation::kMin) {
return static_cast<Expression *>(storage_->Create<Aggregation>(expressions[0], nullptr, Aggregation::Op::MIN));
return static_cast<Expression *>(
storage_->Create<Aggregation>(expressions[0], nullptr, Aggregation::Op::MIN, is_distinct));
}
if (function_name == Aggregation::kMax) {
return static_cast<Expression *>(storage_->Create<Aggregation>(expressions[0], nullptr, Aggregation::Op::MAX));
return static_cast<Expression *>(
storage_->Create<Aggregation>(expressions[0], nullptr, Aggregation::Op::MAX, is_distinct));
}
if (function_name == Aggregation::kSum) {
return static_cast<Expression *>(storage_->Create<Aggregation>(expressions[0], nullptr, Aggregation::Op::SUM));
return static_cast<Expression *>(
storage_->Create<Aggregation>(expressions[0], nullptr, Aggregation::Op::SUM, is_distinct));
}
if (function_name == Aggregation::kAvg) {
return static_cast<Expression *>(storage_->Create<Aggregation>(expressions[0], nullptr, Aggregation::Op::AVG));
return static_cast<Expression *>(
storage_->Create<Aggregation>(expressions[0], nullptr, Aggregation::Op::AVG, is_distinct));
}
if (function_name == Aggregation::kCollect) {
return static_cast<Expression *>(
storage_->Create<Aggregation>(expressions[0], nullptr, Aggregation::Op::COLLECT_LIST));
storage_->Create<Aggregation>(expressions[0], nullptr, Aggregation::Op::COLLECT_LIST, is_distinct));
}
if (function_name == Aggregation::kProject) {
return static_cast<Expression *>(
storage_->Create<Aggregation>(expressions[0], nullptr, Aggregation::Op::PROJECT));
storage_->Create<Aggregation>(expressions[0], nullptr, Aggregation::Op::PROJECT, is_distinct));
}
}
if (expressions.size() == 2U && function_name == Aggregation::kCollect) {
return static_cast<Expression *>(
storage_->Create<Aggregation>(expressions[1], expressions[0], Aggregation::Op::COLLECT_MAP));
storage_->Create<Aggregation>(expressions[1], expressions[0], Aggregation::Op::COLLECT_MAP, is_distinct));
}
auto is_user_defined_function = [](const std::string &function_name) {

View File

@ -41,6 +41,7 @@
#include "query/procedure/cypher_types.hpp"
#include "query/procedure/mg_procedure_impl.hpp"
#include "query/procedure/module.hpp"
#include "query/typed_value.hpp"
#include "storage/v2/property_value.hpp"
#include "storage/v2/view.hpp"
#include "utils/algorithm.hpp"
@ -3211,7 +3212,8 @@ class AggregateCursor : public Cursor {
// aggregation map. The vectors in an AggregationValue contain one element for
// each aggregation in this LogicalOp.
struct AggregationValue {
explicit AggregationValue(utils::MemoryResource *mem) : counts_(mem), values_(mem), remember_(mem) {}
explicit AggregationValue(utils::MemoryResource *mem)
: counts_(mem), values_(mem), remember_(mem), unique_values_(mem) {}
// how many input rows have been aggregated in respective values_ element so
// far
@ -3224,6 +3226,10 @@ class AggregateCursor : public Cursor {
utils::pmr::vector<TypedValue> values_;
// remember values.
utils::pmr::vector<TypedValue> remember_;
using TSet = utils::pmr::unordered_set<TypedValue, TypedValue::Hash, TypedValue::BoolEqual>;
utils::pmr::vector<TSet> unique_values_;
};
const Aggregate &self_;
@ -3299,6 +3305,7 @@ class AggregateCursor : public Cursor {
for (const auto &agg_elem : self_.aggregations_) {
auto *mem = agg_value->values_.get_allocator().GetMemoryResource();
agg_value->values_.emplace_back(DefaultAggregationOpValue(agg_elem, mem));
agg_value->unique_values_.emplace_back(AggregationValue::TSet(mem));
}
agg_value->counts_.resize(self_.aggregations_.size(), 0);
@ -3317,8 +3324,9 @@ class AggregateCursor : public Cursor {
auto count_it = agg_value->counts_.begin();
auto value_it = agg_value->values_.begin();
auto unique_values_it = agg_value->unique_values_.begin();
auto agg_elem_it = self_.aggregations_.begin();
for (; count_it < agg_value->counts_.end(); count_it++, value_it++, agg_elem_it++) {
for (; count_it < agg_value->counts_.end(); count_it++, value_it++, unique_values_it++, agg_elem_it++) {
// COUNT(*) is the only case where input expression is optional
// handle it here
auto input_expr_ptr = agg_elem_it->value;
@ -3333,6 +3341,12 @@ class AggregateCursor : public Cursor {
// Aggregations skip Null input values.
if (input_value.IsNull()) continue;
const auto &agg_op = agg_elem_it->op;
if (agg_elem_it->distinct) {
auto insert_result = unique_values_it->insert(input_value);
if (!insert_result.second) {
break;
}
}
*count_it += 1;
if (*count_it == 1) {
// first value, nothing to aggregate. check type, set and continue.

View File

@ -1657,7 +1657,8 @@ elements are in an undefined state after aggregation.")
:slk-save #'slk-save-ast-pointer
:slk-load (slk-load-ast-pointer "Expression"))
(op "::Aggregation::Op")
(output-sym "Symbol"))
(output-sym "Symbol")
(distinct bool :initval "false" ))
(:documentation
"An aggregation element, contains:
(input data expression, key expression - only used in COLLECT_MAP, type of
@ -2282,9 +2283,9 @@ clauses.
(:public
#>cpp
Foreach() = default;
Foreach(std::shared_ptr<LogicalOperator> input,
Foreach(std::shared_ptr<LogicalOperator> input,
std::shared_ptr<LogicalOperator> updates,
Expression *named_expr,
Expression *named_expr,
Symbol loop_variable_symbol);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;

View File

@ -401,6 +401,8 @@ json ToJson(const Aggregate::Element &elem) {
}
json["op"] = utils::ToLowerCase(Aggregation::OpToString(elem.op));
json["output_symbol"] = ToJson(elem.output_sym);
json["distinct"] = elem.distinct;
return json;
}
////////////////////////// END HELPER FUNCTIONS ////////////////////////////////

View File

@ -344,7 +344,8 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
bool PostVisit(Aggregation &aggr) override {
// Aggregation contains a virtual symbol, where the result will be stored.
const auto &symbol = symbol_table_.at(aggr);
aggregations_.emplace_back(Aggregate::Element{aggr.expression1_, aggr.expression2_, aggr.op_, symbol});
aggregations_.emplace_back(
Aggregate::Element{aggr.expression1_, aggr.expression2_, aggr.op_, symbol, aggr.distinct_});
// Aggregation expression1_ is optional in COUNT(*), and COLLECT_MAP uses
// two expressions, so we can have 0, 1 or 2 elements on the
// has_aggregation_stack for this Aggregation expression.

View File

@ -181,6 +181,9 @@ class Pokec(Dataset):
def benchmark__arango__aggregate(self):
return ("MATCH (n:User) RETURN n.age, COUNT(*)", {})
def benchmark__arango__aggregate_with_distinct(self):
return ("MATCH (n:User) RETURN COUNT(DISTINCT n.age)", {})
def benchmark__arango__aggregate_with_filter(self):
return ("MATCH (n:User) WHERE n.age >= 18 RETURN n.age, COUNT(*)", {})

View File

@ -643,17 +643,70 @@ TEST_F(PrintToJsonTest, Aggregate) {
{
"value" : "(PropertyLookup (Identifier \"node\") \"value\")",
"op" : "sum",
"output_symbol" : "sum"
"output_symbol" : "sum",
"distinct" : false
},
{
"value" : "(PropertyLookup (Identifier \"node\") \"value\")",
"key" : "(PropertyLookup (Identifier \"node\") \"color\")",
"op" : "collect",
"output_symbol" : "map"
"output_symbol" : "map",
"distinct" : false
},
{
"op": "count",
"output_symbol": "count"
"output_symbol": "count",
"distinct" : false
}
],
"group_by" : [
"(PropertyLookup (Identifier \"node\") \"type\")"
],
"remember" : ["node"],
"input" : {
"name" : "ScanAll",
"output_symbol" : "node",
"input" : { "name" : "Once" }
}
})sep");
}
TEST_F(PrintToJsonTest, AggregateWithDistinct) {
memgraph::storage::PropertyId value = dba.NameToProperty("value");
memgraph::storage::PropertyId color = dba.NameToProperty("color");
memgraph::storage::PropertyId type = dba.NameToProperty("type");
auto node_sym = GetSymbol("node");
std::shared_ptr<LogicalOperator> last_op = std::make_shared<ScanAll>(nullptr, node_sym);
last_op = std::make_shared<plan::Aggregate>(
last_op,
std::vector<Aggregate::Element>{
{PROPERTY_LOOKUP("node", value), nullptr, Aggregation::Op::SUM, GetSymbol("sum"), true},
{PROPERTY_LOOKUP("node", value), PROPERTY_LOOKUP("node", color), Aggregation::Op::COLLECT_MAP,
GetSymbol("map"), true},
{nullptr, nullptr, Aggregation::Op::COUNT, GetSymbol("count"), true}},
std::vector<Expression *>{PROPERTY_LOOKUP("node", type)}, std::vector<Symbol>{node_sym});
Check(last_op.get(), R"sep(
{
"name" : "Aggregate",
"aggregations" : [
{
"value" : "(PropertyLookup (Identifier \"node\") \"value\")",
"op" : "sum",
"output_symbol" : "sum",
"distinct" : true
},
{
"value" : "(PropertyLookup (Identifier \"node\") \"value\")",
"key" : "(PropertyLookup (Identifier \"node\") \"color\")",
"op" : "collect",
"output_symbol" : "map",
"distinct" : true
},
{
"op": "count",
"output_symbol": "count",
"distinct" : true
}
],
"group_by" : [

View File

@ -549,12 +549,15 @@ auto GetForeach(AstStorage &storage, NamedExpression *named_expr, const std::vec
#define LESS_EQ(expr1, expr2) storage.Create<memgraph::query::LessEqualOperator>((expr1), (expr2))
#define GREATER(expr1, expr2) storage.Create<memgraph::query::GreaterOperator>((expr1), (expr2))
#define GREATER_EQ(expr1, expr2) storage.Create<memgraph::query::GreaterEqualOperator>((expr1), (expr2))
#define SUM(expr) storage.Create<memgraph::query::Aggregation>((expr), nullptr, memgraph::query::Aggregation::Op::SUM)
#define COUNT(expr) \
storage.Create<memgraph::query::Aggregation>((expr), nullptr, memgraph::query::Aggregation::Op::COUNT)
#define AVG(expr) storage.Create<memgraph::query::Aggregation>((expr), nullptr, memgraph::query::Aggregation::Op::AVG)
#define COLLECT_LIST(expr) \
storage.Create<memgraph::query::Aggregation>((expr), nullptr, memgraph::query::Aggregation::Op::COLLECT_LIST)
#define SUM(expr, distinct) \
storage.Create<memgraph::query::Aggregation>((expr), nullptr, memgraph::query::Aggregation::Op::SUM, (distinct))
#define COUNT(expr, distinct) \
storage.Create<memgraph::query::Aggregation>((expr), nullptr, memgraph::query::Aggregation::Op::COUNT, (distinct))
#define AVG(expr, distinct) \
storage.Create<memgraph::query::Aggregation>((expr), nullptr, memgraph::query::Aggregation::Op::AVG, (distinct))
#define COLLECT_LIST(expr, distinct) \
storage.Create<memgraph::query::Aggregation>((expr), nullptr, memgraph::query::Aggregation::Op::COLLECT_LIST, \
(distinct))
#define EQ(expr1, expr2) storage.Create<memgraph::query::EqualOperator>((expr1), (expr2))
#define NEQ(expr1, expr2) storage.Create<memgraph::query::NotEqualOperator>((expr1), (expr2))
#define AND(expr1, expr2) storage.Create<memgraph::query::AndOperator>((expr1), (expr2))

View File

@ -596,7 +596,7 @@ TEST_F(ExpressionEvaluatorTest, LabelsTest) {
}
TEST_F(ExpressionEvaluatorTest, Aggregation) {
auto aggr = storage.Create<Aggregation>(storage.Create<PrimitiveLiteral>(42), nullptr, Aggregation::Op::COUNT);
auto aggr = storage.Create<Aggregation>(storage.Create<PrimitiveLiteral>(42), nullptr, Aggregation::Op::COUNT, false);
auto aggr_sym = symbol_table.CreateSymbol("aggr", true);
aggr->MapTo(aggr_sym);
frame[aggr_sym] = TypedValue(1);

View File

@ -396,7 +396,7 @@ TYPED_TEST(TestPlanner, MatchWithSumWhereReturn) {
FakeDbAccessor dba;
auto prop = dba.Property("prop");
AstStorage storage;
auto sum = SUM(PROPERTY_LOOKUP("n", prop));
auto sum = SUM(PROPERTY_LOOKUP("n", prop), false);
auto literal = LITERAL(42);
auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), WITH(ADD(sum, literal), AS("sum")),
WHERE(LESS(IDENT("sum"), LITERAL(42))), RETURN("sum", AS("result"))));
@ -410,7 +410,7 @@ TYPED_TEST(TestPlanner, MatchReturnSum) {
auto prop1 = dba.Property("prop1");
auto prop2 = dba.Property("prop2");
AstStorage storage;
auto sum = SUM(PROPERTY_LOOKUP("n", prop1));
auto sum = SUM(PROPERTY_LOOKUP("n", prop1), false);
auto n_prop2 = PROPERTY_LOOKUP("n", prop2);
auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), RETURN(sum, AS("sum"), n_prop2, AS("group"))));
auto aggr = ExpectAggregate({sum}, {n_prop2});
@ -426,7 +426,54 @@ TYPED_TEST(TestPlanner, CreateWithSum) {
AstStorage storage;
auto ident_n = IDENT("n");
auto n_prop = PROPERTY_LOOKUP(ident_n, prop);
auto sum = SUM(n_prop);
auto sum = SUM(n_prop, false);
auto query = QUERY(SINGLE_QUERY(CREATE(PATTERN(NODE("n"))), WITH(sum, AS("sum"))));
auto symbol_table = memgraph::query::MakeSymbolTable(query);
auto acc = ExpectAccumulate({symbol_table.at(*ident_n)});
auto aggr = ExpectAggregate({sum}, {});
auto planner = MakePlanner<TypeParam>(&dba, storage, symbol_table, query);
// We expect both the accumulation and aggregation because the part before
// WITH updates the database.
CheckPlan(planner.plan(), symbol_table, ExpectCreateNode(), acc, aggr, ExpectProduce());
}
TYPED_TEST(TestPlanner, MatchWithSumWithDistinctWhereReturn) {
// Test MATCH (n) WITH SUM(DISTINCT n.prop) + 42 AS sum WHERE sum < 42
// RETURN sum AS result
FakeDbAccessor dba;
auto prop = dba.Property("prop");
AstStorage storage;
auto sum = SUM(PROPERTY_LOOKUP("n", prop), true);
auto literal = LITERAL(42);
auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), WITH(ADD(sum, literal), AS("sum")),
WHERE(LESS(IDENT("sum"), LITERAL(42))), RETURN("sum", AS("result"))));
auto aggr = ExpectAggregate({sum}, {literal});
CheckPlan<TypeParam>(query, storage, ExpectScanAll(), aggr, ExpectProduce(), ExpectFilter(), ExpectProduce());
}
TYPED_TEST(TestPlanner, MatchReturnSumWithDistinct) {
// Test MATCH (n) RETURN SUM(DISTINCT n.prop1) AS sum, n.prop2 AS group
FakeDbAccessor dba;
auto prop1 = dba.Property("prop1");
auto prop2 = dba.Property("prop2");
AstStorage storage;
auto sum = SUM(PROPERTY_LOOKUP("n", prop1), true);
auto n_prop2 = PROPERTY_LOOKUP("n", prop2);
auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), RETURN(sum, AS("sum"), n_prop2, AS("group"))));
auto aggr = ExpectAggregate({sum}, {n_prop2});
auto symbol_table = memgraph::query::MakeSymbolTable(query);
auto planner = MakePlanner<TypeParam>(&dba, storage, symbol_table, query);
CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), aggr, ExpectProduce());
}
TYPED_TEST(TestPlanner, CreateWithSumWithDistinct) {
// Test CREATE (n) WITH SUM(DISTINCT n.prop) AS sum
FakeDbAccessor dba;
auto prop = dba.Property("prop");
AstStorage storage;
auto ident_n = IDENT("n");
auto n_prop = PROPERTY_LOOKUP(ident_n, prop);
auto sum = SUM(n_prop, true);
auto query = QUERY(SINGLE_QUERY(CREATE(PATTERN(NODE("n"))), WITH(sum, AS("sum"))));
auto symbol_table = memgraph::query::MakeSymbolTable(query);
auto acc = ExpectAccumulate({symbol_table.at(*ident_n)});
@ -484,7 +531,24 @@ TYPED_TEST(TestPlanner, CreateReturnSumSkipLimit) {
AstStorage storage;
auto ident_n = IDENT("n");
auto n_prop = PROPERTY_LOOKUP(ident_n, prop);
auto sum = SUM(n_prop);
auto sum = SUM(n_prop, false);
auto query =
QUERY(SINGLE_QUERY(CREATE(PATTERN(NODE("n"))), RETURN(sum, AS("s"), SKIP(LITERAL(2)), LIMIT(LITERAL(1)))));
auto symbol_table = memgraph::query::MakeSymbolTable(query);
auto acc = ExpectAccumulate({symbol_table.at(*ident_n)});
auto aggr = ExpectAggregate({sum}, {});
auto planner = MakePlanner<TypeParam>(&dba, storage, symbol_table, query);
CheckPlan(planner.plan(), symbol_table, ExpectCreateNode(), acc, aggr, ExpectProduce(), ExpectSkip(), ExpectLimit());
}
TYPED_TEST(TestPlanner, CreateReturnSumWithDistinctSkipLimit) {
// Test CREATE (n) RETURN SUM(n.prop) AS s SKIP 2 LIMIT 1
FakeDbAccessor dba;
auto prop = dba.Property("prop");
AstStorage storage;
auto ident_n = IDENT("n");
auto n_prop = PROPERTY_LOOKUP(ident_n, prop);
auto sum = SUM(n_prop, true);
auto query =
QUERY(SINGLE_QUERY(CREATE(PATTERN(NODE("n"))), RETURN(sum, AS("s"), SKIP(LITERAL(2)), LIMIT(LITERAL(1)))));
auto symbol_table = memgraph::query::MakeSymbolTable(query);
@ -539,8 +603,18 @@ TYPED_TEST(TestPlanner, CreateWithOrderByWhere) {
TYPED_TEST(TestPlanner, ReturnAddSumCountOrderBy) {
// Test RETURN SUM(1) + COUNT(2) AS result ORDER BY result
AstStorage storage;
auto sum = SUM(LITERAL(1));
auto count = COUNT(LITERAL(2));
auto sum = SUM(LITERAL(1), false);
auto count = COUNT(LITERAL(2), false);
auto *query = QUERY(SINGLE_QUERY(RETURN(ADD(sum, count), AS("result"), ORDER_BY(IDENT("result")))));
auto aggr = ExpectAggregate({sum, count}, {});
CheckPlan<TypeParam>(query, storage, aggr, ExpectProduce(), ExpectOrderBy());
}
TYPED_TEST(TestPlanner, ReturnAddSumCountWithDistinctOrderBy) {
// Test RETURN SUM(1) + COUNT(2) AS result ORDER BY result
AstStorage storage;
auto sum = SUM(LITERAL(1), true);
auto count = COUNT(LITERAL(2), true);
auto *query = QUERY(SINGLE_QUERY(RETURN(ADD(sum, count), AS("result"), ORDER_BY(IDENT("result")))));
auto aggr = ExpectAggregate({sum, count}, {});
CheckPlan<TypeParam>(query, storage, aggr, ExpectProduce(), ExpectOrderBy());
@ -610,7 +684,7 @@ TYPED_TEST(TestPlanner, CreateWithDistinctSumWhereReturn) {
auto prop = dba.Property("prop");
AstStorage storage;
auto node_n = NODE("n");
auto sum = SUM(PROPERTY_LOOKUP("n", prop));
auto sum = SUM(PROPERTY_LOOKUP("n", prop), false);
auto query = QUERY(SINGLE_QUERY(CREATE(PATTERN(node_n)), WITH_DISTINCT(sum, AS("s")),
WHERE(LESS(IDENT("s"), LITERAL(42))), RETURN("s")));
auto symbol_table = memgraph::query::MakeSymbolTable(query);
@ -749,7 +823,7 @@ TYPED_TEST(TestPlanner, MatchReturnAsteriskSum) {
FakeDbAccessor dba;
auto prop = dba.Property("prop");
AstStorage storage;
auto sum = SUM(PROPERTY_LOOKUP("n", prop));
auto sum = SUM(PROPERTY_LOOKUP("n", prop), false);
auto ret = RETURN(sum, AS("s"));
ret->body_.all_identifiers = true;
auto query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), ret));
@ -818,7 +892,7 @@ TYPED_TEST(TestPlanner, MultipleOptionalMatchReturn) {
TYPED_TEST(TestPlanner, FunctionAggregationReturn) {
// Test RETURN sqrt(SUM(2)) AS result, 42 AS group_by
AstStorage storage;
auto sum = SUM(LITERAL(2));
auto sum = SUM(LITERAL(2), false);
auto group_by_literal = LITERAL(42);
auto *query = QUERY(SINGLE_QUERY(RETURN(FN("sqrt", sum), AS("result"), group_by_literal, AS("group_by"))));
auto aggr = ExpectAggregate({sum}, {group_by_literal});
@ -835,7 +909,7 @@ TYPED_TEST(TestPlanner, FunctionWithoutArguments) {
TYPED_TEST(TestPlanner, ListLiteralAggregationReturn) {
// Test RETURN [SUM(2)] AS result, 42 AS group_by
AstStorage storage;
auto sum = SUM(LITERAL(2));
auto sum = SUM(LITERAL(2), false);
auto group_by_literal = LITERAL(42);
auto *query = QUERY(SINGLE_QUERY(RETURN(LIST(sum), AS("result"), group_by_literal, AS("group_by"))));
auto aggr = ExpectAggregate({sum}, {group_by_literal});
@ -846,7 +920,7 @@ TYPED_TEST(TestPlanner, MapLiteralAggregationReturn) {
// Test RETURN {sum: SUM(2)} AS result, 42 AS group_by
AstStorage storage;
FakeDbAccessor dba;
auto sum = SUM(LITERAL(2));
auto sum = SUM(LITERAL(2), false);
auto group_by_literal = LITERAL(42);
auto *query = QUERY(
SINGLE_QUERY(RETURN(MAP({storage.GetPropertyIx("sum"), sum}), AS("result"), group_by_literal, AS("group_by"))));
@ -857,7 +931,7 @@ TYPED_TEST(TestPlanner, MapLiteralAggregationReturn) {
TYPED_TEST(TestPlanner, EmptyListIndexAggregation) {
// Test RETURN [][SUM(2)] AS result, 42 AS group_by
AstStorage storage;
auto sum = SUM(LITERAL(2));
auto sum = SUM(LITERAL(2), false);
auto empty_list = LIST();
auto group_by_literal = LITERAL(42);
auto *query = QUERY(SINGLE_QUERY(RETURN(storage.Create<memgraph::query::SubscriptOperator>(empty_list, sum),
@ -872,7 +946,7 @@ TYPED_TEST(TestPlanner, EmptyListIndexAggregation) {
TYPED_TEST(TestPlanner, ListSliceAggregationReturn) {
// Test RETURN [1, 2][0..SUM(2)] AS result, 42 AS group_by
AstStorage storage;
auto sum = SUM(LITERAL(2));
auto sum = SUM(LITERAL(2), false);
auto list = LIST(LITERAL(1), LITERAL(2));
auto group_by_literal = LITERAL(42);
auto *query =
@ -886,7 +960,7 @@ TYPED_TEST(TestPlanner, ListSliceAggregationReturn) {
TYPED_TEST(TestPlanner, ListWithAggregationAndGroupBy) {
// Test RETURN [sum(2), 42]
AstStorage storage;
auto sum = SUM(LITERAL(2));
auto sum = SUM(LITERAL(2), false);
auto group_by_literal = LITERAL(42);
auto *query = QUERY(SINGLE_QUERY(RETURN(LIST(sum, group_by_literal), AS("result"))));
auto aggr = ExpectAggregate({sum}, {group_by_literal});
@ -896,8 +970,8 @@ TYPED_TEST(TestPlanner, ListWithAggregationAndGroupBy) {
TYPED_TEST(TestPlanner, AggregatonWithListWithAggregationAndGroupBy) {
// Test RETURN sum(2), [sum(3), 42]
AstStorage storage;
auto sum2 = SUM(LITERAL(2));
auto sum3 = SUM(LITERAL(3));
auto sum2 = SUM(LITERAL(2), false);
auto sum3 = SUM(LITERAL(3), false);
auto group_by_literal = LITERAL(42);
auto *query = QUERY(SINGLE_QUERY(RETURN(sum2, AS("sum2"), LIST(sum3, group_by_literal), AS("list"))));
auto aggr = ExpectAggregate({sum2, sum3}, {group_by_literal});
@ -908,7 +982,7 @@ TYPED_TEST(TestPlanner, MapWithAggregationAndGroupBy) {
// Test RETURN {lit: 42, sum: sum(2)}
AstStorage storage;
FakeDbAccessor dba;
auto sum = SUM(LITERAL(2));
auto sum = SUM(LITERAL(2), false);
auto group_by_literal = LITERAL(42);
auto *query = QUERY(SINGLE_QUERY(RETURN(
MAP({storage.GetPropertyIx("sum"), sum}, {storage.GetPropertyIx("lit"), group_by_literal}), AS("result"))));
@ -1121,7 +1195,7 @@ TYPED_TEST(TestPlanner, SecondPropertyIndex) {
TYPED_TEST(TestPlanner, ReturnSumGroupByAll) {
// Test RETURN sum([1,2,3]), all(x in [1] where x = 1)
AstStorage storage;
auto sum = SUM(LIST(LITERAL(1), LITERAL(2), LITERAL(3)));
auto sum = SUM(LIST(LITERAL(1), LITERAL(2), LITERAL(3)), false);
auto *all = ALL("x", LIST(LITERAL(1)), WHERE(EQ(IDENT("x"), LITERAL(1))));
auto *query = QUERY(SINGLE_QUERY(RETURN(sum, AS("sum"), all, AS("all"))));
auto aggr = ExpectAggregate({sum}, {all});

View File

@ -109,7 +109,7 @@ std::shared_ptr<Produce> MakeAggregationProduce(std::shared_ptr<LogicalOperator>
AstStorage &storage, const std::vector<Expression *> aggr_inputs,
const std::vector<Aggregation::Op> aggr_ops,
const std::vector<Expression *> group_by_exprs,
const std::vector<Symbol> remember) {
const std::vector<Symbol> remember, const bool distinct) {
// prepare all the aggregations
std::vector<Aggregate::Element> aggregates;
std::vector<NamedExpression *> named_expressions;
@ -124,7 +124,7 @@ std::shared_ptr<Produce> MakeAggregationProduce(std::shared_ptr<LogicalOperator>
named_expressions.push_back(named_expr);
// the key expression is only used in COLLECT_MAP
Expression *key_expr_ptr = aggr_op == Aggregation::Op::COLLECT_MAP ? LITERAL("key") : nullptr;
aggregates.emplace_back(Aggregate::Element{*aggr_inputs_it++, key_expr_ptr, aggr_op, aggr_sym});
aggregates.emplace_back(Aggregate::Element{*aggr_inputs_it++, key_expr_ptr, aggr_op, aggr_sym, distinct});
}
// Produce will also evaluate group_by expressions and return them after the
@ -155,16 +155,21 @@ class QueryPlanAggregateOps : public ::testing::Test {
ASSERT_TRUE(dba.InsertVertex().SetProperty(prop, memgraph::storage::PropertyValue(5)).HasValue());
ASSERT_TRUE(dba.InsertVertex().SetProperty(prop, memgraph::storage::PropertyValue(7)).HasValue());
ASSERT_TRUE(dba.InsertVertex().SetProperty(prop, memgraph::storage::PropertyValue(12)).HasValue());
ASSERT_TRUE(dba.InsertVertex().SetProperty(prop, memgraph::storage::PropertyValue(5)).HasValue());
ASSERT_TRUE(dba.InsertVertex().SetProperty(prop, memgraph::storage::PropertyValue(5)).HasValue());
ASSERT_TRUE(dba.InsertVertex().SetProperty(prop, memgraph::storage::PropertyValue(12)).HasValue());
// a missing property (null) gets ignored by all aggregations except
// COUNT(*)
dba.InsertVertex();
dba.AdvanceCommand();
}
auto AggregationResults(bool with_group_by, std::vector<Aggregation::Op> ops = {
Aggregation::Op::COUNT, Aggregation::Op::COUNT, Aggregation::Op::MIN,
Aggregation::Op::MAX, Aggregation::Op::SUM, Aggregation::Op::AVG,
Aggregation::Op::COLLECT_LIST, Aggregation::Op::COLLECT_MAP}) {
auto AggregationResults(bool with_group_by, bool distinct,
std::vector<Aggregation::Op> ops = {
Aggregation::Op::COUNT, Aggregation::Op::COUNT, Aggregation::Op::MIN,
Aggregation::Op::MAX, Aggregation::Op::SUM, Aggregation::Op::AVG,
Aggregation::Op::COLLECT_LIST, Aggregation::Op::COLLECT_MAP}) {
// match all nodes and perform aggregations
auto n = MakeScanAll(storage, symbol_table, "n");
auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop);
@ -173,7 +178,8 @@ class QueryPlanAggregateOps : public ::testing::Test {
std::vector<Expression *> group_bys;
if (with_group_by) group_bys.push_back(n_p);
aggregation_expressions[0] = nullptr;
auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, aggregation_expressions, ops, group_bys, {});
auto produce =
MakeAggregationProduce(n.op_, symbol_table, storage, aggregation_expressions, ops, group_bys, {}, distinct);
auto context = MakeContext(storage, symbol_table, &dba);
return CollectProduce(*produce, &context);
}
@ -181,16 +187,16 @@ class QueryPlanAggregateOps : public ::testing::Test {
TEST_F(QueryPlanAggregateOps, WithData) {
AddData();
auto results = AggregationResults(false);
auto results = AggregationResults(false, false);
ASSERT_EQ(results.size(), 1);
ASSERT_EQ(results[0].size(), 8);
// count(*)
ASSERT_EQ(results[0][0].type(), TypedValue::Type::Int);
EXPECT_EQ(results[0][0].ValueInt(), 4);
EXPECT_EQ(results[0][0].ValueInt(), 7);
// count
ASSERT_EQ(results[0][1].type(), TypedValue::Type::Int);
EXPECT_EQ(results[0][1].ValueInt(), 3);
EXPECT_EQ(results[0][1].ValueInt(), 6);
// min
ASSERT_EQ(results[0][2].type(), TypedValue::Type::Int);
EXPECT_EQ(results[0][2].ValueInt(), 5);
@ -199,13 +205,13 @@ TEST_F(QueryPlanAggregateOps, WithData) {
EXPECT_EQ(results[0][3].ValueInt(), 12);
// sum
ASSERT_EQ(results[0][4].type(), TypedValue::Type::Int);
EXPECT_EQ(results[0][4].ValueInt(), 24);
EXPECT_EQ(results[0][4].ValueInt(), 46);
// avg
ASSERT_EQ(results[0][5].type(), TypedValue::Type::Double);
EXPECT_FLOAT_EQ(results[0][5].ValueDouble(), 24 / 3.0);
EXPECT_FLOAT_EQ(results[0][5].ValueDouble(), 46 / 6.0);
// collect list
ASSERT_EQ(results[0][6].type(), TypedValue::Type::List);
EXPECT_THAT(ToIntList(results[0][6]), UnorderedElementsAre(5, 7, 12));
EXPECT_THAT(ToIntList(results[0][6]), UnorderedElementsAre(5, 7, 12, 5, 5, 12));
// collect map
ASSERT_EQ(results[0][7].type(), TypedValue::Type::Map);
auto map = ToIntMap(results[0][7]);
@ -216,46 +222,46 @@ TEST_F(QueryPlanAggregateOps, WithData) {
TEST_F(QueryPlanAggregateOps, WithoutDataWithGroupBy) {
{
auto results = AggregationResults(true, {Aggregation::Op::COUNT});
auto results = AggregationResults(true, false, {Aggregation::Op::COUNT});
EXPECT_EQ(results.size(), 1);
EXPECT_EQ(results[0][0].type(), TypedValue::Type::Int);
EXPECT_EQ(results[0][0].ValueInt(), 0);
}
{
auto results = AggregationResults(true, {Aggregation::Op::SUM});
auto results = AggregationResults(true, false, {Aggregation::Op::SUM});
EXPECT_EQ(results.size(), 1);
EXPECT_EQ(results[0][0].type(), TypedValue::Type::Int);
EXPECT_EQ(results[0][0].ValueInt(), 0);
}
{
auto results = AggregationResults(true, {Aggregation::Op::AVG});
auto results = AggregationResults(true, false, {Aggregation::Op::AVG});
EXPECT_EQ(results.size(), 1);
EXPECT_EQ(results[0][0].type(), TypedValue::Type::Null);
}
{
auto results = AggregationResults(true, {Aggregation::Op::MIN});
auto results = AggregationResults(true, false, {Aggregation::Op::MIN});
EXPECT_EQ(results.size(), 1);
EXPECT_EQ(results[0][0].type(), TypedValue::Type::Null);
}
{
auto results = AggregationResults(true, {Aggregation::Op::MAX});
auto results = AggregationResults(true, false, {Aggregation::Op::MAX});
EXPECT_EQ(results.size(), 1);
EXPECT_EQ(results[0][0].type(), TypedValue::Type::Null);
}
{
auto results = AggregationResults(true, {Aggregation::Op::COLLECT_LIST});
auto results = AggregationResults(true, false, {Aggregation::Op::COLLECT_LIST});
EXPECT_EQ(results.size(), 1);
EXPECT_EQ(results[0][0].type(), TypedValue::Type::List);
}
{
auto results = AggregationResults(true, {Aggregation::Op::COLLECT_MAP});
auto results = AggregationResults(true, false, {Aggregation::Op::COLLECT_MAP});
EXPECT_EQ(results.size(), 1);
EXPECT_EQ(results[0][0].type(), TypedValue::Type::Map);
}
}
TEST_F(QueryPlanAggregateOps, WithoutDataWithoutGroupBy) {
auto results = AggregationResults(false);
auto results = AggregationResults(false, false);
ASSERT_EQ(results.size(), 1);
ASSERT_EQ(results[0].size(), 8);
// count(*)
@ -325,7 +331,8 @@ TEST(QueryPlan, AggregateGroupByValues) {
auto n = MakeScanAll(storage, symbol_table, "n");
auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop);
auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p}, {Aggregation::Op::COUNT}, {n_p}, {n.sym_});
auto produce =
MakeAggregationProduce(n.op_, symbol_table, storage, {n_p}, {Aggregation::Op::COUNT}, {n_p}, {n.sym_}, false);
auto context = MakeContext(storage, symbol_table, &dba);
auto results = CollectProduce(*produce, &context);
@ -372,7 +379,7 @@ TEST(QueryPlan, AggregateMultipleGroupBy) {
auto n_p3 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop3);
auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p1}, {Aggregation::Op::COUNT},
{n_p1, n_p2, n_p3}, {n.sym_});
{n_p1, n_p2, n_p3}, {n.sym_}, false);
auto context = MakeContext(storage, symbol_table, &dba);
auto results = CollectProduce(*produce, &context);
@ -387,7 +394,7 @@ TEST(QueryPlan, AggregateNoInput) {
SymbolTable symbol_table;
auto two = LITERAL(2);
auto produce = MakeAggregationProduce(nullptr, symbol_table, storage, {two}, {Aggregation::Op::COUNT}, {}, {});
auto produce = MakeAggregationProduce(nullptr, symbol_table, storage, {two}, {Aggregation::Op::COUNT}, {}, {}, false);
auto context = MakeContext(storage, symbol_table, &dba);
auto results = CollectProduce(*produce, &context);
EXPECT_EQ(1, results.size());
@ -419,7 +426,7 @@ TEST(QueryPlan, AggregateCountEdgeCases) {
// returns -1 when there are no results
// otherwise returns MATCH (n) RETURN count(n.prop)
auto count = [&]() {
auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p}, {Aggregation::Op::COUNT}, {}, {});
auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p}, {Aggregation::Op::COUNT}, {}, {}, false);
auto context = MakeContext(storage, symbol_table, &dba);
auto results = CollectProduce(*produce, &context);
if (results.size() == 0) return -1L;
@ -479,7 +486,7 @@ TEST(QueryPlan, AggregateFirstValueTypes) {
auto n_id = n_prop_string->expression_;
auto aggregate = [&](Expression *expression, Aggregation::Op aggr_op) {
auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {expression}, {aggr_op}, {}, {});
auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {expression}, {aggr_op}, {}, {}, false);
auto context = MakeContext(storage, symbol_table, &dba);
CollectProduce(*produce, &context);
};
@ -533,7 +540,7 @@ TEST(QueryPlan, AggregateTypes) {
auto n_p2 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), p2);
auto aggregate = [&](Expression *expression, Aggregation::Op aggr_op) {
auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {expression}, {aggr_op}, {}, {});
auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {expression}, {aggr_op}, {}, {}, false);
auto context = MakeContext(storage, symbol_table, &dba);
CollectProduce(*produce, &context);
};
@ -609,3 +616,396 @@ TEST(QueryPlan, Unwind) {
expected_y_it++;
}
}
TEST_F(QueryPlanAggregateOps, WithDataDistinct) {
AddData();
auto results = AggregationResults(false, true);
ASSERT_EQ(results.size(), 1);
ASSERT_EQ(results[0].size(), 8);
// count(*)
ASSERT_EQ(results[0][0].type(), TypedValue::Type::Int);
EXPECT_EQ(results[0][0].ValueInt(), 7);
// count
ASSERT_EQ(results[0][1].type(), TypedValue::Type::Int);
EXPECT_EQ(results[0][1].ValueInt(), 3);
// min
ASSERT_EQ(results[0][2].type(), TypedValue::Type::Int);
EXPECT_EQ(results[0][2].ValueInt(), 5);
// max
ASSERT_EQ(results[0][3].type(), TypedValue::Type::Int);
EXPECT_EQ(results[0][3].ValueInt(), 12);
// sum
ASSERT_EQ(results[0][4].type(), TypedValue::Type::Int);
EXPECT_EQ(results[0][4].ValueInt(), 24);
// avg
ASSERT_EQ(results[0][5].type(), TypedValue::Type::Double);
EXPECT_FLOAT_EQ(results[0][5].ValueDouble(), 24 / 3.0);
// collect list
ASSERT_EQ(results[0][6].type(), TypedValue::Type::List);
EXPECT_THAT(ToIntList(results[0][6]), UnorderedElementsAre(5, 7, 12));
// collect map
ASSERT_EQ(results[0][7].type(), TypedValue::Type::Map);
auto map = ToIntMap(results[0][7]);
ASSERT_EQ(map.size(), 1);
EXPECT_EQ(map.begin()->first, "key");
EXPECT_FALSE(std::set<int>({5, 7, 12}).insert(map.begin()->second).second);
}
TEST_F(QueryPlanAggregateOps, WithoutDataWithDistinctAndWithGroupBy) {
{
auto results = AggregationResults(true, true, {Aggregation::Op::COUNT});
EXPECT_EQ(results.size(), 1);
EXPECT_EQ(results[0][0].type(), TypedValue::Type::Int);
EXPECT_EQ(results[0][0].ValueInt(), 0);
}
{
auto results = AggregationResults(true, true, {Aggregation::Op::SUM});
EXPECT_EQ(results.size(), 1);
EXPECT_EQ(results[0][0].type(), TypedValue::Type::Int);
EXPECT_EQ(results[0][0].ValueInt(), 0);
}
{
auto results = AggregationResults(true, true, {Aggregation::Op::AVG});
EXPECT_EQ(results.size(), 1);
EXPECT_EQ(results[0][0].type(), TypedValue::Type::Null);
}
{
auto results = AggregationResults(true, true, {Aggregation::Op::MIN});
EXPECT_EQ(results.size(), 1);
EXPECT_EQ(results[0][0].type(), TypedValue::Type::Null);
}
{
auto results = AggregationResults(true, true, {Aggregation::Op::MAX});
EXPECT_EQ(results.size(), 1);
EXPECT_EQ(results[0][0].type(), TypedValue::Type::Null);
}
{
auto results = AggregationResults(true, true, {Aggregation::Op::COLLECT_LIST});
EXPECT_EQ(results.size(), 1);
EXPECT_EQ(results[0][0].type(), TypedValue::Type::List);
}
{
auto results = AggregationResults(true, true, {Aggregation::Op::COLLECT_MAP});
EXPECT_EQ(results.size(), 1);
EXPECT_EQ(results[0][0].type(), TypedValue::Type::Map);
}
}
TEST_F(QueryPlanAggregateOps, WithoutDataWithDistinctAndWithoutGroupBy) {
auto results = AggregationResults(false, true);
ASSERT_EQ(results.size(), 1);
ASSERT_EQ(results[0].size(), 8);
// count(*)
ASSERT_EQ(results[0][0].type(), TypedValue::Type::Int);
EXPECT_EQ(results[0][0].ValueInt(), 0);
// count
ASSERT_EQ(results[0][1].type(), TypedValue::Type::Int);
EXPECT_EQ(results[0][1].ValueInt(), 0);
// min
EXPECT_TRUE(results[0][2].IsNull());
// max
EXPECT_TRUE(results[0][3].IsNull());
// sum
EXPECT_EQ(results[0][4].ValueInt(), 0);
// avg
EXPECT_TRUE(results[0][5].IsNull());
// collect list
ASSERT_EQ(results[0][6].type(), TypedValue::Type::List);
EXPECT_EQ(ToIntList(results[0][6]).size(), 0);
// collect map
ASSERT_EQ(results[0][7].type(), TypedValue::Type::Map);
EXPECT_EQ(ToIntMap(results[0][7]).size(), 0);
}
TEST(QueryPlan, AggregateGroupByValuesWithDistinct) {
// Tests that distinct groups are aggregated properly for values of all types.
// Also test the "remember" part of the Aggregation API as final results are
// obtained via a property lookup of a remembered node.
memgraph::storage::Storage db;
auto storage_dba = db.Access();
memgraph::query::DbAccessor dba(&storage_dba);
// a vector of memgraph::storage::PropertyValue to be set as property values on vertices
// most of them should result in a distinct group (commented where not)
std::vector<memgraph::storage::PropertyValue> group_by_vals;
group_by_vals.emplace_back(4);
group_by_vals.emplace_back(7);
group_by_vals.emplace_back(7.3);
group_by_vals.emplace_back(7.2);
group_by_vals.emplace_back("Johhny");
group_by_vals.emplace_back("Jane");
group_by_vals.emplace_back("1");
group_by_vals.emplace_back(true);
group_by_vals.emplace_back(false);
group_by_vals.emplace_back(std::vector<memgraph::storage::PropertyValue>{memgraph::storage::PropertyValue(1)});
group_by_vals.emplace_back(std::vector<memgraph::storage::PropertyValue>{memgraph::storage::PropertyValue(1),
memgraph::storage::PropertyValue(2)});
group_by_vals.emplace_back(std::vector<memgraph::storage::PropertyValue>{memgraph::storage::PropertyValue(2),
memgraph::storage::PropertyValue(1)});
group_by_vals.emplace_back(memgraph::storage::PropertyValue());
// should NOT result in another group because 7.0 == 7
group_by_vals.emplace_back(7.0);
// should NOT result in another group
group_by_vals.emplace_back(std::vector<memgraph::storage::PropertyValue>{memgraph::storage::PropertyValue(1),
memgraph::storage::PropertyValue(2.0)});
// generate a lot of vertices and set props on them
auto prop = dba.NameToProperty("prop");
for (int i = 0; i < 1000; ++i)
ASSERT_TRUE(dba.InsertVertex().SetProperty(prop, group_by_vals[i % group_by_vals.size()]).HasValue());
dba.AdvanceCommand();
AstStorage storage;
SymbolTable symbol_table;
// match all nodes and perform aggregations
auto n = MakeScanAll(storage, symbol_table, "n");
auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop);
auto produce =
MakeAggregationProduce(n.op_, symbol_table, storage, {n_p}, {Aggregation::Op::COUNT}, {n_p}, {n.sym_}, true);
auto context = MakeContext(storage, symbol_table, &dba);
auto results = CollectProduce(*produce, &context);
ASSERT_EQ(results.size(), group_by_vals.size() - 2);
std::unordered_set<TypedValue, TypedValue::Hash, TypedValue::BoolEqual> result_group_bys;
for (const auto &row : results) {
ASSERT_EQ(2, row.size());
if (!row[1].IsNull()) {
ASSERT_EQ(1, row[0].ValueInt());
}
result_group_bys.insert(row[1]);
}
ASSERT_EQ(result_group_bys.size(), group_by_vals.size() - 2);
std::vector<TypedValue> group_by_tvals;
group_by_tvals.reserve(group_by_vals.size());
for (const auto &v : group_by_vals) group_by_tvals.emplace_back(v);
EXPECT_TRUE(std::is_permutation(group_by_tvals.begin(), group_by_tvals.end() - 2, result_group_bys.begin(),
TypedValue::BoolEqual{}));
}
TEST(QueryPlan, AggregateMultipleGroupByWithDistinct) {
// in this test we have 3 different properties that have different values
// for different records and assert that we get the correct combination
// of values in our groups
memgraph::storage::Storage db;
auto storage_dba = db.Access();
memgraph::query::DbAccessor dba(&storage_dba);
auto prop1 = dba.NameToProperty("prop1");
auto prop2 = dba.NameToProperty("prop2");
auto prop3 = dba.NameToProperty("prop3");
for (int i = 0; i < 2 * 3 * 5; ++i) {
auto v = dba.InsertVertex();
ASSERT_TRUE(v.SetProperty(prop1, memgraph::storage::PropertyValue(static_cast<bool>(i % 2))).HasValue());
ASSERT_TRUE(v.SetProperty(prop2, memgraph::storage::PropertyValue(i % 3)).HasValue());
ASSERT_TRUE(v.SetProperty(prop3, memgraph::storage::PropertyValue("value" + std::to_string(i % 5))).HasValue());
}
dba.AdvanceCommand();
AstStorage storage;
SymbolTable symbol_table;
// match all nodes and perform aggregations
auto n = MakeScanAll(storage, symbol_table, "n");
auto n_p1 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop1);
auto n_p2 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop2);
auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p1}, {Aggregation::Op::COUNT}, {n_p1, n_p2},
{n.sym_}, true);
auto context = MakeContext(storage, symbol_table, &dba);
auto results = CollectProduce(*produce, &context);
for (const auto &row : results) {
ASSERT_EQ(1, row[0].ValueInt());
}
}
TEST(QueryPlan, AggregateNoInputWithDistinct) {
memgraph::storage::Storage db;
auto storage_dba = db.Access();
memgraph::query::DbAccessor dba(&storage_dba);
AstStorage storage;
SymbolTable symbol_table;
auto two = LITERAL(2);
auto produce = MakeAggregationProduce(nullptr, symbol_table, storage, {two}, {Aggregation::Op::COUNT}, {}, {}, true);
auto context = MakeContext(storage, symbol_table, &dba);
auto results = CollectProduce(*produce, &context);
EXPECT_EQ(1, results.size());
EXPECT_EQ(1, results[0].size());
EXPECT_EQ(TypedValue::Type::Int, results[0][0].type());
EXPECT_EQ(1, results[0][0].ValueInt());
}
TEST(QueryPlan, AggregateCountEdgeCasesWithDistinct) {
// tests for detected bugs in the COUNT aggregation behavior
// ensure that COUNT returns correctly for
// - 0 vertices in database
// - 1 vertex in database, property not set
// - 1 vertex in database, property set
// - 2 vertices in database, property set on one
// - 2 vertices in database, property set on both
memgraph::storage::Storage db;
auto storage_dba = db.Access();
memgraph::query::DbAccessor dba(&storage_dba);
auto prop = dba.NameToProperty("prop");
AstStorage storage;
SymbolTable symbol_table;
auto n = MakeScanAll(storage, symbol_table, "n");
auto n_p = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop);
// returns -1 when there are no results
// otherwise returns MATCH (n) RETURN count(n.prop)
auto count = [&]() {
auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p}, {Aggregation::Op::COUNT}, {}, {}, true);
auto context = MakeContext(storage, symbol_table, &dba);
auto results = CollectProduce(*produce, &context);
if (results.size() == 0) return -1L;
EXPECT_EQ(1, results.size());
EXPECT_EQ(1, results[0].size());
EXPECT_EQ(TypedValue::Type::Int, results[0][0].type());
return results[0][0].ValueInt();
};
// no vertices yet in database
EXPECT_EQ(0, count());
// one vertex, no property set
dba.InsertVertex();
dba.AdvanceCommand();
EXPECT_EQ(0, count());
// one vertex, property set
for (auto va : dba.Vertices(memgraph::storage::View::OLD))
ASSERT_TRUE(va.SetProperty(prop, memgraph::storage::PropertyValue(42)).HasValue());
dba.AdvanceCommand();
EXPECT_EQ(1, count());
// two vertices, one with property set
dba.InsertVertex();
dba.AdvanceCommand();
EXPECT_EQ(1, count());
// two vertices, both with property set
for (auto va : dba.Vertices(memgraph::storage::View::OLD))
ASSERT_TRUE(va.SetProperty(prop, memgraph::storage::PropertyValue(42)).HasValue());
dba.AdvanceCommand();
EXPECT_EQ(1, count());
}
TEST(QueryPlan, AggregateFirstValueTypesWithDistinct) {
// testing exceptions that get emitted by the first-value
// type check
memgraph::storage::Storage db;
auto storage_dba = db.Access();
memgraph::query::DbAccessor dba(&storage_dba);
auto v1 = dba.InsertVertex();
auto prop_string = dba.NameToProperty("string");
ASSERT_TRUE(v1.SetProperty(prop_string, memgraph::storage::PropertyValue("johhny")).HasValue());
auto prop_int = dba.NameToProperty("int");
ASSERT_TRUE(v1.SetProperty(prop_int, memgraph::storage::PropertyValue(12)).HasValue());
dba.AdvanceCommand();
AstStorage storage;
SymbolTable symbol_table;
auto n = MakeScanAll(storage, symbol_table, "n");
auto n_prop_string = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop_string);
auto n_prop_int = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), prop_int);
auto n_id = n_prop_string->expression_;
auto aggregate = [&](Expression *expression, Aggregation::Op aggr_op) {
auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {expression}, {aggr_op}, {}, {}, true);
auto context = MakeContext(storage, symbol_table, &dba);
CollectProduce(*produce, &context);
};
// everything except for COUNT and COLLECT fails on a Vertex
aggregate(n_id, Aggregation::Op::COUNT);
EXPECT_THROW(aggregate(n_id, Aggregation::Op::MIN), QueryRuntimeException);
EXPECT_THROW(aggregate(n_id, Aggregation::Op::MAX), QueryRuntimeException);
EXPECT_THROW(aggregate(n_id, Aggregation::Op::AVG), QueryRuntimeException);
EXPECT_THROW(aggregate(n_id, Aggregation::Op::SUM), QueryRuntimeException);
// on strings AVG and SUM fail
aggregate(n_prop_string, Aggregation::Op::COUNT);
aggregate(n_prop_string, Aggregation::Op::MIN);
aggregate(n_prop_string, Aggregation::Op::MAX);
EXPECT_THROW(aggregate(n_prop_string, Aggregation::Op::AVG), QueryRuntimeException);
EXPECT_THROW(aggregate(n_prop_string, Aggregation::Op::SUM), QueryRuntimeException);
// on ints nothing fails
aggregate(n_prop_int, Aggregation::Op::COUNT);
aggregate(n_prop_int, Aggregation::Op::MIN);
aggregate(n_prop_int, Aggregation::Op::MAX);
aggregate(n_prop_int, Aggregation::Op::AVG);
aggregate(n_prop_int, Aggregation::Op::SUM);
aggregate(n_prop_int, Aggregation::Op::COLLECT_LIST);
aggregate(n_prop_int, Aggregation::Op::COLLECT_MAP);
}
TEST(QueryPlan, AggregateTypesWithDistinct) {
// testing exceptions that can get emitted by an aggregation
// does not check all combinations that can result in an exception
// (that logic is defined and tested by TypedValue)
memgraph::storage::Storage db;
auto storage_dba = db.Access();
memgraph::query::DbAccessor dba(&storage_dba);
auto p1 = dba.NameToProperty("p1"); // has only string props
ASSERT_TRUE(dba.InsertVertex().SetProperty(p1, memgraph::storage::PropertyValue("string")).HasValue());
ASSERT_TRUE(dba.InsertVertex().SetProperty(p1, memgraph::storage::PropertyValue("str2")).HasValue());
auto p2 = dba.NameToProperty("p2"); // combines int and bool
ASSERT_TRUE(dba.InsertVertex().SetProperty(p2, memgraph::storage::PropertyValue(42)).HasValue());
ASSERT_TRUE(dba.InsertVertex().SetProperty(p2, memgraph::storage::PropertyValue(true)).HasValue());
dba.AdvanceCommand();
AstStorage storage;
SymbolTable symbol_table;
auto n = MakeScanAll(storage, symbol_table, "n");
auto n_p1 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), p1);
auto n_p2 = PROPERTY_LOOKUP(IDENT("n")->MapTo(n.sym_), p2);
auto aggregate = [&](Expression *expression, Aggregation::Op aggr_op) {
auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {expression}, {aggr_op}, {}, {}, true);
auto context = MakeContext(storage, symbol_table, &dba);
CollectProduce(*produce, &context);
};
// everything except for COUNT and COLLECT fails on a Vertex
auto n_id = n_p1->expression_;
aggregate(n_id, Aggregation::Op::COUNT);
aggregate(n_id, Aggregation::Op::COLLECT_LIST);
aggregate(n_id, Aggregation::Op::COLLECT_MAP);
EXPECT_THROW(aggregate(n_id, Aggregation::Op::MIN), QueryRuntimeException);
EXPECT_THROW(aggregate(n_id, Aggregation::Op::MAX), QueryRuntimeException);
EXPECT_THROW(aggregate(n_id, Aggregation::Op::AVG), QueryRuntimeException);
EXPECT_THROW(aggregate(n_id, Aggregation::Op::SUM), QueryRuntimeException);
// on strings AVG and SUM fail
aggregate(n_p1, Aggregation::Op::COUNT);
aggregate(n_p1, Aggregation::Op::COLLECT_LIST);
aggregate(n_p1, Aggregation::Op::COLLECT_MAP);
aggregate(n_p1, Aggregation::Op::MIN);
aggregate(n_p1, Aggregation::Op::MAX);
EXPECT_THROW(aggregate(n_p1, Aggregation::Op::AVG), QueryRuntimeException);
EXPECT_THROW(aggregate(n_p1, Aggregation::Op::SUM), QueryRuntimeException);
// combination of int and bool, everything except COUNT and COLLECT fails
aggregate(n_p2, Aggregation::Op::COUNT);
aggregate(n_p2, Aggregation::Op::COLLECT_LIST);
aggregate(n_p2, Aggregation::Op::COLLECT_MAP);
EXPECT_THROW(aggregate(n_p2, Aggregation::Op::MIN), QueryRuntimeException);
EXPECT_THROW(aggregate(n_p2, Aggregation::Op::MAX), QueryRuntimeException);
EXPECT_THROW(aggregate(n_p2, Aggregation::Op::AVG), QueryRuntimeException);
EXPECT_THROW(aggregate(n_p2, Aggregation::Op::SUM), QueryRuntimeException);
}

View File

@ -90,7 +90,7 @@ class PlanChecker : public virtual HierarchicalLogicalOperatorVisitor {
}
PRE_VISIT(Unwind);
PRE_VISIT(Distinct);
bool PreVisit(Foreach &op) override {
CheckOp(op);
return false;
@ -216,6 +216,7 @@ class ExpectAggregate : public OpChecker<Aggregate> {
EXPECT_EQ(typeid(aggr_elem.value).hash_code(), typeid(aggr->expression1_).hash_code());
EXPECT_EQ(typeid(aggr_elem.key).hash_code(), typeid(aggr->expression2_).hash_code());
EXPECT_EQ(aggr_elem.op, aggr->op_);
EXPECT_EQ(aggr_elem.distinct, aggr->distinct_);
EXPECT_EQ(aggr_elem.output_sym, symbol_table.at(*aggr));
}
EXPECT_EQ(aggr_it, aggregations_.end());

View File

@ -253,7 +253,7 @@ TEST_F(TestSymbolGenerator, MatchWithWhere) {
TEST_F(TestSymbolGenerator, MatchWithWhereUnbound) {
// Test MATCH (old) WITH COUNT(old) AS c WHERE old.prop < 42
auto prop = dba.NameToProperty("prop");
auto query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("old"))), WITH(COUNT(IDENT("old")), AS("c")),
auto query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("old"))), WITH(COUNT(IDENT("old"), false), AS("c")),
WHERE(LESS(PROPERTY_LOOKUP("old", prop), LITERAL(42)))));
EXPECT_THROW(memgraph::query::MakeSymbolTable(query), UnboundVariableError);
}
@ -313,7 +313,7 @@ TEST_F(TestSymbolGenerator, MatchReturnSum) {
// Test MATCH (n) RETURN SUM(n.prop) + 42 AS result
auto prop = dba.NameToProperty("prop");
auto node = NODE("n");
auto sum = SUM(PROPERTY_LOOKUP("n", prop));
auto sum = SUM(PROPERTY_LOOKUP("n", prop), false);
auto as_result = AS("result");
auto query = QUERY(SINGLE_QUERY(MATCH(PATTERN(node)), RETURN(ADD(sum, LITERAL(42)), as_result)));
auto symbol_table = memgraph::query::MakeSymbolTable(query);
@ -330,8 +330,9 @@ TEST_F(TestSymbolGenerator, MatchReturnSum) {
TEST_F(TestSymbolGenerator, NestedAggregation) {
// Test MATCH (n) RETURN SUM(42 + SUM(n.prop)) AS s
auto prop = dba.NameToProperty("prop");
auto query = QUERY(
SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), RETURN(SUM(ADD(LITERAL(42), SUM(PROPERTY_LOOKUP("n", prop)))), AS("s"))));
auto query =
QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))),
RETURN(SUM(ADD(LITERAL(42), SUM(PROPERTY_LOOKUP("n", prop), false)), false), AS("s"))));
EXPECT_THROW(memgraph::query::MakeSymbolTable(query), SemanticException);
}
@ -339,7 +340,7 @@ TEST_F(TestSymbolGenerator, WrongAggregationContext) {
// Test MATCH (n) WITH n.prop AS prop WHERE SUM(prop) < 42
auto prop = dba.NameToProperty("prop");
auto query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), WITH(PROPERTY_LOOKUP("n", prop), AS("prop")),
WHERE(LESS(SUM(IDENT("prop")), LITERAL(42)))));
WHERE(LESS(SUM(IDENT("prop"), false), LITERAL(42)))));
EXPECT_THROW(memgraph::query::MakeSymbolTable(query), SemanticException);
}
@ -429,14 +430,15 @@ TEST_F(TestSymbolGenerator, LimitUsingIdentifier) {
TEST_F(TestSymbolGenerator, OrderByAggregation) {
// Test MATCH (old) RETURN old AS new ORDER BY COUNT(1)
auto query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("old"))), RETURN("old", AS("new"), ORDER_BY(COUNT(LITERAL(1))))));
auto query =
QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("old"))), RETURN("old", AS("new"), ORDER_BY(COUNT(LITERAL(1), false)))));
EXPECT_THROW(memgraph::query::MakeSymbolTable(query), SemanticException);
}
TEST_F(TestSymbolGenerator, OrderByUnboundVariable) {
// Test MATCH (old) RETURN COUNT(old) AS new ORDER BY old
auto query =
QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("old"))), RETURN(COUNT(IDENT("old")), AS("new"), ORDER_BY(IDENT("old")))));
auto query = QUERY(
SINGLE_QUERY(MATCH(PATTERN(NODE("old"))), RETURN(COUNT(IDENT("old"), false), AS("new"), ORDER_BY(IDENT("old")))));
EXPECT_THROW(memgraph::query::MakeSymbolTable(query), UnboundVariableError);
}
@ -446,7 +448,7 @@ TEST_F(TestSymbolGenerator, AggregationOrderBy) {
auto ident_old = IDENT("old");
auto as_new = AS("new");
auto ident_new = IDENT("new");
auto query = QUERY(SINGLE_QUERY(MATCH(PATTERN(node)), RETURN(COUNT(ident_old), as_new, ORDER_BY(ident_new))));
auto query = QUERY(SINGLE_QUERY(MATCH(PATTERN(node)), RETURN(COUNT(ident_old, false), as_new, ORDER_BY(ident_new))));
auto symbol_table = memgraph::query::MakeSymbolTable(query);
// Symbols for pattern, `old`, `count(old)` and `new`
EXPECT_EQ(symbol_table.max_position(), 4);