Handle asterisk in AST conversion
Summary: TODO: Handle COUNT_ASTERISK aggregation in operator (will do in next diff) Reviewers: teon.banek, florijan, buda Reviewed By: buda Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D350
This commit is contained in:
parent
3428a9bfbd
commit
3d026b106d
@ -8,6 +8,7 @@
|
||||
#include "database/graph_db_datatypes.hpp"
|
||||
#include "query/frontend/ast/ast_visitor.hpp"
|
||||
#include "query/typed_value.hpp"
|
||||
#include "utils/assert.hpp"
|
||||
#include "utils/visitor/visitable.hpp"
|
||||
|
||||
namespace query {
|
||||
@ -38,8 +39,8 @@ class BinaryOperator : public Expression {
|
||||
friend class AstTreeStorage;
|
||||
|
||||
public:
|
||||
Expression *expression1_;
|
||||
Expression *expression2_;
|
||||
Expression *expression1_ = nullptr;
|
||||
Expression *expression2_ = nullptr;
|
||||
|
||||
protected:
|
||||
BinaryOperator(int uid) : Expression(uid) {}
|
||||
@ -51,7 +52,7 @@ class UnaryOperator : public Expression {
|
||||
friend class AstTreeStorage;
|
||||
|
||||
public:
|
||||
Expression *expression_;
|
||||
Expression *expression_ = nullptr;
|
||||
|
||||
protected:
|
||||
UnaryOperator(int uid) : Expression(uid) {}
|
||||
@ -454,9 +455,12 @@ class Identifier : public Expression {
|
||||
public:
|
||||
DEFVISITABLE(TreeVisitorBase);
|
||||
std::string name_;
|
||||
bool user_declared_ = true;
|
||||
|
||||
protected:
|
||||
Identifier(int uid, const std::string &name) : Expression(uid), name_(name) {}
|
||||
Identifier(int uid, const std::string &name, bool user_declared)
|
||||
: Expression(uid), name_(name), user_declared_(user_declared) {}
|
||||
};
|
||||
|
||||
class PropertyLookup : public Expression {
|
||||
@ -572,7 +576,9 @@ class Aggregation : public UnaryOperator {
|
||||
void Accept(TreeVisitorBase &visitor) override {
|
||||
if (visitor.PreVisit(*this)) {
|
||||
visitor.Visit(*this);
|
||||
expression_->Accept(visitor);
|
||||
if (expression_) {
|
||||
expression_->Accept(visitor);
|
||||
}
|
||||
visitor.PostVisit(*this);
|
||||
}
|
||||
}
|
||||
@ -580,7 +586,11 @@ class Aggregation : public UnaryOperator {
|
||||
|
||||
protected:
|
||||
Aggregation(int uid, Expression *expression, Op op)
|
||||
: UnaryOperator(uid, expression), op_(op) {}
|
||||
: UnaryOperator(uid, expression), op_(op) {
|
||||
// Count without expression denotes count(*) in cypher.
|
||||
debug_assert(expression || op == Aggregation::Op::Count,
|
||||
"All aggregations, except COUNT require expression");
|
||||
}
|
||||
};
|
||||
|
||||
class NamedExpression : public Tree {
|
||||
@ -780,6 +790,8 @@ enum class Ordering { ASC, DESC };
|
||||
struct ReturnBody {
|
||||
/** @brief True if distinct results should be produced. */
|
||||
bool distinct = false;
|
||||
/** @brief True if asterisk was found in return body */
|
||||
bool all_identifiers = false;
|
||||
/** @brief Expressions which are used to produce results. */
|
||||
std::vector<NamedExpression *> named_expressions;
|
||||
/** @brief Expressions used for ordering the results. */
|
||||
|
@ -83,7 +83,7 @@ antlrcpp::Any CypherMainVisitor::visitSingleQuery(
|
||||
while (true) {
|
||||
std::string id_name = kAnonPrefix + std::to_string(id++);
|
||||
if (users_identifiers.find(id_name) == users_identifiers.end()) {
|
||||
*identifier = storage_.Create<Identifier>(id_name);
|
||||
*identifier = storage_.Create<Identifier>(id_name, false);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -171,22 +171,21 @@ antlrcpp::Any CypherMainVisitor::visitReturnBody(
|
||||
if (ctx->limit()) {
|
||||
body.limit = static_cast<Expression *>(ctx->limit()->accept(this));
|
||||
}
|
||||
body.named_expressions =
|
||||
ctx->returnItems()->accept(this).as<std::vector<NamedExpression *>>();
|
||||
std::tie(body.all_identifiers, body.named_expressions) =
|
||||
ctx->returnItems()
|
||||
->accept(this)
|
||||
.as<std::pair<bool, std::vector<NamedExpression *>>>();
|
||||
return body;
|
||||
}
|
||||
|
||||
antlrcpp::Any CypherMainVisitor::visitReturnItems(
|
||||
CypherParser::ReturnItemsContext *ctx) {
|
||||
if (ctx->getTokens(kReturnAllTokenId).size()) {
|
||||
// TODO: implement *
|
||||
throw utils::NotYetImplemented();
|
||||
}
|
||||
std::vector<NamedExpression *> named_expressions;
|
||||
for (auto *item : ctx->returnItem()) {
|
||||
named_expressions.push_back(item->accept(this));
|
||||
}
|
||||
return named_expressions;
|
||||
return std::pair<bool, std::vector<NamedExpression *>>(
|
||||
ctx->getTokens(kReturnAllTokenId).size(), named_expressions);
|
||||
}
|
||||
|
||||
antlrcpp::Any CypherMainVisitor::visitReturnItem(
|
||||
@ -669,9 +668,15 @@ antlrcpp::Any CypherMainVisitor::visitAtom(CypherParser::AtomContext *ctx) {
|
||||
return static_cast<Expression *>(storage_.Create<Identifier>(variable));
|
||||
} else if (ctx->functionInvocation()) {
|
||||
return static_cast<Expression *>(ctx->functionInvocation()->accept(this));
|
||||
} else if (ctx->COUNT()) {
|
||||
// Here we handle COUNT(*). COUNT(expression) is handled in
|
||||
// visitFunctionInvocation with other aggregations. This is visible in
|
||||
// functionInvocation and atom producions in opencypher grammar.
|
||||
return static_cast<Expression *>(
|
||||
storage_.Create<Aggregation>(nullptr, Aggregation::Op::COUNT));
|
||||
}
|
||||
// TODO: Implement this. We don't support comprehensions, functions,
|
||||
// filtering... at the moment.
|
||||
// TODO: Implement this. We don't support comprehensions, filtering... at the
|
||||
// moment.
|
||||
throw utils::NotYetImplemented();
|
||||
}
|
||||
|
||||
|
@ -165,7 +165,9 @@ class CypherMainVisitor : public antlropencypher::CypherBaseVisitor {
|
||||
antlrcpp::Any visitReturnBody(CypherParser::ReturnBodyContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return vector<NamedExpression*>
|
||||
* @return pair<bool, vector<NamedExpression*>> first member is true if
|
||||
* asterisk was found in return
|
||||
* expressions.
|
||||
*/
|
||||
antlrcpp::Any visitReturnItems(
|
||||
CypherParser::ReturnItemsContext *ctx) override;
|
||||
|
@ -90,6 +90,7 @@ TEST(CypherMainVisitorTest, ReturnNoDistinctNoBagSemantics) {
|
||||
auto *query = ast_generator.query_;
|
||||
ASSERT_EQ(query->clauses_.size(), 1U);
|
||||
auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]);
|
||||
ASSERT_FALSE(return_clause->body_.all_identifiers);
|
||||
ASSERT_EQ(return_clause->body_.order_by.size(), 0U);
|
||||
ASSERT_EQ(return_clause->body_.named_expressions.size(), 1U);
|
||||
ASSERT_FALSE(return_clause->body_.limit);
|
||||
@ -147,12 +148,21 @@ TEST(CypherMainVisitorTest, ReturnNamedIdentifier) {
|
||||
AstGenerator ast_generator("RETURN var AS var5");
|
||||
auto *query = ast_generator.query_;
|
||||
auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]);
|
||||
ASSERT_FALSE(return_clause->body_.all_identifiers);
|
||||
auto *named_expr = return_clause->body_.named_expressions[0];
|
||||
ASSERT_EQ(named_expr->name_, "var5");
|
||||
auto *identifier = dynamic_cast<Identifier *>(named_expr->expression_);
|
||||
ASSERT_EQ(identifier->name_, "var");
|
||||
}
|
||||
|
||||
TEST(CypherMainVisitorTest, ReturnAsterisk) {
|
||||
AstGenerator ast_generator("RETURN *");
|
||||
auto *query = ast_generator.query_;
|
||||
auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]);
|
||||
ASSERT_TRUE(return_clause->body_.all_identifiers);
|
||||
ASSERT_EQ(return_clause->body_.named_expressions.size(), 0U);
|
||||
}
|
||||
|
||||
TEST(CypherMainVisitorTest, IntegerLiteral) {
|
||||
AstGenerator ast_generator("RETURN 42");
|
||||
auto *query = ast_generator.query_;
|
||||
@ -448,10 +458,11 @@ TEST(CypherMainVisitorTest, UnaryMinusPlusOperators) {
|
||||
}
|
||||
|
||||
TEST(CypherMainVisitorTest, Aggregation) {
|
||||
AstGenerator ast_generator("RETURN COUNT(a), MIN(b), MAX(c), SUM(d), AVG(e)");
|
||||
AstGenerator ast_generator(
|
||||
"RETURN COUNT(a), MIN(b), MAX(c), SUM(d), AVG(e), COUNT(*)");
|
||||
auto *query = ast_generator.query_;
|
||||
auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]);
|
||||
ASSERT_EQ(return_clause->body_.named_expressions.size(), 5);
|
||||
ASSERT_EQ(return_clause->body_.named_expressions.size(), 6U);
|
||||
Aggregation::Op ops[] = {Aggregation::Op::COUNT, Aggregation::Op::MIN,
|
||||
Aggregation::Op::MAX, Aggregation::Op::SUM,
|
||||
Aggregation::Op::AVG};
|
||||
@ -465,6 +476,11 @@ TEST(CypherMainVisitorTest, Aggregation) {
|
||||
ASSERT_TRUE(identifier);
|
||||
ASSERT_EQ(identifier->name_, ids[i]);
|
||||
}
|
||||
auto *aggregation = dynamic_cast<Aggregation *>(
|
||||
return_clause->body_.named_expressions[5]->expression_);
|
||||
ASSERT_TRUE(aggregation);
|
||||
ASSERT_EQ(aggregation->op_, Aggregation::Op::COUNT);
|
||||
ASSERT_FALSE(aggregation->expression_);
|
||||
}
|
||||
|
||||
TEST(CypherMainVisitorTest, UndefinedFunction) {
|
||||
@ -593,6 +609,7 @@ TEST(CypherMainVisitorTest, NodePattern) {
|
||||
ASSERT_TRUE(node->identifier_);
|
||||
EXPECT_EQ(node->identifier_->name_,
|
||||
CypherMainVisitor::kAnonPrefix + std::to_string(1));
|
||||
EXPECT_FALSE(node->identifier_->user_declared_);
|
||||
EXPECT_THAT(node->labels_, UnorderedElementsAre(
|
||||
ast_generator.db_accessor_->label("label1"),
|
||||
ast_generator.db_accessor_->label("label2"),
|
||||
@ -621,6 +638,7 @@ TEST(CypherMainVisitorTest, NodePatternIdentifier) {
|
||||
ASSERT_TRUE(node);
|
||||
ASSERT_TRUE(node->identifier_);
|
||||
EXPECT_EQ(node->identifier_->name_, "var");
|
||||
EXPECT_TRUE(node->identifier_->user_declared_);
|
||||
EXPECT_THAT(node->labels_, UnorderedElementsAre());
|
||||
EXPECT_THAT(node->properties_, UnorderedElementsAre());
|
||||
}
|
||||
@ -645,6 +663,7 @@ TEST(CypherMainVisitorTest, RelationshipPatternNoDetails) {
|
||||
ASSERT_TRUE(edge->identifier_);
|
||||
EXPECT_THAT(edge->identifier_->name_,
|
||||
CypherMainVisitor::kAnonPrefix + std::to_string(2));
|
||||
EXPECT_FALSE(edge->identifier_->user_declared_);
|
||||
}
|
||||
|
||||
// PatternPart in braces.
|
||||
@ -667,6 +686,7 @@ TEST(CypherMainVisitorTest, PatternPartBraces) {
|
||||
ASSERT_TRUE(edge->identifier_);
|
||||
EXPECT_THAT(edge->identifier_->name_,
|
||||
CypherMainVisitor::kAnonPrefix + std::to_string(2));
|
||||
EXPECT_FALSE(edge->identifier_->user_declared_);
|
||||
}
|
||||
|
||||
TEST(CypherMainVisitorTest, RelationshipPatternDetails) {
|
||||
@ -709,6 +729,7 @@ TEST(CypherMainVisitorTest, RelationshipPatternVariable) {
|
||||
EXPECT_EQ(edge->direction_, EdgeAtom::Direction::RIGHT);
|
||||
ASSERT_TRUE(edge->identifier_);
|
||||
EXPECT_THAT(edge->identifier_->name_, "var");
|
||||
EXPECT_TRUE(edge->identifier_->user_declared_);
|
||||
}
|
||||
|
||||
// // Relationship with unbounded variable range.
|
||||
@ -788,6 +809,7 @@ TEST(CypherMainVisitorTest, ReturnUnanemdIdentifier) {
|
||||
auto *identifier = dynamic_cast<Identifier *>(named_expr->expression_);
|
||||
ASSERT_TRUE(identifier);
|
||||
ASSERT_EQ(identifier->name_, "var");
|
||||
ASSERT_TRUE(identifier->user_declared_);
|
||||
}
|
||||
|
||||
TEST(CypherMainVisitorTest, Create) {
|
||||
|
Loading…
Reference in New Issue
Block a user