add 'sample' awesome memgraph function
Summary: This simple function is required by the Tensorflow integration so that Memgraph can always return regular matrices of desired size. Reviewers: teon.banek, mtomic, dsantl Reviewed By: mtomic, dsantl Subscribers: mferencevic, pullbot Differential Revision: https://phabricator.memgraph.io/D1783
This commit is contained in:
parent
1af728b505
commit
664622f68e
@ -422,6 +422,44 @@ TypedValue Tail(TypedValue *args, int64_t nargs, const EvaluationContext &,
|
||||
}
|
||||
}
|
||||
|
||||
TypedValue UniformSample(TypedValue *args, int64_t nargs,
|
||||
const EvaluationContext &,
|
||||
database::GraphDbAccessor *) {
|
||||
static thread_local std::mt19937 pseudo_rand_gen_{std::random_device{}()};
|
||||
if (nargs != 2) {
|
||||
throw QueryRuntimeException(
|
||||
"'uniformSample' requires exactly two arguments.");
|
||||
}
|
||||
switch (args[0].type()) {
|
||||
case TypedValue::Type::Null:
|
||||
if (args[1].IsNull() || (args[1].IsInt() && args[1].ValueInt() >= 0)) {
|
||||
return TypedValue::Null;
|
||||
}
|
||||
throw QueryRuntimeException(
|
||||
"Second argument of 'uniformSample' must be a non-negative integer.");
|
||||
case TypedValue::Type::List:
|
||||
if (args[1].IsInt() && args[1].ValueInt() >= 0) {
|
||||
auto &population = args[0].Value<std::vector<TypedValue>>();
|
||||
auto population_size = population.size();
|
||||
if (population_size == 0) return TypedValue::Null;
|
||||
auto desired_length = args[1].ValueInt();
|
||||
std::uniform_int_distribution<uint64_t> rand_dist{0,
|
||||
population_size - 1};
|
||||
std::vector<TypedValue> sampled;
|
||||
sampled.reserve(desired_length);
|
||||
for (int i = 0; i < desired_length; ++i) {
|
||||
sampled.push_back(population[rand_dist(pseudo_rand_gen_)]);
|
||||
}
|
||||
return sampled;
|
||||
}
|
||||
throw QueryRuntimeException(
|
||||
"Second argument of 'uniformSample' must be a non-negative integer.");
|
||||
default:
|
||||
throw QueryRuntimeException(
|
||||
"First argument of 'uniformSample' must be a list.");
|
||||
}
|
||||
}
|
||||
|
||||
TypedValue Abs(TypedValue *args, int64_t nargs, const EvaluationContext &,
|
||||
database::GraphDbAccessor *) {
|
||||
if (nargs != 1) {
|
||||
@ -895,6 +933,7 @@ NameToFunction(const std::string &function_name) {
|
||||
if (function_name == "RANGE") return Range;
|
||||
if (function_name == "RELATIONSHIPS") return Relationships;
|
||||
if (function_name == "TAIL") return Tail;
|
||||
if (function_name == "UNIFORMSAMPLE") return UniformSample;
|
||||
|
||||
// Mathematical functions - numeric
|
||||
if (function_name == "ABS") return Abs;
|
||||
|
@ -1232,6 +1232,43 @@ TEST_F(FunctionTest, Tail) {
|
||||
ASSERT_THROW(EvaluateFunction("TAIL", {2}), QueryRuntimeException);
|
||||
}
|
||||
|
||||
TEST_F(FunctionTest, UniformSample) {
|
||||
ASSERT_THROW(EvaluateFunction("UNIFORMSAMPLE", {}), QueryRuntimeException);
|
||||
ASSERT_TRUE(
|
||||
EvaluateFunction("UNIFORMSAMPLE", {TypedValue::Null, TypedValue::Null})
|
||||
.IsNull());
|
||||
ASSERT_TRUE(
|
||||
EvaluateFunction("UNIFORMSAMPLE", {TypedValue::Null, 1}).IsNull());
|
||||
ASSERT_THROW(EvaluateFunction("UNIFORMSAMPLE",
|
||||
{std::vector<TypedValue>{}, TypedValue::Null}),
|
||||
QueryRuntimeException);
|
||||
ASSERT_TRUE(EvaluateFunction("UNIFORMSAMPLE", {std::vector<TypedValue>{}, 1})
|
||||
.IsNull());
|
||||
ASSERT_THROW(
|
||||
EvaluateFunction("UNIFORMSAMPLE", {std::vector<TypedValue>{1, 2, 3}, -1}),
|
||||
QueryRuntimeException);
|
||||
ASSERT_EQ(
|
||||
EvaluateFunction("UNIFORMSAMPLE", {std::vector<TypedValue>{1, 2, 3}, 0})
|
||||
.ValueList()
|
||||
.size(),
|
||||
0);
|
||||
ASSERT_EQ(
|
||||
EvaluateFunction("UNIFORMSAMPLE", {std::vector<TypedValue>{1, 2, 3}, 2})
|
||||
.ValueList()
|
||||
.size(),
|
||||
2);
|
||||
ASSERT_EQ(
|
||||
EvaluateFunction("UNIFORMSAMPLE", {std::vector<TypedValue>{1, 2, 3}, 3})
|
||||
.ValueList()
|
||||
.size(),
|
||||
3);
|
||||
ASSERT_EQ(
|
||||
EvaluateFunction("UNIFORMSAMPLE", {std::vector<TypedValue>{1, 2, 3}, 5})
|
||||
.ValueList()
|
||||
.size(),
|
||||
5);
|
||||
}
|
||||
|
||||
TEST_F(FunctionTest, Abs) {
|
||||
ASSERT_THROW(EvaluateFunction("ABS", {}), QueryRuntimeException);
|
||||
ASSERT_TRUE(EvaluateFunction("ABS", {TypedValue::Null}).IsNull());
|
||||
|
Loading…
Reference in New Issue
Block a user