From e8850549d23685e125dfa9d31e564ad0ed03e4c6 Mon Sep 17 00:00:00 2001 From: andrejtonev <29177572+andrejtonev@users.noreply.github.com> Date: Tue, 1 Aug 2023 18:49:11 +0200 Subject: [PATCH] Add multi-tenancy v1 (#952) * Decouple BoltSession and communication::bolt::Session * Add CREATE/USE/DROP DATABASE * Add SHOW DATABASES * Cover WebSocket session * Simple session safety implemented via RWLock * Storage symlinks for backward. compatibility * Extend the audit log with the DB info * Add auth part * Add tenant recovery --- src/audit/log.cpp | 10 +- src/audit/log.hpp | 5 +- src/auth/auth.cpp | 55 +- src/auth/auth.hpp | 45 +- src/auth/models.cpp | 137 +++- src/auth/models.hpp | 94 ++- src/communication/bolt/v1/session.hpp | 37 +- src/communication/bolt/v1/states/handlers.hpp | 14 +- src/communication/bolt/v1/states/init.hpp | 11 +- src/communication/http/listener.hpp | 17 +- src/communication/http/server.hpp | 11 +- src/communication/http/session.hpp | 8 +- src/communication/listener.hpp | 10 +- src/communication/server.hpp | 12 +- src/communication/session.hpp | 6 +- src/communication/v2/listener.hpp | 18 +- src/communication/v2/server.hpp | 18 +- src/communication/v2/session.hpp | 64 +- src/dbms/constants.hpp | 18 + src/dbms/global.hpp | 110 ++++ src/dbms/handler.hpp | 142 +++++ src/dbms/interp_handler.hpp | 106 +++ src/dbms/session_context.hpp | 61 ++ src/dbms/session_context_handler.hpp | 603 ++++++++++++++++++ src/glue/auth.cpp | 4 + src/glue/auth_checker.cpp | 13 +- src/glue/auth_checker.hpp | 6 +- src/glue/auth_handler.cpp | 89 +++ src/glue/auth_handler.hpp | 12 +- src/http_handlers/metrics.hpp | 10 +- src/kvstore/kvstore.cpp | 10 +- src/memgraph.cpp | 452 +++++++++---- src/query/auth_checker.hpp | 6 +- src/query/db_accessor.hpp | 2 + src/query/exceptions.hpp | 6 + src/query/frontend/ast/ast.cpp | 6 + src/query/frontend/ast/ast.hpp | 86 ++- src/query/frontend/ast/ast_visitor.hpp | 6 +- .../frontend/ast/cypher_main_visitor.cpp | 114 +++- .../frontend/ast/cypher_main_visitor.hpp | 45 ++ .../opencypher/grammar/MemgraphCypher.g4 | 34 + .../opencypher/grammar/MemgraphCypherLexer.g4 | 4 + .../frontend/semantic/required_privileges.cpp | 16 + .../frontend/stripped_lexer_constants.hpp | 2 + src/query/interpreter.cpp | 409 +++++++++++- src/query/interpreter.hpp | 34 +- src/query/stream/streams.cpp | 2 +- src/query/trigger.cpp | 2 +- src/storage/v2/config.hpp | 56 +- src/storage/v2/storage.cpp | 3 +- src/storage/v2/storage.hpp | 5 + src/utils/stat.hpp | 3 +- src/utils/sync_ptr.hpp | 189 ++++++ src/utils/typeinfo.hpp | 2 + tests/concurrent/network_server.cpp | 6 +- tests/concurrent/network_session_leak.cpp | 6 +- tests/e2e/analyze_graph/common.py | 7 +- tests/e2e/analyze_graph/optimize_indexes.py | 46 +- tests/e2e/configuration/default_config.py | 15 +- tests/e2e/configuration/storage_info.py | 3 +- tests/e2e/fine_grained_access/CMakeLists.txt | 1 + tests/e2e/fine_grained_access/common.py | 19 +- .../create_delete_filtering_tests.py | 368 ++++++++--- .../edge_type_filtering_tests.py | 75 ++- .../path_filtering_tests.py | 288 ++++++--- tests/e2e/fine_grained_access/show_db.py | 36 ++ tests/e2e/fine_grained_access/workloads.yaml | 63 +- .../graphql/graphql_library_config/crud.js | 9 +- .../e2e/isolation_levels/isolation_levels.cpp | 89 ++- tests/e2e/lba_procedures/common.py | 18 +- .../create_delete_query_modules.py | 179 +++++- .../lba_procedures/read_permission_queries.py | 28 +- .../e2e/lba_procedures/read_query_modules.py | 133 +++- tests/e2e/lba_procedures/show_privileges.py | 4 +- .../lba_procedures/update_query_modules.py | 106 ++- tests/e2e/lba_procedures/workloads.yaml | 10 + tests/e2e/magic_functions/function_example.py | 59 +- tests/e2e/memory/memory_control.cpp | 12 +- .../e2e/memory/memory_limit_global_alloc.cpp | 12 +- .../memory/memory_limit_global_alloc_proc.cpp | 13 +- tests/e2e/memory/workloads.yaml | 17 + .../module_file_manager.cpp | 12 +- tests/e2e/module_file_manager/workloads.yaml | 5 + .../python_query_modules_reloading/common.py | 13 + .../test_reload_query_module.py | 28 +- tests/e2e/transaction_queue/common.py | 1 - .../test_transaction_queue.py | 27 +- tests/e2e/triggers/common.py | 9 +- .../e2e/triggers/triggers_properties_false.py | 39 +- tests/e2e/write_procedures/read_subgraph.py | 98 ++- tests/e2e/write_procedures/simple_write.py | 50 +- tests/integration/audit/runner.py | 82 ++- tests/integration/audit/tester.cpp | 5 +- tests/integration/auth/runner.py | 376 +++++------ tests/integration/auth/tester.cpp | 30 +- .../fine_grained_access/filtering.cpp | 4 +- .../integration/fine_grained_access/runner.py | 75 ++- tests/integration/transactions/runner.sh | 16 +- tests/integration/transactions/tester.cpp | 18 +- tests/manual/single_query.cpp | 2 +- tests/setup.sh | 2 +- tests/unit/CMakeLists.txt | 20 + tests/unit/bolt_session.cpp | 37 +- tests/unit/dbms_interp.cpp | 429 +++++++++++++ tests/unit/dbms_sc_handler.cpp | 343 ++++++++++ tests/unit/interpreter_faker.hpp | 2 +- tests/unit/query_common.hpp | 6 +- tests/unit/query_dump.cpp | 4 +- tests/unit/query_plan_edge_cases.cpp | 2 +- tests/unit/query_procedure_mgp_type.cpp | 6 +- tests/unit/query_procedure_py_module.cpp | 6 +- tests/unit/query_required_privileges.cpp | 4 +- tests/unit/query_trigger.cpp | 33 +- tests/unit/utils_sync_ptr.cpp | 296 +++++++++ 114 files changed, 5927 insertions(+), 1015 deletions(-) create mode 100644 src/dbms/constants.hpp create mode 100644 src/dbms/global.hpp create mode 100644 src/dbms/handler.hpp create mode 100644 src/dbms/interp_handler.hpp create mode 100644 src/dbms/session_context.hpp create mode 100644 src/dbms/session_context_handler.hpp create mode 100644 src/utils/sync_ptr.hpp create mode 100644 tests/e2e/fine_grained_access/show_db.py create mode 100644 tests/unit/dbms_interp.cpp create mode 100644 tests/unit/dbms_sc_handler.cpp create mode 100644 tests/unit/utils_sync_ptr.cpp diff --git a/src/audit/log.cpp b/src/audit/log.cpp index c7cedce08..635898a16 100644 --- a/src/audit/log.cpp +++ b/src/audit/log.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Licensed as a Memgraph Enterprise file under the Memgraph Enterprise // License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use @@ -116,12 +116,12 @@ Log::~Log() { } void Log::Record(const std::string &address, const std::string &username, const std::string &query, - const storage::PropertyValue ¶ms) { + const storage::PropertyValue ¶ms, const std::string &db) { if (!started_.load(std::memory_order_relaxed)) return; auto timestamp = std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()) .count(); - buffer_->emplace(Item{timestamp, address, username, query, params}); + buffer_->emplace(Item{timestamp, address, username, query, params, db}); } void Log::ReopenLog() { @@ -136,8 +136,8 @@ void Log::Flush() { for (uint64_t i = 0; i < buffer_size_; ++i) { auto item = buffer_->pop(); if (!item) break; - log_.Write(fmt::format("{}.{:06d},{},{},{},{}\n", item->timestamp / 1000000, item->timestamp % 1000000, - item->address, item->username, utils::Escape(item->query), + log_.Write(fmt::format("{}.{:06d},{},{},{},{},{}\n", item->timestamp / 1000000, item->timestamp % 1000000, + item->address, item->username, item->db, utils::Escape(item->query), utils::Escape(PropertyValueToJson(item->params).dump()))); } log_.Sync(); diff --git a/src/audit/log.hpp b/src/audit/log.hpp index a26636330..8def3ede5 100644 --- a/src/audit/log.hpp +++ b/src/audit/log.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Licensed as a Memgraph Enterprise file under the Memgraph Enterprise // License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use @@ -32,6 +32,7 @@ class Log { std::string username; std::string query; storage::PropertyValue params; + std::string db; }; public: @@ -51,7 +52,7 @@ class Log { /// Adds an entry to the audit log. Thread-safe. void Record(const std::string &address, const std::string &username, const std::string &query, - const storage::PropertyValue ¶ms); + const storage::PropertyValue ¶ms, const std::string &db); /// Reopens the log file. Used for log file rotation. Thread-safe. void ReopenLog(); diff --git a/src/auth/auth.cpp b/src/auth/auth.cpp index 03ce329df..cfe9dbdbe 100644 --- a/src/auth/auth.cpp +++ b/src/auth/auth.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Licensed as a Memgraph Enterprise file under the Memgraph Enterprise // License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use @@ -314,4 +314,57 @@ std::vector Auth::AllUsersForRole(const std::string &rolename_orig) return ret; } +#ifdef MG_ENTERPRISE +bool Auth::GrantDatabaseToUser(const std::string &db, const std::string &name) { + auto user = GetUser(name); + if (user) { + if (db == kAllDatabases) { + user->db_access().GrantAll(); + } else { + user->db_access().Add(db); + } + SaveUser(*user); + return true; + } + return false; +} + +bool Auth::RevokeDatabaseFromUser(const std::string &db, const std::string &name) { + auto user = GetUser(name); + if (user) { + if (db == kAllDatabases) { + user->db_access().DenyAll(); + } else { + user->db_access().Remove(db); + } + SaveUser(*user); + return true; + } + return false; +} + +void Auth::DeleteDatabase(const std::string &db) { + for (auto it = storage_.begin(kUserPrefix); it != storage_.end(kUserPrefix); ++it) { + auto username = it->first.substr(kUserPrefix.size()); + auto user = GetUser(username); + if (user) { + user->db_access().Delete(db); + SaveUser(*user); + } + } +} + +bool Auth::SetMainDatabase(const std::string &db, const std::string &name) { + auto user = GetUser(name); + if (user) { + if (!user->db_access().SetDefault(db)) { + throw AuthException("Couldn't set default database '{}' for user '{}'!", db, name); + } + SaveUser(*user); + return true; + } + return false; +} +#endif + } // namespace memgraph::auth diff --git a/src/auth/auth.hpp b/src/auth/auth.hpp index 11590e21a..8d2a9d91c 100644 --- a/src/auth/auth.hpp +++ b/src/auth/auth.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Licensed as a Memgraph Enterprise file under the Memgraph Enterprise // License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use @@ -19,6 +19,9 @@ #include "utils/settings.hpp" namespace memgraph::auth { + +static const constexpr char *const kAllDatabases = "*"; + /** * This class serves as the main Authentication/Authorization storage. * It provides functions for managing Users, Roles, Permissions and FineGrainedAccessPermissions. @@ -155,6 +158,46 @@ class Auth final { */ std::vector AllUsersForRole(const std::string &rolename) const; +#ifdef MG_ENTERPRISE + /** + * @brief Revoke access to individual database for a user. + * + * @param db name of the database to revoke + * @param name user's username + * @return true on success + * @throw AuthException if unable to find or update the user + */ + bool RevokeDatabaseFromUser(const std::string &db, const std::string &name); + + /** + * @brief Grant access to individual database for a user. + * + * @param db name of the database to revoke + * @param name user's username + * @return true on success + * @throw AuthException if unable to find or update the user + */ + bool GrantDatabaseToUser(const std::string &db, const std::string &name); + + /** + * @brief Delete a database from all users. + * + * @param db name of the database to delete + * @throw AuthException if unable to read data + */ + void DeleteDatabase(const std::string &db); + + /** + * @brief Set main database for an individual user. + * + * @param db name of the database to revoke + * @param name user's username + * @return true on success + * @throw AuthException if unable to find or update the user + */ + bool SetMainDatabase(const std::string &db, const std::string &name); +#endif + private: // Even though the `kvstore::KVStore` class is guaranteed to be thread-safe, // Auth is not thread-safe because modifying users and roles might require diff --git a/src/auth/models.cpp b/src/auth/models.cpp index 3bd35833e..990ae8f0f 100644 --- a/src/auth/models.cpp +++ b/src/auth/models.cpp @@ -15,8 +15,10 @@ #include "auth/crypto.hpp" #include "auth/exceptions.hpp" +#include "dbms/constants.hpp" #include "license/license.hpp" #include "query/constants.hpp" +#include "spdlog/spdlog.h" #include "utils/cast.hpp" #include "utils/logging.hpp" #include "utils/settings.hpp" @@ -35,18 +37,31 @@ namespace memgraph::auth { namespace { // Constant list of all available permissions. -const std::vector kPermissionsAll = {Permission::MATCH, Permission::CREATE, - Permission::MERGE, Permission::DELETE, - Permission::SET, Permission::REMOVE, - Permission::INDEX, Permission::STATS, - Permission::CONSTRAINT, Permission::DUMP, - Permission::AUTH, Permission::REPLICATION, - Permission::DURABILITY, Permission::READ_FILE, - Permission::FREE_MEMORY, Permission::TRIGGER, - Permission::CONFIG, Permission::STREAM, - Permission::MODULE_READ, Permission::MODULE_WRITE, - Permission::WEBSOCKET, Permission::TRANSACTION_MANAGEMENT, - Permission::STORAGE_MODE}; +const std::vector kPermissionsAll = {Permission::MATCH, + Permission::CREATE, + Permission::MERGE, + Permission::DELETE, + Permission::SET, + Permission::REMOVE, + Permission::INDEX, + Permission::STATS, + Permission::CONSTRAINT, + Permission::DUMP, + Permission::AUTH, + Permission::REPLICATION, + Permission::DURABILITY, + Permission::READ_FILE, + Permission::FREE_MEMORY, + Permission::TRIGGER, + Permission::CONFIG, + Permission::STREAM, + Permission::MODULE_READ, + Permission::MODULE_WRITE, + Permission::WEBSOCKET, + Permission::TRANSACTION_MANAGEMENT, + Permission::STORAGE_MODE, + Permission::MULTI_DATABASE_EDIT, + Permission::MULTI_DATABASE_USE}; } // namespace @@ -98,6 +113,10 @@ std::string PermissionToString(Permission permission) { return "TRANSACTION_MANAGEMENT"; case Permission::STORAGE_MODE: return "STORAGE_MODE"; + case Permission::MULTI_DATABASE_EDIT: + return "MULTI_DATABASE_EDIT"; + case Permission::MULTI_DATABASE_USE: + return "MULTI_DATABASE_USE"; } } @@ -464,6 +483,82 @@ bool operator==(const Role &first, const Role &second) { return first.rolename_ == second.rolename_ && first.permissions_ == second.permissions_; } +#ifdef MG_ENTERPRISE +void Databases::Add(const std::string &db) { + if (allow_all_) { + grants_dbs_.clear(); + allow_all_ = false; + } + grants_dbs_.emplace(db); + denies_dbs_.erase(db); +} + +void Databases::Remove(const std::string &db) { + denies_dbs_.emplace(db); + grants_dbs_.erase(db); +} + +void Databases::Delete(const std::string &db) { + denies_dbs_.erase(db); + if (!allow_all_) { + grants_dbs_.erase(db); + } + // Reset if default deleted + if (default_db_ == db) { + default_db_ = ""; + } +} + +void Databases::GrantAll() { + allow_all_ = true; + grants_dbs_.clear(); + denies_dbs_.clear(); +} + +void Databases::DenyAll() { + allow_all_ = false; + grants_dbs_.clear(); + denies_dbs_.clear(); +} + +bool Databases::SetDefault(const std::string &db) { + if (!Contains(db)) return false; + default_db_ = db; + return true; +} + +[[nodiscard]] bool Databases::Contains(const std::string &db) const { + return !denies_dbs_.contains(db) && (allow_all_ || grants_dbs_.contains(db)); +} + +const std::string &Databases::GetDefault() const { + if (!Contains(default_db_)) { + throw AuthException("No access to the set default database \"{}\".", default_db_); + } + return default_db_; +} + +nlohmann::json Databases::Serialize() const { + nlohmann::json data = nlohmann::json::object(); + data["grants"] = grants_dbs_; + data["denies"] = denies_dbs_; + data["allow_all"] = allow_all_; + data["default"] = default_db_; + return data; +} + +Databases Databases::Deserialize(const nlohmann::json &data) { + if (!data.is_object()) { + throw AuthException("Couldn't load database data!"); + } + if (!data["grants"].is_structured() || !data["denies"].is_structured() || !data["allow_all"].is_boolean() || + !data["default"].is_string()) { + throw AuthException("Couldn't load database data!"); + } + return {data["allow_all"], data["grants"], data["denies"], data["default"]}; +} +#endif + User::User() {} User::User(const std::string &username) : username_(utils::ToLowerCase(username)) {} @@ -472,11 +567,12 @@ User::User(const std::string &username, const std::string &password_hash, const #ifdef MG_ENTERPRISE User::User(const std::string &username, const std::string &password_hash, const Permissions &permissions, - FineGrainedAccessHandler fine_grained_access_handler) + FineGrainedAccessHandler fine_grained_access_handler, Databases db_access) : username_(utils::ToLowerCase(username)), password_hash_(password_hash), permissions_(permissions), - fine_grained_access_handler_(std::move(fine_grained_access_handler)) {} + fine_grained_access_handler_(std::move(fine_grained_access_handler)), + database_access_(db_access) {} #endif bool User::CheckPassword(const std::string &password) { @@ -576,8 +672,10 @@ nlohmann::json User::Serialize() const { #ifdef MG_ENTERPRISE if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) { data["fine_grained_access_handler"] = fine_grained_access_handler_.Serialize(); + data["databases"] = database_access_.Serialize(); } else { data["fine_grained_access_handler"] = {}; + data["databases"] = {}; } #endif // The role shouldn't be serialized here, it is stored as a foreign key. @@ -594,11 +692,20 @@ User User::Deserialize(const nlohmann::json &data) { auto permissions = Permissions::Deserialize(data["permissions"]); #ifdef MG_ENTERPRISE if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) { + Databases db_access; + if (data["databases"].is_structured()) { + db_access = Databases::Deserialize(data["databases"]); + } else { + // Back-compatibility + spdlog::warn("User without specified database access. Given access to the default database."); + db_access.Add(dbms::kDefaultDB); + db_access.SetDefault(dbms::kDefaultDB); + } if (!data["fine_grained_access_handler"].is_object()) { throw AuthException("Couldn't load user data!"); } auto fine_grained_access_handler = FineGrainedAccessHandler::Deserialize(data["fine_grained_access_handler"]); - return {data["username"], data["password_hash"], permissions, fine_grained_access_handler}; + return {data["username"], data["password_hash"], permissions, fine_grained_access_handler, db_access}; } #endif return {data["username"], data["password_hash"], permissions}; diff --git a/src/auth/models.hpp b/src/auth/models.hpp index 4501c18c4..33ba28f80 100644 --- a/src/auth/models.hpp +++ b/src/auth/models.hpp @@ -9,10 +9,13 @@ #pragma once #include +#include #include #include #include +#include "dbms/constants.hpp" +#include "utils/logging.hpp" namespace memgraph::auth { // These permissions must have values that are applicable for usage in a @@ -41,7 +44,9 @@ enum class Permission : uint64_t { MODULE_WRITE = 1U << 19U, WEBSOCKET = 1U << 20U, TRANSACTION_MANAGEMENT = 1U << 21U, - STORAGE_MODE = 1U << 22U + STORAGE_MODE = 1U << 22U, + MULTI_DATABASE_EDIT = 1U << 23U, + MULTI_DATABASE_USE = 1U << 24U, }; // clang-format on @@ -237,6 +242,85 @@ class Role final { bool operator==(const Role &first, const Role &second); +#ifdef MG_ENTERPRISE +class Databases final { + public: + Databases() : grants_dbs_({dbms::kDefaultDB}), allow_all_(false), default_db_(dbms::kDefaultDB) {} + + Databases(const Databases &) = default; + Databases &operator=(const Databases &) = default; + Databases(Databases &&) noexcept = default; + Databases &operator=(Databases &&) noexcept = default; + ~Databases() = default; + + /** + * @brief Add database to the list of granted access. @note allow_all_ will be false after execution + * + * @param db name of the database to grant access to + */ + void Add(const std::string &db); + + /** + * @brief Remove database to the list of granted access. + * @note if allow_all_ is set, the flag will remain set and the + * database will be added to the set of denied databases. + * + * @param db name of the database to grant access to + */ + void Remove(const std::string &db); + + /** + * @brief Called when database is dropped. Removes it from granted (if allow_all is false) and denied set. + * @note allow_all_ is not changed + * + * @param db name of the database to grant access to + */ + void Delete(const std::string &db); + + /** + * @brief Set allow_all_ to true and clears grants and denied sets. + */ + void GrantAll(); + + /** + * @brief Set allow_all_ to false and clears grants and denied sets. + */ + void DenyAll(); + + /** + * @brief Set the default database. + */ + bool SetDefault(const std::string &db); + + /** + * @brief Checks if access is grated to the database. + * + * @param db name of the database + * @return true if allow_all and not denied or granted + */ + bool Contains(const std::string &db) const; + + bool GetAllowAll() const { return allow_all_; } + const std::set &GetGrants() const { return grants_dbs_; } + const std::set &GetDenies() const { return denies_dbs_; } + const std::string &GetDefault() const; + + nlohmann::json Serialize() const; + /// @throw AuthException if unable to deserialize. + static Databases Deserialize(const nlohmann::json &data); + + private: + Databases(bool allow_all, std::set grant, std::set deny, + const std::string &default_db = dbms::kDefaultDB) + : grants_dbs_(grant), denies_dbs_(deny), allow_all_(allow_all), default_db_(default_db) {} + + std::set grants_dbs_; //!< set of databases with granted access + std::set denies_dbs_; //!< set of databases with denied access + bool allow_all_; //!< flag to allow access to everything (denied overrides this) + std::string default_db_; //!< user's default database +}; +#endif + // TODO (mferencevic): Implement password expiry. class User final { public: @@ -246,7 +330,7 @@ class User final { User(const std::string &username, const std::string &password_hash, const Permissions &permissions); #ifdef MG_ENTERPRISE User(const std::string &username, const std::string &password_hash, const Permissions &permissions, - FineGrainedAccessHandler fine_grained_access_handler); + FineGrainedAccessHandler fine_grained_access_handler, Databases db_access = {}); #endif User(const User &) = default; User &operator=(const User &) = default; @@ -279,6 +363,11 @@ class User final { const Role *role() const; +#ifdef MG_ENTERPRISE + Databases &db_access() { return database_access_; } + const Databases &db_access() const { return database_access_; } +#endif + nlohmann::json Serialize() const; /// @throw AuthException if unable to deserialize. @@ -292,6 +381,7 @@ class User final { Permissions permissions_; #ifdef MG_ENTERPRISE FineGrainedAccessHandler fine_grained_access_handler_; + Databases database_access_; #endif std::optional role_; }; diff --git a/src/communication/bolt/v1/session.hpp b/src/communication/bolt/v1/session.hpp index 5e6aa4e39..bc968dab8 100644 --- a/src/communication/bolt/v1/session.hpp +++ b/src/communication/bolt/v1/session.hpp @@ -11,6 +11,8 @@ #pragma once +#include +#include #include #include @@ -24,8 +26,12 @@ #include "communication/bolt/v1/states/executing.hpp" #include "communication/bolt/v1/states/handshake.hpp" #include "communication/bolt/v1/states/init.hpp" +#include "communication/bolt/v1/value.hpp" +#include "dbms/constants.hpp" +#include "dbms/global.hpp" #include "utils/exceptions.hpp" #include "utils/logging.hpp" +#include "utils/uuid.hpp" namespace memgraph::communication::bolt { @@ -48,14 +54,26 @@ class SessionException : public utils::BasicException { * @tparam TOutputStream type of output stream that will be used */ template -class Session { +class Session : public dbms::SessionInterface { public: using TEncoder = Encoder>; + /** + * @brief Construct a new Session object + * + * @param input_stream stream to read from + * @param output_stream stream to write to + * @param impl a default high-level implementation to use (has to be defined) + */ Session(TInputStream *input_stream, TOutputStream *output_stream) - : input_stream_(*input_stream), output_stream_(*output_stream) {} + : input_stream_(*input_stream), output_stream_(*output_stream), session_uuid_(utils::GenerateUUID()) {} - virtual ~Session() {} + virtual ~Session() = default; + + Session(const Session &) = delete; + Session &operator=(const Session &) = delete; + Session(Session &&) noexcept = delete; + Session &operator=(Session &&) noexcept = delete; /** * Process the given `query` with `params`. @@ -66,6 +84,8 @@ class Session { const std::string &query, const std::map ¶ms, const std::map &extra) = 0; + virtual void Configure(const std::map &run_time_info) = 0; + /** * Put results of the processed query in the `encoder`. * @@ -86,7 +106,7 @@ class Session { */ virtual std::map Discard(std::optional n, std::optional qid) = 0; - virtual void BeginTransaction(const std::map &) = 0; + virtual void BeginTransaction(const std::map ¶ms) = 0; virtual void CommitTransaction() = 0; virtual void RollbackTransaction() = 0; @@ -99,7 +119,6 @@ class Session { /** Return the name of the server that should be used for the Bolt INIT * message. */ virtual std::optional GetServerNameForInit() = 0; - /** * Executes the session after data has been read into the buffer. * Goes through the bolt states in order to execute commands from the client. @@ -161,8 +180,7 @@ class Session { } } - // TODO: Rethink if there is a way to hide some members. At the momement all - // of them are public. + // TODO: Rethink if there is a way to hide some members. At the momement all of them are public. TInputStream &input_stream_; TOutputStream &output_stream_; @@ -182,6 +200,9 @@ class Session { Version version_; + std::string GetDatabaseName() const override = 0; + std::string UUID() const final { return session_uuid_; } + private: void ClientFailureInvalidData() { // Set the state to Close. @@ -197,6 +218,8 @@ class Session { // of the session to trigger session cleanup and socket close. throw SessionException("Something went wrong during session execution!"); } + + const std::string session_uuid_; //!< unique identifier of the session (auto generated) }; } // namespace memgraph::communication::bolt diff --git a/src/communication/bolt/v1/states/handlers.hpp b/src/communication/bolt/v1/states/handlers.hpp index b23f008ad..4243a7fb9 100644 --- a/src/communication/bolt/v1/states/handlers.hpp +++ b/src/communication/bolt/v1/states/handlers.hpp @@ -11,6 +11,7 @@ #pragma once +#include #include #include #include @@ -207,7 +208,7 @@ State HandleRunV1(TSession &session, const State state, const Marker marker) { DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state"); - spdlog::debug("[Run] '{}'", query.ValueString()); + spdlog::debug("[Run - {}] '{}'", session.GetDatabaseName(), query.ValueString()); try { // Interpret can throw. @@ -265,7 +266,13 @@ State HandleRunV4(TSession &session, const State state, const Marker marker) { DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state"); - spdlog::debug("[Run] '{}'", query.ValueString()); + try { + session.Configure(extra.ValueMap()); + } catch (const std::exception &e) { + return HandleFailure(session, e); + } + + spdlog::debug("[Run - {}] '{}'", session.GetDatabaseName(), query.ValueString()); try { // Interpret can throw. @@ -381,6 +388,7 @@ State HandleBegin(TSession &session, const State state, const Marker marker) { } try { + session.Configure(extra.ValueMap()); session.BeginTransaction(extra.ValueMap()); } catch (const std::exception &e) { return HandleFailure(session, e); @@ -489,7 +497,7 @@ State HandleRoute(TSession &session, const Marker marker) { template State HandleLogOff() { - // Not arguments sent, the user just needs to reauthenticate + // No arguments sent, the user just needs to reauthenticate return State::Init; } } // namespace memgraph::communication::bolt diff --git a/src/communication/bolt/v1/states/init.hpp b/src/communication/bolt/v1/states/init.hpp index 3e77b632c..955223467 100644 --- a/src/communication/bolt/v1/states/init.hpp +++ b/src/communication/bolt/v1/states/init.hpp @@ -18,6 +18,7 @@ #include "communication/bolt/v1/state.hpp" #include "communication/bolt/v1/value.hpp" #include "communication/exceptions.hpp" +#include "spdlog/spdlog.h" #include "utils/likely.hpp" #include "utils/logging.hpp" @@ -248,8 +249,9 @@ State StateInitRunV5(TSession &session, Marker marker, Signature signature) { } // Stay in Init return State::Init; + } - } else if (signature == Signature::LogOn) { + if (signature == Signature::LogOn) { if (marker != Marker::TinyStruct1) [[unlikely]] { spdlog::trace("Expected TinyStruct1 marker, but received 0x{:02X}!", utils::UnderlyingCast(marker)); spdlog::trace( @@ -273,11 +275,10 @@ State StateInitRunV5(TSession &session, Marker marker, Signature signature) { return State::Close; } return State::Idle; - - } else [[unlikely]] { - spdlog::trace("Expected Init signature, but received 0x{:02X}!", utils::UnderlyingCast(signature)); - return State::Close; } + + spdlog::trace("Expected Init signature, but received 0x{:02X}!", utils::UnderlyingCast(signature)); + return State::Close; } } // namespace details diff --git a/src/communication/http/listener.hpp b/src/communication/http/listener.hpp index 54214e7cd..029bf5ca1 100644 --- a/src/communication/http/listener.hpp +++ b/src/communication/http/listener.hpp @@ -27,11 +27,11 @@ namespace memgraph::communication::http { -template -class Listener final : public std::enable_shared_from_this> { +template +class Listener final : public std::enable_shared_from_this> { using tcp = boost::asio::ip::tcp; - using SessionHandler = Session; - using std::enable_shared_from_this>::shared_from_this; + using SessionHandler = Session; + using std::enable_shared_from_this>::shared_from_this; public: Listener(const Listener &) = delete; @@ -50,8 +50,9 @@ class Listener final : public std::enable_shared_from_thisRun(); + SessionHandler::Create(std::move(socket), session_context_, *context_)->Run(); DoAccept(); } boost::asio::io_context &ioc_; - TSessionData *data_; + TSessionContext *session_context_; ServerContext *context_; tcp::acceptor acceptor_; }; diff --git a/src/communication/http/server.hpp b/src/communication/http/server.hpp index 76b5ffcc4..7b23e5788 100644 --- a/src/communication/http/server.hpp +++ b/src/communication/http/server.hpp @@ -21,14 +21,15 @@ namespace memgraph::communication::http { -template +template class Server final { using tcp = boost::asio::ip::tcp; public: - explicit Server(io::network::Endpoint endpoint, TSessionData *data, ServerContext *context) - : listener_{Listener::Create( - ioc_, data, context, tcp::endpoint{boost::asio::ip::make_address(endpoint.address), endpoint.port})} {} + explicit Server(io::network::Endpoint endpoint, TSessionContext *session_context, ServerContext *context) + : listener_{Listener::Create( + ioc_, session_context, context, + tcp::endpoint{boost::asio::ip::make_address(endpoint.address), endpoint.port})} {} Server(const Server &) = delete; Server(Server &&) = delete; @@ -59,7 +60,7 @@ class Server final { private: boost::asio::io_context ioc_; - std::shared_ptr> listener_; + std::shared_ptr> listener_; std::optional background_thread_; }; } // namespace memgraph::communication::http diff --git a/src/communication/http/session.hpp b/src/communication/http/session.hpp index 90bcf9964..b08ce8f30 100644 --- a/src/communication/http/session.hpp +++ b/src/communication/http/session.hpp @@ -42,10 +42,10 @@ inline void LogError(boost::beast::error_code ec, const std::string_view what) { spdlog::warn("HTTP session failed on {}: {}", what, ec.message()); } -template -class Session : public std::enable_shared_from_this> { +template +class Session : public std::enable_shared_from_this> { using tcp = boost::asio::ip::tcp; - using std::enable_shared_from_this>::shared_from_this; + using std::enable_shared_from_this>::shared_from_this; public: template @@ -72,7 +72,7 @@ class Session : public std::enable_shared_from_this; - explicit Session(tcp::socket &&socket, TSessionData *data, ServerContext &context) + explicit Session(tcp::socket &&socket, TSessionContext *data, ServerContext &context) : stream_(CreateSocket(std::move(socket), context)), handler_(data), strand_{boost::asio::make_strand(GetExecutor())} {} diff --git a/src/communication/listener.hpp b/src/communication/listener.hpp index f3a53cb15..cbb6c0b2f 100644 --- a/src/communication/listener.hpp +++ b/src/communication/listener.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -39,7 +39,7 @@ namespace memgraph::communication { * second, checks all sessions for expiration and shuts them down if they have * expired. */ -template +template class Listener final { private: // The maximum number of events handled per execution thread is 1. This is @@ -48,10 +48,10 @@ class Listener final { // can take a long time. static const int kMaxEvents = 1; - using SessionHandler = Session; + using SessionHandler = Session; public: - Listener(TSessionData *data, ServerContext *context, int inactivity_timeout_sec, const std::string &service_name, + Listener(TSessionContext *data, ServerContext *context, int inactivity_timeout_sec, const std::string &service_name, size_t workers_count) : data_(data), alive_(false), @@ -259,7 +259,7 @@ class Listener final { io::network::Epoll epoll_; - TSessionData *data_; + TSessionContext *data_; utils::SpinLock lock_; std::vector> sessions_; diff --git a/src/communication/server.hpp b/src/communication/server.hpp index 0e4f5bae2..d958f265d 100644 --- a/src/communication/server.hpp +++ b/src/communication/server.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -46,10 +46,10 @@ namespace memgraph::communication { * @tparam TSession the server can handle different Sessions, each session * represents a different protocol so the same network infrastructure * can be used for handling different protocols - * @tparam TSessionData the class with objects that will be forwarded to the + * @tparam TSessionContext the class with objects that will be forwarded to the * session */ -template +template class Server final { public: using Socket = io::network::Socket; @@ -58,12 +58,12 @@ class Server final { * Constructs and binds server to endpoint, operates on session data and * invokes workers_count workers */ - Server(const io::network::Endpoint &endpoint, TSessionData *session_data, ServerContext *context, + Server(const io::network::Endpoint &endpoint, TSessionContext *session_context, ServerContext *context, int inactivity_timeout_sec, const std::string &service_name, size_t workers_count = std::thread::hardware_concurrency()) : alive_(false), endpoint_(endpoint), - listener_(session_data, context, inactivity_timeout_sec, service_name, workers_count), + listener_(session_context, context, inactivity_timeout_sec, service_name, workers_count), service_name_(service_name) {} ~Server() { @@ -156,7 +156,7 @@ class Server final { Socket socket_; io::network::Endpoint endpoint_; - Listener listener_; + Listener listener_; const std::string service_name_; }; diff --git a/src/communication/session.hpp b/src/communication/session.hpp index 98ab417ef..d61929b51 100644 --- a/src/communication/session.hpp +++ b/src/communication/session.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -69,10 +69,10 @@ class OutputStream final { * sessions. It handles socket ownership, inactivity timeout and protocol * wrapping. */ -template +template class Session final { public: - Session(io::network::Socket &&socket, TSessionData *data, ServerContext *context, int inactivity_timeout_sec) + Session(io::network::Socket &&socket, TSessionContext *data, ServerContext *context, int inactivity_timeout_sec) : socket_(std::move(socket)), output_stream_([this](const uint8_t *data, size_t len, bool have_more) { return Write(data, len, have_more); }), session_(data, socket_.endpoint(), input_buffer_.read_end(), &output_stream_), diff --git a/src/communication/v2/listener.hpp b/src/communication/v2/listener.hpp index 0116c0790..82d6fc2cb 100644 --- a/src/communication/v2/listener.hpp +++ b/src/communication/v2/listener.hpp @@ -36,11 +36,11 @@ namespace memgraph::communication::v2 { -template -class Listener final : public std::enable_shared_from_this> { +template +class Listener final : public std::enable_shared_from_this> { using tcp = boost::asio::ip::tcp; - using SessionHandler = Session; - using std::enable_shared_from_this>::shared_from_this; + using SessionHandler = Session; + using std::enable_shared_from_this>::shared_from_this; public: Listener(const Listener &) = delete; @@ -59,10 +59,10 @@ class Listener final : public std::enable_shared_from_thisStart(); DoAccept(); } @@ -123,7 +123,7 @@ class Listener final : public std::enable_shared_from_this +template class Server final { - using ServerHandler = Server; + using ServerHandler = Server; public: /** * Constructs and binds server to endpoint, operates on session data and * invokes workers_count workers */ - Server(ServerEndpoint &endpoint, TSessionData *session_data, ServerContext *server_context, + Server(ServerEndpoint &endpoint, TSessionContext *session_context, ServerContext *server_context, const int inactivity_timeout_sec, const std::string_view service_name, size_t workers_count = std::thread::hardware_concurrency()) : endpoint_{endpoint}, service_name_{service_name}, context_thread_pool_{workers_count}, - listener_{Listener::Create(context_thread_pool_.GetIOContext(), session_data, - server_context, endpoint_, service_name_, - inactivity_timeout_sec)} {} + listener_{Listener::Create(context_thread_pool_.GetIOContext(), session_context, + server_context, endpoint_, service_name_, + inactivity_timeout_sec)} {} ~Server() { MG_ASSERT(!IsRunning(), "Server wasn't shutdown properly"); } @@ -122,7 +122,7 @@ class Server final { std::string service_name_; IOContextThreadPool context_thread_pool_; - std::shared_ptr> listener_; + std::shared_ptr> listener_; }; } // namespace memgraph::communication::v2 diff --git a/src/communication/v2/session.hpp b/src/communication/v2/session.hpp index 076f3f3e9..f069546fd 100644 --- a/src/communication/v2/session.hpp +++ b/src/communication/v2/session.hpp @@ -16,10 +16,12 @@ #include #include #include +#include #include #include #include #include +#include #include #include @@ -41,9 +43,11 @@ #include #include +#include "communication/bolt/v1/session.hpp" #include "communication/buffer.hpp" #include "communication/context.hpp" #include "communication/exceptions.hpp" +#include "dbms/global.hpp" #include "utils/event_counter.hpp" #include "utils/logging.hpp" #include "utils/on_scope_exit.hpp" @@ -95,10 +99,10 @@ class OutputStream final { * Websocket Sessions. It handles socket ownership, inactivity timeout and protocol * wrapping. */ -template -class WebsocketSession : public std::enable_shared_from_this> { +template +class WebsocketSession : public std::enable_shared_from_this> { using WebSocket = boost::beast::websocket::stream; - using std::enable_shared_from_this>::shared_from_this; + using std::enable_shared_from_this>::shared_from_this; public: template @@ -106,6 +110,17 @@ class WebsocketSession : public std::enable_shared_from_this(new WebsocketSession(std::forward(args)...)); } +#ifdef MG_ENTERPRISE + ~WebsocketSession() { session_context_->Delete(session_); } +#else + ~WebsocketSession() = default; +#endif + + WebsocketSession(const WebsocketSession &) = delete; + WebsocketSession &operator=(const WebsocketSession &) = delete; + WebsocketSession(WebsocketSession &&) noexcept = delete; + WebsocketSession &operator=(WebsocketSession &&) noexcept = delete; + // Start the asynchronous accept operation template void DoAccept(boost::beast::http::request> req) { @@ -151,15 +166,20 @@ class WebsocketSession : public std::enable_shared_from_thisRegister(session_); +#endif + } void OnAccept(boost::beast::error_code ec) { if (ec) { @@ -242,6 +262,7 @@ class WebsocketSession : public std::enable_shared_from_this -class Session final : public std::enable_shared_from_this> { +template +class Session final : public std::enable_shared_from_this> { using TCPSocket = tcp::socket; using SSLSocket = boost::asio::ssl::stream; - using std::enable_shared_from_this>::shared_from_this; + using std::enable_shared_from_this>::shared_from_this; public: template @@ -265,11 +286,16 @@ class Session final : public std::enable_shared_from_this(new Session(std::forward(args)...)); } +#ifdef MG_ENTERPRISE + ~Session() { session_context_->Delete(session_); } +#else + ~Session() = default; +#endif + Session(const Session &) = delete; Session(Session &&) = delete; Session &operator=(const Session &) = delete; Session &operator=(Session &&) = delete; - ~Session() = default; bool Start() { if (execution_active_) { @@ -334,18 +360,23 @@ class Session final : public std::enable_shared_from_thisRegister(session_); +#endif ExecuteForSocket([](auto &&socket) { socket.lowest_layer().set_option(tcp::no_delay(true)); // enable PSH socket.lowest_layer().set_option(boost::asio::socket_base::keep_alive(true)); // enable SO_KEEPALIVE @@ -396,7 +427,8 @@ class Session final : public std::enable_shared_from_this(socket_)) { auto sock = std::get(std::move(socket_)); - WebsocketSession::Create(std::move(sock), data_, endpoint_, service_name_) + WebsocketSession::Create(std::move(sock), session_context_, endpoint_, + service_name_) ->DoAccept(parser.release()); execution_active_ = false; return; @@ -535,7 +567,7 @@ class Session final : public std::enable_shared_from_this +#include +#include + +#include "utils/exceptions.hpp" + +namespace memgraph::dbms { + +enum class DeleteError : uint8_t { + DEFAULT_DB, + USING, + NON_EXISTENT, + FAIL, + DISK_FAIL, +}; + +enum class NewError : uint8_t { + NO_CONFIGS, + EXISTS, + DEFUNCT, + GENERIC, +}; + +enum class SetForResult : uint8_t { + SUCCESS, + ALREADY_SET, + FAIL, +}; + +/** + * UnknownSession Exception + * + * Used to indicate that an unknown session was used. + */ +class UnknownSessionException : public utils::BasicException { + public: + using utils::BasicException::BasicException; +}; + +/** + * UnknownDatabase Exception + * + * Used to indicate that an unknown database was used. + */ +class UnknownDatabaseException : public utils::BasicException { + public: + using utils::BasicException::BasicException; +}; + +/** + * @brief Session interface used by the DBMS to handle the the active sessions. + * @todo Try to remove this dependency from SessionContextHandler. OnDelete could be removed, as it only does an assert. + * OnChange could be removed if SetFor returned the pointer and the called then handled the OnChange execution. + * However, the interface is very useful to decouple the interpreter's query execution and the sessions themselves. + */ +class SessionInterface { + public: + SessionInterface() = default; + virtual ~SessionInterface() = default; + + SessionInterface(const SessionInterface &) = default; + SessionInterface &operator=(const SessionInterface &) = default; + SessionInterface(SessionInterface &&) noexcept = default; + SessionInterface &operator=(SessionInterface &&) noexcept = default; + + /** + * @brief Return the unique string identifying the session. + * + * @return std::string + */ + virtual std::string UUID() const = 0; + + /** + * @brief Return the currently active database. + * + * @return std::string + */ + virtual std::string GetDatabaseName() const = 0; + +#ifdef MG_ENTERPRISE + /** + * @brief Gets called on database change. + * + * @return SetForResult enum (SUCCESS, ALREADY_SET or FAIL) + */ + virtual dbms::SetForResult OnChange(const std::string &) = 0; + + /** + * @brief Callback that gets called on database delete (drop). + * + * @return true on success + */ + virtual bool OnDelete(const std::string &) = 0; +#endif +}; + +} // namespace memgraph::dbms diff --git a/src/dbms/handler.hpp b/src/dbms/handler.hpp new file mode 100644 index 000000000..16db5558a --- /dev/null +++ b/src/dbms/handler.hpp @@ -0,0 +1,142 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include +#include +#include +#include +#include + +#include "global.hpp" +#include "utils/result.hpp" +#include "utils/sync_ptr.hpp" + +namespace memgraph::dbms { + +/** + * @brief Generic multi-database content handler. + * + * @tparam TContext + * @tparam TConfig + */ +template +class Handler { + public: + using NewResult = utils::BasicResult>; + + /** + * @brief Empty Handler constructor. + * + */ + Handler() {} + + /** + * @brief Generate a new context and corresponding configuration. + * + * @tparam T1 Variadic template of context constructor arguments + * @tparam T2 Variadic template of config constructor arguments + * @param name Name associated with the new context/config pair + * @param args1 Arguments passed (as a tuple) to the context constructor + * @param args2 Arguments passed (as a tuple) to the config constructor + * @return NewResult + */ + template + NewResult New(std::string name, std::tuple args1, std::tuple args2) { + return New_(name, args1, args2, std::make_index_sequence{}, + std::make_index_sequence{}); + } + + /** + * @brief Get pointer to context. + * + * @param name Name associated with the wanted context + * @return std::optional> + */ + std::optional> Get(const std::string &name) { + if (auto search = items_.find(name); search != items_.end()) { + return search->second.get(); + } + return {}; + } + + /** + * @brief Get the config. + * + * @param name Name associated with the wanted config + * @return std::optional + */ + std::optional GetConfig(const std::string &name) const { + if (auto search = items_.find(name); search != items_.end()) { + return search->second.config(); + } + return {}; + } + + /** + * @brief Delete the context/config pair associated with the name. + * + * @param name Name associated with the context/config pair to delete + * @return true on success + */ + bool Delete(const std::string &name) { + if (auto itr = items_.find(name); itr != items_.end()) { + itr->second.DestroyAndSync(); + items_.erase(itr); + return true; + } + return false; + } + + /** + * @brief Check if a name is already used. + * + * @param name Name to check + * @return true if a context/config pair is already associated with the name + */ + bool Has(const std::string &name) const { return items_.find(name) != items_.end(); } + + auto begin() { return items_.begin(); } + auto end() { return items_.end(); } + auto begin() const { return items_.begin(); } + auto end() const { return items_.end(); } + auto cbegin() const { return items_.cbegin(); } + auto cend() const { return items_.cend(); } + + private: + /** + * @brief Lower level handler that hides some ugly code. + * + * @tparam T1 Variadic template of context constructor arguments + * @tparam T2 Variadic template of config constructor arguments + * @tparam I1 List of indexes associated with the first tuple + * @tparam I2 List of indexes associated with the second tuple + */ + template + NewResult New_(std::string name, std::tuple &args1, std::tuple &args2, + std::integer_sequence /*not-used*/, + std::integer_sequence /*not-used*/) { + // Make sure the emplace will succeed, since we don't want to create temporary objects that could break something + if (!Has(name)) { + auto [itr, _] = items_.emplace(std::piecewise_construct, std::forward_as_tuple(name), + std::forward_as_tuple(TConfig{std::forward(std::get(args1))...}, + std::forward(std::get(args2))...)); + return itr->second.get(); + } + spdlog::info("Item with name \"{}\" already exists.", name); + return NewError::EXISTS; + } + + std::unordered_map> items_; //!< map to all active items +}; + +} // namespace memgraph::dbms diff --git a/src/dbms/interp_handler.hpp b/src/dbms/interp_handler.hpp new file mode 100644 index 000000000..af457b6b0 --- /dev/null +++ b/src/dbms/interp_handler.hpp @@ -0,0 +1,106 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#ifdef MG_ENTERPRISE + +#include "global.hpp" +#include "query/auth_checker.hpp" +#include "query/config.hpp" +#include "query/interpreter.hpp" +#include "storage/v2/storage.hpp" + +#include "handler.hpp" + +namespace memgraph::dbms { + +/** + * @brief Simple class that adds useful information to the query's InterpreterContext + * + * @tparam T Multi-database handler type + */ +template +class ExpandedInterpContext : public query::InterpreterContext { + public: + template + explicit ExpandedInterpContext(T &ref, TArgs &&...args) + : query::InterpreterContext(std::forward(args)...), sc_handler_(ref) {} + + T &sc_handler_; //!< Multi-database/SessionContext handler (used in some queries) +}; + +/** + * @brief Simple structure that expands on the query's InterpreterConfig + * + */ +struct ExpandedInterpConfig { + storage::Config storage_config; //!< Storage configuration + query::InterpreterConfig interp_config; //!< Interpreter configuration +}; + +/** + * @brief Multi-database interpreter context handler + * + * @tparam TSCHandler High-level multi-database/SessionContext handler type + */ +template +class InterpContextHandler : public Handler, ExpandedInterpConfig> { + public: + using InterpContextT = ExpandedInterpContext; + using HandlerT = Handler; + + /** + * @brief Generate a new interpreter context associated with the passed name. + * + * @param name Name associating the new interpreter context + * @param sc_handler Multi-database/SessionContext handler used (some queries might use it) + * @param db Storage associated with the interpreter context + * @param config Interpreter's configuration + * @param dir Directory used by the interpreter + * @param auth_handler AuthQueryHandler used + * @param auth_checker AuthChecker used + * @return HandlerT::NewResult + */ + typename HandlerT::NewResult New(const std::string &name, TSCHandler &sc_handler, storage::Config storage_config, + const query::InterpreterConfig &interpreter_config, + query::AuthQueryHandler &auth_handler, query::AuthChecker &auth_checker) { + // Check if compatible with the existing interpreters + if (std::any_of(HandlerT::cbegin(), HandlerT::cend(), [&](const auto &elem) { + const auto &config = elem.second.config().storage_config; + return config.durability.storage_directory == storage_config.durability.storage_directory; + })) { + spdlog::info("Tried to generate a new context using claimed directory and/or storage."); + return NewError::EXISTS; + } + const auto dir = storage_config.durability.storage_directory; + storage_config.name = name; // Set storage id via config + return HandlerT::New( + name, std::forward_as_tuple(storage_config, interpreter_config), + std::forward_as_tuple(sc_handler, storage_config, interpreter_config, dir, &auth_handler, &auth_checker)); + } + + /** + * @brief All currently active storage. + * + * @return std::vector + */ + std::vector All() const { + std::vector res; + res.reserve(std::distance(HandlerT::cbegin(), HandlerT::cend())); + std::for_each(HandlerT::cbegin(), HandlerT::cend(), [&](const auto &elem) { res.push_back(elem.first); }); + return res; + } +}; + +} // namespace memgraph::dbms + +#endif diff --git a/src/dbms/session_context.hpp b/src/dbms/session_context.hpp new file mode 100644 index 000000000..691b9ee95 --- /dev/null +++ b/src/dbms/session_context.hpp @@ -0,0 +1,61 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include "auth/auth.hpp" +#include "query/interpreter.hpp" +#include "storage/v2/storage.hpp" +#include "utils/synchronized.hpp" + +#if MG_ENTERPRISE +#include "audit/log.hpp" +#endif +namespace memgraph::dbms { + +/** + * @brief Structure encapsulating storage and interpreter context. + * + * @note Each session contains a copy. + */ +struct SessionContext { + // Explicit constructor here to ensure that pointers to all objects are + // supplied. + + SessionContext(std::shared_ptr interpreter_context, std::string run, + memgraph::utils::Synchronized *auth +#ifdef MG_ENTERPRISE + , + memgraph::audit::Log *audit_log +#endif + ) + : interpreter_context(interpreter_context), + run_id(run), + auth(auth) +#ifdef MG_ENTERPRISE + , + audit_log(audit_log) +#endif + { + } + + std::shared_ptr interpreter_context; + std::string run_id; + + // std::shared_ptr auth_context; + memgraph::utils::Synchronized *auth; + +#ifdef MG_ENTERPRISE + memgraph::audit::Log *audit_log; +#endif +}; + +} // namespace memgraph::dbms diff --git a/src/dbms/session_context_handler.hpp b/src/dbms/session_context_handler.hpp new file mode 100644 index 000000000..815c1088f --- /dev/null +++ b/src/dbms/session_context_handler.hpp @@ -0,0 +1,603 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "constants.hpp" +#include "global.hpp" +#include "interp_handler.hpp" +#include "query/auth_checker.hpp" +#include "query/config.hpp" +#include "query/interpreter.hpp" +#include "session_context.hpp" +#include "spdlog/spdlog.h" +#include "storage/v2/durability/durability.hpp" +#include "storage/v2/durability/paths.hpp" +#include "utils/exceptions.hpp" +#include "utils/file.hpp" +#include "utils/logging.hpp" +#include "utils/result.hpp" +#include "utils/rw_lock.hpp" +#include "utils/synchronized.hpp" +#include "utils/uuid.hpp" + +#include "handler.hpp" + +namespace memgraph::dbms { + +#ifdef MG_ENTERPRISE + +using DeleteResult = utils::BasicResult; + +/** + * @brief Multi-database session contexts handler. + */ +class SessionContextHandler { + public: + using StorageT = storage::Storage; + using StorageConfigT = storage::Config; + using LockT = utils::RWLock; + using NewResultT = utils::BasicResult; + + struct Config { + StorageConfigT storage_config; //!< Storage configuration + query::InterpreterConfig interp_config; //!< Interpreter context configuration + std::function *, + std::unique_ptr &, std::unique_ptr &)> + glue_auth; + }; + + struct Statistics { + uint64_t num_vertex; //!< Sum of vertexes in every database + uint64_t num_edges; //!< Sum of edges in every database + uint64_t num_databases; //! number of isolated databases + }; + + /** + * @brief Initialize the handler. + * + * @param audit_log pointer to the audit logger (ENTERPRISE only) + * @param configs storage and interpreter configurations + * @param recovery_on_startup restore databases (and its content) and authentication data + */ + SessionContextHandler(memgraph::audit::Log &audit_log, Config configs, bool recovery_on_startup, bool delete_on_drop) + : lock_{utils::RWLock::Priority::READ}, + default_configs_(configs), + run_id_{utils::GenerateUUID()}, + audit_log_(&audit_log), + delete_on_drop_(delete_on_drop) { + const auto &root = configs.storage_config.durability.storage_directory; + utils::EnsureDirOrDie(root); + // Verify that the user that started the process is the same user that is + // the owner of the storage directory. + storage::durability::VerifyStorageDirectoryOwnerAndProcessUserOrDie(root); + + // Create the lock file and open a handle to it. This will crash the + // database if it can't open the file for writing or if any other process is + // holding the file opened. + lock_file_path_ = root / ".lock"; + lock_file_handle_.Open(lock_file_path_, utils::OutputFile::Mode::OVERWRITE_EXISTING); + MG_ASSERT(lock_file_handle_.AcquireLock(), + "Couldn't acquire lock on the storage directory {}" + "!\nAnother Memgraph process is currently running with the same " + "storage directory, please stop it first before starting this " + "process!", + root); + + // TODO: Figure out if this is needed/wanted + // Clear auth database since we are not recovering + // if (!recovery_on_startup) { + // const auto &auth_dir = root / "auth"; + // // Backup if auth present + // if (utils::DirExists(auth_dir)) { + // auto backup_dir = root / storage::durability::kBackupDirectory; + // std::error_code error_code; + // utils::EnsureDirOrDie(backup_dir); + // std::error_code ec; + // const auto now = std::chrono::system_clock::now(); + // std::ostringstream os; + // os << now.time_since_epoch().count(); + // std::filesystem::rename(auth_dir, backup_dir / ("auth-" + os.str()), ec); + // MG_ASSERT(!ec, "Couldn't backup auth directory because of: {}", ec.message()); + // spdlog::warn( + // "Since Memgraph was not supposed to recover on startup the authentication files will be " + // "overwritten. To prevent important data loss, Memgraph has stored those files into .backup directory " + // "inside the storage directory."); + // } + + // // Clear + // if (std::filesystem::exists(auth_dir)) { + // std::filesystem::remove_all(auth_dir); + // } + // } + + // Lazy initialization of auth_ + auth_ = std::make_unique>(root / "auth"); + configs.glue_auth(auth_.get(), auth_handler_, auth_checker_); + + // TODO: Decouple storage config from dbms config + // TODO: Save individual db configs inside the kvstore and restore from there + storage::UpdatePaths(default_configs_->storage_config, + default_configs_->storage_config.durability.storage_directory / "databases"); + const auto &db_dir = default_configs_->storage_config.durability.storage_directory; + const auto durability_dir = db_dir / ".durability"; + utils::EnsureDirOrDie(db_dir); + utils::EnsureDirOrDie(durability_dir); + durability_ = std::make_unique(durability_dir); + + // Generate the default database + MG_ASSERT(!NewDefault_().HasError(), "Failed while creating the default DB."); + + // Recover previous databases + if (recovery_on_startup) { + for (const auto &[name, _] : *durability_) { + if (name == kDefaultDB) continue; // Already set + spdlog::info("Restoring database {}.", name); + MG_ASSERT(!New_(name).HasError(), "Failed while creating database {}.", name); + spdlog::info("Database {} restored.", name); + } + } else { // Clear databases from the durability list and auth + auto locked_auth = auth_->Lock(); + for (const auto &[name, _] : *durability_) { + if (name == kDefaultDB) continue; + locked_auth->DeleteDatabase(name); + durability_->Delete(name); + } + } + } + + void Shutdown() { + for (auto &ic : interp_handler_) memgraph::query::Shutdown(ic.second.get().get()); + } + + /** + * @brief Create a new SessionContext associated with the "name" database + * + * @param name name of the database + * @return NewResultT context on success, error on failure + */ + NewResultT New(const std::string &name) { + std::lock_guard wr(lock_); + return New_(name, name); + } + + /** + * @brief Get the context associated with the "name" database + * + * @param name + * @return SessionContext + * @throw UnknownDatabaseException if getting unknown database + */ + SessionContext Get(const std::string &name) { + std::shared_lock rd(lock_); + return Get_(name); + } + + /** + * @brief Set the undelying database for a particular session. + * + * @param uuid unique session identifier + * @param db_name unique database name + * @return SetForResult enum + * @throws UnknownDatabaseException, UnknownSessionException or anything OnChange throws + */ + SetForResult SetFor(const std::string &uuid, const std::string &db_name) { + std::shared_lock rd(lock_); + (void)Get_( + db_name); // throws if db doesn't exist (TODO: Better to pass it via OnChange - but injecting dependency) + try { + auto &s = sessions_.at(uuid); + return s.OnChange(db_name); + } catch (std::out_of_range &) { + throw UnknownSessionException("Unknown session \"{}\"", uuid); + } + } + + /** + * @brief Set the undelying database from a session itself. SessionContext handler. + * + * @param db_name unique database name + * @param handler function that gets called in place with the appropriate SessionContext + * @return SetForResult enum + */ + template + requires std::invocable SetForResult SetInPlace(const std::string &db_name, + THandler handler) { + std::shared_lock rd(lock_); + return handler(Get_(db_name)); + } + + /** + * @brief Call void handler under a shared lock. + * + * @param handler function that gets called in place + */ + template + requires std::invocable + void CallInPlace(THandler handler) { + std::shared_lock rd(lock_); + handler(); + } + + /** + * @brief Register an active session (used to handle callbacks). + * + * @param session + * @return true on success + */ + bool Register(SessionInterface &session) { + std::lock_guard wr(lock_); + auto [_, success] = sessions_.emplace(session.UUID(), session); + return success; + } + + /** + * @brief Delete a session. + * + * @param session + */ + bool Delete(const SessionInterface &session) { + std::lock_guard wr(lock_); + return sessions_.erase(session.UUID()) > 0; + } + + /** + * @brief Delete database. + * + * @param db_name database name + * @return DeleteResult error on failure + */ + DeleteResult Delete(const std::string &db_name) { + std::lock_guard wr(lock_); + if (db_name == kDefaultDB) { + // MSG cannot delete the default db + return DeleteError::DEFAULT_DB; + } + // Check if db exists + try { + auto sc = Get_(db_name); + // Check if a session is using the db + if (!sc.interpreter_context->interpreters->empty()) { + return DeleteError::USING; + } + } catch (UnknownDatabaseException &) { + return DeleteError::NON_EXISTENT; + } + + // High level handlers + for (auto &[_, s] : sessions_) { + if (!s.OnDelete(db_name)) { + spdlog::error("Partial failure while deleting database \"{}\".", db_name); + defunct_dbs_.emplace(db_name); + return DeleteError::FAIL; + } + } + + // Low level handlers + const auto storage_path = StorageDir_(db_name); + MG_ASSERT(storage_path, "Missing storage for {}", db_name); + if (!interp_handler_.Delete(db_name)) { + spdlog::error("Partial failure while deleting database \"{}\".", db_name); + defunct_dbs_.emplace(db_name); + return DeleteError::FAIL; + } + + // Remove from auth + auth_->Lock()->DeleteDatabase(db_name); + // Remove from durability list + if (durability_) durability_->Delete(db_name); + + // Delete disk storage + if (delete_on_drop_) { + std::error_code ec; + (void)std::filesystem::remove_all(*storage_path, ec); + if (ec) { + spdlog::error("Failed to clean disk while deleting database \"{}\".", db_name); + defunct_dbs_.emplace(db_name); + return DeleteError::DISK_FAIL; + } + } + + // Delete from defunct_dbs_ (in case a second delete call was successful) + defunct_dbs_.erase(db_name); + + return {}; // Success + } + + /** + * @brief Set the default configurations. + * + * @param configs storage, interpreter and authorization configurations + */ + void SetDefaultConfigs(Config configs) { + std::lock_guard wr(lock_); + default_configs_ = configs; + } + + /** + * @brief Get the default configurations. + * + * @return std::optional + */ + std::optional GetDefaultConfigs() const { + std::shared_lock rd(lock_); + return default_configs_; + } + + /** + * @brief Return all active databases. + * + * @return std::vector + */ + std::vector All() const { + std::shared_lock rd(lock_); + return interp_handler_.All(); + } + + /** + * @brief Return the number of vertex across all databases. + * + * @return uint64_t + */ + Statistics Info() const { + // TODO: Handle overflow + uint64_t nv = 0; + uint64_t ne = 0; + std::shared_lock rd(lock_); + const uint64_t ndb = std::distance(interp_handler_.cbegin(), interp_handler_.cend()); + for (const auto &ic : interp_handler_) { + const auto &info = ic.second.get()->db->GetInfo(); + nv += info.vertex_count; + ne += info.edge_count; + } + return {nv, ne, ndb}; + } + + /** + * @brief Return the currently active database for a particular session. + * + * @param uuid session's unique identifier + * @return std::string name of the database + * @throw + */ + std::string Current(const std::string &uuid) const { + std::shared_lock rd(lock_); + return sessions_.at(uuid).GetDatabaseName(); + } + + /** + * @brief Restore triggers for all currently defined databases. + * @note: Triggers can execute query procedures, so we need to reload the modules first and then the triggers + */ + void RestoreTriggers() { + std::lock_guard wr(lock_); + for (auto &ic_itr : interp_handler_) { + auto ic = ic_itr.second.get(); + spdlog::debug("Restoring trigger for database \"{}\"", ic->db->id()); + auto storage_accessor = ic->db->Access(); + auto dba = memgraph::query::DbAccessor{storage_accessor.get()}; + ic->trigger_store.RestoreTriggers(&ic->ast_cache, &dba, ic->config.query, ic->auth_checker); + } + } + + /** + * @brief Restore streams of all currently defined databases. + * @note: Stream transformations are using modules, they have to be restored after the query modules are loaded. + */ + void RestoreStreams() { + std::lock_guard wr(lock_); + for (auto &ic_itr : interp_handler_) { + auto ic = ic_itr.second.get(); + spdlog::debug("Restoring streams for database \"{}\"", ic->db->id()); + ic->streams.RestoreStreams(); + } + } + + private: + std::optional StorageDir_(const std::string &name) const { + const auto conf = interp_handler_.GetConfig(name); + if (conf) { + return conf->storage_config.durability.storage_directory; + } + spdlog::debug("Failed to find storage dir for database \"{}\"", name); + return {}; + } + + /** + * @brief Create a new SessionContext associated with the "name" database + * + * @param name name of the database + * @return NewResultT context on success, error on failure + */ + NewResultT New_(const std::string &name) { return New_(name, name); } + + /** + * @brief Create a new SessionContext associated with the "name" database + * + * @param name name of the database + * @param storage_subdir undelying RocksDB directory + * @return NewResultT context on success, error on failure + */ + NewResultT New_(const std::string &name, std::filesystem::path storage_subdir) { + if (default_configs_) { + auto storage = default_configs_->storage_config; + storage::UpdatePaths(storage, storage.durability.storage_directory / storage_subdir); + return New_(name, storage, default_configs_->interp_config); + } + spdlog::info("Trying to generate session context without any configurations."); + return NewError::NO_CONFIGS; + } + + /** + * @brief Create a new SessionContext associated with the "name" database + * + * @param name name of the database + * @param storage_config storage configuration + * @param inter_config interpreter configuration + * @return NewResultT context on success, error on failure + */ + NewResultT New_(const std::string &name, StorageConfigT &storage_config, query::InterpreterConfig &inter_config/*, + const std::string &ah_flags*/) { + MG_ASSERT(auth_handler_, "No high level AuthQueryHandler has been supplied."); + MG_ASSERT(auth_checker_, "No high level AuthChecker has been supplied."); + + if (defunct_dbs_.contains(name)) { + spdlog::warn("Failed to generate database due to the unknown state of the previously defunct database \"{}\".", + name); + return NewError::DEFUNCT; + } + + auto new_interp = interp_handler_.New(name, *this, storage_config, inter_config, *auth_handler_, *auth_checker_); + + if (new_interp.HasValue()) { + // Success + if (durability_) durability_->Put(name, "ok"); + return SessionContext{new_interp.GetValue(), run_id_, auth_.get(), audit_log_}; + } + return new_interp.GetError(); + } + + /** + * @brief Create a new SessionContext associated with the default database + * + * @return NewResultT context on success, error on failure + */ + NewResultT NewDefault_() { + // Create the default DB in the root (this is how it was done pre multi-tenancy) + auto res = New_(kDefaultDB, ".."); + if (res.HasValue()) { + // For back-compatibility... + // Recreate the dbms layout for the default db and symlink to the root + const auto dir = StorageDir_(kDefaultDB); + MG_ASSERT(dir, "Failed to find storage path."); + const auto main_dir = *dir / "databases" / kDefaultDB; + + if (!std::filesystem::exists(main_dir)) { + std::filesystem::create_directory(main_dir); + } + + // Force link on-disk directories + const auto conf = interp_handler_.GetConfig(kDefaultDB); + MG_ASSERT(conf, "No configuration for the default database."); + const auto &tmp_conf = conf->storage_config.disk; + std::vector to_link{ + tmp_conf.main_storage_directory, tmp_conf.label_index_directory, + tmp_conf.label_property_index_directory, tmp_conf.unique_constraints_directory, + tmp_conf.name_id_mapper_directory, tmp_conf.id_name_mapper_directory, + tmp_conf.durability_directory, tmp_conf.wal_directory, + }; + + // Add in-memory paths + // Some directories are redundant (skip those) + const std::vector skip{".lock", "audit_log", "auth", "databases", "internal_modules", "settings"}; + for (auto const &item : std::filesystem::directory_iterator{*dir}) { + const auto dir_name = std::filesystem::relative(item.path(), item.path().parent_path()); + if (std::find(skip.begin(), skip.end(), dir_name) != skip.end()) continue; + to_link.push_back(item.path()); + } + + // Symlink to root dir + for (auto const &item : to_link) { + const auto dir_name = std::filesystem::relative(item, item.parent_path()); + const auto link = main_dir / dir_name; + const auto to = std::filesystem::relative(item, main_dir); + if (!std::filesystem::is_symlink(link) && !std::filesystem::exists(link)) { + std::filesystem::create_directory_symlink(to, link); + } else { // Check existing link + std::error_code ec; + const auto test_link = std::filesystem::read_symlink(link, ec); + if (ec || test_link != to) { + MG_ASSERT(false, + "Memgraph storage directory incompatible with new version.\n" + "Please use a clean directory or remove \"{}\" and try again.", + link.string()); + } + } + } + } + return res; + } + + /** + * @brief Get the context associated with the "name" database + * + * @param name + * @return SessionContext + * @throw UnknownDatabaseException if trying to get unknown database + */ + SessionContext Get_(const std::string &name) { + auto interp = interp_handler_.Get(name); + if (interp) { + return SessionContext{*interp, run_id_, auth_.get(), audit_log_}; + } + throw UnknownDatabaseException("Tried to retrieve an unknown database \"{}\".", name); + } + + // Should storage objects ever be deleted? + mutable LockT lock_; //!< protective lock + std::filesystem::path lock_file_path_; //!< Lock file protecting the main storage + utils::OutputFile lock_file_handle_; //!< Handler the lock (crash if already open) + InterpContextHandler interp_handler_; //!< multi-tenancy interpreter handler + // AuthContextHandler auth_handler_; //!< multi-tenancy authorization handler (currently we use a single global + // auth) + std::unique_ptr> auth_; + std::unique_ptr auth_handler_; + std::unique_ptr auth_checker_; + std::optional default_configs_; //!< default storage and interpreter configurations + const std::string run_id_; //!< run's unique identifier (auto generated) + memgraph::audit::Log *audit_log_; //!< pointer to the audit logger + std::unordered_map sessions_; //!< map of active/registered sessions + std::unique_ptr durability_; //!< list of active dbs (pointer so we can postpone its creation) + + std::set defunct_dbs_; //!< Databases that are in an unknown state due to various failures + bool delete_on_drop_; //!< Flag defining if dropping storage also deletes its directory + public: + static SessionContextHandler &ExtractSCH(query::InterpreterContext *interpreter_context) { + return static_cast(interpreter_context)->sc_handler_; + } +}; + +#else +/** + * @brief Initialize the handler. + * + * @param auth pointer to the authenticator + * @param configs storage and interpreter configurations + */ +static inline SessionContext Init(storage::Config &storage_config, query::InterpreterConfig &interp_config, + utils::Synchronized *auth, + query::AuthQueryHandler *auth_handler, query::AuthChecker *auth_checker) { + MG_ASSERT(auth, "Passed a nullptr auth"); + MG_ASSERT(auth_handler, "Passed a nullptr auth_handler"); + MG_ASSERT(auth_checker, "Passed a nullptr auth_checker"); + + storage_config.name = kDefaultDB; + auto interp_context = std::make_shared( + storage_config, interp_config, storage_config.durability.storage_directory, auth_handler, auth_checker); + MG_ASSERT(interp_context, "Failed to construct main interpret context."); + + return SessionContext{interp_context, utils::GenerateUUID(), auth}; +} +#endif + +} // namespace memgraph::dbms diff --git a/src/glue/auth.cpp b/src/glue/auth.cpp index 0ac58e844..8344ad49d 100644 --- a/src/glue/auth.cpp +++ b/src/glue/auth.cpp @@ -62,6 +62,10 @@ auth::Permission PrivilegeToPermission(query::AuthQuery::Privilege privilege) { return auth::Permission::STORAGE_MODE; case query::AuthQuery::Privilege::TRANSACTION_MANAGEMENT: return auth::Permission::TRANSACTION_MANAGEMENT; + case query::AuthQuery::Privilege::MULTI_DATABASE_EDIT: + return auth::Permission::MULTI_DATABASE_EDIT; + case query::AuthQuery::Privilege::MULTI_DATABASE_USE: + return auth::Permission::MULTI_DATABASE_USE; } } diff --git a/src/glue/auth_checker.cpp b/src/glue/auth_checker.cpp index 011a4bb3b..3ea3f998b 100644 --- a/src/glue/auth_checker.cpp +++ b/src/glue/auth_checker.cpp @@ -71,7 +71,8 @@ AuthChecker::AuthChecker( : auth_(auth) {} bool AuthChecker::IsUserAuthorized(const std::optional &username, - const std::vector &privileges) const { + const std::vector &privileges, + const std::string &db_name) const { std::optional maybe_user; { auto locked_auth = auth_->ReadLock(); @@ -83,7 +84,7 @@ bool AuthChecker::IsUserAuthorized(const std::optional &username, } } - return maybe_user.has_value() && IsUserAuthorized(*maybe_user, privileges); + return maybe_user.has_value() && IsUserAuthorized(*maybe_user, privileges, db_name); } #ifdef MG_ENTERPRISE @@ -108,7 +109,13 @@ std::unique_ptr AuthChecker::GetFineGra #endif bool AuthChecker::IsUserAuthorized(const memgraph::auth::User &user, - const std::vector &privileges) { + const std::vector &privileges, + const std::string &db_name) { // NOLINT +#ifdef MG_ENTERPRISE + if (!db_name.empty() && !user.db_access().Contains(db_name)) { + return false; + } +#endif const auto user_permissions = user.GetPermissions(); return std::all_of(privileges.begin(), privileges.end(), [&user_permissions](const auto privilege) { return user_permissions.Has(memgraph::glue::PrivilegeToPermission(privilege)) == diff --git a/src/glue/auth_checker.hpp b/src/glue/auth_checker.hpp index 22f6515c3..e0f917723 100644 --- a/src/glue/auth_checker.hpp +++ b/src/glue/auth_checker.hpp @@ -25,7 +25,8 @@ class AuthChecker : public query::AuthChecker { memgraph::utils::Synchronized *auth); bool IsUserAuthorized(const std::optional &username, - const std::vector &privileges) const override; + const std::vector &privileges, + const std::string &db_name) const override; #ifdef MG_ENTERPRISE std::unique_ptr GetFineGrainedAuthChecker( @@ -33,7 +34,8 @@ class AuthChecker : public query::AuthChecker { #endif [[nodiscard]] static bool IsUserAuthorized(const memgraph::auth::User &user, - const std::vector &privileges); + const std::vector &privileges, + const std::string &db_name = ""); private: memgraph::utils::Synchronized *auth_; diff --git a/src/glue/auth_handler.cpp b/src/glue/auth_handler.cpp index 12cc533b7..0abc8053d 100644 --- a/src/glue/auth_handler.cpp +++ b/src/glue/auth_handler.cpp @@ -16,6 +16,7 @@ #include #include "auth/models.hpp" +#include "dbms/constants.hpp" #include "glue/auth.hpp" #include "license/license.hpp" #include "query/constants.hpp" @@ -122,6 +123,29 @@ std::vector> ShowRolePrivileges( } #ifdef MG_ENTERPRISE +std::vector> ShowDatabasePrivileges( + const std::optional &user) { + if (!memgraph::license::global_license_checker.IsEnterpriseValidFast() || !user) { + return {}; + } + + const auto &db = user->db_access(); + const auto &allows = db.GetAllowAll(); + const auto &grants = db.GetGrants(); + const auto &denies = db.GetDenies(); + + std::vector res; // First element is a list of granted databases, second of revoked ones + if (allows) { + res.emplace_back("*"); + } else { + std::vector grants_vec(grants.cbegin(), grants.cend()); + res.emplace_back(std::move(grants_vec)); + } + std::vector denies_vec(denies.cbegin(), denies.cend()); + res.emplace_back(std::move(denies_vec)); + return {res}; +} + std::vector GetFineGrainedPermissionForPrivilegeForUserOrRole( const memgraph::auth::FineGrainedAccessPermissions &permissions, const std::string &permission_type, const std::string &user_or_role) { @@ -268,6 +292,10 @@ bool AuthQueryHandler::CreateUser(const std::string &username, const std::option } #endif ); +#ifdef MG_ENTERPRISE + GrantDatabaseToUser(auth::kAllDatabases, username); + SetMainDatabase(username, dbms::kDefaultDB); +#endif } return user_added; @@ -319,6 +347,67 @@ bool AuthQueryHandler::CreateRole(const std::string &rolename) { } } +#ifdef MG_ENTERPRISE +bool AuthQueryHandler::RevokeDatabaseFromUser(const std::string &db, const std::string &username) { + if (!std::regex_match(username, name_regex_)) { + throw memgraph::query::QueryRuntimeException("Invalid user name."); + } + try { + auto locked_auth = auth_->Lock(); + auto user = locked_auth->GetUser(username); + if (!user) return false; + return locked_auth->RevokeDatabaseFromUser(db, username); + } catch (const memgraph::auth::AuthException &e) { + throw memgraph::query::QueryRuntimeException(e.what()); + } +} + +bool AuthQueryHandler::GrantDatabaseToUser(const std::string &db, const std::string &username) { + if (!std::regex_match(username, name_regex_)) { + throw memgraph::query::QueryRuntimeException("Invalid user name."); + } + try { + auto locked_auth = auth_->Lock(); + auto user = locked_auth->GetUser(username); + if (!user) return false; + return locked_auth->GrantDatabaseToUser(db, username); + } catch (const memgraph::auth::AuthException &e) { + throw memgraph::query::QueryRuntimeException(e.what()); + } +} + +std::vector> AuthQueryHandler::GetDatabasePrivileges( + const std::string &username) { + if (!std::regex_match(username, name_regex_)) { + throw memgraph::query::QueryRuntimeException("Invalid user or role name."); + } + try { + auto locked_auth = auth_->ReadLock(); + auto user = locked_auth->GetUser(username); + if (!user) { + throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist.", username); + } + return ShowDatabasePrivileges(user); + } catch (const memgraph::auth::AuthException &e) { + throw memgraph::query::QueryRuntimeException(e.what()); + } +} + +bool AuthQueryHandler::SetMainDatabase(const std::string &db, const std::string &username) { + if (!std::regex_match(username, name_regex_)) { + throw memgraph::query::QueryRuntimeException("Invalid user name."); + } + try { + auto locked_auth = auth_->Lock(); + auto user = locked_auth->GetUser(username); + if (!user) return false; + return locked_auth->SetMainDatabase(db, username); + } catch (const memgraph::auth::AuthException &e) { + throw memgraph::query::QueryRuntimeException(e.what()); + } +} +#endif + bool AuthQueryHandler::DropRole(const std::string &rolename) { if (!std::regex_match(rolename, name_regex_)) { throw memgraph::query::QueryRuntimeException("Invalid role name."); diff --git a/src/glue/auth_handler.hpp b/src/glue/auth_handler.hpp index 508d770ab..716095add 100644 --- a/src/glue/auth_handler.hpp +++ b/src/glue/auth_handler.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -38,6 +38,16 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler { void SetPassword(const std::string &username, const std::optional &password) override; +#ifdef MG_ENTERPRISE + bool RevokeDatabaseFromUser(const std::string &db, const std::string &username) override; + + bool GrantDatabaseToUser(const std::string &db, const std::string &username) override; + + std::vector> GetDatabasePrivileges(const std::string &username) override; + + bool SetMainDatabase(const std::string &db, const std::string &username) override; +#endif + bool CreateRole(const std::string &rolename) override; bool DropRole(const std::string &rolename) override; diff --git a/src/http_handlers/metrics.hpp b/src/http_handlers/metrics.hpp index a83f6ee06..43970e616 100644 --- a/src/http_handlers/metrics.hpp +++ b/src/http_handlers/metrics.hpp @@ -47,10 +47,10 @@ struct MetricsResponse { std::vector> event_histograms{}; }; -template +template class MetricsService { public: - explicit MetricsService(TSessionData *data) : db_(data->interpreter_context->db.get()) {} + explicit MetricsService(TSessionContext *session_context) : db_(session_context->interpreter_context->db.get()) {} nlohmann::json GetMetricsJSON() { auto response = GetMetrics(); @@ -141,10 +141,10 @@ class MetricsService { } }; -template +template class MetricsRequestHandler final { public: - explicit MetricsRequestHandler(TSessionData *data) : service_(data) { + explicit MetricsRequestHandler(TSessionContext *session_context) : service_(session_context) { spdlog::info("Basic request handler started!"); } @@ -206,6 +206,6 @@ class MetricsRequestHandler final { } private: - MetricsService service_; + MetricsService service_; }; } // namespace memgraph::http diff --git a/src/kvstore/kvstore.cpp b/src/kvstore/kvstore.cpp index d13358d52..37f9f503d 100644 --- a/src/kvstore/kvstore.cpp +++ b/src/kvstore/kvstore.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -36,7 +36,13 @@ KVStore::KVStore(std::filesystem::path storage) : pimpl_(std::make_unique( pimpl_->db.reset(db); } -KVStore::~KVStore() {} +KVStore::~KVStore() { + spdlog::debug("Destroying KVStore at {}", pimpl_->storage.string()); + const auto sync = pimpl_->db->SyncWAL(); + if (!sync.ok()) spdlog::error("KVStore sync failed!"); + const auto close = pimpl_->db->Close(); + if (!close.ok()) spdlog::error("KVStore close failed!"); +} KVStore::KVStore(KVStore &&other) { pimpl_ = std::move(other.pimpl_); } diff --git a/src/memgraph.cpp b/src/memgraph.cpp index 1c8750553..ee6fac3be 100644 --- a/src/memgraph.cpp +++ b/src/memgraph.cpp @@ -40,6 +40,9 @@ #include "communication/http/server.hpp" #include "communication/websocket/auth.hpp" #include "communication/websocket/server.hpp" +#include "dbms/constants.hpp" +#include "dbms/global.hpp" +#include "dbms/session_context.hpp" #include "glue/auth_checker.hpp" #include "glue/auth_handler.hpp" #include "helpers.hpp" @@ -98,6 +101,7 @@ #include "communication/init.hpp" #include "communication/v2/server.hpp" #include "communication/v2/session.hpp" +#include "dbms/session_context_handler.hpp" #include "glue/communication.hpp" #include "auth/auth.hpp" @@ -157,6 +161,10 @@ DEFINE_string(init_data_file, "", "Path to cypherl file that is used for creatin // `mg_import_csv`. If you change it, make sure to change it there as well. // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_string(data_directory, "mg_data", "Path to directory in which to save all permanent data."); + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +DEFINE_bool(data_recovery_on_startup, false, "Controls whether the database recovers persisted data on startup."); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_uint64(memory_warning_threshold, 1024, "Memory warning threshold, in MB. If Memgraph detects there is " @@ -174,8 +182,11 @@ DEFINE_VALIDATED_uint64(storage_gc_cycle_sec, 30, "Storage garbage collector int // `mg_import_csv`. If you change it, make sure to change it there as well. // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_bool(storage_properties_on_edges, false, "Controls whether edges have properties."); + +// storage_recover_on_startup deprecated; use data_recovery_on_startup instead // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -DEFINE_bool(storage_recover_on_startup, false, "Controls whether the storage recovers persisted data on startup."); +DEFINE_HIDDEN_bool(storage_recover_on_startup, false, + "Controls whether the storage recovers persisted data on startup."); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_VALIDATED_uint64(storage_snapshot_interval_sec, 0, "Storage snapshot creation interval (in seconds). Set " @@ -215,6 +226,12 @@ DEFINE_uint64(storage_recovery_thread_count, memgraph::storage::Config::Durability().recovery_thread_count), "The number of threads used to recover persisted data from disk."); +#ifdef MG_ENTERPRISE +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +DEFINE_bool(storage_delete_on_drop, true, + "If set to true the query 'DROP DATABASE x' will delete the underlying storage as well."); +#endif + // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_bool(telemetry_enabled, false, "Set to true to enable telemetry. We collect information about the " @@ -447,35 +464,6 @@ void AddLoggerSink(spdlog::sink_ptr new_sink) { DEFINE_HIDDEN_string(license_key, "", "License key for Memgraph Enterprise."); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_HIDDEN_string(organization_name, "", "Organization name."); - -/// Encapsulates Dbms and Interpreter that are passed through the network server -/// and worker to the session. -struct SessionData { - // Explicit constructor here to ensure that pointers to all objects are - // supplied. -#if MG_ENTERPRISE - - SessionData(memgraph::query::InterpreterContext *interpreter_context, - memgraph::utils::Synchronized *auth, - memgraph::audit::Log *audit_log) - : interpreter_context(interpreter_context), auth(auth), audit_log(audit_log) {} - memgraph::query::InterpreterContext *interpreter_context; - memgraph::utils::Synchronized *auth; - memgraph::audit::Log *audit_log; - -#else - - SessionData(memgraph::query::InterpreterContext *interpreter_context, - memgraph::utils::Synchronized *auth) - : interpreter_context(interpreter_context), auth(auth) {} - memgraph::query::InterpreterContext *interpreter_context; - memgraph::utils::Synchronized *auth; - -#endif - // NOTE: run_id should be const but that complicates code a lot. - std::optional run_id; -}; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_string(auth_user_or_role_name_regex, memgraph::glue::kDefaultUserRoleRegex.data(), "Set to the regular expression that each user or role name must fulfill."); @@ -498,7 +486,7 @@ void InitFromCypherlFile(memgraph::query::InterpreterContext &ctx, std::string c interpreter.Pull(&stream, {}, results.qid); if (audit_log) { - audit_log->Record("", "", line, {}); + audit_log->Record("", "", line, {}, memgraph::dbms::kDefaultDB); } } } @@ -529,41 +517,146 @@ auto ToQueryExtras(memgraph::communication::bolt::Value const &extra) -> memgrap return memgraph::query::QueryExtras{std::move(metadata_pv), tx_timeout}; } -class BoltSession final : public memgraph::communication::bolt::Session { +class SessionHL final : public memgraph::communication::bolt::Session { public: - BoltSession(SessionData *data, const memgraph::communication::v2::ServerEndpoint &endpoint, - memgraph::communication::v2::InputStream *input_stream, - memgraph::communication::v2::OutputStream *output_stream) + struct ContextWrapper { + explicit ContextWrapper(memgraph::dbms::SessionContext sc) + : session_context(sc), + interpreter(std::make_unique(session_context.interpreter_context.get())), + defunct_(false) { + session_context.interpreter_context->interpreters.WithLock( + [this](auto &interpreters) { interpreters.insert(interpreter.get()); }); + } + ~ContextWrapper() { Defunct(); } + + void Defunct() { + if (!defunct_) { + session_context.interpreter_context->interpreters.WithLock( + [this](auto &interpreters) { interpreters.erase(interpreter.get()); }); + defunct_ = true; + } + } + + ContextWrapper(const ContextWrapper &) = delete; + ContextWrapper &operator=(const ContextWrapper &) = delete; + + ContextWrapper(ContextWrapper &&in) noexcept + : session_context(std::move(in.session_context)), + interpreter(std::move(in.interpreter)), + defunct_(in.defunct_) { + in.defunct_ = true; + } + + ContextWrapper &operator=(ContextWrapper &&in) noexcept { + if (this != &in) { + Defunct(); + session_context = std::move(in.session_context); + interpreter = std::move(in.interpreter); + defunct_ = in.defunct_; + in.defunct_ = true; + } + return *this; + } + + memgraph::query::InterpreterContext *interpreter_context() { return session_context.interpreter_context.get(); } + memgraph::query::Interpreter *interp() { return interpreter.get(); } + memgraph::utils::Synchronized *auth() const { + return session_context.auth; + } +#ifdef MG_ENTERPRISE + memgraph::audit::Log *audit_log() const { return session_context.audit_log; } +#endif + std::string run_id() const { return session_context.run_id; } + bool defunct() const { return defunct_; } + + private: + memgraph::dbms::SessionContext session_context; + std::unique_ptr interpreter; + bool defunct_; + }; + + SessionHL( +#ifdef MG_ENTERPRISE + memgraph::dbms::SessionContextHandler &sc_handler, +#else + memgraph::dbms::SessionContext sc, +#endif + const memgraph::communication::v2::ServerEndpoint &endpoint, + memgraph::communication::v2::InputStream *input_stream, memgraph::communication::v2::OutputStream *output_stream, + const std::string &default_db = memgraph::dbms::kDefaultDB) // NOLINT : memgraph::communication::bolt::Session(input_stream, output_stream), - interpreter_context_(data->interpreter_context), - interpreter_(data->interpreter_context), - auth_(data->auth), -#if MG_ENTERPRISE - audit_log_(data->audit_log), +#ifdef MG_ENTERPRISE + sc_handler_(sc_handler), + current_(sc_handler_.Get(default_db)), +#else + current_(sc), +#endif + interpreter_context_(current_.interpreter_context()), + interpreter_(current_.interp()), + auth_(current_.auth()), +#ifdef MG_ENTERPRISE + audit_log_(current_.audit_log()), #endif endpoint_(endpoint), - run_id_(data->run_id) { + run_id_(current_.run_id()) { memgraph::metrics::IncrementCounter(memgraph::metrics::ActiveBoltSessions); - interpreter_context_->interpreters.WithLock([this](auto &interpreters) { interpreters.insert(&interpreter_); }); } - ~BoltSession() override { - memgraph::metrics::DecrementCounter(memgraph::metrics::ActiveBoltSessions); - interpreter_context_->interpreters.WithLock([this](auto &interpreters) { interpreters.erase(&interpreter_); }); + ~SessionHL() override { memgraph::metrics::DecrementCounter(memgraph::metrics::ActiveBoltSessions); } + + SessionHL(const SessionHL &) = delete; + SessionHL &operator=(const SessionHL &) = delete; + SessionHL(SessionHL &&) = delete; + SessionHL &operator=(SessionHL &&) = delete; + + void Configure(const std::map &run_time_info) override { +#ifdef MG_ENTERPRISE + std::string db; + bool update = false; + // Check if user explicitly defined the database to use + if (run_time_info.contains("db")) { + const auto &db_info = run_time_info.at("db"); + if (!db_info.IsString()) { + throw memgraph::communication::bolt::ClientError("Malformed database name."); + } + db = db_info.ValueString(); + update = db != current_.interpreter_context()->db->id(); + in_explicit_db_ = true; + // NOTE: Once in a transaction, the drivers stop explicitly sending the db and count on using it until commit + } else if (in_explicit_db_ && !interpreter_->in_explicit_transaction_) { // Just on a switch + db = GetDefaultDB(); + update = db != current_.interpreter_context()->db->id(); + in_explicit_db_ = false; + } + + // Check if the underlying database needs to be updated + if (update) { + sc_handler_.SetInPlace(db, [this](auto new_sc) mutable { + const auto &db_name = new_sc.interpreter_context->db->id(); + MultiDatabaseAuth(db_name); + try { + Update(ContextWrapper(new_sc)); + return memgraph::dbms::SetForResult::SUCCESS; + } catch (memgraph::dbms::UnknownDatabaseException &e) { + throw memgraph::communication::bolt::ClientError("No database named \"{}\" found!", db_name); + } + }); + } +#endif } - using memgraph::communication::bolt::Session::TEncoder; + using TEncoder = memgraph::communication::bolt::Encoder< + memgraph::communication::bolt::ChunkedEncoderBuffer>; void BeginTransaction(const std::map &extra) override { - interpreter_.BeginTransaction(ToQueryExtras(extra)); + interpreter_->BeginTransaction(ToQueryExtras(extra)); } - void CommitTransaction() override { interpreter_.CommitTransaction(); } + void CommitTransaction() override { interpreter_->CommitTransaction(); } - void RollbackTransaction() override { interpreter_.RollbackTransaction(); } + void RollbackTransaction() override { interpreter_->RollbackTransaction(); } std::pair, std::optional> Interpret( const std::string &query, const std::map ¶ms, @@ -580,16 +673,22 @@ class BoltSession final : public memgraph::communication::bolt::SessionRecord(endpoint_.address().to_string(), user_ ? *username : "", query, - memgraph::storage::PropertyValue(params_pv)); + memgraph::storage::PropertyValue(params_pv), interpreter_context_->db->id()); } #endif try { - auto result = interpreter_.Prepare(query, params_pv, username, ToQueryExtras(extra)); - if (user_ && !memgraph::glue::AuthChecker::IsUserAuthorized(*user_, result.privileges)) { - interpreter_.Abort(); + auto result = interpreter_->Prepare(query, params_pv, username, ToQueryExtras(extra), UUID()); + const std::string db_name = result.db ? *result.db : ""; + if (user_ && !memgraph::glue::AuthChecker::IsUserAuthorized(*user_, result.privileges, db_name)) { + interpreter_->Abort(); + if (db_name.empty()) { + throw memgraph::communication::bolt::ClientError( + "You are not authorized to execute this query! Please contact your database administrator."); + } throw memgraph::communication::bolt::ClientError( - "You are not authorized to execute this query! Please contact " - "your database administrator."); + "You are not authorized to execute this query on database \"{}\"! Please contact your database " + "administrator.", + db_name); } return {result.headers, result.qid}; @@ -604,7 +703,7 @@ class BoltSession final : public memgraph::communication::bolt::Session Pull(TEncoder *encoder, std::optional n, std::optional qid) override { - TypedValueResultStream stream(encoder, interpreter_context_->db.get()); + TypedValueResultStream stream(encoder, interpreter_context_); return PullResults(stream, n, qid); } @@ -614,14 +713,26 @@ class BoltSession final : public memgraph::communication::bolt::SessionAbort(); } + // Called during Init + // During Init, the user cannot choose the landing DB (switch is done during query execution) bool Authenticate(const std::string &username, const std::string &password) override { auto locked_auth = auth_->Lock(); if (!locked_auth->HasUsers()) { return true; } user_ = locked_auth->Authenticate(username, password); +#ifdef MG_ENTERPRISE + if (user_.has_value()) { + const auto &db = user_->db_access().GetDefault(); + // Check if the underlying database needs to be updated + if (db != current_.interpreter_context()->db->id()) { + const auto &res = sc_handler_.SetFor(UUID(), db); + return res == memgraph::dbms::SetForResult::SUCCESS || res == memgraph::dbms::SetForResult::ALREADY_SET; + } + } +#endif return user_.has_value(); } @@ -630,12 +741,31 @@ class BoltSession final : public memgraph::communication::bolt::Sessiondb->id()) { + UpdateAndDefunct(db_name); // Done during Pull, so we cannot just replace the current db + return memgraph::dbms::SetForResult::SUCCESS; + } + return memgraph::dbms::SetForResult::ALREADY_SET; + } + + bool OnDelete(const std::string &db_name) override { + MG_ASSERT(current_.interpreter_context()->db->id() != db_name && (!defunct_ || defunct_->defunct()), + "Trying to delete a database while still in use."); + return true; + } +#endif + + std::string GetDatabaseName() const override { return interpreter_context_->db->id(); } + private: template std::map PullResults(TStream &stream, std::optional n, std::optional qid) { try { - const auto &summary = interpreter_.Pull(&stream, n, qid); + const auto &summary = interpreter_->Pull(&stream, n, qid); std::map decoded_summary; for (const auto &kv : summary) { auto maybe_value = @@ -660,6 +790,11 @@ class BoltSession final : public memgraph::communication::bolt::Session(cntxt)); + defunct_->Defunct(); + } + + void Update(const std::string &db_name) { + ContextWrapper tmp(sc_handler_.Get(db_name)); + Update(std::move(tmp)); + } + + void Update(ContextWrapper &&cntxt) { + current_ = std::move(cntxt); + interpreter_ = current_.interp(); + interpreter_->in_explicit_db_ = in_explicit_db_; + interpreter_context_ = current_.interpreter_context(); + } + + /** + * @brief Authenticate user on passed database. + * + * @param db database to check against + * @throws bolt::ClientError when user is not authorized + */ + void MultiDatabaseAuth(const std::string &db) { + if (user_ && !memgraph::glue::AuthChecker::IsUserAuthorized(*user_, {}, db)) { + throw memgraph::communication::bolt::ClientError( + "You are not authorized on the database \"{}\"! Please contact your database administrator.", db); + } + } + + /** + * @brief Get the user's default database + * + * @return std::string + */ + std::string GetDefaultDB() { + if (user_.has_value()) { + return user_->db_access().GetDefault(); + } + return memgraph::dbms::kDefaultDB; + } +#endif + /// Wrapper around TEncoder which converts TypedValue to Value /// before forwarding the calls to original TEncoder. class TypedValueResultStream { public: - TypedValueResultStream(TEncoder *encoder, const memgraph::storage::Storage *db) : encoder_(encoder), db_(db) {} + TypedValueResultStream(TEncoder *encoder, memgraph::query::InterpreterContext *ic) + : encoder_(encoder), interpreter_context_(ic) {} void Result(const std::vector &values) { std::vector decoded_values; decoded_values.reserve(values.size()); for (const auto &v : values) { - auto maybe_value = memgraph::glue::ToBoltValue(v, *db_, memgraph::storage::View::NEW); + auto maybe_value = memgraph::glue::ToBoltValue(v, *interpreter_context_->db, memgraph::storage::View::NEW); if (maybe_value.HasError()) { switch (maybe_value.GetError()) { case memgraph::storage::Error::DELETED_OBJECT: @@ -699,25 +888,36 @@ class BoltSession final : public memgraph::communication::bolt::Session defunct_; + memgraph::query::InterpreterContext *interpreter_context_; - memgraph::query::Interpreter interpreter_; + memgraph::query::Interpreter *interpreter_; memgraph::utils::Synchronized *auth_; std::optional user_; #ifdef MG_ENTERPRISE memgraph::audit::Log *audit_log_; + bool in_explicit_db_{false}; //!< If true, the user has defined the database to use via metadata #endif memgraph::communication::v2::ServerEndpoint endpoint_; // NOTE: run_id should be const but that complicates code a lot. std::optional run_id_; }; -using ServerT = memgraph::communication::v2::Server; +#ifdef MG_ENTERPRISE +using ServerT = memgraph::communication::v2::Server; +#else +using ServerT = memgraph::communication::v2::Server; +#endif using MonitoringServerT = - memgraph::communication::http::Server, SessionData>; + memgraph::communication::http::Server, + memgraph::dbms::SessionContext>; using memgraph::communication::ServerContext; // Needed to correctly handle memgraph destruction from a signal handler. @@ -880,10 +1080,6 @@ int main(int argc, char **argv) { // Begin enterprise features initialization - // Auth - memgraph::utils::Synchronized auth{data_directory / - "auth"}; - #ifdef MG_ENTERPRISE // Audit log memgraph::audit::Log audit_log{data_directory / "audit", FLAGS_audit_buffer_size, @@ -907,7 +1103,7 @@ int main(int argc, char **argv) { .interval = std::chrono::seconds(FLAGS_storage_gc_cycle_sec)}, .items = {.properties_on_edges = FLAGS_storage_properties_on_edges}, .durability = {.storage_directory = FLAGS_data_directory, - .recover_on_startup = FLAGS_storage_recover_on_startup, + .recover_on_startup = FLAGS_storage_recover_on_startup || FLAGS_data_recovery_on_startup, .snapshot_retention_count = FLAGS_storage_snapshot_retention_count, .wal_file_size_kibibytes = FLAGS_storage_wal_file_size_kib, .wal_file_flush_every_n_tx = FLAGS_storage_wal_file_flush_every_n_tx, @@ -944,31 +1140,62 @@ int main(int argc, char **argv) { db_config.durability.snapshot_interval = std::chrono::seconds(FLAGS_storage_snapshot_interval_sec); } - memgraph::query::InterpreterContext interpreter_context{ - db_config, - {.query = {.allow_load_csv = FLAGS_allow_load_csv}, - .execution_timeout_sec = FLAGS_query_execution_timeout_sec, - .replication_replica_check_frequency = std::chrono::seconds(FLAGS_replication_replica_check_frequency_sec), - .default_kafka_bootstrap_servers = FLAGS_kafka_bootstrap_servers, - .default_pulsar_service_url = FLAGS_pulsar_service_url, - .stream_transaction_conflict_retries = FLAGS_stream_transaction_conflict_retries, - .stream_transaction_retry_interval = std::chrono::milliseconds(FLAGS_stream_transaction_retry_interval)}, - FLAGS_data_directory}; + // Default interpreter configuration + memgraph::query::InterpreterConfig interp_config{ + .query = {.allow_load_csv = FLAGS_allow_load_csv}, + .execution_timeout_sec = FLAGS_query_execution_timeout_sec, + .replication_replica_check_frequency = std::chrono::seconds(FLAGS_replication_replica_check_frequency_sec), + .default_kafka_bootstrap_servers = FLAGS_kafka_bootstrap_servers, + .default_pulsar_service_url = FLAGS_pulsar_service_url, + .stream_transaction_conflict_retries = FLAGS_stream_transaction_conflict_retries, + .stream_transaction_retry_interval = std::chrono::milliseconds(FLAGS_stream_transaction_retry_interval)}; + + auto auth_glue = + [flag = FLAGS_auth_user_or_role_name_regex]( + memgraph::utils::Synchronized *auth, + std::unique_ptr &ah, std::unique_ptr &ac) { + // Glue high level auth implementations to the query side + ah = std::make_unique(auth, flag); + ac = std::make_unique(auth); + // Handle users passed via arguments + auto *maybe_username = std::getenv(kMgUser); + auto *maybe_password = std::getenv(kMgPassword); + auto *maybe_pass_file = std::getenv(kMgPassfile); + if (maybe_username && maybe_password) { + ah->CreateUser(maybe_username, maybe_password); + } else if (maybe_pass_file) { + const auto [username, password] = LoadUsernameAndPassword(maybe_pass_file); + if (!username.empty() && !password.empty()) { + ah->CreateUser(username, password); + } + } + }; + #ifdef MG_ENTERPRISE - SessionData session_data{&interpreter_context, &auth, &audit_log}; + // SessionContext handler (multi-tenancy) + memgraph::dbms::SessionContextHandler sc_handler(audit_log, {db_config, interp_config, auth_glue}, + FLAGS_storage_recover_on_startup || FLAGS_data_recovery_on_startup, + FLAGS_storage_delete_on_drop); + // Just for current support... TODO remove + auto session_context = sc_handler.Get(memgraph::dbms::kDefaultDB); #else - SessionData session_data{&interpreter_context, &auth}; + + memgraph::utils::Synchronized auth_{data_directory / + "auth"}; + std::unique_ptr auth_handler; + std::unique_ptr auth_checker; + auth_glue(&auth_, auth_handler, auth_checker); + auto session_context = memgraph::dbms::Init(db_config, interp_config, &auth_, auth_handler.get(), auth_checker.get()); + #endif + auto *auth = session_context.auth; + auto &interpreter_context = *session_context.interpreter_context; // TODO remove + memgraph::query::procedure::gModuleRegistry.SetModulesDirectory(query_modules_directories, FLAGS_data_directory); memgraph::query::procedure::gModuleRegistry.UnloadAndLoadModulesFromDirectories(); memgraph::query::procedure::gCallableAliasMapper.LoadMapping(FLAGS_query_callable_mappings_path); - memgraph::glue::AuthQueryHandler auth_handler(&auth, FLAGS_auth_user_or_role_name_regex); - memgraph::glue::AuthChecker auth_checker{&auth}; - interpreter_context.auth = &auth_handler; - interpreter_context.auth_checker = &auth_checker; - if (!FLAGS_init_file.empty()) { spdlog::info("Running init file..."); #ifdef MG_ENTERPRISE @@ -982,18 +1209,10 @@ int main(int argc, char **argv) { #endif } - auto *maybe_username = std::getenv(kMgUser); - auto *maybe_password = std::getenv(kMgPassword); - auto *maybe_pass_file = std::getenv(kMgPassfile); - if (maybe_username && maybe_password) { - auth_handler.CreateUser(maybe_username, maybe_password); - } else if (maybe_pass_file) { - const auto [username, password] = LoadUsernameAndPassword(maybe_pass_file); - if (!username.empty() && !password.empty()) { - auth_handler.CreateUser(username, password); - } - } - +#ifdef MG_ENTERPRISE + sc_handler.RestoreTriggers(); + sc_handler.RestoreStreams(); +#else { // Triggers can execute query procedures, so we need to reload the modules first and then // the triggers @@ -1005,6 +1224,7 @@ int main(int argc, char **argv) { // As the Stream transformations are using modules, they have to be restored after the query modules are loaded. interpreter_context.streams.RestoreStreams(); +#endif ServerContext context; std::string service_name = "Bolt"; @@ -1016,25 +1236,35 @@ int main(int argc, char **argv) { spdlog::warn( memgraph::utils::MessageWithLink("Using non-secure Bolt connection (without SSL).", "https://memgr.ph/ssl")); } - auto server_endpoint = memgraph::communication::v2::ServerEndpoint{ boost::asio::ip::address::from_string(FLAGS_bolt_address), static_cast(FLAGS_bolt_port)}; - ServerT server(server_endpoint, &session_data, &context, FLAGS_bolt_session_inactivity_timeout, service_name, +#ifdef MG_ENTERPRISE + ServerT server(server_endpoint, &sc_handler, &context, FLAGS_bolt_session_inactivity_timeout, service_name, FLAGS_bolt_num_workers); +#else + ServerT server(server_endpoint, &session_context, &context, FLAGS_bolt_session_inactivity_timeout, service_name, + FLAGS_bolt_num_workers); +#endif - const auto run_id = memgraph::utils::GenerateUUID(); const auto machine_id = memgraph::utils::GetMachineId(); - session_data.run_id = run_id; + const auto run_id = session_context.run_id; // For current compatibility // Setup telemetry static constexpr auto telemetry_server{"https://telemetry.memgraph.com/88b5e7e8-746a-11e8-9f85-538a9e9690cc/"}; std::optional telemetry; if (FLAGS_telemetry_enabled) { telemetry.emplace(telemetry_server, data_directory / "telemetry", run_id, machine_id, std::chrono::minutes(10)); - telemetry->AddCollector("storage", [db_ = interpreter_context.db.get()]() -> nlohmann::json { - auto info = db_->GetInfo(); +#ifdef MG_ENTERPRISE + telemetry->AddCollector("storage", [&sc_handler]() -> nlohmann::json { + const auto &info = sc_handler.Info(); + return {{"vertices", info.num_vertex}, {"edges", info.num_edges}, {"databases", info.num_databases}}; + }); +#else + telemetry->AddCollector("storage", [&interpreter_context]() -> nlohmann::json { + auto info = interpreter_context.db->GetInfo(); return {{"vertices", info.vertex_count}, {"edges", info.edge_count}}; }); +#endif telemetry->AddCollector("event_counters", []() -> nlohmann::json { nlohmann::json ret; for (size_t i = 0; i < memgraph::metrics::CounterEnd(); ++i) { @@ -1050,25 +1280,25 @@ int main(int argc, char **argv) { memgraph::license::LicenseInfoSender license_info_sender(telemetry_server, run_id, machine_id, memory_limit, memgraph::license::global_license_checker.GetLicenseInfo()); - memgraph::communication::websocket::SafeAuth websocket_auth{&auth}; + memgraph::communication::websocket::SafeAuth websocket_auth{auth}; memgraph::communication::websocket::Server websocket_server{ {FLAGS_monitoring_address, static_cast(FLAGS_monitoring_port)}, &context, websocket_auth}; AddLoggerSink(websocket_server.GetLoggingSink()); MonitoringServerT metrics_server{ - {FLAGS_metrics_address, static_cast(FLAGS_metrics_port)}, &session_data, &context}; + {FLAGS_metrics_address, static_cast(FLAGS_metrics_port)}, &session_context, &context}; #ifdef MG_ENTERPRISE if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) { // Handler for regular termination signals - auto shutdown = [&metrics_server, &websocket_server, &server, &interpreter_context] { + auto shutdown = [&metrics_server, &websocket_server, &server, &sc_handler] { // Server needs to be shutdown first and then the database. This prevents // a race condition when a transaction is accepted during server shutdown. server.Shutdown(); // After the server is notified to stop accepting and processing // connections we tell the execution engine to stop processing all pending // queries. - memgraph::query::Shutdown(&interpreter_context); + sc_handler.Shutdown(); websocket_server.Shutdown(); metrics_server.Shutdown(); diff --git a/src/query/auth_checker.hpp b/src/query/auth_checker.hpp index 4f6cb1419..cb1be8985 100644 --- a/src/query/auth_checker.hpp +++ b/src/query/auth_checker.hpp @@ -24,7 +24,8 @@ class AuthChecker { virtual ~AuthChecker() = default; [[nodiscard]] virtual bool IsUserAuthorized(const std::optional &username, - const std::vector &privileges) const = 0; + const std::vector &privileges, + const std::string &db_name) const = 0; #ifdef MG_ENTERPRISE [[nodiscard]] virtual std::unique_ptr GetFineGrainedAuthChecker( @@ -92,7 +93,8 @@ class AllowEverythingFineGrainedAuthChecker final : public query::FineGrainedAut class AllowEverythingAuthChecker final : public query::AuthChecker { public: bool IsUserAuthorized(const std::optional & /*username*/, - const std::vector & /*privileges*/) const override { + const std::vector & /*privileges*/, + const std::string & /*db*/) const override { return true; } diff --git a/src/query/db_accessor.hpp b/src/query/db_accessor.hpp index f5ee4f0c1..b14f9ac0b 100644 --- a/src/query/db_accessor.hpp +++ b/src/query/db_accessor.hpp @@ -495,6 +495,8 @@ class DbAccessor final { storage::IndicesInfo ListAllIndices() const { return accessor_->ListAllIndices(); } storage::ConstraintsInfo ListAllConstraints() const { return accessor_->ListAllConstraints(); } + + const std::string &id() const { return accessor_->id(); } }; class SubgraphDbAccessor final { diff --git a/src/query/exceptions.hpp b/src/query/exceptions.hpp index 0476559ed..5c249f7cb 100644 --- a/src/query/exceptions.hpp +++ b/src/query/exceptions.hpp @@ -324,4 +324,10 @@ class ConstraintsPersistenceException : public QueryException { ConstraintsPersistenceException() : QueryException("Persisting constraints on disk failed.") {} }; +class MultiDatabaseQueryInMulticommandTxException : public QueryException { + public: + MultiDatabaseQueryInMulticommandTxException() + : QueryException("Multi-database queries are not allowed in multicommand transactions.") {} +}; + } // namespace memgraph::query diff --git a/src/query/frontend/ast/ast.cpp b/src/query/frontend/ast/ast.cpp index 1aed24e84..ca53f0752 100644 --- a/src/query/frontend/ast/ast.cpp +++ b/src/query/frontend/ast/ast.cpp @@ -279,4 +279,10 @@ constexpr utils::TypeInfo query::Exists::kType{utils::TypeId::AST_EXISTS, "Exist constexpr utils::TypeInfo query::CallSubquery::kType{utils::TypeId::AST_CALL_SUBQUERY, "CallSubquery", &query::Clause::kType}; + +constexpr utils::TypeInfo query::MultiDatabaseQuery::kType{utils::TypeId::AST_MULTI_DATABASE_QUERY, + "MultiDatabaseQuery", &query::Query::kType}; + +constexpr utils::TypeInfo query::ShowDatabasesQuery::kType{utils::TypeId::AST_SHOW_DATABASES, "ShowDatabasesQuery", + &query::Query::kType}; } // namespace memgraph diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index 90a787005..e2ce6061c 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -2778,7 +2778,11 @@ class AuthQuery : public memgraph::query::Query { REVOKE_PRIVILEGE, SHOW_PRIVILEGES, SHOW_ROLE_FOR_USER, - SHOW_USERS_FOR_ROLE + SHOW_USERS_FOR_ROLE, + GRANT_DATABASE_TO_USER, + REVOKE_DATABASE_FROM_USER, + SHOW_DATABASE_PRIVILEGES, + SET_MAIN_DATABASE, }; enum class Privilege { @@ -2804,7 +2808,9 @@ class AuthQuery : public memgraph::query::Query { MODULE_WRITE, WEBSOCKET, STORAGE_MODE, - TRANSACTION_MANAGEMENT + TRANSACTION_MANAGEMENT, + MULTI_DATABASE_EDIT, + MULTI_DATABASE_USE, }; enum class FineGrainedPrivilege { NOTHING, READ, UPDATE, CREATE_DELETE }; @@ -2818,6 +2824,7 @@ class AuthQuery : public memgraph::query::Query { std::string role_; std::string user_or_role_; memgraph::query::Expression *password_{nullptr}; + std::string database_; std::vector privileges_; std::vector>> label_privileges_; @@ -2831,6 +2838,7 @@ class AuthQuery : public memgraph::query::Query { object->role_ = role_; object->user_or_role_ = user_or_role_; object->password_ = password_ ? password_->Clone(storage) : nullptr; + object->database_ = database_; object->privileges_ = privileges_; object->label_privileges_ = label_privileges_; object->edge_type_privileges_ = edge_type_privileges_; @@ -2839,7 +2847,7 @@ class AuthQuery : public memgraph::query::Query { protected: AuthQuery(Action action, std::string user, std::string role, std::string user_or_role, Expression *password, - std::vector privileges, + std::string database, std::vector privileges, std::vector>> label_privileges, std::vector>> edge_type_privileges) : action_(action), @@ -2847,6 +2855,7 @@ class AuthQuery : public memgraph::query::Query { role_(role), user_or_role_(user_or_role), password_(password), + database_(database), privileges_(privileges), label_privileges_(label_privileges), edge_type_privileges_(edge_type_privileges) {} @@ -2856,19 +2865,31 @@ class AuthQuery : public memgraph::query::Query { }; /// Constant that holds all available privileges. -const std::vector kPrivilegesAll = { - AuthQuery::Privilege::CREATE, AuthQuery::Privilege::DELETE, - AuthQuery::Privilege::MATCH, AuthQuery::Privilege::MERGE, - AuthQuery::Privilege::SET, AuthQuery::Privilege::REMOVE, - AuthQuery::Privilege::INDEX, AuthQuery::Privilege::STATS, - AuthQuery::Privilege::AUTH, AuthQuery::Privilege::CONSTRAINT, - AuthQuery::Privilege::DUMP, AuthQuery::Privilege::REPLICATION, - AuthQuery::Privilege::READ_FILE, AuthQuery::Privilege::DURABILITY, - AuthQuery::Privilege::FREE_MEMORY, AuthQuery::Privilege::TRIGGER, - AuthQuery::Privilege::CONFIG, AuthQuery::Privilege::STREAM, - AuthQuery::Privilege::MODULE_READ, AuthQuery::Privilege::MODULE_WRITE, - AuthQuery::Privilege::WEBSOCKET, AuthQuery::Privilege::TRANSACTION_MANAGEMENT, - AuthQuery::Privilege::STORAGE_MODE}; +const std::vector kPrivilegesAll = {AuthQuery::Privilege::CREATE, + AuthQuery::Privilege::DELETE, + AuthQuery::Privilege::MATCH, + AuthQuery::Privilege::MERGE, + AuthQuery::Privilege::SET, + AuthQuery::Privilege::REMOVE, + AuthQuery::Privilege::INDEX, + AuthQuery::Privilege::STATS, + AuthQuery::Privilege::AUTH, + AuthQuery::Privilege::CONSTRAINT, + AuthQuery::Privilege::DUMP, + AuthQuery::Privilege::REPLICATION, + AuthQuery::Privilege::READ_FILE, + AuthQuery::Privilege::DURABILITY, + AuthQuery::Privilege::FREE_MEMORY, + AuthQuery::Privilege::TRIGGER, + AuthQuery::Privilege::CONFIG, + AuthQuery::Privilege::STREAM, + AuthQuery::Privilege::MODULE_READ, + AuthQuery::Privilege::MODULE_WRITE, + AuthQuery::Privilege::WEBSOCKET, + AuthQuery::Privilege::TRANSACTION_MANAGEMENT, + AuthQuery::Privilege::STORAGE_MODE, + AuthQuery::Privilege::MULTI_DATABASE_EDIT, + AuthQuery::Privilege::MULTI_DATABASE_USE}; class InfoQuery : public memgraph::query::Query { public: @@ -3446,5 +3467,38 @@ class CallSubquery : public memgraph::query::Clause { friend class AstStorage; }; +class MultiDatabaseQuery : public memgraph::query::Query { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + DEFVISITABLE(QueryVisitor); + + enum class Action { CREATE, USE, DROP }; + + memgraph::query::MultiDatabaseQuery::Action action_; + std::string db_name_; + + MultiDatabaseQuery *Clone(AstStorage *storage) const override { + auto *object = storage->Create(); + object->action_ = action_; + object->db_name_ = db_name_; + return object; + } +}; + +class ShowDatabasesQuery : public memgraph::query::Query { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + DEFVISITABLE(QueryVisitor); + + ShowDatabasesQuery *Clone(AstStorage *storage) const override { + auto *object = storage->Create(); + return object; + } +}; + } // namespace query } // namespace memgraph diff --git a/src/query/frontend/ast/ast_visitor.hpp b/src/query/frontend/ast/ast_visitor.hpp index 68e4be6f9..9bb2cddc6 100644 --- a/src/query/frontend/ast/ast_visitor.hpp +++ b/src/query/frontend/ast/ast_visitor.hpp @@ -11,6 +11,7 @@ #pragma once +#include "query/frontend/ast/ast.hpp" #include "utils/visitor.hpp" namespace memgraph::query { @@ -102,6 +103,8 @@ class CallSubquery; class AnalyzeGraphQuery; class TransactionQueueQuery; class Exists; +class MultiDatabaseQuery; +class ShowDatabasesQuery; using TreeCompositeVisitor = utils::CompositeVisitor< SingleQuery, CypherUnion, NamedExpression, OrOperator, XorOperator, AndOperator, NotOperator, AdditionOperator, @@ -139,6 +142,7 @@ class QueryVisitor : public utils::Visitor {}; + ShowConfigQuery, TransactionQueueQuery, StorageModeQuery, AnalyzeGraphQuery, + MultiDatabaseQuery, ShowDatabasesQuery> {}; } // namespace memgraph::query diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index 33647b0fa..0b1b875ce 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -1272,7 +1272,7 @@ antlrcpp::Any CypherMainVisitor::visitUserOrRoleName(MemgraphCypher::UserOrRoleN * @return AuthQuery* */ antlrcpp::Any CypherMainVisitor::visitCreateRole(MemgraphCypher::CreateRoleContext *ctx) { - AuthQuery *auth = storage_->Create(); + auto *auth = storage_->Create(); auth->action_ = AuthQuery::Action::CREATE_ROLE; auth->role_ = std::any_cast(ctx->role->accept(this)); return auth; @@ -1282,7 +1282,7 @@ antlrcpp::Any CypherMainVisitor::visitCreateRole(MemgraphCypher::CreateRoleConte * @return AuthQuery* */ antlrcpp::Any CypherMainVisitor::visitDropRole(MemgraphCypher::DropRoleContext *ctx) { - AuthQuery *auth = storage_->Create(); + auto *auth = storage_->Create(); auth->action_ = AuthQuery::Action::DROP_ROLE; auth->role_ = std::any_cast(ctx->role->accept(this)); return auth; @@ -1292,7 +1292,7 @@ antlrcpp::Any CypherMainVisitor::visitDropRole(MemgraphCypher::DropRoleContext * * @return AuthQuery* */ antlrcpp::Any CypherMainVisitor::visitShowRoles(MemgraphCypher::ShowRolesContext *ctx) { - AuthQuery *auth = storage_->Create(); + auto *auth = storage_->Create(); auth->action_ = AuthQuery::Action::SHOW_ROLES; return auth; } @@ -1301,7 +1301,7 @@ antlrcpp::Any CypherMainVisitor::visitShowRoles(MemgraphCypher::ShowRolesContext * @return AuthQuery* */ antlrcpp::Any CypherMainVisitor::visitCreateUser(MemgraphCypher::CreateUserContext *ctx) { - AuthQuery *auth = storage_->Create(); + auto *auth = storage_->Create(); auth->action_ = AuthQuery::Action::CREATE_USER; auth->user_ = std::any_cast(ctx->user->accept(this)); if (ctx->password) { @@ -1317,7 +1317,7 @@ antlrcpp::Any CypherMainVisitor::visitCreateUser(MemgraphCypher::CreateUserConte * @return AuthQuery* */ antlrcpp::Any CypherMainVisitor::visitSetPassword(MemgraphCypher::SetPasswordContext *ctx) { - AuthQuery *auth = storage_->Create(); + auto *auth = storage_->Create(); auth->action_ = AuthQuery::Action::SET_PASSWORD; auth->user_ = std::any_cast(ctx->user->accept(this)); if (!ctx->password->StringLiteral() && !ctx->literal()->CYPHERNULL()) { @@ -1331,7 +1331,7 @@ antlrcpp::Any CypherMainVisitor::visitSetPassword(MemgraphCypher::SetPasswordCon * @return AuthQuery* */ antlrcpp::Any CypherMainVisitor::visitDropUser(MemgraphCypher::DropUserContext *ctx) { - AuthQuery *auth = storage_->Create(); + auto *auth = storage_->Create(); auth->action_ = AuthQuery::Action::DROP_USER; auth->user_ = std::any_cast(ctx->user->accept(this)); return auth; @@ -1341,7 +1341,7 @@ antlrcpp::Any CypherMainVisitor::visitDropUser(MemgraphCypher::DropUserContext * * @return AuthQuery* */ antlrcpp::Any CypherMainVisitor::visitShowUsers(MemgraphCypher::ShowUsersContext *ctx) { - AuthQuery *auth = storage_->Create(); + auto *auth = storage_->Create(); auth->action_ = AuthQuery::Action::SHOW_USERS; return auth; } @@ -1350,7 +1350,7 @@ antlrcpp::Any CypherMainVisitor::visitShowUsers(MemgraphCypher::ShowUsersContext * @return AuthQuery* */ antlrcpp::Any CypherMainVisitor::visitSetRole(MemgraphCypher::SetRoleContext *ctx) { - AuthQuery *auth = storage_->Create(); + auto *auth = storage_->Create(); auth->action_ = AuthQuery::Action::SET_ROLE; auth->user_ = std::any_cast(ctx->user->accept(this)); auth->role_ = std::any_cast(ctx->role->accept(this)); @@ -1361,7 +1361,7 @@ antlrcpp::Any CypherMainVisitor::visitSetRole(MemgraphCypher::SetRoleContext *ct * @return AuthQuery* */ antlrcpp::Any CypherMainVisitor::visitClearRole(MemgraphCypher::ClearRoleContext *ctx) { - AuthQuery *auth = storage_->Create(); + auto *auth = storage_->Create(); auth->action_ = AuthQuery::Action::CLEAR_ROLE; auth->user_ = std::any_cast(ctx->user->accept(this)); return auth; @@ -1371,7 +1371,7 @@ antlrcpp::Any CypherMainVisitor::visitClearRole(MemgraphCypher::ClearRoleContext * @return AuthQuery* */ antlrcpp::Any CypherMainVisitor::visitGrantPrivilege(MemgraphCypher::GrantPrivilegeContext *ctx) { - AuthQuery *auth = storage_->Create(); + auto *auth = storage_->Create(); auth->action_ = AuthQuery::Action::GRANT_PRIVILEGE; auth->user_or_role_ = std::any_cast(ctx->userOrRole->accept(this)); if (ctx->grantPrivilegesList()) { @@ -1393,7 +1393,7 @@ antlrcpp::Any CypherMainVisitor::visitGrantPrivilege(MemgraphCypher::GrantPrivil * @return AuthQuery* */ antlrcpp::Any CypherMainVisitor::visitDenyPrivilege(MemgraphCypher::DenyPrivilegeContext *ctx) { - AuthQuery *auth = storage_->Create(); + auto *auth = storage_->Create(); auth->action_ = AuthQuery::Action::DENY_PRIVILEGE; auth->user_or_role_ = std::any_cast(ctx->userOrRole->accept(this)); if (ctx->privilegesList()) { @@ -1453,7 +1453,7 @@ antlrcpp::Any CypherMainVisitor::visitGrantPrivilegesList(MemgraphCypher::GrantP * @return AuthQuery* */ antlrcpp::Any CypherMainVisitor::visitRevokePrivilege(MemgraphCypher::RevokePrivilegeContext *ctx) { - AuthQuery *auth = storage_->Create(); + auto *auth = storage_->Create(); auth->action_ = AuthQuery::Action::REVOKE_PRIVILEGE; auth->user_or_role_ = std::any_cast(ctx->userOrRole->accept(this)); if (ctx->revokePrivilegesList()) { @@ -1526,6 +1526,16 @@ antlrcpp::Any CypherMainVisitor::visitEntitiesList(MemgraphCypher::EntitiesListC return entities; } +/** + * @return std::string + */ +antlrcpp::Any CypherMainVisitor::visitWildcardName(MemgraphCypher::WildcardNameContext *ctx) { + if (ctx->symbolicName()) { + return ctx->symbolicName()->accept(this); + } + return std::string("*"); +} + /** * @return AuthQuery::Privilege */ @@ -1553,6 +1563,8 @@ antlrcpp::Any CypherMainVisitor::visitPrivilege(MemgraphCypher::PrivilegeContext if (ctx->WEBSOCKET()) return AuthQuery::Privilege::WEBSOCKET; if (ctx->TRANSACTION_MANAGEMENT()) return AuthQuery::Privilege::TRANSACTION_MANAGEMENT; if (ctx->STORAGE_MODE()) return AuthQuery::Privilege::STORAGE_MODE; + if (ctx->MULTI_DATABASE_EDIT()) return AuthQuery::Privilege::MULTI_DATABASE_EDIT; + if (ctx->MULTI_DATABASE_USE()) return AuthQuery::Privilege::MULTI_DATABASE_USE; LOG_FATAL("Should not get here - unknown privilege!"); } @@ -1580,7 +1592,7 @@ antlrcpp::Any CypherMainVisitor::visitEntityType(MemgraphCypher::EntityTypeConte * @return AuthQuery* */ antlrcpp::Any CypherMainVisitor::visitShowPrivileges(MemgraphCypher::ShowPrivilegesContext *ctx) { - AuthQuery *auth = storage_->Create(); + auto *auth = storage_->Create(); auth->action_ = AuthQuery::Action::SHOW_PRIVILEGES; auth->user_or_role_ = std::any_cast(ctx->userOrRole->accept(this)); return auth; @@ -1590,7 +1602,7 @@ antlrcpp::Any CypherMainVisitor::visitShowPrivileges(MemgraphCypher::ShowPrivile * @return AuthQuery* */ antlrcpp::Any CypherMainVisitor::visitShowRoleForUser(MemgraphCypher::ShowRoleForUserContext *ctx) { - AuthQuery *auth = storage_->Create(); + auto *auth = storage_->Create(); auth->action_ = AuthQuery::Action::SHOW_ROLE_FOR_USER; auth->user_ = std::any_cast(ctx->user->accept(this)); return auth; @@ -1600,12 +1612,55 @@ antlrcpp::Any CypherMainVisitor::visitShowRoleForUser(MemgraphCypher::ShowRoleFo * @return AuthQuery* */ antlrcpp::Any CypherMainVisitor::visitShowUsersForRole(MemgraphCypher::ShowUsersForRoleContext *ctx) { - AuthQuery *auth = storage_->Create(); + auto *auth = storage_->Create(); auth->action_ = AuthQuery::Action::SHOW_USERS_FOR_ROLE; auth->role_ = std::any_cast(ctx->role->accept(this)); return auth; } +/** + * @return AuthQuery* + */ +antlrcpp::Any CypherMainVisitor::visitGrantDatabaseToUser(MemgraphCypher::GrantDatabaseToUserContext *ctx) { + auto *auth = storage_->Create(); + auth->action_ = AuthQuery::Action::GRANT_DATABASE_TO_USER; + auth->database_ = std::any_cast(ctx->wildcardName()->accept(this)); + auth->user_ = std::any_cast(ctx->user->accept(this)); + return auth; +} + +/** + * @return AuthQuery* + */ +antlrcpp::Any CypherMainVisitor::visitRevokeDatabaseFromUser(MemgraphCypher::RevokeDatabaseFromUserContext *ctx) { + auto *auth = storage_->Create(); + auth->action_ = AuthQuery::Action::REVOKE_DATABASE_FROM_USER; + auth->database_ = std::any_cast(ctx->wildcardName()->accept(this)); + auth->user_ = std::any_cast(ctx->user->accept(this)); + return auth; +} + +/** + * @return AuthQuery* + */ +antlrcpp::Any CypherMainVisitor::visitShowDatabasePrivileges(MemgraphCypher::ShowDatabasePrivilegesContext *ctx) { + auto *auth = storage_->Create(); + auth->action_ = AuthQuery::Action::SHOW_DATABASE_PRIVILEGES; + auth->user_ = std::any_cast(ctx->user->accept(this)); + return auth; +} + +/** + * @return AuthQuery* + */ +antlrcpp::Any CypherMainVisitor::visitSetMainDatabase(MemgraphCypher::SetMainDatabaseContext *ctx) { + auto *auth = storage_->Create(); + auth->action_ = AuthQuery::Action::SET_MAIN_DATABASE; + auth->database_ = std::any_cast(ctx->db->accept(this)); + auth->user_ = std::any_cast(ctx->user->accept(this)); + return auth; +} + antlrcpp::Any CypherMainVisitor::visitCypherReturn(MemgraphCypher::CypherReturnContext *ctx) { auto *return_clause = storage_->Create(); return_clause->body_ = std::any_cast(ctx->returnBody()->accept(this)); @@ -2671,4 +2726,33 @@ PropertyIx CypherMainVisitor::AddProperty(const std::string &name) { return stor EdgeTypeIx CypherMainVisitor::AddEdgeType(const std::string &name) { return storage_->GetEdgeTypeIx(name); } +antlrcpp::Any CypherMainVisitor::visitCreateDatabase(MemgraphCypher::CreateDatabaseContext *ctx) { + auto *mdb_query = storage_->Create(); + mdb_query->db_name_ = std::any_cast(ctx->databaseName()->accept(this)); + mdb_query->action_ = MultiDatabaseQuery::Action::CREATE; + query_ = mdb_query; + return mdb_query; +} + +antlrcpp::Any CypherMainVisitor::visitUseDatabase(MemgraphCypher::UseDatabaseContext *ctx) { + auto *mdb_query = storage_->Create(); + mdb_query->db_name_ = std::any_cast(ctx->databaseName()->accept(this)); + mdb_query->action_ = MultiDatabaseQuery::Action::USE; + query_ = mdb_query; + return mdb_query; +} + +antlrcpp::Any CypherMainVisitor::visitDropDatabase(MemgraphCypher::DropDatabaseContext *ctx) { + auto *mdb_query = storage_->Create(); + mdb_query->db_name_ = std::any_cast(ctx->databaseName()->accept(this)); + mdb_query->action_ = MultiDatabaseQuery::Action::DROP; + query_ = mdb_query; + return mdb_query; +} + +antlrcpp::Any CypherMainVisitor::visitShowDatabases(MemgraphCypher::ShowDatabasesContext * /*ctx*/) { + query_ = storage_->Create(); + return query_; +} + } // namespace memgraph::query::frontend diff --git a/src/query/frontend/ast/cypher_main_visitor.hpp b/src/query/frontend/ast/cypher_main_visitor.hpp index 56f9725b8..181fa773e 100644 --- a/src/query/frontend/ast/cypher_main_visitor.hpp +++ b/src/query/frontend/ast/cypher_main_visitor.hpp @@ -524,6 +524,11 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor { */ antlrcpp::Any visitEntitiesList(MemgraphCypher::EntitiesListContext *ctx) override; + /** + * @return std::string + */ + antlrcpp::Any visitWildcardName(MemgraphCypher::WildcardNameContext *ctx) override; + /** * @return AuthQuery::FineGrainedPrivilege */ @@ -554,6 +559,26 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor { */ antlrcpp::Any visitShowUsersForRole(MemgraphCypher::ShowUsersForRoleContext *ctx) override; + /** + * @return AuthQuery* + */ + antlrcpp::Any visitGrantDatabaseToUser(MemgraphCypher::GrantDatabaseToUserContext *ctx) override; + + /** + * @return AuthQuery* + */ + antlrcpp::Any visitRevokeDatabaseFromUser(MemgraphCypher::RevokeDatabaseFromUserContext *ctx) override; + + /** + * @return AuthQuery* + */ + antlrcpp::Any visitShowDatabasePrivileges(MemgraphCypher::ShowDatabasePrivilegesContext *ctx) override; + + /** + * @return AuthQuery* + */ + antlrcpp::Any visitSetMainDatabase(MemgraphCypher::SetMainDatabaseContext *ctx) override; + /** * @return Return* */ @@ -935,6 +960,26 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor { */ antlrcpp::Any visitCallSubquery(MemgraphCypher::CallSubqueryContext *ctx) override; + /** + * @return MultiDatabaseQuery* + */ + antlrcpp::Any visitCreateDatabase(MemgraphCypher::CreateDatabaseContext *ctx) override; + + /** + * @return MultiDatabaseQuery* + */ + antlrcpp::Any visitUseDatabase(MemgraphCypher::UseDatabaseContext *ctx) override; + + /** + * @return MultiDatabaseQuery* + */ + antlrcpp::Any visitDropDatabase(MemgraphCypher::DropDatabaseContext *ctx) override; + + /** + * @return ShowDatabasesQuery* + */ + antlrcpp::Any visitShowDatabases(MemgraphCypher::ShowDatabasesContext *ctx) override; + public: Query *query() { return query_; } const static std::string kAnonPrefix; diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 index a55d1281a..bd9ce6e95 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 @@ -107,6 +107,7 @@ memgraphCypherKeyword : cypherKeyword | UNCOMMITTED | UNLOCK | UPDATE + | USE | USER | USERS | VERSION @@ -140,6 +141,8 @@ query : cypherQuery | versionQuery | showConfigQuery | transactionQueueQuery + | multiDatabaseQuery + | showDatabases ; authQuery : createRole @@ -157,6 +160,10 @@ authQuery : createRole | showPrivileges | showRoleForUser | showUsersForRole + | grantDatabaseToUser + | revokeDatabaseFromUser + | showDatabasePrivileges + | setMainDatabase ; replicationQuery : setReplicationRole @@ -208,6 +215,10 @@ streamQuery : checkStream | showStreams ; +databaseName : symbolicName ; + +wildcardName : ASTERISK | symbolicName ; + settingQuery : setSetting | showSetting | showSettings @@ -265,6 +276,14 @@ denyPrivilege : DENY ( ALL PRIVILEGES | privileges=privilegesList ) TO userOrRol revokePrivilege : REVOKE ( ALL PRIVILEGES | privileges=revokePrivilegesList ) FROM userOrRole=userOrRoleName ; +grantDatabaseToUser : GRANT DATABASE db=wildcardName TO user=symbolicName ; + +revokeDatabaseFromUser : REVOKE DATABASE db=wildcardName FROM user=symbolicName ; + +showDatabasePrivileges : SHOW DATABASE PRIVILEGES FOR user=symbolicName ; + +setMainDatabase : SET MAIN DATABASE db=symbolicName FOR user=symbolicName ; + privilege : CREATE | DELETE | MATCH @@ -288,6 +307,8 @@ privilege : CREATE | WEBSOCKET | TRANSACTION_MANAGEMENT | STORAGE_MODE + | MULTI_DATABASE_EDIT + | MULTI_DATABASE_USE ; granularPrivilege : NOTHING | READ | UPDATE | CREATE_DELETE ; @@ -441,3 +462,16 @@ versionQuery : SHOW VERSION ; transactionIdList : transactionId ( ',' transactionId )* ; transactionId : literal ; + +multiDatabaseQuery : createDatabase + | useDatabase + | dropDatabase + ; + +createDatabase : CREATE DATABASE databaseName ; + +useDatabase : USE DATABASE databaseName ; + +dropDatabase : DROP DATABASE databaseName ; + +showDatabases: SHOW DATABASES ; diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 index 1f07e74f0..37b0015ce 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 @@ -49,6 +49,7 @@ CSV : C S V ; DATA : D A T A ; DELIMITER : D E L I M I T E R ; DATABASE : D A T A B A S E ; +DATABASES : D A T A B A S E S ; DENY : D E N Y ; DIRECTORY : D I R E C T O R Y ; DROP : D R O P ; @@ -80,6 +81,8 @@ MAIN : M A I N ; MODE : M O D E ; MODULE_READ : M O D U L E UNDERSCORE R E A D ; MODULE_WRITE : M O D U L E UNDERSCORE W R I T E ; +MULTI_DATABASE_EDIT : M U L T I UNDERSCORE D A T A B A S E UNDERSCORE E D I T ; +MULTI_DATABASE_USE : M U L T I UNDERSCORE D A T A B A S E UNDERSCORE U S E ; NEXT : N E X T ; NO : N O ; NOTHING : N O T H I N G ; @@ -127,6 +130,7 @@ TRIGGERS : T R I G G E R S ; UNCOMMITTED : U N C O M M I T T E D ; UNLOCK : U N L O C K ; UPDATE : U P D A T E ; +USE : U S E ; USER : U S E R ; USERS : U S E R S ; VERSION : V E R S I O N ; diff --git a/src/query/frontend/semantic/required_privileges.cpp b/src/query/frontend/semantic/required_privileges.cpp index 29027bca4..d7d432410 100644 --- a/src/query/frontend/semantic/required_privileges.cpp +++ b/src/query/frontend/semantic/required_privileges.cpp @@ -89,6 +89,22 @@ class PrivilegeExtractor : public QueryVisitor, public HierarchicalTreeVis void Visit(VersionQuery & /*version_query*/) override { AddPrivilege(AuthQuery::Privilege::STATS); } + void Visit(MultiDatabaseQuery &query) override { + switch (query.action_) { + case MultiDatabaseQuery::Action::CREATE: + case MultiDatabaseQuery::Action::DROP: + AddPrivilege(AuthQuery::Privilege::MULTI_DATABASE_EDIT); + break; + case MultiDatabaseQuery::Action::USE: + AddPrivilege(AuthQuery::Privilege::MULTI_DATABASE_USE); + break; + } + } + + void Visit(ShowDatabasesQuery & /*unused*/) override { + AddPrivilege(AuthQuery::Privilege::MULTI_DATABASE_USE); /* OR EDIT */ + } + bool PreVisit(Create & /*unused*/) override { AddPrivilege(AuthQuery::Privilege::CREATE); return false; diff --git a/src/query/frontend/stripped_lexer_constants.hpp b/src/query/frontend/stripped_lexer_constants.hpp index 456c704e6..775081941 100644 --- a/src/query/frontend/stripped_lexer_constants.hpp +++ b/src/query/frontend/stripped_lexer_constants.hpp @@ -139,6 +139,7 @@ const trie::Trie kKeywords = {"union", "false", "reduce", "coalesce", + "use", "user", "password", "alter", @@ -159,6 +160,7 @@ const trie::Trie kKeywords = {"union", "key", "dump", "database", + "databases", "call", "yield", "memory", diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index fe7d878af..9a158f91e 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -24,13 +24,17 @@ #include #include #include +#include #include #include #include #include +#include "auth/auth.hpp" #include "auth/models.hpp" #include "csv/parsing.hpp" +#include "dbms/global.hpp" +#include "dbms/session_context_handler.hpp" #include "glue/communication.hpp" #include "license/license.hpp" #include "memory/memory_control.hpp" @@ -105,6 +109,26 @@ template constexpr auto kAlwaysFalse = false; namespace { +template +void Sort(std::vector &vec) { + std::sort(vec.begin(), vec.end()); +} + +template +void Sort(std::vector &vec) { + std::sort(vec.begin(), vec.end(), + [](const TypedValue &lv, const TypedValue &rv) { return lv.ValueString() < rv.ValueString(); }); +} + +// NOLINTNEXTLINE (misc-unused-parameters) +bool Same(const TypedValue &lv, const TypedValue &rv) { + return TypedValue(lv).ValueString() == TypedValue(rv).ValueString(); +} +bool Same(const TypedValue &lv, const std::string &rv) { return std::string(TypedValue(lv).ValueString()) == rv; } +// NOLINTNEXTLINE (misc-unused-parameters) +bool Same(const std::string &lv, const TypedValue &rv) { return lv == std::string(TypedValue(rv).ValueString()); } +bool Same(const std::string &lv, const std::string &rv) { return lv == rv; } + void UpdateTypeCount(const plan::ReadWriteTypeChecker::RWType type) { switch (type) { case plan::ReadWriteTypeChecker::RWType::R: @@ -333,8 +357,12 @@ class ReplQueryHandler final : public query::ReplicationQueryHandler { /// returns false if the replication role can't be set /// @throw QueryRuntimeException if an error ocurred. -Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Parameters ¶meters, +Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_context, const Parameters ¶meters, DbAccessor *db_accessor) { + AuthQueryHandler *auth = interpreter_context->auth; +#ifdef MG_ENTERPRISE + auto &sc_handler = memgraph::dbms::SessionContextHandler::ExtractSCH(interpreter_context); +#endif // Empty frame for evaluation of password expression. This is OK since // password should be either null or string literal and it's evaluation // should not depend on frame. @@ -351,6 +379,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Pa std::string username = auth_query->user_; std::string rolename = auth_query->role_; std::string user_or_role = auth_query->user_or_role_; + std::string database = auth_query->database_; std::vector privileges = auth_query->privileges_; #ifdef MG_ENTERPRISE std::vector>> label_privileges = @@ -364,11 +393,20 @@ Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Pa const auto license_check_result = license::global_license_checker.IsEnterpriseValid(utils::global_settings); - static const std::unordered_set enterprise_only_methods{ - AuthQuery::Action::CREATE_ROLE, AuthQuery::Action::DROP_ROLE, AuthQuery::Action::SET_ROLE, - AuthQuery::Action::CLEAR_ROLE, AuthQuery::Action::GRANT_PRIVILEGE, AuthQuery::Action::DENY_PRIVILEGE, - AuthQuery::Action::REVOKE_PRIVILEGE, AuthQuery::Action::SHOW_PRIVILEGES, AuthQuery::Action::SHOW_USERS_FOR_ROLE, - AuthQuery::Action::SHOW_ROLE_FOR_USER}; + static const std::unordered_set enterprise_only_methods{AuthQuery::Action::CREATE_ROLE, + AuthQuery::Action::DROP_ROLE, + AuthQuery::Action::SET_ROLE, + AuthQuery::Action::CLEAR_ROLE, + AuthQuery::Action::GRANT_PRIVILEGE, + AuthQuery::Action::DENY_PRIVILEGE, + AuthQuery::Action::REVOKE_PRIVILEGE, + AuthQuery::Action::SHOW_PRIVILEGES, + AuthQuery::Action::SHOW_USERS_FOR_ROLE, + AuthQuery::Action::SHOW_ROLE_FOR_USER, + AuthQuery::Action::GRANT_DATABASE_TO_USER, + AuthQuery::Action::REVOKE_DATABASE_FROM_USER, + AuthQuery::Action::SHOW_DATABASE_PRIVILEGES, + AuthQuery::Action::SET_MAIN_DATABASE}; if (license_check_result.HasError() && enterprise_only_methods.contains(auth_query->action_)) { throw utils::BasicException( @@ -536,6 +574,73 @@ Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Pa return rows; }; return callback; + case AuthQuery::Action::GRANT_DATABASE_TO_USER: +#ifdef MG_ENTERPRISE + callback.fn = [auth, database, username, &sc_handler] { // NOLINT + try { + memgraph::dbms::SessionContext sc(nullptr, "", nullptr, nullptr); + if (database != memgraph::auth::kAllDatabases) { + sc = sc_handler.Get(database); // Will throw if databases doesn't exist and protect it during pull + } + if (!auth->GrantDatabaseToUser(database, username)) { + throw QueryRuntimeException("Failed to grant database {} to user {}.", database, username); + } + } catch (memgraph::dbms::UnknownDatabaseException &e) { + throw QueryRuntimeException(e.what()); + } +#else + callback.fn = [] { +#endif + return std::vector>(); + }; + return callback; + case AuthQuery::Action::REVOKE_DATABASE_FROM_USER: +#ifdef MG_ENTERPRISE + callback.fn = [auth, database, username, &sc_handler] { // NOLINT + try { + memgraph::dbms::SessionContext sc(nullptr, "", nullptr, nullptr); + if (database != memgraph::auth::kAllDatabases) { + sc = sc_handler.Get(database); // Will throw if databases doesn't exist and protect it during pull + } + if (!auth->RevokeDatabaseFromUser(database, username)) { + throw QueryRuntimeException("Failed to revoke database {} from user {}.", database, username); + } + } catch (memgraph::dbms::UnknownDatabaseException &e) { + throw QueryRuntimeException(e.what()); + } +#else + callback.fn = [] { +#endif + return std::vector>(); + }; + return callback; + case AuthQuery::Action::SHOW_DATABASE_PRIVILEGES: + callback.header = {"grants", "denies"}; + callback.fn = [auth, username] { // NOLINT +#ifdef MG_ENTERPRISE + return auth->GetDatabasePrivileges(username); +#else + return std::vector>(); +#endif + }; + return callback; + case AuthQuery::Action::SET_MAIN_DATABASE: +#ifdef MG_ENTERPRISE + callback.fn = [auth, database, username, &sc_handler] { // NOLINT + try { + const auto sc = sc_handler.Get(database); // Will throw if databases doesn't exist and protect it during pull + if (!auth->SetMainDatabase(database, username)) { + throw QueryRuntimeException("Failed to set main database {} for user {}.", database, username); + } + } catch (memgraph::dbms::UnknownDatabaseException &e) { + throw QueryRuntimeException(e.what()); + } +#else + callback.fn = [] { +#endif + return std::vector>(); + }; + return callback; default: break; } @@ -1249,11 +1354,23 @@ bool IsWriteQueryOnMainMemoryReplica(storage::Storage *storage, return false; } +storage::replication::ReplicationRole GetReplicaRole(storage::Storage *storage) { + if (auto storage_mode = storage->GetStorageMode(); storage_mode == storage::StorageMode::IN_MEMORY_ANALYTICAL || + storage_mode == storage::StorageMode::IN_MEMORY_TRANSACTIONAL) { + auto *mem_storage = static_cast(storage); + return mem_storage->GetReplicationRole(); + } + return storage::replication::ReplicationRole::MAIN; +} + } // namespace InterpreterContext::InterpreterContext(const storage::Config storage_config, const InterpreterConfig interpreter_config, - const std::filesystem::path &data_directory) - : trigger_store(data_directory / "triggers"), + const std::filesystem::path &data_directory, query::AuthQueryHandler *ah, + query::AuthChecker *ac) + : auth(ah), + auth_checker(ac), + trigger_store(data_directory / "triggers"), config(interpreter_config), streams{this, data_directory / "streams"} { if (utils::DirExists(storage_config.disk.main_storage_directory)) { @@ -1264,8 +1381,11 @@ InterpreterContext::InterpreterContext(const storage::Config storage_config, con } InterpreterContext::InterpreterContext(std::unique_ptr db, InterpreterConfig interpreter_config, - const std::filesystem::path &data_directory) + const std::filesystem::path &data_directory, query::AuthQueryHandler *ah, + query::AuthChecker *ac) : db(std::move(db)), + auth(ah), + auth_checker(ac), trigger_store(data_directory / "triggers"), config(interpreter_config), streams{this, data_directory / "streams"} {} @@ -2027,7 +2147,7 @@ PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transa auto *auth_query = utils::Downcast(parsed_query.query); - auto callback = HandleAuthQuery(auth_query, interpreter_context->auth, parsed_query.parameters, dba); + auto callback = HandleAuthQuery(auth_query, interpreter_context, parsed_query.parameters, dba); SymbolTable symbol_table; std::vector output_symbols; @@ -2668,7 +2788,7 @@ Callback HandleTransactionQueueQuery(TransactionQueueQuery *transaction_query, ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, db_accessor, storage::View::OLD); bool hasTransactionManagementPrivilege = interpreter_context->auth_checker->IsUserAuthorized( - username, {query::AuthQuery::Privilege::TRANSACTION_MANAGEMENT}); + username, {query::AuthQuery::Privilege::TRANSACTION_MANAGEMENT}, ""); Callback callback; switch (transaction_query->action_) { @@ -2770,6 +2890,7 @@ PreparedQuery PrepareInfoQuery(ParsedQuery parsed_query, bool in_explicit_transa handler = [db, interpreter_isolation_level, next_transaction_isolation_level] { auto info = db->GetInfo(); std::vector> results{ + {TypedValue("name"), TypedValue(db->id())}, {TypedValue("vertex_count"), TypedValue(static_cast(info.vertex_count))}, {TypedValue("edge_count"), TypedValue(static_cast(info.edge_count))}, {TypedValue("average_degree"), TypedValue(info.average_degree)}, @@ -3118,6 +3239,250 @@ PreparedQuery PrepareConstraintQuery(ParsedQuery parsed_query, bool in_explicit_ RWType::NONE}; } +PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, bool in_explicit_transaction, bool in_explicit_db, + InterpreterContext *interpreter_context, const std::string &session_uuid) { +#ifdef MG_ENTERPRISE + if (!license::global_license_checker.IsEnterpriseValidFast()) { + throw QueryException("Trying to use enterprise feature without a valid license."); + } + // TODO: Remove once replicas support multi-tenant replication + if (GetReplicaRole(interpreter_context->db.get()) == storage::replication::ReplicationRole::REPLICA) { + throw QueryException("Query forbidden on the replica!"); + } + if (in_explicit_transaction) { + throw MultiDatabaseQueryInMulticommandTxException(); + } + + auto *query = utils::Downcast(parsed_query.query); + auto &sc_handler = memgraph::dbms::SessionContextHandler::ExtractSCH(interpreter_context); + + switch (query->action_) { + case MultiDatabaseQuery::Action::CREATE: + return PreparedQuery{ + {"STATUS"}, + std::move(parsed_query.required_privileges), + [db_name = query->db_name_, session_uuid, &sc_handler]( + AnyStream *stream, std::optional n) -> std::optional { + std::vector> status; + std::string res; + + const auto success = sc_handler.New(db_name); + if (success.HasError()) { + switch (success.GetError()) { + case dbms::NewError::EXISTS: + res = db_name + " already exists."; + break; + case dbms::NewError::DEFUNCT: + throw QueryRuntimeException( + "{} is defunct and in an unknown state. Try to delete it again or clean up storage and restart " + "Memgraph.", + db_name); + case dbms::NewError::GENERIC: + throw QueryRuntimeException("Failed while creating {}", db_name); + case dbms::NewError::NO_CONFIGS: + throw QueryRuntimeException("No configuration found while trying to create {}", db_name); + } + } else { + res = "Successfully created database " + db_name; + } + status.emplace_back(std::vector{TypedValue(res)}); + auto pull_plan = std::make_shared(std::move(status)); + if (pull_plan->Pull(stream, n)) { + return QueryHandlerResult::COMMIT; + } + return std::nullopt; + }, + RWType::W, + "" // No target DB possible + }; + + case MultiDatabaseQuery::Action::USE: + if (in_explicit_db) { + throw QueryException("Database switching is prohibited if session explicitly defines the used database"); + } + return PreparedQuery{{"STATUS"}, + std::move(parsed_query.required_privileges), + [db_name = query->db_name_, session_uuid, &sc_handler]( + AnyStream *stream, std::optional n) -> std::optional { + std::vector> status; + std::string res; + + memgraph::dbms::SetForResult set = memgraph::dbms::SetForResult::SUCCESS; + + try { + set = sc_handler.SetFor(session_uuid, db_name); + } catch (const utils::BasicException &e) { + throw QueryRuntimeException(e.what()); + } + + switch (set) { + case dbms::SetForResult::SUCCESS: + res = "Using " + db_name; + break; + case dbms::SetForResult::ALREADY_SET: + res = "Already using " + db_name; + break; + case dbms::SetForResult::FAIL: + throw QueryRuntimeException("Failed to start using {}", db_name); + } + + status.emplace_back(std::vector{TypedValue(res)}); + auto pull_plan = std::make_shared(std::move(status)); + if (pull_plan->Pull(stream, n)) { + return QueryHandlerResult::COMMIT; + } + return std::nullopt; + }, + RWType::NONE, + query->db_name_}; + + case MultiDatabaseQuery::Action::DROP: + return PreparedQuery{ + {"STATUS"}, + std::move(parsed_query.required_privileges), + [db_name = query->db_name_, session_uuid, &sc_handler]( + AnyStream *stream, std::optional n) -> std::optional { + std::vector> status; + + memgraph::dbms::DeleteResult success{}; + + try { + success = sc_handler.Delete(db_name); + } catch (const utils::BasicException &e) { + throw QueryRuntimeException(e.what()); + } + + if (success.HasError()) { + switch (success.GetError()) { + case dbms::DeleteError::DEFAULT_DB: + throw QueryRuntimeException("Cannot delete the default database."); + case dbms::DeleteError::NON_EXISTENT: + throw QueryRuntimeException("{} does not exist.", db_name); + case dbms::DeleteError::USING: + throw QueryRuntimeException("Cannot delete {}, it is currently being used.", db_name); + case dbms::DeleteError::FAIL: + throw QueryRuntimeException("Failed while deleting {}", db_name); + case dbms::DeleteError::DISK_FAIL: + throw QueryRuntimeException("Failed to clean storage of {}", db_name); + } + } + + status.emplace_back(std::vector{TypedValue("Successfully deleted " + db_name)}); + auto pull_plan = std::make_shared(std::move(status)); + if (pull_plan->Pull(stream, n)) { + return QueryHandlerResult::COMMIT; + } + return std::nullopt; + }, + RWType::W, + query->db_name_}; + } +#else + throw QueryException("Query not supported."); +#endif +} + +PreparedQuery PrepareShowDatabasesQuery(ParsedQuery parsed_query, InterpreterContext *interpreter_context, + const std::string &session_uuid, std::map *summary, + DbAccessor *dba, utils::MemoryResource *execution_memory, + const std::optional &username, + std::atomic *transaction_status, + std::shared_ptr tx_timer) { +#ifdef MG_ENTERPRISE + if (!license::global_license_checker.IsEnterpriseValidFast()) { + throw QueryException("Trying to use enterprise feature without a valid license."); + } + // TODO: Remove once replicas support multi-tenant replication + if (GetReplicaRole(interpreter_context->db.get()) == storage::replication::ReplicationRole::REPLICA) { + throw QueryException("SHOW DATABASES forbidden on the replica!"); + } + + auto &sc_handler = memgraph::dbms::SessionContextHandler::ExtractSCH(interpreter_context); + AuthQueryHandler *auth = interpreter_context->auth; + + Callback callback; + callback.header = {"Name", "Current"}; + callback.fn = [auth, session_uuid, &sc_handler, username]() mutable -> std::vector> { + std::vector> status; + const auto in_use = sc_handler.Current(session_uuid); + bool found_current = false; + + auto gen_status = [&](T all, K denied) { + Sort(all); + Sort(denied); + + status.reserve(all.size()); + for (const auto &name : all) { + TypedValue use(""); + if (!found_current && Same(name, in_use)) { + use = TypedValue("*"); + found_current = true; + } + status.push_back({TypedValue(name), std::move(use)}); + } + + // No denied databases (no need to filter them out) + if (denied.empty()) return; + + auto denied_itr = denied.begin(); + auto iter = std::remove_if(status.begin(), status.end(), [&denied_itr, &denied](auto &in) -> bool { + while (denied_itr != denied.end() && denied_itr->ValueString() < in[0].ValueString()) ++denied_itr; + return (denied_itr != denied.end() && denied_itr->ValueString() == in[0].ValueString()); + }); + status.erase(iter, status.end()); + }; + + if (!username) { + // No user, return all + gen_status(sc_handler.All(), std::vector{}); + } else { + // User has a subset of accessible dbs; this is synched with the SessionContextHandler + const auto &db_priv = auth->GetDatabasePrivileges(*username); + const auto &allowed = db_priv[0][0]; + const auto &denied = db_priv[0][1].ValueList(); + if (allowed.IsString() && allowed.ValueString() == auth::kAllDatabases) { + // All databases are allowed + gen_status(sc_handler.All(), denied); + } else { + gen_status(allowed.ValueList(), denied); + } + } + + if (!found_current) throw QueryRuntimeException("Missing current database!"); + return status; + }; + + SymbolTable symbol_table; + std::vector output_symbols; + for (const auto &column : callback.header) { + output_symbols.emplace_back(symbol_table.CreateSymbol(column, "false")); + } + + auto plan = std::make_shared(std::make_unique( + std::make_unique(output_symbols, + [fn = callback.fn](Frame *, ExecutionContext *) { return fn(); }), + 0.0, AstStorage{}, symbol_table)); + + auto pull_plan = std::make_shared(plan, parsed_query.parameters, false, dba, interpreter_context, + execution_memory, username, transaction_status, std::move(tx_timer)); + + return PreparedQuery{ + callback.header, std::move(parsed_query.required_privileges), + [pull_plan = std::move(pull_plan), callback = std::move(callback), output_symbols = std::move(output_symbols), + summary](AnyStream *stream, std::optional n) -> std::optional { + if (pull_plan->Pull(stream, n, output_symbols, summary)) { + return callback.should_abort_query ? QueryHandlerResult::ABORT : QueryHandlerResult::COMMIT; + } + return std::nullopt; + }, + RWType::NONE, + "" // No target DB + }; +#else + throw QueryException("Query not supported."); +#endif +} + std::optional Interpreter::GetTransactionId() const { if (db_accessor_) { return db_accessor_->GetTransactionId(); @@ -3146,7 +3511,8 @@ void Interpreter::RollbackTransaction() { Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, const std::map ¶ms, - const std::string *username, QueryExtras const &extras) { + const std::string *username, QueryExtras const &extras, + const std::string &session_uuid) { std::shared_ptr current_timer; if (!in_explicit_transaction_) { query_executions_.clear(); @@ -3177,7 +3543,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, in_explicit_transaction_ ? static_cast(query_executions_.size() - 1) : std::optional{}; query_execution->prepared_query.emplace(PrepareTransactionQuery(trimmed_query, extras)); - return {query_execution->prepared_query->header, query_execution->prepared_query->privileges, qid}; + return {query_execution->prepared_query->header, query_execution->prepared_query->privileges, qid, {}}; } // Don't save BEGIN, COMMIT or ROLLBACK @@ -3329,6 +3695,14 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareTransactionQueueQuery(std::move(parsed_query), username_, in_explicit_transaction_, interpreter_context_, &*execution_db_accessor_); + } else if (utils::Downcast(parsed_query.query)) { + prepared_query = PrepareMultiDatabaseQuery(std::move(parsed_query), in_explicit_transaction_, in_explicit_db_, + interpreter_context_, session_uuid); + } else if (utils::Downcast(parsed_query.query)) { + prepared_query = PrepareShowDatabasesQuery(std::move(parsed_query), interpreter_context_, session_uuid, + &query_execution->summary, &*execution_db_accessor_, + &query_execution->execution_memory_with_exception, username_, + &transaction_status_, std::move(current_timer)); } else { LOG_FATAL("Should not get here -- unknown query type!"); } @@ -3346,7 +3720,14 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, throw QueryException("Write query forbidden on the replica!"); } - return {query_execution->prepared_query->header, query_execution->prepared_query->privileges, qid}; + // Set the target db to the current db (some queries have different target from the current db) + if (!query_execution->prepared_query->db) { + query_execution->prepared_query->db = interpreter_context_->db->id(); + } + query_execution->summary["db"] = *query_execution->prepared_query->db; + + return {query_execution->prepared_query->header, query_execution->prepared_query->privileges, qid, + query_execution->prepared_query->db}; } catch (const utils::BasicException &) { memgraph::metrics::IncrementCounter(memgraph::metrics::FailedQuery); AbortCommand(query_execution_ptr); diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index 90724b979..78a19a32c 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -77,6 +77,24 @@ class AuthQueryHandler { /// @throw QueryRuntimeException if an error ocurred. virtual void SetPassword(const std::string &username, const std::optional &password) = 0; +#ifdef MG_ENTERPRISE + /// Return true if access revoked successfully + /// @throw QueryRuntimeException if an error ocurred. + virtual bool RevokeDatabaseFromUser(const std::string &db, const std::string &username) = 0; + + /// Return true if access granted successfully + /// @throw QueryRuntimeException if an error ocurred. + virtual bool GrantDatabaseToUser(const std::string &db, const std::string &username) = 0; + + /// Returns database access rights for the user + /// @throw QueryRuntimeException if an error ocurred. + virtual std::vector> GetDatabasePrivileges(const std::string &username) = 0; + + /// Return true if main database set successfully + /// @throw QueryRuntimeException if an error ocurred. + virtual bool SetMainDatabase(const std::string &db, const std::string &username) = 0; +#endif + /// Return false if the role already exists. /// @throw QueryRuntimeException if an error ocurred. virtual bool CreateRole(const std::string &rolename) = 0; @@ -202,6 +220,7 @@ struct PreparedQuery { std::vector privileges; std::function(AnyStream *stream, std::optional n)> query_handler; plan::ReadWriteTypeChecker::RWType rw_type; + std::optional db{}; }; /** @@ -223,10 +242,12 @@ class Interpreter; /// TODO: andi decouple in a separate file why here? struct InterpreterContext { explicit InterpreterContext(storage::Config storage_config, InterpreterConfig interpreter_config, - const std::filesystem::path &data_directory); + const std::filesystem::path &data_directory, query::AuthQueryHandler *ah = nullptr, + query::AuthChecker *ac = nullptr); InterpreterContext(std::unique_ptr db, InterpreterConfig interpreter_config, - const std::filesystem::path &data_directory); + const std::filesystem::path &data_directory, query::AuthQueryHandler *ah = nullptr, + query::AuthChecker *ac = nullptr); std::unique_ptr db; @@ -240,8 +261,8 @@ struct InterpreterContext { std::optional tsc_frequency{utils::GetTSCFrequency()}; std::atomic is_shutting_down{false}; - AuthQueryHandler *auth{nullptr}; - AuthChecker *auth_checker{nullptr}; + AuthQueryHandler *auth; + AuthChecker *auth_checker; utils::SkipList ast_cache; utils::SkipList plan_cache; @@ -272,10 +293,12 @@ class Interpreter final { std::vector headers; std::vector privileges; std::optional qid; + std::optional db; }; std::optional username_; bool in_explicit_transaction_{false}; + bool in_explicit_db_{false}; bool expect_rollback_{false}; std::shared_ptr explicit_transaction_timer_{}; std::optional> metadata_{}; //!< User defined transaction metadata @@ -289,7 +312,8 @@ class Interpreter final { * @throw query::QueryException */ PrepareResult Prepare(const std::string &query, const std::map ¶ms, - const std::string *username, QueryExtras const &extras = {}); + const std::string *username, QueryExtras const &extras = {}, + const std::string &session_uuid = {}); /** * Execute the last prepared query and stream *all* of the results into the diff --git a/src/query/stream/streams.cpp b/src/query/stream/streams.cpp index 8e4620d6a..5577b6b49 100644 --- a/src/query/stream/streams.cpp +++ b/src/query/stream/streams.cpp @@ -520,7 +520,7 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std spdlog::trace("Executing query '{}' in stream '{}'", query, stream_name); auto prepare_result = interpreter->Prepare(query, params_prop.IsNull() ? empty_parameters : params_prop.ValueMap(), nullptr); - if (!interpreter_context->auth_checker->IsUserAuthorized(owner, prepare_result.privileges)) { + if (!interpreter_context->auth_checker->IsUserAuthorized(owner, prepare_result.privileges, "")) { throw StreamsException{ "Couldn't execute query '{}' for stream '{}' because the owner is not authorized to execute the " "query!", diff --git a/src/query/trigger.cpp b/src/query/trigger.cpp index d682de677..91aceb079 100644 --- a/src/query/trigger.cpp +++ b/src/query/trigger.cpp @@ -187,7 +187,7 @@ std::shared_ptr Trigger::GetPlan(DbAccessor *db_accessor, trigger_plan_ = std::make_shared(std::move(logical_plan), std::move(identifiers)); } - if (!auth_checker->IsUserAuthorized(owner_, parsed_statements_.required_privileges)) { + if (!auth_checker->IsUserAuthorized(owner_, parsed_statements_.required_privileges, "")) { throw utils::BasicException("The owner of trigger '{}' is not authorized to execute the query!", name_); } return trigger_plan_; diff --git a/src/storage/v2/config.hpp b/src/storage/v2/config.hpp index 04ce5882d..8a07c144d 100644 --- a/src/storage/v2/config.hpp +++ b/src/storage/v2/config.hpp @@ -16,9 +16,15 @@ #include #include "storage/v2/isolation_level.hpp" #include "storage/v2/transaction.hpp" +#include "utils/exceptions.hpp" namespace memgraph::storage { +/// Exception used to signal configuration errors. +class StorageConfigException : public utils::BasicException { + using utils::BasicException::BasicException; +}; + /// Pass this class to the \ref Storage constructor to change the behavior of /// the storage. This class also defines the default behavior. struct Config { @@ -62,15 +68,49 @@ struct Config { } transaction; struct DiskConfig { - std::filesystem::path main_storage_directory{"rocksdb_main_storage"}; - std::filesystem::path label_index_directory{"rocksdb_label_index"}; - std::filesystem::path label_property_index_directory{"rocksdb_label_property_index"}; - std::filesystem::path unique_constraints_directory{"rocksdb_unique_constraints"}; - std::filesystem::path name_id_mapper_directory{"rocksdb_name_id_mapper"}; - std::filesystem::path id_name_mapper_directory{"rocksdb_id_name_mapper"}; - std::filesystem::path durability_directory{"rocksdb_durability"}; - std::filesystem::path wal_directory{"rocksdb_wal"}; + std::filesystem::path main_storage_directory{"storage/rocksdb_main_storage"}; + std::filesystem::path label_index_directory{"storage/rocksdb_label_index"}; + std::filesystem::path label_property_index_directory{"storage/rocksdb_label_property_index"}; + std::filesystem::path unique_constraints_directory{"storage/rocksdb_unique_constraints"}; + std::filesystem::path name_id_mapper_directory{"storage/rocksdb_name_id_mapper"}; + std::filesystem::path id_name_mapper_directory{"storage/rocksdb_id_name_mapper"}; + std::filesystem::path durability_directory{"storage/rocksdb_durability"}; + std::filesystem::path wal_directory{"storage/rocksdb_wal"}; } disk; + + std::string name; }; +static inline void UpdatePaths(Config &config, const std::filesystem::path &storage_dir) { + auto contained = [](const auto &path, const auto &base) -> std::optional { + auto rel = std::filesystem::relative(path, base); + if (!rel.empty() && rel.native()[0] != '.') { // Contained + return rel; + } + return {}; + }; + + const auto old_base = + std::filesystem::weakly_canonical(std::filesystem::absolute(config.durability.storage_directory)); + config.durability.storage_directory = std::filesystem::weakly_canonical(std::filesystem::absolute(storage_dir)); + + auto UPDATE_PATH = [&](auto to_update) { + const auto old_path = std::filesystem::weakly_canonical(std::filesystem::absolute(to_update(config.disk))); + const auto contained_path = contained(old_path, old_base); + if (!contained_path) { + throw StorageConfigException("On-disk directories not contained in root."); + } + to_update(config.disk) = config.durability.storage_directory / *contained_path; + }; + + UPDATE_PATH(std::mem_fn(&Config::DiskConfig::main_storage_directory)); + UPDATE_PATH(std::mem_fn(&Config::DiskConfig::label_index_directory)); + UPDATE_PATH(std::mem_fn(&Config::DiskConfig::label_property_index_directory)); + UPDATE_PATH(std::mem_fn(&Config::DiskConfig::unique_constraints_directory)); + UPDATE_PATH(std::mem_fn(&Config::DiskConfig::name_id_mapper_directory)); + UPDATE_PATH(std::mem_fn(&Config::DiskConfig::id_name_mapper_directory)); + UPDATE_PATH(std::mem_fn(&Config::DiskConfig::durability_directory)); + UPDATE_PATH(std::mem_fn(&Config::DiskConfig::wal_directory)); +} + } // namespace memgraph::storage diff --git a/src/storage/v2/storage.cpp b/src/storage/v2/storage.cpp index 8b6555e8a..e3831eb05 100644 --- a/src/storage/v2/storage.cpp +++ b/src/storage/v2/storage.cpp @@ -50,7 +50,8 @@ Storage::Storage(Config config, StorageMode storage_mode) isolation_level_(config.transaction.isolation_level), storage_mode_(storage_mode), indices_(&constraints_, config, storage_mode), - constraints_(config, storage_mode) {} + constraints_(config, storage_mode), + id_(config.name) {} Storage::Accessor::Accessor(Storage *storage, IsolationLevel isolation_level, StorageMode storage_mode) : storage_(storage), diff --git a/src/storage/v2/storage.hpp b/src/storage/v2/storage.hpp index 34985ccad..0ff286d15 100644 --- a/src/storage/v2/storage.hpp +++ b/src/storage/v2/storage.hpp @@ -74,6 +74,8 @@ class Storage { virtual ~Storage() {} + const std::string &id() const { return id_; } + class Accessor { public: Accessor(Storage *storage, IsolationLevel isolation_level, StorageMode storage_mode); @@ -179,6 +181,8 @@ class Storage { StorageMode GetCreationStorageMode() const; + const std::string &id() const { return storage_->id(); } + protected: Storage *storage_; std::shared_lock storage_guard_; @@ -301,6 +305,7 @@ class Storage { std::atomic vertex_id_{0}; std::atomic edge_id_{0}; + const std::string id_; //!< High-level assigned ID }; } // namespace memgraph::storage diff --git a/src/utils/stat.hpp b/src/utils/stat.hpp index a9b9bbee5..92402cd89 100644 --- a/src/utils/stat.hpp +++ b/src/utils/stat.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -27,6 +27,7 @@ inline uint64_t GetDirDiskUsage(const std::filesystem::path &path) { uint64_t size = 0; for (auto &p : std::filesystem::directory_iterator(path)) { + if (std::filesystem::is_symlink(p)) continue; if (std::filesystem::is_directory(p)) { size += GetDirDiskUsage(p); } else if (std::filesystem::is_regular_file(p)) { diff --git a/src/utils/sync_ptr.hpp b/src/utils/sync_ptr.hpp new file mode 100644 index 000000000..33ccdc26c --- /dev/null +++ b/src/utils/sync_ptr.hpp @@ -0,0 +1,189 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include +#include +#include +#include +#include +#include "utils/exceptions.hpp" + +namespace memgraph::utils { + +/** + * @brief + * + * @tparam TContext + * @tparam TConfig + */ +template +struct SyncPtr { + /** + * @brief Construct a new synched pointer. + * + * @tparam TArgs variable templates used by the TContext constructor + * @param config Additional metadata associated with context + * @param args Arguments to pass to TContext constructor + */ + template + explicit SyncPtr(TConfig config, TArgs &&...args) + : timeout_{1000}, config_{config}, ptr_{new TContext(std::forward(args)...), [this](TContext *ptr) { + this->OnDelete(ptr); + }} {} + + ~SyncPtr() = default; + + SyncPtr(const SyncPtr &) = delete; + SyncPtr &operator=(const SyncPtr &) = delete; + SyncPtr(SyncPtr &&) noexcept = delete; + SyncPtr &operator=(SyncPtr &&) noexcept = delete; + + /** + * @brief Destroy the synched pointer and wait for all copies to get destroyed. + * + */ + void DestroyAndSync() { + ptr_.reset(); + SyncOnDelete(); + } + + /** + * @brief Get (copy) the underlying shared pointer. + * + * @return std::shared_ptr + */ + std::shared_ptr get() { return ptr_; } + std::shared_ptr get() const { return ptr_; } + + /** + * @brief Return saved configuration (metadata) + * + * @return TConfig + */ + TConfig config() { return config_; } + const TConfig &config() const { return config_; } + + void timeout(const std::chrono::milliseconds to) { timeout_ = to; } + std::chrono::milliseconds timeout() const { return timeout_; } + + private: + /** + * @brief Block until OnDelete gets called. + * + */ + void SyncOnDelete() { + std::unique_lock lock(in_use_mtx_); + if (!in_use_cv_.wait_for(lock, timeout_, [this] { return !in_use_; })) { + throw utils::BasicException("Syncronization timeout!"); + } + } + + /** + * @brief Custom destructor used to sync the shared_ptr release. + * + * @param p Pointer to the undelying object. + */ + void OnDelete(TContext *p) { + delete p; + { + std::lock_guard lock(in_use_mtx_); + in_use_ = false; + } + in_use_cv_.notify_all(); + } + + bool in_use_{true}; //!< Flag used to signal sync + mutable std::mutex in_use_mtx_; //!< Mutex used in the cv sync + mutable std::condition_variable in_use_cv_; //!< cv used to signal a sync + std::chrono::milliseconds timeout_; //!< Synchronization timeout in ms + TConfig config_; //!< Additional metadata associated with the context + std::shared_ptr ptr_; //!< Pointer being synced +}; + +template +class SyncPtr { + public: + /** + * @brief Construct a new synched pointer. + * + * @tparam TArgs variable templates used by the TContext constructor + * @param config Additional metadata associated with context + * @param args Arguments to pass to TContext constructor + */ + template + explicit SyncPtr(TArgs &&...args) + : timeout_{1000}, ptr_{new TContext(std::forward(args)...), [this](TContext *ptr) { + this->OnDelete(ptr); + }} {} + + ~SyncPtr() = default; + + SyncPtr(const SyncPtr &) = delete; + SyncPtr &operator=(const SyncPtr &) = delete; + SyncPtr(SyncPtr &&) noexcept = delete; + SyncPtr &operator=(SyncPtr &&) noexcept = delete; + + /** + * @brief Destroy the synched pointer and wait for all copies to get destroyed. + * + */ + void DestroyAndSync() { + ptr_.reset(); + SyncOnDelete(); + } + + /** + * @brief Get (copy) the underlying shared pointer. + * + * @return std::shared_ptr + */ + std::shared_ptr get() { return ptr_; } + std::shared_ptr get() const { return ptr_; } + + void timeout(const std::chrono::milliseconds to) { timeout_ = to; } + std::chrono::milliseconds timeout() const { return timeout_; } + + private: + /** + * @brief Block until OnDelete gets called. + * + */ + void SyncOnDelete() { + std::unique_lock lock(in_use_mtx_); + if (!in_use_cv_.wait_for(lock, timeout_, [this] { return !in_use_; })) { + throw utils::BasicException("Syncronization timeout!"); + } + } + + /** + * @brief Custom destructor used to sync the shared_ptr release. + * + * @param p Pointer to the undelying object. + */ + void OnDelete(TContext *p) { + delete p; + { + std::lock_guard lock(in_use_mtx_); + in_use_ = false; + } + in_use_cv_.notify_all(); + } + + bool in_use_{true}; //!< Flag used to signal sync + mutable std::mutex in_use_mtx_; //!< Mutex used in the cv sync + mutable std::condition_variable in_use_cv_; //!< cv used to signal a sync + std::chrono::milliseconds timeout_; //!< Synchronization timeout in ms + std::shared_ptr ptr_; //!< Pointer being synced +}; + +} // namespace memgraph::utils diff --git a/src/utils/typeinfo.hpp b/src/utils/typeinfo.hpp index 9bce30aff..405200f15 100644 --- a/src/utils/typeinfo.hpp +++ b/src/utils/typeinfo.hpp @@ -184,6 +184,8 @@ enum class TypeId : uint64_t { AST_TRANSACTION_QUEUE_QUERY, AST_EXISTS, AST_CALL_SUBQUERY, + AST_MULTI_DATABASE_QUERY, + AST_SHOW_DATABASES, // Symbol SYMBOL, }; diff --git a/tests/concurrent/network_server.cpp b/tests/concurrent/network_server.cpp index 386d06524..3d73eabf1 100644 --- a/tests/concurrent/network_server.cpp +++ b/tests/concurrent/network_server.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -30,10 +30,10 @@ TEST(Network, Server) { std::cout << endpoint << std::endl; // initialize server - TestData session_data; + TestData session_context; int N = (std::thread::hardware_concurrency() + 1) / 2; ContextT context; - ServerT server(endpoint, &session_data, &context, -1, "Test", N); + ServerT server(endpoint, &session_context, &context, -1, "Test", N); ASSERT_TRUE(server.Start()); const auto &ep = server.endpoint(); diff --git a/tests/concurrent/network_session_leak.cpp b/tests/concurrent/network_session_leak.cpp index 495cf92e0..9007ef82c 100644 --- a/tests/concurrent/network_session_leak.cpp +++ b/tests/concurrent/network_session_leak.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -32,9 +32,9 @@ TEST(Network, SessionLeak) { Endpoint endpoint(interface, 0); // initialize server - TestData session_data; + TestData session_context; ContextT context; - ServerT server(endpoint, &session_data, &context, -1, "Test", 2); + ServerT server(endpoint, &session_context, &context, -1, "Test", 2); ASSERT_TRUE(server.Start()); // start clients diff --git a/tests/e2e/analyze_graph/common.py b/tests/e2e/analyze_graph/common.py index c43d9cd79..ba5729e0a 100644 --- a/tests/e2e/analyze_graph/common.py +++ b/tests/e2e/analyze_graph/common.py @@ -25,9 +25,14 @@ def execute_and_fetch_all(cursor: mgclient.Cursor, query: str, params: dict = {} def connect(**kwargs) -> mgclient.Connection: connection = mgclient.connect(host="localhost", port=7687, **kwargs) connection.autocommit = True - yield connection cursor = connection.cursor() + execute_and_fetch_all(cursor, "USE DATABASE memgraph") + try: + execute_and_fetch_all(cursor, "DROP DATABASE clean") + except: + pass execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n") + yield connection @pytest.fixture diff --git a/tests/e2e/analyze_graph/optimize_indexes.py b/tests/e2e/analyze_graph/optimize_indexes.py index be6a72c1e..6358e4ddd 100644 --- a/tests/e2e/analyze_graph/optimize_indexes.py +++ b/tests/e2e/analyze_graph/optimize_indexes.py @@ -21,6 +21,18 @@ QUERY_PLAN = "QUERY PLAN" # ------------------------------------ +@pytest.fixture(scope="function") +def multi_db(request, connect): + cursor = connect.cursor() + if request.param: + execute_and_fetch_all(cursor, "CREATE DATABASE clean") + execute_and_fetch_all(cursor, "USE DATABASE clean") + execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n") + pass + yield connect + + +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) @pytest.mark.parametrize( "delete_query", [ @@ -30,9 +42,9 @@ QUERY_PLAN = "QUERY PLAN" "ANALYZE GRAPH ON LABELS :Label, :NONEXISTING DELETE STATISTICS", ], ) -def test_analyze_graph_delete_statistics(delete_query, connect): +def test_analyze_graph_delete_statistics(delete_query, multi_db): """Tests that all variants of delete queries work as expected.""" - cursor = connect.cursor() + cursor = multi_db.cursor() execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 100) | CREATE (n:Label {id1: i}));") execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 50) | CREATE (n:Label {id2: i % 5}));") execute_and_fetch_all(cursor, "CREATE INDEX ON :Label(id1);") @@ -62,6 +74,7 @@ def test_analyze_graph_delete_statistics(delete_query, connect): execute_and_fetch_all(cursor, "DROP INDEX ON :Label(id2);") +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) @pytest.mark.parametrize( "analyze_query", [ @@ -71,11 +84,11 @@ def test_analyze_graph_delete_statistics(delete_query, connect): "ANALYZE GRAPH ON LABELS :Label, :NONEXISTING", ], ) -def test_analyze_full_graph(analyze_query, connect): +def test_analyze_full_graph(analyze_query, multi_db): """Tests analyzing full graph and choosing better index based on the smaller average group size. It also tests querying based on labels and that nothing bad will happen by providing non-existing label. """ - cursor = connect.cursor() + cursor = multi_db.cursor() execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 100) | CREATE (n:Label {id1: i}));") execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 50) | CREATE (n:Label {id2: i % 5}));") execute_and_fetch_all(cursor, "CREATE INDEX ON :Label(id1);") @@ -121,9 +134,10 @@ def test_analyze_full_graph(analyze_query, connect): # ----------------------------- -def test_cardinality_different_avg_group_size_uniform_dist(connect): +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) +def test_cardinality_different_avg_group_size_uniform_dist(multi_db): """Tests index optimization with indices both having uniform distribution but one has smaller avg. group size.""" - cursor = connect.cursor() + cursor = multi_db.cursor() execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 100) | CREATE (n:Label {id1: i}));") execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 100) | CREATE (n:Label {id2: i % 20}));") execute_and_fetch_all(cursor, "CREATE INDEX ON :Label(id1);") @@ -151,9 +165,10 @@ def test_cardinality_different_avg_group_size_uniform_dist(connect): execute_and_fetch_all(cursor, "DROP INDEX ON :Label(id2);") -def test_cardinality_same_avg_group_size_uniform_dist_diff_vertex_count(connect): +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) +def test_cardinality_same_avg_group_size_uniform_dist_diff_vertex_count(multi_db): """Tests index choosing where both indices have uniform key distribution with same avg. group size but one has less vertices.""" - cursor = connect.cursor() + cursor = multi_db.cursor() execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 100) | CREATE (n:Label {id1: i}));") execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 50) | CREATE (n:Label {id2: i}));") execute_and_fetch_all(cursor, "CREATE INDEX ON :Label(id1);") @@ -181,9 +196,10 @@ def test_cardinality_same_avg_group_size_uniform_dist_diff_vertex_count(connect) execute_and_fetch_all(cursor, "DROP INDEX ON :Label(id2);") -def test_large_diff_in_num_vertices_v1(connect): +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) +def test_large_diff_in_num_vertices_v1(multi_db): """Tests that when one index has > 10x vertices than the other one, it should be chosen no matter avg group size and uniform distribution.""" - cursor = connect.cursor() + cursor = multi_db.cursor() execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 1000) | CREATE (n:Label {id1: i}));") execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 99) | CREATE (n:Label {id2: 1}));") execute_and_fetch_all(cursor, "CREATE INDEX ON :Label(id1);") @@ -211,9 +227,10 @@ def test_large_diff_in_num_vertices_v1(connect): execute_and_fetch_all(cursor, "DROP INDEX ON :Label(id2);") -def test_large_diff_in_num_vertices_v2(connect): +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) +def test_large_diff_in_num_vertices_v2(multi_db): """Tests that when one index has > 10x vertices than the other one, it should be chosen no matter avg group size and uniform distribution.""" - cursor = connect.cursor() + cursor = multi_db.cursor() execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 99) | CREATE (n:Label {id1: 1}));") execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 1000) | CREATE (n:Label {id2: i}));") execute_and_fetch_all(cursor, "CREATE INDEX ON :Label(id1);") @@ -241,9 +258,10 @@ def test_large_diff_in_num_vertices_v2(connect): execute_and_fetch_all(cursor, "DROP INDEX ON :Label(id2);") -def test_same_avg_group_size_diff_distribution(connect): +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) +def test_same_avg_group_size_diff_distribution(multi_db): """Tests index choice decision based on key distribution.""" - cursor = connect.cursor() + cursor = multi_db.cursor() # Setup first key distribution execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 10) | CREATE (n:Label {id1: 1}));") execute_and_fetch_all(cursor, "FOREACH (i IN range(1, 30) | CREATE (n:Label {id1: 2}));") diff --git a/tests/e2e/configuration/default_config.py b/tests/e2e/configuration/default_config.py index a0a7153fb..dccf8a54b 100644 --- a/tests/e2e/configuration/default_config.py +++ b/tests/e2e/configuration/default_config.py @@ -66,6 +66,11 @@ startup_config_dict = { "Time in seconds after which inactive Bolt sessions will be closed.", ), "data_directory": ("mg_data", "mg_data", "Path to directory in which to save all permanent data."), + "data_recovery_on_startup": ( + "false", + "false", + "Controls whether the database recovers persisted data on startup.", + ), "isolation_level": ( "SNAPSHOT_ISOLATION", "SNAPSHOT_ISOLATION", @@ -133,11 +138,6 @@ startup_config_dict = { "The number of edges and vertices stored in a batch in a snapshot file.", ), "storage_properties_on_edges": ("false", "true", "Controls whether edges have properties."), - "storage_recover_on_startup": ( - "false", - "false", - "Controls whether the storage recovers persisted data on startup.", - ), "storage_recovery_thread_count": ("12", "12", "The number of threads used to recover persisted data from disk."), "storage_snapshot_interval_sec": ( "0", @@ -157,6 +157,11 @@ startup_config_dict = { "Issue a 'fsync' call after this amount of transactions are written to the WAL file. Set to 1 for fully synchronous operation.", ), "storage_wal_file_size_kib": ("20480", "20480", "Minimum file size of each WAL file."), + "storage_delete_on_drop": ( + "true", + "true", + "If set to true the query 'DROP DATABASE x' will delete the underlying storage as well.", + ), "stream_transaction_conflict_retries": ( "30", "30", diff --git a/tests/e2e/configuration/storage_info.py b/tests/e2e/configuration/storage_info.py index 7644c8d3c..614d644cd 100644 --- a/tests/e2e/configuration/storage_info.py +++ b/tests/e2e/configuration/storage_info.py @@ -16,6 +16,7 @@ import mgclient import pytest default_storage_info_dict = { + "name": "memgraph", "vertex_count": 0, "edge_count": 0, "average_degree": 0, @@ -55,7 +56,7 @@ def test_does_default_config_match(): machine_dependent_configurations = ["memory_usage", "disk_usage", "memory_allocated", "allocation_limit"] # Number of different data-points returned by SHOW STORAGE INFO - assert len(config) == 11 + assert len(config) == 12 for conf in config: conf_name = conf[0] diff --git a/tests/e2e/fine_grained_access/CMakeLists.txt b/tests/e2e/fine_grained_access/CMakeLists.txt index e124076f2..6b277694f 100644 --- a/tests/e2e/fine_grained_access/CMakeLists.txt +++ b/tests/e2e/fine_grained_access/CMakeLists.txt @@ -6,3 +6,4 @@ copy_fine_grained_access_e2e_python_files(common.py) copy_fine_grained_access_e2e_python_files(create_delete_filtering_tests.py) copy_fine_grained_access_e2e_python_files(edge_type_filtering_tests.py) copy_fine_grained_access_e2e_python_files(path_filtering_tests.py) +copy_fine_grained_access_e2e_python_files(show_db.py) diff --git a/tests/e2e/fine_grained_access/common.py b/tests/e2e/fine_grained_access/common.py index 21ea2faaa..33768ebbc 100644 --- a/tests/e2e/fine_grained_access/common.py +++ b/tests/e2e/fine_grained_access/common.py @@ -1,4 +1,4 @@ -# Copyright 2021 Memgraph Ltd. +# Copyright 2023 Memgraph Ltd. # # Use of this software is governed by the Business Source License # included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -12,6 +12,23 @@ import mgclient +def switch_db(cursor): + execute_and_fetch_all(cursor, "USE DATABASE clean;") + + +def create_multi_db(cursor, switch): + execute_and_fetch_all(cursor, "USE DATABASE memgraph;") + try: + execute_and_fetch_all(cursor, "DROP DATABASE clean;") + except: + pass + execute_and_fetch_all(cursor, "CREATE DATABASE clean;") + if switch: + switch_db(cursor) + reset_and_prepare(cursor) + execute_and_fetch_all(cursor, "USE DATABASE memgraph;") + + def reset_and_prepare(admin_cursor): execute_and_fetch_all(admin_cursor, "REVOKE LABELS * FROM user;") execute_and_fetch_all(admin_cursor, "REVOKE EDGE_TYPES * FROM user;") diff --git a/tests/e2e/fine_grained_access/create_delete_filtering_tests.py b/tests/e2e/fine_grained_access/create_delete_filtering_tests.py index 9966c3039..b087fa6c5 100644 --- a/tests/e2e/fine_grained_access/create_delete_filtering_tests.py +++ b/tests/e2e/fine_grained_access/create_delete_filtering_tests.py @@ -1,4 +1,4 @@ -# Copyright 2022 Memgraph Ltd. +# Copyright 2023 Memgraph Ltd. # # Use of this software is governed by the Business Source License # included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -16,168 +16,224 @@ import pytest from mgclient import DatabaseError -def test_create_node_all_labels_granted(): +@pytest.mark.parametrize("switch", [False, True]) +def test_create_node_all_labels_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") - results = common.execute_and_fetch_all(user_connnection.cursor(), "CREATE (n:label1) RETURN n;") + if switch: + common.switch_db(user_connection.cursor()) + results = common.execute_and_fetch_all(user_connection.cursor(), "CREATE (n:label1) RETURN n;") assert len(results) == 1 -def test_create_node_all_labels_denied(): +@pytest.mark.parametrize("switch", [False, True]) +def test_create_node_all_labels_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;") + if switch: + common.switch_db(user_connection.cursor()) with pytest.raises(DatabaseError): - common.execute_and_fetch_all(user_connnection.cursor(), "CREATE (n:label1) RETURN n;") + common.execute_and_fetch_all(user_connection.cursor(), "CREATE (n:label1) RETURN n;") -def test_create_node_specific_label_granted(): +@pytest.mark.parametrize("switch", [False, True]) +def test_create_node_specific_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label1 TO user;") - results = common.execute_and_fetch_all(user_connnection.cursor(), "CREATE (n:label1) RETURN n;") + if switch: + common.switch_db(user_connection.cursor()) + results = common.execute_and_fetch_all(user_connection.cursor(), "CREATE (n:label1) RETURN n;") assert len(results) == 1 -def test_create_node_specific_label_denied(): +@pytest.mark.parametrize("switch", [False, True]) +def test_create_node_specific_label_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :label1 TO user;") + if switch: + common.switch_db(user_connection.cursor()) with pytest.raises(DatabaseError): - common.execute_and_fetch_all(user_connnection.cursor(), "CREATE (n:label1) RETURN n;") + common.execute_and_fetch_all(user_connection.cursor(), "CREATE (n:label1) RETURN n;") -def test_delete_node_all_labels_granted(): +@pytest.mark.parametrize("switch", [False, True]) +def test_delete_node_all_labels_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") - common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n:test_delete) DELETE n;") + if switch: + common.switch_db(user_connection.cursor()) + common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n:test_delete) DELETE n;") - results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n:test_delete) RETURN n;") + results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n:test_delete) RETURN n;") assert len(results) == 0 -def test_delete_node_all_labels_denied(): +@pytest.mark.parametrize("switch", [False, True]) +def test_delete_node_all_labels_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;") + if switch: + common.switch_db(user_connection.cursor()) with pytest.raises(DatabaseError): - common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n:test_delete) DELETE n") + common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n:test_delete) DELETE n;") -def test_delete_node_specific_label_granted(): +@pytest.mark.parametrize("switch", [False, True]) +def test_delete_node_specific_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :test_delete TO user;") - results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n:test_delete) DELETE n;") + if switch: + common.switch_db(user_connection.cursor()) + results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n:test_delete) DELETE n;") + if switch: + common.switch_db(admin_connection.cursor()) results = common.execute_and_fetch_all(admin_connection.cursor(), "MATCH (n:test_delete) RETURN n;") assert len(results) == 0 -def test_delete_node_specific_label_denied(): +@pytest.mark.parametrize("switch", [False, True]) +def test_delete_node_specific_label_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :test_delete TO user;") + if switch: + common.switch_db(user_connection.cursor()) with pytest.raises(DatabaseError): - common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n:test_delete) DELETE n;") + common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n:test_delete) DELETE n;") -def test_create_edge_all_labels_all_edge_types_granted(): +@pytest.mark.parametrize("switch", [False, True]) +def test_create_edge_all_labels_all_edge_types_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;") + if switch: + common.switch_db(user_connection.cursor()) results = common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "CREATE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;", ) assert len(results) == 1 -def test_create_edge_all_labels_all_edge_types_denied(): +@pytest.mark.parametrize("switch", [False, True]) +def test_create_edge_all_labels_all_edge_types_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;") + if switch: + common.switch_db(user_connection.cursor()) with pytest.raises(DatabaseError): common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "CREATE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;", ) -def test_create_edge_all_labels_denied_all_edge_types_granted(): +@pytest.mark.parametrize("switch", [False, True]) +def test_create_edge_all_labels_denied_all_edge_types_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;") + if switch: + common.switch_db(user_connection.cursor()) with pytest.raises(DatabaseError): common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "CREATE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;", ) -def test_create_edge_all_labels_granted_all_edge_types_denied(): +@pytest.mark.parametrize("switch", [False, True]) +def test_create_edge_all_labels_granted_all_edge_types_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;") + if switch: + common.switch_db(user_connection.cursor()) with pytest.raises(DatabaseError): common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "CREATE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;", ) -def test_create_edge_all_labels_granted_specific_edge_types_denied(): +@pytest.mark.parametrize("switch", [False, True]) +def test_create_edge_all_labels_granted_specific_edge_types_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") common.execute_and_fetch_all( admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES :edge_type TO user;", ) + if switch: + common.switch_db(user_connection.cursor()) with pytest.raises(DatabaseError): common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "CREATE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;", ) -def test_create_edge_first_node_label_granted(): +@pytest.mark.parametrize("switch", [False, True]) +def test_create_edge_first_node_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label1 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :label2 TO user;") common.execute_and_fetch_all( @@ -185,17 +241,21 @@ def test_create_edge_first_node_label_granted(): "GRANT CREATE_DELETE ON EDGE_TYPES :edge_type TO user;", ) + if switch: + common.switch_db(user_connection.cursor()) with pytest.raises(DatabaseError): common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "CREATE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;", ) -def test_create_edge_second_node_label_granted(): +@pytest.mark.parametrize("switch", [False, True]) +def test_create_edge_second_node_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label2 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :label1 TO user;") common.execute_and_fetch_all( @@ -203,62 +263,78 @@ def test_create_edge_second_node_label_granted(): "GRANT CREATE_DELETE ON EDGE_TYPES :edge_type TO user;", ) + if switch: + common.switch_db(user_connection.cursor()) with pytest.raises(DatabaseError): common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "CREATE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;", ) -def test_delete_edge_all_labels_denied_all_edge_types_granted(): +@pytest.mark.parametrize("switch", [False, True]) +def test_delete_edge_all_labels_denied_all_edge_types_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;") + if switch: + common.switch_db(user_connection.cursor()) with pytest.raises(DatabaseError): common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH (n:test_delete_1)-[r:edge_type_delete]->(m:test_delete_2) DELETE r", ) -def test_delete_edge_all_labels_granted_all_edge_types_denied(): +@pytest.mark.parametrize("switch", [False, True]) +def test_delete_edge_all_labels_granted_all_edge_types_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;") + if switch: + common.switch_db(user_connection.cursor()) with pytest.raises(DatabaseError): common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH (n:test_delete_1)-[r:edge_type_delete]->(m:test_delete_2) DELETE r", ) -def test_delete_edge_all_labels_granted_specific_edge_types_denied(): +@pytest.mark.parametrize("switch", [False, True]) +def test_delete_edge_all_labels_granted_specific_edge_types_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") common.execute_and_fetch_all( admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES :edge_type_delete TO user;", ) + if switch: + common.switch_db(user_connection.cursor()) with pytest.raises(DatabaseError): common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH (n:test_delete_1)-[r:edge_type_delete]->(m:test_delete_2) DELETE r", ) -def test_delete_edge_first_node_label_granted(): +@pytest.mark.parametrize("switch", [False, True]) +def test_delete_edge_first_node_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :test_delete_1 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :test_delete_2 TO user;") common.execute_and_fetch_all( @@ -266,17 +342,21 @@ def test_delete_edge_first_node_label_granted(): "GRANT CREATE_DELETE ON EDGE_TYPES :edge_type_delete TO user;", ) + if switch: + common.switch_db(user_connection.cursor()) with pytest.raises(DatabaseError): common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH (n:test_delete_1)-[r:edge_type_delete]->(m:test_delete_2) DELETE r", ) -def test_delete_edge_second_node_label_granted(): +@pytest.mark.parametrize("switch", [False, True]) +def test_delete_edge_second_node_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :test_delete_2 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :test_delete_1 TO user;") common.execute_and_fetch_all( @@ -284,159 +364,209 @@ def test_delete_edge_second_node_label_granted(): "GRANT CREATE_DELETE ON EDGE_TYPES :edge_type_delete TO user;", ) + if switch: + common.switch_db(user_connection.cursor()) with pytest.raises(DatabaseError): common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH (n:test_delete_1)-[r:edge_type_delete]->(m:test_delete_2) DELETE r", ) -def test_delete_node_with_edge_label_denied(): +@pytest.mark.parametrize("switch", [False, True]) +def test_delete_node_with_edge_label_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all( admin_connection.cursor(), "GRANT UPDATE ON LABELS :test_delete_1 TO user;", ) + if switch: + common.switch_db(user_connection.cursor()) with pytest.raises(DatabaseError): - common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n) DETACH DELETE n;") + common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n) DETACH DELETE n;") -def test_delete_node_with_edge_label_granted(): +@pytest.mark.parametrize("switch", [False, True]) +def test_delete_node_with_edge_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all( admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :test_delete_1 TO user;", ) - common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n) DETACH DELETE n;") + if switch: + common.switch_db(user_connection.cursor()) + common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n) DETACH DELETE n;") + if switch: + common.switch_db(admin_connection.cursor()) results = common.execute_and_fetch_all(admin_connection.cursor(), "MATCH (n:test_delete_1) RETURN n;") assert len(results) == 0 -def test_merge_node_all_labels_granted(): +@pytest.mark.parametrize("switch", [False, True]) +def test_merge_node_all_labels_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") - results = common.execute_and_fetch_all(user_connnection.cursor(), "MERGE (n:label1) RETURN n;") + if switch: + common.switch_db(user_connection.cursor()) + results = common.execute_and_fetch_all(user_connection.cursor(), "MERGE (n:label1) RETURN n;") assert len(results) == 1 -def test_merge_node_all_labels_denied(): +@pytest.mark.parametrize("switch", [False, True]) +def test_merge_node_all_labels_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;") + if switch: + common.switch_db(user_connection.cursor()) with pytest.raises(DatabaseError): - common.execute_and_fetch_all(user_connnection.cursor(), "MERGE (n:label1) RETURN n;") + common.execute_and_fetch_all(user_connection.cursor(), "MERGE (n:label1) RETURN n;") -def test_merge_node_specific_label_granted(): +@pytest.mark.parametrize("switch", [False, True]) +def test_merge_node_specific_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label1 TO user;") - results = common.execute_and_fetch_all(user_connnection.cursor(), "MERGE (n:label1) RETURN n;") + if switch: + common.switch_db(user_connection.cursor()) + results = common.execute_and_fetch_all(user_connection.cursor(), "MERGE (n:label1) RETURN n;") assert len(results) == 1 -def test_merge_node_specific_label_denied(): +@pytest.mark.parametrize("switch", [False, True]) +def test_merge_node_specific_label_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :label1 TO user;") + if switch: + common.switch_db(user_connection.cursor()) with pytest.raises(DatabaseError): - common.execute_and_fetch_all(user_connnection.cursor(), "MERGE (n:label1) RETURN n;") + common.execute_and_fetch_all(user_connection.cursor(), "MERGE (n:label1) RETURN n;") -def test_merge_edge_all_labels_all_edge_types_granted(): +@pytest.mark.parametrize("switch", [False, True]) +def test_merge_edge_all_labels_all_edge_types_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;") + if switch: + common.switch_db(user_connection.cursor()) results = common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MERGE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;", ) assert len(results) == 1 -def test_merge_edge_all_labels_all_edge_types_denied(): +@pytest.mark.parametrize("switch", [False, True]) +def test_merge_edge_all_labels_all_edge_types_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;") + if switch: + common.switch_db(user_connection.cursor()) with pytest.raises(DatabaseError): common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MERGE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;", ) -def test_merge_edge_all_labels_denied_all_edge_types_granted(): +@pytest.mark.parametrize("switch", [False, True]) +def test_merge_edge_all_labels_denied_all_edge_types_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;") + if switch: + common.switch_db(user_connection.cursor()) with pytest.raises(DatabaseError): common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MERGE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;", ) -def test_merge_edge_all_labels_granted_all_edge_types_denied(): +@pytest.mark.parametrize("switch", [False, True]) +def test_merge_edge_all_labels_granted_all_edge_types_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;") + if switch: + common.switch_db(user_connection.cursor()) with pytest.raises(DatabaseError): common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MERGE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;", ) -def test_merge_edge_all_labels_granted_specific_edge_types_denied(): +@pytest.mark.parametrize("switch", [False, True]) +def test_merge_edge_all_labels_granted_specific_edge_types_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") common.execute_and_fetch_all( admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES :edge_type TO user;", ) + if switch: + common.switch_db(user_connection.cursor()) with pytest.raises(DatabaseError): common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MERGE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;", ) -def test_merge_edge_first_node_label_granted(): +@pytest.mark.parametrize("switch", [False, True]) +def test_merge_edge_first_node_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label1 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :label2 TO user;") common.execute_and_fetch_all( @@ -444,17 +574,21 @@ def test_merge_edge_first_node_label_granted(): "GRANT CREATE_DELETE ON EDGE_TYPES :edge_type TO user;", ) + if switch: + common.switch_db(user_connection.cursor()) with pytest.raises(DatabaseError): common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MERGE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;", ) -def test_merge_edge_second_node_label_granted(): +@pytest.mark.parametrize("switch", [False, True]) +def test_merge_edge_second_node_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label2 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :label1 TO user;") common.execute_and_fetch_all( @@ -462,64 +596,86 @@ def test_merge_edge_second_node_label_granted(): "GRANT CREATE_DELETE ON EDGE_TYPES :edge_type TO user;", ) + if switch: + common.switch_db(user_connection.cursor()) with pytest.raises(DatabaseError): common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MERGE (n:label1)-[r:edge_type]->(m:label2) RETURN n,r,m;", ) -def test_set_label_when_label_granted(): +@pytest.mark.parametrize("switch", [False, True]) +def test_set_label_when_label_granted(switch): admin_connection = common.connect(username="admin", password="test") user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :update_label_2 TO user;") + if switch: + common.switch_db(user_connection.cursor()) common.execute_and_fetch_all(user_connection.cursor(), "MATCH (p:test_delete) SET p:update_label_2;") -def test_set_label_when_label_denied(): +@pytest.mark.parametrize("switch", [False, True]) +def test_set_label_when_label_denied(switch): admin_connection = common.connect(username="admin", password="test") user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :update_label_2 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :test_delete TO user;") + if switch: + common.switch_db(user_connection.cursor()) with pytest.raises(DatabaseError): common.execute_and_fetch_all(user_connection.cursor(), "MATCH (p:test_delete) SET p:update_label_2;") -def test_remove_label_when_label_granted(): +@pytest.mark.parametrize("switch", [False, True]) +def test_remove_label_when_label_granted(switch): admin_connection = common.connect(username="admin", password="test") user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :test_delete TO user;") + if switch: + common.switch_db(user_connection.cursor()) common.execute_and_fetch_all(user_connection.cursor(), "MATCH (p:test_delete) REMOVE p:test_delete;") -def test_remove_label_when_label_denied(): +@pytest.mark.parametrize("switch", [False, True]) +def test_remove_label_when_label_denied(switch): admin_connection = common.connect(username="admin", password="test") user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :update_label_2 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :test_delete TO user;") + if switch: + common.switch_db(user_connection.cursor()) with pytest.raises(DatabaseError): common.execute_and_fetch_all(user_connection.cursor(), "MATCH (p:test_delete) REMOVE p:test_delete;") -def test_merge_nodes_pass_when_having_create_delete(): +@pytest.mark.parametrize("switch", [False, True]) +def test_merge_nodes_pass_when_having_create_delete(switch): admin_connection = common.connect(username="admin", password="test") user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) + common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;") + if switch: + common.switch_db(user_connection.cursor()) results = common.execute_and_fetch_all( user_connection.cursor(), "UNWIND [{id: '1', lat: 10, lng: 10}, {id: '2', lat: 10, lng: 10}, {id: '3', lat: 10, lng: 10}] AS row MERGE (o:Location {id: row.id}) RETURN o;", diff --git a/tests/e2e/fine_grained_access/edge_type_filtering_tests.py b/tests/e2e/fine_grained_access/edge_type_filtering_tests.py index 9a190af8b..f2071c54b 100644 --- a/tests/e2e/fine_grained_access/edge_type_filtering_tests.py +++ b/tests/e2e/fine_grained_access/edge_type_filtering_tests.py @@ -1,91 +1,116 @@ -import common import sys + +import common import pytest -def test_all_edge_types_all_labels_granted(): +@pytest.mark.parametrize("switch", [False, True]) +def test_all_edge_types_all_labels_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") - results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;") + if switch: + common.switch_db(user_connection.cursor()) + results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;") assert len(results) == 3 -def test_deny_all_edge_types_and_all_labels(): +@pytest.mark.parametrize("switch", [False, True]) +def test_deny_all_edge_types_and_all_labels(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;") - results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;") + if switch: + common.switch_db(user_connection.cursor()) + results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;") assert len(results) == 0 -def test_revoke_all_edge_types_and_all_labels(): +@pytest.mark.parametrize("switch", [False, True]) +def test_revoke_all_edge_types_and_all_labels(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") - results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;") + if switch: + common.switch_db(user_connection.cursor()) + results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;") assert len(results) == 0 -def test_deny_edge_type(): +@pytest.mark.parametrize("switch", [False, True]) +def test_deny_edge_type(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :label1, :label2, :label3 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edgeType2 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edgeType1 TO user;") - results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;") + if switch: + common.switch_db(user_connection.cursor()) + results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;") assert len(results) == 2 -def test_denied_node_label(): +@pytest.mark.parametrize("switch", [False, True]) +def test_denied_node_label(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :label1,:label3 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edgeType1, :edgeType2 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label2 TO user;") - results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;") + if switch: + common.switch_db(user_connection.cursor()) + results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;") assert len(results) == 2 -def test_denied_one_of_node_label(): +@pytest.mark.parametrize("switch", [False, True]) +def test_denied_one_of_node_label(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :label1,:label2 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edgeType1, :edgeType2 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label3 TO user;") - results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;") + if switch: + common.switch_db(user_connection.cursor()) + results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;") assert len(results) == 1 -def test_revoke_all_labels(): +@pytest.mark.parametrize("switch", [False, True]) +def test_revoke_all_labels(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") - results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;") + if switch: + common.switch_db(user_connection.cursor()) + results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;") assert len(results) == 0 -def test_revoke_all_edge_types(): +@pytest.mark.parametrize("switch", [False, True]) +def test_revoke_all_edge_types(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") - results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;") + if switch: + common.switch_db(user_connection.cursor()) + results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;") assert len(results) == 0 diff --git a/tests/e2e/fine_grained_access/path_filtering_tests.py b/tests/e2e/fine_grained_access/path_filtering_tests.py index f69b9188c..c5b873972 100644 --- a/tests/e2e/fine_grained_access/path_filtering_tests.py +++ b/tests/e2e/fine_grained_access/path_filtering_tests.py @@ -1,22 +1,26 @@ -import common import sys + +import common import pytest -def test_weighted_shortest_path_all_edge_types_all_labels_granted(): +@pytest.mark.parametrize("switch", [False, True]) +def test_weighted_shortest_path_all_edge_types_all_labels_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + if switch: + common.switch_db(user_connection.cursor()) total_paths_results = common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH p=(n)-[r *wShortest (r, n | r.weight)]->(m) RETURN extract( node in nodes(p) | node.id);", ) path_result = common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH p=(n:label0)-[r *wShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length,nodes(p);", ) @@ -47,24 +51,28 @@ def test_weighted_shortest_path_all_edge_types_all_labels_granted(): assert all(node.id in expected_path for node in path_result[0][1]) -def test_weighted_shortest_path_all_edge_types_all_labels_denied(): +@pytest.mark.parametrize("switch", [False, True]) +def test_weighted_shortest_path_all_edge_types_all_labels_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;") + if switch: + common.switch_db(user_connection.cursor()) results = common.execute_and_fetch_all( - user_connnection.cursor(), "MATCH p=(n)-[r *wShortest (r, n | r.weight)]->(m) RETURN p;" + user_connection.cursor(), "MATCH p=(n)-[r *wShortest (r, n | r.weight)]->(m) RETURN p;" ) assert len(results) == 0 -def test_weighted_shortest_path_denied_start(): +@pytest.mark.parametrize("switch", [False, True]) +def test_weighted_shortest_path_denied_start(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -73,17 +81,20 @@ def test_weighted_shortest_path_denied_start(): common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;") + if switch: + common.switch_db(user_connection.cursor()) path_length_result = common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH p=(n:label0)-[r *wShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length;", ) assert len(path_length_result) == 0 -def test_weighted_shortest_path_denied_destination(): +@pytest.mark.parametrize("switch", [False, True]) +def test_weighted_shortest_path_denied_destination(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -92,17 +103,20 @@ def test_weighted_shortest_path_denied_destination(): common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;") + if switch: + common.switch_db(user_connection.cursor()) path_length_result = common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH p=(n:label0)-[r *wShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length;", ) assert len(path_length_result) == 0 -def test_weighted_shortest_path_denied_label_1(): +@pytest.mark.parametrize("switch", [False, True]) +def test_weighted_shortest_path_denied_label_1(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -111,13 +125,15 @@ def test_weighted_shortest_path_denied_label_1(): common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + if switch: + common.switch_db(user_connection.cursor()) total_paths_results = common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH p=(n)-[r *wShortest (r, n | r.weight)]->(m) RETURN extract( node in nodes(p) | node.id);", ) path_result = common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH p=(n:label0)-[r *wShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length, nodes(p);", ) @@ -143,9 +159,10 @@ def test_weighted_shortest_path_denied_label_1(): assert all(node.id in expected_path for node in path_result[0][1]) -def test_weighted_shortest_path_denied_edge_type_3(): +@pytest.mark.parametrize("switch", [False, True]) +def test_weighted_shortest_path_denied_edge_type_3(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") @@ -154,13 +171,15 @@ def test_weighted_shortest_path_denied_edge_type_3(): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;") + if switch: + common.switch_db(user_connection.cursor()) path_result = common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH p=(n:label0)-[r *wShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length, nodes(p);", ) total_paths_results = common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH p=(n)-[r *wShortest (r, n | r.weight)]->(m) RETURN extract( node in nodes(p) | node.id);", ) @@ -191,16 +210,19 @@ def test_weighted_shortest_path_denied_edge_type_3(): assert all(node.id in expected_path for node in path_result[0][1]) -def test_dfs_all_edge_types_all_labels_granted(): +@pytest.mark.parametrize("switch", [False, True]) +def test_dfs_all_edge_types_all_labels_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + if switch: + common.switch_db(user_connection.cursor()) source_destination_paths = common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH path=(n:label0)-[* 1..3]->(m:label4) RETURN extract( node in nodes(path) | node.id);", ) @@ -210,22 +232,26 @@ def test_dfs_all_edge_types_all_labels_granted(): assert all(path[0] in expected_paths for path in source_destination_paths) -def test_dfs_all_edge_types_all_labels_denied(): +@pytest.mark.parametrize("switch", [False, True]) +def test_dfs_all_edge_types_all_labels_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;") - total_paths_results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH p=(n)-[*]->(m) RETURN p;") + if switch: + common.switch_db(user_connection.cursor()) + total_paths_results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH p=(n)-[*]->(m) RETURN p;") assert len(total_paths_results) == 0 -def test_dfs_denied_start(): +@pytest.mark.parametrize("switch", [False, True]) +def test_dfs_denied_start(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -234,16 +260,19 @@ def test_dfs_denied_start(): common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;") + if switch: + common.switch_db(user_connection.cursor()) source_destination_path = common.execute_and_fetch_all( - user_connnection.cursor(), "MATCH p=(n:label0)-[*]->(m:label4) RETURN p;" + user_connection.cursor(), "MATCH p=(n:label0)-[*]->(m:label4) RETURN p;" ) assert len(source_destination_path) == 0 -def test_dfs_denied_destination(): +@pytest.mark.parametrize("switch", [False, True]) +def test_dfs_denied_destination(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -252,16 +281,19 @@ def test_dfs_denied_destination(): common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;") + if switch: + common.switch_db(user_connection.cursor()) source_destination_path = common.execute_and_fetch_all( - user_connnection.cursor(), "MATCH p=(n:label0)-[*]->(m:label4) RETURN p;" + user_connection.cursor(), "MATCH p=(n:label0)-[*]->(m:label4) RETURN p;" ) assert len(source_destination_path) == 0 -def test_dfs_denied_label_1(): +@pytest.mark.parametrize("switch", [False, True]) +def test_dfs_denied_label_1(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -269,8 +301,11 @@ def test_dfs_denied_label_1(): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + + if switch: + common.switch_db(user_connection.cursor()) source_destination_paths = common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH p=(n:label0)-[* 1..3]->(m:label4) RETURN extract( node in nodes(p) | node.id);", ) @@ -280,9 +315,10 @@ def test_dfs_denied_label_1(): assert all(path[0] in expected_paths for path in source_destination_paths) -def test_dfs_denied_edge_type_3(): +@pytest.mark.parametrize("switch", [False, True]) +def test_dfs_denied_edge_type_3(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") @@ -292,8 +328,10 @@ def test_dfs_denied_edge_type_3(): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;") + if switch: + common.switch_db(user_connection.cursor()) source_destination_path = common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH p=(n:label0)-[r * 1..3]->(m:label4) RETURN extract( node in nodes(p) | node.id);", ) @@ -303,16 +341,19 @@ def test_dfs_denied_edge_type_3(): assert source_destination_path[0][0] == expected_path -def test_bfs_sts_all_edge_types_all_labels_granted(): +@pytest.mark.parametrize("switch", [False, True]) +def test_bfs_sts_all_edge_types_all_labels_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + if switch: + common.switch_db(user_connection.cursor()) source_destination_path = common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH (n), (m) WITH n, m MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN extract( node in nodes(p) | node.id);", ) @@ -322,24 +363,28 @@ def test_bfs_sts_all_edge_types_all_labels_granted(): assert source_destination_path[0][0] == expected_path -def test_bfs_sts_all_edge_types_all_labels_denied(): +@pytest.mark.parametrize("switch", [False, True]) +def test_bfs_sts_all_edge_types_all_labels_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;") + if switch: + common.switch_db(user_connection.cursor()) total_paths_results = common.execute_and_fetch_all( - user_connnection.cursor(), "MATCH (n), (m) WITH n, m MATCH p=(n)-[r *BFS]->(m) RETURN p;" + user_connection.cursor(), "MATCH (n), (m) WITH n, m MATCH p=(n)-[r *BFS]->(m) RETURN p;" ) assert len(total_paths_results) == 0 -def test_bfs_sts_denied_start(): +@pytest.mark.parametrize("switch", [False, True]) +def test_bfs_sts_denied_start(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -348,16 +393,19 @@ def test_bfs_sts_denied_start(): common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;") + if switch: + common.switch_db(user_connection.cursor()) source_destination_path = common.execute_and_fetch_all( - user_connnection.cursor(), "MATCH (n), (m) WITH n, m MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN p;" + user_connection.cursor(), "MATCH (n), (m) WITH n, m MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN p;" ) assert len(source_destination_path) == 0 -def test_bfs_sts_denied_destination(): +@pytest.mark.parametrize("switch", [False, True]) +def test_bfs_sts_denied_destination(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -366,16 +414,19 @@ def test_bfs_sts_denied_destination(): common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;") + if switch: + common.switch_db(user_connection.cursor()) source_destination_path = common.execute_and_fetch_all( - user_connnection.cursor(), "MATCH (n), (m) WITH n, m MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN p;" + user_connection.cursor(), "MATCH (n), (m) WITH n, m MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN p;" ) assert len(source_destination_path) == 0 -def test_bfs_sts_denied_label_1(): +@pytest.mark.parametrize("switch", [False, True]) +def test_bfs_sts_denied_label_1(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -383,8 +434,11 @@ def test_bfs_sts_denied_label_1(): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + + if switch: + common.switch_db(user_connection.cursor()) source_destination_path = common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH (n), (m) WITH n, m MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN extract( node in nodes(p) | node.id);", ) expected_path = [0, 2, 4, 5] @@ -393,9 +447,10 @@ def test_bfs_sts_denied_label_1(): assert source_destination_path[0][0] == expected_path -def test_bfs_sts_denied_edge_type_3(): +@pytest.mark.parametrize("switch", [False, True]) +def test_bfs_sts_denied_edge_type_3(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") @@ -404,8 +459,10 @@ def test_bfs_sts_denied_edge_type_3(): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;") + if switch: + common.switch_db(user_connection.cursor()) source_destination_path = common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH (n), (m) WITH n, m MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN extract( node in nodes(p) | node.id);", ) expected_path = [0, 2, 4, 5] @@ -414,16 +471,19 @@ def test_bfs_sts_denied_edge_type_3(): assert source_destination_path[0][0] == expected_path -def test_bfs_single_source_all_edge_types_all_labels_granted(): +@pytest.mark.parametrize("switch", [False, True]) +def test_bfs_single_source_all_edge_types_all_labels_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + if switch: + common.switch_db(user_connection.cursor()) source_destination_path = common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN extract( node in nodes(p) | node.id);", ) @@ -433,22 +493,26 @@ def test_bfs_single_source_all_edge_types_all_labels_granted(): assert source_destination_path[0][0] == expected_path -def test_bfs_single_source_all_edge_types_all_labels_denied(): +@pytest.mark.parametrize("switch", [False, True]) +def test_bfs_single_source_all_edge_types_all_labels_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;") - total_paths_results = common.execute_and_fetch_all(user_connnection.cursor(), "MATCH p=(n)-[r *BFS]->(m) RETURN p;") + if switch: + common.switch_db(user_connection.cursor()) + total_paths_results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH p=(n)-[r *BFS]->(m) RETURN p;") assert len(total_paths_results) == 0 -def test_bfs_single_source_denied_start(): +@pytest.mark.parametrize("switch", [False, True]) +def test_bfs_single_source_denied_start(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -457,16 +521,19 @@ def test_bfs_single_source_denied_start(): common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;") + if switch: + common.switch_db(user_connection.cursor()) source_destination_path = common.execute_and_fetch_all( - user_connnection.cursor(), "MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN p;" + user_connection.cursor(), "MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN p;" ) assert len(source_destination_path) == 0 -def test_bfs_single_source_denied_destination(): +@pytest.mark.parametrize("switch", [False, True]) +def test_bfs_single_source_denied_destination(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -475,16 +542,19 @@ def test_bfs_single_source_denied_destination(): common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;") + if switch: + common.switch_db(user_connection.cursor()) source_destination_path = common.execute_and_fetch_all( - user_connnection.cursor(), "MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN p;" + user_connection.cursor(), "MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN p;" ) assert len(source_destination_path) == 0 -def test_bfs_single_source_denied_label_1(): +@pytest.mark.parametrize("switch", [False, True]) +def test_bfs_single_source_denied_label_1(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -492,8 +562,11 @@ def test_bfs_single_source_denied_label_1(): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + + if switch: + common.switch_db(user_connection.cursor()) source_destination_path = common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN extract( node in nodes(p) | node.id);", ) @@ -503,9 +576,10 @@ def test_bfs_single_source_denied_label_1(): assert source_destination_path[0][0] == expected_path -def test_bfs_single_source_denied_edge_type_3(): +@pytest.mark.parametrize("switch", [False, True]) +def test_bfs_single_source_denied_edge_type_3(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") @@ -514,8 +588,10 @@ def test_bfs_single_source_denied_edge_type_3(): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;") + if switch: + common.switch_db(user_connection.cursor()) source_destination_path = common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH p=(n:label0)-[r *BFS]->(m:label4) RETURN extract( node in nodes(p) | node.id);", ) @@ -525,20 +601,23 @@ def test_bfs_single_source_denied_edge_type_3(): assert source_destination_path[0][0] == expected_path -def test_all_shortest_paths_when_all_edge_types_all_labels_granted(): +@pytest.mark.parametrize("switch", [False, True]) +def test_all_shortest_paths_when_all_edge_types_all_labels_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + if switch: + common.switch_db(user_connection.cursor()) total_paths_results = common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH p=(n)-[r *allShortest (r, n | r.weight)]->(m) RETURN extract( node in nodes(p) | node.id);", ) path_result = common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH p=(n:label0)-[r *allShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length,nodes(p);", ) @@ -569,24 +648,28 @@ def test_all_shortest_paths_when_all_edge_types_all_labels_granted(): assert all(node.id in expected_path for node in path_result[0][1]) -def test_all_shortest_paths_when_all_edge_types_all_labels_denied(): +@pytest.mark.parametrize("switch", [False, True]) +def test_all_shortest_paths_when_all_edge_types_all_labels_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;") + if switch: + common.switch_db(user_connection.cursor()) results = common.execute_and_fetch_all( - user_connnection.cursor(), "MATCH p=(n)-[r *allShortest (r, n | r.weight)]->(m) RETURN p;" + user_connection.cursor(), "MATCH p=(n)-[r *allShortest (r, n | r.weight)]->(m) RETURN p;" ) assert len(results) == 0 -def test_all_shortest_paths_when_denied_start(): +@pytest.mark.parametrize("switch", [False, True]) +def test_all_shortest_paths_when_denied_start(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -595,17 +678,20 @@ def test_all_shortest_paths_when_denied_start(): common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;") + if switch: + common.switch_db(user_connection.cursor()) path_length_result = common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH p=(n:label0)-[r *allShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length;", ) assert len(path_length_result) == 0 -def test_all_shortest_paths_when_denied_destination(): +@pytest.mark.parametrize("switch", [False, True]) +def test_all_shortest_paths_when_denied_destination(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -614,17 +700,20 @@ def test_all_shortest_paths_when_denied_destination(): common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;") + if switch: + common.switch_db(user_connection.cursor()) path_length_result = common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH p=(n:label0)-[r *allShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length;", ) assert len(path_length_result) == 0 -def test_all_shortest_paths_when_denied_label_1(): +@pytest.mark.parametrize("switch", [False, True]) +def test_all_shortest_paths_when_denied_label_1(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -633,13 +722,15 @@ def test_all_shortest_paths_when_denied_label_1(): common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + if switch: + common.switch_db(user_connection.cursor()) total_paths_results = common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH p=(n)-[r *allShortest (r, n | r.weight)]->(m) RETURN extract( node in nodes(p) | node.id);", ) path_result = common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH p=(n:label0)-[r *allShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length, nodes(p);", ) @@ -665,9 +756,10 @@ def test_all_shortest_paths_when_denied_label_1(): assert all(node.id in expected_path for node in path_result[0][1]) -def test_all_shortest_paths_when_denied_edge_type_3(): +@pytest.mark.parametrize("switch", [False, True]) +def test_all_shortest_paths_when_denied_edge_type_3(switch): admin_connection = common.connect(username="admin", password="test") - user_connnection = common.connect(username="user", password="test") + user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") @@ -676,13 +768,15 @@ def test_all_shortest_paths_when_denied_edge_type_3(): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;") + if switch: + common.switch_db(user_connection.cursor()) path_result = common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH p=(n:label0)-[r *allShortest (r, n | r.weight) path_length]->(m:label4) RETURN path_length, nodes(p);", ) total_paths_results = common.execute_and_fetch_all( - user_connnection.cursor(), + user_connection.cursor(), "MATCH p=(n)-[r *allShortest (r, n | r.weight)]->(m) RETURN extract( node in nodes(p) | node.id);", ) diff --git a/tests/e2e/fine_grained_access/show_db.py b/tests/e2e/fine_grained_access/show_db.py new file mode 100644 index 000000000..546d4e24a --- /dev/null +++ b/tests/e2e/fine_grained_access/show_db.py @@ -0,0 +1,36 @@ +# Copyright 2023 Memgraph Ltd. +# +# Use of this software is governed by the Business Source License +# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +# License, and you may not use this file except in compliance with the Business Source License. +# +# As of the Change Date specified in that file, in accordance with +# the Business Source License, use of this software will be governed +# by the Apache License, Version 2.0, included in the file +# licenses/APL.txt. + +import sys + +import common +import pytest +from mgclient import DatabaseError + + +def test_show_databases_w_user(): + admin_connection = common.connect(username="admin", password="test") + user_connection = common.connect(username="user", password="test") + user2_connection = common.connect(username="user2", password="test") + user3_connection = common.connect(username="user3", password="test") + + assert common.execute_and_fetch_all(admin_connection.cursor(), "SHOW DATABASES") == [ + ("db1", ""), + ("db2", ""), + ("memgraph", "*"), + ] + assert common.execute_and_fetch_all(user_connection.cursor(), "SHOW DATABASES") == [("db1", ""), ("memgraph", "*")] + assert common.execute_and_fetch_all(user2_connection.cursor(), "SHOW DATABASES") == [("db2", "*")] + assert common.execute_and_fetch_all(user3_connection.cursor(), "SHOW DATABASES") == [("db1", "*"), ("db2", "")] + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__, "-rA"])) diff --git a/tests/e2e/fine_grained_access/workloads.yaml b/tests/e2e/fine_grained_access/workloads.yaml index c7b14b3d4..aed6eb8bb 100644 --- a/tests/e2e/fine_grained_access/workloads.yaml +++ b/tests/e2e/fine_grained_access/workloads.yaml @@ -9,7 +9,9 @@ create_delete_filtering_cluster: &create_delete_filtering_cluster "CREATE USER admin IDENTIFIED BY 'test';", "CREATE USER user IDENTIFIED BY 'test';", "GRANT ALL PRIVILEGES TO admin;", + "GRANT DATABASE * TO admin;", "GRANT ALL PRIVILEGES TO user;", + "GRANT DATABASE * TO user;", ] edge_type_filtering_cluster: &edge_type_filtering_cluster @@ -22,7 +24,9 @@ edge_type_filtering_cluster: &edge_type_filtering_cluster "CREATE USER admin IDENTIFIED BY 'test';", "CREATE USER user IDENTIFIED BY 'test';", "GRANT ALL PRIVILEGES TO admin;", + "GRANT DATABASE * TO admin;", "GRANT ALL PRIVILEGES TO user;", + "GRANT DATABASE * TO user;", "GRANT CREATE_DELETE ON LABELS * TO admin;", "GRANT CREATE_DELETE ON EDGE_TYPES * TO admin;", "MERGE (l1:label1 {name: 'test1'});", @@ -32,6 +36,17 @@ edge_type_filtering_cluster: &edge_type_filtering_cluster "MATCH (l1:label1),(l3:label3) WHERE l1.name = 'test1' AND l3.name = 'test3' CREATE (l1)-[r:edgeType2]->(l3);", "MERGE (mix:label3:label1 {name: 'test4'});", "MATCH (l1:label1),(mix:label3) WHERE l1.name = 'test1' AND mix.name = 'test4' CREATE (l1)-[r:edgeType2]->(mix);", + "CREATE DATABASE clean;", + "USE DATABASE clean", + "MATCH (n) DETACH DELETE n;", + "MERGE (l1:label1 {name: 'test1'});", + "MERGE (l2:label2 {name: 'test2'});", + "MATCH (l1:label1),(l2:label2) WHERE l1.name = 'test1' AND l2.name = 'test2' CREATE (l1)-[r:edgeType1]->(l2);", + "MERGE (l3:label3 {name: 'test3'});", + "MATCH (l1:label1),(l3:label3) WHERE l1.name = 'test1' AND l3.name = 'test3' CREATE (l1)-[r:edgeType2]->(l3);", + "MERGE (mix:label3:label1 {name: 'test4'});", + "MATCH (l1:label1),(mix:label3) WHERE l1.name = 'test1' AND mix.name = 'test4' CREATE (l1)-[r:edgeType2]->(mix);", + "USE DATABASE memgraph", ] validation_queries: [] @@ -45,7 +60,9 @@ path_filtering_cluster: &path_filtering_cluster "CREATE USER admin IDENTIFIED BY 'test';", "CREATE USER user IDENTIFIED BY 'test';", "GRANT ALL PRIVILEGES TO admin;", + "GRANT DATABASE * TO admin;", "GRANT ALL PRIVILEGES TO user;", + "GRANT DATABASE * TO user;", "MERGE (a:label0 {id: 0}) MERGE (b:label1 {id: 1}) CREATE (a)-[:edge_type_1 {weight: 6}]->(b);", "MERGE (a:label0 {id: 0}) MERGE (b:label2 {id: 2}) CREATE (a)-[:edge_type_1 {weight: 14}]->(b);", "MERGE (a:label1 {id: 1}) MERGE (b:label2 {id: 2}) CREATE (a)-[:edge_type_2 {weight: 1}]->(b);", @@ -56,6 +73,47 @@ path_filtering_cluster: &path_filtering_cluster "MERGE (a:label3 {id: 4}) MERGE (b:label3 {id: 3}) CREATE (a)-[:edge_type_4 {weight: 1}]->(b);", "MERGE (a:label3 {id: 3}) MERGE (b:label4 {id: 5}) CREATE (a)-[:edge_type_4 {weight: 14}]->(b);", "MERGE (a:label3 {id: 4}) MERGE (b:label4 {id: 5}) CREATE (a)-[:edge_type_4 {weight: 8}]->(b);", + "CREATE DATABASE clean;", + "USE DATABASE clean", + "MATCH (n) DETACH DELETE n;", + "MERGE (a:label0 {id: 0}) MERGE (b:label1 {id: 1}) CREATE (a)-[:edge_type_1 {weight: 6}]->(b);", + "MERGE (a:label0 {id: 0}) MERGE (b:label2 {id: 2}) CREATE (a)-[:edge_type_1 {weight: 14}]->(b);", + "MERGE (a:label1 {id: 1}) MERGE (b:label2 {id: 2}) CREATE (a)-[:edge_type_2 {weight: 1}]->(b);", + "MERGE (a:label2 {id: 2}) MERGE (b:label3 {id: 4}) CREATE (a)-[:edge_type_2 {weight: 10}]->(b);", + "MERGE (a:label1 {id: 1}) MERGE (b:label3 {id: 3}) CREATE (a)-[:edge_type_3 {weight: 5}]->(b);", + "MERGE (a:label2 {id: 2}) MERGE (b:label3 {id: 3}) CREATE (a)-[:edge_type_3 {weight: 7}]->(b);", + "MERGE (a:label3 {id: 3}) MERGE (b:label3 {id: 4}) CREATE (a)-[:edge_type_4 {weight: 1}]->(b);", + "MERGE (a:label3 {id: 4}) MERGE (b:label3 {id: 3}) CREATE (a)-[:edge_type_4 {weight: 1}]->(b);", + "MERGE (a:label3 {id: 3}) MERGE (b:label4 {id: 5}) CREATE (a)-[:edge_type_4 {weight: 14}]->(b);", + "MERGE (a:label3 {id: 4}) MERGE (b:label4 {id: 5}) CREATE (a)-[:edge_type_4 {weight: 8}]->(b);", + "USE DATABASE memgraph", + ] + +show_databases_w_user: &show_databases_w_user + cluster: + main: + args: ["--bolt-port", "7687", "--log-level=TRACE"] + log_file: "fine_grained_access.log" + setup_queries: + [ + "CREATE USER admin IDENTIFIED BY 'test';", + "CREATE USER user IDENTIFIED BY 'test';", + "CREATE USER user2 IDENTIFIED BY 'test';", + "CREATE USER user3 IDENTIFIED BY 'test';", + "CREATE DATABASE db1;", + "CREATE DATABASE db2;", + "GRANT ALL PRIVILEGES TO admin;", + "GRANT DATABASE * TO admin;", + "GRANT ALL PRIVILEGES TO user;", + "GRANT DATABASE db1 TO user;", + "GRANT ALL PRIVILEGES TO user2;", + "GRANT DATABASE db2 TO user2;", + "REVOKE DATABASE memgraph FROM user2;", + "SET MAIN DATABASE db2 FOR user2", + "GRANT ALL PRIVILEGES TO user3;", + "GRANT DATABASE * TO user3;", + "REVOKE DATABASE memgraph FROM user3;", + "SET MAIN DATABASE db1 FOR user3", ] workloads: @@ -63,7 +121,6 @@ workloads: binary: "tests/e2e/pytest_runner.sh" args: ["fine_grained_access/create_delete_filtering_tests.py"] <<: *create_delete_filtering_cluster - - name: "EdgeType filtering" binary: "tests/e2e/pytest_runner.sh" args: ["fine_grained_access/edge_type_filtering_tests.py"] @@ -72,3 +129,7 @@ workloads: binary: "tests/e2e/pytest_runner.sh" args: ["fine_grained_access/path_filtering_tests.py"] <<: *path_filtering_cluster + - name: "Show databases with users" + binary: "tests/e2e/pytest_runner.sh" + args: ["fine_grained_access/show_db.py"] + <<: *show_databases_w_user diff --git a/tests/e2e/graphql/graphql_library_config/crud.js b/tests/e2e/graphql/graphql_library_config/crud.js index 65c65c79f..2e95040dc 100644 --- a/tests/e2e/graphql/graphql_library_config/crud.js +++ b/tests/e2e/graphql/graphql_library_config/crud.js @@ -21,7 +21,14 @@ const driver = neo4j.driver( neo4j.auth.basic("", "") ); -const neoSchema = new Neo4jGraphQL({ typeDefs, driver }); +const neoSchema = new Neo4jGraphQL({ + typeDefs, driver, + config: { + driverConfig: { + database: "memgraph", + }, + } +}); neoSchema.getSchema().then((schema) => { const server = new ApolloServer({ diff --git a/tests/e2e/isolation_levels/isolation_levels.cpp b/tests/e2e/isolation_levels/isolation_levels.cpp index 4cec5b13d..15151b106 100644 --- a/tests/e2e/isolation_levels/isolation_levels.cpp +++ b/tests/e2e/isolation_levels/isolation_levels.cpp @@ -9,12 +9,16 @@ // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. +#include #include #include +#include "query/exceptions.hpp" #include "utils/logging.hpp" #include "utils/timer.hpp" +#include + DEFINE_uint64(bolt_port, 7687, "Bolt port"); DEFINE_uint64(timeout, 120, "Timeout seconds"); @@ -53,16 +57,55 @@ bool IsDiskStorageMode(std::unique_ptr &client) { return false; } -void CleanDatabase() { - auto client = GetClient(); +void CleanDatabase(std::unique_ptr &client) { MG_ASSERT(client->Execute("MATCH (n) DETACH DELETE n;")); client->DiscardAll(); } +void SetupCleanDB() { + auto client = GetClient(); + MG_ASSERT(client->Execute("USE DATABASE memgraph;")); + client->DiscardAll(); + try { + client->Execute("DROP DATABASE clean;"); + client->DiscardAll(); + } catch (const mg::ClientException &) { + // In case clean doesn't exist + } + MG_ASSERT(client->Execute("CREATE DATABASE clean;")); + client->DiscardAll(); + MG_ASSERT(client->Execute("USE DATABASE clean;")); + client->DiscardAll(); + CleanDatabase(client); +} + +void SwitchToDB(const std::string &name, std::unique_ptr &client) { + MG_ASSERT(client->Execute(fmt::format("USE DATABASE {};", name))); + client->DiscardAll(); +} + +void SwitchToCleanDB(std::unique_ptr &client) { SwitchToDB("clean", client); } + +void SwitchToSameDB(std::unique_ptr &main, std::unique_ptr &client) { + MG_ASSERT(main->Execute("SHOW DATABASES;")); + auto dbs = main->FetchAll(); + MG_ASSERT(dbs, "Failed to show databases"); + for (const auto &elem : *dbs) { + MG_ASSERT(elem.size(), "Show databases wrong output"); + const auto &active = elem[1].ValueString(); + if (active == "*") { + const auto &name = elem[0].ValueString(); + SwitchToDB(std::string(name), client); + break; + } + } +} + void TestSnapshotIsolation(std::unique_ptr &client) { spdlog::info("Verifying SNAPSHOT ISOLATION"); auto creator = GetClient(); + SwitchToSameDB(client, creator); MG_ASSERT(client->BeginTransaction()); MG_ASSERT(creator->BeginTransaction()); @@ -89,13 +132,14 @@ void TestSnapshotIsolation(std::unique_ptr &client) { "at a later point.", current_vertex_count, 0); MG_ASSERT(client->CommitTransaction()); - CleanDatabase(); + CleanDatabase(creator); } void TestReadCommitted(std::unique_ptr &client) { spdlog::info("Verifying READ COMMITTED"); auto creator = GetClient(); + SwitchToSameDB(client, creator); MG_ASSERT(client->BeginTransaction()); MG_ASSERT(creator->BeginTransaction()); @@ -121,13 +165,14 @@ void TestReadCommitted(std::unique_ptr &client) { "from a committed transaction", current_vertex_count, vertex_count); MG_ASSERT(client->CommitTransaction()); - CleanDatabase(); + CleanDatabase(creator); } void TestReadUncommitted(std::unique_ptr &client) { spdlog::info("Verifying READ UNCOMMITTED"); auto creator = GetClient(); + SwitchToSameDB(client, creator); MG_ASSERT(client->BeginTransaction()); MG_ASSERT(creator->BeginTransaction()); @@ -152,18 +197,23 @@ void TestReadUncommitted(std::unique_ptr &client) { "from a different transaction", current_vertex_count, vertex_count); MG_ASSERT(client->CommitTransaction()); - CleanDatabase(); + CleanDatabase(creator); } inline constexpr std::array isolation_levels{std::pair{"SNAPSHOT ISOLATION", &TestSnapshotIsolation}, std::pair{"READ COMMITTED", &TestReadCommitted}, std::pair{"READ UNCOMMITTED", &TestReadUncommitted}}; -void TestGlobalIsolationLevel(bool isDiskStorage) { +void TestGlobalIsolationLevel(bool isDiskStorage, bool mdb = false) { spdlog::info("\n\n----Test global isolation levels----\n"); auto first_client = GetClient(); auto second_client = GetClient(); + if (mdb) { + SwitchToCleanDB(first_client); + SwitchToCleanDB(second_client); + } + for (const auto &[isolation_level, verification_function] : isolation_levels) { spdlog::info("--------------------------"); @@ -183,11 +233,17 @@ void TestGlobalIsolationLevel(bool isDiskStorage) { } } -void TestSessionIsolationLevel(bool isDiskStorage) { +void TestSessionIsolationLevel(bool isDiskStorage, bool mdb = false) { spdlog::info("\n\n----Test session isolation levels----\n"); auto global_client = GetClient(); auto session_client = GetClient(); + + if (mdb) { + SwitchToCleanDB(global_client); + SwitchToCleanDB(session_client); + } + for (const auto &[global_isolation_level, global_verification_function] : isolation_levels) { if (isDiskStorage && strcmp(global_isolation_level, "SNAPSHOT ISOLATION") != 0) { spdlog::info("Skipping for disk storage unsupported global isolation level {}", global_isolation_level); @@ -218,11 +274,17 @@ void TestSessionIsolationLevel(bool isDiskStorage) { } // Priority of applying the isolation level from highest priority NEXT -> SESSION -> GLOBAL -void TestNextIsolationLevel(bool isDiskStorage) { +void TestNextIsolationLevel(bool isDiskStorage, bool mdb = false) { spdlog::info("\n\n----Test next isolation levels----\n"); auto global_client = GetClient(); auto session_client = GetClient(); + + if (mdb) { + SwitchToCleanDB(global_client); + SwitchToCleanDB(session_client); + } + for (const auto &[global_isolation_level, global_verification_function] : isolation_levels) { if (isDiskStorage && strcmp(global_isolation_level, "SNAPSHOT ISOLATION") != 0) { spdlog::info("Skipping for disk storage unsupported global isolation level {}", global_isolation_level); @@ -289,10 +351,21 @@ int main(int argc, char **argv) { auto client = GetClient(); bool isDiskStorage = IsDiskStorageMode(client); client->DiscardAll(); + bool multiDB = false; TestGlobalIsolationLevel(isDiskStorage); TestSessionIsolationLevel(isDiskStorage); TestNextIsolationLevel(isDiskStorage); + // MultiDB tests + multiDB = true; + spdlog::info("--------------------------"); + spdlog::info("---- RUNNING MULTI DB ----"); + spdlog::info("--------------------------"); + SetupCleanDB(); + TestGlobalIsolationLevel(isDiskStorage, multiDB); + TestSessionIsolationLevel(isDiskStorage, multiDB); + TestNextIsolationLevel(isDiskStorage, multiDB); + return 0; } diff --git a/tests/e2e/lba_procedures/common.py b/tests/e2e/lba_procedures/common.py index 553307b4c..b55a433cf 100644 --- a/tests/e2e/lba_procedures/common.py +++ b/tests/e2e/lba_procedures/common.py @@ -1,4 +1,4 @@ -# Copyright 2021 Memgraph Ltd. +# Copyright 2023 Memgraph Ltd. # # Use of this software is governed by the Business Source License # included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -9,9 +9,10 @@ # by the Apache License, Version 2.0, included in the file # licenses/APL.txt. -import mgclient import typing +import mgclient + def execute_and_fetch_all(cursor: mgclient.Cursor, query: str, params: dict = {}) -> typing.List[tuple]: cursor.execute(query, params) @@ -24,6 +25,19 @@ def connect(**kwargs) -> mgclient.Connection: return connection +def switch_db(cursor): + execute_and_fetch_all(cursor, "USE DATABASE clean;") + + +def create_multi_db(cursor): + execute_and_fetch_all(cursor, "USE DATABASE memgraph;") + try: + execute_and_fetch_all(cursor, "DROP DATABASE clean;") + except: + pass + execute_and_fetch_all(cursor, "CREATE DATABASE clean;") + + def reset_permissions(admin_cursor: mgclient.Cursor, create_index: bool = False): execute_and_fetch_all(admin_cursor, "REVOKE LABELS * FROM user;") execute_and_fetch_all(admin_cursor, "REVOKE EDGE_TYPES * FROM user;") diff --git a/tests/e2e/lba_procedures/create_delete_query_modules.py b/tests/e2e/lba_procedures/create_delete_query_modules.py index e63105bdb..175ae2390 100644 --- a/tests/e2e/lba_procedures/create_delete_query_modules.py +++ b/tests/e2e/lba_procedures/create_delete_query_modules.py @@ -1,4 +1,4 @@ -# Copyright 2022 Memgraph Ltd. +# Copyright 2023 Memgraph Ltd. # # Use of this software is governed by the Business Source License # included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -9,15 +9,10 @@ # by the Apache License, Version 2.0, included in the file # licenses/APL.txt. -import pytest import sys -from common import ( - connect, - execute_and_fetch_all, - mgclient, - reset_create_delete_permissions, -) +import pytest +from common import * AUTHORIZATION_ERROR_IDENTIFIER = "AuthorizationError" @@ -28,55 +23,83 @@ create_edge_query = "MATCH (n:create_delete_label_1), (m:create_delete_label_2) delete_edge_query = "CALL create_delete.delete_edge() YIELD * RETURN *;" -def test_can_not_create_vertex_when_given_nothing(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_not_create_vertex_when_given_nothing(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_create_delete_permissions(admin_cursor) test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER): execute_and_fetch_all(test_cursor, create_vertex_query) -def test_can_create_vertex_when_given_global_create_delete(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_create_vertex_when_given_global_create_delete(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_create_delete_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON LABELS * TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, create_vertex_query) len(result[0][0]) == 1 -def test_can_not_create_vertex_when_given_global_read(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_not_create_vertex_when_given_global_read(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_create_delete_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS * TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER): execute_and_fetch_all(test_cursor, create_vertex_query) -def test_can_not_create_vertex_when_given_global_update(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_not_create_vertex_when_given_global_update(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_create_delete_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON LABELS :create_delete_label TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER): execute_and_fetch_all(test_cursor, create_vertex_query) -def test_can_add_vertex_label_when_given_create_delete(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_add_vertex_label_when_given_create_delete(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_create_delete_permissions(admin_cursor) execute_and_fetch_all( @@ -85,14 +108,20 @@ def test_can_add_vertex_label_when_given_create_delete(): ) test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, set_label_vertex_query) assert "create_delete_label" in result[0][0] assert "new_create_delete_label" in result[0][0] -def test_can_not_add_vertex_label_when_given_update(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_not_add_vertex_label_when_given_update(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_create_delete_permissions(admin_cursor) execute_and_fetch_all( @@ -100,12 +129,18 @@ def test_can_not_add_vertex_label_when_given_update(): ) test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER): execute_and_fetch_all(test_cursor, set_label_vertex_query) -def test_can_not_add_vertex_label_when_given_read(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_not_add_vertex_label_when_given_read(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_create_delete_permissions(admin_cursor) execute_and_fetch_all( @@ -113,118 +148,178 @@ def test_can_not_add_vertex_label_when_given_read(): ) test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER): execute_and_fetch_all(test_cursor, set_label_vertex_query) -def test_can_remove_vertex_label_when_given_create_delete(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_remove_vertex_label_when_given_create_delete(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_create_delete_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON LABELS :create_delete_label TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, remove_label_vertex_query) assert result[0][0] != ":create_delete_label" -def test_can_remove_vertex_label_when_given_global_create_delete(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_remove_vertex_label_when_given_global_create_delete(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_create_delete_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON LABELS * TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, remove_label_vertex_query) assert result[0][0] != ":create_delete_label" -def test_can_not_remove_vertex_label_when_given_update(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_not_remove_vertex_label_when_given_update(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_create_delete_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON LABELS :create_delete_label TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER): execute_and_fetch_all(test_cursor, remove_label_vertex_query) -def test_can_not_remove_vertex_label_when_given_global_update(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_not_remove_vertex_label_when_given_global_update(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_create_delete_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON LABELS * TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER): execute_and_fetch_all(test_cursor, remove_label_vertex_query) -def test_can_not_remove_vertex_label_when_given_read(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_not_remove_vertex_label_when_given_read(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_create_delete_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :create_delete_label TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER): execute_and_fetch_all(test_cursor, remove_label_vertex_query) -def test_can_not_remove_vertex_label_when_given_global_read(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_not_remove_vertex_label_when_given_global_read(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_create_delete_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS * TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER): execute_and_fetch_all(test_cursor, remove_label_vertex_query) -def test_can_not_create_edge_when_given_nothing(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_not_create_edge_when_given_nothing(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_create_delete_permissions(admin_cursor) test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER): execute_and_fetch_all(test_cursor, create_edge_query) -def test_can_not_create_edge_when_given_read(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_not_create_edge_when_given_read(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_create_delete_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT READ ON EDGE_TYPES :new_create_delete_edge_type TO user") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER): execute_and_fetch_all(test_cursor, create_edge_query) -def test_can_not_create_edge_when_given_update(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_not_create_edge_when_given_update(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_create_delete_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON EDGE_TYPES :new_create_delete_edge_type TO user") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER): execute_and_fetch_all(test_cursor, create_edge_query) -def test_can_create_edge_when_given_create_delete(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_create_edge_when_given_create_delete(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_create_delete_permissions(admin_cursor) execute_and_fetch_all( @@ -233,24 +328,36 @@ def test_can_create_edge_when_given_create_delete(): ) test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) no_of_edges = execute_and_fetch_all(test_cursor, create_edge_query) assert no_of_edges[0][0] == 2 -def test_can_not_delete_edge_when_given_nothing(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_not_delete_edge_when_given_nothing(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_create_delete_permissions(admin_cursor) test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER): execute_and_fetch_all(test_cursor, delete_edge_query) -def test_can_not_delete_edge_when_given_read(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_not_delete_edge_when_given_read(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_create_delete_permissions(admin_cursor) execute_and_fetch_all( @@ -259,13 +366,19 @@ def test_can_not_delete_edge_when_given_read(): ) test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER): execute_and_fetch_all(test_cursor, delete_edge_query) -def test_can_not_delete_edge_when_given_update(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_not_delete_edge_when_given_update(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_create_delete_permissions(admin_cursor) execute_and_fetch_all( @@ -274,13 +387,19 @@ def test_can_not_delete_edge_when_given_update(): ) test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) with pytest.raises(mgclient.DatabaseError, match=AUTHORIZATION_ERROR_IDENTIFIER): execute_and_fetch_all(test_cursor, delete_edge_query) -def test_can_delete_edge_when_given_create_delete(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_delete_edge_when_given_create_delete(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_create_delete_permissions(admin_cursor) execute_and_fetch_all( @@ -289,6 +408,8 @@ def test_can_delete_edge_when_given_create_delete(): ) test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) no_of_edges = execute_and_fetch_all(test_cursor, delete_edge_query) diff --git a/tests/e2e/lba_procedures/read_permission_queries.py b/tests/e2e/lba_procedures/read_permission_queries.py index 42b8ab240..4348e6bda 100644 --- a/tests/e2e/lba_procedures/read_permission_queries.py +++ b/tests/e2e/lba_procedures/read_permission_queries.py @@ -1,4 +1,4 @@ -# Copyright 2022 Memgraph Ltd. +# Copyright 2023 Memgraph Ltd. # # Use of this software is governed by the Business Source License # included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -9,12 +9,11 @@ # by the Apache License, Version 2.0, included in the file # licenses/APL.txt. -import pytest import sys - from typing import List -from common import connect, execute_and_fetch_all, reset_permissions +import pytest +from common import * match_query = "MATCH (n) RETURN n;" match_by_id_query = "MATCH (n) WHERE ID(n) >= 0 RETURN n;" @@ -105,11 +104,16 @@ def get_user_cursor(): def execute_read_node_assertion( - operation_case: List[str], queries: List[str], create_index: bool, expected_size: int + operation_case: List[str], queries: List[str], create_index: bool, expected_size: int, switch: bool ) -> None: admin_cursor = get_admin_cursor() user_cursor = get_user_cursor() + if switch: + create_multi_db(admin_cursor) + switch_db(admin_cursor) + switch_db(user_cursor) + reset_permissions(admin_cursor, create_index) for operation in operation_case: @@ -120,7 +124,8 @@ def execute_read_node_assertion( assert len(results) == expected_size -def test_can_read_node_when_authorized(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_read_node_when_authorized(switch): match_queries_without_index = [match_query, match_by_id_query] match_queries_with_index = [ match_by_label_query, @@ -132,14 +137,15 @@ def test_can_read_node_when_authorized(): for expected_size, operation_case in zip( read_node_without_index_operation_cases_expected_size, read_node_without_index_operation_cases ): - execute_read_node_assertion(operation_case, match_queries_without_index, False, expected_size) + execute_read_node_assertion(operation_case, match_queries_without_index, False, expected_size, switch) for expected_size, operation_case in zip( read_node_with_index_operation_cases_expected_sizes, read_node_with_index_operation_cases ): - execute_read_node_assertion(operation_case, match_queries_with_index, True, expected_size) + execute_read_node_assertion(operation_case, match_queries_with_index, True, expected_size, switch) -def test_can_not_read_node_when_authorized(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_not_read_node_when_authorized(switch): match_queries_without_index = [match_query, match_by_id_query] match_queries_with_index = [ match_by_label_query, @@ -151,11 +157,11 @@ def test_can_not_read_node_when_authorized(): for expected_size, operation_case in zip( not_read_node_without_index_operation_cases_expected_sizes, not_read_node_without_index_operation_cases ): - execute_read_node_assertion(operation_case, match_queries_without_index, False, expected_size) + execute_read_node_assertion(operation_case, match_queries_without_index, False, expected_size, switch) for expected_size, operation_case in zip( not_read_node_with_index_operation_cases_expexted_sizes, not_read_node_with_index_operation_cases ): - execute_read_node_assertion(operation_case, match_queries_with_index, True, expected_size) + execute_read_node_assertion(operation_case, match_queries_with_index, True, expected_size, switch) if __name__ == "__main__": diff --git a/tests/e2e/lba_procedures/read_query_modules.py b/tests/e2e/lba_procedures/read_query_modules.py index 9760be5b4..a05ffcb7f 100644 --- a/tests/e2e/lba_procedures/read_query_modules.py +++ b/tests/e2e/lba_procedures/read_query_modules.py @@ -1,4 +1,4 @@ -# Copyright 2022 Memgraph Ltd. +# Copyright 2023 Memgraph Ltd. # # Use of this software is governed by the Business Source License # included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -10,177 +10,260 @@ # licenses/APL.txt. import sys + import pytest -from common import connect, execute_and_fetch_all, reset_permissions +from common import * get_number_of_vertices_query = "CALL read.number_of_visible_nodes() YIELD nr_of_nodes RETURN nr_of_nodes;" get_number_of_edges_query = "CALL read.number_of_visible_edges() YIELD nr_of_edges RETURN nr_of_edges;" -def test_can_read_vertex_through_c_api_when_given_grant_on_label(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_read_vertex_through_c_api_when_given_grant_on_label(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :read_label TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, get_number_of_vertices_query) assert result[0][0] == 1 -def test_can_read_vertex_through_c_api_when_given_update_grant_on_label(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_read_vertex_through_c_api_when_given_update_grant_on_label(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON LABELS :read_label TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, get_number_of_vertices_query) assert result[0][0] == 1 -def test_can_read_vertex_through_c_api_when_given_create_delete_grant_on_label(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_read_vertex_through_c_api_when_given_create_delete_grant_on_label(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON LABELS :read_label TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, get_number_of_vertices_query) assert result[0][0] == 1 -def test_can_not_read_vertex_through_c_api_when_given_nothing(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_not_read_vertex_through_c_api_when_given_nothing(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_permissions(admin_cursor) test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, get_number_of_vertices_query) assert result[0][0] == 0 -def test_can_not_read_vertex_through_c_api_when_given_deny_on_label(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_not_read_vertex_through_c_api_when_given_deny_on_label(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT NOTHING ON LABELS :read_label TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, get_number_of_vertices_query) assert result[0][0] == 0 -def test_can_read_partial_vertices_through_c_api_when_given_global_read_but_deny_on_label(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_read_partial_vertices_through_c_api_when_given_global_read_but_deny_on_label(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT NOTHING ON LABELS :read_label TO user;") execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS * TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, get_number_of_vertices_query) assert result[0][0] == 2 -def test_can_read_partial_vertices_through_c_api_when_given_global_update_but_deny_on_label(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_read_partial_vertices_through_c_api_when_given_global_update_but_deny_on_label(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT NOTHING ON LABELS :read_label TO user;") execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON LABELS * TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, get_number_of_vertices_query) assert result[0][0] == 2 -def test_can_read_partial_vertices_through_c_api_when_given_global_create_delete_but_deny_on_label(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_read_partial_vertices_through_c_api_when_given_global_create_delete_but_deny_on_label(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT NOTHING ON LABELS :read_label TO user;") execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON LABELS * TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, get_number_of_vertices_query) assert result[0][0] == 2 -def test_can_read_edge_through_c_api_when_given_grant_on_edge_type(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_read_edge_through_c_api_when_given_grant_on_edge_type(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :read_label_1, :read_label_2 TO user;") execute_and_fetch_all(admin_cursor, "GRANT READ ON EDGE_TYPES :read_edge_type TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, get_number_of_edges_query) assert result[0][0] == 1 -def test_can_not_read_edge_through_c_api_when_given_deny_on_edge_type(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_not_read_edge_through_c_api_when_given_deny_on_edge_type(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :read_label_1, :read_label_2 TO user;") execute_and_fetch_all(admin_cursor, "GRANT NOTHING ON EDGE_TYPES :read_edge_type TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, get_number_of_edges_query) assert result[0][0] == 0 -def test_can_read_edge_through_c_api_when_given_grant_on_edge_type(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_read_edge_through_c_api_when_given_grant_on_edge_type(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :read_label_1, :read_label_2 TO user;") execute_and_fetch_all(admin_cursor, "GRANT READ ON EDGE_TYPES :read_edge_type TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, get_number_of_edges_query) assert result[0][0] == 1 -def test_can_read_edge_through_c_api_when_given_update_on_edge_type(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_read_edge_through_c_api_when_given_update_on_edge_type(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :read_label_1, :read_label_2 TO user;") execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON EDGE_TYPES :read_edge_type TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, get_number_of_edges_query) assert result[0][0] == 1 -def test_can_read_edge_through_c_api_when_given_create_delete_on_edge_type(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_read_edge_through_c_api_when_given_create_delete_on_edge_type(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :read_label_1, :read_label_2 TO user;") execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON EDGE_TYPES :read_edge_type TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, get_number_of_edges_query) assert result[0][0] == 1 -def test_can_not_read_edge_through_c_api_when_given_read_global_but_deny_on_edge_type(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_not_read_edge_through_c_api_when_given_read_global_but_deny_on_edge_type(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :read_label_1, :read_label_2 TO user;") @@ -188,13 +271,19 @@ def test_can_not_read_edge_through_c_api_when_given_read_global_but_deny_on_edge execute_and_fetch_all(admin_cursor, "GRANT READ ON EDGE_TYPES * TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, get_number_of_edges_query) assert result[0][0] == 0 -def test_can_not_read_edge_through_c_api_when_given_update_global_but_deny_on_edge_type(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_not_read_edge_through_c_api_when_given_update_global_but_deny_on_edge_type(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :read_label_1, :read_label_2 TO user;") @@ -202,13 +291,19 @@ def test_can_not_read_edge_through_c_api_when_given_update_global_but_deny_on_ed execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON EDGE_TYPES * TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, get_number_of_edges_query) assert result[0][0] == 0 -def test_can_not_read_edge_through_c_api_when_given_create_delete_global_but_deny_on_edge_type(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_not_read_edge_through_c_api_when_given_create_delete_global_but_deny_on_edge_type(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :read_label_1, :read_label_2 TO user;") @@ -216,6 +311,8 @@ def test_can_not_read_edge_through_c_api_when_given_create_delete_global_but_den execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, get_number_of_edges_query) assert result[0][0] == 0 diff --git a/tests/e2e/lba_procedures/show_privileges.py b/tests/e2e/lba_procedures/show_privileges.py index 83a2eae97..1021ea7be 100644 --- a/tests/e2e/lba_procedures/show_privileges.py +++ b/tests/e2e/lba_procedures/show_privileges.py @@ -38,6 +38,8 @@ BASIC_PRIVILEGES = [ "MODULE_WRITE", "TRANSACTION_MANAGEMENT", "STORAGE_MODE", + "MULTI_DATABASE_EDIT", + "MULTI_DATABASE_USE", ] @@ -61,7 +63,7 @@ def test_lba_procedures_show_privileges_first_user(): cursor = connect(username="Josip", password="").cursor() result = execute_and_fetch_all(cursor, "SHOW PRIVILEGES FOR Josip;") - assert len(result) == 32 + assert len(result) == 34 fine_privilege_results = [res for res in result if res[0] not in BASIC_PRIVILEGES] diff --git a/tests/e2e/lba_procedures/update_query_modules.py b/tests/e2e/lba_procedures/update_query_modules.py index 6d33cd8b8..538c4e917 100644 --- a/tests/e2e/lba_procedures/update_query_modules.py +++ b/tests/e2e/lba_procedures/update_query_modules.py @@ -1,4 +1,4 @@ -# Copyright 2022 Memgraph Ltd. +# Copyright 2023 Memgraph Ltd. # # Use of this software is governed by the Business Source License # included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -9,107 +9,149 @@ # by the Apache License, Version 2.0, included in the file # licenses/APL.txt. -import pytest import sys -from common import ( - connect, - execute_and_fetch_all, - reset_update_permissions, -) +import pytest +from common import * set_vertex_property_query = "MATCH (n:update_label) CALL update.set_property(n) YIELD * RETURN n.prop;" set_edge_property_query = "MATCH (n:update_label_1)-[r:update_edge_type]->(m:update_label_2) CALL update.set_property(r) YIELD * RETURN r.prop;" -def test_can_not_update_vertex_when_given_read(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_not_update_vertex_when_given_read(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_update_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :update_label TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, set_vertex_property_query) assert result[0][0] == 1 -def test_can_update_vertex_when_given_update_grant_on_label(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_update_vertex_when_given_update_grant_on_label(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_update_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON LABELS :update_label TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, set_vertex_property_query) assert result[0][0] == 2 -def test_can_update_vertex_when_given_create_delete_grant_on_label(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_update_vertex_when_given_create_delete_grant_on_label(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_update_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON LABELS :update_label TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, set_vertex_property_query) assert result[0][0] == 2 -def test_can_update_vertex_when_given_update_global_grant_on_label(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_update_vertex_when_given_update_global_grant_on_label(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_update_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON LABELS * TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, set_vertex_property_query) assert result[0][0] == 2 -def test_can_update_vertex_when_given_create_delete_global_grant_on_label(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_update_vertex_when_given_create_delete_global_grant_on_label(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_update_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON LABELS * TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, set_vertex_property_query) assert result[0][0] == 2 -def test_can_not_update_vertex_when_denied_update_and_granted_global_update_on_label(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_not_update_vertex_when_denied_update_and_granted_global_update_on_label(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_update_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :update_label TO user;") execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON LABELS * TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, set_vertex_property_query) assert result[0][0] == 1 -def test_can_not_update_vertex_when_denied_update_and_granted_global_create_delete_on_label(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_not_update_vertex_when_denied_update_and_granted_global_create_delete_on_label(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_update_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :update_label TO user;") execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON LABELS * TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, set_vertex_property_query) assert result[0][0] == 1 -def test_can_update_edge_when_given_update_grant_on_edge_type(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_update_edge_when_given_update_grant_on_edge_type(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_update_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :update_label_1 TO user;") @@ -117,13 +159,19 @@ def test_can_update_edge_when_given_update_grant_on_edge_type(): execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON EDGE_TYPES :update_edge_type TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, set_edge_property_query) assert result[0][0] == 2 -def test_can_not_update_edge_when_given_read_grant_on_edge_type(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_not_update_edge_when_given_read_grant_on_edge_type(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_update_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :update_label_1 TO user;") @@ -131,13 +179,19 @@ def test_can_not_update_edge_when_given_read_grant_on_edge_type(): execute_and_fetch_all(admin_cursor, "GRANT READ ON EDGE_TYPES :update_edge_type TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, set_edge_property_query) assert result[0][0] == 1 -def test_can_update_edge_when_given_create_delete_grant_on_edge_type(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_update_edge_when_given_create_delete_grant_on_edge_type(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_update_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :update_label_1 TO user;") @@ -145,13 +199,19 @@ def test_can_update_edge_when_given_create_delete_grant_on_edge_type(): execute_and_fetch_all(admin_cursor, "GRANT CREATE_DELETE ON EDGE_TYPES :update_edge_type TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, set_edge_property_query) assert result[0][0] == 2 -def test_can_not_update_edge_when_denied_update_edge_type_but_granted_global_update_on_edge_type(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_not_update_edge_when_denied_update_edge_type_but_granted_global_update_on_edge_type(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_update_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :update_label_1 TO user;") @@ -160,13 +220,19 @@ def test_can_not_update_edge_when_denied_update_edge_type_but_granted_global_upd execute_and_fetch_all(admin_cursor, "GRANT READ ON EDGE_TYPES * TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, set_edge_property_query) assert result[0][0] == 1 -def test_can_not_update_edge_when_denied_update_edge_type_but_granted_global_create_delete_on_edge_type(): +@pytest.mark.parametrize("switch", [False, True]) +def test_can_not_update_edge_when_denied_update_edge_type_but_granted_global_create_delete_on_edge_type(switch): admin_cursor = connect(username="admin", password="test").cursor() + create_multi_db(admin_cursor) + if switch: + switch_db(admin_cursor) reset_update_permissions(admin_cursor) execute_and_fetch_all(admin_cursor, "GRANT READ ON LABELS :update_label_1 TO user;") @@ -175,6 +241,8 @@ def test_can_not_update_edge_when_denied_update_edge_type_but_granted_global_cre execute_and_fetch_all(admin_cursor, "GRANT UPDATE ON EDGE_TYPES * TO user;") test_cursor = connect(username="user", password="test").cursor() + if switch: + switch_db(test_cursor) result = execute_and_fetch_all(test_cursor, set_edge_property_query) assert result[0][0] == 1 diff --git a/tests/e2e/lba_procedures/workloads.yaml b/tests/e2e/lba_procedures/workloads.yaml index 1e7030fa9..d7de0de2a 100644 --- a/tests/e2e/lba_procedures/workloads.yaml +++ b/tests/e2e/lba_procedures/workloads.yaml @@ -6,8 +6,10 @@ read_query_modules_cluster: &read_query_modules_cluster setup_queries: - "CREATE USER admin IDENTIFIED BY 'test';" - "GRANT ALL PRIVILEGES TO admin" + - "GRANT DATABASE * TO admin" - "CREATE USER user IDENTIFIED BY 'test';" - "GRANT ALL PRIVILEGES TO user" + - "GRANT DATABASE * TO user" validation_queries: [] update_query_modules_cluster: &update_query_modules_cluster @@ -18,8 +20,10 @@ update_query_modules_cluster: &update_query_modules_cluster setup_queries: - "CREATE USER admin IDENTIFIED BY 'test';" - "GRANT ALL PRIVILEGES TO admin" + - "GRANT DATABASE * TO admin" - "CREATE USER user IDENTIFIED BY 'test';" - "GRANT ALL PRIVILEGES TO user" + - "GRANT DATABASE * TO user" validation_queries: [] show_privileges_cluster: &show_privileges_cluster @@ -67,8 +71,10 @@ read_permission_queries: &read_permission_queries setup_queries: - "CREATE USER admin IDENTIFIED BY 'test';" - "GRANT ALL PRIVILEGES TO admin" + - "GRANT DATABASE * TO admin" - "CREATE USER user IDENTIFIED BY 'test';" - "GRANT ALL PRIVILEGES TO user" + - "GRANT DATABASE * TO user" validation_queries: [] create_delete_query_modules_cluster: &create_delete_query_modules_cluster @@ -79,8 +85,10 @@ create_delete_query_modules_cluster: &create_delete_query_modules_cluster setup_queries: - "CREATE USER admin IDENTIFIED BY 'test';" - "GRANT ALL PRIVILEGES TO admin;" + - "GRANT DATABASE * TO admin" - "CREATE USER user IDENTIFIED BY 'test';" - "GRANT ALL PRIVILEGES TO user;" + - "GRANT DATABASE * TO user" validation_queries: [] update_permission_queries_cluster: &update_permission_queries_cluster @@ -91,8 +99,10 @@ update_permission_queries_cluster: &update_permission_queries_cluster setup_queries: - "CREATE USER admin IDENTIFIED BY 'test';" - "GRANT ALL PRIVILEGES TO admin;" + - "GRANT DATABASE * TO admin" - "CREATE USER user IDENTIFIED BY 'test'" - "GRANT ALL PRIVILEGES TO user;" + - "GRANT DATABASE * TO user" validation_queries: [] workloads: diff --git a/tests/e2e/magic_functions/function_example.py b/tests/e2e/magic_functions/function_example.py index 0ca208b7d..f63e60b79 100644 --- a/tests/e2e/magic_functions/function_example.py +++ b/tests/e2e/magic_functions/function_example.py @@ -1,4 +1,4 @@ -# Copyright 2022 Memgraph Ltd. +# Copyright 2023 Memgraph Ltd. # # Use of this software is governed by the Business Source License # included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -9,16 +9,29 @@ # by the Apache License, Version 2.0, included in the file # licenses/APL.txt. -import typing -import mgclient import sys +import typing + +import mgclient import pytest from common import execute_and_fetch_all, has_n_result_row -@pytest.mark.parametrize("function_type", ["py", "c"]) -def test_return_argument(connection, function_type): +@pytest.fixture(scope="function") +def multi_db(request, connection): cursor = connection.cursor() + if request.param: + execute_and_fetch_all(cursor, "CREATE DATABASE clean") + execute_and_fetch_all(cursor, "USE DATABASE clean") + execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n") + pass + yield connection + + +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) +@pytest.mark.parametrize("function_type", ["py", "c"]) +def test_return_argument(multi_db, function_type): + cursor = multi_db.cursor() execute_and_fetch_all(cursor, "CREATE (n:Label {id: 1});") assert has_n_result_row(cursor, "MATCH (n) RETURN n", 1) result = execute_and_fetch_all( @@ -31,9 +44,10 @@ def test_return_argument(connection, function_type): assert vertex.properties == {"id": 1} +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) @pytest.mark.parametrize("function_type", ["py", "c"]) -def test_return_optional_argument(connection, function_type): - cursor = connection.cursor() +def test_return_optional_argument(multi_db, function_type): + cursor = multi_db.cursor() assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0) result = execute_and_fetch_all( cursor, @@ -44,9 +58,10 @@ def test_return_optional_argument(connection, function_type): assert result == 42 +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) @pytest.mark.parametrize("function_type", ["py", "c"]) -def test_return_optional_argument_no_arg(connection, function_type): - cursor = connection.cursor() +def test_return_optional_argument_no_arg(multi_db, function_type): + cursor = multi_db.cursor() assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0) result = execute_and_fetch_all( cursor, @@ -57,9 +72,10 @@ def test_return_optional_argument_no_arg(connection, function_type): assert result == 42 +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) @pytest.mark.parametrize("function_type", ["py", "c"]) -def test_add_two_numbers(connection, function_type): - cursor = connection.cursor() +def test_add_two_numbers(multi_db, function_type): + cursor = multi_db.cursor() assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0) result = execute_and_fetch_all( cursor, @@ -70,9 +86,10 @@ def test_add_two_numbers(connection, function_type): assert result_sum == 6 +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) @pytest.mark.parametrize("function_type", ["py", "c"]) -def test_return_null(connection, function_type): - cursor = connection.cursor() +def test_return_null(multi_db, function_type): + cursor = multi_db.cursor() assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0) result = execute_and_fetch_all( cursor, @@ -82,9 +99,10 @@ def test_return_null(connection, function_type): assert result_null is None +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) @pytest.mark.parametrize("function_type", ["py", "c"]) -def test_too_many_arguments(connection, function_type): - cursor = connection.cursor() +def test_too_many_arguments(multi_db, function_type): + cursor = multi_db.cursor() assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0) # Should raise too many arguments with pytest.raises(mgclient.DatabaseError): @@ -94,9 +112,10 @@ def test_too_many_arguments(connection, function_type): ) +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) @pytest.mark.parametrize("function_type", ["py", "c"]) -def test_try_to_write(connection, function_type): - cursor = connection.cursor() +def test_try_to_write(multi_db, function_type): + cursor = multi_db.cursor() execute_and_fetch_all(cursor, "CREATE (n:Label {id: 1});") assert has_n_result_row(cursor, "MATCH (n) RETURN n", 1) # Should raise non mutable @@ -106,9 +125,11 @@ def test_try_to_write(connection, function_type): f"MATCH (n) RETURN {function_type}_write.try_to_write(n, 'property', 1);", ) + +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) @pytest.mark.parametrize("function_type", ["py", "c"]) -def test_case_sensitivity(connection, function_type): - cursor = connection.cursor() +def test_case_sensitivity(multi_db, function_type): + cursor = multi_db.cursor() assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0) # Should raise function does not exist with pytest.raises(mgclient.DatabaseError): diff --git a/tests/e2e/memory/memory_control.cpp b/tests/e2e/memory/memory_control.cpp index fc154684b..0d969220b 100644 --- a/tests/e2e/memory/memory_control.cpp +++ b/tests/e2e/memory/memory_control.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -17,6 +17,7 @@ DEFINE_uint64(bolt_port, 7687, "Bolt port"); DEFINE_uint64(timeout, 120, "Timeout seconds"); +DEFINE_bool(multi_db, false, "Run test in multi db environment"); int main(int argc, char **argv) { google::SetUsageMessage("Memgraph E2E Memory Control"); @@ -34,6 +35,15 @@ int main(int argc, char **argv) { client->Execute("MATCH (n) DETACH DELETE n;"); client->DiscardAll(); + if (FLAGS_multi_db) { + client->Execute("CREATE DATABASE clean;"); + client->DiscardAll(); + client->Execute("USE DATABASE clean;"); + client->DiscardAll(); + client->Execute("MATCH (n) DETACH DELETE n;"); + client->DiscardAll(); + } + const auto *create_query = "UNWIND range(1, 50) as u CREATE (n {string: \"Some longer string\"}) RETURN n;"; memgraph::utils::Timer timer; diff --git a/tests/e2e/memory/memory_limit_global_alloc.cpp b/tests/e2e/memory/memory_limit_global_alloc.cpp index 611e2554a..44938ab1d 100644 --- a/tests/e2e/memory/memory_limit_global_alloc.cpp +++ b/tests/e2e/memory/memory_limit_global_alloc.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -17,6 +17,7 @@ DEFINE_uint64(bolt_port, 7687, "Bolt port"); DEFINE_uint64(timeout, 120, "Timeout seconds"); +DEFINE_bool(multi_db, false, "Run test in multi db environment"); int main(int argc, char **argv) { google::SetUsageMessage("Memgraph E2E Memory Limit For Global Allocators"); @@ -31,6 +32,15 @@ int main(int argc, char **argv) { LOG_FATAL("Failed to connect!"); } + if (FLAGS_multi_db) { + client->Execute("CREATE DATABASE clean;"); + client->DiscardAll(); + client->Execute("USE DATABASE clean;"); + client->DiscardAll(); + client->Execute("MATCH (n) DETACH DELETE n;"); + client->DiscardAll(); + } + bool result = client->Execute("CALL libglobal_memory_limit.procedure() YIELD *"); MG_ASSERT(result == false); return 0; diff --git a/tests/e2e/memory/memory_limit_global_alloc_proc.cpp b/tests/e2e/memory/memory_limit_global_alloc_proc.cpp index 7efb435c3..e1f530123 100644 --- a/tests/e2e/memory/memory_limit_global_alloc_proc.cpp +++ b/tests/e2e/memory/memory_limit_global_alloc_proc.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -18,6 +18,7 @@ DEFINE_uint64(bolt_port, 7687, "Bolt port"); DEFINE_uint64(timeout, 120, "Timeout seconds"); +DEFINE_bool(multi_db, false, "Run test in multi db environment"); int main(int argc, char **argv) { google::SetUsageMessage("Memgraph E2E Memory Limit For Global Allocators"); @@ -31,6 +32,16 @@ int main(int argc, char **argv) { if (!client) { LOG_FATAL("Failed to connect!"); } + + if (FLAGS_multi_db) { + client->Execute("CREATE DATABASE clean;"); + client->DiscardAll(); + client->Execute("USE DATABASE clean;"); + client->DiscardAll(); + client->Execute("MATCH (n) DETACH DELETE n;"); + client->DiscardAll(); + } + MG_ASSERT(client->Execute("CALL libglobal_memory_limit_proc.error() YIELD *")); MG_ASSERT(std::invoke([&] { try { diff --git a/tests/e2e/memory/workloads.yaml b/tests/e2e/memory/workloads.yaml index 88573c761..64e45979d 100644 --- a/tests/e2e/memory/workloads.yaml +++ b/tests/e2e/memory/workloads.yaml @@ -21,14 +21,31 @@ workloads: args: ["--bolt-port", *bolt_port, "--timeout", "180"] <<: *template_cluster + - name: "Memory control multi database" + binary: "tests/e2e/memory/memgraph__e2e__memory__control" + args: ["--bolt-port", *bolt_port, "--timeout", "180", "--multi-db", "true"] + <<: *template_cluster + - name: "Memory limit for modules upon loading" binary: "tests/e2e/memory/memgraph__e2e__memory__limit_global_alloc" args: ["--bolt-port", *bolt_port, "--timeout", "180"] proc: "tests/e2e/memory/procedures/" <<: *template_cluster + - name: "Memory limit for modules upon loading multi database" + binary: "tests/e2e/memory/memgraph__e2e__memory__limit_global_alloc" + args: ["--bolt-port", *bolt_port, "--timeout", "180", "--multi-db", "true"] + proc: "tests/e2e/memory/procedures/" + <<: *template_cluster + - name: "Memory limit for modules inside a procedure" binary: "tests/e2e/memory/memgraph__e2e__memory__limit_global_alloc_proc" args: ["--bolt-port", *bolt_port, "--timeout", "180"] proc: "tests/e2e/memory/procedures/" <<: *template_cluster + + - name: "Memory limit for modules inside a procedure multi database" + binary: "tests/e2e/memory/memgraph__e2e__memory__limit_global_alloc_proc" + args: ["--bolt-port", *bolt_port, "--timeout", "180", "--multi-db", "true"] + proc: "tests/e2e/memory/procedures/" + <<: *template_cluster diff --git a/tests/e2e/module_file_manager/module_file_manager.cpp b/tests/e2e/module_file_manager/module_file_manager.cpp index 5966b5701..20641b9d7 100644 --- a/tests/e2e/module_file_manager/module_file_manager.cpp +++ b/tests/e2e/module_file_manager/module_file_manager.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -21,6 +21,7 @@ DEFINE_uint64(bolt_port, 7687, "Bolt port"); DEFINE_uint64(timeout, 120, "Timeout seconds"); +DEFINE_bool(multi_db, false, "Run test in multi db environment"); namespace { auto GetClient() { @@ -181,6 +182,15 @@ int main(int argc, char **argv) { mg::Client::Init(); auto client = GetClient(); + if (FLAGS_multi_db) { + client->Execute("CREATE DATABASE clean;"); + client->DiscardAll(); + client->Execute("USE DATABASE clean;"); + client->DiscardAll(); + client->Execute("MATCH (n) DETACH DELETE n;"); + client->DiscardAll(); + } + AssertQueryFails(client, CreateModuleFileQuery("some.cpp", "some content"), "mg.create_module_file: The specified file isn't in the supported format."); diff --git a/tests/e2e/module_file_manager/workloads.yaml b/tests/e2e/module_file_manager/workloads.yaml index b6670c6ef..b40242e21 100644 --- a/tests/e2e/module_file_manager/workloads.yaml +++ b/tests/e2e/module_file_manager/workloads.yaml @@ -12,3 +12,8 @@ workloads: binary: "tests/e2e/module_file_manager/memgraph__e2e__module_file_manager" args: ["--bolt-port", *bolt_port] <<: *template_cluster + + - name: "Module File Manager multi database" + binary: "tests/e2e/module_file_manager/memgraph__e2e__module_file_manager" + args: ["--bolt-port", *bolt_port, "--multi-db", "true"] + <<: *template_cluster diff --git a/tests/e2e/python_query_modules_reloading/common.py b/tests/e2e/python_query_modules_reloading/common.py index 3a166a1d5..0faff71b8 100644 --- a/tests/e2e/python_query_modules_reloading/common.py +++ b/tests/e2e/python_query_modules_reloading/common.py @@ -14,6 +14,19 @@ import typing import mgclient +def switch_db(cursor): + execute_and_fetch_all(cursor, "USE DATABASE clean;") + + +def create_multi_db(cursor): + execute_and_fetch_all(cursor, "USE DATABASE memgraph;") + try: + execute_and_fetch_all(cursor, "DROP DATABASE clean;") + except: + pass + execute_and_fetch_all(cursor, "CREATE DATABASE clean;") + + def execute_and_fetch_all(cursor: mgclient.Cursor, query: str, params: dict = {}) -> typing.List[tuple]: cursor.execute(query, params) return cursor.fetchall() diff --git a/tests/e2e/python_query_modules_reloading/test_reload_query_module.py b/tests/e2e/python_query_modules_reloading/test_reload_query_module.py index 7f7205d82..d8e8a05d1 100644 --- a/tests/e2e/python_query_modules_reloading/test_reload_query_module.py +++ b/tests/e2e/python_query_modules_reloading/test_reload_query_module.py @@ -1,4 +1,4 @@ -# Copyright 2022 Memgraph Ltd. +# Copyright 2023 Memgraph Ltd. # # Use of this software is governed by the Business Source License # included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -14,7 +14,7 @@ import os # To be removed import sys import pytest -from common import connect, execute_and_fetch_all +from common import connect, create_multi_db, execute_and_fetch_all, switch_db COMMON_PATH_PREFIX_TEST1 = "procedures/mage/test_module" COMMON_PATH_PREFIX_TEST2 = "procedures/new_test_module_utils" @@ -76,9 +76,13 @@ def postprocess_functions(path1: str, path2: str): ) -def test_mg_load_reload_submodule_root_utils(): +@pytest.mark.parametrize("switch", [False, True]) +def test_mg_load_reload_submodule_root_utils(switch): """Tests whether mg.load reloads content of some submodule code.""" cursor = connect().cursor() + if switch: + create_multi_db(cursor) + switch_db(cursor) # First do a simple experiment test_module_res = execute_and_fetch_all(cursor, "CALL new_test_module.test(10, 2) YIELD * RETURN *;") try: @@ -101,9 +105,13 @@ def test_mg_load_reload_submodule_root_utils(): execute_and_fetch_all(cursor, "CALL mg.load('new_test_module');") -def test_mg_load_all_reload_submodule_root_utils(): +@pytest.mark.parametrize("switch", [False, True]) +def test_mg_load_all_reload_submodule_root_utils(switch): """Tests whether mg.load_all reloads content of some submodule code""" cursor = connect().cursor() + if switch: + create_multi_db(cursor) + switch_db(cursor) # First do a simple experiment test_module_res = execute_and_fetch_all(cursor, "CALL new_test_module.test(10, 2) YIELD * RETURN *;") try: @@ -126,9 +134,13 @@ def test_mg_load_all_reload_submodule_root_utils(): execute_and_fetch_all(cursor, "CALL mg.load_all();") -def test_mg_load_reload_submodule(): +@pytest.mark.parametrize("switch", [False, True]) +def test_mg_load_reload_submodule(switch): """Tests whether mg.load reloads content of some submodule code.""" cursor = connect().cursor() + if switch: + create_multi_db(cursor) + switch_db(cursor) # First do a simple experiment test_module_res = execute_and_fetch_all(cursor, "CALL test_module.test(10, 2) YIELD * RETURN *;") try: @@ -151,9 +163,13 @@ def test_mg_load_reload_submodule(): execute_and_fetch_all(cursor, "CALL mg.load('test_module');") -def test_mg_load_all_reload_submodule(): +@pytest.mark.parametrize("switch", [False, True]) +def test_mg_load_all_reload_submodule(switch): """Tests whether mg.load_all reloads content of some submodule code""" cursor = connect().cursor() + if switch: + create_multi_db(cursor) + switch_db(cursor) # First do a simple experiment test_module_res = execute_and_fetch_all(cursor, "CALL test_module.test(10, 2) YIELD * RETURN *;") try: diff --git a/tests/e2e/transaction_queue/common.py b/tests/e2e/transaction_queue/common.py index fdab67a72..3a166a1d5 100644 --- a/tests/e2e/transaction_queue/common.py +++ b/tests/e2e/transaction_queue/common.py @@ -12,7 +12,6 @@ import typing import mgclient -import pytest def execute_and_fetch_all(cursor: mgclient.Cursor, query: str, params: dict = {}) -> typing.List[tuple]: diff --git a/tests/e2e/transaction_queue/test_transaction_queue.py b/tests/e2e/transaction_queue/test_transaction_queue.py index a75fb40cd..221243c50 100644 --- a/tests/e2e/transaction_queue/test_transaction_queue.py +++ b/tests/e2e/transaction_queue/test_transaction_queue.py @@ -12,7 +12,6 @@ import multiprocessing import sys -import threading import time from typing import List @@ -56,12 +55,28 @@ def test_self_transaction(): assert len(results) == 1 +def test_multitenant_transactions(): + """Tests that show transactions work on another database""" + test_cursor = connect().cursor() + execute_and_fetch_all(test_cursor, "CREATE DATABASE testing") + tx_connection = connect() + tx_cursor = tx_connection.cursor() + tx_process = multiprocessing.Process( + target=process_function, args=(tx_cursor, ["USE DATABASE testing", "MATCH (n) RETURN n"]) + ) + tx_process.start() + time.sleep(0.5) + show_transactions_test(test_cursor, 1) + # TODO Add SHOW TRANSACTIONS ON * that should return all transactions + + def test_admin_has_one_transaction(): """Creates admin and tests that he sees only one transaction.""" # a_cursor is used for creating admin user, simulates main thread superadmin_cursor = connect().cursor() execute_and_fetch_all(superadmin_cursor, "CREATE USER admin") execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin") + execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin") admin_cursor = connect(username="admin", password="").cursor() process = multiprocessing.Process(target=show_transactions_test, args=(admin_cursor, 1)) process.start() @@ -74,6 +89,7 @@ def test_user_can_see_its_transaction(): superadmin_cursor = connect().cursor() execute_and_fetch_all(superadmin_cursor, "CREATE USER admin") execute_and_fetch_all(superadmin_cursor, "GRANT ALL PRIVILEGES TO admin") + execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin") execute_and_fetch_all(superadmin_cursor, "CREATE USER user") execute_and_fetch_all(superadmin_cursor, "REVOKE ALL PRIVILEGES FROM user") user_cursor = connect(username="user", password="").cursor() @@ -89,6 +105,7 @@ def test_explicit_transaction_output(): superadmin_cursor = connect().cursor() execute_and_fetch_all(superadmin_cursor, "CREATE USER admin") execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin") + execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin") admin_connection = connect(username="admin", password="") admin_cursor = admin_connection.cursor() # Admin starts running explicit transaction @@ -114,8 +131,10 @@ def test_superadmin_cannot_see_admin_can_see_admin(): superadmin_cursor = connect().cursor() execute_and_fetch_all(superadmin_cursor, "CREATE USER admin1") execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin1") + execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin1") execute_and_fetch_all(superadmin_cursor, "CREATE USER admin2") execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin2") + execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin2") # Admin starts running infinite query admin_connection_1 = connect(username="admin1", password="") admin_cursor_1 = admin_connection_1.cursor() @@ -153,6 +172,7 @@ def test_admin_sees_superadmin(): superadmin_cursor = superadmin_connection.cursor() execute_and_fetch_all(superadmin_cursor, "CREATE USER admin") execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin") + execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin") # Admin starts running infinite query process = multiprocessing.Process( target=process_function, args=(superadmin_cursor, ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"]) @@ -183,6 +203,7 @@ def test_admin_can_see_user_transaction(): superadmin_cursor = connect().cursor() execute_and_fetch_all(superadmin_cursor, "CREATE USER admin") execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin") + execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin") execute_and_fetch_all(superadmin_cursor, "CREATE USER user") # Admin starts running infinite query admin_connection = connect(username="admin", password="") @@ -220,8 +241,10 @@ def test_user_cannot_see_admin_transaction(): superadmin_cursor = connect().cursor() execute_and_fetch_all(superadmin_cursor, "CREATE USER admin1") execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin1") + execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin1") execute_and_fetch_all(superadmin_cursor, "CREATE USER admin2") execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin2") + execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin2") execute_and_fetch_all(superadmin_cursor, "CREATE USER user") admin_connection_1 = connect(username="admin1", password="") admin_cursor_1 = admin_connection_1.cursor() @@ -282,6 +305,7 @@ def test_admin_killing_multiple_non_existing_transactions(): superadmin_cursor = connect().cursor() execute_and_fetch_all(superadmin_cursor, "CREATE USER admin") execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin") + execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin") # Connect with admin admin_cursor = connect(username="admin", password="").cursor() transactions_id = ["'1'", "'2'", "'3'"] @@ -298,6 +322,7 @@ def test_user_killing_some_transactions(): superadmin_cursor = connect().cursor() execute_and_fetch_all(superadmin_cursor, "CREATE USER admin") execute_and_fetch_all(superadmin_cursor, "GRANT ALL PRIVILEGES TO admin") + execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin") execute_and_fetch_all(superadmin_cursor, "CREATE USER user1") execute_and_fetch_all(superadmin_cursor, "REVOKE ALL PRIVILEGES FROM user1") diff --git a/tests/e2e/triggers/common.py b/tests/e2e/triggers/common.py index 89e7e5a88..9609445e1 100644 --- a/tests/e2e/triggers/common.py +++ b/tests/e2e/triggers/common.py @@ -24,11 +24,14 @@ def execute_and_fetch_all(cursor: mgclient.Cursor, query: str, params: dict = {} def connect(**kwargs) -> mgclient.Connection: connection = mgclient.connect(host="localhost", port=7687, **kwargs) connection.autocommit = True + execute_and_fetch_all(connection.cursor(), "USE DATABASE memgraph") + try: + execute_and_fetch_all(connection.cursor(), "DROP DATABASE clean") + except: + pass + execute_and_fetch_all(connection.cursor(), "MATCH (n) DETACH DELETE n") triggers_list = execute_and_fetch_all(connection.cursor(), "SHOW TRIGGERS;") for trigger in triggers_list: execute_and_fetch_all(connection.cursor(), f"DROP TRIGGER {trigger[0]}") execute_and_fetch_all(connection.cursor(), "MATCH (n) DETACH DELETE n") yield connection - for trigger in triggers_list: - execute_and_fetch_all(connection.cursor(), f"DROP TRIGGER {trigger[0]}") - execute_and_fetch_all(connection.cursor(), "MATCH (n) DETACH DELETE n") diff --git a/tests/e2e/triggers/triggers_properties_false.py b/tests/e2e/triggers/triggers_properties_false.py index 2f5b69feb..c22ddd0d6 100644 --- a/tests/e2e/triggers/triggers_properties_false.py +++ b/tests/e2e/triggers/triggers_properties_false.py @@ -1,4 +1,4 @@ -# Copyright 2022 Memgraph Ltd. +# Copyright 2023 Memgraph Ltd. # # Use of this software is governed by the Business Source License # included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -16,13 +16,25 @@ import pytest from common import connect, execute_and_fetch_all +@pytest.fixture(scope="function") +def multi_db(request, connect): + cursor = connect.cursor() + if request.param: + execute_and_fetch_all(cursor, "CREATE DATABASE clean") + execute_and_fetch_all(cursor, "USE DATABASE clean") + execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n") + pass + yield connect + + +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) @pytest.mark.parametrize("ba_commit", ["BEFORE COMMIT", "AFTER COMMIT"]) -def test_create_on_create(ba_commit, connect): +def test_create_on_create(ba_commit, multi_db): """ Args: ba_commit (str): BEFORE OR AFTER commit """ - cursor = connect.cursor() + cursor = multi_db.cursor() QUERY_TRIGGER_CREATE = f""" CREATE TRIGGER CreateTriggerEdgesCount ON --> CREATE @@ -30,6 +42,7 @@ def test_create_on_create(ba_commit, connect): EXECUTE CREATE (n:CreatedEdge {{count: size(createdEdges)}}) """ + execute_and_fetch_all(cursor, QUERY_TRIGGER_CREATE) execute_and_fetch_all(cursor, "CREATE (n:Node {id: 1})") execute_and_fetch_all(cursor, "CREATE (n:Node {id: 2})") @@ -50,14 +63,22 @@ def test_create_on_create(ba_commit, connect): # execute_and_fetch_all(cursor, "DROP TRIGGER CreateTriggerEdgesCount") # execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;") + # check that there is no cross contamination between databases + nodes = execute_and_fetch_all(cursor, "SHOW DATABASES") + if len(nodes) == 2: # multi db mode + execute_and_fetch_all(cursor, "USE DATABASE memgraph") + created_edges = execute_and_fetch_all(cursor, "MATCH (n:CreatedEdge) RETURN n") + assert len(created_edges) == 0 + +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) @pytest.mark.parametrize("ba_commit", ["AFTER COMMIT", "BEFORE COMMIT"]) -def test_create_on_delete(ba_commit, connect): +def test_create_on_delete(ba_commit, multi_db): """ Args: ba_commit (str): BEFORE OR AFTER commit """ - cursor = connect.cursor() + cursor = multi_db.cursor() QUERY_TRIGGER_CREATE = f""" CREATE TRIGGER DeleteTriggerEdgesCount ON --> DELETE @@ -102,7 +123,15 @@ def test_create_on_delete(ba_commit, connect): # execute_and_fetch_all(cursor, "DROP TRIGGER DeleteTriggerEdgesCount") # execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n")`` + # check that there is no cross contamination between databases + nodes = execute_and_fetch_all(cursor, "SHOW DATABASES") + if len(nodes) == 2: # multi db mode + execute_and_fetch_all(cursor, "USE DATABASE memgraph") + created_edges = execute_and_fetch_all(cursor, "MATCH (n:CreatedEdge) RETURN n") + assert len(created_edges) == 0 + +# @pytest.mark.parametrize("multi_db", [False, True], indirect=True) @pytest.mark.parametrize("ba_commit", ["BEFORE COMMIT", "AFTER COMMIT"]) def test_create_on_delete_explicit_transaction(ba_commit): """ diff --git a/tests/e2e/write_procedures/read_subgraph.py b/tests/e2e/write_procedures/read_subgraph.py index 40f315050..2ebf9c956 100644 --- a/tests/e2e/write_procedures/read_subgraph.py +++ b/tests/e2e/write_procedures/read_subgraph.py @@ -1,4 +1,4 @@ -# Copyright 2022 Memgraph Ltd. +# Copyright 2023 Memgraph Ltd. # # Use of this software is governed by the Business Source License # included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -9,13 +9,25 @@ # by the Apache License, Version 2.0, included in the file # licenses/APL.txt. -import typing -import mgclient import sys +import typing + +import mgclient import pytest from common import execute_and_fetch_all, has_n_result_row +@pytest.fixture(scope="function") +def multi_db(request, connection): + cursor = connection.cursor() + if request.param: + execute_and_fetch_all(cursor, "CREATE DATABASE clean") + execute_and_fetch_all(cursor, "USE DATABASE clean") + execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n") + pass + yield connection + + def create_subgraph(cursor): execute_and_fetch_all(cursor, "CREATE (n:Person {id: 1});") execute_and_fetch_all(cursor, "CREATE (n:Person {id: 2});") @@ -41,8 +53,9 @@ def create_smaller_subgraph(cursor): execute_and_fetch_all(cursor, "MATCH (p:Person {id: 2}) MATCH (t:Team {id:6}) CREATE (p)-[:SUPPORTS]->(t);") -def test_is_callable(connection): - cursor = connection.cursor() +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) +def test_is_callable(multi_db): + cursor = multi_db.cursor() execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;") create_subgraph(cursor) @@ -59,8 +72,9 @@ def test_is_callable(connection): ) -def test_incorrect_graph_argument_placement(connection): - cursor = connection.cursor() +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) +def test_incorrect_graph_argument_placement(multi_db): + cursor = multi_db.cursor() execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;") create_subgraph(cursor) @@ -79,8 +93,9 @@ def test_incorrect_graph_argument_placement(connection): ) -def test_get_vertices(connection): - cursor = connection.cursor() +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) +def test_get_vertices(multi_db): + cursor = multi_db.cursor() execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;") create_subgraph(cursor) @@ -97,8 +112,9 @@ def test_get_vertices(connection): ) -def test_get_out_edges(connection): - cursor = connection.cursor() +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) +def test_get_out_edges(multi_db): + cursor = multi_db.cursor() execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;") create_subgraph(cursor) @@ -115,8 +131,9 @@ def test_get_out_edges(connection): ) -def test_get_in_edges(connection): - cursor = connection.cursor() +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) +def test_get_in_edges(multi_db): + cursor = multi_db.cursor() execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;") create_subgraph(cursor) @@ -133,8 +150,9 @@ def test_get_in_edges(connection): ) -def test_get_2_hop_edges(connection): - cursor = connection.cursor() +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) +def test_get_2_hop_edges(multi_db): + cursor = multi_db.cursor() execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;") create_subgraph(cursor) @@ -150,8 +168,9 @@ def test_get_2_hop_edges(connection): ) -def test_get_out_edges_vertex_id(connection): - cursor = connection.cursor() +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) +def test_get_out_edges_vertex_id(multi_db): + cursor = multi_db.cursor() execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;") create_subgraph(cursor=cursor) @@ -168,8 +187,9 @@ def test_get_out_edges_vertex_id(connection): ) -def test_subgraph_get_path_vertices(connection): - cursor = connection.cursor() +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) +def test_subgraph_get_path_vertices(multi_db): + cursor = multi_db.cursor() execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;") create_subgraph(cursor) @@ -185,8 +205,9 @@ def test_subgraph_get_path_vertices(connection): ) -def test_subgraph_get_path_edges(connection): - cursor = connection.cursor() +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) +def test_subgraph_get_path_edges(multi_db): + cursor = multi_db.cursor() execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;") create_subgraph(cursor) @@ -202,8 +223,9 @@ def test_subgraph_get_path_edges(connection): ) -def test_subgraph_get_path_vertices_in_subgraph(connection): - cursor = connection.cursor() +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) +def test_subgraph_get_path_vertices_in_subgraph(multi_db): + cursor = multi_db.cursor() execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;") create_subgraph(cursor) assert has_n_result_row(cursor, "MATCH (n) RETURN n;", 6) @@ -218,8 +240,9 @@ def test_subgraph_get_path_vertices_in_subgraph(connection): ) -def test_subgraph_insert_vertex_get_vertices(connection): - cursor = connection.cursor() +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) +def test_subgraph_insert_vertex_get_vertices(multi_db): + cursor = multi_db.cursor() execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;") create_subgraph(cursor) assert has_n_result_row(cursor, "MATCH (n) RETURN n;", 6) @@ -234,8 +257,9 @@ def test_subgraph_insert_vertex_get_vertices(connection): ) -def test_subgraph_insert_edge_get_vertex_out_edges(connection): - cursor = connection.cursor() +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) +def test_subgraph_insert_edge_get_vertex_out_edges(multi_db): + cursor = multi_db.cursor() execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;") create_subgraph(cursor) assert has_n_result_row(cursor, "MATCH (n) RETURN n;", 6) @@ -250,8 +274,9 @@ def test_subgraph_insert_edge_get_vertex_out_edges(connection): ) -def test_subgraph_create_edge_both_vertices_not_in_projected_graph_error(connection): - cursor = connection.cursor() +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) +def test_subgraph_create_edge_both_vertices_not_in_projected_graph_error(multi_db): + cursor = multi_db.cursor() execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;") create_subgraph(cursor) assert has_n_result_row(cursor, "MATCH (n) RETURN n;", 6) @@ -267,8 +292,9 @@ def test_subgraph_create_edge_both_vertices_not_in_projected_graph_error(connect ) -def test_subgraph_remove_edge_get_vertex_out_edges(connection): - cursor = connection.cursor() +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) +def test_subgraph_remove_edge_get_vertex_out_edges(multi_db): + cursor = multi_db.cursor() execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;") create_subgraph(cursor) assert has_n_result_row(cursor, "MATCH (n) RETURN n;", 6) @@ -283,8 +309,9 @@ def test_subgraph_remove_edge_get_vertex_out_edges(connection): ) -def test_subgraph_remove_edge_not_in_subgraph_error(connection): - cursor = connection.cursor() +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) +def test_subgraph_remove_edge_not_in_subgraph_error(multi_db): + cursor = multi_db.cursor() execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;") create_subgraph(cursor) assert has_n_result_row(cursor, "MATCH (n) RETURN n;", 6) @@ -299,8 +326,9 @@ def test_subgraph_remove_edge_not_in_subgraph_error(connection): ) -def test_subgraph_remove_vertex_and_out_edges_get_vertices(connection): - cursor = connection.cursor() +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) +def test_subgraph_remove_vertex_and_out_edges_get_vertices(multi_db): + cursor = multi_db.cursor() execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;") create_smaller_subgraph(cursor) assert has_n_result_row(cursor, "MATCH (n) RETURN n;", 4) diff --git a/tests/e2e/write_procedures/simple_write.py b/tests/e2e/write_procedures/simple_write.py index 804e2edfd..4078419f7 100644 --- a/tests/e2e/write_procedures/simple_write.py +++ b/tests/e2e/write_procedures/simple_write.py @@ -1,4 +1,4 @@ -# Copyright 2021 Memgraph Ltd. +# Copyright 2023 Memgraph Ltd. # # Use of this software is governed by the Business Source License # included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -9,17 +9,30 @@ # by the Apache License, Version 2.0, included in the file # licenses/APL.txt. -import typing -import mgclient import sys +import typing + +import mgclient import pytest -from common import execute_and_fetch_all, has_one_result_row, has_n_result_row +from common import execute_and_fetch_all, has_n_result_row, has_one_result_row -def test_is_write(connection): +@pytest.fixture(scope="function") +def multi_db(request, connection): + cursor = connection.cursor() + if request.param: + execute_and_fetch_all(cursor, "CREATE DATABASE clean") + execute_and_fetch_all(cursor, "USE DATABASE clean") + execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n") + pass + yield connection + + +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) +def test_is_write(multi_db): is_write = 2 result_order = "name, signature, is_write" - cursor = connection.cursor() + cursor = multi_db.cursor() for proc in execute_and_fetch_all( cursor, "CALL mg.procedures() YIELD * WITH name, signature, " @@ -41,8 +54,9 @@ def test_is_write(connection): assert cursor.description[2].name == "is_write" -def test_single_vertex(connection): - cursor = connection.cursor() +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) +def test_single_vertex(multi_db): + cursor = multi_db.cursor() assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0) result = execute_and_fetch_all(cursor, "CALL write.create_vertex() YIELD v RETURN v") vertex = result[0][0] @@ -93,8 +107,9 @@ def test_single_vertex(connection): assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0) -def test_single_edge(connection): - cursor = connection.cursor() +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) +def test_single_edge(multi_db): + cursor = multi_db.cursor() assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0) v1_id = execute_and_fetch_all(cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id v2_id = execute_and_fetch_all(cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id @@ -134,8 +149,9 @@ def test_single_edge(connection): assert has_n_result_row(cursor, "MATCH ()-[e]->() RETURN e", 0) -def test_detach_delete_vertex(connection): - cursor = connection.cursor() +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) +def test_detach_delete_vertex(multi_db): + cursor = multi_db.cursor() assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0) v1_id = execute_and_fetch_all(cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id v2_id = execute_and_fetch_all(cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id @@ -156,8 +172,9 @@ def test_detach_delete_vertex(connection): assert has_one_result_row(cursor, f"MATCH (n) WHERE id(n) = {v2_id} RETURN n") -def test_graph_mutability(connection): - cursor = connection.cursor() +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) +def test_graph_mutability(multi_db): + cursor = multi_db.cursor() assert has_n_result_row(cursor, "MATCH (n) RETURN n", 0) v1_id = execute_and_fetch_all(cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id v2_id = execute_and_fetch_all(cursor, "CALL write.create_vertex() YIELD v RETURN v")[0][0].id @@ -193,8 +210,9 @@ def test_graph_mutability(connection): test_mutability(False) -def test_log_message(connection): - cursor = connection.cursor() +@pytest.mark.parametrize("multi_db", [False, True], indirect=True) +def test_log_message(multi_db): + cursor = multi_db.cursor() success = execute_and_fetch_all(cursor, f"CALL read.log_message('message') YIELD success RETURN success")[0][0] assert (success) is True diff --git a/tests/integration/audit/runner.py b/tests/integration/audit/runner.py index 92e175a1b..466c91d9a 100755 --- a/tests/integration/audit/runner.py +++ b/tests/integration/audit/runner.py @@ -21,6 +21,7 @@ import sys import tempfile import time +DEFAULT_DB = "memgraph" SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", "..")) @@ -37,26 +38,24 @@ QUERIES = [ ("CREATE (n {name: $name})", {"name": 5, "leftover": 42}), ("MATCH (n), (m) CREATE (n)-[:e {when: $when}]->(m)", {"when": 42}), ("MATCH (n) RETURN n", {}), - ( - "MATCH (n), (m {type: $type}) RETURN count(n), count(m)", - {"type": "dadada"} - ), + ("MATCH (n), (m {type: $type}) RETURN count(n), count(m)", {"type": "dadada"}), ( "MERGE (n) ON CREATE SET n.created = timestamp() " "ON MATCH SET n.lastSeen = timestamp() " "RETURN n.name, n.created, n.lastSeen", - {} - ), - ( - "MATCH (n {value: $value}) SET n.value = 0 RETURN n", - {"value": "nandare!"} + {}, ), + ("MATCH (n {value: $value}) SET n.value = 0 RETURN n", {"value": "nandare!"}), ("MATCH (n), (m) SET n.value = m.value", {}), ("MATCH (n {test: $test}) REMOVE n.value", {"test": 48}), ("MATCH (n), (m) REMOVE n.value, m.value", {}), ("CREATE INDEX ON :User (id)", {}), ] +CREATE_DB_QUERIES = [ + ("CREATE DATABASE clean", {}), +] + def wait_for_server(port, delay=0.1): cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)] @@ -65,6 +64,13 @@ def wait_for_server(port, delay=0.1): time.sleep(delay) +def gen_mt_queries(queries, db): + out = [] + for query, params in queries: + out.append((db, query, params)) + return out + + def execute_test(memgraph_binary, tester_binary): storage_directory = tempfile.TemporaryDirectory() memgraph_args = [ @@ -74,7 +80,8 @@ def execute_test(memgraph_binary, tester_binary): storage_directory.name, "--audit-enabled", "--log-file=memgraph.log", - "--log-level=TRACE"] + "--log-level=TRACE", + ] # Start the memgraph binary memgraph = subprocess.Popen(list(map(str, memgraph_args))) @@ -90,17 +97,31 @@ def execute_test(memgraph_binary, tester_binary): assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" def execute_queries(queries): - for query, params in queries: + for db, query, params in queries: print(query, params) - args = [tester_binary, "--query", query, - "--params-json", json.dumps(params)] + args = [tester_binary, "--query", query, "--use-db", db, "--params-json", json.dumps(params)] subprocess.run(args).check_returncode() + # Test default db + mt_queries = gen_mt_queries(QUERIES, DEFAULT_DB) + # Execute all queries print("\033[1;36m~~ Starting query execution ~~\033[0m") - execute_queries(QUERIES) + execute_queries(mt_queries) print("\033[1;36m~~ Finished query execution ~~\033[0m\n") + # Test new db + print("\033[1;36m~~ Creating clean database ~~\033[0m") + mt_queries2 = gen_mt_queries(CREATE_DB_QUERIES, DEFAULT_DB) + execute_queries(mt_queries2) + print("\033[1;36m~~ Finished creating clean database ~~\033[0m\n") + + # Execute all queries on clean database + mt_queries3 = gen_mt_queries(QUERIES, "clean") + print("\033[1;36m~~ Starting query execution on clean database ~~\033[0m") + execute_queries(mt_queries3) + print("\033[1;36m~~ Finished query execution on clean database ~~\033[0m\n") + # Shutdown the memgraph binary memgraph.terminate() @@ -109,26 +130,37 @@ def execute_test(memgraph_binary, tester_binary): # Verify the written log print("\033[1;36m~~ Starting log verification ~~\033[0m") with open(os.path.join(storage_directory.name, "audit", "audit.log")) as f: - reader = csv.reader(f, delimiter=',', doublequote=False, - escapechar='\\', lineterminator='\n', - quotechar='"', quoting=csv.QUOTE_MINIMAL, - skipinitialspace=False, strict=True) + reader = csv.reader( + f, + delimiter=",", + doublequote=False, + escapechar="\\", + lineterminator="\n", + quotechar='"', + quoting=csv.QUOTE_MINIMAL, + skipinitialspace=False, + strict=True, + ) queries = [] for line in reader: - timestamp, address, username, query, params = line + timestamp, address, username, database, query, params = line params = json.loads(params) - queries.append((query, params)) - print(query, params) + if query.startswith("USE DATABASE"): + continue # Skip all databases switching queries + queries.append((database, query, params)) + print(database, query, params) - assert queries == QUERIES, "Logged queries don't match " \ - "executed queries!" + # Combine all queries executed + all_queries = mt_queries + all_queries += mt_queries2 + all_queries += mt_queries3 + assert queries == all_queries, "Logged queries don't match " "executed queries!" print("\033[1;36m~~ Finished log verification ~~\033[0m\n") if __name__ == "__main__": memgraph_binary = os.path.join(PROJECT_DIR, "build", "memgraph") - tester_binary = os.path.join(PROJECT_DIR, "build", "tests", - "integration", "audit", "tester") + tester_binary = os.path.join(PROJECT_DIR, "build", "tests", "integration", "audit", "tester") parser = argparse.ArgumentParser() parser.add_argument("--memgraph", default=memgraph_binary) diff --git a/tests/integration/audit/tester.cpp b/tests/integration/audit/tester.cpp index 973227253..a2e4c32f4 100644 --- a/tests/integration/audit/tester.cpp +++ b/tests/integration/audit/tester.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -9,6 +9,7 @@ // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. +#include #include #include @@ -25,6 +26,7 @@ DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server."); DEFINE_string(query, "", "Query to execute"); DEFINE_string(params_json, "{}", "Params for the query"); +DEFINE_string(use_db, "memgraph", "Database to run the query against"); memgraph::communication::bolt::Value JsonToValue(const nlohmann::json &jv) { memgraph::communication::bolt::Value ret; @@ -89,6 +91,7 @@ int main(int argc, char **argv) { memgraph::communication::bolt::Client client(context); client.Connect(endpoint, FLAGS_username, FLAGS_password); + client.Execute(fmt::format("USE DATABASE {}", FLAGS_use_db), {}); client.Execute(FLAGS_query, JsonToValue(nlohmann::json::parse(FLAGS_params_json)).ValueMap()); return 0; diff --git a/tests/integration/auth/runner.py b/tests/integration/auth/runner.py index 7953bc1ee..9c4ab8ca7 100755 --- a/tests/integration/auth/runner.py +++ b/tests/integration/auth/runner.py @@ -29,15 +29,8 @@ PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", "..")) QUERIES = [ # CREATE - ( - "CREATE (n)", - ("CREATE",) - ), - ( - "MATCH (n), (m) CREATE (n)-[:e]->(m)", - ("CREATE", "MATCH") - ), - + ("CREATE (n)", ("CREATE",)), + ("MATCH (n), (m) CREATE (n)-[:e]->(m)", ("CREATE", "MATCH")), # DELETE ( "MATCH (n) DELETE n", @@ -47,116 +40,43 @@ QUERIES = [ "MATCH (n) DETACH DELETE n", ("DELETE", "MATCH"), ), - # MATCH - ( - "MATCH (n) RETURN n", - ("MATCH",) - ), - ( - "MATCH (n), (m) RETURN count(n), count(m)", - ("MATCH",) - ), - + ("MATCH (n) RETURN n", ("MATCH",)), + ("MATCH (n), (m) RETURN count(n), count(m)", ("MATCH",)), # MERGE ( "MERGE (n) ON CREATE SET n.created = timestamp() " "ON MATCH SET n.lastSeen = timestamp() " "RETURN n.name, n.created, n.lastSeen", - ("MERGE",) + ("MERGE",), ), - # SET - ( - "MATCH (n) SET n.value = 0 RETURN n", - ("SET", "MATCH") - ), - ( - "MATCH (n), (m) SET n.value = m.value", - ("SET", "MATCH") - ), - + ("MATCH (n) SET n.value = 0 RETURN n", ("SET", "MATCH")), + ("MATCH (n), (m) SET n.value = m.value", ("SET", "MATCH")), # REMOVE - ( - "MATCH (n) REMOVE n.value", - ("REMOVE", "MATCH") - ), - ( - "MATCH (n), (m) REMOVE n.value, m.value", - ("REMOVE", "MATCH") - ), - + ("MATCH (n) REMOVE n.value", ("REMOVE", "MATCH")), + ("MATCH (n), (m) REMOVE n.value, m.value", ("REMOVE", "MATCH")), # INDEX - ( - "CREATE INDEX ON :User (id)", - ("INDEX",) - ), - + ("CREATE INDEX ON :User (id)", ("INDEX",)), # AUTH - ( - "CREATE ROLE test_role", - ("AUTH",) - ), - ( - "DROP ROLE test_role", - ("AUTH",) - ), - ( - "SHOW ROLES", - ("AUTH",) - ), - ( - "CREATE USER test_user", - ("AUTH",) - ), - ( - "SET PASSWORD FOR test_user TO '1234'", - ("AUTH",) - ), - ( - "DROP USER test_user", - ("AUTH",) - ), - ( - "SHOW USERS", - ("AUTH",) - ), - ( - "SET ROLE FOR test_user TO test_role", - ("AUTH",) - ), - ( - "CLEAR ROLE FOR test_user", - ("AUTH",) - ), - ( - "GRANT ALL PRIVILEGES TO test_user", - ("AUTH",) - ), - ( - "DENY ALL PRIVILEGES TO test_user", - ("AUTH",) - ), - ( - "REVOKE ALL PRIVILEGES FROM test_user", - ("AUTH",) - ), - ( - "SHOW PRIVILEGES FOR test_user", - ("AUTH",) - ), - ( - "SHOW ROLE FOR test_user", - ("AUTH",) - ), - ( - "SHOW USERS FOR test_role", - ("AUTH",) - ), + ("CREATE ROLE test_role", ("AUTH",)), + ("DROP ROLE test_role", ("AUTH",)), + ("SHOW ROLES", ("AUTH",)), + ("CREATE USER test_user", ("AUTH",)), + ("SET PASSWORD FOR test_user TO '1234'", ("AUTH",)), + ("DROP USER test_user", ("AUTH",)), + ("SHOW USERS", ("AUTH",)), + ("SET ROLE FOR test_user TO test_role", ("AUTH",)), + ("CLEAR ROLE FOR test_user", ("AUTH",)), + ("GRANT ALL PRIVILEGES TO test_user", ("AUTH",)), + ("DENY ALL PRIVILEGES TO test_user", ("AUTH",)), + ("REVOKE ALL PRIVILEGES FROM test_user", ("AUTH",)), + ("SHOW PRIVILEGES FOR test_user", ("AUTH",)), + ("SHOW ROLE FOR test_user", ("AUTH",)), + ("SHOW USERS FOR test_role", ("AUTH",)), ] -UNAUTHORIZED_ERROR = "You are not authorized to execute this query! Please " \ - "contact your database administrator." +UNAUTHORIZED_ERROR = r"^You are not authorized to execute this query.*?Please contact your database administrator\." def wait_for_server(port, delay=0.1): @@ -166,8 +86,16 @@ def wait_for_server(port, delay=0.1): time.sleep(delay) -def execute_tester(binary, queries, should_fail=False, failure_message="", - username="", password="", check_failure=True): +def execute_tester( + binary, + queries, + should_fail=False, + failure_message="", + username="", + password="", + check_failure=True, + connection_should_fail=False, +): args = [binary, "--username", username, "--password", password] if should_fail: args.append("--should-fail") @@ -175,6 +103,8 @@ def execute_tester(binary, queries, should_fail=False, failure_message="", args.extend(["--failure-message", failure_message]) if check_failure: args.append("--check-failure") + if connection_should_fail: + args.append("--connection-should-fail") args.extend(queries) subprocess.run(args).check_returncode() @@ -200,18 +130,31 @@ def check_permissions(query_perms, user_perms): def execute_test(memgraph_binary, tester_binary, checker_binary): storage_directory = tempfile.TemporaryDirectory() - memgraph_args = [memgraph_binary, - "--data-directory", storage_directory.name] + memgraph_args = [memgraph_binary, "--data-directory", storage_directory.name] def execute_admin_queries(queries): - return execute_tester(tester_binary, queries, should_fail=False, - check_failure=True, username="admin", - password="admin") + return execute_tester( + tester_binary, queries, should_fail=False, check_failure=True, username="admin", password="admin" + ) - def execute_user_queries(queries, should_fail=False, failure_message="", - check_failure=True): - return execute_tester(tester_binary, queries, should_fail, - failure_message, "user", "user", check_failure) + def execute_user_queries( + queries, + should_fail=False, + failure_message="", + check_failure=True, + username="user", + connection_should_fail=False, + ): + return execute_tester( + tester_binary, + queries, + should_fail, + failure_message, + username, + "user", + check_failure, + connection_should_fail, + ) # Start the memgraph binary memgraph = subprocess.Popen(list(map(str, memgraph_args))) @@ -226,12 +169,33 @@ def execute_test(memgraph_binary, tester_binary, checker_binary): memgraph.terminate() assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" + # Prepare the multi database environment + execute_admin_queries( + [ + "CREATE DATABASE db1", + "CREATE DATABASE db2", + ] + ) + # Prepare all users - execute_admin_queries([ - "CREATE USER ADmin IDENTIFIED BY 'admin'", - "GRANT ALL PRIVILEGES TO admIN", - "CREATE USER usEr IDENTIFIED BY 'user'", - ]) + execute_admin_queries( + [ + "CREATE USER ADmin IDENTIFIED BY 'admin'", + "GRANT ALL PRIVILEGES TO admIN", + "GRANT DATABASE * TO admin", + "CREATE USER usEr IDENTIFIED BY 'user'", + "GRANT DATABASE db1 TO user", + "GRANT DATABASE db2 TO user", + "CREATE USER useR2 IDENTIFIED BY 'user'", + "GRANT DATABASE db2 TO user2", + "REVOKE DATABASE memgraph FROM user2", + "SET MAIN DATABASE db2 FOR user2", + "CREATE USER user3 IDENTIFIED BY 'user'", + "GRANT ALL PRIVILEGES TO user3", + "GRANT DATABASE * TO user3", + "REVOKE DATABASE memgraph FROM user3", + ] + ) # Find all existing permissions permissions = set() @@ -241,14 +205,99 @@ def execute_test(memgraph_binary, tester_binary, checker_binary): # Run the test with all combinations of permissions print("\033[1;36m~~ Starting query test ~~\033[0m") + for db in ["memgraph", "db1"]: + print("\033[1;36m~~ Running against db {} ~~\033[0m".format(db)) + execute_user_queries(["USE DATABASE {}".format(db)], should_fail=True, failure_message=UNAUTHORIZED_ERROR) + execute_admin_queries(["GRANT MULTI_DATABASE_USE TO User"]) + execute_user_queries(["USE DATABASE {}".format(db)], check_failure=False, failure_message=UNAUTHORIZED_ERROR) + for mask in range(0, 2 ** len(permissions)): + user_perms = get_permissions(permissions, mask) + print("\033[1;34m~~ Checking queries with privileges: ", ", ".join(user_perms), " ~~\033[0m") + admin_queries = ["REVOKE ALL PRIVILEGES FROM uSer"] + if len(user_perms) > 0: + admin_queries.append("GRANT {} TO User".format(", ".join(user_perms))) + execute_admin_queries(admin_queries) + authorized, unauthorized = [], [] + for query, query_perms in QUERIES: + if check_permissions(query_perms, user_perms): + authorized.append(query) + else: + unauthorized.append(query) + execute_user_queries(authorized, check_failure=False, failure_message=UNAUTHORIZED_ERROR) + execute_user_queries(unauthorized, should_fail=True, failure_message=UNAUTHORIZED_ERROR) + print("\033[1;36m~~ Finished query test ~~\033[0m\n") + + # Run the user/role permissions test + print("\033[1;36m~~ Starting permissions test ~~\033[0m") + execute_admin_queries( + [ + "CREATE ROLE roLe", + "REVOKE ALL PRIVILEGES FROM uSeR", + ] + ) + execute_checker(checker_binary, []) + for db in ["memgraph", "db1"]: + print("\033[1;36m~~ Running against db {} ~~\033[0m".format(db)) + execute_user_queries(["USE DATABASE {}".format(db)], should_fail=True, failure_message=UNAUTHORIZED_ERROR) + execute_admin_queries(["GRANT MULTI_DATABASE_USE TO User"]) + execute_user_queries(["USE DATABASE {}".format(db)], check_failure=False, failure_message=UNAUTHORIZED_ERROR) + execute_admin_queries(["REVOKE MULTI_DATABASE_USE FROM User"]) + for user_perm in ["GRANT", "DENY", "REVOKE"]: + for role_perm in ["GRANT", "DENY", "REVOKE"]: + for mapped in [True, False]: + print( + "\033[1;34m~~ Checking permissions with user ", + user_perm, + ", role ", + role_perm, + "user mapped to role:", + mapped, + " ~~\033[0m", + ) + if mapped: + execute_admin_queries(["SET ROLE FOR USER TO roLE"]) + else: + execute_admin_queries(["CLEAR ROLE FOR user"]) + user_prep = "FROM" if user_perm == "REVOKE" else "TO" + role_prep = "FROM" if role_perm == "REVOKE" else "TO" + execute_admin_queries( + [ + "{} MATCH {} user".format(user_perm, user_prep), + "{} MATCH {} rOLe".format(role_perm, role_prep), + ] + ) + expected = [] + perms = [user_perm, role_perm] if mapped else [user_perm] + if "DENY" in perms: + expected = ["MATCH", "DENY"] + elif "GRANT" in perms: + expected = ["MATCH", "GRANT"] + if len(expected) > 0: + details = [] + if user_perm == "GRANT": + details.append("GRANTED TO USER") + elif user_perm == "DENY": + details.append("DENIED TO USER") + if mapped: + if role_perm == "GRANT": + details.append("GRANTED TO ROLE") + elif role_perm == "DENY": + details.append("DENIED TO ROLE") + expected.append(", ".join(details)) + execute_checker(checker_binary, expected) + print("\033[1;36m~~ Finished permissions test ~~\033[0m\n") + + # Check database access + # user has access to every db (with global privileges) <- tested above + # user2 has access only to db2 (and it set to default) + # user3 has access only to db2, but the default db is set to default (shouldn't even connect) + print("\033[1;36m~~ Checking privileges with custom default db ~~\033[0m\n") for mask in range(0, 2 ** len(permissions)): user_perms = get_permissions(permissions, mask) - print("\033[1;34m~~ Checking queries with privileges: ", - ", ".join(user_perms), " ~~\033[0m") - admin_queries = ["REVOKE ALL PRIVILEGES FROM uSer"] + print("\033[1;34m~~ Checking queries with privileges: ", ", ".join(user_perms), " ~~\033[0m") + admin_queries = ["REVOKE ALL PRIVILEGES FROM uSer2"] if len(user_perms) > 0: - admin_queries.append( - "GRANT {} TO User".format(", ".join(user_perms))) + admin_queries.append("GRANT {} TO User2".format(", ".join(user_perms))) execute_admin_queries(admin_queries) authorized, unauthorized = [], [] for query, query_perms in QUERIES: @@ -256,55 +305,26 @@ def execute_test(memgraph_binary, tester_binary, checker_binary): authorized.append(query) else: unauthorized.append(query) - execute_user_queries(authorized, check_failure=False, - failure_message=UNAUTHORIZED_ERROR) - execute_user_queries(unauthorized, should_fail=True, - failure_message=UNAUTHORIZED_ERROR) - print("\033[1;36m~~ Finished query test ~~\033[0m\n") + execute_user_queries(authorized, check_failure=False, failure_message=UNAUTHORIZED_ERROR, username="user2") + execute_user_queries(unauthorized, should_fail=True, failure_message=UNAUTHORIZED_ERROR, username="user2") + print("\033[1;36m~~ Finished custom default db checks ~~\033[0m\n") - # Run the user/role permissions test - print("\033[1;36m~~ Starting permissions test ~~\033[0m") - execute_admin_queries([ - "CREATE ROLE roLe", - "REVOKE ALL PRIVILEGES FROM uSeR", - ]) - execute_checker(checker_binary, []) - for user_perm in ["GRANT", "DENY", "REVOKE"]: - for role_perm in ["GRANT", "DENY", "REVOKE"]: - for mapped in [True, False]: - print("\033[1;34m~~ Checking permissions with user ", - user_perm, ", role ", role_perm, - "user mapped to role:", mapped, " ~~\033[0m") - if mapped: - execute_admin_queries(["SET ROLE FOR USER TO roLE"]) - else: - execute_admin_queries(["CLEAR ROLE FOR user"]) - user_prep = "FROM" if user_perm == "REVOKE" else "TO" - role_prep = "FROM" if role_perm == "REVOKE" else "TO" - execute_admin_queries([ - "{} MATCH {} user".format(user_perm, user_prep), - "{} MATCH {} rOLe".format(role_perm, role_prep) - ]) - expected = [] - perms = [user_perm, role_perm] if mapped else [user_perm] - if "DENY" in perms: - expected = ["MATCH", "DENY"] - elif "GRANT" in perms: - expected = ["MATCH", "GRANT"] - if len(expected) > 0: - details = [] - if user_perm == "GRANT": - details.append("GRANTED TO USER") - elif user_perm == "DENY": - details.append("DENIED TO USER") - if mapped: - if role_perm == "GRANT": - details.append("GRANTED TO ROLE") - elif role_perm == "DENY": - details.append("DENIED TO ROLE") - expected.append(", ".join(details)) - execute_checker(checker_binary, expected) - print("\033[1;36m~~ Finished permissions test ~~\033[0m\n") + print("\033[1;36m~~ Checking connections and database switching ~~\033[0m\n") + for db in ["memgraph", "db1"]: + print("\033[1;36m~~ Running against db {} ~~\033[0m".format(db)) + execute_admin_queries(["GRANT {} TO User2".format("MULTI_DATABASE_USE")]) + execute_user_queries( + ["USE DATABASE {}".format(db)], should_fail=True, failure_message=UNAUTHORIZED_ERROR, username="user2" + ) + print("\033[1;36m~~ Running with user3 (shouldn't even connect) ~~\033[0m") + execute_admin_queries(["GRANT {} TO User3".format("MULTI_DATABASE_USE")]) + execute_user_queries( + ["USE DATABASE db2"], + connection_should_fail=True, + failure_message="Couldn't communicate with the server!", + username="user3", + ) + print("\033[1;36m~~ Finished checking connections and database switching ~~\033[0m\n") # Shutdown the memgraph binary memgraph.terminate() @@ -313,10 +333,8 @@ def execute_test(memgraph_binary, tester_binary, checker_binary): if __name__ == "__main__": memgraph_binary = os.path.join(PROJECT_DIR, "build", "memgraph") - tester_binary = os.path.join(PROJECT_DIR, "build", "tests", - "integration", "auth", "tester") - checker_binary = os.path.join(PROJECT_DIR, "build", "tests", - "integration", "auth", "checker") + tester_binary = os.path.join(PROJECT_DIR, "build", "tests", "integration", "auth", "tester") + checker_binary = os.path.join(PROJECT_DIR, "build", "tests", "integration", "auth", "checker") parser = argparse.ArgumentParser() parser.add_argument("--memgraph", default=memgraph_binary) diff --git a/tests/integration/auth/tester.cpp b/tests/integration/auth/tester.cpp index 3ef7392a9..0bd3295ab 100644 --- a/tests/integration/auth/tester.cpp +++ b/tests/integration/auth/tester.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -9,6 +9,8 @@ // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. +#include + #include #include "communication/bolt/client.hpp" @@ -23,6 +25,7 @@ DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server."); DEFINE_bool(check_failure, false, "Set to true to enable failure checking."); DEFINE_bool(should_fail, false, "Set to true to expect a failure."); +DEFINE_bool(connection_should_fail, false, "Set to true to expect a connection failure."); DEFINE_string(failure_message, "", "Set to the expected failure message."); /** @@ -40,7 +43,26 @@ int main(int argc, char **argv) { memgraph::communication::ClientContext context(FLAGS_use_ssl); memgraph::communication::bolt::Client client(context); - client.Connect(endpoint, FLAGS_username, FLAGS_password); + std::regex re(FLAGS_failure_message); + + try { + client.Connect(endpoint, FLAGS_username, FLAGS_password); + } catch (const memgraph::communication::bolt::ClientFatalException &e) { + if (FLAGS_connection_should_fail) { + if (!FLAGS_failure_message.empty() && !std::regex_match(e.what(), re)) { + LOG_FATAL( + "The connection should have failed with an error message of '{}'' but " + "instead it failed with '{}'", + FLAGS_failure_message, e.what()); + } + return 0; + } else { + LOG_FATAL( + "The connection shoudn't have failed but it failed with an " + "error message '{}'", + e.what()); + } + } for (int i = 1; i < argc; ++i) { std::string query(argv[i]); @@ -48,7 +70,7 @@ int main(int argc, char **argv) { client.Execute(query, {}); } catch (const memgraph::communication::bolt::ClientQueryException &e) { if (!FLAGS_check_failure) { - if (!FLAGS_failure_message.empty() && e.what() == FLAGS_failure_message) { + if (!FLAGS_failure_message.empty() && std::regex_match(e.what(), re)) { LOG_FATAL( "The query should have succeeded or failed with an error " "message that isn't equal to '{}' but it failed with that error " @@ -58,7 +80,7 @@ int main(int argc, char **argv) { continue; } if (FLAGS_should_fail) { - if (!FLAGS_failure_message.empty() && e.what() != FLAGS_failure_message) { + if (!FLAGS_failure_message.empty() && !std::regex_match(e.what(), re)) { LOG_FATAL( "The query should have failed with an error message of '{}'' but " "instead it failed with '{}'", diff --git a/tests/integration/fine_grained_access/filtering.cpp b/tests/integration/fine_grained_access/filtering.cpp index cf44333f3..d8e20601e 100644 --- a/tests/integration/fine_grained_access/filtering.cpp +++ b/tests/integration/fine_grained_access/filtering.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -22,6 +22,7 @@ DEFINE_int32(port, 7687, "Server port"); DEFINE_string(username, "admin", "Username for the database"); DEFINE_string(password, "admin", "Password for the database"); DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server."); +DEFINE_string(use_db, "memgraph", "Database to run the query against"); /** * Verifies that user 'user' has privileges that are given as positional @@ -38,6 +39,7 @@ int main(int argc, char **argv) { memgraph::communication::bolt::Client client(context); client.Connect(endpoint, FLAGS_username, FLAGS_password); + client.Execute(fmt::format("USE DATABASE {}", FLAGS_use_db), {}); try { std::string query(argv[1]); diff --git a/tests/integration/fine_grained_access/runner.py b/tests/integration/fine_grained_access/runner.py index 574139fe7..6f284aa39 100644 --- a/tests/integration/fine_grained_access/runner.py +++ b/tests/integration/fine_grained_access/runner.py @@ -23,7 +23,7 @@ from typing import List SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", "..")) -UNAUTHORIZED_ERROR = "You are not authorized to execute this query! Please " "contact your database administrator." +UNAUTHORIZED_ERROR = r"^You are not authorized to execute this query.*?Please contact your database administrator\." def wait_for_server(port, delay=0.1): @@ -47,8 +47,10 @@ def execute_tester( subprocess.run(args).check_returncode() -def execute_filtering(binary: str, queries: List[str], expected: int, username: str = "", password: str = "") -> None: - args = [binary, "--username", username, "--password", password] +def execute_filtering( + binary: str, queries: List[str], expected: int, username: str = "", password: str = "", db: str = "memgraph" +) -> None: + args = [binary, "--username", username, "--password", password, "--use-db", db] args.extend(queries) args.append(str(expected)) @@ -82,35 +84,48 @@ def execute_test(memgraph_binary: str, tester_binary: str, filtering_binary: str assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" # Prepare all users - execute_admin_queries( - [ - "CREATE USER admin IDENTIFIED BY 'admin'", - "GRANT ALL PRIVILEGES TO admin", - "CREATE USER user IDENTIFIED BY 'user'", - "GRANT ALL PRIVILEGES TO user", - "GRANT LABELS :label1, :label2, :label3 TO user", - "GRANT EDGE_TYPES :edgeType1, :edgeType2 TO user", - "MERGE (l1:label1 {name: 'test1'})", - "MERGE (l2:label2 {name: 'test2'})", - "MATCH (l1:label1),(l2:label2) WHERE l1.name = 'test1' AND l2.name = 'test2' CREATE (l1)-[r:edgeType1]->(l2)", - "MERGE (l3:label3 {name: 'test3'})", - "MATCH (l1:label1),(l3:label3) WHERE l1.name = 'test1' AND l3.name = 'test3' CREATE (l1)-[r:edgeType2]->(l3)", - "MERGE (mix:label3:label1 {name: 'test4'})", - "MATCH (l1:label1),(mix:label3) WHERE l1.name = 'test1' AND mix.name = 'test4' CREATE (l1)-[r:edgeType2]->(mix)", - ] - ) + def setup_user(): + execute_admin_queries( + [ + "CREATE USER admin IDENTIFIED BY 'admin'", + "GRANT ALL PRIVILEGES TO admin", + "CREATE USER user IDENTIFIED BY 'user'", + "GRANT ALL PRIVILEGES TO user", + "GRANT LABELS :label1, :label2, :label3 TO user", + "GRANT EDGE_TYPES :edgeType1, :edgeType2 TO user", + ] + ) + + def db_setup(): + execute_admin_queries( + [ + "MERGE (l1:label1 {name: 'test1'})", + "MERGE (l2:label2 {name: 'test2'})", + "MATCH (l1:label1),(l2:label2) WHERE l1.name = 'test1' AND l2.name = 'test2' CREATE (l1)-[r:edgeType1]->(l2)", + "MERGE (l3:label3 {name: 'test3'})", + "MATCH (l1:label1),(l3:label3) WHERE l1.name = 'test1' AND l3.name = 'test3' CREATE (l1)-[r:edgeType2]->(l3)", + "MERGE (mix:label3:label1 {name: 'test4'})", + "MATCH (l1:label1),(mix:label3) WHERE l1.name = 'test1' AND mix.name = 'test4' CREATE (l1)-[r:edgeType2]->(mix)", + ] + ) + + db_setup() # default db setup + execute_admin_queries(["CREATE DATABASE db1", "USE DATABASE db1"]) + db_setup() # db1 setup - # Run the test with all combinations of permissions print("\033[1;36m~~ Starting edge filtering test ~~\033[0m") - execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 3, "user", "user") - execute_admin_queries(["DENY EDGE_TYPES :edgeType1 TO user"]) - execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 2, "user", "user") - execute_admin_queries(["GRANT EDGE_TYPES :edgeType1 TO user", "DENY LABELS :label3 TO user"]) - execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 1, "user", "user") - execute_admin_queries(["DENY LABELS :label1 TO user"]) - execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 0, "user", "user") - execute_admin_queries(["REVOKE LABELS * FROM user", "REVOKE EDGE_TYPES * FROM user"]) - execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 0, "user", "user") + for db in ["memgraph", "db1"]: + setup_user() + # Run the test with all combinations of permissions + execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 3, "user", "user", db) + execute_admin_queries(["DENY EDGE_TYPES :edgeType1 TO user"]) + execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 2, "user", "user", db) + execute_admin_queries(["GRANT EDGE_TYPES :edgeType1 TO user", "DENY LABELS :label3 TO user"]) + execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 1, "user", "user", db) + execute_admin_queries(["DENY LABELS :label1 TO user"]) + execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 0, "user", "user", db) + execute_admin_queries(["REVOKE LABELS * FROM user", "REVOKE EDGE_TYPES * FROM user"]) + execute_filtering(filtering_binary, ["MATCH (n)-[r]->(m) RETURN n,r,m"], 0, "user", "user", db) print("\033[1;36m~~ Finished edge filtering test ~~\033[0m\n") diff --git a/tests/integration/transactions/runner.sh b/tests/integration/transactions/runner.sh index e6b7e32cc..372e33f16 100755 --- a/tests/integration/transactions/runner.sh +++ b/tests/integration/transactions/runner.sh @@ -14,10 +14,14 @@ while ! nc -z -w 1 127.0.0.1 7687; do sleep 0.5 done -# Start the test. +# Start the test on default db. $binary_dir/tests/integration/transactions/tester code=$? +# Start the test on another db. +$binary_dir/tests/integration/transactions/tester --use-db db1 +code2=$? + # Shutdown the memgraph process. kill $pid wait $pid @@ -30,4 +34,12 @@ if [ $code_mg -ne 0 ]; then fi # Exit with the exitcode of the test. -exit $code +if [ $code -ne 0 ]; then + echo "Default database tests failed!" + exit $code +fi + +if [ $code2 -ne 0 ]; then + echo "Non default database tests failed!" + exit $code2 +fi diff --git a/tests/integration/transactions/tester.cpp b/tests/integration/transactions/tester.cpp index 407b11892..216596e59 100644 --- a/tests/integration/transactions/tester.cpp +++ b/tests/integration/transactions/tester.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -11,6 +11,7 @@ #include +#include #include #include @@ -24,12 +25,16 @@ DEFINE_int32(port, 7687, "Server port"); DEFINE_string(username, "", "Username for the database"); DEFINE_string(password, "", "Password for the database"); DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server."); +DEFINE_string(use_db, "memgraph", "Database to run the query against"); using namespace memgraph::communication::bolt; class BoltClient : public ::testing::Test { protected: - virtual void SetUp() { client_.Connect(endpoint_, FLAGS_username, FLAGS_password); } + virtual void SetUp() { + client_.Connect(endpoint_, FLAGS_username, FLAGS_password); + Execute("CREATE DATABASE db1"); + } virtual void TearDown() {} @@ -90,6 +95,15 @@ const std::string kCommitInvalid = "Transaction can't be committed because there was a previous error. Please " "invoke a rollback instead."; +TEST_F(BoltClient, SelectDB) { Execute(fmt::format("USE DATABASE {}", FLAGS_use_db)); } + +TEST_F(BoltClient, SelectDBUnderTx) { + EXPECT_TRUE(Execute("begin")); + EXPECT_THROW(Execute("USE DATABASE memgraph", "Multi-database queries are not allowed in multicommand transactions."), + ClientQueryException); + EXPECT_FALSE(TransactionActive()); +} + TEST_F(BoltClient, CommitWithoutTransaction) { EXPECT_THROW(Execute("commit", kNoCurrentTransactionToCommit), ClientQueryException); EXPECT_FALSE(TransactionActive()); diff --git a/tests/manual/single_query.cpp b/tests/manual/single_query.cpp index cb0bf6f80..d630d1171 100644 --- a/tests/manual/single_query.cpp +++ b/tests/manual/single_query.cpp @@ -36,7 +36,7 @@ int main(int argc, char *argv[]) { memgraph::query::Interpreter interpreter{&interpreter_context}; ResultStreamFaker stream(interpreter_context.db.get()); - auto [header, _, qid] = interpreter.Prepare(argv[1], {}, nullptr); + auto [header, _1, qid, _2] = interpreter.Prepare(argv[1], {}, nullptr); stream.Header(header); auto summary = interpreter.PullAll(&stream); stream.Summary(summary); diff --git a/tests/setup.sh b/tests/setup.sh index c4b2ceb10..bc9d614c2 100755 --- a/tests/setup.sh +++ b/tests/setup.sh @@ -13,7 +13,7 @@ PIP_DEPS=( "neo4j-driver==4.1.1" "parse==1.18.0" "parse-type==0.5.2" - "pytest==6.2.3" + "pytest==7.3.2" "pyyaml==5.4.1" "six==1.15.0" "networkx==2.4" diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index da6987260..34a094c12 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -254,6 +254,9 @@ target_link_libraries(${test_prefix}utils_signals mg-utils) add_unit_test(utils_string.cpp) target_link_libraries(${test_prefix}utils_string mg-utils) +add_unit_test(utils_sync_ptr.cpp) +target_link_libraries(${test_prefix}utils_sync_ptr) + add_unit_test(utils_synchronized.cpp) target_link_libraries(${test_prefix}utils_synchronized mg-utils) @@ -392,3 +395,20 @@ find_package(Boost REQUIRED) add_unit_test(monitoring.cpp) target_link_libraries(${test_prefix}monitoring mg-communication Boost::headers) + +# Test multi-database +if(MG_ENTERPRISE) + # add_unit_test(dbms_storage.cpp) + # target_link_libraries(${test_prefix}dbms_storage mg-storage-v2 mg-query mg-glue) + + add_unit_test(dbms_interp.cpp) + target_link_libraries(${test_prefix}dbms_interp mg-query) +endif() + +# add_unit_test(dbms_auth.cpp) +# target_link_libraries(${test_prefix}dbms_auth mg-glue) + +if(MG_ENTERPRISE) + add_unit_test_with_custom_main(dbms_sc_handler.cpp) + target_link_libraries(${test_prefix}dbms_sc_handler mg-query mg-audit mg-glue) +endif() diff --git a/tests/unit/bolt_session.cpp b/tests/unit/bolt_session.cpp index 982018b33..eaf7c8d93 100644 --- a/tests/unit/bolt_session.cpp +++ b/tests/unit/bolt_session.cpp @@ -20,7 +20,9 @@ #include "query/exceptions.hpp" #include "utils/logging.hpp" +using memgraph::communication::bolt::ChunkedEncoderBuffer; using memgraph::communication::bolt::ClientError; +using memgraph::communication::bolt::Encoder; using memgraph::communication::bolt::Session; using memgraph::communication::bolt::SessionException; using memgraph::communication::bolt::State; @@ -32,15 +34,14 @@ static const char *kQueryReturnMultiple = "UNWIND [1,2,3] as n RETURN n"; static const char *kQueryShowTx = "SHOW TRANSACTIONS"; static const char *kQueryEmpty = "no results"; -class TestSessionData {}; +class TestSessionContext {}; -class TestSession : public Session { +class TestSession final : public Session { public: - using Session::TEncoder; + using TEncoder = Encoder>; - TestSession(TestSessionData *data, TestInputStream *input_stream, TestOutputStream *output_stream) + TestSession(TestSessionContext *data, TestInputStream *input_stream, TestOutputStream *output_stream) : Session(input_stream, output_stream) {} - std::pair, std::optional> Interpret( const std::string &query, const std::map ¶ms, const std::map &extra) override { @@ -96,7 +97,9 @@ class TestSession : public Session { } } - std::map Discard(std::optional, std::optional) override { return {}; } + std::map Discard(std::optional /*unused*/, std::optional /*unused*/) override { + return {}; + } void BeginTransaction(const std::map &extra) override { if (extra.contains("tx_metadata")) { @@ -109,10 +112,20 @@ class TestSession : public Session { void Abort() override { md_.clear(); } - bool Authenticate(const std::string &username, const std::string &password) override { return true; } + bool Authenticate(const std::string & /*username*/, const std::string & /*password*/) override { return true; } std::optional GetServerNameForInit() override { return std::nullopt; } + void Configure(const std::map &) override {} + std::string GetDatabaseName() const override { return ""; } + +#ifdef MG_ENTERPRISE + memgraph::dbms::SetForResult OnChange(const std::string &db_name) override { + return memgraph::dbms::SetForResult::SUCCESS; + } + bool OnDelete(const std::string &) override { return true; } +#endif + void TestHook_ShouldAbort() { should_abort_ = true; } private: @@ -123,11 +136,11 @@ class TestSession : public Session { // TODO: This could be done in fixture. // Shortcuts for writing variable initializations in tests -#define INIT_VARS \ - TestInputStream input_stream; \ - TestOutputStream output_stream; \ - TestSessionData session_data; \ - TestSession session(&session_data, &input_stream, &output_stream); \ +#define INIT_VARS \ + TestInputStream input_stream; \ + TestOutputStream output_stream; \ + TestSessionContext session_context; \ + TestSession session(&session_context, &input_stream, &output_stream); \ std::vector &output = output_stream.output; // Sample testdata that has correct inputs and outputs. diff --git a/tests/unit/dbms_interp.cpp b/tests/unit/dbms_interp.cpp new file mode 100644 index 000000000..a71bcb87c --- /dev/null +++ b/tests/unit/dbms_interp.cpp @@ -0,0 +1,429 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#ifdef MG_ENTERPRISE + +#include +#include +#include + +#include "dbms/global.hpp" +#include "dbms/interp_handler.hpp" + +#include "query/auth_checker.hpp" +#include "query/frontend/ast/ast.hpp" +#include "query/interpreter.hpp" + +class TestAuthHandler : public memgraph::query::AuthQueryHandler { + public: + TestAuthHandler() = default; + + bool CreateUser(const std::string & /*username*/, const std::optional & /*password*/) override { + return true; + } + bool DropUser(const std::string & /*username*/) override { return true; } + void SetPassword(const std::string & /*username*/, const std::optional & /*password*/) override {} + bool RevokeDatabaseFromUser(const std::string & /*db*/, const std::string & /*username*/) override { return true; } + bool GrantDatabaseToUser(const std::string & /*db*/, const std::string & /*username*/) override { return true; } + bool SetMainDatabase(const std::string & /*db*/, const std::string & /*username*/) override { return true; } + std::vector> GetDatabasePrivileges(const std::string & /*user*/) override { + return {}; + } + bool CreateRole(const std::string & /*rolename*/) override { return true; } + bool DropRole(const std::string & /*rolename*/) override { return true; } + std::vector GetUsernames() override { return {}; } + std::vector GetRolenames() override { return {}; } + std::optional GetRolenameForUser(const std::string & /*username*/) override { return {}; } + std::vector GetUsernamesForRole(const std::string & /*rolename*/) override { return {}; } + void SetRole(const std::string &username, const std::string & /*rolename*/) override {} + void ClearRole(const std::string &username) override {} + std::vector> GetPrivileges(const std::string & /*user_or_role*/) override { + return {}; + } + void GrantPrivilege( + const std::string & /*user_or_role*/, const std::vector & /*privileges*/, + const std::vector>> + & /*label_privileges*/, + const std::vector>> + & /*edge_type_privileges*/) override {} + void DenyPrivilege(const std::string & /*user_or_role*/, + const std::vector & /*privileges*/) override {} + void RevokePrivilege( + const std::string & /*user_or_role*/, const std::vector & /*privileges*/, + const std::vector>> + & /*label_privileges*/, + const std::vector>> + & /*edge_type_privileges*/) override {} +}; + +class TestAuthChecker : public memgraph::query::AuthChecker { + public: + bool IsUserAuthorized(const std::optional & /*username*/, + const std::vector & /*privileges*/, + const std::string & /*db*/) const override { + return true; + } + + std::unique_ptr GetFineGrainedAuthChecker( + const std::string & /*username*/, const memgraph::query::DbAccessor * /*db_accessor*/) const override { + return {}; + } +}; + +std::filesystem::path storage_directory{std::filesystem::temp_directory_path() / "MG_test_unit_dbms_interp"}; + +memgraph::query::InterpreterConfig default_conf{}; + +memgraph::storage::Config default_storage_conf(std::string name = "") { + return {.durability = {.storage_directory = storage_directory / name, + .snapshot_wal_mode = + memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL}, + .disk = {.main_storage_directory = storage_directory / name / "disk"}}; +} + +class TestHandler { +} test_handler; + +class DBMS_Interp : public ::testing::Test { + protected: + void SetUp() override { Clear(); } + + void TearDown() override { Clear(); } + + private: + void Clear() { + if (std::filesystem::exists(storage_directory)) { + std::filesystem::remove_all(storage_directory); + } + } +}; + +TEST_F(DBMS_Interp, New) { + memgraph::dbms::InterpContextHandler ih; + memgraph::storage::Config db_conf{ + .durability = {.storage_directory = storage_directory, + .snapshot_wal_mode = + memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL}, + .disk = {.main_storage_directory = storage_directory / "disk"}}; + TestAuthHandler ah; + TestAuthChecker ac; + + { + // Clean initialization + auto ic1 = ih.New("ic1", test_handler, db_conf, default_conf, ah, ac); + ASSERT_TRUE(ic1.HasValue() && ic1.GetValue() != nullptr); + ASSERT_TRUE(std::filesystem::exists(storage_directory / "triggers")); + ASSERT_TRUE(std::filesystem::exists(storage_directory / "streams")); + ASSERT_NE(ic1.GetValue()->db, nullptr); + ASSERT_EQ(&ic1.GetValue()->sc_handler_, &test_handler); + ASSERT_EQ(ih.GetConfig("ic1")->storage_config.durability.storage_directory, storage_directory); + } + { + memgraph::storage::Config db_conf2{ + .durability = {.storage_directory = storage_directory, + .snapshot_wal_mode = + memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL}, + .disk = {.main_storage_directory = storage_directory / "disk"}}; + // Try to override data directory + auto ic2 = ih.New("ic2", test_handler, db_conf2, default_conf, ah, ac); + ASSERT_TRUE(ic2.HasError() && ic2.GetError() == memgraph::dbms::NewError::EXISTS); + } + { + memgraph::storage::Config db_conf3{ + .durability = {.storage_directory = storage_directory / "ic3", + .snapshot_wal_mode = + memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL}, + .disk = {.main_storage_directory = storage_directory / "disk"}}; + // Try to override the name "ic1" + auto ic3 = ih.New("ic1", test_handler, db_conf3, default_conf, ah, ac); + ASSERT_TRUE(ic3.HasError() && ic3.GetError() == memgraph::dbms::NewError::EXISTS); + } + { + // Another clean initialization + memgraph::storage::Config db_conf4{ + .durability = {.storage_directory = storage_directory / "ic4", + .snapshot_wal_mode = + memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL}, + .disk = {.main_storage_directory = storage_directory / "disk"}}; + auto ic4 = ih.New("ic4", test_handler, db_conf4, default_conf, ah, ac); + ASSERT_TRUE(ic4.HasValue() && ic4.GetValue() != nullptr); + ASSERT_TRUE(std::filesystem::exists(storage_directory / "ic4" / "triggers")); + ASSERT_TRUE(std::filesystem::exists(storage_directory / "ic4" / "streams")); + ASSERT_EQ(&ic4.GetValue()->sc_handler_, &test_handler); + ASSERT_EQ(ih.GetConfig("ic4")->storage_config.durability.storage_directory, storage_directory / "ic4"); + } +} + +TEST_F(DBMS_Interp, Get) { + memgraph::dbms::InterpContextHandler ih; + TestAuthHandler ah; + TestAuthChecker ac; + + memgraph::storage::Config db_conf{ + .durability = {.storage_directory = storage_directory / "ic1", + .snapshot_wal_mode = + memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL}, + .disk = {.main_storage_directory = storage_directory / "disk"}}; + auto ic1 = ih.New("ic1", test_handler, db_conf, default_conf, ah, ac); + ASSERT_TRUE(ic1.HasValue() && ic1.GetValue() != nullptr); + + auto ic1_get = ih.Get("ic1"); + ASSERT_TRUE(ic1_get && *ic1_get == ic1.GetValue()); + + memgraph::storage::Config db_conf2{ + .durability = {.storage_directory = storage_directory / "ic2", + .snapshot_wal_mode = + memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL}, + .disk = {.main_storage_directory = storage_directory / "disk"}}; + auto ic2 = ih.New("ic2", test_handler, db_conf2, default_conf, ah, ac); + ASSERT_TRUE(ic2.HasValue() && ic2.GetValue() != nullptr); + + auto ic2_get = ih.Get("ic2"); + ASSERT_TRUE(ic2_get && *ic2_get == ic2.GetValue()); + + ASSERT_FALSE(ih.Get("aa")); + ASSERT_FALSE(ih.Get("ic1 ")); + ASSERT_FALSE(ih.Get("ic21")); + ASSERT_FALSE(ih.Get(" ic2")); +} + +TEST_F(DBMS_Interp, Delete) { + memgraph::dbms::InterpContextHandler ih; + TestAuthHandler ah; + TestAuthChecker ac; + + memgraph::storage::Config db_conf{ + .durability = {.storage_directory = storage_directory / "ic1", + .snapshot_wal_mode = + memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL}, + .disk = {.main_storage_directory = storage_directory / "disk"}}; + { + auto ic1 = ih.New("ic1", test_handler, db_conf, default_conf, ah, ac); + ASSERT_TRUE(ic1.HasValue() && ic1.GetValue() != nullptr); + } + + memgraph::storage::Config db_conf2{ + .durability = {.storage_directory = storage_directory / "ic2", + .snapshot_wal_mode = + memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL}, + .disk = {.main_storage_directory = storage_directory / "disk"}}; + { + auto ic2 = ih.New("ic2", test_handler, db_conf2, default_conf, ah, ac); + ASSERT_TRUE(ic2.HasValue() && ic2.GetValue() != nullptr); + } + + ASSERT_TRUE(ih.Delete("ic1")); + ASSERT_FALSE(ih.Get("ic1")); + ASSERT_FALSE(ih.Delete("ic1")); + ASSERT_FALSE(ih.Delete("ic3")); +} + +/** + * + * + * + * + * + * Test storage (previous StorageHandler, now handled via InterpretContext) + * + * + * + * + * + */ +TEST_F(DBMS_Interp, StorageNew) { + memgraph::dbms::InterpContextHandler ih; + TestAuthHandler ah; + TestAuthChecker ac; + + { ASSERT_FALSE(ih.GetConfig("db1")); } + { + // With custom config + memgraph::storage::Config db_config{ + .durability = {.storage_directory = storage_directory / "db2", + .snapshot_wal_mode = + memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL}, + .disk = {.main_storage_directory = storage_directory / "disk"}}; + auto db2 = ih.New("db2", test_handler, db_config, default_conf, ah, ac); + ASSERT_TRUE(db2.HasValue() && db2.GetValue() != nullptr); + ASSERT_TRUE(std::filesystem::exists(storage_directory / "db2")); + } + { + // With default config + auto db3 = ih.New("db3", test_handler, default_storage_conf("db3"), default_conf, ah, ac); + ASSERT_TRUE(db3.HasValue() && db3.GetValue() != nullptr); + ASSERT_TRUE(std::filesystem::exists(storage_directory / "db3")); + auto db4 = ih.New("db4", test_handler, default_storage_conf("four"), default_conf, ah, ac); + ASSERT_TRUE(db4.HasValue() && db4.GetValue() != nullptr); + ASSERT_TRUE(std::filesystem::exists(storage_directory / "four")); + auto db5 = ih.New("db5", test_handler, default_storage_conf("db3"), default_conf, ah, ac); + ASSERT_TRUE(db5.HasError() && db5.GetError() == memgraph::dbms::NewError::EXISTS); + } + + auto all = ih.All(); + std::sort(all.begin(), all.end()); + ASSERT_EQ(all.size(), 3); + ASSERT_EQ(all[0], "db2"); + ASSERT_EQ(all[1], "db3"); + ASSERT_EQ(all[2], "db4"); +} + +TEST_F(DBMS_Interp, StorageGet) { + memgraph::dbms::InterpContextHandler ih; + TestAuthHandler ah; + TestAuthChecker ac; + + auto db1 = ih.New("db1", test_handler, default_storage_conf("db1"), default_conf, ah, ac); + auto db2 = ih.New("db2", test_handler, default_storage_conf("db2"), default_conf, ah, ac); + auto db3 = ih.New("db3", test_handler, default_storage_conf("db3"), default_conf, ah, ac); + + ASSERT_TRUE(db1.HasValue()); + ASSERT_TRUE(db2.HasValue()); + ASSERT_TRUE(db3.HasValue()); + + auto get_db1 = ih.Get("db1"); + auto get_db2 = ih.Get("db2"); + auto get_db3 = ih.Get("db3"); + + ASSERT_TRUE(get_db1 && *get_db1 == db1.GetValue()); + ASSERT_TRUE(get_db2 && *get_db2 == db2.GetValue()); + ASSERT_TRUE(get_db3 && *get_db3 == db3.GetValue()); + + ASSERT_FALSE(ih.Get("db123")); + ASSERT_FALSE(ih.Get("db2 ")); + ASSERT_FALSE(ih.Get(" db3")); +} + +TEST_F(DBMS_Interp, StorageDelete) { + memgraph::dbms::InterpContextHandler ih; + TestAuthHandler ah; + TestAuthChecker ac; + + auto db1 = ih.New("db1", test_handler, default_storage_conf("db1"), default_conf, ah, ac); + auto db2 = ih.New("db2", test_handler, default_storage_conf("db2"), default_conf, ah, ac); + auto db3 = ih.New("db3", test_handler, default_storage_conf("db3"), default_conf, ah, ac); + + ASSERT_TRUE(db1.HasValue()); + ASSERT_TRUE(db2.HasValue()); + ASSERT_TRUE(db3.HasValue()); + + { + // Release pointer to storage + db1.GetValue().reset(); + // Delete from handler + ASSERT_TRUE(ih.Delete("db1")); + ASSERT_FALSE(ih.Get("db1")); + auto all = ih.All(); + std::sort(all.begin(), all.end()); + ASSERT_EQ(all.size(), 2); + ASSERT_EQ(all[0], "db2"); + ASSERT_EQ(all[1], "db3"); + } + + { + ASSERT_FALSE(ih.Delete("db0")); + ASSERT_FALSE(ih.Delete("db1")); + auto all = ih.All(); + std::sort(all.begin(), all.end()); + ASSERT_EQ(all.size(), 2); + ASSERT_EQ(all[0], "db2"); + ASSERT_EQ(all[1], "db3"); + } +} + +TEST_F(DBMS_Interp, StorageDeleteAndRecover) { + // memgraph::license::global_license_checker.EnableTesting(); + memgraph::dbms::InterpContextHandler ih; + TestAuthHandler ah; + TestAuthChecker ac; + + { + auto db1 = ih.New("db1", test_handler, default_storage_conf("db1"), default_conf, ah, ac); + auto db2 = ih.New("db2", test_handler, default_storage_conf("db2"), default_conf, ah, ac); + + memgraph::storage::Config conf_w_snap{ + .durability = {.storage_directory = storage_directory / "db3", + .snapshot_wal_mode = + memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL, + .snapshot_on_exit = true}, + .disk = {.main_storage_directory = storage_directory / "db3" / "disk"}}; + + auto db3 = ih.New("db3", test_handler, conf_w_snap, default_conf, ah, ac); + + ASSERT_TRUE(db1.HasValue()); + ASSERT_TRUE(db2.HasValue()); + ASSERT_TRUE(db3.HasValue()); + + // Add data to graphs + { + auto storage_dba = db1.GetValue()->db->Access(); + memgraph::query::DbAccessor dba{storage_dba.get()}; + memgraph::query::VertexAccessor v1{dba.InsertVertex()}; + memgraph::query::VertexAccessor v2{dba.InsertVertex()}; + ASSERT_TRUE(v1.AddLabel(dba.NameToLabel("l11")).HasValue()); + ASSERT_TRUE(v2.AddLabel(dba.NameToLabel("l12")).HasValue()); + ASSERT_FALSE(dba.Commit().HasError()); + } + { + auto storage_dba = db3.GetValue()->db->Access(); + memgraph::query::DbAccessor dba{storage_dba.get()}; + memgraph::query::VertexAccessor v1{dba.InsertVertex()}; + memgraph::query::VertexAccessor v2{dba.InsertVertex()}; + memgraph::query::VertexAccessor v3{dba.InsertVertex()}; + ASSERT_TRUE(v1.AddLabel(dba.NameToLabel("l31")).HasValue()); + ASSERT_TRUE(v2.AddLabel(dba.NameToLabel("l32")).HasValue()); + ASSERT_TRUE(v3.AddLabel(dba.NameToLabel("l33")).HasValue()); + ASSERT_FALSE(dba.Commit().HasError()); + } + } + + // Delete from handler + ASSERT_TRUE(ih.Delete("db1")); + ASSERT_TRUE(ih.Delete("db2")); + ASSERT_TRUE(ih.Delete("db3")); + + { + // Recover graphs (only db3) + auto db1 = ih.New("db1", test_handler, default_storage_conf("db1"), default_conf, ah, ac); + auto db2 = ih.New("db2", test_handler, default_storage_conf("db2"), default_conf, ah, ac); + + memgraph::storage::Config conf_w_rec{ + .durability = {.storage_directory = storage_directory / "db3", + .recover_on_startup = true, + .snapshot_wal_mode = + memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL}, + .disk = {.main_storage_directory = storage_directory / "db3" / "disk"}}; + + auto db3 = ih.New("db3", test_handler, conf_w_rec, default_conf, ah, ac); + + // Check content + { + // Empty + auto storage_dba = db1.GetValue()->db->Access(); + memgraph::query::DbAccessor dba{storage_dba.get()}; + ASSERT_EQ(dba.VerticesCount(), 0); + } + { + // Empty + auto storage_dba = db2.GetValue()->db->Access(); + memgraph::query::DbAccessor dba{storage_dba.get()}; + ASSERT_EQ(dba.VerticesCount(), 0); + } + { + // Full + auto storage_dba = db3.GetValue()->db->Access(); + memgraph::query::DbAccessor dba{storage_dba.get()}; + ASSERT_EQ(dba.VerticesCount(), 3); + } + } +} + +#endif diff --git a/tests/unit/dbms_sc_handler.cpp b/tests/unit/dbms_sc_handler.cpp new file mode 100644 index 000000000..a19d769cb --- /dev/null +++ b/tests/unit/dbms_sc_handler.cpp @@ -0,0 +1,343 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include +#include "query/interpreter.hpp" +#ifdef MG_ENTERPRISE + +#include +#include +#include + +#include "dbms/constants.hpp" +#include "dbms/global.hpp" +#include "dbms/session_context_handler.hpp" +#include "glue/auth_checker.hpp" +#include "glue/auth_handler.hpp" +#include "query/config.hpp" + +std::filesystem::path storage_directory{std::filesystem::temp_directory_path() / "MG_test_unit_dbms_sc_handler"}; + +static memgraph::storage::Config storage_conf; + +memgraph::query::InterpreterConfig interp_conf; + +// Global +memgraph::audit::Log audit_log{storage_directory / "audit", 100, 1000}; + +class TestInterface : public memgraph::dbms::SessionInterface { + public: + TestInterface(std::string name, auto on_change, auto on_delete) : id_(id++), db_(name) { + on_change_ = on_change; + on_delete_ = on_delete; + } + std::string UUID() const override { return std::to_string(id_); } + std::string GetDatabaseName() const override { return db_; } + memgraph::dbms::SetForResult OnChange(const std::string &name) override { return on_change_(name); } + bool OnDelete(const std::string &name) override { return on_delete_(name); } + + static int id; + int id_; + std::string db_; + std::function on_change_; + std::function on_delete_; +}; + +int TestInterface::id{0}; + +// Let this be global so we can test it different states throughout + +class TestEnvironment : public ::testing::Environment { + public: + static memgraph::dbms::SessionContextHandler *get() { return ptr_.get(); } + + void SetUp() override { + // Setup config + memgraph::storage::UpdatePaths(storage_conf, storage_directory); + storage_conf.durability.snapshot_wal_mode = + memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL; + // Clean storage directory (running multiple parallel test, run only if the first process) + if (std::filesystem::exists(storage_directory)) { + memgraph::utils::OutputFile lock_file_handle_; + lock_file_handle_.Open(storage_directory / ".lock", memgraph::utils::OutputFile::Mode::OVERWRITE_EXISTING); + if (lock_file_handle_.AcquireLock()) { + std::filesystem::remove_all(storage_directory); + } + } + ptr_ = std::make_unique( + audit_log, + memgraph::dbms::SessionContextHandler::Config{ + storage_conf, interp_conf, + [](memgraph::utils::Synchronized *auth, + std::unique_ptr &ah, + std::unique_ptr &ac) { + // Glue high level auth implementations to the query side + ah = std::make_unique(auth, ""); + ac = std::make_unique(auth); + }}, + false, true); + } + + void TearDown() override { ptr_.reset(); } + + static std::unique_ptr ptr_; +}; + +std::unique_ptr TestEnvironment::ptr_ = nullptr; + +class DBMS_Handler : public testing::Test {}; +using DBMS_HandlerDeath = DBMS_Handler; + +TEST(DBMS_Handler, Init) { + // Check that the default db has been created successfully + std::vector dirs = {"snapshots", "streams", "triggers", "wal"}; + for (const auto &dir : dirs) + ASSERT_TRUE(std::filesystem::exists(storage_directory / dir)) << (storage_directory / dir); + const auto db_path = storage_directory / "databases" / memgraph::dbms::kDefaultDB; + ASSERT_TRUE(std::filesystem::exists(db_path)); + for (const auto &dir : dirs) { + std::error_code ec; + const auto test_link = std::filesystem::read_symlink(db_path / dir, ec); + ASSERT_TRUE(!ec) << ec.message(); + ASSERT_EQ(test_link, "../../" + dir); + } +} + +TEST(DBMS_HandlerDeath, InitSameDir) { + // This will be executed in a clean process (so the singleton will NOT be initalized) + (void)(::testing::GTEST_FLAG(death_test_style) = "threadsafe"); + // NOTE: Init test has ran in another process (so holds the lock) + ASSERT_DEATH( + { + memgraph::dbms::SessionContextHandler sch( + audit_log, + {storage_conf, interp_conf, + [](memgraph::utils::Synchronized *auth, + std::unique_ptr &ah, + std::unique_ptr &ac) { + // Glue high level auth implementations to the query side + ah = std::make_unique(auth, ""); + ac = std::make_unique(auth); + }}, + false, true); + }, + R"(\b.*\b)"); +} + +TEST(DBMS_Handler, New) { + auto &sch = *TestEnvironment::get(); + { + const auto all = sch.All(); + ASSERT_EQ(all.size(), 1); + ASSERT_EQ(all[0], memgraph::dbms::kDefaultDB); + } + { + auto sc1 = sch.New("sc1"); + ASSERT_TRUE(sc1.HasValue()); + ASSERT_TRUE(std::filesystem::exists(storage_directory / "databases" / "sc1")); + ASSERT_TRUE(sc1.GetValue().interpreter_context->db != nullptr); + ASSERT_TRUE(sc1.GetValue().interpreter_context != nullptr); + ASSERT_TRUE(sc1.GetValue().audit_log != nullptr); + ASSERT_TRUE(sc1.GetValue().auth != nullptr); + const auto all = sch.All(); + ASSERT_EQ(all.size(), 2); + ASSERT_TRUE(std::find(all.begin(), all.end(), memgraph::dbms::kDefaultDB) != all.end()); + ASSERT_TRUE(std::find(all.begin(), all.end(), "sc1") != all.end()); + } + { + // Fail if name exists + auto sc2 = sch.New("sc1"); + ASSERT_TRUE(sc2.HasError() && sc2.GetError() == memgraph::dbms::NewError::EXISTS); + } + { + auto sc3 = sch.New("sc3"); + ASSERT_TRUE(sc3.HasValue()); + ASSERT_TRUE(std::filesystem::exists(storage_directory / "databases" / "sc3")); + ASSERT_TRUE(sc3.GetValue().interpreter_context->db != nullptr); + ASSERT_TRUE(sc3.GetValue().interpreter_context != nullptr); + ASSERT_TRUE(sc3.GetValue().audit_log != nullptr); + ASSERT_TRUE(sc3.GetValue().auth != nullptr); + const auto all = sch.All(); + ASSERT_EQ(all.size(), 3); + ASSERT_TRUE(std::find(all.begin(), all.end(), "sc3") != all.end()); + } +} + +TEST(DBMS_Handler, Get) { + auto &sch = *TestEnvironment::get(); + auto default_sc = sch.Get(memgraph::dbms::kDefaultDB); + ASSERT_TRUE(default_sc.interpreter_context->db != nullptr); + ASSERT_TRUE(default_sc.interpreter_context != nullptr); + ASSERT_TRUE(default_sc.audit_log != nullptr); + ASSERT_TRUE(default_sc.auth != nullptr); + + ASSERT_ANY_THROW(sch.Get("non-existent")); + + auto sc1 = sch.Get("sc1"); + ASSERT_TRUE(sc1.interpreter_context->db != nullptr); + ASSERT_TRUE(sc1.interpreter_context != nullptr); + ASSERT_TRUE(sc1.audit_log != nullptr); + ASSERT_TRUE(sc1.auth != nullptr); + + auto sc3 = sch.Get("sc3"); + ASSERT_TRUE(sc3.interpreter_context->db != nullptr); + ASSERT_TRUE(sc3.interpreter_context != nullptr); + ASSERT_TRUE(sc3.audit_log != nullptr); + ASSERT_TRUE(sc3.auth != nullptr); +} + +TEST(DBMS_Handler, SetFor) { + auto &sch = *TestEnvironment::get(); + + ASSERT_TRUE(sch.New("db1").HasValue()); + + bool ti0_on_change_ = false; + bool ti0_on_delete_ = false; + TestInterface ti0( + "memgraph", + [&ti0, &ti0_on_change_](const std::string &name) -> memgraph::dbms::SetForResult { + ti0_on_change_ = true; + if (name != ti0.db_) { + ti0.db_ = name; + return memgraph::dbms::SetForResult::SUCCESS; + } + return memgraph::dbms::SetForResult::ALREADY_SET; + }, + [&](const std::string &name) -> bool { + ti0_on_delete_ = true; + return true; + }); + + bool ti1_on_change_ = false; + bool ti1_on_delete_ = false; + TestInterface ti1( + "db1", + [&](const std::string &name) -> memgraph::dbms::SetForResult { + ti1_on_change_ = true; + return memgraph::dbms::SetForResult::SUCCESS; + }, + [&](const std::string &name) -> bool { + ti1_on_delete_ = true; + return true; + }); + + ASSERT_TRUE(sch.Register(ti0)); + ASSERT_FALSE(sch.Register(ti0)); + + { + ASSERT_EQ(sch.SetFor("0", "db1"), memgraph::dbms::SetForResult::SUCCESS); + ASSERT_TRUE(ti0_on_change_); + ti0_on_change_ = false; + ASSERT_EQ(sch.SetFor("0", "db1"), memgraph::dbms::SetForResult::ALREADY_SET); + ASSERT_TRUE(ti0_on_change_); + ti0_on_change_ = false; + ASSERT_ANY_THROW(sch.SetFor(std::to_string(TestInterface::id), "db1")); // Session does not exist + ASSERT_ANY_THROW(sch.SetFor("1", "db1")); // Session not registered + ASSERT_ANY_THROW(sch.SetFor("0", "db2")); // No db2 + ASSERT_EQ(sch.SetFor("0", "memgraph"), memgraph::dbms::SetForResult::SUCCESS); + ASSERT_TRUE(ti0_on_change_); + } + + ASSERT_TRUE(sch.Delete(ti0)); + ASSERT_FALSE(sch.Delete(ti1)); +} + +TEST(DBMS_Handler, Delete) { + auto &sch = *TestEnvironment::get(); + + bool ti0_on_change_ = false; + bool ti0_on_delete_ = false; + TestInterface ti0( + "memgraph", + [&](const std::string &name) -> memgraph::dbms::SetForResult { + ti0_on_change_ = true; + if (name != "sc3") return memgraph::dbms::SetForResult::SUCCESS; + return memgraph::dbms::SetForResult::FAIL; + }, + [&](const std::string &name) -> bool { + ti0_on_delete_ = true; + return (name != "sc3"); + }); + + bool ti1_on_change_ = false; + bool ti1_on_delete_ = false; + TestInterface ti1( + "sc1", + [&](const std::string &name) -> memgraph::dbms::SetForResult { + ti1_on_change_ = true; + ti1.db_ = name; + return memgraph::dbms::SetForResult::SUCCESS; + }, + [&](const std::string &name) -> bool { + ti1_on_delete_ = true; + return ti1.db_ != name; + }); + + ASSERT_TRUE(sch.Register(ti0)); + ASSERT_TRUE(sch.Register(ti1)); + + { + auto del = sch.Delete(memgraph::dbms::kDefaultDB); + ASSERT_TRUE(del.HasError() && del.GetError() == memgraph::dbms::DeleteError::DEFAULT_DB); + } + { + auto del = sch.Delete("non-existent"); + ASSERT_TRUE(del.HasError() && del.GetError() == memgraph::dbms::DeleteError::NON_EXISTENT); + } + { + // ti1 is using sc1 + auto del = sch.Delete("sc1"); + ASSERT_TRUE(del.HasError()); + ASSERT_TRUE(del.GetError() == memgraph::dbms::DeleteError::FAIL); + } + { + // Delete ti1 so delete will succeed + ASSERT_EQ(sch.SetFor(ti1.UUID(), "memgraph"), memgraph::dbms::SetForResult::SUCCESS); + auto del = sch.Delete("sc1"); + ASSERT_FALSE(del.HasError()) << (int)del.GetError(); + auto del2 = sch.Delete("sc1"); + ASSERT_TRUE(del2.HasError() && del2.GetError() == memgraph::dbms::DeleteError::NON_EXISTENT); + } + { + // Using based on the active interpreters + auto new_sc = sch.New("sc1"); + ASSERT_TRUE(new_sc.HasValue()) << (int)new_sc.GetError(); + auto sc = sch.Get("sc1"); + memgraph::query::Interpreter interpreter(sc.interpreter_context.get()); + sc.interpreter_context->interpreters.WithLock([&](auto &interpreters) { interpreters.insert(&interpreter); }); + auto del = sch.Delete("sc1"); + ASSERT_TRUE(del.HasError()); + ASSERT_EQ(del.GetError(), memgraph::dbms::DeleteError::USING); + sc.interpreter_context->interpreters.WithLock([&](auto &interpreters) { interpreters.erase(&interpreter); }); + } + { + // Interpreter deactivated, so we should be able to delete + auto del = sch.Delete("sc1"); + ASSERT_FALSE(del.HasError()) << (int)del.GetError(); + } + { + ASSERT_TRUE(sch.Delete(ti0)); + auto del = sch.Delete("sc3"); + ASSERT_FALSE(del.HasError()); + ASSERT_FALSE(std::filesystem::exists(storage_directory / "databases" / "sc3")); + } + + ASSERT_TRUE(sch.Delete(ti1)); +} + +int main(int argc, char *argv[]) { + ::testing::InitGoogleTest(&argc, argv); + // gtest takes ownership of the TestEnvironment ptr - we don't delete it. + ::testing::AddGlobalTestEnvironment(new TestEnvironment); + return RUN_ALL_TESTS(); +} + +#endif diff --git a/tests/unit/interpreter_faker.hpp b/tests/unit/interpreter_faker.hpp index 9f6ee5c76..63d4516f3 100644 --- a/tests/unit/interpreter_faker.hpp +++ b/tests/unit/interpreter_faker.hpp @@ -21,7 +21,7 @@ struct InterpreterFaker { auto Prepare(const std::string &query, const std::map ¶ms = {}) { ResultStreamFaker stream(interpreter_context->db.get()); - const auto [header, _, qid] = interpreter.Prepare(query, params, nullptr); + const auto [header, _1, qid, _2] = interpreter.Prepare(query, params, nullptr); stream.Header(header); return std::make_pair(std::move(stream), qid); } diff --git a/tests/unit/query_common.hpp b/tests/unit/query_common.hpp index a01d21f7c..3b920c125 100644 --- a/tests/unit/query_common.hpp +++ b/tests/unit/query_common.hpp @@ -632,9 +632,9 @@ auto GetForeach(AstStorage &storage, NamedExpression *named_expr, const std::vec this->storage.template Create( \ this->storage.template Create(variable), list, expr) #define EXISTS(pattern) this->storage.template Create(pattern) -#define AUTH_QUERY(action, user, role, user_or_role, password, privileges, labels, edgeTypes) \ - storage.Create((action), (user), (role), (user_or_role), password, (privileges), \ - (labels), (edgeTypes)) +#define AUTH_QUERY(action, user, role, user_or_role, password, database, privileges, labels, edgeTypes) \ + storage.Create((action), (user), (role), (user_or_role), password, (database), \ + (privileges), (labels), (edgeTypes)) #define DROP_USER(usernames) storage.Create((usernames)) #define CALL_PROCEDURE(...) memgraph::query::test_common::GetCallProcedure(storage, __VA_ARGS__) #define CALL_SUBQUERY(...) memgraph::query::test_common::GetCallSubquery(this->storage, __VA_ARGS__) diff --git a/tests/unit/query_dump.cpp b/tests/unit/query_dump.cpp index 6c992c38f..64491eb7f 100644 --- a/tests/unit/query_dump.cpp +++ b/tests/unit/query_dump.cpp @@ -209,7 +209,7 @@ auto Execute(memgraph::query::InterpreterContext *context, const std::string &qu memgraph::query::Interpreter interpreter(context); ResultStreamFaker stream(context->db.get()); - auto [header, _, qid] = interpreter.Prepare(query, {}, nullptr); + auto [header, _1, qid, _2] = interpreter.Prepare(query, {}, nullptr); stream.Header(header); auto summary = interpreter.PullAll(&stream); stream.Summary(summary); @@ -790,7 +790,7 @@ class StatefulInterpreter { auto Execute(const std::string &query) { ResultStreamFaker stream(context_->db.get()); - auto [header, _, qid] = interpreter_.Prepare(query, {}, nullptr); + auto [header, _1, qid, _2] = interpreter_.Prepare(query, {}, nullptr); stream.Header(header); auto summary = interpreter_.PullAll(&stream); stream.Summary(summary); diff --git a/tests/unit/query_plan_edge_cases.cpp b/tests/unit/query_plan_edge_cases.cpp index 9d09918e2..26b49b7f4 100644 --- a/tests/unit/query_plan_edge_cases.cpp +++ b/tests/unit/query_plan_edge_cases.cpp @@ -60,7 +60,7 @@ class QueryExecution : public testing::Test { auto Execute(const std::string &query) { ResultStreamFaker stream(this->interpreter_context_->db.get()); - auto [header, _, qid] = interpreter_->Prepare(query, {}, nullptr); + auto [header, _1, qid, _2] = interpreter_->Prepare(query, {}, nullptr); stream.Header(header); auto summary = interpreter_->PullAll(&stream); stream.Summary(summary); diff --git a/tests/unit/query_procedure_mgp_type.cpp b/tests/unit/query_procedure_mgp_type.cpp index 3ebdcbec2..4ed9f4926 100644 --- a/tests/unit/query_procedure_mgp_type.cpp +++ b/tests/unit/query_procedure_mgp_type.cpp @@ -249,7 +249,7 @@ TYPED_TEST(CypherType, VertexSatisfiesType) { auto vertex = dba.InsertVertex(); mgp_memory memory{memgraph::utils::NewDeleteResource()}; memgraph::utils::Allocator alloc(memory.impl); - mgp_graph graph{&dba, memgraph::storage::View::NEW}; + mgp_graph graph{&dba, memgraph::storage::View::NEW, nullptr}; auto *mgp_vertex_v = EXPECT_MGP_NO_ERROR(mgp_value *, mgp_value_make_vertex, alloc.new_object(vertex, &graph)); const memgraph::query::TypedValue tv_vertex(vertex); @@ -274,7 +274,7 @@ TYPED_TEST(CypherType, EdgeSatisfiesType) { auto edge = *dba.InsertEdge(&v1, &v2, dba.NameToEdgeType("edge_type")); mgp_memory memory{memgraph::utils::NewDeleteResource()}; memgraph::utils::Allocator alloc(memory.impl); - mgp_graph graph{&dba, memgraph::storage::View::NEW}; + mgp_graph graph{&dba, memgraph::storage::View::NEW, nullptr}; auto *mgp_edge_v = EXPECT_MGP_NO_ERROR(mgp_value *, mgp_value_make_edge, alloc.new_object(edge, &graph)); const memgraph::query::TypedValue tv_edge(edge); CheckSatisfiesTypesAndNullable( @@ -298,7 +298,7 @@ TYPED_TEST(CypherType, PathSatisfiesType) { auto edge = *dba.InsertEdge(&v1, &v2, dba.NameToEdgeType("edge_type")); mgp_memory memory{memgraph::utils::NewDeleteResource()}; memgraph::utils::Allocator alloc(memory.impl); - mgp_graph graph{&dba, memgraph::storage::View::NEW}; + mgp_graph graph{&dba, memgraph::storage::View::NEW, nullptr}; auto *mgp_vertex_v = alloc.new_object(v1, &graph); auto path = EXPECT_MGP_NO_ERROR(mgp_path *, mgp_path_make_with_start, mgp_vertex_v, &memory); ASSERT_TRUE(path); diff --git a/tests/unit/query_procedure_py_module.cpp b/tests/unit/query_procedure_py_module.cpp index abe8b7f27..9487ebb2c 100644 --- a/tests/unit/query_procedure_py_module.cpp +++ b/tests/unit/query_procedure_py_module.cpp @@ -132,7 +132,7 @@ TYPED_TEST(PyModule, PyVertex) { auto storage_dba = this->db->Access(); memgraph::query::DbAccessor dba(storage_dba.get()); mgp_memory memory{memgraph::utils::NewDeleteResource()}; - mgp_graph graph{&dba, memgraph::storage::View::OLD}; + mgp_graph graph{&dba, memgraph::storage::View::OLD, nullptr}; auto *vertex = EXPECT_MGP_NO_ERROR(mgp_vertex *, mgp_graph_get_vertex_by_id, &graph, mgp_vertex_id{0}, &memory); ASSERT_TRUE(vertex); auto *vertex_value = EXPECT_MGP_NO_ERROR(mgp_value *, mgp_value_make_vertex, @@ -182,7 +182,7 @@ TYPED_TEST(PyModule, PyEdge) { auto storage_dba = this->db->Access(); memgraph::query::DbAccessor dba(storage_dba.get()); mgp_memory memory{memgraph::utils::NewDeleteResource()}; - mgp_graph graph{&dba, memgraph::storage::View::OLD}; + mgp_graph graph{&dba, memgraph::storage::View::OLD, nullptr}; auto *start_v = EXPECT_MGP_NO_ERROR(mgp_vertex *, mgp_graph_get_vertex_by_id, &graph, mgp_vertex_id{0}, &memory); ASSERT_TRUE(start_v); auto *edges_it = EXPECT_MGP_NO_ERROR(mgp_edges_iterator *, mgp_vertex_iter_out_edges, start_v, &memory); @@ -228,7 +228,7 @@ TYPED_TEST(PyModule, PyPath) { auto storage_dba = this->db->Access(); memgraph::query::DbAccessor dba(storage_dba.get()); mgp_memory memory{memgraph::utils::NewDeleteResource()}; - mgp_graph graph{&dba, memgraph::storage::View::OLD}; + mgp_graph graph{&dba, memgraph::storage::View::OLD, nullptr}; auto *start_v = EXPECT_MGP_NO_ERROR(mgp_vertex *, mgp_graph_get_vertex_by_id, &graph, mgp_vertex_id{0}, &memory); ASSERT_TRUE(start_v); auto *path = EXPECT_MGP_NO_ERROR(mgp_path *, mgp_path_make_with_start, start_v, &memory); diff --git a/tests/unit/query_required_privileges.cpp b/tests/unit/query_required_privileges.cpp index 4bdef60dc..a648e396b 100644 --- a/tests/unit/query_required_privileges.cpp +++ b/tests/unit/query_required_privileges.cpp @@ -104,8 +104,8 @@ TEST_F(TestPrivilegeExtractor, AuthQuery) { auto label_privileges = std::vector>>{}; auto edge_type_privileges = std::vector>>{}; - auto *query = AUTH_QUERY(AuthQuery::Action::CREATE_ROLE, "", "role", "", nullptr, std::vector{}, - label_privileges, edge_type_privileges); + auto *query = AUTH_QUERY(AuthQuery::Action::CREATE_ROLE, "", "role", "", nullptr, "", + std::vector{}, label_privileges, edge_type_privileges); EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::AUTH)); } #endif diff --git a/tests/unit/query_trigger.cpp b/tests/unit/query_trigger.cpp index af73fe32c..31a28445f 100644 --- a/tests/unit/query_trigger.cpp +++ b/tests/unit/query_trigger.cpp @@ -40,8 +40,9 @@ const std::unordered_set kAllEventTypes{ class MockAuthChecker : public memgraph::query::AuthChecker { public: - MOCK_CONST_METHOD2(IsUserAuthorized, bool(const std::optional &username, - const std::vector &privileges)); + MOCK_CONST_METHOD3(IsUserAuthorized, + bool(const std::optional &username, + const std::vector &privileges, const std::string &db)); #ifdef MG_ENTERPRISE MOCK_CONST_METHOD2(GetFineGrainedAuthChecker, std::unique_ptr( @@ -1074,8 +1075,11 @@ TYPED_TEST(TriggerStoreTest, TriggerInfo) { store.AddTrigger("trigger", "RETURN 1", {}, memgraph::query::TriggerEventType::VERTEX_CREATE, memgraph::query::TriggerPhase::BEFORE_COMMIT, &this->ast_cache, &*this->dba, memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker); - expected_info.push_back({"trigger", "RETURN 1", memgraph::query::TriggerEventType::VERTEX_CREATE, - memgraph::query::TriggerPhase::BEFORE_COMMIT}); + expected_info.push_back({"trigger", + "RETURN 1", + memgraph::query::TriggerEventType::VERTEX_CREATE, + memgraph::query::TriggerPhase::BEFORE_COMMIT, + {/* no owner */}}); const auto check_trigger_info = [&] { const auto trigger_info = store.GetTriggerInfo(); @@ -1094,8 +1098,11 @@ TYPED_TEST(TriggerStoreTest, TriggerInfo) { store.AddTrigger("edge_update_trigger", "RETURN 1", {}, memgraph::query::TriggerEventType::EDGE_UPDATE, memgraph::query::TriggerPhase::AFTER_COMMIT, &this->ast_cache, &*this->dba, memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker); - expected_info.push_back({"edge_update_trigger", "RETURN 1", memgraph::query::TriggerEventType::EDGE_UPDATE, - memgraph::query::TriggerPhase::AFTER_COMMIT}); + expected_info.push_back({"edge_update_trigger", + "RETURN 1", + memgraph::query::TriggerEventType::EDGE_UPDATE, + memgraph::query::TriggerPhase::AFTER_COMMIT, + {/* no owner */}}); check_trigger_info(); @@ -1224,10 +1231,12 @@ TYPED_TEST(TriggerStoreTest, AuthCheckerUsage) { ::testing::InSequence s; - EXPECT_CALL(mock_checker, IsUserAuthorized(std::optional{}, ElementsAre(Privilege::CREATE))) + EXPECT_CALL(mock_checker, IsUserAuthorized(std::optional{}, ElementsAre(Privilege::CREATE), "")) + .Times(1) + .WillOnce(Return(true)); + EXPECT_CALL(mock_checker, IsUserAuthorized(owner, ElementsAre(Privilege::CREATE), "")) .Times(1) .WillOnce(Return(true)); - EXPECT_CALL(mock_checker, IsUserAuthorized(owner, ElementsAre(Privilege::CREATE))).Times(1).WillOnce(Return(true)); ASSERT_NO_THROW(store->AddTrigger("successfull_trigger_1", "CREATE (n:VERTEX) RETURN n", {}, memgraph::query::TriggerEventType::EDGE_UPDATE, @@ -1239,7 +1248,7 @@ TYPED_TEST(TriggerStoreTest, AuthCheckerUsage) { memgraph::query::TriggerPhase::AFTER_COMMIT, &this->ast_cache, &*this->dba, memgraph::query::InterpreterConfig::Query{}, owner, &mock_checker)); - EXPECT_CALL(mock_checker, IsUserAuthorized(std::optional{}, ElementsAre(Privilege::MATCH))) + EXPECT_CALL(mock_checker, IsUserAuthorized(std::optional{}, ElementsAre(Privilege::MATCH), "")) .Times(1) .WillOnce(Return(false)); @@ -1250,10 +1259,12 @@ TYPED_TEST(TriggerStoreTest, AuthCheckerUsage) { , memgraph::utils::BasicException); store.emplace(this->testing_directory); - EXPECT_CALL(mock_checker, IsUserAuthorized(std::optional{}, ElementsAre(Privilege::CREATE))) + EXPECT_CALL(mock_checker, IsUserAuthorized(std::optional{}, ElementsAre(Privilege::CREATE), "")) .Times(1) .WillOnce(Return(false)); - EXPECT_CALL(mock_checker, IsUserAuthorized(owner, ElementsAre(Privilege::CREATE))).Times(1).WillOnce(Return(true)); + EXPECT_CALL(mock_checker, IsUserAuthorized(owner, ElementsAre(Privilege::CREATE), "")) + .Times(1) + .WillOnce(Return(true)); ASSERT_NO_THROW(store->RestoreTriggers(&this->ast_cache, &*this->dba, memgraph::query::InterpreterConfig::Query{}, &mock_checker)); diff --git a/tests/unit/utils_sync_ptr.cpp b/tests/unit/utils_sync_ptr.cpp new file mode 100644 index 000000000..62111c8a8 --- /dev/null +++ b/tests/unit/utils_sync_ptr.cpp @@ -0,0 +1,296 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include + +#include +#include +#include + +#include +#include "utils/exceptions.hpp" + +using namespace std::chrono_literals; + +TEST(SyncPtr, Basic) { + std::atomic_bool alive{false}; + struct Test { + Test(std::atomic_bool &alive) : alive_(alive) { alive_ = true; } + ~Test() { alive_ = false; } + std::atomic_bool &alive_; + }; + + ASSERT_FALSE(alive); + + memgraph::utils::SyncPtr sp(alive); + ASSERT_TRUE(alive); + auto sp_copy1 = sp.get(); + auto sp_copy2 = sp.get(); + + sp_copy1.reset(); + ASSERT_TRUE(alive); + sp_copy2.reset(); + ASSERT_TRUE(alive); + + sp.DestroyAndSync(); + ASSERT_FALSE(alive); +} + +TEST(SyncPtr, BasicWConfig) { + std::atomic_bool alive{false}; + struct Test { + Test(std::atomic_bool &alive) : alive_(alive) { alive_ = true; } + ~Test() { alive_ = false; } + std::atomic_bool &alive_; + }; + + struct TestConf { + TestConf(int i) : conf_(i) {} + int conf_; + }; + + ASSERT_FALSE(alive); + + memgraph::utils::SyncPtr sp(123, alive); + ASSERT_TRUE(alive); + ASSERT_EQ(sp.config().conf_, 123); + auto sp_copy1 = sp.get(); + auto sp_copy2 = sp.get(); + + sp_copy1.reset(); + ASSERT_TRUE(alive); + sp_copy2.reset(); + ASSERT_TRUE(alive); + + sp.DestroyAndSync(); + ASSERT_FALSE(alive); +} + +TEST(SyncPtr, Sync) { + std::atomic_bool alive{false}; + struct Test { + Test(std::atomic_bool &alive) : alive_(alive) { alive_ = true; } + ~Test() { alive_ = false; } + std::atomic_bool &alive_; + }; + + std::thread th; + + ASSERT_FALSE(alive); + + memgraph::utils::SyncPtr sp(alive); + + { + using namespace std::chrono_literals; + sp.timeout(10000ms); // 10sec + + ASSERT_TRUE(alive); + auto sp_copy1 = sp.get(); + auto sp_copy2 = sp.get(); + + th = std::thread([&alive, p = sp.get()]() mutable { + // Wait for a second and then release the pointer + // SyncPtr will be destroyed in the mean time (and block) + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + ASSERT_TRUE(alive); + p.reset(); + ASSERT_FALSE(alive); + }); + } + + ASSERT_TRUE(alive); + sp.DestroyAndSync(); + + th.join(); + ASSERT_FALSE(alive); +} + +TEST(SyncPtr, SyncWConfig) { + std::atomic_bool alive{false}; + struct Test { + Test(std::atomic_bool &alive) : alive_(alive) { alive_ = true; } + ~Test() { alive_ = false; } + std::atomic_bool &alive_; + }; + + struct TestConf { + TestConf(int i) : conf_(i) {} + int conf_; + }; + + std::thread th; + + ASSERT_FALSE(alive); + + memgraph::utils::SyncPtr sp(456, alive); + + { + using namespace std::chrono_literals; + sp.timeout(10000ms); // 10sec + ASSERT_TRUE(alive); + ASSERT_EQ(sp.config().conf_, 456); + auto sp_copy1 = sp.get(); + auto sp_copy2 = sp.get(); + + th = std::thread([&alive, p = sp.get()]() mutable { + // Wait for a second and then release the pointer + // SyncPtr will be destroyed in the mean time (and block) + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + ASSERT_TRUE(alive); + p.reset(); + ASSERT_FALSE(alive); + }); + } + + ASSERT_TRUE(alive); + sp.DestroyAndSync(); + + th.join(); + ASSERT_FALSE(alive); +} + +TEST(SyncPtr, Timeout100ms) { + std::atomic_bool alive{false}; + struct Test { + Test(std::atomic_bool &alive) : alive_(alive) { alive_ = true; } + ~Test() { alive_ = false; } + std::atomic_bool &alive_; + }; + + std::thread th; + + ASSERT_FALSE(alive); + + memgraph::utils::SyncPtr sp(alive); + using namespace std::chrono_literals; + sp.timeout(100ms); + + ASSERT_TRUE(alive); + + auto p = sp.get(); + + ASSERT_TRUE(alive); + + auto start_100ms = std::chrono::system_clock::now(); + ASSERT_THROW(sp.DestroyAndSync(), memgraph::utils::BasicException); + auto end_100ms = std::chrono::system_clock::now(); + auto delta_100ms = std::chrono::duration_cast(end_100ms - start_100ms).count(); + ASSERT_NEAR(delta_100ms, 100, 100); + + p.reset(); + ASSERT_FALSE(alive); +} + +TEST(SyncPtr, Timeout567ms) { + std::atomic_bool alive{false}; + struct Test { + Test(std::atomic_bool &alive) : alive_(alive) { alive_ = true; } + ~Test() { alive_ = false; } + std::atomic_bool &alive_; + }; + + std::thread th; + + ASSERT_FALSE(alive); + + memgraph::utils::SyncPtr sp(alive); + using namespace std::chrono_literals; + sp.timeout(567ms); + + ASSERT_TRUE(alive); + + auto p = sp.get(); + + ASSERT_TRUE(alive); + + auto start = std::chrono::system_clock::now(); + ASSERT_THROW(sp.DestroyAndSync(), memgraph::utils::BasicException); + auto end = std::chrono::system_clock::now(); + auto delta_ms = std::chrono::duration_cast(end - start).count(); + ASSERT_NEAR(delta_ms, 567, 100); + + p.reset(); + ASSERT_FALSE(alive); +} + +TEST(SyncPtr, Timeout100msWConfig) { + std::atomic_bool alive{false}; + struct Test { + Test(std::atomic_bool &alive) : alive_(alive) { alive_ = true; } + ~Test() { alive_ = false; } + std::atomic_bool &alive_; + }; + + struct TestConf { + TestConf(int i) : conf_(i) {} + int conf_; + }; + + std::thread th; + + ASSERT_FALSE(alive); + + memgraph::utils::SyncPtr sp(0, alive); + using namespace std::chrono_literals; + sp.timeout(100ms); + + ASSERT_TRUE(alive); + + auto p = sp.get(); + + ASSERT_TRUE(alive); + + auto start_100ms = std::chrono::system_clock::now(); + ASSERT_THROW(sp.DestroyAndSync(), memgraph::utils::BasicException); + auto end_100ms = std::chrono::system_clock::now(); + auto delta_100ms = std::chrono::duration_cast(end_100ms - start_100ms).count(); + ASSERT_NEAR(delta_100ms, 100, 100); + + p.reset(); + ASSERT_FALSE(alive); +} + +TEST(SyncPtr, Timeout567msWConfig) { + std::atomic_bool alive{false}; + struct Test { + Test(std::atomic_bool &alive) : alive_(alive) { alive_ = true; } + ~Test() { alive_ = false; } + std::atomic_bool &alive_; + }; + + struct TestConf { + TestConf(int i) : conf_(i) {} + int conf_; + }; + + std::thread th; + + ASSERT_FALSE(alive); + + memgraph::utils::SyncPtr sp(2, alive); + using namespace std::chrono_literals; + sp.timeout(567ms); + + ASSERT_TRUE(alive); + + auto p = sp.get(); + + ASSERT_TRUE(alive); + + auto start = std::chrono::system_clock::now(); + ASSERT_THROW(sp.DestroyAndSync(), memgraph::utils::BasicException); + auto end = std::chrono::system_clock::now(); + auto delta_ms = std::chrono::duration_cast(end - start).count(); + ASSERT_NEAR(delta_ms, 567, 100); + + p.reset(); + ASSERT_FALSE(alive); +}