From 664622f68e3b82cee53eaf31bb966d0a51a6de84 Mon Sep 17 00:00:00 2001 From: Marin Petricevic Date: Mon, 7 Jan 2019 15:43:29 +0100 Subject: [PATCH] add 'sample' awesome memgraph function Summary: This simple function is required by the Tensorflow integration so that Memgraph can always return regular matrices of desired size. Reviewers: teon.banek, mtomic, dsantl Reviewed By: mtomic, dsantl Subscribers: mferencevic, pullbot Differential Revision: https://phabricator.memgraph.io/D1783 --- .../interpret/awesome_memgraph_functions.cpp | 39 +++++++++++++++++++ tests/unit/query_expression_evaluator.cpp | 37 ++++++++++++++++++ 2 files changed, 76 insertions(+) diff --git a/src/query/interpret/awesome_memgraph_functions.cpp b/src/query/interpret/awesome_memgraph_functions.cpp index a57802a29..dfec85abb 100644 --- a/src/query/interpret/awesome_memgraph_functions.cpp +++ b/src/query/interpret/awesome_memgraph_functions.cpp @@ -422,6 +422,44 @@ TypedValue Tail(TypedValue *args, int64_t nargs, const EvaluationContext &, } } +TypedValue UniformSample(TypedValue *args, int64_t nargs, + const EvaluationContext &, + database::GraphDbAccessor *) { + static thread_local std::mt19937 pseudo_rand_gen_{std::random_device{}()}; + if (nargs != 2) { + throw QueryRuntimeException( + "'uniformSample' requires exactly two arguments."); + } + switch (args[0].type()) { + case TypedValue::Type::Null: + if (args[1].IsNull() || (args[1].IsInt() && args[1].ValueInt() >= 0)) { + return TypedValue::Null; + } + throw QueryRuntimeException( + "Second argument of 'uniformSample' must be a non-negative integer."); + case TypedValue::Type::List: + if (args[1].IsInt() && args[1].ValueInt() >= 0) { + auto &population = args[0].Value>(); + auto population_size = population.size(); + if (population_size == 0) return TypedValue::Null; + auto desired_length = args[1].ValueInt(); + std::uniform_int_distribution rand_dist{0, + population_size - 1}; + std::vector sampled; + sampled.reserve(desired_length); + for (int i = 0; i < desired_length; ++i) { + sampled.push_back(population[rand_dist(pseudo_rand_gen_)]); + } + return sampled; + } + throw QueryRuntimeException( + "Second argument of 'uniformSample' must be a non-negative integer."); + default: + throw QueryRuntimeException( + "First argument of 'uniformSample' must be a list."); + } +} + TypedValue Abs(TypedValue *args, int64_t nargs, const EvaluationContext &, database::GraphDbAccessor *) { if (nargs != 1) { @@ -895,6 +933,7 @@ NameToFunction(const std::string &function_name) { if (function_name == "RANGE") return Range; if (function_name == "RELATIONSHIPS") return Relationships; if (function_name == "TAIL") return Tail; + if (function_name == "UNIFORMSAMPLE") return UniformSample; // Mathematical functions - numeric if (function_name == "ABS") return Abs; diff --git a/tests/unit/query_expression_evaluator.cpp b/tests/unit/query_expression_evaluator.cpp index 87b58e4d5..115eaf0b6 100644 --- a/tests/unit/query_expression_evaluator.cpp +++ b/tests/unit/query_expression_evaluator.cpp @@ -1232,6 +1232,43 @@ TEST_F(FunctionTest, Tail) { ASSERT_THROW(EvaluateFunction("TAIL", {2}), QueryRuntimeException); } +TEST_F(FunctionTest, UniformSample) { + ASSERT_THROW(EvaluateFunction("UNIFORMSAMPLE", {}), QueryRuntimeException); + ASSERT_TRUE( + EvaluateFunction("UNIFORMSAMPLE", {TypedValue::Null, TypedValue::Null}) + .IsNull()); + ASSERT_TRUE( + EvaluateFunction("UNIFORMSAMPLE", {TypedValue::Null, 1}).IsNull()); + ASSERT_THROW(EvaluateFunction("UNIFORMSAMPLE", + {std::vector{}, TypedValue::Null}), + QueryRuntimeException); + ASSERT_TRUE(EvaluateFunction("UNIFORMSAMPLE", {std::vector{}, 1}) + .IsNull()); + ASSERT_THROW( + EvaluateFunction("UNIFORMSAMPLE", {std::vector{1, 2, 3}, -1}), + QueryRuntimeException); + ASSERT_EQ( + EvaluateFunction("UNIFORMSAMPLE", {std::vector{1, 2, 3}, 0}) + .ValueList() + .size(), + 0); + ASSERT_EQ( + EvaluateFunction("UNIFORMSAMPLE", {std::vector{1, 2, 3}, 2}) + .ValueList() + .size(), + 2); + ASSERT_EQ( + EvaluateFunction("UNIFORMSAMPLE", {std::vector{1, 2, 3}, 3}) + .ValueList() + .size(), + 3); + ASSERT_EQ( + EvaluateFunction("UNIFORMSAMPLE", {std::vector{1, 2, 3}, 5}) + .ValueList() + .size(), + 5); +} + TEST_F(FunctionTest, Abs) { ASSERT_THROW(EvaluateFunction("ABS", {}), QueryRuntimeException); ASSERT_TRUE(EvaluateFunction("ABS", {TypedValue::Null}).IsNull());