diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index 21aa3bf7b..1f5ce0643 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -45,7 +45,8 @@ antlrcpp::Any CypherMainVisitor::visitSingleQuery( for (Clause *clause : query_->clauses_) { if (dynamic_cast(clause)) { if (has_update || has_return) { - throw SemanticException("Unwind can't be after return or update clause"); + throw SemanticException( + "Unwind can't be after return or update clause"); } } else if (auto *match = dynamic_cast(clause)) { if (has_update || has_return) { @@ -619,7 +620,22 @@ antlrcpp::Any CypherMainVisitor::visitExpression3a( expression = static_cast(storage_.Create( expression, op->expression3b()->accept(this))); } else { - throw utils::NotYetImplemented(); + std::function &, + GraphDbAccessor &)> + f; + if (op->STARTS() && op->WITH()) { + f = NameToFunction(kStartsWith); + } else if (op->ENDS() && op->WITH()) { + f = NameToFunction(kEndsWith); + } else if (op->CONTAINS()) { + f = NameToFunction(kContains); + } else { + throw utils::NotYetImplemented(); + } + auto expression2 = op->expression3b()->accept(this); + std::vector args = {expression, expression2}; + expression = + static_cast(storage_.Create(f, args)); } } return expression; @@ -709,8 +725,8 @@ antlrcpp::Any CypherMainVisitor::visitAtom(CypherParser::AtomContext *ctx) { return static_cast( storage_.Create(nullptr, Aggregation::Op::COUNT)); } - // TODO: Implement this. We don't support comprehensions, filtering... at the - // moment. + // TODO: Implement this. We don't support comprehensions, filtering... at + // the moment. throw utils::NotYetImplemented(); } @@ -794,7 +810,7 @@ antlrcpp::Any CypherMainVisitor::visitFunctionInvocation( } } auto function = NameToFunction(function_name); - if (!function) throw SemanticException(); + if (!function) throw SemanticException("Function doesn't exist."); return static_cast( storage_.Create(function, expressions)); } diff --git a/src/query/interpret/awesome_memgraph_functions.cpp b/src/query/interpret/awesome_memgraph_functions.cpp index 3a833695d..49ea09c5d 100644 --- a/src/query/interpret/awesome_memgraph_functions.cpp +++ b/src/query/interpret/awesome_memgraph_functions.cpp @@ -437,6 +437,50 @@ TypedValue Pi(const std::vector &args, GraphDbAccessor &) { } return M_PI; } + +template +TypedValue StringMatchOperator(const std::vector &args, + GraphDbAccessor &) { + if (args.size() != 2U) { + throw QueryRuntimeException( + "startsWith shouldn't be called with 2 arguments"); + } + bool has_null = false; + auto check_arg = [&](const TypedValue &t) { + if (t.IsNull()) { + has_null = true; + } else if (t.type() != TypedValue::Type::String) { + throw QueryRuntimeException("startsWith called with incompatible type"); + } + }; + check_arg(args[0]); + check_arg(args[1]); + if (has_null) return TypedValue::Null; + const auto &s1 = args[0].Value(); + const auto &s2 = args[1].Value(); + return Predicate(s1, s2); +} + +// Check if s1 starts with s2. +bool StartsWithPredicate(const std::string &s1, const std::string &s2) { + if (s1.size() < s2.size()) return false; + return std::equal(s2.begin(), s2.end(), s1.begin()); +} +auto StartsWith = StringMatchOperator; + +// Check if s1 ends with s2. +bool EndsWithPredicate(const std::string &s1, const std::string &s2) { + if (s1.size() < s2.size()) return false; + return std::equal(s2.rbegin(), s2.rend(), s1.rbegin()); +} +auto EndsWith = StringMatchOperator; + +// Check if s1 contains s2. +bool ContainsPredicate(const std::string &s1, const std::string &s2) { + if (s1.size() < s2.size()) return false; + return s1.find(s2) != std::string::npos; +} +auto Contains = StringMatchOperator; } std::function &, GraphDbAccessor &)> @@ -474,6 +518,9 @@ NameToFunction(const std::string &function_name) { if (function_name == "SIGN") return Sign; if (function_name == "E") return E; if (function_name == "PI") return Pi; + if (function_name == kStartsWith) return StartsWith; + if (function_name == kEndsWith) return EndsWith; + if (function_name == kContains) return Contains; return nullptr; } } diff --git a/src/query/interpret/awesome_memgraph_functions.hpp b/src/query/interpret/awesome_memgraph_functions.hpp index 2744d16ec..2374e6b88 100644 --- a/src/query/interpret/awesome_memgraph_functions.hpp +++ b/src/query/interpret/awesome_memgraph_functions.hpp @@ -7,6 +7,12 @@ namespace query { +namespace { +const char kStartsWith[] = "STARTSWITH"; +const char kEndsWith[] = "ENDSWITH"; +const char kContains[] = "CONTAINS"; +} + std::function &, GraphDbAccessor &)> NameToFunction(const std::string &function_name); } diff --git a/tests/unit/query_expression_evaluator.cpp b/tests/unit/query_expression_evaluator.cpp index 909a56a8d..15eb3b96d 100644 --- a/tests/unit/query_expression_evaluator.cpp +++ b/tests/unit/query_expression_evaluator.cpp @@ -23,6 +23,7 @@ using testing::ElementsAre; using query::test_common::ToInt64List; namespace { + struct NoContextExpressionEvaluator { NoContextExpressionEvaluator() {} Frame frame{128}; @@ -973,4 +974,37 @@ TEST(ExpressionEvaluator, FunctionPi) { ASSERT_THROW(EvaluateFunction("PI", {1}), QueryRuntimeException); ASSERT_DOUBLE_EQ(EvaluateFunction("PI", {}).Value(), M_PI); } + +TEST(ExpressionEvaluator, FunctionStartsWith) { + EXPECT_THROW(EvaluateFunction(kStartsWith, {}), QueryRuntimeException); + EXPECT_TRUE(EvaluateFunction(kStartsWith, {"a", TypedValue::Null}).IsNull()); + EXPECT_THROW(EvaluateFunction(kStartsWith, {TypedValue::Null, 1.3}), + QueryRuntimeException); + EXPECT_TRUE(EvaluateFunction(kStartsWith, {"abc", "abc"}).Value()); + EXPECT_TRUE(EvaluateFunction(kStartsWith, {"abcdef", "abc"}).Value()); + EXPECT_FALSE(EvaluateFunction(kStartsWith, {"abcdef", "aBc"}).Value()); + EXPECT_FALSE(EvaluateFunction(kStartsWith, {"abc", "abcd"}).Value()); +} + +TEST(ExpressionEvaluator, FunctionEndsWith) { + EXPECT_THROW(EvaluateFunction(kEndsWith, {}), QueryRuntimeException); + EXPECT_TRUE(EvaluateFunction(kEndsWith, {"a", TypedValue::Null}).IsNull()); + EXPECT_THROW(EvaluateFunction(kEndsWith, {TypedValue::Null, 1.3}), + QueryRuntimeException); + EXPECT_TRUE(EvaluateFunction(kEndsWith, {"abc", "abc"}).Value()); + EXPECT_TRUE(EvaluateFunction(kEndsWith, {"abcdef", "def"}).Value()); + EXPECT_FALSE(EvaluateFunction(kEndsWith, {"abcdef", "dEf"}).Value()); + EXPECT_FALSE(EvaluateFunction(kEndsWith, {"bcd", "abcd"}).Value()); +} + +TEST(ExpressionEvaluator, FunctionContains) { + EXPECT_THROW(EvaluateFunction(kContains, {}), QueryRuntimeException); + EXPECT_TRUE(EvaluateFunction(kContains, {"a", TypedValue::Null}).IsNull()); + EXPECT_THROW(EvaluateFunction(kContains, {TypedValue::Null, 1.3}), + QueryRuntimeException); + EXPECT_TRUE(EvaluateFunction(kContains, {"abc", "abc"}).Value()); + EXPECT_TRUE(EvaluateFunction(kContains, {"abcde", "bcd"}).Value()); + EXPECT_FALSE(EvaluateFunction(kContains, {"cde", "abcdef"}).Value()); + EXPECT_FALSE(EvaluateFunction(kContains, {"abcdef", "dEf"}).Value()); +} }