diff --git a/src/query/frontend/ast/ast.lcp b/src/query/frontend/ast/ast.lcp index ac2d33747..1332fd741 100644 --- a/src/query/frontend/ast/ast.lcp +++ b/src/query/frontend/ast/ast.lcp @@ -882,6 +882,46 @@ cpp<# cpp<#) (:serialize :capnp)) +(lcp:define-class coalesce (expression) + ((expressions "std::vector" + :scope :public + :capnp-type "List(Tree)" + :capnp-save (save-ast-vector "Expression *") + :capnp-load (load-ast-vector "Expression *") + :documentation "A list of expressions to evaluate. None of the expressions should be nullptr.")) + (:public + #>cpp + Coalesce() = default; + + DEFVISITABLE(TreeVisitor); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + for (auto *expr : expressions_) { + if (!expr->Accept(visitor)) break; + } + } + return visitor.PostVisit(*this); + } + + Coalesce *Clone(AstStorage &storage) const override { + std::vector expressions; + expressions.reserve(expressions_.size()); + for (const auto &expr : expressions_) { + expressions.emplace_back(expr->Clone(storage)); + } + return storage.Create(std::move(expressions)); + } + cpp<# + ) + (:private + #>cpp + Coalesce(int uid, const std::vector &expressions) + : Expression(uid), expressions_(expressions) {} + + friend class AstStorage; + cpp<#) + (:serialize :capnp)) + (lcp:define-class extract (expression) ((identifier "Identifier *" :initval "nullptr" :scope :public :capnp-type "Tree" :capnp-init nil diff --git a/src/query/frontend/ast/ast_visitor.hpp b/src/query/frontend/ast/ast_visitor.hpp index 73d394f23..7417430f8 100644 --- a/src/query/frontend/ast/ast_visitor.hpp +++ b/src/query/frontend/ast/ast_visitor.hpp @@ -15,6 +15,7 @@ class LabelsTest; class Aggregation; class Function; class Reduce; +class Coalesce; class Extract; class All; class Single; @@ -73,9 +74,10 @@ using TreeCompositeVisitor = ::utils::CompositeVisitor< LessEqualOperator, GreaterEqualOperator, InListOperator, SubscriptOperator, ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, ListLiteral, MapLiteral, PropertyLookup, LabelsTest, - Aggregation, Function, Reduce, Extract, All, Single, Create, Match, Return, - With, Pattern, NodeAtom, EdgeAtom, Delete, Where, SetProperty, - SetProperties, SetLabels, RemoveProperty, RemoveLabels, Merge, Unwind>; + Aggregation, Function, Reduce, Coalesce, Extract, All, Single, Create, + Match, Return, With, Pattern, NodeAtom, EdgeAtom, Delete, Where, + SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, Merge, + Unwind>; using TreeLeafVisitor = ::utils::LeafVisitor; + Coalesce, Extract, All, Single, ParameterLookup, Create, Match, Return, + With, Pattern, NodeAtom, EdgeAtom, Delete, Where, SetProperty, + SetProperties, SetLabels, RemoveProperty, RemoveLabels, Merge, Unwind, + Identifier, PrimitiveLiteral, IndexQuery, AuthQuery, StreamQuery>; } // namespace query diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index 1ff5fd33d..c25911afc 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -1318,6 +1318,13 @@ antlrcpp::Any CypherMainVisitor::visitAtom(MemgraphCypher::AtomContext *ctx) { return static_cast(storage_->Create(variable)); } else if (ctx->functionInvocation()) { return static_cast(ctx->functionInvocation()->accept(this)); + } else if (ctx->COALESCE()) { + std::vector exprs; + for (auto *expr_context : ctx->expression()) { + exprs.emplace_back(expr_context->accept(this).as()); + } + return static_cast( + storage_->Create(std::move(exprs))); } else if (ctx->COUNT()) { // Here we handle COUNT(*). COUNT(expression) is handled in // visitFunctionInvocation with other aggregations. This is visible in diff --git a/src/query/frontend/opencypher/grammar/Cypher.g4 b/src/query/frontend/opencypher/grammar/Cypher.g4 index a2f0612b2..75ae99f09 100644 --- a/src/query/frontend/opencypher/grammar/Cypher.g4 +++ b/src/query/frontend/opencypher/grammar/Cypher.g4 @@ -194,6 +194,7 @@ atom : literal | ( FILTER '(' filterExpression ')' ) | ( EXTRACT '(' extractExpression ')' ) | ( REDUCE '(' reduceExpression ')' ) + | ( COALESCE '(' expression ( ',' expression )* ')' ) | ( ALL '(' filterExpression ')' ) | ( ANY '(' filterExpression ')' ) | ( NONE '(' filterExpression ')' ) diff --git a/src/query/frontend/opencypher/grammar/CypherLexer.g4 b/src/query/frontend/opencypher/grammar/CypherLexer.g4 index 985b3ba6b..b68b7eab3 100644 --- a/src/query/frontend/opencypher/grammar/CypherLexer.g4 +++ b/src/query/frontend/opencypher/grammar/CypherLexer.g4 @@ -77,6 +77,7 @@ ASCENDING : A S C E N D I N G ; BFS : B F S ; BY : B Y ; CASE : C A S E ; +COALESCE : C O A L E S C E ; CONTAINS : C O N T A I N S ; COUNT : C O U N T ; CREATE : C R E A T E ; diff --git a/src/query/frontend/stripped_lexer_constants.hpp b/src/query/frontend/stripped_lexer_constants.hpp index 7078717e5..2da9b7d72 100644 --- a/src/query/frontend/stripped_lexer_constants.hpp +++ b/src/query/frontend/stripped_lexer_constants.hpp @@ -79,18 +79,18 @@ class Trie { const int kBitsetSize = 65536; const trie::Trie kKeywords = { - "union", "all", "optional", "match", "unwind", "as", - "merge", "on", "create", "set", "detach", "delete", - "remove", "with", "distinct", "return", "order", "by", - "skip", "limit", "ascending", "asc", "descending", "desc", - "where", "or", "xor", "and", "not", "in", - "starts", "ends", "contains", "is", "null", "case", - "when", "then", "else", "end", "count", "filter", - "extract", "any", "none", "single", "true", "false", - "reduce", "user", "password", "alter", "drop", "stream", - "streams", "load", "data", "kafka", "transform", "batch", - "interval", "show", "start", "stop", "size", "topic", - "test", "unique", "explain"}; + "union", "all", "optional", "match", "unwind", "as", + "merge", "on", "create", "set", "detach", "delete", + "remove", "with", "distinct", "return", "order", "by", + "skip", "limit", "ascending", "asc", "descending", "desc", + "where", "or", "xor", "and", "not", "in", + "starts", "ends", "contains", "is", "null", "case", + "when", "then", "else", "end", "count", "filter", + "extract", "any", "none", "single", "true", "false", + "reduce", "coalesce", "user", "password", "alter", "drop", + "stream", "streams", "load", "data", "kafka", "transform", + "batch", "interval", "show", "start", "stop", "size", + "topic", "test", "unique", "explain"}; // Unicode codepoints that are allowed at the start of the unescaped name. const std::bitset kUnescapedNameAllowedStarts(std::string( diff --git a/src/query/interpret/awesome_memgraph_functions.cpp b/src/query/interpret/awesome_memgraph_functions.cpp index 2cd292b1d..a57802a29 100644 --- a/src/query/interpret/awesome_memgraph_functions.cpp +++ b/src/query/interpret/awesome_memgraph_functions.cpp @@ -32,21 +32,6 @@ namespace { // TODO: Implement degrees, haversin, radians // TODO: Implement spatial functions -TypedValue Coalesce(TypedValue *args, int64_t nargs, const EvaluationContext &, - database::GraphDbAccessor *) { - // TODO: Perhaps this function should be done by the evaluator itself, so as - // to avoid evaluating all the arguments. - if (nargs == 0) { - throw QueryRuntimeException("'coalesce' requires at least one argument."); - } - for (int64_t i = 0; i < nargs; ++i) { - if (args[i].type() != TypedValue::Type::Null) { - return args[i]; - } - } - return TypedValue::Null; -} - TypedValue EndNode(TypedValue *args, int64_t nargs, const EvaluationContext &, database::GraphDbAccessor *) { if (nargs != 1) { @@ -887,7 +872,6 @@ std::function NameToFunction(const std::string &function_name) { // Scalar functions - if (function_name == "COALESCE") return Coalesce; if (function_name == "DEGREE") return Degree; if (function_name == "INDEGREE") return InDegree; if (function_name == "OUTDEGREE") return OutDegree; diff --git a/src/query/interpret/awesome_memgraph_functions.hpp b/src/query/interpret/awesome_memgraph_functions.hpp index f0ee45d1d..dd4602d6f 100644 --- a/src/query/interpret/awesome_memgraph_functions.hpp +++ b/src/query/interpret/awesome_memgraph_functions.hpp @@ -13,7 +13,6 @@ namespace { const char kStartsWith[] = "STARTSWITH"; const char kEndsWith[] = "ENDSWITH"; const char kContains[] = "CONTAINS"; -const char kCoalesce[] = "COALESCE"; } // namespace /// Return the function implementation with the given name. diff --git a/src/query/interpret/eval.hpp b/src/query/interpret/eval.hpp index 512c0deac..bb1091db9 100644 --- a/src/query/interpret/eval.hpp +++ b/src/query/interpret/eval.hpp @@ -367,25 +367,24 @@ class ExpressionEvaluator : public TreeVisitor { return value; } - TypedValue Visit(Function &function) override { - // Handle COALESCE specially -- evaluate the arguments in order until one of - // them produces a non-null value. - if (function.function_name_ == kCoalesce) { - if (function.arguments_.size() == 0) { - throw QueryRuntimeException( - "'coalesce' requires at least one argument."); - } + TypedValue Visit(Coalesce &coalesce) override { + auto &exprs = coalesce.expressions_; - for (int64_t i = 0; i < function.arguments_.size(); ++i) { - TypedValue val = function.arguments_[i]->Accept(*this); - if (val.type() != TypedValue::Type::Null) { - return val; - } - } - - return TypedValue::Null; + if (exprs.size() == 0) { + throw QueryRuntimeException("'coalesce' requires at least one argument."); } + for (int64_t i = 0; i < exprs.size(); ++i) { + TypedValue val = exprs[i]->Accept(*this); + if (!val.IsNull()) { + return val; + } + } + + return TypedValue::Null; + } + + TypedValue Visit(Function &function) override { // Stack allocate evaluated arguments when there's a small number of them. if (function.arguments_.size() <= 8) { TypedValue arguments[8]; diff --git a/src/query/plan/rule_based_planner.cpp b/src/query/plan/rule_based_planner.cpp index b06692360..1fba2122a 100644 --- a/src/query/plan/rule_based_planner.cpp +++ b/src/query/plan/rule_based_planner.cpp @@ -182,7 +182,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor { public: bool PostVisit(ListLiteral &list_literal) override { DCHECK(list_literal.elements_.size() <= has_aggregation_.size()) - << "Expected has_aggregation_ flags as much as there are list " + << "Expected as many has_aggregation_ flags as there are list" "elements."; PostVisitCollectionLiteral(list_literal, [](auto it) { return *it; }); return true; @@ -241,6 +241,19 @@ class ReturnBodyContext : public HierarchicalTreeVisitor { return true; } + bool PostVisit(Coalesce &coalesce) override { + CHECK(has_aggregation_.size() >= coalesce.expressions_.size()) + << "Expected >= " << has_aggregation_.size() + << "has_aggregation_ flags for COALESCE arguments"; + bool has_aggr = false; + for (int i = 0; i < coalesce.expressions_.size(); ++i) { + has_aggr = has_aggr || has_aggregation_.back(); + has_aggregation_.pop_back(); + } + has_aggregation_.emplace_back(has_aggr); + return true; + } + bool PostVisit(Extract &extract) override { // Remove the symbol bound by extract, because we are only interested // in free (unbound) symbols. @@ -310,7 +323,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor { bool PostVisit(Function &function) override { DCHECK(function.arguments_.size() <= has_aggregation_.size()) - << "Expected has_aggregation_ flags as much as there are " + << "Expected as many has_aggregation_ flags as there are" "function arguments."; bool has_aggr = false; auto it = has_aggregation_.end(); diff --git a/tests/unit/query_common.hpp b/tests/unit/query_common.hpp index b826656e8..263dd4006 100644 --- a/tests/unit/query_common.hpp +++ b/tests/unit/query_common.hpp @@ -602,6 +602,8 @@ auto GetMerge(AstStorage &storage, Pattern *pattern, OnMatch on_match, storage.Create( \ storage.Create(accumulator), initializer, \ storage.Create(variable), list, expr) +#define COALESCE(...) \ + storage.Create(std::vector{__VA_ARGS__}) #define EXTRACT(variable, list, expr) \ storage.Create(storage.Create(variable), \ list, expr) diff --git a/tests/unit/query_expression_evaluator.cpp b/tests/unit/query_expression_evaluator.cpp index 259bbb2ef..80035a5ae 100644 --- a/tests/unit/query_expression_evaluator.cpp +++ b/tests/unit/query_expression_evaluator.cpp @@ -812,6 +812,40 @@ TEST_F(ExpressionEvaluatorTest, FunctionExtractExceptions) { EXPECT_THROW(extract->Accept(eval), QueryRuntimeException); } +TEST_F(ExpressionEvaluatorTest, Coalesce) { + // coalesce() + EXPECT_THROW(COALESCE()->Accept(eval), QueryRuntimeException); + + // coalesce(null, null) + EXPECT_TRUE(COALESCE(LITERAL(TypedValue::Null), LITERAL(TypedValue::Null)) + ->Accept(eval) + .IsNull()); + + // coalesce(null, 2, 3) + EXPECT_EQ(COALESCE(LITERAL(TypedValue::Null), LITERAL(2), LITERAL(3)) + ->Accept(eval) + .ValueInt(), + 2); + + // coalesce(null, 2, assert(false), 3) + EXPECT_EQ(COALESCE(LITERAL(TypedValue::Null), LITERAL(2), + FN("ASSERT", LITERAL(false)), LITERAL(3)) + ->Accept(eval) + .ValueInt(), + 2); + + // (null, assert(false)) + EXPECT_THROW(COALESCE(LITERAL(TypedValue::Null), FN("ASSERT", LITERAL(false))) + ->Accept(eval), + QueryRuntimeException); + + // coalesce([null, null]) + EXPECT_FALSE(COALESCE(LITERAL(TypedValue(std::vector{ + TypedValue::Null, TypedValue::Null}))) + ->Accept(eval) + .IsNull()); +} + class ExpressionEvaluatorPropertyLookup : public ExpressionEvaluatorTest { protected: std::pair prop_age = @@ -891,28 +925,6 @@ class FunctionTest : public ExpressionEvaluatorTest { } }; -TEST_F(FunctionTest, Coalesce) { - ASSERT_THROW(EvaluateFunction("COALESCE", {}), QueryRuntimeException); - ASSERT_TRUE(EvaluateFunction("COALESCE", {TypedValue::Null, TypedValue::Null}) - .IsNull()); - ASSERT_EQ(EvaluateFunction("COALESCE", {TypedValue::Null, 2, 3}).ValueInt(), - 2); - - // (null, 2, assert(false), 3) - auto expressions1 = ExpressionsFromTypedValues({TypedValue::Null, 2, 3}); - expressions1.insert( - expressions1.begin() + 2, - storage.Create("ASSERT", ExpressionsFromTypedValues({false}))); - ASSERT_EQ(EvaluateFunctionWithExprs("COALESCE", expressions1).ValueInt(), 2); - - // (null, assert(false)) - auto expressions2 = ExpressionsFromTypedValues({TypedValue::Null}); - expressions2.push_back( - storage.Create("ASSERT", ExpressionsFromTypedValues({false}))); - ASSERT_THROW(EvaluateFunctionWithExprs("COALESCE", expressions2), - QueryRuntimeException); -} - TEST_F(FunctionTest, EndNode) { ASSERT_THROW(EvaluateFunction("ENDNODE", {}), QueryRuntimeException); ASSERT_TRUE(EvaluateFunction("ENDNODE", {TypedValue::Null}).IsNull());