Split planners in headers and templatize PlanningContext

Summary:
This change should allow for passing a different PlanningContext and/or
GraphDbAccessor. In turn, we can write tests which pass a dummy context
for decoupled testing of the planning process (from the rest of the
system).

Reviewers: mislav.bradac

Reviewed By: mislav.bradac

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D700
This commit is contained in:
Teon Banek 2017-08-24 08:42:19 +02:00
parent 3ae35fa161
commit 4ccffbfd9b
5 changed files with 838 additions and 748 deletions

View File

@ -1,9 +1,11 @@
/// @file
/// This file is an entry point for invoking various planners via
/// `MakeLogicalPlan` API.
#pragma once
#include <memory>
#include "query/plan/operator.hpp"
#include "query/plan/rule_based_planner.hpp"
#include "query/plan/variable_start_planner.hpp"
namespace query {
@ -12,227 +14,28 @@ class SymbolTable;
namespace plan {
/// Normalized representation of a pattern that needs to be matched.
struct Expansion {
/// The first node in the expansion, it can be a single node.
NodeAtom *node1 = nullptr;
/// Optional edge which connects the 2 nodes.
EdgeAtom *edge = nullptr;
/// Direction of the edge, it may be flipped compared to original
/// @c EdgeAtom during plan generation.
EdgeAtom::Direction direction = EdgeAtom::Direction::BOTH;
/// Set of symbols found inside the range expressions of a variable path edge.
std::unordered_set<Symbol> symbols_in_range{};
/// Optional node at the other end of an edge. If the expansion
/// contains an edge, then this node is required.
NodeAtom *node2 = nullptr;
};
/// Stores information on filters used inside the @c Matching of a @c QueryPart.
class Filters {
public:
/// Stores the symbols and expression used to filter a property.
struct PropertyFilter {
using Bound = ScanAllByLabelPropertyRange::Bound;
/// Set of used symbols in the @c expression.
std::unordered_set<Symbol> used_symbols;
/// Expression which when evaluated produces the value a property must
/// equal.
Expression *expression = nullptr;
std::experimental::optional<Bound> lower_bound{};
std::experimental::optional<Bound> upper_bound{};
};
/// All filter expressions that should be generated.
auto &all_filters() { return all_filters_; }
const auto &all_filters() const { return all_filters_; }
/// Mapping from a symbol to labels that are filtered on it. These should be
/// used only for generating indexed scans.
const auto &label_filters() const { return label_filters_; }
/// Mapping from a symbol to properties that are filtered on it. These should
/// be used only for generating indexed scans.
const auto &property_filters() const { return property_filters_; }
/// Collects filtering information from a pattern.
///
/// Goes through all the atoms in a pattern and generates filter expressions
/// for found labels, properties and edge types. The generated expressions are
/// stored in @c all_filters. Also, @c label_filters and @c property_filters
/// are populated.
void CollectPatternFilters(Pattern &, SymbolTable &, AstTreeStorage &);
/// Collects filtering information from a where expression.
///
/// Takes the where expression and stores it in @c all_filters, then analyzes
/// the expression for additional information. The additional information is
/// used to populate @c label_filters and @c property_filters, so that indexed
/// scanning can use it.
void CollectWhereFilter(Where &, const SymbolTable &);
private:
void AnalyzeFilter(Expression *, const SymbolTable &);
std::vector<std::pair<Expression *, std::unordered_set<Symbol>>> all_filters_;
std::unordered_map<Symbol, std::set<GraphDbTypes::Label>> label_filters_;
std::unordered_map<
Symbol, std::map<GraphDbTypes::Property, std::vector<PropertyFilter>>>
property_filters_;
};
/// Normalized representation of a single or multiple Match clauses.
///
/// For example, `MATCH (a :Label) -[e1]- (b) -[e2]- (c) MATCH (n) -[e3]- (m)
/// WHERE c.prop < 42` will produce the following.
/// Expansions will store `(a) -[e1]-(b)`, `(b) -[e2]- (c)` and
/// `(n) -[e3]- (m)`.
/// Edge symbols for Cyphermorphism will only contain the set `{e1, e2}` for the
/// first `MATCH` and the set `{e3}` for the second.
/// Filters will contain 2 pairs. One for testing `:Label` on symbol `a` and the
/// other obtained from `WHERE` on symbol `c`.
struct Matching {
/// All expansions that need to be performed across @c Match clauses.
std::vector<Expansion> expansions;
/// Symbols for edges established in match, used to ensure Cyphermorphism.
///
/// There are multiple sets, because each Match clause determines a single
/// set.
std::vector<std::unordered_set<Symbol>> edge_symbols;
/// Information on used filter expressions while matching.
Filters filters;
/// Maps node symbols to expansions which bind them.
std::unordered_map<Symbol, std::set<int>> node_symbol_to_expansions{};
/// All node and edge symbols across all expansions (from all matches).
std::unordered_set<Symbol> expansion_symbols{};
};
/// @brief Represents a read (+ write) part of a query. Parts are split on
/// `WITH` clauses.
///
/// Each part ends with either:
///
/// * `RETURN` clause;
/// * `WITH` clause or
/// * any of the write clauses.
///
/// For a query `MATCH (n) MERGE (n) -[e]- (m) SET n.x = 42 MERGE (l)` the
/// generated QueryPart will have `matching` generated for the `MATCH`.
/// `remaining_clauses` will contain `Merge`, `SetProperty` and `Merge` clauses
/// in that exact order. The pattern inside the first `MERGE` will be used to
/// generate the first `merge_matching` element, and the second `MERGE` pattern
/// will produce the second `merge_matching` element. This way, if someone
/// traverses `remaining_clauses`, the order of appearance of `Merge` clauses is
/// in the same order as their respective `merge_matching` elements.
struct QueryPart {
/// @brief All `MATCH` clauses merged into one @c Matching.
Matching matching;
/// @brief Each `OPTIONAL MATCH` converted to @c Matching.
std::vector<Matching> optional_matching{};
/// @brief @c Matching for each `MERGE` clause.
///
/// Storing the normalized pattern of a @c Merge does not preclude storing the
/// @c Merge clause itself inside `remaining_clauses`. The reason is that we
/// need to have access to other parts of the clause, such as `SET` clauses
/// which need to be run.
///
/// Since @c Merge is contained in `remaining_clauses`, this vector contains
/// matching in the same order as @c Merge appears.
std::vector<Matching> merge_matching{};
/// @brief All the remaining clauses (without @c Match).
std::vector<Clause *> remaining_clauses{};
};
/// @brief Context which contains variables commonly used during planning.
struct PlanningContext {
/// @brief SymbolTable is used to determine inputs and outputs of planned
/// operators.
///
/// Newly created AST nodes may be added to reference existing symbols.
SymbolTable &symbol_table;
/// @brief The storage is used to traverse the AST as well as create new nodes
/// for use in operators.
AstTreeStorage &ast_storage;
/// @brief GraphDbAccessor, which may be used to get some information from the
/// database to generate better plans. The accessor is required only to live
/// long enough for the plan generation to finish.
const GraphDbAccessor &db;
/// @brief Symbol set is used to differentiate cycles in pattern matching.
///
/// During planning, symbols will be added as each operator produces values
/// for them. This way, the operator can be correctly initialized whether to
/// read a symbol or write it. E.g. `MATCH (n) -[r]- (n)` would bind (and
/// write) the first `n`, but the latter `n` would only read the already
/// written information.
std::unordered_set<Symbol> bound_symbols{};
};
/// @brief Planner which uses hardcoded rules to produce operators.
///
/// @sa MakeLogicalPlan
class RuleBasedPlanner {
public:
explicit RuleBasedPlanner(PlanningContext &context) : context_(context) {}
/// @brief The result of plan generation is the root of the generated operator
/// tree.
using PlanResult = std::unique_ptr<LogicalOperator>;
/// @brief Generates the operator tree based on explicitly set rules.
PlanResult Plan(std::vector<QueryPart> &);
private:
PlanningContext &context_;
};
/// @brief Planner which generates multiple plans by changing the order of graph
/// traversal.
///
/// This planner picks different starting nodes from which to start graph
/// traversal. Generating a single plan is backed by @c RuleBasedPlanner.
///
/// @sa MakeLogicalPlan
class VariableStartPlanner {
public:
explicit VariableStartPlanner(PlanningContext &context) : context_(context) {}
/// @brief The result of plan generation is a vector of roots to multiple
/// generated operator trees.
using PlanResult = std::vector<std::unique_ptr<LogicalOperator>>;
/// @brief Generate multiple plans by varying the order of graph traversal.
PlanResult Plan(std::vector<QueryPart> &);
private:
PlanningContext &context_;
};
/// @brief Convert the AST to multiple @c QueryParts.
///
/// This function will normalize patterns inside @c Match and @c Merge clauses
/// and do some other preprocessing in order to generate multiple @c QueryPart
/// structures. @c AstTreeStorage and @c SymbolTable may be used to create new
/// AST nodes.
std::vector<QueryPart> CollectQueryParts(SymbolTable &, AstTreeStorage &);
/// @brief Generates the LogicalOperator tree and returns the resulting plan.
///
/// @tparam TPlanner Type of the planner used for generation.
/// @tparam TDbAccessor Type of the database accessor used for generation.
/// @param storage AstTreeStorage used to construct the operator tree by
/// traversing the @c Query node. The storage may also be used to create new
/// AST nodes for use in operators.
/// @param symbol_table SymbolTable used to determine inputs and outputs of
/// certain operators. Newly created AST nodes may be added to this symbol
/// table.
/// @param db Optional @c GraphDbAccessor, which is used to query database
/// information in order to improve generated plans.
/// @param db @c TDbAccessor, which is used to query database information in
/// order to improve generated plans.
/// @return @c PlanResult which depends on the @c TPlanner used.
///
/// @sa RuleBasedPlanner
/// @sa VariableStartPlanner
template <class TPlanner>
typename TPlanner::PlanResult MakeLogicalPlan(AstTreeStorage &storage,
SymbolTable &symbol_table,
const GraphDbAccessor &db) {
template <template <class> class TPlanner, class TDbAccessor>
auto MakeLogicalPlan(AstTreeStorage &storage, SymbolTable &symbol_table,
const TDbAccessor &db) {
auto query_parts = CollectQueryParts(symbol_table, storage);
PlanningContext context{symbol_table, storage, db};
return TPlanner(context).Plan(query_parts);
PlanningContext<TDbAccessor> context{symbol_table, storage, db};
return TPlanner<decltype(context)>(context).Plan(query_parts);
}
} // namespace plan

View File

@ -1,11 +1,10 @@
#include "query/plan/planner.hpp"
#include "query/plan/rule_based_planner.hpp"
#include <algorithm>
#include <functional>
#include <limits>
#include <unordered_set>
#include "query/frontend/ast/ast.hpp"
#include "utils/algorithm.hpp"
#include "utils/exceptions.hpp"
@ -13,14 +12,6 @@ namespace query::plan {
namespace {
// Returns false if the symbol was already bound, otherwise binds it and
// returns true.
bool BindSymbol(std::unordered_set<Symbol> &bound_symbols,
const Symbol &symbol) {
auto insertion = bound_symbols.insert(symbol);
return insertion.second;
}
/// Utility function for iterating pattern atoms and accumulating a result.
///
/// Each pattern is of the form `NodeAtom (, EdgeAtom, NodeAtom)*`. Therefore,
@ -83,44 +74,13 @@ void ForEachPattern(
}
}
auto GenCreateForPattern(Pattern &pattern, LogicalOperator *input_op,
const SymbolTable &symbol_table,
std::unordered_set<Symbol> &bound_symbols) {
auto base = [&](NodeAtom *node) -> LogicalOperator * {
if (BindSymbol(bound_symbols, symbol_table.at(*node->identifier_)))
return new CreateNode(node, std::shared_ptr<LogicalOperator>(input_op));
else
return input_op;
};
auto collect = [&](LogicalOperator *last_op, NodeAtom *prev_node,
EdgeAtom *edge, NodeAtom *node) {
// Store the symbol from the first node as the input to CreateExpand.
const auto &input_symbol = symbol_table.at(*prev_node->identifier_);
// If the expand node was already bound, then we need to indicate this,
// so that CreateExpand only creates an edge.
bool node_existing = false;
if (!BindSymbol(bound_symbols, symbol_table.at(*node->identifier_))) {
node_existing = true;
}
if (!BindSymbol(bound_symbols, symbol_table.at(*edge->identifier_))) {
permanent_fail("Symbols used for created edges cannot be redeclared.");
}
return new CreateExpand(node, edge,
std::shared_ptr<LogicalOperator>(last_op),
input_symbol, node_existing);
};
return ReducePattern<LogicalOperator *>(pattern, base, collect);
}
auto GenCreate(Create &create, LogicalOperator *input_op,
const SymbolTable &symbol_table,
std::unordered_set<Symbol> &bound_symbols) {
auto last_op = input_op;
for (auto pattern : create.patterns_) {
last_op =
GenCreateForPattern(*pattern, last_op, symbol_table, bound_symbols);
last_op = impl::GenCreateForPattern(*pattern, last_op, symbol_table,
bound_symbols);
}
return last_op;
}
@ -174,42 +134,6 @@ Expression *BoolJoin(AstTreeStorage &storage, Expression *expr1,
return expr1 ? expr1 : expr2;
}
// Contextual information used for generating match operators.
struct MatchContext {
const SymbolTable &symbol_table;
// Already bound symbols, which are used to determine whether the operator
// should reference them or establish new. This is both read from and written
// to during generation.
std::unordered_set<Symbol> &bound_symbols;
// Determines whether the match should see the new graph state or not.
GraphView graph_view = GraphView::OLD;
// All the newly established symbols in match.
std::vector<Symbol> new_symbols{};
};
auto GenFilters(LogicalOperator *last_op,
const std::unordered_set<Symbol> &bound_symbols,
std::vector<std::pair<Expression *, std::unordered_set<Symbol>>>
&all_filters,
AstTreeStorage &storage) {
Expression *filter_expr = nullptr;
for (auto filters_it = all_filters.begin();
filters_it != all_filters.end();) {
if (HasBoundFilterSymbols(bound_symbols, *filters_it)) {
filter_expr =
BoolJoin<FilterAndOperator>(storage, filter_expr, filters_it->first);
filters_it = all_filters.erase(filters_it);
} else {
filters_it++;
}
}
if (filter_expr) {
last_op =
new Filter(std::shared_ptr<LogicalOperator>(last_op), filter_expr);
}
return last_op;
}
// Ast tree visitor which collects the context for a return body.
// The return body of WITH and RETURN clauses consists of:
//
@ -573,78 +497,6 @@ auto GenReturnBody(LogicalOperator *input_op, bool advance_command,
return last_op;
}
auto GenWith(With &with, LogicalOperator *input_op, SymbolTable &symbol_table,
bool is_write, std::unordered_set<Symbol> &bound_symbols,
AstTreeStorage &storage) {
// WITH clause is Accumulate/Aggregate (advance_command) + Produce and
// optional Filter. In case of update and aggregation, we want to accumulate
// first, so that when aggregating, we get the latest results. Similar to
// RETURN clause.
bool accumulate = is_write;
// No need to advance the command if we only performed reads.
bool advance_command = is_write;
ReturnBodyContext body(with.body_, symbol_table, bound_symbols, storage,
with.where_);
LogicalOperator *last_op =
GenReturnBody(input_op, advance_command, body, accumulate);
// Reset bound symbols, so that only those in WITH are exposed.
bound_symbols.clear();
for (const auto &symbol : body.output_symbols()) {
BindSymbol(bound_symbols, symbol);
}
return last_op;
}
auto GenReturn(Return &ret, LogicalOperator *input_op,
SymbolTable &symbol_table, bool is_write,
const std::unordered_set<Symbol> &bound_symbols,
AstTreeStorage &storage) {
// Similar to WITH clause, but we want to accumulate and advance command when
// the query writes to the database. This way we handle the case when we want
// to return expressions with the latest updated results. For example,
// `MATCH (n) -- () SET n.prop = n.prop + 1 RETURN n.prop`. If we match same
// `n` multiple 'k' times, we want to return 'k' results where the property
// value is the same, final result of 'k' increments.
bool accumulate = is_write;
bool advance_command = false;
ReturnBodyContext body(ret.body_, symbol_table, bound_symbols, storage);
return GenReturnBody(input_op, advance_command, body, accumulate);
}
// Generate an operator for a clause which writes to the database. If the clause
// isn't handled, returns nullptr.
LogicalOperator *HandleWriteClause(Clause *clause, LogicalOperator *input_op,
const SymbolTable &symbol_table,
std::unordered_set<Symbol> &bound_symbols) {
if (auto *create = dynamic_cast<Create *>(clause)) {
return GenCreate(*create, input_op, symbol_table, bound_symbols);
} else if (auto *del = dynamic_cast<query::Delete *>(clause)) {
return new plan::Delete(std::shared_ptr<LogicalOperator>(input_op),
del->expressions_, del->detach_);
} else if (auto *set = dynamic_cast<query::SetProperty *>(clause)) {
return new plan::SetProperty(std::shared_ptr<LogicalOperator>(input_op),
set->property_lookup_, set->expression_);
} else if (auto *set = dynamic_cast<query::SetProperties *>(clause)) {
auto op = set->update_ ? plan::SetProperties::Op::UPDATE
: plan::SetProperties::Op::REPLACE;
const auto &input_symbol = symbol_table.at(*set->identifier_);
return new plan::SetProperties(std::shared_ptr<LogicalOperator>(input_op),
input_symbol, set->expression_, op);
} else if (auto *set = dynamic_cast<query::SetLabels *>(clause)) {
const auto &input_symbol = symbol_table.at(*set->identifier_);
return new plan::SetLabels(std::shared_ptr<LogicalOperator>(input_op),
input_symbol, set->labels_);
} else if (auto *rem = dynamic_cast<query::RemoveProperty *>(clause)) {
return new plan::RemoveProperty(std::shared_ptr<LogicalOperator>(input_op),
rem->property_lookup_);
} else if (auto *rem = dynamic_cast<query::RemoveLabels *>(clause)) {
const auto &input_symbol = symbol_table.at(*rem->identifier_);
return new plan::RemoveLabels(std::shared_ptr<LogicalOperator>(input_op),
input_symbol, rem->labels_);
}
return nullptr;
}
// Converts multiple Patterns to Expansions. Each Pattern can contain an
// arbitrarily long chain of nodes and edges. The conversion to an Expansion is
// done by splitting a pattern into triplets (node1, edge, node2). The triplets
@ -803,176 +655,6 @@ bool FindBestLabelPropertyIndex(
return found;
}
ScanAll *GenScanByIndex(
LogicalOperator *last_op, const GraphDbAccessor &db,
const Symbol &node_symbol, const MatchContext &context,
const std::set<GraphDbTypes::Label> &labels,
const std::map<GraphDbTypes::Property, std::vector<Filters::PropertyFilter>>
&properties) {
debug_assert(!labels.empty(),
"Without labels, indexed data cannot be scanned.");
// First, try to see if we can use label+property index. If not, use just the
// label index (which ought to exist).
GraphDbTypes::Label best_label;
std::pair<GraphDbTypes::Property, Filters::PropertyFilter> best_property;
if (FindBestLabelPropertyIndex(db, labels, properties, node_symbol,
context.bound_symbols, best_label,
best_property)) {
const auto &prop_filter = best_property.second;
if (prop_filter.lower_bound || prop_filter.upper_bound) {
return new ScanAllByLabelPropertyRange(
std::shared_ptr<LogicalOperator>(last_op), node_symbol, best_label,
best_property.first, prop_filter.lower_bound, prop_filter.upper_bound,
context.graph_view);
} else {
debug_assert(
prop_filter.expression,
"Property filter should either have bounds or an expression.");
return new ScanAllByLabelPropertyValue(
std::shared_ptr<LogicalOperator>(last_op), node_symbol, best_label,
best_property.first, prop_filter.expression, context.graph_view);
}
}
auto label = FindBestLabelIndex(db, labels);
return new ScanAllByLabel(std::shared_ptr<LogicalOperator>(last_op),
node_symbol, label, context.graph_view);
}
LogicalOperator *PlanMatching(const Matching &matching,
LogicalOperator *input_op,
PlanningContext &planning_ctx,
MatchContext &context) {
auto &bound_symbols = context.bound_symbols;
auto &storage = planning_ctx.ast_storage;
const auto &symbol_table = context.symbol_table;
// Copy all_filters, because we will modify the list as we generate Filters.
auto all_filters = matching.filters.all_filters();
// Try to generate any filters even before the 1st match operator. This
// optimizes the optional match which filters only on symbols bound in regular
// match.
auto *last_op = GenFilters(input_op, bound_symbols, all_filters, storage);
for (const auto &expansion : matching.expansions) {
const auto &node1_symbol = symbol_table.at(*expansion.node1->identifier_);
if (BindSymbol(bound_symbols, node1_symbol)) {
// We have just bound this symbol, so generate ScanAll which fills it.
auto labels = FindOr(matching.filters.label_filters(), node1_symbol,
std::set<GraphDbTypes::Label>())
.first;
if (labels.empty()) {
// Without labels, we can only generate ScanAll of everything.
last_op = new ScanAll(std::shared_ptr<LogicalOperator>(last_op),
node1_symbol, context.graph_view);
} else {
// With labels, we can scan indexed data.
auto properties =
FindOr(matching.filters.property_filters(), node1_symbol,
std::map<GraphDbTypes::Property,
std::vector<Filters::PropertyFilter>>())
.first;
last_op = GenScanByIndex(last_op, planning_ctx.db, node1_symbol,
context, labels, properties);
}
context.new_symbols.emplace_back(node1_symbol);
last_op = GenFilters(last_op, bound_symbols, all_filters, storage);
}
// We have an edge, so generate Expand.
if (expansion.edge) {
// If the expand symbols were already bound, then we need to indicate
// that they exist. The Expand will then check whether the pattern holds
// instead of writing the expansion to symbols.
const auto &node_symbol = symbol_table.at(*expansion.node2->identifier_);
auto existing_node = false;
if (!BindSymbol(bound_symbols, node_symbol)) {
existing_node = true;
} else {
context.new_symbols.emplace_back(node_symbol);
}
const auto &edge_symbol = symbol_table.at(*expansion.edge->identifier_);
auto existing_edge = false;
if (!BindSymbol(bound_symbols, edge_symbol)) {
existing_edge = true;
} else {
context.new_symbols.emplace_back(edge_symbol);
}
if (auto *bf_atom = dynamic_cast<BreadthFirstAtom *>(expansion.edge)) {
const auto &traversed_edge_symbol =
symbol_table.at(*bf_atom->traversed_edge_identifier_);
const auto &next_node_symbol =
symbol_table.at(*bf_atom->next_node_identifier_);
last_op = new ExpandBreadthFirst(
node_symbol, edge_symbol, expansion.direction, bf_atom->max_depth_,
next_node_symbol, traversed_edge_symbol,
bf_atom->filter_expression_,
std::shared_ptr<LogicalOperator>(last_op), node1_symbol,
existing_node, context.graph_view);
} else if (expansion.edge->has_range_) {
last_op = new ExpandVariable(
node_symbol, edge_symbol, expansion.direction,
expansion.direction != expansion.edge->direction_,
expansion.edge->lower_bound_, expansion.edge->upper_bound_,
std::shared_ptr<LogicalOperator>(last_op), node1_symbol,
existing_node, existing_edge, context.graph_view);
} else {
last_op =
new Expand(node_symbol, edge_symbol, expansion.direction,
std::shared_ptr<LogicalOperator>(last_op), node1_symbol,
existing_node, existing_edge, context.graph_view);
}
if (!existing_edge) {
// Ensure Cyphermorphism (different edge symbols always map to different
// edges).
for (const auto &edge_symbols : matching.edge_symbols) {
if (edge_symbols.find(edge_symbol) == edge_symbols.end()) {
continue;
}
std::vector<Symbol> other_symbols;
for (const auto &symbol : edge_symbols) {
if (symbol == edge_symbol ||
bound_symbols.find(symbol) == bound_symbols.end()) {
continue;
}
other_symbols.push_back(symbol);
}
if (!other_symbols.empty()) {
last_op = new ExpandUniquenessFilter<EdgeAccessor>(
std::shared_ptr<LogicalOperator>(last_op), edge_symbol,
other_symbols);
}
}
}
last_op = GenFilters(last_op, bound_symbols, all_filters, storage);
}
}
debug_assert(all_filters.empty(), "Expected to generate all filters");
return last_op;
}
auto GenMerge(query::Merge &merge, LogicalOperator *input_op,
const Matching &matching, PlanningContext &context) {
// Copy the bound symbol set, because we don't want to use the updated version
// when generating the create part.
std::unordered_set<Symbol> bound_symbols_copy(context.bound_symbols);
MatchContext match_ctx{context.symbol_table, bound_symbols_copy,
GraphView::NEW};
auto on_match = PlanMatching(matching, nullptr, context, match_ctx);
// Use the original bound_symbols, so we fill it with new symbols.
auto on_create = GenCreateForPattern(
*merge.pattern_, nullptr, context.symbol_table, context.bound_symbols);
for (auto &set : merge.on_create_) {
on_create = HandleWriteClause(set, on_create, context.symbol_table,
context.bound_symbols);
debug_assert(on_create, "Expected SET in MERGE ... ON CREATE");
}
for (auto &set : merge.on_match_) {
on_match = HandleWriteClause(set, on_match, context.symbol_table,
context.bound_symbols);
debug_assert(on_match, "Expected SET in MERGE ... ON MATCH");
}
return new plan::Merge(std::shared_ptr<LogicalOperator>(input_op),
std::shared_ptr<LogicalOperator>(on_match),
std::shared_ptr<LogicalOperator>(on_create));
}
} // namespace
// Analyzes the filter expression by collecting information on filtering labels
@ -1188,65 +870,179 @@ std::vector<QueryPart> CollectQueryParts(SymbolTable &symbol_table,
return query_parts;
}
std::unique_ptr<LogicalOperator> RuleBasedPlanner::Plan(
std::vector<QueryPart> &query_parts) {
auto &context = context_;
LogicalOperator *input_op = nullptr;
// Set to true if a query command writes to the database.
bool is_write = false;
for (const auto &query_part : query_parts) {
MatchContext match_ctx{context.symbol_table, context.bound_symbols};
input_op = PlanMatching(query_part.matching, input_op, context, match_ctx);
for (const auto &matching : query_part.optional_matching) {
MatchContext opt_ctx{context.symbol_table, context.bound_symbols};
auto *match_op = PlanMatching(matching, nullptr, context, opt_ctx);
if (match_op) {
input_op = new Optional(std::shared_ptr<LogicalOperator>(input_op),
std::shared_ptr<LogicalOperator>(match_op),
opt_ctx.new_symbols);
}
}
int merge_id = 0;
for (auto &clause : query_part.remaining_clauses) {
debug_assert(dynamic_cast<Match *>(clause) == nullptr,
"Unexpected Match in remaining clauses");
if (auto *ret = dynamic_cast<Return *>(clause)) {
input_op = GenReturn(*ret, input_op, context.symbol_table, is_write,
context.bound_symbols, context.ast_storage);
} else if (auto *merge = dynamic_cast<query::Merge *>(clause)) {
input_op = GenMerge(*merge, input_op,
query_part.merge_matching[merge_id++], context);
// Treat MERGE clause as write, because we do not know if it will create
// anything.
is_write = true;
} else if (auto *with = dynamic_cast<query::With *>(clause)) {
input_op = GenWith(*with, input_op, context.symbol_table, is_write,
context.bound_symbols, context.ast_storage);
// WITH clause advances the command, so reset the flag.
is_write = false;
} else if (auto *op =
HandleWriteClause(clause, input_op, context.symbol_table,
context.bound_symbols)) {
is_write = true;
input_op = op;
} else if (auto *unwind = dynamic_cast<query::Unwind *>(clause)) {
const auto &symbol =
context.symbol_table.at(*unwind->named_expression_);
BindSymbol(context.bound_symbols, symbol);
input_op =
new plan::Unwind(std::shared_ptr<LogicalOperator>(input_op),
unwind->named_expression_->expression_, symbol);
} else if (auto *create_index =
dynamic_cast<query::CreateIndex *>(clause)) {
debug_assert(!input_op, "Unexpected operator before CreateIndex");
input_op = new plan::CreateIndex(create_index->label_,
create_index->property_);
} else {
throw utils::NotYetImplemented("clause conversion to operator(s)");
}
}
}
return std::unique_ptr<LogicalOperator>(input_op);
namespace impl {
// Returns false if the symbol was already bound, otherwise binds it and
// returns true.
bool BindSymbol(std::unordered_set<Symbol> &bound_symbols,
const Symbol &symbol) {
auto insertion = bound_symbols.insert(symbol);
return insertion.second;
}
LogicalOperator *GenFilters(
LogicalOperator *last_op, const std::unordered_set<Symbol> &bound_symbols,
std::vector<std::pair<Expression *, std::unordered_set<Symbol>>>
&all_filters,
AstTreeStorage &storage) {
Expression *filter_expr = nullptr;
for (auto filters_it = all_filters.begin();
filters_it != all_filters.end();) {
if (HasBoundFilterSymbols(bound_symbols, *filters_it)) {
filter_expr =
BoolJoin<FilterAndOperator>(storage, filter_expr, filters_it->first);
filters_it = all_filters.erase(filters_it);
} else {
filters_it++;
}
}
if (filter_expr) {
last_op =
new Filter(std::shared_ptr<LogicalOperator>(last_op), filter_expr);
}
return last_op;
}
ScanAll *GenScanByIndex(
LogicalOperator *last_op, const GraphDbAccessor &db,
const Symbol &node_symbol, const MatchContext &context,
const std::set<GraphDbTypes::Label> &labels,
const std::map<GraphDbTypes::Property, std::vector<Filters::PropertyFilter>>
&properties) {
debug_assert(!labels.empty(),
"Without labels, indexed data cannot be scanned.");
// First, try to see if we can use label+property index. If not, use just the
// label index (which ought to exist).
GraphDbTypes::Label best_label;
std::pair<GraphDbTypes::Property, Filters::PropertyFilter> best_property;
if (FindBestLabelPropertyIndex(db, labels, properties, node_symbol,
context.bound_symbols, best_label,
best_property)) {
const auto &prop_filter = best_property.second;
if (prop_filter.lower_bound || prop_filter.upper_bound) {
return new ScanAllByLabelPropertyRange(
std::shared_ptr<LogicalOperator>(last_op), node_symbol, best_label,
best_property.first, prop_filter.lower_bound, prop_filter.upper_bound,
context.graph_view);
} else {
debug_assert(
prop_filter.expression,
"Property filter should either have bounds or an expression.");
return new ScanAllByLabelPropertyValue(
std::shared_ptr<LogicalOperator>(last_op), node_symbol, best_label,
best_property.first, prop_filter.expression, context.graph_view);
}
}
auto label = FindBestLabelIndex(db, labels);
return new ScanAllByLabel(std::shared_ptr<LogicalOperator>(last_op),
node_symbol, label, context.graph_view);
}
LogicalOperator *GenReturn(Return &ret, LogicalOperator *input_op,
SymbolTable &symbol_table, bool is_write,
const std::unordered_set<Symbol> &bound_symbols,
AstTreeStorage &storage) {
// Similar to WITH clause, but we want to accumulate and advance command when
// the query writes to the database. This way we handle the case when we want
// to return expressions with the latest updated results. For example,
// `MATCH (n) -- () SET n.prop = n.prop + 1 RETURN n.prop`. If we match same
// `n` multiple 'k' times, we want to return 'k' results where the property
// value is the same, final result of 'k' increments.
bool accumulate = is_write;
bool advance_command = false;
ReturnBodyContext body(ret.body_, symbol_table, bound_symbols, storage);
return GenReturnBody(input_op, advance_command, body, accumulate);
}
LogicalOperator *GenCreateForPattern(
Pattern &pattern, LogicalOperator *input_op,
const SymbolTable &symbol_table,
std::unordered_set<Symbol> &bound_symbols) {
auto base = [&](NodeAtom *node) -> LogicalOperator * {
if (BindSymbol(bound_symbols, symbol_table.at(*node->identifier_)))
return new CreateNode(node, std::shared_ptr<LogicalOperator>(input_op));
else
return input_op;
};
auto collect = [&](LogicalOperator *last_op, NodeAtom *prev_node,
EdgeAtom *edge, NodeAtom *node) {
// Store the symbol from the first node as the input to CreateExpand.
const auto &input_symbol = symbol_table.at(*prev_node->identifier_);
// If the expand node was already bound, then we need to indicate this,
// so that CreateExpand only creates an edge.
bool node_existing = false;
if (!BindSymbol(bound_symbols, symbol_table.at(*node->identifier_))) {
node_existing = true;
}
if (!BindSymbol(bound_symbols, symbol_table.at(*edge->identifier_))) {
permanent_fail("Symbols used for created edges cannot be redeclared.");
}
return new CreateExpand(node, edge,
std::shared_ptr<LogicalOperator>(last_op),
input_symbol, node_existing);
};
return ReducePattern<LogicalOperator *>(pattern, base, collect);
}
// Generate an operator for a clause which writes to the database. If the clause
// isn't handled, returns nullptr.
LogicalOperator *HandleWriteClause(Clause *clause, LogicalOperator *input_op,
const SymbolTable &symbol_table,
std::unordered_set<Symbol> &bound_symbols) {
if (auto *create = dynamic_cast<Create *>(clause)) {
return GenCreate(*create, input_op, symbol_table, bound_symbols);
} else if (auto *del = dynamic_cast<query::Delete *>(clause)) {
return new plan::Delete(std::shared_ptr<LogicalOperator>(input_op),
del->expressions_, del->detach_);
} else if (auto *set = dynamic_cast<query::SetProperty *>(clause)) {
return new plan::SetProperty(std::shared_ptr<LogicalOperator>(input_op),
set->property_lookup_, set->expression_);
} else if (auto *set = dynamic_cast<query::SetProperties *>(clause)) {
auto op = set->update_ ? plan::SetProperties::Op::UPDATE
: plan::SetProperties::Op::REPLACE;
const auto &input_symbol = symbol_table.at(*set->identifier_);
return new plan::SetProperties(std::shared_ptr<LogicalOperator>(input_op),
input_symbol, set->expression_, op);
} else if (auto *set = dynamic_cast<query::SetLabels *>(clause)) {
const auto &input_symbol = symbol_table.at(*set->identifier_);
return new plan::SetLabels(std::shared_ptr<LogicalOperator>(input_op),
input_symbol, set->labels_);
} else if (auto *rem = dynamic_cast<query::RemoveProperty *>(clause)) {
return new plan::RemoveProperty(std::shared_ptr<LogicalOperator>(input_op),
rem->property_lookup_);
} else if (auto *rem = dynamic_cast<query::RemoveLabels *>(clause)) {
const auto &input_symbol = symbol_table.at(*rem->identifier_);
return new plan::RemoveLabels(std::shared_ptr<LogicalOperator>(input_op),
input_symbol, rem->labels_);
}
return nullptr;
}
LogicalOperator *GenWith(With &with, LogicalOperator *input_op,
SymbolTable &symbol_table, bool is_write,
std::unordered_set<Symbol> &bound_symbols,
AstTreeStorage &storage) {
// WITH clause is Accumulate/Aggregate (advance_command) + Produce and
// optional Filter. In case of update and aggregation, we want to accumulate
// first, so that when aggregating, we get the latest results. Similar to
// RETURN clause.
bool accumulate = is_write;
// No need to advance the command if we only performed reads.
bool advance_command = is_write;
ReturnBodyContext body(with.body_, symbol_table, bound_symbols, storage,
with.where_);
LogicalOperator *last_op =
GenReturnBody(input_op, advance_command, body, accumulate);
// Reset bound symbols, so that only those in WITH are exposed.
bound_symbols.clear();
for (const auto &symbol : body.output_symbols()) {
BindSymbol(bound_symbols, symbol);
}
return last_op;
}
} // namespace impl
} // namespace query::plan

View File

@ -0,0 +1,448 @@
/// @file
#pragma once
#include "query/frontend/ast/ast.hpp"
#include "query/plan/operator.hpp"
namespace query::plan {
/// Normalized representation of a pattern that needs to be matched.
struct Expansion {
/// The first node in the expansion, it can be a single node.
NodeAtom *node1 = nullptr;
/// Optional edge which connects the 2 nodes.
EdgeAtom *edge = nullptr;
/// Direction of the edge, it may be flipped compared to original
/// @c EdgeAtom during plan generation.
EdgeAtom::Direction direction = EdgeAtom::Direction::BOTH;
/// Set of symbols found inside the range expressions of a variable path edge.
std::unordered_set<Symbol> symbols_in_range{};
/// Optional node at the other end of an edge. If the expansion
/// contains an edge, then this node is required.
NodeAtom *node2 = nullptr;
};
/// Stores information on filters used inside the @c Matching of a @c QueryPart.
class Filters {
public:
/// Stores the symbols and expression used to filter a property.
struct PropertyFilter {
using Bound = ScanAllByLabelPropertyRange::Bound;
/// Set of used symbols in the @c expression.
std::unordered_set<Symbol> used_symbols;
/// Expression which when evaluated produces the value a property must
/// equal.
Expression *expression = nullptr;
std::experimental::optional<Bound> lower_bound{};
std::experimental::optional<Bound> upper_bound{};
};
/// All filter expressions that should be generated.
auto &all_filters() { return all_filters_; }
const auto &all_filters() const { return all_filters_; }
/// Mapping from a symbol to labels that are filtered on it. These should be
/// used only for generating indexed scans.
const auto &label_filters() const { return label_filters_; }
/// Mapping from a symbol to properties that are filtered on it. These should
/// be used only for generating indexed scans.
const auto &property_filters() const { return property_filters_; }
/// Collects filtering information from a pattern.
///
/// Goes through all the atoms in a pattern and generates filter expressions
/// for found labels, properties and edge types. The generated expressions are
/// stored in @c all_filters. Also, @c label_filters and @c property_filters
/// are populated.
void CollectPatternFilters(Pattern &, SymbolTable &, AstTreeStorage &);
/// Collects filtering information from a where expression.
///
/// Takes the where expression and stores it in @c all_filters, then analyzes
/// the expression for additional information. The additional information is
/// used to populate @c label_filters and @c property_filters, so that indexed
/// scanning can use it.
void CollectWhereFilter(Where &, const SymbolTable &);
private:
void AnalyzeFilter(Expression *, const SymbolTable &);
std::vector<std::pair<Expression *, std::unordered_set<Symbol>>> all_filters_;
std::unordered_map<Symbol, std::set<GraphDbTypes::Label>> label_filters_;
std::unordered_map<
Symbol, std::map<GraphDbTypes::Property, std::vector<PropertyFilter>>>
property_filters_;
};
/// Normalized representation of a single or multiple Match clauses.
///
/// For example, `MATCH (a :Label) -[e1]- (b) -[e2]- (c) MATCH (n) -[e3]- (m)
/// WHERE c.prop < 42` will produce the following.
/// Expansions will store `(a) -[e1]-(b)`, `(b) -[e2]- (c)` and
/// `(n) -[e3]- (m)`.
/// Edge symbols for Cyphermorphism will only contain the set `{e1, e2}` for the
/// first `MATCH` and the set `{e3}` for the second.
/// Filters will contain 2 pairs. One for testing `:Label` on symbol `a` and the
/// other obtained from `WHERE` on symbol `c`.
struct Matching {
/// All expansions that need to be performed across @c Match clauses.
std::vector<Expansion> expansions;
/// Symbols for edges established in match, used to ensure Cyphermorphism.
///
/// There are multiple sets, because each Match clause determines a single
/// set.
std::vector<std::unordered_set<Symbol>> edge_symbols;
/// Information on used filter expressions while matching.
Filters filters;
/// Maps node symbols to expansions which bind them.
std::unordered_map<Symbol, std::set<int>> node_symbol_to_expansions{};
/// All node and edge symbols across all expansions (from all matches).
std::unordered_set<Symbol> expansion_symbols{};
};
/// @brief Represents a read (+ write) part of a query. Parts are split on
/// `WITH` clauses.
///
/// Each part ends with either:
///
/// * `RETURN` clause;
/// * `WITH` clause or
/// * any of the write clauses.
///
/// For a query `MATCH (n) MERGE (n) -[e]- (m) SET n.x = 42 MERGE (l)` the
/// generated QueryPart will have `matching` generated for the `MATCH`.
/// `remaining_clauses` will contain `Merge`, `SetProperty` and `Merge` clauses
/// in that exact order. The pattern inside the first `MERGE` will be used to
/// generate the first `merge_matching` element, and the second `MERGE` pattern
/// will produce the second `merge_matching` element. This way, if someone
/// traverses `remaining_clauses`, the order of appearance of `Merge` clauses is
/// in the same order as their respective `merge_matching` elements.
struct QueryPart {
/// @brief All `MATCH` clauses merged into one @c Matching.
Matching matching;
/// @brief Each `OPTIONAL MATCH` converted to @c Matching.
std::vector<Matching> optional_matching{};
/// @brief @c Matching for each `MERGE` clause.
///
/// Storing the normalized pattern of a @c Merge does not preclude storing the
/// @c Merge clause itself inside `remaining_clauses`. The reason is that we
/// need to have access to other parts of the clause, such as `SET` clauses
/// which need to be run.
///
/// Since @c Merge is contained in `remaining_clauses`, this vector contains
/// matching in the same order as @c Merge appears.
std::vector<Matching> merge_matching{};
/// @brief All the remaining clauses (without @c Match).
std::vector<Clause *> remaining_clauses{};
};
/// @brief Context which contains variables commonly used during planning.
template <class TDbAccessor>
struct PlanningContext {
/// @brief SymbolTable is used to determine inputs and outputs of planned
/// operators.
///
/// Newly created AST nodes may be added to reference existing symbols.
SymbolTable &symbol_table;
/// @brief The storage is used to traverse the AST as well as create new nodes
/// for use in operators.
AstTreeStorage &ast_storage;
/// @brief TDbAccessor, which may be used to get some information from the
/// database to generate better plans. The accessor is required only to live
/// long enough for the plan generation to finish.
const TDbAccessor &db;
/// @brief Symbol set is used to differentiate cycles in pattern matching.
///
/// During planning, symbols will be added as each operator produces values
/// for them. This way, the operator can be correctly initialized whether to
/// read a symbol or write it. E.g. `MATCH (n) -[r]- (n)` would bind (and
/// write) the first `n`, but the latter `n` would only read the already
/// written information.
std::unordered_set<Symbol> bound_symbols{};
};
// Contextual information used for generating match operators.
struct MatchContext {
const SymbolTable &symbol_table;
// Already bound symbols, which are used to determine whether the operator
// should reference them or establish new. This is both read from and written
// to during generation.
std::unordered_set<Symbol> &bound_symbols;
// Determines whether the match should see the new graph state or not.
GraphView graph_view = GraphView::OLD;
// All the newly established symbols in match.
std::vector<Symbol> new_symbols{};
};
/// @brief Convert the AST to multiple @c QueryParts.
///
/// This function will normalize patterns inside @c Match and @c Merge clauses
/// and do some other preprocessing in order to generate multiple @c QueryPart
/// structures. @c AstTreeStorage and @c SymbolTable may be used to create new
/// AST nodes.
std::vector<QueryPart> CollectQueryParts(SymbolTable &, AstTreeStorage &);
namespace impl {
// These functions are an internal implementation of RuleBasedPlanner. To avoid
// writing the whole code inline in this header file, they are declared here and
// defined in the cpp file.
bool BindSymbol(std::unordered_set<Symbol> &bound_symbols,
const Symbol &symbol);
LogicalOperator *GenFilters(
LogicalOperator *last_op, const std::unordered_set<Symbol> &bound_symbols,
std::vector<std::pair<Expression *, std::unordered_set<Symbol>>>
&all_filters,
AstTreeStorage &storage);
ScanAll *GenScanByIndex(
LogicalOperator *last_op, const GraphDbAccessor &db,
const Symbol &node_symbol, const MatchContext &context,
const std::set<GraphDbTypes::Label> &labels,
const std::map<GraphDbTypes::Property, std::vector<Filters::PropertyFilter>>
&properties);
LogicalOperator *GenReturn(Return &ret, LogicalOperator *input_op,
SymbolTable &symbol_table, bool is_write,
const std::unordered_set<Symbol> &bound_symbols,
AstTreeStorage &storage);
LogicalOperator *GenCreateForPattern(Pattern &pattern,
LogicalOperator *input_op,
const SymbolTable &symbol_table,
std::unordered_set<Symbol> &bound_symbols);
LogicalOperator *HandleWriteClause(Clause *clause, LogicalOperator *input_op,
const SymbolTable &symbol_table,
std::unordered_set<Symbol> &bound_symbols);
LogicalOperator *GenWith(With &with, LogicalOperator *input_op,
SymbolTable &symbol_table, bool is_write,
std::unordered_set<Symbol> &bound_symbols,
AstTreeStorage &storage);
} // namespace impl
/// @brief Planner which uses hardcoded rules to produce operators.
///
/// @sa MakeLogicalPlan
template <class TPlanningContext>
class RuleBasedPlanner {
public:
explicit RuleBasedPlanner(TPlanningContext &context) : context_(context) {}
/// @brief The result of plan generation is the root of the generated operator
/// tree.
using PlanResult = std::unique_ptr<LogicalOperator>;
/// @brief Generates the operator tree based on explicitly set rules.
PlanResult Plan(std::vector<QueryPart> &query_parts) {
auto &context = context_;
LogicalOperator *input_op = nullptr;
// Set to true if a query command writes to the database.
bool is_write = false;
for (const auto &query_part : query_parts) {
MatchContext match_ctx{context.symbol_table, context.bound_symbols};
input_op = PlanMatching(query_part.matching, input_op, match_ctx);
for (const auto &matching : query_part.optional_matching) {
MatchContext opt_ctx{context.symbol_table, context.bound_symbols};
auto *match_op = PlanMatching(matching, nullptr, opt_ctx);
if (match_op) {
input_op = new Optional(std::shared_ptr<LogicalOperator>(input_op),
std::shared_ptr<LogicalOperator>(match_op),
opt_ctx.new_symbols);
}
}
int merge_id = 0;
for (auto &clause : query_part.remaining_clauses) {
debug_assert(dynamic_cast<Match *>(clause) == nullptr,
"Unexpected Match in remaining clauses");
if (auto *ret = dynamic_cast<Return *>(clause)) {
input_op =
impl::GenReturn(*ret, input_op, context.symbol_table, is_write,
context.bound_symbols, context.ast_storage);
} else if (auto *merge = dynamic_cast<query::Merge *>(clause)) {
input_op =
GenMerge(*merge, input_op, query_part.merge_matching[merge_id++]);
// Treat MERGE clause as write, because we do not know if it will
// create
// anything.
is_write = true;
} else if (auto *with = dynamic_cast<query::With *>(clause)) {
input_op =
impl::GenWith(*with, input_op, context.symbol_table, is_write,
context.bound_symbols, context.ast_storage);
// WITH clause advances the command, so reset the flag.
is_write = false;
} else if (auto *op = impl::HandleWriteClause(clause, input_op,
context.symbol_table,
context.bound_symbols)) {
is_write = true;
input_op = op;
} else if (auto *unwind = dynamic_cast<query::Unwind *>(clause)) {
const auto &symbol =
context.symbol_table.at(*unwind->named_expression_);
impl::BindSymbol(context.bound_symbols, symbol);
input_op =
new plan::Unwind(std::shared_ptr<LogicalOperator>(input_op),
unwind->named_expression_->expression_, symbol);
} else if (auto *create_index =
dynamic_cast<query::CreateIndex *>(clause)) {
debug_assert(!input_op, "Unexpected operator before CreateIndex");
input_op = new plan::CreateIndex(create_index->label_,
create_index->property_);
} else {
throw utils::NotYetImplemented("clause conversion to operator(s)");
}
}
}
return std::unique_ptr<LogicalOperator>(input_op);
}
private:
TPlanningContext &context_;
LogicalOperator *PlanMatching(const Matching &matching,
LogicalOperator *input_op,
MatchContext &match_context) {
auto &bound_symbols = match_context.bound_symbols;
auto &storage = context_.ast_storage;
const auto &symbol_table = match_context.symbol_table;
// Copy all_filters, because we will modify the list as we generate Filters.
auto all_filters = matching.filters.all_filters();
// Try to generate any filters even before the 1st match operator. This
// optimizes the optional match which filters only on symbols bound in
// regular
// match.
auto *last_op =
impl::GenFilters(input_op, bound_symbols, all_filters, storage);
for (const auto &expansion : matching.expansions) {
const auto &node1_symbol = symbol_table.at(*expansion.node1->identifier_);
if (impl::BindSymbol(bound_symbols, node1_symbol)) {
// We have just bound this symbol, so generate ScanAll which fills it.
auto labels = FindOr(matching.filters.label_filters(), node1_symbol,
std::set<GraphDbTypes::Label>())
.first;
if (labels.empty()) {
// Without labels, we can only generate ScanAll of everything.
last_op = new ScanAll(std::shared_ptr<LogicalOperator>(last_op),
node1_symbol, match_context.graph_view);
} else {
// With labels, we can scan indexed data.
auto properties =
FindOr(matching.filters.property_filters(), node1_symbol,
std::map<GraphDbTypes::Property,
std::vector<Filters::PropertyFilter>>())
.first;
last_op = impl::GenScanByIndex(last_op, context_.db, node1_symbol,
match_context, labels, properties);
}
match_context.new_symbols.emplace_back(node1_symbol);
last_op =
impl::GenFilters(last_op, bound_symbols, all_filters, storage);
}
// We have an edge, so generate Expand.
if (expansion.edge) {
// If the expand symbols were already bound, then we need to indicate
// that they exist. The Expand will then check whether the pattern holds
// instead of writing the expansion to symbols.
const auto &node_symbol =
symbol_table.at(*expansion.node2->identifier_);
auto existing_node = false;
if (!impl::BindSymbol(bound_symbols, node_symbol)) {
existing_node = true;
} else {
match_context.new_symbols.emplace_back(node_symbol);
}
const auto &edge_symbol = symbol_table.at(*expansion.edge->identifier_);
auto existing_edge = false;
if (!impl::BindSymbol(bound_symbols, edge_symbol)) {
existing_edge = true;
} else {
match_context.new_symbols.emplace_back(edge_symbol);
}
if (auto *bf_atom = dynamic_cast<BreadthFirstAtom *>(expansion.edge)) {
const auto &traversed_edge_symbol =
symbol_table.at(*bf_atom->traversed_edge_identifier_);
const auto &next_node_symbol =
symbol_table.at(*bf_atom->next_node_identifier_);
last_op = new ExpandBreadthFirst(
node_symbol, edge_symbol, expansion.direction,
bf_atom->max_depth_, next_node_symbol, traversed_edge_symbol,
bf_atom->filter_expression_,
std::shared_ptr<LogicalOperator>(last_op), node1_symbol,
existing_node, match_context.graph_view);
} else if (expansion.edge->has_range_) {
last_op = new ExpandVariable(
node_symbol, edge_symbol, expansion.direction,
expansion.direction != expansion.edge->direction_,
expansion.edge->lower_bound_, expansion.edge->upper_bound_,
std::shared_ptr<LogicalOperator>(last_op), node1_symbol,
existing_node, existing_edge, match_context.graph_view);
} else {
last_op = new Expand(node_symbol, edge_symbol, expansion.direction,
std::shared_ptr<LogicalOperator>(last_op),
node1_symbol, existing_node, existing_edge,
match_context.graph_view);
}
if (!existing_edge) {
// Ensure Cyphermorphism (different edge symbols always map to
// different
// edges).
for (const auto &edge_symbols : matching.edge_symbols) {
if (edge_symbols.find(edge_symbol) == edge_symbols.end()) {
continue;
}
std::vector<Symbol> other_symbols;
for (const auto &symbol : edge_symbols) {
if (symbol == edge_symbol ||
bound_symbols.find(symbol) == bound_symbols.end()) {
continue;
}
other_symbols.push_back(symbol);
}
if (!other_symbols.empty()) {
last_op = new ExpandUniquenessFilter<EdgeAccessor>(
std::shared_ptr<LogicalOperator>(last_op), edge_symbol,
other_symbols);
}
}
}
last_op =
impl::GenFilters(last_op, bound_symbols, all_filters, storage);
}
}
debug_assert(all_filters.empty(), "Expected to generate all filters");
return last_op;
}
auto GenMerge(query::Merge &merge, LogicalOperator *input_op,
const Matching &matching) {
// Copy the bound symbol set, because we don't want to use the updated
// version
// when generating the create part.
std::unordered_set<Symbol> bound_symbols_copy(context_.bound_symbols);
MatchContext match_ctx{context_.symbol_table, bound_symbols_copy,
GraphView::NEW};
auto on_match = PlanMatching(matching, nullptr, match_ctx);
// Use the original bound_symbols, so we fill it with new symbols.
auto on_create = impl::GenCreateForPattern(*merge.pattern_, nullptr,
context_.symbol_table,
context_.bound_symbols);
for (auto &set : merge.on_create_) {
on_create = impl::HandleWriteClause(set, on_create, context_.symbol_table,
context_.bound_symbols);
debug_assert(on_create, "Expected SET in MERGE ... ON CREATE");
}
for (auto &set : merge.on_match_) {
on_match = impl::HandleWriteClause(set, on_match, context_.symbol_table,
context_.bound_symbols);
debug_assert(on_match, "Expected SET in MERGE ... ON MATCH");
}
return new plan::Merge(std::shared_ptr<LogicalOperator>(input_op),
std::shared_ptr<LogicalOperator>(on_match),
std::shared_ptr<LogicalOperator>(on_create));
}
};
} // namespace query::plan

View File

@ -1,18 +1,15 @@
#include "query/plan/planner.hpp"
#include "query/plan/variable_start_planner.hpp"
#include <limits>
#include <queue>
#include "cppitertools/slice.hpp"
#include "gflags/gflags.h"
#include "utils/flag_validation.hpp"
DEFINE_VALIDATED_uint64(
query_max_plans, 1000U, "Maximum number of generated plans for a query",
FLAG_IN_RANGE(1, std::numeric_limits<std::uint64_t>::max()));
namespace query::plan {
namespace query::plan::impl {
namespace {
@ -278,129 +275,6 @@ class VaryMatchingStart {
std::unordered_set<NodeAtom *, NodeSymbolHash, NodeSymbolEqual> nodes_;
};
// Produces a Cartesian product among vectors between begin and end iterator.
// For example:
//
// std::vector<int> first_set{1,2,3};
// std::vector<int> second_set{4,5};
// std::vector<std::vector<int>> all_sets{first_set, second_set};
// // prod should be {{1, 4}, {1, 5}, {2, 4}, {2, 5}, {3, 4}, {3, 5}}
// auto product = CartesianProduct(all_sets.cbegin(), all_sets.cend());
// for (const auto &set : product) {
// ...
// }
//
// The product is created lazily by iterating over the constructed
// CartesianProduct instance.
template <typename T>
class CartesianProduct {
public:
CartesianProduct(std::vector<std::vector<T>> sets)
: original_sets_(std::move(sets)),
begin_(original_sets_.cbegin()),
end_(original_sets_.cend()) {}
class iterator {
public:
typedef std::input_iterator_tag iterator_category;
typedef std::vector<T> value_type;
typedef long difference_type;
typedef const std::vector<T> &reference;
typedef const std::vector<T> *pointer;
explicit iterator(CartesianProduct &self, bool is_done)
: self_(self), is_done_(is_done) {
if (is_done || self.begin_ == self.end_) {
is_done_ = true;
return;
}
auto begin = self.begin_;
while (begin != self.end_) {
auto set_it = begin->cbegin();
if (set_it == begin->cend()) {
// One of the sets is empty, so there is no product.
is_done_ = true;
return;
}
// Collect the first product, by taking the first element of each set.
current_product_.emplace_back(*set_it);
// Store starting iterators to all sets.
sets_.emplace_back(begin, set_it);
begin++;
}
}
iterator &operator++() {
if (is_done_) return *this;
// Increment the leftmost set iterator.
auto sets_it = sets_.begin();
sets_it->second++;
// If the leftmost is at the end, reset it and increment the next
// leftmost.
while (sets_it->second == sets_it->first->cend()) {
sets_it->second = sets_it->first->cbegin();
sets_it++;
if (sets_it == sets_.end()) {
// The leftmost set is the last set and it was exhausted, so we are
// done.
is_done_ = true;
return *this;
}
sets_it->second++;
}
// We can now collect another product from the modified set iterators.
debug_assert(
current_product_.size() == sets_.size(),
"Expected size of current_product_ to match the size of sets_");
size_t i = 0;
// Change only the prefix of the product, remaining elements (after
// sets_it) should be the same.
auto last_unmodified = sets_it + 1;
for (auto kv_it = sets_.begin(); kv_it != last_unmodified; ++kv_it) {
current_product_[i++] = *kv_it->second;
}
return *this;
}
bool operator==(const iterator &other) const {
if (self_.begin_ != other.self_.begin_ || self_.end_ != other.self_.end_)
return false;
return (is_done_ && other.is_done_) || (sets_ == other.sets_);
}
bool operator!=(const iterator &other) const { return !(*this == other); }
// Iterator interface says that dereferencing a past-the-end iterator is
// undefined, so don't bother checking if we are done.
reference operator*() const { return current_product_; }
pointer operator->() const { return &current_product_; }
private:
CartesianProduct &self_;
// Vector of (original_sets_iterator, set_iterator) pairs. The
// original_sets_iterator points to the set among all the sets, while the
// set_iterator points to an element inside the pointed to set.
std::vector<
std::pair<decltype(self_.begin_), decltype(self_.begin_->cbegin())>>
sets_;
// Currently built product from pointed to elements in all sets.
std::vector<T> current_product_;
// Set to true when we have generated all products.
bool is_done_ = false;
};
auto begin() { return iterator(*this, false); }
auto end() { return iterator(*this, true); }
private:
friend class iterator;
// The original sets whose Cartesian product we are calculating.
std::vector<std::vector<T>> original_sets_;
// Iterators to the beginning and end of original_sets_.
typename std::vector<std::vector<T>>::const_iterator begin_;
typename std::vector<std::vector<T>>::const_iterator end_;
};
// Similar to VaryMatchingStart, but varies the starting nodes for all given
// matchings. After all matchings produce multiple alternative starts, the
// Cartesian product of all of them is returned.
@ -413,10 +287,12 @@ auto VaryMultiMatchingStarts(const std::vector<Matching> &matchings,
variants.emplace_back(
std::vector<Matching>(variant.begin(), variant.end()));
}
return iter::slice(CartesianProduct<Matching>(std::move(variants)), 0UL,
return iter::slice(MakeCartesianProduct(std::move(variants)), 0UL,
FLAGS_query_max_plans);
}
} // namespace
// Produces alternative query parts out of a single part by varying how each
// graph matching is done.
std::vector<QueryPart> VaryQueryPartMatching(const QueryPart &query_part,
@ -474,32 +350,4 @@ std::vector<QueryPart> VaryQueryPartMatching(const QueryPart &query_part,
return variants;
}
// Generates different, equivalent query parts by taking different graph
// matching routes for each query part.
auto VaryQueryMatching(const std::vector<QueryPart> &query_parts,
const SymbolTable &symbol_table) {
std::vector<std::vector<QueryPart>> alternative_query_parts;
for (const auto &query_part : query_parts) {
alternative_query_parts.emplace_back(
VaryQueryPartMatching(query_part, symbol_table));
}
return iter::slice(
CartesianProduct<QueryPart>(std::move(alternative_query_parts)), 0UL,
FLAGS_query_max_plans);
}
} // namespace
std::vector<std::unique_ptr<LogicalOperator>> VariableStartPlanner::Plan(
std::vector<QueryPart> &query_parts) {
std::vector<std::unique_ptr<LogicalOperator>> plans;
auto alternatives = VaryQueryMatching(query_parts, context_.symbol_table);
RuleBasedPlanner rule_planner(context_);
for (auto alternative_query_parts : alternatives) {
context_.bound_symbols.clear();
plans.emplace_back(rule_planner.Plan(alternative_query_parts));
}
return plans;
}
} // namespace query::plan
} // namespace query::plan::impl

View File

@ -0,0 +1,195 @@
/// @file
#pragma once
#include "cppitertools/slice.hpp"
#include "gflags/gflags.h"
#include "query/plan/rule_based_planner.hpp"
DECLARE_uint64(query_max_plans);
namespace query::plan {
/// Produces a Cartesian product among vectors between begin and end iterator.
/// For example:
///
/// std::vector<int> first_set{1,2,3};
/// std::vector<int> second_set{4,5};
/// std::vector<std::vector<int>> all_sets{first_set, second_set};
/// // prod should be {{1, 4}, {1, 5}, {2, 4}, {2, 5}, {3, 4}, {3, 5}}
/// auto product = CartesianProduct(all_sets.cbegin(), all_sets.cend());
/// for (const auto &set : product) {
/// ...
/// }
///
/// The product is created lazily by iterating over the constructed
/// CartesianProduct instance.
template <typename T>
class CartesianProduct {
public:
CartesianProduct(std::vector<std::vector<T>> sets)
: original_sets_(std::move(sets)),
begin_(original_sets_.cbegin()),
end_(original_sets_.cend()) {}
class iterator {
public:
typedef std::input_iterator_tag iterator_category;
typedef std::vector<T> value_type;
typedef long difference_type;
typedef const std::vector<T> &reference;
typedef const std::vector<T> *pointer;
explicit iterator(CartesianProduct &self, bool is_done)
: self_(self), is_done_(is_done) {
if (is_done || self.begin_ == self.end_) {
is_done_ = true;
return;
}
auto begin = self.begin_;
while (begin != self.end_) {
auto set_it = begin->cbegin();
if (set_it == begin->cend()) {
// One of the sets is empty, so there is no product.
is_done_ = true;
return;
}
// Collect the first product, by taking the first element of each set.
current_product_.emplace_back(*set_it);
// Store starting iterators to all sets.
sets_.emplace_back(begin, set_it);
begin++;
}
}
iterator &operator++() {
if (is_done_) return *this;
// Increment the leftmost set iterator.
auto sets_it = sets_.begin();
sets_it->second++;
// If the leftmost is at the end, reset it and increment the next
// leftmost.
while (sets_it->second == sets_it->first->cend()) {
sets_it->second = sets_it->first->cbegin();
sets_it++;
if (sets_it == sets_.end()) {
// The leftmost set is the last set and it was exhausted, so we are
// done.
is_done_ = true;
return *this;
}
sets_it->second++;
}
// We can now collect another product from the modified set iterators.
debug_assert(
current_product_.size() == sets_.size(),
"Expected size of current_product_ to match the size of sets_");
size_t i = 0;
// Change only the prefix of the product, remaining elements (after
// sets_it) should be the same.
auto last_unmodified = sets_it + 1;
for (auto kv_it = sets_.begin(); kv_it != last_unmodified; ++kv_it) {
current_product_[i++] = *kv_it->second;
}
return *this;
}
bool operator==(const iterator &other) const {
if (self_.begin_ != other.self_.begin_ || self_.end_ != other.self_.end_)
return false;
return (is_done_ && other.is_done_) || (sets_ == other.sets_);
}
bool operator!=(const iterator &other) const { return !(*this == other); }
// Iterator interface says that dereferencing a past-the-end iterator is
// undefined, so don't bother checking if we are done.
reference operator*() const { return current_product_; }
pointer operator->() const { return &current_product_; }
private:
CartesianProduct &self_;
// Vector of (original_sets_iterator, set_iterator) pairs. The
// original_sets_iterator points to the set among all the sets, while the
// set_iterator points to an element inside the pointed to set.
std::vector<
std::pair<decltype(self_.begin_), decltype(self_.begin_->cbegin())>>
sets_;
// Currently built product from pointed to elements in all sets.
std::vector<T> current_product_;
// Set to true when we have generated all products.
bool is_done_ = false;
};
auto begin() { return iterator(*this, false); }
auto end() { return iterator(*this, true); }
private:
friend class iterator;
// The original sets whose Cartesian product we are calculating.
std::vector<std::vector<T>> original_sets_;
// Iterators to the beginning and end of original_sets_.
typename std::vector<std::vector<T>>::const_iterator begin_;
typename std::vector<std::vector<T>>::const_iterator end_;
};
/// Convenience function for creating CartesianProduct by deducing template
/// arguments from function arguments.
template <typename T>
auto MakeCartesianProduct(std::vector<std::vector<T>> sets) {
return CartesianProduct<T>(sets);
}
namespace impl {
std::vector<QueryPart> VaryQueryPartMatching(const QueryPart &query_part,
const SymbolTable &symbol_table);
} // namespace impl
/// @brief Planner which generates multiple plans by changing the order of graph
/// traversal.
///
/// This planner picks different starting nodes from which to start graph
/// traversal. Generating a single plan is backed by @c RuleBasedPlanner.
///
/// @sa MakeLogicalPlan
template <class TPlanningContext>
class VariableStartPlanner {
public:
explicit VariableStartPlanner(TPlanningContext &context)
: context_(context) {}
/// @brief The result of plan generation is a vector of roots to multiple
/// generated operator trees.
using PlanResult = std::vector<std::unique_ptr<LogicalOperator>>;
/// @brief Generate multiple plans by varying the order of graph traversal.
PlanResult Plan(std::vector<QueryPart> &query_parts) {
std::vector<std::unique_ptr<LogicalOperator>> plans;
auto alternatives = VaryQueryMatching(query_parts, context_.symbol_table);
RuleBasedPlanner<TPlanningContext> rule_planner(context_);
for (auto alternative_query_parts : alternatives) {
context_.bound_symbols.clear();
plans.emplace_back(rule_planner.Plan(alternative_query_parts));
}
return plans;
}
private:
TPlanningContext &context_;
// Generates different, equivalent query parts by taking different graph
// matching routes for each query part.
auto VaryQueryMatching(const std::vector<QueryPart> &query_parts,
const SymbolTable &symbol_table) {
std::vector<std::vector<QueryPart>> alternative_query_parts;
for (const auto &query_part : query_parts) {
alternative_query_parts.emplace_back(
impl::VaryQueryPartMatching(query_part, symbol_table));
}
return iter::slice(MakeCartesianProduct(std::move(alternative_query_parts)),
0UL, FLAGS_query_max_plans);
}
};
} // namespace query::plan