From 32e56684db218c0a2e8a4c39533acb52bc64804e Mon Sep 17 00:00:00 2001 From: Ivan Paljak Date: Wed, 26 Feb 2020 13:13:36 +0100 Subject: [PATCH] Add Python class for mgp_vertex Reviewers: teon.banek Reviewed By: teon.banek Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D2688 --- include/mgp.py | 73 +++++++++-------- src/query/procedure/py_module.cpp | 125 +++++++++++++++++++++++++++--- 2 files changed, 154 insertions(+), 44 deletions(-) diff --git a/include/mgp.py b/include/mgp.py index 8209209dc..f696b4440 100644 --- a/include/mgp.py +++ b/include/mgp.py @@ -25,34 +25,21 @@ import _mgp class Label: '''Label of a Vertex.''' + __slots__ = ('_name',) + + def __init__(self, name): + self._name = name; @property def name(self) -> str: - pass - - -class Labels: - '''A collection of labels on a Vertex.''' - - def __len__(self) -> int: - '''Raise InvalidVertexError.''' - pass - - def __getitem__(self, index: int) -> Label: - '''Raise InvalidVertexError.''' - pass - - def __iter__(self) -> typing.Iterable[Label]: - '''Raise InvalidVertexError.''' - pass - - def __contains__(self, label: typing.Union[Label, str]) -> bool: - '''Test whether a label exists, either by name or another Label. - - Raise InvalidVertexError. - ''' - pass + return self._name; + def __eq__(self, other) -> bool: + if isinstance(other, Label): + return self._name == other.name + if isinstance(other, str): + return self._name == other + return NotImplemented # Named property value of a Vertex or an Edge. # It would be better to use typing.NamedTuple with typed fields, but that is @@ -197,38 +184,58 @@ class Vertex: in a query. You should not globally store an instance of a Vertex. Using an invalid Vertex instance will raise InvalidVertexError. ''' + __slots__ = ('_vertex',) def __init__(self, vertex): - raise NotImplementedError() + if not isinstance(vertex, _mgp.Vertex): + raise TypeError("Expected '_mgp.Vertex', got '{}'".fmt(type(vertex))) + self._vertex = vertex + + def is_valid(self) -> bool: + '''Return True if `self` is in valid context and may be used''' + return self._vertex.is_valid() @property def id(self) -> VertexId: '''Raise InvalidVertexError.''' - pass + if not self.is_valid(): + raise InvalidVertexError() + return self._vertex.get_id() @property - def labels(self) -> Labels: + def labels(self) -> typing.List[Label]: '''Raise InvalidVertexError.''' - pass + if not self.is_valid(): + raise InvalidVertexError() + return tuple(Label(self._vertex.label_at(i)) + for i in range(self._vertex.labels_count())) @property def properties(self) -> Properties: '''Raise InvalidVertexError.''' - pass + if not self.is_valid(): + raise InvalidVertexError() + return Properties(self._vertex) @property def in_edges(self) -> typing.Iterable[Edge]: '''Raise InvalidVertexError.''' - pass + if not self.is_valid(): + raise InvalidVertexError() + raise NotImplementedError() @property def out_edges(self) -> typing.Iterable[Edge]: '''Raise InvalidVertexError.''' - pass + if not self.is_valid(): + raise InvalidVertexError() + raise NotImplementedError() def __eq__(self, other) -> bool: - '''Raise InvalidVertexError.''' - pass + '''Raise InvalidVertexError''' + if not self.is_valid(): + raise InvalidVertexError() + return self._vertex == other._vertex class Path: diff --git a/src/query/procedure/py_module.cpp b/src/query/procedure/py_module.cpp index 32b8d7d54..53fb05874 100644 --- a/src/query/procedure/py_module.cpp +++ b/src/query/procedure/py_module.cpp @@ -34,6 +34,8 @@ struct PyVerticesIterator { PyGraph *py_graph; }; +PyObject *MakePyVertex(mgp_vertex *vertex, PyGraph *py_graph); + void PyVerticesIteratorDealloc(PyVerticesIterator *self) { CHECK(self->it); CHECK(self->py_graph); @@ -51,9 +53,8 @@ PyObject *PyVerticesIteratorGet(PyVerticesIterator *self, CHECK(self->py_graph->graph); const auto *vertex = mgp_vertices_iterator_get(self->it); if (!vertex) Py_RETURN_NONE; - // TODO: Wrap mgp_vertex_copy(vertex) into _mgp.Vertex and return it. - PyErr_SetString(PyExc_NotImplementedError, "get"); - return nullptr; + return MakePyVertex(mgp_vertex_copy(vertex, self->py_graph->memory), + self->py_graph); } PyObject *PyVerticesIteratorNext(PyVerticesIterator *self, @@ -63,9 +64,8 @@ PyObject *PyVerticesIteratorNext(PyVerticesIterator *self, CHECK(self->py_graph->graph); const auto *vertex = mgp_vertices_iterator_next(self->it); if (!vertex) Py_RETURN_NONE; - // TODO: Wrap mgp_vertex_copy(vertex) into _mgp.Vertex and return it. - PyErr_SetString(PyExc_NotImplementedError, "next"); - return nullptr; + return MakePyVertex(mgp_vertex_copy(vertex, self->py_graph->memory), + self->py_graph); } static PyMethodDef PyVerticesIteratorMethods[] = { @@ -104,11 +104,7 @@ PyObject *PyGraphGetVertexById(PyGraph *self, PyObject *args) { "Unable to find the vertex with given ID."); return nullptr; } - // TODO: Wrap into _mgp.Vertex and let it handle mgp_vertex_destroy via - // dealloc function. - mgp_vertex_destroy(vertex); - PyErr_SetString(PyExc_NotImplementedError, "get_vertex_by_id"); - return nullptr; + return MakePyVertex(mgp_vertex_copy(vertex, self->memory), self); } PyObject *PyGraphIterVertices(PyGraph *self, PyObject *Py_UNUSED(ignored)) { @@ -452,6 +448,112 @@ PyObject *PyEdgeRichCompare(PyObject *self, PyObject *other, int op) { return PyBool_FromLong(mgp_edge_equal(e1->edge, e2->edge)); } +struct PyVertex { + PyObject_HEAD + mgp_vertex *vertex; + PyGraph *py_graph; +}; + +void PyVertexDealloc(PyVertex *self) { + CHECK(self->vertex); + CHECK(self->py_graph); + // Avoid invoking `mgp_vertex_destroy` if we are not in valid execution + // context. The query execution should free all memory used during + // execution, so we may cause a double free issue. + if (self->py_graph->graph) mgp_vertex_destroy(self->vertex); + Py_DECREF(self->py_graph); + Py_TYPE(self)->tp_free(self); +} + +PyObject *PyVertexIsValid(PyVertex *self, PyObject *Py_UNUSED(ignored)) { + return PyBool_FromLong(!!self->py_graph->graph); +} + +PyObject *PyVertexGetId(PyVertex *self, PyObject *Py_UNUSED(ignored)) { + CHECK(self); + CHECK(self->vertex); + CHECK(self->py_graph); + CHECK(self->py_graph->graph); + return PyLong_FromLongLong(mgp_vertex_get_id(self->vertex).as_int); +} + +PyObject *PyVertexLabelsCount(PyVertex *self, PyObject *Py_UNUSED(ignored)) { + CHECK(self); + CHECK(self->vertex); + CHECK(self->py_graph); + CHECK(self->py_graph->graph); + return PyLong_FromSize_t(mgp_vertex_labels_count(self->vertex)); +} + +PyObject *PyVertexLabelAt(PyVertex *self, PyObject *args) { + CHECK(self); + CHECK(self->vertex); + CHECK(self->py_graph); + CHECK(self->py_graph->graph); + static_assert(std::numeric_limits::max() <= + std::numeric_limits::max()); + Py_ssize_t id; + if (!PyArg_ParseTuple(args, "n", &id)) return nullptr; + auto label = mgp_vertex_label_at(self->vertex, id); + if (label.name == nullptr || id < 0) { + PyErr_SetString(PyExc_IndexError, + "Unable to find the label with given ID."); + return nullptr; + } + return PyUnicode_FromString(label.name); +} + +static PyMethodDef PyVertexMethods[] = { + {"is_valid", reinterpret_cast(PyVertexIsValid), METH_NOARGS, + "Return True if Vertex is in valid context and may be used."}, + {"get_id", reinterpret_cast(PyVertexGetId), METH_NOARGS, + "Return vertex id."}, + {"labels_count", reinterpret_cast(PyVertexLabelsCount), + METH_NOARGS, "Return number of lables of a vertex."}, + {"label_at", reinterpret_cast(PyVertexLabelAt), METH_VARARGS, + "Return label of a vertex on a given index."}, + {nullptr}}; + +PyObject *PyVertexRichCompare(PyObject *self, PyObject *other, int op); + +static PyTypeObject PyVertexType = { + PyVarObject_HEAD_INIT(nullptr, 0).tp_name = "_mgp.Vertex", + .tp_doc = "Wraps struct mgp_vertex.", + .tp_basicsize = sizeof(PyVertex), + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_new = PyType_GenericNew, + .tp_methods = PyVertexMethods, + .tp_dealloc = reinterpret_cast(PyVertexDealloc), + .tp_richcompare = PyVertexRichCompare, +}; + +PyObject *MakePyVertex(mgp_vertex *vertex, PyGraph *py_graph) { + CHECK(vertex->GetMemoryResource() == py_graph->memory->impl); + auto *py_vertex = PyObject_New(PyVertex, &PyVertexType); + if (!py_vertex) return nullptr; + py_vertex->vertex = vertex; + py_vertex->py_graph = py_graph; + Py_INCREF(py_graph); + return PyObject_Init(reinterpret_cast(py_vertex), &PyVertexType); +} + +PyObject *PyVertexRichCompare(PyObject *self, PyObject *other, int op) { + CHECK(self); + CHECK(other); + + if (Py_TYPE(self) != &PyVertexType || Py_TYPE(other) != &PyVertexType || + op != Py_EQ) { + Py_RETURN_NOTIMPLEMENTED; + } + + auto *v1 = reinterpret_cast(self); + auto *v2 = reinterpret_cast(other); + CHECK(v1->vertex); + CHECK(v2->vertex); + + return PyBool_FromLong(mgp_vertex_equal(v1->vertex, v2->vertex)); +} + PyObject *PyInitMgpModule() { PyObject *mgp = PyModule_Create(&PyMgpModule); if (!mgp) return nullptr; @@ -474,6 +576,7 @@ PyObject *PyInitMgpModule() { if (!register_type(&PyEdgeType, "Edge")) return nullptr; if (!register_type(&PyQueryProcType, "Proc")) return nullptr; if (!register_type(&PyQueryModuleType, "Module")) return nullptr; + if (!register_type(&PyVertexType, "Vertex")) return nullptr; Py_INCREF(Py_None); if (PyModule_AddObject(mgp, "_MODULE", Py_None) < 0) { Py_DECREF(Py_None);