From 68e94e64e4b498135213b3a9676415c2df82ab10 Mon Sep 17 00:00:00 2001
From: Teon Banek <theongugl@gmail.com>
Date: Tue, 28 Mar 2017 11:04:28 +0200
Subject: [PATCH] Plan set operations

Summary:
Inline and remove GenProduce.
It doesn't do anything useful, so instantiating `Produce` can be done
inline.

Add dummy MakeCursor overrides for set operators so that the build passes.

Reviewers: florijan, mislav.bradac

Reviewed By: mislav.bradac

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D193
---
 src/query/frontend/ast/ast.hpp          | 13 ++++++++++++
 src/query/frontend/logical/operator.hpp | 13 ++++++++++++
 src/query/frontend/logical/planner.cpp  | 27 +++++++++++++++++-------
 tests/unit/query_common.hpp             | 28 +++++++++++++++++++++++++
 tests/unit/query_planner.cpp            | 16 ++++++++++++++
 5 files changed, 89 insertions(+), 8 deletions(-)

diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp
index 051a3c545..4ace5e24d 100644
--- a/src/query/frontend/ast/ast.hpp
+++ b/src/query/frontend/ast/ast.hpp
@@ -579,6 +579,10 @@ class SetProperty : public Clause {
 
  protected:
   SetProperty(int uid) : Clause(uid) {}
+  SetProperty(int uid, PropertyLookup *property_lookup, Expression *expression)
+      : Clause(uid),
+        property_lookup_(property_lookup),
+        expression_(expression) {}
 };
 
 class SetProperties : public Clause {
@@ -597,6 +601,12 @@ class SetProperties : public Clause {
 
  protected:
   SetProperties(int uid) : Clause(uid) {}
+  SetProperties(int uid, Identifier *identifier, Expression *expression,
+                bool update = false)
+      : Clause(uid),
+        identifier_(identifier),
+        expression_(expression),
+        update_(update) {}
 };
 
 class SetLabels : public Clause {
@@ -613,6 +623,9 @@ class SetLabels : public Clause {
 
  protected:
   SetLabels(int uid) : Clause(uid) {}
+  SetLabels(int uid, Identifier *identifier,
+            const std::vector<GraphDb::Label> &labels)
+      : Clause(uid), identifier_(identifier), labels_(labels) {}
 };
 
 // It would be better to call this AstTree, but we already have a class Tree,
diff --git a/src/query/frontend/logical/operator.hpp b/src/query/frontend/logical/operator.hpp
index 078dd7ceb..ec9dfe35a 100644
--- a/src/query/frontend/logical/operator.hpp
+++ b/src/query/frontend/logical/operator.hpp
@@ -820,6 +820,11 @@ class SetProperty : public LogicalOperator {
               Expression *rhs)
       : input_(input), lhs_(lhs), rhs_(rhs) {}
 
+
+  std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override {
+    return nullptr;
+  }
+
   void Accept(LogicalOperatorVisitor &visitor) override {
     visitor.Visit(*this);
     input_->Accept(visitor);
@@ -847,6 +852,10 @@ class SetProperties : public LogicalOperator {
                 const Symbol input_symbol, Expression *rhs, Op op)
       : input_(input), input_symbol_(input_symbol), rhs_(rhs), op_(op) {}
 
+  std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override {
+    return nullptr;
+  }
+
   void Accept(LogicalOperatorVisitor &visitor) override {
     visitor.Visit(*this);
     input_->Accept(visitor);
@@ -867,6 +876,10 @@ class SetLabels : public LogicalOperator {
             const std::vector<GraphDb::Label> &labels)
       : input_(input), input_symbol_(input_symbol), labels_(labels) {}
 
+  std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override {
+    return nullptr;
+  }
+
   void Accept(LogicalOperatorVisitor &visitor) override {
     visitor.Visit(*this);
     input_->Accept(visitor);
diff --git a/src/query/frontend/logical/planner.cpp b/src/query/frontend/logical/planner.cpp
index a166fbe50..1fc7d24f8 100644
--- a/src/query/frontend/logical/planner.cpp
+++ b/src/query/frontend/logical/planner.cpp
@@ -71,7 +71,7 @@ auto GenCreateForPattern(Pattern &pattern, LogicalOperator *input_op,
   auto collect = [&](LogicalOperator *last_op, NodeAtom *prev_node,
                      EdgeAtom *edge, NodeAtom *node) {
     // Store the symbol from the first node as the input to CreateExpand.
-    auto input_symbol = symbol_table.at(*prev_node->identifier_);
+    const auto &input_symbol = symbol_table.at(*prev_node->identifier_);
     // If the expand node was already bound, then we need to indicate this,
     // so that CreateExpand only creates an edge.
     bool node_existing = false;
@@ -123,7 +123,7 @@ auto GenMatch(Match &match, LogicalOperator *input_op,
   auto collect = [&](LogicalOperator *last_op, NodeAtom *prev_node,
                      EdgeAtom *edge, NodeAtom *node) {
     // Store the symbol from the first node as the input to Expand.
-    auto input_symbol = symbol_table.at(*prev_node->identifier_);
+    const auto &input_symbol = symbol_table.at(*prev_node->identifier_);
     // If the expand symbols were already bound, then we need to indicate
     // this as a cycle. The Expand will then check whether the pattern holds
     // instead of writing the expansion to symbols.
@@ -161,11 +161,6 @@ auto GenMatch(Match &match, LogicalOperator *input_op,
   return last_op;
 }
 
-auto GenReturn(Return &ret, LogicalOperator *input_op) {
-  return new Produce(std::shared_ptr<LogicalOperator>(input_op),
-                     ret.named_expressions_);
-}
-
 }  // namespace
 
 std::unique_ptr<LogicalOperator> MakeLogicalPlan(
@@ -183,12 +178,28 @@ std::unique_ptr<LogicalOperator> MakeLogicalPlan(
     if (auto *match = dynamic_cast<Match *>(clause_ptr)) {
       input_op = GenMatch(*match, input_op, symbol_table, bound_symbols);
     } else if (auto *ret = dynamic_cast<Return *>(clause_ptr)) {
-      input_op = GenReturn(*ret, input_op);
+      input_op = new Produce(std::shared_ptr<LogicalOperator>(input_op),
+                             ret->named_expressions_);
     } else if (auto *create = dynamic_cast<Create *>(clause_ptr)) {
       input_op = GenCreate(*create, input_op, symbol_table, bound_symbols);
     } else if (auto *del = dynamic_cast<query::Delete *>(clause_ptr)) {
       input_op = new plan::Delete(std::shared_ptr<LogicalOperator>(input_op),
                                   del->expressions_, del->detach_);
+    } else if (auto *set = dynamic_cast<query::SetProperty *>(clause_ptr)) {
+      input_op =
+          new plan::SetProperty(std::shared_ptr<LogicalOperator>(input_op),
+                                set->property_lookup_, set->expression_);
+    } else if (auto *set = dynamic_cast<query::SetProperties *>(clause_ptr)) {
+      auto op = set->update_ ? plan::SetProperties::Op::UPDATE
+                             : plan::SetProperties::Op::REPLACE;
+      const auto &input_symbol = symbol_table.at(*set->identifier_);
+      input_op =
+          new plan::SetProperties(std::shared_ptr<LogicalOperator>(input_op),
+                                  input_symbol, set->expression_, op);
+    } else if (auto *set = dynamic_cast<query::SetLabels *>(clause_ptr)) {
+      const auto &input_symbol = symbol_table.at(*set->identifier_);
+      input_op = new plan::SetLabels(std::shared_ptr<LogicalOperator>(input_op),
+                                     input_symbol, set->labels_);
     } else {
       throw NotYetImplemented();
     }
diff --git a/tests/unit/query_common.hpp b/tests/unit/query_common.hpp
index 77668d05f..21b5878d9 100644
--- a/tests/unit/query_common.hpp
+++ b/tests/unit/query_common.hpp
@@ -102,6 +102,33 @@ auto GetDelete(AstTreeStorage &storage, std::vector<Expression *> exprs,
   return del;
 }
 
+///
+/// Create a set property clause for given property lookup and the right hand
+/// side expression.
+///
+auto GetSet(AstTreeStorage &storage, PropertyLookup *prop_lookup,
+            Expression *expr) {
+  return storage.Create<SetProperty>(prop_lookup, expr);
+}
+
+///
+/// Create a set properties clause for given identifier name and the right hand
+/// side expression.
+///
+auto GetSet(AstTreeStorage &storage, const std::string &name, Expression *expr,
+            bool update = false) {
+  return storage.Create<SetProperties>(storage.Create<Identifier>(name), expr,
+                                       update);
+}
+
+///
+/// Create a set labels clause for given identifier name and labels.
+///
+auto GetSet(AstTreeStorage &storage, const std::string &name,
+            std::vector<GraphDb::Label> labels) {
+  return storage.Create<SetLabels>(storage.Create<Identifier>(name), labels);
+}
+
 }  // namespace test_common
 
 }  // namespace query
@@ -135,5 +162,6 @@ auto GetDelete(AstTreeStorage &storage, std::vector<Expression *> exprs,
 #define DELETE(...) query::test_common::GetDelete(storage, {__VA_ARGS__})
 #define DETACH_DELETE(...) \
   query::test_common::GetDelete(storage, {__VA_ARGS__}, true)
+#define SET(...) query::test_common::GetSet(storage, __VA_ARGS__)
 #define QUERY(...) query::test_common::GetQuery(storage, {__VA_ARGS__})
 #define LESS(expr1, expr2) storage.Create<query::LessOperator>((expr1), (expr2))
diff --git a/tests/unit/query_planner.cpp b/tests/unit/query_planner.cpp
index bceccd862..df4f264e5 100644
--- a/tests/unit/query_planner.cpp
+++ b/tests/unit/query_planner.cpp
@@ -35,6 +35,9 @@ class PlanChecker : public LogicalOperatorVisitor {
   void Visit(EdgeFilter &op) override { AssertType(op); }
   void Visit(Filter &op) override { AssertType(op); }
   void Visit(Produce &op) override { AssertType(op); }
+  void Visit(SetProperty &op) override { AssertType(op); }
+  void Visit(SetProperties &op) override { AssertType(op); }
+  void Visit(SetLabels &op) override { AssertType(op); }
 
  private:
   void AssertType(const LogicalOperator &op) {
@@ -153,4 +156,17 @@ TEST(TestLogicalPlanner, MatchDelete) {
   CheckPlan(*query, {typeid(ScanAll).hash_code(), typeid(Delete).hash_code()});
 }
 
+TEST(TestLogicalPlanner, MatchNodeSet) {
+  // Test MATCH (n) SET n.prop = 42, n = n, n :label
+  AstTreeStorage storage;
+  std::string prop("prop");
+  std::string label("label");
+  auto query = QUERY(MATCH(PATTERN(NODE("n"))),
+                     SET(PROPERTY_LOOKUP("n", &prop), LITERAL(42)),
+                     SET("n", IDENT("n")), SET("n", {&label}));
+  CheckPlan(*query,
+            {typeid(ScanAll).hash_code(), typeid(SetProperty).hash_code(),
+             typeid(SetProperties).hash_code(), typeid(SetLabels).hash_code()});
+}
+
 }