Support typing annotations in Python Query Modules

Reviewers: mferencevic, ipaljak, tlastre

Reviewed By: mferencevic

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D2689
This commit is contained in:
Teon Banek 2020-02-27 12:24:26 +01:00
parent 32e56684db
commit ad892f2db3
2 changed files with 246 additions and 25 deletions

View File

@ -28,11 +28,11 @@ class Label:
__slots__ = ('_name',)
def __init__(self, name):
self._name = name;
self._name = name
@property
def name(self) -> str:
return self._name;
return self._name
def __eq__(self, other) -> bool:
if isinstance(other, Label):
@ -41,6 +41,7 @@ class Label:
return self._name == other
return NotImplemented
# Named property value of a Vertex or an Edge.
# It would be better to use typing.NamedTuple with typed fields, but that is
# not available in Python 3.5.
@ -169,7 +170,10 @@ class Edge:
return self._edge == other._edge
VertexId = typing.NewType('VertexId', int)
if sys.version_info >= (3, 5, 2):
VertexId = typing.NewType('VertexId', int)
else:
VertexId = int
class InvalidVertexError(Exception):
@ -371,15 +375,110 @@ class ProcCtx:
Number = typing.Union[int, float]
List = typing.List
Map = typing.Union[dict, Edge, Vertex]
Any = typing.Union[bool, str, Number, Map, Path, list]
List = typing.List
Nullable = typing.Optional
class UnsupportedTypingError(Exception):
'''Signals a typing annotation is not supported as a _mgp.CypherType.'''
def __init__(self, type_):
super().__init__("Unsupported typing annotation '{}'".format(type_))
def _typing_to_cypher_type(type_):
'''Convert typing annotation to a _mgp.CypherType instance.'''
simple_types = {
typing.Any: _mgp.type_nullable(_mgp.type_any()),
object: _mgp.type_nullable(_mgp.type_any()),
list: _mgp.type_list(_mgp.type_nullable(_mgp.type_any())),
Any: _mgp.type_any(),
bool: _mgp.type_bool(),
str: _mgp.type_string(),
int: _mgp.type_int(),
float: _mgp.type_float(),
Number: _mgp.type_number(),
Map: _mgp.type_map(),
Vertex: _mgp.type_node(),
Edge: _mgp.type_relationship(),
Path: _mgp.type_path()
}
try:
return simple_types[type_]
except KeyError:
pass
if sys.version_info >= (3, 8):
complex_type = typing.get_origin(type_)
type_args = typing.get_args(type_)
if complex_type == typing.Union:
# If we have a Union with NoneType inside, it means we are building
# a nullable type.
if isinstance(None, type_args):
types = tuple(t for t in type_args if not isinstance(None, t))
if len(types) == 1:
type_arg, = types
else:
# We cannot do typing.Union[*types], so do the equivalent
# with __getitem__ which does not even need arg unpacking.
type_arg = typing.Union.__getitem__(types)
return _mgp.type_nullable(_typing_to_cypher_type(type_arg))
elif complex_type == list:
type_arg, = type_args
return _mgp.type_list(_typing_to_cypher_type(type_arg))
raise UnsupportedTypingError(type_)
else:
# We cannot get to type args in any reliable way prior to 3.8, but we
# still want to support typing.Optional and typing.List, so just parse
# their string representations. Hopefully, that is always pretty
# printed the same way. `typing.List[type]` is printed as such, while
# `typing.Optional[type]` is printed as 'typing.Union[type, NoneType]'
def parse_type_args(type_as_str):
return tuple(map(str.strip,
type_as_str[type_as_str.index('[') + 1: -1].split(',')))
def get_simple_type(type_as_str):
for simple_type, cypher_type in simple_types.items():
if type_as_str == str(simple_type):
return cypher_type
# Fallback to comparing to __name__ if it exits. This handles
# the cases like when we have 'object' which is
# `object.__name__`, but `str(object)` is "<class 'object'>"
try:
if type_as_str == simple_type.__name__:
return cypher_type
except AttributeError:
pass
def parse_typing(type_as_str):
if type_as_str.startswith('typing.Union'):
type_args_as_str = parse_type_args(type_as_str)
none_type_as_str = type(None).__name__
if none_type_as_str in type_args_as_str:
types = tuple(t for t in type_args_as_str if t != none_type_as_str)
if len(types) == 1:
type_arg_as_str, = types
else:
type_arg_as_str = 'typing.Union[' + ', '.join(types) + ']'
simple_type = get_simple_type(type_arg_as_str)
if simple_type is not None:
return _mgp.type_nullable(simple_type)
return _mgp.type_nullable(parse_typing(type_arg_as_str))
elif type_as_str.startswith('typing.List'):
type_arg_as_str, = parse_type_args(type_as_str)
simple_type = get_simple_type(type_arg_as_str)
if simple_type is not None:
return _mgp.type_list(simple_type)
return _mgp.type_list(parse_typing(type_arg_as_str))
raise UnsupportedTypingError(type_)
return parse_typing(str(type_))
# Procedure registration
class Deprecated:
@ -387,8 +486,6 @@ class Deprecated:
__slots__ = ('field_type',)
def __init__(self, type_):
if not isinstance(type_, type):
raise TypeError("Expected 'type', got '{}'".format(type_))
self.field_type = type_
@ -437,7 +534,7 @@ def read_proc(func: typing.Callable[..., Record]):
.format(type(func)))
if inspect.iscoroutinefunction(func):
raise TypeError("Callable must not be 'async def' function")
if sys.version_info.minor >= 6:
if sys.version_info >= (3, 6):
if inspect.isasyncgenfunction(func):
raise TypeError("Callable must not be 'async def' function")
if inspect.isgeneratorfunction(func):
@ -456,23 +553,22 @@ def read_proc(func: typing.Callable[..., Record]):
for param in params:
name = param.name
type_ = param.annotation
# TODO: Convert type_ to _mgp.CypherType
if type_ is param.empty:
type_ = object
cypher_type = _typing_to_cypher_type(type_)
if param.default is param.empty:
mgp_proc.add_arg(name, type_)
mgp_proc.add_arg(name, cypher_type)
else:
mgp_proc.add_opt_arg(name, type_, param.default)
mgp_proc.add_opt_arg(name, cypher_type, param.default)
if sig.return_annotation is not sig.empty:
record = sig.return_annotation
if not isinstance(record, Record):
raise TypeError("Expected '{}' to return 'mgp.Record', got '{}'"
.format(func.__name__, type(record)))
for name, type_ in record.fields.items():
# TODO: Convert type_ to _mgp.CypherType
if isinstance(type_, Deprecated):
field_type = type_.field_type
mgp_proc.add_deprecated_result(name, field_type)
cypher_type = _typing_to_cypher_type(type_.field_type)
mgp_proc.add_deprecated_result(name, cypher_type)
else:
mgp_proc.add_result(name, type_)
mgp_proc.add_result(name, _typing_to_cypher_type(type_))
return func

