Add aggregation conversion from antlr to ast
Reviewers: teon.banek Reviewed By: teon.banek Differential Revision: https://phabricator.memgraph.io/D260
This commit is contained in:
parent
7981bd19e0
commit
dfa6800edd
@ -418,7 +418,11 @@ class Aggregation : public UnaryOperator {
|
||||
|
||||
public:
|
||||
enum class Op { COUNT, MIN, MAX, SUM, AVG };
|
||||
Op op_;
|
||||
static const constexpr char *const kCount = "COUNT";
|
||||
static const constexpr char *const kMin = "MIN";
|
||||
static const constexpr char *const kMax = "MAX";
|
||||
static const constexpr char *const kSum = "SUM";
|
||||
static const constexpr char *const kAvg = "AVG";
|
||||
|
||||
void Accept(TreeVisitorBase &visitor) override {
|
||||
if (visitor.PreVisit(*this)) {
|
||||
@ -427,6 +431,7 @@ class Aggregation : public UnaryOperator {
|
||||
visitor.PostVisit(*this);
|
||||
}
|
||||
}
|
||||
Op op_;
|
||||
|
||||
protected:
|
||||
Aggregation(int uid, Expression *expression, Op op)
|
||||
|
@ -588,6 +588,8 @@ antlrcpp::Any CypherMainVisitor::visitAtom(CypherParser::AtomContext *ctx) {
|
||||
std::string variable = ctx->variable()->accept(this);
|
||||
users_identifiers.insert(variable);
|
||||
return static_cast<Expression *>(storage_.Create<Identifier>(variable));
|
||||
} else if (ctx->functionInvocation()) {
|
||||
return static_cast<Expression *>(ctx->functionInvocation()->accept(this));
|
||||
}
|
||||
// TODO: Implement this. We don't support comprehensions, functions,
|
||||
// filtering... at the moment.
|
||||
@ -633,6 +635,51 @@ antlrcpp::Any CypherMainVisitor::visitNumberLiteral(
|
||||
}
|
||||
}
|
||||
|
||||
antlrcpp::Any CypherMainVisitor::visitFunctionInvocation(
|
||||
CypherParser::FunctionInvocationContext *ctx) {
|
||||
if (ctx->DISTINCT()) {
|
||||
throw NotYetImplemented();
|
||||
}
|
||||
std::vector<Expression *> expressions;
|
||||
for (auto *expression : ctx->expression()) {
|
||||
expressions.push_back(expression->accept(this));
|
||||
}
|
||||
std::string function_name = ctx->functionName()->accept(this);
|
||||
if (expressions.size() == 1U) {
|
||||
if (function_name == Aggregation::kCount) {
|
||||
return static_cast<Expression *>(
|
||||
storage_.Create<Aggregation>(expressions[0], Aggregation::Op::COUNT));
|
||||
}
|
||||
if (function_name == Aggregation::kMin) {
|
||||
return static_cast<Expression *>(
|
||||
storage_.Create<Aggregation>(expressions[0], Aggregation::Op::MIN));
|
||||
}
|
||||
if (function_name == Aggregation::kMax) {
|
||||
return static_cast<Expression *>(
|
||||
storage_.Create<Aggregation>(expressions[0], Aggregation::Op::MAX));
|
||||
}
|
||||
if (function_name == Aggregation::kSum) {
|
||||
return static_cast<Expression *>(
|
||||
storage_.Create<Aggregation>(expressions[0], Aggregation::Op::SUM));
|
||||
}
|
||||
if (function_name == Aggregation::kAvg) {
|
||||
return static_cast<Expression *>(
|
||||
storage_.Create<Aggregation>(expressions[0], Aggregation::Op::AVG));
|
||||
}
|
||||
}
|
||||
// it is not a aggregation, it is a regular function,
|
||||
// will be implemented in next diff
|
||||
throw NotYetImplemented();
|
||||
}
|
||||
|
||||
antlrcpp::Any CypherMainVisitor::visitFunctionName(
|
||||
CypherParser::FunctionNameContext *ctx) {
|
||||
std::string function_name = ctx->getText();
|
||||
std::transform(function_name.begin(), function_name.end(),
|
||||
function_name.begin(), toupper);
|
||||
return function_name;
|
||||
}
|
||||
|
||||
antlrcpp::Any CypherMainVisitor::visitDoubleLiteral(
|
||||
CypherParser::DoubleLiteralContext *ctx) {
|
||||
// stod would be nicer but it uses current locale so we shouldn't use it.
|
||||
|
@ -370,6 +370,18 @@ class CypherMainVisitor : public antlropencypher::CypherBaseVisitor {
|
||||
antlrcpp::Any visitParenthesizedExpression(
|
||||
CypherParser::ParenthesizedExpressionContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return Expression*
|
||||
*/
|
||||
antlrcpp::Any visitFunctionInvocation(
|
||||
CypherParser::FunctionInvocationContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return string - uppercased
|
||||
*/
|
||||
antlrcpp::Any visitFunctionName(
|
||||
CypherParser::FunctionNameContext *ctx) override;
|
||||
|
||||
/**
|
||||
* @return Literal*
|
||||
*/
|
||||
|
@ -145,6 +145,14 @@ class ExpressionEvaluator : public TreeVisitorBase {
|
||||
result_stack_.push_back(literal.value_);
|
||||
}
|
||||
|
||||
void Visit(Aggregation &aggregation) override {
|
||||
auto value = frame_[symbol_table_[aggregation]];
|
||||
// Aggregation is probably always simple type, but let's switch accessor
|
||||
// just to be sure.
|
||||
SwitchAccessors(value);
|
||||
result_stack_.emplace_back(std::move(value));
|
||||
}
|
||||
|
||||
private:
|
||||
// If the given TypedValue contains accessors, switch them to New or Old,
|
||||
// depending on use_new_ flag.
|
||||
|
@ -329,6 +329,26 @@ TEST(CypherMainVisitorTest, UnaryMinusPlusOperators) {
|
||||
ASSERT_EQ(operand->value_.Value<int64_t>(), 5);
|
||||
}
|
||||
|
||||
TEST(CypherMainVisitorTest, Aggregation) {
|
||||
AstGenerator ast_generator("RETURN COUNT(a), MIN(b), MAX(c), SUM(d), AVG(e)");
|
||||
auto *query = ast_generator.query_;
|
||||
auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]);
|
||||
ASSERT_EQ(return_clause->named_expressions_.size(), 5);
|
||||
Aggregation::Op ops[] = {Aggregation::Op::COUNT, Aggregation::Op::MIN,
|
||||
Aggregation::Op::MAX, Aggregation::Op::SUM,
|
||||
Aggregation::Op::AVG};
|
||||
std::string ids[] = {"a", "b", "c", "d", "e"};
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
auto *aggregation = dynamic_cast<Aggregation *>(
|
||||
return_clause->named_expressions_[i]->expression_);
|
||||
ASSERT_TRUE(aggregation);
|
||||
ASSERT_EQ(aggregation->op_, ops[i]);
|
||||
auto *identifier = dynamic_cast<Identifier *>(aggregation->expression_);
|
||||
ASSERT_TRUE(identifier);
|
||||
ASSERT_EQ(identifier->name_, ids[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CypherMainVisitorTest, StringLiteralDoubleQuotes) {
|
||||
AstGenerator ast_generator("RETURN \"mi'rko\"");
|
||||
auto *query = ast_generator.query_;
|
||||
|
Loading…
Reference in New Issue
Block a user