Add queries to show or terminate active transactions (#790)

This commit is contained in:
Andi 2023-03-27 15:46:00 +02:00 committed by GitHub
parent a9dc344b49
commit 029be10f1d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
44 changed files with 1255 additions and 172 deletions

View File

@ -55,6 +55,15 @@ class NotEnoughMemoryException : public std::exception {
const char *what() const throw() { return "Not enough memory!"; } const char *what() const throw() { return "Not enough memory!"; }
}; };
class MustAbortException : public std::exception {
public:
explicit MustAbortException(const std::string &message) : message_(message) {}
const char *what() const noexcept override { return message_.c_str(); }
private:
std::string message_;
};
// Forward declarations // Forward declarations
class Nodes; class Nodes;
using GraphNodes = Nodes; using GraphNodes = Nodes;
@ -141,6 +150,10 @@ class Graph {
/// @brief Deletes a relationship from the graph. /// @brief Deletes a relationship from the graph.
void DeleteRelationship(const Relationship &relationship); void DeleteRelationship(const Relationship &relationship);
bool MustAbort() const;
void CheckMustAbort() const;
private: private:
mgp_graph *graph_; mgp_graph *graph_;
}; };
@ -1572,6 +1585,14 @@ inline Id::Id(int64_t id) : id_(id) {}
inline Graph::Graph(mgp_graph *graph) : graph_(graph) {} inline Graph::Graph(mgp_graph *graph) : graph_(graph) {}
inline bool Graph::MustAbort() const { return must_abort(graph_); }
inline void Graph::CheckMustAbort() const {
if (MustAbort()) {
throw MustAbortException("Query was asked to abort.");
}
}
inline int64_t Graph::Order() const { inline int64_t Graph::Order() const {
int64_t i = 0; int64_t i = 0;
for (const auto _ : Nodes()) { for (const auto _ : Nodes()) {

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Licensed as a Memgraph Enterprise file under the Memgraph Enterprise // Licensed as a Memgraph Enterprise file under the Memgraph Enterprise
// License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use // License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use
@ -34,13 +34,17 @@ namespace memgraph::auth {
namespace { namespace {
// Constant list of all available permissions. // Constant list of all available permissions.
const std::vector<Permission> kPermissionsAll = { const std::vector<Permission> kPermissionsAll = {Permission::MATCH, Permission::CREATE,
Permission::MATCH, Permission::CREATE, Permission::MERGE, Permission::DELETE, Permission::MERGE, Permission::DELETE,
Permission::SET, Permission::REMOVE, Permission::INDEX, Permission::STATS, Permission::SET, Permission::REMOVE,
Permission::CONSTRAINT, Permission::DUMP, Permission::AUTH, Permission::REPLICATION, Permission::INDEX, Permission::STATS,
Permission::DURABILITY, Permission::READ_FILE, Permission::FREE_MEMORY, Permission::TRIGGER, Permission::CONSTRAINT, Permission::DUMP,
Permission::CONFIG, Permission::STREAM, Permission::MODULE_READ, Permission::MODULE_WRITE, Permission::AUTH, Permission::REPLICATION,
Permission::WEBSOCKET}; Permission::DURABILITY, Permission::READ_FILE,
Permission::FREE_MEMORY, Permission::TRIGGER,
Permission::CONFIG, Permission::STREAM,
Permission::MODULE_READ, Permission::MODULE_WRITE,
Permission::WEBSOCKET, Permission::TRANSACTION_MANAGEMENT};
} // namespace } // namespace
std::string PermissionToString(Permission permission) { std::string PermissionToString(Permission permission) {
@ -87,6 +91,8 @@ std::string PermissionToString(Permission permission) {
return "MODULE_WRITE"; return "MODULE_WRITE";
case Permission::WEBSOCKET: case Permission::WEBSOCKET:
return "WEBSOCKET"; return "WEBSOCKET";
case Permission::TRANSACTION_MANAGEMENT:
return "TRANSACTION_MANAGEMENT";
} }
} }

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Licensed as a Memgraph Enterprise file under the Memgraph Enterprise // Licensed as a Memgraph Enterprise file under the Memgraph Enterprise
// License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use // License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use
@ -40,7 +40,8 @@ enum class Permission : uint64_t {
STREAM = 1U << 17U, STREAM = 1U << 17U,
MODULE_READ = 1U << 18U, MODULE_READ = 1U << 18U,
MODULE_WRITE = 1U << 19U, MODULE_WRITE = 1U << 19U,
WEBSOCKET = 1U << 20U WEBSOCKET = 1U << 20U,
TRANSACTION_MANAGEMENT = 1U << 21U
}; };
// clang-format on // clang-format on

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -58,6 +58,8 @@ auth::Permission PrivilegeToPermission(query::AuthQuery::Privilege privilege) {
return auth::Permission::MODULE_WRITE; return auth::Permission::MODULE_WRITE;
case query::AuthQuery::Privilege::WEBSOCKET: case query::AuthQuery::Privilege::WEBSOCKET:
return auth::Permission::WEBSOCKET; return auth::Permission::WEBSOCKET;
case query::AuthQuery::Privilege::TRANSACTION_MANAGEMENT:
return auth::Permission::TRANSACTION_MANAGEMENT;
} }
} }

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -84,6 +84,7 @@ bool AuthChecker::IsUserAuthorized(const std::optional<std::string> &username,
return maybe_user.has_value() && IsUserAuthorized(*maybe_user, privileges); return maybe_user.has_value() && IsUserAuthorized(*maybe_user, privileges);
} }
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
std::unique_ptr<memgraph::query::FineGrainedAuthChecker> AuthChecker::GetFineGrainedAuthChecker( std::unique_ptr<memgraph::query::FineGrainedAuthChecker> AuthChecker::GetFineGrainedAuthChecker(
const std::string &username, const memgraph::query::DbAccessor *dba) const { const std::string &username, const memgraph::query::DbAccessor *dba) const {

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -26,9 +26,11 @@ class AuthChecker : public query::AuthChecker {
bool IsUserAuthorized(const std::optional<std::string> &username, bool IsUserAuthorized(const std::optional<std::string> &username,
const std::vector<query::AuthQuery::Privilege> &privileges) const override; const std::vector<query::AuthQuery::Privilege> &privileges) const override;
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
std::unique_ptr<memgraph::query::FineGrainedAuthChecker> GetFineGrainedAuthChecker( std::unique_ptr<memgraph::query::FineGrainedAuthChecker> GetFineGrainedAuthChecker(
const std::string &username, const memgraph::query::DbAccessor *dba) const override; const std::string &username, const memgraph::query::DbAccessor *dba) const override;
#endif #endif
[[nodiscard]] static bool IsUserAuthorized(const memgraph::auth::User &user, [[nodiscard]] static bool IsUserAuthorized(const memgraph::auth::User &user,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges); const std::vector<memgraph::query::AuthQuery::Privilege> &privileges);

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -522,6 +522,7 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
: memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream, : memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream,
memgraph::communication::v2::OutputStream>(input_stream, output_stream), memgraph::communication::v2::OutputStream>(input_stream, output_stream),
db_(data->db), db_(data->db),
interpreter_context_(data->interpreter_context),
interpreter_(data->interpreter_context), interpreter_(data->interpreter_context),
auth_(data->auth), auth_(data->auth),
#if MG_ENTERPRISE #if MG_ENTERPRISE
@ -529,6 +530,11 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
#endif #endif
endpoint_(endpoint), endpoint_(endpoint),
run_id_(data->run_id) { run_id_(data->run_id) {
interpreter_context_->interpreters.WithLock([this](auto &interpreters) { interpreters.insert(&interpreter_); });
}
~BoltSession() override {
interpreter_context_->interpreters.WithLock([this](auto &interpreters) { interpreters.erase(&interpreter_); });
} }
using memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream, using memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream,
@ -674,6 +680,7 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
// NOTE: Needed only for ToBoltValue conversions // NOTE: Needed only for ToBoltValue conversions
const memgraph::storage::Storage *db_; const memgraph::storage::Storage *db_;
memgraph::query::InterpreterContext *interpreter_context_;
memgraph::query::Interpreter interpreter_; memgraph::query::Interpreter interpreter_;
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth_; memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth_;
std::optional<memgraph::auth::User> user_; std::optional<memgraph::auth::User> user_;

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -24,6 +24,15 @@
namespace memgraph::query { namespace memgraph::query {
enum class TransactionStatus {
IDLE,
ACTIVE,
VERIFYING,
TERMINATED,
STARTED_COMMITTING,
STARTED_ROLLBACK,
};
struct EvaluationContext { struct EvaluationContext {
/// Memory for allocations during evaluation of a *single* Pull call. /// Memory for allocations during evaluation of a *single* Pull call.
/// ///
@ -66,6 +75,7 @@ struct ExecutionContext {
SymbolTable symbol_table; SymbolTable symbol_table;
EvaluationContext evaluation_context; EvaluationContext evaluation_context;
std::atomic<bool> *is_shutting_down{nullptr}; std::atomic<bool> *is_shutting_down{nullptr};
std::atomic<TransactionStatus> *transaction_status{nullptr};
bool is_profile_query{false}; bool is_profile_query{false};
std::chrono::duration<double> profile_execution_time; std::chrono::duration<double> profile_execution_time;
plan::ProfilingStats stats; plan::ProfilingStats stats;
@ -82,7 +92,9 @@ static_assert(std::is_move_assignable_v<ExecutionContext>, "ExecutionContext mus
static_assert(std::is_move_constructible_v<ExecutionContext>, "ExecutionContext must be move constructible!"); static_assert(std::is_move_constructible_v<ExecutionContext>, "ExecutionContext must be move constructible!");
inline bool MustAbort(const ExecutionContext &context) noexcept { inline bool MustAbort(const ExecutionContext &context) noexcept {
return (context.is_shutting_down != nullptr && context.is_shutting_down->load(std::memory_order_acquire)) || return (context.transaction_status != nullptr &&
context.transaction_status->load(std::memory_order_acquire) == TransactionStatus::TERMINATED) ||
(context.is_shutting_down != nullptr && context.is_shutting_down->load(std::memory_order_acquire)) ||
context.timer.IsExpired(); context.timer.IsExpired();
} }

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -120,9 +120,8 @@ class HintedAbortError : public utils::BasicException {
using utils::BasicException::BasicException; using utils::BasicException::BasicException;
HintedAbortError() HintedAbortError()
: utils::BasicException( : utils::BasicException(
"Transaction was asked to abort, most likely because it was " "Transaction was asked to abort either because it was executing longer than time specified or another user "
"executing longer than time specified by " "asked it to abort.") {}
"--query-execution-timeout-sec flag.") {}
}; };
class ExplicitTransactionUsageException : public QueryRuntimeException { class ExplicitTransactionUsageException : public QueryRuntimeException {
@ -237,4 +236,11 @@ class ReplicationException : public utils::BasicException {
: utils::BasicException("Replication Exception: {} Check the status of the replicas using 'SHOW REPLICA' query.", : utils::BasicException("Replication Exception: {} Check the status of the replicas using 'SHOW REPLICA' query.",
message) {} message) {}
}; };
class TransactionQueueInMulticommandTxException : public QueryException {
public:
TransactionQueueInMulticommandTxException()
: QueryException("Transaction queue queries not allowed in multicommand transactions.") {}
};
} // namespace memgraph::query } // namespace memgraph::query

View File

@ -10,6 +10,7 @@
// licenses/APL.txt. // licenses/APL.txt.
#include "query/frontend/ast/ast.hpp" #include "query/frontend/ast/ast.hpp"
#include "query/frontend/ast/ast_visitor.hpp"
#include "utils/typeinfo.hpp" #include "utils/typeinfo.hpp"
namespace memgraph { namespace memgraph {
@ -259,5 +260,8 @@ constexpr utils::TypeInfo query::Foreach::kType{utils::TypeId::AST_FOREACH, "For
constexpr utils::TypeInfo query::ShowConfigQuery::kType{utils::TypeId::AST_SHOW_CONFIG_QUERY, "ShowConfigQuery", constexpr utils::TypeInfo query::ShowConfigQuery::kType{utils::TypeId::AST_SHOW_CONFIG_QUERY, "ShowConfigQuery",
&query::Query::kType}; &query::Query::kType};
constexpr utils::TypeInfo query::TransactionQueueQuery::kType{utils::TypeId::AST_TRANSACTION_QUEUE_QUERY,
"TransactionQueueQuery", &query::Query::kType};
constexpr utils::TypeInfo query::Exists::kType{utils::TypeId::AST_EXISTS, "Exists", &query::Expression::kType}; constexpr utils::TypeInfo query::Exists::kType{utils::TypeId::AST_EXISTS, "Exists", &query::Expression::kType};
} // namespace memgraph } // namespace memgraph

View File

@ -2699,7 +2699,8 @@ class AuthQuery : public memgraph::query::Query {
STREAM, STREAM,
MODULE_READ, MODULE_READ,
MODULE_WRITE, MODULE_WRITE,
WEBSOCKET WEBSOCKET,
TRANSACTION_MANAGEMENT
}; };
enum class FineGrainedPrivilege { NOTHING, READ, UPDATE, CREATE_DELETE }; enum class FineGrainedPrivilege { NOTHING, READ, UPDATE, CREATE_DELETE };
@ -2752,13 +2753,17 @@ class AuthQuery : public memgraph::query::Query {
/// Constant that holds all available privileges. /// Constant that holds all available privileges.
const std::vector<AuthQuery::Privilege> kPrivilegesAll = { const std::vector<AuthQuery::Privilege> kPrivilegesAll = {
AuthQuery::Privilege::CREATE, AuthQuery::Privilege::DELETE, AuthQuery::Privilege::MATCH, AuthQuery::Privilege::CREATE, AuthQuery::Privilege::DELETE,
AuthQuery::Privilege::MERGE, AuthQuery::Privilege::SET, AuthQuery::Privilege::REMOVE, AuthQuery::Privilege::MATCH, AuthQuery::Privilege::MERGE,
AuthQuery::Privilege::INDEX, AuthQuery::Privilege::STATS, AuthQuery::Privilege::AUTH, AuthQuery::Privilege::SET, AuthQuery::Privilege::REMOVE,
AuthQuery::Privilege::CONSTRAINT, AuthQuery::Privilege::DUMP, AuthQuery::Privilege::REPLICATION, AuthQuery::Privilege::INDEX, AuthQuery::Privilege::STATS,
AuthQuery::Privilege::READ_FILE, AuthQuery::Privilege::DURABILITY, AuthQuery::Privilege::FREE_MEMORY, AuthQuery::Privilege::AUTH, AuthQuery::Privilege::CONSTRAINT,
AuthQuery::Privilege::TRIGGER, AuthQuery::Privilege::CONFIG, AuthQuery::Privilege::STREAM, AuthQuery::Privilege::DUMP, AuthQuery::Privilege::REPLICATION,
AuthQuery::Privilege::MODULE_READ, AuthQuery::Privilege::MODULE_WRITE, AuthQuery::Privilege::WEBSOCKET}; AuthQuery::Privilege::READ_FILE, AuthQuery::Privilege::DURABILITY,
AuthQuery::Privilege::FREE_MEMORY, AuthQuery::Privilege::TRIGGER,
AuthQuery::Privilege::CONFIG, AuthQuery::Privilege::STREAM,
AuthQuery::Privilege::MODULE_READ, AuthQuery::Privilege::MODULE_WRITE,
AuthQuery::Privilege::WEBSOCKET, AuthQuery::Privilege::TRANSACTION_MANAGEMENT};
class InfoQuery : public memgraph::query::Query { class InfoQuery : public memgraph::query::Query {
public: public:
@ -3203,6 +3208,28 @@ class ShowConfigQuery : public memgraph::query::Query {
} }
}; };
class TransactionQueueQuery : public memgraph::query::Query {
public:
static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; }
enum class Action { SHOW_TRANSACTIONS, TERMINATE_TRANSACTIONS };
TransactionQueueQuery() = default;
DEFVISITABLE(QueryVisitor<void>);
memgraph::query::TransactionQueueQuery::Action action_;
std::vector<Expression *> transaction_id_list_;
TransactionQueueQuery *Clone(AstStorage *storage) const override {
auto *object = storage->Create<TransactionQueueQuery>();
object->action_ = action_;
object->transaction_id_list_ = transaction_id_list_;
return object;
}
};
class Exists : public memgraph::query::Expression { class Exists : public memgraph::query::Expression {
public: public:
static const utils::TypeInfo kType; static const utils::TypeInfo kType;

View File

@ -2284,7 +2284,7 @@ cpp<#
(lcp:define-enum privilege (lcp:define-enum privilege
(create delete match merge set remove index stats auth constraint (create delete match merge set remove index stats auth constraint
dump replication durability read_file free_memory trigger config stream module_read module_write dump replication durability read_file free_memory trigger config stream module_read module_write
websocket) websocket transaction_management)
(:serialize)) (:serialize))
(lcp:define-enum fine-grained-privilege (lcp:define-enum fine-grained-privilege
(nothing read update create_delete) (nothing read update create_delete)
@ -2333,7 +2333,7 @@ const std::vector<AuthQuery::Privilege> kPrivilegesAll = {
AuthQuery::Privilege::FREE_MEMORY, AuthQuery::Privilege::TRIGGER, AuthQuery::Privilege::FREE_MEMORY, AuthQuery::Privilege::TRIGGER,
AuthQuery::Privilege::CONFIG, AuthQuery::Privilege::STREAM, AuthQuery::Privilege::CONFIG, AuthQuery::Privilege::STREAM,
AuthQuery::Privilege::MODULE_READ, AuthQuery::Privilege::MODULE_WRITE, AuthQuery::Privilege::MODULE_READ, AuthQuery::Privilege::MODULE_WRITE,
AuthQuery::Privilege::WEBSOCKET}; AuthQuery::Privilege::WEBSOCKET, AuthQuery::Privilege::TRANSACTION_MANAGEMENT};
cpp<# cpp<#
(lcp:define-class info-query (query) (lcp:define-class info-query (query)
@ -2661,6 +2661,26 @@ cpp<#
(:serialize (:slk)) (:serialize (:slk))
(:clone)) (:clone))
(lcp:define-class transaction-queue-query (query)
((action "Action" :scope :public)
(transaction_id_list "std::vector<Expression*>" :scope :public))
(:public
(lcp:define-enum action
(show-transactions terminate-transactions)
(:serialize))
#>cpp
TransactionQueueQuery() = default;
DEFVISITABLE(QueryVisitor<void>);
cpp<#)
(:private
#>cpp
friend class AstStorage;
cpp<#)
(:serialize (:slk))
(:clone))
(lcp:define-class version-query (query) () (lcp:define-class version-query (query) ()
(:public (:public
#>cpp #>cpp

View File

@ -95,6 +95,7 @@ class SettingQuery;
class VersionQuery; class VersionQuery;
class Foreach; class Foreach;
class ShowConfigQuery; class ShowConfigQuery;
class TransactionQueueQuery;
class Exists; class Exists;
using TreeCompositeVisitor = utils::CompositeVisitor< using TreeCompositeVisitor = utils::CompositeVisitor<
@ -127,9 +128,10 @@ class ExpressionVisitor
None, ParameterLookup, Identifier, PrimitiveLiteral, RegexMatch, Exists> {}; None, ParameterLookup, Identifier, PrimitiveLiteral, RegexMatch, Exists> {};
template <class TResult> template <class TResult>
class QueryVisitor : public utils::Visitor<TResult, CypherQuery, ExplainQuery, ProfileQuery, IndexQuery, AuthQuery, class QueryVisitor
InfoQuery, ConstraintQuery, DumpQuery, ReplicationQuery, LockPathQuery, : public utils::Visitor<TResult, CypherQuery, ExplainQuery, ProfileQuery, IndexQuery, AuthQuery, InfoQuery,
FreeMemoryQuery, TriggerQuery, IsolationLevelQuery, CreateSnapshotQuery, ConstraintQuery, DumpQuery, ReplicationQuery, LockPathQuery, FreeMemoryQuery, TriggerQuery,
StreamQuery, SettingQuery, VersionQuery, ShowConfigQuery> {}; IsolationLevelQuery, CreateSnapshotQuery, StreamQuery, SettingQuery, TransactionQueueQuery,
VersionQuery, ShowConfigQuery> {};
} // namespace memgraph::query } // namespace memgraph::query

View File

@ -11,8 +11,10 @@
#include "query/frontend/ast/cypher_main_visitor.hpp" #include "query/frontend/ast/cypher_main_visitor.hpp"
#include <support/Any.h> #include <support/Any.h>
#include <tree/ParseTreeVisitor.h>
#include <algorithm> #include <algorithm>
#include <any>
#include <climits> #include <climits>
#include <codecvt> #include <codecvt>
#include <cstring> #include <cstring>
@ -631,6 +633,7 @@ void GetTopicNames(auto &destination, MemgraphCypher::TopicNamesContext *topic_n
destination = std::any_cast<Expression *>(topic_names_ctx->accept(&visitor)); destination = std::any_cast<Expression *>(topic_names_ctx->accept(&visitor));
} }
} }
} // namespace } // namespace
antlrcpp::Any CypherMainVisitor::visitKafkaCreateStreamConfig(MemgraphCypher::KafkaCreateStreamConfigContext *ctx) { antlrcpp::Any CypherMainVisitor::visitKafkaCreateStreamConfig(MemgraphCypher::KafkaCreateStreamConfigContext *ctx) {
@ -883,6 +886,34 @@ antlrcpp::Any CypherMainVisitor::visitShowSettings(MemgraphCypher::ShowSettingsC
return setting_query; return setting_query;
} }
antlrcpp::Any CypherMainVisitor::visitTransactionQueueQuery(MemgraphCypher::TransactionQueueQueryContext *ctx) {
MG_ASSERT(ctx->children.size() == 1, "TransactionQueueQuery should have exactly one child!");
auto *transaction_queue_query = std::any_cast<TransactionQueueQuery *>(ctx->children[0]->accept(this));
query_ = transaction_queue_query;
return transaction_queue_query;
}
antlrcpp::Any CypherMainVisitor::visitShowTransactions(MemgraphCypher::ShowTransactionsContext * /*ctx*/) {
auto *transaction_shower = storage_->Create<TransactionQueueQuery>();
transaction_shower->action_ = TransactionQueueQuery::Action::SHOW_TRANSACTIONS;
return transaction_shower;
}
antlrcpp::Any CypherMainVisitor::visitTerminateTransactions(MemgraphCypher::TerminateTransactionsContext *ctx) {
auto *terminator = storage_->Create<TransactionQueueQuery>();
terminator->action_ = TransactionQueueQuery::Action::TERMINATE_TRANSACTIONS;
terminator->transaction_id_list_ = std::any_cast<std::vector<Expression *>>(ctx->transactionIdList()->accept(this));
return terminator;
}
antlrcpp::Any CypherMainVisitor::visitTransactionIdList(MemgraphCypher::TransactionIdListContext *ctx) {
std::vector<Expression *> transaction_ids;
for (auto *transaction_id : ctx->transactionId()) {
transaction_ids.push_back(std::any_cast<Expression *>(transaction_id->accept(this)));
}
return transaction_ids;
}
antlrcpp::Any CypherMainVisitor::visitVersionQuery(MemgraphCypher::VersionQueryContext * /*ctx*/) { antlrcpp::Any CypherMainVisitor::visitVersionQuery(MemgraphCypher::VersionQueryContext * /*ctx*/) {
auto *version_query = storage_->Create<VersionQuery>(); auto *version_query = storage_->Create<VersionQuery>();
query_ = version_query; query_ = version_query;
@ -1451,6 +1482,7 @@ antlrcpp::Any CypherMainVisitor::visitPrivilege(MemgraphCypher::PrivilegeContext
if (ctx->MODULE_READ()) return AuthQuery::Privilege::MODULE_READ; if (ctx->MODULE_READ()) return AuthQuery::Privilege::MODULE_READ;
if (ctx->MODULE_WRITE()) return AuthQuery::Privilege::MODULE_WRITE; if (ctx->MODULE_WRITE()) return AuthQuery::Privilege::MODULE_WRITE;
if (ctx->WEBSOCKET()) return AuthQuery::Privilege::WEBSOCKET; if (ctx->WEBSOCKET()) return AuthQuery::Privilege::WEBSOCKET;
if (ctx->TRANSACTION_MANAGEMENT()) return AuthQuery::Privilege::TRANSACTION_MANAGEMENT;
LOG_FATAL("Should not get here - unknown privilege!"); LOG_FATAL("Should not get here - unknown privilege!");
} }

View File

@ -358,6 +358,26 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
*/ */
antlrcpp::Any visitShowSettings(MemgraphCypher::ShowSettingsContext *ctx) override; antlrcpp::Any visitShowSettings(MemgraphCypher::ShowSettingsContext *ctx) override;
/**
* @return TransactionQueueQuery*
*/
antlrcpp::Any visitTransactionQueueQuery(MemgraphCypher::TransactionQueueQueryContext *ctx) override;
/**
* @return ShowTransactions*
*/
antlrcpp::Any visitShowTransactions(MemgraphCypher::ShowTransactionsContext *ctx) override;
/**
* @return TerminateTransactions*
*/
antlrcpp::Any visitTerminateTransactions(MemgraphCypher::TerminateTransactionsContext *ctx) override;
/**
* @return TransactionIdList*
*/
antlrcpp::Any visitTransactionIdList(MemgraphCypher::TransactionIdListContext *ctx) override;
/** /**
* @return VersionQuery* * @return VersionQuery*
*/ */

View File

@ -102,6 +102,8 @@ memgraphCypherKeyword : cypherKeyword
| USER | USER
| USERS | USERS
| VERSION | VERSION
| TERMINATE
| TRANSACTIONS
; ;
symbolicName : UnescapedSymbolicName symbolicName : UnescapedSymbolicName
@ -127,6 +129,7 @@ query : cypherQuery
| settingQuery | settingQuery
| versionQuery | versionQuery
| showConfigQuery | showConfigQuery
| transactionQueueQuery
; ;
authQuery : createRole authQuery : createRole
@ -197,6 +200,14 @@ settingQuery : setSetting
| showSettings | showSettings
; ;
transactionQueueQuery : showTransactions
| terminateTransactions
;
showTransactions : SHOW TRANSACTIONS ;
terminateTransactions : TERMINATE TRANSACTIONS transactionIdList;
loadCsv : LOAD CSV FROM csvFile ( WITH | NO ) HEADER loadCsv : LOAD CSV FROM csvFile ( WITH | NO ) HEADER
( IGNORE BAD ) ? ( IGNORE BAD ) ?
( DELIMITER delimiter ) ? ( DELIMITER delimiter ) ?
@ -259,6 +270,7 @@ privilege : CREATE
| MODULE_READ | MODULE_READ
| MODULE_WRITE | MODULE_WRITE
| WEBSOCKET | WEBSOCKET
| TRANSACTION_MANAGEMENT
; ;
granularPrivilege : NOTHING | READ | UPDATE | CREATE_DELETE ; granularPrivilege : NOTHING | READ | UPDATE | CREATE_DELETE ;
@ -402,3 +414,7 @@ showSettings : SHOW DATABASE SETTINGS ;
showConfigQuery : SHOW CONFIG ; showConfigQuery : SHOW CONFIG ;
versionQuery : SHOW VERSION ; versionQuery : SHOW VERSION ;
transactionIdList : transactionId ( ',' transactionId )* ;
transactionId : literal ;

View File

@ -53,6 +53,7 @@ DIRECTORY : D I R E C T O R Y ;
DROP : D R O P ; DROP : D R O P ;
DUMP : D U M P ; DUMP : D U M P ;
DURABILITY : D U R A B I L I T Y ; DURABILITY : D U R A B I L I T Y ;
EDGE_TYPES : E D G E UNDERSCORE T Y P E S ;
EXECUTE : E X E C U T E ; EXECUTE : E X E C U T E ;
FOR : F O R ; FOR : F O R ;
FOREACH : F O R E A C H; FOREACH : F O R E A C H;
@ -103,10 +104,13 @@ STOP : S T O P ;
STREAM : S T R E A M ; STREAM : S T R E A M ;
STREAMS : S T R E A M S ; STREAMS : S T R E A M S ;
SYNC : S Y N C ; SYNC : S Y N C ;
TERMINATE : T E R M I N A T E ;
TIMEOUT : T I M E O U T ; TIMEOUT : T I M E O U T ;
TO : T O ; TO : T O ;
TOPICS : T O P I C S; TOPICS : T O P I C S;
TRANSACTION : T R A N S A C T I O N ; TRANSACTION : T R A N S A C T I O N ;
TRANSACTION_MANAGEMENT : T R A N S A C T I O N UNDERSCORE M A N A G E M E N T ;
TRANSACTIONS : T R A N S A C T I O N S ;
TRANSFORM : T R A N S F O R M ; TRANSFORM : T R A N S F O R M ;
TRIGGER : T R I G G E R ; TRIGGER : T R I G G E R ;
TRIGGERS : T R I G G E R S ; TRIGGERS : T R I G G E R S ;
@ -117,4 +121,3 @@ USER : U S E R ;
USERS : U S E R S ; USERS : U S E R S ;
VERSION : V E R S I O N ; VERSION : V E R S I O N ;
WEBSOCKET : W E B S O C K E T ; WEBSOCKET : W E B S O C K E T ;
EDGE_TYPES : E D G E UNDERSCORE T Y P E S ;

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -80,6 +80,8 @@ class PrivilegeExtractor : public QueryVisitor<void>, public HierarchicalTreeVis
void Visit(SettingQuery & /*setting_query*/) override { AddPrivilege(AuthQuery::Privilege::CONFIG); } void Visit(SettingQuery & /*setting_query*/) override { AddPrivilege(AuthQuery::Privilege::CONFIG); }
void Visit(TransactionQueueQuery & /*transaction_queue_query*/) override {}
void Visit(VersionQuery & /*version_query*/) override { AddPrivilege(AuthQuery::Privilege::STATS); } void Visit(VersionQuery & /*version_query*/) override { AddPrivilege(AuthQuery::Privilege::STATS); }
bool PreVisit(Create & /*unused*/) override { bool PreVisit(Create & /*unused*/) override {

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -18,9 +18,12 @@
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <functional> #include <functional>
#include <iterator>
#include <limits> #include <limits>
#include <optional> #include <optional>
#include <thread>
#include <unordered_map> #include <unordered_map>
#include <utility>
#include <variant> #include <variant>
#include "auth/models.hpp" #include "auth/models.hpp"
@ -59,6 +62,7 @@
#include "utils/logging.hpp" #include "utils/logging.hpp"
#include "utils/memory.hpp" #include "utils/memory.hpp"
#include "utils/memory_tracker.hpp" #include "utils/memory_tracker.hpp"
#include "utils/on_scope_exit.hpp"
#include "utils/readable_size.hpp" #include "utils/readable_size.hpp"
#include "utils/settings.hpp" #include "utils/settings.hpp"
#include "utils/string.hpp" #include "utils/string.hpp"
@ -975,7 +979,8 @@ struct PullPlanVector {
struct PullPlan { struct PullPlan {
explicit PullPlan(std::shared_ptr<CachedPlan> plan, const Parameters &parameters, bool is_profile_query, explicit PullPlan(std::shared_ptr<CachedPlan> plan, const Parameters &parameters, bool is_profile_query,
DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory, DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory,
std::optional<std::string> username, TriggerContextCollector *trigger_context_collector = nullptr, std::optional<std::string> username, std::atomic<TransactionStatus> *transaction_status,
TriggerContextCollector *trigger_context_collector = nullptr,
std::optional<size_t> memory_limit = {}); std::optional<size_t> memory_limit = {});
std::optional<plan::ProfilingStatsWithTotalTime> Pull(AnyStream *stream, std::optional<int> n, std::optional<plan::ProfilingStatsWithTotalTime> Pull(AnyStream *stream, std::optional<int> n,
const std::vector<Symbol> &output_symbols, const std::vector<Symbol> &output_symbols,
@ -1004,8 +1009,8 @@ struct PullPlan {
PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters &parameters, const bool is_profile_query, PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters &parameters, const bool is_profile_query,
DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory, DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory,
std::optional<std::string> username, TriggerContextCollector *trigger_context_collector, std::optional<std::string> username, std::atomic<TransactionStatus> *transaction_status,
const std::optional<size_t> memory_limit) TriggerContextCollector *trigger_context_collector, const std::optional<size_t> memory_limit)
: plan_(plan), : plan_(plan),
cursor_(plan->plan().MakeCursor(execution_memory)), cursor_(plan->plan().MakeCursor(execution_memory)),
frame_(plan->symbol_table().max_position(), execution_memory), frame_(plan->symbol_table().max_position(), execution_memory),
@ -1025,6 +1030,7 @@ PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters &par
ctx_.timer = utils::AsyncTimer{interpreter_context->config.execution_timeout_sec}; ctx_.timer = utils::AsyncTimer{interpreter_context->config.execution_timeout_sec};
} }
ctx_.is_shutting_down = &interpreter_context->is_shutting_down; ctx_.is_shutting_down = &interpreter_context->is_shutting_down;
ctx_.transaction_status = transaction_status;
ctx_.is_profile_query = is_profile_query; ctx_.is_profile_query = is_profile_query;
ctx_.trigger_context_collector = trigger_context_collector; ctx_.trigger_context_collector = trigger_context_collector;
} }
@ -1137,12 +1143,14 @@ PreparedQuery Interpreter::PrepareTransactionQuery(std::string_view query_upper)
if (in_explicit_transaction_) { if (in_explicit_transaction_) {
throw ExplicitTransactionUsageException("Nested transactions are not supported."); throw ExplicitTransactionUsageException("Nested transactions are not supported.");
} }
in_explicit_transaction_ = true; in_explicit_transaction_ = true;
expect_rollback_ = false; expect_rollback_ = false;
db_accessor_ = db_accessor_ =
std::make_unique<storage::Storage::Accessor>(interpreter_context_->db->Access(GetIsolationLevelOverride())); std::make_unique<storage::Storage::Accessor>(interpreter_context_->db->Access(GetIsolationLevelOverride()));
execution_db_accessor_.emplace(db_accessor_.get()); execution_db_accessor_.emplace(db_accessor_.get());
transaction_status_.store(TransactionStatus::ACTIVE, std::memory_order_release);
if (interpreter_context_->trigger_store.HasTriggers()) { if (interpreter_context_->trigger_store.HasTriggers()) {
trigger_context_collector_.emplace(interpreter_context_->trigger_store.GetEventTypes()); trigger_context_collector_.emplace(interpreter_context_->trigger_store.GetEventTypes());
@ -1194,7 +1202,7 @@ PreparedQuery Interpreter::PrepareTransactionQuery(std::string_view query_upper)
PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string, TypedValue> *summary, PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string, TypedValue> *summary,
InterpreterContext *interpreter_context, DbAccessor *dba, InterpreterContext *interpreter_context, DbAccessor *dba,
utils::MemoryResource *execution_memory, std::vector<Notification> *notifications, utils::MemoryResource *execution_memory, std::vector<Notification> *notifications,
const std::string *username, const std::string *username, std::atomic<TransactionStatus> *transaction_status,
TriggerContextCollector *trigger_context_collector = nullptr) { TriggerContextCollector *trigger_context_collector = nullptr) {
auto *cypher_query = utils::Downcast<CypherQuery>(parsed_query.query); auto *cypher_query = utils::Downcast<CypherQuery>(parsed_query.query);
@ -1239,9 +1247,9 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string,
header.push_back( header.push_back(
utils::FindOr(parsed_query.stripped_query.named_expressions(), symbol.token_position(), symbol.name()).first); utils::FindOr(parsed_query.stripped_query.named_expressions(), symbol.token_position(), symbol.name()).first);
} }
auto pull_plan = auto pull_plan = std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context,
std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context, execution_memory, execution_memory, StringPointerToOptional(username), transaction_status,
StringPointerToOptional(username), trigger_context_collector, memory_limit); trigger_context_collector, memory_limit);
return PreparedQuery{std::move(header), std::move(parsed_query.required_privileges), return PreparedQuery{std::move(header), std::move(parsed_query.required_privileges),
[pull_plan = std::move(pull_plan), output_symbols = std::move(output_symbols), summary]( [pull_plan = std::move(pull_plan), output_symbols = std::move(output_symbols), summary](
AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> { AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> {
@ -1301,8 +1309,8 @@ PreparedQuery PrepareExplainQuery(ParsedQuery parsed_query, std::map<std::string
PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_transaction, PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
std::map<std::string, TypedValue> *summary, InterpreterContext *interpreter_context, std::map<std::string, TypedValue> *summary, InterpreterContext *interpreter_context,
DbAccessor *dba, utils::MemoryResource *execution_memory, DbAccessor *dba, utils::MemoryResource *execution_memory, const std::string *username,
const std::string *username) { std::atomic<TransactionStatus> *transaction_status) {
const std::string kProfileQueryStart = "profile "; const std::string kProfileQueryStart = "profile ";
MG_ASSERT(utils::StartsWith(utils::ToLowerCase(parsed_query.stripped_query.query()), kProfileQueryStart), MG_ASSERT(utils::StartsWith(utils::ToLowerCase(parsed_query.stripped_query.query()), kProfileQueryStart),
@ -1363,13 +1371,14 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra
// We want to execute the query we are profiling lazily, so we delay // We want to execute the query we are profiling lazily, so we delay
// the construction of the corresponding context. // the construction of the corresponding context.
stats_and_total_time = std::optional<plan::ProfilingStatsWithTotalTime>{}, stats_and_total_time = std::optional<plan::ProfilingStatsWithTotalTime>{},
pull_plan = std::shared_ptr<PullPlanVector>(nullptr)]( pull_plan = std::shared_ptr<PullPlanVector>(nullptr), transaction_status](
AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> { AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> {
// No output symbols are given so that nothing is streamed. // No output symbols are given so that nothing is streamed.
if (!stats_and_total_time) { if (!stats_and_total_time) {
stats_and_total_time = PullPlan(plan, parameters, true, dba, interpreter_context, stats_and_total_time =
execution_memory, optional_username, nullptr, memory_limit) PullPlan(plan, parameters, true, dba, interpreter_context, execution_memory,
.Pull(stream, {}, {}, summary); optional_username, transaction_status, nullptr, memory_limit)
.Pull(stream, {}, {}, summary);
pull_plan = std::make_shared<PullPlanVector>(ProfilingStatsToTable(*stats_and_total_time)); pull_plan = std::make_shared<PullPlanVector>(ProfilingStatsToTable(*stats_and_total_time));
} }
@ -1524,7 +1533,8 @@ PreparedQuery PrepareIndexQuery(ParsedQuery parsed_query, bool in_explicit_trans
PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transaction, PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
std::map<std::string, TypedValue> *summary, InterpreterContext *interpreter_context, std::map<std::string, TypedValue> *summary, InterpreterContext *interpreter_context,
DbAccessor *dba, utils::MemoryResource *execution_memory, const std::string *username) { DbAccessor *dba, utils::MemoryResource *execution_memory, const std::string *username,
std::atomic<TransactionStatus> *transaction_status) {
if (in_explicit_transaction) { if (in_explicit_transaction) {
throw UserModificationInMulticommandTxException(); throw UserModificationInMulticommandTxException();
} }
@ -1545,7 +1555,7 @@ PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transa
0.0, AstStorage{}, symbol_table)); 0.0, AstStorage{}, symbol_table));
auto pull_plan = std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context, auto pull_plan = std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context,
execution_memory, StringPointerToOptional(username)); execution_memory, StringPointerToOptional(username), transaction_status);
return PreparedQuery{ return PreparedQuery{
callback.header, std::move(parsed_query.required_privileges), callback.header, std::move(parsed_query.required_privileges),
[pull_plan = std::move(pull_plan), callback = std::move(callback), output_symbols = std::move(output_symbols), [pull_plan = std::move(pull_plan), callback = std::move(callback), output_symbols = std::move(output_symbols),
@ -1558,7 +1568,7 @@ PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transa
RWType::NONE}; RWType::NONE};
} }
PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, const bool in_explicit_transaction, PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
std::vector<Notification> *notifications, InterpreterContext *interpreter_context, std::vector<Notification> *notifications, InterpreterContext *interpreter_context,
DbAccessor *dba) { DbAccessor *dba) {
if (in_explicit_transaction) { if (in_explicit_transaction) {
@ -1586,7 +1596,7 @@ PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, const bool in_ex
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
} }
PreparedQuery PrepareLockPathQuery(ParsedQuery parsed_query, const bool in_explicit_transaction, PreparedQuery PrepareLockPathQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
InterpreterContext *interpreter_context, DbAccessor *dba) { InterpreterContext *interpreter_context, DbAccessor *dba) {
if (in_explicit_transaction) { if (in_explicit_transaction) {
throw LockPathModificationInMulticommandTxException(); throw LockPathModificationInMulticommandTxException();
@ -1615,7 +1625,7 @@ PreparedQuery PrepareLockPathQuery(ParsedQuery parsed_query, const bool in_expli
RWType::NONE}; RWType::NONE};
} }
PreparedQuery PrepareFreeMemoryQuery(ParsedQuery parsed_query, const bool in_explicit_transaction, PreparedQuery PrepareFreeMemoryQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
InterpreterContext *interpreter_context) { InterpreterContext *interpreter_context) {
if (in_explicit_transaction) { if (in_explicit_transaction) {
throw FreeMemoryModificationInMulticommandTxException(); throw FreeMemoryModificationInMulticommandTxException();
@ -1632,7 +1642,7 @@ PreparedQuery PrepareFreeMemoryQuery(ParsedQuery parsed_query, const bool in_exp
RWType::NONE}; RWType::NONE};
} }
PreparedQuery PrepareShowConfigQuery(ParsedQuery parsed_query, const bool in_explicit_transaction) { PreparedQuery PrepareShowConfigQuery(ParsedQuery parsed_query, bool in_explicit_transaction) {
if (in_explicit_transaction) { if (in_explicit_transaction) {
throw ShowConfigModificationInMulticommandTxException(); throw ShowConfigModificationInMulticommandTxException();
} }
@ -1736,7 +1746,7 @@ Callback ShowTriggers(InterpreterContext *interpreter_context) {
}}; }};
} }
PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, const bool in_explicit_transaction, PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
std::vector<Notification> *notifications, InterpreterContext *interpreter_context, std::vector<Notification> *notifications, InterpreterContext *interpreter_context,
DbAccessor *dba, const std::map<std::string, storage::PropertyValue> &user_parameters, DbAccessor *dba, const std::map<std::string, storage::PropertyValue> &user_parameters,
const std::string *username) { const std::string *username) {
@ -1786,7 +1796,7 @@ PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, const bool in_explic
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
} }
PreparedQuery PrepareStreamQuery(ParsedQuery parsed_query, const bool in_explicit_transaction, PreparedQuery PrepareStreamQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
std::vector<Notification> *notifications, InterpreterContext *interpreter_context, std::vector<Notification> *notifications, InterpreterContext *interpreter_context,
DbAccessor *dba, DbAccessor *dba,
const std::map<std::string, storage::PropertyValue> & /*user_parameters*/, const std::map<std::string, storage::PropertyValue> & /*user_parameters*/,
@ -1828,7 +1838,7 @@ constexpr auto ToStorageIsolationLevel(const IsolationLevelQuery::IsolationLevel
} }
} }
PreparedQuery PrepareIsolationLevelQuery(ParsedQuery parsed_query, const bool in_explicit_transaction, PreparedQuery PrepareIsolationLevelQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
InterpreterContext *interpreter_context, Interpreter *interpreter) { InterpreterContext *interpreter_context, Interpreter *interpreter) {
if (in_explicit_transaction) { if (in_explicit_transaction) {
throw IsolationLevelModificationInMulticommandTxException(); throw IsolationLevelModificationInMulticommandTxException();
@ -1883,7 +1893,7 @@ PreparedQuery PrepareCreateSnapshotQuery(ParsedQuery parsed_query, bool in_expli
RWType::NONE}; RWType::NONE};
} }
PreparedQuery PrepareSettingQuery(ParsedQuery parsed_query, const bool in_explicit_transaction, DbAccessor *dba) { PreparedQuery PrepareSettingQuery(ParsedQuery parsed_query, bool in_explicit_transaction, DbAccessor *dba) {
if (in_explicit_transaction) { if (in_explicit_transaction) {
throw SettingConfigInMulticommandTxException{}; throw SettingConfigInMulticommandTxException{};
} }
@ -1909,7 +1919,155 @@ PreparedQuery PrepareSettingQuery(ParsedQuery parsed_query, const bool in_explic
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
} }
PreparedQuery PrepareVersionQuery(ParsedQuery parsed_query, const bool in_explicit_transaction) { std::vector<std::vector<TypedValue>> TransactionQueueQueryHandler::ShowTransactions(
const std::unordered_set<Interpreter *> &interpreters, const std::optional<std::string> &username,
bool hasTransactionManagementPrivilege) {
std::vector<std::vector<TypedValue>> results;
results.reserve(interpreters.size());
for (Interpreter *interpreter : interpreters) {
TransactionStatus alive_status = TransactionStatus::ACTIVE;
// if it is just checking status, commit and abort should wait for the end of the check
// ignore interpreters that already started committing or rollback
if (!interpreter->transaction_status_.compare_exchange_strong(alive_status, TransactionStatus::VERIFYING)) {
continue;
}
utils::OnScopeExit clean_status([interpreter]() {
interpreter->transaction_status_.store(TransactionStatus::ACTIVE, std::memory_order_release);
});
std::optional<uint64_t> transaction_id = interpreter->GetTransactionId();
if (transaction_id.has_value() && (interpreter->username_ == username || hasTransactionManagementPrivilege)) {
const auto &typed_queries = interpreter->GetQueries();
results.push_back({TypedValue(interpreter->username_.value_or("")),
TypedValue(std::to_string(transaction_id.value())), TypedValue(typed_queries)});
}
}
return results;
}
std::vector<std::vector<TypedValue>> TransactionQueueQueryHandler::KillTransactions(
InterpreterContext *interpreter_context, const std::vector<std::string> &maybe_kill_transaction_ids,
const std::optional<std::string> &username, bool hasTransactionManagementPrivilege) {
std::vector<std::vector<TypedValue>> results;
for (const std::string &transaction_id : maybe_kill_transaction_ids) {
bool killed = false;
bool transaction_found = false;
// Multiple simultaneous TERMINATE TRANSACTIONS aren't allowed
// TERMINATE and SHOW TRANSACTIONS are mutually exclusive
interpreter_context->interpreters.WithLock([&transaction_id, &killed, &transaction_found, username,
hasTransactionManagementPrivilege](const auto &interpreters) {
for (Interpreter *interpreter : interpreters) {
TransactionStatus alive_status = TransactionStatus::ACTIVE;
// if it is just checking kill, commit and abort should wait for the end of the check
// The only way to start checking if the transaction will get killed is if the transaction_status is
// active
if (!interpreter->transaction_status_.compare_exchange_strong(alive_status, TransactionStatus::VERIFYING)) {
continue;
}
utils::OnScopeExit clean_status([interpreter, &killed]() {
if (killed) {
interpreter->transaction_status_.store(TransactionStatus::TERMINATED, std::memory_order_release);
} else {
interpreter->transaction_status_.store(TransactionStatus::ACTIVE, std::memory_order_release);
}
});
std::optional<uint64_t> intr_trans = interpreter->GetTransactionId();
if (intr_trans.has_value() && std::to_string(intr_trans.value()) == transaction_id) {
transaction_found = true;
if (interpreter->username_ == username || hasTransactionManagementPrivilege) {
killed = true;
spdlog::warn("Transaction {} successfully killed", transaction_id);
} else {
spdlog::warn("Not enough rights to kill the transaction");
}
break;
}
}
});
if (!transaction_found) {
spdlog::warn("Transaction {} not found", transaction_id);
}
results.push_back({TypedValue(transaction_id), TypedValue(killed)});
}
return results;
}
Callback HandleTransactionQueueQuery(TransactionQueueQuery *transaction_query,
const std::optional<std::string> &username, const Parameters &parameters,
InterpreterContext *interpreter_context, DbAccessor *db_accessor) {
Frame frame(0);
SymbolTable symbol_table;
EvaluationContext evaluation_context;
evaluation_context.timestamp = QueryTimestamp();
evaluation_context.parameters = parameters;
ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, db_accessor, storage::View::OLD);
bool hasTransactionManagementPrivilege = interpreter_context->auth_checker->IsUserAuthorized(
username, {query::AuthQuery::Privilege::TRANSACTION_MANAGEMENT});
Callback callback;
switch (transaction_query->action_) {
case TransactionQueueQuery::Action::SHOW_TRANSACTIONS: {
callback.header = {"username", "transaction_id", "query"};
callback.fn = [handler = TransactionQueueQueryHandler(), interpreter_context, username,
hasTransactionManagementPrivilege]() mutable {
std::vector<std::vector<TypedValue>> results;
// Multiple simultaneous SHOW TRANSACTIONS aren't allowed
interpreter_context->interpreters.WithLock(
[&results, handler, username, hasTransactionManagementPrivilege](const auto &interpreters) {
results = handler.ShowTransactions(interpreters, username, hasTransactionManagementPrivilege);
});
return results;
};
break;
}
case TransactionQueueQuery::Action::TERMINATE_TRANSACTIONS: {
std::vector<std::string> maybe_kill_transaction_ids;
std::transform(transaction_query->transaction_id_list_.begin(), transaction_query->transaction_id_list_.end(),
std::back_inserter(maybe_kill_transaction_ids), [&evaluator](Expression *expression) {
return std::string(expression->Accept(evaluator).ValueString());
});
callback.header = {"transaction_id", "killed"};
callback.fn = [handler = TransactionQueueQueryHandler(), interpreter_context, maybe_kill_transaction_ids,
username, hasTransactionManagementPrivilege]() mutable {
return handler.KillTransactions(interpreter_context, maybe_kill_transaction_ids, username,
hasTransactionManagementPrivilege);
};
break;
}
}
return callback;
}
PreparedQuery PrepareTransactionQueueQuery(ParsedQuery parsed_query, const std::optional<std::string> &username,
bool in_explicit_transaction, InterpreterContext *interpreter_context,
DbAccessor *dba) {
if (in_explicit_transaction) {
throw TransactionQueueInMulticommandTxException();
}
auto *transaction_queue_query = utils::Downcast<TransactionQueueQuery>(parsed_query.query);
MG_ASSERT(transaction_queue_query);
auto callback =
HandleTransactionQueueQuery(transaction_queue_query, username, parsed_query.parameters, interpreter_context, dba);
return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges),
[callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}](
AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> {
if (UNLIKELY(!pull_plan)) {
pull_plan = std::make_shared<PullPlanVector>(callback_fn());
}
if (pull_plan->Pull(stream, n)) {
return QueryHandlerResult::COMMIT;
}
return std::nullopt;
},
RWType::NONE};
}
PreparedQuery PrepareVersionQuery(ParsedQuery parsed_query, bool in_explicit_transaction) {
if (in_explicit_transaction) { if (in_explicit_transaction) {
throw VersionInfoInMulticommandTxException(); throw VersionInfoInMulticommandTxException();
} }
@ -2263,6 +2421,13 @@ PreparedQuery PrepareConstraintQuery(ParsedQuery parsed_query, bool in_explicit_
RWType::NONE}; RWType::NONE};
} }
std::optional<uint64_t> Interpreter::GetTransactionId() const {
if (db_accessor_) {
return db_accessor_->GetTransactionId();
}
return {};
}
void Interpreter::BeginTransaction() { void Interpreter::BeginTransaction() {
const auto prepared_query = PrepareTransactionQuery("BEGIN"); const auto prepared_query = PrepareTransactionQuery("BEGIN");
prepared_query.query_handler(nullptr, {}); prepared_query.query_handler(nullptr, {});
@ -2272,12 +2437,14 @@ void Interpreter::CommitTransaction() {
const auto prepared_query = PrepareTransactionQuery("COMMIT"); const auto prepared_query = PrepareTransactionQuery("COMMIT");
prepared_query.query_handler(nullptr, {}); prepared_query.query_handler(nullptr, {});
query_executions_.clear(); query_executions_.clear();
transaction_queries_->clear();
} }
void Interpreter::RollbackTransaction() { void Interpreter::RollbackTransaction() {
const auto prepared_query = PrepareTransactionQuery("ROLLBACK"); const auto prepared_query = PrepareTransactionQuery("ROLLBACK");
prepared_query.query_handler(nullptr, {}); prepared_query.query_handler(nullptr, {});
query_executions_.clear(); query_executions_.clear();
transaction_queries_->clear();
} }
Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
@ -2285,10 +2452,17 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
const std::string *username) { const std::string *username) {
if (!in_explicit_transaction_) { if (!in_explicit_transaction_) {
query_executions_.clear(); query_executions_.clear();
transaction_queries_->clear();
} }
// This will be done in the handle transaction query. Our handler can save username and then send it to the kill and
// show transactions.
std::optional<std::string> user = StringPointerToOptional(username);
username_ = user;
query_executions_.emplace_back(std::make_unique<QueryExecution>()); query_executions_.emplace_back(std::make_unique<QueryExecution>());
auto &query_execution = query_executions_.back(); auto &query_execution = query_executions_.back();
std::optional<int> qid = std::optional<int> qid =
in_explicit_transaction_ ? static_cast<int>(query_executions_.size() - 1) : std::optional<int>{}; in_explicit_transaction_ ? static_cast<int>(query_executions_.size() - 1) : std::optional<int>{};
@ -2302,6 +2476,9 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
return {query_execution->prepared_query->header, query_execution->prepared_query->privileges, qid}; return {query_execution->prepared_query->header, query_execution->prepared_query->privileges, qid};
} }
// Don't save BEGIN, COMMIT or ROLLBACK
transaction_queries_->push_back(query_string);
// All queries other than transaction control queries advance the command in // All queries other than transaction control queries advance the command in
// an explicit transaction block. // an explicit transaction block.
if (in_explicit_transaction_) { if (in_explicit_transaction_) {
@ -2327,10 +2504,12 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
if (!in_explicit_transaction_ && if (!in_explicit_transaction_ &&
(utils::Downcast<CypherQuery>(parsed_query.query) || utils::Downcast<ExplainQuery>(parsed_query.query) || (utils::Downcast<CypherQuery>(parsed_query.query) || utils::Downcast<ExplainQuery>(parsed_query.query) ||
utils::Downcast<ProfileQuery>(parsed_query.query) || utils::Downcast<DumpQuery>(parsed_query.query) || utils::Downcast<ProfileQuery>(parsed_query.query) || utils::Downcast<DumpQuery>(parsed_query.query) ||
utils::Downcast<TriggerQuery>(parsed_query.query))) { utils::Downcast<TriggerQuery>(parsed_query.query) ||
utils::Downcast<TransactionQueueQuery>(parsed_query.query))) {
db_accessor_ = db_accessor_ =
std::make_unique<storage::Storage::Accessor>(interpreter_context_->db->Access(GetIsolationLevelOverride())); std::make_unique<storage::Storage::Accessor>(interpreter_context_->db->Access(GetIsolationLevelOverride()));
execution_db_accessor_.emplace(db_accessor_.get()); execution_db_accessor_.emplace(db_accessor_.get());
transaction_status_.store(TransactionStatus::ACTIVE, std::memory_order_release);
if (utils::Downcast<CypherQuery>(parsed_query.query) && interpreter_context_->trigger_store.HasTriggers()) { if (utils::Downcast<CypherQuery>(parsed_query.query) && interpreter_context_->trigger_store.HasTriggers()) {
trigger_context_collector_.emplace(interpreter_context_->trigger_store.GetEventTypes()); trigger_context_collector_.emplace(interpreter_context_->trigger_store.GetEventTypes());
@ -2343,15 +2522,15 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
if (utils::Downcast<CypherQuery>(parsed_query.query)) { if (utils::Downcast<CypherQuery>(parsed_query.query)) {
prepared_query = PrepareCypherQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_, prepared_query = PrepareCypherQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_,
&*execution_db_accessor_, &query_execution->execution_memory, &*execution_db_accessor_, &query_execution->execution_memory,
&query_execution->notifications, username, &query_execution->notifications, username, &transaction_status_,
trigger_context_collector_ ? &*trigger_context_collector_ : nullptr); trigger_context_collector_ ? &*trigger_context_collector_ : nullptr);
} else if (utils::Downcast<ExplainQuery>(parsed_query.query)) { } else if (utils::Downcast<ExplainQuery>(parsed_query.query)) {
prepared_query = PrepareExplainQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_, prepared_query = PrepareExplainQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_,
&*execution_db_accessor_, &query_execution->execution_memory_with_exception); &*execution_db_accessor_, &query_execution->execution_memory_with_exception);
} else if (utils::Downcast<ProfileQuery>(parsed_query.query)) { } else if (utils::Downcast<ProfileQuery>(parsed_query.query)) {
prepared_query = PrepareProfileQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, prepared_query = PrepareProfileQuery(
interpreter_context_, &*execution_db_accessor_, std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, interpreter_context_,
&query_execution->execution_memory_with_exception, username); &*execution_db_accessor_, &query_execution->execution_memory_with_exception, username, &transaction_status_);
} else if (utils::Downcast<DumpQuery>(parsed_query.query)) { } else if (utils::Downcast<DumpQuery>(parsed_query.query)) {
prepared_query = PrepareDumpQuery(std::move(parsed_query), &query_execution->summary, &*execution_db_accessor_, prepared_query = PrepareDumpQuery(std::move(parsed_query), &query_execution->summary, &*execution_db_accessor_,
&query_execution->execution_memory); &query_execution->execution_memory);
@ -2359,9 +2538,9 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
prepared_query = PrepareIndexQuery(std::move(parsed_query), in_explicit_transaction_, prepared_query = PrepareIndexQuery(std::move(parsed_query), in_explicit_transaction_,
&query_execution->notifications, interpreter_context_); &query_execution->notifications, interpreter_context_);
} else if (utils::Downcast<AuthQuery>(parsed_query.query)) { } else if (utils::Downcast<AuthQuery>(parsed_query.query)) {
prepared_query = PrepareAuthQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, prepared_query = PrepareAuthQuery(
interpreter_context_, &*execution_db_accessor_, std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, interpreter_context_,
&query_execution->execution_memory_with_exception, username); &*execution_db_accessor_, &query_execution->execution_memory_with_exception, username, &transaction_status_);
} else if (utils::Downcast<InfoQuery>(parsed_query.query)) { } else if (utils::Downcast<InfoQuery>(parsed_query.query)) {
prepared_query = PrepareInfoQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, prepared_query = PrepareInfoQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary,
interpreter_context_, interpreter_context_->db, interpreter_context_, interpreter_context_->db,
@ -2398,6 +2577,9 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
prepared_query = PrepareSettingQuery(std::move(parsed_query), in_explicit_transaction_, &*execution_db_accessor_); prepared_query = PrepareSettingQuery(std::move(parsed_query), in_explicit_transaction_, &*execution_db_accessor_);
} else if (utils::Downcast<VersionQuery>(parsed_query.query)) { } else if (utils::Downcast<VersionQuery>(parsed_query.query)) {
prepared_query = PrepareVersionQuery(std::move(parsed_query), in_explicit_transaction_); prepared_query = PrepareVersionQuery(std::move(parsed_query), in_explicit_transaction_);
} else if (utils::Downcast<TransactionQueueQuery>(parsed_query.query)) {
prepared_query = PrepareTransactionQueueQuery(std::move(parsed_query), username_, in_explicit_transaction_,
interpreter_context_, &*execution_db_accessor_);
} else { } else {
LOG_FATAL("Should not get here -- unknown query type!"); LOG_FATAL("Should not get here -- unknown query type!");
} }
@ -2425,7 +2607,29 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
} }
} }
std::vector<TypedValue> Interpreter::GetQueries() {
auto typed_queries = std::vector<TypedValue>();
transaction_queries_.WithLock([&typed_queries](const auto &transaction_queries) {
std::for_each(transaction_queries.begin(), transaction_queries.end(),
[&typed_queries](const auto &query) { typed_queries.emplace_back(query); });
});
return typed_queries;
}
void Interpreter::Abort() { void Interpreter::Abort() {
auto expected = TransactionStatus::ACTIVE;
while (!transaction_status_.compare_exchange_weak(expected, TransactionStatus::STARTED_ROLLBACK)) {
if (expected == TransactionStatus::TERMINATED || expected == TransactionStatus::IDLE) {
transaction_status_.store(TransactionStatus::STARTED_ROLLBACK);
break;
}
expected = TransactionStatus::ACTIVE;
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
utils::OnScopeExit clean_status(
[this]() { transaction_status_.store(TransactionStatus::IDLE, std::memory_order_release); });
expect_rollback_ = false; expect_rollback_ = false;
in_explicit_transaction_ = false; in_explicit_transaction_ = false;
if (!db_accessor_) return; if (!db_accessor_) return;
@ -2437,7 +2641,7 @@ void Interpreter::Abort() {
namespace { namespace {
void RunTriggersIndividually(const utils::SkipList<Trigger> &triggers, InterpreterContext *interpreter_context, void RunTriggersIndividually(const utils::SkipList<Trigger> &triggers, InterpreterContext *interpreter_context,
TriggerContext trigger_context) { TriggerContext trigger_context, std::atomic<TransactionStatus> *transaction_status) {
// Run the triggers // Run the triggers
for (const auto &trigger : triggers.access()) { for (const auto &trigger : triggers.access()) {
utils::MonotonicBufferResource execution_memory{kExecutionMemoryBlockSize}; utils::MonotonicBufferResource execution_memory{kExecutionMemoryBlockSize};
@ -2449,7 +2653,8 @@ void RunTriggersIndividually(const utils::SkipList<Trigger> &triggers, Interpret
trigger_context.AdaptForAccessor(&db_accessor); trigger_context.AdaptForAccessor(&db_accessor);
try { try {
trigger.Execute(&db_accessor, &execution_memory, interpreter_context->config.execution_timeout_sec, trigger.Execute(&db_accessor, &execution_memory, interpreter_context->config.execution_timeout_sec,
&interpreter_context->is_shutting_down, trigger_context, interpreter_context->auth_checker); &interpreter_context->is_shutting_down, transaction_status, trigger_context,
interpreter_context->auth_checker);
} catch (const utils::BasicException &exception) { } catch (const utils::BasicException &exception) {
spdlog::warn("Trigger '{}' failed with exception:\n{}", trigger.Name(), exception.what()); spdlog::warn("Trigger '{}' failed with exception:\n{}", trigger.Name(), exception.what());
db_accessor.Abort(); db_accessor.Abort();
@ -2504,6 +2709,25 @@ void Interpreter::Commit() {
// a query. // a query.
if (!db_accessor_) return; if (!db_accessor_) return;
/*
At this point we must check that the transaction is alive to start committing. The only other possible state is
verifying and in that case we must check if the transaction was terminated and if yes abort committing. Exception
should suffice.
*/
auto expected = TransactionStatus::ACTIVE;
while (!transaction_status_.compare_exchange_weak(expected, TransactionStatus::STARTED_COMMITTING)) {
if (expected == TransactionStatus::TERMINATED) {
throw memgraph::utils::BasicException(
"Aborting transaction commit because the transaction was requested to stop from other session. ");
}
expected = TransactionStatus::ACTIVE;
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
// Clean transaction status if something went wrong
utils::OnScopeExit clean_status(
[this]() { transaction_status_.store(TransactionStatus::IDLE, std::memory_order_release); });
std::optional<TriggerContext> trigger_context = std::nullopt; std::optional<TriggerContext> trigger_context = std::nullopt;
if (trigger_context_collector_) { if (trigger_context_collector_) {
trigger_context.emplace(std::move(*trigger_context_collector_).TransformToTriggerContext()); trigger_context.emplace(std::move(*trigger_context_collector_).TransformToTriggerContext());
@ -2517,7 +2741,8 @@ void Interpreter::Commit() {
AdvanceCommand(); AdvanceCommand();
try { try {
trigger.Execute(&*execution_db_accessor_, &execution_memory, interpreter_context_->config.execution_timeout_sec, trigger.Execute(&*execution_db_accessor_, &execution_memory, interpreter_context_->config.execution_timeout_sec,
&interpreter_context_->is_shutting_down, *trigger_context, interpreter_context_->auth_checker); &interpreter_context_->is_shutting_down, &transaction_status_, *trigger_context,
interpreter_context_->auth_checker);
} catch (const utils::BasicException &e) { } catch (const utils::BasicException &e) {
throw utils::BasicException( throw utils::BasicException(
fmt::format("Trigger '{}' caused the transaction to fail.\nException: {}", trigger.Name(), e.what())); fmt::format("Trigger '{}' caused the transaction to fail.\nException: {}", trigger.Name(), e.what()));
@ -2579,10 +2804,10 @@ void Interpreter::Commit() {
// This means the ordered execution of after commit triggers are not guaranteed. // This means the ordered execution of after commit triggers are not guaranteed.
if (trigger_context && interpreter_context_->trigger_store.AfterCommitTriggers().size() > 0) { if (trigger_context && interpreter_context_->trigger_store.AfterCommitTriggers().size() > 0) {
interpreter_context_->after_commit_trigger_pool.AddTask( interpreter_context_->after_commit_trigger_pool.AddTask(
[trigger_context = std::move(*trigger_context), interpreter_context = this->interpreter_context_, [this, trigger_context = std::move(*trigger_context),
user_transaction = std::shared_ptr(std::move(db_accessor_))]() mutable { user_transaction = std::shared_ptr(std::move(db_accessor_))]() mutable {
RunTriggersIndividually(interpreter_context->trigger_store.AfterCommitTriggers(), interpreter_context, RunTriggersIndividually(this->interpreter_context_->trigger_store.AfterCommitTriggers(),
std::move(trigger_context)); this->interpreter_context_, std::move(trigger_context), &this->transaction_status_);
user_transaction->FinalizeTransaction(); user_transaction->FinalizeTransaction();
SPDLOG_DEBUG("Finished executing after commit triggers"); // NOLINT(bugprone-lambda-function-name) SPDLOG_DEBUG("Finished executing after commit triggers"); // NOLINT(bugprone-lambda-function-name)
}); });

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -11,6 +11,8 @@
#pragma once #pragma once
#include <unordered_set>
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include "query/auth_checker.hpp" #include "query/auth_checker.hpp"
@ -37,6 +39,7 @@
#include "utils/settings.hpp" #include "utils/settings.hpp"
#include "utils/skip_list.hpp" #include "utils/skip_list.hpp"
#include "utils/spin_lock.hpp" #include "utils/spin_lock.hpp"
#include "utils/synchronized.hpp"
#include "utils/thread_pool.hpp" #include "utils/thread_pool.hpp"
#include "utils/timer.hpp" #include "utils/timer.hpp"
#include "utils/tsc.hpp" #include "utils/tsc.hpp"
@ -179,12 +182,12 @@ struct PreparedQuery {
plan::ReadWriteTypeChecker::RWType rw_type; plan::ReadWriteTypeChecker::RWType rw_type;
}; };
class Interpreter;
/** /**
* Holds data shared between multiple `Interpreter` instances (which might be * Holds data shared between multiple `Interpreter` instances (which might be
* running concurrently). * running concurrently).
* *
* Users should initialize the context but should not modify it after it has
* been passed to an `Interpreter` instance.
*/ */
struct InterpreterContext { struct InterpreterContext {
explicit InterpreterContext(storage::Storage *db, InterpreterConfig config, explicit InterpreterContext(storage::Storage *db, InterpreterConfig config,
@ -214,6 +217,7 @@ struct InterpreterContext {
const InterpreterConfig config; const InterpreterConfig config;
query::stream::Streams streams; query::stream::Streams streams;
utils::Synchronized<std::unordered_set<Interpreter *>, utils::SpinLock> interpreters;
}; };
/// Function that is used to tell all active interpreters that they should stop /// Function that is used to tell all active interpreters that they should stop
@ -235,6 +239,10 @@ class Interpreter final {
std::optional<int> qid; std::optional<int> qid;
}; };
std::optional<std::string> username_;
bool in_explicit_transaction_{false};
bool expect_rollback_{false};
/** /**
* Prepare a query for execution. * Prepare a query for execution.
* *
@ -290,6 +298,11 @@ class Interpreter final {
void BeginTransaction(); void BeginTransaction();
/*
Returns transaction id or empty if the db_accessor is not initialized.
*/
std::optional<uint64_t> GetTransactionId() const;
void CommitTransaction(); void CommitTransaction();
void RollbackTransaction(); void RollbackTransaction();
@ -297,11 +310,15 @@ class Interpreter final {
void SetNextTransactionIsolationLevel(storage::IsolationLevel isolation_level); void SetNextTransactionIsolationLevel(storage::IsolationLevel isolation_level);
void SetSessionIsolationLevel(storage::IsolationLevel isolation_level); void SetSessionIsolationLevel(storage::IsolationLevel isolation_level);
std::vector<TypedValue> GetQueries();
/** /**
* Abort the current multicommand transaction. * Abort the current multicommand transaction.
*/ */
void Abort(); void Abort();
std::atomic<TransactionStatus> transaction_status_{TransactionStatus::IDLE};
private: private:
struct QueryExecution { struct QueryExecution {
std::optional<PreparedQuery> prepared_query; std::optional<PreparedQuery> prepared_query;
@ -338,6 +355,8 @@ class Interpreter final {
// and deletion of a single query execution, i.e. when a query finishes, // and deletion of a single query execution, i.e. when a query finishes,
// we reset the corresponding unique_ptr. // we reset the corresponding unique_ptr.
std::vector<std::unique_ptr<QueryExecution>> query_executions_; std::vector<std::unique_ptr<QueryExecution>> query_executions_;
// all queries that are run as part of the current transaction
utils::Synchronized<std::vector<std::string>, utils::SpinLock> transaction_queries_;
InterpreterContext *interpreter_context_; InterpreterContext *interpreter_context_;
@ -347,8 +366,6 @@ class Interpreter final {
std::unique_ptr<storage::Storage::Accessor> db_accessor_; std::unique_ptr<storage::Storage::Accessor> db_accessor_;
std::optional<DbAccessor> execution_db_accessor_; std::optional<DbAccessor> execution_db_accessor_;
std::optional<TriggerContextCollector> trigger_context_collector_; std::optional<TriggerContextCollector> trigger_context_collector_;
bool in_explicit_transaction_{false};
bool expect_rollback_{false};
std::optional<storage::IsolationLevel> interpreter_isolation_level; std::optional<storage::IsolationLevel> interpreter_isolation_level;
std::optional<storage::IsolationLevel> next_transaction_isolation_level; std::optional<storage::IsolationLevel> next_transaction_isolation_level;
@ -365,12 +382,32 @@ class Interpreter final {
} }
}; };
class TransactionQueueQueryHandler {
public:
TransactionQueueQueryHandler() = default;
virtual ~TransactionQueueQueryHandler() = default;
TransactionQueueQueryHandler(const TransactionQueueQueryHandler &) = default;
TransactionQueueQueryHandler &operator=(const TransactionQueueQueryHandler &) = default;
TransactionQueueQueryHandler(TransactionQueueQueryHandler &&) = default;
TransactionQueueQueryHandler &operator=(TransactionQueueQueryHandler &&) = default;
static std::vector<std::vector<TypedValue>> ShowTransactions(const std::unordered_set<Interpreter *> &interpreters,
const std::optional<std::string> &username,
bool hasTransactionManagementPrivilege);
static std::vector<std::vector<TypedValue>> KillTransactions(
InterpreterContext *interpreter_context, const std::vector<std::string> &maybe_kill_transaction_ids,
const std::optional<std::string> &username, bool hasTransactionManagementPrivilege);
};
template <typename TStream> template <typename TStream>
std::map<std::string, TypedValue> Interpreter::Pull(TStream *result_stream, std::optional<int> n, std::map<std::string, TypedValue> Interpreter::Pull(TStream *result_stream, std::optional<int> n,
std::optional<int> qid) { std::optional<int> qid) {
MG_ASSERT(in_explicit_transaction_ || !qid, "qid can be only used in explicit transaction!"); MG_ASSERT(in_explicit_transaction_ || !qid, "qid can be only used in explicit transaction!");
const int qid_value = qid ? *qid : static_cast<int>(query_executions_.size() - 1);
const int qid_value = qid ? *qid : static_cast<int>(query_executions_.size() - 1);
if (qid_value < 0 || qid_value >= query_executions_.size()) { if (qid_value < 0 || qid_value >= query_executions_.size()) {
throw InvalidArgumentsException("qid", "Query with specified ID does not exist!"); throw InvalidArgumentsException("qid", "Query with specified ID does not exist!");
} }
@ -430,6 +467,7 @@ std::map<std::string, TypedValue> Interpreter::Pull(TStream *result_stream, std:
// methods as we will delete summary contained in them which we need // methods as we will delete summary contained in them which we need
// after our query finished executing. // after our query finished executing.
query_executions_.clear(); query_executions_.clear();
transaction_queries_->clear();
} else { } else {
// We can only clear this execution as some of the queries // We can only clear this execution as some of the queries
// in the transaction can be in unfinished state // in the transaction can be in unfinished state

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -490,6 +490,11 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std
retry_interval = interpreter_context_->config.stream_transaction_retry_interval]( retry_interval = interpreter_context_->config.stream_transaction_retry_interval](
const std::vector<typename TStream::Message> &messages) mutable { const std::vector<typename TStream::Message> &messages) mutable {
auto accessor = interpreter_context->db->Access(); auto accessor = interpreter_context->db->Access();
// register new interpreter into interpreter_context_
interpreter_context->interpreters->insert(interpreter.get());
utils::OnScopeExit interpreter_cleanup{
[interpreter_context, interpreter]() { interpreter_context->interpreters->erase(interpreter.get()); }};
EventCounter::IncrementCounter(EventCounter::MessagesConsumed, messages.size()); EventCounter::IncrementCounter(EventCounter::MessagesConsumed, messages.size());
CallCustomTransformation(transformation_name, messages, result, accessor, *memory_resource, stream_name); CallCustomTransformation(transformation_name, messages, result, accessor, *memory_resource, stream_name);

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -195,7 +195,8 @@ std::shared_ptr<Trigger::TriggerPlan> Trigger::GetPlan(DbAccessor *db_accessor,
void Trigger::Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution_memory, void Trigger::Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution_memory,
const double max_execution_time_sec, std::atomic<bool> *is_shutting_down, const double max_execution_time_sec, std::atomic<bool> *is_shutting_down,
const TriggerContext &context, const AuthChecker *auth_checker) const { std::atomic<TransactionStatus> *transaction_status, const TriggerContext &context,
const AuthChecker *auth_checker) const {
if (!context.ShouldEventTrigger(event_type_)) { if (!context.ShouldEventTrigger(event_type_)) {
return; return;
} }
@ -214,6 +215,7 @@ void Trigger::Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution
ctx.evaluation_context.labels = NamesToLabels(plan.ast_storage().labels_, dba); ctx.evaluation_context.labels = NamesToLabels(plan.ast_storage().labels_, dba);
ctx.timer = utils::AsyncTimer(max_execution_time_sec); ctx.timer = utils::AsyncTimer(max_execution_time_sec);
ctx.is_shutting_down = is_shutting_down; ctx.is_shutting_down = is_shutting_down;
ctx.transaction_status = transaction_status;
ctx.is_profile_query = false; ctx.is_profile_query = false;
// Set up temporary memory for a single Pull. Initial memory comes from the // Set up temporary memory for a single Pull. Initial memory comes from the

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -31,6 +31,8 @@
#include "utils/spin_lock.hpp" #include "utils/spin_lock.hpp"
namespace memgraph::query { namespace memgraph::query {
enum class TransactionStatus;
struct Trigger { struct Trigger {
explicit Trigger(std::string name, const std::string &query, explicit Trigger(std::string name, const std::string &query,
const std::map<std::string, storage::PropertyValue> &user_parameters, TriggerEventType event_type, const std::map<std::string, storage::PropertyValue> &user_parameters, TriggerEventType event_type,
@ -39,8 +41,8 @@ struct Trigger {
const query::AuthChecker *auth_checker); const query::AuthChecker *auth_checker);
void Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution_memory, double max_execution_time_sec, void Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution_memory, double max_execution_time_sec,
std::atomic<bool> *is_shutting_down, const TriggerContext &context, std::atomic<bool> *is_shutting_down, std::atomic<TransactionStatus> *transaction_status,
const AuthChecker *auth_checker) const; const TriggerContext &context, const AuthChecker *auth_checker) const;
bool operator==(const Trigger &other) const { return name_ == other.name_; } bool operator==(const Trigger &other) const { return name_ == other.name_; }
// NOLINTNEXTLINE (modernize-use-nullptr) // NOLINTNEXTLINE (modernize-use-nullptr)

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -12,6 +12,7 @@
#include "storage/v2/constraints.hpp" #include "storage/v2/constraints.hpp"
#include <algorithm> #include <algorithm>
#include <atomic>
#include <cstring> #include <cstring>
#include <map> #include <map>
@ -71,7 +72,7 @@ bool LastCommittedVersionHasLabelProperty(const Vertex &vertex, LabelId label, c
while (delta != nullptr) { while (delta != nullptr) {
auto ts = delta->timestamp->load(std::memory_order_acquire); auto ts = delta->timestamp->load(std::memory_order_acquire);
if (ts < commit_timestamp || ts == transaction.transaction_id) { if (ts < commit_timestamp || ts == transaction.transaction_id.load(std::memory_order_acquire)) {
break; break;
} }

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -11,6 +11,7 @@
#pragma once #pragma once
#include <atomic>
#include "storage/v2/property_value.hpp" #include "storage/v2/property_value.hpp"
#include "storage/v2/transaction.hpp" #include "storage/v2/transaction.hpp"
#include "storage/v2/view.hpp" #include "storage/v2/view.hpp"
@ -30,7 +31,7 @@ inline void ApplyDeltasForRead(Transaction *transaction, const Delta *delta, Vie
// This allows the transaction to see its changes even though it's committed. // This allows the transaction to see its changes even though it's committed.
const auto commit_timestamp = transaction->commit_timestamp const auto commit_timestamp = transaction->commit_timestamp
? transaction->commit_timestamp->load(std::memory_order_acquire) ? transaction->commit_timestamp->load(std::memory_order_acquire)
: transaction->transaction_id; : transaction->transaction_id.load(std::memory_order_acquire);
while (delta != nullptr) { while (delta != nullptr) {
auto ts = delta->timestamp->load(std::memory_order_acquire); auto ts = delta->timestamp->load(std::memory_order_acquire);
auto cid = delta->command_id; auto cid = delta->command_id;
@ -80,7 +81,7 @@ inline bool PrepareForWrite(Transaction *transaction, TObj *object) {
if (object->delta == nullptr) return true; if (object->delta == nullptr) return true;
auto ts = object->delta->timestamp->load(std::memory_order_acquire); auto ts = object->delta->timestamp->load(std::memory_order_acquire);
if (ts == transaction->transaction_id || ts < transaction->start_timestamp) { if (ts == transaction->transaction_id.load(std::memory_order_acquire) || ts < transaction->start_timestamp) {
return true; return true;
} }

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -985,8 +985,8 @@ void Storage::Accessor::Abort() {
auto vertex = prev.vertex; auto vertex = prev.vertex;
std::lock_guard<utils::SpinLock> guard(vertex->lock); std::lock_guard<utils::SpinLock> guard(vertex->lock);
Delta *current = vertex->delta; Delta *current = vertex->delta;
while (current != nullptr && while (current != nullptr && current->timestamp->load(std::memory_order_acquire) ==
current->timestamp->load(std::memory_order_acquire) == transaction_.transaction_id) { transaction_.transaction_id.load(std::memory_order_acquire)) {
switch (current->action) { switch (current->action) {
case Delta::Action::REMOVE_LABEL: { case Delta::Action::REMOVE_LABEL: {
auto it = std::find(vertex->labels.begin(), vertex->labels.end(), current->label); auto it = std::find(vertex->labels.begin(), vertex->labels.end(), current->label);
@ -1072,8 +1072,8 @@ void Storage::Accessor::Abort() {
auto edge = prev.edge; auto edge = prev.edge;
std::lock_guard<utils::SpinLock> guard(edge->lock); std::lock_guard<utils::SpinLock> guard(edge->lock);
Delta *current = edge->delta; Delta *current = edge->delta;
while (current != nullptr && while (current != nullptr && current->timestamp->load(std::memory_order_acquire) ==
current->timestamp->load(std::memory_order_acquire) == transaction_.transaction_id) { transaction_.transaction_id.load(std::memory_order_acquire)) {
switch (current->action) { switch (current->action) {
case Delta::Action::SET_PROPERTY: { case Delta::Action::SET_PROPERTY: {
edge->properties.SetProperty(current->property.key, current->property.value); edge->properties.SetProperty(current->property.key, current->property.value);
@ -1144,6 +1144,13 @@ void Storage::Accessor::FinalizeTransaction() {
} }
} }
std::optional<uint64_t> Storage::Accessor::GetTransactionId() const {
if (is_transaction_active_) {
return transaction_.transaction_id.load(std::memory_order_acquire);
}
return {};
}
const std::string &Storage::LabelToName(LabelId label) const { return name_id_mapper_.IdToName(label.AsUint()); } const std::string &Storage::LabelToName(LabelId label) const { return name_id_mapper_.IdToName(label.AsUint()); }
const std::string &Storage::PropertyToName(PropertyId property) const { const std::string &Storage::PropertyToName(PropertyId property) const {

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -12,6 +12,7 @@
#pragma once #pragma once
#include <atomic> #include <atomic>
#include <cstdint>
#include <filesystem> #include <filesystem>
#include <optional> #include <optional>
#include <shared_mutex> #include <shared_mutex>
@ -324,6 +325,8 @@ class Storage final {
void FinalizeTransaction(); void FinalizeTransaction();
std::optional<uint64_t> GetTransactionId() const;
private: private:
/// @throw std::bad_alloc /// @throw std::bad_alloc
VertexAccessor CreateVertex(storage::Gid gid); VertexAccessor CreateVertex(storage::Gid gid);

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -39,7 +39,7 @@ struct Transaction {
isolation_level(isolation_level) {} isolation_level(isolation_level) {}
Transaction(Transaction &&other) noexcept Transaction(Transaction &&other) noexcept
: transaction_id(other.transaction_id), : transaction_id(other.transaction_id.load(std::memory_order_acquire)),
start_timestamp(other.start_timestamp), start_timestamp(other.start_timestamp),
commit_timestamp(std::move(other.commit_timestamp)), commit_timestamp(std::move(other.commit_timestamp)),
command_id(other.command_id), command_id(other.command_id),
@ -56,10 +56,10 @@ struct Transaction {
/// @throw std::bad_alloc if failed to create the `commit_timestamp` /// @throw std::bad_alloc if failed to create the `commit_timestamp`
void EnsureCommitTimestampExists() { void EnsureCommitTimestampExists() {
if (commit_timestamp != nullptr) return; if (commit_timestamp != nullptr) return;
commit_timestamp = std::make_unique<std::atomic<uint64_t>>(transaction_id); commit_timestamp = std::make_unique<std::atomic<uint64_t>>(transaction_id.load(std::memory_order_relaxed));
} }
uint64_t transaction_id; std::atomic<uint64_t> transaction_id;
uint64_t start_timestamp; uint64_t start_timestamp;
// The `Transaction` object is stack allocated, but the `commit_timestamp` // The `Transaction` object is stack allocated, but the `commit_timestamp`
// must be heap allocated because `Delta`s have a pointer to it, and that // must be heap allocated because `Delta`s have a pointer to it, and that
@ -73,12 +73,16 @@ struct Transaction {
}; };
inline bool operator==(const Transaction &first, const Transaction &second) { inline bool operator==(const Transaction &first, const Transaction &second) {
return first.transaction_id == second.transaction_id; return first.transaction_id.load(std::memory_order_acquire) == second.transaction_id.load(std::memory_order_acquire);
} }
inline bool operator<(const Transaction &first, const Transaction &second) { inline bool operator<(const Transaction &first, const Transaction &second) {
return first.transaction_id < second.transaction_id; return first.transaction_id.load(std::memory_order_acquire) < second.transaction_id.load(std::memory_order_acquire);
}
inline bool operator==(const Transaction &first, const uint64_t &second) {
return first.transaction_id.load(std::memory_order_acquire) == second;
}
inline bool operator<(const Transaction &first, const uint64_t &second) {
return first.transaction_id.load(std::memory_order_acquire) < second;
} }
inline bool operator==(const Transaction &first, const uint64_t &second) { return first.transaction_id == second; }
inline bool operator<(const Transaction &first, const uint64_t &second) { return first.transaction_id < second; }
} // namespace memgraph::storage } // namespace memgraph::storage

View File

@ -176,8 +176,8 @@ enum class TypeId : uint64_t {
AST_VERSION_QUERY, AST_VERSION_QUERY,
AST_FOREACH, AST_FOREACH,
AST_SHOW_CONFIG_QUERY, AST_SHOW_CONFIG_QUERY,
AST_TRANSACTION_QUEUE_QUERY,
AST_EXISTS, AST_EXISTS,
// Symbol // Symbol
SYMBOL, SYMBOL,
}; };

View File

@ -44,6 +44,7 @@ add_subdirectory(module_file_manager)
add_subdirectory(monitoring_server) add_subdirectory(monitoring_server)
add_subdirectory(lba_procedures) add_subdirectory(lba_procedures)
add_subdirectory(python_query_modules_reloading) add_subdirectory(python_query_modules_reloading)
add_subdirectory(transaction_queue)
add_subdirectory(mock_api) add_subdirectory(mock_api)
copy_e2e_python_files(pytest_runner pytest_runner.sh "") copy_e2e_python_files(pytest_runner pytest_runner.sh "")

View File

@ -10,6 +10,7 @@
# licenses/APL.txt. # licenses/APL.txt.
import sys import sys
import pytest import pytest
from common import connect, execute_and_fetch_all from common import connect, execute_and_fetch_all
@ -35,6 +36,7 @@ BASIC_PRIVILEGES = [
"MODULE_READ", "MODULE_READ",
"WEBSOCKET", "WEBSOCKET",
"MODULE_WRITE", "MODULE_WRITE",
"TRANSACTION_MANAGEMENT",
] ]
@ -58,7 +60,7 @@ def test_lba_procedures_show_privileges_first_user():
cursor = connect(username="Josip", password="").cursor() cursor = connect(username="Josip", password="").cursor()
result = execute_and_fetch_all(cursor, "SHOW PRIVILEGES FOR Josip;") result = execute_and_fetch_all(cursor, "SHOW PRIVILEGES FOR Josip;")
assert len(result) == 30 assert len(result) == 31
fine_privilege_results = [res for res in result if res[0] not in BASIC_PRIVILEGES] fine_privilege_results = [res for res in result if res[0] not in BASIC_PRIVILEGES]

View File

@ -0,0 +1,8 @@
function(copy_query_modules_reloading_procedures_e2e_python_files FILE_NAME)
copy_e2e_python_files(transaction_queue ${FILE_NAME})
endfunction()
copy_query_modules_reloading_procedures_e2e_python_files(common.py)
copy_query_modules_reloading_procedures_e2e_python_files(test_transaction_queue.py)
add_subdirectory(procedures)

View File

@ -0,0 +1,26 @@
# Copyright 2022 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import typing
import mgclient
import pytest
def execute_and_fetch_all(cursor: mgclient.Cursor, query: str, params: dict = {}) -> typing.List[tuple]:
cursor.execute(query, params)
return cursor.fetchall()
def connect(**kwargs) -> mgclient.Connection:
connection = mgclient.connect(host="localhost", port=7687, **kwargs)
connection.autocommit = True
return connection

View File

@ -0,0 +1 @@
copy_e2e_python_files(transaction_queue infinite_query.py)

View File

@ -0,0 +1,27 @@
# Copyright 2021 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import threading
import time
import mgp
@mgp.read_proc
def long_query(ctx: mgp.ProcCtx) -> mgp.Record(my_id=int):
id = 1
try:
while True:
if ctx.check_must_abort():
break
id += 1
except mgp.AbortError:
return mgp.Record(my_id=id)

View File

@ -0,0 +1,338 @@
# Copyright 2022 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import multiprocessing
import sys
import threading
import time
from typing import List
import mgclient
import pytest
from common import connect, execute_and_fetch_all
# Utility functions
# -------------------------
def get_non_show_transaction_id(results):
"""Returns transaction id of the first transaction that is not SHOW TRANSACTIONS;"""
for res in results:
if res[2] != ["SHOW TRANSACTIONS"]:
return res[1]
def show_transactions_test(cursor, expected_num_results: int):
results = execute_and_fetch_all(cursor, "SHOW TRANSACTIONS")
assert len(results) == expected_num_results
return results
def process_function(cursor, queries: List[str]):
try:
for query in queries:
cursor.execute(query, {})
except mgclient.DatabaseError:
pass
# Tests
# -------------------------
def test_self_transaction():
"""Tests that simple show transactions work when no other is running."""
cursor = connect().cursor()
results = execute_and_fetch_all(cursor, "SHOW TRANSACTIONS")
assert len(results) == 1
def test_admin_has_one_transaction():
"""Creates admin and tests that he sees only one transaction."""
# a_cursor is used for creating admin user, simulates main thread
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin")
admin_cursor = connect(username="admin", password="").cursor()
process = multiprocessing.Process(target=show_transactions_test, args=(admin_cursor, 1))
process.start()
process.join()
execute_and_fetch_all(superadmin_cursor, "DROP USER admin")
def test_user_can_see_its_transaction():
"""Tests that user without privileges can see its own transaction"""
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
execute_and_fetch_all(superadmin_cursor, "GRANT ALL PRIVILEGES TO admin")
execute_and_fetch_all(superadmin_cursor, "CREATE USER user")
execute_and_fetch_all(superadmin_cursor, "REVOKE ALL PRIVILEGES FROM user")
user_cursor = connect(username="user", password="").cursor()
process = multiprocessing.Process(target=show_transactions_test, args=(user_cursor, 1))
process.start()
process.join()
admin_cursor = connect(username="admin", password="").cursor()
execute_and_fetch_all(admin_cursor, "DROP USER user")
execute_and_fetch_all(admin_cursor, "DROP USER admin")
def test_explicit_transaction_output():
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin")
admin_connection = connect(username="admin", password="")
admin_cursor = admin_connection.cursor()
# Admin starts running explicit transaction
process = multiprocessing.Process(
target=process_function,
args=(superadmin_cursor, ["BEGIN", "CREATE (n:Person {id_: 1})", "CREATE (n:Person {id_: 2})"]),
)
process.start()
time.sleep(0.5)
show_results = show_transactions_test(admin_cursor, 2)
if show_results[0][2] == ["SHOW TRANSACTIONS"]:
executing_index = 0
else:
executing_index = 1
assert show_results[1 - executing_index][2] == ["CREATE (n:Person {id_: 1})", "CREATE (n:Person {id_: 2})"]
execute_and_fetch_all(superadmin_cursor, "ROLLBACK")
execute_and_fetch_all(superadmin_cursor, "DROP USER admin")
def test_superadmin_cannot_see_admin_can_see_admin():
"""Tests that superadmin cannot see the transaction created by admin but two admins can see and kill each other's transactions."""
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin1")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin1")
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin2")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin2")
# Admin starts running infinite query
admin_connection_1 = connect(username="admin1", password="")
admin_cursor_1 = admin_connection_1.cursor()
admin_connection_2 = connect(username="admin2", password="")
admin_cursor_2 = admin_connection_2.cursor()
process = multiprocessing.Process(
target=process_function, args=(admin_cursor_1, ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"])
)
process.start()
time.sleep(0.5)
# Superadmin shouldn't see the execution of the admin
show_transactions_test(superadmin_cursor, 1)
show_results = show_transactions_test(admin_cursor_2, 2)
# Don't rely on the order of intepreters in Memgraph
if show_results[0][2] == ["SHOW TRANSACTIONS"]:
executing_index = 0
else:
executing_index = 1
assert show_results[executing_index][0] == "admin2"
assert show_results[executing_index][2] == ["SHOW TRANSACTIONS"]
assert show_results[1 - executing_index][0] == "admin1"
assert show_results[1 - executing_index][2] == ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"]
# Kill transaction
long_transaction_id = show_results[1 - executing_index][1]
execute_and_fetch_all(admin_cursor_2, f"TERMINATE TRANSACTIONS '{long_transaction_id}'")
execute_and_fetch_all(superadmin_cursor, "DROP USER admin1")
execute_and_fetch_all(superadmin_cursor, "DROP USER admin2")
admin_connection_1.close()
admin_connection_2.close()
def test_admin_sees_superadmin():
"""Tests that admin created by superadmin can see the superadmin's transaction."""
superadmin_connection = connect()
superadmin_cursor = superadmin_connection.cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin")
# Admin starts running infinite query
process = multiprocessing.Process(
target=process_function, args=(superadmin_cursor, ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"])
)
process.start()
time.sleep(0.5)
admin_cursor = connect(username="admin", password="").cursor()
show_results = show_transactions_test(admin_cursor, 2)
# show_results_2 = show_transactions_test(admin_cursor, 2)
# Don't rely on the order of intepreters in Memgraph
if show_results[0][2] == ["SHOW TRANSACTIONS"]:
executing_index = 0
else:
executing_index = 1
assert show_results[executing_index][0] == "admin"
assert show_results[executing_index][2] == ["SHOW TRANSACTIONS"]
assert show_results[1 - executing_index][0] == ""
assert show_results[1 - executing_index][2] == ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"]
# Kill transaction
long_transaction_id = show_results[1 - executing_index][1]
execute_and_fetch_all(admin_cursor, f"TERMINATE TRANSACTIONS '{long_transaction_id}'")
execute_and_fetch_all(admin_cursor, "DROP USER admin")
superadmin_connection.close()
def test_admin_can_see_user_transaction():
"""Tests that admin can see user's transaction and kill it."""
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin")
execute_and_fetch_all(superadmin_cursor, "CREATE USER user")
# Admin starts running infinite query
admin_connection = connect(username="admin", password="")
admin_cursor = admin_connection.cursor()
user_connection = connect(username="user", password="")
user_cursor = user_connection.cursor()
process = multiprocessing.Process(
target=process_function, args=(user_cursor, ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"])
)
process.start()
time.sleep(0.5)
# Admin should see the user's transaction.
show_results = show_transactions_test(admin_cursor, 2)
# Don't rely on the order of intepreters in Memgraph
if show_results[0][2] == ["SHOW TRANSACTIONS"]:
executing_index = 0
else:
executing_index = 1
assert show_results[executing_index][0] == "admin"
assert show_results[executing_index][2] == ["SHOW TRANSACTIONS"]
assert show_results[1 - executing_index][0] == "user"
assert show_results[1 - executing_index][2] == ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"]
# Kill transaction
long_transaction_id = show_results[1 - executing_index][1]
execute_and_fetch_all(admin_cursor, f"TERMINATE TRANSACTIONS '{long_transaction_id}'")
execute_and_fetch_all(superadmin_cursor, "DROP USER admin")
execute_and_fetch_all(superadmin_cursor, "DROP USER user")
admin_connection.close()
user_connection.close()
def test_user_cannot_see_admin_transaction():
"""User cannot see admin's transaction but other admin can and he can kill it."""
# Superadmin creates two admins and one user
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin1")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin1")
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin2")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin2")
execute_and_fetch_all(superadmin_cursor, "CREATE USER user")
admin_connection_1 = connect(username="admin1", password="")
admin_cursor_1 = admin_connection_1.cursor()
admin_connection_2 = connect(username="admin2", password="")
admin_cursor_2 = admin_connection_2.cursor()
user_connection = connect(username="user", password="")
user_cursor = user_connection.cursor()
# Admin1 starts running long running query
process = multiprocessing.Process(
target=process_function, args=(admin_cursor_1, ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"])
)
process.start()
time.sleep(0.5)
# User should not see the admin's transaction.
show_transactions_test(user_cursor, 1)
# Second admin should see other admin's transactions
show_results = show_transactions_test(admin_cursor_2, 2)
# Don't rely on the order of intepreters in Memgraph
if show_results[0][2] == ["SHOW TRANSACTIONS"]:
executing_index = 0
else:
executing_index = 1
assert show_results[executing_index][0] == "admin2"
assert show_results[executing_index][2] == ["SHOW TRANSACTIONS"]
assert show_results[1 - executing_index][0] == "admin1"
assert show_results[1 - executing_index][2] == ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"]
# Kill transaction
long_transaction_id = show_results[1 - executing_index][1]
execute_and_fetch_all(admin_cursor_2, f"TERMINATE TRANSACTIONS '{long_transaction_id}'")
execute_and_fetch_all(superadmin_cursor, "DROP USER admin1")
execute_and_fetch_all(superadmin_cursor, "DROP USER admin2")
execute_and_fetch_all(superadmin_cursor, "DROP USER user")
admin_connection_1.close()
admin_connection_2.close()
user_connection.close()
def test_killing_non_existing_transaction():
cursor = connect().cursor()
results = execute_and_fetch_all(cursor, "TERMINATE TRANSACTIONS '1'")
assert len(results) == 1
assert results[0][0] == "1" # transaction id
assert results[0][1] == False # not killed
def test_killing_multiple_non_existing_transactions():
cursor = connect().cursor()
transactions_id = ["'1'", "'2'", "'3'"]
results = execute_and_fetch_all(cursor, f"TERMINATE TRANSACTIONS {','.join(transactions_id)}")
assert len(results) == 3
for i in range(len(results)):
assert results[i][0] == eval(transactions_id[i]) # transaction id
assert results[i][1] == False # not killed
def test_admin_killing_multiple_non_existing_transactions():
# Starting, superadmin admin
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin")
# Connect with admin
admin_cursor = connect(username="admin", password="").cursor()
transactions_id = ["'1'", "'2'", "'3'"]
results = execute_and_fetch_all(admin_cursor, f"TERMINATE TRANSACTIONS {','.join(transactions_id)}")
assert len(results) == 3
for i in range(len(results)):
assert results[i][0] == eval(transactions_id[i]) # transaction id
assert results[i][1] == False # not killed
execute_and_fetch_all(admin_cursor, "DROP USER admin")
def test_user_killing_some_transactions():
"""Tests what happens when user can kill only some of the transactions given."""
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
execute_and_fetch_all(superadmin_cursor, "GRANT ALL PRIVILEGES TO admin")
execute_and_fetch_all(superadmin_cursor, "CREATE USER user1")
execute_and_fetch_all(superadmin_cursor, "REVOKE ALL PRIVILEGES FROM user1")
# Connect with user in two different sessions
admin_cursor = connect(username="admin", password="").cursor()
execute_and_fetch_all(admin_cursor, "CREATE USER user2")
execute_and_fetch_all(admin_cursor, "GRANT ALL PRIVILEGES TO user2")
user_connection_1 = connect(username="user1", password="")
user_cursor_1 = user_connection_1.cursor()
user_connection_2 = connect(username="user2", password="")
user_cursor_2 = user_connection_2.cursor()
process_1 = multiprocessing.Process(
target=process_function, args=(user_cursor_1, ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"])
)
process_2 = multiprocessing.Process(target=process_function, args=(user_cursor_2, ["BEGIN", "MATCH (n) RETURN n"]))
process_1.start()
process_2.start()
# Create another user1 connections
user_connection_1_copy = connect(username="user1", password="")
user_cursor_1_copy = user_connection_1_copy.cursor()
show_user_1_results = show_transactions_test(user_cursor_1_copy, 2)
if show_user_1_results[0][2] == ["SHOW TRANSACTIONS"]:
execution_index = 0
else:
execution_index = 1
assert show_user_1_results[1 - execution_index][2] == ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"]
# Connect with admin
time.sleep(0.5)
show_admin_results = show_transactions_test(admin_cursor, 3)
for show_admin_res in show_admin_results:
if show_admin_res[2] != "[SHOW TRANSACTIONS]":
execute_and_fetch_all(admin_cursor, f"TERMINATE TRANSACTIONS '{show_admin_res[1]}'")
user_connection_1.close()
user_connection_2.close()
if __name__ == "__main__":
sys.exit(pytest.main([__file__, "-rA"]))

View File

@ -0,0 +1,14 @@
test_transaction_queue: &test_transaction_queue
cluster:
main:
args: ["--bolt-port", "7687", "--log-level=TRACE", "--also-log-to-stderr"]
log_file: "transaction_queue.log"
setup_queries: []
validation_queries: []
workloads:
- name: "test-transaction-queue" # should be the same as the python file
binary: "tests/e2e/pytest_runner.sh"
proc: "tests/e2e/transaction_queue/procedures/"
args: ["transaction_queue/test_transaction_queue.py"]
<<: *test_transaction_queue

View File

@ -130,6 +130,12 @@ target_link_libraries(${test_prefix}query_serialization_property_value mg-query)
add_unit_test(query_streams.cpp) add_unit_test(query_streams.cpp)
target_link_libraries(${test_prefix}query_streams mg-query kafka-mock) target_link_libraries(${test_prefix}query_streams mg-query kafka-mock)
add_unit_test(transaction_queue.cpp)
target_link_libraries(${test_prefix}transaction_queue mg-communication mg-query mg-glue)
add_unit_test(transaction_queue_multiple.cpp)
target_link_libraries(${test_prefix}transaction_queue_multiple mg-communication mg-query mg-glue)
# Test query functions # Test query functions
add_unit_test(query_function_mgp_module.cpp) add_unit_test(query_function_mgp_module.cpp)
target_link_libraries(${test_prefix}query_function_mgp_module mg-query) target_link_libraries(${test_prefix}query_function_mgp_module mg-query)

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -18,6 +18,7 @@
#include "glue/communication.hpp" #include "glue/communication.hpp"
#include "gmock/gmock.h" #include "gmock/gmock.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "interpreter_faker.hpp"
#include "query/auth_checker.hpp" #include "query/auth_checker.hpp"
#include "query/config.hpp" #include "query/config.hpp"
#include "query/exceptions.hpp" #include "query/exceptions.hpp"
@ -40,57 +41,18 @@ auto ToEdgeList(const memgraph::communication::bolt::Value &v) {
return list; return list;
}; };
struct InterpreterFaker {
InterpreterFaker(memgraph::storage::Storage *db, const memgraph::query::InterpreterConfig config,
const std::filesystem::path &data_directory)
: interpreter_context(db, config, data_directory), interpreter(&interpreter_context) {
interpreter_context.auth_checker = &auth_checker;
}
auto Prepare(const std::string &query, const std::map<std::string, memgraph::storage::PropertyValue> &params = {}) {
ResultStreamFaker stream(interpreter_context.db);
const auto [header, _, qid] = interpreter.Prepare(query, params, nullptr);
stream.Header(header);
return std::make_pair(std::move(stream), qid);
}
void Pull(ResultStreamFaker *stream, std::optional<int> n = {}, std::optional<int> qid = {}) {
const auto summary = interpreter.Pull(stream, n, qid);
stream->Summary(summary);
}
/**
* Execute the given query and commit the transaction.
*
* Return the query stream.
*/
auto Interpret(const std::string &query, const std::map<std::string, memgraph::storage::PropertyValue> &params = {}) {
auto prepare_result = Prepare(query, params);
auto &stream = prepare_result.first;
auto summary = interpreter.Pull(&stream, {}, prepare_result.second);
stream.Summary(summary);
return std::move(stream);
}
memgraph::query::AllowEverythingAuthChecker auth_checker;
memgraph::query::InterpreterContext interpreter_context;
memgraph::query::Interpreter interpreter;
};
} // namespace } // namespace
// TODO: This is not a unit test, but tests/integration dir is chaotic at the // TODO: This is not a unit test, but tests/integration dir is chaotic at the
// moment. After tests refactoring is done, move/rename this. // moment. After tests refactoring is done, move/rename this.
class InterpreterTest : public ::testing::Test { class InterpreterTest : public ::testing::Test {
protected: public:
memgraph::storage::Storage db_; memgraph::storage::Storage db_;
std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_tests_unit_interpreter"}; std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_tests_unit_interpreter"};
memgraph::query::InterpreterContext interpreter_context{&db_, {}, data_directory};
InterpreterFaker default_interpreter{&db_, {}, data_directory}; InterpreterFaker default_interpreter{&interpreter_context};
auto Prepare(const std::string &query, const std::map<std::string, memgraph::storage::PropertyValue> &params = {}) { auto Prepare(const std::string &query, const std::map<std::string, memgraph::storage::PropertyValue> &params = {}) {
return default_interpreter.Prepare(query, params); return default_interpreter.Prepare(query, params);
@ -638,8 +600,6 @@ TEST_F(InterpreterTest, UniqueConstraintTest) {
} }
TEST_F(InterpreterTest, ExplainQuery) { TEST_F(InterpreterTest, ExplainQuery) {
const auto &interpreter_context = default_interpreter.interpreter_context;
EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); EXPECT_EQ(interpreter_context.plan_cache.size(), 0U);
EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); EXPECT_EQ(interpreter_context.ast_cache.size(), 0U);
auto stream = Interpret("EXPLAIN MATCH (n) RETURN *;"); auto stream = Interpret("EXPLAIN MATCH (n) RETURN *;");
@ -663,8 +623,6 @@ TEST_F(InterpreterTest, ExplainQuery) {
} }
TEST_F(InterpreterTest, ExplainQueryMultiplePulls) { TEST_F(InterpreterTest, ExplainQueryMultiplePulls) {
const auto &interpreter_context = default_interpreter.interpreter_context;
EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); EXPECT_EQ(interpreter_context.plan_cache.size(), 0U);
EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); EXPECT_EQ(interpreter_context.ast_cache.size(), 0U);
auto [stream, qid] = Prepare("EXPLAIN MATCH (n) RETURN *;"); auto [stream, qid] = Prepare("EXPLAIN MATCH (n) RETURN *;");
@ -698,8 +656,6 @@ TEST_F(InterpreterTest, ExplainQueryMultiplePulls) {
} }
TEST_F(InterpreterTest, ExplainQueryInMulticommandTransaction) { TEST_F(InterpreterTest, ExplainQueryInMulticommandTransaction) {
const auto &interpreter_context = default_interpreter.interpreter_context;
EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); EXPECT_EQ(interpreter_context.plan_cache.size(), 0U);
EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); EXPECT_EQ(interpreter_context.ast_cache.size(), 0U);
Interpret("BEGIN"); Interpret("BEGIN");
@ -725,8 +681,6 @@ TEST_F(InterpreterTest, ExplainQueryInMulticommandTransaction) {
} }
TEST_F(InterpreterTest, ExplainQueryWithParams) { TEST_F(InterpreterTest, ExplainQueryWithParams) {
const auto &interpreter_context = default_interpreter.interpreter_context;
EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); EXPECT_EQ(interpreter_context.plan_cache.size(), 0U);
EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); EXPECT_EQ(interpreter_context.ast_cache.size(), 0U);
auto stream = auto stream =
@ -751,8 +705,6 @@ TEST_F(InterpreterTest, ExplainQueryWithParams) {
} }
TEST_F(InterpreterTest, ProfileQuery) { TEST_F(InterpreterTest, ProfileQuery) {
const auto &interpreter_context = default_interpreter.interpreter_context;
EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); EXPECT_EQ(interpreter_context.plan_cache.size(), 0U);
EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); EXPECT_EQ(interpreter_context.ast_cache.size(), 0U);
auto stream = Interpret("PROFILE MATCH (n) RETURN *;"); auto stream = Interpret("PROFILE MATCH (n) RETURN *;");
@ -776,8 +728,6 @@ TEST_F(InterpreterTest, ProfileQuery) {
} }
TEST_F(InterpreterTest, ProfileQueryMultiplePulls) { TEST_F(InterpreterTest, ProfileQueryMultiplePulls) {
const auto &interpreter_context = default_interpreter.interpreter_context;
EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); EXPECT_EQ(interpreter_context.plan_cache.size(), 0U);
EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); EXPECT_EQ(interpreter_context.ast_cache.size(), 0U);
auto [stream, qid] = Prepare("PROFILE MATCH (n) RETURN *;"); auto [stream, qid] = Prepare("PROFILE MATCH (n) RETURN *;");
@ -820,8 +770,6 @@ TEST_F(InterpreterTest, ProfileQueryInMulticommandTransaction) {
} }
TEST_F(InterpreterTest, ProfileQueryWithParams) { TEST_F(InterpreterTest, ProfileQueryWithParams) {
const auto &interpreter_context = default_interpreter.interpreter_context;
EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); EXPECT_EQ(interpreter_context.plan_cache.size(), 0U);
EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); EXPECT_EQ(interpreter_context.ast_cache.size(), 0U);
auto stream = auto stream =
@ -846,8 +794,6 @@ TEST_F(InterpreterTest, ProfileQueryWithParams) {
} }
TEST_F(InterpreterTest, ProfileQueryWithLiterals) { TEST_F(InterpreterTest, ProfileQueryWithLiterals) {
const auto &interpreter_context = default_interpreter.interpreter_context;
EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); EXPECT_EQ(interpreter_context.plan_cache.size(), 0U);
EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); EXPECT_EQ(interpreter_context.ast_cache.size(), 0U);
auto stream = Interpret("PROFILE UNWIND range(1, 1000) AS x CREATE (:Node {id: x});", {}); auto stream = Interpret("PROFILE UNWIND range(1, 1000) AS x CREATE (:Node {id: x});", {});
@ -1087,7 +1033,6 @@ TEST_F(InterpreterTest, LoadCsvClause) {
} }
TEST_F(InterpreterTest, CacheableQueries) { TEST_F(InterpreterTest, CacheableQueries) {
const auto &interpreter_context = default_interpreter.interpreter_context;
// This should be cached // This should be cached
{ {
SCOPED_TRACE("Cacheable query"); SCOPED_TRACE("Cacheable query");
@ -1120,7 +1065,9 @@ TEST_F(InterpreterTest, AllowLoadCsvConfig) {
"CREATE TRIGGER trigger ON CREATE BEFORE COMMIT EXECUTE LOAD CSV FROM 'file.csv' WITH HEADER AS row RETURN " "CREATE TRIGGER trigger ON CREATE BEFORE COMMIT EXECUTE LOAD CSV FROM 'file.csv' WITH HEADER AS row RETURN "
"row"}; "row"};
InterpreterFaker interpreter_faker{&db_, {.query = {.allow_load_csv = allow_load_csv}}, directory_manager.Path()}; memgraph::query::InterpreterContext csv_interpreter_context{
&db_, {.query = {.allow_load_csv = allow_load_csv}}, directory_manager.Path()};
InterpreterFaker interpreter_faker{&csv_interpreter_context};
for (const auto &query : queries) { for (const auto &query : queries) {
if (allow_load_csv) { if (allow_load_csv) {
SCOPED_TRACE(fmt::format("'{}' should not throw because LOAD CSV is allowed", query)); SCOPED_TRACE(fmt::format("'{}' should not throw because LOAD CSV is allowed", query));

View File

@ -0,0 +1,49 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "communication/result_stream_faker.hpp"
#include "query/interpreter.hpp"
struct InterpreterFaker {
InterpreterFaker(memgraph::query::InterpreterContext *interpreter_context)
: interpreter_context(interpreter_context), interpreter(interpreter_context) {
interpreter_context->auth_checker = &auth_checker;
interpreter_context->interpreters.WithLock([this](auto &interpreters) { interpreters.insert(&interpreter); });
}
auto Prepare(const std::string &query, const std::map<std::string, memgraph::storage::PropertyValue> &params = {}) {
ResultStreamFaker stream(interpreter_context->db);
const auto [header, _, qid] = interpreter.Prepare(query, params, nullptr);
stream.Header(header);
return std::make_pair(std::move(stream), qid);
}
void Pull(ResultStreamFaker *stream, std::optional<int> n = {}, std::optional<int> qid = {}) {
const auto summary = interpreter.Pull(stream, n, qid);
stream->Summary(summary);
}
/**
* Execute the given query and commit the transaction.
*
* Return the query stream.
*/
auto Interpret(const std::string &query, const std::map<std::string, memgraph::storage::PropertyValue> &params = {}) {
auto prepare_result = Prepare(query, params);
auto &stream = prepare_result.first;
auto summary = interpreter.Pull(&stream, {}, prepare_result.second);
stream.Summary(summary);
return std::move(stream);
}
memgraph::query::AllowEverythingAuthChecker auth_checker;
memgraph::query::InterpreterContext *interpreter_context;
memgraph::query::Interpreter interpreter;
};

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -39,7 +39,6 @@ class MockAuthChecker : public memgraph::query::AuthChecker {
public: public:
MOCK_CONST_METHOD2(IsUserAuthorized, bool(const std::optional<std::string> &username, MOCK_CONST_METHOD2(IsUserAuthorized, bool(const std::optional<std::string> &username,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges)); const std::vector<memgraph::query::AuthQuery::Privilege> &privileges));
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
MOCK_CONST_METHOD2(GetFineGrainedAuthChecker, MOCK_CONST_METHOD2(GetFineGrainedAuthChecker,
std::unique_ptr<memgraph::query::FineGrainedAuthChecker>( std::unique_ptr<memgraph::query::FineGrainedAuthChecker>(

View File

@ -0,0 +1,75 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <chrono>
#include <stop_token>
#include <string>
#include <thread>
#include <gtest/gtest.h>
#include "gmock/gmock.h"
#include "interpreter_faker.hpp"
/*
Tests rely on the fact that interpreters are sequentially added to runninng_interpreters to get transaction_id of its
corresponding interpreter/.
*/
class TransactionQueueSimpleTest : public ::testing::Test {
protected:
memgraph::storage::Storage db_;
std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_tests_unit_transaction_queue_intr"};
memgraph::query::InterpreterContext interpreter_context{&db_, {}, data_directory};
InterpreterFaker running_interpreter{&interpreter_context}, main_interpreter{&interpreter_context};
};
TEST_F(TransactionQueueSimpleTest, TwoInterpretersInterleaving) {
bool started = false;
std::jthread running_thread = std::jthread(
[this, &started](std::stop_token st, int thread_index) {
running_interpreter.Interpret("BEGIN");
started = true;
},
0);
{
while (!started) {
std::this_thread::sleep_for(std::chrono::milliseconds(20));
}
main_interpreter.Interpret("CREATE (:Person {prop: 1})");
auto show_stream = main_interpreter.Interpret("SHOW TRANSACTIONS");
ASSERT_EQ(show_stream.GetResults().size(), 2U);
// superadmin executing the transaction
EXPECT_EQ(show_stream.GetResults()[0][0].ValueString(), "");
ASSERT_TRUE(show_stream.GetResults()[0][1].IsString());
EXPECT_EQ(show_stream.GetResults()[0][2].ValueList().at(0).ValueString(), "SHOW TRANSACTIONS");
// Also anonymous user executing
EXPECT_EQ(show_stream.GetResults()[1][0].ValueString(), "");
ASSERT_TRUE(show_stream.GetResults()[1][1].IsString());
// Kill the other transaction
std::string run_trans_id = show_stream.GetResults()[1][1].ValueString();
std::string esc_run_trans_id = "'" + run_trans_id + "'";
auto terminate_stream = main_interpreter.Interpret("TERMINATE TRANSACTIONS " + esc_run_trans_id);
// check result of killing
ASSERT_EQ(terminate_stream.GetResults().size(), 1U);
EXPECT_EQ(terminate_stream.GetResults()[0][0].ValueString(), run_trans_id);
ASSERT_TRUE(terminate_stream.GetResults()[0][1].ValueBool()); // that the transaction is actually killed
// check the number of transactions now
auto show_stream_after_killing = main_interpreter.Interpret("SHOW TRANSACTIONS");
ASSERT_EQ(show_stream_after_killing.GetResults().size(), 1U);
// test the state of the database
auto results_stream = main_interpreter.Interpret("MATCH (n) RETURN n");
ASSERT_EQ(results_stream.GetResults().size(), 1U); // from the main interpreter
main_interpreter.Interpret("MATCH (n) DETACH DELETE n");
// finish thread
running_thread.request_stop();
}
}

View File

@ -0,0 +1,118 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <chrono>
#include <random>
#include <stop_token>
#include <string>
#include <thread>
#include <gtest/gtest.h>
#include "gmock/gmock.h"
#include "spdlog/spdlog.h"
#include "interpreter_faker.hpp"
#include "query/exceptions.hpp"
constexpr int NUM_INTERPRETERS = 4, INSERTIONS = 4000;
/*
Tests rely on the fact that interpreters are sequentially added to running_interpreters to get transaction_id of its
corresponding interpreter.
*/
class TransactionQueueMultipleTest : public ::testing::Test {
protected:
memgraph::storage::Storage db_;
std::filesystem::path data_directory{std::filesystem::temp_directory_path() /
"MG_tests_unit_transaction_queue_multiple_intr"};
memgraph::query::InterpreterContext interpreter_context{&db_, {}, data_directory};
InterpreterFaker main_interpreter{&interpreter_context};
std::vector<InterpreterFaker *> running_interpreters;
TransactionQueueMultipleTest() {
for (int i = 0; i < NUM_INTERPRETERS; ++i) {
InterpreterFaker *faker = new InterpreterFaker(&interpreter_context);
running_interpreters.push_back(faker);
}
}
~TransactionQueueMultipleTest() override {
for (int i = 0; i < NUM_INTERPRETERS; ++i) {
delete running_interpreters[i];
}
}
};
// Tests whether admin can see transaction of superadmin
TEST_F(TransactionQueueMultipleTest, TerminateTransaction) {
std::vector<bool> started(NUM_INTERPRETERS, false);
auto thread_func = [this, &started](int thread_index) {
try {
running_interpreters[thread_index]->Interpret("BEGIN");
started[thread_index] = true;
// add try-catch block
for (int j = 0; j < INSERTIONS; ++j) {
running_interpreters[thread_index]->Interpret("CREATE (:Person {prop: " + std::to_string(thread_index) + "})");
}
} catch (memgraph::query::HintedAbortError &e) {
}
};
{
std::vector<std::jthread> running_threads;
running_threads.reserve(NUM_INTERPRETERS);
for (int i = 0; i < NUM_INTERPRETERS; ++i) {
running_threads.emplace_back(thread_func, i);
}
while (!std::all_of(started.begin(), started.end(), [](const bool v) { return v; })) {
std::this_thread::sleep_for(std::chrono::milliseconds(20));
}
auto show_stream = main_interpreter.Interpret("SHOW TRANSACTIONS");
ASSERT_EQ(show_stream.GetResults().size(), NUM_INTERPRETERS + 1);
// Choose random transaction to kill
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<int> distr(0, NUM_INTERPRETERS - 1);
int index_to_terminate = distr(gen);
// Kill random transaction
std::string run_trans_id =
std::to_string(running_interpreters[index_to_terminate]->interpreter.GetTransactionId().value());
std::string esc_run_trans_id = "'" + run_trans_id + "'";
auto terminate_stream = main_interpreter.Interpret("TERMINATE TRANSACTIONS " + esc_run_trans_id);
// check result of killing
ASSERT_EQ(terminate_stream.GetResults().size(), 1U);
EXPECT_EQ(terminate_stream.GetResults()[0][0].ValueString(), run_trans_id);
ASSERT_TRUE(terminate_stream.GetResults()[0][1].ValueBool()); // that the transaction is actually killed
// test here show transactions
auto show_stream_after_kill = main_interpreter.Interpret("SHOW TRANSACTIONS");
ASSERT_EQ(show_stream_after_kill.GetResults().size(), NUM_INTERPRETERS);
// wait to finish for threads
for (int i = 0; i < NUM_INTERPRETERS; ++i) {
running_threads[i].join();
}
// test the state of the database
for (int i = 0; i < NUM_INTERPRETERS; ++i) {
if (i != index_to_terminate) {
running_interpreters[i]->Interpret("COMMIT");
}
std::string fetch_query = "MATCH (n:Person) WHERE n.prop=" + std::to_string(i) + " RETURN n";
auto results_stream = main_interpreter.Interpret(fetch_query);
if (i == index_to_terminate) {
ASSERT_EQ(results_stream.GetResults().size(), 0);
} else {
ASSERT_EQ(results_stream.GetResults().size(), INSERTIONS);
}
}
main_interpreter.Interpret("MATCH (n) DETACH DELETE n");
}
}