Add python & cpp batching option in procedures

* Add API for batching from the procedure 
* Use PoolResource for batched procedures
This commit is contained in:
Antonio Filipovic 2023-06-26 15:46:13 +02:00 committed by GitHub
parent 00226dee24
commit d573eda8bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 1155 additions and 70 deletions

View File

@ -677,6 +677,16 @@ inline mgp_proc *module_add_write_procedure(mgp_module *module, const char *name
return MgInvoke<mgp_proc *>(mgp_module_add_write_procedure, module, name, cb);
}
inline mgp_proc *module_add_batch_read_procedure(mgp_module *module, const char *name, mgp_proc_cb cb,
mgp_proc_initializer initializer, mgp_proc_cleanup cleanup) {
return MgInvoke<mgp_proc *>(mgp_module_add_batch_read_procedure, module, name, cb, initializer, cleanup);
}
inline mgp_proc *module_add_batch_write_procedure(mgp_module *module, const char *name, mgp_proc_cb cb,
mgp_proc_initializer initializer, mgp_proc_cleanup cleanup) {
return MgInvoke<mgp_proc *>(mgp_module_add_batch_write_procedure, module, name, cb, initializer, cleanup);
}
inline void proc_add_arg(mgp_proc *proc, const char *name, mgp_type *type) {
MgInvokeVoid(mgp_proc_add_arg, proc, name, type);
}

View File

