Add tests for symbol generation

Summary:
Add tests for symbol generator

Also, remove redundant PreVisit method from AST visitor

Reviewers: mislav.bradac, buda, florijan

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D124
This commit is contained in:
Teon Banek 2017-03-15 14:20:03 +01:00
parent b956d0812b
commit 42e8d339c5
7 changed files with 153 additions and 68 deletions

View File

@ -6,7 +6,8 @@
#include "query/frontend/interpret/interpret.hpp"
#include "query/frontend/logical/planner.hpp"
#include "query/frontend/opencypher/parser.hpp"
#include "query/frontend/typecheck/typecheck.hpp"
#include "query/frontend/semantic/symbol_table.hpp"
#include "query/frontend/semantic/symbol_generator.hpp"
namespace query {
@ -52,11 +53,11 @@ class Engine {
// symbol table fill
SymbolTable symbol_table;
TypeCheckVisitor typecheck_visitor(symbol_table);
high_level_tree->Accept(typecheck_visitor);
SymbolGenerator symbol_generator(symbol_table);
high_level_tree->Accept(symbol_generator);
// high level tree -> logical plan
auto logical_plan = Apply(*high_level_tree);
auto logical_plan = MakeLogicalPlan(*high_level_tree);
// generate frame based on symbol table max_position
Frame frame(symbol_table.max_position());

View File

@ -7,6 +7,7 @@ namespace query {
class SyntaxException : public BasicException {
public:
using BasicException::BasicException;
SyntaxException() : BasicException("") {}
};
@ -19,8 +20,9 @@ class SyntaxException : public BasicException {
// query and only report line numbers of semantic errors (not position in the
// line) if multiple line strings are not allowed by grammar. We could also
// print whole line that contains error instead of specifying line number.
class SemanticException : BasicException {
class SemanticException : public BasicException {
public:
using BasicException::BasicException;
SemanticException() : BasicException("") {}
};

View File

@ -28,7 +28,6 @@ public:
Identifier(int uid, const std::string &name) : Expression(uid), name_(name) {}
void Accept(TreeVisitorBase &visitor) override {
visitor.PreVisit(*this);
visitor.Visit(*this);
visitor.PostVisit(*this);
}
@ -40,9 +39,8 @@ class NamedExpression : public Tree {
public:
NamedExpression(int uid) : Tree(uid) {}
void Accept(TreeVisitorBase &visitor) override {
visitor.PreVisit(*this);
expression_->Accept(visitor);
visitor.Visit(*this);
expression_->Accept(visitor);
visitor.PostVisit(*this);
}
@ -59,9 +57,8 @@ class NodeAtom : public PatternAtom {
public:
NodeAtom(int uid) : PatternAtom(uid) {}
void Accept(TreeVisitorBase &visitor) override {
visitor.PreVisit(*this);
identifier_->Accept(visitor);
visitor.Visit(*this);
identifier_->Accept(visitor);
visitor.PostVisit(*this);
}
@ -75,9 +72,8 @@ public:
EdgeAtom(int uid) : PatternAtom(uid) {}
void Accept(TreeVisitorBase &visitor) override {
visitor.PreVisit(*this);
identifier_->Accept(visitor);
visitor.Visit(*this);
identifier_->Accept(visitor);
visitor.PostVisit(*this);
}
@ -94,11 +90,10 @@ class Pattern : public Tree {
public:
Pattern(int uid) : Tree(uid) {}
void Accept(TreeVisitorBase &visitor) override {
visitor.PreVisit(*this);
visitor.Visit(*this);
for (auto &part : atoms_) {
part->Accept(visitor);
}
visitor.Visit(*this);
visitor.PostVisit(*this);
}
std::shared_ptr<Identifier> identifier_;
@ -109,11 +104,10 @@ class Query : public Tree {
public:
Query(int uid) : Tree(uid) {}
void Accept(TreeVisitorBase &visitor) override {
visitor.PreVisit(*this);
visitor.Visit(*this);
for (auto &clause : clauses_) {
clause->Accept(visitor);
}
visitor.Visit(*this);
visitor.PostVisit(*this);
}
std::vector<std::shared_ptr<Clause>> clauses_;
@ -124,11 +118,10 @@ public:
Match(int uid) : Clause(uid) {}
std::vector<std::shared_ptr<Pattern>> patterns_;
void Accept(TreeVisitorBase &visitor) override {
visitor.PreVisit(*this);
visitor.Visit(*this);
for (auto &pattern : patterns_) {
pattern->Accept(visitor);
}
visitor.Visit(*this);
visitor.PostVisit(*this);
}
};
@ -137,11 +130,10 @@ class Return : public Clause {
public:
Return(int uid) : Clause(uid) {}
void Accept(TreeVisitorBase &visitor) override {
visitor.PreVisit(*this);
visitor.Visit(*this);
for (auto &expr : named_expressions_) {
expr->Accept(visitor);
}
visitor.Visit(*this);
visitor.PostVisit(*this);
}
std::vector<std::shared_ptr<NamedExpression>> named_expressions_;

View File

@ -16,31 +16,23 @@ class TreeVisitorBase {
public:
virtual ~TreeVisitorBase() {}
// Start of the tree is a Query.
virtual void PreVisit(Query&) {}
virtual void Visit(Query&) {}
virtual void PostVisit(Query&) {}
// Expressions
virtual void PreVisit(NamedExpression&) {}
virtual void Visit(NamedExpression&) {}
virtual void PostVisit(NamedExpression&) {}
virtual void PreVisit(Identifier&) {}
virtual void Visit(Identifier&) {}
virtual void PostVisit(Identifier&) {}
// Clauses
virtual void PreVisit(Match&) {}
virtual void Visit(Match&) {}
virtual void PostVisit(Match&) {}
virtual void PreVisit(Return&) {}
virtual void Visit(Return&) {}
virtual void PostVisit(Return&) {}
// Pattern and its subparts.
virtual void PreVisit(Pattern&) {}
virtual void Visit(Pattern&) {}
virtual void PostVisit(Pattern&) {}
virtual void PreVisit(NodeAtom&) {}
virtual void Visit(NodeAtom&) {}
virtual void PostVisit(NodeAtom&) {}
virtual void PreVisit(EdgeAtom&) {}
virtual void Visit(EdgeAtom&) {}
virtual void PostVisit(EdgeAtom&) {}
};

View File

@ -1,6 +1,6 @@
#pragma once
#include "utils/exceptions/basic_exception.hpp"
#include "query/exceptions.hpp"
#include "query/frontend/ast/ast.hpp"
#include "query/frontend/semantic/symbol_table.hpp"
@ -10,48 +10,30 @@ class SymbolGenerator : public TreeVisitorBase {
public:
SymbolGenerator(SymbolTable& symbol_table) : symbol_table_(symbol_table) {}
// Expressions
void PreVisit(NamedExpression& named_expr) override {
scope_.in_named_expr = true;
// Clauses
void PostVisit(Return& ret) override {
for (auto &named_expr : ret.named_expressions_) {
// Named expressions establish bindings for expressions which come after
// return, but not for the expressions contained inside.
symbol_table_[*named_expr] = CreateSymbol(named_expr->name_);
}
}
// Expressions
void Visit(Identifier& ident) override {
Symbol symbol;
if (scope_.in_named_expr) {
// TODO: Handle this better, so that the `with` variables aren't
// shadowed.
if (HasSymbol(ident.name_)) {
scope_.revert_variable = true;
scope_.old_symbol = scope_.variables[ident.name_];
}
symbol = CreateSymbol(ident.name_);
} else if (scope_.in_pattern) {
if (scope_.in_pattern) {
symbol = GetOrCreateSymbol(ident.name_);
} else {
if (!HasSymbol(ident.name_))
// TODO: Special exception for type check
throw BasicException("Unbound identifier: " + ident.name_);
throw SemanticException("Unbound identifier: " + ident.name_);
symbol = scope_.variables[ident.name_];
}
symbol_table_[ident] = symbol;
}
void PostVisit(Identifier& ident) override {
if (scope_.in_named_expr) {
if (scope_.revert_variable) {
scope_.variables[ident.name_] = scope_.old_symbol;
}
scope_.in_named_expr = false;
scope_.revert_variable = false;
}
}
// Clauses
void PreVisit(Return& ret) override {
scope_.in_return = true;
}
void PostVisit(Return& ret) override {
scope_.in_return = false;
}
// Pattern and its subparts.
void PreVisit(Pattern& pattern) override {
void Visit(Pattern& pattern) override {
scope_.in_pattern = true;
}
void PostVisit(Pattern& pattern) override {
@ -60,14 +42,8 @@ class SymbolGenerator : public TreeVisitorBase {
private:
struct Scope {
Scope()
: in_pattern(false), in_return(false), in_named_expr(false),
revert_variable(false) {}
Scope() : in_pattern(false) {}
bool in_pattern;
bool in_return;
bool in_named_expr;
bool revert_variable;
Symbol old_symbol;
std::map<std::string, Symbol> variables;
};

View File

@ -6,11 +6,19 @@
#include "query/frontend/ast/ast.hpp"
namespace query {
struct Symbol {
class Symbol {
public:
Symbol() {}
Symbol(const std::string& name, int position) : name_(name), position_(position) {}
Symbol(const std::string& name, int position)
: name_(name), position_(position) {}
std::string name_;
int position_;
bool operator==(const Symbol& other) const {
return position_ == other.position_ && name_ == other.name_;
}
bool operator!=(const Symbol& other) const { return !operator==(other); }
};
class SymbolTable {

View File

@ -0,0 +1,114 @@
#include <memory>
#include "gtest/gtest.h"
#include "query/frontend/ast/ast.hpp"
#include "query/frontend/interpret/interpret.hpp"
#include "query/frontend/semantic/symbol_table.hpp"
#include "query/frontend/semantic/symbol_generator.hpp"
using namespace query;
// Build a simple AST which describes:
// MATCH (node_atom_1) RETURN node_atom_1 AS node_atom_1
static std::unique_ptr<Query> MatchNodeReturn() {
int uid = 0;
auto node_atom = std::make_shared<NodeAtom>(uid++);
node_atom->identifier_ = std::make_shared<Identifier>(uid++, "node_atom_1");
auto pattern = std::make_shared<Pattern>(uid++);
pattern->atoms_.emplace_back(node_atom);
auto match = std::make_shared<Match>(uid++);
match->patterns_.emplace_back(pattern);
auto query = std::make_unique<Query>(uid++);
query->clauses_.emplace_back(match);
auto named_expr = std::make_shared<NamedExpression>(uid++);
named_expr->name_ = "node_atom_1";
named_expr->expression_ = std::make_shared<Identifier>(uid++, "node_atom_1");
auto ret = std::make_shared<Return>(uid++);
ret->named_expressions_.emplace_back(named_expr);
query->clauses_.emplace_back(ret);
return query;
}
// AST using variable in return bound by naming the previous return expression.
// This is treated as an unbound variable.
// MATCH (node_atom_1) RETURN node_atom_1 AS n, n AS n
static std::unique_ptr<Query> MatchUnboundMultiReturn() {
int uid = 0;
auto node_atom = std::make_shared<NodeAtom>(uid++);
node_atom->identifier_ = std::make_shared<Identifier>(uid++, "node_atom_1");
auto pattern = std::make_shared<Pattern>(uid++);
pattern->atoms_.emplace_back(node_atom);
auto match = std::make_shared<Match>(uid++);
match->patterns_.emplace_back(pattern);
auto query = std::make_unique<Query>(uid++);
query->clauses_.emplace_back(match);
auto named_expr_1 = std::make_shared<NamedExpression>(uid++);
named_expr_1->name_ = "n";
named_expr_1->expression_ = std::make_shared<Identifier>(uid++, "node_atom_1");
auto named_expr_2 = std::make_shared<NamedExpression>(uid++);
named_expr_2->name_ = "n";
named_expr_2->expression_ = std::make_shared<Identifier>(uid++, "n");
auto ret = std::make_shared<Return>(uid++);
ret->named_expressions_.emplace_back(named_expr_1);
ret->named_expressions_.emplace_back(named_expr_2);
query->clauses_.emplace_back(ret);
return query;
}
// AST with unbound variable in return: MATCH (n) RETURN x AS x
static std::unique_ptr<Query> MatchNodeUnboundReturn() {
int uid = 0;
auto node_atom = std::make_shared<NodeAtom>(uid++);
node_atom->identifier_ = std::make_shared<Identifier>(uid++, "n");
auto pattern = std::make_shared<Pattern>(uid++);
pattern->atoms_.emplace_back(node_atom);
auto match = std::make_shared<Match>(uid++);
match->patterns_.emplace_back(pattern);
auto query = std::make_unique<Query>(uid++);
query->clauses_.emplace_back(match);
auto named_expr = std::make_shared<NamedExpression>(uid++);
named_expr->name_ = "x";
named_expr->expression_ = std::make_shared<Identifier>(uid++, "x");
auto ret = std::make_shared<Return>(uid++);
ret->named_expressions_.emplace_back(named_expr);
query->clauses_.emplace_back(ret);
return query;
}
TEST(TestSymbolGenerator, MatchNodeReturn) {
SymbolTable symbol_table;
auto query_ast = MatchNodeReturn();
SymbolGenerator symbol_generator(symbol_table);
query_ast->Accept(symbol_generator);
EXPECT_EQ(symbol_table.max_position(), 2);
auto match = std::dynamic_pointer_cast<Match>(query_ast->clauses_[0]);
auto pattern = match->patterns_[0];
auto node_atom = std::dynamic_pointer_cast<NodeAtom>(pattern->atoms_[0]);
auto node_sym = symbol_table[*node_atom->identifier_];
EXPECT_EQ(node_sym.name_, "node_atom_1");
auto ret = std::dynamic_pointer_cast<Return>(query_ast->clauses_[1]);
auto named_expr = ret->named_expressions_[0];
auto column_sym = symbol_table[*named_expr];
EXPECT_EQ(node_sym.name_, column_sym.name_);
EXPECT_NE(node_sym, column_sym);
auto ret_sym = symbol_table[*named_expr->expression_];
EXPECT_EQ(node_sym, ret_sym);
}
TEST(TestSymbolGenerator, MatchUnboundMultiReturn) {
SymbolTable symbol_table;
auto query_ast = MatchUnboundMultiReturn();
SymbolGenerator symbol_generator(symbol_table);
EXPECT_THROW(query_ast->Accept(symbol_generator), SemanticException);
}
TEST(TestSymbolGenerator, MatchNodeUnboundReturn) {
SymbolTable symbol_table;
auto query_ast = MatchNodeUnboundReturn();
SymbolGenerator symbol_generator(symbol_table);
EXPECT_THROW(query_ast->Accept(symbol_generator), SemanticException);
}