diff --git a/src/query/frontend/ast/ast.lcp b/src/query/frontend/ast/ast.lcp index 238bc1d7b..70f8678cb 100644 --- a/src/query/frontend/ast/ast.lcp +++ b/src/query/frontend/ast/ast.lcp @@ -1028,6 +1028,47 @@ cpp<# (:serialize (:slk)) (:clone)) +;; TODO: This is pretty much copy pasted from All. Consider merging Reduce, +;; All, Any and Single into something like a higher-order function call which +;; takes a list argument and a function which is applied on list elements. +(lcp:define-class any (expression) + ((identifier "Identifier *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Identifier")) + (list-expression "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (where "Where *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Where"))) + (:public + #>cpp + Any() = default; + + DEFVISITABLE(ExpressionVisitor); + DEFVISITABLE(ExpressionVisitor); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + identifier_->Accept(visitor) && list_expression_->Accept(visitor) && + where_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + cpp<#) + (:protected + #>cpp + Any(Identifier *identifier, Expression *list_expression, Where *where) + : identifier_(identifier), + list_expression_(list_expression), + where_(where) {} + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + (lcp:define-class parameter-lookup (expression) ((token-position :int32_t :initval -1 :scope :public :documentation "This field contains token position of *literal* used to create ParameterLookup object. If ParameterLookup object is not created from a literal leave this value at -1.")) diff --git a/src/query/frontend/ast/ast_visitor.hpp b/src/query/frontend/ast/ast_visitor.hpp index eba58818e..ffc6229fd 100644 --- a/src/query/frontend/ast/ast_visitor.hpp +++ b/src/query/frontend/ast/ast_visitor.hpp @@ -19,6 +19,7 @@ class Coalesce; class Extract; class All; class Single; +class Any; class ParameterLookup; class CallProcedure; class Create; @@ -79,7 +80,7 @@ using TreeCompositeVisitor = ::utils::CompositeVisitor< GreaterEqualOperator, InListOperator, SubscriptOperator, ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, ListLiteral, MapLiteral, PropertyLookup, LabelsTest, - Aggregation, Function, Reduce, Coalesce, Extract, All, Single, + Aggregation, Function, Reduce, Coalesce, Extract, All, Single, Any, CallProcedure, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom, Delete, Where, SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, Merge, Unwind, RegexMatch>; @@ -107,8 +108,8 @@ class ExpressionVisitor SubscriptOperator, ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, ListLiteral, MapLiteral, PropertyLookup, LabelsTest, Aggregation, Function, Reduce, Coalesce, - Extract, All, Single, ParameterLookup, Identifier, PrimitiveLiteral, - RegexMatch> {}; + Extract, All, Single, Any, ParameterLookup, Identifier, + PrimitiveLiteral, RegexMatch> {}; template class QueryVisitor diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index 023ac2448..9e08779b7 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -1377,6 +1377,20 @@ antlrcpp::Any CypherMainVisitor::visitAtom(MemgraphCypher::AtomContext *ctx) { Where *where = ctx->filterExpression()->where()->accept(this); return static_cast( storage_->Create(ident, list_expr, where)); + } else if (ctx->ANY()) { + auto *ident = storage_->Create(ctx->filterExpression() + ->idInColl() + ->variable() + ->accept(this) + .as()); + Expression *list_expr = + ctx->filterExpression()->idInColl()->expression()->accept(this); + if (!ctx->filterExpression()->where()) { + throw SyntaxException("ANY(...) requires a WHERE predicate."); + } + Where *where = ctx->filterExpression()->where()->accept(this); + return static_cast( + storage_->Create(ident, list_expr, where)); } else if (ctx->REDUCE()) { auto *accumulator = storage_->Create( ctx->reduceExpression()->accumulator->accept(this).as()); diff --git a/src/query/frontend/ast/pretty_print.cpp b/src/query/frontend/ast/pretty_print.cpp index f9cb4998e..26320ee18 100644 --- a/src/query/frontend/ast/pretty_print.cpp +++ b/src/query/frontend/ast/pretty_print.cpp @@ -51,6 +51,7 @@ class ExpressionPrettyPrinter : public ExpressionVisitor { void Visit(Extract &op) override; void Visit(All &op) override; void Visit(Single &op) override; + void Visit(Any &op) override; void Visit(Identifier &op) override; void Visit(PrimitiveLiteral &op) override; void Visit(PropertyLookup &op) override; @@ -286,6 +287,11 @@ void ExpressionPrettyPrinter::Visit(Single &op) { op.where_->expression_); } +void ExpressionPrettyPrinter::Visit(Any &op) { + PrintOperator(out_, "Any", op.identifier_, op.list_expression_, + op.where_->expression_); +} + void ExpressionPrettyPrinter::Visit(Identifier &op) { PrintOperator(out_, "Identifier", op.name_); } diff --git a/src/query/frontend/semantic/symbol_generator.cpp b/src/query/frontend/semantic/symbol_generator.cpp index 10a082b5d..b88c36b20 100644 --- a/src/query/frontend/semantic/symbol_generator.cpp +++ b/src/query/frontend/semantic/symbol_generator.cpp @@ -354,6 +354,12 @@ bool SymbolGenerator::PreVisit(Single &single) { return false; } +bool SymbolGenerator::PreVisit(Any &any) { + any.list_expression_->Accept(*this); + VisitWithIdentifiers(any.where_->expression_, {any.identifier_}); + return false; +} + bool SymbolGenerator::PreVisit(Reduce &reduce) { reduce.initializer_->Accept(*this); reduce.list_->Accept(*this); diff --git a/src/query/frontend/semantic/symbol_generator.hpp b/src/query/frontend/semantic/symbol_generator.hpp index c9718c749..103febb11 100644 --- a/src/query/frontend/semantic/symbol_generator.hpp +++ b/src/query/frontend/semantic/symbol_generator.hpp @@ -58,6 +58,7 @@ class SymbolGenerator : public HierarchicalTreeVisitor { bool PostVisit(IfOperator &) override; bool PreVisit(All &) override; bool PreVisit(Single &) override; + bool PreVisit(Any &) override; bool PreVisit(Reduce &) override; bool PreVisit(Extract &) override; diff --git a/src/query/interpret/eval.hpp b/src/query/interpret/eval.hpp index f62a15863..063ad95c4 100644 --- a/src/query/interpret/eval.hpp +++ b/src/query/interpret/eval.hpp @@ -523,6 +523,32 @@ class ExpressionEvaluator : public ExpressionVisitor { return TypedValue(predicate_satisfied, ctx_->memory); } + TypedValue Visit(Any &any) override { + auto list_value = any.list_expression_->Accept(*this); + if (list_value.IsNull()) { + return TypedValue(ctx_->memory); + } + if (list_value.type() != TypedValue::Type::List) { + throw QueryRuntimeException("ANY expected a list, got {}.", + list_value.type()); + } + const auto &list = list_value.ValueList(); + const auto &symbol = symbol_table_->at(*any.identifier_); + for (const auto &element : list) { + frame_->at(symbol) = element; + auto result = any.where_->expression_->Accept(*this); + if (!result.IsNull() && result.type() != TypedValue::Type::Bool) { + throw QueryRuntimeException( + "Predicate of ANY must evaluate to boolean, got {}.", + result.type()); + } + if (result.IsNull() || result.ValueBool()) { + return result; + } + } + return TypedValue(false, ctx_->memory); + } + TypedValue Visit(ParameterLookup ¶m_lookup) override { return TypedValue( ctx_->parameters.AtTokenPosition(param_lookup.token_position_), diff --git a/src/query/plan/preprocess.hpp b/src/query/plan/preprocess.hpp index 66a80d798..629305099 100644 --- a/src/query/plan/preprocess.hpp +++ b/src/query/plan/preprocess.hpp @@ -37,6 +37,13 @@ class UsedSymbolsCollector : public HierarchicalTreeVisitor { return true; } + bool PostVisit(Any &any) override { + // Remove the symbol which is bound by all, because we are only interested + // in free (unbound) symbols. + symbols_.erase(symbol_table_.at(*any.identifier_)); + return true; + } + bool PostVisit(Reduce &reduce) override { // Remove the symbols bound by reduce, because we are only interested // in free (unbound) symbols. diff --git a/src/query/plan/rule_based_planner.cpp b/src/query/plan/rule_based_planner.cpp index 303010d6e..3d90a5fee 100644 --- a/src/query/plan/rule_based_planner.cpp +++ b/src/query/plan/rule_based_planner.cpp @@ -165,6 +165,21 @@ class ReturnBodyContext : public HierarchicalTreeVisitor { return true; } + bool PostVisit(Any &any) override { + // Remove the symbol which is bound by any, because we are only interested + // in free (unbound) symbols. + used_symbols_.erase(symbol_table_.at(*any.identifier_)); + CHECK(has_aggregation_.size() >= 3U) + << "Expected 3 has_aggregation_ flags for ANY arguments"; + bool has_aggr = false; + for (int i = 0; i < 3; ++i) { + has_aggr = has_aggr || has_aggregation_.back(); + has_aggregation_.pop_back(); + } + has_aggregation_.emplace_back(has_aggr); + return true; + } + bool PostVisit(Reduce &reduce) override { // Remove the symbols bound by reduce, because we are only interested // in free (unbound) symbols. diff --git a/tests/unit/query_common.hpp b/tests/unit/query_common.hpp index b82ba2f45..bd63a14c5 100644 --- a/tests/unit/query_common.hpp +++ b/tests/unit/query_common.hpp @@ -639,6 +639,9 @@ auto GetMerge(AstStorage &storage, Pattern *pattern, OnMatch on_match, #define SINGLE(variable, list, where) \ storage.Create(storage.Create(variable), \ list, where) +#define ANY(variable, list, where) \ + storage.Create(storage.Create(variable), \ + list, where) #define REDUCE(accumulator, initializer, variable, list, expr) \ storage.Create( \ storage.Create(accumulator), initializer, \ diff --git a/tests/unit/query_expression_evaluator.cpp b/tests/unit/query_expression_evaluator.cpp index 76cd3fb00..8db7b5bee 100644 --- a/tests/unit/query_expression_evaluator.cpp +++ b/tests/unit/query_expression_evaluator.cpp @@ -769,6 +769,49 @@ TEST_F(ExpressionEvaluatorTest, FunctionSingleNullList) { EXPECT_TRUE(value.IsNull()); } +TEST_F(ExpressionEvaluatorTest, FunctionAny) { + AstStorage storage; + auto *ident_x = IDENT("x"); + auto *any = + ANY("x", LIST(LITERAL(1), LITERAL(2)), WHERE(EQ(ident_x, LITERAL(1)))); + const auto x_sym = symbol_table.CreateSymbol("x", true); + any->identifier_->MapTo(x_sym); + ident_x->MapTo(x_sym); + auto value = Eval(any); + ASSERT_TRUE(value.IsBool()); + EXPECT_TRUE(value.ValueBool()); +} + +TEST_F(ExpressionEvaluatorTest, FunctionAny2) { + AstStorage storage; + auto *ident_x = IDENT("x"); + auto *any = + ANY("x", LIST(LITERAL(1), LITERAL(2)), WHERE(EQ(ident_x, LITERAL(0)))); + const auto x_sym = symbol_table.CreateSymbol("x", true); + any->identifier_->MapTo(x_sym); + ident_x->MapTo(x_sym); + auto value = Eval(any); + ASSERT_TRUE(value.IsBool()); + EXPECT_FALSE(value.ValueBool()); +} + +TEST_F(ExpressionEvaluatorTest, FunctionAnyNullList) { + AstStorage storage; + auto *any = ANY("x", LITERAL(storage::PropertyValue()), WHERE(LITERAL(true))); + const auto x_sym = symbol_table.CreateSymbol("x", true); + any->identifier_->MapTo(x_sym); + auto value = Eval(any); + EXPECT_TRUE(value.IsNull()); +} + +TEST_F(ExpressionEvaluatorTest, FunctionAnyWhereWrongType) { + AstStorage storage; + auto *any = ANY("x", LIST(LITERAL(1)), WHERE(LITERAL(2))); + const auto x_sym = symbol_table.CreateSymbol("x", true); + any->identifier_->MapTo(x_sym); + EXPECT_THROW(Eval(any), QueryRuntimeException); +} + TEST_F(ExpressionEvaluatorTest, FunctionReduce) { AstStorage storage; auto *ident_sum = IDENT("sum");