diff --git a/src/query/frontend/ast/cypher_main_visitor.hpp b/src/query/frontend/ast/cypher_main_visitor.hpp index f7a6dcde5..d120a385b 100644 --- a/src/query/frontend/ast/cypher_main_visitor.hpp +++ b/src/query/frontend/ast/cypher_main_visitor.hpp @@ -521,6 +521,7 @@ class CypherMainVisitor : public antlropencypher::CypherBaseVisitor { public: Query *query() { return query_; } + AstTreeStorage &storage() { return storage_; } const static std::string kAnonPrefix; private: diff --git a/src/query/frontend/semantic/symbol_generator.cpp b/src/query/frontend/semantic/symbol_generator.cpp index 79561507f..76389cd65 100644 --- a/src/query/frontend/semantic/symbol_generator.cpp +++ b/src/query/frontend/semantic/symbol_generator.cpp @@ -122,6 +122,18 @@ void SymbolGenerator::PostVisit(Unwind &unwind) { symbol_table_[*unwind.named_expression_] = CreateSymbol(name); } +void SymbolGenerator::Visit(Match &) { scope_.in_match = true; } +void SymbolGenerator::PostVisit(Match &) { + scope_.in_match = false; + // Check variables in property maps after visiting Match, so that they can + // reference symbols out of bind order. + for (auto &ident : scope_.identifiers_in_property_maps) { + if (!HasSymbol(ident->name_)) throw UnboundVariableError(ident->name_); + symbol_table_[*ident] = scope_.symbols[ident->name_]; + } + scope_.identifiers_in_property_maps.clear(); +} + // Expressions void SymbolGenerator::Visit(Identifier &ident) { @@ -152,6 +164,11 @@ void SymbolGenerator::Visit(Identifier &ident) { type = Symbol::Type::Edge; } symbol = GetOrCreateSymbol(ident.name_, type); + } else if (scope_.in_pattern && scope_.in_property_map && scope_.in_match) { + // Variables in property maps during MATCH can reference symbols bound later + // in the same MATCH. We collect them here, so that they can be checked + // after visiting Match. + scope_.identifiers_in_property_maps.emplace_back(&ident); } else { // Everything else references a bound symbol. if (!HasSymbol(ident.name_)) throw UnboundVariableError(ident.name_); diff --git a/src/query/frontend/semantic/symbol_generator.hpp b/src/query/frontend/semantic/symbol_generator.hpp index 4adfd9ffb..ef964480b 100644 --- a/src/query/frontend/semantic/symbol_generator.hpp +++ b/src/query/frontend/semantic/symbol_generator.hpp @@ -34,6 +34,8 @@ class SymbolGenerator : public TreeVisitorBase { void Visit(Merge &) override; void PostVisit(Merge &) override; void PostVisit(Unwind &) override; + void Visit(Match &) override; + void PostVisit(Match &) override; // Expressions void Visit(Identifier &) override; @@ -72,10 +74,14 @@ class SymbolGenerator : public TreeVisitorBase { bool in_limit{false}; bool in_order_by{false}; bool in_where{false}; + bool in_match{false}; // True if the return/with contains an aggregation in any named expression. bool has_aggregation{false}; // Map from variable names to symbols. std::map symbols; + // Identifiers found in property maps of patterns in a single Match clause. + // They need to be checked after visiting Match. + std::vector identifiers_in_property_maps; }; bool HasSymbol(const std::string &name); diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index 41b52a90a..3cd14c7d1 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -36,7 +36,7 @@ void Interpret(const std::string &query, GraphDbAccessor &db_accessor, high_level_tree->Accept(symbol_generator); // high level tree -> logical plan - auto logical_plan = plan::MakeLogicalPlan(*high_level_tree, symbol_table); + auto logical_plan = plan::MakeLogicalPlan(visitor.storage(), symbol_table); // generate frame based on symbol table max_position Frame frame(symbol_table.max_position()); diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 78fb20c60..afd143d2f 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -357,107 +357,6 @@ bool Expand::ExpandCursor::HandleExistingNode(const VertexAccessor new_node, } } -NodeFilter::NodeFilter(const std::shared_ptr &input, - Symbol input_symbol, const NodeAtom *node_atom) - : input_(input ? input : std::make_shared()), - input_symbol_(input_symbol), - node_atom_(node_atom) {} - -ACCEPT_WITH_INPUT(NodeFilter) - -std::unique_ptr NodeFilter::MakeCursor(GraphDbAccessor &db) { - return std::make_unique(*this, db); -} - -NodeFilter::NodeFilterCursor::NodeFilterCursor(const NodeFilter &self, - GraphDbAccessor &db) - : self_(self), db_(db), input_cursor_(self_.input_->MakeCursor(db)) {} - -bool NodeFilter::NodeFilterCursor::Pull(Frame &frame, - const SymbolTable &symbol_table) { - while (input_cursor_->Pull(frame, symbol_table)) { - auto &vertex = frame[self_.input_symbol_].Value(); - // Filter needs to use the old, unmodified vertex, even though we may change - // properties or labels during the current command. - vertex.SwitchOld(); - if (VertexPasses(vertex, frame, symbol_table)) return true; - } - return false; -} - -void NodeFilter::NodeFilterCursor::Reset() { input_cursor_->Reset(); } - -bool NodeFilter::NodeFilterCursor::VertexPasses( - const VertexAccessor &vertex, Frame &frame, - const SymbolTable &symbol_table) { - for (auto label : self_.node_atom_->labels_) - if (!vertex.has_label(label)) return false; - - ExpressionEvaluator expression_evaluator(frame, symbol_table, db_, - GraphView::OLD); - // We don't want newly set properties to affect filtering. - for (auto prop_pair : self_.node_atom_->properties_) { - prop_pair.second->Accept(expression_evaluator); - TypedValue comparison_result = - vertex.PropsAt(prop_pair.first) == expression_evaluator.PopBack(); - if (comparison_result.IsNull() || !comparison_result.Value()) - return false; - } - return true; -} - -EdgeFilter::EdgeFilter(const std::shared_ptr &input, - Symbol input_symbol, const EdgeAtom *edge_atom) - : input_(input ? input : std::make_shared()), - input_symbol_(input_symbol), - edge_atom_(edge_atom) {} - -ACCEPT_WITH_INPUT(EdgeFilter) - -std::unique_ptr EdgeFilter::MakeCursor(GraphDbAccessor &db) { - return std::make_unique(*this, db); -} -EdgeFilter::EdgeFilterCursor::EdgeFilterCursor(const EdgeFilter &self, - GraphDbAccessor &db) - : self_(self), db_(db), input_cursor_(self_.input_->MakeCursor(db)) {} - -bool EdgeFilter::EdgeFilterCursor::Pull(Frame &frame, - const SymbolTable &symbol_table) { - while (input_cursor_->Pull(frame, symbol_table)) { - auto &edge = frame[self_.input_symbol_].Value(); - // Filter needs to use the old, unmodified edge, even though we may change - // properties or types during the current command. - edge.SwitchOld(); - if (EdgePasses(edge, frame, symbol_table)) return true; - } - return false; -} - -void EdgeFilter::EdgeFilterCursor::Reset() { input_cursor_->Reset(); } - -bool EdgeFilter::EdgeFilterCursor::EdgePasses(const EdgeAccessor &edge, - Frame &frame, - const SymbolTable &symbol_table) { - // edge type filtering - logical OR - const auto &types = self_.edge_atom_->edge_types_; - GraphDbTypes::EdgeType type = edge.edge_type(); - if (types.size() && std::none_of(types.begin(), types.end(), - [type](auto t) { return t == type; })) - return false; - - // We don't want newly set properties to affect filtering. - ExpressionEvaluator expression_evaluator(frame, symbol_table, db_, - GraphView::OLD); - for (auto prop_pair : self_.edge_atom_->properties_) { - prop_pair.second->Accept(expression_evaluator); - TypedValue comparison_result = - edge.PropsAt(prop_pair.first) == expression_evaluator.PopBack(); - if (comparison_result.IsNull() || !comparison_result.Value()) - return false; - } - return true; -} - Filter::Filter(const std::shared_ptr &input, Expression *expression) : input_(input ? input : std::make_shared()), diff --git a/src/query/plan/operator.hpp b/src/query/plan/operator.hpp index 0b348a366..262a8ea45 100644 --- a/src/query/plan/operator.hpp +++ b/src/query/plan/operator.hpp @@ -55,8 +55,6 @@ class CreateNode; class CreateExpand; class ScanAll; class Expand; -class NodeFilter; -class EdgeFilter; class Filter; class Produce; class Delete; @@ -80,9 +78,9 @@ class Distinct; /** @brief Base class for visitors of @c LogicalOperator class hierarchy. */ using LogicalOperatorVisitor = ::utils::Visitor< - Once, CreateNode, CreateExpand, ScanAll, Expand, NodeFilter, EdgeFilter, - Filter, Produce, Delete, SetProperty, SetProperties, SetLabels, - RemoveProperty, RemoveLabels, ExpandUniquenessFilter, + Once, CreateNode, CreateExpand, ScanAll, Expand, Filter, Produce, Delete, + SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, + ExpandUniquenessFilter, ExpandUniquenessFilter, Accumulate, AdvanceCommand, Aggregate, Skip, Limit, OrderBy, Merge, Optional, Unwind, Distinct>; @@ -425,88 +423,6 @@ class Expand : public LogicalOperator { }; }; -/** @brief Operator which filters nodes by labels and properties. - * - * This operator is used to implement `MATCH (n :label {prop: value})`, so that - * it filters nodes with specified labels and properties by value. - */ -class NodeFilter : public LogicalOperator { - public: - /** @brief Construct @c NodeFilter. - * - * @param input Optional, preceding @c LogicalOperator. - * @param input_symbol @c Symbol where the node to be filtered is stored. - * @param node_atom @c NodeAtom with labels and properties to filter by. - */ - NodeFilter(const std::shared_ptr &input, Symbol input_symbol, - const NodeAtom *node_atom); - void Accept(LogicalOperatorVisitor &visitor) override; - std::unique_ptr MakeCursor(GraphDbAccessor &db) override; - - private: - const std::shared_ptr input_; - const Symbol input_symbol_; - const NodeAtom *node_atom_; - - class NodeFilterCursor : public Cursor { - public: - NodeFilterCursor(const NodeFilter &self, GraphDbAccessor &db); - bool Pull(Frame &frame, const SymbolTable &symbol_table) override; - void Reset() override; - - private: - const NodeFilter &self_; - GraphDbAccessor &db_; - const std::unique_ptr input_cursor_; - - /** Helper function for checking if the given vertex - * passes this filter. */ - bool VertexPasses(const VertexAccessor &vertex, Frame &frame, - const SymbolTable &symbol_table); - }; -}; - -/** @brief Operator which filters edges by relationship type and properties. - * - * This operator is used to implement `MATCH () -[r :label {prop: value}]- ()`, - * so that it filters edges with specified types and properties by value. - */ -class EdgeFilter : public LogicalOperator { - public: - /** @brief Construct @c EdgeFilter. - * - * @param input Optional, preceding @c LogicalOperator. - * @param input_symbol @c Symbol where the edge to be filtered is stored. - * @param edge_atom @c EdgeAtom with edge types and properties to filter by. - */ - EdgeFilter(const std::shared_ptr &input, Symbol input_symbol, - const EdgeAtom *edge_atom); - void Accept(LogicalOperatorVisitor &visitor) override; - std::unique_ptr MakeCursor(GraphDbAccessor &db) override; - - private: - const std::shared_ptr input_; - const Symbol input_symbol_; - const EdgeAtom *edge_atom_; - - class EdgeFilterCursor : public Cursor { - public: - EdgeFilterCursor(const EdgeFilter &self, GraphDbAccessor &db); - bool Pull(Frame &frame, const SymbolTable &symbol_table) override; - void Reset() override; - - private: - const EdgeFilter &self_; - GraphDbAccessor &db_; - const std::unique_ptr input_cursor_; - - /** Helper function for checking if the given edge satisfied - * the criteria of this edge filter. */ - bool EdgePasses(const EdgeAccessor &edge, Frame &frame, - const SymbolTable &symbol_table); - }; -}; - /** * @brief Filter whose Pull returns true only when the given expression * evaluates into true. diff --git a/src/query/plan/planner.cpp b/src/query/plan/planner.cpp index dd4b8bc11..2a1a98216 100644 --- a/src/query/plan/planner.cpp +++ b/src/query/plan/planner.cpp @@ -101,6 +101,113 @@ auto GenCreate(Create &create, LogicalOperator *input_op, return last_op; } +// Collects symbols from identifiers found in visited AST nodes. +class UsedSymbolsCollector : public TreeVisitorBase { + public: + UsedSymbolsCollector(const SymbolTable &symbol_table) + : symbol_table_(symbol_table) {} + + using TreeVisitorBase::Visit; + void Visit(Identifier &ident) override { + const auto &symbol = symbol_table_.at(ident); + symbols_.insert(symbol.position_); + } + + std::unordered_set symbols_; + const SymbolTable &symbol_table_; +}; + +bool HasBoundFilterSymbols( + const std::unordered_set &bound_symbols, + const std::pair> &filter) { + for (const auto &symbol : filter.second) { + if (bound_symbols.find(symbol) == bound_symbols.end()) { + return false; + } + } + return true; +} + +template +Expression *BoolJoin(AstTreeStorage &storage, Expression *expr1, + Expression *expr2) { + if (expr1 && expr2) { + return storage.Create(expr1, expr2); + } + return expr1 ? expr1 : expr2; +} + +template +Expression *PropertiesEqual(AstTreeStorage &storage, + UsedSymbolsCollector &collector, TAtom *atom) { + Expression *filter_expr = nullptr; + for (auto &prop_pair : atom->properties_) { + prop_pair.second->Accept(collector); + auto *property_lookup = + storage.Create(atom->identifier_, prop_pair.first); + auto *prop_equal = + storage.Create(property_lookup, prop_pair.second); + filter_expr = BoolJoin(storage, filter_expr, prop_equal); + } + return filter_expr; +} + +auto &CollectPatternFilters( + Pattern &pattern, const SymbolTable &symbol_table, + std::list>> &filters, + AstTreeStorage &storage) { + UsedSymbolsCollector collector(symbol_table); + auto node_filter = [&](NodeAtom *node) { + Expression *labels_filter = + node->labels_.empty() ? nullptr + : labels_filter = storage.Create( + node->identifier_, node->labels_); + auto *props_filter = PropertiesEqual(storage, collector, node); + if (labels_filter || props_filter) { + collector.symbols_.insert(symbol_table.at(*node->identifier_).position_); + filters.emplace_back( + BoolJoin(storage, labels_filter, props_filter), + collector.symbols_); + collector.symbols_.clear(); + } + return &filters; + }; + auto expand_filter = [&](auto *filters, NodeAtom *prev_node, EdgeAtom *edge, + NodeAtom *node) { + Expression *types_filter = edge->edge_types_.empty() + ? nullptr + : storage.Create( + edge->identifier_, edge->edge_types_); + auto *props_filter = PropertiesEqual(storage, collector, edge); + if (types_filter || props_filter) { + const auto &edge_symbol = symbol_table.at(*edge->identifier_); + collector.symbols_.insert(edge_symbol.position_); + filters->emplace_back( + BoolJoin(storage, types_filter, props_filter), + collector.symbols_); + collector.symbols_.clear(); + } + return node_filter(node); + }; + return *ReducePattern< + std::list>> *>( + pattern, node_filter, expand_filter); +} + +void CollectMatchFilters( + const Match &match, const SymbolTable &symbol_table, + std::list>> &filters, + AstTreeStorage &storage) { + for (auto *pattern : match.patterns_) { + CollectPatternFilters(*pattern, symbol_table, filters, storage); + } + if (match.where_) { + UsedSymbolsCollector collector(symbol_table); + match.where_->expression_->Accept(collector); + filters.emplace_back(match.where_->expression_, collector.symbols_); + } +} + // Contextual information used for generating match operators. struct MatchContext { const SymbolTable &symbol_table; @@ -110,20 +217,47 @@ struct MatchContext { std::unordered_set &bound_symbols; // Determines whether the match should see the new graph state or not. GraphView graph_view = GraphView::OLD; + // Pairs of filter expression and symbols used in them. The list should be + // filled using CollectPatternFilters function, and later modified during + // GenMatchForPattern. + std::list>> filters; // Symbols for edges established in match, used to ensure Cyphermorphism. std::unordered_set edge_symbols; // All the newly established symbols in match. std::vector new_symbols; }; +auto GenFilters( + LogicalOperator *last_op, const std::unordered_set &bound_symbols, + std::list>> &filters, + AstTreeStorage &storage) { + Expression *filter_expr = nullptr; + for (auto filters_it = filters.begin(); filters_it != filters.end();) { + if (HasBoundFilterSymbols(bound_symbols, *filters_it)) { + filter_expr = + BoolJoin(storage, filter_expr, filters_it->first); + filters_it = filters.erase(filters_it); + } else { + filters_it++; + } + } + if (filter_expr) { + last_op = + new Filter(std::shared_ptr(last_op), filter_expr); + } + return last_op; +} + // Generates operators for matching the given pattern and appends them to // input_op. Fills the context with all the new symbols and edge symbols. auto GenMatchForPattern(Pattern &pattern, LogicalOperator *input_op, - MatchContext &context) { + MatchContext &context, AstTreeStorage &storage) { auto &bound_symbols = context.bound_symbols; const auto &symbol_table = context.symbol_table; auto base = [&](NodeAtom *node) { - LogicalOperator *last_op = input_op; + // Try to generate any filters even before the 1st match operator. + auto *last_op = + GenFilters(input_op, bound_symbols, context.filters, storage); // If the first atom binds a symbol, we generate a ScanAll which writes it. // Otherwise, someone else generates it (e.g. a previous ScanAll). const auto &node_symbol = symbol_table.at(*node->identifier_); @@ -132,13 +266,7 @@ auto GenMatchForPattern(Pattern &pattern, LogicalOperator *input_op, context.graph_view); context.new_symbols.emplace_back(node_symbol); } - // Even though we may skip generating ScanAll, we still want to add a filter - // in case this atom adds more labels/properties for filtering. - if (!node->labels_.empty() || !node->properties_.empty()) { - last_op = new NodeFilter(std::shared_ptr(last_op), - symbol_table.at(*node->identifier_), node); - } - return last_op; + return GenFilters(last_op, bound_symbols, context.filters, storage); }; auto collect = [&](LogicalOperator *last_op, NodeAtom *prev_node, EdgeAtom *edge, NodeAtom *node) { @@ -177,38 +305,54 @@ auto GenMatchForPattern(Pattern &pattern, LogicalOperator *input_op, // Insert edge_symbol after creating ExpandUniquenessFilter, so that we // avoid filtering by the same edge we just expanded. context.edge_symbols.insert(edge_symbol); - if (!edge->edge_types_.empty() || !edge->properties_.empty()) { - last_op = new EdgeFilter(std::shared_ptr(last_op), - symbol_table.at(*edge->identifier_), edge); - } - if (!node->labels_.empty() || !node->properties_.empty()) { - last_op = new NodeFilter(std::shared_ptr(last_op), - symbol_table.at(*node->identifier_), node); - } - return last_op; + return GenFilters(last_op, bound_symbols, context.filters, storage); }; return ReducePattern(pattern, base, collect); } -auto GenMatch(Match &match, LogicalOperator *input_op, - const SymbolTable &symbol_table, - std::unordered_set &bound_symbols) { - auto last_op = match.optional_ ? nullptr : input_op; - MatchContext context{symbol_table, bound_symbols}; - for (auto pattern : match.patterns_) { - last_op = GenMatchForPattern(*pattern, last_op, context); +auto GenMatches(std::vector &matches, LogicalOperator *input_op, + const SymbolTable &symbol_table, + std::unordered_set &bound_symbols, + AstTreeStorage &storage) { + auto *last_op = input_op; + MatchContext req_ctx{symbol_table, bound_symbols}; + // Collect all non-optional match filters, so that we can put them as soon as + // possible in the operator tree. Optional match need to be treated + // specially, because they need to remain inside the optional match. + for (auto *match : matches) { + if (match->optional_) { + continue; + } + CollectMatchFilters(*match, symbol_table, req_ctx.filters, storage); } - if (match.where_) { - last_op = new Filter(std::shared_ptr(last_op), - match.where_->expression_); - } - // Plan Optional after Filter. because with `OPTIONAL MATCH ... WHERE`, - // filtering is done while looking for the pattern. - if (match.optional_) { - last_op = new Optional(std::shared_ptr(input_op), - std::shared_ptr(last_op), - context.new_symbols); + auto gen_match = [&storage](const Match &match, LogicalOperator *input_op, + MatchContext &context) { + auto *match_op = input_op; + for (auto *pattern : match.patterns_) { + match_op = GenMatchForPattern(*pattern, match_op, context, storage); + } + return match_op; + }; + for (auto *match : matches) { + if (match->optional_) { + // Optional match needs to be standalone, so filter only by its filters + // and don't plug the previous match_op as input. + MatchContext opt_ctx{symbol_table, bound_symbols}; + CollectMatchFilters(*match, symbol_table, opt_ctx.filters, storage); + auto *match_op = gen_match(*match, nullptr, opt_ctx); + last_op = new Optional(std::shared_ptr(last_op), + std::shared_ptr(match_op), + opt_ctx.new_symbols); + debug_assert(opt_ctx.filters.empty(), + "Expected to generate all optional filters"); + } else { + // Since we reuse req_ctx, we need to clear the symbols for the new match. + req_ctx.edge_symbols.clear(); + req_ctx.new_symbols.clear(); + last_op = gen_match(*match, last_op, req_ctx); + } } + debug_assert(req_ctx.filters.empty(), "Expected to generate all filters"); return last_op; } @@ -256,7 +400,9 @@ class ReturnBodyContext : public TreeVisitorBase { using TreeVisitorBase::Visit; using TreeVisitorBase::PostVisit; - void Visit(PrimitiveLiteral &) override { has_aggregation_.emplace_back(false); } + void Visit(PrimitiveLiteral &) override { + has_aggregation_.emplace_back(false); + } void Visit(ListLiteral &) override { has_aggregation_.emplace_back(false); } @@ -485,12 +631,15 @@ LogicalOperator *HandleWriteClause(Clause *clause, LogicalOperator *input_op, auto GenMerge(query::Merge &merge, LogicalOperator *input_op, const SymbolTable &symbol_table, - std::unordered_set &bound_symbols) { + std::unordered_set &bound_symbols, AstTreeStorage &storage) { // Copy the bound symbol set, because we don't want to use the updated version // when generating the create part. std::unordered_set bound_symbols_copy(bound_symbols); MatchContext context{symbol_table, bound_symbols_copy, GraphView::NEW}; - auto on_match = GenMatchForPattern(*merge.pattern_, nullptr, context); + CollectPatternFilters(*merge.pattern_, symbol_table, context.filters, + storage); + auto on_match = + GenMatchForPattern(*merge.pattern_, nullptr, context, storage); // Use the original bound_symbols, so we fill it with new symbols. auto on_create = GenCreateForPattern(*merge.pattern_, nullptr, symbol_table, bound_symbols); @@ -510,7 +659,8 @@ auto GenMerge(query::Merge &merge, LogicalOperator *input_op, } // namespace std::unique_ptr MakeLogicalPlan( - query::Query &query, const query::SymbolTable &symbol_table) { + AstTreeStorage &storage, const SymbolTable &symbol_table) { + auto query = storage.query(); // bound_symbols set is used to differentiate cycles in pattern matching, so // that the operator can be correctly initialized whether to read the symbol // or write it. E.g. `MATCH (n) -[r]- (n)` would bind (and write) the first @@ -519,35 +669,47 @@ std::unique_ptr MakeLogicalPlan( // Set to true if a query command writes to the database. bool is_write = false; LogicalOperator *input_op = nullptr; - for (auto &clause : query.clauses_) { + // All sequential Match clauses. Reset after encountering non-Match. + std::vector matches; + for (auto &clause : query->clauses_) { // Clauses which read from the database. if (auto *match = dynamic_cast(clause)) { - input_op = GenMatch(*match, input_op, symbol_table, bound_symbols); - } else if (auto *ret = dynamic_cast(clause)) { - input_op = GenReturn(*ret, input_op, symbol_table, is_write); - } else if (auto *merge = dynamic_cast(clause)) { - input_op = GenMerge(*merge, input_op, symbol_table, bound_symbols); - // Treat MERGE clause as write, because we do not know if it will create - // anything. - is_write = true; - } else if (auto *with = dynamic_cast(clause)) { - input_op = - GenWith(*with, input_op, symbol_table, is_write, bound_symbols); - // WITH clause advances the command, so reset the flag. - is_write = false; - } else if (auto *op = HandleWriteClause(clause, input_op, symbol_table, - bound_symbols)) { - is_write = true; - input_op = op; - } else if (auto *unwind = dynamic_cast(clause)) { - input_op = new plan::Unwind(std::shared_ptr(input_op), - unwind->named_expression_->expression_, - symbol_table.at(*unwind->named_expression_)); + matches.emplace_back(match); } else { - throw utils::NotYetImplemented( - "Encountered a clause which cannot be converted to operator(s)"); + input_op = + GenMatches(matches, input_op, symbol_table, bound_symbols, storage); + matches.clear(); + if (auto *ret = dynamic_cast(clause)) { + input_op = GenReturn(*ret, input_op, symbol_table, is_write); + } else if (auto *merge = dynamic_cast(clause)) { + input_op = + GenMerge(*merge, input_op, symbol_table, bound_symbols, storage); + // Treat MERGE clause as write, because we do not know if it will create + // anything. + is_write = true; + } else if (auto *with = dynamic_cast(clause)) { + input_op = + GenWith(*with, input_op, symbol_table, is_write, bound_symbols); + // WITH clause advances the command, so reset the flag. + is_write = false; + } else if (auto *op = HandleWriteClause(clause, input_op, symbol_table, + bound_symbols)) { + is_write = true; + input_op = op; + } else if (auto *unwind = dynamic_cast(clause)) { + input_op = + new plan::Unwind(std::shared_ptr(input_op), + unwind->named_expression_->expression_, + symbol_table.at(*unwind->named_expression_)); + } else { + throw utils::NotYetImplemented( + "Encountered a clause which cannot be converted to operator(s)"); + } } } + debug_assert( + matches.empty(), + "Expected Match clause(s) to be followed by an update or return clause"); return std::unique_ptr(input_op); } diff --git a/src/query/plan/planner.hpp b/src/query/plan/planner.hpp index 1dada21a2..af76d6559 100644 --- a/src/query/plan/planner.hpp +++ b/src/query/plan/planner.hpp @@ -6,15 +6,19 @@ namespace query { -class Query; +class AstTreeStorage; class SymbolTable; namespace plan { -// Returns the root of LogicalOperator tree. The tree is constructed by -// traversing the given AST Query node. SymbolTable is used to determine inputs -// and outputs of certain operators. +/// @brief Generates the LogicalOperator tree and returns the root operation. +/// +/// The tree is constructed by traversing the @c Query node from given +/// @c AstTreeStorage. The storage may also be used to create new AST nodes for +/// use in operators. @c SymbolTable is used to determine inputs and outputs of +/// certain operators. std::unique_ptr MakeLogicalPlan( - query::Query &query, const query::SymbolTable &symbol_table); -} + AstTreeStorage &storage, const query::SymbolTable &symbol_table); } + +} // namespace plan diff --git a/tests/unit/query_common.hpp b/tests/unit/query_common.hpp index 60f1f6342..7ed0a4749 100644 --- a/tests/unit/query_common.hpp +++ b/tests/unit/query_common.hpp @@ -86,6 +86,10 @@ auto GetPropertyLookup(AstTreeStorage &storage, const std::string &name, return storage.Create(storage.Create(name), property); } +auto GetPropertyLookup(AstTreeStorage &storage, Expression *expr, + GraphDbTypes::Property property) { + return storage.Create(expr, property); +} /// /// Create an EdgeAtom with given name, edge_type and direction. @@ -415,3 +419,6 @@ auto GetMerge(AstTreeStorage &storage, Pattern *pattern, OnMatch on_match, storage.Create((expr), query::Aggregation::Op::SUM) #define COUNT(expr) \ storage.Create((expr), query::Aggregation::Op::COUNT) +#define EQ(expr1, expr2) storage.Create((expr1), (expr2)) +#define AND(expr1, expr2) storage.Create((expr1), (expr2)) +#define OR(expr1, expr2) storage.Create((expr1), (expr2)) diff --git a/tests/unit/query_plan_create_set_remove_delete.cpp b/tests/unit/query_plan_create_set_remove_delete.cpp index 49a414a11..9492538d0 100644 --- a/tests/unit/query_plan_create_set_remove_delete.cpp +++ b/tests/unit/query_plan_create_set_remove_delete.cpp @@ -707,8 +707,10 @@ TEST(QueryPlan, NodeFilterSet) { scan_all.node_->properties_[prop] = LITERAL(42); auto expand = MakeExpand(storage, symbol_table, scan_all.op_, scan_all.sym_, "r", EdgeAtom::Direction::BOTH, false, "m", false); - auto node_filter = - std::make_shared(expand.op_, scan_all.sym_, scan_all.node_); + auto *filter_expr = + EQ(storage.Create(scan_all.node_->identifier_, prop), + LITERAL(42)); + auto node_filter = std::make_shared(expand.op_, filter_expr); // SET n.prop = n.prop + 1 auto set_prop = PROPERTY_LOOKUP("n", prop); symbol_table[*set_prop->expression_] = scan_all.sym_; diff --git a/tests/unit/query_plan_match_filter_return.cpp b/tests/unit/query_plan_match_filter_return.cpp index 9f77c5336..5ccdab33f 100644 --- a/tests/unit/query_plan_match_filter_return.cpp +++ b/tests/unit/query_plan_match_filter_return.cpp @@ -142,7 +142,10 @@ TEST(QueryPlan, NodeFilterLabelsAndProperties) { n.node_->properties_[property] = LITERAL(42); // node filtering - auto node_filter = std::make_shared(n.op_, n.sym_, n.node_); + auto *filter_expr = + AND(storage.Create(n.node_->identifier_, n.node_->labels_), + EQ(PROPERTY_LOOKUP(n.node_->identifier_, property), LITERAL(42))); + auto node_filter = std::make_shared(n.op_, filter_expr); // make a named expression and a produce auto output = NEXPR("x", IDENT("n")); @@ -194,7 +197,9 @@ TEST(QueryPlan, NodeFilterMultipleLabels) { n.node_->labels_.emplace_back(label2); // node filtering - auto node_filter = std::make_shared(n.op_, n.sym_, n.node_); + auto *filter_expr = + storage.Create(n.node_->identifier_, n.node_->labels_); + auto node_filter = std::make_shared(n.op_, filter_expr); // make a named expression and a produce auto output = NEXPR("n", IDENT("n")); @@ -491,8 +496,11 @@ TEST(QueryPlan, EdgeFilter) { EdgeAtom::Direction::RIGHT, false, "m", false); r_m.edge_->edge_types_.push_back(edge_types[0]); r_m.edge_->properties_[prop] = LITERAL(42); - auto edge_filter = - std::make_shared(r_m.op_, r_m.edge_sym_, r_m.edge_); + auto *filter_expr = + AND(storage.Create(r_m.edge_->identifier_, + r_m.edge_->edge_types_), + EQ(PROPERTY_LOOKUP(r_m.edge_->identifier_, prop), LITERAL(42))); + auto edge_filter = std::make_shared(r_m.op_, filter_expr); // make a named expression and a produce auto output = NEXPR("m", IDENT("m")); @@ -511,26 +519,6 @@ TEST(QueryPlan, EdgeFilter) { EXPECT_EQ(3, test_filter()); } -TEST(QueryPlan, EdgeFilterEmpty) { - Dbms dbms; - auto dba = dbms.active(); - - auto v1 = dba->insert_vertex(); - auto v2 = dba->insert_vertex(); - dba->insert_edge(v1, v2, dba->edge_type("type")); - dba->advance_command(); - - AstTreeStorage storage; - SymbolTable symbol_table; - - auto n = MakeScanAll(storage, symbol_table, "n"); - auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", - EdgeAtom::Direction::RIGHT, false, "m", false); - auto edge_filter = - std::make_shared(r_m.op_, r_m.edge_sym_, r_m.edge_); - EXPECT_EQ(1, PullAll(edge_filter, *dba, symbol_table)); -} - TEST(QueryPlan, EdgeFilterMultipleTypes) { Dbms dbms; auto dba = dbms.active(); @@ -552,11 +540,12 @@ TEST(QueryPlan, EdgeFilterMultipleTypes) { auto n = MakeScanAll(storage, symbol_table, "n"); auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", EdgeAtom::Direction::RIGHT, false, "m", false); - // add a property filter - auto edge_filter = - std::make_shared(r_m.op_, r_m.edge_sym_, r_m.edge_); + // add an edge type filter r_m.edge_->edge_types_.push_back(type_1); r_m.edge_->edge_types_.push_back(type_2); + auto *filter_expr = storage.Create(r_m.edge_->identifier_, + r_m.edge_->edge_types_); + auto edge_filter = std::make_shared(r_m.op_, filter_expr); // make a named expression and a produce auto output = NEXPR("m", IDENT("m")); diff --git a/tests/unit/query_planner.cpp b/tests/unit/query_planner.cpp index 93c0046b3..9dbc152ce 100644 --- a/tests/unit/query_planner.cpp +++ b/tests/unit/query_planner.cpp @@ -44,8 +44,6 @@ class PlanChecker : public LogicalOperatorVisitor { void Visit(Delete &op) override { CheckOp(op); } void Visit(ScanAll &op) override { CheckOp(op); } void Visit(Expand &op) override { CheckOp(op); } - void Visit(NodeFilter &op) override { CheckOp(op); } - void Visit(EdgeFilter &op) override { CheckOp(op); } void Visit(Filter &op) override { CheckOp(op); } void Visit(Produce &op) override { CheckOp(op); } void Visit(SetProperty &op) override { CheckOp(op); } @@ -104,8 +102,6 @@ using ExpectCreateExpand = OpChecker; using ExpectDelete = OpChecker; using ExpectScanAll = OpChecker; using ExpectExpand = OpChecker; -using ExpectNodeFilter = OpChecker; -using ExpectEdgeFilter = OpChecker; using ExpectFilter = OpChecker; using ExpectProduce = OpChecker; using ExpectSetProperty = OpChecker; @@ -212,17 +208,17 @@ auto CheckPlan(LogicalOperator &plan, const SymbolTable &symbol_table, } template -auto CheckPlan(query::Query &query, TChecker... checker) { - auto symbol_table = MakeSymbolTable(query); - auto plan = MakeLogicalPlan(query, symbol_table); +auto CheckPlan(AstTreeStorage &storage, TChecker... checker) { + auto symbol_table = MakeSymbolTable(*storage.query()); + auto plan = MakeLogicalPlan(storage, symbol_table); CheckPlan(*plan, symbol_table, checker...); } TEST(TestLogicalPlanner, MatchNodeReturn) { // Test MATCH (n) RETURN n AS n AstTreeStorage storage; - auto query = QUERY(MATCH(PATTERN(NODE("n"))), RETURN(IDENT("n"), AS("n"))); - CheckPlan(*query, ExpectScanAll(), ExpectProduce()); + QUERY(MATCH(PATTERN(NODE("n"))), RETURN(IDENT("n"), AS("n"))); + CheckPlan(storage, ExpectScanAll(), ExpectProduce()); } TEST(TestLogicalPlanner, CreateNodeReturn) { @@ -232,7 +228,7 @@ TEST(TestLogicalPlanner, CreateNodeReturn) { auto query = QUERY(CREATE(PATTERN(NODE("n"))), RETURN(ident_n, AS("n"))); auto symbol_table = MakeSymbolTable(*query); auto acc = ExpectAccumulate({symbol_table.at(*ident_n)}); - auto plan = MakeLogicalPlan(*query, symbol_table); + auto plan = MakeLogicalPlan(storage, symbol_table); CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, ExpectProduce()); } @@ -242,16 +238,16 @@ TEST(TestLogicalPlanner, CreateExpand) { Dbms dbms; auto dba = dbms.active(); auto relationship = dba->edge_type("relationship"); - auto query = QUERY(CREATE(PATTERN( - NODE("n"), EDGE("r", relationship, Direction::RIGHT), NODE("m")))); - CheckPlan(*query, ExpectCreateNode(), ExpectCreateExpand()); + QUERY(CREATE(PATTERN(NODE("n"), EDGE("r", relationship, Direction::RIGHT), + NODE("m")))); + CheckPlan(storage, ExpectCreateNode(), ExpectCreateExpand()); } TEST(TestLogicalPlanner, CreateMultipleNode) { // Test CREATE (n), (m) AstTreeStorage storage; - auto query = QUERY(CREATE(PATTERN(NODE("n")), PATTERN(NODE("m")))); - CheckPlan(*query, ExpectCreateNode(), ExpectCreateNode()); + QUERY(CREATE(PATTERN(NODE("n")), PATTERN(NODE("m")))); + CheckPlan(storage, ExpectCreateNode(), ExpectCreateNode()); } TEST(TestLogicalPlanner, CreateNodeExpandNode) { @@ -260,10 +256,10 @@ TEST(TestLogicalPlanner, CreateNodeExpandNode) { Dbms dbms; auto dba = dbms.active(); auto relationship = dba->edge_type("rel"); - auto query = QUERY(CREATE( + QUERY(CREATE( PATTERN(NODE("n"), EDGE("r", relationship, Direction::RIGHT), NODE("m")), PATTERN(NODE("l")))); - CheckPlan(*query, ExpectCreateNode(), ExpectCreateExpand(), + CheckPlan(storage, ExpectCreateNode(), ExpectCreateExpand(), ExpectCreateNode()); } @@ -273,11 +269,10 @@ TEST(TestLogicalPlanner, MatchCreateExpand) { Dbms dbms; auto dba = dbms.active(); auto relationship = dba->edge_type("relationship"); - auto query = - QUERY(MATCH(PATTERN(NODE("n"))), - CREATE(PATTERN(NODE("n"), EDGE("r", relationship, Direction::RIGHT), - NODE("m")))); - CheckPlan(*query, ExpectScanAll(), ExpectCreateExpand()); + QUERY(MATCH(PATTERN(NODE("n"))), + CREATE(PATTERN(NODE("n"), EDGE("r", relationship, Direction::RIGHT), + NODE("m")))); + CheckPlan(storage, ExpectScanAll(), ExpectCreateExpand()); } TEST(TestLogicalPlanner, MatchLabeledNodes) { @@ -286,9 +281,8 @@ TEST(TestLogicalPlanner, MatchLabeledNodes) { Dbms dbms; auto dba = dbms.active(); auto label = dba->label("label"); - auto query = - QUERY(MATCH(PATTERN(NODE("n", label))), RETURN(IDENT("n"), AS("n"))); - CheckPlan(*query, ExpectScanAll(), ExpectNodeFilter(), ExpectProduce()); + QUERY(MATCH(PATTERN(NODE("n", label))), RETURN(IDENT("n"), AS("n"))); + CheckPlan(storage, ExpectScanAll(), ExpectFilter(), ExpectProduce()); } TEST(TestLogicalPlanner, MatchPathReturn) { @@ -297,10 +291,9 @@ TEST(TestLogicalPlanner, MatchPathReturn) { Dbms dbms; auto dba = dbms.active(); auto relationship = dba->edge_type("relationship"); - auto query = - QUERY(MATCH(PATTERN(NODE("n"), EDGE("r", relationship), NODE("m"))), - RETURN(IDENT("n"), AS("n"))); - CheckPlan(*query, ExpectScanAll(), ExpectExpand(), ExpectEdgeFilter(), + QUERY(MATCH(PATTERN(NODE("n"), EDGE("r", relationship), NODE("m"))), + RETURN(IDENT("n"), AS("n"))); + CheckPlan(storage, ExpectScanAll(), ExpectExpand(), ExpectFilter(), ExpectProduce()); } @@ -310,17 +303,17 @@ TEST(TestLogicalPlanner, MatchWhereReturn) { Dbms dbms; auto dba = dbms.active(); auto property = dba->property("property"); - auto query = QUERY(MATCH(PATTERN(NODE("n"))), - WHERE(LESS(PROPERTY_LOOKUP("n", property), LITERAL(42))), - RETURN(IDENT("n"), AS("n"))); - CheckPlan(*query, ExpectScanAll(), ExpectFilter(), ExpectProduce()); + QUERY(MATCH(PATTERN(NODE("n"))), + WHERE(LESS(PROPERTY_LOOKUP("n", property), LITERAL(42))), + RETURN(IDENT("n"), AS("n"))); + CheckPlan(storage, ExpectScanAll(), ExpectFilter(), ExpectProduce()); } TEST(TestLogicalPlanner, MatchDelete) { // Test MATCH (n) DELETE n AstTreeStorage storage; - auto query = QUERY(MATCH(PATTERN(NODE("n"))), DELETE(IDENT("n"))); - CheckPlan(*query, ExpectScanAll(), ExpectDelete()); + QUERY(MATCH(PATTERN(NODE("n"))), DELETE(IDENT("n"))); + CheckPlan(storage, ExpectScanAll(), ExpectDelete()); } TEST(TestLogicalPlanner, MatchNodeSet) { @@ -330,11 +323,10 @@ TEST(TestLogicalPlanner, MatchNodeSet) { auto dba = dbms.active(); auto prop = dba->property("prop"); auto label = dba->label("label"); - auto query = QUERY(MATCH(PATTERN(NODE("n"))), - SET(PROPERTY_LOOKUP("n", prop), LITERAL(42)), - SET("n", IDENT("n")), SET("n", {label})); - CheckPlan(*query, ExpectScanAll(), ExpectSetProperty(), ExpectSetProperties(), - ExpectSetLabels()); + QUERY(MATCH(PATTERN(NODE("n"))), SET(PROPERTY_LOOKUP("n", prop), LITERAL(42)), + SET("n", IDENT("n")), SET("n", {label})); + CheckPlan(storage, ExpectScanAll(), ExpectSetProperty(), + ExpectSetProperties(), ExpectSetLabels()); } TEST(TestLogicalPlanner, MatchRemove) { @@ -344,95 +336,101 @@ TEST(TestLogicalPlanner, MatchRemove) { auto dba = dbms.active(); auto prop = dba->property("prop"); auto label = dba->label("label"); - auto query = QUERY(MATCH(PATTERN(NODE("n"))), - REMOVE(PROPERTY_LOOKUP("n", prop)), REMOVE("n", {label})); - CheckPlan(*query, ExpectScanAll(), ExpectRemoveProperty(), + QUERY(MATCH(PATTERN(NODE("n"))), REMOVE(PROPERTY_LOOKUP("n", prop)), + REMOVE("n", {label})); + CheckPlan(storage, ExpectScanAll(), ExpectRemoveProperty(), ExpectRemoveLabels()); } TEST(TestLogicalPlanner, MatchMultiPattern) { - // Test MATCH (n) -[r]- (m), (j) -[e]- (i) + // Test MATCH (n) -[r]- (m), (j) -[e]- (i) RETURN n AstTreeStorage storage; - auto query = QUERY(MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m")), - PATTERN(NODE("j"), EDGE("e"), NODE("i")))); + QUERY(MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m")), + PATTERN(NODE("j"), EDGE("e"), NODE("i"))), + RETURN(IDENT("n"), AS("n"))); // We expect the expansions after the first to have a uniqueness filter in a // single MATCH clause. - CheckPlan(*query, ExpectScanAll(), ExpectExpand(), ExpectScanAll(), - ExpectExpand(), ExpectExpandUniquenessFilter()); + CheckPlan(storage, ExpectScanAll(), ExpectExpand(), ExpectScanAll(), + ExpectExpand(), ExpectExpandUniquenessFilter(), + ExpectProduce()); } TEST(TestLogicalPlanner, MatchMultiPatternSameStart) { - // Test MATCH (n), (n) -[e]- (m) + // Test MATCH (n), (n) -[e]- (m) RETURN n AstTreeStorage storage; - auto query = QUERY( - MATCH(PATTERN(NODE("n")), PATTERN(NODE("n"), EDGE("e"), NODE("m")))); + QUERY(MATCH(PATTERN(NODE("n")), PATTERN(NODE("n"), EDGE("e"), NODE("m"))), + RETURN(IDENT("n"), AS("n"))); // We expect the second pattern to generate only an Expand, since another // ScanAll would be redundant. - CheckPlan(*query, ExpectScanAll(), ExpectExpand()); + CheckPlan(storage, ExpectScanAll(), ExpectExpand(), ExpectProduce()); } TEST(TestLogicalPlanner, MatchMultiPatternSameExpandStart) { - // Test MATCH (n) -[r]- (m), (m) -[e]- (l) + // Test MATCH (n) -[r]- (m), (m) -[e]- (l) RETURN n AstTreeStorage storage; - auto query = QUERY(MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m")), - PATTERN(NODE("m"), EDGE("e"), NODE("l")))); + QUERY(MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m")), + PATTERN(NODE("m"), EDGE("e"), NODE("l"))), + RETURN(IDENT("n"), AS("n"))); // We expect the second pattern to generate only an Expand. Another // ScanAll would be redundant, as it would generate the nodes obtained from // expansion. Additionally, a uniqueness filter is expected. - CheckPlan(*query, ExpectScanAll(), ExpectExpand(), ExpectExpand(), - ExpectExpandUniquenessFilter()); + CheckPlan(storage, ExpectScanAll(), ExpectExpand(), ExpectExpand(), + ExpectExpandUniquenessFilter(), ExpectProduce()); } TEST(TestLogicalPlanner, MultiMatch) { - // Test MATCH (n) -[r]- (m) MATCH (j) -[e]- (i) -[f]- (h) + // Test MATCH (n) -[r]- (m) MATCH (j) -[e]- (i) -[f]- (h) RETURN n AstTreeStorage storage; - auto query = QUERY( - MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))), - MATCH(PATTERN(NODE("j"), EDGE("e"), NODE("i"), EDGE("f"), NODE("h")))); + QUERY(MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))), + MATCH(PATTERN(NODE("j"), EDGE("e"), NODE("i"), EDGE("f"), NODE("h"))), + RETURN(IDENT("n"), AS("n"))); // Multiple MATCH clauses form a Cartesian product, so the uniqueness should // not cross MATCH boundaries. - CheckPlan(*query, ExpectScanAll(), ExpectExpand(), ExpectScanAll(), + CheckPlan(storage, ExpectScanAll(), ExpectExpand(), ExpectScanAll(), ExpectExpand(), ExpectExpand(), - ExpectExpandUniquenessFilter()); + ExpectExpandUniquenessFilter(), ExpectProduce()); } TEST(TestLogicalPlanner, MultiMatchSameStart) { - // Test MATCH (n) MATCH (n) -[r]- (m) + // Test MATCH (n) MATCH (n) -[r]- (m) RETURN n AstTreeStorage storage; - auto query = QUERY(MATCH(PATTERN(NODE("n"))), - MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m")))); + QUERY(MATCH(PATTERN(NODE("n"))), + MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))), + RETURN(IDENT("n"), AS("n"))); // Similar to MatchMultiPatternSameStart, we expect only Expand from second // MATCH clause. - CheckPlan(*query, ExpectScanAll(), ExpectExpand()); + CheckPlan(storage, ExpectScanAll(), ExpectExpand(), ExpectProduce()); } TEST(TestLogicalPlanner, MatchExistingEdge) { - // Test MATCH (n) -[r]- (m) -[r]- (j) + // Test MATCH (n) -[r]- (m) -[r]- (j) RETURN n AstTreeStorage storage; - auto query = QUERY( - MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"), EDGE("r"), NODE("j")))); + QUERY(MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"), EDGE("r"), NODE("j"))), + RETURN(IDENT("n"), AS("n"))); // There is no ExpandUniquenessFilter for referencing the same edge. - CheckPlan(*query, ExpectScanAll(), ExpectExpand(), ExpectExpand()); + CheckPlan(storage, ExpectScanAll(), ExpectExpand(), ExpectExpand(), + ExpectProduce()); } TEST(TestLogicalPlanner, MultiMatchExistingEdgeOtherEdge) { - // Test MATCH (n) -[r]- (m) MATCH (m) -[r]- (j) -[e]- (l) + // Test MATCH (n) -[r]- (m) MATCH (m) -[r]- (j) -[e]- (l) RETURN n AstTreeStorage storage; - auto query = QUERY( - MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))), - MATCH(PATTERN(NODE("m"), EDGE("r"), NODE("j"), EDGE("e"), NODE("l")))); + QUERY(MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))), + MATCH(PATTERN(NODE("m"), EDGE("r"), NODE("j"), EDGE("e"), NODE("l"))), + RETURN(IDENT("n"), AS("n"))); // We need ExpandUniquenessFilter for edge `e` against `r` in second MATCH. - CheckPlan(*query, ExpectScanAll(), ExpectExpand(), ExpectExpand(), - ExpectExpand(), ExpectExpandUniquenessFilter()); + CheckPlan(storage, ExpectScanAll(), ExpectExpand(), ExpectExpand(), + ExpectExpand(), ExpectExpandUniquenessFilter(), + ExpectProduce()); } TEST(TestLogicalPlanner, MatchWithReturn) { // Test MATCH (old) WITH old AS new RETURN new AS new AstTreeStorage storage; - auto query = QUERY(MATCH(PATTERN(NODE("old"))), WITH(IDENT("old"), AS("new")), - RETURN(IDENT("new"), AS("new"))); + QUERY(MATCH(PATTERN(NODE("old"))), WITH(IDENT("old"), AS("new")), + RETURN(IDENT("new"), AS("new"))); // No accumulation since we only do reads. - CheckPlan(*query, ExpectScanAll(), ExpectProduce(), ExpectProduce()); + CheckPlan(storage, ExpectScanAll(), ExpectProduce(), ExpectProduce()); } TEST(TestLogicalPlanner, MatchWithWhereReturn) { @@ -441,11 +439,11 @@ TEST(TestLogicalPlanner, MatchWithWhereReturn) { auto dba = dbms.active(); auto prop = dba->property("prop"); AstTreeStorage storage; - auto query = QUERY(MATCH(PATTERN(NODE("old"))), WITH(IDENT("old"), AS("new")), - WHERE(LESS(PROPERTY_LOOKUP("new", prop), LITERAL(42))), - RETURN(IDENT("new"), AS("new"))); + QUERY(MATCH(PATTERN(NODE("old"))), WITH(IDENT("old"), AS("new")), + WHERE(LESS(PROPERTY_LOOKUP("new", prop), LITERAL(42))), + RETURN(IDENT("new"), AS("new"))); // No accumulation since we only do reads. - CheckPlan(*query, ExpectScanAll(), ExpectProduce(), ExpectFilter(), + CheckPlan(storage, ExpectScanAll(), ExpectProduce(), ExpectFilter(), ExpectProduce()); } @@ -456,10 +454,9 @@ TEST(TestLogicalPlanner, CreateMultiExpand) { auto r = dba->edge_type("r"); auto p = dba->edge_type("p"); AstTreeStorage storage; - auto query = QUERY( - CREATE(PATTERN(NODE("n"), EDGE("r", r, Direction::RIGHT), NODE("m")), - PATTERN(NODE("n"), EDGE("p", p, Direction::RIGHT), NODE("l")))); - CheckPlan(*query, ExpectCreateNode(), ExpectCreateExpand(), + QUERY(CREATE(PATTERN(NODE("n"), EDGE("r", r, Direction::RIGHT), NODE("m")), + PATTERN(NODE("n"), EDGE("p", p, Direction::RIGHT), NODE("l")))); + CheckPlan(storage, ExpectCreateNode(), ExpectCreateExpand(), ExpectCreateExpand()); } @@ -472,12 +469,11 @@ TEST(TestLogicalPlanner, MatchWithSumWhereReturn) { AstTreeStorage storage; auto sum = SUM(PROPERTY_LOOKUP("n", prop)); auto literal = LITERAL(42); - auto query = - QUERY(MATCH(PATTERN(NODE("n"))), WITH(ADD(sum, literal), AS("sum")), - WHERE(LESS(IDENT("sum"), LITERAL(42))), - RETURN(IDENT("sum"), AS("result"))); + QUERY(MATCH(PATTERN(NODE("n"))), WITH(ADD(sum, literal), AS("sum")), + WHERE(LESS(IDENT("sum"), LITERAL(42))), + RETURN(IDENT("sum"), AS("result"))); auto aggr = ExpectAggregate({sum}, {literal}); - CheckPlan(*query, ExpectScanAll(), aggr, ExpectProduce(), ExpectFilter(), + CheckPlan(storage, ExpectScanAll(), aggr, ExpectProduce(), ExpectFilter(), ExpectProduce()); } @@ -490,10 +486,10 @@ TEST(TestLogicalPlanner, MatchReturnSum) { AstTreeStorage storage; auto sum = SUM(PROPERTY_LOOKUP("n", prop1)); auto n_prop2 = PROPERTY_LOOKUP("n", prop2); - auto query = QUERY(MATCH(PATTERN(NODE("n"))), - RETURN(sum, AS("sum"), n_prop2, AS("group"))); + QUERY(MATCH(PATTERN(NODE("n"))), + RETURN(sum, AS("sum"), n_prop2, AS("group"))); auto aggr = ExpectAggregate({sum}, {n_prop2}); - CheckPlan(*query, ExpectScanAll(), aggr, ExpectProduce()); + CheckPlan(storage, ExpectScanAll(), aggr, ExpectProduce()); } TEST(TestLogicalPlanner, CreateWithSum) { @@ -508,7 +504,7 @@ TEST(TestLogicalPlanner, CreateWithSum) { auto symbol_table = MakeSymbolTable(*query); auto acc = ExpectAccumulate({symbol_table.at(*n_prop->expression_)}); auto aggr = ExpectAggregate({sum}, {}); - auto plan = MakeLogicalPlan(*query, symbol_table); + auto plan = MakeLogicalPlan(storage, symbol_table); // We expect both the accumulation and aggregation because the part before // WITH updates the database. CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, aggr, @@ -521,20 +517,18 @@ TEST(TestLogicalPlanner, MatchWithCreate) { auto dba = dbms.active(); auto r_type = dba->edge_type("r"); AstTreeStorage storage; - auto query = - QUERY(MATCH(PATTERN(NODE("n"))), WITH(IDENT("n"), AS("a")), - CREATE(PATTERN(NODE("a"), EDGE("r", r_type, Direction::RIGHT), - NODE("b")))); - CheckPlan(*query, ExpectScanAll(), ExpectProduce(), ExpectCreateExpand()); + QUERY(MATCH(PATTERN(NODE("n"))), WITH(IDENT("n"), AS("a")), + CREATE(PATTERN(NODE("a"), EDGE("r", r_type, Direction::RIGHT), + NODE("b")))); + CheckPlan(storage, ExpectScanAll(), ExpectProduce(), ExpectCreateExpand()); } TEST(TestLogicalPlanner, MatchReturnSkipLimit) { // Test MATCH (n) RETURN n SKIP 2 LIMIT 1 AstTreeStorage storage; - auto query = - QUERY(MATCH(PATTERN(NODE("n"))), - RETURN(IDENT("n"), AS("n"), SKIP(LITERAL(2)), LIMIT(LITERAL(1)))); - CheckPlan(*query, ExpectScanAll(), ExpectProduce(), ExpectSkip(), + QUERY(MATCH(PATTERN(NODE("n"))), + RETURN(IDENT("n"), AS("n"), SKIP(LITERAL(2)), LIMIT(LITERAL(1)))); + CheckPlan(storage, ExpectScanAll(), ExpectProduce(), ExpectSkip(), ExpectLimit()); } @@ -547,7 +541,7 @@ TEST(TestLogicalPlanner, CreateWithSkipReturnLimit) { RETURN(IDENT("m"), AS("m"), LIMIT(LITERAL(1)))); auto symbol_table = MakeSymbolTable(*query); auto acc = ExpectAccumulate({symbol_table.at(*ident_n)}); - auto plan = MakeLogicalPlan(*query, symbol_table); + auto plan = MakeLogicalPlan(storage, symbol_table); // Since we have a write query, we need to have Accumulate. This is a bit // different than Neo4j 3.0, which optimizes WITH followed by RETURN as a // single RETURN clause and then moves Skip and Limit before Accumulate. This @@ -570,7 +564,7 @@ TEST(TestLogicalPlanner, CreateReturnSumSkipLimit) { auto symbol_table = MakeSymbolTable(*query); auto acc = ExpectAccumulate({symbol_table.at(*n_prop->expression_)}); auto aggr = ExpectAggregate({sum}, {}); - auto plan = MakeLogicalPlan(*query, symbol_table); + auto plan = MakeLogicalPlan(storage, symbol_table); CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, aggr, ExpectProduce(), ExpectSkip(), ExpectLimit()); } @@ -582,8 +576,8 @@ TEST(TestLogicalPlanner, MatchReturnOrderBy) { auto prop = dba->property("prop"); AstTreeStorage storage; auto ret = RETURN(IDENT("n"), AS("n"), ORDER_BY(PROPERTY_LOOKUP("n", prop))); - auto query = QUERY(MATCH(PATTERN(NODE("n"))), ret); - CheckPlan(*query, ExpectScanAll(), ExpectProduce(), ExpectOrderBy()); + QUERY(MATCH(PATTERN(NODE("n"))), ret); + CheckPlan(storage, ExpectScanAll(), ExpectProduce(), ExpectOrderBy()); } TEST(TestLogicalPlanner, CreateWithOrderByWhere) { @@ -610,7 +604,7 @@ TEST(TestLogicalPlanner, CreateWithOrderByWhere) { symbol_table.at(*r_prop->expression_), // `r` in ORDER BY symbol_table.at(*m_prop->expression_), // `m` in WHERE }); - auto plan = MakeLogicalPlan(*query, symbol_table); + auto plan = MakeLogicalPlan(storage, symbol_table); CheckPlan(*plan, symbol_table, ExpectCreateNode(), ExpectCreateExpand(), acc, ExpectProduce(), ExpectFilter(), ExpectOrderBy()); } @@ -620,10 +614,9 @@ TEST(TestLogicalPlanner, ReturnAddSumCountOrderBy) { AstTreeStorage storage; auto sum = SUM(LITERAL(1)); auto count = COUNT(LITERAL(2)); - auto query = - QUERY(RETURN(ADD(sum, count), AS("result"), ORDER_BY(IDENT("result")))); + QUERY(RETURN(ADD(sum, count), AS("result"), ORDER_BY(IDENT("result")))); auto aggr = ExpectAggregate({sum, count}, {}); - CheckPlan(*query, aggr, ExpectProduce(), ExpectOrderBy()); + CheckPlan(storage, aggr, ExpectProduce(), ExpectOrderBy()); } TEST(TestLogicalPlanner, MatchMerge) { @@ -642,14 +635,14 @@ TEST(TestLogicalPlanner, MatchMerge) { ON_MATCH(SET(PROPERTY_LOOKUP("n", prop), LITERAL(42))), ON_CREATE(SET("m", IDENT("n")))), RETURN(ident_n, AS("n"))); - std::list on_match{ - new ExpectExpand(), new ExpectEdgeFilter(), new ExpectSetProperty()}; + std::list on_match{new ExpectExpand(), new ExpectFilter(), + new ExpectSetProperty()}; std::list on_create{new ExpectCreateExpand(), new ExpectSetProperties()}; auto symbol_table = MakeSymbolTable(*query); // We expect Accumulate after Merge, because it is considered as a write. auto acc = ExpectAccumulate({symbol_table.at(*ident_n)}); - auto plan = MakeLogicalPlan(*query, symbol_table); + auto plan = MakeLogicalPlan(storage, symbol_table); CheckPlan(*plan, symbol_table, ExpectScanAll(), ExpectMerge(on_match, on_create), acc, ExpectProduce()); for (auto &op : on_match) delete op; @@ -664,30 +657,31 @@ TEST(TestLogicalPlanner, MatchOptionalMatchWhereReturn) { auto dba = dbms.active(); auto prop = dba->property("prop"); AstTreeStorage storage; - auto query = QUERY(MATCH(PATTERN(NODE("n"))), - OPTIONAL_MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))), - WHERE(LESS(PROPERTY_LOOKUP("m", prop), LITERAL(42))), - RETURN(IDENT("r"), AS("r"))); + QUERY(MATCH(PATTERN(NODE("n"))), + OPTIONAL_MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))), + WHERE(LESS(PROPERTY_LOOKUP("m", prop), LITERAL(42))), + RETURN(IDENT("r"), AS("r"))); std::list optional{new ExpectScanAll(), new ExpectExpand(), new ExpectFilter()}; - CheckPlan(*query, ExpectScanAll(), ExpectOptional(optional), ExpectProduce()); + CheckPlan(storage, ExpectScanAll(), ExpectOptional(optional), + ExpectProduce()); } TEST(TestLogicalPlanner, MatchUnwindReturn) { // Test MATCH (n) UNWIND [1,2,3] AS x RETURN n AS n, x AS x AstTreeStorage storage; - auto query = QUERY(MATCH(PATTERN(NODE("n"))), - UNWIND(LIST(LITERAL(1), LITERAL(2), LITERAL(3)), AS("x")), - RETURN(IDENT("n"), AS("n"), IDENT("x"), AS("x"))); - CheckPlan(*query, ExpectScanAll(), ExpectUnwind(), ExpectProduce()); + QUERY(MATCH(PATTERN(NODE("n"))), + UNWIND(LIST(LITERAL(1), LITERAL(2), LITERAL(3)), AS("x")), + RETURN(IDENT("n"), AS("n"), IDENT("x"), AS("x"))); + CheckPlan(storage, ExpectScanAll(), ExpectUnwind(), ExpectProduce()); } TEST(TestLogicalPlanner, ReturnDistinctOrderBySkipLimit) { // Test RETURN DISTINCT 1 ORDER BY 1 SKIP 1 LIMIT 1 AstTreeStorage storage; - auto query = QUERY(RETURN_DISTINCT(LITERAL(1), AS("1"), ORDER_BY(LITERAL(1)), - SKIP(LITERAL(1)), LIMIT(LITERAL(1)))); - CheckPlan(*query, ExpectProduce(), ExpectDistinct(), ExpectOrderBy(), + QUERY(RETURN_DISTINCT(LITERAL(1), AS("1"), ORDER_BY(LITERAL(1)), + SKIP(LITERAL(1)), LIMIT(LITERAL(1)))); + CheckPlan(storage, ExpectProduce(), ExpectDistinct(), ExpectOrderBy(), ExpectSkip(), ExpectLimit()); } @@ -705,9 +699,76 @@ TEST(TestLogicalPlanner, CreateWithDistinctSumWhereReturn) { auto symbol_table = MakeSymbolTable(*query); auto acc = ExpectAccumulate({symbol_table.at(*node_n->identifier_)}); auto aggr = ExpectAggregate({sum}, {}); - auto plan = MakeLogicalPlan(*query, symbol_table); + auto plan = MakeLogicalPlan(storage, symbol_table); CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, aggr, ExpectProduce(), ExpectFilter(), ExpectDistinct(), ExpectProduce()); } +TEST(TestLogicalPlanner, MatchCrossReferenceVariable) { + // Test MATCH (n {prop: m.prop}), (m {prop: n.prop}) RETURN n + Dbms dbms; + auto dba = dbms.active(); + auto prop = dba->property("prop"); + AstTreeStorage storage; + auto node_n = NODE("n"); + auto m_prop = PROPERTY_LOOKUP("m", prop); + node_n->properties_[prop] = m_prop; + auto node_m = NODE("m"); + auto n_prop = PROPERTY_LOOKUP("n", prop); + node_m->properties_[prop] = n_prop; + QUERY(MATCH(PATTERN(node_n), PATTERN(node_m)), RETURN(IDENT("n"), AS("n"))); + // We expect both ScanAll to come before filters (2 are joined into one), + // because they need to populate the symbol values. + CheckPlan(storage, ExpectScanAll(), ExpectScanAll(), ExpectFilter(), + ExpectProduce()); +} + +TEST(TestLogicalPlanner, MatchWhereBeforeExpand) { + // Test MATCH (n) -[r]- (m) WHERE n.prop < 42 RETURN n + Dbms dbms; + auto dba = dbms.active(); + auto prop = dba->property("prop"); + AstTreeStorage storage; + QUERY(MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))), + WHERE(LESS(PROPERTY_LOOKUP("n", prop), LITERAL(42))), + RETURN(IDENT("n"), AS("n"))); + // We expect Fitler to come immediately after ScanAll, since it only uses `n`. + CheckPlan(storage, ExpectScanAll(), ExpectFilter(), ExpectExpand(), + ExpectProduce()); +} + +TEST(TestLogicalPlanner, MultiMatchWhere) { + // Test MATCH (n) -[r]- (m) MATCH (l) WHERE n.prop < 42 RETURN n + Dbms dbms; + auto dba = dbms.active(); + auto prop = dba->property("prop"); + AstTreeStorage storage; + QUERY(MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))), + MATCH(PATTERN(NODE("l"))), + WHERE(LESS(PROPERTY_LOOKUP("n", prop), LITERAL(42))), + RETURN(IDENT("n"), AS("n"))); + // Even though WHERE is in the second MATCH clause, we expect Filter to come + // before second ScanAll, since it only uses the value from first ScanAll. + CheckPlan(storage, ExpectScanAll(), ExpectFilter(), ExpectExpand(), + ExpectScanAll(), ExpectProduce()); +} + +TEST(TestLogicalPlanner, MatchOptionalMatchWhere) { + // Test MATCH (n) -[r]- (m) OPTIONAL MATCH (l) WHERE n.prop < 42 RETURN n + Dbms dbms; + auto dba = dbms.active(); + auto prop = dba->property("prop"); + AstTreeStorage storage; + QUERY(MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))), + OPTIONAL_MATCH(PATTERN(NODE("l"))), + WHERE(LESS(PROPERTY_LOOKUP("n", prop), LITERAL(42))), + RETURN(IDENT("n"), AS("n"))); + // Even though WHERE is in the second MATCH clause, and it uses the value from + // first ScanAll, it must remain part of the Optional. It should come before + // optional ScanAll. + std::list optional{new ExpectFilter(), new ExpectScanAll()}; + CheckPlan(storage, ExpectScanAll(), ExpectExpand(), ExpectOptional(optional), + ExpectProduce()); +} + } // namespace diff --git a/tests/unit/query_semantic.cpp b/tests/unit/query_semantic.cpp index 14c4857b4..3ea95a4f0 100644 --- a/tests/unit/query_semantic.cpp +++ b/tests/unit/query_semantic.cpp @@ -761,4 +761,35 @@ TEST(TestSymbolGenerator, WithUnwindReturn) { EXPECT_NE(elem, symbol_table.at(*ret_as_elem)); } +TEST(TestSymbolGenerator, MatchCrossReferenceVariable) { + // MATCH (n {prop: m.prop}), (m {prop: n.prop}) RETURN n + Dbms dbms; + auto dba = dbms.active(); + auto prop = dba->property("prop"); + AstTreeStorage storage; + auto node_n = NODE("n"); + auto m_prop = PROPERTY_LOOKUP("m", prop); + node_n->properties_[prop] = m_prop; + auto node_m = NODE("m"); + auto n_prop = PROPERTY_LOOKUP("n", prop); + node_m->properties_[prop] = n_prop; + auto ident_n = IDENT("n"); + auto as_n = AS("n"); + auto query = + QUERY(MATCH(PATTERN(node_n), PATTERN(node_m)), RETURN(ident_n, as_n)); + SymbolTable symbol_table; + SymbolGenerator symbol_generator(symbol_table); + query->Accept(symbol_generator); + // Symbols for `n`, `m` and `AS n` + EXPECT_EQ(symbol_table.max_position(), 3); + auto n = symbol_table.at(*node_n->identifier_); + EXPECT_EQ(n, symbol_table.at(*n_prop->expression_)); + EXPECT_EQ(n, symbol_table.at(*ident_n)); + EXPECT_NE(n, symbol_table.at(*as_n)); + auto m = symbol_table.at(*node_m->identifier_); + EXPECT_EQ(m, symbol_table.at(*m_prop->expression_)); + EXPECT_NE(n, m); + EXPECT_NE(m, symbol_table.at(*as_n)); +} + } // namespace