From bfbec8d55087cd514aa416e7566a86d7d121aa01 Mon Sep 17 00:00:00 2001 From: florijan Date: Wed, 13 Sep 2017 17:09:04 +0200 Subject: [PATCH] Counter function added Reviewers: buda, mferencevic, mislav.bradac Reviewed By: mislav.bradac Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D787 --- CHANGELOG.md | 1 + docs/user_technical/open-cypher.md | 1 + src/database/graph_db.hpp | 4 ++ src/database/graph_db_accessor.cpp | 6 +++ src/database/graph_db_accessor.hpp | 7 ++++ .../interpret/awesome_memgraph_functions.cpp | 22 +++++++--- .../memgraph_V1/features/functions.feature | 17 ++++++++ tests/unit/query_expression_evaluator.cpp | 40 +++++++++++++++---- 8 files changed, 84 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e26675074..4ea2b5e76 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ * Map indexing supported. * `assert` function added. * Use \u to specify 4 digit codepoint and \U for 8 digit +* `counter` function added. ### Bug Fixes and Other Changes diff --git a/docs/user_technical/open-cypher.md b/docs/user_technical/open-cypher.md index 010024c0b..acafd3a23 100644 --- a/docs/user_technical/open-cypher.md +++ b/docs/user_technical/open-cypher.md @@ -500,6 +500,7 @@ functions. `contains` | Check if the first argument has an element which is equal to the second argument. `all` | Check if all elements of a list satisfy a predicate.
The syntax is: `all(variable IN list WHERE predicate)`. `assert` | Raises an exception reported to the client if the given argument is not `true`. + `counter` | Generates integers that are guaranteed to be unique on the database level, for the given counter name. #### String Operators diff --git a/src/database/graph_db.hpp b/src/database/graph_db.hpp index 6e2e8452f..a118c1707 100644 --- a/src/database/graph_db.hpp +++ b/src/database/graph_db.hpp @@ -5,6 +5,7 @@ #include "cppitertools/filter.hpp" #include "cppitertools/imap.hpp" +#include "data_structures/concurrent/concurrent_map.hpp" #include "data_structures/concurrent/concurrent_set.hpp" #include "data_structures/concurrent/skiplist.hpp" #include "database/graph_db_datatypes.hpp" @@ -120,4 +121,7 @@ class GraphDb { // Periodically wakes up and hints to transactions that are running for a long // time to stop their execution. Scheduler transaction_killer_; + + // DB level global counters, used in the "counter" function + ConcurrentMap> counters_; }; diff --git a/src/database/graph_db_accessor.cpp b/src/database/graph_db_accessor.cpp index ac7d3b668..0064de484 100644 --- a/src/database/graph_db_accessor.cpp +++ b/src/database/graph_db_accessor.cpp @@ -328,3 +328,9 @@ const std::string &GraphDbAccessor::PropertyName( debug_assert(!commited_ && !aborted_, "Accessor committed or aborted"); return *property; } + +int64_t GraphDbAccessor::Counter(const std::string &name) { + return db_.counters_.access() + .emplace(name, std::make_tuple(name), std::make_tuple(0)) + .first->second.fetch_add(1); +} diff --git a/src/database/graph_db_accessor.hpp b/src/database/graph_db_accessor.hpp index 0e32b5e44..8c27cee36 100644 --- a/src/database/graph_db_accessor.hpp +++ b/src/database/graph_db_accessor.hpp @@ -561,6 +561,13 @@ class GraphDbAccessor { if (!accessor.new_) accessor.new_ = accessor.vlist_->update(*transaction_); } + /** + * Returns the current value of the counter with the given name, and + * increments that counter. If the counter with the given name does not exist, + * a new counter is created and this function returns 0. + */ + int64_t Counter(const std::string &name); + private: /** * Insert this vertex into corresponding label and label+property (if it diff --git a/src/query/interpret/awesome_memgraph_functions.cpp b/src/query/interpret/awesome_memgraph_functions.cpp index 8c1bc0b20..0c773eaf5 100644 --- a/src/query/interpret/awesome_memgraph_functions.cpp +++ b/src/query/interpret/awesome_memgraph_functions.cpp @@ -512,18 +512,27 @@ TypedValue Assert(const std::vector &args, GraphDbAccessor &) { throw QueryRuntimeException("assert takes one or two arguments"); } if (args[0].type() != TypedValue::Type::Bool) - throw QueryRuntimeException("first assert argument must be bool"); + throw QueryRuntimeException("first assert argument must be bool"); if (args.size() == 2U && args[1].type() != TypedValue::Type::String) - throw QueryRuntimeException("second assert argument must be a string"); + throw QueryRuntimeException("second assert argument must be a string"); if (!args[0].ValueBool()) { std::string message("assertion failed"); - if (args.size() == 2U) - message += ": " + args[1].ValueString(); + if (args.size() == 2U) message += ": " + args[1].ValueString(); throw QueryRuntimeException(message); } return args[0]; } -} // annonymous namespace + +TypedValue Counter(const std::vector &args, GraphDbAccessor &dba) { + if (args.size() != 1U) { + throw QueryRuntimeException("counter takes one argument"); + } + if (!args[0].IsString()) + throw QueryRuntimeException("first counter argument must be a string"); + + return dba.Counter(args[0].ValueString()); +} +} // annonymous namespace std::function &, GraphDbAccessor &)> NameToFunction(const std::string &function_name) { @@ -566,6 +575,7 @@ NameToFunction(const std::string &function_name) { if (function_name == kEndsWith) return EndsWith; if (function_name == kContains) return Contains; if (function_name == "ASSERT") return Assert; + if (function_name == "COUNTER") return Counter; return nullptr; } -} // namespace query +} // namespace query diff --git a/tests/qa/tck_engine/tests/memgraph_V1/features/functions.feature b/tests/qa/tck_engine/tests/memgraph_V1/features/functions.feature index 8025500cf..7e252b3f3 100644 --- a/tests/qa/tck_engine/tests/memgraph_V1/features/functions.feature +++ b/tests/qa/tck_engine/tests/memgraph_V1/features/functions.feature @@ -690,3 +690,20 @@ Feature: Functions Then the result should be: | res | | true | + + Scenario: Counter test: + Given an empty graph + And having executed: + """ + CREATE (), (), () + """ + When executing query: + """ + MATCH (n) SET n.id = counter("n.id") WITH n SKIP 1 + RETURN n.id, counter("other") AS c2 + """ + Then the result should be: + | n.id | c2 | + | 1 | 0 | + | 2 | 1 | + diff --git a/tests/unit/query_expression_evaluator.cpp b/tests/unit/query_expression_evaluator.cpp index 44afc4fde..4d350736c 100644 --- a/tests/unit/query_expression_evaluator.cpp +++ b/tests/unit/query_expression_evaluator.cpp @@ -36,11 +36,13 @@ struct NoContextExpressionEvaluator { }; TypedValue EvaluateFunction(const std::string &function_name, - const std::vector &args) { + const std::vector &args, Dbms &dbms) { AstTreeStorage storage; - NoContextExpressionEvaluator eval; - Dbms dbms; + SymbolTable symbol_table; auto dba = dbms.active(); + Frame frame{128}; + Parameters parameters; + ExpressionEvaluator eval{frame, parameters, symbol_table, *dba}; std::vector expressions; for (const auto &arg : args) { @@ -48,7 +50,13 @@ TypedValue EvaluateFunction(const std::string &function_name, } auto *op = storage.Create(NameToFunction(function_name), expressions); - return op->Accept(eval.eval); + return op->Accept(eval); +} + +TypedValue EvaluateFunction(const std::string &function_name, + const std::vector &args) { + Dbms dbms; + return EvaluateFunction(function_name, args, dbms); } TEST(ExpressionEvaluator, OrOperator) { @@ -1166,13 +1174,17 @@ TEST(ExpressionEvaluator, FunctionAllWhereWrongType) { TEST(ExpressionEvaluator, FunctionAssert) { // Invalid calls. ASSERT_THROW(EvaluateFunction("ASSERT", {}), QueryRuntimeException); - ASSERT_THROW(EvaluateFunction("ASSERT", {false, false}), QueryRuntimeException); - ASSERT_THROW(EvaluateFunction("ASSERT", {"string", false}), QueryRuntimeException); - ASSERT_THROW(EvaluateFunction("ASSERT", {false, "reason", true}), QueryRuntimeException); + ASSERT_THROW(EvaluateFunction("ASSERT", {false, false}), + QueryRuntimeException); + ASSERT_THROW(EvaluateFunction("ASSERT", {"string", false}), + QueryRuntimeException); + ASSERT_THROW(EvaluateFunction("ASSERT", {false, "reason", true}), + QueryRuntimeException); // Valid calls, assertion fails. ASSERT_THROW(EvaluateFunction("ASSERT", {false}), QueryRuntimeException); - ASSERT_THROW(EvaluateFunction("ASSERT", {false, "message"}), QueryRuntimeException); + ASSERT_THROW(EvaluateFunction("ASSERT", {false, "message"}), + QueryRuntimeException); try { EvaluateFunction("ASSERT", {false, "bbgba"}); } catch (QueryRuntimeException &e) { @@ -1194,4 +1206,16 @@ TEST(ExpressionEvaluator, ParameterLookup) { EXPECT_EQ(value.Value(), 42); } +TEST(ExpressionEvaluator, FunctionCounter) { + Dbms dbms; + EXPECT_THROW(EvaluateFunction("COUNTER", {}, dbms), QueryRuntimeException); + EXPECT_THROW(EvaluateFunction("COUNTER", {"a", "b"}, dbms), + QueryRuntimeException); + EXPECT_EQ(EvaluateFunction("COUNTER", {"c1"}, dbms).ValueInt(), 0); + EXPECT_EQ(EvaluateFunction("COUNTER", {"c1"}, dbms).ValueInt(), 1); + EXPECT_EQ(EvaluateFunction("COUNTER", {"c2"}, dbms).ValueInt(), 0); + EXPECT_EQ(EvaluateFunction("COUNTER", {"c1"}, dbms).ValueInt(), 2); + EXPECT_EQ(EvaluateFunction("COUNTER", {"c2"}, dbms).ValueInt(), 1); +} + } // namespace