Register Python procedures

Summary:
With this diff you should now be able to register `example.py` and read
procedures found there. The procedures will be listed through `CALL
mg.procedures() YIELD *` query, but invoking them will raise
`NotYetImplemented`.

If you wish to test this, you will need to run the Memgraph executable
with PYTHONPATH set to the `include` directory where `mgp.py` is found.
Additionally, you need to pass `--query-modules-directory` flag to
Memgraph, such that it points to where it will find the `example.py`.

For example, when running from the root directory of Memgraph repo, the
shell invocation below should do the trick (assuming `./build/memgraph`
is where is the executable). Make sure that `./query_modules/` does not
have `example.so` built, as that may interfere with loading
`example.py`.

    PYTHONPATH=$PWD/include ./build/memgraph --query-modules-directory=./query_modules/

Reviewers: mferencevic, ipaljak, llugovic

Reviewed By: mferencevic, ipaljak, llugovic

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D2678
This commit is contained in:
Teon Banek 2020-02-19 10:39:18 +01:00
parent 47ce444c02
commit 34db077cbd
5 changed files with 458 additions and 132 deletions

View File

@ -15,6 +15,9 @@ This module provides the API for usage in custom openCypher procedures.
# 3.5, but variable type annotations are only available with Python 3.6+
from collections import namedtuple
import functools
import inspect
import sys
import typing
import _mgp
@ -236,10 +239,11 @@ class Path:
class Record:
'''Represents a record of resulting field values.'''
__slots__ = ('fields',)
def __init__(self, **kwargs):
'''Initialize with name=value fields in kwargs.'''
pass
self.fields = kwargs
class InvalidProcCtxError(Exception):
@ -346,8 +350,12 @@ Nullable = typing.Optional
class Deprecated:
'''Annotate a resulting Record's field as deprecated.'''
__slots__ = ('field_type',)
def __init__(self, type_):
pass
if not isinstance(type_, type):
raise TypeError("Expected 'type', got '{}'".format(type_))
self.field_type = type_
def read_proc(func: typing.Callable[..., Record]):
@ -390,4 +398,47 @@ def read_proc(func: typing.Callable[..., Record]):
CALL example.procedure(1) YIELD args, result;
Naturally, you may pass in different arguments or yield less fields.
'''
pass
if not callable(func):
raise TypeError("Expected a callable object, got an instance of '{}'"
.format(type(func)))
if inspect.iscoroutinefunction(func):
raise TypeError("Callable must not be 'async def' function")
if sys.version_info.minor >= 6:
if inspect.isasyncgenfunction(func):
raise TypeError("Callable must not be 'async def' function")
if inspect.isgeneratorfunction(func):
raise NotImplementedError("Generator functions are not supported")
sig = inspect.signature(func)
params = tuple(sig.parameters.values())
if params and params[0].annotation is ProcCtx:
params = params[1:]
mgp_proc = _mgp._MODULE.add_read_procedure(func)
else:
@functools.wraps(func)
def wrapper(*args):
args_without_context = args[1:]
return func(*args_without_context)
mgp_proc = _mgp._MODULE.add_read_procedure(wrapper)
for param in params:
name = param.name
type_ = param.annotation
# TODO: Convert type_ to _mgp.CypherType
if type_ is param.empty:
type_ = object
if param.default is param.empty:
mgp_proc.add_arg(name, type_)
else:
mgp_proc.add_opt_arg(name, type_, param.default)
if sig.return_annotation is not sig.empty:
record = sig.return_annotation
if not isinstance(record, Record):
raise TypeError("Expected '{}' to return 'mgp.Record', got '{}'"
.format(func.__name__, type(record)))
for name, type_ in record.fields.items():
# TODO: Convert type_ to _mgp.CypherType
if isinstance(type_, Deprecated):
field_type = type_.field_type
mgp_proc.add_deprecated_result(name, field_type)
else:
mgp_proc.add_result(name, type_)
return func

View File

@ -7,6 +7,7 @@ extern "C" {
#include <optional>
#include "py/py.hpp"
#include "query/procedure/py_module.hpp"
#include "utils/pmr/vector.hpp"
#include "utils/string.hpp"
@ -14,7 +15,50 @@ namespace query::procedure {
ModuleRegistry gModuleRegistry;
Module::~Module() {}
class BuiltinModule final : public Module {
public:
BuiltinModule();
~BuiltinModule() override;
BuiltinModule(const BuiltinModule &) = delete;
BuiltinModule(BuiltinModule &&) = delete;
BuiltinModule &operator=(const BuiltinModule &) = delete;
BuiltinModule &operator=(BuiltinModule &&) = delete;
bool Close() override;
bool Reload() override;
const std::map<std::string, mgp_proc, std::less<>> *Procedures()
const override;
void AddProcedure(std::string_view name, mgp_proc proc);
private:
/// Registered procedures
std::map<std::string, mgp_proc, std::less<>> procedures_;
};
BuiltinModule::BuiltinModule() {}
BuiltinModule::~BuiltinModule() {}
bool BuiltinModule::Reload() { return true; }
bool BuiltinModule::Close() { return true; }
const std::map<std::string, mgp_proc, std::less<>> *BuiltinModule::Procedures()
const {
return &procedures_;
}
void BuiltinModule::AddProcedure(std::string_view name, mgp_proc proc) {
procedures_.emplace(name, std::move(proc));
}
namespace {
void RegisterMgReload(ModuleRegistry *module_registry, utils::RWLock *lock,
BuiltinModule *module) {
// Reloading relies on the fact that regular procedure invocation through
@ -129,26 +173,58 @@ void RegisterMgProcedures(
module->AddProcedure("procedures", std::move(procedures));
}
// Run `fun` with `mgp_module *` and `mgp_memory *` arguments. If `fun` returned
// a `true` value, store the `mgp_module::procedures` into `proc_map`. The
// return value of WithModuleRegistration is the same as that of `fun`. Note,
// the return value need only be convertible to `bool`, it does not have to be
// `bool` itself.
template <class TProcMap, class TFun>
auto WithModuleRegistration(TProcMap *proc_map, const TFun &fun) {
// We probably don't need more than 256KB for module initialization.
constexpr size_t stack_bytes = 256 * 1024;
unsigned char stack_memory[stack_bytes];
utils::MonotonicBufferResource monotonic_memory(stack_memory, stack_bytes);
mgp_memory memory{&monotonic_memory};
mgp_module module_def{memory.impl};
auto res = fun(&module_def, &memory);
if (res)
// Copy procedures into resulting proc_map.
for (const auto &proc : module_def.procedures) proc_map->emplace(proc);
return res;
}
} // namespace
Module::~Module() {}
class SharedLibraryModule final : public Module {
public:
SharedLibraryModule();
~SharedLibraryModule() override;
SharedLibraryModule(const SharedLibraryModule &) = delete;
SharedLibraryModule(SharedLibraryModule &&) = delete;
SharedLibraryModule &operator=(const SharedLibraryModule &) = delete;
SharedLibraryModule &operator=(SharedLibraryModule &&) = delete;
BuiltinModule::BuiltinModule() {}
bool Load(std::filesystem::path file_path);
BuiltinModule::~BuiltinModule() {}
bool Close() override;
bool BuiltinModule::Reload() { return true; }
bool Reload() override;
bool BuiltinModule::Close() { return true; }
const std::map<std::string, mgp_proc, std::less<>> *Procedures()
const override;
const std::map<std::string, mgp_proc, std::less<>> *BuiltinModule::Procedures()
const {
return &procedures_;
}
void BuiltinModule::AddProcedure(std::string_view name, mgp_proc proc) {
procedures_.emplace(name, std::move(proc));
}
private:
/// Path as requested for loading the module from a library.
std::filesystem::path file_path_;
/// System handle to shared library.
void *handle_;
/// Required initialization function called on module load.
std::function<int(mgp_module *, mgp_memory *)> init_fn_;
/// Optional shutdown function called on module unload.
std::function<int()> shutdown_fn_;
/// Registered procedures
std::map<std::string, mgp_proc, std::less<>> procedures_;
};
SharedLibraryModule::SharedLibraryModule() : handle_(nullptr) {}
@ -176,24 +252,21 @@ bool SharedLibraryModule::Load(std::filesystem::path file_path) {
handle_ = nullptr;
return false;
}
// We probably don't need more than 256KB for module initialazation.
constexpr size_t stack_bytes = 256 * 1024;
unsigned char stack_memory[stack_bytes];
utils::MonotonicBufferResource monotonic_memory(stack_memory, stack_bytes);
mgp_memory memory{&monotonic_memory};
mgp_module module_def{memory.impl};
// Run mgp_init_module which must succeed.
int init_res = init_fn_(&module_def, &memory);
if (init_res != 0) {
LOG(ERROR) << "Unable to load module " << file_path
<< "; mgp_init_module returned " << init_res;
dlclose(handle_);
handle_ = nullptr;
if (!WithModuleRegistration(
&procedures_, [&](auto *module_def, auto *memory) {
// Run mgp_init_module which must succeed.
int init_res = init_fn_(module_def, memory);
if (init_res != 0) {
LOG(ERROR) << "Unable to load module " << file_path
<< "; mgp_init_module returned " << init_res;
dlclose(handle_);
handle_ = nullptr;
return false;
}
return true;
})) {
return false;
}
// Copy procedures into our memory.
for (const auto &proc : module_def.procedures)
procedures_.emplace(proc);
// Get optional mgp_shutdown_module
shutdown_fn_ =
reinterpret_cast<int (*)()>(dlsym(handle_, "mgp_shutdown_module"));
@ -228,7 +301,6 @@ bool SharedLibraryModule::Close() {
bool SharedLibraryModule::Reload() {
CHECK(handle_) << "Attempting to reload a module that has not been loaded...";
LOG(INFO) << "Reloading module " << file_path_ << " ...";
if (!Close()) return false;
return Load(file_path_);
}
@ -240,11 +312,37 @@ const std::map<std::string, mgp_proc, std::less<>>
return &procedures_;
}
class PythonModule final : public Module {
public:
PythonModule();
~PythonModule() override;
PythonModule(const PythonModule &) = delete;
PythonModule(PythonModule &&) = delete;
PythonModule &operator=(const PythonModule &) = delete;
PythonModule &operator=(PythonModule &&) = delete;
bool Load(std::filesystem::path file_path);
bool Close() override;
bool Reload() override;
const std::map<std::string, mgp_proc, std::less<>> *Procedures()
const override;
private:
py::Object py_module_;
std::map<std::string, mgp_proc, std::less<>> procedures_;
};
PythonModule::PythonModule() {}
PythonModule::~PythonModule() {}
PythonModule::~PythonModule() {
if (py_module_) Close();
}
bool PythonModule::Load(std::filesystem::path file_path) {
CHECK(!py_module_) << "Attempting to load an already loaded module...";
LOG(INFO) << "Loading module " << file_path << " ...";
auto gil = py::EnsureGIL();
auto *py_path = PySys_GetObject("path");
@ -262,29 +360,47 @@ bool PythonModule::Load(std::filesystem::path file_path) {
return false;
}
}
py::Object py_module(PyImport_ImportModule(file_path.stem().c_str()));
if (!py_module) {
auto exc_info = py::FetchError().value();
LOG(ERROR) << "Unable to load module " << file_path << "; " << exc_info;
return false;
}
// TODO: Actually create a module
py_module_ =
WithModuleRegistration(&procedures_, [&](auto *module_def, auto *memory) {
return ImportPyModule(file_path.stem().c_str(), module_def);
});
if (py_module_) return true;
auto exc_info = py::FetchError().value();
LOG(ERROR) << "Unable to load module " << file_path << "; " << exc_info;
return false;
}
bool PythonModule::Close() {
//TODO: implement
return false;
CHECK(py_module_)
<< "Attempting to close a module that has not been loaded...";
// Deleting procedures will probably release PyObject closures, so we need to
// take the GIL.
auto gil = py::EnsureGIL();
procedures_.clear();
py_module_ = py::Object(nullptr);
return true;
}
bool PythonModule::Reload() {
//TODO: implement
CHECK(py_module_)
<< "Attempting to reload a module that has not been loaded...";
auto gil = py::EnsureGIL();
procedures_.clear();
py_module_ =
WithModuleRegistration(&procedures_, [&](auto *module_def, auto *memory) {
return ReloadPyModule(py_module_, module_def);
});
if (py_module_) return true;
auto exc_info = py::FetchError().value();
LOG(ERROR) << "Unable to reload module; " << exc_info;
return false;
}
const std::map<std::string, mgp_proc, std::less<>> *PythonModule::Procedures()
const {
return nullptr;
CHECK(py_module_) << "Attempting to access procedures of a module that has "
"not been loaded...";
return &procedures_;
}
ModuleRegistry::ModuleRegistry() {
@ -312,7 +428,7 @@ bool ModuleRegistry::LoadModuleLibrary(std::filesystem::path path) {
if (!loaded) return false;
modules_[module_name] = std::move(module);
} else {
LOG(ERROR) << "Unkown query module file " << path;
LOG(ERROR) << "Unknown query module file " << path;
return false;
}
return true;
@ -334,6 +450,7 @@ bool ModuleRegistry::ReloadModuleNamed(const std::string_view &name) {
return false;
}
auto &module = found_it->second;
LOG(INFO) << "Reloading module '" << name << "' ...";
if (!module->Reload()) {
modules_.erase(found_it);
return false;
@ -344,6 +461,7 @@ bool ModuleRegistry::ReloadModuleNamed(const std::string_view &name) {
bool ModuleRegistry::ReloadAllModules() {
std::unique_lock<utils::RWLock> guard(lock_);
for (auto &[name, module] : modules_) {
LOG(INFO) << "Reloading module '" << name << "' ...";
if (!module->Reload()) {
modules_.erase(name);
return false;

View File

@ -36,79 +36,6 @@ class Module {
const = 0;
};
class BuiltinModule final : public Module {
public:
BuiltinModule();
~BuiltinModule() override;
BuiltinModule(const BuiltinModule &) = delete;
BuiltinModule(BuiltinModule &&) = delete;
BuiltinModule &operator=(const BuiltinModule &) = delete;
BuiltinModule &operator=(BuiltinModule &&) = delete;
bool Close() override;
bool Reload() override;
const std::map<std::string, mgp_proc, std::less<>> *Procedures()
const override;
void AddProcedure(std::string_view name, mgp_proc proc);
private:
/// Registered procedures
std::map<std::string, mgp_proc, std::less<>> procedures_;
};
class SharedLibraryModule final : public Module {
public:
SharedLibraryModule();
~SharedLibraryModule() override;
SharedLibraryModule(const SharedLibraryModule &) = delete;
SharedLibraryModule(SharedLibraryModule &&) = delete;
SharedLibraryModule &operator=(const SharedLibraryModule &) = delete;
SharedLibraryModule &operator=(SharedLibraryModule &&) = delete;
bool Load(std::filesystem::path file_path);
bool Close() override;
bool Reload() override;
const std::map<std::string, mgp_proc, std::less<>> *Procedures()
const override;
private:
/// Path as requested for loading the module from a library.
std::filesystem::path file_path_;
/// System handle to shared library.
void *handle_;
/// Required initialization function called on module load.
std::function<int(mgp_module *, mgp_memory *)> init_fn_;
/// Optional shutdown function called on module unload.
std::function<int()> shutdown_fn_;
/// Registered procedures
std::map<std::string, mgp_proc, std::less<>> procedures_;
};
class PythonModule final : public Module {
public:
PythonModule();
~PythonModule() override;
PythonModule(const PythonModule &) = delete;
PythonModule(PythonModule &&) = delete;
PythonModule &operator=(const PythonModule &) = delete;
PythonModule &operator=(PythonModule &&) = delete;
bool Load(std::filesystem::path file_path);
bool Close() override;
bool Reload() override;
const std::map<std::string, mgp_proc, std::less<>> *Procedures()
const override;
};
/// Proxy for a registered Module, acquires a read lock from ModuleRegistry.
class ModulePtr final {
const Module *module_{nullptr};

View File

@ -340,6 +340,177 @@ PyObject *MakePyGraph(const mgp_graph *graph, mgp_memory *memory) {
return PyObject_Init(reinterpret_cast<PyObject *>(py_graph), &PyGraphType);
}
struct PyQueryProc {
PyObject_HEAD
mgp_proc *proc;
};
PyObject *PyQueryProcAddArg(PyQueryProc *self, PyObject *args) {
CHECK(self->proc);
const char *name = nullptr;
PyObject *py_type = nullptr;
if (!PyArg_ParseTuple(args, "sO", &name, &py_type)) return nullptr;
// TODO: Convert Python type to mgp_type
const auto *type = mgp_type_nullable(mgp_type_any());
if (!mgp_proc_add_arg(self->proc, name, type)) {
PyErr_SetString(PyExc_ValueError, "Invalid call to mgp_proc_add_arg.");
return nullptr;
}
Py_RETURN_NONE;
}
PyObject *PyQueryProcAddOptArg(PyQueryProc *self, PyObject *args) {
CHECK(self->proc);
const char *name = nullptr;
PyObject *py_type = nullptr;
PyObject *py_value = nullptr;
if (!PyArg_ParseTuple(args, "sOO", &name, &py_type, &py_value))
return nullptr;
// TODO: Convert Python type to mgp_type
const auto *type = mgp_type_nullable(mgp_type_any());
mgp_memory memory{self->proc->opt_args.get_allocator().GetMemoryResource()};
mgp_value *value;
try {
value = PyObjectToMgpValue(py_value, &memory);
} catch (const std::bad_alloc &e) {
PyErr_SetString(PyExc_MemoryError, e.what());
return nullptr;
} catch (const std::overflow_error &e) {
PyErr_SetString(PyExc_OverflowError, e.what());
return nullptr;
} catch (const std::invalid_argument &e) {
PyErr_SetString(PyExc_ValueError, e.what());
return nullptr;
} catch (const std::exception &e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return nullptr;
}
CHECK(value);
if (!mgp_proc_add_opt_arg(self->proc, name, type, value)) {
mgp_value_destroy(value);
PyErr_SetString(PyExc_ValueError, "Invalid call to mgp_proc_add_opt_arg.");
return nullptr;
}
mgp_value_destroy(value);
Py_RETURN_NONE;
}
PyObject *PyQueryProcAddResult(PyQueryProc *self, PyObject *args) {
CHECK(self->proc);
const char *name = nullptr;
PyObject *py_type = nullptr;
if (!PyArg_ParseTuple(args, "sO", &name, &py_type)) return nullptr;
// TODO: Convert Python type to mgp_type
const auto *type = mgp_type_nullable(mgp_type_any());
if (!mgp_proc_add_result(self->proc, name, type)) {
PyErr_SetString(PyExc_ValueError, "Invalid call to mgp_proc_add_result.");
return nullptr;
}
Py_RETURN_NONE;
}
PyObject *PyQueryProcAddDeprecatedResult(PyQueryProc *self, PyObject *args) {
CHECK(self->proc);
const char *name = nullptr;
PyObject *py_type = nullptr;
if (!PyArg_ParseTuple(args, "sO", &name, &py_type)) return nullptr;
// TODO: Convert Python type to mgp_type
const auto *type = mgp_type_nullable(mgp_type_any());
if (!mgp_proc_add_deprecated_result(self->proc, name, type)) {
PyErr_SetString(PyExc_ValueError,
"Invalid call to mgp_proc_add_deprecated_result.");
return nullptr;
}
Py_RETURN_NONE;
}
static PyMethodDef PyQueryProcMethods[] = {
{"add_arg", reinterpret_cast<PyCFunction>(PyQueryProcAddArg), METH_VARARGS,
"Add a required argument to a procedure."},
{"add_opt_arg", reinterpret_cast<PyCFunction>(PyQueryProcAddOptArg),
METH_VARARGS,
"Add an optional argument with a default value to a procedure."},
{"add_result", reinterpret_cast<PyCFunction>(PyQueryProcAddResult),
METH_VARARGS, "Add a result field to a procedure."},
{"add_deprecated_result",
reinterpret_cast<PyCFunction>(PyQueryProcAddDeprecatedResult),
METH_VARARGS,
"Add a result field to a procedure and mark it as deprecated."},
{nullptr},
};
static PyTypeObject PyQueryProcType = {
PyVarObject_HEAD_INIT(nullptr, 0)
.tp_name = "_mgp.Proc",
.tp_doc = "Wraps struct mgp_proc.",
.tp_basicsize = sizeof(PyQueryProc),
.tp_flags = Py_TPFLAGS_DEFAULT,
.tp_new = PyType_GenericNew,
.tp_methods = PyQueryProcMethods,
};
struct PyQueryModule {
PyObject_HEAD
mgp_module *module;
};
PyObject *PyQueryModuleAddReadProcedure(PyQueryModule *self, PyObject *cb) {
CHECK(self->module);
if (!PyCallable_Check(cb)) {
PyErr_SetString(PyExc_TypeError, "Expected a callable object.");
return nullptr;
}
Py_INCREF(cb);
py::Object py_cb(cb);
py::Object py_name(PyObject_GetAttrString(py_cb, "__name__"));
const auto *name = PyUnicode_AsUTF8(py_name);
// TODO: Validate name
auto *memory = self->module->procedures.get_allocator().GetMemoryResource();
mgp_proc proc(
name,
[py_cb](const mgp_list *, const mgp_graph *, mgp_result *, mgp_memory *) {
auto gil = py::EnsureGIL();
throw utils::NotYetImplemented("Invoking Python procedures");
},
memory);
const auto &[proc_it, did_insert] =
self->module->procedures.emplace(name, std::move(proc));
if (!did_insert) {
PyErr_SetString(PyExc_ValueError,
"Already registered a procedure with the same name.");
return nullptr;
}
auto *py_proc = PyObject_New(PyQueryProc, &PyQueryProcType);
if (!py_proc) return nullptr;
py_proc->proc = &proc_it->second;
return PyObject_Init(reinterpret_cast<PyObject *>(py_proc), &PyQueryProcType);
}
static PyMethodDef PyQueryModuleMethods[] = {
{"add_read_procedure",
reinterpret_cast<PyCFunction>(PyQueryModuleAddReadProcedure), METH_O,
"Register a read-only procedure with this module."},
{nullptr},
};
static PyTypeObject PyQueryModuleType = {
PyVarObject_HEAD_INIT(nullptr, 0)
.tp_name = "_mgp.Module",
.tp_doc = "Wraps struct mgp_module.",
.tp_basicsize = sizeof(PyQueryModule),
.tp_flags = Py_TPFLAGS_DEFAULT,
.tp_new = PyType_GenericNew,
.tp_methods = PyQueryModuleMethods,
};
PyObject *MakePyQueryModule(mgp_module *module) {
auto *py_query_module = PyObject_New(PyQueryModule, &PyQueryModuleType);
if (!py_query_module) return nullptr;
py_query_module->module = module;
return PyObject_Init(reinterpret_cast<PyObject *>(py_query_module),
&PyQueryModuleType);
}
static PyModuleDef PyMgpModule = {
PyModuleDef_HEAD_INIT,
.m_name = "_mgp",
@ -348,26 +519,68 @@ static PyModuleDef PyMgpModule = {
};
PyObject *PyInitMgpModule() {
if (PyType_Ready(&PyVerticesIteratorType) < 0) return nullptr;
if (PyType_Ready(&PyGraphType) < 0) return nullptr;
PyObject *mgp = PyModule_Create(&PyMgpModule);
if (!mgp) return nullptr;
Py_INCREF(&PyVerticesIteratorType);
if (PyModule_AddObject(
mgp, "VerticesIterator",
reinterpret_cast<PyObject *>(&PyVerticesIteratorType)) < 0) {
Py_DECREF(&PyVerticesIteratorType);
Py_DECREF(mgp);
auto register_type = [mgp](auto *type, const auto *name) -> bool {
if (PyType_Ready(type) < 0) {
Py_DECREF(mgp);
return false;
}
Py_INCREF(type);
if (PyModule_AddObject(mgp, name, reinterpret_cast<PyObject *>(type)) < 0) {
Py_DECREF(type);
Py_DECREF(mgp);
return false;
}
return true;
};
if (!register_type(&PyVerticesIteratorType, "VerticesIterator"))
return nullptr;
}
Py_INCREF(&PyGraphType);
if (PyModule_AddObject(mgp, "Graph",
reinterpret_cast<PyObject *>(&PyGraphType)) < 0) {
Py_DECREF(&PyGraphType);
if (!register_type(&PyGraphType, "Graph")) return nullptr;
if (!register_type(&PyQueryProcType, "Proc")) return nullptr;
if (!register_type(&PyQueryModuleType, "Module")) return nullptr;
Py_INCREF(Py_None);
if (PyModule_AddObject(mgp, "_MODULE", Py_None) < 0) {
Py_DECREF(Py_None);
Py_DECREF(mgp);
return nullptr;
}
return mgp;
}
namespace {
template <class TFun>
auto WithMgpModule(mgp_module *module_def, const TFun &fun) {
py::Object py_mgp(PyImport_ImportModule("_mgp"));
py::Object py_mgp_module(PyObject_GetAttrString(py_mgp, "_MODULE"));
CHECK(py_mgp_module) << "Expected '_mgp' to have attribute '_MODULE'";
// NOTE: This check is not thread safe, but this should only go through
// ModuleRegistry::LoadModuleLibrary which ought to serialize loading.
CHECK(py_mgp_module == Py_None)
<< "Expected '_mgp._MODULE' to be None as we are just starting to "
"import a new module. Is some other thread also importing Python "
"modules?";
CHECK(py_mgp) << "Expected builtin '_mgp' to be available for import";
auto *py_query_module = MakePyQueryModule(module_def);
CHECK(py_query_module);
CHECK(0 <= PyObject_SetAttrString(py_mgp, "_MODULE", py_query_module));
auto ret = fun();
CHECK(0 <= PyObject_SetAttrString(py_mgp, "_MODULE", Py_None));
return ret;
}
} // namespace
py::Object ImportPyModule(const char *name, mgp_module *module_def) {
return WithMgpModule(
module_def, [name]() { return py::Object(PyImport_ImportModule(name)); });
}
py::Object ReloadPyModule(PyObject *py_module, mgp_module *module_def) {
return WithMgpModule(module_def, [py_module]() {
return py::Object(PyImport_ReloadModule(py_module));
});
}
} // namespace query::procedure

View File

@ -6,6 +6,7 @@
struct mgp_graph;
struct mgp_memory;
struct mgp_module;
struct mgp_value;
namespace query::procedure {
@ -30,4 +31,20 @@ PyObject *PyInitMgpModule();
/// Create an instance of _mgp.Graph class.
PyObject *MakePyGraph(const mgp_graph *, mgp_memory *);
/// Import a module with given name in the context of mgp_module.
///
/// This function can only be called when '_mgp' module has been initialized in
/// Python.
///
/// Return nullptr and set appropriate Python exception on failure.
py::Object ImportPyModule(const char *, mgp_module *);
/// Reload already loaded Python module in the context of mgp_module.
///
/// This function can only be called when '_mgp' module has been initialized in
/// Python.
///
/// Return nullptr and set appropriate Python exception on failure.
py::Object ReloadPyModule(PyObject *, mgp_module *);
} // namespace query::procedure