Use memory from EvaluationContext in awesome functions

Reviewers: mferencevic, msantl, mtomic

Reviewed By: msantl

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D2133
This commit is contained in:
Teon Banek 2019-06-07 16:21:16 +02:00
parent 3fd14e2d5f
commit 65ab2574bc
3 changed files with 260 additions and 226 deletions

View File

@ -34,51 +34,57 @@ namespace {
// TODO: Implement degrees, haversin, radians
// TODO: Implement spatial functions
TypedValue EndNode(TypedValue *args, int64_t nargs, const EvaluationContext &,
database::GraphDbAccessor *) {
/////////////////////////// IMPORTANT NOTE! ////////////////////////////////////
// All of the functions take mutable `TypedValue *` to arguments, but none of
// the functions should ever need to actually modify the arguments! Let's try to
// keep our sanity in a good state by treating the arguments as immutable.
////////////////////////////////////////////////////////////////////////////////
TypedValue EndNode(TypedValue *args, int64_t nargs,
const EvaluationContext &ctx, database::GraphDbAccessor *) {
if (nargs != 1) {
throw QueryRuntimeException("'endNode' requires exactly one argument.");
}
switch (args[0].type()) {
case TypedValue::Type::Null:
return TypedValue();
return TypedValue(ctx.memory);
case TypedValue::Type::Edge:
return TypedValue(args[0].Value<EdgeAccessor>().to());
return TypedValue(args[0].Value<EdgeAccessor>().to(), ctx.memory);
default:
throw QueryRuntimeException("'endNode' argument must be an edge.");
}
}
TypedValue Head(TypedValue *args, int64_t nargs, const EvaluationContext &,
TypedValue Head(TypedValue *args, int64_t nargs, const EvaluationContext &ctx,
database::GraphDbAccessor *) {
if (nargs != 1) {
throw QueryRuntimeException("'head' requires exactly one argument.");
}
switch (args[0].type()) {
case TypedValue::Type::Null:
return TypedValue();
return TypedValue(ctx.memory);
case TypedValue::Type::List: {
const auto &list = args[0].ValueList();
if (list.empty()) return TypedValue();
return list[0];
if (list.empty()) return TypedValue(ctx.memory);
return TypedValue(list[0], ctx.memory);
}
default:
throw QueryRuntimeException("'head' argument must be a list.");
}
}
TypedValue Last(TypedValue *args, int64_t nargs, const EvaluationContext &,
TypedValue Last(TypedValue *args, int64_t nargs, const EvaluationContext &ctx,
database::GraphDbAccessor *) {
if (nargs != 1) {
throw QueryRuntimeException("'last' requires exactly one argument.");
}
switch (args[0].type()) {
case TypedValue::Type::Null:
return TypedValue();
return TypedValue(ctx.memory);
case TypedValue::Type::List: {
const auto &list = args[0].ValueList();
if (list.empty()) return TypedValue();
return list.back();
if (list.empty()) return TypedValue(ctx.memory);
return TypedValue(list.back(), ctx.memory);
}
default:
throw QueryRuntimeException("'last' argument must be a list.");
@ -86,22 +92,21 @@ TypedValue Last(TypedValue *args, int64_t nargs, const EvaluationContext &,
}
TypedValue Properties(TypedValue *args, int64_t nargs,
const EvaluationContext &,
const EvaluationContext &ctx,
database::GraphDbAccessor *dba) {
if (nargs != 1) {
throw QueryRuntimeException("'properties' requires exactly one argument.");
}
auto get_properties = [&](const auto &record_accessor) {
std::map<std::string, TypedValue> properties;
TypedValue::TMap properties(ctx.memory);
for (const auto &property : record_accessor.Properties()) {
properties[dba->PropertyName(property.first)] =
TypedValue(property.second);
properties.emplace(dba->PropertyName(property.first), property.second);
}
return TypedValue(properties);
return TypedValue(std::move(properties));
};
switch (args[0].type()) {
case TypedValue::Type::Null:
return TypedValue();
return TypedValue(ctx.memory);
case TypedValue::Type::Vertex:
return get_properties(args[0].Value<VertexAccessor>());
case TypedValue::Type::Edge:
@ -112,83 +117,89 @@ TypedValue Properties(TypedValue *args, int64_t nargs,
}
}
TypedValue Size(TypedValue *args, int64_t nargs, const EvaluationContext &,
TypedValue Size(TypedValue *args, int64_t nargs, const EvaluationContext &ctx,
database::GraphDbAccessor *) {
if (nargs != 1) {
throw QueryRuntimeException("'size' requires exactly one argument.");
}
switch (args[0].type()) {
case TypedValue::Type::Null:
return TypedValue();
return TypedValue(ctx.memory);
case TypedValue::Type::List:
return TypedValue(static_cast<int64_t>(args[0].ValueList().size()));
return TypedValue(static_cast<int64_t>(args[0].ValueList().size()),
ctx.memory);
case TypedValue::Type::String:
return TypedValue(static_cast<int64_t>(args[0].ValueString().size()));
return TypedValue(static_cast<int64_t>(args[0].ValueString().size()),
ctx.memory);
case TypedValue::Type::Map:
// neo4j doesn't implement size for map, but I don't see a good reason not
// to do it.
return TypedValue(static_cast<int64_t>(args[0].ValueMap().size()));
return TypedValue(static_cast<int64_t>(args[0].ValueMap().size()),
ctx.memory);
case TypedValue::Type::Path:
return TypedValue(
static_cast<int64_t>(args[0].ValuePath().edges().size()));
static_cast<int64_t>(args[0].ValuePath().edges().size()), ctx.memory);
default:
throw QueryRuntimeException(
"'size' argument must be a string, a collection or a path.");
}
}
TypedValue StartNode(TypedValue *args, int64_t nargs, const EvaluationContext &,
TypedValue StartNode(TypedValue *args, int64_t nargs,
const EvaluationContext &ctx,
database::GraphDbAccessor *) {
if (nargs != 1) {
throw QueryRuntimeException("'startNode' requires exactly one argument.");
}
switch (args[0].type()) {
case TypedValue::Type::Null:
return TypedValue();
return TypedValue(ctx.memory);
case TypedValue::Type::Edge:
return TypedValue(args[0].Value<EdgeAccessor>().from());
return TypedValue(args[0].Value<EdgeAccessor>().from(), ctx.memory);
default:
throw QueryRuntimeException("'startNode' argument must be an edge.");
}
}
TypedValue Degree(TypedValue *args, int64_t nargs, const EvaluationContext &,
TypedValue Degree(TypedValue *args, int64_t nargs, const EvaluationContext &ctx,
database::GraphDbAccessor *) {
if (nargs != 1) {
throw QueryRuntimeException("'degree' requires exactly one argument.");
}
switch (args[0].type()) {
case TypedValue::Type::Null:
return TypedValue();
return TypedValue(ctx.memory);
case TypedValue::Type::Vertex: {
auto &vertex = args[0].Value<VertexAccessor>();
const auto &vertex = args[0].Value<VertexAccessor>();
return TypedValue(
static_cast<int64_t>(vertex.out_degree() + vertex.in_degree()));
static_cast<int64_t>(vertex.out_degree() + vertex.in_degree()),
ctx.memory);
}
default:
throw QueryRuntimeException("'degree' argument must be a node.");
}
}
TypedValue InDegree(TypedValue *args, int64_t nargs, const EvaluationContext &,
database::GraphDbAccessor *) {
TypedValue InDegree(TypedValue *args, int64_t nargs,
const EvaluationContext &ctx, database::GraphDbAccessor *) {
if (nargs != 1) {
throw QueryRuntimeException("'inDegree' requires exactly one argument.");
}
switch (args[0].type()) {
case TypedValue::Type::Null:
return TypedValue();
return TypedValue(ctx.memory);
case TypedValue::Type::Vertex: {
auto &vertex = args[0].Value<VertexAccessor>();
return TypedValue(static_cast<int64_t>(vertex.in_degree()));
const auto &vertex = args[0].Value<VertexAccessor>();
return TypedValue(static_cast<int64_t>(vertex.in_degree()), ctx.memory);
}
default:
throw QueryRuntimeException("'inDegree' argument must be a node.");
}
}
TypedValue OutDegree(TypedValue *args, int64_t nargs, const EvaluationContext &,
TypedValue OutDegree(TypedValue *args, int64_t nargs,
const EvaluationContext &ctx,
database::GraphDbAccessor *) {
if (nargs != 1) {
throw QueryRuntimeException("'outDegree' requires exactly one argument.");
@ -196,35 +207,36 @@ TypedValue OutDegree(TypedValue *args, int64_t nargs, const EvaluationContext &,
switch (args[0].type()) {
case TypedValue::Type::Null:
return TypedValue();
return TypedValue(ctx.memory);
case TypedValue::Type::Vertex: {
auto &vertex = args[0].Value<VertexAccessor>();
return TypedValue(static_cast<int64_t>(vertex.out_degree()));
const auto &vertex = args[0].Value<VertexAccessor>();
return TypedValue(static_cast<int64_t>(vertex.out_degree()), ctx.memory);
}
default:
throw QueryRuntimeException("'outDegree' argument must be a node.");
}
}
TypedValue ToBoolean(TypedValue *args, int64_t nargs, const EvaluationContext &,
TypedValue ToBoolean(TypedValue *args, int64_t nargs,
const EvaluationContext &ctx,
database::GraphDbAccessor *) {
if (nargs != 1) {
throw QueryRuntimeException("'toBoolean' requires exactly one argument.");
}
switch (args[0].type()) {
case TypedValue::Type::Null:
return TypedValue();
return TypedValue(ctx.memory);
case TypedValue::Type::Bool:
return TypedValue(args[0].Value<bool>());
return TypedValue(args[0].Value<bool>(), ctx.memory);
case TypedValue::Type::Int:
return TypedValue(args[0].ValueInt() != 0L);
return TypedValue(args[0].ValueInt() != 0L, ctx.memory);
case TypedValue::Type::String: {
auto s = utils::ToUpperCase(utils::Trim(args[0].ValueString()));
if (s == "TRUE") return TypedValue(true);
if (s == "FALSE") return TypedValue(false);
if (s == "TRUE") return TypedValue(true, ctx.memory);
if (s == "FALSE") return TypedValue(false, ctx.memory);
// I think this is just stupid and that exception should be thrown, but
// neo4j does it this way...
return TypedValue();
return TypedValue(ctx.memory);
}
default:
throw QueryRuntimeException(
@ -232,24 +244,25 @@ TypedValue ToBoolean(TypedValue *args, int64_t nargs, const EvaluationContext &,
}
}
TypedValue ToFloat(TypedValue *args, int64_t nargs, const EvaluationContext &,
database::GraphDbAccessor *) {
TypedValue ToFloat(TypedValue *args, int64_t nargs,
const EvaluationContext &ctx, database::GraphDbAccessor *) {
if (nargs != 1) {
throw QueryRuntimeException("'toFloat' requires exactly one argument.");
}
switch (args[0].type()) {
case TypedValue::Type::Null:
return TypedValue();
return TypedValue(ctx.memory);
case TypedValue::Type::Int:
return TypedValue(static_cast<double>(args[0].Value<int64_t>()));
return TypedValue(static_cast<double>(args[0].Value<int64_t>()),
ctx.memory);
case TypedValue::Type::Double:
return args[0];
return TypedValue(args[0], ctx.memory);
case TypedValue::Type::String:
try {
return TypedValue(
utils::ParseDouble(utils::Trim(args[0].ValueString())));
utils::ParseDouble(utils::Trim(args[0].ValueString())), ctx.memory);
} catch (const utils::BasicException &) {
return TypedValue();
return TypedValue(ctx.memory);
}
default:
throw QueryRuntimeException(
@ -257,28 +270,31 @@ TypedValue ToFloat(TypedValue *args, int64_t nargs, const EvaluationContext &,
}
}
TypedValue ToInteger(TypedValue *args, int64_t nargs, const EvaluationContext &,
TypedValue ToInteger(TypedValue *args, int64_t nargs,
const EvaluationContext &ctx,
database::GraphDbAccessor *) {
if (nargs != 1) {
throw QueryRuntimeException("'toInteger' requires exactly one argument'");
}
switch (args[0].type()) {
case TypedValue::Type::Null:
return TypedValue();
return TypedValue(ctx.memory);
case TypedValue::Type::Bool:
return TypedValue(args[0].ValueBool() ? 1L : 0L);
return TypedValue(args[0].ValueBool() ? 1L : 0L, ctx.memory);
case TypedValue::Type::Int:
return args[0];
return TypedValue(args[0], ctx.memory);
case TypedValue::Type::Double:
return TypedValue(static_cast<int64_t>(args[0].Value<double>()));
return TypedValue(static_cast<int64_t>(args[0].Value<double>()),
ctx.memory);
case TypedValue::Type::String:
try {
// Yup, this is correct. String is valid if it has floating point
// number, then it is parsed and converted to int.
return TypedValue(static_cast<int64_t>(
utils::ParseDouble(utils::Trim(args[0].ValueString()))));
return TypedValue(static_cast<int64_t>(utils::ParseDouble(
utils::Trim(args[0].ValueString()))),
ctx.memory);
} catch (const utils::BasicException &) {
return TypedValue();
return TypedValue(ctx.memory);
}
default:
throw QueryRuntimeException(
@ -286,37 +302,38 @@ TypedValue ToInteger(TypedValue *args, int64_t nargs, const EvaluationContext &,
}
}
TypedValue Type(TypedValue *args, int64_t nargs, const EvaluationContext &,
TypedValue Type(TypedValue *args, int64_t nargs, const EvaluationContext &ctx,
database::GraphDbAccessor *dba) {
if (nargs != 1) {
throw QueryRuntimeException("'type' requires exactly one argument.");
}
switch (args[0].type()) {
case TypedValue::Type::Null:
return TypedValue();
return TypedValue(ctx.memory);
case TypedValue::Type::Edge:
return TypedValue(
dba->EdgeTypeName(args[0].Value<EdgeAccessor>().EdgeType()));
dba->EdgeTypeName(args[0].Value<EdgeAccessor>().EdgeType()),
ctx.memory);
default:
throw QueryRuntimeException("'type' argument must be an edge.");
}
}
TypedValue Keys(TypedValue *args, int64_t nargs, const EvaluationContext &,
TypedValue Keys(TypedValue *args, int64_t nargs, const EvaluationContext &ctx,
database::GraphDbAccessor *dba) {
if (nargs != 1) {
throw QueryRuntimeException("'keys' requires exactly one argument.");
}
auto get_keys = [&](const auto &record_accessor) {
std::vector<TypedValue> keys;
TypedValue::TVector keys(ctx.memory);
for (const auto &property : record_accessor.Properties()) {
keys.emplace_back(dba->PropertyName(property.first));
}
return TypedValue(keys);
return TypedValue(std::move(keys));
};
switch (args[0].type()) {
case TypedValue::Type::Null:
return TypedValue();
return TypedValue(ctx.memory);
case TypedValue::Type::Vertex:
return get_keys(args[0].Value<VertexAccessor>());
case TypedValue::Type::Edge:
@ -326,55 +343,61 @@ TypedValue Keys(TypedValue *args, int64_t nargs, const EvaluationContext &,
}
}
TypedValue Labels(TypedValue *args, int64_t nargs, const EvaluationContext &,
TypedValue Labels(TypedValue *args, int64_t nargs, const EvaluationContext &ctx,
database::GraphDbAccessor *dba) {
if (nargs != 1) {
throw QueryRuntimeException("'labels' requires exactly one argument.");
}
switch (args[0].type()) {
case TypedValue::Type::Null:
return TypedValue();
return TypedValue(ctx.memory);
case TypedValue::Type::Vertex: {
std::vector<TypedValue> labels;
TypedValue::TVector labels(ctx.memory);
for (const auto &label : args[0].Value<VertexAccessor>().labels()) {
labels.emplace_back(dba->LabelName(label));
}
return TypedValue(labels);
return TypedValue(std::move(labels));
}
default:
throw QueryRuntimeException("'labels' argument must be a node.");
}
}
TypedValue Nodes(TypedValue *args, int64_t nargs, const EvaluationContext &,
TypedValue Nodes(TypedValue *args, int64_t nargs, const EvaluationContext &ctx,
database::GraphDbAccessor *) {
if (nargs != 1) {
throw QueryRuntimeException("'nodes' requires exactly one argument.");
}
if (args[0].IsNull()) return TypedValue();
if (args[0].IsNull()) return TypedValue(ctx.memory);
if (!args[0].IsPath()) {
throw QueryRuntimeException("'nodes' argument should be a path.");
}
auto &vertices = args[0].ValuePath().vertices();
return TypedValue(std::vector<TypedValue>(vertices.begin(), vertices.end()));
const auto &vertices = args[0].ValuePath().vertices();
TypedValue::TVector values(ctx.memory);
values.reserve(vertices.size());
for (const auto &v : vertices) values.emplace_back(v);
return TypedValue(std::move(values));
}
TypedValue Relationships(TypedValue *args, int64_t nargs,
const EvaluationContext &,
const EvaluationContext &ctx,
database::GraphDbAccessor *) {
if (nargs != 1) {
throw QueryRuntimeException(
"'relationships' requires exactly one argument.");
}
if (args[0].IsNull()) return TypedValue();
if (args[0].IsNull()) return TypedValue(ctx.memory);
if (!args[0].IsPath()) {
throw QueryRuntimeException("'relationships' argument must be a path.");
}
auto &edges = args[0].ValuePath().edges();
return TypedValue(std::vector<TypedValue>(edges.begin(), edges.end()));
const auto &edges = args[0].ValuePath().edges();
TypedValue::TVector values(ctx.memory);
values.reserve(edges.size());
for (const auto &e : edges) values.emplace_back(e);
return TypedValue(std::move(values));
}
TypedValue Range(TypedValue *args, int64_t nargs, const EvaluationContext &,
TypedValue Range(TypedValue *args, int64_t nargs, const EvaluationContext &ctx,
database::GraphDbAccessor *) {
if (nargs != 2 && nargs != 3) {
throw QueryRuntimeException("'range' requires two or three arguments.");
@ -388,14 +411,14 @@ TypedValue Range(TypedValue *args, int64_t nargs, const EvaluationContext &,
}
};
for (int64_t i = 0; i < nargs; ++i) check_type(args[i]);
if (has_null) return TypedValue();
if (has_null) return TypedValue(ctx.memory);
auto lbound = args[0].Value<int64_t>();
auto rbound = args[1].Value<int64_t>();
int64_t step = nargs == 3 ? args[2].Value<int64_t>() : 1;
if (step == 0) {
throw QueryRuntimeException("step argument of 'range' can't be zero.");
}
std::vector<TypedValue> list;
TypedValue::TVector list(ctx.memory);
if (lbound <= rbound && step > 0) {
for (auto i = lbound; i <= rbound; i += step) {
list.emplace_back(i);
@ -405,22 +428,22 @@ TypedValue Range(TypedValue *args, int64_t nargs, const EvaluationContext &,
list.emplace_back(i);
}
}
return TypedValue(list);
return TypedValue(std::move(list));
}
TypedValue Tail(TypedValue *args, int64_t nargs, const EvaluationContext &,
TypedValue Tail(TypedValue *args, int64_t nargs, const EvaluationContext &ctx,
database::GraphDbAccessor *) {
if (nargs != 1) {
throw QueryRuntimeException("'tail' requires exactly one argument.");
}
switch (args[0].type()) {
case TypedValue::Type::Null:
return TypedValue();
return TypedValue(ctx.memory);
case TypedValue::Type::List: {
auto list = args[0].ValueList();
if (list.empty()) return TypedValue(list);
TypedValue::TVector list(args[0].ValueList(), ctx.memory);
if (list.empty()) return TypedValue(std::move(list));
list.erase(list.begin());
return TypedValue(list);
return TypedValue(std::move(list));
}
default:
throw QueryRuntimeException("'tail' argument must be a list.");
@ -428,7 +451,7 @@ TypedValue Tail(TypedValue *args, int64_t nargs, const EvaluationContext &,
}
TypedValue UniformSample(TypedValue *args, int64_t nargs,
const EvaluationContext &,
const EvaluationContext &ctx,
database::GraphDbAccessor *) {
static thread_local std::mt19937 pseudo_rand_gen_{std::random_device{}()};
if (nargs != 2) {
@ -438,7 +461,7 @@ TypedValue UniformSample(TypedValue *args, int64_t nargs,
switch (args[0].type()) {
case TypedValue::Type::Null:
if (args[1].IsNull() || (args[1].IsInt() && args[1].ValueInt() >= 0)) {
return TypedValue();
return TypedValue(ctx.memory);
}
throw QueryRuntimeException(
"Second argument of 'uniformSample' must be a non-negative integer.");
@ -446,16 +469,16 @@ TypedValue UniformSample(TypedValue *args, int64_t nargs,
if (args[1].IsInt() && args[1].ValueInt() >= 0) {
auto &population = args[0].ValueList();
auto population_size = population.size();
if (population_size == 0) return TypedValue();
if (population_size == 0) return TypedValue(ctx.memory);
auto desired_length = args[1].ValueInt();
std::uniform_int_distribution<uint64_t> rand_dist{0,
population_size - 1};
std::vector<TypedValue> sampled;
TypedValue::TVector sampled(ctx.memory);
sampled.reserve(desired_length);
for (int i = 0; i < desired_length; ++i) {
sampled.push_back(population[rand_dist(pseudo_rand_gen_)]);
sampled.emplace_back(population[rand_dist(pseudo_rand_gen_)]);
}
return TypedValue(sampled);
return TypedValue(std::move(sampled));
}
throw QueryRuntimeException(
"Second argument of 'uniformSample' must be a non-negative integer.");
@ -465,37 +488,39 @@ TypedValue UniformSample(TypedValue *args, int64_t nargs,
}
}
TypedValue Abs(TypedValue *args, int64_t nargs, const EvaluationContext &,
TypedValue Abs(TypedValue *args, int64_t nargs, const EvaluationContext &ctx,
database::GraphDbAccessor *) {
if (nargs != 1) {
throw QueryRuntimeException("'abs' requires exactly one argument.");
}
switch (args[0].type()) {
case TypedValue::Type::Null:
return TypedValue();
return TypedValue(ctx.memory);
case TypedValue::Type::Int:
return TypedValue(std::abs(args[0].Value<int64_t>()));
return TypedValue(std::abs(args[0].Value<int64_t>()), ctx.memory);
case TypedValue::Type::Double:
return TypedValue(std::abs(args[0].Value<double>()));
return TypedValue(std::abs(args[0].Value<double>()), ctx.memory);
default:
throw QueryRuntimeException("'abs' argument should be a number.");
}
}
#define WRAP_CMATH_FLOAT_FUNCTION(name, lowercased_name) \
TypedValue name(TypedValue *args, int64_t nargs, const EvaluationContext &, \
database::GraphDbAccessor *) { \
TypedValue name(TypedValue *args, int64_t nargs, \
const EvaluationContext &ctx, database::GraphDbAccessor *) { \
if (nargs != 1) { \
throw QueryRuntimeException("'" #lowercased_name \
"' requires exactly one argument."); \
} \
switch (args[0].type()) { \
case TypedValue::Type::Null: \
return TypedValue(); \
return TypedValue(ctx.memory); \
case TypedValue::Type::Int: \
return TypedValue(lowercased_name(args[0].Value<int64_t>())); \
return TypedValue(lowercased_name(args[0].Value<int64_t>()), \
ctx.memory); \
case TypedValue::Type::Double: \
return TypedValue(lowercased_name(args[0].Value<double>())); \
return TypedValue(lowercased_name(args[0].Value<double>()), \
ctx.memory); \
default: \
throw QueryRuntimeException(#lowercased_name \
" argument must be a number."); \
@ -520,13 +545,13 @@ WRAP_CMATH_FLOAT_FUNCTION(Tan, tan)
#undef WRAP_CMATH_FLOAT_FUNCTION
TypedValue Atan2(TypedValue *args, int64_t nargs, const EvaluationContext &,
TypedValue Atan2(TypedValue *args, int64_t nargs, const EvaluationContext &ctx,
database::GraphDbAccessor *) {
if (nargs != 2) {
throw QueryRuntimeException("'atan2' requires two arguments.");
}
if (args[0].type() == TypedValue::Type::Null) return TypedValue();
if (args[1].type() == TypedValue::Type::Null) return TypedValue();
if (args[0].type() == TypedValue::Type::Null) return TypedValue(ctx.memory);
if (args[1].type() == TypedValue::Type::Null) return TypedValue(ctx.memory);
auto to_double = [](const TypedValue &t) -> double {
switch (t.type()) {
case TypedValue::Type::Int:
@ -539,18 +564,18 @@ TypedValue Atan2(TypedValue *args, int64_t nargs, const EvaluationContext &,
};
double y = to_double(args[0]);
double x = to_double(args[1]);
return TypedValue(atan2(y, x));
return TypedValue(atan2(y, x), ctx.memory);
}
TypedValue Sign(TypedValue *args, int64_t nargs, const EvaluationContext &,
TypedValue Sign(TypedValue *args, int64_t nargs, const EvaluationContext &ctx,
database::GraphDbAccessor *) {
if (nargs != 1) {
throw QueryRuntimeException("'sign' requires exactly one argument.");
}
auto sign = [](auto x) { return TypedValue((0 < x) - (x < 0)); };
auto sign = [&](auto x) { return TypedValue((0 < x) - (x < 0), ctx.memory); };
switch (args[0].type()) {
case TypedValue::Type::Null:
return TypedValue();
return TypedValue(ctx.memory);
case TypedValue::Type::Int:
return sign(args[0].Value<int64_t>());
case TypedValue::Type::Double:
@ -560,36 +585,36 @@ TypedValue Sign(TypedValue *args, int64_t nargs, const EvaluationContext &,
}
}
TypedValue E(TypedValue *, int64_t nargs, const EvaluationContext &,
TypedValue E(TypedValue *, int64_t nargs, const EvaluationContext &ctx,
database::GraphDbAccessor *) {
if (nargs != 0) {
throw QueryRuntimeException("'e' requires no arguments.");
}
return TypedValue(M_E);
return TypedValue(M_E, ctx.memory);
}
TypedValue Pi(TypedValue *, int64_t nargs, const EvaluationContext &,
TypedValue Pi(TypedValue *, int64_t nargs, const EvaluationContext &ctx,
database::GraphDbAccessor *) {
if (nargs != 0) {
throw QueryRuntimeException("'pi' requires no arguments.");
}
return TypedValue(M_PI);
return TypedValue(M_PI, ctx.memory);
}
TypedValue Rand(TypedValue *, int64_t nargs, const EvaluationContext &,
TypedValue Rand(TypedValue *, int64_t nargs, const EvaluationContext &ctx,
database::GraphDbAccessor *) {
static thread_local std::mt19937 pseudo_rand_gen_{std::random_device{}()};
static thread_local std::uniform_real_distribution<> rand_dist_{0, 1};
if (nargs != 0) {
throw QueryRuntimeException("'rand' requires no arguments.");
}
return TypedValue(rand_dist_(pseudo_rand_gen_));
return TypedValue(rand_dist_(pseudo_rand_gen_), ctx.memory);
}
template <bool (*Predicate)(const TypedValue::TString &s1,
const TypedValue::TString &s2)>
TypedValue StringMatchOperator(TypedValue *args, int64_t nargs,
const EvaluationContext &,
const EvaluationContext &ctx,
database::GraphDbAccessor *) {
if (nargs != 2) {
throw QueryRuntimeException(
@ -606,10 +631,10 @@ TypedValue StringMatchOperator(TypedValue *args, int64_t nargs,
};
check_arg(args[0]);
check_arg(args[1]);
if (has_null) return TypedValue();
if (has_null) return TypedValue(ctx.memory);
const auto &s1 = args[0].ValueString();
const auto &s2 = args[1].ValueString();
return TypedValue(Predicate(s1, s2));
return TypedValue(Predicate(s1, s2), ctx.memory);
}
// Check if s1 starts with s2.
@ -636,7 +661,7 @@ bool ContainsPredicate(const TypedValue::TString &s1,
}
auto Contains = StringMatchOperator<ContainsPredicate>;
TypedValue Assert(TypedValue *args, int64_t nargs, const EvaluationContext &,
TypedValue Assert(TypedValue *args, int64_t nargs, const EvaluationContext &ctx,
database::GraphDbAccessor *) {
if (nargs < 1 || nargs > 2) {
throw QueryRuntimeException("'assert' requires one or two arguments");
@ -656,7 +681,7 @@ TypedValue Assert(TypedValue *args, int64_t nargs, const EvaluationContext &,
message += ".";
throw QueryRuntimeException(message);
}
return args[0];
return TypedValue(args[0], ctx.memory);
}
#if defined(MG_SINGLE_NODE) || defined(MG_SINGLE_NODE_HA)
@ -684,22 +709,24 @@ TypedValue Counter(TypedValue *args, int64_t nargs,
auto value = it->second;
it->second += step;
return TypedValue(value);
return TypedValue(value, context.memory);
}
#endif
#ifdef MG_DISTRIBUTED
TypedValue WorkerId(TypedValue *args, int64_t nargs, const EvaluationContext &,
database::GraphDbAccessor *) {
TypedValue WorkerId(TypedValue *args, int64_t nargs,
const EvaluationContext &ctx, database::GraphDbAccessor *) {
if (nargs != 1) {
throw QueryRuntimeException("'workerId' requires exactly one argument.");
}
auto &arg = args[0];
const auto &arg = args[0];
switch (arg.type()) {
case TypedValue::Type::Vertex:
return TypedValue(arg.ValueVertex().GlobalAddress().worker_id());
return TypedValue(arg.ValueVertex().GlobalAddress().worker_id(),
ctx.memory);
case TypedValue::Type::Edge:
return TypedValue(arg.ValueEdge().GlobalAddress().worker_id());
return TypedValue(arg.ValueEdge().GlobalAddress().worker_id(),
ctx.memory);
default:
throw QueryRuntimeException(
"'workerId' argument must be a node or an edge.");
@ -707,41 +734,43 @@ TypedValue WorkerId(TypedValue *args, int64_t nargs, const EvaluationContext &,
}
#endif
TypedValue Id(TypedValue *args, int64_t nargs, const EvaluationContext &,
TypedValue Id(TypedValue *args, int64_t nargs, const EvaluationContext &ctx,
database::GraphDbAccessor *dba) {
if (nargs != 1) {
throw QueryRuntimeException("'id' requires exactly one argument.");
}
auto &arg = args[0];
const auto &arg = args[0];
switch (arg.type()) {
case TypedValue::Type::Vertex: {
return TypedValue(arg.ValueVertex().CypherId());
return TypedValue(arg.ValueVertex().CypherId(), ctx.memory);
}
case TypedValue::Type::Edge: {
return TypedValue(arg.ValueEdge().CypherId());
return TypedValue(arg.ValueEdge().CypherId(), ctx.memory);
}
default:
throw QueryRuntimeException("'id' argument must be a node or an edge.");
}
}
TypedValue ToString(TypedValue *args, int64_t nargs, const EvaluationContext &,
database::GraphDbAccessor *) {
TypedValue ToString(TypedValue *args, int64_t nargs,
const EvaluationContext &ctx, database::GraphDbAccessor *) {
if (nargs != 1) {
throw QueryRuntimeException("'toString' requires exactly one argument.");
}
auto &arg = args[0];
const auto &arg = args[0];
switch (arg.type()) {
case TypedValue::Type::Null:
return TypedValue();
return TypedValue(ctx.memory);
case TypedValue::Type::String:
return arg;
return TypedValue(arg, ctx.memory);
case TypedValue::Type::Int:
return TypedValue(std::to_string(arg.ValueInt()));
// TODO: This is making a pointless copy of std::string, we may want to
// use a different conversion to string
return TypedValue(std::to_string(arg.ValueInt()), ctx.memory);
case TypedValue::Type::Double:
return TypedValue(std::to_string(arg.ValueDouble()));
return TypedValue(std::to_string(arg.ValueDouble()), ctx.memory);
case TypedValue::Type::Bool:
return TypedValue(arg.ValueBool() ? "true" : "false");
return TypedValue(arg.ValueBool() ? "true" : "false", ctx.memory);
default:
throw QueryRuntimeException(
"'toString' argument must be a number, a string or a boolean.");
@ -753,10 +782,10 @@ TypedValue Timestamp(TypedValue *, int64_t nargs, const EvaluationContext &ctx,
if (nargs != 0) {
throw QueryRuntimeException("'timestamp' requires no arguments.");
}
return TypedValue(ctx.timestamp);
return TypedValue(ctx.timestamp, ctx.memory);
}
TypedValue Left(TypedValue *args, int64_t nargs, const EvaluationContext &,
TypedValue Left(TypedValue *args, int64_t nargs, const EvaluationContext &ctx,
database::GraphDbAccessor *dba) {
if (nargs != 2) {
throw QueryRuntimeException("'left' requires two arguments.");
@ -764,16 +793,15 @@ TypedValue Left(TypedValue *args, int64_t nargs, const EvaluationContext &,
switch (args[0].type()) {
case TypedValue::Type::Null:
if (args[1].IsNull() || (args[1].IsInt() && args[1].ValueInt() >= 0)) {
return TypedValue();
return TypedValue(ctx.memory);
}
throw QueryRuntimeException(
"Second argument of 'left' must be a non-negative integer.");
case TypedValue::Type::String:
if (args[1].IsInt() && args[1].ValueInt() >= 0) {
auto *memory = args[0].GetMemoryResource();
return TypedValue(
utils::Substr(args[0].ValueString(), 0, args[1].ValueInt()),
memory);
ctx.memory);
}
throw QueryRuntimeException(
"Second argument of 'left' must be a non-negative integer.");
@ -782,7 +810,7 @@ TypedValue Left(TypedValue *args, int64_t nargs, const EvaluationContext &,
}
}
TypedValue Right(TypedValue *args, int64_t nargs, const EvaluationContext &,
TypedValue Right(TypedValue *args, int64_t nargs, const EvaluationContext &ctx,
database::GraphDbAccessor *dba) {
if (nargs != 2) {
throw QueryRuntimeException("'right' requires two arguments.");
@ -790,7 +818,7 @@ TypedValue Right(TypedValue *args, int64_t nargs, const EvaluationContext &,
switch (args[0].type()) {
case TypedValue::Type::Null:
if (args[1].IsNull() || (args[1].IsInt() && args[1].ValueInt() >= 0)) {
return TypedValue();
return TypedValue(ctx.memory);
}
throw QueryRuntimeException(
"Second argument of 'right' must be a non-negative integer.");
@ -798,11 +826,10 @@ TypedValue Right(TypedValue *args, int64_t nargs, const EvaluationContext &,
const auto &str = args[0].ValueString();
if (args[1].IsInt() && args[1].ValueInt() >= 0) {
auto len = args[1].ValueInt();
auto *memory = args[0].GetMemoryResource();
return len <= str.size()
? TypedValue(utils::Substr(str, str.size() - len, len),
memory)
: TypedValue(str, memory);
ctx.memory)
: TypedValue(str, ctx.memory);
}
throw QueryRuntimeException(
"Second argument of 'right' must be a non-negative integer.");
@ -814,7 +841,8 @@ TypedValue Right(TypedValue *args, int64_t nargs, const EvaluationContext &,
}
TypedValue CallStringFunction(
TypedValue *args, int64_t nargs, const std::string &name,
TypedValue *args, int64_t nargs, utils::MemoryResource *memory,
const std::string &name,
std::function<TypedValue::TString(const TypedValue::TString &)> fun) {
if (nargs != 1) {
throw QueryRuntimeException("'" + name +
@ -822,62 +850,68 @@ TypedValue CallStringFunction(
}
switch (args[0].type()) {
case TypedValue::Type::Null:
return TypedValue(args[0].GetMemoryResource());
return TypedValue(memory);
case TypedValue::Type::String:
return TypedValue(fun(args[0].ValueString()));
return TypedValue(fun(args[0].ValueString()), memory);
default:
throw QueryRuntimeException("'" + name +
"' argument should be a string.");
}
}
TypedValue LTrim(TypedValue *args, int64_t nargs, const EvaluationContext &,
TypedValue LTrim(TypedValue *args, int64_t nargs, const EvaluationContext &ctx,
database::GraphDbAccessor *) {
return CallStringFunction(args, nargs, "lTrim", [](const auto &str) {
return TypedValue::TString(utils::LTrim(str), str.get_allocator());
return CallStringFunction(
args, nargs, ctx.memory, "lTrim", [&](const auto &str) {
return TypedValue::TString(utils::LTrim(str), ctx.memory);
});
}
TypedValue RTrim(TypedValue *args, int64_t nargs, const EvaluationContext &,
TypedValue RTrim(TypedValue *args, int64_t nargs, const EvaluationContext &ctx,
database::GraphDbAccessor *) {
return CallStringFunction(args, nargs, "rTrim", [](const auto &str) {
return TypedValue::TString(utils::RTrim(str), str.get_allocator());
return CallStringFunction(
args, nargs, ctx.memory, "rTrim", [&](const auto &str) {
return TypedValue::TString(utils::RTrim(str), ctx.memory);
});
}
TypedValue Trim(TypedValue *args, int64_t nargs, const EvaluationContext &,
TypedValue Trim(TypedValue *args, int64_t nargs, const EvaluationContext &ctx,
database::GraphDbAccessor *) {
return CallStringFunction(args, nargs, "trim", [](const auto &str) {
return TypedValue::TString(utils::Trim(str), str.get_allocator());
return CallStringFunction(
args, nargs, ctx.memory, "trim", [&](const auto &str) {
return TypedValue::TString(utils::Trim(str), ctx.memory);
});
}
TypedValue Reverse(TypedValue *args, int64_t nargs, const EvaluationContext &,
database::GraphDbAccessor *) {
return CallStringFunction(args, nargs, "reverse", [](const auto &str) {
return utils::Reversed(str, str.get_allocator());
});
TypedValue Reverse(TypedValue *args, int64_t nargs,
const EvaluationContext &ctx, database::GraphDbAccessor *) {
return CallStringFunction(
args, nargs, ctx.memory, "reverse",
[&](const auto &str) { return utils::Reversed(str, ctx.memory); });
}
TypedValue ToLower(TypedValue *args, int64_t nargs, const EvaluationContext &,
database::GraphDbAccessor *) {
return CallStringFunction(args, nargs, "toLower", [](const auto &str) {
TypedValue::TString res(str.get_allocator());
TypedValue ToLower(TypedValue *args, int64_t nargs,
const EvaluationContext &ctx, database::GraphDbAccessor *) {
return CallStringFunction(args, nargs, ctx.memory, "toLower",
[&](const auto &str) {
TypedValue::TString res(ctx.memory);
utils::ToLowerCase(&res, str);
return res;
});
}
TypedValue ToUpper(TypedValue *args, int64_t nargs, const EvaluationContext &,
database::GraphDbAccessor *) {
return CallStringFunction(args, nargs, "toUpper", [](const auto &str) {
TypedValue::TString res(str.get_allocator());
TypedValue ToUpper(TypedValue *args, int64_t nargs,
const EvaluationContext &ctx, database::GraphDbAccessor *) {
return CallStringFunction(args, nargs, ctx.memory, "toUpper",
[&](const auto &str) {
TypedValue::TString res(ctx.memory);
utils::ToUpperCase(&res, str);
return res;
});
}
TypedValue Replace(TypedValue *args, int64_t nargs, const EvaluationContext &,
TypedValue Replace(TypedValue *args, int64_t nargs,
const EvaluationContext &ctx,
database::GraphDbAccessor *dba) {
if (nargs != 3) {
throw QueryRuntimeException("'replace' requires three arguments.");
@ -895,13 +929,15 @@ TypedValue Replace(TypedValue *args, int64_t nargs, const EvaluationContext &,
"Third argument of 'replace' should be a string.");
}
if (args[0].IsNull() || args[1].IsNull() || args[2].IsNull()) {
return TypedValue();
return TypedValue(ctx.memory);
}
return TypedValue(utils::Replace(args[0].ValueString(), args[1].ValueString(),
args[2].ValueString()));
TypedValue::TString replaced(ctx.memory);
utils::Replace(&replaced, args[0].ValueString(), args[1].ValueString(),
args[2].ValueString());
return TypedValue(std::move(replaced));
}
TypedValue Split(TypedValue *args, int64_t nargs, const EvaluationContext &,
TypedValue Split(TypedValue *args, int64_t nargs, const EvaluationContext &ctx,
database::GraphDbAccessor *dba) {
if (nargs != 2) {
throw QueryRuntimeException("'split' requires two arguments.");
@ -915,17 +951,15 @@ TypedValue Split(TypedValue *args, int64_t nargs, const EvaluationContext &,
"Second argument of 'split' should be a string.");
}
if (args[0].IsNull() || args[1].IsNull()) {
return TypedValue();
return TypedValue(ctx.memory);
}
std::vector<TypedValue> result;
for (const auto &str :
utils::Split(args[0].ValueString(), args[1].ValueString())) {
result.emplace_back(str);
}
return TypedValue(result);
TypedValue::TVector result(ctx.memory);
utils::Split(&result, args[0].ValueString(), args[1].ValueString());
return TypedValue(std::move(result));
}
TypedValue Substring(TypedValue *args, int64_t nargs, const EvaluationContext &,
TypedValue Substring(TypedValue *args, int64_t nargs,
const EvaluationContext &ctx,
database::GraphDbAccessor *) {
if (nargs != 2 && nargs != 3) {
throw QueryRuntimeException("'substring' requires two or three arguments.");
@ -943,21 +977,20 @@ TypedValue Substring(TypedValue *args, int64_t nargs, const EvaluationContext &,
"Third argument of 'substring' should be a non-negative integer.");
}
if (args[0].IsNull()) {
return TypedValue();
return TypedValue(ctx.memory);
}
const auto &str = args[0].ValueString();
auto start = args[1].ValueInt();
auto *memory = args[0].GetMemoryResource();
if (nargs == 2) {
return TypedValue(utils::Substr(str, start), memory);
return TypedValue(utils::Substr(str, start), ctx.memory);
}
auto len = args[2].ValueInt();
return TypedValue(utils::Substr(str, start, len), memory);
return TypedValue(utils::Substr(str, start, len), ctx.memory);
}
} // namespace
std::function<TypedValue(TypedValue *, int64_t, const EvaluationContext &,
std::function<TypedValue(TypedValue *, int64_t, const EvaluationContext &ctx,
database::GraphDbAccessor *)>
NameToFunction(const std::string &function_name) {
// Scalar functions

View File

@ -381,22 +381,23 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
for (size_t i = 0; i < function.arguments_.size(); ++i) {
arguments[i] = function.arguments_[i]->Accept(*this);
}
// TODO: Update awesome_memgraph_functions to use the allocator from ctx_
return TypedValue(function.function_(
arguments, function.arguments_.size(), *ctx_, dba_),
ctx_->memory);
auto res = function.function_(arguments, function.arguments_.size(),
*ctx_, dba_);
CHECK(res.GetMemoryResource() == ctx_->memory);
return res;
} else {
TypedValue::TVector arguments(ctx_->memory);
arguments.reserve(function.arguments_.size());
for (const auto &argument : function.arguments_) {
arguments.emplace_back(argument->Accept(*this));
}
// TODO: Update awesome_memgraph_functions to use the allocator from ctx_
return TypedValue(
function.function_(arguments.data(), arguments.size(), *ctx_, dba_),
ctx_->memory);
auto res =
function.function_(arguments.data(), arguments.size(), *ctx_, dba_);
CHECK(res.GetMemoryResource() == ctx_->memory);
return res;
}
}
TypedValue Visit(Reduce &reduce) override {
auto list_value = reduce.list_->Accept(*this);
if (list_value.IsNull()) {

View File

@ -175,7 +175,7 @@ template <class TAllocator>
std::basic_string<char, std::char_traits<char>, TAllocator> *Replace(
std::basic_string<char, std::char_traits<char>, TAllocator> *out,
const std::string_view &src, const std::string_view &match,
const std::string_view &replacement, const TAllocator &alloc) {
const std::string_view &replacement) {
// TODO: This could be implemented much more efficiently.
*out = src;
for (size_t pos = out->find(match); pos != std::string::npos;
@ -190,7 +190,7 @@ inline std::string Replace(const std::string_view &src,
const std::string_view &match,
const std::string_view &replacement) {
std::string res;
Replace(&res, src, match, replacement, std::allocator<char>());
Replace(&res, src, match, replacement);
return res;
}