diff --git a/src/query/procedure/py_module.cpp b/src/query/procedure/py_module.cpp index 509a966fb..3ffd77b37 100644 --- a/src/query/procedure/py_module.cpp +++ b/src/query/procedure/py_module.cpp @@ -1123,12 +1123,20 @@ py::Object MgpValueToPyObject(const mgp_value &value, PyGraph *py_graph) { case MGP_VALUE_TYPE_VERTEX: throw utils::NotYetImplemented("MgpValueToPyObject"); case MGP_VALUE_TYPE_EDGE: { + py::Object py_mgp(PyImport_ImportModule("mgp")); + if (!py_mgp) return nullptr; const auto *e = mgp_value_get_edge(&value); - return py::Object(reinterpret_cast<PyObject *>(MakePyEdge(*e, py_graph))); + py::Object py_edge( + reinterpret_cast<PyObject *>(MakePyEdge(*e, py_graph))); + return py_mgp.CallMethod("Edge", py_edge); } case MGP_VALUE_TYPE_PATH: { + py::Object py_mgp(PyImport_ImportModule("mgp")); + if (!py_mgp) return nullptr; const auto *p = mgp_value_get_path(&value); - return py::Object(reinterpret_cast<PyObject *>(MakePyPath(*p, py_graph))); + py::Object py_path( + reinterpret_cast<PyObject *>(MakePyPath(*p, py_graph))); + return py_mgp.CallMethod("Path", py_path); } } } @@ -1164,6 +1172,31 @@ mgp_value *PyObjectToMgpValue(PyObject *o, mgp_memory *memory) { return v; }; + auto is_mgp_instance = [](PyObject *obj, const char *mgp_type_name) { + py::Object py_mgp(PyImport_ImportModule("mgp")); + if (!py_mgp) { + PyErr_Clear(); + // This way we skip conversions of types from user-facing 'mgp' module. + return false; + } + auto mgp_type = py_mgp.GetAttr(mgp_type_name); + if (!mgp_type) { + PyErr_Clear(); + std::stringstream ss; + ss << "'mgp' module is missing '" << mgp_type_name << "' type"; + throw std::invalid_argument(ss.str()); + } + int res = PyObject_IsInstance(obj, mgp_type); + if (res == -1) { + PyErr_Clear(); + std::stringstream ss; + ss << "Error when checking object is instance of 'mgp." << mgp_type_name + << "' type"; + throw std::invalid_argument(ss.str()); + } + return static_cast<bool>(res); + }; + mgp_value *mgp_v{nullptr}; if (o == Py_None) { @@ -1258,6 +1291,28 @@ mgp_value *PyObjectToMgpValue(PyObject *o, mgp_memory *memory) { } } else if (Py_TYPE(o) == &PyVertexType) { throw utils::NotYetImplemented("PyObjectToMgpValue"); + } else if (is_mgp_instance(o, "Edge")) { + py::Object edge(PyObject_GetAttrString(o, "_edge")); + if (!edge) { + PyErr_Clear(); + throw std::invalid_argument("'mgp.Edge' is missing '_edge' attribute"); + } + return PyObjectToMgpValue(edge, memory); + } else if (is_mgp_instance(o, "Vertex")) { + py::Object vertex(PyObject_GetAttrString(o, "_vertex")); + if (!vertex) { + PyErr_Clear(); + throw std::invalid_argument( + "'mgp.Vertex' is missing '_vertex' attribute"); + } + return PyObjectToMgpValue(vertex, memory); + } else if (is_mgp_instance(o, "Path")) { + py::Object path(PyObject_GetAttrString(o, "_path")); + if (!path) { + PyErr_Clear(); + throw std::invalid_argument("'mgp.Path' is missing '_path' attribute"); + } + return PyObjectToMgpValue(path, memory); } else { throw std::invalid_argument("Unsupported PyObject conversion"); } diff --git a/src/query/procedure/py_module.hpp b/src/query/procedure/py_module.hpp index bd9e53130..29ed38e20 100644 --- a/src/query/procedure/py_module.hpp +++ b/src/query/procedure/py_module.hpp @@ -16,6 +16,12 @@ struct PyGraph; /// Convert an `mgp_value` into a Python object, referencing the given `PyGraph` /// instance and using the same allocator as the graph. /// +/// Values of type `MGP_VALUE_TYPE_VERTEX`, `MGP_VALUE_TYPE_EDGE` and +/// `MGP_VALUE_TYPE_PATH` are returned as `mgp.Vertex`, `mgp.Edge` and +/// `mgp.Path` respectively, and *not* their internal `_mgp` +/// representations. Other value types are converted to equivalent builtin +/// Python objects. +/// /// Return a non-null `py::Object` instance on success. Otherwise, return a null /// `py::Object` instance and set the appropriate Python exception. py::Object MgpValueToPyObject(const mgp_value &value, PyGraph *py_graph); @@ -25,6 +31,9 @@ py::Object MgpValueToPyObject(const mgp_value &value, PyObject *py_graph); /// Convert a Python object into `mgp_value`, constructing it using the given /// `mgp_memory` allocator. /// +/// If the user-facing 'mgp' module can be imported, this function will handle +/// conversion of 'mgp.Vertex', 'mgp.Edge' and 'mgp.Path' values. +/// /// @throw std::bad_alloc /// @throw std::overflow_error if attempting to convert a Python integer which /// too large to fit into int64_t. diff --git a/tests/apollo_runs.py b/tests/apollo_runs.py index 3cfff3ef0..00357a96b 100755 --- a/tests/apollo_runs.py +++ b/tests/apollo_runs.py @@ -75,6 +75,11 @@ def get_runs(build_dir, include=None, exclude=None, outfile=None, if name.endswith("storage_v2_durability"): prefix = "TIMEOUT=300 " + # py_module unit test requires user-facing 'mgp' module + if name.endswith("py_module"): + mgp_path = os.path.join("..", "include", "mgp.py") + files.append(os.path.relpath(mgp_path, dirname)) + # get output files outfile_paths = [] if outfile: diff --git a/tests/unit/query_procedure_py_module.cpp b/tests/unit/query_procedure_py_module.cpp index 63eb4e871..251854bea 100644 --- a/tests/unit/query_procedure_py_module.cpp +++ b/tests/unit/query_procedure_py_module.cpp @@ -1,5 +1,6 @@ #include <gtest/gtest.h> +#include <filesystem> #include <string> #include "query/procedure/mg_procedure_impl.hpp" @@ -180,7 +181,25 @@ int main(int argc, char **argv) { PyEval_InitThreads(); int test_result; { - py::Object mgp(PyImport_ImportModule("_mgp")); + // Setup importing 'mgp' module by adding its directory to `sys.path`. + std::filesystem::path invocation_path(argv[0]); + auto mgp_py_path = + invocation_path.parent_path() / "../../../include/mgp.py"; + CHECK(std::filesystem::exists(mgp_py_path)); + auto *py_path = PySys_GetObject("path"); + CHECK(py_path); + py::Object import_dir( + PyUnicode_FromString(mgp_py_path.parent_path().c_str())); + if (PyList_Append(py_path, import_dir) != 0) { + auto exc_info = py::FetchError().value(); + LOG(FATAL) << exc_info; + } + py::Object mgp(PyImport_ImportModule("mgp")); + if (!mgp) { + auto exc_info = py::FetchError().value(); + LOG(FATAL) << exc_info; + } + // Now run tests. Py_BEGIN_ALLOW_THREADS; test_result = RUN_ALL_TESTS(); Py_END_ALLOW_THREADS;