Preprocess Ast to QueryParts and plan ScanAllByLabel

Summary:
Mention the non-existent function name in semantic error. Don't merge optional
matches into one Matching, because it is an error to treat multiple optional
matches as a single optional match. Document new structures and functions. Add
not so smart ScanAllByLabel generation.

Reviewers: mislav.bradac, buda, florijan, lion

Reviewed By: mislav.bradac

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D394
This commit is contained in:
Teon Banek 2017-05-26 12:05:00 +02:00
parent d3d8264fae
commit 74b082f050
4 changed files with 334 additions and 186 deletions

View File

@ -810,7 +810,8 @@ antlrcpp::Any CypherMainVisitor::visitFunctionInvocation(
}
}
auto function = NameToFunction(function_name);
if (!function) throw SemanticException("Function doesn't exist.");
if (!function)
throw SemanticException("Function '{}' doesn't exist.", function_name);
return static_cast<Expression *>(
storage_.Create<Function>(function, expressions));
}

View File

@ -60,6 +60,27 @@ auto ReducePattern(
return last_res;
}
void ForeachPattern(
Pattern &pattern, std::function<void(NodeAtom *)> base,
std::function<void(NodeAtom *, EdgeAtom *, NodeAtom *)> collect) {
debug_assert(!pattern.atoms_.empty(), "Missing atoms in pattern");
auto atoms_it = pattern.atoms_.begin();
auto current_node = dynamic_cast<NodeAtom *>(*atoms_it++);
debug_assert(current_node, "First pattern atom is not a node");
base(current_node);
// Remaining atoms need to follow sequentially as (EdgeAtom, NodeAtom)*
while (atoms_it != pattern.atoms_.end()) {
auto edge = dynamic_cast<EdgeAtom *>(*atoms_it++);
debug_assert(edge, "Expected an edge atom in pattern.");
debug_assert(atoms_it != pattern.atoms_.end(),
"Edge atom should not end the pattern.");
auto prev_node = current_node;
current_node = dynamic_cast<NodeAtom *>(*atoms_it++);
debug_assert(current_node, "Expected a node atom in pattern.");
collect(prev_node, edge, current_node);
}
}
auto GenCreateForPattern(Pattern &pattern, LogicalOperator *input_op,
const SymbolTable &symbol_table,
std::unordered_set<Symbol> &bound_symbols) {
@ -159,16 +180,15 @@ Expression *PropertiesEqual(AstTreeStorage &storage,
return filter_expr;
}
auto &CollectPatternFilters(
void CollectPatternFilters(
Pattern &pattern, const SymbolTable &symbol_table,
std::list<std::pair<Expression *, std::unordered_set<Symbol>>> &filters,
std::vector<std::pair<Expression *, std::unordered_set<Symbol>>> &filters,
AstTreeStorage &storage) {
UsedSymbolsCollector collector(symbol_table);
auto node_filter = [&](NodeAtom *node) {
Expression *labels_filter =
node->labels_.empty() ? nullptr
: labels_filter = storage.Create<LabelsTest>(
node->identifier_, node->labels_);
node->labels_.empty() ? nullptr : storage.Create<LabelsTest>(
node->identifier_, node->labels_);
auto *props_filter = PropertiesEqual(storage, collector, node);
if (labels_filter || props_filter) {
collector.symbols_.insert(symbol_table.at(*node->identifier_));
@ -177,9 +197,8 @@ auto &CollectPatternFilters(
collector.symbols_);
collector.symbols_.clear();
}
return &filters;
};
auto expand_filter = [&](auto *filters, NodeAtom *prev_node, EdgeAtom *edge,
auto expand_filter = [&](NodeAtom *prev_node, EdgeAtom *edge,
NodeAtom *node) {
Expression *types_filter = edge->edge_types_.empty()
? nullptr
@ -189,30 +208,14 @@ auto &CollectPatternFilters(
if (types_filter || props_filter) {
const auto &edge_symbol = symbol_table.at(*edge->identifier_);
collector.symbols_.insert(edge_symbol);
filters->emplace_back(
filters.emplace_back(
BoolJoin<FilterAndOperator>(storage, types_filter, props_filter),
collector.symbols_);
collector.symbols_.clear();
}
return node_filter(node);
node_filter(node);
};
return *ReducePattern<
std::list<std::pair<Expression *, std::unordered_set<Symbol>>> *>(
pattern, node_filter, expand_filter);
}
void CollectMatchFilters(
const Match &match, const SymbolTable &symbol_table,
std::list<std::pair<Expression *, std::unordered_set<Symbol>>> &filters,
AstTreeStorage &storage) {
for (auto *pattern : match.patterns_) {
CollectPatternFilters(*pattern, symbol_table, filters, storage);
}
if (match.where_) {
UsedSymbolsCollector collector(symbol_table);
match.where_->expression_->Accept(collector);
filters.emplace_back(match.where_->expression_, collector.symbols_);
}
ForeachPattern(pattern, node_filter, expand_filter);
}
// Contextual information used for generating match operators.
@ -224,19 +227,13 @@ struct MatchContext {
std::unordered_set<Symbol> &bound_symbols;
// Determines whether the match should see the new graph state or not.
GraphView graph_view = GraphView::OLD;
// Pairs of filter expression and symbols used in them. The list should be
// filled using CollectPatternFilters function, and later modified during
// GenMatchForPattern.
std::list<std::pair<Expression *, std::unordered_set<Symbol>>> filters;
// Symbols for edges established in match, used to ensure Cyphermorphism.
std::unordered_set<Symbol> edge_symbols;
// All the newly established symbols in match.
std::vector<Symbol> new_symbols;
};
auto GenFilters(
LogicalOperator *last_op, const std::unordered_set<Symbol> &bound_symbols,
std::list<std::pair<Expression *, std::unordered_set<Symbol>>> &filters,
std::vector<std::pair<Expression *, std::unordered_set<Symbol>>> &filters,
AstTreeStorage &storage) {
Expression *filter_expr = nullptr;
for (auto filters_it = filters.begin(); filters_it != filters.end();) {
@ -255,114 +252,6 @@ auto GenFilters(
return last_op;
}
// Generates operators for matching the given pattern and appends them to
// input_op. Fills the context with all the new symbols and edge symbols.
auto GenMatchForPattern(Pattern &pattern, LogicalOperator *input_op,
MatchContext &context, AstTreeStorage &storage) {
auto &bound_symbols = context.bound_symbols;
const auto &symbol_table = context.symbol_table;
auto base = [&](NodeAtom *node) {
// Try to generate any filters even before the 1st match operator.
auto *last_op =
GenFilters(input_op, bound_symbols, context.filters, storage);
// If the first atom binds a symbol, we generate a ScanAll which writes it.
// Otherwise, someone else generates it (e.g. a previous ScanAll).
const auto &node_symbol = symbol_table.at(*node->identifier_);
if (BindSymbol(bound_symbols, node_symbol)) {
last_op = new ScanAll(std::shared_ptr<LogicalOperator>(last_op),
node_symbol, context.graph_view);
context.new_symbols.emplace_back(node_symbol);
}
return GenFilters(last_op, bound_symbols, context.filters, storage);
};
auto collect = [&](LogicalOperator *last_op, NodeAtom *prev_node,
EdgeAtom *edge, NodeAtom *node) {
// Store the symbol from the first node as the input to Expand.
const auto &input_symbol = symbol_table.at(*prev_node->identifier_);
// If the expand symbols were already bound, then we need to indicate
// that they exist. The Expand will then check whether the pattern holds
// instead of writing the expansion to symbols.
const auto &node_symbol = symbol_table.at(*node->identifier_);
auto existing_node = false;
if (!BindSymbol(bound_symbols, node_symbol)) {
existing_node = true;
} else {
context.new_symbols.emplace_back(node_symbol);
}
const auto &edge_symbol = symbol_table.at(*edge->identifier_);
auto existing_edge = false;
if (!BindSymbol(bound_symbols, edge_symbol)) {
existing_edge = true;
} else {
context.new_symbols.emplace_back(edge_symbol);
}
last_op = new Expand(node, edge, std::shared_ptr<LogicalOperator>(last_op),
input_symbol, existing_node, existing_edge,
context.graph_view);
if (!existing_edge) {
// Ensure Cyphermorphism (different edge symbols always map to different
// edges).
if (!context.edge_symbols.empty()) {
last_op = new ExpandUniquenessFilter<EdgeAccessor>(
std::shared_ptr<LogicalOperator>(last_op), edge_symbol,
std::vector<Symbol>(context.edge_symbols.begin(),
context.edge_symbols.end()));
}
}
// Insert edge_symbol after creating ExpandUniquenessFilter, so that we
// avoid filtering by the same edge we just expanded.
context.edge_symbols.insert(edge_symbol);
return GenFilters(last_op, bound_symbols, context.filters, storage);
};
return ReducePattern<LogicalOperator *>(pattern, base, collect);
}
auto GenMatches(std::vector<Match *> &matches, LogicalOperator *input_op,
const SymbolTable &symbol_table,
std::unordered_set<Symbol> &bound_symbols,
AstTreeStorage &storage) {
auto *last_op = input_op;
MatchContext req_ctx{symbol_table, bound_symbols};
// Collect all non-optional match filters, so that we can put them as soon as
// possible in the operator tree. Optional match need to be treated
// specially, because they need to remain inside the optional match.
for (auto *match : matches) {
if (match->optional_) {
continue;
}
CollectMatchFilters(*match, symbol_table, req_ctx.filters, storage);
}
auto gen_match = [&storage](const Match &match, LogicalOperator *input_op,
MatchContext &context) {
auto *match_op = input_op;
for (auto *pattern : match.patterns_) {
match_op = GenMatchForPattern(*pattern, match_op, context, storage);
}
return match_op;
};
for (auto *match : matches) {
if (match->optional_) {
// Optional match needs to be standalone, so filter only by its filters
// and don't plug the previous match_op as input.
MatchContext opt_ctx{symbol_table, bound_symbols};
CollectMatchFilters(*match, symbol_table, opt_ctx.filters, storage);
auto *match_op = gen_match(*match, nullptr, opt_ctx);
last_op = new Optional(std::shared_ptr<LogicalOperator>(last_op),
std::shared_ptr<LogicalOperator>(match_op),
opt_ctx.new_symbols);
debug_assert(opt_ctx.filters.empty(),
"Expected to generate all optional filters");
} else {
// Since we reuse req_ctx, we need to clear the symbols for the new match.
req_ctx.edge_symbols.clear();
req_ctx.new_symbols.clear();
last_op = gen_match(*match, last_op, req_ctx);
}
}
debug_assert(req_ctx.filters.empty(), "Expected to generate all filters");
return last_op;
}
// Ast tree visitor which collects the context for a return body.
// The return body of WITH and RETURN clauses consists of:
//
@ -693,27 +582,269 @@ LogicalOperator *HandleWriteClause(Clause *clause, LogicalOperator *input_op,
return nullptr;
}
// Normalized representation of a pattern that needs to be matched.
struct Expansion {
// The first node in the expansion, it can be a single node.
NodeAtom *node1 = nullptr;
// Optional edge which connects the 2 nodes.
EdgeAtom *edge = nullptr;
// Optional node at the other end of an edge. If the expansion contains an
// edge, then this node is required.
NodeAtom *node2 = nullptr;
};
// Normalized representation of a single or multiple Match clauses.
//
// For example, `MATCH (a :Label) -[e1]- (b) -[e2]- (c) MATCH (n) -[e3]- (m)
// WHERE c.prop < 42` will produce the following.
// Expansions will store `(a) -[e1]-(b)`, `(b) -[e2]- (c)` and `(n) -[e3]- (m)`.
// Edge symbols for Cyphermorphism will only contain the set `{e1, e2}` for the
// first `MATCH` and the set `{e3}` for the second.
// Filters will contain 2 pairs. One for testing `:Label` on symbol `a` and the
// other obtained from `WHERE` on symbol `c`.
struct Matching {
// All expansions that need to be performed across Match clauses.
std::vector<Expansion> expansions;
// Symbols for edges established in match, used to ensure Cyphermorphism.
// There are multiple sets, because each Match clause determines a single set.
std::vector<std::unordered_set<Symbol>> edge_symbols;
// Pairs of filter expression and symbols used in them. The list should be
// filled using CollectPatternFilters function.
std::vector<std::pair<Expression *, std::unordered_set<Symbol>>> filters;
};
// Represents a read (+ write) part of a query. Each part ends with either:
// * RETURN clause;
// * WITH clause or
// * any of the write clauses.
//
// For a query `MATCH (n) MERGE (n) -[e]- (m) SET n.x = 42 MERGE (l)` the
// generated QueryPart will have `matching` generated for the `MATCH`.
// `remaining_clauses` will contain `Merge`, `SetProperty` and `Merge` clauses
// in that exact order. The pattern inside the first `MERGE` will be used to
// generate the first `merge_matching` element, and the second `MERGE` pattern
// will produce the second `merge_matching` element. This way, if someone
// traverses `remaining_clauses`, the order of appearance of `Merge` clauses is
// in the same order as their respective `merge_matching` elements.
struct QueryPart {
// All MATCH clauses merged into one Matching.
Matching matching;
// Each OPTIONAL MATCH converted to Matching.
std::vector<Matching> optional_matching;
// Matching for each MERGE clause. Since Merge is contained in
// remaining_clauses, this vector contains matching in the same order as Merge
// appears.
std::vector<Matching> merge_matching;
// All the remaining clauses (without Match).
std::vector<Clause *> remaining_clauses;
};
// Context which contains variables commonly used during planning.
struct PlanningContext {
SymbolTable &symbol_table;
AstTreeStorage &ast_storage;
// bound_symbols set is used to differentiate cycles in pattern matching, so
// that the operator can be correctly initialized whether to read the symbol
// or write it. E.g. `MATCH (n) -[r]- (n)` would bind (and write) the first
// `n`, but the latter `n` would only read the already written information.
std::unordered_set<Symbol> bound_symbols;
};
// Converts multiple Patterns to Expansions. Each Pattern can contain an
// arbitrarily long chain of nodes and edges. The conversion to an Expansion is
// done by splitting a pattern into triplets (node1, edge, node2). The triplets
// conserve the semantics of the pattern. For example, in a pattern:
// (m) -[e]- (n) -[f]- (o) the same can be achieved with:
// (m) -[e]- (n), (n) -[f]- (o).
// This representation makes it easier to permute from which node or edge we
// want to start expanding.
std::vector<Expansion> NormalizePatterns(
const std::vector<Pattern *> &patterns) {
std::vector<Expansion> expansions;
auto collect_node = [&](auto *node) {
expansions.emplace_back(Expansion{node});
};
auto collect_expansion = [&](auto *prev_node, auto *edge,
auto *current_node) {
expansions.emplace_back(Expansion{prev_node, edge, current_node});
};
for (const auto &pattern : patterns) {
ForeachPattern(*pattern, collect_node, collect_expansion);
}
return expansions;
}
// Fills the given Matching, by converting the Match patterns to normalized
// representation as Expansions. Filters used in the Match are also collected,
// as well as edge symbols which determine Cyphermorphism. Collecting filters
// will lift them out of a pattern and generate new expressions (just like they
// were in a Where clause).
void AddMatching(const std::vector<Pattern *> &patterns, Where *where,
const SymbolTable &symbol_table, AstTreeStorage &storage,
Matching &matching) {
auto expansions = NormalizePatterns(patterns);
std::unordered_set<Symbol> edge_symbols;
for (const auto &expansion : expansions) {
if (expansion.edge) {
edge_symbols.insert(symbol_table.at(*expansion.edge->identifier_));
}
}
if (!edge_symbols.empty()) {
matching.edge_symbols.emplace_back(edge_symbols);
}
matching.expansions.insert(matching.expansions.end(), expansions.begin(),
expansions.end());
for (auto *pattern : patterns) {
CollectPatternFilters(*pattern, symbol_table, matching.filters, storage);
}
if (where) {
UsedSymbolsCollector collector(symbol_table);
where->expression_->Accept(collector);
matching.filters.emplace_back(where->expression_, collector.symbols_);
}
}
void AddMatching(const Match &match, const SymbolTable &symbol_table,
AstTreeStorage &storage, Matching &matching) {
return AddMatching(match.patterns_, match.where_, symbol_table, storage,
matching);
}
// Converts a Query to multiple QueryParts. In the process new Ast nodes may be
// created, e.g. filter expressions.
std::vector<QueryPart> CollectQueryParts(const SymbolTable &symbol_table,
AstTreeStorage &storage) {
auto query = storage.query();
std::vector<QueryPart> query_parts(1);
auto *query_part = &query_parts.back();
for (auto &clause : query->clauses_) {
if (auto *match = dynamic_cast<Match *>(clause)) {
if (match->optional_) {
query_part->optional_matching.emplace_back(Matching{});
AddMatching(*match, symbol_table, storage,
query_part->optional_matching.back());
} else {
debug_assert(query_part->optional_matching.empty(),
"Match clause cannot follow optional match.");
AddMatching(*match, symbol_table, storage, query_part->matching);
}
} else {
query_part->remaining_clauses.push_back(clause);
if (auto *merge = dynamic_cast<query::Merge *>(clause)) {
query_part->merge_matching.emplace_back(Matching{});
AddMatching({merge->pattern_}, nullptr, symbol_table, storage,
query_part->merge_matching.back());
} else if (dynamic_cast<With *>(clause)) {
query_parts.emplace_back(QueryPart{});
query_part = &query_parts.back();
} else if (dynamic_cast<Return *>(clause)) {
// TODO: Support RETURN UNION ...
return query_parts;
}
}
}
return query_parts;
}
LogicalOperator *PlanMatching(const Matching &matching,
LogicalOperator *input_op,
AstTreeStorage &storage, MatchContext &context) {
auto &bound_symbols = context.bound_symbols;
const auto &symbol_table = context.symbol_table;
// Copy filters, because we will modify the list as we generate Filters.
auto filters = matching.filters;
// Try to generate any filters even before the 1st match operator. This
// optimizes the optional match which filters only on symbols bound in regular
// match.
auto *last_op = GenFilters(input_op, bound_symbols, filters, storage);
for (const auto &expansion : matching.expansions) {
const auto &node1_symbol = symbol_table.at(*expansion.node1->identifier_);
if (BindSymbol(bound_symbols, node1_symbol)) {
// We have just bound this symbol, so generate ScanAll which fills it.
const auto &labels = expansion.node1->labels_;
if (labels.empty()) {
last_op = new ScanAll(std::shared_ptr<LogicalOperator>(last_op),
node1_symbol, context.graph_view);
} else {
// Don't act smart by selecting the best label index, so take the first.
last_op = new ScanAllByLabel(std::shared_ptr<LogicalOperator>(last_op),
node1_symbol, labels.front(),
context.graph_view);
}
context.new_symbols.emplace_back(node1_symbol);
last_op = GenFilters(last_op, bound_symbols, filters, storage);
}
// We have an edge, so generate Expand.
if (expansion.edge) {
// If the expand symbols were already bound, then we need to indicate
// that they exist. The Expand will then check whether the pattern holds
// instead of writing the expansion to symbols.
const auto &node_symbol = symbol_table.at(*expansion.node2->identifier_);
auto existing_node = false;
if (!BindSymbol(bound_symbols, node_symbol)) {
existing_node = true;
} else {
context.new_symbols.emplace_back(node_symbol);
}
const auto &edge_symbol = symbol_table.at(*expansion.edge->identifier_);
auto existing_edge = false;
if (!BindSymbol(bound_symbols, edge_symbol)) {
existing_edge = true;
} else {
context.new_symbols.emplace_back(edge_symbol);
}
last_op =
new Expand(expansion.node2, expansion.edge,
std::shared_ptr<LogicalOperator>(last_op), node1_symbol,
existing_node, existing_edge, context.graph_view);
if (!existing_edge) {
// Ensure Cyphermorphism (different edge symbols always map to different
// edges).
for (const auto &edge_symbols : matching.edge_symbols) {
if (edge_symbols.find(edge_symbol) == edge_symbols.end()) {
continue;
}
std::vector<Symbol> other_symbols;
for (const auto &symbol : edge_symbols) {
if (symbol == edge_symbol ||
bound_symbols.find(symbol) == bound_symbols.end()) {
continue;
}
other_symbols.push_back(symbol);
}
if (!other_symbols.empty()) {
last_op = new ExpandUniquenessFilter<EdgeAccessor>(
std::shared_ptr<LogicalOperator>(last_op), edge_symbol,
other_symbols);
}
}
}
last_op = GenFilters(last_op, bound_symbols, filters, storage);
}
}
debug_assert(filters.empty(), "Expected to generate all filters");
return last_op;
}
auto GenMerge(query::Merge &merge, LogicalOperator *input_op,
const SymbolTable &symbol_table,
std::unordered_set<Symbol> &bound_symbols,
AstTreeStorage &storage) {
const Matching &matching, PlanningContext &context) {
// Copy the bound symbol set, because we don't want to use the updated version
// when generating the create part.
std::unordered_set<Symbol> bound_symbols_copy(bound_symbols);
MatchContext context{symbol_table, bound_symbols_copy, GraphView::NEW};
CollectPatternFilters(*merge.pattern_, symbol_table, context.filters,
storage);
std::unordered_set<Symbol> bound_symbols_copy(context.bound_symbols);
MatchContext match_ctx{context.symbol_table, bound_symbols_copy,
GraphView::NEW};
auto on_match =
GenMatchForPattern(*merge.pattern_, nullptr, context, storage);
PlanMatching(matching, nullptr, context.ast_storage, match_ctx);
// Use the original bound_symbols, so we fill it with new symbols.
auto on_create = GenCreateForPattern(*merge.pattern_, nullptr, symbol_table,
bound_symbols);
auto on_create = GenCreateForPattern(
*merge.pattern_, nullptr, context.symbol_table, context.bound_symbols);
for (auto &set : merge.on_create_) {
on_create = HandleWriteClause(set, on_create, symbol_table, bound_symbols);
on_create = HandleWriteClause(set, on_create, context.symbol_table,
context.bound_symbols);
debug_assert(on_create, "Expected SET in MERGE ... ON CREATE");
}
for (auto &set : merge.on_match_) {
on_match = HandleWriteClause(set, on_match, symbol_table, bound_symbols);
on_match = HandleWriteClause(set, on_match, context.symbol_table,
context.bound_symbols);
debug_assert(on_match, "Expected SET in MERGE ... ON MATCH");
}
return new plan::Merge(std::shared_ptr<LogicalOperator>(input_op),
@ -725,46 +856,52 @@ auto GenMerge(query::Merge &merge, LogicalOperator *input_op,
std::unique_ptr<LogicalOperator> MakeLogicalPlan(AstTreeStorage &storage,
SymbolTable &symbol_table) {
auto query = storage.query();
// bound_symbols set is used to differentiate cycles in pattern matching, so
// that the operator can be correctly initialized whether to read the symbol
// or write it. E.g. `MATCH (n) -[r]- (n)` would bind (and write) the first
// `n`, but the latter `n` would only read the already written information.
std::unordered_set<Symbol> bound_symbols;
auto query_parts = CollectQueryParts(symbol_table, storage);
PlanningContext context{symbol_table, storage};
LogicalOperator *input_op = nullptr;
// Set to true if a query command writes to the database.
bool is_write = false;
LogicalOperator *input_op = nullptr;
// All sequential Match clauses. Reset after encountering non-Match.
std::vector<Match *> matches;
for (auto &clause : query->clauses_) {
// Clauses which read from the database.
if (auto *match = dynamic_cast<Match *>(clause)) {
matches.emplace_back(match);
} else {
input_op =
GenMatches(matches, input_op, symbol_table, bound_symbols, storage);
matches.clear();
for (const auto &query_part : query_parts) {
MatchContext match_ctx{context.symbol_table, context.bound_symbols};
input_op = PlanMatching(query_part.matching, input_op, context.ast_storage,
match_ctx);
for (const auto &matching : query_part.optional_matching) {
MatchContext opt_ctx{context.symbol_table, context.bound_symbols};
auto *match_op =
PlanMatching(matching, nullptr, context.ast_storage, opt_ctx);
if (match_op) {
input_op = new Optional(std::shared_ptr<LogicalOperator>(input_op),
std::shared_ptr<LogicalOperator>(match_op),
opt_ctx.new_symbols);
}
}
int merge_id = 0;
for (auto &clause : query_part.remaining_clauses) {
debug_assert(dynamic_cast<Match *>(clause) == nullptr,
"Unexpected Match in remaining clauses");
if (auto *ret = dynamic_cast<Return *>(clause)) {
input_op = GenReturn(*ret, input_op, symbol_table, is_write,
bound_symbols, storage);
input_op = GenReturn(*ret, input_op, context.symbol_table, is_write,
context.bound_symbols, context.ast_storage);
} else if (auto *merge = dynamic_cast<query::Merge *>(clause)) {
input_op =
GenMerge(*merge, input_op, symbol_table, bound_symbols, storage);
input_op = GenMerge(*merge, input_op,
query_part.merge_matching[merge_id++], context);
// Treat MERGE clause as write, because we do not know if it will create
// anything.
is_write = true;
} else if (auto *with = dynamic_cast<query::With *>(clause)) {
input_op = GenWith(*with, input_op, symbol_table, is_write,
bound_symbols, storage);
input_op = GenWith(*with, input_op, context.symbol_table, is_write,
context.bound_symbols, context.ast_storage);
// WITH clause advances the command, so reset the flag.
is_write = false;
} else if (auto *op = HandleWriteClause(clause, input_op, symbol_table,
bound_symbols)) {
} else if (auto *op =
HandleWriteClause(clause, input_op, context.symbol_table,
context.bound_symbols)) {
is_write = true;
input_op = op;
} else if (auto *unwind = dynamic_cast<query::Unwind *>(clause)) {
const auto &symbol = symbol_table.at(*unwind->named_expression_);
BindSymbol(bound_symbols, symbol);
const auto &symbol =
context.symbol_table.at(*unwind->named_expression_);
BindSymbol(context.bound_symbols, symbol);
input_op =
new plan::Unwind(std::shared_ptr<LogicalOperator>(input_op),
unwind->named_expression_->expression_, symbol);
@ -774,9 +911,6 @@ std::unique_ptr<LogicalOperator> MakeLogicalPlan(AstTreeStorage &storage,
}
}
}
debug_assert(
matches.empty(),
"Expected Match clause(s) to be followed by an update or return clause");
return std::unique_ptr<LogicalOperator>(input_op);
}

View File

@ -19,6 +19,7 @@ namespace plan {
/// certain operators.
std::unique_ptr<LogicalOperator> MakeLogicalPlan(
AstTreeStorage &storage, query::SymbolTable &symbol_table);
}
} // namespace plan
} // namespace query

View File

@ -49,6 +49,7 @@ class PlanChecker : public HierarchicalLogicalOperatorVisitor {
PRE_VISIT(CreateExpand);
PRE_VISIT(Delete);
PRE_VISIT(ScanAll);
PRE_VISIT(ScanAllByLabel);
PRE_VISIT(Expand);
PRE_VISIT(Filter);
PRE_VISIT(Produce);
@ -111,6 +112,7 @@ using ExpectCreateNode = OpChecker<CreateNode>;
using ExpectCreateExpand = OpChecker<CreateExpand>;
using ExpectDelete = OpChecker<Delete>;
using ExpectScanAll = OpChecker<ScanAll>;
using ExpectScanAllByLabel = OpChecker<ScanAllByLabel>;
using ExpectExpand = OpChecker<Expand>;
using ExpectFilter = OpChecker<Filter>;
using ExpectProduce = OpChecker<Produce>;
@ -292,7 +294,7 @@ TEST(TestLogicalPlanner, MatchLabeledNodes) {
auto dba = dbms.active();
auto label = dba->label("label");
QUERY(MATCH(PATTERN(NODE("n", label))), RETURN(IDENT("n"), AS("n")));
CheckPlan(storage, ExpectScanAll(), ExpectFilter(), ExpectProduce());
CheckPlan(storage, ExpectScanAllByLabel(), ExpectFilter(), ExpectProduce());
}
TEST(TestLogicalPlanner, MatchPathReturn) {
@ -847,4 +849,14 @@ TEST(TestLogicalPlanner, UnwindMergeNodeProperty) {
for (auto &op : on_create) delete op;
}
TEST(TestLogicalPlanner, MultipleOptionalMatchReturn) {
// Test OPTIONAL MATCH (n) OPTIONAL MATCH (m) RETURN n
AstTreeStorage storage;
QUERY(OPTIONAL_MATCH(PATTERN(NODE("n"))), OPTIONAL_MATCH(PATTERN(NODE("m"))),
RETURN(IDENT("n"), AS("n")));
std::list<BaseOpChecker *> optional{new ExpectScanAll()};
CheckPlan(storage, ExpectOptional(optional), ExpectOptional(optional),
ExpectProduce());
}
} // namespace