Generalize MakeLogicalPlan with regards to planner

Summary:
This change modifies the planning API to be more general, in order to support
picking different planning strategies. The current planning strategy has been
named RuleBasedPlanner.

Reviewers: florijan, mislav.bradac

Reviewed By: florijan

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D411
This commit is contained in:
Teon Banek 2017-05-24 16:13:25 +02:00
parent fd19f76cba
commit 62f6a58c32
4 changed files with 139 additions and 121 deletions

View File

@ -39,7 +39,8 @@ void Interpret(const std::string &query, GraphDbAccessor &db_accessor,
high_level_tree->Accept(symbol_generator);
// high level tree -> logical plan
auto logical_plan = plan::MakeLogicalPlan(visitor.storage(), symbol_table);
auto logical_plan = plan::MakeLogicalPlan<plan::RuleBasedPlanner>(
visitor.storage(), symbol_table);
// generate frame based on symbol table max_position
Frame frame(symbol_table.max_position());

View File

@ -582,74 +582,6 @@ LogicalOperator *HandleWriteClause(Clause *clause, LogicalOperator *input_op,
return nullptr;
}
// Normalized representation of a pattern that needs to be matched.
struct Expansion {
// The first node in the expansion, it can be a single node.
NodeAtom *node1 = nullptr;
// Optional edge which connects the 2 nodes.
EdgeAtom *edge = nullptr;
// Optional node at the other end of an edge. If the expansion contains an
// edge, then this node is required.
NodeAtom *node2 = nullptr;
};
// Normalized representation of a single or multiple Match clauses.
//
// For example, `MATCH (a :Label) -[e1]- (b) -[e2]- (c) MATCH (n) -[e3]- (m)
// WHERE c.prop < 42` will produce the following.
// Expansions will store `(a) -[e1]-(b)`, `(b) -[e2]- (c)` and `(n) -[e3]- (m)`.
// Edge symbols for Cyphermorphism will only contain the set `{e1, e2}` for the
// first `MATCH` and the set `{e3}` for the second.
// Filters will contain 2 pairs. One for testing `:Label` on symbol `a` and the
// other obtained from `WHERE` on symbol `c`.
struct Matching {
// All expansions that need to be performed across Match clauses.
std::vector<Expansion> expansions;
// Symbols for edges established in match, used to ensure Cyphermorphism.
// There are multiple sets, because each Match clause determines a single set.
std::vector<std::unordered_set<Symbol>> edge_symbols;
// Pairs of filter expression and symbols used in them. The list should be
// filled using CollectPatternFilters function.
std::vector<std::pair<Expression *, std::unordered_set<Symbol>>> filters;
};
// Represents a read (+ write) part of a query. Each part ends with either:
// * RETURN clause;
// * WITH clause or
// * any of the write clauses.
//
// For a query `MATCH (n) MERGE (n) -[e]- (m) SET n.x = 42 MERGE (l)` the
// generated QueryPart will have `matching` generated for the `MATCH`.
// `remaining_clauses` will contain `Merge`, `SetProperty` and `Merge` clauses
// in that exact order. The pattern inside the first `MERGE` will be used to
// generate the first `merge_matching` element, and the second `MERGE` pattern
// will produce the second `merge_matching` element. This way, if someone
// traverses `remaining_clauses`, the order of appearance of `Merge` clauses is
// in the same order as their respective `merge_matching` elements.
struct QueryPart {
// All MATCH clauses merged into one Matching.
Matching matching;
// Each OPTIONAL MATCH converted to Matching.
std::vector<Matching> optional_matching;
// Matching for each MERGE clause. Since Merge is contained in
// remaining_clauses, this vector contains matching in the same order as Merge
// appears.
std::vector<Matching> merge_matching;
// All the remaining clauses (without Match).
std::vector<Clause *> remaining_clauses;
};
// Context which contains variables commonly used during planning.
struct PlanningContext {
SymbolTable &symbol_table;
AstTreeStorage &ast_storage;
// bound_symbols set is used to differentiate cycles in pattern matching, so
// that the operator can be correctly initialized whether to read the symbol
// or write it. E.g. `MATCH (n) -[r]- (n)` would bind (and write) the first
// `n`, but the latter `n` would only read the already written information.
std::unordered_set<Symbol> bound_symbols;
};
// Converts multiple Patterns to Expansions. Each Pattern can contain an
// arbitrarily long chain of nodes and edges. The conversion to an Expansion is
// done by splitting a pattern into triplets (node1, edge, node2). The triplets
@ -709,42 +641,6 @@ void AddMatching(const Match &match, const SymbolTable &symbol_table,
matching);
}
// Converts a Query to multiple QueryParts. In the process new Ast nodes may be
// created, e.g. filter expressions.
std::vector<QueryPart> CollectQueryParts(const SymbolTable &symbol_table,
AstTreeStorage &storage) {
auto query = storage.query();
std::vector<QueryPart> query_parts(1);
auto *query_part = &query_parts.back();
for (auto &clause : query->clauses_) {
if (auto *match = dynamic_cast<Match *>(clause)) {
if (match->optional_) {
query_part->optional_matching.emplace_back(Matching{});
AddMatching(*match, symbol_table, storage,
query_part->optional_matching.back());
} else {
debug_assert(query_part->optional_matching.empty(),
"Match clause cannot follow optional match.");
AddMatching(*match, symbol_table, storage, query_part->matching);
}
} else {
query_part->remaining_clauses.push_back(clause);
if (auto *merge = dynamic_cast<query::Merge *>(clause)) {
query_part->merge_matching.emplace_back(Matching{});
AddMatching({merge->pattern_}, nullptr, symbol_table, storage,
query_part->merge_matching.back());
} else if (dynamic_cast<With *>(clause)) {
query_parts.emplace_back(QueryPart{});
query_part = &query_parts.back();
} else if (dynamic_cast<Return *>(clause)) {
// TODO: Support RETURN UNION ...
return query_parts;
}
}
}
return query_parts;
}
LogicalOperator *PlanMatching(const Matching &matching,
LogicalOperator *input_op,
AstTreeStorage &storage, MatchContext &context) {
@ -854,10 +750,45 @@ auto GenMerge(query::Merge &merge, LogicalOperator *input_op,
} // namespace
std::unique_ptr<LogicalOperator> MakeLogicalPlan(AstTreeStorage &storage,
SymbolTable &symbol_table) {
auto query_parts = CollectQueryParts(symbol_table, storage);
PlanningContext context{symbol_table, storage};
// Converts a Query to multiple QueryParts. In the process new Ast nodes may be
// created, e.g. filter expressions.
std::vector<QueryPart> CollectQueryParts(const SymbolTable &symbol_table,
AstTreeStorage &storage) {
auto query = storage.query();
std::vector<QueryPart> query_parts(1);
auto *query_part = &query_parts.back();
for (auto &clause : query->clauses_) {
if (auto *match = dynamic_cast<Match *>(clause)) {
if (match->optional_) {
query_part->optional_matching.emplace_back(Matching{});
AddMatching(*match, symbol_table, storage,
query_part->optional_matching.back());
} else {
debug_assert(query_part->optional_matching.empty(),
"Match clause cannot follow optional match.");
AddMatching(*match, symbol_table, storage, query_part->matching);
}
} else {
query_part->remaining_clauses.push_back(clause);
if (auto *merge = dynamic_cast<query::Merge *>(clause)) {
query_part->merge_matching.emplace_back(Matching{});
AddMatching({merge->pattern_}, nullptr, symbol_table, storage,
query_part->merge_matching.back());
} else if (dynamic_cast<With *>(clause)) {
query_parts.emplace_back(QueryPart{});
query_part = &query_parts.back();
} else if (dynamic_cast<Return *>(clause)) {
// TODO: Support RETURN UNION ...
return query_parts;
}
}
}
return query_parts;
}
std::unique_ptr<LogicalOperator> RuleBasedPlanner::Plan(
std::vector<QueryPart> &query_parts) {
auto &context = context_;
LogicalOperator *input_op = nullptr;
// Set to true if a query command writes to the database.
bool is_write = false;

View File

@ -11,14 +11,100 @@ class SymbolTable;
namespace plan {
// Normalized representation of a pattern that needs to be matched.
struct Expansion {
// The first node in the expansion, it can be a single node.
NodeAtom *node1 = nullptr;
// Optional edge which connects the 2 nodes.
EdgeAtom *edge = nullptr;
// Optional node at the other end of an edge. If the expansion contains an
// edge, then this node is required.
NodeAtom *node2 = nullptr;
};
// Normalized representation of a single or multiple Match clauses.
//
// For example, `MATCH (a :Label) -[e1]- (b) -[e2]- (c) MATCH (n) -[e3]- (m)
// WHERE c.prop < 42` will produce the following.
// Expansions will store `(a) -[e1]-(b)`, `(b) -[e2]- (c)` and `(n) -[e3]- (m)`.
// Edge symbols for Cyphermorphism will only contain the set `{e1, e2}` for the
// first `MATCH` and the set `{e3}` for the second.
// Filters will contain 2 pairs. One for testing `:Label` on symbol `a` and the
// other obtained from `WHERE` on symbol `c`.
struct Matching {
// All expansions that need to be performed across Match clauses.
std::vector<Expansion> expansions;
// Symbols for edges established in match, used to ensure Cyphermorphism.
// There are multiple sets, because each Match clause determines a single set.
std::vector<std::unordered_set<Symbol>> edge_symbols;
// Pairs of filter expression and symbols used in them. The list should be
// filled using CollectPatternFilters function.
std::vector<std::pair<Expression *, std::unordered_set<Symbol>>> filters;
};
// Represents a read (+ write) part of a query. Each part ends with either:
// * RETURN clause;
// * WITH clause or
// * any of the write clauses.
//
// For a query `MATCH (n) MERGE (n) -[e]- (m) SET n.x = 42 MERGE (l)` the
// generated QueryPart will have `matching` generated for the `MATCH`.
// `remaining_clauses` will contain `Merge`, `SetProperty` and `Merge` clauses
// in that exact order. The pattern inside the first `MERGE` will be used to
// generate the first `merge_matching` element, and the second `MERGE` pattern
// will produce the second `merge_matching` element. This way, if someone
// traverses `remaining_clauses`, the order of appearance of `Merge` clauses is
// in the same order as their respective `merge_matching` elements.
struct QueryPart {
// All MATCH clauses merged into one Matching.
Matching matching;
// Each OPTIONAL MATCH converted to Matching.
std::vector<Matching> optional_matching;
// Matching for each MERGE clause. Since Merge is contained in
// remaining_clauses, this vector contains matching in the same order as Merge
// appears.
std::vector<Matching> merge_matching;
// All the remaining clauses (without Match).
std::vector<Clause *> remaining_clauses;
};
// Context which contains variables commonly used during planning.
struct PlanningContext {
SymbolTable &symbol_table;
AstTreeStorage &ast_storage;
// bound_symbols set is used to differentiate cycles in pattern matching, so
// that the operator can be correctly initialized whether to read the symbol
// or write it. E.g. `MATCH (n) -[r]- (n)` would bind (and write) the first
// `n`, but the latter `n` would only read the already written information.
std::unordered_set<Symbol> bound_symbols;
};
class RuleBasedPlanner {
public:
RuleBasedPlanner(PlanningContext &context) : context_(context) {}
using PlanResult = std::unique_ptr<LogicalOperator>;
PlanResult Plan(std::vector<QueryPart> &);
private:
PlanningContext &context_;
};
std::vector<QueryPart> CollectQueryParts(const SymbolTable &, AstTreeStorage &);
/// @brief Generates the LogicalOperator tree and returns the root operation.
///
/// The tree is constructed by traversing the @c Query node from given
/// @c AstTreeStorage. The storage may also be used to create new AST nodes for
/// use in operators. @c SymbolTable is used to determine inputs and outputs of
/// certain operators.
std::unique_ptr<LogicalOperator> MakeLogicalPlan(
AstTreeStorage &storage, query::SymbolTable &symbol_table);
template <class TPlanner>
typename TPlanner::PlanResult MakeLogicalPlan(AstTreeStorage &storage,
SymbolTable &symbol_table) {
auto query_parts = CollectQueryParts(symbol_table, storage);
PlanningContext context{symbol_table, storage};
return TPlanner(context).Plan(query_parts);
}
} // namespace plan

View File

@ -222,7 +222,7 @@ auto CheckPlan(LogicalOperator &plan, const SymbolTable &symbol_table,
template <class... TChecker>
auto CheckPlan(AstTreeStorage &storage, TChecker... checker) {
auto symbol_table = MakeSymbolTable(*storage.query());
auto plan = MakeLogicalPlan(storage, symbol_table);
auto plan = MakeLogicalPlan<RuleBasedPlanner>(storage, symbol_table);
CheckPlan(*plan, symbol_table, checker...);
}
@ -240,7 +240,7 @@ TEST(TestLogicalPlanner, CreateNodeReturn) {
auto query = QUERY(CREATE(PATTERN(NODE("n"))), RETURN(ident_n, AS("n")));
auto symbol_table = MakeSymbolTable(*query);
auto acc = ExpectAccumulate({symbol_table.at(*ident_n)});
auto plan = MakeLogicalPlan(storage, symbol_table);
auto plan = MakeLogicalPlan<RuleBasedPlanner>(storage, symbol_table);
CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, ExpectProduce());
}
@ -516,7 +516,7 @@ TEST(TestLogicalPlanner, CreateWithSum) {
auto symbol_table = MakeSymbolTable(*query);
auto acc = ExpectAccumulate({symbol_table.at(*n_prop->expression_)});
auto aggr = ExpectAggregate({sum}, {});
auto plan = MakeLogicalPlan(storage, symbol_table);
auto plan = MakeLogicalPlan<RuleBasedPlanner>(storage, symbol_table);
// We expect both the accumulation and aggregation because the part before
// WITH updates the database.
CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, aggr,
@ -553,7 +553,7 @@ TEST(TestLogicalPlanner, CreateWithSkipReturnLimit) {
RETURN(IDENT("m"), AS("m"), LIMIT(LITERAL(1))));
auto symbol_table = MakeSymbolTable(*query);
auto acc = ExpectAccumulate({symbol_table.at(*ident_n)});
auto plan = MakeLogicalPlan(storage, symbol_table);
auto plan = MakeLogicalPlan<RuleBasedPlanner>(storage, symbol_table);
// Since we have a write query, we need to have Accumulate. This is a bit
// different than Neo4j 3.0, which optimizes WITH followed by RETURN as a
// single RETURN clause and then moves Skip and Limit before Accumulate. This
@ -576,7 +576,7 @@ TEST(TestLogicalPlanner, CreateReturnSumSkipLimit) {
auto symbol_table = MakeSymbolTable(*query);
auto acc = ExpectAccumulate({symbol_table.at(*n_prop->expression_)});
auto aggr = ExpectAggregate({sum}, {});
auto plan = MakeLogicalPlan(storage, symbol_table);
auto plan = MakeLogicalPlan<RuleBasedPlanner>(storage, symbol_table);
CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, aggr, ExpectProduce(),
ExpectSkip(), ExpectLimit());
}
@ -615,7 +615,7 @@ TEST(TestLogicalPlanner, CreateWithOrderByWhere) {
symbol_table.at(*r_prop->expression_), // `r` in ORDER BY
symbol_table.at(*m_prop->expression_), // `m` in WHERE
});
auto plan = MakeLogicalPlan(storage, symbol_table);
auto plan = MakeLogicalPlan<RuleBasedPlanner>(storage, symbol_table);
CheckPlan(*plan, symbol_table, ExpectCreateNode(), ExpectCreateExpand(), acc,
ExpectProduce(), ExpectFilter(), ExpectOrderBy());
}
@ -653,7 +653,7 @@ TEST(TestLogicalPlanner, MatchMerge) {
auto symbol_table = MakeSymbolTable(*query);
// We expect Accumulate after Merge, because it is considered as a write.
auto acc = ExpectAccumulate({symbol_table.at(*ident_n)});
auto plan = MakeLogicalPlan(storage, symbol_table);
auto plan = MakeLogicalPlan<RuleBasedPlanner>(storage, symbol_table);
CheckPlan(*plan, symbol_table, ExpectScanAll(),
ExpectMerge(on_match, on_create), acc, ExpectProduce());
for (auto &op : on_match) delete op;
@ -710,7 +710,7 @@ TEST(TestLogicalPlanner, CreateWithDistinctSumWhereReturn) {
auto symbol_table = MakeSymbolTable(*query);
auto acc = ExpectAccumulate({symbol_table.at(*node_n->identifier_)});
auto aggr = ExpectAggregate({sum}, {});
auto plan = MakeLogicalPlan(storage, symbol_table);
auto plan = MakeLogicalPlan<RuleBasedPlanner>(storage, symbol_table);
CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, aggr, ExpectProduce(),
ExpectFilter(), ExpectDistinct(), ExpectProduce());
}
@ -792,7 +792,7 @@ TEST(TestLogicalPlanner, MatchReturnAsterisk) {
ret->body_.all_identifiers = true;
auto query = QUERY(MATCH(PATTERN(NODE("n"), EDGE("e"), NODE("m"))), ret);
auto symbol_table = MakeSymbolTable(*query);
auto plan = MakeLogicalPlan(storage, symbol_table);
auto plan = MakeLogicalPlan<RuleBasedPlanner>(storage, symbol_table);
CheckPlan(*plan, symbol_table, ExpectScanAll(), ExpectExpand(),
ExpectProduce());
std::vector<std::string> output_names;
@ -814,7 +814,7 @@ TEST(TestLogicalPlanner, MatchReturnAsteriskSum) {
ret->body_.all_identifiers = true;
auto query = QUERY(MATCH(PATTERN(NODE("n"))), ret);
auto symbol_table = MakeSymbolTable(*query);
auto plan = MakeLogicalPlan(storage, symbol_table);
auto plan = MakeLogicalPlan<RuleBasedPlanner>(storage, symbol_table);
auto *produce = dynamic_cast<Produce *>(plan.get());
ASSERT_TRUE(produce);
const auto &named_expressions = produce->named_expressions();