diff --git a/src/auth/auth.cpp b/src/auth/auth.cpp index 16f6607b7..33d8b2cac 100644 --- a/src/auth/auth.cpp +++ b/src/auth/auth.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Licensed as a Memgraph Enterprise file under the Memgraph Enterprise // License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use @@ -8,9 +8,7 @@ #include "auth/auth.hpp" -#include #include -#include #include #include @@ -18,7 +16,6 @@ #include "auth/exceptions.hpp" #include "license/license.hpp" #include "utils/flag_validation.hpp" -#include "utils/logging.hpp" #include "utils/message.hpp" #include "utils/settings.hpp" #include "utils/string.hpp" @@ -64,7 +61,8 @@ const std::string kLinkPrefix = "link:"; * key="link:", value="" */ -Auth::Auth(const std::string &storage_directory) : storage_(storage_directory), module_(FLAGS_auth_module_executable) {} +Auth::Auth(std::string storage_directory, Config config) + : storage_(std::move(storage_directory)), module_(FLAGS_auth_module_executable), config_{std::move(config)} {} std::optional Auth::Authenticate(const std::string &username, const std::string &password) { if (module_.IsUsed()) { @@ -113,7 +111,7 @@ std::optional Auth::Authenticate(const std::string &username, const std::s return std::nullopt; } } else { - user->UpdatePassword(password); + UpdatePassword(*user, password); } if (FLAGS_auth_module_manage_roles) { if (!rolename.empty()) { @@ -197,13 +195,46 @@ void Auth::SaveUser(const User &user) { } } +void Auth::UpdatePassword(auth::User &user, const std::optional &password) { + // Check if null + if (!password) { + if (!config_.password_permit_null) { + throw AuthException("Null passwords aren't permitted!"); + } + } else { + // Check if compliant with our filter + if (config_.custom_password_regex) { + if (const auto license_check_result = license::global_license_checker.IsEnterpriseValid(utils::global_settings); + license_check_result.HasError()) { + throw AuthException( + "Custom password regex is a Memgraph Enterprise feature. Please set the config " + "(\"--auth-password-strength-regex\") to its default value (\"{}\") or remove the flag.\n{}", + glue::kDefaultPasswordRegex, + license::LicenseCheckErrorToString(license_check_result.GetError(), "password regex")); + } + } + if (!std::regex_match(*password, config_.password_regex)) { + throw AuthException( + "The user password doesn't conform to the required strength! Regex: " + "\"{}\"", + config_.password_regex_str); + } + } + + // All checks passed; update + user.UpdatePassword(password); +} + std::optional Auth::AddUser(const std::string &username, const std::optional &password) { + if (!NameRegexMatch(username)) { + throw AuthException("Invalid user name."); + } auto existing_user = GetUser(username); if (existing_user) return std::nullopt; auto existing_role = GetRole(username); if (existing_role) return std::nullopt; auto new_user = User(username); - new_user.UpdatePassword(password); + UpdatePassword(new_user, password); SaveUser(new_user); return new_user; } @@ -255,6 +286,9 @@ void Auth::SaveRole(const Role &role) { } std::optional Auth::AddRole(const std::string &rolename) { + if (!NameRegexMatch(rolename)) { + throw AuthException("Invalid role name."); + } if (auto existing_role = GetRole(rolename)) return std::nullopt; if (auto existing_user = GetUser(rolename)) return std::nullopt; auto new_role = Role(rolename); @@ -359,4 +393,19 @@ bool Auth::SetMainDatabase(std::string_view db, const std::string &name) { } #endif +bool Auth::NameRegexMatch(const std::string &user_or_role) const { + if (config_.custom_name_regex) { + if (const auto license_check_result = + memgraph::license::global_license_checker.IsEnterpriseValid(memgraph::utils::global_settings); + license_check_result.HasError()) { + throw memgraph::auth::AuthException( + "Custom user/role regex is a Memgraph Enterprise feature. Please set the config " + "(\"--auth-user-or-role-name-regex\") to its default value (\"{}\") or remove the flag.\n{}", + glue::kDefaultUserRoleRegex, + memgraph::license::LicenseCheckErrorToString(license_check_result.GetError(), "user/role regex")); + } + } + return std::regex_match(user_or_role, config_.name_regex); +} + } // namespace memgraph::auth diff --git a/src/auth/auth.hpp b/src/auth/auth.hpp index b9568c311..aa90c349a 100644 --- a/src/auth/auth.hpp +++ b/src/auth/auth.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Licensed as a Memgraph Enterprise file under the Memgraph Enterprise // License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use @@ -10,11 +10,13 @@ #include #include +#include #include #include "auth/exceptions.hpp" #include "auth/models.hpp" #include "auth/module.hpp" +#include "glue/auth_global.hpp" #include "kvstore/kvstore.hpp" #include "utils/settings.hpp" @@ -31,7 +33,40 @@ static const constexpr char *const kAllDatabases = "*"; */ class Auth final { public: - explicit Auth(const std::string &storage_directory); + struct Config { + Config() {} + Config(std::string name_regex, std::string password_regex, bool password_permit_null) + : name_regex_str{std::move(name_regex)}, + password_regex_str{std::move(password_regex)}, + password_permit_null{password_permit_null}, + custom_name_regex{name_regex_str != glue::kDefaultUserRoleRegex}, + name_regex{name_regex_str}, + custom_password_regex{password_regex_str != glue::kDefaultPasswordRegex}, + password_regex{password_regex_str} {} + + std::string name_regex_str{glue::kDefaultUserRoleRegex}; + std::string password_regex_str{glue::kDefaultPasswordRegex}; + bool password_permit_null{true}; + + private: + friend class Auth; + bool custom_name_regex{false}; + std::regex name_regex{name_regex_str}; + bool custom_password_regex{false}; + std::regex password_regex{password_regex_str}; + }; + + explicit Auth(std::string storage_directory, Config config); + + /** + * @brief Set the Config object + * + * @param config + */ + void SetConfig(Config config) { + // NOTE: The Auth class itself is not thread-safe, higher-level code needs to synchronize it when using it. + config_ = std::move(config); + } /** * Authenticates a user using his username and password. @@ -85,6 +120,14 @@ class Auth final { */ bool RemoveUser(const std::string &username); + /** + * @brief + * + * @param user + * @param password + */ + void UpdatePassword(auth::User &user, const std::optional &password); + /** * Gets all users from the storage. * @@ -199,10 +242,20 @@ class Auth final { #endif private: + /** + * @brief + * + * @param user_or_role + * @return true + * @return false + */ + bool NameRegexMatch(const std::string &user_or_role) const; + // Even though the `kvstore::KVStore` class is guaranteed to be thread-safe, // Auth is not thread-safe because modifying users and roles might require // more than one operation on the storage. kvstore::KVStore storage_; auth::Module module_; + Config config_; }; } // namespace memgraph::auth diff --git a/src/auth/models.cpp b/src/auth/models.cpp index 4fad3965d..7ded6410d 100644 --- a/src/auth/models.cpp +++ b/src/auth/models.cpp @@ -9,7 +9,6 @@ #include "auth/models.hpp" #include -#include #include #include @@ -21,19 +20,8 @@ #include "query/constants.hpp" #include "spdlog/spdlog.h" #include "utils/cast.hpp" -#include "utils/logging.hpp" -#include "utils/settings.hpp" #include "utils/string.hpp" -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -DEFINE_bool(auth_password_permit_null, true, "Set to false to disable null passwords."); - -inline constexpr std::string_view default_password_regex = ".+"; -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -DEFINE_string(auth_password_strength_regex, default_password_regex.data(), - "The regular expression that should be used to match the entire " - "entered password to ensure its strength."); - namespace memgraph::auth { namespace { @@ -587,31 +575,9 @@ bool User::CheckPassword(const std::string &password) { void User::UpdatePassword(const std::optional &password) { if (!password) { - if (!FLAGS_auth_password_permit_null) { - throw AuthException("Null passwords aren't permitted!"); - } password_hash_ = ""; return; } - - if (FLAGS_auth_password_strength_regex != default_password_regex) { - if (const auto license_check_result = license::global_license_checker.IsEnterpriseValid(utils::global_settings); - license_check_result.HasError()) { - throw AuthException( - "Custom password regex is a Memgraph Enterprise feature. Please set the config " - "(\"--auth-password-strength-regex\") to its default value (\"{}\") or remove the flag.\n{}", - default_password_regex, - license::LicenseCheckErrorToString(license_check_result.GetError(), "password regex")); - } - } - std::regex re(FLAGS_auth_password_strength_regex); - if (!std::regex_match(*password, re)) { - throw AuthException( - "The user password doesn't conform to the required strength! Regex: " - "\"{}\"", - FLAGS_auth_password_strength_regex); - } - password_hash_ = EncryptPassword(*password); } diff --git a/src/flags/general.cpp b/src/flags/general.cpp index bcf1f7e1f..cd2c95c60 100644 --- a/src/flags/general.cpp +++ b/src/flags/general.cpp @@ -195,3 +195,9 @@ DEFINE_HIDDEN_string(organization_name, "", "Organization name."); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_string(auth_user_or_role_name_regex, memgraph::glue::kDefaultUserRoleRegex.data(), "Set to the regular expression that each user or role name must fulfill."); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +DEFINE_bool(auth_password_permit_null, true, "Set to false to disable null passwords."); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +DEFINE_string(auth_password_strength_regex, memgraph::glue::kDefaultPasswordRegex.data(), + "The regular expression that should be used to match the entire " + "entered password to ensure its strength."); diff --git a/src/flags/general.hpp b/src/flags/general.hpp index 5483d92cb..a1e8729ab 100644 --- a/src/flags/general.hpp +++ b/src/flags/general.hpp @@ -118,3 +118,7 @@ DECLARE_string(license_key); DECLARE_string(organization_name); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) DECLARE_string(auth_user_or_role_name_regex); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +DECLARE_bool(auth_password_permit_null); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +DECLARE_string(auth_password_strength_regex); diff --git a/src/glue/auth_global.hpp b/src/glue/auth_global.hpp index 4675b6978..008960c76 100644 --- a/src/glue/auth_global.hpp +++ b/src/glue/auth_global.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -13,4 +13,5 @@ namespace memgraph::glue { inline constexpr std::string_view kDefaultUserRoleRegex = "[a-zA-Z0-9_.+-@]+"; +static constexpr std::string_view kDefaultPasswordRegex = ".+"; } // namespace memgraph::glue diff --git a/src/glue/auth_handler.cpp b/src/glue/auth_handler.cpp index a86dc5f48..f3efb6ba0 100644 --- a/src/glue/auth_handler.cpp +++ b/src/glue/auth_handler.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -249,25 +249,10 @@ std::vector> ShowFineGrainedRolePrivile namespace memgraph::glue { AuthQueryHandler::AuthQueryHandler( - memgraph::utils::Synchronized *auth, - std::string name_regex_string) - : auth_(auth), name_regex_string_(std::move(name_regex_string)), name_regex_(name_regex_string_) {} + memgraph::utils::Synchronized *auth) + : auth_(auth) {} bool AuthQueryHandler::CreateUser(const std::string &username, const std::optional &password) { - if (name_regex_string_ != kDefaultUserRoleRegex) { - if (const auto license_check_result = - memgraph::license::global_license_checker.IsEnterpriseValid(memgraph::utils::global_settings); - license_check_result.HasError()) { - throw memgraph::auth::AuthException( - "Custom user/role regex is a Memgraph Enterprise feature. Please set the config " - "(\"--auth-user-or-role-name-regex\") to its default value (\"{}\") or remove the flag.\n{}", - kDefaultUserRoleRegex, - memgraph::license::LicenseCheckErrorToString(license_check_result.GetError(), "user/role regex")); - } - } - if (!std::regex_match(username, name_regex_)) { - throw query::QueryRuntimeException("Invalid user name."); - } try { const auto [first_user, user_added] = std::invoke([&, this] { auto locked_auth = auth_->Lock(); @@ -305,9 +290,6 @@ bool AuthQueryHandler::CreateUser(const std::string &username, const std::option } bool AuthQueryHandler::DropUser(const std::string &username) { - if (!std::regex_match(username, name_regex_)) { - throw memgraph::query::QueryRuntimeException("Invalid user name."); - } try { auto locked_auth = auth_->Lock(); auto user = locked_auth->GetUser(username); @@ -319,16 +301,13 @@ bool AuthQueryHandler::DropUser(const std::string &username) { } void AuthQueryHandler::SetPassword(const std::string &username, const std::optional &password) { - if (!std::regex_match(username, name_regex_)) { - throw memgraph::query::QueryRuntimeException("Invalid user name."); - } try { auto locked_auth = auth_->Lock(); auto user = locked_auth->GetUser(username); if (!user) { throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist.", username); } - user->UpdatePassword(password); + locked_auth->UpdatePassword(*user, password); locked_auth->SaveUser(*user); } catch (const memgraph::auth::AuthException &e) { throw memgraph::query::QueryRuntimeException(e.what()); @@ -336,9 +315,6 @@ void AuthQueryHandler::SetPassword(const std::string &username, const std::optio } bool AuthQueryHandler::CreateRole(const std::string &rolename) { - if (!std::regex_match(rolename, name_regex_)) { - throw memgraph::query::QueryRuntimeException("Invalid role name."); - } try { auto locked_auth = auth_->Lock(); return locked_auth->AddRole(rolename).has_value(); @@ -349,9 +325,6 @@ bool AuthQueryHandler::CreateRole(const std::string &rolename) { #ifdef MG_ENTERPRISE bool AuthQueryHandler::RevokeDatabaseFromUser(const std::string &db, const std::string &username) { - if (!std::regex_match(username, name_regex_)) { - throw memgraph::query::QueryRuntimeException("Invalid user name."); - } try { auto locked_auth = auth_->Lock(); auto user = locked_auth->GetUser(username); @@ -363,9 +336,6 @@ bool AuthQueryHandler::RevokeDatabaseFromUser(const std::string &db, const std:: } bool AuthQueryHandler::GrantDatabaseToUser(const std::string &db, const std::string &username) { - if (!std::regex_match(username, name_regex_)) { - throw memgraph::query::QueryRuntimeException("Invalid user name."); - } try { auto locked_auth = auth_->Lock(); auto user = locked_auth->GetUser(username); @@ -378,9 +348,6 @@ bool AuthQueryHandler::GrantDatabaseToUser(const std::string &db, const std::str std::vector> AuthQueryHandler::GetDatabasePrivileges( const std::string &username) { - if (!std::regex_match(username, name_regex_)) { - throw memgraph::query::QueryRuntimeException("Invalid user or role name."); - } try { auto locked_auth = auth_->ReadLock(); auto user = locked_auth->GetUser(username); @@ -394,9 +361,6 @@ std::vector> AuthQueryHandler::GetDatab } bool AuthQueryHandler::SetMainDatabase(std::string_view db, const std::string &username) { - if (!std::regex_match(username, name_regex_)) { - throw memgraph::query::QueryRuntimeException("Invalid user name."); - } try { auto locked_auth = auth_->Lock(); auto user = locked_auth->GetUser(username); @@ -417,9 +381,6 @@ void AuthQueryHandler::DeleteDatabase(std::string_view db) { #endif bool AuthQueryHandler::DropRole(const std::string &rolename) { - if (!std::regex_match(rolename, name_regex_)) { - throw memgraph::query::QueryRuntimeException("Invalid role name."); - } try { auto locked_auth = auth_->Lock(); auto role = locked_auth->GetRole(rolename); @@ -465,9 +426,6 @@ std::vector AuthQueryHandler::GetRolenames() { } std::optional AuthQueryHandler::GetRolenameForUser(const std::string &username) { - if (!std::regex_match(username, name_regex_)) { - throw memgraph::query::QueryRuntimeException("Invalid user name."); - } try { auto locked_auth = auth_->ReadLock(); auto user = locked_auth->GetUser(username); @@ -485,9 +443,6 @@ std::optional AuthQueryHandler::GetRolenameForUser(const std::strin } std::vector AuthQueryHandler::GetUsernamesForRole(const std::string &rolename) { - if (!std::regex_match(rolename, name_regex_)) { - throw memgraph::query::QueryRuntimeException("Invalid role name."); - } try { auto locked_auth = auth_->ReadLock(); auto role = locked_auth->GetRole(rolename); @@ -507,12 +462,6 @@ std::vector AuthQueryHandler::GetUsernamesForRole(c } void AuthQueryHandler::SetRole(const std::string &username, const std::string &rolename) { - if (!std::regex_match(username, name_regex_)) { - throw memgraph::query::QueryRuntimeException("Invalid user name."); - } - if (!std::regex_match(rolename, name_regex_)) { - throw memgraph::query::QueryRuntimeException("Invalid role name."); - } try { auto locked_auth = auth_->Lock(); auto user = locked_auth->GetUser(username); @@ -535,9 +484,6 @@ void AuthQueryHandler::SetRole(const std::string &username, const std::string &r } void AuthQueryHandler::ClearRole(const std::string &username) { - if (!std::regex_match(username, name_regex_)) { - throw memgraph::query::QueryRuntimeException("Invalid user name."); - } try { auto locked_auth = auth_->Lock(); auto user = locked_auth->GetUser(username); @@ -552,9 +498,6 @@ void AuthQueryHandler::ClearRole(const std::string &username) { } std::vector> AuthQueryHandler::GetPrivileges(const std::string &user_or_role) { - if (!std::regex_match(user_or_role, name_regex_)) { - throw memgraph::query::QueryRuntimeException("Invalid user or role name."); - } try { auto locked_auth = auth_->ReadLock(); std::vector> grants; @@ -704,9 +647,6 @@ void AuthQueryHandler::EditPermissions( const TEditFineGrainedPermissionsFun &edit_fine_grained_permissions_fun #endif ) { - if (!std::regex_match(user_or_role, name_regex_)) { - throw memgraph::query::QueryRuntimeException("Invalid user or role name."); - } try { std::vector permissions; permissions.reserve(privileges.size()); diff --git a/src/glue/auth_handler.hpp b/src/glue/auth_handler.hpp index e6b8724d4..c226a4560 100644 --- a/src/glue/auth_handler.hpp +++ b/src/glue/auth_handler.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -24,12 +24,9 @@ namespace memgraph::glue { class AuthQueryHandler final : public memgraph::query::AuthQueryHandler { memgraph::utils::Synchronized *auth_; - std::string name_regex_string_; - std::regex name_regex_; public: - AuthQueryHandler(memgraph::utils::Synchronized *auth, - std::string name_regex_string); + AuthQueryHandler(memgraph::utils::Synchronized *auth); bool CreateUser(const std::string &username, const std::optional &password) override; diff --git a/src/memgraph.cpp b/src/memgraph.cpp index 1d13c4b76..cbd63490e 100644 --- a/src/memgraph.cpp +++ b/src/memgraph.cpp @@ -357,11 +357,10 @@ int main(int argc, char **argv) { .stream_transaction_retry_interval = std::chrono::milliseconds(FLAGS_stream_transaction_retry_interval)}; auto auth_glue = - [flag = FLAGS_auth_user_or_role_name_regex]( - memgraph::utils::Synchronized *auth, - std::unique_ptr &ah, std::unique_ptr &ac) { + [](memgraph::utils::Synchronized *auth, + std::unique_ptr &ah, std::unique_ptr &ac) { // Glue high level auth implementations to the query side - ah = std::make_unique(auth, flag); + ah = std::make_unique(auth); ac = std::make_unique(auth); // Handle users passed via arguments auto *maybe_username = std::getenv(kMgUser); @@ -377,9 +376,10 @@ int main(int argc, char **argv) { } }; - // WIP - memgraph::utils::Synchronized auth_{data_directory / - "auth"}; + memgraph::auth::Auth::Config auth_config{FLAGS_auth_user_or_role_name_regex, FLAGS_auth_password_strength_regex, + FLAGS_auth_password_permit_null}; + memgraph::utils::Synchronized auth_{ + data_directory / "auth", auth_config}; std::unique_ptr auth_handler; std::unique_ptr auth_checker; auth_glue(&auth_, auth_handler, auth_checker); diff --git a/tests/integration/telemetry/client.cpp b/tests/integration/telemetry/client.cpp index 8c32664fb..b93b1ada5 100644 --- a/tests/integration/telemetry/client.cpp +++ b/tests/integration/telemetry/client.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -13,6 +13,7 @@ #include "dbms/dbms_handler.hpp" #include "glue/auth_checker.hpp" +#include "glue/auth_global.hpp" #include "glue/auth_handler.hpp" #include "requests/requests.hpp" #include "storage/v2/config.hpp" @@ -32,9 +33,10 @@ int main(int argc, char **argv) { // Memgraph backend std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_telemetry_integration_test"}; - memgraph::utils::Synchronized auth_{data_directory / - "auth"}; - memgraph::glue::AuthQueryHandler auth_handler(&auth_, ""); + memgraph::utils::Synchronized auth_{ + data_directory / "auth", + memgraph::auth::Auth::Config{std::string{memgraph::glue::kDefaultUserRoleRegex}, "", true}}; + memgraph::glue::AuthQueryHandler auth_handler(&auth_); memgraph::glue::AuthChecker auth_checker(&auth_); memgraph::storage::Config db_config; diff --git a/tests/unit/auth.cpp b/tests/unit/auth.cpp index 6dbe20914..3c2931a77 100644 --- a/tests/unit/auth.cpp +++ b/tests/unit/auth.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -19,6 +19,7 @@ #include "auth/auth.hpp" #include "auth/crypto.hpp" #include "auth/models.hpp" +#include "glue/auth_global.hpp" #include "license/license.hpp" #include "utils/cast.hpp" #include "utils/file.hpp" @@ -26,90 +27,70 @@ using namespace memgraph::auth; namespace fs = std::filesystem; -DECLARE_bool(auth_password_permit_null); -DECLARE_string(auth_password_strength_regex); DECLARE_string(password_encryption_algorithm); class AuthWithStorage : public ::testing::Test { protected: void SetUp() override { memgraph::utils::EnsureDir(test_folder_); - FLAGS_auth_password_permit_null = true; - FLAGS_auth_password_strength_regex = ".+"; - memgraph::license::global_license_checker.EnableTesting(); + auth.emplace(test_folder_ / ("unit_auth_test_" + std::to_string(static_cast(getpid()))), auth_config); } void TearDown() override { fs::remove_all(test_folder_); } fs::path test_folder_{fs::temp_directory_path() / "MG_tests_unit_auth"}; - - Auth auth{test_folder_ / ("unit_auth_test_" + std::to_string(static_cast(getpid())))}; + Auth::Config auth_config{}; + std::optional auth{}; }; TEST_F(AuthWithStorage, AddRole) { - ASSERT_TRUE(auth.AddRole("admin")); - ASSERT_TRUE(auth.AddRole("user")); - ASSERT_FALSE(auth.AddRole("admin")); + ASSERT_TRUE(auth->AddRole("admin")); + ASSERT_TRUE(auth->AddRole("user")); + ASSERT_FALSE(auth->AddRole("admin")); } TEST_F(AuthWithStorage, RemoveRole) { - ASSERT_TRUE(auth.AddRole("admin")); - ASSERT_TRUE(auth.RemoveRole("admin")); - class AuthWithStorage : public ::testing::Test { - protected: - void SetUp() override { - memgraph::utils::EnsureDir(test_folder_); - FLAGS_auth_password_permit_null = true; - FLAGS_auth_password_strength_regex = ".+"; - - memgraph::license::global_license_checker.EnableTesting(); - } - - void TearDown() override { fs::remove_all(test_folder_); } - - fs::path test_folder_{fs::temp_directory_path() / "MG_tests_unit_auth"}; - - Auth auth{test_folder_ / ("unit_auth_test_" + std::to_string(static_cast(getpid())))}; - }; - ASSERT_FALSE(auth.HasUsers()); - ASSERT_FALSE(auth.RemoveUser("test2")); - ASSERT_FALSE(auth.RemoveUser("test")); - ASSERT_FALSE(auth.HasUsers()); + ASSERT_TRUE(auth->AddRole("admin")); + ASSERT_TRUE(auth->RemoveRole("admin")); + ASSERT_FALSE(auth->HasUsers()); + ASSERT_FALSE(auth->RemoveUser("test2")); + ASSERT_FALSE(auth->RemoveUser("test")); + ASSERT_FALSE(auth->HasUsers()); } TEST_F(AuthWithStorage, Authenticate) { - ASSERT_FALSE(auth.HasUsers()); + ASSERT_FALSE(auth->HasUsers()); - auto user = auth.AddUser("test"); + auto user = auth->AddUser("test"); ASSERT_NE(user, std::nullopt); - ASSERT_TRUE(auth.HasUsers()); + ASSERT_TRUE(auth->HasUsers()); - ASSERT_TRUE(auth.Authenticate("test", "123")); + ASSERT_TRUE(auth->Authenticate("test", "123")); user->UpdatePassword("123"); - auth.SaveUser(*user); + auth->SaveUser(*user); - ASSERT_NE(auth.Authenticate("test", "123"), std::nullopt); + ASSERT_NE(auth->Authenticate("test", "123"), std::nullopt); - ASSERT_EQ(auth.Authenticate("test", "456"), std::nullopt); - ASSERT_NE(auth.Authenticate("test", "123"), std::nullopt); + ASSERT_EQ(auth->Authenticate("test", "456"), std::nullopt); + ASSERT_NE(auth->Authenticate("test", "123"), std::nullopt); user->UpdatePassword(); - auth.SaveUser(*user); + auth->SaveUser(*user); - ASSERT_NE(auth.Authenticate("test", "123"), std::nullopt); - ASSERT_NE(auth.Authenticate("test", "456"), std::nullopt); + ASSERT_NE(auth->Authenticate("test", "123"), std::nullopt); + ASSERT_NE(auth->Authenticate("test", "456"), std::nullopt); - ASSERT_EQ(auth.Authenticate("nonexistant", "123"), std::nullopt); + ASSERT_EQ(auth->Authenticate("nonexistant", "123"), std::nullopt); } TEST_F(AuthWithStorage, UserRolePermissions) { - ASSERT_FALSE(auth.HasUsers()); - ASSERT_TRUE(auth.AddUser("test")); - ASSERT_TRUE(auth.HasUsers()); + ASSERT_FALSE(auth->HasUsers()); + ASSERT_TRUE(auth->AddUser("test")); + ASSERT_TRUE(auth->HasUsers()); - auto user = auth.GetUser("test"); + auto user = auth->GetUser("test"); ASSERT_NE(user, std::nullopt); // Test initial user permissions. @@ -130,8 +111,8 @@ TEST_F(AuthWithStorage, UserRolePermissions) { ASSERT_EQ(user->permissions(), user->GetPermissions()); // Create role. - ASSERT_TRUE(auth.AddRole("admin")); - auto role = auth.GetRole("admin"); + ASSERT_TRUE(auth->AddRole("admin")); + auto role = auth->GetRole("admin"); ASSERT_NE(role, std::nullopt); // Assign permissions to role and role to user. @@ -163,11 +144,11 @@ TEST_F(AuthWithStorage, UserRolePermissions) { #ifdef MG_ENTERPRISE TEST_F(AuthWithStorage, UserRoleFineGrainedAccessHandler) { - ASSERT_FALSE(auth.HasUsers()); - ASSERT_TRUE(auth.AddUser("test")); - ASSERT_TRUE(auth.HasUsers()); + ASSERT_FALSE(auth->HasUsers()); + ASSERT_TRUE(auth->AddUser("test")); + ASSERT_TRUE(auth->HasUsers()); - auto user = auth.GetUser("test"); + auto user = auth->GetUser("test"); ASSERT_NE(user, std::nullopt); // Test initial user fine grained access permissions. @@ -204,8 +185,8 @@ TEST_F(AuthWithStorage, UserRoleFineGrainedAccessHandler) { user->GetFineGrainedAccessEdgeTypePermissions()); // Create role. - ASSERT_TRUE(auth.AddRole("admin")); - auto role = auth.GetRole("admin"); + ASSERT_TRUE(auth->AddRole("admin")); + auto role = auth->GetRole("admin"); ASSERT_NE(role, std::nullopt); // Grant label and edge type to role and role to user. @@ -236,44 +217,44 @@ TEST_F(AuthWithStorage, UserRoleFineGrainedAccessHandler) { TEST_F(AuthWithStorage, RoleManipulations) { { - auto user1 = auth.AddUser("user1"); + auto user1 = auth->AddUser("user1"); ASSERT_TRUE(user1); - auto role1 = auth.AddRole("role1"); + auto role1 = auth->AddRole("role1"); ASSERT_TRUE(role1); user1->SetRole(*role1); - auth.SaveUser(*user1); + auth->SaveUser(*user1); - auto user2 = auth.AddUser("user2"); + auto user2 = auth->AddUser("user2"); ASSERT_TRUE(user2); - auto role2 = auth.AddRole("role2"); + auto role2 = auth->AddRole("role2"); ASSERT_TRUE(role2); user2->SetRole(*role2); - auth.SaveUser(*user2); + auth->SaveUser(*user2); } { - auto user1 = auth.GetUser("user1"); + auto user1 = auth->GetUser("user1"); ASSERT_TRUE(user1); const auto *role1 = user1->role(); ASSERT_NE(role1, nullptr); ASSERT_EQ(role1->rolename(), "role1"); - auto user2 = auth.GetUser("user2"); + auto user2 = auth->GetUser("user2"); ASSERT_TRUE(user2); const auto *role2 = user2->role(); ASSERT_NE(role2, nullptr); ASSERT_EQ(role2->rolename(), "role2"); } - ASSERT_TRUE(auth.RemoveRole("role1")); + ASSERT_TRUE(auth->RemoveRole("role1")); { - auto user1 = auth.GetUser("user1"); + auto user1 = auth->GetUser("user1"); ASSERT_TRUE(user1); const auto *role = user1->role(); ASSERT_EQ(role, nullptr); - auto user2 = auth.GetUser("user2"); + auto user2 = auth->GetUser("user2"); ASSERT_TRUE(user2); const auto *role2 = user2->role(); ASSERT_NE(role2, nullptr); @@ -281,17 +262,17 @@ TEST_F(AuthWithStorage, RoleManipulations) { } { - auto role1 = auth.AddRole("role1"); + auto role1 = auth->AddRole("role1"); ASSERT_TRUE(role1); } { - auto user1 = auth.GetUser("user1"); + auto user1 = auth->GetUser("user1"); ASSERT_TRUE(user1); const auto *role1 = user1->role(); ASSERT_EQ(role1, nullptr); - auto user2 = auth.GetUser("user2"); + auto user2 = auth->GetUser("user2"); ASSERT_TRUE(user2); const auto *role2 = user2->role(); ASSERT_NE(role2, nullptr); @@ -299,7 +280,7 @@ TEST_F(AuthWithStorage, RoleManipulations) { } { - auto users = auth.AllUsers(); + auto users = auth->AllUsers(); std::sort(users.begin(), users.end(), [](const User &a, const User &b) { return a.username() < b.username(); }); ASSERT_EQ(users.size(), 2); ASSERT_EQ(users[0].username(), "user1"); @@ -307,7 +288,7 @@ TEST_F(AuthWithStorage, RoleManipulations) { } { - auto roles = auth.AllRoles(); + auto roles = auth->AllRoles(); std::sort(roles.begin(), roles.end(), [](const Role &a, const Role &b) { return a.rolename() < b.rolename(); }); ASSERT_EQ(roles.size(), 2); ASSERT_EQ(roles[0].rolename(), "role1"); @@ -315,7 +296,7 @@ TEST_F(AuthWithStorage, RoleManipulations) { } { - auto users = auth.AllUsersForRole("role2"); + auto users = auth->AllUsersForRole("role2"); ASSERT_EQ(users.size(), 1); ASSERT_EQ(users[0].username(), "user2"); } @@ -323,16 +304,16 @@ TEST_F(AuthWithStorage, RoleManipulations) { TEST_F(AuthWithStorage, UserRoleLinkUnlink) { { - auto user = auth.AddUser("user"); + auto user = auth->AddUser("user"); ASSERT_TRUE(user); - auto role = auth.AddRole("role"); + auto role = auth->AddRole("role"); ASSERT_TRUE(role); user->SetRole(*role); - auth.SaveUser(*user); + auth->SaveUser(*user); } { - auto user = auth.GetUser("user"); + auto user = auth->GetUser("user"); ASSERT_TRUE(user); const auto *role = user->role(); ASSERT_NE(role, nullptr); @@ -340,14 +321,14 @@ TEST_F(AuthWithStorage, UserRoleLinkUnlink) { } { - auto user = auth.GetUser("user"); + auto user = auth->GetUser("user"); ASSERT_TRUE(user); user->ClearRole(); - auth.SaveUser(*user); + auth->SaveUser(*user); } { - auto user = auth.GetUser("user"); + auto user = auth->GetUser("user"); ASSERT_TRUE(user); ASSERT_EQ(user->role(), nullptr); } @@ -355,19 +336,19 @@ TEST_F(AuthWithStorage, UserRoleLinkUnlink) { TEST_F(AuthWithStorage, UserPasswordCreation) { { - auto user = auth.AddUser("test"); + auto user = auth->AddUser("test"); ASSERT_TRUE(user); - ASSERT_TRUE(auth.Authenticate("test", "123")); - ASSERT_TRUE(auth.Authenticate("test", "456")); - ASSERT_TRUE(auth.RemoveUser(user->username())); + ASSERT_TRUE(auth->Authenticate("test", "123")); + ASSERT_TRUE(auth->Authenticate("test", "456")); + ASSERT_TRUE(auth->RemoveUser(user->username())); } { - auto user = auth.AddUser("test", "123"); + auto user = auth->AddUser("test", "123"); ASSERT_TRUE(user); - ASSERT_TRUE(auth.Authenticate("test", "123")); - ASSERT_FALSE(auth.Authenticate("test", "456")); - ASSERT_TRUE(auth.RemoveUser(user->username())); + ASSERT_TRUE(auth->Authenticate("test", "123")); + ASSERT_FALSE(auth->Authenticate("test", "456")); + ASSERT_TRUE(auth->RemoveUser(user->username())); } } @@ -382,36 +363,53 @@ TEST_F(AuthWithStorage, PasswordStrength) { const std::string kAlmostStrongPassword = "ThisPasswordMeetsAllButOneCriterion1234"; const std::string kStrongPassword = "ThisIsAVeryStrongPassword123$"; - auto user = auth.AddUser("user"); - ASSERT_TRUE(user); + { + auth.reset(); + auth.emplace(test_folder_ / ("unit_auth_test_" + std::to_string(static_cast(getpid()))), + Auth::Config{std::string{memgraph::glue::kDefaultUserRoleRegex}, kWeakRegex, true}); + auto user = auth->AddUser("user1"); + ASSERT_TRUE(user); + ASSERT_NO_THROW(auth->UpdatePassword(*user, std::nullopt)); + ASSERT_NO_THROW(auth->UpdatePassword(*user, kWeakPassword)); + ASSERT_NO_THROW(auth->UpdatePassword(*user, kAlmostStrongPassword)); + ASSERT_NO_THROW(auth->UpdatePassword(*user, kStrongPassword)); + } - FLAGS_auth_password_permit_null = true; - FLAGS_auth_password_strength_regex = kWeakRegex; - ASSERT_NO_THROW(user->UpdatePassword()); - ASSERT_NO_THROW(user->UpdatePassword(kWeakPassword)); - ASSERT_NO_THROW(user->UpdatePassword(kAlmostStrongPassword)); - ASSERT_NO_THROW(user->UpdatePassword(kStrongPassword)); + { + auth.reset(); + auth.emplace(test_folder_ / ("unit_auth_test_" + std::to_string(static_cast(getpid()))), + Auth::Config{std::string{memgraph::glue::kDefaultUserRoleRegex}, kWeakRegex, false}); + ASSERT_THROW(auth->AddUser("user2", std::nullopt), AuthException); + auto user = auth->AddUser("user2", kWeakPassword); + ASSERT_TRUE(user); + ASSERT_NO_THROW(auth->UpdatePassword(*user, kWeakPassword)); + ASSERT_NO_THROW(auth->UpdatePassword(*user, kAlmostStrongPassword)); + ASSERT_NO_THROW(auth->UpdatePassword(*user, kStrongPassword)); + } - FLAGS_auth_password_permit_null = false; - FLAGS_auth_password_strength_regex = kWeakRegex; - ASSERT_THROW(user->UpdatePassword(), AuthException); - ASSERT_NO_THROW(user->UpdatePassword(kWeakPassword)); - ASSERT_NO_THROW(user->UpdatePassword(kAlmostStrongPassword)); - ASSERT_NO_THROW(user->UpdatePassword(kStrongPassword)); + { + auth.reset(); + auth.emplace(test_folder_ / ("unit_auth_test_" + std::to_string(static_cast(getpid()))), + Auth::Config{std::string{memgraph::glue::kDefaultUserRoleRegex}, kStrongRegex, true}); + auto user = auth->AddUser("user3"); + ASSERT_TRUE(user); + ASSERT_NO_THROW(auth->UpdatePassword(*user, std::nullopt)); + ASSERT_THROW(auth->UpdatePassword(*user, kWeakPassword), AuthException); + ASSERT_THROW(auth->UpdatePassword(*user, kAlmostStrongPassword), AuthException); + ASSERT_NO_THROW(auth->UpdatePassword(*user, kStrongPassword)); + } - FLAGS_auth_password_permit_null = true; - FLAGS_auth_password_strength_regex = kStrongRegex; - ASSERT_NO_THROW(user->UpdatePassword()); - ASSERT_THROW(user->UpdatePassword(kWeakPassword), AuthException); - ASSERT_THROW(user->UpdatePassword(kAlmostStrongPassword), AuthException); - ASSERT_NO_THROW(user->UpdatePassword(kStrongPassword)); - - FLAGS_auth_password_permit_null = false; - FLAGS_auth_password_strength_regex = kStrongRegex; - ASSERT_THROW(user->UpdatePassword(), AuthException); - ASSERT_THROW(user->UpdatePassword(kWeakPassword), AuthException); - ASSERT_THROW(user->UpdatePassword(kAlmostStrongPassword), AuthException); - ASSERT_NO_THROW(user->UpdatePassword(kStrongPassword)); + { + auth.reset(); + auth.emplace(test_folder_ / ("unit_auth_test_" + std::to_string(static_cast(getpid()))), + Auth::Config{std::string{memgraph::glue::kDefaultUserRoleRegex}, kStrongRegex, false}); + ASSERT_THROW(auth->AddUser("user4", std::nullopt);, AuthException); + ASSERT_THROW(auth->AddUser("user4", kWeakPassword);, AuthException); + ASSERT_THROW(auth->AddUser("user4", kAlmostStrongPassword);, AuthException); + auto user = auth->AddUser("user4", kStrongPassword); + ASSERT_TRUE(user); + ASSERT_NO_THROW(auth->UpdatePassword(*user, kStrongPassword)); + } } TEST(AuthWithoutStorage, Permissions) { @@ -680,30 +678,30 @@ TEST(AuthWithoutStorage, RoleSerializeDeserialize) { } TEST_F(AuthWithStorage, UserWithRoleSerializeDeserialize) { - auto role = auth.AddRole("role"); + auto role = auth->AddRole("role"); ASSERT_TRUE(role); role->permissions().Grant(Permission::MATCH); role->permissions().Deny(Permission::MERGE); - auth.SaveRole(*role); + auth->SaveRole(*role); - auto user = auth.AddUser("user"); + auto user = auth->AddUser("user"); ASSERT_TRUE(user); user->permissions().Grant(Permission::MATCH); user->permissions().Deny(Permission::MERGE); user->UpdatePassword("world"); user->SetRole(*role); - auth.SaveUser(*user); + auth->SaveUser(*user); - auto new_user = auth.GetUser("user"); + auto new_user = auth->GetUser("user"); ASSERT_TRUE(new_user); ASSERT_EQ(*user, *new_user); } TEST_F(AuthWithStorage, UserRoleUniqueName) { - ASSERT_TRUE(auth.AddUser("user")); - ASSERT_TRUE(auth.AddRole("role")); - ASSERT_FALSE(auth.AddRole("user")); - ASSERT_FALSE(auth.AddUser("role")); + ASSERT_TRUE(auth->AddUser("user")); + ASSERT_TRUE(auth->AddRole("role")); + ASSERT_FALSE(auth->AddRole("user")); + ASSERT_FALSE(auth->AddUser("role")); } TEST(AuthWithoutStorage, CaseInsensitivity) { @@ -748,58 +746,58 @@ TEST(AuthWithoutStorage, CaseInsensitivity) { TEST_F(AuthWithStorage, CaseInsensitivity) { // AddUser { - auto user = auth.AddUser("Alice", "alice"); + auto user = auth->AddUser("Alice", "alice"); ASSERT_TRUE(user); ASSERT_EQ(user->username(), "alice"); - ASSERT_FALSE(auth.AddUser("alice")); - ASSERT_FALSE(auth.AddUser("alicE")); + ASSERT_FALSE(auth->AddUser("alice")); + ASSERT_FALSE(auth->AddUser("alicE")); } { - auto user = auth.AddUser("BoB", "bob"); + auto user = auth->AddUser("BoB", "bob"); ASSERT_TRUE(user); ASSERT_EQ(user->username(), "bob"); - ASSERT_FALSE(auth.AddUser("bob")); - ASSERT_FALSE(auth.AddUser("bOb")); + ASSERT_FALSE(auth->AddUser("bob")); + ASSERT_FALSE(auth->AddUser("bOb")); } // Authenticate { - auto user = auth.Authenticate("alice", "alice"); + auto user = auth->Authenticate("alice", "alice"); ASSERT_TRUE(user); ASSERT_EQ(user->username(), "alice"); } { - auto user = auth.Authenticate("alICe", "alice"); + auto user = auth->Authenticate("alICe", "alice"); ASSERT_TRUE(user); ASSERT_EQ(user->username(), "alice"); } // GetUser { - auto user = auth.GetUser("alice"); + auto user = auth->GetUser("alice"); ASSERT_TRUE(user); ASSERT_EQ(user->username(), "alice"); } { - auto user = auth.GetUser("aLicE"); + auto user = auth->GetUser("aLicE"); ASSERT_TRUE(user); ASSERT_EQ(user->username(), "alice"); } - ASSERT_FALSE(auth.GetUser("carol")); + ASSERT_FALSE(auth->GetUser("carol")); // RemoveUser { - auto user = auth.AddUser("caRol", "carol"); + auto user = auth->AddUser("caRol", "carol"); ASSERT_TRUE(user); ASSERT_EQ(user->username(), "carol"); - ASSERT_TRUE(auth.RemoveUser("cAROl")); - ASSERT_FALSE(auth.RemoveUser("carol")); - ASSERT_FALSE(auth.GetUser("CAROL")); + ASSERT_TRUE(auth->RemoveUser("cAROl")); + ASSERT_FALSE(auth->RemoveUser("carol")); + ASSERT_FALSE(auth->GetUser("CAROL")); } // AllUsers { - auto users = auth.AllUsers(); + auto users = auth->AllUsers(); ASSERT_EQ(users.size(), 2); std::sort(users.begin(), users.end(), [](const auto &a, const auto &b) { return a.username() < b.username(); }); ASSERT_EQ(users[0].username(), "alice"); @@ -808,48 +806,48 @@ TEST_F(AuthWithStorage, CaseInsensitivity) { // AddRole { - auto role = auth.AddRole("Moderator"); + auto role = auth->AddRole("Moderator"); ASSERT_TRUE(role); ASSERT_EQ(role->rolename(), "moderator"); - ASSERT_FALSE(auth.AddRole("moderator")); - ASSERT_FALSE(auth.AddRole("MODERATOR")); + ASSERT_FALSE(auth->AddRole("moderator")); + ASSERT_FALSE(auth->AddRole("MODERATOR")); } { - auto role = auth.AddRole("adMIN"); + auto role = auth->AddRole("adMIN"); ASSERT_TRUE(role); ASSERT_EQ(role->rolename(), "admin"); - ASSERT_FALSE(auth.AddRole("Admin")); - ASSERT_FALSE(auth.AddRole("ADMIn")); + ASSERT_FALSE(auth->AddRole("Admin")); + ASSERT_FALSE(auth->AddRole("ADMIn")); } - ASSERT_FALSE(auth.AddRole("ALICE")); - ASSERT_FALSE(auth.AddUser("ModeRAtor")); + ASSERT_FALSE(auth->AddRole("ALICE")); + ASSERT_FALSE(auth->AddUser("ModeRAtor")); // GetRole { - auto role = auth.GetRole("moderator"); + auto role = auth->GetRole("moderator"); ASSERT_TRUE(role); ASSERT_EQ(role->rolename(), "moderator"); } { - auto role = auth.GetRole("MoDERATOR"); + auto role = auth->GetRole("MoDERATOR"); ASSERT_TRUE(role); ASSERT_EQ(role->rolename(), "moderator"); } - ASSERT_FALSE(auth.GetRole("root")); + ASSERT_FALSE(auth->GetRole("root")); // RemoveRole { - auto role = auth.AddRole("RooT"); + auto role = auth->AddRole("RooT"); ASSERT_TRUE(role); ASSERT_EQ(role->rolename(), "root"); - ASSERT_TRUE(auth.RemoveRole("rOOt")); - ASSERT_FALSE(auth.RemoveRole("RoOt")); - ASSERT_FALSE(auth.GetRole("RoOt")); + ASSERT_TRUE(auth->RemoveRole("rOOt")); + ASSERT_FALSE(auth->RemoveRole("RoOt")); + ASSERT_FALSE(auth->GetRole("RoOt")); } // AllRoles { - auto roles = auth.AllRoles(); + auto roles = auth->AllRoles(); ASSERT_EQ(roles.size(), 2); std::sort(roles.begin(), roles.end(), [](const auto &a, const auto &b) { return a.rolename() < b.rolename(); }); ASSERT_EQ(roles[0].rolename(), "admin"); @@ -858,14 +856,14 @@ TEST_F(AuthWithStorage, CaseInsensitivity) { // SaveRole { - auto role = auth.GetRole("MODErator"); + auto role = auth->GetRole("MODErator"); ASSERT_TRUE(role); ASSERT_EQ(role->rolename(), "moderator"); role->permissions().Grant(memgraph::auth::Permission::MATCH); - auth.SaveRole(*role); + auth->SaveRole(*role); } { - auto role = auth.GetRole("modeRATOR"); + auto role = auth->GetRole("modeRATOR"); ASSERT_TRUE(role); ASSERT_EQ(role->rolename(), "moderator"); ASSERT_EQ(role->permissions().Has(memgraph::auth::Permission::MATCH), memgraph::auth::PermissionLevel::GRANT); @@ -873,17 +871,17 @@ TEST_F(AuthWithStorage, CaseInsensitivity) { // SaveUser { - auto user = auth.GetUser("aLice"); + auto user = auth->GetUser("aLice"); ASSERT_TRUE(user); ASSERT_EQ(user->username(), "alice"); - auto role = auth.GetRole("moderAtor"); + auto role = auth->GetRole("moderAtor"); ASSERT_TRUE(role); ASSERT_EQ(role->rolename(), "moderator"); user->SetRole(*role); - auth.SaveUser(*user); + auth->SaveUser(*user); } { - auto user = auth.GetUser("aLIce"); + auto user = auth->GetUser("aLIce"); ASSERT_TRUE(user); ASSERT_EQ(user->username(), "alice"); const auto *role = user->role(); @@ -893,27 +891,27 @@ TEST_F(AuthWithStorage, CaseInsensitivity) { // AllUsersForRole { - auto carol = auth.AddUser("caROl"); + auto carol = auth->AddUser("caROl"); ASSERT_TRUE(carol); ASSERT_EQ(carol->username(), "carol"); - auto dave = auth.AddUser("daVe"); + auto dave = auth->AddUser("daVe"); ASSERT_TRUE(dave); ASSERT_EQ(dave->username(), "dave"); - auto admin = auth.GetRole("aDMin"); + auto admin = auth->GetRole("aDMin"); ASSERT_TRUE(admin); ASSERT_EQ(admin->rolename(), "admin"); carol->SetRole(*admin); - auth.SaveUser(*carol); + auth->SaveUser(*carol); dave->SetRole(*admin); - auth.SaveUser(*dave); + auth->SaveUser(*dave); } { - auto users = auth.AllUsersForRole("modeRAtoR"); + auto users = auth->AllUsersForRole("modeRAtoR"); ASSERT_EQ(users.size(), 1); ASSERT_EQ(users[0].username(), "alice"); } { - auto users = auth.AllUsersForRole("AdmiN"); + auto users = auth->AllUsersForRole("AdmiN"); ASSERT_EQ(users.size(), 2); std::sort(users.begin(), users.end(), [](const auto &a, const auto &b) { return a.username() < b.username(); }); ASSERT_EQ(users[0].username(), "carol"); @@ -966,18 +964,15 @@ class AuthWithStorageWithVariousEncryptionAlgorithms : public ::testing::Test { protected: void SetUp() override { memgraph::utils::EnsureDir(test_folder_); - FLAGS_auth_password_permit_null = true; - FLAGS_auth_password_strength_regex = ".+"; FLAGS_password_encryption_algorithm = "bcrypt"; - memgraph::license::global_license_checker.EnableTesting(); } void TearDown() override { fs::remove_all(test_folder_); } fs::path test_folder_{fs::temp_directory_path() / "MG_tests_unit_auth"}; - - Auth auth{test_folder_ / ("unit_auth_test_" + std::to_string(static_cast(getpid())))}; + Auth::Config auth_config{}; + Auth auth{test_folder_ / ("unit_auth_test_" + std::to_string(static_cast(getpid()))), auth_config}; }; TEST_F(AuthWithStorageWithVariousEncryptionAlgorithms, AddUserDefault) { diff --git a/tests/unit/auth_handler.cpp b/tests/unit/auth_handler.cpp index 6537575fd..1230fe4ba 100644 --- a/tests/unit/auth_handler.cpp +++ b/tests/unit/auth_handler.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -26,8 +26,9 @@ class AuthQueryHandlerFixture : public testing::Test { protected: std::filesystem::path test_folder_{std::filesystem::temp_directory_path() / "MG_tests_unit_auth_handler"}; memgraph::utils::Synchronized auth{ - test_folder_ / ("unit_auth_handler_test_" + std::to_string(static_cast(getpid())))}; - memgraph::glue::AuthQueryHandler auth_handler{&auth, memgraph::glue::kDefaultUserRoleRegex.data()}; + test_folder_ / ("unit_auth_handler_test_" + std::to_string(static_cast(getpid()))), + memgraph::auth::Auth::Config{/* default */}}; + memgraph::glue::AuthQueryHandler auth_handler{&auth}; std::string user_name = "Mate"; std::string edge_type_repr = "EdgeType1"; diff --git a/tests/unit/dbms_handler.cpp b/tests/unit/dbms_handler.cpp index e0d566240..2abe0b77d 100644 --- a/tests/unit/dbms_handler.cpp +++ b/tests/unit/dbms_handler.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -66,7 +66,7 @@ class TestEnvironment : public ::testing::Environment { } auth = std::make_unique>( - storage_directory / "auth"); + storage_directory / "auth", memgraph::auth::Auth::Config{/* default */}); ptr_ = std::make_unique(storage_conf, auth.get(), false); } diff --git a/tests/unit/dbms_handler_community.cpp b/tests/unit/dbms_handler_community.cpp index 58f8dd2ad..4a47e018b 100644 --- a/tests/unit/dbms_handler_community.cpp +++ b/tests/unit/dbms_handler_community.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -51,7 +51,7 @@ class TestEnvironment : public ::testing::Environment { } auth = std::make_unique>( - storage_directory / "auth"); + storage_directory / "auth", memgraph::auth::Auth::Config{/* default */}); ptr_ = std::make_unique(storage_conf); } diff --git a/tests/unit/multi_tenancy.cpp b/tests/unit/multi_tenancy.cpp index 5581dcada..59364776a 100644 --- a/tests/unit/multi_tenancy.cpp +++ b/tests/unit/multi_tenancy.cpp @@ -98,10 +98,8 @@ class MultiTenantTest : public ::testing::Test { struct MinMemgraph { explicit MinMemgraph(const memgraph::storage::Config &conf) - : dbms{conf, - reinterpret_cast< - memgraph::utils::Synchronized *>(0), - true}, + : auth{conf.durability.storage_directory / "auth", memgraph::auth::Auth::Config{/* default */}}, + dbms{conf, &auth, true}, interpreter_context{{}, &dbms, &dbms.ReplicationState()} { memgraph::utils::global_settings.Initialize(conf.durability.storage_directory / "settings"); memgraph::license::RegisterLicenseSettings(memgraph::license::global_license_checker, @@ -114,6 +112,7 @@ class MultiTenantTest : public ::testing::Test { auto NewInterpreter() { return InterpreterFaker{&interpreter_context, dbms.Get()}; } + memgraph::utils::Synchronized auth; memgraph::dbms::DbmsHandler dbms; memgraph::query::InterpreterContext interpreter_context; }; diff --git a/tests/unit/storage_v2_replication.cpp b/tests/unit/storage_v2_replication.cpp index 4a2515ec7..008494436 100644 --- a/tests/unit/storage_v2_replication.cpp +++ b/tests/unit/storage_v2_replication.cpp @@ -22,6 +22,7 @@ #include #include #include +#include "auth/auth.hpp" #include "dbms/database.hpp" #include "dbms/dbms_handler.hpp" #include "dbms/replication_handler.hpp" @@ -31,6 +32,7 @@ #include "storage/v2/indices/label_index_stats.hpp" #include "storage/v2/storage.hpp" #include "storage/v2/view.hpp" +#include "utils/rw_lock.hpp" #include "utils/synchronized.hpp" using testing::UnorderedElementsAre; @@ -111,12 +113,11 @@ class ReplicationTest : public ::testing::Test { struct MinMemgraph { MinMemgraph(const memgraph::storage::Config &conf) - : dbms{conf + : auth{conf.durability.storage_directory / "auth", memgraph::auth::Auth::Config{/* default */}}, + dbms{conf #ifdef MG_ENTERPRISE , - reinterpret_cast< - memgraph::utils::Synchronized *>(0), - true + &auth, true #endif }, repl_state{dbms.ReplicationState()}, @@ -124,6 +125,8 @@ struct MinMemgraph { db{*db_acc.get()}, repl_handler(dbms) { } + + memgraph::utils::Synchronized auth; memgraph::dbms::DbmsHandler dbms; memgraph::replication::ReplicationState &repl_state; memgraph::dbms::DatabaseAccess db_acc;