diff --git a/src/query/entry.hpp b/src/query/entry.hpp index e0abc3000..43bc9899b 100644 --- a/src/query/entry.hpp +++ b/src/query/entry.hpp @@ -6,7 +6,8 @@ #include "query/frontend/interpret/interpret.hpp" #include "query/frontend/logical/planner.hpp" #include "query/frontend/opencypher/parser.hpp" -#include "query/frontend/typecheck/typecheck.hpp" +#include "query/frontend/semantic/symbol_table.hpp" +#include "query/frontend/semantic/symbol_generator.hpp" namespace query { @@ -52,11 +53,11 @@ class Engine { // symbol table fill SymbolTable symbol_table; - TypeCheckVisitor typecheck_visitor(symbol_table); - high_level_tree->Accept(typecheck_visitor); + SymbolGenerator symbol_generator(symbol_table); + high_level_tree->Accept(symbol_generator); // high level tree -> logical plan - auto logical_plan = Apply(*high_level_tree); + auto logical_plan = MakeLogicalPlan(*high_level_tree); // generate frame based on symbol table max_position Frame frame(symbol_table.max_position()); diff --git a/src/query/exceptions.hpp b/src/query/exceptions.hpp index ff48a623e..479961bc3 100644 --- a/src/query/exceptions.hpp +++ b/src/query/exceptions.hpp @@ -7,6 +7,7 @@ namespace query { class SyntaxException : public BasicException { public: + using BasicException::BasicException; SyntaxException() : BasicException("") {} }; @@ -19,8 +20,9 @@ class SyntaxException : public BasicException { // query and only report line numbers of semantic errors (not position in the // line) if multiple line strings are not allowed by grammar. We could also // print whole line that contains error instead of specifying line number. -class SemanticException : BasicException { +class SemanticException : public BasicException { public: + using BasicException::BasicException; SemanticException() : BasicException("") {} }; diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index 953b6ca4a..64c7ddb76 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -28,7 +28,6 @@ public: Identifier(int uid, const std::string &name) : Expression(uid), name_(name) {} void Accept(TreeVisitorBase &visitor) override { - visitor.PreVisit(*this); visitor.Visit(*this); visitor.PostVisit(*this); } @@ -40,9 +39,8 @@ class NamedExpression : public Tree { public: NamedExpression(int uid) : Tree(uid) {} void Accept(TreeVisitorBase &visitor) override { - visitor.PreVisit(*this); - expression_->Accept(visitor); visitor.Visit(*this); + expression_->Accept(visitor); visitor.PostVisit(*this); } @@ -59,9 +57,8 @@ class NodeAtom : public PatternAtom { public: NodeAtom(int uid) : PatternAtom(uid) {} void Accept(TreeVisitorBase &visitor) override { - visitor.PreVisit(*this); - identifier_->Accept(visitor); visitor.Visit(*this); + identifier_->Accept(visitor); visitor.PostVisit(*this); } @@ -75,9 +72,8 @@ public: EdgeAtom(int uid) : PatternAtom(uid) {} void Accept(TreeVisitorBase &visitor) override { - visitor.PreVisit(*this); - identifier_->Accept(visitor); visitor.Visit(*this); + identifier_->Accept(visitor); visitor.PostVisit(*this); } @@ -94,11 +90,10 @@ class Pattern : public Tree { public: Pattern(int uid) : Tree(uid) {} void Accept(TreeVisitorBase &visitor) override { - visitor.PreVisit(*this); + visitor.Visit(*this); for (auto &part : atoms_) { part->Accept(visitor); } - visitor.Visit(*this); visitor.PostVisit(*this); } std::shared_ptr<Identifier> identifier_; @@ -109,11 +104,10 @@ class Query : public Tree { public: Query(int uid) : Tree(uid) {} void Accept(TreeVisitorBase &visitor) override { - visitor.PreVisit(*this); + visitor.Visit(*this); for (auto &clause : clauses_) { clause->Accept(visitor); } - visitor.Visit(*this); visitor.PostVisit(*this); } std::vector<std::shared_ptr<Clause>> clauses_; @@ -124,11 +118,10 @@ public: Match(int uid) : Clause(uid) {} std::vector<std::shared_ptr<Pattern>> patterns_; void Accept(TreeVisitorBase &visitor) override { - visitor.PreVisit(*this); + visitor.Visit(*this); for (auto &pattern : patterns_) { pattern->Accept(visitor); } - visitor.Visit(*this); visitor.PostVisit(*this); } }; @@ -137,11 +130,10 @@ class Return : public Clause { public: Return(int uid) : Clause(uid) {} void Accept(TreeVisitorBase &visitor) override { - visitor.PreVisit(*this); + visitor.Visit(*this); for (auto &expr : named_expressions_) { expr->Accept(visitor); } - visitor.Visit(*this); visitor.PostVisit(*this); } std::vector<std::shared_ptr<NamedExpression>> named_expressions_; diff --git a/src/query/frontend/ast/ast_visitor.hpp b/src/query/frontend/ast/ast_visitor.hpp index c3b6aea40..0bb6f9948 100644 --- a/src/query/frontend/ast/ast_visitor.hpp +++ b/src/query/frontend/ast/ast_visitor.hpp @@ -16,31 +16,23 @@ class TreeVisitorBase { public: virtual ~TreeVisitorBase() {} // Start of the tree is a Query. - virtual void PreVisit(Query&) {} virtual void Visit(Query&) {} virtual void PostVisit(Query&) {} // Expressions - virtual void PreVisit(NamedExpression&) {} virtual void Visit(NamedExpression&) {} virtual void PostVisit(NamedExpression&) {} - virtual void PreVisit(Identifier&) {} virtual void Visit(Identifier&) {} virtual void PostVisit(Identifier&) {} // Clauses - virtual void PreVisit(Match&) {} virtual void Visit(Match&) {} virtual void PostVisit(Match&) {} - virtual void PreVisit(Return&) {} virtual void Visit(Return&) {} virtual void PostVisit(Return&) {} // Pattern and its subparts. - virtual void PreVisit(Pattern&) {} virtual void Visit(Pattern&) {} virtual void PostVisit(Pattern&) {} - virtual void PreVisit(NodeAtom&) {} virtual void Visit(NodeAtom&) {} virtual void PostVisit(NodeAtom&) {} - virtual void PreVisit(EdgeAtom&) {} virtual void Visit(EdgeAtom&) {} virtual void PostVisit(EdgeAtom&) {} }; diff --git a/src/query/frontend/semantic/symbol_generator.hpp b/src/query/frontend/semantic/symbol_generator.hpp index c11551762..b1e99b2f8 100644 --- a/src/query/frontend/semantic/symbol_generator.hpp +++ b/src/query/frontend/semantic/symbol_generator.hpp @@ -1,6 +1,6 @@ #pragma once -#include "utils/exceptions/basic_exception.hpp" +#include "query/exceptions.hpp" #include "query/frontend/ast/ast.hpp" #include "query/frontend/semantic/symbol_table.hpp" @@ -10,48 +10,30 @@ class SymbolGenerator : public TreeVisitorBase { public: SymbolGenerator(SymbolTable& symbol_table) : symbol_table_(symbol_table) {} - // Expressions - void PreVisit(NamedExpression& named_expr) override { - scope_.in_named_expr = true; + // Clauses + void PostVisit(Return& ret) override { + for (auto &named_expr : ret.named_expressions_) { + // Named expressions establish bindings for expressions which come after + // return, but not for the expressions contained inside. + symbol_table_[*named_expr] = CreateSymbol(named_expr->name_); + } } + + // Expressions void Visit(Identifier& ident) override { Symbol symbol; - if (scope_.in_named_expr) { - // TODO: Handle this better, so that the `with` variables aren't - // shadowed. - if (HasSymbol(ident.name_)) { - scope_.revert_variable = true; - scope_.old_symbol = scope_.variables[ident.name_]; - } - symbol = CreateSymbol(ident.name_); - } else if (scope_.in_pattern) { + if (scope_.in_pattern) { symbol = GetOrCreateSymbol(ident.name_); } else { if (!HasSymbol(ident.name_)) // TODO: Special exception for type check - throw BasicException("Unbound identifier: " + ident.name_); + throw SemanticException("Unbound identifier: " + ident.name_); symbol = scope_.variables[ident.name_]; } symbol_table_[ident] = symbol; } - void PostVisit(Identifier& ident) override { - if (scope_.in_named_expr) { - if (scope_.revert_variable) { - scope_.variables[ident.name_] = scope_.old_symbol; - } - scope_.in_named_expr = false; - scope_.revert_variable = false; - } - } - // Clauses - void PreVisit(Return& ret) override { - scope_.in_return = true; - } - void PostVisit(Return& ret) override { - scope_.in_return = false; - } // Pattern and its subparts. - void PreVisit(Pattern& pattern) override { + void Visit(Pattern& pattern) override { scope_.in_pattern = true; } void PostVisit(Pattern& pattern) override { @@ -60,14 +42,8 @@ class SymbolGenerator : public TreeVisitorBase { private: struct Scope { - Scope() - : in_pattern(false), in_return(false), in_named_expr(false), - revert_variable(false) {} + Scope() : in_pattern(false) {} bool in_pattern; - bool in_return; - bool in_named_expr; - bool revert_variable; - Symbol old_symbol; std::map<std::string, Symbol> variables; }; diff --git a/src/query/frontend/semantic/symbol_table.hpp b/src/query/frontend/semantic/symbol_table.hpp index 6fc222705..515747176 100644 --- a/src/query/frontend/semantic/symbol_table.hpp +++ b/src/query/frontend/semantic/symbol_table.hpp @@ -6,11 +6,19 @@ #include "query/frontend/ast/ast.hpp" namespace query { -struct Symbol { +class Symbol { + public: Symbol() {} - Symbol(const std::string& name, int position) : name_(name), position_(position) {} + Symbol(const std::string& name, int position) + : name_(name), position_(position) {} std::string name_; int position_; + + bool operator==(const Symbol& other) const { + return position_ == other.position_ && name_ == other.name_; + } + bool operator!=(const Symbol& other) const { return !operator==(other); } + }; class SymbolTable { diff --git a/tests/unit/query_semantic.cpp b/tests/unit/query_semantic.cpp new file mode 100644 index 000000000..c19c1f032 --- /dev/null +++ b/tests/unit/query_semantic.cpp @@ -0,0 +1,114 @@ +#include <memory> + +#include "gtest/gtest.h" + +#include "query/frontend/ast/ast.hpp" +#include "query/frontend/interpret/interpret.hpp" +#include "query/frontend/semantic/symbol_table.hpp" +#include "query/frontend/semantic/symbol_generator.hpp" + +using namespace query; + +// Build a simple AST which describes: +// MATCH (node_atom_1) RETURN node_atom_1 AS node_atom_1 +static std::unique_ptr<Query> MatchNodeReturn() { + int uid = 0; + auto node_atom = std::make_shared<NodeAtom>(uid++); + node_atom->identifier_ = std::make_shared<Identifier>(uid++, "node_atom_1"); + auto pattern = std::make_shared<Pattern>(uid++); + pattern->atoms_.emplace_back(node_atom); + auto match = std::make_shared<Match>(uid++); + match->patterns_.emplace_back(pattern); + auto query = std::make_unique<Query>(uid++); + query->clauses_.emplace_back(match); + + auto named_expr = std::make_shared<NamedExpression>(uid++); + named_expr->name_ = "node_atom_1"; + named_expr->expression_ = std::make_shared<Identifier>(uid++, "node_atom_1"); + auto ret = std::make_shared<Return>(uid++); + ret->named_expressions_.emplace_back(named_expr); + query->clauses_.emplace_back(ret); + return query; +} + +// AST using variable in return bound by naming the previous return expression. +// This is treated as an unbound variable. +// MATCH (node_atom_1) RETURN node_atom_1 AS n, n AS n +static std::unique_ptr<Query> MatchUnboundMultiReturn() { + int uid = 0; + auto node_atom = std::make_shared<NodeAtom>(uid++); + node_atom->identifier_ = std::make_shared<Identifier>(uid++, "node_atom_1"); + auto pattern = std::make_shared<Pattern>(uid++); + pattern->atoms_.emplace_back(node_atom); + auto match = std::make_shared<Match>(uid++); + match->patterns_.emplace_back(pattern); + auto query = std::make_unique<Query>(uid++); + query->clauses_.emplace_back(match); + + auto named_expr_1 = std::make_shared<NamedExpression>(uid++); + named_expr_1->name_ = "n"; + named_expr_1->expression_ = std::make_shared<Identifier>(uid++, "node_atom_1"); + auto named_expr_2 = std::make_shared<NamedExpression>(uid++); + named_expr_2->name_ = "n"; + named_expr_2->expression_ = std::make_shared<Identifier>(uid++, "n"); + auto ret = std::make_shared<Return>(uid++); + ret->named_expressions_.emplace_back(named_expr_1); + ret->named_expressions_.emplace_back(named_expr_2); + query->clauses_.emplace_back(ret); + return query; +} + +// AST with unbound variable in return: MATCH (n) RETURN x AS x +static std::unique_ptr<Query> MatchNodeUnboundReturn() { + int uid = 0; + auto node_atom = std::make_shared<NodeAtom>(uid++); + node_atom->identifier_ = std::make_shared<Identifier>(uid++, "n"); + auto pattern = std::make_shared<Pattern>(uid++); + pattern->atoms_.emplace_back(node_atom); + auto match = std::make_shared<Match>(uid++); + match->patterns_.emplace_back(pattern); + auto query = std::make_unique<Query>(uid++); + query->clauses_.emplace_back(match); + + auto named_expr = std::make_shared<NamedExpression>(uid++); + named_expr->name_ = "x"; + named_expr->expression_ = std::make_shared<Identifier>(uid++, "x"); + auto ret = std::make_shared<Return>(uid++); + ret->named_expressions_.emplace_back(named_expr); + query->clauses_.emplace_back(ret); + return query; +} + +TEST(TestSymbolGenerator, MatchNodeReturn) { + SymbolTable symbol_table; + auto query_ast = MatchNodeReturn(); + SymbolGenerator symbol_generator(symbol_table); + query_ast->Accept(symbol_generator); + EXPECT_EQ(symbol_table.max_position(), 2); + auto match = std::dynamic_pointer_cast<Match>(query_ast->clauses_[0]); + auto pattern = match->patterns_[0]; + auto node_atom = std::dynamic_pointer_cast<NodeAtom>(pattern->atoms_[0]); + auto node_sym = symbol_table[*node_atom->identifier_]; + EXPECT_EQ(node_sym.name_, "node_atom_1"); + auto ret = std::dynamic_pointer_cast<Return>(query_ast->clauses_[1]); + auto named_expr = ret->named_expressions_[0]; + auto column_sym = symbol_table[*named_expr]; + EXPECT_EQ(node_sym.name_, column_sym.name_); + EXPECT_NE(node_sym, column_sym); + auto ret_sym = symbol_table[*named_expr->expression_]; + EXPECT_EQ(node_sym, ret_sym); +} + +TEST(TestSymbolGenerator, MatchUnboundMultiReturn) { + SymbolTable symbol_table; + auto query_ast = MatchUnboundMultiReturn(); + SymbolGenerator symbol_generator(symbol_table); + EXPECT_THROW(query_ast->Accept(symbol_generator), SemanticException); +} + +TEST(TestSymbolGenerator, MatchNodeUnboundReturn) { + SymbolTable symbol_table; + auto query_ast = MatchNodeUnboundReturn(); + SymbolGenerator symbol_generator(symbol_table); + EXPECT_THROW(query_ast->Accept(symbol_generator), SemanticException); +}