diff --git a/src/query/frontend/semantic/symbol_generator.cpp b/src/query/frontend/semantic/symbol_generator.cpp index fa9e643fa..2ffe84ed8 100644 --- a/src/query/frontend/semantic/symbol_generator.cpp +++ b/src/query/frontend/semantic/symbol_generator.cpp @@ -6,26 +6,24 @@ namespace query { -auto SymbolGenerator::CreateVariable(const std::string &name, - SymbolGenerator::Variable::Type type) { - auto symbol = symbol_table_.CreateSymbol(name); - auto variable = SymbolGenerator::Variable{symbol, type}; - scope_.variables[name] = variable; - return variable; +auto SymbolGenerator::CreateSymbol(const std::string &name, Symbol::Type type) { + auto symbol = symbol_table_.CreateSymbol(name, type); + scope_.symbols[name] = symbol; + return symbol; } -auto SymbolGenerator::GetOrCreateVariable( - const std::string &name, SymbolGenerator::Variable::Type type) { - auto search = scope_.variables.find(name); - if (search != scope_.variables.end()) { - auto variable = search->second; - if (type != SymbolGenerator::Variable::Type::Any && type != variable.type) { - throw TypeMismatchError(name, TypeToString(variable.type), - TypeToString(type)); +auto SymbolGenerator::GetOrCreateSymbol(const std::string &name, + Symbol::Type type) { + auto search = scope_.symbols.find(name); + if (search != scope_.symbols.end()) { + auto symbol = search->second; + if (type != Symbol::Type::Any && type != symbol.type_) { + throw TypeMismatchError(name, Symbol::TypeToString(symbol.type_), + Symbol::TypeToString(type)); } return search->second; } - return CreateVariable(name, type); + return CreateSymbol(name, type); } // Clauses @@ -37,7 +35,7 @@ void SymbolGenerator::PostVisit(Return &ret) { for (auto &named_expr : ret.named_expressions_) { // Named expressions establish bindings for expressions which come after // return, but not for the expressions contained inside. - symbol_table_[*named_expr] = CreateVariable(named_expr->name_).symbol; + symbol_table_[*named_expr] = CreateSymbol(named_expr->name_); } } @@ -59,23 +57,23 @@ void SymbolGenerator::Visit(Identifier &ident) { // Additionally, we will support edge referencing in pattern: // `MATCH (n) - [r] -> (n) - [r] -> (n) RETURN r`, which would // usually raise redeclaration of `r`. - if (scope_.in_property_map && !HasVariable(ident.name_)) { + if (scope_.in_property_map && !HasSymbol(ident.name_)) { // Case 1) throw UnboundVariableError(ident.name_); } else if ((scope_.in_create_node || scope_.in_create_edge) && - HasVariable(ident.name_)) { + HasSymbol(ident.name_)) { // Case 2) throw RedeclareVariableError(ident.name_); } - auto type = Variable::Type::Vertex; + auto type = Symbol::Type::Vertex; if (scope_.in_edge_atom) { - type = Variable::Type::Edge; + type = Symbol::Type::Edge; } - symbol = GetOrCreateVariable(ident.name_, type).symbol; + symbol = GetOrCreateSymbol(ident.name_, type); } else { // Everything else references a bound symbol. - if (!HasVariable(ident.name_)) throw UnboundVariableError(ident.name_); - symbol = scope_.variables[ident.name_].symbol; + if (!HasSymbol(ident.name_)) throw UnboundVariableError(ident.name_); + symbol = scope_.symbols[ident.name_]; } symbol_table_[ident] = symbol; } @@ -128,8 +126,8 @@ void SymbolGenerator::PostVisit(EdgeAtom &edge_atom) { scope_.in_create_edge = false; } -bool SymbolGenerator::HasVariable(const std::string &name) { - return scope_.variables.find(name) != scope_.variables.end(); +bool SymbolGenerator::HasSymbol(const std::string &name) { + return scope_.symbols.find(name) != scope_.symbols.end(); } } // namespace query diff --git a/src/query/frontend/semantic/symbol_generator.hpp b/src/query/frontend/semantic/symbol_generator.hpp index b0ae49657..6f0ca345c 100644 --- a/src/query/frontend/semantic/symbol_generator.hpp +++ b/src/query/frontend/semantic/symbol_generator.hpp @@ -40,22 +40,8 @@ class SymbolGenerator : public TreeVisitorBase { void PostVisit(EdgeAtom &edge_atom) override; private: - // A variable stores the associated symbol and its type. - struct Variable { - // This is similar to TypedValue::Type, but this has `Any` type. - enum class Type { Any, Vertex, Edge, Path }; - - Symbol symbol; - Type type{Type::Any}; - }; - - std::string TypeToString(Variable::Type type) { - const char *enum_string[] = {"Any", "Vertex", "Edge", "Path"}; - return enum_string[static_cast<int>(type)]; - } - // Scope stores the state of where we are when visiting the AST and a map of - // names to variables. + // names to symbols. struct Scope { bool in_pattern{false}; bool in_create{false}; @@ -67,20 +53,20 @@ class SymbolGenerator : public TreeVisitorBase { bool in_node_atom{false}; bool in_edge_atom{false}; bool in_property_map{false}; - std::map<std::string, Variable> variables; + std::map<std::string, Symbol> symbols; }; - bool HasVariable(const std::string &name); + bool HasSymbol(const std::string &name); - // Returns a new variable with a freshly generated symbol. Previous mapping - // of the same name to a different variable is replaced with the new one. - auto CreateVariable(const std::string &name, - Variable::Type type = Variable::Type::Any); + // Returns a freshly generated symbol. Previous mapping of the same name to a + // different symbol is replaced with the new one. + auto CreateSymbol(const std::string &name, + Symbol::Type type = Symbol::Type::Any); - // Returns the variable by name. If the mapping already exists, checks if the - // types match. Otherwise, returns a new variable. - auto GetOrCreateVariable(const std::string &name, - Variable::Type type = Variable::Type::Any); + // Returns the symbol by name. If the mapping already exists, checks if the + // types match. Otherwise, returns a new symbol. + auto GetOrCreateSymbol(const std::string &name, + Symbol::Type type = Symbol::Type::Any); SymbolTable &symbol_table_; Scope scope_; diff --git a/src/query/frontend/semantic/symbol_table.hpp b/src/query/frontend/semantic/symbol_table.hpp index f2061dbba..dcf21b360 100644 --- a/src/query/frontend/semantic/symbol_table.hpp +++ b/src/query/frontend/semantic/symbol_table.hpp @@ -6,32 +6,44 @@ #include "query/frontend/ast/ast.hpp" namespace query { + class Symbol { public: + // This is similar to TypedValue::Type, but this has `Any` type. + enum class Type { Any, Vertex, Edge, Path }; + + static std::string TypeToString(Type type) { + const char *enum_string[] = {"Any", "Vertex", "Edge", "Path"}; + return enum_string[static_cast<int>(type)]; + } + Symbol() {} - Symbol(const std::string& name, int position) - : name_(name), position_(position) {} + Symbol(const std::string &name, int position, Type type = Type::Any) + : name_(name), position_(position), type_(type) {} + std::string name_; int position_; + Type type_{Type::Any}; - bool operator==(const Symbol& other) const { - return position_ == other.position_ && name_ == other.name_; + bool operator==(const Symbol &other) const { + return position_ == other.position_ && name_ == other.name_ && + type_ == other.type_; } - bool operator!=(const Symbol& other) const { return !operator==(other); } - + bool operator!=(const Symbol &other) const { return !operator==(other); } }; class SymbolTable { public: - Symbol CreateSymbol(const std::string& name) { + Symbol CreateSymbol(const std::string &name, + Symbol::Type type = Symbol::Type::Any) { int position = position_++; - return Symbol(name, position); + return Symbol(name, position, type); } - auto &operator[](const Tree& tree) { return table_[tree.uid()]; } + auto &operator[](const Tree &tree) { return table_[tree.uid()]; } - auto &at(const Tree& tree) { return table_.at(tree.uid()); } - const auto &at(const Tree& tree) const { return table_.at(tree.uid()); } + auto &at(const Tree &tree) { return table_.at(tree.uid()); } + const auto &at(const Tree &tree) const { return table_.at(tree.uid()); } int max_position() const { return position_; } @@ -39,4 +51,5 @@ class SymbolTable { int position_{0}; std::map<int, Symbol> table_; }; + } diff --git a/tests/unit/query_semantic.cpp b/tests/unit/query_semantic.cpp index 91cde8765..9ed6fa151 100644 --- a/tests/unit/query_semantic.cpp +++ b/tests/unit/query_semantic.cpp @@ -27,6 +27,7 @@ TEST(TestSymbolGenerator, MatchNodeReturn) { auto node_atom = dynamic_cast<NodeAtom *>(pattern->atoms_[0]); auto node_sym = symbol_table[*node_atom->identifier_]; EXPECT_EQ(node_sym.name_, "node_atom_1"); + EXPECT_EQ(node_sym.type_, Symbol::Type::Vertex); auto ret = dynamic_cast<Return *>(query_ast->clauses_[1]); auto named_expr = ret->named_expressions_[0]; auto column_sym = symbol_table[*named_expr]; @@ -86,10 +87,12 @@ TEST(TestSymbolGenerator, MatchSameEdge) { is_node = !is_node; } auto &node_symbol = node_symbols.front(); + EXPECT_EQ(node_symbol.type_, Symbol::Type::Vertex); for (auto &symbol : node_symbols) { EXPECT_EQ(node_symbol, symbol); } auto &edge_symbol = edge_symbols.front(); + EXPECT_EQ(edge_symbol.type_, Symbol::Type::Edge); for (auto &symbol : edge_symbols) { EXPECT_EQ(edge_symbol, symbol); } @@ -125,6 +128,7 @@ TEST(TestSymbolGenerator, CreateNodeReturn) { auto node_atom = dynamic_cast<NodeAtom *>(pattern->atoms_[0]); auto node_sym = symbol_table[*node_atom->identifier_]; EXPECT_EQ(node_sym.name_, "n"); + EXPECT_EQ(node_sym.type_, Symbol::Type::Vertex); auto ret = dynamic_cast<Return *>(query_ast->clauses_[1]); auto named_expr = ret->named_expressions_[0]; auto column_sym = symbol_table[*named_expr]; @@ -249,6 +253,7 @@ TEST(TestSymbolGenerator, CreateDelete) { EXPECT_EQ(symbol_table.max_position(), 1); auto node_symbol = symbol_table.at(*node->identifier_); auto ident_symbol = symbol_table.at(*ident); + EXPECT_EQ(node_symbol.type_, Symbol::Type::Vertex); EXPECT_EQ(node_symbol, ident_symbol); }