diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 50f066057..8737b0ed1 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -3735,34 +3735,6 @@ std::vector CallProcedure::ModifiedSymbols( namespace { -// Return the ModulePtr and `mgp_proc *` of the found procedure after resolving -// `fully_qualified_procedure_name`. `memory` is used for temporary allocations -// inside this function. ModulePtr must be kept alive to make sure it won't be -// unloaded. -// @throw QueryRuntimeException if unable to find the procedure. -std::pair FindProcedureOrThrow( - const std::string_view &fully_qualified_procedure_name, - utils::MemoryResource *memory) { - utils::pmr::vector name_parts(memory); - utils::Split(&name_parts, fully_qualified_procedure_name, "."); - if (name_parts.size() == 1U) { - throw QueryRuntimeException("There's no top-level procedure '{}'", - fully_qualified_procedure_name); - } - auto last_dot_pos = fully_qualified_procedure_name.find_last_of('.'); - CHECK(last_dot_pos != std::string_view::npos); - const auto &module_name = - fully_qualified_procedure_name.substr(0, last_dot_pos); - const auto &proc_name = name_parts.back(); - auto module = procedure::gModuleRegistry.GetModuleNamed(module_name); - if (!module) throw QueryRuntimeException("'{}' isn't loaded!", module_name); - const auto &proc_it = module->procedures.find(proc_name); - if (proc_it == module->procedures.end()) - throw QueryRuntimeException("'{}' does not have a procedure named '{}'", - module_name, proc_name); - return {std::move(module), &proc_it->second}; -} - void CallCustomProcedure(const std::string_view &fully_qualified_procedure_name, const mgp_proc &proc, const std::vector &args, @@ -3880,8 +3852,14 @@ class CallProcedureCursor : public Cursor { // it's not possible for a single thread to request multiple read locks. // Builtin module registration in query/procedure/module.cpp depends on // this locking scheme. - const auto &[module, proc] = FindProcedureOrThrow( - self_->procedure_name_, context.evaluation_context.memory); + const auto &maybe_found = procedure::FindProcedure( + procedure::gModuleRegistry, self_->procedure_name_, + context.evaluation_context.memory); + if (!maybe_found) { + throw QueryRuntimeException("There is no procedure named '{}'.", + self_->procedure_name_); + } + const auto &[module, proc] = *maybe_found; result_.signature = &proc->results; // Use evaluation memory, as invoking a procedure is akin to a simple // evaluation of an expression. diff --git a/src/query/procedure/module.cpp b/src/query/procedure/module.cpp index 12ad1b09f..8ee9853d1 100644 --- a/src/query/procedure/module.cpp +++ b/src/query/procedure/module.cpp @@ -6,6 +6,9 @@ extern "C" { #include +#include "utils/pmr/vector.hpp" +#include "utils/string.hpp" + namespace query::procedure { ModuleRegistry gModuleRegistry; @@ -213,7 +216,7 @@ bool ModuleRegistry::LoadModuleLibrary(std::filesystem::path path) { return true; } -ModulePtr ModuleRegistry::GetModuleNamed(const std::string_view &name) { +ModulePtr ModuleRegistry::GetModuleNamed(const std::string_view &name) const { std::shared_lock guard(lock_); auto found_it = modules_.find(name); if (found_it == modules_.end()) return nullptr; @@ -270,4 +273,23 @@ void ModuleRegistry::UnloadAllModules() { modules_.clear(); } +std::optional> FindProcedure( + const ModuleRegistry &module_registry, + const std::string_view &fully_qualified_procedure_name, + utils::MemoryResource *memory) { + utils::pmr::vector name_parts(memory); + utils::Split(&name_parts, fully_qualified_procedure_name, "."); + if (name_parts.size() == 1U) return std::nullopt; + auto last_dot_pos = fully_qualified_procedure_name.find_last_of('.'); + CHECK(last_dot_pos != std::string_view::npos); + const auto &module_name = + fully_qualified_procedure_name.substr(0, last_dot_pos); + const auto &proc_name = name_parts.back(); + auto module = module_registry.GetModuleNamed(module_name); + if (!module) return std::nullopt; + const auto &proc_it = module->procedures.find(proc_name); + if (proc_it == module->procedures.end()) return std::nullopt; + return std::make_pair(std::move(module), &proc_it->second); +} + } // namespace query::procedure diff --git a/src/query/procedure/module.hpp b/src/query/procedure/module.hpp index d4ceb1cfc..a8de87fa1 100644 --- a/src/query/procedure/module.hpp +++ b/src/query/procedure/module.hpp @@ -4,12 +4,14 @@ #include #include +#include #include #include #include #include #include "query/procedure/mg_procedure_impl.hpp" +#include "utils/memory.hpp" #include "utils/rw_lock.hpp" namespace query::procedure { @@ -47,7 +49,7 @@ class ModulePtr final { /// Thread-safe registration of modules from libraries, uses utils::RWLock. class ModuleRegistry final { std::map> modules_; - utils::RWLock lock_{utils::RWLock::Priority::WRITE}; + mutable utils::RWLock lock_{utils::RWLock::Priority::WRITE}; public: ModuleRegistry(); @@ -63,7 +65,7 @@ class ModuleRegistry final { /// Find a module with given name or return nullptr. /// Takes a read lock. - ModulePtr GetModuleNamed(const std::string_view &name); + ModulePtr GetModuleNamed(const std::string_view &name) const; /// Reload a module with given name and return true if successful. /// Takes a write lock. Builtin modules cannot be reloaded, though true will @@ -85,4 +87,13 @@ class ModuleRegistry final { /// Single, global module registry. extern ModuleRegistry gModuleRegistry; +/// Return the ModulePtr and `mgp_proc *` of the found procedure after resolving +/// `fully_qualified_procedure_name`. `memory` is used for temporary allocations +/// inside this function. ModulePtr must be kept alive to make sure it won't be +/// unloaded. +std::optional> FindProcedure( + const ModuleRegistry &module_registry, + const std::string_view &fully_qualified_procedure_name, + utils::MemoryResource *memory); + } // namespace query::procedure