Implement calling Python procedures
Summary: You should now be able to invoke query procedures written in Python. To test the example you can run memgraph with PYTHONPATH set to `include`. For example, assuming you are in the root of the repo, run this command. PYTHONPATH=$PWD/include ./build/memgraph --query-modules-directory=./query_modules/ Alternatively, you can set a symlink inside the ./query_modules to point to `include/mgp.py`, so there's no need to set PYTHONPATH. For example, assuming you are in the root of the repo, run the following. cd ./query_modules ln -s ../include/mgp.py cd .. ./build/memgraph --query-modules-directory=./query_modules/ Depends on D207 Reviewers: mferencevic, ipaljak, dsantl Reviewed By: ipaljak Subscribers: buda, tlastre, pullbot Differential Revision: https://phabricator.memgraph.io/D2708
This commit is contained in:
parent
6f83fff171
commit
fbdcad1106
@ -456,11 +456,22 @@ class ProcCtx:
|
||||
Access to a ProcCtx is only valid during a single execution of a procedure
|
||||
in a query. You should not globally store a ProcCtx instance.
|
||||
'''
|
||||
__slots__ = ('_graph',)
|
||||
|
||||
def __init__(self, graph):
|
||||
if not isinstance(graph, _mgp.Graph):
|
||||
raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph)))
|
||||
self._graph = Graph(graph)
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
return self._graph.is_valid()
|
||||
|
||||
@property
|
||||
def graph(self) -> Graph:
|
||||
'''Raise InvalidContextError if context is invalid.'''
|
||||
pass
|
||||
if not self.is_valid():
|
||||
raise InvalidContextError()
|
||||
return self._graph
|
||||
|
||||
|
||||
# Additional typing support
|
||||
@ -634,13 +645,15 @@ def read_proc(func: typing.Callable[..., Record]):
|
||||
sig = inspect.signature(func)
|
||||
params = tuple(sig.parameters.values())
|
||||
if params and params[0].annotation is ProcCtx:
|
||||
@functools.wraps(func)
|
||||
def wrapper(graph, args):
|
||||
return func(ProcCtx(graph), *args)
|
||||
params = params[1:]
|
||||
mgp_proc = _mgp._MODULE.add_read_procedure(func)
|
||||
mgp_proc = _mgp._MODULE.add_read_procedure(wrapper)
|
||||
else:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args):
|
||||
args_without_context = args[1:]
|
||||
return func(*args_without_context)
|
||||
def wrapper(graph, args):
|
||||
return func(*args)
|
||||
mgp_proc = _mgp._MODULE.add_read_procedure(wrapper)
|
||||
for param in params:
|
||||
name = param.name
|
||||
|
@ -44,8 +44,15 @@ class [[nodiscard]] Object final {
|
||||
public:
|
||||
Object() = default;
|
||||
Object(std::nullptr_t) {}
|
||||
/// Construct by taking the ownership of `PyObject *`.
|
||||
explicit Object(PyObject *ptr) noexcept : ptr_(ptr) {}
|
||||
|
||||
/// Construct from a borrowed `PyObject *`, i.e. non-owned pointer.
|
||||
static Object FromBorrow(PyObject *ptr) noexcept {
|
||||
Py_XINCREF(ptr);
|
||||
return Object(ptr);
|
||||
}
|
||||
|
||||
~Object() noexcept { Py_XDECREF(ptr_); }
|
||||
|
||||
Object(const Object &other) noexcept : ptr_(other.ptr_) { Py_XINCREF(ptr_); }
|
||||
|
@ -1386,14 +1386,6 @@ const mgp_type *mgp_type_nullable(const mgp_type *type) {
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool IsValidIdentifierName(const char *name) {
|
||||
if (!name) return false;
|
||||
std::regex regex("[_[:alpha:]][_[:alnum:]]*");
|
||||
return std::regex_match(name, regex);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
mgp_proc *mgp_module_add_read_procedure(mgp_module *module, const char *name,
|
||||
mgp_proc_cb cb) {
|
||||
if (!module || !cb) return nullptr;
|
||||
@ -1548,4 +1540,10 @@ void PrintProcSignature(const mgp_proc &proc, std::ostream *stream) {
|
||||
(*stream) << ")";
|
||||
}
|
||||
|
||||
bool IsValidIdentifierName(const char *name) {
|
||||
if (!name) return false;
|
||||
std::regex regex("[_[:alpha:]][_[:alnum:]]*");
|
||||
return std::regex_match(name, regex);
|
||||
}
|
||||
|
||||
} // namespace query::procedure
|
||||
|
@ -573,4 +573,6 @@ namespace query::procedure {
|
||||
/// @throw anything std::ostream::operator<< may throw.
|
||||
void PrintProcSignature(const mgp_proc &, std::ostream *);
|
||||
|
||||
bool IsValidIdentifierName(const char *name);
|
||||
|
||||
} // namespace query::procedure
|
||||
|
@ -161,6 +161,12 @@ static PyTypeObject PyEdgesIteratorType = {
|
||||
.tp_dealloc = reinterpret_cast<destructor>(PyEdgesIteratorDealloc),
|
||||
};
|
||||
|
||||
PyObject *PyGraphInvalidate(PyGraph *self, PyObject *Py_UNUSED(ignored)) {
|
||||
self->graph = nullptr;
|
||||
self->memory = nullptr;
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
PyObject *PyGraphIsValid(PyGraph *self, PyObject *Py_UNUSED(ignored)) {
|
||||
return PyBool_FromLong(!!self->graph);
|
||||
}
|
||||
@ -208,6 +214,9 @@ PyObject *PyGraphIterVertices(PyGraph *self, PyObject *Py_UNUSED(ignored)) {
|
||||
static PyMethodDef PyGraphMethods[] = {
|
||||
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy),
|
||||
METH_NOARGS, "__reduce__ is not supported"},
|
||||
{"invalidate", reinterpret_cast<PyCFunction>(PyGraphInvalidate),
|
||||
METH_NOARGS,
|
||||
"Invalidate the Graph context thus preventing the Graph from being used."},
|
||||
{"is_valid", reinterpret_cast<PyCFunction>(PyGraphIsValid), METH_NOARGS,
|
||||
"Return True if Graph is in valid context and may be used."},
|
||||
{"get_vertex_by_id", reinterpret_cast<PyCFunction>(PyGraphGetVertexById),
|
||||
@ -403,23 +412,173 @@ py::Object MgpListToPyTuple(const mgp_list *list, PyGraph *py_graph) {
|
||||
return py_tuple;
|
||||
}
|
||||
|
||||
py::Object MgpListToPyTuple(const mgp_list *list, PyObject *py_graph) {
|
||||
if (Py_TYPE(py_graph) != &PyGraphType) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a _mgp.Graph.");
|
||||
return nullptr;
|
||||
}
|
||||
return MgpListToPyTuple(list, reinterpret_cast<PyGraph *>(py_graph));
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
void SetErrorFromPython(mgp_result *result, const py::ExceptionInfo &exc_info) {
|
||||
std::stringstream ss;
|
||||
ss << exc_info;
|
||||
const auto &msg = ss.str();
|
||||
mgp_result_set_error_msg(result, msg.c_str());
|
||||
}
|
||||
|
||||
std::optional<py::ExceptionInfo> AddRecordFromPython(mgp_result *result,
|
||||
py::Object py_record) {
|
||||
py::Object py_mgp(PyImport_ImportModule("mgp"));
|
||||
if (!py_mgp) return py::FetchError();
|
||||
auto record_cls = py_mgp.GetAttr("Record");
|
||||
if (!record_cls) return py::FetchError();
|
||||
if (!PyObject_IsInstance(py_record, record_cls)) {
|
||||
std::stringstream ss;
|
||||
ss << "Value '" << py_record << "' is not an instance of 'mgp.Record'";
|
||||
const auto &msg = ss.str();
|
||||
PyErr_SetString(PyExc_TypeError, msg.c_str());
|
||||
return py::FetchError();
|
||||
}
|
||||
py::Object fields(py_record.GetAttr("fields"));
|
||||
if (!fields) return py::FetchError();
|
||||
if (!PyDict_Check(fields)) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"Expected 'mgp.Record.fields' to be a 'dict'");
|
||||
return py::FetchError();
|
||||
}
|
||||
py::Object items(PyDict_Items(fields));
|
||||
if (!items) return py::FetchError();
|
||||
auto *record = mgp_result_new_record(result);
|
||||
if (!record) {
|
||||
PyErr_NoMemory();
|
||||
return py::FetchError();
|
||||
}
|
||||
Py_ssize_t len = PyList_GET_SIZE(static_cast<PyObject *>(items));
|
||||
for (Py_ssize_t i = 0; i < len; ++i) {
|
||||
auto *item = PyList_GET_ITEM(static_cast<PyObject *>(items), i);
|
||||
if (!item) return py::FetchError();
|
||||
CHECK(PyTuple_Check(item));
|
||||
auto *key = PyTuple_GetItem(item, 0);
|
||||
if (!key) return py::FetchError();
|
||||
if (!PyUnicode_Check(key)) {
|
||||
std::stringstream ss;
|
||||
ss << "Field name '" << py::Object::FromBorrow(key)
|
||||
<< "' is not an instance of 'str'";
|
||||
const auto &msg = ss.str();
|
||||
PyErr_SetString(PyExc_TypeError, msg.c_str());
|
||||
return py::FetchError();
|
||||
}
|
||||
const auto *field_name = PyUnicode_AsUTF8(key);
|
||||
if (!field_name) return py::FetchError();
|
||||
auto *val = PyTuple_GetItem(item, 1);
|
||||
if (!val) return py::FetchError();
|
||||
mgp_memory memory{result->rows.get_allocator().GetMemoryResource()};
|
||||
mgp_value *field_val{nullptr};
|
||||
try {
|
||||
// TODO: Make PyObjectToMgpValue set a Python exception instead.
|
||||
field_val = PyObjectToMgpValue(val, &memory);
|
||||
} catch (const std::exception &e) {
|
||||
PyErr_SetString(PyExc_ValueError, e.what());
|
||||
return py::FetchError();
|
||||
}
|
||||
CHECK(field_val);
|
||||
if (!mgp_result_record_insert(record, field_name, field_val)) {
|
||||
std::stringstream ss;
|
||||
ss << "Unable to insert field '" << py::Object::FromBorrow(key)
|
||||
<< "' with value: '" << py::Object::FromBorrow(val)
|
||||
<< "'; did you set the correct field type?";
|
||||
const auto &msg = ss.str();
|
||||
PyErr_SetString(PyExc_ValueError, msg.c_str());
|
||||
mgp_value_destroy(field_val);
|
||||
return py::FetchError();
|
||||
}
|
||||
mgp_value_destroy(field_val);
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::optional<py::ExceptionInfo> AddMultipleRecordsFromPython(
|
||||
mgp_result *result, py::Object py_seq) {
|
||||
Py_ssize_t len = PySequence_Size(py_seq);
|
||||
if (len == -1) return py::FetchError();
|
||||
for (Py_ssize_t i = 0; i < len; ++i) {
|
||||
py::Object py_record(PySequence_GetItem(py_seq, i));
|
||||
if (!py_record) return py::FetchError();
|
||||
auto maybe_exc = AddRecordFromPython(result, py_record);
|
||||
if (maybe_exc) return maybe_exc;
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
template <class TFun>
|
||||
std::optional<py::ExceptionInfo> WithPyGraph(const mgp_graph *graph,
|
||||
mgp_memory *memory,
|
||||
const TFun &fun) {
|
||||
py::Object py_graph(MakePyGraph(graph, memory));
|
||||
if (!py_graph) return py::FetchError();
|
||||
try {
|
||||
auto maybe_exc = fun(py_graph);
|
||||
// After `fun` finishes, invalidate the graph thus preventing its use in
|
||||
// Python code. This is just a precaution if someone were to store
|
||||
// `mgp_<type>` objects globally in Python.
|
||||
LOG_IF(FATAL, !py_graph.CallMethod("invalidate"))
|
||||
<< py::FetchError().value();
|
||||
// Run gc.collect (reference cycle-detection) explicitly, so that we are
|
||||
// sure the procedure cleaned up everything it held references to. If any
|
||||
// `mgp_<type>` remains alive, that means the user stored in somewhere
|
||||
// globally and that will get reported as a query procedure memory leak in
|
||||
// our logs.
|
||||
py::Object gc(PyImport_ImportModule("gc"));
|
||||
LOG_IF(FATAL, !gc) << py::FetchError().value();
|
||||
LOG_IF(FATAL, !gc.CallMethod("collect")) << py::FetchError().value();
|
||||
return maybe_exc;
|
||||
} catch (...) {
|
||||
LOG_IF(FATAL, !py_graph.CallMethod("invalidate"))
|
||||
<< py::FetchError().value();
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
PyObject *PyQueryModuleAddReadProcedure(PyQueryModule *self, PyObject *cb) {
|
||||
CHECK(self->module);
|
||||
if (!PyCallable_Check(cb)) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a callable object.");
|
||||
return nullptr;
|
||||
}
|
||||
Py_INCREF(cb);
|
||||
py::Object py_cb(cb);
|
||||
auto py_cb = py::Object::FromBorrow(cb);
|
||||
py::Object py_name(py_cb.GetAttr("__name__"));
|
||||
const auto *name = PyUnicode_AsUTF8(py_name);
|
||||
// TODO: Validate name
|
||||
if (!name) return nullptr;
|
||||
if (!IsValidIdentifierName(name)) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"Procedure name is not a valid identifier");
|
||||
return nullptr;
|
||||
}
|
||||
auto *memory = self->module->procedures.get_allocator().GetMemoryResource();
|
||||
mgp_proc proc(
|
||||
name,
|
||||
[py_cb](const mgp_list *, const mgp_graph *, mgp_result *, mgp_memory *) {
|
||||
[py_cb](const mgp_list *args, const mgp_graph *graph, mgp_result *result,
|
||||
mgp_memory *memory) {
|
||||
auto gil = py::EnsureGIL();
|
||||
throw utils::NotYetImplemented("Invoking Python procedures");
|
||||
auto maybe_exc =
|
||||
WithPyGraph(graph, memory,
|
||||
[&](auto py_graph) -> std::optional<py::ExceptionInfo> {
|
||||
py::Object py_args(MgpListToPyTuple(args, py_graph));
|
||||
if (!py_args) return py::FetchError();
|
||||
auto py_res = py_cb.Call(py_graph, py_args);
|
||||
if (!py_res) return py::FetchError();
|
||||
if (PySequence_Check(py_res)) {
|
||||
return AddMultipleRecordsFromPython(result, py_res);
|
||||
} else {
|
||||
return AddRecordFromPython(result, py_res);
|
||||
}
|
||||
});
|
||||
if (maybe_exc) return SetErrorFromPython(result, *maybe_exc);
|
||||
},
|
||||
memory);
|
||||
const auto &[proc_it, did_insert] =
|
||||
|
Loading…
Reference in New Issue
Block a user