diff --git a/query_modules/example.c b/query_modules/example.c index 2f4d0ca97..b29afbe26 100644 --- a/query_modules/example.c +++ b/query_modules/example.c @@ -12,8 +12,9 @@ // In case of memory errors, this function will report them and finish // executing. // -// The procedure can be invoked in openCypher using the following call: -// CALL example.procedure(1, 2, 3) YIELD args, result; +// The procedure can be invoked in openCypher using the following calls: +// CALL example.procedure(1, 2) YIELD args, result; +// CALL example.procedure(1) YIELD args, result; // Naturally, you may pass in different arguments or yield less fields. static void procedure(const struct mgp_list *args, const struct mgp_graph *graph, struct mgp_result *result, diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 7e4f6f1ee..9aeebbd0b 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -3778,21 +3778,63 @@ void CallCustomProcedure(const std::string_view &fully_qualified_procedure_name, if (proc_it == module->procedures.end()) throw QueryRuntimeException("'{}' does not have a procedure named '{}'", module_name, proc_name); + const auto &proc = proc_it->second; + // Build and type check procedure arguments. mgp_graph graph{ctx.db_accessor, graph_view}; mgp_list proc_args(memory); proc_args.elems.reserve(args.size()); ExpressionEvaluator evaluator(frame, ctx.symbol_table, ctx.evaluation_context, ctx.db_accessor, graph_view); - for (auto *arg : args) { - proc_args.elems.emplace_back(arg->Accept(evaluator), &graph); + if (args.size() < proc.args.size() || + // Rely on `||` short circuit so we can avoid potential overflow of + // proc.args.size() + proc.opt_args.size() by subtracting. + (args.size() - proc.args.size() > proc.opt_args.size())) { + if (proc.args.empty() && proc.opt_args.empty()) { + throw QueryRuntimeException("'{}' requires no arguments.", + fully_qualified_procedure_name); + } else if (proc.opt_args.empty()) { + throw QueryRuntimeException( + "'{}' requires exactly {} {}.", fully_qualified_procedure_name, + proc.args.size(), proc.args.size() == 1U ? "argument" : "arguments"); + } else { + throw QueryRuntimeException("'{}' requires between {} and {} arguments.", + fully_qualified_procedure_name, + proc.args.size(), + proc.args.size() + proc.opt_args.size()); + } + } + for (size_t i = 0; i < args.size(); ++i) { + auto arg = args[i]->Accept(evaluator); + std::string_view name; + const query::procedure::CypherType *type; + if (proc.args.size() > i) { + name = proc.args[i].first; + type = proc.args[i].second; + } else { + CHECK(proc.opt_args.size() > i - proc.args.size()); + name = std::get<0>(proc.opt_args[i - proc.args.size()]); + type = std::get<1>(proc.opt_args[i - proc.args.size()]); + } + if (!type->SatisfiesType(arg)) { + throw QueryRuntimeException( + "'{}' argument named '{}' at position {} must be of type {}.", + fully_qualified_procedure_name, name, i, type->GetPresentableName()); + } + proc_args.elems.emplace_back(std::move(arg), &graph); + } + // Fill missing optional arguments with their default values. + CHECK(args.size() >= proc.args.size()); + size_t passed_in_opt_args = args.size() - proc.args.size(); + CHECK(passed_in_opt_args <= proc.opt_args.size()); + for (size_t i = passed_in_opt_args; i < proc.opt_args.size(); ++i) { + proc_args.elems.emplace_back(std::get<2>(proc.opt_args[i]), &graph); } // TODO: Add syntax for controlling procedure memory limits. utils::LimitedMemoryResource limited_mem(memory, 100 * 1024 * 1024 /* 100 MB */); mgp_memory proc_memory{&limited_mem}; // TODO: What about cross library boundary exceptions? OMG C++?! - // TODO: Type check both arguments and results against procedure signature. - proc_it->second.cb(&proc_args, &graph, result, &proc_memory); + proc.cb(&proc_args, &graph, result, &proc_memory); size_t leaked_bytes = limited_mem.GetAllocatedBytes(); LOG_IF(WARNING, leaked_bytes > 0U) << "Query procedure '" << fully_qualified_procedure_name << "' leaked " diff --git a/src/query/procedure/mg_procedure_impl.cpp b/src/query/procedure/mg_procedure_impl.cpp index a2457d7d3..46f2b5e0a 100644 --- a/src/query/procedure/mg_procedure_impl.cpp +++ b/src/query/procedure/mg_procedure_impl.cpp @@ -1405,7 +1405,8 @@ int mgp_proc_add_opt_arg(mgp_proc *proc, const char *name, const mgp_type *type, case MGP_VALUE_TYPE_MAP: break; } - // TODO: Check `default_value` satisfies `type`. + // Default value must be of required `type`. + if (!type->impl->SatisfiesType(*default_value)) return 0; auto *memory = proc->opt_args.get_allocator().GetMemoryResource(); try { proc->opt_args.emplace_back(utils::pmr::string(name, memory),