From bce48361ca887d22763294e437d7bb1c742e9daa Mon Sep 17 00:00:00 2001 From: andrejtonev <29177572+andrejtonev@users.noreply.github.com> Date: Wed, 20 Sep 2023 13:13:54 +0200 Subject: [PATCH] Decoupling Interpreter from Storage (#1186) Unique/global InterpreterContext that is Storage agnostic (has a reference to the DbmsHandler instead) * InterpreterContext is no longer the owner of Storage * New Database structure that handles Storage, Triggers, Streams * Renamed SessinContextHandler to DbmsHandler and simplified the multi-tenant logic * Added Gatekeeper and updated handlers to use it --------- Co-authored-by: Gareth Lloyd --- .gitignore | 3 +- src/CMakeLists.txt | 1 + src/communication/bolt/v1/session.hpp | 6 +- src/communication/v2/session.hpp | 27 +- src/dbms/CMakeLists.txt | 3 + src/dbms/database.cpp | 34 + src/dbms/database.hpp | 148 ++++ src/dbms/database_handler.hpp | 97 +++ src/dbms/dbms_handler.hpp | 390 ++++++++++ src/dbms/global.hpp | 47 -- src/dbms/handler.hpp | 109 ++- src/dbms/interp_handler.hpp | 106 --- src/dbms/session_context.hpp | 61 -- src/dbms/session_context_handler.hpp | 603 ---------------- src/glue/CMakeLists.txt | 9 +- src/glue/MonitoringServerT.cpp | 3 +- src/glue/MonitoringServerT.hpp | 9 +- src/glue/ServerT.cpp | 6 +- src/glue/ServerT.hpp | 39 +- src/glue/SessionHL.cpp | 237 +++--- src/glue/SessionHL.hpp | 95 +-- src/glue/auth_handler.cpp | 8 + src/glue/auth_handler.hpp | 4 +- src/glue/run_id.cpp | 15 + src/glue/run_id.hpp | 16 + src/http_handlers/metrics.hpp | 16 +- src/memgraph.cpp | 174 ++--- src/query/CMakeLists.txt | 7 +- src/query/auth_checker.hpp | 56 +- src/query/auth_query_handler.cpp | 12 + src/query/auth_query_handler.hpp | 126 ++++ src/query/cypher_query_interpreter.cpp | 2 + src/query/cypher_query_interpreter.hpp | 2 - src/query/exceptions.hpp | 5 + src/query/interpreter.cpp | 839 +++++++++++----------- src/query/interpreter.hpp | 174 +---- src/query/interpreter_context.cpp | 73 ++ src/query/interpreter_context.hpp | 85 +++ src/query/procedure/module.cpp | 1 + src/query/stream/common.cpp | 2 +- src/query/stream/common.hpp | 2 +- src/query/stream/sources.cpp | 10 +- src/query/stream/sources.hpp | 6 +- src/query/stream/streams.cpp | 74 +- src/query/stream/streams.hpp | 19 +- src/storage/v2/config.hpp | 1 + src/utils/gatekeeper.hpp | 213 ++++++ src/utils/sync_ptr.hpp | 189 ----- src/utils/tsc.cpp | 10 +- src/utils/tsc.hpp | 4 +- tests/benchmark/expansion.cpp | 24 +- tests/manual/single_query.cpp | 13 +- tests/unit/CMakeLists.txt | 19 +- tests/unit/bolt_session.cpp | 7 - tests/unit/dbms_database.cpp | 224 ++++++ tests/unit/dbms_handler.cpp | 193 +++++ tests/unit/dbms_interp.cpp | 431 ----------- tests/unit/dbms_sc_handler.cpp | 343 --------- tests/unit/interpreter.cpp | 119 +-- tests/unit/interpreter_faker.hpp | 7 +- tests/unit/query_dump.cpp | 271 ++++--- tests/unit/query_plan_edge_cases.cpp | 36 +- tests/unit/query_streams.cpp | 65 +- tests/unit/storage_v2_storage_mode.cpp | 39 +- tests/unit/transaction_queue.cpp | 36 +- tests/unit/transaction_queue_multiple.cpp | 35 +- tests/unit/utils_sync_ptr.cpp | 296 -------- 67 files changed, 2947 insertions(+), 3389 deletions(-) create mode 100644 src/dbms/CMakeLists.txt create mode 100644 src/dbms/database.cpp create mode 100644 src/dbms/database.hpp create mode 100644 src/dbms/database_handler.hpp create mode 100644 src/dbms/dbms_handler.hpp delete mode 100644 src/dbms/interp_handler.hpp delete mode 100644 src/dbms/session_context.hpp delete mode 100644 src/dbms/session_context_handler.hpp create mode 100644 src/glue/run_id.cpp create mode 100644 src/glue/run_id.hpp create mode 100644 src/query/auth_query_handler.cpp create mode 100644 src/query/auth_query_handler.hpp create mode 100644 src/query/interpreter_context.cpp create mode 100644 src/query/interpreter_context.hpp create mode 100644 src/utils/gatekeeper.hpp delete mode 100644 src/utils/sync_ptr.hpp create mode 100644 tests/unit/dbms_database.cpp create mode 100644 tests/unit/dbms_handler.cpp delete mode 100644 tests/unit/dbms_interp.cpp delete mode 100644 tests/unit/dbms_sc_handler.cpp delete mode 100644 tests/unit/utils_sync_ptr.cpp diff --git a/.gitignore b/.gitignore index 754661364..1534b2640 100644 --- a/.gitignore +++ b/.gitignore @@ -16,8 +16,7 @@ .ycm_extra_conf.pyc .temp/ Testing/ -build -build/ +/build*/ release/examples/build cmake-build-* cmake/DownloadProject/ diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 833072ec5..a3a53fcc2 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -18,6 +18,7 @@ add_subdirectory(rpc) add_subdirectory(license) add_subdirectory(auth) add_subdirectory(audit) +add_subdirectory(dbms) add_subdirectory(flags) string(TOLOWER ${CMAKE_BUILD_TYPE} lower_build_type) diff --git a/src/communication/bolt/v1/session.hpp b/src/communication/bolt/v1/session.hpp index ab594722c..bcceddf45 100644 --- a/src/communication/bolt/v1/session.hpp +++ b/src/communication/bolt/v1/session.hpp @@ -54,7 +54,7 @@ class SessionException : public utils::BasicException { * @tparam TOutputStream type of output stream that will be used */ template -class Session : public dbms::SessionInterface { +class Session { public: using TEncoder = Encoder>; @@ -208,8 +208,8 @@ class Session : public dbms::SessionInterface { Version version_; - std::string GetDatabaseName() const override = 0; - std::string UUID() const final { return session_uuid_; } + virtual std::string GetDatabaseName() const = 0; + std::string UUID() const { return session_uuid_; } private: void ClientFailureInvalidData() { diff --git a/src/communication/v2/session.hpp b/src/communication/v2/session.hpp index f74740197..37e55e112 100644 --- a/src/communication/v2/session.hpp +++ b/src/communication/v2/session.hpp @@ -110,11 +110,7 @@ class WebsocketSession : public std::enable_shared_from_this(new WebsocketSession(std::forward(args)...)); } -#ifdef MG_ENTERPRISE - ~WebsocketSession() { session_context_->Delete(session_); } -#else ~WebsocketSession() = default; -#endif WebsocketSession(const WebsocketSession &) = delete; WebsocketSession &operator=(const WebsocketSession &) = delete; @@ -171,14 +167,15 @@ class WebsocketSession : public std::enable_shared_from_thisic, endpoint, input_buffer_.read_end(), &output_stream_, session_context->auth, +#ifdef MG_ENTERPRISE + session_context->audit_log +#endif + }, session_context_{session_context}, endpoint_{endpoint}, remote_endpoint_{ws_.next_layer().socket().remote_endpoint()}, service_name_{service_name} { -#ifdef MG_ENTERPRISE - session_context_->Register(session_); -#endif } void OnAccept(boost::beast::error_code ec) { @@ -286,11 +283,7 @@ class Session final : public std::enable_shared_from_this(new Session(std::forward(args)...)); } -#ifdef MG_ENTERPRISE - ~Session() { session_context_->Delete(session_); } -#else ~Session() = default; -#endif Session(const Session &) = delete; Session(Session &&) = delete; @@ -366,17 +359,17 @@ class Session final : public std::enable_shared_from_thisic, endpoint, input_buffer_.read_end(), &output_stream_, session_context->auth, +#ifdef MG_ENTERPRISE + session_context->audit_log +#endif + }, session_context_{session_context}, endpoint_{endpoint}, remote_endpoint_{GetRemoteEndpoint()}, service_name_{service_name}, timeout_seconds_(inactivity_timeout_sec), timeout_timer_(GetExecutor()) { -#ifdef MG_ENTERPRISE - // TODO Try to remove Register (see comment at SessionInterface declaration) - session_context_->Register(session_); -#endif ExecuteForSocket([](auto &&socket) { socket.lowest_layer().set_option(tcp::no_delay(true)); // enable PSH socket.lowest_layer().set_option(boost::asio::socket_base::keep_alive(true)); // enable SO_KEEPALIVE diff --git a/src/dbms/CMakeLists.txt b/src/dbms/CMakeLists.txt new file mode 100644 index 000000000..8796790e4 --- /dev/null +++ b/src/dbms/CMakeLists.txt @@ -0,0 +1,3 @@ + +add_library(mg-dbms STATIC database.cpp) +target_link_libraries(mg-dbms mg-utils mg-storage-v2 mg-query) diff --git a/src/dbms/database.cpp b/src/dbms/database.cpp new file mode 100644 index 000000000..89c3d2bf5 --- /dev/null +++ b/src/dbms/database.cpp @@ -0,0 +1,34 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "dbms/database.hpp" +#include "storage/v2/disk/storage.hpp" +#include "storage/v2/inmemory/storage.hpp" + +template struct memgraph::utils::Gatekeeper; + +namespace memgraph::dbms { + +Database::Database(const storage::Config &config) + : trigger_store_(config.durability.storage_directory / "triggers"), + streams_{config.durability.storage_directory / "streams"} { + if (config.force_on_disk || utils::DirExists(config.disk.main_storage_directory)) { + storage_ = std::make_unique(config); + } else { + storage_ = std::make_unique(config); + } +} + +void Database::SwitchToOnDisk() { + storage_ = std::make_unique(std::move(storage_->config_)); +} + +} // namespace memgraph::dbms diff --git a/src/dbms/database.hpp b/src/dbms/database.hpp new file mode 100644 index 000000000..a95cf6d91 --- /dev/null +++ b/src/dbms/database.hpp @@ -0,0 +1,148 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "query/cypher_query_interpreter.hpp" +#include "query/stream/streams.hpp" +#include "query/trigger.hpp" +#include "storage/v2/storage.hpp" +#include "utils/gatekeeper.hpp" + +namespace memgraph::dbms { + +/** + * @brief Class containing everything associated with a single Database + * + */ +class Database { + public: + /** + * @brief Construct a new Database object + * + * @param config storage configuration + */ + explicit Database(const storage::Config &config); + + /** + * @brief Returns the raw storage pointer. + * @note Ideally everybody would be using an accessor + * TODO: Remove + * + * @return storage::Storage* + */ + storage::Storage *storage() { return storage_.get(); } + + /** + * @brief Storage's Accessor + * + * @param override_isolation_level + * @return std::unique_ptr + */ + std::unique_ptr Access( + std::optional override_isolation_level = {}) { + return storage_->Access(override_isolation_level); + } + + /** + * @brief Unique storage identified (name) + * + * @return const std::string& + */ + const std::string &id() const { return storage_->id(); } + + /** + * @brief Returns the storage configuration + * + * @return const storage::Config& + */ + const storage::Config &config() const { return storage_->config_; } + + /** + * @brief Get the storage mode + * + * @return storage::StorageMode + */ + storage::StorageMode GetStorageMode() const { return storage_->GetStorageMode(); } + + /** + * @brief Get the storage info + * + * @return storage::StorageInfo + */ + storage::StorageInfo GetInfo() const { return storage_->GetInfo(); } + + /** + * @brief Switch storage to OnDisk + * + */ + void SwitchToOnDisk(); + + /** + * @brief Returns the raw TriggerStore pointer + * + * @return query::TriggerStore* + */ + query::TriggerStore *trigger_store() { return &trigger_store_; } + + /** + * @brief Returns the raw Streams pointer + * + * @return query::stream::Streams* + */ + query::stream::Streams *streams() { return &streams_; } + + /** + * @brief Returns the raw ThreadPool pointer (used for after commit triggers) + * + * @return utils::ThreadPool* + */ + utils::ThreadPool *thread_pool() { return &after_commit_trigger_pool_; } + + /** + * @brief Add task to the after commit trigger thread pool + * + * @param new_task + */ + void AddTask(std::function new_task) { after_commit_trigger_pool_.AddTask(std::move(new_task)); } + + /** + * @brief Returns the PlanCache vector raw pointer + * + * @return utils::SkipList* + */ + utils::SkipList *plan_cache() { return &plan_cache_; } + + private: + std::unique_ptr storage_; //!< Underlying storage + query::TriggerStore trigger_store_; //!< Triggers associated with the storage + utils::ThreadPool after_commit_trigger_pool_{1}; //!< Thread pool for executing after commit triggers + query::stream::Streams streams_; //!< Streams associated with the storage + + // TODO: Move to a better place + utils::SkipList plan_cache_; //!< Plan cache associated with the storage +}; + +} // namespace memgraph::dbms + +extern template struct memgraph::utils::Gatekeeper; + +namespace memgraph::dbms { +using DatabaseAccess = memgraph::utils::Gatekeeper::Accessor; +} // namespace memgraph::dbms diff --git a/src/dbms/database_handler.hpp b/src/dbms/database_handler.hpp new file mode 100644 index 000000000..4f142a341 --- /dev/null +++ b/src/dbms/database_handler.hpp @@ -0,0 +1,97 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#ifdef MG_ENTERPRISE + +#include +#include +#include +#include +#include +#include +#include + +#include "dbms/database.hpp" + +#include "handler.hpp" + +namespace memgraph::dbms { + +/* NOTE + * The Database object is shared. All the higher-level function calls should be protected. + * Storage function calls should already be protected; add protection where needed. + * + * Current implementation uses a handler of Database objects. It owns them and gives + * Gatekeeper::Accessor to it. These guarantee that the object won't be + * destroyed unless no one is using it. + */ + +/**Config + * @brief Multi-database storage handler + * + */ +class DatabaseHandler : public Handler { + public: + using HandlerT = Handler; + + /** + * @brief Generate new storage associated with the passed name. + * + * @param name Name associating the new interpreter context + * @param config Storage configuration + * @return HandlerT::NewResult + */ + HandlerT::NewResult New(std::string_view name, storage::Config config) { + // Control that no one is using the same data directory + if (std::any_of(begin(), end(), [&](auto &elem) { + auto db_acc = elem.second.access(); + MG_ASSERT(db_acc.has_value(), "Gatekeeper in invalid state"); + return db_acc->get()->config().durability.storage_directory == config.durability.storage_directory; + })) { + spdlog::info("Tried to generate new storage using a claimed directory."); + return NewError::EXISTS; + } + config.name = name; // Set storage id via config + return HandlerT::New(std::piecewise_construct, name, config); + } + + /** + * @brief All currently active storage. + * + * @return std::vector + */ + std::vector All() const { + std::vector res; + res.reserve(std::distance(cbegin(), cend())); + std::for_each(cbegin(), cend(), [&](const auto &elem) { res.push_back(elem.first); }); + return res; + } + + /** + * @brief Get the associated storage's configuration + * + * @param name + * @return std::optional + */ + std::optional GetConfig(std::string_view name) { + auto db = Get(name); + if (db) { + return (*db)->config(); + } + return std::nullopt; + } +}; + +} // namespace memgraph::dbms + +#endif diff --git a/src/dbms/dbms_handler.hpp b/src/dbms/dbms_handler.hpp new file mode 100644 index 000000000..a1aa51dae --- /dev/null +++ b/src/dbms/dbms_handler.hpp @@ -0,0 +1,390 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "auth/auth.hpp" +#include "constants.hpp" +#include "dbms/database_handler.hpp" +#include "global.hpp" +#include "query/config.hpp" +#include "query/interpreter_context.hpp" +#include "spdlog/spdlog.h" +#include "storage/v2/durability/durability.hpp" +#include "storage/v2/durability/paths.hpp" +#include "utils/exceptions.hpp" +#include "utils/file.hpp" +#include "utils/logging.hpp" +#include "utils/result.hpp" +#include "utils/rw_lock.hpp" +#include "utils/synchronized.hpp" +#include "utils/uuid.hpp" + +namespace memgraph::dbms { + +#ifdef MG_ENTERPRISE + +using DeleteResult = utils::BasicResult; + +/** + * @brief Multi-database session contexts handler. + */ +class DbmsHandler { + public: + using LockT = utils::RWLock; + using NewResultT = utils::BasicResult; + + struct Statistics { + uint64_t num_vertex; //!< Sum of vertexes in every database + uint64_t num_edges; //!< Sum of edges in every database + uint64_t num_databases; //! number of isolated databases + }; + + /** + * @brief Initialize the handler. + * + * @param configs storage and interpreter configurations + * @param auth pointer to the global authenticator + * @param recovery_on_startup restore databases (and its content) and authentication data + * @param delete_on_drop when dropping delete any associated directories on disk + */ + DbmsHandler(storage::Config config, auto *auth, bool recovery_on_startup, bool delete_on_drop) + : lock_{utils::RWLock::Priority::READ}, default_config_{std::move(config)}, delete_on_drop_(delete_on_drop) { + // TODO: Decouple storage config from dbms config + // TODO: Save individual db configs inside the kvstore and restore from there + storage::UpdatePaths(*default_config_, default_config_->durability.storage_directory / "databases"); + const auto &db_dir = default_config_->durability.storage_directory; + const auto durability_dir = db_dir / ".durability"; + utils::EnsureDirOrDie(db_dir); + utils::EnsureDirOrDie(durability_dir); + durability_ = std::make_unique(durability_dir); + + // Generate the default database + MG_ASSERT(!NewDefault_().HasError(), "Failed while creating the default DB."); + + // Recover previous databases + if (recovery_on_startup) { + for (const auto &[name, _] : *durability_) { + if (name == kDefaultDB) continue; // Already set + spdlog::info("Restoring database {}.", name); + MG_ASSERT(!New_(name).HasError(), "Failed while creating database {}.", name); + spdlog::info("Database {} restored.", name); + } + } else { // Clear databases from the durability list and auth + auto locked_auth = auth->Lock(); + for (const auto &[name, _] : *durability_) { + if (name == kDefaultDB) continue; + locked_auth->DeleteDatabase(name); + durability_->Delete(name); + } + } + } + + /** + * @brief Create a new Database associated with the "name" database + * + * @param name name of the database + * @return NewResultT context on success, error on failure + */ + NewResultT New(const std::string &name) { + std::lock_guard wr(lock_); + return New_(name, name); + } + + /** + * @brief Get the context associated with the "name" database + * + * @param name + * @return DatabaseAccess + * @throw UnknownDatabaseException if database not found + */ + DatabaseAccess Get(std::string_view name) { + std::shared_lock rd(lock_); + return Get_(name); + } + + /** + * @brief Delete database. + * + * @param db_name database name + * @return DeleteResult error on failure + */ + DeleteResult Delete(const std::string &db_name) { + std::lock_guard wr(lock_); + if (db_name == kDefaultDB) { + // MSG cannot delete the default db + return DeleteError::DEFAULT_DB; + } + + const auto storage_path = StorageDir_(db_name); + if (!storage_path) return DeleteError::NON_EXISTENT; + + // Check if db exists + try { + // Low level handlers + if (!db_handler_.Delete(db_name)) { + return DeleteError::USING; + } + } catch (utils::BasicException &) { + return DeleteError::NON_EXISTENT; + } + + // Remove from durability list + if (durability_) durability_->Delete(db_name); + + // Delete disk storage + if (delete_on_drop_) { + std::error_code ec; + (void)std::filesystem::remove_all(*storage_path, ec); + if (ec) { + spdlog::error("Failed to clean disk while deleting database \"{}\".", db_name); + defunct_dbs_.emplace(db_name); + return DeleteError::DISK_FAIL; + } + } + + // Delete from defunct_dbs_ (in case a second delete call was successful) + defunct_dbs_.erase(db_name); + + return {}; // Success + } + + /** + * @brief Return all active databases. + * + * @return std::vector + */ + std::vector All() const { + std::shared_lock rd(lock_); + return db_handler_.All(); + } + + /** + * @brief Return the number of vertex across all databases. + * + * @return uint64_t + */ + Statistics Info() { + // TODO: Handle overflow? + uint64_t nv = 0; + uint64_t ne = 0; + std::shared_lock rd(lock_); + const uint64_t ndb = std::distance(db_handler_.cbegin(), db_handler_.cend()); + for (auto &[_, db_gk] : db_handler_) { + auto db_acc_opt = db_gk.access(); + if (!db_acc_opt) continue; + auto &db_acc = *db_acc_opt; + const auto &info = db_acc->GetInfo(); + nv += info.vertex_count; + ne += info.edge_count; + } + return {nv, ne, ndb}; + } + + /** + * @brief Restore triggers for all currently defined databases. + * @note: Triggers can execute query procedures, so we need to reload the modules first and then the triggers + * + * @param ic global InterpreterContext + */ + void RestoreTriggers(query::InterpreterContext *ic) { + std::lock_guard wr(lock_); + for (auto &[_, db_gk] : db_handler_) { + auto db_acc_opt = db_gk.access(); + if (!db_acc_opt) continue; + auto &db_acc = *db_acc_opt; + spdlog::debug("Restoring trigger for database \"{}\"", db_acc->id()); + auto storage_accessor = db_acc->Access(); + auto dba = memgraph::query::DbAccessor{storage_accessor.get()}; + db_acc->trigger_store()->RestoreTriggers(&ic->ast_cache, &dba, ic->config.query, ic->auth_checker); + } + } + + /** + * @brief Restore streams of all currently defined databases. + * @note: Stream transformations are using modules, they have to be restored after the query modules are loaded. + * + * @param ic global InterpreterContext + */ + void RestoreStreams(query::InterpreterContext *ic) { + std::lock_guard wr(lock_); + for (auto &[_, db_gk] : db_handler_) { + auto db_acc = db_gk.access(); + if (!db_acc) continue; + auto *db = db_acc->get(); + spdlog::debug("Restoring streams for database \"{}\"", db->id()); + db->streams()->RestoreStreams(*db_acc, ic); + } + } + + private: + /** + * @brief return the storage directory of the associated database + * + * @param name Database name + * @return std::optional + */ + std::optional StorageDir_(const std::string &name) { + const auto conf = db_handler_.GetConfig(name); + if (conf) { + return conf->durability.storage_directory; + } + spdlog::debug("Failed to find storage dir for database \"{}\"", name); + return {}; + } + + /** + * @brief Create a new Database associated with the "name" database + * + * @param name name of the database + * @return NewResultT context on success, error on failure + */ + NewResultT New_(const std::string &name) { return New_(name, name); } + + /** + * @brief Create a new Database associated with the "name" database + * + * @param name name of the database + * @param storage_subdir undelying RocksDB directory + * @return NewResultT context on success, error on failure + */ + NewResultT New_(const std::string &name, std::filesystem::path storage_subdir) { + if (default_config_) { + auto config_copy = *default_config_; + storage::UpdatePaths(config_copy, default_config_->durability.storage_directory / storage_subdir); + return New_(name, config_copy); + } + spdlog::info("Trying to generate session context without any configurations."); + return NewError::NO_CONFIGS; + } + + /** + * @brief Create a new Database associated with the "name" database + * + * @param name name of the database + * @param storage_config storage configuration + * @return NewResultT context on success, error on failure + */ + NewResultT New_(const std::string &name, storage::Config &storage_config) { + if (defunct_dbs_.contains(name)) { + spdlog::warn("Failed to generate database due to the unknown state of the previously defunct database \"{}\".", + name); + return NewError::DEFUNCT; + } + + auto new_db = db_handler_.New(name, storage_config); + if (new_db.HasValue()) { + // Success + if (durability_) durability_->Put(name, "ok"); // TODO: Serialize the configuration? + return new_db.GetValue(); + } + return new_db.GetError(); + } + + /** + * @brief Create a new Database associated with the default database + * + * @return NewResultT context on success, error on failure + */ + NewResultT NewDefault_() { + // Create the default DB in the root (this is how it was done pre multi-tenancy) + auto res = New_(kDefaultDB, ".."); + if (res.HasValue()) { + // For back-compatibility... + // Recreate the dbms layout for the default db and symlink to the root + const auto dir = StorageDir_(kDefaultDB); + MG_ASSERT(dir, "Failed to find storage path."); + const auto main_dir = *dir / "databases" / kDefaultDB; + + if (!std::filesystem::exists(main_dir)) { + std::filesystem::create_directory(main_dir); + } + + // Force link on-disk directories + const auto conf = db_handler_.GetConfig(kDefaultDB); + MG_ASSERT(conf, "No configuration for the default database."); + const auto &tmp_conf = conf->disk; + std::vector to_link{ + tmp_conf.main_storage_directory, tmp_conf.label_index_directory, + tmp_conf.label_property_index_directory, tmp_conf.unique_constraints_directory, + tmp_conf.name_id_mapper_directory, tmp_conf.id_name_mapper_directory, + tmp_conf.durability_directory, tmp_conf.wal_directory, + }; + + // Add in-memory paths + // Some directories are redundant (skip those) + const std::vector skip{".lock", "audit_log", "auth", "databases", "internal_modules", "settings"}; + for (auto const &item : std::filesystem::directory_iterator{*dir}) { + const auto dir_name = std::filesystem::relative(item.path(), item.path().parent_path()); + if (std::find(skip.begin(), skip.end(), dir_name) != skip.end()) continue; + to_link.push_back(item.path()); + } + + // Symlink to root dir + for (auto const &item : to_link) { + const auto dir_name = std::filesystem::relative(item, item.parent_path()); + const auto link = main_dir / dir_name; + const auto to = std::filesystem::relative(item, main_dir); + if (!std::filesystem::is_symlink(link) && !std::filesystem::exists(link)) { + std::filesystem::create_directory_symlink(to, link); + } else { // Check existing link + std::error_code ec; + const auto test_link = std::filesystem::read_symlink(link, ec); + if (ec || test_link != to) { + MG_ASSERT(false, + "Memgraph storage directory incompatible with new version.\n" + "Please use a clean directory or remove \"{}\" and try again.", + link.string()); + } + } + } + } + return res; + } + + /** + * @brief Get the DatabaseAccess for the database associated with the "name" + * + * @param name + * @return DatabaseAccess + * @throw UnknownDatabaseException if trying to get unknown database + */ + DatabaseAccess Get_(std::string_view name) { + auto db = db_handler_.Get(name); + if (db) { + return *db; + } + throw UnknownDatabaseException("Tried to retrieve an unknown database \"{}\".", name); + } + + // Should storage objects ever be deleted? + mutable LockT lock_; //!< protective lock + DatabaseHandler db_handler_; //!< multi-tenancy storage handler + std::optional default_config_; //!< Storage configuration used when creating new databases + std::unique_ptr durability_; //!< list of active dbs (pointer so we can postpone its creation) + std::set defunct_dbs_; //!< Databases that are in an unknown state due to various failures + bool delete_on_drop_; //!< Flag defining if dropping storage also deletes its directory +}; +#endif + +} // namespace memgraph::dbms diff --git a/src/dbms/global.hpp b/src/dbms/global.hpp index e0c3beae1..7b521d5d3 100644 --- a/src/dbms/global.hpp +++ b/src/dbms/global.hpp @@ -60,51 +60,4 @@ class UnknownDatabaseException : public utils::BasicException { using utils::BasicException::BasicException; }; -/** - * @brief Session interface used by the DBMS to handle the the active sessions. - * @todo Try to remove this dependency from SessionContextHandler. OnDelete could be removed, as it only does an assert. - * OnChange could be removed if SetFor returned the pointer and the called then handled the OnChange execution. - * However, the interface is very useful to decouple the interpreter's query execution and the sessions themselves. - */ -class SessionInterface { - public: - SessionInterface() = default; - virtual ~SessionInterface() = default; - - SessionInterface(const SessionInterface &) = default; - SessionInterface &operator=(const SessionInterface &) = default; - SessionInterface(SessionInterface &&) noexcept = default; - SessionInterface &operator=(SessionInterface &&) noexcept = default; - - /** - * @brief Return the unique string identifying the session. - * - * @return std::string - */ - virtual std::string UUID() const = 0; - - /** - * @brief Return the currently active database. - * - * @return std::string - */ - virtual std::string GetDatabaseName() const = 0; - -#ifdef MG_ENTERPRISE - /** - * @brief Gets called on database change. - * - * @return SetForResult enum (SUCCESS, ALREADY_SET or FAIL) - */ - virtual dbms::SetForResult OnChange(const std::string &) = 0; - - /** - * @brief Callback that gets called on database delete (drop). - * - * @return true on success - */ - virtual bool OnDelete(const std::string &) = 0; -#endif -}; - } // namespace memgraph::dbms diff --git a/src/dbms/handler.hpp b/src/dbms/handler.hpp index 16db5558a..6558fee85 100644 --- a/src/dbms/handler.hpp +++ b/src/dbms/handler.hpp @@ -18,21 +18,21 @@ #include #include "global.hpp" +#include "utils/exceptions.hpp" +#include "utils/gatekeeper.hpp" #include "utils/result.hpp" -#include "utils/sync_ptr.hpp" namespace memgraph::dbms { /** * @brief Generic multi-database content handler. * - * @tparam TContext - * @tparam TConfig + * @tparam T */ -template +template class Handler { public: - using NewResult = utils::BasicResult>; + using NewResult = utils::BasicResult::Accessor>; /** * @brief Empty Handler constructor. @@ -43,67 +43,65 @@ class Handler { /** * @brief Generate a new context and corresponding configuration. * - * @tparam T1 Variadic template of context constructor arguments - * @tparam T2 Variadic template of config constructor arguments - * @param name Name associated with the new context/config pair - * @param args1 Arguments passed (as a tuple) to the context constructor - * @param args2 Arguments passed (as a tuple) to the config constructor + * @tparam Args Variadic template of constructor arguments of T + * @param name Name associated with the new T + * @param args Arguments passed to the constructor of T * @return NewResult */ - template - NewResult New(std::string name, std::tuple args1, std::tuple args2) { - return New_(name, args1, args2, std::make_index_sequence{}, - std::make_index_sequence{}); + template + NewResult New(std::piecewise_construct_t /* marker */, std::string_view name, Args... args) { + // Make sure the emplace will succeed, since we don't want to create temporary objects that could break something + if (!Has(name)) { + auto [itr, _] = items_.emplace(std::piecewise_construct, std::forward_as_tuple(name), + std::forward_as_tuple(std::forward(args)...)); + auto db_acc = itr->second.access(); + if (db_acc) return std::move(*db_acc); + return NewError::DEFUNCT; + } + spdlog::info("Item with name \"{}\" already exists.", name); + return NewError::EXISTS; } /** * @brief Get pointer to context. * * @param name Name associated with the wanted context - * @return std::optional> + * @return std::optional::Accessor> */ - std::optional> Get(const std::string &name) { + std::optional::Accessor> Get(std::string_view name) { if (auto search = items_.find(name); search != items_.end()) { - return search->second.get(); + return search->second.access(); } - return {}; + return std::nullopt; } /** - * @brief Get the config. + * @brief Delete the context associated with the name. * - * @param name Name associated with the wanted config - * @return std::optional - */ - std::optional GetConfig(const std::string &name) const { - if (auto search = items_.find(name); search != items_.end()) { - return search->second.config(); - } - return {}; - } - - /** - * @brief Delete the context/config pair associated with the name. - * - * @param name Name associated with the context/config pair to delete + * @param name Name associated with the context to delete * @return true on success + * @throw BasicException */ bool Delete(const std::string &name) { if (auto itr = items_.find(name); itr != items_.end()) { - itr->second.DestroyAndSync(); - items_.erase(itr); - return true; + auto db_acc = itr->second.access(); + if (db_acc && db_acc->try_delete()) { + db_acc->reset(); + items_.erase(itr); + return true; + } + return false; } - return false; + throw utils::BasicException("Unknown item \"{}\".", name); } /** * @brief Check if a name is already used. * * @param name Name to check - * @return true if a context/config pair is already associated with the name + * @return true if a T is already associated with the name */ - bool Has(const std::string &name) const { return items_.find(name) != items_.end(); } + bool Has(std::string_view name) const { return items_.find(name) != items_.end(); } auto begin() { return items_.begin(); } auto end() { return items_.end(); } @@ -112,31 +110,16 @@ class Handler { auto cbegin() const { return items_.cbegin(); } auto cend() const { return items_.cend(); } - private: - /** - * @brief Lower level handler that hides some ugly code. - * - * @tparam T1 Variadic template of context constructor arguments - * @tparam T2 Variadic template of config constructor arguments - * @tparam I1 List of indexes associated with the first tuple - * @tparam I2 List of indexes associated with the second tuple - */ - template - NewResult New_(std::string name, std::tuple &args1, std::tuple &args2, - std::integer_sequence /*not-used*/, - std::integer_sequence /*not-used*/) { - // Make sure the emplace will succeed, since we don't want to create temporary objects that could break something - if (!Has(name)) { - auto [itr, _] = items_.emplace(std::piecewise_construct, std::forward_as_tuple(name), - std::forward_as_tuple(TConfig{std::forward(std::get(args1))...}, - std::forward(std::get(args2))...)); - return itr->second.get(); - } - spdlog::info("Item with name \"{}\" already exists.", name); - return NewError::EXISTS; - } + struct string_hash { + using is_transparent = void; + [[nodiscard]] size_t operator()(const char *s) const { return std::hash{}(s); } + [[nodiscard]] size_t operator()(std::string_view s) const { return std::hash{}(s); } + [[nodiscard]] size_t operator()(const std::string &s) const { return std::hash{}(s); } + }; - std::unordered_map> items_; //!< map to all active items + private: + std::unordered_map, string_hash, std::equal_to<>> + items_; //!< map to all active items }; } // namespace memgraph::dbms diff --git a/src/dbms/interp_handler.hpp b/src/dbms/interp_handler.hpp deleted file mode 100644 index af457b6b0..000000000 --- a/src/dbms/interp_handler.hpp +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright 2023 Memgraph Ltd. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source -// License, and you may not use this file except in compliance with the Business Source License. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -#pragma once - -#ifdef MG_ENTERPRISE - -#include "global.hpp" -#include "query/auth_checker.hpp" -#include "query/config.hpp" -#include "query/interpreter.hpp" -#include "storage/v2/storage.hpp" - -#include "handler.hpp" - -namespace memgraph::dbms { - -/** - * @brief Simple class that adds useful information to the query's InterpreterContext - * - * @tparam T Multi-database handler type - */ -template -class ExpandedInterpContext : public query::InterpreterContext { - public: - template - explicit ExpandedInterpContext(T &ref, TArgs &&...args) - : query::InterpreterContext(std::forward(args)...), sc_handler_(ref) {} - - T &sc_handler_; //!< Multi-database/SessionContext handler (used in some queries) -}; - -/** - * @brief Simple structure that expands on the query's InterpreterConfig - * - */ -struct ExpandedInterpConfig { - storage::Config storage_config; //!< Storage configuration - query::InterpreterConfig interp_config; //!< Interpreter configuration -}; - -/** - * @brief Multi-database interpreter context handler - * - * @tparam TSCHandler High-level multi-database/SessionContext handler type - */ -template -class InterpContextHandler : public Handler, ExpandedInterpConfig> { - public: - using InterpContextT = ExpandedInterpContext; - using HandlerT = Handler; - - /** - * @brief Generate a new interpreter context associated with the passed name. - * - * @param name Name associating the new interpreter context - * @param sc_handler Multi-database/SessionContext handler used (some queries might use it) - * @param db Storage associated with the interpreter context - * @param config Interpreter's configuration - * @param dir Directory used by the interpreter - * @param auth_handler AuthQueryHandler used - * @param auth_checker AuthChecker used - * @return HandlerT::NewResult - */ - typename HandlerT::NewResult New(const std::string &name, TSCHandler &sc_handler, storage::Config storage_config, - const query::InterpreterConfig &interpreter_config, - query::AuthQueryHandler &auth_handler, query::AuthChecker &auth_checker) { - // Check if compatible with the existing interpreters - if (std::any_of(HandlerT::cbegin(), HandlerT::cend(), [&](const auto &elem) { - const auto &config = elem.second.config().storage_config; - return config.durability.storage_directory == storage_config.durability.storage_directory; - })) { - spdlog::info("Tried to generate a new context using claimed directory and/or storage."); - return NewError::EXISTS; - } - const auto dir = storage_config.durability.storage_directory; - storage_config.name = name; // Set storage id via config - return HandlerT::New( - name, std::forward_as_tuple(storage_config, interpreter_config), - std::forward_as_tuple(sc_handler, storage_config, interpreter_config, dir, &auth_handler, &auth_checker)); - } - - /** - * @brief All currently active storage. - * - * @return std::vector - */ - std::vector All() const { - std::vector res; - res.reserve(std::distance(HandlerT::cbegin(), HandlerT::cend())); - std::for_each(HandlerT::cbegin(), HandlerT::cend(), [&](const auto &elem) { res.push_back(elem.first); }); - return res; - } -}; - -} // namespace memgraph::dbms - -#endif diff --git a/src/dbms/session_context.hpp b/src/dbms/session_context.hpp deleted file mode 100644 index 691b9ee95..000000000 --- a/src/dbms/session_context.hpp +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2023 Memgraph Ltd. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source -// License, and you may not use this file except in compliance with the Business Source License. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -#pragma once - -#include "auth/auth.hpp" -#include "query/interpreter.hpp" -#include "storage/v2/storage.hpp" -#include "utils/synchronized.hpp" - -#if MG_ENTERPRISE -#include "audit/log.hpp" -#endif -namespace memgraph::dbms { - -/** - * @brief Structure encapsulating storage and interpreter context. - * - * @note Each session contains a copy. - */ -struct SessionContext { - // Explicit constructor here to ensure that pointers to all objects are - // supplied. - - SessionContext(std::shared_ptr interpreter_context, std::string run, - memgraph::utils::Synchronized *auth -#ifdef MG_ENTERPRISE - , - memgraph::audit::Log *audit_log -#endif - ) - : interpreter_context(interpreter_context), - run_id(run), - auth(auth) -#ifdef MG_ENTERPRISE - , - audit_log(audit_log) -#endif - { - } - - std::shared_ptr interpreter_context; - std::string run_id; - - // std::shared_ptr auth_context; - memgraph::utils::Synchronized *auth; - -#ifdef MG_ENTERPRISE - memgraph::audit::Log *audit_log; -#endif -}; - -} // namespace memgraph::dbms diff --git a/src/dbms/session_context_handler.hpp b/src/dbms/session_context_handler.hpp deleted file mode 100644 index 815c1088f..000000000 --- a/src/dbms/session_context_handler.hpp +++ /dev/null @@ -1,603 +0,0 @@ -// Copyright 2023 Memgraph Ltd. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source -// License, and you may not use this file except in compliance with the Business Source License. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "constants.hpp" -#include "global.hpp" -#include "interp_handler.hpp" -#include "query/auth_checker.hpp" -#include "query/config.hpp" -#include "query/interpreter.hpp" -#include "session_context.hpp" -#include "spdlog/spdlog.h" -#include "storage/v2/durability/durability.hpp" -#include "storage/v2/durability/paths.hpp" -#include "utils/exceptions.hpp" -#include "utils/file.hpp" -#include "utils/logging.hpp" -#include "utils/result.hpp" -#include "utils/rw_lock.hpp" -#include "utils/synchronized.hpp" -#include "utils/uuid.hpp" - -#include "handler.hpp" - -namespace memgraph::dbms { - -#ifdef MG_ENTERPRISE - -using DeleteResult = utils::BasicResult; - -/** - * @brief Multi-database session contexts handler. - */ -class SessionContextHandler { - public: - using StorageT = storage::Storage; - using StorageConfigT = storage::Config; - using LockT = utils::RWLock; - using NewResultT = utils::BasicResult; - - struct Config { - StorageConfigT storage_config; //!< Storage configuration - query::InterpreterConfig interp_config; //!< Interpreter context configuration - std::function *, - std::unique_ptr &, std::unique_ptr &)> - glue_auth; - }; - - struct Statistics { - uint64_t num_vertex; //!< Sum of vertexes in every database - uint64_t num_edges; //!< Sum of edges in every database - uint64_t num_databases; //! number of isolated databases - }; - - /** - * @brief Initialize the handler. - * - * @param audit_log pointer to the audit logger (ENTERPRISE only) - * @param configs storage and interpreter configurations - * @param recovery_on_startup restore databases (and its content) and authentication data - */ - SessionContextHandler(memgraph::audit::Log &audit_log, Config configs, bool recovery_on_startup, bool delete_on_drop) - : lock_{utils::RWLock::Priority::READ}, - default_configs_(configs), - run_id_{utils::GenerateUUID()}, - audit_log_(&audit_log), - delete_on_drop_(delete_on_drop) { - const auto &root = configs.storage_config.durability.storage_directory; - utils::EnsureDirOrDie(root); - // Verify that the user that started the process is the same user that is - // the owner of the storage directory. - storage::durability::VerifyStorageDirectoryOwnerAndProcessUserOrDie(root); - - // Create the lock file and open a handle to it. This will crash the - // database if it can't open the file for writing or if any other process is - // holding the file opened. - lock_file_path_ = root / ".lock"; - lock_file_handle_.Open(lock_file_path_, utils::OutputFile::Mode::OVERWRITE_EXISTING); - MG_ASSERT(lock_file_handle_.AcquireLock(), - "Couldn't acquire lock on the storage directory {}" - "!\nAnother Memgraph process is currently running with the same " - "storage directory, please stop it first before starting this " - "process!", - root); - - // TODO: Figure out if this is needed/wanted - // Clear auth database since we are not recovering - // if (!recovery_on_startup) { - // const auto &auth_dir = root / "auth"; - // // Backup if auth present - // if (utils::DirExists(auth_dir)) { - // auto backup_dir = root / storage::durability::kBackupDirectory; - // std::error_code error_code; - // utils::EnsureDirOrDie(backup_dir); - // std::error_code ec; - // const auto now = std::chrono::system_clock::now(); - // std::ostringstream os; - // os << now.time_since_epoch().count(); - // std::filesystem::rename(auth_dir, backup_dir / ("auth-" + os.str()), ec); - // MG_ASSERT(!ec, "Couldn't backup auth directory because of: {}", ec.message()); - // spdlog::warn( - // "Since Memgraph was not supposed to recover on startup the authentication files will be " - // "overwritten. To prevent important data loss, Memgraph has stored those files into .backup directory " - // "inside the storage directory."); - // } - - // // Clear - // if (std::filesystem::exists(auth_dir)) { - // std::filesystem::remove_all(auth_dir); - // } - // } - - // Lazy initialization of auth_ - auth_ = std::make_unique>(root / "auth"); - configs.glue_auth(auth_.get(), auth_handler_, auth_checker_); - - // TODO: Decouple storage config from dbms config - // TODO: Save individual db configs inside the kvstore and restore from there - storage::UpdatePaths(default_configs_->storage_config, - default_configs_->storage_config.durability.storage_directory / "databases"); - const auto &db_dir = default_configs_->storage_config.durability.storage_directory; - const auto durability_dir = db_dir / ".durability"; - utils::EnsureDirOrDie(db_dir); - utils::EnsureDirOrDie(durability_dir); - durability_ = std::make_unique(durability_dir); - - // Generate the default database - MG_ASSERT(!NewDefault_().HasError(), "Failed while creating the default DB."); - - // Recover previous databases - if (recovery_on_startup) { - for (const auto &[name, _] : *durability_) { - if (name == kDefaultDB) continue; // Already set - spdlog::info("Restoring database {}.", name); - MG_ASSERT(!New_(name).HasError(), "Failed while creating database {}.", name); - spdlog::info("Database {} restored.", name); - } - } else { // Clear databases from the durability list and auth - auto locked_auth = auth_->Lock(); - for (const auto &[name, _] : *durability_) { - if (name == kDefaultDB) continue; - locked_auth->DeleteDatabase(name); - durability_->Delete(name); - } - } - } - - void Shutdown() { - for (auto &ic : interp_handler_) memgraph::query::Shutdown(ic.second.get().get()); - } - - /** - * @brief Create a new SessionContext associated with the "name" database - * - * @param name name of the database - * @return NewResultT context on success, error on failure - */ - NewResultT New(const std::string &name) { - std::lock_guard wr(lock_); - return New_(name, name); - } - - /** - * @brief Get the context associated with the "name" database - * - * @param name - * @return SessionContext - * @throw UnknownDatabaseException if getting unknown database - */ - SessionContext Get(const std::string &name) { - std::shared_lock rd(lock_); - return Get_(name); - } - - /** - * @brief Set the undelying database for a particular session. - * - * @param uuid unique session identifier - * @param db_name unique database name - * @return SetForResult enum - * @throws UnknownDatabaseException, UnknownSessionException or anything OnChange throws - */ - SetForResult SetFor(const std::string &uuid, const std::string &db_name) { - std::shared_lock rd(lock_); - (void)Get_( - db_name); // throws if db doesn't exist (TODO: Better to pass it via OnChange - but injecting dependency) - try { - auto &s = sessions_.at(uuid); - return s.OnChange(db_name); - } catch (std::out_of_range &) { - throw UnknownSessionException("Unknown session \"{}\"", uuid); - } - } - - /** - * @brief Set the undelying database from a session itself. SessionContext handler. - * - * @param db_name unique database name - * @param handler function that gets called in place with the appropriate SessionContext - * @return SetForResult enum - */ - template - requires std::invocable SetForResult SetInPlace(const std::string &db_name, - THandler handler) { - std::shared_lock rd(lock_); - return handler(Get_(db_name)); - } - - /** - * @brief Call void handler under a shared lock. - * - * @param handler function that gets called in place - */ - template - requires std::invocable - void CallInPlace(THandler handler) { - std::shared_lock rd(lock_); - handler(); - } - - /** - * @brief Register an active session (used to handle callbacks). - * - * @param session - * @return true on success - */ - bool Register(SessionInterface &session) { - std::lock_guard wr(lock_); - auto [_, success] = sessions_.emplace(session.UUID(), session); - return success; - } - - /** - * @brief Delete a session. - * - * @param session - */ - bool Delete(const SessionInterface &session) { - std::lock_guard wr(lock_); - return sessions_.erase(session.UUID()) > 0; - } - - /** - * @brief Delete database. - * - * @param db_name database name - * @return DeleteResult error on failure - */ - DeleteResult Delete(const std::string &db_name) { - std::lock_guard wr(lock_); - if (db_name == kDefaultDB) { - // MSG cannot delete the default db - return DeleteError::DEFAULT_DB; - } - // Check if db exists - try { - auto sc = Get_(db_name); - // Check if a session is using the db - if (!sc.interpreter_context->interpreters->empty()) { - return DeleteError::USING; - } - } catch (UnknownDatabaseException &) { - return DeleteError::NON_EXISTENT; - } - - // High level handlers - for (auto &[_, s] : sessions_) { - if (!s.OnDelete(db_name)) { - spdlog::error("Partial failure while deleting database \"{}\".", db_name); - defunct_dbs_.emplace(db_name); - return DeleteError::FAIL; - } - } - - // Low level handlers - const auto storage_path = StorageDir_(db_name); - MG_ASSERT(storage_path, "Missing storage for {}", db_name); - if (!interp_handler_.Delete(db_name)) { - spdlog::error("Partial failure while deleting database \"{}\".", db_name); - defunct_dbs_.emplace(db_name); - return DeleteError::FAIL; - } - - // Remove from auth - auth_->Lock()->DeleteDatabase(db_name); - // Remove from durability list - if (durability_) durability_->Delete(db_name); - - // Delete disk storage - if (delete_on_drop_) { - std::error_code ec; - (void)std::filesystem::remove_all(*storage_path, ec); - if (ec) { - spdlog::error("Failed to clean disk while deleting database \"{}\".", db_name); - defunct_dbs_.emplace(db_name); - return DeleteError::DISK_FAIL; - } - } - - // Delete from defunct_dbs_ (in case a second delete call was successful) - defunct_dbs_.erase(db_name); - - return {}; // Success - } - - /** - * @brief Set the default configurations. - * - * @param configs storage, interpreter and authorization configurations - */ - void SetDefaultConfigs(Config configs) { - std::lock_guard wr(lock_); - default_configs_ = configs; - } - - /** - * @brief Get the default configurations. - * - * @return std::optional - */ - std::optional GetDefaultConfigs() const { - std::shared_lock rd(lock_); - return default_configs_; - } - - /** - * @brief Return all active databases. - * - * @return std::vector - */ - std::vector All() const { - std::shared_lock rd(lock_); - return interp_handler_.All(); - } - - /** - * @brief Return the number of vertex across all databases. - * - * @return uint64_t - */ - Statistics Info() const { - // TODO: Handle overflow - uint64_t nv = 0; - uint64_t ne = 0; - std::shared_lock rd(lock_); - const uint64_t ndb = std::distance(interp_handler_.cbegin(), interp_handler_.cend()); - for (const auto &ic : interp_handler_) { - const auto &info = ic.second.get()->db->GetInfo(); - nv += info.vertex_count; - ne += info.edge_count; - } - return {nv, ne, ndb}; - } - - /** - * @brief Return the currently active database for a particular session. - * - * @param uuid session's unique identifier - * @return std::string name of the database - * @throw - */ - std::string Current(const std::string &uuid) const { - std::shared_lock rd(lock_); - return sessions_.at(uuid).GetDatabaseName(); - } - - /** - * @brief Restore triggers for all currently defined databases. - * @note: Triggers can execute query procedures, so we need to reload the modules first and then the triggers - */ - void RestoreTriggers() { - std::lock_guard wr(lock_); - for (auto &ic_itr : interp_handler_) { - auto ic = ic_itr.second.get(); - spdlog::debug("Restoring trigger for database \"{}\"", ic->db->id()); - auto storage_accessor = ic->db->Access(); - auto dba = memgraph::query::DbAccessor{storage_accessor.get()}; - ic->trigger_store.RestoreTriggers(&ic->ast_cache, &dba, ic->config.query, ic->auth_checker); - } - } - - /** - * @brief Restore streams of all currently defined databases. - * @note: Stream transformations are using modules, they have to be restored after the query modules are loaded. - */ - void RestoreStreams() { - std::lock_guard wr(lock_); - for (auto &ic_itr : interp_handler_) { - auto ic = ic_itr.second.get(); - spdlog::debug("Restoring streams for database \"{}\"", ic->db->id()); - ic->streams.RestoreStreams(); - } - } - - private: - std::optional StorageDir_(const std::string &name) const { - const auto conf = interp_handler_.GetConfig(name); - if (conf) { - return conf->storage_config.durability.storage_directory; - } - spdlog::debug("Failed to find storage dir for database \"{}\"", name); - return {}; - } - - /** - * @brief Create a new SessionContext associated with the "name" database - * - * @param name name of the database - * @return NewResultT context on success, error on failure - */ - NewResultT New_(const std::string &name) { return New_(name, name); } - - /** - * @brief Create a new SessionContext associated with the "name" database - * - * @param name name of the database - * @param storage_subdir undelying RocksDB directory - * @return NewResultT context on success, error on failure - */ - NewResultT New_(const std::string &name, std::filesystem::path storage_subdir) { - if (default_configs_) { - auto storage = default_configs_->storage_config; - storage::UpdatePaths(storage, storage.durability.storage_directory / storage_subdir); - return New_(name, storage, default_configs_->interp_config); - } - spdlog::info("Trying to generate session context without any configurations."); - return NewError::NO_CONFIGS; - } - - /** - * @brief Create a new SessionContext associated with the "name" database - * - * @param name name of the database - * @param storage_config storage configuration - * @param inter_config interpreter configuration - * @return NewResultT context on success, error on failure - */ - NewResultT New_(const std::string &name, StorageConfigT &storage_config, query::InterpreterConfig &inter_config/*, - const std::string &ah_flags*/) { - MG_ASSERT(auth_handler_, "No high level AuthQueryHandler has been supplied."); - MG_ASSERT(auth_checker_, "No high level AuthChecker has been supplied."); - - if (defunct_dbs_.contains(name)) { - spdlog::warn("Failed to generate database due to the unknown state of the previously defunct database \"{}\".", - name); - return NewError::DEFUNCT; - } - - auto new_interp = interp_handler_.New(name, *this, storage_config, inter_config, *auth_handler_, *auth_checker_); - - if (new_interp.HasValue()) { - // Success - if (durability_) durability_->Put(name, "ok"); - return SessionContext{new_interp.GetValue(), run_id_, auth_.get(), audit_log_}; - } - return new_interp.GetError(); - } - - /** - * @brief Create a new SessionContext associated with the default database - * - * @return NewResultT context on success, error on failure - */ - NewResultT NewDefault_() { - // Create the default DB in the root (this is how it was done pre multi-tenancy) - auto res = New_(kDefaultDB, ".."); - if (res.HasValue()) { - // For back-compatibility... - // Recreate the dbms layout for the default db and symlink to the root - const auto dir = StorageDir_(kDefaultDB); - MG_ASSERT(dir, "Failed to find storage path."); - const auto main_dir = *dir / "databases" / kDefaultDB; - - if (!std::filesystem::exists(main_dir)) { - std::filesystem::create_directory(main_dir); - } - - // Force link on-disk directories - const auto conf = interp_handler_.GetConfig(kDefaultDB); - MG_ASSERT(conf, "No configuration for the default database."); - const auto &tmp_conf = conf->storage_config.disk; - std::vector to_link{ - tmp_conf.main_storage_directory, tmp_conf.label_index_directory, - tmp_conf.label_property_index_directory, tmp_conf.unique_constraints_directory, - tmp_conf.name_id_mapper_directory, tmp_conf.id_name_mapper_directory, - tmp_conf.durability_directory, tmp_conf.wal_directory, - }; - - // Add in-memory paths - // Some directories are redundant (skip those) - const std::vector skip{".lock", "audit_log", "auth", "databases", "internal_modules", "settings"}; - for (auto const &item : std::filesystem::directory_iterator{*dir}) { - const auto dir_name = std::filesystem::relative(item.path(), item.path().parent_path()); - if (std::find(skip.begin(), skip.end(), dir_name) != skip.end()) continue; - to_link.push_back(item.path()); - } - - // Symlink to root dir - for (auto const &item : to_link) { - const auto dir_name = std::filesystem::relative(item, item.parent_path()); - const auto link = main_dir / dir_name; - const auto to = std::filesystem::relative(item, main_dir); - if (!std::filesystem::is_symlink(link) && !std::filesystem::exists(link)) { - std::filesystem::create_directory_symlink(to, link); - } else { // Check existing link - std::error_code ec; - const auto test_link = std::filesystem::read_symlink(link, ec); - if (ec || test_link != to) { - MG_ASSERT(false, - "Memgraph storage directory incompatible with new version.\n" - "Please use a clean directory or remove \"{}\" and try again.", - link.string()); - } - } - } - } - return res; - } - - /** - * @brief Get the context associated with the "name" database - * - * @param name - * @return SessionContext - * @throw UnknownDatabaseException if trying to get unknown database - */ - SessionContext Get_(const std::string &name) { - auto interp = interp_handler_.Get(name); - if (interp) { - return SessionContext{*interp, run_id_, auth_.get(), audit_log_}; - } - throw UnknownDatabaseException("Tried to retrieve an unknown database \"{}\".", name); - } - - // Should storage objects ever be deleted? - mutable LockT lock_; //!< protective lock - std::filesystem::path lock_file_path_; //!< Lock file protecting the main storage - utils::OutputFile lock_file_handle_; //!< Handler the lock (crash if already open) - InterpContextHandler interp_handler_; //!< multi-tenancy interpreter handler - // AuthContextHandler auth_handler_; //!< multi-tenancy authorization handler (currently we use a single global - // auth) - std::unique_ptr> auth_; - std::unique_ptr auth_handler_; - std::unique_ptr auth_checker_; - std::optional default_configs_; //!< default storage and interpreter configurations - const std::string run_id_; //!< run's unique identifier (auto generated) - memgraph::audit::Log *audit_log_; //!< pointer to the audit logger - std::unordered_map sessions_; //!< map of active/registered sessions - std::unique_ptr durability_; //!< list of active dbs (pointer so we can postpone its creation) - - std::set defunct_dbs_; //!< Databases that are in an unknown state due to various failures - bool delete_on_drop_; //!< Flag defining if dropping storage also deletes its directory - public: - static SessionContextHandler &ExtractSCH(query::InterpreterContext *interpreter_context) { - return static_cast(interpreter_context)->sc_handler_; - } -}; - -#else -/** - * @brief Initialize the handler. - * - * @param auth pointer to the authenticator - * @param configs storage and interpreter configurations - */ -static inline SessionContext Init(storage::Config &storage_config, query::InterpreterConfig &interp_config, - utils::Synchronized *auth, - query::AuthQueryHandler *auth_handler, query::AuthChecker *auth_checker) { - MG_ASSERT(auth, "Passed a nullptr auth"); - MG_ASSERT(auth_handler, "Passed a nullptr auth_handler"); - MG_ASSERT(auth_checker, "Passed a nullptr auth_checker"); - - storage_config.name = kDefaultDB; - auto interp_context = std::make_shared( - storage_config, interp_config, storage_config.durability.storage_directory, auth_handler, auth_checker); - MG_ASSERT(interp_context, "Failed to construct main interpret context."); - - return SessionContext{interp_context, utils::GenerateUUID(), auth}; -} -#endif - -} // namespace memgraph::dbms diff --git a/src/glue/CMakeLists.txt b/src/glue/CMakeLists.txt index eef05cd81..83b0168ac 100644 --- a/src/glue/CMakeLists.txt +++ b/src/glue/CMakeLists.txt @@ -1,4 +1,11 @@ add_library(mg-glue STATIC ) -target_sources(mg-glue PRIVATE auth.cpp auth_checker.cpp auth_handler.cpp communication.cpp SessionHL.cpp ServerT.cpp MonitoringServerT.cpp) +target_sources(mg-glue PRIVATE auth.cpp + auth_checker.cpp + auth_handler.cpp + communication.cpp + SessionHL.cpp + ServerT.cpp + MonitoringServerT.cpp + run_id.cpp) target_link_libraries(mg-glue mg-query mg-auth mg-audit) target_precompile_headers(mg-glue INTERFACE auth_checker.hpp auth_handler.hpp) diff --git a/src/glue/MonitoringServerT.cpp b/src/glue/MonitoringServerT.cpp index 2fde4f572..68ea7dfdc 100644 --- a/src/glue/MonitoringServerT.cpp +++ b/src/glue/MonitoringServerT.cpp @@ -10,5 +10,4 @@ // licenses/APL.txt. #include "glue/MonitoringServerT.hpp" -template class memgraph::communication::http::Server< - memgraph::http::MetricsRequestHandler, memgraph::dbms::SessionContext>; +template class memgraph::communication::http::Server; diff --git a/src/glue/MonitoringServerT.hpp b/src/glue/MonitoringServerT.hpp index aab7a1c0c..ed219d04e 100644 --- a/src/glue/MonitoringServerT.hpp +++ b/src/glue/MonitoringServerT.hpp @@ -11,15 +11,14 @@ #pragma once #include "communication/http/server.hpp" -#include "dbms/session_context.hpp" #include "http_handlers/metrics.hpp" +#include "storage/v2/storage.hpp" -extern template class memgraph::communication::http::Server< - memgraph::http::MetricsRequestHandler, memgraph::dbms::SessionContext>; +extern template class memgraph::communication::http::Server; namespace memgraph::glue { using MonitoringServerT = - memgraph::communication::http::Server, - memgraph::dbms::SessionContext>; + memgraph::communication::http::Server; } // namespace memgraph::glue diff --git a/src/glue/ServerT.cpp b/src/glue/ServerT.cpp index 94eafcb6b..ceaad7b6c 100644 --- a/src/glue/ServerT.cpp +++ b/src/glue/ServerT.cpp @@ -10,8 +10,4 @@ // licenses/APL.txt. #include "glue/ServerT.hpp" -#ifdef MG_ENTERPRISE -template class memgraph::communication::v2::Server; -#else -template class memgraph::communication::v2::Server; -#endif +template class memgraph::communication::v2::Server; diff --git a/src/glue/ServerT.hpp b/src/glue/ServerT.hpp index 3652f94d0..641553128 100644 --- a/src/glue/ServerT.hpp +++ b/src/glue/ServerT.hpp @@ -12,24 +12,35 @@ #include "communication/v2/server.hpp" #include "glue/SessionHL.hpp" +#include "utils/synchronized.hpp" -#ifdef MG_ENTERPRISE -#include "dbms/session_context_handler.hpp" -#else -#include "dbms/session_context.hpp" +namespace memgraph::query { +struct InterpreterContext; +} + +#if MG_ENTERPRISE +namespace memgraph::audit { +class Log; +} #endif -#ifdef MG_ENTERPRISE -extern template class memgraph::communication::v2::Server; -#else -extern template class memgraph::communication::v2::Server; +namespace memgraph::auth { +class Auth; +} +namespace memgraph::utils { +class WritePrioritizedRWLock; +} + +struct Context { + memgraph::query::InterpreterContext *ic; + memgraph::utils::Synchronized *auth; +#if MG_ENTERPRISE + memgraph::audit::Log *audit_log; #endif +}; + +extern template class memgraph::communication::v2::Server; namespace memgraph::glue { -#ifdef MG_ENTERPRISE -using ServerT = memgraph::communication::v2::Server; -#else -using ServerT = memgraph::communication::v2::Server; -#endif +using ServerT = memgraph::communication::v2::Server; } // namespace memgraph::glue diff --git a/src/glue/SessionHL.cpp b/src/glue/SessionHL.cpp index 2387f7e34..87481eb20 100644 --- a/src/glue/SessionHL.cpp +++ b/src/glue/SessionHL.cpp @@ -9,19 +9,21 @@ // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. -#include "glue/SessionHL.hpp" #include +#include "gflags/gflags.h" #include "audit/log.hpp" +#include "dbms/constants.hpp" #include "flags/run_time_configurable.hpp" +#include "glue/SessionHL.hpp" #include "glue/auth_checker.hpp" #include "glue/communication.hpp" +#include "glue/run_id.hpp" #include "license/license.hpp" #include "query/discard_value_stream.hpp" +#include "query/interpreter_context.hpp" #include "utils/spin_lock.hpp" -#include "gflags/gflags.h" - namespace memgraph::metrics { extern const Event ActiveBoltSessions; } // namespace memgraph::metrics @@ -47,14 +49,14 @@ auto ToQueryExtras(const memgraph::communication::bolt::Value &extra) -> memgrap class TypedValueResultStreamBase { public: - explicit TypedValueResultStreamBase(memgraph::query::InterpreterContext *interpreterContext); + explicit TypedValueResultStreamBase(memgraph::storage::Storage *storage); std::vector DecodeValues( const std::vector &values) const; - private: + protected: // NOTE: Needed only for ToBoltValue conversions - memgraph::query::InterpreterContext *interpreter_context_; + memgraph::storage::Storage *storage_; }; /// Wrapper around TEncoder which converts TypedValue to Value @@ -62,8 +64,8 @@ class TypedValueResultStreamBase { template class TypedValueResultStream : public TypedValueResultStreamBase { public: - TypedValueResultStream(TEncoder *encoder, memgraph::query::InterpreterContext *ic) - : TypedValueResultStreamBase{ic}, encoder_(encoder) {} + TypedValueResultStream(TEncoder *encoder, memgraph::storage::Storage *storage) + : TypedValueResultStreamBase{storage}, encoder_(encoder) {} void Result(const std::vector &values) { encoder_->MessageRecord(DecodeValues(values)); } @@ -76,7 +78,7 @@ std::vector TypedValueResultStreamBase::De std::vector decoded_values; decoded_values.reserve(values.size()); for (const auto &v : values) { - auto maybe_value = memgraph::glue::ToBoltValue(v, *interpreter_context_->db, memgraph::storage::View::NEW); + auto maybe_value = memgraph::glue::ToBoltValue(v, *storage_, memgraph::storage::View::NEW); if (maybe_value.HasError()) { switch (maybe_value.GetError()) { case memgraph::storage::Error::DELETED_OBJECT: @@ -93,33 +95,13 @@ std::vector TypedValueResultStreamBase::De } return decoded_values; } -TypedValueResultStreamBase::TypedValueResultStreamBase(memgraph::query::InterpreterContext *interpreterContext) - : interpreter_context_(interpreterContext) {} +TypedValueResultStreamBase::TypedValueResultStreamBase(memgraph::storage::Storage *storage) : storage_(storage) {} namespace memgraph::glue { #ifdef MG_ENTERPRISE - -void SessionHL::UpdateAndDefunct(const std::string &db_name) { - UpdateAndDefunct(ContextWrapper(sc_handler_.Get(db_name))); -} -void SessionHL::UpdateAndDefunct(ContextWrapper &&cntxt) { - defunct_.emplace(std::move(current_)); - Update(std::forward(cntxt)); - defunct_->Defunct(); -} -void SessionHL::Update(const std::string &db_name) { - ContextWrapper tmp(sc_handler_.Get(db_name)); - Update(std::move(tmp)); -} -void SessionHL::Update(ContextWrapper &&cntxt) { - current_ = std::move(cntxt); - interpreter_ = current_.interp(); - interpreter_->in_explicit_db_ = in_explicit_db_; - interpreter_context_ = current_.interpreter_context(); -} -void SessionHL::MultiDatabaseAuth(const std::string &db) { - if (user_ && !AuthChecker::IsUserAuthorized(*user_, {}, db)) { +inline static void MultiDatabaseAuth(const std::optional &user, std::string_view db) { + if (user && !AuthChecker::IsUserAuthorized(*user, {}, std::string(db))) { throw memgraph::communication::bolt::ClientError( "You are not authorized on the database \"{}\"! Please contact your database administrator.", db); } @@ -130,23 +112,13 @@ std::string SessionHL::GetDefaultDB() { } return memgraph::dbms::kDefaultDB; } - -bool SessionHL::OnDelete(const std::string &db_name) { - MG_ASSERT(current_.interpreter_context()->db->id() != db_name && (!defunct_ || defunct_->defunct()), - "Trying to delete a database while still in use."); - return true; -} -memgraph::dbms::SetForResult SessionHL::OnChange(const std::string &db_name) { - MultiDatabaseAuth(db_name); - if (db_name != current_.interpreter_context()->db->id()) { - UpdateAndDefunct(db_name); // Done during Pull, so we cannot just replace the current db - return memgraph::dbms::SetForResult::SUCCESS; - } - return memgraph::dbms::SetForResult::ALREADY_SET; -} - #endif -std::string SessionHL::GetDatabaseName() const { return interpreter_context_->db->id(); } + +std::string SessionHL::GetDatabaseName() const { + if (!interpreter_.db_acc_) return ""; + const auto *db = interpreter_.db_acc_->get(); + return db->id(); +} std::optional SessionHL::GetServerNameForInit() { auto locked_name = flags::run_time::bolt_server_name_.Lock(); @@ -154,30 +126,29 @@ std::optional SessionHL::GetServerNameForInit() { } bool SessionHL::Authenticate(const std::string &username, const std::string &password) { - auto locked_auth = auth_->Lock(); - if (!locked_auth->HasUsers()) { - return true; - } - user_ = locked_auth->Authenticate(username, password); -#ifdef MG_ENTERPRISE - if (user_.has_value()) { - const auto &db = user_->db_access().GetDefault(); - // Check if the underlying database needs to be updated - if (db != current_.interpreter_context()->db->id()) { - const auto &res = sc_handler_.SetFor(UUID(), db); - return res == memgraph::dbms::SetForResult::SUCCESS || res == memgraph::dbms::SetForResult::ALREADY_SET; + bool res = true; + { + auto locked_auth = auth_->Lock(); + if (locked_auth->HasUsers()) { + user_ = locked_auth->Authenticate(username, password); + res = user_.has_value(); } } +#ifdef MG_ENTERPRISE + // Start off with the default database + interpreter_.SetCurrentDB(GetDefaultDB()); #endif - return user_.has_value(); + implicit_db_.emplace(GetDatabaseName()); + return res; } -void SessionHL::Abort() { interpreter_->Abort(); } + +void SessionHL::Abort() { interpreter_.Abort(); } std::map SessionHL::Discard(std::optional n, std::optional qid) { try { memgraph::query::DiscardValueResultStream stream; - return DecodeSummary(interpreter_->Pull(&stream, n, qid)); + return DecodeSummary(interpreter_.Pull(&stream, n, qid)); } catch (const memgraph::query::QueryException &e) { // Wrap QueryException into ClientError, because we want to allow the // client to fix their query. @@ -187,15 +158,18 @@ std::map SessionHL::Discard(s std::map SessionHL::Pull(SessionHL::TEncoder *encoder, std::optional n, std::optional qid) { + // TODO: Update once interpreter can handle non-database queries (db_acc will be nullopt) + auto *db = interpreter_.db_acc_->get(); try { - TypedValueResultStream stream(encoder, interpreter_context_); - return DecodeSummary(interpreter_->Pull(&stream, n, qid)); + TypedValueResultStream stream(encoder, db->storage()); + return DecodeSummary(interpreter_.Pull(&stream, n, qid)); } catch (const memgraph::query::QueryException &e) { // Wrap QueryException into ClientError, because we want to allow the // client to fix their query. throw memgraph::communication::bolt::ClientError(e.what()); } } + std::pair, std::optional> SessionHL::Interpret( const std::string &query, const std::map ¶ms, const std::map &extra) { @@ -209,16 +183,18 @@ std::pair, std::optional> SessionHL::Interpret( } #ifdef MG_ENTERPRISE + // TODO: Update once interpreter can handle non-database queries (db_acc will be nullopt) + auto *db = interpreter_.db_acc_->get(); if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) { audit_log_->Record(endpoint_.address().to_string(), user_ ? *username : "", query, - memgraph::storage::PropertyValue(params_pv), interpreter_context_->db->id()); + memgraph::storage::PropertyValue(params_pv), db->id()); } #endif try { - auto result = interpreter_->Prepare(query, params_pv, username, ToQueryExtras(extra), UUID()); + auto result = interpreter_.Prepare(query, params_pv, username, ToQueryExtras(extra), UUID()); const std::string db_name = result.db ? *result.db : ""; if (user_ && !AuthChecker::IsUserAuthorized(*user_, result.privileges, db_name)) { - interpreter_->Abort(); + interpreter_.Abort(); if (db_name.empty()) { throw memgraph::communication::bolt::ClientError( "You are not authorized to execute this query! Please contact your database administrator."); @@ -238,10 +214,10 @@ std::pair, std::optional> SessionHL::Interpret( throw memgraph::communication::bolt::ClientError(e.what()); } } -void SessionHL::RollbackTransaction() { interpreter_->RollbackTransaction(); } -void SessionHL::CommitTransaction() { interpreter_->CommitTransaction(); } +void SessionHL::RollbackTransaction() { interpreter_.RollbackTransaction(); } +void SessionHL::CommitTransaction() { interpreter_.CommitTransaction(); } void SessionHL::BeginTransaction(const std::map &extra) { - interpreter_->BeginTransaction(ToQueryExtras(extra)); + interpreter_.BeginTransaction(ToQueryExtras(extra)); } void SessionHL::Configure(const std::map &run_time_info) { #ifdef MG_ENTERPRISE @@ -254,108 +230,68 @@ void SessionHL::Configure(const std::mapdb->id(); + const auto ¤t = GetDatabaseName(); + update = db != current; + if (!in_explicit_db_) implicit_db_.emplace(current); // Still not in an explicit database, save for recovery in_explicit_db_ = true; // NOTE: Once in a transaction, the drivers stop explicitly sending the db and count on using it until commit - } else if (in_explicit_db_ && !interpreter_->in_explicit_transaction_) { // Just on a switch - db = GetDefaultDB(); - update = db != current_.interpreter_context()->db->id(); + } else if (in_explicit_db_ && !interpreter_.in_explicit_transaction_) { // Just on a switch + if (implicit_db_) { + db = *implicit_db_; + } else { + db = GetDefaultDB(); + } + update = db != GetDatabaseName(); in_explicit_db_ = false; } // Check if the underlying database needs to be updated if (update) { - sc_handler_.SetInPlace(db, [this](auto new_sc) mutable { - const auto &db_name = new_sc.interpreter_context->db->id(); - MultiDatabaseAuth(db_name); - try { - Update(ContextWrapper(new_sc)); - return memgraph::dbms::SetForResult::SUCCESS; - } catch (memgraph::dbms::UnknownDatabaseException &e) { - throw memgraph::communication::bolt::ClientError("No database named \"{}\" found!", db_name); - } - }); + MultiDatabaseAuth(user_, db); + interpreter_.SetCurrentDB(db); } #endif } -SessionHL::~SessionHL() { memgraph::metrics::DecrementCounter(memgraph::metrics::ActiveBoltSessions); } -SessionHL::SessionHL( +SessionHL::SessionHL(memgraph::query::InterpreterContext *interpreter_context, + const memgraph::communication::v2::ServerEndpoint &endpoint, + memgraph::communication::v2::InputStream *input_stream, + memgraph::communication::v2::OutputStream *output_stream, + memgraph::utils::Synchronized *auth #ifdef MG_ENTERPRISE - memgraph::dbms::SessionContextHandler &sc_handler, -#else - memgraph::dbms::SessionContext sc, + , + memgraph::audit::Log *audit_log #endif - const memgraph::communication::v2::ServerEndpoint &endpoint, memgraph::communication::v2::InputStream *input_stream, - memgraph::communication::v2::OutputStream *output_stream, const std::string &default_db) // NOLINT + ) : Session(input_stream, output_stream), + interpreter_context_(interpreter_context), + interpreter_(interpreter_context_), #ifdef MG_ENTERPRISE - sc_handler_(sc_handler), - current_(sc_handler_.Get(default_db)), -#else - current_(sc), -#endif - interpreter_context_(current_.interpreter_context()), - interpreter_(current_.interp()), - auth_(current_.auth()), -#ifdef MG_ENTERPRISE - audit_log_(current_.audit_log()), + audit_log_(audit_log), #endif + auth_(auth), endpoint_(endpoint), - run_id_(current_.run_id()) { + implicit_db_(dbms::kDefaultDB) { // Metrics update memgraph::metrics::IncrementCounter(memgraph::metrics::ActiveBoltSessions); +#ifdef MG_ENTERPRISE + interpreter_.OnChangeCB([&](std::string_view db_name) { MultiDatabaseAuth(user_, db_name); }); +#endif + interpreter_context_->interpreters.WithLock([this](auto &interpreters) { interpreters.insert(&interpreter_); }); } -/// ContextWrapper -ContextWrapper::ContextWrapper(memgraph::dbms::SessionContext sc) - : session_context(sc), - interpreter(std::make_unique(session_context.interpreter_context.get())), - defunct_(false) { - session_context.interpreter_context->interpreters.WithLock( - [this](auto &interpreters) { interpreters.insert(interpreter.get()); }); +SessionHL::~SessionHL() { + memgraph::metrics::DecrementCounter(memgraph::metrics::ActiveBoltSessions); + interpreter_context_->interpreters.WithLock([this](auto &interpreters) { interpreters.erase(&interpreter_); }); } -ContextWrapper::~ContextWrapper() { Defunct(); } -void ContextWrapper::Defunct() { - if (!defunct_) { - session_context.interpreter_context->interpreters.WithLock( - [this](auto &interpreters) { interpreters.erase(interpreter.get()); }); - defunct_ = true; - } -} -ContextWrapper::ContextWrapper(ContextWrapper &&in) noexcept - : session_context(std::move(in.session_context)), interpreter(std::move(in.interpreter)), defunct_(in.defunct_) { - in.defunct_ = true; -} -ContextWrapper &ContextWrapper::operator=(ContextWrapper &&in) noexcept { - if (this != &in) { - Defunct(); - session_context = std::move(in.session_context); - interpreter = std::move(in.interpreter); - defunct_ = in.defunct_; - in.defunct_ = true; - } - return *this; -} -memgraph::query::InterpreterContext *ContextWrapper::interpreter_context() { - return session_context.interpreter_context.get(); -} -memgraph::query::Interpreter *ContextWrapper::interp() { return interpreter.get(); } -memgraph::utils::Synchronized *ContextWrapper::auth() - const { - return session_context.auth; -} -std::string ContextWrapper::run_id() const { return session_context.run_id; } -bool ContextWrapper::defunct() const { return defunct_; } -#ifdef MG_ENTERPRISE -memgraph::audit::Log *ContextWrapper::audit_log() const { return session_context.audit_log; } -#endif std::map SessionHL::DecodeSummary( const std::map &summary) { + // TODO: Update once interpreter can handle non-database queries (db_acc will be nullopt) + auto *db = interpreter_.db_acc_->get(); std::map decoded_summary; for (const auto &kv : summary) { - auto maybe_value = ToBoltValue(kv.second, *interpreter_context_->db, memgraph::storage::View::NEW); + auto maybe_value = ToBoltValue(kv.second, *db->storage(), memgraph::storage::View::NEW); if (maybe_value.HasError()) { switch (maybe_value.GetError()) { case memgraph::storage::Error::DELETED_OBJECT: @@ -372,14 +308,7 @@ std::map SessionHL::DecodeSum // This is sent with every query, instead of only on bolt init inside // communication/bolt/v1/states/init.hpp because neo4jdriver does not // read the init message. - if (auto run_id = run_id_; run_id) { - decoded_summary.emplace("run_id", *run_id); - } - - // Clean up previous session (session gets defunct when switching between databases) - if (defunct_) { - defunct_.reset(); - } + decoded_summary.emplace("run_id", memgraph::glue::run_id_); return decoded_summary; } diff --git a/src/glue/SessionHL.hpp b/src/glue/SessionHL.hpp index 5f15ed4cc..d6c095a04 100644 --- a/src/glue/SessionHL.hpp +++ b/src/glue/SessionHL.hpp @@ -10,56 +10,28 @@ // licenses/APL.txt. #pragma once +#include "audit/log.hpp" +#include "auth/auth.hpp" #include "communication/v2/server.hpp" #include "communication/v2/session.hpp" -#include "dbms/session_context.hpp" - -#ifdef MG_ENTERPRISE -#include "dbms/session_context_handler.hpp" -#else -#include "dbms/session_context.hpp" -#endif +#include "dbms/database.hpp" +#include "query/interpreter.hpp" namespace memgraph::glue { -struct ContextWrapper { - explicit ContextWrapper(memgraph::dbms::SessionContext sc); - ~ContextWrapper(); - - ContextWrapper(const ContextWrapper &) = delete; - ContextWrapper &operator=(const ContextWrapper &) = delete; - - ContextWrapper(ContextWrapper &&in) noexcept; - ContextWrapper &operator=(ContextWrapper &&in) noexcept; - - void Defunct(); - memgraph::query::InterpreterContext *interpreter_context(); - memgraph::query::Interpreter *interp(); - memgraph::utils::Synchronized *auth() const; - std::string run_id() const; - bool defunct() const; -#ifdef MG_ENTERPRISE - memgraph::audit::Log *audit_log() const; -#endif - - private: - memgraph::dbms::SessionContext session_context; - std::unique_ptr interpreter; - bool defunct_; -}; - class SessionHL final : public memgraph::communication::bolt::Session { public: - SessionHL( + SessionHL(memgraph::query::InterpreterContext *interpreter_context, + const memgraph::communication::v2::ServerEndpoint &endpoint, + memgraph::communication::v2::InputStream *input_stream, + memgraph::communication::v2::OutputStream *output_stream, + memgraph::utils::Synchronized *auth #ifdef MG_ENTERPRISE - memgraph::dbms::SessionContextHandler &sc_handler, -#else - memgraph::dbms::SessionContext sc, + , + memgraph::audit::Log *audit_log #endif - const memgraph::communication::v2::ServerEndpoint &endpoint, - memgraph::communication::v2::InputStream *input_stream, memgraph::communication::v2::OutputStream *output_stream, - const std::string &default_db = memgraph::dbms::kDefaultDB); + ); ~SessionHL() override; @@ -92,14 +64,8 @@ class SessionHL final : public memgraph::communication::bolt::Session GetServerNameForInit() override; std::string GetDatabaseName() const override; @@ -108,54 +74,23 @@ class SessionHL final : public memgraph::communication::bolt::Session DecodeSummary( const std::map &summary); -#ifdef MG_ENTERPRISE - /** - * @brief Update setup to the new database. - * - * @param db_name name of the target database - * @throws UnknownDatabaseException if handler cannot get it - */ - void UpdateAndDefunct(const std::string &db_name); - - void UpdateAndDefunct(ContextWrapper &&cntxt); - - void Update(const std::string &db_name); - - void Update(ContextWrapper &&cntxt); - - /** - * @brief Authenticate user on passed database. - * - * @param db database to check against - * @throws bolt::ClientError when user is not authorized - */ - void MultiDatabaseAuth(const std::string &db); - /** * @brief Get the user's default database * * @return std::string */ std::string GetDefaultDB(); -#endif - -#ifdef MG_ENTERPRISE - memgraph::dbms::SessionContextHandler &sc_handler_; -#endif - ContextWrapper current_; - std::optional defunct_; memgraph::query::InterpreterContext *interpreter_context_; - memgraph::query::Interpreter *interpreter_; - memgraph::utils::Synchronized *auth_; + memgraph::query::Interpreter interpreter_; std::optional user_; #ifdef MG_ENTERPRISE memgraph::audit::Log *audit_log_; bool in_explicit_db_{false}; //!< If true, the user has defined the database to use via metadata #endif + memgraph::utils::Synchronized *auth_; memgraph::communication::v2::ServerEndpoint endpoint_; - // NOTE: run_id should be const but that complicates code a lot. - std::optional run_id_; + std::optional implicit_db_; }; } // namespace memgraph::glue diff --git a/src/glue/auth_handler.cpp b/src/glue/auth_handler.cpp index 0abc8053d..b4ebfcd2a 100644 --- a/src/glue/auth_handler.cpp +++ b/src/glue/auth_handler.cpp @@ -406,6 +406,14 @@ bool AuthQueryHandler::SetMainDatabase(const std::string &db, const std::string throw memgraph::query::QueryRuntimeException(e.what()); } } + +void AuthQueryHandler::DeleteDatabase(std::string_view db) { + try { + auth_->Lock()->DeleteDatabase(std::string(db)); + } catch (const memgraph::auth::AuthException &e) { + throw memgraph::query::QueryRuntimeException(e.what()); + } +} #endif bool AuthQueryHandler::DropRole(const std::string &rolename) { diff --git a/src/glue/auth_handler.hpp b/src/glue/auth_handler.hpp index b2f278a47..8798c150a 100644 --- a/src/glue/auth_handler.hpp +++ b/src/glue/auth_handler.hpp @@ -17,7 +17,7 @@ #include "auth_global.hpp" #include "glue/auth.hpp" #include "license/license.hpp" -#include "query/interpreter.hpp" +#include "query/auth_query_handler.hpp" #include "utils/string.hpp" namespace memgraph::glue { @@ -45,6 +45,8 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler { std::vector> GetDatabasePrivileges(const std::string &username) override; bool SetMainDatabase(const std::string &db, const std::string &username) override; + + void DeleteDatabase(std::string_view db) override; #endif bool CreateRole(const std::string &rolename) override; diff --git a/src/glue/run_id.cpp b/src/glue/run_id.cpp new file mode 100644 index 000000000..84d00d61a --- /dev/null +++ b/src/glue/run_id.cpp @@ -0,0 +1,15 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "glue/run_id.hpp" +#include "utils/uuid.hpp" + +const std::string memgraph::glue::run_id_ = memgraph::utils::GenerateUUID(); diff --git a/src/glue/run_id.hpp b/src/glue/run_id.hpp new file mode 100644 index 000000000..6616c49bd --- /dev/null +++ b/src/glue/run_id.hpp @@ -0,0 +1,16 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +namespace memgraph::glue { +extern const std::string run_id_; +} // namespace memgraph::glue diff --git a/src/http_handlers/metrics.hpp b/src/http_handlers/metrics.hpp index a852639a3..93c114ce3 100644 --- a/src/http_handlers/metrics.hpp +++ b/src/http_handlers/metrics.hpp @@ -47,10 +47,9 @@ struct MetricsResponse { std::vector> event_histograms{}; }; -template class MetricsService { public: - explicit MetricsService(TSessionContext *session_context) : db_(session_context->interpreter_context->db.get()) {} + explicit MetricsService(storage::Storage *storage) : db_(storage) {} nlohmann::json GetMetricsJSON() { auto response = GetMetrics(); @@ -98,7 +97,7 @@ class MetricsService { return metrics_response; } - auto GetEventCounters() { + inline static std::vector> GetEventCounters() { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector> event_counters{}; event_counters.reserve(memgraph::metrics::CounterEnd()); @@ -111,7 +110,7 @@ class MetricsService { return event_counters; } - auto GetEventGauges() { + inline static std::vector> GetEventGauges() { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector> event_gauges{}; event_gauges.reserve(memgraph::metrics::GaugeEnd()); @@ -124,7 +123,7 @@ class MetricsService { return event_gauges; } - auto GetEventHistograms() { + inline static std::vector> GetEventHistograms() { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector> event_histograms{}; @@ -143,10 +142,11 @@ class MetricsService { } }; -template +// TODO: Should this be inside Database? +// Raw pointer could be dangerous class MetricsRequestHandler final { public: - explicit MetricsRequestHandler(TSessionContext *session_context) : service_(session_context) { + explicit MetricsRequestHandler(storage::Storage *storage) : service_(storage) { spdlog::info("Basic request handler started!"); } @@ -208,6 +208,6 @@ class MetricsRequestHandler final { } private: - MetricsService service_; + MetricsService service_; }; } // namespace memgraph::http diff --git a/src/memgraph.cpp b/src/memgraph.cpp index 6a5010d21..b79a3c565 100644 --- a/src/memgraph.cpp +++ b/src/memgraph.cpp @@ -9,22 +9,22 @@ // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. -#include "flags/run_time_configurable.hpp" -#ifndef MG_ENTERPRISE -#include "dbms/session_context_handler.hpp" -#endif - #include "audit/log.hpp" #include "communication/websocket/auth.hpp" #include "communication/websocket/server.hpp" +#include "dbms/constants.hpp" #include "flags/all.hpp" +#include "flags/run_time_configurable.hpp" #include "glue/MonitoringServerT.hpp" #include "glue/ServerT.hpp" #include "glue/auth_checker.hpp" #include "glue/auth_handler.hpp" +#include "glue/run_id.hpp" #include "helpers.hpp" #include "license/license_sender.hpp" +#include "query/config.hpp" #include "query/discard_value_stream.hpp" +#include "query/interpreter.hpp" #include "query/procedure/callable_alias_mapper.hpp" #include "query/procedure/module.hpp" #include "query/procedure/py_module.hpp" @@ -36,13 +36,18 @@ #include "utils/terminate_handler.hpp" #include "version.hpp" +#include "dbms/dbms_handler.hpp" +#include "query/auth_query_handler.hpp" +#include "query/interpreter_context.hpp" + constexpr const char *kMgUser = "MEMGRAPH_USER"; constexpr const char *kMgPassword = "MEMGRAPH_PASSWORD"; constexpr const char *kMgPassfile = "MEMGRAPH_PASSFILE"; -void InitFromCypherlFile(memgraph::query::InterpreterContext &ctx, std::string cypherl_file_path, - memgraph::audit::Log *audit_log = nullptr) { - memgraph::query::Interpreter interpreter(&ctx); +// TODO: move elsewhere so that we can remove need of interpreter.hpp +void InitFromCypherlFile(memgraph::query::InterpreterContext &ctx, memgraph::dbms::DatabaseAccess &db_acc, + std::string cypherl_file_path, memgraph::audit::Log *audit_log = nullptr) { + memgraph::query::Interpreter interpreter(&ctx, db_acc); std::ifstream file(cypherl_file_path); if (!file.is_open()) { @@ -198,6 +203,22 @@ int main(int argc, char **argv) { auto data_directory = std::filesystem::path(FLAGS_data_directory); + memgraph::utils::EnsureDirOrDie(data_directory); + // Verify that the user that started the process is the same user that is + // the owner of the storage directory. + memgraph::storage::durability::VerifyStorageDirectoryOwnerAndProcessUserOrDie(data_directory); + // Create the lock file and open a handle to it. This will crash the + // database if it can't open the file for writing or if any other process is + // holding the file opened. + memgraph::utils::OutputFile lock_file_handle; + lock_file_handle.Open(data_directory / ".lock", memgraph::utils::OutputFile::Mode::OVERWRITE_EXISTING); + MG_ASSERT(lock_file_handle.AcquireLock(), + "Couldn't acquire lock on the storage directory {}" + "!\nAnother Memgraph process is currently running with the same " + "storage directory, please stop it first before starting this " + "process!", + data_directory); + const auto memory_limit = memgraph::flags::GetMemoryLimit(); // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) spdlog::info("Memory limit in config is set to {}", memgraph::utils::GetReadableSize(memory_limit)); @@ -320,60 +341,63 @@ int main(int argc, char **argv) { } }; -#ifdef MG_ENTERPRISE - // SessionContext handler (multi-tenancy) - memgraph::dbms::SessionContextHandler sc_handler(audit_log, {db_config, interp_config, auth_glue}, - FLAGS_storage_recover_on_startup || FLAGS_data_recovery_on_startup, - FLAGS_storage_delete_on_drop); - // Just for current support... TODO remove - auto session_context = sc_handler.Get(memgraph::dbms::kDefaultDB); -#else - + // WIP memgraph::utils::Synchronized auth_{data_directory / "auth"}; std::unique_ptr auth_handler; std::unique_ptr auth_checker; auth_glue(&auth_, auth_handler, auth_checker); - auto session_context = memgraph::dbms::Init(db_config, interp_config, &auth_, auth_handler.get(), auth_checker.get()); +#ifdef MG_ENTERPRISE + memgraph::dbms::DbmsHandler new_handler(db_config, &auth_, FLAGS_data_recovery_on_startup, + FLAGS_storage_delete_on_drop); + auto db_acc = new_handler.Get(memgraph::dbms::kDefaultDB); + memgraph::query::InterpreterContext interpreter_context_(interp_config, &new_handler, auth_handler.get(), + auth_checker.get()); +#else + memgraph::utils::Gatekeeper db_gatekeeper{db_config}; + auto db_acc_opt = db_gatekeeper.access(); + MG_ASSERT(db_acc_opt, "Failed to access the main database"); + auto &db_acc = *db_acc_opt; + memgraph::query::InterpreterContext interpreter_context_(interp_config, nullptr, auth_handler.get(), + auth_checker.get()); #endif - - auto *auth = session_context.auth; - auto &interpreter_context = *session_context.interpreter_context; // TODO remove + MG_ASSERT(db_acc, "Failed to access the main database"); memgraph::query::procedure::gModuleRegistry.SetModulesDirectory(memgraph::flags::ParseQueryModulesDirectory(), FLAGS_data_directory); memgraph::query::procedure::gModuleRegistry.UnloadAndLoadModulesFromDirectories(); memgraph::query::procedure::gCallableAliasMapper.LoadMapping(FLAGS_query_callable_mappings_path); + // TODO Make multi-tenant if (!FLAGS_init_file.empty()) { spdlog::info("Running init file..."); #ifdef MG_ENTERPRISE if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) { - InitFromCypherlFile(interpreter_context, FLAGS_init_file, &audit_log); + InitFromCypherlFile(interpreter_context_, db_acc, FLAGS_init_file, &audit_log); } else { - InitFromCypherlFile(interpreter_context, FLAGS_init_file); + InitFromCypherlFile(interpreter_context_, db_acc, FLAGS_init_file); } #else - InitFromCypherlFile(interpreter_context, FLAGS_init_file); + InitFromCypherlFile(interpreter_context_, db_acc, FLAGS_init_file); #endif } #ifdef MG_ENTERPRISE - sc_handler.RestoreTriggers(); - sc_handler.RestoreStreams(); + new_handler.RestoreTriggers(&interpreter_context_); + new_handler.RestoreStreams(&interpreter_context_); #else { // Triggers can execute query procedures, so we need to reload the modules first and then // the triggers - auto storage_accessor = interpreter_context.db->Access(); + auto storage_accessor = db_acc->Access(); auto dba = memgraph::query::DbAccessor{storage_accessor.get()}; - interpreter_context.trigger_store.RestoreTriggers( - &interpreter_context.ast_cache, &dba, interpreter_context.config.query, interpreter_context.auth_checker); + db_acc->trigger_store()->RestoreTriggers(&interpreter_context_.ast_cache, &dba, interpreter_context_.config.query, + interpreter_context_.auth_checker); } // As the Stream transformations are using modules, they have to be restored after the query modules are loaded. - interpreter_context.streams.RestoreStreams(); + db_acc->streams()->RestoreStreams(db_acc, &interpreter_context_); #endif ServerContext context; @@ -389,29 +413,31 @@ int main(int argc, char **argv) { auto server_endpoint = memgraph::communication::v2::ServerEndpoint{ boost::asio::ip::address::from_string(FLAGS_bolt_address), static_cast(FLAGS_bolt_port)}; #ifdef MG_ENTERPRISE - memgraph::glue::ServerT server(server_endpoint, &sc_handler, &context, FLAGS_bolt_session_inactivity_timeout, - service_name, FLAGS_bolt_num_workers); + Context session_context{&interpreter_context_, &auth_, &audit_log}; #else + Context session_context{&interpreter_context_, &auth_}; +#endif memgraph::glue::ServerT server(server_endpoint, &session_context, &context, FLAGS_bolt_session_inactivity_timeout, service_name, FLAGS_bolt_num_workers); -#endif const auto machine_id = memgraph::utils::GetMachineId(); - const auto run_id = session_context.run_id; // For current compatibility // Setup telemetry static constexpr auto telemetry_server{"https://telemetry.memgraph.com/88b5e7e8-746a-11e8-9f85-538a9e9690cc/"}; std::optional telemetry; if (FLAGS_telemetry_enabled) { - telemetry.emplace(telemetry_server, data_directory / "telemetry", run_id, machine_id, std::chrono::minutes(10)); + telemetry.emplace(telemetry_server, data_directory / "telemetry", memgraph::glue::run_id_, machine_id, + std::chrono::minutes(10)); #ifdef MG_ENTERPRISE - telemetry->AddCollector("storage", [&sc_handler]() -> nlohmann::json { - const auto &info = sc_handler.Info(); + telemetry->AddCollector("storage", [&new_handler]() -> nlohmann::json { + const auto &info = new_handler.Info(); return {{"vertices", info.num_vertex}, {"edges", info.num_edges}, {"databases", info.num_databases}}; }); #else - telemetry->AddCollector("storage", [&interpreter_context]() -> nlohmann::json { - auto info = interpreter_context.db->GetInfo(); + telemetry->AddCollector("storage", [gk = &db_gatekeeper]() -> nlohmann::json { + auto db_acc = gk->access(); + MG_ASSERT(db_acc, "Failed to get access to the default database"); + auto info = db_acc->get()->GetInfo(); return {{"vertices", info.vertex_count}, {"edges", info.edge_count}}; }); #endif @@ -427,67 +453,46 @@ int main(int argc, char **argv) { return memgraph::query::plan::CallProcedure::GetAndResetCounters(); }); } - memgraph::license::LicenseInfoSender license_info_sender(telemetry_server, run_id, machine_id, memory_limit, + memgraph::license::LicenseInfoSender license_info_sender(telemetry_server, memgraph::glue::run_id_, machine_id, + memory_limit, memgraph::license::global_license_checker.GetLicenseInfo()); - memgraph::communication::websocket::SafeAuth websocket_auth{auth}; + memgraph::communication::websocket::SafeAuth websocket_auth{&auth_}; memgraph::communication::websocket::Server websocket_server{ {FLAGS_monitoring_address, static_cast(FLAGS_monitoring_port)}, &context, websocket_auth}; memgraph::flags::AddLoggerSink(websocket_server.GetLoggingSink()); - memgraph::glue::MonitoringServerT metrics_server{ - {FLAGS_metrics_address, static_cast(FLAGS_metrics_port)}, &session_context, &context}; - #ifdef MG_ENTERPRISE - if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) { - // Handler for regular termination signals - auto shutdown = [&metrics_server, &websocket_server, &server, &sc_handler] { - // Server needs to be shutdown first and then the database. This prevents - // a race condition when a transaction is accepted during server shutdown. - server.Shutdown(); - // After the server is notified to stop accepting and processing - // connections we tell the execution engine to stop processing all pending - // queries. - sc_handler.Shutdown(); + // TODO: Make multi-tenant + memgraph::glue::MonitoringServerT metrics_server{ + {FLAGS_metrics_address, static_cast(FLAGS_metrics_port)}, db_acc->storage(), &context}; +#endif - websocket_server.Shutdown(); - metrics_server.Shutdown(); - }; - - InitSignalHandlers(shutdown); - } else { - // Handler for regular termination signals - auto shutdown = [&websocket_server, &server, &interpreter_context] { - // Server needs to be shutdown first and then the database. This prevents - // a race condition when a transaction is accepted during server shutdown. - server.Shutdown(); - // After the server is notified to stop accepting and processing - // connections we tell the execution engine to stop processing all pending - // queries. - memgraph::query::Shutdown(&interpreter_context); - - websocket_server.Shutdown(); - }; - - InitSignalHandlers(shutdown); - } -#else // Handler for regular termination signals - auto shutdown = [&websocket_server, &server, &interpreter_context] { + auto shutdown = [ +#ifdef MG_ENTERPRISE + &metrics_server, +#endif + &websocket_server, &server, &interpreter_context_] { // Server needs to be shutdown first and then the database. This prevents // a race condition when a transaction is accepted during server shutdown. server.Shutdown(); // After the server is notified to stop accepting and processing // connections we tell the execution engine to stop processing all pending // queries. - memgraph::query::Shutdown(&interpreter_context); - + interpreter_context_.Shutdown(); websocket_server.Shutdown(); +#ifdef MG_ENTERPRISE + metrics_server.Shutdown(); +#endif }; InitSignalHandlers(shutdown); -#endif + // Release the temporary database access + db_acc.reset(); + + // Startup the main server MG_ASSERT(server.Start(), "Couldn't start the Bolt server!"); websocket_server.Start(); @@ -500,13 +505,16 @@ int main(int argc, char **argv) { if (!FLAGS_init_data_file.empty()) { spdlog::info("Running init data file."); #ifdef MG_ENTERPRISE + auto db_acc = new_handler.Get(memgraph::dbms::kDefaultDB); if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) { - InitFromCypherlFile(interpreter_context, FLAGS_init_data_file, &audit_log); + InitFromCypherlFile(interpreter_context_, db_acc, FLAGS_init_data_file, &audit_log); } else { - InitFromCypherlFile(interpreter_context, FLAGS_init_data_file); + InitFromCypherlFile(interpreter_context_, db_acc, FLAGS_init_data_file); } #else - InitFromCypherlFile(interpreter_context, FLAGS_init_data_file); + auto db_acc_2 = db_gatekeeper.access(); + MG_ASSERT(db_acc_2, "Failed to gain access to the main database"); + InitFromCypherlFile(interpreter_context_, *db_acc_2, FLAGS_init_data_file); #endif } diff --git a/src/query/CMakeLists.txt b/src/query/CMakeLists.txt index c1370cc0b..a66516c34 100644 --- a/src/query/CMakeLists.txt +++ b/src/query/CMakeLists.txt @@ -36,7 +36,9 @@ set(mg_query_sources trigger_context.cpp typed_value.cpp graph.cpp - db_accessor.cpp) + db_accessor.cpp + auth_query_handler.cpp + interpreter_context.cpp) add_library(mg-query STATIC ${mg_query_sources}) target_include_directories(mg-query PUBLIC ${CMAKE_SOURCE_DIR}/include) @@ -51,7 +53,8 @@ target_link_libraries(mg-query PUBLIC dl mg-kvstore mg-memory mg::csv - mg-flags) + mg-flags + mg-dbms) if(NOT "${MG_PYTHON_PATH}" STREQUAL "") set(Python3_ROOT_DIR "${MG_PYTHON_PATH}") endif() diff --git a/src/query/auth_checker.hpp b/src/query/auth_checker.hpp index f64c16b1e..1eb9d02e9 100644 --- a/src/query/auth_checker.hpp +++ b/src/query/auth_checker.hpp @@ -11,7 +11,11 @@ #pragma once -#include "query/db_accessor.hpp" +#include +#include +#include +#include + #include "query/frontend/ast/ast.hpp" #include "storage/v2/id_types.hpp" @@ -19,17 +23,19 @@ namespace memgraph::query { class FineGrainedAuthChecker; +class DbAccessor; + class AuthChecker { public: virtual ~AuthChecker() = default; [[nodiscard]] virtual bool IsUserAuthorized(const std::optional &username, - const std::vector &privileges, + const std::vector &privileges, const std::string &db_name) const = 0; #ifdef MG_ENTERPRISE [[nodiscard]] virtual std::unique_ptr GetFineGrainedAuthChecker( - const std::string &username, const memgraph::query::DbAccessor *db_accessor) const = 0; + const std::string &username, const DbAccessor *db_accessor) const = 0; virtual void ClearCache() const = 0; #endif @@ -39,75 +45,73 @@ class FineGrainedAuthChecker { public: virtual ~FineGrainedAuthChecker() = default; - [[nodiscard]] virtual bool Has(const query::VertexAccessor &vertex, memgraph::storage::View view, - query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const = 0; + [[nodiscard]] virtual bool Has(const VertexAccessor &vertex, memgraph::storage::View view, + AuthQuery::FineGrainedPrivilege fine_grained_privilege) const = 0; - [[nodiscard]] virtual bool Has(const query::EdgeAccessor &edge, - query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const = 0; + [[nodiscard]] virtual bool Has(const EdgeAccessor &edge, + AuthQuery::FineGrainedPrivilege fine_grained_privilege) const = 0; [[nodiscard]] virtual bool Has(const std::vector &labels, - query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const = 0; + AuthQuery::FineGrainedPrivilege fine_grained_privilege) const = 0; [[nodiscard]] virtual bool Has(const memgraph::storage::EdgeTypeId &edge_type, - query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const = 0; + AuthQuery::FineGrainedPrivilege fine_grained_privilege) const = 0; [[nodiscard]] virtual bool HasGlobalPrivilegeOnVertices( - memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const = 0; + AuthQuery::FineGrainedPrivilege fine_grained_privilege) const = 0; [[nodiscard]] virtual bool HasGlobalPrivilegeOnEdges( - memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const = 0; + AuthQuery::FineGrainedPrivilege fine_grained_privilege) const = 0; }; -class AllowEverythingFineGrainedAuthChecker final : public query::FineGrainedAuthChecker { +class AllowEverythingFineGrainedAuthChecker final : public FineGrainedAuthChecker { public: bool Has(const VertexAccessor & /*vertex*/, const memgraph::storage::View /*view*/, - const query::AuthQuery::FineGrainedPrivilege /*fine_grained_privilege*/) const override { + const AuthQuery::FineGrainedPrivilege /*fine_grained_privilege*/) const override { return true; } - bool Has(const memgraph::query::EdgeAccessor & /*edge*/, - const query::AuthQuery::FineGrainedPrivilege /*fine_grained_privilege*/) const override { + bool Has(const EdgeAccessor & /*edge*/, + const AuthQuery::FineGrainedPrivilege /*fine_grained_privilege*/) const override { return true; } bool Has(const std::vector & /*labels*/, - const query::AuthQuery::FineGrainedPrivilege /*fine_grained_privilege*/) const override { + const AuthQuery::FineGrainedPrivilege /*fine_grained_privilege*/) const override { return true; } bool Has(const memgraph::storage::EdgeTypeId & /*edge_type*/, - const query::AuthQuery::FineGrainedPrivilege /*fine_grained_privilege*/) const override { + const AuthQuery::FineGrainedPrivilege /*fine_grained_privilege*/) const override { return true; } - bool HasGlobalPrivilegeOnVertices( - const memgraph::query::AuthQuery::FineGrainedPrivilege /*fine_grained_privilege*/) const override { + bool HasGlobalPrivilegeOnVertices(const AuthQuery::FineGrainedPrivilege /*fine_grained_privilege*/) const override { return true; } - bool HasGlobalPrivilegeOnEdges( - const memgraph::query::AuthQuery::FineGrainedPrivilege /*fine_grained_privilege*/) const override { + bool HasGlobalPrivilegeOnEdges(const AuthQuery::FineGrainedPrivilege /*fine_grained_privilege*/) const override { return true; } -}; // namespace memgraph::query +}; #endif -class AllowEverythingAuthChecker final : public query::AuthChecker { +class AllowEverythingAuthChecker final : public AuthChecker { public: bool IsUserAuthorized(const std::optional & /*username*/, - const std::vector & /*privileges*/, + const std::vector & /*privileges*/, const std::string & /*db*/) const override { return true; } #ifdef MG_ENTERPRISE std::unique_ptr GetFineGrainedAuthChecker(const std::string & /*username*/, - const query::DbAccessor * /*dba*/) const override { + const DbAccessor * /*dba*/) const override { return std::make_unique(); } void ClearCache() const override {} #endif -}; // namespace memgraph::query +}; } // namespace memgraph::query diff --git a/src/query/auth_query_handler.cpp b/src/query/auth_query_handler.cpp new file mode 100644 index 000000000..a5337ca8f --- /dev/null +++ b/src/query/auth_query_handler.cpp @@ -0,0 +1,12 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/auth_query_handler.hpp" diff --git a/src/query/auth_query_handler.hpp b/src/query/auth_query_handler.hpp new file mode 100644 index 000000000..908dd3ebc --- /dev/null +++ b/src/query/auth_query_handler.hpp @@ -0,0 +1,126 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include +#include +#include +#include + +#include "query/frontend/ast/ast.hpp" // overkill +#include "query/typed_value.hpp" + +namespace memgraph::query { + +class AuthQueryHandler { + public: + AuthQueryHandler() = default; + virtual ~AuthQueryHandler() = default; + + AuthQueryHandler(const AuthQueryHandler &) = delete; + AuthQueryHandler(AuthQueryHandler &&) = delete; + AuthQueryHandler &operator=(const AuthQueryHandler &) = delete; + AuthQueryHandler &operator=(AuthQueryHandler &&) = delete; + + /// Return false if the user already exists. + /// @throw QueryRuntimeException if an error ocurred. + virtual bool CreateUser(const std::string &username, const std::optional &password) = 0; + + /// Return false if the user does not exist. + /// @throw QueryRuntimeException if an error ocurred. + virtual bool DropUser(const std::string &username) = 0; + + /// @throw QueryRuntimeException if an error ocurred. + virtual void SetPassword(const std::string &username, const std::optional &password) = 0; + +#ifdef MG_ENTERPRISE + /// Return true if access revoked successfully + /// @throw QueryRuntimeException if an error ocurred. + virtual bool RevokeDatabaseFromUser(const std::string &db, const std::string &username) = 0; + + /// Return true if access granted successfully + /// @throw QueryRuntimeException if an error ocurred. + virtual bool GrantDatabaseToUser(const std::string &db, const std::string &username) = 0; + + /// Returns database access rights for the user + /// @throw QueryRuntimeException if an error ocurred. + virtual std::vector> GetDatabasePrivileges(const std::string &username) = 0; + + /// Return true if main database set successfully + /// @throw QueryRuntimeException if an error ocurred. + virtual bool SetMainDatabase(const std::string &db, const std::string &username) = 0; + + /// Delete database from all users + /// @throw QueryRuntimeException if an error ocurred. + virtual void DeleteDatabase(std::string_view db) = 0; +#endif + + /// Return false if the role already exists. + /// @throw QueryRuntimeException if an error ocurred. + virtual bool CreateRole(const std::string &rolename) = 0; + + /// Return false if the role does not exist. + /// @throw QueryRuntimeException if an error ocurred. + virtual bool DropRole(const std::string &rolename) = 0; + + /// @throw QueryRuntimeException if an error ocurred. + virtual std::vector GetUsernames() = 0; + + /// @throw QueryRuntimeException if an error ocurred. + virtual std::vector GetRolenames() = 0; + + /// @throw QueryRuntimeException if an error ocurred. + virtual std::optional GetRolenameForUser(const std::string &username) = 0; + + /// @throw QueryRuntimeException if an error ocurred. + virtual std::vector GetUsernamesForRole(const std::string &rolename) = 0; + + /// @throw QueryRuntimeException if an error ocurred. + virtual void SetRole(const std::string &username, const std::string &rolename) = 0; + + /// @throw QueryRuntimeException if an error ocurred. + virtual void ClearRole(const std::string &username) = 0; + + virtual std::vector> GetPrivileges(const std::string &user_or_role) = 0; + + /// @throw QueryRuntimeException if an error ocurred. + virtual void GrantPrivilege( + const std::string &user_or_role, const std::vector &privileges +#ifdef MG_ENTERPRISE + , + const std::vector>> + &label_privileges, + + const std::vector>> + &edge_type_privileges +#endif + ) = 0; + + /// @throw QueryRuntimeException if an error ocurred. + virtual void DenyPrivilege(const std::string &user_or_role, + const std::vector &privileges) = 0; + + /// @throw QueryRuntimeException if an error ocurred. + virtual void RevokePrivilege( + const std::string &user_or_role, const std::vector &privileges +#ifdef MG_ENTERPRISE + , + const std::vector>> + &label_privileges, + + const std::vector>> + &edge_type_privileges +#endif + ) = 0; +}; + +} // namespace memgraph::query diff --git a/src/query/cypher_query_interpreter.cpp b/src/query/cypher_query_interpreter.cpp index 8759333bc..3deb5ccb5 100644 --- a/src/query/cypher_query_interpreter.cpp +++ b/src/query/cypher_query_interpreter.cpp @@ -10,6 +10,8 @@ // licenses/APL.txt. #include "query/cypher_query_interpreter.hpp" +#include "query/frontend/ast/cypher_main_visitor.hpp" +#include "query/frontend/opencypher/parser.hpp" // NOLINTNEXTLINE (cppcoreguidelines-avoid-non-const-global-variables) DEFINE_bool(query_cost_planner, true, "Use the cost-estimating query planner."); diff --git a/src/query/cypher_query_interpreter.hpp b/src/query/cypher_query_interpreter.hpp index ab8f9c4dd..b920beb8b 100644 --- a/src/query/cypher_query_interpreter.hpp +++ b/src/query/cypher_query_interpreter.hpp @@ -12,8 +12,6 @@ #pragma once #include "query/config.hpp" -#include "query/frontend/ast/cypher_main_visitor.hpp" -#include "query/frontend/opencypher/parser.hpp" #include "query/frontend/semantic/required_privileges.hpp" #include "query/frontend/semantic/symbol_generator.hpp" #include "query/frontend/stripped.hpp" diff --git a/src/query/exceptions.hpp b/src/query/exceptions.hpp index fdb374398..7ba12c393 100644 --- a/src/query/exceptions.hpp +++ b/src/query/exceptions.hpp @@ -158,6 +158,11 @@ class ExplicitTransactionUsageException : public QueryRuntimeException { using QueryRuntimeException::QueryRuntimeException; }; +class DatabaseContextRequiredException : public QueryRuntimeException { + public: + using QueryRuntimeException::QueryRuntimeException; +}; + class WriteVertexOperationInEdgeImportModeException : public QueryException { public: WriteVertexOperationInEdgeImportModeException() diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index 0b70536bd..b8e980cb7 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -33,12 +33,14 @@ #include "auth/auth.hpp" #include "auth/models.hpp" #include "csv/parsing.hpp" +#include "dbms/database.hpp" +#include "dbms/dbms_handler.hpp" #include "dbms/global.hpp" -#include "dbms/session_context_handler.hpp" #include "flags/run_time_configurable.hpp" #include "glue/communication.hpp" #include "license/license.hpp" #include "memory/memory_control.hpp" +#include "query/config.hpp" #include "query/constants.hpp" #include "query/context.hpp" #include "query/cypher_query_interpreter.hpp" @@ -60,6 +62,7 @@ #include "query/procedure/module.hpp" #include "query/stream.hpp" #include "query/stream/common.hpp" +#include "query/stream/streams.hpp" #include "query/trigger.hpp" #include "query/typed_value.hpp" #include "spdlog/spdlog.h" @@ -90,6 +93,11 @@ #include "utils/tsc.hpp" #include "utils/typeinfo.hpp" #include "utils/variant_helpers.hpp" + +#include "dbms/dbms_handler.hpp" +#include "query/auth_query_handler.hpp" +#include "query/interpreter_context.hpp" + namespace memgraph::metrics { extern Event ReadQuery; extern Event WriteQuery; @@ -112,6 +120,7 @@ constexpr auto kAlwaysFalse = false; namespace { template + void Sort(std::vector &vec) { std::sort(vec.begin(), vec.end()); } @@ -126,9 +135,11 @@ void Sort(std::vector &vec) { bool Same(const TypedValue &lv, const TypedValue &rv) { return TypedValue(lv).ValueString() == TypedValue(rv).ValueString(); } +// NOLINTNEXTLINE (misc-unused-parameters) bool Same(const TypedValue &lv, const std::string &rv) { return std::string(TypedValue(lv).ValueString()) == rv; } // NOLINTNEXTLINE (misc-unused-parameters) bool Same(const std::string &lv, const TypedValue &rv) { return lv == std::string(TypedValue(rv).ValueString()); } +// NOLINTNEXTLINE (misc-unused-parameters) bool Same(const std::string &lv, const std::string &rv) { return lv == rv; } void UpdateTypeCount(const plan::ReadWriteTypeChecker::RWType type) { @@ -357,7 +368,7 @@ class ReplQueryHandler final : public query::ReplicationQueryHandler { Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_context, const Parameters ¶meters) { AuthQueryHandler *auth = interpreter_context->auth; #ifdef MG_ENTERPRISE - auto &sc_handler = memgraph::dbms::SessionContextHandler::ExtractSCH(interpreter_context); + auto *db_handler = interpreter_context->db_handler; #endif // TODO: MemoryResource for EvaluationContext, it should probably be passed as // the argument to Callback. @@ -566,11 +577,12 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_ return callback; case AuthQuery::Action::GRANT_DATABASE_TO_USER: #ifdef MG_ENTERPRISE - callback.fn = [auth, database, username, &sc_handler] { // NOLINT + callback.fn = [auth, database, username, db_handler] { // NOLINT try { - memgraph::dbms::SessionContext sc(nullptr, "", nullptr, nullptr); + std::optional db = + std::nullopt; // Hold pointer to database to protect it until query is done if (database != memgraph::auth::kAllDatabases) { - sc = sc_handler.Get(database); // Will throw if databases doesn't exist and protect it during pull + db = db_handler->Get(database); // Will throw if databases doesn't exist and protect it during pull } if (!auth->GrantDatabaseToUser(database, username)) { throw QueryRuntimeException("Failed to grant database {} to user {}.", database, username); @@ -586,11 +598,12 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_ return callback; case AuthQuery::Action::REVOKE_DATABASE_FROM_USER: #ifdef MG_ENTERPRISE - callback.fn = [auth, database, username, &sc_handler] { // NOLINT + callback.fn = [auth, database, username, db_handler] { // NOLINT try { - memgraph::dbms::SessionContext sc(nullptr, "", nullptr, nullptr); + std::optional db = + std::nullopt; // Hold pointer to database to protect it until query is done if (database != memgraph::auth::kAllDatabases) { - sc = sc_handler.Get(database); // Will throw if databases doesn't exist and protect it during pull + db = db_handler->Get(database); // Will throw if databases doesn't exist and protect it during pull } if (!auth->RevokeDatabaseFromUser(database, username)) { throw QueryRuntimeException("Failed to revoke database {} from user {}.", database, username); @@ -606,19 +619,22 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_ return callback; case AuthQuery::Action::SHOW_DATABASE_PRIVILEGES: callback.header = {"grants", "denies"}; - callback.fn = [auth, username] { // NOLINT #ifdef MG_ENTERPRISE + callback.fn = [auth, username] { // NOLINT return auth->GetDatabasePrivileges(username); -#else - return std::vector>(); -#endif }; +#else + callback.fn = [] { // NOLINT + return std::vector>(); + }; +#endif return callback; case AuthQuery::Action::SET_MAIN_DATABASE: #ifdef MG_ENTERPRISE - callback.fn = [auth, database, username, &sc_handler] { // NOLINT + callback.fn = [auth, database, username, db_handler] { // NOLINT try { - const auto sc = sc_handler.Get(database); // Will throw if databases doesn't exist and protect it during pull + const auto db = + db_handler->Get(database); // Will throw if databases doesn't exist and protect it during pull if (!auth->SetMainDatabase(database, username)) { throw QueryRuntimeException("Failed to set main database {} for user {}.", database, username); } @@ -636,8 +652,8 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_ } } // namespace -Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters ¶meters, - InterpreterContext *interpreter_context, std::vector *notifications) { +Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters ¶meters, storage::Storage *storage, + const query::InterpreterConfig &config, std::vector *notifications) { // TODO: MemoryResource for EvaluationContext, it should probably be passed as // the argument to Callback. EvaluationContext evaluation_context; @@ -657,8 +673,7 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters & notifications->emplace_back(SeverityLevel::WARNING, NotificationCode::REPLICA_PORT_WARNING, "Be careful the replication port must be different from the memgraph port!"); } - callback.fn = [handler = ReplQueryHandler{interpreter_context->db.get()}, role = repl_query->role_, - maybe_port]() mutable { + callback.fn = [handler = ReplQueryHandler{storage}, role = repl_query->role_, maybe_port]() mutable { handler.SetReplicationRole(role, maybe_port); return std::vector>(); }; @@ -670,7 +685,7 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters & } case ReplicationQuery::Action::SHOW_REPLICATION_ROLE: { callback.header = {"replication role"}; - callback.fn = [handler = ReplQueryHandler{interpreter_context->db.get()}] { + callback.fn = [handler = ReplQueryHandler{storage}] { auto mode = handler.ShowReplicationRole(); switch (mode) { case ReplicationQuery::ReplicationRole::MAIN: { @@ -687,9 +702,9 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters & const auto &name = repl_query->replica_name_; const auto &sync_mode = repl_query->sync_mode_; auto socket_address = repl_query->socket_address_->Accept(evaluator); - const auto replica_check_frequency = interpreter_context->config.replication_replica_check_frequency; + const auto replica_check_frequency = config.replication_replica_check_frequency; - callback.fn = [handler = ReplQueryHandler{interpreter_context->db.get()}, name, socket_address, sync_mode, + callback.fn = [handler = ReplQueryHandler{storage}, name, socket_address, sync_mode, replica_check_frequency]() mutable { handler.RegisterReplica(name, std::string(socket_address.ValueString()), sync_mode, replica_check_frequency); return std::vector>(); @@ -701,7 +716,7 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters & case ReplicationQuery::Action::DROP_REPLICA: { const auto &name = repl_query->replica_name_; - callback.fn = [handler = ReplQueryHandler{interpreter_context->db.get()}, name]() mutable { + callback.fn = [handler = ReplQueryHandler{storage}, name]() mutable { handler.DropReplica(name); return std::vector>(); }; @@ -714,8 +729,7 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters & callback.header = { "name", "socket_address", "sync_mode", "current_timestamp_of_replica", "number_of_timestamp_behind_master", "state"}; - callback.fn = [handler = ReplQueryHandler{interpreter_context->db.get()}, - replica_nfields = callback.header.size()] { + callback.fn = [handler = ReplQueryHandler{storage}, replica_nfields = callback.header.size()] { const auto &replicas = handler.ShowReplicas(); auto typed_replicas = std::vector>{}; typed_replicas.reserve(replicas.size()); @@ -787,6 +801,7 @@ std::vector EvaluateTopicNames(ExpressionVisitor &evalu } Callback::CallbackFunction GetKafkaCreateCallback(StreamQuery *stream_query, ExpressionVisitor &evaluator, + memgraph::dbms::DatabaseAccess db_acc, InterpreterContext *interpreter_context, const std::string *username) { static constexpr std::string_view kDefaultConsumerGroup = "mg_consumer"; @@ -815,29 +830,30 @@ Callback::CallbackFunction GetKafkaCreateCallback(StreamQuery *stream_query, Exp memgraph::metrics::IncrementCounter(memgraph::metrics::StreamsCreated); - return [interpreter_context, stream_name = stream_query->stream_name_, + return [db_acc = std::move(db_acc), interpreter_context, stream_name = stream_query->stream_name_, topic_names = EvaluateTopicNames(evaluator, stream_query->topic_names_), consumer_group = std::move(consumer_group), common_stream_info = std::move(common_stream_info), bootstrap_servers = std::move(bootstrap), owner = StringPointerToOptional(username), configs = get_config_map(stream_query->configs_, "Configs"), - credentials = get_config_map(stream_query->credentials_, "Credentials")]() mutable { - std::string bootstrap = bootstrap_servers - ? std::move(*bootstrap_servers) - : std::string{interpreter_context->config.default_kafka_bootstrap_servers}; - interpreter_context->streams.Create(stream_name, - {.common_info = std::move(common_stream_info), - .topics = std::move(topic_names), - .consumer_group = std::move(consumer_group), - .bootstrap_servers = std::move(bootstrap), - .configs = std::move(configs), - .credentials = std::move(credentials)}, - std::move(owner)); + credentials = get_config_map(stream_query->credentials_, "Credentials"), + default_server = interpreter_context->config.default_kafka_bootstrap_servers]() mutable { + std::string bootstrap = bootstrap_servers ? std::move(*bootstrap_servers) : std::move(default_server); + + db_acc->streams()->Create(stream_name, + {.common_info = std::move(common_stream_info), + .topics = std::move(topic_names), + .consumer_group = std::move(consumer_group), + .bootstrap_servers = std::move(bootstrap), + .configs = std::move(configs), + .credentials = std::move(credentials)}, + std::move(owner), db_acc, interpreter_context); return std::vector>{}; }; } Callback::CallbackFunction GetPulsarCreateCallback(StreamQuery *stream_query, ExpressionVisitor &evaluator, + memgraph::dbms::DatabaseAccess db, InterpreterContext *interpreter_context, const std::string *username) { auto service_url = GetOptionalStringValue(stream_query->service_url_, evaluator); @@ -847,24 +863,24 @@ Callback::CallbackFunction GetPulsarCreateCallback(StreamQuery *stream_query, Ex auto common_stream_info = GetCommonStreamInfo(stream_query, evaluator); memgraph::metrics::IncrementCounter(memgraph::metrics::StreamsCreated); - return [interpreter_context, stream_name = stream_query->stream_name_, + return [db = std::move(db), interpreter_context, stream_name = stream_query->stream_name_, topic_names = EvaluateTopicNames(evaluator, stream_query->topic_names_), common_stream_info = std::move(common_stream_info), service_url = std::move(service_url), - owner = StringPointerToOptional(username)]() mutable { - std::string url = - service_url ? std::move(*service_url) : std::string{interpreter_context->config.default_pulsar_service_url}; - interpreter_context->streams.Create( + owner = StringPointerToOptional(username), + default_service = interpreter_context->config.default_pulsar_service_url]() mutable { + std::string url = service_url ? std::move(*service_url) : std::move(default_service); + db->streams()->Create( stream_name, {.common_info = std::move(common_stream_info), .topics = std::move(topic_names), .service_url = std::move(url)}, - std::move(owner)); + std::move(owner), db, interpreter_context); return std::vector>{}; }; } Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters ¶meters, - InterpreterContext *interpreter_context, const std::string *username, - std::vector *notifications) { + memgraph::dbms::DatabaseAccess &db_acc, InterpreterContext *interpreter_context, + const std::string *username, std::vector *notifications) { // TODO: MemoryResource for EvaluationContext, it should probably be passed as // the argument to Callback. EvaluationContext evaluation_context; @@ -877,10 +893,10 @@ Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters ¶mete case StreamQuery::Action::CREATE_STREAM: { switch (stream_query->type_) { case StreamQuery::Type::KAFKA: - callback.fn = GetKafkaCreateCallback(stream_query, evaluator, interpreter_context, username); + callback.fn = GetKafkaCreateCallback(stream_query, evaluator, db_acc, interpreter_context, username); break; case StreamQuery::Type::PULSAR: - callback.fn = GetPulsarCreateCallback(stream_query, evaluator, interpreter_context, username); + callback.fn = GetPulsarCreateCallback(stream_query, evaluator, db_acc, interpreter_context, username); break; } notifications->emplace_back(SeverityLevel::INFO, NotificationCode::CREATE_STREAM, @@ -896,13 +912,13 @@ Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters ¶mete throw utils::BasicException("Parameter BATCH_LIMIT cannot hold negative value"); } - callback.fn = [interpreter_context, stream_name = stream_query->stream_name_, batch_limit, timeout]() { - interpreter_context->streams.StartWithLimit(stream_name, static_cast(batch_limit.value()), timeout); + callback.fn = [streams = db_acc->streams(), stream_name = stream_query->stream_name_, batch_limit, timeout]() { + streams->StartWithLimit(stream_name, static_cast(batch_limit.value()), timeout); return std::vector>{}; }; } else { - callback.fn = [interpreter_context, stream_name = stream_query->stream_name_]() { - interpreter_context->streams.Start(stream_name); + callback.fn = [streams = db_acc->streams(), stream_name = stream_query->stream_name_]() { + streams->Start(stream_name); return std::vector>{}; }; notifications->emplace_back(SeverityLevel::INFO, NotificationCode::START_STREAM, @@ -911,16 +927,16 @@ Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters ¶mete return callback; } case StreamQuery::Action::START_ALL_STREAMS: { - callback.fn = [interpreter_context]() { - interpreter_context->streams.StartAll(); + callback.fn = [streams = db_acc->streams()]() { + streams->StartAll(); return std::vector>{}; }; notifications->emplace_back(SeverityLevel::INFO, NotificationCode::START_ALL_STREAMS, "Started all streams."); return callback; } case StreamQuery::Action::STOP_STREAM: { - callback.fn = [interpreter_context, stream_name = stream_query->stream_name_]() { - interpreter_context->streams.Stop(stream_name); + callback.fn = [streams = db_acc->streams(), stream_name = stream_query->stream_name_]() { + streams->Stop(stream_name); return std::vector>{}; }; notifications->emplace_back(SeverityLevel::INFO, NotificationCode::STOP_STREAM, @@ -928,16 +944,16 @@ Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters ¶mete return callback; } case StreamQuery::Action::STOP_ALL_STREAMS: { - callback.fn = [interpreter_context]() { - interpreter_context->streams.StopAll(); + callback.fn = [streams = db_acc->streams()]() { + streams->StopAll(); return std::vector>{}; }; notifications->emplace_back(SeverityLevel::INFO, NotificationCode::STOP_ALL_STREAMS, "Stopped all streams."); return callback; } case StreamQuery::Action::DROP_STREAM: { - callback.fn = [interpreter_context, stream_name = stream_query->stream_name_]() { - interpreter_context->streams.Drop(stream_name); + callback.fn = [streams = db_acc->streams(), stream_name = stream_query->stream_name_]() { + streams->Drop(stream_name); return std::vector>{}; }; notifications->emplace_back(SeverityLevel::INFO, NotificationCode::DROP_STREAM, @@ -946,8 +962,8 @@ Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters ¶mete } case StreamQuery::Action::SHOW_STREAMS: { callback.header = {"name", "type", "batch_interval", "batch_size", "transformation_name", "owner", "is running"}; - callback.fn = [interpreter_context]() { - auto streams_status = interpreter_context->streams.GetStreamInfo(); + callback.fn = [streams = db_acc->streams()]() { + auto streams_status = streams->GetStreamInfo(); std::vector> results; results.reserve(streams_status.size()); auto stream_info_as_typed_stream_info_emplace_in = [](auto &typed_status, const auto &stream_info) { @@ -983,10 +999,11 @@ Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters ¶mete throw utils::BasicException("Parameter BATCH_LIMIT cannot hold negative value"); } - callback.fn = [interpreter_context, stream_name = stream_query->stream_name_, + callback.fn = [db_acc, stream_name = stream_query->stream_name_, timeout = GetOptionalValue(stream_query->timeout_, evaluator), batch_limit]() mutable { - return interpreter_context->streams.Check(stream_name, timeout, batch_limit); + // TODO Is this safe + return db_acc->streams()->Check(stream_name, db_acc, timeout, batch_limit); }; notifications->emplace_back(SeverityLevel::INFO, NotificationCode::CHECK_STREAM, fmt::format("Checked stream {}.", stream_query->stream_name_)); @@ -1346,33 +1363,30 @@ storage::replication::ReplicationRole GetReplicaRole(storage::Storage *storage) } // namespace -InterpreterContext::InterpreterContext(const storage::Config storage_config, const InterpreterConfig interpreter_config, - const std::filesystem::path &data_directory, query::AuthQueryHandler *ah, - query::AuthChecker *ac) - : auth(ah), - auth_checker(ac), - trigger_store(data_directory / "triggers"), - config(interpreter_config), - streams{this, data_directory / "streams"} { - if (utils::DirExists(storage_config.disk.main_storage_directory)) { - db = std::make_unique(storage_config); - } else { - db = std::make_unique(storage_config); - } -} - -InterpreterContext::InterpreterContext(std::unique_ptr &&db, InterpreterConfig interpreter_config, - const std::filesystem::path &data_directory, query::AuthQueryHandler *ah, - query::AuthChecker *ac) - : db(std::move(db)), - auth(ah), - auth_checker(ac), - trigger_store(data_directory / "triggers"), - config(interpreter_config), - streams{this, data_directory / "streams"} {} +#ifdef MG_ENTERPRISE +InterpreterContext::InterpreterContext(InterpreterConfig interpreter_config, memgraph::dbms::DbmsHandler *handler, + query::AuthQueryHandler *ah, query::AuthChecker *ac) + : db_handler(handler), config(interpreter_config), auth(ah), auth_checker(ac) {} +#else +InterpreterContext::InterpreterContext(InterpreterConfig interpreter_config, + memgraph::utils::Gatekeeper *db_gatekeeper, + query::AuthQueryHandler *ah, query::AuthChecker *ac) + : db_gatekeeper(db_gatekeeper), config(interpreter_config), auth(ah), auth_checker(ac) {} +#endif Interpreter::Interpreter(InterpreterContext *interpreter_context) : interpreter_context_(interpreter_context) { MG_ASSERT(interpreter_context_, "Interpreter context must not be NULL"); +#ifndef MG_ENTERPRISE + auto db_acc = interpreter_context_->db_gatekeeper->access(); + MG_ASSERT(db_acc, "Database accessor needs to be valid"); + db_acc_ = std::move(db_acc); +#endif +} + +Interpreter::Interpreter(InterpreterContext *interpreter_context, memgraph::dbms::DatabaseAccess db) + : db_acc_(std::move(db)), interpreter_context_(interpreter_context) { + MG_ASSERT(db_acc_, "Database accessor needs to be valid"); + MG_ASSERT(interpreter_context_, "Interpreter context must not be NULL"); } auto DetermineTxTimeout(std::optional tx_timeout_ms, InterpreterConfig const &config) -> TxTimeout { @@ -1413,12 +1427,15 @@ PreparedQuery Interpreter::PrepareTransactionQuery(std::string_view query_upper, explicit_transaction_timer_ = timeout ? std::make_shared(timeout.ValueUnsafe().count()) : nullptr; - db_accessor_ = interpreter_context_->db->Access(GetIsolationLevelOverride()); + if (!db_acc_) throw DatabaseContextRequiredException("No current database for transaction defined."); + + auto &db_acc = *db_acc_; + db_accessor_ = db_acc->Access(GetIsolationLevelOverride()); execution_db_accessor_.emplace(db_accessor_.get()); transaction_status_.store(TransactionStatus::ACTIVE, std::memory_order_release); - if (interpreter_context_->trigger_store.HasTriggers()) { - trigger_context_collector_.emplace(interpreter_context_->trigger_store.GetEventTypes()); + if (db_acc->trigger_store()->HasTriggers()) { + trigger_context_collector_.emplace(db_acc->trigger_store()->GetEventTypes()); } }; } else if (query_upper == "COMMIT") { @@ -1515,7 +1532,7 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map *notifications, const std::string *username, std::atomic *transaction_status, - std::shared_ptr tx_timer, + std::shared_ptr tx_timer, auto *plan_cache, TriggerContextCollector *trigger_context_collector = nullptr, FrameChangeCollector *frame_change_collector = nullptr) { auto *cypher_query = utils::Downcast(parsed_query.query); @@ -1548,8 +1565,7 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::mapplan_cache : nullptr, dba); + parsed_query.parameters, plan_cache, dba); TryCaching(plan->ast_storage(), frame_change_collector); summary->insert_or_assign("cost_estimate", plan->cost()); @@ -1585,8 +1601,7 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map *summary, - InterpreterContext *interpreter_context, DbAccessor *dba, - utils::MemoryResource *execution_memory) { + InterpreterContext *interpreter_context, DbAccessor *dba, auto *plan_cache) { const std::string kExplainQueryStart = "explain "; MG_ASSERT(utils::StartsWith(utils::ToLowerCase(parsed_query.stripped_query.query()), kExplainQueryStart), "Expected stripped query to start with '{}'", kExplainQueryStart); @@ -1606,7 +1621,7 @@ PreparedQuery PrepareExplainQuery(ParsedQuery parsed_query, std::mapplan_cache : nullptr, dba); + parsed_inner_query.parameters, parsed_inner_query.is_cacheable ? plan_cache : nullptr, dba); std::stringstream printed_plan; plan::PrettyPrint(*dba, &cypher_query_plan->plan(), &printed_plan); @@ -1634,7 +1649,7 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra std::map *summary, InterpreterContext *interpreter_context, DbAccessor *dba, utils::MemoryResource *execution_memory, const std::string *username, std::atomic *transaction_status, - std::shared_ptr tx_timer, + std::shared_ptr tx_timer, auto *plan_cache, FrameChangeCollector *frame_change_collector) { const std::string kProfileQueryStart = "profile "; @@ -1657,7 +1672,7 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra throw ProfileInMulticommandTxException(); } - if (!interpreter_context->tsc_frequency) { + if (!memgraph::utils::IsAvailableTSC()) { throw QueryException("TSC support is missing for PROFILE"); } @@ -1694,7 +1709,7 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra auto cypher_query_plan = CypherQueryToPlan( parsed_inner_query.stripped_query.hash(), std::move(parsed_inner_query.ast_storage), cypher_query, - parsed_inner_query.parameters, parsed_inner_query.is_cacheable ? &interpreter_context->plan_cache : nullptr, dba); + parsed_inner_query.parameters, parsed_inner_query.is_cacheable ? plan_cache : nullptr, dba); TryCaching(cypher_query_plan->ast_storage(), frame_change_collector); auto rw_type_checker = plan::ReadWriteTypeChecker(); auto optional_username = StringPointerToOptional(username); @@ -1956,13 +1971,13 @@ Callback HandleAnalyzeGraphQuery(AnalyzeGraphQuery *analyze_graph_query, DbAcces } PreparedQuery PrepareAnalyzeGraphQuery(ParsedQuery parsed_query, bool in_explicit_transaction, - DbAccessor *execution_db_accessor, InterpreterContext *interpreter_context) { + DbAccessor *execution_db_accessor, auto *plan_cache) { if (in_explicit_transaction) { throw AnalyzeGraphInMulticommandTxException(); } // Creating an index influences computed plan costs. - auto invalidate_plan_cache = [plan_cache = &interpreter_context->plan_cache] { + auto invalidate_plan_cache = [plan_cache] { auto access = plan_cache->access(); for (auto &kv : access) { access.remove(kv.first); @@ -1990,7 +2005,7 @@ PreparedQuery PrepareAnalyzeGraphQuery(ParsedQuery parsed_query, bool in_explici } PreparedQuery PrepareIndexQuery(ParsedQuery parsed_query, bool in_explicit_transaction, - std::vector *notifications, InterpreterContext *interpreter_context) { + std::vector *notifications, storage::Storage *storage, auto *plan_cache) { if (in_explicit_transaction) { throw IndexInMulticommandTxException(); } @@ -1999,21 +2014,21 @@ PreparedQuery PrepareIndexQuery(ParsedQuery parsed_query, bool in_explicit_trans std::function handler; // Creating an index influences computed plan costs. - auto invalidate_plan_cache = [plan_cache = &interpreter_context->plan_cache] { + auto invalidate_plan_cache = [plan_cache] { auto access = plan_cache->access(); for (auto &kv : access) { access.remove(kv.first); } }; - auto label = interpreter_context->db->NameToLabel(index_query->label_.name); + auto label = storage->NameToLabel(index_query->label_.name); std::vector properties; std::vector properties_string; properties.reserve(index_query->properties_.size()); properties_string.reserve(index_query->properties_.size()); for (const auto &prop : index_query->properties_) { - properties.push_back(interpreter_context->db->NameToProperty(prop.name)); + properties.push_back(storage->NameToProperty(prop.name)); properties_string.push_back(prop.name); } auto properties_stringified = utils::Join(properties_string, ", "); @@ -2029,12 +2044,12 @@ PreparedQuery PrepareIndexQuery(ParsedQuery parsed_query, bool in_explicit_trans index_notification.title = fmt::format("Created index on label {} on properties {}.", index_query->label_.name, properties_stringified); - handler = [interpreter_context, label, properties_stringified = std::move(properties_stringified), + handler = [storage, label, properties_stringified = std::move(properties_stringified), label_name = index_query->label_.name, properties = std::move(properties), invalidate_plan_cache = std::move(invalidate_plan_cache)](Notification &index_notification) { MG_ASSERT(properties.size() <= 1U); - auto maybe_index_error = properties.empty() ? interpreter_context->db->CreateIndex(label) - : interpreter_context->db->CreateIndex(label, properties[0]); + auto maybe_index_error = + properties.empty() ? storage->CreateIndex(label) : storage->CreateIndex(label, properties[0]); utils::OnScopeExit invalidator(invalidate_plan_cache); if (maybe_index_error.HasError()) { @@ -2066,12 +2081,12 @@ PreparedQuery PrepareIndexQuery(ParsedQuery parsed_query, bool in_explicit_trans index_notification.code = NotificationCode::DROP_INDEX; index_notification.title = fmt::format("Dropped index on label {} on properties {}.", index_query->label_.name, utils::Join(properties_string, ", ")); - handler = [interpreter_context, label, properties_stringified = std::move(properties_stringified), + handler = [storage, label, properties_stringified = std::move(properties_stringified), label_name = index_query->label_.name, properties = std::move(properties), invalidate_plan_cache = std::move(invalidate_plan_cache)](Notification &index_notification) { MG_ASSERT(properties.size() <= 1U); - auto maybe_index_error = properties.empty() ? interpreter_context->db->DropIndex(label) - : interpreter_context->db->DropIndex(label, properties[0]); + auto maybe_index_error = + properties.empty() ? storage->DropIndex(label) : storage->DropIndex(label, properties[0]); utils::OnScopeExit invalidator(invalidate_plan_cache); if (maybe_index_error.HasError()) { @@ -2123,42 +2138,41 @@ PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transa auto callback = HandleAuthQuery(auth_query, interpreter_context, parsed_query.parameters); - return PreparedQuery{ - std::move(callback.header), std::move(parsed_query.required_privileges), - [handler = std::move(callback.fn), pull_plan = std::shared_ptr(nullptr), interpreter_context]( - AnyStream *stream, std::optional n) mutable -> std::optional { - if (!pull_plan) { - // Run the specific query - auto results = handler(); - pull_plan = std::make_shared(std::move(results)); + return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges), + [handler = std::move(callback.fn), pull_plan = std::shared_ptr(nullptr), + interpreter_context]( // NOLINT + AnyStream *stream, std::optional n) mutable -> std::optional { + if (!pull_plan) { + // Run the specific query + auto results = handler(); + pull_plan = std::make_shared(std::move(results)); #ifdef MG_ENTERPRISE - // Invalidate auth cache after every type of AuthQuery - interpreter_context->auth_checker->ClearCache(); + // Invalidate auth cache after every type of AuthQuery + interpreter_context->auth_checker->ClearCache(); #endif - } + } - if (pull_plan->Pull(stream, n)) { - return QueryHandlerResult::COMMIT; - } - return std::nullopt; - }, - RWType::NONE}; + if (pull_plan->Pull(stream, n)) { + return QueryHandlerResult::COMMIT; + } + return std::nullopt; + }, + RWType::NONE}; } PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, bool in_explicit_transaction, - std::vector *notifications, - InterpreterContext *interpreter_context) { + std::vector *notifications, storage::Storage *storage, + const InterpreterConfig &config) { if (in_explicit_transaction) { throw ReplicationModificationInMulticommandTxException(); } - if (interpreter_context->db->GetStorageMode() == storage::StorageMode::ON_DISK_TRANSACTIONAL) { + if (storage->GetStorageMode() == storage::StorageMode::ON_DISK_TRANSACTIONAL) { throw ReplicationDisabledOnDiskStorage(); } auto *replication_query = utils::Downcast(parsed_query.query); - auto callback = - HandleReplicationQuery(replication_query, parsed_query.parameters, interpreter_context, notifications); + auto callback = HandleReplicationQuery(replication_query, parsed_query.parameters, storage, config, notifications); return PreparedQuery{callback.header, std::move(parsed_query.required_privileges), [callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr{nullptr}]( @@ -2177,13 +2191,12 @@ PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, bool in_explicit // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) } -PreparedQuery PrepareLockPathQuery(ParsedQuery parsed_query, bool in_explicit_transaction, - InterpreterContext *interpreter_context) { +PreparedQuery PrepareLockPathQuery(ParsedQuery parsed_query, bool in_explicit_transaction, storage::Storage *storage) { if (in_explicit_transaction) { throw LockPathModificationInMulticommandTxException(); } - if (interpreter_context->db->GetStorageMode() == storage::StorageMode::ON_DISK_TRANSACTIONAL) { + if (storage->GetStorageMode() == storage::StorageMode::ON_DISK_TRANSACTIONAL) { throw LockPathDisabledOnDiskStorage(); } @@ -2192,9 +2205,9 @@ PreparedQuery PrepareLockPathQuery(ParsedQuery parsed_query, bool in_explicit_tr return PreparedQuery{ {"STATUS"}, std::move(parsed_query.required_privileges), - [interpreter_context, action = lock_path_query->action_]( - AnyStream *stream, std::optional n) -> std::optional { - auto *mem_storage = static_cast(interpreter_context->db.get()); + [storage, action = lock_path_query->action_](AnyStream *stream, + std::optional n) -> std::optional { + auto *mem_storage = static_cast(storage); std::vector> status; std::string res; @@ -2236,24 +2249,23 @@ PreparedQuery PrepareLockPathQuery(ParsedQuery parsed_query, bool in_explicit_tr } PreparedQuery PrepareFreeMemoryQuery(ParsedQuery parsed_query, bool in_explicit_transaction, - InterpreterContext *interpreter_context) { + storage::Storage *storage) { if (in_explicit_transaction) { throw FreeMemoryModificationInMulticommandTxException(); } - if (interpreter_context->db->GetStorageMode() == storage::StorageMode::ON_DISK_TRANSACTIONAL) { + if (storage->GetStorageMode() == storage::StorageMode::ON_DISK_TRANSACTIONAL) { throw FreeMemoryDisabledOnDiskStorage(); } - return PreparedQuery{ - {}, - std::move(parsed_query.required_privileges), - [interpreter_context](AnyStream *stream, std::optional n) -> std::optional { - interpreter_context->db->FreeMemory(); - memory::PurgeUnusedMemory(); - return QueryHandlerResult::COMMIT; - }, - RWType::NONE}; + return PreparedQuery{{}, + std::move(parsed_query.required_privileges), + [storage](AnyStream *stream, std::optional n) -> std::optional { + storage->FreeMemory(); + memory::PurgeUnusedMemory(); + return QueryHandlerResult::COMMIT; + }, + RWType::NONE}; } PreparedQuery PrepareShowConfigQuery(ParsedQuery parsed_query, bool in_explicit_transaction) { @@ -2314,34 +2326,36 @@ TriggerEventType ToTriggerEventType(const TriggerQuery::EventType event_type) { Callback CreateTrigger(TriggerQuery *trigger_query, const std::map &user_parameters, - InterpreterContext *interpreter_context, DbAccessor *dba, std::optional owner) { - return { - {}, - [trigger_name = std::move(trigger_query->trigger_name_), trigger_statement = std::move(trigger_query->statement_), - event_type = trigger_query->event_type_, before_commit = trigger_query->before_commit_, interpreter_context, dba, - user_parameters, owner = std::move(owner)]() mutable -> std::vector> { - interpreter_context->trigger_store.AddTrigger( - std::move(trigger_name), trigger_statement, user_parameters, ToTriggerEventType(event_type), - before_commit ? TriggerPhase::BEFORE_COMMIT : TriggerPhase::AFTER_COMMIT, &interpreter_context->ast_cache, - dba, interpreter_context->config.query, std::move(owner), interpreter_context->auth_checker); - memgraph::metrics::IncrementCounter(memgraph::metrics::TriggersCreated); - return {}; - }}; -} - -Callback DropTrigger(TriggerQuery *trigger_query, InterpreterContext *interpreter_context) { + TriggerStore *trigger_store, InterpreterContext *interpreter_context, DbAccessor *dba, + std::optional owner) { return {{}, [trigger_name = std::move(trigger_query->trigger_name_), - interpreter_context]() -> std::vector> { - interpreter_context->trigger_store.DropTrigger(trigger_name); + trigger_statement = std::move(trigger_query->statement_), event_type = trigger_query->event_type_, + before_commit = trigger_query->before_commit_, trigger_store, interpreter_context, dba, user_parameters, + owner = std::move(owner)]() mutable -> std::vector> { + trigger_store->AddTrigger(std::move(trigger_name), trigger_statement, user_parameters, + ToTriggerEventType(event_type), + before_commit ? TriggerPhase::BEFORE_COMMIT : TriggerPhase::AFTER_COMMIT, + &interpreter_context->ast_cache, dba, interpreter_context->config.query, + std::move(owner), interpreter_context->auth_checker); + memgraph::metrics::IncrementCounter(memgraph::metrics::TriggersCreated); return {}; }}; } -Callback ShowTriggers(InterpreterContext *interpreter_context) { - return {{"trigger name", "statement", "event type", "phase", "owner"}, [interpreter_context] { +Callback DropTrigger(TriggerQuery *trigger_query, TriggerStore *trigger_store) { + return {{}, + [trigger_name = std::move(trigger_query->trigger_name_), + trigger_store]() -> std::vector> { + trigger_store->DropTrigger(trigger_name); + return {}; + }}; +} + +Callback ShowTriggers(TriggerStore *trigger_store) { + return {{"trigger name", "statement", "event type", "phase", "owner"}, [trigger_store] { std::vector> results; - auto trigger_infos = interpreter_context->trigger_store.GetTriggerInfo(); + auto trigger_infos = trigger_store->GetTriggerInfo(); results.reserve(trigger_infos.size()); for (auto &trigger_info : trigger_infos) { std::vector typed_trigger_info; @@ -2362,8 +2376,9 @@ Callback ShowTriggers(InterpreterContext *interpreter_context) { } PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, bool in_explicit_transaction, - std::vector *notifications, InterpreterContext *interpreter_context, - DbAccessor *dba, const std::map &user_parameters, + std::vector *notifications, TriggerStore *trigger_store, + InterpreterContext *interpreter_context, DbAccessor *dba, + const std::map &user_parameters, const std::string *username) { if (in_explicit_transaction) { throw TriggerModificationInMulticommandTxException(); @@ -2373,19 +2388,19 @@ PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, bool in_explicit_tra MG_ASSERT(trigger_query); std::optional trigger_notification; - auto callback = std::invoke([trigger_query, interpreter_context, dba, &user_parameters, + auto callback = std::invoke([trigger_query, trigger_store, interpreter_context, dba, &user_parameters, owner = StringPointerToOptional(username), &trigger_notification]() mutable { switch (trigger_query->action_) { case TriggerQuery::Action::CREATE_TRIGGER: trigger_notification.emplace(SeverityLevel::INFO, NotificationCode::CREATE_TRIGGER, fmt::format("Created trigger {}.", trigger_query->trigger_name_)); - return CreateTrigger(trigger_query, user_parameters, interpreter_context, dba, std::move(owner)); + return CreateTrigger(trigger_query, user_parameters, trigger_store, interpreter_context, dba, std::move(owner)); case TriggerQuery::Action::DROP_TRIGGER: trigger_notification.emplace(SeverityLevel::INFO, NotificationCode::DROP_TRIGGER, fmt::format("Dropped trigger {}.", trigger_query->trigger_name_)); - return DropTrigger(trigger_query, interpreter_context); + return DropTrigger(trigger_query, trigger_store); case TriggerQuery::Action::SHOW_TRIGGERS: - return ShowTriggers(interpreter_context); + return ShowTriggers(trigger_store); } }); @@ -2411,8 +2426,8 @@ PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, bool in_explicit_tra } PreparedQuery PrepareStreamQuery(ParsedQuery parsed_query, bool in_explicit_transaction, - std::vector *notifications, InterpreterContext *interpreter_context, - const std::string *username) { + std::vector *notifications, memgraph::dbms::DatabaseAccess &db_acc, + InterpreterContext *interpreter_context, const std::string *username) { if (in_explicit_transaction) { throw StreamQueryInMulticommandTxException(); } @@ -2420,7 +2435,7 @@ PreparedQuery PrepareStreamQuery(ParsedQuery parsed_query, bool in_explicit_tran auto *stream_query = utils::Downcast(parsed_query.query); MG_ASSERT(stream_query); auto callback = - HandleStreamQuery(stream_query, parsed_query.parameters, interpreter_context, username, notifications); + HandleStreamQuery(stream_query, parsed_query.parameters, db_acc, interpreter_context, username, notifications); return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges), [callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr{nullptr}]( @@ -2481,7 +2496,7 @@ bool SwitchingFromDiskToInMemory(storage::StorageMode current_mode, storage::Sto } PreparedQuery PrepareIsolationLevelQuery(ParsedQuery parsed_query, const bool in_explicit_transaction, - InterpreterContext *interpreter_context, Interpreter *interpreter) { + storage::Storage *storage, Interpreter *interpreter) { if (in_explicit_transaction) { throw IsolationLevelModificationInMulticommandTxException(); } @@ -2491,12 +2506,11 @@ PreparedQuery PrepareIsolationLevelQuery(ParsedQuery parsed_query, const bool in const auto isolation_level = ToStorageIsolationLevel(isolation_level_query->isolation_level_); - auto callback = [isolation_level_query, isolation_level, interpreter_context, - interpreter]() -> std::function { + auto callback = [isolation_level_query, isolation_level, storage, interpreter]() -> std::function { switch (isolation_level_query->isolation_level_scope_) { case IsolationLevelQuery::IsolationLevelScope::GLOBAL: - return [interpreter_context, isolation_level] { - if (auto maybe_error = interpreter_context->db->SetIsolationLevel(isolation_level); maybe_error.HasError()) { + return [storage, isolation_level] { + if (auto maybe_error = storage->SetIsolationLevel(isolation_level); maybe_error.HasError()) { switch (maybe_error.GetError()) { case storage::Storage::SetIsolationLevelError::DisabledForAnalyticalMode: throw IsolationLevelModificationInAnalyticsException(); @@ -2522,9 +2536,9 @@ PreparedQuery PrepareIsolationLevelQuery(ParsedQuery parsed_query, const bool in } Callback SwitchMemoryDevice(storage::StorageMode current_mode, storage::StorageMode requested_mode, - InterpreterContext *interpreter_context) { + memgraph::dbms::DatabaseAccess &db) { Callback callback; - callback.fn = [current_mode, requested_mode, interpreter_context]() mutable { + callback.fn = [current_mode, requested_mode, &db]() mutable { if (current_mode == requested_mode) { return std::vector>(); } @@ -2535,29 +2549,38 @@ Callback SwitchMemoryDevice(storage::StorageMode current_mode, storage::StorageM "automatically start in the default in-memory transactional storage mode."); } if (SwitchingFromInMemoryToDisk(current_mode, requested_mode)) { - std::unique_lock main_guard{interpreter_context->db->main_lock_}; + if (!db.try_exclusively([](auto &in) { + if (!in.streams()->GetStreamInfo().empty()) { + throw utils::BasicException( + "You cannot switch from an in-memory storage mode to the on-disk storage mode when there are " + "associated streams. Drop all streams and retry."); + } - if (auto vertex_cnt_approx = interpreter_context->db->GetInfo().vertex_count; vertex_cnt_approx > 0) { - throw utils::BasicException( - "You cannot switch from an in-memory storage mode to the on-disk storage mode when the database " - "contains data. Delete all entries from the database, run FREE MEMORY and then repeat this " - "query. "); - } + if (!in.trigger_store()->GetTriggerInfo().empty()) { + throw utils::BasicException( + "You cannot switch from an in-memory storage mode to the on-disk storage mode when there are " + "associated triggers. Drop all triggers and retry."); + } - main_guard.unlock(); - if (interpreter_context->interpreters->size() > 1) { + std::unique_lock main_guard{in.storage()->main_lock_}; // do we need this? + if (auto vertex_cnt_approx = in.storage()->GetInfo().vertex_count; vertex_cnt_approx > 0) { + throw utils::BasicException( + "You cannot switch from an in-memory storage mode to the on-disk storage mode when the database " + "contains data. Delete all entries from the database, run FREE MEMORY and then repeat this " + "query. "); + } + main_guard.unlock(); + in.SwitchToOnDisk(); + })) { // Try exclusively failed throw utils::BasicException( "You cannot switch from an in-memory storage mode to the on-disk storage mode when there are " "multiple sessions active. Close all other sessions and try again. As Memgraph Lab uses " - "multiple sessions to run queries in parallel, " + "multiple sessions to run queries in parallel, " "it is currently impossible to switch to the on-disk storage mode within Lab. " "Close it, connect to the instance with mgconsole " "and change the storage mode to on-disk from there. Then, you can reconnect with the Lab " "and continue to use the instance as usual."); } - - auto db_config = interpreter_context->db->config_; - interpreter_context->db = std::make_unique(db_config); } return std::vector>(); }; @@ -2574,7 +2597,7 @@ bool ActiveTransactionsExist(InterpreterContext *interpreter_context) { } PreparedQuery PrepareStorageModeQuery(ParsedQuery parsed_query, const bool in_explicit_transaction, - InterpreterContext *interpreter_context) { + memgraph::dbms::DatabaseAccess &db, InterpreterContext *interpreter_context) { if (in_explicit_transaction) { throw StorageModeModificationInMulticommandTxException(); } @@ -2582,13 +2605,13 @@ PreparedQuery PrepareStorageModeQuery(ParsedQuery parsed_query, const bool in_ex auto *storage_mode_query = utils::Downcast(parsed_query.query); MG_ASSERT(storage_mode_query); const auto requested_mode = ToStorageMode(storage_mode_query->storage_mode_); - auto current_mode = interpreter_context->db->GetStorageMode(); + auto current_mode = db->GetStorageMode(); std::function callback; if (current_mode == storage::StorageMode::ON_DISK_TRANSACTIONAL || requested_mode == storage::StorageMode::ON_DISK_TRANSACTIONAL) { - callback = SwitchMemoryDevice(current_mode, requested_mode, interpreter_context).fn; + callback = SwitchMemoryDevice(current_mode, requested_mode, db).fn; } else { if (ActiveTransactionsExist(interpreter_context)) { spdlog::info( @@ -2596,8 +2619,9 @@ PreparedQuery PrepareStorageModeQuery(ParsedQuery parsed_query, const bool in_ex "transactions using 'SHOW TRANSACTIONS' query and ensure no other transactions are active."); } - callback = [requested_mode, interpreter_context]() -> std::function { - return [interpreter_context, requested_mode] { interpreter_context->db->SetStorageMode(requested_mode); }; + callback = [requested_mode, storage = db->storage()]() -> std::function { + // SetStorageMode will probably be handled at the Database level + return [storage, requested_mode] { storage->SetStorageMode(requested_mode); }; }(); } @@ -2612,12 +2636,12 @@ PreparedQuery PrepareStorageModeQuery(ParsedQuery parsed_query, const bool in_ex } PreparedQuery PrepareEdgeImportModeQuery(ParsedQuery parsed_query, const bool in_explicit_transaction, - InterpreterContext *interpreter_context) { + storage::Storage *db) { if (in_explicit_transaction) { throw EdgeImportModeModificationInMulticommandTxException(); } - if (interpreter_context->db->GetStorageMode() != storage::StorageMode::ON_DISK_TRANSACTIONAL) { + if (db->GetStorageMode() != storage::StorageMode::ON_DISK_TRANSACTIONAL) { throw EdgeImportModeQueryDisabledOnDiskStorage(); } @@ -2625,9 +2649,9 @@ PreparedQuery PrepareEdgeImportModeQuery(ParsedQuery parsed_query, const bool in MG_ASSERT(edge_import_mode_query); const auto requested_status = ToEdgeImportMode(edge_import_mode_query->status_); - auto callback = [requested_status, interpreter_context]() -> std::function { - return [interpreter_context, requested_status] { - auto *disk_storage = static_cast(interpreter_context->db.get()); + auto callback = [requested_status, db]() -> std::function { + return [db, requested_status] { + auto *disk_storage = static_cast(db); disk_storage->SetEdgeImportMode(requested_status); }; }(); @@ -2643,20 +2667,20 @@ PreparedQuery PrepareEdgeImportModeQuery(ParsedQuery parsed_query, const bool in } PreparedQuery PrepareCreateSnapshotQuery(ParsedQuery parsed_query, bool in_explicit_transaction, - InterpreterContext *interpreter_context) { + storage::Storage *storage) { if (in_explicit_transaction) { throw CreateSnapshotInMulticommandTxException(); } - if (interpreter_context->db->GetStorageMode() == storage::StorageMode::ON_DISK_TRANSACTIONAL) { + if (storage->GetStorageMode() == storage::StorageMode::ON_DISK_TRANSACTIONAL) { throw CreateSnapshotDisabledOnDiskStorage(); } return PreparedQuery{ {}, std::move(parsed_query.required_privileges), - [interpreter_context](AnyStream *stream, std::optional n) -> std::optional { - auto *mem_storage = static_cast(interpreter_context->db.get()); + [storage](AnyStream * /*stream*/, std::optional /*n*/) -> std::optional { + auto *mem_storage = static_cast(storage); if (auto maybe_error = mem_storage->CreateSnapshot({}); maybe_error.HasError()) { switch (maybe_error.GetError()) { case storage::InMemoryStorage::CreateSnapshotError::DisabledForReplica: @@ -2704,7 +2728,7 @@ PreparedQuery PrepareSettingQuery(ParsedQuery parsed_query, bool in_explicit_tra std::vector> TransactionQueueQueryHandler::ShowTransactions( const std::unordered_set &interpreters, const std::optional &username, - bool hasTransactionManagementPrivilege) { + bool hasTransactionManagementPrivilege, std::optional &filter_db_acc) { std::vector> results; results.reserve(interpreters.size()); for (Interpreter *interpreter : interpreters) { @@ -2717,6 +2741,7 @@ std::vector> TransactionQueueQueryHandler::ShowTransacti utils::OnScopeExit clean_status([interpreter]() { interpreter->transaction_status_.store(TransactionStatus::ACTIVE, std::memory_order_release); }); + if (interpreter->db_acc_ != filter_db_acc) continue; std::optional transaction_id = interpreter->GetTransactionId(); if (transaction_id.has_value() && (interpreter->username_ == username || hasTransactionManagementPrivilege)) { const auto &typed_queries = interpreter->GetQueries(); @@ -2735,57 +2760,10 @@ std::vector> TransactionQueueQueryHandler::ShowTransacti return results; } -std::vector> TransactionQueueQueryHandler::KillTransactions( - InterpreterContext *interpreter_context, const std::vector &maybe_kill_transaction_ids, - const std::optional &username, bool hasTransactionManagementPrivilege) { - std::vector> results; - for (const std::string &transaction_id : maybe_kill_transaction_ids) { - bool killed = false; - bool transaction_found = false; - // Multiple simultaneous TERMINATE TRANSACTIONS aren't allowed - // TERMINATE and SHOW TRANSACTIONS are mutually exclusive - interpreter_context->interpreters.WithLock([&transaction_id, &killed, &transaction_found, username, - hasTransactionManagementPrivilege](const auto &interpreters) { - for (Interpreter *interpreter : interpreters) { - TransactionStatus alive_status = TransactionStatus::ACTIVE; - // if it is just checking kill, commit and abort should wait for the end of the check - // The only way to start checking if the transaction will get killed is if the transaction_status is - // active - if (!interpreter->transaction_status_.compare_exchange_strong(alive_status, TransactionStatus::VERIFYING)) { - continue; - } - utils::OnScopeExit clean_status([interpreter, &killed]() { - if (killed) { - interpreter->transaction_status_.store(TransactionStatus::TERMINATED, std::memory_order_release); - } else { - interpreter->transaction_status_.store(TransactionStatus::ACTIVE, std::memory_order_release); - } - }); - - std::optional intr_trans = interpreter->GetTransactionId(); - if (intr_trans.has_value() && std::to_string(intr_trans.value()) == transaction_id) { - transaction_found = true; - if (interpreter->username_ == username || hasTransactionManagementPrivilege) { - killed = true; - spdlog::warn("Transaction {} successfully killed", transaction_id); - } else { - spdlog::warn("Not enough rights to kill the transaction"); - } - break; - } - } - }); - if (!transaction_found) { - spdlog::warn("Transaction {} not found", transaction_id); - } - results.push_back({TypedValue(transaction_id), TypedValue(killed)}); - } - return results; -} - Callback HandleTransactionQueueQuery(TransactionQueueQuery *transaction_query, const std::optional &username, const Parameters ¶meters, - InterpreterContext *interpreter_context) { + InterpreterContext *interpreter_context, + memgraph::query::Interpreter &interpreter) { EvaluationContext evaluation_context; evaluation_context.timestamp = QueryTimestamp(); evaluation_context.parameters = parameters; @@ -2794,17 +2772,23 @@ Callback HandleTransactionQueueQuery(TransactionQueueQuery *transaction_query, bool hasTransactionManagementPrivilege = interpreter_context->auth_checker->IsUserAuthorized( username, {query::AuthQuery::Privilege::TRANSACTION_MANAGEMENT}, ""); + if (!interpreter.db_acc_) { + // TODO: remove in future when we have cypher execution which can happen without any database transaction + // ie. when we have a transaction ID not tied to storage transaction + throw DatabaseContextRequiredException("No current database for transaction defined."); + } + Callback callback; switch (transaction_query->action_) { case TransactionQueueQuery::Action::SHOW_TRANSACTIONS: { callback.header = {"username", "transaction_id", "query", "metadata"}; callback.fn = [handler = TransactionQueueQueryHandler(), interpreter_context, username, - hasTransactionManagementPrivilege]() mutable { + hasTransactionManagementPrivilege, db_acc = &interpreter.db_acc_]() mutable { std::vector> results; // Multiple simultaneous SHOW TRANSACTIONS aren't allowed interpreter_context->interpreters.WithLock( - [&results, handler, username, hasTransactionManagementPrivilege](const auto &interpreters) { - results = handler.ShowTransactions(interpreters, username, hasTransactionManagementPrivilege); + [&results, handler, username, hasTransactionManagementPrivilege, db_acc](const auto &interpreters) { + results = handler.ShowTransactions(interpreters, username, hasTransactionManagementPrivilege, *db_acc); }); return results; }; @@ -2817,10 +2801,11 @@ Callback HandleTransactionQueueQuery(TransactionQueueQuery *transaction_query, return std::string(expression->Accept(evaluator).ValueString()); }); callback.header = {"transaction_id", "killed"}; - callback.fn = [handler = TransactionQueueQueryHandler(), interpreter_context, maybe_kill_transaction_ids, - username, hasTransactionManagementPrivilege]() mutable { - return handler.KillTransactions(interpreter_context, maybe_kill_transaction_ids, username, - hasTransactionManagementPrivilege); + callback.fn = [handler = TransactionQueueQueryHandler(), interpreter_context, + maybe_kill_transaction_ids = std::move(maybe_kill_transaction_ids), username, + hasTransactionManagementPrivilege, interpreter = &interpreter]() mutable { + return interpreter_context->KillTransactions(std::move(maybe_kill_transaction_ids), username, + hasTransactionManagementPrivilege, *interpreter); }; break; } @@ -2831,15 +2816,15 @@ Callback HandleTransactionQueueQuery(TransactionQueueQuery *transaction_query, PreparedQuery PrepareTransactionQueueQuery(ParsedQuery parsed_query, const std::optional &username, bool in_explicit_transaction, InterpreterContext *interpreter_context, - DbAccessor *dba) { + memgraph::query::Interpreter &interpreter) { if (in_explicit_transaction) { throw TransactionQueueInMulticommandTxException(); } auto *transaction_queue_query = utils::Downcast(parsed_query.query); MG_ASSERT(transaction_queue_query); - auto callback = - HandleTransactionQueueQuery(transaction_queue_query, username, parsed_query.parameters, interpreter_context); + auto callback = HandleTransactionQueueQuery(transaction_queue_query, username, parsed_query.parameters, + interpreter_context, interpreter); return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges), [callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr{nullptr}]( @@ -2875,8 +2860,8 @@ PreparedQuery PrepareVersionQuery(ParsedQuery parsed_query, bool in_explicit_tra } PreparedQuery PrepareInfoQuery(ParsedQuery parsed_query, bool in_explicit_transaction, - std::map * /*summary*/, InterpreterContext *interpreter_context, - storage::Storage *db, utils::MemoryResource * /*execution_memory*/, + std::map * /*summary*/, storage::Storage *storage, + utils::MemoryResource * /*execution_memory*/, std::optional interpreter_isolation_level, std::optional next_transaction_isolation_level) { if (in_explicit_transaction) { @@ -2891,10 +2876,10 @@ PreparedQuery PrepareInfoQuery(ParsedQuery parsed_query, bool in_explicit_transa case InfoQuery::InfoType::STORAGE: header = {"storage info", "value"}; - handler = [db, interpreter_isolation_level, next_transaction_isolation_level] { - auto info = db->GetInfo(); + handler = [storage, interpreter_isolation_level, next_transaction_isolation_level] { + auto info = storage->GetInfo(); std::vector> results{ - {TypedValue("name"), TypedValue(db->id())}, + {TypedValue("name"), TypedValue(storage->id())}, {TypedValue("vertex_count"), TypedValue(static_cast(info.vertex_count))}, {TypedValue("edge_count"), TypedValue(static_cast(info.edge_count))}, {TypedValue("average_degree"), TypedValue(info.average_degree)}, @@ -2902,29 +2887,28 @@ PreparedQuery PrepareInfoQuery(ParsedQuery parsed_query, bool in_explicit_transa {TypedValue("disk_usage"), TypedValue(static_cast(info.disk_usage))}, {TypedValue("memory_allocated"), TypedValue(static_cast(utils::total_memory_tracker.Amount()))}, {TypedValue("allocation_limit"), TypedValue(static_cast(utils::total_memory_tracker.HardLimit()))}, - {TypedValue("global_isolation_level"), TypedValue(IsolationLevelToString(db->GetIsolationLevel()))}, + {TypedValue("global_isolation_level"), TypedValue(IsolationLevelToString(storage->GetIsolationLevel()))}, {TypedValue("session_isolation_level"), TypedValue(IsolationLevelToString(interpreter_isolation_level))}, {TypedValue("next_session_isolation_level"), TypedValue(IsolationLevelToString(next_transaction_isolation_level))}, - {TypedValue("storage_mode"), TypedValue(StorageModeToString(db->GetStorageMode()))}}; + {TypedValue("storage_mode"), TypedValue(StorageModeToString(storage->GetStorageMode()))}}; return std::pair{results, QueryHandlerResult::COMMIT}; }; break; case InfoQuery::InfoType::INDEX: header = {"index type", "label", "property"}; - handler = [interpreter_context] { + handler = [storage] { const std::string_view label_index_mark{"label"}; const std::string_view label_property_index_mark{"label+property"}; - auto *db = interpreter_context->db.get(); - auto info = db->ListAllIndices(); + auto info = storage->ListAllIndices(); std::vector> results; results.reserve(info.label.size() + info.label_property.size()); for (const auto &item : info.label) { - results.push_back({TypedValue(label_index_mark), TypedValue(db->LabelToName(item)), TypedValue()}); + results.push_back({TypedValue(label_index_mark), TypedValue(storage->LabelToName(item)), TypedValue()}); } for (const auto &item : info.label_property) { - results.push_back({TypedValue(label_property_index_mark), TypedValue(db->LabelToName(item.first)), - TypedValue(db->PropertyToName(item.second))}); + results.push_back({TypedValue(label_property_index_mark), TypedValue(storage->LabelToName(item.first)), + TypedValue(storage->PropertyToName(item.second))}); } std::sort(results.begin(), results.end(), [&label_index_mark](const auto &record_1, const auto &record_2) { @@ -2949,23 +2933,22 @@ PreparedQuery PrepareInfoQuery(ParsedQuery parsed_query, bool in_explicit_transa break; case InfoQuery::InfoType::CONSTRAINT: header = {"constraint type", "label", "properties"}; - handler = [interpreter_context] { - auto *db = interpreter_context->db.get(); - auto info = db->ListAllConstraints(); + handler = [storage] { + auto info = storage->ListAllConstraints(); std::vector> results; results.reserve(info.existence.size() + info.unique.size()); for (const auto &item : info.existence) { - results.push_back({TypedValue("exists"), TypedValue(db->LabelToName(item.first)), - TypedValue(db->PropertyToName(item.second))}); + results.push_back({TypedValue("exists"), TypedValue(storage->LabelToName(item.first)), + TypedValue(storage->PropertyToName(item.second))}); } for (const auto &item : info.unique) { std::vector properties; properties.reserve(item.second.size()); for (const auto &property : item.second) { - properties.emplace_back(db->PropertyToName(property)); + properties.emplace_back(storage->PropertyToName(property)); } results.push_back( - {TypedValue("unique"), TypedValue(db->LabelToName(item.first)), TypedValue(std::move(properties))}); + {TypedValue("unique"), TypedValue(storage->LabelToName(item.first)), TypedValue(std::move(properties))}); } return std::pair{results, QueryHandlerResult::NOTHING}; }; @@ -3000,8 +2983,7 @@ PreparedQuery PrepareInfoQuery(ParsedQuery parsed_query, bool in_explicit_transa } PreparedQuery PrepareConstraintQuery(ParsedQuery parsed_query, bool in_explicit_transaction, - std::vector *notifications, - InterpreterContext *interpreter_context) { + std::vector *notifications, storage::Storage *storage) { if (in_explicit_transaction) { throw ConstraintInMulticommandTxException(); } @@ -3009,13 +2991,13 @@ PreparedQuery PrepareConstraintQuery(ParsedQuery parsed_query, bool in_explicit_ auto *constraint_query = utils::Downcast(parsed_query.query); std::function handler; - auto label = interpreter_context->db->NameToLabel(constraint_query->constraint_.label.name); + auto label = storage->NameToLabel(constraint_query->constraint_.label.name); std::vector properties; std::vector properties_string; properties.reserve(constraint_query->constraint_.properties.size()); properties_string.reserve(constraint_query->constraint_.properties.size()); for (const auto &prop : constraint_query->constraint_.properties) { - properties.push_back(interpreter_context->db->NameToProperty(prop.name)); + properties.push_back(storage->NameToProperty(prop.name)); properties_string.push_back(prop.name); } auto properties_stringified = utils::Join(properties_string, ", "); @@ -3034,21 +3016,20 @@ PreparedQuery PrepareConstraintQuery(ParsedQuery parsed_query, bool in_explicit_ } constraint_notification.title = fmt::format("Created EXISTS constraint on label {} on properties {}.", constraint_query->constraint_.label.name, properties_stringified); - handler = [interpreter_context, label, label_name = constraint_query->constraint_.label.name, + handler = [storage, label, label_name = constraint_query->constraint_.label.name, properties_stringified = std::move(properties_stringified), properties = std::move(properties)](Notification &constraint_notification) { - auto maybe_constraint_error = interpreter_context->db->CreateExistenceConstraint(label, properties[0], {}); + auto maybe_constraint_error = storage->CreateExistenceConstraint(label, properties[0], {}); if (maybe_constraint_error.HasError()) { const auto &error = maybe_constraint_error.GetError(); std::visit( - [&interpreter_context, &label_name, &properties_stringified, - &constraint_notification](T &&arg) { + [storage, &label_name, &properties_stringified, &constraint_notification](T &&arg) { using ErrorType = std::remove_cvref_t; if constexpr (std::is_same_v) { auto &violation = arg; MG_ASSERT(violation.properties.size() == 1U); - auto property_name = interpreter_context->db->PropertyToName(*violation.properties.begin()); + auto property_name = storage->PropertyToName(*violation.properties.begin()); throw QueryRuntimeException( "Unable to create existence constraint :{}({}), because an " "existing node violates it.", @@ -3085,24 +3066,22 @@ PreparedQuery PrepareConstraintQuery(ParsedQuery parsed_query, bool in_explicit_ constraint_notification.title = fmt::format("Created UNIQUE constraint on label {} on properties {}.", constraint_query->constraint_.label.name, utils::Join(properties_string, ", ")); - handler = [interpreter_context, label, label_name = constraint_query->constraint_.label.name, + handler = [storage, label, label_name = constraint_query->constraint_.label.name, properties_stringified = std::move(properties_stringified), property_set = std::move(property_set)](Notification &constraint_notification) { - auto maybe_constraint_error = interpreter_context->db->CreateUniqueConstraint(label, property_set, {}); + auto maybe_constraint_error = storage->CreateUniqueConstraint(label, property_set, {}); if (maybe_constraint_error.HasError()) { const auto &error = maybe_constraint_error.GetError(); std::visit( - [&interpreter_context, &label_name, &properties_stringified, - &constraint_notification](T &&arg) { + [storage, &label_name, &properties_stringified, &constraint_notification](T &&arg) { using ErrorType = std::remove_cvref_t; if constexpr (std::is_same_v) { auto &violation = arg; - auto violation_label_name = interpreter_context->db->LabelToName(violation.label); + auto violation_label_name = storage->LabelToName(violation.label); std::stringstream property_names_stream; - utils::PrintIterable(property_names_stream, violation.properties, ", ", - [&interpreter_context](auto &stream, const auto &prop) { - stream << interpreter_context->db->PropertyToName(prop); - }); + utils::PrintIterable( + property_names_stream, violation.properties, ", ", + [storage](auto &stream, const auto &prop) { stream << storage->PropertyToName(prop); }); throw QueryRuntimeException( "Unable to create unique constraint :{}({}), because an " "existing node violates it.", @@ -3161,10 +3140,10 @@ PreparedQuery PrepareConstraintQuery(ParsedQuery parsed_query, bool in_explicit_ constraint_notification.title = fmt::format("Dropped EXISTS constraint on label {} on properties {}.", constraint_query->constraint_.label.name, utils::Join(properties_string, ", ")); - handler = [interpreter_context, label, label_name = constraint_query->constraint_.label.name, + handler = [storage, label, label_name = constraint_query->constraint_.label.name, properties_stringified = std::move(properties_stringified), properties = std::move(properties)](Notification &constraint_notification) { - auto maybe_constraint_error = interpreter_context->db->DropExistenceConstraint(label, properties[0], {}); + auto maybe_constraint_error = storage->DropExistenceConstraint(label, properties[0], {}); if (maybe_constraint_error.HasError()) { const auto &error = maybe_constraint_error.GetError(); std::visit( @@ -3202,10 +3181,10 @@ PreparedQuery PrepareConstraintQuery(ParsedQuery parsed_query, bool in_explicit_ constraint_notification.title = fmt::format("Dropped UNIQUE constraint on label {} on properties {}.", constraint_query->constraint_.label.name, utils::Join(properties_string, ", ")); - handler = [interpreter_context, label, label_name = constraint_query->constraint_.label.name, + handler = [storage, label, label_name = constraint_query->constraint_.label.name, properties_stringified = std::move(properties_stringified), property_set = std::move(property_set)](Notification &constraint_notification) { - auto maybe_constraint_error = interpreter_context->db->DropUniqueConstraint(label, property_set, {}); + auto maybe_constraint_error = storage->DropUniqueConstraint(label, property_set, {}); if (maybe_constraint_error.HasError()) { const auto &error = maybe_constraint_error.GetError(); std::visit( @@ -3264,13 +3243,17 @@ PreparedQuery PrepareConstraintQuery(ParsedQuery parsed_query, bool in_explicit_ } PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, bool in_explicit_transaction, bool in_explicit_db, - InterpreterContext *interpreter_context, const std::string &session_uuid) { + InterpreterContext *interpreter_context, + memgraph::query::Interpreter &interpreter, + std::optional> on_change_cb) { #ifdef MG_ENTERPRISE if (!license::global_license_checker.IsEnterpriseValidFast()) { throw QueryException("Trying to use enterprise feature without a valid license."); } // TODO: Remove once replicas support multi-tenant replication - if (GetReplicaRole(interpreter_context->db.get()) == storage::replication::ReplicationRole::REPLICA) { + if (!interpreter.db_acc_) + throw DatabaseContextRequiredException("Multi database queries require a defined database."); + if (GetReplicaRole(interpreter.db_acc_->get()->storage()) == storage::replication::ReplicationRole::REPLICA) { throw QueryException("Query forbidden on the replica!"); } if (in_explicit_transaction) { @@ -3278,19 +3261,19 @@ PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, bool in_explic } auto *query = utils::Downcast(parsed_query.query); - auto &sc_handler = memgraph::dbms::SessionContextHandler::ExtractSCH(interpreter_context); + auto *db_handler = interpreter_context->db_handler; switch (query->action_) { case MultiDatabaseQuery::Action::CREATE: return PreparedQuery{ {"STATUS"}, std::move(parsed_query.required_privileges), - [db_name = query->db_name_, session_uuid, &sc_handler]( - AnyStream *stream, std::optional n) -> std::optional { + [db_name = query->db_name_, db_handler](AnyStream *stream, + std::optional n) -> std::optional { std::vector> status; std::string res; - const auto success = sc_handler.New(db_name); + const auto success = db_handler->New(db_name); if (success.HasError()) { switch (success.GetError()) { case dbms::NewError::EXISTS: @@ -3326,30 +3309,24 @@ PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, bool in_explic } return PreparedQuery{{"STATUS"}, std::move(parsed_query.required_privileges), - [db_name = query->db_name_, session_uuid, &sc_handler]( + [db_name = query->db_name_, db_handler, &interpreter, on_change_cb]( AnyStream *stream, std::optional n) -> std::optional { std::vector> status; std::string res; - memgraph::dbms::SetForResult set = memgraph::dbms::SetForResult::SUCCESS; - try { - set = sc_handler.SetFor(session_uuid, db_name); + if (interpreter.db_acc_ && db_name == interpreter.db_acc_->get()->id()) { + res = "Already using " + db_name; + } else { + auto tmp = db_handler->Get(db_name); + if (on_change_cb) (*on_change_cb)(db_name); // Will trow if cb fails + interpreter.SetCurrentDB(std::move(tmp)); + res = "Using " + db_name; + } } catch (const utils::BasicException &e) { throw QueryRuntimeException(e.what()); } - switch (set) { - case dbms::SetForResult::SUCCESS: - res = "Using " + db_name; - break; - case dbms::SetForResult::ALREADY_SET: - res = "Already using " + db_name; - break; - case dbms::SetForResult::FAIL: - throw QueryRuntimeException("Failed to start using {}", db_name); - } - status.emplace_back(std::vector{TypedValue(res)}); auto pull_plan = std::make_shared(std::move(status)); if (pull_plan->Pull(stream, n)) { @@ -3364,33 +3341,36 @@ PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, bool in_explic return PreparedQuery{ {"STATUS"}, std::move(parsed_query.required_privileges), - [db_name = query->db_name_, session_uuid, &sc_handler]( + [db_name = query->db_name_, db_handler, auth = interpreter_context->auth]( AnyStream *stream, std::optional n) -> std::optional { std::vector> status; memgraph::dbms::DeleteResult success{}; try { - success = sc_handler.Delete(db_name); + // Remove database + success = db_handler->Delete(db_name); + if (!success.HasError()) { + // Remove from auth + auth->DeleteDatabase(db_name); + } else { + switch (success.GetError()) { + case dbms::DeleteError::DEFAULT_DB: + throw QueryRuntimeException("Cannot delete the default database."); + case dbms::DeleteError::NON_EXISTENT: + throw QueryRuntimeException("{} does not exist.", db_name); + case dbms::DeleteError::USING: + throw QueryRuntimeException("Cannot delete {}, it is currently being used.", db_name); + case dbms::DeleteError::FAIL: + throw QueryRuntimeException("Failed while deleting {}", db_name); + case dbms::DeleteError::DISK_FAIL: + throw QueryRuntimeException("Failed to clean storage of {}", db_name); + } + } } catch (const utils::BasicException &e) { throw QueryRuntimeException(e.what()); } - if (success.HasError()) { - switch (success.GetError()) { - case dbms::DeleteError::DEFAULT_DB: - throw QueryRuntimeException("Cannot delete the default database."); - case dbms::DeleteError::NON_EXISTENT: - throw QueryRuntimeException("{} does not exist.", db_name); - case dbms::DeleteError::USING: - throw QueryRuntimeException("Cannot delete {}, it is currently being used.", db_name); - case dbms::DeleteError::FAIL: - throw QueryRuntimeException("Failed while deleting {}", db_name); - case dbms::DeleteError::DISK_FAIL: - throw QueryRuntimeException("Failed to clean storage of {}", db_name); - } - } - status.emplace_back(std::vector{TypedValue("Successfully deleted " + db_name)}); auto pull_plan = std::make_shared(std::move(status)); if (pull_plan->Pull(stream, n)) { @@ -3406,25 +3386,27 @@ PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, bool in_explic #endif } -PreparedQuery PrepareShowDatabasesQuery(ParsedQuery parsed_query, InterpreterContext *interpreter_context, - const std::string &session_uuid, const std::optional &username) { +PreparedQuery PrepareShowDatabasesQuery(ParsedQuery parsed_query, storage::Storage *storage, + InterpreterContext *interpreter_context, + const std::optional &username) { #ifdef MG_ENTERPRISE if (!license::global_license_checker.IsEnterpriseValidFast()) { throw QueryException("Trying to use enterprise feature without a valid license."); } // TODO: Remove once replicas support multi-tenant replication - if (GetReplicaRole(interpreter_context->db.get()) == storage::replication::ReplicationRole::REPLICA) { + if (GetReplicaRole(storage) == storage::replication::ReplicationRole::REPLICA) { throw QueryException("SHOW DATABASES forbidden on the replica!"); } - auto &sc_handler = memgraph::dbms::SessionContextHandler::ExtractSCH(interpreter_context); + // TODO pick directly from ic + auto *db_handler = interpreter_context->db_handler; AuthQueryHandler *auth = interpreter_context->auth; Callback callback; callback.header = {"Name", "Current"}; - callback.fn = [auth, session_uuid, &sc_handler, username]() mutable -> std::vector> { + callback.fn = [auth, storage, db_handler, username]() mutable -> std::vector> { std::vector> status; - const auto in_use = sc_handler.Current(session_uuid); + const auto &in_use = storage->id(); bool found_current = false; auto gen_status = [&](T all, K denied) { @@ -3454,7 +3436,7 @@ PreparedQuery PrepareShowDatabasesQuery(ParsedQuery parsed_query, InterpreterCon if (!username) { // No user, return all - gen_status(sc_handler.All(), std::vector{}); + gen_status(db_handler->All(), std::vector{}); } else { // User has a subset of accessible dbs; this is synched with the SessionContextHandler const auto &db_priv = auth->GetDatabasePrivileges(*username); @@ -3462,7 +3444,7 @@ PreparedQuery PrepareShowDatabasesQuery(ParsedQuery parsed_query, InterpreterCon const auto &denied = db_priv[0][1].ValueList(); if (allowed.IsString() && allowed.ValueString() == auth::kAllDatabases) { // All databases are allowed - gen_status(sc_handler.All(), denied); + gen_status(db_handler->All(), denied); } else { gen_status(allowed.ValueList(), denied); } @@ -3518,11 +3500,30 @@ void Interpreter::RollbackTransaction() { transaction_queries_->clear(); } +#if MG_ENTERPRISE +// Before Prepare or during Prepare, but single-threaded. +// TODO: Is there any cleanup? +void Interpreter::SetCurrentDB(std::string_view db_name) { + // Can throw + // do we lock here? + db_acc_ = interpreter_context_->db_handler->Get(db_name); +} +void Interpreter::SetCurrentDB(memgraph::dbms::DatabaseAccess new_db) { + // do we lock here? + db_acc_ = std::move(new_db); +} +#endif + Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, const std::map ¶ms, const std::string *username, QueryExtras const &extras, const std::string &session_uuid) { std::shared_ptr current_timer; + + // TODO: Remove once the interpreter is storage/tx independent and could run without an associated database + if (!db_acc_) throw DatabaseContextRequiredException("Database required for the query."); + auto *db = db_acc_->get(); + if (!in_explicit_transaction_) { query_executions_.clear(); transaction_queries_->clear(); @@ -3620,15 +3621,22 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, utils::Downcast(parsed_query.query) || utils::Downcast(parsed_query.query) || utils::Downcast(parsed_query.query))) { memgraph::metrics::IncrementCounter(memgraph::metrics::ActiveTransactions); - db_accessor_ = interpreter_context_->db->Access(GetIsolationLevelOverride()); + auto &db_acc = *db_acc_; + db_accessor_ = db_acc->Access(GetIsolationLevelOverride()); execution_db_accessor_.emplace(db_accessor_.get()); transaction_status_.store(TransactionStatus::ACTIVE, std::memory_order_release); - if (utils::Downcast(parsed_query.query) && interpreter_context_->trigger_store.HasTriggers()) { - trigger_context_collector_.emplace(interpreter_context_->trigger_store.GetEventTypes()); + if (utils::Downcast(parsed_query.query) && db_acc->trigger_store()->HasTriggers()) { + trigger_context_collector_.emplace(db_acc->trigger_store()->GetEventTypes()); } } + const auto is_cacheable = parsed_query.is_cacheable; + auto *plan_cache = db->plan_cache(); + auto get_plan_cache = [&]() { + return is_cacheable ? plan_cache : nullptr; + }; // Some queries run additional parsing and may need the plan_cache even if the outer query is not cacheable + utils::Timer planning_timer; PreparedQuery prepared_query; utils::MemoryResource *memory_resource = @@ -3637,77 +3645,77 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, frame_change_collector_.reset(); frame_change_collector_.emplace(memory_resource); if (utils::Downcast(parsed_query.query)) { - prepared_query = PrepareCypherQuery( - std::move(parsed_query), &query_execution->summary, interpreter_context_, &*execution_db_accessor_, - memory_resource, &query_execution->notifications, username, &transaction_status_, std::move(current_timer), - trigger_context_collector_ ? &*trigger_context_collector_ : nullptr, &*frame_change_collector_); + prepared_query = PrepareCypherQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_, + &*execution_db_accessor_, memory_resource, &query_execution->notifications, + username, &transaction_status_, std::move(current_timer), get_plan_cache(), + trigger_context_collector_ ? &*trigger_context_collector_ : nullptr, + &*frame_change_collector_); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareExplainQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_, - &*execution_db_accessor_, &query_execution->execution_memory_with_exception); + &*execution_db_accessor_, plan_cache); } else if (utils::Downcast(parsed_query.query)) { - prepared_query = PrepareProfileQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, - interpreter_context_, &*execution_db_accessor_, - &query_execution->execution_memory_with_exception, username, - &transaction_status_, std::move(current_timer), &*frame_change_collector_); + prepared_query = PrepareProfileQuery( + std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, interpreter_context_, + &*execution_db_accessor_, &query_execution->execution_memory_with_exception, username, &transaction_status_, + std::move(current_timer), plan_cache, &*frame_change_collector_); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareDumpQuery(std::move(parsed_query), &query_execution->summary, &*execution_db_accessor_, memory_resource); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareIndexQuery(std::move(parsed_query), in_explicit_transaction_, - &query_execution->notifications, interpreter_context_); + &query_execution->notifications, db->storage(), get_plan_cache()); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareAnalyzeGraphQuery(std::move(parsed_query), in_explicit_transaction_, - &*execution_db_accessor_, interpreter_context_); + &*execution_db_accessor_, get_plan_cache()); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareAuthQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareInfoQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, - interpreter_context_, interpreter_context_->db.get(), - &query_execution->execution_memory_with_exception, interpreter_isolation_level, - next_transaction_isolation_level); + db->storage(), &query_execution->execution_memory_with_exception, + interpreter_isolation_level, next_transaction_isolation_level); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareConstraintQuery(std::move(parsed_query), in_explicit_transaction_, - &query_execution->notifications, interpreter_context_); + &query_execution->notifications, db->storage()); } else if (utils::Downcast(parsed_query.query)) { - prepared_query = PrepareReplicationQuery(std::move(parsed_query), in_explicit_transaction_, - &query_execution->notifications, interpreter_context_); + prepared_query = + PrepareReplicationQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications, + db->storage(), interpreter_context_->config); } else if (utils::Downcast(parsed_query.query)) { - prepared_query = PrepareLockPathQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_); + prepared_query = PrepareLockPathQuery(std::move(parsed_query), in_explicit_transaction_, db->storage()); } else if (utils::Downcast(parsed_query.query)) { - prepared_query = PrepareFreeMemoryQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_); + prepared_query = PrepareFreeMemoryQuery(std::move(parsed_query), in_explicit_transaction_, db->storage()); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareShowConfigQuery(std::move(parsed_query), in_explicit_transaction_); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareTriggerQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications, - interpreter_context_, &*execution_db_accessor_, params, username); + db->trigger_store(), interpreter_context_, &*execution_db_accessor_, params, username); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareStreamQuery(std::move(parsed_query), in_explicit_transaction_, - &query_execution->notifications, interpreter_context_, username); + &query_execution->notifications, *db_acc_, interpreter_context_, username); } else if (utils::Downcast(parsed_query.query)) { prepared_query = - PrepareIsolationLevelQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_, this); + PrepareIsolationLevelQuery(std::move(parsed_query), in_explicit_transaction_, db->storage(), this); } else if (utils::Downcast(parsed_query.query)) { - prepared_query = - PrepareCreateSnapshotQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_); + prepared_query = PrepareCreateSnapshotQuery(std::move(parsed_query), in_explicit_transaction_, db->storage()); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareSettingQuery(std::move(parsed_query), in_explicit_transaction_, &*execution_db_accessor_); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareVersionQuery(std::move(parsed_query), in_explicit_transaction_); } else if (utils::Downcast(parsed_query.query)) { - prepared_query = PrepareStorageModeQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_); + prepared_query = + PrepareStorageModeQuery(std::move(parsed_query), in_explicit_transaction_, *db_acc_, interpreter_context_); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareTransactionQueueQuery(std::move(parsed_query), username_, in_explicit_transaction_, - interpreter_context_, &*execution_db_accessor_); + interpreter_context_, *this); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareMultiDatabaseQuery(std::move(parsed_query), in_explicit_transaction_, in_explicit_db_, - interpreter_context_, session_uuid); + interpreter_context_, *this, on_change_); } else if (utils::Downcast(parsed_query.query)) { prepared_query = - PrepareShowDatabasesQuery(std::move(parsed_query), interpreter_context_, session_uuid, username_); + PrepareShowDatabasesQuery(std::move(parsed_query), db->storage(), interpreter_context_, username_); } else if (utils::Downcast(parsed_query.query)) { - prepared_query = - PrepareEdgeImportModeQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_); + prepared_query = PrepareEdgeImportModeQuery(std::move(parsed_query), in_explicit_transaction_, db->storage()); } else { LOG_FATAL("Should not get here -- unknown query type!"); } @@ -3720,14 +3728,14 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, UpdateTypeCount(rw_type); - if (IsWriteQueryOnMainMemoryReplica(interpreter_context_->db.get(), rw_type)) { + if (IsWriteQueryOnMainMemoryReplica(db->storage(), rw_type)) { query_execution = nullptr; throw QueryException("Write query forbidden on the replica!"); } // Set the target db to the current db (some queries have different target from the current db) if (!query_execution->prepared_query->db) { - query_execution->prepared_query->db = interpreter_context_->db->id(); + query_execution->prepared_query->db = db->id(); } query_execution->summary["db"] = *query_execution->prepared_query->db; @@ -3783,16 +3791,16 @@ void Interpreter::Abort() { } namespace { -void RunTriggersIndividually(const utils::SkipList &triggers, InterpreterContext *interpreter_context, - TriggerContext original_trigger_context, - std::atomic *transaction_status) { +void RunTriggersAfterCommit(dbms::DatabaseAccess db_acc, InterpreterContext *interpreter_context, + TriggerContext original_trigger_context, + std::atomic *transaction_status) { // Run the triggers - for (const auto &trigger : triggers.access()) { + for (const auto &trigger : db_acc->trigger_store()->AfterCommitTriggers().access()) { utils::MonotonicBufferResource execution_memory{kExecutionMemoryBlockSize}; // create a new transaction for each trigger - auto storage_acc = interpreter_context->db->Access(); - DbAccessor db_accessor{storage_acc.get()}; + auto tx_acc = db_acc->Access(); + DbAccessor db_accessor{tx_acc.get()}; // On-disk storage removes all Vertex/Edge Accessors because previous trigger tx finished. // So we need to adapt TriggerContext based on user transaction which is still alive. @@ -3858,6 +3866,10 @@ void Interpreter::Commit() { // a query. if (!db_accessor_) return; + // TODO: Better (or removed) check + if (!db_acc_) return; + auto *db = db_acc_->get(); + /* At this point we must check that the transaction is alive to start committing. The only other possible state is verifying and in that case we must check if the transaction was terminated and if yes abort committing. Exception @@ -3877,7 +3889,7 @@ void Interpreter::Commit() { utils::OnScopeExit clean_status( [this]() { transaction_status_.store(TransactionStatus::IDLE, std::memory_order_release); }); - auto current_storage_mode = interpreter_context_->db->GetStorageMode(); + auto current_storage_mode = db->GetStorageMode(); auto creation_mode = db_accessor_->GetCreationStorageMode(); if (creation_mode != storage::StorageMode::ON_DISK_TRANSACTIONAL && current_storage_mode == storage::StorageMode::ON_DISK_TRANSACTIONAL) { @@ -3902,7 +3914,7 @@ void Interpreter::Commit() { if (trigger_context) { // Run the triggers - for (const auto &trigger : interpreter_context_->trigger_store.BeforeCommitTriggers().access()) { + for (const auto &trigger : db->trigger_store()->BeforeCommitTriggers().access()) { utils::MonotonicBufferResource execution_memory{kExecutionMemoryBlockSize}; AdvanceCommand(); try { @@ -3973,15 +3985,14 @@ void Interpreter::Commit() { // finished, that transaction probably will schedule its after commit triggers, because the other transactions that // want to commit are still waiting for commiting or one of them just started commiting its changes. This means the // ordered execution of after commit triggers are not guaranteed. - if (trigger_context && interpreter_context_->trigger_store.AfterCommitTriggers().size() > 0) { - interpreter_context_->after_commit_trigger_pool.AddTask( - [this, trigger_context = std::move(*trigger_context), - user_transaction = std::shared_ptr(std::move(db_accessor_))]() mutable { - RunTriggersIndividually(this->interpreter_context_->trigger_store.AfterCommitTriggers(), - this->interpreter_context_, std::move(trigger_context), &this->transaction_status_); - user_transaction->FinalizeTransaction(); - SPDLOG_DEBUG("Finished executing after commit triggers"); // NOLINT(bugprone-lambda-function-name) - }); + if (trigger_context && db->trigger_store()->AfterCommitTriggers().size() > 0) { + db->AddTask([this, trigger_context = std::move(*trigger_context), + user_transaction = std::shared_ptr(std::move(db_accessor_))]() mutable { + // TODO: Should this take the db_ and not Access()? + RunTriggersAfterCommit(*db_acc_, interpreter_context_, std::move(trigger_context), &this->transaction_status_); + user_transaction->FinalizeTransaction(); + SPDLOG_DEBUG("Finished executing after commit triggers"); // NOLINT(bugprone-lambda-function-name) + }); } SPDLOG_DEBUG("Finished committing the transaction"); diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index 11504a909..ccb13373b 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -15,7 +15,9 @@ #include +#include "dbms/database.hpp" #include "query/auth_checker.hpp" +#include "query/auth_query_handler.hpp" #include "query/config.hpp" #include "query/context.hpp" #include "query/cypher_query_interpreter.hpp" @@ -53,106 +55,11 @@ extern const Event FailedQuery; namespace memgraph::query { +struct InterpreterContext; + inline constexpr size_t kExecutionMemoryBlockSize = 1UL * 1024UL * 1024UL; inline constexpr size_t kExecutionPoolMaxBlockSize = 1024UL; // 2 ^ 10 -class AuthQueryHandler { - public: - AuthQueryHandler() = default; - virtual ~AuthQueryHandler() = default; - - AuthQueryHandler(const AuthQueryHandler &) = delete; - AuthQueryHandler(AuthQueryHandler &&) = delete; - AuthQueryHandler &operator=(const AuthQueryHandler &) = delete; - AuthQueryHandler &operator=(AuthQueryHandler &&) = delete; - - /// Return false if the user already exists. - /// @throw QueryRuntimeException if an error ocurred. - virtual bool CreateUser(const std::string &username, const std::optional &password) = 0; - - /// Return false if the user does not exist. - /// @throw QueryRuntimeException if an error ocurred. - virtual bool DropUser(const std::string &username) = 0; - - /// @throw QueryRuntimeException if an error ocurred. - virtual void SetPassword(const std::string &username, const std::optional &password) = 0; - -#ifdef MG_ENTERPRISE - /// Return true if access revoked successfully - /// @throw QueryRuntimeException if an error ocurred. - virtual bool RevokeDatabaseFromUser(const std::string &db, const std::string &username) = 0; - - /// Return true if access granted successfully - /// @throw QueryRuntimeException if an error ocurred. - virtual bool GrantDatabaseToUser(const std::string &db, const std::string &username) = 0; - - /// Returns database access rights for the user - /// @throw QueryRuntimeException if an error ocurred. - virtual std::vector> GetDatabasePrivileges(const std::string &username) = 0; - - /// Return true if main database set successfully - /// @throw QueryRuntimeException if an error ocurred. - virtual bool SetMainDatabase(const std::string &db, const std::string &username) = 0; -#endif - - /// Return false if the role already exists. - /// @throw QueryRuntimeException if an error ocurred. - virtual bool CreateRole(const std::string &rolename) = 0; - - /// Return false if the role does not exist. - /// @throw QueryRuntimeException if an error ocurred. - virtual bool DropRole(const std::string &rolename) = 0; - - /// @throw QueryRuntimeException if an error ocurred. - virtual std::vector GetUsernames() = 0; - - /// @throw QueryRuntimeException if an error ocurred. - virtual std::vector GetRolenames() = 0; - - /// @throw QueryRuntimeException if an error ocurred. - virtual std::optional GetRolenameForUser(const std::string &username) = 0; - - /// @throw QueryRuntimeException if an error ocurred. - virtual std::vector GetUsernamesForRole(const std::string &rolename) = 0; - - /// @throw QueryRuntimeException if an error ocurred. - virtual void SetRole(const std::string &username, const std::string &rolename) = 0; - - /// @throw QueryRuntimeException if an error ocurred. - virtual void ClearRole(const std::string &username) = 0; - - virtual std::vector> GetPrivileges(const std::string &user_or_role) = 0; - - /// @throw QueryRuntimeException if an error ocurred. - virtual void GrantPrivilege( - const std::string &user_or_role, const std::vector &privileges -#ifdef MG_ENTERPRISE - , - const std::vector>> - &label_privileges, - - const std::vector>> - &edge_type_privileges -#endif - ) = 0; - - /// @throw QueryRuntimeException if an error ocurred. - virtual void DenyPrivilege(const std::string &user_or_role, const std::vector &privileges) = 0; - - /// @throw QueryRuntimeException if an error ocurred. - virtual void RevokePrivilege( - const std::string &user_or_role, const std::vector &privileges -#ifdef MG_ENTERPRISE - , - const std::vector>> - &label_privileges, - - const std::vector>> - &edge_type_privileges -#endif - ) = 0; -}; - enum class QueryHandlerResult { COMMIT, ABORT, NOTHING }; class ReplicationQueryHandler { @@ -232,57 +139,10 @@ struct QueryExtras { std::optional tx_timeout; }; -class Interpreter; - -/** - * Holds data shared between multiple `Interpreter` instances (which might be - * running concurrently). - * - */ -/// TODO: andi decouple in a separate file why here? -struct InterpreterContext { - explicit InterpreterContext(storage::Config storage_config, InterpreterConfig interpreter_config, - const std::filesystem::path &data_directory, query::AuthQueryHandler *ah = nullptr, - query::AuthChecker *ac = nullptr); - - InterpreterContext(std::unique_ptr &&db, InterpreterConfig interpreter_config, - const std::filesystem::path &data_directory, query::AuthQueryHandler *ah = nullptr, - query::AuthChecker *ac = nullptr); - - std::unique_ptr db; - - // ANTLR has singleton instance that is shared between threads. It is - // protected by locks inside of ANTLR. Unfortunately, they are not protected - // in a very good way. Once we have ANTLR version without race conditions we - // can remove this lock. This will probably never happen since ANTLR - // developers introduce more bugs in each version. Fortunately, we have - // cache so this lock probably won't impact performance much... - utils::SpinLock antlr_lock; - std::optional tsc_frequency{utils::GetTSCFrequency()}; - std::atomic is_shutting_down{false}; - - AuthQueryHandler *auth; - AuthChecker *auth_checker; - - utils::SkipList ast_cache; - utils::SkipList plan_cache; - - TriggerStore trigger_store; - utils::ThreadPool after_commit_trigger_pool{1}; - - const InterpreterConfig config; - - query::stream::Streams streams; - utils::Synchronized, utils::SpinLock> interpreters; -}; - -/// Function that is used to tell all active interpreters that they should stop -/// their ongoing execution. -inline void Shutdown(InterpreterContext *context) { context->is_shutting_down.store(true, std::memory_order_release); } - class Interpreter final { public: - explicit Interpreter(InterpreterContext *interpreter_context); + Interpreter(InterpreterContext *interpreter_context); + Interpreter(InterpreterContext *interpreter_context, memgraph::dbms::DatabaseAccess db); Interpreter(const Interpreter &) = delete; Interpreter &operator=(const Interpreter &) = delete; Interpreter(Interpreter &&) = delete; @@ -303,6 +163,14 @@ class Interpreter final { std::shared_ptr explicit_transaction_timer_{}; std::optional> metadata_{}; //!< User defined transaction metadata + std::optional db_acc_; // Current db (TODO: expand to support multiple) + +#ifdef MG_ENTERPRISE + void SetCurrentDB(std::string_view db_name); + void SetCurrentDB(memgraph::dbms::DatabaseAccess new_db); + void OnChangeCB(auto cb) { on_change_.emplace(cb); } +#endif + /** * Prepare a query for execution. * @@ -435,6 +303,8 @@ class Interpreter final { // To avoid this, we use unique_ptr with which we manualy control construction // and deletion of a single query execution, i.e. when a query finishes, // we reset the corresponding unique_ptr. + // TODO Figure out how this would work for multi-database + // Exists only during a single transaction (for now should be okay as is) std::vector> query_executions_; // all queries that are run as part of the current transaction utils::Synchronized, utils::SpinLock> transaction_queries_; @@ -462,6 +332,8 @@ class Interpreter final { return std::count_if(query_executions_.begin(), query_executions_.end(), [](const auto &execution) { return execution && execution->prepared_query; }); } + + std::optional> on_change_{}; }; class TransactionQueueQueryHandler { @@ -475,13 +347,9 @@ class TransactionQueueQueryHandler { TransactionQueueQueryHandler(TransactionQueueQueryHandler &&) = default; TransactionQueueQueryHandler &operator=(TransactionQueueQueryHandler &&) = default; - static std::vector> ShowTransactions(const std::unordered_set &interpreters, - const std::optional &username, - bool hasTransactionManagementPrivilege); - - static std::vector> KillTransactions( - InterpreterContext *interpreter_context, const std::vector &maybe_kill_transaction_ids, - const std::optional &username, bool hasTransactionManagementPrivilege); + static std::vector> ShowTransactions( + const std::unordered_set &interpreters, const std::optional &username, + bool hasTransactionManagementPrivilege, std::optional &filter_db_acc); }; template diff --git a/src/query/interpreter_context.cpp b/src/query/interpreter_context.cpp new file mode 100644 index 000000000..eeedcf591 --- /dev/null +++ b/src/query/interpreter_context.cpp @@ -0,0 +1,73 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/interpreter_context.hpp" + +#include "query/interpreter.hpp" +namespace memgraph::query { +std::vector> InterpreterContext::KillTransactions( + std::vector maybe_kill_transaction_ids, const std::optional &username, + bool hasTransactionManagementPrivilege, Interpreter &calling_interpreter) { + auto not_found_midpoint = maybe_kill_transaction_ids.end(); + + // Multiple simultaneous TERMINATE TRANSACTIONS aren't allowed + // TERMINATE and SHOW TRANSACTIONS are mutually exclusive + interpreters.WithLock([¬_found_midpoint, &maybe_kill_transaction_ids, username, hasTransactionManagementPrivilege, + filter_db_acc = &calling_interpreter.db_acc_](const auto &interpreters) { + for (Interpreter *interpreter : interpreters) { + TransactionStatus alive_status = TransactionStatus::ACTIVE; + // if it is just checking kill, commit and abort should wait for the end of the check + // The only way to start checking if the transaction will get killed is if the transaction_status is + // active + if (!interpreter->transaction_status_.compare_exchange_strong(alive_status, TransactionStatus::VERIFYING)) { + continue; + } + bool killed = false; + utils::OnScopeExit clean_status([interpreter, &killed]() { + if (killed) { + interpreter->transaction_status_.store(TransactionStatus::TERMINATED, std::memory_order_release); + } else { + interpreter->transaction_status_.store(TransactionStatus::ACTIVE, std::memory_order_release); + } + }); + if (interpreter->db_acc_ != *filter_db_acc) continue; + std::optional intr_trans = interpreter->GetTransactionId(); + if (!intr_trans.has_value()) continue; + + auto transaction_id = std::to_string(intr_trans.value()); + + auto it = std::find(maybe_kill_transaction_ids.begin(), not_found_midpoint, transaction_id); + if (it != not_found_midpoint) { + // update the maybe_kill_transaction_ids (partitioning not found + killed) + --not_found_midpoint; + std::iter_swap(it, not_found_midpoint); + if (interpreter->username_ == username || hasTransactionManagementPrivilege) { + killed = true; // Note: this is used by the above `clean_status` (OnScopeExit) + spdlog::warn("Transaction {} successfully killed", transaction_id); + } else { + spdlog::warn("Not enough rights to kill the transaction"); + } + } + } + }); + + std::vector> results; + for (auto it = maybe_kill_transaction_ids.begin(); it != not_found_midpoint; ++it) { + results.push_back({TypedValue(*it), TypedValue(false)}); + spdlog::warn("Transaction {} not found", *it); + } + for (auto it = not_found_midpoint; it != maybe_kill_transaction_ids.end(); ++it) { + results.push_back({TypedValue(*it), TypedValue(true)}); + } + + return results; +} +} // namespace memgraph::query diff --git a/src/query/interpreter_context.hpp b/src/query/interpreter_context.hpp new file mode 100644 index 000000000..010c85220 --- /dev/null +++ b/src/query/interpreter_context.hpp @@ -0,0 +1,85 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include +#include +#include +#include + +#include "query/config.hpp" +#include "query/cypher_query_interpreter.hpp" +#include "query/typed_value.hpp" +#include "utils/gatekeeper.hpp" +#include "utils/skip_list.hpp" +#include "utils/spin_lock.hpp" +#include "utils/synchronized.hpp" + +namespace memgraph::dbms { +#ifdef MG_ENTERPRISE +class DbmsHandler; +#else +class Database; +#endif +} // namespace memgraph::dbms + +namespace memgraph::query { + +class AuthQueryHandler; +class AuthChecker; +class Interpreter; + +/** + * Holds data shared between multiple `Interpreter` instances (which might be + * running concurrently). + * + */ +struct InterpreterContext { +#ifdef MG_ENTERPRISE + InterpreterContext(InterpreterConfig interpreter_config, memgraph::dbms::DbmsHandler *db_handler, + AuthQueryHandler *ah = nullptr, AuthChecker *ac = nullptr); +#else + InterpreterContext(InterpreterConfig interpreter_config, + memgraph::utils::Gatekeeper *db_gatekeeper, + query::AuthQueryHandler *ah = nullptr, query::AuthChecker *ac = nullptr); +#endif + +#ifdef MG_ENTERPRISE + memgraph::dbms::DbmsHandler *db_handler; +#else + memgraph::utils::Gatekeeper *db_gatekeeper; +#endif + + // Internal + const InterpreterConfig config; + std::atomic is_shutting_down{false}; // TODO: Do we even need this, since there is a global one also + memgraph::utils::SkipList ast_cache; + + // GLOBAL + AuthQueryHandler *auth; + AuthChecker *auth_checker; + + // Used to check active transactions + // TODO: Have a way to read the current database + memgraph::utils::Synchronized, memgraph::utils::SpinLock> interpreters; + + /// Function that is used to tell all active interpreters that they should stop + /// their ongoing execution. + void Shutdown() { is_shutting_down.store(true, std::memory_order_release); } + + std::vector> KillTransactions(std::vector maybe_kill_transaction_ids, + const std::optional &username, + bool hasTransactionManagementPrivilege, + Interpreter &calling_interpreter); +}; + +} // namespace memgraph::query diff --git a/src/query/procedure/module.cpp b/src/query/procedure/module.cpp index dc6b5ac73..338ccbcaa 100644 --- a/src/query/procedure/module.cpp +++ b/src/query/procedure/module.cpp @@ -12,6 +12,7 @@ #include "query/procedure/module.hpp" #include +#include #include extern "C" { diff --git a/src/query/stream/common.cpp b/src/query/stream/common.cpp index 50b554f4d..a89f1406d 100644 --- a/src/query/stream/common.cpp +++ b/src/query/stream/common.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source diff --git a/src/query/stream/common.hpp b/src/query/stream/common.hpp index 0882d9cd4..26ec66907 100644 --- a/src/query/stream/common.hpp +++ b/src/query/stream/common.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source diff --git a/src/query/stream/sources.cpp b/src/query/stream/sources.cpp index 82ddc6216..c8a111a1d 100644 --- a/src/query/stream/sources.cpp +++ b/src/query/stream/sources.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -51,8 +51,8 @@ void KafkaStream::Stop() { consumer_->Stop(); } bool KafkaStream::IsRunning() const { return consumer_->IsRunning(); } void KafkaStream::Check(std::optional timeout, std::optional batch_limit, - const ConsumerFunction &consumer_function) const { - consumer_->Check(timeout, batch_limit, consumer_function); + ConsumerFunction consumer_function) const { + consumer_->Check(timeout, batch_limit, std::move(consumer_function)); } utils::BasicResult KafkaStream::SetStreamOffset(const int64_t offset) { @@ -115,8 +115,8 @@ void PulsarStream::StartWithLimit(uint64_t batch_limit, std::optionalStop(); } bool PulsarStream::IsRunning() const { return consumer_->IsRunning(); } void PulsarStream::Check(std::optional timeout, std::optional batch_limit, - const ConsumerFunction &consumer_function) const { - consumer_->Check(timeout, batch_limit, consumer_function); + ConsumerFunction consumer_function) const { + consumer_->Check(timeout, batch_limit, std::move(consumer_function)); } namespace { diff --git a/src/query/stream/sources.hpp b/src/query/stream/sources.hpp index 2eeaa39d8..d4a7aede4 100644 --- a/src/query/stream/sources.hpp +++ b/src/query/stream/sources.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -41,7 +41,7 @@ struct KafkaStream { bool IsRunning() const; void Check(std::optional timeout, std::optional batch_limit, - const ConsumerFunction &consumer_function) const; + ConsumerFunction consumer_function) const; utils::BasicResult SetStreamOffset(int64_t offset); @@ -77,7 +77,7 @@ struct PulsarStream { bool IsRunning() const; void Check(std::optional timeout, std::optional batch_limit, - const ConsumerFunction &consumer_function) const; + ConsumerFunction consumer_function) const; private: using Consumer = integrations::pulsar::Consumer; diff --git a/src/query/stream/streams.cpp b/src/query/stream/streams.cpp index 5577b6b49..16e7d7522 100644 --- a/src/query/stream/streams.cpp +++ b/src/query/stream/streams.cpp @@ -18,6 +18,8 @@ #include #include +#include "dbms/database.hpp" +#include "dbms/dbms_handler.hpp" #include "integrations/constants.hpp" #include "mg_procedure.h" #include "query/db_accessor.hpp" @@ -161,10 +163,7 @@ void from_json(const nlohmann::json &data, StreamStatus &status) { from_json(data, status.info); } -Streams::Streams(InterpreterContext *interpreter_context, std::filesystem::path directory) - : interpreter_context_(interpreter_context), storage_(std::move(directory)) { - RegisterProcedures(); -} +Streams::Streams(std::filesystem::path directory) : storage_(std::move(directory)) { RegisterProcedures(); } void Streams::RegisterProcedures() { RegisterKafkaProcedures(); @@ -448,11 +447,12 @@ void Streams::RegisterPulsarProcedures() { } } -template +template void Streams::Create(const std::string &stream_name, typename TStream::StreamInfo info, - std::optional owner) { + std::optional owner, TDbAccess db_acc, InterpreterContext *ic) { auto locked_streams = streams_.Lock(); - auto it = CreateConsumer(*locked_streams, stream_name, std::move(info), std::move(owner)); + auto it = CreateConsumer(*locked_streams, stream_name, std::move(info), std::move(owner), + std::move(db_acc), ic); try { std::visit( @@ -467,30 +467,38 @@ void Streams::Create(const std::string &stream_name, typename TStream::StreamInf } } -template void Streams::Create(const std::string &stream_name, KafkaStream::StreamInfo info, - std::optional owner); -template void Streams::Create(const std::string &stream_name, PulsarStream::StreamInfo info, - std::optional owner); +template void Streams::Create(const std::string &stream_name, + KafkaStream::StreamInfo info, + std::optional owner, + dbms::DatabaseAccess db, InterpreterContext *ic); +template void Streams::Create(const std::string &stream_name, + PulsarStream::StreamInfo info, + std::optional owner, + dbms::DatabaseAccess db, InterpreterContext *ic); -template +template Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std::string &stream_name, typename TStream::StreamInfo stream_info, - std::optional owner) { + std::optional owner, TDbAccess db_acc, + InterpreterContext *interpreter_context) { if (map.contains(stream_name)) { throw StreamsException{"Stream already exists with name '{}'", stream_name}; } auto *memory_resource = utils::NewDeleteResource(); - auto consumer_function = [interpreter_context = interpreter_context_, memory_resource, stream_name, + auto consumer_function = [interpreter_context, memory_resource, stream_name, transformation_name = stream_info.common_info.transformation_name, owner = owner, - interpreter = std::make_shared(interpreter_context_), + interpreter = std::make_shared(interpreter_context, std::move(db_acc)), result = mgp_result{nullptr, memory_resource}, - total_retries = interpreter_context_->config.stream_transaction_conflict_retries, - retry_interval = interpreter_context_->config.stream_transaction_retry_interval]( + total_retries = interpreter_context->config.stream_transaction_conflict_retries, + retry_interval = interpreter_context->config.stream_transaction_retry_interval]( const std::vector &messages) mutable { - auto accessor = interpreter_context->db->Access(); - // register new interpreter into interpreter_context_ +#ifdef MG_ENTERPRISE + interpreter->OnChangeCB([](auto) { return false; }); // Disable database change +#endif + auto accessor = interpreter->db_acc_->get()->Access(); + // register new interpreter into interpreter_context interpreter_context->interpreters->insert(interpreter.get()); utils::OnScopeExit interpreter_cleanup{ [interpreter_context, interpreter]() { interpreter_context->interpreters->erase(interpreter.get()); }}; @@ -552,7 +560,8 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std return insert_result.first; } -void Streams::RestoreStreams() { +template +void Streams::RestoreStreams(TDbAccess db, InterpreterContext *ic) { spdlog::info("Loading streams..."); auto locked_streams_map = streams_.Lock(); MG_ASSERT(locked_streams_map->empty(), "Cannot restore streams when some streams already exist!"); @@ -563,8 +572,8 @@ void Streams::RestoreStreams() { return fmt::format("Failed to load stream '{}', because: {} caused by {}", stream_name, message, nested_message); }; - const auto create_consumer = [&, &stream_name = stream_name, this](StreamStatus status, - auto &&stream_json_data) { + const auto create_consumer = [&, &stream_name = stream_name](StreamStatus status, + auto &&stream_json_data) { try { stream_json_data.get_to(status); } catch (const nlohmann::json::type_error &exception) { @@ -577,7 +586,8 @@ void Streams::RestoreStreams() { MG_ASSERT(status.name == stream_name, "Expected stream name is '{}', but got '{}'", status.name, stream_name); try { - auto it = CreateConsumer(*locked_streams_map, stream_name, std::move(status.info), std::move(status.owner)); + auto it = CreateConsumer(*locked_streams_map, stream_name, std::move(status.info), std::move(status.owner), + db, ic); if (status.is_running) { std::visit( [&](const auto &stream_data) { @@ -613,6 +623,8 @@ void Streams::RestoreStreams() { } } +template void Streams::RestoreStreams(dbms::DatabaseAccess db, InterpreterContext *ic); + void Streams::Drop(const std::string &stream_name) { auto locked_streams = streams_.Lock(); @@ -722,7 +734,9 @@ std::vector> Streams::GetStreamInfo() const { return result; } -TransformationResult Streams::Check(const std::string &stream_name, std::optional timeout, +template +TransformationResult Streams::Check(const std::string &stream_name, TDbAccess db_acc, + std::optional timeout, std::optional batch_limit) const { std::optional locked_streams{streams_.ReadLock()}; auto it = GetStream(**locked_streams, stream_name); @@ -739,10 +753,9 @@ TransformationResult Streams::Check(const std::string &stream_name, std::optiona mgp_result result{nullptr, memory_resource}; TransformationResult test_result; - auto consumer_function = [interpreter_context = interpreter_context_, memory_resource, &stream_name, - &transformation_name = transformation_name, &result, - &test_result](const std::vector &messages) mutable { - auto accessor = interpreter_context->db->Access(); + auto consumer_function = [&db_acc, memory_resource, &stream_name, &transformation_name = transformation_name, + &result, &test_result](const std::vector &messages) mutable { + auto accessor = db_acc->Access(); CallCustomTransformation(transformation_name, messages, result, *accessor, *memory_resource, stream_name); auto result_row = std::vector(); @@ -774,4 +787,9 @@ TransformationResult Streams::Check(const std::string &stream_name, std::optiona it->second); } +template TransformationResult Streams::Check(const std::string &stream_name, + dbms::DatabaseAccess db_acc, + std::optional timeout, + std::optional batch_limit) const; + } // namespace memgraph::query::stream diff --git a/src/query/stream/streams.hpp b/src/query/stream/streams.hpp index 9db28f808..3d2ac2a67 100644 --- a/src/query/stream/streams.hpp +++ b/src/query/stream/streams.hpp @@ -81,13 +81,14 @@ class Streams final { /// /// @param interpreter_context context to use to run the result of transformations /// @param directory a directory path to store the persisted streams metadata - Streams(InterpreterContext *interpreter_context, std::filesystem::path directory); + explicit Streams(std::filesystem::path directory); /// Restores the streams from the persisted metadata. /// The restoration is done in a best effort manner, therefore no exception is thrown on failure, but the error is /// logged. If a stream was running previously, then after restoration it will be started. /// This function should only be called when there are no existing streams. - void RestoreStreams(); + template + void RestoreStreams(TDbAccess db, InterpreterContext *interpreter_context); /// Creates a new import stream. /// The create implies connecting to the server to get metadata necessary to initialize the stream. This @@ -97,8 +98,9 @@ class Streams final { /// @param stream_info the necessary informations needed to create the Kafka consumer and transform the messages /// /// @throws StreamsException if the stream with the same name exists or if the creation of Kafka consumer fails - template - void Create(const std::string &stream_name, typename TStream::StreamInfo info, std::optional owner); + template + void Create(const std::string &stream_name, typename TStream::StreamInfo info, std::optional owner, + TDbAccess db, InterpreterContext *interpreter_context); /// Deletes an existing stream and all the data that was persisted. /// @@ -161,7 +163,8 @@ class Streams final { /// @throws StreamsException if the stream doesn't exist /// @throws ConsumerRunningException if the consumer is already running /// @throws ConsumerCheckFailedException if the transformation function throws any std::exception during processing - TransformationResult Check(const std::string &stream_name, + template + TransformationResult Check(const std::string &stream_name, TDbAccess db, std::optional timeout = std::nullopt, std::optional batch_limit = std::nullopt) const; @@ -180,9 +183,10 @@ class Streams final { using StreamsMap = std::unordered_map; using SynchronizedStreamsMap = utils::Synchronized; - template + template StreamsMap::iterator CreateConsumer(StreamsMap &map, const std::string &stream_name, - typename TStream::StreamInfo stream_info, std::optional owner); + typename TStream::StreamInfo stream_info, std::optional owner, + TDbAccess db, InterpreterContext *interpreter_context); template void Persist(StreamStatus &&status) { @@ -196,7 +200,6 @@ class Streams final { void RegisterKafkaProcedures(); void RegisterPulsarProcedures(); - InterpreterContext *interpreter_context_; kvstore::KVStore storage_; SynchronizedStreamsMap streams_; diff --git a/src/storage/v2/config.hpp b/src/storage/v2/config.hpp index bdb2e586b..cd1590beb 100644 --- a/src/storage/v2/config.hpp +++ b/src/storage/v2/config.hpp @@ -78,6 +78,7 @@ struct Config { } disk; std::string name; + bool force_on_disk{false}; }; static inline void UpdatePaths(Config &config, const std::filesystem::path &storage_dir) { diff --git a/src/utils/gatekeeper.hpp b/src/utils/gatekeeper.hpp new file mode 100644 index 000000000..d187d3528 --- /dev/null +++ b/src/utils/gatekeeper.hpp @@ -0,0 +1,213 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace memgraph::utils { + +struct run_t {}; +struct not_run_t {}; + +template +struct EvalResult; +template <> +struct EvalResult { + template + EvalResult(run_t /* marker */, Func &&func, T &arg) : was_run{true} { + std::invoke(std::forward(func), arg); + } + EvalResult(not_run_t /* marker */) : was_run{false} {} + + ~EvalResult() = default; + + EvalResult(EvalResult const &) = delete; + EvalResult(EvalResult &&) = delete; + EvalResult &operator=(EvalResult const &) = delete; + EvalResult &operator=(EvalResult &&) = delete; + + explicit operator bool() const { return was_run; } + + private: + bool was_run; +}; + +template +struct EvalResult { + template + EvalResult(run_t /* marker */, Func &&func, T &arg) : return_result{std::invoke(std::forward(func), arg)} {} + EvalResult(not_run_t /* marker */) {} + + ~EvalResult() = default; + + EvalResult(EvalResult const &) = delete; + EvalResult(EvalResult &&) = delete; + EvalResult &operator=(EvalResult const &) = delete; + EvalResult &operator=(EvalResult &&) = delete; + + explicit operator bool() const { return return_result.has_value(); } + + constexpr const Ret &value() const & { return return_result.value(); } + constexpr Ret &value() & { return return_result.value(); } + constexpr Ret &&value() && { return return_result.value(); } + constexpr const Ret &&value() const && { return return_result.value(); } + + private: + std::optional return_result = std::nullopt; +}; + +template +EvalResult(run_t, Func &&, T &) -> EvalResult>; + +template +struct Gatekeeper { + template + explicit Gatekeeper(Args &&...args) : value_{std::forward(args)...} {} + + Gatekeeper(Gatekeeper const &) = delete; + Gatekeeper(Gatekeeper &&) noexcept = delete; + Gatekeeper &operator=(Gatekeeper const &) = delete; + Gatekeeper &operator=(Gatekeeper &&) = delete; + + struct Accessor { + friend Gatekeeper; + + private: + explicit Accessor(Gatekeeper *owner) : owner_{owner} { ++owner_->count_; } + + public: + Accessor(Accessor const &other) : owner_{other.owner_} { + if (owner_) { + auto guard = std::unique_lock{owner_->mutex_}; + ++owner_->count_; + } + }; + Accessor(Accessor &&other) noexcept : owner_{std::exchange(other.owner_, nullptr)} {}; + Accessor &operator=(Accessor const &other) { + // no change assignment + if (owner_ == other.owner_) { + return *this; + } + + // gain ownership + if (other.owner_) { + auto guard = std::unique_lock{other.owner_->mutex_}; + ++other.owner_->count_; + } + + // reliquish ownership + if (owner_) { + auto guard = std::unique_lock{owner_->mutex_}; + --owner_->count_; + } + + // correct owner + owner_ = other.owner_; + return *this; + }; + Accessor &operator=(Accessor &&other) noexcept { + // self assignment + if (&other == this) return *this; + + // reliquish ownership + if (owner_) { + auto guard = std::unique_lock{owner_->mutex_}; + --owner_->count_; + } + + // correct owners + owner_ = std::exchange(other.owner_, nullptr); + return *this; + } + + ~Accessor() { reset(); } + + auto get() -> T * { return std::addressof(*owner_->value_); } + auto get() const -> const T * { return std::addressof(*owner_->value_); } + T *operator->() { return std::addressof(*owner_->value_); } + const T *operator->() const { return std::addressof(*owner_->value_); } + + template + [[nodiscard]] auto try_exclusively(Func &&func) -> EvalResult> { + // Prevent new access + auto guard = std::unique_lock{owner_->mutex_}; + // Only invoke if we have exclusive access + if (owner_->count_ != 1) { + return {not_run_t{}}; + } + // Invoke and hold result in wrapper type + return {run_t{}, std::forward(func), *owner_->value_}; + } + + // Completely invalidated the accessor if return true + [[nodiscard]] bool try_delete(std::chrono::milliseconds timeout = std::chrono::milliseconds(100)) { + // Prevent new access + auto guard = std::unique_lock{owner_->mutex_}; + if (!owner_->cv_.wait_for(guard, timeout, [this] { return owner_->count_ == 1; })) { + return false; + } + // Delete value + owner_->value_ = std::nullopt; + return true; + } + + explicit operator bool() const { return owner_ != nullptr; } + + void reset() { + if (owner_) { + { + auto guard = std::unique_lock{owner_->mutex_}; + --owner_->count_; + } + owner_->cv_.notify_all(); + } + owner_ = nullptr; + } + + friend bool operator==(Accessor const &lhs, Accessor const &rhs) { return lhs.owner_ == rhs.owner_; } + + private: + Gatekeeper *owner_ = nullptr; + }; + + std::optional access() { + auto guard = std::unique_lock{mutex_}; + if (value_) { + return Accessor{this}; + } + return std::nullopt; + } + + ~Gatekeeper() { + // wait for count to drain to 0 + auto lock = std::unique_lock{mutex_}; + cv_.wait(lock, [this] { return count_ == 0; }); + } + + private: + std::optional value_; + uint64_t count_ = 0; + std::mutex mutex_; // TODO change to something cheaper? + std::condition_variable cv_; +}; + +} // namespace memgraph::utils diff --git a/src/utils/sync_ptr.hpp b/src/utils/sync_ptr.hpp deleted file mode 100644 index 33ccdc26c..000000000 --- a/src/utils/sync_ptr.hpp +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright 2023 Memgraph Ltd. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source -// License, and you may not use this file except in compliance with the Business Source License. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -#pragma once - -#include -#include -#include -#include -#include -#include "utils/exceptions.hpp" - -namespace memgraph::utils { - -/** - * @brief - * - * @tparam TContext - * @tparam TConfig - */ -template -struct SyncPtr { - /** - * @brief Construct a new synched pointer. - * - * @tparam TArgs variable templates used by the TContext constructor - * @param config Additional metadata associated with context - * @param args Arguments to pass to TContext constructor - */ - template - explicit SyncPtr(TConfig config, TArgs &&...args) - : timeout_{1000}, config_{config}, ptr_{new TContext(std::forward(args)...), [this](TContext *ptr) { - this->OnDelete(ptr); - }} {} - - ~SyncPtr() = default; - - SyncPtr(const SyncPtr &) = delete; - SyncPtr &operator=(const SyncPtr &) = delete; - SyncPtr(SyncPtr &&) noexcept = delete; - SyncPtr &operator=(SyncPtr &&) noexcept = delete; - - /** - * @brief Destroy the synched pointer and wait for all copies to get destroyed. - * - */ - void DestroyAndSync() { - ptr_.reset(); - SyncOnDelete(); - } - - /** - * @brief Get (copy) the underlying shared pointer. - * - * @return std::shared_ptr - */ - std::shared_ptr get() { return ptr_; } - std::shared_ptr get() const { return ptr_; } - - /** - * @brief Return saved configuration (metadata) - * - * @return TConfig - */ - TConfig config() { return config_; } - const TConfig &config() const { return config_; } - - void timeout(const std::chrono::milliseconds to) { timeout_ = to; } - std::chrono::milliseconds timeout() const { return timeout_; } - - private: - /** - * @brief Block until OnDelete gets called. - * - */ - void SyncOnDelete() { - std::unique_lock lock(in_use_mtx_); - if (!in_use_cv_.wait_for(lock, timeout_, [this] { return !in_use_; })) { - throw utils::BasicException("Syncronization timeout!"); - } - } - - /** - * @brief Custom destructor used to sync the shared_ptr release. - * - * @param p Pointer to the undelying object. - */ - void OnDelete(TContext *p) { - delete p; - { - std::lock_guard lock(in_use_mtx_); - in_use_ = false; - } - in_use_cv_.notify_all(); - } - - bool in_use_{true}; //!< Flag used to signal sync - mutable std::mutex in_use_mtx_; //!< Mutex used in the cv sync - mutable std::condition_variable in_use_cv_; //!< cv used to signal a sync - std::chrono::milliseconds timeout_; //!< Synchronization timeout in ms - TConfig config_; //!< Additional metadata associated with the context - std::shared_ptr ptr_; //!< Pointer being synced -}; - -template -class SyncPtr { - public: - /** - * @brief Construct a new synched pointer. - * - * @tparam TArgs variable templates used by the TContext constructor - * @param config Additional metadata associated with context - * @param args Arguments to pass to TContext constructor - */ - template - explicit SyncPtr(TArgs &&...args) - : timeout_{1000}, ptr_{new TContext(std::forward(args)...), [this](TContext *ptr) { - this->OnDelete(ptr); - }} {} - - ~SyncPtr() = default; - - SyncPtr(const SyncPtr &) = delete; - SyncPtr &operator=(const SyncPtr &) = delete; - SyncPtr(SyncPtr &&) noexcept = delete; - SyncPtr &operator=(SyncPtr &&) noexcept = delete; - - /** - * @brief Destroy the synched pointer and wait for all copies to get destroyed. - * - */ - void DestroyAndSync() { - ptr_.reset(); - SyncOnDelete(); - } - - /** - * @brief Get (copy) the underlying shared pointer. - * - * @return std::shared_ptr - */ - std::shared_ptr get() { return ptr_; } - std::shared_ptr get() const { return ptr_; } - - void timeout(const std::chrono::milliseconds to) { timeout_ = to; } - std::chrono::milliseconds timeout() const { return timeout_; } - - private: - /** - * @brief Block until OnDelete gets called. - * - */ - void SyncOnDelete() { - std::unique_lock lock(in_use_mtx_); - if (!in_use_cv_.wait_for(lock, timeout_, [this] { return !in_use_; })) { - throw utils::BasicException("Syncronization timeout!"); - } - } - - /** - * @brief Custom destructor used to sync the shared_ptr release. - * - * @param p Pointer to the undelying object. - */ - void OnDelete(TContext *p) { - delete p; - { - std::lock_guard lock(in_use_mtx_); - in_use_ = false; - } - in_use_cv_.notify_all(); - } - - bool in_use_{true}; //!< Flag used to signal sync - mutable std::mutex in_use_mtx_; //!< Mutex used in the cv sync - mutable std::condition_variable in_use_cv_; //!< cv used to signal a sync - std::chrono::milliseconds timeout_; //!< Synchronization timeout in ms - std::shared_ptr ptr_; //!< Pointer being synced -}; - -} // namespace memgraph::utils diff --git a/src/utils/tsc.cpp b/src/utils/tsc.cpp index fa87f9754..02bc9a9b2 100644 --- a/src/utils/tsc.cpp +++ b/src/utils/tsc.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -20,12 +20,14 @@ extern "C" { namespace memgraph::utils { uint64_t ReadTSC() { return rdtsc(); } -std::optional GetTSCFrequency() { +bool IsAvailableTSC() { // init is only needed for fetching frequency - static auto result = std::invoke([] { return rdtsc_init(); }); - return result == 0 ? std::optional{rdtsc_get_tsc_hz()} : std::nullopt; + static bool available = [] { return rdtsc_init() == 0; }(); // iile + return available; } +std::optional GetTSCFrequency() { return IsAvailableTSC() ? std::optional{rdtsc_get_tsc_hz()} : std::nullopt; } + TSCTimer::TSCTimer(std::optional frequency) : frequency_(frequency) { if (!frequency_) return; start_value_ = utils::ReadTSC(); diff --git a/src/utils/tsc.hpp b/src/utils/tsc.hpp index 40344ddaf..d6fbb6c14 100644 --- a/src/utils/tsc.hpp +++ b/src/utils/tsc.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -18,6 +18,8 @@ namespace memgraph::utils { // TSC stands for Time-Stamp Counter +bool IsAvailableTSC(); + uint64_t ReadTSC(); std::optional GetTSCFrequency(); diff --git a/tests/benchmark/expansion.cpp b/tests/benchmark/expansion.cpp index 5b7988f3a..9bbb02579 100644 --- a/tests/benchmark/expansion.cpp +++ b/tests/benchmark/expansion.cpp @@ -10,28 +10,35 @@ // licenses/APL.txt. #include +#include #include "communication/result_stream_faker.hpp" #include "query/config.hpp" #include "query/interpreter.hpp" +#include "query/interpreter_context.hpp" #include "query/typed_value.hpp" #include "storage/v2/inmemory/storage.hpp" #include "storage/v2/isolation_level.hpp" +#include "utils/logging.hpp" class ExpansionBenchFixture : public benchmark::Fixture { protected: std::optional interpreter_context; std::optional interpreter; std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "expansion-benchmark"}; + std::optional> db_gk{memgraph::storage::Config{ + .durability.storage_directory = data_directory, .disk.main_storage_directory = data_directory / "disk"}}; void SetUp(const benchmark::State &state) override { - interpreter_context.emplace(memgraph::storage::Config{}, memgraph::query::InterpreterConfig{}, data_directory); - auto *db = interpreter_context->db.get(); + auto db_acc_opt = db_gk->access(); + MG_ASSERT(db_acc_opt, "Failed to access db"); + auto &db_acc = *db_acc_opt; + interpreter_context.emplace(memgraph::query::InterpreterConfig{}, nullptr); - auto label = db->NameToLabel("Starting"); + auto label = db_acc->storage()->NameToLabel("Starting"); { - auto dba = db->Access(); + auto dba = db_acc->Access(); for (int i = 0; i < state.range(0); i++) dba->CreateVertex(); // the fixed part is one vertex expanding to 1000 others @@ -45,14 +52,15 @@ class ExpansionBenchFixture : public benchmark::Fixture { MG_ASSERT(!dba->Commit().HasError()); } - MG_ASSERT(!db->CreateIndex(label).HasError()); + MG_ASSERT(!db_acc->storage()->CreateIndex(label).HasError()); - interpreter.emplace(&*interpreter_context); + interpreter.emplace(&*interpreter_context, std::move(db_acc)); } void TearDown(const benchmark::State &) override { interpreter = std::nullopt; interpreter_context = std::nullopt; + db_gk.reset(); std::filesystem::remove_all(data_directory); } }; @@ -61,7 +69,7 @@ BENCHMARK_DEFINE_F(ExpansionBenchFixture, Match)(benchmark::State &state) { auto query = "MATCH (s:Starting) return s"; while (state.KeepRunning()) { - ResultStreamFaker results(interpreter_context->db.get()); + ResultStreamFaker results(interpreter->db_acc_->get()->storage()); interpreter->Prepare(query, {}, nullptr); interpreter->PullAll(&results); } @@ -76,7 +84,7 @@ BENCHMARK_DEFINE_F(ExpansionBenchFixture, Expand)(benchmark::State &state) { auto query = "MATCH (s:Starting) WITH s MATCH (s)--(d) RETURN count(d)"; while (state.KeepRunning()) { - ResultStreamFaker results(interpreter_context->db.get()); + ResultStreamFaker results(interpreter->db_acc_->get()->storage()); interpreter->Prepare(query, {}, nullptr); interpreter->PullAll(&results); } diff --git a/tests/manual/single_query.cpp b/tests/manual/single_query.cpp index d630d1171..5b0e138de 100644 --- a/tests/manual/single_query.cpp +++ b/tests/manual/single_query.cpp @@ -13,6 +13,7 @@ #include "license/license.hpp" #include "query/config.hpp" #include "query/interpreter.hpp" +#include "query/interpreter_context.hpp" #include "storage/v2/config.hpp" #include "storage/v2/inmemory/storage.hpp" #include "storage/v2/isolation_level.hpp" @@ -31,11 +32,15 @@ int main(int argc, char *argv[]) { memgraph::utils::OnScopeExit([&data_directory] { std::filesystem::remove_all(data_directory); }); memgraph::license::global_license_checker.EnableTesting(); - memgraph::query::InterpreterContext interpreter_context{memgraph::storage::Config{}, - memgraph::query::InterpreterConfig{}, data_directory}; - memgraph::query::Interpreter interpreter{&interpreter_context}; + memgraph::utils::Gatekeeper db_gk(memgraph::storage::Config{ + .durability.storage_directory = data_directory, .disk.main_storage_directory = data_directory / "disk"}); + auto db_acc_opt = db_gk.access(); + MG_ASSERT(db_acc_opt, "Failed to access db"); + auto &db_acc = *db_acc_opt; + memgraph::query::InterpreterContext interpreter_context(memgraph::query::InterpreterConfig{}, nullptr); + memgraph::query::Interpreter interpreter{&interpreter_context, db_acc}; - ResultStreamFaker stream(interpreter_context.db.get()); + ResultStreamFaker stream(db_acc->storage()); auto [header, _1, qid, _2] = interpreter.Prepare(argv[1], {}, nullptr); stream.Header(header); auto summary = interpreter.PullAll(&stream); diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 75a090fa6..64f95b880 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -257,9 +257,6 @@ target_link_libraries(${test_prefix}utils_signals mg-utils) add_unit_test(utils_string.cpp) target_link_libraries(${test_prefix}utils_string mg-utils) -add_unit_test(utils_sync_ptr.cpp) -target_link_libraries(${test_prefix}utils_sync_ptr) - add_unit_test(utils_synchronized.cpp) target_link_libraries(${test_prefix}utils_synchronized mg-utils) @@ -404,17 +401,9 @@ target_link_libraries(${test_prefix}monitoring mg-communication Boost::headers) # Test multi-database if(MG_ENTERPRISE) - # add_unit_test(dbms_storage.cpp) - # target_link_libraries(${test_prefix}dbms_storage mg-storage-v2 mg-query mg-glue) + add_unit_test(dbms_database.cpp) + target_link_libraries(${test_prefix}dbms_database mg-storage-v2 mg-query mg-glue mg-dbms) - add_unit_test(dbms_interp.cpp) - target_link_libraries(${test_prefix}dbms_interp mg-query) -endif() - -# add_unit_test(dbms_auth.cpp) -# target_link_libraries(${test_prefix}dbms_auth mg-glue) - -if(MG_ENTERPRISE) - add_unit_test_with_custom_main(dbms_sc_handler.cpp) - target_link_libraries(${test_prefix}dbms_sc_handler mg-query mg-audit mg-glue) + add_unit_test_with_custom_main(dbms_handler.cpp) + target_link_libraries(${test_prefix}dbms_handler mg-query mg-auth mg-glue mg-dbms) endif() diff --git a/tests/unit/bolt_session.cpp b/tests/unit/bolt_session.cpp index eaf7c8d93..684a7e2e2 100644 --- a/tests/unit/bolt_session.cpp +++ b/tests/unit/bolt_session.cpp @@ -119,13 +119,6 @@ class TestSession final : public Session { void Configure(const std::map &) override {} std::string GetDatabaseName() const override { return ""; } -#ifdef MG_ENTERPRISE - memgraph::dbms::SetForResult OnChange(const std::string &db_name) override { - return memgraph::dbms::SetForResult::SUCCESS; - } - bool OnDelete(const std::string &) override { return true; } -#endif - void TestHook_ShouldAbort() { should_abort_ = true; } private: diff --git a/tests/unit/dbms_database.cpp b/tests/unit/dbms_database.cpp new file mode 100644 index 000000000..e2a6bf449 --- /dev/null +++ b/tests/unit/dbms_database.cpp @@ -0,0 +1,224 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include +#include +#include + +#include "dbms/database_handler.hpp" +#include "dbms/global.hpp" + +#include "license/license.hpp" +#include "query_plan_common.hpp" +#include "storage/v2/view.hpp" + +std::filesystem::path storage_directory{std::filesystem::temp_directory_path() / "MG_test_unit_dbms_database"}; + +memgraph::storage::Config default_conf(std::string name = "") { + return {.durability = {.storage_directory = storage_directory / name, + .snapshot_wal_mode = + memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL}, + .disk = {.main_storage_directory = storage_directory / name / "disk"}}; +} + +class DBMS_Database : public ::testing::Test { + protected: + void SetUp() override { Clear(); } + + void TearDown() override { Clear(); } + + private: + void Clear() { + if (std::filesystem::exists(storage_directory)) { + std::filesystem::remove_all(storage_directory); + } + } +}; + +#ifdef MG_ENTERPRISE +TEST_F(DBMS_Database, New) { + memgraph::dbms::DatabaseHandler db_handler; + { ASSERT_FALSE(db_handler.GetConfig("db1")); } + { // With custom config + memgraph::storage::Config db_config{ + .durability = {.storage_directory = storage_directory / "db2", + .snapshot_wal_mode = + memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL}, + .disk = {.main_storage_directory = storage_directory / "disk"}}; + auto db2 = db_handler.New("db2", db_config); + ASSERT_TRUE(db2.HasValue() && db2.GetValue()); + ASSERT_TRUE(std::filesystem::exists(storage_directory / "db2")); + } + { + // With default config + auto db3 = db_handler.New("db3", default_conf("db3")); + ASSERT_TRUE(db3.HasValue() && db3.GetValue()); + ASSERT_TRUE(std::filesystem::exists(storage_directory / "db3")); + auto db4 = db_handler.New("db4", default_conf("four")); + ASSERT_TRUE(db4.HasValue() && db4.GetValue()); + ASSERT_TRUE(std::filesystem::exists(storage_directory / "four")); + auto db5 = db_handler.New("db5", default_conf("db3")); + ASSERT_TRUE(db5.HasError() && db5.GetError() == memgraph::dbms::NewError::EXISTS); + } + + auto all = db_handler.All(); + std::sort(all.begin(), all.end()); + ASSERT_EQ(all.size(), 3); + ASSERT_EQ(all[0], "db2"); + ASSERT_EQ(all[1], "db3"); + ASSERT_EQ(all[2], "db4"); +} + +TEST_F(DBMS_Database, Get) { + memgraph::dbms::DatabaseHandler db_handler; + + auto db1 = db_handler.New("db1", default_conf("db1")); + auto db2 = db_handler.New("db2", default_conf("db2")); + auto db3 = db_handler.New("db3", default_conf("db3")); + + ASSERT_TRUE(db1.HasValue()); + ASSERT_TRUE(db2.HasValue()); + ASSERT_TRUE(db3.HasValue()); + + auto get_db1 = db_handler.Get("db1"); + auto get_db2 = db_handler.Get("db2"); + auto get_db3 = db_handler.Get("db3"); + + ASSERT_TRUE(get_db1 && *get_db1 == db1.GetValue()); + ASSERT_TRUE(get_db2 && *get_db2 == db2.GetValue()); + ASSERT_TRUE(get_db3 && *get_db3 == db3.GetValue()); + + ASSERT_FALSE(db_handler.Get("db123")); + ASSERT_FALSE(db_handler.Get("db2 ")); + ASSERT_FALSE(db_handler.Get(" db3")); +} + +TEST_F(DBMS_Database, Delete) { + memgraph::dbms::DatabaseHandler db_handler; + + auto db1 = db_handler.New("db1", default_conf("db1")); + auto db2 = db_handler.New("db2", default_conf("db2")); + auto db3 = db_handler.New("db3", default_conf("db3")); + + ASSERT_TRUE(db1.HasValue()); + ASSERT_TRUE(db2.HasValue()); + ASSERT_TRUE(db3.HasValue()); + + { + // Release accessor to storage + db1.GetValue().reset(); + // Delete from handler + ASSERT_TRUE(db_handler.Delete("db1")); + ASSERT_FALSE(db_handler.Get("db1")); + auto all = db_handler.All(); + std::sort(all.begin(), all.end()); + ASSERT_EQ(all.size(), 2); + ASSERT_EQ(all[0], "db2"); + ASSERT_EQ(all[1], "db3"); + } + + { + ASSERT_THROW(db_handler.Delete("db0"), memgraph::utils::BasicException); + ASSERT_THROW(db_handler.Delete("db1"), memgraph::utils::BasicException); + auto all = db_handler.All(); + std::sort(all.begin(), all.end()); + ASSERT_EQ(all.size(), 2); + ASSERT_EQ(all[0], "db2"); + ASSERT_EQ(all[1], "db3"); + } +} + +TEST_F(DBMS_Database, DeleteAndRecover) { + memgraph::license::global_license_checker.EnableTesting(); + memgraph::dbms::DatabaseHandler db_handler; + + { + auto db1 = db_handler.New("db1", default_conf("db1")); + auto db2 = db_handler.New("db2", default_conf("db2")); + + memgraph::storage::Config conf_w_snap{ + .durability = {.storage_directory = storage_directory / "db3", + .snapshot_wal_mode = + memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL, + .snapshot_on_exit = true}, + .disk = {.main_storage_directory = storage_directory / "db3" / "disk"}}; + + auto db3 = db_handler.New("db3", conf_w_snap); + + ASSERT_TRUE(db1.HasValue()); + ASSERT_TRUE(db2.HasValue()); + ASSERT_TRUE(db3.HasValue()); + + // Add data to graphs + { + auto storage_dba = db1.GetValue()->Access(); + memgraph::query::DbAccessor dba{storage_dba.get()}; + memgraph::query::VertexAccessor v1{dba.InsertVertex()}; + memgraph::query::VertexAccessor v2{dba.InsertVertex()}; + ASSERT_TRUE(v1.AddLabel(dba.NameToLabel("l11")).HasValue()); + ASSERT_TRUE(v2.AddLabel(dba.NameToLabel("l12")).HasValue()); + ASSERT_FALSE(dba.Commit().HasError()); + } + { + auto storage_dba = db3.GetValue()->Access(); + memgraph::query::DbAccessor dba{storage_dba.get()}; + memgraph::query::VertexAccessor v1{dba.InsertVertex()}; + memgraph::query::VertexAccessor v2{dba.InsertVertex()}; + memgraph::query::VertexAccessor v3{dba.InsertVertex()}; + ASSERT_TRUE(v1.AddLabel(dba.NameToLabel("l31")).HasValue()); + ASSERT_TRUE(v2.AddLabel(dba.NameToLabel("l32")).HasValue()); + ASSERT_TRUE(v3.AddLabel(dba.NameToLabel("l33")).HasValue()); + ASSERT_FALSE(dba.Commit().HasError()); + } + } + + // Delete from handler + ASSERT_TRUE(db_handler.Delete("db1")); + ASSERT_TRUE(db_handler.Delete("db2")); + ASSERT_TRUE(db_handler.Delete("db3")); + + { + // Recover graphs (only db3) + auto db1 = db_handler.New("db1", default_conf("db1")); + auto db2 = db_handler.New("db2", default_conf("db2")); + + memgraph::storage::Config conf_w_rec{ + .durability = {.storage_directory = storage_directory / "db3", + .recover_on_startup = true, + .snapshot_wal_mode = + memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL}, + .disk = {.main_storage_directory = storage_directory / "db3" / "disk"}}; + + auto db3 = db_handler.New("db3", conf_w_rec); + + // Check content + { + // Empty + auto storage_dba = db1.GetValue()->Access(); + memgraph::query::DbAccessor dba{storage_dba.get()}; + ASSERT_EQ(dba.VerticesCount(), 0); + } + { + // Empty + auto storage_dba = db2.GetValue()->Access(); + memgraph::query::DbAccessor dba{storage_dba.get()}; + ASSERT_EQ(dba.VerticesCount(), 0); + } + { + // Full + auto storage_dba = db3.GetValue()->Access(); + memgraph::query::DbAccessor dba{storage_dba.get()}; + ASSERT_EQ(dba.VerticesCount(), 3); + } + } +} + +#endif diff --git a/tests/unit/dbms_handler.cpp b/tests/unit/dbms_handler.cpp new file mode 100644 index 000000000..13fe7e69b --- /dev/null +++ b/tests/unit/dbms_handler.cpp @@ -0,0 +1,193 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/auth_query_handler.hpp" +#ifdef MG_ENTERPRISE +#include +#include +#include +#include + +#include "dbms/constants.hpp" +#include "dbms/dbms_handler.hpp" +#include "dbms/global.hpp" +#include "glue/auth_checker.hpp" +#include "glue/auth_handler.hpp" +#include "query/config.hpp" +#include "query/interpreter.hpp" + +// Global +std::filesystem::path storage_directory{std::filesystem::temp_directory_path() / "MG_test_unit_dbms_handler"}; +static memgraph::storage::Config storage_conf; +std::unique_ptr> auth; + +// Let this be global so we can test it different states throughout + +class TestEnvironment : public ::testing::Environment { + public: + static memgraph::dbms::DbmsHandler *get() { return ptr_.get(); } + + void SetUp() override { + // Setup config + memgraph::storage::UpdatePaths(storage_conf, storage_directory); + storage_conf.durability.snapshot_wal_mode = + memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL; + // Clean storage directory (running multiple parallel test, run only if the first process) + if (std::filesystem::exists(storage_directory)) { + memgraph::utils::OutputFile lock_file_handle_; + lock_file_handle_.Open(storage_directory / ".lock", memgraph::utils::OutputFile::Mode::OVERWRITE_EXISTING); + if (lock_file_handle_.AcquireLock()) { + std::filesystem::remove_all(storage_directory); + } + } + auth = + std::make_unique>( + storage_directory / "auth"); + ptr_ = std::make_unique(storage_conf, auth.get(), false, true); + } + + void TearDown() override { + ptr_.reset(); + auth.reset(); + } + + static std::unique_ptr ptr_; +}; + +std::unique_ptr TestEnvironment::ptr_ = nullptr; + +class DBMS_Handler : public testing::Test {}; +using DBMS_HandlerDeath = DBMS_Handler; + +TEST(DBMS_Handler, Init) { + // Check that the default db has been created successfully + std::vector dirs = {"snapshots", "streams", "triggers", "wal"}; + for (const auto &dir : dirs) + ASSERT_TRUE(std::filesystem::exists(storage_directory / dir)) << (storage_directory / dir); + const auto db_path = storage_directory / "databases" / memgraph::dbms::kDefaultDB; + ASSERT_TRUE(std::filesystem::exists(db_path)); + for (const auto &dir : dirs) { + std::error_code ec; + const auto test_link = std::filesystem::read_symlink(db_path / dir, ec); + ASSERT_TRUE(!ec) << ec.message(); + ASSERT_EQ(test_link, "../../" + dir); + } +} + +TEST(DBMS_Handler, New) { + auto &dbms = *TestEnvironment::get(); + { + const auto all = dbms.All(); + ASSERT_EQ(all.size(), 1); + ASSERT_EQ(all[0], memgraph::dbms::kDefaultDB); + } + { + auto db1 = dbms.New("db1"); + ASSERT_TRUE(db1.HasValue()); + ASSERT_TRUE(db1.GetValue()); + ASSERT_TRUE(std::filesystem::exists(storage_directory / "databases" / "db1")); + ASSERT_TRUE(db1.GetValue()->storage() != nullptr); + ASSERT_TRUE(db1.GetValue()->streams() != nullptr); + ASSERT_TRUE(db1.GetValue()->trigger_store() != nullptr); + ASSERT_TRUE(db1.GetValue()->thread_pool() != nullptr); + const auto all = dbms.All(); + ASSERT_EQ(all.size(), 2); + ASSERT_TRUE(std::find(all.begin(), all.end(), memgraph::dbms::kDefaultDB) != all.end()); + ASSERT_TRUE(std::find(all.begin(), all.end(), "db1") != all.end()); + } + { + // Fail if name exists + auto db2 = dbms.New("db1"); + ASSERT_TRUE(db2.HasError() && db2.GetError() == memgraph::dbms::NewError::EXISTS); + } + { + auto db3 = dbms.New("db3"); + ASSERT_TRUE(db3.HasValue()); + ASSERT_TRUE(std::filesystem::exists(storage_directory / "databases" / "db3")); + ASSERT_TRUE(db3.GetValue()->storage() != nullptr); + ASSERT_TRUE(db3.GetValue()->streams() != nullptr); + ASSERT_TRUE(db3.GetValue()->trigger_store() != nullptr); + ASSERT_TRUE(db3.GetValue()->thread_pool() != nullptr); + const auto all = dbms.All(); + ASSERT_EQ(all.size(), 3); + ASSERT_TRUE(std::find(all.begin(), all.end(), "db3") != all.end()); + } +} + +TEST(DBMS_Handler, Get) { + auto &dbms = *TestEnvironment::get(); + auto default_db = dbms.Get(memgraph::dbms::kDefaultDB); + ASSERT_TRUE(default_db); + ASSERT_TRUE(default_db->storage() != nullptr); + ASSERT_TRUE(default_db->streams() != nullptr); + ASSERT_TRUE(default_db->trigger_store() != nullptr); + ASSERT_TRUE(default_db->thread_pool() != nullptr); + + ASSERT_ANY_THROW(dbms.Get("non-existent")); + + auto db1 = dbms.Get("db1"); + ASSERT_TRUE(db1); + ASSERT_TRUE(db1->storage() != nullptr); + ASSERT_TRUE(db1->streams() != nullptr); + ASSERT_TRUE(db1->trigger_store() != nullptr); + ASSERT_TRUE(db1->thread_pool() != nullptr); + + auto db3 = dbms.Get("db3"); + ASSERT_TRUE(db3); + ASSERT_TRUE(db3->storage() != nullptr); + ASSERT_TRUE(db3->streams() != nullptr); + ASSERT_TRUE(db3->trigger_store() != nullptr); + ASSERT_TRUE(db3->thread_pool() != nullptr); +} + +TEST(DBMS_Handler, Delete) { + auto &dbms = *TestEnvironment::get(); + + auto db1_acc = dbms.Get("db1"); // Holds access to database + + { + auto del = dbms.Delete(memgraph::dbms::kDefaultDB); + ASSERT_TRUE(del.HasError() && del.GetError() == memgraph::dbms::DeleteError::DEFAULT_DB); + } + { + auto del = dbms.Delete("non-existent"); + ASSERT_TRUE(del.HasError() && del.GetError() == memgraph::dbms::DeleteError::NON_EXISTENT); + } + { + // db1_acc is using db1 + auto del = dbms.Delete("db1"); + ASSERT_TRUE(del.HasError()); + ASSERT_TRUE(del.GetError() == memgraph::dbms::DeleteError::USING); + } + { + // Reset db1_acc (releases access) so delete will succeed + db1_acc.reset(); + ASSERT_FALSE(db1_acc); + auto del = dbms.Delete("db1"); + ASSERT_FALSE(del.HasError()) << (int)del.GetError(); + auto del2 = dbms.Delete("db1"); + ASSERT_TRUE(del2.HasError() && del2.GetError() == memgraph::dbms::DeleteError::NON_EXISTENT); + } + { + auto del = dbms.Delete("db3"); + ASSERT_FALSE(del.HasError()); + ASSERT_FALSE(std::filesystem::exists(storage_directory / "databases" / "db3")); + } +} + +int main(int argc, char *argv[]) { + ::testing::InitGoogleTest(&argc, argv); + // gtest takes ownership of the TestEnvironment ptr - we don't delete it. + ::testing::AddGlobalTestEnvironment(new TestEnvironment); + return RUN_ALL_TESTS(); +} + +#endif diff --git a/tests/unit/dbms_interp.cpp b/tests/unit/dbms_interp.cpp deleted file mode 100644 index 5dbd51150..000000000 --- a/tests/unit/dbms_interp.cpp +++ /dev/null @@ -1,431 +0,0 @@ -// Copyright 2023 Memgraph Ltd. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source -// License, and you may not use this file except in compliance with the Business Source License. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -#ifdef MG_ENTERPRISE - -#include -#include -#include - -#include "dbms/global.hpp" -#include "dbms/interp_handler.hpp" - -#include "query/auth_checker.hpp" -#include "query/frontend/ast/ast.hpp" -#include "query/interpreter.hpp" - -class TestAuthHandler : public memgraph::query::AuthQueryHandler { - public: - TestAuthHandler() = default; - - bool CreateUser(const std::string & /*username*/, const std::optional & /*password*/) override { - return true; - } - bool DropUser(const std::string & /*username*/) override { return true; } - void SetPassword(const std::string & /*username*/, const std::optional & /*password*/) override {} - bool RevokeDatabaseFromUser(const std::string & /*db*/, const std::string & /*username*/) override { return true; } - bool GrantDatabaseToUser(const std::string & /*db*/, const std::string & /*username*/) override { return true; } - bool SetMainDatabase(const std::string & /*db*/, const std::string & /*username*/) override { return true; } - std::vector> GetDatabasePrivileges(const std::string & /*user*/) override { - return {}; - } - bool CreateRole(const std::string & /*rolename*/) override { return true; } - bool DropRole(const std::string & /*rolename*/) override { return true; } - std::vector GetUsernames() override { return {}; } - std::vector GetRolenames() override { return {}; } - std::optional GetRolenameForUser(const std::string & /*username*/) override { return {}; } - std::vector GetUsernamesForRole(const std::string & /*rolename*/) override { return {}; } - void SetRole(const std::string &username, const std::string & /*rolename*/) override {} - void ClearRole(const std::string &username) override {} - std::vector> GetPrivileges(const std::string & /*user_or_role*/) override { - return {}; - } - void GrantPrivilege( - const std::string & /*user_or_role*/, const std::vector & /*privileges*/, - const std::vector>> - & /*label_privileges*/, - const std::vector>> - & /*edge_type_privileges*/) override {} - void DenyPrivilege(const std::string & /*user_or_role*/, - const std::vector & /*privileges*/) override {} - void RevokePrivilege( - const std::string & /*user_or_role*/, const std::vector & /*privileges*/, - const std::vector>> - & /*label_privileges*/, - const std::vector>> - & /*edge_type_privileges*/) override {} -}; - -class TestAuthChecker : public memgraph::query::AuthChecker { - public: - bool IsUserAuthorized(const std::optional & /*username*/, - const std::vector & /*privileges*/, - const std::string & /*db*/) const override { - return true; - } - - std::unique_ptr GetFineGrainedAuthChecker( - const std::string & /*username*/, const memgraph::query::DbAccessor * /*db_accessor*/) const override { - return {}; - } - - void ClearCache() const override {} -}; - -std::filesystem::path storage_directory{std::filesystem::temp_directory_path() / "MG_test_unit_dbms_interp"}; - -memgraph::query::InterpreterConfig default_conf{}; - -memgraph::storage::Config default_storage_conf(std::string name = "") { - return {.durability = {.storage_directory = storage_directory / name, - .snapshot_wal_mode = - memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL}, - .disk = {.main_storage_directory = storage_directory / name / "disk"}}; -} - -class TestHandler { -} test_handler; - -class DBMS_Interp : public ::testing::Test { - protected: - void SetUp() override { Clear(); } - - void TearDown() override { Clear(); } - - private: - void Clear() { - if (std::filesystem::exists(storage_directory)) { - std::filesystem::remove_all(storage_directory); - } - } -}; - -TEST_F(DBMS_Interp, New) { - memgraph::dbms::InterpContextHandler ih; - memgraph::storage::Config db_conf{ - .durability = {.storage_directory = storage_directory, - .snapshot_wal_mode = - memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL}, - .disk = {.main_storage_directory = storage_directory / "disk"}}; - TestAuthHandler ah; - TestAuthChecker ac; - - { - // Clean initialization - auto ic1 = ih.New("ic1", test_handler, db_conf, default_conf, ah, ac); - ASSERT_TRUE(ic1.HasValue() && ic1.GetValue() != nullptr); - ASSERT_TRUE(std::filesystem::exists(storage_directory / "triggers")); - ASSERT_TRUE(std::filesystem::exists(storage_directory / "streams")); - ASSERT_NE(ic1.GetValue()->db, nullptr); - ASSERT_EQ(&ic1.GetValue()->sc_handler_, &test_handler); - ASSERT_EQ(ih.GetConfig("ic1")->storage_config.durability.storage_directory, storage_directory); - } - { - memgraph::storage::Config db_conf2{ - .durability = {.storage_directory = storage_directory, - .snapshot_wal_mode = - memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL}, - .disk = {.main_storage_directory = storage_directory / "disk"}}; - // Try to override data directory - auto ic2 = ih.New("ic2", test_handler, db_conf2, default_conf, ah, ac); - ASSERT_TRUE(ic2.HasError() && ic2.GetError() == memgraph::dbms::NewError::EXISTS); - } - { - memgraph::storage::Config db_conf3{ - .durability = {.storage_directory = storage_directory / "ic3", - .snapshot_wal_mode = - memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL}, - .disk = {.main_storage_directory = storage_directory / "disk"}}; - // Try to override the name "ic1" - auto ic3 = ih.New("ic1", test_handler, db_conf3, default_conf, ah, ac); - ASSERT_TRUE(ic3.HasError() && ic3.GetError() == memgraph::dbms::NewError::EXISTS); - } - { - // Another clean initialization - memgraph::storage::Config db_conf4{ - .durability = {.storage_directory = storage_directory / "ic4", - .snapshot_wal_mode = - memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL}, - .disk = {.main_storage_directory = storage_directory / "disk"}}; - auto ic4 = ih.New("ic4", test_handler, db_conf4, default_conf, ah, ac); - ASSERT_TRUE(ic4.HasValue() && ic4.GetValue() != nullptr); - ASSERT_TRUE(std::filesystem::exists(storage_directory / "ic4" / "triggers")); - ASSERT_TRUE(std::filesystem::exists(storage_directory / "ic4" / "streams")); - ASSERT_EQ(&ic4.GetValue()->sc_handler_, &test_handler); - ASSERT_EQ(ih.GetConfig("ic4")->storage_config.durability.storage_directory, storage_directory / "ic4"); - } -} - -TEST_F(DBMS_Interp, Get) { - memgraph::dbms::InterpContextHandler ih; - TestAuthHandler ah; - TestAuthChecker ac; - - memgraph::storage::Config db_conf{ - .durability = {.storage_directory = storage_directory / "ic1", - .snapshot_wal_mode = - memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL}, - .disk = {.main_storage_directory = storage_directory / "disk"}}; - auto ic1 = ih.New("ic1", test_handler, db_conf, default_conf, ah, ac); - ASSERT_TRUE(ic1.HasValue() && ic1.GetValue() != nullptr); - - auto ic1_get = ih.Get("ic1"); - ASSERT_TRUE(ic1_get && *ic1_get == ic1.GetValue()); - - memgraph::storage::Config db_conf2{ - .durability = {.storage_directory = storage_directory / "ic2", - .snapshot_wal_mode = - memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL}, - .disk = {.main_storage_directory = storage_directory / "disk"}}; - auto ic2 = ih.New("ic2", test_handler, db_conf2, default_conf, ah, ac); - ASSERT_TRUE(ic2.HasValue() && ic2.GetValue() != nullptr); - - auto ic2_get = ih.Get("ic2"); - ASSERT_TRUE(ic2_get && *ic2_get == ic2.GetValue()); - - ASSERT_FALSE(ih.Get("aa")); - ASSERT_FALSE(ih.Get("ic1 ")); - ASSERT_FALSE(ih.Get("ic21")); - ASSERT_FALSE(ih.Get(" ic2")); -} - -TEST_F(DBMS_Interp, Delete) { - memgraph::dbms::InterpContextHandler ih; - TestAuthHandler ah; - TestAuthChecker ac; - - memgraph::storage::Config db_conf{ - .durability = {.storage_directory = storage_directory / "ic1", - .snapshot_wal_mode = - memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL}, - .disk = {.main_storage_directory = storage_directory / "disk"}}; - { - auto ic1 = ih.New("ic1", test_handler, db_conf, default_conf, ah, ac); - ASSERT_TRUE(ic1.HasValue() && ic1.GetValue() != nullptr); - } - - memgraph::storage::Config db_conf2{ - .durability = {.storage_directory = storage_directory / "ic2", - .snapshot_wal_mode = - memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL}, - .disk = {.main_storage_directory = storage_directory / "disk"}}; - { - auto ic2 = ih.New("ic2", test_handler, db_conf2, default_conf, ah, ac); - ASSERT_TRUE(ic2.HasValue() && ic2.GetValue() != nullptr); - } - - ASSERT_TRUE(ih.Delete("ic1")); - ASSERT_FALSE(ih.Get("ic1")); - ASSERT_FALSE(ih.Delete("ic1")); - ASSERT_FALSE(ih.Delete("ic3")); -} - -/** - * - * - * - * - * - * Test storage (previous StorageHandler, now handled via InterpretContext) - * - * - * - * - * - */ -TEST_F(DBMS_Interp, StorageNew) { - memgraph::dbms::InterpContextHandler ih; - TestAuthHandler ah; - TestAuthChecker ac; - - { ASSERT_FALSE(ih.GetConfig("db1")); } - { - // With custom config - memgraph::storage::Config db_config{ - .durability = {.storage_directory = storage_directory / "db2", - .snapshot_wal_mode = - memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL}, - .disk = {.main_storage_directory = storage_directory / "disk"}}; - auto db2 = ih.New("db2", test_handler, db_config, default_conf, ah, ac); - ASSERT_TRUE(db2.HasValue() && db2.GetValue() != nullptr); - ASSERT_TRUE(std::filesystem::exists(storage_directory / "db2")); - } - { - // With default config - auto db3 = ih.New("db3", test_handler, default_storage_conf("db3"), default_conf, ah, ac); - ASSERT_TRUE(db3.HasValue() && db3.GetValue() != nullptr); - ASSERT_TRUE(std::filesystem::exists(storage_directory / "db3")); - auto db4 = ih.New("db4", test_handler, default_storage_conf("four"), default_conf, ah, ac); - ASSERT_TRUE(db4.HasValue() && db4.GetValue() != nullptr); - ASSERT_TRUE(std::filesystem::exists(storage_directory / "four")); - auto db5 = ih.New("db5", test_handler, default_storage_conf("db3"), default_conf, ah, ac); - ASSERT_TRUE(db5.HasError() && db5.GetError() == memgraph::dbms::NewError::EXISTS); - } - - auto all = ih.All(); - std::sort(all.begin(), all.end()); - ASSERT_EQ(all.size(), 3); - ASSERT_EQ(all[0], "db2"); - ASSERT_EQ(all[1], "db3"); - ASSERT_EQ(all[2], "db4"); -} - -TEST_F(DBMS_Interp, StorageGet) { - memgraph::dbms::InterpContextHandler ih; - TestAuthHandler ah; - TestAuthChecker ac; - - auto db1 = ih.New("db1", test_handler, default_storage_conf("db1"), default_conf, ah, ac); - auto db2 = ih.New("db2", test_handler, default_storage_conf("db2"), default_conf, ah, ac); - auto db3 = ih.New("db3", test_handler, default_storage_conf("db3"), default_conf, ah, ac); - - ASSERT_TRUE(db1.HasValue()); - ASSERT_TRUE(db2.HasValue()); - ASSERT_TRUE(db3.HasValue()); - - auto get_db1 = ih.Get("db1"); - auto get_db2 = ih.Get("db2"); - auto get_db3 = ih.Get("db3"); - - ASSERT_TRUE(get_db1 && *get_db1 == db1.GetValue()); - ASSERT_TRUE(get_db2 && *get_db2 == db2.GetValue()); - ASSERT_TRUE(get_db3 && *get_db3 == db3.GetValue()); - - ASSERT_FALSE(ih.Get("db123")); - ASSERT_FALSE(ih.Get("db2 ")); - ASSERT_FALSE(ih.Get(" db3")); -} - -TEST_F(DBMS_Interp, StorageDelete) { - memgraph::dbms::InterpContextHandler ih; - TestAuthHandler ah; - TestAuthChecker ac; - - auto db1 = ih.New("db1", test_handler, default_storage_conf("db1"), default_conf, ah, ac); - auto db2 = ih.New("db2", test_handler, default_storage_conf("db2"), default_conf, ah, ac); - auto db3 = ih.New("db3", test_handler, default_storage_conf("db3"), default_conf, ah, ac); - - ASSERT_TRUE(db1.HasValue()); - ASSERT_TRUE(db2.HasValue()); - ASSERT_TRUE(db3.HasValue()); - - { - // Release pointer to storage - db1.GetValue().reset(); - // Delete from handler - ASSERT_TRUE(ih.Delete("db1")); - ASSERT_FALSE(ih.Get("db1")); - auto all = ih.All(); - std::sort(all.begin(), all.end()); - ASSERT_EQ(all.size(), 2); - ASSERT_EQ(all[0], "db2"); - ASSERT_EQ(all[1], "db3"); - } - - { - ASSERT_FALSE(ih.Delete("db0")); - ASSERT_FALSE(ih.Delete("db1")); - auto all = ih.All(); - std::sort(all.begin(), all.end()); - ASSERT_EQ(all.size(), 2); - ASSERT_EQ(all[0], "db2"); - ASSERT_EQ(all[1], "db3"); - } -} - -TEST_F(DBMS_Interp, StorageDeleteAndRecover) { - // memgraph::license::global_license_checker.EnableTesting(); - memgraph::dbms::InterpContextHandler ih; - TestAuthHandler ah; - TestAuthChecker ac; - - { - auto db1 = ih.New("db1", test_handler, default_storage_conf("db1"), default_conf, ah, ac); - auto db2 = ih.New("db2", test_handler, default_storage_conf("db2"), default_conf, ah, ac); - - memgraph::storage::Config conf_w_snap{ - .durability = {.storage_directory = storage_directory / "db3", - .snapshot_wal_mode = - memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL, - .snapshot_on_exit = true}, - .disk = {.main_storage_directory = storage_directory / "db3" / "disk"}}; - - auto db3 = ih.New("db3", test_handler, conf_w_snap, default_conf, ah, ac); - - ASSERT_TRUE(db1.HasValue()); - ASSERT_TRUE(db2.HasValue()); - ASSERT_TRUE(db3.HasValue()); - - // Add data to graphs - { - auto storage_dba = db1.GetValue()->db->Access(); - memgraph::query::DbAccessor dba{storage_dba.get()}; - memgraph::query::VertexAccessor v1{dba.InsertVertex()}; - memgraph::query::VertexAccessor v2{dba.InsertVertex()}; - ASSERT_TRUE(v1.AddLabel(dba.NameToLabel("l11")).HasValue()); - ASSERT_TRUE(v2.AddLabel(dba.NameToLabel("l12")).HasValue()); - ASSERT_FALSE(dba.Commit().HasError()); - } - { - auto storage_dba = db3.GetValue()->db->Access(); - memgraph::query::DbAccessor dba{storage_dba.get()}; - memgraph::query::VertexAccessor v1{dba.InsertVertex()}; - memgraph::query::VertexAccessor v2{dba.InsertVertex()}; - memgraph::query::VertexAccessor v3{dba.InsertVertex()}; - ASSERT_TRUE(v1.AddLabel(dba.NameToLabel("l31")).HasValue()); - ASSERT_TRUE(v2.AddLabel(dba.NameToLabel("l32")).HasValue()); - ASSERT_TRUE(v3.AddLabel(dba.NameToLabel("l33")).HasValue()); - ASSERT_FALSE(dba.Commit().HasError()); - } - } - - // Delete from handler - ASSERT_TRUE(ih.Delete("db1")); - ASSERT_TRUE(ih.Delete("db2")); - ASSERT_TRUE(ih.Delete("db3")); - - { - // Recover graphs (only db3) - auto db1 = ih.New("db1", test_handler, default_storage_conf("db1"), default_conf, ah, ac); - auto db2 = ih.New("db2", test_handler, default_storage_conf("db2"), default_conf, ah, ac); - - memgraph::storage::Config conf_w_rec{ - .durability = {.storage_directory = storage_directory / "db3", - .recover_on_startup = true, - .snapshot_wal_mode = - memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL}, - .disk = {.main_storage_directory = storage_directory / "db3" / "disk"}}; - - auto db3 = ih.New("db3", test_handler, conf_w_rec, default_conf, ah, ac); - - // Check content - { - // Empty - auto storage_dba = db1.GetValue()->db->Access(); - memgraph::query::DbAccessor dba{storage_dba.get()}; - ASSERT_EQ(dba.VerticesCount(), 0); - } - { - // Empty - auto storage_dba = db2.GetValue()->db->Access(); - memgraph::query::DbAccessor dba{storage_dba.get()}; - ASSERT_EQ(dba.VerticesCount(), 0); - } - { - // Full - auto storage_dba = db3.GetValue()->db->Access(); - memgraph::query::DbAccessor dba{storage_dba.get()}; - ASSERT_EQ(dba.VerticesCount(), 3); - } - } -} - -#endif diff --git a/tests/unit/dbms_sc_handler.cpp b/tests/unit/dbms_sc_handler.cpp deleted file mode 100644 index a19d769cb..000000000 --- a/tests/unit/dbms_sc_handler.cpp +++ /dev/null @@ -1,343 +0,0 @@ -// Copyright 2023 Memgraph Ltd. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source -// License, and you may not use this file except in compliance with the Business Source License. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -#include -#include "query/interpreter.hpp" -#ifdef MG_ENTERPRISE - -#include -#include -#include - -#include "dbms/constants.hpp" -#include "dbms/global.hpp" -#include "dbms/session_context_handler.hpp" -#include "glue/auth_checker.hpp" -#include "glue/auth_handler.hpp" -#include "query/config.hpp" - -std::filesystem::path storage_directory{std::filesystem::temp_directory_path() / "MG_test_unit_dbms_sc_handler"}; - -static memgraph::storage::Config storage_conf; - -memgraph::query::InterpreterConfig interp_conf; - -// Global -memgraph::audit::Log audit_log{storage_directory / "audit", 100, 1000}; - -class TestInterface : public memgraph::dbms::SessionInterface { - public: - TestInterface(std::string name, auto on_change, auto on_delete) : id_(id++), db_(name) { - on_change_ = on_change; - on_delete_ = on_delete; - } - std::string UUID() const override { return std::to_string(id_); } - std::string GetDatabaseName() const override { return db_; } - memgraph::dbms::SetForResult OnChange(const std::string &name) override { return on_change_(name); } - bool OnDelete(const std::string &name) override { return on_delete_(name); } - - static int id; - int id_; - std::string db_; - std::function on_change_; - std::function on_delete_; -}; - -int TestInterface::id{0}; - -// Let this be global so we can test it different states throughout - -class TestEnvironment : public ::testing::Environment { - public: - static memgraph::dbms::SessionContextHandler *get() { return ptr_.get(); } - - void SetUp() override { - // Setup config - memgraph::storage::UpdatePaths(storage_conf, storage_directory); - storage_conf.durability.snapshot_wal_mode = - memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL; - // Clean storage directory (running multiple parallel test, run only if the first process) - if (std::filesystem::exists(storage_directory)) { - memgraph::utils::OutputFile lock_file_handle_; - lock_file_handle_.Open(storage_directory / ".lock", memgraph::utils::OutputFile::Mode::OVERWRITE_EXISTING); - if (lock_file_handle_.AcquireLock()) { - std::filesystem::remove_all(storage_directory); - } - } - ptr_ = std::make_unique( - audit_log, - memgraph::dbms::SessionContextHandler::Config{ - storage_conf, interp_conf, - [](memgraph::utils::Synchronized *auth, - std::unique_ptr &ah, - std::unique_ptr &ac) { - // Glue high level auth implementations to the query side - ah = std::make_unique(auth, ""); - ac = std::make_unique(auth); - }}, - false, true); - } - - void TearDown() override { ptr_.reset(); } - - static std::unique_ptr ptr_; -}; - -std::unique_ptr TestEnvironment::ptr_ = nullptr; - -class DBMS_Handler : public testing::Test {}; -using DBMS_HandlerDeath = DBMS_Handler; - -TEST(DBMS_Handler, Init) { - // Check that the default db has been created successfully - std::vector dirs = {"snapshots", "streams", "triggers", "wal"}; - for (const auto &dir : dirs) - ASSERT_TRUE(std::filesystem::exists(storage_directory / dir)) << (storage_directory / dir); - const auto db_path = storage_directory / "databases" / memgraph::dbms::kDefaultDB; - ASSERT_TRUE(std::filesystem::exists(db_path)); - for (const auto &dir : dirs) { - std::error_code ec; - const auto test_link = std::filesystem::read_symlink(db_path / dir, ec); - ASSERT_TRUE(!ec) << ec.message(); - ASSERT_EQ(test_link, "../../" + dir); - } -} - -TEST(DBMS_HandlerDeath, InitSameDir) { - // This will be executed in a clean process (so the singleton will NOT be initalized) - (void)(::testing::GTEST_FLAG(death_test_style) = "threadsafe"); - // NOTE: Init test has ran in another process (so holds the lock) - ASSERT_DEATH( - { - memgraph::dbms::SessionContextHandler sch( - audit_log, - {storage_conf, interp_conf, - [](memgraph::utils::Synchronized *auth, - std::unique_ptr &ah, - std::unique_ptr &ac) { - // Glue high level auth implementations to the query side - ah = std::make_unique(auth, ""); - ac = std::make_unique(auth); - }}, - false, true); - }, - R"(\b.*\b)"); -} - -TEST(DBMS_Handler, New) { - auto &sch = *TestEnvironment::get(); - { - const auto all = sch.All(); - ASSERT_EQ(all.size(), 1); - ASSERT_EQ(all[0], memgraph::dbms::kDefaultDB); - } - { - auto sc1 = sch.New("sc1"); - ASSERT_TRUE(sc1.HasValue()); - ASSERT_TRUE(std::filesystem::exists(storage_directory / "databases" / "sc1")); - ASSERT_TRUE(sc1.GetValue().interpreter_context->db != nullptr); - ASSERT_TRUE(sc1.GetValue().interpreter_context != nullptr); - ASSERT_TRUE(sc1.GetValue().audit_log != nullptr); - ASSERT_TRUE(sc1.GetValue().auth != nullptr); - const auto all = sch.All(); - ASSERT_EQ(all.size(), 2); - ASSERT_TRUE(std::find(all.begin(), all.end(), memgraph::dbms::kDefaultDB) != all.end()); - ASSERT_TRUE(std::find(all.begin(), all.end(), "sc1") != all.end()); - } - { - // Fail if name exists - auto sc2 = sch.New("sc1"); - ASSERT_TRUE(sc2.HasError() && sc2.GetError() == memgraph::dbms::NewError::EXISTS); - } - { - auto sc3 = sch.New("sc3"); - ASSERT_TRUE(sc3.HasValue()); - ASSERT_TRUE(std::filesystem::exists(storage_directory / "databases" / "sc3")); - ASSERT_TRUE(sc3.GetValue().interpreter_context->db != nullptr); - ASSERT_TRUE(sc3.GetValue().interpreter_context != nullptr); - ASSERT_TRUE(sc3.GetValue().audit_log != nullptr); - ASSERT_TRUE(sc3.GetValue().auth != nullptr); - const auto all = sch.All(); - ASSERT_EQ(all.size(), 3); - ASSERT_TRUE(std::find(all.begin(), all.end(), "sc3") != all.end()); - } -} - -TEST(DBMS_Handler, Get) { - auto &sch = *TestEnvironment::get(); - auto default_sc = sch.Get(memgraph::dbms::kDefaultDB); - ASSERT_TRUE(default_sc.interpreter_context->db != nullptr); - ASSERT_TRUE(default_sc.interpreter_context != nullptr); - ASSERT_TRUE(default_sc.audit_log != nullptr); - ASSERT_TRUE(default_sc.auth != nullptr); - - ASSERT_ANY_THROW(sch.Get("non-existent")); - - auto sc1 = sch.Get("sc1"); - ASSERT_TRUE(sc1.interpreter_context->db != nullptr); - ASSERT_TRUE(sc1.interpreter_context != nullptr); - ASSERT_TRUE(sc1.audit_log != nullptr); - ASSERT_TRUE(sc1.auth != nullptr); - - auto sc3 = sch.Get("sc3"); - ASSERT_TRUE(sc3.interpreter_context->db != nullptr); - ASSERT_TRUE(sc3.interpreter_context != nullptr); - ASSERT_TRUE(sc3.audit_log != nullptr); - ASSERT_TRUE(sc3.auth != nullptr); -} - -TEST(DBMS_Handler, SetFor) { - auto &sch = *TestEnvironment::get(); - - ASSERT_TRUE(sch.New("db1").HasValue()); - - bool ti0_on_change_ = false; - bool ti0_on_delete_ = false; - TestInterface ti0( - "memgraph", - [&ti0, &ti0_on_change_](const std::string &name) -> memgraph::dbms::SetForResult { - ti0_on_change_ = true; - if (name != ti0.db_) { - ti0.db_ = name; - return memgraph::dbms::SetForResult::SUCCESS; - } - return memgraph::dbms::SetForResult::ALREADY_SET; - }, - [&](const std::string &name) -> bool { - ti0_on_delete_ = true; - return true; - }); - - bool ti1_on_change_ = false; - bool ti1_on_delete_ = false; - TestInterface ti1( - "db1", - [&](const std::string &name) -> memgraph::dbms::SetForResult { - ti1_on_change_ = true; - return memgraph::dbms::SetForResult::SUCCESS; - }, - [&](const std::string &name) -> bool { - ti1_on_delete_ = true; - return true; - }); - - ASSERT_TRUE(sch.Register(ti0)); - ASSERT_FALSE(sch.Register(ti0)); - - { - ASSERT_EQ(sch.SetFor("0", "db1"), memgraph::dbms::SetForResult::SUCCESS); - ASSERT_TRUE(ti0_on_change_); - ti0_on_change_ = false; - ASSERT_EQ(sch.SetFor("0", "db1"), memgraph::dbms::SetForResult::ALREADY_SET); - ASSERT_TRUE(ti0_on_change_); - ti0_on_change_ = false; - ASSERT_ANY_THROW(sch.SetFor(std::to_string(TestInterface::id), "db1")); // Session does not exist - ASSERT_ANY_THROW(sch.SetFor("1", "db1")); // Session not registered - ASSERT_ANY_THROW(sch.SetFor("0", "db2")); // No db2 - ASSERT_EQ(sch.SetFor("0", "memgraph"), memgraph::dbms::SetForResult::SUCCESS); - ASSERT_TRUE(ti0_on_change_); - } - - ASSERT_TRUE(sch.Delete(ti0)); - ASSERT_FALSE(sch.Delete(ti1)); -} - -TEST(DBMS_Handler, Delete) { - auto &sch = *TestEnvironment::get(); - - bool ti0_on_change_ = false; - bool ti0_on_delete_ = false; - TestInterface ti0( - "memgraph", - [&](const std::string &name) -> memgraph::dbms::SetForResult { - ti0_on_change_ = true; - if (name != "sc3") return memgraph::dbms::SetForResult::SUCCESS; - return memgraph::dbms::SetForResult::FAIL; - }, - [&](const std::string &name) -> bool { - ti0_on_delete_ = true; - return (name != "sc3"); - }); - - bool ti1_on_change_ = false; - bool ti1_on_delete_ = false; - TestInterface ti1( - "sc1", - [&](const std::string &name) -> memgraph::dbms::SetForResult { - ti1_on_change_ = true; - ti1.db_ = name; - return memgraph::dbms::SetForResult::SUCCESS; - }, - [&](const std::string &name) -> bool { - ti1_on_delete_ = true; - return ti1.db_ != name; - }); - - ASSERT_TRUE(sch.Register(ti0)); - ASSERT_TRUE(sch.Register(ti1)); - - { - auto del = sch.Delete(memgraph::dbms::kDefaultDB); - ASSERT_TRUE(del.HasError() && del.GetError() == memgraph::dbms::DeleteError::DEFAULT_DB); - } - { - auto del = sch.Delete("non-existent"); - ASSERT_TRUE(del.HasError() && del.GetError() == memgraph::dbms::DeleteError::NON_EXISTENT); - } - { - // ti1 is using sc1 - auto del = sch.Delete("sc1"); - ASSERT_TRUE(del.HasError()); - ASSERT_TRUE(del.GetError() == memgraph::dbms::DeleteError::FAIL); - } - { - // Delete ti1 so delete will succeed - ASSERT_EQ(sch.SetFor(ti1.UUID(), "memgraph"), memgraph::dbms::SetForResult::SUCCESS); - auto del = sch.Delete("sc1"); - ASSERT_FALSE(del.HasError()) << (int)del.GetError(); - auto del2 = sch.Delete("sc1"); - ASSERT_TRUE(del2.HasError() && del2.GetError() == memgraph::dbms::DeleteError::NON_EXISTENT); - } - { - // Using based on the active interpreters - auto new_sc = sch.New("sc1"); - ASSERT_TRUE(new_sc.HasValue()) << (int)new_sc.GetError(); - auto sc = sch.Get("sc1"); - memgraph::query::Interpreter interpreter(sc.interpreter_context.get()); - sc.interpreter_context->interpreters.WithLock([&](auto &interpreters) { interpreters.insert(&interpreter); }); - auto del = sch.Delete("sc1"); - ASSERT_TRUE(del.HasError()); - ASSERT_EQ(del.GetError(), memgraph::dbms::DeleteError::USING); - sc.interpreter_context->interpreters.WithLock([&](auto &interpreters) { interpreters.erase(&interpreter); }); - } - { - // Interpreter deactivated, so we should be able to delete - auto del = sch.Delete("sc1"); - ASSERT_FALSE(del.HasError()) << (int)del.GetError(); - } - { - ASSERT_TRUE(sch.Delete(ti0)); - auto del = sch.Delete("sc3"); - ASSERT_FALSE(del.HasError()); - ASSERT_FALSE(std::filesystem::exists(storage_directory / "databases" / "sc3")); - } - - ASSERT_TRUE(sch.Delete(ti1)); -} - -int main(int argc, char *argv[]) { - ::testing::InitGoogleTest(&argc, argv); - // gtest takes ownership of the TestEnvironment ptr - we don't delete it. - ::testing::AddGlobalTestEnvironment(new TestEnvironment); - return RUN_ALL_TESTS(); -} - -#endif diff --git a/tests/unit/interpreter.cpp b/tests/unit/interpreter.cpp index 1115fd3be..ff8df198e 100644 --- a/tests/unit/interpreter.cpp +++ b/tests/unit/interpreter.cpp @@ -26,12 +26,14 @@ #include "query/config.hpp" #include "query/exceptions.hpp" #include "query/interpreter.hpp" +#include "query/interpreter_context.hpp" #include "query/stream.hpp" #include "query/typed_value.hpp" #include "query_common.hpp" #include "storage/v2/inmemory/storage.hpp" #include "storage/v2/isolation_level.hpp" #include "storage/v2/property_value.hpp" +#include "storage/v2/storage_mode.hpp" #include "utils/logging.hpp" namespace { @@ -49,20 +51,43 @@ auto ToEdgeList(const memgraph::communication::bolt::Value &v) { // TODO: This is not a unit test, but tests/integration dir is chaotic at the // moment. After tests refactoring is done, move/rename this. +constexpr auto kNoHandler = nullptr; + template class InterpreterTest : public ::testing::Test { public: const std::string testSuite = "interpreter"; const std::string testSuiteCsv = "interpreter_csv"; + std::filesystem::path data_directory = std::filesystem::temp_directory_path() / "MG_tests_unit_interpreter"; - InterpreterTest() - : data_directory(std::filesystem::temp_directory_path() / "MG_tests_unit_interpreter"), - interpreter_context(std::make_unique(disk_test_utils::GenerateOnDiskConfig(testSuite)), {}, - data_directory) { - memgraph::flags::run_time::execution_timeout_sec_ = 600.0; - } + InterpreterTest() : interpreter_context({}, kNoHandler) { memgraph::flags::run_time::execution_timeout_sec_ = 600.0; } + + memgraph::utils::Gatekeeper db_gk{ + [&]() { + memgraph::storage::Config config{}; + config.durability.storage_directory = data_directory; + config.disk.main_storage_directory = config.durability.storage_directory / "disk"; + if constexpr (std::is_same_v) { + config.disk = disk_test_utils::GenerateOnDiskConfig(testSuite).disk; + config.force_on_disk = true; + } + return config; + }() // iile + }; + + memgraph::dbms::DatabaseAccess db{ + [&]() { + auto db_acc_opt = db_gk.access(); + MG_ASSERT(db_acc_opt, "Failed to access db"); + auto &db_acc = *db_acc_opt; + MG_ASSERT(db_acc->GetStorageMode() == (std::is_same_v + ? memgraph::storage::StorageMode::ON_DISK_TRANSACTIONAL + : memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL), + "Wrong storage mode!"); + return db_acc; + }() // iile + }; - std::filesystem::path data_directory; memgraph::query::InterpreterContext interpreter_context; void TearDown() override { @@ -72,7 +97,7 @@ class InterpreterTest : public ::testing::Test { } } - InterpreterFaker default_interpreter{&interpreter_context}; + InterpreterFaker default_interpreter{&interpreter_context, db}; auto Prepare(const std::string &query, const std::map ¶ms = {}) { return default_interpreter.Prepare(query, params); @@ -320,7 +345,7 @@ TYPED_TEST(InterpreterTest, Bfs) { // Set up. { - auto storage_dba = this->interpreter_context.db->Access(); + auto storage_dba = this->db->Access(); memgraph::query::DbAccessor dba(storage_dba.get()); auto add_node = [&](int level, bool reachable) { auto node = dba.InsertVertex(); @@ -633,7 +658,7 @@ TYPED_TEST(InterpreterTest, UniqueConstraintTest) { } TYPED_TEST(InterpreterTest, ExplainQuery) { - EXPECT_EQ(this->interpreter_context.plan_cache.size(), 0U); + EXPECT_EQ(this->db->plan_cache()->size(), 0U); EXPECT_EQ(this->interpreter_context.ast_cache.size(), 0U); auto stream = this->Interpret("EXPLAIN MATCH (n) RETURN *;"); ASSERT_EQ(stream.GetHeader().size(), 1U); @@ -647,16 +672,16 @@ TYPED_TEST(InterpreterTest, ExplainQuery) { ++expected_it; } // We should have a plan cache for MATCH ... - EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(this->db->plan_cache()->size(), 1U); // We should have AST cache for EXPLAIN ... and for inner MATCH ... EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U); this->Interpret("MATCH (n) RETURN *;"); - EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(this->db->plan_cache()->size(), 1U); EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U); } TYPED_TEST(InterpreterTest, ExplainQueryMultiplePulls) { - EXPECT_EQ(this->interpreter_context.plan_cache.size(), 0U); + EXPECT_EQ(this->db->plan_cache()->size(), 0U); EXPECT_EQ(this->interpreter_context.ast_cache.size(), 0U); auto [stream, qid] = this->Prepare("EXPLAIN MATCH (n) RETURN *;"); ASSERT_EQ(stream.GetHeader().size(), 1U); @@ -680,16 +705,16 @@ TYPED_TEST(InterpreterTest, ExplainQueryMultiplePulls) { ASSERT_EQ(stream.GetResults()[2].size(), 1U); EXPECT_EQ(stream.GetResults()[2].front().ValueString(), *expected_it); // We should have a plan cache for MATCH ... - EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(this->db->plan_cache()->size(), 1U); // We should have AST cache for EXPLAIN ... and for inner MATCH ... EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U); this->Interpret("MATCH (n) RETURN *;"); - EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(this->db->plan_cache()->size(), 1U); EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U); } TYPED_TEST(InterpreterTest, ExplainQueryInMulticommandTransaction) { - EXPECT_EQ(this->interpreter_context.plan_cache.size(), 0U); + EXPECT_EQ(this->db->plan_cache()->size(), 0U); EXPECT_EQ(this->interpreter_context.ast_cache.size(), 0U); this->Interpret("BEGIN"); auto stream = this->Interpret("EXPLAIN MATCH (n) RETURN *;"); @@ -705,16 +730,16 @@ TYPED_TEST(InterpreterTest, ExplainQueryInMulticommandTransaction) { ++expected_it; } // We should have a plan cache for MATCH ... - EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(this->db->plan_cache()->size(), 1U); // We should have AST cache for EXPLAIN ... and for inner MATCH ... EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U); this->Interpret("MATCH (n) RETURN *;"); - EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(this->db->plan_cache()->size(), 1U); EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U); } TYPED_TEST(InterpreterTest, ExplainQueryWithParams) { - EXPECT_EQ(this->interpreter_context.plan_cache.size(), 0U); + EXPECT_EQ(this->db->plan_cache()->size(), 0U); EXPECT_EQ(this->interpreter_context.ast_cache.size(), 0U); auto stream = this->Interpret("EXPLAIN MATCH (n) WHERE n.id = $id RETURN *;", {{"id", memgraph::storage::PropertyValue(42)}}); @@ -729,16 +754,16 @@ TYPED_TEST(InterpreterTest, ExplainQueryWithParams) { ++expected_it; } // We should have a plan cache for MATCH ... - EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(this->db->plan_cache()->size(), 1U); // We should have AST cache for EXPLAIN ... and for inner MATCH ... EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U); this->Interpret("MATCH (n) WHERE n.id = $id RETURN *;", {{"id", memgraph::storage::PropertyValue("something else")}}); - EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(this->db->plan_cache()->size(), 1U); EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U); } TYPED_TEST(InterpreterTest, ProfileQuery) { - EXPECT_EQ(this->interpreter_context.plan_cache.size(), 0U); + EXPECT_EQ(this->db->plan_cache()->size(), 0U); EXPECT_EQ(this->interpreter_context.ast_cache.size(), 0U); auto stream = this->Interpret("PROFILE MATCH (n) RETURN *;"); std::vector expected_header{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"}; @@ -752,16 +777,16 @@ TYPED_TEST(InterpreterTest, ProfileQuery) { ++expected_it; } // We should have a plan cache for MATCH ... - EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(this->db->plan_cache()->size(), 1U); // We should have AST cache for PROFILE ... and for inner MATCH ... EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U); this->Interpret("MATCH (n) RETURN *;"); - EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(this->db->plan_cache()->size(), 1U); EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U); } TYPED_TEST(InterpreterTest, ProfileQueryMultiplePulls) { - EXPECT_EQ(this->interpreter_context.plan_cache.size(), 0U); + EXPECT_EQ(this->db->plan_cache()->size(), 0U); EXPECT_EQ(this->interpreter_context.ast_cache.size(), 0U); auto [stream, qid] = this->Prepare("PROFILE MATCH (n) RETURN *;"); std::vector expected_header{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"}; @@ -788,11 +813,11 @@ TYPED_TEST(InterpreterTest, ProfileQueryMultiplePulls) { ASSERT_EQ(stream.GetResults()[2][0].ValueString(), *expected_it); // We should have a plan cache for MATCH ... - EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(this->db->plan_cache()->size(), 1U); // We should have AST cache for PROFILE ... and for inner MATCH ... EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U); this->Interpret("MATCH (n) RETURN *;"); - EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(this->db->plan_cache()->size(), 1U); EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U); } @@ -803,7 +828,7 @@ TYPED_TEST(InterpreterTest, ProfileQueryInMulticommandTransaction) { } TYPED_TEST(InterpreterTest, ProfileQueryWithParams) { - EXPECT_EQ(this->interpreter_context.plan_cache.size(), 0U); + EXPECT_EQ(this->db->plan_cache()->size(), 0U); EXPECT_EQ(this->interpreter_context.ast_cache.size(), 0U); auto stream = this->Interpret("PROFILE MATCH (n) WHERE n.id = $id RETURN *;", {{"id", memgraph::storage::PropertyValue(42)}}); @@ -818,16 +843,16 @@ TYPED_TEST(InterpreterTest, ProfileQueryWithParams) { ++expected_it; } // We should have a plan cache for MATCH ... - EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(this->db->plan_cache()->size(), 1U); // We should have AST cache for PROFILE ... and for inner MATCH ... EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U); this->Interpret("MATCH (n) WHERE n.id = $id RETURN *;", {{"id", memgraph::storage::PropertyValue("something else")}}); - EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(this->db->plan_cache()->size(), 1U); EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U); } TYPED_TEST(InterpreterTest, ProfileQueryWithLiterals) { - EXPECT_EQ(this->interpreter_context.plan_cache.size(), 0U); + EXPECT_EQ(this->db->plan_cache()->size(), 0U); EXPECT_EQ(this->interpreter_context.ast_cache.size(), 0U); auto stream = this->Interpret("PROFILE UNWIND range(1, 1000) AS x CREATE (:Node {id: x});", {}); std::vector expected_header{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"}; @@ -841,11 +866,11 @@ TYPED_TEST(InterpreterTest, ProfileQueryWithLiterals) { ++expected_it; } // We should have a plan cache for UNWIND ... - EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(this->db->plan_cache()->size(), 1U); // We should have AST cache for PROFILE ... and for inner UNWIND ... EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U); this->Interpret("UNWIND range(42, 4242) AS x CREATE (:Node {id: x});", {}); - EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(this->db->plan_cache()->size(), 1U); EXPECT_EQ(this->interpreter_context.ast_cache.size(), 2U); } @@ -1071,7 +1096,7 @@ TYPED_TEST(InterpreterTest, CacheableQueries) { SCOPED_TRACE("Cacheable query"); this->Interpret("RETURN 1"); EXPECT_EQ(this->interpreter_context.ast_cache.size(), 1U); - EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(this->db->plan_cache()->size(), 1U); } { @@ -1080,7 +1105,7 @@ TYPED_TEST(InterpreterTest, CacheableQueries) { // result signature could be changed this->Interpret("CALL mg.load_all()"); EXPECT_EQ(this->interpreter_context.ast_cache.size(), 1U); - EXPECT_EQ(this->interpreter_context.plan_cache.size(), 1U); + EXPECT_EQ(this->db->plan_cache()->size(), 1U); } } @@ -1098,11 +1123,25 @@ TYPED_TEST(InterpreterTest, AllowLoadCsvConfig) { "CREATE TRIGGER trigger ON CREATE BEFORE COMMIT EXECUTE LOAD CSV FROM 'file.csv' WITH HEADER AS row RETURN " "row"}; - memgraph::query::InterpreterContext csv_interpreter_context{ - std::make_unique(disk_test_utils::GenerateOnDiskConfig(this->testSuiteCsv)), - {.query = {.allow_load_csv = allow_load_csv}}, - directory_manager.Path()}; - InterpreterFaker interpreter_faker{&csv_interpreter_context}; + memgraph::storage::Config config2{}; + config2.durability.storage_directory = directory_manager.Path(); + config2.disk.main_storage_directory = config2.durability.storage_directory / "disk"; + if constexpr (std::is_same_v) { + config2.disk = disk_test_utils::GenerateOnDiskConfig(this->testSuiteCsv).disk; + config2.force_on_disk = true; + } + + memgraph::utils::Gatekeeper db_gk2(config2); + auto db_acc_opt = db_gk2.access(); + ASSERT_TRUE(db_acc_opt) << "Failed to access db2"; + auto &db_acc = *db_acc_opt; + ASSERT_TRUE(db_acc->GetStorageMode() == (std::is_same_v + ? memgraph::storage::StorageMode::ON_DISK_TRANSACTIONAL + : memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL)) + << "Wrong storage mode!"; + + memgraph::query::InterpreterContext csv_interpreter_context{{.query = {.allow_load_csv = allow_load_csv}}, nullptr}; + InterpreterFaker interpreter_faker{&csv_interpreter_context, db_acc}; for (const auto &query : queries) { if (allow_load_csv) { SCOPED_TRACE(fmt::format("'{}' should not throw because LOAD CSV is allowed", query)); diff --git a/tests/unit/interpreter_faker.hpp b/tests/unit/interpreter_faker.hpp index 63d4516f3..a303007ec 100644 --- a/tests/unit/interpreter_faker.hpp +++ b/tests/unit/interpreter_faker.hpp @@ -11,16 +11,17 @@ #include "communication/result_stream_faker.hpp" #include "query/interpreter.hpp" +#include "query/interpreter_context.hpp" struct InterpreterFaker { - InterpreterFaker(memgraph::query::InterpreterContext *interpreter_context) - : interpreter_context(interpreter_context), interpreter(interpreter_context) { + InterpreterFaker(memgraph::query::InterpreterContext *interpreter_context, memgraph::dbms::DatabaseAccess db) + : interpreter_context(interpreter_context), interpreter(interpreter_context, db) { interpreter_context->auth_checker = &auth_checker; interpreter_context->interpreters.WithLock([this](auto &interpreters) { interpreters.insert(&interpreter); }); } auto Prepare(const std::string &query, const std::map ¶ms = {}) { - ResultStreamFaker stream(interpreter_context->db.get()); + ResultStreamFaker stream(interpreter.db_acc_->get()->storage()); const auto [header, _1, qid, _2] = interpreter.Prepare(query, params, nullptr); stream.Header(header); return std::make_pair(std::move(stream), qid); diff --git a/tests/unit/query_dump.cpp b/tests/unit/query_dump.cpp index a11c1e65b..de7530274 100644 --- a/tests/unit/query_dump.cpp +++ b/tests/unit/query_dump.cpp @@ -17,10 +17,13 @@ #include #include "communication/result_stream_faker.hpp" +#include "dbms/database.hpp" #include "disk_test_utils.hpp" #include "query/config.hpp" #include "query/dump.hpp" #include "query/interpreter.hpp" +#include "query/interpreter_context.hpp" +#include "query/stream/streams.hpp" #include "query/typed_value.hpp" #include "storage/v2/disk/storage.hpp" #include "storage/v2/edge_accessor.hpp" @@ -205,9 +208,10 @@ DatabaseState GetState(memgraph::storage::Storage *db) { return {vertices, edges, label_indices, label_property_indices, existence_constraints, unique_constraints}; } -auto Execute(memgraph::query::InterpreterContext *context, const std::string &query) { - memgraph::query::Interpreter interpreter(context); - ResultStreamFaker stream(context->db.get()); +auto Execute(memgraph::query::InterpreterContext *context, memgraph::dbms::DatabaseAccess db, + const std::string &query) { + memgraph::query::Interpreter interpreter(context, db); + ResultStreamFaker stream(db->storage()); auto [header, _1, qid, _2] = interpreter.Prepare(query, {}, nullptr); stream.Header(header); @@ -275,14 +279,40 @@ class DumpTest : public ::testing::Test { public: const std::string testSuite = "query_dump"; std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_tests_unit_query_dump_class"}; - memgraph::query::InterpreterContext context{ - std::make_unique(disk_test_utils::GenerateOnDiskConfig(testSuite)), - memgraph::query::InterpreterConfig{}, data_directory}; + + memgraph::utils::Gatekeeper db_gk{ + [&]() { + memgraph::storage::Config config{}; + config.durability.storage_directory = data_directory; + config.disk.main_storage_directory = config.durability.storage_directory / "disk"; + if constexpr (std::is_same_v) { + config.disk = disk_test_utils::GenerateOnDiskConfig(testSuite).disk; + config.force_on_disk = true; + } + return config; + }() // iile + }; + + memgraph::dbms::DatabaseAccess db{ + [&]() { + auto db_acc_opt = db_gk.access(); + MG_ASSERT(db_acc_opt, "Failed to access db"); + auto &db_acc = *db_acc_opt; + MG_ASSERT(db_acc->GetStorageMode() == (std::is_same_v + ? memgraph::storage::StorageMode::ON_DISK_TRANSACTIONAL + : memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL), + "Wrong storage mode!"); + return db_acc; + }() // iile + }; + + memgraph::query::InterpreterContext context{memgraph::query::InterpreterConfig{}, nullptr}; void TearDown() override { if (std::is_same::value) { disk_test_utils::RemoveRocksDbDirs(testSuite); } + std::filesystem::remove_all(data_directory); } }; @@ -291,10 +321,10 @@ TYPED_TEST_CASE(DumpTest, StorageTypes); // NOLINTNEXTLINE(hicpp-special-member-functions) TYPED_TEST(DumpTest, EmptyGraph) { - ResultStreamFaker stream(this->context.db.get()); + ResultStreamFaker stream(this->db->storage()); memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource()); { - auto acc = this->context.db->Access(); + auto acc = this->db->storage()->Access(); memgraph::query::DbAccessor dba(acc.get()); memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream); } @@ -304,16 +334,16 @@ TYPED_TEST(DumpTest, EmptyGraph) { // NOLINTNEXTLINE(hicpp-special-member-functions) TYPED_TEST(DumpTest, SingleVertex) { { - auto dba = this->context.db->Access(); + auto dba = this->db->storage()->Access(); CreateVertex(dba.get(), {}, {}, false); ASSERT_FALSE(dba->Commit().HasError()); } { - ResultStreamFaker stream(this->context.db.get()); + ResultStreamFaker stream(this->db->storage()); memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource()); { - auto acc = this->context.db->Access(); + auto acc = this->db->storage()->Access(); memgraph::query::DbAccessor dba(acc.get()); memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream); } @@ -325,16 +355,16 @@ TYPED_TEST(DumpTest, SingleVertex) { // NOLINTNEXTLINE(hicpp-special-member-functions) TYPED_TEST(DumpTest, VertexWithSingleLabel) { { - auto dba = this->context.db->Access(); + auto dba = this->db->storage()->Access(); CreateVertex(dba.get(), {"Label1"}, {}, false); ASSERT_FALSE(dba->Commit().HasError()); } { - ResultStreamFaker stream(this->context.db.get()); + ResultStreamFaker stream(this->db->storage()); memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource()); { - auto acc = this->context.db->Access(); + auto acc = this->db->storage()->Access(); memgraph::query::DbAccessor dba(acc.get()); memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream); } @@ -346,16 +376,16 @@ TYPED_TEST(DumpTest, VertexWithSingleLabel) { // NOLINTNEXTLINE(hicpp-special-member-functions) TYPED_TEST(DumpTest, VertexWithMultipleLabels) { { - auto dba = this->context.db->Access(); + auto dba = this->db->storage()->Access(); CreateVertex(dba.get(), {"Label1", "Label 2"}, {}, false); ASSERT_FALSE(dba->Commit().HasError()); } { - ResultStreamFaker stream(this->context.db.get()); + ResultStreamFaker stream(this->db->storage()); memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource()); { - auto acc = this->context.db->Access(); + auto acc = this->db->storage()->Access(); memgraph::query::DbAccessor dba(acc.get()); memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream); } @@ -368,16 +398,16 @@ TYPED_TEST(DumpTest, VertexWithMultipleLabels) { // NOLINTNEXTLINE(hicpp-special-member-functions) TYPED_TEST(DumpTest, VertexWithSingleProperty) { { - auto dba = this->context.db->Access(); + auto dba = this->db->storage()->Access(); CreateVertex(dba.get(), {}, {{"prop", memgraph::storage::PropertyValue(42)}}, false); ASSERT_FALSE(dba->Commit().HasError()); } { - ResultStreamFaker stream(this->context.db.get()); + ResultStreamFaker stream(this->db->storage()); memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource()); { - auto acc = this->context.db->Access(); + auto acc = this->db->storage()->Access(); memgraph::query::DbAccessor dba(acc.get()); memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream); } @@ -389,7 +419,7 @@ TYPED_TEST(DumpTest, VertexWithSingleProperty) { // NOLINTNEXTLINE(hicpp-special-member-functions) TYPED_TEST(DumpTest, MultipleVertices) { { - auto dba = this->context.db->Access(); + auto dba = this->db->storage()->Access(); CreateVertex(dba.get(), {}, {}, false); CreateVertex(dba.get(), {}, {}, false); CreateVertex(dba.get(), {}, {}, false); @@ -397,10 +427,10 @@ TYPED_TEST(DumpTest, MultipleVertices) { } { - ResultStreamFaker stream(this->context.db.get()); + ResultStreamFaker stream(this->db->storage()); memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource()); { - auto acc = this->context.db->Access(); + auto acc = this->db->storage()->Access(); memgraph::query::DbAccessor dba(acc.get()); memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream); } @@ -412,7 +442,7 @@ TYPED_TEST(DumpTest, MultipleVertices) { TYPED_TEST(DumpTest, PropertyValue) { { - auto dba = this->context.db->Access(); + auto dba = this->db->storage()->Access(); auto null_value = memgraph::storage::PropertyValue(); auto int_value = memgraph::storage::PropertyValue(13); auto bool_value = memgraph::storage::PropertyValue(true); @@ -435,10 +465,10 @@ TYPED_TEST(DumpTest, PropertyValue) { } { - ResultStreamFaker stream(this->context.db.get()); + ResultStreamFaker stream(this->db->storage()); memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource()); { - auto acc = this->context.db->Access(); + auto acc = this->db->storage()->Access(); memgraph::query::DbAccessor dba(acc.get()); memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream); } @@ -454,7 +484,7 @@ TYPED_TEST(DumpTest, PropertyValue) { // NOLINTNEXTLINE(hicpp-special-member-functions) TYPED_TEST(DumpTest, SingleEdge) { { - auto dba = this->context.db->Access(); + auto dba = this->db->storage()->Access(); auto u = CreateVertex(dba.get(), {}, {}, false); auto v = CreateVertex(dba.get(), {}, {}, false); CreateEdge(dba.get(), &u, &v, "EdgeType", {}, false); @@ -462,10 +492,10 @@ TYPED_TEST(DumpTest, SingleEdge) { } { - ResultStreamFaker stream(this->context.db.get()); + ResultStreamFaker stream(this->db->storage()); memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource()); { - auto acc = this->context.db->Access(); + auto acc = this->db->storage()->Access(); memgraph::query::DbAccessor dba(acc.get()); memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream); } @@ -480,7 +510,7 @@ TYPED_TEST(DumpTest, SingleEdge) { // NOLINTNEXTLINE(hicpp-special-member-functions) TYPED_TEST(DumpTest, MultipleEdges) { { - auto dba = this->context.db->Access(); + auto dba = this->db->storage()->Access(); auto u = CreateVertex(dba.get(), {}, {}, false); auto v = CreateVertex(dba.get(), {}, {}, false); auto w = CreateVertex(dba.get(), {}, {}, false); @@ -491,10 +521,10 @@ TYPED_TEST(DumpTest, MultipleEdges) { } { - ResultStreamFaker stream(this->context.db.get()); + ResultStreamFaker stream(this->db->storage()); memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource()); { - auto acc = this->context.db->Access(); + auto acc = this->db->storage()->Access(); memgraph::query::DbAccessor dba(acc.get()); memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream); } @@ -513,7 +543,7 @@ TYPED_TEST(DumpTest, MultipleEdges) { // NOLINTNEXTLINE(hicpp-special-member-functions) TYPED_TEST(DumpTest, EdgeWithProperties) { { - auto dba = this->context.db->Access(); + auto dba = this->db->storage()->Access(); auto u = CreateVertex(dba.get(), {}, {}, false); auto v = CreateVertex(dba.get(), {}, {}, false); CreateEdge(dba.get(), &u, &v, "EdgeType", {{"prop", memgraph::storage::PropertyValue(13)}}, false); @@ -521,10 +551,10 @@ TYPED_TEST(DumpTest, EdgeWithProperties) { } { - ResultStreamFaker stream(this->context.db.get()); + ResultStreamFaker stream(this->db->storage()); memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource()); { - auto acc = this->context.db->Access(); + auto acc = this->db->storage()->Access(); memgraph::query::DbAccessor dba(acc.get()); memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream); } @@ -539,22 +569,24 @@ TYPED_TEST(DumpTest, EdgeWithProperties) { // NOLINTNEXTLINE(hicpp-special-member-functions) TYPED_TEST(DumpTest, IndicesKeys) { { - auto dba = this->context.db->Access(); + auto dba = this->db->storage()->Access(); CreateVertex(dba.get(), {"Label1", "Label 2"}, {{"p", memgraph::storage::PropertyValue(1)}}, false); ASSERT_FALSE(dba->Commit().HasError()); } ASSERT_FALSE( - this->context.db->CreateIndex(this->context.db->NameToLabel("Label1"), this->context.db->NameToProperty("prop")) + this->db->storage() + ->CreateIndex(this->db->storage()->NameToLabel("Label1"), this->db->storage()->NameToProperty("prop")) + .HasError()); + ASSERT_FALSE( + this->db->storage() + ->CreateIndex(this->db->storage()->NameToLabel("Label 2"), this->db->storage()->NameToProperty("prop `")) .HasError()); - ASSERT_FALSE(this->context.db - ->CreateIndex(this->context.db->NameToLabel("Label 2"), this->context.db->NameToProperty("prop `")) - .HasError()); { - ResultStreamFaker stream(this->context.db.get()); + ResultStreamFaker stream(this->db->storage()); memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource()); { - auto acc = this->context.db->Access(); + auto acc = this->db->storage()->Access(); memgraph::query::DbAccessor dba(acc.get()); memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream); } @@ -567,21 +599,21 @@ TYPED_TEST(DumpTest, IndicesKeys) { // NOLINTNEXTLINE(hicpp-special-member-functions) TYPED_TEST(DumpTest, ExistenceConstraints) { { - auto dba = this->context.db->Access(); + auto dba = this->db->storage()->Access(); CreateVertex(dba.get(), {"L`abel 1"}, {{"prop", memgraph::storage::PropertyValue(1)}}, false); ASSERT_FALSE(dba->Commit().HasError()); } { - auto res = this->context.db->CreateExistenceConstraint(this->context.db->NameToLabel("L`abel 1"), - this->context.db->NameToProperty("prop"), {}); + auto res = this->db->storage()->CreateExistenceConstraint(this->db->storage()->NameToLabel("L`abel 1"), + this->db->storage()->NameToProperty("prop"), {}); ASSERT_FALSE(res.HasError()); } { - ResultStreamFaker stream(this->context.db.get()); + ResultStreamFaker stream(this->db->storage()); memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource()); { - auto acc = this->context.db->Access(); + auto acc = this->db->storage()->Access(); memgraph::query::DbAccessor dba(acc.get()); memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream); } @@ -593,7 +625,7 @@ TYPED_TEST(DumpTest, ExistenceConstraints) { TYPED_TEST(DumpTest, UniqueConstraints) { { - auto dba = this->context.db->Access(); + auto dba = this->db->storage()->Access(); CreateVertex(dba.get(), {"Label"}, {{"prop", memgraph::storage::PropertyValue(1)}, {"prop2", memgraph::storage::PropertyValue(2)}}, false); @@ -603,18 +635,18 @@ TYPED_TEST(DumpTest, UniqueConstraints) { ASSERT_FALSE(dba->Commit().HasError()); } { - auto res = this->context.db->CreateUniqueConstraint( - this->context.db->NameToLabel("Label"), - {this->context.db->NameToProperty("prop"), this->context.db->NameToProperty("prop2")}, {}); + auto res = this->db->storage()->CreateUniqueConstraint( + this->db->storage()->NameToLabel("Label"), + {this->db->storage()->NameToProperty("prop"), this->db->storage()->NameToProperty("prop2")}, {}); ASSERT_TRUE(res.HasValue()); ASSERT_EQ(res.GetValue(), memgraph::storage::UniqueConstraints::CreationStatus::SUCCESS); } { - ResultStreamFaker stream(this->context.db.get()); + ResultStreamFaker stream(this->db->storage()); memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource()); { - auto acc = this->context.db->Access(); + auto acc = this->db->storage()->Access(); memgraph::query::DbAccessor dba(acc.get()); memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream); } @@ -633,7 +665,7 @@ TYPED_TEST(DumpTest, UniqueConstraints) { // NOLINTNEXTLINE(hicpp-special-member-functions) TYPED_TEST(DumpTest, CheckStateVertexWithMultipleProperties) { { - auto dba = this->context.db->Access(); + auto dba = this->db->storage()->Access(); std::map prop1 = { {"nested1", memgraph::storage::PropertyValue(1337)}, {"nested2", memgraph::storage::PropertyValue(3.14)}}; @@ -644,15 +676,30 @@ TYPED_TEST(DumpTest, CheckStateVertexWithMultipleProperties) { ASSERT_FALSE(dba->Commit().HasError()); } - auto data_directory = std::filesystem::temp_directory_path() / "MG_tests_unit_query_dump"; - memgraph::query::InterpreterContext interpreter_context(std::make_unique(), - memgraph::query::InterpreterConfig{}, data_directory); + memgraph::storage::Config config{}; + config.durability.storage_directory = this->data_directory / "s1"; + config.disk.main_storage_directory = config.durability.storage_directory / "disk"; + if constexpr (std::is_same_v) { + config.disk = disk_test_utils::GenerateOnDiskConfig("query-dump-s1").disk; + config.force_on_disk = true; + } + + memgraph::utils::Gatekeeper db_gk(config); + auto db_acc_opt = db_gk.access(); + ASSERT_TRUE(db_acc_opt) << "Failed to access db"; + auto &db_acc = *db_acc_opt; + ASSERT_TRUE(db_acc->GetStorageMode() == (std::is_same_v + ? memgraph::storage::StorageMode::ON_DISK_TRANSACTIONAL + : memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL)) + << "Wrong storage mode!"; + + memgraph::query::InterpreterContext interpreter_context(memgraph::query::InterpreterConfig{}, nullptr); { - ResultStreamFaker stream(this->context.db.get()); + ResultStreamFaker stream(this->db->storage()); memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource()); { - auto acc = this->context.db->Access(); + auto acc = this->db->storage()->Access(); memgraph::query::DbAccessor dba(acc.get()); memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream); } @@ -661,7 +708,7 @@ TYPED_TEST(DumpTest, CheckStateVertexWithMultipleProperties) { for (const auto &item : results) { ASSERT_EQ(item.size(), 1); ASSERT_TRUE(item[0].IsString()); - Execute(&interpreter_context, item[0].ValueString()); + Execute(&interpreter_context, db_acc, item[0].ValueString()); } } } @@ -669,7 +716,7 @@ TYPED_TEST(DumpTest, CheckStateVertexWithMultipleProperties) { // NOLINTNEXTLINE(hicpp-special-member-functions) TYPED_TEST(DumpTest, CheckStateSimpleGraph) { { - auto dba = this->context.db->Access(); + auto dba = this->db->storage()->Access(); auto u = CreateVertex(dba.get(), {"Person"}, {{"name", memgraph::storage::PropertyValue("Ivan")}}); auto v = CreateVertex(dba.get(), {"Person"}, {{"name", memgraph::storage::PropertyValue("Josko")}}); auto w = CreateVertex( @@ -710,33 +757,48 @@ TYPED_TEST(DumpTest, CheckStateSimpleGraph) { ASSERT_FALSE(dba->Commit().HasError()); } { - auto ret = this->context.db->CreateExistenceConstraint(this->context.db->NameToLabel("Person"), - this->context.db->NameToProperty("name"), {}); + auto ret = this->db->storage()->CreateExistenceConstraint(this->db->storage()->NameToLabel("Person"), + this->db->storage()->NameToProperty("name"), {}); ASSERT_FALSE(ret.HasError()); } { - auto ret = this->context.db->CreateUniqueConstraint(this->context.db->NameToLabel("Person"), - {this->context.db->NameToProperty("name")}, {}); + auto ret = this->db->storage()->CreateUniqueConstraint(this->db->storage()->NameToLabel("Person"), + {this->db->storage()->NameToProperty("name")}, {}); ASSERT_TRUE(ret.HasValue()); ASSERT_EQ(ret.GetValue(), memgraph::storage::UniqueConstraints::CreationStatus::SUCCESS); } - ASSERT_FALSE( - this->context.db->CreateIndex(this->context.db->NameToLabel("Person"), this->context.db->NameToProperty("id")) - .HasError()); - ASSERT_FALSE(this->context.db - ->CreateIndex(this->context.db->NameToLabel("Person"), - this->context.db->NameToProperty("unexisting_property")) + ASSERT_FALSE(this->db->storage() + ->CreateIndex(this->db->storage()->NameToLabel("Person"), this->db->storage()->NameToProperty("id")) + .HasError()); + ASSERT_FALSE(this->db->storage() + ->CreateIndex(this->db->storage()->NameToLabel("Person"), + this->db->storage()->NameToProperty("unexisting_property")) .HasError()); - const auto &db_initial_state = GetState(this->context.db.get()); - auto data_directory = std::filesystem::temp_directory_path() / "MG_tests_unit_query_dump"; - memgraph::query::InterpreterContext interpreter_context(std::make_unique(), - memgraph::query::InterpreterConfig{}, data_directory); + const auto &db_initial_state = GetState(this->db->storage()); + memgraph::storage::Config config{}; + config.durability.storage_directory = this->data_directory / "s2"; + config.disk.main_storage_directory = config.durability.storage_directory / "disk"; + if constexpr (std::is_same_v) { + config.disk = disk_test_utils::GenerateOnDiskConfig("query-dump-s2").disk; + config.force_on_disk = true; + } + + memgraph::utils::Gatekeeper db_gk(config); + auto db_acc_opt = db_gk.access(); + ASSERT_TRUE(db_acc_opt) << "Failed to access db"; + auto &db_acc = *db_acc_opt; + ASSERT_TRUE(db_acc->GetStorageMode() == (std::is_same_v + ? memgraph::storage::StorageMode::ON_DISK_TRANSACTIONAL + : memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL)) + << "Wrong storage mode!"; + + memgraph::query::InterpreterContext interpreter_context(memgraph::query::InterpreterConfig{}, nullptr); { - ResultStreamFaker stream(this->context.db.get()); + ResultStreamFaker stream(this->db->storage()); memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource()); { - auto acc = this->context.db->Access(); + auto acc = this->db->storage()->Access(); memgraph::query::DbAccessor dba(acc.get()); memgraph::query::DumpDatabaseToCypherQueries(&dba, &query_stream); } @@ -749,23 +811,23 @@ TYPED_TEST(DumpTest, CheckStateSimpleGraph) { ASSERT_EQ(item.size(), 1); ASSERT_TRUE(item[0].IsString()); spdlog::debug("Query: {}", item[0].ValueString()); - Execute(&interpreter_context, item[0].ValueString()); + Execute(&interpreter_context, db_acc, item[0].ValueString()); ++i; } } - ASSERT_EQ(GetState(this->context.db.get()), db_initial_state); + ASSERT_EQ(GetState(this->db->storage()), db_initial_state); } // NOLINTNEXTLINE(hicpp-special-member-functions) TYPED_TEST(DumpTest, ExecuteDumpDatabase) { { - auto dba = this->context.db->Access(); + auto dba = this->db->storage()->Access(); CreateVertex(dba.get(), {}, {}, false); ASSERT_FALSE(dba->Commit().HasError()); } { - auto stream = Execute(&this->context, "DUMP DATABASE"); + auto stream = Execute(&this->context, this->db, "DUMP DATABASE"); const auto &header = stream.GetHeader(); const auto &results = stream.GetResults(); ASSERT_EQ(header.size(), 1U); @@ -784,11 +846,11 @@ TYPED_TEST(DumpTest, ExecuteDumpDatabase) { class StatefulInterpreter { public: - explicit StatefulInterpreter(memgraph::query::InterpreterContext *context) - : context_(context), interpreter_(context_) {} + explicit StatefulInterpreter(memgraph::query::InterpreterContext *context, memgraph::dbms::DatabaseAccess db) + : context_(context), interpreter_(context_, db) {} auto Execute(const std::string &query) { - ResultStreamFaker stream(context_->db.get()); + ResultStreamFaker stream(interpreter_.db_acc_->get()->storage()); auto [header, _1, qid, _2] = interpreter_.Prepare(query, {}, nullptr); stream.Header(header); @@ -799,18 +861,13 @@ class StatefulInterpreter { } private: - static const std::filesystem::path data_directory_; - memgraph::query::InterpreterContext *context_; memgraph::query::Interpreter interpreter_; }; -const std::filesystem::path StatefulInterpreter::data_directory_{std::filesystem::temp_directory_path() / - "MG_tests_unit_query_dump_stateful"}; - // NOLINTNEXTLINE(hicpp-special-member-functions) TYPED_TEST(DumpTest, ExecuteDumpDatabaseInMulticommandTransaction) { - StatefulInterpreter interpreter(&this->context); + StatefulInterpreter interpreter(&this->context, this->db); // Begin the transaction before the vertex is created. interpreter.Execute("BEGIN"); @@ -827,7 +884,7 @@ TYPED_TEST(DumpTest, ExecuteDumpDatabaseInMulticommandTransaction) { // Create the vertex. { - auto dba = this->context.db->Access(); + auto dba = this->db->storage()->Access(); CreateVertex(dba.get(), {}, {}, false); ASSERT_FALSE(dba->Commit().HasError()); } @@ -875,39 +932,41 @@ TYPED_TEST(DumpTest, MultiplePartialPulls) { { // Create indices ASSERT_FALSE( - this->context.db->CreateIndex(this->context.db->NameToLabel("PERSON"), this->context.db->NameToProperty("name")) + this->db->storage() + ->CreateIndex(this->db->storage()->NameToLabel("PERSON"), this->db->storage()->NameToProperty("name")) + .HasError()); + ASSERT_FALSE( + this->db->storage() + ->CreateIndex(this->db->storage()->NameToLabel("PERSON"), this->db->storage()->NameToProperty("surname")) .HasError()); - ASSERT_FALSE(this->context.db - ->CreateIndex(this->context.db->NameToLabel("PERSON"), this->context.db->NameToProperty("surname")) - .HasError()); // Create existence constraints { - auto res = this->context.db->CreateExistenceConstraint(this->context.db->NameToLabel("PERSON"), - this->context.db->NameToProperty("name"), {}); + auto res = this->db->storage()->CreateExistenceConstraint(this->db->storage()->NameToLabel("PERSON"), + this->db->storage()->NameToProperty("name"), {}); ASSERT_FALSE(res.HasError()); } { - auto res = this->context.db->CreateExistenceConstraint(this->context.db->NameToLabel("PERSON"), - this->context.db->NameToProperty("surname"), {}); + auto res = this->db->storage()->CreateExistenceConstraint(this->db->storage()->NameToLabel("PERSON"), + this->db->storage()->NameToProperty("surname"), {}); ASSERT_FALSE(res.HasError()); } // Create unique constraints { - auto res = this->context.db->CreateUniqueConstraint(this->context.db->NameToLabel("PERSON"), - {this->context.db->NameToProperty("name")}, {}); + auto res = this->db->storage()->CreateUniqueConstraint(this->db->storage()->NameToLabel("PERSON"), + {this->db->storage()->NameToProperty("name")}, {}); ASSERT_TRUE(res.HasValue()); ASSERT_EQ(res.GetValue(), memgraph::storage::UniqueConstraints::CreationStatus::SUCCESS); } { - auto res = this->context.db->CreateUniqueConstraint(this->context.db->NameToLabel("PERSON"), - {this->context.db->NameToProperty("surname")}, {}); + auto res = this->db->storage()->CreateUniqueConstraint(this->db->storage()->NameToLabel("PERSON"), + {this->db->storage()->NameToProperty("surname")}, {}); ASSERT_TRUE(res.HasValue()); ASSERT_EQ(res.GetValue(), memgraph::storage::UniqueConstraints::CreationStatus::SUCCESS); } - auto dba = this->context.db->Access(); + auto dba = this->db->storage()->Access(); auto p1 = CreateVertex(dba.get(), {"PERSON"}, {{"name", memgraph::storage::PropertyValue("Person1")}, {"surname", memgraph::storage::PropertyValue("Unique1")}}, @@ -935,9 +994,9 @@ TYPED_TEST(DumpTest, MultiplePartialPulls) { ASSERT_FALSE(dba->Commit().HasError()); } - ResultStreamFaker stream(this->context.db.get()); + ResultStreamFaker stream(this->db->storage()); memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource()); - auto acc = this->context.db->Access(); + auto acc = this->db->storage()->Access(); memgraph::query::DbAccessor dba(acc.get()); memgraph::query::PullPlanDump pullPlan{&dba}; diff --git a/tests/unit/query_plan_edge_cases.cpp b/tests/unit/query_plan_edge_cases.cpp index 26b49b7f4..a81b517b5 100644 --- a/tests/unit/query_plan_edge_cases.cpp +++ b/tests/unit/query_plan_edge_cases.cpp @@ -23,6 +23,8 @@ #include "communication/result_stream_faker.hpp" #include "query/interpreter.hpp" +#include "query/interpreter_context.hpp" +#include "query/stream/streams.hpp" #include "storage/v2/inmemory/storage.hpp" #include "storage/v2/storage.hpp" @@ -32,24 +34,48 @@ template class QueryExecution : public testing::Test { protected: const std::string testSuite = "query_plan_edge_cases"; + std::optional db_acc_; std::optional interpreter_context_; std::optional interpreter_; std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_tests_unit_query_plan_edge_cases"}; + std::optional> db_gk{ + [&]() { + memgraph::storage::Config config{}; + config.durability.storage_directory = data_directory; + config.disk.main_storage_directory = config.durability.storage_directory / "disk"; + if constexpr (std::is_same_v) { + config.disk = disk_test_utils::GenerateOnDiskConfig(testSuite).disk; + config.force_on_disk = true; + } + return config; + }() // iile + }; + void SetUp() { - interpreter_context_.emplace(std::make_unique(disk_test_utils::GenerateOnDiskConfig(testSuite)), - memgraph::query::InterpreterConfig{}, data_directory); - interpreter_.emplace(&*interpreter_context_); + auto db_acc_opt = db_gk->access(); + MG_ASSERT(db_acc_opt, "Failed to access db"); + auto &db_acc = *db_acc_opt; + MG_ASSERT(db_acc->GetStorageMode() == (std::is_same_v + ? memgraph::storage::StorageMode::ON_DISK_TRANSACTIONAL + : memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL), + "Wrong storage mode!"); + db_acc_ = std::move(db_acc); + + interpreter_context_.emplace(memgraph::query::InterpreterConfig{}, nullptr); + interpreter_.emplace(&*interpreter_context_, *db_acc_); } void TearDown() { interpreter_ = std::nullopt; interpreter_context_ = std::nullopt; - + db_acc_.reset(); + db_gk.reset(); if (std::is_same::value) { disk_test_utils::RemoveRocksDbDirs(testSuite); } + std::filesystem::remove_all(data_directory); } /** @@ -58,7 +84,7 @@ class QueryExecution : public testing::Test { * Return the query results. */ auto Execute(const std::string &query) { - ResultStreamFaker stream(this->interpreter_context_->db.get()); + ResultStreamFaker stream(this->db_acc_->get()->storage()); auto [header, _1, qid, _2] = interpreter_->Prepare(query, {}, nullptr); stream.Header(header); diff --git a/tests/unit/query_streams.cpp b/tests/unit/query_streams.cpp index 3c78b5590..a64c2b090 100644 --- a/tests/unit/query_streams.cpp +++ b/tests/unit/query_streams.cpp @@ -22,6 +22,7 @@ #include "kafka_mock.hpp" #include "query/config.hpp" #include "query/interpreter.hpp" +#include "query/interpreter_context.hpp" #include "query/stream/streams.hpp" #include "storage/v2/disk/storage.hpp" #include "storage/v2/inmemory/storage.hpp" @@ -74,9 +75,32 @@ class StreamsTestFixture : public ::testing::Test { // Streams constructor. // InterpreterContext::auth_checker_ is used in the Streams object, but only in the message processing part. Because // these tests don't send any messages, the auth_checker_ pointer can be left as nullptr. - memgraph::query::InterpreterContext interpreter_context_{ - std::make_unique(disk_test_utils::GenerateOnDiskConfig(testSuite)), - memgraph::query::InterpreterConfig{}, data_directory_}; + + memgraph::utils::Gatekeeper db_gk{ + [&]() { + memgraph::storage::Config config{}; + config.durability.storage_directory = data_directory_; + config.disk.main_storage_directory = config.durability.storage_directory / "disk"; + if constexpr (std::is_same_v) { + config.disk = disk_test_utils::GenerateOnDiskConfig(testSuite).disk; + config.force_on_disk = true; + } + return config; + }() // iile + }; + memgraph::dbms::DatabaseAccess db_{ + [&]() { + auto db_acc_opt = db_gk.access(); + MG_ASSERT(db_acc_opt, "Failed to access db"); + auto &db_acc = *db_acc_opt; + MG_ASSERT(db_acc->GetStorageMode() == (std::is_same_v + ? memgraph::storage::StorageMode::ON_DISK_TRANSACTIONAL + : memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL), + "Wrong storage mode!"); + return db_acc; + }() // iile + }; + memgraph::query::InterpreterContext interpreter_context_{memgraph::query::InterpreterConfig{}, nullptr}; std::filesystem::path streams_data_directory_{data_directory_ / "separate-dir-for-test"}; std::optional proxyStreams_; @@ -84,11 +108,12 @@ class StreamsTestFixture : public ::testing::Test { if (std::is_same::value) { disk_test_utils::RemoveRocksDbDirs(testSuite); } + std::filesystem::remove_all(data_directory_); } void ResetStreamsObject() { proxyStreams_.emplace(); - proxyStreams_->streams_.emplace(&interpreter_context_, streams_data_directory_); + proxyStreams_->streams_.emplace(streams_data_directory_); } void CheckStreamStatus(const StreamCheckData &check_data) { @@ -151,8 +176,8 @@ TYPED_TEST_CASE(StreamsTestFixture, StorageTypes); TYPED_TEST(StreamsTestFixture, SimpleStreamManagement) { auto check_data = this->CreateDefaultStreamCheckData(); - this->proxyStreams_->streams_->template Create(check_data.name, check_data.info, - check_data.owner); + this->proxyStreams_->streams_->template Create( + check_data.name, check_data.info, check_data.owner, this->db_, &this->interpreter_context_); EXPECT_NO_FATAL_FAILURE(this->CheckStreamStatus(check_data)); EXPECT_NO_THROW(this->proxyStreams_->streams_->Start(check_data.name)); @@ -178,12 +203,12 @@ TYPED_TEST(StreamsTestFixture, SimpleStreamManagement) { TYPED_TEST(StreamsTestFixture, CreateAlreadyExisting) { auto stream_info = this->CreateDefaultStreamInfo(); auto stream_name = GetDefaultStreamName(); - this->proxyStreams_->streams_->template Create(stream_name, stream_info, - std::nullopt); + this->proxyStreams_->streams_->template Create( + stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_); try { - this->proxyStreams_->streams_->template Create(stream_name, stream_info, - std::nullopt); + this->proxyStreams_->streams_->template Create( + stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_); FAIL() << "Creating already existing stream should throw\n"; } catch (memgraph::query::stream::StreamsException &exception) { EXPECT_EQ(exception.what(), fmt::format("Stream already exists with name '{}'", stream_name)); @@ -194,8 +219,8 @@ TYPED_TEST(StreamsTestFixture, DropNotExistingStream) { const auto stream_info = this->CreateDefaultStreamInfo(); const auto stream_name = GetDefaultStreamName(); const std::string not_existing_stream_name{"ThisDoesn'tExists"}; - this->proxyStreams_->streams_->template Create(stream_name, stream_info, - std::nullopt); + this->proxyStreams_->streams_->template Create( + stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_); try { this->proxyStreams_->streams_->Drop(not_existing_stream_name); @@ -250,7 +275,7 @@ TYPED_TEST(StreamsTestFixture, RestoreStreams) { // Reset the Streams object to trigger reloading this->ResetStreamsObject(); EXPECT_TRUE(this->proxyStreams_->streams_->GetStreamInfo().empty()); - this->proxyStreams_->streams_->RestoreStreams(); + this->proxyStreams_->streams_->RestoreStreams(this->db_, &this->interpreter_context_); EXPECT_EQ(stream_check_datas.size(), this->proxyStreams_->streams_->GetStreamInfo().size()); for (const auto &check_data : stream_check_datas) { ASSERT_NO_FATAL_FAILURE(this->CheckStreamStatus(check_data)); @@ -258,12 +283,12 @@ TYPED_TEST(StreamsTestFixture, RestoreStreams) { } }; - this->proxyStreams_->streams_->RestoreStreams(); + this->proxyStreams_->streams_->RestoreStreams(this->db_, &this->interpreter_context_); EXPECT_TRUE(this->proxyStreams_->streams_->GetStreamInfo().empty()); for (auto &check_data : stream_check_datas) { this->proxyStreams_->streams_->template Create( - check_data.name, check_data.info, check_data.owner); + check_data.name, check_data.info, check_data.owner, this->db_, &this->interpreter_context_); } { SCOPED_TRACE("After streams are created"); @@ -299,13 +324,13 @@ TYPED_TEST(StreamsTestFixture, RestoreStreams) { TYPED_TEST(StreamsTestFixture, CheckWithTimeout) { const auto stream_info = this->CreateDefaultStreamInfo(); const auto stream_name = GetDefaultStreamName(); - this->proxyStreams_->streams_->template Create(stream_name, stream_info, - std::nullopt); + this->proxyStreams_->streams_->template Create( + stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_); std::chrono::milliseconds timeout{3000}; const auto start = std::chrono::steady_clock::now(); - EXPECT_THROW(this->proxyStreams_->streams_->Check(stream_name, timeout, std::nullopt), + EXPECT_THROW(this->proxyStreams_->streams_->Check(stream_name, this->db_, timeout, std::nullopt), memgraph::integrations::kafka::ConsumerCheckFailedException); const auto end = std::chrono::steady_clock::now(); @@ -325,7 +350,7 @@ TYPED_TEST(StreamsTestFixture, CheckInvalidConfig) { EXPECT_TRUE(message.find(kConfigValue) != std::string::npos) << message; }; EXPECT_THROW_WITH_MSG(this->proxyStreams_->streams_->template Create( - stream_name, stream_info, std::nullopt), + stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_), memgraph::integrations::kafka::SettingCustomConfigFailed, checker); } @@ -341,6 +366,6 @@ TYPED_TEST(StreamsTestFixture, CheckInvalidCredentials) { EXPECT_TRUE(message.find(kCredentialValue) == std::string::npos) << message; }; EXPECT_THROW_WITH_MSG(this->proxyStreams_->streams_->template Create( - stream_name, stream_info, std::nullopt), + stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_), memgraph::integrations::kafka::SettingCustomConfigFailed, checker); } diff --git a/tests/unit/storage_v2_storage_mode.cpp b/tests/unit/storage_v2_storage_mode.cpp index ad0def8ce..6656acf02 100644 --- a/tests/unit/storage_v2_storage_mode.cpp +++ b/tests/unit/storage_v2_storage_mode.cpp @@ -18,6 +18,7 @@ #include "interpreter_faker.hpp" #include "query/exceptions.hpp" +#include "query/interpreter_context.hpp" #include "storage/v2/inmemory/storage.hpp" #include "storage/v2/isolation_level.hpp" #include "storage/v2/storage_mode.hpp" @@ -68,10 +69,26 @@ INSTANTIATE_TEST_CASE_P(ParameterizedStorageModeTests, StorageModeTest, ::testin class StorageModeMultiTxTest : public ::testing::Test { protected: - std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_tests_unit_storage_mode"}; - memgraph::query::InterpreterContext interpreter_context{ - std::make_unique(), {}, data_directory}; - InterpreterFaker running_interpreter{&interpreter_context}, main_interpreter{&interpreter_context}; + std::filesystem::path data_directory = []() { + const auto tmp = std::filesystem::temp_directory_path() / "MG_tests_unit_storage_mode"; + std::filesystem::remove_all(tmp); + return tmp; + }(); // iile + + memgraph::utils::Gatekeeper db_gk{memgraph::storage::Config{ + .durability.storage_directory = data_directory, .disk.main_storage_directory = data_directory / "disk"}}; + + memgraph::dbms::DatabaseAccess db{ + [&]() { + auto db_acc_opt = db_gk.access(); + auto &db_acc = *db_acc_opt; + MG_ASSERT(db_acc, "Failed to access db"); + return db_acc; + }() // iile + }; + + memgraph::query::InterpreterContext interpreter_context{{}, nullptr}; + InterpreterFaker running_interpreter{&interpreter_context, db}, main_interpreter{&interpreter_context, db}; }; TEST_F(StorageModeMultiTxTest, ModeSwitchInactiveTransaction) { @@ -87,11 +104,11 @@ TEST_F(StorageModeMultiTxTest, ModeSwitchInactiveTransaction) { while (!started) { std::this_thread::sleep_for(std::chrono::milliseconds(20)); } - ASSERT_EQ(interpreter_context.db->GetStorageMode(), memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL); + ASSERT_EQ(db->GetStorageMode(), memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL); main_interpreter.Interpret("STORAGE MODE IN_MEMORY_ANALYTICAL"); // should change state - ASSERT_EQ(interpreter_context.db->GetStorageMode(), memgraph::storage::StorageMode::IN_MEMORY_ANALYTICAL); + ASSERT_EQ(db->GetStorageMode(), memgraph::storage::StorageMode::IN_MEMORY_ANALYTICAL); // finish thread running_thread.request_stop(); @@ -100,7 +117,7 @@ TEST_F(StorageModeMultiTxTest, ModeSwitchInactiveTransaction) { TEST_F(StorageModeMultiTxTest, ModeSwitchActiveTransaction) { // transactional state - ASSERT_EQ(interpreter_context.db->GetStorageMode(), memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL); + ASSERT_EQ(db->GetStorageMode(), memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL); main_interpreter.Interpret("BEGIN"); bool started = false; @@ -119,7 +136,7 @@ TEST_F(StorageModeMultiTxTest, ModeSwitchActiveTransaction) { std::this_thread::sleep_for(std::chrono::milliseconds(20)); } // should not change still - ASSERT_EQ(interpreter_context.db->GetStorageMode(), memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL); + ASSERT_EQ(db->GetStorageMode(), memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL); main_interpreter.Interpret("COMMIT"); @@ -127,7 +144,7 @@ TEST_F(StorageModeMultiTxTest, ModeSwitchActiveTransaction) { std::this_thread::sleep_for(std::chrono::milliseconds(20)); } // should change state - ASSERT_EQ(interpreter_context.db->GetStorageMode(), memgraph::storage::StorageMode::IN_MEMORY_ANALYTICAL); + ASSERT_EQ(db->GetStorageMode(), memgraph::storage::StorageMode::IN_MEMORY_ANALYTICAL); // finish thread running_thread.request_stop(); @@ -135,11 +152,11 @@ TEST_F(StorageModeMultiTxTest, ModeSwitchActiveTransaction) { } TEST_F(StorageModeMultiTxTest, ErrorChangeIsolationLevel) { - ASSERT_EQ(interpreter_context.db->GetStorageMode(), memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL); + ASSERT_EQ(db->GetStorageMode(), memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL); main_interpreter.Interpret("STORAGE MODE IN_MEMORY_ANALYTICAL"); // should change state - ASSERT_EQ(interpreter_context.db->GetStorageMode(), memgraph::storage::StorageMode::IN_MEMORY_ANALYTICAL); + ASSERT_EQ(db->GetStorageMode(), memgraph::storage::StorageMode::IN_MEMORY_ANALYTICAL); ASSERT_THROW(running_interpreter.Interpret("SET GLOBAL TRANSACTION ISOLATION LEVEL READ COMMITTED;"), memgraph::query::IsolationLevelModificationInAnalyticsException); diff --git a/tests/unit/transaction_queue.cpp b/tests/unit/transaction_queue.cpp index d04006a32..45aad1588 100644 --- a/tests/unit/transaction_queue.cpp +++ b/tests/unit/transaction_queue.cpp @@ -19,6 +19,7 @@ #include "disk_test_utils.hpp" #include "interpreter_faker.hpp" +#include "query/interpreter_context.hpp" #include "storage/v2/inmemory/storage.hpp" /* @@ -30,11 +31,38 @@ class TransactionQueueSimpleTest : public ::testing::Test { protected: const std::string testSuite = "transactin_queue"; std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_tests_unit_transaction_queue_intr"}; - memgraph::query::InterpreterContext interpreter_context{ - std::make_unique(disk_test_utils::GenerateOnDiskConfig(testSuite)), {}, data_directory}; - InterpreterFaker running_interpreter{&interpreter_context}, main_interpreter{&interpreter_context}; + memgraph::utils::Gatekeeper db_gk{ + [&]() { + memgraph::storage::Config config{}; + config.durability.storage_directory = data_directory; + config.disk.main_storage_directory = config.durability.storage_directory / "disk"; + if constexpr (std::is_same_v) { + config.disk = disk_test_utils::GenerateOnDiskConfig(testSuite).disk; + config.force_on_disk = true; + } + return config; + }() // iile + }; - void TearDown() override { disk_test_utils::RemoveRocksDbDirs(testSuite); } + memgraph::dbms::DatabaseAccess db{ + [&]() { + auto db_acc_opt = db_gk.access(); + MG_ASSERT(db_acc_opt, "Failed to access db"); + auto &db_acc = *db_acc_opt; + MG_ASSERT(db_acc->GetStorageMode() == (std::is_same_v + ? memgraph::storage::StorageMode::ON_DISK_TRANSACTIONAL + : memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL), + "Wrong storage mode!"); + return db_acc; + }() // iile + }; + memgraph::query::InterpreterContext interpreter_context{{}, nullptr}; + InterpreterFaker running_interpreter{&interpreter_context, db}, main_interpreter{&interpreter_context, db}; + + void TearDown() override { + disk_test_utils::RemoveRocksDbDirs(testSuite); + std::filesystem::remove_all(data_directory); + } }; using StorageTypes = ::testing::Types; diff --git a/tests/unit/transaction_queue_multiple.cpp b/tests/unit/transaction_queue_multiple.cpp index 7ce7a4ae4..f5f9941d1 100644 --- a/tests/unit/transaction_queue_multiple.cpp +++ b/tests/unit/transaction_queue_multiple.cpp @@ -22,6 +22,7 @@ #include "disk_test_utils.hpp" #include "interpreter_faker.hpp" #include "query/exceptions.hpp" +#include "query/interpreter_context.hpp" #include "storage/v2/config.hpp" #include "storage/v2/disk/storage.hpp" #include "storage/v2/inmemory/storage.hpp" @@ -38,14 +39,39 @@ class TransactionQueueMultipleTest : public ::testing::Test { const std::string testSuite = "transactin_queue_multiple"; std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_tests_unit_transaction_queue_multiple_intr"}; - memgraph::query::InterpreterContext interpreter_context{ - std::make_unique(disk_test_utils::GenerateOnDiskConfig(testSuite)), {}, data_directory}; - InterpreterFaker main_interpreter{&interpreter_context}; + memgraph::utils::Gatekeeper db_gk{ + [&]() { + memgraph::storage::Config config{}; + config.durability.storage_directory = data_directory; + config.disk.main_storage_directory = config.durability.storage_directory / "disk"; + if constexpr (std::is_same_v) { + config.disk = disk_test_utils::GenerateOnDiskConfig(testSuite).disk; + config.force_on_disk = true; + } + return config; + }() // iile + }; + + memgraph::dbms::DatabaseAccess db{ + [&]() { + auto db_acc_opt = db_gk.access(); + MG_ASSERT(db_acc_opt, "Failed to access db"); + auto &db_acc = *db_acc_opt; + MG_ASSERT(db_acc->GetStorageMode() == (std::is_same_v + ? memgraph::storage::StorageMode::ON_DISK_TRANSACTIONAL + : memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL), + "Wrong storage mode!"); + return db_acc; + }() // iile + }; + + memgraph::query::InterpreterContext interpreter_context{{}, nullptr}; + InterpreterFaker main_interpreter{&interpreter_context, db}; std::vector running_interpreters; TransactionQueueMultipleTest() { for (int i = 0; i < NUM_INTERPRETERS; ++i) { - InterpreterFaker *faker = new InterpreterFaker(&interpreter_context); + InterpreterFaker *faker = new InterpreterFaker(&interpreter_context, db); running_interpreters.push_back(faker); } } @@ -55,6 +81,7 @@ class TransactionQueueMultipleTest : public ::testing::Test { delete running_interpreters[i]; } disk_test_utils::RemoveRocksDbDirs(testSuite); + std::filesystem::remove_all(data_directory); } }; diff --git a/tests/unit/utils_sync_ptr.cpp b/tests/unit/utils_sync_ptr.cpp deleted file mode 100644 index 62111c8a8..000000000 --- a/tests/unit/utils_sync_ptr.cpp +++ /dev/null @@ -1,296 +0,0 @@ -// Copyright 2023 Memgraph Ltd. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source -// License, and you may not use this file except in compliance with the Business Source License. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -#include - -#include -#include -#include - -#include -#include "utils/exceptions.hpp" - -using namespace std::chrono_literals; - -TEST(SyncPtr, Basic) { - std::atomic_bool alive{false}; - struct Test { - Test(std::atomic_bool &alive) : alive_(alive) { alive_ = true; } - ~Test() { alive_ = false; } - std::atomic_bool &alive_; - }; - - ASSERT_FALSE(alive); - - memgraph::utils::SyncPtr sp(alive); - ASSERT_TRUE(alive); - auto sp_copy1 = sp.get(); - auto sp_copy2 = sp.get(); - - sp_copy1.reset(); - ASSERT_TRUE(alive); - sp_copy2.reset(); - ASSERT_TRUE(alive); - - sp.DestroyAndSync(); - ASSERT_FALSE(alive); -} - -TEST(SyncPtr, BasicWConfig) { - std::atomic_bool alive{false}; - struct Test { - Test(std::atomic_bool &alive) : alive_(alive) { alive_ = true; } - ~Test() { alive_ = false; } - std::atomic_bool &alive_; - }; - - struct TestConf { - TestConf(int i) : conf_(i) {} - int conf_; - }; - - ASSERT_FALSE(alive); - - memgraph::utils::SyncPtr sp(123, alive); - ASSERT_TRUE(alive); - ASSERT_EQ(sp.config().conf_, 123); - auto sp_copy1 = sp.get(); - auto sp_copy2 = sp.get(); - - sp_copy1.reset(); - ASSERT_TRUE(alive); - sp_copy2.reset(); - ASSERT_TRUE(alive); - - sp.DestroyAndSync(); - ASSERT_FALSE(alive); -} - -TEST(SyncPtr, Sync) { - std::atomic_bool alive{false}; - struct Test { - Test(std::atomic_bool &alive) : alive_(alive) { alive_ = true; } - ~Test() { alive_ = false; } - std::atomic_bool &alive_; - }; - - std::thread th; - - ASSERT_FALSE(alive); - - memgraph::utils::SyncPtr sp(alive); - - { - using namespace std::chrono_literals; - sp.timeout(10000ms); // 10sec - - ASSERT_TRUE(alive); - auto sp_copy1 = sp.get(); - auto sp_copy2 = sp.get(); - - th = std::thread([&alive, p = sp.get()]() mutable { - // Wait for a second and then release the pointer - // SyncPtr will be destroyed in the mean time (and block) - std::this_thread::sleep_for(std::chrono::milliseconds(200)); - ASSERT_TRUE(alive); - p.reset(); - ASSERT_FALSE(alive); - }); - } - - ASSERT_TRUE(alive); - sp.DestroyAndSync(); - - th.join(); - ASSERT_FALSE(alive); -} - -TEST(SyncPtr, SyncWConfig) { - std::atomic_bool alive{false}; - struct Test { - Test(std::atomic_bool &alive) : alive_(alive) { alive_ = true; } - ~Test() { alive_ = false; } - std::atomic_bool &alive_; - }; - - struct TestConf { - TestConf(int i) : conf_(i) {} - int conf_; - }; - - std::thread th; - - ASSERT_FALSE(alive); - - memgraph::utils::SyncPtr sp(456, alive); - - { - using namespace std::chrono_literals; - sp.timeout(10000ms); // 10sec - ASSERT_TRUE(alive); - ASSERT_EQ(sp.config().conf_, 456); - auto sp_copy1 = sp.get(); - auto sp_copy2 = sp.get(); - - th = std::thread([&alive, p = sp.get()]() mutable { - // Wait for a second and then release the pointer - // SyncPtr will be destroyed in the mean time (and block) - std::this_thread::sleep_for(std::chrono::milliseconds(200)); - ASSERT_TRUE(alive); - p.reset(); - ASSERT_FALSE(alive); - }); - } - - ASSERT_TRUE(alive); - sp.DestroyAndSync(); - - th.join(); - ASSERT_FALSE(alive); -} - -TEST(SyncPtr, Timeout100ms) { - std::atomic_bool alive{false}; - struct Test { - Test(std::atomic_bool &alive) : alive_(alive) { alive_ = true; } - ~Test() { alive_ = false; } - std::atomic_bool &alive_; - }; - - std::thread th; - - ASSERT_FALSE(alive); - - memgraph::utils::SyncPtr sp(alive); - using namespace std::chrono_literals; - sp.timeout(100ms); - - ASSERT_TRUE(alive); - - auto p = sp.get(); - - ASSERT_TRUE(alive); - - auto start_100ms = std::chrono::system_clock::now(); - ASSERT_THROW(sp.DestroyAndSync(), memgraph::utils::BasicException); - auto end_100ms = std::chrono::system_clock::now(); - auto delta_100ms = std::chrono::duration_cast(end_100ms - start_100ms).count(); - ASSERT_NEAR(delta_100ms, 100, 100); - - p.reset(); - ASSERT_FALSE(alive); -} - -TEST(SyncPtr, Timeout567ms) { - std::atomic_bool alive{false}; - struct Test { - Test(std::atomic_bool &alive) : alive_(alive) { alive_ = true; } - ~Test() { alive_ = false; } - std::atomic_bool &alive_; - }; - - std::thread th; - - ASSERT_FALSE(alive); - - memgraph::utils::SyncPtr sp(alive); - using namespace std::chrono_literals; - sp.timeout(567ms); - - ASSERT_TRUE(alive); - - auto p = sp.get(); - - ASSERT_TRUE(alive); - - auto start = std::chrono::system_clock::now(); - ASSERT_THROW(sp.DestroyAndSync(), memgraph::utils::BasicException); - auto end = std::chrono::system_clock::now(); - auto delta_ms = std::chrono::duration_cast(end - start).count(); - ASSERT_NEAR(delta_ms, 567, 100); - - p.reset(); - ASSERT_FALSE(alive); -} - -TEST(SyncPtr, Timeout100msWConfig) { - std::atomic_bool alive{false}; - struct Test { - Test(std::atomic_bool &alive) : alive_(alive) { alive_ = true; } - ~Test() { alive_ = false; } - std::atomic_bool &alive_; - }; - - struct TestConf { - TestConf(int i) : conf_(i) {} - int conf_; - }; - - std::thread th; - - ASSERT_FALSE(alive); - - memgraph::utils::SyncPtr sp(0, alive); - using namespace std::chrono_literals; - sp.timeout(100ms); - - ASSERT_TRUE(alive); - - auto p = sp.get(); - - ASSERT_TRUE(alive); - - auto start_100ms = std::chrono::system_clock::now(); - ASSERT_THROW(sp.DestroyAndSync(), memgraph::utils::BasicException); - auto end_100ms = std::chrono::system_clock::now(); - auto delta_100ms = std::chrono::duration_cast(end_100ms - start_100ms).count(); - ASSERT_NEAR(delta_100ms, 100, 100); - - p.reset(); - ASSERT_FALSE(alive); -} - -TEST(SyncPtr, Timeout567msWConfig) { - std::atomic_bool alive{false}; - struct Test { - Test(std::atomic_bool &alive) : alive_(alive) { alive_ = true; } - ~Test() { alive_ = false; } - std::atomic_bool &alive_; - }; - - struct TestConf { - TestConf(int i) : conf_(i) {} - int conf_; - }; - - std::thread th; - - ASSERT_FALSE(alive); - - memgraph::utils::SyncPtr sp(2, alive); - using namespace std::chrono_literals; - sp.timeout(567ms); - - ASSERT_TRUE(alive); - - auto p = sp.get(); - - ASSERT_TRUE(alive); - - auto start = std::chrono::system_clock::now(); - ASSERT_THROW(sp.DestroyAndSync(), memgraph::utils::BasicException); - auto end = std::chrono::system_clock::now(); - auto delta_ms = std::chrono::duration_cast(end - start).count(); - ASSERT_NEAR(delta_ms, 567, 100); - - p.reset(); - ASSERT_FALSE(alive); -}