diff --git a/src/memgraph.cpp b/src/memgraph.cpp index 63c6ffbb4..adf10ec34 100644 --- a/src/memgraph.cpp +++ b/src/memgraph.cpp @@ -577,7 +577,7 @@ void SingleNodeMain() { if (!FLAGS_query_modules_directory.empty()) { for (const auto &entry : std::filesystem::directory_iterator(FLAGS_query_modules_directory)) { - if (entry.is_regular_file() && entry.path().extension() == ".so") + if (entry.is_regular_file()) query::procedure::gModuleRegistry.LoadModuleLibrary(entry.path()); } } diff --git a/src/memgraph_init.cpp b/src/memgraph_init.cpp index 0fb5d4712..28006833b 100644 --- a/src/memgraph_init.cpp +++ b/src/memgraph_init.cpp @@ -4,6 +4,7 @@ #include "config.hpp" #include "glue/communication.hpp" +#include "py/py.hpp" #include "query/exceptions.hpp" #include "requests/requests.hpp" #include "storage/v2/view.hpp" @@ -223,6 +224,15 @@ int WithInit(int argc, char **argv, // Unhandled exception handler init. std::set_terminate(&utils::TerminateHandler); + // Initialize Python + auto *program_name = Py_DecodeLocale(argv[0], nullptr); + CHECK(program_name); + // Set program name, so Python can find its way to runtime libraries relative + // to executable. + Py_SetProgramName(program_name); + Py_InitializeEx(0 /* = initsigs */); + PyEval_InitThreads(); + Py_BEGIN_ALLOW_THREADS; // Initialize the communication library. communication::Init(); @@ -247,5 +257,9 @@ int WithInit(int argc, char **argv, requests::Init(); memgraph_main(); + Py_END_ALLOW_THREADS; + // Shutdown Python + Py_Finalize(); + PyMem_RawFree(program_name); return 0; } diff --git a/src/py/py.hpp b/src/py/py.hpp new file mode 100644 index 000000000..04cf7cb1b --- /dev/null +++ b/src/py/py.hpp @@ -0,0 +1,164 @@ +/// @file +/// Provides a C++ API for working with Python's original C API. +#pragma once + +#include +#include + +// Define to use Py_ssize_t for API returning length of something. Some future +// Python version will only support Py_ssize_t, so it's best to always define +// this macro before including Python.h. +#define PY_SSIZE_T_CLEAN +#include + +#if PY_MAJOR_VERSION != 3 || PY_MINOR_VERSION < 5 +#error "Minimum supported Python API is 3.5" +#endif + +namespace py { + +/// Ensure the current thread is ready to call Python C API. +/// +/// You must *not* try to ensure the GIL when the runtime is finalizing, as +/// that will terminate the thread. You may use `_Py_IsFinalizing` or +/// `sys.is_finalizing()` to check for such a case. +class EnsureGIL final { + PyGILState_STATE gil_state_; + + public: + EnsureGIL() : gil_state_(PyGILState_Ensure()) {} + ~EnsureGIL() { PyGILState_Release(gil_state_); } + EnsureGIL(const EnsureGIL &) = delete; + EnsureGIL(EnsureGIL &&) = delete; + EnsureGIL &operator=(const EnsureGIL &) = delete; + EnsureGIL &operator=(EnsureGIL &&) = delete; +}; + +/// Owns a `PyObject *` and supports a more C++ idiomatic API to objects. +class Object final { + PyObject *ptr_{nullptr}; + + public: + Object() = default; + explicit Object(PyObject *ptr) noexcept : ptr_(ptr) {} + + ~Object() noexcept { Py_XDECREF(ptr_); } + + Object(const Object &other) noexcept : ptr_(other.ptr_) { Py_XINCREF(ptr_); } + + Object(Object &&other) noexcept : ptr_(other.ptr_) { other.ptr_ = nullptr; } + + Object &operator=(const Object &other) noexcept { + if (this == &other) return *this; + Py_XDECREF(ptr_); + ptr_ = other.ptr_; + Py_XINCREF(ptr_); + return *this; + } + + Object &operator=(Object &&other) noexcept { + if (this == &other) return *this; + Py_XDECREF(ptr_); + ptr_ = other.ptr_; + other.ptr_ = nullptr; + return *this; + } + + operator PyObject *() const { return ptr_; } + + operator bool() const { return ptr_; } + + /// Equivalent to `str(o)` in Python. + /// + /// Returned Object is nullptr if an error occurred. + /// @sa FetchError + Object Str() const { return Object(PyObject_Str(ptr_)); } + + /// Equivalent to `callable()` in Python. + /// + /// Returned Object is nullptr if an error occurred. + /// @sa FetchError + Object Call() const { return Object(PyObject_CallObject(ptr_, nullptr)); } + + /// Equivalent to `callable(*args)` in Python. + /// + /// Returned Object is nullptr if an error occurred. + /// @sa FetchError + template + Object Call(const TArgs &... args) const { + return Object(PyObject_CallFunctionObjArgs( + ptr_, static_cast(args)..., nullptr)); + } + + /// Equivalent to `obj.meth_name()` in Python. + /// + /// Returned Object is nullptr if an error occurred. + /// @sa FetchError + Object CallMethod(std::string_view meth_name) const { + Object name( + PyUnicode_FromStringAndSize(meth_name.data(), meth_name.size())); + return Object(PyObject_CallMethodObjArgs(ptr_, name, nullptr)); + } + + /// Equivalent to `obj.meth_name(*args)` in Python. + /// + /// Returned Object is nullptr if an error occurred. + /// @sa FetchError + template + Object CallMethod(std::string_view meth_name, const TArgs &... args) const { + Object name( + PyUnicode_FromStringAndSize(meth_name.data(), meth_name.size())); + return Object(PyObject_CallMethodObjArgs( + ptr_, name, static_cast(args)..., nullptr)); + } +}; + +/// Write Object to stream as if `str(o)` was called in Python. +inline std::ostream &operator<<(std::ostream &os, const Object &py_object) { + auto py_str = py_object.Str(); + os << PyUnicode_AsUTF8(py_str); + return os; +} + +/// Stores information on a raised Python exception. +/// @sa FetchError +struct ExceptionInfo final { + /// Type of the exception, if nullptr there is no exception. + Object type; + /// Optional value of the exception. + Object value; + /// Optional traceback of the exception. + Object traceback; +}; + +/// Write ExceptionInfo to stream just like the Python interpreter would. +inline std::ostream &operator<<(std::ostream &os, + const ExceptionInfo &exc_info) { + if (!exc_info.type) return os; + Object traceback_mod(PyImport_ImportModule("traceback")); + CHECK(traceback_mod); + Object format_exception_fn( + PyObject_GetAttrString(traceback_mod, "format_exception")); + CHECK(format_exception_fn); + auto list = format_exception_fn.Call(exc_info.type, exc_info.value, + exc_info.traceback); + CHECK(list); + auto len = PyList_GET_SIZE(static_cast(list)); + for (Py_ssize_t i = 0; i < len; ++i) { + auto *py_str = PyList_GET_ITEM(static_cast(list), i); + os << PyUnicode_AsUTF8(py_str); + } + return os; +} + +/// Get the current exception info and clear the current exception indicator. +/// +/// This is normally used to catch and handle exceptions via C API. +inline std::optional FetchError() { + PyObject *exc_type, *exc_value, *traceback; + PyErr_Fetch(&exc_type, &exc_value, &traceback); + if (!exc_type) return std::nullopt; + return ExceptionInfo{Object(exc_type), Object(exc_value), Object(traceback)}; +} + +} // namespace py diff --git a/src/query/CMakeLists.txt b/src/query/CMakeLists.txt index 90bba7824..f7023af74 100644 --- a/src/query/CMakeLists.txt +++ b/src/query/CMakeLists.txt @@ -34,6 +34,8 @@ add_dependencies(mg-query generate_lcp_query) target_include_directories(mg-query PRIVATE ${CMAKE_SOURCE_DIR}/include) target_link_libraries(mg-query dl cppitertools) target_link_libraries(mg-query mg-storage-v2) +find_package(Python3 3.5 REQUIRED COMPONENTS Development) +target_link_libraries(mg-query Python3::Python) # Generate Antlr openCypher parser diff --git a/src/query/procedure/module.cpp b/src/query/procedure/module.cpp index 8ee9853d1..1754d922a 100644 --- a/src/query/procedure/module.cpp +++ b/src/query/procedure/module.cpp @@ -6,6 +6,7 @@ extern "C" { #include +#include "py/py.hpp" #include "utils/pmr/vector.hpp" #include "utils/string.hpp" @@ -15,6 +16,34 @@ ModuleRegistry gModuleRegistry; namespace { +std::optional LoadModuleFromPythonFile(std::filesystem::path path) { + LOG(INFO) << "Loading module " << path << " ..."; + auto gil = py::EnsureGIL(); + auto *py_path = PySys_GetObject("path"); + CHECK(py_path); + py::Object import_dir(PyUnicode_FromString(path.parent_path().c_str())); + int import_dir_in_path = PySequence_Contains(py_path, import_dir); + if (import_dir_in_path == -1) { + LOG(ERROR) << "Unexpected error when loading module " << path; + return std::nullopt; + } + if (import_dir_in_path == 0) { + if (PyList_Append(py_path, import_dir) != 0) { + auto exc_info = py::FetchError().value(); + LOG(ERROR) << "Unable to load module " << path << "; " << exc_info; + return std::nullopt; + } + } + py::Object py_module(PyImport_ImportModule(path.stem().c_str())); + if (!py_module) { + auto exc_info = py::FetchError().value(); + LOG(ERROR) << "Unable to load module " << path << "; " << exc_info; + return std::nullopt; + } + // TODO: Actually create a module + return std::nullopt; +} + std::optional LoadModuleFromSharedLibrary(std::filesystem::path path) { LOG(INFO) << "Loading module " << path << " ..."; Module module{path}; @@ -210,7 +239,15 @@ bool ModuleRegistry::LoadModuleLibrary(std::filesystem::path path) { LOG(ERROR) << "Unable to overwrite an already loaded module " << path; return false; } - auto maybe_module = LoadModuleFromSharedLibrary(path); + std::optional maybe_module; + if (path.extension() == ".so") { + maybe_module = LoadModuleFromSharedLibrary(path); + } else if (path.extension() == ".py") { + maybe_module = LoadModuleFromPythonFile(path); + } else { + LOG(ERROR) << "Unkown query module file " << path; + return false; + } if (!maybe_module) return false; modules_[module_name] = std::move(*maybe_module); return true;