Add python API for messages transformations (#181)
* Add python messages/transformations implementation * Added fixed result return type to transformations * Added is_deprecated to mgp_trans
This commit is contained in:
parent
ac230d0c2d
commit
a928c158da
@ -845,7 +845,8 @@ const struct mgp_message *mgp_messages_at(const struct mgp_messages *, size_t);
|
|||||||
/// Passed in arguments will not live longer than the callback's execution.
|
/// Passed in arguments will not live longer than the callback's execution.
|
||||||
/// Therefore, you must not store them globally or use the passed in mgp_memory
|
/// Therefore, you must not store them globally or use the passed in mgp_memory
|
||||||
/// to allocate global resources.
|
/// to allocate global resources.
|
||||||
typedef void (*mgp_trans_cb)(const struct mgp_messages *, struct mgp_graph *, struct mgp_result *, struct mgp_memory *);
|
typedef void (*mgp_trans_cb)(const struct mgp_messages *, const struct mgp_graph *, struct mgp_result *,
|
||||||
|
struct mgp_memory *);
|
||||||
|
|
||||||
/// Adds a transformation cb to the module pointed by mgp_module.
|
/// Adds a transformation cb to the module pointed by mgp_module.
|
||||||
/// Return non-zero if the transformation is added successfully.
|
/// Return non-zero if the transformation is added successfully.
|
||||||
|
143
include/mgp.py
143
include/mgp.py
@ -714,6 +714,18 @@ class Deprecated:
|
|||||||
self.field_type = type_
|
self.field_type = type_
|
||||||
|
|
||||||
|
|
||||||
|
def raise_if_does_not_meet_requirements(func: typing.Callable[..., Record]):
|
||||||
|
if not callable(func):
|
||||||
|
raise TypeError("Expected a callable object, got an instance of '{}'"
|
||||||
|
.format(type(func)))
|
||||||
|
if inspect.iscoroutinefunction(func):
|
||||||
|
raise TypeError("Callable must not be 'async def' function")
|
||||||
|
if sys.version_info >= (3, 6):
|
||||||
|
if inspect.isasyncgenfunction(func):
|
||||||
|
raise TypeError("Callable must not be 'async def' function")
|
||||||
|
if inspect.isgeneratorfunction(func):
|
||||||
|
raise NotImplementedError("Generator functions are not supported")
|
||||||
|
|
||||||
def read_proc(func: typing.Callable[..., Record]):
|
def read_proc(func: typing.Callable[..., Record]):
|
||||||
'''
|
'''
|
||||||
Register `func` as a a read-only procedure of the current module.
|
Register `func` as a a read-only procedure of the current module.
|
||||||
@ -754,16 +766,7 @@ def read_proc(func: typing.Callable[..., Record]):
|
|||||||
CALL example.procedure(1) YIELD args, result;
|
CALL example.procedure(1) YIELD args, result;
|
||||||
Naturally, you may pass in different arguments or yield less fields.
|
Naturally, you may pass in different arguments or yield less fields.
|
||||||
'''
|
'''
|
||||||
if not callable(func):
|
raise_if_does_not_meet_requirements(func)
|
||||||
raise TypeError("Expected a callable object, got an instance of '{}'"
|
|
||||||
.format(type(func)))
|
|
||||||
if inspect.iscoroutinefunction(func):
|
|
||||||
raise TypeError("Callable must not be 'async def' function")
|
|
||||||
if sys.version_info >= (3, 6):
|
|
||||||
if inspect.isasyncgenfunction(func):
|
|
||||||
raise TypeError("Callable must not be 'async def' function")
|
|
||||||
if inspect.isgeneratorfunction(func):
|
|
||||||
raise NotImplementedError("Generator functions are not supported")
|
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
params = tuple(sig.parameters.values())
|
params = tuple(sig.parameters.values())
|
||||||
if params and params[0].annotation is ProcCtx:
|
if params and params[0].annotation is ProcCtx:
|
||||||
@ -799,3 +802,123 @@ def read_proc(func: typing.Callable[..., Record]):
|
|||||||
else:
|
else:
|
||||||
mgp_proc.add_result(name, _typing_to_cypher_type(type_))
|
mgp_proc.add_result(name, _typing_to_cypher_type(type_))
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
class InvalidMessageError(Exception):
|
||||||
|
'''Signals using a message instance outside of the registered transformation.'''
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Message:
|
||||||
|
'''Represents a message from a stream.'''
|
||||||
|
__slots__ = ('_message',)
|
||||||
|
|
||||||
|
def __init__(self, message):
|
||||||
|
if not isinstance(message, _mgp.Message):
|
||||||
|
raise TypeError("Expected '_mgp.Message', got '{}'".format(type(message)))
|
||||||
|
self._message = message
|
||||||
|
|
||||||
|
def __deepcopy__(self, memo):
|
||||||
|
# This is the same as the shallow copy, because we want to share the
|
||||||
|
# underlying C struct. Besides, it doesn't make much sense to actually
|
||||||
|
# copy _mgp.Messages as that always references all the messages.
|
||||||
|
return Message(self._message)
|
||||||
|
|
||||||
|
def is_valid(self) -> bool:
|
||||||
|
'''Return True if `self` is in valid context and may be used.'''
|
||||||
|
return self._message.is_valid()
|
||||||
|
|
||||||
|
def payload(self) -> bytes:
|
||||||
|
if not self.is_valid():
|
||||||
|
raise InvalidMessageError()
|
||||||
|
return self._messages._payload(_message)
|
||||||
|
|
||||||
|
def topic_name(self) -> str:
|
||||||
|
if not self.is_valid():
|
||||||
|
raise InvalidMessageError()
|
||||||
|
return self._messages._topic_name(_message)
|
||||||
|
|
||||||
|
def key() -> bytes:
|
||||||
|
if not self.is_valid():
|
||||||
|
raise InvalidMessageError()
|
||||||
|
return self._messages.key(_message)
|
||||||
|
|
||||||
|
def timestamp() -> int:
|
||||||
|
if not self.is_valid():
|
||||||
|
raise InvalidMessageError()
|
||||||
|
return self._messages.timestamp(_message)
|
||||||
|
|
||||||
|
class InvalidMessagesError(Exception):
|
||||||
|
'''Signals using a messages instance outside of the registered transformation.'''
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Messages:
|
||||||
|
'''Represents a list of messages from a stream.'''
|
||||||
|
__slots__ = ('_messages',)
|
||||||
|
|
||||||
|
def __init__(self, messages):
|
||||||
|
if not isinstance(messages, _mgp.Messages):
|
||||||
|
raise TypeError("Expected '_mgp.Messages', got '{}'".format(type(messages)))
|
||||||
|
self._messages = messages
|
||||||
|
|
||||||
|
def __deepcopy__(self, memo):
|
||||||
|
# This is the same as the shallow copy, because we want to share the
|
||||||
|
# underlying C struct. Besides, it doesn't make much sense to actually
|
||||||
|
# copy _mgp.Messages as that always references all the messages.
|
||||||
|
return Messages(self._messages)
|
||||||
|
|
||||||
|
def is_valid(self) -> bool:
|
||||||
|
'''Return True if `self` is in valid context and may be used.'''
|
||||||
|
return self._messages.is_valid()
|
||||||
|
|
||||||
|
def message_at(self, id : int) -> Message:
|
||||||
|
'''Raise InvalidMessagesError if context is invalid.'''
|
||||||
|
if not self.is_valid():
|
||||||
|
raise InvalidMessagesError()
|
||||||
|
return Message(self._messages.message_at(id))
|
||||||
|
|
||||||
|
def total_messages() -> int:
|
||||||
|
'''Raise InvalidContextError if context is invalid.'''
|
||||||
|
if not self.is_valid():
|
||||||
|
raise InvalidMessagesError()
|
||||||
|
return self._messages.total_messages()
|
||||||
|
|
||||||
|
class TransCtx:
|
||||||
|
'''Context of a transformation being executed.
|
||||||
|
|
||||||
|
Access to a TransCtx is only valid during a single execution of a transformation.
|
||||||
|
You should not globally store a TransCtx instance.
|
||||||
|
'''
|
||||||
|
__slots__ = ('_graph')
|
||||||
|
|
||||||
|
def __init__(self, graph):
|
||||||
|
if not isinstance(graph, _mgp.Graph):
|
||||||
|
raise TypeError("Expected '_mgp.Graph', got '{}'".format(type(graph)))
|
||||||
|
self._graph = Graph(graph)
|
||||||
|
|
||||||
|
def is_valid(self) -> bool:
|
||||||
|
return self._graph.is_valid()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def graph(self) -> Graph:
|
||||||
|
'''Raise InvalidContextError if context is invalid.'''
|
||||||
|
if not self.is_valid():
|
||||||
|
raise InvalidContextError()
|
||||||
|
return self._graph
|
||||||
|
|
||||||
|
def transformation(func: typing.Callable[..., Record]):
|
||||||
|
raise_if_does_not_meet_requirements(func)
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
params = tuple(sig.parameters.values())
|
||||||
|
if not params or not params[0].annotation is Messages:
|
||||||
|
if not len(params) == 2 or not params[1].annotation is Messages:
|
||||||
|
raise NotImplementedError("Valid signatures for transformations are (TransCtx, Messages) or (Messages)")
|
||||||
|
if params[0].annotation is TransCtx:
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(graph, messages):
|
||||||
|
return func(TransCtx(graph), messages)
|
||||||
|
_mgp._MODULE.add_transformation(wrapper)
|
||||||
|
else:
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(graph, messages):
|
||||||
|
return func(messages)
|
||||||
|
_mgp._MODULE.add_transformation(wrapper)
|
||||||
|
return func
|
||||||
|
@ -3734,6 +3734,7 @@ class CallProcedureCursor : public Cursor {
|
|||||||
mgp_graph graph{context.db_accessor, graph_view, &context};
|
mgp_graph graph{context.db_accessor, graph_view, &context};
|
||||||
CallCustomProcedure(self_->procedure_name_, *proc, self_->arguments_, graph, &evaluator, memory, memory_limit,
|
CallCustomProcedure(self_->procedure_name_, *proc, self_->arguments_, graph, &evaluator, memory, memory_limit,
|
||||||
&result_);
|
&result_);
|
||||||
|
|
||||||
// Reset result_.signature to nullptr, because outside of this scope we
|
// Reset result_.signature to nullptr, because outside of this scope we
|
||||||
// will no longer hold a lock on the `module`. If someone were to reload
|
// will no longer hold a lock on the `module`. If someone were to reload
|
||||||
// it, the pointer would be invalid.
|
// it, the pointer would be invalid.
|
||||||
|
@ -8,10 +8,12 @@
|
|||||||
|
|
||||||
#include "module.hpp"
|
#include "module.hpp"
|
||||||
#include "utils/algorithm.hpp"
|
#include "utils/algorithm.hpp"
|
||||||
|
#include "utils/concepts.hpp"
|
||||||
#include "utils/logging.hpp"
|
#include "utils/logging.hpp"
|
||||||
#include "utils/math.hpp"
|
#include "utils/math.hpp"
|
||||||
#include "utils/memory.hpp"
|
#include "utils/memory.hpp"
|
||||||
#include "utils/string.hpp"
|
#include "utils/string.hpp"
|
||||||
|
|
||||||
// This file contains implementation of top level C API functions, but this is
|
// This file contains implementation of top level C API functions, but this is
|
||||||
// all actually part of query::procedure. So use that namespace for simplicity.
|
// all actually part of query::procedure. So use that namespace for simplicity.
|
||||||
// NOLINTNEXTLINE(google-build-using-namespace)
|
// NOLINTNEXTLINE(google-build-using-namespace)
|
||||||
@ -1343,27 +1345,38 @@ int mgp_proc_add_opt_arg(mgp_proc *proc, const char *name, const mgp_type *type,
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
int AddResultToProc(mgp_proc *proc, const char *name, const mgp_type *type, bool is_deprecated) {
|
template <typename T>
|
||||||
if (!proc || !type) return 0;
|
concept ModuleProperties = utils::SameAsAnyOf<T, mgp_proc, mgp_trans>;
|
||||||
if (!IsValidIdentifierName(name)) return 0;
|
|
||||||
if (proc->results.find(name) != proc->results.end()) return 0;
|
template <ModuleProperties T>
|
||||||
|
bool AddResultToProp(T *prop, const char *name, const mgp_type *type, bool is_deprecated) {
|
||||||
|
if (!prop || !type) return false;
|
||||||
|
if (!IsValidIdentifierName(name)) return false;
|
||||||
|
if (prop->results.find(name) != prop->results.end()) return false;
|
||||||
try {
|
try {
|
||||||
auto *memory = proc->results.get_allocator().GetMemoryResource();
|
auto *memory = prop->results.get_allocator().GetMemoryResource();
|
||||||
proc->results.emplace(utils::pmr::string(name, memory), std::make_pair(type->impl.get(), is_deprecated));
|
prop->results.emplace(utils::pmr::string(name, memory), std::make_pair(type->impl.get(), is_deprecated));
|
||||||
return 1;
|
return true;
|
||||||
} catch (...) {
|
} catch (...) {
|
||||||
return 0;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
int mgp_proc_add_result(mgp_proc *proc, const char *name, const mgp_type *type) {
|
int mgp_proc_add_result(mgp_proc *proc, const char *name, const mgp_type *type) {
|
||||||
return AddResultToProc(proc, name, type, false);
|
return AddResultToProp(proc, name, type, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MgpTransAddFixedResult(mgp_trans *trans) {
|
||||||
|
if (int err = AddResultToProp(trans, "query", mgp_type_string(), false); err != 1) {
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
return AddResultToProp(trans, "parameters", mgp_type_nullable(mgp_type_list(mgp_type_any())), false);
|
||||||
}
|
}
|
||||||
|
|
||||||
int mgp_proc_add_deprecated_result(mgp_proc *proc, const char *name, const mgp_type *type) {
|
int mgp_proc_add_deprecated_result(mgp_proc *proc, const char *name, const mgp_type *type) {
|
||||||
return AddResultToProc(proc, name, type, true);
|
return AddResultToProp(proc, name, type, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
int mgp_must_abort(const mgp_graph *graph) {
|
int mgp_must_abort(const mgp_graph *graph) {
|
||||||
|
@ -478,21 +478,23 @@ struct mgp_trans {
|
|||||||
|
|
||||||
/// @throw std::bad_alloc
|
/// @throw std::bad_alloc
|
||||||
/// @throw std::length_error
|
/// @throw std::length_error
|
||||||
mgp_trans(const char *name, mgp_trans_cb cb, utils::MemoryResource *memory) : name(name, memory), cb(cb) {}
|
mgp_trans(const char *name, mgp_trans_cb cb, utils::MemoryResource *memory)
|
||||||
|
: name(name, memory), cb(cb), results(memory) {}
|
||||||
|
|
||||||
/// @throw std::bad_alloc
|
/// @throw std::bad_alloc
|
||||||
/// @throw std::length_error
|
/// @throw std::length_error
|
||||||
mgp_trans(const char *name,
|
mgp_trans(const char *name,
|
||||||
std::function<void(const mgp_messages *, const mgp_graph *, mgp_result *, mgp_memory *)> cb,
|
std::function<void(const mgp_messages *, const mgp_graph *, mgp_result *, mgp_memory *)> cb,
|
||||||
utils::MemoryResource *memory)
|
utils::MemoryResource *memory)
|
||||||
: name(name, memory), cb(cb) {}
|
: name(name, memory), cb(cb), results(memory) {}
|
||||||
|
|
||||||
/// @throw std::bad_alloc
|
/// @throw std::bad_alloc
|
||||||
/// @throw std::length_error
|
/// @throw std::length_error
|
||||||
mgp_trans(const mgp_trans &other, utils::MemoryResource *memory) : name(other.name, memory), cb(other.cb) {}
|
mgp_trans(const mgp_trans &other, utils::MemoryResource *memory)
|
||||||
|
: name(other.name, memory), cb(other.cb), results(other.results) {}
|
||||||
|
|
||||||
mgp_trans(mgp_trans &&other, utils::MemoryResource *memory)
|
mgp_trans(mgp_trans &&other, utils::MemoryResource *memory)
|
||||||
: name(std::move(other.name), memory), cb(std::move(other.cb)) {}
|
: name(std::move(other.name), memory), cb(std::move(other.cb)), results(std::move(other.results)) {}
|
||||||
|
|
||||||
mgp_trans(const mgp_trans &other) = default;
|
mgp_trans(const mgp_trans &other) = default;
|
||||||
mgp_trans(mgp_trans &&other) = default;
|
mgp_trans(mgp_trans &&other) = default;
|
||||||
@ -505,9 +507,13 @@ struct mgp_trans {
|
|||||||
/// Name of the transformation.
|
/// Name of the transformation.
|
||||||
utils::pmr::string name;
|
utils::pmr::string name;
|
||||||
/// Entry-point for the transformation.
|
/// Entry-point for the transformation.
|
||||||
std::function<void(const mgp_messages *, mgp_graph *, mgp_result *, mgp_memory *)> cb;
|
std::function<void(const mgp_messages *, const mgp_graph *, mgp_result *, mgp_memory *)> cb;
|
||||||
|
/// Fields this transformation returns.
|
||||||
|
utils::pmr::map<utils::pmr::string, std::pair<const query::procedure::CypherType *, bool>> results;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
bool MgpTransAddFixedResult(mgp_trans *trans);
|
||||||
|
|
||||||
struct mgp_module {
|
struct mgp_module {
|
||||||
using allocator_type = utils::Allocator<mgp_module>;
|
using allocator_type = utils::Allocator<mgp_module>;
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ extern "C" {
|
|||||||
|
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
|
||||||
|
#include "fmt/format.h"
|
||||||
#include "py/py.hpp"
|
#include "py/py.hpp"
|
||||||
#include "query/procedure/py_module.hpp"
|
#include "query/procedure/py_module.hpp"
|
||||||
#include "utils/file.hpp"
|
#include "utils/file.hpp"
|
||||||
@ -275,30 +276,44 @@ bool SharedLibraryModule::Load(const std::filesystem::path &file_path) {
|
|||||||
}
|
}
|
||||||
// Get required mgp_init_module
|
// Get required mgp_init_module
|
||||||
init_fn_ = reinterpret_cast<int (*)(mgp_module *, mgp_memory *)>(dlsym(handle_, "mgp_init_module"));
|
init_fn_ = reinterpret_cast<int (*)(mgp_module *, mgp_memory *)>(dlsym(handle_, "mgp_init_module"));
|
||||||
const char *error = dlerror();
|
char *dl_errored = dlerror();
|
||||||
if (!init_fn_ || error) {
|
if (!init_fn_ || dl_errored) {
|
||||||
spdlog::error("Unable to load module {}; {}", file_path, error);
|
spdlog::error("Unable to load module {}; {}", file_path, dl_errored);
|
||||||
dlclose(handle_);
|
dlclose(handle_);
|
||||||
handle_ = nullptr;
|
handle_ = nullptr;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!WithModuleRegistration(&procedures_, &transformations_, [&](auto *module_def, auto *memory) {
|
auto module_cb = [&](auto *module_def, auto *memory) {
|
||||||
// Run mgp_init_module which must succeed.
|
// Run mgp_init_module which must succeed.
|
||||||
int init_res = init_fn_(module_def, memory);
|
int init_res = init_fn_(module_def, memory);
|
||||||
if (init_res != 0) {
|
auto with_error = [this](std::string_view error_msg) {
|
||||||
spdlog::error("Unable to load module {}; mgp_init_module_returned {}", file_path, init_res);
|
spdlog::error(error_msg);
|
||||||
dlclose(handle_);
|
dlclose(handle_);
|
||||||
handle_ = nullptr;
|
handle_ = nullptr;
|
||||||
return false;
|
return false;
|
||||||
}
|
};
|
||||||
return true;
|
|
||||||
})) {
|
if (init_res != 0) {
|
||||||
|
const auto error = fmt::format("Unable to load module {}; mgp_init_module_returned {} ", file_path, init_res);
|
||||||
|
return with_error(error);
|
||||||
|
}
|
||||||
|
for (auto &trans : module_def->transformations) {
|
||||||
|
const bool was_result_added = MgpTransAddFixedResult(&trans.second);
|
||||||
|
if (!was_result_added) {
|
||||||
|
const auto error =
|
||||||
|
fmt::format("Unable to add result to transformation in module {}; add result failed", file_path);
|
||||||
|
return with_error(error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
if (!WithModuleRegistration(&procedures_, &transformations_, module_cb)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// Get optional mgp_shutdown_module
|
// Get optional mgp_shutdown_module
|
||||||
shutdown_fn_ = reinterpret_cast<int (*)()>(dlsym(handle_, "mgp_shutdown_module"));
|
shutdown_fn_ = reinterpret_cast<int (*)()>(dlsym(handle_, "mgp_shutdown_module"));
|
||||||
error = dlerror();
|
dl_errored = dlerror();
|
||||||
if (error) spdlog::warn("When loading module {}; {}", file_path, error);
|
if (dl_errored) spdlog::warn("When loading module {}; {}", file_path, dl_errored);
|
||||||
spdlog::info("Loaded module {}", file_path);
|
spdlog::info("Loaded module {}", file_path);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -376,11 +391,23 @@ bool PythonModule::Load(const std::filesystem::path &file_path) {
|
|||||||
spdlog::error("Unable to load module {}; {}", file_path, *maybe_exc);
|
spdlog::error("Unable to load module {}; {}", file_path, *maybe_exc);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
py_module_ = WithModuleRegistration(&procedures_, &transformations_, [&](auto *module_def, auto *memory) {
|
bool succ = true;
|
||||||
return ImportPyModule(file_path.stem().c_str(), module_def);
|
auto module_cb = [&](auto *module_def, auto *memory) {
|
||||||
});
|
auto result = ImportPyModule(file_path.stem().c_str(), module_def);
|
||||||
|
for (auto &trans : module_def->transformations) {
|
||||||
|
succ = MgpTransAddFixedResult(&trans.second);
|
||||||
|
if (!succ) return result;
|
||||||
|
};
|
||||||
|
return result;
|
||||||
|
};
|
||||||
|
py_module_ = WithModuleRegistration(&procedures_, &transformations_, module_cb);
|
||||||
if (py_module_) {
|
if (py_module_) {
|
||||||
spdlog::info("Loaded module {}", file_path);
|
spdlog::info("Loaded module {}", file_path);
|
||||||
|
|
||||||
|
if (!succ) {
|
||||||
|
spdlog::error("Unable to add result to transformation");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
auto exc_info = py::FetchError().value();
|
auto exc_info = py::FetchError().value();
|
||||||
@ -566,6 +593,7 @@ std::optional<std::pair<std::string_view, std::string_view>> FindModuleNameAndPr
|
|||||||
if (name_parts.size() == 1U) return std::nullopt;
|
if (name_parts.size() == 1U) return std::nullopt;
|
||||||
auto last_dot_pos = fully_qualified_name.find_last_of('.');
|
auto last_dot_pos = fully_qualified_name.find_last_of('.');
|
||||||
MG_ASSERT(last_dot_pos != std::string_view::npos);
|
MG_ASSERT(last_dot_pos != std::string_view::npos);
|
||||||
|
|
||||||
const auto &module_name = fully_qualified_name.substr(0, last_dot_pos);
|
const auto &module_name = fully_qualified_name.substr(0, last_dot_pos);
|
||||||
const auto &name = name_parts.back();
|
const auto &name = name_parts.back();
|
||||||
return std::make_pair(module_name, name);
|
return std::make_pair(module_name, name);
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "query/procedure/mg_procedure_impl.hpp"
|
#include "query/procedure/mg_procedure_impl.hpp"
|
||||||
|
#include "utils/pmr/vector.hpp"
|
||||||
|
|
||||||
namespace query::procedure {
|
namespace query::procedure {
|
||||||
|
|
||||||
@ -178,7 +179,7 @@ PyObject *PyGraphGetVertexById(PyGraph *self, PyObject *args) {
|
|||||||
MG_ASSERT(self->graph);
|
MG_ASSERT(self->graph);
|
||||||
MG_ASSERT(self->memory);
|
MG_ASSERT(self->memory);
|
||||||
static_assert(std::is_same_v<int64_t, long>);
|
static_assert(std::is_same_v<int64_t, long>);
|
||||||
int64_t id;
|
int64_t id = 0;
|
||||||
if (!PyArg_ParseTuple(args, "l", &id)) return nullptr;
|
if (!PyArg_ParseTuple(args, "l", &id)) return nullptr;
|
||||||
auto *vertex = mgp_graph_get_vertex_by_id(self->graph, mgp_vertex_id{id}, self->memory);
|
auto *vertex = mgp_graph_get_vertex_by_id(self->graph, mgp_vertex_id{id}, self->memory);
|
||||||
if (!vertex) {
|
if (!vertex) {
|
||||||
@ -400,6 +401,188 @@ struct PyQueryModule {
|
|||||||
};
|
};
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
|
struct PyMessages {
|
||||||
|
PyObject_HEAD;
|
||||||
|
const mgp_messages *messages;
|
||||||
|
mgp_memory *memory;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct PyMessage {
|
||||||
|
PyObject_HEAD;
|
||||||
|
const mgp_message *message;
|
||||||
|
const PyMessages *messages;
|
||||||
|
mgp_memory *memory;
|
||||||
|
};
|
||||||
|
|
||||||
|
PyObject *PyMessagesIsValid(const PyMessages *self, PyObject *Py_UNUSED(ignored)) {
|
||||||
|
return PyBool_FromLong(!!self->messages);
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject *PyMessageIsValid(PyMessage *self, PyObject *Py_UNUSED(ignored)) {
|
||||||
|
return PyMessagesIsValid(self->messages, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject *PyMessageGetPayload(PyMessage *self, PyObject *Py_UNUSED(ignored)) {
|
||||||
|
MG_ASSERT(self->message);
|
||||||
|
auto payload_size = mgp_message_get_payload_size(self->message);
|
||||||
|
const auto *payload = mgp_message_get_payload(self->message);
|
||||||
|
auto *raw_bytes = PyByteArray_FromStringAndSize(payload, payload_size);
|
||||||
|
if (!raw_bytes) {
|
||||||
|
PyErr_SetString(PyExc_RuntimeError, "Unable to get raw bytes from payload");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return raw_bytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject *PyMessageGetTopicName(PyMessage *self, PyObject *Py_UNUSED(ignored)) {
|
||||||
|
MG_ASSERT(self->message);
|
||||||
|
MG_ASSERT(self->memory);
|
||||||
|
const auto *topic_name = mgp_message_topic_name(self->message);
|
||||||
|
auto *py_topic_name = PyUnicode_FromString(topic_name);
|
||||||
|
if (!py_topic_name) {
|
||||||
|
PyErr_SetString(PyExc_RuntimeError, "Unable to get raw bytes from payload");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return py_topic_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject *PyMessageGetKey(PyMessage *self, PyObject *Py_UNUSED(ignored)) {
|
||||||
|
MG_ASSERT(self->message);
|
||||||
|
MG_ASSERT(self->memory);
|
||||||
|
auto key_size = mgp_message_key_size(self->message);
|
||||||
|
const auto *key = mgp_message_key(self->message);
|
||||||
|
auto *raw_bytes = PyByteArray_FromStringAndSize(key, key_size);
|
||||||
|
if (!raw_bytes) {
|
||||||
|
PyErr_SetString(PyExc_RuntimeError, "Unable to get raw bytes from payload");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return raw_bytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject *PyMessageGetTimestamp(PyMessage *self, PyObject *Py_UNUSED(ignored)) {
|
||||||
|
MG_ASSERT(self->message);
|
||||||
|
MG_ASSERT(self->memory);
|
||||||
|
auto timestamp = mgp_message_timestamp(self->message);
|
||||||
|
auto *py_int = PyLong_FromUnsignedLong(timestamp);
|
||||||
|
if (!py_int) {
|
||||||
|
PyErr_SetString(PyExc_IndexError, "Unable to get timestamp.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return py_int;
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
static PyMethodDef PyMessageMethods[] = {
|
||||||
|
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
|
||||||
|
{"is_valid", reinterpret_cast<PyCFunction>(PyMessageIsValid), METH_NOARGS,
|
||||||
|
"Return True if messages is in valid context and may be used."},
|
||||||
|
{"payload", reinterpret_cast<PyCFunction>(PyMessageGetPayload), METH_NOARGS, "Get payload"},
|
||||||
|
{"topic_name", reinterpret_cast<PyCFunction>(PyMessageGetTopicName), METH_NOARGS, "Get topic name."},
|
||||||
|
{"key", reinterpret_cast<PyCFunction>(PyMessageGetKey), METH_NOARGS, "Get message key."},
|
||||||
|
{"timestamp", reinterpret_cast<PyCFunction>(PyMessageGetTimestamp), METH_NOARGS, "Get message timestamp."},
|
||||||
|
{nullptr},
|
||||||
|
};
|
||||||
|
|
||||||
|
void PyMessageDealloc(PyMessage *self) {
|
||||||
|
MG_ASSERT(self->memory);
|
||||||
|
MG_ASSERT(self->message);
|
||||||
|
MG_ASSERT(self->messages);
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
Py_DECREF(self->messages);
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
Py_TYPE(self)->tp_free(self);
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
static PyTypeObject PyMessageType = {
|
||||||
|
PyVarObject_HEAD_INIT(nullptr, 0).tp_name = "_mgp.Message",
|
||||||
|
.tp_basicsize = sizeof(PyMessage),
|
||||||
|
.tp_dealloc = reinterpret_cast<destructor>(PyMessageDealloc),
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
.tp_flags = Py_TPFLAGS_DEFAULT,
|
||||||
|
.tp_doc = "Wraps struct mgp_message.",
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
.tp_methods = PyMessageMethods,
|
||||||
|
};
|
||||||
|
|
||||||
|
PyObject *PyMessagesInvalidate(PyMessages *self, PyObject *Py_UNUSED(ignored)) {
|
||||||
|
self->messages = nullptr;
|
||||||
|
self->memory = nullptr;
|
||||||
|
Py_RETURN_NONE;
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject *PyMessagesGetTotalMessages(PyMessages *self, PyObject *Py_UNUSED(ignored)) {
|
||||||
|
MG_ASSERT(self->messages);
|
||||||
|
MG_ASSERT(self->memory);
|
||||||
|
auto size = self->messages->messages.size();
|
||||||
|
auto *py_int = PyLong_FromSize_t(size);
|
||||||
|
if (!py_int) {
|
||||||
|
PyErr_SetString(PyExc_IndexError, "Unable to get timestamp.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return py_int;
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject *PyMessagesGetMessageAt(PyMessages *self, PyObject *args) {
|
||||||
|
MG_ASSERT(self->messages);
|
||||||
|
MG_ASSERT(self->memory);
|
||||||
|
int64_t id = 0;
|
||||||
|
if (!PyArg_ParseTuple(args, "l", &id)) return nullptr;
|
||||||
|
if (id < 0) return nullptr;
|
||||||
|
const auto *message = mgp_messages_at(self->messages, id);
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
auto *py_message = PyObject_New(PyMessage, &PyMessageType);
|
||||||
|
if (!py_message) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
py_message->message = message;
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
Py_INCREF(self);
|
||||||
|
py_message->messages = self;
|
||||||
|
py_message->memory = self->memory;
|
||||||
|
if (!message) {
|
||||||
|
PyErr_SetString(PyExc_IndexError, "Unable to find the message with given index.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
return reinterpret_cast<PyObject *>(py_message);
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
static PyMethodDef PyMessagesMethods[] = {
|
||||||
|
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
|
||||||
|
{"invalidate", reinterpret_cast<PyCFunction>(PyMessagesInvalidate), METH_NOARGS,
|
||||||
|
"Invalidate the messages context thus preventing the messages from being used"},
|
||||||
|
{"is_valid", reinterpret_cast<PyCFunction>(PyMessagesIsValid), METH_NOARGS,
|
||||||
|
"Return True if messages is in valid context and may be used."},
|
||||||
|
{"total_messages", reinterpret_cast<PyCFunction>(PyMessagesGetTotalMessages), METH_VARARGS,
|
||||||
|
"Get number of messages available"},
|
||||||
|
{"message_at", reinterpret_cast<PyCFunction>(PyMessagesGetMessageAt), METH_VARARGS,
|
||||||
|
"Get message at index idx from messages"},
|
||||||
|
{nullptr},
|
||||||
|
};
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
static PyTypeObject PyMessagesType = {
|
||||||
|
PyVarObject_HEAD_INIT(nullptr, 0).tp_name = "_mgp.Messages",
|
||||||
|
.tp_basicsize = sizeof(PyMessages),
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
.tp_flags = Py_TPFLAGS_DEFAULT,
|
||||||
|
.tp_doc = "Wraps struct mgp_messages.",
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
.tp_methods = PyMessagesMethods,
|
||||||
|
};
|
||||||
|
|
||||||
|
PyObject *MakePyMessages(const mgp_messages *msgs, mgp_memory *memory) {
|
||||||
|
MG_ASSERT(!msgs || (msgs && memory));
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
auto *py_messages = PyObject_New(PyMessages, &PyMessagesType);
|
||||||
|
if (!py_messages) return nullptr;
|
||||||
|
py_messages->messages = msgs;
|
||||||
|
py_messages->memory = memory;
|
||||||
|
return reinterpret_cast<PyObject *>(py_messages);
|
||||||
|
}
|
||||||
|
|
||||||
py::Object MgpListToPyTuple(const mgp_list *list, PyGraph *py_graph) {
|
py::Object MgpListToPyTuple(const mgp_list *list, PyGraph *py_graph) {
|
||||||
MG_ASSERT(list);
|
MG_ASSERT(list);
|
||||||
MG_ASSERT(py_graph);
|
MG_ASSERT(py_graph);
|
||||||
@ -585,6 +768,86 @@ void CallPythonProcedure(const py::Object &py_cb, const mgp_list *args, const mg
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CallPythonTransformation(const py::Object &py_cb, const mgp_messages *msgs, const mgp_graph *graph,
|
||||||
|
mgp_result *result, mgp_memory *memory) {
|
||||||
|
auto gil = py::EnsureGIL();
|
||||||
|
|
||||||
|
auto error_to_msg = [](const std::optional<py::ExceptionInfo> &exc_info) -> std::optional<std::string> {
|
||||||
|
if (!exc_info) return std::nullopt;
|
||||||
|
// Here we tell the traceback formatter to skip the first line of the
|
||||||
|
// traceback because that line will always be our wrapper function in our
|
||||||
|
// internal `mgp.py` file. With that line skipped, the user will always
|
||||||
|
// get only the relevant traceback that happened in his Python code.
|
||||||
|
return py::FormatException(*exc_info, /* skip_first_line = */ true);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto call = [&](py::Object py_graph, py::Object py_messages) -> std::optional<py::ExceptionInfo> {
|
||||||
|
auto py_res = py_cb.Call(py_messages, py_graph);
|
||||||
|
if (!py_res) return py::FetchError();
|
||||||
|
if (PySequence_Check(py_res.Ptr())) {
|
||||||
|
return AddMultipleRecordsFromPython(result, py_res);
|
||||||
|
}
|
||||||
|
return AddRecordFromPython(result, py_res);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto cleanup = [](py::Object py_graph, py::Object py_messages) {
|
||||||
|
// Run `gc.collect` (reference cycle-detection) explicitly, so that we are
|
||||||
|
// sure the procedure cleaned up everything it held references to. If the
|
||||||
|
// user stored a reference to one of our `_mgp` instances then the
|
||||||
|
// internally used `mgp_*` structs will stay unfreed and a memory leak
|
||||||
|
// will be reported at the end of the query execution.
|
||||||
|
py::Object gc(PyImport_ImportModule("gc"));
|
||||||
|
if (!gc) {
|
||||||
|
LOG_FATAL(py::FetchError().value());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!gc.CallMethod("collect")) {
|
||||||
|
LOG_FATAL(py::FetchError().value());
|
||||||
|
}
|
||||||
|
|
||||||
|
// After making sure all references from our side have been cleared,
|
||||||
|
// invalidate the `_mgp.Graph` object. If the user kept a reference to one
|
||||||
|
// of our `_mgp` instances then this will prevent them from using those
|
||||||
|
// objects (whose internal `mgp_*` pointers are now invalid and would cause
|
||||||
|
// a crash).
|
||||||
|
if (!py_graph.CallMethod("invalidate")) {
|
||||||
|
LOG_FATAL(py::FetchError().value());
|
||||||
|
}
|
||||||
|
if (!py_messages.CallMethod("invalidate")) {
|
||||||
|
LOG_FATAL(py::FetchError().value());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// It is *VERY IMPORTANT* to note that this code takes great care not to keep
|
||||||
|
// any extra references to any `_mgp` instances (except for `_mgp.Graph`), so
|
||||||
|
// as not to introduce extra reference counts and prevent their deallocation.
|
||||||
|
// In particular, the `ExceptionInfo` object has a `traceback` field that
|
||||||
|
// contains references to the Python frames and their arguments, and therefore
|
||||||
|
// our `_mgp` instances as well. Within this code we ensure not to keep the
|
||||||
|
// `ExceptionInfo` object alive so that no extra reference counts are
|
||||||
|
// introduced. We only fetch the error message and immediately destroy the
|
||||||
|
// object.
|
||||||
|
std::optional<std::string> maybe_msg;
|
||||||
|
{
|
||||||
|
py::Object py_graph(MakePyGraph(graph, memory));
|
||||||
|
py::Object py_messages(MakePyMessages(msgs, memory));
|
||||||
|
if (py_graph && py_messages) {
|
||||||
|
try {
|
||||||
|
maybe_msg = error_to_msg(call(py_graph, py_messages));
|
||||||
|
cleanup(py_graph, py_messages);
|
||||||
|
} catch (...) {
|
||||||
|
cleanup(py_graph, py_messages);
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
maybe_msg = error_to_msg(py::FetchError());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (maybe_msg) {
|
||||||
|
mgp_result_set_error_msg(result, maybe_msg->c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
PyObject *PyQueryModuleAddReadProcedure(PyQueryModule *self, PyObject *cb) {
|
PyObject *PyQueryModuleAddReadProcedure(PyQueryModule *self, PyObject *cb) {
|
||||||
@ -619,10 +882,41 @@ PyObject *PyQueryModuleAddReadProcedure(PyQueryModule *self, PyObject *cb) {
|
|||||||
return reinterpret_cast<PyObject *>(py_proc);
|
return reinterpret_cast<PyObject *>(py_proc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PyObject *PyQueryModuleAddTransformation(PyQueryModule *self, PyObject *cb) {
|
||||||
|
MG_ASSERT(self->module);
|
||||||
|
if (!PyCallable_Check(cb)) {
|
||||||
|
PyErr_SetString(PyExc_TypeError, "Expected a callable object.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto py_cb = py::Object::FromBorrow(cb);
|
||||||
|
py::Object py_name(py_cb.GetAttr("__name__"));
|
||||||
|
const auto *name = PyUnicode_AsUTF8(py_name.Ptr());
|
||||||
|
if (!name) return nullptr;
|
||||||
|
if (!IsValidIdentifierName(name)) {
|
||||||
|
PyErr_SetString(PyExc_ValueError, "Transformation name is not a valid identifier");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto *memory = self->module->transformations.get_allocator().GetMemoryResource();
|
||||||
|
mgp_trans trans(
|
||||||
|
name,
|
||||||
|
[py_cb](const mgp_messages *msgs, const mgp_graph *graph, mgp_result *result, mgp_memory *memory) {
|
||||||
|
CallPythonTransformation(py_cb, msgs, graph, result, memory);
|
||||||
|
},
|
||||||
|
memory);
|
||||||
|
const auto [trans_it, did_insert] = self->module->transformations.emplace(name, std::move(trans));
|
||||||
|
if (!did_insert) {
|
||||||
|
PyErr_SetString(PyExc_ValueError, "Already registered a procedure with the same name.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
Py_RETURN_NONE;
|
||||||
|
}
|
||||||
|
|
||||||
static PyMethodDef PyQueryModuleMethods[] = {
|
static PyMethodDef PyQueryModuleMethods[] = {
|
||||||
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
|
{"__reduce__", reinterpret_cast<PyCFunction>(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"},
|
||||||
{"add_read_procedure", reinterpret_cast<PyCFunction>(PyQueryModuleAddReadProcedure), METH_O,
|
{"add_read_procedure", reinterpret_cast<PyCFunction>(PyQueryModuleAddReadProcedure), METH_O,
|
||||||
"Register a read-only procedure with this module."},
|
"Register a read-only procedure with this module."},
|
||||||
|
{"add_transformation", reinterpret_cast<PyCFunction>(PyQueryModuleAddTransformation), METH_O,
|
||||||
|
"Register a transformation with this module."},
|
||||||
{nullptr},
|
{nullptr},
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1348,6 +1642,8 @@ PyObject *PyInitMgpModule() {
|
|||||||
if (!register_type(&PyVertexType, "Vertex")) return nullptr;
|
if (!register_type(&PyVertexType, "Vertex")) return nullptr;
|
||||||
if (!register_type(&PyPathType, "Path")) return nullptr;
|
if (!register_type(&PyPathType, "Path")) return nullptr;
|
||||||
if (!register_type(&PyCypherTypeType, "Type")) return nullptr;
|
if (!register_type(&PyCypherTypeType, "Type")) return nullptr;
|
||||||
|
if (!register_type(&PyMessagesType, "Messages")) return nullptr;
|
||||||
|
if (!register_type(&PyMessageType, "Message")) return nullptr;
|
||||||
Py_INCREF(Py_None);
|
Py_INCREF(Py_None);
|
||||||
if (PyModule_AddObject(mgp, "_MODULE", Py_None) < 0) {
|
if (PyModule_AddObject(mgp, "_MODULE", Py_None) < 0) {
|
||||||
Py_DECREF(Py_None);
|
Py_DECREF(Py_None);
|
||||||
|
@ -5,7 +5,8 @@
|
|||||||
#include "test_utils.hpp"
|
#include "test_utils.hpp"
|
||||||
|
|
||||||
TEST(MgpTransTest, TestMgpTransApi) {
|
TEST(MgpTransTest, TestMgpTransApi) {
|
||||||
constexpr auto no_op_cb = [](const mgp_messages *msg, mgp_graph *graph, mgp_result *result, mgp_memory *memory) {};
|
constexpr auto no_op_cb = [](const mgp_messages *msg, const mgp_graph *graph, mgp_result *result,
|
||||||
|
mgp_memory *memory) {};
|
||||||
mgp_module module(utils::NewDeleteResource());
|
mgp_module module(utils::NewDeleteResource());
|
||||||
// If this is false, then mgp_module_add_transformation()
|
// If this is false, then mgp_module_add_transformation()
|
||||||
// correctly calls IsValidIdentifier(). We don't need to test
|
// correctly calls IsValidIdentifier(). We don't need to test
|
||||||
|
Loading…
Reference in New Issue
Block a user