diff --git a/CHANGELOG.md b/CHANGELOG.md index f538bf5e7..ac6cf1c84 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ ### Major Features and Improvements * CASE construct (without aggregations) +* rand() function added ### Bug Fixes and Other Changes diff --git a/docs/user_technical/open-cypher.md b/docs/user_technical/open-cypher.md index a27c936c7..323a728f8 100644 --- a/docs/user_technical/open-cypher.md +++ b/docs/user_technical/open-cypher.md @@ -440,6 +440,7 @@ functions. `sign` | Applies the signum function to a given number and returns the result. The signum of positive numbers is 1, of negative -1 and for 0 returns 0. `e` | Returns the base of the natural logarithm. `pi` | Returns the constant *pi*. + `rand` | Returns a random floating point number between 0 (inclusive) and 1 (exclusive). `startsWith` | Check if the first argument starts with the second. `endsWith` | Check if the first argument ends with the second. `contains` | Check if the first argument has an element which is equal to the second argument. diff --git a/src/database/graph_db_accessor.cpp b/src/database/graph_db_accessor.cpp index d8465c57a..cecee848f 100644 --- a/src/database/graph_db_accessor.cpp +++ b/src/database/graph_db_accessor.cpp @@ -9,7 +9,9 @@ #include "utils/on_scope_exit.hpp" GraphDbAccessor::GraphDbAccessor(GraphDb &db) - : db_(db), transaction_(db.tx_engine_.Begin()) {} + : db_(db), transaction_(db.tx_engine_.Begin()) { + pseudo_rand_gen_.seed(std::random_device()()); +} GraphDbAccessor::~GraphDbAccessor() { if (!commited_ && !aborted_) { @@ -332,3 +334,5 @@ const std::string &GraphDbAccessor::PropertyName( debug_assert(!commited_ && !aborted_, "Accessor committed or aborted"); return *property; } + +double GraphDbAccessor::Rand() { return rand_dist_(pseudo_rand_gen_); } diff --git a/src/database/graph_db_accessor.hpp b/src/database/graph_db_accessor.hpp index edf156910..47bd958b2 100644 --- a/src/database/graph_db_accessor.hpp +++ b/src/database/graph_db_accessor.hpp @@ -6,6 +6,7 @@ #pragma once #include <experimental/optional> +#include <random> #include "cppitertools/filter.hpp" #include "cppitertools/imap.hpp" @@ -552,6 +553,11 @@ class GraphDbAccessor { if (!accessor.new_) accessor.new_ = accessor.vlist_->update(*transaction_); } + /** + * Returns a uniformly random-generated number from the [0, 1) interval. + */ + double Rand(); + private: /** * Insert this vertex into corresponding label and label+property (if it @@ -584,6 +590,7 @@ class GraphDbAccessor { void UpdatePropertyIndex(const GraphDbTypes::Property &property, const RecordAccessor<Vertex> &record_accessor, const Vertex *const vertex); + GraphDb &db_; /** The current transaction */ @@ -591,4 +598,8 @@ class GraphDbAccessor { bool commited_{false}; bool aborted_{false}; + + // Random number generation stuff. + std::mt19937 pseudo_rand_gen_; + std::uniform_real_distribution<> rand_dist_{0, 1}; }; diff --git a/src/query/interpret/awesome_memgraph_functions.cpp b/src/query/interpret/awesome_memgraph_functions.cpp index cbc49f648..304c5e499 100644 --- a/src/query/interpret/awesome_memgraph_functions.cpp +++ b/src/query/interpret/awesome_memgraph_functions.cpp @@ -453,6 +453,13 @@ TypedValue Pi(const std::vector<TypedValue> &args, GraphDbAccessor &) { return M_PI; } +TypedValue Rand(const std::vector<TypedValue> &args, GraphDbAccessor &dba) { + if (args.size() != 0U) { + throw QueryRuntimeException("rand shouldn't be called with arguments"); + } + return dba.Rand(); +} + template <bool (*Predicate)(const std::string &s1, const std::string &s2)> TypedValue StringMatchOperator(const std::vector<TypedValue> &args, GraphDbAccessor &) { @@ -534,6 +541,7 @@ NameToFunction(const std::string &function_name) { if (function_name == "SIGN") return Sign; if (function_name == "E") return E; if (function_name == "PI") return Pi; + if (function_name == "RAND") return Rand; if (function_name == kStartsWith) return StartsWith; if (function_name == kEndsWith) return EndsWith; if (function_name == kContains) return Contains; diff --git a/tests/qa/tck_engine/tests/memgraph_V1/features/functions.feature b/tests/qa/tck_engine/tests/memgraph_V1/features/functions.feature index 5399c4ce6..3dc29b65f 100644 --- a/tests/qa/tck_engine/tests/memgraph_V1/features/functions.feature +++ b/tests/qa/tck_engine/tests/memgraph_V1/features/functions.feature @@ -616,6 +616,15 @@ Feature: Functions | n | | 3.141592653589793 | + Scenario: Rand test: + When executing query: + """ + WITH rand() as r RETURN r >= 0.0 AND r < 1.0 as result + """ + Then the result should be: + | result | + | true | + Scenario: All test 01: When executing query: """ diff --git a/tests/unit/graph_db_accessor.cpp b/tests/unit/graph_db_accessor.cpp index b99084761..083528e3c 100644 --- a/tests/unit/graph_db_accessor.cpp +++ b/tests/unit/graph_db_accessor.cpp @@ -352,6 +352,19 @@ TEST(GraphDbAccessorTest, Transfer) { EXPECT_EQ(dba3->Transfer(e12)->PropsAt(prop).Value<int64_t>(), 12); } +TEST(GraphDbAccessorTest, Rand) { + Dbms dbms; + auto dba = dbms.active(); + + double a = dba->Rand(); + EXPECT_GE(a, 0.0); + EXPECT_LT(a, 1.0); + double b = dba->Rand(); + EXPECT_GE(b, 0.0); + EXPECT_LT(b, 1.0); + EXPECT_NE(a, b); +} + int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); // ::testing::GTEST_FLAG(filter) = "*.DetachRemoveVertex"; diff --git a/tests/unit/query_expression_evaluator.cpp b/tests/unit/query_expression_evaluator.cpp index 3abb26989..c9dbbf83a 100644 --- a/tests/unit/query_expression_evaluator.cpp +++ b/tests/unit/query_expression_evaluator.cpp @@ -1047,6 +1047,12 @@ TEST(ExpressionEvaluator, FunctionPi) { ASSERT_DOUBLE_EQ(EvaluateFunction("PI", {}).Value<double>(), M_PI); } +TEST(ExpressionEvaluator, FunctionRand) { + ASSERT_THROW(EvaluateFunction("RAND", {1}), QueryRuntimeException); + ASSERT_GE(EvaluateFunction("RAND", {}).Value<double>(), 0.0); + ASSERT_LT(EvaluateFunction("RAND", {}).Value<double>(), 1.0); +} + TEST(ExpressionEvaluator, FunctionStartsWith) { EXPECT_THROW(EvaluateFunction(kStartsWith, {}), QueryRuntimeException); EXPECT_TRUE(EvaluateFunction(kStartsWith, {"a", TypedValue::Null}).IsNull());