Handle asterisk in AST conversion
Summary: TODO: Handle COUNT_ASTERISK aggregation in operator (will do in next diff) Reviewers: teon.banek, florijan, buda Reviewed By: buda Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D350
This commit is contained in:
parent
3428a9bfbd
commit
3d026b106d
@ -8,6 +8,7 @@
|
|||||||
#include "database/graph_db_datatypes.hpp"
|
#include "database/graph_db_datatypes.hpp"
|
||||||
#include "query/frontend/ast/ast_visitor.hpp"
|
#include "query/frontend/ast/ast_visitor.hpp"
|
||||||
#include "query/typed_value.hpp"
|
#include "query/typed_value.hpp"
|
||||||
|
#include "utils/assert.hpp"
|
||||||
#include "utils/visitor/visitable.hpp"
|
#include "utils/visitor/visitable.hpp"
|
||||||
|
|
||||||
namespace query {
|
namespace query {
|
||||||
@ -38,8 +39,8 @@ class BinaryOperator : public Expression {
|
|||||||
friend class AstTreeStorage;
|
friend class AstTreeStorage;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Expression *expression1_;
|
Expression *expression1_ = nullptr;
|
||||||
Expression *expression2_;
|
Expression *expression2_ = nullptr;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
BinaryOperator(int uid) : Expression(uid) {}
|
BinaryOperator(int uid) : Expression(uid) {}
|
||||||
@ -51,7 +52,7 @@ class UnaryOperator : public Expression {
|
|||||||
friend class AstTreeStorage;
|
friend class AstTreeStorage;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Expression *expression_;
|
Expression *expression_ = nullptr;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
UnaryOperator(int uid) : Expression(uid) {}
|
UnaryOperator(int uid) : Expression(uid) {}
|
||||||
@ -454,9 +455,12 @@ class Identifier : public Expression {
|
|||||||
public:
|
public:
|
||||||
DEFVISITABLE(TreeVisitorBase);
|
DEFVISITABLE(TreeVisitorBase);
|
||||||
std::string name_;
|
std::string name_;
|
||||||
|
bool user_declared_ = true;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Identifier(int uid, const std::string &name) : Expression(uid), name_(name) {}
|
Identifier(int uid, const std::string &name) : Expression(uid), name_(name) {}
|
||||||
|
Identifier(int uid, const std::string &name, bool user_declared)
|
||||||
|
: Expression(uid), name_(name), user_declared_(user_declared) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
class PropertyLookup : public Expression {
|
class PropertyLookup : public Expression {
|
||||||
@ -572,7 +576,9 @@ class Aggregation : public UnaryOperator {
|
|||||||
void Accept(TreeVisitorBase &visitor) override {
|
void Accept(TreeVisitorBase &visitor) override {
|
||||||
if (visitor.PreVisit(*this)) {
|
if (visitor.PreVisit(*this)) {
|
||||||
visitor.Visit(*this);
|
visitor.Visit(*this);
|
||||||
expression_->Accept(visitor);
|
if (expression_) {
|
||||||
|
expression_->Accept(visitor);
|
||||||
|
}
|
||||||
visitor.PostVisit(*this);
|
visitor.PostVisit(*this);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -580,7 +586,11 @@ class Aggregation : public UnaryOperator {
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
Aggregation(int uid, Expression *expression, Op op)
|
Aggregation(int uid, Expression *expression, Op op)
|
||||||
: UnaryOperator(uid, expression), op_(op) {}
|
: UnaryOperator(uid, expression), op_(op) {
|
||||||
|
// Count without expression denotes count(*) in cypher.
|
||||||
|
debug_assert(expression || op == Aggregation::Op::Count,
|
||||||
|
"All aggregations, except COUNT require expression");
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class NamedExpression : public Tree {
|
class NamedExpression : public Tree {
|
||||||
@ -780,6 +790,8 @@ enum class Ordering { ASC, DESC };
|
|||||||
struct ReturnBody {
|
struct ReturnBody {
|
||||||
/** @brief True if distinct results should be produced. */
|
/** @brief True if distinct results should be produced. */
|
||||||
bool distinct = false;
|
bool distinct = false;
|
||||||
|
/** @brief True if asterisk was found in return body */
|
||||||
|
bool all_identifiers = false;
|
||||||
/** @brief Expressions which are used to produce results. */
|
/** @brief Expressions which are used to produce results. */
|
||||||
std::vector<NamedExpression *> named_expressions;
|
std::vector<NamedExpression *> named_expressions;
|
||||||
/** @brief Expressions used for ordering the results. */
|
/** @brief Expressions used for ordering the results. */
|
||||||
|
@ -83,7 +83,7 @@ antlrcpp::Any CypherMainVisitor::visitSingleQuery(
|
|||||||
while (true) {
|
while (true) {
|
||||||
std::string id_name = kAnonPrefix + std::to_string(id++);
|
std::string id_name = kAnonPrefix + std::to_string(id++);
|
||||||
if (users_identifiers.find(id_name) == users_identifiers.end()) {
|
if (users_identifiers.find(id_name) == users_identifiers.end()) {
|
||||||
*identifier = storage_.Create<Identifier>(id_name);
|
*identifier = storage_.Create<Identifier>(id_name, false);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -171,22 +171,21 @@ antlrcpp::Any CypherMainVisitor::visitReturnBody(
|
|||||||
if (ctx->limit()) {
|
if (ctx->limit()) {
|
||||||
body.limit = static_cast<Expression *>(ctx->limit()->accept(this));
|
body.limit = static_cast<Expression *>(ctx->limit()->accept(this));
|
||||||
}
|
}
|
||||||
body.named_expressions =
|
std::tie(body.all_identifiers, body.named_expressions) =
|
||||||
ctx->returnItems()->accept(this).as<std::vector<NamedExpression *>>();
|
ctx->returnItems()
|
||||||
|
->accept(this)
|
||||||
|
.as<std::pair<bool, std::vector<NamedExpression *>>>();
|
||||||
return body;
|
return body;
|
||||||
}
|
}
|
||||||
|
|
||||||
antlrcpp::Any CypherMainVisitor::visitReturnItems(
|
antlrcpp::Any CypherMainVisitor::visitReturnItems(
|
||||||
CypherParser::ReturnItemsContext *ctx) {
|
CypherParser::ReturnItemsContext *ctx) {
|
||||||
if (ctx->getTokens(kReturnAllTokenId).size()) {
|
|
||||||
// TODO: implement *
|
|
||||||
throw utils::NotYetImplemented();
|
|
||||||
}
|
|
||||||
std::vector<NamedExpression *> named_expressions;
|
std::vector<NamedExpression *> named_expressions;
|
||||||
for (auto *item : ctx->returnItem()) {
|
for (auto *item : ctx->returnItem()) {
|
||||||
named_expressions.push_back(item->accept(this));
|
named_expressions.push_back(item->accept(this));
|
||||||
}
|
}
|
||||||
return named_expressions;
|
return std::pair<bool, std::vector<NamedExpression *>>(
|
||||||
|
ctx->getTokens(kReturnAllTokenId).size(), named_expressions);
|
||||||
}
|
}
|
||||||
|
|
||||||
antlrcpp::Any CypherMainVisitor::visitReturnItem(
|
antlrcpp::Any CypherMainVisitor::visitReturnItem(
|
||||||
@ -669,9 +668,15 @@ antlrcpp::Any CypherMainVisitor::visitAtom(CypherParser::AtomContext *ctx) {
|
|||||||
return static_cast<Expression *>(storage_.Create<Identifier>(variable));
|
return static_cast<Expression *>(storage_.Create<Identifier>(variable));
|
||||||
} else if (ctx->functionInvocation()) {
|
} else if (ctx->functionInvocation()) {
|
||||||
return static_cast<Expression *>(ctx->functionInvocation()->accept(this));
|
return static_cast<Expression *>(ctx->functionInvocation()->accept(this));
|
||||||
|
} else if (ctx->COUNT()) {
|
||||||
|
// Here we handle COUNT(*). COUNT(expression) is handled in
|
||||||
|
// visitFunctionInvocation with other aggregations. This is visible in
|
||||||
|
// functionInvocation and atom producions in opencypher grammar.
|
||||||
|
return static_cast<Expression *>(
|
||||||
|
storage_.Create<Aggregation>(nullptr, Aggregation::Op::COUNT));
|
||||||
}
|
}
|
||||||
// TODO: Implement this. We don't support comprehensions, functions,
|
// TODO: Implement this. We don't support comprehensions, filtering... at the
|
||||||
// filtering... at the moment.
|
// moment.
|
||||||
throw utils::NotYetImplemented();
|
throw utils::NotYetImplemented();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -165,7 +165,9 @@ class CypherMainVisitor : public antlropencypher::CypherBaseVisitor {
|
|||||||
antlrcpp::Any visitReturnBody(CypherParser::ReturnBodyContext *ctx) override;
|
antlrcpp::Any visitReturnBody(CypherParser::ReturnBodyContext *ctx) override;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return vector<NamedExpression*>
|
* @return pair<bool, vector<NamedExpression*>> first member is true if
|
||||||
|
* asterisk was found in return
|
||||||
|
* expressions.
|
||||||
*/
|
*/
|
||||||
antlrcpp::Any visitReturnItems(
|
antlrcpp::Any visitReturnItems(
|
||||||
CypherParser::ReturnItemsContext *ctx) override;
|
CypherParser::ReturnItemsContext *ctx) override;
|
||||||
|
@ -90,6 +90,7 @@ TEST(CypherMainVisitorTest, ReturnNoDistinctNoBagSemantics) {
|
|||||||
auto *query = ast_generator.query_;
|
auto *query = ast_generator.query_;
|
||||||
ASSERT_EQ(query->clauses_.size(), 1U);
|
ASSERT_EQ(query->clauses_.size(), 1U);
|
||||||
auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]);
|
auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]);
|
||||||
|
ASSERT_FALSE(return_clause->body_.all_identifiers);
|
||||||
ASSERT_EQ(return_clause->body_.order_by.size(), 0U);
|
ASSERT_EQ(return_clause->body_.order_by.size(), 0U);
|
||||||
ASSERT_EQ(return_clause->body_.named_expressions.size(), 1U);
|
ASSERT_EQ(return_clause->body_.named_expressions.size(), 1U);
|
||||||
ASSERT_FALSE(return_clause->body_.limit);
|
ASSERT_FALSE(return_clause->body_.limit);
|
||||||
@ -147,12 +148,21 @@ TEST(CypherMainVisitorTest, ReturnNamedIdentifier) {
|
|||||||
AstGenerator ast_generator("RETURN var AS var5");
|
AstGenerator ast_generator("RETURN var AS var5");
|
||||||
auto *query = ast_generator.query_;
|
auto *query = ast_generator.query_;
|
||||||
auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]);
|
auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]);
|
||||||
|
ASSERT_FALSE(return_clause->body_.all_identifiers);
|
||||||
auto *named_expr = return_clause->body_.named_expressions[0];
|
auto *named_expr = return_clause->body_.named_expressions[0];
|
||||||
ASSERT_EQ(named_expr->name_, "var5");
|
ASSERT_EQ(named_expr->name_, "var5");
|
||||||
auto *identifier = dynamic_cast<Identifier *>(named_expr->expression_);
|
auto *identifier = dynamic_cast<Identifier *>(named_expr->expression_);
|
||||||
ASSERT_EQ(identifier->name_, "var");
|
ASSERT_EQ(identifier->name_, "var");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CypherMainVisitorTest, ReturnAsterisk) {
|
||||||
|
AstGenerator ast_generator("RETURN *");
|
||||||
|
auto *query = ast_generator.query_;
|
||||||
|
auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]);
|
||||||
|
ASSERT_TRUE(return_clause->body_.all_identifiers);
|
||||||
|
ASSERT_EQ(return_clause->body_.named_expressions.size(), 0U);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(CypherMainVisitorTest, IntegerLiteral) {
|
TEST(CypherMainVisitorTest, IntegerLiteral) {
|
||||||
AstGenerator ast_generator("RETURN 42");
|
AstGenerator ast_generator("RETURN 42");
|
||||||
auto *query = ast_generator.query_;
|
auto *query = ast_generator.query_;
|
||||||
@ -448,10 +458,11 @@ TEST(CypherMainVisitorTest, UnaryMinusPlusOperators) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(CypherMainVisitorTest, Aggregation) {
|
TEST(CypherMainVisitorTest, Aggregation) {
|
||||||
AstGenerator ast_generator("RETURN COUNT(a), MIN(b), MAX(c), SUM(d), AVG(e)");
|
AstGenerator ast_generator(
|
||||||
|
"RETURN COUNT(a), MIN(b), MAX(c), SUM(d), AVG(e), COUNT(*)");
|
||||||
auto *query = ast_generator.query_;
|
auto *query = ast_generator.query_;
|
||||||
auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]);
|
auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]);
|
||||||
ASSERT_EQ(return_clause->body_.named_expressions.size(), 5);
|
ASSERT_EQ(return_clause->body_.named_expressions.size(), 6U);
|
||||||
Aggregation::Op ops[] = {Aggregation::Op::COUNT, Aggregation::Op::MIN,
|
Aggregation::Op ops[] = {Aggregation::Op::COUNT, Aggregation::Op::MIN,
|
||||||
Aggregation::Op::MAX, Aggregation::Op::SUM,
|
Aggregation::Op::MAX, Aggregation::Op::SUM,
|
||||||
Aggregation::Op::AVG};
|
Aggregation::Op::AVG};
|
||||||
@ -465,6 +476,11 @@ TEST(CypherMainVisitorTest, Aggregation) {
|
|||||||
ASSERT_TRUE(identifier);
|
ASSERT_TRUE(identifier);
|
||||||
ASSERT_EQ(identifier->name_, ids[i]);
|
ASSERT_EQ(identifier->name_, ids[i]);
|
||||||
}
|
}
|
||||||
|
auto *aggregation = dynamic_cast<Aggregation *>(
|
||||||
|
return_clause->body_.named_expressions[5]->expression_);
|
||||||
|
ASSERT_TRUE(aggregation);
|
||||||
|
ASSERT_EQ(aggregation->op_, Aggregation::Op::COUNT);
|
||||||
|
ASSERT_FALSE(aggregation->expression_);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CypherMainVisitorTest, UndefinedFunction) {
|
TEST(CypherMainVisitorTest, UndefinedFunction) {
|
||||||
@ -593,6 +609,7 @@ TEST(CypherMainVisitorTest, NodePattern) {
|
|||||||
ASSERT_TRUE(node->identifier_);
|
ASSERT_TRUE(node->identifier_);
|
||||||
EXPECT_EQ(node->identifier_->name_,
|
EXPECT_EQ(node->identifier_->name_,
|
||||||
CypherMainVisitor::kAnonPrefix + std::to_string(1));
|
CypherMainVisitor::kAnonPrefix + std::to_string(1));
|
||||||
|
EXPECT_FALSE(node->identifier_->user_declared_);
|
||||||
EXPECT_THAT(node->labels_, UnorderedElementsAre(
|
EXPECT_THAT(node->labels_, UnorderedElementsAre(
|
||||||
ast_generator.db_accessor_->label("label1"),
|
ast_generator.db_accessor_->label("label1"),
|
||||||
ast_generator.db_accessor_->label("label2"),
|
ast_generator.db_accessor_->label("label2"),
|
||||||
@ -621,6 +638,7 @@ TEST(CypherMainVisitorTest, NodePatternIdentifier) {
|
|||||||
ASSERT_TRUE(node);
|
ASSERT_TRUE(node);
|
||||||
ASSERT_TRUE(node->identifier_);
|
ASSERT_TRUE(node->identifier_);
|
||||||
EXPECT_EQ(node->identifier_->name_, "var");
|
EXPECT_EQ(node->identifier_->name_, "var");
|
||||||
|
EXPECT_TRUE(node->identifier_->user_declared_);
|
||||||
EXPECT_THAT(node->labels_, UnorderedElementsAre());
|
EXPECT_THAT(node->labels_, UnorderedElementsAre());
|
||||||
EXPECT_THAT(node->properties_, UnorderedElementsAre());
|
EXPECT_THAT(node->properties_, UnorderedElementsAre());
|
||||||
}
|
}
|
||||||
@ -645,6 +663,7 @@ TEST(CypherMainVisitorTest, RelationshipPatternNoDetails) {
|
|||||||
ASSERT_TRUE(edge->identifier_);
|
ASSERT_TRUE(edge->identifier_);
|
||||||
EXPECT_THAT(edge->identifier_->name_,
|
EXPECT_THAT(edge->identifier_->name_,
|
||||||
CypherMainVisitor::kAnonPrefix + std::to_string(2));
|
CypherMainVisitor::kAnonPrefix + std::to_string(2));
|
||||||
|
EXPECT_FALSE(edge->identifier_->user_declared_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// PatternPart in braces.
|
// PatternPart in braces.
|
||||||
@ -667,6 +686,7 @@ TEST(CypherMainVisitorTest, PatternPartBraces) {
|
|||||||
ASSERT_TRUE(edge->identifier_);
|
ASSERT_TRUE(edge->identifier_);
|
||||||
EXPECT_THAT(edge->identifier_->name_,
|
EXPECT_THAT(edge->identifier_->name_,
|
||||||
CypherMainVisitor::kAnonPrefix + std::to_string(2));
|
CypherMainVisitor::kAnonPrefix + std::to_string(2));
|
||||||
|
EXPECT_FALSE(edge->identifier_->user_declared_);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CypherMainVisitorTest, RelationshipPatternDetails) {
|
TEST(CypherMainVisitorTest, RelationshipPatternDetails) {
|
||||||
@ -709,6 +729,7 @@ TEST(CypherMainVisitorTest, RelationshipPatternVariable) {
|
|||||||
EXPECT_EQ(edge->direction_, EdgeAtom::Direction::RIGHT);
|
EXPECT_EQ(edge->direction_, EdgeAtom::Direction::RIGHT);
|
||||||
ASSERT_TRUE(edge->identifier_);
|
ASSERT_TRUE(edge->identifier_);
|
||||||
EXPECT_THAT(edge->identifier_->name_, "var");
|
EXPECT_THAT(edge->identifier_->name_, "var");
|
||||||
|
EXPECT_TRUE(edge->identifier_->user_declared_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// // Relationship with unbounded variable range.
|
// // Relationship with unbounded variable range.
|
||||||
@ -788,6 +809,7 @@ TEST(CypherMainVisitorTest, ReturnUnanemdIdentifier) {
|
|||||||
auto *identifier = dynamic_cast<Identifier *>(named_expr->expression_);
|
auto *identifier = dynamic_cast<Identifier *>(named_expr->expression_);
|
||||||
ASSERT_TRUE(identifier);
|
ASSERT_TRUE(identifier);
|
||||||
ASSERT_EQ(identifier->name_, "var");
|
ASSERT_EQ(identifier->name_, "var");
|
||||||
|
ASSERT_TRUE(identifier->user_declared_);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CypherMainVisitorTest, Create) {
|
TEST(CypherMainVisitorTest, Create) {
|
||||||
|
Loading…
Reference in New Issue
Block a user