add 'sample' awesome memgraph function

Summary: This simple function is required by the Tensorflow integration so that Memgraph can always return regular matrices of desired size.

Reviewers: teon.banek, mtomic, dsantl

Reviewed By: mtomic, dsantl

Subscribers: mferencevic, pullbot

Differential Revision: https://phabricator.memgraph.io/D1783
This commit is contained in:
Marin Petricevic 2019-01-07 15:43:29 +01:00
parent 1af728b505
commit 664622f68e
2 changed files with 76 additions and 0 deletions

View File

@ -422,6 +422,44 @@ TypedValue Tail(TypedValue *args, int64_t nargs, const EvaluationContext &,
}
}
TypedValue UniformSample(TypedValue *args, int64_t nargs,
const EvaluationContext &,
database::GraphDbAccessor *) {
static thread_local std::mt19937 pseudo_rand_gen_{std::random_device{}()};
if (nargs != 2) {
throw QueryRuntimeException(
"'uniformSample' requires exactly two arguments.");
}
switch (args[0].type()) {
case TypedValue::Type::Null:
if (args[1].IsNull() || (args[1].IsInt() && args[1].ValueInt() >= 0)) {
return TypedValue::Null;
}
throw QueryRuntimeException(
"Second argument of 'uniformSample' must be a non-negative integer.");
case TypedValue::Type::List:
if (args[1].IsInt() && args[1].ValueInt() >= 0) {
auto &population = args[0].Value<std::vector<TypedValue>>();
auto population_size = population.size();
if (population_size == 0) return TypedValue::Null;
auto desired_length = args[1].ValueInt();
std::uniform_int_distribution<uint64_t> rand_dist{0,
population_size - 1};
std::vector<TypedValue> sampled;
sampled.reserve(desired_length);
for (int i = 0; i < desired_length; ++i) {
sampled.push_back(population[rand_dist(pseudo_rand_gen_)]);
}
return sampled;
}
throw QueryRuntimeException(
"Second argument of 'uniformSample' must be a non-negative integer.");
default:
throw QueryRuntimeException(
"First argument of 'uniformSample' must be a list.");
}
}
TypedValue Abs(TypedValue *args, int64_t nargs, const EvaluationContext &,
database::GraphDbAccessor *) {
if (nargs != 1) {
@ -895,6 +933,7 @@ NameToFunction(const std::string &function_name) {
if (function_name == "RANGE") return Range;
if (function_name == "RELATIONSHIPS") return Relationships;
if (function_name == "TAIL") return Tail;
if (function_name == "UNIFORMSAMPLE") return UniformSample;
// Mathematical functions - numeric
if (function_name == "ABS") return Abs;

View File

@ -1232,6 +1232,43 @@ TEST_F(FunctionTest, Tail) {
ASSERT_THROW(EvaluateFunction("TAIL", {2}), QueryRuntimeException);
}
TEST_F(FunctionTest, UniformSample) {
ASSERT_THROW(EvaluateFunction("UNIFORMSAMPLE", {}), QueryRuntimeException);
ASSERT_TRUE(
EvaluateFunction("UNIFORMSAMPLE", {TypedValue::Null, TypedValue::Null})
.IsNull());
ASSERT_TRUE(
EvaluateFunction("UNIFORMSAMPLE", {TypedValue::Null, 1}).IsNull());
ASSERT_THROW(EvaluateFunction("UNIFORMSAMPLE",
{std::vector<TypedValue>{}, TypedValue::Null}),
QueryRuntimeException);
ASSERT_TRUE(EvaluateFunction("UNIFORMSAMPLE", {std::vector<TypedValue>{}, 1})
.IsNull());
ASSERT_THROW(
EvaluateFunction("UNIFORMSAMPLE", {std::vector<TypedValue>{1, 2, 3}, -1}),
QueryRuntimeException);
ASSERT_EQ(
EvaluateFunction("UNIFORMSAMPLE", {std::vector<TypedValue>{1, 2, 3}, 0})
.ValueList()
.size(),
0);
ASSERT_EQ(
EvaluateFunction("UNIFORMSAMPLE", {std::vector<TypedValue>{1, 2, 3}, 2})
.ValueList()
.size(),
2);
ASSERT_EQ(
EvaluateFunction("UNIFORMSAMPLE", {std::vector<TypedValue>{1, 2, 3}, 3})
.ValueList()
.size(),
3);
ASSERT_EQ(
EvaluateFunction("UNIFORMSAMPLE", {std::vector<TypedValue>{1, 2, 3}, 5})
.ValueList()
.size(),
5);
}
TEST_F(FunctionTest, Abs) {
ASSERT_THROW(EvaluateFunction("ABS", {}), QueryRuntimeException);
ASSERT_TRUE(EvaluateFunction("ABS", {TypedValue::Null}).IsNull());