Add python API for messages transformations (#181)

* Add python messages/transformations implementation

* Added fixed result return type to transformations

* Added is_deprecated to mgp_trans
This commit is contained in:
Kostas Kyrimis 2021-07-01 13:36:41 +03:00 committed by Antonio Andelic
parent ac230d0c2d
commit a928c158da
8 changed files with 516 additions and 47 deletions

View File

@ -845,7 +845,8 @@ const struct mgp_message *mgp_messages_at(const struct mgp_messages *, size_t);
/// Passed in arguments will not live longer than the callback's execution.
/// Therefore, you must not store them globally or use the passed in mgp_memory
/// to allocate global resources.
typedef void (*mgp_trans_cb)(const struct mgp_messages *, struct mgp_graph *, struct mgp_result *, struct mgp_memory *);
typedef void (*mgp_trans_cb)(const struct mgp_messages *, const struct mgp_graph *, struct mgp_result *,
struct mgp_memory *);
/// Adds a transformation cb to the module pointed by mgp_module.
/// Return non-zero if the transformation is added successfully.

View File

@ -714,6 +714,18 @@ class Deprecated:
self.field_type = type_
def raise_if_does_not_meet_requirements(func: typing.Callable[..., Record]):
if not callable(func):
raise TypeError("Expected a callable object, got an instance of '{}'"
.format(type(func)))
if inspect.iscoroutinefunction(func):
raise TypeError("Callable must not be 'async def' function")
if sys.version_info >= (3, 6):
if inspect.isasyncgenfunction(func):
raise TypeError("Callable must not be 'async def' function")
if inspect.isgeneratorfunction(func):
raise NotImplementedError("Generator functions are not supported")
def read_proc(func: typing.Callable[..., Record]):
'''
Register `func` as a a read-only procedure of the current module.
@ -754,16 +766,7 @@ def read_proc(func: typing.Callable[..., Record]):
CALL example.procedure(1) YIELD args, result;
Naturally, you may pass in different arguments or yield less fields.
'''
if not callable(func):
raise TypeError("Expected a callable object, got an instance of '{}'"
.format(type(func)))
if inspect.iscoroutinefunction(func):
raise TypeError("Callable must not be 'async def' function")
if sys.version_info >= (3, 6):
if inspect.isasyncgenfunction(func):
raise TypeError("Callable must not be 'async def' function")
if inspect.isgeneratorfunction(func):
raise NotImplementedError("Generator functions are not supported")
raise_if_does_not_meet_requirements(func)
sig = inspect.signature(func)
params = tuple(sig.parameters.values())
if params and params[0].annotation is ProcCtx:
@ -799,3 +802,123 @@ def read_proc(func: typing.Callable[..., Record]):
else:
mgp_proc.add_result(name, _typing_to_cypher_type(type_))
return func
class InvalidMessageError(Exception):
'''Signals using a message instance outside of the registered transformation.'''
pass
class Message:
'''Represents a message from a stream.'''
__slots__ = ('_message',)
def __init__(self, message):
if not isinstance(message, _mgp.Message):
raise TypeError("Expected '_mgp.Message', got '{}'".format(type(message)))
self._message = message
def __deepcopy__(self, memo):
# This is the same as the shallow copy, because we want to share the
# underlying C struct. Besides, it doesn't make much sense to actually
# copy _mgp.Messages as that always references all the messages.
return Message(self._message)
def is_valid(self) -> bool:
'''Return True if `self` is in valid context and may be used.'''
return self._message.is_valid()
def payload(self) -> bytes:
if not self.is_valid():
raise InvalidMessageError()
return self._messages._payload(_message)
def topic_name(self) -> str:
if not self.is_valid():
raise InvalidMessageError()
return self._messages._topic_name(_message)
def key() -> bytes:
if not self.is_valid():
raise InvalidMessageError()
return self._messages.key(_message)
def timestamp() -> int:
if not self.is_valid():
raise InvalidMessageError()
return self._messages.timestamp(_message)
class InvalidMessagesError(Exception):
'''Signals using a messages instance outside of the registered transformation.'''
pass
class Messages:
'''Represents a list of messages from a stream.'''
__slots__ = ('_messages',)
def __init__(self, messages):
if not isinstance(messages, _mgp.Messages):
raise TypeError("Expected '_mgp.Messages', got '{}'".format(type(messages)))
self._messages = messages
def __deepcopy__(self, memo):
# This is the same as the shallow copy, because we want to share the
# underlying C struct. Besides, it doesn't make much sense to actually
# copy _mgp.Messages as that always references all the messages.
return Messages(self._messages)
def is_valid(self) -> bool:
'''Return True if `self` is in valid context and may be used.'''
return self._messages.is_valid()
def message_at(self, id : int) -> Message:
'''Raise InvalidMessagesError if context is invalid.'''
if not self.is_valid():
raise InvalidMessagesError()
return Message(self._messages.message_at(id))
def total_messages() -> int:
'''Raise InvalidContextError if context is invalid.'''
if not self.is_valid():
raise InvalidMessagesError()
return self._messages.total_messages()
class TransCtx:
'''Context of a transformation being executed.
Access to a TransCtx is only valid during a single execution of a transformation.
You should not globally store a TransCtx instance.
'''
__slots__ = ('_graph')
def __init__(self, graph):
if not isinstance(graph, _mgp.Graph):
raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph)))
self._graph = Graph(graph)
def is_valid(self) -> bool:
return self._graph.is_valid()
@property
def graph(self) -> Graph:
'''Raise InvalidContextError if context is invalid.'''
if not self.is_valid():
raise InvalidContextError()
return self._graph
def transformation(func: typing.Callable[..., Record]):
raise_if_does_not_meet_requirements(func)
sig = inspect.signature(func)
params = tuple(sig.parameters.values())
if not params or not params[0].annotation is Messages:
if not len(params) == 2 or not params[1].annotation is Messages:
raise NotImplementedError("Valid signatures for transformations are (TransCtx, Messages) or (Messages)")
if params[0].annotation is TransCtx:
@functools.wraps(func)
def wrapper(graph, messages):
return func(TransCtx(graph), messages)
_mgp._MODULE.add_transformation(wrapper)
else:
@functools.wraps(func)
def wrapper(graph, messages):
return func(messages)
_mgp._MODULE.add_transformation(wrapper)
return func

