Query::Plan::Aggregation - optional input bugfix

Reviewers: teon.banek, mislav.bradac

Reviewed By: teon.banek

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D280
This commit is contained in:
florijan 2017-04-13 16:47:11 +02:00
parent bba20cf89c
commit 3d0181b28b
3 changed files with 58 additions and 20 deletions

View File

@ -977,12 +977,15 @@ std::unique_ptr<Cursor> Aggregate::MakeCursor(GraphDbAccessor &db) {
Aggregate::AggregateCursor::AggregateCursor(Aggregate &self,
GraphDbAccessor &db)
: self_(self), db_(db), input_cursor_(self_.input_->MakeCursor(db)) {}
: self_(self),
db_(db),
input_cursor_(self.input_ ? self_.input_->MakeCursor(db) : nullptr) {}
bool Aggregate::AggregateCursor::Pull(Frame &frame,
const SymbolTable &symbol_table) {
if (!pulled_all_input_) {
PullAllInput(frame, symbol_table);
ProcessAll(frame, symbol_table);
pulled_all_input_ = true;
aggregation_it_ = aggregation_.begin();
@ -1015,23 +1018,15 @@ bool Aggregate::AggregateCursor::Pull(Frame &frame,
return true;
}
void Aggregate::AggregateCursor::PullAllInput(Frame &frame,
const SymbolTable &symbol_table) {
void Aggregate::AggregateCursor::ProcessAll(Frame &frame,
const SymbolTable &symbol_table) {
ExpressionEvaluator evaluator(frame, symbol_table);
evaluator.SwitchNew();
while (input_cursor_->Pull(frame, symbol_table)) {
// create the group-by list of values
std::list<TypedValue> group_by;
for (Expression *expression : self_.group_by_) {
expression->Accept(evaluator);
group_by.emplace_back(evaluator.PopBack());
}
AggregationValue &agg_value = aggregation_[group_by];
EnsureInitialized(frame, agg_value);
Update(frame, symbol_table, evaluator, agg_value);
}
if (input_cursor_)
while (input_cursor_->Pull(frame, symbol_table))
ProcessOne(frame, symbol_table, evaluator);
else
ProcessOne(frame, symbol_table, evaluator);
// calculate AVG aggregations (so far they have only been summed)
for (int pos = 0; pos < self_.aggregations_.size(); ++pos) {
@ -1045,6 +1040,21 @@ void Aggregate::AggregateCursor::PullAllInput(Frame &frame,
}
}
void Aggregate::AggregateCursor::ProcessOne(Frame &frame,
const SymbolTable &symbol_table,
ExpressionEvaluator &evaluator) {
// create the group-by list of values
std::list<TypedValue> group_by;
for (Expression *expression : self_.group_by_) {
expression->Accept(evaluator);
group_by.emplace_back(evaluator.PopBack());
}
AggregationValue &agg_value = aggregation_[group_by];
EnsureInitialized(frame, agg_value);
Update(frame, symbol_table, evaluator, agg_value);
}
void Aggregate::AggregateCursor::EnsureInitialized(
Frame &frame, Aggregate::AggregateCursor::AggregationValue &agg_value) {
if (agg_value.values_.size() > 0) return;

View File

@ -873,6 +873,7 @@ class Aggregate : public LogicalOperator {
Aggregate &self_;
GraphDbAccessor &db_;
// optional
std::unique_ptr<Cursor> input_cursor_;
// storage for aggregated data
// map key is the list of group-by values
@ -891,14 +892,21 @@ class Aggregate : public LogicalOperator {
bool pulled_all_input_{false};
/**
* Pulls from the input until it's exhausted. Accumulates the results
* in the `aggregation_` map.
* Pulls from the input operator until exhausted and aggregates the
* results. If the input operator is not provided, a single call
* to ProccessOne is issued.
*
* Accumulation automatically groups the results so that `aggregation_`
* cache cardinality depends on number of
* aggregation results, and not on the number of inputs.
*/
void PullAllInput(Frame &frame, const SymbolTable &symbolTable);
void ProcessAll(Frame &frame, const SymbolTable &symbol_table);
/**
* Performs a single accumulation.
*/
void ProcessOne(Frame &frame, const SymbolTable &symbolTable,
ExpressionEvaluator &evaluator);
/** Ensures the new AggregationValue has been initialized. This means
* that the value vectors are filled with an appropriate number of Nulls,

View File

@ -326,6 +326,26 @@ TEST(QueryPlan, AggregateAdvance) {
check(true);
}
TEST(QueryPlan, AggregateNoInput) {
Dbms dbms;
auto dba = dbms.active();
AstTreeStorage storage;
SymbolTable symbol_table;
auto two = LITERAL(2);
auto output = NEXPR("two", IDENT("two"));
symbol_table[*output->expression_] = symbol_table.CreateSymbol("two");
auto produce = MakeAggregationProduce(nullptr, symbol_table, storage, {two},
{Aggregation::Op::COUNT},
{}, {});
auto results = CollectProduce(produce, symbol_table, *dba).GetResults();
EXPECT_EQ(1, results.size());
EXPECT_EQ(1, results[0].size());
EXPECT_EQ(TypedValue::Type::Int, results[0][0].type());
EXPECT_EQ(1, results[0][0].Value<int64_t>());
}
TEST(QueryPlan, AggregateCountEdgeCases) {
// tests for detected bugs in the COUNT aggregation behavior
// ensure that COUNT returns correctly for