Implement calling Python procedures

Summary:
You should now be able to invoke query procedures written in Python. To
test the example you can run memgraph with PYTHONPATH set to `include`.
For example, assuming you are in the root of the repo, run this command.

    PYTHONPATH=$PWD/include ./build/memgraph --query-modules-directory=./query_modules/

Alternatively, you can set a symlink inside the ./query_modules to point
to `include/mgp.py`, so there's no need to set PYTHONPATH. For example,
assuming you are in the root of the repo, run the following.

    cd ./query_modules
    ln -s ../include/mgp.py
    cd ..
    ./build/memgraph --query-modules-directory=./query_modules/

Depends on D207

Reviewers: mferencevic, ipaljak, dsantl

Reviewed By: ipaljak

Subscribers: buda, tlastre, pullbot

Differential Revision: https://phabricator.memgraph.io/D2708
This commit is contained in:
Teon Banek 2020-03-06 15:31:31 +01:00
parent 6f83fff171
commit fbdcad1106
5 changed files with 197 additions and 18 deletions

View File

@ -456,11 +456,22 @@ class ProcCtx:
Access to a ProcCtx is only valid during a single execution of a procedure
in a query. You should not globally store a ProcCtx instance.
'''
__slots__ = ('_graph',)
def __init__(self, graph):
if not isinstance(graph, _mgp.Graph):
raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph)))
self._graph = Graph(graph)
def is_valid(self) -> bool:
return self._graph.is_valid()
@property
def graph(self) -> Graph:
'''Raise InvalidContextError if context is invalid.'''
pass
if not self.is_valid():
raise InvalidContextError()
return self._graph
# Additional typing support
@ -634,13 +645,15 @@ def read_proc(func: typing.Callable[..., Record]):
sig = inspect.signature(func)
params = tuple(sig.parameters.values())
if params and params[0].annotation is ProcCtx:
@functools.wraps(func)
def wrapper(graph, args):
return func(ProcCtx(graph), *args)
params = params[1:]
mgp_proc = _mgp._MODULE.add_read_procedure(func)
mgp_proc = _mgp._MODULE.add_read_procedure(wrapper)
else:
@functools.wraps(func)
def wrapper(*args):
args_without_context = args[1:]
return func(*args_without_context)
def wrapper(graph, args):
return func(*args)
mgp_proc = _mgp._MODULE.add_read_procedure(wrapper)
for param in params:
name = param.name

View File

@ -44,8 +44,15 @@ class [[nodiscard]] Object final {
public:
Object() = default;
Object(std::nullptr_t) {}
/// Construct by taking the ownership of `PyObject *`.
explicit Object(PyObject *ptr) noexcept : ptr_(ptr) {}
/// Construct from a borrowed `PyObject *`, i.e. non-owned pointer.
static Object FromBorrow(PyObject *ptr) noexcept {
Py_XINCREF(ptr);
return Object(ptr);
}
~Object() noexcept { Py_XDECREF(ptr_); }
Object(const Object &other) noexcept : ptr_(other.ptr_) { Py_XINCREF(ptr_); }

View File

@ -1386,14 +1386,6 @@ const mgp_type *mgp_type_nullable(const mgp_type *type) {
}
}
namespace {
bool IsValidIdentifierName(const char *name) {
if (!name) return false;
std::regex regex("[_[:alpha:]][_[:alnum:]]*");
return std::regex_match(name, regex);
}
} // namespace
mgp_proc *mgp_module_add_read_procedure(mgp_module *module, const char *name,
mgp_proc_cb cb) {
if (!module || !cb) return nullptr;
@ -1548,4 +1540,10 @@ void PrintProcSignature(const mgp_proc &proc, std::ostream *stream) {
(*stream) << ")";
}
bool IsValidIdentifierName(const char *name) {
if (!name) return false;
std::regex regex("[_[:alpha:]][_[:alnum:]]*");
return std::regex_match(name, regex);
}
} // namespace query::procedure

View File

@ -573,4 +573,6 @@ namespace query::procedure {
/// @throw anything std::ostream::operator<< may throw.
void PrintProcSignature(const mgp_proc &, std::ostream *);
bool IsValidIdentifierName(const char *name);
} // namespace query::procedure

View File

@ -161,6 +161,12 @@ static PyTypeObject PyEdgesIteratorType = {
.tp_dealloc = reinterpret_cast<destructor>(PyEdgesIteratorDealloc),
};
PyObject *PyGraphInvalidate(PyGraph *self, PyObject *Py_UNUSED(ignored)) {
self->graph = nullptr;
self->memory = nullptr;
Py_RETURN_NONE;
}
PyObject *PyGraphIsValid(PyGraph *self, PyObject *Py_UNUSED(ignored)) {
return PyBool_FromLong(!!self->graph);
}
@ -208,6 +214,9 @@ PyObject *PyGraphIterVertices(PyGraph *self, PyObject *Py_UNUSED(ignored)) {
static PyMethodDef PyGraphMethods[] = {
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy),
METH_NOARGS, "__reduce__ is not supported"},
{"invalidate", reinterpret_cast<PyCFunction>(PyGraphInvalidate),
METH_NOARGS,
"Invalidate the Graph context thus preventing the Graph from being used."},
{"is_valid", reinterpret_cast<PyCFunction>(PyGraphIsValid), METH_NOARGS,
"Return True if Graph is in valid context and may be used."},
{"get_vertex_by_id", reinterpret_cast<PyCFunction>(PyGraphGetVertexById),
@ -403,23 +412,173 @@ py::Object MgpListToPyTuple(const mgp_list *list, PyGraph *py_graph) {
return py_tuple;
}
py::Object MgpListToPyTuple(const mgp_list *list, PyObject *py_graph) {
if (Py_TYPE(py_graph) != &PyGraphType) {
PyErr_SetString(PyExc_TypeError, "Expected a _mgp.Graph.");
return nullptr;
}
return MgpListToPyTuple(list, reinterpret_cast<PyGraph *>(py_graph));
}
namespace {
void SetErrorFromPython(mgp_result *result, const py::ExceptionInfo &exc_info) {
std::stringstream ss;
ss << exc_info;
const auto &msg = ss.str();
mgp_result_set_error_msg(result, msg.c_str());
}
std::optional<py::ExceptionInfo> AddRecordFromPython(mgp_result *result,
py::Object py_record) {
py::Object py_mgp(PyImport_ImportModule("mgp"));
if (!py_mgp) return py::FetchError();
auto record_cls = py_mgp.GetAttr("Record");
if (!record_cls) return py::FetchError();
if (!PyObject_IsInstance(py_record, record_cls)) {
std::stringstream ss;
ss << "Value '" << py_record << "' is not an instance of 'mgp.Record'";
const auto &msg = ss.str();
PyErr_SetString(PyExc_TypeError, msg.c_str());
return py::FetchError();
}
py::Object fields(py_record.GetAttr("fields"));
if (!fields) return py::FetchError();
if (!PyDict_Check(fields)) {
PyErr_SetString(PyExc_TypeError,
"Expected 'mgp.Record.fields' to be a 'dict'");
return py::FetchError();
}
py::Object items(PyDict_Items(fields));
if (!items) return py::FetchError();
auto *record = mgp_result_new_record(result);
if (!record) {
PyErr_NoMemory();
return py::FetchError();
}
Py_ssize_t len = PyList_GET_SIZE(static_cast<PyObject *>(items));
for (Py_ssize_t i = 0; i < len; ++i) {
auto *item = PyList_GET_ITEM(static_cast<PyObject *>(items), i);
if (!item) return py::FetchError();
CHECK(PyTuple_Check(item));
auto *key = PyTuple_GetItem(item, 0);
if (!key) return py::FetchError();
if (!PyUnicode_Check(key)) {
std::stringstream ss;
ss << "Field name '" << py::Object::FromBorrow(key)
<< "' is not an instance of 'str'";
const auto &msg = ss.str();
PyErr_SetString(PyExc_TypeError, msg.c_str());
return py::FetchError();
}
const auto *field_name = PyUnicode_AsUTF8(key);
if (!field_name) return py::FetchError();
auto *val = PyTuple_GetItem(item, 1);
if (!val) return py::FetchError();
mgp_memory memory{result->rows.get_allocator().GetMemoryResource()};
mgp_value *field_val{nullptr};
try {
// TODO: Make PyObjectToMgpValue set a Python exception instead.
field_val = PyObjectToMgpValue(val, &memory);
} catch (const std::exception &e) {
PyErr_SetString(PyExc_ValueError, e.what());
return py::FetchError();
}
CHECK(field_val);
if (!mgp_result_record_insert(record, field_name, field_val)) {
std::stringstream ss;
ss << "Unable to insert field '" << py::Object::FromBorrow(key)
<< "' with value: '" << py::Object::FromBorrow(val)
<< "'; did you set the correct field type?";
const auto &msg = ss.str();
PyErr_SetString(PyExc_ValueError, msg.c_str());
mgp_value_destroy(field_val);
return py::FetchError();
}
mgp_value_destroy(field_val);
}
return std::nullopt;
}
std::optional<py::ExceptionInfo> AddMultipleRecordsFromPython(
mgp_result *result, py::Object py_seq) {
Py_ssize_t len = PySequence_Size(py_seq);
if (len == -1) return py::FetchError();
for (Py_ssize_t i = 0; i < len; ++i) {
py::Object py_record(PySequence_GetItem(py_seq, i));
if (!py_record) return py::FetchError();
auto maybe_exc = AddRecordFromPython(result, py_record);
if (maybe_exc) return maybe_exc;
}
return std::nullopt;
}
template <class TFun>
std::optional<py::ExceptionInfo> WithPyGraph(const mgp_graph *graph,
mgp_memory *memory,
const TFun &fun) {
py::Object py_graph(MakePyGraph(graph, memory));
if (!py_graph) return py::FetchError();
try {
auto maybe_exc = fun(py_graph);
// After `fun` finishes, invalidate the graph thus preventing its use in
// Python code. This is just a precaution if someone were to store
// `mgp_<type>` objects globally in Python.
LOG_IF(FATAL, !py_graph.CallMethod("invalidate"))
<< py::FetchError().value();
// Run gc.collect (reference cycle-detection) explicitly, so that we are
// sure the procedure cleaned up everything it held references to. If any
// `mgp_<type>` remains alive, that means the user stored in somewhere
// globally and that will get reported as a query procedure memory leak in
// our logs.
py::Object gc(PyImport_ImportModule("gc"));
LOG_IF(FATAL, !gc) << py::FetchError().value();
LOG_IF(FATAL, !gc.CallMethod("collect")) << py::FetchError().value();
return maybe_exc;
} catch (...) {
LOG_IF(FATAL, !py_graph.CallMethod("invalidate"))
<< py::FetchError().value();
throw;
}
}
} // namespace
PyObject *PyQueryModuleAddReadProcedure(PyQueryModule *self, PyObject *cb) {
CHECK(self->module);
if (!PyCallable_Check(cb)) {
PyErr_SetString(PyExc_TypeError, "Expected a callable object.");
return nullptr;
}
Py_INCREF(cb);
py::Object py_cb(cb);
auto py_cb = py::Object::FromBorrow(cb);
py::Object py_name(py_cb.GetAttr("__name__"));
const auto *name = PyUnicode_AsUTF8(py_name);
// TODO: Validate name
if (!name) return nullptr;
if (!IsValidIdentifierName(name)) {
PyErr_SetString(PyExc_ValueError,
"Procedure name is not a valid identifier");
return nullptr;
}
auto *memory = self->module->procedures.get_allocator().GetMemoryResource();
mgp_proc proc(
name,
[py_cb](const mgp_list *, const mgp_graph *, mgp_result *, mgp_memory *) {
[py_cb](const mgp_list *args, const mgp_graph *graph, mgp_result *result,
mgp_memory *memory) {
auto gil = py::EnsureGIL();
throw utils::NotYetImplemented("Invoking Python procedures");
auto maybe_exc =
WithPyGraph(graph, memory,
[&](auto py_graph) -> std::optional<py::ExceptionInfo> {
py::Object py_args(MgpListToPyTuple(args, py_graph));
if (!py_args) return py::FetchError();
auto py_res = py_cb.Call(py_graph, py_args);
if (!py_res) return py::FetchError();
if (PySequence_Check(py_res)) {
return AddMultipleRecordsFromPython(result, py_res);
} else {
return AddRecordFromPython(result, py_res);
}
});
if (maybe_exc) return SetErrorFromPython(result, *maybe_exc);
},
memory);
const auto &[proc_it, did_insert] =