From 5e6aaf231dd6686a13cbd498d1c8903427ac7c8f Mon Sep 17 00:00:00 2001
From: Teon Banek <theongugl@gmail.com>
Date: Fri, 17 Mar 2017 09:57:20 +0100
Subject: [PATCH] Add tests for symbol generation in CREATE clause

Summary:
Add tests for symbol generation in CREATE clause and correctly (hopefully)
check symbols in create clause and properties.

Reviewers: florijan, mislav.bradac, buda

Reviewed By: mislav.bradac

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D135
---
 src/query/exceptions.hpp                      |  12 +
 src/query/frontend/ast/ast.hpp                |  14 +-
 .../frontend/semantic/symbol_generator.hpp    |  84 ++++--
 tests/unit/query_semantic.cpp                 | 264 ++++++++++++++++--
 4 files changed, 331 insertions(+), 43 deletions(-)

diff --git a/src/query/exceptions.hpp b/src/query/exceptions.hpp
index 479961bc3..d09fd08eb 100644
--- a/src/query/exceptions.hpp
+++ b/src/query/exceptions.hpp
@@ -26,6 +26,18 @@ class SemanticException : public BasicException {
   SemanticException() : BasicException("") {}
 };
 
+class UnboundVariableError : public SemanticException {
+ public:
+  UnboundVariableError(const std::string &name)
+      : SemanticException("Unbound variable: " + name) {}
+};
+
+class RedeclareVariableError : public SemanticException {
+ public:
+  RedeclareVariableError(const std::string& name)
+      : SemanticException("Redeclaring variable: " + name) {}
+};
+
 class CppCodeGeneratorException : public StacktraceException {
  public:
   using StacktraceException::StacktraceException;
diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp
index 58f4f8a90..48464d097 100644
--- a/src/query/frontend/ast/ast.hpp
+++ b/src/query/frontend/ast/ast.hpp
@@ -82,8 +82,12 @@ class NamedExpression : public Tree {
 
 class PatternAtom : public Tree {
   friend class AstTreeStorage;
+ public:
+  Identifier* identifier_ = nullptr;
  protected:
   PatternAtom(int uid) : Tree(uid) {}
+  PatternAtom(int uid, Identifier *identifier)
+      : Tree(uid), identifier_(identifier) {}
 };
 
 class NodeAtom : public PatternAtom {
@@ -95,14 +99,11 @@ class NodeAtom : public PatternAtom {
     visitor.PostVisit(*this);
   }
 
-  Identifier* identifier_ = nullptr;
   std::vector<GraphDb::Label> labels_;
   std::map<GraphDb::Property, Expression*> properties_;
 
  protected:
-  NodeAtom(int uid) : PatternAtom(uid) {}
-  NodeAtom(int uid, Identifier *identifier) :
-      PatternAtom(uid), identifier_(identifier) {}
+  using PatternAtom::PatternAtom;
 };
 
 class EdgeAtom : public PatternAtom {
@@ -117,11 +118,10 @@ class EdgeAtom : public PatternAtom {
   }
 
   Direction direction_ = Direction::BOTH;
-  Identifier* identifier_ = nullptr;
   std::vector<GraphDb::EdgeType> types_;
 
  protected:
-  EdgeAtom(int uid) : PatternAtom(uid) {}
+  using PatternAtom::PatternAtom;
 };
 
 class Clause : public Tree {
@@ -166,7 +166,7 @@ class Query : public Tree {
 class Create : public Clause {
  public:
   Create(int uid) : Clause(uid) {}
-  std::vector<std::shared_ptr<Pattern>> patterns_;
+  std::vector<Pattern*> patterns_;
   void Accept(TreeVisitorBase &visitor) override {
     visitor.Visit(*this);
     for (auto &pattern : patterns_) {
diff --git a/src/query/frontend/semantic/symbol_generator.hpp b/src/query/frontend/semantic/symbol_generator.hpp
index fa6888c00..01c62097a 100644
--- a/src/query/frontend/semantic/symbol_generator.hpp
+++ b/src/query/frontend/semantic/symbol_generator.hpp
@@ -8,13 +8,20 @@ namespace query {
 
 class SymbolGenerator : public TreeVisitorBase {
  public:
-  SymbolGenerator(SymbolTable& symbol_table) : symbol_table_(symbol_table) {}
+  SymbolGenerator(SymbolTable &symbol_table) : symbol_table_(symbol_table) {}
 
   using TreeVisitorBase::Visit;
   using TreeVisitorBase::PostVisit;
 
   // Clauses
-  void PostVisit(Return& ret) override {
+  void Visit(Create &create) override {
+    scope_.in_create = true;
+  }
+  void PostVisit(Create &create) override {
+    scope_.in_create = false;
+  }
+
+  void PostVisit(Return &ret) override {
     for (auto &named_expr : ret.named_expressions_) {
       // Named expressions establish bindings for expressions which come after
       // return, but not for the expressions contained inside.
@@ -23,47 +30,92 @@ class SymbolGenerator : public TreeVisitorBase {
   }
 
   // Expressions
-  void Visit(Identifier& ident) override {
+  void Visit(Identifier &ident) override {
     Symbol symbol;
     if (scope_.in_pattern) {
+      // Patterns can bind new symbols or reference already bound. But there
+      // are the following special cases:
+      //  1) Expressions in property maps `{prop_name: expr}` can only reference
+      //     bound variables.
+      //  2) Patterns used to create nodes and edges cannot redeclare already
+      //     established bindings. Declaration only happens in single node
+      //     patterns and in edge patterns. OpenCypher example,
+      //     `MATCH (n) CREATE (n)` should throw an error that `n` is already
+      //     declared. While `MATCH (n) CREATE (n) -[:R]-> (n)` is allowed,
+      //     since `n` now references the bound node instead of declaring it.
+      //     Additionally, we will support edge referencing in pattern:
+      //     `MATCH (n) - [r] -> (n) - [r] -> (n) RETURN r`, which would
+      //     usually raise redeclaration of `r`.
+      if (scope_.in_property_map && !HasSymbol(ident.name_)) {
+        // Case 1)
+        throw UnboundVariableError(ident.name_);
+      } else if ((scope_.in_create_node || scope_.in_create_edge) &&
+                 HasSymbol(ident.name_)) {
+        // Case 2)
+        throw RedeclareVariableError(ident.name_);
+      }
       symbol = GetOrCreateSymbol(ident.name_);
     } else {
+      // Everything else references a bound symbol.
       if (!HasSymbol(ident.name_))
-        // TODO: Special exception for type check
-        throw SemanticException("Unbound identifier: " + ident.name_);
+        throw UnboundVariableError(ident.name_);
       symbol = scope_.variables[ident.name_];
     }
     symbol_table_[ident] = symbol;
   }
+
   // Pattern and its subparts.
-  void Visit(Pattern& pattern) override {
+  void Visit(Pattern &pattern) override {
     scope_.in_pattern = true;
+    if (scope_.in_create && pattern.atoms_.size() == 1) {
+      debug_assert(dynamic_cast<NodeAtom*>(pattern.atoms_[0]),
+                   "Expected a single NodeAtom in Pattern");
+      scope_.in_create_node = true;
+    }
   }
-  void PostVisit(Pattern& pattern) override {
+  void PostVisit(Pattern &pattern) override {
     scope_.in_pattern = false;
+    scope_.in_create_node = false;
+  }
+
+  void Visit(NodeAtom &node_atom) override {
+    scope_.in_property_map = true;
+    for (auto kv : node_atom.properties_) {
+      kv.second->Accept(*this);
+    }
+    scope_.in_property_map = false;
+  }
+
+  void Visit(EdgeAtom &edge_atom) override {
+    if (scope_.in_create) {
+      scope_.in_create_edge = true;
+    }
+  }
+  void PostVisit(EdgeAtom &edge_atom) override {
+    scope_.in_create_edge = false;
   }
 
  private:
   struct Scope {
-    Scope() : in_pattern(false) {}
-    bool in_pattern;
+    bool in_pattern{false};
+    bool in_create{false};
+    bool in_create_node{false};
+    bool in_create_edge{false};
+    bool in_property_map{false};
     std::map<std::string, Symbol> variables;
   };
 
-  bool HasSymbol(const std::string& name)
-  {
+  bool HasSymbol(const std::string &name) {
     return scope_.variables.find(name) != scope_.variables.end();
   }
 
-  Symbol CreateSymbol(const std::string &name)
-  {
+  Symbol CreateSymbol(const std::string &name) {
     auto symbol = symbol_table_.CreateSymbol(name);
     scope_.variables[name] = symbol;
     return symbol;
   }
 
-  Symbol GetOrCreateSymbol(const std::string& name)
-  {
+  Symbol GetOrCreateSymbol(const std::string &name) {
     auto search = scope_.variables.find(name);
     if (search != scope_.variables.end()) {
       return search->second;
@@ -71,7 +123,7 @@ class SymbolGenerator : public TreeVisitorBase {
     return CreateSymbol(name);
   }
 
-  SymbolTable& symbol_table_;
+  SymbolTable &symbol_table_;
   Scope scope_;
 };
 
diff --git a/tests/unit/query_semantic.cpp b/tests/unit/query_semantic.cpp
index 81cb1e318..1838fbc3c 100644
--- a/tests/unit/query_semantic.cpp
+++ b/tests/unit/query_semantic.cpp
@@ -11,15 +11,35 @@ using namespace query;
 
 namespace {
 
+// Returns a `(name1) -[name2]- (name3) ...` pattern.
+auto GetPattern(AstTreeStorage &storage, std::vector<std::string> names) {
+  bool is_node{true};
+  auto pattern = storage.Create<Pattern>();
+  for (auto &name : names) {
+    PatternAtom *atom;
+    auto identifier = storage.Create<Identifier>(name);
+    if (is_node) {
+      atom = storage.Create<NodeAtom>(identifier);
+    } else {
+      atom = storage.Create<EdgeAtom>(identifier);
+    }
+    pattern->atoms_.emplace_back(atom);
+    is_node = !is_node;
+  }
+  return pattern;
+}
+
+// Returns a `MATCH (node)` clause.
+auto GetMatchNode(AstTreeStorage &storage, const std::string &node_name) {
+  auto match = storage.Create<Match>();
+  match->patterns_.emplace_back(GetPattern(storage, {node_name}));
+  return match;
+}
+
 // Build a simple AST which describes:
 // MATCH (node_atom_1) RETURN node_atom_1 AS node_atom_1
 Query *MatchNodeReturn(AstTreeStorage &storage) {
-  auto node_atom = storage.Create<NodeAtom>();
-  node_atom->identifier_ = storage.Create<Identifier>("node_atom_1");
-  auto pattern = storage.Create<Pattern>();
-  pattern->atoms_.emplace_back(node_atom);
-  auto match = storage.Create<Match>();
-  match->patterns_.emplace_back(pattern);
+  auto match = GetMatchNode(storage, "node_atom_1");
   auto query = storage.query();
   query->clauses_.emplace_back(match);
 
@@ -36,12 +56,7 @@ Query *MatchNodeReturn(AstTreeStorage &storage) {
 // This is treated as an unbound variable.
 // MATCH (node_atom_1) RETURN node_atom_1 AS n, n AS n
 Query *MatchUnboundMultiReturn(AstTreeStorage &storage) {
-  auto node_atom = storage.Create<NodeAtom>();
-  node_atom->identifier_ = storage.Create<Identifier>("node_atom_1");
-  auto pattern = storage.Create<Pattern>();
-  pattern->atoms_.emplace_back(node_atom);
-  auto match = storage.Create<Match>();
-  match->patterns_.emplace_back(pattern);
+  auto match = GetMatchNode(storage, "node_atom_1");
   auto query = storage.query();
   query->clauses_.emplace_back(match);
 
@@ -60,12 +75,7 @@ Query *MatchUnboundMultiReturn(AstTreeStorage &storage) {
 
 // AST with unbound variable in return: MATCH (n) RETURN x AS x
 Query *MatchNodeUnboundReturn(AstTreeStorage &storage) {
-  auto node_atom = storage.Create<NodeAtom>();
-  node_atom->identifier_ = storage.Create<Identifier>("n");
-  auto pattern = storage.Create<Pattern>();
-  pattern->atoms_.emplace_back(node_atom);
-  auto match = storage.Create<Match>();
-  match->patterns_.emplace_back(pattern);
+  auto match = GetMatchNode(storage, "n");
   auto query = storage.query();
   query->clauses_.emplace_back(match);
 
@@ -78,6 +88,123 @@ Query *MatchNodeUnboundReturn(AstTreeStorage &storage) {
   return query;
 }
 
+// AST with match pattern referencing an edge multiple times:
+// MATCH (n) -[r]-> (n) -[r]-> (n) RETURN r AS r
+// This usually throws a redeclaration error, but we support it.
+Query *MatchSameEdge(AstTreeStorage &storage) {
+  auto match = storage.Create<Match>();
+  match->patterns_.emplace_back(GetPattern(storage, {"n", "r", "n", "r", "n"}));
+  auto query = storage.query();
+  query->clauses_.emplace_back(match);
+
+  auto named_expr = storage.Create<NamedExpression>();
+  named_expr->name_ = "r";
+  named_expr->expression_ = storage.Create<Identifier>("r");
+  auto ret = storage.Create<Return>();
+  ret->named_expressions_.emplace_back(named_expr);
+  query->clauses_.emplace_back(ret);
+  return query;
+}
+
+std::string prop_name = "prop";
+
+// AST with unbound variable in create: CREATE ({prop: x})
+Query *CreatePropertyUnbound(AstTreeStorage &storage) {
+  auto prop_expr = storage.Create<Identifier>("x");
+  auto node_atom = storage.Create<NodeAtom>();
+  node_atom->identifier_ = storage.Create<Identifier>("anon");
+  node_atom->properties_[&prop_name] = prop_expr;
+  auto pattern = storage.Create<Pattern>();
+  pattern->atoms_.emplace_back(node_atom);
+  auto create = storage.Create<Create>();
+  create->patterns_.emplace_back(pattern);
+  auto query = storage.query();
+  query->clauses_.emplace_back(create);
+  return query;
+}
+
+// Simple AST returning a created node: CREATE (n) RETURN n
+Query *CreateNodeReturn(AstTreeStorage &storage) {
+  auto node_atom = storage.Create<NodeAtom>();
+  node_atom->identifier_ = storage.Create<Identifier>("n");
+  auto pattern = storage.Create<Pattern>();
+  pattern->atoms_.emplace_back(node_atom);
+  auto create = storage.Create<Create>();
+  create->patterns_.emplace_back(pattern);
+  auto query = storage.query();
+  query->clauses_.emplace_back(create);
+
+  auto named_expr = storage.Create<NamedExpression>();
+  named_expr->name_ = "n";
+  named_expr->expression_ = storage.Create<Identifier>("n");
+  auto ret = storage.Create<Return>();
+  ret->named_expressions_.emplace_back(named_expr);
+  query->clauses_.emplace_back(ret);
+  return query;
+}
+
+// AST with redeclaring a variable when creating nodes: CREATE (n), (n)
+Query *CreateRedeclareNode(AstTreeStorage &storage) {
+  auto create = storage.Create<Create>();
+  for (int patterns = 0; patterns < 2; ++patterns) {
+    auto pattern = storage.Create<Pattern>();
+    auto node_atom = storage.Create<NodeAtom>();
+    node_atom->identifier_ = storage.Create<Identifier>("n");
+    pattern->atoms_.emplace_back(node_atom);
+    create->patterns_.emplace_back(pattern);
+  }
+  auto query = storage.query();
+  query->clauses_.emplace_back(create);
+  return query;
+}
+
+// AST with redeclaring a variable when creating nodes with multiple creates:
+// CREATE (n) CREATE (n)
+Query *MultiCreateRedeclareNode(AstTreeStorage &storage) {
+  auto query = storage.query();
+
+  for (int creates = 0; creates < 2; ++creates) {
+    auto pattern = storage.Create<Pattern>();
+    auto node_atom = storage.Create<NodeAtom>();
+    node_atom->identifier_ = storage.Create<Identifier>("n");
+    pattern->atoms_.emplace_back(node_atom);
+    auto create = storage.Create<Create>();
+    create->patterns_.emplace_back(pattern);
+    query->clauses_.emplace_back(create);
+  }
+  return query;
+}
+
+// AST with redeclaring a match node variable in create: MATCH (n) CREATE (n)
+Query *MatchCreateRedeclareNode(AstTreeStorage &storage) {
+  auto match = GetMatchNode(storage, "n");
+  auto query = storage.query();
+  query->clauses_.emplace_back(match);
+
+  auto node_atom_2 = storage.Create<NodeAtom>();
+  node_atom_2->identifier_ = storage.Create<Identifier>("n");
+  auto pattern_2 = storage.Create<Pattern>();
+  pattern_2->atoms_.emplace_back(node_atom_2);
+  auto create = storage.Create<Create>();
+  create->patterns_.emplace_back(pattern_2);
+  query->clauses_.emplace_back(create);
+  return query;
+}
+
+// AST with redeclaring a match edge variable in create:
+// MATCH (n) -[r]-> (m) CREATE (n) -[r] -> (l)
+Query *MatchCreateRedeclareEdge(AstTreeStorage &storage) {
+  auto match = storage.Create<Match>();
+  match->patterns_.emplace_back(GetPattern(storage, {"n", "r", "m"}));
+  auto query = storage.query();
+  query->clauses_.emplace_back(match);
+
+  auto create = storage.Create<Create>();
+  create->patterns_.emplace_back(GetPattern(storage, {"n", "r", "l"}));
+  query->clauses_.emplace_back(create);
+  return query;
+}
+
 TEST(TestSymbolGenerator, MatchNodeReturn) {
   SymbolTable symbol_table;
   AstTreeStorage storage;
@@ -104,7 +231,7 @@ TEST(TestSymbolGenerator, MatchUnboundMultiReturn) {
   AstTreeStorage storage;
   auto query_ast = MatchUnboundMultiReturn(storage);
   SymbolGenerator symbol_generator(symbol_table);
-  EXPECT_THROW(query_ast->Accept(symbol_generator), SemanticException);
+  EXPECT_THROW(query_ast->Accept(symbol_generator), UnboundVariableError);
 }
 
 TEST(TestSymbolGenerator, MatchNodeUnboundReturn) {
@@ -112,6 +239,103 @@ TEST(TestSymbolGenerator, MatchNodeUnboundReturn) {
   AstTreeStorage storage;
   auto query_ast = MatchNodeUnboundReturn(storage);
   SymbolGenerator symbol_generator(symbol_table);
-  EXPECT_THROW(query_ast->Accept(symbol_generator), SemanticException);
+  EXPECT_THROW(query_ast->Accept(symbol_generator), UnboundVariableError);
 }
+
+TEST(TestSymbolGenerator, MatchSameEdge) {
+  SymbolTable symbol_table;
+  AstTreeStorage storage;
+  auto query_ast = MatchSameEdge(storage);
+  SymbolGenerator symbol_generator(symbol_table);
+  query_ast->Accept(symbol_generator);
+  EXPECT_EQ(symbol_table.max_position(), 3);
+  auto match = dynamic_cast<Match*>(query_ast->clauses_[0]);
+  auto pattern = match->patterns_[0];
+  std::vector<Symbol> node_symbols;
+  std::vector<Symbol> edge_symbols;
+  bool is_node{true};
+  for (auto &atom : pattern->atoms_) {
+    auto symbol = symbol_table[*atom->identifier_];
+    if (is_node) {
+      node_symbols.emplace_back(symbol);
+    } else {
+      edge_symbols.emplace_back(symbol);
+    }
+    is_node = !is_node;
+  }
+  auto &node_symbol = node_symbols.front();
+  for (auto &symbol : node_symbols) {
+    EXPECT_EQ(node_symbol, symbol);
+  }
+  auto &edge_symbol = edge_symbols.front();
+  for (auto &symbol : edge_symbols) {
+    EXPECT_EQ(edge_symbol, symbol);
+  }
+  auto ret = dynamic_cast<Return*>(query_ast->clauses_[1]);
+  auto named_expr = ret->named_expressions_[0];
+  auto ret_symbol = symbol_table[*named_expr->expression_];
+  EXPECT_EQ(edge_symbol, ret_symbol);
+}
+
+TEST(TestSymbolGenerator, CreatePropertyUnbound) {
+  SymbolTable symbol_table;
+  AstTreeStorage storage;
+  auto query_ast = CreatePropertyUnbound(storage);
+  SymbolGenerator symbol_generator(symbol_table);
+  EXPECT_THROW(query_ast->Accept(symbol_generator), UnboundVariableError);
+}
+
+TEST(TestSymbolGenerator, CreateNodeReturn) {
+  SymbolTable symbol_table;
+  AstTreeStorage storage;
+  auto query_ast = CreateNodeReturn(storage);
+  SymbolGenerator symbol_generator(symbol_table);
+  query_ast->Accept(symbol_generator);
+  EXPECT_EQ(symbol_table.max_position(), 2);
+  auto create = dynamic_cast<Create*>(query_ast->clauses_[0]);
+  auto pattern = create->patterns_[0];
+  auto node_atom = dynamic_cast<NodeAtom*>(pattern->atoms_[0]);
+  auto node_sym = symbol_table[*node_atom->identifier_];
+  EXPECT_EQ(node_sym.name_, "n");
+  auto ret = dynamic_cast<Return*>(query_ast->clauses_[1]);
+  auto named_expr = ret->named_expressions_[0];
+  auto column_sym = symbol_table[*named_expr];
+  EXPECT_EQ(node_sym.name_, column_sym.name_);
+  EXPECT_NE(node_sym, column_sym);
+  auto ret_sym = symbol_table[*named_expr->expression_];
+  EXPECT_EQ(node_sym, ret_sym);
+}
+
+TEST(TestSymbolGenerator, CreateRedeclareNode) {
+  SymbolTable symbol_table;
+  AstTreeStorage storage;
+  auto query_ast = CreateRedeclareNode(storage);
+  SymbolGenerator symbol_generator(symbol_table);
+  EXPECT_THROW(query_ast->Accept(symbol_generator), RedeclareVariableError);
+}
+
+TEST(TestSymbolGenerator, MultiCreateRedeclareNode) {
+  SymbolTable symbol_table;
+  AstTreeStorage storage;
+  auto query_ast = MultiCreateRedeclareNode(storage);
+  SymbolGenerator symbol_generator(symbol_table);
+  EXPECT_THROW(query_ast->Accept(symbol_generator), RedeclareVariableError);
+}
+
+TEST(TestSymbolGenerator, MatchCreateRedeclareNode) {
+  SymbolTable symbol_table;
+  AstTreeStorage storage;
+  auto query_ast = MatchCreateRedeclareNode(storage);
+  SymbolGenerator symbol_generator(symbol_table);
+  EXPECT_THROW(query_ast->Accept(symbol_generator), RedeclareVariableError);
+}
+
+TEST(TestSymbolGenerator, MatchCreateRedeclareEdge) {
+  SymbolTable symbol_table;
+  AstTreeStorage storage;
+  auto query_ast = MatchCreateRedeclareEdge(storage);
+  SymbolGenerator symbol_generator(symbol_table);
+  EXPECT_THROW(query_ast->Accept(symbol_generator), RedeclareVariableError);
+}
+
 }