Add python API for messages transformations (#181)
* Add python messages/transformations implementation * Added fixed result return type to transformations * Added is_deprecated to mgp_trans
This commit is contained in:
parent
ac230d0c2d
commit
a928c158da
@ -845,7 +845,8 @@ const struct mgp_message *mgp_messages_at(const struct mgp_messages *, size_t);
|
||||
/// Passed in arguments will not live longer than the callback's execution.
|
||||
/// Therefore, you must not store them globally or use the passed in mgp_memory
|
||||
/// to allocate global resources.
|
||||
typedef void (*mgp_trans_cb)(const struct mgp_messages *, struct mgp_graph *, struct mgp_result *, struct mgp_memory *);
|
||||
typedef void (*mgp_trans_cb)(const struct mgp_messages *, const struct mgp_graph *, struct mgp_result *,
|
||||
struct mgp_memory *);
|
||||
|
||||
/// Adds a transformation cb to the module pointed by mgp_module.
|
||||
/// Return non-zero if the transformation is added successfully.
|
||||
|
143
include/mgp.py
143
include/mgp.py
@ -714,6 +714,18 @@ class Deprecated:
|
||||
self.field_type = type_
|
||||
|
||||
|
||||
def raise_if_does_not_meet_requirements(func: typing.Callable[..., Record]):
|
||||
if not callable(func):
|
||||
raise TypeError("Expected a callable object, got an instance of '{}'"
|
||||
.format(type(func)))
|
||||
if inspect.iscoroutinefunction(func):
|
||||
raise TypeError("Callable must not be 'async def' function")
|
||||
if sys.version_info >= (3, 6):
|
||||
if inspect.isasyncgenfunction(func):
|
||||
raise TypeError("Callable must not be 'async def' function")
|
||||
if inspect.isgeneratorfunction(func):
|
||||
raise NotImplementedError("Generator functions are not supported")
|
||||
|
||||
def read_proc(func: typing.Callable[..., Record]):
|
||||
'''
|
||||
Register `func` as a a read-only procedure of the current module.
|
||||
@ -754,16 +766,7 @@ def read_proc(func: typing.Callable[..., Record]):
|
||||
CALL example.procedure(1) YIELD args, result;
|
||||
Naturally, you may pass in different arguments or yield less fields.
|
||||
'''
|
||||
if not callable(func):
|
||||
raise TypeError("Expected a callable object, got an instance of '{}'"
|
||||
.format(type(func)))
|
||||
if inspect.iscoroutinefunction(func):
|
||||
raise TypeError("Callable must not be 'async def' function")
|
||||
if sys.version_info >= (3, 6):
|
||||
if inspect.isasyncgenfunction(func):
|
||||
raise TypeError("Callable must not be 'async def' function")
|
||||
if inspect.isgeneratorfunction(func):
|
||||
raise NotImplementedError("Generator functions are not supported")
|
||||
raise_if_does_not_meet_requirements(func)
|
||||
sig = inspect.signature(func)
|
||||
params = tuple(sig.parameters.values())
|
||||
if params and params[0].annotation is ProcCtx:
|
||||
@ -799,3 +802,123 @@ def read_proc(func: typing.Callable[..., Record]):
|
||||
else:
|
||||
mgp_proc.add_result(name, _typing_to_cypher_type(type_))
|
||||
return func
|
||||
|
||||
class InvalidMessageError(Exception):
|
||||
'''Signals using a message instance outside of the registered transformation.'''
|
||||
pass
|
||||
|
||||
class Message:
|
||||
'''Represents a message from a stream.'''
|
||||
__slots__ = ('_message',)
|
||||
|
||||
def __init__(self, message):
|
||||
if not isinstance(message, _mgp.Message):
|
||||
raise TypeError("Expected '_mgp.Message', got '{}'".format(type(message)))
|
||||
self._message = message
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
# This is the same as the shallow copy, because we want to share the
|
||||
# underlying C struct. Besides, it doesn't make much sense to actually
|
||||
# copy _mgp.Messages as that always references all the messages.
|
||||
return Message(self._message)
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
'''Return True if `self` is in valid context and may be used.'''
|
||||
return self._message.is_valid()
|
||||
|
||||
def payload(self) -> bytes:
|
||||
if not self.is_valid():
|
||||
raise InvalidMessageError()
|
||||
return self._messages._payload(_message)
|
||||
|
||||
def topic_name(self) -> str:
|
||||
if not self.is_valid():
|
||||
raise InvalidMessageError()
|
||||
return self._messages._topic_name(_message)
|
||||
|
||||
def key() -> bytes:
|
||||
if not self.is_valid():
|
||||
raise InvalidMessageError()
|
||||
return self._messages.key(_message)
|
||||
|
||||
def timestamp() -> int:
|
||||
if not self.is_valid():
|
||||
raise InvalidMessageError()
|
||||
return self._messages.timestamp(_message)
|
||||
|
||||
class InvalidMessagesError(Exception):
|
||||
'''Signals using a messages instance outside of the registered transformation.'''
|
||||
pass
|
||||
|
||||
class Messages:
|
||||
'''Represents a list of messages from a stream.'''
|
||||
__slots__ = ('_messages',)
|
||||
|
||||
def __init__(self, messages):
|
||||
if not isinstance(messages, _mgp.Messages):
|
||||
raise TypeError("Expected '_mgp.Messages', got '{}'".format(type(messages)))
|
||||
self._messages = messages
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
# This is the same as the shallow copy, because we want to share the
|
||||
# underlying C struct. Besides, it doesn't make much sense to actually
|
||||
# copy _mgp.Messages as that always references all the messages.
|
||||
return Messages(self._messages)
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
'''Return True if `self` is in valid context and may be used.'''
|
||||
return self._messages.is_valid()
|
||||
|
||||
def message_at(self, id : int) -> Message:
|
||||
'''Raise InvalidMessagesError if context is invalid.'''
|
||||
if not self.is_valid():
|
||||
raise InvalidMessagesError()
|
||||
return Message(self._messages.message_at(id))
|
||||
|
||||
def total_messages() -> int:
|
||||
'''Raise InvalidContextError if context is invalid.'''
|
||||
if not self.is_valid():
|
||||
raise InvalidMessagesError()
|
||||
return self._messages.total_messages()
|
||||
|
||||
class TransCtx:
|
||||
'''Context of a transformation being executed.
|
||||
|
||||
Access to a TransCtx is only valid during a single execution of a transformation.
|
||||
You should not globally store a TransCtx instance.
|
||||
'''
|
||||
__slots__ = ('_graph')
|
||||
|
||||
def __init__(self, graph):
|
||||
if not isinstance(graph, _mgp.Graph):
|
||||
raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph)))
|
||||
self._graph = Graph(graph)
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
return self._graph.is_valid()
|
||||
|
||||
@property
|
||||
def graph(self) -> Graph:
|
||||
'''Raise InvalidContextError if context is invalid.'''
|
||||
if not self.is_valid():
|
||||
raise InvalidContextError()
|
||||
return self._graph
|
||||
|
||||
def transformation(func: typing.Callable[..., Record]):
|
||||
raise_if_does_not_meet_requirements(func)
|
||||
sig = inspect.signature(func)
|
||||
params = tuple(sig.parameters.values())
|
||||
if not params or not params[0].annotation is Messages:
|
||||
if not len(params) == 2 or not params[1].annotation is Messages:
|
||||
raise NotImplementedError("Valid signatures for transformations are (TransCtx, Messages) or (Messages)")
|
||||
if params[0].annotation is TransCtx:
|
||||
@functools.wraps(func)
|
||||
def wrapper(graph, messages):
|
||||
return func(TransCtx(graph), messages)
|
||||
_mgp._MODULE.add_transformation(wrapper)
|
||||
else:
|
||||
@functools.wraps(func)
|
||||
def wrapper(graph, messages):
|
||||
return func(messages)
|
||||
_mgp._MODULE.add_transformation(wrapper)
|
||||
return func
|
||||
|
@ -3734,6 +3734,7 @@ class CallProcedureCursor : public Cursor {
|
||||
mgp_graph graph{context.db_accessor, graph_view, &context};
|
||||
CallCustomProcedure(self_->procedure_name_, *proc, self_->arguments_, graph, &evaluator, memory, memory_limit,
|
||||
&result_);
|
||||
|
||||
// Reset result_.signature to nullptr, because outside of this scope we
|
||||
// will no longer hold a lock on the `module`. If someone were to reload
|
||||
// it, the pointer would be invalid.
|
||||
|
@ -8,10 +8,12 @@
|
||||
|
||||
#include "module.hpp"
|
||||
#include "utils/algorithm.hpp"
|
||||
#include "utils/concepts.hpp"
|
||||
#include "utils/logging.hpp"
|
||||
#include "utils/math.hpp"
|
||||
#include "utils/memory.hpp"
|
||||
#include "utils/string.hpp"
|
||||
|
||||
// This file contains implementation of top level C API functions, but this is
|
||||
// all actually part of query::procedure. So use that namespace for simplicity.
|
||||
// NOLINTNEXTLINE(google-build-using-namespace)
|
||||
@ -1343,27 +1345,38 @@ int mgp_proc_add_opt_arg(mgp_proc *proc, const char *name, const mgp_type *type,
|
||||
|
||||
namespace {
|
||||
|
||||
int AddResultToProc(mgp_proc *proc, const char *name, const mgp_type *type, bool is_deprecated) {
|
||||
if (!proc || !type) return 0;
|
||||
if (!IsValidIdentifierName(name)) return 0;
|
||||
if (proc->results.find(name) != proc->results.end()) return 0;
|
||||
template <typename T>
|
||||
concept ModuleProperties = utils::SameAsAnyOf<T, mgp_proc, mgp_trans>;
|
||||
|
||||
template <ModuleProperties T>
|
||||
bool AddResultToProp(T *prop, const char *name, const mgp_type *type, bool is_deprecated) {
|
||||
if (!prop || !type) return false;
|
||||
if (!IsValidIdentifierName(name)) return false;
|
||||
if (prop->results.find(name) != prop->results.end()) return false;
|
||||
try {
|
||||
auto *memory = proc->results.get_allocator().GetMemoryResource();
|
||||
proc->results.emplace(utils::pmr::string(name, memory), std::make_pair(type->impl.get(), is_deprecated));
|
||||
return 1;
|
||||
auto *memory = prop->results.get_allocator().GetMemoryResource();
|
||||
prop->results.emplace(utils::pmr::string(name, memory), std::make_pair(type->impl.get(), is_deprecated));
|
||||
return true;
|
||||
} catch (...) {
|
||||
return 0;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int mgp_proc_add_result(mgp_proc *proc, const char *name, const mgp_type *type) {
|
||||
return AddResultToProc(proc, name, type, false);
|
||||
return AddResultToProp(proc, name, type, false);
|
||||
}
|
||||
|
||||
bool MgpTransAddFixedResult(mgp_trans *trans) {
|
||||
if (int err = AddResultToProp(trans, "query", mgp_type_string(), false); err != 1) {
|
||||
return err;
|
||||
}
|
||||
return AddResultToProp(trans, "parameters", mgp_type_nullable(mgp_type_list(mgp_type_any())), false);
|
||||
}
|
||||
|
||||
int mgp_proc_add_deprecated_result(mgp_proc *proc, const char *name, const mgp_type *type) {
|
||||
return AddResultToProc(proc, name, type, true);
|
||||
return AddResultToProp(proc, name, type, true);
|
||||
}
|
||||
|
||||
int mgp_must_abort(const mgp_graph *graph) {
|
||||
|
@ -478,21 +478,23 @@ struct mgp_trans {
|
||||
|
||||
/// @throw std::bad_alloc
|
||||
/// @throw std::length_error
|
||||
mgp_trans(const char *name, mgp_trans_cb cb, utils::MemoryResource *memory) : name(name, memory), cb(cb) {}
|
||||
mgp_trans(const char *name, mgp_trans_cb cb, utils::MemoryResource *memory)
|
||||
: name(name, memory), cb(cb), results(memory) {}
|
||||
|
||||
/// @throw std::bad_alloc
|
||||
/// @throw std::length_error
|
||||
mgp_trans(const char *name,
|
||||
std::function<void(const mgp_messages *, const mgp_graph *, mgp_result *, mgp_memory *)> cb,
|
||||
utils::MemoryResource *memory)
|
||||
: name(name, memory), cb(cb) {}
|
||||
: name(name, memory), cb(cb), results(memory) {}
|
||||
|
||||
/// @throw std::bad_alloc
|
||||
/// @throw std::length_error
|
||||
mgp_trans(const mgp_trans &other, utils::MemoryResource *memory) : name(other.name, memory), cb(other.cb) {}
|
||||
mgp_trans(const mgp_trans &other, utils::MemoryResource *memory)
|
||||
: name(other.name, memory), cb(other.cb), results(other.results) {}
|
||||
|
||||
mgp_trans(mgp_trans &&other, utils::MemoryResource *memory)
|
||||
: name(std::move(other.name), memory), cb(std::move(other.cb)) {}
|
||||
: name(std::move(other.name), memory), cb(std::move(other.cb)), results(std::move(other.results)) {}
|
||||
|
||||
mgp_trans(const mgp_trans &other) = default;
|
||||
mgp_trans(mgp_trans &&other) = default;
|
||||
@ -505,9 +507,13 @@ struct mgp_trans {
|
||||
/// Name of the transformation.
|
||||
utils::pmr::string name;
|
||||
/// Entry-point for the transformation.
|
||||
std::function<void(const mgp_messages *, mgp_graph *, mgp_result *, mgp_memory *)> cb;
|
||||
std::function<void(const mgp_messages *, const mgp_graph *, mgp_result *, mgp_memory *)> cb;
|
||||
/// Fields this transformation returns.
|
||||
utils::pmr::map<utils::pmr::string, std::pair<const query::procedure::CypherType *, bool>> results;
|
||||
};
|
||||
|
||||
bool MgpTransAddFixedResult(mgp_trans *trans);
|
||||
|
||||
struct mgp_module {
|
||||
using allocator_type = utils::Allocator<mgp_module>;
|
||||
|
||||
|
@ -7,6 +7,7 @@ extern "C" {
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "fmt/format.h"
|
||||
#include "py/py.hpp"
|
||||
#include "query/procedure/py_module.hpp"
|
||||
#include "utils/file.hpp"
|
||||
@ -275,30 +276,44 @@ bool SharedLibraryModule::Load(const std::filesystem::path &file_path) {
|
||||
}
|
||||
// Get required mgp_init_module
|
||||
init_fn_ = reinterpret_cast<int (*)(mgp_module *, mgp_memory *)>(dlsym(handle_, "mgp_init_module"));
|
||||
const char *error = dlerror();
|
||||
if (!init_fn_ || error) {
|
||||
spdlog::error("Unable to load module {}; {}", file_path, error);
|
||||
char *dl_errored = dlerror();
|
||||
if (!init_fn_ || dl_errored) {
|
||||
spdlog::error("Unable to load module {}; {}", file_path, dl_errored);
|
||||
dlclose(handle_);
|
||||
handle_ = nullptr;
|
||||
return false;
|
||||
}
|
||||
if (!WithModuleRegistration(&procedures_, &transformations_, [&](auto *module_def, auto *memory) {
|
||||
// Run mgp_init_module which must succeed.
|
||||
int init_res = init_fn_(module_def, memory);
|
||||
if (init_res != 0) {
|
||||
spdlog::error("Unable to load module {}; mgp_init_module_returned {}", file_path, init_res);
|
||||
dlclose(handle_);
|
||||
handle_ = nullptr;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
})) {
|
||||
auto module_cb = [&](auto *module_def, auto *memory) {
|
||||
// Run mgp_init_module which must succeed.
|
||||
int init_res = init_fn_(module_def, memory);
|
||||
auto with_error = [this](std::string_view error_msg) {
|
||||
spdlog::error(error_msg);
|
||||
dlclose(handle_);
|
||||
handle_ = nullptr;
|
||||
return false;
|
||||
};
|
||||
|
||||
if (init_res != 0) {
|
||||
const auto error = fmt::format("Unable to load module {}; mgp_init_module_returned {} ", file_path, init_res);
|
||||
return with_error(error);
|
||||
}
|
||||
for (auto &trans : module_def->transformations) {
|
||||
const bool was_result_added = MgpTransAddFixedResult(&trans.second);
|
||||
if (!was_result_added) {
|
||||
const auto error =
|
||||
fmt::format("Unable to add result to transformation in module {}; add result failed", file_path);
|
||||
return with_error(error);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
if (!WithModuleRegistration(&procedures_, &transformations_, module_cb)) {
|
||||
return false;
|
||||
}
|
||||
// Get optional mgp_shutdown_module
|
||||
shutdown_fn_ = reinterpret_cast<int (*)()>(dlsym(handle_, "mgp_shutdown_module"));
|
||||
error = dlerror();
|
||||
if (error) spdlog::warn("When loading module {}; {}", file_path, error);
|
||||
dl_errored = dlerror();
|
||||
if (dl_errored) spdlog::warn("When loading module {}; {}", file_path, dl_errored);
|
||||
spdlog::info("Loaded module {}", file_path);
|
||||
return true;
|
||||
}
|
||||
@ -376,11 +391,23 @@ bool PythonModule::Load(const std::filesystem::path &file_path) {
|
||||
spdlog::error("Unable to load module {}; {}", file_path, *maybe_exc);
|
||||
return false;
|
||||
}
|
||||
py_module_ = WithModuleRegistration(&procedures_, &transformations_, [&](auto *module_def, auto *memory) {
|
||||
return ImportPyModule(file_path.stem().c_str(), module_def);
|
||||
});
|
||||
bool succ = true;
|
||||
auto module_cb = [&](auto *module_def, auto *memory) {
|
||||
auto result = ImportPyModule(file_path.stem().c_str(), module_def);
|
||||
for (auto &trans : module_def->transformations) {
|
||||
succ = MgpTransAddFixedResult(&trans.second);
|
||||
if (!succ) return result;
|
||||
};
|
||||
return result;
|
||||
};
|
||||
py_module_ = WithModuleRegistration(&procedures_, &transformations_, module_cb);
|
||||
if (py_module_) {
|
||||
spdlog::info("Loaded module {}", file_path);
|
||||
|
||||
if (!succ) {
|
||||
spdlog::error("Unable to add result to transformation");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
auto exc_info = py::FetchError().value();
|
||||
@ -566,6 +593,7 @@ std::optional<std::pair<std::string_view, std::string_view>> FindModuleNameAndPr
|
||||
if (name_parts.size() == 1U) return std::nullopt;
|
||||
auto last_dot_pos = fully_qualified_name.find_last_of('.');
|
||||
MG_ASSERT(last_dot_pos != std::string_view::npos);
|
||||
|
||||
const auto &module_name = fully_qualified_name.substr(0, last_dot_pos);
|
||||
const auto &name = name_parts.back();
|
||||
return std::make_pair(module_name, name);
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <string>
|
||||
|
||||
#include "query/procedure/mg_procedure_impl.hpp"
|
||||
#include "utils/pmr/vector.hpp"
|
||||
|
||||
namespace query::procedure {
|
||||
|
||||
@ -178,7 +179,7 @@ PyObject *PyGraphGetVertexById(PyGraph *self, PyObject *args) {
|
||||
MG_ASSERT(self->graph);
|
||||
MG_ASSERT(self->memory);
|
||||
static_assert(std::is_same_v<int64_t, long>);
|
||||
int64_t id;
|
||||
int64_t id = 0;
|
||||
if (!PyArg_ParseTuple(args, "l", &id)) return nullptr;
|
||||
auto *vertex = mgp_graph_get_vertex_by_id(self->graph, mgp_vertex_id{id}, self->memory);
|
||||
if (!vertex) {
|
||||
@ -400,6 +401,188 @@ struct PyQueryModule {
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
struct PyMessages {
|
||||
PyObject_HEAD;
|
||||
const mgp_messages *messages;
|
||||
mgp_memory *memory;
|
||||
};
|
||||
|
||||
struct PyMessage {
|
||||
PyObject_HEAD;
|
||||
const mgp_message *message;
|
||||
const PyMessages *messages;
|
||||
mgp_memory *memory;
|
||||
};
|
||||
|
||||
PyObject *PyMessagesIsValid(const PyMessages *self, PyObject *Py_UNUSED(ignored)) {
|
||||
return PyBool_FromLong(!!self->messages);
|
||||
}
|
||||
|
||||
PyObject *PyMessageIsValid(PyMessage *self, PyObject *Py_UNUSED(ignored)) {
|
||||
return PyMessagesIsValid(self->messages, nullptr);
|
||||
}
|
||||
|
||||
PyObject *PyMessageGetPayload(PyMessage *self, PyObject *Py_UNUSED(ignored)) {
|
||||
MG_ASSERT(self->message);
|
||||
auto payload_size = mgp_message_get_payload_size(self->message);
|
||||
const auto *payload = mgp_message_get_payload(self->message);
|
||||
auto *raw_bytes = PyByteArray_FromStringAndSize(payload, payload_size);
|
||||
if (!raw_bytes) {
|
||||
PyErr_SetString(PyExc_RuntimeError, "Unable to get raw bytes from payload");
|
||||
return nullptr;
|
||||
}
|
||||
return raw_bytes;
|
||||
}
|
||||
|
||||
PyObject *PyMessageGetTopicName(PyMessage *self, PyObject *Py_UNUSED(ignored)) {
|
||||
MG_ASSERT(self->message);
|
||||
MG_ASSERT(self->memory);
|
||||
const auto *topic_name = mgp_message_topic_name(self->message);
|
||||
auto *py_topic_name = PyUnicode_FromString(topic_name);
|
||||
if (!py_topic_name) {
|
||||
PyErr_SetString(PyExc_RuntimeError, "Unable to get raw bytes from payload");
|
||||
return nullptr;
|
||||
}
|
||||
return py_topic_name;
|
||||
}
|
||||
|
||||
PyObject *PyMessageGetKey(PyMessage *self, PyObject *Py_UNUSED(ignored)) {
|
||||
MG_ASSERT(self->message);
|
||||
MG_ASSERT(self->memory);
|
||||
auto key_size = mgp_message_key_size(self->message);
|
||||
const auto *key = mgp_message_key(self->message);
|
||||
auto *raw_bytes = PyByteArray_FromStringAndSize(key, key_size);
|
||||
if (!raw_bytes) {
|
||||
PyErr_SetString(PyExc_RuntimeError, "Unable to get raw bytes from payload");
|
||||
return nullptr;
|
||||
}
|
||||
return raw_bytes;
|
||||
}
|
||||
|
||||
PyObject *PyMessageGetTimestamp(PyMessage *self, PyObject *Py_UNUSED(ignored)) {
|
||||
MG_ASSERT(self->message);
|
||||
MG_ASSERT(self->memory);
|
||||
auto timestamp = mgp_message_timestamp(self->message);
|
||||
auto *py_int = PyLong_FromUnsignedLong(timestamp);
|
||||
if (!py_int) {
|
||||
PyErr_SetString(PyExc_IndexError, "Unable to get timestamp.");
|
||||
return nullptr;
|
||||
}
|
||||
return py_int;
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static PyMethodDef PyMessageMethods[] = {
|
||||
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
|
||||
{"is_valid", reinterpret_cast<PyCFunction>(PyMessageIsValid), METH_NOARGS,
|
||||
"Return True if messages is in valid context and may be used."},
|
||||
{"payload", reinterpret_cast<PyCFunction>(PyMessageGetPayload), METH_NOARGS, "Get payload"},
|
||||
{"topic_name", reinterpret_cast<PyCFunction>(PyMessageGetTopicName), METH_NOARGS, "Get topic name."},
|
||||
{"key", reinterpret_cast<PyCFunction>(PyMessageGetKey), METH_NOARGS, "Get message key."},
|
||||
{"timestamp", reinterpret_cast<PyCFunction>(PyMessageGetTimestamp), METH_NOARGS, "Get message timestamp."},
|
||||
{nullptr},
|
||||
};
|
||||
|
||||
void PyMessageDealloc(PyMessage *self) {
|
||||
MG_ASSERT(self->memory);
|
||||
MG_ASSERT(self->message);
|
||||
MG_ASSERT(self->messages);
|
||||
// NOLINTNEXTLINE
|
||||
Py_DECREF(self->messages);
|
||||
// NOLINTNEXTLINE
|
||||
Py_TYPE(self)->tp_free(self);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static PyTypeObject PyMessageType = {
|
||||
PyVarObject_HEAD_INIT(nullptr, 0).tp_name = "_mgp.Message",
|
||||
.tp_basicsize = sizeof(PyMessage),
|
||||
.tp_dealloc = reinterpret_cast<destructor>(PyMessageDealloc),
|
||||
// NOLINTNEXTLINE
|
||||
.tp_flags = Py_TPFLAGS_DEFAULT,
|
||||
.tp_doc = "Wraps struct mgp_message.",
|
||||
// NOLINTNEXTLINE
|
||||
.tp_methods = PyMessageMethods,
|
||||
};
|
||||
|
||||
PyObject *PyMessagesInvalidate(PyMessages *self, PyObject *Py_UNUSED(ignored)) {
|
||||
self->messages = nullptr;
|
||||
self->memory = nullptr;
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
PyObject *PyMessagesGetTotalMessages(PyMessages *self, PyObject *Py_UNUSED(ignored)) {
|
||||
MG_ASSERT(self->messages);
|
||||
MG_ASSERT(self->memory);
|
||||
auto size = self->messages->messages.size();
|
||||
auto *py_int = PyLong_FromSize_t(size);
|
||||
if (!py_int) {
|
||||
PyErr_SetString(PyExc_IndexError, "Unable to get timestamp.");
|
||||
return nullptr;
|
||||
}
|
||||
return py_int;
|
||||
}
|
||||
|
||||
PyObject *PyMessagesGetMessageAt(PyMessages *self, PyObject *args) {
|
||||
MG_ASSERT(self->messages);
|
||||
MG_ASSERT(self->memory);
|
||||
int64_t id = 0;
|
||||
if (!PyArg_ParseTuple(args, "l", &id)) return nullptr;
|
||||
if (id < 0) return nullptr;
|
||||
const auto *message = mgp_messages_at(self->messages, id);
|
||||
// NOLINTNEXTLINE
|
||||
auto *py_message = PyObject_New(PyMessage, &PyMessageType);
|
||||
if (!py_message) {
|
||||
return nullptr;
|
||||
}
|
||||
py_message->message = message;
|
||||
// NOLINTNEXTLINE
|
||||
Py_INCREF(self);
|
||||
py_message->messages = self;
|
||||
py_message->memory = self->memory;
|
||||
if (!message) {
|
||||
PyErr_SetString(PyExc_IndexError, "Unable to find the message with given index.");
|
||||
return nullptr;
|
||||
}
|
||||
// NOLINTNEXTLINE
|
||||
return reinterpret_cast<PyObject *>(py_message);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static PyMethodDef PyMessagesMethods[] = {
|
||||
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
|
||||
{"invalidate", reinterpret_cast<PyCFunction>(PyMessagesInvalidate), METH_NOARGS,
|
||||
"Invalidate the messages context thus preventing the messages from being used"},
|
||||
{"is_valid", reinterpret_cast<PyCFunction>(PyMessagesIsValid), METH_NOARGS,
|
||||
"Return True if messages is in valid context and may be used."},
|
||||
{"total_messages", reinterpret_cast<PyCFunction>(PyMessagesGetTotalMessages), METH_VARARGS,
|
||||
"Get number of messages available"},
|
||||
{"message_at", reinterpret_cast<PyCFunction>(PyMessagesGetMessageAt), METH_VARARGS,
|
||||
"Get message at index idx from messages"},
|
||||
{nullptr},
|
||||
};
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static PyTypeObject PyMessagesType = {
|
||||
PyVarObject_HEAD_INIT(nullptr, 0).tp_name = "_mgp.Messages",
|
||||
.tp_basicsize = sizeof(PyMessages),
|
||||
// NOLINTNEXTLINE
|
||||
.tp_flags = Py_TPFLAGS_DEFAULT,
|
||||
.tp_doc = "Wraps struct mgp_messages.",
|
||||
// NOLINTNEXTLINE
|
||||
.tp_methods = PyMessagesMethods,
|
||||
};
|
||||
|
||||
PyObject *MakePyMessages(const mgp_messages *msgs, mgp_memory *memory) {
|
||||
MG_ASSERT(!msgs || (msgs && memory));
|
||||
// NOLINTNEXTLINE
|
||||
auto *py_messages = PyObject_New(PyMessages, &PyMessagesType);
|
||||
if (!py_messages) return nullptr;
|
||||
py_messages->messages = msgs;
|
||||
py_messages->memory = memory;
|
||||
return reinterpret_cast<PyObject *>(py_messages);
|
||||
}
|
||||
|
||||
py::Object MgpListToPyTuple(const mgp_list *list, PyGraph *py_graph) {
|
||||
MG_ASSERT(list);
|
||||
MG_ASSERT(py_graph);
|
||||
@ -585,6 +768,86 @@ void CallPythonProcedure(const py::Object &py_cb, const mgp_list *args, const mg
|
||||
}
|
||||
}
|
||||
|
||||
void CallPythonTransformation(const py::Object &py_cb, const mgp_messages *msgs, const mgp_graph *graph,
|
||||
mgp_result *result, mgp_memory *memory) {
|
||||
auto gil = py::EnsureGIL();
|
||||
|
||||
auto error_to_msg = [](const std::optional<py::ExceptionInfo> &exc_info) -> std::optional<std::string> {
|
||||
if (!exc_info) return std::nullopt;
|
||||
// Here we tell the traceback formatter to skip the first line of the
|
||||
// traceback because that line will always be our wrapper function in our
|
||||
// internal `mgp.py` file. With that line skipped, the user will always
|
||||
// get only the relevant traceback that happened in his Python code.
|
||||
return py::FormatException(*exc_info, /* skip_first_line = */ true);
|
||||
};
|
||||
|
||||
auto call = [&](py::Object py_graph, py::Object py_messages) -> std::optional<py::ExceptionInfo> {
|
||||
auto py_res = py_cb.Call(py_messages, py_graph);
|
||||
if (!py_res) return py::FetchError();
|
||||
if (PySequence_Check(py_res.Ptr())) {
|
||||
return AddMultipleRecordsFromPython(result, py_res);
|
||||
}
|
||||
return AddRecordFromPython(result, py_res);
|
||||
};
|
||||
|
||||
auto cleanup = [](py::Object py_graph, py::Object py_messages) {
|
||||
// Run `gc.collect` (reference cycle-detection) explicitly, so that we are
|
||||
// sure the procedure cleaned up everything it held references to. If the
|
||||
// user stored a reference to one of our `_mgp` instances then the
|
||||
// internally used `mgp_*` structs will stay unfreed and a memory leak
|
||||
// will be reported at the end of the query execution.
|
||||
py::Object gc(PyImport_ImportModule("gc"));
|
||||
if (!gc) {
|
||||
LOG_FATAL(py::FetchError().value());
|
||||
}
|
||||
|
||||
if (!gc.CallMethod("collect")) {
|
||||
LOG_FATAL(py::FetchError().value());
|
||||
}
|
||||
|
||||
// After making sure all references from our side have been cleared,
|
||||
// invalidate the `_mgp.Graph` object. If the user kept a reference to one
|
||||
// of our `_mgp` instances then this will prevent them from using those
|
||||
// objects (whose internal `mgp_*` pointers are now invalid and would cause
|
||||
// a crash).
|
||||
if (!py_graph.CallMethod("invalidate")) {
|
||||
LOG_FATAL(py::FetchError().value());
|
||||
}
|
||||
if (!py_messages.CallMethod("invalidate")) {
|
||||
LOG_FATAL(py::FetchError().value());
|
||||
}
|
||||
};
|
||||
|
||||
// It is *VERY IMPORTANT* to note that this code takes great care not to keep
|
||||
// any extra references to any `_mgp` instances (except for `_mgp.Graph`), so
|
||||
// as not to introduce extra reference counts and prevent their deallocation.
|
||||
// In particular, the `ExceptionInfo` object has a `traceback` field that
|
||||
// contains references to the Python frames and their arguments, and therefore
|
||||
// our `_mgp` instances as well. Within this code we ensure not to keep the
|
||||
// `ExceptionInfo` object alive so that no extra reference counts are
|
||||
// introduced. We only fetch the error message and immediately destroy the
|
||||
// object.
|
||||
std::optional<std::string> maybe_msg;
|
||||
{
|
||||
py::Object py_graph(MakePyGraph(graph, memory));
|
||||
py::Object py_messages(MakePyMessages(msgs, memory));
|
||||
if (py_graph && py_messages) {
|
||||
try {
|
||||
maybe_msg = error_to_msg(call(py_graph, py_messages));
|
||||
cleanup(py_graph, py_messages);
|
||||
} catch (...) {
|
||||
cleanup(py_graph, py_messages);
|
||||
throw;
|
||||
}
|
||||
} else {
|
||||
maybe_msg = error_to_msg(py::FetchError());
|
||||
}
|
||||
}
|
||||
|
||||
if (maybe_msg) {
|
||||
mgp_result_set_error_msg(result, maybe_msg->c_str());
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
PyObject *PyQueryModuleAddReadProcedure(PyQueryModule *self, PyObject *cb) {
|
||||
@ -619,10 +882,41 @@ PyObject *PyQueryModuleAddReadProcedure(PyQueryModule *self, PyObject *cb) {
|
||||
return reinterpret_cast<PyObject *>(py_proc);
|
||||
}
|
||||
|
||||
PyObject *PyQueryModuleAddTransformation(PyQueryModule *self, PyObject *cb) {
|
||||
MG_ASSERT(self->module);
|
||||
if (!PyCallable_Check(cb)) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a callable object.");
|
||||
return nullptr;
|
||||
}
|
||||
auto py_cb = py::Object::FromBorrow(cb);
|
||||
py::Object py_name(py_cb.GetAttr("__name__"));
|
||||
const auto *name = PyUnicode_AsUTF8(py_name.Ptr());
|
||||
if (!name) return nullptr;
|
||||
if (!IsValidIdentifierName(name)) {
|
||||
PyErr_SetString(PyExc_ValueError, "Transformation name is not a valid identifier");
|
||||
return nullptr;
|
||||
}
|
||||
auto *memory = self->module->transformations.get_allocator().GetMemoryResource();
|
||||
mgp_trans trans(
|
||||
name,
|
||||
[py_cb](const mgp_messages *msgs, const mgp_graph *graph, mgp_result *result, mgp_memory *memory) {
|
||||
CallPythonTransformation(py_cb, msgs, graph, result, memory);
|
||||
},
|
||||
memory);
|
||||
const auto [trans_it, did_insert] = self->module->transformations.emplace(name, std::move(trans));
|
||||
if (!did_insert) {
|
||||
PyErr_SetString(PyExc_ValueError, "Already registered a procedure with the same name.");
|
||||
return nullptr;
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyMethodDef PyQueryModuleMethods[] = {
|
||||
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
|
||||
{"add_read_procedure", reinterpret_cast<PyCFunction>(PyQueryModuleAddReadProcedure), METH_O,
|
||||
"Register a read-only procedure with this module."},
|
||||
{"add_transformation", reinterpret_cast<PyCFunction>(PyQueryModuleAddTransformation), METH_O,
|
||||
"Register a transformation with this module."},
|
||||
{nullptr},
|
||||
};
|
||||
|
||||
@ -1348,6 +1642,8 @@ PyObject *PyInitMgpModule() {
|
||||
if (!register_type(&PyVertexType, "Vertex")) return nullptr;
|
||||
if (!register_type(&PyPathType, "Path")) return nullptr;
|
||||
if (!register_type(&PyCypherTypeType, "Type")) return nullptr;
|
||||
if (!register_type(&PyMessagesType, "Messages")) return nullptr;
|
||||
if (!register_type(&PyMessageType, "Message")) return nullptr;
|
||||
Py_INCREF(Py_None);
|
||||
if (PyModule_AddObject(mgp, "_MODULE", Py_None) < 0) {
|
||||
Py_DECREF(Py_None);
|
||||
|
@ -5,7 +5,8 @@
|
||||
#include "test_utils.hpp"
|
||||
|
||||
TEST(MgpTransTest, TestMgpTransApi) {
|
||||
constexpr auto no_op_cb = [](const mgp_messages *msg, mgp_graph *graph, mgp_result *result, mgp_memory *memory) {};
|
||||
constexpr auto no_op_cb = [](const mgp_messages *msg, const mgp_graph *graph, mgp_result *result,
|
||||
mgp_memory *memory) {};
|
||||
mgp_module module(utils::NewDeleteResource());
|
||||
// If this is false, then mgp_module_add_transformation()
|
||||
// correctly calls IsValidIdentifier(). We don't need to test
|
||||
|
Loading…
Reference in New Issue
Block a user