From 52709ad04c28286ac7e7e0ef4cce4082cc614bbd Mon Sep 17 00:00:00 2001 From: Teon Banek <teon.banek@memgraph.io> Date: Wed, 30 Aug 2017 15:37:00 +0200 Subject: [PATCH] Inline filter inside ExpandVariable Summary: Reorder class definition in ast.hpp. Test inlining filters in ExpandVariable. Reviewers: florijan, mislav.bradac Reviewed By: florijan Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D726 --- CMakeLists.txt | 7 +- src/query/frontend/ast/ast.hpp | 142 ++++++----- src/query/plan/operator.cpp | 92 ++++--- src/query/plan/operator.hpp | 4 +- src/query/plan/rule_based_planner.cpp | 355 ++++++++++++++------------ src/query/plan/rule_based_planner.hpp | 40 ++- tests/unit/query_planner.cpp | 18 +- 7 files changed, 374 insertions(+), 284 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 93024fdae..db39d34fc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -71,6 +71,9 @@ endif() # TODO: set here 17 once it will be available in the cmake version (3.8) set(cxx_standard 14) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++1z -Wall -Wno-c++1z-extensions") +# Don't omit frame pointer in RelWithDebInfo, for additional callchain debug. +set(CMAKE_CXX_FLAGS_RELWITHDEBINFO + "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -fno-omit-frame-pointer") # ----------------------------------------------------------------------------- # dir variables @@ -99,7 +102,7 @@ endif() # default build type is debug if ("${CMAKE_BUILD_TYPE}" STREQUAL "") - set(CMAKE_BUILD_TYPE "debug") + set(CMAKE_BUILD_TYPE "Debug") endif() message(STATUS "CMake build type: ${CMAKE_BUILD_TYPE}") # ----------------------------------------------------------------------------- @@ -356,7 +359,7 @@ string(STRIP ${COMMIT_HASH} COMMIT_HASH) set(MEMGRAPH_BUILD_NAME "memgraph_${COMMIT_NO}_${COMMIT_HASH}_${COMMIT_BRANCH}_${CMAKE_BUILD_TYPE}") add_custom_target(memgraph_link_target ALL - COMMAND ${CMAKE_COMMAND} -E create_symlink ${CMAKE_BINARY_DIR}/${MEMGRAPH_BUILD_NAME} ${CMAKE_BINARY_DIR}/memgraph DEPENDS ${MEMGRAPH_BUILD_NAME}) + COMMAND ${CMAKE_COMMAND} -E create_symlink ${CMAKE_BINARY_DIR}/${MEMGRAPH_BUILD_NAME} ${CMAKE_BINARY_DIR}/memgraph DEPENDS ${MEMGRAPH_BUILD_NAME}) # ----------------------------------------------------------------------------- # memgraph main executable diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index 84166259f..cb2395d1c 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -77,6 +77,8 @@ class Tree : public ::utils::Visitable<HierarchicalTreeVisitor>, const int uid_; }; +// Expressions + class Expression : public Tree { friend class AstTreeStorage; @@ -87,6 +89,29 @@ class Expression : public Tree { Expression(int uid) : Tree(uid) {} }; +class Where : public Tree { + friend class AstTreeStorage; + + public: + DEFVISITABLE(TreeVisitor<TypedValue>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + Where *Clone(AstTreeStorage &storage) const override { + return storage.Create<Where>(expression_->Clone(storage)); + } + + Expression *expression_ = nullptr; + + protected: + Where(int uid) : Tree(uid) {} + Where(int uid, Expression *expression) : Tree(uid), expression_(expression) {} +}; + class BinaryOperator : public Expression { friend class AstTreeStorage; @@ -840,6 +865,42 @@ class Aggregation : public BinaryOperator { } }; +class All : public Expression { + friend class AstTreeStorage; + + public: + DEFVISITABLE(TreeVisitor<TypedValue>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + identifier_->Accept(visitor) && list_expression_->Accept(visitor) && + where_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + All *Clone(AstTreeStorage &storage) const override { + return storage.Create<All>(identifier_->Clone(storage), + list_expression_->Clone(storage), + where_->Clone(storage)); + } + + Identifier *identifier_ = nullptr; + Expression *list_expression_ = nullptr; + Where *where_ = nullptr; + + protected: + All(int uid, Identifier *identifier, Expression *list_expression, + Where *where) + : Expression(uid), + identifier_(identifier), + list_expression_(list_expression), + where_(where) { + debug_assert(identifier, "identifier must not be nullptr"); + debug_assert(list_expression, "list_expression must not be nullptr"); + debug_assert(where, "where must not be nullptr"); + } +}; + class NamedExpression : public Tree { friend class AstTreeStorage; @@ -877,6 +938,8 @@ class NamedExpression : public Tree { token_position_(token_position) {} }; +// Pattern atoms + class PatternAtom : public Tree { friend class AstTreeStorage; @@ -1026,15 +1089,6 @@ class BreadthFirstAtom : public EdgeAtom { max_depth_(max_depth) {} }; -class Clause : public Tree { - friend class AstTreeStorage; - - public: - Clause(int uid) : Tree(uid) {} - - Clause *Clone(AstTreeStorage &storage) const override = 0; -}; - class Pattern : public Tree { friend class AstTreeStorage; @@ -1065,6 +1119,17 @@ class Pattern : public Tree { Pattern(int uid) : Tree(uid) {} }; +// Clauses + +class Clause : public Tree { + friend class AstTreeStorage; + + public: + Clause(int uid) : Tree(uid) {} + + Clause *Clone(AstTreeStorage &storage) const override = 0; +}; + class Query : public Tree { friend class AstTreeStorage; @@ -1120,65 +1185,6 @@ class Create : public Clause { std::vector<Pattern *> patterns_; }; -class Where : public Tree { - friend class AstTreeStorage; - - public: - DEFVISITABLE(TreeVisitor<TypedValue>); - bool Accept(HierarchicalTreeVisitor &visitor) override { - if (visitor.PreVisit(*this)) { - expression_->Accept(visitor); - } - return visitor.PostVisit(*this); - } - - Where *Clone(AstTreeStorage &storage) const override { - return storage.Create<Where>(expression_->Clone(storage)); - } - - Expression *expression_ = nullptr; - - protected: - Where(int uid) : Tree(uid) {} - Where(int uid, Expression *expression) : Tree(uid), expression_(expression) {} -}; - -class All : public Expression { - friend class AstTreeStorage; - - public: - DEFVISITABLE(TreeVisitor<TypedValue>); - bool Accept(HierarchicalTreeVisitor &visitor) override { - if (visitor.PreVisit(*this)) { - identifier_->Accept(visitor) && list_expression_->Accept(visitor) && - where_->Accept(visitor); - } - return visitor.PostVisit(*this); - } - - All *Clone(AstTreeStorage &storage) const override { - return storage.Create<All>(identifier_->Clone(storage), - list_expression_->Clone(storage), - where_->Clone(storage)); - } - - Identifier *identifier_ = nullptr; - Expression *list_expression_ = nullptr; - Where *where_ = nullptr; - - protected: - All(int uid, Identifier *identifier, Expression *list_expression, - Where *where) - : Expression(uid), - identifier_(identifier), - list_expression_(list_expression), - where_(where) { - debug_assert(identifier, "identifier must not be nullptr"); - debug_assert(list_expression, "list_expression must not be nullptr"); - debug_assert(where, "where must not be nullptr"); - } -}; - class Match : public Clause { friend class AstTreeStorage; diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 75cfa15fc..519da233f 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -44,6 +44,18 @@ void ExpectType(Symbol symbol, TypedValue value, TypedValue::Type expected) { symbol.name(), value.type()); } +// Returns boolean result of evaluating filter expression. Null is treated as +// false. Other non boolean values raise a QueryRuntimeException. +bool EvaluateFilter(ExpressionEvaluator &evaluator, Expression *filter) { + TypedValue result = filter->Accept(evaluator); + // Null is treated like false. + if (result.IsNull()) return false; + if (result.type() != TypedValue::Type::Bool) + throw QueryRuntimeException( + "Filter expression must be a bool or null, but got {}.", result.type()); + return result.Value<bool>(); +} + } // namespace bool Once::OnceCursor::Pull(Frame &, const SymbolTable &) { @@ -239,7 +251,7 @@ ScanAll::ScanAll(const std::shared_ptr<LogicalOperator> &input, output_symbol_(output_symbol), graph_view_(graph_view) { permanent_assert(graph_view != GraphView::AS_IS, - "ScanAll must have explicitly defined GraphView") + "ScanAll must have explicitly defined GraphView"); } ACCEPT_WITH_INPUT(ScanAll) @@ -300,10 +312,10 @@ std::unique_ptr<Cursor> ScanAllByLabelPropertyRange::MakeCursor( ExpressionEvaluator evaluator(frame, symbol_table, db, graph_view_); auto convert = [&evaluator](const auto &bound) -> std::experimental::optional<utils::Bound<PropertyValue>> { - if (!bound) return std::experimental::nullopt; - return std::experimental::make_optional(utils::Bound<PropertyValue>( - bound.value().value()->Accept(evaluator), bound.value().type())); - }; + if (!bound) return std::experimental::nullopt; + return std::experimental::make_optional(utils::Bound<PropertyValue>( + bound.value().value()->Accept(evaluator), bound.value().type())); + }; return db.Vertices(label_, property_, convert(lower_bound()), convert(upper_bound()), graph_view_ == GraphView::NEW); }; @@ -531,12 +543,14 @@ ExpandVariable::ExpandVariable(Symbol node_symbol, Symbol edge_symbol, Expression *lower_bound, Expression *upper_bound, const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, bool existing_node, - bool existing_edge, GraphView graph_view) + bool existing_edge, GraphView graph_view, + Expression *filter) : ExpandCommon(node_symbol, edge_symbol, direction, input, input_symbol, existing_node, existing_edge, graph_view), lower_bound_(lower_bound), upper_bound_(upper_bound), - is_reverse_(is_reverse) {} + is_reverse_(is_reverse), + filter_(filter) {} bool Expand::ExpandCursor::HandleExistingEdge(const EdgeAccessor &new_edge, Frame &frame) const { @@ -612,8 +626,9 @@ class ExpandVariableCursor : public Cursor { : self_(self), db_(db), input_cursor_(self.input_->MakeCursor(db)) {} bool Pull(Frame &frame, const SymbolTable &symbol_table) override { + ExpressionEvaluator evaluator(frame, symbol_table, db_, self_.graph_view_); while (true) { - if (Expand(frame)) return true; + if (Expand(frame, symbol_table)) return true; if (PullInput(frame, symbol_table)) { // if lower bound is zero we also yield empty paths @@ -625,8 +640,11 @@ class ExpandVariableCursor : public Cursor { // take into account existing_edge when yielding empty paths if ((!self_.existing_edge_ || edges_on_frame.empty()) && // Place the start vertex on the frame. - self_.HandleExistingNode(start_vertex, frame)) + self_.HandleExistingNode(start_vertex, frame)) { + if (self_.filter_ && !EvaluateFilter(evaluator, self_.filter_)) + continue; return true; + } } // if lower bound is not zero, we just continue, the next // loop iteration will attempt to expand and we're good @@ -793,7 +811,8 @@ class ExpandVariableCursor : public Cursor { * case no more expansions are available from the current input * vertex and another Pull from the input cursor should be performed. */ - bool Expand(Frame &frame) { + bool Expand(Frame &frame, const SymbolTable &symbol_table) { + ExpressionEvaluator evaluator(frame, symbol_table, db_, self_.graph_view_); // some expansions might not be valid due to // edge uniqueness, existing_edge, existing_node criterions, // so expand in a loop until either the input vertex is @@ -851,6 +870,10 @@ class ExpandVariableCursor : public Cursor { auto edge_placement_result = HandleEdgePlacement(current_edge.first, edges_on_frame); if (edge_placement_result == EdgePlacementResult::MISMATCH) continue; + // Skip expanding out of filtered expansion. It is assumed that the + // expression does not use the vertex which has yet to be put on frame. + // Therefore, this check is done as soon as the edge is on the frame. + if (self_.filter_ && !EvaluateFilter(evaluator, self_.filter_)) continue; VertexAccessor current_vertex = current_edge.second == EdgeAtom::Direction::IN @@ -1050,16 +1073,7 @@ bool Filter::FilterCursor::Pull(Frame &frame, const SymbolTable &symbol_table) { // and edges. ExpressionEvaluator evaluator(frame, symbol_table, db_, GraphView::OLD); while (input_cursor_->Pull(frame, symbol_table)) { - TypedValue result = self_.expression_->Accept(evaluator); - // Null is treated like false. - if (result.IsNull()) continue; - - if (result.type() != TypedValue::Type::Bool) - throw QueryRuntimeException( - "Filter expression must be a bool or null, but got {}.", - result.type()); - if (!result.Value<bool>()) continue; - return true; + if (EvaluateFilter(evaluator, self_.expression_)) return true; } return false; } @@ -1203,11 +1217,11 @@ bool SetProperty::SetPropertyCursor::Pull(Frame &frame, // Skip setting properties on Null (can occur in optional match). break; case TypedValue::Type::Map: - // Semantically modifying a map makes sense, but it's not supported due to - // all the copying we do (when PropertyValue -> TypedValue and in - // ExpressionEvaluator). So even though we set a map property here, that - // is never visible to the user and it's not stored. - // TODO: fix above described bug + // Semantically modifying a map makes sense, but it's not supported due to + // all the copying we do (when PropertyValue -> TypedValue and in + // ExpressionEvaluator). So even though we set a map property here, that + // is never visible to the user and it's not stored. + // TODO: fix above described bug default: throw QueryRuntimeException( "Properties can only be set on Vertices and Edges"); @@ -1737,14 +1751,14 @@ void Aggregate::AggregateCursor::Update( *value_it = 1; break; case Aggregation::Op::COLLECT_LIST: - value_it->Value<std::vector<TypedValue>>().push_back(input_value); - break; + value_it->Value<std::vector<TypedValue>>().push_back(input_value); + break; case Aggregation::Op::COLLECT_MAP: - auto key = agg_elem_it->key->Accept(evaluator); - if (key.type() != TypedValue::Type::String) - throw QueryRuntimeException("Map key must be a string"); - value_it->Value<std::map<std::string, TypedValue>>().emplace( - key.Value<std::string>(), input_value); + auto key = agg_elem_it->key->Accept(evaluator); + if (key.type() != TypedValue::Type::String) + throw QueryRuntimeException("Map key must be a string"); + value_it->Value<std::map<std::string, TypedValue>>().emplace( + key.Value<std::string>(), input_value); break; } continue; @@ -1789,14 +1803,14 @@ void Aggregate::AggregateCursor::Update( *value_it = *value_it + input_value; break; case Aggregation::Op::COLLECT_LIST: - value_it->Value<std::vector<TypedValue>>().push_back(input_value); - break; + value_it->Value<std::vector<TypedValue>>().push_back(input_value); + break; case Aggregation::Op::COLLECT_MAP: - auto key = agg_elem_it->key->Accept(evaluator); - if (key.type() != TypedValue::Type::String) - throw QueryRuntimeException("Map key must be a string"); - value_it->Value<std::map<std::string, TypedValue>>().emplace( - key.Value<std::string>(), input_value); + auto key = agg_elem_it->key->Accept(evaluator); + if (key.type() != TypedValue::Type::String) + throw QueryRuntimeException("Map key must be a string"); + value_it->Value<std::map<std::string, TypedValue>>().emplace( + key.Value<std::string>(), input_value); break; } // end switch over Aggregation::Op enum } // end loop over all aggregations diff --git a/src/query/plan/operator.hpp b/src/query/plan/operator.hpp index ab957280b..433fe3a8c 100644 --- a/src/query/plan/operator.hpp +++ b/src/query/plan/operator.hpp @@ -633,7 +633,8 @@ class ExpandVariable : public LogicalOperator, public ExpandCommon { Expression *lower_bound, Expression *upper_bound, const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, bool existing_node, bool existing_edge, - GraphView graph_view = GraphView::AS_IS); + GraphView graph_view = GraphView::AS_IS, + Expression *filter = nullptr); bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override; @@ -646,6 +647,7 @@ class ExpandVariable : public LogicalOperator, public ExpandCommon { // True if the path should be written as expanding from node_symbol to // input_symbol. bool is_reverse_; + Expression *filter_; }; /** diff --git a/src/query/plan/rule_based_planner.cpp b/src/query/plan/rule_based_planner.cpp index 33f6463c2..8a78453a7 100644 --- a/src/query/plan/rule_based_planner.cpp +++ b/src/query/plan/rule_based_planner.cpp @@ -114,10 +114,9 @@ class UsedSymbolsCollector : public HierarchicalTreeVisitor { const SymbolTable &symbol_table_; }; -bool HasBoundFilterSymbols( - const std::unordered_set<Symbol> &bound_symbols, - const std::pair<Expression *, std::unordered_set<Symbol>> &filter) { - for (const auto &symbol : filter.second) { +bool HasBoundFilterSymbols(const std::unordered_set<Symbol> &bound_symbols, + const Filters::FilterInfo &filter) { + for (const auto &symbol : filter.used_symbols) { if (bound_symbols.find(symbol) == bound_symbols.end()) { return false; } @@ -357,8 +356,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor { // Aggregation expression1_ is optional in COUNT(*), and COLLECT_MAP uses // two expressions, so we can have 0, 1 or 2 elements on the // has_aggregation_stack for this Aggregation expression. - if (aggr.op_ == Aggregation::Op::COLLECT_MAP) - has_aggregation_.pop_back(); + if (aggr.op_ == Aggregation::Op::COLLECT_MAP) has_aggregation_.pop_back(); if (aggr.expression1_) has_aggregation_.back() = true; else @@ -594,11 +592,178 @@ void AddMatching(const Match &match, SymbolTable &symbol_table, matching); } +// Iterates over `all_filters` joining them in one expression via +// `FilterAndOperator`. Filters which use unbound symbols are skipped, as well +// as those that fail the `predicate` function. The function takes a single +// argument, `FilterInfo`. All the joined filters are removed from +// `all_filters`. +template <class TPredicate> +Expression *ExtractFilters(const std::unordered_set<Symbol> &bound_symbols, + std::vector<Filters::FilterInfo> &all_filters, + AstTreeStorage &storage, + const TPredicate &predicate) { + Expression *filter_expr = nullptr; + for (auto filters_it = all_filters.begin(); + filters_it != all_filters.end();) { + if (HasBoundFilterSymbols(bound_symbols, *filters_it) && + predicate(*filters_it)) { + filter_expr = BoolJoin<FilterAndOperator>(storage, filter_expr, + filters_it->expression); + filters_it = all_filters.erase(filters_it); + } else { + filters_it++; + } + } + return filter_expr; +} + } // namespace +namespace impl { + +// Returns false if the symbol was already bound, otherwise binds it and +// returns true. +bool BindSymbol(std::unordered_set<Symbol> &bound_symbols, + const Symbol &symbol) { + auto insertion = bound_symbols.insert(symbol); + return insertion.second; +} + +Expression *FindExpandVariableFilter( + const std::unordered_set<Symbol> &bound_symbols, + const Symbol &expands_to_node, + std::vector<Filters::FilterInfo> &all_filters, AstTreeStorage &storage) { + return ExtractFilters(bound_symbols, all_filters, storage, + [&](const auto &filter) { + return filter.is_for_expand_variable && + filter.used_symbols.find(expands_to_node) == + filter.used_symbols.end(); + }); +} + +LogicalOperator *GenFilters(LogicalOperator *last_op, + const std::unordered_set<Symbol> &bound_symbols, + std::vector<Filters::FilterInfo> &all_filters, + AstTreeStorage &storage) { + auto *filter_expr = ExtractFilters(bound_symbols, all_filters, storage, + [](const auto &) { return true; }); + if (filter_expr) { + last_op = + new Filter(std::shared_ptr<LogicalOperator>(last_op), filter_expr); + } + return last_op; +} + +LogicalOperator *GenReturn(Return &ret, LogicalOperator *input_op, + SymbolTable &symbol_table, bool is_write, + const std::unordered_set<Symbol> &bound_symbols, + AstTreeStorage &storage) { + // Similar to WITH clause, but we want to accumulate and advance command when + // the query writes to the database. This way we handle the case when we want + // to return expressions with the latest updated results. For example, + // `MATCH (n) -- () SET n.prop = n.prop + 1 RETURN n.prop`. If we match same + // `n` multiple 'k' times, we want to return 'k' results where the property + // value is the same, final result of 'k' increments. + bool accumulate = is_write; + bool advance_command = false; + ReturnBodyContext body(ret.body_, symbol_table, bound_symbols, storage); + return GenReturnBody(input_op, advance_command, body, accumulate); +} + +LogicalOperator *GenCreateForPattern( + Pattern &pattern, LogicalOperator *input_op, + const SymbolTable &symbol_table, + std::unordered_set<Symbol> &bound_symbols) { + auto base = [&](NodeAtom *node) -> LogicalOperator * { + if (BindSymbol(bound_symbols, symbol_table.at(*node->identifier_))) + return new CreateNode(node, std::shared_ptr<LogicalOperator>(input_op)); + else + return input_op; + }; + + auto collect = [&](LogicalOperator *last_op, NodeAtom *prev_node, + EdgeAtom *edge, NodeAtom *node) { + // Store the symbol from the first node as the input to CreateExpand. + const auto &input_symbol = symbol_table.at(*prev_node->identifier_); + // If the expand node was already bound, then we need to indicate this, + // so that CreateExpand only creates an edge. + bool node_existing = false; + if (!BindSymbol(bound_symbols, symbol_table.at(*node->identifier_))) { + node_existing = true; + } + if (!BindSymbol(bound_symbols, symbol_table.at(*edge->identifier_))) { + permanent_fail("Symbols used for created edges cannot be redeclared."); + } + return new CreateExpand(node, edge, + std::shared_ptr<LogicalOperator>(last_op), + input_symbol, node_existing); + }; + + return ReducePattern<LogicalOperator *>(pattern, base, collect); +} + +// Generate an operator for a clause which writes to the database. If the clause +// isn't handled, returns nullptr. +LogicalOperator *HandleWriteClause(Clause *clause, LogicalOperator *input_op, + const SymbolTable &symbol_table, + std::unordered_set<Symbol> &bound_symbols) { + if (auto *create = dynamic_cast<Create *>(clause)) { + return GenCreate(*create, input_op, symbol_table, bound_symbols); + } else if (auto *del = dynamic_cast<query::Delete *>(clause)) { + return new plan::Delete(std::shared_ptr<LogicalOperator>(input_op), + del->expressions_, del->detach_); + } else if (auto *set = dynamic_cast<query::SetProperty *>(clause)) { + return new plan::SetProperty(std::shared_ptr<LogicalOperator>(input_op), + set->property_lookup_, set->expression_); + } else if (auto *set = dynamic_cast<query::SetProperties *>(clause)) { + auto op = set->update_ ? plan::SetProperties::Op::UPDATE + : plan::SetProperties::Op::REPLACE; + const auto &input_symbol = symbol_table.at(*set->identifier_); + return new plan::SetProperties(std::shared_ptr<LogicalOperator>(input_op), + input_symbol, set->expression_, op); + } else if (auto *set = dynamic_cast<query::SetLabels *>(clause)) { + const auto &input_symbol = symbol_table.at(*set->identifier_); + return new plan::SetLabels(std::shared_ptr<LogicalOperator>(input_op), + input_symbol, set->labels_); + } else if (auto *rem = dynamic_cast<query::RemoveProperty *>(clause)) { + return new plan::RemoveProperty(std::shared_ptr<LogicalOperator>(input_op), + rem->property_lookup_); + } else if (auto *rem = dynamic_cast<query::RemoveLabels *>(clause)) { + const auto &input_symbol = symbol_table.at(*rem->identifier_); + return new plan::RemoveLabels(std::shared_ptr<LogicalOperator>(input_op), + input_symbol, rem->labels_); + } + return nullptr; +} + +LogicalOperator *GenWith(With &with, LogicalOperator *input_op, + SymbolTable &symbol_table, bool is_write, + std::unordered_set<Symbol> &bound_symbols, + AstTreeStorage &storage) { + // WITH clause is Accumulate/Aggregate (advance_command) + Produce and + // optional Filter. In case of update and aggregation, we want to accumulate + // first, so that when aggregating, we get the latest results. Similar to + // RETURN clause. + bool accumulate = is_write; + // No need to advance the command if we only performed reads. + bool advance_command = is_write; + ReturnBodyContext body(with.body_, symbol_table, bound_symbols, storage, + with.where_); + LogicalOperator *last_op = + GenReturnBody(input_op, advance_command, body, accumulate); + // Reset bound symbols, so that only those in WITH are exposed. + bound_symbols.clear(); + for (const auto &symbol : body.output_symbols()) { + BindSymbol(bound_symbols, symbol); + } + return last_op; +} + +} // namespace impl + // Analyzes the filter expression by collecting information on filtering labels -// and properties to be used with indexing. Note that all filters are never -// updated here, but only labels and properties are. +// and properties to be used with indexing. Note that `all_filters_` are never +// updated here, but only `label_filters_` and `property_filters_` are. void Filters::AnalyzeFilter(Expression *expr, const SymbolTable &symbol_table) { using Bound = ScanAllByLabelPropertyRange::Bound; auto get_property_lookup = [](auto *maybe_lookup, auto *&prop_lookup, @@ -714,11 +879,11 @@ void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table, collector.symbols_.insert(symbol); // PropertyLookup uses the symbol. if (is_variable_path) { all_filters_.emplace_back( - storage.Create<All>(identifier, atom->identifier_, - storage.Create<Where>(prop_equal)), - collector.symbols_); + FilterInfo{storage.Create<All>(identifier, atom->identifier_, + storage.Create<Where>(prop_equal)), + collector.symbols_, true}); } else { - all_filters_.emplace_back(prop_equal, collector.symbols_); + all_filters_.emplace_back(FilterInfo{prop_equal, collector.symbols_}); } } }; @@ -729,9 +894,9 @@ void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table, label_filters_[node_symbol].insert(node->labels_.begin(), node->labels_.end()); // Create a LabelsTest and store it in all_filters_. - all_filters_.emplace_back( + all_filters_.emplace_back(FilterInfo{ storage.Create<LabelsTest>(node->identifier_, node->labels_), - std::unordered_set<Symbol>{node_symbol}); + std::unordered_set<Symbol>{node_symbol}}); } add_properties_filter(node); }; @@ -740,19 +905,19 @@ void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table, if (!edge->edge_types_.empty()) { if (edge->has_range_) { // We need a new identifier and symbol for All. - auto *identifier = edge->identifier_->Clone(storage); - symbol_table[*identifier] = - symbol_table.CreateSymbol(identifier->name_, false); + auto *ident_in_all = edge->identifier_->Clone(storage); + symbol_table[*ident_in_all] = + symbol_table.CreateSymbol(ident_in_all->name_, false); auto *edge_type_test = - storage.Create<EdgeTypeTest>(identifier, edge->edge_types_); - all_filters_.emplace_back( - storage.Create<All>(identifier, edge->identifier_, + storage.Create<EdgeTypeTest>(ident_in_all, edge->edge_types_); + all_filters_.emplace_back(FilterInfo{ + storage.Create<All>(ident_in_all, edge->identifier_, storage.Create<Where>(edge_type_test)), - std::unordered_set<Symbol>{edge_symbol}); + std::unordered_set<Symbol>{edge_symbol}, true}); } else { - all_filters_.emplace_back( + all_filters_.emplace_back(FilterInfo{ storage.Create<EdgeTypeTest>(edge->identifier_, edge->edge_types_), - std::unordered_set<Symbol>{edge_symbol}); + std::unordered_set<Symbol>{edge_symbol}}); } } add_properties_filter(edge, edge->has_range_); @@ -761,13 +926,13 @@ void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table, ForEachPattern(pattern, add_node_filter, add_expand_filter); } -// Adds the where filter expression to all filters and collects additional +// Adds the where filter expression to `all_filters_` and collects additional // information for potential property and label indexing. void Filters::CollectWhereFilter(Where &where, const SymbolTable &symbol_table) { UsedSymbolsCollector collector(symbol_table); where.expression_->Accept(collector); - all_filters_.emplace_back(where.expression_, collector.symbols_); + all_filters_.emplace_back(FilterInfo{where.expression_, collector.symbols_}); AnalyzeFilter(where.expression_, symbol_table); } @@ -809,144 +974,4 @@ std::vector<QueryPart> CollectQueryParts(SymbolTable &symbol_table, return query_parts; } -namespace impl { - -// Returns false if the symbol was already bound, otherwise binds it and -// returns true. -bool BindSymbol(std::unordered_set<Symbol> &bound_symbols, - const Symbol &symbol) { - auto insertion = bound_symbols.insert(symbol); - return insertion.second; -} - -LogicalOperator *GenFilters( - LogicalOperator *last_op, const std::unordered_set<Symbol> &bound_symbols, - std::vector<std::pair<Expression *, std::unordered_set<Symbol>>> - &all_filters, - AstTreeStorage &storage) { - Expression *filter_expr = nullptr; - for (auto filters_it = all_filters.begin(); - filters_it != all_filters.end();) { - if (HasBoundFilterSymbols(bound_symbols, *filters_it)) { - filter_expr = - BoolJoin<FilterAndOperator>(storage, filter_expr, filters_it->first); - filters_it = all_filters.erase(filters_it); - } else { - filters_it++; - } - } - if (filter_expr) { - last_op = - new Filter(std::shared_ptr<LogicalOperator>(last_op), filter_expr); - } - return last_op; -} - -LogicalOperator *GenReturn(Return &ret, LogicalOperator *input_op, - SymbolTable &symbol_table, bool is_write, - const std::unordered_set<Symbol> &bound_symbols, - AstTreeStorage &storage) { - // Similar to WITH clause, but we want to accumulate and advance command when - // the query writes to the database. This way we handle the case when we want - // to return expressions with the latest updated results. For example, - // `MATCH (n) -- () SET n.prop = n.prop + 1 RETURN n.prop`. If we match same - // `n` multiple 'k' times, we want to return 'k' results where the property - // value is the same, final result of 'k' increments. - bool accumulate = is_write; - bool advance_command = false; - ReturnBodyContext body(ret.body_, symbol_table, bound_symbols, storage); - return GenReturnBody(input_op, advance_command, body, accumulate); -} - -LogicalOperator *GenCreateForPattern( - Pattern &pattern, LogicalOperator *input_op, - const SymbolTable &symbol_table, - std::unordered_set<Symbol> &bound_symbols) { - auto base = [&](NodeAtom *node) -> LogicalOperator * { - if (BindSymbol(bound_symbols, symbol_table.at(*node->identifier_))) - return new CreateNode(node, std::shared_ptr<LogicalOperator>(input_op)); - else - return input_op; - }; - - auto collect = [&](LogicalOperator *last_op, NodeAtom *prev_node, - EdgeAtom *edge, NodeAtom *node) { - // Store the symbol from the first node as the input to CreateExpand. - const auto &input_symbol = symbol_table.at(*prev_node->identifier_); - // If the expand node was already bound, then we need to indicate this, - // so that CreateExpand only creates an edge. - bool node_existing = false; - if (!BindSymbol(bound_symbols, symbol_table.at(*node->identifier_))) { - node_existing = true; - } - if (!BindSymbol(bound_symbols, symbol_table.at(*edge->identifier_))) { - permanent_fail("Symbols used for created edges cannot be redeclared."); - } - return new CreateExpand(node, edge, - std::shared_ptr<LogicalOperator>(last_op), - input_symbol, node_existing); - }; - - return ReducePattern<LogicalOperator *>(pattern, base, collect); -} - -// Generate an operator for a clause which writes to the database. If the clause -// isn't handled, returns nullptr. -LogicalOperator *HandleWriteClause(Clause *clause, LogicalOperator *input_op, - const SymbolTable &symbol_table, - std::unordered_set<Symbol> &bound_symbols) { - if (auto *create = dynamic_cast<Create *>(clause)) { - return GenCreate(*create, input_op, symbol_table, bound_symbols); - } else if (auto *del = dynamic_cast<query::Delete *>(clause)) { - return new plan::Delete(std::shared_ptr<LogicalOperator>(input_op), - del->expressions_, del->detach_); - } else if (auto *set = dynamic_cast<query::SetProperty *>(clause)) { - return new plan::SetProperty(std::shared_ptr<LogicalOperator>(input_op), - set->property_lookup_, set->expression_); - } else if (auto *set = dynamic_cast<query::SetProperties *>(clause)) { - auto op = set->update_ ? plan::SetProperties::Op::UPDATE - : plan::SetProperties::Op::REPLACE; - const auto &input_symbol = symbol_table.at(*set->identifier_); - return new plan::SetProperties(std::shared_ptr<LogicalOperator>(input_op), - input_symbol, set->expression_, op); - } else if (auto *set = dynamic_cast<query::SetLabels *>(clause)) { - const auto &input_symbol = symbol_table.at(*set->identifier_); - return new plan::SetLabels(std::shared_ptr<LogicalOperator>(input_op), - input_symbol, set->labels_); - } else if (auto *rem = dynamic_cast<query::RemoveProperty *>(clause)) { - return new plan::RemoveProperty(std::shared_ptr<LogicalOperator>(input_op), - rem->property_lookup_); - } else if (auto *rem = dynamic_cast<query::RemoveLabels *>(clause)) { - const auto &input_symbol = symbol_table.at(*rem->identifier_); - return new plan::RemoveLabels(std::shared_ptr<LogicalOperator>(input_op), - input_symbol, rem->labels_); - } - return nullptr; -} - -LogicalOperator *GenWith(With &with, LogicalOperator *input_op, - SymbolTable &symbol_table, bool is_write, - std::unordered_set<Symbol> &bound_symbols, - AstTreeStorage &storage) { - // WITH clause is Accumulate/Aggregate (advance_command) + Produce and - // optional Filter. In case of update and aggregation, we want to accumulate - // first, so that when aggregating, we get the latest results. Similar to - // RETURN clause. - bool accumulate = is_write; - // No need to advance the command if we only performed reads. - bool advance_command = is_write; - ReturnBodyContext body(with.body_, symbol_table, bound_symbols, storage, - with.where_); - LogicalOperator *last_op = - GenReturnBody(input_op, advance_command, body, accumulate); - // Reset bound symbols, so that only those in WITH are exposed. - bound_symbols.clear(); - for (const auto &symbol : body.output_symbols()) { - BindSymbol(bound_symbols, symbol); - } - return last_op; -} - -} // namespace impl - } // namespace query::plan diff --git a/src/query/plan/rule_based_planner.hpp b/src/query/plan/rule_based_planner.hpp index 795686ad4..672f028e6 100644 --- a/src/query/plan/rule_based_planner.hpp +++ b/src/query/plan/rule_based_planner.hpp @@ -38,7 +38,19 @@ class Filters { std::experimental::optional<Bound> upper_bound{}; }; - /// All filter expressions that should be generated. + /// Stores additional information for a filter expression. + struct FilterInfo { + /// The filter expression which must be satisfied. + Expression *expression; + /// Set of used symbols by the filter @c expression. + std::unordered_set<Symbol> used_symbols; + /// True if the filter is to be applied on multiple expanding edges. + /// This is used to inline filtering in an @c ExpandVariable operator. + bool is_for_expand_variable = false; + }; + + /// List of FilterInfo objects corresponding to all filter expressions that + /// should be generated. auto &all_filters() { return all_filters_; } const auto &all_filters() const { return all_filters_; } /// Mapping from a symbol to labels that are filtered on it. These should be @@ -66,7 +78,7 @@ class Filters { private: void AnalyzeFilter(Expression *, const SymbolTable &); - std::vector<std::pair<Expression *, std::unordered_set<Symbol>>> all_filters_; + std::vector<FilterInfo> all_filters_; std::unordered_map<Symbol, std::set<GraphDbTypes::Label>> label_filters_; std::unordered_map< Symbol, std::map<GraphDbTypes::Property, std::vector<PropertyFilter>>> @@ -190,11 +202,20 @@ namespace impl { bool BindSymbol(std::unordered_set<Symbol> &bound_symbols, const Symbol &symbol); -LogicalOperator *GenFilters( - LogicalOperator *last_op, const std::unordered_set<Symbol> &bound_symbols, - std::vector<std::pair<Expression *, std::unordered_set<Symbol>>> - &all_filters, - AstTreeStorage &storage); +// Looks for filter expressions, which can be inlined in an ExpandVariable +// operator. Such expressions are merged into one (via `and`) and removed from +// `all_filters`. If the expression uses `expands_to_node`, it is skipped. In +// such a case, we cannot cut variable expand short, since filtering may be +// satisfied by a node deeper in the path. +Expression *FindExpandVariableFilter( + const std::unordered_set<Symbol> &bound_symbols, + const Symbol &expands_to_node, + std::vector<Filters::FilterInfo> &all_filters, AstTreeStorage &storage); + +LogicalOperator *GenFilters(LogicalOperator *last_op, + const std::unordered_set<Symbol> &bound_symbols, + std::vector<Filters::FilterInfo> &all_filters, + AstTreeStorage &storage); LogicalOperator *GenReturn(Return &ret, LogicalOperator *input_op, SymbolTable &symbol_table, bool is_write, @@ -464,12 +485,15 @@ class RuleBasedPlanner { std::shared_ptr<LogicalOperator>(last_op), node1_symbol, existing_node, match_context.graph_view); } else if (expansion.edge->has_range_) { + auto *filter_expr = impl::FindExpandVariableFilter( + bound_symbols, node_symbol, all_filters, storage); last_op = new ExpandVariable( node_symbol, edge_symbol, expansion.direction, expansion.direction != expansion.edge->direction_, expansion.edge->lower_bound_, expansion.edge->upper_bound_, std::shared_ptr<LogicalOperator>(last_op), node1_symbol, - existing_node, existing_edge, match_context.graph_view); + existing_node, existing_edge, match_context.graph_view, + filter_expr); } else { last_op = new Expand(node_symbol, edge_symbol, expansion.direction, std::shared_ptr<LogicalOperator>(last_op), diff --git a/tests/unit/query_planner.cpp b/tests/unit/query_planner.cpp index 8ec502e1c..aaa7c1af8 100644 --- a/tests/unit/query_planner.cpp +++ b/tests/unit/query_planner.cpp @@ -1252,7 +1252,7 @@ TEST(TestLogicalPlanner, MatchExpandVariableNoBounds) { CheckPlan(storage, ExpectScanAll(), ExpectExpandVariable(), ExpectProduce()); } -TEST(TestLogicalPlanner, MatchExpandVariableFiltered) { +TEST(TestLogicalPlanner, MatchExpandVariableInlinedFilter) { // Test MATCH (n) -[r :type * {prop: 42}]-> (m) RETURN r Dbms dbms; auto dba = dbms.active(); @@ -1263,6 +1263,22 @@ TEST(TestLogicalPlanner, MatchExpandVariableFiltered) { edge->has_range_ = true; edge->properties_[prop] = LITERAL(42); QUERY(MATCH(PATTERN(NODE("n"), edge, NODE("m"))), RETURN("r")); + CheckPlan(storage, ExpectScanAll(), + ExpectExpandVariable(), // Filter is inlined in expand + ExpectProduce()); +} + +TEST(TestLogicalPlanner, MatchExpandVariableNotInlinedFilter) { + // Test MATCH (n) -[r :type * {prop: m.prop}]-> (m) RETURN r + Dbms dbms; + auto dba = dbms.active(); + auto type = dba->EdgeType("type"); + auto prop = PROPERTY_PAIR("prop"); + AstTreeStorage storage; + auto edge = EDGE("r", type); + edge->has_range_ = true; + edge->properties_[prop] = EQ(PROPERTY_LOOKUP("m", prop), LITERAL(42)); + QUERY(MATCH(PATTERN(NODE("n"), edge, NODE("m"))), RETURN("r")); CheckPlan(storage, ExpectScanAll(), ExpectExpandVariable(), ExpectFilter(), ExpectProduce()); }