diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 0f9c8ee09..5138896bb 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -201,6 +201,8 @@ set(mg_single_node_ha_sources query/plan/rewrite/index_lookup.cpp query/plan/rule_based_planner.cpp query/plan/variable_start_planner.cpp + query/procedure/mg_procedure_impl.cpp + query/procedure/module.cpp query/repl.cpp query/typed_value.cpp storage/common/constraints/record.cpp @@ -243,6 +245,7 @@ if (READLINE_FOUND) endif() add_library(mg-single-node-ha STATIC ${mg_single_node_ha_sources}) +target_include_directories(mg-single-node-ha PRIVATE ${CMAKE_SOURCE_DIR}/include) target_link_libraries(mg-single-node-ha ${MG_SINGLE_NODE_HA_LIBS}) add_dependencies(mg-single-node-ha generate_opencypher_parser) add_dependencies(mg-single-node-ha generate_lcp_single_node_ha) diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 5d2391c44..b83986130 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -23,12 +23,15 @@ #include "query/interpret/eval.hpp" #include "query/path.hpp" #include "query/plan/scoped_profile.hpp" +#include "query/procedure/mg_procedure_impl.hpp" +#include "query/procedure/module.hpp" #include "utils/algorithm.hpp" #include "utils/exceptions.hpp" #include "utils/hashing/fnv.hpp" #include "utils/pmr/unordered_map.hpp" #include "utils/pmr/unordered_set.hpp" #include "utils/pmr/vector.hpp" +#include "utils/string.hpp" // macro for the default implementation of LogicalOperator::Accept // that accepts the visitor and visits it's input_ operator @@ -3661,4 +3664,158 @@ UniqueCursorPtr OutputTableStream::MakeCursor( return MakeUniqueCursorPtr<OutputTableStreamCursor>(mem, this); } +CallProcedure::CallProcedure(std::shared_ptr<LogicalOperator> input, + std::string name, std::vector<Expression *> args, + std::vector<std::string> fields, + std::vector<Symbol> symbols) + : input_(input ? input : std::make_shared<Once>()), + procedure_name_(name), + arguments_(args), + result_fields_(fields), + result_symbols_(symbols) {} + +ACCEPT_WITH_INPUT(CallProcedure); + +std::vector<Symbol> CallProcedure::OutputSymbols(const SymbolTable &) const { + return result_symbols_; +} + +std::vector<Symbol> CallProcedure::ModifiedSymbols( + const SymbolTable &table) const { + auto symbols = input_->ModifiedSymbols(table); + symbols.insert(symbols.end(), result_symbols_.begin(), result_symbols_.end()); + return symbols; +} + +namespace { + +void CallCustomProcedure(const std::string_view &fully_qualified_procedure_name, + const std::vector<Expression *> &args, + storage::View graph_view, const ExecutionContext &ctx, + Frame *frame, mgp_result *result) { + // Use evaluation memory, as invoking a procedure is akin to a simple + // evaluation of an expression. + // TODO: This will probably need to be changed when we add support for + // generator like procedures which yield a new result on each invocation. + auto *memory = ctx.evaluation_context.memory; + utils::pmr::vector<std::string_view> name_parts(memory); + utils::Split(&name_parts, fully_qualified_procedure_name, "."); + // First try to handle special procedure invocations for loading a module. + // TODO: When we add registering multiple procedures in a single module, it + // might be a good idea to simply register these special procedures just like + // regular procedures. That way we won't have to have any special case logic. + if (name_parts.size() > 1U) { + auto pos = fully_qualified_procedure_name.find_last_of('.'); + CHECK(pos != std::string_view::npos); + const auto &module_name = fully_qualified_procedure_name.substr(0, pos); + const auto &proc_name = name_parts.back(); + if (proc_name == "__reload__") { + procedure::gModuleRegistry.ReloadModuleNamed(module_name); + return; + } + } + const auto &module_name = fully_qualified_procedure_name; + if (module_name == "reload-all-modules") { + procedure::gModuleRegistry.ReloadAllModules(); + return; + } + auto module = procedure::gModuleRegistry.GetModuleNamed(module_name); + if (!module) throw QueryRuntimeException("'{}' isn't loaded!", module_name); + static_assert(std::uses_allocator_v<mgp_value, utils::Allocator<mgp_value>>, + "Expected mgp_value to use custom allocator and makes STL " + "containers aware of that"); + mgp_graph graph{ctx.db_accessor, graph_view}; + mgp_list module_args(memory); + module_args.elems.reserve(args.size()); + ExpressionEvaluator evaluator(frame, ctx.symbol_table, ctx.evaluation_context, + ctx.db_accessor, graph_view); + for (auto *arg : args) { + module_args.elems.emplace_back(arg->Accept(evaluator), &graph); + } + // TODO: Add syntax for controlling procedure memory limits. + utils::LimitedMemoryResource limited_mem(memory, + 100 * 1024 * 1024 /* 100 MB */); + mgp_memory proc_memory{&limited_mem}; + // TODO: What about cross library boundary exceptions? OMG C++?! + module->main_fn(&module_args, &graph, result, &proc_memory); + size_t leaked_bytes = limited_mem.GetAllocatedBytes(); + LOG_IF(WARNING, leaked_bytes > 0U) + << "Query procedure '" << fully_qualified_procedure_name << "' leaked " + << leaked_bytes << " *tracked* bytes"; +} + +} // namespace + +class CallProcedureCursor : public Cursor { + const CallProcedure *self_; + UniqueCursorPtr input_cursor_; + mgp_result result_; + decltype(result_.rows.end()) result_row_it_{result_.rows.end()}; + + public: + CallProcedureCursor(const CallProcedure *self, utils::MemoryResource *mem) + : self_(self), + input_cursor_(self_->input_->MakeCursor(mem)), + // result_ needs to live throughout multiple Pull evaluations, until all + // rows are produced. Therefore, we use the memory dedicated for the + // whole execution. + result_(mem) { + CHECK(self_->result_fields_.size() == self_->result_symbols_.size()) + << "Incorrectly constructed CallProcedure"; + } + + bool Pull(Frame &frame, ExecutionContext &context) override { + SCOPED_PROFILE_OP("CallProcedure"); + + if (context.db_accessor->MustAbort()) throw HintedAbortError(); + + // We need to fetch new procedure results after pulling from input. + // TODO: Look into openCypher's distinction between procedures returning an + // empty result set vs procedures which return `void`. We currently don't + // have procedures registering what they return. + // This `while` loop will skip over empty results. + while (result_row_it_ == result_.rows.end()) { + if (!input_cursor_->Pull(frame, context)) return false; + result_.rows.clear(); + result_.error_msg.reset(); + // TODO: When we add support for write and eager procedures, we will need + // to plan this operator with Accumulate and pass in storage::View::NEW. + auto graph_view = storage::View::OLD; + CallCustomProcedure(self_->procedure_name_, self_->arguments_, graph_view, + context, &frame, &result_); + if (result_.error_msg) { + throw QueryRuntimeException("{}: {}", self_->procedure_name_, + *result_.error_msg); + } + result_row_it_ = result_.rows.begin(); + } + + for (size_t i = 0; i < self_->result_fields_.size(); ++i) { + const auto &values = result_row_it_->values; + std::string_view field_name(self_->result_fields_[i]); + auto result_it = values.find(field_name); + if (result_it == values.end()) { + throw QueryRuntimeException( + "Procedure '{}' does not yield a record with '{}' field.", + self_->procedure_name_, field_name); + } + frame[self_->result_symbols_[i]] = result_it->second; + } + ++result_row_it_; + return true; + } + + void Reset() override { + result_.rows.clear(); + result_.error_msg.reset(); + input_cursor_->Reset(); + } + + void Shutdown() override {} +}; + +UniqueCursorPtr CallProcedure::MakeCursor(utils::MemoryResource *mem) const { + return MakeUniqueCursorPtr<CallProcedureCursor>(mem, this, mem); +} + } // namespace query::plan diff --git a/src/query/plan/operator.lcp b/src/query/plan/operator.lcp index 1a3817d59..4c5798b99 100644 --- a/src/query/plan/operator.lcp +++ b/src/query/plan/operator.lcp @@ -114,14 +114,15 @@ class Unwind; class Distinct; class Union; class Cartesian; +class CallProcedure; using LogicalOperatorCompositeVisitor = ::utils::CompositeVisitor< Once, CreateNode, CreateExpand, ScanAll, ScanAllByLabel, ScanAllByLabelPropertyRange, ScanAllByLabelPropertyValue, Expand, ExpandVariable, ConstructNamedPath, Filter, Produce, Delete, SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, - EdgeUniquenessFilter, Accumulate, Aggregate, Skip, Limit, - OrderBy, Merge, Optional, Unwind, Distinct, Union, Cartesian>; + EdgeUniquenessFilter, Accumulate, Aggregate, Skip, Limit, OrderBy, Merge, + Optional, Unwind, Distinct, Union, Cartesian, CallProcedure>; using LogicalOperatorLeafVisitor = ::utils::LeafVisitor<Once>; @@ -162,9 +163,9 @@ can serve as inputs to others and thus a sequence of operations is formed.") /** Return @c Symbol vector where the query results will be stored. * - * Currently, output symbols are generated in @c Produce and @c Union - * operators. @c Skip, @c Limit, @c OrderBy and @c Distinct propagate the - * symbols from @c Produce (if it exists as input operator). + * Currently, output symbols are generated in @c Produce @c Union and + * @c CallProcedure operators. @c Skip, @c Limit, @c OrderBy and @c Distinct + * propagate the symbols from @c Produce (if it exists as input operator). * * @param SymbolTable used to find symbols for expressions. * @return std::vector<Symbol> used for results. @@ -2058,5 +2059,37 @@ at once. Instead, each call of the callback should return a single row of the ta (:serialize (:slk)) (:clone)) +(lcp:define-class call-procedure (logical-operator) + ((input "std::shared_ptr<LogicalOperator>" :scope :public + :slk-save #'slk-save-operator-pointer + :slk-load #'slk-load-operator-pointer) + (procedure-name "std::string" :scope :public) + (arguments "std::vector<Expression *>" + :scope :public + :slk-save #'slk-save-ast-vector + :slk-load (slk-load-ast-vector "Expression")) + (result-fields "std::vector<std::string>" :scope :public) + (result-symbols "std::vector<Symbol>" :scope :public)) + (:public + #>cpp + CallProcedure() = default; + CallProcedure(std::shared_ptr<LogicalOperator> input, std::string name, + std::vector<Expression *> arguments, + std::vector<std::string> fields, std::vector<Symbol> symbols); + + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> OutputSymbols(const SymbolTable &) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { + input_ = input; + } + cpp<#) + (:serialize (:slk)) + (:clone)) + (lcp:pop-namespace) ;; plan (lcp:pop-namespace) ;; query diff --git a/src/query/plan/rewrite/index_lookup.hpp b/src/query/plan/rewrite/index_lookup.hpp index eb119ef6a..1ad7a07d0 100644 --- a/src/query/plan/rewrite/index_lookup.hpp +++ b/src/query/plan/rewrite/index_lookup.hpp @@ -379,6 +379,15 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor { return true; } + bool PreVisit(CallProcedure &op) override { + prev_ops_.push_back(&op); + return true; + } + bool PostVisit(CallProcedure &) override { + prev_ops_.pop_back(); + return true; + } + std::shared_ptr<LogicalOperator> new_root_; private: diff --git a/src/query/plan/rule_based_planner.hpp b/src/query/plan/rule_based_planner.hpp index 3fc82eb63..08f066c22 100644 --- a/src/query/plan/rule_based_planner.hpp +++ b/src/query/plan/rule_based_planner.hpp @@ -209,6 +209,21 @@ class RuleBasedPlanner { input_op = std::make_unique<plan::Unwind>( std::move(input_op), unwind->named_expression_->expression_, symbol); + } else if (auto *call_proc = + utils::Downcast<query::CallProcedure>(clause)) { + std::vector<Symbol> result_symbols; + result_symbols.reserve(call_proc->result_identifiers_.size()); + for (const auto *ident : call_proc->result_identifiers_) { + const auto &sym = context.symbol_table->at(*ident); + context.bound_symbols.insert(sym); + result_symbols.push_back(sym); + } + // TODO: When we add support for write and eager procedures, we will + // need to plan this operator with Accumulate and pass in + // storage::View::NEW. + input_op = std::make_unique<plan::CallProcedure>( + std::move(input_op), call_proc->procedure_name_, + call_proc->arguments_, call_proc->result_fields_, result_symbols); } else { throw utils::NotYetImplemented( "clause '{}' conversion to operator(s)", diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 25fd93a65..3b85418bb 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -133,34 +133,8 @@ add_unit_test(query_plan_create_set_remove_delete.cpp) target_link_libraries(${test_prefix}query_plan_create_set_remove_delete mg-single-node kvstore_dummy_lib) # Storage V2 in query execution -define_add_lcp(add_lcp_query_plan_v2_create_set_remove_delete - lcp_query_plan_v2_create_set_remove_delete - generated_lcp_query_plan_v2_create_set_remove_delete_files) - -add_lcp_query_plan_v2_create_set_remove_delete( - ${CMAKE_SOURCE_DIR}/src/query/frontend/ast/ast.lcp) -add_lcp_query_plan_v2_create_set_remove_delete( - ${CMAKE_SOURCE_DIR}/src/query/frontend/semantic/symbol.lcp) -add_lcp_query_plan_v2_create_set_remove_delete( - ${CMAKE_SOURCE_DIR}/src/query/plan/operator.lcp) - -add_custom_target(generate_lcp_query_plan_v2_create_set_remove_delete DEPENDS - ${generated_lcp_query_plan_v2_create_set_remove_delete_files}) - -add_unit_test(query_plan_v2_create_set_remove_delete.cpp - ${lcp_query_plan_v2_create_set_remove_delete} - ${CMAKE_SOURCE_DIR}/src/query/common.cpp - # ${CMAKE_SOURCE_DIR}/src/query/frontend/ast/ast.lcp.cpp - ${CMAKE_SOURCE_DIR}/src/query/frontend/ast/pretty_print.cpp - ${CMAKE_SOURCE_DIR}/src/query/plan/operator.cpp - # ${CMAKE_SOURCE_DIR}/src/query/plan/operator.lcp.cpp - ${CMAKE_SOURCE_DIR}/src/query/typed_value.cpp) -target_compile_definitions(${test_prefix}query_plan_v2_create_set_remove_delete PUBLIC MG_SINGLE_NODE_V2) -target_link_libraries(${test_prefix}query_plan_v2_create_set_remove_delete glog cppitertools) -target_link_libraries(${test_prefix}query_plan_v2_create_set_remove_delete mg-storage-v2) -add_dependencies(${test_prefix}query_plan_v2_create_set_remove_delete - generate_lcp_query_plan_v2_create_set_remove_delete) - +add_unit_test(query_plan_v2_create_set_remove_delete.cpp) +target_link_libraries(${test_prefix}query_plan_v2_create_set_remove_delete mg-single-node-v2 mg-auth kvstore_dummy_lib) # END Storage V2 in query execution add_unit_test(query_plan_edge_cases.cpp) diff --git a/tests/unit/query_plan.cpp b/tests/unit/query_plan.cpp index 819e254cb..1b2d35d64 100644 --- a/tests/unit/query_plan.cpp +++ b/tests/unit/query_plan.cpp @@ -1470,4 +1470,50 @@ TYPED_TEST(TestPlanner, FilterRegexMatchPreferRangeIndex) { ExpectFilter(), ExpectProduce()); } +TYPED_TEST(TestPlanner, CallProcedureStandalone) { + // Test CALL proc(1,2,3) YIELD field AS result + AstStorage storage; + auto *ast_call = storage.Create<query::CallProcedure>(); + ast_call->procedure_name_ = "proc"; + ast_call->arguments_ = {LITERAL(1), LITERAL(2), LITERAL(3)}; + ast_call->result_fields_ = {"field"}; + ast_call->result_identifiers_ = {IDENT("result")}; + auto *query = QUERY(SINGLE_QUERY(ast_call)); + auto symbol_table = query::MakeSymbolTable(query); + std::vector<Symbol> result_syms; + result_syms.reserve(ast_call->result_identifiers_.size()); + for (const auto *ident : ast_call->result_identifiers_) { + result_syms.push_back(symbol_table.at(*ident)); + } + FakeDbAccessor dba; + auto planner = MakePlanner<TypeParam>(&dba, storage, symbol_table, query); + CheckPlan(planner.plan(), symbol_table, + ExpectCallProcedure(ast_call->procedure_name_, ast_call->arguments_, + ast_call->result_fields_, result_syms)); +} + +TYPED_TEST(TestPlanner, CallProcedureAfterScanAll) { + // Test MATCH (n) CALL proc(n) YIELD field AS result RETURN result + AstStorage storage; + auto *ast_call = storage.Create<query::CallProcedure>(); + ast_call->procedure_name_ = "proc"; + ast_call->arguments_ = {IDENT("n")}; + ast_call->result_fields_ = {"field"}; + ast_call->result_identifiers_ = {IDENT("result")}; + auto *query = QUERY( + SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), ast_call, RETURN("result"))); + auto symbol_table = query::MakeSymbolTable(query); + std::vector<Symbol> result_syms; + result_syms.reserve(ast_call->result_identifiers_.size()); + for (const auto *ident : ast_call->result_identifiers_) { + result_syms.push_back(symbol_table.at(*ident)); + } + FakeDbAccessor dba; + auto planner = MakePlanner<TypeParam>(&dba, storage, symbol_table, query); + CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), + ExpectCallProcedure(ast_call->procedure_name_, ast_call->arguments_, + ast_call->result_fields_, result_syms), + ExpectProduce()); +} + } // namespace diff --git a/tests/unit/query_plan_checker.hpp b/tests/unit/query_plan_checker.hpp index 6b4f5081d..897d7a939 100644 --- a/tests/unit/query_plan_checker.hpp +++ b/tests/unit/query_plan_checker.hpp @@ -90,6 +90,8 @@ class PlanChecker : public virtual HierarchicalLogicalOperatorVisitor { return false; } + PRE_VISIT(CallProcedure); + #undef PRE_VISIT #undef VISIT @@ -325,6 +327,34 @@ class ExpectCartesian : public OpChecker<Cartesian> { const std::list<std::unique_ptr<BaseOpChecker>> &right_; }; +class ExpectCallProcedure : public OpChecker<CallProcedure> { + public: + ExpectCallProcedure(const std::string &name, + const std::vector<query::Expression *> &args, + const std::vector<std::string> &fields, + const std::vector<Symbol> &result_syms) + : name_(name), args_(args), fields_(fields), result_syms_(result_syms) {} + + void ExpectOp(CallProcedure &op, const SymbolTable &symbol_table) override { + EXPECT_EQ(op.procedure_name_, name_); + EXPECT_EQ(op.arguments_.size(), args_.size()); + for (size_t i = 0; i < args_.size(); ++i) { + const auto *op_arg = op.arguments_[i]; + const auto *expected_arg = args_[i]; + // TODO: Proper expression equality + EXPECT_EQ(op_arg->GetTypeInfo(), expected_arg->GetTypeInfo()); + } + EXPECT_EQ(op.result_fields_, fields_); + EXPECT_EQ(op.result_symbols_, result_syms_); + } + + private: + std::string name_; + std::vector<query::Expression *> args_; + std::vector<std::string> fields_; + std::vector<Symbol> result_syms_; +}; + template <class T> std::list<std::unique_ptr<BaseOpChecker>> MakeCheckers(T arg) { std::list<std::unique_ptr<BaseOpChecker>> l;