Rework CallProcedure to validate result sets
Summary: This diff renames `__reload__` procedure to be `mg.reload` accepting a module name. The main CallCustomProcedure function is now split into multiple parts, so that there's more control over finding a procedure, type checking its arguments and finally checking the returned result set. Depends on D2572 Reviewers: mferencevic, ipaljak Reviewed By: ipaljak Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D2573
This commit is contained in:
parent
d71f1bfa35
commit
e31331aae4
@ -68,7 +68,8 @@ int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory) {
|
|||||||
}
|
}
|
||||||
mgp_value_destroy(null_value);
|
mgp_value_destroy(null_value);
|
||||||
if (!mgp_proc_add_result(proc, "result", mgp_type_string())) return 1;
|
if (!mgp_proc_add_result(proc, "result", mgp_type_string())) return 1;
|
||||||
if (!mgp_proc_add_result(proc, "args", mgp_type_list(mgp_type_any())))
|
if (!mgp_proc_add_result(proc, "args",
|
||||||
|
mgp_type_list(mgp_type_nullable(mgp_type_any()))))
|
||||||
return 1;
|
return 1;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -3735,24 +3735,47 @@ std::vector<Symbol> CallProcedure::ModifiedSymbols(
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
void CallCustomProcedure(const std::string_view &fully_qualified_procedure_name,
|
// Return true if we handled one of the special `mg` module procedures for
|
||||||
const std::vector<Expression *> &args,
|
// reloading query modules.
|
||||||
storage::View graph_view, const ExecutionContext &ctx,
|
// @throw QueryRuntimeException in case of error during procedure invocation.
|
||||||
Frame *frame, mgp_result *result) {
|
bool HandleReloadProcedures(
|
||||||
// Use evaluation memory, as invoking a procedure is akin to a simple
|
const std::string_view &fully_qualified_procedure_name,
|
||||||
// evaluation of an expression.
|
const std::vector<Expression *> &args, ExpressionEvaluator *evaluator) {
|
||||||
// TODO: This will probably need to be changed when we add support for
|
|
||||||
// generator like procedures which yield a new result on each invocation.
|
|
||||||
auto *memory = ctx.evaluation_context.memory;
|
|
||||||
// First try to handle special procedure invocations for (re)loading modules.
|
|
||||||
// It would be great to simply register `reload_all_modules` as a
|
// It would be great to simply register `reload_all_modules` as a
|
||||||
// regular procedure on a `mg` module, so we don't have a special case here.
|
// regular procedure on a `mg` module, so we don't have a special case here.
|
||||||
// Unfortunately, reloading requires taking a write lock, and we would
|
// Unfortunately, reloading requires taking a write lock, and we would
|
||||||
// acquire a read lock by getting the module.
|
// acquire a read lock by getting the module.
|
||||||
if (fully_qualified_procedure_name == "mg.reload_all_modules") {
|
if (fully_qualified_procedure_name == "mg.reload_all_modules") {
|
||||||
|
if (!args.empty())
|
||||||
|
throw QueryRuntimeException(
|
||||||
|
"'mg.reload_all_modules' requires no arguments.");
|
||||||
procedure::gModuleRegistry.ReloadAllModules();
|
procedure::gModuleRegistry.ReloadAllModules();
|
||||||
return;
|
return true;
|
||||||
|
} else if (fully_qualified_procedure_name == "mg.reload") {
|
||||||
|
// This is a special case for the same reasons as `mg.reload_all_modules`.
|
||||||
|
if (args.size() != 1U)
|
||||||
|
throw QueryRuntimeException("'mg.reload' requires exactly 1 argument.");
|
||||||
|
const auto &arg = args.front()->Accept(*evaluator);
|
||||||
|
if (!arg.IsString()) {
|
||||||
|
throw QueryRuntimeException(
|
||||||
|
"'mg.reload' argument named 'module_name' at position 0 must be of "
|
||||||
|
"type STRING.");
|
||||||
|
}
|
||||||
|
const auto &module_name = arg.ValueString();
|
||||||
|
procedure::gModuleRegistry.ReloadModuleNamed(module_name);
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the ModulePtr and `mgp_proc *` of the found procedure after resolving
|
||||||
|
// `fully_qualified_procedure_name`. `memory` is used for temporary allocations
|
||||||
|
// inside this function. ModulePtr must be kept alive to make sure it won't be
|
||||||
|
// unloaded.
|
||||||
|
// @throw QueryRuntimeException if unable to find the procedure.
|
||||||
|
std::pair<procedure::ModulePtr, const mgp_proc *> FindProcedureOrThrow(
|
||||||
|
const std::string_view &fully_qualified_procedure_name,
|
||||||
|
utils::MemoryResource *memory) {
|
||||||
utils::pmr::vector<std::string_view> name_parts(memory);
|
utils::pmr::vector<std::string_view> name_parts(memory);
|
||||||
utils::Split(&name_parts, fully_qualified_procedure_name, ".");
|
utils::Split(&name_parts, fully_qualified_procedure_name, ".");
|
||||||
if (name_parts.size() == 1U) {
|
if (name_parts.size() == 1U) {
|
||||||
@ -3764,27 +3787,26 @@ void CallCustomProcedure(const std::string_view &fully_qualified_procedure_name,
|
|||||||
const auto &module_name =
|
const auto &module_name =
|
||||||
fully_qualified_procedure_name.substr(0, last_dot_pos);
|
fully_qualified_procedure_name.substr(0, last_dot_pos);
|
||||||
const auto &proc_name = name_parts.back();
|
const auto &proc_name = name_parts.back();
|
||||||
// This is a special case for the same reasons as `mg.reload_all_modules`.
|
auto module = procedure::gModuleRegistry.GetModuleNamed(module_name);
|
||||||
if (proc_name == "__reload__") {
|
|
||||||
procedure::gModuleRegistry.ReloadModuleNamed(module_name);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const auto &module = procedure::gModuleRegistry.GetModuleNamed(module_name);
|
|
||||||
if (!module) throw QueryRuntimeException("'{}' isn't loaded!", module_name);
|
if (!module) throw QueryRuntimeException("'{}' isn't loaded!", module_name);
|
||||||
static_assert(std::uses_allocator_v<mgp_value, utils::Allocator<mgp_value>>,
|
|
||||||
"Expected mgp_value to use custom allocator and makes STL "
|
|
||||||
"containers aware of that");
|
|
||||||
const auto &proc_it = module->procedures.find(proc_name);
|
const auto &proc_it = module->procedures.find(proc_name);
|
||||||
if (proc_it == module->procedures.end())
|
if (proc_it == module->procedures.end())
|
||||||
throw QueryRuntimeException("'{}' does not have a procedure named '{}'",
|
throw QueryRuntimeException("'{}' does not have a procedure named '{}'",
|
||||||
module_name, proc_name);
|
module_name, proc_name);
|
||||||
const auto &proc = proc_it->second;
|
return {std::move(module), &proc_it->second};
|
||||||
|
}
|
||||||
|
|
||||||
|
void CallCustomProcedure(const std::string_view &fully_qualified_procedure_name,
|
||||||
|
const mgp_proc &proc,
|
||||||
|
const std::vector<Expression *> &args,
|
||||||
|
const mgp_graph &graph, ExpressionEvaluator *evaluator,
|
||||||
|
utils::MemoryResource *memory, mgp_result *result) {
|
||||||
|
static_assert(std::uses_allocator_v<mgp_value, utils::Allocator<mgp_value>>,
|
||||||
|
"Expected mgp_value to use custom allocator and makes STL "
|
||||||
|
"containers aware of that");
|
||||||
// Build and type check procedure arguments.
|
// Build and type check procedure arguments.
|
||||||
mgp_graph graph{ctx.db_accessor, graph_view};
|
|
||||||
mgp_list proc_args(memory);
|
mgp_list proc_args(memory);
|
||||||
proc_args.elems.reserve(args.size());
|
proc_args.elems.reserve(args.size());
|
||||||
ExpressionEvaluator evaluator(frame, ctx.symbol_table, ctx.evaluation_context,
|
|
||||||
ctx.db_accessor, graph_view);
|
|
||||||
if (args.size() < proc.args.size() ||
|
if (args.size() < proc.args.size() ||
|
||||||
// Rely on `||` short circuit so we can avoid potential overflow of
|
// Rely on `||` short circuit so we can avoid potential overflow of
|
||||||
// proc.args.size() + proc.opt_args.size() by subtracting.
|
// proc.args.size() + proc.opt_args.size() by subtracting.
|
||||||
@ -3804,7 +3826,7 @@ void CallCustomProcedure(const std::string_view &fully_qualified_procedure_name,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < args.size(); ++i) {
|
for (size_t i = 0; i < args.size(); ++i) {
|
||||||
auto arg = args[i]->Accept(evaluator);
|
auto arg = args[i]->Accept(*evaluator);
|
||||||
std::string_view name;
|
std::string_view name;
|
||||||
const query::procedure::CypherType *type;
|
const query::procedure::CypherType *type;
|
||||||
if (proc.args.size() > i) {
|
if (proc.args.size() > i) {
|
||||||
@ -3833,6 +3855,7 @@ void CallCustomProcedure(const std::string_view &fully_qualified_procedure_name,
|
|||||||
utils::LimitedMemoryResource limited_mem(memory,
|
utils::LimitedMemoryResource limited_mem(memory,
|
||||||
100 * 1024 * 1024 /* 100 MB */);
|
100 * 1024 * 1024 /* 100 MB */);
|
||||||
mgp_memory proc_memory{&limited_mem};
|
mgp_memory proc_memory{&limited_mem};
|
||||||
|
CHECK(result->signature == &proc.results);
|
||||||
// TODO: What about cross library boundary exceptions? OMG C++?!
|
// TODO: What about cross library boundary exceptions? OMG C++?!
|
||||||
proc.cb(&proc_args, &graph, result, &proc_memory);
|
proc.cb(&proc_args, &graph, result, &proc_memory);
|
||||||
size_t leaked_bytes = limited_mem.GetAllocatedBytes();
|
size_t leaked_bytes = limited_mem.GetAllocatedBytes();
|
||||||
@ -3856,7 +3879,7 @@ class CallProcedureCursor : public Cursor {
|
|||||||
// result_ needs to live throughout multiple Pull evaluations, until all
|
// result_ needs to live throughout multiple Pull evaluations, until all
|
||||||
// rows are produced. Therefore, we use the memory dedicated for the
|
// rows are produced. Therefore, we use the memory dedicated for the
|
||||||
// whole execution.
|
// whole execution.
|
||||||
result_(mem) {
|
result_(nullptr, mem) {
|
||||||
CHECK(self_->result_fields_.size() == self_->result_symbols_.size())
|
CHECK(self_->result_fields_.size() == self_->result_symbols_.size())
|
||||||
<< "Incorrectly constructed CallProcedure";
|
<< "Incorrectly constructed CallProcedure";
|
||||||
}
|
}
|
||||||
@ -3866,6 +3889,7 @@ class CallProcedureCursor : public Cursor {
|
|||||||
|
|
||||||
if (MustAbort(context)) throw HintedAbortError();
|
if (MustAbort(context)) throw HintedAbortError();
|
||||||
|
|
||||||
|
size_t result_signature_size = 0;
|
||||||
// We need to fetch new procedure results after pulling from input.
|
// We need to fetch new procedure results after pulling from input.
|
||||||
// TODO: Look into openCypher's distinction between procedures returning an
|
// TODO: Look into openCypher's distinction between procedures returning an
|
||||||
// empty result set vs procedures which return `void`. We currently don't
|
// empty result set vs procedures which return `void`. We currently don't
|
||||||
@ -3873,13 +3897,40 @@ class CallProcedureCursor : public Cursor {
|
|||||||
// This `while` loop will skip over empty results.
|
// This `while` loop will skip over empty results.
|
||||||
while (result_row_it_ == result_.rows.end()) {
|
while (result_row_it_ == result_.rows.end()) {
|
||||||
if (!input_cursor_->Pull(frame, context)) return false;
|
if (!input_cursor_->Pull(frame, context)) return false;
|
||||||
|
result_.signature = nullptr;
|
||||||
result_.rows.clear();
|
result_.rows.clear();
|
||||||
result_.error_msg.reset();
|
result_.error_msg.reset();
|
||||||
// TODO: When we add support for write and eager procedures, we will need
|
// TODO: When we add support for write and eager procedures, we will need
|
||||||
// to plan this operator with Accumulate and pass in storage::View::NEW.
|
// to plan this operator with Accumulate and pass in storage::View::NEW.
|
||||||
auto graph_view = storage::View::OLD;
|
auto graph_view = storage::View::OLD;
|
||||||
CallCustomProcedure(self_->procedure_name_, self_->arguments_, graph_view,
|
ExpressionEvaluator evaluator(&frame, context.symbol_table,
|
||||||
context, &frame, &result_);
|
context.evaluation_context,
|
||||||
|
context.db_accessor, graph_view);
|
||||||
|
// First try to handle special procedures for (re)loading modules.
|
||||||
|
if (HandleReloadProcedures(self_->procedure_name_, self_->arguments_,
|
||||||
|
&evaluator))
|
||||||
|
continue;
|
||||||
|
// Nothing special, so find the regular procedure and invoke it.
|
||||||
|
// It might be a good idea to resolve the procedure name once, at the
|
||||||
|
// start. Unfortunately, this could deadlock if we tried to invoke a
|
||||||
|
// procedure from a module (read lock) and reload a module (write lock)
|
||||||
|
// inside the same execution thread.
|
||||||
|
const auto &[module, proc] = FindProcedureOrThrow(
|
||||||
|
self_->procedure_name_, context.evaluation_context.memory);
|
||||||
|
result_.signature = &proc->results;
|
||||||
|
// Use evaluation memory, as invoking a procedure is akin to a simple
|
||||||
|
// evaluation of an expression.
|
||||||
|
// TODO: This will probably need to be changed when we add support for
|
||||||
|
// generator like procedures which yield a new result on each invocation.
|
||||||
|
auto *memory = context.evaluation_context.memory;
|
||||||
|
mgp_graph graph{context.db_accessor, graph_view};
|
||||||
|
CallCustomProcedure(self_->procedure_name_, *proc, self_->arguments_,
|
||||||
|
graph, &evaluator, memory, &result_);
|
||||||
|
// Reset result_.signature to nullptr, because outside of this scope we
|
||||||
|
// will no longer hold a lock on the `module`. If someone were to reload
|
||||||
|
// it, the pointer would be invalid.
|
||||||
|
result_signature_size = result_.signature->size();
|
||||||
|
result_.signature = nullptr;
|
||||||
if (result_.error_msg) {
|
if (result_.error_msg) {
|
||||||
throw QueryRuntimeException("{}: {}", self_->procedure_name_,
|
throw QueryRuntimeException("{}: {}", self_->procedure_name_,
|
||||||
*result_.error_msg);
|
*result_.error_msg);
|
||||||
@ -3887,8 +3938,17 @@ class CallProcedureCursor : public Cursor {
|
|||||||
result_row_it_ = result_.rows.begin();
|
result_row_it_ = result_.rows.begin();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const auto &values = result_row_it_->values;
|
||||||
|
// Check that the row has all fields as required by the result signature.
|
||||||
|
// C API guarantees that it's impossible to set fields which are not part of
|
||||||
|
// the result record, but it does not gurantee that some may be missing. See
|
||||||
|
// `mgp_result_record_insert`.
|
||||||
|
if (values.size() != result_signature_size) {
|
||||||
|
throw QueryRuntimeException(
|
||||||
|
"Procedure '{}' did not yield all fields as required by its "
|
||||||
|
"signature.", self_->procedure_name_);
|
||||||
|
}
|
||||||
for (size_t i = 0; i < self_->result_fields_.size(); ++i) {
|
for (size_t i = 0; i < self_->result_fields_.size(); ++i) {
|
||||||
const auto &values = result_row_it_->values;
|
|
||||||
std::string_view field_name(self_->result_fields_[i]);
|
std::string_view field_name(self_->result_fields_[i]);
|
||||||
auto result_it = values.find(field_name);
|
auto result_it = values.find(field_name);
|
||||||
if (result_it == values.end()) {
|
if (result_it == values.end()) {
|
||||||
|
@ -792,8 +792,10 @@ int mgp_result_set_error_msg(mgp_result *res, const char *msg) {
|
|||||||
|
|
||||||
mgp_result_record *mgp_result_new_record(mgp_result *res) {
|
mgp_result_record *mgp_result_new_record(mgp_result *res) {
|
||||||
auto *memory = res->rows.get_allocator().GetMemoryResource();
|
auto *memory = res->rows.get_allocator().GetMemoryResource();
|
||||||
|
CHECK(res->signature) << "Expected to have a valid signature";
|
||||||
try {
|
try {
|
||||||
res->rows.push_back(mgp_result_record{
|
res->rows.push_back(mgp_result_record{
|
||||||
|
res->signature,
|
||||||
utils::pmr::map<utils::pmr::string, query::TypedValue>(memory)});
|
utils::pmr::map<utils::pmr::string, query::TypedValue>(memory)});
|
||||||
} catch (...) {
|
} catch (...) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -804,8 +806,12 @@ mgp_result_record *mgp_result_new_record(mgp_result *res) {
|
|||||||
int mgp_result_record_insert(mgp_result_record *record, const char *field_name,
|
int mgp_result_record_insert(mgp_result_record *record, const char *field_name,
|
||||||
const mgp_value *val) {
|
const mgp_value *val) {
|
||||||
auto *memory = record->values.get_allocator().GetMemoryResource();
|
auto *memory = record->values.get_allocator().GetMemoryResource();
|
||||||
// TODO: Result validation when we add registering procedures with result
|
// Validate field_name & val satisfy the procedure's result signature.
|
||||||
// signature description.
|
CHECK(record->signature) << "Expected to have a valid signature";
|
||||||
|
auto find_it = record->signature->find(field_name);
|
||||||
|
if (find_it == record->signature->end()) return 0;
|
||||||
|
const auto *type = find_it->second.first;
|
||||||
|
if (!type->SatisfiesType(*val)) return 0;
|
||||||
try {
|
try {
|
||||||
record->values.emplace(field_name, ToTypedValue(*val, memory));
|
record->values.emplace(field_name, ToTypedValue(*val, memory));
|
||||||
} catch (...) {
|
} catch (...) {
|
||||||
|
@ -327,12 +327,25 @@ struct mgp_path {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct mgp_result_record {
|
struct mgp_result_record {
|
||||||
|
/// Result record signature as defined for mgp_proc.
|
||||||
|
const utils::pmr::map<utils::pmr::string,
|
||||||
|
std::pair<const query::procedure::CypherType *, bool>>
|
||||||
|
*signature;
|
||||||
utils::pmr::map<utils::pmr::string, query::TypedValue> values;
|
utils::pmr::map<utils::pmr::string, query::TypedValue> values;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct mgp_result {
|
struct mgp_result {
|
||||||
explicit mgp_result(utils::MemoryResource *mem) : rows(mem) {}
|
explicit mgp_result(
|
||||||
|
const utils::pmr::map<
|
||||||
|
utils::pmr::string,
|
||||||
|
std::pair<const query::procedure::CypherType *, bool>> *signature,
|
||||||
|
utils::MemoryResource *mem)
|
||||||
|
: signature(signature), rows(mem) {}
|
||||||
|
|
||||||
|
/// Result record signature as defined for mgp_proc.
|
||||||
|
const utils::pmr::map<utils::pmr::string,
|
||||||
|
std::pair<const query::procedure::CypherType *, bool>>
|
||||||
|
*signature;
|
||||||
utils::pmr::vector<mgp_result_record> rows;
|
utils::pmr::vector<mgp_result_record> rows;
|
||||||
std::optional<utils::pmr::string> error_msg;
|
std::optional<utils::pmr::string> error_msg;
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user