diff --git a/CHANGELOG.md b/CHANGELOG.md index b5c114f0c..906fa7afe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ * Maps can now be stored as vertex/edge properties. * `collect` aggregation now supports Map collection. * Map indexing supported. +* `assert` function added. ### Bug Fixes and Other Changes diff --git a/docs/user_technical/open-cypher.md b/docs/user_technical/open-cypher.md index 9db276940..859ce511c 100644 --- a/docs/user_technical/open-cypher.md +++ b/docs/user_technical/open-cypher.md @@ -453,6 +453,7 @@ functions. `endsWith` | Check if the first argument ends with the second. `contains` | Check if the first argument has an element which is equal to the second argument. `all` | Check if all elements of a list satisfy a predicate.
The syntax is: `all(variable IN list WHERE predicate)`. + `assert` | Raises an exception reported to the client if the given argument is not `true`. #### String Operators diff --git a/src/query/interpret/awesome_memgraph_functions.cpp b/src/query/interpret/awesome_memgraph_functions.cpp index c7acca997..8c1bc0b20 100644 --- a/src/query/interpret/awesome_memgraph_functions.cpp +++ b/src/query/interpret/awesome_memgraph_functions.cpp @@ -506,7 +506,24 @@ bool ContainsPredicate(const std::string &s1, const std::string &s2) { return s1.find(s2) != std::string::npos; } auto Contains = StringMatchOperator; + +TypedValue Assert(const std::vector &args, GraphDbAccessor &) { + if (args.size() < 1U || args.size() > 2U) { + throw QueryRuntimeException("assert takes one or two arguments"); + } + if (args[0].type() != TypedValue::Type::Bool) + throw QueryRuntimeException("first assert argument must be bool"); + if (args.size() == 2U && args[1].type() != TypedValue::Type::String) + throw QueryRuntimeException("second assert argument must be a string"); + if (!args[0].ValueBool()) { + std::string message("assertion failed"); + if (args.size() == 2U) + message += ": " + args[1].ValueString(); + throw QueryRuntimeException(message); + } + return args[0]; } +} // annonymous namespace std::function &, GraphDbAccessor &)> NameToFunction(const std::string &function_name) { @@ -548,6 +565,7 @@ NameToFunction(const std::string &function_name) { if (function_name == kStartsWith) return StartsWith; if (function_name == kEndsWith) return EndsWith; if (function_name == kContains) return Contains; + if (function_name == "ASSERT") return Assert; return nullptr; } -} +} // namespace query diff --git a/src/utils/string.hpp b/src/utils/string.hpp index 95bf0ff03..31562ce54 100644 --- a/src/utils/string.hpp +++ b/src/utils/string.hpp @@ -20,7 +20,7 @@ namespace utils { * * @return trimmed string */ -inline std::string Trim(const std::string& s) { +inline std::string Trim(const std::string &s) { auto begin = s.begin(); auto end = s.end(); if (begin == end) { @@ -58,11 +58,11 @@ inline std::string ToUpperCase(std::string s) { /** * Join strings in vector separated by a given separator. */ -inline std::string Join(const std::vector& strings, - const std::string& separator) { +inline std::string Join(const std::vector &strings, + const std::string &separator) { if (strings.size() == 0U) return ""; int64_t total_size = 0; - for (const auto& x : strings) { + for (const auto &x : strings) { total_size += x.size(); } total_size += separator.size() * (static_cast(strings.size()) - 1); @@ -80,8 +80,8 @@ inline std::string Join(const std::vector& strings, * Replaces all occurences of in with . */ // TODO: This could be implemented much more efficient. -inline std::string Replace(std::string src, const std::string& match, - const std::string& replacement) { +inline std::string Replace(std::string src, const std::string &match, + const std::string &replacement) { for (size_t pos = src.find(match); pos != std::string::npos; pos = src.find(match, pos + replacement.size())) { src.erase(pos, match.length()).insert(pos, replacement); @@ -113,7 +113,7 @@ inline std::vector Split(const std::string &src, * Parse double using classic locale, throws BasicException if it wasn't able to * parse whole string. */ -inline double ParseDouble(const std::string& s) { +inline double ParseDouble(const std::string &s) { // stod would be nicer but it uses current locale so we shouldn't use it. double t = 0.0; std::istringstream iss(s); @@ -124,4 +124,12 @@ inline double ParseDouble(const std::string& s) { } return t; } + +/** + * Checks if the given string `s` ends with the given `suffix`. + */ +inline bool EndsWith(const std::string &s, const std::string &suffix) { + return s.size() >= suffix.size() && + s.compare(s.size() - suffix.size(), std::string::npos, suffix) == 0; +} } diff --git a/tests/qa/tck_engine/tests/memgraph_V1/features/functions.feature b/tests/qa/tck_engine/tests/memgraph_V1/features/functions.feature index 3dc29b65f..8025500cf 100644 --- a/tests/qa/tck_engine/tests/memgraph_V1/features/functions.feature +++ b/tests/qa/tck_engine/tests/memgraph_V1/features/functions.feature @@ -649,3 +649,44 @@ Feature: Functions RETURN all(x IN [1, 2, '3'] WHERE x < 3) AS a """ Then an error should be raised + + + Scenario: Assert test fail, no message: + Given an empty graph + And having executed: + """ + CREATE ({a: 1}) + """ + When executing query: + """ + MATCH (n) RETURN assert(n.a = 2) AS res + """ + Then an error should be raised + + + Scenario: Assert test fail: + Given an empty graph + And having executed: + """ + CREATE ({a: 1, b: "string"}) + """ + When executing query: + """ + MATCH (n) RETURN assert(n.a = 2, n.b) AS res + """ + Then an error should be raised + + + Scenario: Assert test pass: + Given an empty graph + And having executed: + """ + CREATE ({a: 1, b: "string"}) + """ + When executing query: + """ + MATCH (n) RETURN assert(n.a = 1, n.b) AS res + """ + Then the result should be: + | res | + | true | diff --git a/tests/unit/query_expression_evaluator.cpp b/tests/unit/query_expression_evaluator.cpp index fb577cdd0..aadc2c0fb 100644 --- a/tests/unit/query_expression_evaluator.cpp +++ b/tests/unit/query_expression_evaluator.cpp @@ -14,6 +14,7 @@ #include "query/interpret/awesome_memgraph_functions.hpp" #include "query/interpret/eval.hpp" #include "query/interpret/frame.hpp" +#include "utils/string.hpp" #include "query_common.hpp" @@ -1159,4 +1160,25 @@ TEST(ExpressionEvaluator, FunctionAllWhereWrongType) { eval.symbol_table[*all->identifier_] = x_sym; EXPECT_THROW(all->Accept(eval.eval), QueryRuntimeException); } + +TEST(ExpressionEvaluator, FunctionAssert) { + // Invalid calls. + ASSERT_THROW(EvaluateFunction("ASSERT", {}), QueryRuntimeException); + ASSERT_THROW(EvaluateFunction("ASSERT", {false, false}), QueryRuntimeException); + ASSERT_THROW(EvaluateFunction("ASSERT", {"string", false}), QueryRuntimeException); + ASSERT_THROW(EvaluateFunction("ASSERT", {false, "reason", true}), QueryRuntimeException); + + // Valid calls, assertion fails. + ASSERT_THROW(EvaluateFunction("ASSERT", {false}), QueryRuntimeException); + ASSERT_THROW(EvaluateFunction("ASSERT", {false, "message"}), QueryRuntimeException); + try { + EvaluateFunction("ASSERT", {false, "bbgba"}); + } catch (QueryRuntimeException &e) { + ASSERT_TRUE(utils::EndsWith(e.what(), "bbgba")); + } + + // Valid calls, assertion passes. + ASSERT_TRUE(EvaluateFunction("ASSERT", {true}).ValueBool()); + ASSERT_TRUE(EvaluateFunction("ASSERT", {true, "message"}).ValueBool()); +} }