diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4fa35d7ac..875c3def0 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -63,6 +63,7 @@ set(mg_single_node_sources query/plan/rewrite/index_lookup.cpp query/plan/rule_based_planner.cpp query/plan/variable_start_planner.cpp + query/plugin/plugin.cpp query/repl.cpp query/typed_value.cpp storage/common/constraints/record.cpp @@ -133,6 +134,7 @@ set(mg_single_node_v2_sources query/plan/rewrite/index_lookup.cpp query/plan/rule_based_planner.cpp query/plan/variable_start_planner.cpp + query/plugin/plugin.cpp query/typed_value.cpp memgraph_init.cpp ) diff --git a/src/query/plugin/plugin.cpp b/src/query/plugin/plugin.cpp new file mode 100644 index 000000000..0462012f4 --- /dev/null +++ b/src/query/plugin/plugin.cpp @@ -0,0 +1,142 @@ +#include "query/plugin/plugin.hpp" + +extern "C" { +#include +} + +#include + +namespace query::plugin { + +PluginRegistry gPluginRegistry; + +namespace { + +std::optional LoadPluginFromSharedLibrary(std::filesystem::path path) { + LOG(INFO) << "Loading plugin " << path << " ..."; + Plugin plugin{path}; + dlerror(); // Clear any existing error. + plugin.handle = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL); + if (!plugin.handle) { + LOG(ERROR) << "Unable to load plugin " << path << "; " << dlerror(); + return std::nullopt; + } + // Get required mg_main + plugin.main_fn = + reinterpret_cast(dlsym(plugin.handle, "mg_main")); + const char *error = dlerror(); + if (!plugin.main_fn || error) { + LOG(ERROR) << "Unable to load plugin " << path << "; " << error; + dlclose(plugin.handle); + return std::nullopt; + } + // Get optional mg_init_module + plugin.init_fn = + reinterpret_cast(dlsym(plugin.handle, "mg_init_module")); + error = dlerror(); + if (error) LOG(WARNING) << "When loading plugin " << path << "; " << error; + // Run mg_init_module which must succeed. + if (plugin.init_fn) { + int init_res = plugin.init_fn(); + if (init_res != 0) { + LOG(ERROR) << "Unable to load plugin " << path + << "; mg_init_module returned " << init_res; + dlclose(plugin.handle); + return std::nullopt; + } + } + // Get optional mg_shutdown_module + plugin.shutdown_fn = + reinterpret_cast(dlsym(plugin.handle, "mg_shutdown_module")); + error = dlerror(); + if (error) LOG(WARNING) << "When loading plugin " << path << "; " << error; + LOG(INFO) << "Loaded plugin " << path; + return plugin; +} + +bool ClosePlugin(Plugin *plugin) { + LOG(INFO) << "Closing plugin " << plugin->file_path << " ..."; + if (plugin->shutdown_fn) { + int shutdown_res = plugin->shutdown_fn(); + if (shutdown_res != 0) { + LOG(WARNING) << "When closing plugin " << plugin->file_path + << "; mg_shutdown_module returned " << shutdown_res; + } + } + if (dlclose(plugin->handle) != 0) { + LOG(ERROR) << "Failed to close plugin " << plugin->file_path << "; " + << dlerror(); + return false; + } + LOG(INFO) << "Closed plugin " << plugin->file_path; + return true; +} + +} // namespace + +bool PluginRegistry::LoadPluginLibrary(std::filesystem::path path) { + std::unique_lock guard(lock_); + std::string plugin_name(path.stem()); + if (plugins_.find(plugin_name) != plugins_.end()) return true; + auto maybe_plugin = LoadPluginFromSharedLibrary(path); + if (!maybe_plugin) return false; + plugins_[plugin_name] = std::move(*maybe_plugin); + return true; +} + +PluginPtr PluginRegistry::GetPluginNamed(const std::string_view &name) { + std::shared_lock guard(lock_); + // NOTE: std::unordered_map::find cannot work with std::string_view :( + auto found_it = plugins_.find(std::string(name)); + if (found_it == plugins_.end()) return nullptr; + return PluginPtr(&found_it->second, std::move(guard)); +} + +bool PluginRegistry::ReloadPluginNamed(const std::string_view &name) { + std::unique_lock guard(lock_); + // NOTE: std::unordered_map::find cannot work with std::string_view :( + auto found_it = plugins_.find(std::string(name)); + if (found_it == plugins_.end()) { + LOG(ERROR) << "Trying to reload plugin '" << name + << "' which is not loaded."; + return false; + } + auto &plugin = found_it->second; + if (!ClosePlugin(&plugin)) { + plugins_.erase(found_it); + return false; + } + auto maybe_plugin = LoadPluginFromSharedLibrary(plugin.file_path); + if (!maybe_plugin) { + plugins_.erase(found_it); + return false; + } + plugin = std::move(*maybe_plugin); + return true; +} + +bool PluginRegistry::ReloadAllPlugins() { + std::unique_lock guard(lock_); + for (auto &[name, plugin] : plugins_) { + if (!ClosePlugin(&plugin)) { + plugins_.erase(name); + return false; + } + auto maybe_plugin = LoadPluginFromSharedLibrary(plugin.file_path); + if (!maybe_plugin) { + plugins_.erase(name); + return false; + } + plugin = std::move(*maybe_plugin); + } + return true; +} + +void PluginRegistry::UnloadAllPlugins() { + std::unique_lock guard(lock_); + for (auto &name_and_plugin : plugins_) ClosePlugin(&name_and_plugin.second); + plugins_.clear(); +} + +} // namespace query::plugin diff --git a/src/query/plugin/plugin.hpp b/src/query/plugin/plugin.hpp new file mode 100644 index 000000000..5e4643846 --- /dev/null +++ b/src/query/plugin/plugin.hpp @@ -0,0 +1,89 @@ +/// @file API for loading and registering plugins providing custom oC procedures +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "utils/rw_lock.hpp" + +struct mg_value; +struct mg_graph; +struct mg_result; + +namespace query::plugin { + +struct Plugin final { + /// Path as requested for loading the plugin from a library. + std::filesystem::path file_path; + /// System handle to shared library. + void *handle; + /// Entry-point for plugin's custom procedure. + std::function + main_fn; + /// Optional initialization function called on plugin load. + std::function init_fn; + /// Optional shutdown function called on plugin unload. + std::function shutdown_fn; +}; + + +/// Proxy for a registered Plugin, acquires a read lock from PluginRegistry. +class PluginPtr final { + const Plugin *plugin_{nullptr}; + std::shared_lock lock_; + + public: + PluginPtr() = default; + PluginPtr(std::nullptr_t) {} + PluginPtr(const Plugin *plugin, std::shared_lock lock) + : plugin_(plugin), lock_(std::move(lock)) {} + + explicit operator bool() const { return static_cast(plugin_); } + + const Plugin &operator*() const { return *plugin_; } + const Plugin *operator->() const { return plugin_; } +}; + +/// Thread-safe registration of plugins from libraries, uses utils::RWLock. +class PluginRegistry final { + std::unordered_map plugins_; + utils::RWLock lock_{utils::RWLock::Priority::WRITE}; + + public: + /// Load a plugin from the given path and return true if successful. + /// + /// A write lock is taken during the execution of this method. Loading a + /// plugin is done through `dlopen` facility and path is resolved accordingly. + /// The plugin is registered using the filename part of the path, with the + /// extension removed. If a plugin with the same name already exists, the + /// function does nothing. + bool LoadPluginLibrary(std::filesystem::path path); + + /// Find a plugin with given name or return nullptr. + /// Takes a read lock. + PluginPtr GetPluginNamed(const std::string_view &name); + + /// Reload a plugin with given name and return true if successful. + /// Takes a write lock. If false was returned, then the plugin is no longer + /// registered. + bool ReloadPluginNamed(const std::string_view &name); + + /// Reload all loaded plugins and return true if successful. + /// Takes a write lock. If false was returned, the plugin which failed to + /// reload is no longer registered. Remaining plugins may or may not be + /// reloaded, but are valid and registered. + bool ReloadAllPlugins(); + + /// Remove all loaded plugins. + /// Takes a write lock. + void UnloadAllPlugins(); +}; + +/// Single, global plugin registry. +extern PluginRegistry gPluginRegistry; + +} // namespace query::plugin