diff --git a/CHANGELOG.md b/CHANGELOG.md
index b5c114f0c..906fa7afe 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -9,6 +9,7 @@
* Maps can now be stored as vertex/edge properties.
* `collect` aggregation now supports Map collection.
* Map indexing supported.
+* `assert` function added.
### Bug Fixes and Other Changes
diff --git a/docs/user_technical/open-cypher.md b/docs/user_technical/open-cypher.md
index 9db276940..859ce511c 100644
--- a/docs/user_technical/open-cypher.md
+++ b/docs/user_technical/open-cypher.md
@@ -453,6 +453,7 @@ functions.
`endsWith` | Check if the first argument ends with the second.
`contains` | Check if the first argument has an element which is equal to the second argument.
`all` | Check if all elements of a list satisfy a predicate.
The syntax is: `all(variable IN list WHERE predicate)`.
+ `assert` | Raises an exception reported to the client if the given argument is not `true`.
#### String Operators
diff --git a/src/query/interpret/awesome_memgraph_functions.cpp b/src/query/interpret/awesome_memgraph_functions.cpp
index c7acca997..8c1bc0b20 100644
--- a/src/query/interpret/awesome_memgraph_functions.cpp
+++ b/src/query/interpret/awesome_memgraph_functions.cpp
@@ -506,7 +506,24 @@ bool ContainsPredicate(const std::string &s1, const std::string &s2) {
return s1.find(s2) != std::string::npos;
}
auto Contains = StringMatchOperator;
+
+TypedValue Assert(const std::vector &args, GraphDbAccessor &) {
+ if (args.size() < 1U || args.size() > 2U) {
+ throw QueryRuntimeException("assert takes one or two arguments");
+ }
+ if (args[0].type() != TypedValue::Type::Bool)
+ throw QueryRuntimeException("first assert argument must be bool");
+ if (args.size() == 2U && args[1].type() != TypedValue::Type::String)
+ throw QueryRuntimeException("second assert argument must be a string");
+ if (!args[0].ValueBool()) {
+ std::string message("assertion failed");
+ if (args.size() == 2U)
+ message += ": " + args[1].ValueString();
+ throw QueryRuntimeException(message);
+ }
+ return args[0];
}
+} // annonymous namespace
std::function &, GraphDbAccessor &)>
NameToFunction(const std::string &function_name) {
@@ -548,6 +565,7 @@ NameToFunction(const std::string &function_name) {
if (function_name == kStartsWith) return StartsWith;
if (function_name == kEndsWith) return EndsWith;
if (function_name == kContains) return Contains;
+ if (function_name == "ASSERT") return Assert;
return nullptr;
}
-}
+} // namespace query
diff --git a/src/utils/string.hpp b/src/utils/string.hpp
index 95bf0ff03..31562ce54 100644
--- a/src/utils/string.hpp
+++ b/src/utils/string.hpp
@@ -20,7 +20,7 @@ namespace utils {
*
* @return trimmed string
*/
-inline std::string Trim(const std::string& s) {
+inline std::string Trim(const std::string &s) {
auto begin = s.begin();
auto end = s.end();
if (begin == end) {
@@ -58,11 +58,11 @@ inline std::string ToUpperCase(std::string s) {
/**
* Join strings in vector separated by a given separator.
*/
-inline std::string Join(const std::vector& strings,
- const std::string& separator) {
+inline std::string Join(const std::vector &strings,
+ const std::string &separator) {
if (strings.size() == 0U) return "";
int64_t total_size = 0;
- for (const auto& x : strings) {
+ for (const auto &x : strings) {
total_size += x.size();
}
total_size += separator.size() * (static_cast(strings.size()) - 1);
@@ -80,8 +80,8 @@ inline std::string Join(const std::vector& strings,
* Replaces all occurences of in with .
*/
// TODO: This could be implemented much more efficient.
-inline std::string Replace(std::string src, const std::string& match,
- const std::string& replacement) {
+inline std::string Replace(std::string src, const std::string &match,
+ const std::string &replacement) {
for (size_t pos = src.find(match); pos != std::string::npos;
pos = src.find(match, pos + replacement.size())) {
src.erase(pos, match.length()).insert(pos, replacement);
@@ -113,7 +113,7 @@ inline std::vector Split(const std::string &src,
* Parse double using classic locale, throws BasicException if it wasn't able to
* parse whole string.
*/
-inline double ParseDouble(const std::string& s) {
+inline double ParseDouble(const std::string &s) {
// stod would be nicer but it uses current locale so we shouldn't use it.
double t = 0.0;
std::istringstream iss(s);
@@ -124,4 +124,12 @@ inline double ParseDouble(const std::string& s) {
}
return t;
}
+
+/**
+ * Checks if the given string `s` ends with the given `suffix`.
+ */
+inline bool EndsWith(const std::string &s, const std::string &suffix) {
+ return s.size() >= suffix.size() &&
+ s.compare(s.size() - suffix.size(), std::string::npos, suffix) == 0;
+}
}
diff --git a/tests/qa/tck_engine/tests/memgraph_V1/features/functions.feature b/tests/qa/tck_engine/tests/memgraph_V1/features/functions.feature
index 3dc29b65f..8025500cf 100644
--- a/tests/qa/tck_engine/tests/memgraph_V1/features/functions.feature
+++ b/tests/qa/tck_engine/tests/memgraph_V1/features/functions.feature
@@ -649,3 +649,44 @@ Feature: Functions
RETURN all(x IN [1, 2, '3'] WHERE x < 3) AS a
"""
Then an error should be raised
+
+
+ Scenario: Assert test fail, no message:
+ Given an empty graph
+ And having executed:
+ """
+ CREATE ({a: 1})
+ """
+ When executing query:
+ """
+ MATCH (n) RETURN assert(n.a = 2) AS res
+ """
+ Then an error should be raised
+
+
+ Scenario: Assert test fail:
+ Given an empty graph
+ And having executed:
+ """
+ CREATE ({a: 1, b: "string"})
+ """
+ When executing query:
+ """
+ MATCH (n) RETURN assert(n.a = 2, n.b) AS res
+ """
+ Then an error should be raised
+
+
+ Scenario: Assert test pass:
+ Given an empty graph
+ And having executed:
+ """
+ CREATE ({a: 1, b: "string"})
+ """
+ When executing query:
+ """
+ MATCH (n) RETURN assert(n.a = 1, n.b) AS res
+ """
+ Then the result should be:
+ | res |
+ | true |
diff --git a/tests/unit/query_expression_evaluator.cpp b/tests/unit/query_expression_evaluator.cpp
index fb577cdd0..aadc2c0fb 100644
--- a/tests/unit/query_expression_evaluator.cpp
+++ b/tests/unit/query_expression_evaluator.cpp
@@ -14,6 +14,7 @@
#include "query/interpret/awesome_memgraph_functions.hpp"
#include "query/interpret/eval.hpp"
#include "query/interpret/frame.hpp"
+#include "utils/string.hpp"
#include "query_common.hpp"
@@ -1159,4 +1160,25 @@ TEST(ExpressionEvaluator, FunctionAllWhereWrongType) {
eval.symbol_table[*all->identifier_] = x_sym;
EXPECT_THROW(all->Accept(eval.eval), QueryRuntimeException);
}
+
+TEST(ExpressionEvaluator, FunctionAssert) {
+ // Invalid calls.
+ ASSERT_THROW(EvaluateFunction("ASSERT", {}), QueryRuntimeException);
+ ASSERT_THROW(EvaluateFunction("ASSERT", {false, false}), QueryRuntimeException);
+ ASSERT_THROW(EvaluateFunction("ASSERT", {"string", false}), QueryRuntimeException);
+ ASSERT_THROW(EvaluateFunction("ASSERT", {false, "reason", true}), QueryRuntimeException);
+
+ // Valid calls, assertion fails.
+ ASSERT_THROW(EvaluateFunction("ASSERT", {false}), QueryRuntimeException);
+ ASSERT_THROW(EvaluateFunction("ASSERT", {false, "message"}), QueryRuntimeException);
+ try {
+ EvaluateFunction("ASSERT", {false, "bbgba"});
+ } catch (QueryRuntimeException &e) {
+ ASSERT_TRUE(utils::EndsWith(e.what(), "bbgba"));
+ }
+
+ // Valid calls, assertion passes.
+ ASSERT_TRUE(EvaluateFunction("ASSERT", {true}).ValueBool());
+ ASSERT_TRUE(EvaluateFunction("ASSERT", {true, "message"}).ValueBool());
+}
}