Implement STARTS WITH, ENDS WITH, CONTAINS

Reviewers: buda, teon.banek, florijan

Reviewed By: buda

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D389
This commit is contained in:
Mislav Bradac 2017-05-19 19:25:54 +02:00
parent f9cd87bb46
commit 839d63284b
4 changed files with 108 additions and 5 deletions

View File

@ -45,7 +45,8 @@ antlrcpp::Any CypherMainVisitor::visitSingleQuery(
for (Clause *clause : query_->clauses_) {
if (dynamic_cast<Unwind *>(clause)) {
if (has_update || has_return) {
throw SemanticException("Unwind can't be after return or update clause");
throw SemanticException(
"Unwind can't be after return or update clause");
}
} else if (auto *match = dynamic_cast<Match *>(clause)) {
if (has_update || has_return) {
@ -619,7 +620,22 @@ antlrcpp::Any CypherMainVisitor::visitExpression3a(
expression = static_cast<Expression *>(storage_.Create<InListOperator>(
expression, op->expression3b()->accept(this)));
} else {
throw utils::NotYetImplemented();
std::function<TypedValue(const std::vector<TypedValue> &,
GraphDbAccessor &)>
f;
if (op->STARTS() && op->WITH()) {
f = NameToFunction(kStartsWith);
} else if (op->ENDS() && op->WITH()) {
f = NameToFunction(kEndsWith);
} else if (op->CONTAINS()) {
f = NameToFunction(kContains);
} else {
throw utils::NotYetImplemented();
}
auto expression2 = op->expression3b()->accept(this);
std::vector<Expression *> args = {expression, expression2};
expression =
static_cast<Expression *>(storage_.Create<Function>(f, args));
}
}
return expression;
@ -709,8 +725,8 @@ antlrcpp::Any CypherMainVisitor::visitAtom(CypherParser::AtomContext *ctx) {
return static_cast<Expression *>(
storage_.Create<Aggregation>(nullptr, Aggregation::Op::COUNT));
}
// TODO: Implement this. We don't support comprehensions, filtering... at the
// moment.
// TODO: Implement this. We don't support comprehensions, filtering... at
// the moment.
throw utils::NotYetImplemented();
}
@ -794,7 +810,7 @@ antlrcpp::Any CypherMainVisitor::visitFunctionInvocation(
}
}
auto function = NameToFunction(function_name);
if (!function) throw SemanticException();
if (!function) throw SemanticException("Function doesn't exist.");
return static_cast<Expression *>(
storage_.Create<Function>(function, expressions));
}

View File

@ -437,6 +437,50 @@ TypedValue Pi(const std::vector<TypedValue> &args, GraphDbAccessor &) {
}
return M_PI;
}
template <bool (*Predicate)(const std::string &s1, const std::string &s2)>
TypedValue StringMatchOperator(const std::vector<TypedValue> &args,
GraphDbAccessor &) {
if (args.size() != 2U) {
throw QueryRuntimeException(
"startsWith shouldn't be called with 2 arguments");
}
bool has_null = false;
auto check_arg = [&](const TypedValue &t) {
if (t.IsNull()) {
has_null = true;
} else if (t.type() != TypedValue::Type::String) {
throw QueryRuntimeException("startsWith called with incompatible type");
}
};
check_arg(args[0]);
check_arg(args[1]);
if (has_null) return TypedValue::Null;
const auto &s1 = args[0].Value<std::string>();
const auto &s2 = args[1].Value<std::string>();
return Predicate(s1, s2);
}
// Check if s1 starts with s2.
bool StartsWithPredicate(const std::string &s1, const std::string &s2) {
if (s1.size() < s2.size()) return false;
return std::equal(s2.begin(), s2.end(), s1.begin());
}
auto StartsWith = StringMatchOperator<StartsWithPredicate>;
// Check if s1 ends with s2.
bool EndsWithPredicate(const std::string &s1, const std::string &s2) {
if (s1.size() < s2.size()) return false;
return std::equal(s2.rbegin(), s2.rend(), s1.rbegin());
}
auto EndsWith = StringMatchOperator<EndsWithPredicate>;
// Check if s1 contains s2.
bool ContainsPredicate(const std::string &s1, const std::string &s2) {
if (s1.size() < s2.size()) return false;
return s1.find(s2) != std::string::npos;
}
auto Contains = StringMatchOperator<ContainsPredicate>;
}
std::function<TypedValue(const std::vector<TypedValue> &, GraphDbAccessor &)>
@ -474,6 +518,9 @@ NameToFunction(const std::string &function_name) {
if (function_name == "SIGN") return Sign;
if (function_name == "E") return E;
if (function_name == "PI") return Pi;
if (function_name == kStartsWith) return StartsWith;
if (function_name == kEndsWith) return EndsWith;
if (function_name == kContains) return Contains;
return nullptr;
}
}

