diff --git a/CMakeLists.txt b/CMakeLists.txt index 4e3eba504..036d36574 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -362,6 +362,7 @@ set(memgraph_src_files ${src_dir}/query/console.cpp ${src_dir}/query/frontend/ast/cypher_main_visitor.cpp ${src_dir}/query/typed_value.cpp + ${src_dir}/query/frontend/interpret/awesome_memgraph_functions.cpp ${src_dir}/query/frontend/logical/operator.cpp ${src_dir}/query/frontend/logical/planner.cpp ${src_dir}/query/frontend/semantic/symbol_generator.cpp diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index f9e96c93a..cc8772e2c 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -413,6 +413,30 @@ class PropertyLookup : public Expression { : Expression(uid), expression_(expression), property_(property) {} }; +class Function : public Expression { + friend class AstTreeStorage; + + public: + void Accept(TreeVisitorBase &visitor) override { + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + for (auto *argument : arguments_) { + argument->Accept(visitor); + } + visitor.PostVisit(*this); + } + } + + std::function<TypedValue(const std::vector<TypedValue> &)> function_; + std::vector<Expression *> arguments_; + + protected: + Function(int uid, + std::function<TypedValue(const std::vector<TypedValue> &)> function, + const std::vector<Expression *> &arguments) + : Expression(uid), function_(function), arguments_(arguments) {} +}; + class Aggregation : public UnaryOperator { friend class AstTreeStorage; diff --git a/src/query/frontend/ast/ast_visitor.hpp b/src/query/frontend/ast/ast_visitor.hpp index 7fecf156f..e69d9ba47 100644 --- a/src/query/frontend/ast/ast_visitor.hpp +++ b/src/query/frontend/ast/ast_visitor.hpp @@ -10,6 +10,7 @@ class NamedExpression; class Identifier; class PropertyLookup; class Aggregation; +class Function; class Create; class Match; class Return; @@ -50,7 +51,7 @@ using TreeVisitorBase = ::utils::Visitor< DivisionOperator, ModOperator, NotEqualOperator, EqualOperator, LessOperator, GreaterOperator, LessEqualOperator, GreaterEqualOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, Identifier, Literal, - PropertyLookup, Aggregation, Create, Match, Return, With, Pattern, NodeAtom, - EdgeAtom, Delete, Where, SetProperty, SetProperties, SetLabels, + PropertyLookup, Aggregation, Function, Create, Match, Return, With, Pattern, + NodeAtom, EdgeAtom, Delete, Where, SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels>; } diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index 61a665d91..6808bf824 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -11,6 +11,7 @@ #include "database/graph_db.hpp" #include "query/exceptions.hpp" +#include "query/frontend/interpret/awesome_memgraph_functions.hpp" #include "utils/assert.hpp" namespace query { @@ -640,11 +641,11 @@ antlrcpp::Any CypherMainVisitor::visitFunctionInvocation( if (ctx->DISTINCT()) { throw NotYetImplemented(); } + std::string function_name = ctx->functionName()->accept(this); std::vector<Expression *> expressions; for (auto *expression : ctx->expression()) { expressions.push_back(expression->accept(this)); } - std::string function_name = ctx->functionName()->accept(this); if (expressions.size() == 1U) { if (function_name == Aggregation::kCount) { return static_cast<Expression *>( @@ -667,9 +668,10 @@ antlrcpp::Any CypherMainVisitor::visitFunctionInvocation( storage_.Create<Aggregation>(expressions[0], Aggregation::Op::AVG)); } } - // it is not a aggregation, it is a regular function, - // will be implemented in next diff - throw NotYetImplemented(); + auto function = NameToFunction(function_name); + if (!function) throw SemanticException(); + return static_cast<Expression *>( + storage_.Create<Function>(function, expressions)); } antlrcpp::Any CypherMainVisitor::visitFunctionName( diff --git a/src/query/frontend/interpret/awesome_memgraph_functions.cpp b/src/query/frontend/interpret/awesome_memgraph_functions.cpp new file mode 100644 index 000000000..9a38750b4 --- /dev/null +++ b/src/query/frontend/interpret/awesome_memgraph_functions.cpp @@ -0,0 +1,38 @@ +#include "query/frontend/interpret/awesome_memgraph_functions.hpp" + +#include <cmath> +#include <cstdlib> + +#include "query/exceptions.hpp" + +namespace query { +namespace { + +TypedValue Abs(const std::vector<TypedValue> &args) { + if (args.size() != 1U) { + throw QueryRuntimeException("ABS requires one argument"); + } + switch (args[0].type()) { + case TypedValue::Type::Null: + return TypedValue::Null; + case TypedValue::Type::Bool: + return args[0].Value<bool>(); + case TypedValue::Type::Int: + return static_cast<int64_t>( + std::abs(static_cast<long long>(args[0].Value<int64_t>()))); + case TypedValue::Type::Double: + return std::abs(args[0].Value<double>()); + default: + throw QueryRuntimeException("ABS called with incompatible type"); + } +} +} + +std::function<TypedValue(const std::vector<TypedValue> &)> NameToFunction( + const std::string &function_name) { + if (function_name == "ABS") { + return Abs; + } + return nullptr; +} +} diff --git a/src/query/frontend/interpret/awesome_memgraph_functions.hpp b/src/query/frontend/interpret/awesome_memgraph_functions.hpp new file mode 100644 index 000000000..7f5b28169 --- /dev/null +++ b/src/query/frontend/interpret/awesome_memgraph_functions.hpp @@ -0,0 +1,11 @@ +#pragma once + +#include <vector> + +#include "query/typed_value.hpp" + +namespace query { + +std::function<TypedValue(const std::vector<TypedValue> &)> NameToFunction( + const std::string &function_name); +} diff --git a/src/query/frontend/interpret/interpret.hpp b/src/query/frontend/interpret/interpret.hpp index b04446c92..ac4f90dfe 100644 --- a/src/query/frontend/interpret/interpret.hpp +++ b/src/query/frontend/interpret/interpret.hpp @@ -15,7 +15,9 @@ class Frame { public: Frame(int size) : size_(size), elems_(size_) {} - TypedValue &operator[](const Symbol &symbol) { return elems_[symbol.position_]; } + TypedValue &operator[](const Symbol &symbol) { + return elems_[symbol.position_]; + } const TypedValue &operator[](const Symbol &symbol) const { return elems_[symbol.position_]; } @@ -153,6 +155,15 @@ class ExpressionEvaluator : public TreeVisitorBase { result_stack_.emplace_back(std::move(value)); } + void PostVisit(Function &function) override { + std::vector<TypedValue> arguments; + for (int i = 0; i < static_cast<int>(function.arguments_.size()); ++i) { + arguments.push_back(PopBack()); + } + reverse(arguments.begin(), arguments.end()); + result_stack_.emplace_back(function.function_(arguments)); + } + private: // If the given TypedValue contains accessors, switch them to New or Old, // depending on use_new_ flag. diff --git a/tests/unit/cypher_main_visitor.cpp b/tests/unit/cypher_main_visitor.cpp index 6a8e5deed..5c376933b 100644 --- a/tests/unit/cypher_main_visitor.cpp +++ b/tests/unit/cypher_main_visitor.cpp @@ -349,6 +349,27 @@ TEST(CypherMainVisitorTest, Aggregation) { } } +TEST(CypherMainVisitorTest, UndefinedFunction) { + ASSERT_THROW(AstGenerator("RETURN " + "IHopeWeWillNeverHaveAwesomeMemgraphProcedureWithS" + "uchALongAndAwesomeNameSinceThisTestWouldFail(1)"), + SemanticException); +} + +TEST(CypherMainVisitorTest, Function) { + AstGenerator ast_generator("RETURN abs(n, 2)"); + auto *query = ast_generator.query_; + auto *return_clause = dynamic_cast<Return *>(query->clauses_[0]); + ASSERT_EQ(return_clause->named_expressions_.size(), 1); + auto *function = dynamic_cast<Function *>( + return_clause->named_expressions_[0]->expression_); + ASSERT_TRUE(function); + ASSERT_TRUE(function->function_); + // Check if function is abs. + ASSERT_EQ(function->function_({-2}).Value<int64_t>(), 2); + ASSERT_EQ(function->arguments_.size(), 2); +} + TEST(CypherMainVisitorTest, StringLiteralDoubleQuotes) { AstGenerator ast_generator("RETURN \"mi'rko\""); auto *query = ast_generator.query_; diff --git a/tests/unit/query_expression_evaluator.cpp b/tests/unit/query_expression_evaluator.cpp index e344460ed..bc2309cc1 100644 --- a/tests/unit/query_expression_evaluator.cpp +++ b/tests/unit/query_expression_evaluator.cpp @@ -11,6 +11,7 @@ #include "gtest/gtest.h" #include "query/frontend/ast/ast.hpp" +#include "query/frontend/interpret/awesome_memgraph_functions.hpp" #include "query/frontend/interpret/interpret.hpp" #include "query/frontend/opencypher/parser.hpp" @@ -244,3 +245,39 @@ TEST(ExpressionEvaluator, IsNullOperator) { op->Accept(eval.eval); ASSERT_EQ(eval.eval.PopBack().Value<bool>(), true); } + +TEST(ExpressionEvaluator, Function) { + AstTreeStorage storage; + NoContextExpressionEvaluator eval; + { + std::vector<Expression *> arguments = { + storage.Create<Literal>(TypedValue::Null)}; + auto *op = storage.Create<Function>(NameToFunction("ABS"), arguments); + op->Accept(eval.eval); + ASSERT_EQ(eval.eval.PopBack().type(), TypedValue::Type::Null); + } + { + std::vector<Expression *> arguments = {storage.Create<Literal>(-2)}; + auto *op = storage.Create<Function>(NameToFunction("ABS"), arguments); + op->Accept(eval.eval); + ASSERT_EQ(eval.eval.PopBack().Value<int64_t>(), 2); + } + { + std::vector<Expression *> arguments = {storage.Create<Literal>(-2.5)}; + auto *op = storage.Create<Function>(NameToFunction("ABS"), arguments); + op->Accept(eval.eval); + ASSERT_EQ(eval.eval.PopBack().Value<double>(), 2.5); + } + { + std::vector<Expression *> arguments = {storage.Create<Literal>(true)}; + auto *op = storage.Create<Function>(NameToFunction("ABS"), arguments); + op->Accept(eval.eval); + ASSERT_EQ(eval.eval.PopBack().Value<bool>(), true); + } + { + std::vector<Expression *> arguments = { + storage.Create<Literal>(std::vector<TypedValue>(5))}; + auto *op = storage.Create<Function>(NameToFunction("ABS"), arguments); + ASSERT_THROW(op->Accept(eval.eval), QueryRuntimeException); + } +}