diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index 669428d7d..f39d92c7a 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -1013,8 +1013,7 @@ class Function : public Expression { private: std::string function_name_; - std::function &, Context *)> - function_; + std::function function_; }; class Aggregation : public BinaryOperator { diff --git a/src/query/interpret/awesome_memgraph_functions.cpp b/src/query/interpret/awesome_memgraph_functions.cpp index 5b24d1e94..e7ffdb1a2 100644 --- a/src/query/interpret/awesome_memgraph_functions.cpp +++ b/src/query/interpret/awesome_memgraph_functions.cpp @@ -32,20 +32,22 @@ namespace { // TODO: Implement degrees, haversin, radians // TODO: Implement spatial functions -TypedValue Coalesce(const std::vector &args, Context *) { - if (args.size() == 0U) { +TypedValue Coalesce(TypedValue *args, int64_t nargs, Context *) { + // TODO: Perhaps this function should be done by the evaluator itself, so as + // to avoid evaluating all the arguments. + if (nargs == 0) { throw QueryRuntimeException("'coalesce' requires at least one argument."); } - for (auto &arg : args) { - if (arg.type() != TypedValue::Type::Null) { - return arg; + for (int64_t i = 0; i < nargs; ++i) { + if (args[i].type() != TypedValue::Type::Null) { + return args[i]; } } return TypedValue::Null; } -TypedValue EndNode(const std::vector &args, Context *) { - if (args.size() != 1U) { +TypedValue EndNode(TypedValue *args, int64_t nargs, Context *) { + if (nargs != 1) { throw QueryRuntimeException("'endNode' requires exactly one argument."); } switch (args[0].type()) { @@ -58,8 +60,8 @@ TypedValue EndNode(const std::vector &args, Context *) { } } -TypedValue Head(const std::vector &args, Context *) { - if (args.size() != 1U) { +TypedValue Head(TypedValue *args, int64_t nargs, Context *) { + if (nargs != 1) { throw QueryRuntimeException("'head' requires exactly one argument."); } switch (args[0].type()) { @@ -75,8 +77,8 @@ TypedValue Head(const std::vector &args, Context *) { } } -TypedValue Last(const std::vector &args, Context *) { - if (args.size() != 1U) { +TypedValue Last(TypedValue *args, int64_t nargs, Context *) { + if (nargs != 1) { throw QueryRuntimeException("'last' requires exactly one argument."); } switch (args[0].type()) { @@ -92,8 +94,8 @@ TypedValue Last(const std::vector &args, Context *) { } } -TypedValue Properties(const std::vector &args, Context *ctx) { - if (args.size() != 1U) { +TypedValue Properties(TypedValue *args, int64_t nargs, Context *ctx) { + if (nargs != 1) { throw QueryRuntimeException("'properties' requires exactly one argument."); } auto get_properties = [&](const auto &record_accessor) { @@ -117,8 +119,8 @@ TypedValue Properties(const std::vector &args, Context *ctx) { } } -TypedValue Size(const std::vector &args, Context *) { - if (args.size() != 1U) { +TypedValue Size(TypedValue *args, int64_t nargs, Context *) { + if (nargs != 1) { throw QueryRuntimeException("'size' requires exactly one argument."); } switch (args[0].type()) { @@ -142,8 +144,8 @@ TypedValue Size(const std::vector &args, Context *) { } } -TypedValue StartNode(const std::vector &args, Context *) { - if (args.size() != 1U) { +TypedValue StartNode(TypedValue *args, int64_t nargs, Context *) { + if (nargs != 1) { throw QueryRuntimeException("'startNode' requires exactly one argument."); } switch (args[0].type()) { @@ -156,8 +158,8 @@ TypedValue StartNode(const std::vector &args, Context *) { } } -TypedValue Degree(const std::vector &args, Context *) { - if (args.size() != 1U) { +TypedValue Degree(TypedValue *args, int64_t nargs, Context *) { + if (nargs != 1) { throw QueryRuntimeException("'degree' requires exactly one argument."); } switch (args[0].type()) { @@ -172,8 +174,8 @@ TypedValue Degree(const std::vector &args, Context *) { } } -TypedValue ToBoolean(const std::vector &args, Context *) { - if (args.size() != 1U) { +TypedValue ToBoolean(TypedValue *args, int64_t nargs, Context *) { + if (nargs != 1) { throw QueryRuntimeException("'toBoolean' requires exactly one argument."); } switch (args[0].type()) { @@ -197,8 +199,8 @@ TypedValue ToBoolean(const std::vector &args, Context *) { } } -TypedValue ToFloat(const std::vector &args, Context *) { - if (args.size() != 1U) { +TypedValue ToFloat(TypedValue *args, int64_t nargs, Context *) { + if (nargs != 1) { throw QueryRuntimeException("'toFloat' requires exactly one argument."); } switch (args[0].type()) { @@ -220,8 +222,8 @@ TypedValue ToFloat(const std::vector &args, Context *) { } } -TypedValue ToInteger(const std::vector &args, Context *) { - if (args.size() != 1U) { +TypedValue ToInteger(TypedValue *args, int64_t nargs, Context *) { + if (nargs != 1) { throw QueryRuntimeException("'toInteger' requires exactly one argument'"); } switch (args[0].type()) { @@ -248,8 +250,8 @@ TypedValue ToInteger(const std::vector &args, Context *) { } } -TypedValue Type(const std::vector &args, Context *ctx) { - if (args.size() != 1U) { +TypedValue Type(TypedValue *args, int64_t nargs, Context *ctx) { + if (nargs != 1) { throw QueryRuntimeException("'type' requires exactly one argument."); } switch (args[0].type()) { @@ -263,8 +265,8 @@ TypedValue Type(const std::vector &args, Context *ctx) { } } -TypedValue Keys(const std::vector &args, Context *ctx) { - if (args.size() != 1U) { +TypedValue Keys(TypedValue *args, int64_t nargs, Context *ctx) { + if (nargs != 1) { throw QueryRuntimeException("'keys' requires exactly one argument."); } auto get_keys = [&](const auto &record_accessor) { @@ -286,8 +288,8 @@ TypedValue Keys(const std::vector &args, Context *ctx) { } } -TypedValue Labels(const std::vector &args, Context *ctx) { - if (args.size() != 1U) { +TypedValue Labels(TypedValue *args, int64_t nargs, Context *ctx) { + if (nargs != 1) { throw QueryRuntimeException("'labels' requires exactly one argument."); } switch (args[0].type()) { @@ -305,8 +307,8 @@ TypedValue Labels(const std::vector &args, Context *ctx) { } } -TypedValue Nodes(const std::vector &args, Context *) { - if (args.size() != 1U) { +TypedValue Nodes(TypedValue *args, int64_t nargs, Context *) { + if (nargs != 1) { throw QueryRuntimeException("'nodes' requires exactly one argument."); } if (args[0].IsNull()) return TypedValue::Null; @@ -317,8 +319,8 @@ TypedValue Nodes(const std::vector &args, Context *) { return std::vector(vertices.begin(), vertices.end()); } -TypedValue Relationships(const std::vector &args, Context *) { - if (args.size() != 1U) { +TypedValue Relationships(TypedValue *args, int64_t nargs, Context *) { + if (nargs != 1) { throw QueryRuntimeException( "'relationships' requires exactly one argument."); } @@ -330,8 +332,8 @@ TypedValue Relationships(const std::vector &args, Context *) { return std::vector(edges.begin(), edges.end()); } -TypedValue Range(const std::vector &args, Context *) { - if (args.size() != 2U && args.size() != 3U) { +TypedValue Range(TypedValue *args, int64_t nargs, Context *) { + if (nargs != 2 && nargs != 3) { throw QueryRuntimeException("'range' requires two or three arguments."); } bool has_null = false; @@ -342,11 +344,11 @@ TypedValue Range(const std::vector &args, Context *) { throw QueryRuntimeException("arguments of 'range' must be integers."); } }; - std::for_each(args.begin(), args.end(), check_type); + for (int64_t i = 0; i < nargs; ++i) check_type(args[i]); if (has_null) return TypedValue::Null; auto lbound = args[0].Value(); auto rbound = args[1].Value(); - int64_t step = args.size() == 3U ? args[2].Value() : 1; + int64_t step = nargs == 3 ? args[2].Value() : 1; if (step == 0) { throw QueryRuntimeException("step argument of 'range' can't be zero."); } @@ -363,8 +365,8 @@ TypedValue Range(const std::vector &args, Context *) { return list; } -TypedValue Tail(const std::vector &args, Context *) { - if (args.size() != 1U) { +TypedValue Tail(TypedValue *args, int64_t nargs, Context *) { + if (nargs != 1) { throw QueryRuntimeException("'tail' requires exactly one argument."); } switch (args[0].type()) { @@ -381,8 +383,8 @@ TypedValue Tail(const std::vector &args, Context *) { } } -TypedValue Abs(const std::vector &args, Context *) { - if (args.size() != 1U) { +TypedValue Abs(TypedValue *args, int64_t nargs, Context *) { + if (nargs != 1) { throw QueryRuntimeException("'abs' requires exactly one argument."); } switch (args[0].type()) { @@ -399,8 +401,8 @@ TypedValue Abs(const std::vector &args, Context *) { } #define WRAP_CMATH_FLOAT_FUNCTION(name, lowercased_name) \ - TypedValue name(const std::vector &args, Context *) { \ - if (args.size() != 1U) { \ + TypedValue name(TypedValue *args, int64_t nargs, Context *) { \ + if (nargs != 1) { \ throw QueryRuntimeException("'" #lowercased_name \ "' requires exactly one argument."); \ } \ @@ -435,8 +437,8 @@ WRAP_CMATH_FLOAT_FUNCTION(Tan, tan) #undef WRAP_CMATH_FLOAT_FUNCTION -TypedValue Atan2(const std::vector &args, Context *) { - if (args.size() != 2U) { +TypedValue Atan2(TypedValue *args, int64_t nargs, Context *) { + if (nargs != 2) { throw QueryRuntimeException("'atan2' requires two arguments."); } if (args[0].type() == TypedValue::Type::Null) return TypedValue::Null; @@ -456,8 +458,8 @@ TypedValue Atan2(const std::vector &args, Context *) { return atan2(y, x); } -TypedValue Sign(const std::vector &args, Context *) { - if (args.size() != 1U) { +TypedValue Sign(TypedValue *args, int64_t nargs, Context *) { + if (nargs != 1) { throw QueryRuntimeException("'sign' requires exactly one argument."); } auto sign = [](auto x) { return (0 < x) - (x < 0); }; @@ -473,32 +475,32 @@ TypedValue Sign(const std::vector &args, Context *) { } } -TypedValue E(const std::vector &args, Context *) { - if (args.size() != 0U) { +TypedValue E(TypedValue *, int64_t nargs, Context *) { + if (nargs != 0) { throw QueryRuntimeException("'e' requires no arguments."); } return M_E; } -TypedValue Pi(const std::vector &args, Context *) { - if (args.size() != 0U) { +TypedValue Pi(TypedValue *, int64_t nargs, Context *) { + if (nargs != 0) { throw QueryRuntimeException("'pi' requires no arguments."); } return M_PI; } -TypedValue Rand(const std::vector &args, Context *) { +TypedValue Rand(TypedValue *, int64_t nargs, Context *) { static thread_local std::mt19937 pseudo_rand_gen_{std::random_device{}()}; static thread_local std::uniform_real_distribution<> rand_dist_{0, 1}; - if (args.size() != 0U) { + if (nargs != 0) { throw QueryRuntimeException("'rand' requires no arguments."); } return rand_dist_(pseudo_rand_gen_); } template -TypedValue StringMatchOperator(const std::vector &args, Context *) { - if (args.size() != 2U) { +TypedValue StringMatchOperator(TypedValue *args, int64_t nargs, Context *) { + if (nargs != 2) { throw QueryRuntimeException( "'startsWith' and 'endsWith' require two arguments."); } @@ -540,27 +542,27 @@ bool ContainsPredicate(const std::string &s1, const std::string &s2) { } auto Contains = StringMatchOperator; -TypedValue Assert(const std::vector &args, Context *) { - if (args.size() < 1U || args.size() > 2U) { +TypedValue Assert(TypedValue *args, int64_t nargs, Context *) { + if (nargs < 1 || nargs > 2) { throw QueryRuntimeException("'assert' requires one or two arguments"); } if (args[0].type() != TypedValue::Type::Bool) throw QueryRuntimeException( "First argument of 'assert' must be a boolean."); - if (args.size() == 2U && args[1].type() != TypedValue::Type::String) + if (nargs == 2 && args[1].type() != TypedValue::Type::String) throw QueryRuntimeException( "Second argument of 'assert' must be a string."); if (!args[0].ValueBool()) { std::string message("Assertion failed"); - if (args.size() == 2U) message += ": " + args[1].ValueString(); + if (nargs == 2) message += ": " + args[1].ValueString(); message += "."; throw QueryRuntimeException(message); } return args[0]; } -TypedValue Counter(const std::vector &args, Context *ctx) { - if (args.size() != 1U) { +TypedValue Counter(TypedValue *args, int64_t nargs, Context *ctx) { + if (nargs != 1) { throw QueryRuntimeException("'counter' requires exactly one argument."); } if (!args[0].IsString()) @@ -569,8 +571,8 @@ TypedValue Counter(const std::vector &args, Context *ctx) { return ctx->db_accessor_.Counter(args[0].ValueString()); } -TypedValue CounterSet(const std::vector &args, Context *ctx) { - if (args.size() != 2U) { +TypedValue CounterSet(TypedValue *args, int64_t nargs, Context *ctx) { + if (nargs != 2) { throw QueryRuntimeException("'counterSet' requires two arguments."); } if (!args[0].IsString()) @@ -583,16 +585,16 @@ TypedValue CounterSet(const std::vector &args, Context *ctx) { return TypedValue::Null; } -TypedValue IndexInfo(const std::vector &args, Context *ctx) { - if (args.size() != 0U) +TypedValue IndexInfo(TypedValue *, int64_t nargs, Context *ctx) { + if (nargs != 0) throw QueryRuntimeException("'indexInfo' requires no arguments."); auto info = ctx->db_accessor_.IndexInfo(); return std::vector(info.begin(), info.end()); } -TypedValue WorkerId(const std::vector &args, Context *) { - if (args.size() != 1U) { +TypedValue WorkerId(TypedValue *args, int64_t nargs, Context *) { + if (nargs != 1) { throw QueryRuntimeException("'workerId' requires exactly one argument."); } auto &arg = args[0]; @@ -607,8 +609,8 @@ TypedValue WorkerId(const std::vector &args, Context *) { } } -TypedValue Id(const std::vector &args, Context *ctx) { - if (args.size() != 1U) { +TypedValue Id(TypedValue *args, int64_t nargs, Context *ctx) { + if (nargs != 1) { throw QueryRuntimeException("'id' requires exactly one argument."); } auto &arg = args[0]; @@ -624,8 +626,8 @@ TypedValue Id(const std::vector &args, Context *ctx) { } } -TypedValue ToString(const std::vector &args, Context *) { - if (args.size() != 1U) { +TypedValue ToString(TypedValue *args, int64_t nargs, Context *) { + if (nargs != 1) { throw QueryRuntimeException("'toString' requires exactly one argument."); } auto &arg = args[0]; @@ -646,15 +648,15 @@ TypedValue ToString(const std::vector &args, Context *) { } } -TypedValue Timestamp(const std::vector &args, Context *ctx) { - if (args.size() != 0) { +TypedValue Timestamp(TypedValue *, int64_t nargs, Context *ctx) { + if (nargs != 0) { throw QueryRuntimeException("'timestamp' requires no arguments."); } return ctx->timestamp_; } -TypedValue Left(const std::vector &args, Context *ctx) { - if (args.size() != 2) { +TypedValue Left(TypedValue *args, int64_t nargs, Context *ctx) { + if (nargs != 2) { throw QueryRuntimeException("'left' requires two arguments."); } switch (args[0].type()) { @@ -675,8 +677,8 @@ TypedValue Left(const std::vector &args, Context *ctx) { } } -TypedValue Right(const std::vector &args, Context *ctx) { - if (args.size() != 2) { +TypedValue Right(TypedValue *args, int64_t nargs, Context *ctx) { + if (nargs != 2) { throw QueryRuntimeException("'right' requires two arguments."); } switch (args[0].type()) { @@ -702,8 +704,8 @@ TypedValue Right(const std::vector &args, Context *ctx) { } #define WRAP_STRING_FUNCTION(name, lowercased_name, function) \ - TypedValue name(const std::vector &args, Context *) { \ - if (args.size() != 1U) { \ + TypedValue name(TypedValue *args, int64_t nargs, Context *) { \ + if (nargs != 1) { \ throw QueryRuntimeException("'" #lowercased_name \ "' requires exactly one argument."); \ } \ @@ -725,8 +727,8 @@ WRAP_STRING_FUNCTION(Reverse, reverse, utils::Reversed); WRAP_STRING_FUNCTION(ToLower, toLower, utils::ToLowerCase); WRAP_STRING_FUNCTION(ToUpper, toUpper, utils::ToUpperCase); -TypedValue Replace(const std::vector &args, Context *ctx) { - if (args.size() != 3U) { +TypedValue Replace(TypedValue *args, int64_t nargs, Context *ctx) { + if (nargs != 3) { throw QueryRuntimeException("'replace' requires three arguments."); } if (!args[0].IsNull() && !args[0].IsString()) { @@ -748,8 +750,8 @@ TypedValue Replace(const std::vector &args, Context *ctx) { args[2].ValueString()); } -TypedValue Split(const std::vector &args, Context *ctx) { - if (args.size() != 2U) { +TypedValue Split(TypedValue *args, int64_t nargs, Context *ctx) { + if (nargs != 2) { throw QueryRuntimeException("'split' requires two arguments."); } if (!args[0].IsNull() && !args[0].IsString()) { @@ -771,8 +773,8 @@ TypedValue Split(const std::vector &args, Context *ctx) { return result; } -TypedValue Substring(const std::vector &args, Context *) { - if (args.size() != 2U && args.size() != 3U) { +TypedValue Substring(TypedValue *args, int64_t nargs, Context *) { + if (nargs != 2 && nargs != 3) { throw QueryRuntimeException("'substring' requires two or three arguments."); } if (!args[0].IsNull() && !args[0].IsString()) { @@ -783,7 +785,7 @@ TypedValue Substring(const std::vector &args, Context *) { throw QueryRuntimeException( "Second argument of 'substring' should be a non-negative integer."); } - if (args.size() == 3U && (!args[2].IsInt() || args[2].ValueInt() < 0)) { + if (nargs == 3 && (!args[2].IsInt() || args[2].ValueInt() < 0)) { throw QueryRuntimeException( "Third argument of 'substring' should be a non-negative integer."); } @@ -792,7 +794,7 @@ TypedValue Substring(const std::vector &args, Context *) { } const auto &str = args[0].ValueString(); int start = args[1].ValueInt(); - if (args.size() == 2U) { + if (nargs == 2) { return start < str.size() ? str.substr(start) : ""; } int len = args[2].ValueInt(); @@ -801,7 +803,7 @@ TypedValue Substring(const std::vector &args, Context *) { } // namespace -std::function &, Context *)> +std::function NameToFunction(const std::string &function_name) { // Scalar functions if (function_name == "COALESCE") return Coalesce; @@ -878,4 +880,5 @@ NameToFunction(const std::string &function_name) { return nullptr; } + } // namespace query diff --git a/src/query/interpret/awesome_memgraph_functions.hpp b/src/query/interpret/awesome_memgraph_functions.hpp index 670bcfdb6..b42f4cb54 100644 --- a/src/query/interpret/awesome_memgraph_functions.hpp +++ b/src/query/interpret/awesome_memgraph_functions.hpp @@ -1,3 +1,4 @@ +/// @file #pragma once #include @@ -14,6 +15,14 @@ const char kEndsWith[] = "ENDSWITH"; const char kContains[] = "CONTAINS"; } // namespace -std::function &, Context *)> +/// Return the function implementation with the given name. +/// +/// Note, returned function signature uses C-style access to an array to allow +/// having an array stored anywhere the caller likes, as long as it is +/// contiguous in memory. Since most functions don't take many arguments, it's +/// convenient to have them stored in the calling stack frame. +std::function NameToFunction(const std::string &function_name); + } // namespace query diff --git a/src/query/interpret/eval.hpp b/src/query/interpret/eval.hpp index b1b4e9517..b357c7e39 100644 --- a/src/query/interpret/eval.hpp +++ b/src/query/interpret/eval.hpp @@ -8,6 +8,7 @@ #include "database/graph_db_accessor.hpp" #include "query/common.hpp" +#include "query/context.hpp" #include "query/exceptions.hpp" #include "query/frontend/ast/ast.hpp" #include "query/frontend/semantic/symbol_table.hpp" @@ -157,8 +158,7 @@ class ExpressionEvaluator : public TreeVisitor { // Exceptions have higher priority than returning nulls when list expression // is not null. if (_list.type() != TypedValue::Type::List) { - throw QueryRuntimeException("IN expected a list, got {}.", - _list.type()); + throw QueryRuntimeException("IN expected a list, got {}.", _list.type()); } auto list = _list.Value>(); @@ -366,11 +366,22 @@ class ExpressionEvaluator : public TreeVisitor { } TypedValue Visit(Function &function) override { - std::vector arguments; - for (const auto &argument : function.arguments_) { - arguments.emplace_back(argument->Accept(*this)); + // Stack allocate evaluated arguments when there's a small number of them. + if (function.arguments_.size() <= 8) { + TypedValue arguments[8]; + for (size_t i = 0; i < function.arguments_.size(); ++i) { + arguments[i] = function.arguments_[i]->Accept(*this); + } + return function.function()(arguments, function.arguments_.size(), + context_); + } else { + std::vector arguments; + arguments.reserve(function.arguments_.size()); + for (const auto &argument : function.arguments_) { + arguments.emplace_back(argument->Accept(*this)); + } + return function.function()(arguments.data(), arguments.size(), context_); } - return function.function()(arguments, context_); } TypedValue Visit(Reduce &reduce) override { diff --git a/src/query/interpret/frame.hpp b/src/query/interpret/frame.hpp index d53ca7e9d..218916392 100644 --- a/src/query/interpret/frame.hpp +++ b/src/query/interpret/frame.hpp @@ -9,7 +9,7 @@ namespace query { class Frame { public: - Frame(int size) : size_(size), elems_(size_) {} + explicit Frame(int size) : size_(size), elems_(size_) {} TypedValue &operator[](const Symbol &symbol) { return elems_[symbol.position()]; diff --git a/tests/benchmark/query/eval.cpp b/tests/benchmark/query/eval.cpp new file mode 100644 index 000000000..af6294d09 --- /dev/null +++ b/tests/benchmark/query/eval.cpp @@ -0,0 +1,55 @@ +#include + +#include "query/frontend/ast/ast.hpp" +#include "query/interpret/eval.hpp" + +static void BenchmarkCoalesceCallWithNulls(benchmark::State &state) { + int64_t num_args = state.range(0); + query::AstStorage ast_storage; + std::vector arguments; + arguments.reserve(num_args); + for (int64_t i = 0; i < num_args; ++i) { + arguments.emplace_back( + ast_storage.Create(query::TypedValue::Null)); + } + auto *function = ast_storage.Create("COALESCE", arguments); + query::Frame frame(0); + database::GraphDbAccessor *dba = nullptr; + query::Context context(*dba); + query::ExpressionEvaluator evaluator(frame, &context, query::GraphView::OLD); + while (state.KeepRunning()) { + function->Accept(evaluator); + } +} + +static void BenchmarkCoalesceCallWithStrings(benchmark::State &state) { + int64_t num_args = state.range(0); + query::AstStorage ast_storage; + std::vector arguments; + arguments.reserve(num_args); + for (int64_t i = 0; i < num_args; ++i) { + std::string val = "some_string " + std::to_string(i); + arguments.emplace_back(ast_storage.Create(val)); + } + auto *function = ast_storage.Create("COALESCE", arguments); + query::Frame frame(0); + database::GraphDbAccessor *dba = nullptr; + query::Context context(*dba); + query::ExpressionEvaluator evaluator(frame, &context, query::GraphView::OLD); + while (state.KeepRunning()) { + function->Accept(evaluator); + } +} + +// We are interested in benchmarking the usual amount of arguments +BENCHMARK(BenchmarkCoalesceCallWithNulls) + ->RangeMultiplier(2) + ->Range(1, 256) + ->ThreadRange(1, 16); + +BENCHMARK(BenchmarkCoalesceCallWithStrings) + ->RangeMultiplier(2) + ->Range(1, 256) + ->ThreadRange(1, 16); + +BENCHMARK_MAIN();