From 4b5c0d3426b626d595a46afef6d4451636fc94ae Mon Sep 17 00:00:00 2001 From: Lovro Lugovic Date: Thu, 11 Oct 2018 15:04:49 +0200 Subject: [PATCH] Implement `coalesce` as a special operator Reviewers: teon.banek, mtomic Reviewed By: mtomic Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D1652 --- .../interpret/awesome_memgraph_functions.hpp | 1 + src/query/interpret/eval.hpp | 18 +++++++++ tests/unit/query_expression_evaluator.cpp | 37 +++++++++++++++++-- 3 files changed, 52 insertions(+), 4 deletions(-) diff --git a/src/query/interpret/awesome_memgraph_functions.hpp b/src/query/interpret/awesome_memgraph_functions.hpp index dd4602d6f..f0ee45d1d 100644 --- a/src/query/interpret/awesome_memgraph_functions.hpp +++ b/src/query/interpret/awesome_memgraph_functions.hpp @@ -13,6 +13,7 @@ namespace { const char kStartsWith[] = "STARTSWITH"; const char kEndsWith[] = "ENDSWITH"; const char kContains[] = "CONTAINS"; +const char kCoalesce[] = "COALESCE"; } // namespace /// Return the function implementation with the given name. diff --git a/src/query/interpret/eval.hpp b/src/query/interpret/eval.hpp index 1ec7a5e14..2bd8712cf 100644 --- a/src/query/interpret/eval.hpp +++ b/src/query/interpret/eval.hpp @@ -371,6 +371,24 @@ class ExpressionEvaluator : public TreeVisitor { } TypedValue Visit(Function &function) override { + // Handle COALESCE specially -- evaluate the arguments in order until one of + // them produces a non-null value. + if (function.function_name_ == kCoalesce) { + if (function.arguments_.size() == 0) { + throw QueryRuntimeException( + "'coalesce' requires at least one argument."); + } + + for (int64_t i = 0; i < function.arguments_.size(); ++i) { + TypedValue val = function.arguments_[i]->Accept(*this); + if (val.type() != TypedValue::Type::Null) { + return val; + } + } + + return TypedValue::Null; + } + // Stack allocate evaluated arguments when there's a small number of them. if (function.arguments_.size() <= 8) { TypedValue arguments[8]; diff --git a/tests/unit/query_expression_evaluator.cpp b/tests/unit/query_expression_evaluator.cpp index 71a7e52f0..d05d86246 100644 --- a/tests/unit/query_expression_evaluator.cpp +++ b/tests/unit/query_expression_evaluator.cpp @@ -860,20 +860,35 @@ TEST_F(ExpressionEvaluatorPropertyLookup, MapLiteral) { class FunctionTest : public ExpressionEvaluatorTest { protected: - TypedValue EvaluateFunction(const std::string &function_name, - const std::vector &args) { + std::vector ExpressionsFromTypedValues( + const std::vector &tvs) { std::vector expressions; - for (size_t i = 0; i < args.size(); ++i) { + expressions.reserve(tvs.size()); + + for (size_t i = 0; i < tvs.size(); ++i) { auto *ident = storage.Create("arg_" + std::to_string(i), true); auto sym = symbol_table.CreateSymbol("arg_" + std::to_string(i), true); symbol_table[*ident] = sym; - frame[sym] = args[i]; + frame[sym] = tvs[i]; expressions.push_back(ident); } + + return expressions; + } + + TypedValue EvaluateFunctionWithExprs( + const std::string &function_name, + const std::vector &expressions) { auto *op = storage.Create(function_name, expressions); return op->Accept(eval); } + + TypedValue EvaluateFunction(const std::string &function_name, + const std::vector &args) { + return EvaluateFunctionWithExprs(function_name, + ExpressionsFromTypedValues(args)); + } }; TEST_F(FunctionTest, Coalesce) { @@ -882,6 +897,20 @@ TEST_F(FunctionTest, Coalesce) { .IsNull()); ASSERT_EQ(EvaluateFunction("COALESCE", {TypedValue::Null, 2, 3}).ValueInt(), 2); + + // (null, 2, assert(false), 3) + auto expressions1 = ExpressionsFromTypedValues({TypedValue::Null, 2, 3}); + expressions1.insert( + expressions1.begin() + 2, + storage.Create("ASSERT", ExpressionsFromTypedValues({false}))); + ASSERT_EQ(EvaluateFunctionWithExprs("COALESCE", expressions1).ValueInt(), 2); + + // (null, assert(false)) + auto expressions2 = ExpressionsFromTypedValues({TypedValue::Null}); + expressions2.push_back( + storage.Create("ASSERT", ExpressionsFromTypedValues({false}))); + ASSERT_THROW(EvaluateFunctionWithExprs("COALESCE", expressions2), + QueryRuntimeException); } TEST_F(FunctionTest, EndNode) {