Rework CallProcedure to validate result sets

Summary:
This diff renames `__reload__` procedure to be `mg.reload` accepting a
module name. The main CallCustomProcedure function is now split into
multiple parts, so that there's more control over finding a procedure,
type checking its arguments and finally checking the returned result
set.

Depends on D2572

Reviewers: mferencevic, ipaljak

Reviewed By: ipaljak

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D2573
This commit is contained in:
Teon Banek 2019-11-26 16:07:55 +01:00
parent d71f1bfa35
commit e31331aae4
4 changed files with 113 additions and 33 deletions

View File

@ -68,7 +68,8 @@ int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory) {
} }
mgp_value_destroy(null_value); mgp_value_destroy(null_value);
if (!mgp_proc_add_result(proc, "result", mgp_type_string())) return 1; if (!mgp_proc_add_result(proc, "result", mgp_type_string())) return 1;
if (!mgp_proc_add_result(proc, "args", mgp_type_list(mgp_type_any()))) if (!mgp_proc_add_result(proc, "args",
mgp_type_list(mgp_type_nullable(mgp_type_any()))))
return 1; return 1;
return 0; return 0;
} }

View File

@ -3735,24 +3735,47 @@ std::vector<Symbol> CallProcedure::ModifiedSymbols(
namespace { namespace {
void CallCustomProcedure(const std::string_view &fully_qualified_procedure_name, // Return true if we handled one of the special `mg` module procedures for
const std::vector<Expression *> &args, // reloading query modules.
storage::View graph_view, const ExecutionContext &ctx, // @throw QueryRuntimeException in case of error during procedure invocation.
Frame *frame, mgp_result *result) { bool HandleReloadProcedures(
// Use evaluation memory, as invoking a procedure is akin to a simple const std::string_view &fully_qualified_procedure_name,
// evaluation of an expression. const std::vector<Expression *> &args, ExpressionEvaluator *evaluator) {
// TODO: This will probably need to be changed when we add support for
// generator like procedures which yield a new result on each invocation.
auto *memory = ctx.evaluation_context.memory;
// First try to handle special procedure invocations for (re)loading modules.
// It would be great to simply register `reload_all_modules` as a // It would be great to simply register `reload_all_modules` as a
// regular procedure on a `mg` module, so we don't have a special case here. // regular procedure on a `mg` module, so we don't have a special case here.
// Unfortunately, reloading requires taking a write lock, and we would // Unfortunately, reloading requires taking a write lock, and we would
// acquire a read lock by getting the module. // acquire a read lock by getting the module.
if (fully_qualified_procedure_name == "mg.reload_all_modules") { if (fully_qualified_procedure_name == "mg.reload_all_modules") {
if (!args.empty())
throw QueryRuntimeException(
"'mg.reload_all_modules' requires no arguments.");
procedure::gModuleRegistry.ReloadAllModules(); procedure::gModuleRegistry.ReloadAllModules();
return; return true;
} else if (fully_qualified_procedure_name == "mg.reload") {
// This is a special case for the same reasons as `mg.reload_all_modules`.
if (args.size() != 1U)
throw QueryRuntimeException("'mg.reload' requires exactly 1 argument.");
const auto &arg = args.front()->Accept(*evaluator);
if (!arg.IsString()) {
throw QueryRuntimeException(
"'mg.reload' argument named 'module_name' at position 0 must be of "
"type STRING.");
}
const auto &module_name = arg.ValueString();
procedure::gModuleRegistry.ReloadModuleNamed(module_name);
return true;
} }
return false;
}
// Return the ModulePtr and `mgp_proc *` of the found procedure after resolving
// `fully_qualified_procedure_name`. `memory` is used for temporary allocations
// inside this function. ModulePtr must be kept alive to make sure it won't be
// unloaded.
// @throw QueryRuntimeException if unable to find the procedure.
std::pair<procedure::ModulePtr, const mgp_proc *> FindProcedureOrThrow(
const std::string_view &fully_qualified_procedure_name,
utils::MemoryResource *memory) {
utils::pmr::vector<std::string_view> name_parts(memory); utils::pmr::vector<std::string_view> name_parts(memory);
utils::Split(&name_parts, fully_qualified_procedure_name, "."); utils::Split(&name_parts, fully_qualified_procedure_name, ".");
if (name_parts.size() == 1U) { if (name_parts.size() == 1U) {
@ -3764,27 +3787,26 @@ void CallCustomProcedure(const std::string_view &fully_qualified_procedure_name,
const auto &module_name = const auto &module_name =
fully_qualified_procedure_name.substr(0, last_dot_pos); fully_qualified_procedure_name.substr(0, last_dot_pos);
const auto &proc_name = name_parts.back(); const auto &proc_name = name_parts.back();
// This is a special case for the same reasons as `mg.reload_all_modules`. auto module = procedure::gModuleRegistry.GetModuleNamed(module_name);
if (proc_name == "__reload__") {
procedure::gModuleRegistry.ReloadModuleNamed(module_name);
return;
}
const auto &module = procedure::gModuleRegistry.GetModuleNamed(module_name);
if (!module) throw QueryRuntimeException("'{}' isn't loaded!", module_name); if (!module) throw QueryRuntimeException("'{}' isn't loaded!", module_name);
static_assert(std::uses_allocator_v<mgp_value, utils::Allocator<mgp_value>>,
"Expected mgp_value to use custom allocator and makes STL "
"containers aware of that");
const auto &proc_it = module->procedures.find(proc_name); const auto &proc_it = module->procedures.find(proc_name);
if (proc_it == module->procedures.end()) if (proc_it == module->procedures.end())
throw QueryRuntimeException("'{}' does not have a procedure named '{}'", throw QueryRuntimeException("'{}' does not have a procedure named '{}'",
module_name, proc_name); module_name, proc_name);
const auto &proc = proc_it->second; return {std::move(module), &proc_it->second};
}
void CallCustomProcedure(const std::string_view &fully_qualified_procedure_name,
const mgp_proc &proc,
const std::vector<Expression *> &args,
const mgp_graph &graph, ExpressionEvaluator *evaluator,
utils::MemoryResource *memory, mgp_result *result) {
static_assert(std::uses_allocator_v<mgp_value, utils::Allocator<mgp_value>>,
"Expected mgp_value to use custom allocator and makes STL "
"containers aware of that");
// Build and type check procedure arguments. // Build and type check procedure arguments.
mgp_graph graph{ctx.db_accessor, graph_view};
mgp_list proc_args(memory); mgp_list proc_args(memory);
proc_args.elems.reserve(args.size()); proc_args.elems.reserve(args.size());
ExpressionEvaluator evaluator(frame, ctx.symbol_table, ctx.evaluation_context,
ctx.db_accessor, graph_view);
if (args.size() < proc.args.size() || if (args.size() < proc.args.size() ||
// Rely on `||` short circuit so we can avoid potential overflow of // Rely on `||` short circuit so we can avoid potential overflow of
// proc.args.size() + proc.opt_args.size() by subtracting. // proc.args.size() + proc.opt_args.size() by subtracting.
@ -3804,7 +3826,7 @@ void CallCustomProcedure(const std::string_view &fully_qualified_procedure_name,
} }
} }
for (size_t i = 0; i < args.size(); ++i) { for (size_t i = 0; i < args.size(); ++i) {
auto arg = args[i]->Accept(evaluator); auto arg = args[i]->Accept(*evaluator);
std::string_view name; std::string_view name;
const query::procedure::CypherType *type; const query::procedure::CypherType *type;
if (proc.args.size() > i) { if (proc.args.size() > i) {
@ -3833,6 +3855,7 @@ void CallCustomProcedure(const std::string_view &fully_qualified_procedure_name,
utils::LimitedMemoryResource limited_mem(memory, utils::LimitedMemoryResource limited_mem(memory,
100 * 1024 * 1024 /* 100 MB */); 100 * 1024 * 1024 /* 100 MB */);
mgp_memory proc_memory{&limited_mem}; mgp_memory proc_memory{&limited_mem};
CHECK(result->signature == &proc.results);
// TODO: What about cross library boundary exceptions? OMG C++?! // TODO: What about cross library boundary exceptions? OMG C++?!
proc.cb(&proc_args, &graph, result, &proc_memory); proc.cb(&proc_args, &graph, result, &proc_memory);
size_t leaked_bytes = limited_mem.GetAllocatedBytes(); size_t leaked_bytes = limited_mem.GetAllocatedBytes();
@ -3856,7 +3879,7 @@ class CallProcedureCursor : public Cursor {
// result_ needs to live throughout multiple Pull evaluations, until all // result_ needs to live throughout multiple Pull evaluations, until all
// rows are produced. Therefore, we use the memory dedicated for the // rows are produced. Therefore, we use the memory dedicated for the
// whole execution. // whole execution.
result_(mem) { result_(nullptr, mem) {
CHECK(self_->result_fields_.size() == self_->result_symbols_.size()) CHECK(self_->result_fields_.size() == self_->result_symbols_.size())
<< "Incorrectly constructed CallProcedure"; << "Incorrectly constructed CallProcedure";
} }
@ -3866,6 +3889,7 @@ class CallProcedureCursor : public Cursor {
if (MustAbort(context)) throw HintedAbortError(); if (MustAbort(context)) throw HintedAbortError();
size_t result_signature_size = 0;
// We need to fetch new procedure results after pulling from input. // We need to fetch new procedure results after pulling from input.
// TODO: Look into openCypher's distinction between procedures returning an // TODO: Look into openCypher's distinction between procedures returning an
// empty result set vs procedures which return `void`. We currently don't // empty result set vs procedures which return `void`. We currently don't
@ -3873,13 +3897,40 @@ class CallProcedureCursor : public Cursor {
// This `while` loop will skip over empty results. // This `while` loop will skip over empty results.
while (result_row_it_ == result_.rows.end()) { while (result_row_it_ == result_.rows.end()) {
if (!input_cursor_->Pull(frame, context)) return false; if (!input_cursor_->Pull(frame, context)) return false;
result_.signature = nullptr;
result_.rows.clear(); result_.rows.clear();
result_.error_msg.reset(); result_.error_msg.reset();
// TODO: When we add support for write and eager procedures, we will need // TODO: When we add support for write and eager procedures, we will need
// to plan this operator with Accumulate and pass in storage::View::NEW. // to plan this operator with Accumulate and pass in storage::View::NEW.
auto graph_view = storage::View::OLD; auto graph_view = storage::View::OLD;
CallCustomProcedure(self_->procedure_name_, self_->arguments_, graph_view, ExpressionEvaluator evaluator(&frame, context.symbol_table,
context, &frame, &result_); context.evaluation_context,
context.db_accessor, graph_view);
// First try to handle special procedures for (re)loading modules.
if (HandleReloadProcedures(self_->procedure_name_, self_->arguments_,
&evaluator))
continue;
// Nothing special, so find the regular procedure and invoke it.
// It might be a good idea to resolve the procedure name once, at the
// start. Unfortunately, this could deadlock if we tried to invoke a
// procedure from a module (read lock) and reload a module (write lock)
// inside the same execution thread.
const auto &[module, proc] = FindProcedureOrThrow(
self_->procedure_name_, context.evaluation_context.memory);
result_.signature = &proc->results;
// Use evaluation memory, as invoking a procedure is akin to a simple
// evaluation of an expression.
// TODO: This will probably need to be changed when we add support for
// generator like procedures which yield a new result on each invocation.
auto *memory = context.evaluation_context.memory;
mgp_graph graph{context.db_accessor, graph_view};
CallCustomProcedure(self_->procedure_name_, *proc, self_->arguments_,
graph, &evaluator, memory, &result_);
// Reset result_.signature to nullptr, because outside of this scope we
// will no longer hold a lock on the `module`. If someone were to reload
// it, the pointer would be invalid.
result_signature_size = result_.signature->size();
result_.signature = nullptr;
if (result_.error_msg) { if (result_.error_msg) {
throw QueryRuntimeException("{}: {}", self_->procedure_name_, throw QueryRuntimeException("{}: {}", self_->procedure_name_,
*result_.error_msg); *result_.error_msg);
@ -3887,8 +3938,17 @@ class CallProcedureCursor : public Cursor {
result_row_it_ = result_.rows.begin(); result_row_it_ = result_.rows.begin();
} }
const auto &values = result_row_it_->values;
// Check that the row has all fields as required by the result signature.
// C API guarantees that it's impossible to set fields which are not part of
// the result record, but it does not gurantee that some may be missing. See
// `mgp_result_record_insert`.
if (values.size() != result_signature_size) {
throw QueryRuntimeException(
"Procedure '{}' did not yield all fields as required by its "
"signature.", self_->procedure_name_);
}
for (size_t i = 0; i < self_->result_fields_.size(); ++i) { for (size_t i = 0; i < self_->result_fields_.size(); ++i) {
const auto &values = result_row_it_->values;
std::string_view field_name(self_->result_fields_[i]); std::string_view field_name(self_->result_fields_[i]);
auto result_it = values.find(field_name); auto result_it = values.find(field_name);
if (result_it == values.end()) { if (result_it == values.end()) {

View File

@ -792,8 +792,10 @@ int mgp_result_set_error_msg(mgp_result *res, const char *msg) {
mgp_result_record *mgp_result_new_record(mgp_result *res) { mgp_result_record *mgp_result_new_record(mgp_result *res) {
auto *memory = res->rows.get_allocator().GetMemoryResource(); auto *memory = res->rows.get_allocator().GetMemoryResource();
CHECK(res->signature) << "Expected to have a valid signature";
try { try {
res->rows.push_back(mgp_result_record{ res->rows.push_back(mgp_result_record{
res->signature,
utils::pmr::map<utils::pmr::string, query::TypedValue>(memory)}); utils::pmr::map<utils::pmr::string, query::TypedValue>(memory)});
} catch (...) { } catch (...) {
return nullptr; return nullptr;
@ -804,8 +806,12 @@ mgp_result_record *mgp_result_new_record(mgp_result *res) {
int mgp_result_record_insert(mgp_result_record *record, const char *field_name, int mgp_result_record_insert(mgp_result_record *record, const char *field_name,
const mgp_value *val) { const mgp_value *val) {
auto *memory = record->values.get_allocator().GetMemoryResource(); auto *memory = record->values.get_allocator().GetMemoryResource();
// TODO: Result validation when we add registering procedures with result // Validate field_name & val satisfy the procedure's result signature.
// signature description. CHECK(record->signature) << "Expected to have a valid signature";
auto find_it = record->signature->find(field_name);
if (find_it == record->signature->end()) return 0;
const auto *type = find_it->second.first;
if (!type->SatisfiesType(*val)) return 0;
try { try {
record->values.emplace(field_name, ToTypedValue(*val, memory)); record->values.emplace(field_name, ToTypedValue(*val, memory));
} catch (...) { } catch (...) {

View File

@ -327,12 +327,25 @@ struct mgp_path {
}; };
struct mgp_result_record { struct mgp_result_record {
/// Result record signature as defined for mgp_proc.
const utils::pmr::map<utils::pmr::string,
std::pair<const query::procedure::CypherType *, bool>>
*signature;
utils::pmr::map<utils::pmr::string, query::TypedValue> values; utils::pmr::map<utils::pmr::string, query::TypedValue> values;
}; };
struct mgp_result { struct mgp_result {
explicit mgp_result(utils::MemoryResource *mem) : rows(mem) {} explicit mgp_result(
const utils::pmr::map<
utils::pmr::string,
std::pair<const query::procedure::CypherType *, bool>> *signature,
utils::MemoryResource *mem)
: signature(signature), rows(mem) {}
/// Result record signature as defined for mgp_proc.
const utils::pmr::map<utils::pmr::string,
std::pair<const query::procedure::CypherType *, bool>>
*signature;
utils::pmr::vector<mgp_result_record> rows; utils::pmr::vector<mgp_result_record> rows;
std::optional<utils::pmr::string> error_msg; std::optional<utils::pmr::string> error_msg;
}; };