From fbdcad1106ec58da0f2067257523ae07cb1fa260 Mon Sep 17 00:00:00 2001 From: Teon Banek Date: Fri, 6 Mar 2020 15:31:31 +0100 Subject: [PATCH] Implement calling Python procedures Summary: You should now be able to invoke query procedures written in Python. To test the example you can run memgraph with PYTHONPATH set to `include`. For example, assuming you are in the root of the repo, run this command. PYTHONPATH=$PWD/include ./build/memgraph --query-modules-directory=./query_modules/ Alternatively, you can set a symlink inside the ./query_modules to point to `include/mgp.py`, so there's no need to set PYTHONPATH. For example, assuming you are in the root of the repo, run the following. cd ./query_modules ln -s ../include/mgp.py cd .. ./build/memgraph --query-modules-directory=./query_modules/ Depends on D207 Reviewers: mferencevic, ipaljak, dsantl Reviewed By: ipaljak Subscribers: buda, tlastre, pullbot Differential Revision: https://phabricator.memgraph.io/D2708 --- include/mgp.py | 23 ++- src/py/py.hpp | 7 + src/query/procedure/mg_procedure_impl.cpp | 14 +- src/query/procedure/mg_procedure_impl.hpp | 2 + src/query/procedure/py_module.cpp | 169 +++++++++++++++++++++- 5 files changed, 197 insertions(+), 18 deletions(-) diff --git a/include/mgp.py b/include/mgp.py index 3cf45e985..eca18be78 100644 --- a/include/mgp.py +++ b/include/mgp.py @@ -456,11 +456,22 @@ class ProcCtx: Access to a ProcCtx is only valid during a single execution of a procedure in a query. You should not globally store a ProcCtx instance. ''' + __slots__ = ('_graph',) + + def __init__(self, graph): + if not isinstance(graph, _mgp.Graph): + raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph))) + self._graph = Graph(graph) + + def is_valid(self) -> bool: + return self._graph.is_valid() @property def graph(self) -> Graph: '''Raise InvalidContextError if context is invalid.''' - pass + if not self.is_valid(): + raise InvalidContextError() + return self._graph # Additional typing support @@ -634,13 +645,15 @@ def read_proc(func: typing.Callable[..., Record]): sig = inspect.signature(func) params = tuple(sig.parameters.values()) if params and params[0].annotation is ProcCtx: + @functools.wraps(func) + def wrapper(graph, args): + return func(ProcCtx(graph), *args) params = params[1:] - mgp_proc = _mgp._MODULE.add_read_procedure(func) + mgp_proc = _mgp._MODULE.add_read_procedure(wrapper) else: @functools.wraps(func) - def wrapper(*args): - args_without_context = args[1:] - return func(*args_without_context) + def wrapper(graph, args): + return func(*args) mgp_proc = _mgp._MODULE.add_read_procedure(wrapper) for param in params: name = param.name diff --git a/src/py/py.hpp b/src/py/py.hpp index 73b20e6c3..642b4d5f6 100644 --- a/src/py/py.hpp +++ b/src/py/py.hpp @@ -44,8 +44,15 @@ class [[nodiscard]] Object final { public: Object() = default; Object(std::nullptr_t) {} + /// Construct by taking the ownership of `PyObject *`. explicit Object(PyObject *ptr) noexcept : ptr_(ptr) {} + /// Construct from a borrowed `PyObject *`, i.e. non-owned pointer. + static Object FromBorrow(PyObject *ptr) noexcept { + Py_XINCREF(ptr); + return Object(ptr); + } + ~Object() noexcept { Py_XDECREF(ptr_); } Object(const Object &other) noexcept : ptr_(other.ptr_) { Py_XINCREF(ptr_); } diff --git a/src/query/procedure/mg_procedure_impl.cpp b/src/query/procedure/mg_procedure_impl.cpp index 4bda561d7..70224a228 100644 --- a/src/query/procedure/mg_procedure_impl.cpp +++ b/src/query/procedure/mg_procedure_impl.cpp @@ -1386,14 +1386,6 @@ const mgp_type *mgp_type_nullable(const mgp_type *type) { } } -namespace { -bool IsValidIdentifierName(const char *name) { - if (!name) return false; - std::regex regex("[_[:alpha:]][_[:alnum:]]*"); - return std::regex_match(name, regex); -} -} // namespace - mgp_proc *mgp_module_add_read_procedure(mgp_module *module, const char *name, mgp_proc_cb cb) { if (!module || !cb) return nullptr; @@ -1548,4 +1540,10 @@ void PrintProcSignature(const mgp_proc &proc, std::ostream *stream) { (*stream) << ")"; } +bool IsValidIdentifierName(const char *name) { + if (!name) return false; + std::regex regex("[_[:alpha:]][_[:alnum:]]*"); + return std::regex_match(name, regex); +} + } // namespace query::procedure diff --git a/src/query/procedure/mg_procedure_impl.hpp b/src/query/procedure/mg_procedure_impl.hpp index b7cd5ba55..61908466b 100644 --- a/src/query/procedure/mg_procedure_impl.hpp +++ b/src/query/procedure/mg_procedure_impl.hpp @@ -573,4 +573,6 @@ namespace query::procedure { /// @throw anything std::ostream::operator<< may throw. void PrintProcSignature(const mgp_proc &, std::ostream *); +bool IsValidIdentifierName(const char *name); + } // namespace query::procedure diff --git a/src/query/procedure/py_module.cpp b/src/query/procedure/py_module.cpp index 3ffd77b37..b2c633abf 100644 --- a/src/query/procedure/py_module.cpp +++ b/src/query/procedure/py_module.cpp @@ -161,6 +161,12 @@ static PyTypeObject PyEdgesIteratorType = { .tp_dealloc = reinterpret_cast(PyEdgesIteratorDealloc), }; +PyObject *PyGraphInvalidate(PyGraph *self, PyObject *Py_UNUSED(ignored)) { + self->graph = nullptr; + self->memory = nullptr; + Py_RETURN_NONE; +} + PyObject *PyGraphIsValid(PyGraph *self, PyObject *Py_UNUSED(ignored)) { return PyBool_FromLong(!!self->graph); } @@ -208,6 +214,9 @@ PyObject *PyGraphIterVertices(PyGraph *self, PyObject *Py_UNUSED(ignored)) { static PyMethodDef PyGraphMethods[] = { {"__reduce__", reinterpret_cast(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"}, + {"invalidate", reinterpret_cast(PyGraphInvalidate), + METH_NOARGS, + "Invalidate the Graph context thus preventing the Graph from being used."}, {"is_valid", reinterpret_cast(PyGraphIsValid), METH_NOARGS, "Return True if Graph is in valid context and may be used."}, {"get_vertex_by_id", reinterpret_cast(PyGraphGetVertexById), @@ -403,23 +412,173 @@ py::Object MgpListToPyTuple(const mgp_list *list, PyGraph *py_graph) { return py_tuple; } +py::Object MgpListToPyTuple(const mgp_list *list, PyObject *py_graph) { + if (Py_TYPE(py_graph) != &PyGraphType) { + PyErr_SetString(PyExc_TypeError, "Expected a _mgp.Graph."); + return nullptr; + } + return MgpListToPyTuple(list, reinterpret_cast(py_graph)); +} + +namespace { + +void SetErrorFromPython(mgp_result *result, const py::ExceptionInfo &exc_info) { + std::stringstream ss; + ss << exc_info; + const auto &msg = ss.str(); + mgp_result_set_error_msg(result, msg.c_str()); +} + +std::optional AddRecordFromPython(mgp_result *result, + py::Object py_record) { + py::Object py_mgp(PyImport_ImportModule("mgp")); + if (!py_mgp) return py::FetchError(); + auto record_cls = py_mgp.GetAttr("Record"); + if (!record_cls) return py::FetchError(); + if (!PyObject_IsInstance(py_record, record_cls)) { + std::stringstream ss; + ss << "Value '" << py_record << "' is not an instance of 'mgp.Record'"; + const auto &msg = ss.str(); + PyErr_SetString(PyExc_TypeError, msg.c_str()); + return py::FetchError(); + } + py::Object fields(py_record.GetAttr("fields")); + if (!fields) return py::FetchError(); + if (!PyDict_Check(fields)) { + PyErr_SetString(PyExc_TypeError, + "Expected 'mgp.Record.fields' to be a 'dict'"); + return py::FetchError(); + } + py::Object items(PyDict_Items(fields)); + if (!items) return py::FetchError(); + auto *record = mgp_result_new_record(result); + if (!record) { + PyErr_NoMemory(); + return py::FetchError(); + } + Py_ssize_t len = PyList_GET_SIZE(static_cast(items)); + for (Py_ssize_t i = 0; i < len; ++i) { + auto *item = PyList_GET_ITEM(static_cast(items), i); + if (!item) return py::FetchError(); + CHECK(PyTuple_Check(item)); + auto *key = PyTuple_GetItem(item, 0); + if (!key) return py::FetchError(); + if (!PyUnicode_Check(key)) { + std::stringstream ss; + ss << "Field name '" << py::Object::FromBorrow(key) + << "' is not an instance of 'str'"; + const auto &msg = ss.str(); + PyErr_SetString(PyExc_TypeError, msg.c_str()); + return py::FetchError(); + } + const auto *field_name = PyUnicode_AsUTF8(key); + if (!field_name) return py::FetchError(); + auto *val = PyTuple_GetItem(item, 1); + if (!val) return py::FetchError(); + mgp_memory memory{result->rows.get_allocator().GetMemoryResource()}; + mgp_value *field_val{nullptr}; + try { + // TODO: Make PyObjectToMgpValue set a Python exception instead. + field_val = PyObjectToMgpValue(val, &memory); + } catch (const std::exception &e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return py::FetchError(); + } + CHECK(field_val); + if (!mgp_result_record_insert(record, field_name, field_val)) { + std::stringstream ss; + ss << "Unable to insert field '" << py::Object::FromBorrow(key) + << "' with value: '" << py::Object::FromBorrow(val) + << "'; did you set the correct field type?"; + const auto &msg = ss.str(); + PyErr_SetString(PyExc_ValueError, msg.c_str()); + mgp_value_destroy(field_val); + return py::FetchError(); + } + mgp_value_destroy(field_val); + } + return std::nullopt; +} + +std::optional AddMultipleRecordsFromPython( + mgp_result *result, py::Object py_seq) { + Py_ssize_t len = PySequence_Size(py_seq); + if (len == -1) return py::FetchError(); + for (Py_ssize_t i = 0; i < len; ++i) { + py::Object py_record(PySequence_GetItem(py_seq, i)); + if (!py_record) return py::FetchError(); + auto maybe_exc = AddRecordFromPython(result, py_record); + if (maybe_exc) return maybe_exc; + } + return std::nullopt; +} + +template +std::optional WithPyGraph(const mgp_graph *graph, + mgp_memory *memory, + const TFun &fun) { + py::Object py_graph(MakePyGraph(graph, memory)); + if (!py_graph) return py::FetchError(); + try { + auto maybe_exc = fun(py_graph); + // After `fun` finishes, invalidate the graph thus preventing its use in + // Python code. This is just a precaution if someone were to store + // `mgp_` objects globally in Python. + LOG_IF(FATAL, !py_graph.CallMethod("invalidate")) + << py::FetchError().value(); + // Run gc.collect (reference cycle-detection) explicitly, so that we are + // sure the procedure cleaned up everything it held references to. If any + // `mgp_` remains alive, that means the user stored in somewhere + // globally and that will get reported as a query procedure memory leak in + // our logs. + py::Object gc(PyImport_ImportModule("gc")); + LOG_IF(FATAL, !gc) << py::FetchError().value(); + LOG_IF(FATAL, !gc.CallMethod("collect")) << py::FetchError().value(); + return maybe_exc; + } catch (...) { + LOG_IF(FATAL, !py_graph.CallMethod("invalidate")) + << py::FetchError().value(); + throw; + } +} + +} // namespace + PyObject *PyQueryModuleAddReadProcedure(PyQueryModule *self, PyObject *cb) { CHECK(self->module); if (!PyCallable_Check(cb)) { PyErr_SetString(PyExc_TypeError, "Expected a callable object."); return nullptr; } - Py_INCREF(cb); - py::Object py_cb(cb); + auto py_cb = py::Object::FromBorrow(cb); py::Object py_name(py_cb.GetAttr("__name__")); const auto *name = PyUnicode_AsUTF8(py_name); - // TODO: Validate name + if (!name) return nullptr; + if (!IsValidIdentifierName(name)) { + PyErr_SetString(PyExc_ValueError, + "Procedure name is not a valid identifier"); + return nullptr; + } auto *memory = self->module->procedures.get_allocator().GetMemoryResource(); mgp_proc proc( name, - [py_cb](const mgp_list *, const mgp_graph *, mgp_result *, mgp_memory *) { + [py_cb](const mgp_list *args, const mgp_graph *graph, mgp_result *result, + mgp_memory *memory) { auto gil = py::EnsureGIL(); - throw utils::NotYetImplemented("Invoking Python procedures"); + auto maybe_exc = + WithPyGraph(graph, memory, + [&](auto py_graph) -> std::optional { + py::Object py_args(MgpListToPyTuple(args, py_graph)); + if (!py_args) return py::FetchError(); + auto py_res = py_cb.Call(py_graph, py_args); + if (!py_res) return py::FetchError(); + if (PySequence_Check(py_res)) { + return AddMultipleRecordsFromPython(result, py_res); + } else { + return AddRecordFromPython(result, py_res); + } + }); + if (maybe_exc) return SetErrorFromPython(result, *maybe_exc); }, memory); const auto &[proc_it, did_insert] =