Implement simple function conversion
Reviewers: buda Reviewed By: buda Subscribers: florijan, teon.banek Differential Revision: https://phabricator.memgraph.io/D264
This commit is contained in:
parent
1273cea870
commit
36129cdcae
@ -362,6 +362,7 @@ set(memgraph_src_files
|
||||
${src_dir}/query/console.cpp
|
||||
${src_dir}/query/frontend/ast/cypher_main_visitor.cpp
|
||||
${src_dir}/query/typed_value.cpp
|
||||
${src_dir}/query/frontend/interpret/awesome_memgraph_functions.cpp
|
||||
${src_dir}/query/frontend/logical/operator.cpp
|
||||
${src_dir}/query/frontend/logical/planner.cpp
|
||||
${src_dir}/query/frontend/semantic/symbol_generator.cpp
|
||||
|
@ -413,6 +413,30 @@ class PropertyLookup : public Expression {
|
||||
: Expression(uid), expression_(expression), property_(property) {}
|
||||
};
|
||||
|
||||
class Function : public Expression {
|
||||
friend class AstTreeStorage;
|
||||
|
||||
public:
|
||||
void Accept(TreeVisitorBase &visitor) override {
|
||||
if (visitor.PreVisit(*this)) {
|
||||
visitor.Visit(*this);
|
||||
for (auto *argument : arguments_) {
|
||||
argument->Accept(visitor);
|
||||
}
|
||||
visitor.PostVisit(*this);
|
||||
}
|
||||
}
|
||||
|
||||
std::function<TypedValue(const std::vector<TypedValue> &)> function_;
|
||||
std::vector<Expression *> arguments_;
|
||||
|
||||
protected:
|
||||
Function(int uid,
|
||||
std::function<TypedValue(const std::vector<TypedValue> &)> function,
|
||||
const std::vector<Expression *> &arguments)
|
||||
: Expression(uid), function_(function), arguments_(arguments) {}
|
||||
};
|
||||
|
||||
class Aggregation : public UnaryOperator {
|
||||
friend class AstTreeStorage;
|
||||
|
||||
|
@ -10,6 +10,7 @@ class NamedExpression;
|
||||
class Identifier;
|
||||
class PropertyLookup;
|
||||
class Aggregation;
|
||||
class Function;
|
||||
class Create;
|
||||
class Match;
|
||||
class Return;
|
||||
@ -50,7 +51,7 @@ using TreeVisitorBase = ::utils::Visitor<
|
||||
DivisionOperator, ModOperator, NotEqualOperator, EqualOperator,
|
||||
LessOperator, GreaterOperator, LessEqualOperator, GreaterEqualOperator,
|
||||
UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, Identifier, Literal,
|
||||
PropertyLookup, Aggregation, Create, Match, Return, With, Pattern, NodeAtom,
|
||||
EdgeAtom, Delete, Where, SetProperty, SetProperties, SetLabels,
|
||||
PropertyLookup, Aggregation, Function, Create, Match, Return, With, Pattern,
|
||||
NodeAtom, EdgeAtom, Delete, Where, SetProperty, SetProperties, SetLabels,
|
||||
RemoveProperty, RemoveLabels>;
|
||||
}
|
||||
|
@ -11,6 +11,7 @@
|
||||
|
||||
#include "database/graph_db.hpp"
|
||||
#include "query/exceptions.hpp"
|
||||
#include "query/frontend/interpret/awesome_memgraph_functions.hpp"
|
||||
#include "utils/assert.hpp"
|
||||
|
||||
namespace query {
|
||||
@ -640,11 +641,11 @@ antlrcpp::Any CypherMainVisitor::visitFunctionInvocation(
|
||||
if (ctx->DISTINCT()) {
|
||||
throw NotYetImplemented();
|
||||
}
|
||||
std::string function_name = ctx->functionName()->accept(this);
|
||||
std::vector<Expression *> expressions;
|
||||
for (auto *expression : ctx->expression()) {
|
||||
expressions.push_back(expression->accept(this));
|
||||
}
|
||||
std::string function_name = ctx->functionName()->accept(this);
|
||||
if (expressions.size() == 1U) {
|
||||
if (function_name == Aggregation::kCount) {
|
||||
return static_cast<Expression *>(
|
||||
@ -667,9 +668,10 @@ antlrcpp::Any CypherMainVisitor::visitFunctionInvocation(
|
||||
storage_.Create<Aggregation>(expressions[0], Aggregation::Op::AVG));
|
||||
}
|
||||
}
|
||||
// it is not a aggregation, it is a regular function,
|
||||
// will be implemented in next diff
|
||||
throw NotYetImplemented();
|
||||
auto function = NameToFunction(function_name);
|
||||
if (!function) throw SemanticException();
|
||||
return static_cast<Expression *>(
|
||||
storage_.Create<Function>(function, expressions));
|
||||
}
|
||||
|
||||
antlrcpp::Any CypherMainVisitor::visitFunctionName(
|
||||
|
38
src/query/frontend/interpret/awesome_memgraph_functions.cpp
Normal file
38
src/query/frontend/interpret/awesome_memgraph_functions.cpp
Normal file
@ -0,0 +1,38 @@
|
||||
#include "query/frontend/interpret/awesome_memgraph_functions.hpp"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "query/exceptions.hpp"
|
||||
|
||||
namespace query {
|
||||
namespace {
|
||||
|
||||
TypedValue Abs(const std::vector<TypedValue> &args) {
|
||||
if (args.size() != 1U) {
|
||||
throw QueryRuntimeException("ABS requires one argument");
|
||||
}
|
||||
switch (args[0].type()) {
|
||||
case TypedValue::Type::Null:
|
||||
return TypedValue::Null;
|
||||
case TypedValue::Type::Bool:
|
||||
return args[0].Value<bool>();
|
||||
case TypedValue::Type::Int:
|
||||
return static_cast<int64_t>(
|
||||
std::abs(static_cast<long long>(args[0].Value<int64_t>())));
|
||||
case TypedValue::Type::Double:
|
||||
return std::abs(args[0].Value<double>());
|
||||
default:
|
||||
throw QueryRuntimeException("ABS called with incompatible type");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::function<TypedValue(const std::vector<TypedValue> &)> NameToFunction(
|
||||
const std::string &function_name) {
|
||||
if (function_name == "ABS") {
|
||||
return Abs;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
}
|
11
src/query/frontend/interpret/awesome_memgraph_functions.hpp
Normal file
11
src/query/frontend/interpret/awesome_memgraph_functions.hpp
Normal file
@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "query/typed_value.hpp"
|
||||
|
||||
namespace query {
|
||||
|
||||
std::function<TypedValue(const std::vector<TypedValue> &)> NameToFunction(
|
||||
const std::string &function_name);
|
||||
}
|
@ -15,7 +15,9 @@ class Frame {
|
||||
public:
|
||||
Frame(int size) : size_(size), elems_(size_) {}
|
||||
|
||||
TypedValue &operator[](const Symbol &symbol) { return elems_[symbol.position_]; }
|
||||
TypedValue &operator[](const Symbol &symbol) {
|
||||
return elems_[symbol.position_];
|
||||
}
|
||||
const TypedValue &operator[](const Symbol &symbol) const {
|
||||
return elems_[symbol.position_];
|
||||
}
|
||||
@ -153,6 +155,15 @@ class ExpressionEvaluator : public TreeVisitorBase {
|
||||
result_stack_.emplace_back(std::move(value));
|
||||
}
|
||||
|
||||
void PostVisit(Function &function) override {
|
||||
std::vector<TypedValue> arguments;
|
||||
for (int i = 0; i < static_cast<int>(function.arguments_.size()); ++i) {
|
||||
arguments.push_back(PopBack());
|
||||
}
|
||||
reverse(arguments.begin(), arguments.end());
|
||||
result_stack_.emplace_back(function.function_(arguments));
|
||||
}
|
||||
|
||||
private:
|
||||
// If the given TypedValue contains accessors, switch them to New or Old,
|
||||
// depending on use_new_ flag.
|
||||
|
@ -349,6 +349,27 @@ TEST(CypherMainVisitorTest, Aggregation) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CypherMainVisitorTest, UndefinedFunction) {
|
||||
ASSERT_THROW(AstGenerator("RETURN "
|
||||
"IHopeWeWillNeverHaveAwesomeMemgraphProcedureWithS"
|
||||
"uchALongAndAwesomeNameSinceThisTestWouldFail(1)"),
|
||||
SemanticException);
|
||||
}
|
||||
|
||||
TEST(CypherMainVisitorTest, Function) {
|
||||
AstGenerator ast_generator("RETURN abs(n, 2)");
|
||||
auto *query = ast_generator.query_;
|
||||
auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]);
|
||||
ASSERT_EQ(return_clause->named_expressions_.size(), 1);
|
||||
auto *function = dynamic_cast<Function *>(
|
||||
return_clause->named_expressions_[0]->expression_);
|
||||
ASSERT_TRUE(function);
|
||||
ASSERT_TRUE(function->function_);
|
||||
// Check if function is abs.
|
||||
ASSERT_EQ(function->function_({-2}).Value<int64_t>(), 2);
|
||||
ASSERT_EQ(function->arguments_.size(), 2);
|
||||
}
|
||||
|
||||
TEST(CypherMainVisitorTest, StringLiteralDoubleQuotes) {
|
||||
AstGenerator ast_generator("RETURN \"mi'rko\"");
|
||||
auto *query = ast_generator.query_;
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "query/frontend/ast/ast.hpp"
|
||||
#include "query/frontend/interpret/awesome_memgraph_functions.hpp"
|
||||
#include "query/frontend/interpret/interpret.hpp"
|
||||
#include "query/frontend/opencypher/parser.hpp"
|
||||
|
||||
@ -244,3 +245,39 @@ TEST(ExpressionEvaluator, IsNullOperator) {
|
||||
op->Accept(eval.eval);
|
||||
ASSERT_EQ(eval.eval.PopBack().Value<bool>(), true);
|
||||
}
|
||||
|
||||
TEST(ExpressionEvaluator, Function) {
|
||||
AstTreeStorage storage;
|
||||
NoContextExpressionEvaluator eval;
|
||||
{
|
||||
std::vector<Expression *> arguments = {
|
||||
storage.Create<Literal>(TypedValue::Null)};
|
||||
auto *op = storage.Create<Function>(NameToFunction("ABS"), arguments);
|
||||
op->Accept(eval.eval);
|
||||
ASSERT_EQ(eval.eval.PopBack().type(), TypedValue::Type::Null);
|
||||
}
|
||||
{
|
||||
std::vector<Expression *> arguments = {storage.Create<Literal>(-2)};
|
||||
auto *op = storage.Create<Function>(NameToFunction("ABS"), arguments);
|
||||
op->Accept(eval.eval);
|
||||
ASSERT_EQ(eval.eval.PopBack().Value<int64_t>(), 2);
|
||||
}
|
||||
{
|
||||
std::vector<Expression *> arguments = {storage.Create<Literal>(-2.5)};
|
||||
auto *op = storage.Create<Function>(NameToFunction("ABS"), arguments);
|
||||
op->Accept(eval.eval);
|
||||
ASSERT_EQ(eval.eval.PopBack().Value<double>(), 2.5);
|
||||
}
|
||||
{
|
||||
std::vector<Expression *> arguments = {storage.Create<Literal>(true)};
|
||||
auto *op = storage.Create<Function>(NameToFunction("ABS"), arguments);
|
||||
op->Accept(eval.eval);
|
||||
ASSERT_EQ(eval.eval.PopBack().Value<bool>(), true);
|
||||
}
|
||||
{
|
||||
std::vector<Expression *> arguments = {
|
||||
storage.Create<Literal>(std::vector<TypedValue>(5))};
|
||||
auto *op = storage.Create<Function>(NameToFunction("ABS"), arguments);
|
||||
ASSERT_THROW(op->Accept(eval.eval), QueryRuntimeException);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user