diff --git a/src/py/py.hpp b/src/py/py.hpp index 674c98630..fedd0ed2d 100644 --- a/src/py/py.hpp +++ b/src/py/py.hpp @@ -226,8 +226,22 @@ inline std::ostream &operator<<(std::ostream &os, return os; } +/// Format ExceptionInfo as a string just like the Python interpreter would. +[[nodiscard]] inline std::string FormatException( + const ExceptionInfo &exc_info) { + std::stringstream ss; + ss << exc_info; + return ss.str(); +} + /// Get the current exception info and clear the current exception indicator. /// +/// NOTE: This function must be used with caution because it returns the +/// exception information as real Python objects. The returned objects will have +/// references to the current objects on the Python frame in them. That could +/// cause unintentional lifetime extension of objects that you potentially want +/// to destroy or, even worse, that you have already destroyed. +/// /// This is normally used to catch and handle exceptions via C API. [[nodiscard]] inline std::optional FetchError() { PyObject *exc_type, *exc_value, *traceback; diff --git a/src/query/procedure/py_module.cpp b/src/query/procedure/py_module.cpp index 83dae0904..ff933e699 100644 --- a/src/query/procedure/py_module.cpp +++ b/src/query/procedure/py_module.cpp @@ -417,13 +417,6 @@ py::Object MgpListToPyTuple(const mgp_list *list, PyObject *py_graph) { namespace { -void SetErrorFromPython(mgp_result *result, const py::ExceptionInfo &exc_info) { - std::stringstream ss; - ss << exc_info; - const auto &msg = ss.str(); - mgp_result_set_error_msg(result, msg.c_str()); -} - std::optional AddRecordFromPython(mgp_result *result, py::Object py_record) { py::Object py_mgp(PyImport_ImportModule("mgp")); @@ -508,32 +501,75 @@ std::optional AddMultipleRecordsFromPython( return std::nullopt; } -template -std::optional WithPyGraph(const mgp_graph *graph, - mgp_memory *memory, - const TFun &fun) { - py::Object py_graph(MakePyGraph(graph, memory)); - if (!py_graph) return py::FetchError(); - try { - auto maybe_exc = fun(py_graph); - // After `fun` finishes, invalidate the graph thus preventing its use in - // Python code. This is just a precaution if someone were to store - // `mgp_` objects globally in Python. - LOG_IF(FATAL, !py_graph.CallMethod("invalidate")) - << py::FetchError().value(); - // Run gc.collect (reference cycle-detection) explicitly, so that we are - // sure the procedure cleaned up everything it held references to. If any - // `mgp_` remains alive, that means the user stored in somewhere - // globally and that will get reported as a query procedure memory leak in - // our logs. +void CallPythonProcedure(py::Object py_cb, const mgp_list *args, + const mgp_graph *graph, mgp_result *result, + mgp_memory *memory) { + auto gil = py::EnsureGIL(); + + auto error_to_msg = [](const std::optional &exc_info) + -> std::optional { + if (!exc_info) return std::nullopt; + return py::FormatException(*exc_info); + }; + + auto call = [&](py::Object py_graph) -> std::optional { + py::Object py_args(MgpListToPyTuple(args, py_graph.Ptr())); + if (!py_args) return py::FetchError(); + auto py_res = py_cb.Call(py_graph, py_args); + if (!py_res) return py::FetchError(); + if (PySequence_Check(py_res.Ptr())) { + return AddMultipleRecordsFromPython(result, py_res); + } else { + return AddRecordFromPython(result, py_res); + } + }; + + auto cleanup = [](py::Object py_graph) { + // Run `gc.collect` (reference cycle-detection) explicitly, so that we are + // sure the procedure cleaned up everything it held references to. If the + // user stored a reference to one of our `_mgp` instances then the + // internally used `mgp_*` structs will stay unfreed and a memory leak + // will be reported at the end of the query execution. py::Object gc(PyImport_ImportModule("gc")); LOG_IF(FATAL, !gc) << py::FetchError().value(); LOG_IF(FATAL, !gc.CallMethod("collect")) << py::FetchError().value(); - return maybe_exc; - } catch (...) { + + // After making sure all references from our side have been cleared, + // invalidate the `_mgp.Graph` object. If the user kept a reference to one + // of our `_mgp` instances then this will prevent them from using those + // objects (whose internal `mgp_*` pointers are now invalid and would cause + // a crash). LOG_IF(FATAL, !py_graph.CallMethod("invalidate")) << py::FetchError().value(); - throw; + }; + + // It is *VERY IMPORTANT* to note that this code takes great care not to keep + // any extra references to any `_mgp` instances (except for `_mgp.Graph`), so + // as not to introduce extra reference counts and prevent their deallocation. + // In particular, the `ExceptionInfo` object has a `traceback` field that + // contains references to the Python frames and their arguments, and therefore + // our `_mgp` instances as well. Within this code we ensure not to keep the + // `ExceptionInfo` object alive so that no extra reference counts are + // introduced. We only fetch the error message and immediately destroy the + // object. + std::optional maybe_msg; + { + py::Object py_graph(MakePyGraph(graph, memory)); + if (py_graph) { + try { + maybe_msg = error_to_msg(call(py_graph)); + cleanup(py_graph); + } catch (...) { + cleanup(py_graph); + throw; + } + } else { + maybe_msg = error_to_msg(py::FetchError()); + } + } + + if (maybe_msg) { + mgp_result_set_error_msg(result, maybe_msg->c_str()); } } @@ -559,21 +595,7 @@ PyObject *PyQueryModuleAddReadProcedure(PyQueryModule *self, PyObject *cb) { name, [py_cb](const mgp_list *args, const mgp_graph *graph, mgp_result *result, mgp_memory *memory) { - auto gil = py::EnsureGIL(); - auto maybe_exc = WithPyGraph( - graph, memory, - [&](auto py_graph) -> std::optional { - py::Object py_args(MgpListToPyTuple(args, py_graph.Ptr())); - if (!py_args) return py::FetchError(); - auto py_res = py_cb.Call(py_graph, py_args); - if (!py_res) return py::FetchError(); - if (PySequence_Check(py_res.Ptr())) { - return AddMultipleRecordsFromPython(result, py_res); - } else { - return AddRecordFromPython(result, py_res); - } - }); - if (maybe_exc) return SetErrorFromPython(result, *maybe_exc); + CallPythonProcedure(py_cb, args, graph, result, memory); }, memory); const auto &[proc_it, did_insert] =