diff --git a/CMakeLists.txt b/CMakeLists.txt index 5f00a3439..72999f286 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -348,7 +348,7 @@ set(memgraph_src_files ${src_dir}/database/graph_db.cpp ${src_dir}/database/graph_db_accessor.cpp ${src_dir}/query/stripper.cpp - ${src_dir}/query/backend/cpp/cypher_main_visitor.cpp + ${src_dir}/query/frontend/ast/cypher_main_visitor.cpp ${src_dir}/query/backend/cpp/typed_value.cpp ${src_dir}/query/frontend/ast/ast.cpp ) diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index e2c201227..18ce5da4c 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -11,10 +11,49 @@ namespace query { class Frame; class SymbolTable; +// Forward declares for TreeVisitorBase +class Query; +class Ident; +class Match; +class Return; +class Pattern; +class NodePart; +class EdgePart; + +class TreeVisitorBase { + public: + // Start of the tree is a Query. + virtual void PreVisit(Query& query) {} + virtual void Visit(Query& query) = 0; + virtual void PostVisit(Query& query) {} + // Expressions + virtual void PreVisit(Ident& ident) {} + virtual void Visit(Ident& ident) = 0; + virtual void PostVisit(Ident& ident) {} + // Clauses + virtual void PreVisit(Match& match) {} + virtual void Visit(Match& match) = 0; + virtual void PostVisit(Match& match) {} + virtual void PreVisit(Return& ret) {} + virtual void Visit(Return& ret) = 0; + virtual void PostVisit(Return& ret) {} + // Pattern and its subparts. + virtual void PreVisit(Pattern& pattern) {} + virtual void Visit(Pattern& pattern) = 0; + virtual void PostVisit(Pattern& pattern) {} + virtual void PreVisit(NodePart& node_part) {} + virtual void Visit(NodePart& node_part) = 0; + virtual void PostVisit(NodePart& node_part) {} + virtual void PreVisit(EdgePart& edge_part) {} + virtual void Visit(EdgePart& edge_part) = 0; + virtual void PostVisit(EdgePart& edge_part) {} +}; + class Tree { public: Tree(const int uid) : uid_(uid) {} int uid() const { return uid_; } + virtual void Accept(TreeVisitorBase& visitor) = 0; private: const int uid_; @@ -29,9 +68,14 @@ class Ident : public Expr { public: std::string identifier_; TypedValue Evaluate(Frame &frame, SymbolTable &symbol_table) override; + void Accept(TreeVisitorBase& visitor) override { + visitor.PreVisit(*this); + visitor.Visit(*this); + visitor.PostVisit(*this); + } }; -class Part {}; +class Part : public Tree {}; class NodePart : public Part { public: @@ -39,12 +83,24 @@ public: // TODO: Mislav call GraphDb::label(label_name) to populate labels_! std::vector<GraphDb::Label> labels_; // TODO: properties + void Accept(TreeVisitorBase& visitor) override { + visitor.PreVisit(*this); + identifier_.Accept(visitor); + visitor.Visit(*this); + visitor.PostVisit(*this); + } }; class EdgePart : public Part { public: Ident identifier_; // TODO: finish this: properties, types... + void Accept(TreeVisitorBase& visitor) override { + visitor.PreVisit(*this); + identifier_.Accept(visitor); + visitor.Visit(*this); + visitor.PostVisit(*this); + } }; class Clause : public Tree {}; @@ -52,20 +108,52 @@ class Clause : public Tree {}; class Pattern : public Tree { public: std::vector<std::unique_ptr<Part>> node_parts_; + void Accept(TreeVisitorBase& visitor) override { + visitor.PreVisit(*this); + for (auto& node_part : node_parts_) { + node_part->Accept(visitor); + } + visitor.Visit(*this); + visitor.PostVisit(*this); + } }; class Query : public Tree { public: std::vector<std::unique_ptr<Clause>> clauses_; + void Accept(TreeVisitorBase& visitor) override { + visitor.PreVisit(*this); + for (auto& clause : clauses_) { + clause->Accept(visitor); + } + visitor.Visit(*this); + visitor.PostVisit(*this); + } }; class Match : public Clause { public: std::vector<std::unique_ptr<Pattern>> patterns_; + void Accept(TreeVisitorBase& visitor) override { + visitor.PreVisit(*this); + for (auto& pattern : patterns_) { + pattern->Accept(visitor); + } + visitor.Visit(*this); + visitor.PostVisit(*this); + } }; class Return : public Clause { public: std::vector<std::unique_ptr<Expr>> exprs_; + void Accept(TreeVisitorBase& visitor) override { + visitor.PreVisit(*this); + for (auto& expr : exprs_) { + expr->Accept(visitor); + } + visitor.Visit(*this); + visitor.PostVisit(*this); + } }; } diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index 368785014..f60b152a7 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -1,4 +1,4 @@ -#include "query/backend/cpp/cypher_main_visitor.hpp" +#include "query/frontend/ast/cypher_main_visitor.hpp" #include <climits> #include <string> diff --git a/src/query/frontend/typecheck/typecheck.hpp b/src/query/frontend/typecheck/typecheck.hpp new file mode 100644 index 000000000..c10558124 --- /dev/null +++ b/src/query/frontend/typecheck/typecheck.hpp @@ -0,0 +1,73 @@ +#pragma once + +#include "utils/exceptions/basic_exception.hpp" +#include "query/frontend/ast/ast.hpp" +#include "query/frontend/typecheck/symbol_table.hpp" + +namespace query { + +class TypeCheckVisitor : public TreeVisitorBase { + public: + TypeCheckVisitor(SymbolTable& symbol_table) : symbol_table_(symbol_table_) {} + + // Start of the tree is a Query. + void Visit(Query& query) override {} + // Expressions + void Visit(Ident& ident) override { + Symbol symbol; + } if (scope_.in_pattern) { + symbol = GetOrCreateSymbol(ident.identifier_); + } else { + if (!HasSymbol(ident.identifier_)) + // TODO: Special exception for type check + throw BasicException("Unbound identifier: " + ident.identifier_); + symbol = scope_.variables[ident.identifier_]; + } + symbol_table_[ident] = symbol; + } + // Clauses + void Visit(Match& match) override {} + void PreVisit(Return& ret) override { + scope_.in_return = true; + } + void PostVisit(Return& ret) override { + scope_.in_return = false; + } + void Visit(Return& ret) override {} + // Pattern and its subparts. + void PreVisit(Pattern& pattern) override { + scope_.in_pattern = true; + } + void PostVisit(Pattern& pattern) override { + scope_.in_pattern = false; + } + void Visit(Pattern& pattern) override {} + void Visit(NodePart& node_part) override {} + void Visit(EdgePart& edge_part) override {} + + private: + struct Scope { + Scope() : in_pattern(false), in_return(false) {} + bool in_pattern; + bool in_return; + std::map<std::string, Symbol> variables; + }; + + bool HasSymbol(const std::string& name) + { + return scope_.variables.find(name) != scope_.variables.end(); + } + + Symbol GetOrCreateSymbol(const std::string& name) + { + auto search = scope_.variables.find(name) + if (search != scope_.variables.end()) { + return *search; + } + scope_.variables[name] = symbol_table_.CreateSymbol(name); + } + SymbolTable& symbol_table_; + Scope scope_; +}; + +}