diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ea2b5e76..d329e9163 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ * Map indexing supported. * `assert` function added. * Use \u to specify 4 digit codepoint and \U for 8 digit -* `counter` function added. +* `counter` and `counterSet` functions added. ### Bug Fixes and Other Changes diff --git a/docs/user_technical/open-cypher.md b/docs/user_technical/open-cypher.md index acafd3a23..4059927df 100644 --- a/docs/user_technical/open-cypher.md +++ b/docs/user_technical/open-cypher.md @@ -501,6 +501,7 @@ functions. `all` | Check if all elements of a list satisfy a predicate.<br/>The syntax is: `all(variable IN list WHERE predicate)`. `assert` | Raises an exception reported to the client if the given argument is not `true`. `counter` | Generates integers that are guaranteed to be unique on the database level, for the given counter name. + `counterSet` | Sets the counter with the given name to the given value. #### String Operators diff --git a/src/database/graph_db_accessor.cpp b/src/database/graph_db_accessor.cpp index 0064de484..7d96f7e01 100644 --- a/src/database/graph_db_accessor.cpp +++ b/src/database/graph_db_accessor.cpp @@ -334,3 +334,10 @@ int64_t GraphDbAccessor::Counter(const std::string &name) { .emplace(name, std::make_tuple(name), std::make_tuple(0)) .first->second.fetch_add(1); } + +void GraphDbAccessor::CounterSet(const std::string &name, int64_t value) { + auto name_counter_pair = db_.counters_.access() + .emplace(name, std::make_tuple(name), std::make_tuple(value)); + if (!name_counter_pair.second) + name_counter_pair.first->second.store(value); +} diff --git a/src/database/graph_db_accessor.hpp b/src/database/graph_db_accessor.hpp index 8c27cee36..19c3d131a 100644 --- a/src/database/graph_db_accessor.hpp +++ b/src/database/graph_db_accessor.hpp @@ -568,6 +568,13 @@ class GraphDbAccessor { */ int64_t Counter(const std::string &name); + /** + * Sets the counter with the given name to the given value. Returns nothing. + * If the counter with the given name does not exist, a new counter is + * created and set to the given value. + */ + void CounterSet(const std::string &name, int64_t value); + private: /** * Insert this vertex into corresponding label and label+property (if it diff --git a/src/query/interpret/awesome_memgraph_functions.cpp b/src/query/interpret/awesome_memgraph_functions.cpp index 0c773eaf5..14dcf7e2d 100644 --- a/src/query/interpret/awesome_memgraph_functions.cpp +++ b/src/query/interpret/awesome_memgraph_functions.cpp @@ -532,6 +532,19 @@ TypedValue Counter(const std::vector<TypedValue> &args, GraphDbAccessor &dba) { return dba.Counter(args[0].ValueString()); } + +TypedValue CounterSet(const std::vector<TypedValue> &args, GraphDbAccessor &dba) { + if (args.size() != 2U) { + throw QueryRuntimeException("counterSet takes two arguments"); + } + if (!args[0].IsString()) + throw QueryRuntimeException("first counterSet argument must be a string"); + if (!args[1].IsInt()) + throw QueryRuntimeException("first counterSet argument must be an int"); + + dba.CounterSet(args[0].ValueString(), args[1].ValueInt()); + return TypedValue::Null; +} } // annonymous namespace std::function<TypedValue(const std::vector<TypedValue> &, GraphDbAccessor &)> @@ -576,6 +589,7 @@ NameToFunction(const std::string &function_name) { if (function_name == kContains) return Contains; if (function_name == "ASSERT") return Assert; if (function_name == "COUNTER") return Counter; + if (function_name == "COUNTERSET") return CounterSet; return nullptr; } } // namespace query diff --git a/tests/qa/tck_engine/tests/memgraph_V1/features/functions.feature b/tests/qa/tck_engine/tests/memgraph_V1/features/functions.feature index 7e252b3f3..4a6eaafec 100644 --- a/tests/qa/tck_engine/tests/memgraph_V1/features/functions.feature +++ b/tests/qa/tck_engine/tests/memgraph_V1/features/functions.feature @@ -707,3 +707,15 @@ Feature: Functions | 1 | 0 | | 2 | 1 | + + Scenario: CounterSet test: + When executing query: + """ + WITH counter("n") AS zero + WITH counter("n") AS one, zero + WITH counterSet("n", 42) AS nothing, zero, one + RETURN counter("n") AS n, zero, one, counter("n2") AS n2 + """ + Then the result should be: + | n | zero | one | n2 | + | 42 | 0 | 1 | 0 | diff --git a/tests/unit/query_expression_evaluator.cpp b/tests/unit/query_expression_evaluator.cpp index 4d350736c..e61f9f6de 100644 --- a/tests/unit/query_expression_evaluator.cpp +++ b/tests/unit/query_expression_evaluator.cpp @@ -1218,4 +1218,21 @@ TEST(ExpressionEvaluator, FunctionCounter) { EXPECT_EQ(EvaluateFunction("COUNTER", {"c2"}, dbms).ValueInt(), 1); } +TEST(ExpressionEvaluator, FunctionCounterSet) { + Dbms dbms; + EXPECT_THROW(EvaluateFunction("COUNTERSET", {}, dbms), QueryRuntimeException); + EXPECT_THROW(EvaluateFunction("COUNTERSET", {"a"}, dbms), + QueryRuntimeException); + EXPECT_THROW(EvaluateFunction("COUNTERSET", {"a", "b"}, dbms), + QueryRuntimeException); + EXPECT_THROW(EvaluateFunction("COUNTERSET", {"a", 11, 12}, dbms), + QueryRuntimeException); + EXPECT_EQ(EvaluateFunction("COUNTER", {"c1"}, dbms).ValueInt(), 0); + EvaluateFunction("COUNTERSET", {"c1", 12}, dbms); + EXPECT_EQ(EvaluateFunction("COUNTER", {"c1"}, dbms).ValueInt(), 12); + EvaluateFunction("COUNTERSET", {"c2", 42}, dbms); + EXPECT_EQ(EvaluateFunction("COUNTER", {"c2"}, dbms).ValueInt(), 42); + EXPECT_EQ(EvaluateFunction("COUNTER", {"c1"}, dbms).ValueInt(), 13); + EXPECT_EQ(EvaluateFunction("COUNTER", {"c2"}, dbms).ValueInt(), 43); +} } // namespace