Handle asterisk in AST conversion

Summary: TODO: Handle COUNT_ASTERISK aggregation in operator (will do in next diff)

Reviewers: teon.banek, florijan, buda

Reviewed By: buda

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D350
This commit is contained in:
Mislav Bradac 2017-05-05 18:12:16 +02:00
parent 3428a9bfbd
commit 3d026b106d
4 changed files with 59 additions and 18 deletions

View File

@ -8,6 +8,7 @@
#include "database/graph_db_datatypes.hpp"
#include "query/frontend/ast/ast_visitor.hpp"
#include "query/typed_value.hpp"
#include "utils/assert.hpp"
#include "utils/visitor/visitable.hpp"
namespace query {
@ -38,8 +39,8 @@ class BinaryOperator : public Expression {
friend class AstTreeStorage;
public:
Expression *expression1_;
Expression *expression2_;
Expression *expression1_ = nullptr;
Expression *expression2_ = nullptr;
protected:
BinaryOperator(int uid) : Expression(uid) {}
@ -51,7 +52,7 @@ class UnaryOperator : public Expression {
friend class AstTreeStorage;
public:
Expression *expression_;
Expression *expression_ = nullptr;
protected:
UnaryOperator(int uid) : Expression(uid) {}
@ -454,9 +455,12 @@ class Identifier : public Expression {
public:
DEFVISITABLE(TreeVisitorBase);
std::string name_;
bool user_declared_ = true;
protected:
Identifier(int uid, const std::string &name) : Expression(uid), name_(name) {}
Identifier(int uid, const std::string &name, bool user_declared)
: Expression(uid), name_(name), user_declared_(user_declared) {}
};
class PropertyLookup : public Expression {
@ -572,7 +576,9 @@ class Aggregation : public UnaryOperator {
void Accept(TreeVisitorBase &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
expression_->Accept(visitor);
if (expression_) {
expression_->Accept(visitor);
}
visitor.PostVisit(*this);
}
}
@ -580,7 +586,11 @@ class Aggregation : public UnaryOperator {
protected:
Aggregation(int uid, Expression *expression, Op op)
: UnaryOperator(uid, expression), op_(op) {}
: UnaryOperator(uid, expression), op_(op) {
// Count without expression denotes count(*) in cypher.
debug_assert(expression || op == Aggregation::Op::Count,
"All aggregations, except COUNT require expression");
}
};
class NamedExpression : public Tree {
@ -780,6 +790,8 @@ enum class Ordering { ASC, DESC };
struct ReturnBody {
/** @brief True if distinct results should be produced. */
bool distinct = false;
/** @brief True if asterisk was found in return body */
bool all_identifiers = false;
/** @brief Expressions which are used to produce results. */
std::vector<NamedExpression *> named_expressions;
/** @brief Expressions used for ordering the results. */

View File

@ -83,7 +83,7 @@ antlrcpp::Any CypherMainVisitor::visitSingleQuery(
while (true) {
std::string id_name = kAnonPrefix + std::to_string(id++);
if (users_identifiers.find(id_name) == users_identifiers.end()) {
*identifier = storage_.Create<Identifier>(id_name);
*identifier = storage_.Create<Identifier>(id_name, false);
break;
}
}
@ -171,22 +171,21 @@ antlrcpp::Any CypherMainVisitor::visitReturnBody(
if (ctx->limit()) {
body.limit = static_cast<Expression *>(ctx->limit()->accept(this));
}
body.named_expressions =
ctx->returnItems()->accept(this).as<std::vector<NamedExpression *>>();
std::tie(body.all_identifiers, body.named_expressions) =
ctx->returnItems()
->accept(this)
.as<std::pair<bool, std::vector<NamedExpression *>>>();
return body;
}
antlrcpp::Any CypherMainVisitor::visitReturnItems(
CypherParser::ReturnItemsContext *ctx) {
if (ctx->getTokens(kReturnAllTokenId).size()) {
// TODO: implement *
throw utils::NotYetImplemented();
}
std::vector<NamedExpression *> named_expressions;
for (auto *item : ctx->returnItem()) {
named_expressions.push_back(item->accept(this));
}
return named_expressions;
return std::pair<bool, std::vector<NamedExpression *>>(
ctx->getTokens(kReturnAllTokenId).size(), named_expressions);
}
antlrcpp::Any CypherMainVisitor::visitReturnItem(
@ -669,9 +668,15 @@ antlrcpp::Any CypherMainVisitor::visitAtom(CypherParser::AtomContext *ctx) {
return static_cast<Expression *>(storage_.Create<Identifier>(variable));
} else if (ctx->functionInvocation()) {
return static_cast<Expression *>(ctx->functionInvocation()->accept(this));
} else if (ctx->COUNT()) {
// Here we handle COUNT(*). COUNT(expression) is handled in
// visitFunctionInvocation with other aggregations. This is visible in
// functionInvocation and atom producions in opencypher grammar.
return static_cast<Expression *>(
storage_.Create<Aggregation>(nullptr, Aggregation::Op::COUNT));
}
// TODO: Implement this. We don't support comprehensions, functions,
// filtering... at the moment.
// TODO: Implement this. We don't support comprehensions, filtering... at the
// moment.
throw utils::NotYetImplemented();
}

View File

@ -165,7 +165,9 @@ class CypherMainVisitor : public antlropencypher::CypherBaseVisitor {
antlrcpp::Any visitReturnBody(CypherParser::ReturnBodyContext *ctx) override;
/**
* @return vector<NamedExpression*>
* @return pair<bool, vector<NamedExpression*>> first member is true if
* asterisk was found in return
* expressions.
*/
antlrcpp::Any visitReturnItems(
CypherParser::ReturnItemsContext *ctx) override;

View File

@ -90,6 +90,7 @@ TEST(CypherMainVisitorTest, ReturnNoDistinctNoBagSemantics) {
auto *query = ast_generator.query_;
ASSERT_EQ(query->clauses_.size(), 1U);
auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]);
ASSERT_FALSE(return_clause->body_.all_identifiers);
ASSERT_EQ(return_clause->body_.order_by.size(), 0U);
ASSERT_EQ(return_clause->body_.named_expressions.size(), 1U);
ASSERT_FALSE(return_clause->body_.limit);
@ -147,12 +148,21 @@ TEST(CypherMainVisitorTest, ReturnNamedIdentifier) {
AstGenerator ast_generator("RETURN var AS var5");
auto *query = ast_generator.query_;
auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]);
ASSERT_FALSE(return_clause->body_.all_identifiers);
auto *named_expr = return_clause->body_.named_expressions[0];
ASSERT_EQ(named_expr->name_, "var5");
auto *identifier = dynamic_cast<Identifier *>(named_expr->expression_);
ASSERT_EQ(identifier->name_, "var");
}
TEST(CypherMainVisitorTest, ReturnAsterisk) {
AstGenerator ast_generator("RETURN *");
auto *query = ast_generator.query_;
auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]);
ASSERT_TRUE(return_clause->body_.all_identifiers);
ASSERT_EQ(return_clause->body_.named_expressions.size(), 0U);
}
TEST(CypherMainVisitorTest, IntegerLiteral) {
AstGenerator ast_generator("RETURN 42");
auto *query = ast_generator.query_;
@ -448,10 +458,11 @@ TEST(CypherMainVisitorTest, UnaryMinusPlusOperators) {
}
TEST(CypherMainVisitorTest, Aggregation) {
AstGenerator ast_generator("RETURN COUNT(a), MIN(b), MAX(c), SUM(d), AVG(e)");
AstGenerator ast_generator(
"RETURN COUNT(a), MIN(b), MAX(c), SUM(d), AVG(e), COUNT(*)");
auto *query = ast_generator.query_;
auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]);
ASSERT_EQ(return_clause->body_.named_expressions.size(), 5);
ASSERT_EQ(return_clause->body_.named_expressions.size(), 6U);
Aggregation::Op ops[] = {Aggregation::Op::COUNT, Aggregation::Op::MIN,
Aggregation::Op::MAX, Aggregation::Op::SUM,
Aggregation::Op::AVG};
@ -465,6 +476,11 @@ TEST(CypherMainVisitorTest, Aggregation) {
ASSERT_TRUE(identifier);
ASSERT_EQ(identifier->name_, ids[i]);
}
auto *aggregation = dynamic_cast<Aggregation *>(
return_clause->body_.named_expressions[5]->expression_);
ASSERT_TRUE(aggregation);
ASSERT_EQ(aggregation->op_, Aggregation::Op::COUNT);
ASSERT_FALSE(aggregation->expression_);
}
TEST(CypherMainVisitorTest, UndefinedFunction) {
@ -593,6 +609,7 @@ TEST(CypherMainVisitorTest, NodePattern) {
ASSERT_TRUE(node->identifier_);
EXPECT_EQ(node->identifier_->name_,
CypherMainVisitor::kAnonPrefix + std::to_string(1));
EXPECT_FALSE(node->identifier_->user_declared_);
EXPECT_THAT(node->labels_, UnorderedElementsAre(
ast_generator.db_accessor_->label("label1"),
ast_generator.db_accessor_->label("label2"),
@ -621,6 +638,7 @@ TEST(CypherMainVisitorTest, NodePatternIdentifier) {
ASSERT_TRUE(node);
ASSERT_TRUE(node->identifier_);
EXPECT_EQ(node->identifier_->name_, "var");
EXPECT_TRUE(node->identifier_->user_declared_);
EXPECT_THAT(node->labels_, UnorderedElementsAre());
EXPECT_THAT(node->properties_, UnorderedElementsAre());
}
@ -645,6 +663,7 @@ TEST(CypherMainVisitorTest, RelationshipPatternNoDetails) {
ASSERT_TRUE(edge->identifier_);
EXPECT_THAT(edge->identifier_->name_,
CypherMainVisitor::kAnonPrefix + std::to_string(2));
EXPECT_FALSE(edge->identifier_->user_declared_);
}
// PatternPart in braces.
@ -667,6 +686,7 @@ TEST(CypherMainVisitorTest, PatternPartBraces) {
ASSERT_TRUE(edge->identifier_);
EXPECT_THAT(edge->identifier_->name_,
CypherMainVisitor::kAnonPrefix + std::to_string(2));
EXPECT_FALSE(edge->identifier_->user_declared_);
}
TEST(CypherMainVisitorTest, RelationshipPatternDetails) {
@ -709,6 +729,7 @@ TEST(CypherMainVisitorTest, RelationshipPatternVariable) {
EXPECT_EQ(edge->direction_, EdgeAtom::Direction::RIGHT);
ASSERT_TRUE(edge->identifier_);
EXPECT_THAT(edge->identifier_->name_, "var");
EXPECT_TRUE(edge->identifier_->user_declared_);
}
// // Relationship with unbounded variable range.
@ -788,6 +809,7 @@ TEST(CypherMainVisitorTest, ReturnUnanemdIdentifier) {
auto *identifier = dynamic_cast<Identifier *>(named_expr->expression_);
ASSERT_TRUE(identifier);
ASSERT_EQ(identifier->name_, "var");
ASSERT_TRUE(identifier->user_declared_);
}
TEST(CypherMainVisitorTest, Create) {