diff --git a/src/query/interpret/awesome_memgraph_functions.cpp b/src/query/interpret/awesome_memgraph_functions.cpp index beff453ce..9b615f8ae 100644 --- a/src/query/interpret/awesome_memgraph_functions.cpp +++ b/src/query/interpret/awesome_memgraph_functions.cpp @@ -34,51 +34,57 @@ namespace { // TODO: Implement degrees, haversin, radians // TODO: Implement spatial functions -TypedValue EndNode(TypedValue *args, int64_t nargs, const EvaluationContext &, - database::GraphDbAccessor *) { +/////////////////////////// IMPORTANT NOTE! //////////////////////////////////// +// All of the functions take mutable `TypedValue *` to arguments, but none of +// the functions should ever need to actually modify the arguments! Let's try to +// keep our sanity in a good state by treating the arguments as immutable. +//////////////////////////////////////////////////////////////////////////////// + +TypedValue EndNode(TypedValue *args, int64_t nargs, + const EvaluationContext &ctx, database::GraphDbAccessor *) { if (nargs != 1) { throw QueryRuntimeException("'endNode' requires exactly one argument."); } switch (args[0].type()) { case TypedValue::Type::Null: - return TypedValue(); + return TypedValue(ctx.memory); case TypedValue::Type::Edge: - return TypedValue(args[0].Value().to()); + return TypedValue(args[0].Value().to(), ctx.memory); default: throw QueryRuntimeException("'endNode' argument must be an edge."); } } -TypedValue Head(TypedValue *args, int64_t nargs, const EvaluationContext &, +TypedValue Head(TypedValue *args, int64_t nargs, const EvaluationContext &ctx, database::GraphDbAccessor *) { if (nargs != 1) { throw QueryRuntimeException("'head' requires exactly one argument."); } switch (args[0].type()) { case TypedValue::Type::Null: - return TypedValue(); + return TypedValue(ctx.memory); case TypedValue::Type::List: { const auto &list = args[0].ValueList(); - if (list.empty()) return TypedValue(); - return list[0]; + if (list.empty()) return TypedValue(ctx.memory); + return TypedValue(list[0], ctx.memory); } default: throw QueryRuntimeException("'head' argument must be a list."); } } -TypedValue Last(TypedValue *args, int64_t nargs, const EvaluationContext &, +TypedValue Last(TypedValue *args, int64_t nargs, const EvaluationContext &ctx, database::GraphDbAccessor *) { if (nargs != 1) { throw QueryRuntimeException("'last' requires exactly one argument."); } switch (args[0].type()) { case TypedValue::Type::Null: - return TypedValue(); + return TypedValue(ctx.memory); case TypedValue::Type::List: { const auto &list = args[0].ValueList(); - if (list.empty()) return TypedValue(); - return list.back(); + if (list.empty()) return TypedValue(ctx.memory); + return TypedValue(list.back(), ctx.memory); } default: throw QueryRuntimeException("'last' argument must be a list."); @@ -86,22 +92,21 @@ TypedValue Last(TypedValue *args, int64_t nargs, const EvaluationContext &, } TypedValue Properties(TypedValue *args, int64_t nargs, - const EvaluationContext &, + const EvaluationContext &ctx, database::GraphDbAccessor *dba) { if (nargs != 1) { throw QueryRuntimeException("'properties' requires exactly one argument."); } auto get_properties = [&](const auto &record_accessor) { - std::map properties; + TypedValue::TMap properties(ctx.memory); for (const auto &property : record_accessor.Properties()) { - properties[dba->PropertyName(property.first)] = - TypedValue(property.second); + properties.emplace(dba->PropertyName(property.first), property.second); } - return TypedValue(properties); + return TypedValue(std::move(properties)); }; switch (args[0].type()) { case TypedValue::Type::Null: - return TypedValue(); + return TypedValue(ctx.memory); case TypedValue::Type::Vertex: return get_properties(args[0].Value()); case TypedValue::Type::Edge: @@ -112,83 +117,89 @@ TypedValue Properties(TypedValue *args, int64_t nargs, } } -TypedValue Size(TypedValue *args, int64_t nargs, const EvaluationContext &, +TypedValue Size(TypedValue *args, int64_t nargs, const EvaluationContext &ctx, database::GraphDbAccessor *) { if (nargs != 1) { throw QueryRuntimeException("'size' requires exactly one argument."); } switch (args[0].type()) { case TypedValue::Type::Null: - return TypedValue(); + return TypedValue(ctx.memory); case TypedValue::Type::List: - return TypedValue(static_cast(args[0].ValueList().size())); + return TypedValue(static_cast(args[0].ValueList().size()), + ctx.memory); case TypedValue::Type::String: - return TypedValue(static_cast(args[0].ValueString().size())); + return TypedValue(static_cast(args[0].ValueString().size()), + ctx.memory); case TypedValue::Type::Map: // neo4j doesn't implement size for map, but I don't see a good reason not // to do it. - return TypedValue(static_cast(args[0].ValueMap().size())); + return TypedValue(static_cast(args[0].ValueMap().size()), + ctx.memory); case TypedValue::Type::Path: return TypedValue( - static_cast(args[0].ValuePath().edges().size())); + static_cast(args[0].ValuePath().edges().size()), ctx.memory); default: throw QueryRuntimeException( "'size' argument must be a string, a collection or a path."); } } -TypedValue StartNode(TypedValue *args, int64_t nargs, const EvaluationContext &, +TypedValue StartNode(TypedValue *args, int64_t nargs, + const EvaluationContext &ctx, database::GraphDbAccessor *) { if (nargs != 1) { throw QueryRuntimeException("'startNode' requires exactly one argument."); } switch (args[0].type()) { case TypedValue::Type::Null: - return TypedValue(); + return TypedValue(ctx.memory); case TypedValue::Type::Edge: - return TypedValue(args[0].Value().from()); + return TypedValue(args[0].Value().from(), ctx.memory); default: throw QueryRuntimeException("'startNode' argument must be an edge."); } } -TypedValue Degree(TypedValue *args, int64_t nargs, const EvaluationContext &, +TypedValue Degree(TypedValue *args, int64_t nargs, const EvaluationContext &ctx, database::GraphDbAccessor *) { if (nargs != 1) { throw QueryRuntimeException("'degree' requires exactly one argument."); } switch (args[0].type()) { case TypedValue::Type::Null: - return TypedValue(); + return TypedValue(ctx.memory); case TypedValue::Type::Vertex: { - auto &vertex = args[0].Value(); + const auto &vertex = args[0].Value(); return TypedValue( - static_cast(vertex.out_degree() + vertex.in_degree())); + static_cast(vertex.out_degree() + vertex.in_degree()), + ctx.memory); } default: throw QueryRuntimeException("'degree' argument must be a node."); } } -TypedValue InDegree(TypedValue *args, int64_t nargs, const EvaluationContext &, - database::GraphDbAccessor *) { +TypedValue InDegree(TypedValue *args, int64_t nargs, + const EvaluationContext &ctx, database::GraphDbAccessor *) { if (nargs != 1) { throw QueryRuntimeException("'inDegree' requires exactly one argument."); } switch (args[0].type()) { case TypedValue::Type::Null: - return TypedValue(); + return TypedValue(ctx.memory); case TypedValue::Type::Vertex: { - auto &vertex = args[0].Value(); - return TypedValue(static_cast(vertex.in_degree())); + const auto &vertex = args[0].Value(); + return TypedValue(static_cast(vertex.in_degree()), ctx.memory); } default: throw QueryRuntimeException("'inDegree' argument must be a node."); } } -TypedValue OutDegree(TypedValue *args, int64_t nargs, const EvaluationContext &, +TypedValue OutDegree(TypedValue *args, int64_t nargs, + const EvaluationContext &ctx, database::GraphDbAccessor *) { if (nargs != 1) { throw QueryRuntimeException("'outDegree' requires exactly one argument."); @@ -196,35 +207,36 @@ TypedValue OutDegree(TypedValue *args, int64_t nargs, const EvaluationContext &, switch (args[0].type()) { case TypedValue::Type::Null: - return TypedValue(); + return TypedValue(ctx.memory); case TypedValue::Type::Vertex: { - auto &vertex = args[0].Value(); - return TypedValue(static_cast(vertex.out_degree())); + const auto &vertex = args[0].Value(); + return TypedValue(static_cast(vertex.out_degree()), ctx.memory); } default: throw QueryRuntimeException("'outDegree' argument must be a node."); } } -TypedValue ToBoolean(TypedValue *args, int64_t nargs, const EvaluationContext &, +TypedValue ToBoolean(TypedValue *args, int64_t nargs, + const EvaluationContext &ctx, database::GraphDbAccessor *) { if (nargs != 1) { throw QueryRuntimeException("'toBoolean' requires exactly one argument."); } switch (args[0].type()) { case TypedValue::Type::Null: - return TypedValue(); + return TypedValue(ctx.memory); case TypedValue::Type::Bool: - return TypedValue(args[0].Value()); + return TypedValue(args[0].Value(), ctx.memory); case TypedValue::Type::Int: - return TypedValue(args[0].ValueInt() != 0L); + return TypedValue(args[0].ValueInt() != 0L, ctx.memory); case TypedValue::Type::String: { auto s = utils::ToUpperCase(utils::Trim(args[0].ValueString())); - if (s == "TRUE") return TypedValue(true); - if (s == "FALSE") return TypedValue(false); + if (s == "TRUE") return TypedValue(true, ctx.memory); + if (s == "FALSE") return TypedValue(false, ctx.memory); // I think this is just stupid and that exception should be thrown, but // neo4j does it this way... - return TypedValue(); + return TypedValue(ctx.memory); } default: throw QueryRuntimeException( @@ -232,24 +244,25 @@ TypedValue ToBoolean(TypedValue *args, int64_t nargs, const EvaluationContext &, } } -TypedValue ToFloat(TypedValue *args, int64_t nargs, const EvaluationContext &, - database::GraphDbAccessor *) { +TypedValue ToFloat(TypedValue *args, int64_t nargs, + const EvaluationContext &ctx, database::GraphDbAccessor *) { if (nargs != 1) { throw QueryRuntimeException("'toFloat' requires exactly one argument."); } switch (args[0].type()) { case TypedValue::Type::Null: - return TypedValue(); + return TypedValue(ctx.memory); case TypedValue::Type::Int: - return TypedValue(static_cast(args[0].Value())); + return TypedValue(static_cast(args[0].Value()), + ctx.memory); case TypedValue::Type::Double: - return args[0]; + return TypedValue(args[0], ctx.memory); case TypedValue::Type::String: try { return TypedValue( - utils::ParseDouble(utils::Trim(args[0].ValueString()))); + utils::ParseDouble(utils::Trim(args[0].ValueString())), ctx.memory); } catch (const utils::BasicException &) { - return TypedValue(); + return TypedValue(ctx.memory); } default: throw QueryRuntimeException( @@ -257,28 +270,31 @@ TypedValue ToFloat(TypedValue *args, int64_t nargs, const EvaluationContext &, } } -TypedValue ToInteger(TypedValue *args, int64_t nargs, const EvaluationContext &, +TypedValue ToInteger(TypedValue *args, int64_t nargs, + const EvaluationContext &ctx, database::GraphDbAccessor *) { if (nargs != 1) { throw QueryRuntimeException("'toInteger' requires exactly one argument'"); } switch (args[0].type()) { case TypedValue::Type::Null: - return TypedValue(); + return TypedValue(ctx.memory); case TypedValue::Type::Bool: - return TypedValue(args[0].ValueBool() ? 1L : 0L); + return TypedValue(args[0].ValueBool() ? 1L : 0L, ctx.memory); case TypedValue::Type::Int: - return args[0]; + return TypedValue(args[0], ctx.memory); case TypedValue::Type::Double: - return TypedValue(static_cast(args[0].Value())); + return TypedValue(static_cast(args[0].Value()), + ctx.memory); case TypedValue::Type::String: try { // Yup, this is correct. String is valid if it has floating point // number, then it is parsed and converted to int. - return TypedValue(static_cast( - utils::ParseDouble(utils::Trim(args[0].ValueString())))); + return TypedValue(static_cast(utils::ParseDouble( + utils::Trim(args[0].ValueString()))), + ctx.memory); } catch (const utils::BasicException &) { - return TypedValue(); + return TypedValue(ctx.memory); } default: throw QueryRuntimeException( @@ -286,37 +302,38 @@ TypedValue ToInteger(TypedValue *args, int64_t nargs, const EvaluationContext &, } } -TypedValue Type(TypedValue *args, int64_t nargs, const EvaluationContext &, +TypedValue Type(TypedValue *args, int64_t nargs, const EvaluationContext &ctx, database::GraphDbAccessor *dba) { if (nargs != 1) { throw QueryRuntimeException("'type' requires exactly one argument."); } switch (args[0].type()) { case TypedValue::Type::Null: - return TypedValue(); + return TypedValue(ctx.memory); case TypedValue::Type::Edge: return TypedValue( - dba->EdgeTypeName(args[0].Value().EdgeType())); + dba->EdgeTypeName(args[0].Value().EdgeType()), + ctx.memory); default: throw QueryRuntimeException("'type' argument must be an edge."); } } -TypedValue Keys(TypedValue *args, int64_t nargs, const EvaluationContext &, +TypedValue Keys(TypedValue *args, int64_t nargs, const EvaluationContext &ctx, database::GraphDbAccessor *dba) { if (nargs != 1) { throw QueryRuntimeException("'keys' requires exactly one argument."); } auto get_keys = [&](const auto &record_accessor) { - std::vector keys; + TypedValue::TVector keys(ctx.memory); for (const auto &property : record_accessor.Properties()) { keys.emplace_back(dba->PropertyName(property.first)); } - return TypedValue(keys); + return TypedValue(std::move(keys)); }; switch (args[0].type()) { case TypedValue::Type::Null: - return TypedValue(); + return TypedValue(ctx.memory); case TypedValue::Type::Vertex: return get_keys(args[0].Value()); case TypedValue::Type::Edge: @@ -326,55 +343,61 @@ TypedValue Keys(TypedValue *args, int64_t nargs, const EvaluationContext &, } } -TypedValue Labels(TypedValue *args, int64_t nargs, const EvaluationContext &, +TypedValue Labels(TypedValue *args, int64_t nargs, const EvaluationContext &ctx, database::GraphDbAccessor *dba) { if (nargs != 1) { throw QueryRuntimeException("'labels' requires exactly one argument."); } switch (args[0].type()) { case TypedValue::Type::Null: - return TypedValue(); + return TypedValue(ctx.memory); case TypedValue::Type::Vertex: { - std::vector labels; + TypedValue::TVector labels(ctx.memory); for (const auto &label : args[0].Value().labels()) { labels.emplace_back(dba->LabelName(label)); } - return TypedValue(labels); + return TypedValue(std::move(labels)); } default: throw QueryRuntimeException("'labels' argument must be a node."); } } -TypedValue Nodes(TypedValue *args, int64_t nargs, const EvaluationContext &, +TypedValue Nodes(TypedValue *args, int64_t nargs, const EvaluationContext &ctx, database::GraphDbAccessor *) { if (nargs != 1) { throw QueryRuntimeException("'nodes' requires exactly one argument."); } - if (args[0].IsNull()) return TypedValue(); + if (args[0].IsNull()) return TypedValue(ctx.memory); if (!args[0].IsPath()) { throw QueryRuntimeException("'nodes' argument should be a path."); } - auto &vertices = args[0].ValuePath().vertices(); - return TypedValue(std::vector(vertices.begin(), vertices.end())); + const auto &vertices = args[0].ValuePath().vertices(); + TypedValue::TVector values(ctx.memory); + values.reserve(vertices.size()); + for (const auto &v : vertices) values.emplace_back(v); + return TypedValue(std::move(values)); } TypedValue Relationships(TypedValue *args, int64_t nargs, - const EvaluationContext &, + const EvaluationContext &ctx, database::GraphDbAccessor *) { if (nargs != 1) { throw QueryRuntimeException( "'relationships' requires exactly one argument."); } - if (args[0].IsNull()) return TypedValue(); + if (args[0].IsNull()) return TypedValue(ctx.memory); if (!args[0].IsPath()) { throw QueryRuntimeException("'relationships' argument must be a path."); } - auto &edges = args[0].ValuePath().edges(); - return TypedValue(std::vector(edges.begin(), edges.end())); + const auto &edges = args[0].ValuePath().edges(); + TypedValue::TVector values(ctx.memory); + values.reserve(edges.size()); + for (const auto &e : edges) values.emplace_back(e); + return TypedValue(std::move(values)); } -TypedValue Range(TypedValue *args, int64_t nargs, const EvaluationContext &, +TypedValue Range(TypedValue *args, int64_t nargs, const EvaluationContext &ctx, database::GraphDbAccessor *) { if (nargs != 2 && nargs != 3) { throw QueryRuntimeException("'range' requires two or three arguments."); @@ -388,14 +411,14 @@ TypedValue Range(TypedValue *args, int64_t nargs, const EvaluationContext &, } }; for (int64_t i = 0; i < nargs; ++i) check_type(args[i]); - if (has_null) return TypedValue(); + if (has_null) return TypedValue(ctx.memory); auto lbound = args[0].Value(); auto rbound = args[1].Value(); int64_t step = nargs == 3 ? args[2].Value() : 1; if (step == 0) { throw QueryRuntimeException("step argument of 'range' can't be zero."); } - std::vector list; + TypedValue::TVector list(ctx.memory); if (lbound <= rbound && step > 0) { for (auto i = lbound; i <= rbound; i += step) { list.emplace_back(i); @@ -405,22 +428,22 @@ TypedValue Range(TypedValue *args, int64_t nargs, const EvaluationContext &, list.emplace_back(i); } } - return TypedValue(list); + return TypedValue(std::move(list)); } -TypedValue Tail(TypedValue *args, int64_t nargs, const EvaluationContext &, +TypedValue Tail(TypedValue *args, int64_t nargs, const EvaluationContext &ctx, database::GraphDbAccessor *) { if (nargs != 1) { throw QueryRuntimeException("'tail' requires exactly one argument."); } switch (args[0].type()) { case TypedValue::Type::Null: - return TypedValue(); + return TypedValue(ctx.memory); case TypedValue::Type::List: { - auto list = args[0].ValueList(); - if (list.empty()) return TypedValue(list); + TypedValue::TVector list(args[0].ValueList(), ctx.memory); + if (list.empty()) return TypedValue(std::move(list)); list.erase(list.begin()); - return TypedValue(list); + return TypedValue(std::move(list)); } default: throw QueryRuntimeException("'tail' argument must be a list."); @@ -428,7 +451,7 @@ TypedValue Tail(TypedValue *args, int64_t nargs, const EvaluationContext &, } TypedValue UniformSample(TypedValue *args, int64_t nargs, - const EvaluationContext &, + const EvaluationContext &ctx, database::GraphDbAccessor *) { static thread_local std::mt19937 pseudo_rand_gen_{std::random_device{}()}; if (nargs != 2) { @@ -438,7 +461,7 @@ TypedValue UniformSample(TypedValue *args, int64_t nargs, switch (args[0].type()) { case TypedValue::Type::Null: if (args[1].IsNull() || (args[1].IsInt() && args[1].ValueInt() >= 0)) { - return TypedValue(); + return TypedValue(ctx.memory); } throw QueryRuntimeException( "Second argument of 'uniformSample' must be a non-negative integer."); @@ -446,16 +469,16 @@ TypedValue UniformSample(TypedValue *args, int64_t nargs, if (args[1].IsInt() && args[1].ValueInt() >= 0) { auto &population = args[0].ValueList(); auto population_size = population.size(); - if (population_size == 0) return TypedValue(); + if (population_size == 0) return TypedValue(ctx.memory); auto desired_length = args[1].ValueInt(); std::uniform_int_distribution rand_dist{0, population_size - 1}; - std::vector sampled; + TypedValue::TVector sampled(ctx.memory); sampled.reserve(desired_length); for (int i = 0; i < desired_length; ++i) { - sampled.push_back(population[rand_dist(pseudo_rand_gen_)]); + sampled.emplace_back(population[rand_dist(pseudo_rand_gen_)]); } - return TypedValue(sampled); + return TypedValue(std::move(sampled)); } throw QueryRuntimeException( "Second argument of 'uniformSample' must be a non-negative integer."); @@ -465,41 +488,43 @@ TypedValue UniformSample(TypedValue *args, int64_t nargs, } } -TypedValue Abs(TypedValue *args, int64_t nargs, const EvaluationContext &, +TypedValue Abs(TypedValue *args, int64_t nargs, const EvaluationContext &ctx, database::GraphDbAccessor *) { if (nargs != 1) { throw QueryRuntimeException("'abs' requires exactly one argument."); } switch (args[0].type()) { case TypedValue::Type::Null: - return TypedValue(); + return TypedValue(ctx.memory); case TypedValue::Type::Int: - return TypedValue(std::abs(args[0].Value())); + return TypedValue(std::abs(args[0].Value()), ctx.memory); case TypedValue::Type::Double: - return TypedValue(std::abs(args[0].Value())); + return TypedValue(std::abs(args[0].Value()), ctx.memory); default: throw QueryRuntimeException("'abs' argument should be a number."); } } -#define WRAP_CMATH_FLOAT_FUNCTION(name, lowercased_name) \ - TypedValue name(TypedValue *args, int64_t nargs, const EvaluationContext &, \ - database::GraphDbAccessor *) { \ - if (nargs != 1) { \ - throw QueryRuntimeException("'" #lowercased_name \ - "' requires exactly one argument."); \ - } \ - switch (args[0].type()) { \ - case TypedValue::Type::Null: \ - return TypedValue(); \ - case TypedValue::Type::Int: \ - return TypedValue(lowercased_name(args[0].Value())); \ - case TypedValue::Type::Double: \ - return TypedValue(lowercased_name(args[0].Value())); \ - default: \ - throw QueryRuntimeException(#lowercased_name \ - " argument must be a number."); \ - } \ +#define WRAP_CMATH_FLOAT_FUNCTION(name, lowercased_name) \ + TypedValue name(TypedValue *args, int64_t nargs, \ + const EvaluationContext &ctx, database::GraphDbAccessor *) { \ + if (nargs != 1) { \ + throw QueryRuntimeException("'" #lowercased_name \ + "' requires exactly one argument."); \ + } \ + switch (args[0].type()) { \ + case TypedValue::Type::Null: \ + return TypedValue(ctx.memory); \ + case TypedValue::Type::Int: \ + return TypedValue(lowercased_name(args[0].Value()), \ + ctx.memory); \ + case TypedValue::Type::Double: \ + return TypedValue(lowercased_name(args[0].Value()), \ + ctx.memory); \ + default: \ + throw QueryRuntimeException(#lowercased_name \ + " argument must be a number."); \ + } \ } WRAP_CMATH_FLOAT_FUNCTION(Ceil, ceil) @@ -520,13 +545,13 @@ WRAP_CMATH_FLOAT_FUNCTION(Tan, tan) #undef WRAP_CMATH_FLOAT_FUNCTION -TypedValue Atan2(TypedValue *args, int64_t nargs, const EvaluationContext &, +TypedValue Atan2(TypedValue *args, int64_t nargs, const EvaluationContext &ctx, database::GraphDbAccessor *) { if (nargs != 2) { throw QueryRuntimeException("'atan2' requires two arguments."); } - if (args[0].type() == TypedValue::Type::Null) return TypedValue(); - if (args[1].type() == TypedValue::Type::Null) return TypedValue(); + if (args[0].type() == TypedValue::Type::Null) return TypedValue(ctx.memory); + if (args[1].type() == TypedValue::Type::Null) return TypedValue(ctx.memory); auto to_double = [](const TypedValue &t) -> double { switch (t.type()) { case TypedValue::Type::Int: @@ -539,18 +564,18 @@ TypedValue Atan2(TypedValue *args, int64_t nargs, const EvaluationContext &, }; double y = to_double(args[0]); double x = to_double(args[1]); - return TypedValue(atan2(y, x)); + return TypedValue(atan2(y, x), ctx.memory); } -TypedValue Sign(TypedValue *args, int64_t nargs, const EvaluationContext &, +TypedValue Sign(TypedValue *args, int64_t nargs, const EvaluationContext &ctx, database::GraphDbAccessor *) { if (nargs != 1) { throw QueryRuntimeException("'sign' requires exactly one argument."); } - auto sign = [](auto x) { return TypedValue((0 < x) - (x < 0)); }; + auto sign = [&](auto x) { return TypedValue((0 < x) - (x < 0), ctx.memory); }; switch (args[0].type()) { case TypedValue::Type::Null: - return TypedValue(); + return TypedValue(ctx.memory); case TypedValue::Type::Int: return sign(args[0].Value()); case TypedValue::Type::Double: @@ -560,36 +585,36 @@ TypedValue Sign(TypedValue *args, int64_t nargs, const EvaluationContext &, } } -TypedValue E(TypedValue *, int64_t nargs, const EvaluationContext &, +TypedValue E(TypedValue *, int64_t nargs, const EvaluationContext &ctx, database::GraphDbAccessor *) { if (nargs != 0) { throw QueryRuntimeException("'e' requires no arguments."); } - return TypedValue(M_E); + return TypedValue(M_E, ctx.memory); } -TypedValue Pi(TypedValue *, int64_t nargs, const EvaluationContext &, +TypedValue Pi(TypedValue *, int64_t nargs, const EvaluationContext &ctx, database::GraphDbAccessor *) { if (nargs != 0) { throw QueryRuntimeException("'pi' requires no arguments."); } - return TypedValue(M_PI); + return TypedValue(M_PI, ctx.memory); } -TypedValue Rand(TypedValue *, int64_t nargs, const EvaluationContext &, +TypedValue Rand(TypedValue *, int64_t nargs, const EvaluationContext &ctx, database::GraphDbAccessor *) { static thread_local std::mt19937 pseudo_rand_gen_{std::random_device{}()}; static thread_local std::uniform_real_distribution<> rand_dist_{0, 1}; if (nargs != 0) { throw QueryRuntimeException("'rand' requires no arguments."); } - return TypedValue(rand_dist_(pseudo_rand_gen_)); + return TypedValue(rand_dist_(pseudo_rand_gen_), ctx.memory); } template TypedValue StringMatchOperator(TypedValue *args, int64_t nargs, - const EvaluationContext &, + const EvaluationContext &ctx, database::GraphDbAccessor *) { if (nargs != 2) { throw QueryRuntimeException( @@ -606,10 +631,10 @@ TypedValue StringMatchOperator(TypedValue *args, int64_t nargs, }; check_arg(args[0]); check_arg(args[1]); - if (has_null) return TypedValue(); + if (has_null) return TypedValue(ctx.memory); const auto &s1 = args[0].ValueString(); const auto &s2 = args[1].ValueString(); - return TypedValue(Predicate(s1, s2)); + return TypedValue(Predicate(s1, s2), ctx.memory); } // Check if s1 starts with s2. @@ -636,7 +661,7 @@ bool ContainsPredicate(const TypedValue::TString &s1, } auto Contains = StringMatchOperator; -TypedValue Assert(TypedValue *args, int64_t nargs, const EvaluationContext &, +TypedValue Assert(TypedValue *args, int64_t nargs, const EvaluationContext &ctx, database::GraphDbAccessor *) { if (nargs < 1 || nargs > 2) { throw QueryRuntimeException("'assert' requires one or two arguments"); @@ -656,7 +681,7 @@ TypedValue Assert(TypedValue *args, int64_t nargs, const EvaluationContext &, message += "."; throw QueryRuntimeException(message); } - return args[0]; + return TypedValue(args[0], ctx.memory); } #if defined(MG_SINGLE_NODE) || defined(MG_SINGLE_NODE_HA) @@ -684,22 +709,24 @@ TypedValue Counter(TypedValue *args, int64_t nargs, auto value = it->second; it->second += step; - return TypedValue(value); + return TypedValue(value, context.memory); } #endif #ifdef MG_DISTRIBUTED -TypedValue WorkerId(TypedValue *args, int64_t nargs, const EvaluationContext &, - database::GraphDbAccessor *) { +TypedValue WorkerId(TypedValue *args, int64_t nargs, + const EvaluationContext &ctx, database::GraphDbAccessor *) { if (nargs != 1) { throw QueryRuntimeException("'workerId' requires exactly one argument."); } - auto &arg = args[0]; + const auto &arg = args[0]; switch (arg.type()) { case TypedValue::Type::Vertex: - return TypedValue(arg.ValueVertex().GlobalAddress().worker_id()); + return TypedValue(arg.ValueVertex().GlobalAddress().worker_id(), + ctx.memory); case TypedValue::Type::Edge: - return TypedValue(arg.ValueEdge().GlobalAddress().worker_id()); + return TypedValue(arg.ValueEdge().GlobalAddress().worker_id(), + ctx.memory); default: throw QueryRuntimeException( "'workerId' argument must be a node or an edge."); @@ -707,41 +734,43 @@ TypedValue WorkerId(TypedValue *args, int64_t nargs, const EvaluationContext &, } #endif -TypedValue Id(TypedValue *args, int64_t nargs, const EvaluationContext &, +TypedValue Id(TypedValue *args, int64_t nargs, const EvaluationContext &ctx, database::GraphDbAccessor *dba) { if (nargs != 1) { throw QueryRuntimeException("'id' requires exactly one argument."); } - auto &arg = args[0]; + const auto &arg = args[0]; switch (arg.type()) { case TypedValue::Type::Vertex: { - return TypedValue(arg.ValueVertex().CypherId()); + return TypedValue(arg.ValueVertex().CypherId(), ctx.memory); } case TypedValue::Type::Edge: { - return TypedValue(arg.ValueEdge().CypherId()); + return TypedValue(arg.ValueEdge().CypherId(), ctx.memory); } default: throw QueryRuntimeException("'id' argument must be a node or an edge."); } } -TypedValue ToString(TypedValue *args, int64_t nargs, const EvaluationContext &, - database::GraphDbAccessor *) { +TypedValue ToString(TypedValue *args, int64_t nargs, + const EvaluationContext &ctx, database::GraphDbAccessor *) { if (nargs != 1) { throw QueryRuntimeException("'toString' requires exactly one argument."); } - auto &arg = args[0]; + const auto &arg = args[0]; switch (arg.type()) { case TypedValue::Type::Null: - return TypedValue(); + return TypedValue(ctx.memory); case TypedValue::Type::String: - return arg; + return TypedValue(arg, ctx.memory); case TypedValue::Type::Int: - return TypedValue(std::to_string(arg.ValueInt())); + // TODO: This is making a pointless copy of std::string, we may want to + // use a different conversion to string + return TypedValue(std::to_string(arg.ValueInt()), ctx.memory); case TypedValue::Type::Double: - return TypedValue(std::to_string(arg.ValueDouble())); + return TypedValue(std::to_string(arg.ValueDouble()), ctx.memory); case TypedValue::Type::Bool: - return TypedValue(arg.ValueBool() ? "true" : "false"); + return TypedValue(arg.ValueBool() ? "true" : "false", ctx.memory); default: throw QueryRuntimeException( "'toString' argument must be a number, a string or a boolean."); @@ -753,10 +782,10 @@ TypedValue Timestamp(TypedValue *, int64_t nargs, const EvaluationContext &ctx, if (nargs != 0) { throw QueryRuntimeException("'timestamp' requires no arguments."); } - return TypedValue(ctx.timestamp); + return TypedValue(ctx.timestamp, ctx.memory); } -TypedValue Left(TypedValue *args, int64_t nargs, const EvaluationContext &, +TypedValue Left(TypedValue *args, int64_t nargs, const EvaluationContext &ctx, database::GraphDbAccessor *dba) { if (nargs != 2) { throw QueryRuntimeException("'left' requires two arguments."); @@ -764,16 +793,15 @@ TypedValue Left(TypedValue *args, int64_t nargs, const EvaluationContext &, switch (args[0].type()) { case TypedValue::Type::Null: if (args[1].IsNull() || (args[1].IsInt() && args[1].ValueInt() >= 0)) { - return TypedValue(); + return TypedValue(ctx.memory); } throw QueryRuntimeException( "Second argument of 'left' must be a non-negative integer."); case TypedValue::Type::String: if (args[1].IsInt() && args[1].ValueInt() >= 0) { - auto *memory = args[0].GetMemoryResource(); return TypedValue( utils::Substr(args[0].ValueString(), 0, args[1].ValueInt()), - memory); + ctx.memory); } throw QueryRuntimeException( "Second argument of 'left' must be a non-negative integer."); @@ -782,7 +810,7 @@ TypedValue Left(TypedValue *args, int64_t nargs, const EvaluationContext &, } } -TypedValue Right(TypedValue *args, int64_t nargs, const EvaluationContext &, +TypedValue Right(TypedValue *args, int64_t nargs, const EvaluationContext &ctx, database::GraphDbAccessor *dba) { if (nargs != 2) { throw QueryRuntimeException("'right' requires two arguments."); @@ -790,7 +818,7 @@ TypedValue Right(TypedValue *args, int64_t nargs, const EvaluationContext &, switch (args[0].type()) { case TypedValue::Type::Null: if (args[1].IsNull() || (args[1].IsInt() && args[1].ValueInt() >= 0)) { - return TypedValue(); + return TypedValue(ctx.memory); } throw QueryRuntimeException( "Second argument of 'right' must be a non-negative integer."); @@ -798,11 +826,10 @@ TypedValue Right(TypedValue *args, int64_t nargs, const EvaluationContext &, const auto &str = args[0].ValueString(); if (args[1].IsInt() && args[1].ValueInt() >= 0) { auto len = args[1].ValueInt(); - auto *memory = args[0].GetMemoryResource(); return len <= str.size() ? TypedValue(utils::Substr(str, str.size() - len, len), - memory) - : TypedValue(str, memory); + ctx.memory) + : TypedValue(str, ctx.memory); } throw QueryRuntimeException( "Second argument of 'right' must be a non-negative integer."); @@ -814,7 +841,8 @@ TypedValue Right(TypedValue *args, int64_t nargs, const EvaluationContext &, } TypedValue CallStringFunction( - TypedValue *args, int64_t nargs, const std::string &name, + TypedValue *args, int64_t nargs, utils::MemoryResource *memory, + const std::string &name, std::function fun) { if (nargs != 1) { throw QueryRuntimeException("'" + name + @@ -822,62 +850,68 @@ TypedValue CallStringFunction( } switch (args[0].type()) { case TypedValue::Type::Null: - return TypedValue(args[0].GetMemoryResource()); + return TypedValue(memory); case TypedValue::Type::String: - return TypedValue(fun(args[0].ValueString())); + return TypedValue(fun(args[0].ValueString()), memory); default: throw QueryRuntimeException("'" + name + "' argument should be a string."); } } -TypedValue LTrim(TypedValue *args, int64_t nargs, const EvaluationContext &, +TypedValue LTrim(TypedValue *args, int64_t nargs, const EvaluationContext &ctx, database::GraphDbAccessor *) { - return CallStringFunction(args, nargs, "lTrim", [](const auto &str) { - return TypedValue::TString(utils::LTrim(str), str.get_allocator()); - }); + return CallStringFunction( + args, nargs, ctx.memory, "lTrim", [&](const auto &str) { + return TypedValue::TString(utils::LTrim(str), ctx.memory); + }); } -TypedValue RTrim(TypedValue *args, int64_t nargs, const EvaluationContext &, +TypedValue RTrim(TypedValue *args, int64_t nargs, const EvaluationContext &ctx, database::GraphDbAccessor *) { - return CallStringFunction(args, nargs, "rTrim", [](const auto &str) { - return TypedValue::TString(utils::RTrim(str), str.get_allocator()); - }); + return CallStringFunction( + args, nargs, ctx.memory, "rTrim", [&](const auto &str) { + return TypedValue::TString(utils::RTrim(str), ctx.memory); + }); } -TypedValue Trim(TypedValue *args, int64_t nargs, const EvaluationContext &, +TypedValue Trim(TypedValue *args, int64_t nargs, const EvaluationContext &ctx, database::GraphDbAccessor *) { - return CallStringFunction(args, nargs, "trim", [](const auto &str) { - return TypedValue::TString(utils::Trim(str), str.get_allocator()); - }); + return CallStringFunction( + args, nargs, ctx.memory, "trim", [&](const auto &str) { + return TypedValue::TString(utils::Trim(str), ctx.memory); + }); } -TypedValue Reverse(TypedValue *args, int64_t nargs, const EvaluationContext &, - database::GraphDbAccessor *) { - return CallStringFunction(args, nargs, "reverse", [](const auto &str) { - return utils::Reversed(str, str.get_allocator()); - }); +TypedValue Reverse(TypedValue *args, int64_t nargs, + const EvaluationContext &ctx, database::GraphDbAccessor *) { + return CallStringFunction( + args, nargs, ctx.memory, "reverse", + [&](const auto &str) { return utils::Reversed(str, ctx.memory); }); } -TypedValue ToLower(TypedValue *args, int64_t nargs, const EvaluationContext &, - database::GraphDbAccessor *) { - return CallStringFunction(args, nargs, "toLower", [](const auto &str) { - TypedValue::TString res(str.get_allocator()); - utils::ToLowerCase(&res, str); - return res; - }); +TypedValue ToLower(TypedValue *args, int64_t nargs, + const EvaluationContext &ctx, database::GraphDbAccessor *) { + return CallStringFunction(args, nargs, ctx.memory, "toLower", + [&](const auto &str) { + TypedValue::TString res(ctx.memory); + utils::ToLowerCase(&res, str); + return res; + }); } -TypedValue ToUpper(TypedValue *args, int64_t nargs, const EvaluationContext &, - database::GraphDbAccessor *) { - return CallStringFunction(args, nargs, "toUpper", [](const auto &str) { - TypedValue::TString res(str.get_allocator()); - utils::ToUpperCase(&res, str); - return res; - }); +TypedValue ToUpper(TypedValue *args, int64_t nargs, + const EvaluationContext &ctx, database::GraphDbAccessor *) { + return CallStringFunction(args, nargs, ctx.memory, "toUpper", + [&](const auto &str) { + TypedValue::TString res(ctx.memory); + utils::ToUpperCase(&res, str); + return res; + }); } -TypedValue Replace(TypedValue *args, int64_t nargs, const EvaluationContext &, +TypedValue Replace(TypedValue *args, int64_t nargs, + const EvaluationContext &ctx, database::GraphDbAccessor *dba) { if (nargs != 3) { throw QueryRuntimeException("'replace' requires three arguments."); @@ -895,13 +929,15 @@ TypedValue Replace(TypedValue *args, int64_t nargs, const EvaluationContext &, "Third argument of 'replace' should be a string."); } if (args[0].IsNull() || args[1].IsNull() || args[2].IsNull()) { - return TypedValue(); + return TypedValue(ctx.memory); } - return TypedValue(utils::Replace(args[0].ValueString(), args[1].ValueString(), - args[2].ValueString())); + TypedValue::TString replaced(ctx.memory); + utils::Replace(&replaced, args[0].ValueString(), args[1].ValueString(), + args[2].ValueString()); + return TypedValue(std::move(replaced)); } -TypedValue Split(TypedValue *args, int64_t nargs, const EvaluationContext &, +TypedValue Split(TypedValue *args, int64_t nargs, const EvaluationContext &ctx, database::GraphDbAccessor *dba) { if (nargs != 2) { throw QueryRuntimeException("'split' requires two arguments."); @@ -915,17 +951,15 @@ TypedValue Split(TypedValue *args, int64_t nargs, const EvaluationContext &, "Second argument of 'split' should be a string."); } if (args[0].IsNull() || args[1].IsNull()) { - return TypedValue(); + return TypedValue(ctx.memory); } - std::vector result; - for (const auto &str : - utils::Split(args[0].ValueString(), args[1].ValueString())) { - result.emplace_back(str); - } - return TypedValue(result); + TypedValue::TVector result(ctx.memory); + utils::Split(&result, args[0].ValueString(), args[1].ValueString()); + return TypedValue(std::move(result)); } -TypedValue Substring(TypedValue *args, int64_t nargs, const EvaluationContext &, +TypedValue Substring(TypedValue *args, int64_t nargs, + const EvaluationContext &ctx, database::GraphDbAccessor *) { if (nargs != 2 && nargs != 3) { throw QueryRuntimeException("'substring' requires two or three arguments."); @@ -943,21 +977,20 @@ TypedValue Substring(TypedValue *args, int64_t nargs, const EvaluationContext &, "Third argument of 'substring' should be a non-negative integer."); } if (args[0].IsNull()) { - return TypedValue(); + return TypedValue(ctx.memory); } const auto &str = args[0].ValueString(); auto start = args[1].ValueInt(); - auto *memory = args[0].GetMemoryResource(); if (nargs == 2) { - return TypedValue(utils::Substr(str, start), memory); + return TypedValue(utils::Substr(str, start), ctx.memory); } auto len = args[2].ValueInt(); - return TypedValue(utils::Substr(str, start, len), memory); + return TypedValue(utils::Substr(str, start, len), ctx.memory); } } // namespace -std::function NameToFunction(const std::string &function_name) { // Scalar functions diff --git a/src/query/interpret/eval.hpp b/src/query/interpret/eval.hpp index 43861ffe8..c52d91047 100644 --- a/src/query/interpret/eval.hpp +++ b/src/query/interpret/eval.hpp @@ -381,22 +381,23 @@ class ExpressionEvaluator : public ExpressionVisitor { for (size_t i = 0; i < function.arguments_.size(); ++i) { arguments[i] = function.arguments_[i]->Accept(*this); } - // TODO: Update awesome_memgraph_functions to use the allocator from ctx_ - return TypedValue(function.function_( - arguments, function.arguments_.size(), *ctx_, dba_), - ctx_->memory); + auto res = function.function_(arguments, function.arguments_.size(), + *ctx_, dba_); + CHECK(res.GetMemoryResource() == ctx_->memory); + return res; } else { TypedValue::TVector arguments(ctx_->memory); arguments.reserve(function.arguments_.size()); for (const auto &argument : function.arguments_) { arguments.emplace_back(argument->Accept(*this)); } - // TODO: Update awesome_memgraph_functions to use the allocator from ctx_ - return TypedValue( - function.function_(arguments.data(), arguments.size(), *ctx_, dba_), - ctx_->memory); + auto res = + function.function_(arguments.data(), arguments.size(), *ctx_, dba_); + CHECK(res.GetMemoryResource() == ctx_->memory); + return res; } } + TypedValue Visit(Reduce &reduce) override { auto list_value = reduce.list_->Accept(*this); if (list_value.IsNull()) { diff --git a/src/utils/string.hpp b/src/utils/string.hpp index 7b418843b..6ffcd8755 100644 --- a/src/utils/string.hpp +++ b/src/utils/string.hpp @@ -175,7 +175,7 @@ template std::basic_string, TAllocator> *Replace( std::basic_string, TAllocator> *out, const std::string_view &src, const std::string_view &match, - const std::string_view &replacement, const TAllocator &alloc) { + const std::string_view &replacement) { // TODO: This could be implemented much more efficiently. *out = src; for (size_t pos = out->find(match); pos != std::string::npos; @@ -190,7 +190,7 @@ inline std::string Replace(const std::string_view &src, const std::string_view &match, const std::string_view &replacement) { std::string res; - Replace(&res, src, match, replacement, std::allocator()); + Replace(&res, src, match, replacement); return res; }