Implement simple function conversion

Reviewers: buda

Reviewed By: buda

Subscribers: florijan, teon.banek

Differential Revision: https://phabricator.memgraph.io/D264
This commit is contained in:
Mislav Bradac 2017-04-11 19:12:52 +02:00
parent 1273cea870
commit 36129cdcae
9 changed files with 153 additions and 7 deletions

View File

@ -362,6 +362,7 @@ set(memgraph_src_files
${src_dir}/query/console.cpp
${src_dir}/query/frontend/ast/cypher_main_visitor.cpp
${src_dir}/query/typed_value.cpp
${src_dir}/query/frontend/interpret/awesome_memgraph_functions.cpp
${src_dir}/query/frontend/logical/operator.cpp
${src_dir}/query/frontend/logical/planner.cpp
${src_dir}/query/frontend/semantic/symbol_generator.cpp

View File

@ -413,6 +413,30 @@ class PropertyLookup : public Expression {
: Expression(uid), expression_(expression), property_(property) {}
};
class Function : public Expression {
friend class AstTreeStorage;
public:
void Accept(TreeVisitorBase &visitor) override {
if (visitor.PreVisit(*this)) {
visitor.Visit(*this);
for (auto *argument : arguments_) {
argument->Accept(visitor);
}
visitor.PostVisit(*this);
}
}
std::function<TypedValue(const std::vector<TypedValue> &)> function_;
std::vector<Expression *> arguments_;
protected:
Function(int uid,
std::function<TypedValue(const std::vector<TypedValue> &)> function,
const std::vector<Expression *> &arguments)
: Expression(uid), function_(function), arguments_(arguments) {}
};
class Aggregation : public UnaryOperator {
friend class AstTreeStorage;

View File

@ -10,6 +10,7 @@ class NamedExpression;
class Identifier;
class PropertyLookup;
class Aggregation;
class Function;
class Create;
class Match;
class Return;
@ -50,7 +51,7 @@ using TreeVisitorBase = ::utils::Visitor<
DivisionOperator, ModOperator, NotEqualOperator, EqualOperator,
LessOperator, GreaterOperator, LessEqualOperator, GreaterEqualOperator,
UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, Identifier, Literal,
PropertyLookup, Aggregation, Create, Match, Return, With, Pattern, NodeAtom,
EdgeAtom, Delete, Where, SetProperty, SetProperties, SetLabels,
PropertyLookup, Aggregation, Function, Create, Match, Return, With, Pattern,
NodeAtom, EdgeAtom, Delete, Where, SetProperty, SetProperties, SetLabels,
RemoveProperty, RemoveLabels>;
}

View File

@ -11,6 +11,7 @@
#include "database/graph_db.hpp"
#include "query/exceptions.hpp"
#include "query/frontend/interpret/awesome_memgraph_functions.hpp"
#include "utils/assert.hpp"
namespace query {
@ -640,11 +641,11 @@ antlrcpp::Any CypherMainVisitor::visitFunctionInvocation(
if (ctx->DISTINCT()) {
throw NotYetImplemented();
}
std::string function_name = ctx->functionName()->accept(this);
std::vector<Expression *> expressions;
for (auto *expression : ctx->expression()) {
expressions.push_back(expression->accept(this));
}
std::string function_name = ctx->functionName()->accept(this);
if (expressions.size() == 1U) {
if (function_name == Aggregation::kCount) {
return static_cast<Expression *>(
@ -667,9 +668,10 @@ antlrcpp::Any CypherMainVisitor::visitFunctionInvocation(
storage_.Create<Aggregation>(expressions[0], Aggregation::Op::AVG));
}
}
// it is not a aggregation, it is a regular function,
// will be implemented in next diff
throw NotYetImplemented();
auto function = NameToFunction(function_name);
if (!function) throw SemanticException();
return static_cast<Expression *>(
storage_.Create<Function>(function, expressions));
}
antlrcpp::Any CypherMainVisitor::visitFunctionName(

View File

@ -0,0 +1,38 @@
#include "query/frontend/interpret/awesome_memgraph_functions.hpp"
#include <cmath>
#include <cstdlib>
#include "query/exceptions.hpp"
namespace query {
namespace {
TypedValue Abs(const std::vector<TypedValue> &args) {
if (args.size() != 1U) {
throw QueryRuntimeException("ABS requires one argument");
}
switch (args[0].type()) {
case TypedValue::Type::Null:
return TypedValue::Null;
case TypedValue::Type::Bool:
return args[0].Value<bool>();
case TypedValue::Type::Int:
return static_cast<int64_t>(
std::abs(static_cast<long long>(args[0].Value<int64_t>())));
case TypedValue::Type::Double:
return std::abs(args[0].Value<double>());
default:
throw QueryRuntimeException("ABS called with incompatible type");
}
}
}
std::function<TypedValue(const std::vector<TypedValue> &)> NameToFunction(
const std::string &function_name) {
if (function_name == "ABS") {
return Abs;
}
return nullptr;
}
}

View File

@ -0,0 +1,11 @@
#pragma once
#include <vector>
#include "query/typed_value.hpp"
namespace query {
std::function<TypedValue(const std::vector<TypedValue> &)> NameToFunction(
const std::string &function_name);
}

View File

@ -15,7 +15,9 @@ class Frame {
public:
Frame(int size) : size_(size), elems_(size_) {}
TypedValue &operator[](const Symbol &symbol) { return elems_[symbol.position_]; }
TypedValue &operator[](const Symbol &symbol) {
return elems_[symbol.position_];
}
const TypedValue &operator[](const Symbol &symbol) const {
return elems_[symbol.position_];
}
@ -153,6 +155,15 @@ class ExpressionEvaluator : public TreeVisitorBase {
result_stack_.emplace_back(std::move(value));
}
void PostVisit(Function &function) override {
std::vector<TypedValue> arguments;
for (int i = 0; i < static_cast<int>(function.arguments_.size()); ++i) {
arguments.push_back(PopBack());
}
reverse(arguments.begin(), arguments.end());
result_stack_.emplace_back(function.function_(arguments));
}
private:
// If the given TypedValue contains accessors, switch them to New or Old,
// depending on use_new_ flag.

View File

@ -349,6 +349,27 @@ TEST(CypherMainVisitorTest, Aggregation) {
}
}
TEST(CypherMainVisitorTest, UndefinedFunction) {
ASSERT_THROW(AstGenerator("RETURN "
"IHopeWeWillNeverHaveAwesomeMemgraphProcedureWithS"
"uchALongAndAwesomeNameSinceThisTestWouldFail(1)"),
SemanticException);
}
TEST(CypherMainVisitorTest, Function) {
AstGenerator ast_generator("RETURN abs(n, 2)");
auto *query = ast_generator.query_;
auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]);
ASSERT_EQ(return_clause->named_expressions_.size(), 1);
auto *function = dynamic_cast<Function *>(
return_clause->named_expressions_[0]->expression_);
ASSERT_TRUE(function);
ASSERT_TRUE(function->function_);
// Check if function is abs.
ASSERT_EQ(function->function_({-2}).Value<int64_t>(), 2);
ASSERT_EQ(function->arguments_.size(), 2);
}
TEST(CypherMainVisitorTest, StringLiteralDoubleQuotes) {
AstGenerator ast_generator("RETURN \"mi'rko\"");
auto *query = ast_generator.query_;

View File

@ -11,6 +11,7 @@
#include "gtest/gtest.h"
#include "query/frontend/ast/ast.hpp"
#include "query/frontend/interpret/awesome_memgraph_functions.hpp"
#include "query/frontend/interpret/interpret.hpp"
#include "query/frontend/opencypher/parser.hpp"
@ -244,3 +245,39 @@ TEST(ExpressionEvaluator, IsNullOperator) {
op->Accept(eval.eval);
ASSERT_EQ(eval.eval.PopBack().Value<bool>(), true);
}
TEST(ExpressionEvaluator, Function) {
AstTreeStorage storage;
NoContextExpressionEvaluator eval;
{
std::vector<Expression *> arguments = {
storage.Create<Literal>(TypedValue::Null)};
auto *op = storage.Create<Function>(NameToFunction("ABS"), arguments);
op->Accept(eval.eval);
ASSERT_EQ(eval.eval.PopBack().type(), TypedValue::Type::Null);
}
{
std::vector<Expression *> arguments = {storage.Create<Literal>(-2)};
auto *op = storage.Create<Function>(NameToFunction("ABS"), arguments);
op->Accept(eval.eval);
ASSERT_EQ(eval.eval.PopBack().Value<int64_t>(), 2);
}
{
std::vector<Expression *> arguments = {storage.Create<Literal>(-2.5)};
auto *op = storage.Create<Function>(NameToFunction("ABS"), arguments);
op->Accept(eval.eval);
ASSERT_EQ(eval.eval.PopBack().Value<double>(), 2.5);
}
{
std::vector<Expression *> arguments = {storage.Create<Literal>(true)};
auto *op = storage.Create<Function>(NameToFunction("ABS"), arguments);
op->Accept(eval.eval);
ASSERT_EQ(eval.eval.PopBack().Value<bool>(), true);
}
{
std::vector<Expression *> arguments = {
storage.Create<Literal>(std::vector<TypedValue>(5))};
auto *op = storage.Create<Function>(NameToFunction("ABS"), arguments);
ASSERT_THROW(op->Accept(eval.eval), QueryRuntimeException);
}
}