Handle asterisk in AST conversion

Summary: TODO: Handle COUNT_ASTERISK aggregation in operator (will do in next diff)

Reviewers: teon.banek, florijan, buda

Reviewed By: buda

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D350
This commit is contained in:
Mislav Bradac 2017-05-05 18:12:16 +02:00
parent 3428a9bfbd
commit 3d026b106d
4 changed files with 59 additions and 18 deletions

View File

@ -8,6 +8,7 @@
#include "database/graph_db_datatypes.hpp" #include "database/graph_db_datatypes.hpp"
#include "query/frontend/ast/ast_visitor.hpp" #include "query/frontend/ast/ast_visitor.hpp"
#include "query/typed_value.hpp" #include "query/typed_value.hpp"
#include "utils/assert.hpp"
#include "utils/visitor/visitable.hpp" #include "utils/visitor/visitable.hpp"
namespace query { namespace query {
@ -38,8 +39,8 @@ class BinaryOperator : public Expression {
friend class AstTreeStorage; friend class AstTreeStorage;
public: public:
Expression *expression1_; Expression *expression1_ = nullptr;
Expression *expression2_; Expression *expression2_ = nullptr;
protected: protected:
BinaryOperator(int uid) : Expression(uid) {} BinaryOperator(int uid) : Expression(uid) {}
@ -51,7 +52,7 @@ class UnaryOperator : public Expression {
friend class AstTreeStorage; friend class AstTreeStorage;
public: public:
Expression *expression_; Expression *expression_ = nullptr;
protected: protected:
UnaryOperator(int uid) : Expression(uid) {} UnaryOperator(int uid) : Expression(uid) {}
@ -454,9 +455,12 @@ class Identifier : public Expression {
public: public:
DEFVISITABLE(TreeVisitorBase); DEFVISITABLE(TreeVisitorBase);
std::string name_; std::string name_;
bool user_declared_ = true;
protected: protected:
Identifier(int uid, const std::string &name) : Expression(uid), name_(name) {} Identifier(int uid, const std::string &name) : Expression(uid), name_(name) {}
Identifier(int uid, const std::string &name, bool user_declared)
: Expression(uid), name_(name), user_declared_(user_declared) {}
}; };
class PropertyLookup : public Expression { class PropertyLookup : public Expression {
@ -572,7 +576,9 @@ class Aggregation : public UnaryOperator {
void Accept(TreeVisitorBase &visitor) override { void Accept(TreeVisitorBase &visitor) override {
if (visitor.PreVisit(*this)) { if (visitor.PreVisit(*this)) {
visitor.Visit(*this); visitor.Visit(*this);
expression_->Accept(visitor); if (expression_) {
expression_->Accept(visitor);
}
visitor.PostVisit(*this); visitor.PostVisit(*this);
} }
} }
@ -580,7 +586,11 @@ class Aggregation : public UnaryOperator {
protected: protected:
Aggregation(int uid, Expression *expression, Op op) Aggregation(int uid, Expression *expression, Op op)
: UnaryOperator(uid, expression), op_(op) {} : UnaryOperator(uid, expression), op_(op) {
// Count without expression denotes count(*) in cypher.
debug_assert(expression || op == Aggregation::Op::Count,
"All aggregations, except COUNT require expression");
}
}; };
class NamedExpression : public Tree { class NamedExpression : public Tree {
@ -780,6 +790,8 @@ enum class Ordering { ASC, DESC };
struct ReturnBody { struct ReturnBody {
/** @brief True if distinct results should be produced. */ /** @brief True if distinct results should be produced. */
bool distinct = false; bool distinct = false;
/** @brief True if asterisk was found in return body */
bool all_identifiers = false;
/** @brief Expressions which are used to produce results. */ /** @brief Expressions which are used to produce results. */
std::vector<NamedExpression *> named_expressions; std::vector<NamedExpression *> named_expressions;
/** @brief Expressions used for ordering the results. */ /** @brief Expressions used for ordering the results. */

View File

@ -83,7 +83,7 @@ antlrcpp::Any CypherMainVisitor::visitSingleQuery(
while (true) { while (true) {
std::string id_name = kAnonPrefix + std::to_string(id++); std::string id_name = kAnonPrefix + std::to_string(id++);
if (users_identifiers.find(id_name) == users_identifiers.end()) { if (users_identifiers.find(id_name) == users_identifiers.end()) {
*identifier = storage_.Create<Identifier>(id_name); *identifier = storage_.Create<Identifier>(id_name, false);
break; break;
} }
} }
@ -171,22 +171,21 @@ antlrcpp::Any CypherMainVisitor::visitReturnBody(
if (ctx->limit()) { if (ctx->limit()) {
body.limit = static_cast<Expression *>(ctx->limit()->accept(this)); body.limit = static_cast<Expression *>(ctx->limit()->accept(this));
} }
body.named_expressions = std::tie(body.all_identifiers, body.named_expressions) =
ctx->returnItems()->accept(this).as<std::vector<NamedExpression *>>(); ctx->returnItems()
->accept(this)
.as<std::pair<bool, std::vector<NamedExpression *>>>();
return body; return body;
} }
antlrcpp::Any CypherMainVisitor::visitReturnItems( antlrcpp::Any CypherMainVisitor::visitReturnItems(
CypherParser::ReturnItemsContext *ctx) { CypherParser::ReturnItemsContext *ctx) {
if (ctx->getTokens(kReturnAllTokenId).size()) {
// TODO: implement *
throw utils::NotYetImplemented();
}
std::vector<NamedExpression *> named_expressions; std::vector<NamedExpression *> named_expressions;
for (auto *item : ctx->returnItem()) { for (auto *item : ctx->returnItem()) {
named_expressions.push_back(item->accept(this)); named_expressions.push_back(item->accept(this));
} }
return named_expressions; return std::pair<bool, std::vector<NamedExpression *>>(
ctx->getTokens(kReturnAllTokenId).size(), named_expressions);
} }
antlrcpp::Any CypherMainVisitor::visitReturnItem( antlrcpp::Any CypherMainVisitor::visitReturnItem(
@ -669,9 +668,15 @@ antlrcpp::Any CypherMainVisitor::visitAtom(CypherParser::AtomContext *ctx) {
return static_cast<Expression *>(storage_.Create<Identifier>(variable)); return static_cast<Expression *>(storage_.Create<Identifier>(variable));
} else if (ctx->functionInvocation()) { } else if (ctx->functionInvocation()) {
return static_cast<Expression *>(ctx->functionInvocation()->accept(this)); return static_cast<Expression *>(ctx->functionInvocation()->accept(this));
} else if (ctx->COUNT()) {
// Here we handle COUNT(*). COUNT(expression) is handled in
// visitFunctionInvocation with other aggregations. This is visible in
// functionInvocation and atom producions in opencypher grammar.
return static_cast<Expression *>(
storage_.Create<Aggregation>(nullptr, Aggregation::Op::COUNT));
} }
// TODO: Implement this. We don't support comprehensions, functions, // TODO: Implement this. We don't support comprehensions, filtering... at the
// filtering... at the moment. // moment.
throw utils::NotYetImplemented(); throw utils::NotYetImplemented();
} }

View File

@ -165,7 +165,9 @@ class CypherMainVisitor : public antlropencypher::CypherBaseVisitor {
antlrcpp::Any visitReturnBody(CypherParser::ReturnBodyContext *ctx) override; antlrcpp::Any visitReturnBody(CypherParser::ReturnBodyContext *ctx) override;
/** /**
* @return vector<NamedExpression*> * @return pair<bool, vector<NamedExpression*>> first member is true if
* asterisk was found in return
* expressions.
*/ */
antlrcpp::Any visitReturnItems( antlrcpp::Any visitReturnItems(
CypherParser::ReturnItemsContext *ctx) override; CypherParser::ReturnItemsContext *ctx) override;

View File

@ -90,6 +90,7 @@ TEST(CypherMainVisitorTest, ReturnNoDistinctNoBagSemantics) {
auto *query = ast_generator.query_; auto *query = ast_generator.query_;
ASSERT_EQ(query->clauses_.size(), 1U); ASSERT_EQ(query->clauses_.size(), 1U);
auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]); auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]);
ASSERT_FALSE(return_clause->body_.all_identifiers);
ASSERT_EQ(return_clause->body_.order_by.size(), 0U); ASSERT_EQ(return_clause->body_.order_by.size(), 0U);
ASSERT_EQ(return_clause->body_.named_expressions.size(), 1U); ASSERT_EQ(return_clause->body_.named_expressions.size(), 1U);
ASSERT_FALSE(return_clause->body_.limit); ASSERT_FALSE(return_clause->body_.limit);
@ -147,12 +148,21 @@ TEST(CypherMainVisitorTest, ReturnNamedIdentifier) {
AstGenerator ast_generator("RETURN var AS var5"); AstGenerator ast_generator("RETURN var AS var5");
auto *query = ast_generator.query_; auto *query = ast_generator.query_;
auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]); auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]);
ASSERT_FALSE(return_clause->body_.all_identifiers);
auto *named_expr = return_clause->body_.named_expressions[0]; auto *named_expr = return_clause->body_.named_expressions[0];
ASSERT_EQ(named_expr->name_, "var5"); ASSERT_EQ(named_expr->name_, "var5");
auto *identifier = dynamic_cast<Identifier *>(named_expr->expression_); auto *identifier = dynamic_cast<Identifier *>(named_expr->expression_);
ASSERT_EQ(identifier->name_, "var"); ASSERT_EQ(identifier->name_, "var");
} }
TEST(CypherMainVisitorTest, ReturnAsterisk) {
AstGenerator ast_generator("RETURN *");
auto *query = ast_generator.query_;
auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]);
ASSERT_TRUE(return_clause->body_.all_identifiers);
ASSERT_EQ(return_clause->body_.named_expressions.size(), 0U);
}
TEST(CypherMainVisitorTest, IntegerLiteral) { TEST(CypherMainVisitorTest, IntegerLiteral) {
AstGenerator ast_generator("RETURN 42"); AstGenerator ast_generator("RETURN 42");
auto *query = ast_generator.query_; auto *query = ast_generator.query_;
@ -448,10 +458,11 @@ TEST(CypherMainVisitorTest, UnaryMinusPlusOperators) {
} }
TEST(CypherMainVisitorTest, Aggregation) { TEST(CypherMainVisitorTest, Aggregation) {
AstGenerator ast_generator("RETURN COUNT(a), MIN(b), MAX(c), SUM(d), AVG(e)"); AstGenerator ast_generator(
"RETURN COUNT(a), MIN(b), MAX(c), SUM(d), AVG(e), COUNT(*)");
auto *query = ast_generator.query_; auto *query = ast_generator.query_;
auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]); auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]);
ASSERT_EQ(return_clause->body_.named_expressions.size(), 5); ASSERT_EQ(return_clause->body_.named_expressions.size(), 6U);
Aggregation::Op ops[] = {Aggregation::Op::COUNT, Aggregation::Op::MIN, Aggregation::Op ops[] = {Aggregation::Op::COUNT, Aggregation::Op::MIN,
Aggregation::Op::MAX, Aggregation::Op::SUM, Aggregation::Op::MAX, Aggregation::Op::SUM,
Aggregation::Op::AVG}; Aggregation::Op::AVG};
@ -465,6 +476,11 @@ TEST(CypherMainVisitorTest, Aggregation) {
ASSERT_TRUE(identifier); ASSERT_TRUE(identifier);
ASSERT_EQ(identifier->name_, ids[i]); ASSERT_EQ(identifier->name_, ids[i]);
} }
auto *aggregation = dynamic_cast<Aggregation *>(
return_clause->body_.named_expressions[5]->expression_);
ASSERT_TRUE(aggregation);
ASSERT_EQ(aggregation->op_, Aggregation::Op::COUNT);
ASSERT_FALSE(aggregation->expression_);
} }
TEST(CypherMainVisitorTest, UndefinedFunction) { TEST(CypherMainVisitorTest, UndefinedFunction) {
@ -593,6 +609,7 @@ TEST(CypherMainVisitorTest, NodePattern) {
ASSERT_TRUE(node->identifier_); ASSERT_TRUE(node->identifier_);
EXPECT_EQ(node->identifier_->name_, EXPECT_EQ(node->identifier_->name_,
CypherMainVisitor::kAnonPrefix + std::to_string(1)); CypherMainVisitor::kAnonPrefix + std::to_string(1));
EXPECT_FALSE(node->identifier_->user_declared_);
EXPECT_THAT(node->labels_, UnorderedElementsAre( EXPECT_THAT(node->labels_, UnorderedElementsAre(
ast_generator.db_accessor_->label("label1"), ast_generator.db_accessor_->label("label1"),
ast_generator.db_accessor_->label("label2"), ast_generator.db_accessor_->label("label2"),
@ -621,6 +638,7 @@ TEST(CypherMainVisitorTest, NodePatternIdentifier) {
ASSERT_TRUE(node); ASSERT_TRUE(node);
ASSERT_TRUE(node->identifier_); ASSERT_TRUE(node->identifier_);
EXPECT_EQ(node->identifier_->name_, "var"); EXPECT_EQ(node->identifier_->name_, "var");
EXPECT_TRUE(node->identifier_->user_declared_);
EXPECT_THAT(node->labels_, UnorderedElementsAre()); EXPECT_THAT(node->labels_, UnorderedElementsAre());
EXPECT_THAT(node->properties_, UnorderedElementsAre()); EXPECT_THAT(node->properties_, UnorderedElementsAre());
} }
@ -645,6 +663,7 @@ TEST(CypherMainVisitorTest, RelationshipPatternNoDetails) {
ASSERT_TRUE(edge->identifier_); ASSERT_TRUE(edge->identifier_);
EXPECT_THAT(edge->identifier_->name_, EXPECT_THAT(edge->identifier_->name_,
CypherMainVisitor::kAnonPrefix + std::to_string(2)); CypherMainVisitor::kAnonPrefix + std::to_string(2));
EXPECT_FALSE(edge->identifier_->user_declared_);
} }
// PatternPart in braces. // PatternPart in braces.
@ -667,6 +686,7 @@ TEST(CypherMainVisitorTest, PatternPartBraces) {
ASSERT_TRUE(edge->identifier_); ASSERT_TRUE(edge->identifier_);
EXPECT_THAT(edge->identifier_->name_, EXPECT_THAT(edge->identifier_->name_,
CypherMainVisitor::kAnonPrefix + std::to_string(2)); CypherMainVisitor::kAnonPrefix + std::to_string(2));
EXPECT_FALSE(edge->identifier_->user_declared_);
} }
TEST(CypherMainVisitorTest, RelationshipPatternDetails) { TEST(CypherMainVisitorTest, RelationshipPatternDetails) {
@ -709,6 +729,7 @@ TEST(CypherMainVisitorTest, RelationshipPatternVariable) {
EXPECT_EQ(edge->direction_, EdgeAtom::Direction::RIGHT); EXPECT_EQ(edge->direction_, EdgeAtom::Direction::RIGHT);
ASSERT_TRUE(edge->identifier_); ASSERT_TRUE(edge->identifier_);
EXPECT_THAT(edge->identifier_->name_, "var"); EXPECT_THAT(edge->identifier_->name_, "var");
EXPECT_TRUE(edge->identifier_->user_declared_);
} }
// // Relationship with unbounded variable range. // // Relationship with unbounded variable range.
@ -788,6 +809,7 @@ TEST(CypherMainVisitorTest, ReturnUnanemdIdentifier) {
auto *identifier = dynamic_cast<Identifier *>(named_expr->expression_); auto *identifier = dynamic_cast<Identifier *>(named_expr->expression_);
ASSERT_TRUE(identifier); ASSERT_TRUE(identifier);
ASSERT_EQ(identifier->name_, "var"); ASSERT_EQ(identifier->name_, "var");
ASSERT_TRUE(identifier->user_declared_);
} }
TEST(CypherMainVisitorTest, Create) { TEST(CypherMainVisitorTest, Create) {