diff --git a/.gitignore b/.gitignore index 83b9a5dc7..754661364 100644 --- a/.gitignore +++ b/.gitignore @@ -59,3 +59,5 @@ src/raft/storage_info_rpc_messages.hpp src/stats/stats_rpc_messages.hpp src/storage/distributed/rpc/concurrent_id_mapper_rpc_messages.hpp src/transactions/distributed/engine_rpc_messages.hpp +/tests/manual/js/transaction_timeout/package-lock.json +/tests/manual/js/transaction_timeout/node_modules/ diff --git a/include/_mgp.hpp b/include/_mgp.hpp index a98303905..4cac0f85d 100644 --- a/include/_mgp.hpp +++ b/include/_mgp.hpp @@ -718,7 +718,7 @@ inline void proc_add_deprecated_result(mgp_proc *proc, const char *name, mgp_typ MgInvokeVoid(mgp_proc_add_deprecated_result, proc, name, type); } -inline bool must_abort(mgp_graph *graph) { return mgp_must_abort(graph); } +inline int must_abort(mgp_graph *graph) { return mgp_must_abort(graph); } // mgp_result diff --git a/include/mg_procedure.h b/include/mg_procedure.h index 19306655f..78e80cc28 100644 --- a/include/mg_procedure.h +++ b/include/mg_procedure.h @@ -1460,7 +1460,10 @@ enum mgp_error mgp_log(enum mgp_log_level log_level, const char *output); /// @{ /// Return non-zero if the currently executing procedure should abort as soon as -/// possible. +/// possible. If non-zero the reasons are: +/// (1) The transaction was requested to be terminated +/// (2) The server is gracefully shutting down +/// (3) The transaction has hit its timeout threshold /// /// Procedures which perform heavyweight processing run the risk of running too /// long and going over the query execution time limit. To prevent this, such diff --git a/include/mgp.hpp b/include/mgp.hpp index 7b5c9c65f..28fe7d152 100644 --- a/include/mgp.hpp +++ b/include/mgp.hpp @@ -67,6 +67,21 @@ class MustAbortException : public std::exception { std::string message_; }; +class TerminatedMustAbortException : public MustAbortException { + public: + explicit TerminatedMustAbortException() : MustAbortException("Query was asked to terminate directly.") {} +}; + +class ShutdownMustAbortException : public MustAbortException { + public: + explicit ShutdownMustAbortException() : MustAbortException("Query was asked to because of server shutdown.") {} +}; + +class TimeoutMustAbortException : public MustAbortException { + public: + explicit TimeoutMustAbortException() : MustAbortException("Query was asked to because of timeout was hit.") {} +}; + // Forward declarations class Nodes; using GraphNodes = Nodes; @@ -109,6 +124,19 @@ class Id { int64_t id_; }; +enum class AbortReason : uint8_t { + NO_ABORT = 0, + + // transaction has been requested to terminate, ie. "TERMINATE TRANSACTIONS ..." + TERMINATED = 1, + + // server is gracefully shutting down + SHUTDOWN = 2, + + // the transaction timeout has been reached. Either via "--query-execution-timeout-sec", or a per-transaction timeout + TIMEOUT = 3, +}; + /// @brief Wrapper class for @ref mgp_graph. class Graph { private: @@ -153,8 +181,13 @@ class Graph { /// @brief Deletes a relationship from the graph. void DeleteRelationship(const Relationship &relationship); - bool MustAbort() const; + /// @brief Checks if process must abort + /// @return AbortReason the reason to abort, if no need to abort then AbortReason::NO_ABORT is returned + AbortReason MustAbort() const; + /// @brief Checks if process must abort + /// @throws MustAbortException If process must abort for any reason + /// @note For the reason why the process must abort consider using MustAbort method instead void CheckMustAbort() const; private: @@ -1709,11 +1742,31 @@ inline Id::Id(int64_t id) : id_(id) {} inline Graph::Graph(mgp_graph *graph) : graph_(graph) {} -inline bool Graph::MustAbort() const { return must_abort(graph_); } +inline AbortReason Graph::MustAbort() const { + const auto reason = must_abort(graph_); + switch (reason) { + case 1: + return AbortReason::TERMINATED; + case 2: + return AbortReason::SHUTDOWN; + case 3: + return AbortReason::TIMEOUT; + default: + break; + } + return AbortReason::NO_ABORT; +} inline void Graph::CheckMustAbort() const { - if (MustAbort()) { - throw MustAbortException("Query was asked to abort."); + switch (MustAbort()) { + case AbortReason::TERMINATED: + throw TerminatedMustAbortException(); + case AbortReason::SHUTDOWN: + throw ShutdownMustAbortException(); + case AbortReason::TIMEOUT: + throw TimeoutMustAbortException(); + case AbortReason::NO_ABORT: + break; } } diff --git a/src/communication/bolt/v1/codes.hpp b/src/communication/bolt/v1/codes.hpp index 4e37e94aa..4b9be96d0 100644 --- a/src/communication/bolt/v1/codes.hpp +++ b/src/communication/bolt/v1/codes.hpp @@ -19,7 +19,7 @@ inline constexpr uint8_t kPreamble[4] = {0x60, 0x60, 0xB0, 0x17}; enum class Signature : uint8_t { Noop = 0x00, - Init = 0x01, + Init = 0x01, // v3+ now HELLO LogOn = 0x6A, LogOff = 0x6B, AckFailure = 0x0E, // only v1 diff --git a/src/communication/bolt/v1/session.hpp b/src/communication/bolt/v1/session.hpp index cf7717fe3..5e6aa4e39 100644 --- a/src/communication/bolt/v1/session.hpp +++ b/src/communication/bolt/v1/session.hpp @@ -64,7 +64,7 @@ class Session { */ virtual std::pair<std::vector<std::string>, std::optional<int>> Interpret( const std::string &query, const std::map<std::string, Value> ¶ms, - const std::map<std::string, memgraph::communication::bolt::Value> &metadata) = 0; + const std::map<std::string, memgraph::communication::bolt::Value> &extra) = 0; /** * Put results of the processed query in the `encoder`. diff --git a/src/communication/bolt/v1/states/executing.hpp b/src/communication/bolt/v1/states/executing.hpp index 518e976c9..b58b3c39b 100644 --- a/src/communication/bolt/v1/states/executing.hpp +++ b/src/communication/bolt/v1/states/executing.hpp @@ -152,6 +152,7 @@ State StateExecutingRun(TSession &session, State state) { return RunHandlerV4<TSession>(signature, session, state, marker); } case 5: + memgraph::metrics::IncrementCounter(memgraph::metrics::BoltMessages); return RunHandlerV5<TSession>(signature, session, state, marker); default: spdlog::trace("Unsupported bolt version:{}.{})!", session.version_.major, session.version_.minor); diff --git a/src/communication/bolt/v1/states/handlers.hpp b/src/communication/bolt/v1/states/handlers.hpp index 03524a490..b23f008ad 100644 --- a/src/communication/bolt/v1/states/handlers.hpp +++ b/src/communication/bolt/v1/states/handlers.hpp @@ -73,23 +73,6 @@ inline std::pair<std::string, std::string> ExceptionToErrorMessage(const std::ex "should be in database logs."}; } -namespace helpers { - -/** Extracts metadata from the extras field. - * NOTE: In order to avoid a copy, the metadata in moved. - * TODO: Update if extra field is used for anything else. - */ -inline std::map<std::string, Value> ConsumeMetadata(Value &extra) { - std::map<std::string, Value> md; - auto &md_tv = extra.ValueMap()["tx_metadata"]; - if (md_tv.IsMap()) { - md = std::move(md_tv.ValueMap()); - } - return md; -} - -} // namespace helpers - namespace details { template <bool is_pull, typename TSession> @@ -286,8 +269,7 @@ State HandleRunV4(TSession &session, const State state, const Marker marker) { try { // Interpret can throw. - const auto [header, qid] = - session.Interpret(query.ValueString(), params.ValueMap(), helpers::ConsumeMetadata(extra)); + const auto [header, qid] = session.Interpret(query.ValueString(), params.ValueMap(), extra.ValueMap()); // Convert std::string to Value std::vector<Value> vec; std::map<std::string, Value> data; @@ -399,7 +381,7 @@ State HandleBegin(TSession &session, const State state, const Marker marker) { } try { - session.BeginTransaction(helpers::ConsumeMetadata(extra)); + session.BeginTransaction(extra.ValueMap()); } catch (const std::exception &e) { return HandleFailure(session, e); } diff --git a/src/memgraph.cpp b/src/memgraph.cpp index d5e4d46f4..33899832d 100644 --- a/src/memgraph.cpp +++ b/src/memgraph.cpp @@ -503,6 +503,25 @@ namespace memgraph::metrics { extern const Event ActiveBoltSessions; } // namespace memgraph::metrics +auto ToQueryExtras(memgraph::communication::bolt::Value const &extra) -> memgraph::query::QueryExtras { + auto const &as_map = extra.ValueMap(); + + auto metadata_pv = std::map<std::string, memgraph::storage::PropertyValue>{}; + + if (auto const it = as_map.find("tx_metadata"); it != as_map.cend() && it->second.IsMap()) { + for (const auto &[key, bolt_md] : it->second.ValueMap()) { + metadata_pv.emplace(key, memgraph::glue::ToPropertyValue(bolt_md)); + } + } + + auto tx_timeout = std::optional<int64_t>{}; + if (auto const it = as_map.find("tx_timeout"); it != as_map.cend() && it->second.IsInt()) { + tx_timeout = it->second.ValueInt(); + } + + return memgraph::query::QueryExtras{std::move(metadata_pv), tx_timeout}; +} + class BoltSession final : public memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream, memgraph::communication::v2::OutputStream> { public: @@ -531,12 +550,8 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph using memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream, memgraph::communication::v2::OutputStream>::TEncoder; - void BeginTransaction(const std::map<std::string, memgraph::communication::bolt::Value> &metadata) override { - std::map<std::string, memgraph::storage::PropertyValue> metadata_pv; - for (const auto &[key, bolt_value] : metadata) { - metadata_pv.emplace(key, memgraph::glue::ToPropertyValue(bolt_value)); - } - interpreter_.BeginTransaction(metadata_pv); + void BeginTransaction(const std::map<std::string, memgraph::communication::bolt::Value> &extra) override { + interpreter_.BeginTransaction(ToQueryExtras(extra)); } void CommitTransaction() override { interpreter_.CommitTransaction(); } @@ -545,15 +560,11 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph std::pair<std::vector<std::string>, std::optional<int>> Interpret( const std::string &query, const std::map<std::string, memgraph::communication::bolt::Value> ¶ms, - const std::map<std::string, memgraph::communication::bolt::Value> &metadata) override { + const std::map<std::string, memgraph::communication::bolt::Value> &extra) override { std::map<std::string, memgraph::storage::PropertyValue> params_pv; - std::map<std::string, memgraph::storage::PropertyValue> metadata_pv; for (const auto &[key, bolt_param] : params) { params_pv.emplace(key, memgraph::glue::ToPropertyValue(bolt_param)); } - for (const auto &[key, bolt_md] : metadata) { - metadata_pv.emplace(key, memgraph::glue::ToPropertyValue(bolt_md)); - } const std::string *username{nullptr}; if (user_) { username = &user_->username(); @@ -565,7 +576,7 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph } #endif try { - auto result = interpreter_.Prepare(query, params_pv, username, metadata_pv); + auto result = interpreter_.Prepare(query, params_pv, username, ToQueryExtras(extra)); if (user_ && !memgraph::glue::AuthChecker::IsUserAuthorized(*user_, result.privileges)) { interpreter_.Abort(); throw memgraph::communication::bolt::ClientError( diff --git a/src/query/context.hpp b/src/query/context.hpp index 6d84beeed..3040d6e10 100644 --- a/src/query/context.hpp +++ b/src/query/context.hpp @@ -85,7 +85,7 @@ struct ExecutionContext { ExecutionStats execution_stats; TriggerContextCollector *trigger_context_collector{nullptr}; FrameChangeCollector *frame_change_collector{nullptr}; - utils::AsyncTimer timer; + std::shared_ptr<utils::AsyncTimer> timer; #ifdef MG_ENTERPRISE std::unique_ptr<FineGrainedAuthChecker> auth_checker{nullptr}; #endif @@ -94,11 +94,18 @@ struct ExecutionContext { static_assert(std::is_move_assignable_v<ExecutionContext>, "ExecutionContext must be move assignable!"); static_assert(std::is_move_constructible_v<ExecutionContext>, "ExecutionContext must be move constructible!"); -inline bool MustAbort(const ExecutionContext &context) noexcept { - return (context.transaction_status != nullptr && - context.transaction_status->load(std::memory_order_acquire) == TransactionStatus::TERMINATED) || - (context.is_shutting_down != nullptr && context.is_shutting_down->load(std::memory_order_acquire)) || - context.timer.IsExpired(); +inline auto MustAbort(const ExecutionContext &context) noexcept -> AbortReason { + if (context.transaction_status != nullptr && + context.transaction_status->load(std::memory_order_acquire) == TransactionStatus::TERMINATED) { + return AbortReason::TERMINATED; + } + if (context.is_shutting_down != nullptr && context.is_shutting_down->load(std::memory_order_acquire)) { + return AbortReason::SHUTDOWN; + } + if (context.timer && context.timer->IsExpired()) { + return AbortReason::TIMEOUT; + } + return AbortReason::NO_ABORT; } inline plan::ProfilingStatsWithTotalTime GetStatsWithTotalTime(const ExecutionContext &context) { diff --git a/src/query/exceptions.hpp b/src/query/exceptions.hpp index 8ba00830b..0476559ed 100644 --- a/src/query/exceptions.hpp +++ b/src/query/exceptions.hpp @@ -112,16 +112,45 @@ class QueryRuntimeException : public QueryException { using QueryException::QueryException; }; +enum class AbortReason : uint8_t { + NO_ABORT = 0, + + // transaction has been requested to terminate, ie. "TERMINATE TRANSACTIONS ..." + TERMINATED = 1, + + // server is gracefully shutting down + SHUTDOWN = 2, + + // the transaction timeout has been reached. Either via "--query-execution-timeout-sec", or a per-transaction timeout + TIMEOUT = 3, +}; + // This one is inherited from BasicException and will be treated as // TransientError, i. e. client will be encouraged to retry execution because it // could succeed if executed again. class HintedAbortError : public utils::BasicException { public: using utils::BasicException::BasicException; - HintedAbortError() - : utils::BasicException( - "Transaction was asked to abort either because it was executing longer than time specified or another user " - "asked it to abort.") {} + explicit HintedAbortError(AbortReason reason) : utils::BasicException(AsMsg(reason)), reason_{reason} {} + + auto Reason() const -> AbortReason { return reason_; } + + private: + static auto AsMsg(AbortReason reason) -> std::string_view { + using namespace std::string_view_literals; + switch (reason) { + case AbortReason::TERMINATED: + return "Transaction was asked to abort by another user."sv; + case AbortReason::SHUTDOWN: + return "Transaction was asked to abort because of database shutdown."sv; + case AbortReason::TIMEOUT: + return "Transaction was asked to abort because of transaction timeout."sv; + default: + // should never happen + return "Transaction was asked to abort for an unknown reason."sv; + } + } + AbortReason reason_; }; class ExplicitTransactionUsageException : public QueryRuntimeException { diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index 4a224626c..5c237b575 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -1018,10 +1018,28 @@ struct PullPlanVector { std::vector<std::vector<TypedValue>> values_; }; +struct TxTimeout { + TxTimeout() = default; + explicit TxTimeout(std::chrono::duration<double> value) noexcept : value_{std::in_place, value} { + // validation + // - negative timeout makes no sense + // - zero timeout means no timeout + if (value_ <= std::chrono::milliseconds{0}) value_.reset(); + }; + explicit operator bool() const { return value_.has_value(); } + + /// Must call operator bool() first to know if safe + auto ValueUnsafe() const -> std::chrono::duration<double> const & { return *value_; } + + private: + std::optional<std::chrono::duration<double>> value_; +}; + struct PullPlan { explicit PullPlan(std::shared_ptr<CachedPlan> plan, const Parameters ¶meters, bool is_profile_query, DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory, std::optional<std::string> username, std::atomic<TransactionStatus> *transaction_status, + std::shared_ptr<utils::AsyncTimer> tx_timer, TriggerContextCollector *trigger_context_collector = nullptr, std::optional<size_t> memory_limit = {}, bool use_monotonic_memory = true, FrameChangeCollector *frame_change_collector_ = nullptr); @@ -1061,8 +1079,9 @@ struct PullPlan { PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters ¶meters, const bool is_profile_query, DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory, std::optional<std::string> username, std::atomic<TransactionStatus> *transaction_status, - TriggerContextCollector *trigger_context_collector, const std::optional<size_t> memory_limit, - bool use_monotonic_memory, FrameChangeCollector *frame_change_collector) + std::shared_ptr<utils::AsyncTimer> tx_timer, TriggerContextCollector *trigger_context_collector, + const std::optional<size_t> memory_limit, bool use_monotonic_memory, + FrameChangeCollector *frame_change_collector) : plan_(plan), cursor_(plan->plan().MakeCursor(execution_memory)), frame_(plan->symbol_table().max_position(), execution_memory), @@ -1086,9 +1105,7 @@ PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters &par } } #endif - if (interpreter_context->config.execution_timeout_sec > 0) { - ctx_.timer = utils::AsyncTimer{interpreter_context->config.execution_timeout_sec}; - } + ctx_.timer = std::move(tx_timer); ctx_.is_shutting_down = &interpreter_context->is_shutting_down; ctx_.transaction_status = transaction_status; ctx_.is_profile_query = is_profile_query; @@ -1236,14 +1253,30 @@ Interpreter::Interpreter(InterpreterContext *interpreter_context) : interpreter_ MG_ASSERT(interpreter_context_, "Interpreter context must not be NULL"); } -PreparedQuery Interpreter::PrepareTransactionQuery(std::string_view query_upper, - const std::map<std::string, storage::PropertyValue> &metadata) { +auto DetermineTxTimeout(std::optional<int64_t> tx_timeout_ms, InterpreterConfig const &config) -> TxTimeout { + using double_seconds = std::chrono::duration<double>; + + auto const global_tx_timeout = double_seconds{config.execution_timeout_sec}; + auto const valid_global_tx_timeout = global_tx_timeout > double_seconds{0}; + + if (tx_timeout_ms) { + auto const timeout = std::chrono::duration_cast<double_seconds>(std::chrono::milliseconds{*tx_timeout_ms}); + if (valid_global_tx_timeout) return TxTimeout{std::min(global_tx_timeout, timeout)}; + return TxTimeout{timeout}; + } + if (valid_global_tx_timeout) { + return TxTimeout{global_tx_timeout}; + } + return TxTimeout{}; +} + +PreparedQuery Interpreter::PrepareTransactionQuery(std::string_view query_upper, QueryExtras const &extras) { std::function<void()> handler; if (query_upper == "BEGIN") { - // TODO: Evaluate doing move(metadata). Currently the metadata is very small, but this will be important if it ever + // TODO: Evaluate doing move(extras). Currently the extras is very small, but this will be important if it ever // becomes large. - handler = [this, metadata] { + handler = [this, extras = extras] { if (in_explicit_transaction_) { throw ExplicitTransactionUsageException("Nested transactions are not supported."); } @@ -1252,7 +1285,11 @@ PreparedQuery Interpreter::PrepareTransactionQuery(std::string_view query_upper, in_explicit_transaction_ = true; expect_rollback_ = false; - metadata_ = GenOptional(metadata); + metadata_ = GenOptional(extras.metadata_pv); + + auto const timeout = DetermineTxTimeout(extras.tx_timeout, interpreter_context_->config); + explicit_transaction_timer_ = + timeout ? std::make_shared<utils::AsyncTimer>(timeout.ValueUnsafe().count()) : nullptr; db_accessor_ = interpreter_context_->db->Access(GetIsolationLevelOverride()); execution_db_accessor_.emplace(db_accessor_.get()); @@ -1283,6 +1320,7 @@ PreparedQuery Interpreter::PrepareTransactionQuery(std::string_view query_upper, expect_rollback_ = false; in_explicit_transaction_ = false; metadata_ = std::nullopt; + explicit_transaction_timer_.reset(); }; } else if (query_upper == "ROLLBACK") { handler = [this] { @@ -1296,6 +1334,7 @@ PreparedQuery Interpreter::PrepareTransactionQuery(std::string_view query_upper, expect_rollback_ = false; in_explicit_transaction_ = false; metadata_ = std::nullopt; + explicit_transaction_timer_.reset(); }; } else { LOG_FATAL("Should not get here -- unknown transaction query!"); @@ -1354,6 +1393,7 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string, InterpreterContext *interpreter_context, DbAccessor *dba, utils::MemoryResource *execution_memory, std::vector<Notification> *notifications, const std::string *username, std::atomic<TransactionStatus> *transaction_status, + std::shared_ptr<utils::AsyncTimer> tx_timer, TriggerContextCollector *trigger_context_collector = nullptr, FrameChangeCollector *frame_change_collector = nullptr) { auto *cypher_query = utils::Downcast<CypherQuery>(parsed_query.query); @@ -1403,10 +1443,11 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string, header.push_back( utils::FindOr(parsed_query.stripped_query.named_expressions(), symbol.token_position(), symbol.name()).first); } - auto pull_plan = std::make_shared<PullPlan>( - plan, parsed_query.parameters, false, dba, interpreter_context, execution_memory, - StringPointerToOptional(username), transaction_status, trigger_context_collector, memory_limit, - use_monotonic_memory, frame_change_collector->IsTrackingValues() ? frame_change_collector : nullptr); + auto pull_plan = + std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context, execution_memory, + StringPointerToOptional(username), transaction_status, std::move(tx_timer), + trigger_context_collector, memory_limit, use_monotonic_memory, + frame_change_collector->IsTrackingValues() ? frame_change_collector : nullptr); return PreparedQuery{std::move(header), std::move(parsed_query.required_privileges), [pull_plan = std::move(pull_plan), output_symbols = std::move(output_symbols), summary]( AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> { @@ -1468,6 +1509,7 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra std::map<std::string, TypedValue> *summary, InterpreterContext *interpreter_context, DbAccessor *dba, utils::MemoryResource *execution_memory, const std::string *username, std::atomic<TransactionStatus> *transaction_status, + std::shared_ptr<utils::AsyncTimer> tx_timer, FrameChangeCollector *frame_change_collector) { const std::string kProfileQueryStart = "profile "; @@ -1533,36 +1575,37 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra rw_type_checker.InferRWType(const_cast<plan::LogicalOperator &>(cypher_query_plan->plan())); - return PreparedQuery{ - {"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"}, - std::move(parsed_query.required_privileges), - [plan = std::move(cypher_query_plan), parameters = std::move(parsed_inner_query.parameters), summary, dba, - interpreter_context, execution_memory, memory_limit, optional_username, - // We want to execute the query we are profiling lazily, so we delay - // the construction of the corresponding context. - stats_and_total_time = std::optional<plan::ProfilingStatsWithTotalTime>{}, - pull_plan = std::shared_ptr<PullPlanVector>(nullptr), transaction_status, use_monotonic_memory, - frame_change_collector](AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> { - // No output symbols are given so that nothing is streamed. - if (!stats_and_total_time) { - stats_and_total_time = - PullPlan(plan, parameters, true, dba, interpreter_context, execution_memory, optional_username, - transaction_status, nullptr, memory_limit, use_monotonic_memory, - frame_change_collector->IsTrackingValues() ? frame_change_collector : nullptr) - .Pull(stream, {}, {}, summary); - pull_plan = std::make_shared<PullPlanVector>(ProfilingStatsToTable(*stats_and_total_time)); - } + return PreparedQuery{{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"}, + std::move(parsed_query.required_privileges), + [plan = std::move(cypher_query_plan), parameters = std::move(parsed_inner_query.parameters), + summary, dba, interpreter_context, execution_memory, memory_limit, optional_username, + // We want to execute the query we are profiling lazily, so we delay + // the construction of the corresponding context. + stats_and_total_time = std::optional<plan::ProfilingStatsWithTotalTime>{}, + pull_plan = std::shared_ptr<PullPlanVector>(nullptr), transaction_status, use_monotonic_memory, + frame_change_collector, tx_timer = std::move(tx_timer)]( + AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> { + // No output symbols are given so that nothing is streamed. + if (!stats_and_total_time) { + stats_and_total_time = + PullPlan(plan, parameters, true, dba, interpreter_context, execution_memory, + optional_username, transaction_status, std::move(tx_timer), nullptr, + memory_limit, use_monotonic_memory, + frame_change_collector->IsTrackingValues() ? frame_change_collector : nullptr) + .Pull(stream, {}, {}, summary); + pull_plan = std::make_shared<PullPlanVector>(ProfilingStatsToTable(*stats_and_total_time)); + } - MG_ASSERT(stats_and_total_time, "Failed to execute the query!"); + MG_ASSERT(stats_and_total_time, "Failed to execute the query!"); - if (pull_plan->Pull(stream, n)) { - summary->insert_or_assign("profile", ProfilingStatsToJson(*stats_and_total_time).dump()); - return QueryHandlerResult::ABORT; - } + if (pull_plan->Pull(stream, n)) { + summary->insert_or_assign("profile", ProfilingStatsToJson(*stats_and_total_time).dump()); + return QueryHandlerResult::ABORT; + } - return std::nullopt; - }, - rw_type_checker.type}; + return std::nullopt; + }, + rw_type_checker.type}; } PreparedQuery PrepareDumpQuery(ParsedQuery parsed_query, std::map<std::string, TypedValue> *summary, DbAccessor *dba, @@ -1947,7 +1990,8 @@ PreparedQuery PrepareIndexQuery(ParsedQuery parsed_query, bool in_explicit_trans PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transaction, std::map<std::string, TypedValue> *summary, InterpreterContext *interpreter_context, DbAccessor *dba, utils::MemoryResource *execution_memory, const std::string *username, - std::atomic<TransactionStatus> *transaction_status) { + std::atomic<TransactionStatus> *transaction_status, + std::shared_ptr<utils::AsyncTimer> tx_timer) { if (in_explicit_transaction) { throw UserModificationInMulticommandTxException(); } @@ -1967,8 +2011,9 @@ PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transa [fn = callback.fn](Frame *, ExecutionContext *) { return fn(); }), 0.0, AstStorage{}, symbol_table)); - auto pull_plan = std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context, - execution_memory, StringPointerToOptional(username), transaction_status); + auto pull_plan = + std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context, execution_memory, + StringPointerToOptional(username), transaction_status, std::move(tx_timer)); return PreparedQuery{ callback.header, std::move(parsed_query.required_privileges), [pull_plan = std::move(pull_plan), callback = std::move(callback), output_symbols = std::move(output_symbols), @@ -2362,7 +2407,8 @@ Callback SwitchMemoryDevice(storage::StorageMode current_mode, storage::StorageM if (SwitchingFromDiskToInMemory(current_mode, requested_mode)) { throw utils::BasicException( "You cannot switch from the on-disk storage mode to an in-memory storage mode while the database is running. " - "To make the switch, delete the data directory and restart the database. Once restarted, Memgraph will automatically " + "To make the switch, delete the data directory and restart the database. Once restarted, Memgraph will " + "automatically " "start in the default in-memory transactional storage mode."); } if (SwitchingFromInMemoryToDisk(current_mode, requested_mode)) { @@ -3051,8 +3097,8 @@ std::optional<uint64_t> Interpreter::GetTransactionId() const { return {}; } -void Interpreter::BeginTransaction(const std::map<std::string, storage::PropertyValue> &metadata) { - const auto prepared_query = PrepareTransactionQuery("BEGIN", metadata); +void Interpreter::BeginTransaction(QueryExtras const &extras) { + const auto prepared_query = PrepareTransactionQuery("BEGIN", extras); prepared_query.query_handler(nullptr, {}); } @@ -3072,13 +3118,17 @@ void Interpreter::RollbackTransaction() { Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, const std::map<std::string, storage::PropertyValue> ¶ms, - const std::string *username, - const std::map<std::string, storage::PropertyValue> &metadata) { + const std::string *username, QueryExtras const &extras) { + std::shared_ptr<utils::AsyncTimer> current_timer; if (!in_explicit_transaction_) { query_executions_.clear(); transaction_queries_->clear(); // Handle user-defined metadata in auto-transactions - metadata_ = GenOptional(metadata); + metadata_ = GenOptional(extras.metadata_pv); + auto const timeout = DetermineTxTimeout(extras.tx_timeout, interpreter_context_->config); + current_timer = timeout ? std::make_shared<utils::AsyncTimer>(timeout.ValueUnsafe().count()) : nullptr; + } else { + current_timer = explicit_transaction_timer_; } // This will be done in the handle transaction query. Our handler can save username and then send it to the kill and @@ -3098,7 +3148,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, std::optional<int> qid = in_explicit_transaction_ ? static_cast<int>(query_executions_.size() - 1) : std::optional<int>{}; - query_execution->prepared_query.emplace(PrepareTransactionQuery(trimmed_query, metadata)); + query_execution->prepared_query.emplace(PrepareTransactionQuery(trimmed_query, extras)); return {query_execution->prepared_query->header, query_execution->prepared_query->privileges, qid}; } @@ -3184,7 +3234,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, if (utils::Downcast<CypherQuery>(parsed_query.query)) { prepared_query = PrepareCypherQuery( std::move(parsed_query), &query_execution->summary, interpreter_context_, &*execution_db_accessor_, - memory_resource, &query_execution->notifications, username, &transaction_status_, + memory_resource, &query_execution->notifications, username, &transaction_status_, std::move(current_timer), trigger_context_collector_ ? &*trigger_context_collector_ : nullptr, &*frame_change_collector_); } else if (utils::Downcast<ExplainQuery>(parsed_query.query)) { prepared_query = PrepareExplainQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_, @@ -3193,7 +3243,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, prepared_query = PrepareProfileQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, interpreter_context_, &*execution_db_accessor_, &query_execution->execution_memory_with_exception, username, - &transaction_status_, &*frame_change_collector_); + &transaction_status_, std::move(current_timer), &*frame_change_collector_); } else if (utils::Downcast<DumpQuery>(parsed_query.query)) { prepared_query = PrepareDumpQuery(std::move(parsed_query), &query_execution->summary, &*execution_db_accessor_, memory_resource); @@ -3204,9 +3254,10 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, prepared_query = PrepareAnalyzeGraphQuery(std::move(parsed_query), in_explicit_transaction_, &*execution_db_accessor_, interpreter_context_); } else if (utils::Downcast<AuthQuery>(parsed_query.query)) { - prepared_query = PrepareAuthQuery( - std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, interpreter_context_, - &*execution_db_accessor_, &query_execution->execution_memory_with_exception, username, &transaction_status_); + prepared_query = PrepareAuthQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, + interpreter_context_, &*execution_db_accessor_, + &query_execution->execution_memory_with_exception, username, + &transaction_status_, std::move(current_timer)); } else if (utils::Downcast<InfoQuery>(parsed_query.query)) { prepared_query = PrepareInfoQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, interpreter_context_, interpreter_context_->db.get(), @@ -3299,6 +3350,7 @@ void Interpreter::Abort() { expect_rollback_ = false; in_explicit_transaction_ = false; metadata_ = std::nullopt; + explicit_transaction_timer_.reset(); memgraph::metrics::DecrementCounter(memgraph::metrics::ActiveTransactions); diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index f294fce3b..90724b979 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -204,6 +204,15 @@ struct PreparedQuery { plan::ReadWriteTypeChecker::RWType rw_type; }; +/** + * Holds data for the Query which is extra + * NOTE: maybe need to parse more in the future, ATM we ignore some parts from BOLT + */ +struct QueryExtras { + std::map<std::string, memgraph::storage::PropertyValue> metadata_pv; + std::optional<int64_t> tx_timeout; +}; + class Interpreter; /** @@ -268,6 +277,7 @@ class Interpreter final { std::optional<std::string> username_; bool in_explicit_transaction_{false}; bool expect_rollback_{false}; + std::shared_ptr<utils::AsyncTimer> explicit_transaction_timer_{}; std::optional<std::map<std::string, storage::PropertyValue>> metadata_{}; //!< User defined transaction metadata /** @@ -279,8 +289,7 @@ class Interpreter final { * @throw query::QueryException */ PrepareResult Prepare(const std::string &query, const std::map<std::string, storage::PropertyValue> ¶ms, - const std::string *username, - const std::map<std::string, storage::PropertyValue> &metadata = {}); + const std::string *username, QueryExtras const &extras = {}); /** * Execute the last prepared query and stream *all* of the results into the @@ -324,7 +333,7 @@ class Interpreter final { std::map<std::string, TypedValue> Pull(TStream *result_stream, std::optional<int> n = {}, std::optional<int> qid = {}); - void BeginTransaction(const std::map<std::string, storage::PropertyValue> &metadata = {}); + void BeginTransaction(QueryExtras const &extras = {}); std::optional<uint64_t> GetTransactionId() const; @@ -419,8 +428,7 @@ class Interpreter final { std::optional<storage::IsolationLevel> interpreter_isolation_level; std::optional<storage::IsolationLevel> next_transaction_isolation_level; - PreparedQuery PrepareTransactionQuery(std::string_view query_upper, - const std::map<std::string, storage::PropertyValue> &metadata = {}); + PreparedQuery PrepareTransactionQuery(std::string_view query_upper, QueryExtras const &extras = {}); void Commit(); void AdvanceCommand(); void AbortCommand(std::unique_ptr<QueryExecution> *query_execution); diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 91ebad094..26d7cc8fc 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -158,6 +158,10 @@ uint64_t ComputeProfilingKey(const T *obj) { return reinterpret_cast<uint64_t>(obj); } +inline void AbortCheck(ExecutionContext const &context) { + if (auto const reason = MustAbort(context); reason != AbortReason::NO_ABORT) throw HintedAbortError(reason); +} + } // namespace #define SCOPED_PROFILE_OP(name) ScopedProfile profile{ComputeProfilingKey(this), name, &context}; @@ -430,7 +434,7 @@ class ScanAllCursor : public Cursor { bool Pull(Frame &frame, ExecutionContext &context) override { SCOPED_PROFILE_OP(op_name_); - if (MustAbort(context)) throw HintedAbortError(); + AbortCheck(context); while (!vertices_ || vertices_it_.value() == vertices_.value().end()) { if (!input_cursor_->Pull(frame, context)) return false; @@ -738,7 +742,7 @@ bool Expand::ExpandCursor::Pull(Frame &frame, ExecutionContext &context) { }; while (true) { - if (MustAbort(context)) throw HintedAbortError(); + AbortCheck(context); // attempt to get a value from the incoming edges if (in_edges_ && *in_edges_it_ != in_edges_->end()) { auto edge = *(*in_edges_it_)++; @@ -1001,7 +1005,7 @@ class ExpandVariableCursor : public Cursor { // Input Vertex could be null if it is created by a failed optional match. // In those cases we skip that input pull and continue with the next. while (true) { - if (MustAbort(context)) throw HintedAbortError(); + AbortCheck(context); if (!input_cursor_->Pull(frame, context)) return false; TypedValue &vertex_value = frame[self_.input_symbol_]; @@ -1071,7 +1075,7 @@ class ExpandVariableCursor : public Cursor { // existing_node criterions, so expand in a loop until either the input // vertex is exhausted or a valid variable-length expansion is available. while (true) { - if (MustAbort(context)) throw HintedAbortError(); + AbortCheck(context); // pop from the stack while there is stuff to pop and the current // level is exhausted while (!edges_.empty() && edges_it_.back() == edges_.back().end()) { @@ -1269,7 +1273,7 @@ class STShortestPathCursor : public query::plan::Cursor { out_edge[sink] = std::nullopt; while (true) { - if (MustAbort(context)) throw HintedAbortError(); + AbortCheck(context); // Top-down step (expansion from the source). ++current_length; if (current_length > upper_bound) return false; @@ -1475,7 +1479,7 @@ class SingleSourceShortestPathCursor : public query::plan::Cursor { // do it all in a loop because we skip some elements while (true) { - if (MustAbort(context)) throw HintedAbortError(); + AbortCheck(context); // if we have nothing to visit on the current depth, switch to next if (to_visit_current_.empty()) to_visit_current_.swap(to_visit_next_); @@ -1681,7 +1685,7 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor { }; while (true) { - if (MustAbort(context)) throw HintedAbortError(); + AbortCheck(context); if (pq_.empty()) { if (!input_cursor_->Pull(frame, context)) return false; const auto &vertex_value = frame[self_.input_symbol_]; @@ -1718,7 +1722,7 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor { } while (!pq_.empty()) { - if (MustAbort(context)) throw HintedAbortError(); + AbortCheck(context); auto [current_weight, current_depth, current_vertex, current_edge] = pq_.top(); pq_.pop(); @@ -2017,7 +2021,7 @@ class ExpandAllShortestPathsCursor : public query::plan::Cursor { auto create_DFS_traversal_tree = [this, &context, &memory, &create_state, &expand_from_vertex]() { while (!pq_.empty()) { - if (MustAbort(context)) throw HintedAbortError(); + AbortCheck(context); const auto [current_weight, current_depth, current_vertex, directed_edge] = pq_.top(); pq_.pop(); @@ -2068,7 +2072,7 @@ class ExpandAllShortestPathsCursor : public query::plan::Cursor { // On each subsequent Pull run, paths are created from the traversal stack and returned. while (true) { // Check if there is an external error. - if (MustAbort(context)) throw HintedAbortError(); + AbortCheck(context); // The algorithm is run all at once by create_DFS_traversal_tree, after which we // traverse the tree iteratively by preserving the traversal state on stack. @@ -2486,7 +2490,7 @@ bool Delete::DeleteCursor::Pull(Frame &frame, ExecutionContext &context) { auto &dba = *context.db_accessor; // delete edges first for (TypedValue &expression_result : expression_results) { - if (MustAbort(context)) throw HintedAbortError(); + AbortCheck(context); if (expression_result.type() == TypedValue::Type::Edge) { auto &ea = expression_result.ValueEdge(); #ifdef MG_ENTERPRISE @@ -2518,7 +2522,7 @@ bool Delete::DeleteCursor::Pull(Frame &frame, ExecutionContext &context) { // delete vertices for (TypedValue &expression_result : expression_results) { - if (MustAbort(context)) throw HintedAbortError(); + AbortCheck(context); switch (expression_result.type()) { case TypedValue::Type::Vertex: { auto &va = expression_result.ValueVertex(); @@ -3197,9 +3201,7 @@ class EmptyResultCursor : public Cursor { if (!pulled_all_input_) { while (input_cursor_->Pull(frame, context)) { - if (MustAbort(context)) { - throw HintedAbortError(); - } + AbortCheck(context); } pulled_all_input_ = true; } @@ -3255,7 +3257,7 @@ class AccumulateCursor : public Cursor { if (self_.advance_command_) dba.AdvanceCommand(); } - if (MustAbort(context)) throw HintedAbortError(); + AbortCheck(context); if (cache_it_ == cache_.end()) return false; auto row_it = (cache_it_++)->begin(); for (const Symbol &symbol : self_.symbols_) { @@ -3831,7 +3833,7 @@ class OrderByCursor : public Cursor { if (cache_it_ == cache_.end()) return false; - if (MustAbort(context)) throw HintedAbortError(); + AbortCheck(context); // place the output values on the frame DMG_ASSERT(self_.output_symbols_.size() == cache_it_->remember.size(), @@ -4055,7 +4057,7 @@ class UnwindCursor : public Cursor { bool Pull(Frame &frame, ExecutionContext &context) override { SCOPED_PROFILE_OP("Unwind"); while (true) { - if (MustAbort(context)) throw HintedAbortError(); + AbortCheck(context); // if we reached the end of our list of values // pull from the input if (input_value_it_ == input_value_.end()) { @@ -4314,7 +4316,7 @@ class CartesianCursor : public Cursor { restore_frame(self_.right_symbols_, right_op_frame_); } - if (MustAbort(context)) throw HintedAbortError(); + AbortCheck(context); restore_frame(self_.left_symbols_, *left_op_frames_it_); left_op_frames_it_++; @@ -4565,7 +4567,7 @@ class CallProcedureCursor : public Cursor { bool Pull(Frame &frame, ExecutionContext &context) override { SCOPED_PROFILE_OP("CallProcedure"); - if (MustAbort(context)) throw HintedAbortError(); + AbortCheck(context); // We need to fetch new procedure results after pulling from input. // TODO: Look into openCypher's distinction between procedures returning an @@ -4786,7 +4788,7 @@ class LoadCsvCursor : public Cursor { bool Pull(Frame &frame, ExecutionContext &context) override { SCOPED_PROFILE_OP("LoadCsv"); - if (MustAbort(context)) throw HintedAbortError(); + AbortCheck(context); // ToDo(the-joksim): // - this is an ungodly hack because the pipeline of creating a plan diff --git a/src/query/procedure/mg_procedure_impl.cpp b/src/query/procedure/mg_procedure_impl.cpp index 6bc6a3f67..dfe8f21cf 100644 --- a/src/query/procedure/mg_procedure_impl.cpp +++ b/src/query/procedure/mg_procedure_impl.cpp @@ -2962,7 +2962,18 @@ mgp_error mgp_proc_add_deprecated_result(mgp_proc *proc, const char *name, mgp_t int mgp_must_abort(mgp_graph *graph) { MG_ASSERT(graph->ctx); static_assert(noexcept(memgraph::query::MustAbort(*graph->ctx))); - return memgraph::query::MustAbort(*graph->ctx) ? 1 : 0; + auto const reason = memgraph::query::MustAbort(*graph->ctx); + // NOTE: deliberately decoupled to avoid accidental ABI breaks + switch (reason) { + case memgraph::query::AbortReason::TERMINATED: + return 1; + case memgraph::query::AbortReason::SHUTDOWN: + return 2; + case memgraph::query::AbortReason::TIMEOUT: + return 3; + case memgraph::query::AbortReason::NO_ABORT: + return 0; + } } namespace memgraph::query::procedure { diff --git a/src/query/trigger.cpp b/src/query/trigger.cpp index e8692ae5a..d682de677 100644 --- a/src/query/trigger.cpp +++ b/src/query/trigger.cpp @@ -213,7 +213,7 @@ void Trigger::Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution ctx.evaluation_context.parameters = parsed_statements_.parameters; ctx.evaluation_context.properties = NamesToProperties(plan.ast_storage().properties_, dba); ctx.evaluation_context.labels = NamesToLabels(plan.ast_storage().labels_, dba); - ctx.timer = utils::AsyncTimer(max_execution_time_sec); + ctx.timer = (max_execution_time_sec > 0.0) ? std::make_shared<utils::AsyncTimer>(max_execution_time_sec) : nullptr; ctx.is_shutting_down = is_shutting_down; ctx.transaction_status = transaction_status; ctx.is_profile_query = false; diff --git a/tests/manual/js/transaction_timeout/README.md b/tests/manual/js/transaction_timeout/README.md new file mode 100644 index 000000000..e4f54677a --- /dev/null +++ b/tests/manual/js/transaction_timeout/README.md @@ -0,0 +1,8 @@ +move into e2e when ts-node based runner has been setup + +To run +```bash +npm install +npm run test:auto +npm run test:explicit +``` diff --git a/tests/manual/js/transaction_timeout/auto-transaction.ts b/tests/manual/js/transaction_timeout/auto-transaction.ts new file mode 100644 index 000000000..3fd704725 --- /dev/null +++ b/tests/manual/js/transaction_timeout/auto-transaction.ts @@ -0,0 +1,39 @@ +import { Driver, RxSession, Session } from "neo4j-driver"; +import { finalize } from "rxjs"; + +var neo4j = require("neo4j-driver"); + +const driver: Driver = neo4j.driver("bolt://localhost:7687"); + +async function setup() { + const session: Session = driver.session(); + + try { + await session.run('MATCH (n) DETACH DELETE n'); + await session.run('UNWIND RANGE(1, 100) AS x CREATE ()'); + } finally { + session.close(); + } +} + +setup() + .then( + () => { + const session: RxSession = driver.rxSession({ defaultAccessMode: 'READ' }); + session + .run("MATCH (), (), (), () RETURN 42 AS thing;", // NOTE: A long query + undefined, + { timeout: 50 } // NOTE: with a short timeout + ) + .records() + .pipe(finalize(() => { + session.close(); + driver.close(); + })) + .subscribe({ + next: record => { }, + complete: () => { console.info('complete'); process.exit(1); }, // UNEXPECTED + error: msg => console.error('Error:', msg.message), // NOTE: expected to error with server side timeout + }); + } + ) diff --git a/tests/manual/js/transaction_timeout/explicit-transaction.ts b/tests/manual/js/transaction_timeout/explicit-transaction.ts new file mode 100644 index 000000000..df9c192ae --- /dev/null +++ b/tests/manual/js/transaction_timeout/explicit-transaction.ts @@ -0,0 +1,42 @@ +import { Driver, Session } from "neo4j-driver"; + +import { EMPTY } from "rxjs" +import { catchError, finalize, map, mergeMap, concatWith } from "rxjs/operators" + +const neo4j = require("neo4j-driver"); + +const driver: Driver = neo4j.driver("bolt://localhost:7687"); + +async function setup() { + const session: Session = driver.session(); + + try { + await session.run('MATCH (n) DETACH DELETE n'); + await session.run('UNWIND RANGE(1, 100) AS x CREATE ()'); + } finally { + session.close(); + } +} + +setup().then(() => { + const rxSession = driver.rxSession({ defaultAccessMode: 'READ' }); + rxSession + .beginTransaction({ timeout: 50 }) // NOTE: a short timeout + .pipe( + mergeMap(tx => + tx + .run('MATCH (),(),(),() RETURN 42 AS thing;') // NOTE: a long query + .records() + .pipe( + catchError(err => { tx.rollback(); throw err; }), + concatWith(EMPTY.pipe(finalize(() => tx.commit()))) + ) + ), + finalize(() => { rxSession.close(); driver.close() }) + ) + .subscribe({ + next: record => { }, + complete: () => { console.info('complete'); process.exit(1); }, // UNEXPECTED + error: msg => console.error('Error:', msg.message), // NOTE: expected to error with server side timeout + }) +}) diff --git a/tests/manual/js/transaction_timeout/package.json b/tests/manual/js/transaction_timeout/package.json new file mode 100644 index 000000000..a73ab1954 --- /dev/null +++ b/tests/manual/js/transaction_timeout/package.json @@ -0,0 +1,15 @@ +{ + "version": "1.0.0", + "description": "", + "scripts": { + "test:explicit": "ts-node explicit-transaction.ts", + "test:auto": "ts-node auto-transaction.ts" + }, + "dependencies": { + "neo4j-driver": "^5.9.2", + "rxjs": "^7.8.1" + }, + "devDependencies": { + "ts-node": "^10.0.0" + } +} diff --git a/tests/unit/bolt_session.cpp b/tests/unit/bolt_session.cpp index 2ab896296..982018b33 100644 --- a/tests/unit/bolt_session.cpp +++ b/tests/unit/bolt_session.cpp @@ -17,6 +17,7 @@ #include "bolt_common.hpp" #include "communication/bolt/v1/session.hpp" #include "communication/exceptions.hpp" +#include "query/exceptions.hpp" #include "utils/logging.hpp" using memgraph::communication::bolt::ClientError; @@ -42,8 +43,11 @@ class TestSession : public Session<TestInputStream, TestOutputStream> { std::pair<std::vector<std::string>, std::optional<int>> Interpret( const std::string &query, const std::map<std::string, Value> ¶ms, - const std::map<std::string, Value> &metadata) override { - if (!metadata.empty()) md_ = metadata; + const std::map<std::string, Value> &extra) override { + if (extra.contains("tx_metadata")) { + auto const &metadata = extra.at("tx_metadata").ValueMap(); + if (!metadata.empty()) md_ = metadata; + } if (query == kQueryReturn42 || query == kQueryEmpty || query == kQueryReturnMultiple) { query_ = query; return {{"result_name"}, {}}; @@ -60,6 +64,9 @@ class TestSession : public Session<TestInputStream, TestOutputStream> { } std::map<std::string, Value> Pull(TEncoder *encoder, std::optional<int> n, std::optional<int> qid) override { + if (should_abort_) { + throw memgraph::query::HintedAbortError(memgraph::query::AbortReason::TERMINATED); + } if (query_ == kQueryReturn42) { encoder->MessageRecord(std::vector<Value>{Value(42)}); return {}; @@ -91,7 +98,12 @@ class TestSession : public Session<TestInputStream, TestOutputStream> { std::map<std::string, Value> Discard(std::optional<int>, std::optional<int>) override { return {}; } - void BeginTransaction(const std::map<std::string, Value> &metadata) override { md_ = metadata; } + void BeginTransaction(const std::map<std::string, Value> &extra) override { + if (extra.contains("tx_metadata")) { + auto const &metadata = extra.at("tx_metadata").ValueMap(); + if (!metadata.empty()) md_ = metadata; + } + } void CommitTransaction() override { md_.clear(); } void RollbackTransaction() override { md_.clear(); } @@ -101,9 +113,12 @@ class TestSession : public Session<TestInputStream, TestOutputStream> { std::optional<std::string> GetServerNameForInit() override { return std::nullopt; } + void TestHook_ShouldAbort() { should_abort_ = true; } + private: std::string query_; std::map<std::string, Value> md_; + bool should_abort_ = false; }; // TODO: This could be done in fixture. @@ -171,9 +186,22 @@ inline constexpr uint8_t handshake_req[] = {0x60, 0x60, 0xb0, 0x17, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; inline constexpr uint8_t handshake_resp[] = {0x00, 0x00, 0x03, 0x04}; inline constexpr uint8_t route[]{0xb3, 0x66, 0xa0, 0x90, 0xc0}; -const std::string extra_w_metadata = - "\xa2\x8b\x74\x78\x5f\x6d\x65\x74\x61\x64\x61\x74\x61\xa2\x83\x73\x74\x72\x83\x61\x68\x61\x83\x6e\x75\x6d\x7b\x8a" - "\x74\x78\x5f\x74\x69\x6d\x65\x6f\x75\x74\xc9\x07\xd0"; +constexpr std::string_view extra_w_metadata = + "\xa2" // Map size 2 + "\x8b\x74\x78\x5f\x6d\x65\x74\x61\x64\x61\x74\x61" // "tx_metadata" + "\xa2" // Map size 2 + "\x83\x73\x74\x72" // "str" + "\x83\x61\x68\x61" // "aha" + "\x83\x6e\x75\x6d" // "num" + "\x7b" // 123 + "\x8a\x74\x78\x5f\x74\x69\x6d\x65\x6f\x75\x74" // "tx_timeout" + "\xc9\x07\xd0"; // INT_16 2000 + +constexpr std::string_view extra_w_127ms_timeout = + "\xa1" // Map size 1 + "\x8a\x74\x78\x5F\x74\x69\x6D\x65\x6F\x75\x74" // String size 10 "tx_timeout" + "\x7f"; // Integer 127 (representing 127ms) + inline constexpr uint8_t commit[] = {0xb0, 0x12}; } // namespace v4_3 @@ -248,7 +276,7 @@ void ExecuteInit(TestInputStream &input_stream, TestSession &session, std::vecto // Write bolt encoded run request void WriteRunRequest(TestInputStream &input_stream, const char *str, const bool is_v4 = false, - const std::string &extra = "\xA0") { + std::string_view extra = "\xA0") { // write chunk header auto len = strlen(str); WriteChunkHeader(input_stream, (3 + is_v4 * extra.size()) + 2 + len + 1); @@ -1165,3 +1193,32 @@ TEST(BoltSession, PassMetadata) { EXPECT_NE(find_str, end(output)); } } + +TEST(BoltSession, PartialStream) { + // v4+ + { + INIT_VARS; + + ExecuteHandshake(input_stream, session, output, v4_3::handshake_req, v4_3::handshake_resp); + ExecuteInit(input_stream, session, output, true); + + WriteRunRequest(input_stream, kQueryReturnMultiple, true, v4_3::extra_w_127ms_timeout); + session.Execute(); + ASSERT_EQ(session.state_, State::Result); + + ExecuteCommand(input_stream, session, v4::pull_one_req, sizeof(v4::pull_one_req)); + ASSERT_EQ(session.state_, State::Result); + constexpr std::array<uint8_t, 10> md_has_more_true{0x88, 0x68, 0x61, 0x73, 0x5F, 0x6D, 0x6F, 0x72, 0x65, 0xC3}; + auto find_has_more = std::search(cbegin(output), cend(output), cbegin(md_has_more_true), cend(md_has_more_true)); + EXPECT_NE(find_has_more, cend(output)); + + session.TestHook_ShouldAbort(); // pretend the 127ms timeout was hit + ExecuteCommand(input_stream, session, v4::pull_one_req, sizeof(v4::pull_one_req)); + + PrintOutput(output); + + auto const error_msg = std::u8string_view{u8"Transaction was asked to abort by another user."}; + auto const find_msg = std::search(cbegin(output), cend(output), cbegin(error_msg), cend(error_msg)); + EXPECT_NE(find_msg, cend(output)); + } +}