Remember only symbols from group by during aggregation

Summary:
This avoids the unnecessary work of storing symbols which can never be
read after an aggregation is complete. With regards to distributed, a
major benefit is gained in reducing what is transferred over the
network. Hopefully, this doesn't break some obscure case where we
actually needed to remember all used symbols.

Reviewers: florijan, mislav.bradac, msantl

Reviewed By: florijan

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1158
This commit is contained in:
Teon Banek 2018-01-31 15:01:29 +01:00
parent abdd7d8e9d
commit 03e571ea3a
3 changed files with 56 additions and 38 deletions

View File

@ -29,36 +29,6 @@ void ForEachPattern(
}
}
// Collects symbols from identifiers found in visited AST nodes.
class UsedSymbolsCollector : public HierarchicalTreeVisitor {
public:
explicit UsedSymbolsCollector(const SymbolTable &symbol_table)
: symbol_table_(symbol_table) {}
using HierarchicalTreeVisitor::PostVisit;
using HierarchicalTreeVisitor::PreVisit;
using HierarchicalTreeVisitor::Visit;
bool PostVisit(All &all) override {
// Remove the symbol which is bound by all, because we are only interested
// in free (unbound) symbols.
symbols_.erase(symbol_table_.at(*all.identifier_));
return true;
}
bool Visit(Identifier &ident) override {
symbols_.insert(symbol_table_.at(ident));
return true;
}
bool Visit(PrimitiveLiteral &) override { return true; }
bool Visit(ParameterLookup &) override { return true; }
bool Visit(query::CreateIndex &) override { return true; }
std::unordered_set<Symbol> symbols_;
const SymbolTable &symbol_table_;
};
// Converts multiple Patterns to Expansions. Each Pattern can contain an
// arbitrarily long chain of nodes and edges. The conversion to an Expansion is
// done by splitting a pattern into triplets (node1, edge, node2). The triplets

View File

@ -13,6 +13,36 @@
namespace query::plan {
/// Collects symbols from identifiers found in visited AST nodes.
class UsedSymbolsCollector : public HierarchicalTreeVisitor {
public:
explicit UsedSymbolsCollector(const SymbolTable &symbol_table)
: symbol_table_(symbol_table) {}
using HierarchicalTreeVisitor::PostVisit;
using HierarchicalTreeVisitor::PreVisit;
using HierarchicalTreeVisitor::Visit;
bool PostVisit(All &all) override {
// Remove the symbol which is bound by all, because we are only interested
// in free (unbound) symbols.
symbols_.erase(symbol_table_.at(*all.identifier_));
return true;
}
bool Visit(Identifier &ident) override {
symbols_.insert(symbol_table_.at(ident));
return true;
}
bool Visit(PrimitiveLiteral &) override { return true; }
bool Visit(ParameterLookup &) override { return true; }
bool Visit(query::CreateIndex &) override { return true; }
std::unordered_set<Symbol> symbols_;
const SymbolTable &symbol_table_;
};
/// Normalized representation of a pattern that needs to be matched.
struct Expansion {
/// The first node in the expansion, it can be a single node.

View File

@ -116,7 +116,14 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
named_expr->Accept(*this);
named_expressions_.emplace_back(named_expr);
}
// Collect aggregations.
// Collect symbols used in group by expressions.
if (!aggregations_.empty()) {
UsedSymbolsCollector collector(symbol_table_);
for (auto &group_by : group_by_) {
group_by->Accept(collector);
}
group_by_used_symbols_ = collector.symbols_;
}
if (aggregations_.empty()) {
// Visit order_by and where if we do not have aggregations. This way we
// prevent collecting group_by expressions from order_by and where, which
@ -205,8 +212,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
bool Visit(Identifier &ident) override {
const auto &symbol = symbol_table_.at(ident);
if (std::find(output_symbols_.begin(), output_symbols_.end(), symbol) ==
output_symbols_.end()) {
if (!utils::Contains(output_symbols_, symbol)) {
// Don't pick up new symbols, even though they may be used in ORDER BY or
// WHERE.
used_symbols_.insert(symbol);
@ -392,9 +398,9 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
auto *limit() const { return body_.limit; }
// Optional Where clause for filtering.
const auto *where() const { return where_; }
// Set of symbols used inside the visited expressions outside of aggregation
// expression. These only includes old symbols, even though new ones may have
// been used in ORDER BY or WHERE.
// Set of symbols used inside the visited expressions, including the inside of
// aggregation expression. These only includes old symbols, even though new
// ones may have been used in ORDER BY or WHERE.
const auto &used_symbols() const { return used_symbols_; }
// List of aggregation elements found in expressions.
const auto &aggregations() const { return aggregations_; }
@ -402,6 +408,8 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
// expressions are used for grouping. For example, in `WITH sum(n.a) + 2 * n.b
// AS sum, n.c AS nc`, we will group by `2 * n.b` and `n.c`.
const auto &group_by() const { return group_by_; }
// Set of symbols used in group by expressions.
const auto &group_by_used_symbols() const { return group_by_used_symbols_; }
// All symbols generated by named expressions. They are collected in order of
// named_expressions.
const auto &output_symbols() const { return output_symbols_; }
@ -416,7 +424,15 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
std::vector<Symbol> output_symbols_;
std::vector<Aggregate::Element> aggregations_;
std::vector<Expression *> group_by_;
// Flag indicating whether an expression contains an aggregation.
std::unordered_set<Symbol> group_by_used_symbols_;
// Flag stack indicating whether an expression contains an aggregation. A
// stack is needed so that we differentiate the case where a child
// sub-expression has an aggregation, while the other child doesn't. For
// example AST, (+ (sum x) y)
// * (sum x) -- Has an aggregation.
// * y -- Doesn't, we need to group by this.
// * (+ (sum x) y) -- The whole expression has an aggregation, so we don't
// group by it.
std::list<bool> has_aggregation_;
std::vector<NamedExpression *> named_expressions_;
};
@ -435,8 +451,10 @@ auto GenReturnBody(LogicalOperator *input_op, bool advance_command,
}
if (!body.aggregations().empty()) {
// When we have aggregation, SKIP/LIMIT should always come after it.
std::vector<Symbol> remember(body.group_by_used_symbols().begin(),
body.group_by_used_symbols().end());
last_op = new Aggregate(std::shared_ptr<LogicalOperator>(last_op),
body.aggregations(), body.group_by(), used_symbols);
body.aggregations(), body.group_by(), remember);
}
last_op = new Produce(std::shared_ptr<LogicalOperator>(last_op),
body.named_expressions());