Remember only symbols from group by during aggregation
Summary: This avoids the unnecessary work of storing symbols which can never be read after an aggregation is complete. With regards to distributed, a major benefit is gained in reducing what is transferred over the network. Hopefully, this doesn't break some obscure case where we actually needed to remember all used symbols. Reviewers: florijan, mislav.bradac, msantl Reviewed By: florijan Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D1158
This commit is contained in:
parent
abdd7d8e9d
commit
03e571ea3a
src/query/plan
@ -29,36 +29,6 @@ void ForEachPattern(
|
||||
}
|
||||
}
|
||||
|
||||
// Collects symbols from identifiers found in visited AST nodes.
|
||||
class UsedSymbolsCollector : public HierarchicalTreeVisitor {
|
||||
public:
|
||||
explicit UsedSymbolsCollector(const SymbolTable &symbol_table)
|
||||
: symbol_table_(symbol_table) {}
|
||||
|
||||
using HierarchicalTreeVisitor::PostVisit;
|
||||
using HierarchicalTreeVisitor::PreVisit;
|
||||
using HierarchicalTreeVisitor::Visit;
|
||||
|
||||
bool PostVisit(All &all) override {
|
||||
// Remove the symbol which is bound by all, because we are only interested
|
||||
// in free (unbound) symbols.
|
||||
symbols_.erase(symbol_table_.at(*all.identifier_));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Visit(Identifier &ident) override {
|
||||
symbols_.insert(symbol_table_.at(ident));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Visit(PrimitiveLiteral &) override { return true; }
|
||||
bool Visit(ParameterLookup &) override { return true; }
|
||||
bool Visit(query::CreateIndex &) override { return true; }
|
||||
|
||||
std::unordered_set<Symbol> symbols_;
|
||||
const SymbolTable &symbol_table_;
|
||||
};
|
||||
|
||||
// Converts multiple Patterns to Expansions. Each Pattern can contain an
|
||||
// arbitrarily long chain of nodes and edges. The conversion to an Expansion is
|
||||
// done by splitting a pattern into triplets (node1, edge, node2). The triplets
|
||||
|
@ -13,6 +13,36 @@
|
||||
|
||||
namespace query::plan {
|
||||
|
||||
/// Collects symbols from identifiers found in visited AST nodes.
|
||||
class UsedSymbolsCollector : public HierarchicalTreeVisitor {
|
||||
public:
|
||||
explicit UsedSymbolsCollector(const SymbolTable &symbol_table)
|
||||
: symbol_table_(symbol_table) {}
|
||||
|
||||
using HierarchicalTreeVisitor::PostVisit;
|
||||
using HierarchicalTreeVisitor::PreVisit;
|
||||
using HierarchicalTreeVisitor::Visit;
|
||||
|
||||
bool PostVisit(All &all) override {
|
||||
// Remove the symbol which is bound by all, because we are only interested
|
||||
// in free (unbound) symbols.
|
||||
symbols_.erase(symbol_table_.at(*all.identifier_));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Visit(Identifier &ident) override {
|
||||
symbols_.insert(symbol_table_.at(ident));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Visit(PrimitiveLiteral &) override { return true; }
|
||||
bool Visit(ParameterLookup &) override { return true; }
|
||||
bool Visit(query::CreateIndex &) override { return true; }
|
||||
|
||||
std::unordered_set<Symbol> symbols_;
|
||||
const SymbolTable &symbol_table_;
|
||||
};
|
||||
|
||||
/// Normalized representation of a pattern that needs to be matched.
|
||||
struct Expansion {
|
||||
/// The first node in the expansion, it can be a single node.
|
||||
|
@ -116,7 +116,14 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
named_expr->Accept(*this);
|
||||
named_expressions_.emplace_back(named_expr);
|
||||
}
|
||||
// Collect aggregations.
|
||||
// Collect symbols used in group by expressions.
|
||||
if (!aggregations_.empty()) {
|
||||
UsedSymbolsCollector collector(symbol_table_);
|
||||
for (auto &group_by : group_by_) {
|
||||
group_by->Accept(collector);
|
||||
}
|
||||
group_by_used_symbols_ = collector.symbols_;
|
||||
}
|
||||
if (aggregations_.empty()) {
|
||||
// Visit order_by and where if we do not have aggregations. This way we
|
||||
// prevent collecting group_by expressions from order_by and where, which
|
||||
@ -205,8 +212,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
|
||||
bool Visit(Identifier &ident) override {
|
||||
const auto &symbol = symbol_table_.at(ident);
|
||||
if (std::find(output_symbols_.begin(), output_symbols_.end(), symbol) ==
|
||||
output_symbols_.end()) {
|
||||
if (!utils::Contains(output_symbols_, symbol)) {
|
||||
// Don't pick up new symbols, even though they may be used in ORDER BY or
|
||||
// WHERE.
|
||||
used_symbols_.insert(symbol);
|
||||
@ -392,9 +398,9 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
auto *limit() const { return body_.limit; }
|
||||
// Optional Where clause for filtering.
|
||||
const auto *where() const { return where_; }
|
||||
// Set of symbols used inside the visited expressions outside of aggregation
|
||||
// expression. These only includes old symbols, even though new ones may have
|
||||
// been used in ORDER BY or WHERE.
|
||||
// Set of symbols used inside the visited expressions, including the inside of
|
||||
// aggregation expression. These only includes old symbols, even though new
|
||||
// ones may have been used in ORDER BY or WHERE.
|
||||
const auto &used_symbols() const { return used_symbols_; }
|
||||
// List of aggregation elements found in expressions.
|
||||
const auto &aggregations() const { return aggregations_; }
|
||||
@ -402,6 +408,8 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
// expressions are used for grouping. For example, in `WITH sum(n.a) + 2 * n.b
|
||||
// AS sum, n.c AS nc`, we will group by `2 * n.b` and `n.c`.
|
||||
const auto &group_by() const { return group_by_; }
|
||||
// Set of symbols used in group by expressions.
|
||||
const auto &group_by_used_symbols() const { return group_by_used_symbols_; }
|
||||
// All symbols generated by named expressions. They are collected in order of
|
||||
// named_expressions.
|
||||
const auto &output_symbols() const { return output_symbols_; }
|
||||
@ -416,7 +424,15 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
std::vector<Symbol> output_symbols_;
|
||||
std::vector<Aggregate::Element> aggregations_;
|
||||
std::vector<Expression *> group_by_;
|
||||
// Flag indicating whether an expression contains an aggregation.
|
||||
std::unordered_set<Symbol> group_by_used_symbols_;
|
||||
// Flag stack indicating whether an expression contains an aggregation. A
|
||||
// stack is needed so that we differentiate the case where a child
|
||||
// sub-expression has an aggregation, while the other child doesn't. For
|
||||
// example AST, (+ (sum x) y)
|
||||
// * (sum x) -- Has an aggregation.
|
||||
// * y -- Doesn't, we need to group by this.
|
||||
// * (+ (sum x) y) -- The whole expression has an aggregation, so we don't
|
||||
// group by it.
|
||||
std::list<bool> has_aggregation_;
|
||||
std::vector<NamedExpression *> named_expressions_;
|
||||
};
|
||||
@ -435,8 +451,10 @@ auto GenReturnBody(LogicalOperator *input_op, bool advance_command,
|
||||
}
|
||||
if (!body.aggregations().empty()) {
|
||||
// When we have aggregation, SKIP/LIMIT should always come after it.
|
||||
std::vector<Symbol> remember(body.group_by_used_symbols().begin(),
|
||||
body.group_by_used_symbols().end());
|
||||
last_op = new Aggregate(std::shared_ptr<LogicalOperator>(last_op),
|
||||
body.aggregations(), body.group_by(), used_symbols);
|
||||
body.aggregations(), body.group_by(), remember);
|
||||
}
|
||||
last_op = new Produce(std::shared_ptr<LogicalOperator>(last_op),
|
||||
body.named_expressions());
|
||||
|
Loading…
Reference in New Issue
Block a user