From 591d086013e66fc1b753f554d3c960dbb6c14fe5 Mon Sep 17 00:00:00 2001
From: Teon Banek <teon.banek@memgraph.io>
Date: Tue, 22 Aug 2017 16:24:40 +0200
Subject: [PATCH] Map symbols to expansions to speed up variable planning

Summary:
Test variable planning BFS.
Add more tests for variably planning ExpandVariable.
Don't recreate the whole matching when varying expansions.
Use explicit constructors in private planner classes.

Reviewers: mislav.bradac, florijan

Reviewed By: mislav.bradac

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D691
---
 src/query/plan/planner.hpp                  |   8 +-
 src/query/plan/rule_based_planner.cpp       |  22 ++-
 src/query/plan/variable_start_planner.cpp   | 182 +++++++++++---------
 tests/benchmark/query/planner.cpp           |  50 ++++++
 tests/unit/query_variable_start_planner.cpp |  55 ++++++
 5 files changed, 232 insertions(+), 85 deletions(-)
 create mode 100644 tests/benchmark/query/planner.cpp

diff --git a/src/query/plan/planner.hpp b/src/query/plan/planner.hpp
index e5b8e9b3b..86cc1de3a 100644
--- a/src/query/plan/planner.hpp
+++ b/src/query/plan/planner.hpp
@@ -99,6 +99,10 @@ struct Matching {
   std::vector<std::unordered_set<Symbol>> edge_symbols;
   /// Information on used filter expressions while matching.
   Filters filters;
+  /// Maps node symbols to expansions which bind them.
+  std::unordered_map<Symbol, std::set<int>> node_symbol_to_expansions{};
+  /// All node and edge symbols across all expansions (from all matches).
+  std::unordered_set<Symbol> expansion_symbols{};
 };
 
 /// @brief Represents a read (+ write) part of a query. Parts are split on
