From 69eca9b0437a3370d13a01669b8eabbcc9ff9599 Mon Sep 17 00:00:00 2001 From: Antonio Andelic Date: Fri, 11 Feb 2022 11:29:41 +0100 Subject: [PATCH] Procedures for handling modules (#330) --- .clang-tidy | 1 + libs/setup.sh | 4 +- src/auth/models.cpp | 6 +- src/auth/models.hpp | 52 +- src/glue/auth.cpp | 6 +- src/memgraph.cpp | 4 +- src/query/frontend/ast/ast.lcp | 5 +- .../frontend/ast/cypher_main_visitor.cpp | 4 +- .../opencypher/grammar/MemgraphCypher.g4 | 2 + .../opencypher/grammar/MemgraphCypherLexer.g4 | 2 + .../frontend/semantic/required_privileges.cpp | 12 +- src/query/plan/operator.cpp | 10 +- src/query/procedure/mg_procedure_impl.cpp | 9 +- src/query/procedure/mg_procedure_impl.hpp | 38 +- src/query/procedure/module.cpp | 457 +++++++++++++++--- src/query/procedure/module.hpp | 8 +- src/query/procedure/py_module.cpp | 13 +- src/query/stream/streams.cpp | 106 ++-- src/storage/v2/delta.hpp | 11 +- src/storage/v2/storage.cpp | 20 +- tests/e2e/CMakeLists.txt | 1 + tests/e2e/memory/memory_control.cpp | 7 +- .../memory/memory_limit_global_alloc_proc.cpp | 21 +- tests/e2e/module_file_manager/CMakeLists.txt | 4 + .../module_file_manager.cpp | 270 +++++++++++ tests/e2e/module_file_manager/workloads.yaml | 14 + tests/e2e/triggers/on_delete_triggers.cpp | 5 +- tests/e2e/triggers/on_update_triggers.cpp | 4 +- tests/e2e/triggers/privilige_check.cpp | 21 +- tests/unit/cypher_main_visitor.cpp | 4 +- tests/unit/query_common.hpp | 11 +- tests/unit/query_procedure_mgp_module.cpp | 10 +- tests/unit/query_required_privileges.cpp | 24 + 33 files changed, 938 insertions(+), 228 deletions(-) create mode 100644 tests/e2e/module_file_manager/CMakeLists.txt create mode 100644 tests/e2e/module_file_manager/module_file_manager.cpp create mode 100644 tests/e2e/module_file_manager/workloads.yaml diff --git a/.clang-tidy b/.clang-tidy index 1560bebe0..35d6c9b84 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -1,6 +1,7 @@ --- Checks: '*, -abseil-string-find-str-contains, + -altera-id-dependent-backward-branch, -altera-struct-pack-align, -altera-unroll-loops, -android-*, diff --git a/libs/setup.sh b/libs/setup.sh index 7c7160626..4b02a1bf8 100755 --- a/libs/setup.sh +++ b/libs/setup.sh @@ -211,8 +211,8 @@ git apply ../rocksdb.patch popd # mgclient -mgclient_tag="v1.3.0" # (2021-09-23) -repo_clone_try_double "${primary_urls[mgclient]}" "${secondary_urls[mgclient]}" "mgclient" "$mgclient_tag" true +mgclient_tag="96e95c6845463cbe88948392be58d26da0d5ffd3" # (2022-02-08) +repo_clone_try_double "${primary_urls[mgclient]}" "${secondary_urls[mgclient]}" "mgclient" "$mgclient_tag" sed -i 's/\${CMAKE_INSTALL_LIBDIR}/lib/' mgclient/src/CMakeLists.txt # pymgclient diff --git a/src/auth/models.cpp b/src/auth/models.cpp index 9d225b416..5eaa042a6 100644 --- a/src/auth/models.cpp +++ b/src/auth/models.cpp @@ -1,4 +1,4 @@ -// Copyright 2021 Memgraph Ltd. +// Copyright 2022 Memgraph Ltd. // // Licensed as a Memgraph Enterprise file under the Memgraph Enterprise // License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use @@ -68,6 +68,10 @@ std::string PermissionToString(Permission permission) { return "AUTH"; case Permission::STREAM: return "STREAM"; + case Permission::MODULE_READ: + return "MODULE_READ"; + case Permission::MODULE_WRITE: + return "MODULE_WRITE"; } } diff --git a/src/auth/models.hpp b/src/auth/models.hpp index 7647dbec3..e4affc236 100644 --- a/src/auth/models.hpp +++ b/src/auth/models.hpp @@ -1,4 +1,4 @@ -// Copyright 2021 Memgraph Ltd. +// Copyright 2022 Memgraph Ltd. // // Licensed as a Memgraph Enterprise file under the Memgraph Enterprise // License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use @@ -19,34 +19,36 @@ namespace auth { // bitmask. // clang-format off enum class Permission : uint64_t { - MATCH = 1, - CREATE = 1U << 1U, - MERGE = 1U << 2U, - DELETE = 1U << 3U, - SET = 1U << 4U, - REMOVE = 1U << 5U, - INDEX = 1U << 6U, - STATS = 1U << 7U, - CONSTRAINT = 1U << 8U, - DUMP = 1U << 9U, - REPLICATION = 1U << 10U, - DURABILITY = 1U << 11U, - READ_FILE = 1U << 12U, - FREE_MEMORY = 1U << 13U, - TRIGGER = 1U << 14U, - CONFIG = 1U << 15U, - AUTH = 1U << 16U, - STREAM = 1U << 17U + MATCH = 1, + CREATE = 1U << 1U, + MERGE = 1U << 2U, + DELETE = 1U << 3U, + SET = 1U << 4U, + REMOVE = 1U << 5U, + INDEX = 1U << 6U, + STATS = 1U << 7U, + CONSTRAINT = 1U << 8U, + DUMP = 1U << 9U, + REPLICATION = 1U << 10U, + DURABILITY = 1U << 11U, + READ_FILE = 1U << 12U, + FREE_MEMORY = 1U << 13U, + TRIGGER = 1U << 14U, + CONFIG = 1U << 15U, + AUTH = 1U << 16U, + STREAM = 1U << 17U, + MODULE_READ = 1U << 18U, + MODULE_WRITE = 1U << 19U }; // clang-format on // Constant list of all available permissions. -const std::vector kPermissionsAll = {Permission::MATCH, Permission::CREATE, Permission::MERGE, - Permission::DELETE, Permission::SET, Permission::REMOVE, - Permission::INDEX, Permission::STATS, Permission::CONSTRAINT, - Permission::DUMP, Permission::AUTH, Permission::REPLICATION, - Permission::DURABILITY, Permission::READ_FILE, Permission::FREE_MEMORY, - Permission::TRIGGER, Permission::CONFIG, Permission::STREAM}; +const std::vector kPermissionsAll = { + Permission::MATCH, Permission::CREATE, Permission::MERGE, Permission::DELETE, + Permission::SET, Permission::REMOVE, Permission::INDEX, Permission::STATS, + Permission::CONSTRAINT, Permission::DUMP, Permission::AUTH, Permission::REPLICATION, + Permission::DURABILITY, Permission::READ_FILE, Permission::FREE_MEMORY, Permission::TRIGGER, + Permission::CONFIG, Permission::STREAM, Permission::MODULE_READ, Permission::MODULE_WRITE}; // Function that converts a permission to its string representation. std::string PermissionToString(Permission permission); diff --git a/src/glue/auth.cpp b/src/glue/auth.cpp index 48474239f..990877cd2 100644 --- a/src/glue/auth.cpp +++ b/src/glue/auth.cpp @@ -1,4 +1,4 @@ -// Copyright 2021 Memgraph Ltd. +// Copyright 2022 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -51,6 +51,10 @@ auth::Permission PrivilegeToPermission(query::AuthQuery::Privilege privilege) { return auth::Permission::AUTH; case query::AuthQuery::Privilege::STREAM: return auth::Permission::STREAM; + case query::AuthQuery::Privilege::MODULE_READ: + return auth::Permission::MODULE_READ; + case query::AuthQuery::Privilege::MODULE_WRITE: + return auth::Permission::MODULE_WRITE; } } } // namespace glue diff --git a/src/memgraph.cpp b/src/memgraph.cpp index 19a73403a..d6f3eebc0 100644 --- a/src/memgraph.cpp +++ b/src/memgraph.cpp @@ -1,4 +1,4 @@ -// Copyright 2021 Memgraph Ltd. +// Copyright 2022 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -1146,7 +1146,7 @@ int main(int argc, char **argv) { SessionData session_data{&db, &interpreter_context, &auth}; #endif - query::procedure::gModuleRegistry.SetModulesDirectory(query_modules_directories); + query::procedure::gModuleRegistry.SetModulesDirectory(query_modules_directories, FLAGS_data_directory); query::procedure::gModuleRegistry.UnloadAndLoadModulesFromDirectories(); AuthQueryHandler auth_handler(&auth, FLAGS_auth_user_or_role_name_regex); diff --git a/src/query/frontend/ast/ast.lcp b/src/query/frontend/ast/ast.lcp index b08833c87..f50aba5d4 100644 --- a/src/query/frontend/ast/ast.lcp +++ b/src/query/frontend/ast/ast.lcp @@ -2247,7 +2247,7 @@ cpp<# (:serialize)) (lcp:define-enum privilege (create delete match merge set remove index stats auth constraint - dump replication durability read_file free_memory trigger config stream) + dump replication durability read_file free_memory trigger config stream module_read module_write) (:serialize)) #>cpp AuthQuery() = default; @@ -2287,7 +2287,8 @@ const std::vector kPrivilegesAll = { AuthQuery::Privilege::READ_FILE, AuthQuery::Privilege::DURABILITY, AuthQuery::Privilege::FREE_MEMORY, AuthQuery::Privilege::TRIGGER, - AuthQuery::Privilege::CONFIG, AuthQuery::Privilege::STREAM}; + AuthQuery::Privilege::CONFIG, AuthQuery::Privilege::STREAM, + AuthQuery::Privilege::MODULE_READ, AuthQuery::Privilege::MODULE_WRITE}; cpp<# (lcp:define-class info-query (query) diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index ab2e02cc6..6ac33cd25 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -1091,7 +1091,7 @@ antlrcpp::Any CypherMainVisitor::visitCallProcedure(MemgraphCypher::CallProcedur if (!maybe_found) { throw SemanticException("There is no procedure named '{}'.", call_proc->procedure_name_); } - call_proc->is_write_ = maybe_found->second->is_write_procedure; + call_proc->is_write_ = maybe_found->second->info.is_write; auto *yield_ctx = ctx->yieldProcedureResults(); if (!yield_ctx) { @@ -1330,6 +1330,8 @@ antlrcpp::Any CypherMainVisitor::visitPrivilege(MemgraphCypher::PrivilegeContext if (ctx->CONFIG()) return AuthQuery::Privilege::CONFIG; if (ctx->DURABILITY()) return AuthQuery::Privilege::DURABILITY; if (ctx->STREAM()) return AuthQuery::Privilege::STREAM; + if (ctx->MODULE_READ()) return AuthQuery::Privilege::MODULE_READ; + if (ctx->MODULE_WRITE()) return AuthQuery::Privilege::MODULE_WRITE; LOG_FATAL("Should not get here - unknown privilege!"); } diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 index be9497e20..02de4493c 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 @@ -239,6 +239,8 @@ privilege : CREATE | CONFIG | DURABILITY | STREAM + | MODULE_READ + | MODULE_WRITE ; privilegeList : privilege ( ',' privilege )* ; diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 index e678a73e9..a45651c66 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 @@ -70,6 +70,8 @@ LOAD : L O A D ; LOCK : L O C K ; MAIN : M A I N ; MODE : M O D E ; +MODULE_READ : M O D U L E UNDERSCORE R E A D ; +MODULE_WRITE : M O D U L E UNDERSCORE W R I T E ; NEXT : N E X T ; NO : N O ; PASSWORD : P A S S W O R D ; diff --git a/src/query/frontend/semantic/required_privileges.cpp b/src/query/frontend/semantic/required_privileges.cpp index 6e61dfc2c..42f74a4af 100644 --- a/src/query/frontend/semantic/required_privileges.cpp +++ b/src/query/frontend/semantic/required_privileges.cpp @@ -1,4 +1,4 @@ -// Copyright 2021 Memgraph Ltd. +// Copyright 2022 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -11,6 +11,8 @@ #include "query/frontend/ast/ast.hpp" #include "query/frontend/ast/ast_visitor.hpp" +#include "query/procedure/module.hpp" +#include "utils/memory.hpp" namespace query { @@ -82,8 +84,12 @@ class PrivilegeExtractor : public QueryVisitor, public HierarchicalTreeVis AddPrivilege(AuthQuery::Privilege::CREATE); return false; } - bool PreVisit(CallProcedure & /*unused*/) override { - // TODO: Corresponding privilege + bool PreVisit(CallProcedure &procedure) override { + const auto maybe_proc = + procedure::FindProcedure(procedure::gModuleRegistry, procedure.procedure_name_, utils::NewDeleteResource()); + if (maybe_proc && maybe_proc->second->info.required_privilege) { + AddPrivilege(*maybe_proc->second->info.required_privilege); + } return false; } bool PreVisit(Delete & /*unused*/) override { diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index b8652d3a3..ada565853 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -1,4 +1,4 @@ -// Copyright 2021 Memgraph Ltd. +// Copyright 2022 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -2107,7 +2107,7 @@ template concept AccessorWithProperties = requires(T value, storage::PropertyId property_id, storage::PropertyValue property_value) { { value.ClearProperties() } -> std::same_as>>; - { value.SetProperty(property_id, property_value) }; + {value.SetProperty(property_id, property_value)}; }; /// Helper function that sets the given values on either a Vertex or an Edge. @@ -3813,13 +3813,13 @@ class CallProcedureCursor : public Cursor { throw QueryRuntimeException("There is no procedure named '{}'.", self_->procedure_name_); } const auto &[module, proc] = *maybe_found; - if (proc->is_write_procedure != self_->is_write_) { + if (proc->info.is_write != self_->is_write_) { auto get_proc_type_str = [](bool is_write) { return is_write ? "write" : "read"; }; throw QueryRuntimeException("The procedure named '{}' was a {} procedure, but changed to be a {} procedure.", self_->procedure_name_, get_proc_type_str(self_->is_write_), - get_proc_type_str(proc->is_write_procedure)); + get_proc_type_str(proc->info.is_write)); } - const auto graph_view = proc->is_write_procedure ? storage::View::NEW : storage::View::OLD; + const auto graph_view = proc->info.is_write ? storage::View::NEW : storage::View::OLD; ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, graph_view); diff --git a/src/query/procedure/mg_procedure_impl.cpp b/src/query/procedure/mg_procedure_impl.cpp index 0ddefa6b2..e28abdbea 100644 --- a/src/query/procedure/mg_procedure_impl.cpp +++ b/src/query/procedure/mg_procedure_impl.cpp @@ -2353,7 +2353,8 @@ mgp_error mgp_type_nullable(mgp_type *type, mgp_type **result) { } namespace { -mgp_proc *mgp_module_add_procedure(mgp_module *module, const char *name, mgp_proc_cb cb, bool is_write_procedure) { +mgp_proc *mgp_module_add_procedure(mgp_module *module, const char *name, mgp_proc_cb cb, + const ProcedureInfo &procedure_info) { if (!IsValidIdentifierName(name)) { throw std::invalid_argument{fmt::format("Invalid procedure name: {}", name)}; } @@ -2363,16 +2364,16 @@ mgp_proc *mgp_module_add_procedure(mgp_module *module, const char *name, mgp_pro auto *memory = module->procedures.get_allocator().GetMemoryResource(); // May throw std::bad_alloc, std::length_error - return &module->procedures.emplace(name, mgp_proc(name, cb, memory, is_write_procedure)).first->second; + return &module->procedures.emplace(name, mgp_proc(name, cb, memory, procedure_info)).first->second; } } // namespace mgp_error mgp_module_add_read_procedure(mgp_module *module, const char *name, mgp_proc_cb cb, mgp_proc **result) { - return WrapExceptions([=] { return mgp_module_add_procedure(module, name, cb, false); }, result); + return WrapExceptions([=] { return mgp_module_add_procedure(module, name, cb, {.is_write = false}); }, result); } mgp_error mgp_module_add_write_procedure(mgp_module *module, const char *name, mgp_proc_cb cb, mgp_proc **result) { - return WrapExceptions([=] { return mgp_module_add_procedure(module, name, cb, true); }, result); + return WrapExceptions([=] { return mgp_module_add_procedure(module, name, cb, {.is_write = true}); }, result); } mgp_error mgp_proc_add_arg(mgp_proc *proc, const char *name, mgp_type *type) { diff --git a/src/query/procedure/mg_procedure_impl.hpp b/src/query/procedure/mg_procedure_impl.hpp index df8653d2b..697708537 100644 --- a/src/query/procedure/mg_procedure_impl.hpp +++ b/src/query/procedure/mg_procedure_impl.hpp @@ -23,6 +23,7 @@ #include "integrations/pulsar/consumer.hpp" #include "query/context.hpp" #include "query/db_accessor.hpp" +#include "query/frontend/ast/ast.hpp" #include "query/procedure/cypher_type_ptr.hpp" #include "query/typed_value.hpp" #include "storage/v2/view.hpp" @@ -653,40 +654,29 @@ struct mgp_type { query::procedure::CypherTypePtr impl; }; +struct ProcedureInfo { + bool is_write = false; + std::optional required_privilege = std::nullopt; +}; struct mgp_proc { using allocator_type = utils::Allocator; /// @throw std::bad_alloc /// @throw std::length_error - mgp_proc(const char *name, mgp_proc_cb cb, utils::MemoryResource *memory, bool is_write_procedure) - : name(name, memory), - cb(cb), - args(memory), - opt_args(memory), - results(memory), - is_write_procedure(is_write_procedure) {} + mgp_proc(const char *name, mgp_proc_cb cb, utils::MemoryResource *memory, const ProcedureInfo &info = {}) + : name(name, memory), cb(cb), args(memory), opt_args(memory), results(memory), info(info) {} /// @throw std::bad_alloc /// @throw std::length_error mgp_proc(const char *name, std::function cb, - utils::MemoryResource *memory, bool is_write_procedure) - : name(name, memory), - cb(cb), - args(memory), - opt_args(memory), - results(memory), - is_write_procedure(is_write_procedure) {} + utils::MemoryResource *memory, const ProcedureInfo &info = {}) + : name(name, memory), cb(cb), args(memory), opt_args(memory), results(memory), info(info) {} /// @throw std::bad_alloc /// @throw std::length_error mgp_proc(const std::string_view name, std::function cb, - utils::MemoryResource *memory, bool is_write_procedure) - : name(name, memory), - cb(cb), - args(memory), - opt_args(memory), - results(memory), - is_write_procedure(is_write_procedure) {} + utils::MemoryResource *memory, const ProcedureInfo &info = {}) + : name(name, memory), cb(cb), args(memory), opt_args(memory), results(memory), info(info) {} /// @throw std::bad_alloc /// @throw std::length_error @@ -696,7 +686,7 @@ struct mgp_proc { args(other.args, memory), opt_args(other.opt_args, memory), results(other.results, memory), - is_write_procedure(other.is_write_procedure) {} + info(other.info) {} mgp_proc(mgp_proc &&other, utils::MemoryResource *memory) : name(std::move(other.name), memory), @@ -704,7 +694,7 @@ struct mgp_proc { args(std::move(other.args), memory), opt_args(std::move(other.opt_args), memory), results(std::move(other.results), memory), - is_write_procedure(other.is_write_procedure) {} + info(other.info) {} mgp_proc(const mgp_proc &other) = default; mgp_proc(mgp_proc &&other) = default; @@ -724,7 +714,7 @@ struct mgp_proc { utils::pmr::vector> opt_args; /// Fields this procedure returns, as a (name -> (type, is_deprecated)) map. utils::pmr::map> results; - bool is_write_procedure{false}; + ProcedureInfo info; }; struct mgp_trans { diff --git a/src/query/procedure/module.cpp b/src/query/procedure/module.cpp index e322d9cee..4bbc67b43 100644 --- a/src/query/procedure/module.cpp +++ b/src/query/procedure/module.cpp @@ -1,4 +1,4 @@ -// Copyright 2021 Memgraph Ltd. +// Copyright 2022 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -10,14 +10,14 @@ // licenses/APL.txt. #include "query/procedure/module.hpp" -#include "utils/memory.hpp" + +#include +#include extern "C" { #include } -#include - #include #include @@ -26,6 +26,7 @@ extern "C" { #include "query/procedure/py_module.hpp" #include "utils/file.hpp" #include "utils/logging.hpp" +#include "utils/memory.hpp" #include "utils/message.hpp" #include "utils/pmr/vector.hpp" #include "utils/string.hpp" @@ -83,6 +84,12 @@ void BuiltinModule::AddTransformation(std::string_view name, mgp_trans trans) { namespace { +auto WithUpgradedLock(auto *lock, const auto &function) { + lock->unlock_shared(); + utils::OnScopeExit shared_lock{[&] { lock->lock_shared(); }}; + function(); +}; + void RegisterMgLoad(ModuleRegistry *module_registry, utils::RWLock *lock, BuiltinModule *module) { // Loading relies on the fact that regular procedure invocation through // CallProcedureCursor::Pull takes ModuleRegistry::lock_ with READ access. To @@ -96,31 +103,19 @@ void RegisterMgLoad(ModuleRegistry *module_registry, utils::RWLock *lock, Builti // single thread may only take either a READ or a WRITE lock, it's not // possible for a thread to hold both. If a thread tries to do that, it will // deadlock immediately (no other thread needs to do anything). - auto with_unlock_shared = [lock](const auto &load_function) { - lock->unlock_shared(); - try { - load_function(); - // There's no finally in C++, but we have to return our original READ lock - // state in any possible case. - } catch (...) { - lock->lock_shared(); - throw; - } - lock->lock_shared(); + auto load_all_cb = [module_registry, lock](mgp_list * /*args*/, mgp_graph * /*graph*/, mgp_result * /*result*/, + mgp_memory * /*memory*/) { + WithUpgradedLock(lock, [&]() { module_registry->UnloadAndLoadModulesFromDirectories(); }); }; - auto load_all_cb = [module_registry, with_unlock_shared](mgp_list * /*args*/, mgp_graph * /*graph*/, - mgp_result * /*result*/, mgp_memory * /*memory*/) { - with_unlock_shared([&]() { module_registry->UnloadAndLoadModulesFromDirectories(); }); - }; - mgp_proc load_all("load_all", load_all_cb, utils::NewDeleteResource(), false); + mgp_proc load_all("load_all", load_all_cb, utils::NewDeleteResource()); module->AddProcedure("load_all", std::move(load_all)); - auto load_cb = [module_registry, with_unlock_shared](mgp_list *args, mgp_graph * /*graph*/, mgp_result *result, - mgp_memory * /*memory*/) { + auto load_cb = [module_registry, lock](mgp_list *args, mgp_graph * /*graph*/, mgp_result *result, + mgp_memory * /*memory*/) { MG_ASSERT(Call(mgp_list_size, args) == 1U, "Should have been type checked already"); auto *arg = Call(mgp_list_at, args, 0); MG_ASSERT(CallBool(mgp_value_is_string, arg), "Should have been type checked already"); bool succ = false; - with_unlock_shared([&]() { + WithUpgradedLock(lock, [&]() { const char *arg_as_string{nullptr}; if (const auto err = mgp_value_get_string(arg, &arg_as_string); err != MGP_ERROR_NO_ERROR) { succ = false; @@ -132,7 +127,7 @@ void RegisterMgLoad(ModuleRegistry *module_registry, utils::RWLock *lock, Builti MG_ASSERT(mgp_result_set_error_msg(result, "Failed to (re)load the module.") == MGP_ERROR_NO_ERROR); } }; - mgp_proc load("load", load_cb, utils::NewDeleteResource(), false); + mgp_proc load("load", load_cb, utils::NewDeleteResource()); MG_ASSERT(mgp_proc_add_arg(&load, "module_name", Call(mgp_type_string)) == MGP_ERROR_NO_ERROR); module->AddProcedure("load", std::move(load)); } @@ -172,11 +167,8 @@ void RegisterMgProcedures( for (const auto &[proc_name, proc] : *module->Procedures()) { mgp_result_record *record{nullptr}; - { - const auto success = TryOrSetError([&] { return mgp_result_new_record(result, &record); }, result); - if (!success) { - return; - } + if (!TryOrSetError([&] { return mgp_result_new_record(result, &record); }, result)) { + return; } const auto path_value = GetStringValueOrSetError(path_string.c_str(), memory, result); @@ -185,12 +177,9 @@ void RegisterMgProcedures( } MgpUniquePtr is_editable_value{nullptr, mgp_value_destroy}; - { - const auto success = TryOrSetError( - [&] { return CreateMgpObject(is_editable_value, mgp_value_make_bool, is_editable, memory); }, result); - if (!success) { - return; - } + if (!TryOrSetError([&] { return CreateMgpObject(is_editable_value, mgp_value_make_bool, is_editable, memory); }, + result)) { + return; } utils::pmr::string full_name(module_name, memory->impl); @@ -211,15 +200,12 @@ void RegisterMgProcedures( } MgpUniquePtr is_write_value{nullptr, mgp_value_destroy}; - { - const auto success = TryOrSetError( - [&, &proc = proc] { - return CreateMgpObject(is_write_value, mgp_value_make_bool, proc.is_write_procedure ? 1 : 0, memory); - }, - result); - if (!success) { - return; - } + if (!TryOrSetError( + [&, &proc = proc] { + return CreateMgpObject(is_write_value, mgp_value_make_bool, proc.info.is_write ? 1 : 0, memory); + }, + result)) { + return; } if (!InsertResultOrSetError(result, record, "name", name_value.get())) { @@ -244,7 +230,7 @@ void RegisterMgProcedures( } } }; - mgp_proc procedures("procedures", procedures_cb, utils::NewDeleteResource(), false); + mgp_proc procedures("procedures", procedures_cb, utils::NewDeleteResource()); MG_ASSERT(mgp_proc_add_result(&procedures, "name", Call(mgp_type_string)) == MGP_ERROR_NO_ERROR); MG_ASSERT(mgp_proc_add_result(&procedures, "signature", Call(mgp_type_string)) == MGP_ERROR_NO_ERROR); MG_ASSERT(mgp_proc_add_result(&procedures, "is_write", Call(mgp_type_bool)) == MGP_ERROR_NO_ERROR); @@ -269,11 +255,8 @@ void RegisterMgTransformations(const std::mapTransformations()) { mgp_result_record *record{nullptr}; - { - const auto success = TryOrSetError([&] { return mgp_result_new_record(result, &record); }, result); - if (!success) { - return; - } + if (!TryOrSetError([&] { return mgp_result_new_record(result, &record); }, result)) { + return; } const auto path_value = GetStringValueOrSetError(path_string.c_str(), memory, result); @@ -282,12 +265,9 @@ void RegisterMgTransformations(const std::map is_editable_value{nullptr, mgp_value_destroy}; - { - const auto success = TryOrSetError( - [&] { return CreateMgpObject(is_editable_value, mgp_value_make_bool, is_editable, memory); }, result); - if (!success) { - return; - } + if (!TryOrSetError([&] { return CreateMgpObject(is_editable_value, mgp_value_make_bool, is_editable, memory); }, + result)) { + return; } utils::pmr::string full_name(module_name, memory->impl); @@ -313,13 +293,358 @@ void RegisterMgTransformations(const std::map(mgp_type_string)) == MGP_ERROR_NO_ERROR); MG_ASSERT(mgp_proc_add_result(&procedures, "path", Call(mgp_type_string)) == MGP_ERROR_NO_ERROR); MG_ASSERT(mgp_proc_add_result(&procedures, "is_editable", Call(mgp_type_bool)) == MGP_ERROR_NO_ERROR); module->AddProcedure("transformations", std::move(procedures)); } +namespace { +bool IsAllowedExtension(const auto &extension) { + constexpr std::array allowed_extensions{".py"}; + return std::any_of(allowed_extensions.begin(), allowed_extensions.end(), + [&](const auto allowed_extension) { return allowed_extension == extension; }); +} + +bool IsSubPath(const auto &base, const auto &destination) { + const auto relative = std::filesystem::relative(destination, base); + return !relative.empty() && *relative.begin() != ".."; +} + +std::optional ReadFile(const auto &path) { + std::ifstream file(path); + if (!file.is_open()) { + return std::nullopt; + } + + const auto size = std::filesystem::file_size(path); + std::string content(size, '\0'); + file.read(content.data(), static_cast(size)); + return std::move(content); +} + +// Return the module directory that contains the `path` +utils::BasicResult ParentModuleDirectory(const ModuleRegistry &module_registry, + const std::filesystem::path &path) { + const auto &module_directories = module_registry.GetModulesDirectory(); + + auto longest_parent_directory = module_directories.end(); + auto max_length = std::numeric_limits::min(); + for (auto it = module_directories.begin(); it != module_directories.end(); ++it) { + if (IsSubPath(*it, path)) { + const auto length = std::filesystem::canonical(*it).string().size(); + if (length > max_length) { + longest_parent_directory = it; + max_length = length; + } + } + } + + if (longest_parent_directory == module_directories.end()) { + return "The specified file isn't contained in any of the module directories."; + } + + return *longest_parent_directory; +} +} // namespace + +void RegisterMgGetModuleFiles(ModuleRegistry *module_registry, BuiltinModule *module) { + auto get_module_files_cb = [module_registry](mgp_list * /*args*/, mgp_graph * /*unused*/, mgp_result *result, + mgp_memory *memory) { + for (const auto &module_directory : module_registry->GetModulesDirectory()) { + for (const auto &dir_entry : std::filesystem::recursive_directory_iterator(module_directory)) { + if (dir_entry.is_regular_file() && IsAllowedExtension(dir_entry.path().extension())) { + mgp_result_record *record{nullptr}; + if (!TryOrSetError([&] { return mgp_result_new_record(result, &record); }, result)) { + return; + } + + const auto path_string = GetPathString(dir_entry); + const auto is_editable = IsFileEditable(dir_entry); + + const auto path_value = GetStringValueOrSetError(path_string.c_str(), memory, result); + if (!path_value) { + return; + } + + MgpUniquePtr is_editable_value{nullptr, mgp_value_destroy}; + if (!TryOrSetError( + [&] { return CreateMgpObject(is_editable_value, mgp_value_make_bool, is_editable, memory); }, + result)) { + return; + } + + if (!InsertResultOrSetError(result, record, "path", path_value.get())) { + return; + } + + if (!InsertResultOrSetError(result, record, "is_editable", is_editable_value.get())) { + return; + } + } + } + } + }; + + mgp_proc get_module_files("get_module_files", get_module_files_cb, utils::NewDeleteResource(), + {.required_privilege = AuthQuery::Privilege::MODULE_READ}); + MG_ASSERT(mgp_proc_add_result(&get_module_files, "path", Call(mgp_type_string)) == MGP_ERROR_NO_ERROR); + MG_ASSERT(mgp_proc_add_result(&get_module_files, "is_editable", Call(mgp_type_bool)) == + MGP_ERROR_NO_ERROR); + module->AddProcedure("get_module_files", std::move(get_module_files)); +} + +void RegisterMgGetModuleFile(ModuleRegistry *module_registry, BuiltinModule *module) { + auto get_module_file_cb = [module_registry](mgp_list *args, mgp_graph * /*unused*/, mgp_result *result, + mgp_memory *memory) { + MG_ASSERT(Call(mgp_list_size, args) == 1U, "Should have been type checked already"); + auto *arg = Call(mgp_list_at, args, 0); + MG_ASSERT(CallBool(mgp_value_is_string, arg), "Should have been type checked already"); + const char *path_str{nullptr}; + if (!TryOrSetError([&] { return mgp_value_get_string(arg, &path_str); }, result)) { + return; + } + + const std::filesystem::path path{path_str}; + + if (!path.is_absolute()) { + static_cast(mgp_result_set_error_msg(result, "The path should be an absolute path.")); + return; + } + + if (!IsAllowedExtension(path.extension())) { + static_cast(mgp_result_set_error_msg(result, "The specified file isn't in the supported format.")); + return; + } + + if (!std::filesystem::exists(path)) { + static_cast(mgp_result_set_error_msg(result, "The specified file doesn't exist.")); + return; + } + + if (auto maybe_error_msg = ParentModuleDirectory(*module_registry, path); maybe_error_msg.HasError()) { + static_cast(mgp_result_set_error_msg(result, maybe_error_msg.GetError())); + return; + } + + const auto maybe_content = ReadFile(path); + if (!maybe_content) { + static_cast(mgp_result_set_error_msg(result, "Couldn't read the content of the file.")); + return; + } + + mgp_result_record *record{nullptr}; + if (!TryOrSetError([&] { return mgp_result_new_record(result, &record); }, result)) { + return; + } + + const auto content_value = GetStringValueOrSetError(maybe_content->c_str(), memory, result); + if (!content_value) { + return; + } + + if (!InsertResultOrSetError(result, record, "content", content_value.get())) { + return; + } + }; + mgp_proc get_module_file("get_module_file", std::move(get_module_file_cb), utils::NewDeleteResource(), + {.required_privilege = AuthQuery::Privilege::MODULE_READ}); + MG_ASSERT(mgp_proc_add_arg(&get_module_file, "path", Call(mgp_type_string)) == MGP_ERROR_NO_ERROR); + MG_ASSERT(mgp_proc_add_result(&get_module_file, "content", Call(mgp_type_string)) == MGP_ERROR_NO_ERROR); + module->AddProcedure("get_module_file", std::move(get_module_file)); +} + +namespace { +utils::BasicResult WriteToFile(const std::filesystem::path &file, const std::string_view content) { + std::ofstream output_file{file}; + if (!output_file.is_open()) { + return fmt::format("Failed to open the file at location {}", file); + } + output_file.write(content.data(), static_cast(content.size())); + output_file.flush(); + return {}; +} +} // namespace + +void RegisterMgCreateModuleFile(ModuleRegistry *module_registry, utils::RWLock *lock, BuiltinModule *module) { + auto create_module_file_cb = [module_registry, lock](mgp_list *args, mgp_graph * /*unused*/, mgp_result *result, + mgp_memory *memory) { + MG_ASSERT(Call(mgp_list_size, args) == 2U, "Should have been type checked already"); + auto *filename_arg = Call(mgp_list_at, args, 0); + MG_ASSERT(CallBool(mgp_value_is_string, filename_arg), "Should have been type checked already"); + const char *filename_str{nullptr}; + if (!TryOrSetError([&] { return mgp_value_get_string(filename_arg, &filename_str); }, result)) { + return; + } + + const auto file_path = module_registry->InternalModuleDir() / filename_str; + + if (!IsSubPath(module_registry->InternalModuleDir(), file_path)) { + static_cast(mgp_result_set_error_msg( + result, + "Invalid relative path defined. The module file cannot be define outside the internal modules directory.")); + return; + } + + if (!IsAllowedExtension(file_path.extension())) { + static_cast(mgp_result_set_error_msg(result, "The specified file isn't in the supported format.")); + return; + } + + if (std::filesystem::exists(file_path)) { + static_cast(mgp_result_set_error_msg(result, "File with the same name already exists!")); + return; + } + + utils::EnsureDir(file_path.parent_path()); + + auto *content_arg = Call(mgp_list_at, args, 1); + MG_ASSERT(CallBool(mgp_value_is_string, content_arg), "Should have been type checked already"); + const char *content_str{nullptr}; + if (!TryOrSetError([&] { return mgp_value_get_string(content_arg, &content_str); }, result)) { + return; + } + + if (auto maybe_error = WriteToFile(file_path, {content_str, std::strlen(content_str)}); maybe_error.HasError()) { + static_cast(mgp_result_set_error_msg(result, maybe_error.GetError().c_str())); + return; + } + + mgp_result_record *record{nullptr}; + if (!TryOrSetError([&] { return mgp_result_new_record(result, &record); }, result)) { + return; + } + + const auto path_value = GetStringValueOrSetError(std::filesystem::canonical(file_path).c_str(), memory, result); + if (!path_value) { + return; + } + + if (!InsertResultOrSetError(result, record, "path", path_value.get())) { + return; + } + + WithUpgradedLock(lock, [&]() { module_registry->UnloadAndLoadModulesFromDirectories(); }); + }; + mgp_proc create_module_file("create_module_file", std::move(create_module_file_cb), utils::NewDeleteResource(), + {.required_privilege = AuthQuery::Privilege::MODULE_WRITE}); + MG_ASSERT(mgp_proc_add_arg(&create_module_file, "filename", Call(mgp_type_string)) == MGP_ERROR_NO_ERROR); + MG_ASSERT(mgp_proc_add_arg(&create_module_file, "content", Call(mgp_type_string)) == MGP_ERROR_NO_ERROR); + MG_ASSERT(mgp_proc_add_result(&create_module_file, "path", Call(mgp_type_string)) == MGP_ERROR_NO_ERROR); + module->AddProcedure("create_module_file", std::move(create_module_file)); +} + +void RegisterMgUpdateModuleFile(ModuleRegistry *module_registry, utils::RWLock *lock, BuiltinModule *module) { + auto update_module_file_cb = [module_registry, lock](mgp_list *args, mgp_graph * /*unused*/, mgp_result *result, + mgp_memory * /*memory*/) { + MG_ASSERT(Call(mgp_list_size, args) == 2U, "Should have been type checked already"); + auto *path_arg = Call(mgp_list_at, args, 0); + MG_ASSERT(CallBool(mgp_value_is_string, path_arg), "Should have been type checked already"); + const char *path_str{nullptr}; + if (!TryOrSetError([&] { return mgp_value_get_string(path_arg, &path_str); }, result)) { + return; + } + + const std::filesystem::path path{path_str}; + + if (!path.is_absolute()) { + static_cast(mgp_result_set_error_msg(result, "The path should be an absolute path.")); + return; + } + + if (!IsAllowedExtension(path.extension())) { + static_cast(mgp_result_set_error_msg(result, "The specified file isn't in the supported format.")); + return; + } + + if (!std::filesystem::exists(path)) { + static_cast(mgp_result_set_error_msg(result, "The specified file doesn't exist.")); + return; + } + + if (auto maybe_error_msg = ParentModuleDirectory(*module_registry, path); maybe_error_msg.HasError()) { + static_cast(mgp_result_set_error_msg(result, maybe_error_msg.GetError())); + return; + } + + auto *content_arg = Call(mgp_list_at, args, 1); + MG_ASSERT(CallBool(mgp_value_is_string, content_arg), "Should have been type checked already"); + const char *content_str{nullptr}; + if (!TryOrSetError([&] { return mgp_value_get_string(content_arg, &content_str); }, result)) { + return; + } + + if (auto maybe_error = WriteToFile(path, {content_str, std::strlen(content_str)}); maybe_error.HasError()) { + static_cast(mgp_result_set_error_msg(result, maybe_error.GetError().c_str())); + return; + } + + WithUpgradedLock(lock, [&]() { module_registry->UnloadAndLoadModulesFromDirectories(); }); + }; + mgp_proc update_module_file("update_module_file", std::move(update_module_file_cb), utils::NewDeleteResource(), + {.required_privilege = AuthQuery::Privilege::MODULE_WRITE}); + MG_ASSERT(mgp_proc_add_arg(&update_module_file, "path", Call(mgp_type_string)) == MGP_ERROR_NO_ERROR); + MG_ASSERT(mgp_proc_add_arg(&update_module_file, "content", Call(mgp_type_string)) == MGP_ERROR_NO_ERROR); + module->AddProcedure("update_module_file", std::move(update_module_file)); +} + +void RegisterMgDeleteModuleFile(ModuleRegistry *module_registry, utils::RWLock *lock, BuiltinModule *module) { + auto delete_module_file_cb = [module_registry, lock](mgp_list *args, mgp_graph * /*unused*/, mgp_result *result, + mgp_memory * /*memory*/) { + MG_ASSERT(Call(mgp_list_size, args) == 1U, "Should have been type checked already"); + auto *path_arg = Call(mgp_list_at, args, 0); + MG_ASSERT(CallBool(mgp_value_is_string, path_arg), "Should have been type checked already"); + const char *path_str{nullptr}; + if (!TryOrSetError([&] { return mgp_value_get_string(path_arg, &path_str); }, result)) { + return; + } + + const std::filesystem::path path{path_str}; + + if (!path.is_absolute()) { + static_cast(mgp_result_set_error_msg(result, "The path should be an absolute path.")); + return; + } + + if (!IsAllowedExtension(path.extension())) { + static_cast(mgp_result_set_error_msg(result, "The specified file isn't in the supported format.")); + return; + } + + if (!std::filesystem::exists(path)) { + static_cast(mgp_result_set_error_msg(result, "The specified file doesn't exist.")); + return; + } + + const auto parent_module_directory = ParentModuleDirectory(*module_registry, path); + if (parent_module_directory.HasError()) { + static_cast(mgp_result_set_error_msg(result, parent_module_directory.GetError())); + return; + } + + std::error_code ec; + if (!std::filesystem::remove(path, ec)) { + static_cast( + mgp_result_set_error_msg(result, fmt::format("Failed to delete the module: {}", ec.message()).c_str())); + return; + } + + auto parent_path = path.parent_path(); + while (!std::filesystem::is_symlink(parent_path) && std::filesystem::is_empty(parent_path) && + !std::filesystem::equivalent(*parent_module_directory, parent_path)) { + std::filesystem::remove(parent_path); + parent_path = parent_path.parent_path(); + } + + WithUpgradedLock(lock, [&]() { module_registry->UnloadAndLoadModulesFromDirectories(); }); + }; + mgp_proc delete_module_file("delete_module_file", std::move(delete_module_file_cb), utils::NewDeleteResource(), + {.required_privilege = AuthQuery::Privilege::MODULE_WRITE}); + MG_ASSERT(mgp_proc_add_arg(&delete_module_file, "path", Call(mgp_type_string)) == MGP_ERROR_NO_ERROR); + module->AddProcedure("delete_module_file", std::move(delete_module_file)); +} + // Run `fun` with `mgp_module *` and `mgp_memory *` arguments. If `fun` returned // a `true` value, store the `mgp_module::procedures` and // `mgp_module::transformations into `proc_map`. The return value of WithModuleRegistration @@ -560,6 +885,9 @@ bool PythonModule::Close() { py_module_ = py::Object(nullptr); return false; } + + // Remove the cached bytecode if it's present + std::filesystem::remove_all(file_path_.parent_path() / "__pycache__"); py_module_ = py::Object(nullptr); spdlog::info("Closed module {}", file_path_); return true; @@ -627,13 +955,24 @@ ModuleRegistry::ModuleRegistry() { RegisterMgProcedures(&modules_, module.get()); RegisterMgTransformations(&modules_, module.get()); RegisterMgLoad(this, &lock_, module.get()); + RegisterMgGetModuleFiles(this, module.get()); + RegisterMgGetModuleFile(this, module.get()); + RegisterMgCreateModuleFile(this, &lock_, module.get()); + RegisterMgUpdateModuleFile(this, &lock_, module.get()); + RegisterMgDeleteModuleFile(this, &lock_, module.get()); modules_.emplace("mg", std::move(module)); } -void ModuleRegistry::SetModulesDirectory(std::vector modules_dirs) { +void ModuleRegistry::SetModulesDirectory(std::vector modules_dirs, + const std::filesystem::path &data_directory) { + internal_module_dir_ = data_directory / "internal_modules"; + utils::EnsureDirOrDie(internal_module_dir_); modules_dirs_ = std::move(modules_dirs); + modules_dirs_.push_back(internal_module_dir_); } +const std::vector &ModuleRegistry::GetModulesDirectory() const { return modules_dirs_; } + bool ModuleRegistry::LoadModuleIfFound(const std::filesystem::path &modules_dir, const std::string_view name) { if (!utils::DirExists(modules_dir)) { spdlog::error( @@ -722,6 +1061,8 @@ bool ModuleRegistry::RegisterMgProcedure(const std::string_view name, mgp_proc p return false; } +const std::filesystem::path &ModuleRegistry::InternalModuleDir() const noexcept { return internal_module_dir_; } + namespace { /// This function returns a pair of either diff --git a/src/query/procedure/module.hpp b/src/query/procedure/module.hpp index cefdbc780..9e269687d 100644 --- a/src/query/procedure/module.hpp +++ b/src/query/procedure/module.hpp @@ -1,4 +1,4 @@ -// Copyright 2021 Memgraph Ltd. +// Copyright 2022 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -87,7 +87,8 @@ class ModuleRegistry final { ModuleRegistry(); /// Set the modules directories that will be used when (re)loading modules. - void SetModulesDirectory(std::vector modules_dir); + void SetModulesDirectory(std::vector modules_dir, const std::filesystem::path &data_directory); + const std::vector &GetModulesDirectory() const; /// Atomically load or reload a module with a particular name from the given /// directory. @@ -121,8 +122,11 @@ class ModuleRegistry final { bool RegisterMgProcedure(std::string_view name, mgp_proc proc); + const std::filesystem::path &InternalModuleDir() const noexcept; + private: std::vector modules_dirs_; + std::filesystem::path internal_module_dir_; }; /// Single, global module registry. diff --git a/src/query/procedure/py_module.cpp b/src/query/procedure/py_module.cpp index 3f3716c1c..b73a0204f 100644 --- a/src/query/procedure/py_module.cpp +++ b/src/query/procedure/py_module.cpp @@ -1,4 +1,4 @@ -// Copyright 2021 Memgraph Ltd. +// Copyright 2022 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -1046,12 +1046,11 @@ PyObject *PyQueryModuleAddProcedure(PyQueryModule *self, PyObject *cb, bool is_w return nullptr; } auto *memory = self->module->procedures.get_allocator().GetMemoryResource(); - mgp_proc proc( - name, - [py_cb](mgp_list *args, mgp_graph *graph, mgp_result *result, mgp_memory *memory) { - CallPythonProcedure(py_cb, args, graph, result, memory); - }, - memory, is_write_procedure); + mgp_proc proc(name, + [py_cb](mgp_list *args, mgp_graph *graph, mgp_result *result, mgp_memory *memory) { + CallPythonProcedure(py_cb, args, graph, result, memory); + }, + memory, {.is_write = is_write_procedure}); const auto &[proc_it, did_insert] = self->module->procedures.emplace(name, std::move(proc)); if (!did_insert) { PyErr_SetString(PyExc_ValueError, "Already registered a procedure with the same name."); diff --git a/src/query/stream/streams.cpp b/src/query/stream/streams.cpp index e1e83c5ce..3e3689a3c 100644 --- a/src/query/stream/streams.cpp +++ b/src/query/stream/streams.cpp @@ -196,7 +196,7 @@ void Streams::RegisterKafkaProcedures() { it->second); }; - mgp_proc proc(proc_name, set_stream_offset, utils::NewDeleteResource(), false); + mgp_proc proc(proc_name, set_stream_offset, utils::NewDeleteResource()); MG_ASSERT(mgp_proc_add_arg(&proc, "stream_name", procedure::Call(mgp_type_string)) == MGP_ERROR_NO_ERROR); MG_ASSERT(mgp_proc_add_arg(&proc, "offset", procedure::Call(mgp_type_int)) == MGP_ERROR_NO_ERROR); @@ -226,12 +226,8 @@ void Streams::RegisterKafkaProcedures() { auto stream_source_ptr = kafka_stream.stream_source->Lock(); const auto info = stream_source_ptr->Info(kafka_stream.transformation_name); mgp_result_record *record{nullptr}; - { - const auto success = - procedure::TryOrSetError([&] { return mgp_result_new_record(result, &record); }, result); - if (!success) { - return; - } + if (!procedure::TryOrSetError([&] { return mgp_result_new_record(result, &record); }, result)) { + return; } const auto consumer_group_value = @@ -241,15 +237,13 @@ void Streams::RegisterKafkaProcedures() { } procedure::MgpUniquePtr topic_names{nullptr, mgp_list_destroy}; - { - const auto success = procedure::TryOrSetError( - [&] { - return procedure::CreateMgpObject(topic_names, mgp_list_make_empty, info.topics.size(), memory); - }, - result); - if (!success) { - return; - } + if (!procedure::TryOrSetError( + [&] { + return procedure::CreateMgpObject(topic_names, mgp_list_make_empty, info.topics.size(), + memory); + }, + result)) { + return; } for (const auto &topic : info.topics) { @@ -261,15 +255,14 @@ void Streams::RegisterKafkaProcedures() { } procedure::MgpUniquePtr topics_value{nullptr, mgp_value_destroy}; - { - const auto success = procedure::TryOrSetError( - [&] { return procedure::CreateMgpObject(topics_value, mgp_value_make_list, topic_names.get()); }, - result); - if (!success) { - return; - } - static_cast(topic_names.release()); + if (!procedure::TryOrSetError( + [&] { + return procedure::CreateMgpObject(topics_value, mgp_value_make_list, topic_names.get()); + }, + result)) { + return; } + static_cast(topic_names.release()); const auto bootstrap_servers_value = procedure::GetStringValueOrSetError(info.bootstrap_servers.c_str(), memory, result); @@ -282,12 +275,9 @@ void Streams::RegisterKafkaProcedures() { -> procedure::MgpUniquePtr { procedure::MgpUniquePtr configs_value{nullptr, mgp_value_destroy}; procedure::MgpUniquePtr configs{nullptr, mgp_map_destroy}; - { - const auto success = procedure::TryOrSetError( - [&] { return procedure::CreateMgpObject(configs, mgp_map_make_empty, memory); }, result); - if (!success) { - return configs_value; - } + if (!procedure::TryOrSetError( + [&] { return procedure::CreateMgpObject(configs, mgp_map_make_empty, memory); }, result)) { + return configs_value; } for (const auto &[key, value] : configs_to_convert) { @@ -298,15 +288,12 @@ void Streams::RegisterKafkaProcedures() { configs->items.emplace(key, std::move(*value_value)); } - { - const auto success = procedure::TryOrSetError( - [&] { return procedure::CreateMgpObject(configs_value, mgp_value_make_map, configs.get()); }, - result); - if (!success) { - return configs_value; - } - static_cast(configs.release()); + if (!procedure::TryOrSetError( + [&] { return procedure::CreateMgpObject(configs_value, mgp_value_make_map, configs.get()); }, + result)) { + return configs_value; } + static_cast(configs.release()); return configs_value; }; @@ -358,7 +345,7 @@ void Streams::RegisterKafkaProcedures() { it->second); }; - mgp_proc proc(proc_name, get_stream_info, utils::NewDeleteResource(), false); + mgp_proc proc(proc_name, get_stream_info, utils::NewDeleteResource()); MG_ASSERT(mgp_proc_add_arg(&proc, "stream_name", procedure::Call(mgp_type_string)) == MGP_ERROR_NO_ERROR); MG_ASSERT(mgp_proc_add_result(&proc, consumer_group_result_name.data(), @@ -395,12 +382,8 @@ void Streams::RegisterPulsarProcedures() { auto stream_source_ptr = pulsar_stream.stream_source->Lock(); const auto info = stream_source_ptr->Info(pulsar_stream.transformation_name); mgp_result_record *record{nullptr}; - { - const auto success = - procedure::TryOrSetError([&] { return mgp_result_new_record(result, &record); }, result); - if (!success) { - return; - } + if (!procedure::TryOrSetError([&] { return mgp_result_new_record(result, &record); }, result)) { + return; } auto service_url_value = procedure::GetStringValueOrSetError(info.service_url.c_str(), memory, result); @@ -409,15 +392,13 @@ void Streams::RegisterPulsarProcedures() { } procedure::MgpUniquePtr topic_names{nullptr, mgp_list_destroy}; - { - const auto success = procedure::TryOrSetError( - [&] { - return procedure::CreateMgpObject(topic_names, mgp_list_make_empty, info.topics.size(), memory); - }, - result); - if (!success) { - return; - } + if (!procedure::TryOrSetError( + [&] { + return procedure::CreateMgpObject(topic_names, mgp_list_make_empty, info.topics.size(), + memory); + }, + result)) { + return; } for (const auto &topic : info.topics) { @@ -429,15 +410,12 @@ void Streams::RegisterPulsarProcedures() { } procedure::MgpUniquePtr topics_value{nullptr, mgp_value_destroy}; - { - const auto success = procedure::TryOrSetError( - [&] { - return procedure::CreateMgpObject(topics_value, mgp_value_make_list, topic_names.release()); - }, - result); - if (!success) { - return; - } + if (!procedure::TryOrSetError( + [&] { + return procedure::CreateMgpObject(topics_value, mgp_value_make_list, topic_names.release()); + }, + result)) { + return; } if (!procedure::InsertResultOrSetError(result, record, topics_result_name.data(), topics_value.get())) { @@ -455,7 +433,7 @@ void Streams::RegisterPulsarProcedures() { it->second); }; - mgp_proc proc(proc_name, get_stream_info, utils::NewDeleteResource(), false); + mgp_proc proc(proc_name, get_stream_info, utils::NewDeleteResource()); MG_ASSERT(mgp_proc_add_arg(&proc, "stream_name", procedure::Call(mgp_type_string)) == MGP_ERROR_NO_ERROR); MG_ASSERT(mgp_proc_add_result(&proc, service_url_result_name.data(), diff --git a/src/storage/v2/delta.hpp b/src/storage/v2/delta.hpp index 6f4fe1e32..16f343fc7 100644 --- a/src/storage/v2/delta.hpp +++ b/src/storage/v2/delta.hpp @@ -1,4 +1,4 @@ -// Copyright 2021 Memgraph Ltd. +// Copyright 2022 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -43,17 +43,19 @@ class PreviousPtr { public: enum class Type { + NULLPTR, DELTA, VERTEX, EDGE, }; struct Pointer { + Pointer() = default; explicit Pointer(Delta *delta) : type(Type::DELTA), delta(delta) {} explicit Pointer(Vertex *vertex) : type(Type::VERTEX), vertex(vertex) {} explicit Pointer(Edge *edge) : type(Type::EDGE), edge(edge) {} - Type type; + Type type{Type::NULLPTR}; Delta *delta{nullptr}; Vertex *vertex{nullptr}; Edge *edge{nullptr}; @@ -65,6 +67,9 @@ class PreviousPtr { Pointer Get() const { uintptr_t value = storage_.load(std::memory_order_acquire); + if (value == 0) { + return {}; + } uintptr_t type = value & kMask; if (type == kDelta) { return Pointer{reinterpret_cast(value & ~kMask)}; @@ -108,6 +113,8 @@ inline bool operator==(const PreviousPtr::Pointer &a, const PreviousPtr::Pointer return a.edge == b.edge; case PreviousPtr::Type::DELTA: return a.delta == b.delta; + case PreviousPtr::Type::NULLPTR: + return b.type == PreviousPtr::Type::NULLPTR; } } diff --git a/src/storage/v2/storage.cpp b/src/storage/v2/storage.cpp index f90862cdc..10655afd6 100644 --- a/src/storage/v2/storage.cpp +++ b/src/storage/v2/storage.cpp @@ -1,4 +1,4 @@ -// Copyright 2021 Memgraph Ltd. +// Copyright 2022 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -377,7 +377,8 @@ Storage::Storage(Config config) if (auto maybe_error = this->CreateSnapshot(); maybe_error.HasError()) { switch (maybe_error.GetError()) { case CreateSnapshotError::DisabledForReplica: - spdlog::warn(utils::MessageWithLink("Snapshots are disabled for replicas.", "https://memgr.ph/replication")); + spdlog::warn( + utils::MessageWithLink("Snapshots are disabled for replicas.", "https://memgr.ph/replication")); break; } } @@ -826,6 +827,7 @@ utils::BasicResult Storage::Accessor::Commit( // vertices. for (const auto &delta : transaction_.deltas) { auto prev = delta.prev.Get(); + MG_ASSERT(prev.type != PreviousPtr::Type::NULLPTR, "Invalid pointer!"); if (prev.type != PreviousPtr::Type::VERTEX) { continue; } @@ -855,6 +857,7 @@ utils::BasicResult Storage::Accessor::Commit( // to be validated/committed. for (const auto &delta : transaction_.deltas) { auto prev = delta.prev.Get(); + MG_ASSERT(prev.type != PreviousPtr::Type::NULLPTR, "Invalid pointer!"); if (prev.type != PreviousPtr::Type::VERTEX) { continue; } @@ -865,6 +868,7 @@ utils::BasicResult Storage::Accessor::Commit( // vertices. for (const auto &delta : transaction_.deltas) { auto prev = delta.prev.Get(); + MG_ASSERT(prev.type != PreviousPtr::Type::NULLPTR, "Invalid pointer!"); if (prev.type != PreviousPtr::Type::VERTEX) { continue; } @@ -1064,6 +1068,8 @@ void Storage::Accessor::Abort() { break; } case PreviousPtr::Type::DELTA: + // pointer probably couldn't be set because allocation failed + case PreviousPtr::Type::NULLPTR: break; } } @@ -1437,6 +1443,7 @@ void Storage::CollectGarbage() { guard = std::unique_lock(parent.edge->lock); break; case PreviousPtr::Type::DELTA: + case PreviousPtr::Type::NULLPTR: LOG_FATAL("Invalid database state!"); } } @@ -1449,6 +1456,9 @@ void Storage::CollectGarbage() { prev_delta->next.store(nullptr, std::memory_order_release); break; } + case PreviousPtr::Type::NULLPTR: { + LOG_FATAL("Invalid pointer!"); + } } break; } @@ -1595,6 +1605,7 @@ void Storage::AppendToWal(const Transaction &transaction, uint64_t final_commit_ }); } auto prev = delta->prev.Get(); + MG_ASSERT(prev.type != PreviousPtr::Type::NULLPTR, "Invalid pointer!"); if (prev.type != PreviousPtr::Type::DELTA) break; delta = prev.delta; } @@ -1617,6 +1628,7 @@ void Storage::AppendToWal(const Transaction &transaction, uint64_t final_commit_ // and modify vertex data. for (const auto &delta : transaction.deltas) { auto prev = delta.prev.Get(); + MG_ASSERT(prev.type != PreviousPtr::Type::NULLPTR, "Invalid pointer!"); if (prev.type != PreviousPtr::Type::VERTEX) continue; find_and_apply_deltas(&delta, *prev.vertex, [](auto action) { switch (action) { @@ -1638,6 +1650,7 @@ void Storage::AppendToWal(const Transaction &transaction, uint64_t final_commit_ // 2. Process all Vertex deltas and store all operations that create edges. for (const auto &delta : transaction.deltas) { auto prev = delta.prev.Get(); + MG_ASSERT(prev.type != PreviousPtr::Type::NULLPTR, "Invalid pointer!"); if (prev.type != PreviousPtr::Type::VERTEX) continue; find_and_apply_deltas(&delta, *prev.vertex, [](auto action) { switch (action) { @@ -1659,6 +1672,7 @@ void Storage::AppendToWal(const Transaction &transaction, uint64_t final_commit_ // 3. Process all Edge deltas and store all operations that modify edge data. for (const auto &delta : transaction.deltas) { auto prev = delta.prev.Get(); + MG_ASSERT(prev.type != PreviousPtr::Type::NULLPTR, "Invalid pointer!"); if (prev.type != PreviousPtr::Type::EDGE) continue; find_and_apply_deltas(&delta, *prev.edge, [](auto action) { switch (action) { @@ -1680,6 +1694,7 @@ void Storage::AppendToWal(const Transaction &transaction, uint64_t final_commit_ // 4. Process all Vertex deltas and store all operations that delete edges. for (const auto &delta : transaction.deltas) { auto prev = delta.prev.Get(); + MG_ASSERT(prev.type != PreviousPtr::Type::NULLPTR, "Invalid pointer!"); if (prev.type != PreviousPtr::Type::VERTEX) continue; find_and_apply_deltas(&delta, *prev.vertex, [](auto action) { switch (action) { @@ -1701,6 +1716,7 @@ void Storage::AppendToWal(const Transaction &transaction, uint64_t final_commit_ // 5. Process all Vertex deltas and store all operations that delete vertices. for (const auto &delta : transaction.deltas) { auto prev = delta.prev.Get(); + MG_ASSERT(prev.type != PreviousPtr::Type::NULLPTR, "Invalid pointer!"); if (prev.type != PreviousPtr::Type::VERTEX) continue; find_and_apply_deltas(&delta, *prev.vertex, [](auto action) { switch (action) { diff --git a/tests/e2e/CMakeLists.txt b/tests/e2e/CMakeLists.txt index 06e7f8baf..7660e74e1 100644 --- a/tests/e2e/CMakeLists.txt +++ b/tests/e2e/CMakeLists.txt @@ -13,5 +13,6 @@ add_subdirectory(isolation_levels) add_subdirectory(streams) add_subdirectory(temporal_types) add_subdirectory(write_procedures) +add_subdirectory(module_file_manager) copy_e2e_python_files(pytest_runner pytest_runner.sh "") diff --git a/tests/e2e/memory/memory_control.cpp b/tests/e2e/memory/memory_control.cpp index 03563923c..9cab68b09 100644 --- a/tests/e2e/memory/memory_control.cpp +++ b/tests/e2e/memory/memory_control.cpp @@ -1,4 +1,4 @@ -// Copyright 2021 Memgraph Ltd. +// Copyright 2022 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -42,10 +42,11 @@ int main(int argc, char **argv) { LOG_FATAL("The test timed out"); } client->Execute(create_query); - if (!client->FetchOne()) { + try { + client->DiscardAll(); + } catch (const mg::TransientException & /*unused*/) { break; } - client->DiscardAll(); } spdlog::info("Memgraph is out of memory"); diff --git a/tests/e2e/memory/memory_limit_global_alloc_proc.cpp b/tests/e2e/memory/memory_limit_global_alloc_proc.cpp index 13b8bfa26..5f8cb7864 100644 --- a/tests/e2e/memory/memory_limit_global_alloc_proc.cpp +++ b/tests/e2e/memory/memory_limit_global_alloc_proc.cpp @@ -1,4 +1,4 @@ -// Copyright 2021 Memgraph Ltd. +// Copyright 2022 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -10,8 +10,8 @@ // licenses/APL.txt. #include -#include #include +#include #include "utils/logging.hpp" #include "utils/timer.hpp" @@ -31,11 +31,20 @@ int main(int argc, char **argv) { if (!client) { LOG_FATAL("Failed to connect!"); } - bool result = client->Execute("CALL libglobal_memory_limit_proc.error() YIELD *"); - auto result1 = client->FetchAll(); - MG_ASSERT(result1 != std::nullopt && result1->size() == 0); + MG_ASSERT(client->Execute("CALL libglobal_memory_limit_proc.error() YIELD *")); + MG_ASSERT(std::invoke([&] { + try { + auto result1 = client->FetchAll(); + } catch (const mg::ClientException &e) { + MG_ASSERT(e.what() == std::string_view{"libglobal_memory_limit_proc.error: Out of memory"}, + "Invalid message received"); + return true; + } + return false; + }), + "Procedure didn't throw the expected `mg::ClientException`"); - result = client->Execute("CALL libglobal_memory_limit_proc.success() YIELD *"); + MG_ASSERT(client->Execute("CALL libglobal_memory_limit_proc.success() YIELD *")); auto result2 = client->FetchAll(); MG_ASSERT(result2 != std::nullopt && result2->size() > 0); return 0; diff --git a/tests/e2e/module_file_manager/CMakeLists.txt b/tests/e2e/module_file_manager/CMakeLists.txt new file mode 100644 index 000000000..84d8845ff --- /dev/null +++ b/tests/e2e/module_file_manager/CMakeLists.txt @@ -0,0 +1,4 @@ +find_package(gflags REQUIRED) + +add_executable(memgraph__e2e__module_file_manager module_file_manager.cpp) +target_link_libraries(memgraph__e2e__module_file_manager gflags mgclient mg-utils mg-io Threads::Threads) diff --git a/tests/e2e/module_file_manager/module_file_manager.cpp b/tests/e2e/module_file_manager/module_file_manager.cpp new file mode 100644 index 000000000..68cc0413e --- /dev/null +++ b/tests/e2e/module_file_manager/module_file_manager.cpp @@ -0,0 +1,270 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include +#include + +#include +#include + +#include "utils/file.hpp" +#include "utils/logging.hpp" +#include "utils/timer.hpp" + +DEFINE_uint64(bolt_port, 7687, "Bolt port"); +DEFINE_uint64(timeout, 120, "Timeout seconds"); + +namespace { +auto GetClient() { + auto client = + mg::Client::Connect({.host = "127.0.0.1", .port = static_cast(FLAGS_bolt_port), .use_ssl = false}); + MG_ASSERT(client, "Failed to connect!"); + + return client; +} + +std::vector GetModuleFiles(auto &client) { + MG_ASSERT(client->Execute("CALL mg.get_module_files() YIELD path")); + + const auto result_rows = client->FetchAll(); + MG_ASSERT(result_rows, "Failed to get results"); + + std::vector result; + result.reserve(result_rows->size()); + + for (const auto &row : *result_rows) { + MG_ASSERT(row.size() == 1, "Invalid result received from mg.get_module_files"); + result.emplace_back(row[0].ValueString()); + } + + return result; +} + +bool ModuleFileExists(auto &client, const auto &path) { + const auto module_files = GetModuleFiles(client); + + return std::any_of(module_files.begin(), module_files.end(), + [&](const auto &module_file) { return module_file == path; }); +} + +void AssertModuleFileExists(auto &client, const auto &path) { + MG_ASSERT(ModuleFileExists(client, path), "Module file {} is missing", path); +} + +void AssertModuleFileNotExists(auto &client, const auto &path) { + MG_ASSERT(!ModuleFileExists(client, path), "Invalid module file {} is present", path); +} + +bool ProcedureExists(auto &client, const std::string_view procedure_name, + std::optional path = std::nullopt) { + MG_ASSERT(client->Execute("CALL mg.procedures() YIELD name, path")); + + const auto result_rows = client->FetchAll(); + MG_ASSERT(result_rows, "Failed to get results for mg.procedures()"); + + return std::find_if(result_rows->begin(), result_rows->end(), [&, procedure_name](const auto &row) { + MG_ASSERT(row.size() == 2, "Invalid result received from mg.procedures()"); + if (row[0].ValueString() == procedure_name) { + if (path) { + return row[1].ValueString() == std::filesystem::canonical(*path).generic_string(); + } + return true; + } + return false; + }) != result_rows->end(); +} + +void AssertProcedureExists(auto &client, const std::string_view procedure_name, + std::optional path = std::nullopt) { + MG_ASSERT(ProcedureExists(client, procedure_name, path), "Procedure {} is missing", procedure_name); +} + +void AssertProcedureNotExists(auto &client, const std::string_view procedure_name) { + MG_ASSERT(!ProcedureExists(client, procedure_name), "Invalid procedure ('{}') is present", procedure_name); +} + +template +void AssertQueryFails(auto &client, const std::string &query, std::optional expected_message) { + spdlog::info("Asserting query '{}' fails", query); + MG_ASSERT(client->Execute(query)); + try { + client->FetchAll(); + } catch (const TException &exception) { + if (expected_message) { + MG_ASSERT(*expected_message == exception.what(), + "Exception with a different message was thrown.\n\t\tExpected: {}\n\t\tActual: {}", *expected_message, + exception.what()); + } + return; + } + + LOG_FATAL("Didn't receive expected exception"); +} + +std::string CreateModuleFileQuery(const std::string_view filename, const std::string_view content) { + return fmt::format("CALL mg.create_module_file('{}', '{}') YIELD path", filename, content); +} + +std::filesystem::path CreateModuleFile(auto &client, const std::string_view filename, const std::string_view content) { + spdlog::info("Creating module file '{}' with content:\n{}", filename, content); + MG_ASSERT(client->Execute(CreateModuleFileQuery(filename, content))); + + const auto result_row = client->FetchOne(); + MG_ASSERT(result_row && result_row->size() == 1, "Received invalid result from mg.create_module_file"); + MG_ASSERT(!client->FetchOne().has_value(), "Too many results received from mg.create_module_file"); + + return result_row->at(0).ValueString(); +} + +std::string GetModuleFileQuery(const std::filesystem::path &path) { + return fmt::format("CALL mg.get_module_file({}) YIELD content", path); +} + +std::string GetModuleFile(auto &client, const std::filesystem::path &path) { + spdlog::info("Getting content of module file '{}'", path); + MG_ASSERT(client->Execute(GetModuleFileQuery(path))); + + const auto result_row = client->FetchOne(); + MG_ASSERT(result_row && result_row->size() == 1, "Received invalid result from mg.get_module_file"); + MG_ASSERT(!client->FetchOne().has_value(), "Too many results received from mg.get_module_file"); + + return std::string{result_row->at(0).ValueString()}; +} + +std::string UpdateModuleFileQuery(const std::filesystem::path &path, const std::string_view content) { + return fmt::format("CALL mg.update_module_file({}, '{}')", path, content); +} + +void UpdateModuleFile(auto &client, const std::filesystem::path &path, const std::string_view content) { + spdlog::info("Updating module file {} with content:\n{}", path, content); + MG_ASSERT(client->Execute(UpdateModuleFileQuery(path, content))); + MG_ASSERT(client->FetchAll().has_value()); +} + +std::string DeleteModuleFileQuery(const std::filesystem::path &path) { + return fmt::format("CALL mg.delete_module_file({})", path); +} + +void DeleteModuleFile(auto &client, const std::filesystem::path &path) { + spdlog::info("Deleting module file {}", path); + MG_ASSERT(client->Execute(DeleteModuleFileQuery(path))); + MG_ASSERT(client->FetchAll().has_value()); +} + +constexpr std::string_view module_content1 = R"(import mgp + +@mgp.read_proc +def simple1(ctx: mgp.ProcCtx) -> mgp.Record(result=bool): + return mgp.Record(mutable=True))"; + +constexpr std::string_view module_content2 = R"(import mgp + +@mgp.read_proc +def simple2(ctx: mgp.ProcCtx) -> mgp.Record(result=bool): + return mgp.Record(mutable=True))"; + +} // namespace + +int main(int argc, char **argv) { + google::SetUsageMessage("Memgraph E2E Isolation Levels"); + gflags::ParseCommandLineFlags(&argc, &argv, true); + logging::RedirectToStderr(); + + mg::Client::Init(); + auto client = GetClient(); + + AssertQueryFails(client, CreateModuleFileQuery("some.cpp", "some content"), + "mg.create_module_file: The specified file isn't in the supported format."); + + AssertQueryFails(client, CreateModuleFileQuery("../some.cpp", "some content"), + "mg.create_module_file: Invalid relative path defined. The module file cannot " + "be define outside the internal modules directory."); + + AssertProcedureNotExists(client, "some.simple1"); + const auto module_path = CreateModuleFile(client, "some.py", module_content1); + AssertQueryFails(client, CreateModuleFileQuery("some.py", "some content"), + "mg.create_module_file: File with the same name already exists!"); + + AssertProcedureExists(client, "some.simple1", module_path); + AssertModuleFileExists(client, module_path); + MG_ASSERT(GetModuleFile(client, module_path) == module_content1, + "Content received from mg.get_module_file is incorrect"); + + AssertQueryFails(client, GetModuleFileQuery("some.py"), + "mg.get_module_file: The path should be an absolute path."); + + AssertQueryFails(client, GetModuleFileQuery(module_path.parent_path() / "some.cpp"), + "mg.get_module_file: The specified file isn't in the supported format."); + + AssertQueryFails(client, GetModuleFileQuery(module_path.parent_path() / "some2.py"), + "mg.get_module_file: The specified file doesn't exist."); + + AssertQueryFails(client, UpdateModuleFileQuery("some.py", "some content"), + "mg.update_module_file: The path should be an absolute path."); + + AssertQueryFails(client, + UpdateModuleFileQuery(module_path.parent_path() / "some.cpp", "some content"), + "mg.update_module_file: The specified file isn't in the supported format."); + + AssertQueryFails(client, + UpdateModuleFileQuery(module_path.parent_path() / "some2.py", "some content"), + "mg.update_module_file: The specified file doesn't exist."); + + UpdateModuleFile(client, module_path, module_content2); + AssertProcedureNotExists(client, "some.simple1"); + AssertProcedureExists(client, "some.simple2", module_path); + AssertModuleFileExists(client, module_path); + MG_ASSERT(GetModuleFile(client, module_path) == module_content2, + "Content received from mg.get_module_file is incorrect"); + + AssertQueryFails(client, DeleteModuleFileQuery("some.py"), + "mg.delete_module_file: The path should be an absolute path."); + + AssertQueryFails(client, DeleteModuleFileQuery(module_path.parent_path() / "some.cpp"), + "mg.delete_module_file: The specified file isn't in the supported format."); + + AssertQueryFails(client, DeleteModuleFileQuery(module_path.parent_path() / "some2.py"), + "mg.delete_module_file: The specified file doesn't exist."); + + DeleteModuleFile(client, module_path); + AssertProcedureNotExists(client, "some.simple1"); + AssertProcedureNotExists(client, "some.simple2"); + AssertModuleFileNotExists(client, module_path); + + const auto non_module_directory = + std::filesystem::temp_directory_path() / "module_file_manager_e2e_non_module_directory"; + utils::EnsureDirOrDie(non_module_directory); + const auto non_module_file_path{non_module_directory / "something.py"}; + + { + std::ofstream non_module_file{non_module_file_path}; + MG_ASSERT(non_module_file.is_open(), "Failed to open {} for writing", non_module_file_path); + constexpr std::string_view content = "import mgp"; + non_module_file.write(content.data(), content.size()); + non_module_file.flush(); + } + + AssertQueryFails( + client, GetModuleFileQuery(non_module_file_path), + "mg.get_module_file: The specified file isn't contained in any of the module directories."); + + AssertQueryFails( + client, UpdateModuleFileQuery(non_module_file_path, "some content"), + "mg.update_module_file: The specified file isn't contained in any of the module directories."); + + AssertQueryFails( + client, DeleteModuleFileQuery(non_module_file_path), + "mg.delete_module_file: The specified file isn't contained in any of the module directories."); + + MG_ASSERT(std::filesystem::remove_all(non_module_directory), "Failed to cleanup directories"); + + return 0; +} diff --git a/tests/e2e/module_file_manager/workloads.yaml b/tests/e2e/module_file_manager/workloads.yaml new file mode 100644 index 000000000..b6670c6ef --- /dev/null +++ b/tests/e2e/module_file_manager/workloads.yaml @@ -0,0 +1,14 @@ +bolt_port: &bolt_port "7687" +template_cluster: &template_cluster + cluster: + main: + args: ["--bolt-port", *bolt_port, "--log-level=TRACE"] + log_file: "module-file-manager-e2e.log" + setup_queries: [] + validation_queries: [] + +workloads: + - name: "Module File Manager" + binary: "tests/e2e/module_file_manager/memgraph__e2e__module_file_manager" + args: ["--bolt-port", *bolt_port] + <<: *template_cluster diff --git a/tests/e2e/triggers/on_delete_triggers.cpp b/tests/e2e/triggers/on_delete_triggers.cpp index 5e8500855..bb3a2d4ef 100644 --- a/tests/e2e/triggers/on_delete_triggers.cpp +++ b/tests/e2e/triggers/on_delete_triggers.cpp @@ -1,4 +1,4 @@ -// Copyright 2021 Memgraph Ltd. +// Copyright 2022 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -101,14 +101,17 @@ void DropOnDeleteTriggers(mg::Client &client, const std::unordered_set +bool FunctionThrows(const auto &function) { + try { + function(); + } catch (const TException & /*unused*/) { + return true; + } + return false; +} + int main(int argc, char **argv) { gflags::SetUsageMessage("Memgraph E2E Triggers privilege check"); gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -52,7 +62,10 @@ int main(int argc, char **argv) { "UNWIND createdVertices as createdVertex " "CREATE (n: {} {{ id: createdVertex.id }})", kTriggerPrefix, vertexLabel, vertexLabel)); - client.DiscardAll(); + const bool succeeded = !FunctionThrows([&] { client.DiscardAll(); }); + + MG_ASSERT(succeeded == should_succeed, "Unexpected outcome from creating triggers: expected {}, actual {}", + should_succeed, succeeded); const auto number_of_triggers_after = get_number_of_triggers(); if (should_succeed) { MG_ASSERT(number_of_triggers_after == number_of_triggers_before + 1); @@ -162,10 +175,12 @@ int main(int argc, char **argv) { "CREATE (n: {} {{ id: createdVertex.id }})", kTriggerPrefix, kUserWithoutCreate, kUserWithoutCreate)); client_without_create->DiscardAll(); + userless_client->Execute(fmt::format("REVOKE CREATE FROM {};", kUserWithoutCreate)); userless_client->DiscardAll(); - CreateVertex(*userless_client, kVertexId); + MG_ASSERT(FunctionThrows([&] { CreateVertex(*userless_client, kVertexId); }), + "Create should have thrown because user doesn't have privilege for CREATE"); CheckNumberOfAllVertices(*userless_client, 0); return 0; diff --git a/tests/unit/cypher_main_visitor.cpp b/tests/unit/cypher_main_visitor.cpp index 6d1d469c8..9af86aa3a 100644 --- a/tests/unit/cypher_main_visitor.cpp +++ b/tests/unit/cypher_main_visitor.cpp @@ -211,7 +211,7 @@ class MockModule : public procedure::Module { const std::map> *Procedures() const override { return &procedures; } const std::map> *Transformations() const override { return &transformations; } - std::optional Path() const override { return std::nullopt; }; + std::optional Path() const override { return std::nullopt; } std::map> procedures{}; std::map> transformations{}; @@ -248,7 +248,7 @@ class CypherMainVisitorTest : public ::testing::TestWithParam &results, const ProcedureType type) { utils::MemoryResource *memory = utils::NewDeleteResource(); const bool is_write = type == ProcedureType::WRITE; - mgp_proc proc(name, DummyProcCallback, memory, is_write); + mgp_proc proc(name, DummyProcCallback, memory, {.is_write = is_write}); for (const auto arg : args) { proc.args.emplace_back(utils::pmr::string{arg, memory}, &any_type); } diff --git a/tests/unit/query_common.hpp b/tests/unit/query_common.hpp index f28c972aa..4ead307fd 100644 --- a/tests/unit/query_common.hpp +++ b/tests/unit/query_common.hpp @@ -1,4 +1,4 @@ -// Copyright 2021 Memgraph Ltd. +// Copyright 2022 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -451,6 +451,14 @@ auto GetMerge(AstStorage &storage, Pattern *pattern, OnMatch on_match, OnCreate return merge; } +auto GetCallProcedure(AstStorage &storage, std::string procedure_name, + std::vector arguments = {}) { + auto *call_procedure = storage.Create(); + call_procedure->procedure_name_ = std::move(procedure_name); + call_procedure->arguments_ = std::move(arguments); + return call_procedure; +} + } // namespace test_common } // namespace query @@ -558,3 +566,4 @@ auto GetMerge(AstStorage &storage, Pattern *pattern, OnMatch on_match, OnCreate #define AUTH_QUERY(action, user, role, user_or_role, password, privileges) \ storage.Create((action), (user), (role), (user_or_role), password, (privileges)) #define DROP_USER(usernames) storage.Create((usernames)) +#define CALL_PROCEDURE(...) query::test_common::GetCallProcedure(storage, __VA_ARGS__) diff --git a/tests/unit/query_procedure_mgp_module.cpp b/tests/unit/query_procedure_mgp_module.cpp index f5d047628..d7fb13f67 100644 --- a/tests/unit/query_procedure_mgp_module.cpp +++ b/tests/unit/query_procedure_mgp_module.cpp @@ -1,4 +1,4 @@ -// Copyright 2021 Memgraph Ltd. +// Copyright 2022 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -125,12 +125,12 @@ TEST(Module, ProcedureSignatureOnlyOptArg) { TEST(Module, ReadWriteProcedures) { mgp_module module(utils::NewDeleteResource()); auto *read_proc = EXPECT_MGP_NO_ERROR(mgp_proc *, mgp_module_add_read_procedure, &module, "read", &DummyCallback); - EXPECT_FALSE(read_proc->is_write_procedure); + EXPECT_FALSE(read_proc->info.is_write); auto *write_proc = EXPECT_MGP_NO_ERROR(mgp_proc *, mgp_module_add_write_procedure, &module, "write", &DummyCallback); - EXPECT_TRUE(write_proc->is_write_procedure); + EXPECT_TRUE(write_proc->info.is_write); mgp_proc read_proc_with_function{"dummy_name", std::function{ [](mgp_list *, mgp_graph *, mgp_result *, mgp_memory *) {}}, - utils::NewDeleteResource(), false}; - EXPECT_FALSE(read_proc_with_function.is_write_procedure); + utils::NewDeleteResource()}; + EXPECT_FALSE(read_proc_with_function.info.is_write); } diff --git a/tests/unit/query_required_privileges.cpp b/tests/unit/query_required_privileges.cpp index aea519ce4..e6dc3c6ad 100644 --- a/tests/unit/query_required_privileges.cpp +++ b/tests/unit/query_required_privileges.cpp @@ -191,3 +191,27 @@ TEST_F(TestPrivilegeExtractor, ShowVersion) { auto *query = storage.Create(); EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::STATS)); } + +TEST_F(TestPrivilegeExtractor, CallProcedureQuery) { + { + auto *query = QUERY(SINGLE_QUERY(CALL_PROCEDURE("mg.get_module_files"))); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::MODULE_READ)); + } + { + auto *query = QUERY(SINGLE_QUERY(CALL_PROCEDURE("mg.create_module_file", {LITERAL("some_name.py")}))); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::MODULE_WRITE)); + } + { + auto *query = QUERY( + SINGLE_QUERY(CALL_PROCEDURE("mg.update_module_file", {LITERAL("some_name.py"), LITERAL("some content")}))); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::MODULE_WRITE)); + } + { + auto *query = QUERY(SINGLE_QUERY(CALL_PROCEDURE("mg.get_module_file", {LITERAL("some_name.py")}))); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::MODULE_READ)); + } + { + auto *query = QUERY(SINGLE_QUERY(CALL_PROCEDURE("mg.delete_module_file", {LITERAL("some_name.py")}))); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::MODULE_WRITE)); + } +}