From a928c158da47d4ef07d44de4de1b7b1ffecc8a80 Mon Sep 17 00:00:00 2001 From: Kostas Kyrimis Date: Thu, 1 Jul 2021 13:36:41 +0300 Subject: [PATCH] Add python API for messages transformations (#181) * Add python messages/transformations implementation * Added fixed result return type to transformations * Added is_deprecated to mgp_trans --- include/mg_procedure.h | 3 +- include/mgp.py | 143 ++++++++++- src/query/plan/operator.cpp | 1 + src/query/procedure/mg_procedure_impl.cpp | 33 ++- src/query/procedure/mg_procedure_impl.hpp | 16 +- src/query/procedure/module.cpp | 66 +++-- src/query/procedure/py_module.cpp | 298 +++++++++++++++++++++- tests/unit/mgp_trans_c_api.cpp | 3 +- 8 files changed, 516 insertions(+), 47 deletions(-) diff --git a/include/mg_procedure.h b/include/mg_procedure.h index 4b8a6aed0..e13065d16 100644 --- a/include/mg_procedure.h +++ b/include/mg_procedure.h @@ -845,7 +845,8 @@ const struct mgp_message *mgp_messages_at(const struct mgp_messages *, size_t); /// Passed in arguments will not live longer than the callback's execution. /// Therefore, you must not store them globally or use the passed in mgp_memory /// to allocate global resources. -typedef void (*mgp_trans_cb)(const struct mgp_messages *, struct mgp_graph *, struct mgp_result *, struct mgp_memory *); +typedef void (*mgp_trans_cb)(const struct mgp_messages *, const struct mgp_graph *, struct mgp_result *, + struct mgp_memory *); /// Adds a transformation cb to the module pointed by mgp_module. /// Return non-zero if the transformation is added successfully. diff --git a/include/mgp.py b/include/mgp.py index a7044a77d..025d9fd21 100644 --- a/include/mgp.py +++ b/include/mgp.py @@ -714,6 +714,18 @@ class Deprecated: self.field_type = type_ +def raise_if_does_not_meet_requirements(func: typing.Callable[..., Record]): + if not callable(func): + raise TypeError("Expected a callable object, got an instance of '{}'" + .format(type(func))) + if inspect.iscoroutinefunction(func): + raise TypeError("Callable must not be 'async def' function") + if sys.version_info >= (3, 6): + if inspect.isasyncgenfunction(func): + raise TypeError("Callable must not be 'async def' function") + if inspect.isgeneratorfunction(func): + raise NotImplementedError("Generator functions are not supported") + def read_proc(func: typing.Callable[..., Record]): ''' Register `func` as a a read-only procedure of the current module. @@ -754,16 +766,7 @@ def read_proc(func: typing.Callable[..., Record]): CALL example.procedure(1) YIELD args, result; Naturally, you may pass in different arguments or yield less fields. ''' - if not callable(func): - raise TypeError("Expected a callable object, got an instance of '{}'" - .format(type(func))) - if inspect.iscoroutinefunction(func): - raise TypeError("Callable must not be 'async def' function") - if sys.version_info >= (3, 6): - if inspect.isasyncgenfunction(func): - raise TypeError("Callable must not be 'async def' function") - if inspect.isgeneratorfunction(func): - raise NotImplementedError("Generator functions are not supported") + raise_if_does_not_meet_requirements(func) sig = inspect.signature(func) params = tuple(sig.parameters.values()) if params and params[0].annotation is ProcCtx: @@ -799,3 +802,123 @@ def read_proc(func: typing.Callable[..., Record]): else: mgp_proc.add_result(name, _typing_to_cypher_type(type_)) return func + +class InvalidMessageError(Exception): + '''Signals using a message instance outside of the registered transformation.''' + pass + +class Message: + '''Represents a message from a stream.''' + __slots__ = ('_message',) + + def __init__(self, message): + if not isinstance(message, _mgp.Message): + raise TypeError("Expected '_mgp.Message', got '{}'".format(type(message))) + self._message = message + + def __deepcopy__(self, memo): + # This is the same as the shallow copy, because we want to share the + # underlying C struct. Besides, it doesn't make much sense to actually + # copy _mgp.Messages as that always references all the messages. + return Message(self._message) + + def is_valid(self) -> bool: + '''Return True if `self` is in valid context and may be used.''' + return self._message.is_valid() + + def payload(self) -> bytes: + if not self.is_valid(): + raise InvalidMessageError() + return self._messages._payload(_message) + + def topic_name(self) -> str: + if not self.is_valid(): + raise InvalidMessageError() + return self._messages._topic_name(_message) + + def key() -> bytes: + if not self.is_valid(): + raise InvalidMessageError() + return self._messages.key(_message) + + def timestamp() -> int: + if not self.is_valid(): + raise InvalidMessageError() + return self._messages.timestamp(_message) + +class InvalidMessagesError(Exception): + '''Signals using a messages instance outside of the registered transformation.''' + pass + +class Messages: + '''Represents a list of messages from a stream.''' + __slots__ = ('_messages',) + + def __init__(self, messages): + if not isinstance(messages, _mgp.Messages): + raise TypeError("Expected '_mgp.Messages', got '{}'".format(type(messages))) + self._messages = messages + + def __deepcopy__(self, memo): + # This is the same as the shallow copy, because we want to share the + # underlying C struct. Besides, it doesn't make much sense to actually + # copy _mgp.Messages as that always references all the messages. + return Messages(self._messages) + + def is_valid(self) -> bool: + '''Return True if `self` is in valid context and may be used.''' + return self._messages.is_valid() + + def message_at(self, id : int) -> Message: + '''Raise InvalidMessagesError if context is invalid.''' + if not self.is_valid(): + raise InvalidMessagesError() + return Message(self._messages.message_at(id)) + + def total_messages() -> int: + '''Raise InvalidContextError if context is invalid.''' + if not self.is_valid(): + raise InvalidMessagesError() + return self._messages.total_messages() + +class TransCtx: + '''Context of a transformation being executed. + + Access to a TransCtx is only valid during a single execution of a transformation. + You should not globally store a TransCtx instance. + ''' + __slots__ = ('_graph') + + def __init__(self, graph): + if not isinstance(graph, _mgp.Graph): + raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph))) + self._graph = Graph(graph) + + def is_valid(self) -> bool: + return self._graph.is_valid() + + @property + def graph(self) -> Graph: + '''Raise InvalidContextError if context is invalid.''' + if not self.is_valid(): + raise InvalidContextError() + return self._graph + +def transformation(func: typing.Callable[..., Record]): + raise_if_does_not_meet_requirements(func) + sig = inspect.signature(func) + params = tuple(sig.parameters.values()) + if not params or not params[0].annotation is Messages: + if not len(params) == 2 or not params[1].annotation is Messages: + raise NotImplementedError("Valid signatures for transformations are (TransCtx, Messages) or (Messages)") + if params[0].annotation is TransCtx: + @functools.wraps(func) + def wrapper(graph, messages): + return func(TransCtx(graph), messages) + _mgp._MODULE.add_transformation(wrapper) + else: + @functools.wraps(func) + def wrapper(graph, messages): + return func(messages) + _mgp._MODULE.add_transformation(wrapper) + return func diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index d29e980b8..fc9d3faa6 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -3734,6 +3734,7 @@ class CallProcedureCursor : public Cursor { mgp_graph graph{context.db_accessor, graph_view, &context}; CallCustomProcedure(self_->procedure_name_, *proc, self_->arguments_, graph, &evaluator, memory, memory_limit, &result_); + // Reset result_.signature to nullptr, because outside of this scope we // will no longer hold a lock on the `module`. If someone were to reload // it, the pointer would be invalid. diff --git a/src/query/procedure/mg_procedure_impl.cpp b/src/query/procedure/mg_procedure_impl.cpp index 6ea693646..7f4d9e95a 100644 --- a/src/query/procedure/mg_procedure_impl.cpp +++ b/src/query/procedure/mg_procedure_impl.cpp @@ -8,10 +8,12 @@ #include "module.hpp" #include "utils/algorithm.hpp" +#include "utils/concepts.hpp" #include "utils/logging.hpp" #include "utils/math.hpp" #include "utils/memory.hpp" #include "utils/string.hpp" + // This file contains implementation of top level C API functions, but this is // all actually part of query::procedure. So use that namespace for simplicity. // NOLINTNEXTLINE(google-build-using-namespace) @@ -1343,27 +1345,38 @@ int mgp_proc_add_opt_arg(mgp_proc *proc, const char *name, const mgp_type *type, namespace { -int AddResultToProc(mgp_proc *proc, const char *name, const mgp_type *type, bool is_deprecated) { - if (!proc || !type) return 0; - if (!IsValidIdentifierName(name)) return 0; - if (proc->results.find(name) != proc->results.end()) return 0; +template +concept ModuleProperties = utils::SameAsAnyOf; + +template +bool AddResultToProp(T *prop, const char *name, const mgp_type *type, bool is_deprecated) { + if (!prop || !type) return false; + if (!IsValidIdentifierName(name)) return false; + if (prop->results.find(name) != prop->results.end()) return false; try { - auto *memory = proc->results.get_allocator().GetMemoryResource(); - proc->results.emplace(utils::pmr::string(name, memory), std::make_pair(type->impl.get(), is_deprecated)); - return 1; + auto *memory = prop->results.get_allocator().GetMemoryResource(); + prop->results.emplace(utils::pmr::string(name, memory), std::make_pair(type->impl.get(), is_deprecated)); + return true; } catch (...) { - return 0; + return false; } } } // namespace int mgp_proc_add_result(mgp_proc *proc, const char *name, const mgp_type *type) { - return AddResultToProc(proc, name, type, false); + return AddResultToProp(proc, name, type, false); +} + +bool MgpTransAddFixedResult(mgp_trans *trans) { + if (int err = AddResultToProp(trans, "query", mgp_type_string(), false); err != 1) { + return err; + } + return AddResultToProp(trans, "parameters", mgp_type_nullable(mgp_type_list(mgp_type_any())), false); } int mgp_proc_add_deprecated_result(mgp_proc *proc, const char *name, const mgp_type *type) { - return AddResultToProc(proc, name, type, true); + return AddResultToProp(proc, name, type, true); } int mgp_must_abort(const mgp_graph *graph) { diff --git a/src/query/procedure/mg_procedure_impl.hpp b/src/query/procedure/mg_procedure_impl.hpp index e95bd47ab..b39651dc3 100644 --- a/src/query/procedure/mg_procedure_impl.hpp +++ b/src/query/procedure/mg_procedure_impl.hpp @@ -478,21 +478,23 @@ struct mgp_trans { /// @throw std::bad_alloc /// @throw std::length_error - mgp_trans(const char *name, mgp_trans_cb cb, utils::MemoryResource *memory) : name(name, memory), cb(cb) {} + mgp_trans(const char *name, mgp_trans_cb cb, utils::MemoryResource *memory) + : name(name, memory), cb(cb), results(memory) {} /// @throw std::bad_alloc /// @throw std::length_error mgp_trans(const char *name, std::function cb, utils::MemoryResource *memory) - : name(name, memory), cb(cb) {} + : name(name, memory), cb(cb), results(memory) {} /// @throw std::bad_alloc /// @throw std::length_error - mgp_trans(const mgp_trans &other, utils::MemoryResource *memory) : name(other.name, memory), cb(other.cb) {} + mgp_trans(const mgp_trans &other, utils::MemoryResource *memory) + : name(other.name, memory), cb(other.cb), results(other.results) {} mgp_trans(mgp_trans &&other, utils::MemoryResource *memory) - : name(std::move(other.name), memory), cb(std::move(other.cb)) {} + : name(std::move(other.name), memory), cb(std::move(other.cb)), results(std::move(other.results)) {} mgp_trans(const mgp_trans &other) = default; mgp_trans(mgp_trans &&other) = default; @@ -505,9 +507,13 @@ struct mgp_trans { /// Name of the transformation. utils::pmr::string name; /// Entry-point for the transformation. - std::function cb; + std::function cb; + /// Fields this transformation returns. + utils::pmr::map> results; }; +bool MgpTransAddFixedResult(mgp_trans *trans); + struct mgp_module { using allocator_type = utils::Allocator; diff --git a/src/query/procedure/module.cpp b/src/query/procedure/module.cpp index 30001a74d..8050cba5f 100644 --- a/src/query/procedure/module.cpp +++ b/src/query/procedure/module.cpp @@ -7,6 +7,7 @@ extern "C" { #include +#include "fmt/format.h" #include "py/py.hpp" #include "query/procedure/py_module.hpp" #include "utils/file.hpp" @@ -275,30 +276,44 @@ bool SharedLibraryModule::Load(const std::filesystem::path &file_path) { } // Get required mgp_init_module init_fn_ = reinterpret_cast(dlsym(handle_, "mgp_init_module")); - const char *error = dlerror(); - if (!init_fn_ || error) { - spdlog::error("Unable to load module {}; {}", file_path, error); + char *dl_errored = dlerror(); + if (!init_fn_ || dl_errored) { + spdlog::error("Unable to load module {}; {}", file_path, dl_errored); dlclose(handle_); handle_ = nullptr; return false; } - if (!WithModuleRegistration(&procedures_, &transformations_, [&](auto *module_def, auto *memory) { - // Run mgp_init_module which must succeed. - int init_res = init_fn_(module_def, memory); - if (init_res != 0) { - spdlog::error("Unable to load module {}; mgp_init_module_returned {}", file_path, init_res); - dlclose(handle_); - handle_ = nullptr; - return false; - } - return true; - })) { + auto module_cb = [&](auto *module_def, auto *memory) { + // Run mgp_init_module which must succeed. + int init_res = init_fn_(module_def, memory); + auto with_error = [this](std::string_view error_msg) { + spdlog::error(error_msg); + dlclose(handle_); + handle_ = nullptr; + return false; + }; + + if (init_res != 0) { + const auto error = fmt::format("Unable to load module {}; mgp_init_module_returned {} ", file_path, init_res); + return with_error(error); + } + for (auto &trans : module_def->transformations) { + const bool was_result_added = MgpTransAddFixedResult(&trans.second); + if (!was_result_added) { + const auto error = + fmt::format("Unable to add result to transformation in module {}; add result failed", file_path); + return with_error(error); + } + } + return true; + }; + if (!WithModuleRegistration(&procedures_, &transformations_, module_cb)) { return false; } // Get optional mgp_shutdown_module shutdown_fn_ = reinterpret_cast(dlsym(handle_, "mgp_shutdown_module")); - error = dlerror(); - if (error) spdlog::warn("When loading module {}; {}", file_path, error); + dl_errored = dlerror(); + if (dl_errored) spdlog::warn("When loading module {}; {}", file_path, dl_errored); spdlog::info("Loaded module {}", file_path); return true; } @@ -376,11 +391,23 @@ bool PythonModule::Load(const std::filesystem::path &file_path) { spdlog::error("Unable to load module {}; {}", file_path, *maybe_exc); return false; } - py_module_ = WithModuleRegistration(&procedures_, &transformations_, [&](auto *module_def, auto *memory) { - return ImportPyModule(file_path.stem().c_str(), module_def); - }); + bool succ = true; + auto module_cb = [&](auto *module_def, auto *memory) { + auto result = ImportPyModule(file_path.stem().c_str(), module_def); + for (auto &trans : module_def->transformations) { + succ = MgpTransAddFixedResult(&trans.second); + if (!succ) return result; + }; + return result; + }; + py_module_ = WithModuleRegistration(&procedures_, &transformations_, module_cb); if (py_module_) { spdlog::info("Loaded module {}", file_path); + + if (!succ) { + spdlog::error("Unable to add result to transformation"); + return false; + } return true; } auto exc_info = py::FetchError().value(); @@ -566,6 +593,7 @@ std::optional> FindModuleNameAndPr if (name_parts.size() == 1U) return std::nullopt; auto last_dot_pos = fully_qualified_name.find_last_of('.'); MG_ASSERT(last_dot_pos != std::string_view::npos); + const auto &module_name = fully_qualified_name.substr(0, last_dot_pos); const auto &name = name_parts.back(); return std::make_pair(module_name, name); diff --git a/src/query/procedure/py_module.cpp b/src/query/procedure/py_module.cpp index 96ef5f387..157a47cb6 100644 --- a/src/query/procedure/py_module.cpp +++ b/src/query/procedure/py_module.cpp @@ -5,6 +5,7 @@ #include #include "query/procedure/mg_procedure_impl.hpp" +#include "utils/pmr/vector.hpp" namespace query::procedure { @@ -178,7 +179,7 @@ PyObject *PyGraphGetVertexById(PyGraph *self, PyObject *args) { MG_ASSERT(self->graph); MG_ASSERT(self->memory); static_assert(std::is_same_v); - int64_t id; + int64_t id = 0; if (!PyArg_ParseTuple(args, "l", &id)) return nullptr; auto *vertex = mgp_graph_get_vertex_by_id(self->graph, mgp_vertex_id{id}, self->memory); if (!vertex) { @@ -400,6 +401,188 @@ struct PyQueryModule { }; // clang-format on +struct PyMessages { + PyObject_HEAD; + const mgp_messages *messages; + mgp_memory *memory; +}; + +struct PyMessage { + PyObject_HEAD; + const mgp_message *message; + const PyMessages *messages; + mgp_memory *memory; +}; + +PyObject *PyMessagesIsValid(const PyMessages *self, PyObject *Py_UNUSED(ignored)) { + return PyBool_FromLong(!!self->messages); +} + +PyObject *PyMessageIsValid(PyMessage *self, PyObject *Py_UNUSED(ignored)) { + return PyMessagesIsValid(self->messages, nullptr); +} + +PyObject *PyMessageGetPayload(PyMessage *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self->message); + auto payload_size = mgp_message_get_payload_size(self->message); + const auto *payload = mgp_message_get_payload(self->message); + auto *raw_bytes = PyByteArray_FromStringAndSize(payload, payload_size); + if (!raw_bytes) { + PyErr_SetString(PyExc_RuntimeError, "Unable to get raw bytes from payload"); + return nullptr; + } + return raw_bytes; +} + +PyObject *PyMessageGetTopicName(PyMessage *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self->message); + MG_ASSERT(self->memory); + const auto *topic_name = mgp_message_topic_name(self->message); + auto *py_topic_name = PyUnicode_FromString(topic_name); + if (!py_topic_name) { + PyErr_SetString(PyExc_RuntimeError, "Unable to get raw bytes from payload"); + return nullptr; + } + return py_topic_name; +} + +PyObject *PyMessageGetKey(PyMessage *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self->message); + MG_ASSERT(self->memory); + auto key_size = mgp_message_key_size(self->message); + const auto *key = mgp_message_key(self->message); + auto *raw_bytes = PyByteArray_FromStringAndSize(key, key_size); + if (!raw_bytes) { + PyErr_SetString(PyExc_RuntimeError, "Unable to get raw bytes from payload"); + return nullptr; + } + return raw_bytes; +} + +PyObject *PyMessageGetTimestamp(PyMessage *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self->message); + MG_ASSERT(self->memory); + auto timestamp = mgp_message_timestamp(self->message); + auto *py_int = PyLong_FromUnsignedLong(timestamp); + if (!py_int) { + PyErr_SetString(PyExc_IndexError, "Unable to get timestamp."); + return nullptr; + } + return py_int; +} + +// NOLINTNEXTLINE +static PyMethodDef PyMessageMethods[] = { + {"__reduce__", reinterpret_cast(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"}, + {"is_valid", reinterpret_cast(PyMessageIsValid), METH_NOARGS, + "Return True if messages is in valid context and may be used."}, + {"payload", reinterpret_cast(PyMessageGetPayload), METH_NOARGS, "Get payload"}, + {"topic_name", reinterpret_cast(PyMessageGetTopicName), METH_NOARGS, "Get topic name."}, + {"key", reinterpret_cast(PyMessageGetKey), METH_NOARGS, "Get message key."}, + {"timestamp", reinterpret_cast(PyMessageGetTimestamp), METH_NOARGS, "Get message timestamp."}, + {nullptr}, +}; + +void PyMessageDealloc(PyMessage *self) { + MG_ASSERT(self->memory); + MG_ASSERT(self->message); + MG_ASSERT(self->messages); + // NOLINTNEXTLINE + Py_DECREF(self->messages); + // NOLINTNEXTLINE + Py_TYPE(self)->tp_free(self); +} + +// NOLINTNEXTLINE +static PyTypeObject PyMessageType = { + PyVarObject_HEAD_INIT(nullptr, 0).tp_name = "_mgp.Message", + .tp_basicsize = sizeof(PyMessage), + .tp_dealloc = reinterpret_cast(PyMessageDealloc), + // NOLINTNEXTLINE + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_doc = "Wraps struct mgp_message.", + // NOLINTNEXTLINE + .tp_methods = PyMessageMethods, +}; + +PyObject *PyMessagesInvalidate(PyMessages *self, PyObject *Py_UNUSED(ignored)) { + self->messages = nullptr; + self->memory = nullptr; + Py_RETURN_NONE; +} + +PyObject *PyMessagesGetTotalMessages(PyMessages *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self->messages); + MG_ASSERT(self->memory); + auto size = self->messages->messages.size(); + auto *py_int = PyLong_FromSize_t(size); + if (!py_int) { + PyErr_SetString(PyExc_IndexError, "Unable to get timestamp."); + return nullptr; + } + return py_int; +} + +PyObject *PyMessagesGetMessageAt(PyMessages *self, PyObject *args) { + MG_ASSERT(self->messages); + MG_ASSERT(self->memory); + int64_t id = 0; + if (!PyArg_ParseTuple(args, "l", &id)) return nullptr; + if (id < 0) return nullptr; + const auto *message = mgp_messages_at(self->messages, id); + // NOLINTNEXTLINE + auto *py_message = PyObject_New(PyMessage, &PyMessageType); + if (!py_message) { + return nullptr; + } + py_message->message = message; + // NOLINTNEXTLINE + Py_INCREF(self); + py_message->messages = self; + py_message->memory = self->memory; + if (!message) { + PyErr_SetString(PyExc_IndexError, "Unable to find the message with given index."); + return nullptr; + } + // NOLINTNEXTLINE + return reinterpret_cast(py_message); +} + +// NOLINTNEXTLINE +static PyMethodDef PyMessagesMethods[] = { + {"__reduce__", reinterpret_cast(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"}, + {"invalidate", reinterpret_cast(PyMessagesInvalidate), METH_NOARGS, + "Invalidate the messages context thus preventing the messages from being used"}, + {"is_valid", reinterpret_cast(PyMessagesIsValid), METH_NOARGS, + "Return True if messages is in valid context and may be used."}, + {"total_messages", reinterpret_cast(PyMessagesGetTotalMessages), METH_VARARGS, + "Get number of messages available"}, + {"message_at", reinterpret_cast(PyMessagesGetMessageAt), METH_VARARGS, + "Get message at index idx from messages"}, + {nullptr}, +}; + +// NOLINTNEXTLINE +static PyTypeObject PyMessagesType = { + PyVarObject_HEAD_INIT(nullptr, 0).tp_name = "_mgp.Messages", + .tp_basicsize = sizeof(PyMessages), + // NOLINTNEXTLINE + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_doc = "Wraps struct mgp_messages.", + // NOLINTNEXTLINE + .tp_methods = PyMessagesMethods, +}; + +PyObject *MakePyMessages(const mgp_messages *msgs, mgp_memory *memory) { + MG_ASSERT(!msgs || (msgs && memory)); + // NOLINTNEXTLINE + auto *py_messages = PyObject_New(PyMessages, &PyMessagesType); + if (!py_messages) return nullptr; + py_messages->messages = msgs; + py_messages->memory = memory; + return reinterpret_cast(py_messages); +} + py::Object MgpListToPyTuple(const mgp_list *list, PyGraph *py_graph) { MG_ASSERT(list); MG_ASSERT(py_graph); @@ -585,6 +768,86 @@ void CallPythonProcedure(const py::Object &py_cb, const mgp_list *args, const mg } } +void CallPythonTransformation(const py::Object &py_cb, const mgp_messages *msgs, const mgp_graph *graph, + mgp_result *result, mgp_memory *memory) { + auto gil = py::EnsureGIL(); + + auto error_to_msg = [](const std::optional &exc_info) -> std::optional { + if (!exc_info) return std::nullopt; + // Here we tell the traceback formatter to skip the first line of the + // traceback because that line will always be our wrapper function in our + // internal `mgp.py` file. With that line skipped, the user will always + // get only the relevant traceback that happened in his Python code. + return py::FormatException(*exc_info, /* skip_first_line = */ true); + }; + + auto call = [&](py::Object py_graph, py::Object py_messages) -> std::optional { + auto py_res = py_cb.Call(py_messages, py_graph); + if (!py_res) return py::FetchError(); + if (PySequence_Check(py_res.Ptr())) { + return AddMultipleRecordsFromPython(result, py_res); + } + return AddRecordFromPython(result, py_res); + }; + + auto cleanup = [](py::Object py_graph, py::Object py_messages) { + // Run `gc.collect` (reference cycle-detection) explicitly, so that we are + // sure the procedure cleaned up everything it held references to. If the + // user stored a reference to one of our `_mgp` instances then the + // internally used `mgp_*` structs will stay unfreed and a memory leak + // will be reported at the end of the query execution. + py::Object gc(PyImport_ImportModule("gc")); + if (!gc) { + LOG_FATAL(py::FetchError().value()); + } + + if (!gc.CallMethod("collect")) { + LOG_FATAL(py::FetchError().value()); + } + + // After making sure all references from our side have been cleared, + // invalidate the `_mgp.Graph` object. If the user kept a reference to one + // of our `_mgp` instances then this will prevent them from using those + // objects (whose internal `mgp_*` pointers are now invalid and would cause + // a crash). + if (!py_graph.CallMethod("invalidate")) { + LOG_FATAL(py::FetchError().value()); + } + if (!py_messages.CallMethod("invalidate")) { + LOG_FATAL(py::FetchError().value()); + } + }; + + // It is *VERY IMPORTANT* to note that this code takes great care not to keep + // any extra references to any `_mgp` instances (except for `_mgp.Graph`), so + // as not to introduce extra reference counts and prevent their deallocation. + // In particular, the `ExceptionInfo` object has a `traceback` field that + // contains references to the Python frames and their arguments, and therefore + // our `_mgp` instances as well. Within this code we ensure not to keep the + // `ExceptionInfo` object alive so that no extra reference counts are + // introduced. We only fetch the error message and immediately destroy the + // object. + std::optional maybe_msg; + { + py::Object py_graph(MakePyGraph(graph, memory)); + py::Object py_messages(MakePyMessages(msgs, memory)); + if (py_graph && py_messages) { + try { + maybe_msg = error_to_msg(call(py_graph, py_messages)); + cleanup(py_graph, py_messages); + } catch (...) { + cleanup(py_graph, py_messages); + throw; + } + } else { + maybe_msg = error_to_msg(py::FetchError()); + } + } + + if (maybe_msg) { + mgp_result_set_error_msg(result, maybe_msg->c_str()); + } +} } // namespace PyObject *PyQueryModuleAddReadProcedure(PyQueryModule *self, PyObject *cb) { @@ -619,10 +882,41 @@ PyObject *PyQueryModuleAddReadProcedure(PyQueryModule *self, PyObject *cb) { return reinterpret_cast(py_proc); } +PyObject *PyQueryModuleAddTransformation(PyQueryModule *self, PyObject *cb) { + MG_ASSERT(self->module); + if (!PyCallable_Check(cb)) { + PyErr_SetString(PyExc_TypeError, "Expected a callable object."); + return nullptr; + } + auto py_cb = py::Object::FromBorrow(cb); + py::Object py_name(py_cb.GetAttr("__name__")); + const auto *name = PyUnicode_AsUTF8(py_name.Ptr()); + if (!name) return nullptr; + if (!IsValidIdentifierName(name)) { + PyErr_SetString(PyExc_ValueError, "Transformation name is not a valid identifier"); + return nullptr; + } + auto *memory = self->module->transformations.get_allocator().GetMemoryResource(); + mgp_trans trans( + name, + [py_cb](const mgp_messages *msgs, const mgp_graph *graph, mgp_result *result, mgp_memory *memory) { + CallPythonTransformation(py_cb, msgs, graph, result, memory); + }, + memory); + const auto [trans_it, did_insert] = self->module->transformations.emplace(name, std::move(trans)); + if (!did_insert) { + PyErr_SetString(PyExc_ValueError, "Already registered a procedure with the same name."); + return nullptr; + } + Py_RETURN_NONE; +} + static PyMethodDef PyQueryModuleMethods[] = { {"__reduce__", reinterpret_cast(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"}, {"add_read_procedure", reinterpret_cast(PyQueryModuleAddReadProcedure), METH_O, "Register a read-only procedure with this module."}, + {"add_transformation", reinterpret_cast(PyQueryModuleAddTransformation), METH_O, + "Register a transformation with this module."}, {nullptr}, }; @@ -1348,6 +1642,8 @@ PyObject *PyInitMgpModule() { if (!register_type(&PyVertexType, "Vertex")) return nullptr; if (!register_type(&PyPathType, "Path")) return nullptr; if (!register_type(&PyCypherTypeType, "Type")) return nullptr; + if (!register_type(&PyMessagesType, "Messages")) return nullptr; + if (!register_type(&PyMessageType, "Message")) return nullptr; Py_INCREF(Py_None); if (PyModule_AddObject(mgp, "_MODULE", Py_None) < 0) { Py_DECREF(Py_None); diff --git a/tests/unit/mgp_trans_c_api.cpp b/tests/unit/mgp_trans_c_api.cpp index f06b2e77c..54051c079 100644 --- a/tests/unit/mgp_trans_c_api.cpp +++ b/tests/unit/mgp_trans_c_api.cpp @@ -5,7 +5,8 @@ #include "test_utils.hpp" TEST(MgpTransTest, TestMgpTransApi) { - constexpr auto no_op_cb = [](const mgp_messages *msg, mgp_graph *graph, mgp_result *result, mgp_memory *memory) {}; + constexpr auto no_op_cb = [](const mgp_messages *msg, const mgp_graph *graph, mgp_result *result, + mgp_memory *memory) {}; mgp_module module(utils::NewDeleteResource()); // If this is false, then mgp_module_add_transformation() // correctly calls IsValidIdentifier(). We don't need to test