View File

@ -7,6 +7,12 @@
namespace query {
namespace {
const char kStartsWith[] = "STARTSWITH";
const char kEndsWith[] = "ENDSWITH";
const char kContains[] = "CONTAINS";
}
std::function<TypedValue(const std::vector<TypedValue> &, GraphDbAccessor &)>
NameToFunction(const std::string &function_name);
}

View File

@ -23,6 +23,7 @@ using testing::ElementsAre;
using query::test_common::ToInt64List;
namespace {
struct NoContextExpressionEvaluator {
NoContextExpressionEvaluator() {}
Frame frame{128};
@ -973,4 +974,37 @@ TEST(ExpressionEvaluator, FunctionPi) {
ASSERT_THROW(EvaluateFunction("PI", {1}), QueryRuntimeException);
ASSERT_DOUBLE_EQ(EvaluateFunction("PI", {}).Value<double>(), M_PI);
}
TEST(ExpressionEvaluator, FunctionStartsWith) {
EXPECT_THROW(EvaluateFunction(kStartsWith, {}), QueryRuntimeException);
EXPECT_TRUE(EvaluateFunction(kStartsWith, {"a", TypedValue::Null}).IsNull());
EXPECT_THROW(EvaluateFunction(kStartsWith, {TypedValue::Null, 1.3}),
QueryRuntimeException);
EXPECT_TRUE(EvaluateFunction(kStartsWith, {"abc", "abc"}).Value<bool>());
EXPECT_TRUE(EvaluateFunction(kStartsWith, {"abcdef", "abc"}).Value<bool>());
EXPECT_FALSE(EvaluateFunction(kStartsWith, {"abcdef", "aBc"}).Value<bool>());
EXPECT_FALSE(EvaluateFunction(kStartsWith, {"abc", "abcd"}).Value<bool>());
}
TEST(ExpressionEvaluator, FunctionEndsWith) {
EXPECT_THROW(EvaluateFunction(kEndsWith, {}), QueryRuntimeException);
EXPECT_TRUE(EvaluateFunction(kEndsWith, {"a", TypedValue::Null}).IsNull());
EXPECT_THROW(EvaluateFunction(kEndsWith, {TypedValue::Null, 1.3}),
QueryRuntimeException);
EXPECT_TRUE(EvaluateFunction(kEndsWith, {"abc", "abc"}).Value<bool>());
EXPECT_TRUE(EvaluateFunction(kEndsWith, {"abcdef", "def"}).Value<bool>());
EXPECT_FALSE(EvaluateFunction(kEndsWith, {"abcdef", "dEf"}).Value<bool>());
EXPECT_FALSE(EvaluateFunction(kEndsWith, {"bcd", "abcd"}).Value<bool>());
}
TEST(ExpressionEvaluator, FunctionContains) {
EXPECT_THROW(EvaluateFunction(kContains, {}), QueryRuntimeException);
EXPECT_TRUE(EvaluateFunction(kContains, {"a", TypedValue::Null}).IsNull());
EXPECT_THROW(EvaluateFunction(kContains, {TypedValue::Null, 1.3}),
QueryRuntimeException);
EXPECT_TRUE(EvaluateFunction(kContains, {"abc", "abc"}).Value<bool>());
EXPECT_TRUE(EvaluateFunction(kContains, {"abcde", "bcd"}).Value<bool>());
EXPECT_FALSE(EvaluateFunction(kContains, {"cde", "abcdef"}).Value<bool>());
EXPECT_FALSE(EvaluateFunction(kContains, {"abcdef", "dEf"}).Value<bool>());
}
}