From 52709ad04c28286ac7e7e0ef4cce4082cc614bbd Mon Sep 17 00:00:00 2001
From: Teon Banek <teon.banek@memgraph.io>
Date: Wed, 30 Aug 2017 15:37:00 +0200
Subject: [PATCH] Inline filter inside ExpandVariable

Summary:
Reorder class definition in ast.hpp.
Test inlining filters in ExpandVariable.

Reviewers: florijan, mislav.bradac

Reviewed By: florijan

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D726
---
 CMakeLists.txt                        |   7 +-
 src/query/frontend/ast/ast.hpp        | 142 ++++++-----
 src/query/plan/operator.cpp           |  92 ++++---
 src/query/plan/operator.hpp           |   4 +-
 src/query/plan/rule_based_planner.cpp | 355 ++++++++++++++------------
 src/query/plan/rule_based_planner.hpp |  40 ++-
 tests/unit/query_planner.cpp          |  18 +-
 7 files changed, 374 insertions(+), 284 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 93024fdae..db39d34fc 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -71,6 +71,9 @@ endif()
 # TODO: set here 17 once it will be available in the cmake version (3.8)
 set(cxx_standard 14)
 set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++1z -Wall -Wno-c++1z-extensions")
+# Don't omit frame pointer in RelWithDebInfo, for additional callchain debug.
+set(CMAKE_CXX_FLAGS_RELWITHDEBINFO
+    "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -fno-omit-frame-pointer")
 # -----------------------------------------------------------------------------
 
 # dir variables
@@ -99,7 +102,7 @@ endif()
 
 # default build type is debug
 if ("${CMAKE_BUILD_TYPE}" STREQUAL "")
-    set(CMAKE_BUILD_TYPE "debug")
+    set(CMAKE_BUILD_TYPE "Debug")
 endif()
 message(STATUS "CMake build type: ${CMAKE_BUILD_TYPE}")
 # -----------------------------------------------------------------------------
@@ -356,7 +359,7 @@ string(STRIP ${COMMIT_HASH} COMMIT_HASH)
 set(MEMGRAPH_BUILD_NAME
     "memgraph_${COMMIT_NO}_${COMMIT_HASH}_${COMMIT_BRANCH}_${CMAKE_BUILD_TYPE}")
 add_custom_target(memgraph_link_target ALL
-	COMMAND ${CMAKE_COMMAND} -E create_symlink ${CMAKE_BINARY_DIR}/${MEMGRAPH_BUILD_NAME} ${CMAKE_BINARY_DIR}/memgraph DEPENDS ${MEMGRAPH_BUILD_NAME})
+  COMMAND ${CMAKE_COMMAND} -E create_symlink ${CMAKE_BINARY_DIR}/${MEMGRAPH_BUILD_NAME} ${CMAKE_BINARY_DIR}/memgraph DEPENDS ${MEMGRAPH_BUILD_NAME})
 # -----------------------------------------------------------------------------
 
 # memgraph main executable
diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp
index 84166259f..cb2395d1c 100644
--- a/src/query/frontend/ast/ast.hpp
+++ b/src/query/frontend/ast/ast.hpp
@@ -77,6 +77,8 @@ class Tree : public ::utils::Visitable<HierarchicalTreeVisitor>,
   const int uid_;
 };
 
