Add aggregation conversion from antlr to ast

Reviewers: teon.banek

Reviewed By: teon.banek

Differential Revision: https://phabricator.memgraph.io/D260
This commit is contained in:
Mislav Bradac 2017-04-11 14:39:34 +02:00
parent 7981bd19e0
commit dfa6800edd
5 changed files with 93 additions and 1 deletions

View File

@ -418,7 +418,11 @@ class Aggregation : public UnaryOperator {
public:
enum class Op { COUNT, MIN, MAX, SUM, AVG };
Op op_;
static const constexpr char *const kCount = "COUNT";
static const constexpr char *const kMin = "MIN";
static const constexpr char *const kMax = "MAX";
static const constexpr char *const kSum = "SUM";
static const constexpr char *const kAvg = "AVG";
void Accept(TreeVisitorBase &visitor) override {
if (visitor.PreVisit(*this)) {
@ -427,6 +431,7 @@ class Aggregation : public UnaryOperator {
visitor.PostVisit(*this);
}
}
Op op_;
protected:
Aggregation(int uid, Expression *expression, Op op)

View File

@ -588,6 +588,8 @@ antlrcpp::Any CypherMainVisitor::visitAtom(CypherParser::AtomContext *ctx) {
std::string variable = ctx->variable()->accept(this);
users_identifiers.insert(variable);
return static_cast<Expression *>(storage_.Create<Identifier>(variable));
} else if (ctx->functionInvocation()) {
return static_cast<Expression *>(ctx->functionInvocation()->accept(this));
}
// TODO: Implement this. We don't support comprehensions, functions,
// filtering... at the moment.
@ -633,6 +635,51 @@ antlrcpp::Any CypherMainVisitor::visitNumberLiteral(
}
}
antlrcpp::Any CypherMainVisitor::visitFunctionInvocation(
CypherParser::FunctionInvocationContext *ctx) {
if (ctx->DISTINCT()) {
throw NotYetImplemented();
}
std::vector<Expression *> expressions;
for (auto *expression : ctx->expression()) {
expressions.push_back(expression->accept(this));
}
std::string function_name = ctx->functionName()->accept(this);
if (expressions.size() == 1U) {
if (function_name == Aggregation::kCount) {
return static_cast<Expression *>(
storage_.Create<Aggregation>(expressions[0], Aggregation::Op::COUNT));
}
if (function_name == Aggregation::kMin) {
return static_cast<Expression *>(
storage_.Create<Aggregation>(expressions[0], Aggregation::Op::MIN));
}
if (function_name == Aggregation::kMax) {
return static_cast<Expression *>(
storage_.Create<Aggregation>(expressions[0], Aggregation::Op::MAX));
}
if (function_name == Aggregation::kSum) {
return static_cast<Expression *>(
storage_.Create<Aggregation>(expressions[0], Aggregation::Op::SUM));
}
if (function_name == Aggregation::kAvg) {
return static_cast<Expression *>(
storage_.Create<Aggregation>(expressions[0], Aggregation::Op::AVG));
}
}
// it is not a aggregation, it is a regular function,
// will be implemented in next diff
throw NotYetImplemented();
}
antlrcpp::Any CypherMainVisitor::visitFunctionName(
CypherParser::FunctionNameContext *ctx) {
std::string function_name = ctx->getText();
std::transform(function_name.begin(), function_name.end(),
function_name.begin(), toupper);
return function_name;
}
antlrcpp::Any CypherMainVisitor::visitDoubleLiteral(
CypherParser::DoubleLiteralContext *ctx) {
// stod would be nicer but it uses current locale so we shouldn't use it.

View File

@ -370,6 +370,18 @@ class CypherMainVisitor : public antlropencypher::CypherBaseVisitor {
antlrcpp::Any visitParenthesizedExpression(
CypherParser::ParenthesizedExpressionContext *ctx) override;
/**
* @return Expression*
*/
antlrcpp::Any visitFunctionInvocation(
CypherParser::FunctionInvocationContext *ctx) override;
/**
* @return string - uppercased
*/
antlrcpp::Any visitFunctionName(
CypherParser::FunctionNameContext *ctx) override;
/**
* @return Literal*
*/

View File

@ -145,6 +145,14 @@ class ExpressionEvaluator : public TreeVisitorBase {
result_stack_.push_back(literal.value_);
}
void Visit(Aggregation &aggregation) override {
auto value = frame_[symbol_table_[aggregation]];
// Aggregation is probably always simple type, but let's switch accessor
// just to be sure.
SwitchAccessors(value);
result_stack_.emplace_back(std::move(value));
}
private:
// If the given TypedValue contains accessors, switch them to New or Old,
// depending on use_new_ flag.

View File

@ -329,6 +329,26 @@ TEST(CypherMainVisitorTest, UnaryMinusPlusOperators) {
ASSERT_EQ(operand->value_.Value<int64_t>(), 5);
}
TEST(CypherMainVisitorTest, Aggregation) {
AstGenerator ast_generator("RETURN COUNT(a), MIN(b), MAX(c), SUM(d), AVG(e)");
auto *query = ast_generator.query_;
auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]);
ASSERT_EQ(return_clause->named_expressions_.size(), 5);
Aggregation::Op ops[] = {Aggregation::Op::COUNT, Aggregation::Op::MIN,
Aggregation::Op::MAX, Aggregation::Op::SUM,
Aggregation::Op::AVG};
std::string ids[] = {"a", "b", "c", "d", "e"};
for (int i = 0; i < 5; ++i) {
auto *aggregation = dynamic_cast<Aggregation *>(
return_clause->named_expressions_[i]->expression_);
ASSERT_TRUE(aggregation);
ASSERT_EQ(aggregation->op_, ops[i]);
auto *identifier = dynamic_cast<Identifier *>(aggregation->expression_);
ASSERT_TRUE(identifier);
ASSERT_EQ(identifier->name_, ids[i]);
}
}
TEST(CypherMainVisitorTest, StringLiteralDoubleQuotes) {
AstGenerator ast_generator("RETURN \"mi'rko\"");
auto *query = ast_generator.query_;