diff --git a/include/mgp.py b/include/mgp.py index f696b4440..d580b5b6d 100644 --- a/include/mgp.py +++ b/include/mgp.py @@ -28,11 +28,11 @@ class Label: __slots__ = ('_name',) def __init__(self, name): - self._name = name; + self._name = name @property def name(self) -> str: - return self._name; + return self._name def __eq__(self, other) -> bool: if isinstance(other, Label): @@ -41,6 +41,7 @@ class Label: return self._name == other return NotImplemented + # Named property value of a Vertex or an Edge. # It would be better to use typing.NamedTuple with typed fields, but that is # not available in Python 3.5. @@ -169,7 +170,10 @@ class Edge: return self._edge == other._edge -VertexId = typing.NewType('VertexId', int) +if sys.version_info >= (3, 5, 2): + VertexId = typing.NewType('VertexId', int) +else: + VertexId = int class InvalidVertexError(Exception): @@ -371,15 +375,110 @@ class ProcCtx: Number = typing.Union[int, float] -List = typing.List - Map = typing.Union[dict, Edge, Vertex] Any = typing.Union[bool, str, Number, Map, Path, list] +List = typing.List + Nullable = typing.Optional +class UnsupportedTypingError(Exception): + '''Signals a typing annotation is not supported as a _mgp.CypherType.''' + + def __init__(self, type_): + super().__init__("Unsupported typing annotation '{}'".format(type_)) + + +def _typing_to_cypher_type(type_): + '''Convert typing annotation to a _mgp.CypherType instance.''' + simple_types = { + typing.Any: _mgp.type_nullable(_mgp.type_any()), + object: _mgp.type_nullable(_mgp.type_any()), + list: _mgp.type_list(_mgp.type_nullable(_mgp.type_any())), + Any: _mgp.type_any(), + bool: _mgp.type_bool(), + str: _mgp.type_string(), + int: _mgp.type_int(), + float: _mgp.type_float(), + Number: _mgp.type_number(), + Map: _mgp.type_map(), + Vertex: _mgp.type_node(), + Edge: _mgp.type_relationship(), + Path: _mgp.type_path() + } + try: + return simple_types[type_] + except KeyError: + pass + if sys.version_info >= (3, 8): + complex_type = typing.get_origin(type_) + type_args = typing.get_args(type_) + if complex_type == typing.Union: + # If we have a Union with NoneType inside, it means we are building + # a nullable type. + if isinstance(None, type_args): + types = tuple(t for t in type_args if not isinstance(None, t)) + if len(types) == 1: + type_arg, = types + else: + # We cannot do typing.Union[*types], so do the equivalent + # with __getitem__ which does not even need arg unpacking. + type_arg = typing.Union.__getitem__(types) + return _mgp.type_nullable(_typing_to_cypher_type(type_arg)) + elif complex_type == list: + type_arg, = type_args + return _mgp.type_list(_typing_to_cypher_type(type_arg)) + raise UnsupportedTypingError(type_) + else: + # We cannot get to type args in any reliable way prior to 3.8, but we + # still want to support typing.Optional and typing.List, so just parse + # their string representations. Hopefully, that is always pretty + # printed the same way. `typing.List[type]` is printed as such, while + # `typing.Optional[type]` is printed as 'typing.Union[type, NoneType]' + def parse_type_args(type_as_str): + return tuple(map(str.strip, + type_as_str[type_as_str.index('[') + 1: -1].split(','))) + + def get_simple_type(type_as_str): + for simple_type, cypher_type in simple_types.items(): + if type_as_str == str(simple_type): + return cypher_type + # Fallback to comparing to __name__ if it exits. This handles + # the cases like when we have 'object' which is + # `object.__name__`, but `str(object)` is "" + try: + if type_as_str == simple_type.__name__: + return cypher_type + except AttributeError: + pass + + def parse_typing(type_as_str): + if type_as_str.startswith('typing.Union'): + type_args_as_str = parse_type_args(type_as_str) + none_type_as_str = type(None).__name__ + if none_type_as_str in type_args_as_str: + types = tuple(t for t in type_args_as_str if t != none_type_as_str) + if len(types) == 1: + type_arg_as_str, = types + else: + type_arg_as_str = 'typing.Union[' + ', '.join(types) + ']' + simple_type = get_simple_type(type_arg_as_str) + if simple_type is not None: + return _mgp.type_nullable(simple_type) + return _mgp.type_nullable(parse_typing(type_arg_as_str)) + elif type_as_str.startswith('typing.List'): + type_arg_as_str, = parse_type_args(type_as_str) + simple_type = get_simple_type(type_arg_as_str) + if simple_type is not None: + return _mgp.type_list(simple_type) + return _mgp.type_list(parse_typing(type_arg_as_str)) + raise UnsupportedTypingError(type_) + + return parse_typing(str(type_)) + + # Procedure registration class Deprecated: @@ -387,8 +486,6 @@ class Deprecated: __slots__ = ('field_type',) def __init__(self, type_): - if not isinstance(type_, type): - raise TypeError("Expected 'type', got '{}'".format(type_)) self.field_type = type_ @@ -437,7 +534,7 @@ def read_proc(func: typing.Callable[..., Record]): .format(type(func))) if inspect.iscoroutinefunction(func): raise TypeError("Callable must not be 'async def' function") - if sys.version_info.minor >= 6: + if sys.version_info >= (3, 6): if inspect.isasyncgenfunction(func): raise TypeError("Callable must not be 'async def' function") if inspect.isgeneratorfunction(func): @@ -456,23 +553,22 @@ def read_proc(func: typing.Callable[..., Record]): for param in params: name = param.name type_ = param.annotation - # TODO: Convert type_ to _mgp.CypherType if type_ is param.empty: type_ = object + cypher_type = _typing_to_cypher_type(type_) if param.default is param.empty: - mgp_proc.add_arg(name, type_) + mgp_proc.add_arg(name, cypher_type) else: - mgp_proc.add_opt_arg(name, type_, param.default) + mgp_proc.add_opt_arg(name, cypher_type, param.default) if sig.return_annotation is not sig.empty: record = sig.return_annotation if not isinstance(record, Record): raise TypeError("Expected '{}' to return 'mgp.Record', got '{}'" .format(func.__name__, type(record))) for name, type_ in record.fields.items(): - # TODO: Convert type_ to _mgp.CypherType if isinstance(type_, Deprecated): - field_type = type_.field_type - mgp_proc.add_deprecated_result(name, field_type) + cypher_type = _typing_to_cypher_type(type_.field_type) + mgp_proc.add_deprecated_result(name, cypher_type) else: - mgp_proc.add_result(name, type_) + mgp_proc.add_result(name, _typing_to_cypher_type(type_)) return func diff --git a/src/query/procedure/py_module.cpp b/src/query/procedure/py_module.cpp index 53fb05874..1bdf799da 100644 --- a/src/query/procedure/py_module.cpp +++ b/src/query/procedure/py_module.cpp @@ -158,6 +158,28 @@ PyObject *MakePyGraph(const mgp_graph *graph, mgp_memory *memory) { return PyObject_Init(reinterpret_cast(py_graph), &PyGraphType); } +struct PyCypherType { + PyObject_HEAD + const mgp_type *type; +}; + +static PyTypeObject PyCypherTypeType = { + PyVarObject_HEAD_INIT(nullptr, 0) + .tp_name = "_mgp.Type", + .tp_doc = "Wraps struct mgp_type.", + .tp_basicsize = sizeof(PyCypherType), + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_new = PyType_GenericNew, +}; + +PyObject *MakePyCypherType(const mgp_type *type) { + auto *py_type = PyObject_New(PyCypherType, &PyCypherTypeType); + if (!py_type) return nullptr; + py_type->type = type; + return PyObject_Init(reinterpret_cast(py_type), + &PyCypherTypeType); +} + struct PyQueryProc { PyObject_HEAD mgp_proc *proc; @@ -168,8 +190,11 @@ PyObject *PyQueryProcAddArg(PyQueryProc *self, PyObject *args) { const char *name = nullptr; PyObject *py_type = nullptr; if (!PyArg_ParseTuple(args, "sO", &name, &py_type)) return nullptr; - // TODO: Convert Python type to mgp_type - const auto *type = mgp_type_nullable(mgp_type_any()); + if (Py_TYPE(py_type) != &PyCypherTypeType) { + PyErr_SetString(PyExc_TypeError, "Expected a _mgp.Type."); + return nullptr; + } + const auto *type = reinterpret_cast(py_type)->type; if (!mgp_proc_add_arg(self->proc, name, type)) { PyErr_SetString(PyExc_ValueError, "Invalid call to mgp_proc_add_arg."); return nullptr; @@ -184,8 +209,11 @@ PyObject *PyQueryProcAddOptArg(PyQueryProc *self, PyObject *args) { PyObject *py_value = nullptr; if (!PyArg_ParseTuple(args, "sOO", &name, &py_type, &py_value)) return nullptr; - // TODO: Convert Python type to mgp_type - const auto *type = mgp_type_nullable(mgp_type_any()); + if (Py_TYPE(py_type) != &PyCypherTypeType) { + PyErr_SetString(PyExc_TypeError, "Expected a _mgp.Type."); + return nullptr; + } + const auto *type = reinterpret_cast(py_type)->type; mgp_memory memory{self->proc->opt_args.get_allocator().GetMemoryResource()}; mgp_value *value; try { @@ -218,8 +246,11 @@ PyObject *PyQueryProcAddResult(PyQueryProc *self, PyObject *args) { const char *name = nullptr; PyObject *py_type = nullptr; if (!PyArg_ParseTuple(args, "sO", &name, &py_type)) return nullptr; - // TODO: Convert Python type to mgp_type - const auto *type = mgp_type_nullable(mgp_type_any()); + if (Py_TYPE(py_type) != &PyCypherTypeType) { + PyErr_SetString(PyExc_TypeError, "Expected a _mgp.Type."); + return nullptr; + } + const auto *type = reinterpret_cast(py_type)->type; if (!mgp_proc_add_result(self->proc, name, type)) { PyErr_SetString(PyExc_ValueError, "Invalid call to mgp_proc_add_result."); return nullptr; @@ -232,8 +263,11 @@ PyObject *PyQueryProcAddDeprecatedResult(PyQueryProc *self, PyObject *args) { const char *name = nullptr; PyObject *py_type = nullptr; if (!PyArg_ParseTuple(args, "sO", &name, &py_type)) return nullptr; - // TODO: Convert Python type to mgp_type - const auto *type = mgp_type_nullable(mgp_type_any()); + if (Py_TYPE(py_type) != &PyCypherTypeType) { + PyErr_SetString(PyExc_TypeError, "Expected a _mgp.Type."); + return nullptr; + } + const auto *type = reinterpret_cast(py_type)->type; if (!mgp_proc_add_deprecated_result(self->proc, name, type)) { PyErr_SetString(PyExc_ValueError, "Invalid call to mgp_proc_add_deprecated_result."); @@ -329,11 +363,100 @@ PyObject *MakePyQueryModule(mgp_module *module) { &PyQueryModuleType); } +PyObject *PyMgpModuleTypeNullable(PyObject *mod, PyObject *obj) { + if (Py_TYPE(obj) != &PyCypherTypeType) { + PyErr_SetString(PyExc_TypeError, "Expected a _mgp.Type."); + return nullptr; + } + auto *py_type = reinterpret_cast(obj); + return MakePyCypherType(mgp_type_nullable(py_type->type)); +} + +PyObject *PyMgpModuleTypeList(PyObject *mod, PyObject *obj) { + if (Py_TYPE(obj) != &PyCypherTypeType) { + PyErr_SetString(PyExc_TypeError, "Expected a _mgp.Type."); + return nullptr; + } + auto *py_type = reinterpret_cast(obj); + return MakePyCypherType(mgp_type_list(py_type->type)); +} + +PyObject *PyMgpModuleTypeAny(PyObject *mod, PyObject *Py_UNUSED(ignored)) { + return MakePyCypherType(mgp_type_any()); +} + +PyObject *PyMgpModuleTypeBool(PyObject *mod, PyObject *Py_UNUSED(ignored)) { + return MakePyCypherType(mgp_type_bool()); +} + +PyObject *PyMgpModuleTypeString(PyObject *mod, PyObject *Py_UNUSED(ignored)) { + return MakePyCypherType(mgp_type_string()); +} + +PyObject *PyMgpModuleTypeInt(PyObject *mod, PyObject *Py_UNUSED(ignored)) { + return MakePyCypherType(mgp_type_int()); +} + +PyObject *PyMgpModuleTypeFloat(PyObject *mod, PyObject *Py_UNUSED(ignored)) { + return MakePyCypherType(mgp_type_float()); +} + +PyObject *PyMgpModuleTypeNumber(PyObject *mod, PyObject *Py_UNUSED(ignored)) { + return MakePyCypherType(mgp_type_number()); +} + +PyObject *PyMgpModuleTypeMap(PyObject *mod, PyObject *Py_UNUSED(ignored)) { + return MakePyCypherType(mgp_type_map()); +} + +PyObject *PyMgpModuleTypeNode(PyObject *mod, PyObject *Py_UNUSED(ignored)) { + return MakePyCypherType(mgp_type_node()); +} + +PyObject *PyMgpModuleTypeRelationship(PyObject *mod, + PyObject *Py_UNUSED(ignored)) { + return MakePyCypherType(mgp_type_relationship()); +} + +PyObject *PyMgpModuleTypePath(PyObject *mod, PyObject *Py_UNUSED(ignored)) { + return MakePyCypherType(mgp_type_path()); +} + +static PyMethodDef PyMgpModuleMethods[] = { + {"type_nullable", PyMgpModuleTypeNullable, METH_O, + "Build a type representing either a `null` value or a value of given " + "type."}, + {"type_list", PyMgpModuleTypeList, METH_O, + "Build a type representing a list of values of given type."}, + {"type_any", PyMgpModuleTypeAny, METH_NOARGS, + "Get the type representing any value that isn't `null`."}, + {"type_bool", PyMgpModuleTypeBool, METH_NOARGS, + "Get the type representing boolean values."}, + {"type_string", PyMgpModuleTypeString, METH_NOARGS, + "Get the type representing string values."}, + {"type_int", PyMgpModuleTypeInt, METH_NOARGS, + "Get the type representing integer values."}, + {"type_float", PyMgpModuleTypeFloat, METH_NOARGS, + "Get the type representing floating-point values."}, + {"type_number", PyMgpModuleTypeNumber, METH_NOARGS, + "Get the type representing any number value."}, + {"type_map", PyMgpModuleTypeMap, METH_NOARGS, + "Get the type representing map values."}, + {"type_node", PyMgpModuleTypeNode, METH_NOARGS, + "Get the type representing graph node values."}, + {"type_relationship", PyMgpModuleTypeRelationship, METH_NOARGS, + "Get the type representing graph relationship values."}, + {"type_path", PyMgpModuleTypePath, METH_NOARGS, + "Get the type representing a graph path (walk) from one node to another."}, + {nullptr}, +}; + static PyModuleDef PyMgpModule = { PyModuleDef_HEAD_INIT, .m_name = "_mgp", .m_doc = "Contains raw bindings to mg_procedure.h C API.", .m_size = -1, + .m_methods = PyMgpModuleMethods, }; struct PyEdge { @@ -403,7 +526,8 @@ static PyMethodDef PyEdgeMethods[] = { PyObject *PyEdgeRichCompare(PyObject *self, PyObject *other, int op); static PyTypeObject PyEdgeType = { - PyVarObject_HEAD_INIT(nullptr, 0).tp_name = "_mgp.Edge", + PyVarObject_HEAD_INIT(nullptr, 0) + .tp_name = "_mgp.Edge", .tp_doc = "Wraps struct mgp_edge.", .tp_basicsize = sizeof(PyEdge), .tp_flags = Py_TPFLAGS_DEFAULT, @@ -577,6 +701,7 @@ PyObject *PyInitMgpModule() { if (!register_type(&PyQueryProcType, "Proc")) return nullptr; if (!register_type(&PyQueryModuleType, "Module")) return nullptr; if (!register_type(&PyVertexType, "Vertex")) return nullptr; + if (!register_type(&PyCypherTypeType, "Type")) return nullptr; Py_INCREF(Py_None); if (PyModule_AddObject(mgp, "_MODULE", Py_None) < 0) { Py_DECREF(Py_None); @@ -591,6 +716,7 @@ namespace { template auto WithMgpModule(mgp_module *module_def, const TFun &fun) { py::Object py_mgp(PyImport_ImportModule("_mgp")); + CHECK(py_mgp) << "Expected builtin '_mgp' to be available for import"; py::Object py_mgp_module(PyObject_GetAttrString(py_mgp, "_MODULE")); CHECK(py_mgp_module) << "Expected '_mgp' to have attribute '_MODULE'"; // NOTE: This check is not thread safe, but this should only go through @@ -599,7 +725,6 @@ auto WithMgpModule(mgp_module *module_def, const TFun &fun) { << "Expected '_mgp._MODULE' to be None as we are just starting to " "import a new module. Is some other thread also importing Python " "modules?"; - CHECK(py_mgp) << "Expected builtin '_mgp' to be available for import"; auto *py_query_module = MakePyQueryModule(module_def); CHECK(py_query_module); CHECK(0 <= PyObject_SetAttrString(py_mgp, "_MODULE", py_query_module));