diff --git a/src/query/frontend/logical/planner.cpp b/src/query/frontend/logical/planner.cpp index 9d9c77e60..40a769695 100644 --- a/src/query/frontend/logical/planner.cpp +++ b/src/query/frontend/logical/planner.cpp @@ -16,51 +16,87 @@ bool BindSymbol(std::unordered_set &bound_symbols, const Symbol &symbol) { return insertion.second; } -LogicalOperator *GenCreateForPattern(Pattern &pattern, - LogicalOperator *input_op, - const SymbolTable &symbol_table, - std::unordered_set bound_symbols) { +/// Utility function for iterating pattern atoms and accumulating a result. +/// +/// Each pattern is of the form `NodeAtom (, EdgeAtom, NodeAtom)*`. Therefore, +/// the `base` function is called on the first `NodeAtom`, while the `collect` +/// is called for the whole triplet. Result of the function is passed to the +/// next call. Final result is returned. +/// +/// Example usage of counting edge atoms in the pattern. +/// +/// auto base = [](NodeAtom *first_node) { return 0; }; +/// auto collect = [](int accum, NodeAtom *prev_node, EdgeAtom *edge, +/// NodeAtom *node) { +/// return accum + 1; +/// }; +/// int edge_count = ReducePattern(pattern, base, collect); +/// +// TODO: It might be a good idea to move this somewhere else, for easier usage +// in other files. +template +auto ReducePattern( + Pattern &pattern, std::function base, + std::function collect) { + debug_assert(!pattern.atoms_.empty(), "Missing atoms in pattern"); auto atoms_it = pattern.atoms_.begin(); - auto last_node = dynamic_cast(*atoms_it++); - debug_assert(last_node, "First pattern atom is not a node"); - auto last_op = input_op; - if (BindSymbol(bound_symbols, symbol_table.at(*last_node->identifier_))) { - // TODO: Pass last_op when CreateOp gets support for it. This will - // support e.g. `MATCH (n) CREATE (m)` and `CREATE (n), (m)`. - if (last_op) { - throw NotYetImplemented(); - } - last_op = new CreateOp(last_node); - } + auto current_node = dynamic_cast(*atoms_it++); + debug_assert(current_node, "First pattern atom is not a node"); + auto last_res = base(current_node); // Remaining atoms need to follow sequentially as (EdgeAtom, NodeAtom)* while (atoms_it != pattern.atoms_.end()) { auto edge = dynamic_cast(*atoms_it++); debug_assert(edge, "Expected an edge atom in pattern."); debug_assert(atoms_it != pattern.atoms_.end(), "Edge atom should not end the pattern."); + auto prev_node = current_node; + current_node = dynamic_cast(*atoms_it++); + debug_assert(current_node, "Expected a node atom in pattern."); + last_res = collect(last_res, prev_node, edge, current_node); + } + return last_res; +} + +auto GenCreateForPattern(Pattern &pattern, LogicalOperator *input_op, + const SymbolTable &symbol_table, + std::unordered_set bound_symbols) { + auto base = [&](NodeAtom *node) -> LogicalOperator * { + if (BindSymbol(bound_symbols, symbol_table.at(*node->identifier_))) { + // TODO: Pass input_op when CreateOp gets support for it. This will + // support e.g. `MATCH (n) CREATE (m)` and `CREATE (n), (m)`. + if (input_op) { + throw NotYetImplemented(); + } + return new CreateOp(node); + } else { + return input_op; + } + }; + + auto collect = [&](LogicalOperator *last_op, NodeAtom *prev_node, + EdgeAtom *edge, NodeAtom *node) { // Store the symbol from the first node as the input to CreateExpand. - auto input_symbol = symbol_table.at(*last_node->identifier_); - last_node = dynamic_cast(*atoms_it++); - debug_assert(last_node, "Expected a node atom in pattern."); + auto input_symbol = symbol_table.at(*prev_node->identifier_); // If the expand node was already bound, then we need to indicate this, // so that CreateExpand only creates an edge. bool node_existing = false; - if (!BindSymbol(bound_symbols, symbol_table.at(*last_node->identifier_))) { + if (!BindSymbol(bound_symbols, symbol_table.at(*node->identifier_))) { node_existing = true; } if (!BindSymbol(bound_symbols, symbol_table.at(*edge->identifier_))) { permanent_fail("Symbols used for created edges cannot be redeclared."); } - last_op = new CreateExpand(last_node, edge, - std::shared_ptr(last_op), - input_symbol, node_existing); - } - return last_op; + return new CreateExpand(node, edge, + std::shared_ptr(last_op), + input_symbol, node_existing); + }; + + return ReducePattern(pattern, base, collect); } -LogicalOperator *GenCreate(Create &create, LogicalOperator *input_op, - const SymbolTable &symbol_table, - std::unordered_set bound_symbols) { +auto GenCreate(Create &create, LogicalOperator *input_op, + const SymbolTable &symbol_table, + std::unordered_set bound_symbols) { auto last_op = input_op; for (auto pattern : create.patterns_) { last_op = @@ -69,70 +105,62 @@ LogicalOperator *GenCreate(Create &create, LogicalOperator *input_op, return last_op; } -LogicalOperator *GenMatch(Match &match, LogicalOperator *input_op, - const SymbolTable &symbol_table, - std::unordered_set &bound_symbols) { - if (input_op) { - // TODO: Support clauses before match. - throw NotYetImplemented(); - } - if (match.patterns_.size() != 1) { - // TODO: Support matching multiple patterns. - throw NotYetImplemented(); - } - auto &pattern = match.patterns_[0]; - debug_assert(!pattern->atoms_.empty(), "Missing atoms in pattern"); - auto atoms_it = pattern->atoms_.begin(); - auto last_node = dynamic_cast(*atoms_it++); - debug_assert(last_node, "First pattern atom is not a node"); - // First atom always binds a symbol, and we don't care if it already existed, - // because we create a ScanAll which writes that symbol. This may need to - // change when we support clauses before match. - BindSymbol(bound_symbols, symbol_table.at(*last_node->identifier_)); - LogicalOperator *last_op = new ScanAll(last_node); - if (!last_node->labels_.empty() || !last_node->properties_.empty()) { - last_op = - new NodeFilter(std::shared_ptr(last_op), - symbol_table.at(*last_node->identifier_), last_node); - } - // Remaining atoms need to follow sequentially as (EdgeAtom, NodeAtom)* - while (atoms_it != pattern->atoms_.end()) { - auto edge = dynamic_cast(*atoms_it++); - debug_assert(edge, "Expected an edge atom in pattern."); - debug_assert(atoms_it != pattern->atoms_.end(), - "Edge atom should not end the pattern."); +auto GenMatch(Match &match, LogicalOperator *input_op, + const SymbolTable &symbol_table, + std::unordered_set &bound_symbols) { + auto base = [&](NodeAtom *node) { + if (input_op) { + // TODO: Support clauses before match. + throw NotYetImplemented(); + } + // First atom always binds a symbol, and we don't care if it already + // existed, + // because we create a ScanAll which writes that symbol. This may need to + // change when we support clauses before match. + BindSymbol(bound_symbols, symbol_table.at(*node->identifier_)); + LogicalOperator *last_op = new ScanAll(node); + if (!node->labels_.empty() || !node->properties_.empty()) { + last_op = new NodeFilter(std::shared_ptr(last_op), + symbol_table.at(*node->identifier_), node); + } + return last_op; + }; + auto collect = [&](LogicalOperator *last_op, NodeAtom *prev_node, + EdgeAtom *edge, NodeAtom *node) { // Store the symbol from the first node as the input to Expand. - auto input_symbol = symbol_table.at(*last_node->identifier_); - last_node = dynamic_cast(*atoms_it++); - debug_assert(last_node, "Expected a node atom in pattern."); + auto input_symbol = symbol_table.at(*prev_node->identifier_); // If the expand symbols were already bound, then we need to indicate // this as a cycle. The Expand will then check whether the pattern holds // instead of writing the expansion to symbols. auto node_cycle = false; auto edge_cycle = false; - if (!BindSymbol(bound_symbols, symbol_table.at(*last_node->identifier_))) { + if (!BindSymbol(bound_symbols, symbol_table.at(*node->identifier_))) { node_cycle = true; } if (!BindSymbol(bound_symbols, symbol_table.at(*edge->identifier_))) { edge_cycle = true; } - last_op = - new Expand(last_node, edge, std::shared_ptr(last_op), - input_symbol, node_cycle, edge_cycle); + last_op = new Expand(node, edge, std::shared_ptr(last_op), + input_symbol, node_cycle, edge_cycle); if (!edge->edge_types_.empty() || !edge->properties_.empty()) { last_op = new EdgeFilter(std::shared_ptr(last_op), symbol_table.at(*edge->identifier_), edge); } - if (!last_node->labels_.empty() || !last_node->properties_.empty()) { - last_op = - new NodeFilter(std::shared_ptr(last_op), - symbol_table.at(*last_node->identifier_), last_node); + if (!node->labels_.empty() || !node->properties_.empty()) { + last_op = new NodeFilter(std::shared_ptr(last_op), + symbol_table.at(*node->identifier_), node); } + return last_op; + }; + + if (match.patterns_.size() != 1) { + // TODO: Support matching multiple patterns. + throw NotYetImplemented(); } - return last_op; + return ReducePattern(*match.patterns_[0], base, collect); } -Produce *GenReturn(Return &ret, LogicalOperator *input_op) { +auto GenReturn(Return &ret, LogicalOperator *input_op) { if (!input_op) { // TODO: Support standalone RETURN clause (e.g. RETURN 2) throw NotYetImplemented();