View File

@ -158,6 +158,28 @@ PyObject *MakePyGraph(const mgp_graph *graph, mgp_memory *memory) {
return PyObject_Init(reinterpret_cast<PyObject *>(py_graph), &PyGraphType);
}
struct PyCypherType {
PyObject_HEAD
const mgp_type *type;
};
static PyTypeObject PyCypherTypeType = {
PyVarObject_HEAD_INIT(nullptr, 0)
.tp_name = "_mgp.Type",
.tp_doc = "Wraps struct mgp_type.",
.tp_basicsize = sizeof(PyCypherType),
.tp_flags = Py_TPFLAGS_DEFAULT,
.tp_new = PyType_GenericNew,
};
PyObject *MakePyCypherType(const mgp_type *type) {
auto *py_type = PyObject_New(PyCypherType, &PyCypherTypeType);
if (!py_type) return nullptr;
py_type->type = type;
return PyObject_Init(reinterpret_cast<PyObject *>(py_type),
&PyCypherTypeType);
}
struct PyQueryProc {
PyObject_HEAD
mgp_proc *proc;
@ -168,8 +190,11 @@ PyObject *PyQueryProcAddArg(PyQueryProc *self, PyObject *args) {
const char *name = nullptr;
PyObject *py_type = nullptr;
if (!PyArg_ParseTuple(args, "sO", &name, &py_type)) return nullptr;
// TODO: Convert Python type to mgp_type
const auto *type = mgp_type_nullable(mgp_type_any());
if (Py_TYPE(py_type) != &PyCypherTypeType) {
PyErr_SetString(PyExc_TypeError, "Expected a _mgp.Type.");
return nullptr;
}
const auto *type = reinterpret_cast<PyCypherType *>(py_type)->type;
if (!mgp_proc_add_arg(self->proc, name, type)) {
PyErr_SetString(PyExc_ValueError, "Invalid call to mgp_proc_add_arg.");
return nullptr;
@ -184,8 +209,11 @@ PyObject *PyQueryProcAddOptArg(PyQueryProc *self, PyObject *args) {
PyObject *py_value = nullptr;
if (!PyArg_ParseTuple(args, "sOO", &name, &py_type, &py_value))
return nullptr;
// TODO: Convert Python type to mgp_type
const auto *type = mgp_type_nullable(mgp_type_any());
if (Py_TYPE(py_type) != &PyCypherTypeType) {
PyErr_SetString(PyExc_TypeError, "Expected a _mgp.Type.");
return nullptr;
}
const auto *type = reinterpret_cast<PyCypherType *>(py_type)->type;
mgp_memory memory{self->proc->opt_args.get_allocator().GetMemoryResource()};
mgp_value *value;
try {
@ -218,8 +246,11 @@ PyObject *PyQueryProcAddResult(PyQueryProc *self, PyObject *args) {
const char *name = nullptr;
PyObject *py_type = nullptr;
if (!PyArg_ParseTuple(args, "sO", &name, &py_type)) return nullptr;
// TODO: Convert Python type to mgp_type
const auto *type = mgp_type_nullable(mgp_type_any());
if (Py_TYPE(py_type) != &PyCypherTypeType) {
PyErr_SetString(PyExc_TypeError, "Expected a _mgp.Type.");
return nullptr;
}
const auto *type = reinterpret_cast<PyCypherType *>(py_type)->type;
if (!mgp_proc_add_result(self->proc, name, type)) {
PyErr_SetString(PyExc_ValueError, "Invalid call to mgp_proc_add_result.");
return nullptr;
@ -232,8 +263,11 @@ PyObject *PyQueryProcAddDeprecatedResult(PyQueryProc *self, PyObject *args) {
const char *name = nullptr;
PyObject *py_type = nullptr;
if (!PyArg_ParseTuple(args, "sO", &name, &py_type)) return nullptr;
// TODO: Convert Python type to mgp_type
const auto *type = mgp_type_nullable(mgp_type_any());
if (Py_TYPE(py_type) != &PyCypherTypeType) {
PyErr_SetString(PyExc_TypeError, "Expected a _mgp.Type.");
return nullptr;
}
const auto *type = reinterpret_cast<PyCypherType *>(py_type)->type;
if (!mgp_proc_add_deprecated_result(self->proc, name, type)) {
PyErr_SetString(PyExc_ValueError,
"Invalid call to mgp_proc_add_deprecated_result.");
@ -329,11 +363,100 @@ PyObject *MakePyQueryModule(mgp_module *module) {
&PyQueryModuleType);
}
PyObject *PyMgpModuleTypeNullable(PyObject *mod, PyObject *obj) {
if (Py_TYPE(obj) != &PyCypherTypeType) {
PyErr_SetString(PyExc_TypeError, "Expected a _mgp.Type.");
return nullptr;
}
auto *py_type = reinterpret_cast<PyCypherType *>(obj);
return MakePyCypherType(mgp_type_nullable(py_type->type));
}
PyObject *PyMgpModuleTypeList(PyObject *mod, PyObject *obj) {
if (Py_TYPE(obj) != &PyCypherTypeType) {
PyErr_SetString(PyExc_TypeError, "Expected a _mgp.Type.");
return nullptr;
}
auto *py_type = reinterpret_cast<PyCypherType *>(obj);
return MakePyCypherType(mgp_type_list(py_type->type));
}
PyObject *PyMgpModuleTypeAny(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
return MakePyCypherType(mgp_type_any());
}
PyObject *PyMgpModuleTypeBool(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
return MakePyCypherType(mgp_type_bool());
}
PyObject *PyMgpModuleTypeString(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
return MakePyCypherType(mgp_type_string());
}
PyObject *PyMgpModuleTypeInt(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
return MakePyCypherType(mgp_type_int());
}
PyObject *PyMgpModuleTypeFloat(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
return MakePyCypherType(mgp_type_float());
}
PyObject *PyMgpModuleTypeNumber(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
return MakePyCypherType(mgp_type_number());
}
PyObject *PyMgpModuleTypeMap(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
return MakePyCypherType(mgp_type_map());
}
PyObject *PyMgpModuleTypeNode(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
return MakePyCypherType(mgp_type_node());
}
PyObject *PyMgpModuleTypeRelationship(PyObject *mod,
PyObject *Py_UNUSED(ignored)) {
return MakePyCypherType(mgp_type_relationship());
}
PyObject *PyMgpModuleTypePath(PyObject *mod, PyObject *Py_UNUSED(ignored)) {
return MakePyCypherType(mgp_type_path());
}
static PyMethodDef PyMgpModuleMethods[] = {
{"type_nullable", PyMgpModuleTypeNullable, METH_O,
"Build a type representing either a `null` value or a value of given "
"type."},
{"type_list", PyMgpModuleTypeList, METH_O,
"Build a type representing a list of values of given type."},
{"type_any", PyMgpModuleTypeAny, METH_NOARGS,
"Get the type representing any value that isn't `null`."},
{"type_bool", PyMgpModuleTypeBool, METH_NOARGS,
"Get the type representing boolean values."},
{"type_string", PyMgpModuleTypeString, METH_NOARGS,
"Get the type representing string values."},
{"type_int", PyMgpModuleTypeInt, METH_NOARGS,
"Get the type representing integer values."},
{"type_float", PyMgpModuleTypeFloat, METH_NOARGS,
"Get the type representing floating-point values."},
{"type_number", PyMgpModuleTypeNumber, METH_NOARGS,
"Get the type representing any number value."},
{"type_map", PyMgpModuleTypeMap, METH_NOARGS,
"Get the type representing map values."},
{"type_node", PyMgpModuleTypeNode, METH_NOARGS,
"Get the type representing graph node values."},
{"type_relationship", PyMgpModuleTypeRelationship, METH_NOARGS,
"Get the type representing graph relationship values."},
{"type_path", PyMgpModuleTypePath, METH_NOARGS,
"Get the type representing a graph path (walk) from one node to another."},
{nullptr},
};
static PyModuleDef PyMgpModule = {
PyModuleDef_HEAD_INIT,
.m_name = "_mgp",
.m_doc = "Contains raw bindings to mg_procedure.h C API.",
.m_size = -1,
.m_methods = PyMgpModuleMethods,
};
struct PyEdge {
@ -403,7 +526,8 @@ static PyMethodDef PyEdgeMethods[] = {
PyObject *PyEdgeRichCompare(PyObject *self, PyObject *other, int op);
static PyTypeObject PyEdgeType = {
PyVarObject_HEAD_INIT(nullptr, 0).tp_name = "_mgp.Edge",
PyVarObject_HEAD_INIT(nullptr, 0)
.tp_name = "_mgp.Edge",
.tp_doc = "Wraps struct mgp_edge.",
.tp_basicsize = sizeof(PyEdge),
.tp_flags = Py_TPFLAGS_DEFAULT,
@ -577,6 +701,7 @@ PyObject *PyInitMgpModule() {
if (!register_type(&PyQueryProcType, "Proc")) return nullptr;
if (!register_type(&PyQueryModuleType, "Module")) return nullptr;
if (!register_type(&PyVertexType, "Vertex")) return nullptr;
if (!register_type(&PyCypherTypeType, "Type")) return nullptr;
Py_INCREF(Py_None);
if (PyModule_AddObject(mgp, "_MODULE", Py_None) < 0) {
Py_DECREF(Py_None);
@ -591,6 +716,7 @@ namespace {
template <class TFun>
auto WithMgpModule(mgp_module *module_def, const TFun &fun) {
py::Object py_mgp(PyImport_ImportModule("_mgp"));
CHECK(py_mgp) << "Expected builtin '_mgp' to be available for import";
py::Object py_mgp_module(PyObject_GetAttrString(py_mgp, "_MODULE"));
CHECK(py_mgp_module) << "Expected '_mgp' to have attribute '_MODULE'";
// NOTE: This check is not thread safe, but this should only go through
@ -599,7 +725,6 @@ auto WithMgpModule(mgp_module *module_def, const TFun &fun) {
<< "Expected '_mgp._MODULE' to be None as we are just starting to "
"import a new module. Is some other thread also importing Python "
"modules?";
CHECK(py_mgp) << "Expected builtin '_mgp' to be available for import";
auto *py_query_module = MakePyQueryModule(module_def);
CHECK(py_query_module);
CHECK(0 <= PyObject_SetAttrString(py_mgp, "_MODULE", py_query_module));