View File

@ -3734,6 +3734,7 @@ class CallProcedureCursor : public Cursor {
mgp_graph graph{context.db_accessor, graph_view, &context};
CallCustomProcedure(self_->procedure_name_, *proc, self_->arguments_, graph, &evaluator, memory, memory_limit,
&result_);
// Reset result_.signature to nullptr, because outside of this scope we
// will no longer hold a lock on the `module`. If someone were to reload
// it, the pointer would be invalid.

View File

@ -8,10 +8,12 @@
#include "module.hpp"
#include "utils/algorithm.hpp"
#include "utils/concepts.hpp"
#include "utils/logging.hpp"
#include "utils/math.hpp"
#include "utils/memory.hpp"
#include "utils/string.hpp"
// This file contains implementation of top level C API functions, but this is
// all actually part of query::procedure. So use that namespace for simplicity.
// NOLINTNEXTLINE(google-build-using-namespace)
@ -1343,27 +1345,38 @@ int mgp_proc_add_opt_arg(mgp_proc *proc, const char *name, const mgp_type *type,
namespace {
int AddResultToProc(mgp_proc *proc, const char *name, const mgp_type *type, bool is_deprecated) {
if (!proc || !type) return 0;
if (!IsValidIdentifierName(name)) return 0;
if (proc->results.find(name) != proc->results.end()) return 0;
template <typename T>
concept ModuleProperties = utils::SameAsAnyOf<T, mgp_proc, mgp_trans>;
template <ModuleProperties T>
bool AddResultToProp(T *prop, const char *name, const mgp_type *type, bool is_deprecated) {
if (!prop || !type) return false;
if (!IsValidIdentifierName(name)) return false;
if (prop->results.find(name) != prop->results.end()) return false;
try {
auto *memory = proc->results.get_allocator().GetMemoryResource();
proc->results.emplace(utils::pmr::string(name, memory), std::make_pair(type->impl.get(), is_deprecated));
return 1;
auto *memory = prop->results.get_allocator().GetMemoryResource();
prop->results.emplace(utils::pmr::string(name, memory), std::make_pair(type->impl.get(), is_deprecated));
return true;
} catch (...) {
return 0;
return false;
}
}
} // namespace
int mgp_proc_add_result(mgp_proc *proc, const char *name, const mgp_type *type) {
return AddResultToProc(proc, name, type, false);
return AddResultToProp(proc, name, type, false);
}
bool MgpTransAddFixedResult(mgp_trans *trans) {
if (int err = AddResultToProp(trans, "query", mgp_type_string(), false); err != 1) {
return err;
}
return AddResultToProp(trans, "parameters", mgp_type_nullable(mgp_type_list(mgp_type_any())), false);
}
int mgp_proc_add_deprecated_result(mgp_proc *proc, const char *name, const mgp_type *type) {
return AddResultToProc(proc, name, type, true);
return AddResultToProp(proc, name, type, true);
}
int mgp_must_abort(const mgp_graph *graph) {

View File

@ -478,21 +478,23 @@ struct mgp_trans {
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_trans(const char *name, mgp_trans_cb cb, utils::MemoryResource *memory) : name(name, memory), cb(cb) {}
mgp_trans(const char *name, mgp_trans_cb cb, utils::MemoryResource *memory)
: name(name, memory), cb(cb), results(memory) {}
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_trans(const char *name,
std::function<void(const mgp_messages *, const mgp_graph *, mgp_result *, mgp_memory *)> cb,
utils::MemoryResource *memory)
: name(name, memory), cb(cb) {}
: name(name, memory), cb(cb), results(memory) {}
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_trans(const mgp_trans &other, utils::MemoryResource *memory) : name(other.name, memory), cb(other.cb) {}
mgp_trans(const mgp_trans &other, utils::MemoryResource *memory)
: name(other.name, memory), cb(other.cb), results(other.results) {}
mgp_trans(mgp_trans &&other, utils::MemoryResource *memory)
: name(std::move(other.name), memory), cb(std::move(other.cb)) {}
: name(std::move(other.name), memory), cb(std::move(other.cb)), results(std::move(other.results)) {}
mgp_trans(const mgp_trans &other) = default;
mgp_trans(mgp_trans &&other) = default;
@ -505,9 +507,13 @@ struct mgp_trans {
/// Name of the transformation.
utils::pmr::string name;
/// Entry-point for the transformation.
std::function<void(const mgp_messages *, mgp_graph *, mgp_result *, mgp_memory *)> cb;
std::function<void(const mgp_messages *, const mgp_graph *, mgp_result *, mgp_memory *)> cb;
/// Fields this transformation returns.
utils::pmr::map<utils::pmr::string, std::pair<const query::procedure::CypherType *, bool>> results;
};
bool MgpTransAddFixedResult(mgp_trans *trans);
struct mgp_module {
using allocator_type = utils::Allocator<mgp_module>;

View File

@ -7,6 +7,7 @@ extern "C" {
#include <optional>
#include "fmt/format.h"
#include "py/py.hpp"
#include "query/procedure/py_module.hpp"
#include "utils/file.hpp"
@ -275,30 +276,44 @@ bool SharedLibraryModule::Load(const std::filesystem::path &file_path) {
}
// Get required mgp_init_module
init_fn_ = reinterpret_cast<int (*)(mgp_module *, mgp_memory *)>(dlsym(handle_, "mgp_init_module"));
const char *error = dlerror();
if (!init_fn_ || error) {
spdlog::error("Unable to load module {}; {}", file_path, error);
char *dl_errored = dlerror();
if (!init_fn_ || dl_errored) {
spdlog::error("Unable to load module {}; {}", file_path, dl_errored);
dlclose(handle_);
handle_ = nullptr;
return false;
}
if (!WithModuleRegistration(&procedures_, &transformations_, [&](auto *module_def, auto *memory) {
// Run mgp_init_module which must succeed.
int init_res = init_fn_(module_def, memory);
if (init_res != 0) {
spdlog::error("Unable to load module {}; mgp_init_module_returned {}", file_path, init_res);
dlclose(handle_);
handle_ = nullptr;
return false;
}
return true;
})) {
auto module_cb = [&](auto *module_def, auto *memory) {
// Run mgp_init_module which must succeed.
int init_res = init_fn_(module_def, memory);
auto with_error = [this](std::string_view error_msg) {
spdlog::error(error_msg);
dlclose(handle_);
handle_ = nullptr;
return false;
};
if (init_res != 0) {
const auto error = fmt::format("Unable to load module {}; mgp_init_module_returned {} ", file_path, init_res);
return with_error(error);
}
for (auto &trans : module_def->transformations) {
const bool was_result_added = MgpTransAddFixedResult(&trans.second);
if (!was_result_added) {
const auto error =
fmt::format("Unable to add result to transformation in module {}; add result failed", file_path);
return with_error(error);
}
}
return true;
};
if (!WithModuleRegistration(&procedures_, &transformations_, module_cb)) {
return false;
}
// Get optional mgp_shutdown_module
shutdown_fn_ = reinterpret_cast<int (*)()>(dlsym(handle_, "mgp_shutdown_module"));
error = dlerror();
if (error) spdlog::warn("When loading module {}; {}", file_path, error);
dl_errored = dlerror();
if (dl_errored) spdlog::warn("When loading module {}; {}", file_path, dl_errored);
spdlog::info("Loaded module {}", file_path);
return true;
}
@ -376,11 +391,23 @@ bool PythonModule::Load(const std::filesystem::path &file_path) {
spdlog::error("Unable to load module {}; {}", file_path, *maybe_exc);
return false;
}
py_module_ = WithModuleRegistration(&procedures_, &transformations_, [&](auto *module_def, auto *memory) {
return ImportPyModule(file_path.stem().c_str(), module_def);
});
bool succ = true;
auto module_cb = [&](auto *module_def, auto *memory) {
auto result = ImportPyModule(file_path.stem().c_str(), module_def);
for (auto &trans : module_def->transformations) {
succ = MgpTransAddFixedResult(&trans.second);
if (!succ) return result;
};
return result;
};
py_module_ = WithModuleRegistration(&procedures_, &transformations_, module_cb);
if (py_module_) {
spdlog::info("Loaded module {}", file_path);
if (!succ) {
spdlog::error("Unable to add result to transformation");
return false;
}
return true;
}
auto exc_info = py::FetchError().value();
@ -566,6 +593,7 @@ std::optional<std::pair<std::string_view, std::string_view>> FindModuleNameAndPr
if (name_parts.size() == 1U) return std::nullopt;
auto last_dot_pos = fully_qualified_name.find_last_of('.');
MG_ASSERT(last_dot_pos != std::string_view::npos);
const auto &module_name = fully_qualified_name.substr(0, last_dot_pos);
const auto &name = name_parts.back();
return std::make_pair(module_name, name);

View File

@ -5,6 +5,7 @@
#include <string>
#include "query/procedure/mg_procedure_impl.hpp"
#include "utils/pmr/vector.hpp"
namespace query::procedure {
@ -178,7 +179,7 @@ PyObject *PyGraphGetVertexById(PyGraph *self, PyObject *args) {
MG_ASSERT(self->graph);
MG_ASSERT(self->memory);
static_assert(std::is_same_v<int64_t, long>);
int64_t id;
int64_t id = 0;
if (!PyArg_ParseTuple(args, "l", &id)) return nullptr;
auto *vertex = mgp_graph_get_vertex_by_id(self->graph, mgp_vertex_id{id}, self->memory);
if (!vertex) {
@ -400,6 +401,188 @@ struct PyQueryModule {
};
// clang-format on
struct PyMessages {
PyObject_HEAD;
const mgp_messages *messages;
mgp_memory *memory;
};
struct PyMessage {
PyObject_HEAD;
const mgp_message *message;
const PyMessages *messages;
mgp_memory *memory;
};
PyObject *PyMessagesIsValid(const PyMessages *self, PyObject *Py_UNUSED(ignored)) {
return PyBool_FromLong(!!self->messages);
}
PyObject *PyMessageIsValid(PyMessage *self, PyObject *Py_UNUSED(ignored)) {
return PyMessagesIsValid(self->messages, nullptr);
}
PyObject *PyMessageGetPayload(PyMessage *self, PyObject *Py_UNUSED(ignored)) {
MG_ASSERT(self->message);
auto payload_size = mgp_message_get_payload_size(self->message);
const auto *payload = mgp_message_get_payload(self->message);
auto *raw_bytes = PyByteArray_FromStringAndSize(payload, payload_size);
if (!raw_bytes) {
PyErr_SetString(PyExc_RuntimeError, "Unable to get raw bytes from payload");
return nullptr;
}
return raw_bytes;
}
PyObject *PyMessageGetTopicName(PyMessage *self, PyObject *Py_UNUSED(ignored)) {
MG_ASSERT(self->message);
MG_ASSERT(self->memory);
const auto *topic_name = mgp_message_topic_name(self->message);
auto *py_topic_name = PyUnicode_FromString(topic_name);
if (!py_topic_name) {
PyErr_SetString(PyExc_RuntimeError, "Unable to get raw bytes from payload");
return nullptr;
}
return py_topic_name;
}
PyObject *PyMessageGetKey(PyMessage *self, PyObject *Py_UNUSED(ignored)) {
MG_ASSERT(self->message);
MG_ASSERT(self->memory);
auto key_size = mgp_message_key_size(self->message);
const auto *key = mgp_message_key(self->message);
auto *raw_bytes = PyByteArray_FromStringAndSize(key, key_size);
if (!raw_bytes) {
PyErr_SetString(PyExc_RuntimeError, "Unable to get raw bytes from payload");
return nullptr;
}
return raw_bytes;
}
PyObject *PyMessageGetTimestamp(PyMessage *self, PyObject *Py_UNUSED(ignored)) {
MG_ASSERT(self->message);
MG_ASSERT(self->memory);
auto timestamp = mgp_message_timestamp(self->message);
auto *py_int = PyLong_FromUnsignedLong(timestamp);
if (!py_int) {
PyErr_SetString(PyExc_IndexError, "Unable to get timestamp.");
return nullptr;
}
return py_int;
}
// NOLINTNEXTLINE
static PyMethodDef PyMessageMethods[] = {
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
{"is_valid", reinterpret_cast<PyCFunction>(PyMessageIsValid), METH_NOARGS,
"Return True if messages is in valid context and may be used."},
{"payload", reinterpret_cast<PyCFunction>(PyMessageGetPayload), METH_NOARGS, "Get payload"},
{"topic_name", reinterpret_cast<PyCFunction>(PyMessageGetTopicName), METH_NOARGS, "Get topic name."},
{"key", reinterpret_cast<PyCFunction>(PyMessageGetKey), METH_NOARGS, "Get message key."},
{"timestamp", reinterpret_cast<PyCFunction>(PyMessageGetTimestamp), METH_NOARGS, "Get message timestamp."},
{nullptr},
};
void PyMessageDealloc(PyMessage *self) {
MG_ASSERT(self->memory);
MG_ASSERT(self->message);
MG_ASSERT(self->messages);
// NOLINTNEXTLINE
Py_DECREF(self->messages);
// NOLINTNEXTLINE
Py_TYPE(self)->tp_free(self);
}
// NOLINTNEXTLINE
static PyTypeObject PyMessageType = {
PyVarObject_HEAD_INIT(nullptr, 0).tp_name = "_mgp.Message",
.tp_basicsize = sizeof(PyMessage),
.tp_dealloc = reinterpret_cast<destructor>(PyMessageDealloc),
// NOLINTNEXTLINE
.tp_flags = Py_TPFLAGS_DEFAULT,
.tp_doc = "Wraps struct mgp_message.",
// NOLINTNEXTLINE
.tp_methods = PyMessageMethods,
};
PyObject *PyMessagesInvalidate(PyMessages *self, PyObject *Py_UNUSED(ignored)) {
self->messages = nullptr;
self->memory = nullptr;
Py_RETURN_NONE;
}
PyObject *PyMessagesGetTotalMessages(PyMessages *self, PyObject *Py_UNUSED(ignored)) {
MG_ASSERT(self->messages);
MG_ASSERT(self->memory);
auto size = self->messages->messages.size();
auto *py_int = PyLong_FromSize_t(size);
if (!py_int) {
PyErr_SetString(PyExc_IndexError, "Unable to get timestamp.");
return nullptr;
}
return py_int;
}
PyObject *PyMessagesGetMessageAt(PyMessages *self, PyObject *args) {
MG_ASSERT(self->messages);
MG_ASSERT(self->memory);
int64_t id = 0;
if (!PyArg_ParseTuple(args, "l", &id)) return nullptr;
if (id < 0) return nullptr;
const auto *message = mgp_messages_at(self->messages, id);
// NOLINTNEXTLINE
auto *py_message = PyObject_New(PyMessage, &PyMessageType);
if (!py_message) {
return nullptr;
}
py_message->message = message;
// NOLINTNEXTLINE
Py_INCREF(self);
py_message->messages = self;
py_message->memory = self->memory;
if (!message) {
PyErr_SetString(PyExc_IndexError, "Unable to find the message with given index.");
return nullptr;
}
// NOLINTNEXTLINE
return reinterpret_cast<PyObject *>(py_message);
}
// NOLINTNEXTLINE
static PyMethodDef PyMessagesMethods[] = {
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
{"invalidate", reinterpret_cast<PyCFunction>(PyMessagesInvalidate), METH_NOARGS,
"Invalidate the messages context thus preventing the messages from being used"},
{"is_valid", reinterpret_cast<PyCFunction>(PyMessagesIsValid), METH_NOARGS,
"Return True if messages is in valid context and may be used."},
{"total_messages", reinterpret_cast<PyCFunction>(PyMessagesGetTotalMessages), METH_VARARGS,
"Get number of messages available"},
{"message_at", reinterpret_cast<PyCFunction>(PyMessagesGetMessageAt), METH_VARARGS,
"Get message at index idx from messages"},
{nullptr},
};
// NOLINTNEXTLINE
static PyTypeObject PyMessagesType = {
PyVarObject_HEAD_INIT(nullptr, 0).tp_name = "_mgp.Messages",
.tp_basicsize = sizeof(PyMessages),
// NOLINTNEXTLINE
.tp_flags = Py_TPFLAGS_DEFAULT,
.tp_doc = "Wraps struct mgp_messages.",
// NOLINTNEXTLINE
.tp_methods = PyMessagesMethods,
};
PyObject *MakePyMessages(const mgp_messages *msgs, mgp_memory *memory) {
MG_ASSERT(!msgs || (msgs && memory));
// NOLINTNEXTLINE
auto *py_messages = PyObject_New(PyMessages, &PyMessagesType);
if (!py_messages) return nullptr;
py_messages->messages = msgs;
py_messages->memory = memory;
return reinterpret_cast<PyObject *>(py_messages);
}
py::Object MgpListToPyTuple(const mgp_list *list, PyGraph *py_graph) {
MG_ASSERT(list);
MG_ASSERT(py_graph);
@ -585,6 +768,86 @@ void CallPythonProcedure(const py::Object &py_cb, const mgp_list *args, const mg
}
}
void CallPythonTransformation(const py::Object &py_cb, const mgp_messages *msgs, const mgp_graph *graph,
mgp_result *result, mgp_memory *memory) {
auto gil = py::EnsureGIL();
auto error_to_msg = [](const std::optional<py::ExceptionInfo> &exc_info) -> std::optional<std::string> {
if (!exc_info) return std::nullopt;
// Here we tell the traceback formatter to skip the first line of the
// traceback because that line will always be our wrapper function in our
// internal `mgp.py` file. With that line skipped, the user will always
// get only the relevant traceback that happened in his Python code.
return py::FormatException(*exc_info, /* skip_first_line = */ true);
};
auto call = [&](py::Object py_graph, py::Object py_messages) -> std::optional<py::ExceptionInfo> {
auto py_res = py_cb.Call(py_messages, py_graph);
if (!py_res) return py::FetchError();
if (PySequence_Check(py_res.Ptr())) {
return AddMultipleRecordsFromPython(result, py_res);
}
return AddRecordFromPython(result, py_res);
};
auto cleanup = [](py::Object py_graph, py::Object py_messages) {
// Run `gc.collect` (reference cycle-detection) explicitly, so that we are
// sure the procedure cleaned up everything it held references to. If the
// user stored a reference to one of our `_mgp` instances then the
// internally used `mgp_*` structs will stay unfreed and a memory leak
// will be reported at the end of the query execution.
py::Object gc(PyImport_ImportModule("gc"));
if (!gc) {
LOG_FATAL(py::FetchError().value());
}
if (!gc.CallMethod("collect")) {
LOG_FATAL(py::FetchError().value());
}
// After making sure all references from our side have been cleared,
// invalidate the `_mgp.Graph` object. If the user kept a reference to one
// of our `_mgp` instances then this will prevent them from using those
// objects (whose internal `mgp_*` pointers are now invalid and would cause
// a crash).
if (!py_graph.CallMethod("invalidate")) {
LOG_FATAL(py::FetchError().value());
}
if (!py_messages.CallMethod("invalidate")) {
LOG_FATAL(py::FetchError().value());
}
};
// It is *VERY IMPORTANT* to note that this code takes great care not to keep
// any extra references to any `_mgp` instances (except for `_mgp.Graph`), so
// as not to introduce extra reference counts and prevent their deallocation.
// In particular, the `ExceptionInfo` object has a `traceback` field that
// contains references to the Python frames and their arguments, and therefore
// our `_mgp` instances as well. Within this code we ensure not to keep the
// `ExceptionInfo` object alive so that no extra reference counts are
// introduced. We only fetch the error message and immediately destroy the
// object.
std::optional<std::string> maybe_msg;
{
py::Object py_graph(MakePyGraph(graph, memory));
py::Object py_messages(MakePyMessages(msgs, memory));
if (py_graph && py_messages) {
try {
maybe_msg = error_to_msg(call(py_graph, py_messages));
cleanup(py_graph, py_messages);
} catch (...) {
cleanup(py_graph, py_messages);
throw;
}
} else {
maybe_msg = error_to_msg(py::FetchError());
}
}
if (maybe_msg) {
mgp_result_set_error_msg(result, maybe_msg->c_str());
}
}
} // namespace
PyObject *PyQueryModuleAddReadProcedure(PyQueryModule *self, PyObject *cb) {
@ -619,10 +882,41 @@ PyObject *PyQueryModuleAddReadProcedure(PyQueryModule *self, PyObject *cb) {
return reinterpret_cast<PyObject *>(py_proc);
}
PyObject *PyQueryModuleAddTransformation(PyQueryModule *self, PyObject *cb) {
MG_ASSERT(self->module);
if (!PyCallable_Check(cb)) {
PyErr_SetString(PyExc_TypeError, "Expected a callable object.");
return nullptr;
}
auto py_cb = py::Object::FromBorrow(cb);
py::Object py_name(py_cb.GetAttr("__name__"));
const auto *name = PyUnicode_AsUTF8(py_name.Ptr());
if (!name) return nullptr;
if (!IsValidIdentifierName(name)) {
PyErr_SetString(PyExc_ValueError, "Transformation name is not a valid identifier");
return nullptr;
}
auto *memory = self->module->transformations.get_allocator().GetMemoryResource();
mgp_trans trans(
name,
[py_cb](const mgp_messages *msgs, const mgp_graph *graph, mgp_result *result, mgp_memory *memory) {
CallPythonTransformation(py_cb, msgs, graph, result, memory);
},
memory);
const auto [trans_it, did_insert] = self->module->transformations.emplace(name, std::move(trans));
if (!did_insert) {
PyErr_SetString(PyExc_ValueError, "Already registered a procedure with the same name.");
return nullptr;
}
Py_RETURN_NONE;
}
static PyMethodDef PyQueryModuleMethods[] = {
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
{"add_read_procedure", reinterpret_cast<PyCFunction>(PyQueryModuleAddReadProcedure), METH_O,
"Register a read-only procedure with this module."},
{"add_transformation", reinterpret_cast<PyCFunction>(PyQueryModuleAddTransformation), METH_O,
"Register a transformation with this module."},
{nullptr},
};
@ -1348,6 +1642,8 @@ PyObject *PyInitMgpModule() {
if (!register_type(&PyVertexType, "Vertex")) return nullptr;
if (!register_type(&PyPathType, "Path")) return nullptr;
if (!register_type(&PyCypherTypeType, "Type")) return nullptr;
if (!register_type(&PyMessagesType, "Messages")) return nullptr;
if (!register_type(&PyMessageType, "Message")) return nullptr;
Py_INCREF(Py_None);
if (PyModule_AddObject(mgp, "_MODULE", Py_None) < 0) {
Py_DECREF(Py_None);

View File

@ -5,7 +5,8 @@
#include "test_utils.hpp"
TEST(MgpTransTest, TestMgpTransApi) {
constexpr auto no_op_cb = [](const mgp_messages *msg, mgp_graph *graph, mgp_result *result, mgp_memory *memory) {};
constexpr auto no_op_cb = [](const mgp_messages *msg, const mgp_graph *graph, mgp_result *result,
mgp_memory *memory) {};
mgp_module module(utils::NewDeleteResource());
// If this is false, then mgp_module_add_transformation()
// correctly calls IsValidIdentifier(). We don't need to test