Add tests for symbol generation
Summary: Add tests for symbol generator Also, remove redundant PreVisit method from AST visitor Reviewers: mislav.bradac, buda, florijan Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D124
This commit is contained in:
parent
b956d0812b
commit
42e8d339c5
@ -6,7 +6,8 @@
|
||||
#include "query/frontend/interpret/interpret.hpp"
|
||||
#include "query/frontend/logical/planner.hpp"
|
||||
#include "query/frontend/opencypher/parser.hpp"
|
||||
#include "query/frontend/typecheck/typecheck.hpp"
|
||||
#include "query/frontend/semantic/symbol_table.hpp"
|
||||
#include "query/frontend/semantic/symbol_generator.hpp"
|
||||
|
||||
namespace query {
|
||||
|
||||
@ -52,11 +53,11 @@ class Engine {
|
||||
|
||||
// symbol table fill
|
||||
SymbolTable symbol_table;
|
||||
TypeCheckVisitor typecheck_visitor(symbol_table);
|
||||
high_level_tree->Accept(typecheck_visitor);
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
high_level_tree->Accept(symbol_generator);
|
||||
|
||||
// high level tree -> logical plan
|
||||
auto logical_plan = Apply(*high_level_tree);
|
||||
auto logical_plan = MakeLogicalPlan(*high_level_tree);
|
||||
|
||||
// generate frame based on symbol table max_position
|
||||
Frame frame(symbol_table.max_position());
|
||||
|
@ -7,6 +7,7 @@ namespace query {
|
||||
|
||||
class SyntaxException : public BasicException {
|
||||
public:
|
||||
using BasicException::BasicException;
|
||||
SyntaxException() : BasicException("") {}
|
||||
};
|
||||
|
||||
@ -19,8 +20,9 @@ class SyntaxException : public BasicException {
|
||||
// query and only report line numbers of semantic errors (not position in the
|
||||
// line) if multiple line strings are not allowed by grammar. We could also
|
||||
// print whole line that contains error instead of specifying line number.
|
||||
class SemanticException : BasicException {
|
||||
class SemanticException : public BasicException {
|
||||
public:
|
||||
using BasicException::BasicException;
|
||||
SemanticException() : BasicException("") {}
|
||||
};
|
||||
|
||||
|
@ -28,7 +28,6 @@ public:
|
||||
Identifier(int uid, const std::string &name) : Expression(uid), name_(name) {}
|
||||
|
||||
void Accept(TreeVisitorBase &visitor) override {
|
||||
visitor.PreVisit(*this);
|
||||
visitor.Visit(*this);
|
||||
visitor.PostVisit(*this);
|
||||
}
|
||||
@ -40,9 +39,8 @@ class NamedExpression : public Tree {
|
||||
public:
|
||||
NamedExpression(int uid) : Tree(uid) {}
|
||||
void Accept(TreeVisitorBase &visitor) override {
|
||||
visitor.PreVisit(*this);
|
||||
expression_->Accept(visitor);
|
||||
visitor.Visit(*this);
|
||||
expression_->Accept(visitor);
|
||||
visitor.PostVisit(*this);
|
||||
}
|
||||
|
||||
@ -59,9 +57,8 @@ class NodeAtom : public PatternAtom {
|
||||
public:
|
||||
NodeAtom(int uid) : PatternAtom(uid) {}
|
||||
void Accept(TreeVisitorBase &visitor) override {
|
||||
visitor.PreVisit(*this);
|
||||
identifier_->Accept(visitor);
|
||||
visitor.Visit(*this);
|
||||
identifier_->Accept(visitor);
|
||||
visitor.PostVisit(*this);
|
||||
}
|
||||
|
||||
@ -75,9 +72,8 @@ public:
|
||||
|
||||
EdgeAtom(int uid) : PatternAtom(uid) {}
|
||||
void Accept(TreeVisitorBase &visitor) override {
|
||||
visitor.PreVisit(*this);
|
||||
identifier_->Accept(visitor);
|
||||
visitor.Visit(*this);
|
||||
identifier_->Accept(visitor);
|
||||
visitor.PostVisit(*this);
|
||||
}
|
||||
|
||||
@ -94,11 +90,10 @@ class Pattern : public Tree {
|
||||
public:
|
||||
Pattern(int uid) : Tree(uid) {}
|
||||
void Accept(TreeVisitorBase &visitor) override {
|
||||
visitor.PreVisit(*this);
|
||||
visitor.Visit(*this);
|
||||
for (auto &part : atoms_) {
|
||||
part->Accept(visitor);
|
||||
}
|
||||
visitor.Visit(*this);
|
||||
visitor.PostVisit(*this);
|
||||
}
|
||||
std::shared_ptr<Identifier> identifier_;
|
||||
@ -109,11 +104,10 @@ class Query : public Tree {
|
||||
public:
|
||||
Query(int uid) : Tree(uid) {}
|
||||
void Accept(TreeVisitorBase &visitor) override {
|
||||
visitor.PreVisit(*this);
|
||||
visitor.Visit(*this);
|
||||
for (auto &clause : clauses_) {
|
||||
clause->Accept(visitor);
|
||||
}
|
||||
visitor.Visit(*this);
|
||||
visitor.PostVisit(*this);
|
||||
}
|
||||
std::vector<std::shared_ptr<Clause>> clauses_;
|
||||
@ -124,11 +118,10 @@ public:
|
||||
Match(int uid) : Clause(uid) {}
|
||||
std::vector<std::shared_ptr<Pattern>> patterns_;
|
||||
void Accept(TreeVisitorBase &visitor) override {
|
||||
visitor.PreVisit(*this);
|
||||
visitor.Visit(*this);
|
||||
for (auto &pattern : patterns_) {
|
||||
pattern->Accept(visitor);
|
||||
}
|
||||
visitor.Visit(*this);
|
||||
visitor.PostVisit(*this);
|
||||
}
|
||||
};
|
||||
@ -137,11 +130,10 @@ class Return : public Clause {
|
||||
public:
|
||||
Return(int uid) : Clause(uid) {}
|
||||
void Accept(TreeVisitorBase &visitor) override {
|
||||
visitor.PreVisit(*this);
|
||||
visitor.Visit(*this);
|
||||
for (auto &expr : named_expressions_) {
|
||||
expr->Accept(visitor);
|
||||
}
|
||||
visitor.Visit(*this);
|
||||
visitor.PostVisit(*this);
|
||||
}
|
||||
std::vector<std::shared_ptr<NamedExpression>> named_expressions_;
|
||||
|
@ -16,31 +16,23 @@ class TreeVisitorBase {
|
||||
public:
|
||||
virtual ~TreeVisitorBase() {}
|
||||
// Start of the tree is a Query.
|
||||
virtual void PreVisit(Query&) {}
|
||||
virtual void Visit(Query&) {}
|
||||
virtual void PostVisit(Query&) {}
|
||||
// Expressions
|
||||
virtual void PreVisit(NamedExpression&) {}
|
||||
virtual void Visit(NamedExpression&) {}
|
||||
virtual void PostVisit(NamedExpression&) {}
|
||||
virtual void PreVisit(Identifier&) {}
|
||||
virtual void Visit(Identifier&) {}
|
||||
virtual void PostVisit(Identifier&) {}
|
||||
// Clauses
|
||||
virtual void PreVisit(Match&) {}
|
||||
virtual void Visit(Match&) {}
|
||||
virtual void PostVisit(Match&) {}
|
||||
virtual void PreVisit(Return&) {}
|
||||
virtual void Visit(Return&) {}
|
||||
virtual void PostVisit(Return&) {}
|
||||
// Pattern and its subparts.
|
||||
virtual void PreVisit(Pattern&) {}
|
||||
virtual void Visit(Pattern&) {}
|
||||
virtual void PostVisit(Pattern&) {}
|
||||
virtual void PreVisit(NodeAtom&) {}
|
||||
virtual void Visit(NodeAtom&) {}
|
||||
virtual void PostVisit(NodeAtom&) {}
|
||||
virtual void PreVisit(EdgeAtom&) {}
|
||||
virtual void Visit(EdgeAtom&) {}
|
||||
virtual void PostVisit(EdgeAtom&) {}
|
||||
};
|
||||
|
@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "utils/exceptions/basic_exception.hpp"
|
||||
#include "query/exceptions.hpp"
|
||||
#include "query/frontend/ast/ast.hpp"
|
||||
#include "query/frontend/semantic/symbol_table.hpp"
|
||||
|
||||
@ -10,48 +10,30 @@ class SymbolGenerator : public TreeVisitorBase {
|
||||
public:
|
||||
SymbolGenerator(SymbolTable& symbol_table) : symbol_table_(symbol_table) {}
|
||||
|
||||
// Expressions
|
||||
void PreVisit(NamedExpression& named_expr) override {
|
||||
scope_.in_named_expr = true;
|
||||
// Clauses
|
||||
void PostVisit(Return& ret) override {
|
||||
for (auto &named_expr : ret.named_expressions_) {
|
||||
// Named expressions establish bindings for expressions which come after
|
||||
// return, but not for the expressions contained inside.
|
||||
symbol_table_[*named_expr] = CreateSymbol(named_expr->name_);
|
||||
}
|
||||
}
|
||||
|
||||
// Expressions
|
||||
void Visit(Identifier& ident) override {
|
||||
Symbol symbol;
|
||||
if (scope_.in_named_expr) {
|
||||
// TODO: Handle this better, so that the `with` variables aren't
|
||||
// shadowed.
|
||||
if (HasSymbol(ident.name_)) {
|
||||
scope_.revert_variable = true;
|
||||
scope_.old_symbol = scope_.variables[ident.name_];
|
||||
}
|
||||
symbol = CreateSymbol(ident.name_);
|
||||
} else if (scope_.in_pattern) {
|
||||
if (scope_.in_pattern) {
|
||||
symbol = GetOrCreateSymbol(ident.name_);
|
||||
} else {
|
||||
if (!HasSymbol(ident.name_))
|
||||
// TODO: Special exception for type check
|
||||
throw BasicException("Unbound identifier: " + ident.name_);
|
||||
throw SemanticException("Unbound identifier: " + ident.name_);
|
||||
symbol = scope_.variables[ident.name_];
|
||||
}
|
||||
symbol_table_[ident] = symbol;
|
||||
}
|
||||
void PostVisit(Identifier& ident) override {
|
||||
if (scope_.in_named_expr) {
|
||||
if (scope_.revert_variable) {
|
||||
scope_.variables[ident.name_] = scope_.old_symbol;
|
||||
}
|
||||
scope_.in_named_expr = false;
|
||||
scope_.revert_variable = false;
|
||||
}
|
||||
}
|
||||
// Clauses
|
||||
void PreVisit(Return& ret) override {
|
||||
scope_.in_return = true;
|
||||
}
|
||||
void PostVisit(Return& ret) override {
|
||||
scope_.in_return = false;
|
||||
}
|
||||
// Pattern and its subparts.
|
||||
void PreVisit(Pattern& pattern) override {
|
||||
void Visit(Pattern& pattern) override {
|
||||
scope_.in_pattern = true;
|
||||
}
|
||||
void PostVisit(Pattern& pattern) override {
|
||||
@ -60,14 +42,8 @@ class SymbolGenerator : public TreeVisitorBase {
|
||||
|
||||
private:
|
||||
struct Scope {
|
||||
Scope()
|
||||
: in_pattern(false), in_return(false), in_named_expr(false),
|
||||
revert_variable(false) {}
|
||||
Scope() : in_pattern(false) {}
|
||||
bool in_pattern;
|
||||
bool in_return;
|
||||
bool in_named_expr;
|
||||
bool revert_variable;
|
||||
Symbol old_symbol;
|
||||
std::map<std::string, Symbol> variables;
|
||||
};
|
||||
|
||||
|
@ -6,11 +6,19 @@
|
||||
#include "query/frontend/ast/ast.hpp"
|
||||
|
||||
namespace query {
|
||||
struct Symbol {
|
||||
class Symbol {
|
||||
public:
|
||||
Symbol() {}
|
||||
Symbol(const std::string& name, int position) : name_(name), position_(position) {}
|
||||
Symbol(const std::string& name, int position)
|
||||
: name_(name), position_(position) {}
|
||||
std::string name_;
|
||||
int position_;
|
||||
|
||||
bool operator==(const Symbol& other) const {
|
||||
return position_ == other.position_ && name_ == other.name_;
|
||||
}
|
||||
bool operator!=(const Symbol& other) const { return !operator==(other); }
|
||||
|
||||
};
|
||||
|
||||
class SymbolTable {
|
||||
|
114
tests/unit/query_semantic.cpp
Normal file
114
tests/unit/query_semantic.cpp
Normal file
@ -0,0 +1,114 @@
|
||||
#include <memory>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "query/frontend/ast/ast.hpp"
|
||||
#include "query/frontend/interpret/interpret.hpp"
|
||||
#include "query/frontend/semantic/symbol_table.hpp"
|
||||
#include "query/frontend/semantic/symbol_generator.hpp"
|
||||
|
||||
using namespace query;
|
||||
|
||||
// Build a simple AST which describes:
|
||||
// MATCH (node_atom_1) RETURN node_atom_1 AS node_atom_1
|
||||
static std::unique_ptr<Query> MatchNodeReturn() {
|
||||
int uid = 0;
|
||||
auto node_atom = std::make_shared<NodeAtom>(uid++);
|
||||
node_atom->identifier_ = std::make_shared<Identifier>(uid++, "node_atom_1");
|
||||
auto pattern = std::make_shared<Pattern>(uid++);
|
||||
pattern->atoms_.emplace_back(node_atom);
|
||||
auto match = std::make_shared<Match>(uid++);
|
||||
match->patterns_.emplace_back(pattern);
|
||||
auto query = std::make_unique<Query>(uid++);
|
||||
query->clauses_.emplace_back(match);
|
||||
|
||||
auto named_expr = std::make_shared<NamedExpression>(uid++);
|
||||
named_expr->name_ = "node_atom_1";
|
||||
named_expr->expression_ = std::make_shared<Identifier>(uid++, "node_atom_1");
|
||||
auto ret = std::make_shared<Return>(uid++);
|
||||
ret->named_expressions_.emplace_back(named_expr);
|
||||
query->clauses_.emplace_back(ret);
|
||||
return query;
|
||||
}
|
||||
|
||||
// AST using variable in return bound by naming the previous return expression.
|
||||
// This is treated as an unbound variable.
|
||||
// MATCH (node_atom_1) RETURN node_atom_1 AS n, n AS n
|
||||
static std::unique_ptr<Query> MatchUnboundMultiReturn() {
|
||||
int uid = 0;
|
||||
auto node_atom = std::make_shared<NodeAtom>(uid++);
|
||||
node_atom->identifier_ = std::make_shared<Identifier>(uid++, "node_atom_1");
|
||||
auto pattern = std::make_shared<Pattern>(uid++);
|
||||
pattern->atoms_.emplace_back(node_atom);
|
||||
auto match = std::make_shared<Match>(uid++);
|
||||
match->patterns_.emplace_back(pattern);
|
||||
auto query = std::make_unique<Query>(uid++);
|
||||
query->clauses_.emplace_back(match);
|
||||
|
||||
auto named_expr_1 = std::make_shared<NamedExpression>(uid++);
|
||||
named_expr_1->name_ = "n";
|
||||
named_expr_1->expression_ = std::make_shared<Identifier>(uid++, "node_atom_1");
|
||||
auto named_expr_2 = std::make_shared<NamedExpression>(uid++);
|
||||
named_expr_2->name_ = "n";
|
||||
named_expr_2->expression_ = std::make_shared<Identifier>(uid++, "n");
|
||||
auto ret = std::make_shared<Return>(uid++);
|
||||
ret->named_expressions_.emplace_back(named_expr_1);
|
||||
ret->named_expressions_.emplace_back(named_expr_2);
|
||||
query->clauses_.emplace_back(ret);
|
||||
return query;
|
||||
}
|
||||
|
||||
// AST with unbound variable in return: MATCH (n) RETURN x AS x
|
||||
static std::unique_ptr<Query> MatchNodeUnboundReturn() {
|
||||
int uid = 0;
|
||||
auto node_atom = std::make_shared<NodeAtom>(uid++);
|
||||
node_atom->identifier_ = std::make_shared<Identifier>(uid++, "n");
|
||||
auto pattern = std::make_shared<Pattern>(uid++);
|
||||
pattern->atoms_.emplace_back(node_atom);
|
||||
auto match = std::make_shared<Match>(uid++);
|
||||
match->patterns_.emplace_back(pattern);
|
||||
auto query = std::make_unique<Query>(uid++);
|
||||
query->clauses_.emplace_back(match);
|
||||
|
||||
auto named_expr = std::make_shared<NamedExpression>(uid++);
|
||||
named_expr->name_ = "x";
|
||||
named_expr->expression_ = std::make_shared<Identifier>(uid++, "x");
|
||||
auto ret = std::make_shared<Return>(uid++);
|
||||
ret->named_expressions_.emplace_back(named_expr);
|
||||
query->clauses_.emplace_back(ret);
|
||||
return query;
|
||||
}
|
||||
|
||||
TEST(TestSymbolGenerator, MatchNodeReturn) {
|
||||
SymbolTable symbol_table;
|
||||
auto query_ast = MatchNodeReturn();
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
query_ast->Accept(symbol_generator);
|
||||
EXPECT_EQ(symbol_table.max_position(), 2);
|
||||
auto match = std::dynamic_pointer_cast<Match>(query_ast->clauses_[0]);
|
||||
auto pattern = match->patterns_[0];
|
||||
auto node_atom = std::dynamic_pointer_cast<NodeAtom>(pattern->atoms_[0]);
|
||||
auto node_sym = symbol_table[*node_atom->identifier_];
|
||||
EXPECT_EQ(node_sym.name_, "node_atom_1");
|
||||
auto ret = std::dynamic_pointer_cast<Return>(query_ast->clauses_[1]);
|
||||
auto named_expr = ret->named_expressions_[0];
|
||||
auto column_sym = symbol_table[*named_expr];
|
||||
EXPECT_EQ(node_sym.name_, column_sym.name_);
|
||||
EXPECT_NE(node_sym, column_sym);
|
||||
auto ret_sym = symbol_table[*named_expr->expression_];
|
||||
EXPECT_EQ(node_sym, ret_sym);
|
||||
}
|
||||
|
||||
TEST(TestSymbolGenerator, MatchUnboundMultiReturn) {
|
||||
SymbolTable symbol_table;
|
||||
auto query_ast = MatchUnboundMultiReturn();
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
EXPECT_THROW(query_ast->Accept(symbol_generator), SemanticException);
|
||||
}
|
||||
|
||||
TEST(TestSymbolGenerator, MatchNodeUnboundReturn) {
|
||||
SymbolTable symbol_table;
|
||||
auto query_ast = MatchNodeUnboundReturn();
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
EXPECT_THROW(query_ast->Accept(symbol_generator), SemanticException);
|
||||
}
|
Loading…
Reference in New Issue
Block a user