diff --git a/src/glue/communication.cpp b/src/glue/communication.cpp index fdf5129f6..60181e877 100644 --- a/src/glue/communication.cpp +++ b/src/glue/communication.cpp @@ -127,6 +127,8 @@ storage::Result ToBoltValue(const query::TypedValue &value, const storage return Value(value.ValueLocalDateTime()); case query::TypedValue::Type::Duration: return Value(value.ValueDuration()); + case query::TypedValue::Type::Function: + throw communication::bolt::ValueException("Unsupported conversion from TypedValue::Function to Value"); case query::TypedValue::Type::Graph: auto maybe_graph = ToBoltGraph(value.ValueGraph(), db, view); if (maybe_graph.HasError()) return maybe_graph.GetError(); diff --git a/src/query/common.cpp b/src/query/common.cpp index 793ae8044..3c75ed5ec 100644 --- a/src/query/common.cpp +++ b/src/query/common.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -62,6 +62,7 @@ bool TypedValueCompare(const TypedValue &a, const TypedValue &b) { case TypedValue::Type::Edge: case TypedValue::Type::Path: case TypedValue::Type::Graph: + case TypedValue::Type::Function: throw QueryRuntimeException("Comparison is not defined for values of type {}.", a.type()); case TypedValue::Type::Null: LOG_FATAL("Invalid type"); diff --git a/src/query/interpret/awesome_memgraph_functions.cpp b/src/query/interpret/awesome_memgraph_functions.cpp index 2cfd11f8c..6f49ee99f 100644 --- a/src/query/interpret/awesome_memgraph_functions.cpp +++ b/src/query/interpret/awesome_memgraph_functions.cpp @@ -593,6 +593,7 @@ TypedValue ValueType(const TypedValue *args, int64_t nargs, const FunctionContex case TypedValue::Type::Duration: return TypedValue("DURATION", ctx.memory); case TypedValue::Type::Graph: + case TypedValue::Type::Function: throw QueryRuntimeException("Cannot fetch graph as it is not standardized openCypher type name"); } } diff --git a/src/query/interpret/eval.hpp b/src/query/interpret/eval.hpp index 916082bb2..e09ddcc97 100644 --- a/src/query/interpret/eval.hpp +++ b/src/query/interpret/eval.hpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -903,7 +904,17 @@ class ExpressionEvaluator : public ExpressionVisitor { return TypedValue(std::move(result), ctx_->memory); } - TypedValue Visit(Exists &exists) override { return TypedValue{frame_->at(symbol_table_->at(exists)), ctx_->memory}; } + TypedValue Visit(Exists &exists) override { + TypedValue &frame_exists_value = frame_->at(symbol_table_->at(exists)); + if (!frame_exists_value.IsFunction()) [[unlikely]] { + throw QueryRuntimeException( + "Unexpected behavior: Exists expected a function, got {}. Please report the problem on GitHub issues", + frame_exists_value.type()); + } + TypedValue result{ctx_->memory}; + frame_exists_value.ValueFunction()(&result); + return result; + } TypedValue Visit(All &all) override { auto list_value = all.list_expression_->Accept(*this); diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 1c8d021c7..63bf5cd40 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -2500,13 +2500,16 @@ std::vector EvaluatePatternFilter::ModifiedSymbols(const SymbolTable &ta } bool EvaluatePatternFilter::EvaluatePatternFilterCursor::Pull(Frame &frame, ExecutionContext &context) { - OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP("EvaluatePatternFilter"); + std::function function = [&frame, self = this->self_, input_cursor = this->input_cursor_.get(), + &context](TypedValue *return_value) { + OOMExceptionEnabler oom_exception; + input_cursor->Reset(); - input_cursor_->Reset(); - - frame[self_.output_symbol_] = TypedValue(input_cursor_->Pull(frame, context), context.evaluation_context.memory); + *return_value = TypedValue(input_cursor->Pull(frame, context), context.evaluation_context.memory); + }; + frame[self_.output_symbol_] = TypedValue(std::move(function)); return true; } diff --git a/src/query/plan/rule_based_planner.cpp b/src/query/plan/rule_based_planner.cpp index cd223dd8e..f3d0c1487 100644 --- a/src/query/plan/rule_based_planner.cpp +++ b/src/query/plan/rule_based_planner.cpp @@ -17,6 +17,7 @@ #include #include +#include "query/plan/preprocess.hpp" #include "utils/algorithm.hpp" #include "utils/exceptions.hpp" #include "utils/logging.hpp" @@ -516,14 +517,25 @@ bool HasBoundFilterSymbols(const std::unordered_set &bound_symbols, cons Expression *ExtractFilters(const std::unordered_set &bound_symbols, Filters &filters, AstStorage &storage) { Expression *filter_expr = nullptr; + std::vector and_joinable_filters{}; for (auto filters_it = filters.begin(); filters_it != filters.end();) { if (HasBoundFilterSymbols(bound_symbols, *filters_it)) { - filter_expr = impl::BoolJoin(storage, filter_expr, filters_it->expression); + and_joinable_filters.emplace_back(*filters_it); filters_it = filters.erase(filters_it); } else { filters_it++; } } + // Idea here is to join filters in a way + // that pattern filter ( exists() ) is at the end + // so if any of the AND filters before + // evaluate to false we don't need to + // evaluate pattern ( exists() ) filter + std::partition(and_joinable_filters.begin(), and_joinable_filters.end(), + [](const FilterInfo &filter_info) { return filter_info.type != FilterInfo::Type::Pattern; }); + for (auto &and_joinable_filter : and_joinable_filters) { + filter_expr = impl::BoolJoin(storage, filter_expr, and_joinable_filter.expression); + } return filter_expr; } diff --git a/src/query/procedure/mg_procedure_impl.cpp b/src/query/procedure/mg_procedure_impl.cpp index f87377ba5..ab2b3ae4b 100644 --- a/src/query/procedure/mg_procedure_impl.cpp +++ b/src/query/procedure/mg_procedure_impl.cpp @@ -313,6 +313,8 @@ mgp_value_type FromTypedValueType(memgraph::query::TypedValue::Type type) { return MGP_VALUE_TYPE_LOCAL_DATE_TIME; case memgraph::query::TypedValue::Type::Duration: return MGP_VALUE_TYPE_DURATION; + case memgraph::query::TypedValue::Type::Function: + throw std::logic_error{"mgp_value for TypedValue::Type::Function doesn't exist."}; case memgraph::query::TypedValue::Type::Graph: throw std::logic_error{"mgp_value for TypedValue::Type::Graph doesn't exist."}; } @@ -3672,7 +3674,8 @@ std::ostream &PrintValue(const TypedValue &value, std::ostream *stream) { case TypedValue::Type::Edge: case TypedValue::Type::Path: case TypedValue::Type::Graph: - LOG_FATAL("value must not be a graph element"); + case TypedValue::Type::Function: + LOG_FATAL("value must not be a graph|function element"); } } diff --git a/src/query/typed_value.cpp b/src/query/typed_value.cpp index 13db88e1c..91893d71c 100644 --- a/src/query/typed_value.cpp +++ b/src/query/typed_value.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -22,6 +22,7 @@ #include "storage/v2/temporal.hpp" #include "utils/exceptions.hpp" #include "utils/fnv.hpp" +#include "utils/logging.hpp" #include "utils/memory.hpp" namespace memgraph::query { @@ -215,6 +216,9 @@ TypedValue::TypedValue(const TypedValue &other, utils::MemoryResource *memory) : case Type::Duration: new (&duration_v) utils::Duration(other.duration_v); return; + case Type::Function: + new (&function_v) std::function(other.function_v); + return; case Type::Graph: auto *graph_ptr = utils::Allocator(memory_).new_object(*other.graph_v); new (&graph_v) std::unique_ptr(graph_ptr); @@ -268,6 +272,9 @@ TypedValue::TypedValue(TypedValue &&other, utils::MemoryResource *memory) : memo case Type::Duration: new (&duration_v) utils::Duration(other.duration_v); break; + case Type::Function: + new (&function_v) std::function(other.function_v); + break; case Type::Graph: if (other.GetMemoryResource() == memory_) { new (&graph_v) std::unique_ptr(std::move(other.graph_v)); @@ -343,6 +350,7 @@ DEFINE_VALUE_AND_TYPE_GETTERS(utils::Date, Date, date_v) DEFINE_VALUE_AND_TYPE_GETTERS(utils::LocalTime, LocalTime, local_time_v) DEFINE_VALUE_AND_TYPE_GETTERS(utils::LocalDateTime, LocalDateTime, local_date_time_v) DEFINE_VALUE_AND_TYPE_GETTERS(utils::Duration, Duration, duration_v) +DEFINE_VALUE_AND_TYPE_GETTERS(std::function, Function, function_v) Graph &TypedValue::ValueGraph() { if (type_ != Type::Graph) { @@ -417,6 +425,8 @@ std::ostream &operator<<(std::ostream &os, const TypedValue::Type &type) { return os << "duration"; case TypedValue::Type::Graph: return os << "graph"; + case TypedValue::Type::Function: + return os << "function"; } LOG_FATAL("Unsupported TypedValue::Type"); } @@ -569,6 +579,9 @@ TypedValue &TypedValue::operator=(const TypedValue &other) { case Type::Duration: new (&duration_v) utils::Duration(other.duration_v); return *this; + case Type::Function: + new (&function_v) std::function(other.function_v); + return *this; } LOG_FATAL("Unsupported TypedValue::Type"); } @@ -628,6 +641,9 @@ TypedValue &TypedValue::operator=(TypedValue &&other) noexcept(false) { case Type::Duration: new (&duration_v) utils::Duration(other.duration_v); break; + case Type::Function: + new (&function_v) std::function{other.function_v}; + break; case Type::Graph: if (other.GetMemoryResource() == memory_) { new (&graph_v) std::unique_ptr(std::move(other.graph_v)); @@ -676,6 +692,9 @@ void TypedValue::DestroyValue() { case Type::LocalDateTime: case Type::Duration: break; + case Type::Function: + std::destroy_at(&function_v); + break; case Type::Graph: { auto *graph = graph_v.release(); std::destroy_at(&graph_v); @@ -1153,6 +1172,8 @@ size_t TypedValue::Hash::operator()(const TypedValue &value) const { case TypedValue::Type::Duration: return utils::DurationHash{}(value.ValueDuration()); break; + case TypedValue::Type::Function: + throw TypedValueException("Unsupported hash function for Function"); case TypedValue::Type::Graph: throw TypedValueException("Unsupported hash function for Graph"); } diff --git a/src/query/typed_value.hpp b/src/query/typed_value.hpp index c215e2276..0af38cecc 100644 --- a/src/query/typed_value.hpp +++ b/src/query/typed_value.hpp @@ -84,7 +84,8 @@ class TypedValue { LocalTime, LocalDateTime, Duration, - Graph + Graph, + Function }; // TypedValue at this exact moment of compilation is an incomplete type, and @@ -420,6 +421,9 @@ class TypedValue { new (&graph_v) std::unique_ptr(graph_ptr); } + explicit TypedValue(std::function &&other) + : function_v(std::move(other)), type_(Type::Function) {} + /** * Construct with the value of other. * Default utils::NewDeleteResource() is used for allocations. After the move, @@ -451,6 +455,7 @@ class TypedValue { TypedValue &operator=(const utils::LocalTime &); TypedValue &operator=(const utils::LocalDateTime &); TypedValue &operator=(const utils::Duration &); + TypedValue &operator=(const std::function &); /** Copy assign other, utils::MemoryResource of `this` is used */ TypedValue &operator=(const TypedValue &other); @@ -506,6 +511,7 @@ class TypedValue { DECLARE_VALUE_AND_TYPE_GETTERS(utils::LocalDateTime, LocalDateTime) DECLARE_VALUE_AND_TYPE_GETTERS(utils::Duration, Duration) DECLARE_VALUE_AND_TYPE_GETTERS(Graph, Graph) + DECLARE_VALUE_AND_TYPE_GETTERS(std::function, Function) #undef DECLARE_VALUE_AND_TYPE_GETTERS @@ -550,6 +556,7 @@ class TypedValue { utils::Duration duration_v; // As the unique_ptr is not allocator aware, it requires special attention when copying or moving graphs std::unique_ptr graph_v; + std::function function_v; }; /** diff --git a/tests/unit/formatters.hpp b/tests/unit/formatters.hpp index a5ee49166..5217fd65c 100644 --- a/tests/unit/formatters.hpp +++ b/tests/unit/formatters.hpp @@ -138,6 +138,8 @@ inline std::string ToString(const memgraph::query::TypedValue &value, const TAcc break; case memgraph::query::TypedValue::Type::Graph: throw std::logic_error{"Not implemented"}; + case memgraph::query::TypedValue::Type::Function: + throw std::logic_error{"Not implemented"}; } return os.str(); } diff --git a/tests/unit/query_expression_evaluator.cpp b/tests/unit/query_expression_evaluator.cpp index 44d3ed301..c9786fe5e 100644 --- a/tests/unit/query_expression_evaluator.cpp +++ b/tests/unit/query_expression_evaluator.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -83,6 +84,14 @@ class ExpressionEvaluatorTest : public ::testing::Test { return id; } + Exists *CreateExistsWithValue(std::string name, TypedValue &&value) { + auto id = storage.template Create(); + auto symbol = symbol_table.CreateSymbol(name, true); + id->MapTo(symbol); + frame[symbol] = std::move(value); + return id; + } + template auto Eval(TExpression *expr) { ctx.properties = NamesToProperties(storage.properties_, &dba); @@ -149,6 +158,33 @@ TYPED_TEST(ExpressionEvaluatorTest, AndOperatorShortCircuit) { } } +TYPED_TEST(ExpressionEvaluatorTest, AndExistsOperatorShortCircuit) { + { + std::function my_func = [](TypedValue * /*return_value*/) { + throw QueryRuntimeException("This should not be evaluated"); + }; + TypedValue func_should_not_evaluate{std::move(my_func)}; + + auto *op = this->storage.template Create( + this->storage.template Create(false), + this->CreateExistsWithValue("anon1", std::move(func_should_not_evaluate))); + auto value = this->Eval(op); + EXPECT_EQ(value.ValueBool(), false); + } + { + std::function my_func = [memory = this->ctx.memory](TypedValue *return_value) { + *return_value = TypedValue(false, memory); + }; + TypedValue should_evaluate{std::move(my_func)}; + + auto *op = + this->storage.template Create(this->storage.template Create(true), + this->CreateExistsWithValue("anon1", std::move(should_evaluate))); + auto value = this->Eval(op); + EXPECT_EQ(value.ValueBool(), false); + } +} + TYPED_TEST(ExpressionEvaluatorTest, AndOperatorNull) { { // Null doesn't short circuit diff --git a/tests/unit/query_plan.cpp b/tests/unit/query_plan.cpp index 910ebdc54..bc4b2660c 100644 --- a/tests/unit/query_plan.cpp +++ b/tests/unit/query_plan.cpp @@ -853,6 +853,26 @@ TYPED_TEST(TestPlanner, MatchFilterPropIsNotNull) { } } +TYPED_TEST(TestPlanner, MatchFilterWhere) { + // Test MATCH (n)-[r]-(m) WHERE exists((n)-[]-()) and n!=n and 7!=8 RETURN n + auto *query = QUERY(SINGLE_QUERY( + MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))), + WHERE(AND(EXISTS(PATTERN(NODE("n"), EDGE("edge2", memgraph::query::EdgeAtom::Direction::BOTH, {}, false), + NODE("node3", std::nullopt, false))), + AND(NEQ(IDENT("n"), IDENT("n")), NEQ(LITERAL(7), LITERAL(8))))), + RETURN("n"))); + + std::list pattern_filter{new ExpectScanAll(), new ExpectExpand(), new ExpectLimit(), + new ExpectEvaluatePatternFilter()}; + CheckPlan( + query, this->storage, + ExpectFilter(), // 7!=8 + ExpectScanAll(), + ExpectFilter(std::vector>{pattern_filter}), // filter pulls from expand + ExpectExpand(), ExpectProduce()); + DeleteListContent(&pattern_filter); +} + TYPED_TEST(TestPlanner, MultiMatchWhere) { // Test MATCH (n) -[r]- (m) MATCH (l) WHERE n.prop < 42 RETURN n FakeDbAccessor dba; diff --git a/tests/unit/query_plan_checker.hpp b/tests/unit/query_plan_checker.hpp index 6f2f23df7..92089eb82 100644 --- a/tests/unit/query_plan_checker.hpp +++ b/tests/unit/query_plan_checker.hpp @@ -14,11 +14,13 @@ #include #include +#include "query/frontend/ast/ast.hpp" #include "query/frontend/semantic/symbol_generator.hpp" #include "query/frontend/semantic/symbol_table.hpp" #include "query/plan/operator.hpp" #include "query/plan/planner.hpp" #include "query/plan/preprocess.hpp" +#include "utils/typeinfo.hpp" namespace memgraph::query::plan { @@ -197,6 +199,29 @@ class ExpectFilter : public OpChecker { filter.pattern_filters_[i]->Accept(check_updates); } + // ordering in AND Operator must be ..., exists, exists, exists. + auto *expr = filter.expression_; + std::vector filter_expressions; + while (auto *and_operator = utils::Downcast(expr)) { + auto *expr1 = and_operator->expression1_; + auto *expr2 = and_operator->expression2_; + filter_expressions.emplace_back(expr1); + expr = expr2; + } + if (expr) filter_expressions.emplace_back(expr); + + auto it = filter_expressions.begin(); + for (; it != filter_expressions.end(); it++) { + if ((*it)->GetTypeInfo().name == query::Exists::kType.name) { + break; + } + } + while (it != filter_expressions.end()) { + ASSERT_TRUE((*it)->GetTypeInfo().name == query::Exists::kType.name) + << "Filter expression is '" << (*it)->GetTypeInfo().name << "' expected '" << query::Exists::kType.name + << "'!"; + it++; + } } std::vector> pattern_filters_;