From b33a65413783af2c96e16fd5090227198a8bff75 Mon Sep 17 00:00:00 2001
From: Teon Banek <theongugl@gmail.com>
Date: Wed, 22 Mar 2017 09:42:43 +0100
Subject: [PATCH] Add basic type checking to SymbolGenerator

Summary:
Test for simple type mismatch of node/edge types.
Add basic type checking of variables.
Check for edge type when creating an edge.
Add documentation to private structs and methods.

Reviewers: mislav.bradac, buda, florijan

Reviewed By: mislav.bradac

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D148
---
 src/query/exceptions.hpp                      | 14 +++-
 .../frontend/semantic/symbol_generator.hpp    | 76 ++++++++++++++----
 tests/unit/query_semantic.cpp                 | 79 +++++++++++++++----
 3 files changed, 137 insertions(+), 32 deletions(-)

diff --git a/src/query/exceptions.hpp b/src/query/exceptions.hpp
index d09fd08eb..872769ea2 100644
--- a/src/query/exceptions.hpp
+++ b/src/query/exceptions.hpp
@@ -3,6 +3,8 @@
 #include "utils/exceptions/basic_exception.hpp"
 #include "utils/exceptions/stacktrace_exception.hpp"
 
+#include <fmt/format.h>
+
 namespace query {
 
 class SyntaxException : public BasicException {
@@ -34,10 +36,19 @@ class UnboundVariableError : public SemanticException {
 
 class RedeclareVariableError : public SemanticException {
  public:
-  RedeclareVariableError(const std::string& name)
+  RedeclareVariableError(const std::string &name)
       : SemanticException("Redeclaring variable: " + name) {}
 };
 
+class TypeMismatchError : public SemanticException {
+ public:
+  TypeMismatchError(const std::string &name, const std::string &datum,
+                    const std::string &expected)
+      : SemanticException(fmt::format(
+            "Type mismatch: '{}' already defined as '{}', but expected '{}'.",
+            name, datum, expected)) {}
+};
+
 class CppCodeGeneratorException : public StacktraceException {
  public:
   using StacktraceException::StacktraceException;
@@ -62,5 +73,4 @@ class QueryEngineException : public StacktraceException {
  public:
   using StacktraceException::StacktraceException;
 };
-
 }
diff --git a/src/query/frontend/semantic/symbol_generator.hpp b/src/query/frontend/semantic/symbol_generator.hpp
index 01c62097a..9c44557f6 100644
--- a/src/query/frontend/semantic/symbol_generator.hpp
+++ b/src/query/frontend/semantic/symbol_generator.hpp
@@ -25,7 +25,7 @@ class SymbolGenerator : public TreeVisitorBase {
     for (auto &named_expr : ret.named_expressions_) {
       // Named expressions establish bindings for expressions which come after
       // return, but not for the expressions contained inside.
-      symbol_table_[*named_expr] = CreateSymbol(named_expr->name_);
+      symbol_table_[*named_expr] = CreateVariable(named_expr->name_).symbol;
     }
   }
 
@@ -46,20 +46,24 @@ class SymbolGenerator : public TreeVisitorBase {
       //     Additionally, we will support edge referencing in pattern:
       //     `MATCH (n) - [r] -> (n) - [r] -> (n) RETURN r`, which would
       //     usually raise redeclaration of `r`.
-      if (scope_.in_property_map && !HasSymbol(ident.name_)) {
+      if (scope_.in_property_map && !HasVariable(ident.name_)) {
         // Case 1)
         throw UnboundVariableError(ident.name_);
       } else if ((scope_.in_create_node || scope_.in_create_edge) &&
-                 HasSymbol(ident.name_)) {
+                 HasVariable(ident.name_)) {
         // Case 2)
         throw RedeclareVariableError(ident.name_);
       }
-      symbol = GetOrCreateSymbol(ident.name_);
+      auto type = Variable::Type::Vertex;
+      if (scope_.in_edge_atom) {
+        type = Variable::Type::Edge;
+      }
+      symbol = GetOrCreateVariable(ident.name_, type).symbol;
     } else {
       // Everything else references a bound symbol.
-      if (!HasSymbol(ident.name_))
+      if (!HasVariable(ident.name_))
         throw UnboundVariableError(ident.name_);
-      symbol = scope_.variables[ident.name_];
+      symbol = scope_.variables[ident.name_].symbol;
     }
     symbol_table_[ident] = symbol;
   }
@@ -75,52 +79,96 @@ class SymbolGenerator : public TreeVisitorBase {
   }
   void PostVisit(Pattern &pattern) override {
     scope_.in_pattern = false;
-    scope_.in_create_node = false;
   }
 
   void Visit(NodeAtom &node_atom) override {
+    scope_.in_node_atom = true;
     scope_.in_property_map = true;
     for (auto kv : node_atom.properties_) {
       kv.second->Accept(*this);
     }
     scope_.in_property_map = false;
   }
+  void PostVisit(NodeAtom &node_atom) override {
+    scope_.in_node_atom = false;
+  }
 
   void Visit(EdgeAtom &edge_atom) override {
+    scope_.in_edge_atom = true;
     if (scope_.in_create) {
       scope_.in_create_edge = true;
+      if (edge_atom.edge_types_.size() != 1) {
+        throw SemanticException("A single relationship type must be specified "
+                                "when creating an edge.");
+      }
     }
   }
   void PostVisit(EdgeAtom &edge_atom) override {
+    scope_.in_edge_atom = false;
     scope_.in_create_edge = false;
   }
 
  private:
+  // A variable stores the associated symbol and its type.
+  struct Variable {
+    // This is similar to TypedValue::Type, but this has `Any` type.
+    enum class Type : unsigned {
+      Any, Vertex, Edge, Path
+    };
+
+    Symbol symbol;
+    Type type{Type::Any};
+  };
+
+  std::string TypeToString(Variable::Type type) {
+    const char *enum_string[] = {"Any", "Vertex", "Edge", "Path"};
+    return enum_string[static_cast<int>(type)];
+  }
+
+  // Scope stores the state of where we are when visiting the AST and a map of
+  // names to variables.
   struct Scope {
     bool in_pattern{false};
     bool in_create{false};
+    // in_create_node is true if we are creating *only* a node. Therefore, it
+    // is *not* equivalent to in_create && in_node_atom.
     bool in_create_node{false};
+    // True if creating an edge; shortcut for in_create && in_edge_atom.
     bool in_create_edge{false};
+    bool in_node_atom{false};
+    bool in_edge_atom{false};
     bool in_property_map{false};
-    std::map<std::string, Symbol> variables;
+    std::map<std::string, Variable> variables;
   };
 
-  bool HasSymbol(const std::string &name) {
+  bool HasVariable(const std::string &name) {
     return scope_.variables.find(name) != scope_.variables.end();
   }
 
-  Symbol CreateSymbol(const std::string &name) {
+  // Returns a new variable with a freshly generated symbol. Previous mapping
+  // of the same name to a different variable is replaced with the new one.
+  Variable CreateVariable(
+      const std::string &name, Variable::Type type = Variable::Type::Any) {
     auto symbol = symbol_table_.CreateSymbol(name);
-    scope_.variables[name] = symbol;
-    return symbol;
+    auto variable = Variable{symbol, type};
+    scope_.variables[name] = variable;
+    return variable;
   }
 
-  Symbol GetOrCreateSymbol(const std::string &name) {
+  // Returns the variable by name. If the mapping already exists, checks if the
+  // types match. Otherwise, returns a new variable.
+  Variable GetOrCreateVariable(
+      const std::string &name, Variable::Type type = Variable::Type::Any) {
     auto search = scope_.variables.find(name);
     if (search != scope_.variables.end()) {
+      auto variable = search->second;
+      if (type != Variable::Type::Any && type != variable.type) {
+        throw TypeMismatchError(name, TypeToString(variable.type),
+                                TypeToString(type));
+      }
       return search->second;
     }
-    return CreateSymbol(name);
+    return CreateVariable(name, type);
   }
 
   SymbolTable &symbol_table_;
diff --git a/tests/unit/query_semantic.cpp b/tests/unit/query_semantic.cpp
index 1838fbc3c..f6bbbc036 100644
--- a/tests/unit/query_semantic.cpp
+++ b/tests/unit/query_semantic.cpp
@@ -191,20 +191,6 @@ Query *MatchCreateRedeclareNode(AstTreeStorage &storage) {
   return query;
 }
 
-// AST with redeclaring a match edge variable in create:
-// MATCH (n) -[r]-> (m) CREATE (n) -[r] -> (l)
-Query *MatchCreateRedeclareEdge(AstTreeStorage &storage) {
-  auto match = storage.Create<Match>();
-  match->patterns_.emplace_back(GetPattern(storage, {"n", "r", "m"}));
-  auto query = storage.query();
-  query->clauses_.emplace_back(match);
-
-  auto create = storage.Create<Create>();
-  create->patterns_.emplace_back(GetPattern(storage, {"n", "r", "l"}));
-  query->clauses_.emplace_back(create);
-  return query;
-}
-
 TEST(TestSymbolGenerator, MatchNodeReturn) {
   SymbolTable symbol_table;
   AstTreeStorage storage;
@@ -333,9 +319,70 @@ TEST(TestSymbolGenerator, MatchCreateRedeclareNode) {
 TEST(TestSymbolGenerator, MatchCreateRedeclareEdge) {
   SymbolTable symbol_table;
   AstTreeStorage storage;
-  auto query_ast = MatchCreateRedeclareEdge(storage);
+  // AST with redeclaring a match edge variable in create:
+  // MATCH (n) -[r]-> (m) CREATE (n) -[r] -> (l)
+  auto match = storage.Create<Match>();
+  match->patterns_.emplace_back(GetPattern(storage, {"n", "r", "m"}));
+  auto query = storage.query();
+  query->clauses_.emplace_back(match);
+
+  auto create = storage.Create<Create>();
+  auto pattern = GetPattern(storage, {"n", "r", "l"});
+  auto edge_atom = dynamic_cast<EdgeAtom*>(pattern->atoms_[1]);
+  std::string relationship("relationship");
+  edge_atom->edge_types_.emplace_back(&relationship);
+  create->patterns_.emplace_back(pattern);
+  query->clauses_.emplace_back(create);
   SymbolGenerator symbol_generator(symbol_table);
-  EXPECT_THROW(query_ast->Accept(symbol_generator), RedeclareVariableError);
+  EXPECT_THROW(query->Accept(symbol_generator), RedeclareVariableError);
+}
+
+TEST(TestSymbolGenerator, MatchTypeMismatch) {
+  AstTreeStorage storage;
+  // Using an edge variable as a node causes a type mismatch.
+  // MATCH (n) -[r]-> (r)
+  auto match = storage.Create<Match>();
+  match->patterns_.emplace_back(GetPattern(storage, {"n", "r", "r"}));
+  auto query = storage.query();
+  query->clauses_.emplace_back(match);
+  SymbolTable symbol_table;
+  SymbolGenerator symbol_generator(symbol_table);
+  EXPECT_THROW(query->Accept(symbol_generator), TypeMismatchError);
+}
+
+TEST(TestSymbolGenerator, MatchCreateTypeMismatch) {
+  AstTreeStorage storage;
+  // Using an edge variable as a node causes a type mismatch.
+  // MATCH (n1) -[r1]- (n2) CREATE (r1) -[r2]- (n2)
+  auto match = storage.Create<Match>();
+  match->patterns_.emplace_back(GetPattern(storage, {"n1", "r1", "n2"}));
+  auto query = storage.query();
+  query->clauses_.emplace_back(match);
+  auto create = storage.Create<Create>();
+  create->patterns_.emplace_back(GetPattern(storage, {"r1", "r2", "n2"}));
+  query->clauses_.emplace_back(create);
+  SymbolTable symbol_table;
+  SymbolGenerator symbol_generator(symbol_table);
+  EXPECT_THROW(query->Accept(symbol_generator), TypeMismatchError);
+}
+
+TEST(TestSymbolGenerator, CreateMultipleEdgeType) {
+  AstTreeStorage storage;
+  // Multiple edge relationship are not allowed when creating edges.
+  // CREATE (n) -[r :rel1 | :rel2]- (m)
+  auto pattern = GetPattern(storage, {"n", "r", "m"});
+  auto edge_atom = dynamic_cast<EdgeAtom*>(pattern->atoms_[1]);
+  std::string rel1("rel1");
+  edge_atom->edge_types_.emplace_back(&rel1);
+  std::string rel2("rel2");
+  edge_atom->edge_types_.emplace_back(&rel2);
+  auto create = storage.Create<Create>();
+  create->patterns_.emplace_back(pattern);
+  auto query = storage.query();
+  query->clauses_.emplace_back(create);
+  SymbolTable symbol_table;
+  SymbolGenerator symbol_generator(symbol_table);
+  EXPECT_THROW(query->Accept(symbol_generator), SemanticException);
 }
 
 }