Add python & cpp batching option in procedures
* Add API for batching from the procedure * Use PoolResource for batched procedures
This commit is contained in:
parent
00226dee24
commit
d573eda8bb
@ -677,6 +677,16 @@ inline mgp_proc *module_add_write_procedure(mgp_module *module, const char *name
|
||||
return MgInvoke<mgp_proc *>(mgp_module_add_write_procedure, module, name, cb);
|
||||
}
|
||||
|
||||
inline mgp_proc *module_add_batch_read_procedure(mgp_module *module, const char *name, mgp_proc_cb cb,
|
||||
mgp_proc_initializer initializer, mgp_proc_cleanup cleanup) {
|
||||
return MgInvoke<mgp_proc *>(mgp_module_add_batch_read_procedure, module, name, cb, initializer, cleanup);
|
||||
}
|
||||
|
||||
inline mgp_proc *module_add_batch_write_procedure(mgp_module *module, const char *name, mgp_proc_cb cb,
|
||||
mgp_proc_initializer initializer, mgp_proc_cleanup cleanup) {
|
||||
return MgInvoke<mgp_proc *>(mgp_module_add_batch_write_procedure, module, name, cb, initializer, cleanup);
|
||||
}
|
||||
|
||||
inline void proc_add_arg(mgp_proc *proc, const char *name, mgp_type *type) {
|
||||
MgInvokeVoid(mgp_proc_add_arg, proc, name, type);
|
||||
}
|
||||
|
@ -1318,6 +1318,13 @@ MGP_ENUM_CLASS mgp_log_level{
|
||||
/// to allocate global resources.
|
||||
typedef void (*mgp_proc_cb)(struct mgp_list *, struct mgp_graph *, struct mgp_result *, struct mgp_memory *);
|
||||
|
||||
/// Cleanup for a query module read procedure. Can't be invoked through OpenCypher. Cleans batched stream.
|
||||
typedef void (*mgp_proc_cleanup)();
|
||||
|
||||
/// Initializer for a query module batched read procedure. Can't be invoked through OpenCypher. Initializes batched
|
||||
/// stream.
|
||||
typedef void (*mgp_proc_initializer)(struct mgp_list *, struct mgp_graph *, struct mgp_memory *);
|
||||
|
||||
/// Register a read-only procedure to a module.
|
||||
///
|
||||
/// The `name` must be a sequence of digits, underscores, lowercase and
|
||||
@ -1342,6 +1349,30 @@ enum mgp_error mgp_module_add_read_procedure(struct mgp_module *module, const ch
|
||||
enum mgp_error mgp_module_add_write_procedure(struct mgp_module *module, const char *name, mgp_proc_cb cb,
|
||||
struct mgp_proc **result);
|
||||
|
||||
/// Register a readable batched procedure to a module.
|
||||
///
|
||||
/// The `name` must be a valid identifier, following the same rules as the
|
||||
/// procedure`name` in mgp_module_add_read_procedure.
|
||||
///
|
||||
/// Return mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate memory for mgp_proc.
|
||||
/// Return mgp_error::MGP_ERROR_INVALID_ARGUMENT if `name` is not a valid procedure name.
|
||||
/// RETURN mgp_error::MGP_ERROR_LOGIC_ERROR if a procedure with the same name was already registered.
|
||||
enum mgp_error mgp_module_add_batch_read_procedure(struct mgp_module *module, const char *name, mgp_proc_cb cb,
|
||||
mgp_proc_initializer initializer, mgp_proc_cleanup cleanup,
|
||||
struct mgp_proc **result);
|
||||
|
||||
/// Register a writeable batched procedure to a module.
|
||||
///
|
||||
/// The `name` must be a valid identifier, following the same rules as the
|
||||
/// procedure`name` in mgp_module_add_read_procedure.
|
||||
///
|
||||
/// Return mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate memory for mgp_proc.
|
||||
/// Return mgp_error::MGP_ERROR_INVALID_ARGUMENT if `name` is not a valid procedure name.
|
||||
/// RETURN mgp_error::MGP_ERROR_LOGIC_ERROR if a procedure with the same name was already registered.
|
||||
enum mgp_error mgp_module_add_batch_write_procedure(struct mgp_module *module, const char *name, mgp_proc_cb cb,
|
||||
mgp_proc_initializer initializer, mgp_proc_cleanup cleanup,
|
||||
struct mgp_proc **result);
|
||||
|
||||
/// Add a required argument to a procedure.
|
||||
///
|
||||
/// The order of adding arguments will correspond to the order the procedure
|
||||
|
@ -1319,6 +1319,20 @@ inline void AddProcedure(mgp_proc_cb callback, std::string_view name, ProcedureT
|
||||
std::vector<Parameter> parameters, std::vector<Return> returns, mgp_module *module,
|
||||
mgp_memory *memory);
|
||||
|
||||
/// @brief Adds a batch procedure to the query module.
|
||||
/// @param callback - procedure callback
|
||||
/// @param initializer - procedure initializer
|
||||
/// @param cleanup - procedure cleanup
|
||||
/// @param name - procedure name
|
||||
/// @param proc_type - procedure type (read/write)
|
||||
/// @param parameters - procedure parameters
|
||||
/// @param returns - procedure return values
|
||||
/// @param module - the query module that the procedure is added to
|
||||
/// @param memory - access to memory
|
||||
inline void AddBatchProcedure(mgp_proc_cb callback, mgp_proc_initializer initializer, mgp_proc_cleanup cleanup,
|
||||
std::string_view name, ProcedureType proc_type, std::vector<Parameter> parameters,
|
||||
std::vector<Return> returns, mgp_module *module, mgp_memory *memory);
|
||||
|
||||
/// @brief Adds a function to the query module.
|
||||
/// @param callback - function callback
|
||||
/// @param name - function name
|
||||
@ -3430,14 +3444,11 @@ inline mgp_type *Return::GetMGPType() const {
|
||||
return util::ToMGPType(type_);
|
||||
}
|
||||
|
||||
void AddProcedure(mgp_proc_cb callback, std::string_view name, ProcedureType proc_type,
|
||||
std::vector<Parameter> parameters, std::vector<Return> returns, mgp_module *module,
|
||||
mgp_memory *memory) {
|
||||
auto proc = (proc_type == ProcedureType::Read) ? mgp::module_add_read_procedure(module, name.data(), callback)
|
||||
: mgp::module_add_write_procedure(module, name.data(), callback);
|
||||
|
||||
// do not enter
|
||||
namespace detail {
|
||||
void AddParamsReturnsToProc(mgp_proc *proc, std::vector<Parameter> ¶meters, const std::vector<Return> &returns) {
|
||||
for (const auto ¶meter : parameters) {
|
||||
auto parameter_name = parameter.name.data();
|
||||
const auto *parameter_name = parameter.name.data();
|
||||
if (!parameter.optional) {
|
||||
mgp::proc_add_arg(proc, parameter_name, parameter.GetMGPType());
|
||||
} else {
|
||||
@ -3446,18 +3457,35 @@ void AddProcedure(mgp_proc_cb callback, std::string_view name, ProcedureType pro
|
||||
}
|
||||
|
||||
for (const auto return_ : returns) {
|
||||
auto return_name = return_.name.data();
|
||||
|
||||
const auto *return_name = return_.name.data();
|
||||
mgp::proc_add_result(proc, return_name, return_.GetMGPType());
|
||||
}
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
void AddProcedure(mgp_proc_cb callback, std::string_view name, ProcedureType proc_type,
|
||||
std::vector<Parameter> parameters, std::vector<Return> returns, mgp_module *module,
|
||||
mgp_memory *memory) {
|
||||
auto *proc = (proc_type == ProcedureType::Read) ? mgp::module_add_read_procedure(module, name.data(), callback)
|
||||
: mgp::module_add_write_procedure(module, name.data(), callback);
|
||||
detail::AddParamsReturnsToProc(proc, parameters, returns);
|
||||
}
|
||||
|
||||
void AddBatchProcedure(mgp_proc_cb callback, mgp_proc_initializer initializer, mgp_proc_cleanup cleanup,
|
||||
std::string_view name, ProcedureType proc_type, std::vector<Parameter> parameters,
|
||||
std::vector<Return> returns, mgp_module *module, mgp_memory *memory) {
|
||||
auto *proc = (proc_type == ProcedureType::Read)
|
||||
? mgp::module_add_batch_read_procedure(module, name.data(), callback, initializer, cleanup)
|
||||
: mgp::module_add_batch_write_procedure(module, name.data(), callback, initializer, cleanup);
|
||||
detail::AddParamsReturnsToProc(proc, parameters, returns);
|
||||
}
|
||||
|
||||
void AddFunction(mgp_func_cb callback, std::string_view name, std::vector<Parameter> parameters, mgp_module *module,
|
||||
mgp_memory *memory) {
|
||||
auto func = mgp::module_add_function(module, name.data(), callback);
|
||||
auto *func = mgp::module_add_function(module, name.data(), callback);
|
||||
|
||||
for (const auto ¶meter : parameters) {
|
||||
auto parameter_name = parameter.name.data();
|
||||
const auto *parameter_name = parameter.name.data();
|
||||
|
||||
if (!parameter.optional) {
|
||||
mgp::func_add_arg(func, parameter_name, parameter.GetMGPType());
|
||||
|
159
include/mgp.py
159
include/mgp.py
@ -1402,6 +1402,13 @@ class UnsupportedTypingError(Exception):
|
||||
super().__init__("Unsupported typing annotation '{}'".format(type_))
|
||||
|
||||
|
||||
class UnequalTypesError(Exception):
|
||||
"""Signals a typing annotation is not equal between types"""
|
||||
|
||||
def __init__(self, type1_: typing.Any, type2_: typing.Any):
|
||||
super().__init__(f"Unequal typing annotation '{type1_}' and '{type2_}'")
|
||||
|
||||
|
||||
def _typing_to_cypher_type(type_):
|
||||
"""Convert typing annotation to a _mgp.CypherType instance."""
|
||||
simple_types = {
|
||||
@ -1514,6 +1521,72 @@ def _typing_to_cypher_type(type_):
|
||||
return parse_typing(str(type_))
|
||||
|
||||
|
||||
def _is_typing_same(type1_, type2_):
|
||||
"""Convert typing annotation to a _mgp.CypherType instance."""
|
||||
simple_types = {
|
||||
typing.Any: 1,
|
||||
object: 2,
|
||||
list: 3,
|
||||
Any: 4,
|
||||
bool: 5,
|
||||
str: 6,
|
||||
int: 7,
|
||||
float: 8,
|
||||
Number: 9,
|
||||
Map: 10,
|
||||
Vertex: 11,
|
||||
Edge: 12,
|
||||
Path: 13,
|
||||
Date: 14,
|
||||
LocalTime: 15,
|
||||
LocalDateTime: 16,
|
||||
Duration: 17,
|
||||
}
|
||||
try:
|
||||
return simple_types[type1_] == simple_types[type2_]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
# skip type checks
|
||||
return True
|
||||
|
||||
complex_type1 = typing.get_origin(type1_)
|
||||
type_args1 = typing.get_args(type2_)
|
||||
|
||||
complex_type2 = typing.get_origin(type1_)
|
||||
type_args2 = typing.get_args(type2_)
|
||||
|
||||
if complex_type2 != complex_type1:
|
||||
raise UnequalTypesError(type1_, type2_)
|
||||
|
||||
if complex_type1 == typing.Union:
|
||||
contains_none_arg1 = type(None) in type_args1
|
||||
contains_none_arg2 = type(None) in type_args2
|
||||
|
||||
if contains_none_arg1 != contains_none_arg2:
|
||||
raise UnequalTypesError(type1_, type2_)
|
||||
|
||||
if contains_none_arg1:
|
||||
types1 = tuple(t for t in type_args1 if t is not type(None)) # noqa E721
|
||||
types2 = tuple(t for t in type_args2 if t is not type(None)) # noqa E721
|
||||
if len(types1) != len(types2):
|
||||
raise UnequalTypesError(types1, types2)
|
||||
if len(types1) == 1:
|
||||
(type_arg1,) = types1
|
||||
(type_arg2,) = types2
|
||||
else:
|
||||
type_arg1 = typing.Union.__getitem__(types1)
|
||||
type_arg2 = typing.Union.__getitem__(types2)
|
||||
return _is_typing_same(type_arg1, type_arg2)
|
||||
elif complex_type1 == list:
|
||||
(type_arg1,) = type_args1
|
||||
(type_arg2,) = type_args2
|
||||
return _is_typing_same(type_arg1, type_arg2)
|
||||
# skip type checks
|
||||
return True
|
||||
|
||||
|
||||
# Procedure registration
|
||||
|
||||
|
||||
@ -1673,6 +1746,92 @@ def write_proc(func: typing.Callable[..., Record]):
|
||||
return _register_proc(func, True)
|
||||
|
||||
|
||||
def _register_batch_proc(
|
||||
func: typing.Callable[..., Record], initializer: typing.Callable, cleanup: typing.Callable, is_write: bool
|
||||
):
|
||||
raise_if_does_not_meet_requirements(func)
|
||||
register_func = _mgp.Module.add_batch_write_procedure if is_write else _mgp.Module.add_batch_read_procedure
|
||||
func_sig = inspect.signature(func)
|
||||
func_params = tuple(func_sig.parameters.values())
|
||||
|
||||
initializer_sig = inspect.signature(initializer)
|
||||
initializer_params = tuple(initializer_sig.parameters.values())
|
||||
|
||||
assert (
|
||||
func_params and initializer_params or not func_params and not initializer_params
|
||||
), "Both function params and initializer params must exist or not exist"
|
||||
|
||||
assert len(func_params) == len(initializer_params), "Number of params must be same"
|
||||
|
||||
assert initializer_sig.return_annotation is initializer_sig.empty, "Initializer can't return anything"
|
||||
|
||||
if func_params and func_params[0].annotation is ProcCtx:
|
||||
assert (
|
||||
initializer_params and initializer_params[0].annotation is ProcCtx
|
||||
), "Initializer must have mgp.ProcCtx as first parameter"
|
||||
|
||||
@wraps(func)
|
||||
def wrapper_func(graph, args):
|
||||
return func(ProcCtx(graph), *args)
|
||||
|
||||
@wraps(initializer)
|
||||
def wrapper_initializer(graph, args):
|
||||
return initializer(ProcCtx(graph), *args)
|
||||
|
||||
func_params = func_params[1:]
|
||||
initializer_params = initializer_params[1:]
|
||||
mgp_proc = register_func(_mgp._MODULE, wrapper_func, wrapper_initializer, cleanup)
|
||||
else:
|
||||
|
||||
@wraps(func)
|
||||
def wrapper_func(graph, args):
|
||||
return func(*args)
|
||||
|
||||
@wraps(initializer)
|
||||
def wrapper_initializer(graph, args):
|
||||
return initializer(*args)
|
||||
|
||||
mgp_proc = register_func(_mgp._MODULE, wrapper_func, wrapper_initializer, cleanup)
|
||||
|
||||
for func_param, initializer_param in zip(func_params, initializer_params):
|
||||
func_param_name = func_param.name
|
||||
func_param_type_ = func_param.annotation
|
||||
if func_param_type_ is func_param.empty:
|
||||
func_param_type_ = object
|
||||
initializer_param_type_ = initializer_param.annotation
|
||||
if initializer_param.annotation is initializer_param.empty:
|
||||
initializer_param_type_ = object
|
||||
|
||||
assert _is_typing_same(
|
||||
func_param_type_, initializer_param_type_
|
||||
), "Types of initializer and function must be same"
|
||||
|
||||
func_cypher_type = _typing_to_cypher_type(func_param_type_)
|
||||
if func_param.default is func_param.empty:
|
||||
mgp_proc.add_arg(func_param_name, func_cypher_type)
|
||||
else:
|
||||
mgp_proc.add_opt_arg(func_param_name, func_cypher_type, func_param.default)
|
||||
if func_sig.return_annotation is not func_sig.empty:
|
||||
record = func_sig.return_annotation
|
||||
if not isinstance(record, Record):
|
||||
raise TypeError("Expected '{}' to return 'mgp.Record', got '{}'".format(func.__name__, type(record)))
|
||||
for name, type_ in record.fields.items():
|
||||
if isinstance(type_, Deprecated):
|
||||
cypher_type = _typing_to_cypher_type(type_.field_type)
|
||||
mgp_proc.add_deprecated_result(name, cypher_type)
|
||||
else:
|
||||
mgp_proc.add_result(name, _typing_to_cypher_type(type_))
|
||||
return func
|
||||
|
||||
|
||||
def add_batch_write_proc(func: typing.Callable[..., Record], initializer: typing.Callable, cleanup: typing.Callable):
|
||||
return _register_batch_proc(func, initializer, cleanup, True)
|
||||
|
||||
|
||||
def add_batch_read_proc(func: typing.Callable[..., Record], initializer: typing.Callable, cleanup: typing.Callable):
|
||||
return _register_batch_proc(func, initializer, cleanup, False)
|
||||
|
||||
|
||||
class InvalidMessageError(Exception):
|
||||
"""
|
||||
Signals using a message instance outside of the registered transformation.
|
||||
|
@ -10,6 +10,7 @@
|
||||
// licenses/APL.txt.
|
||||
|
||||
#include "query/interpreter.hpp"
|
||||
#include <bits/ranges_algo.h>
|
||||
#include <fmt/core.h>
|
||||
|
||||
#include <algorithm>
|
||||
@ -50,6 +51,7 @@
|
||||
#include "query/plan/planner.hpp"
|
||||
#include "query/plan/profile.hpp"
|
||||
#include "query/plan/vertex_count_cache.hpp"
|
||||
#include "query/procedure/module.hpp"
|
||||
#include "query/stream.hpp"
|
||||
#include "query/stream/common.hpp"
|
||||
#include "query/trigger.hpp"
|
||||
@ -78,7 +80,6 @@
|
||||
#include "utils/tsc.hpp"
|
||||
#include "utils/typeinfo.hpp"
|
||||
#include "utils/variant_helpers.hpp"
|
||||
|
||||
namespace memgraph::metrics {
|
||||
extern Event ReadQuery;
|
||||
extern Event WriteQuery;
|
||||
@ -1288,6 +1289,30 @@ inline static void TryCaching(const AstStorage &ast_storage, FrameChangeCollecto
|
||||
}
|
||||
}
|
||||
|
||||
bool IsCallBatchedProcedureQuery(const std::vector<memgraph::query::Clause *> &clauses) {
|
||||
EvaluationContext evaluation_context;
|
||||
|
||||
return std::ranges::any_of(clauses, [&evaluation_context](const auto &clause) -> bool {
|
||||
if (clause->GetTypeInfo() == CallProcedure::kType) {
|
||||
auto *call_procedure_clause = utils::Downcast<CallProcedure>(clause);
|
||||
|
||||
const auto &maybe_found = memgraph::query::procedure::FindProcedure(
|
||||
procedure::gModuleRegistry, call_procedure_clause->procedure_name_, evaluation_context.memory);
|
||||
if (!maybe_found) {
|
||||
throw QueryRuntimeException("There is no procedure named '{}'.", call_procedure_clause->procedure_name_);
|
||||
}
|
||||
const auto &[module, proc] = *maybe_found;
|
||||
if (proc->info.is_batched) {
|
||||
spdlog::trace("Using PoolResource for batched query procedure");
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
});
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string, TypedValue> *summary,
|
||||
InterpreterContext *interpreter_context, DbAccessor *dba,
|
||||
utils::MemoryResource *execution_memory, std::vector<Notification> *notifications,
|
||||
@ -1319,7 +1344,7 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string,
|
||||
contains_csv = true;
|
||||
}
|
||||
// If this is LOAD CSV query, use PoolResource without MonotonicMemoryResource as we want to reuse allocated memory
|
||||
auto use_monotonic_memory = !contains_csv;
|
||||
auto use_monotonic_memory = !contains_csv && !IsCallBatchedProcedureQuery(clauses);
|
||||
auto plan = CypherQueryToPlan(parsed_query.stripped_query.hash(), std::move(parsed_query.ast_storage), cypher_query,
|
||||
parsed_query.parameters,
|
||||
parsed_query.is_cacheable ? &interpreter_context->plan_cache : nullptr, dba);
|
||||
@ -1451,7 +1476,7 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra
|
||||
contains_csv = true;
|
||||
}
|
||||
// If this is LOAD CSV query, use PoolResource without MonotonicMemoryResource as we want to reuse allocated memory
|
||||
auto use_monotonic_memory = !contains_csv;
|
||||
auto use_monotonic_memory = !contains_csv && !IsCallBatchedProcedureQuery(clauses);
|
||||
|
||||
MG_ASSERT(cypher_query, "Cypher grammar should not allow other queries in PROFILE");
|
||||
Frame frame(0);
|
||||
@ -2858,14 +2883,14 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
|
||||
auto *profile_query = utils::Downcast<ProfileQuery>(parsed_query.query);
|
||||
cypher_query = profile_query->cypher_query_;
|
||||
}
|
||||
|
||||
if (const auto &clauses = cypher_query->single_query_->clauses_;
|
||||
std::any_of(clauses.begin(), clauses.end(),
|
||||
[](const auto *clause) { return clause->GetTypeInfo() == LoadCsv::kType; })) {
|
||||
IsCallBatchedProcedureQuery(clauses) || std::any_of(clauses.begin(), clauses.end(), [](const auto *clause) {
|
||||
return clause->GetTypeInfo() == LoadCsv::kType;
|
||||
})) {
|
||||
// Using PoolResource without MonotonicMemoryResouce for LOAD CSV reduces memory usage.
|
||||
// QueryExecution MemoryResource is mostly used for allocations done on Frame and storing `row`s
|
||||
query_executions_[query_executions_.size() - 1] = std::make_unique<QueryExecution>(
|
||||
utils::PoolResource(8, kExecutionPoolMaxBlockSize, utils::NewDeleteResource(), utils::NewDeleteResource()));
|
||||
query_executions_[query_executions_.size() - 1] = std::make_unique<QueryExecution>(utils::PoolResource(
|
||||
128, kExecutionPoolMaxBlockSize, utils::NewDeleteResource(), utils::NewDeleteResource()));
|
||||
query_execution_ptr = &query_executions_.back();
|
||||
}
|
||||
}
|
||||
|
@ -14,6 +14,7 @@
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
#include <queue>
|
||||
#include <random>
|
||||
#include <string>
|
||||
@ -4465,7 +4466,8 @@ namespace {
|
||||
|
||||
void CallCustomProcedure(const std::string_view fully_qualified_procedure_name, const mgp_proc &proc,
|
||||
const std::vector<Expression *> &args, mgp_graph &graph, ExpressionEvaluator *evaluator,
|
||||
utils::MemoryResource *memory, std::optional<size_t> memory_limit, mgp_result *result) {
|
||||
utils::MemoryResource *memory, std::optional<size_t> memory_limit, mgp_result *result,
|
||||
const bool call_initializer = false) {
|
||||
static_assert(std::uses_allocator_v<mgp_value, utils::Allocator<mgp_value>>,
|
||||
"Expected mgp_value to use custom allocator and makes STL "
|
||||
"containers aware of that");
|
||||
@ -4489,6 +4491,11 @@ void CallCustomProcedure(const std::string_view fully_qualified_procedure_name,
|
||||
}
|
||||
|
||||
procedure::ConstructArguments(args_list, proc, fully_qualified_procedure_name, proc_args, graph);
|
||||
if (call_initializer) {
|
||||
MG_ASSERT(proc.initializer);
|
||||
mgp_memory initializer_memory{memory};
|
||||
proc.initializer.value()(&proc_args, &graph, &initializer_memory);
|
||||
}
|
||||
if (memory_limit) {
|
||||
SPDLOG_INFO("Running '{}' with memory limit of {}", fully_qualified_procedure_name,
|
||||
utils::GetReadableSize(*memory_limit));
|
||||
@ -4516,18 +4523,22 @@ void CallCustomProcedure(const std::string_view fully_qualified_procedure_name,
|
||||
class CallProcedureCursor : public Cursor {
|
||||
const CallProcedure *self_;
|
||||
UniqueCursorPtr input_cursor_;
|
||||
mgp_result result_;
|
||||
decltype(result_.rows.end()) result_row_it_{result_.rows.end()};
|
||||
mgp_result *result_;
|
||||
decltype(result_->rows.end()) result_row_it_{result_->rows.end()};
|
||||
size_t result_signature_size_{0};
|
||||
bool stream_exhausted{true};
|
||||
bool call_initializer{false};
|
||||
std::optional<std::function<void()>> cleanup_{std::nullopt};
|
||||
|
||||
public:
|
||||
CallProcedureCursor(const CallProcedure *self, utils::MemoryResource *mem)
|
||||
: self_(self),
|
||||
input_cursor_(self_->input_->MakeCursor(mem)),
|
||||
// result_ needs to live throughout multiple Pull evaluations, until all
|
||||
// rows are produced. Therefore, we use the memory dedicated for the
|
||||
// whole execution.
|
||||
result_(nullptr, mem) {
|
||||
// rows are produced. We don't use the memory dedicated for QueryExecution (and Frame),
|
||||
// but memory dedicated for procedure to wipe result_ and everything allocated in procedure all at once.
|
||||
result_(utils::Allocator<mgp_result>(self_->memory_resource)
|
||||
.new_object<mgp_result>(nullptr, self_->memory_resource)) {
|
||||
MG_ASSERT(self_->result_fields_.size() == self_->result_symbols_.size(), "Incorrectly constructed CallProcedure");
|
||||
}
|
||||
|
||||
@ -4541,11 +4552,7 @@ class CallProcedureCursor : public Cursor {
|
||||
// empty result set vs procedures which return `void`. We currently don't
|
||||
// have procedures registering what they return.
|
||||
// This `while` loop will skip over empty results.
|
||||
while (result_row_it_ == result_.rows.end()) {
|
||||
if (!input_cursor_->Pull(frame, context)) return false;
|
||||
result_.signature = nullptr;
|
||||
result_.rows.clear();
|
||||
result_.error_msg.reset();
|
||||
while (result_row_it_ == result_->rows.end()) {
|
||||
// It might be a good idea to resolve the procedure name once, at the
|
||||
// start. Unfortunately, this could deadlock if we tried to invoke a
|
||||
// procedure from a module (read lock) and reload a module (write lock)
|
||||
@ -4565,30 +4572,61 @@ class CallProcedureCursor : public Cursor {
|
||||
self_->procedure_name_, get_proc_type_str(self_->is_write_),
|
||||
get_proc_type_str(proc->info.is_write));
|
||||
}
|
||||
if (!proc->info.is_batched) {
|
||||
stream_exhausted = true;
|
||||
}
|
||||
|
||||
if (stream_exhausted) {
|
||||
if (!input_cursor_->Pull(frame, context)) {
|
||||
if (proc->cleanup) {
|
||||
proc->cleanup.value()();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
stream_exhausted = false;
|
||||
if (proc->initializer) {
|
||||
call_initializer = true;
|
||||
MG_ASSERT(proc->cleanup);
|
||||
proc->cleanup.value()();
|
||||
}
|
||||
}
|
||||
if (!cleanup_ && proc->cleanup) [[unlikely]] {
|
||||
cleanup_.emplace(*proc->cleanup);
|
||||
}
|
||||
// Unpluging memory without calling destruct on each object since everything was allocated with this memory
|
||||
// resource
|
||||
self_->monotonic_memory.Release();
|
||||
result_ =
|
||||
utils::Allocator<mgp_result>(self_->memory_resource).new_object<mgp_result>(nullptr, self_->memory_resource);
|
||||
|
||||
const auto graph_view = proc->info.is_write ? storage::View::NEW : storage::View::OLD;
|
||||
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
|
||||
graph_view);
|
||||
|
||||
result_.signature = &proc->results;
|
||||
// Use evaluation memory, as invoking a procedure is akin to a simple
|
||||
// evaluation of an expression.
|
||||
result_->signature = &proc->results;
|
||||
|
||||
// Use special memory as invoking procedure is complex
|
||||
// TODO: This will probably need to be changed when we add support for
|
||||
// generator like procedures which yield a new result on each invocation.
|
||||
auto *memory = context.evaluation_context.memory;
|
||||
// generator like procedures which yield a new result on new query calls.
|
||||
auto *memory = self_->memory_resource;
|
||||
auto memory_limit = EvaluateMemoryLimit(&evaluator, self_->memory_limit_, self_->memory_scale_);
|
||||
auto graph = mgp_graph::WritableGraph(*context.db_accessor, graph_view, context);
|
||||
CallCustomProcedure(self_->procedure_name_, *proc, self_->arguments_, graph, &evaluator, memory, memory_limit,
|
||||
&result_);
|
||||
result_, call_initializer);
|
||||
|
||||
if (call_initializer) call_initializer = false;
|
||||
|
||||
// Reset result_.signature to nullptr, because outside of this scope we
|
||||
// will no longer hold a lock on the `module`. If someone were to reload
|
||||
// it, the pointer would be invalid.
|
||||
result_signature_size_ = result_.signature->size();
|
||||
result_.signature = nullptr;
|
||||
if (result_.error_msg) {
|
||||
throw QueryRuntimeException("{}: {}", self_->procedure_name_, *result_.error_msg);
|
||||
result_signature_size_ = result_->signature->size();
|
||||
result_->signature = nullptr;
|
||||
if (result_->error_msg) {
|
||||
throw QueryRuntimeException("{}: {}", self_->procedure_name_, *result_->error_msg);
|
||||
}
|
||||
result_row_it_ = result_.rows.begin();
|
||||
result_row_it_ = result_->rows.begin();
|
||||
|
||||
stream_exhausted = result_row_it_ == result_->rows.end();
|
||||
}
|
||||
|
||||
auto &values = result_row_it_->values;
|
||||
@ -4621,12 +4659,20 @@ class CallProcedureCursor : public Cursor {
|
||||
}
|
||||
|
||||
void Reset() override {
|
||||
result_.rows.clear();
|
||||
result_.error_msg.reset();
|
||||
input_cursor_->Reset();
|
||||
self_->monotonic_memory.Release();
|
||||
result_ =
|
||||
utils::Allocator<mgp_result>(self_->memory_resource).new_object<mgp_result>(nullptr, self_->memory_resource);
|
||||
if (cleanup_) {
|
||||
cleanup_.value()();
|
||||
}
|
||||
}
|
||||
|
||||
void Shutdown() override {}
|
||||
void Shutdown() override {
|
||||
self_->monotonic_memory.Release();
|
||||
if (cleanup_) {
|
||||
cleanup_.value()();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
UniqueCursorPtr CallProcedure::MakeCursor(utils::MemoryResource *mem) const {
|
||||
|
@ -2199,6 +2199,8 @@ class CallProcedure : public memgraph::query::plan::LogicalOperator {
|
||||
Expression *memory_limit_{nullptr};
|
||||
size_t memory_scale_{1024U};
|
||||
bool is_write_;
|
||||
mutable utils::MonotonicBufferResource monotonic_memory{1024UL * 1024UL};
|
||||
utils::MemoryResource *memory_resource = &monotonic_memory;
|
||||
|
||||
std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override {
|
||||
auto object = std::make_unique<CallProcedure>();
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -2731,6 +2731,7 @@ mgp_error mgp_type_nullable(mgp_type *type, mgp_type **result) {
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// @throw std::bad_alloc, std::length_error
|
||||
mgp_proc *mgp_module_add_procedure(mgp_module *module, const char *name, mgp_proc_cb cb,
|
||||
const ProcedureInfo &procedure_info) {
|
||||
if (!IsValidIdentifierName(name)) {
|
||||
@ -2741,9 +2742,24 @@ mgp_proc *mgp_module_add_procedure(mgp_module *module, const char *name, mgp_pro
|
||||
};
|
||||
|
||||
auto *memory = module->procedures.get_allocator().GetMemoryResource();
|
||||
// May throw std::bad_alloc, std::length_error
|
||||
return &module->procedures.emplace(name, mgp_proc(name, cb, memory, procedure_info)).first->second;
|
||||
}
|
||||
|
||||
/// @throw std::bad_alloc, std::length_error
|
||||
mgp_proc *mgp_module_add_batch_procedure(mgp_module *module, const char *name, mgp_proc_cb cb_batch,
|
||||
mgp_proc_initializer initializer, mgp_proc_cleanup cleanup,
|
||||
const ProcedureInfo &procedure_info) {
|
||||
if (!IsValidIdentifierName(name)) {
|
||||
throw std::invalid_argument{fmt::format("Invalid procedure name: {}", name)};
|
||||
}
|
||||
if (module->procedures.find(name) != module->procedures.end()) {
|
||||
throw std::logic_error{fmt::format("Procedure already exists with name '{}'", name)};
|
||||
};
|
||||
auto *memory = module->procedures.get_allocator().GetMemoryResource();
|
||||
return &module->procedures.emplace(name, mgp_proc(name, cb_batch, initializer, cleanup, memory, procedure_info))
|
||||
.first->second;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
mgp_error mgp_module_add_read_procedure(mgp_module *module, const char *name, mgp_proc_cb cb, mgp_proc **result) {
|
||||
@ -2754,6 +2770,28 @@ mgp_error mgp_module_add_write_procedure(mgp_module *module, const char *name, m
|
||||
return WrapExceptions([=] { return mgp_module_add_procedure(module, name, cb, {.is_write = true}); }, result);
|
||||
}
|
||||
|
||||
mgp_error mgp_module_add_batch_read_procedure(mgp_module *module, const char *name, mgp_proc_cb cb_batch,
|
||||
mgp_proc_initializer initializer, mgp_proc_cleanup cleanup,
|
||||
mgp_proc **result) {
|
||||
return WrapExceptions(
|
||||
[=] {
|
||||
return mgp_module_add_batch_procedure(module, name, cb_batch, initializer, cleanup,
|
||||
{.is_write = false, .is_batched = true});
|
||||
},
|
||||
result);
|
||||
}
|
||||
|
||||
mgp_error mgp_module_add_batch_write_procedure(mgp_module *module, const char *name, mgp_proc_cb cb_batch,
|
||||
mgp_proc_initializer initializer, mgp_proc_cleanup cleanup,
|
||||
mgp_proc **result) {
|
||||
return WrapExceptions(
|
||||
[=] {
|
||||
return mgp_module_add_batch_procedure(module, name, cb_batch, initializer, cleanup,
|
||||
{.is_write = true, .is_batched = true});
|
||||
},
|
||||
result);
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
concept IsCallable = memgraph::utils::SameAsAnyOf<T, mgp_proc, mgp_func>;
|
||||
|
@ -729,7 +729,8 @@ struct mgp_type {
|
||||
};
|
||||
|
||||
struct ProcedureInfo {
|
||||
bool is_write = false;
|
||||
bool is_write{false};
|
||||
bool is_batched{false};
|
||||
std::optional<memgraph::query::AuthQuery::Privilege> required_privilege = std::nullopt;
|
||||
};
|
||||
struct mgp_proc {
|
||||
@ -740,6 +741,33 @@ struct mgp_proc {
|
||||
mgp_proc(const char *name, mgp_proc_cb cb, memgraph::utils::MemoryResource *memory, const ProcedureInfo &info = {})
|
||||
: name(name, memory), cb(cb), args(memory), opt_args(memory), results(memory), info(info) {}
|
||||
|
||||
/// @throw std::bad_alloc
|
||||
/// @throw std::length_error
|
||||
mgp_proc(const char *name, mgp_proc_cb cb, mgp_proc_initializer initializer, mgp_proc_cleanup cleanup,
|
||||
memgraph::utils::MemoryResource *memory, const ProcedureInfo &info = {})
|
||||
: name(name, memory),
|
||||
cb(cb),
|
||||
initializer(initializer),
|
||||
cleanup(cleanup),
|
||||
args(memory),
|
||||
opt_args(memory),
|
||||
results(memory),
|
||||
info(info) {}
|
||||
|
||||
/// @throw std::bad_alloc
|
||||
/// @throw std::length_error
|
||||
mgp_proc(const char *name, std::function<void(mgp_list *, mgp_graph *, mgp_result *, mgp_memory *)> cb,
|
||||
std::function<void(mgp_list *, mgp_graph *, mgp_memory *)> initializer, std::function<void()> cleanup,
|
||||
memgraph::utils::MemoryResource *memory, const ProcedureInfo &info = {})
|
||||
: name(name, memory),
|
||||
cb(cb),
|
||||
initializer(initializer),
|
||||
cleanup(cleanup),
|
||||
args(memory),
|
||||
opt_args(memory),
|
||||
results(memory),
|
||||
info(info) {}
|
||||
|
||||
/// @throw std::bad_alloc
|
||||
/// @throw std::length_error
|
||||
mgp_proc(const char *name, std::function<void(mgp_list *, mgp_graph *, mgp_result *, mgp_memory *)> cb,
|
||||
@ -757,6 +785,8 @@ struct mgp_proc {
|
||||
mgp_proc(const mgp_proc &other, memgraph::utils::MemoryResource *memory)
|
||||
: name(other.name, memory),
|
||||
cb(other.cb),
|
||||
initializer(other.initializer),
|
||||
cleanup(other.cleanup),
|
||||
args(other.args, memory),
|
||||
opt_args(other.opt_args, memory),
|
||||
results(other.results, memory),
|
||||
@ -765,6 +795,8 @@ struct mgp_proc {
|
||||
mgp_proc(mgp_proc &&other, memgraph::utils::MemoryResource *memory)
|
||||
: name(std::move(other.name), memory),
|
||||
cb(std::move(other.cb)),
|
||||
initializer(other.initializer),
|
||||
cleanup(other.cleanup),
|
||||
args(std::move(other.args), memory),
|
||||
opt_args(std::move(other.opt_args), memory),
|
||||
results(std::move(other.results), memory),
|
||||
@ -782,6 +814,13 @@ struct mgp_proc {
|
||||
memgraph::utils::pmr::string name;
|
||||
/// Entry-point for the procedure.
|
||||
std::function<void(mgp_list *, mgp_graph *, mgp_result *, mgp_memory *)> cb;
|
||||
|
||||
/// Initializer for batched procedure.
|
||||
std::optional<std::function<void(mgp_list *, mgp_graph *, mgp_memory *)>> initializer;
|
||||
|
||||
/// Dtor for batched procedure.
|
||||
std::optional<std::function<void()>> cleanup;
|
||||
|
||||
/// Required, positional arguments as a (name, type) pair.
|
||||
memgraph::utils::pmr::vector<std::pair<memgraph::utils::pmr::string, const memgraph::query::procedure::CypherType *>>
|
||||
args;
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include "query/procedure/py_module.hpp"
|
||||
|
||||
#include <datetime.h>
|
||||
#include <methodobject.h>
|
||||
#include <pyerrors.h>
|
||||
#include <array>
|
||||
#include <optional>
|
||||
@ -54,6 +55,8 @@ PyObject *gMgpValueConversionError{nullptr}; // NOLINT(cppcoreguidelines-avo
|
||||
PyObject *gMgpSerializationError{nullptr}; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
PyObject *gMgpAuthorizationError{nullptr}; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
|
||||
constexpr bool kStartGarbageCollection{true};
|
||||
|
||||
// Returns true if an exception is raised
|
||||
bool RaiseExceptionFromErrorCode(const mgp_error error) {
|
||||
switch (error) {
|
||||
@ -945,20 +948,37 @@ std::optional<py::ExceptionInfo> AddMultipleRecordsFromPython(mgp_result *result
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::function<void()> PyObjectCleanup(py::Object &py_object) {
|
||||
return [py_object]() {
|
||||
// Run `gc.collect` (reference cycle-detection) explicitly, so that we are
|
||||
// sure the procedure cleaned up everything it held references to. If the
|
||||
// user stored a reference to one of our `_mgp` instances then the
|
||||
// internally used `mgp_*` structs will stay unfreed and a memory leak
|
||||
// will be reported at the end of the query execution.
|
||||
py::Object gc(PyImport_ImportModule("gc"));
|
||||
if (!gc) {
|
||||
LOG_FATAL(py::FetchError().value());
|
||||
}
|
||||
std::optional<py::ExceptionInfo> AddMultipleBatchRecordsFromPython(mgp_result *result, py::Object py_seq,
|
||||
mgp_memory *memory) {
|
||||
Py_ssize_t len = PySequence_Size(py_seq.Ptr());
|
||||
if (len == -1) return py::FetchError();
|
||||
result->rows.reserve(len);
|
||||
for (Py_ssize_t i = 0; i < len; ++i) {
|
||||
py::Object py_record(PySequence_GetItem(py_seq.Ptr(), i));
|
||||
if (!py_record) return py::FetchError();
|
||||
auto maybe_exc = AddRecordFromPython(result, py_record, memory);
|
||||
if (maybe_exc) return maybe_exc;
|
||||
}
|
||||
PySequence_DelSlice(py_seq.Ptr(), 0, PySequence_Size(py_seq.Ptr()));
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (!gc.CallMethod("collect")) {
|
||||
LOG_FATAL(py::FetchError().value());
|
||||
std::function<void()> PyObjectCleanup(py::Object &py_object, bool start_gc) {
|
||||
return [py_object, start_gc]() {
|
||||
if (start_gc) {
|
||||
// Run `gc.collect` (reference cycle-detection) explicitly, so that we are
|
||||
// sure the procedure cleaned up everything it held references to. If the
|
||||
// user stored a reference to one of our `_mgp` instances then the
|
||||
// internally used `mgp_*` structs will stay unfreed and a memory leak
|
||||
// will be reported at the end of the query execution.
|
||||
py::Object gc(PyImport_ImportModule("gc"));
|
||||
if (!gc) {
|
||||
LOG_FATAL(py::FetchError().value());
|
||||
}
|
||||
|
||||
if (!gc.CallMethod("collect")) {
|
||||
LOG_FATAL(py::FetchError().value());
|
||||
}
|
||||
}
|
||||
|
||||
// After making sure all references from our side have been cleared,
|
||||
@ -973,8 +993,7 @@ std::function<void()> PyObjectCleanup(py::Object &py_object) {
|
||||
}
|
||||
|
||||
void CallPythonProcedure(const py::Object &py_cb, mgp_list *args, mgp_graph *graph, mgp_result *result,
|
||||
mgp_memory *memory) {
|
||||
// *memory here is memory from `EvalContext`
|
||||
mgp_memory *memory, bool is_batched) {
|
||||
auto gil = py::EnsureGIL();
|
||||
|
||||
auto error_to_msg = [](const std::optional<py::ExceptionInfo> &exc_info) -> std::optional<std::string> {
|
||||
@ -992,10 +1011,12 @@ void CallPythonProcedure(const py::Object &py_cb, mgp_list *args, mgp_graph *gra
|
||||
auto py_res = py_cb.Call(py_graph, py_args);
|
||||
if (!py_res) return py::FetchError();
|
||||
if (PySequence_Check(py_res.Ptr())) {
|
||||
if (is_batched) {
|
||||
return AddMultipleBatchRecordsFromPython(result, py_res, memory);
|
||||
}
|
||||
return AddMultipleRecordsFromPython(result, py_res, memory);
|
||||
} else {
|
||||
return AddRecordFromPython(result, py_res, memory);
|
||||
}
|
||||
return AddRecordFromPython(result, py_res, memory);
|
||||
};
|
||||
|
||||
// It is *VERY IMPORTANT* to note that this code takes great care not to keep
|
||||
@ -1010,7 +1031,7 @@ void CallPythonProcedure(const py::Object &py_cb, mgp_list *args, mgp_graph *gra
|
||||
std::optional<std::string> maybe_msg;
|
||||
{
|
||||
py::Object py_graph(MakePyGraph(graph, memory));
|
||||
utils::OnScopeExit clean_up(PyObjectCleanup(py_graph));
|
||||
utils::OnScopeExit clean_up(PyObjectCleanup(py_graph, !is_batched));
|
||||
if (py_graph) {
|
||||
maybe_msg = error_to_msg(call(py_graph));
|
||||
} else {
|
||||
@ -1023,6 +1044,38 @@ void CallPythonProcedure(const py::Object &py_cb, mgp_list *args, mgp_graph *gra
|
||||
}
|
||||
}
|
||||
|
||||
void CallPythonCleanup(const py::Object &py_cleanup) {
|
||||
auto gil = py::EnsureGIL();
|
||||
|
||||
auto py_res = py_cleanup.Call();
|
||||
|
||||
py::Object gc(PyImport_ImportModule("gc"));
|
||||
if (!gc) {
|
||||
LOG_FATAL(py::FetchError().value());
|
||||
}
|
||||
|
||||
if (!gc.CallMethod("collect")) {
|
||||
LOG_FATAL(py::FetchError().value());
|
||||
}
|
||||
}
|
||||
|
||||
void CallPythonInitializer(const py::Object &py_initializer, mgp_list *args, mgp_graph *graph, mgp_memory *memory) {
|
||||
auto gil = py::EnsureGIL();
|
||||
|
||||
auto call = [&](py::Object py_graph) -> std::optional<py::ExceptionInfo> {
|
||||
py::Object py_args(MgpListToPyTuple(args, py_graph.Ptr()));
|
||||
if (!py_args) return py::FetchError();
|
||||
auto py_res = py_initializer.Call(py_graph, py_args);
|
||||
return {};
|
||||
};
|
||||
|
||||
{
|
||||
py::Object py_graph(MakePyGraph(graph, memory));
|
||||
utils::OnScopeExit clean_up_graph(PyObjectCleanup(py_graph, !kStartGarbageCollection));
|
||||
call(py_graph);
|
||||
}
|
||||
}
|
||||
|
||||
void CallPythonTransformation(const py::Object &py_cb, mgp_messages *msgs, mgp_graph *graph, mgp_result *result,
|
||||
mgp_memory *memory) {
|
||||
auto gil = py::EnsureGIL();
|
||||
@ -1059,8 +1112,8 @@ void CallPythonTransformation(const py::Object &py_cb, mgp_messages *msgs, mgp_g
|
||||
py::Object py_graph(MakePyGraph(graph, memory));
|
||||
py::Object py_messages(MakePyMessages(msgs, memory));
|
||||
|
||||
utils::OnScopeExit clean_up_graph(PyObjectCleanup(py_graph));
|
||||
utils::OnScopeExit clean_up_messages(PyObjectCleanup(py_messages));
|
||||
utils::OnScopeExit clean_up_graph(PyObjectCleanup(py_graph, kStartGarbageCollection));
|
||||
utils::OnScopeExit clean_up_messages(PyObjectCleanup(py_messages, kStartGarbageCollection));
|
||||
|
||||
if (py_graph && py_messages) {
|
||||
maybe_msg = error_to_msg(call(py_graph, py_messages));
|
||||
@ -1111,7 +1164,7 @@ void CallPythonFunction(const py::Object &py_cb, mgp_list *args, mgp_graph *grap
|
||||
std::optional<std::string> maybe_msg;
|
||||
{
|
||||
py::Object py_graph(MakePyGraph(graph, memory));
|
||||
utils::OnScopeExit clean_up(PyObjectCleanup(py_graph));
|
||||
utils::OnScopeExit clean_up(PyObjectCleanup(py_graph, kStartGarbageCollection));
|
||||
if (py_graph) {
|
||||
auto maybe_result = call(py_graph);
|
||||
if (!maybe_result.HasError()) {
|
||||
@ -1147,7 +1200,7 @@ PyObject *PyQueryModuleAddProcedure(PyQueryModule *self, PyObject *cb, bool is_w
|
||||
auto *memory = self->module->procedures.get_allocator().GetMemoryResource();
|
||||
mgp_proc proc(name,
|
||||
[py_cb](mgp_list *args, mgp_graph *graph, mgp_result *result, mgp_memory *memory) {
|
||||
CallPythonProcedure(py_cb, args, graph, result, memory);
|
||||
CallPythonProcedure(py_cb, args, graph, result, memory, false);
|
||||
},
|
||||
memory, {.is_write = is_write_procedure});
|
||||
const auto &[proc_it, did_insert] = self->module->procedures.emplace(name, std::move(proc));
|
||||
@ -1160,6 +1213,52 @@ PyObject *PyQueryModuleAddProcedure(PyQueryModule *self, PyObject *cb, bool is_w
|
||||
py_proc->callable = &proc_it->second;
|
||||
return reinterpret_cast<PyObject *>(py_proc);
|
||||
}
|
||||
|
||||
PyObject *PyQueryModuleAddBatchProcedure(PyQueryModule *self, PyObject *args, bool is_write_procedure) {
|
||||
MG_ASSERT(self->module);
|
||||
PyObject *cb{nullptr};
|
||||
PyObject *initializer{nullptr};
|
||||
PyObject *cleanup{nullptr};
|
||||
|
||||
if (!PyArg_ParseTuple(args, "OOO", &cb, &initializer, &cleanup)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!PyCallable_Check(cb)) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a callable object.");
|
||||
return nullptr;
|
||||
}
|
||||
auto py_cb = py::Object::FromBorrow(cb);
|
||||
auto py_initializer = py::Object::FromBorrow(initializer);
|
||||
auto py_cleanup = py::Object::FromBorrow(cleanup);
|
||||
py::Object py_name(py_cb.GetAttr("__name__"));
|
||||
const auto *name = PyUnicode_AsUTF8(py_name.Ptr());
|
||||
if (!name) return nullptr;
|
||||
if (!IsValidIdentifierName(name)) {
|
||||
PyErr_SetString(PyExc_ValueError, "Procedure name is not a valid identifier");
|
||||
return nullptr;
|
||||
}
|
||||
auto *memory = self->module->procedures.get_allocator().GetMemoryResource();
|
||||
mgp_proc proc(
|
||||
name,
|
||||
[py_cb](mgp_list *args, mgp_graph *graph, mgp_result *result, mgp_memory *memory) {
|
||||
CallPythonProcedure(py_cb, args, graph, result, memory, true);
|
||||
},
|
||||
[py_initializer](mgp_list *args, mgp_graph *graph, mgp_memory *memory) {
|
||||
CallPythonInitializer(py_initializer, args, graph, memory);
|
||||
},
|
||||
[py_cleanup] { CallPythonCleanup(py_cleanup); }, memory, {.is_write = is_write_procedure, .is_batched = true});
|
||||
const auto &[proc_it, did_insert] = self->module->procedures.emplace(name, std::move(proc));
|
||||
if (!did_insert) {
|
||||
PyErr_SetString(PyExc_ValueError, "Already registered a procedure with the same name.");
|
||||
return nullptr;
|
||||
}
|
||||
auto *py_proc = PyObject_New(PyQueryProc, &PyQueryProcType); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
|
||||
if (!py_proc) return nullptr;
|
||||
py_proc->callable = &proc_it->second;
|
||||
return reinterpret_cast<PyObject *>(py_proc);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
PyObject *PyQueryModuleAddReadProcedure(PyQueryModule *self, PyObject *cb) {
|
||||
@ -1170,6 +1269,14 @@ PyObject *PyQueryModuleAddWriteProcedure(PyQueryModule *self, PyObject *cb) {
|
||||
return PyQueryModuleAddProcedure(self, cb, true);
|
||||
}
|
||||
|
||||
PyObject *PyQueryModuleAddBatchReadProcedure(PyQueryModule *self, PyObject *args) {
|
||||
return PyQueryModuleAddBatchProcedure(self, args, false);
|
||||
}
|
||||
|
||||
PyObject *PyQueryModuleAddBatchWriteProcedure(PyQueryModule *self, PyObject *args) {
|
||||
return PyQueryModuleAddBatchProcedure(self, args, true);
|
||||
}
|
||||
|
||||
PyObject *PyQueryModuleAddTransformation(PyQueryModule *self, PyObject *cb) {
|
||||
MG_ASSERT(self->module);
|
||||
if (!PyCallable_Check(cb)) {
|
||||
@ -1238,6 +1345,10 @@ static PyMethodDef PyQueryModuleMethods[] = {
|
||||
"Register a read-only procedure with this module."},
|
||||
{"add_write_procedure", reinterpret_cast<PyCFunction>(PyQueryModuleAddWriteProcedure), METH_O,
|
||||
"Register a writeable procedure with this module."},
|
||||
{"add_batch_read_procedure", reinterpret_cast<PyCFunction>(PyQueryModuleAddBatchReadProcedure), METH_VARARGS,
|
||||
"Register a read-only batch procedure with this module."},
|
||||
{"add_batch_write_procedure", reinterpret_cast<PyCFunction>(PyQueryModuleAddBatchWriteProcedure), METH_VARARGS,
|
||||
"Register a writeable batched procedure with this module."},
|
||||
{"add_transformation", reinterpret_cast<PyCFunction>(PyQueryModuleAddTransformation), METH_O,
|
||||
"Register a transformation with this module."},
|
||||
{"add_function", reinterpret_cast<PyCFunction>(PyQueryModuleAddFunction), METH_O,
|
||||
|
@ -58,6 +58,7 @@ add_subdirectory(mock_api)
|
||||
add_subdirectory(load_csv)
|
||||
add_subdirectory(init_file_flags)
|
||||
add_subdirectory(analytical_mode)
|
||||
add_subdirectory(batched_procedures)
|
||||
|
||||
copy_e2e_python_files(pytest_runner pytest_runner.sh "")
|
||||
file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/memgraph-selfsigned.crt DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
|
||||
|
9
tests/e2e/batched_procedures/CMakeLists.txt
Normal file
9
tests/e2e/batched_procedures/CMakeLists.txt
Normal file
@ -0,0 +1,9 @@
|
||||
function(copy_batched_procedures_e2e_python_files FILE_NAME)
|
||||
copy_e2e_python_files(batched_procedures ${FILE_NAME})
|
||||
endfunction()
|
||||
|
||||
copy_batched_procedures_e2e_python_files(common.py)
|
||||
copy_batched_procedures_e2e_python_files(conftest.py)
|
||||
copy_batched_procedures_e2e_python_files(simple_read.py)
|
||||
|
||||
add_subdirectory(procedures)
|
0
tests/e2e/batched_procedures/__init__.py
Normal file
0
tests/e2e/batched_procedures/__init__.py
Normal file
34
tests/e2e/batched_procedures/common.py
Normal file
34
tests/e2e/batched_procedures/common.py
Normal file
@ -0,0 +1,34 @@
|
||||
# Copyright 2023 Memgraph Ltd.
|
||||
#
|
||||
# Use of this software is governed by the Business Source License
|
||||
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
# License, and you may not use this file except in compliance with the Business Source License.
|
||||
#
|
||||
# As of the Change Date specified in that file, in accordance with
|
||||
# the Business Source License, use of this software will be governed
|
||||
# by the Apache License, Version 2.0, included in the file
|
||||
# licenses/APL.txt.
|
||||
|
||||
import typing
|
||||
|
||||
import mgclient
|
||||
|
||||
|
||||
def execute_and_fetch_all(cursor: mgclient.Cursor, query: str, params: dict = {}) -> typing.List[tuple]:
|
||||
cursor.execute(query, params)
|
||||
return cursor.fetchall()
|
||||
|
||||
|
||||
def connect(**kwargs) -> mgclient.Connection:
|
||||
connection = mgclient.connect(host="localhost", port=7687, **kwargs)
|
||||
connection.autocommit = True
|
||||
return connection
|
||||
|
||||
|
||||
def has_n_result_row(cursor: mgclient.Cursor, query: str, n: int):
|
||||
results = execute_and_fetch_all(cursor, query)
|
||||
return len(results) == n
|
||||
|
||||
|
||||
def has_one_result_row(cursor: mgclient.Cursor, query: str):
|
||||
return has_n_result_row(cursor, query, 1)
|
26
tests/e2e/batched_procedures/conftest.py
Normal file
26
tests/e2e/batched_procedures/conftest.py
Normal file
@ -0,0 +1,26 @@
|
||||
# Copyright 2023 Memgraph Ltd.
|
||||
#
|
||||
# Use of this software is governed by the Business Source License
|
||||
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
# License, and you may not use this file except in compliance with the Business Source License.
|
||||
#
|
||||
# As of the Change Date specified in that file, in accordance with
|
||||
# the Business Source License, use of this software will be governed
|
||||
# by the Apache License, Version 2.0, included in the file
|
||||
# licenses/APL.txt.
|
||||
|
||||
import pytest
|
||||
from common import connect, execute_and_fetch_all
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def connection():
|
||||
connection = connect()
|
||||
yield connection
|
||||
cursor = connection.cursor()
|
||||
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n")
|
||||
|
||||
|
||||
def get_connection():
|
||||
connection = connect()
|
||||
return connection
|
7
tests/e2e/batched_procedures/procedures/CMakeLists.txt
Normal file
7
tests/e2e/batched_procedures/procedures/CMakeLists.txt
Normal file
@ -0,0 +1,7 @@
|
||||
copy_batched_procedures_e2e_python_files(batch_py_read.py)
|
||||
copy_batched_procedures_e2e_python_files(batch_py_write.py)
|
||||
|
||||
add_query_module(batch_c_read batch_c_read.cpp)
|
||||
|
||||
|
||||
add_subdirectory(common)
|
136
tests/e2e/batched_procedures/procedures/batch_c_read.cpp
Normal file
136
tests/e2e/batched_procedures/procedures/batch_c_read.cpp
Normal file
@ -0,0 +1,136 @@
|
||||
// Copyright 2023 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#include <exception>
|
||||
#include <stdexcept>
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "mg_procedure.h"
|
||||
|
||||
#include "mgp.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using mgp_int = decltype(mgp::Value().ValueInt());
|
||||
static mgp_int num_ints{0};
|
||||
static mgp_int returned_ints{0};
|
||||
|
||||
static mgp_int num_strings{0};
|
||||
static int returned_strings{0};
|
||||
|
||||
const char *kReturnOutput = "output";
|
||||
|
||||
void NumsBatchInit(struct mgp_list *args, mgp_graph *graph, struct mgp_memory *memory) {
|
||||
mgp::memory = memory;
|
||||
const auto arguments = mgp::List(args);
|
||||
if (arguments.Empty()) {
|
||||
throw std::runtime_error("Expected to recieve argument");
|
||||
}
|
||||
if (arguments[0].Type() != mgp::Type::Int) {
|
||||
throw std::runtime_error("Wrong type of first arguments");
|
||||
}
|
||||
const auto num = arguments[0].ValueInt();
|
||||
num_ints = num;
|
||||
}
|
||||
|
||||
void NumsBatch(struct mgp_list *args, mgp_graph *graph, mgp_result *result, struct mgp_memory *memory) {
|
||||
mgp::memory = memory;
|
||||
const auto arguments = mgp::List(args);
|
||||
const auto record_factory = mgp::RecordFactory(result);
|
||||
if (returned_ints < num_ints) {
|
||||
auto record = record_factory.NewRecord();
|
||||
record.Insert(kReturnOutput, ++returned_ints);
|
||||
}
|
||||
}
|
||||
|
||||
void NumsBatchCleanup() {
|
||||
returned_ints = 0;
|
||||
num_ints = 0;
|
||||
}
|
||||
|
||||
void StringsBatchInit(struct mgp_list *args, mgp_graph *graph, struct mgp_memory *memory) {
|
||||
mgp::memory = memory;
|
||||
const auto arguments = mgp::List(args);
|
||||
if (arguments.Empty()) {
|
||||
throw std::runtime_error("Expected to recieve argument");
|
||||
}
|
||||
if (arguments[0].Type() != mgp::Type::Int) {
|
||||
throw std::runtime_error("Wrong type of first arguments");
|
||||
}
|
||||
const auto num = arguments[0].ValueInt();
|
||||
num_strings = num;
|
||||
}
|
||||
|
||||
void StringsBatch(struct mgp_list *args, mgp_graph *graph, mgp_result *result, struct mgp_memory *memory) {
|
||||
mgp::memory = memory;
|
||||
const auto arguments = mgp::List(args);
|
||||
const auto record_factory = mgp::RecordFactory(result);
|
||||
|
||||
if (returned_strings < num_strings) {
|
||||
auto record = record_factory.NewRecord();
|
||||
returned_strings++;
|
||||
record.Insert(kReturnOutput, "output");
|
||||
}
|
||||
}
|
||||
|
||||
void StringsBatchCleanup() {
|
||||
returned_strings = 0;
|
||||
num_strings = 0;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Each module needs to define mgp_init_module function.
|
||||
// Here you can register multiple functions/procedures your module supports.
|
||||
extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory) {
|
||||
{
|
||||
mgp_proc *proc{nullptr};
|
||||
auto err_code =
|
||||
mgp_module_add_batch_read_procedure(module, "batch_nums", NumsBatch, NumsBatchInit, NumsBatchCleanup, &proc);
|
||||
if (err_code != mgp_error::MGP_ERROR_NO_ERROR) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
mgp_type *type_int{nullptr};
|
||||
static_cast<void>(mgp_type_int(&type_int));
|
||||
err_code = mgp_proc_add_arg(proc, "num_ints", type_int);
|
||||
if (err_code != mgp_error::MGP_ERROR_NO_ERROR) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
mgp_type *return_type_int{nullptr};
|
||||
static_cast<void>(mgp_type_int(&return_type_int));
|
||||
err_code = mgp_proc_add_result(proc, "output", return_type_int);
|
||||
if (err_code != mgp_error::MGP_ERROR_NO_ERROR) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
try {
|
||||
mgp::memory = memory;
|
||||
mgp::AddBatchProcedure(StringsBatch, StringsBatchInit, StringsBatchCleanup, "batch_strings",
|
||||
mgp::ProcedureType::Read, {mgp::Parameter("num_strings", mgp::Type::Int)},
|
||||
{mgp::Return("output", mgp::Type::String)}, module, memory);
|
||||
|
||||
} catch (const std::exception &e) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// This is an optional function if you need to release any resources before the
|
||||
// module is unloaded. You will probably need this if you acquired some
|
||||
// resources in mgp_init_module.
|
||||
extern "C" int mgp_shutdown_module() { return 0; }
|
134
tests/e2e/batched_procedures/procedures/batch_py_read.py
Normal file
134
tests/e2e/batched_procedures/procedures/batch_py_read.py
Normal file
@ -0,0 +1,134 @@
|
||||
# Copyright 2023 Memgraph Ltd.
|
||||
#
|
||||
# Use of this software is governed by the Business Source License
|
||||
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
# License, and you may not use this file except in compliance with the Business Source License.
|
||||
#
|
||||
# As of the Change Date specified in that file, in accordance with
|
||||
# the Business Source License, use of this software will be governed
|
||||
# by the Apache License, Version 2.0, included in the file
|
||||
# licenses/APL.txt.
|
||||
|
||||
import mgp
|
||||
|
||||
# isort: off
|
||||
from common.shared import BaseClass, InitializationGraphMutable, InitializationUnderlyingGraphMutable
|
||||
|
||||
initialization_underlying_graph_mutable = InitializationUnderlyingGraphMutable()
|
||||
|
||||
|
||||
def cleanup_underlying():
|
||||
initialization_underlying_graph_mutable.reset()
|
||||
|
||||
|
||||
def init_underlying_graph_is_mutable(ctx: mgp.ProcCtx, object: mgp.Any):
|
||||
initialization_underlying_graph_mutable.set()
|
||||
|
||||
|
||||
def underlying_graph_is_mutable(ctx: mgp.ProcCtx, object: mgp.Any) -> mgp.Record(mutable=bool, init_called=bool):
|
||||
if initialization_underlying_graph_mutable.get_to_return() > 0:
|
||||
initialization_underlying_graph_mutable.increment_returned(1)
|
||||
return mgp.Record(
|
||||
mutable=object.underlying_graph_is_mutable(), init_called=initialization_underlying_graph_mutable.get()
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
# Register batched
|
||||
mgp.add_batch_read_proc(underlying_graph_is_mutable, init_underlying_graph_is_mutable, cleanup_underlying)
|
||||
|
||||
|
||||
initialization_graph_mutable = InitializationGraphMutable()
|
||||
|
||||
|
||||
def init_graph_is_mutable(ctx: mgp.ProcCtx):
|
||||
initialization_graph_mutable.set()
|
||||
|
||||
|
||||
def graph_is_mutable(ctx: mgp.ProcCtx) -> mgp.Record(mutable=bool, init_called=bool):
|
||||
if initialization_graph_mutable.get_to_return() > 0:
|
||||
initialization_graph_mutable.increment_returned(1)
|
||||
return mgp.Record(mutable=ctx.graph.is_mutable(), init_called=initialization_graph_mutable.get())
|
||||
return []
|
||||
|
||||
|
||||
def cleanup_graph():
|
||||
initialization_graph_mutable.reset()
|
||||
|
||||
|
||||
# Register batched
|
||||
mgp.add_batch_read_proc(graph_is_mutable, init_graph_is_mutable, cleanup_graph)
|
||||
|
||||
|
||||
class BatchingNums(BaseClass):
|
||||
def __init__(self, nums_to_return):
|
||||
super().__init__(nums_to_return)
|
||||
self._nums = []
|
||||
self._i = 0
|
||||
|
||||
|
||||
batching_nums = BatchingNums(10)
|
||||
|
||||
|
||||
def init_batching_nums(ctx: mgp.ProcCtx):
|
||||
batching_nums.set()
|
||||
batching_nums._nums = [i for i in range(1, 11)]
|
||||
batching_nums._i = 0
|
||||
|
||||
|
||||
def batch_nums(ctx: mgp.ProcCtx) -> mgp.Record(num=int, init_called=bool, is_valid=bool):
|
||||
if batching_nums.get_to_return() > 0:
|
||||
batching_nums.increment_returned(1)
|
||||
batching_nums._i += 1
|
||||
return mgp.Record(
|
||||
num=batching_nums._nums[batching_nums._i - 1],
|
||||
init_called=batching_nums.get(),
|
||||
is_valid=ctx.graph.is_valid(),
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def cleanup_batching_nums():
|
||||
batching_nums.reset()
|
||||
batching_nums._i = 0
|
||||
|
||||
|
||||
# Register batched
|
||||
mgp.add_batch_read_proc(batch_nums, init_batching_nums, cleanup_batching_nums)
|
||||
|
||||
|
||||
class BatchingVertices(BaseClass):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._vertices = []
|
||||
self._i = 0
|
||||
|
||||
|
||||
batching_vertices = BatchingVertices()
|
||||
|
||||
|
||||
def init_batching_vertices(ctx: mgp.ProcCtx):
|
||||
print("init called")
|
||||
print("graph is mutable", ctx.graph.is_mutable())
|
||||
batching_vertices.set()
|
||||
batching_vertices._vertices = list(ctx.graph.vertices)
|
||||
batching_vertices._i = 0
|
||||
batching_vertices._num_to_return = len(batching_vertices._vertices)
|
||||
|
||||
|
||||
def batch_vertices(ctx: mgp.ProcCtx) -> mgp.Record(vertex_id=int, init_called=bool):
|
||||
if batching_vertices.get_to_return() == 0:
|
||||
return []
|
||||
batching_vertices.increment_returned(1)
|
||||
return mgp.Record(vertex=batching_vertices._vertices[batching_vertices._i].id, init_called=batching_vertices.get())
|
||||
|
||||
|
||||
def cleanup_batching_vertices():
|
||||
batching_vertices.reset()
|
||||
batching_vertices._vertices = []
|
||||
batching_vertices._i = 0
|
||||
batching_vertices._num_to_return = 0
|
||||
|
||||
|
||||
# Register batched
|
||||
mgp.add_batch_read_proc(batch_vertices, init_batching_vertices, cleanup_batching_vertices)
|
60
tests/e2e/batched_procedures/procedures/batch_py_write.py
Normal file
60
tests/e2e/batched_procedures/procedures/batch_py_write.py
Normal file
@ -0,0 +1,60 @@
|
||||
# Copyright 2023 Memgraph Ltd.
|
||||
#
|
||||
# Use of this software is governed by the Business Source License
|
||||
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
# License, and you may not use this file except in compliance with the Business Source License.
|
||||
#
|
||||
# As of the Change Date specified in that file, in accordance with
|
||||
# the Business Source License, use of this software will be governed
|
||||
# by the Apache License, Version 2.0, included in the file
|
||||
# licenses/APL.txt.
|
||||
|
||||
import mgp
|
||||
|
||||
# isort: off
|
||||
from common.shared import InitializationGraphMutable, InitializationUnderlyingGraphMutable
|
||||
|
||||
write_init_underlying_graph_mutable = InitializationUnderlyingGraphMutable()
|
||||
|
||||
|
||||
def cleanup_underlying():
|
||||
write_init_underlying_graph_mutable.reset()
|
||||
|
||||
|
||||
def init_underlying_graph_is_mutable(ctx: mgp.ProcCtx, object: mgp.Any):
|
||||
write_init_underlying_graph_mutable.set()
|
||||
|
||||
|
||||
def underlying_graph_is_mutable(ctx: mgp.ProcCtx, object: mgp.Any) -> mgp.Record(mutable=bool, init_called=bool):
|
||||
if write_init_underlying_graph_mutable.get_to_return() == 0:
|
||||
return []
|
||||
write_init_underlying_graph_mutable.increment_returned(1)
|
||||
return mgp.Record(
|
||||
mutable=object.underlying_graph_is_mutable(), init_called=write_init_underlying_graph_mutable.get()
|
||||
)
|
||||
|
||||
|
||||
# Register batched
|
||||
mgp.add_batch_write_proc(underlying_graph_is_mutable, init_underlying_graph_is_mutable, cleanup_underlying)
|
||||
|
||||
|
||||
write_init_graph_mutable = InitializationGraphMutable()
|
||||
|
||||
|
||||
def init_graph_is_mutable(ctx: mgp.ProcCtx):
|
||||
write_init_graph_mutable.set()
|
||||
|
||||
|
||||
def graph_is_mutable(ctx: mgp.ProcCtx) -> mgp.Record(mutable=bool, init_called=bool):
|
||||
if write_init_graph_mutable.get_to_return() > 0:
|
||||
write_init_graph_mutable.increment_returned(1)
|
||||
return mgp.Record(mutable=ctx.graph.is_mutable(), init_called=write_init_graph_mutable.get())
|
||||
return []
|
||||
|
||||
|
||||
def cleanup_graph():
|
||||
write_init_graph_mutable.reset()
|
||||
|
||||
|
||||
# Register batched
|
||||
mgp.add_batch_write_proc(graph_is_mutable, init_graph_is_mutable, cleanup_graph)
|
@ -0,0 +1 @@
|
||||
copy_batched_procedures_e2e_python_files(shared.py)
|
31
tests/e2e/batched_procedures/procedures/common/shared.py
Normal file
31
tests/e2e/batched_procedures/procedures/common/shared.py
Normal file
@ -0,0 +1,31 @@
|
||||
class BaseClass:
|
||||
def __init__(self, num_to_return=1) -> None:
|
||||
self._init_is_called = False
|
||||
self._num_to_return = num_to_return
|
||||
self._num_returned = 0
|
||||
|
||||
def reset(self):
|
||||
self._init_is_called = False
|
||||
self._num_returned = 0
|
||||
|
||||
def set(self):
|
||||
self._init_is_called = True
|
||||
|
||||
def get(self):
|
||||
return self._init_is_called
|
||||
|
||||
def increment_returned(self, returned: int):
|
||||
self._num_returned += returned
|
||||
|
||||
def get_to_return(self) -> int:
|
||||
return self._num_to_return - self._num_returned
|
||||
|
||||
|
||||
class InitializationUnderlyingGraphMutable(BaseClass):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
class InitializationGraphMutable(BaseClass):
|
||||
def __init__(self):
|
||||
super().__init__()
|
143
tests/e2e/batched_procedures/simple_read.py
Normal file
143
tests/e2e/batched_procedures/simple_read.py
Normal file
@ -0,0 +1,143 @@
|
||||
# Copyright 2023 Memgraph Ltd.
|
||||
#
|
||||
# Use of this software is governed by the Business Source License
|
||||
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
# License, and you may not use this file except in compliance with the Business Source License.
|
||||
#
|
||||
# As of the Change Date specified in that file, in accordance with
|
||||
# the Business Source License, use of this software will be governed
|
||||
# by the Apache License, Version 2.0, included in the file
|
||||
# licenses/APL.txt.
|
||||
|
||||
# isort: off
|
||||
import sys
|
||||
import pytest
|
||||
|
||||
from common import execute_and_fetch_all, has_n_result_row, has_one_result_row
|
||||
from conftest import get_connection
|
||||
from mgclient import DatabaseError
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"is_write",
|
||||
[
|
||||
True,
|
||||
False,
|
||||
],
|
||||
)
|
||||
def test_graph_mutability(is_write: bool, connection):
|
||||
cursor = connection.cursor()
|
||||
|
||||
execute_and_fetch_all(cursor, f"MATCH (n) DETACH DELETE n")
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
|
||||
|
||||
module = "write" if is_write else "read"
|
||||
|
||||
result = list(
|
||||
execute_and_fetch_all(
|
||||
cursor,
|
||||
f"CALL batch_py_{module}.graph_is_mutable() " "YIELD mutable, init_called RETURN mutable, init_called",
|
||||
)
|
||||
)
|
||||
assert result == [(is_write, True)]
|
||||
|
||||
execute_and_fetch_all(cursor, "CREATE ()")
|
||||
result = list(
|
||||
execute_and_fetch_all(
|
||||
cursor,
|
||||
"MATCH (n) "
|
||||
f"CALL batch_py_{module}.underlying_graph_is_mutable(n) "
|
||||
"YIELD mutable, init_called RETURN mutable, init_called",
|
||||
)
|
||||
)
|
||||
assert result == [(is_write, True)]
|
||||
|
||||
execute_and_fetch_all(cursor, "CREATE ()-[:TYPE]->()")
|
||||
result = list(
|
||||
execute_and_fetch_all(
|
||||
cursor,
|
||||
"MATCH (n)-[e]->(m) "
|
||||
f"CALL batch_py_{module}.underlying_graph_is_mutable(e) "
|
||||
"YIELD mutable, init_called RETURN mutable, init_called",
|
||||
)
|
||||
)
|
||||
assert result == [(is_write, True)]
|
||||
|
||||
|
||||
def test_batching_nums(connection):
|
||||
cursor = connection.cursor()
|
||||
execute_and_fetch_all(cursor, f"MATCH (n) DETACH DELETE n")
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
|
||||
|
||||
result = list(
|
||||
execute_and_fetch_all(
|
||||
cursor,
|
||||
f"CALL batch_py_read.batch_nums() " "YIELD num, init_called, is_valid RETURN num, init_called, is_valid",
|
||||
)
|
||||
)
|
||||
assert result == [(i, True, True) for i in range(1, 11)]
|
||||
|
||||
execute_and_fetch_all(cursor, "CREATE () CREATE ()")
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 2)
|
||||
result = list(
|
||||
execute_and_fetch_all(
|
||||
cursor,
|
||||
"MATCH (n) "
|
||||
"CALL batch_py_read.batch_nums() "
|
||||
"YIELD num, init_called, is_valid RETURN num, init_called, is_valid ",
|
||||
)
|
||||
)
|
||||
assert result == [(i, True, True) for i in range(1, 11)] * 2
|
||||
|
||||
|
||||
def test_batching_vertices(connection):
|
||||
cursor = connection.cursor()
|
||||
execute_and_fetch_all(cursor, f"MATCH (n) DETACH DELETE n")
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
|
||||
|
||||
execute_and_fetch_all(cursor, f"CREATE () CREATE ()")
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 2)
|
||||
|
||||
with pytest.raises(DatabaseError):
|
||||
result = list(
|
||||
execute_and_fetch_all(
|
||||
cursor, f"CALL batch_py_read.batch_vertices() " "YIELD vertex, init_called RETURN vertex, init_called"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_batching_nums_c(connection):
|
||||
cursor = connection.cursor()
|
||||
execute_and_fetch_all(cursor, f"MATCH (n) DETACH DELETE n")
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
|
||||
|
||||
num_ints = 10
|
||||
result = list(
|
||||
execute_and_fetch_all(
|
||||
cursor,
|
||||
f"CALL batch_c_read.batch_nums({num_ints}) " "YIELD output RETURN output",
|
||||
)
|
||||
)
|
||||
result_list = [item[0] for item in result]
|
||||
print(result_list)
|
||||
print([i for i in range(1, num_ints + 1)])
|
||||
assert result_list == [i for i in range(1, num_ints + 1)]
|
||||
|
||||
|
||||
def test_batching_strings_c(connection):
|
||||
cursor = connection.cursor()
|
||||
execute_and_fetch_all(cursor, f"MATCH (n) DETACH DELETE n")
|
||||
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
|
||||
|
||||
num_strings = 10
|
||||
result = list(
|
||||
execute_and_fetch_all(
|
||||
cursor,
|
||||
f"CALL batch_c_read.batch_strings({num_strings}) " "YIELD output RETURN output",
|
||||
)
|
||||
)
|
||||
assert len(result) == num_strings
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main([__file__, "-rA"]))
|
14
tests/e2e/batched_procedures/workloads.yaml
Normal file
14
tests/e2e/batched_procedures/workloads.yaml
Normal file
@ -0,0 +1,14 @@
|
||||
template_cluster: &template_cluster
|
||||
cluster:
|
||||
main:
|
||||
args: ["--bolt-port", "7687", "--log-level=TRACE"]
|
||||
log_file: "batched-procedures-e2e.log"
|
||||
setup_queries: []
|
||||
validation_queries: []
|
||||
|
||||
workloads:
|
||||
- name: "Batched procedures read"
|
||||
binary: "tests/e2e/pytest_runner.sh"
|
||||
proc: "tests/e2e/batched_procedures/procedures/"
|
||||
args: ["batched_procedures/simple_read.py"]
|
||||
<<: *template_cluster
|
Loading…
Reference in New Issue
Block a user