From 84e7aec27b240ee294b80f461fcb7a79bcf16766 Mon Sep 17 00:00:00 2001
From: Matija Santl <matija.santl@memgraph.com>
Date: Thu, 15 Feb 2018 10:20:43 +0100
Subject: [PATCH] Implement Cartesian Cursor

Summary:
Pulls left op cursor and keeps the result, and then for each pull of
the right op cursor, adds all the left op results to produce a cartesian
product.

Reviewers: teon.banek, florijan

Reviewed By: teon.banek

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1201
---
 src/query/plan/operator.cpp                   |  96 +++++++++++--
 src/query/plan/operator.hpp                   |   2 +
 tests/unit/query_plan_match_filter_return.cpp | 133 ++++++++++++++++++
 3 files changed, 222 insertions(+), 9 deletions(-)

diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp
index 741c5460d..892006b2c 100644
--- a/src/query/plan/operator.cpp
+++ b/src/query/plan/operator.cpp
@@ -328,10 +328,10 @@ std::unique_ptr<Cursor> ScanAllByLabelPropertyRange::MakeCursor(
                                   context.symbol_table_, db, graph_view_);
     auto convert = [&evaluator](const auto &bound)
         -> std::experimental::optional<utils::Bound<PropertyValue>> {
-      if (!bound) return std::experimental::nullopt;
-      return std::experimental::make_optional(utils::Bound<PropertyValue>(
-          bound.value().value()->Accept(evaluator), bound.value().type()));
-    };
+          if (!bound) return std::experimental::nullopt;
+          return std::experimental::make_optional(utils::Bound<PropertyValue>(
+              bound.value().value()->Accept(evaluator), bound.value().type()));
+        };
     return db.Vertices(label_, property_, convert(lower_bound()),
                        convert(upper_bound()), graph_view_ == GraphView::NEW);
   };
@@ -1058,9 +1058,8 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor {
                                   self_.graph_view_);
     // For the given (vertex, edge, vertex) tuple checks if they satisfy the
     // "where" condition. if so, places them in the priority queue.
-    auto expand_pair = [this, &evaluator, &frame](VertexAccessor from,
-                                                  EdgeAccessor edge,
-                                                  VertexAccessor vertex) {
+    auto expand_pair = [this, &evaluator, &frame](
+        VertexAccessor from, EdgeAccessor edge, VertexAccessor vertex) {
       SwitchAccessor(edge, self_.graph_view_);
       SwitchAccessor(vertex, self_.graph_view_);
 
@@ -3026,6 +3025,86 @@ class SynchronizeCursor : public Cursor {
     }
   }
 };