+// Expressions
+
 class Expression : public Tree {
   friend class AstTreeStorage;
 
@@ -87,6 +89,29 @@ class Expression : public Tree {
   Expression(int uid) : Tree(uid) {}
 };
 
+class Where : public Tree {
+  friend class AstTreeStorage;
+
+ public:
+  DEFVISITABLE(TreeVisitor<TypedValue>);
+  bool Accept(HierarchicalTreeVisitor &visitor) override {
+    if (visitor.PreVisit(*this)) {
+      expression_->Accept(visitor);
+    }
+    return visitor.PostVisit(*this);
+  }
+
+  Where *Clone(AstTreeStorage &storage) const override {
+    return storage.Create<Where>(expression_->Clone(storage));
+  }
+
+  Expression *expression_ = nullptr;
+
+ protected:
+  Where(int uid) : Tree(uid) {}
+  Where(int uid, Expression *expression) : Tree(uid), expression_(expression) {}
+};
+
 class BinaryOperator : public Expression {
   friend class AstTreeStorage;
 
@@ -840,6 +865,42 @@ class Aggregation : public BinaryOperator {
   }
 };
 
+class All : public Expression {
+  friend class AstTreeStorage;
+
+ public:
+  DEFVISITABLE(TreeVisitor<TypedValue>);
+  bool Accept(HierarchicalTreeVisitor &visitor) override {
+    if (visitor.PreVisit(*this)) {
+      identifier_->Accept(visitor) && list_expression_->Accept(visitor) &&
+          where_->Accept(visitor);
+    }
+    return visitor.PostVisit(*this);
+  }
+
+  All *Clone(AstTreeStorage &storage) const override {
+    return storage.Create<All>(identifier_->Clone(storage),
+                               list_expression_->Clone(storage),
+                               where_->Clone(storage));
+  }
+
+  Identifier *identifier_ = nullptr;
+  Expression *list_expression_ = nullptr;
+  Where *where_ = nullptr;
+
+ protected:
+  All(int uid, Identifier *identifier, Expression *list_expression,
+      Where *where)
+      : Expression(uid),
+        identifier_(identifier),
+        list_expression_(list_expression),
+        where_(where) {
+    debug_assert(identifier, "identifier must not be nullptr");
+    debug_assert(list_expression, "list_expression must not be nullptr");
+    debug_assert(where, "where must not be nullptr");
+  }
+};
+
 class NamedExpression : public Tree {
   friend class AstTreeStorage;
 
@@ -877,6 +938,8 @@ class NamedExpression : public Tree {
         token_position_(token_position) {}
 };
 
+// Pattern atoms
+
 class PatternAtom : public Tree {
   friend class AstTreeStorage;
 
@@ -1026,15 +1089,6 @@ class BreadthFirstAtom : public EdgeAtom {
         max_depth_(max_depth) {}
 };
 
-class Clause : public Tree {
-  friend class AstTreeStorage;
-
- public:
-  Clause(int uid) : Tree(uid) {}
-
-  Clause *Clone(AstTreeStorage &storage) const override = 0;
-};
-
 class Pattern : public Tree {
   friend class AstTreeStorage;
 
@@ -1065,6 +1119,17 @@ class Pattern : public Tree {
   Pattern(int uid) : Tree(uid) {}
 };
 
+// Clauses
+
+class Clause : public Tree {
+  friend class AstTreeStorage;
+
+ public:
+  Clause(int uid) : Tree(uid) {}
+
+  Clause *Clone(AstTreeStorage &storage) const override = 0;
+};
+
 class Query : public Tree {
   friend class AstTreeStorage;
 
@@ -1120,65 +1185,6 @@ class Create : public Clause {
   std::vector<Pattern *> patterns_;
 };
 
-class Where : public Tree {
-  friend class AstTreeStorage;
-
- public:
-  DEFVISITABLE(TreeVisitor<TypedValue>);
-  bool Accept(HierarchicalTreeVisitor &visitor) override {
-    if (visitor.PreVisit(*this)) {
-      expression_->Accept(visitor);
-    }
-    return visitor.PostVisit(*this);
-  }
-
-  Where *Clone(AstTreeStorage &storage) const override {
-    return storage.Create<Where>(expression_->Clone(storage));
-  }
-
-  Expression *expression_ = nullptr;
-
- protected:
-  Where(int uid) : Tree(uid) {}
-  Where(int uid, Expression *expression) : Tree(uid), expression_(expression) {}
-};
-
-class All : public Expression {
-  friend class AstTreeStorage;
-
- public:
-  DEFVISITABLE(TreeVisitor<TypedValue>);
-  bool Accept(HierarchicalTreeVisitor &visitor) override {
-    if (visitor.PreVisit(*this)) {
-      identifier_->Accept(visitor) && list_expression_->Accept(visitor) &&
-          where_->Accept(visitor);
-    }
-    return visitor.PostVisit(*this);
-  }
-
-  All *Clone(AstTreeStorage &storage) const override {
-    return storage.Create<All>(identifier_->Clone(storage),
-                               list_expression_->Clone(storage),
-                               where_->Clone(storage));
-  }
-
-  Identifier *identifier_ = nullptr;
-  Expression *list_expression_ = nullptr;
-  Where *where_ = nullptr;
-
- protected:
-  All(int uid, Identifier *identifier, Expression *list_expression,
-      Where *where)
-      : Expression(uid),
-        identifier_(identifier),
-        list_expression_(list_expression),
-        where_(where) {
-    debug_assert(identifier, "identifier must not be nullptr");
-    debug_assert(list_expression, "list_expression must not be nullptr");
-    debug_assert(where, "where must not be nullptr");
-  }
-};
-
 class Match : public Clause {
   friend class AstTreeStorage;
 
diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp
index 75cfa15fc..519da233f 100644
--- a/src/query/plan/operator.cpp
+++ b/src/query/plan/operator.cpp
@@ -44,6 +44,18 @@ void ExpectType(Symbol symbol, TypedValue value, TypedValue::Type expected) {
                                 symbol.name(), value.type());
 }
 
+// Returns boolean result of evaluating filter expression. Null is treated as
+// false. Other non boolean values raise a QueryRuntimeException.
+bool EvaluateFilter(ExpressionEvaluator &evaluator, Expression *filter) {
+  TypedValue result = filter->Accept(evaluator);
+  // Null is treated like false.
+  if (result.IsNull()) return false;
+  if (result.type() != TypedValue::Type::Bool)
+    throw QueryRuntimeException(
+        "Filter expression must be a bool or null, but got {}.", result.type());
+  return result.Value<bool>();
+}
+
 }  // namespace
 
 bool Once::OnceCursor::Pull(Frame &, const SymbolTable &) {
@@ -239,7 +251,7 @@ ScanAll::ScanAll(const std::shared_ptr<LogicalOperator> &input,
       output_symbol_(output_symbol),
       graph_view_(graph_view) {
   permanent_assert(graph_view != GraphView::AS_IS,
-                   "ScanAll must have explicitly defined GraphView")
+                   "ScanAll must have explicitly defined GraphView");
 }
 
 ACCEPT_WITH_INPUT(ScanAll)
@@ -300,10 +312,10 @@ std::unique_ptr<Cursor> ScanAllByLabelPropertyRange::MakeCursor(
     ExpressionEvaluator evaluator(frame, symbol_table, db, graph_view_);
     auto convert = [&evaluator](const auto &bound)
         -> std::experimental::optional<utils::Bound<PropertyValue>> {
-          if (!bound) return std::experimental::nullopt;
-          return std::experimental::make_optional(utils::Bound<PropertyValue>(
-              bound.value().value()->Accept(evaluator), bound.value().type()));
-        };
+      if (!bound) return std::experimental::nullopt;
+      return std::experimental::make_optional(utils::Bound<PropertyValue>(
+          bound.value().value()->Accept(evaluator), bound.value().type()));
+    };
     return db.Vertices(label_, property_, convert(lower_bound()),
                        convert(upper_bound()), graph_view_ == GraphView::NEW);
   };
@@ -531,12 +543,14 @@ ExpandVariable::ExpandVariable(Symbol node_symbol, Symbol edge_symbol,
                                Expression *lower_bound, Expression *upper_bound,
                                const std::shared_ptr<LogicalOperator> &input,
                                Symbol input_symbol, bool existing_node,
-                               bool existing_edge, GraphView graph_view)
+                               bool existing_edge, GraphView graph_view,
+                               Expression *filter)
     : ExpandCommon(node_symbol, edge_symbol, direction, input, input_symbol,
                    existing_node, existing_edge, graph_view),
       lower_bound_(lower_bound),
       upper_bound_(upper_bound),
-      is_reverse_(is_reverse) {}
+      is_reverse_(is_reverse),
+      filter_(filter) {}
 
 bool Expand::ExpandCursor::HandleExistingEdge(const EdgeAccessor &new_edge,
                                               Frame &frame) const {
@@ -612,8 +626,9 @@ class ExpandVariableCursor : public Cursor {
       : self_(self), db_(db), input_cursor_(self.input_->MakeCursor(db)) {}
 
   bool Pull(Frame &frame, const SymbolTable &symbol_table) override {
+    ExpressionEvaluator evaluator(frame, symbol_table, db_, self_.graph_view_);
     while (true) {
-      if (Expand(frame)) return true;
+      if (Expand(frame, symbol_table)) return true;
 
       if (PullInput(frame, symbol_table)) {
         // if lower bound is zero we also yield empty paths
@@ -625,8 +640,11 @@ class ExpandVariableCursor : public Cursor {
           // take into account existing_edge when yielding empty paths
           if ((!self_.existing_edge_ || edges_on_frame.empty()) &&
               // Place the start vertex on the frame.
-              self_.HandleExistingNode(start_vertex, frame))
+              self_.HandleExistingNode(start_vertex, frame)) {
+            if (self_.filter_ && !EvaluateFilter(evaluator, self_.filter_))
+              continue;
             return true;
+          }
         }
         // if lower bound is not zero, we just continue, the next
         // loop iteration will attempt to expand and we're good
@@ -793,7 +811,8 @@ class ExpandVariableCursor : public Cursor {
    * case no more expansions are available from the current input
    * vertex and another Pull from the input cursor should be performed.
    */
-  bool Expand(Frame &frame) {
+  bool Expand(Frame &frame, const SymbolTable &symbol_table) {
+    ExpressionEvaluator evaluator(frame, symbol_table, db_, self_.graph_view_);
     // some expansions might not be valid due to
     // edge uniqueness, existing_edge, existing_node criterions,
     // so expand in a loop until either the input vertex is
@@ -851,6 +870,10 @@ class ExpandVariableCursor : public Cursor {
       auto edge_placement_result =
           HandleEdgePlacement(current_edge.first, edges_on_frame);
       if (edge_placement_result == EdgePlacementResult::MISMATCH) continue;
+      // Skip expanding out of filtered expansion. It is assumed that the
+      // expression does not use the vertex which has yet to be put on frame.
+      // Therefore, this check is done as soon as the edge is on the frame.
+      if (self_.filter_ && !EvaluateFilter(evaluator, self_.filter_)) continue;
 
       VertexAccessor current_vertex =
           current_edge.second == EdgeAtom::Direction::IN
@@ -1050,16 +1073,7 @@ bool Filter::FilterCursor::Pull(Frame &frame, const SymbolTable &symbol_table) {
   // and edges.
   ExpressionEvaluator evaluator(frame, symbol_table, db_, GraphView::OLD);
   while (input_cursor_->Pull(frame, symbol_table)) {
-    TypedValue result = self_.expression_->Accept(evaluator);
-    // Null is treated like false.
-    if (result.IsNull()) continue;
-
-    if (result.type() != TypedValue::Type::Bool)
-      throw QueryRuntimeException(
-          "Filter expression must be a bool or null, but got {}.",
-          result.type());
-    if (!result.Value<bool>()) continue;
-    return true;
+    if (EvaluateFilter(evaluator, self_.expression_)) return true;
   }
   return false;
 }
@@ -1203,11 +1217,11 @@ bool SetProperty::SetPropertyCursor::Pull(Frame &frame,
       // Skip setting properties on Null (can occur in optional match).
       break;
     case TypedValue::Type::Map:
-      // Semantically modifying a map makes sense, but it's not supported due to
-      // all the copying we do (when PropertyValue -> TypedValue and in
-      // ExpressionEvaluator). So even though we set a map property here, that
-      // is never visible to the user and it's not stored.
-      // TODO: fix above described bug
+    // Semantically modifying a map makes sense, but it's not supported due to
+    // all the copying we do (when PropertyValue -> TypedValue and in
+    // ExpressionEvaluator). So even though we set a map property here, that
+    // is never visible to the user and it's not stored.
+    // TODO: fix above described bug
     default:
       throw QueryRuntimeException(
           "Properties can only be set on Vertices and Edges");
@@ -1737,14 +1751,14 @@ void Aggregate::AggregateCursor::Update(
           *value_it = 1;
           break;
         case Aggregation::Op::COLLECT_LIST:
-            value_it->Value<std::vector<TypedValue>>().push_back(input_value);
-            break;
+          value_it->Value<std::vector<TypedValue>>().push_back(input_value);
+          break;
         case Aggregation::Op::COLLECT_MAP:
-            auto key = agg_elem_it->key->Accept(evaluator);
-            if (key.type() != TypedValue::Type::String)
-              throw QueryRuntimeException("Map key must be a string");
-            value_it->Value<std::map<std::string, TypedValue>>().emplace(
-                key.Value<std::string>(), input_value);
+          auto key = agg_elem_it->key->Accept(evaluator);
+          if (key.type() != TypedValue::Type::String)
+            throw QueryRuntimeException("Map key must be a string");
+          value_it->Value<std::map<std::string, TypedValue>>().emplace(
+              key.Value<std::string>(), input_value);
           break;
       }
       continue;
@@ -1789,14 +1803,14 @@ void Aggregate::AggregateCursor::Update(
         *value_it = *value_it + input_value;
         break;
       case Aggregation::Op::COLLECT_LIST:
-          value_it->Value<std::vector<TypedValue>>().push_back(input_value);
-          break;
+        value_it->Value<std::vector<TypedValue>>().push_back(input_value);
+        break;
       case Aggregation::Op::COLLECT_MAP:
-          auto key = agg_elem_it->key->Accept(evaluator);
-          if (key.type() != TypedValue::Type::String)
-            throw QueryRuntimeException("Map key must be a string");
-          value_it->Value<std::map<std::string, TypedValue>>().emplace(
-              key.Value<std::string>(), input_value);
+        auto key = agg_elem_it->key->Accept(evaluator);
+        if (key.type() != TypedValue::Type::String)
+          throw QueryRuntimeException("Map key must be a string");
+        value_it->Value<std::map<std::string, TypedValue>>().emplace(
+            key.Value<std::string>(), input_value);
         break;
     }  // end switch over Aggregation::Op enum
   }    // end loop over all aggregations
diff --git a/src/query/plan/operator.hpp b/src/query/plan/operator.hpp
index ab957280b..433fe3a8c 100644
--- a/src/query/plan/operator.hpp
+++ b/src/query/plan/operator.hpp
@@ -633,7 +633,8 @@ class ExpandVariable : public LogicalOperator, public ExpandCommon {
                  Expression *lower_bound, Expression *upper_bound,
                  const std::shared_ptr<LogicalOperator> &input,
                  Symbol input_symbol, bool existing_node, bool existing_edge,
-                 GraphView graph_view = GraphView::AS_IS);
+                 GraphView graph_view = GraphView::AS_IS,
+                 Expression *filter = nullptr);
 
   bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
   std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
@@ -646,6 +647,7 @@ class ExpandVariable : public LogicalOperator, public ExpandCommon {
   // True if the path should be written as expanding from node_symbol to
   // input_symbol.
   bool is_reverse_;
+  Expression *filter_;
 };
 
 /**
diff --git a/src/query/plan/rule_based_planner.cpp b/src/query/plan/rule_based_planner.cpp
index 33f6463c2..8a78453a7 100644
--- a/src/query/plan/rule_based_planner.cpp
+++ b/src/query/plan/rule_based_planner.cpp
@@ -114,10 +114,9 @@ class UsedSymbolsCollector : public HierarchicalTreeVisitor {
   const SymbolTable &symbol_table_;
 };
 
-bool HasBoundFilterSymbols(
-    const std::unordered_set<Symbol> &bound_symbols,
-    const std::pair<Expression *, std::unordered_set<Symbol>> &filter) {
-  for (const auto &symbol : filter.second) {
+bool HasBoundFilterSymbols(const std::unordered_set<Symbol> &bound_symbols,
+                           const Filters::FilterInfo &filter) {
+  for (const auto &symbol : filter.used_symbols) {
     if (bound_symbols.find(symbol) == bound_symbols.end()) {
       return false;
     }
@@ -357,8 +356,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
     // Aggregation expression1_ is optional in COUNT(*), and COLLECT_MAP uses
     // two expressions, so we can have 0, 1 or 2 elements on the
     // has_aggregation_stack for this Aggregation expression.
-    if (aggr.op_ == Aggregation::Op::COLLECT_MAP)
-      has_aggregation_.pop_back();
+    if (aggr.op_ == Aggregation::Op::COLLECT_MAP) has_aggregation_.pop_back();
     if (aggr.expression1_)
       has_aggregation_.back() = true;
     else
@@ -594,11 +592,178 @@ void AddMatching(const Match &match, SymbolTable &symbol_table,
                      matching);
 }
 
+// Iterates over `all_filters` joining them in one expression via
+// `FilterAndOperator`. Filters which use unbound symbols are skipped, as well
+// as those that fail the `predicate` function. The function takes a single
+// argument, `FilterInfo`. All the joined filters are removed from
+// `all_filters`.
+template <class TPredicate>
+Expression *ExtractFilters(const std::unordered_set<Symbol> &bound_symbols,
+                           std::vector<Filters::FilterInfo> &all_filters,
+                           AstTreeStorage &storage,
+                           const TPredicate &predicate) {
+  Expression *filter_expr = nullptr;
+  for (auto filters_it = all_filters.begin();
+       filters_it != all_filters.end();) {
+    if (HasBoundFilterSymbols(bound_symbols, *filters_it) &&
+        predicate(*filters_it)) {
+      filter_expr = BoolJoin<FilterAndOperator>(storage, filter_expr,
+                                                filters_it->expression);
+      filters_it = all_filters.erase(filters_it);
+    } else {
+      filters_it++;
+    }
+  }
+  return filter_expr;
+}
+
 }  // namespace
 
+namespace impl {
+
+// Returns false if the symbol was already bound, otherwise binds it and
+// returns true.
+bool BindSymbol(std::unordered_set<Symbol> &bound_symbols,
+                const Symbol &symbol) {
+  auto insertion = bound_symbols.insert(symbol);
+  return insertion.second;
+}
+
+Expression *FindExpandVariableFilter(
+    const std::unordered_set<Symbol> &bound_symbols,
+    const Symbol &expands_to_node,
+    std::vector<Filters::FilterInfo> &all_filters, AstTreeStorage &storage) {
+  return ExtractFilters(bound_symbols, all_filters, storage,
+                        [&](const auto &filter) {
+                          return filter.is_for_expand_variable &&
+                                 filter.used_symbols.find(expands_to_node) ==
+                                     filter.used_symbols.end();
+                        });
+}
+
+LogicalOperator *GenFilters(LogicalOperator *last_op,
+                            const std::unordered_set<Symbol> &bound_symbols,
+                            std::vector<Filters::FilterInfo> &all_filters,
+                            AstTreeStorage &storage) {
+  auto *filter_expr = ExtractFilters(bound_symbols, all_filters, storage,
+                                     [](const auto &) { return true; });
+  if (filter_expr) {
+    last_op =
+        new Filter(std::shared_ptr<LogicalOperator>(last_op), filter_expr);
+  }
+  return last_op;
+}
+
+LogicalOperator *GenReturn(Return &ret, LogicalOperator *input_op,
+                           SymbolTable &symbol_table, bool is_write,
+                           const std::unordered_set<Symbol> &bound_symbols,
+                           AstTreeStorage &storage) {
+  // Similar to WITH clause, but we want to accumulate and advance command when
+  // the query writes to the database. This way we handle the case when we want
+  // to return expressions with the latest updated results. For example,
+  // `MATCH (n) -- () SET n.prop = n.prop + 1 RETURN n.prop`. If we match same
+  // `n` multiple 'k' times, we want to return 'k' results where the property
+  // value is the same, final result of 'k' increments.
+  bool accumulate = is_write;
+  bool advance_command = false;
+  ReturnBodyContext body(ret.body_, symbol_table, bound_symbols, storage);
+  return GenReturnBody(input_op, advance_command, body, accumulate);
+}
+
+LogicalOperator *GenCreateForPattern(
+    Pattern &pattern, LogicalOperator *input_op,
+    const SymbolTable &symbol_table,
+    std::unordered_set<Symbol> &bound_symbols) {
+  auto base = [&](NodeAtom *node) -> LogicalOperator * {
+    if (BindSymbol(bound_symbols, symbol_table.at(*node->identifier_)))
+      return new CreateNode(node, std::shared_ptr<LogicalOperator>(input_op));
+    else
+      return input_op;
+  };
+
+  auto collect = [&](LogicalOperator *last_op, NodeAtom *prev_node,
+                     EdgeAtom *edge, NodeAtom *node) {
+    // Store the symbol from the first node as the input to CreateExpand.
+    const auto &input_symbol = symbol_table.at(*prev_node->identifier_);
+    // If the expand node was already bound, then we need to indicate this,
+    // so that CreateExpand only creates an edge.
+    bool node_existing = false;
+    if (!BindSymbol(bound_symbols, symbol_table.at(*node->identifier_))) {
+      node_existing = true;
+    }
+    if (!BindSymbol(bound_symbols, symbol_table.at(*edge->identifier_))) {
+      permanent_fail("Symbols used for created edges cannot be redeclared.");
+    }
+    return new CreateExpand(node, edge,
+                            std::shared_ptr<LogicalOperator>(last_op),
+                            input_symbol, node_existing);
+  };
+
+  return ReducePattern<LogicalOperator *>(pattern, base, collect);
+}
+
+// Generate an operator for a clause which writes to the database. If the clause
+// isn't handled, returns nullptr.
+LogicalOperator *HandleWriteClause(Clause *clause, LogicalOperator *input_op,
+                                   const SymbolTable &symbol_table,
+                                   std::unordered_set<Symbol> &bound_symbols) {
+  if (auto *create = dynamic_cast<Create *>(clause)) {
+    return GenCreate(*create, input_op, symbol_table, bound_symbols);
+  } else if (auto *del = dynamic_cast<query::Delete *>(clause)) {
+    return new plan::Delete(std::shared_ptr<LogicalOperator>(input_op),
+                            del->expressions_, del->detach_);
+  } else if (auto *set = dynamic_cast<query::SetProperty *>(clause)) {
+    return new plan::SetProperty(std::shared_ptr<LogicalOperator>(input_op),
+                                 set->property_lookup_, set->expression_);
+  } else if (auto *set = dynamic_cast<query::SetProperties *>(clause)) {
+    auto op = set->update_ ? plan::SetProperties::Op::UPDATE
+                           : plan::SetProperties::Op::REPLACE;
+    const auto &input_symbol = symbol_table.at(*set->identifier_);
+    return new plan::SetProperties(std::shared_ptr<LogicalOperator>(input_op),
+                                   input_symbol, set->expression_, op);
+  } else if (auto *set = dynamic_cast<query::SetLabels *>(clause)) {
+    const auto &input_symbol = symbol_table.at(*set->identifier_);
+    return new plan::SetLabels(std::shared_ptr<LogicalOperator>(input_op),
+                               input_symbol, set->labels_);
+  } else if (auto *rem = dynamic_cast<query::RemoveProperty *>(clause)) {
+    return new plan::RemoveProperty(std::shared_ptr<LogicalOperator>(input_op),
+                                    rem->property_lookup_);
+  } else if (auto *rem = dynamic_cast<query::RemoveLabels *>(clause)) {
+    const auto &input_symbol = symbol_table.at(*rem->identifier_);
+    return new plan::RemoveLabels(std::shared_ptr<LogicalOperator>(input_op),
+                                  input_symbol, rem->labels_);
+  }
+  return nullptr;
+}
+
+LogicalOperator *GenWith(With &with, LogicalOperator *input_op,
+                         SymbolTable &symbol_table, bool is_write,
+                         std::unordered_set<Symbol> &bound_symbols,
+                         AstTreeStorage &storage) {
+  // WITH clause is Accumulate/Aggregate (advance_command) + Produce and
+  // optional Filter. In case of update and aggregation, we want to accumulate
+  // first, so that when aggregating, we get the latest results. Similar to
+  // RETURN clause.
+  bool accumulate = is_write;
+  // No need to advance the command if we only performed reads.
+  bool advance_command = is_write;
+  ReturnBodyContext body(with.body_, symbol_table, bound_symbols, storage,
+                         with.where_);
+  LogicalOperator *last_op =
+      GenReturnBody(input_op, advance_command, body, accumulate);
+  // Reset bound symbols, so that only those in WITH are exposed.
+  bound_symbols.clear();
+  for (const auto &symbol : body.output_symbols()) {
+    BindSymbol(bound_symbols, symbol);
+  }
+  return last_op;
+}
+
+}  // namespace impl
+
 // Analyzes the filter expression by collecting information on filtering labels
-// and properties to be used with indexing. Note that all filters are never
-// updated here, but only labels and properties are.
+// and properties to be used with indexing. Note that `all_filters_` are never
+// updated here, but only `label_filters_` and `property_filters_` are.
 void Filters::AnalyzeFilter(Expression *expr, const SymbolTable &symbol_table) {
   using Bound = ScanAllByLabelPropertyRange::Bound;
   auto get_property_lookup = [](auto *maybe_lookup, auto *&prop_lookup,
@@ -714,11 +879,11 @@ void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table,
       collector.symbols_.insert(symbol);  // PropertyLookup uses the symbol.
       if (is_variable_path) {
         all_filters_.emplace_back(
-            storage.Create<All>(identifier, atom->identifier_,
-                                storage.Create<Where>(prop_equal)),
-            collector.symbols_);
+            FilterInfo{storage.Create<All>(identifier, atom->identifier_,
+                                           storage.Create<Where>(prop_equal)),
+                       collector.symbols_, true});
       } else {
-        all_filters_.emplace_back(prop_equal, collector.symbols_);
+        all_filters_.emplace_back(FilterInfo{prop_equal, collector.symbols_});
       }
     }
   };
@@ -729,9 +894,9 @@ void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table,
       label_filters_[node_symbol].insert(node->labels_.begin(),
                                          node->labels_.end());
       // Create a LabelsTest and store it in all_filters_.
-      all_filters_.emplace_back(
+      all_filters_.emplace_back(FilterInfo{
           storage.Create<LabelsTest>(node->identifier_, node->labels_),
-          std::unordered_set<Symbol>{node_symbol});
+          std::unordered_set<Symbol>{node_symbol}});
     }
     add_properties_filter(node);
   };
@@ -740,19 +905,19 @@ void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table,
     if (!edge->edge_types_.empty()) {
       if (edge->has_range_) {
         // We need a new identifier and symbol for All.
-        auto *identifier = edge->identifier_->Clone(storage);
-        symbol_table[*identifier] =
-            symbol_table.CreateSymbol(identifier->name_, false);
+        auto *ident_in_all = edge->identifier_->Clone(storage);
+        symbol_table[*ident_in_all] =
+            symbol_table.CreateSymbol(ident_in_all->name_, false);
         auto *edge_type_test =
-            storage.Create<EdgeTypeTest>(identifier, edge->edge_types_);
-        all_filters_.emplace_back(
-            storage.Create<All>(identifier, edge->identifier_,
+            storage.Create<EdgeTypeTest>(ident_in_all, edge->edge_types_);
+        all_filters_.emplace_back(FilterInfo{
+            storage.Create<All>(ident_in_all, edge->identifier_,
                                 storage.Create<Where>(edge_type_test)),
-            std::unordered_set<Symbol>{edge_symbol});
+            std::unordered_set<Symbol>{edge_symbol}, true});
       } else {
-        all_filters_.emplace_back(
+        all_filters_.emplace_back(FilterInfo{
             storage.Create<EdgeTypeTest>(edge->identifier_, edge->edge_types_),
-            std::unordered_set<Symbol>{edge_symbol});
+            std::unordered_set<Symbol>{edge_symbol}});
       }
     }
     add_properties_filter(edge, edge->has_range_);
@@ -761,13 +926,13 @@ void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table,
   ForEachPattern(pattern, add_node_filter, add_expand_filter);
 }
 
-// Adds the where filter expression to all filters and collects additional
+// Adds the where filter expression to `all_filters_` and collects additional
 // information for potential property and label indexing.
 void Filters::CollectWhereFilter(Where &where,
                                  const SymbolTable &symbol_table) {
   UsedSymbolsCollector collector(symbol_table);
   where.expression_->Accept(collector);
-  all_filters_.emplace_back(where.expression_, collector.symbols_);
+  all_filters_.emplace_back(FilterInfo{where.expression_, collector.symbols_});
   AnalyzeFilter(where.expression_, symbol_table);
 }
 
@@ -809,144 +974,4 @@ std::vector<QueryPart> CollectQueryParts(SymbolTable &symbol_table,
   return query_parts;
 }
 
-namespace impl {
-
-// Returns false if the symbol was already bound, otherwise binds it and
-// returns true.
-bool BindSymbol(std::unordered_set<Symbol> &bound_symbols,
-                const Symbol &symbol) {
-  auto insertion = bound_symbols.insert(symbol);
-  return insertion.second;
-}
-
-LogicalOperator *GenFilters(
-    LogicalOperator *last_op, const std::unordered_set<Symbol> &bound_symbols,
-    std::vector<std::pair<Expression *, std::unordered_set<Symbol>>>
-        &all_filters,
-    AstTreeStorage &storage) {
-  Expression *filter_expr = nullptr;
-  for (auto filters_it = all_filters.begin();
-       filters_it != all_filters.end();) {
-    if (HasBoundFilterSymbols(bound_symbols, *filters_it)) {
-      filter_expr =
-          BoolJoin<FilterAndOperator>(storage, filter_expr, filters_it->first);
-      filters_it = all_filters.erase(filters_it);
-    } else {
-      filters_it++;
-    }
-  }
-  if (filter_expr) {
-    last_op =
-        new Filter(std::shared_ptr<LogicalOperator>(last_op), filter_expr);
-  }
-  return last_op;
-}
-
-LogicalOperator *GenReturn(Return &ret, LogicalOperator *input_op,
-                           SymbolTable &symbol_table, bool is_write,
-                           const std::unordered_set<Symbol> &bound_symbols,
-                           AstTreeStorage &storage) {
-  // Similar to WITH clause, but we want to accumulate and advance command when
-  // the query writes to the database. This way we handle the case when we want
-  // to return expressions with the latest updated results. For example,
-  // `MATCH (n) -- () SET n.prop = n.prop + 1 RETURN n.prop`. If we match same
-  // `n` multiple 'k' times, we want to return 'k' results where the property
-  // value is the same, final result of 'k' increments.
-  bool accumulate = is_write;
-  bool advance_command = false;
-  ReturnBodyContext body(ret.body_, symbol_table, bound_symbols, storage);
-  return GenReturnBody(input_op, advance_command, body, accumulate);
-}
-
-LogicalOperator *GenCreateForPattern(
-    Pattern &pattern, LogicalOperator *input_op,
-    const SymbolTable &symbol_table,
-    std::unordered_set<Symbol> &bound_symbols) {
-  auto base = [&](NodeAtom *node) -> LogicalOperator * {
-    if (BindSymbol(bound_symbols, symbol_table.at(*node->identifier_)))
-      return new CreateNode(node, std::shared_ptr<LogicalOperator>(input_op));
-    else
-      return input_op;
-  };
-
-  auto collect = [&](LogicalOperator *last_op, NodeAtom *prev_node,
-                     EdgeAtom *edge, NodeAtom *node) {
-    // Store the symbol from the first node as the input to CreateExpand.
-    const auto &input_symbol = symbol_table.at(*prev_node->identifier_);
-    // If the expand node was already bound, then we need to indicate this,
-    // so that CreateExpand only creates an edge.
-    bool node_existing = false;
-    if (!BindSymbol(bound_symbols, symbol_table.at(*node->identifier_))) {
-      node_existing = true;
-    }
-    if (!BindSymbol(bound_symbols, symbol_table.at(*edge->identifier_))) {
-      permanent_fail("Symbols used for created edges cannot be redeclared.");
-    }
-    return new CreateExpand(node, edge,
-                            std::shared_ptr<LogicalOperator>(last_op),
-                            input_symbol, node_existing);
-  };
-
-  return ReducePattern<LogicalOperator *>(pattern, base, collect);
-}
-
-// Generate an operator for a clause which writes to the database. If the clause
-// isn't handled, returns nullptr.
-LogicalOperator *HandleWriteClause(Clause *clause, LogicalOperator *input_op,
-                                   const SymbolTable &symbol_table,
-                                   std::unordered_set<Symbol> &bound_symbols) {
-  if (auto *create = dynamic_cast<Create *>(clause)) {
-    return GenCreate(*create, input_op, symbol_table, bound_symbols);
-  } else if (auto *del = dynamic_cast<query::Delete *>(clause)) {
-    return new plan::Delete(std::shared_ptr<LogicalOperator>(input_op),
-                            del->expressions_, del->detach_);
-  } else if (auto *set = dynamic_cast<query::SetProperty *>(clause)) {
-    return new plan::SetProperty(std::shared_ptr<LogicalOperator>(input_op),
-                                 set->property_lookup_, set->expression_);
-  } else if (auto *set = dynamic_cast<query::SetProperties *>(clause)) {
-    auto op = set->update_ ? plan::SetProperties::Op::UPDATE
-                           : plan::SetProperties::Op::REPLACE;
-    const auto &input_symbol = symbol_table.at(*set->identifier_);
-    return new plan::SetProperties(std::shared_ptr<LogicalOperator>(input_op),
-                                   input_symbol, set->expression_, op);
-  } else if (auto *set = dynamic_cast<query::SetLabels *>(clause)) {
-    const auto &input_symbol = symbol_table.at(*set->identifier_);
-    return new plan::SetLabels(std::shared_ptr<LogicalOperator>(input_op),
-                               input_symbol, set->labels_);
-  } else if (auto *rem = dynamic_cast<query::RemoveProperty *>(clause)) {
-    return new plan::RemoveProperty(std::shared_ptr<LogicalOperator>(input_op),
-                                    rem->property_lookup_);
-  } else if (auto *rem = dynamic_cast<query::RemoveLabels *>(clause)) {
-    const auto &input_symbol = symbol_table.at(*rem->identifier_);
-    return new plan::RemoveLabels(std::shared_ptr<LogicalOperator>(input_op),
-                                  input_symbol, rem->labels_);
-  }
-  return nullptr;
-}
-
-LogicalOperator *GenWith(With &with, LogicalOperator *input_op,
-                         SymbolTable &symbol_table, bool is_write,
-                         std::unordered_set<Symbol> &bound_symbols,
-                         AstTreeStorage &storage) {
-  // WITH clause is Accumulate/Aggregate (advance_command) + Produce and
-  // optional Filter. In case of update and aggregation, we want to accumulate
-  // first, so that when aggregating, we get the latest results. Similar to
-  // RETURN clause.
-  bool accumulate = is_write;
-  // No need to advance the command if we only performed reads.
-  bool advance_command = is_write;
-  ReturnBodyContext body(with.body_, symbol_table, bound_symbols, storage,
-                         with.where_);
-  LogicalOperator *last_op =
-      GenReturnBody(input_op, advance_command, body, accumulate);
-  // Reset bound symbols, so that only those in WITH are exposed.
-  bound_symbols.clear();
-  for (const auto &symbol : body.output_symbols()) {
-    BindSymbol(bound_symbols, symbol);
-  }
-  return last_op;
-}
-
-}  // namespace impl
-
 }  // namespace query::plan
diff --git a/src/query/plan/rule_based_planner.hpp b/src/query/plan/rule_based_planner.hpp
index 795686ad4..672f028e6 100644
--- a/src/query/plan/rule_based_planner.hpp
+++ b/src/query/plan/rule_based_planner.hpp
@@ -38,7 +38,19 @@ class Filters {
     std::experimental::optional<Bound> upper_bound{};
   };
 
-  /// All filter expressions that should be generated.
+  /// Stores additional information for a filter expression.
+  struct FilterInfo {
+    /// The filter expression which must be satisfied.
+    Expression *expression;
+    /// Set of used symbols by the filter @c expression.
+    std::unordered_set<Symbol> used_symbols;
+    /// True if the filter is to be applied on multiple expanding edges.
+    /// This is used to inline filtering in an @c ExpandVariable operator.
+    bool is_for_expand_variable = false;
+  };
+
+  /// List of FilterInfo objects corresponding to all filter expressions that
+  /// should be generated.
   auto &all_filters() { return all_filters_; }
   const auto &all_filters() const { return all_filters_; }
   /// Mapping from a symbol to labels that are filtered on it. These should be
@@ -66,7 +78,7 @@ class Filters {
  private:
   void AnalyzeFilter(Expression *, const SymbolTable &);
 
-  std::vector<std::pair<Expression *, std::unordered_set<Symbol>>> all_filters_;
+  std::vector<FilterInfo> all_filters_;
   std::unordered_map<Symbol, std::set<GraphDbTypes::Label>> label_filters_;
   std::unordered_map<
       Symbol, std::map<GraphDbTypes::Property, std::vector<PropertyFilter>>>
@@ -190,11 +202,20 @@ namespace impl {
 bool BindSymbol(std::unordered_set<Symbol> &bound_symbols,
                 const Symbol &symbol);
 
-LogicalOperator *GenFilters(
-    LogicalOperator *last_op, const std::unordered_set<Symbol> &bound_symbols,
-    std::vector<std::pair<Expression *, std::unordered_set<Symbol>>>
-        &all_filters,
-    AstTreeStorage &storage);
+// Looks for filter expressions, which can be inlined in an ExpandVariable
+// operator. Such expressions are merged into one (via `and`) and removed from
+// `all_filters`. If the expression uses `expands_to_node`, it is skipped. In
+// such a case, we cannot cut variable expand short, since filtering may be
+// satisfied by a node deeper in the path.
+Expression *FindExpandVariableFilter(
+    const std::unordered_set<Symbol> &bound_symbols,
+    const Symbol &expands_to_node,
+    std::vector<Filters::FilterInfo> &all_filters, AstTreeStorage &storage);
+
+LogicalOperator *GenFilters(LogicalOperator *last_op,
+                            const std::unordered_set<Symbol> &bound_symbols,
+                            std::vector<Filters::FilterInfo> &all_filters,
+                            AstTreeStorage &storage);
 
 LogicalOperator *GenReturn(Return &ret, LogicalOperator *input_op,
                            SymbolTable &symbol_table, bool is_write,
@@ -464,12 +485,15 @@ class RuleBasedPlanner {
               std::shared_ptr<LogicalOperator>(last_op), node1_symbol,
               existing_node, match_context.graph_view);
         } else if (expansion.edge->has_range_) {
+          auto *filter_expr = impl::FindExpandVariableFilter(
+              bound_symbols, node_symbol, all_filters, storage);
           last_op = new ExpandVariable(
               node_symbol, edge_symbol, expansion.direction,
               expansion.direction != expansion.edge->direction_,
               expansion.edge->lower_bound_, expansion.edge->upper_bound_,
               std::shared_ptr<LogicalOperator>(last_op), node1_symbol,
-              existing_node, existing_edge, match_context.graph_view);
+              existing_node, existing_edge, match_context.graph_view,
+              filter_expr);
         } else {
           last_op = new Expand(node_symbol, edge_symbol, expansion.direction,
                                std::shared_ptr<LogicalOperator>(last_op),
diff --git a/tests/unit/query_planner.cpp b/tests/unit/query_planner.cpp
index 8ec502e1c..aaa7c1af8 100644
--- a/tests/unit/query_planner.cpp
+++ b/tests/unit/query_planner.cpp
@@ -1252,7 +1252,7 @@ TEST(TestLogicalPlanner, MatchExpandVariableNoBounds) {
   CheckPlan(storage, ExpectScanAll(), ExpectExpandVariable(), ExpectProduce());
 }
 
-TEST(TestLogicalPlanner, MatchExpandVariableFiltered) {
+TEST(TestLogicalPlanner, MatchExpandVariableInlinedFilter) {
   // Test MATCH (n) -[r :type * {prop: 42}]-> (m) RETURN r
   Dbms dbms;
   auto dba = dbms.active();
@@ -1263,6 +1263,22 @@ TEST(TestLogicalPlanner, MatchExpandVariableFiltered) {
   edge->has_range_ = true;
   edge->properties_[prop] = LITERAL(42);
   QUERY(MATCH(PATTERN(NODE("n"), edge, NODE("m"))), RETURN("r"));
+  CheckPlan(storage, ExpectScanAll(),
+            ExpectExpandVariable(),  // Filter is inlined in expand
+            ExpectProduce());
+}
+
+TEST(TestLogicalPlanner, MatchExpandVariableNotInlinedFilter) {
+  // Test MATCH (n) -[r :type * {prop: m.prop}]-> (m) RETURN r
+  Dbms dbms;
+  auto dba = dbms.active();
+  auto type = dba->EdgeType("type");
+  auto prop = PROPERTY_PAIR("prop");
+  AstTreeStorage storage;
+  auto edge = EDGE("r", type);
+  edge->has_range_ = true;
+  edge->properties_[prop] = EQ(PROPERTY_LOOKUP("m", prop), LITERAL(42));
+  QUERY(MATCH(PATTERN(NODE("n"), edge, NODE("m"))), RETURN("r"));
   CheckPlan(storage, ExpectScanAll(), ExpectExpandVariable(), ExpectFilter(),
             ExpectProduce());
 }