diff --git a/include/mg_procedure.h b/include/mg_procedure.h index 88a11e7c3..4d0ce72de 100644 --- a/include/mg_procedure.h +++ b/include/mg_procedure.h @@ -650,6 +650,111 @@ const struct mgp_type *mgp_type_list(const struct mgp_type *element_type); /// NULL is returned if unable to allocate the new type. const struct mgp_type *mgp_type_nullable(const struct mgp_type *type); ///@} + +/// @name Query Module & Procedures +/// +/// The following structures and functions are used to build a query module. You +/// will receive an empty instance of mgp_module through your +/// `int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory)` +/// function. Each shared library that wishes to provide a query module needs to +/// have the said function. Inside you can fill the module with procedures, +/// which can then be called through openCypher. +/// +/// Arguments to `mgp_init_module` will not live longer than the function's +/// execution, so you must not store them globally. Additionally, you must not +/// use the passed in mgp_memory to allocate global resources. +///@{ + +/// Stores information on your query module. +struct mgp_module; + +/// Describes a procedure of a query module. +struct mgp_proc; + +/// Entry-point for a query module procedure, invoked through openCypher. +/// +/// Passed in arguments will not live longer than the callback's execution. +/// Therefore, you must not store them globally or use the passed in mgp_memory +/// to allocate global resources. +typedef void (*mgp_proc_cb)(const struct mgp_list *, const struct mgp_graph *, + struct mgp_result *, struct mgp_memory *); + +/// Register a read-only procedure with a module. +/// +/// The `name` must be a sequence of digits, underscores, lowercase and +/// uppercase Latin letters. The name must begin with a non-digit character. +/// Note that Unicode characters are not allowed. Additionally, names are +/// case-sensitive. +/// +/// NULL is returned if unable to allocate memory for mgp_proc; if `name` is +/// not valid or a procedure with the same name was already registered. +struct mgp_proc *mgp_module_add_read_procedure(struct mgp_module *module, + const char *name, + mgp_proc_cb cb); + +/// Add a required argument to a procedure. +/// +/// The order of adding arguments will correspond to the order the procedure +/// must receive them through openCypher. Required arguments will be followed by +/// optional arguments. +/// +/// The `name` must be a valid identifier, following the same rules as the +/// procedure`name` in mgp_module_add_read_procedure. +/// +/// Passed in `type` describes what kind of values can be used as the argument. +/// +/// 0 is returned if unable to allocate memory for an argument; if invoking this +/// function after setting an optional argument or if `name` is not valid. +/// Non-zero is returned on success. +int mgp_proc_add_arg(struct mgp_proc *proc, const char *name, + const struct mgp_type *type); + +/// Add an optional argument with a default value to a procedure. +/// +/// The order of adding arguments will correspond to the order the procedure +/// must receive them through openCypher. Optional arguments must follow the +/// required arguments. +/// +/// The `name` must be a valid identifier, following the same rules as the +/// procedure `name` in mgp_module_add_read_procedure. +/// +/// Passed in `type` describes what kind of values can be used as the argument. +/// +/// `default_value` is copied and set as the default value for the argument. +/// Don't forget to call mgp_value_destroy when you are done using +/// `default_value`. When the procedure is called, if this argument is not +/// provided, `default_value` will be used instead. `default_value` must satisfy +/// the given `type`. +/// +/// 0 is returned if unable to allocate memory for an argument; if `name` is +/// not valid or `default_value` does not satisfy `type`. Non-zero is returned +/// on success. +int mgp_proc_add_opt_arg(struct mgp_proc *proc, const char *name, + const struct mgp_type *type, + const struct mgp_value *default_value); + +/// Add a result field to a procedure. +/// +/// The `name` must be a valid identifier, following the same rules as the +/// procedure `name` in mgp_module_add_read_procedure. +/// +/// Passed in `type` describes what kind of values can be returned through the +/// result field. +/// +/// 0 is returned if unable to allocate memory for a result field; if +/// `name` is not valid or if a result field with the same name was already +/// added. Non-zero is returned on success. +int mgp_proc_add_result(struct mgp_proc *proc, const char *name, + const struct mgp_type *type); + +/// Add a result field to a procedure and mark it as deprecated. +/// +/// This is the same as mgp_proc_add_result, but the result field will be marked +/// as deprecated. +int mgp_proc_add_deprecated_result(struct mgp_proc *proc, const char *name, + const struct mgp_type *type); +///@} + #ifdef __cplusplus } // extern "C" #endif diff --git a/query_modules/example.c b/query_modules/example.c index 4d50da66a..2f4d0ca97 100644 --- a/query_modules/example.c +++ b/query_modules/example.c @@ -9,13 +9,15 @@ // This example procedure returns 2 fields: `args` and `result`. // * `args` is a copy of arguments passed to the procedure. // * `result` is the result of this procedure, a "Hello World!" string. -// In case of memory errors, this function will report them and finish executing. +// In case of memory errors, this function will report them and finish +// executing. // // The procedure can be invoked in openCypher using the following call: -// CALL example(1, 2, 3) YIELD args, result; +// CALL example.procedure(1, 2, 3) YIELD args, result; // Naturally, you may pass in different arguments or yield less fields. -void mgp_main(const struct mgp_list *args, const struct mgp_graph *graph, - struct mgp_result *result, struct mgp_memory *memory) { +static void procedure(const struct mgp_list *args, + const struct mgp_graph *graph, struct mgp_result *result, + struct mgp_memory *memory) { struct mgp_list *args_copy = mgp_list_make_empty(mgp_list_size(args), memory); if (args_copy == NULL) goto error_memory; for (size_t i = 0; i < mgp_list_size(args); ++i) { @@ -48,10 +50,25 @@ error_memory: return; } -// This is an optional function if you need to initialize any global state when -// your module is loaded. -int mgp_init_module() { - // Return 0 to indicate success. +// Each module needs to define mgp_init_module function. +// Here you can register multiple procedures your module supports. +int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory) { + struct mgp_proc *proc = + mgp_module_add_read_procedure(module, "procedure", procedure); + if (!proc) return 1; + if (!mgp_proc_add_arg(proc, "required_arg", + mgp_type_nullable(mgp_type_any()))) + return 1; + struct mgp_value *null_value = mgp_value_make_null(memory); + if (!mgp_proc_add_opt_arg(proc, "optional_arg", + mgp_type_nullable(mgp_type_any()), null_value)) { + mgp_value_destroy(null_value); + return 1; + } + mgp_value_destroy(null_value); + if (!mgp_proc_add_result(proc, "result", mgp_type_string())) return 1; + if (!mgp_proc_add_result(proc, "args", mgp_type_list(mgp_type_any()))) + return 1; return 0; } diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 9648db853..a4ce9885e 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -95,7 +95,7 @@ if (USE_LTALLOC) endif() add_library(mg-single-node STATIC ${mg_single_node_sources}) -target_include_directories(mg-single-node PRIVATE ${CMAKE_SOURCE_DIR}/include) +target_include_directories(mg-single-node PUBLIC ${CMAKE_SOURCE_DIR}/include) target_link_libraries(mg-single-node ${MG_SINGLE_NODE_LIBS}) add_dependencies(mg-single-node generate_opencypher_parser) add_dependencies(mg-single-node generate_lcp_single_node) @@ -153,7 +153,7 @@ if (USE_LTALLOC) endif() add_library(mg-single-node-v2 STATIC ${mg_single_node_v2_sources}) -target_include_directories(mg-single-node-v2 PRIVATE ${CMAKE_SOURCE_DIR}/include) +target_include_directories(mg-single-node-v2 PUBLIC ${CMAKE_SOURCE_DIR}/include) target_link_libraries(mg-single-node-v2 ${MG_SINGLE_NODE_V2_LIBS}) add_dependencies(mg-single-node-v2 generate_opencypher_parser) add_dependencies(mg-single-node-v2 generate_lcp_common) @@ -243,7 +243,7 @@ if (USE_LTALLOC) endif() add_library(mg-single-node-ha STATIC ${mg_single_node_ha_sources}) -target_include_directories(mg-single-node-ha PRIVATE ${CMAKE_SOURCE_DIR}/include) +target_include_directories(mg-single-node-ha PUBLIC ${CMAKE_SOURCE_DIR}/include) target_link_libraries(mg-single-node-ha ${MG_SINGLE_NODE_HA_LIBS}) add_dependencies(mg-single-node-ha generate_opencypher_parser) add_dependencies(mg-single-node-ha generate_lcp_single_node_ha) diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 0cee78719..d2cacbcf4 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -3717,46 +3717,55 @@ void CallCustomProcedure(const std::string_view &fully_qualified_procedure_name, // TODO: This will probably need to be changed when we add support for // generator like procedures which yield a new result on each invocation. auto *memory = ctx.evaluation_context.memory; - utils::pmr::vector name_parts(memory); - utils::Split(&name_parts, fully_qualified_procedure_name, "."); - // First try to handle special procedure invocations for loading a module. - // TODO: When we add registering multiple procedures in a single module, it - // might be a good idea to simply register these special procedures just like - // regular procedures. That way we won't have to have any special case logic. - if (name_parts.size() > 1U) { - auto pos = fully_qualified_procedure_name.find_last_of('.'); - CHECK(pos != std::string_view::npos); - const auto &module_name = fully_qualified_procedure_name.substr(0, pos); - const auto &proc_name = name_parts.back(); - if (proc_name == "__reload__") { - procedure::gModuleRegistry.ReloadModuleNamed(module_name); - return; - } - } - const auto &module_name = fully_qualified_procedure_name; - if (module_name == "reload-all-modules") { + // First try to handle special procedure invocations for (re)loading modules. + // It would be great to simply register `reload_all_modules` as a + // regular procedure on a `mg` module, so we don't have a special case here. + // Unfortunately, reloading requires taking a write lock, and we would + // acquire a read lock by getting the module. + if (fully_qualified_procedure_name == "mg.reload_all_modules") { procedure::gModuleRegistry.ReloadAllModules(); return; } - auto module = procedure::gModuleRegistry.GetModuleNamed(module_name); + utils::pmr::vector name_parts(memory); + utils::Split(&name_parts, fully_qualified_procedure_name, "."); + if (name_parts.size() == 1U) { + throw QueryRuntimeException("There's no top-level procedure '{}'", + fully_qualified_procedure_name); + } + auto last_dot_pos = fully_qualified_procedure_name.find_last_of('.'); + CHECK(last_dot_pos != std::string_view::npos); + const auto &module_name = + fully_qualified_procedure_name.substr(0, last_dot_pos); + const auto &proc_name = name_parts.back(); + // This is a special case for the same reasons as `mg.reload_all_modules`. + if (proc_name == "__reload__") { + procedure::gModuleRegistry.ReloadModuleNamed(module_name); + return; + } + const auto &module = procedure::gModuleRegistry.GetModuleNamed(module_name); if (!module) throw QueryRuntimeException("'{}' isn't loaded!", module_name); static_assert(std::uses_allocator_v>, "Expected mgp_value to use custom allocator and makes STL " "containers aware of that"); + const auto &proc_it = module->procedures.find(proc_name); + if (proc_it == module->procedures.end()) + throw QueryRuntimeException("'{}' does not have a procedure named '{}'", + module_name, proc_name); mgp_graph graph{ctx.db_accessor, graph_view}; - mgp_list module_args(memory); - module_args.elems.reserve(args.size()); + mgp_list proc_args(memory); + proc_args.elems.reserve(args.size()); ExpressionEvaluator evaluator(frame, ctx.symbol_table, ctx.evaluation_context, ctx.db_accessor, graph_view); for (auto *arg : args) { - module_args.elems.emplace_back(arg->Accept(evaluator), &graph); + proc_args.elems.emplace_back(arg->Accept(evaluator), &graph); } // TODO: Add syntax for controlling procedure memory limits. utils::LimitedMemoryResource limited_mem(memory, 100 * 1024 * 1024 /* 100 MB */); mgp_memory proc_memory{&limited_mem}; // TODO: What about cross library boundary exceptions? OMG C++?! - module->main_fn(&module_args, &graph, result, &proc_memory); + // TODO: Type check both arguments and results against procedure signature. + proc_it->second.cb(&proc_args, &graph, result, &proc_memory); size_t leaked_bytes = limited_mem.GetAllocatedBytes(); LOG_IF(WARNING, leaked_bytes > 0U) << "Query procedure '" << fully_qualified_procedure_name << "' leaked " diff --git a/src/query/procedure/mg_procedure_impl.cpp b/src/query/procedure/mg_procedure_impl.cpp index fa487c4c9..d83f15f62 100644 --- a/src/query/procedure/mg_procedure_impl.cpp +++ b/src/query/procedure/mg_procedure_impl.cpp @@ -3,10 +3,12 @@ #include #include #include +#include #include #include +#include "utils/algorithm.hpp" #include "utils/math.hpp" // This file contains implementation of top level C API functions, but this is @@ -1290,6 +1292,7 @@ const mgp_type *mgp_type_path() { } const mgp_type *mgp_type_list(const mgp_type *type) { + if (!type) return nullptr; // Maps `type` to corresponding instance of ListType. static utils::pmr::map list_types( utils::NewDeleteResource()); @@ -1314,6 +1317,7 @@ const mgp_type *mgp_type_list(const mgp_type *type) { } const mgp_type *mgp_type_nullable(const mgp_type *type) { + if (!type) return nullptr; // Maps `type` to corresponding instance of NullableType. static utils::pmr::map gNullableTypes( utils::NewDeleteResource()); @@ -1332,3 +1336,109 @@ const mgp_type *mgp_type_nullable(const mgp_type *type) { return nullptr; } } + +namespace { +bool IsValidIdentifierName(const char *name) { + if (!name) return false; + std::regex regex("[_[:alpha:]][_[:alnum:]]*"); + return std::regex_match(name, regex); +} +} // namespace + +mgp_proc *mgp_module_add_read_procedure(mgp_module *module, const char *name, + mgp_proc_cb cb) { + if (!module || !cb) return nullptr; + if (!IsValidIdentifierName(name)) return nullptr; + if (module->procedures.find(name) != module->procedures.end()) return nullptr; + try { + auto *memory = module->procedures.get_allocator().GetMemoryResource(); + // May throw std::bad_alloc, std::length_error + return &module->procedures.emplace(name, mgp_proc(name, cb, memory)) + .first->second; + } catch (...) { + return nullptr; + } +} + +int mgp_proc_add_arg(mgp_proc *proc, const char *name, const mgp_type *type) { + if (!proc || !type) return 0; + if (!proc->opt_args.empty()) return 0; + if (!IsValidIdentifierName(name)) return 0; + try { + proc->args.emplace_back(name, type->impl.get()); + return 1; + } catch (...) { + return 0; + } +} + +int mgp_proc_add_opt_arg(mgp_proc *proc, const char *name, const mgp_type *type, + const mgp_value *default_value) { + if (!proc || !type || !default_value) return 0; + if (!IsValidIdentifierName(name)) return 0; + // TODO: Check `default_value` satisfies `type`. + auto *memory = proc->opt_args.get_allocator().GetMemoryResource(); + try { + proc->opt_args.emplace_back(utils::pmr::string(name, memory), + type->impl.get(), + ToTypedValue(*default_value, memory)); + return 1; + } catch (...) { + return 0; + } +} + +namespace { + +int AddResultToProc(mgp_proc *proc, const char *name, const mgp_type *type, + bool is_deprecated) { + if (!proc || !type) return 0; + if (!IsValidIdentifierName(name)) return 0; + if (proc->results.find(name) != proc->results.end()) return 0; + try { + auto *memory = proc->results.get_allocator().GetMemoryResource(); + proc->results.emplace(utils::pmr::string(name, memory), + std::make_pair(type->impl.get(), is_deprecated)); + return 1; + } catch (...) { + return 0; + } +} + +} // namespace + +int mgp_proc_add_result(mgp_proc *proc, const char *name, + const mgp_type *type) { + return AddResultToProc(proc, name, type, false); +} + +int mgp_proc_add_deprecated_result(mgp_proc *proc, const char *name, + const mgp_type *type) { + return AddResultToProc(proc, name, type, true); +} + +namespace query::procedure { + +void PrintProcSignature(const mgp_proc &proc, std::ostream *stream) { + (*stream) << proc.name << "("; + utils::PrintIterable( + *stream, proc.args, ", ", [](auto &stream, const auto &arg) { + stream << arg.first << " :: " << arg.second->GetPresentableName(); + }); + if (!proc.opt_args.empty()) (*stream) << ", "; + utils::PrintIterable( + *stream, proc.opt_args, ", ", [](auto &stream, const auto &arg) { + stream << std::get<0>(arg) << " = " << std::get<2>(arg) + << " :: " << std::get<1>(arg)->GetPresentableName(); + }); + (*stream) << ") :: ("; + utils::PrintIterable( + *stream, proc.results, ", ", [](auto &stream, const auto &name_result) { + const auto &[type, is_deprecated] = name_result.second; + if (is_deprecated) stream << "DEPRECATED "; + stream << name_result.first << " :: " << type->GetPresentableName(); + }); + (*stream) << ")"; +} + +} // namespace query::procedure diff --git a/src/query/procedure/mg_procedure_impl.hpp b/src/query/procedure/mg_procedure_impl.hpp index 510eb8406..38e2e0b1a 100644 --- a/src/query/procedure/mg_procedure_impl.hpp +++ b/src/query/procedure/mg_procedure_impl.hpp @@ -6,6 +6,7 @@ #include "mg_procedure.h" #include +#include #include "query/db_accessor.hpp" #include "query/procedure/cypher_types.hpp" @@ -460,3 +461,89 @@ struct mgp_vertices_iterator { struct mgp_type { query::procedure::CypherTypePtr impl; }; + +struct mgp_proc { + using allocator_type = utils::Allocator; + + /// @throw std::bad_alloc + /// @throw std::length_error + mgp_proc(const char *name, mgp_proc_cb cb, utils::MemoryResource *memory) + : name(name, memory), + cb(cb), + args(memory), + opt_args(memory), + results(memory) {} + + /// @throw std::bad_alloc + /// @throw std::length_error + mgp_proc(const mgp_proc &other, utils::MemoryResource *memory) + : name(other.name, memory), + cb(other.cb), + args(other.args, memory), + opt_args(other.opt_args, memory), + results(other.results, memory) {} + + mgp_proc(mgp_proc &&other, utils::MemoryResource *memory) + : name(std::move(other.name), memory), + cb(std::move(other.cb)), + args(std::move(other.args), memory), + opt_args(std::move(other.opt_args), memory), + results(std::move(other.results), memory) {} + + mgp_proc(const mgp_proc &other) = default; + mgp_proc(mgp_proc &&other) = default; + + mgp_proc &operator=(const mgp_proc &) = delete; + mgp_proc &operator=(mgp_proc &&) = delete; + + ~mgp_proc() = default; + + /// Name of the procedure. + utils::pmr::string name; + /// Entry-point for the procedure. + std::function + cb; + /// Required, positional arguments as a (name, type) pair. + utils::pmr::vector< + std::pair> + args; + /// Optional positional arguments as a (name, type, default_value) tuple. + utils::pmr::vector< + std::tuple> + opt_args; + /// Fields this procedure returns, as a (name -> (type, is_deprecated)) map. + utils::pmr::map> + results; +}; + +struct mgp_module { + using allocator_type = utils::Allocator; + + explicit mgp_module(utils::MemoryResource *memory) : procedures(memory) {} + + mgp_module(const mgp_module &other, utils::MemoryResource *memory) + : procedures(other.procedures, memory) {} + + mgp_module(mgp_module &&other, utils::MemoryResource *memory) + : procedures(std::move(other.procedures), memory) {} + + mgp_module(const mgp_module &) = default; + mgp_module(mgp_module &&) = default; + + mgp_module &operator=(const mgp_module &) = delete; + mgp_module &operator=(mgp_module &&) = delete; + + ~mgp_module() = default; + + utils::pmr::map procedures; +}; + +namespace query::procedure { + +/// @throw anything std::ostream::operator<< may throw. +void PrintProcSignature(const mgp_proc &, std::ostream *); + +} // namespace query::procedure diff --git a/src/query/procedure/module.cpp b/src/query/procedure/module.cpp index dc28da21c..6bec77f73 100644 --- a/src/query/procedure/module.cpp +++ b/src/query/procedure/module.cpp @@ -21,31 +21,32 @@ std::optional LoadModuleFromSharedLibrary(std::filesystem::path path) { LOG(ERROR) << "Unable to load module " << path << "; " << dlerror(); return std::nullopt; } - // Get required mgp_main - module.main_fn = reinterpret_cast( - dlsym(module.handle, "mgp_main")); + // Get required mgp_init_module + module.init_fn = reinterpret_cast( + dlsym(module.handle, "mgp_init_module")); const char *error = dlerror(); - if (!module.main_fn || error) { + if (!module.init_fn || error) { LOG(ERROR) << "Unable to load module " << path << "; " << error; dlclose(module.handle); return std::nullopt; } - // Get optional mgp_init_module - module.init_fn = - reinterpret_cast(dlsym(module.handle, "mgp_init_module")); - error = dlerror(); - if (error) LOG(WARNING) << "When loading module " << path << "; " << error; + // We probably don't need more than 256KB for module initialazation. + constexpr size_t stack_bytes = 256 * 1024; + unsigned char stack_memory[stack_bytes]; + utils::MonotonicBufferResource monotonic_memory(stack_memory, stack_bytes); + mgp_memory memory{&monotonic_memory}; + mgp_module module_def{memory.impl}; // Run mgp_init_module which must succeed. - if (module.init_fn) { - int init_res = module.init_fn(); - if (init_res != 0) { - LOG(ERROR) << "Unable to load module " << path - << "; mgp_init_module returned " << init_res; - dlclose(module.handle); - return std::nullopt; - } + int init_res = module.init_fn(&module_def, &memory); + if (init_res != 0) { + LOG(ERROR) << "Unable to load module " << path + << "; mgp_init_module returned " << init_res; + dlclose(module.handle); + return std::nullopt; } + // Copy procedures into our memory. + for (const auto &proc : module_def.procedures) + module.procedures.emplace(proc); // Get optional mgp_shutdown_module module.shutdown_fn = reinterpret_cast(dlsym(module.handle, "mgp_shutdown_module")); diff --git a/src/query/procedure/module.hpp b/src/query/procedure/module.hpp index 9518ed287..9be152451 100644 --- a/src/query/procedure/module.hpp +++ b/src/query/procedure/module.hpp @@ -9,13 +9,9 @@ #include #include +#include "query/procedure/mg_procedure_impl.hpp" #include "utils/rw_lock.hpp" -struct mgp_graph; -struct mgp_list; -struct mgp_memory; -struct mgp_result; - namespace query::procedure { struct Module final { @@ -23,14 +19,12 @@ struct Module final { std::filesystem::path file_path; /// System handle to shared library. void *handle; - /// Entry-point for module's custom procedure. - std::function - main_fn; - /// Optional initialization function called on module load. - std::function init_fn; + /// Required initialization function called on module load. + std::function init_fn; /// Optional shutdown function called on module unload. std::function shutdown_fn; + /// Registered procedures + std::map> procedures; }; /// Proxy for a registered Module, acquires a read lock from ModuleRegistry. diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 031fa8df4..e854c8894 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -150,6 +150,10 @@ target_link_libraries(${test_prefix}query_plan mg-single-node kvstore_dummy_lib) add_unit_test(query_procedure_mgp_type.cpp) target_link_libraries(${test_prefix}query_procedure_mgp_type mg-single-node kvstore_dummy_lib) target_include_directories(${test_prefix}query_procedure_mgp_type PRIVATE ${CMAKE_SOURCE_DIR}/include) + +add_unit_test(query_procedure_mgp_module.cpp) +target_link_libraries(${test_prefix}query_procedure_mgp_module mg-single-node kvstore_dummy_lib) +target_include_directories(${test_prefix}query_procedure_mgp_module PRIVATE ${CMAKE_SOURCE_DIR}/include) # END query/procedure add_unit_test(query_required_privileges.cpp) diff --git a/tests/unit/query_procedure_mgp_module.cpp b/tests/unit/query_procedure_mgp_module.cpp new file mode 100644 index 000000000..00a4a3041 --- /dev/null +++ b/tests/unit/query_procedure_mgp_module.cpp @@ -0,0 +1,93 @@ +#include + +#include +#include + +#include "query/procedure/mg_procedure_impl.hpp" + +static void DummyCallback(const mgp_list *, const mgp_graph *, mgp_result *, + mgp_memory *) {} + +TEST(Module, InvalidProcedureRegistration) { + mgp_module module(utils::NewDeleteResource()); + EXPECT_FALSE(mgp_module_add_read_procedure(&module, "dashes-not-supported", + DummyCallback)); + EXPECT_FALSE(mgp_module_add_read_procedure( + &module, u8"unicode\u22c6not\u2014supported", DummyCallback)); + EXPECT_FALSE(mgp_module_add_read_procedure( + &module, u8"`backticks⋆\u22c6won't-save\u2014you`", DummyCallback)); + EXPECT_FALSE(mgp_module_add_read_procedure( + &module, "42_name_must_not_start_with_number", DummyCallback)); + EXPECT_FALSE(mgp_module_add_read_procedure(&module, "div/", DummyCallback)); + EXPECT_FALSE(mgp_module_add_read_procedure(&module, "mul*", DummyCallback)); + EXPECT_FALSE(mgp_module_add_read_procedure( + &module, "question_mark_is_not_valid?", DummyCallback)); +} + +TEST(Module, RegisteringTheSameProcedureMultipleTimes) { + mgp_module module(utils::NewDeleteResource()); + EXPECT_EQ(module.procedures.find("same_name"), module.procedures.end()); + EXPECT_TRUE( + mgp_module_add_read_procedure(&module, "same_name", DummyCallback)); + EXPECT_NE(module.procedures.find("same_name"), module.procedures.end()); + EXPECT_FALSE( + mgp_module_add_read_procedure(&module, "same_name", DummyCallback)); + EXPECT_FALSE( + mgp_module_add_read_procedure(&module, "same_name", DummyCallback)); + EXPECT_NE(module.procedures.find("same_name"), module.procedures.end()); +} + +TEST(Module, CaseSensitiveProcedureNames) { + mgp_module module(utils::NewDeleteResource()); + EXPECT_TRUE(module.procedures.empty()); + EXPECT_TRUE( + mgp_module_add_read_procedure(&module, "not_same", DummyCallback)); + EXPECT_TRUE( + mgp_module_add_read_procedure(&module, "NoT_saME", DummyCallback)); + EXPECT_TRUE( + mgp_module_add_read_procedure(&module, "NOT_SAME", DummyCallback)); + EXPECT_EQ(module.procedures.size(), 3U); +} + +static void CheckSignature(const mgp_proc *proc, const std::string &expected) { + std::stringstream ss; + query::procedure::PrintProcSignature(*proc, &ss); + EXPECT_EQ(ss.str(), expected); +} + +TEST(Module, ProcedureSignature) { + mgp_memory memory{utils::NewDeleteResource()}; + mgp_module module(utils::NewDeleteResource()); + auto *proc = mgp_module_add_read_procedure(&module, "proc", DummyCallback); + CheckSignature(proc, "proc() :: ()"); + mgp_proc_add_arg(proc, "arg1", mgp_type_number()); + CheckSignature(proc, "proc(arg1 :: NUMBER) :: ()"); + mgp_proc_add_opt_arg(proc, "opt1", mgp_type_nullable(mgp_type_any()), + mgp_value_make_null(&memory)); + CheckSignature(proc, "proc(arg1 :: NUMBER, opt1 = Null :: ANY?) :: ()"); + mgp_proc_add_result(proc, "res1", mgp_type_list(mgp_type_int())); + CheckSignature( + proc, + "proc(arg1 :: NUMBER, opt1 = Null :: ANY?) :: (res1 :: LIST OF INTEGER)"); + EXPECT_FALSE(mgp_proc_add_arg(proc, "arg2", mgp_type_number())); + CheckSignature( + proc, + "proc(arg1 :: NUMBER, opt1 = Null :: ANY?) :: (res1 :: LIST OF INTEGER)"); + EXPECT_FALSE(mgp_proc_add_arg(proc, "arg2", mgp_type_map())); + CheckSignature(proc, + "proc(arg1 :: NUMBER, opt1 = Null :: ANY?) :: " + "(res1 :: LIST OF INTEGER)"); + mgp_proc_add_opt_arg(proc, "opt2", mgp_type_int(), + mgp_value_make_int(42, &memory)); + CheckSignature( + proc, + "proc(arg1 :: NUMBER, opt1 = Null :: ANY?, opt2 = 42 :: INTEGER) :: " + "(res1 :: LIST OF INTEGER)"); + mgp_proc_add_deprecated_result(proc, "res2", mgp_type_string()); + CheckSignature( + proc, + "proc(arg1 :: NUMBER, opt1 = Null :: ANY?, opt2 = 42 :: INTEGER) :: " + "(res1 :: LIST OF INTEGER, DEPRECATED res2 :: STRING)"); + EXPECT_FALSE(mgp_proc_add_result(proc, "res2", mgp_type_any())); + EXPECT_FALSE(mgp_proc_add_deprecated_result(proc, "res1", mgp_type_any())); +}