@@ -166,7 +170,7 @@ struct PlanningContext {
 /// @sa MakeLogicalPlan
 class RuleBasedPlanner {
  public:
-  RuleBasedPlanner(PlanningContext &context) : context_(context) {}
+  explicit RuleBasedPlanner(PlanningContext &context) : context_(context) {}
 
   /// @brief The result of plan generation is the root of the generated operator
   /// tree.
@@ -187,7 +191,7 @@ class RuleBasedPlanner {
 /// @sa MakeLogicalPlan
 class VariableStartPlanner {
  public:
-  VariableStartPlanner(PlanningContext &context) : context_(context) {}
+  explicit VariableStartPlanner(PlanningContext &context) : context_(context) {}
 
   /// @brief The result of plan generation is a vector of roots to multiple
   /// generated operator trees.
diff --git a/src/query/plan/rule_based_planner.cpp b/src/query/plan/rule_based_planner.cpp
index 6095ca65f..c85b47811 100644
--- a/src/query/plan/rule_based_planner.cpp
+++ b/src/query/plan/rule_based_planner.cpp
@@ -128,7 +128,7 @@ auto GenCreate(Create &create, LogicalOperator *input_op,
 // Collects symbols from identifiers found in visited AST nodes.
 class UsedSymbolsCollector : public HierarchicalTreeVisitor {
  public:
-  UsedSymbolsCollector(const SymbolTable &symbol_table)
+  explicit UsedSymbolsCollector(const SymbolTable &symbol_table)
       : symbol_table_(symbol_table) {}
 
   using HierarchicalTreeVisitor::PreVisit;
@@ -702,15 +702,29 @@ void AddMatching(const std::vector<Pattern *> &patterns, Where *where,
   auto expansions = NormalizePatterns(symbol_table, patterns);
   std::unordered_set<Symbol> edge_symbols;
   for (const auto &expansion : expansions) {
+    // Matching may already have some expansions, so offset our index.
+    const int expansion_ix = matching.expansions.size();
+    // Map node1 symbol to expansion
+    const auto &node1_sym = symbol_table.at(*expansion.node1->identifier_);
+    matching.node_symbol_to_expansions[node1_sym].insert(expansion_ix);
+    // Add node1 to all symbols.
+    matching.expansion_symbols.insert(node1_sym);
     if (expansion.edge) {
-      edge_symbols.insert(symbol_table.at(*expansion.edge->identifier_));
+      const auto &edge_sym = symbol_table.at(*expansion.edge->identifier_);
+      // Fill edge symbols for Cyphermorphism.
+      edge_symbols.insert(edge_sym);
+      // Map node2 symbol to expansion
+      const auto &node2_sym = symbol_table.at(*expansion.node2->identifier_);
+      matching.node_symbol_to_expansions[node2_sym].insert(expansion_ix);
+      // Add edge and node2 to all symbols
+      matching.expansion_symbols.insert(edge_sym);
+      matching.expansion_symbols.insert(node2_sym);
     }
+    matching.expansions.push_back(expansion);
   }
   if (!edge_symbols.empty()) {
     matching.edge_symbols.emplace_back(edge_symbols);
   }
-  matching.expansions.insert(matching.expansions.end(), expansions.begin(),
-                             expansions.end());
   for (auto *pattern : patterns) {
     matching.filters.CollectPatternFilters(*pattern, symbol_table, storage);
   }
diff --git a/src/query/plan/variable_start_planner.cpp b/src/query/plan/variable_start_planner.cpp
index 6535d04a4..fea4df2b2 100644
--- a/src/query/plan/variable_start_planner.cpp
+++ b/src/query/plan/variable_start_planner.cpp
@@ -1,6 +1,7 @@
 #include "query/plan/planner.hpp"
 
 #include <limits>
+#include <queue>
 
 #include "cppitertools/slice.hpp"
 #include "gflags/gflags.h"
@@ -17,7 +18,7 @@ namespace {
 
 class NodeSymbolHash {
  public:
-  NodeSymbolHash(const SymbolTable &symbol_table)
+  explicit NodeSymbolHash(const SymbolTable &symbol_table)
       : symbol_table_(symbol_table) {}
 
   size_t operator()(const NodeAtom *node_atom) const {
@@ -30,11 +31,11 @@ class NodeSymbolHash {
 
 class NodeSymbolEqual {
  public:
-  NodeSymbolEqual(const SymbolTable &symbol_table)
+  explicit NodeSymbolEqual(const SymbolTable &symbol_table)
       : symbol_table_(symbol_table) {}
 
-  size_t operator()(const NodeAtom *node_atom1,
-                    const NodeAtom *node_atom2) const {
+  bool operator()(const NodeAtom *node_atom1,
+                  const NodeAtom *node_atom2) const {
     return symbol_table_.at(*node_atom1->identifier_) ==
            symbol_table_.at(*node_atom2->identifier_);
   }
@@ -43,14 +44,20 @@ class NodeSymbolEqual {
   const SymbolTable &symbol_table_;
 };
 
-// Finds the next Expansion which has one of its nodes among the already
-// expanded symbols. The function may modify expansions, by flipping their nodes
-// and direction. This is done, so that the return iterator always points to the
-// expansion whose node1 is the already expanded one, while node2 may not be.
-auto NextExpansion(const SymbolTable &symbol_table,
-                   const std::unordered_set<Symbol> &expanded_symbols,
-                   const std::unordered_set<Symbol> &all_expansion_symbols,
-                   std::vector<Expansion> &expansions) {
+// Add applicable expansions for `node_symbol` to `next_expansions`. These
+// expansions are removed from `node_symbol_to_expansions`, while
+// `seen_expansions` and `expanded_symbols` are populated with new data.
+void AddNextExpansions(
+    const Symbol &node_symbol, const Matching &matching,
+    const SymbolTable &symbol_table,
+    std::unordered_set<Symbol> &expanded_symbols,
+    std::unordered_map<Symbol, std::set<int>> &node_symbol_to_expansions,
+    std::unordered_set<int> &seen_expansions,
+    std::queue<Expansion> &next_expansions) {
+  auto node_to_expansions_it = node_symbol_to_expansions.find(node_symbol);
+  if (node_to_expansions_it == node_symbol_to_expansions.end()) {
+    return;
+  }
   // Returns true if the expansion is a regular expand or if it is a variable
   // path expand, but with bound symbols used inside the range expression.
   auto can_expand = [&](auto &expansion) {
@@ -60,84 +67,103 @@ auto NextExpansion(const SymbolTable &symbol_table,
       // therefore bound. If the symbols are not found in the whole expansion,
       // then the semantic analysis should guarantee that the symbols have been
       // bound long before we expand.
-      if (all_expansion_symbols.find(range_symbol) !=
-              all_expansion_symbols.end() &&
+      if (matching.expansion_symbols.find(range_symbol) !=
+              matching.expansion_symbols.end() &&
           expanded_symbols.find(range_symbol) == expanded_symbols.end()) {
         return false;
       }
     }
     return true;
   };
-  auto expansion_it = expansions.begin();
-  for (; expansion_it != expansions.end(); ++expansion_it) {
-    if (!can_expand(*expansion_it)) {
+  auto &node_expansions = node_to_expansions_it->second;
+  auto node_expansions_it = node_expansions.begin();
+  while (node_expansions_it != node_to_expansions_it->second.end()) {
+    auto expansion_id = *node_expansions_it;
+    if (seen_expansions.find(expansion_id) != seen_expansions.end()) {
+      // Skip and erase seen (already expanded) expansions.
+      node_expansions_it = node_expansions.erase(node_expansions_it);
       continue;
     }
-    const auto &node1_symbol =
-        symbol_table.at(*expansion_it->node1->identifier_);
-    if (expanded_symbols.find(node1_symbol) != expanded_symbols.end()) {
-      return expansion_it;
+    auto expansion = matching.expansions[expansion_id];
+    if (!can_expand(expansion)) {
+      // Skip but save expansions which need other symbols for later.
+      ++node_expansions_it;
+      continue;
     }
-    // Try expanding from node2 by flipping the expansion.
-    auto *node2 = expansion_it->node2;
-    if (node2 &&
-        expanded_symbols.find(symbol_table.at(*node2->identifier_)) !=
-            expanded_symbols.end() &&
+    if (symbol_table.at(*expansion.node1->identifier_) != node_symbol) {
+      // We are not expanding from node1, so flip the expansion.
+      debug_assert(
+          expansion.node2 &&
+              symbol_table.at(*expansion.node2->identifier_) == node_symbol,
+          "Expected node_symbol to be bound in node2");
+      if (!dynamic_cast<BreadthFirstAtom *>(expansion.edge)) {
         // BFS must *not* be flipped. Doing that changes the BFS results.
-        !dynamic_cast<BreadthFirstAtom *>(expansion_it->edge)) {
-      std::swap(expansion_it->node2, expansion_it->node1);
-      if (expansion_it->direction != EdgeAtom::Direction::BOTH) {
-        expansion_it->direction =
-            expansion_it->direction == EdgeAtom::Direction::IN
-                ? EdgeAtom::Direction::OUT
-                : EdgeAtom::Direction::IN;
+        std::swap(expansion.node1, expansion.node2);
+        if (expansion.direction != EdgeAtom::Direction::BOTH) {
+          expansion.direction = expansion.direction == EdgeAtom::Direction::IN
+                                    ? EdgeAtom::Direction::OUT
+                                    : EdgeAtom::Direction::IN;
+        }
       }
-      return expansion_it;
     }
+    seen_expansions.insert(expansion_id);
+    expanded_symbols.insert(symbol_table.at(*expansion.node1->identifier_));
+    if (expansion.edge) {
+      expanded_symbols.insert(symbol_table.at(*expansion.edge->identifier_));
+      expanded_symbols.insert(symbol_table.at(*expansion.node2->identifier_));
+    }
+    next_expansions.emplace(std::move(expansion));
+    node_expansions_it = node_expansions.erase(node_expansions_it);
+  }
+  if (node_expansions.empty()) {
+    node_symbol_to_expansions.erase(node_to_expansions_it);
   }
-  return expansion_it;
 }
 
 // Generates expansions emanating from the start_node by forming a chain. When
 // the chain can no longer be continued, a different starting node is picked
 // among remaining expansions and the process continues. This is done until all
-// original_expansions are used.
-std::vector<Expansion> ExpansionsFrom(
-    const NodeAtom *start_node, std::vector<Expansion> original_expansions,
-    const SymbolTable &symbol_table) {
-  std::vector<Expansion> expansions;
+// matching.expansions are used.
+std::vector<Expansion> ExpansionsFrom(const NodeAtom *start_node,
+                                      const Matching &matching,
+                                      const SymbolTable &symbol_table) {
+  // Make a copy of node_symbol_to_expansions, because we will modify it as
+  // expansions are chained.
+  auto node_symbol_to_expansions = matching.node_symbol_to_expansions;
+  std::unordered_set<int> seen_expansions;
+  std::queue<Expansion> next_expansions;
   std::unordered_set<Symbol> expanded_symbols(
       {symbol_table.at(*start_node->identifier_)});
-  std::unordered_set<Symbol> all_expansion_symbols;
-  for (const auto &expansion : original_expansions) {
-    all_expansion_symbols.insert(
-        symbol_table.at(*expansion.node1->identifier_));
-    if (expansion.edge) {
-      all_expansion_symbols.insert(
-          symbol_table.at(*expansion.edge->identifier_));
-      all_expansion_symbols.insert(
-          symbol_table.at(*expansion.node2->identifier_));
+  auto add_next_expansions = [&](const auto *node) {
+    AddNextExpansions(symbol_table.at(*node->identifier_), matching,
+                      symbol_table, expanded_symbols, node_symbol_to_expansions,
+                      seen_expansions, next_expansions);
+  };
+  add_next_expansions(start_node);
+  // Potential optimization: expansions and next_expansions could be merge into
+  // a single vector and an index could be used to determine from which should
+  // additional expansions be added.
+  std::vector<Expansion> expansions;
+  while (!next_expansions.empty()) {
+    auto expansion = next_expansions.front();
+    next_expansions.pop();
+    expansions.emplace_back(expansion);
+    add_next_expansions(expansion.node1);
+    if (expansion.node2) {
+      add_next_expansions(expansion.node2);
     }
   }
-  while (!original_expansions.empty()) {
-    auto next_it = NextExpansion(symbol_table, expanded_symbols,
-                                 all_expansion_symbols, original_expansions);
-    if (next_it == original_expansions.end()) {
-      // We could pick a new starting expansion, but to avoid runtime
-      // complexity, simply append the remaining expansions and return them.
-      // They should have a correct order, since the original expansions were
-      // verified during semantic analysis.
-      expansions.insert(expansions.end(), original_expansions.begin(),
-                        original_expansions.end());
-      return expansions;
+  if (!node_symbol_to_expansions.empty()) {
+    // We could pick a new starting expansion, but to avoid runtime
+    // complexity, simply append the remaining expansions. They should have the
+    // correct order, since the original expansions were verified during
+    // semantic analysis.
+    for (int i = 0; i < matching.expansions.size(); ++i) {
+      if (seen_expansions.find(i) != seen_expansions.end()) {
+        continue;
+      }
+      expansions.emplace_back(matching.expansions[i]);
     }
-    expanded_symbols.insert(symbol_table.at(*next_it->node1->identifier_));
-    if (next_it->node2) {
-      expanded_symbols.insert(symbol_table.at(*next_it->edge->identifier_));
-      expanded_symbols.insert(symbol_table.at(*next_it->node2->identifier_));
-    }
-    expansions.emplace_back(*next_it);
-    original_expansions.erase(next_it);
   }
   return expansions;
 }
@@ -178,17 +204,17 @@ class VaryMatchingStart {
 
     iterator(VaryMatchingStart &self, bool is_done)
         : self_(self),
-          // Use the original matching as the first matching, for the case when
-          // there are no nodes.
+          // Use the original matching as the first matching. We are only
+          // interested in changing the expansions part, so the remaining fields
+          // should stay the same. This also produces a matching for the case
+          // when there are no nodes.
           current_matching_(self.matching_) {
       if (!self_.nodes_.empty()) {
-        // Overwrite the original matching with the new one by generating it
-        // from the first start node.
+        // Overwrite the original matching expansions with the new ones by
+        // generating it from the first start node.
         start_nodes_it_ = self_.nodes_.begin();
-        current_matching_ = Matching{
-            ExpansionsFrom(**start_nodes_it_, self_.matching_.expansions,
-                           self_.symbol_table_),
-            self_.matching_.edge_symbols, self_.matching_.filters};
+        current_matching_.expansions = ExpansionsFrom(
+            **start_nodes_it_, self_.matching_, self_.symbol_table_);
       }
       debug_assert(
           start_nodes_it_ || self_.nodes_.empty(),
@@ -215,10 +241,8 @@ class VaryMatchingStart {
         return *this;
       }
       const auto &start_node = **start_nodes_it_;
-      current_matching_ =
-          Matching{ExpansionsFrom(start_node, self_.matching_.expansions,
-                                  self_.symbol_table_),
-                   self_.matching_.edge_symbols, self_.matching_.filters};
+      current_matching_.expansions =
+          ExpansionsFrom(start_node, self_.matching_, self_.symbol_table_);
       return *this;
     }
 
diff --git a/tests/benchmark/query/planner.cpp b/tests/benchmark/query/planner.cpp
new file mode 100644
index 000000000..7b6658756
--- /dev/null
+++ b/tests/benchmark/query/planner.cpp
@@ -0,0 +1,50 @@
+#include <string>
+
+#include <benchmark/benchmark_api.h>
+
+#include "database/dbms.hpp"
+#include "query/frontend/semantic/symbol_generator.hpp"
+#include "query/plan/planner.hpp"
+
+// Add chained MATCH (node1) -- (node2), MATCH (node2) -- (node3) ... clauses.
+static void AddMatches(int num_matches, query::AstTreeStorage &storage) {
+  for (int i = 0; i < num_matches; ++i) {
+    auto *match = storage.Create<query::Match>();
+    auto *pattern = storage.Create<query::Pattern>();
+    pattern->identifier_ = storage.Create<query::Identifier>("path");
+    match->patterns_.emplace_back(pattern);
+    std::string node1_name = "node" + std::to_string(i - 1);
+    pattern->atoms_.emplace_back(storage.Create<query::NodeAtom>(
+        storage.Create<query::Identifier>(node1_name)));
+    pattern->atoms_.emplace_back(storage.Create<query::EdgeAtom>(
+        storage.Create<query::Identifier>("edge" + std::to_string(i)),
+        query::EdgeAtom::Direction::BOTH));
+    pattern->atoms_.emplace_back(storage.Create<query::NodeAtom>(
+        storage.Create<query::Identifier>("node" + std::to_string(i))));
+    storage.query()->clauses_.emplace_back(match);
+  }
+}
+
+static void BM_MakeLogicalPlan(benchmark::State &state) {
+  while (state.KeepRunning()) {
+    state.PauseTiming();
+    Dbms dbms;
+    auto dba = dbms.active();
+    query::AstTreeStorage storage;
+    int num_matches = state.range(0);
+    AddMatches(num_matches, storage);
+    query::SymbolTable symbol_table;
+    query::SymbolGenerator symbol_generator(symbol_table);
+    storage.query()->Accept(symbol_generator);
+    state.ResumeTiming();
+    query::plan::MakeLogicalPlan<query::plan::VariableStartPlanner>(
+        storage, symbol_table, *dba);
+  }
+};
+
+BENCHMARK(BM_MakeLogicalPlan)
+    ->RangeMultiplier(2)
+    ->Range(50, 400)
+    ->Unit(benchmark::kMillisecond);
+
+BENCHMARK_MAIN();
diff --git a/tests/unit/query_variable_start_planner.cpp b/tests/unit/query_variable_start_planner.cpp
index 4a7e0dbe0..db8d70dc2 100644
--- a/tests/unit/query_variable_start_planner.cpp
+++ b/tests/unit/query_variable_start_planner.cpp
@@ -235,4 +235,59 @@ TEST(TestVariableStartPlanner, MatchVariableExpand) {
   });
 }
 
+TEST(TestVariableStartPlanner, MatchVariableExpandReferenceNode) {
+  Dbms dbms;
+  auto dba = dbms.active();
+  auto id = dba->Property("id");
+  // Graph (v1 {id:1}) -[:r1]-> (v2 {id: 2}) -[:r2]-> (v3 {id: 3})
+  auto v1 = dba->InsertVertex();
+  v1.PropsSet(id, 1);
+  auto v2 = dba->InsertVertex();
+  v2.PropsSet(id, 2);
+  auto v3 = dba->InsertVertex();
+  v3.PropsSet(id, 3);
+  auto r1 = dba->InsertEdge(v1, v2, dba->EdgeType("r1"));
+  auto r2 = dba->InsertEdge(v2, v3, dba->EdgeType("r2"));
+  dba->AdvanceCommand();
+  // Test MATCH (n) -[r*..n.id]-> (m) RETURN r
+  AstTreeStorage storage;
+  auto edge = EDGE("r", Direction::OUT);
+  edge->has_range_ = true;
+  edge->upper_bound_ = PROPERTY_LOOKUP("n", id);
+  QUERY(MATCH(PATTERN(NODE("n"), edge, NODE("m"))), RETURN("r"));
+  // We expect to get a single column with the following rows:
+  TypedValue r1_list(std::vector<TypedValue>{r1});  // [r1] (v1 -[*..1]-> v2)
+  TypedValue r2_list(std::vector<TypedValue>{r2});  // [r2] (v2 -[*..2]-> v3)
+  CheckPlansProduce(2, storage, *dba, [&](const auto &results) {
+    AssertRows(results, {{r1_list}, {r2_list}});
+  });
+}
+
+TEST(TestVariableStartPlanner, MatchBfs) {
+  Dbms dbms;
+  auto dba = dbms.active();
+  auto id = dba->Property("id");
+  // Graph (v1 {id:1}) -[:r1]-> (v2 {id: 2}) -[:r2]-> (v3 {id: 3})
+  auto v1 = dba->InsertVertex();
+  v1.PropsSet(id, 1);
+  auto v2 = dba->InsertVertex();
+  v2.PropsSet(id, 2);
+  auto v3 = dba->InsertVertex();
+  v3.PropsSet(id, 3);
+  auto r1 = dba->InsertEdge(v1, v2, dba->EdgeType("r1"));
+  dba->InsertEdge(v2, v3, dba->EdgeType("r2"));
+  dba->AdvanceCommand();
+  // Test MATCH (n) -bfs[r](r, n|n.id <> 3, 10)-> (m) RETURN r
+  AstTreeStorage storage;
+  auto *bfs = storage.Create<query::BreadthFirstAtom>(
+      IDENT("r"), Direction::OUT, IDENT("r"), IDENT("n"),
+      NEQ(PROPERTY_LOOKUP("n", id), LITERAL(3)), LITERAL(10));
+  QUERY(MATCH(PATTERN(NODE("n"), bfs, NODE("m"))), RETURN("r"));
+  // We expect to get a single column with the following rows:
+  TypedValue r1_list(std::vector<TypedValue>{r1});  // [r1]
+  CheckPlansProduce(2, storage, *dba, [&](const auto &results) {
+    AssertRows(results, {{r1_list}});
+  });
+}
+
 }  // namespace