From 0db7883670a514d6c5b04805dad78b3d8ac1b8fe Mon Sep 17 00:00:00 2001 From: Mislav Bradac Date: Thu, 16 Mar 2017 15:00:34 +0100 Subject: [PATCH] Add AstTreeStorage Reviewers: teon.banek Reviewed By: teon.banek Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D134 --- src/query/context.cpp | 4 +- src/query/context.hpp | 5 +- src/query/frontend/ast/ast.hpp | 125 +++++++++++++----- .../frontend/ast/cypher_main_visitor.cpp | 75 +++++------ .../frontend/ast/cypher_main_visitor.hpp | 5 +- src/query/frontend/logical/operator.hpp | 16 +-- src/query/frontend/logical/planner.cpp | 13 +- src/query/frontend/logical/planner.hpp | 5 +- tests/unit/interpreter.cpp | 84 +++++------- tests/unit/query_semantic.cpp | 79 +++++------ 10 files changed, 226 insertions(+), 185 deletions(-) diff --git a/src/query/context.cpp b/src/query/context.cpp index 10accd9b9..132777b3e 100644 --- a/src/query/context.cpp +++ b/src/query/context.cpp @@ -3,8 +3,8 @@ namespace query { -std::shared_ptr HighLevelAstConversion::Apply(Context &ctx, - antlr4::tree::ParseTree *tree) { +Query *HighLevelAstConversion::Apply(Context &ctx, + antlr4::tree::ParseTree *tree) { query::frontend::CypherMainVisitor visitor(ctx); visitor.visit(tree); return visitor.query(); diff --git a/src/query/context.hpp b/src/query/context.hpp index e065c27fd..fed918ae0 100644 --- a/src/query/context.hpp +++ b/src/query/context.hpp @@ -27,11 +27,8 @@ class Context { public: Context(Config config, GraphDbAccessor &db_accessor) : config_(config), db_accessor_(db_accessor) {} - int next_uid() { return uid_counter_++; } - Config config_; GraphDbAccessor &db_accessor_; - int uid_counter_ = 0; }; class LogicalPlanner { @@ -49,6 +46,6 @@ private: class HighLevelAstConversion { public: - std::shared_ptr Apply(Context &ctx, antlr4::tree::ParseTree *tree); + Query *Apply(Context &ctx, antlr4::tree::ParseTree *tree); }; } diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index 2edbec2ac..dc177a9ac 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -9,43 +9,45 @@ namespace query { +class AstTreeStorage; + class Tree : public ::utils::Visitable { -public: - Tree(int uid) : uid_(uid) {} + friend class AstTreeStorage; + public: int uid() const { return uid_; } + protected: + Tree(int uid) : uid_(uid) {} + private: const int uid_; }; class Expression : public Tree { - public: + protected: Expression(int uid) : Tree(uid) {} }; class Identifier : public Expression { + friend class AstTreeStorage; public: - Identifier(int uid, const std::string &name) : Expression(uid), name_(name) {} - - DEFVISITABLE(TreeVisitorBase) - + DEFVISITABLE(TreeVisitorBase); std::string name_; + + protected: + Identifier(int uid, const std::string &name) : Expression(uid), name_(name) {} }; class PropertyLookup : public Expression { + friend class AstTreeStorage; public: - PropertyLookup(int uid, std::shared_ptr expression, - GraphDb::Property property) - : Expression(uid), expression_(expression), property_(property) {} - void Accept(TreeVisitorBase &visitor) override { visitor.Visit(*this); expression_->Accept(visitor); visitor.PostVisit(*this); } - std::shared_ptr - expression_; // vertex or edge, what if map literal??? + Expression *expression_; GraphDb::Property property_; // TODO potential problem: property lookups are allowed on both map literals // and records, but map literals have strings as keys and records have @@ -53,11 +55,16 @@ class PropertyLookup : public Expression { // // possible solution: store both string and GraphDb::Property here and choose // between the two depending on Expression result + + protected: + PropertyLookup(int uid, Expression* expression, + GraphDb::Property property) + : Expression(uid), expression_(expression), property_(property) {} }; class NamedExpression : public Tree { + friend class AstTreeStorage; public: - NamedExpression(int uid) : Tree(uid) {} void Accept(TreeVisitorBase &visitor) override { visitor.Visit(*this); expression_->Accept(visitor); @@ -65,33 +72,45 @@ class NamedExpression : public Tree { } std::string name_; - std::shared_ptr expression_; + Expression* expression_; + + protected: + NamedExpression(int uid) : Tree(uid) {} + NamedExpression(int uid, std::string name, Expression *expression) : + Tree(uid), name_(name), expression_(expression) {} }; class PatternAtom : public Tree { - public: + friend class AstTreeStorage; + protected: PatternAtom(int uid) : Tree(uid) {} }; class NodeAtom : public PatternAtom { + friend class AstTreeStorage; public: - NodeAtom(int uid) : PatternAtom(uid) {} void Accept(TreeVisitorBase &visitor) override { visitor.Visit(*this); identifier_->Accept(visitor); visitor.PostVisit(*this); } - std::shared_ptr identifier_; + Identifier* identifier_; std::vector labels_; - std::map> properties_; + std::map properties_; + + protected: + NodeAtom(int uid) : PatternAtom(uid) {} + NodeAtom(int uid, Identifier *identifier) : + PatternAtom(uid), identifier_(identifier) {} + }; class EdgeAtom : public PatternAtom { + friend class AstTreeStorage; public: enum class Direction { LEFT, RIGHT, BOTH }; - EdgeAtom(int uid) : PatternAtom(uid) {} void Accept(TreeVisitorBase &visitor) override { visitor.Visit(*this); identifier_->Accept(visitor); @@ -99,17 +118,21 @@ class EdgeAtom : public PatternAtom { } Direction direction = Direction::BOTH; - std::shared_ptr identifier_; + Identifier* identifier_; + + protected: + EdgeAtom(int uid) : PatternAtom(uid) {} }; class Clause : public Tree { + friend class AstTreeStorage; public: Clause(int uid) : Tree(uid) {} }; class Pattern : public Tree { + friend class AstTreeStorage; public: - Pattern(int uid) : Tree(uid) {} void Accept(TreeVisitorBase &visitor) override { visitor.Visit(*this); for (auto &part : atoms_) { @@ -117,13 +140,16 @@ class Pattern : public Tree { } visitor.PostVisit(*this); } - std::shared_ptr identifier_; - std::vector> atoms_; + Identifier* identifier_; + std::vector atoms_; + + protected: + Pattern(int uid) : Tree(uid) {} }; class Query : public Tree { + friend class AstTreeStorage; public: - Query(int uid) : Tree(uid) {} void Accept(TreeVisitorBase &visitor) override { visitor.Visit(*this); for (auto &clause : clauses_) { @@ -131,7 +157,10 @@ class Query : public Tree { } visitor.PostVisit(*this); } - std::vector> clauses_; + std::vector clauses_; + + protected: + Query(int uid) : Tree(uid) {} }; class Create : public Clause { @@ -148,9 +177,9 @@ class Create : public Clause { }; class Match : public Clause { + friend class AstTreeStorage; public: - Match(int uid) : Clause(uid) {} - std::vector> patterns_; + std::vector patterns_; void Accept(TreeVisitorBase &visitor) override { visitor.Visit(*this); for (auto &pattern : patterns_) { @@ -158,11 +187,14 @@ class Match : public Clause { } visitor.PostVisit(*this); } + + protected: + Match(int uid) : Clause(uid) {} }; class Return : public Clause { + friend class AstTreeStorage; public: - Return(int uid) : Clause(uid) {} void Accept(TreeVisitorBase &visitor) override { visitor.Visit(*this); for (auto &expr : named_expressions_) { @@ -170,6 +202,39 @@ class Return : public Clause { } visitor.PostVisit(*this); } - std::vector> named_expressions_; + std::vector named_expressions_; + + protected: + Return(int uid) : Clause(uid) {} +}; + +// It would be better to call this AstTree, but we already have a class Tree, +// which could be renamed to Node or AstTreeNode, but we also have a class +// called NodeAtom... +class AstTreeStorage { + friend class AstTreeStorage; + + public: + AstTreeStorage() { + storage_.emplace_back(new Query(next_uid_++)); + } + AstTreeStorage(const AstTreeStorage &) = delete; + AstTreeStorage &operator=(const AstTreeStorage &) = delete; + + template + T *Create(Args&&... args) { + // Never call create for a Query. Call query() instead. + static_assert(!std::is_same::value, "Call query() instead"); + // TODO: use std::forward here + T *p = new T(next_uid_++, args...); + storage_.emplace_back(p); + return p; + } + + Query *query() { return dynamic_cast(storage_[0].get()); } + + private: + int next_uid_ = 0; + std::vector> storage_; }; } diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index 7c9be517f..b006c48fd 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -41,10 +41,9 @@ namespace { antlrcpp::Any CypherMainVisitor::visitSingleQuery(CypherParser::SingleQueryContext *ctx) { - query_ = std::make_shared(ctx_.next_uid()); + query_ = storage_.query(); for (auto *child : ctx->clause()) { - query_->clauses_.push_back( - child->accept(this).as>()); + query_->clauses_.push_back(child->accept(this)); } return query_; } @@ -52,18 +51,17 @@ antlrcpp::Any CypherMainVisitor::visitClause(CypherParser::ClauseContext *ctx) { if (!ctx->cypherReturn() && !ctx->cypherMatch()) { throw std::exception(); } - return 0; + return visitChildren(ctx); } antlrcpp::Any CypherMainVisitor::visitCypherMatch(CypherParser::CypherMatchContext *ctx) { - auto match = std::make_shared(ctx_.next_uid()); + auto *match = storage_.Create(); if (ctx->OPTIONAL() || ctx->where()) { throw std::exception(); } - match->patterns_ = - ctx->pattern()->accept(this).as>>(); - return std::shared_ptr(match); + match->patterns_ = ctx->pattern()->accept(this).as>(); + return match; } antlrcpp::Any @@ -84,22 +82,22 @@ CypherMainVisitor::visitReturnBody(CypherParser::ReturnBodyContext *ctx) { antlrcpp::Any CypherMainVisitor::visitReturnItems(CypherParser::ReturnItemsContext *ctx) { - auto return_clause = std::make_shared(ctx_.next_uid()); + auto *return_clause = storage_.Create(); if (ctx->getTokens(kReturnAllTokenId).size()) { throw std::exception(); } for (auto *item : ctx->returnItem()) { return_clause->named_expressions_.push_back(item->accept(this)); } - return std::shared_ptr(return_clause); + return return_clause; } antlrcpp::Any CypherMainVisitor::visitReturnItem(CypherParser::ReturnItemContext *ctx) { - auto named_expr = std::make_shared(ctx_.next_uid()); + auto *named_expr = storage_.Create(); if (ctx->variable()) { named_expr->name_ = - std::string(ctx_.next_uid(), ctx->variable()->accept(this)); + std::string(ctx->variable()->accept(this).as()); } else { // TODO: Should we get this by text or some escaping is needed? named_expr->name_ = std::string(ctx->getText()); @@ -110,16 +108,16 @@ CypherMainVisitor::visitReturnItem(CypherParser::ReturnItemContext *ctx) { antlrcpp::Any CypherMainVisitor::visitNodePattern(CypherParser::NodePatternContext *ctx) { - auto node = std::make_shared(ctx_.next_uid()); + auto *node = storage_.Create(); if (ctx->variable()) { // TODO: user's identifiers should be unchanged, but we must be sure that // ours identifier is not in a clash with user's. std::string variable = ctx->variable()->accept(this); - node->identifier_ = std::make_shared( - ctx_.next_uid(), kUserIdentPrefix + variable); + node->identifier_ = + storage_.Create(kUserIdentPrefix + variable); } else { - node->identifier_ = std::make_shared( - ctx_.next_uid(), kAnonIdentPrefix + std::to_string(next_ident_id_++)); + node->identifier_ = storage_.Create( + kAnonIdentPrefix + std::to_string(next_ident_id_++)); } if (ctx->nodeLabels()) { std::vector labels = ctx->nodeLabels()->accept(this); @@ -135,7 +133,7 @@ CypherMainVisitor::visitNodePattern(CypherParser::NodePatternContext *ctx) { // .as>(); } - return std::shared_ptr(node); + return node; } antlrcpp::Any @@ -184,7 +182,7 @@ CypherMainVisitor::visitSymbolicName(CypherParser::SymbolicNameContext *ctx) { antlrcpp::Any CypherMainVisitor::visitPattern(CypherParser::PatternContext *ctx) { - std::vector> patterns; + std::vector patterns; for (auto *pattern_part : ctx->patternPart()) { patterns.push_back(pattern_part->accept(this)); } @@ -193,15 +191,15 @@ CypherMainVisitor::visitPattern(CypherParser::PatternContext *ctx) { antlrcpp::Any CypherMainVisitor::visitPatternPart(CypherParser::PatternPartContext *ctx) { - std::shared_ptr pattern = ctx->anonymousPatternPart()->accept(this); + Pattern *pattern = ctx->anonymousPatternPart()->accept(this); if (ctx->variable()) { // TODO: don't change user's identifier name. std::string variable = ctx->variable()->accept(this); - pattern->identifier_ = std::make_shared( - ctx_.next_uid(), kUserIdentPrefix + variable); + pattern->identifier_ = + storage_.Create(kUserIdentPrefix + variable); } else { - pattern->identifier_ = std::make_shared( - ctx_.next_uid(), kAnonIdentPrefix + std::to_string(next_ident_id_++)); + pattern->identifier_ = storage_.Create( + kAnonIdentPrefix + std::to_string(next_ident_id_++)); } return pattern; } @@ -211,11 +209,11 @@ antlrcpp::Any CypherMainVisitor::visitPatternElement( if (ctx->patternElement()) { return ctx->patternElement()->accept(this); } - auto pattern = std::make_shared(ctx_.next_uid()); + auto pattern = storage_.Create(); pattern->atoms_.push_back(ctx->nodePattern()->accept(this)); for (auto *pattern_element_chain : ctx->patternElementChain()) { - std::pair, std::shared_ptr> - element = pattern_element_chain->accept(this); + std::pair element = + pattern_element_chain->accept(this); pattern->atoms_.push_back(element.first); pattern->atoms_.push_back(element.second); } @@ -224,23 +222,21 @@ antlrcpp::Any CypherMainVisitor::visitPatternElement( antlrcpp::Any CypherMainVisitor::visitPatternElementChain( CypherParser::PatternElementChainContext *ctx) { - return std::pair, std::shared_ptr>( - ctx->relationshipPattern() - ->accept(this) - .as>(), - ctx->nodePattern()->accept(this).as>()); + return std::pair( + ctx->relationshipPattern()->accept(this), + ctx->nodePattern()->accept(this)); } antlrcpp::Any CypherMainVisitor::visitRelationshipPattern( CypherParser::RelationshipPatternContext *ctx) { - auto edge = std::make_shared(ctx_.next_uid()); + auto *edge = storage_.Create(); if (ctx->relationshipDetail()) { if (ctx->relationshipDetail()->variable()) { std::string variable = ctx->relationshipDetail()->variable()->accept(this); // TODO: Don't change user's identifier name. - edge->identifier_ = std::make_shared( - ctx_.next_uid(), kUserIdentPrefix + variable); + edge->identifier_ = + storage_.Create(kUserIdentPrefix + variable); } if (ctx->relationshipDetail()->relationshipTypes()) { throw std::exception(); @@ -260,8 +256,8 @@ antlrcpp::Any CypherMainVisitor::visitRelationshipPattern( // relationship.lower_bound = range.first; // relationship.upper_bound = range.second; if (!edge->identifier_) { - edge->identifier_ = std::make_shared( - ctx_.next_uid(), kAnonIdentPrefix + std::to_string(next_ident_id_++)); + edge->identifier_ = storage_.Create( + kAnonIdentPrefix + std::to_string(next_ident_id_++)); } if (ctx->leftArrowHead() && !ctx->rightArrowHead()) { @@ -273,7 +269,7 @@ antlrcpp::Any CypherMainVisitor::visitRelationshipPattern( // grammar. edge->direction = EdgeAtom::Direction::BOTH; } - return std::shared_ptr(edge); + return edge; } antlrcpp::Any CypherMainVisitor::visitRelationshipDetail( @@ -550,8 +546,7 @@ antlrcpp::Any CypherMainVisitor::visitAtom(CypherParser::AtomContext *ctx) { return ctx->parenthesizedExpression()->accept(this); } else if (ctx->variable()) { std::string variable = ctx->variable()->accept(this); - return std::shared_ptr(std::make_shared( - ctx_.next_uid(), kUserIdentPrefix + variable)); + return storage_.Create(kUserIdentPrefix + variable); } // TODO: Implement this. We don't support comprehensions, functions, // filtering... at the moment. diff --git a/src/query/frontend/ast/cypher_main_visitor.hpp b/src/query/frontend/ast/cypher_main_visitor.hpp index f73b8d8a4..3e95f8bc1 100644 --- a/src/query/frontend/ast/cypher_main_visitor.hpp +++ b/src/query/frontend/ast/cypher_main_visitor.hpp @@ -289,14 +289,15 @@ private: visitIntegerLiteral(CypherParser::IntegerLiteralContext *ctx) override; public: - std::shared_ptr query() { return query_; } + Query *query() { return query_; } private: Context &ctx_; int next_ident_id_; const std::string kUserIdentPrefix = "u_"; const std::string kAnonIdentPrefix = "a_"; - std::shared_ptr query_; + AstTreeStorage storage_; + Query *query_; }; } } diff --git a/src/query/frontend/logical/operator.hpp b/src/query/frontend/logical/operator.hpp index 6064a59a2..d2679609d 100644 --- a/src/query/frontend/logical/operator.hpp +++ b/src/query/frontend/logical/operator.hpp @@ -29,7 +29,7 @@ class LogicalOperator { class CreateOp : public LogicalOperator { public: - CreateOp(std::shared_ptr node_atom) : node_atom_(node_atom) {} + CreateOp(NodeAtom* node_atom) : node_atom_(node_atom) {} private: class CreateOpCursor : public Cursor { @@ -66,12 +66,12 @@ public: } private: - std::shared_ptr node_atom_; + NodeAtom* node_atom_; }; class ScanAll : public LogicalOperator { public: - ScanAll(std::shared_ptr node_atom) : node_atom_(node_atom) {} + ScanAll(NodeAtom *node_atom) : node_atom_(node_atom) {} private: class ScanAllCursor : public Cursor { @@ -99,7 +99,7 @@ class ScanAll : public LogicalOperator { } private: - std::shared_ptr node_atom_; + NodeAtom *node_atom_; }; class NodeFilter : public LogicalOperator { @@ -107,7 +107,7 @@ class NodeFilter : public LogicalOperator { NodeFilter( std::shared_ptr input, Symbol input_symbol, std::vector labels, - std::map> properties) + std::map properties) : input_(input), input_symbol_(input_symbol), labels_(labels), @@ -158,13 +158,13 @@ class NodeFilter : public LogicalOperator { std::shared_ptr input_; const Symbol input_symbol_; std::vector labels_; - std::map> properties_; + std::map properties_; }; class Produce : public LogicalOperator { public: Produce(std::shared_ptr input, - std::vector> named_expressions) + std::vector named_expressions) : input_(input), named_expressions_(named_expressions) { children_.emplace_back(input); } @@ -198,6 +198,6 @@ class Produce : public LogicalOperator { private: std::shared_ptr input_; - std::vector> named_expressions_; + std::vector named_expressions_; }; } diff --git a/src/query/frontend/logical/planner.cpp b/src/query/frontend/logical/planner.cpp index 34f115acf..f0d252934 100644 --- a/src/query/frontend/logical/planner.cpp +++ b/src/query/frontend/logical/planner.cpp @@ -5,6 +5,8 @@ namespace query { +namespace { + static LogicalOperator *GenCreate( Create& create, std::shared_ptr input_op) { @@ -18,11 +20,11 @@ static LogicalOperator *GenCreate( if (pattern->atoms_.size() != 1) { throw NotYetImplemented(); } - auto node_atom = std::dynamic_pointer_cast(pattern->atoms_[0]); + auto *node_atom = dynamic_cast(pattern->atoms_[0]); return new CreateOp(node_atom); } -static LogicalOperator *GenMatch( +LogicalOperator *GenMatch( Match& match, std::shared_ptr input_op, const SymbolTable &symbol_table) @@ -37,7 +39,7 @@ static LogicalOperator *GenMatch( if (pattern->atoms_.size() != 1) { throw NotYetImplemented(); } - auto node_atom = std::dynamic_pointer_cast(pattern->atoms_[0]); + auto *node_atom = dynamic_cast(pattern->atoms_[0]); auto *scan_all = new ScanAll(node_atom); if (!node_atom->labels_.empty() || !node_atom->properties_.empty()) { auto &input_symbol = symbol_table.at(*node_atom->identifier_); @@ -47,20 +49,21 @@ static LogicalOperator *GenMatch( return scan_all; } -static Produce *GenReturn(Return& ret, std::shared_ptr input_op) +Produce *GenReturn(Return& ret, std::shared_ptr input_op) { if (!input_op) { throw NotYetImplemented(); } return new Produce(input_op, ret.named_expressions_); } +} std::unique_ptr MakeLogicalPlan( Query& query, const SymbolTable &symbol_table) { LogicalOperator *input_op = nullptr; for (auto &clause : query.clauses_) { - auto *clause_ptr = clause.get(); + auto *clause_ptr = clause; if (auto *match = dynamic_cast(clause_ptr)) { input_op = GenMatch(*match, std::shared_ptr(input_op), symbol_table); diff --git a/src/query/frontend/logical/planner.hpp b/src/query/frontend/logical/planner.hpp index d9d03cbf5..59c0af551 100644 --- a/src/query/frontend/logical/planner.hpp +++ b/src/query/frontend/logical/planner.hpp @@ -12,7 +12,6 @@ class SymbolTable; // Returns the root of LogicalOperator tree. The tree is constructed by // traversing the given AST Query node. SymbolTable is used to determine inputs // and outputs of certain operators. -std::unique_ptr MakeLogicalPlan( - Query& query, const SymbolTable &symbol_table); - +std::unique_ptr +MakeLogicalPlan(Query &query, const SymbolTable &symbol_table); } diff --git a/tests/unit/interpreter.cpp b/tests/unit/interpreter.cpp index a3ee818a6..acd1c3879 100644 --- a/tests/unit/interpreter.cpp +++ b/tests/unit/interpreter.cpp @@ -46,7 +46,8 @@ auto CollectProduce(std::shared_ptr produce, SymbolTable &symbol_table, auto cursor = produce->MakeCursor(db_accessor); while (cursor->Pull(frame, symbol_table)) { std::vector values; - for (auto &symbol : symbols) values.emplace_back(frame[symbol]); + for (auto &symbol : symbols) + values.emplace_back(frame[symbol]); stream.Result(values); } @@ -59,34 +60,12 @@ void ExecuteCreate(std::shared_ptr create, GraphDbAccessor &db) { SymbolTable symbol_table; Frame frame(symbol_table.max_position()); auto cursor = create->MakeCursor(db); - while (cursor->Pull(frame, symbol_table)) - ; + while (cursor->Pull(frame, symbol_table)) { + continue; + } } -/* - * Following are helper functions that create high level AST - * and logical operator objects. - */ - -auto MakeNamedExpression(Context &ctx, const std::string name, - std::shared_ptr expression) { - auto named_expression = std::make_shared(ctx.next_uid()); - named_expression->name_ = name; - named_expression->expression_ = expression; - return named_expression; -} - -auto MakeIdentifier(Context &ctx, const std::string name) { - return std::make_shared(ctx.next_uid(), name); -} - -auto MakeNode(Context &ctx, std::shared_ptr identifier) { - auto node = std::make_shared(ctx.next_uid()); - node->identifier_ = identifier; - return node; -} - -auto MakeScanAll(std::shared_ptr node_atom) { +auto MakeScanAll(NodeAtom *node_atom) { return std::make_shared(node_atom); } @@ -94,8 +73,7 @@ template auto MakeProduce(std::shared_ptr input, TNamedExpressions... named_expressions) { return std::make_shared( - input, - std::vector>{named_expressions...}); + input, std::vector{named_expressions...}); } /* @@ -110,15 +88,15 @@ TEST(Interpreter, MatchReturn) { dba->insert_vertex(); dba->insert_vertex(); - Config config; - Context ctx(config, *dba); + AstTreeStorage storage; // make a scan all - auto node = MakeNode(ctx, MakeIdentifier(ctx, "n")); + auto node = storage.Create(storage.Create("n")); auto scan_all = MakeScanAll(node); // make a named expression and a produce - auto output = MakeNamedExpression(ctx, "n", MakeIdentifier(ctx, "n")); + auto output = + storage.Create("n", storage.Create("n")); auto produce = MakeProduce(scan_all, output); // fill up the symbol table @@ -156,11 +134,10 @@ TEST(Interpreter, NodeFilterLabelsAndProperties) { v4.PropsSet(property, 42); v5.PropsSet(property, 1); - Config config; - Context ctx(config, *dba); + AstTreeStorage storage; // make a scan all - auto node = MakeNode(ctx, MakeIdentifier(ctx, "n")); + auto node = storage.Create(storage.Create("n")); auto scan_all = MakeScanAll(node); // node filtering @@ -169,10 +146,11 @@ TEST(Interpreter, NodeFilterLabelsAndProperties) { // TODO implement the test once int-literal expressions are available auto node_filter = std::make_shared( scan_all, n_symbol, std::vector{label}, - std::map>()); + std::map{}); // make a named expression and a produce - auto output = MakeNamedExpression(ctx, "n", MakeIdentifier(ctx, "n")); + auto output = + storage.Create("x", storage.Create("n")); auto produce = MakeProduce(node_filter, output); // fill up the symbol table @@ -193,26 +171,25 @@ TEST(Interpreter, NodeFilterMultipleLabels) { GraphDb::Label label2 = dba->label("label2"); GraphDb::Label label3 = dba->label("label3"); // the test will look for nodes that have label1 and label2 - dba->insert_vertex(); // NOT accepted - dba->insert_vertex().add_label(label1); // NOT accepted - dba->insert_vertex().add_label(label2); // NOT accepted - dba->insert_vertex().add_label(label3); // NOT accepted - auto v1 = dba->insert_vertex(); // YES accepted + dba->insert_vertex(); // NOT accepted + dba->insert_vertex().add_label(label1); // NOT accepted + dba->insert_vertex().add_label(label2); // NOT accepted + dba->insert_vertex().add_label(label3); // NOT accepted + auto v1 = dba->insert_vertex(); // YES accepted v1.add_label(label1); v1.add_label(label2); - auto v2 = dba->insert_vertex(); // NOT accepted + auto v2 = dba->insert_vertex(); // NOT accepted v2.add_label(label1); v2.add_label(label3); - auto v3 = dba->insert_vertex(); // YES accepted + auto v3 = dba->insert_vertex(); // YES accepted v3.add_label(label1); v3.add_label(label2); v3.add_label(label3); - Config config; - Context ctx(config, *dba); + AstTreeStorage storage; // make a scan all - auto node = MakeNode(ctx, MakeIdentifier(ctx, "n")); + auto node = storage.Create(storage.Create("n")); auto scan_all = MakeScanAll(node); // node filtering @@ -221,10 +198,11 @@ TEST(Interpreter, NodeFilterMultipleLabels) { // TODO implement the test once int-literal expressions are available auto node_filter = std::make_shared( scan_all, n_symbol, std::vector{label1, label2}, - std::map>()); + std::map()); // make a named expression and a produce - auto output = MakeNamedExpression(ctx, "n", MakeIdentifier(ctx, "n")); + auto output = + storage.Create("n", storage.Create("n")); auto produce = MakeProduce(node_filter, output); // fill up the symbol table @@ -239,13 +217,13 @@ TEST(Interpreter, NodeFilterMultipleLabels) { TEST(Interpreter, CreateNodeWithAttributes) { Dbms dbms; auto dba = dbms.active(); - Config config; - Context ctx(config, *dba); GraphDb::Label label = dba->label("Person"); GraphDb::Property property = dba->label("age"); - auto node = MakeNode(ctx, MakeIdentifier(ctx, "n")); + AstTreeStorage storage; + + auto node = storage.Create(storage.Create("n")); node->labels_.emplace_back(label); // TODO make a property here with an int literal expression // node->properties_[property] = TypedValue(42); diff --git a/tests/unit/query_semantic.cpp b/tests/unit/query_semantic.cpp index c19c1f032..81cb1e318 100644 --- a/tests/unit/query_semantic.cpp +++ b/tests/unit/query_semantic.cpp @@ -9,23 +9,24 @@ using namespace query; +namespace { + // Build a simple AST which describes: // MATCH (node_atom_1) RETURN node_atom_1 AS node_atom_1 -static std::unique_ptr MatchNodeReturn() { - int uid = 0; - auto node_atom = std::make_shared(uid++); - node_atom->identifier_ = std::make_shared(uid++, "node_atom_1"); - auto pattern = std::make_shared(uid++); +Query *MatchNodeReturn(AstTreeStorage &storage) { + auto node_atom = storage.Create(); + node_atom->identifier_ = storage.Create("node_atom_1"); + auto pattern = storage.Create(); pattern->atoms_.emplace_back(node_atom); - auto match = std::make_shared(uid++); + auto match = storage.Create(); match->patterns_.emplace_back(pattern); - auto query = std::make_unique(uid++); + auto query = storage.query(); query->clauses_.emplace_back(match); - auto named_expr = std::make_shared(uid++); + auto named_expr = storage.Create(); named_expr->name_ = "node_atom_1"; - named_expr->expression_ = std::make_shared(uid++, "node_atom_1"); - auto ret = std::make_shared(uid++); + named_expr->expression_ = storage.Create("node_atom_1"); + auto ret = storage.Create(); ret->named_expressions_.emplace_back(named_expr); query->clauses_.emplace_back(ret); return query; @@ -34,24 +35,23 @@ static std::unique_ptr MatchNodeReturn() { // AST using variable in return bound by naming the previous return expression. // This is treated as an unbound variable. // MATCH (node_atom_1) RETURN node_atom_1 AS n, n AS n -static std::unique_ptr MatchUnboundMultiReturn() { - int uid = 0; - auto node_atom = std::make_shared(uid++); - node_atom->identifier_ = std::make_shared(uid++, "node_atom_1"); - auto pattern = std::make_shared(uid++); +Query *MatchUnboundMultiReturn(AstTreeStorage &storage) { + auto node_atom = storage.Create(); + node_atom->identifier_ = storage.Create("node_atom_1"); + auto pattern = storage.Create(); pattern->atoms_.emplace_back(node_atom); - auto match = std::make_shared(uid++); + auto match = storage.Create(); match->patterns_.emplace_back(pattern); - auto query = std::make_unique(uid++); + auto query = storage.query(); query->clauses_.emplace_back(match); - auto named_expr_1 = std::make_shared(uid++); + auto named_expr_1 = storage.Create(); named_expr_1->name_ = "n"; - named_expr_1->expression_ = std::make_shared(uid++, "node_atom_1"); - auto named_expr_2 = std::make_shared(uid++); + named_expr_1->expression_ = storage.Create("node_atom_1"); + auto named_expr_2 = storage.Create(); named_expr_2->name_ = "n"; - named_expr_2->expression_ = std::make_shared(uid++, "n"); - auto ret = std::make_shared(uid++); + named_expr_2->expression_ = storage.Create("n"); + auto ret = storage.Create(); ret->named_expressions_.emplace_back(named_expr_1); ret->named_expressions_.emplace_back(named_expr_2); query->clauses_.emplace_back(ret); @@ -59,21 +59,20 @@ static std::unique_ptr MatchUnboundMultiReturn() { } // AST with unbound variable in return: MATCH (n) RETURN x AS x -static std::unique_ptr MatchNodeUnboundReturn() { - int uid = 0; - auto node_atom = std::make_shared(uid++); - node_atom->identifier_ = std::make_shared(uid++, "n"); - auto pattern = std::make_shared(uid++); +Query *MatchNodeUnboundReturn(AstTreeStorage &storage) { + auto node_atom = storage.Create(); + node_atom->identifier_ = storage.Create("n"); + auto pattern = storage.Create(); pattern->atoms_.emplace_back(node_atom); - auto match = std::make_shared(uid++); + auto match = storage.Create(); match->patterns_.emplace_back(pattern); - auto query = std::make_unique(uid++); + auto query = storage.query(); query->clauses_.emplace_back(match); - auto named_expr = std::make_shared(uid++); + auto named_expr = storage.Create(); named_expr->name_ = "x"; - named_expr->expression_ = std::make_shared(uid++, "x"); - auto ret = std::make_shared(uid++); + named_expr->expression_ = storage.Create("x"); + auto ret = storage.Create(); ret->named_expressions_.emplace_back(named_expr); query->clauses_.emplace_back(ret); return query; @@ -81,16 +80,17 @@ static std::unique_ptr MatchNodeUnboundReturn() { TEST(TestSymbolGenerator, MatchNodeReturn) { SymbolTable symbol_table; - auto query_ast = MatchNodeReturn(); + AstTreeStorage storage; + auto query_ast = MatchNodeReturn(storage); SymbolGenerator symbol_generator(symbol_table); query_ast->Accept(symbol_generator); EXPECT_EQ(symbol_table.max_position(), 2); - auto match = std::dynamic_pointer_cast(query_ast->clauses_[0]); + auto match = dynamic_cast(query_ast->clauses_[0]); auto pattern = match->patterns_[0]; - auto node_atom = std::dynamic_pointer_cast(pattern->atoms_[0]); + auto node_atom = dynamic_cast(pattern->atoms_[0]); auto node_sym = symbol_table[*node_atom->identifier_]; EXPECT_EQ(node_sym.name_, "node_atom_1"); - auto ret = std::dynamic_pointer_cast(query_ast->clauses_[1]); + auto ret = dynamic_cast(query_ast->clauses_[1]); auto named_expr = ret->named_expressions_[0]; auto column_sym = symbol_table[*named_expr]; EXPECT_EQ(node_sym.name_, column_sym.name_); @@ -101,14 +101,17 @@ TEST(TestSymbolGenerator, MatchNodeReturn) { TEST(TestSymbolGenerator, MatchUnboundMultiReturn) { SymbolTable symbol_table; - auto query_ast = MatchUnboundMultiReturn(); + AstTreeStorage storage; + auto query_ast = MatchUnboundMultiReturn(storage); SymbolGenerator symbol_generator(symbol_table); EXPECT_THROW(query_ast->Accept(symbol_generator), SemanticException); } TEST(TestSymbolGenerator, MatchNodeUnboundReturn) { SymbolTable symbol_table; - auto query_ast = MatchNodeUnboundReturn(); + AstTreeStorage storage; + auto query_ast = MatchNodeUnboundReturn(storage); SymbolGenerator symbol_generator(symbol_table); EXPECT_THROW(query_ast->Accept(symbol_generator), SemanticException); } +}