Support typing annotations in Python Query Modules
Reviewers: mferencevic, ipaljak, tlastre Reviewed By: mferencevic Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D2689
This commit is contained in:
parent
32e56684db
commit
ad892f2db3
126
include/mgp.py
126
include/mgp.py
@ -28,11 +28,11 @@ class Label:
|
||||
__slots__ = ('_name',)
|
||||
|
||||
def __init__(self, name):
|
||||
self._name = name;
|
||||
self._name = name
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name;
|
||||
return self._name
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if isinstance(other, Label):
|
||||
@ -41,6 +41,7 @@ class Label:
|
||||
return self._name == other
|
||||
return NotImplemented
|
||||
|
||||
|
||||
# Named property value of a Vertex or an Edge.
|
||||
# It would be better to use typing.NamedTuple with typed fields, but that is
|
||||
# not available in Python 3.5.
|
||||
@ -169,7 +170,10 @@ class Edge:
|
||||
return self._edge == other._edge
|
||||
|
||||
|
||||
VertexId = typing.NewType('VertexId', int)
|
||||
if sys.version_info >= (3, 5, 2):
|
||||
VertexId = typing.NewType('VertexId', int)
|
||||
else:
|
||||
VertexId = int
|
||||
|
||||
|
||||
class InvalidVertexError(Exception):
|
||||
@ -371,15 +375,110 @@ class ProcCtx:
|
||||
|
||||
Number = typing.Union[int, float]
|
||||
|
||||
List = typing.List
|
||||
|
||||
Map = typing.Union[dict, Edge, Vertex]
|
||||
|
||||
Any = typing.Union[bool, str, Number, Map, Path, list]
|
||||
|
||||
List = typing.List
|
||||
|
||||
Nullable = typing.Optional
|
||||
|
||||
|
||||
class UnsupportedTypingError(Exception):
|
||||
'''Signals a typing annotation is not supported as a _mgp.CypherType.'''
|
||||
|
||||
def __init__(self, type_):
|
||||
super().__init__("Unsupported typing annotation '{}'".format(type_))
|
||||
|
||||
|
||||
def _typing_to_cypher_type(type_):
|
||||
'''Convert typing annotation to a _mgp.CypherType instance.'''
|
||||
simple_types = {
|
||||
typing.Any: _mgp.type_nullable(_mgp.type_any()),
|
||||
object: _mgp.type_nullable(_mgp.type_any()),
|
||||
list: _mgp.type_list(_mgp.type_nullable(_mgp.type_any())),
|
||||
Any: _mgp.type_any(),
|
||||
bool: _mgp.type_bool(),
|
||||
str: _mgp.type_string(),
|
||||
int: _mgp.type_int(),
|
||||
float: _mgp.type_float(),
|
||||
Number: _mgp.type_number(),
|
||||
Map: _mgp.type_map(),
|
||||
Vertex: _mgp.type_node(),
|
||||
Edge: _mgp.type_relationship(),
|
||||
Path: _mgp.type_path()
|
||||
}
|
||||
try:
|
||||
return simple_types[type_]
|
||||
except KeyError:
|
||||
pass
|
||||
if sys.version_info >= (3, 8):
|
||||
complex_type = typing.get_origin(type_)
|
||||
type_args = typing.get_args(type_)
|
||||
if complex_type == typing.Union:
|
||||
# If we have a Union with NoneType inside, it means we are building
|
||||
# a nullable type.
|
||||
if isinstance(None, type_args):
|
||||
types = tuple(t for t in type_args if not isinstance(None, t))
|
||||
if len(types) == 1:
|
||||
type_arg, = types
|
||||
else:
|
||||
# We cannot do typing.Union[*types], so do the equivalent
|
||||
# with __getitem__ which does not even need arg unpacking.
|
||||
type_arg = typing.Union.__getitem__(types)
|
||||
return _mgp.type_nullable(_typing_to_cypher_type(type_arg))
|
||||
elif complex_type == list:
|
||||
type_arg, = type_args
|
||||
return _mgp.type_list(_typing_to_cypher_type(type_arg))
|
||||
raise UnsupportedTypingError(type_)
|
||||
else:
|
||||
# We cannot get to type args in any reliable way prior to 3.8, but we
|
||||
# still want to support typing.Optional and typing.List, so just parse
|
||||
# their string representations. Hopefully, that is always pretty
|
||||
# printed the same way. `typing.List[type]` is printed as such, while
|
||||
# `typing.Optional[type]` is printed as 'typing.Union[type, NoneType]'
|
||||
def parse_type_args(type_as_str):
|
||||
return tuple(map(str.strip,
|
||||
type_as_str[type_as_str.index('[') + 1: -1].split(',')))
|
||||
|
||||
def get_simple_type(type_as_str):
|
||||
for simple_type, cypher_type in simple_types.items():
|
||||
if type_as_str == str(simple_type):
|
||||
return cypher_type
|
||||
# Fallback to comparing to __name__ if it exits. This handles
|
||||
# the cases like when we have 'object' which is
|
||||
# `object.__name__`, but `str(object)` is "<class 'object'>"
|
||||
try:
|
||||
if type_as_str == simple_type.__name__:
|
||||
return cypher_type
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def parse_typing(type_as_str):
|
||||
if type_as_str.startswith('typing.Union'):
|
||||
type_args_as_str = parse_type_args(type_as_str)
|
||||
none_type_as_str = type(None).__name__
|
||||
if none_type_as_str in type_args_as_str:
|
||||
types = tuple(t for t in type_args_as_str if t != none_type_as_str)
|
||||
if len(types) == 1:
|
||||
type_arg_as_str, = types
|
||||
else:
|
||||
type_arg_as_str = 'typing.Union[' + ', '.join(types) + ']'
|
||||
simple_type = get_simple_type(type_arg_as_str)
|
||||
if simple_type is not None:
|
||||
return _mgp.type_nullable(simple_type)
|
||||
return _mgp.type_nullable(parse_typing(type_arg_as_str))
|
||||
elif type_as_str.startswith('typing.List'):
|
||||
type_arg_as_str, = parse_type_args(type_as_str)
|
||||
simple_type = get_simple_type(type_arg_as_str)
|
||||
if simple_type is not None:
|
||||
return _mgp.type_list(simple_type)
|
||||
return _mgp.type_list(parse_typing(type_arg_as_str))
|
||||
raise UnsupportedTypingError(type_)
|
||||
|
||||
return parse_typing(str(type_))
|
||||
|
||||
|
||||
# Procedure registration
|
||||
|
||||
class Deprecated:
|
||||
@ -387,8 +486,6 @@ class Deprecated:
|
||||
__slots__ = ('field_type',)
|
||||
|
||||
def __init__(self, type_):
|
||||
if not isinstance(type_, type):
|
||||
raise TypeError("Expected 'type', got '{}'".format(type_))
|
||||
self.field_type = type_
|
||||
|
||||
|
||||
@ -437,7 +534,7 @@ def read_proc(func: typing.Callable[..., Record]):
|
||||
.format(type(func)))
|
||||
if inspect.iscoroutinefunction(func):
|
||||
raise TypeError("Callable must not be 'async def' function")
|
||||
if sys.version_info.minor >= 6:
|
||||
if sys.version_info >= (3, 6):
|
||||
if inspect.isasyncgenfunction(func):
|
||||
raise TypeError("Callable must not be 'async def' function")
|
||||
if inspect.isgeneratorfunction(func):
|
||||
@ -456,23 +553,22 @@ def read_proc(func: typing.Callable[..., Record]):
|
||||
for param in params:
|
||||
name = param.name
|
||||
type_ = param.annotation
|
||||
# TODO: Convert type_ to _mgp.CypherType
|
||||
if type_ is param.empty:
|
||||
type_ = object
|
||||
cypher_type = _typing_to_cypher_type(type_)
|
||||
if param.default is param.empty:
|
||||
mgp_proc.add_arg(name, type_)
|
||||
mgp_proc.add_arg(name, cypher_type)
|
||||
else:
|
||||
mgp_proc.add_opt_arg(name, type_, param.default)
|
||||
mgp_proc.add_opt_arg(name, cypher_type, param.default)
|
||||
if sig.return_annotation is not sig.empty:
|
||||
record = sig.return_annotation
|
||||
if not isinstance(record, Record):
|
||||
raise TypeError("Expected '{}' to return 'mgp.Record', got '{}'"
|
||||
.format(func.__name__, type(record)))
|
||||
for name, type_ in record.fields.items():
|
||||
# TODO: Convert type_ to _mgp.CypherType
|
||||
if isinstance(type_, Deprecated):
|
||||
field_type = type_.field_type
|
||||
mgp_proc.add_deprecated_result(name, field_type)
|
||||
cypher_type = _typing_to_cypher_type(type_.field_type)
|
||||
mgp_proc.add_deprecated_result(name, cypher_type)
|
||||
else:
|
||||
mgp_proc.add_result(name, type_)
|
||||
mgp_proc.add_result(name, _typing_to_cypher_type(type_))
|
||||
return func
|
||||
|
@ -158,6 +158,28 @@ PyObject *MakePyGraph(const mgp_graph *graph, mgp_memory *memory) {
|
||||
return PyObject_Init(reinterpret_cast<PyObject *>(py_graph), &PyGraphType);
|
||||
}
|
||||
|
||||
struct PyCypherType {
|
||||
PyObject_HEAD
|
||||
const mgp_type *type;
|
||||
};
|
||||
|
||||
static PyTypeObject PyCypherTypeType = {
|
||||
PyVarObject_HEAD_INIT(nullptr, 0)
|
||||
.tp_name = "_mgp.Type",
|
||||
.tp_doc = "Wraps struct mgp_type.",
|
||||
.tp_basicsize = sizeof(PyCypherType),
|
||||
.tp_flags = Py_TPFLAGS_DEFAULT,
|
||||
.tp_new = PyType_GenericNew,
|
||||
};
|
||||
|
||||
PyObject *MakePyCypherType(const mgp_type *type) {
|
||||
auto *py_type = PyObject_New(PyCypherType, &PyCypherTypeType);
|
||||
if (!py_type) return nullptr;
|
||||
py_type->type = type;
|
||||
return PyObject_Init(reinterpret_cast<PyObject *>(py_type),
|
||||
&PyCypherTypeType);
|
||||
}
|
||||
|
||||
struct PyQueryProc {
|
||||
PyObject_HEAD
|
||||
mgp_proc *proc;
|
||||
@ -168,8 +190,11 @@ PyObject *PyQueryProcAddArg(PyQueryProc *self, PyObject *args) {
|
||||
const char *name = nullptr;
|
||||
PyObject *py_type = nullptr;
|
||||
if (!PyArg_ParseTuple(args, "sO", &name, &py_type)) return nullptr;
|
||||
// TODO: Convert Python type to mgp_type
|
||||
const auto *type = mgp_type_nullable(mgp_type_any());
|
||||
if (Py_TYPE(py_type) != &PyCypherTypeType) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a _mgp.Type.");
|
||||
return nullptr;
|
||||
}
|
||||
const auto *type = reinterpret_cast<PyCypherType *>(py_type)->type;
|
||||
if (!mgp_proc_add_arg(self->proc, name, type)) {
|
||||
PyErr_SetString(PyExc_ValueError, "Invalid call to mgp_proc_add_arg.");
|
||||
return nullptr;
|
||||
@ -184,8 +209,11 @@ PyObject *PyQueryProcAddOptArg(PyQueryProc *self, PyObject *args) {
|
||||
PyObject *py_value = nullptr;
|
||||
if (!PyArg_ParseTuple(args, "sOO", &name, &py_type, &py_value))
|
||||
return nullptr;
|
||||
// TODO: Convert Python type to mgp_type
|
||||
const auto *type = mgp_type_nullable(mgp_type_any());
|
||||
if (Py_TYPE(py_type) != &PyCypherTypeType) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a _mgp.Type.");
|
||||
return nullptr;
|
||||
}
|
||||
const auto *type = reinterpret_cast<PyCypherType *>(py_type)->type;
|
||||
mgp_memory memory{self->proc->opt_args.get_allocator().GetMemoryResource()};
|
||||
mgp_value *value;
|
||||
try {
|
||||
@ -218,8 +246,11 @@ PyObject *PyQueryProcAddResult(PyQueryProc *self, PyObject *args) {
|
||||
const char *name = nullptr;
|
||||
PyObject *py_type = nullptr;
|
||||
if (!PyArg_ParseTuple(args, "sO", &name, &py_type)) return nullptr;
|
||||
// TODO: Convert Python type to mgp_type
|
||||
const auto *type = mgp_type_nullable(mgp_type_any());
|
||||
if (Py_TYPE(py_type) != &PyCypherTypeType) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a _mgp.Type.");
|
||||
return nullptr;
|
||||
}
|
||||
const auto *type = reinterpret_cast<PyCypherType *>(py_type)->type;
|
||||
if (!mgp_proc_add_result(self->proc, name, type)) {
|
||||
PyErr_SetString(PyExc_ValueError, "Invalid call to mgp_proc_add_result.");
|
||||
return nullptr;
|
||||
@ -232,8 +263,11 @@ PyObject *PyQueryProcAddDeprecatedResult(PyQueryProc *self, PyObject *args) {
|
||||
const char *name = nullptr;
|
||||
PyObject *py_type = nullptr;
|
||||
if (!PyArg_ParseTuple(args, "sO", &name, &py_type)) return nullptr;
|
||||
// TODO: Convert Python type to mgp_type
|
||||
const auto *type = mgp_type_nullable(mgp_type_any());
|
||||
if (Py_TYPE(py_type) != &PyCypherTypeType) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a _mgp.Type.");
|
||||
return nullptr;
|
||||
}
|
||||
const auto *type = reinterpret_cast<PyCypherType *>(py_type)->type;
|
||||
if (!mgp_proc_add_deprecated_result(self->proc, name, type)) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"Invalid call to mgp_proc_add_deprecated_result.");
|
||||
@ -329,11 +363,100 @@ PyObject *MakePyQueryModule(mgp_module *module) {
|
||||
&PyQueryModuleType);
|
||||
}
|
||||
|
||||
PyObject *PyMgpModuleTypeNullable(PyObject *mod, PyObject *obj) {
|
||||
if (Py_TYPE(obj) != &PyCypherTypeType) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a _mgp.Type.");
|
||||
return nullptr;
|
||||
}
|
||||
auto *py_type = reinterpret_cast<PyCypherType *>(obj);
|
||||
return MakePyCypherType(mgp_type_nullable(py_type->type));
|
||||
}
|
||||
|
||||
PyObject *PyMgpModuleTypeList(PyObject *mod, PyObject *obj) {
|
||||
if (Py_TYPE(obj) != &PyCypherTypeType) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a _mgp.Type.");
|
||||
return nullptr;
|
||||
}
|
||||
auto *py_type = reinterpret_cast<PyCypherType *>(obj);
|
||||
return MakePyCypherType(mgp_type_list(py_type->type));
|
||||
}
|
||||
|
||||
PyObject *PyMgpModuleTypeAny(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
|
||||
return MakePyCypherType(mgp_type_any());
|
||||
}
|
||||
|
||||
PyObject *PyMgpModuleTypeBool(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
|
||||
return MakePyCypherType(mgp_type_bool());
|
||||
}
|
||||
|
||||
PyObject *PyMgpModuleTypeString(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
|
||||
return MakePyCypherType(mgp_type_string());
|
||||
}
|
||||
|
||||
PyObject *PyMgpModuleTypeInt(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
|
||||
return MakePyCypherType(mgp_type_int());
|
||||
}
|
||||
|
||||
PyObject *PyMgpModuleTypeFloat(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
|
||||
return MakePyCypherType(mgp_type_float());
|
||||
}
|
||||
|
||||
PyObject *PyMgpModuleTypeNumber(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
|
||||
return MakePyCypherType(mgp_type_number());
|
||||
}
|
||||
|
||||
PyObject *PyMgpModuleTypeMap(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
|
||||
return MakePyCypherType(mgp_type_map());
|
||||
}
|
||||
|
||||
PyObject *PyMgpModuleTypeNode(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
|
||||
return MakePyCypherType(mgp_type_node());
|
||||
}
|
||||
|
||||
PyObject *PyMgpModuleTypeRelationship(PyObject *mod,
|
||||
PyObject *Py_UNUSED(ignored)) {
|
||||
return MakePyCypherType(mgp_type_relationship());
|
||||
}
|
||||
|
||||
PyObject *PyMgpModuleTypePath(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
|
||||
return MakePyCypherType(mgp_type_path());
|
||||
}
|
||||
|
||||
static PyMethodDef PyMgpModuleMethods[] = {
|
||||
{"type_nullable", PyMgpModuleTypeNullable, METH_O,
|
||||
"Build a type representing either a `null` value or a value of given "
|
||||
"type."},
|
||||
{"type_list", PyMgpModuleTypeList, METH_O,
|
||||
"Build a type representing a list of values of given type."},
|
||||
{"type_any", PyMgpModuleTypeAny, METH_NOARGS,
|
||||
"Get the type representing any value that isn't `null`."},
|
||||
{"type_bool", PyMgpModuleTypeBool, METH_NOARGS,
|
||||
"Get the type representing boolean values."},
|
||||
{"type_string", PyMgpModuleTypeString, METH_NOARGS,
|
||||
"Get the type representing string values."},
|
||||
{"type_int", PyMgpModuleTypeInt, METH_NOARGS,
|
||||
"Get the type representing integer values."},
|
||||
{"type_float", PyMgpModuleTypeFloat, METH_NOARGS,
|
||||
"Get the type representing floating-point values."},
|
||||
{"type_number", PyMgpModuleTypeNumber, METH_NOARGS,
|
||||
"Get the type representing any number value."},
|
||||
{"type_map", PyMgpModuleTypeMap, METH_NOARGS,
|
||||
"Get the type representing map values."},
|
||||
{"type_node", PyMgpModuleTypeNode, METH_NOARGS,
|
||||
"Get the type representing graph node values."},
|
||||
{"type_relationship", PyMgpModuleTypeRelationship, METH_NOARGS,
|
||||
"Get the type representing graph relationship values."},
|
||||
{"type_path", PyMgpModuleTypePath, METH_NOARGS,
|
||||
"Get the type representing a graph path (walk) from one node to another."},
|
||||
{nullptr},
|
||||
};
|
||||
|
||||
static PyModuleDef PyMgpModule = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
.m_name = "_mgp",
|
||||
.m_doc = "Contains raw bindings to mg_procedure.h C API.",
|
||||
.m_size = -1,
|
||||
.m_methods = PyMgpModuleMethods,
|
||||
};
|
||||
|
||||
struct PyEdge {
|
||||
@ -403,7 +526,8 @@ static PyMethodDef PyEdgeMethods[] = {
|
||||
PyObject *PyEdgeRichCompare(PyObject *self, PyObject *other, int op);
|
||||
|
||||
static PyTypeObject PyEdgeType = {
|
||||
PyVarObject_HEAD_INIT(nullptr, 0).tp_name = "_mgp.Edge",
|
||||
PyVarObject_HEAD_INIT(nullptr, 0)
|
||||
.tp_name = "_mgp.Edge",
|
||||
.tp_doc = "Wraps struct mgp_edge.",
|
||||
.tp_basicsize = sizeof(PyEdge),
|
||||
.tp_flags = Py_TPFLAGS_DEFAULT,
|
||||
@ -577,6 +701,7 @@ PyObject *PyInitMgpModule() {
|
||||
if (!register_type(&PyQueryProcType, "Proc")) return nullptr;
|
||||
if (!register_type(&PyQueryModuleType, "Module")) return nullptr;
|
||||
if (!register_type(&PyVertexType, "Vertex")) return nullptr;
|
||||
if (!register_type(&PyCypherTypeType, "Type")) return nullptr;
|
||||
Py_INCREF(Py_None);
|
||||
if (PyModule_AddObject(mgp, "_MODULE", Py_None) < 0) {
|
||||
Py_DECREF(Py_None);
|
||||
@ -591,6 +716,7 @@ namespace {
|
||||
template <class TFun>
|
||||
auto WithMgpModule(mgp_module *module_def, const TFun &fun) {
|
||||
py::Object py_mgp(PyImport_ImportModule("_mgp"));
|
||||
CHECK(py_mgp) << "Expected builtin '_mgp' to be available for import";
|
||||
py::Object py_mgp_module(PyObject_GetAttrString(py_mgp, "_MODULE"));
|
||||
CHECK(py_mgp_module) << "Expected '_mgp' to have attribute '_MODULE'";
|
||||
// NOTE: This check is not thread safe, but this should only go through
|
||||
@ -599,7 +725,6 @@ auto WithMgpModule(mgp_module *module_def, const TFun &fun) {
|
||||
<< "Expected '_mgp._MODULE' to be None as we are just starting to "
|
||||
"import a new module. Is some other thread also importing Python "
|
||||
"modules?";
|
||||
CHECK(py_mgp) << "Expected builtin '_mgp' to be available for import";
|
||||
auto *py_query_module = MakePyQueryModule(module_def);
|
||||
CHECK(py_query_module);
|
||||
CHECK(0 <= PyObject_SetAttrString(py_mgp, "_MODULE", py_query_module));
|
||||
|
Loading…
Reference in New Issue
Block a user