Preprocess Ast to QueryParts and plan ScanAllByLabel
Summary: Mention the non-existent function name in semantic error. Don't merge optional matches into one Matching, because it is an error to treat multiple optional matches as a single optional match. Document new structures and functions. Add not so smart ScanAllByLabel generation. Reviewers: mislav.bradac, buda, florijan, lion Reviewed By: mislav.bradac Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D394
This commit is contained in:
parent
d3d8264fae
commit
74b082f050
@ -810,7 +810,8 @@ antlrcpp::Any CypherMainVisitor::visitFunctionInvocation(
|
||||
}
|
||||
}
|
||||
auto function = NameToFunction(function_name);
|
||||
if (!function) throw SemanticException("Function doesn't exist.");
|
||||
if (!function)
|
||||
throw SemanticException("Function '{}' doesn't exist.", function_name);
|
||||
return static_cast<Expression *>(
|
||||
storage_.Create<Function>(function, expressions));
|
||||
}
|
||||
|
@ -60,6 +60,27 @@ auto ReducePattern(
|
||||
return last_res;
|
||||
}
|
||||
|
||||
void ForeachPattern(
|
||||
Pattern &pattern, std::function<void(NodeAtom *)> base,
|
||||
std::function<void(NodeAtom *, EdgeAtom *, NodeAtom *)> collect) {
|
||||
debug_assert(!pattern.atoms_.empty(), "Missing atoms in pattern");
|
||||
auto atoms_it = pattern.atoms_.begin();
|
||||
auto current_node = dynamic_cast<NodeAtom *>(*atoms_it++);
|
||||
debug_assert(current_node, "First pattern atom is not a node");
|
||||
base(current_node);
|
||||
// Remaining atoms need to follow sequentially as (EdgeAtom, NodeAtom)*
|
||||
while (atoms_it != pattern.atoms_.end()) {
|
||||
auto edge = dynamic_cast<EdgeAtom *>(*atoms_it++);
|
||||
debug_assert(edge, "Expected an edge atom in pattern.");
|
||||
debug_assert(atoms_it != pattern.atoms_.end(),
|
||||
"Edge atom should not end the pattern.");
|
||||
auto prev_node = current_node;
|
||||
current_node = dynamic_cast<NodeAtom *>(*atoms_it++);
|
||||
debug_assert(current_node, "Expected a node atom in pattern.");
|
||||
collect(prev_node, edge, current_node);
|
||||
}
|
||||
}
|
||||
|
||||
auto GenCreateForPattern(Pattern &pattern, LogicalOperator *input_op,
|
||||
const SymbolTable &symbol_table,
|
||||
std::unordered_set<Symbol> &bound_symbols) {
|
||||
@ -159,16 +180,15 @@ Expression *PropertiesEqual(AstTreeStorage &storage,
|
||||
return filter_expr;
|
||||
}
|
||||
|
||||
auto &CollectPatternFilters(
|
||||
void CollectPatternFilters(
|
||||
Pattern &pattern, const SymbolTable &symbol_table,
|
||||
std::list<std::pair<Expression *, std::unordered_set<Symbol>>> &filters,
|
||||
std::vector<std::pair<Expression *, std::unordered_set<Symbol>>> &filters,
|
||||
AstTreeStorage &storage) {
|
||||
UsedSymbolsCollector collector(symbol_table);
|
||||
auto node_filter = [&](NodeAtom *node) {
|
||||
Expression *labels_filter =
|
||||
node->labels_.empty() ? nullptr
|
||||
: labels_filter = storage.Create<LabelsTest>(
|
||||
node->identifier_, node->labels_);
|
||||
node->labels_.empty() ? nullptr : storage.Create<LabelsTest>(
|
||||
node->identifier_, node->labels_);
|
||||
auto *props_filter = PropertiesEqual(storage, collector, node);
|
||||
if (labels_filter || props_filter) {
|
||||
collector.symbols_.insert(symbol_table.at(*node->identifier_));
|
||||
@ -177,9 +197,8 @@ auto &CollectPatternFilters(
|
||||
collector.symbols_);
|
||||
collector.symbols_.clear();
|
||||
}
|
||||
return &filters;
|
||||
};
|
||||
auto expand_filter = [&](auto *filters, NodeAtom *prev_node, EdgeAtom *edge,
|
||||
auto expand_filter = [&](NodeAtom *prev_node, EdgeAtom *edge,
|
||||
NodeAtom *node) {
|
||||
Expression *types_filter = edge->edge_types_.empty()
|
||||
? nullptr
|
||||
@ -189,30 +208,14 @@ auto &CollectPatternFilters(
|
||||
if (types_filter || props_filter) {
|
||||
const auto &edge_symbol = symbol_table.at(*edge->identifier_);
|
||||
collector.symbols_.insert(edge_symbol);
|
||||
filters->emplace_back(
|
||||
filters.emplace_back(
|
||||
BoolJoin<FilterAndOperator>(storage, types_filter, props_filter),
|
||||
collector.symbols_);
|
||||
collector.symbols_.clear();
|
||||
}
|
||||
return node_filter(node);
|
||||
node_filter(node);
|
||||
};
|
||||
return *ReducePattern<
|
||||
std::list<std::pair<Expression *, std::unordered_set<Symbol>>> *>(
|
||||
pattern, node_filter, expand_filter);
|
||||
}
|
||||
|
||||
void CollectMatchFilters(
|
||||
const Match &match, const SymbolTable &symbol_table,
|
||||
std::list<std::pair<Expression *, std::unordered_set<Symbol>>> &filters,
|
||||
AstTreeStorage &storage) {
|
||||
for (auto *pattern : match.patterns_) {
|
||||
CollectPatternFilters(*pattern, symbol_table, filters, storage);
|
||||
}
|
||||
if (match.where_) {
|
||||
UsedSymbolsCollector collector(symbol_table);
|
||||
match.where_->expression_->Accept(collector);
|
||||
filters.emplace_back(match.where_->expression_, collector.symbols_);
|
||||
}
|
||||
ForeachPattern(pattern, node_filter, expand_filter);
|
||||
}
|
||||
|
||||
// Contextual information used for generating match operators.
|
||||
@ -224,19 +227,13 @@ struct MatchContext {
|
||||
std::unordered_set<Symbol> &bound_symbols;
|
||||
// Determines whether the match should see the new graph state or not.
|
||||
GraphView graph_view = GraphView::OLD;
|
||||
// Pairs of filter expression and symbols used in them. The list should be
|
||||
// filled using CollectPatternFilters function, and later modified during
|
||||
// GenMatchForPattern.
|
||||
std::list<std::pair<Expression *, std::unordered_set<Symbol>>> filters;
|
||||
// Symbols for edges established in match, used to ensure Cyphermorphism.
|
||||
std::unordered_set<Symbol> edge_symbols;
|
||||
// All the newly established symbols in match.
|
||||
std::vector<Symbol> new_symbols;
|
||||
};
|
||||
|
||||
auto GenFilters(
|
||||
LogicalOperator *last_op, const std::unordered_set<Symbol> &bound_symbols,
|
||||
std::list<std::pair<Expression *, std::unordered_set<Symbol>>> &filters,
|
||||
std::vector<std::pair<Expression *, std::unordered_set<Symbol>>> &filters,
|
||||
AstTreeStorage &storage) {
|
||||
Expression *filter_expr = nullptr;
|
||||
for (auto filters_it = filters.begin(); filters_it != filters.end();) {
|
||||
@ -255,114 +252,6 @@ auto GenFilters(
|
||||
return last_op;
|
||||
}
|
||||
|
||||
// Generates operators for matching the given pattern and appends them to
|
||||
// input_op. Fills the context with all the new symbols and edge symbols.
|
||||
auto GenMatchForPattern(Pattern &pattern, LogicalOperator *input_op,
|
||||
MatchContext &context, AstTreeStorage &storage) {
|
||||
auto &bound_symbols = context.bound_symbols;
|
||||
const auto &symbol_table = context.symbol_table;
|
||||
auto base = [&](NodeAtom *node) {
|
||||
// Try to generate any filters even before the 1st match operator.
|
||||
auto *last_op =
|
||||
GenFilters(input_op, bound_symbols, context.filters, storage);
|
||||
// If the first atom binds a symbol, we generate a ScanAll which writes it.
|
||||
// Otherwise, someone else generates it (e.g. a previous ScanAll).
|
||||
const auto &node_symbol = symbol_table.at(*node->identifier_);
|
||||
if (BindSymbol(bound_symbols, node_symbol)) {
|
||||
last_op = new ScanAll(std::shared_ptr<LogicalOperator>(last_op),
|
||||
node_symbol, context.graph_view);
|
||||
context.new_symbols.emplace_back(node_symbol);
|
||||
}
|
||||
return GenFilters(last_op, bound_symbols, context.filters, storage);
|
||||
};
|
||||
auto collect = [&](LogicalOperator *last_op, NodeAtom *prev_node,
|
||||
EdgeAtom *edge, NodeAtom *node) {
|
||||
// Store the symbol from the first node as the input to Expand.
|
||||
const auto &input_symbol = symbol_table.at(*prev_node->identifier_);
|
||||
// If the expand symbols were already bound, then we need to indicate
|
||||
// that they exist. The Expand will then check whether the pattern holds
|
||||
// instead of writing the expansion to symbols.
|
||||
const auto &node_symbol = symbol_table.at(*node->identifier_);
|
||||
auto existing_node = false;
|
||||
if (!BindSymbol(bound_symbols, node_symbol)) {
|
||||
existing_node = true;
|
||||
} else {
|
||||
context.new_symbols.emplace_back(node_symbol);
|
||||
}
|
||||
const auto &edge_symbol = symbol_table.at(*edge->identifier_);
|
||||
auto existing_edge = false;
|
||||
if (!BindSymbol(bound_symbols, edge_symbol)) {
|
||||
existing_edge = true;
|
||||
} else {
|
||||
context.new_symbols.emplace_back(edge_symbol);
|
||||
}
|
||||
last_op = new Expand(node, edge, std::shared_ptr<LogicalOperator>(last_op),
|
||||
input_symbol, existing_node, existing_edge,
|
||||
context.graph_view);
|
||||
if (!existing_edge) {
|
||||
// Ensure Cyphermorphism (different edge symbols always map to different
|
||||
// edges).
|
||||
if (!context.edge_symbols.empty()) {
|
||||
last_op = new ExpandUniquenessFilter<EdgeAccessor>(
|
||||
std::shared_ptr<LogicalOperator>(last_op), edge_symbol,
|
||||
std::vector<Symbol>(context.edge_symbols.begin(),
|
||||
context.edge_symbols.end()));
|
||||
}
|
||||
}
|
||||
// Insert edge_symbol after creating ExpandUniquenessFilter, so that we
|
||||
// avoid filtering by the same edge we just expanded.
|
||||
context.edge_symbols.insert(edge_symbol);
|
||||
return GenFilters(last_op, bound_symbols, context.filters, storage);
|
||||
};
|
||||
return ReducePattern<LogicalOperator *>(pattern, base, collect);
|
||||
}
|
||||
|
||||
auto GenMatches(std::vector<Match *> &matches, LogicalOperator *input_op,
|
||||
const SymbolTable &symbol_table,
|
||||
std::unordered_set<Symbol> &bound_symbols,
|
||||
AstTreeStorage &storage) {
|
||||
auto *last_op = input_op;
|
||||
MatchContext req_ctx{symbol_table, bound_symbols};
|
||||
// Collect all non-optional match filters, so that we can put them as soon as
|
||||
// possible in the operator tree. Optional match need to be treated
|
||||
// specially, because they need to remain inside the optional match.
|
||||
for (auto *match : matches) {
|
||||
if (match->optional_) {
|
||||
continue;
|
||||
}
|
||||
CollectMatchFilters(*match, symbol_table, req_ctx.filters, storage);
|
||||
}
|
||||
auto gen_match = [&storage](const Match &match, LogicalOperator *input_op,
|
||||
MatchContext &context) {
|
||||
auto *match_op = input_op;
|
||||
for (auto *pattern : match.patterns_) {
|
||||
match_op = GenMatchForPattern(*pattern, match_op, context, storage);
|
||||
}
|
||||
return match_op;
|
||||
};
|
||||
for (auto *match : matches) {
|
||||
if (match->optional_) {
|
||||
// Optional match needs to be standalone, so filter only by its filters
|
||||
// and don't plug the previous match_op as input.
|
||||
MatchContext opt_ctx{symbol_table, bound_symbols};
|
||||
CollectMatchFilters(*match, symbol_table, opt_ctx.filters, storage);
|
||||
auto *match_op = gen_match(*match, nullptr, opt_ctx);
|
||||
last_op = new Optional(std::shared_ptr<LogicalOperator>(last_op),
|
||||
std::shared_ptr<LogicalOperator>(match_op),
|
||||
opt_ctx.new_symbols);
|
||||
debug_assert(opt_ctx.filters.empty(),
|
||||
"Expected to generate all optional filters");
|
||||
} else {
|
||||
// Since we reuse req_ctx, we need to clear the symbols for the new match.
|
||||
req_ctx.edge_symbols.clear();
|
||||
req_ctx.new_symbols.clear();
|
||||
last_op = gen_match(*match, last_op, req_ctx);
|
||||
}
|
||||
}
|
||||
debug_assert(req_ctx.filters.empty(), "Expected to generate all filters");
|
||||
return last_op;
|
||||
}
|
||||
|
||||
// Ast tree visitor which collects the context for a return body.
|
||||
// The return body of WITH and RETURN clauses consists of:
|
||||
//
|
||||
@ -693,27 +582,269 @@ LogicalOperator *HandleWriteClause(Clause *clause, LogicalOperator *input_op,
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Normalized representation of a pattern that needs to be matched.
|
||||
struct Expansion {
|
||||
// The first node in the expansion, it can be a single node.
|
||||
NodeAtom *node1 = nullptr;
|
||||
// Optional edge which connects the 2 nodes.
|
||||
EdgeAtom *edge = nullptr;
|
||||
// Optional node at the other end of an edge. If the expansion contains an
|
||||
// edge, then this node is required.
|
||||
NodeAtom *node2 = nullptr;
|
||||
};
|
||||
|
||||
// Normalized representation of a single or multiple Match clauses.
|
||||
//
|
||||
// For example, `MATCH (a :Label) -[e1]- (b) -[e2]- (c) MATCH (n) -[e3]- (m)
|
||||
// WHERE c.prop < 42` will produce the following.
|
||||
// Expansions will store `(a) -[e1]-(b)`, `(b) -[e2]- (c)` and `(n) -[e3]- (m)`.
|
||||
// Edge symbols for Cyphermorphism will only contain the set `{e1, e2}` for the
|
||||
// first `MATCH` and the set `{e3}` for the second.
|
||||
// Filters will contain 2 pairs. One for testing `:Label` on symbol `a` and the
|
||||
// other obtained from `WHERE` on symbol `c`.
|
||||
struct Matching {
|
||||
// All expansions that need to be performed across Match clauses.
|
||||
std::vector<Expansion> expansions;
|
||||
// Symbols for edges established in match, used to ensure Cyphermorphism.
|
||||
// There are multiple sets, because each Match clause determines a single set.
|
||||
std::vector<std::unordered_set<Symbol>> edge_symbols;
|
||||
// Pairs of filter expression and symbols used in them. The list should be
|
||||
// filled using CollectPatternFilters function.
|
||||
std::vector<std::pair<Expression *, std::unordered_set<Symbol>>> filters;
|
||||
};
|
||||
|
||||
// Represents a read (+ write) part of a query. Each part ends with either:
|
||||
// * RETURN clause;
|
||||
// * WITH clause or
|
||||
// * any of the write clauses.
|
||||
//
|
||||
// For a query `MATCH (n) MERGE (n) -[e]- (m) SET n.x = 42 MERGE (l)` the
|
||||
// generated QueryPart will have `matching` generated for the `MATCH`.
|
||||
// `remaining_clauses` will contain `Merge`, `SetProperty` and `Merge` clauses
|
||||
// in that exact order. The pattern inside the first `MERGE` will be used to
|
||||
// generate the first `merge_matching` element, and the second `MERGE` pattern
|
||||
// will produce the second `merge_matching` element. This way, if someone
|
||||
// traverses `remaining_clauses`, the order of appearance of `Merge` clauses is
|
||||
// in the same order as their respective `merge_matching` elements.
|
||||
struct QueryPart {
|
||||
// All MATCH clauses merged into one Matching.
|
||||
Matching matching;
|
||||
// Each OPTIONAL MATCH converted to Matching.
|
||||
std::vector<Matching> optional_matching;
|
||||
// Matching for each MERGE clause. Since Merge is contained in
|
||||
// remaining_clauses, this vector contains matching in the same order as Merge
|
||||
// appears.
|
||||
std::vector<Matching> merge_matching;
|
||||
// All the remaining clauses (without Match).
|
||||
std::vector<Clause *> remaining_clauses;
|
||||
};
|
||||
|
||||
// Context which contains variables commonly used during planning.
|
||||
struct PlanningContext {
|
||||
SymbolTable &symbol_table;
|
||||
AstTreeStorage &ast_storage;
|
||||
// bound_symbols set is used to differentiate cycles in pattern matching, so
|
||||
// that the operator can be correctly initialized whether to read the symbol
|
||||
// or write it. E.g. `MATCH (n) -[r]- (n)` would bind (and write) the first
|
||||
// `n`, but the latter `n` would only read the already written information.
|
||||
std::unordered_set<Symbol> bound_symbols;
|
||||
};
|
||||
|
||||
// Converts multiple Patterns to Expansions. Each Pattern can contain an
|
||||
// arbitrarily long chain of nodes and edges. The conversion to an Expansion is
|
||||
// done by splitting a pattern into triplets (node1, edge, node2). The triplets
|
||||
// conserve the semantics of the pattern. For example, in a pattern:
|
||||
// (m) -[e]- (n) -[f]- (o) the same can be achieved with:
|
||||
// (m) -[e]- (n), (n) -[f]- (o).
|
||||
// This representation makes it easier to permute from which node or edge we
|
||||
// want to start expanding.
|
||||
std::vector<Expansion> NormalizePatterns(
|
||||
const std::vector<Pattern *> &patterns) {
|
||||
std::vector<Expansion> expansions;
|
||||
auto collect_node = [&](auto *node) {
|
||||
expansions.emplace_back(Expansion{node});
|
||||
};
|
||||
auto collect_expansion = [&](auto *prev_node, auto *edge,
|
||||
auto *current_node) {
|
||||
expansions.emplace_back(Expansion{prev_node, edge, current_node});
|
||||
};
|
||||
for (const auto &pattern : patterns) {
|
||||
ForeachPattern(*pattern, collect_node, collect_expansion);
|
||||
}
|
||||
return expansions;
|
||||
}
|
||||
|
||||
// Fills the given Matching, by converting the Match patterns to normalized
|
||||
// representation as Expansions. Filters used in the Match are also collected,
|
||||
// as well as edge symbols which determine Cyphermorphism. Collecting filters
|
||||
// will lift them out of a pattern and generate new expressions (just like they
|
||||
// were in a Where clause).
|
||||
void AddMatching(const std::vector<Pattern *> &patterns, Where *where,
|
||||
const SymbolTable &symbol_table, AstTreeStorage &storage,
|
||||
Matching &matching) {
|
||||
auto expansions = NormalizePatterns(patterns);
|
||||
std::unordered_set<Symbol> edge_symbols;
|
||||
for (const auto &expansion : expansions) {
|
||||
if (expansion.edge) {
|
||||
edge_symbols.insert(symbol_table.at(*expansion.edge->identifier_));
|
||||
}
|
||||
}
|
||||
if (!edge_symbols.empty()) {
|
||||
matching.edge_symbols.emplace_back(edge_symbols);
|
||||
}
|
||||
matching.expansions.insert(matching.expansions.end(), expansions.begin(),
|
||||
expansions.end());
|
||||
for (auto *pattern : patterns) {
|
||||
CollectPatternFilters(*pattern, symbol_table, matching.filters, storage);
|
||||
}
|
||||
if (where) {
|
||||
UsedSymbolsCollector collector(symbol_table);
|
||||
where->expression_->Accept(collector);
|
||||
matching.filters.emplace_back(where->expression_, collector.symbols_);
|
||||
}
|
||||
}
|
||||
void AddMatching(const Match &match, const SymbolTable &symbol_table,
|
||||
AstTreeStorage &storage, Matching &matching) {
|
||||
return AddMatching(match.patterns_, match.where_, symbol_table, storage,
|
||||
matching);
|
||||
}
|
||||
|
||||
// Converts a Query to multiple QueryParts. In the process new Ast nodes may be
|
||||
// created, e.g. filter expressions.
|
||||
std::vector<QueryPart> CollectQueryParts(const SymbolTable &symbol_table,
|
||||
AstTreeStorage &storage) {
|
||||
auto query = storage.query();
|
||||
std::vector<QueryPart> query_parts(1);
|
||||
auto *query_part = &query_parts.back();
|
||||
for (auto &clause : query->clauses_) {
|
||||
if (auto *match = dynamic_cast<Match *>(clause)) {
|
||||
if (match->optional_) {
|
||||
query_part->optional_matching.emplace_back(Matching{});
|
||||
AddMatching(*match, symbol_table, storage,
|
||||
query_part->optional_matching.back());
|
||||
} else {
|
||||
debug_assert(query_part->optional_matching.empty(),
|
||||
"Match clause cannot follow optional match.");
|
||||
AddMatching(*match, symbol_table, storage, query_part->matching);
|
||||
}
|
||||
} else {
|
||||
query_part->remaining_clauses.push_back(clause);
|
||||
if (auto *merge = dynamic_cast<query::Merge *>(clause)) {
|
||||
query_part->merge_matching.emplace_back(Matching{});
|
||||
AddMatching({merge->pattern_}, nullptr, symbol_table, storage,
|
||||
query_part->merge_matching.back());
|
||||
} else if (dynamic_cast<With *>(clause)) {
|
||||
query_parts.emplace_back(QueryPart{});
|
||||
query_part = &query_parts.back();
|
||||
} else if (dynamic_cast<Return *>(clause)) {
|
||||
// TODO: Support RETURN UNION ...
|
||||
return query_parts;
|
||||
}
|
||||
}
|
||||
}
|
||||
return query_parts;
|
||||
}
|
||||
|
||||
LogicalOperator *PlanMatching(const Matching &matching,
|
||||
LogicalOperator *input_op,
|
||||
AstTreeStorage &storage, MatchContext &context) {
|
||||
auto &bound_symbols = context.bound_symbols;
|
||||
const auto &symbol_table = context.symbol_table;
|
||||
// Copy filters, because we will modify the list as we generate Filters.
|
||||
auto filters = matching.filters;
|
||||
// Try to generate any filters even before the 1st match operator. This
|
||||
// optimizes the optional match which filters only on symbols bound in regular
|
||||
// match.
|
||||
auto *last_op = GenFilters(input_op, bound_symbols, filters, storage);
|
||||
for (const auto &expansion : matching.expansions) {
|
||||
const auto &node1_symbol = symbol_table.at(*expansion.node1->identifier_);
|
||||
if (BindSymbol(bound_symbols, node1_symbol)) {
|
||||
// We have just bound this symbol, so generate ScanAll which fills it.
|
||||
const auto &labels = expansion.node1->labels_;
|
||||
if (labels.empty()) {
|
||||
last_op = new ScanAll(std::shared_ptr<LogicalOperator>(last_op),
|
||||
node1_symbol, context.graph_view);
|
||||
} else {
|
||||
// Don't act smart by selecting the best label index, so take the first.
|
||||
last_op = new ScanAllByLabel(std::shared_ptr<LogicalOperator>(last_op),
|
||||
node1_symbol, labels.front(),
|
||||
context.graph_view);
|
||||
}
|
||||
context.new_symbols.emplace_back(node1_symbol);
|
||||
last_op = GenFilters(last_op, bound_symbols, filters, storage);
|
||||
}
|
||||
// We have an edge, so generate Expand.
|
||||
if (expansion.edge) {
|
||||
// If the expand symbols were already bound, then we need to indicate
|
||||
// that they exist. The Expand will then check whether the pattern holds
|
||||
// instead of writing the expansion to symbols.
|
||||
const auto &node_symbol = symbol_table.at(*expansion.node2->identifier_);
|
||||
auto existing_node = false;
|
||||
if (!BindSymbol(bound_symbols, node_symbol)) {
|
||||
existing_node = true;
|
||||
} else {
|
||||
context.new_symbols.emplace_back(node_symbol);
|
||||
}
|
||||
const auto &edge_symbol = symbol_table.at(*expansion.edge->identifier_);
|
||||
auto existing_edge = false;
|
||||
if (!BindSymbol(bound_symbols, edge_symbol)) {
|
||||
existing_edge = true;
|
||||
} else {
|
||||
context.new_symbols.emplace_back(edge_symbol);
|
||||
}
|
||||
last_op =
|
||||
new Expand(expansion.node2, expansion.edge,
|
||||
std::shared_ptr<LogicalOperator>(last_op), node1_symbol,
|
||||
existing_node, existing_edge, context.graph_view);
|
||||
if (!existing_edge) {
|
||||
// Ensure Cyphermorphism (different edge symbols always map to different
|
||||
// edges).
|
||||
for (const auto &edge_symbols : matching.edge_symbols) {
|
||||
if (edge_symbols.find(edge_symbol) == edge_symbols.end()) {
|
||||
continue;
|
||||
}
|
||||
std::vector<Symbol> other_symbols;
|
||||
for (const auto &symbol : edge_symbols) {
|
||||
if (symbol == edge_symbol ||
|
||||
bound_symbols.find(symbol) == bound_symbols.end()) {
|
||||
continue;
|
||||
}
|
||||
other_symbols.push_back(symbol);
|
||||
}
|
||||
if (!other_symbols.empty()) {
|
||||
last_op = new ExpandUniquenessFilter<EdgeAccessor>(
|
||||
std::shared_ptr<LogicalOperator>(last_op), edge_symbol,
|
||||
other_symbols);
|
||||
}
|
||||
}
|
||||
}
|
||||
last_op = GenFilters(last_op, bound_symbols, filters, storage);
|
||||
}
|
||||
}
|
||||
debug_assert(filters.empty(), "Expected to generate all filters");
|
||||
return last_op;
|
||||
}
|
||||
|
||||
auto GenMerge(query::Merge &merge, LogicalOperator *input_op,
|
||||
const SymbolTable &symbol_table,
|
||||
std::unordered_set<Symbol> &bound_symbols,
|
||||
AstTreeStorage &storage) {
|
||||
const Matching &matching, PlanningContext &context) {
|
||||
// Copy the bound symbol set, because we don't want to use the updated version
|
||||
// when generating the create part.
|
||||
std::unordered_set<Symbol> bound_symbols_copy(bound_symbols);
|
||||
MatchContext context{symbol_table, bound_symbols_copy, GraphView::NEW};
|
||||
CollectPatternFilters(*merge.pattern_, symbol_table, context.filters,
|
||||
storage);
|
||||
std::unordered_set<Symbol> bound_symbols_copy(context.bound_symbols);
|
||||
MatchContext match_ctx{context.symbol_table, bound_symbols_copy,
|
||||
GraphView::NEW};
|
||||
auto on_match =
|
||||
GenMatchForPattern(*merge.pattern_, nullptr, context, storage);
|
||||
PlanMatching(matching, nullptr, context.ast_storage, match_ctx);
|
||||
// Use the original bound_symbols, so we fill it with new symbols.
|
||||
auto on_create = GenCreateForPattern(*merge.pattern_, nullptr, symbol_table,
|
||||
bound_symbols);
|
||||
auto on_create = GenCreateForPattern(
|
||||
*merge.pattern_, nullptr, context.symbol_table, context.bound_symbols);
|
||||
for (auto &set : merge.on_create_) {
|
||||
on_create = HandleWriteClause(set, on_create, symbol_table, bound_symbols);
|
||||
on_create = HandleWriteClause(set, on_create, context.symbol_table,
|
||||
context.bound_symbols);
|
||||
debug_assert(on_create, "Expected SET in MERGE ... ON CREATE");
|
||||
}
|
||||
for (auto &set : merge.on_match_) {
|
||||
on_match = HandleWriteClause(set, on_match, symbol_table, bound_symbols);
|
||||
on_match = HandleWriteClause(set, on_match, context.symbol_table,
|
||||
context.bound_symbols);
|
||||
debug_assert(on_match, "Expected SET in MERGE ... ON MATCH");
|
||||
}
|
||||
return new plan::Merge(std::shared_ptr<LogicalOperator>(input_op),
|
||||
@ -725,46 +856,52 @@ auto GenMerge(query::Merge &merge, LogicalOperator *input_op,
|
||||
|
||||
std::unique_ptr<LogicalOperator> MakeLogicalPlan(AstTreeStorage &storage,
|
||||
SymbolTable &symbol_table) {
|
||||
auto query = storage.query();
|
||||
// bound_symbols set is used to differentiate cycles in pattern matching, so
|
||||
// that the operator can be correctly initialized whether to read the symbol
|
||||
// or write it. E.g. `MATCH (n) -[r]- (n)` would bind (and write) the first
|
||||
// `n`, but the latter `n` would only read the already written information.
|
||||
std::unordered_set<Symbol> bound_symbols;
|
||||
auto query_parts = CollectQueryParts(symbol_table, storage);
|
||||
PlanningContext context{symbol_table, storage};
|
||||
LogicalOperator *input_op = nullptr;
|
||||
// Set to true if a query command writes to the database.
|
||||
bool is_write = false;
|
||||
LogicalOperator *input_op = nullptr;
|
||||
// All sequential Match clauses. Reset after encountering non-Match.
|
||||
std::vector<Match *> matches;
|
||||
for (auto &clause : query->clauses_) {
|
||||
// Clauses which read from the database.
|
||||
if (auto *match = dynamic_cast<Match *>(clause)) {
|
||||
matches.emplace_back(match);
|
||||
} else {
|
||||
input_op =
|
||||
GenMatches(matches, input_op, symbol_table, bound_symbols, storage);
|
||||
matches.clear();
|
||||
for (const auto &query_part : query_parts) {
|
||||
MatchContext match_ctx{context.symbol_table, context.bound_symbols};
|
||||
input_op = PlanMatching(query_part.matching, input_op, context.ast_storage,
|
||||
match_ctx);
|
||||
for (const auto &matching : query_part.optional_matching) {
|
||||
MatchContext opt_ctx{context.symbol_table, context.bound_symbols};
|
||||
auto *match_op =
|
||||
PlanMatching(matching, nullptr, context.ast_storage, opt_ctx);
|
||||
if (match_op) {
|
||||
input_op = new Optional(std::shared_ptr<LogicalOperator>(input_op),
|
||||
std::shared_ptr<LogicalOperator>(match_op),
|
||||
opt_ctx.new_symbols);
|
||||
}
|
||||
}
|
||||
int merge_id = 0;
|
||||
for (auto &clause : query_part.remaining_clauses) {
|
||||
debug_assert(dynamic_cast<Match *>(clause) == nullptr,
|
||||
"Unexpected Match in remaining clauses");
|
||||
if (auto *ret = dynamic_cast<Return *>(clause)) {
|
||||
input_op = GenReturn(*ret, input_op, symbol_table, is_write,
|
||||
bound_symbols, storage);
|
||||
input_op = GenReturn(*ret, input_op, context.symbol_table, is_write,
|
||||
context.bound_symbols, context.ast_storage);
|
||||
} else if (auto *merge = dynamic_cast<query::Merge *>(clause)) {
|
||||
input_op =
|
||||
GenMerge(*merge, input_op, symbol_table, bound_symbols, storage);
|
||||
input_op = GenMerge(*merge, input_op,
|
||||
query_part.merge_matching[merge_id++], context);
|
||||
// Treat MERGE clause as write, because we do not know if it will create
|
||||
// anything.
|
||||
is_write = true;
|
||||
} else if (auto *with = dynamic_cast<query::With *>(clause)) {
|
||||
input_op = GenWith(*with, input_op, symbol_table, is_write,
|
||||
bound_symbols, storage);
|
||||
input_op = GenWith(*with, input_op, context.symbol_table, is_write,
|
||||
context.bound_symbols, context.ast_storage);
|
||||
// WITH clause advances the command, so reset the flag.
|
||||
is_write = false;
|
||||
} else if (auto *op = HandleWriteClause(clause, input_op, symbol_table,
|
||||
bound_symbols)) {
|
||||
} else if (auto *op =
|
||||
HandleWriteClause(clause, input_op, context.symbol_table,
|
||||
context.bound_symbols)) {
|
||||
is_write = true;
|
||||
input_op = op;
|
||||
} else if (auto *unwind = dynamic_cast<query::Unwind *>(clause)) {
|
||||
const auto &symbol = symbol_table.at(*unwind->named_expression_);
|
||||
BindSymbol(bound_symbols, symbol);
|
||||
const auto &symbol =
|
||||
context.symbol_table.at(*unwind->named_expression_);
|
||||
BindSymbol(context.bound_symbols, symbol);
|
||||
input_op =
|
||||
new plan::Unwind(std::shared_ptr<LogicalOperator>(input_op),
|
||||
unwind->named_expression_->expression_, symbol);
|
||||
@ -774,9 +911,6 @@ std::unique_ptr<LogicalOperator> MakeLogicalPlan(AstTreeStorage &storage,
|
||||
}
|
||||
}
|
||||
}
|
||||
debug_assert(
|
||||
matches.empty(),
|
||||
"Expected Match clause(s) to be followed by an update or return clause");
|
||||
return std::unique_ptr<LogicalOperator>(input_op);
|
||||
}
|
||||
|
||||
|
@ -19,6 +19,7 @@ namespace plan {
|
||||
/// certain operators.
|
||||
std::unique_ptr<LogicalOperator> MakeLogicalPlan(
|
||||
AstTreeStorage &storage, query::SymbolTable &symbol_table);
|
||||
}
|
||||
|
||||
} // namespace plan
|
||||
|
||||
} // namespace query
|
||||
|
@ -49,6 +49,7 @@ class PlanChecker : public HierarchicalLogicalOperatorVisitor {
|
||||
PRE_VISIT(CreateExpand);
|
||||
PRE_VISIT(Delete);
|
||||
PRE_VISIT(ScanAll);
|
||||
PRE_VISIT(ScanAllByLabel);
|
||||
PRE_VISIT(Expand);
|
||||
PRE_VISIT(Filter);
|
||||
PRE_VISIT(Produce);
|
||||
@ -111,6 +112,7 @@ using ExpectCreateNode = OpChecker<CreateNode>;
|
||||
using ExpectCreateExpand = OpChecker<CreateExpand>;
|
||||
using ExpectDelete = OpChecker<Delete>;
|
||||
using ExpectScanAll = OpChecker<ScanAll>;
|
||||
using ExpectScanAllByLabel = OpChecker<ScanAllByLabel>;
|
||||
using ExpectExpand = OpChecker<Expand>;
|
||||
using ExpectFilter = OpChecker<Filter>;
|
||||
using ExpectProduce = OpChecker<Produce>;
|
||||
@ -292,7 +294,7 @@ TEST(TestLogicalPlanner, MatchLabeledNodes) {
|
||||
auto dba = dbms.active();
|
||||
auto label = dba->label("label");
|
||||
QUERY(MATCH(PATTERN(NODE("n", label))), RETURN(IDENT("n"), AS("n")));
|
||||
CheckPlan(storage, ExpectScanAll(), ExpectFilter(), ExpectProduce());
|
||||
CheckPlan(storage, ExpectScanAllByLabel(), ExpectFilter(), ExpectProduce());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchPathReturn) {
|
||||
@ -847,4 +849,14 @@ TEST(TestLogicalPlanner, UnwindMergeNodeProperty) {
|
||||
for (auto &op : on_create) delete op;
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MultipleOptionalMatchReturn) {
|
||||
// Test OPTIONAL MATCH (n) OPTIONAL MATCH (m) RETURN n
|
||||
AstTreeStorage storage;
|
||||
QUERY(OPTIONAL_MATCH(PATTERN(NODE("n"))), OPTIONAL_MATCH(PATTERN(NODE("m"))),
|
||||
RETURN(IDENT("n"), AS("n")));
|
||||
std::list<BaseOpChecker *> optional{new ExpectScanAll()};
|
||||
CheckPlan(storage, ExpectOptional(optional), ExpectOptional(optional),
|
||||
ExpectProduce());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user