Implement STARTS WITH, ENDS WITH, CONTAINS
Reviewers: buda, teon.banek, florijan Reviewed By: buda Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D389
This commit is contained in:
parent
f9cd87bb46
commit
839d63284b
@ -45,7 +45,8 @@ antlrcpp::Any CypherMainVisitor::visitSingleQuery(
|
||||
for (Clause *clause : query_->clauses_) {
|
||||
if (dynamic_cast<Unwind *>(clause)) {
|
||||
if (has_update || has_return) {
|
||||
throw SemanticException("Unwind can't be after return or update clause");
|
||||
throw SemanticException(
|
||||
"Unwind can't be after return or update clause");
|
||||
}
|
||||
} else if (auto *match = dynamic_cast<Match *>(clause)) {
|
||||
if (has_update || has_return) {
|
||||
@ -619,7 +620,22 @@ antlrcpp::Any CypherMainVisitor::visitExpression3a(
|
||||
expression = static_cast<Expression *>(storage_.Create<InListOperator>(
|
||||
expression, op->expression3b()->accept(this)));
|
||||
} else {
|
||||
throw utils::NotYetImplemented();
|
||||
std::function<TypedValue(const std::vector<TypedValue> &,
|
||||
GraphDbAccessor &)>
|
||||
f;
|
||||
if (op->STARTS() && op->WITH()) {
|
||||
f = NameToFunction(kStartsWith);
|
||||
} else if (op->ENDS() && op->WITH()) {
|
||||
f = NameToFunction(kEndsWith);
|
||||
} else if (op->CONTAINS()) {
|
||||
f = NameToFunction(kContains);
|
||||
} else {
|
||||
throw utils::NotYetImplemented();
|
||||
}
|
||||
auto expression2 = op->expression3b()->accept(this);
|
||||
std::vector<Expression *> args = {expression, expression2};
|
||||
expression =
|
||||
static_cast<Expression *>(storage_.Create<Function>(f, args));
|
||||
}
|
||||
}
|
||||
return expression;
|
||||
@ -709,8 +725,8 @@ antlrcpp::Any CypherMainVisitor::visitAtom(CypherParser::AtomContext *ctx) {
|
||||
return static_cast<Expression *>(
|
||||
storage_.Create<Aggregation>(nullptr, Aggregation::Op::COUNT));
|
||||
}
|
||||
// TODO: Implement this. We don't support comprehensions, filtering... at the
|
||||
// moment.
|
||||
// TODO: Implement this. We don't support comprehensions, filtering... at
|
||||
// the moment.
|
||||
throw utils::NotYetImplemented();
|
||||
}
|
||||
|
||||
@ -794,7 +810,7 @@ antlrcpp::Any CypherMainVisitor::visitFunctionInvocation(
|
||||
}
|
||||
}
|
||||
auto function = NameToFunction(function_name);
|
||||
if (!function) throw SemanticException();
|
||||
if (!function) throw SemanticException("Function doesn't exist.");
|
||||
return static_cast<Expression *>(
|
||||
storage_.Create<Function>(function, expressions));
|
||||
}
|
||||
|
@ -437,6 +437,50 @@ TypedValue Pi(const std::vector<TypedValue> &args, GraphDbAccessor &) {
|
||||
}
|
||||
return M_PI;
|
||||
}
|
||||
|
||||
template <bool (*Predicate)(const std::string &s1, const std::string &s2)>
|
||||
TypedValue StringMatchOperator(const std::vector<TypedValue> &args,
|
||||
GraphDbAccessor &) {
|
||||
if (args.size() != 2U) {
|
||||
throw QueryRuntimeException(
|
||||
"startsWith shouldn't be called with 2 arguments");
|
||||
}
|
||||
bool has_null = false;
|
||||
auto check_arg = [&](const TypedValue &t) {
|
||||
if (t.IsNull()) {
|
||||
has_null = true;
|
||||
} else if (t.type() != TypedValue::Type::String) {
|
||||
throw QueryRuntimeException("startsWith called with incompatible type");
|
||||
}
|
||||
};
|
||||
check_arg(args[0]);
|
||||
check_arg(args[1]);
|
||||
if (has_null) return TypedValue::Null;
|
||||
const auto &s1 = args[0].Value<std::string>();
|
||||
const auto &s2 = args[1].Value<std::string>();
|
||||
return Predicate(s1, s2);
|
||||
}
|
||||
|
||||
// Check if s1 starts with s2.
|
||||
bool StartsWithPredicate(const std::string &s1, const std::string &s2) {
|
||||
if (s1.size() < s2.size()) return false;
|
||||
return std::equal(s2.begin(), s2.end(), s1.begin());
|
||||
}
|
||||
auto StartsWith = StringMatchOperator<StartsWithPredicate>;
|
||||
|
||||
// Check if s1 ends with s2.
|
||||
bool EndsWithPredicate(const std::string &s1, const std::string &s2) {
|
||||
if (s1.size() < s2.size()) return false;
|
||||
return std::equal(s2.rbegin(), s2.rend(), s1.rbegin());
|
||||
}
|
||||
auto EndsWith = StringMatchOperator<EndsWithPredicate>;
|
||||
|
||||
// Check if s1 contains s2.
|
||||
bool ContainsPredicate(const std::string &s1, const std::string &s2) {
|
||||
if (s1.size() < s2.size()) return false;
|
||||
return s1.find(s2) != std::string::npos;
|
||||
}
|
||||
auto Contains = StringMatchOperator<ContainsPredicate>;
|
||||
}
|
||||
|
||||
std::function<TypedValue(const std::vector<TypedValue> &, GraphDbAccessor &)>
|
||||
@ -474,6 +518,9 @@ NameToFunction(const std::string &function_name) {
|
||||
if (function_name == "SIGN") return Sign;
|
||||
if (function_name == "E") return E;
|
||||
if (function_name == "PI") return Pi;
|
||||
if (function_name == kStartsWith) return StartsWith;
|
||||
if (function_name == kEndsWith) return EndsWith;
|
||||
if (function_name == kContains) return Contains;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
@ -7,6 +7,12 @@
|
||||
|
||||
namespace query {
|
||||
|
||||
namespace {
|
||||
const char kStartsWith[] = "STARTSWITH";
|
||||
const char kEndsWith[] = "ENDSWITH";
|
||||
const char kContains[] = "CONTAINS";
|
||||
}
|
||||
|
||||
std::function<TypedValue(const std::vector<TypedValue> &, GraphDbAccessor &)>
|
||||
NameToFunction(const std::string &function_name);
|
||||
}
|
||||
|
@ -23,6 +23,7 @@ using testing::ElementsAre;
|
||||
using query::test_common::ToInt64List;
|
||||
|
||||
namespace {
|
||||
|
||||
struct NoContextExpressionEvaluator {
|
||||
NoContextExpressionEvaluator() {}
|
||||
Frame frame{128};
|
||||
@ -973,4 +974,37 @@ TEST(ExpressionEvaluator, FunctionPi) {
|
||||
ASSERT_THROW(EvaluateFunction("PI", {1}), QueryRuntimeException);
|
||||
ASSERT_DOUBLE_EQ(EvaluateFunction("PI", {}).Value<double>(), M_PI);
|
||||
}
|
||||
|
||||
TEST(ExpressionEvaluator, FunctionStartsWith) {
|
||||
EXPECT_THROW(EvaluateFunction(kStartsWith, {}), QueryRuntimeException);
|
||||
EXPECT_TRUE(EvaluateFunction(kStartsWith, {"a", TypedValue::Null}).IsNull());
|
||||
EXPECT_THROW(EvaluateFunction(kStartsWith, {TypedValue::Null, 1.3}),
|
||||
QueryRuntimeException);
|
||||
EXPECT_TRUE(EvaluateFunction(kStartsWith, {"abc", "abc"}).Value<bool>());
|
||||
EXPECT_TRUE(EvaluateFunction(kStartsWith, {"abcdef", "abc"}).Value<bool>());
|
||||
EXPECT_FALSE(EvaluateFunction(kStartsWith, {"abcdef", "aBc"}).Value<bool>());
|
||||
EXPECT_FALSE(EvaluateFunction(kStartsWith, {"abc", "abcd"}).Value<bool>());
|
||||
}
|
||||
|
||||
TEST(ExpressionEvaluator, FunctionEndsWith) {
|
||||
EXPECT_THROW(EvaluateFunction(kEndsWith, {}), QueryRuntimeException);
|
||||
EXPECT_TRUE(EvaluateFunction(kEndsWith, {"a", TypedValue::Null}).IsNull());
|
||||
EXPECT_THROW(EvaluateFunction(kEndsWith, {TypedValue::Null, 1.3}),
|
||||
QueryRuntimeException);
|
||||
EXPECT_TRUE(EvaluateFunction(kEndsWith, {"abc", "abc"}).Value<bool>());
|
||||
EXPECT_TRUE(EvaluateFunction(kEndsWith, {"abcdef", "def"}).Value<bool>());
|
||||
EXPECT_FALSE(EvaluateFunction(kEndsWith, {"abcdef", "dEf"}).Value<bool>());
|
||||
EXPECT_FALSE(EvaluateFunction(kEndsWith, {"bcd", "abcd"}).Value<bool>());
|
||||
}
|
||||
|
||||
TEST(ExpressionEvaluator, FunctionContains) {
|
||||
EXPECT_THROW(EvaluateFunction(kContains, {}), QueryRuntimeException);
|
||||
EXPECT_TRUE(EvaluateFunction(kContains, {"a", TypedValue::Null}).IsNull());
|
||||
EXPECT_THROW(EvaluateFunction(kContains, {TypedValue::Null, 1.3}),
|
||||
QueryRuntimeException);
|
||||
EXPECT_TRUE(EvaluateFunction(kContains, {"abc", "abc"}).Value<bool>());
|
||||
EXPECT_TRUE(EvaluateFunction(kContains, {"abcde", "bcd"}).Value<bool>());
|
||||
EXPECT_FALSE(EvaluateFunction(kContains, {"cde", "abcdef"}).Value<bool>());
|
||||
EXPECT_FALSE(EvaluateFunction(kContains, {"abcdef", "dEf"}).Value<bool>());
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user