@ -1318,6 +1318,13 @@ MGP_ENUM_CLASS mgp_log_level{
/// to allocate global resources.
typedef void (*mgp_proc_cb)(struct mgp_list *, struct mgp_graph *, struct mgp_result *, struct mgp_memory *);
/// Cleanup for a query module read procedure. Can't be invoked through OpenCypher. Cleans batched stream.
typedef void (*mgp_proc_cleanup)();
/// Initializer for a query module batched read procedure. Can't be invoked through OpenCypher. Initializes batched
/// stream.
typedef void (*mgp_proc_initializer)(struct mgp_list *, struct mgp_graph *, struct mgp_memory *);
/// Register a read-only procedure to a module.
///
/// The `name` must be a sequence of digits, underscores, lowercase and
@ -1342,6 +1349,30 @@ enum mgp_error mgp_module_add_read_procedure(struct mgp_module *module, const ch
enum mgp_error mgp_module_add_write_procedure(struct mgp_module *module, const char *name, mgp_proc_cb cb,
struct mgp_proc **result);
/// Register a readable batched procedure to a module.
///
/// The `name` must be a valid identifier, following the same rules as the
/// procedure`name` in mgp_module_add_read_procedure.
///
/// Return mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate memory for mgp_proc.
/// Return mgp_error::MGP_ERROR_INVALID_ARGUMENT if `name` is not a valid procedure name.
/// RETURN mgp_error::MGP_ERROR_LOGIC_ERROR if a procedure with the same name was already registered.
enum mgp_error mgp_module_add_batch_read_procedure(struct mgp_module *module, const char *name, mgp_proc_cb cb,
mgp_proc_initializer initializer, mgp_proc_cleanup cleanup,
struct mgp_proc **result);
/// Register a writeable batched procedure to a module.
///
/// The `name` must be a valid identifier, following the same rules as the
/// procedure`name` in mgp_module_add_read_procedure.
///
/// Return mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate memory for mgp_proc.
/// Return mgp_error::MGP_ERROR_INVALID_ARGUMENT if `name` is not a valid procedure name.
/// RETURN mgp_error::MGP_ERROR_LOGIC_ERROR if a procedure with the same name was already registered.
enum mgp_error mgp_module_add_batch_write_procedure(struct mgp_module *module, const char *name, mgp_proc_cb cb,
mgp_proc_initializer initializer, mgp_proc_cleanup cleanup,
struct mgp_proc **result);
/// Add a required argument to a procedure.
///
/// The order of adding arguments will correspond to the order the procedure

View File

@ -1319,6 +1319,20 @@ inline void AddProcedure(mgp_proc_cb callback, std::string_view name, ProcedureT
std::vector<Parameter> parameters, std::vector<Return> returns, mgp_module *module,
mgp_memory *memory);
/// @brief Adds a batch procedure to the query module.
/// @param callback - procedure callback
/// @param initializer - procedure initializer
/// @param cleanup - procedure cleanup
/// @param name - procedure name
/// @param proc_type - procedure type (read/write)
/// @param parameters - procedure parameters
/// @param returns - procedure return values
/// @param module - the query module that the procedure is added to
/// @param memory - access to memory
inline void AddBatchProcedure(mgp_proc_cb callback, mgp_proc_initializer initializer, mgp_proc_cleanup cleanup,
std::string_view name, ProcedureType proc_type, std::vector<Parameter> parameters,
std::vector<Return> returns, mgp_module *module, mgp_memory *memory);
/// @brief Adds a function to the query module.
/// @param callback - function callback
/// @param name - function name
@ -3430,14 +3444,11 @@ inline mgp_type *Return::GetMGPType() const {
return util::ToMGPType(type_);
}
void AddProcedure(mgp_proc_cb callback, std::string_view name, ProcedureType proc_type,
std::vector<Parameter> parameters, std::vector<Return> returns, mgp_module *module,
mgp_memory *memory) {
auto proc = (proc_type == ProcedureType::Read) ? mgp::module_add_read_procedure(module, name.data(), callback)
: mgp::module_add_write_procedure(module, name.data(), callback);
// do not enter
namespace detail {
void AddParamsReturnsToProc(mgp_proc *proc, std::vector<Parameter> &parameters, const std::vector<Return> &returns) {
for (const auto &parameter : parameters) {
auto parameter_name = parameter.name.data();
const auto *parameter_name = parameter.name.data();
if (!parameter.optional) {
mgp::proc_add_arg(proc, parameter_name, parameter.GetMGPType());
} else {
@ -3446,18 +3457,35 @@ void AddProcedure(mgp_proc_cb callback, std::string_view name, ProcedureType pro
}
for (const auto return_ : returns) {
auto return_name = return_.name.data();
const auto *return_name = return_.name.data();
mgp::proc_add_result(proc, return_name, return_.GetMGPType());
}
}
} // namespace detail
void AddProcedure(mgp_proc_cb callback, std::string_view name, ProcedureType proc_type,
std::vector<Parameter> parameters, std::vector<Return> returns, mgp_module *module,
mgp_memory *memory) {
auto *proc = (proc_type == ProcedureType::Read) ? mgp::module_add_read_procedure(module, name.data(), callback)
: mgp::module_add_write_procedure(module, name.data(), callback);
detail::AddParamsReturnsToProc(proc, parameters, returns);
}
void AddBatchProcedure(mgp_proc_cb callback, mgp_proc_initializer initializer, mgp_proc_cleanup cleanup,
std::string_view name, ProcedureType proc_type, std::vector<Parameter> parameters,
std::vector<Return> returns, mgp_module *module, mgp_memory *memory) {
auto *proc = (proc_type == ProcedureType::Read)
? mgp::module_add_batch_read_procedure(module, name.data(), callback, initializer, cleanup)
: mgp::module_add_batch_write_procedure(module, name.data(), callback, initializer, cleanup);
detail::AddParamsReturnsToProc(proc, parameters, returns);
}
void AddFunction(mgp_func_cb callback, std::string_view name, std::vector<Parameter> parameters, mgp_module *module,
mgp_memory *memory) {
auto func = mgp::module_add_function(module, name.data(), callback);
auto *func = mgp::module_add_function(module, name.data(), callback);
for (const auto &parameter : parameters) {
auto parameter_name = parameter.name.data();
const auto *parameter_name = parameter.name.data();
if (!parameter.optional) {
mgp::func_add_arg(func, parameter_name, parameter.GetMGPType());

View File

@ -1402,6 +1402,13 @@ class UnsupportedTypingError(Exception):
super().__init__("Unsupported typing annotation '{}'".format(type_))
class UnequalTypesError(Exception):
"""Signals a typing annotation is not equal between types"""
def __init__(self, type1_: typing.Any, type2_: typing.Any):
super().__init__(f"Unequal typing annotation '{type1_}' and '{type2_}'")
def _typing_to_cypher_type(type_):
"""Convert typing annotation to a _mgp.CypherType instance."""
simple_types = {
@ -1514,6 +1521,72 @@ def _typing_to_cypher_type(type_):
return parse_typing(str(type_))
def _is_typing_same(type1_, type2_):
"""Convert typing annotation to a _mgp.CypherType instance."""
simple_types = {
typing.Any: 1,
object: 2,
list: 3,
Any: 4,
bool: 5,
str: 6,
int: 7,
float: 8,
Number: 9,
Map: 10,
Vertex: 11,
Edge: 12,
Path: 13,
Date: 14,
LocalTime: 15,
LocalDateTime: 16,
Duration: 17,
}
try:
return simple_types[type1_] == simple_types[type2_]
except KeyError:
pass
if sys.version_info < (3, 8):
# skip type checks
return True
complex_type1 = typing.get_origin(type1_)
type_args1 = typing.get_args(type2_)
complex_type2 = typing.get_origin(type1_)
type_args2 = typing.get_args(type2_)
if complex_type2 != complex_type1:
raise UnequalTypesError(type1_, type2_)
if complex_type1 == typing.Union:
contains_none_arg1 = type(None) in type_args1
contains_none_arg2 = type(None) in type_args2
if contains_none_arg1 != contains_none_arg2:
raise UnequalTypesError(type1_, type2_)
if contains_none_arg1:
types1 = tuple(t for t in type_args1 if t is not type(None)) # noqa E721
types2 = tuple(t for t in type_args2 if t is not type(None)) # noqa E721
if len(types1) != len(types2):
raise UnequalTypesError(types1, types2)
if len(types1) == 1:
(type_arg1,) = types1
(type_arg2,) = types2
else:
type_arg1 = typing.Union.__getitem__(types1)
type_arg2 = typing.Union.__getitem__(types2)
return _is_typing_same(type_arg1, type_arg2)
elif complex_type1 == list:
(type_arg1,) = type_args1
(type_arg2,) = type_args2
return _is_typing_same(type_arg1, type_arg2)
# skip type checks
return True
# Procedure registration
@ -1673,6 +1746,92 @@ def write_proc(func: typing.Callable[..., Record]):
return _register_proc(func, True)
def _register_batch_proc(
func: typing.Callable[..., Record], initializer: typing.Callable, cleanup: typing.Callable, is_write: bool
):
raise_if_does_not_meet_requirements(func)
register_func = _mgp.Module.add_batch_write_procedure if is_write else _mgp.Module.add_batch_read_procedure
func_sig = inspect.signature(func)
func_params = tuple(func_sig.parameters.values())
initializer_sig = inspect.signature(initializer)
initializer_params = tuple(initializer_sig.parameters.values())
assert (
func_params and initializer_params or not func_params and not initializer_params
), "Both function params and initializer params must exist or not exist"
assert len(func_params) == len(initializer_params), "Number of params must be same"
assert initializer_sig.return_annotation is initializer_sig.empty, "Initializer can't return anything"
if func_params and func_params[0].annotation is ProcCtx:
assert (
initializer_params and initializer_params[0].annotation is ProcCtx
), "Initializer must have mgp.ProcCtx as first parameter"
@wraps(func)
def wrapper_func(graph, args):
return func(ProcCtx(graph), *args)
@wraps(initializer)
def wrapper_initializer(graph, args):
return initializer(ProcCtx(graph), *args)
func_params = func_params[1:]
initializer_params = initializer_params[1:]
mgp_proc = register_func(_mgp._MODULE, wrapper_func, wrapper_initializer, cleanup)
else:
@wraps(func)
def wrapper_func(graph, args):
return func(*args)
@wraps(initializer)
def wrapper_initializer(graph, args):
return initializer(*args)
mgp_proc = register_func(_mgp._MODULE, wrapper_func, wrapper_initializer, cleanup)
for func_param, initializer_param in zip(func_params, initializer_params):
func_param_name = func_param.name
func_param_type_ = func_param.annotation
if func_param_type_ is func_param.empty:
func_param_type_ = object
initializer_param_type_ = initializer_param.annotation
if initializer_param.annotation is initializer_param.empty:
initializer_param_type_ = object
assert _is_typing_same(
func_param_type_, initializer_param_type_
), "Types of initializer and function must be same"
func_cypher_type = _typing_to_cypher_type(func_param_type_)
if func_param.default is func_param.empty:
mgp_proc.add_arg(func_param_name, func_cypher_type)
else:
mgp_proc.add_opt_arg(func_param_name, func_cypher_type, func_param.default)
if func_sig.return_annotation is not func_sig.empty:
record = func_sig.return_annotation
if not isinstance(record, Record):
raise TypeError("Expected '{}' to return 'mgp.Record', got '{}'".format(func.__name__, type(record)))
for name, type_ in record.fields.items():
if isinstance(type_, Deprecated):
cypher_type = _typing_to_cypher_type(type_.field_type)
mgp_proc.add_deprecated_result(name, cypher_type)
else:
mgp_proc.add_result(name, _typing_to_cypher_type(type_))
return func
def add_batch_write_proc(func: typing.Callable[..., Record], initializer: typing.Callable, cleanup: typing.Callable):
return _register_batch_proc(func, initializer, cleanup, True)
def add_batch_read_proc(func: typing.Callable[..., Record], initializer: typing.Callable, cleanup: typing.Callable):
return _register_batch_proc(func, initializer, cleanup, False)
class InvalidMessageError(Exception):
"""
Signals using a message instance outside of the registered transformation.

View File

@ -10,6 +10,7 @@
// licenses/APL.txt.
#include "query/interpreter.hpp"
#include <bits/ranges_algo.h>
#include <fmt/core.h>
#include <algorithm>
@ -50,6 +51,7 @@
#include "query/plan/planner.hpp"
#include "query/plan/profile.hpp"
#include "query/plan/vertex_count_cache.hpp"
#include "query/procedure/module.hpp"
#include "query/stream.hpp"
#include "query/stream/common.hpp"
#include "query/trigger.hpp"
@ -78,7 +80,6 @@
#include "utils/tsc.hpp"
#include "utils/typeinfo.hpp"
#include "utils/variant_helpers.hpp"
namespace memgraph::metrics {
extern Event ReadQuery;
extern Event WriteQuery;
@ -1288,6 +1289,30 @@ inline static void TryCaching(const AstStorage &ast_storage, FrameChangeCollecto
}
}
bool IsCallBatchedProcedureQuery(const std::vector<memgraph::query::Clause *> &clauses) {
EvaluationContext evaluation_context;
return std::ranges::any_of(clauses, [&evaluation_context](const auto &clause) -> bool {
if (clause->GetTypeInfo() == CallProcedure::kType) {
auto *call_procedure_clause = utils::Downcast<CallProcedure>(clause);
const auto &maybe_found = memgraph::query::procedure::FindProcedure(
procedure::gModuleRegistry, call_procedure_clause->procedure_name_, evaluation_context.memory);
if (!maybe_found) {
throw QueryRuntimeException("There is no procedure named '{}'.", call_procedure_clause->procedure_name_);
}
const auto &[module, proc] = *maybe_found;
if (proc->info.is_batched) {
spdlog::trace("Using PoolResource for batched query procedure");
return true;
}
}
return false;
});
return false;
}
PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string, TypedValue> *summary,
InterpreterContext *interpreter_context, DbAccessor *dba,
utils::MemoryResource *execution_memory, std::vector<Notification> *notifications,
@ -1319,7 +1344,7 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string,
contains_csv = true;
}
// If this is LOAD CSV query, use PoolResource without MonotonicMemoryResource as we want to reuse allocated memory
auto use_monotonic_memory = !contains_csv;
auto use_monotonic_memory = !contains_csv && !IsCallBatchedProcedureQuery(clauses);
auto plan = CypherQueryToPlan(parsed_query.stripped_query.hash(), std::move(parsed_query.ast_storage), cypher_query,
parsed_query.parameters,
parsed_query.is_cacheable ? &interpreter_context->plan_cache : nullptr, dba);
@ -1451,7 +1476,7 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra
contains_csv = true;
}
// If this is LOAD CSV query, use PoolResource without MonotonicMemoryResource as we want to reuse allocated memory
auto use_monotonic_memory = !contains_csv;
auto use_monotonic_memory = !contains_csv && !IsCallBatchedProcedureQuery(clauses);
MG_ASSERT(cypher_query, "Cypher grammar should not allow other queries in PROFILE");
Frame frame(0);
@ -2858,14 +2883,14 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
auto *profile_query = utils::Downcast<ProfileQuery>(parsed_query.query);
cypher_query = profile_query->cypher_query_;
}
if (const auto &clauses = cypher_query->single_query_->clauses_;
std::any_of(clauses.begin(), clauses.end(),
[](const auto *clause) { return clause->GetTypeInfo() == LoadCsv::kType; })) {
IsCallBatchedProcedureQuery(clauses) || std::any_of(clauses.begin(), clauses.end(), [](const auto *clause) {
return clause->GetTypeInfo() == LoadCsv::kType;
})) {
// Using PoolResource without MonotonicMemoryResouce for LOAD CSV reduces memory usage.
// QueryExecution MemoryResource is mostly used for allocations done on Frame and storing `row`s
query_executions_[query_executions_.size() - 1] = std::make_unique<QueryExecution>(
utils::PoolResource(8, kExecutionPoolMaxBlockSize, utils::NewDeleteResource(), utils::NewDeleteResource()));
query_executions_[query_executions_.size() - 1] = std::make_unique<QueryExecution>(utils::PoolResource(
128, kExecutionPoolMaxBlockSize, utils::NewDeleteResource(), utils::NewDeleteResource()));
query_execution_ptr = &query_executions_.back();
}
}

View File

@ -14,6 +14,7 @@
#include <algorithm>
#include <cstdint>
#include <limits>
#include <optional>
#include <queue>
#include <random>
#include <string>
@ -4465,7 +4466,8 @@ namespace {
void CallCustomProcedure(const std::string_view fully_qualified_procedure_name, const mgp_proc &proc,
const std::vector<Expression *> &args, mgp_graph &graph, ExpressionEvaluator *evaluator,
utils::MemoryResource *memory, std::optional<size_t> memory_limit, mgp_result *result) {
utils::MemoryResource *memory, std::optional<size_t> memory_limit, mgp_result *result,
const bool call_initializer = false) {
static_assert(std::uses_allocator_v<mgp_value, utils::Allocator<mgp_value>>,
"Expected mgp_value to use custom allocator and makes STL "
"containers aware of that");
@ -4489,6 +4491,11 @@ void CallCustomProcedure(const std::string_view fully_qualified_procedure_name,
}
procedure::ConstructArguments(args_list, proc, fully_qualified_procedure_name, proc_args, graph);
if (call_initializer) {
MG_ASSERT(proc.initializer);
mgp_memory initializer_memory{memory};
proc.initializer.value()(&proc_args, &graph, &initializer_memory);
}
if (memory_limit) {
SPDLOG_INFO("Running '{}' with memory limit of {}", fully_qualified_procedure_name,
utils::GetReadableSize(*memory_limit));
@ -4516,18 +4523,22 @@ void CallCustomProcedure(const std::string_view fully_qualified_procedure_name,
class CallProcedureCursor : public Cursor {
const CallProcedure *self_;
UniqueCursorPtr input_cursor_;
mgp_result result_;
decltype(result_.rows.end()) result_row_it_{result_.rows.end()};
mgp_result *result_;
decltype(result_->rows.end()) result_row_it_{result_->rows.end()};
size_t result_signature_size_{0};
bool stream_exhausted{true};
bool call_initializer{false};
std::optional<std::function<void()>> cleanup_{std::nullopt};
public:
CallProcedureCursor(const CallProcedure *self, utils::MemoryResource *mem)
: self_(self),
input_cursor_(self_->input_->MakeCursor(mem)),
// result_ needs to live throughout multiple Pull evaluations, until all
// rows are produced. Therefore, we use the memory dedicated for the
// whole execution.
result_(nullptr, mem) {
// rows are produced. We don't use the memory dedicated for QueryExecution (and Frame),
// but memory dedicated for procedure to wipe result_ and everything allocated in procedure all at once.
result_(utils::Allocator<mgp_result>(self_->memory_resource)
.new_object<mgp_result>(nullptr, self_->memory_resource)) {
MG_ASSERT(self_->result_fields_.size() == self_->result_symbols_.size(), "Incorrectly constructed CallProcedure");
}
@ -4541,11 +4552,7 @@ class CallProcedureCursor : public Cursor {
// empty result set vs procedures which return `void`. We currently don't
// have procedures registering what they return.
// This `while` loop will skip over empty results.
while (result_row_it_ == result_.rows.end()) {
if (!input_cursor_->Pull(frame, context)) return false;
result_.signature = nullptr;
result_.rows.clear();
result_.error_msg.reset();
while (result_row_it_ == result_->rows.end()) {
// It might be a good idea to resolve the procedure name once, at the
// start. Unfortunately, this could deadlock if we tried to invoke a
// procedure from a module (read lock) and reload a module (write lock)
@ -4565,30 +4572,61 @@ class CallProcedureCursor : public Cursor {
self_->procedure_name_, get_proc_type_str(self_->is_write_),
get_proc_type_str(proc->info.is_write));
}
if (!proc->info.is_batched) {
stream_exhausted = true;
}
if (stream_exhausted) {
if (!input_cursor_->Pull(frame, context)) {
if (proc->cleanup) {
proc->cleanup.value()();
}
return false;
}
stream_exhausted = false;
if (proc->initializer) {
call_initializer = true;
MG_ASSERT(proc->cleanup);
proc->cleanup.value()();
}
}
if (!cleanup_ && proc->cleanup) [[unlikely]] {
cleanup_.emplace(*proc->cleanup);
}
// Unpluging memory without calling destruct on each object since everything was allocated with this memory
// resource
self_->monotonic_memory.Release();
result_ =
utils::Allocator<mgp_result>(self_->memory_resource).new_object<mgp_result>(nullptr, self_->memory_resource);
const auto graph_view = proc->info.is_write ? storage::View::NEW : storage::View::OLD;
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
graph_view);
result_.signature = &proc->results;
// Use evaluation memory, as invoking a procedure is akin to a simple
// evaluation of an expression.
result_->signature = &proc->results;
// Use special memory as invoking procedure is complex
// TODO: This will probably need to be changed when we add support for
// generator like procedures which yield a new result on each invocation.
auto *memory = context.evaluation_context.memory;
// generator like procedures which yield a new result on new query calls.
auto *memory = self_->memory_resource;
auto memory_limit = EvaluateMemoryLimit(&evaluator, self_->memory_limit_, self_->memory_scale_);
auto graph = mgp_graph::WritableGraph(*context.db_accessor, graph_view, context);
CallCustomProcedure(self_->procedure_name_, *proc, self_->arguments_, graph, &evaluator, memory, memory_limit,
&result_);
result_, call_initializer);
if (call_initializer) call_initializer = false;
// Reset result_.signature to nullptr, because outside of this scope we
// will no longer hold a lock on the `module`. If someone were to reload
// it, the pointer would be invalid.
result_signature_size_ = result_.signature->size();
result_.signature = nullptr;
if (result_.error_msg) {
throw QueryRuntimeException("{}: {}", self_->procedure_name_, *result_.error_msg);
result_signature_size_ = result_->signature->size();
result_->signature = nullptr;
if (result_->error_msg) {
throw QueryRuntimeException("{}: {}", self_->procedure_name_, *result_->error_msg);
}
result_row_it_ = result_.rows.begin();
result_row_it_ = result_->rows.begin();
stream_exhausted = result_row_it_ == result_->rows.end();
}
auto &values = result_row_it_->values;
@ -4621,12 +4659,20 @@ class CallProcedureCursor : public Cursor {
}
void Reset() override {
result_.rows.clear();
result_.error_msg.reset();
input_cursor_->Reset();
self_->monotonic_memory.Release();
result_ =
utils::Allocator<mgp_result>(self_->memory_resource).new_object<mgp_result>(nullptr, self_->memory_resource);
if (cleanup_) {
cleanup_.value()();
}
}
void Shutdown() override {}
void Shutdown() override {
self_->monotonic_memory.Release();
if (cleanup_) {
cleanup_.value()();
}
}
};
UniqueCursorPtr CallProcedure::MakeCursor(utils::MemoryResource *mem) const {

View File

@ -2199,6 +2199,8 @@ class CallProcedure : public memgraph::query::plan::LogicalOperator {
Expression *memory_limit_{nullptr};
size_t memory_scale_{1024U};
bool is_write_;
mutable utils::MonotonicBufferResource monotonic_memory{1024UL * 1024UL};
utils::MemoryResource *memory_resource = &monotonic_memory;
std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override {
auto object = std::make_unique<CallProcedure>();

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -2731,6 +2731,7 @@ mgp_error mgp_type_nullable(mgp_type *type, mgp_type **result) {
}
namespace {
/// @throw std::bad_alloc, std::length_error
mgp_proc *mgp_module_add_procedure(mgp_module *module, const char *name, mgp_proc_cb cb,
const ProcedureInfo &procedure_info) {
if (!IsValidIdentifierName(name)) {
@ -2741,9 +2742,24 @@ mgp_proc *mgp_module_add_procedure(mgp_module *module, const char *name, mgp_pro
};
auto *memory = module->procedures.get_allocator().GetMemoryResource();
// May throw std::bad_alloc, std::length_error
return &module->procedures.emplace(name, mgp_proc(name, cb, memory, procedure_info)).first->second;
}
/// @throw std::bad_alloc, std::length_error
mgp_proc *mgp_module_add_batch_procedure(mgp_module *module, const char *name, mgp_proc_cb cb_batch,
mgp_proc_initializer initializer, mgp_proc_cleanup cleanup,
const ProcedureInfo &procedure_info) {
if (!IsValidIdentifierName(name)) {
throw std::invalid_argument{fmt::format("Invalid procedure name: {}", name)};
}
if (module->procedures.find(name) != module->procedures.end()) {
throw std::logic_error{fmt::format("Procedure already exists with name '{}'", name)};
};
auto *memory = module->procedures.get_allocator().GetMemoryResource();
return &module->procedures.emplace(name, mgp_proc(name, cb_batch, initializer, cleanup, memory, procedure_info))
.first->second;
}
} // namespace
mgp_error mgp_module_add_read_procedure(mgp_module *module, const char *name, mgp_proc_cb cb, mgp_proc **result) {
@ -2754,6 +2770,28 @@ mgp_error mgp_module_add_write_procedure(mgp_module *module, const char *name, m
return WrapExceptions([=] { return mgp_module_add_procedure(module, name, cb, {.is_write = true}); }, result);
}
mgp_error mgp_module_add_batch_read_procedure(mgp_module *module, const char *name, mgp_proc_cb cb_batch,
mgp_proc_initializer initializer, mgp_proc_cleanup cleanup,
mgp_proc **result) {
return WrapExceptions(
[=] {
return mgp_module_add_batch_procedure(module, name, cb_batch, initializer, cleanup,
{.is_write = false, .is_batched = true});
},
result);
}
mgp_error mgp_module_add_batch_write_procedure(mgp_module *module, const char *name, mgp_proc_cb cb_batch,
mgp_proc_initializer initializer, mgp_proc_cleanup cleanup,
mgp_proc **result) {
return WrapExceptions(
[=] {
return mgp_module_add_batch_procedure(module, name, cb_batch, initializer, cleanup,
{.is_write = true, .is_batched = true});
},
result);
}
namespace {
template <typename T>
concept IsCallable = memgraph::utils::SameAsAnyOf<T, mgp_proc, mgp_func>;

View File

@ -729,7 +729,8 @@ struct mgp_type {
};
struct ProcedureInfo {
bool is_write = false;
bool is_write{false};
bool is_batched{false};
std::optional<memgraph::query::AuthQuery::Privilege> required_privilege = std::nullopt;
};
struct mgp_proc {
@ -740,6 +741,33 @@ struct mgp_proc {
mgp_proc(const char *name, mgp_proc_cb cb, memgraph::utils::MemoryResource *memory, const ProcedureInfo &info = {})
: name(name, memory), cb(cb), args(memory), opt_args(memory), results(memory), info(info) {}
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_proc(const char *name, mgp_proc_cb cb, mgp_proc_initializer initializer, mgp_proc_cleanup cleanup,
memgraph::utils::MemoryResource *memory, const ProcedureInfo &info = {})
: name(name, memory),
cb(cb),
initializer(initializer),
cleanup(cleanup),
args(memory),
opt_args(memory),
results(memory),
info(info) {}
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_proc(const char *name, std::function<void(mgp_list *, mgp_graph *, mgp_result *, mgp_memory *)> cb,
std::function<void(mgp_list *, mgp_graph *, mgp_memory *)> initializer, std::function<void()> cleanup,
memgraph::utils::MemoryResource *memory, const ProcedureInfo &info = {})
: name(name, memory),
cb(cb),
initializer(initializer),
cleanup(cleanup),
args(memory),
opt_args(memory),
results(memory),
info(info) {}
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_proc(const char *name, std::function<void(mgp_list *, mgp_graph *, mgp_result *, mgp_memory *)> cb,
@ -757,6 +785,8 @@ struct mgp_proc {
mgp_proc(const mgp_proc &other, memgraph::utils::MemoryResource *memory)
: name(other.name, memory),
cb(other.cb),
initializer(other.initializer),
cleanup(other.cleanup),
args(other.args, memory),
opt_args(other.opt_args, memory),
results(other.results, memory),
@ -765,6 +795,8 @@ struct mgp_proc {
mgp_proc(mgp_proc &&other, memgraph::utils::MemoryResource *memory)
: name(std::move(other.name), memory),
cb(std::move(other.cb)),
initializer(other.initializer),
cleanup(other.cleanup),
args(std::move(other.args), memory),
opt_args(std::move(other.opt_args), memory),
results(std::move(other.results), memory),
@ -782,6 +814,13 @@ struct mgp_proc {
memgraph::utils::pmr::string name;
/// Entry-point for the procedure.
std::function<void(mgp_list *, mgp_graph *, mgp_result *, mgp_memory *)> cb;
/// Initializer for batched procedure.
std::optional<std::function<void(mgp_list *, mgp_graph *, mgp_memory *)>> initializer;
/// Dtor for batched procedure.
std::optional<std::function<void()>> cleanup;
/// Required, positional arguments as a (name, type) pair.
memgraph::utils::pmr::vector<std::pair<memgraph::utils::pmr::string, const memgraph::query::procedure::CypherType *>>
args;

View File

@ -12,6 +12,7 @@
#include "query/procedure/py_module.hpp"
#include <datetime.h>
#include <methodobject.h>
#include <pyerrors.h>
#include <array>
#include <optional>
@ -54,6 +55,8 @@ PyObject *gMgpValueConversionError{nullptr}; // NOLINT(cppcoreguidelines-avo
PyObject *gMgpSerializationError{nullptr}; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
PyObject *gMgpAuthorizationError{nullptr}; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
constexpr bool kStartGarbageCollection{true};
// Returns true if an exception is raised
bool RaiseExceptionFromErrorCode(const mgp_error error) {
switch (error) {
@ -945,20 +948,37 @@ std::optional<py::ExceptionInfo> AddMultipleRecordsFromPython(mgp_result *result
return std::nullopt;
}
std::function<void()> PyObjectCleanup(py::Object &py_object) {
return [py_object]() {
// Run `gc.collect` (reference cycle-detection) explicitly, so that we are
// sure the procedure cleaned up everything it held references to. If the
// user stored a reference to one of our `_mgp` instances then the
// internally used `mgp_*` structs will stay unfreed and a memory leak
// will be reported at the end of the query execution.
py::Object gc(PyImport_ImportModule("gc"));
if (!gc) {
LOG_FATAL(py::FetchError().value());
}
std::optional<py::ExceptionInfo> AddMultipleBatchRecordsFromPython(mgp_result *result, py::Object py_seq,
mgp_memory *memory) {
Py_ssize_t len = PySequence_Size(py_seq.Ptr());
if (len == -1) return py::FetchError();
result->rows.reserve(len);
for (Py_ssize_t i = 0; i < len; ++i) {
py::Object py_record(PySequence_GetItem(py_seq.Ptr(), i));
if (!py_record) return py::FetchError();
auto maybe_exc = AddRecordFromPython(result, py_record, memory);
if (maybe_exc) return maybe_exc;
}
PySequence_DelSlice(py_seq.Ptr(), 0, PySequence_Size(py_seq.Ptr()));
return std::nullopt;
}
if (!gc.CallMethod("collect")) {
LOG_FATAL(py::FetchError().value());
std::function<void()> PyObjectCleanup(py::Object &py_object, bool start_gc) {
return [py_object, start_gc]() {
if (start_gc) {
// Run `gc.collect` (reference cycle-detection) explicitly, so that we are
// sure the procedure cleaned up everything it held references to. If the
// user stored a reference to one of our `_mgp` instances then the
// internally used `mgp_*` structs will stay unfreed and a memory leak
// will be reported at the end of the query execution.
py::Object gc(PyImport_ImportModule("gc"));
if (!gc) {
LOG_FATAL(py::FetchError().value());
}
if (!gc.CallMethod("collect")) {
LOG_FATAL(py::FetchError().value());
}
}
// After making sure all references from our side have been cleared,
@ -973,8 +993,7 @@ std::function<void()> PyObjectCleanup(py::Object &py_object) {
}
void CallPythonProcedure(const py::Object &py_cb, mgp_list *args, mgp_graph *graph, mgp_result *result,
mgp_memory *memory) {
// *memory here is memory from `EvalContext`
mgp_memory *memory, bool is_batched) {
auto gil = py::EnsureGIL();
auto error_to_msg = [](const std::optional<py::ExceptionInfo> &exc_info) -> std::optional<std::string> {
@ -992,10 +1011,12 @@ void CallPythonProcedure(const py::Object &py_cb, mgp_list *args, mgp_graph *gra
auto py_res = py_cb.Call(py_graph, py_args);
if (!py_res) return py::FetchError();
if (PySequence_Check(py_res.Ptr())) {
if (is_batched) {
return AddMultipleBatchRecordsFromPython(result, py_res, memory);
}
return AddMultipleRecordsFromPython(result, py_res, memory);
} else {
return AddRecordFromPython(result, py_res, memory);
}
return AddRecordFromPython(result, py_res, memory);
};
// It is *VERY IMPORTANT* to note that this code takes great care not to keep
@ -1010,7 +1031,7 @@ void CallPythonProcedure(const py::Object &py_cb, mgp_list *args, mgp_graph *gra
std::optional<std::string> maybe_msg;
{
py::Object py_graph(MakePyGraph(graph, memory));
utils::OnScopeExit clean_up(PyObjectCleanup(py_graph));
utils::OnScopeExit clean_up(PyObjectCleanup(py_graph, !is_batched));
if (py_graph) {
maybe_msg = error_to_msg(call(py_graph));
} else {
@ -1023,6 +1044,38 @@ void CallPythonProcedure(const py::Object &py_cb, mgp_list *args, mgp_graph *gra
}
}
void CallPythonCleanup(const py::Object &py_cleanup) {
auto gil = py::EnsureGIL();
auto py_res = py_cleanup.Call();
py::Object gc(PyImport_ImportModule("gc"));
if (!gc) {
LOG_FATAL(py::FetchError().value());
}
if (!gc.CallMethod("collect")) {
LOG_FATAL(py::FetchError().value());
}
}
void CallPythonInitializer(const py::Object &py_initializer, mgp_list *args, mgp_graph *graph, mgp_memory *memory) {
auto gil = py::EnsureGIL();
auto call = [&](py::Object py_graph) -> std::optional<py::ExceptionInfo> {
py::Object py_args(MgpListToPyTuple(args, py_graph.Ptr()));
if (!py_args) return py::FetchError();
auto py_res = py_initializer.Call(py_graph, py_args);
return {};
};
{
py::Object py_graph(MakePyGraph(graph, memory));
utils::OnScopeExit clean_up_graph(PyObjectCleanup(py_graph, !kStartGarbageCollection));
call(py_graph);
}
}
void CallPythonTransformation(const py::Object &py_cb, mgp_messages *msgs, mgp_graph *graph, mgp_result *result,
mgp_memory *memory) {
auto gil = py::EnsureGIL();
@ -1059,8 +1112,8 @@ void CallPythonTransformation(const py::Object &py_cb, mgp_messages *msgs, mgp_g
py::Object py_graph(MakePyGraph(graph, memory));
py::Object py_messages(MakePyMessages(msgs, memory));
utils::OnScopeExit clean_up_graph(PyObjectCleanup(py_graph));
utils::OnScopeExit clean_up_messages(PyObjectCleanup(py_messages));
utils::OnScopeExit clean_up_graph(PyObjectCleanup(py_graph, kStartGarbageCollection));
utils::OnScopeExit clean_up_messages(PyObjectCleanup(py_messages, kStartGarbageCollection));
if (py_graph && py_messages) {
maybe_msg = error_to_msg(call(py_graph, py_messages));
@ -1111,7 +1164,7 @@ void CallPythonFunction(const py::Object &py_cb, mgp_list *args, mgp_graph *grap
std::optional<std::string> maybe_msg;
{
py::Object py_graph(MakePyGraph(graph, memory));
utils::OnScopeExit clean_up(PyObjectCleanup(py_graph));
utils::OnScopeExit clean_up(PyObjectCleanup(py_graph, kStartGarbageCollection));
if (py_graph) {
auto maybe_result = call(py_graph);
if (!maybe_result.HasError()) {
@ -1147,7 +1200,7 @@ PyObject *PyQueryModuleAddProcedure(PyQueryModule *self, PyObject *cb, bool is_w
auto *memory = self->module->procedures.get_allocator().GetMemoryResource();
mgp_proc proc(name,
[py_cb](mgp_list *args, mgp_graph *graph, mgp_result *result, mgp_memory *memory) {
CallPythonProcedure(py_cb, args, graph, result, memory);
CallPythonProcedure(py_cb, args, graph, result, memory, false);
},
memory, {.is_write = is_write_procedure});
const auto &[proc_it, did_insert] = self->module->procedures.emplace(name, std::move(proc));
@ -1160,6 +1213,52 @@ PyObject *PyQueryModuleAddProcedure(PyQueryModule *self, PyObject *cb, bool is_w
py_proc->callable = &proc_it->second;
return reinterpret_cast<PyObject *>(py_proc);
}
PyObject *PyQueryModuleAddBatchProcedure(PyQueryModule *self, PyObject *args, bool is_write_procedure) {
MG_ASSERT(self->module);
PyObject *cb{nullptr};
PyObject *initializer{nullptr};
PyObject *cleanup{nullptr};
if (!PyArg_ParseTuple(args, "OOO", &cb, &initializer, &cleanup)) {
return nullptr;
}
if (!PyCallable_Check(cb)) {
PyErr_SetString(PyExc_TypeError, "Expected a callable object.");
return nullptr;
}
auto py_cb = py::Object::FromBorrow(cb);
auto py_initializer = py::Object::FromBorrow(initializer);
auto py_cleanup = py::Object::FromBorrow(cleanup);
py::Object py_name(py_cb.GetAttr("__name__"));
const auto *name = PyUnicode_AsUTF8(py_name.Ptr());
if (!name) return nullptr;
if (!IsValidIdentifierName(name)) {
PyErr_SetString(PyExc_ValueError, "Procedure name is not a valid identifier");
return nullptr;
}
auto *memory = self->module->procedures.get_allocator().GetMemoryResource();
mgp_proc proc(
name,
[py_cb](mgp_list *args, mgp_graph *graph, mgp_result *result, mgp_memory *memory) {
CallPythonProcedure(py_cb, args, graph, result, memory, true);
},
[py_initializer](mgp_list *args, mgp_graph *graph, mgp_memory *memory) {
CallPythonInitializer(py_initializer, args, graph, memory);
},
[py_cleanup] { CallPythonCleanup(py_cleanup); }, memory, {.is_write = is_write_procedure, .is_batched = true});
const auto &[proc_it, did_insert] = self->module->procedures.emplace(name, std::move(proc));
if (!did_insert) {
PyErr_SetString(PyExc_ValueError, "Already registered a procedure with the same name.");
return nullptr;
}
auto *py_proc = PyObject_New(PyQueryProc, &PyQueryProcType); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
if (!py_proc) return nullptr;
py_proc->callable = &proc_it->second;
return reinterpret_cast<PyObject *>(py_proc);
}
} // namespace
PyObject *PyQueryModuleAddReadProcedure(PyQueryModule *self, PyObject *cb) {
@ -1170,6 +1269,14 @@ PyObject *PyQueryModuleAddWriteProcedure(PyQueryModule *self, PyObject *cb) {
return PyQueryModuleAddProcedure(self, cb, true);
}
PyObject *PyQueryModuleAddBatchReadProcedure(PyQueryModule *self, PyObject *args) {
return PyQueryModuleAddBatchProcedure(self, args, false);
}
PyObject *PyQueryModuleAddBatchWriteProcedure(PyQueryModule *self, PyObject *args) {
return PyQueryModuleAddBatchProcedure(self, args, true);
}
PyObject *PyQueryModuleAddTransformation(PyQueryModule *self, PyObject *cb) {
MG_ASSERT(self->module);
if (!PyCallable_Check(cb)) {
@ -1238,6 +1345,10 @@ static PyMethodDef PyQueryModuleMethods[] = {
"Register a read-only procedure with this module."},
{"add_write_procedure", reinterpret_cast<PyCFunction>(PyQueryModuleAddWriteProcedure), METH_O,
"Register a writeable procedure with this module."},
{"add_batch_read_procedure", reinterpret_cast<PyCFunction>(PyQueryModuleAddBatchReadProcedure), METH_VARARGS,
"Register a read-only batch procedure with this module."},
{"add_batch_write_procedure", reinterpret_cast<PyCFunction>(PyQueryModuleAddBatchWriteProcedure), METH_VARARGS,
"Register a writeable batched procedure with this module."},
{"add_transformation", reinterpret_cast<PyCFunction>(PyQueryModuleAddTransformation), METH_O,
"Register a transformation with this module."},
{"add_function", reinterpret_cast<PyCFunction>(PyQueryModuleAddFunction), METH_O,

View File

@ -58,6 +58,7 @@ add_subdirectory(mock_api)
add_subdirectory(load_csv)
add_subdirectory(init_file_flags)
add_subdirectory(analytical_mode)
add_subdirectory(batched_procedures)
copy_e2e_python_files(pytest_runner pytest_runner.sh "")
file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/memgraph-selfsigned.crt DESTINATION ${CMAKE_CURRENT_BINARY_DIR})

View File

@ -0,0 +1,9 @@
function(copy_batched_procedures_e2e_python_files FILE_NAME)
copy_e2e_python_files(batched_procedures ${FILE_NAME})
endfunction()
copy_batched_procedures_e2e_python_files(common.py)
copy_batched_procedures_e2e_python_files(conftest.py)
copy_batched_procedures_e2e_python_files(simple_read.py)
add_subdirectory(procedures)

View File

View File

@ -0,0 +1,34 @@
# Copyright 2023 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import typing
import mgclient
def execute_and_fetch_all(cursor: mgclient.Cursor, query: str, params: dict = {}) -> typing.List[tuple]:
cursor.execute(query, params)
return cursor.fetchall()
def connect(**kwargs) -> mgclient.Connection:
connection = mgclient.connect(host="localhost", port=7687, **kwargs)
connection.autocommit = True
return connection
def has_n_result_row(cursor: mgclient.Cursor, query: str, n: int):
results = execute_and_fetch_all(cursor, query)
return len(results) == n
def has_one_result_row(cursor: mgclient.Cursor, query: str):
return has_n_result_row(cursor, query, 1)

View File

@ -0,0 +1,26 @@
# Copyright 2023 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import pytest
from common import connect, execute_and_fetch_all
@pytest.fixture(autouse=True)
def connection():
connection = connect()
yield connection
cursor = connection.cursor()
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n")
def get_connection():
connection = connect()
return connection

View File

@ -0,0 +1,7 @@
copy_batched_procedures_e2e_python_files(batch_py_read.py)
copy_batched_procedures_e2e_python_files(batch_py_write.py)
add_query_module(batch_c_read batch_c_read.cpp)
add_subdirectory(common)

View File

@ -0,0 +1,136 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <exception>
#include <stdexcept>
#include <functional>
#include "mg_procedure.h"
#include "mgp.hpp"
namespace {
using mgp_int = decltype(mgp::Value().ValueInt());
static mgp_int num_ints{0};
static mgp_int returned_ints{0};
static mgp_int num_strings{0};
static int returned_strings{0};
const char *kReturnOutput = "output";
void NumsBatchInit(struct mgp_list *args, mgp_graph *graph, struct mgp_memory *memory) {
mgp::memory = memory;
const auto arguments = mgp::List(args);
if (arguments.Empty()) {
throw std::runtime_error("Expected to recieve argument");
}
if (arguments[0].Type() != mgp::Type::Int) {
throw std::runtime_error("Wrong type of first arguments");
}
const auto num = arguments[0].ValueInt();
num_ints = num;
}
void NumsBatch(struct mgp_list *args, mgp_graph *graph, mgp_result *result, struct mgp_memory *memory) {
mgp::memory = memory;
const auto arguments = mgp::List(args);
const auto record_factory = mgp::RecordFactory(result);
if (returned_ints < num_ints) {
auto record = record_factory.NewRecord();
record.Insert(kReturnOutput, ++returned_ints);
}
}
void NumsBatchCleanup() {
returned_ints = 0;
num_ints = 0;
}
void StringsBatchInit(struct mgp_list *args, mgp_graph *graph, struct mgp_memory *memory) {
mgp::memory = memory;
const auto arguments = mgp::List(args);
if (arguments.Empty()) {
throw std::runtime_error("Expected to recieve argument");
}
if (arguments[0].Type() != mgp::Type::Int) {
throw std::runtime_error("Wrong type of first arguments");
}
const auto num = arguments[0].ValueInt();
num_strings = num;
}
void StringsBatch(struct mgp_list *args, mgp_graph *graph, mgp_result *result, struct mgp_memory *memory) {
mgp::memory = memory;
const auto arguments = mgp::List(args);
const auto record_factory = mgp::RecordFactory(result);
if (returned_strings < num_strings) {
auto record = record_factory.NewRecord();
returned_strings++;
record.Insert(kReturnOutput, "output");
}
}
void StringsBatchCleanup() {
returned_strings = 0;
num_strings = 0;
}
} // namespace
// Each module needs to define mgp_init_module function.
// Here you can register multiple functions/procedures your module supports.
extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory) {
{
mgp_proc *proc{nullptr};
auto err_code =
mgp_module_add_batch_read_procedure(module, "batch_nums", NumsBatch, NumsBatchInit, NumsBatchCleanup, &proc);
if (err_code != mgp_error::MGP_ERROR_NO_ERROR) {
return 1;
}
mgp_type *type_int{nullptr};
static_cast<void>(mgp_type_int(&type_int));
err_code = mgp_proc_add_arg(proc, "num_ints", type_int);
if (err_code != mgp_error::MGP_ERROR_NO_ERROR) {
return 1;
}
mgp_type *return_type_int{nullptr};
static_cast<void>(mgp_type_int(&return_type_int));
err_code = mgp_proc_add_result(proc, "output", return_type_int);
if (err_code != mgp_error::MGP_ERROR_NO_ERROR) {
return 1;
}
}
{
try {
mgp::memory = memory;
mgp::AddBatchProcedure(StringsBatch, StringsBatchInit, StringsBatchCleanup, "batch_strings",
mgp::ProcedureType::Read, {mgp::Parameter("num_strings", mgp::Type::Int)},
{mgp::Return("output", mgp::Type::String)}, module, memory);
} catch (const std::exception &e) {
return 1;
}
}
return 0;
}
// This is an optional function if you need to release any resources before the
// module is unloaded. You will probably need this if you acquired some
// resources in mgp_init_module.
extern "C" int mgp_shutdown_module() { return 0; }

View File

@ -0,0 +1,134 @@
# Copyright 2023 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import mgp
# isort: off
from common.shared import BaseClass, InitializationGraphMutable, InitializationUnderlyingGraphMutable
initialization_underlying_graph_mutable = InitializationUnderlyingGraphMutable()
def cleanup_underlying():
initialization_underlying_graph_mutable.reset()
def init_underlying_graph_is_mutable(ctx: mgp.ProcCtx, object: mgp.Any):
initialization_underlying_graph_mutable.set()
def underlying_graph_is_mutable(ctx: mgp.ProcCtx, object: mgp.Any) -> mgp.Record(mutable=bool, init_called=bool):
if initialization_underlying_graph_mutable.get_to_return() > 0:
initialization_underlying_graph_mutable.increment_returned(1)
return mgp.Record(
mutable=object.underlying_graph_is_mutable(), init_called=initialization_underlying_graph_mutable.get()
)
return []
# Register batched
mgp.add_batch_read_proc(underlying_graph_is_mutable, init_underlying_graph_is_mutable, cleanup_underlying)
initialization_graph_mutable = InitializationGraphMutable()
def init_graph_is_mutable(ctx: mgp.ProcCtx):
initialization_graph_mutable.set()
def graph_is_mutable(ctx: mgp.ProcCtx) -> mgp.Record(mutable=bool, init_called=bool):
if initialization_graph_mutable.get_to_return() > 0:
initialization_graph_mutable.increment_returned(1)
return mgp.Record(mutable=ctx.graph.is_mutable(), init_called=initialization_graph_mutable.get())
return []
def cleanup_graph():
initialization_graph_mutable.reset()
# Register batched
mgp.add_batch_read_proc(graph_is_mutable, init_graph_is_mutable, cleanup_graph)
class BatchingNums(BaseClass):
def __init__(self, nums_to_return):
super().__init__(nums_to_return)
self._nums = []
self._i = 0
batching_nums = BatchingNums(10)
def init_batching_nums(ctx: mgp.ProcCtx):
batching_nums.set()
batching_nums._nums = [i for i in range(1, 11)]
batching_nums._i = 0
def batch_nums(ctx: mgp.ProcCtx) -> mgp.Record(num=int, init_called=bool, is_valid=bool):
if batching_nums.get_to_return() > 0:
batching_nums.increment_returned(1)
batching_nums._i += 1
return mgp.Record(
num=batching_nums._nums[batching_nums._i - 1],
init_called=batching_nums.get(),
is_valid=ctx.graph.is_valid(),
)
return []
def cleanup_batching_nums():
batching_nums.reset()
batching_nums._i = 0
# Register batched
mgp.add_batch_read_proc(batch_nums, init_batching_nums, cleanup_batching_nums)
class BatchingVertices(BaseClass):
def __init__(self):
super().__init__()
self._vertices = []
self._i = 0
batching_vertices = BatchingVertices()
def init_batching_vertices(ctx: mgp.ProcCtx):
print("init called")
print("graph is mutable", ctx.graph.is_mutable())
batching_vertices.set()
batching_vertices._vertices = list(ctx.graph.vertices)
batching_vertices._i = 0
batching_vertices._num_to_return = len(batching_vertices._vertices)
def batch_vertices(ctx: mgp.ProcCtx) -> mgp.Record(vertex_id=int, init_called=bool):
if batching_vertices.get_to_return() == 0:
return []
batching_vertices.increment_returned(1)
return mgp.Record(vertex=batching_vertices._vertices[batching_vertices._i].id, init_called=batching_vertices.get())
def cleanup_batching_vertices():
batching_vertices.reset()
batching_vertices._vertices = []
batching_vertices._i = 0
batching_vertices._num_to_return = 0
# Register batched
mgp.add_batch_read_proc(batch_vertices, init_batching_vertices, cleanup_batching_vertices)

View File

@ -0,0 +1,60 @@
# Copyright 2023 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import mgp
# isort: off
from common.shared import InitializationGraphMutable, InitializationUnderlyingGraphMutable
write_init_underlying_graph_mutable = InitializationUnderlyingGraphMutable()
def cleanup_underlying():
write_init_underlying_graph_mutable.reset()
def init_underlying_graph_is_mutable(ctx: mgp.ProcCtx, object: mgp.Any):
write_init_underlying_graph_mutable.set()
def underlying_graph_is_mutable(ctx: mgp.ProcCtx, object: mgp.Any) -> mgp.Record(mutable=bool, init_called=bool):
if write_init_underlying_graph_mutable.get_to_return() == 0:
return []
write_init_underlying_graph_mutable.increment_returned(1)
return mgp.Record(
mutable=object.underlying_graph_is_mutable(), init_called=write_init_underlying_graph_mutable.get()
)
# Register batched
mgp.add_batch_write_proc(underlying_graph_is_mutable, init_underlying_graph_is_mutable, cleanup_underlying)
write_init_graph_mutable = InitializationGraphMutable()
def init_graph_is_mutable(ctx: mgp.ProcCtx):
write_init_graph_mutable.set()
def graph_is_mutable(ctx: mgp.ProcCtx) -> mgp.Record(mutable=bool, init_called=bool):
if write_init_graph_mutable.get_to_return() > 0:
write_init_graph_mutable.increment_returned(1)
return mgp.Record(mutable=ctx.graph.is_mutable(), init_called=write_init_graph_mutable.get())
return []
def cleanup_graph():
write_init_graph_mutable.reset()
# Register batched
mgp.add_batch_write_proc(graph_is_mutable, init_graph_is_mutable, cleanup_graph)

View File

@ -0,0 +1 @@
copy_batched_procedures_e2e_python_files(shared.py)

View File

@ -0,0 +1,31 @@
class BaseClass:
def __init__(self, num_to_return=1) -> None:
self._init_is_called = False
self._num_to_return = num_to_return
self._num_returned = 0
def reset(self):
self._init_is_called = False
self._num_returned = 0
def set(self):
self._init_is_called = True
def get(self):
return self._init_is_called
def increment_returned(self, returned: int):
self._num_returned += returned
def get_to_return(self) -> int:
return self._num_to_return - self._num_returned
class InitializationUnderlyingGraphMutable(BaseClass):
def __init__(self):
super().__init__()
class InitializationGraphMutable(BaseClass):
def __init__(self):
super().__init__()

View File

@ -0,0 +1,143 @@
# Copyright 2023 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
# isort: off
import sys
import pytest
from common import execute_and_fetch_all, has_n_result_row, has_one_result_row
from conftest import get_connection
from mgclient import DatabaseError
@pytest.mark.parametrize(
"is_write",
[
True,
False,
],
)
def test_graph_mutability(is_write: bool, connection):
cursor = connection.cursor()
execute_and_fetch_all(cursor, f"MATCH (n) DETACH DELETE n")
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
module = "write" if is_write else "read"
result = list(
execute_and_fetch_all(
cursor,
f"CALL batch_py_{module}.graph_is_mutable() " "YIELD mutable, init_called RETURN mutable, init_called",
)
)
assert result == [(is_write, True)]
execute_and_fetch_all(cursor, "CREATE ()")
result = list(
execute_and_fetch_all(
cursor,
"MATCH (n) "
f"CALL batch_py_{module}.underlying_graph_is_mutable(n) "
"YIELD mutable, init_called RETURN mutable, init_called",
)
)
assert result == [(is_write, True)]
execute_and_fetch_all(cursor, "CREATE ()-[:TYPE]->()")
result = list(
execute_and_fetch_all(
cursor,
"MATCH (n)-[e]->(m) "
f"CALL batch_py_{module}.underlying_graph_is_mutable(e) "
"YIELD mutable, init_called RETURN mutable, init_called",
)
)
assert result == [(is_write, True)]
def test_batching_nums(connection):
cursor = connection.cursor()
execute_and_fetch_all(cursor, f"MATCH (n) DETACH DELETE n")
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
result = list(
execute_and_fetch_all(
cursor,
f"CALL batch_py_read.batch_nums() " "YIELD num, init_called, is_valid RETURN num, init_called, is_valid",
)
)
assert result == [(i, True, True) for i in range(1, 11)]
execute_and_fetch_all(cursor, "CREATE () CREATE ()")
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 2)
result = list(
execute_and_fetch_all(
cursor,
"MATCH (n) "
"CALL batch_py_read.batch_nums() "
"YIELD num, init_called, is_valid RETURN num, init_called, is_valid ",
)
)
assert result == [(i, True, True) for i in range(1, 11)] * 2
def test_batching_vertices(connection):
cursor = connection.cursor()
execute_and_fetch_all(cursor, f"MATCH (n) DETACH DELETE n")
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
execute_and_fetch_all(cursor, f"CREATE () CREATE ()")
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 2)
with pytest.raises(DatabaseError):
result = list(
execute_and_fetch_all(
cursor, f"CALL batch_py_read.batch_vertices() " "YIELD vertex, init_called RETURN vertex, init_called"
)
)
def test_batching_nums_c(connection):
cursor = connection.cursor()
execute_and_fetch_all(cursor, f"MATCH (n) DETACH DELETE n")
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
num_ints = 10
result = list(
execute_and_fetch_all(
cursor,
f"CALL batch_c_read.batch_nums({num_ints}) " "YIELD output RETURN output",
)
)
result_list = [item[0] for item in result]
print(result_list)
print([i for i in range(1, num_ints + 1)])
assert result_list == [i for i in range(1, num_ints + 1)]
def test_batching_strings_c(connection):
cursor = connection.cursor()
execute_and_fetch_all(cursor, f"MATCH (n) DETACH DELETE n")
assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0)
num_strings = 10
result = list(
execute_and_fetch_all(
cursor,
f"CALL batch_c_read.batch_strings({num_strings}) " "YIELD output RETURN output",
)
)
assert len(result) == num_strings
if __name__ == "__main__":
sys.exit(pytest.main([__file__, "-rA"]))

View File

@ -0,0 +1,14 @@
template_cluster: &template_cluster
cluster:
main:
args: ["--bolt-port", "7687", "--log-level=TRACE"]
log_file: "batched-procedures-e2e.log"
setup_queries: []
validation_queries: []
workloads:
- name: "Batched procedures read"
binary: "tests/e2e/pytest_runner.sh"
proc: "tests/e2e/batched_procedures/procedures/"
args: ["batched_procedures/simple_read.py"]
<<: *template_cluster