diff --git a/src/query/exceptions.hpp b/src/query/exceptions.hpp index d09fd08eb..872769ea2 100644 --- a/src/query/exceptions.hpp +++ b/src/query/exceptions.hpp @@ -3,6 +3,8 @@ #include "utils/exceptions/basic_exception.hpp" #include "utils/exceptions/stacktrace_exception.hpp" +#include + namespace query { class SyntaxException : public BasicException { @@ -34,10 +36,19 @@ class UnboundVariableError : public SemanticException { class RedeclareVariableError : public SemanticException { public: - RedeclareVariableError(const std::string& name) + RedeclareVariableError(const std::string &name) : SemanticException("Redeclaring variable: " + name) {} }; +class TypeMismatchError : public SemanticException { + public: + TypeMismatchError(const std::string &name, const std::string &datum, + const std::string &expected) + : SemanticException(fmt::format( + "Type mismatch: '{}' already defined as '{}', but expected '{}'.", + name, datum, expected)) {} +}; + class CppCodeGeneratorException : public StacktraceException { public: using StacktraceException::StacktraceException; @@ -62,5 +73,4 @@ class QueryEngineException : public StacktraceException { public: using StacktraceException::StacktraceException; }; - } diff --git a/src/query/frontend/semantic/symbol_generator.hpp b/src/query/frontend/semantic/symbol_generator.hpp index 01c62097a..9c44557f6 100644 --- a/src/query/frontend/semantic/symbol_generator.hpp +++ b/src/query/frontend/semantic/symbol_generator.hpp @@ -25,7 +25,7 @@ class SymbolGenerator : public TreeVisitorBase { for (auto &named_expr : ret.named_expressions_) { // Named expressions establish bindings for expressions which come after // return, but not for the expressions contained inside. - symbol_table_[*named_expr] = CreateSymbol(named_expr->name_); + symbol_table_[*named_expr] = CreateVariable(named_expr->name_).symbol; } } @@ -46,20 +46,24 @@ class SymbolGenerator : public TreeVisitorBase { // Additionally, we will support edge referencing in pattern: // `MATCH (n) - [r] -> (n) - [r] -> (n) RETURN r`, which would // usually raise redeclaration of `r`. - if (scope_.in_property_map && !HasSymbol(ident.name_)) { + if (scope_.in_property_map && !HasVariable(ident.name_)) { // Case 1) throw UnboundVariableError(ident.name_); } else if ((scope_.in_create_node || scope_.in_create_edge) && - HasSymbol(ident.name_)) { + HasVariable(ident.name_)) { // Case 2) throw RedeclareVariableError(ident.name_); } - symbol = GetOrCreateSymbol(ident.name_); + auto type = Variable::Type::Vertex; + if (scope_.in_edge_atom) { + type = Variable::Type::Edge; + } + symbol = GetOrCreateVariable(ident.name_, type).symbol; } else { // Everything else references a bound symbol. - if (!HasSymbol(ident.name_)) + if (!HasVariable(ident.name_)) throw UnboundVariableError(ident.name_); - symbol = scope_.variables[ident.name_]; + symbol = scope_.variables[ident.name_].symbol; } symbol_table_[ident] = symbol; } @@ -75,52 +79,96 @@ class SymbolGenerator : public TreeVisitorBase { } void PostVisit(Pattern &pattern) override { scope_.in_pattern = false; - scope_.in_create_node = false; } void Visit(NodeAtom &node_atom) override { + scope_.in_node_atom = true; scope_.in_property_map = true; for (auto kv : node_atom.properties_) { kv.second->Accept(*this); } scope_.in_property_map = false; } + void PostVisit(NodeAtom &node_atom) override { + scope_.in_node_atom = false; + } void Visit(EdgeAtom &edge_atom) override { + scope_.in_edge_atom = true; if (scope_.in_create) { scope_.in_create_edge = true; + if (edge_atom.edge_types_.size() != 1) { + throw SemanticException("A single relationship type must be specified " + "when creating an edge."); + } } } void PostVisit(EdgeAtom &edge_atom) override { + scope_.in_edge_atom = false; scope_.in_create_edge = false; } private: + // A variable stores the associated symbol and its type. + struct Variable { + // This is similar to TypedValue::Type, but this has `Any` type. + enum class Type : unsigned { + Any, Vertex, Edge, Path + }; + + Symbol symbol; + Type type{Type::Any}; + }; + + std::string TypeToString(Variable::Type type) { + const char *enum_string[] = {"Any", "Vertex", "Edge", "Path"}; + return enum_string[static_cast(type)]; + } + + // Scope stores the state of where we are when visiting the AST and a map of + // names to variables. struct Scope { bool in_pattern{false}; bool in_create{false}; + // in_create_node is true if we are creating *only* a node. Therefore, it + // is *not* equivalent to in_create && in_node_atom. bool in_create_node{false}; + // True if creating an edge; shortcut for in_create && in_edge_atom. bool in_create_edge{false}; + bool in_node_atom{false}; + bool in_edge_atom{false}; bool in_property_map{false}; - std::map variables; + std::map variables; }; - bool HasSymbol(const std::string &name) { + bool HasVariable(const std::string &name) { return scope_.variables.find(name) != scope_.variables.end(); } - Symbol CreateSymbol(const std::string &name) { + // Returns a new variable with a freshly generated symbol. Previous mapping + // of the same name to a different variable is replaced with the new one. + Variable CreateVariable( + const std::string &name, Variable::Type type = Variable::Type::Any) { auto symbol = symbol_table_.CreateSymbol(name); - scope_.variables[name] = symbol; - return symbol; + auto variable = Variable{symbol, type}; + scope_.variables[name] = variable; + return variable; } - Symbol GetOrCreateSymbol(const std::string &name) { + // Returns the variable by name. If the mapping already exists, checks if the + // types match. Otherwise, returns a new variable. + Variable GetOrCreateVariable( + const std::string &name, Variable::Type type = Variable::Type::Any) { auto search = scope_.variables.find(name); if (search != scope_.variables.end()) { + auto variable = search->second; + if (type != Variable::Type::Any && type != variable.type) { + throw TypeMismatchError(name, TypeToString(variable.type), + TypeToString(type)); + } return search->second; } - return CreateSymbol(name); + return CreateVariable(name, type); } SymbolTable &symbol_table_; diff --git a/tests/unit/query_semantic.cpp b/tests/unit/query_semantic.cpp index 1838fbc3c..f6bbbc036 100644 --- a/tests/unit/query_semantic.cpp +++ b/tests/unit/query_semantic.cpp @@ -191,20 +191,6 @@ Query *MatchCreateRedeclareNode(AstTreeStorage &storage) { return query; } -// AST with redeclaring a match edge variable in create: -// MATCH (n) -[r]-> (m) CREATE (n) -[r] -> (l) -Query *MatchCreateRedeclareEdge(AstTreeStorage &storage) { - auto match = storage.Create(); - match->patterns_.emplace_back(GetPattern(storage, {"n", "r", "m"})); - auto query = storage.query(); - query->clauses_.emplace_back(match); - - auto create = storage.Create(); - create->patterns_.emplace_back(GetPattern(storage, {"n", "r", "l"})); - query->clauses_.emplace_back(create); - return query; -} - TEST(TestSymbolGenerator, MatchNodeReturn) { SymbolTable symbol_table; AstTreeStorage storage; @@ -333,9 +319,70 @@ TEST(TestSymbolGenerator, MatchCreateRedeclareNode) { TEST(TestSymbolGenerator, MatchCreateRedeclareEdge) { SymbolTable symbol_table; AstTreeStorage storage; - auto query_ast = MatchCreateRedeclareEdge(storage); + // AST with redeclaring a match edge variable in create: + // MATCH (n) -[r]-> (m) CREATE (n) -[r] -> (l) + auto match = storage.Create(); + match->patterns_.emplace_back(GetPattern(storage, {"n", "r", "m"})); + auto query = storage.query(); + query->clauses_.emplace_back(match); + + auto create = storage.Create(); + auto pattern = GetPattern(storage, {"n", "r", "l"}); + auto edge_atom = dynamic_cast(pattern->atoms_[1]); + std::string relationship("relationship"); + edge_atom->edge_types_.emplace_back(&relationship); + create->patterns_.emplace_back(pattern); + query->clauses_.emplace_back(create); SymbolGenerator symbol_generator(symbol_table); - EXPECT_THROW(query_ast->Accept(symbol_generator), RedeclareVariableError); + EXPECT_THROW(query->Accept(symbol_generator), RedeclareVariableError); +} + +TEST(TestSymbolGenerator, MatchTypeMismatch) { + AstTreeStorage storage; + // Using an edge variable as a node causes a type mismatch. + // MATCH (n) -[r]-> (r) + auto match = storage.Create(); + match->patterns_.emplace_back(GetPattern(storage, {"n", "r", "r"})); + auto query = storage.query(); + query->clauses_.emplace_back(match); + SymbolTable symbol_table; + SymbolGenerator symbol_generator(symbol_table); + EXPECT_THROW(query->Accept(symbol_generator), TypeMismatchError); +} + +TEST(TestSymbolGenerator, MatchCreateTypeMismatch) { + AstTreeStorage storage; + // Using an edge variable as a node causes a type mismatch. + // MATCH (n1) -[r1]- (n2) CREATE (r1) -[r2]- (n2) + auto match = storage.Create(); + match->patterns_.emplace_back(GetPattern(storage, {"n1", "r1", "n2"})); + auto query = storage.query(); + query->clauses_.emplace_back(match); + auto create = storage.Create(); + create->patterns_.emplace_back(GetPattern(storage, {"r1", "r2", "n2"})); + query->clauses_.emplace_back(create); + SymbolTable symbol_table; + SymbolGenerator symbol_generator(symbol_table); + EXPECT_THROW(query->Accept(symbol_generator), TypeMismatchError); +} + +TEST(TestSymbolGenerator, CreateMultipleEdgeType) { + AstTreeStorage storage; + // Multiple edge relationship are not allowed when creating edges. + // CREATE (n) -[r :rel1 | :rel2]- (m) + auto pattern = GetPattern(storage, {"n", "r", "m"}); + auto edge_atom = dynamic_cast(pattern->atoms_[1]); + std::string rel1("rel1"); + edge_atom->edge_types_.emplace_back(&rel1); + std::string rel2("rel2"); + edge_atom->edge_types_.emplace_back(&rel2); + auto create = storage.Create(); + create->patterns_.emplace_back(pattern); + auto query = storage.query(); + query->clauses_.emplace_back(create); + SymbolTable symbol_table; + SymbolGenerator symbol_generator(symbol_table); + EXPECT_THROW(query->Accept(symbol_generator), SemanticException); } }