diff --git a/src/query/context.hpp b/src/query/context.hpp index eea60317c..92e4b609e 100644 --- a/src/query/context.hpp +++ b/src/query/context.hpp @@ -1,14 +1,52 @@ #pragma once #include "antlr4-runtime.h" +#include "database/graph_db_accessor.hpp" #include "query/frontend/ast/cypher_main_visitor.hpp" +class TypedcheckedTree {}; + +class LogicalPlan {}; + +class Context; + +class LogicalPlanGenerator { + public: + std::vector<LogicalPlan> Generate(TypedcheckedTree&, Context&) { + return {LogicalPlan()}; + } +}; + +struct Config { + LogicalPlanGenerator logical_plan_generator; +}; + class Context { + public: int uid_counter; + Context(Config config, GraphDbAccessor& db_accessor) + : config(config), db_accessor(db_accessor) {} + + Config config; + GraphDbAccessor& db_accessor; +}; + +class LogicalPlanner { + public: + LogicalPlanner(Context ctx) : ctx_(ctx) {} + + LogicalPlan Apply(TypedcheckedTree typedchecked_tree) { + return ctx_.config.logical_plan_generator.Generate(typedchecked_tree, + ctx_)[0]; + } + + private: + Context ctx_; }; class HighLevelAstConversion { - void Apply(const Context &ctx, antlr4::tree::ParseTree *tree) { + public: + void Apply(const Context& ctx, antlr4::tree::ParseTree* tree) { query::frontend::CypherMainVisitor visitor(ctx); visitor.visit(tree); } diff --git a/src/query/entry.hpp b/src/query/entry.hpp index 6a5ae25e2..373a5d5fc 100644 --- a/src/query/entry.hpp +++ b/src/query/entry.hpp @@ -2,21 +2,44 @@ #include "database/graph_db_accessor.hpp" #include "query/frontend/opencypher/parser.hpp" +#include "query/context.hpp" namespace query { template <typename Stream> class Engine { public: - Engine() {} + Engine() { + } auto Execute(const std::string &query, GraphDbAccessor &db_accessor, Stream &stream) { - frontend::opencypher::Parser parser(query); + Config config; + Context ctx(config, db_accessor); + ::frontend::opencypher::Parser parser(query); auto low_level_tree = parser.tree(); - // high level tree - // typechecked tree - // logical tree + auto high_level_tree = low2high_tree.Apply(ctx, low_level_tree); + TypedcheckedTree typechecked_tree; + auto logical_plan = LogicalPlanner(ctx).Apply(typechecked_tree); + // interpret & stream results + // generate frame based on symbol table max_position + Frame frame(size); + auto cursor = logical_plan.MakeCursor(frame); + logical_plan.WriteHeader(stream); + auto symbols = logical_plan.OutputSymbols(symbol_table); + while (cursor.pull(frame, context)) { + std::vector<TypedValue> values(symbols.size()); + for (auto symbol : symbols) { + values.emplace_back(frame[symbol]); + } + stream.Result(values); + } + stream.Summary({"type": "r"}); + } +private: + Context ctx; + HighLevelAstConversion low2high_tree; + }; } diff --git a/src/query/frontend/ast/ast.cpp b/src/query/frontend/ast/ast.cpp index 34abc30dc..a6a32fbeb 100644 --- a/src/query/frontend/ast/ast.cpp +++ b/src/query/frontend/ast/ast.cpp @@ -7,4 +7,7 @@ TypedValue Ident::Evaluate(Frame& frame, SymbolTable& symbol_table) { return frame[symbol_table[*this].position_]; } +void NamedExpr::Evaluate(Frame& frame, SymbolTable& symbol_table) { + frame[symbol_table[*ident_].position_] = expr_->Evaluate(frame, symbol_table); +} } diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index 772e3234d..fb1ea4efd 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -50,24 +50,24 @@ class TreeVisitorBase { }; class Tree { -public: + public: Tree(const int uid) : uid_(uid) {} int uid() const { return uid_; } virtual void Accept(TreeVisitorBase& visitor) = 0; -private: + private: const int uid_; }; class Expr : public Tree { -public: - virtual TypedValue Evaluate(Frame &, SymbolTable &) = 0; + public: + virtual TypedValue Evaluate(Frame&, SymbolTable&) = 0; }; class Ident : public Expr { -public: + public: std::string identifier_; - TypedValue Evaluate(Frame &frame, SymbolTable &symbol_table) override; + TypedValue Evaluate(Frame& frame, SymbolTable& symbol_table) override; void Accept(TreeVisitorBase& visitor) override { visitor.PreVisit(*this); visitor.Visit(*this); @@ -77,8 +77,15 @@ public: class Part : public Tree {}; +class NamedExpr : public Tree { + public: + std::shared_ptr<Ident> ident_; + std::shared_ptr<Expr> expr_; + void Evaluate(Frame& frame, SymbolTable& symbol_table); +}; + class NodePart : public Part { -public: + public: Ident identifier_; // TODO: Mislav call GraphDb::label(label_name) to populate labels_! std::vector<GraphDb::Label> labels_; @@ -92,7 +99,7 @@ public: }; class EdgePart : public Part { -public: + public: Ident identifier_; // TODO: finish this: properties, types... void Accept(TreeVisitorBase& visitor) override { @@ -106,7 +113,7 @@ public: class Clause : public Tree {}; class Pattern : public Tree { -public: + public: std::vector<std::shared_ptr<NodePart>> node_parts_; void Accept(TreeVisitorBase& visitor) override { visitor.PreVisit(*this); @@ -119,7 +126,7 @@ public: }; class Query : public Tree { -public: + public: std::vector<std::unique_ptr<Clause>> clauses_; void Accept(TreeVisitorBase& visitor) override { visitor.PreVisit(*this); @@ -132,7 +139,7 @@ public: }; class Match : public Clause { -public: + public: std::vector<std::unique_ptr<Pattern>> patterns_; void Accept(TreeVisitorBase& visitor) override { visitor.PreVisit(*this); @@ -145,7 +152,7 @@ public: }; class Return : public Clause { -public: + public: std::vector<std::shared_ptr<Expr>> exprs_; void Accept(TreeVisitorBase& visitor) override { visitor.PreVisit(*this); diff --git a/src/query/frontend/logical/operator.hpp b/src/query/frontend/logical/operator.hpp index b82b735ea..829666443 100644 --- a/src/query/frontend/logical/operator.hpp +++ b/src/query/frontend/logical/operator.hpp @@ -9,6 +9,24 @@ #include "query/frontend/typecheck/symbol_table.hpp" namespace query { + +class ConsoleResultStream : public Loggable { + public: + ConsoleResultStream() : Loggable("ConsoleResultStream") {} + + void Header(const std::vector<std::string>&) { logger.info("header"); } + + void Result(std::vector<TypedValue>& values) { + for (auto value : values) { + logger.info(" result"); + } + } + + void Summary(const std::map<std::string, TypedValue>&) { + logger.info("summary"); + } +}; + class Cursor { public: virtual bool pull(Frame&, SymbolTable&) = 0; @@ -19,6 +37,10 @@ class LogicalOperator { public: auto children() { return children_; }; virtual std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor db) = 0; + virtual void WriteHeader(ConsoleResultStream&) {} + virtual std::vector<Symbol> OutputSymbols(SymbolTable& symbol_table) { + return {}; + } virtual ~LogicalOperator() {} protected: @@ -32,8 +54,10 @@ class ScanAll : public LogicalOperator { private: class ScanAllCursor : public Cursor { public: - ScanAllCursor(ScanAll& parent, GraphDbAccessor db) - : parent_(parent), db_(db), vertices_(db.vertices()), + ScanAllCursor(ScanAll& self, GraphDbAccessor db) + : self_(self), + db_(db), + vertices_(db.vertices()), vertices_it_(vertices_.begin()) {} bool pull(Frame& frame, SymbolTable& symbol_table) override { @@ -47,14 +71,14 @@ class ScanAll : public LogicalOperator { } private: - ScanAll& parent_; + ScanAll& self_; GraphDbAccessor db_; decltype(db_.vertices()) vertices_; decltype(vertices_.begin()) vertices_it_; bool evaluate(Frame& frame, SymbolTable& symbol_table, VertexAccessor& vertex) { - auto node_part = parent_.node_part_; + auto node_part = self_.node_part_; for (auto label : node_part->labels_) { if (!vertex.has_label(label)) return false; } @@ -65,8 +89,7 @@ class ScanAll : public LogicalOperator { public: std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor db) override { - Cursor* cursor = new ScanAllCursor(*this, db); - return std::unique_ptr<Cursor>(cursor); + return std::make_unique<Cursor>(ScanAllCursor(*this, db)); } private: @@ -76,29 +99,56 @@ class ScanAll : public LogicalOperator { class Produce : public LogicalOperator { public: - Produce(std::shared_ptr<LogicalOperator> op, std::vector<std::shared_ptr<Expr>> exprs) - : exprs_(exprs) { - children_.emplace_back(op); + Produce(std::shared_ptr<LogicalOperator> input, + std::vector<std::shared_ptr<NamedExpr>> exprs) + : input_(input), exprs_(exprs) { + children_.emplace_back(input); } + void WriteHeader(ConsoleResultStream& stream) override { + // TODO: write real result + stream.Header({"n"}); + } + + std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor db) override { + return std::make_unique<Cursor>(ProduceCursor(*this, db)); + } + + std::vector<Symbol> OutputSymbols(SymbolTable& symbol_table) override { + std::vector<Symbol> result(exprs_.size()); + for (auto named_expr : exprs_) { + result.emplace_back(symbol_table[*named_expr->ident_]); + } + return result; +} + private: class ProduceCursor : public Cursor { public: - ProduceCursor(Produce& parent) : parent_(parent) {} - bool pull(Frame &frame, SymbolTable& symbol_table) override { - for (auto expr : parent_.exprs_) { - frame[symbol_table[*expr].position_] = expr->Evaluate(frame, symbol_table); + ProduceCursor(Produce& self, GraphDbAccessor db) + : self_(self), self_cursor_(self_.MakeCursor(db)) {} + bool pull(Frame& frame, SymbolTable& symbol_table) override { + if (self_cursor_->pull(frame, symbol_table)) { + for (auto expr : self_.exprs_) { + expr->Evaluate(frame, symbol_table); + } + return true; } - return true; + return false; } + private: - Produce& parent_; + Produce& self_; + std::unique_ptr<Cursor> self_cursor_; }; + public: std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor) override { return std::unique_ptr<Cursor>(new ProduceCursor(*this)); } + private: - std::vector<std::shared_ptr<Expr>> exprs_; + std::shared_ptr<LogicalOperator> input_; + std::vector<std::shared_ptr<NamedExpr>> exprs_; }; } diff --git a/tests/manual/compiler_prototype.cpp b/tests/manual/compiler_prototype.cpp index 7907978af..665acc041 100644 --- a/tests/manual/compiler_prototype.cpp +++ b/tests/manual/compiler_prototype.cpp @@ -10,23 +10,6 @@ using std::cout; using std::cin; using std::endl; -class ConsoleResultStream : public Loggable { - public: - ConsoleResultStream() : Loggable("ConsoleResultStream") {} - - void Header(const std::vector<std::string>&) { logger.info("header"); } - - void Result(std::vector<TypedValue>& values) { - for (auto value : values) { - logger.info(" result"); - } - } - - void Summary(const std::map<std::string, TypedValue>&) { - logger.info("summary"); - } -}; - int main(int argc, char* argv[]) { // init arguments REGISTER_ARGS(argc, argv);