diff --git a/src/glue/communication.cpp b/src/glue/communication.cpp index fdf5129f6..60181e877 100644 --- a/src/glue/communication.cpp +++ b/src/glue/communication.cpp @@ -127,6 +127,8 @@ storage::Result ToBoltValue(const query::TypedValue &value, const storage return Value(value.ValueLocalDateTime()); case query::TypedValue::Type::Duration: return Value(value.ValueDuration()); + case query::TypedValue::Type::Function: + throw communication::bolt::ValueException("Unsupported conversion from TypedValue::Function to Value"); case query::TypedValue::Type::Graph: auto maybe_graph = ToBoltGraph(value.ValueGraph(), db, view); if (maybe_graph.HasError()) return maybe_graph.GetError(); diff --git a/src/memgraph.cpp b/src/memgraph.cpp index ff8ebfd1f..00a159aa5 100644 --- a/src/memgraph.cpp +++ b/src/memgraph.cpp @@ -65,10 +65,13 @@ void InitFromCypherlFile(memgraph::query::InterpreterContext &ctx, memgraph::dbm std::string line; while (std::getline(file, line)) { if (!line.empty()) { - auto results = interpreter.Prepare(line, {}, {}); - memgraph::query::DiscardValueResultStream stream; - interpreter.Pull(&stream, {}, results.qid); - + try { + auto results = interpreter.Prepare(line, {}, {}); + memgraph::query::DiscardValueResultStream stream; + interpreter.Pull(&stream, {}, results.qid); + } catch (const memgraph::query::UserAlreadyExistsException &e) { + spdlog::warn("{} The rest of the init-file will be run.", e.what()); + } if (audit_log) { audit_log->Record("", "", line, {}, memgraph::dbms::kDefaultDB); } diff --git a/src/query/common.cpp b/src/query/common.cpp index 793ae8044..3c75ed5ec 100644 --- a/src/query/common.cpp +++ b/src/query/common.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -62,6 +62,7 @@ bool TypedValueCompare(const TypedValue &a, const TypedValue &b) { case TypedValue::Type::Edge: case TypedValue::Type::Path: case TypedValue::Type::Graph: + case TypedValue::Type::Function: throw QueryRuntimeException("Comparison is not defined for values of type {}.", a.type()); case TypedValue::Type::Null: LOG_FATAL("Invalid type"); diff --git a/src/query/exceptions.hpp b/src/query/exceptions.hpp index 1b2e712f9..ac8cc8fe8 100644 --- a/src/query/exceptions.hpp +++ b/src/query/exceptions.hpp @@ -126,6 +126,12 @@ class InfoInMulticommandTxException : public QueryException { SPECIALIZE_GET_EXCEPTION_NAME(InfoInMulticommandTxException) }; +class UserAlreadyExistsException : public QueryException { + public: + using QueryException::QueryException; + SPECIALIZE_GET_EXCEPTION_NAME(UserAlreadyExistsException) +}; + /** * An exception for an illegal operation that can not be detected * before the query starts executing over data. diff --git a/src/query/frame_change.hpp b/src/query/frame_change.hpp index 32fe1f36e..7baf1fe41 100644 --- a/src/query/frame_change.hpp +++ b/src/query/frame_change.hpp @@ -8,41 +8,42 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. +#include +#include #include "query/typed_value.hpp" +#include "utils/fnv.hpp" #include "utils/memory.hpp" #include "utils/pmr/unordered_map.hpp" #include "utils/pmr/vector.hpp" namespace memgraph::query { // Key is hash output, value is vector of unique elements -using CachedType = utils::pmr::unordered_map>; +using CachedType = utils::pmr::unordered_map>; struct CachedValue { + using allocator_type = utils::Allocator; + // Cached value, this can be probably templateized CachedType cache_; - explicit CachedValue(utils::MemoryResource *mem) : cache_(mem) {} + explicit CachedValue(utils::MemoryResource *mem) : cache_{mem} {}; + CachedValue(const CachedValue &other, utils::MemoryResource *mem) : cache_(other.cache_, mem) {} + CachedValue(CachedValue &&other, utils::MemoryResource *mem) : cache_(std::move(other.cache_), mem){}; - CachedValue(CachedType &&cache, memgraph::utils::MemoryResource *memory) : cache_(std::move(cache), memory) {} + CachedValue(CachedValue &&other) noexcept : CachedValue(std::move(other), other.GetMemoryResource()) {} - CachedValue(const CachedValue &other, memgraph::utils::MemoryResource *memory) : cache_(other.cache_, memory) {} + CachedValue(const CachedValue &other) + : CachedValue(other, std::allocator_traits::select_on_container_copy_construction( + other.GetMemoryResource()) + .GetMemoryResource()) {} - CachedValue(CachedValue &&other, memgraph::utils::MemoryResource *memory) : cache_(std::move(other.cache_), memory) {} - - CachedValue(CachedValue &&other) noexcept = delete; - - /// Copy construction without memgraph::utils::MemoryResource is not allowed. - CachedValue(const CachedValue &) = delete; + utils::MemoryResource *GetMemoryResource() const { return cache_.get_allocator().GetMemoryResource(); } CachedValue &operator=(const CachedValue &) = delete; CachedValue &operator=(CachedValue &&) = delete; ~CachedValue() = default; - memgraph::utils::MemoryResource *GetMemoryResource() const noexcept { - return cache_.get_allocator().GetMemoryResource(); - } - bool CacheValue(const TypedValue &maybe_list) { if (!maybe_list.IsList()) { return false; @@ -70,7 +71,7 @@ struct CachedValue { } private: - static bool IsValueInVec(const std::vector &vec_values, const TypedValue &value) { + static bool IsValueInVec(const utils::pmr::vector &vec_values, const TypedValue &value) { return std::any_of(vec_values.begin(), vec_values.end(), [&value](auto &vec_value) { const auto is_value_equal = vec_value == value; if (is_value_equal.IsNull()) return false; @@ -82,35 +83,70 @@ struct CachedValue { // Class tracks keys for which user can cache values which help with faster search or faster retrieval // in the future. Used for IN LIST operator. class FrameChangeCollector { + /** Allocator type so that STL containers are aware that we need one */ + using allocator_type = utils::Allocator; + public: - explicit FrameChangeCollector() : tracked_values_(&memory_resource_){}; + explicit FrameChangeCollector(utils::MemoryResource *mem = utils::NewDeleteResource()) : tracked_values_{mem} {} + + FrameChangeCollector(FrameChangeCollector &&other, utils::MemoryResource *mem) + : tracked_values_(std::move(other.tracked_values_), mem) {} + FrameChangeCollector(const FrameChangeCollector &other, utils::MemoryResource *mem) + : tracked_values_(other.tracked_values_, mem) {} + + FrameChangeCollector(const FrameChangeCollector &other) + : FrameChangeCollector(other, std::allocator_traits::select_on_container_copy_construction( + other.GetMemoryResource()) + .GetMemoryResource()){}; + + FrameChangeCollector(FrameChangeCollector &&other) noexcept + : FrameChangeCollector(std::move(other), other.GetMemoryResource()) {} + + /** Copy assign other, utils::MemoryResource of `this` is used */ + FrameChangeCollector &operator=(const FrameChangeCollector &) = default; + + /** Move assign other, utils::MemoryResource of `this` is used. */ + FrameChangeCollector &operator=(FrameChangeCollector &&) noexcept = default; + + utils::MemoryResource *GetMemoryResource() const { return tracked_values_.get_allocator().GetMemoryResource(); } CachedValue &AddTrackingKey(const std::string &key) { - const auto &[it, _] = tracked_values_.emplace(key, tracked_values_.get_allocator().GetMemoryResource()); + const auto &[it, _] = tracked_values_.emplace( + std::piecewise_construct, std::forward_as_tuple(utils::pmr::string(key, utils::NewDeleteResource())), + std::forward_as_tuple()); return it->second; } - bool IsKeyTracked(const std::string &key) const { return tracked_values_.contains(key); } + bool IsKeyTracked(const std::string &key) const { + return tracked_values_.contains(utils::pmr::string(key, utils::NewDeleteResource())); + } bool IsKeyValueCached(const std::string &key) const { - return IsKeyTracked(key) && !tracked_values_.at(key).cache_.empty(); + return IsKeyTracked(key) && !tracked_values_.at(utils::pmr::string(key, utils::NewDeleteResource())).cache_.empty(); } bool ResetTrackingValue(const std::string &key) { - if (!tracked_values_.contains(key)) { + if (!tracked_values_.contains(utils::pmr::string(key, utils::NewDeleteResource()))) { return false; } - tracked_values_.erase(key); + tracked_values_.erase(utils::pmr::string(key, utils::NewDeleteResource())); AddTrackingKey(key); return true; } - CachedValue &GetCachedValue(const std::string &key) { return tracked_values_.at(key); } + CachedValue &GetCachedValue(const std::string &key) { + return tracked_values_.at(utils::pmr::string(key, utils::NewDeleteResource())); + } bool IsTrackingValues() const { return !tracked_values_.empty(); } + ~FrameChangeCollector() = default; + private: - utils::MonotonicBufferResource memory_resource_{0}; - memgraph::utils::pmr::unordered_map tracked_values_; + struct PmrStringHash { + size_t operator()(const utils::pmr::string &key) const { return utils::Fnv(key); } + }; + + utils::pmr::unordered_map tracked_values_; }; } // namespace memgraph::query diff --git a/src/query/interpret/awesome_memgraph_functions.cpp b/src/query/interpret/awesome_memgraph_functions.cpp index 2cfd11f8c..6f49ee99f 100644 --- a/src/query/interpret/awesome_memgraph_functions.cpp +++ b/src/query/interpret/awesome_memgraph_functions.cpp @@ -593,6 +593,7 @@ TypedValue ValueType(const TypedValue *args, int64_t nargs, const FunctionContex case TypedValue::Type::Duration: return TypedValue("DURATION", ctx.memory); case TypedValue::Type::Graph: + case TypedValue::Type::Function: throw QueryRuntimeException("Cannot fetch graph as it is not standardized openCypher type name"); } } diff --git a/src/query/interpret/eval.hpp b/src/query/interpret/eval.hpp index 916082bb2..e09ddcc97 100644 --- a/src/query/interpret/eval.hpp +++ b/src/query/interpret/eval.hpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -903,7 +904,17 @@ class ExpressionEvaluator : public ExpressionVisitor { return TypedValue(std::move(result), ctx_->memory); } - TypedValue Visit(Exists &exists) override { return TypedValue{frame_->at(symbol_table_->at(exists)), ctx_->memory}; } + TypedValue Visit(Exists &exists) override { + TypedValue &frame_exists_value = frame_->at(symbol_table_->at(exists)); + if (!frame_exists_value.IsFunction()) [[unlikely]] { + throw QueryRuntimeException( + "Unexpected behavior: Exists expected a function, got {}. Please report the problem on GitHub issues", + frame_exists_value.type()); + } + TypedValue result{ctx_->memory}; + frame_exists_value.ValueFunction()(&result); + return result; + } TypedValue Visit(All &all) override { auto list_value = all.list_expression_->Accept(*this); diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index bb4cd5b94..37482d2cf 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -478,7 +478,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_ MG_ASSERT(password.IsString() || password.IsNull()); if (!auth->CreateUser(username, password.IsString() ? std::make_optional(std::string(password.ValueString())) : std::nullopt)) { - throw QueryRuntimeException("User '{}' already exists.", username); + throw UserAlreadyExistsException("User '{}' already exists.", username); } // If the license is not valid we create users with admin access diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 1c8d021c7..63bf5cd40 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -2500,13 +2500,16 @@ std::vector EvaluatePatternFilter::ModifiedSymbols(const SymbolTable &ta } bool EvaluatePatternFilter::EvaluatePatternFilterCursor::Pull(Frame &frame, ExecutionContext &context) { - OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP("EvaluatePatternFilter"); + std::function function = [&frame, self = this->self_, input_cursor = this->input_cursor_.get(), + &context](TypedValue *return_value) { + OOMExceptionEnabler oom_exception; + input_cursor->Reset(); - input_cursor_->Reset(); - - frame[self_.output_symbol_] = TypedValue(input_cursor_->Pull(frame, context), context.evaluation_context.memory); + *return_value = TypedValue(input_cursor->Pull(frame, context), context.evaluation_context.memory); + }; + frame[self_.output_symbol_] = TypedValue(std::move(function)); return true; } diff --git a/src/query/plan/rule_based_planner.cpp b/src/query/plan/rule_based_planner.cpp index cd223dd8e..f3d0c1487 100644 --- a/src/query/plan/rule_based_planner.cpp +++ b/src/query/plan/rule_based_planner.cpp @@ -17,6 +17,7 @@ #include #include +#include "query/plan/preprocess.hpp" #include "utils/algorithm.hpp" #include "utils/exceptions.hpp" #include "utils/logging.hpp" @@ -516,14 +517,25 @@ bool HasBoundFilterSymbols(const std::unordered_set &bound_symbols, cons Expression *ExtractFilters(const std::unordered_set &bound_symbols, Filters &filters, AstStorage &storage) { Expression *filter_expr = nullptr; + std::vector and_joinable_filters{}; for (auto filters_it = filters.begin(); filters_it != filters.end();) { if (HasBoundFilterSymbols(bound_symbols, *filters_it)) { - filter_expr = impl::BoolJoin(storage, filter_expr, filters_it->expression); + and_joinable_filters.emplace_back(*filters_it); filters_it = filters.erase(filters_it); } else { filters_it++; } } + // Idea here is to join filters in a way + // that pattern filter ( exists() ) is at the end + // so if any of the AND filters before + // evaluate to false we don't need to + // evaluate pattern ( exists() ) filter + std::partition(and_joinable_filters.begin(), and_joinable_filters.end(), + [](const FilterInfo &filter_info) { return filter_info.type != FilterInfo::Type::Pattern; }); + for (auto &and_joinable_filter : and_joinable_filters) { + filter_expr = impl::BoolJoin(storage, filter_expr, and_joinable_filter.expression); + } return filter_expr; } diff --git a/src/query/procedure/mg_procedure_impl.cpp b/src/query/procedure/mg_procedure_impl.cpp index f87377ba5..ab2b3ae4b 100644 --- a/src/query/procedure/mg_procedure_impl.cpp +++ b/src/query/procedure/mg_procedure_impl.cpp @@ -313,6 +313,8 @@ mgp_value_type FromTypedValueType(memgraph::query::TypedValue::Type type) { return MGP_VALUE_TYPE_LOCAL_DATE_TIME; case memgraph::query::TypedValue::Type::Duration: return MGP_VALUE_TYPE_DURATION; + case memgraph::query::TypedValue::Type::Function: + throw std::logic_error{"mgp_value for TypedValue::Type::Function doesn't exist."}; case memgraph::query::TypedValue::Type::Graph: throw std::logic_error{"mgp_value for TypedValue::Type::Graph doesn't exist."}; } @@ -3672,7 +3674,8 @@ std::ostream &PrintValue(const TypedValue &value, std::ostream *stream) { case TypedValue::Type::Edge: case TypedValue::Type::Path: case TypedValue::Type::Graph: - LOG_FATAL("value must not be a graph element"); + case TypedValue::Type::Function: + LOG_FATAL("value must not be a graph|function element"); } } diff --git a/src/query/typed_value.cpp b/src/query/typed_value.cpp index 13db88e1c..91893d71c 100644 --- a/src/query/typed_value.cpp +++ b/src/query/typed_value.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -22,6 +22,7 @@ #include "storage/v2/temporal.hpp" #include "utils/exceptions.hpp" #include "utils/fnv.hpp" +#include "utils/logging.hpp" #include "utils/memory.hpp" namespace memgraph::query { @@ -215,6 +216,9 @@ TypedValue::TypedValue(const TypedValue &other, utils::MemoryResource *memory) : case Type::Duration: new (&duration_v) utils::Duration(other.duration_v); return; + case Type::Function: + new (&function_v) std::function(other.function_v); + return; case Type::Graph: auto *graph_ptr = utils::Allocator(memory_).new_object(*other.graph_v); new (&graph_v) std::unique_ptr(graph_ptr); @@ -268,6 +272,9 @@ TypedValue::TypedValue(TypedValue &&other, utils::MemoryResource *memory) : memo case Type::Duration: new (&duration_v) utils::Duration(other.duration_v); break; + case Type::Function: + new (&function_v) std::function(other.function_v); + break; case Type::Graph: if (other.GetMemoryResource() == memory_) { new (&graph_v) std::unique_ptr(std::move(other.graph_v)); @@ -343,6 +350,7 @@ DEFINE_VALUE_AND_TYPE_GETTERS(utils::Date, Date, date_v) DEFINE_VALUE_AND_TYPE_GETTERS(utils::LocalTime, LocalTime, local_time_v) DEFINE_VALUE_AND_TYPE_GETTERS(utils::LocalDateTime, LocalDateTime, local_date_time_v) DEFINE_VALUE_AND_TYPE_GETTERS(utils::Duration, Duration, duration_v) +DEFINE_VALUE_AND_TYPE_GETTERS(std::function, Function, function_v) Graph &TypedValue::ValueGraph() { if (type_ != Type::Graph) { @@ -417,6 +425,8 @@ std::ostream &operator<<(std::ostream &os, const TypedValue::Type &type) { return os << "duration"; case TypedValue::Type::Graph: return os << "graph"; + case TypedValue::Type::Function: + return os << "function"; } LOG_FATAL("Unsupported TypedValue::Type"); } @@ -569,6 +579,9 @@ TypedValue &TypedValue::operator=(const TypedValue &other) { case Type::Duration: new (&duration_v) utils::Duration(other.duration_v); return *this; + case Type::Function: + new (&function_v) std::function(other.function_v); + return *this; } LOG_FATAL("Unsupported TypedValue::Type"); } @@ -628,6 +641,9 @@ TypedValue &TypedValue::operator=(TypedValue &&other) noexcept(false) { case Type::Duration: new (&duration_v) utils::Duration(other.duration_v); break; + case Type::Function: + new (&function_v) std::function{other.function_v}; + break; case Type::Graph: if (other.GetMemoryResource() == memory_) { new (&graph_v) std::unique_ptr(std::move(other.graph_v)); @@ -676,6 +692,9 @@ void TypedValue::DestroyValue() { case Type::LocalDateTime: case Type::Duration: break; + case Type::Function: + std::destroy_at(&function_v); + break; case Type::Graph: { auto *graph = graph_v.release(); std::destroy_at(&graph_v); @@ -1153,6 +1172,8 @@ size_t TypedValue::Hash::operator()(const TypedValue &value) const { case TypedValue::Type::Duration: return utils::DurationHash{}(value.ValueDuration()); break; + case TypedValue::Type::Function: + throw TypedValueException("Unsupported hash function for Function"); case TypedValue::Type::Graph: throw TypedValueException("Unsupported hash function for Graph"); } diff --git a/src/query/typed_value.hpp b/src/query/typed_value.hpp index c215e2276..0af38cecc 100644 --- a/src/query/typed_value.hpp +++ b/src/query/typed_value.hpp @@ -84,7 +84,8 @@ class TypedValue { LocalTime, LocalDateTime, Duration, - Graph + Graph, + Function }; // TypedValue at this exact moment of compilation is an incomplete type, and @@ -420,6 +421,9 @@ class TypedValue { new (&graph_v) std::unique_ptr(graph_ptr); } + explicit TypedValue(std::function &&other) + : function_v(std::move(other)), type_(Type::Function) {} + /** * Construct with the value of other. * Default utils::NewDeleteResource() is used for allocations. After the move, @@ -451,6 +455,7 @@ class TypedValue { TypedValue &operator=(const utils::LocalTime &); TypedValue &operator=(const utils::LocalDateTime &); TypedValue &operator=(const utils::Duration &); + TypedValue &operator=(const std::function &); /** Copy assign other, utils::MemoryResource of `this` is used */ TypedValue &operator=(const TypedValue &other); @@ -506,6 +511,7 @@ class TypedValue { DECLARE_VALUE_AND_TYPE_GETTERS(utils::LocalDateTime, LocalDateTime) DECLARE_VALUE_AND_TYPE_GETTERS(utils::Duration, Duration) DECLARE_VALUE_AND_TYPE_GETTERS(Graph, Graph) + DECLARE_VALUE_AND_TYPE_GETTERS(std::function, Function) #undef DECLARE_VALUE_AND_TYPE_GETTERS @@ -550,6 +556,7 @@ class TypedValue { utils::Duration duration_v; // As the unique_ptr is not allocator aware, it requires special attention when copying or moving graphs std::unique_ptr graph_v; + std::function function_v; }; /** diff --git a/tests/integration/CMakeLists.txt b/tests/integration/CMakeLists.txt index 73d98ce6a..c61f046dc 100644 --- a/tests/integration/CMakeLists.txt +++ b/tests/integration/CMakeLists.txt @@ -1,38 +1,14 @@ -# telemetry test binaries add_subdirectory(telemetry) - -# ssl test binaries add_subdirectory(ssl) - -# transactions test binaries add_subdirectory(transactions) - -# auth test binaries add_subdirectory(auth) - -# lba test binaries add_subdirectory(fine_grained_access) - -# audit test binaries add_subdirectory(audit) - -# ldap test binaries add_subdirectory(ldap) - -# mg_import_csv test binaries add_subdirectory(mg_import_csv) - -# license_check test binaries add_subdirectory(license_info) - -#environment variable check binaries add_subdirectory(env_variable_check) - -#flag check binaries add_subdirectory(flag_check) - -#storage mode binaries add_subdirectory(storage_mode) - -#run time settings binaries add_subdirectory(run_time_settings) +add_subdirectory(init_file) diff --git a/tests/integration/init_file/CMakeLists.txt b/tests/integration/init_file/CMakeLists.txt new file mode 100644 index 000000000..41f2af6cc --- /dev/null +++ b/tests/integration/init_file/CMakeLists.txt @@ -0,0 +1,6 @@ +set(target_name memgraph__integration__init_file) +set(tester_target_name ${target_name}__tester) + +add_executable(${tester_target_name} tester.cpp) +set_target_properties(${tester_target_name} PROPERTIES OUTPUT_NAME tester) +target_link_libraries(${tester_target_name} mg-communication) diff --git a/tests/integration/init_file/auth.cypherl b/tests/integration/init_file/auth.cypherl new file mode 100644 index 000000000..3a2f8d441 --- /dev/null +++ b/tests/integration/init_file/auth.cypherl @@ -0,0 +1 @@ +CREATE USER memgraph1 IDENTIFIED BY '1234'; diff --git a/tests/integration/init_file/runner.py b/tests/integration/init_file/runner.py new file mode 100644 index 000000000..fcaa10f95 --- /dev/null +++ b/tests/integration/init_file/runner.py @@ -0,0 +1,60 @@ +import argparse +import os +import subprocess +import sys +import tempfile +import time + +SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) +PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", "..")) +BUILD_DIR = os.path.join(PROJECT_DIR, "build") +INIT_FILE = os.path.join(SCRIPT_DIR, "auth.cypherl") +SIGNAL_SIGTERM = 15 + + +def wait_for_server(port, delay=0.1): + cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)] + while subprocess.call(cmd) != 0: + time.sleep(0.01) + time.sleep(delay) + + +def prepare_memgraph(memgraph_args): + memgraph = subprocess.Popen(list(map(str, memgraph_args))) + time.sleep(0.1) + assert memgraph.poll() is None, "Memgraph process died prematurely!" + wait_for_server(7687) + return memgraph + + +def terminate_memgraph(memgraph): + pid = memgraph.pid + try: + os.kill(pid, SIGNAL_SIGTERM) + except os.OSError: + assert False, "Memgraph process didn't exit cleanly!" + time.sleep(1) + + +def execute_test_restart_memgraph_with_init_file(memgraph_binary: str, tester_binary: str) -> None: + storage_directory = tempfile.TemporaryDirectory() + tester_args = [tester_binary, "--username", "memgraph1", "--password", "1234"] + memgraph = prepare_memgraph([memgraph_binary, "--data-directory", storage_directory.name, "--init-file", INIT_FILE]) + subprocess.run(tester_args, stdout=subprocess.PIPE, check=True).check_returncode() + terminate_memgraph(memgraph) + memgraph = prepare_memgraph([memgraph_binary, "--data-directory", storage_directory.name, "--init-file", INIT_FILE]) + subprocess.run(tester_args, stdout=subprocess.PIPE, check=True).check_returncode() + terminate_memgraph(memgraph) + + +if __name__ == "__main__": + memgraph_binary = os.path.join(PROJECT_DIR, "build", "memgraph") + tester_binary = os.path.join(BUILD_DIR, "tests", "integration", "init_file", "tester") + + parser = argparse.ArgumentParser() + parser.add_argument("--memgraph", default=memgraph_binary) + parser.add_argument("--tester", default=tester_binary) + args = parser.parse_args() + + execute_test_restart_memgraph_with_init_file(args.memgraph, args.tester) + sys.exit(0) diff --git a/tests/integration/init_file/tester.cpp b/tests/integration/init_file/tester.cpp new file mode 100644 index 000000000..d4486ead5 --- /dev/null +++ b/tests/integration/init_file/tester.cpp @@ -0,0 +1,47 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include + +#include "communication/bolt/client.hpp" +#include "io/network/endpoint.hpp" +#include "io/network/utils.hpp" +#include "utils/logging.hpp" + +DEFINE_string(address, "127.0.0.1", "Server address"); +DEFINE_int32(port, 7687, "Server port"); +DEFINE_string(username, "", "Username for the database"); +DEFINE_string(password, "", "Password for the database"); +DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server."); + +// NOLINTNEXTLINE(bugprone-exception-escape) +int main(int argc, char **argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + memgraph::logging::RedirectToStderr(); + + memgraph::communication::SSLInit sslInit; + + memgraph::io::network::Endpoint endpoint(memgraph::io::network::ResolveHostname(FLAGS_address), FLAGS_port); + + memgraph::communication::ClientContext context(FLAGS_use_ssl); + memgraph::communication::bolt::Client client(context); + + client.Connect(endpoint, FLAGS_username, FLAGS_password); + auto ret = client.Execute("SHOW USERS", {}); + auto size = ret.records.size(); + MG_ASSERT(size == 1, "Too much users returned for SHOW USERA (got {}, expected 1)!", size); + auto row0_size = ret.records[0].size(); + MG_ASSERT(row0_size == 1, "Too much entries in query dump row (got {}, expected 1)!", row0_size); + auto user = ret.records[0][0].ValueString(); + MG_ASSERT(user == "memgraph1", "Unexpected user returned for SHOW USERS (got {}, expected memgraph)!", user); + + return 0; +} diff --git a/tests/unit/formatters.hpp b/tests/unit/formatters.hpp index a5ee49166..5217fd65c 100644 --- a/tests/unit/formatters.hpp +++ b/tests/unit/formatters.hpp @@ -138,6 +138,8 @@ inline std::string ToString(const memgraph::query::TypedValue &value, const TAcc break; case memgraph::query::TypedValue::Type::Graph: throw std::logic_error{"Not implemented"}; + case memgraph::query::TypedValue::Type::Function: + throw std::logic_error{"Not implemented"}; } return os.str(); } diff --git a/tests/unit/query_expression_evaluator.cpp b/tests/unit/query_expression_evaluator.cpp index 44d3ed301..c9786fe5e 100644 --- a/tests/unit/query_expression_evaluator.cpp +++ b/tests/unit/query_expression_evaluator.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -83,6 +84,14 @@ class ExpressionEvaluatorTest : public ::testing::Test { return id; } + Exists *CreateExistsWithValue(std::string name, TypedValue &&value) { + auto id = storage.template Create(); + auto symbol = symbol_table.CreateSymbol(name, true); + id->MapTo(symbol); + frame[symbol] = std::move(value); + return id; + } + template auto Eval(TExpression *expr) { ctx.properties = NamesToProperties(storage.properties_, &dba); @@ -149,6 +158,33 @@ TYPED_TEST(ExpressionEvaluatorTest, AndOperatorShortCircuit) { } } +TYPED_TEST(ExpressionEvaluatorTest, AndExistsOperatorShortCircuit) { + { + std::function my_func = [](TypedValue * /*return_value*/) { + throw QueryRuntimeException("This should not be evaluated"); + }; + TypedValue func_should_not_evaluate{std::move(my_func)}; + + auto *op = this->storage.template Create( + this->storage.template Create(false), + this->CreateExistsWithValue("anon1", std::move(func_should_not_evaluate))); + auto value = this->Eval(op); + EXPECT_EQ(value.ValueBool(), false); + } + { + std::function my_func = [memory = this->ctx.memory](TypedValue *return_value) { + *return_value = TypedValue(false, memory); + }; + TypedValue should_evaluate{std::move(my_func)}; + + auto *op = + this->storage.template Create(this->storage.template Create(true), + this->CreateExistsWithValue("anon1", std::move(should_evaluate))); + auto value = this->Eval(op); + EXPECT_EQ(value.ValueBool(), false); + } +} + TYPED_TEST(ExpressionEvaluatorTest, AndOperatorNull) { { // Null doesn't short circuit diff --git a/tests/unit/query_plan.cpp b/tests/unit/query_plan.cpp index 910ebdc54..bc4b2660c 100644 --- a/tests/unit/query_plan.cpp +++ b/tests/unit/query_plan.cpp @@ -853,6 +853,26 @@ TYPED_TEST(TestPlanner, MatchFilterPropIsNotNull) { } } +TYPED_TEST(TestPlanner, MatchFilterWhere) { + // Test MATCH (n)-[r]-(m) WHERE exists((n)-[]-()) and n!=n and 7!=8 RETURN n + auto *query = QUERY(SINGLE_QUERY( + MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))), + WHERE(AND(EXISTS(PATTERN(NODE("n"), EDGE("edge2", memgraph::query::EdgeAtom::Direction::BOTH, {}, false), + NODE("node3", std::nullopt, false))), + AND(NEQ(IDENT("n"), IDENT("n")), NEQ(LITERAL(7), LITERAL(8))))), + RETURN("n"))); + + std::list pattern_filter{new ExpectScanAll(), new ExpectExpand(), new ExpectLimit(), + new ExpectEvaluatePatternFilter()}; + CheckPlan( + query, this->storage, + ExpectFilter(), // 7!=8 + ExpectScanAll(), + ExpectFilter(std::vector>{pattern_filter}), // filter pulls from expand + ExpectExpand(), ExpectProduce()); + DeleteListContent(&pattern_filter); +} + TYPED_TEST(TestPlanner, MultiMatchWhere) { // Test MATCH (n) -[r]- (m) MATCH (l) WHERE n.prop < 42 RETURN n FakeDbAccessor dba; diff --git a/tests/unit/query_plan_checker.hpp b/tests/unit/query_plan_checker.hpp index 6f2f23df7..92089eb82 100644 --- a/tests/unit/query_plan_checker.hpp +++ b/tests/unit/query_plan_checker.hpp @@ -14,11 +14,13 @@ #include #include +#include "query/frontend/ast/ast.hpp" #include "query/frontend/semantic/symbol_generator.hpp" #include "query/frontend/semantic/symbol_table.hpp" #include "query/plan/operator.hpp" #include "query/plan/planner.hpp" #include "query/plan/preprocess.hpp" +#include "utils/typeinfo.hpp" namespace memgraph::query::plan { @@ -197,6 +199,29 @@ class ExpectFilter : public OpChecker { filter.pattern_filters_[i]->Accept(check_updates); } + // ordering in AND Operator must be ..., exists, exists, exists. + auto *expr = filter.expression_; + std::vector filter_expressions; + while (auto *and_operator = utils::Downcast(expr)) { + auto *expr1 = and_operator->expression1_; + auto *expr2 = and_operator->expression2_; + filter_expressions.emplace_back(expr1); + expr = expr2; + } + if (expr) filter_expressions.emplace_back(expr); + + auto it = filter_expressions.begin(); + for (; it != filter_expressions.end(); it++) { + if ((*it)->GetTypeInfo().name == query::Exists::kType.name) { + break; + } + } + while (it != filter_expressions.end()) { + ASSERT_TRUE((*it)->GetTypeInfo().name == query::Exists::kType.name) + << "Filter expression is '" << (*it)->GetTypeInfo().name << "' expected '" << query::Exists::kType.name + << "'!"; + it++; + } } std::vector> pattern_filters_;