From 3d026b106d946d3f4b57631a45ad28fadf25ad39 Mon Sep 17 00:00:00 2001 From: Mislav Bradac Date: Fri, 5 May 2017 18:12:16 +0200 Subject: [PATCH] Handle asterisk in AST conversion Summary: TODO: Handle COUNT_ASTERISK aggregation in operator (will do in next diff) Reviewers: teon.banek, florijan, buda Reviewed By: buda Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D350 --- src/query/frontend/ast/ast.hpp | 22 ++++++++++++---- .../frontend/ast/cypher_main_visitor.cpp | 25 +++++++++++------- .../frontend/ast/cypher_main_visitor.hpp | 4 ++- tests/unit/cypher_main_visitor.cpp | 26 +++++++++++++++++-- 4 files changed, 59 insertions(+), 18 deletions(-) diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index f21ee36a6..0b7a479ac 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -8,6 +8,7 @@ #include "database/graph_db_datatypes.hpp" #include "query/frontend/ast/ast_visitor.hpp" #include "query/typed_value.hpp" +#include "utils/assert.hpp" #include "utils/visitor/visitable.hpp" namespace query { @@ -38,8 +39,8 @@ class BinaryOperator : public Expression { friend class AstTreeStorage; public: - Expression *expression1_; - Expression *expression2_; + Expression *expression1_ = nullptr; + Expression *expression2_ = nullptr; protected: BinaryOperator(int uid) : Expression(uid) {} @@ -51,7 +52,7 @@ class UnaryOperator : public Expression { friend class AstTreeStorage; public: - Expression *expression_; + Expression *expression_ = nullptr; protected: UnaryOperator(int uid) : Expression(uid) {} @@ -454,9 +455,12 @@ class Identifier : public Expression { public: DEFVISITABLE(TreeVisitorBase); std::string name_; + bool user_declared_ = true; protected: Identifier(int uid, const std::string &name) : Expression(uid), name_(name) {} + Identifier(int uid, const std::string &name, bool user_declared) + : Expression(uid), name_(name), user_declared_(user_declared) {} }; class PropertyLookup : public Expression { @@ -572,7 +576,9 @@ class Aggregation : public UnaryOperator { void Accept(TreeVisitorBase &visitor) override { if (visitor.PreVisit(*this)) { visitor.Visit(*this); - expression_->Accept(visitor); + if (expression_) { + expression_->Accept(visitor); + } visitor.PostVisit(*this); } } @@ -580,7 +586,11 @@ class Aggregation : public UnaryOperator { protected: Aggregation(int uid, Expression *expression, Op op) - : UnaryOperator(uid, expression), op_(op) {} + : UnaryOperator(uid, expression), op_(op) { + // Count without expression denotes count(*) in cypher. + debug_assert(expression || op == Aggregation::Op::Count, + "All aggregations, except COUNT require expression"); + } }; class NamedExpression : public Tree { @@ -780,6 +790,8 @@ enum class Ordering { ASC, DESC }; struct ReturnBody { /** @brief True if distinct results should be produced. */ bool distinct = false; + /** @brief True if asterisk was found in return body */ + bool all_identifiers = false; /** @brief Expressions which are used to produce results. */ std::vector named_expressions; /** @brief Expressions used for ordering the results. */ diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index 2839727c1..5f107cbde 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -83,7 +83,7 @@ antlrcpp::Any CypherMainVisitor::visitSingleQuery( while (true) { std::string id_name = kAnonPrefix + std::to_string(id++); if (users_identifiers.find(id_name) == users_identifiers.end()) { - *identifier = storage_.Create(id_name); + *identifier = storage_.Create(id_name, false); break; } } @@ -171,22 +171,21 @@ antlrcpp::Any CypherMainVisitor::visitReturnBody( if (ctx->limit()) { body.limit = static_cast(ctx->limit()->accept(this)); } - body.named_expressions = - ctx->returnItems()->accept(this).as>(); + std::tie(body.all_identifiers, body.named_expressions) = + ctx->returnItems() + ->accept(this) + .as>>(); return body; } antlrcpp::Any CypherMainVisitor::visitReturnItems( CypherParser::ReturnItemsContext *ctx) { - if (ctx->getTokens(kReturnAllTokenId).size()) { - // TODO: implement * - throw utils::NotYetImplemented(); - } std::vector named_expressions; for (auto *item : ctx->returnItem()) { named_expressions.push_back(item->accept(this)); } - return named_expressions; + return std::pair>( + ctx->getTokens(kReturnAllTokenId).size(), named_expressions); } antlrcpp::Any CypherMainVisitor::visitReturnItem( @@ -669,9 +668,15 @@ antlrcpp::Any CypherMainVisitor::visitAtom(CypherParser::AtomContext *ctx) { return static_cast(storage_.Create(variable)); } else if (ctx->functionInvocation()) { return static_cast(ctx->functionInvocation()->accept(this)); + } else if (ctx->COUNT()) { + // Here we handle COUNT(*). COUNT(expression) is handled in + // visitFunctionInvocation with other aggregations. This is visible in + // functionInvocation and atom producions in opencypher grammar. + return static_cast( + storage_.Create(nullptr, Aggregation::Op::COUNT)); } - // TODO: Implement this. We don't support comprehensions, functions, - // filtering... at the moment. + // TODO: Implement this. We don't support comprehensions, filtering... at the + // moment. throw utils::NotYetImplemented(); } diff --git a/src/query/frontend/ast/cypher_main_visitor.hpp b/src/query/frontend/ast/cypher_main_visitor.hpp index 93daadb85..52fffa5fe 100644 --- a/src/query/frontend/ast/cypher_main_visitor.hpp +++ b/src/query/frontend/ast/cypher_main_visitor.hpp @@ -165,7 +165,9 @@ class CypherMainVisitor : public antlropencypher::CypherBaseVisitor { antlrcpp::Any visitReturnBody(CypherParser::ReturnBodyContext *ctx) override; /** - * @return vector + * @return pair> first member is true if + * asterisk was found in return + * expressions. */ antlrcpp::Any visitReturnItems( CypherParser::ReturnItemsContext *ctx) override; diff --git a/tests/unit/cypher_main_visitor.cpp b/tests/unit/cypher_main_visitor.cpp index f3208a6ef..947ee96e4 100644 --- a/tests/unit/cypher_main_visitor.cpp +++ b/tests/unit/cypher_main_visitor.cpp @@ -90,6 +90,7 @@ TEST(CypherMainVisitorTest, ReturnNoDistinctNoBagSemantics) { auto *query = ast_generator.query_; ASSERT_EQ(query->clauses_.size(), 1U); auto *return_clause = dynamic_cast(query->clauses_[0]); + ASSERT_FALSE(return_clause->body_.all_identifiers); ASSERT_EQ(return_clause->body_.order_by.size(), 0U); ASSERT_EQ(return_clause->body_.named_expressions.size(), 1U); ASSERT_FALSE(return_clause->body_.limit); @@ -147,12 +148,21 @@ TEST(CypherMainVisitorTest, ReturnNamedIdentifier) { AstGenerator ast_generator("RETURN var AS var5"); auto *query = ast_generator.query_; auto *return_clause = dynamic_cast(query->clauses_[0]); + ASSERT_FALSE(return_clause->body_.all_identifiers); auto *named_expr = return_clause->body_.named_expressions[0]; ASSERT_EQ(named_expr->name_, "var5"); auto *identifier = dynamic_cast(named_expr->expression_); ASSERT_EQ(identifier->name_, "var"); } +TEST(CypherMainVisitorTest, ReturnAsterisk) { + AstGenerator ast_generator("RETURN *"); + auto *query = ast_generator.query_; + auto *return_clause = dynamic_cast(query->clauses_[0]); + ASSERT_TRUE(return_clause->body_.all_identifiers); + ASSERT_EQ(return_clause->body_.named_expressions.size(), 0U); +} + TEST(CypherMainVisitorTest, IntegerLiteral) { AstGenerator ast_generator("RETURN 42"); auto *query = ast_generator.query_; @@ -448,10 +458,11 @@ TEST(CypherMainVisitorTest, UnaryMinusPlusOperators) { } TEST(CypherMainVisitorTest, Aggregation) { - AstGenerator ast_generator("RETURN COUNT(a), MIN(b), MAX(c), SUM(d), AVG(e)"); + AstGenerator ast_generator( + "RETURN COUNT(a), MIN(b), MAX(c), SUM(d), AVG(e), COUNT(*)"); auto *query = ast_generator.query_; auto *return_clause = dynamic_cast(query->clauses_[0]); - ASSERT_EQ(return_clause->body_.named_expressions.size(), 5); + ASSERT_EQ(return_clause->body_.named_expressions.size(), 6U); Aggregation::Op ops[] = {Aggregation::Op::COUNT, Aggregation::Op::MIN, Aggregation::Op::MAX, Aggregation::Op::SUM, Aggregation::Op::AVG}; @@ -465,6 +476,11 @@ TEST(CypherMainVisitorTest, Aggregation) { ASSERT_TRUE(identifier); ASSERT_EQ(identifier->name_, ids[i]); } + auto *aggregation = dynamic_cast( + return_clause->body_.named_expressions[5]->expression_); + ASSERT_TRUE(aggregation); + ASSERT_EQ(aggregation->op_, Aggregation::Op::COUNT); + ASSERT_FALSE(aggregation->expression_); } TEST(CypherMainVisitorTest, UndefinedFunction) { @@ -593,6 +609,7 @@ TEST(CypherMainVisitorTest, NodePattern) { ASSERT_TRUE(node->identifier_); EXPECT_EQ(node->identifier_->name_, CypherMainVisitor::kAnonPrefix + std::to_string(1)); + EXPECT_FALSE(node->identifier_->user_declared_); EXPECT_THAT(node->labels_, UnorderedElementsAre( ast_generator.db_accessor_->label("label1"), ast_generator.db_accessor_->label("label2"), @@ -621,6 +638,7 @@ TEST(CypherMainVisitorTest, NodePatternIdentifier) { ASSERT_TRUE(node); ASSERT_TRUE(node->identifier_); EXPECT_EQ(node->identifier_->name_, "var"); + EXPECT_TRUE(node->identifier_->user_declared_); EXPECT_THAT(node->labels_, UnorderedElementsAre()); EXPECT_THAT(node->properties_, UnorderedElementsAre()); } @@ -645,6 +663,7 @@ TEST(CypherMainVisitorTest, RelationshipPatternNoDetails) { ASSERT_TRUE(edge->identifier_); EXPECT_THAT(edge->identifier_->name_, CypherMainVisitor::kAnonPrefix + std::to_string(2)); + EXPECT_FALSE(edge->identifier_->user_declared_); } // PatternPart in braces. @@ -667,6 +686,7 @@ TEST(CypherMainVisitorTest, PatternPartBraces) { ASSERT_TRUE(edge->identifier_); EXPECT_THAT(edge->identifier_->name_, CypherMainVisitor::kAnonPrefix + std::to_string(2)); + EXPECT_FALSE(edge->identifier_->user_declared_); } TEST(CypherMainVisitorTest, RelationshipPatternDetails) { @@ -709,6 +729,7 @@ TEST(CypherMainVisitorTest, RelationshipPatternVariable) { EXPECT_EQ(edge->direction_, EdgeAtom::Direction::RIGHT); ASSERT_TRUE(edge->identifier_); EXPECT_THAT(edge->identifier_->name_, "var"); + EXPECT_TRUE(edge->identifier_->user_declared_); } // // Relationship with unbounded variable range. @@ -788,6 +809,7 @@ TEST(CypherMainVisitorTest, ReturnUnanemdIdentifier) { auto *identifier = dynamic_cast(named_expr->expression_); ASSERT_TRUE(identifier); ASSERT_EQ(identifier->name_, "var"); + ASSERT_TRUE(identifier->user_declared_); } TEST(CypherMainVisitorTest, Create) {