Query::Plan::Aggregation - optional input bugfix
Reviewers: teon.banek, mislav.bradac Reviewed By: teon.banek Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D280
This commit is contained in:
parent
bba20cf89c
commit
3d0181b28b
@ -977,12 +977,15 @@ std::unique_ptr<Cursor> Aggregate::MakeCursor(GraphDbAccessor &db) {
|
||||
|
||||
Aggregate::AggregateCursor::AggregateCursor(Aggregate &self,
|
||||
GraphDbAccessor &db)
|
||||
: self_(self), db_(db), input_cursor_(self_.input_->MakeCursor(db)) {}
|
||||
: self_(self),
|
||||
db_(db),
|
||||
input_cursor_(self.input_ ? self_.input_->MakeCursor(db) : nullptr) {}
|
||||
|
||||
bool Aggregate::AggregateCursor::Pull(Frame &frame,
|
||||
const SymbolTable &symbol_table) {
|
||||
if (!pulled_all_input_) {
|
||||
PullAllInput(frame, symbol_table);
|
||||
ProcessAll(frame, symbol_table);
|
||||
|
||||
pulled_all_input_ = true;
|
||||
aggregation_it_ = aggregation_.begin();
|
||||
|
||||
@ -1015,23 +1018,15 @@ bool Aggregate::AggregateCursor::Pull(Frame &frame,
|
||||
return true;
|
||||
}
|
||||
|
||||
void Aggregate::AggregateCursor::PullAllInput(Frame &frame,
|
||||
const SymbolTable &symbol_table) {
|
||||
void Aggregate::AggregateCursor::ProcessAll(Frame &frame,
|
||||
const SymbolTable &symbol_table) {
|
||||
ExpressionEvaluator evaluator(frame, symbol_table);
|
||||
evaluator.SwitchNew();
|
||||
|
||||
while (input_cursor_->Pull(frame, symbol_table)) {
|
||||
// create the group-by list of values
|
||||
std::list<TypedValue> group_by;
|
||||
for (Expression *expression : self_.group_by_) {
|
||||
expression->Accept(evaluator);
|
||||
group_by.emplace_back(evaluator.PopBack());
|
||||
}
|
||||
|
||||
AggregationValue &agg_value = aggregation_[group_by];
|
||||
EnsureInitialized(frame, agg_value);
|
||||
Update(frame, symbol_table, evaluator, agg_value);
|
||||
}
|
||||
if (input_cursor_)
|
||||
while (input_cursor_->Pull(frame, symbol_table))
|
||||
ProcessOne(frame, symbol_table, evaluator);
|
||||
else
|
||||
ProcessOne(frame, symbol_table, evaluator);
|
||||
|
||||
// calculate AVG aggregations (so far they have only been summed)
|
||||
for (int pos = 0; pos < self_.aggregations_.size(); ++pos) {
|
||||
@ -1045,6 +1040,21 @@ void Aggregate::AggregateCursor::PullAllInput(Frame &frame,
|
||||
}
|
||||
}
|
||||
|
||||
void Aggregate::AggregateCursor::ProcessOne(Frame &frame,
|
||||
const SymbolTable &symbol_table,
|
||||
ExpressionEvaluator &evaluator) {
|
||||
// create the group-by list of values
|
||||
std::list<TypedValue> group_by;
|
||||
for (Expression *expression : self_.group_by_) {
|
||||
expression->Accept(evaluator);
|
||||
group_by.emplace_back(evaluator.PopBack());
|
||||
}
|
||||
|
||||
AggregationValue &agg_value = aggregation_[group_by];
|
||||
EnsureInitialized(frame, agg_value);
|
||||
Update(frame, symbol_table, evaluator, agg_value);
|
||||
}
|
||||
|
||||
void Aggregate::AggregateCursor::EnsureInitialized(
|
||||
Frame &frame, Aggregate::AggregateCursor::AggregationValue &agg_value) {
|
||||
if (agg_value.values_.size() > 0) return;
|
||||
|
@ -873,6 +873,7 @@ class Aggregate : public LogicalOperator {
|
||||
|
||||
Aggregate &self_;
|
||||
GraphDbAccessor &db_;
|
||||
// optional
|
||||
std::unique_ptr<Cursor> input_cursor_;
|
||||
// storage for aggregated data
|
||||
// map key is the list of group-by values
|
||||
@ -891,14 +892,21 @@ class Aggregate : public LogicalOperator {
|
||||
bool pulled_all_input_{false};
|
||||
|
||||
/**
|
||||
* Pulls from the input until it's exhausted. Accumulates the results
|
||||
* in the `aggregation_` map.
|
||||
* Pulls from the input operator until exhausted and aggregates the
|
||||
* results. If the input operator is not provided, a single call
|
||||
* to ProccessOne is issued.
|
||||
*
|
||||
* Accumulation automatically groups the results so that `aggregation_`
|
||||
* cache cardinality depends on number of
|
||||
* aggregation results, and not on the number of inputs.
|
||||
*/
|
||||
void PullAllInput(Frame &frame, const SymbolTable &symbolTable);
|
||||
void ProcessAll(Frame &frame, const SymbolTable &symbol_table);
|
||||
|
||||
/**
|
||||
* Performs a single accumulation.
|
||||
*/
|
||||
void ProcessOne(Frame &frame, const SymbolTable &symbolTable,
|
||||
ExpressionEvaluator &evaluator);
|
||||
|
||||
/** Ensures the new AggregationValue has been initialized. This means
|
||||
* that the value vectors are filled with an appropriate number of Nulls,
|
||||
|
@ -326,6 +326,26 @@ TEST(QueryPlan, AggregateAdvance) {
|
||||
check(true);
|
||||
}
|
||||
|
||||
TEST(QueryPlan, AggregateNoInput) {
|
||||
Dbms dbms;
|
||||
auto dba = dbms.active();
|
||||
AstTreeStorage storage;
|
||||
SymbolTable symbol_table;
|
||||
|
||||
auto two = LITERAL(2);
|
||||
auto output = NEXPR("two", IDENT("two"));
|
||||
symbol_table[*output->expression_] = symbol_table.CreateSymbol("two");
|
||||
|
||||
auto produce = MakeAggregationProduce(nullptr, symbol_table, storage, {two},
|
||||
{Aggregation::Op::COUNT},
|
||||
{}, {});
|
||||
auto results = CollectProduce(produce, symbol_table, *dba).GetResults();
|
||||
EXPECT_EQ(1, results.size());
|
||||
EXPECT_EQ(1, results[0].size());
|
||||
EXPECT_EQ(TypedValue::Type::Int, results[0][0].type());
|
||||
EXPECT_EQ(1, results[0][0].Value<int64_t>());
|
||||
}
|
||||
|
||||
TEST(QueryPlan, AggregateCountEdgeCases) {
|
||||
// tests for detected bugs in the COUNT aggregation behavior
|
||||
// ensure that COUNT returns correctly for
|
||||
|
Loading…
Reference in New Issue
Block a user