From 8da6ce67c04fe76e7a7fd16a6cc2de013c112bef Mon Sep 17 00:00:00 2001
From: Teon Banek <theongugl@gmail.com>
Date: Wed, 22 Mar 2017 13:09:38 +0100
Subject: [PATCH] Add planning CreateExpand operator

Summary: Add planning CreateExpand operator. This is quite similar to planning Expand, but I wouldn't abstract the duplicated parts yet. Also, raise semantic error if creating bidirectional edges

Reviewers: buda, mislav.bradac, florijan

Reviewed By: mislav.bradac

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D155
---
 src/query/frontend/logical/planner.cpp        | 190 ++++++++++--------
 .../frontend/semantic/symbol_generator.hpp    |   8 +-
 tests/unit/query_planner.cpp                  |  50 +++++
 tests/unit/query_semantic.cpp                 |  33 ++-
 4 files changed, 189 insertions(+), 92 deletions(-)

diff --git a/src/query/frontend/logical/planner.cpp b/src/query/frontend/logical/planner.cpp
index 1e811e330..9d9c77e60 100644
--- a/src/query/frontend/logical/planner.cpp
+++ b/src/query/frontend/logical/planner.cpp
@@ -9,41 +9,69 @@ namespace query {
 
 namespace {
 
-static LogicalOperator *GenCreate(
-    Create& create, std::shared_ptr<LogicalOperator> input_op)
-{
-  if (input_op) {
-    // TODO: Support clauses before CREATE, e.g. `MATCH (n) CREATE (m)`
-    throw NotYetImplemented();
-  }
-  if (create.patterns_.size() != 1) {
-    // TODO: Support creating multiple patterns, e.g. `CREATE (n), (m)`
-    throw NotYetImplemented();
-  }
-  auto &pattern = create.patterns_[0];
-  if (pattern->atoms_.size() != 1) {
-    // TODO: Support creating edges.
-    throw NotYetImplemented();
-  }
-  auto *node_atom = dynamic_cast<NodeAtom*>(pattern->atoms_[0]);
-  debug_assert(node_atom, "First pattern atom is not a node");
-  return new CreateOp(node_atom);
-}
-
 // Returns false if the symbol was already bound, otherwise binds it and
 // returns true.
-bool BindSymbol(std::unordered_set<int> &bound_symbols, const Symbol &symbol)
-{
+bool BindSymbol(std::unordered_set<int> &bound_symbols, const Symbol &symbol) {
   auto insertion = bound_symbols.insert(symbol.position_);
   return insertion.second;
 }
 
-LogicalOperator *GenMatch(
-    Match& match,
-    std::shared_ptr<LogicalOperator> input_op,
-    const SymbolTable &symbol_table,
-    std::unordered_set<int> &bound_symbols)
-{
+LogicalOperator *GenCreateForPattern(Pattern &pattern,
+                                     LogicalOperator *input_op,
+                                     const SymbolTable &symbol_table,
+                                     std::unordered_set<int> bound_symbols) {
+  auto atoms_it = pattern.atoms_.begin();
+  auto last_node = dynamic_cast<NodeAtom *>(*atoms_it++);
+  debug_assert(last_node, "First pattern atom is not a node");
+  auto last_op = input_op;
+  if (BindSymbol(bound_symbols, symbol_table.at(*last_node->identifier_))) {
+    // TODO: Pass last_op when CreateOp gets support for it. This will
+    // support e.g. `MATCH (n) CREATE (m)` and `CREATE (n), (m)`.
+    if (last_op) {
+      throw NotYetImplemented();
+    }
+    last_op = new CreateOp(last_node);
+  }
+  // Remaining atoms need to follow sequentially as (EdgeAtom, NodeAtom)*
+  while (atoms_it != pattern.atoms_.end()) {
+    auto edge = dynamic_cast<EdgeAtom *>(*atoms_it++);
+    debug_assert(edge, "Expected an edge atom in pattern.");
+    debug_assert(atoms_it != pattern.atoms_.end(),
+                 "Edge atom should not end the pattern.");
+    // Store the symbol from the first node as the input to CreateExpand.
+    auto input_symbol = symbol_table.at(*last_node->identifier_);
+    last_node = dynamic_cast<NodeAtom *>(*atoms_it++);
+    debug_assert(last_node, "Expected a node atom in pattern.");
+    // If the expand node was already bound, then we need to indicate this,
+    // so that CreateExpand only creates an edge.
+    bool node_existing = false;
+    if (!BindSymbol(bound_symbols, symbol_table.at(*last_node->identifier_))) {
+      node_existing = true;
+    }
+    if (!BindSymbol(bound_symbols, symbol_table.at(*edge->identifier_))) {
+      permanent_fail("Symbols used for created edges cannot be redeclared.");
+    }
+    last_op = new CreateExpand(last_node, edge,
+                               std::shared_ptr<LogicalOperator>(last_op),
+                               input_symbol, node_existing);
+  }
+  return last_op;
+}
+
+LogicalOperator *GenCreate(Create &create, LogicalOperator *input_op,
+                           const SymbolTable &symbol_table,
+                           std::unordered_set<int> bound_symbols) {
+  auto last_op = input_op;
+  for (auto pattern : create.patterns_) {
+    last_op =
+        GenCreateForPattern(*pattern, last_op, symbol_table, bound_symbols);
+  }
+  return last_op;
+}
+
+LogicalOperator *GenMatch(Match &match, LogicalOperator *input_op,
+                          const SymbolTable &symbol_table,
+                          std::unordered_set<int> &bound_symbols) {
   if (input_op) {
     // TODO: Support clauses before match.
     throw NotYetImplemented();
@@ -55,7 +83,7 @@ LogicalOperator *GenMatch(
   auto &pattern = match.patterns_[0];
   debug_assert(!pattern->atoms_.empty(), "Missing atoms in pattern");
   auto atoms_it = pattern->atoms_.begin();
-  auto last_node = dynamic_cast<NodeAtom*>(*atoms_it++);
+  auto last_node = dynamic_cast<NodeAtom *>(*atoms_it++);
   debug_assert(last_node, "First pattern atom is not a node");
   // First atom always binds a symbol, and we don't care if it already existed,
   // because we create a ScanAll which writes that symbol. This may need to
@@ -63,67 +91,60 @@ LogicalOperator *GenMatch(
   BindSymbol(bound_symbols, symbol_table.at(*last_node->identifier_));
   LogicalOperator *last_op = new ScanAll(last_node);
   if (!last_node->labels_.empty() || !last_node->properties_.empty()) {
-    last_op = new NodeFilter(std::shared_ptr<LogicalOperator>(last_op),
-                             symbol_table.at(*last_node->identifier_),
-                             last_node);
+    last_op =
+        new NodeFilter(std::shared_ptr<LogicalOperator>(last_op),
+                       symbol_table.at(*last_node->identifier_), last_node);
   }
-  EdgeAtom *last_edge = nullptr;
   // Remaining atoms need to follow sequentially as (EdgeAtom, NodeAtom)*
-  for ( ; atoms_it != pattern->atoms_.end(); ++atoms_it) {
-    if (last_edge) {
-      // Store the symbol from the first node as the input to Expand.
-      auto input_symbol = symbol_table.at(*last_node->identifier_);
-      last_node = dynamic_cast<NodeAtom*>(*atoms_it);
-      debug_assert(last_node, "Expected a node atom in pattern.");
-      // If the expand symbols were already bound, then we need to indicate
-      // this as a cycle. The Expand will then check whether the pattern holds
-      // instead of writing the expansion to symbols.
-      auto node_cycle = false;
-      auto edge_cycle = false;
-      if (!BindSymbol(bound_symbols, symbol_table.at(*last_node->identifier_))) {
-        node_cycle = true;
-      }
-      if (!BindSymbol(bound_symbols, symbol_table.at(*last_edge->identifier_))) {
-        edge_cycle = true;
-      }
-      last_op = new Expand(last_node, last_edge,
-                           std::shared_ptr<LogicalOperator>(last_op),
-                           input_symbol, node_cycle, edge_cycle);
-      if (!last_edge->edge_types_.empty()) {
-        last_op = new EdgeFilter(std::shared_ptr<LogicalOperator>(last_op),
-                                 symbol_table.at(*last_edge->identifier_),
-                                 last_edge);
-      }
-      if (!last_node->labels_.empty() || !last_node->properties_.empty()) {
-        last_op = new NodeFilter(std::shared_ptr<LogicalOperator>(last_op),
-                                 symbol_table.at(*last_node->identifier_),
-                                 last_node);
-      }
-      // Don't forget to clear the edge, because we expect the next
-      // (EdgeAtom, NodeAtom) sequence.
-      last_edge = nullptr;
-    } else {
-      last_edge = dynamic_cast<EdgeAtom*>(*atoms_it);
-      debug_assert(last_edge, "Expected an edge atom in pattern.");
+  while (atoms_it != pattern->atoms_.end()) {
+    auto edge = dynamic_cast<EdgeAtom *>(*atoms_it++);
+    debug_assert(edge, "Expected an edge atom in pattern.");
+    debug_assert(atoms_it != pattern->atoms_.end(),
+                 "Edge atom should not end the pattern.");
+    // Store the symbol from the first node as the input to Expand.
+    auto input_symbol = symbol_table.at(*last_node->identifier_);
+    last_node = dynamic_cast<NodeAtom *>(*atoms_it++);
+    debug_assert(last_node, "Expected a node atom in pattern.");
+    // If the expand symbols were already bound, then we need to indicate
+    // this as a cycle. The Expand will then check whether the pattern holds
+    // instead of writing the expansion to symbols.
+    auto node_cycle = false;
+    auto edge_cycle = false;
+    if (!BindSymbol(bound_symbols, symbol_table.at(*last_node->identifier_))) {
+      node_cycle = true;
+    }
+    if (!BindSymbol(bound_symbols, symbol_table.at(*edge->identifier_))) {
+      edge_cycle = true;
+    }
+    last_op =
+        new Expand(last_node, edge, std::shared_ptr<LogicalOperator>(last_op),
+                   input_symbol, node_cycle, edge_cycle);
+    if (!edge->edge_types_.empty() || !edge->properties_.empty()) {
+      last_op = new EdgeFilter(std::shared_ptr<LogicalOperator>(last_op),
+                               symbol_table.at(*edge->identifier_), edge);
+    }
+    if (!last_node->labels_.empty() || !last_node->properties_.empty()) {
+      last_op =
+          new NodeFilter(std::shared_ptr<LogicalOperator>(last_op),
+                         symbol_table.at(*last_node->identifier_), last_node);
     }
   }
-  debug_assert(!last_edge, "Edge atom should not end the pattern.");
   return last_op;
 }
 
-Produce *GenReturn(Return& ret, std::shared_ptr<LogicalOperator> input_op)
-{
+Produce *GenReturn(Return &ret, LogicalOperator *input_op) {
   if (!input_op) {
     // TODO: Support standalone RETURN clause (e.g. RETURN 2)
     throw NotYetImplemented();
   }
-  return new Produce(input_op, ret.named_expressions_);
-}
+  return new Produce(std::shared_ptr<LogicalOperator>(input_op),
+                     ret.named_expressions_);
 }
 
+}  // namespace
+
 std::unique_ptr<LogicalOperator> MakeLogicalPlan(
-    Query& query, const SymbolTable &symbol_table)
-{
+    Query &query, const SymbolTable &symbol_table) {
   // TODO: Extract functions and state into a class with methods. Possibly a
   // visitor or similar to avoid all those dynamic casts.
   LogicalOperator *input_op = nullptr;
@@ -134,13 +155,12 @@ std::unique_ptr<LogicalOperator> MakeLogicalPlan(
   std::unordered_set<int> bound_symbols;
   for (auto &clause : query.clauses_) {
     auto *clause_ptr = clause;
-    if (auto *match = dynamic_cast<Match*>(clause_ptr)) {
-      input_op = GenMatch(*match, std::shared_ptr<LogicalOperator>(input_op),
-                          symbol_table, bound_symbols);
-    } else if (auto *ret = dynamic_cast<Return*>(clause_ptr)) {
-      input_op = GenReturn(*ret, std::shared_ptr<LogicalOperator>(input_op));
-    } else if (auto *create = dynamic_cast<Create*>(clause_ptr)) {
-      input_op = GenCreate(*create, std::shared_ptr<LogicalOperator>(input_op));
+    if (auto *match = dynamic_cast<Match *>(clause_ptr)) {
+      input_op = GenMatch(*match, input_op, symbol_table, bound_symbols);
+    } else if (auto *ret = dynamic_cast<Return *>(clause_ptr)) {
+      input_op = GenReturn(*ret, input_op);
+    } else if (auto *create = dynamic_cast<Create *>(clause_ptr)) {
+      input_op = GenCreate(*create, input_op, symbol_table, bound_symbols);
     } else {
       throw NotYetImplemented();
     }
@@ -148,4 +168,4 @@ std::unique_ptr<LogicalOperator> MakeLogicalPlan(
   return std::unique_ptr<LogicalOperator>(input_op);
 }
 
-}
+}  // namespace query
diff --git a/src/query/frontend/semantic/symbol_generator.hpp b/src/query/frontend/semantic/symbol_generator.hpp
index 9c44557f6..bf0558e34 100644
--- a/src/query/frontend/semantic/symbol_generator.hpp
+++ b/src/query/frontend/semantic/symbol_generator.hpp
@@ -101,6 +101,10 @@ class SymbolGenerator : public TreeVisitorBase {
         throw SemanticException("A single relationship type must be specified "
                                 "when creating an edge.");
       }
+      if (edge_atom.direction_ == EdgeAtom::Direction::BOTH) {
+        throw SemanticException("Bidirectional relationship are not supported "
+                                "when creating an edge");
+      }
     }
   }
   void PostVisit(EdgeAtom &edge_atom) override {
@@ -112,9 +116,7 @@ class SymbolGenerator : public TreeVisitorBase {
   // A variable stores the associated symbol and its type.
   struct Variable {
     // This is similar to TypedValue::Type, but this has `Any` type.
-    enum class Type : unsigned {
-      Any, Vertex, Edge, Path
-    };
+    enum class Type { Any, Vertex, Edge, Path };
 
     Symbol symbol;
     Type type{Type::Any};
diff --git a/tests/unit/query_planner.cpp b/tests/unit/query_planner.cpp
index 4c846c33e..9ecf56381 100644
--- a/tests/unit/query_planner.cpp
+++ b/tests/unit/query_planner.cpp
@@ -21,6 +21,7 @@ class PlanChecker : public LogicalOperatorVisitor {
   PlanChecker(std::list<size_t> types) : types_(types) {}
 
   void Visit(CreateOp &op) override { AssertType(op); }
+  void Visit(CreateExpand &op) override { AssertType(op); }
   void Visit(ScanAll &op) override { AssertType(op); }
   void Visit(Expand &op) override { AssertType(op); }
   void Visit(NodeFilter &op) override { AssertType(op); }
@@ -102,6 +103,55 @@ TEST(TestLogicalPlanner, CreateNodeReturn) {
   plan->Accept(plan_checker);
 }
 
+TEST(TestLogicalPlanner, CreateExpand) {
+  // Test CREATE (n) -[r :rel1]-> (m)
+  AstTreeStorage storage;
+  auto create = storage.Create<Create>();
+  auto pattern = GetPattern(storage, {"n", "r", "m"});
+  create->patterns_.emplace_back(pattern);
+  auto edge_atom = dynamic_cast<EdgeAtom*>(pattern->atoms_[1]);
+  edge_atom->direction_ = EdgeAtom::Direction::RIGHT;
+  std::string relationship("relationship");
+  edge_atom->edge_types_.emplace_back(&relationship);
+  auto query = storage.query();
+  query->clauses_.emplace_back(create);
+  SymbolTable symbol_table;
+  SymbolGenerator symbol_generator(symbol_table);
+  query->Accept(symbol_generator);
+  auto plan = MakeLogicalPlan(*query, symbol_table);
+  std::list<size_t> expected_types;
+  expected_types.emplace_back(typeid(CreateOp).hash_code());
+  expected_types.emplace_back(typeid(CreateExpand).hash_code());
+  PlanChecker plan_checker(expected_types);
+  plan->Accept(plan_checker);
+}
+
+TEST(TestLogicalPlanner, MatchCreateExpand) {
+  // Test MATCH (n) CREATE (n) -[r :rel1]-> (m)
+  AstTreeStorage storage;
+  auto match = storage.Create<Match>();
+  match->patterns_.emplace_back(GetPattern(storage, {"n"}));
+  auto query = storage.query();
+  query->clauses_.emplace_back(match);
+  auto create = storage.Create<Create>();
+  auto pattern = GetPattern(storage, {"n", "r", "m"});
+  create->patterns_.emplace_back(pattern);
+  auto edge_atom = dynamic_cast<EdgeAtom*>(pattern->atoms_[1]);
+  edge_atom->direction_ = EdgeAtom::Direction::RIGHT;
+  std::string relationship("relationship");
+  edge_atom->edge_types_.emplace_back(&relationship);
+  query->clauses_.emplace_back(create);
+  SymbolTable symbol_table;
+  SymbolGenerator symbol_generator(symbol_table);
+  query->Accept(symbol_generator);
+  auto plan = MakeLogicalPlan(*query, symbol_table);
+  std::list<size_t> expected_types;
+  expected_types.emplace_back(typeid(ScanAll).hash_code());
+  expected_types.emplace_back(typeid(CreateExpand).hash_code());
+  PlanChecker plan_checker(expected_types);
+  plan->Accept(plan_checker);
+}
+
 TEST(TestLogicalPlanner, MatchLabeledNodes) {
   // Test MATCH (n :label) RETURN n AS n
   AstTreeStorage storage;
diff --git a/tests/unit/query_semantic.cpp b/tests/unit/query_semantic.cpp
index f6bbbc036..467f7cb26 100644
--- a/tests/unit/query_semantic.cpp
+++ b/tests/unit/query_semantic.cpp
@@ -320,7 +320,7 @@ TEST(TestSymbolGenerator, MatchCreateRedeclareEdge) {
   SymbolTable symbol_table;
   AstTreeStorage storage;
   // AST with redeclaring a match edge variable in create:
-  // MATCH (n) -[r]-> (m) CREATE (n) -[r] -> (l)
+  // MATCH (n) -[r]- (m) CREATE (n) -[r] -> (l)
   auto match = storage.Create<Match>();
   match->patterns_.emplace_back(GetPattern(storage, {"n", "r", "m"}));
   auto query = storage.query();
@@ -329,6 +329,7 @@ TEST(TestSymbolGenerator, MatchCreateRedeclareEdge) {
   auto create = storage.Create<Create>();
   auto pattern = GetPattern(storage, {"n", "r", "l"});
   auto edge_atom = dynamic_cast<EdgeAtom*>(pattern->atoms_[1]);
+  edge_atom->direction_ = EdgeAtom::Direction::RIGHT;
   std::string relationship("relationship");
   edge_atom->edge_types_.emplace_back(&relationship);
   create->patterns_.emplace_back(pattern);
@@ -353,13 +354,16 @@ TEST(TestSymbolGenerator, MatchTypeMismatch) {
 TEST(TestSymbolGenerator, MatchCreateTypeMismatch) {
   AstTreeStorage storage;
   // Using an edge variable as a node causes a type mismatch.
-  // MATCH (n1) -[r1]- (n2) CREATE (r1) -[r2]- (n2)
+  // MATCH (n1) -[r1]- (n2) CREATE (r1) -[r2]-> (n2)
   auto match = storage.Create<Match>();
   match->patterns_.emplace_back(GetPattern(storage, {"n1", "r1", "n2"}));
   auto query = storage.query();
   query->clauses_.emplace_back(match);
   auto create = storage.Create<Create>();
-  create->patterns_.emplace_back(GetPattern(storage, {"r1", "r2", "n2"}));
+  auto pattern = GetPattern(storage, {"r1", "r2", "n2"});
+  create->patterns_.emplace_back(pattern);
+  auto edge_atom = dynamic_cast<EdgeAtom*>(pattern->atoms_[1]);
+  edge_atom->direction_ = EdgeAtom::Direction::RIGHT;
   query->clauses_.emplace_back(create);
   SymbolTable symbol_table;
   SymbolGenerator symbol_generator(symbol_table);
@@ -369,9 +373,30 @@ TEST(TestSymbolGenerator, MatchCreateTypeMismatch) {
 TEST(TestSymbolGenerator, CreateMultipleEdgeType) {
   AstTreeStorage storage;
   // Multiple edge relationship are not allowed when creating edges.
-  // CREATE (n) -[r :rel1 | :rel2]- (m)
+  // CREATE (n) -[r :rel1 | :rel2]-> (m)
   auto pattern = GetPattern(storage, {"n", "r", "m"});
   auto edge_atom = dynamic_cast<EdgeAtom*>(pattern->atoms_[1]);
+  edge_atom->direction_ = EdgeAtom::Direction::RIGHT;
+  std::string rel1("rel1");
+  edge_atom->edge_types_.emplace_back(&rel1);
+  std::string rel2("rel2");
+  edge_atom->edge_types_.emplace_back(&rel2);
+  auto create = storage.Create<Create>();
+  create->patterns_.emplace_back(pattern);
+  auto query = storage.query();
+  query->clauses_.emplace_back(create);
+  SymbolTable symbol_table;
+  SymbolGenerator symbol_generator(symbol_table);
+  EXPECT_THROW(query->Accept(symbol_generator), SemanticException);
+}
+
+TEST(TestSymbolGenerator, CreateBidirectionalEdge) {
+  AstTreeStorage storage;
+  // Bidirectional relationships are not allowed when creating edges.
+  // CREATE (n) -[r :rel1]- (m)
+  auto pattern = GetPattern(storage, {"n", "r", "m"});
+  auto edge_atom = dynamic_cast<EdgeAtom*>(pattern->atoms_[1]);
+  edge_atom->direction_ = EdgeAtom::Direction::BOTH;
   std::string rel1("rel1");
   edge_atom->edge_types_.emplace_back(&rel1);
   std::string rel2("rel2");