+
+class CartesianCursor : public Cursor {
+ public:
+  CartesianCursor(const Cartesian &self, database::GraphDbAccessor &db)
+      : self_(self),
+        left_op_cursor_(self.left_op()->MakeCursor(db)),
+        right_op_cursor_(self_.right_op()->MakeCursor(db)) {
+    CHECK(left_op_cursor_ != nullptr)
+        << "CartesianCursor: Missing left operator cursor.";
+    CHECK(right_op_cursor_ != nullptr)
+        << "CartesianCursor: Missing right operator cursor.";
+  }
+
+  bool Pull(Frame &frame, Context &context) override {
+    auto copy_frame = [&frame]() {
+      std::vector<TypedValue> result;
+      for (auto &elem : frame.elems()) {
+        result.emplace_back(std::move(elem));
+      }
+      return result;
+    };
+
+    if (!cartesian_pull_initialized_) {
+      // Pull all left_op frames.
+      while (left_op_cursor_->Pull(frame, context)) {
+        left_op_frames_.emplace_back(copy_frame());
+      }
+
+      // We're setting the iterator to 'end' here so it pulls the right cursor.
+      left_op_frames_it_ = left_op_frames_.end();
+      cartesian_pull_initialized_ = true;
+    }
+
+    // If left operator yielded zero results there is no cartesian product.
+    if (left_op_frames_.empty()) {
+      return false;
+    }
+
+    auto restore_frame = [&frame](const std::vector<Symbol> &symbols,
+                                  const std::vector<TypedValue> &restore_from) {
+      for (const auto &symbol : symbols) {
+        frame[symbol] = restore_from[symbol.position()];
+      }
+    };
+
+    if (left_op_frames_it_ == left_op_frames_.end()) {
+      // Advance right_op_cursor_.
+      if (!right_op_cursor_->Pull(frame, context)) return false;
+
+      right_op_frame_ = copy_frame();
+      left_op_frames_it_ = left_op_frames_.begin();
+    } else {
+      // Make sure right_op_cursor last pulled results are on frame.
+      restore_frame(self_.right_symbols(), right_op_frame_);
+    }
+
+    restore_frame(self_.left_symbols(), *left_op_frames_it_);
+    left_op_frames_it_++;
+    return true;
+  }
+
+  void Reset() override {
+    left_op_cursor_->Reset();
+    right_op_cursor_->Reset();
+    right_op_frame_.clear();
+    left_op_frames_.clear();
+    left_op_frames_it_ = left_op_frames_.end();
+    cartesian_pull_initialized_ = false;
+  }
+
+ private:
+  const Cartesian &self_;
+  std::vector<std::vector<TypedValue>> left_op_frames_;
+  std::vector<TypedValue> right_op_frame_;
+  const std::unique_ptr<Cursor> left_op_cursor_;
+  const std::unique_ptr<Cursor> right_op_cursor_;
+  std::vector<std::vector<TypedValue>>::iterator left_op_frames_it_;
+  bool cartesian_pull_initialized_{false};
+};
+
 }  // namespace
 
 std::unique_ptr<Cursor> Synchronize::MakeCursor(
@@ -3042,8 +3121,7 @@ bool Cartesian::Accept(HierarchicalLogicalOperatorVisitor &visitor) {
 
 std::unique_ptr<Cursor> Cartesian::MakeCursor(
     database::GraphDbAccessor &db) const {
-  // TODO: Implement cursor.
-  return nullptr;
+  return std::make_unique<CartesianCursor>(*this, db);
 }
 
 }  // namespace query::plan
diff --git a/src/query/plan/operator.hpp b/src/query/plan/operator.hpp
index 06e659872..c890c2db8 100644
--- a/src/query/plan/operator.hpp
+++ b/src/query/plan/operator.hpp
@@ -2402,7 +2402,9 @@ class Cartesian : public LogicalOperator {
       database::GraphDbAccessor &db) const override;
 
   auto left_op() const { return left_op_; }
+  auto left_symbols() const { return left_symbols_; }
   auto right_op() const { return right_op_; }
+  auto right_symbols() const { return right_symbols_; }
 
  private:
   std::shared_ptr<LogicalOperator> left_op_;
diff --git a/tests/unit/query_plan_match_filter_return.cpp b/tests/unit/query_plan_match_filter_return.cpp
index b11b2d308..eb5ecff2f 100644
--- a/tests/unit/query_plan_match_filter_return.cpp
+++ b/tests/unit/query_plan_match_filter_return.cpp
@@ -246,6 +246,139 @@ TEST(QueryPlan, NodeFilterMultipleLabels) {
   EXPECT_EQ(results.size(), 2);
 }
 
+TEST(QueryPlan, Cartesian) {
+  database::SingleNode db;
+  database::GraphDbAccessor dba(db);
+
+  auto add_vertex = [&dba](std::string label) {
+    auto vertex = dba.InsertVertex();
+    vertex.add_label(dba.Label(label));
+    return vertex;
+  };
+
+  std::vector<VertexAccessor> vertices{add_vertex("v1"), add_vertex("v2"),
+                                       add_vertex("v3")};
+  dba.AdvanceCommand();
+
+  AstTreeStorage storage;
+  SymbolTable symbol_table;
+
+  auto n = MakeScanAll(storage, symbol_table, "n");
+  auto m = MakeScanAll(storage, symbol_table, "m");
+  auto return_n = NEXPR("n", IDENT("n"));
+  symbol_table[*return_n->expression_] = n.sym_;
+  symbol_table[*return_n] =
+      symbol_table.CreateSymbol("named_expression_1", true);
+  auto return_m = NEXPR("m", IDENT("m"));
+  symbol_table[*return_m->expression_] = m.sym_;
+  symbol_table[*return_m] =
+      symbol_table.CreateSymbol("named_expression_2", true);
+
+  std::vector<Symbol> left_symbols{n.sym_};
+  std::vector<Symbol> right_symbols{m.sym_};
+  auto cartesian_op =
+      std::make_shared<Cartesian>(n.op_, left_symbols, m.op_, right_symbols);
+
+  auto produce = MakeProduce(cartesian_op, return_n, return_m);
+
+  auto results = CollectProduce(produce.get(), symbol_table, dba);
+  EXPECT_EQ(results.size(), 9);
+  for (int i = 0; i < 3; ++i) {
+    for (int j = 0; j < 3; ++j) {
+      EXPECT_EQ(results[3 * i + j][0].Value<VertexAccessor>(), vertices[j]);
+      EXPECT_EQ(results[3 * i + j][1].Value<VertexAccessor>(), vertices[i]);
+    }
+  }
+}
+
+TEST(QueryPlan, CartesianEmptySet) {
+  database::SingleNode db;
+  database::GraphDbAccessor dba(db);
+
+  AstTreeStorage storage;
+  SymbolTable symbol_table;
+
+  auto n = MakeScanAll(storage, symbol_table, "n");
+  auto m = MakeScanAll(storage, symbol_table, "m");
+  auto return_n = NEXPR("n", IDENT("n"));
+  symbol_table[*return_n->expression_] = n.sym_;
+  symbol_table[*return_n] =
+      symbol_table.CreateSymbol("named_expression_1", true);
+  auto return_m = NEXPR("m", IDENT("m"));
+  symbol_table[*return_m->expression_] = m.sym_;
+  symbol_table[*return_m] =
+      symbol_table.CreateSymbol("named_expression_2", true);
+
+  std::vector<Symbol> left_symbols{n.sym_};
+  std::vector<Symbol> right_symbols{m.sym_};
+  auto cartesian_op =
+      std::make_shared<Cartesian>(n.op_, left_symbols, m.op_, right_symbols);
+
+  auto produce = MakeProduce(cartesian_op, return_n, return_m);
+
+  auto results = CollectProduce(produce.get(), symbol_table, dba);
+  EXPECT_EQ(results.size(), 0);
+}
+
+TEST(QueryPlan, CartesianThreeWay) {
+  database::SingleNode db;
+  database::GraphDbAccessor dba(db);
+  auto add_vertex = [&dba](std::string label) {
+    auto vertex = dba.InsertVertex();
+    vertex.add_label(dba.Label(label));
+    return vertex;
+  };
+
+  std::vector<VertexAccessor> vertices{add_vertex("v1"), add_vertex("v2"),
+                                       add_vertex("v3")};
+  dba.AdvanceCommand();
+
+  AstTreeStorage storage;
+  SymbolTable symbol_table;
+
+  auto n = MakeScanAll(storage, symbol_table, "n");
+  auto m = MakeScanAll(storage, symbol_table, "m");
+  auto l = MakeScanAll(storage, symbol_table, "l");
+  auto return_n = NEXPR("n", IDENT("n"));
+  symbol_table[*return_n->expression_] = n.sym_;
+  symbol_table[*return_n] =
+      symbol_table.CreateSymbol("named_expression_1", true);
+  auto return_m = NEXPR("m", IDENT("m"));
+  symbol_table[*return_m->expression_] = m.sym_;
+  symbol_table[*return_m] =
+      symbol_table.CreateSymbol("named_expression_2", true);
+  auto return_l = NEXPR("l", IDENT("l"));
+  symbol_table[*return_l->expression_] = l.sym_;
+  symbol_table[*return_l] =
+      symbol_table.CreateSymbol("named_expression_3", true);
+
+  std::vector<Symbol> n_symbols{n.sym_};
+  std::vector<Symbol> m_symbols{m.sym_};
+  std::vector<Symbol> n_m_symbols{n.sym_, m.sym_};
+  std::vector<Symbol> l_symbols{l.sym_};
+  auto cartesian_op_1 =
+      std::make_shared<Cartesian>(n.op_, n_symbols, m.op_, m_symbols);
+
+  auto cartesian_op_2 = std::make_shared<Cartesian>(cartesian_op_1, n_m_symbols,
+                                                    l.op_, l_symbols);
+
+  auto produce = MakeProduce(cartesian_op_2, return_n, return_m, return_l);
+
+  auto results = CollectProduce(produce.get(), symbol_table, dba);
+  EXPECT_EQ(results.size(), 27);
+  int id = 0;
+  for (int i = 0; i < 3; ++i) {
+    for (int j = 0; j < 3; ++j) {
+      for (int k = 0; k < 3; ++k) {
+        EXPECT_EQ(results[id][0].Value<VertexAccessor>(), vertices[k]);
+        EXPECT_EQ(results[id][1].Value<VertexAccessor>(), vertices[j]);
+        EXPECT_EQ(results[id][2].Value<VertexAccessor>(), vertices[i]);
+        ++id;
+      }
+    }
+  }
+}
+
 class ExpandFixture : public testing::Test {
  protected:
   database::SingleNode db_;