Add planning distributed execution of AVG

Reviewers: florijan, msantl

Reviewed By: msantl

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1151
This commit is contained in:
Teon Banek 2018-01-29 15:04:55 +01:00
parent 760c6246d8
commit d41ffb5039

View File

@ -179,6 +179,7 @@ class DistributedPlanner : public HierarchicalLogicalOperatorVisitor {
case Aggregation::Op::MIN:
case Aggregation::Op::MAX:
case Aggregation::Op::SUM:
case Aggregation::Op::AVG:
break;
default:
return false;
@ -194,79 +195,123 @@ class DistributedPlanner : public HierarchicalLogicalOperatorVisitor {
input->OutputSymbols(distributed_plan_.symbol_table)));
return true;
}
// Aggregate uses associative operation(s), so split the work across master
// and workers.
auto make_merge_aggregation = [this](auto op, const auto &worker_sym) {
auto *worker_ident =
distributed_plan_.ast_storage.Create<Identifier>(worker_sym.name());
distributed_plan_.symbol_table[*worker_ident] = worker_sym;
auto merge_name =
Aggregation::OpToString(op) + std::to_string(worker_ident->uid());
auto make_ident = [this](const auto &symbol) {
auto *ident =
distributed_plan_.ast_storage.Create<Identifier>(symbol.name());
distributed_plan_.symbol_table[*ident] = symbol;
return ident;
};
auto make_named_expr = [&](const auto &in_sym, const auto &out_sym) {
auto *nexpr = distributed_plan_.ast_storage.Create<NamedExpression>(
out_sym.name(), make_ident(in_sym));
distributed_plan_.symbol_table[*nexpr] = out_sym;
return nexpr;
};
auto make_merge_aggregation = [&](auto op, const auto &worker_sym) {
auto *worker_ident = make_ident(worker_sym);
auto merge_name = Aggregation::OpToString(op) +
std::to_string(worker_ident->uid()) + "<-" +
worker_sym.name();
auto merge_sym = distributed_plan_.symbol_table.CreateSymbol(
merge_name, false, Symbol::Type::Number);
return Aggregate::Element{worker_ident, nullptr, op, merge_sym};
};
// Aggregate uses associative operation(s), so split the work across master
// and workers.
std::vector<Aggregate::Element> master_aggrs;
master_aggrs.reserve(aggr_op.aggregations().size());
std::vector<Aggregate::Element> worker_aggrs;
worker_aggrs.reserve(aggr_op.aggregations().size());
// We will need to create a Produce operator which moves the final results
// from new (merge) symbols into old aggregation symbols, because
// expressions following the aggregation expect the result in old symbols.
std::vector<NamedExpression *> produce_exprs;
produce_exprs.reserve(aggr_op.aggregations().size());
for (const auto &aggr : aggr_op.aggregations()) {
switch (aggr.op) {
// Count, like sum, only needs to sum all of the results on master.
case Aggregation::Op::COUNT:
case Aggregation::Op::SUM:
master_aggrs.emplace_back(
make_merge_aggregation(Aggregation::Op::SUM, aggr.output_sym));
case Aggregation::Op::SUM: {
worker_aggrs.emplace_back(aggr);
auto merge_aggr =
make_merge_aggregation(Aggregation::Op::SUM, aggr.output_sym);
master_aggrs.emplace_back(merge_aggr);
produce_exprs.emplace_back(
make_named_expr(merge_aggr.output_sym, aggr.output_sym));
break;
}
case Aggregation::Op::MIN:
case Aggregation::Op::MAX:
master_aggrs.emplace_back(
make_merge_aggregation(aggr.op, aggr.output_sym));
case Aggregation::Op::MAX: {
worker_aggrs.emplace_back(aggr);
auto merge_aggr = make_merge_aggregation(aggr.op, aggr.output_sym);
master_aggrs.emplace_back(merge_aggr);
produce_exprs.emplace_back(
make_named_expr(merge_aggr.output_sym, aggr.output_sym));
break;
}
// AVG is split into:
// * workers: SUM(xpr), COUNT(expr)
// * master: SUM(worker_sum) / toFloat(SUM(worker_count)) AS avg
case Aggregation::Op::AVG: {
auto worker_sum_sym = distributed_plan_.symbol_table.CreateSymbol(
aggr.output_sym.name() + "_SUM", false, Symbol::Type::Number);
Aggregate::Element worker_sum{aggr.value, aggr.key,
Aggregation::Op::SUM, worker_sum_sym};
worker_aggrs.emplace_back(worker_sum);
auto worker_count_sym = distributed_plan_.symbol_table.CreateSymbol(
aggr.output_sym.name() + "_COUNT", false, Symbol::Type::Number);
Aggregate::Element worker_count{
aggr.value, aggr.key, Aggregation::Op::COUNT, worker_count_sym};
worker_aggrs.emplace_back(worker_count);
auto master_sum =
make_merge_aggregation(Aggregation::Op::SUM, worker_sum_sym);
master_aggrs.emplace_back(master_sum);
auto master_count =
make_merge_aggregation(Aggregation::Op::SUM, worker_count_sym);
master_aggrs.emplace_back(master_count);
auto *master_sum_ident = make_ident(master_sum.output_sym);
auto *master_count_ident = make_ident(master_count.output_sym);
auto *to_float = distributed_plan_.ast_storage.Create<Function>(
"TOFLOAT", std::vector<Expression *>{master_count_ident});
auto *div_expr =
distributed_plan_.ast_storage.Create<DivisionOperator>(
master_sum_ident, to_float);
auto *as_avg = distributed_plan_.ast_storage.Create<NamedExpression>(
aggr.output_sym.name(), div_expr);
distributed_plan_.symbol_table[*as_avg] = aggr.output_sym;
produce_exprs.emplace_back(as_avg);
break;
}
default:
throw utils::NotYetImplemented("distributed planning");
}
}
// Rewiring is done in PostVisit(Produce), so just store our results.
master_aggrs_ = master_aggrs;
worker_aggr_ = std::make_shared<Aggregate>(
aggr_op.input(), worker_aggrs, aggr_op.group_by(), aggr_op.remember());
std::vector<Symbol> pull_symbols;
pull_symbols.reserve(worker_aggrs.size() + aggr_op.remember().size());
for (const auto &aggr : worker_aggrs)
pull_symbols.push_back(aggr.output_sym);
for (const auto &sym : aggr_op.remember()) pull_symbols.push_back(sym);
auto pull_op = std::make_shared<PullRemote>(
worker_aggr_, distributed_plan_.plan_id, pull_symbols);
auto master_aggr_op = std::make_shared<Aggregate>(
pull_op, master_aggrs, aggr_op.group_by(), aggr_op.remember());
// Make our master Aggregate into Produce + Aggregate
master_aggr_ = std::make_unique<Produce>(master_aggr_op, produce_exprs);
return true;
}
bool PreVisit(Produce &) override { return true; }
bool PostVisit(Produce &produce) override {
if (master_aggrs_.empty()) return true;
if (!master_aggr_) return true;
// We have to rewire master/worker aggregation.
DCHECK(worker_aggr_);
DCHECK(!distributed_plan_.worker_plan);
DCHECK(std::dynamic_pointer_cast<Aggregate>(produce.input()));
auto aggr_op = std::static_pointer_cast<Aggregate>(produce.input());
std::vector<Symbol> pull_symbols;
pull_symbols.reserve(aggr_op->aggregations().size() +
aggr_op->remember().size());
for (const auto &aggr : aggr_op->aggregations())
pull_symbols.push_back(aggr.output_sym);
for (const auto &sym : aggr_op->remember()) pull_symbols.push_back(sym);
distributed_plan_.worker_plan = aggr_op;
auto pull_op = std::make_shared<PullRemote>(
aggr_op, distributed_plan_.plan_id, pull_symbols);
auto master_aggr_op = std::make_shared<Aggregate>(
pull_op, master_aggrs_, aggr_op->group_by(), aggr_op->remember());
// Create a Produce operator which only moves the final results from new
// symbols into old aggregation symbols, because expressions following the
// aggregation expect the result in old symbols.
std::vector<NamedExpression *> produce_exprs;
produce_exprs.reserve(aggr_op->aggregations().size());
for (int i = 0; i < aggr_op->aggregations().size(); ++i) {
const auto &merge_result_sym = master_aggrs_[i].output_sym;
const auto &original_result_sym = aggr_op->aggregations()[i].output_sym;
auto *ident = distributed_plan_.ast_storage.Create<Identifier>(
merge_result_sym.name());
distributed_plan_.symbol_table[*ident] = merge_result_sym;
auto *nexpr = distributed_plan_.ast_storage.Create<NamedExpression>(
original_result_sym.name(), ident);
distributed_plan_.symbol_table[*nexpr] = original_result_sym;
produce_exprs.emplace_back(nexpr);
}
// Wire our master Produce into Produce + Aggregate
produce.set_input(std::make_shared<Produce>(master_aggr_op, produce_exprs));
master_aggrs_.clear();
distributed_plan_.worker_plan = std::move(worker_aggr_);
produce.set_input(std::move(master_aggr_));
return true;
}
@ -286,7 +331,8 @@ class DistributedPlanner : public HierarchicalLogicalOperatorVisitor {
private:
DistributedPlan &distributed_plan_;
// Used for rewiring the master/worker aggregation in PostVisit(Produce)
std::vector<Aggregate::Element> master_aggrs_;
std::shared_ptr<LogicalOperator> worker_aggr_;
std::unique_ptr<LogicalOperator> master_aggr_;
bool has_scan_all_ = false;
void RaiseIfCartesian() {