diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4685da727..4d5d523c6 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -22,8 +22,10 @@ add_subdirectory(dbms) add_subdirectory(flags) add_subdirectory(distributed) add_subdirectory(replication) +add_subdirectory(replication_handler) add_subdirectory(coordination) add_subdirectory(replication_coordination_glue) +add_subdirectory(system) string(TOLOWER ${CMAKE_BUILD_TYPE} lower_build_type) @@ -43,10 +45,10 @@ set(mg_single_node_v2_sources add_executable(memgraph ${mg_single_node_v2_sources}) target_include_directories(memgraph PUBLIC ${CMAKE_SOURCE_DIR}/include) target_link_libraries(memgraph stdc++fs Threads::Threads - mg-telemetry mg-communication mg-communication-metrics mg-memory mg-utils mg-license mg-settings mg-glue mg-flags) + mg-telemetry mg-communication mg-communication-metrics mg-memory mg-utils mg-license mg-settings mg-glue mg-flags mg::system mg::replication_handler) # NOTE: `include/mg_procedure.syms` describes a pattern match for symbols which -# should be dynamically exported, so that `dlopen` can correctly link the +# should be dynamically exported, so that `dlopen` can correctly link th # symbols in custom procedure module libraries. target_link_libraries(memgraph "-Wl,--dynamic-list=${CMAKE_SOURCE_DIR}/include/mg_procedure.syms") set_target_properties(memgraph PROPERTIES diff --git a/src/auth/CMakeLists.txt b/src/auth/CMakeLists.txt index 4e5b5697a..49c8258c4 100644 --- a/src/auth/CMakeLists.txt +++ b/src/auth/CMakeLists.txt @@ -2,7 +2,9 @@ set(auth_src_files auth.cpp crypto.cpp models.cpp - module.cpp) + module.cpp + rpc.cpp + replication_handlers.cpp) find_package(Seccomp REQUIRED) find_package(fmt REQUIRED) @@ -11,7 +13,7 @@ find_package(gflags REQUIRED) add_library(mg-auth STATIC ${auth_src_files}) target_link_libraries(mg-auth json libbcrypt gflags fmt::fmt) -target_link_libraries(mg-auth mg-utils mg-kvstore mg-license ) +target_link_libraries(mg-auth mg-utils mg-kvstore mg-license mg::system mg-replication) target_link_libraries(mg-auth ${Seccomp_LIBRARIES}) target_include_directories(mg-auth SYSTEM PRIVATE ${Seccomp_INCLUDE_DIRS}) diff --git a/src/auth/auth.cpp b/src/auth/auth.cpp index 88f0c4410..405c04c45 100644 --- a/src/auth/auth.cpp +++ b/src/auth/auth.cpp @@ -9,13 +9,16 @@ #include "auth/auth.hpp" #include +#include #include #include #include "auth/crypto.hpp" #include "auth/exceptions.hpp" +#include "auth/rpc.hpp" #include "license/license.hpp" +#include "system/transaction.hpp" #include "utils/flag_validation.hpp" #include "utils/message.hpp" #include "utils/settings.hpp" @@ -41,12 +44,84 @@ DEFINE_VALIDATED_int32(auth_module_timeout_ms, 10000, FLAG_IN_RANGE(100, 1800000)); namespace memgraph::auth { + +namespace { +#ifdef MG_ENTERPRISE +/** + * REPLICATION SYSTEM ACTION IMPLEMENTATIONS + */ +struct UpdateAuthData : memgraph::system::ISystemAction { + explicit UpdateAuthData(User user) : user_{std::move(user)}, role_{std::nullopt} {} + explicit UpdateAuthData(Role role) : user_{std::nullopt}, role_{std::move(role)} {} + + void DoDurability() override { /* Done during Auth execution */ + } + + bool DoReplication(replication::ReplicationClient &client, replication::ReplicationEpoch const &epoch, + memgraph::system::Transaction const &txn) const override { + auto check_response = [](const replication::UpdateAuthDataRes &response) { return response.success; }; + if (user_) { + return client.SteamAndFinalizeDelta( + check_response, std::string{epoch.id()}, txn.last_committed_system_timestamp(), txn.timestamp(), *user_); + } + if (role_) { + return client.SteamAndFinalizeDelta( + check_response, std::string{epoch.id()}, txn.last_committed_system_timestamp(), txn.timestamp(), *role_); + } + // Should never get here + MG_ASSERT(false, "Trying to update auth data that is not a user nor a role"); + return {}; + } + + void PostReplication(replication::RoleMainData &mainData) const override {} + + private: + std::optional user_; + std::optional role_; +}; + +struct DropAuthData : memgraph::system::ISystemAction { + enum class AuthDataType { USER, ROLE }; + + explicit DropAuthData(AuthDataType type, std::string_view name) : type_{type}, name_{name} {} + + void DoDurability() override { /* Done during Auth execution */ + } + + bool DoReplication(replication::ReplicationClient &client, replication::ReplicationEpoch const &epoch, + memgraph::system::Transaction const &txn) const override { + auto check_response = [](const replication::DropAuthDataRes &response) { return response.success; }; + + memgraph::replication::DropAuthDataReq::DataType type{}; + switch (type_) { + case AuthDataType::USER: + type = memgraph::replication::DropAuthDataReq::DataType::USER; + break; + case AuthDataType::ROLE: + type = memgraph::replication::DropAuthDataReq::DataType::ROLE; + break; + } + return client.SteamAndFinalizeDelta( + check_response, std::string{epoch.id()}, txn.last_committed_system_timestamp(), txn.timestamp(), type, name_); + } + void PostReplication(replication::RoleMainData &mainData) const override {} + + private: + AuthDataType type_; + std::string name_; +}; +#endif + +/** + * CONSTANTS + */ const std::string kUserPrefix = "user:"; const std::string kRolePrefix = "role:"; const std::string kLinkPrefix = "link:"; const std::string kVersion = "version"; static constexpr auto kVersionV1 = "V1"; +} // namespace /** * All data stored in the `Auth` storage is stored in an underlying @@ -148,6 +223,12 @@ std::optional Auth::Authenticate(const std::string &username, const std::s // Authenticate the user. if (!is_authenticated) return std::nullopt; + /** + * TODO + * The auth module should not update auth data. + * There is now way to replicate it and we should not be storing sensitive data if we don't have to. + */ + // Find or create the user and return it. auto user = GetUser(username); if (!user) { @@ -240,7 +321,7 @@ std::optional Auth::GetUser(const std::string &username_orig) const { return user; } -void Auth::SaveUser(const User &user) { +void Auth::SaveUser(const User &user, system::Transaction *system_tx) { bool success = false; if (const auto *role = user.role(); role != nullptr) { success = storage_.PutMultiple( @@ -252,6 +333,12 @@ void Auth::SaveUser(const User &user) { if (!success) { throw AuthException("Couldn't save user '{}'!", user.username()); } + // All changes to the user end up calling this function, so no need to add a delta anywhere else + if (system_tx) { +#ifdef MG_ENTERPRISE + system_tx->AddAction(user); +#endif + } } void Auth::UpdatePassword(auth::User &user, const std::optional &password) { @@ -284,7 +371,8 @@ void Auth::UpdatePassword(auth::User &user, const std::optional &pa user.UpdatePassword(password); } -std::optional Auth::AddUser(const std::string &username, const std::optional &password) { +std::optional Auth::AddUser(const std::string &username, const std::optional &password, + system::Transaction *system_tx) { if (!NameRegexMatch(username)) { throw AuthException("Invalid user name."); } @@ -294,17 +382,23 @@ std::optional Auth::AddUser(const std::string &username, const std::option if (existing_role) return std::nullopt; auto new_user = User(username); UpdatePassword(new_user, password); - SaveUser(new_user); + SaveUser(new_user, system_tx); return new_user; } -bool Auth::RemoveUser(const std::string &username_orig) { +bool Auth::RemoveUser(const std::string &username_orig, system::Transaction *system_tx) { auto username = utils::ToLowerCase(username_orig); if (!storage_.Get(kUserPrefix + username)) return false; std::vector keys({kLinkPrefix + username, kUserPrefix + username}); if (!storage_.DeleteMultiple(keys)) { throw AuthException("Couldn't remove user '{}'!", username); } + // Handling drop user delta + if (system_tx) { +#ifdef MG_ENTERPRISE + system_tx->AddAction(DropAuthData::AuthDataType::USER, username); +#endif + } return true; } @@ -321,6 +415,19 @@ std::vector Auth::AllUsers() const { return ret; } +std::vector Auth::AllUsernames() const { + std::vector ret; + for (auto it = storage_.begin(kUserPrefix); it != storage_.end(kUserPrefix); ++it) { + auto username = it->first.substr(kUserPrefix.size()); + if (username != utils::ToLowerCase(username)) continue; + auto user = GetUser(username); + if (user) { + ret.push_back(username); + } + } + return ret; +} + bool Auth::HasUsers() const { return storage_.begin(kUserPrefix) != storage_.end(kUserPrefix); } std::optional Auth::GetRole(const std::string &rolename_orig) const { @@ -338,24 +445,30 @@ std::optional Auth::GetRole(const std::string &rolename_orig) const { return Role::Deserialize(data); } -void Auth::SaveRole(const Role &role) { +void Auth::SaveRole(const Role &role, system::Transaction *system_tx) { if (!storage_.Put(kRolePrefix + role.rolename(), role.Serialize().dump())) { throw AuthException("Couldn't save role '{}'!", role.rolename()); } + // All changes to the role end up calling this function, so no need to add a delta anywhere else + if (system_tx) { +#ifdef MG_ENTERPRISE + system_tx->AddAction(role); +#endif + } } -std::optional Auth::AddRole(const std::string &rolename) { +std::optional Auth::AddRole(const std::string &rolename, system::Transaction *system_tx) { if (!NameRegexMatch(rolename)) { throw AuthException("Invalid role name."); } if (auto existing_role = GetRole(rolename)) return std::nullopt; if (auto existing_user = GetUser(rolename)) return std::nullopt; auto new_role = Role(rolename); - SaveRole(new_role); + SaveRole(new_role, system_tx); return new_role; } -bool Auth::RemoveRole(const std::string &rolename_orig) { +bool Auth::RemoveRole(const std::string &rolename_orig, system::Transaction *system_tx) { auto rolename = utils::ToLowerCase(rolename_orig); if (!storage_.Get(kRolePrefix + rolename)) return false; std::vector keys; @@ -368,6 +481,12 @@ bool Auth::RemoveRole(const std::string &rolename_orig) { if (!storage_.DeleteMultiple(keys)) { throw AuthException("Couldn't remove role '{}'!", rolename); } + // Handling drop role delta + if (system_tx) { +#ifdef MG_ENTERPRISE + system_tx->AddAction(DropAuthData::AuthDataType::ROLE, rolename); +#endif + } return true; } @@ -385,6 +504,18 @@ std::vector Auth::AllRoles() const { return ret; } +std::vector Auth::AllRolenames() const { + std::vector ret; + for (auto it = storage_.begin(kRolePrefix); it != storage_.end(kRolePrefix); ++it) { + auto rolename = it->first.substr(kRolePrefix.size()); + if (rolename != utils::ToLowerCase(rolename)) continue; + if (auto role = GetRole(rolename)) { + ret.push_back(rolename); + } + } + return ret; +} + std::vector Auth::AllUsersForRole(const std::string &rolename_orig) const { const auto rolename = utils::ToLowerCase(rolename_orig); std::vector ret; @@ -404,48 +535,48 @@ std::vector Auth::AllUsersForRole(const std::string &rolename_orig) } #ifdef MG_ENTERPRISE -bool Auth::GrantDatabaseToUser(const std::string &db, const std::string &name) { +bool Auth::GrantDatabaseToUser(const std::string &db, const std::string &name, system::Transaction *system_tx) { if (auto user = GetUser(name)) { if (db == kAllDatabases) { user->db_access().GrantAll(); } else { user->db_access().Add(db); } - SaveUser(*user); + SaveUser(*user, system_tx); return true; } return false; } -bool Auth::RevokeDatabaseFromUser(const std::string &db, const std::string &name) { +bool Auth::RevokeDatabaseFromUser(const std::string &db, const std::string &name, system::Transaction *system_tx) { if (auto user = GetUser(name)) { if (db == kAllDatabases) { user->db_access().DenyAll(); } else { user->db_access().Remove(db); } - SaveUser(*user); + SaveUser(*user, system_tx); return true; } return false; } -void Auth::DeleteDatabase(const std::string &db) { +void Auth::DeleteDatabase(const std::string &db, system::Transaction *system_tx) { for (auto it = storage_.begin(kUserPrefix); it != storage_.end(kUserPrefix); ++it) { auto username = it->first.substr(kUserPrefix.size()); if (auto user = GetUser(username)) { user->db_access().Delete(db); - SaveUser(*user); + SaveUser(*user, system_tx); } } } -bool Auth::SetMainDatabase(std::string_view db, const std::string &name) { +bool Auth::SetMainDatabase(std::string_view db, const std::string &name, system::Transaction *system_tx) { if (auto user = GetUser(name)) { if (!user->db_access().SetDefault(db)) { throw AuthException("Couldn't set default database '{}' for user '{}'!", db, name); } - SaveUser(*user); + SaveUser(*user, system_tx); return true; } return false; diff --git a/src/auth/auth.hpp b/src/auth/auth.hpp index aa90c349a..4b1bcd479 100644 --- a/src/auth/auth.hpp +++ b/src/auth/auth.hpp @@ -18,10 +18,15 @@ #include "auth/module.hpp" #include "glue/auth_global.hpp" #include "kvstore/kvstore.hpp" +#include "system/action.hpp" #include "utils/settings.hpp" +#include "utils/synchronized.hpp" namespace memgraph::auth { +class Auth; +using SynchedAuth = memgraph::utils::Synchronized; + static const constexpr char *const kAllDatabases = "*"; /** @@ -68,6 +73,13 @@ class Auth final { config_ = std::move(config); } + /** + * @brief + * + * @return Config + */ + Config GetConfig() const { return config_; } + /** * Authenticates a user using his username and password. * @@ -96,7 +108,7 @@ class Auth final { * * @throw AuthException if unable to save the user. */ - void SaveUser(const User &user); + void SaveUser(const User &user, system::Transaction *system_tx = nullptr); /** * Creates a user if the user doesn't exist. @@ -107,7 +119,8 @@ class Auth final { * @return a user when the user is created, nullopt if the user exists * @throw AuthException if unable to save the user. */ - std::optional AddUser(const std::string &username, const std::optional &password = std::nullopt); + std::optional AddUser(const std::string &username, const std::optional &password = std::nullopt, + system::Transaction *system_tx = nullptr); /** * Removes a user from the storage. @@ -118,7 +131,7 @@ class Auth final { * doesn't exist * @throw AuthException if unable to remove the user. */ - bool RemoveUser(const std::string &username); + bool RemoveUser(const std::string &username, system::Transaction *system_tx = nullptr); /** * @brief @@ -136,6 +149,13 @@ class Auth final { */ std::vector AllUsers() const; + /** + * @brief + * + * @return std::vector + */ + std::vector AllUsernames() const; + /** * Returns whether there are users in the storage. * @@ -160,7 +180,7 @@ class Auth final { * * @throw AuthException if unable to save the role. */ - void SaveRole(const Role &role); + void SaveRole(const Role &role, system::Transaction *system_tx = nullptr); /** * Creates a role if the role doesn't exist. @@ -170,7 +190,7 @@ class Auth final { * @return a role when the role is created, nullopt if the role exists * @throw AuthException if unable to save the role. */ - std::optional AddRole(const std::string &rolename); + std::optional AddRole(const std::string &rolename, system::Transaction *system_tx = nullptr); /** * Removes a role from the storage. @@ -181,7 +201,7 @@ class Auth final { * doesn't exist * @throw AuthException if unable to remove the role. */ - bool RemoveRole(const std::string &rolename); + bool RemoveRole(const std::string &rolename, system::Transaction *system_tx = nullptr); /** * Gets all roles from the storage. @@ -191,6 +211,13 @@ class Auth final { */ std::vector AllRoles() const; + /** + * @brief + * + * @return std::vector + */ + std::vector AllRolenames() const; + /** * Gets all users for a role from the storage. * @@ -210,7 +237,7 @@ class Auth final { * @return true on success * @throw AuthException if unable to find or update the user */ - bool RevokeDatabaseFromUser(const std::string &db, const std::string &name); + bool RevokeDatabaseFromUser(const std::string &db, const std::string &name, system::Transaction *system_tx = nullptr); /** * @brief Grant access to individual database for a user. @@ -220,7 +247,7 @@ class Auth final { * @return true on success * @throw AuthException if unable to find or update the user */ - bool GrantDatabaseToUser(const std::string &db, const std::string &name); + bool GrantDatabaseToUser(const std::string &db, const std::string &name, system::Transaction *system_tx = nullptr); /** * @brief Delete a database from all users. @@ -228,7 +255,7 @@ class Auth final { * @param db name of the database to delete * @throw AuthException if unable to read data */ - void DeleteDatabase(const std::string &db); + void DeleteDatabase(const std::string &db, system::Transaction *system_tx = nullptr); /** * @brief Set main database for an individual user. @@ -238,7 +265,7 @@ class Auth final { * @return true on success * @throw AuthException if unable to find or update the user */ - bool SetMainDatabase(std::string_view db, const std::string &name); + bool SetMainDatabase(std::string_view db, const std::string &name, system::Transaction *system_tx = nullptr); #endif private: diff --git a/src/auth/models.cpp b/src/auth/models.cpp index a59a73c7b..f75e6fe32 100644 --- a/src/auth/models.cpp +++ b/src/auth/models.cpp @@ -611,27 +611,49 @@ Permissions User::GetPermissions() const { #ifdef MG_ENTERPRISE FineGrainedAccessPermissions User::GetFineGrainedAccessLabelPermissions() const { + return Merge(GetUserFineGrainedAccessLabelPermissions(), GetRoleFineGrainedAccessLabelPermissions()); +} + +FineGrainedAccessPermissions User::GetFineGrainedAccessEdgeTypePermissions() const { + return Merge(GetUserFineGrainedAccessEdgeTypePermissions(), GetRoleFineGrainedAccessEdgeTypePermissions()); +} + +FineGrainedAccessPermissions User::GetUserFineGrainedAccessEdgeTypePermissions() const { if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) { return FineGrainedAccessPermissions{}; } - if (role_) { - return Merge(role()->fine_grained_access_handler().label_permissions(), - fine_grained_access_handler_.label_permissions()); + return fine_grained_access_handler_.edge_type_permissions(); +} + +FineGrainedAccessPermissions User::GetUserFineGrainedAccessLabelPermissions() const { + if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) { + return FineGrainedAccessPermissions{}; } return fine_grained_access_handler_.label_permissions(); } -FineGrainedAccessPermissions User::GetFineGrainedAccessEdgeTypePermissions() const { +FineGrainedAccessPermissions User::GetRoleFineGrainedAccessEdgeTypePermissions() const { if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) { return FineGrainedAccessPermissions{}; } + if (role_) { - return Merge(role()->fine_grained_access_handler().edge_type_permissions(), - fine_grained_access_handler_.edge_type_permissions()); + return role()->fine_grained_access_handler().edge_type_permissions(); } - return fine_grained_access_handler_.edge_type_permissions(); + return FineGrainedAccessPermissions{}; +} + +FineGrainedAccessPermissions User::GetRoleFineGrainedAccessLabelPermissions() const { + if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) { + return FineGrainedAccessPermissions{}; + } + + if (role_) { + return role()->fine_grained_access_handler().label_permissions(); + } + return FineGrainedAccessPermissions{}; } #endif diff --git a/src/auth/models.hpp b/src/auth/models.hpp index bb6dd2a7a..b65d172ff 100644 --- a/src/auth/models.hpp +++ b/src/auth/models.hpp @@ -207,6 +207,8 @@ bool operator==(const FineGrainedAccessHandler &first, const FineGrainedAccessHa class Role final { public: + Role() = default; + explicit Role(const std::string &rolename); Role(const std::string &rolename, const Permissions &permissions); #ifdef MG_ENTERPRISE @@ -369,6 +371,10 @@ class User final { #ifdef MG_ENTERPRISE FineGrainedAccessPermissions GetFineGrainedAccessLabelPermissions() const; FineGrainedAccessPermissions GetFineGrainedAccessEdgeTypePermissions() const; + FineGrainedAccessPermissions GetUserFineGrainedAccessLabelPermissions() const; + FineGrainedAccessPermissions GetUserFineGrainedAccessEdgeTypePermissions() const; + FineGrainedAccessPermissions GetRoleFineGrainedAccessLabelPermissions() const; + FineGrainedAccessPermissions GetRoleFineGrainedAccessEdgeTypePermissions() const; const FineGrainedAccessHandler &fine_grained_access_handler() const; FineGrainedAccessHandler &fine_grained_access_handler(); #endif diff --git a/src/auth/replication_handlers.cpp b/src/auth/replication_handlers.cpp new file mode 100644 index 000000000..8ee0cd7f3 --- /dev/null +++ b/src/auth/replication_handlers.cpp @@ -0,0 +1,170 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "auth/replication_handlers.hpp" + +#include "auth/auth.hpp" +#include "auth/rpc.hpp" +#include "license/license.hpp" + +namespace memgraph::auth { + +#ifdef MG_ENTERPRISE +void UpdateAuthDataHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access, auth::SynchedAuth &auth, + slk::Reader *req_reader, slk::Builder *res_builder) { + replication::UpdateAuthDataReq req; + memgraph::slk::Load(&req, req_reader); + + using memgraph::replication::UpdateAuthDataRes; + UpdateAuthDataRes res(false); + + // Note: No need to check epoch, recovery mechanism is done by a full uptodate snapshot + // of the set of databases. Hence no history exists to maintain regarding epoch change. + // If MAIN has changed we need to check this new group_timestamp is consistent with + // what we have so far. + + if (req.expected_group_timestamp != system_state_access.LastCommitedTS()) { + spdlog::debug("UpdateAuthDataHandler: bad expected timestamp {},{}", req.expected_group_timestamp, + system_state_access.LastCommitedTS()); + memgraph::slk::Save(res, res_builder); + return; + } + + try { + // Update + if (req.user) auth->SaveUser(*req.user); + if (req.role) auth->SaveRole(*req.role); + // Success + system_state_access.SetLastCommitedTS(req.new_group_timestamp); + res = UpdateAuthDataRes(true); + spdlog::debug("UpdateAuthDataHandler: SUCCESS updated LCTS to {}", req.new_group_timestamp); + } catch (const auth::AuthException & /* not used */) { + // Failure + } + + memgraph::slk::Save(res, res_builder); +} + +void DropAuthDataHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access, auth::SynchedAuth &auth, + slk::Reader *req_reader, slk::Builder *res_builder) { + replication::DropAuthDataReq req; + memgraph::slk::Load(&req, req_reader); + + using memgraph::replication::DropAuthDataRes; + DropAuthDataRes res(false); + + // Note: No need to check epoch, recovery mechanism is done by a full uptodate snapshot + // of the set of databases. Hence no history exists to maintain regarding epoch change. + // If MAIN has changed we need to check this new group_timestamp is consistent with + // what we have so far. + + if (req.expected_group_timestamp != system_state_access.LastCommitedTS()) { + spdlog::debug("DropAuthDataHandler: bad expected timestamp {},{}", req.expected_group_timestamp, + system_state_access.LastCommitedTS()); + memgraph::slk::Save(res, res_builder); + return; + } + + try { + // Remove + switch (req.type) { + case replication::DropAuthDataReq::DataType::USER: + auth->RemoveUser(req.name); + break; + case replication::DropAuthDataReq::DataType::ROLE: + auth->RemoveRole(req.name); + break; + } + // Success + system_state_access.SetLastCommitedTS(req.new_group_timestamp); + res = DropAuthDataRes(true); + spdlog::debug("DropAuthDataHandler: SUCCESS updated LCTS to {}", req.new_group_timestamp); + } catch (const auth::AuthException & /* not used */) { + // Failure + } + + memgraph::slk::Save(res, res_builder); +} + +bool SystemRecoveryHandler(auth::SynchedAuth &auth, auth::Auth::Config auth_config, + const std::vector &users, const std::vector &roles) { + return auth.WithLock([&](auto &locked_auth) { + // Update config + locked_auth.SetConfig(std::move(auth_config)); + // Get all current users + auto old_users = locked_auth.AllUsernames(); + // Save incoming users + for (const auto &user : users) { + // Missing users + try { + locked_auth.SaveUser(user); + } catch (const auth::AuthException &) { + spdlog::debug("SystemRecoveryHandler: Failed to save user"); + return false; + } + const auto it = std::find(old_users.begin(), old_users.end(), user.username()); + if (it != old_users.end()) old_users.erase(it); + } + // Delete all the leftover users + for (const auto &user : old_users) { + if (!locked_auth.RemoveUser(user)) { + spdlog::debug("SystemRecoveryHandler: Failed to remove user \"{}\".", user); + return false; + } + } + + // Roles are only supported with a license + if (license::global_license_checker.IsEnterpriseValidFast()) { + // Get all current roles + auto old_roles = locked_auth.AllRolenames(); + // Save incoming users + for (const auto &role : roles) { + // Missing users + try { + locked_auth.SaveRole(role); + } catch (const auth::AuthException &) { + spdlog::debug("SystemRecoveryHandler: Failed to save user"); + return false; + } + const auto it = std::find(old_roles.begin(), old_roles.end(), role.rolename()); + if (it != old_roles.end()) old_roles.erase(it); + } + // Delete all the leftover users + for (const auto &role : old_roles) { + if (!locked_auth.RemoveRole(role)) { + spdlog::debug("SystemRecoveryHandler: Failed to remove user \"{}\".", role); + return false; + } + } + } + + // Success + return true; + }); +} + +void Register(replication::RoleReplicaData const &data, system::ReplicaHandlerAccessToState &system_state_access, + auth::SynchedAuth &auth) { + // NOTE: Register even without license as the user could add a license at run-time + data.server->rpc_server_.Register( + [system_state_access, &auth](auto *req_reader, auto *res_builder) mutable { + spdlog::debug("Received UpdateAuthDataRpc"); + UpdateAuthDataHandler(system_state_access, auth, req_reader, res_builder); + }); + data.server->rpc_server_.Register( + [system_state_access, &auth](auto *req_reader, auto *res_builder) mutable { + spdlog::debug("Received DropAuthDataRpc"); + DropAuthDataHandler(system_state_access, auth, req_reader, res_builder); + }); +} +#endif + +} // namespace memgraph::auth diff --git a/src/auth/replication_handlers.hpp b/src/auth/replication_handlers.hpp new file mode 100644 index 000000000..0d46e957f --- /dev/null +++ b/src/auth/replication_handlers.hpp @@ -0,0 +1,31 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include "auth/auth.hpp" +#include "replication/state.hpp" +#include "slk/streams.hpp" +#include "system/state.hpp" + +namespace memgraph::auth { +#ifdef MG_ENTERPRISE +void UpdateAuthDataHandler(system::ReplicaHandlerAccessToState &system_state_access, auth::SynchedAuth &auth, + slk::Reader *req_reader, slk::Builder *res_builder); +void DropAuthDataHandler(system::ReplicaHandlerAccessToState &system_state_access, auth::SynchedAuth &auth, + slk::Reader *req_reader, slk::Builder *res_builder); + +bool SystemRecoveryHandler(auth::SynchedAuth &auth, auth::Auth::Config auth_config, + const std::vector &users, const std::vector &roles); +void Register(replication::RoleReplicaData const &data, system::ReplicaHandlerAccessToState &system_state_access, + auth::SynchedAuth &auth); +#endif +} // namespace memgraph::auth diff --git a/src/auth/rpc.cpp b/src/auth/rpc.cpp new file mode 100644 index 000000000..f1d09eb01 --- /dev/null +++ b/src/auth/rpc.cpp @@ -0,0 +1,178 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "auth/rpc.hpp" + +#include +#include "auth/auth.hpp" +#include "slk/serialization.hpp" +#include "slk/streams.hpp" +#include "utils/enum.hpp" + +namespace memgraph::slk { + +// Serialize code for auth::Role +void Save(const auth::Role &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.Serialize().dump(), builder); +} +namespace { +auth::Role LoadAuthRole(memgraph::slk::Reader *reader) { + std::string tmp; + memgraph::slk::Load(&tmp, reader); + const auto json = nlohmann::json::parse(tmp); + return memgraph::auth::Role::Deserialize(json); +} +} // namespace +// Deserialize code for auth::Role +void Load(auth::Role *self, memgraph::slk::Reader *reader) { *self = LoadAuthRole(reader); } +// Special case for optional +template <> +inline void Load(std::optional *obj, Reader *reader) { + bool exists = false; + Load(&exists, reader); + if (exists) { + obj->emplace(LoadAuthRole(reader)); + } else { + *obj = std::nullopt; + } +} + +// Serialize code for auth::User +void Save(const auth::User &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.Serialize().dump(), builder); + std::optional role{}; + if (const auto *role_ptr = self.role(); role_ptr) { + role.emplace(*role_ptr); + } + memgraph::slk::Save(role, builder); +} +// Deserialize code for auth::User +void Load(auth::User *self, memgraph::slk::Reader *reader) { + std::string tmp; + memgraph::slk::Load(&tmp, reader); + const auto json = nlohmann::json::parse(tmp); + *self = memgraph::auth::User::Deserialize(json); + std::optional role{}; + memgraph::slk::Load(&role, reader); + if (role) + self->SetRole(*role); + else + self->ClearRole(); +} + +// Serialize code for auth::Auth::Config +void Save(const auth::Auth::Config &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.name_regex_str, builder); + memgraph::slk::Save(self.password_regex_str, builder); + memgraph::slk::Save(self.password_permit_null, builder); +} +// Deserialize code for auth::Auth::Config +void Load(auth::Auth::Config *self, memgraph::slk::Reader *reader) { + std::string name_regex_str{}; + std::string password_regex_str{}; + bool password_permit_null{}; + + memgraph::slk::Load(&name_regex_str, reader); + memgraph::slk::Load(&password_regex_str, reader); + memgraph::slk::Load(&password_permit_null, reader); + + *self = auth::Auth::Config{std::move(name_regex_str), std::move(password_regex_str), password_permit_null}; +} + +// Serialize code for UpdateAuthDataReq +void Save(const memgraph::replication::UpdateAuthDataReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.epoch_id, builder); + memgraph::slk::Save(self.expected_group_timestamp, builder); + memgraph::slk::Save(self.new_group_timestamp, builder); + memgraph::slk::Save(self.user, builder); + memgraph::slk::Save(self.role, builder); +} +void Load(memgraph::replication::UpdateAuthDataReq *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->epoch_id, reader); + memgraph::slk::Load(&self->expected_group_timestamp, reader); + memgraph::slk::Load(&self->new_group_timestamp, reader); + memgraph::slk::Load(&self->user, reader); + memgraph::slk::Load(&self->role, reader); +} + +// Serialize code for UpdateAuthDataRes +void Save(const memgraph::replication::UpdateAuthDataRes &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.success, builder); +} +void Load(memgraph::replication::UpdateAuthDataRes *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->success, reader); +} + +// Serialize code for DropAuthDataReq +void Save(const memgraph::replication::DropAuthDataReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.epoch_id, builder); + memgraph::slk::Save(self.expected_group_timestamp, builder); + memgraph::slk::Save(self.new_group_timestamp, builder); + memgraph::slk::Save(utils::EnumToNum<2, uint8_t>(self.type), builder); + memgraph::slk::Save(self.name, builder); +} +void Load(memgraph::replication::DropAuthDataReq *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->epoch_id, reader); + memgraph::slk::Load(&self->expected_group_timestamp, reader); + memgraph::slk::Load(&self->new_group_timestamp, reader); + uint8_t type_tmp = 0; + memgraph::slk::Load(&type_tmp, reader); + if (!utils::NumToEnum<2>(type_tmp, self->type)) { + throw SlkReaderException("Unexpected result line:{}!", __LINE__); + } + memgraph::slk::Load(&self->name, reader); +} + +// Serialize code for DropAuthDataRes +void Save(const memgraph::replication::DropAuthDataRes &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.success, builder); +} +void Load(memgraph::replication::DropAuthDataRes *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->success, reader); +} + +} // namespace memgraph::slk + +namespace memgraph::replication { + +constexpr utils::TypeInfo UpdateAuthDataReq::kType{utils::TypeId::REP_UPDATE_AUTH_DATA_REQ, "UpdateAuthDataReq", + nullptr}; + +constexpr utils::TypeInfo UpdateAuthDataRes::kType{utils::TypeId::REP_UPDATE_AUTH_DATA_RES, "UpdateAuthDataRes", + nullptr}; + +constexpr utils::TypeInfo DropAuthDataReq::kType{utils::TypeId::REP_DROP_AUTH_DATA_REQ, "DropAuthDataReq", nullptr}; + +constexpr utils::TypeInfo DropAuthDataRes::kType{utils::TypeId::REP_DROP_AUTH_DATA_RES, "DropAuthDataRes", nullptr}; + +void UpdateAuthDataReq::Save(const UpdateAuthDataReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); +} +void UpdateAuthDataReq::Load(UpdateAuthDataReq *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(self, reader); +} +void UpdateAuthDataRes::Save(const UpdateAuthDataRes &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); +} +void UpdateAuthDataRes::Load(UpdateAuthDataRes *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(self, reader); +} + +void DropAuthDataReq::Save(const DropAuthDataReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); +} +void DropAuthDataReq::Load(DropAuthDataReq *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); } +void DropAuthDataRes::Save(const DropAuthDataRes &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); +} +void DropAuthDataRes::Load(DropAuthDataRes *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); } + +} // namespace memgraph::replication diff --git a/src/auth/rpc.hpp b/src/auth/rpc.hpp new file mode 100644 index 000000000..55bd403c7 --- /dev/null +++ b/src/auth/rpc.hpp @@ -0,0 +1,119 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include + +#include "auth/auth.hpp" +#include "auth/models.hpp" +#include "rpc/messages.hpp" +#include "slk/streams.hpp" + +namespace memgraph::replication { + +struct UpdateAuthDataReq { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(UpdateAuthDataReq *self, memgraph::slk::Reader *reader); + static void Save(const UpdateAuthDataReq &self, memgraph::slk::Builder *builder); + UpdateAuthDataReq() = default; + UpdateAuthDataReq(std::string epoch_id, uint64_t expected_ts, uint64_t new_ts, auth::User user) + : epoch_id{std::move(epoch_id)}, + expected_group_timestamp{expected_ts}, + new_group_timestamp{new_ts}, + user{std::move(user)} {} + UpdateAuthDataReq(std::string epoch_id, uint64_t expected_ts, uint64_t new_ts, auth::Role role) + : epoch_id{std::move(epoch_id)}, + expected_group_timestamp{expected_ts}, + new_group_timestamp{new_ts}, + role{std::move(role)} {} + + std::string epoch_id; + uint64_t expected_group_timestamp; + uint64_t new_group_timestamp; + std::optional user; + std::optional role; +}; + +struct UpdateAuthDataRes { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(UpdateAuthDataRes *self, memgraph::slk::Reader *reader); + static void Save(const UpdateAuthDataRes &self, memgraph::slk::Builder *builder); + UpdateAuthDataRes() = default; + explicit UpdateAuthDataRes(bool success) : success{success} {} + + bool success; +}; + +using UpdateAuthDataRpc = rpc::RequestResponse; + +struct DropAuthDataReq { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(DropAuthDataReq *self, memgraph::slk::Reader *reader); + static void Save(const DropAuthDataReq &self, memgraph::slk::Builder *builder); + DropAuthDataReq() = default; + + enum class DataType { USER, ROLE }; + + DropAuthDataReq(std::string epoch_id, uint64_t expected_ts, uint64_t new_ts, DataType type, std::string_view name) + : epoch_id{std::move(epoch_id)}, + expected_group_timestamp{expected_ts}, + new_group_timestamp{new_ts}, + type{type}, + name{name} {} + + std::string epoch_id; + uint64_t expected_group_timestamp; + uint64_t new_group_timestamp; + DataType type; + std::string name; +}; + +struct DropAuthDataRes { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(DropAuthDataRes *self, memgraph::slk::Reader *reader); + static void Save(const DropAuthDataRes &self, memgraph::slk::Builder *builder); + DropAuthDataRes() = default; + explicit DropAuthDataRes(bool success) : success{success} {} + + bool success; +}; + +using DropAuthDataRpc = rpc::RequestResponse; + +} // namespace memgraph::replication + +namespace memgraph::slk { + +void Save(const auth::Role &self, memgraph::slk::Builder *builder); +void Load(auth::Role *self, memgraph::slk::Reader *reader); +void Save(const auth::User &self, memgraph::slk::Builder *builder); +void Load(auth::User *self, memgraph::slk::Reader *reader); +void Save(const auth::Auth::Config &self, memgraph::slk::Builder *builder); +void Load(auth::Auth::Config *self, memgraph::slk::Reader *reader); + +void Save(const memgraph::replication::UpdateAuthDataRes &self, memgraph::slk::Builder *builder); +void Load(memgraph::replication::UpdateAuthDataRes *self, memgraph::slk::Reader *reader); +void Save(const memgraph::replication::UpdateAuthDataReq & /*self*/, memgraph::slk::Builder * /*builder*/); +void Load(memgraph::replication::UpdateAuthDataReq * /*self*/, memgraph::slk::Reader * /*reader*/); +void Save(const memgraph::replication::DropAuthDataRes &self, memgraph::slk::Builder *builder); +void Load(memgraph::replication::DropAuthDataRes *self, memgraph::slk::Reader *reader); +void Save(const memgraph::replication::DropAuthDataReq & /*self*/, memgraph::slk::Builder * /*builder*/); +void Load(memgraph::replication::DropAuthDataReq * /*self*/, memgraph::slk::Reader * /*reader*/); +} // namespace memgraph::slk diff --git a/src/communication/websocket/auth.hpp b/src/communication/websocket/auth.hpp index b3d59ade7..1ab865a2a 100644 --- a/src/communication/websocket/auth.hpp +++ b/src/communication/websocket/auth.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -14,8 +14,6 @@ #include #include "auth/auth.hpp" -#include "utils/spin_lock.hpp" -#include "utils/synchronized.hpp" namespace memgraph::communication::websocket { @@ -30,7 +28,7 @@ class AuthenticationInterface { class SafeAuth : public AuthenticationInterface { public: - explicit SafeAuth(utils::Synchronized *auth) : auth_{auth} {} + explicit SafeAuth(auth::SynchedAuth *auth) : auth_{auth} {} bool Authenticate(const std::string &username, const std::string &password) const override; @@ -39,6 +37,6 @@ class SafeAuth : public AuthenticationInterface { bool HasAnyUsers() const override; private: - utils::Synchronized *auth_; + auth::SynchedAuth *auth_; }; } // namespace memgraph::communication::websocket diff --git a/src/coordination/CMakeLists.txt b/src/coordination/CMakeLists.txt index b46843639..d44cbcd26 100644 --- a/src/coordination/CMakeLists.txt +++ b/src/coordination/CMakeLists.txt @@ -13,6 +13,7 @@ target_sources(mg-coordination include/coordination/coordinator_data.hpp include/coordination/constants.hpp include/coordination/coordinator_cluster_config.hpp + include/coordination/coordinator_handlers.hpp PRIVATE coordinator_client.cpp @@ -21,9 +22,10 @@ target_sources(mg-coordination coordinator_server.cpp coordinator_data.cpp coordinator_instance.cpp + coordinator_handlers.cpp ) target_include_directories(mg-coordination PUBLIC include) target_link_libraries(mg-coordination - PUBLIC mg::utils mg::rpc mg::slk mg::io mg::repl_coord_glue lib::rangev3 nuraft + PUBLIC mg::utils mg::rpc mg::slk mg::io mg::repl_coord_glue lib::rangev3 nuraft mg-replication_handler ) diff --git a/src/coordination/coordinator_data.cpp b/src/coordination/coordinator_data.cpp index 2af21949c..856c3e84d 100644 --- a/src/coordination/coordinator_data.cpp +++ b/src/coordination/coordinator_data.cpp @@ -183,7 +183,7 @@ auto CoordinatorData::RegisterInstance(CoordinatorClientConfig config) -> Regist if (std::ranges::any_of(registered_instances_, [&config](CoordinatorInstance const &instance) { return instance.SocketAddress() == config.SocketAddress(); })) { - return RegisterInstanceCoordinatorStatus::END_POINT_EXISTS; + return RegisterInstanceCoordinatorStatus::ENDPOINT_EXISTS; } try { diff --git a/src/dbms/coordinator_handlers.cpp b/src/coordination/coordinator_handlers.cpp similarity index 54% rename from src/dbms/coordinator_handlers.cpp rename to src/coordination/coordinator_handlers.cpp index 42f3a336b..63e1e4f8f 100644 --- a/src/dbms/coordinator_handlers.cpp +++ b/src/coordination/coordinator_handlers.cpp @@ -10,41 +10,35 @@ // licenses/APL.txt. #ifdef MG_ENTERPRISE +#include "coordination/coordinator_handlers.hpp" -#include "dbms/coordinator_handlers.hpp" +#include -#include "coordination/coordinator_exceptions.hpp" #include "coordination/coordinator_rpc.hpp" -#include "dbms/dbms_handler.hpp" -#include "dbms/replication_client.hpp" -#include "dbms/utils.hpp" - -#include "range/v3/view.hpp" +#include "coordination/include/coordination/coordinator_server.hpp" namespace memgraph::dbms { -void CoordinatorHandlers::Register(DbmsHandler &dbms_handler) { - auto &server = dbms_handler.CoordinatorState().GetCoordinatorServer(); - +void CoordinatorHandlers::Register(memgraph::coordination::CoordinatorServer &server, + replication::ReplicationHandler &replication_handler) { server.Register( - [&dbms_handler](slk::Reader *req_reader, slk::Builder *res_builder) -> void { + [&](slk::Reader *req_reader, slk::Builder *res_builder) -> void { spdlog::info("Received PromoteReplicaToMainRpc"); - CoordinatorHandlers::PromoteReplicaToMainHandler(dbms_handler, req_reader, res_builder); + CoordinatorHandlers::PromoteReplicaToMainHandler(replication_handler, req_reader, res_builder); }); server.Register( - [&dbms_handler](slk::Reader *req_reader, slk::Builder *res_builder) -> void { + [&replication_handler](slk::Reader *req_reader, slk::Builder *res_builder) -> void { spdlog::info("Received DemoteMainToReplicaRpc from coordinator server"); - CoordinatorHandlers::DemoteMainToReplicaHandler(dbms_handler, req_reader, res_builder); + CoordinatorHandlers::DemoteMainToReplicaHandler(replication_handler, req_reader, res_builder); }); } -void CoordinatorHandlers::DemoteMainToReplicaHandler(DbmsHandler &dbms_handler, slk::Reader *req_reader, - slk::Builder *res_builder) { - auto &repl_state = dbms_handler.ReplicationState(); - spdlog::info("Executing SetMainToReplicaHandler"); +void CoordinatorHandlers::DemoteMainToReplicaHandler(replication::ReplicationHandler &replication_handler, + slk::Reader *req_reader, slk::Builder *res_builder) { + spdlog::info("Executing DemoteMainToReplicaHandler"); - if (repl_state.IsReplica()) { + if (!replication_handler.IsMain()) { spdlog::error("Setting to replica must be performed on main."); slk::Save(coordination::DemoteMainToReplicaRes{false}, res_builder); return; @@ -57,7 +51,7 @@ void CoordinatorHandlers::DemoteMainToReplicaHandler(DbmsHandler &dbms_handler, .ip_address = req.replication_client_info.replication_ip_address, .port = req.replication_client_info.replication_port}; - if (bool const success = memgraph::dbms::SetReplicationRoleReplica(dbms_handler, clients_config); !success) { + if (!replication_handler.SetReplicationRoleReplica(clients_config)) { spdlog::error("Demoting main to replica failed!"); slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder); return; @@ -66,19 +60,17 @@ void CoordinatorHandlers::DemoteMainToReplicaHandler(DbmsHandler &dbms_handler, slk::Save(coordination::PromoteReplicaToMainRes{true}, res_builder); } -void CoordinatorHandlers::PromoteReplicaToMainHandler(DbmsHandler &dbms_handler, slk::Reader *req_reader, - slk::Builder *res_builder) { - auto &repl_state = dbms_handler.ReplicationState(); - - if (!repl_state.IsReplica()) { - spdlog::error("Only replica can be promoted to main!"); +void CoordinatorHandlers::PromoteReplicaToMainHandler(replication::ReplicationHandler &replication_handler, + slk::Reader *req_reader, slk::Builder *res_builder) { + if (!replication_handler.IsReplica()) { + spdlog::error("Failover must be performed on replica!"); slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder); return; } // This can fail because of disk. If it does, the cluster state could get inconsistent. // We don't handle disk issues. - if (bool const success = memgraph::dbms::DoReplicaToMainPromotion(dbms_handler); !success) { + if (!replication_handler.DoReplicaToMainPromotion()) { spdlog::error("Promoting replica to main failed!"); slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder); return; @@ -96,53 +88,32 @@ void CoordinatorHandlers::PromoteReplicaToMainHandler(DbmsHandler &dbms_handler, }; }; - MG_ASSERT( - std::get(repl_state.ReplicationData()).registered_replicas_.empty(), - "No replicas should be registered after promoting replica to main and before registering replication clients!"); - // registering replicas for (auto const &config : req.replication_clients_info | ranges::views::transform(converter)) { - auto instance_client = repl_state.RegisterReplica(config); + auto instance_client = replication_handler.RegisterReplica(config); if (instance_client.HasError()) { using enum memgraph::replication::RegisterReplicaError; switch (instance_client.GetError()) { - // Can't happen, we are already replica - case NOT_MAIN: - spdlog::error("Failover must be performed on main!"); - slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder); - return; // Can't happen, checked on the coordinator side - case NAME_EXISTS: + case memgraph::query::RegisterReplicaError::NAME_EXISTS: spdlog::error("Replica with the same name already exists!"); slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder); return; // Can't happen, checked on the coordinator side - case ENDPOINT_EXISTS: + case memgraph::query::RegisterReplicaError::ENDPOINT_EXISTS: spdlog::error("Replica with the same endpoint already exists!"); slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder); return; // We don't handle disk issues - case COULD_NOT_BE_PERSISTED: + case memgraph::query::RegisterReplicaError::COULD_NOT_BE_PERSISTED: spdlog::error("Registered replica could not be persisted!"); slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder); return; - case SUCCESS: + case memgraph::query::RegisterReplicaError::CONNECTION_FAILED: + // Connection failure is not a fatal error break; } } - if (!allow_mt_repl && dbms_handler.All().size() > 1) { - spdlog::warn("Multi-tenant replication is currently not supported!"); - } - - auto &instance_client_ref = *instance_client.GetValue(); - - // Update system before enabling individual storage <-> replica clients - dbms_handler.SystemRestore(instance_client_ref); - - const bool all_clients_good = memgraph::dbms::RegisterAllDatabasesClients(dbms_handler, instance_client_ref); - MG_ASSERT(all_clients_good, "Failed to register one or more databases to the REPLICA \"{}\".", config.name); - - StartReplicaClient(dbms_handler, instance_client_ref); } slk::Save(coordination::PromoteReplicaToMainRes{true}, res_builder); diff --git a/src/dbms/coordinator_handlers.hpp b/src/coordination/include/coordination/coordinator_handlers.hpp similarity index 57% rename from src/dbms/coordinator_handlers.hpp rename to src/coordination/include/coordination/coordinator_handlers.hpp index f41de50a9..a5cd4929e 100644 --- a/src/dbms/coordinator_handlers.hpp +++ b/src/coordination/include/coordination/coordinator_handlers.hpp @@ -13,7 +13,9 @@ #ifdef MG_ENTERPRISE -#include "slk/serialization.hpp" +#include "coordination/coordinator_server.hpp" +#include "replication_handler/replication_handler.hpp" +#include "slk/streams.hpp" namespace memgraph::dbms { @@ -21,12 +23,14 @@ class DbmsHandler; class CoordinatorHandlers { public: - static void Register(DbmsHandler &dbms_handler); + static void Register(memgraph::coordination::CoordinatorServer &server, + replication::ReplicationHandler &replication_handler); private: - static void PromoteReplicaToMainHandler(DbmsHandler &dbms_handler, slk::Reader *req_reader, + static void PromoteReplicaToMainHandler(replication::ReplicationHandler &replication_handler, slk::Reader *req_reader, slk::Builder *res_builder); - static void DemoteMainToReplicaHandler(DbmsHandler &dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder); + static void DemoteMainToReplicaHandler(replication::ReplicationHandler &replication_handler, slk::Reader *req_reader, + slk::Builder *res_builder); }; } // namespace memgraph::dbms diff --git a/src/coordination/include/coordination/register_main_replica_coordinator_status.hpp b/src/coordination/include/coordination/register_main_replica_coordinator_status.hpp index 2a8d199af..3e742fb3b 100644 --- a/src/coordination/include/coordination/register_main_replica_coordinator_status.hpp +++ b/src/coordination/include/coordination/register_main_replica_coordinator_status.hpp @@ -19,7 +19,7 @@ namespace memgraph::coordination { enum class RegisterInstanceCoordinatorStatus : uint8_t { NAME_EXISTS, - END_POINT_EXISTS, + ENDPOINT_EXISTS, NOT_COORDINATOR, RPC_FAILED, SUCCESS diff --git a/src/dbms/CMakeLists.txt b/src/dbms/CMakeLists.txt index 9cd94c44c..e1750f4dc 100644 --- a/src/dbms/CMakeLists.txt +++ b/src/dbms/CMakeLists.txt @@ -1,2 +1,10 @@ -add_library(mg-dbms STATIC dbms_handler.cpp database.cpp replication_handler.cpp coordinator_handler.cpp replication_client.cpp inmemory/replication_handlers.cpp coordinator_handlers.cpp) -target_link_libraries(mg-dbms mg-utils mg-storage-v2 mg-query mg-replication mg-coordination) +add_library(mg-dbms STATIC + dbms_handler.cpp + database.cpp + coordinator_handler.cpp + inmemory/replication_handlers.cpp + replication_handlers.cpp + rpc.cpp + +) +target_link_libraries(mg-dbms mg-utils mg-storage-v2 mg-query mg-auth mg-replication mg-coordination) diff --git a/src/dbms/coordinator_handler.cpp b/src/dbms/coordinator_handler.cpp index 1c062c074..958de0f91 100644 --- a/src/dbms/coordinator_handler.cpp +++ b/src/dbms/coordinator_handler.cpp @@ -18,20 +18,21 @@ namespace memgraph::dbms { -CoordinatorHandler::CoordinatorHandler(DbmsHandler &dbms_handler) : dbms_handler_(dbms_handler) {} +CoordinatorHandler::CoordinatorHandler(coordination::CoordinatorState &coordinator_state) + : coordinator_state_(coordinator_state) {} auto CoordinatorHandler::RegisterInstance(memgraph::coordination::CoordinatorClientConfig config) -> coordination::RegisterInstanceCoordinatorStatus { - return dbms_handler_.CoordinatorState().RegisterInstance(config); + return coordinator_state_.RegisterInstance(config); } auto CoordinatorHandler::SetInstanceToMain(std::string instance_name) -> coordination::SetInstanceToMainCoordinatorStatus { - return dbms_handler_.CoordinatorState().SetInstanceToMain(std::move(instance_name)); + return coordinator_state_.SetInstanceToMain(std::move(instance_name)); } auto CoordinatorHandler::ShowInstances() const -> std::vector { - return dbms_handler_.CoordinatorState().ShowInstances(); + return coordinator_state_.ShowInstances(); } } // namespace memgraph::dbms diff --git a/src/dbms/coordinator_handler.hpp b/src/dbms/coordinator_handler.hpp index 6f7ad8ce5..04cfe8032 100644 --- a/src/dbms/coordinator_handler.hpp +++ b/src/dbms/coordinator_handler.hpp @@ -15,11 +15,9 @@ #include "coordination/coordinator_config.hpp" #include "coordination/coordinator_instance_status.hpp" +#include "coordination/coordinator_state.hpp" #include "coordination/register_main_replica_coordinator_status.hpp" -#include "utils/result.hpp" -#include -#include #include namespace memgraph::dbms { @@ -28,7 +26,7 @@ class DbmsHandler; class CoordinatorHandler { public: - explicit CoordinatorHandler(DbmsHandler &dbms_handler); + explicit CoordinatorHandler(coordination::CoordinatorState &coordinator_state); auto RegisterInstance(coordination::CoordinatorClientConfig config) -> coordination::RegisterInstanceCoordinatorStatus; @@ -38,7 +36,7 @@ class CoordinatorHandler { auto ShowInstances() const -> std::vector; private: - DbmsHandler &dbms_handler_; + coordination::CoordinatorState &coordinator_state_; }; } // namespace memgraph::dbms diff --git a/src/dbms/database.cpp b/src/dbms/database.cpp index 9a56d400a..4226456eb 100644 --- a/src/dbms/database.cpp +++ b/src/dbms/database.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -11,10 +11,7 @@ #include "dbms/database.hpp" #include "dbms/inmemory/storage_helper.hpp" -#include "dbms/replication_handler.hpp" -#include "flags/storage_mode.hpp" #include "storage/v2/disk/storage.hpp" -#include "storage/v2/inmemory/storage.hpp" #include "storage/v2/storage_mode.hpp" template struct memgraph::utils::Gatekeeper; diff --git a/src/dbms/dbms_handler.cpp b/src/dbms/dbms_handler.cpp index 7222c4461..861bcf701 100644 --- a/src/dbms/dbms_handler.cpp +++ b/src/dbms/dbms_handler.cpp @@ -11,29 +11,73 @@ #include "dbms/dbms_handler.hpp" -#include "dbms/coordinator_handlers.hpp" -#include "flags/replication.hpp" - #include #include #include "dbms/constants.hpp" #include "dbms/global.hpp" -#include "dbms/replication_client.hpp" #include "spdlog/spdlog.h" +#include "system/include/system/system.hpp" #include "utils/exceptions.hpp" #include "utils/logging.hpp" #include "utils/uuid.hpp" namespace memgraph::dbms { -#ifdef MG_ENTERPRISE - namespace { -constexpr std::string_view kDBPrefix = "database:"; // Key prefix for database durability -constexpr std::string_view kLastCommitedSystemTsKey = "last_commited_system_ts"; // Key for timestamp durability +constexpr std::string_view kDBPrefix = "database:"; // Key prefix for database durability + +std::string RegisterReplicaErrorToString(query::RegisterReplicaError error) { + switch (error) { + using enum query::RegisterReplicaError; + case NAME_EXISTS: + return "NAME_EXISTS"; + case ENDPOINT_EXISTS: + return "ENDPOINT_EXISTS"; + case CONNECTION_FAILED: + return "CONNECTION_FAILED"; + case COULD_NOT_BE_PERSISTED: + return "COULD_NOT_BE_PERSISTED"; + } +} + +// Per storage +// NOTE Storage will connect to all replicas. Future work might change this +void RestoreReplication(replication::RoleMainData &mainData, DatabaseAccess db_acc) { + spdlog::info("Restoring replication role."); + + // Each individual client has already been restored and started. Here we just go through each database and start its + // client + for (auto &instance_client : mainData.registered_replicas_) { + spdlog::info("Replica {} restoration started for {}.", instance_client.name_, db_acc->name()); + const auto &ret = db_acc->storage()->repl_storage_state_.replication_clients_.WithLock( + [&, db_acc](auto &storage_clients) mutable -> utils::BasicResult { + auto client = std::make_unique(instance_client); + auto *storage = db_acc->storage(); + client->Start(storage, std::move(db_acc)); + // After start the storage <-> replica state should be READY or RECOVERING (if correctly started) + // MAYBE_BEHIND isn't a statement of the current state, this is the default value + // Failed to start due to branching of MAIN and REPLICA + if (client->State() == storage::replication::ReplicaState::MAYBE_BEHIND) { + spdlog::warn("Connection failed when registering replica {}. Replica will still be registered.", + instance_client.name_); + } + storage_clients.push_back(std::move(client)); + return {}; + }); + + if (ret.HasError()) { + MG_ASSERT(query::RegisterReplicaError::CONNECTION_FAILED != ret.GetError()); + LOG_FATAL("Failure when restoring replica {}: {}.", instance_client.name_, + RegisterReplicaErrorToString(ret.GetError())); + } + spdlog::info("Replica {} restored for {}.", instance_client.name_, db_acc->name()); + } + spdlog::info("Replication role restored to MAIN."); +} } // namespace +#ifdef MG_ENTERPRISE struct Durability { enum class DurabilityVersion : uint8_t { V0 = 0, @@ -112,11 +156,9 @@ struct Durability { } }; -DbmsHandler::DbmsHandler( - storage::Config config, - memgraph::utils::Synchronized *auth, - bool recovery_on_startup) - : default_config_{std::move(config)}, repl_state_{ReplicationStateRootPath(default_config_)} { +DbmsHandler::DbmsHandler(storage::Config config, memgraph::system::System &system, + replication::ReplicationState &repl_state, auth::SynchedAuth &auth, bool recovery_on_startup) + : default_config_{std::move(config)}, auth_{auth}, repl_state_{repl_state}, system_{&system} { // TODO: Decouple storage config from dbms config // TODO: Save individual db configs inside the kvstore and restore from there @@ -150,19 +192,13 @@ DbmsHandler::DbmsHandler( const auto uuid = json.at("uuid").get(); const auto rel_dir = json.at("rel_dir").get(); spdlog::info("Restoring database {} at {}.", name, rel_dir); - auto new_db = New_(name, uuid, rel_dir); + auto new_db = New_(name, uuid, nullptr, rel_dir); MG_ASSERT(!new_db.HasError(), "Failed while creating database {}.", name); directories.emplace(rel_dir.filename()); spdlog::info("Database {} restored.", name); } - // Read the last timestamp - auto lcst = durability_->Get(kLastCommitedSystemTsKey); - if (lcst) { - last_commited_system_timestamp_ = std::stoul(*lcst); - system_timestamp_ = last_commited_system_timestamp_; - } } else { // Clear databases from the durability list and auth - auto locked_auth = auth->Lock(); + auto locked_auth = auth_.Lock(); auto it = durability_->begin(std::string{kDBPrefix}); auto end = durability_->end(std::string{kDBPrefix}); for (; it != end; ++it) { @@ -172,8 +208,6 @@ DbmsHandler::DbmsHandler( locked_auth->DeleteDatabase(name); durability_->Delete(key); } - // Delete the last timestamp - durability_->Delete(kLastCommitedSystemTsKey); } /* @@ -198,45 +232,29 @@ DbmsHandler::DbmsHandler( */ // Setup the default DB SetupDefault_(); - - /* - * REPLICATION RECOVERY AND STARTUP - */ - // Startup replication state (if recovered at startup) - auto replica = [this](replication::RoleReplicaData const &data) { return StartRpcServer(*this, data); }; - // Replication recovery and frequent check start - auto main = [this](replication::RoleMainData &data) { - for (auto &client : data.registered_replicas_) { - SystemRestore(client); - } - ForEach([this](DatabaseAccess db) { RecoverReplication(db); }); - for (auto &client : data.registered_replicas_) { - StartReplicaClient(*this, client); - } - return true; - }; - // Startup proccess for main/replica - MG_ASSERT(std::visit(memgraph::utils::Overloaded{replica, main}, repl_state_.ReplicationData()), - "Replica recovery failure!"); - - // Warning - if (default_config_.durability.snapshot_wal_mode == storage::Config::Durability::SnapshotWalMode::DISABLED && - repl_state_.IsMain()) { - spdlog::warn( - "The instance has the MAIN replication role, but durability logs and snapshots are disabled. Please " - "consider " - "enabling durability by using --storage-snapshot-interval-sec and --storage-wal-enabled flags because " - "without write-ahead logs this instance is not replicating any data."); - } - - // MAIN or REPLICA instance - if (FLAGS_coordinator_server_port) { - CoordinatorHandlers::Register(*this); - MG_ASSERT(coordinator_state_.GetCoordinatorServer().Start(), "Failed to start coordinator server!"); - } } -DbmsHandler::DeleteResult DbmsHandler::TryDelete(std::string_view db_name) { +struct DropDatabase : memgraph::system::ISystemAction { + explicit DropDatabase(utils::UUID uuid) : uuid_{uuid} {} + void DoDurability() override { /* Done during DBMS execution */ + } + + bool DoReplication(replication::ReplicationClient &client, replication::ReplicationEpoch const &epoch, + memgraph::system::Transaction const &txn) const override { + auto check_response = [](const storage::replication::DropDatabaseRes &response) { + return response.result != storage::replication::DropDatabaseRes::Result::FAILURE; + }; + + return client.SteamAndFinalizeDelta( + check_response, epoch.id(), txn.last_committed_system_timestamp(), txn.timestamp(), uuid_); + } + void PostReplication(replication::RoleMainData &mainData) const override {} + + private: + utils::UUID uuid_; +}; + +DbmsHandler::DeleteResult DbmsHandler::TryDelete(std::string_view db_name, system::Transaction *transaction) { std::lock_guard wr(lock_); if (db_name == kDefaultDB) { // MSG cannot delete the default db @@ -273,9 +291,10 @@ DbmsHandler::DeleteResult DbmsHandler::TryDelete(std::string_view db_name) { // Success // Save delta - if (system_transaction_) { - system_transaction_->delta.emplace(SystemTransaction::Delta::drop_database, uuid); + if (transaction) { + transaction->AddAction(uuid); } + return {}; } @@ -296,18 +315,48 @@ DbmsHandler::DeleteResult DbmsHandler::Delete(utils::UUID uuid) { return Delete_(db_name); } -DbmsHandler::NewResultT DbmsHandler::New_(storage::Config storage_config) { +struct CreateDatabase : memgraph::system::ISystemAction { + explicit CreateDatabase(storage::SalientConfig config, DatabaseAccess db_acc) + : config_{std::move(config)}, db_acc(db_acc) {} + + void DoDurability() override { + // Done during dbms execution + } + + bool DoReplication(replication::ReplicationClient &client, replication::ReplicationEpoch const &epoch, + memgraph::system::Transaction const &txn) const override { + auto check_response = [](const storage::replication::CreateDatabaseRes &response) { + return response.result != storage::replication::CreateDatabaseRes::Result::FAILURE; + }; + + return client.SteamAndFinalizeDelta( + check_response, epoch.id(), txn.last_committed_system_timestamp(), txn.timestamp(), config_); + } + + void PostReplication(replication::RoleMainData &mainData) const override { + // Sync database with REPLICAs + // NOTE: The function bellow is used to create ReplicationStorageClient, so it must be called on a new storage + // We don't need to have it here, since the function won't fail even if the replication client fails to + // connect We will just have everything ready, for recovery at some point. + dbms::DbmsHandler::RecoverStorageReplication(db_acc, mainData); + } + + private: + storage::SalientConfig config_; + DatabaseAccess db_acc; +}; + +DbmsHandler::NewResultT DbmsHandler::New_(storage::Config storage_config, system::Transaction *txn) { auto new_db = db_handler_.New(storage_config, repl_state_); if (new_db.HasValue()) { // Success // Save delta - if (system_transaction_) { - system_transaction_->delta.emplace(SystemTransaction::Delta::create_database, storage_config.salient); - } UpdateDurability(storage_config); - return new_db.GetValue(); + if (txn) { + txn->AddAction(storage_config.salient, new_db.GetValue()); + } } - return new_db.GetError(); + return new_db; } DbmsHandler::DeleteResult DbmsHandler::Delete_(std::string_view db_name) { @@ -361,89 +410,16 @@ void DbmsHandler::UpdateDurability(const storage::Config &config, std::optional< durability_->Put(key, val); } -AllSyncReplicaStatus DbmsHandler::Commit() { - if (system_transaction_ == std::nullopt || system_transaction_->delta == std::nullopt) - return AllSyncReplicaStatus::AllCommitsConfirmed; // Nothing to commit - const auto &delta = *system_transaction_->delta; - - auto sync_status = AllSyncReplicaStatus::AllCommitsConfirmed; - // TODO Create a system client that can handle all of this automatically - switch (delta.action) { - using enum SystemTransaction::Delta::Action; - case CREATE_DATABASE: { - // Replication - auto main_handler = [&](memgraph::replication::RoleMainData &main_data) { - // TODO: data race issue? registered_replicas_ access not protected - // This is sync in any case, as this is the startup - for (auto &client : main_data.registered_replicas_) { - bool completed = SteamAndFinalizeDelta( - client, - [](const storage::replication::CreateDatabaseRes &response) { - return response.result != storage::replication::CreateDatabaseRes::Result::FAILURE; - }, - std::string(main_data.epoch_.id()), last_commited_system_timestamp_, - system_transaction_->system_timestamp, delta.config); - // TODO: reduce duplicate code - if (!completed && client.mode_ == replication_coordination_glue::ReplicationMode::SYNC) { - sync_status = AllSyncReplicaStatus::SomeCommitsUnconfirmed; - } - } - // Sync database with REPLICAs - RecoverReplication(Get_(delta.config.name)); - }; - auto replica_handler = [](memgraph::replication::RoleReplicaData &) { /* Nothing to do */ }; - std::visit(utils::Overloaded{main_handler, replica_handler}, repl_state_.ReplicationData()); - } break; - case DROP_DATABASE: { - // Replication - auto main_handler = [&](memgraph::replication::RoleMainData &main_data) { - // TODO: data race issue? registered_replicas_ access not protected - // This is sync in any case, as this is the startup - for (auto &client : main_data.registered_replicas_) { - bool completed = SteamAndFinalizeDelta( - client, - [](const storage::replication::DropDatabaseRes &response) { - return response.result != storage::replication::DropDatabaseRes::Result::FAILURE; - }, - std::string(main_data.epoch_.id()), last_commited_system_timestamp_, - system_transaction_->system_timestamp, delta.uuid); - // TODO: reduce duplicate code - if (!completed && client.mode_ == replication_coordination_glue::ReplicationMode::SYNC) { - sync_status = AllSyncReplicaStatus::SomeCommitsUnconfirmed; - } - } - }; - auto replica_handler = [](memgraph::replication::RoleReplicaData &) { /* Nothing to do */ }; - std::visit(utils::Overloaded{main_handler, replica_handler}, repl_state_.ReplicationData()); - } break; - } - - durability_->Put(kLastCommitedSystemTsKey, std::to_string(system_transaction_->system_timestamp)); - last_commited_system_timestamp_ = system_transaction_->system_timestamp; - ResetSystemTransaction(); - return sync_status; -} - -#else // not MG_ENTERPRISE - -AllSyncReplicaStatus DbmsHandler::Commit() { - if (system_transaction_ == std::nullopt || system_transaction_->delta == std::nullopt) { - return AllSyncReplicaStatus::AllCommitsConfirmed; // Nothing to commit - } - const auto &delta = *system_transaction_->delta; - - switch (delta.action) { - using enum SystemTransaction::Delta::Action; - case CREATE_DATABASE: - case DROP_DATABASE: - /* Community edition doesn't support multi-tenant replication */ - break; - } - - last_commited_system_timestamp_ = system_transaction_->system_timestamp; - ResetSystemTransaction(); - return AllSyncReplicaStatus::AllCommitsConfirmed; -} - #endif + +void DbmsHandler::RecoverStorageReplication(DatabaseAccess db_acc, replication::RoleMainData &role_main_data) { + if (allow_mt_repl || db_acc->name() == dbms::kDefaultDB) { + // Handle global replication state + spdlog::info("Replication configuration will be stored and will be automatically restored in case of a crash."); + // RECOVER REPLICA CONNECTIONS + memgraph::dbms::RestoreReplication(role_main_data, db_acc); + } else if (!role_main_data.registered_replicas_.empty()) { + spdlog::warn("Multi-tenant replication is currently not supported!"); + } +} } // namespace memgraph::dbms diff --git a/src/dbms/dbms_handler.hpp b/src/dbms/dbms_handler.hpp index 2066321e2..24a7599a2 100644 --- a/src/dbms/dbms_handler.hpp +++ b/src/dbms/dbms_handler.hpp @@ -25,24 +25,24 @@ #include "constants.hpp" #include "dbms/database.hpp" #include "dbms/inmemory/replication_handlers.hpp" -#include "dbms/replication_handler.hpp" +#include "dbms/rpc.hpp" #include "kvstore/kvstore.hpp" +#include "license/license.hpp" #include "replication/replication_client.hpp" #include "storage/v2/config.hpp" -#include "storage/v2/replication/enums.hpp" -#include "storage/v2/replication/rpc.hpp" #include "storage/v2/transaction.hpp" +#include "system/system.hpp" #include "utils/thread_pool.hpp" #ifdef MG_ENTERPRISE #include "coordination/coordinator_state.hpp" #include "dbms/database_handler.hpp" #endif -#include "dbms/transaction.hpp" #include "global.hpp" #include "query/config.hpp" #include "query/interpreter_context.hpp" #include "spdlog/spdlog.h" #include "storage/v2/isolation_level.hpp" +#include "system/system.hpp" #include "utils/logging.hpp" #include "utils/result.hpp" #include "utils/rw_lock.hpp" @@ -51,11 +51,6 @@ namespace memgraph::dbms { -enum class AllSyncReplicaStatus { - AllCommitsConfirmed, - SomeCommitsUnconfirmed, -}; - struct Statistics { uint64_t num_vertex; //!< Sum of vertexes in every database uint64_t num_edges; //!< Sum of edges in every database @@ -111,8 +106,8 @@ class DbmsHandler { * @param auth pointer to the global authenticator * @param recovery_on_startup restore databases (and its content) and authentication data */ - DbmsHandler(storage::Config config, - memgraph::utils::Synchronized *auth, + DbmsHandler(storage::Config config, memgraph::system::System &system, replication::ReplicationState &repl_state, + auth::SynchedAuth &auth, bool recovery_on_startup); // TODO If more arguments are added use a config struct #else /** @@ -120,15 +115,14 @@ class DbmsHandler { * * @param configs storage configuration */ - DbmsHandler(storage::Config config) - : repl_state_{ReplicationStateRootPath(config)}, + DbmsHandler(storage::Config config, memgraph::system::System &system, replication::ReplicationState &repl_state) + : repl_state_{repl_state}, + system_{&system}, db_gatekeeper_{[&] { config.salient.name = kDefaultDB; return std::move(config); }(), - repl_state_} { - RecoverReplication(Get()); - } + repl_state_} {} #endif #ifdef MG_ENTERPRISE @@ -138,10 +132,10 @@ class DbmsHandler { * @param name name of the database * @return NewResultT context on success, error on failure */ - NewResultT New(const std::string &name) { + NewResultT New(const std::string &name, system::Transaction *txn = nullptr) { std::lock_guard wr(lock_); const auto uuid = utils::UUID{}; - return New_(name, uuid); + return New_(name, uuid, txn); } /** @@ -234,7 +228,7 @@ class DbmsHandler { * @param db_name database name * @return DeleteResult error on failure */ - DeleteResult TryDelete(std::string_view db_name); + DeleteResult TryDelete(std::string_view db_name, system::Transaction *transaction = nullptr); /** * @brief Delete or defer deletion of database. @@ -267,23 +261,12 @@ class DbmsHandler { #endif } - replication::ReplicationState &ReplicationState() { return repl_state_; } - replication::ReplicationState const &ReplicationState() const { return repl_state_; } - - bool IsMain() const { return repl_state_.IsMain(); } - bool IsReplica() const { return repl_state_.IsReplica(); } - -#ifdef MG_ENTERPRISE - coordination::CoordinatorState &CoordinatorState() { return coordinator_state_; } -#endif - /** * @brief Return the statistics all databases. * * @return Statistics */ - Statistics Stats() { - auto const replication_role = repl_state_.GetRole(); + Statistics Stats(memgraph::replication_coordination_glue::ReplicationRole replication_role) { Statistics stats{}; // TODO: Handle overflow? #ifdef MG_ENTERPRISE @@ -319,8 +302,7 @@ class DbmsHandler { * * @return std::vector */ - std::vector Info() { - auto const replication_role = repl_state_.GetRole(); + std::vector Info(memgraph::replication_coordination_glue::ReplicationRole replication_role) { std::vector res; #ifdef MG_ENTERPRISE std::shared_lock rd(lock_); @@ -407,98 +389,17 @@ class DbmsHandler { } } - void NewSystemTransaction() { - DMG_ASSERT(!system_transaction_, "Already running a system transaction"); - system_transaction_.emplace(++system_timestamp_); - } - - void ResetSystemTransaction() { system_transaction_.reset(); } - - //! \tparam RPC An rpc::RequestResponse - //! \tparam Args the args type - //! \param client the client to use for rpc communication - //! \param check predicate to check response is ok - //! \param args arguments to forward to the rpc request - //! \return If replica stream is completed or enqueued - template - bool SteamAndFinalizeDelta(auto &client, auto &&check, Args &&...args) { - try { - auto stream = client.rpc_client_.template Stream(std::forward(args)...); - auto task = [&client, check = std::forward(check), stream = std::move(stream)]() mutable { - if (stream.IsDefunct()) { - client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; }); - return false; - } - try { - if (check(stream.AwaitResponse())) { - return true; - } - } catch (memgraph::rpc::GenericRpcFailedException const &e) { - // swallow error, fallthrough to error handling - } - // This replica needs SYSTEM recovery - client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; }); - return false; - }; - - if (client.mode_ == memgraph::replication_coordination_glue::ReplicationMode::ASYNC) { - client.thread_pool_.AddTask([task = utils::CopyMovableFunctionWrapper{std::move(task)}]() mutable { task(); }); - return true; - } - - return task(); - } catch (memgraph::rpc::GenericRpcFailedException const &e) { - // This replica needs SYSTEM recovery - client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; }); - return false; - } - }; - - AllSyncReplicaStatus Commit(); - - auto LastCommitedTS() const -> uint64_t { return last_commited_system_timestamp_; } - void SetLastCommitedTS(uint64_t new_ts) { last_commited_system_timestamp_.store(new_ts); } + static void RecoverStorageReplication(DatabaseAccess db_acc, replication::RoleMainData &role_main_data); + auto default_config() const -> storage::Config const & { #ifdef MG_ENTERPRISE - // When being called by intepreter no need to gain lock, it should already be under a system transaction - // But concurrently the FrequentCheck is running and will need to lock before reading last_commited_system_timestamp_ - template - void SystemRestore(replication::ReplicationClient &client) { - // Check if system is up to date - if (client.state_.WithLock( - [](auto &state) { return state == memgraph::replication::ReplicationClient::State::READY; })) - return; - - // Try to recover... - { - auto [database_configs, last_commited_system_timestamp] = std::invoke([&] { - auto sys_guard = - std::unique_lock{system_lock_, std::defer_lock}; // ensure no other system transaction in progress - if constexpr (REQUIRE_LOCK) { - sys_guard.lock(); - } - auto configs = std::vector{}; - ForEach([&configs](DatabaseAccess acc) { configs.emplace_back(acc->config().salient); }); - return std::pair{configs, last_commited_system_timestamp_.load()}; - }); - try { - auto stream = client.rpc_client_.Stream(last_commited_system_timestamp, - std::move(database_configs)); - const auto response = stream.AwaitResponse(); - if (response.result == storage::replication::SystemRecoveryRes::Result::FAILURE) { - client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; }); - return; - } - } catch (memgraph::rpc::GenericRpcFailedException const &e) { - client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; }); - return; - } - } - - // Successfully recovered - client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::READY; }); - } + return default_config_; +#else + const auto acc = db_gatekeeper_.access(); + MG_ASSERT(acc, "Failed to get default database!"); + return acc->get()->config(); #endif + } private: #ifdef MG_ENTERPRISE @@ -524,7 +425,8 @@ class DbmsHandler { * @param uuid undelying RocksDB directory * @return NewResultT context on success, error on failure */ - NewResultT New_(std::string_view name, utils::UUID uuid, std::optional rel_dir = {}) { + NewResultT New_(std::string_view name, utils::UUID uuid, system::Transaction *txn = nullptr, + std::optional rel_dir = {}) { auto config_copy = default_config_; config_copy.salient.name = name; config_copy.salient.uuid = uuid; @@ -535,7 +437,7 @@ class DbmsHandler { storage::UpdatePaths(config_copy, default_config_.durability.storage_directory / kMultiTenantDir / std::string{uuid}); } - return New_(std::move(config_copy)); + return New_(std::move(config_copy), txn); } /** @@ -544,11 +446,11 @@ class DbmsHandler { * @param config configuration to be used * @return NewResultT context on success, error on failure */ - NewResultT New_(const storage::SalientConfig &config) { + NewResultT New_(const storage::SalientConfig &config, system::Transaction *txn = nullptr) { auto config_copy = default_config_; config_copy.salient = config; // name, uuid, mode, etc UpdatePaths(config_copy, config_copy.durability.storage_directory / kMultiTenantDir / std::string{config.uuid}); - return New_(std::move(config_copy)); + return New_(std::move(config_copy), txn); } /** @@ -557,7 +459,7 @@ class DbmsHandler { * @param storage_config storage configuration * @return NewResultT context on success, error on failure */ - NewResultT New_(storage::Config storage_config); + DbmsHandler::NewResultT New_(storage::Config storage_config, system::Transaction *txn = nullptr); // TODO: new overload of Delete_ with DatabaseAccess DeleteResult Delete_(std::string_view db_name); @@ -572,7 +474,8 @@ class DbmsHandler { Get(kDefaultDB); } catch (const UnknownDatabaseException &) { // No default DB restored, create it - MG_ASSERT(New_(kDefaultDB, {/* random UUID */}, ".").HasValue(), "Failed while creating the default database"); + MG_ASSERT(New_(kDefaultDB, {/* random UUID */}, nullptr, ".").HasValue(), + "Failed while creating the default database"); } // For back-compatibility... @@ -659,35 +562,24 @@ class DbmsHandler { } #endif - void RecoverReplication(DatabaseAccess db_acc) { - if (allow_mt_repl || db_acc->name() == dbms::kDefaultDB) { - // Handle global replication state - spdlog::info("Replication configuration will be stored and will be automatically restored in case of a crash."); - // RECOVER REPLICA CONNECTIONS - memgraph::dbms::RestoreReplication(repl_state_, std::move(db_acc)); - } else if (const ::memgraph::replication::RoleMainData *data = - std::get_if<::memgraph::replication::RoleMainData>(&repl_state_.ReplicationData()); - data && !data->registered_replicas_.empty()) { - spdlog::warn("Multi-tenant replication is currently not supported!"); - } - } - #ifdef MG_ENTERPRISE mutable LockT lock_{utils::RWLock::Priority::READ}; //!< protective lock storage::Config default_config_; //!< Storage configuration used when creating new databases DatabaseHandler db_handler_; //!< multi-tenancy storage handler - std::unique_ptr durability_; //!< list of active dbs (pointer so we can postpone its creation) - coordination::CoordinatorState coordinator_state_; //!< Replication coordinator + // TODO: move to be common + std::unique_ptr durability_; //!< list of active dbs (pointer so we can postpone its creation) + auth::SynchedAuth &auth_; //!< Synchronized auth::Auth #endif - // TODO: Make an api - public: - utils::ResourceLock system_lock_{}; //!> Ensure exclusive access for system queries private: - std::optional system_transaction_; //!< Current system transaction (only one at a time) - uint64_t system_timestamp_{storage::kTimestampInitialId}; //!< System timestamp - std::atomic_uint64_t last_commited_system_timestamp_{ - storage::kTimestampInitialId}; //!< Last commited system timestamp - replication::ReplicationState repl_state_; //!< Global replication state + // NOTE: atm the only reason this exists here, is because we pass it into the construction of New Database's + // Database only uses it as a convience to make the correct Access without out needing to be told the + // current replication role. TODO: make Database Access explicit about the role and remove this from + // dbms stuff + replication::ReplicationState &repl_state_; //!< Ref to global replication state + public: + // TODO fix to be non public/remove from dbms....maybe + system::System *system_; + #ifndef MG_ENTERPRISE mutable utils::Gatekeeper db_gatekeeper_; //!< Single databases gatekeeper #endif diff --git a/src/dbms/inmemory/storage_helper.hpp b/src/dbms/inmemory/storage_helper.hpp index fa1b9646a..900ad5356 100644 --- a/src/dbms/inmemory/storage_helper.hpp +++ b/src/dbms/inmemory/storage_helper.hpp @@ -11,10 +11,6 @@ #pragma once -#include - -#include "dbms/constants.hpp" -#include "dbms/replication_handler.hpp" #include "replication/state.hpp" #include "storage/v2/config.hpp" #include "storage/v2/inmemory/storage.hpp" diff --git a/src/dbms/replication_client.cpp b/src/dbms/replication_client.cpp deleted file mode 100644 index fa0c30daa..000000000 --- a/src/dbms/replication_client.cpp +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2024 Memgraph Ltd. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source -// License, and you may not use this file except in compliance with the Business Source License. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -#include "dbms/replication_client.hpp" -#include "replication/replication_client.hpp" - -namespace memgraph::dbms { - -void StartReplicaClient(DbmsHandler &dbms_handler, replication::ReplicationClient &client) { - // No client error, start instance level client - auto const &endpoint = client.rpc_client_.Endpoint(); - spdlog::trace("Replication client started at: {}:{}", endpoint.address, endpoint.port); - client.StartFrequentCheck([&dbms_handler](bool reconnect, replication::ReplicationClient &client) { - // Working connection - // Check if system needs restoration - if (reconnect) { - client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; }); - } -#ifdef MG_ENTERPRISE - dbms_handler.SystemRestore(client); -#endif - // Check if any database has been left behind - dbms_handler.ForEach([&name = client.name_, reconnect](dbms::DatabaseAccess db_acc) { - // Specific database <-> replica client - db_acc->storage()->repl_storage_state_.WithClient(name, [&](storage::ReplicationStorageClient *client) { - if (reconnect || client->State() == storage::replication::ReplicaState::MAYBE_BEHIND) { - // Database <-> replica might be behind, check and recover - client->TryCheckReplicaStateAsync(db_acc->storage(), db_acc); - } - }); - }); - }); -} // namespace memgraph::dbms - -} // namespace memgraph::dbms diff --git a/src/dbms/replication_handler.cpp b/src/dbms/replication_handler.cpp deleted file mode 100644 index 285752f76..000000000 --- a/src/dbms/replication_handler.cpp +++ /dev/null @@ -1,400 +0,0 @@ -// Copyright 2024 Memgraph Ltd. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source -// License, and you may not use this file except in compliance with the Business Source License. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -#include "dbms/replication_handler.hpp" - -#include - -#include "dbms/constants.hpp" -#include "dbms/dbms_handler.hpp" -#include "dbms/global.hpp" -#include "dbms/inmemory/replication_handlers.hpp" -#include "dbms/replication_client.hpp" -#include "dbms/utils.hpp" -#include "replication/messages.hpp" -#include "replication/state.hpp" -#include "spdlog/spdlog.h" -#include "storage/v2/config.hpp" -#include "storage/v2/replication/rpc.hpp" -#include "utils/on_scope_exit.hpp" - -using memgraph::replication::RoleMainData; -using memgraph::replication::RoleReplicaData; - -namespace memgraph::dbms { - -namespace { - -std::string RegisterReplicaErrorToString(RegisterReplicaError error) { - switch (error) { - using enum RegisterReplicaError; - case NAME_EXISTS: - return "NAME_EXISTS"; - case ENDPOINT_EXISTS: - return "ENDPOINT_EXISTS"; - case CONNECTION_FAILED: - return "CONNECTION_FAILED"; - case COULD_NOT_BE_PERSISTED: - return "COULD_NOT_BE_PERSISTED"; - } -} -} // namespace - -ReplicationHandler::ReplicationHandler(DbmsHandler &dbms_handler) : dbms_handler_(dbms_handler) {} - -bool ReplicationHandler::SetReplicationRoleMain() { - auto const main_handler = [](RoleMainData &) { - // If we are already MAIN, we don't want to change anything - return false; - }; - - auto const replica_handler = [this](RoleReplicaData const &) { - return memgraph::dbms::DoReplicaToMainPromotion(dbms_handler_); - }; - - // TODO: under lock - return std::visit(utils::Overloaded{main_handler, replica_handler}, - dbms_handler_.ReplicationState().ReplicationData()); -} - -bool ReplicationHandler::SetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config) { - // We don't want to restart the server if we're already a REPLICA - if (dbms_handler_.ReplicationState().IsReplica()) { - return false; - } - - // TODO StorageState needs to be synched. Could have a dangling reference if someone adds a database as we are - // deleting the replica. - // Remove database specific clients - dbms_handler_.ForEach([&](DatabaseAccess db_acc) { - auto *storage = db_acc->storage(); - storage->repl_storage_state_.replication_clients_.WithLock([](auto &clients) { clients.clear(); }); - }); - // Remove instance level clients - std::get(dbms_handler_.ReplicationState().ReplicationData()).registered_replicas_.clear(); - - // Creates the server - dbms_handler_.ReplicationState().SetReplicationRoleReplica(config); - - // Start - const auto success = - std::visit(utils::Overloaded{[](RoleMainData const &) { - // ASSERT - return false; - }, - [this](RoleReplicaData const &data) { return StartRpcServer(dbms_handler_, data); }}, - dbms_handler_.ReplicationState().ReplicationData()); - // TODO Handle error (restore to main?) - return success; -} - -auto ReplicationHandler::RegisterReplica(const memgraph::replication::ReplicationClientConfig &config) - -> memgraph::utils::BasicResult { - MG_ASSERT(dbms_handler_.ReplicationState().IsMain(), "Only main instance can register a replica!"); - - auto maybe_client = dbms_handler_.ReplicationState().RegisterReplica(config); - if (maybe_client.HasError()) { - switch (maybe_client.GetError()) { - case memgraph::replication::RegisterReplicaError::NOT_MAIN: - MG_ASSERT(false, "Only main instance can register a replica!"); - return {}; - case memgraph::replication::RegisterReplicaError::NAME_EXISTS: - return memgraph::dbms::RegisterReplicaError::NAME_EXISTS; - case memgraph::replication::RegisterReplicaError::ENDPOINT_EXISTS: - return memgraph::dbms::RegisterReplicaError::ENDPOINT_EXISTS; - case memgraph::replication::RegisterReplicaError::COULD_NOT_BE_PERSISTED: - return memgraph::dbms::RegisterReplicaError::COULD_NOT_BE_PERSISTED; - case memgraph::replication::RegisterReplicaError::SUCCESS: - break; - } - } - - if (!allow_mt_repl && dbms_handler_.All().size() > 1) { - spdlog::warn("Multi-tenant replication is currently not supported!"); - } - -#ifdef MG_ENTERPRISE - // Update system before enabling individual storage <-> replica clients - dbms_handler_.SystemRestore(*maybe_client.GetValue()); -#endif - - const auto dbms_error = memgraph::dbms::HandleRegisterReplicaStatus(maybe_client); - if (dbms_error.has_value()) { - return *dbms_error; - } - auto &instance_client_ptr = maybe_client.GetValue(); - const bool all_clients_good = memgraph::dbms::RegisterAllDatabasesClients(dbms_handler_, *instance_client_ptr); - - // NOTE Currently if any databases fails, we revert back - if (!all_clients_good) { - spdlog::error("Failed to register all databases on the REPLICA \"{}\"", config.name); - UnregisterReplica(config.name); - return RegisterReplicaError::CONNECTION_FAILED; - } - - // No client error, start instance level client - StartReplicaClient(dbms_handler_, *instance_client_ptr); - return {}; -} - -auto ReplicationHandler::UnregisterReplica(std::string_view name) -> UnregisterReplicaResult { - auto const replica_handler = [](RoleReplicaData const &) -> UnregisterReplicaResult { - return UnregisterReplicaResult::NOT_MAIN; - }; - auto const main_handler = [this, name](RoleMainData &mainData) -> UnregisterReplicaResult { - if (!dbms_handler_.ReplicationState().TryPersistUnregisterReplica(name)) { - return UnregisterReplicaResult::COULD_NOT_BE_PERSISTED; - } - // Remove database specific clients - dbms_handler_.ForEach([name](DatabaseAccess db_acc) { - db_acc->storage()->repl_storage_state_.replication_clients_.WithLock([&name](auto &clients) { - std::erase_if(clients, [name](const auto &client) { return client->Name() == name; }); - }); - }); - // Remove instance level clients - auto const n_unregistered = - std::erase_if(mainData.registered_replicas_, [name](auto const &client) { return client.name_ == name; }); - return n_unregistered != 0 ? UnregisterReplicaResult::SUCCESS : UnregisterReplicaResult::CAN_NOT_UNREGISTER; - }; - - return std::visit(utils::Overloaded{main_handler, replica_handler}, - dbms_handler_.ReplicationState().ReplicationData()); -} - -auto ReplicationHandler::GetRole() const -> memgraph::replication_coordination_glue::ReplicationRole { - return dbms_handler_.ReplicationState().GetRole(); -} - -bool ReplicationHandler::IsMain() const { return dbms_handler_.ReplicationState().IsMain(); } - -bool ReplicationHandler::IsReplica() const { return dbms_handler_.ReplicationState().IsReplica(); } - -// Per storage -// NOTE Storage will connect to all replicas. Future work might change this -void RestoreReplication(replication::ReplicationState &repl_state, DatabaseAccess db_acc) { - spdlog::info("Restoring replication role."); - - /// MAIN - auto const recover_main = [db_acc = std::move(db_acc)](RoleMainData &mainData) mutable { // NOLINT - // Each individual client has already been restored and started. Here we just go through each database and start its - // client - for (auto &instance_client : mainData.registered_replicas_) { - spdlog::info("Replica {} restoration started for {}.", instance_client.name_, db_acc->name()); - const auto &ret = db_acc->storage()->repl_storage_state_.replication_clients_.WithLock( - [&, db_acc](auto &storage_clients) mutable -> utils::BasicResult { - auto client = std::make_unique(instance_client); - auto *storage = db_acc->storage(); - client->Start(storage, std::move(db_acc)); - // After start the storage <-> replica state should be READY or RECOVERING (if correctly started) - // MAYBE_BEHIND isn't a statement of the current state, this is the default value - // Failed to start due to branching of MAIN and REPLICA - if (client->State() == storage::replication::ReplicaState::MAYBE_BEHIND) { - spdlog::warn("Connection failed when registering replica {}. Replica will still be registered.", - instance_client.name_); - } - storage_clients.push_back(std::move(client)); - return {}; - }); - - if (ret.HasError()) { - MG_ASSERT(RegisterReplicaError::CONNECTION_FAILED != ret.GetError()); - LOG_FATAL("Failure when restoring replica {}: {}.", instance_client.name_, - RegisterReplicaErrorToString(ret.GetError())); - } - spdlog::info("Replica {} restored for {}.", instance_client.name_, db_acc->name()); - } - spdlog::info("Replication role restored to MAIN."); - }; - - /// REPLICA - auto const recover_replica = [](RoleReplicaData const &data) { /*nothing to do*/ }; - - std::visit( - utils::Overloaded{ - recover_main, - recover_replica, - }, - repl_state.ReplicationData()); -} - -namespace system_replication { -#ifdef MG_ENTERPRISE -void SystemHeartbeatHandler(const uint64_t ts, slk::Reader *req_reader, slk::Builder *res_builder) { - replication::SystemHeartbeatReq req; - replication::SystemHeartbeatReq::Load(&req, req_reader); - - replication::SystemHeartbeatRes res(ts); - memgraph::slk::Save(res, res_builder); -} - -void CreateDatabaseHandler(DbmsHandler &dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder) { - memgraph::storage::replication::CreateDatabaseReq req; - memgraph::slk::Load(&req, req_reader); - - using memgraph::storage::replication::CreateDatabaseRes; - CreateDatabaseRes res(CreateDatabaseRes::Result::FAILURE); - - // Note: No need to check epoch, recovery mechanism is done by a full uptodate snapshot - // of the set of databases. Hence no history exists to maintain regarding epoch change. - // If MAIN has changed we need to check this new group_timestamp is consistent with - // what we have so far. - - if (req.expected_group_timestamp != dbms_handler.LastCommitedTS()) { - spdlog::debug("CreateDatabaseHandler: bad expected timestamp {},{}", req.expected_group_timestamp, - dbms_handler.LastCommitedTS()); - memgraph::slk::Save(res, res_builder); - return; - } - - try { - // Create new - auto new_db = dbms_handler.Update(req.config); - if (new_db.HasValue()) { - // Successfully create db - dbms_handler.SetLastCommitedTS(req.new_group_timestamp); - res = CreateDatabaseRes(CreateDatabaseRes::Result::SUCCESS); - spdlog::debug("CreateDatabaseHandler: SUCCESS updated LCTS to {}", req.new_group_timestamp); - } - } catch (...) { - // Failure - } - - memgraph::slk::Save(res, res_builder); -} - -void DropDatabaseHandler(DbmsHandler &dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder) { - memgraph::storage::replication::DropDatabaseReq req; - memgraph::slk::Load(&req, req_reader); - - using memgraph::storage::replication::DropDatabaseRes; - DropDatabaseRes res(DropDatabaseRes::Result::FAILURE); - - // Note: No need to check epoch, recovery mechanism is done by a full uptodate snapshot - // of the set of databases. Hence no history exists to maintain regarding epoch change. - // If MAIN has changed we need to check this new group_timestamp is consistent with - // what we have so far. - - if (req.expected_group_timestamp != dbms_handler.LastCommitedTS()) { - spdlog::debug("DropDatabaseHandler: bad expected timestamp {},{}", req.expected_group_timestamp, - dbms_handler.LastCommitedTS()); - memgraph::slk::Save(res, res_builder); - return; - } - - try { - // NOTE: Single communication channel can exist at a time, no other database can be deleted/created at the moment. - auto new_db = dbms_handler.Delete(req.uuid); - if (new_db.HasError()) { - if (new_db.GetError() == DeleteError::NON_EXISTENT) { - // Nothing to drop - dbms_handler.SetLastCommitedTS(req.new_group_timestamp); - res = DropDatabaseRes(DropDatabaseRes::Result::NO_NEED); - } - } else { - // Successfully drop db - dbms_handler.SetLastCommitedTS(req.new_group_timestamp); - res = DropDatabaseRes(DropDatabaseRes::Result::SUCCESS); - spdlog::debug("DropDatabaseHandler: SUCCESS updated LCTS to {}", req.new_group_timestamp); - } - } catch (...) { - // Failure - } - - memgraph::slk::Save(res, res_builder); -} - -void SystemRecoveryHandler(DbmsHandler &dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder) { - // TODO Speed up - memgraph::storage::replication::SystemRecoveryReq req; - memgraph::slk::Load(&req, req_reader); - - using memgraph::storage::replication::SystemRecoveryRes; - SystemRecoveryRes res(SystemRecoveryRes::Result::FAILURE); - - utils::OnScopeExit send_on_exit([&]() { memgraph::slk::Save(res, res_builder); }); - - // Get all current dbs - auto old = dbms_handler.All(); - - // Check/create the incoming dbs - for (const auto &config : req.database_configs) { - // Missing db - try { - if (dbms_handler.Update(config).HasError()) { - spdlog::debug("SystemRecoveryHandler: Failed to update database \"{}\".", config.name); - return; // Send failure on exit - } - } catch (const UnknownDatabaseException &) { - spdlog::debug("SystemRecoveryHandler: UnknownDatabaseException"); - return; // Send failure on exit - } - const auto it = std::find(old.begin(), old.end(), config.name); - if (it != old.end()) old.erase(it); - } - - // Delete all the leftover old dbs - for (const auto &remove_db : old) { - const auto del = dbms_handler.Delete(remove_db); - if (del.HasError()) { - // Some errors are not terminal - if (del.GetError() == DeleteError::DEFAULT_DB || del.GetError() == DeleteError::NON_EXISTENT) { - spdlog::debug("SystemRecoveryHandler: Dropped database \"{}\".", remove_db); - continue; - } - spdlog::debug("SystemRecoveryHandler: Failed to drop database \"{}\".", remove_db); - return; // Send failure on exit - } - } - // Successfully recovered - dbms_handler.SetLastCommitedTS(req.forced_group_timestamp); - spdlog::debug("SystemRecoveryHandler: SUCCESS updated LCTS to {}", req.forced_group_timestamp); - res = SystemRecoveryRes(SystemRecoveryRes::Result::SUCCESS); -} -#endif - -void Register(replication::RoleReplicaData const &data, dbms::DbmsHandler &dbms_handler) { -#ifdef MG_ENTERPRISE - data.server->rpc_server_.Register( - [&dbms_handler](auto *req_reader, auto *res_builder) { - spdlog::debug("Received SystemHeartbeatRpc"); - SystemHeartbeatHandler(dbms_handler.LastCommitedTS(), req_reader, res_builder); - }); - data.server->rpc_server_.Register( - [&dbms_handler](auto *req_reader, auto *res_builder) { - spdlog::debug("Received CreateDatabaseRpc"); - CreateDatabaseHandler(dbms_handler, req_reader, res_builder); - }); - data.server->rpc_server_.Register( - [&dbms_handler](auto *req_reader, auto *res_builder) { - spdlog::debug("Received DropDatabaseRpc"); - DropDatabaseHandler(dbms_handler, req_reader, res_builder); - }); - data.server->rpc_server_.Register( - [&dbms_handler](auto *req_reader, auto *res_builder) { - spdlog::debug("Received SystemRecoveryRpc"); - SystemRecoveryHandler(dbms_handler, req_reader, res_builder); - }); -#endif -} -} // namespace system_replication - -bool StartRpcServer(DbmsHandler &dbms_handler, const replication::RoleReplicaData &data) { - // Register handlers - InMemoryReplicationHandlers::Register(&dbms_handler, *data.server); - system_replication::Register(data, dbms_handler); - // Start server - if (!data.server->Start()) { - spdlog::error("Unable to start the replication server."); - return false; - } - return true; -} -} // namespace memgraph::dbms diff --git a/src/dbms/replication_handler.hpp b/src/dbms/replication_handler.hpp deleted file mode 100644 index 53c64e34b..000000000 --- a/src/dbms/replication_handler.hpp +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2024 Memgraph Ltd. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source -// License, and you may not use this file except in compliance with the Business Source License. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -#pragma once - -#include "replication_coordination_glue/role.hpp" -#include "dbms/database.hpp" -#include "utils/result.hpp" - -namespace memgraph::replication { -struct ReplicationState; -struct ReplicationServerConfig; -struct ReplicationClientConfig; -} // namespace memgraph::replication - -namespace memgraph::dbms { - -class DbmsHandler; - -enum class RegisterReplicaError : uint8_t { NAME_EXISTS, ENDPOINT_EXISTS, CONNECTION_FAILED, COULD_NOT_BE_PERSISTED }; - -enum class UnregisterReplicaResult : uint8_t { - NOT_MAIN, - COULD_NOT_BE_PERSISTED, - CAN_NOT_UNREGISTER, - SUCCESS, -}; - -/// A handler type that keep in sync current ReplicationState and the MAIN/REPLICA-ness of Storage -/// TODO: extend to do multiple storages -struct ReplicationHandler { - explicit ReplicationHandler(DbmsHandler &dbms_handler); - - // as REPLICA, become MAIN - bool SetReplicationRoleMain(); - - // as MAIN, become REPLICA - bool SetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config); - - // as MAIN, define and connect to REPLICAs - auto RegisterReplica(const memgraph::replication::ReplicationClientConfig &config) - -> utils::BasicResult; - - // as MAIN, remove a REPLICA connection - auto UnregisterReplica(std::string_view name) -> UnregisterReplicaResult; - - // Helper pass-through (TODO: remove) - auto GetRole() const -> memgraph::replication_coordination_glue::ReplicationRole; - bool IsMain() const; - bool IsReplica() const; - - private: - DbmsHandler &dbms_handler_; -}; - -/// A handler type that keep in sync current ReplicationState and the MAIN/REPLICA-ness of Storage -/// TODO: extend to do multiple storages -void RestoreReplication(replication::ReplicationState &repl_state, DatabaseAccess db_acc); - -namespace system_replication { -// System handlers -#ifdef MG_ENTERPRISE -void CreateDatabaseHandler(DbmsHandler &dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder); -void SystemHeartbeatHandler(uint64_t ts, slk::Reader *req_reader, slk::Builder *res_builder); -void SystemRecoveryHandler(DbmsHandler &dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder); -#endif - -/// Register all DBMS level RPC handlers -void Register(replication::RoleReplicaData const &data, DbmsHandler &dbms_handler); -} // namespace system_replication - -bool StartRpcServer(DbmsHandler &dbms_handler, const replication::RoleReplicaData &data); - -} // namespace memgraph::dbms diff --git a/src/dbms/replication_handlers.cpp b/src/dbms/replication_handlers.cpp new file mode 100644 index 000000000..2c77262fa --- /dev/null +++ b/src/dbms/replication_handlers.cpp @@ -0,0 +1,191 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "dbms/replication_handlers.hpp" + +#include "dbms/database.hpp" +#include "dbms/dbms_handler.hpp" +#include "storage/v2/storage.hpp" +#include "system/state.hpp" + +namespace memgraph::dbms { + +#ifdef MG_ENTERPRISE + +void CreateDatabaseHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access, + DbmsHandler &dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder) { + using memgraph::storage::replication::CreateDatabaseRes; + CreateDatabaseRes res(CreateDatabaseRes::Result::FAILURE); + + // Ignore if no license + if (!license::global_license_checker.IsEnterpriseValidFast()) { + spdlog::error("Handling CreateDatabase, an enterprise RPC message, without license."); + memgraph::slk::Save(res, res_builder); + return; + } + + memgraph::storage::replication::CreateDatabaseReq req; + memgraph::slk::Load(&req, req_reader); + + // Note: No need to check epoch, recovery mechanism is done by a full uptodate snapshot + // of the set of databases. Hence no history exists to maintain regarding epoch change. + // If MAIN has changed we need to check this new group_timestamp is consistent with + // what we have so far. + + if (req.expected_group_timestamp != system_state_access.LastCommitedTS()) { + spdlog::debug("CreateDatabaseHandler: bad expected timestamp {},{}", req.expected_group_timestamp, + system_state_access.LastCommitedTS()); + memgraph::slk::Save(res, res_builder); + return; + } + + try { + // Create new + auto new_db = dbms_handler.Update(req.config); + if (new_db.HasValue()) { + // Successfully create db + system_state_access.SetLastCommitedTS(req.new_group_timestamp); + res = CreateDatabaseRes(CreateDatabaseRes::Result::SUCCESS); + spdlog::debug("CreateDatabaseHandler: SUCCESS updated LCTS to {}", req.new_group_timestamp); + } + } catch (...) { + // Failure + } + + memgraph::slk::Save(res, res_builder); +} + +void DropDatabaseHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access, DbmsHandler &dbms_handler, + slk::Reader *req_reader, slk::Builder *res_builder) { + using memgraph::storage::replication::DropDatabaseRes; + DropDatabaseRes res(DropDatabaseRes::Result::FAILURE); + + // Ignore if no license + if (!license::global_license_checker.IsEnterpriseValidFast()) { + spdlog::error("Handling DropDatabase, an enterprise RPC message, without license."); + memgraph::slk::Save(res, res_builder); + return; + } + + memgraph::storage::replication::DropDatabaseReq req; + memgraph::slk::Load(&req, req_reader); + + // Note: No need to check epoch, recovery mechanism is done by a full uptodate snapshot + // of the set of databases. Hence no history exists to maintain regarding epoch change. + // If MAIN has changed we need to check this new group_timestamp is consistent with + // what we have so far. + + if (req.expected_group_timestamp != system_state_access.LastCommitedTS()) { + spdlog::debug("DropDatabaseHandler: bad expected timestamp {},{}", req.expected_group_timestamp, + system_state_access.LastCommitedTS()); + memgraph::slk::Save(res, res_builder); + return; + } + + try { + // NOTE: Single communication channel can exist at a time, no other database can be deleted/created at the moment. + auto new_db = dbms_handler.Delete(req.uuid); + if (new_db.HasError()) { + if (new_db.GetError() == DeleteError::NON_EXISTENT) { + // Nothing to drop + system_state_access.SetLastCommitedTS(req.new_group_timestamp); + res = DropDatabaseRes(DropDatabaseRes::Result::NO_NEED); + } + } else { + // Successfully drop db + system_state_access.SetLastCommitedTS(req.new_group_timestamp); + res = DropDatabaseRes(DropDatabaseRes::Result::SUCCESS); + spdlog::debug("DropDatabaseHandler: SUCCESS updated LCTS to {}", req.new_group_timestamp); + } + } catch (...) { + // Failure + } + + memgraph::slk::Save(res, res_builder); +} + +bool SystemRecoveryHandler(DbmsHandler &dbms_handler, const std::vector &database_configs) { + /* + * NO LICENSE + */ + if (!license::global_license_checker.IsEnterpriseValidFast()) { + spdlog::error("Handling SystemRecovery, an enterprise RPC message, without license."); + for (const auto &config : database_configs) { + // Only handle default DB + if (config.name != kDefaultDB) continue; + try { + if (dbms_handler.Update(config).HasError()) { + return false; + } + } catch (const UnknownDatabaseException &) { + return false; + } + } + return true; + } + + /* + * MULTI-TENANCY + */ + // Get all current dbs + auto old = dbms_handler.All(); + // Check/create the incoming dbs + for (const auto &config : database_configs) { + // Missing db + try { + if (dbms_handler.Update(config).HasError()) { + spdlog::debug("SystemRecoveryHandler: Failed to update database \"{}\".", config.name); + return false; + } + } catch (const UnknownDatabaseException &) { + spdlog::debug("SystemRecoveryHandler: UnknownDatabaseException"); + return false; + } + const auto it = std::find(old.begin(), old.end(), config.name); + if (it != old.end()) old.erase(it); + } + + // Delete all the leftover old dbs + for (const auto &remove_db : old) { + const auto del = dbms_handler.Delete(remove_db); + if (del.HasError()) { + // Some errors are not terminal + if (del.GetError() == DeleteError::DEFAULT_DB || del.GetError() == DeleteError::NON_EXISTENT) { + spdlog::debug("SystemRecoveryHandler: Dropped database \"{}\".", remove_db); + continue; + } + spdlog::debug("SystemRecoveryHandler: Failed to drop database \"{}\".", remove_db); + return false; + } + } + + /* + * SUCCESS + */ + return true; +} + +void Register(replication::RoleReplicaData const &data, system::ReplicaHandlerAccessToState &system_state_access, + dbms::DbmsHandler &dbms_handler) { + // NOTE: Register even without license as the user could add a license at run-time + data.server->rpc_server_.Register( + [system_state_access, &dbms_handler](auto *req_reader, auto *res_builder) mutable { + spdlog::debug("Received CreateDatabaseRpc"); + CreateDatabaseHandler(system_state_access, dbms_handler, req_reader, res_builder); + }); + data.server->rpc_server_.Register( + [system_state_access, &dbms_handler](auto *req_reader, auto *res_builder) mutable { + spdlog::debug("Received DropDatabaseRpc"); + DropDatabaseHandler(system_state_access, dbms_handler, req_reader, res_builder); + }); +} +#endif +} // namespace memgraph::dbms diff --git a/src/dbms/replication_handlers.hpp b/src/dbms/replication_handlers.hpp new file mode 100644 index 000000000..48e91e384 --- /dev/null +++ b/src/dbms/replication_handlers.hpp @@ -0,0 +1,32 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include "dbms/dbms_handler.hpp" +#include "replication/state.hpp" +#include "slk/streams.hpp" +#include "system/state.hpp" + +namespace memgraph::dbms { +#ifdef MG_ENTERPRISE +// RPC handlers +void CreateDatabaseHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access, + DbmsHandler &dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder); +void DropDatabaseHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access, DbmsHandler &dbms_handler, + slk::Reader *req_reader, slk::Builder *res_builder); +bool SystemRecoveryHandler(DbmsHandler &dbms_handler, const std::vector &database_configs); + +// RPC registration +void Register(replication::RoleReplicaData const &data, system::ReplicaHandlerAccessToState &system_state_access, + dbms::DbmsHandler &dbms_handler); +#endif +} // namespace memgraph::dbms diff --git a/src/dbms/rpc.cpp b/src/dbms/rpc.cpp new file mode 100644 index 000000000..18a425cbf --- /dev/null +++ b/src/dbms/rpc.cpp @@ -0,0 +1,118 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "dbms/rpc.hpp" + +#include "slk/streams.hpp" +#include "storage/v2/replication/rpc.hpp" +#include "utils/enum.hpp" +#include "utils/typeinfo.hpp" + +namespace memgraph { + +namespace storage::replication { + +void CreateDatabaseReq::Save(const CreateDatabaseReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); +} +void CreateDatabaseReq::Load(CreateDatabaseReq *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(self, reader); +} +void CreateDatabaseRes::Save(const CreateDatabaseRes &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); +} +void CreateDatabaseRes::Load(CreateDatabaseRes *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(self, reader); +} +void DropDatabaseReq::Save(const DropDatabaseReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); +} +void DropDatabaseReq::Load(DropDatabaseReq *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); } +void DropDatabaseRes::Save(const DropDatabaseRes &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); +} +void DropDatabaseRes::Load(DropDatabaseRes *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); } + +const utils::TypeInfo CreateDatabaseReq::kType{utils::TypeId::REP_CREATE_DATABASE_REQ, "CreateDatabaseReq", nullptr}; + +const utils::TypeInfo CreateDatabaseRes::kType{utils::TypeId::REP_CREATE_DATABASE_RES, "CreateDatabaseRes", nullptr}; + +const utils::TypeInfo DropDatabaseReq::kType{utils::TypeId::REP_DROP_DATABASE_REQ, "DropDatabaseReq", nullptr}; + +const utils::TypeInfo DropDatabaseRes::kType{utils::TypeId::REP_DROP_DATABASE_RES, "DropDatabaseRes", nullptr}; + +} // namespace storage::replication + +// Autogenerated SLK serialization code +namespace slk { + +// Serialize code for CreateDatabaseReq + +void Save(const memgraph::storage::replication::CreateDatabaseReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.epoch_id, builder); + memgraph::slk::Save(self.expected_group_timestamp, builder); + memgraph::slk::Save(self.new_group_timestamp, builder); + memgraph::slk::Save(self.config, builder); +} + +void Load(memgraph::storage::replication::CreateDatabaseReq *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->epoch_id, reader); + memgraph::slk::Load(&self->expected_group_timestamp, reader); + memgraph::slk::Load(&self->new_group_timestamp, reader); + memgraph::slk::Load(&self->config, reader); +} + +// Serialize code for CreateDatabaseRes + +void Save(const memgraph::storage::replication::CreateDatabaseRes &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(utils::EnumToNum(self.result), builder); +} + +void Load(memgraph::storage::replication::CreateDatabaseRes *self, memgraph::slk::Reader *reader) { + uint8_t res = 0; + memgraph::slk::Load(&res, reader); + if (!utils::NumToEnum(res, self->result)) { + throw SlkReaderException("Unexpected result line:{}!", __LINE__); + } +} + +// Serialize code for DropDatabaseReq + +void Save(const memgraph::storage::replication::DropDatabaseReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.epoch_id, builder); + memgraph::slk::Save(self.expected_group_timestamp, builder); + memgraph::slk::Save(self.new_group_timestamp, builder); + memgraph::slk::Save(self.uuid, builder); +} + +void Load(memgraph::storage::replication::DropDatabaseReq *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->epoch_id, reader); + memgraph::slk::Load(&self->expected_group_timestamp, reader); + memgraph::slk::Load(&self->new_group_timestamp, reader); + memgraph::slk::Load(&self->uuid, reader); +} + +// Serialize code for DropDatabaseRes + +void Save(const memgraph::storage::replication::DropDatabaseRes &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(utils::EnumToNum(self.result), builder); +} + +void Load(memgraph::storage::replication::DropDatabaseRes *self, memgraph::slk::Reader *reader) { + uint8_t res = 0; + memgraph::slk::Load(&res, reader); + if (!utils::NumToEnum(res, self->result)) { + throw SlkReaderException("Unexpected result line:{}!", __LINE__); + } +} + +} // namespace slk +} // namespace memgraph diff --git a/src/dbms/rpc.hpp b/src/dbms/rpc.hpp new file mode 100644 index 000000000..b08928e80 --- /dev/null +++ b/src/dbms/rpc.hpp @@ -0,0 +1,118 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include +#include +#include + +#include "rpc/messages.hpp" +#include "slk/streams.hpp" +#include "storage/v2/config.hpp" +#include "utils/uuid.hpp" + +namespace memgraph::storage::replication { + +struct CreateDatabaseReq { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(CreateDatabaseReq *self, memgraph::slk::Reader *reader); + static void Save(const CreateDatabaseReq &self, memgraph::slk::Builder *builder); + CreateDatabaseReq() = default; + CreateDatabaseReq(std::string_view epoch_id, uint64_t expected_group_timestamp, uint64_t new_group_timestamp, + storage::SalientConfig config) + : epoch_id(std::string(epoch_id)), + expected_group_timestamp{expected_group_timestamp}, + new_group_timestamp(new_group_timestamp), + config(std::move(config)) {} + + std::string epoch_id; + uint64_t expected_group_timestamp; + uint64_t new_group_timestamp; + storage::SalientConfig config; +}; + +struct CreateDatabaseRes { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + enum class Result : uint8_t { SUCCESS, NO_NEED, FAILURE, /* Leave at end */ N }; + + static void Load(CreateDatabaseRes *self, memgraph::slk::Reader *reader); + static void Save(const CreateDatabaseRes &self, memgraph::slk::Builder *builder); + CreateDatabaseRes() = default; + explicit CreateDatabaseRes(Result res) : result(res) {} + + Result result; +}; + +using CreateDatabaseRpc = rpc::RequestResponse; + +struct DropDatabaseReq { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(DropDatabaseReq *self, memgraph::slk::Reader *reader); + static void Save(const DropDatabaseReq &self, memgraph::slk::Builder *builder); + DropDatabaseReq() = default; + DropDatabaseReq(std::string_view epoch_id, uint64_t expected_group_timestamp, uint64_t new_group_timestamp, + const utils::UUID &uuid) + : epoch_id(std::string(epoch_id)), + expected_group_timestamp{expected_group_timestamp}, + new_group_timestamp(new_group_timestamp), + uuid(uuid) {} + + std::string epoch_id; + uint64_t expected_group_timestamp; + uint64_t new_group_timestamp; + utils::UUID uuid; +}; + +struct DropDatabaseRes { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + enum class Result : uint8_t { SUCCESS, NO_NEED, FAILURE, /* Leave at end */ N }; + + static void Load(DropDatabaseRes *self, memgraph::slk::Reader *reader); + static void Save(const DropDatabaseRes &self, memgraph::slk::Builder *builder); + DropDatabaseRes() = default; + explicit DropDatabaseRes(Result res) : result(res) {} + + Result result; +}; + +using DropDatabaseRpc = rpc::RequestResponse; + +} // namespace memgraph::storage::replication + +// SLK serialization declarations +namespace memgraph::slk { + +void Save(const memgraph::storage::replication::CreateDatabaseReq &self, memgraph::slk::Builder *builder); + +void Load(memgraph::storage::replication::CreateDatabaseReq *self, memgraph::slk::Reader *reader); + +void Save(const memgraph::storage::replication::CreateDatabaseRes &self, memgraph::slk::Builder *builder); + +void Load(memgraph::storage::replication::CreateDatabaseRes *self, memgraph::slk::Reader *reader); + +void Save(const memgraph::storage::replication::DropDatabaseReq &self, memgraph::slk::Builder *builder); + +void Load(memgraph::storage::replication::DropDatabaseReq *self, memgraph::slk::Reader *reader); + +void Save(const memgraph::storage::replication::DropDatabaseRes &self, memgraph::slk::Builder *builder); + +void Load(memgraph::storage::replication::DropDatabaseRes *self, memgraph::slk::Reader *reader); + +} // namespace memgraph::slk diff --git a/src/dbms/transaction.hpp b/src/dbms/transaction.hpp index 7167d9ec5..394cafceb 100644 --- a/src/dbms/transaction.hpp +++ b/src/dbms/transaction.hpp @@ -11,7 +11,10 @@ #pragma once +#include #include +#include +#include "auth/models.hpp" #include "storage/v2/config.hpp" namespace memgraph::dbms { @@ -20,17 +23,70 @@ struct SystemTransaction { enum class Action { CREATE_DATABASE, DROP_DATABASE, + UPDATE_AUTH_DATA, + DROP_AUTH_DATA, + /** + * + * CREATE USER user_name [IDENTIFIED BY 'password']; + * SET PASSWORD FOR user_name TO 'new_password'; + * ^ SaveUser + * + * DROP USER user_name; + * ^ Directly on KVStore + * + * CREATE ROLE role_name; + * ^ SaveRole + * + * DROP ROLE + * ^ RemoveRole + * + * SET ROLE FOR user_name TO role_name; + * CLEAR ROLE FOR user_name; + * ^ Do stuff then do SaveUser + * + * GRANT privilege_list TO user_or_role; + * DENY AUTH, INDEX TO moderator: + * REVOKE AUTH, INDEX TO moderator: + * GRANT permission_level ON (LABELS | EDGE_TYPES) label_list TO user_or_role; + * REVOKE (LABELS | EDGE_TYPES) label_or_edge_type_list FROM user_or_role + * DENY (LABELS | EDGE_TYPES) label_or_edge_type_list TO user_or_role + * ^ all of these are EditPermissions <-> SaveUser/Role + * + * Multi-tenant TODO Doc; + * ^ Should all call SaveUser + * + */ }; static constexpr struct CreateDatabase { } create_database; static constexpr struct DropDatabase { } drop_database; + static constexpr struct UpdateAuthData { + } update_auth_data; + static constexpr struct DropAuthData { + } drop_auth_data; + enum class AuthData { USER, ROLE }; + + // Multi-tenancy Delta(CreateDatabase /*tag*/, storage::SalientConfig config) : action(Action::CREATE_DATABASE), config(std::move(config)) {} Delta(DropDatabase /*tag*/, const utils::UUID &uuid) : action(Action::DROP_DATABASE), uuid(uuid) {} + // Auth + Delta(UpdateAuthData /*tag*/, std::optional user) + : action(Action::UPDATE_AUTH_DATA), auth_data{std::move(user), std::nullopt} {} + Delta(UpdateAuthData /*tag*/, std::optional role) + : action(Action::UPDATE_AUTH_DATA), auth_data{std::nullopt, std::move(role)} {} + Delta(DropAuthData /*tag*/, AuthData type, std::string_view name) + : action(Action::DROP_AUTH_DATA), + auth_data_key{ + .type = type, + .name = std::string{name}, + } {} + + // Generic Delta(const Delta &) = delete; Delta(Delta &&) = delete; Delta &operator=(const Delta &) = delete; @@ -42,8 +98,14 @@ struct SystemTransaction { std::destroy_at(&config); break; case Action::DROP_DATABASE: + std::destroy_at(&uuid); + break; + case Action::UPDATE_AUTH_DATA: + std::destroy_at(&auth_data); + break; + case Action::DROP_AUTH_DATA: + std::destroy_at(&auth_data_key); break; - // Some deltas might have special destructor handling } } @@ -51,13 +113,20 @@ struct SystemTransaction { union { storage::SalientConfig config; utils::UUID uuid; + struct { + std::optional user; + std::optional role; + } auth_data; + struct { + AuthData type; + std::string name; + } auth_data_key; }; }; explicit SystemTransaction(uint64_t timestamp) : system_timestamp(timestamp) {} - // Currently system transitions support a single delta - std::optional delta{}; + std::list deltas{}; uint64_t system_timestamp; }; diff --git a/src/dbms/utils.hpp b/src/dbms/utils.hpp deleted file mode 100644 index 801ac9be3..000000000 --- a/src/dbms/utils.hpp +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright 2024 Memgraph Ltd. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source -// License, and you may not use this file except in compliance with the Business Source License. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. -#pragma once - -#include "dbms/dbms_handler.hpp" -#include "dbms/replication_handler.hpp" -#include "replication/include/replication/state.hpp" -#include "utils/result.hpp" - -namespace memgraph::dbms { - -inline bool DoReplicaToMainPromotion(dbms::DbmsHandler &dbms_handler) { - auto &repl_state = dbms_handler.ReplicationState(); - // STEP 1) bring down all REPLICA servers - dbms_handler.ForEach([](DatabaseAccess db_acc) { - auto *storage = db_acc->storage(); - // Remember old epoch + storage timestamp association - storage->PrepareForNewEpoch(); - }); - - // STEP 2) Change to MAIN - // TODO: restore replication servers if false? - if (!repl_state.SetReplicationRoleMain()) { - // TODO: Handle recovery on failure??? - return false; - } - - // STEP 3) We are now MAIN, update storage local epoch - const auto &epoch = - std::get(std::as_const(dbms_handler.ReplicationState()).ReplicationData()).epoch_; - dbms_handler.ForEach([&](DatabaseAccess db_acc) { - auto *storage = db_acc->storage(); - storage->repl_storage_state_.epoch_ = epoch; - }); - - return true; -}; - -inline bool SetReplicationRoleReplica(dbms::DbmsHandler &dbms_handler, - const memgraph::replication::ReplicationServerConfig &config) { - if (dbms_handler.ReplicationState().IsReplica()) { - return false; - } - - // TODO StorageState needs to be synched. Could have a dangling reference if someone adds a database as we are - // deleting the replica. - // Remove database specific clients - dbms_handler.ForEach([&](DatabaseAccess db_acc) { - auto *storage = db_acc->storage(); - storage->repl_storage_state_.replication_clients_.WithLock([](auto &clients) { clients.clear(); }); - }); - // Remove instance level clients - std::get(dbms_handler.ReplicationState().ReplicationData()).registered_replicas_.clear(); - - // Creates the server - dbms_handler.ReplicationState().SetReplicationRoleReplica(config); - - // Start - const auto success = std::visit(utils::Overloaded{[](replication::RoleMainData const &) { - // ASSERT - return false; - }, - [&dbms_handler](replication::RoleReplicaData const &data) { - return StartRpcServer(dbms_handler, data); - }}, - dbms_handler.ReplicationState().ReplicationData()); - // TODO Handle error (restore to main?) - return success; -} - -template -inline bool RegisterAllDatabasesClients(dbms::DbmsHandler &dbms_handler, - replication::ReplicationClient &instance_client) { - if (!allow_mt_repl && dbms_handler.All().size() > 1) { - spdlog::warn("Multi-tenant replication is currently not supported!"); - } - - bool all_clients_good = true; - - dbms_handler.ForEach([&](DatabaseAccess db_acc) { - auto *storage = db_acc->storage(); - if (!allow_mt_repl && storage->name() != kDefaultDB) { - return; - } - // TODO: ATM only IN_MEMORY_TRANSACTIONAL, fix other modes - if (storage->storage_mode_ != storage::StorageMode::IN_MEMORY_TRANSACTIONAL) return; - - using enum storage::replication::ReplicaState; - - all_clients_good &= storage->repl_storage_state_.replication_clients_.WithLock( - [storage, &instance_client, db_acc = std::move(db_acc)](auto &storage_clients) mutable { // NOLINT - auto client = std::make_unique(instance_client); - client->Start(storage, std::move(db_acc)); - if (client->State() == MAYBE_BEHIND && !AllowRPCFailure) { - return false; - } - storage_clients.push_back(std::move(client)); - return true; - }); - }); - - return all_clients_good; -} - -inline std::optional HandleRegisterReplicaStatus( - utils::BasicResult &instance_client) { - if (instance_client.HasError()) switch (instance_client.GetError()) { - case replication::RegisterReplicaError::NOT_MAIN: - MG_ASSERT(false, "Only main instance can register a replica!"); - return {}; - case replication::RegisterReplicaError::NAME_EXISTS: - return dbms::RegisterReplicaError::NAME_EXISTS; - case replication::RegisterReplicaError::ENDPOINT_EXISTS: - return dbms::RegisterReplicaError::ENDPOINT_EXISTS; - case replication::RegisterReplicaError::COULD_NOT_BE_PERSISTED: - return dbms::RegisterReplicaError::COULD_NOT_BE_PERSISTED; - case replication::RegisterReplicaError::SUCCESS: - break; - } - return {}; -} - -} // namespace memgraph::dbms diff --git a/src/glue/ServerT.hpp b/src/glue/ServerT.hpp index 641553128..0d0dbe700 100644 --- a/src/glue/ServerT.hpp +++ b/src/glue/ServerT.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -33,7 +33,7 @@ class WritePrioritizedRWLock; struct Context { memgraph::query::InterpreterContext *ic; - memgraph::utils::Synchronized *auth; + memgraph::auth::SynchedAuth *auth; #if MG_ENTERPRISE memgraph::audit::Log *audit_log; #endif diff --git a/src/glue/SessionHL.cpp b/src/glue/SessionHL.cpp index 61c1ab26f..07e1bf6e8 100644 --- a/src/glue/SessionHL.cpp +++ b/src/glue/SessionHL.cpp @@ -319,8 +319,7 @@ void SessionHL::Configure(const std::map *auth + memgraph::communication::v2::OutputStream *output_stream, memgraph::auth::SynchedAuth *auth #ifdef MG_ENTERPRISE , memgraph::audit::Log *audit_log diff --git a/src/glue/SessionHL.hpp b/src/glue/SessionHL.hpp index 374d2464e..64dcddda5 100644 --- a/src/glue/SessionHL.hpp +++ b/src/glue/SessionHL.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -25,8 +25,7 @@ class SessionHL final : public memgraph::communication::bolt::Session *auth + memgraph::communication::v2::OutputStream *output_stream, memgraph::auth::SynchedAuth *auth #ifdef MG_ENTERPRISE , memgraph::audit::Log *audit_log @@ -88,7 +87,7 @@ class SessionHL final : public memgraph::communication::bolt::Session *auth_; + memgraph::auth::SynchedAuth *auth_; memgraph::communication::v2::ServerEndpoint endpoint_; std::optional implicit_db_; }; diff --git a/src/glue/auth_checker.cpp b/src/glue/auth_checker.cpp index 981ab8cca..4db6c827e 100644 --- a/src/glue/auth_checker.cpp +++ b/src/glue/auth_checker.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -66,9 +66,7 @@ bool IsUserAuthorizedEdgeType(const memgraph::auth::User &user, const memgraph:: #endif namespace memgraph::glue { -AuthChecker::AuthChecker( - memgraph::utils::Synchronized *auth) - : auth_(auth) {} +AuthChecker::AuthChecker(memgraph::auth::SynchedAuth *auth) : auth_(auth) {} bool AuthChecker::IsUserAuthorized(const std::optional &username, const std::vector &privileges, diff --git a/src/glue/auth_checker.hpp b/src/glue/auth_checker.hpp index e926c120b..217ac0c74 100644 --- a/src/glue/auth_checker.hpp +++ b/src/glue/auth_checker.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -22,8 +22,7 @@ namespace memgraph::glue { class AuthChecker : public query::AuthChecker { public: - explicit AuthChecker( - memgraph::utils::Synchronized *auth); + explicit AuthChecker(memgraph::auth::SynchedAuth *auth); bool IsUserAuthorized(const std::optional &username, const std::vector &privileges, @@ -41,7 +40,7 @@ class AuthChecker : public query::AuthChecker { const std::string &db_name = ""); private: - memgraph::utils::Synchronized *auth_; + memgraph::auth::SynchedAuth *auth_; mutable memgraph::utils::Synchronized user_; // cached user }; #ifdef MG_ENTERPRISE diff --git a/src/glue/auth_handler.cpp b/src/glue/auth_handler.cpp index f3efb6ba0..2d7260b3c 100644 --- a/src/glue/auth_handler.cpp +++ b/src/glue/auth_handler.cpp @@ -210,16 +210,25 @@ std::vector> ShowFineGrainedUserPrivile if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) { return {}; } - const auto &label_permissions = user->GetFineGrainedAccessLabelPermissions(); - const auto &edge_type_permissions = user->GetFineGrainedAccessEdgeTypePermissions(); - auto all_fine_grained_permissions = - GetFineGrainedPermissionForPrivilegeForUserOrRole(label_permissions, "LABEL", "USER"); - auto edge_type_fine_grained_permissions = - GetFineGrainedPermissionForPrivilegeForUserOrRole(edge_type_permissions, "EDGE_TYPE", "USER"); + auto all_fine_grained_permissions = GetFineGrainedPermissionForPrivilegeForUserOrRole( + user->GetUserFineGrainedAccessLabelPermissions(), "LABEL", "USER"); + auto all_role_fine_grained_permissions = GetFineGrainedPermissionForPrivilegeForUserOrRole( + user->GetRoleFineGrainedAccessLabelPermissions(), "LABEL", "ROLE"); + all_fine_grained_permissions.insert(all_fine_grained_permissions.end(), + std::make_move_iterator(all_role_fine_grained_permissions.begin()), + std::make_move_iterator(all_role_fine_grained_permissions.end())); - all_fine_grained_permissions.insert(all_fine_grained_permissions.end(), edge_type_fine_grained_permissions.begin(), - edge_type_fine_grained_permissions.end()); + auto edge_type_fine_grained_permissions = GetFineGrainedPermissionForPrivilegeForUserOrRole( + user->GetUserFineGrainedAccessEdgeTypePermissions(), "EDGE_TYPE", "USER"); + auto role_edge_type_fine_grained_permissions = GetFineGrainedPermissionForPrivilegeForUserOrRole( + user->GetRoleFineGrainedAccessEdgeTypePermissions(), "EDGE_TYPE", "ROLE"); + all_fine_grained_permissions.insert(all_fine_grained_permissions.end(), + std::make_move_iterator(edge_type_fine_grained_permissions.begin()), + std::make_move_iterator(edge_type_fine_grained_permissions.end())); + all_fine_grained_permissions.insert(all_fine_grained_permissions.end(), + std::make_move_iterator(role_edge_type_fine_grained_permissions.begin()), + std::make_move_iterator(role_edge_type_fine_grained_permissions.end())); return ConstructFineGrainedPrivilegesResult(all_fine_grained_permissions); } @@ -233,9 +242,9 @@ std::vector> ShowFineGrainedRolePrivile const auto &edge_type_permissions = role->GetFineGrainedAccessEdgeTypePermissions(); auto all_fine_grained_permissions = - GetFineGrainedPermissionForPrivilegeForUserOrRole(label_permissions, "LABEL", "USER"); + GetFineGrainedPermissionForPrivilegeForUserOrRole(label_permissions, "LABEL", "ROLE"); auto edge_type_fine_grained_permissions = - GetFineGrainedPermissionForPrivilegeForUserOrRole(edge_type_permissions, "EDGE_TYPE", "USER"); + GetFineGrainedPermissionForPrivilegeForUserOrRole(edge_type_permissions, "EDGE_TYPE", "ROLE"); all_fine_grained_permissions.insert(all_fine_grained_permissions.end(), edge_type_fine_grained_permissions.begin(), edge_type_fine_grained_permissions.end()); @@ -248,16 +257,15 @@ std::vector> ShowFineGrainedRolePrivile namespace memgraph::glue { -AuthQueryHandler::AuthQueryHandler( - memgraph::utils::Synchronized *auth) - : auth_(auth) {} +AuthQueryHandler::AuthQueryHandler(memgraph::auth::SynchedAuth *auth) : auth_(auth) {} -bool AuthQueryHandler::CreateUser(const std::string &username, const std::optional &password) { +bool AuthQueryHandler::CreateUser(const std::string &username, const std::optional &password, + system::Transaction *system_tx) { try { const auto [first_user, user_added] = std::invoke([&, this] { auto locked_auth = auth_->Lock(); const auto first_user = !locked_auth->HasUsers(); - const auto user_added = locked_auth->AddUser(username, password).has_value(); + const auto user_added = locked_auth->AddUser(username, password, system_tx).has_value(); return std::make_pair(first_user, user_added); }); @@ -276,10 +284,11 @@ bool AuthQueryHandler::CreateUser(const std::string &username, const std::option } } #endif - ); + , + system_tx); #ifdef MG_ENTERPRISE - GrantDatabaseToUser(auth::kAllDatabases, username); - SetMainDatabase(dbms::kDefaultDB, username); + GrantDatabaseToUser(auth::kAllDatabases, username, system_tx); + SetMainDatabase(dbms::kDefaultDB, username, system_tx); #endif } @@ -289,18 +298,19 @@ bool AuthQueryHandler::CreateUser(const std::string &username, const std::option } } -bool AuthQueryHandler::DropUser(const std::string &username) { +bool AuthQueryHandler::DropUser(const std::string &username, system::Transaction *system_tx) { try { auto locked_auth = auth_->Lock(); auto user = locked_auth->GetUser(username); if (!user) return false; - return locked_auth->RemoveUser(username); + return locked_auth->RemoveUser(username, system_tx); } catch (const memgraph::auth::AuthException &e) { throw memgraph::query::QueryRuntimeException(e.what()); } } -void AuthQueryHandler::SetPassword(const std::string &username, const std::optional &password) { +void AuthQueryHandler::SetPassword(const std::string &username, const std::optional &password, + system::Transaction *system_tx) { try { auto locked_auth = auth_->Lock(); auto user = locked_auth->GetUser(username); @@ -308,39 +318,41 @@ void AuthQueryHandler::SetPassword(const std::string &username, const std::optio throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist.", username); } locked_auth->UpdatePassword(*user, password); - locked_auth->SaveUser(*user); + locked_auth->SaveUser(*user, system_tx); } catch (const memgraph::auth::AuthException &e) { throw memgraph::query::QueryRuntimeException(e.what()); } } -bool AuthQueryHandler::CreateRole(const std::string &rolename) { +bool AuthQueryHandler::CreateRole(const std::string &rolename, system::Transaction *system_tx) { try { auto locked_auth = auth_->Lock(); - return locked_auth->AddRole(rolename).has_value(); + return locked_auth->AddRole(rolename, system_tx).has_value(); } catch (const memgraph::auth::AuthException &e) { throw memgraph::query::QueryRuntimeException(e.what()); } } #ifdef MG_ENTERPRISE -bool AuthQueryHandler::RevokeDatabaseFromUser(const std::string &db, const std::string &username) { +bool AuthQueryHandler::RevokeDatabaseFromUser(const std::string &db_name, const std::string &username, + system::Transaction *system_tx) { try { auto locked_auth = auth_->Lock(); auto user = locked_auth->GetUser(username); if (!user) return false; - return locked_auth->RevokeDatabaseFromUser(db, username); + return locked_auth->RevokeDatabaseFromUser(db_name, username, system_tx); } catch (const memgraph::auth::AuthException &e) { throw memgraph::query::QueryRuntimeException(e.what()); } } -bool AuthQueryHandler::GrantDatabaseToUser(const std::string &db, const std::string &username) { +bool AuthQueryHandler::GrantDatabaseToUser(const std::string &db_name, const std::string &username, + system::Transaction *system_tx) { try { auto locked_auth = auth_->Lock(); auto user = locked_auth->GetUser(username); if (!user) return false; - return locked_auth->GrantDatabaseToUser(db, username); + return locked_auth->GrantDatabaseToUser(db_name, username, system_tx); } catch (const memgraph::auth::AuthException &e) { throw memgraph::query::QueryRuntimeException(e.what()); } @@ -360,27 +372,28 @@ std::vector> AuthQueryHandler::GetDatab } } -bool AuthQueryHandler::SetMainDatabase(std::string_view db, const std::string &username) { +bool AuthQueryHandler::SetMainDatabase(std::string_view db_name, const std::string &username, + system::Transaction *system_tx) { try { auto locked_auth = auth_->Lock(); auto user = locked_auth->GetUser(username); if (!user) return false; - return locked_auth->SetMainDatabase(db, username); + return locked_auth->SetMainDatabase(db_name, username, system_tx); } catch (const memgraph::auth::AuthException &e) { throw memgraph::query::QueryRuntimeException(e.what()); } } -void AuthQueryHandler::DeleteDatabase(std::string_view db) { +void AuthQueryHandler::DeleteDatabase(std::string_view db_name, system::Transaction *system_tx) { try { - auth_->Lock()->DeleteDatabase(std::string(db)); + auth_->Lock()->DeleteDatabase(std::string(db_name), system_tx); } catch (const memgraph::auth::AuthException &e) { throw memgraph::query::QueryRuntimeException(e.what()); } } #endif -bool AuthQueryHandler::DropRole(const std::string &rolename) { +bool AuthQueryHandler::DropRole(const std::string &rolename, system::Transaction *system_tx) { try { auto locked_auth = auth_->Lock(); auto role = locked_auth->GetRole(rolename); @@ -389,7 +402,7 @@ bool AuthQueryHandler::DropRole(const std::string &rolename) { return false; }; - return locked_auth->RemoveRole(rolename); + return locked_auth->RemoveRole(rolename, system_tx); } catch (const memgraph::auth::AuthException &e) { throw memgraph::query::QueryRuntimeException(e.what()); } @@ -461,7 +474,8 @@ std::vector AuthQueryHandler::GetUsernamesForRole(c } } -void AuthQueryHandler::SetRole(const std::string &username, const std::string &rolename) { +void AuthQueryHandler::SetRole(const std::string &username, const std::string &rolename, + system::Transaction *system_tx) { try { auto locked_auth = auth_->Lock(); auto user = locked_auth->GetUser(username); @@ -477,13 +491,13 @@ void AuthQueryHandler::SetRole(const std::string &username, const std::string &r current_role->rolename()); } user->SetRole(*role); - locked_auth->SaveUser(*user); + locked_auth->SaveUser(*user, system_tx); } catch (const memgraph::auth::AuthException &e) { throw memgraph::query::QueryRuntimeException(e.what()); } } -void AuthQueryHandler::ClearRole(const std::string &username) { +void AuthQueryHandler::ClearRole(const std::string &username, system::Transaction *system_tx) { try { auto locked_auth = auth_->Lock(); auto user = locked_auth->GetUser(username); @@ -491,7 +505,7 @@ void AuthQueryHandler::ClearRole(const std::string &username) { throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist .", username); } user->ClearRole(); - locked_auth->SaveUser(*user); + locked_auth->SaveUser(*user, system_tx); } catch (const memgraph::auth::AuthException &e) { throw memgraph::query::QueryRuntimeException(e.what()); } @@ -545,7 +559,8 @@ void AuthQueryHandler::GrantPrivilege( const std::vector>> &edge_type_privileges #endif -) { + , + system::Transaction *system_tx) { EditPermissions( user_or_role, privileges, #ifdef MG_ENTERPRISE @@ -568,11 +583,13 @@ void AuthQueryHandler::GrantPrivilege( } } #endif - ); + , + system_tx); } // namespace memgraph::glue void AuthQueryHandler::DenyPrivilege(const std::string &user_or_role, - const std::vector &privileges) { + const std::vector &privileges, + system::Transaction *system_tx) { EditPermissions( user_or_role, privileges, #ifdef MG_ENTERPRISE @@ -588,7 +605,8 @@ void AuthQueryHandler::DenyPrivilege(const std::string &user_or_role, , [](auto &fine_grained_permissions, const auto &privilege_collection) {} #endif - ); + , + system_tx); } void AuthQueryHandler::RevokePrivilege( @@ -600,7 +618,8 @@ void AuthQueryHandler::RevokePrivilege( const std::vector>> &edge_type_privileges #endif -) { + , + system::Transaction *system_tx) { EditPermissions( user_or_role, privileges, #ifdef MG_ENTERPRISE @@ -622,7 +641,8 @@ void AuthQueryHandler::RevokePrivilege( } } #endif - ); + , + system_tx); } // namespace memgraph::glue template permissions; permissions.reserve(privileges.size()); @@ -675,7 +696,7 @@ void AuthQueryHandler::EditPermissions( } } #endif - locked_auth->SaveUser(*user); + locked_auth->SaveUser(*user, system_tx); } else { for (const auto &permission : permissions) { edit_permissions_fun(role->permissions(), permission); @@ -691,7 +712,7 @@ void AuthQueryHandler::EditPermissions( } } #endif - locked_auth->SaveRole(*role); + locked_auth->SaveRole(*role, system_tx); } } catch (const memgraph::auth::AuthException &e) { throw memgraph::query::QueryRuntimeException(e.what()); diff --git a/src/glue/auth_handler.hpp b/src/glue/auth_handler.hpp index c226a4560..52db6075f 100644 --- a/src/glue/auth_handler.hpp +++ b/src/glue/auth_handler.hpp @@ -23,32 +23,36 @@ namespace memgraph::glue { class AuthQueryHandler final : public memgraph::query::AuthQueryHandler { - memgraph::utils::Synchronized *auth_; + memgraph::auth::SynchedAuth *auth_; public: - AuthQueryHandler(memgraph::utils::Synchronized *auth); + explicit AuthQueryHandler(memgraph::auth::SynchedAuth *auth); - bool CreateUser(const std::string &username, const std::optional &password) override; + bool CreateUser(const std::string &username, const std::optional &password, + system::Transaction *system_tx) override; - bool DropUser(const std::string &username) override; + bool DropUser(const std::string &username, system::Transaction *system_tx) override; - void SetPassword(const std::string &username, const std::optional &password) override; + void SetPassword(const std::string &username, const std::optional &password, + system::Transaction *system_tx) override; #ifdef MG_ENTERPRISE - bool RevokeDatabaseFromUser(const std::string &db, const std::string &username) override; + bool RevokeDatabaseFromUser(const std::string &db_name, const std::string &username, + system::Transaction *system_tx) override; - bool GrantDatabaseToUser(const std::string &db, const std::string &username) override; + bool GrantDatabaseToUser(const std::string &db_name, const std::string &username, + system::Transaction *system_tx) override; std::vector> GetDatabasePrivileges(const std::string &username) override; - bool SetMainDatabase(std::string_view db, const std::string &username) override; + bool SetMainDatabase(std::string_view db_name, const std::string &username, system::Transaction *system_tx) override; - void DeleteDatabase(std::string_view db) override; + void DeleteDatabase(std::string_view db_name, system::Transaction *system_tx) override; #endif - bool CreateRole(const std::string &rolename) override; + bool CreateRole(const std::string &rolename, system::Transaction *system_tx) override; - bool DropRole(const std::string &rolename) override; + bool DropRole(const std::string &rolename, system::Transaction *system_tx) override; std::vector GetUsernames() override; @@ -58,9 +62,9 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler { std::vector GetUsernamesForRole(const std::string &rolename) override; - void SetRole(const std::string &username, const std::string &rolename) override; + void SetRole(const std::string &username, const std::string &rolename, system::Transaction *system_tx) override; - void ClearRole(const std::string &username) override; + void ClearRole(const std::string &username, system::Transaction *system_tx) override; std::vector> GetPrivileges(const std::string &user_or_role) override; @@ -74,10 +78,12 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler { const std::vector>> &edge_type_privileges #endif - ) override; + , + system::Transaction *system_tx) override; void DenyPrivilege(const std::string &user_or_role, - const std::vector &privileges) override; + const std::vector &privileges, + system::Transaction *system_tx) override; void RevokePrivilege( const std::string &user_or_role, const std::vector &privileges @@ -88,7 +94,8 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler { const std::vector>> &edge_type_privileges #endif - ) override; + , + system::Transaction *system_tx) override; private: template ( } KVStore::~KVStore() { + if (pimpl_ == nullptr) return; spdlog::debug("Destroying KVStore at {}", pimpl_->storage.string()); const auto sync = pimpl_->db->SyncWAL(); if (!sync.ok()) spdlog::error("KVStore sync failed!"); diff --git a/src/memgraph.cpp b/src/memgraph.cpp index cbd63490e..fa2f9c1f2 100644 --- a/src/memgraph.cpp +++ b/src/memgraph.cpp @@ -11,9 +11,12 @@ #include #include "audit/log.hpp" +#include "auth/auth.hpp" #include "communication/websocket/auth.hpp" #include "communication/websocket/server.hpp" +#include "coordination/coordinator_handlers.hpp" #include "dbms/constants.hpp" +#include "dbms/dbms_handler.hpp" #include "dbms/inmemory/replication_handlers.hpp" #include "flags/all.hpp" #include "glue/MonitoringServerT.hpp" @@ -24,14 +27,19 @@ #include "helpers.hpp" #include "license/license_sender.hpp" #include "memory/global_memory_control.hpp" +#include "query/auth_query_handler.hpp" #include "query/config.hpp" #include "query/discard_value_stream.hpp" #include "query/interpreter.hpp" +#include "query/interpreter_context.hpp" #include "query/procedure/callable_alias_mapper.hpp" #include "query/procedure/module.hpp" #include "query/procedure/py_module.hpp" +#include "replication_handler/replication_handler.hpp" +#include "replication_handler/system_replication.hpp" #include "requests/requests.hpp" #include "storage/v2/durability/durability.hpp" +#include "system/system.hpp" #include "telemetry/telemetry.hpp" #include "utils/signals.hpp" #include "utils/sysinfo/memory.hpp" @@ -39,10 +47,6 @@ #include "utils/terminate_handler.hpp" #include "version.hpp" -#include "dbms/dbms_handler.hpp" -#include "query/auth_query_handler.hpp" -#include "query/interpreter_context.hpp" - namespace { constexpr const char *kMgUser = "MEMGRAPH_USER"; constexpr const char *kMgPassword = "MEMGRAPH_PASSWORD"; @@ -356,44 +360,75 @@ int main(int argc, char **argv) { .stream_transaction_conflict_retries = FLAGS_stream_transaction_conflict_retries, .stream_transaction_retry_interval = std::chrono::milliseconds(FLAGS_stream_transaction_retry_interval)}; - auto auth_glue = - [](memgraph::utils::Synchronized *auth, - std::unique_ptr &ah, std::unique_ptr &ac) { - // Glue high level auth implementations to the query side - ah = std::make_unique(auth); - ac = std::make_unique(auth); - // Handle users passed via arguments - auto *maybe_username = std::getenv(kMgUser); - auto *maybe_password = std::getenv(kMgPassword); - auto *maybe_pass_file = std::getenv(kMgPassfile); - if (maybe_username && maybe_password) { - ah->CreateUser(maybe_username, maybe_password); - } else if (maybe_pass_file) { - const auto [username, password] = LoadUsernameAndPassword(maybe_pass_file); - if (!username.empty() && !password.empty()) { - ah->CreateUser(username, password); - } - } - }; + auto auth_glue = [](memgraph::auth::SynchedAuth *auth, std::unique_ptr &ah, + std::unique_ptr &ac) { + // Glue high level auth implementations to the query side + ah = std::make_unique(auth); + ac = std::make_unique(auth); + // Handle users passed via arguments + auto *maybe_username = std::getenv(kMgUser); + auto *maybe_password = std::getenv(kMgPassword); + auto *maybe_pass_file = std::getenv(kMgPassfile); + if (maybe_username && maybe_password) { + ah->CreateUser(maybe_username, maybe_password, nullptr); + } else if (maybe_pass_file) { + const auto [username, password] = LoadUsernameAndPassword(maybe_pass_file); + if (!username.empty() && !password.empty()) { + ah->CreateUser(username, password, nullptr); + } + } + }; memgraph::auth::Auth::Config auth_config{FLAGS_auth_user_or_role_name_regex, FLAGS_auth_password_strength_regex, FLAGS_auth_password_permit_null}; - memgraph::utils::Synchronized auth_{ - data_directory / "auth", auth_config}; + memgraph::auth::SynchedAuth auth_{data_directory / "auth", auth_config}; std::unique_ptr auth_handler; std::unique_ptr auth_checker; auth_glue(&auth_, auth_handler, auth_checker); - memgraph::dbms::DbmsHandler dbms_handler(db_config + auto system = memgraph::system::System{db_config.durability.storage_directory, FLAGS_data_recovery_on_startup}; + + // singleton replication state + memgraph::replication::ReplicationState repl_state{ReplicationStateRootPath(db_config)}; + + // singleton coordinator state +#ifdef MG_ENTERPRISE + memgraph::coordination::CoordinatorState coordinator_state; +#endif + + memgraph::dbms::DbmsHandler dbms_handler(db_config, system, repl_state #ifdef MG_ENTERPRISE , - &auth_, FLAGS_data_recovery_on_startup + auth_, FLAGS_data_recovery_on_startup #endif ); + + // Note: Now that all system's subsystems are initialised (dbms & auth) + // We can now initialise the recovery of replication (which will include those subsystems) + // ReplicationHandler will handle the recovery + auto replication_handler = memgraph::replication::ReplicationHandler{repl_state, dbms_handler +#ifdef MG_ENTERPRISE + , + &system, auth_ +#endif + }; + +#ifdef MG_ENTERPRISE + // MAIN or REPLICA instance + if (FLAGS_coordinator_server_port) { + memgraph::dbms::CoordinatorHandlers::Register(coordinator_state.GetCoordinatorServer(), replication_handler); + MG_ASSERT(coordinator_state.GetCoordinatorServer().Start(), "Failed to start coordinator server!"); + } +#endif + auto db_acc = dbms_handler.Get(); - memgraph::query::InterpreterContext interpreter_context_( - interp_config, &dbms_handler, &dbms_handler.ReplicationState(), auth_handler.get(), auth_checker.get()); + memgraph::query::InterpreterContext interpreter_context_(interp_config, &dbms_handler, &repl_state, system, +#ifdef MG_ENTERPRISE + &coordinator_state, +#endif + auth_handler.get(), auth_checker.get(), + &replication_handler); MG_ASSERT(db_acc, "Failed to access the main database"); memgraph::query::procedure::gModuleRegistry.SetModulesDirectory(memgraph::flags::ParseQueryModulesDirectory(), @@ -460,9 +495,9 @@ int main(int argc, char **argv) { if (FLAGS_telemetry_enabled) { telemetry.emplace(telemetry_server, data_directory / "telemetry", memgraph::glue::run_id_, machine_id, service_name == "BoltS", FLAGS_data_directory, std::chrono::minutes(10)); - telemetry->AddStorageCollector(dbms_handler, auth_); + telemetry->AddStorageCollector(dbms_handler, auth_, repl_state); #ifdef MG_ENTERPRISE - telemetry->AddDatabaseCollector(dbms_handler); + telemetry->AddDatabaseCollector(dbms_handler, repl_state); #else telemetry->AddDatabaseCollector(); #endif diff --git a/src/query/CMakeLists.txt b/src/query/CMakeLists.txt index 39e508ed1..3bc7c9499 100644 --- a/src/query/CMakeLists.txt +++ b/src/query/CMakeLists.txt @@ -56,6 +56,7 @@ target_link_libraries(mg-query PUBLIC dl mg-kvstore mg-memory mg::csv + mg::system mg-flags mg-dbms mg-events) diff --git a/src/query/auth_query_handler.hpp b/src/query/auth_query_handler.hpp index 693103354..0258005c3 100644 --- a/src/query/auth_query_handler.hpp +++ b/src/query/auth_query_handler.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -18,6 +18,7 @@ #include "query/frontend/ast/ast.hpp" // overkill #include "query/typed_value.hpp" +#include "system/system.hpp" namespace memgraph::query { @@ -33,23 +34,27 @@ class AuthQueryHandler { /// Return false if the user already exists. /// @throw QueryRuntimeException if an error ocurred. - virtual bool CreateUser(const std::string &username, const std::optional &password) = 0; + virtual bool CreateUser(const std::string &username, const std::optional &password, + system::Transaction *system_tx) = 0; /// Return false if the user does not exist. /// @throw QueryRuntimeException if an error ocurred. - virtual bool DropUser(const std::string &username) = 0; + virtual bool DropUser(const std::string &username, system::Transaction *system_tx) = 0; /// @throw QueryRuntimeException if an error ocurred. - virtual void SetPassword(const std::string &username, const std::optional &password) = 0; + virtual void SetPassword(const std::string &username, const std::optional &password, + system::Transaction *system_tx) = 0; #ifdef MG_ENTERPRISE /// Return true if access revoked successfully /// @throw QueryRuntimeException if an error ocurred. - virtual bool RevokeDatabaseFromUser(const std::string &db, const std::string &username) = 0; + virtual bool RevokeDatabaseFromUser(const std::string &db, const std::string &username, + system::Transaction *system_tx) = 0; /// Return true if access granted successfully /// @throw QueryRuntimeException if an error ocurred. - virtual bool GrantDatabaseToUser(const std::string &db, const std::string &username) = 0; + virtual bool GrantDatabaseToUser(const std::string &db, const std::string &username, + system::Transaction *system_tx) = 0; /// Returns database access rights for the user /// @throw QueryRuntimeException if an error ocurred. @@ -57,20 +62,20 @@ class AuthQueryHandler { /// Return true if main database set successfully /// @throw QueryRuntimeException if an error ocurred. - virtual bool SetMainDatabase(std::string_view db, const std::string &username) = 0; + virtual bool SetMainDatabase(std::string_view db, const std::string &username, system::Transaction *system_tx) = 0; /// Delete database from all users /// @throw QueryRuntimeException if an error ocurred. - virtual void DeleteDatabase(std::string_view db) = 0; + virtual void DeleteDatabase(std::string_view db, system::Transaction *system_tx) = 0; #endif /// Return false if the role already exists. /// @throw QueryRuntimeException if an error ocurred. - virtual bool CreateRole(const std::string &rolename) = 0; + virtual bool CreateRole(const std::string &rolename, system::Transaction *system_tx) = 0; /// Return false if the role does not exist. /// @throw QueryRuntimeException if an error ocurred. - virtual bool DropRole(const std::string &rolename) = 0; + virtual bool DropRole(const std::string &rolename, system::Transaction *system_tx) = 0; /// @throw QueryRuntimeException if an error ocurred. virtual std::vector GetUsernames() = 0; @@ -85,10 +90,10 @@ class AuthQueryHandler { virtual std::vector GetUsernamesForRole(const std::string &rolename) = 0; /// @throw QueryRuntimeException if an error ocurred. - virtual void SetRole(const std::string &username, const std::string &rolename) = 0; + virtual void SetRole(const std::string &username, const std::string &rolename, system::Transaction *system_tx) = 0; /// @throw QueryRuntimeException if an error ocurred. - virtual void ClearRole(const std::string &username) = 0; + virtual void ClearRole(const std::string &username, system::Transaction *system_tx) = 0; virtual std::vector> GetPrivileges(const std::string &user_or_role) = 0; @@ -103,11 +108,13 @@ class AuthQueryHandler { const std::vector>> &edge_type_privileges #endif - ) = 0; + , + system::Transaction *system_tx) = 0; /// @throw QueryRuntimeException if an error ocurred. virtual void DenyPrivilege(const std::string &user_or_role, - const std::vector &privileges) = 0; + const std::vector &privileges, + system::Transaction *system_tx) = 0; /// @throw QueryRuntimeException if an error ocurred. virtual void RevokePrivilege( @@ -120,7 +127,8 @@ class AuthQueryHandler { const std::vector>> &edge_type_privileges #endif - ) = 0; + , + system::Transaction *system_tx) = 0; }; } // namespace memgraph::query diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index af46865db..783ff6ae9 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -34,7 +34,9 @@ #include "auth/auth.hpp" #include "auth/models.hpp" #include "csv/parsing.hpp" +#include "dbms/coordinator_handler.hpp" #include "dbms/database.hpp" +#include "dbms/dbms_handler.hpp" #include "dbms/global.hpp" #include "dbms/inmemory/storage_helper.hpp" #include "flags/replication.hpp" @@ -43,6 +45,7 @@ #include "license/license.hpp" #include "memory/global_memory_control.hpp" #include "memory/query_memory_control.hpp" +#include "query/auth_query_handler.hpp" #include "query/config.hpp" #include "query/constants.hpp" #include "query/context.hpp" @@ -58,12 +61,14 @@ #include "query/frontend/semantic/symbol_generator.hpp" #include "query/interpret/eval.hpp" #include "query/interpret/frame.hpp" +#include "query/interpreter_context.hpp" #include "query/metadata.hpp" #include "query/plan/hint_provider.hpp" #include "query/plan/planner.hpp" #include "query/plan/profile.hpp" #include "query/plan/vertex_count_cache.hpp" #include "query/procedure/module.hpp" +#include "query/replication_query_handler.hpp" #include "query/stream.hpp" #include "query/stream/common.hpp" #include "query/stream/sources.hpp" @@ -71,6 +76,7 @@ #include "query/trigger.hpp" #include "query/typed_value.hpp" #include "replication/config.hpp" +#include "replication/state.hpp" #include "spdlog/spdlog.h" #include "storage/v2/disk/storage.hpp" #include "storage/v2/edge.hpp" @@ -101,13 +107,6 @@ #include "utils/typeinfo.hpp" #include "utils/variant_helpers.hpp" -#include "dbms/coordinator_handler.hpp" -#include "dbms/dbms_handler.hpp" -#include "dbms/replication_handler.hpp" -#include "query/auth_query_handler.hpp" -#include "query/interpreter_context.hpp" -#include "replication/state.hpp" - #ifdef MG_ENTERPRISE #include "coordination/constants.hpp" #endif @@ -306,17 +305,18 @@ class ReplQueryHandler { ReplicationQuery::ReplicaState state; }; - explicit ReplQueryHandler(dbms::DbmsHandler *dbms_handler) : handler_{*dbms_handler} {} + explicit ReplQueryHandler(query::ReplicationQueryHandler &replication_query_handler) + : handler_{&replication_query_handler} {} /// @throw QueryRuntimeException if an error ocurred. void SetReplicationRole(ReplicationQuery::ReplicationRole replication_role, std::optional port) { auto ValidatePort = [](std::optional port) -> void { - if (*port < 0 || *port > std::numeric_limits::max()) { + if (!port || *port < 0 || *port > std::numeric_limits::max()) { throw QueryRuntimeException("Port number invalid!"); } }; if (replication_role == ReplicationQuery::ReplicationRole::MAIN) { - if (!handler_.SetReplicationRoleMain()) { + if (!handler_->SetReplicationRoleMain()) { throw QueryRuntimeException("Couldn't set replication role to main!"); } } else { @@ -327,7 +327,7 @@ class ReplQueryHandler { .port = static_cast(*port), }; - if (!handler_.SetReplicationRoleReplica(config)) { + if (!handler_->SetReplicationRoleReplica(config)) { throw QueryRuntimeException("Couldn't set role to replica!"); } } @@ -335,7 +335,7 @@ class ReplQueryHandler { /// @throw QueryRuntimeException if an error ocurred. ReplicationQuery::ReplicationRole ShowReplicationRole() const { - switch (handler_.GetRole()) { + switch (handler_->GetRole()) { case memgraph::replication_coordination_glue::ReplicationRole::MAIN: return ReplicationQuery::ReplicationRole::MAIN; case memgraph::replication_coordination_glue::ReplicationRole::REPLICA: @@ -349,7 +349,7 @@ class ReplQueryHandler { const ReplicationQuery::SyncMode sync_mode, const std::chrono::seconds replica_check_frequency) { // Coordinator is main by default so this check is OK although it should actually be nothing (neither main nor // replica) - if (handler_.IsReplica()) { + if (handler_->IsReplica()) { // replica can't register another replica throw QueryRuntimeException("Replica can't register another replica!"); } @@ -368,7 +368,7 @@ class ReplQueryHandler { .replica_check_frequency = replica_check_frequency, .ssl = std::nullopt}; - const auto error = handler_.RegisterReplica(replication_config).HasError(); + const auto error = handler_->TryRegisterReplica(replication_config).HasError(); if (error) { throw QueryRuntimeException(fmt::format("Couldn't register replica '{}'!", name)); @@ -381,9 +381,9 @@ class ReplQueryHandler { /// @throw QueryRuntimeException if an error occurred. void DropReplica(std::string_view replica_name) { - auto const result = handler_.UnregisterReplica(replica_name); + auto const result = handler_->UnregisterReplica(replica_name); switch (result) { - using enum memgraph::dbms::UnregisterReplicaResult; + using enum memgraph::query::UnregisterReplicaResult; case NOT_MAIN: throw QueryRuntimeException("Replica can't unregister a replica!"); case COULD_NOT_BE_PERSISTED: @@ -396,7 +396,7 @@ class ReplQueryHandler { } std::vector ShowReplicas(const dbms::Database &db) const { - if (handler_.IsReplica()) { + if (handler_->IsReplica()) { // replica can't show registered replicas (it shouldn't have any) throw QueryRuntimeException("Replica can't show registered replicas (it shouldn't have any)!"); } @@ -447,19 +447,16 @@ class ReplQueryHandler { } private: - dbms::ReplicationHandler handler_; + query::ReplicationQueryHandler *handler_; }; +#ifdef MG_ENTERPRISE class CoordQueryHandler final : public query::CoordinatorQueryHandler { public: - explicit CoordQueryHandler(dbms::DbmsHandler *dbms_handler) : handler_ { *dbms_handler } -#ifdef MG_ENTERPRISE - , coordinator_handler_(*dbms_handler) -#endif - { - } + explicit CoordQueryHandler(coordination::CoordinatorState &coordinator_state) + + : coordinator_handler_(coordinator_state) {} -#ifdef MG_ENTERPRISE /// @throw QueryRuntimeException if an error ocurred. void RegisterInstance(const std::string &coordinator_socket_address, const std::string &replication_socket_address, const std::chrono::seconds instance_check_frequency, const std::string &instance_name, @@ -497,7 +494,7 @@ class CoordQueryHandler final : public query::CoordinatorQueryHandler { using enum memgraph::coordination::RegisterInstanceCoordinatorStatus; case NAME_EXISTS: throw QueryRuntimeException("Couldn't register replica instance since instance with such name already exists!"); - case END_POINT_EXISTS: + case ENDPOINT_EXISTS: throw QueryRuntimeException( "Couldn't register replica instance since instance with such endpoint already exists!"); case NOT_COORDINATOR: @@ -527,26 +524,20 @@ class CoordQueryHandler final : public query::CoordinatorQueryHandler { } } -#endif - -#ifdef MG_ENTERPRISE std::vector ShowInstances() const override { return coordinator_handler_.ShowInstances(); } -#endif - private: - dbms::ReplicationHandler handler_; -#ifdef MG_ENTERPRISE dbms::CoordinatorHandler coordinator_handler_; -#endif }; +#endif /// returns false if the replication role can't be set /// @throw QueryRuntimeException if an error ocurred. -Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_context, const Parameters ¶meters) { +Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_context, const Parameters ¶meters, + Interpreter &interpreter) { AuthQueryHandler *auth = interpreter_context->auth; #ifdef MG_ENTERPRISE auto *db_handler = interpreter_context->dbms_handler; @@ -595,63 +586,111 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_ license::LicenseCheckErrorToString(license_check_result.GetError(), "advanced authentication features")); } + const auto forbid_on_replica = [has_license = license_check_result.HasError(), + is_replica = interpreter_context->repl_state->IsReplica()]() { + if (is_replica) { +#if MG_ENTERPRISE + if (has_license) { + throw QueryException( + "Query forbidden on the replica! Update on MAIN, as it is the only source of truth for authentication " + "data. MAIN will then replicate the update to connected REPLICAs"); + } + throw QueryException( + "Query forbidden on the replica! Switch role to MAIN and update user data, then switch back to REPLICA."); +#else + throw QueryException( + "Query forbidden on the replica! Switch role to MAIN and update user data, then switch back to REPLICA."); +#endif + } + }; + switch (auth_query->action_) { case AuthQuery::Action::CREATE_USER: - callback.fn = [auth, username, password, valid_enterprise_license = !license_check_result.HasError()] { + forbid_on_replica(); + callback.fn = [auth, username, password, valid_enterprise_license = !license_check_result.HasError(), + interpreter = &interpreter] { + if (!interpreter->system_transaction_) { + throw QueryException("Expected to be in a system transaction"); + } + MG_ASSERT(password.IsString() || password.IsNull()); - if (!auth->CreateUser(username, password.IsString() ? std::make_optional(std::string(password.ValueString())) - : std::nullopt)) { + if (!auth->CreateUser( + username, password.IsString() ? std::make_optional(std::string(password.ValueString())) : std::nullopt, + &*interpreter->system_transaction_)) { throw UserAlreadyExistsException("User '{}' already exists.", username); } // If the license is not valid we create users with admin access if (!valid_enterprise_license) { spdlog::warn("Granting all the privileges to {}.", username); - auth->GrantPrivilege(username, kPrivilegesAll + auth->GrantPrivilege( + username, kPrivilegesAll #ifdef MG_ENTERPRISE - , - {{{AuthQuery::FineGrainedPrivilege::CREATE_DELETE, {query::kAsterisk}}}}, - { - { - { - AuthQuery::FineGrainedPrivilege::CREATE_DELETE, { query::kAsterisk } - } - } - } + , + {{{AuthQuery::FineGrainedPrivilege::CREATE_DELETE, {query::kAsterisk}}}}, + { + { + { + AuthQuery::FineGrainedPrivilege::CREATE_DELETE, { query::kAsterisk } + } + } + } #endif - ); + , + &*interpreter->system_transaction_); } return std::vector>(); }; return callback; case AuthQuery::Action::DROP_USER: - callback.fn = [auth, username] { - if (!auth->DropUser(username)) { + forbid_on_replica(); + callback.fn = [auth, username, interpreter = &interpreter] { + if (!interpreter->system_transaction_) { + throw QueryException("Expected to be in a system transaction"); + } + + if (!auth->DropUser(username, &*interpreter->system_transaction_)) { throw QueryRuntimeException("User '{}' doesn't exist.", username); } return std::vector>(); }; return callback; case AuthQuery::Action::SET_PASSWORD: - callback.fn = [auth, username, password] { + forbid_on_replica(); + callback.fn = [auth, username, password, interpreter = &interpreter] { + if (!interpreter->system_transaction_) { + throw QueryException("Expected to be in a system transaction"); + } + MG_ASSERT(password.IsString() || password.IsNull()); auth->SetPassword(username, - password.IsString() ? std::make_optional(std::string(password.ValueString())) : std::nullopt); + password.IsString() ? std::make_optional(std::string(password.ValueString())) : std::nullopt, + &*interpreter->system_transaction_); return std::vector>(); }; return callback; case AuthQuery::Action::CREATE_ROLE: - callback.fn = [auth, rolename] { - if (!auth->CreateRole(rolename)) { + forbid_on_replica(); + callback.fn = [auth, rolename, interpreter = &interpreter] { + if (!interpreter->system_transaction_) { + throw QueryException("Expected to be in a system transaction"); + } + + if (!auth->CreateRole(rolename, &*interpreter->system_transaction_)) { throw QueryRuntimeException("Role '{}' already exists.", rolename); } return std::vector>(); }; return callback; case AuthQuery::Action::DROP_ROLE: - callback.fn = [auth, rolename] { - if (!auth->DropRole(rolename)) { + forbid_on_replica(); + callback.fn = [auth, rolename, interpreter = &interpreter] { + if (!interpreter->system_transaction_) { + throw QueryException("Expected to be in a system transaction"); + } + + if (!auth->DropRole(rolename, &*interpreter->system_transaction_)) { throw QueryRuntimeException("Role '{}' doesn't exist.", rolename); } return std::vector>(); @@ -682,52 +721,79 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_ }; return callback; case AuthQuery::Action::SET_ROLE: - callback.fn = [auth, username, rolename] { - auth->SetRole(username, rolename); + forbid_on_replica(); + callback.fn = [auth, username, rolename, interpreter = &interpreter] { + if (!interpreter->system_transaction_) { + throw QueryException("Expected to be in a system transaction"); + } + + auth->SetRole(username, rolename, &*interpreter->system_transaction_); return std::vector>(); }; return callback; case AuthQuery::Action::CLEAR_ROLE: - callback.fn = [auth, username] { - auth->ClearRole(username); + forbid_on_replica(); + callback.fn = [auth, username, interpreter = &interpreter] { + if (!interpreter->system_transaction_) { + throw QueryException("Expected to be in a system transaction"); + } + + auth->ClearRole(username, &*interpreter->system_transaction_); return std::vector>(); }; return callback; case AuthQuery::Action::GRANT_PRIVILEGE: - callback.fn = [auth, user_or_role, privileges + forbid_on_replica(); + callback.fn = [auth, user_or_role, privileges, interpreter = &interpreter #ifdef MG_ENTERPRISE , label_privileges, edge_type_privileges #endif ] { + if (!interpreter->system_transaction_) { + throw QueryException("Expected to be in a system transaction"); + } + auth->GrantPrivilege(user_or_role, privileges #ifdef MG_ENTERPRISE , label_privileges, edge_type_privileges #endif - ); + , + &*interpreter->system_transaction_); return std::vector>(); }; return callback; case AuthQuery::Action::DENY_PRIVILEGE: - callback.fn = [auth, user_or_role, privileges] { - auth->DenyPrivilege(user_or_role, privileges); + forbid_on_replica(); + callback.fn = [auth, user_or_role, privileges, interpreter = &interpreter] { + if (!interpreter->system_transaction_) { + throw QueryException("Expected to be in a system transaction"); + } + + auth->DenyPrivilege(user_or_role, privileges, &*interpreter->system_transaction_); return std::vector>(); }; return callback; case AuthQuery::Action::REVOKE_PRIVILEGE: { - callback.fn = [auth, user_or_role, privileges + forbid_on_replica(); + callback.fn = [auth, user_or_role, privileges, interpreter = &interpreter #ifdef MG_ENTERPRISE , label_privileges, edge_type_privileges #endif ] { + if (!interpreter->system_transaction_) { + throw QueryException("Expected to be in a system transaction"); + } + auth->RevokePrivilege(user_or_role, privileges #ifdef MG_ENTERPRISE , label_privileges, edge_type_privileges #endif - ); + , + &*interpreter->system_transaction_); return std::vector>(); }; return callback; @@ -757,15 +823,20 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_ }; return callback; case AuthQuery::Action::GRANT_DATABASE_TO_USER: + forbid_on_replica(); #ifdef MG_ENTERPRISE - callback.fn = [auth, database, username, db_handler] { // NOLINT + callback.fn = [auth, database, username, db_handler, interpreter = &interpreter] { // NOLINT + if (!interpreter->system_transaction_) { + throw QueryException("Expected to be in a system transaction"); + } + try { std::optional db = std::nullopt; // Hold pointer to database to protect it until query is done if (database != memgraph::auth::kAllDatabases) { db = db_handler->Get(database); // Will throw if databases doesn't exist and protect it during pull } - if (!auth->GrantDatabaseToUser(database, username)) { + if (!auth->GrantDatabaseToUser(database, username, &*interpreter->system_transaction_)) { throw QueryRuntimeException("Failed to grant database {} to user {}.", database, username); } } catch (memgraph::dbms::UnknownDatabaseException &e) { @@ -778,15 +849,20 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_ }; return callback; case AuthQuery::Action::REVOKE_DATABASE_FROM_USER: + forbid_on_replica(); #ifdef MG_ENTERPRISE - callback.fn = [auth, database, username, db_handler] { // NOLINT + callback.fn = [auth, database, username, db_handler, interpreter = &interpreter] { // NOLINT + if (!interpreter->system_transaction_) { + throw QueryException("Expected to be in a system transaction"); + } + try { std::optional db = std::nullopt; // Hold pointer to database to protect it until query is done if (database != memgraph::auth::kAllDatabases) { db = db_handler->Get(database); // Will throw if databases doesn't exist and protect it during pull } - if (!auth->RevokeDatabaseFromUser(database, username)) { + if (!auth->RevokeDatabaseFromUser(database, username, &*interpreter->system_transaction_)) { throw QueryRuntimeException("Failed to revoke database {} from user {}.", database, username); } } catch (memgraph::dbms::UnknownDatabaseException &e) { @@ -811,12 +887,17 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_ #endif return callback; case AuthQuery::Action::SET_MAIN_DATABASE: + forbid_on_replica(); #ifdef MG_ENTERPRISE - callback.fn = [auth, database, username, db_handler] { // NOLINT + callback.fn = [auth, database, username, db_handler, interpreter = &interpreter] { // NOLINT + if (!interpreter->system_transaction_) { + throw QueryException("Expected to be in a system transaction"); + } + try { const auto db = db_handler->Get(database); // Will throw if databases doesn't exist and protect it during pull - if (!auth->SetMainDatabase(database, username)) { + if (!auth->SetMainDatabase(database, username, &*interpreter->system_transaction_)) { throw QueryRuntimeException("Failed to set main database {} for user {}.", database, username); } } catch (memgraph::dbms::UnknownDatabaseException &e) { @@ -834,7 +915,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_ } // namespace Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters ¶meters, - dbms::DbmsHandler *dbms_handler, CurrentDB ¤t_db, + ReplicationQueryHandler &replication_query_handler, CurrentDB ¤t_db, const query::InterpreterConfig &config, std::vector *notifications) { // TODO: MemoryResource for EvaluationContext, it should probably be passed as // the argument to Callback. @@ -864,7 +945,8 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters & notifications->emplace_back(SeverityLevel::WARNING, NotificationCode::REPLICA_PORT_WARNING, "Be careful the replication port must be different from the memgraph port!"); } - callback.fn = [handler = ReplQueryHandler{dbms_handler}, role = repl_query->role_, maybe_port]() mutable { + callback.fn = [handler = ReplQueryHandler{replication_query_handler}, role = repl_query->role_, + maybe_port]() mutable { handler.SetReplicationRole(role, maybe_port); return std::vector>(); }; @@ -882,7 +964,7 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters & #endif callback.header = {"replication role"}; - callback.fn = [handler = ReplQueryHandler{dbms_handler}] { + callback.fn = [handler = ReplQueryHandler{replication_query_handler}] { auto mode = handler.ShowReplicationRole(); switch (mode) { case ReplicationQuery::ReplicationRole::MAIN: { @@ -906,7 +988,7 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters & auto socket_address = repl_query->socket_address_->Accept(evaluator); const auto replica_check_frequency = config.replication_replica_check_frequency; - callback.fn = [handler = ReplQueryHandler{dbms_handler}, name, socket_address, sync_mode, + callback.fn = [handler = ReplQueryHandler{replication_query_handler}, name, socket_address, sync_mode, replica_check_frequency]() mutable { handler.RegisterReplica(name, std::string(socket_address.ValueString()), sync_mode, replica_check_frequency); return std::vector>(); @@ -923,7 +1005,7 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters & } #endif const auto &name = repl_query->instance_name_; - callback.fn = [handler = ReplQueryHandler{dbms_handler}, name]() mutable { + callback.fn = [handler = ReplQueryHandler{replication_query_handler}, name]() mutable { handler.DropReplica(name); return std::vector>(); }; @@ -941,7 +1023,7 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters & callback.header = { "name", "socket_address", "sync_mode", "current_timestamp_of_replica", "number_of_timestamp_behind_master", "state"}; - callback.fn = [handler = ReplQueryHandler{dbms_handler}, replica_nfields = callback.header.size(), + callback.fn = [handler = ReplQueryHandler{replication_query_handler}, replica_nfields = callback.header.size(), db_acc = current_db.db_acc_] { const auto &replicas = handler.ShowReplicas(*db_acc->get()); auto typed_replicas = std::vector>{}; @@ -989,16 +1071,17 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters & } } +#ifdef MG_ENTERPRISE Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Parameters ¶meters, - dbms::DbmsHandler *dbms_handler, const query::InterpreterConfig &config, - std::vector *notifications) { + coordination::CoordinatorState *coordinator_state, + const query::InterpreterConfig &config, std::vector *notifications) { Callback callback; switch (coordinator_query->action_) { case CoordinatorQuery::Action::REGISTER_INSTANCE: { if (!license::global_license_checker.IsEnterpriseValidFast()) { throw QueryException("Trying to use enterprise feature without a valid license."); } -#ifdef MG_ENTERPRISE + if constexpr (!coordination::allow_ha) { throw QueryRuntimeException( "High availability is experimental feature. Please set MG_EXPERIMENTAL_HIGH_AVAILABILITY compile flag to " @@ -1014,7 +1097,7 @@ Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Param auto coordinator_socket_address_tv = coordinator_query->coordinator_socket_address_->Accept(evaluator); auto replication_socket_address_tv = coordinator_query->replication_socket_address_->Accept(evaluator); - callback.fn = [handler = CoordQueryHandler{dbms_handler}, coordinator_socket_address_tv, + callback.fn = [handler = CoordQueryHandler{*coordinator_state}, coordinator_socket_address_tv, replication_socket_address_tv, main_check_frequency = config.replication_replica_check_frequency, instance_name = coordinator_query->instance_name_, sync_mode = coordinator_query->sync_mode_]() mutable { @@ -1029,13 +1112,11 @@ Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Param fmt::format("Coordinator has registered coordinator server on {} for instance {}.", coordinator_socket_address_tv.ValueString(), coordinator_query->instance_name_)); return callback; -#endif } case CoordinatorQuery::Action::SET_INSTANCE_TO_MAIN: { if (!license::global_license_checker.IsEnterpriseValidFast()) { throw QueryException("Trying to use enterprise feature without a valid license."); } -#ifdef MG_ENTERPRISE if constexpr (!coordination::allow_ha) { throw QueryRuntimeException( "High availability is experimental feature. Please set MG_EXPERIMENTAL_HIGH_AVAILABILITY compile flag to " @@ -1049,20 +1130,18 @@ Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Param EvaluationContext evaluation_context{.timestamp = QueryTimestamp(), .parameters = parameters}; auto evaluator = PrimitiveLiteralExpressionEvaluator{evaluation_context}; - callback.fn = [handler = CoordQueryHandler{dbms_handler}, + callback.fn = [handler = CoordQueryHandler{*coordinator_state}, instance_name = coordinator_query->instance_name_]() mutable { handler.SetInstanceToMain(instance_name); return std::vector>(); }; return callback; -#endif } case CoordinatorQuery::Action::SHOW_REPLICATION_CLUSTER: { if (!license::global_license_checker.IsEnterpriseValidFast()) { throw QueryException("Trying to use enterprise feature without a valid license."); } -#ifdef MG_ENTERPRISE if constexpr (!coordination::allow_ha) { throw QueryRuntimeException( "High availability is experimental feature. Please set MG_EXPERIMENTAL_HIGH_AVAILABILITY compile flag to " @@ -1073,7 +1152,8 @@ Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Param } callback.header = {"name", "socket_address", "alive", "role"}; - callback.fn = [handler = CoordQueryHandler{dbms_handler}, replica_nfields = callback.header.size()]() mutable { + callback.fn = [handler = CoordQueryHandler{*coordinator_state}, + replica_nfields = callback.header.size()]() mutable { auto const instances = handler.ShowInstances(); std::vector> result{}; result.reserve(result.size()); @@ -1087,11 +1167,11 @@ Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Param return result; }; return callback; -#endif } return callback; } } +#endif stream::CommonStreamInfo GetCommonStreamInfo(StreamQuery *stream_query, ExpressionVisitor &evaluator) { return { @@ -2493,14 +2573,14 @@ PreparedQuery PrepareIndexQuery(ParsedQuery parsed_query, bool in_explicit_trans } PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transaction, - InterpreterContext *interpreter_context) { + InterpreterContext *interpreter_context, Interpreter &interpreter) { if (in_explicit_transaction) { throw UserModificationInMulticommandTxException(); } auto *auth_query = utils::Downcast(parsed_query.query); - auto callback = HandleAuthQuery(auth_query, interpreter_context, parsed_query.parameters); + auto callback = HandleAuthQuery(auth_query, interpreter_context, parsed_query.parameters, interpreter); return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges), [handler = std::move(callback.fn), pull_plan = std::shared_ptr(nullptr), @@ -2525,15 +2605,16 @@ PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transa } PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, bool in_explicit_transaction, - std::vector *notifications, dbms::DbmsHandler &dbms_handler, - CurrentDB ¤t_db, const InterpreterConfig &config) { + std::vector *notifications, + ReplicationQueryHandler &replication_query_handler, CurrentDB ¤t_db, + const InterpreterConfig &config) { if (in_explicit_transaction) { throw ReplicationModificationInMulticommandTxException(); } auto *replication_query = utils::Downcast(parsed_query.query); - auto callback = HandleReplicationQuery(replication_query, parsed_query.parameters, &dbms_handler, current_db, config, - notifications); + auto callback = HandleReplicationQuery(replication_query, parsed_query.parameters, replication_query_handler, + current_db, config, notifications); return PreparedQuery{callback.header, std::move(parsed_query.required_privileges), [callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr{nullptr}]( @@ -2552,8 +2633,10 @@ PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, bool in_explicit // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) } +#ifdef MG_ENTERPRISE PreparedQuery PrepareCoordinatorQuery(ParsedQuery parsed_query, bool in_explicit_transaction, - std::vector *notifications, dbms::DbmsHandler &dbms_handler, + std::vector *notifications, + coordination::CoordinatorState &coordinator_state, const InterpreterConfig &config) { if (in_explicit_transaction) { throw CoordinatorModificationInMulticommandTxException(); @@ -2561,7 +2644,7 @@ PreparedQuery PrepareCoordinatorQuery(ParsedQuery parsed_query, bool in_explicit auto *coordinator_query = utils::Downcast(parsed_query.query); auto callback = - HandleCoordinatorQuery(coordinator_query, parsed_query.parameters, &dbms_handler, config, notifications); + HandleCoordinatorQuery(coordinator_query, parsed_query.parameters, &coordinator_state, config, notifications); return PreparedQuery{callback.header, std::move(parsed_query.required_privileges), [callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr{nullptr}]( @@ -2579,6 +2662,7 @@ PreparedQuery PrepareCoordinatorQuery(ParsedQuery parsed_query, bool in_explicit // False positive report for the std::make_shared above // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) } +#endif PreparedQuery PrepareLockPathQuery(ParsedQuery parsed_query, bool in_explicit_transaction, CurrentDB ¤t_db) { if (in_explicit_transaction) { @@ -3681,7 +3765,8 @@ PreparedQuery PrepareConstraintQuery(ParsedQuery parsed_query, bool in_explicit_ PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, CurrentDB ¤t_db, InterpreterContext *interpreter_context, - std::optional> on_change_cb) { + std::optional> on_change_cb, + Interpreter &interpreter) { #ifdef MG_ENTERPRISE if (!license::global_license_checker.IsEnterpriseValidFast()) { throw QueryException("Trying to use enterprise feature without a valid license."); @@ -3700,12 +3785,16 @@ PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, CurrentDB &cur return PreparedQuery{ {"STATUS"}, std::move(parsed_query.required_privileges), - [db_name = query->db_name_, db_handler](AnyStream *stream, - std::optional n) -> std::optional { + [db_name = query->db_name_, db_handler, interpreter = &interpreter]( + AnyStream *stream, std::optional n) -> std::optional { + if (!interpreter->system_transaction_) { + throw QueryException("Expected to be in a system transaction"); + } + std::vector> status; std::string res; - const auto success = db_handler->New(db_name); + const auto success = db_handler->New(db_name, &*interpreter->system_transaction_); if (success.HasError()) { switch (success.GetError()) { case dbms::NewError::EXISTS: @@ -3780,16 +3869,20 @@ PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, CurrentDB &cur return PreparedQuery{ {"STATUS"}, std::move(parsed_query.required_privileges), - [db_name = query->db_name_, db_handler, auth = interpreter_context->auth]( + [db_name = query->db_name_, db_handler, auth = interpreter_context->auth, interpreter = &interpreter]( AnyStream *stream, std::optional n) -> std::optional { + if (!interpreter->system_transaction_) { + throw QueryException("Expected to be in a system transaction"); + } + std::vector> status; try { // Remove database - auto success = db_handler->TryDelete(db_name); + auto success = db_handler->TryDelete(db_name, &*interpreter->system_transaction_); if (!success.HasError()) { // Remove from auth - if (auth) auth->DeleteDatabase(db_name); + if (auth) auth->DeleteDatabase(db_name, &*interpreter->system_transaction_); } else { switch (success.GetError()) { case dbms::DeleteError::DEFAULT_DB: @@ -4040,18 +4133,15 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, utils::Downcast(parsed_query.query); // TODO Split SHOW REPLICAS (which needs the db) and other replication queries - auto system_transaction_guard = std::invoke([&]() -> std::optional { - if (system_queries) { - // TODO: Ordering between system and data queries - // Start a system transaction - auto system_unique = std::unique_lock{interpreter_context_->dbms_handler->system_lock_, std::defer_lock}; - if (!system_unique.try_lock_for(std::chrono::milliseconds(kSystemTxTryMS))) { - throw ConcurrentSystemQueriesException("Multiple concurrent system queries are not supported."); - } - return std::optional{std::in_place, std::move(system_unique), - *interpreter_context_->dbms_handler}; + auto system_transaction = std::invoke([&]() -> std::optional { + if (!system_queries) return std::nullopt; + + // TODO: Ordering between system and data queries + auto system_txn = interpreter_context_->system_->TryCreateTransaction(std::chrono::milliseconds(kSystemTxTryMS)); + if (!system_txn) { + throw ConcurrentSystemQueriesException("Multiple concurrent system queries are not supported."); } - return std::nullopt; + return system_txn; }); // Some queries do not require a database to be executed (current_db_ won't be passed on to the Prepare*; special @@ -4117,7 +4207,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, prepared_query = PrepareAnalyzeGraphQuery(std::move(parsed_query), in_explicit_transaction_, current_db_); } else if (utils::Downcast(parsed_query.query)) { /// SYSTEM (Replication) PURE - prepared_query = PrepareAuthQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_); + prepared_query = PrepareAuthQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_, *this); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareDatabaseInfoQuery(std::move(parsed_query), in_explicit_transaction_, current_db_); } else if (utils::Downcast(parsed_query.query)) { @@ -4128,13 +4218,18 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, &query_execution->notifications, current_db_); } else if (utils::Downcast(parsed_query.query)) { /// TODO: make replication DB agnostic - prepared_query = - PrepareReplicationQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications, - *interpreter_context_->dbms_handler, current_db_, interpreter_context_->config); + prepared_query = PrepareReplicationQuery( + std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications, + *interpreter_context_->replication_handler_, current_db_, interpreter_context_->config); + } else if (utils::Downcast(parsed_query.query)) { +#ifdef MG_ENTERPRISE prepared_query = PrepareCoordinatorQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications, - *interpreter_context_->dbms_handler, interpreter_context_->config); + *interpreter_context_->coordinator_state_, interpreter_context_->config); +#else + throw QueryRuntimeException("Coordinator queries are not part of community edition"); +#endif } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareLockPathQuery(std::move(parsed_query), in_explicit_transaction_, current_db_); } else if (utils::Downcast(parsed_query.query)) { @@ -4177,8 +4272,8 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, } /// SYSTEM (Replication) + INTERPRETER // DMG_ASSERT(system_guard); - prepared_query = PrepareMultiDatabaseQuery(std::move(parsed_query), current_db_, interpreter_context_, on_change_ - /*, *system_guard*/); + prepared_query = + PrepareMultiDatabaseQuery(std::move(parsed_query), current_db_, interpreter_context_, on_change_, *this); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareShowDatabasesQuery(std::move(parsed_query), interpreter_context_, username_); } else if (utils::Downcast(parsed_query.query)) { @@ -4210,7 +4305,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, query_execution->summary["db"] = *query_execution->prepared_query->db; // prepare is done, move system txn guard to be owned by interpreter - system_transaction_guard_ = std::move(system_transaction_guard); + system_transaction_ = std::move(system_transaction); return {query_execution->prepared_query->header, query_execution->prepared_query->privileges, qid, query_execution->prepared_query->db}; } catch (const utils::BasicException &) { @@ -4360,13 +4455,13 @@ void Interpreter::Commit() { current_transaction_.reset(); if (!current_db_.db_transactional_accessor_ || !current_db_.db_acc_) { // No database nor db transaction; check for system transaction - if (!system_transaction_guard_) return; + if (!system_transaction_) return; // TODO Distinguish between data and system transaction state // Think about updating the status to a struct with bitfield // Clean transaction status on exit utils::OnScopeExit clean_status([this]() { - system_transaction_guard_.reset(); + system_transaction_.reset(); // System transactions are not terminable // Durability has happened at time of PULL // Commit is doing replication and timestamp update @@ -4384,7 +4479,23 @@ void Interpreter::Commit() { } }); - system_transaction_guard_->Commit(); + auto const main_commit = [&](replication::RoleMainData &mainData) { + // Only enterprise can do system replication +#ifdef MG_ENTERPRISE + if (license::global_license_checker.IsEnterpriseValidFast()) { + return system_transaction_->Commit(memgraph::system::DoReplication{mainData}); + } +#endif + return system_transaction_->Commit(memgraph::system::DoNothing{}); + }; + + auto const replica_commit = [&](replication::RoleReplicaData &) { + return system_transaction_->Commit(memgraph::system::DoNothing{}); + }; + + auto const commit_method = utils::Overloaded{main_commit, replica_commit}; + [[maybe_unused]] auto sync_result = std::visit(commit_method, interpreter_context_->repl_state->ReplicationData()); + // TODO: something with sync_result return; } auto *db = current_db_.db_acc_->get(); diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index 42100059c..cf822d8b9 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -72,6 +72,7 @@ inline constexpr size_t kExecutionPoolMaxBlockSize = 1024UL; // 2 ^ 10 enum class QueryHandlerResult { COMMIT, ABORT, NOTHING }; +#ifdef MG_ENTERPRISE class CoordinatorQueryHandler { public: CoordinatorQueryHandler() = default; @@ -93,7 +94,6 @@ class CoordinatorQueryHandler { ReplicationQuery::ReplicaState state; }; -#ifdef MG_ENTERPRISE struct MainReplicaStatus { std::string_view name; std::string_view socket_address; @@ -103,9 +103,7 @@ class CoordinatorQueryHandler { MainReplicaStatus(std::string_view name, std::string_view socket_address, bool alive, bool is_main) : name{name}, socket_address{socket_address}, alive{alive}, is_main{is_main} {} }; -#endif -#ifdef MG_ENTERPRISE /// @throw QueryRuntimeException if an error ocurred. virtual void RegisterInstance(const std::string &coordinator_socket_address, const std::string &replication_socket_address, @@ -117,9 +115,8 @@ class CoordinatorQueryHandler { /// @throw QueryRuntimeException if an error ocurred. virtual std::vector ShowInstances() const = 0; - -#endif }; +#endif class AnalyzeGraphQueryHandler { public: @@ -296,32 +293,12 @@ class Interpreter final { void SetUser(std::string_view username); - struct SystemTransactionGuard { - explicit SystemTransactionGuard(std::unique_lock guard, dbms::DbmsHandler &dbms_handler) - : system_guard_(std::move(guard)), dbms_handler_{&dbms_handler} { - dbms_handler_->NewSystemTransaction(); - } - SystemTransactionGuard &operator=(SystemTransactionGuard &&) = default; - SystemTransactionGuard(SystemTransactionGuard &&) = default; - - ~SystemTransactionGuard() { - if (system_guard_.owns_lock()) dbms_handler_->ResetSystemTransaction(); - } - - dbms::AllSyncReplicaStatus Commit() { return dbms_handler_->Commit(); } - - private: - std::unique_lock system_guard_; - dbms::DbmsHandler *dbms_handler_; - }; - - std::optional system_transaction_guard_{}; + std::optional system_transaction_{}; private: void ResetInterpreter() { query_executions_.clear(); - system_guard.reset(); - system_transaction_guard_.reset(); + system_transaction_.reset(); transaction_queries_->clear(); if (current_db_.db_acc_ && current_db_.db_acc_->is_deleting()) { current_db_.db_acc_.reset(); @@ -386,8 +363,6 @@ class Interpreter final { // TODO Figure out how this would work for multi-database // Exists only during a single transaction (for now should be okay as is) std::vector> query_executions_; - // TODO: our upgradable lock guard for system - std::optional system_guard; // all queries that are run as part of the current transaction utils::Synchronized, utils::SpinLock> transaction_queries_; diff --git a/src/query/interpreter_context.cpp b/src/query/interpreter_context.cpp index cace25ec6..f7b4584ba 100644 --- a/src/query/interpreter_context.cpp +++ b/src/query/interpreter_context.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -12,12 +12,27 @@ #include "query/interpreter_context.hpp" #include "query/interpreter.hpp" +#include "system/include/system/system.hpp" namespace memgraph::query { InterpreterContext::InterpreterContext(InterpreterConfig interpreter_config, dbms::DbmsHandler *dbms_handler, - replication::ReplicationState *rs, query::AuthQueryHandler *ah, - query::AuthChecker *ac) - : dbms_handler(dbms_handler), config(interpreter_config), repl_state(rs), auth(ah), auth_checker(ac) {} + replication::ReplicationState *rs, memgraph::system::System &system, +#ifdef MG_ENTERPRISE + memgraph::coordination::CoordinatorState *coordinator_state, +#endif + AuthQueryHandler *ah, AuthChecker *ac, + ReplicationQueryHandler *replication_handler) + : dbms_handler(dbms_handler), + config(interpreter_config), + repl_state(rs), +#ifdef MG_ENTERPRISE + coordinator_state_{coordinator_state}, +#endif + auth(ah), + auth_checker(ac), + replication_handler_{replication_handler}, + system_{&system} { +} std::vector> InterpreterContext::TerminateTransactions( std::vector maybe_kill_transaction_ids, const std::optional &username, diff --git a/src/query/interpreter_context.hpp b/src/query/interpreter_context.hpp index 9b54dbd3a..c5fe00d2d 100644 --- a/src/query/interpreter_context.hpp +++ b/src/query/interpreter_context.hpp @@ -20,14 +20,20 @@ #include "query/config.hpp" #include "query/cypher_query_interpreter.hpp" +#include "query/replication_query_handler.hpp" #include "query/typed_value.hpp" #include "replication/state.hpp" #include "storage/v2/config.hpp" #include "storage/v2/transaction.hpp" +#include "system/state.hpp" +#include "system/system.hpp" #include "utils/gatekeeper.hpp" #include "utils/skip_list.hpp" #include "utils/spin_lock.hpp" #include "utils/synchronized.hpp" +#ifdef MG_ENTERPRISE +#include "coordination/coordinator_state.hpp" +#endif namespace memgraph::dbms { class DbmsHandler; @@ -48,7 +54,12 @@ class Interpreter; */ struct InterpreterContext { InterpreterContext(InterpreterConfig interpreter_config, dbms::DbmsHandler *dbms_handler, - replication::ReplicationState *rs, AuthQueryHandler *ah = nullptr, AuthChecker *ac = nullptr); + replication::ReplicationState *rs, memgraph::system::System &system, +#ifdef MG_ENTERPRISE + memgraph::coordination::CoordinatorState *coordinator_state, +#endif + AuthQueryHandler *ah = nullptr, AuthChecker *ac = nullptr, + ReplicationQueryHandler *replication_handler = nullptr); memgraph::dbms::DbmsHandler *dbms_handler; @@ -59,9 +70,14 @@ struct InterpreterContext { // GLOBAL memgraph::replication::ReplicationState *repl_state; +#ifdef MG_ENTERPRISE + memgraph::coordination::CoordinatorState *coordinator_state_; +#endif AuthQueryHandler *auth; AuthChecker *auth_checker; + ReplicationQueryHandler *replication_handler_; + system::System *system_; // Used to check active transactions // TODO: Have a way to read the current database diff --git a/src/query/replication_query_handler.hpp b/src/query/replication_query_handler.hpp new file mode 100644 index 000000000..f391b4867 --- /dev/null +++ b/src/query/replication_query_handler.hpp @@ -0,0 +1,60 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include "replication_coordination_glue/role.hpp" +#include "utils/result.hpp" + +// BEGIN fwd declares +namespace memgraph::replication { +struct ReplicationState; +struct ReplicationServerConfig; +struct ReplicationClientConfig; +} // namespace memgraph::replication + +namespace memgraph::query { + +enum class RegisterReplicaError : uint8_t { NAME_EXISTS, ENDPOINT_EXISTS, CONNECTION_FAILED, COULD_NOT_BE_PERSISTED }; +enum class UnregisterReplicaResult : uint8_t { + NOT_MAIN, + COULD_NOT_BE_PERSISTED, + CAN_NOT_UNREGISTER, + SUCCESS, +}; + +/// A handler type that keep in sync current ReplicationState and the MAIN/REPLICA-ness of Storage +struct ReplicationQueryHandler { + virtual ~ReplicationQueryHandler() = default; + + // as REPLICA, become MAIN + virtual bool SetReplicationRoleMain() = 0; + + // as MAIN, become REPLICA + virtual bool SetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config) = 0; + + // as MAIN, define and connect to REPLICAs + virtual auto TryRegisterReplica(const memgraph::replication::ReplicationClientConfig &config) + -> utils::BasicResult = 0; + + virtual auto RegisterReplica(const memgraph::replication::ReplicationClientConfig &config) + -> utils::BasicResult = 0; + + // as MAIN, remove a REPLICA connection + virtual auto UnregisterReplica(std::string_view name) -> UnregisterReplicaResult = 0; + + // Helper pass-through (TODO: remove) + virtual auto GetRole() const -> memgraph::replication_coordination_glue::ReplicationRole = 0; + virtual bool IsMain() const = 0; + virtual bool IsReplica() const = 0; +}; + +} // namespace memgraph::query diff --git a/src/replication/CMakeLists.txt b/src/replication/CMakeLists.txt index e19ba7061..676dce744 100644 --- a/src/replication/CMakeLists.txt +++ b/src/replication/CMakeLists.txt @@ -6,7 +6,6 @@ target_sources(mg-replication include/replication/epoch.hpp include/replication/config.hpp include/replication/status.hpp - include/replication/messages.hpp include/replication/replication_client.hpp include/replication/replication_server.hpp @@ -15,7 +14,6 @@ target_sources(mg-replication epoch.cpp config.cpp status.cpp - messages.cpp replication_client.cpp replication_server.cpp ) diff --git a/src/replication/include/replication/messages.hpp b/src/replication/include/replication/messages.hpp deleted file mode 100644 index b4e0b51c7..000000000 --- a/src/replication/include/replication/messages.hpp +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2024 Memgraph Ltd. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source -// License, and you may not use this file except in compliance with the Business Source License. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -#pragma once - -#include "rpc/messages.hpp" -#include "slk/serialization.hpp" - -namespace memgraph::replication { -struct SystemHeartbeatReq { - static const utils::TypeInfo kType; - static const utils::TypeInfo &GetTypeInfo() { return kType; } - - static void Load(SystemHeartbeatReq *self, memgraph::slk::Reader *reader); - static void Save(const SystemHeartbeatReq &self, memgraph::slk::Builder *builder); - SystemHeartbeatReq() = default; -}; - -struct SystemHeartbeatRes { - static const utils::TypeInfo kType; - static const utils::TypeInfo &GetTypeInfo() { return kType; } - - static void Load(SystemHeartbeatRes *self, memgraph::slk::Reader *reader); - static void Save(const SystemHeartbeatRes &self, memgraph::slk::Builder *builder); - SystemHeartbeatRes() = default; - explicit SystemHeartbeatRes(uint64_t system_timestamp) : system_timestamp(system_timestamp) {} - - uint64_t system_timestamp; -}; - -using SystemHeartbeatRpc = rpc::RequestResponse; -} // namespace memgraph::replication - -namespace memgraph::slk { -void Save(const memgraph::replication::SystemHeartbeatRes &self, memgraph::slk::Builder *builder); -void Load(memgraph::replication::SystemHeartbeatRes *self, memgraph::slk::Reader *reader); -void Save(const memgraph::replication::SystemHeartbeatReq & /*self*/, memgraph::slk::Builder * /*builder*/); -void Load(memgraph::replication::SystemHeartbeatReq * /*self*/, memgraph::slk::Reader * /*reader*/); -} // namespace memgraph::slk diff --git a/src/replication/include/replication/replication_client.hpp b/src/replication/include/replication/replication_client.hpp index 0c64ae625..0bf0e424f 100644 --- a/src/replication/include/replication/replication_client.hpp +++ b/src/replication/include/replication/replication_client.hpp @@ -41,26 +41,67 @@ struct ReplicationClient { void StartFrequentCheck(F &&callback) { // Help the user to get the most accurate replica state possible. if (replica_check_frequency_ > std::chrono::seconds(0)) { - replica_checker_.Run("Replica Checker", replica_check_frequency_, - [this, cb = std::forward(callback), reconnect = false]() mutable { - try { - { - auto stream{rpc_client_.Stream()}; - stream.AwaitResponse(); - } - cb(reconnect, *this); - reconnect = false; - } catch (const rpc::RpcFailedException &) { - // Nothing to do...wait for a reconnect - // NOTE: Here we are communicating with the instance connection. - // We don't have access to the undelying client; so the only thing we can do it - // tell the callback that this is a reconnection and to check the state - reconnect = true; - } - }); + replica_checker_.Run( + "Replica Checker", replica_check_frequency_, + [this, cb = std::forward(callback), reconnect = false]() mutable { + try { + { + auto stream{rpc_client_.Stream()}; + stream.AwaitResponse(); + } + cb(reconnect, *this); + reconnect = false; + } catch (const rpc::RpcFailedException &) { + // Nothing to do...wait for a reconnect + // NOTE: Here we are communicating with the instance connection. + // We don't have access to the undelying client; so the only thing we can do it + // tell the callback that this is a reconnection and to check the state + reconnect = true; + } + }); } } + //! \tparam RPC An rpc::RequestResponse + //! \tparam Args the args type + //! \param client the client to use for rpc communication + //! \param check predicate to check response is ok + //! \param args arguments to forward to the rpc request + //! \return If replica stream is completed or enqueued + template + bool SteamAndFinalizeDelta(auto &&check, Args &&...args) { + try { + auto stream = rpc_client_.template Stream(std::forward(args)...); + auto task = [this, check = std::forward(check), stream = std::move(stream)]() mutable { + if (stream.IsDefunct()) { + state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; }); + return false; + } + try { + if (check(stream.AwaitResponse())) { + return true; + } + } catch (memgraph::rpc::GenericRpcFailedException const &e) { + // swallow error, fallthrough to error handling + } + // This replica needs SYSTEM recovery + state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; }); + return false; + }; + + if (mode_ == memgraph::replication_coordination_glue::ReplicationMode::ASYNC) { + thread_pool_.AddTask([task = utils::CopyMovableFunctionWrapper{std::move(task)}]() mutable { task(); }); + return true; + } + + return task(); + } catch (memgraph::rpc::GenericRpcFailedException const &e) { + // This replica needs SYSTEM recovery + state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; }); + return false; + } + }; + std::string name_; communication::ClientContext rpc_context_; rpc::Client rpc_client_; diff --git a/src/replication/messages.cpp b/src/replication/messages.cpp deleted file mode 100644 index b2dca374e..000000000 --- a/src/replication/messages.cpp +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2024 Memgraph Ltd. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source -// License, and you may not use this file except in compliance with the Business Source License. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. -#include "replication/messages.hpp" - -namespace memgraph::replication { - -constexpr utils::TypeInfo SystemHeartbeatReq::kType{utils::TypeId::REP_SYSTEM_HEARTBEAT_REQ, "SystemHeartbeatReq", - nullptr}; - -constexpr utils::TypeInfo SystemHeartbeatRes::kType{utils::TypeId::REP_SYSTEM_HEARTBEAT_RES, "SystemHeartbeatRes", - nullptr}; - -void SystemHeartbeatReq::Save(const SystemHeartbeatReq &self, memgraph::slk::Builder *builder) { - memgraph::slk::Save(self, builder); -} -void SystemHeartbeatReq::Load(SystemHeartbeatReq *self, memgraph::slk::Reader *reader) { - memgraph::slk::Load(self, reader); -} -void SystemHeartbeatRes::Save(const SystemHeartbeatRes &self, memgraph::slk::Builder *builder) { - memgraph::slk::Save(self, builder); -} -void SystemHeartbeatRes::Load(SystemHeartbeatRes *self, memgraph::slk::Reader *reader) { - memgraph::slk::Load(self, reader); -} - -} // namespace memgraph::replication - -namespace memgraph::slk { -// Serialize code for SystemHeartbeatRes -void Save(const memgraph::replication::SystemHeartbeatRes &self, memgraph::slk::Builder *builder) { - memgraph::slk::Save(self.system_timestamp, builder); -} -void Load(memgraph::replication::SystemHeartbeatRes *self, memgraph::slk::Reader *reader) { - memgraph::slk::Load(&self->system_timestamp, reader); -} - -// Serialize code for SystemHeartbeatReq -void Save(const memgraph::replication::SystemHeartbeatReq & /*self*/, memgraph::slk::Builder * /*builder*/) { - /* Nothing to serialize */ -} -void Load(memgraph::replication::SystemHeartbeatReq * /*self*/, memgraph::slk::Reader * /*reader*/) { - /* Nothing to serialize */ -} -} // namespace memgraph::slk diff --git a/src/replication_handler/CMakeLists.txt b/src/replication_handler/CMakeLists.txt new file mode 100644 index 000000000..a0cd3734c --- /dev/null +++ b/src/replication_handler/CMakeLists.txt @@ -0,0 +1,17 @@ +add_library(mg-replication_handler STATIC) +add_library(mg::replication_handler ALIAS mg-replication_handler) +target_sources(mg-replication_handler + PUBLIC + include/replication_handler/replication_handler.hpp + include/replication_handler/system_replication.hpp + include/replication_handler/system_rpc.hpp + + PRIVATE + replication_handler.cpp + system_replication.cpp + system_rpc.cpp + ) +target_include_directories(mg-replication_handler PUBLIC include) + +target_link_libraries(mg-replication_handler + PUBLIC mg-auth mg-dbms mg-replication) diff --git a/src/replication_handler/include/replication_handler/replication_handler.hpp b/src/replication_handler/include/replication_handler/replication_handler.hpp new file mode 100644 index 000000000..1ae9ceb6d --- /dev/null +++ b/src/replication_handler/include/replication_handler/replication_handler.hpp @@ -0,0 +1,220 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. +#pragma once + +#include "auth/auth.hpp" +#include "dbms/dbms_handler.hpp" +#include "replication/include/replication/state.hpp" +#include "replication_handler/system_rpc.hpp" +#include "utils/result.hpp" + +namespace memgraph::replication { + +inline std::optional HandleRegisterReplicaStatus( + utils::BasicResult &instance_client); + +#ifdef MG_ENTERPRISE +void StartReplicaClient(replication::ReplicationClient &client, system::System *system, dbms::DbmsHandler &dbms_handler, + auth::SynchedAuth &auth); +#else +void StartReplicaClient(replication::ReplicationClient &client, dbms::DbmsHandler &dbms_handler); +#endif + +#ifdef MG_ENTERPRISE +// TODO: Split into 2 functions: dbms and auth +// When being called by interpreter no need to gain lock, it should already be under a system transaction +// But concurrently the FrequentCheck is running and will need to lock before reading last_committed_system_timestamp_ +template +void SystemRestore(replication::ReplicationClient &client, system::System *system, dbms::DbmsHandler &dbms_handler, + auth::SynchedAuth &auth) { + // Check if system is up to date + if (client.state_.WithLock( + [](auto &state) { return state == memgraph::replication::ReplicationClient::State::READY; })) + return; + + // Try to recover... + { + struct DbInfo { + std::vector configs; + uint64_t last_committed_timestamp; + }; + DbInfo db_info = std::invoke([&] { + auto guard = std::invoke([&]() -> std::optional { + if constexpr (REQUIRE_LOCK) { + return system->GenTransactionGuard(); + } + return std::nullopt; + }); + + if (license::global_license_checker.IsEnterpriseValidFast()) { + auto configs = std::vector{}; + dbms_handler.ForEach([&configs](dbms::DatabaseAccess acc) { configs.emplace_back(acc->config().salient); }); + // TODO: This is `SystemRestore` maybe DbInfo is incorrect as it will need Auth also + return DbInfo{configs, system->LastCommittedSystemTimestamp()}; + } + + // No license -> send only default config + return DbInfo{{dbms_handler.Get()->config().salient}, system->LastCommittedSystemTimestamp()}; + }); + try { + auto stream = std::invoke([&]() { + // Handle only default database is no license + if (!license::global_license_checker.IsEnterpriseValidFast()) { + return client.rpc_client_.Stream( + db_info.last_committed_timestamp, std::move(db_info.configs), auth::Auth::Config{}, + std::vector{}, std::vector{}); + } + return auth.WithLock([&](auto &locked_auth) { + return client.rpc_client_.Stream( + db_info.last_committed_timestamp, std::move(db_info.configs), locked_auth.GetConfig(), + locked_auth.AllUsers(), locked_auth.AllRoles()); + }); + }); + const auto response = stream.AwaitResponse(); + if (response.result == replication::SystemRecoveryRes::Result::FAILURE) { + client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; }); + return; + } + } catch (memgraph::rpc::GenericRpcFailedException const &e) { + client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; }); + return; + } + } + + // Successfully recovered + client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::READY; }); +} +#endif + +/// A handler type that keep in sync current ReplicationState and the MAIN/REPLICA-ness of Storage +struct ReplicationHandler : public memgraph::query::ReplicationQueryHandler { +#ifdef MG_ENTERPRISE + explicit ReplicationHandler(memgraph::replication::ReplicationState &repl_state, + memgraph::dbms::DbmsHandler &dbms_handler, memgraph::system::System *system, + memgraph::auth::SynchedAuth &auth); +#else + explicit ReplicationHandler(memgraph::replication::ReplicationState &repl_state, + memgraph::dbms::DbmsHandler &dbms_handler); +#endif + + // as REPLICA, become MAIN + bool SetReplicationRoleMain() override; + + // as MAIN, become REPLICA + bool SetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config) override; + + // as MAIN, define and connect to REPLICAs + auto TryRegisterReplica(const memgraph::replication::ReplicationClientConfig &config) + -> memgraph::utils::BasicResult override; + + auto RegisterReplica(const memgraph::replication::ReplicationClientConfig &config) + -> memgraph::utils::BasicResult override; + + // as MAIN, remove a REPLICA connection + auto UnregisterReplica(std::string_view name) -> memgraph::query::UnregisterReplicaResult override; + + bool DoReplicaToMainPromotion(); + + // Helper pass-through (TODO: remove) + auto GetRole() const -> memgraph::replication_coordination_glue::ReplicationRole override; + bool IsMain() const override; + bool IsReplica() const override; + + private: + template + auto RegisterReplica_(const memgraph::replication::ReplicationClientConfig &config) + -> memgraph::utils::BasicResult { + MG_ASSERT(repl_state_.IsMain(), "Only main instance can register a replica!"); + + auto maybe_client = repl_state_.RegisterReplica(config); + if (maybe_client.HasError()) { + switch (maybe_client.GetError()) { + case memgraph::replication::RegisterReplicaError::NOT_MAIN: + MG_ASSERT(false, "Only main instance can register a replica!"); + return {}; + case memgraph::replication::RegisterReplicaError::NAME_EXISTS: + return memgraph::query::RegisterReplicaError::NAME_EXISTS; + case memgraph::replication::RegisterReplicaError::ENDPOINT_EXISTS: + return memgraph::query::RegisterReplicaError::ENDPOINT_EXISTS; + case memgraph::replication::RegisterReplicaError::COULD_NOT_BE_PERSISTED: + return memgraph::query::RegisterReplicaError::COULD_NOT_BE_PERSISTED; + case memgraph::replication::RegisterReplicaError::SUCCESS: + break; + } + } + + if (!memgraph::dbms::allow_mt_repl && dbms_handler_.All().size() > 1) { + spdlog::warn("Multi-tenant replication is currently not supported!"); + } + +#ifdef MG_ENTERPRISE + // Update system before enabling individual storage <-> replica clients + SystemRestore(*maybe_client.GetValue(), system_, dbms_handler_, auth_); +#endif + + const auto dbms_error = HandleRegisterReplicaStatus(maybe_client); + if (dbms_error.has_value()) { + return *dbms_error; + } + auto &instance_client_ptr = maybe_client.GetValue(); + + bool all_clients_good = true; + // Add database specific clients (NOTE Currently all databases are connected to each replica) + dbms_handler_.ForEach([&](dbms::DatabaseAccess db_acc) { + auto *storage = db_acc->storage(); + if (!dbms::allow_mt_repl && storage->name() != dbms::kDefaultDB) { + return; + } + // TODO: ATM only IN_MEMORY_TRANSACTIONAL, fix other modes + if (storage->storage_mode_ != storage::StorageMode::IN_MEMORY_TRANSACTIONAL) return; + + all_clients_good &= storage->repl_storage_state_.replication_clients_.WithLock( + [storage, &instance_client_ptr, db_acc = std::move(db_acc)](auto &storage_clients) mutable { // NOLINT + auto client = std::make_unique(*instance_client_ptr); + // All good, start replica client + client->Start(storage, std::move(db_acc)); + // After start the storage <-> replica state should be READY or RECOVERING (if correctly started) + // MAYBE_BEHIND isn't a statement of the current state, this is the default value + // Failed to start due an error like branching of MAIN and REPLICA + const bool success = client->State() != storage::replication::ReplicaState::MAYBE_BEHIND; + if (HandleFailure || success) { + storage_clients.push_back(std::move(client)); + } + return success; + }); + }); + + // NOTE Currently if any databases fails, we revert back + if (!HandleFailure && !all_clients_good) { + spdlog::error("Failed to register all databases on the REPLICA \"{}\"", config.name); + UnregisterReplica(config.name); + return memgraph::query::RegisterReplicaError::CONNECTION_FAILED; + } + + // No client error, start instance level client +#ifdef MG_ENTERPRISE + StartReplicaClient(*instance_client_ptr, system_, dbms_handler_, auth_); +#else + StartReplicaClient(*instance_client_ptr, dbms_handler_); +#endif + return {}; + } + + memgraph::replication::ReplicationState &repl_state_; + memgraph::dbms::DbmsHandler &dbms_handler_; + +#ifdef MG_ENTERPRISE + memgraph::system::System *system_; + memgraph::auth::SynchedAuth &auth_; +#endif +}; + +} // namespace memgraph::replication diff --git a/src/replication_handler/include/replication_handler/system_replication.hpp b/src/replication_handler/include/replication_handler/system_replication.hpp new file mode 100644 index 000000000..e1d177fc6 --- /dev/null +++ b/src/replication_handler/include/replication_handler/system_replication.hpp @@ -0,0 +1,31 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include "auth/auth.hpp" +#include "dbms/dbms_handler.hpp" +#include "slk/streams.hpp" +#include "system/state.hpp" + +namespace memgraph::replication { +#ifdef MG_ENTERPRISE +void SystemHeartbeatHandler(uint64_t ts, slk::Reader *req_reader, slk::Builder *res_builder); +void SystemRecoveryHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access, + dbms::DbmsHandler &dbms_handler, auth::SynchedAuth &auth, slk::Reader *req_reader, + slk::Builder *res_builder); +void Register(replication::RoleReplicaData const &data, dbms::DbmsHandler &dbms_handler, auth::SynchedAuth &auth); +bool StartRpcServer(dbms::DbmsHandler &dbms_handler, const replication::RoleReplicaData &data, auth::SynchedAuth &auth); +#else +bool StartRpcServer(dbms::DbmsHandler &dbms_handler, const replication::RoleReplicaData &data); +#endif + +} // namespace memgraph::replication diff --git a/src/replication_handler/include/replication_handler/system_rpc.hpp b/src/replication_handler/include/replication_handler/system_rpc.hpp new file mode 100644 index 000000000..a2469fc5d --- /dev/null +++ b/src/replication_handler/include/replication_handler/system_rpc.hpp @@ -0,0 +1,95 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include +#include + +#include "auth/auth.hpp" +#include "auth/models.hpp" +#include "rpc/messages.hpp" +#include "storage/v2/config.hpp" + +namespace memgraph::replication { +struct SystemHeartbeatReq { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(SystemHeartbeatReq *self, memgraph::slk::Reader *reader); + static void Save(const SystemHeartbeatReq &self, memgraph::slk::Builder *builder); + SystemHeartbeatReq() = default; +}; + +struct SystemHeartbeatRes { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(SystemHeartbeatRes *self, memgraph::slk::Reader *reader); + static void Save(const SystemHeartbeatRes &self, memgraph::slk::Builder *builder); + SystemHeartbeatRes() = default; + explicit SystemHeartbeatRes(uint64_t system_timestamp) : system_timestamp(system_timestamp) {} + + uint64_t system_timestamp; +}; + +using SystemHeartbeatRpc = rpc::RequestResponse; + +struct SystemRecoveryReq { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(SystemRecoveryReq *self, memgraph::slk::Reader *reader); + static void Save(const SystemRecoveryReq &self, memgraph::slk::Builder *builder); + SystemRecoveryReq() = default; + SystemRecoveryReq(uint64_t forced_group_timestamp, std::vector database_configs, + auth::Auth::Config auth_config, std::vector users, std::vector roles) + : forced_group_timestamp{forced_group_timestamp}, + database_configs(std::move(database_configs)), + auth_config(std::move(auth_config)), + users{std::move(users)}, + roles{std::move(roles)} {} + + uint64_t forced_group_timestamp; + std::vector database_configs; + auth::Auth::Config auth_config; + std::vector users; + std::vector roles; +}; + +struct SystemRecoveryRes { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + enum class Result : uint8_t { SUCCESS, NO_NEED, FAILURE, /* Leave at end */ N }; + + static void Load(SystemRecoveryRes *self, memgraph::slk::Reader *reader); + static void Save(const SystemRecoveryRes &self, memgraph::slk::Builder *builder); + SystemRecoveryRes() = default; + explicit SystemRecoveryRes(Result res) : result(res) {} + + Result result; +}; + +using SystemRecoveryRpc = rpc::RequestResponse; + +} // namespace memgraph::replication + +namespace memgraph::slk { +void Save(const memgraph::replication::SystemHeartbeatRes &self, memgraph::slk::Builder *builder); +void Load(memgraph::replication::SystemHeartbeatRes *self, memgraph::slk::Reader *reader); +void Save(const memgraph::replication::SystemHeartbeatReq & /*self*/, memgraph::slk::Builder * /*builder*/); +void Load(memgraph::replication::SystemHeartbeatReq * /*self*/, memgraph::slk::Reader * /*reader*/); +void Save(const memgraph::replication::SystemRecoveryReq &self, memgraph::slk::Builder *builder); +void Load(memgraph::replication::SystemRecoveryReq *self, memgraph::slk::Reader *reader); +void Save(const memgraph::replication::SystemRecoveryRes &self, memgraph::slk::Builder *builder); +void Load(memgraph::replication::SystemRecoveryRes *self, memgraph::slk::Reader *reader); +} // namespace memgraph::slk diff --git a/src/replication_handler/replication_handler.cpp b/src/replication_handler/replication_handler.cpp new file mode 100644 index 000000000..cf1800168 --- /dev/null +++ b/src/replication_handler/replication_handler.cpp @@ -0,0 +1,291 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "replication_handler/replication_handler.hpp" +#include "dbms/dbms_handler.hpp" +#include "replication_handler/system_replication.hpp" + +namespace memgraph::replication { + +namespace { +#ifdef MG_ENTERPRISE +void RecoverReplication(memgraph::replication::ReplicationState &repl_state, memgraph::system::System *system, + memgraph::dbms::DbmsHandler &dbms_handler, memgraph::auth::SynchedAuth &auth) { + /* + * REPLICATION RECOVERY AND STARTUP + */ + + // Startup replication state (if recovered at startup) + auto replica = [&dbms_handler, &auth](memgraph::replication::RoleReplicaData const &data) { + return memgraph::replication::StartRpcServer(dbms_handler, data, auth); + }; + + // Replication recovery and frequent check start + auto main = [system, &dbms_handler, &auth](memgraph::replication::RoleMainData &mainData) { + for (auto &client : mainData.registered_replicas_) { + memgraph::replication::SystemRestore(client, system, dbms_handler, auth); + } + // DBMS here + dbms_handler.ForEach([&mainData](memgraph::dbms::DatabaseAccess db_acc) { + dbms::DbmsHandler::RecoverStorageReplication(std::move(db_acc), mainData); + }); + + for (auto &client : mainData.registered_replicas_) { + memgraph::replication::StartReplicaClient(client, system, dbms_handler, auth); + } + + // Warning + if (dbms_handler.default_config().durability.snapshot_wal_mode == + memgraph::storage::Config::Durability::SnapshotWalMode::DISABLED) { + spdlog::warn( + "The instance has the MAIN replication role, but durability logs and snapshots are disabled. Please " + "consider " + "enabling durability by using --storage-snapshot-interval-sec and --storage-wal-enabled flags because " + "without write-ahead logs this instance is not replicating any data."); + } + + return true; + }; + + auto result = std::visit(memgraph::utils::Overloaded{replica, main}, repl_state.ReplicationData()); + MG_ASSERT(result, "Replica recovery failure!"); +} +#else +void RecoverReplication(memgraph::replication::ReplicationState &repl_state, + memgraph::dbms::DbmsHandler &dbms_handler) { + // Startup replication state (if recovered at startup) + auto replica = [&dbms_handler](memgraph::replication::RoleReplicaData const &data) { + return memgraph::replication::StartRpcServer(dbms_handler, data); + }; + + // Replication recovery and frequent check start + auto main = [&dbms_handler](memgraph::replication::RoleMainData &mainData) { + dbms::DbmsHandler::RecoverStorageReplication(dbms_handler.Get(), mainData); + + for (auto &client : mainData.registered_replicas_) { + memgraph::replication::StartReplicaClient(client, dbms_handler); + } + + // Warning + if (dbms_handler.default_config().durability.snapshot_wal_mode == + memgraph::storage::Config::Durability::SnapshotWalMode::DISABLED) { + spdlog::warn( + "The instance has the MAIN replication role, but durability logs and snapshots are disabled. Please " + "consider " + "enabling durability by using --storage-snapshot-interval-sec and --storage-wal-enabled flags because " + "without write-ahead logs this instance is not replicating any data."); + } + + return true; + }; + + auto result = std::visit(memgraph::utils::Overloaded{replica, main}, repl_state.ReplicationData()); + MG_ASSERT(result, "Replica recovery failure!"); +} +#endif +} // namespace + +inline std::optional HandleRegisterReplicaStatus( + utils::BasicResult &instance_client) { + if (instance_client.HasError()) switch (instance_client.GetError()) { + case replication::RegisterReplicaError::NOT_MAIN: + MG_ASSERT(false, "Only main instance can register a replica!"); + return {}; + case replication::RegisterReplicaError::NAME_EXISTS: + return query::RegisterReplicaError::NAME_EXISTS; + case replication::RegisterReplicaError::ENDPOINT_EXISTS: + return query::RegisterReplicaError::ENDPOINT_EXISTS; + case replication::RegisterReplicaError::COULD_NOT_BE_PERSISTED: + return query::RegisterReplicaError::COULD_NOT_BE_PERSISTED; + case replication::RegisterReplicaError::SUCCESS: + break; + } + return {}; +} + +#ifdef MG_ENTERPRISE +void StartReplicaClient(replication::ReplicationClient &client, system::System *system, dbms::DbmsHandler &dbms_handler, + auth::SynchedAuth &auth) { +#else +void StartReplicaClient(replication::ReplicationClient &client, dbms::DbmsHandler &dbms_handler) { +#endif + // No client error, start instance level client + auto const &endpoint = client.rpc_client_.Endpoint(); + spdlog::trace("Replication client started at: {}:{}", endpoint.address, endpoint.port); + client.StartFrequentCheck([&, +#ifdef MG_ENTERPRISE + system = system, +#endif + license = license::global_license_checker.IsEnterpriseValidFast()]( + bool reconnect, replication::ReplicationClient &client) mutable { + // Working connection + // Check if system needs restoration + if (reconnect) { + client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; }); + } + // Check if license has changed + const auto new_license = license::global_license_checker.IsEnterpriseValidFast(); + if (new_license != license) { + license = new_license; + client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; }); + } +#ifdef MG_ENTERPRISE + SystemRestore(client, system, dbms_handler, auth); +#endif + // Check if any database has been left behind + dbms_handler.ForEach([&name = client.name_, reconnect](dbms::DatabaseAccess db_acc) { + // Specific database <-> replica client + db_acc->storage()->repl_storage_state_.WithClient(name, [&](storage::ReplicationStorageClient *client) { + if (reconnect || client->State() == storage::replication::ReplicaState::MAYBE_BEHIND) { + // Database <-> replica might be behind, check and recover + client->TryCheckReplicaStateAsync(db_acc->storage(), db_acc); + } + }); + }); + }); +} + +#ifdef MG_ENTERPRISE +ReplicationHandler::ReplicationHandler(memgraph::replication::ReplicationState &repl_state, + memgraph::dbms::DbmsHandler &dbms_handler, memgraph::system::System *system, + memgraph::auth::SynchedAuth &auth) + : repl_state_{repl_state}, dbms_handler_{dbms_handler}, system_{system}, auth_{auth} { + RecoverReplication(repl_state_, system_, dbms_handler_, auth_); +} +#else +ReplicationHandler::ReplicationHandler(replication::ReplicationState &repl_state, dbms::DbmsHandler &dbms_handler) + : repl_state_{repl_state}, dbms_handler_{dbms_handler} { + RecoverReplication(repl_state_, dbms_handler_); +} +#endif + +bool ReplicationHandler::SetReplicationRoleMain() { + auto const main_handler = [](memgraph::replication::RoleMainData &) { + // If we are already MAIN, we don't want to change anything + return false; + }; + + auto const replica_handler = [this](memgraph::replication::RoleReplicaData const &) { + return DoReplicaToMainPromotion(); + }; + + // TODO: under lock + return std::visit(memgraph::utils::Overloaded{main_handler, replica_handler}, repl_state_.ReplicationData()); +} + +bool ReplicationHandler::SetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config) { + // We don't want to restart the server if we're already a REPLICA + if (repl_state_.IsReplica()) { + return false; + } + + // TODO StorageState needs to be synched. Could have a dangling reference if someone adds a database as we are + // deleting the replica. + // Remove database specific clients + dbms_handler_.ForEach([&](memgraph::dbms::DatabaseAccess db_acc) { + auto *storage = db_acc->storage(); + storage->repl_storage_state_.replication_clients_.WithLock([](auto &clients) { clients.clear(); }); + }); + // Remove instance level clients + std::get(repl_state_.ReplicationData()).registered_replicas_.clear(); + + // Creates the server + repl_state_.SetReplicationRoleReplica(config); + + // Start + const auto success = + std::visit(memgraph::utils::Overloaded{[](memgraph::replication::RoleMainData const &) { + // ASSERT + return false; + }, + [this](memgraph::replication::RoleReplicaData const &data) { +#ifdef MG_ENTERPRISE + return StartRpcServer(dbms_handler_, data, auth_); +#else + return StartRpcServer(dbms_handler_, data); +#endif + }}, + repl_state_.ReplicationData()); + // TODO Handle error (restore to main?) + return success; +} + +bool ReplicationHandler::DoReplicaToMainPromotion() { + // STEP 1) bring down all REPLICA servers + dbms_handler_.ForEach([](dbms::DatabaseAccess db_acc) { + auto *storage = db_acc->storage(); + // Remember old epoch + storage timestamp association + storage->PrepareForNewEpoch(); + }); + + // STEP 2) Change to MAIN + // TODO: restore replication servers if false? + if (!repl_state_.SetReplicationRoleMain()) { + // TODO: Handle recovery on failure??? + return false; + } + + // STEP 3) We are now MAIN, update storage local epoch + const auto &epoch = std::get(std::as_const(repl_state_).ReplicationData()).epoch_; + dbms_handler_.ForEach([&](dbms::DatabaseAccess db_acc) { + auto *storage = db_acc->storage(); + storage->repl_storage_state_.epoch_ = epoch; + }); + + return true; +}; + +// as MAIN, define and connect to REPLICAs +auto ReplicationHandler::TryRegisterReplica(const memgraph::replication::ReplicationClientConfig &config) + -> memgraph::utils::BasicResult { + return RegisterReplica_(config); +} + +auto ReplicationHandler::RegisterReplica(const memgraph::replication::ReplicationClientConfig &config) + -> memgraph::utils::BasicResult { + return RegisterReplica_(config); +} + +auto ReplicationHandler::UnregisterReplica(std::string_view name) -> memgraph::query::UnregisterReplicaResult { + auto const replica_handler = + [](memgraph::replication::RoleReplicaData const &) -> memgraph::query::UnregisterReplicaResult { + return memgraph::query::UnregisterReplicaResult::NOT_MAIN; + }; + auto const main_handler = + [this, name](memgraph::replication::RoleMainData &mainData) -> memgraph::query::UnregisterReplicaResult { + if (!repl_state_.TryPersistUnregisterReplica(name)) { + return memgraph::query::UnregisterReplicaResult::COULD_NOT_BE_PERSISTED; + } + // Remove database specific clients + dbms_handler_.ForEach([name](memgraph::dbms::DatabaseAccess db_acc) { + db_acc->storage()->repl_storage_state_.replication_clients_.WithLock([&name](auto &clients) { + std::erase_if(clients, [name](const auto &client) { return client->Name() == name; }); + }); + }); + // Remove instance level clients + auto const n_unregistered = + std::erase_if(mainData.registered_replicas_, [name](auto const &client) { return client.name_ == name; }); + return n_unregistered != 0 ? memgraph::query::UnregisterReplicaResult::SUCCESS + : memgraph::query::UnregisterReplicaResult::CAN_NOT_UNREGISTER; + }; + + return std::visit(memgraph::utils::Overloaded{main_handler, replica_handler}, repl_state_.ReplicationData()); +} + +auto ReplicationHandler::GetRole() const -> memgraph::replication_coordination_glue::ReplicationRole { + return repl_state_.GetRole(); +} + +bool ReplicationHandler::IsMain() const { return repl_state_.IsMain(); } + +bool ReplicationHandler::IsReplica() const { return repl_state_.IsReplica(); } + +} // namespace memgraph::replication diff --git a/src/replication_handler/system_replication.cpp b/src/replication_handler/system_replication.cpp new file mode 100644 index 000000000..4f818a567 --- /dev/null +++ b/src/replication_handler/system_replication.cpp @@ -0,0 +1,115 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "replication_handler/system_replication.hpp" + +#include + +#include "auth/replication_handlers.hpp" +#include "dbms/replication_handlers.hpp" +#include "license/license.hpp" +#include "replication_handler/system_rpc.hpp" + +namespace memgraph::replication { + +#ifdef MG_ENTERPRISE +void SystemHeartbeatHandler(const uint64_t ts, slk::Reader *req_reader, slk::Builder *res_builder) { + replication::SystemHeartbeatRes res{0}; + + // Ignore if no license + if (!license::global_license_checker.IsEnterpriseValidFast()) { + spdlog::error("Handling SystemHeartbeat, an enterprise RPC message, without license."); + memgraph::slk::Save(res, res_builder); + return; + } + + replication::SystemHeartbeatReq req; + replication::SystemHeartbeatReq::Load(&req, req_reader); + + res = replication::SystemHeartbeatRes{ts}; + memgraph::slk::Save(res, res_builder); +} + +void SystemRecoveryHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access, + dbms::DbmsHandler &dbms_handler, auth::SynchedAuth &auth, slk::Reader *req_reader, + slk::Builder *res_builder) { + using memgraph::replication::SystemRecoveryRes; + SystemRecoveryRes res(SystemRecoveryRes::Result::FAILURE); + + utils::OnScopeExit send_on_exit([&]() { memgraph::slk::Save(res, res_builder); }); + + memgraph::replication::SystemRecoveryReq req; + memgraph::slk::Load(&req, req_reader); + + /* + * DBMS + */ + if (!dbms::SystemRecoveryHandler(dbms_handler, req.database_configs)) return; // Failure sent on exit + + /* + * AUTH + */ + if (!auth::SystemRecoveryHandler(auth, req.auth_config, req.users, req.roles)) return; // Failure sent on exit + + /* + * SUCCESSFUL RECOVERY + */ + system_state_access.SetLastCommitedTS(req.forced_group_timestamp); + spdlog::debug("SystemRecoveryHandler: SUCCESS updated LCTS to {}", req.forced_group_timestamp); + res = SystemRecoveryRes(SystemRecoveryRes::Result::SUCCESS); +} + +void Register(replication::RoleReplicaData const &data, dbms::DbmsHandler &dbms_handler, auth::SynchedAuth &auth) { + // NOTE: Register even without license as the user could add a license at run-time + // TODO: fix Register when system is removed from DbmsHandler + + auto system_state_access = dbms_handler.system_->CreateSystemStateAccess(); + + // System + data.server->rpc_server_.Register( + [system_state_access](auto *req_reader, auto *res_builder) { + spdlog::debug("Received SystemHeartbeatRpc"); + SystemHeartbeatHandler(system_state_access.LastCommitedTS(), req_reader, res_builder); + }); + data.server->rpc_server_.Register( + [system_state_access, &dbms_handler, &auth](auto *req_reader, auto *res_builder) mutable { + spdlog::debug("Received SystemRecoveryRpc"); + SystemRecoveryHandler(system_state_access, dbms_handler, auth, req_reader, res_builder); + }); + + // DBMS + dbms::Register(data, system_state_access, dbms_handler); + + // Auth + auth::Register(data, system_state_access, auth); +} +#endif + +#ifdef MG_ENTERPRISE +bool StartRpcServer(dbms::DbmsHandler &dbms_handler, const replication::RoleReplicaData &data, + auth::SynchedAuth &auth) { +#else +bool StartRpcServer(dbms::DbmsHandler &dbms_handler, const replication::RoleReplicaData &data) { +#endif + // Register storage handlers + dbms::InMemoryReplicationHandlers::Register(&dbms_handler, *data.server); +#ifdef MG_ENTERPRISE + // Register system handlers + Register(data, dbms_handler, auth); +#endif + // Start server + if (!data.server->Start()) { + spdlog::error("Unable to start the replication server."); + return false; + } + return true; +} +} // namespace memgraph::replication diff --git a/src/replication_handler/system_rpc.cpp b/src/replication_handler/system_rpc.cpp new file mode 100644 index 000000000..0a0bd5e05 --- /dev/null +++ b/src/replication_handler/system_rpc.cpp @@ -0,0 +1,111 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "replication_handler/system_rpc.hpp" + +#include + +#include "auth/rpc.hpp" +#include "slk/serialization.hpp" +#include "slk/streams.hpp" +#include "storage/v2/replication/rpc.hpp" +#include "utils/enum.hpp" + +namespace memgraph::slk { +// Serialize code for SystemHeartbeatRes +void Save(const memgraph::replication::SystemHeartbeatRes &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.system_timestamp, builder); +} +void Load(memgraph::replication::SystemHeartbeatRes *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->system_timestamp, reader); +} + +// Serialize code for SystemHeartbeatReq +void Save(const memgraph::replication::SystemHeartbeatReq & /*self*/, memgraph::slk::Builder * /*builder*/) { + /* Nothing to serialize */ +} +void Load(memgraph::replication::SystemHeartbeatReq * /*self*/, memgraph::slk::Reader * /*reader*/) { + /* Nothing to serialize */ +} + +// Serialize code for SystemRecoveryReq +void Save(const memgraph::replication::SystemRecoveryReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.forced_group_timestamp, builder); + memgraph::slk::Save(self.database_configs, builder); + memgraph::slk::Save(self.auth_config, builder); + memgraph::slk::Save(self.users, builder); + memgraph::slk::Save(self.roles, builder); +} + +void Load(memgraph::replication::SystemRecoveryReq *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->forced_group_timestamp, reader); + memgraph::slk::Load(&self->database_configs, reader); + memgraph::slk::Load(&self->auth_config, reader); + memgraph::slk::Load(&self->users, reader); + memgraph::slk::Load(&self->roles, reader); +} + +// Serialize code for SystemRecoveryRes +void Save(const memgraph::replication::SystemRecoveryRes &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(utils::EnumToNum(self.result), builder); +} + +void Load(memgraph::replication::SystemRecoveryRes *self, memgraph::slk::Reader *reader) { + uint8_t res = 0; + memgraph::slk::Load(&res, reader); + if (!utils::NumToEnum(res, self->result)) { + throw SlkReaderException("Unexpected result line:{}!", __LINE__); + } +} + +} // namespace memgraph::slk + +namespace memgraph::replication { + +constexpr utils::TypeInfo SystemHeartbeatReq::kType{utils::TypeId::REP_SYSTEM_HEARTBEAT_REQ, "SystemHeartbeatReq", + nullptr}; + +constexpr utils::TypeInfo SystemHeartbeatRes::kType{utils::TypeId::REP_SYSTEM_HEARTBEAT_RES, "SystemHeartbeatRes", + nullptr}; + +constexpr utils::TypeInfo SystemRecoveryReq::kType{utils::TypeId::REP_SYSTEM_RECOVERY_REQ, "SystemRecoveryReq", + nullptr}; + +constexpr utils::TypeInfo SystemRecoveryRes::kType{utils::TypeId::REP_SYSTEM_RECOVERY_RES, "SystemRecoveryRes", + nullptr}; + +void SystemHeartbeatReq::Save(const SystemHeartbeatReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); +} +void SystemHeartbeatReq::Load(SystemHeartbeatReq *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(self, reader); +} +void SystemHeartbeatRes::Save(const SystemHeartbeatRes &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); +} +void SystemHeartbeatRes::Load(SystemHeartbeatRes *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(self, reader); +} + +void SystemRecoveryReq::Save(const SystemRecoveryReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); +} +void SystemRecoveryReq::Load(SystemRecoveryReq *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(self, reader); +} +void SystemRecoveryRes::Save(const SystemRecoveryRes &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); +} +void SystemRecoveryRes::Load(SystemRecoveryRes *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(self, reader); +} + +} // namespace memgraph::replication diff --git a/src/storage/v2/replication/rpc.cpp b/src/storage/v2/replication/rpc.cpp index 27fc1a0d6..59d1a02b9 100644 --- a/src/storage/v2/replication/rpc.cpp +++ b/src/storage/v2/replication/rpc.cpp @@ -59,39 +59,6 @@ void TimestampRes::Save(const TimestampRes &self, memgraph::slk::Builder *builde memgraph::slk::Save(self, builder); } void TimestampRes::Load(TimestampRes *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); } -void CreateDatabaseReq::Save(const CreateDatabaseReq &self, memgraph::slk::Builder *builder) { - memgraph::slk::Save(self, builder); -} -void CreateDatabaseReq::Load(CreateDatabaseReq *self, memgraph::slk::Reader *reader) { - memgraph::slk::Load(self, reader); -} -void CreateDatabaseRes::Save(const CreateDatabaseRes &self, memgraph::slk::Builder *builder) { - memgraph::slk::Save(self, builder); -} -void CreateDatabaseRes::Load(CreateDatabaseRes *self, memgraph::slk::Reader *reader) { - memgraph::slk::Load(self, reader); -} -void DropDatabaseReq::Save(const DropDatabaseReq &self, memgraph::slk::Builder *builder) { - memgraph::slk::Save(self, builder); -} -void DropDatabaseReq::Load(DropDatabaseReq *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); } -void DropDatabaseRes::Save(const DropDatabaseRes &self, memgraph::slk::Builder *builder) { - memgraph::slk::Save(self, builder); -} -void DropDatabaseRes::Load(DropDatabaseRes *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); } -void SystemRecoveryReq::Save(const SystemRecoveryReq &self, memgraph::slk::Builder *builder) { - memgraph::slk::Save(self, builder); -} -void SystemRecoveryReq::Load(SystemRecoveryReq *self, memgraph::slk::Reader *reader) { - memgraph::slk::Load(self, reader); -} -void SystemRecoveryRes::Save(const SystemRecoveryRes &self, memgraph::slk::Builder *builder) { - memgraph::slk::Save(self, builder); -} -void SystemRecoveryRes::Load(SystemRecoveryRes *self, memgraph::slk::Reader *reader) { - memgraph::slk::Load(self, reader); -} - } // namespace storage::replication constexpr utils::TypeInfo storage::replication::AppendDeltasReq::kType{utils::TypeId::REP_APPEND_DELTAS_REQ, @@ -130,24 +97,6 @@ constexpr utils::TypeInfo storage::replication::TimestampReq::kType{utils::TypeI constexpr utils::TypeInfo storage::replication::TimestampRes::kType{utils::TypeId::REP_TIMESTAMP_RES, "TimestampRes", nullptr}; -constexpr utils::TypeInfo storage::replication::CreateDatabaseReq::kType{utils::TypeId::REP_CREATE_DATABASE_REQ, - "CreateDatabaseReq", nullptr}; - -constexpr utils::TypeInfo storage::replication::CreateDatabaseRes::kType{utils::TypeId::REP_CREATE_DATABASE_RES, - "CreateDatabaseRes", nullptr}; - -constexpr utils::TypeInfo storage::replication::DropDatabaseReq::kType{utils::TypeId::REP_DROP_DATABASE_REQ, - "DropDatabaseReq", nullptr}; - -constexpr utils::TypeInfo storage::replication::DropDatabaseRes::kType{utils::TypeId::REP_DROP_DATABASE_RES, - "DropDatabaseRes", nullptr}; - -constexpr utils::TypeInfo storage::replication::SystemRecoveryReq::kType{utils::TypeId::REP_SYSTEM_RECOVERY_REQ, - "SystemRecoveryReq", nullptr}; - -constexpr utils::TypeInfo storage::replication::SystemRecoveryRes::kType{utils::TypeId::REP_SYSTEM_RECOVERY_RES, - "SystemRecoveryRes", nullptr}; - // Autogenerated SLK serialization code namespace slk { // Serialize code for TimestampRes @@ -316,91 +265,5 @@ void Load(memgraph::storage::SalientConfig *self, memgraph::slk::Reader *reader) memgraph::slk::Load(&self->items.enable_schema_metadata, reader); } -// Serialize code for CreateDatabaseReq - -void Save(const memgraph::storage::replication::CreateDatabaseReq &self, memgraph::slk::Builder *builder) { - memgraph::slk::Save(self.epoch_id, builder); - memgraph::slk::Save(self.expected_group_timestamp, builder); - memgraph::slk::Save(self.new_group_timestamp, builder); - memgraph::slk::Save(self.config, builder); -} - -void Load(memgraph::storage::replication::CreateDatabaseReq *self, memgraph::slk::Reader *reader) { - memgraph::slk::Load(&self->epoch_id, reader); - memgraph::slk::Load(&self->expected_group_timestamp, reader); - memgraph::slk::Load(&self->new_group_timestamp, reader); - memgraph::slk::Load(&self->config, reader); -} - -// Serialize code for CreateDatabaseRes - -void Save(const memgraph::storage::replication::CreateDatabaseRes &self, memgraph::slk::Builder *builder) { - memgraph::slk::Save(utils::EnumToNum(self.result), builder); -} - -void Load(memgraph::storage::replication::CreateDatabaseRes *self, memgraph::slk::Reader *reader) { - uint8_t res = 0; - memgraph::slk::Load(&res, reader); - if (!utils::NumToEnum(res, self->result)) { - throw SlkReaderException("Unexpected result line:{}!", __LINE__); - } -} - -// Serialize code for DropDatabaseReq - -void Save(const memgraph::storage::replication::DropDatabaseReq &self, memgraph::slk::Builder *builder) { - memgraph::slk::Save(self.epoch_id, builder); - memgraph::slk::Save(self.expected_group_timestamp, builder); - memgraph::slk::Save(self.new_group_timestamp, builder); - memgraph::slk::Save(self.uuid, builder); -} - -void Load(memgraph::storage::replication::DropDatabaseReq *self, memgraph::slk::Reader *reader) { - memgraph::slk::Load(&self->epoch_id, reader); - memgraph::slk::Load(&self->expected_group_timestamp, reader); - memgraph::slk::Load(&self->new_group_timestamp, reader); - memgraph::slk::Load(&self->uuid, reader); -} - -// Serialize code for DropDatabaseRes - -void Save(const memgraph::storage::replication::DropDatabaseRes &self, memgraph::slk::Builder *builder) { - memgraph::slk::Save(utils::EnumToNum(self.result), builder); -} - -void Load(memgraph::storage::replication::DropDatabaseRes *self, memgraph::slk::Reader *reader) { - uint8_t res = 0; - memgraph::slk::Load(&res, reader); - if (!utils::NumToEnum(res, self->result)) { - throw SlkReaderException("Unexpected result line:{}!", __LINE__); - } -} - -// Serialize code for SystemRecoveryReq - -void Save(const memgraph::storage::replication::SystemRecoveryReq &self, memgraph::slk::Builder *builder) { - memgraph::slk::Save(self.forced_group_timestamp, builder); - memgraph::slk::Save(self.database_configs, builder); -} - -void Load(memgraph::storage::replication::SystemRecoveryReq *self, memgraph::slk::Reader *reader) { - memgraph::slk::Load(&self->forced_group_timestamp, reader); - memgraph::slk::Load(&self->database_configs, reader); -} - -// Serialize code for SystemRecoveryRes - -void Save(const memgraph::storage::replication::SystemRecoveryRes &self, memgraph::slk::Builder *builder) { - memgraph::slk::Save(utils::EnumToNum(self.result), builder); -} - -void Load(memgraph::storage::replication::SystemRecoveryRes *self, memgraph::slk::Reader *reader) { - uint8_t res = 0; - memgraph::slk::Load(&res, reader); - if (!utils::NumToEnum(res, self->result)) { - throw SlkReaderException("Unexpected result line:{}!", __LINE__); - } -} - } // namespace slk } // namespace memgraph diff --git a/src/storage/v2/replication/rpc.hpp b/src/storage/v2/replication/rpc.hpp index 62f8b680c..9c9f5c285 100644 --- a/src/storage/v2/replication/rpc.hpp +++ b/src/storage/v2/replication/rpc.hpp @@ -201,108 +201,6 @@ struct TimestampRes { using TimestampRpc = rpc::RequestResponse; -struct CreateDatabaseReq { - static const utils::TypeInfo kType; - static const utils::TypeInfo &GetTypeInfo() { return kType; } - - static void Load(CreateDatabaseReq *self, memgraph::slk::Reader *reader); - static void Save(const CreateDatabaseReq &self, memgraph::slk::Builder *builder); - CreateDatabaseReq() = default; - CreateDatabaseReq(std::string epoch_id, uint64_t expected_group_timestamp, uint64_t new_group_timestamp, - storage::SalientConfig config) - : epoch_id(std::move(epoch_id)), - expected_group_timestamp{expected_group_timestamp}, - new_group_timestamp(new_group_timestamp), - config(std::move(config)) {} - - std::string epoch_id; - uint64_t expected_group_timestamp; - uint64_t new_group_timestamp; - storage::SalientConfig config; -}; - -struct CreateDatabaseRes { - static const utils::TypeInfo kType; - static const utils::TypeInfo &GetTypeInfo() { return kType; } - - enum class Result : uint8_t { SUCCESS, NO_NEED, FAILURE, /* Leave at end */ N }; - - static void Load(CreateDatabaseRes *self, memgraph::slk::Reader *reader); - static void Save(const CreateDatabaseRes &self, memgraph::slk::Builder *builder); - CreateDatabaseRes() = default; - explicit CreateDatabaseRes(Result res) : result(res) {} - - Result result; -}; - -using CreateDatabaseRpc = rpc::RequestResponse; - -struct DropDatabaseReq { - static const utils::TypeInfo kType; - static const utils::TypeInfo &GetTypeInfo() { return kType; } - - static void Load(DropDatabaseReq *self, memgraph::slk::Reader *reader); - static void Save(const DropDatabaseReq &self, memgraph::slk::Builder *builder); - DropDatabaseReq() = default; - DropDatabaseReq(std::string epoch_id, uint64_t expected_group_timestamp, uint64_t new_group_timestamp, - const utils::UUID &uuid) - : epoch_id(std::move(epoch_id)), - expected_group_timestamp{expected_group_timestamp}, - new_group_timestamp(new_group_timestamp), - uuid(uuid) {} - - std::string epoch_id; - uint64_t expected_group_timestamp; - uint64_t new_group_timestamp; - utils::UUID uuid; -}; - -struct DropDatabaseRes { - static const utils::TypeInfo kType; - static const utils::TypeInfo &GetTypeInfo() { return kType; } - - enum class Result : uint8_t { SUCCESS, NO_NEED, FAILURE, /* Leave at end */ N }; - - static void Load(DropDatabaseRes *self, memgraph::slk::Reader *reader); - static void Save(const DropDatabaseRes &self, memgraph::slk::Builder *builder); - DropDatabaseRes() = default; - explicit DropDatabaseRes(Result res) : result(res) {} - - Result result; -}; - -using DropDatabaseRpc = rpc::RequestResponse; - -struct SystemRecoveryReq { - static const utils::TypeInfo kType; - static const utils::TypeInfo &GetTypeInfo() { return kType; } - - static void Load(SystemRecoveryReq *self, memgraph::slk::Reader *reader); - static void Save(const SystemRecoveryReq &self, memgraph::slk::Builder *builder); - SystemRecoveryReq() = default; - SystemRecoveryReq(uint64_t forced_group_timestamp, std::vector database_configs) - : forced_group_timestamp{forced_group_timestamp}, database_configs(std::move(database_configs)) {} - - uint64_t forced_group_timestamp; - std::vector database_configs; -}; - -struct SystemRecoveryRes { - static const utils::TypeInfo kType; - static const utils::TypeInfo &GetTypeInfo() { return kType; } - - enum class Result : uint8_t { SUCCESS, NO_NEED, FAILURE, /* Leave at end */ N }; - - static void Load(SystemRecoveryRes *self, memgraph::slk::Reader *reader); - static void Save(const SystemRecoveryRes &self, memgraph::slk::Builder *builder); - SystemRecoveryRes() = default; - explicit SystemRecoveryRes(Result res) : result(res) {} - - Result result; -}; - -using SystemRecoveryRpc = rpc::RequestResponse; - } // namespace memgraph::storage::replication // SLK serialization declarations @@ -356,28 +254,8 @@ void Save(const memgraph::storage::replication::AppendDeltasReq &self, memgraph: void Load(memgraph::storage::replication::AppendDeltasReq *self, memgraph::slk::Reader *reader); -void Save(const memgraph::storage::replication::CreateDatabaseReq &self, memgraph::slk::Builder *builder); +void Save(const memgraph::storage::SalientConfig &self, memgraph::slk::Builder *builder); -void Load(memgraph::storage::replication::CreateDatabaseReq *self, memgraph::slk::Reader *reader); - -void Save(const memgraph::storage::replication::CreateDatabaseRes &self, memgraph::slk::Builder *builder); - -void Load(memgraph::storage::replication::CreateDatabaseRes *self, memgraph::slk::Reader *reader); - -void Save(const memgraph::storage::replication::DropDatabaseReq &self, memgraph::slk::Builder *builder); - -void Load(memgraph::storage::replication::DropDatabaseReq *self, memgraph::slk::Reader *reader); - -void Save(const memgraph::storage::replication::DropDatabaseRes &self, memgraph::slk::Builder *builder); - -void Load(memgraph::storage::replication::DropDatabaseRes *self, memgraph::slk::Reader *reader); - -void Save(const memgraph::storage::replication::SystemRecoveryReq &self, memgraph::slk::Builder *builder); - -void Load(memgraph::storage::replication::SystemRecoveryReq *self, memgraph::slk::Reader *reader); - -void Save(const memgraph::storage::replication::SystemRecoveryRes &self, memgraph::slk::Builder *builder); - -void Load(memgraph::storage::replication::SystemRecoveryRes *self, memgraph::slk::Reader *reader); +void Load(memgraph::storage::SalientConfig *self, memgraph::slk::Reader *reader); } // namespace memgraph::slk diff --git a/src/system/CMakeLists.txt b/src/system/CMakeLists.txt new file mode 100644 index 000000000..339ca399c --- /dev/null +++ b/src/system/CMakeLists.txt @@ -0,0 +1,23 @@ +add_library(mg-system STATIC) +add_library(mg::system ALIAS mg-system) +target_sources(mg-system + PUBLIC + include/system/action.hpp + include/system/system.hpp + include/system/transaction.hpp + include/system/state.hpp + + PRIVATE + action.cpp + system.cpp + transaction.cpp + state.cpp + +) +target_include_directories(mg-system PUBLIC include) + +target_link_libraries(mg-system + PUBLIC + mg::replication + +) diff --git a/src/dbms/replication_client.hpp b/src/system/action.cpp similarity index 64% rename from src/dbms/replication_client.hpp rename to src/system/action.cpp index c1bac91a2..00e1c1be9 100644 --- a/src/dbms/replication_client.hpp +++ b/src/system/action.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -8,14 +8,4 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. - -#pragma once - -#include "dbms/dbms_handler.hpp" -#include "replication/replication_client.hpp" - -namespace memgraph::dbms { - -void StartReplicaClient(DbmsHandler &dbms_handler, replication::ReplicationClient &client); - -} // namespace memgraph::dbms +#include "system/include/system/action.hpp" diff --git a/src/system/include/system/action.hpp b/src/system/include/system/action.hpp new file mode 100644 index 000000000..77f4cb3e8 --- /dev/null +++ b/src/system/include/system/action.hpp @@ -0,0 +1,38 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include "replication/epoch.hpp" +#include "replication/replication_client.hpp" +#include "replication/state.hpp" + +namespace memgraph::system { + +struct Transaction; + +/// The system action interface that subsystems will implement. This OO-style separation is needed so that one common +/// mechanism can be used for all subsystem replication within a system transaction, without the need for System to +/// know about all the subsystems. +struct ISystemAction { + /// Durability step which is defered until commit time + virtual void DoDurability() = 0; + + /// Prepare the RPC payload that will be sent to all replicas clients + virtual bool DoReplication(memgraph::replication::ReplicationClient &client, + memgraph::replication::ReplicationEpoch const &epoch, + Transaction const &system_tx) const = 0; + + virtual void PostReplication(memgraph::replication::RoleMainData &main_data) const = 0; + + virtual ~ISystemAction() = default; +}; +} // namespace memgraph::system diff --git a/src/system/include/system/state.hpp b/src/system/include/system/state.hpp new file mode 100644 index 000000000..4ef699c1e --- /dev/null +++ b/src/system/include/system/state.hpp @@ -0,0 +1,57 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include +#include + +#include "kvstore/kvstore.hpp" +#include "utils/file.hpp" + +namespace memgraph::system { + +namespace { +constexpr std::string_view kLastCommitedSystemTsKey = "last_committed_system_ts"; // Key for timestamp durability +} + +struct State { + explicit State(std::optional storage, bool recovery_on_startup); + + void FinalizeTransaction(std::uint64_t timestamp) { + if (durability_) { + durability_->Put(kLastCommitedSystemTsKey, std::to_string(timestamp)); + } + last_committed_system_timestamp_.store(timestamp); + } + + auto LastCommittedSystemTimestamp() -> uint64_t { return last_committed_system_timestamp_.load(); } + + private: + friend struct ReplicaHandlerAccessToState; + friend struct Transaction; + + std::optional durability_; + std::atomic_uint64_t last_committed_system_timestamp_{}; +}; + +struct ReplicaHandlerAccessToState { + explicit ReplicaHandlerAccessToState(memgraph::system::State &state) : state_{&state} {} + + auto LastCommitedTS() const -> uint64_t { return state_->last_committed_system_timestamp_.load(); } + + void SetLastCommitedTS(uint64_t new_timestamp) { state_->last_committed_system_timestamp_.store(new_timestamp); } + + private: + State *state_; +}; + +} // namespace memgraph::system diff --git a/src/system/include/system/system.hpp b/src/system/include/system/system.hpp new file mode 100644 index 000000000..eb15a553f --- /dev/null +++ b/src/system/include/system/system.hpp @@ -0,0 +1,52 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include "system/state.hpp" +#include "system/transaction.hpp" + +namespace memgraph::system { + +struct TransactionGuard { + explicit TransactionGuard(std::unique_lock guard) : guard_(std::move(guard)) {} + + private: + std::unique_lock guard_; +}; + +struct System { + // NOTE: default arguments to make testing easier. + System(std::optional storage = std::nullopt, bool recovery_on_startup = false) + : state_(std::move(storage), recovery_on_startup), timestamp_{state_.LastCommittedSystemTimestamp()} {} + + auto TryCreateTransaction(std::chrono::microseconds try_time = std::chrono::milliseconds{100}) + -> std::optional { + auto system_unique = std::unique_lock{mtx_, std::defer_lock}; + if (!system_unique.try_lock_for(try_time)) { + return std::nullopt; + } + return Transaction{state_, std::move(system_unique), timestamp_++}; + } + + // TODO: this and LastCommittedSystemTimestamp maybe not needed + auto GenTransactionGuard() -> TransactionGuard { return TransactionGuard{std::unique_lock{mtx_}}; } + auto LastCommittedSystemTimestamp() -> uint64_t { return state_.LastCommittedSystemTimestamp(); } + + auto CreateSystemStateAccess() -> ReplicaHandlerAccessToState { return ReplicaHandlerAccessToState{state_}; } + + private: + State state_; + std::timed_mutex mtx_{}; + std::uint64_t timestamp_{}; +}; + +} // namespace memgraph::system diff --git a/src/system/include/system/transaction.hpp b/src/system/include/system/transaction.hpp new file mode 100644 index 000000000..af03fe434 --- /dev/null +++ b/src/system/include/system/transaction.hpp @@ -0,0 +1,124 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include +#include +#include +#include +#include +#include "replication/state.hpp" +#include "system/action.hpp" +#include "system/state.hpp" + +namespace memgraph::system { + +enum class AllSyncReplicaStatus : std::uint8_t { + AllCommitsConfirmed, + SomeCommitsUnconfirmed, +}; + +struct Transaction; + +template +concept ReplicationPolicy = requires(T handler, ISystemAction const &action, Transaction const &txn) { + { handler.ApplyAction(action, txn) } -> std::same_as; +}; + +struct System; + +struct Transaction { + template TAction, typename... Args> + requires std::constructible_from + void AddAction(Args &&...args) { actions_.emplace_back(std::make_unique(std::forward(args)...)); } + + template + auto Commit(Handler handler) -> AllSyncReplicaStatus { + if (!lock_.owns_lock() || actions_.empty()) { + // If no actions, we do not increment the last commited ts, since there is no delta to send to the REPLICA + Abort(); + return AllSyncReplicaStatus::AllCommitsConfirmed; // TODO: some kind of error + } + + auto sync_status = AllSyncReplicaStatus::AllCommitsConfirmed; + + while (!actions_.empty()) { + auto &action = actions_.front(); + + /// durability + action->DoDurability(); + + /// replication prep + auto action_sync_status = handler.ApplyAction(*action, *this); + if (action_sync_status != AllSyncReplicaStatus::AllCommitsConfirmed) { + sync_status = AllSyncReplicaStatus::SomeCommitsUnconfirmed; + } + + actions_.pop_front(); + } + + state_->FinalizeTransaction(timestamp_); + lock_.unlock(); + + return sync_status; + } + + void Abort() { + if (lock_.owns_lock()) { + lock_.unlock(); + } + actions_.clear(); + } + + auto last_committed_system_timestamp() const -> uint64_t { return state_->last_committed_system_timestamp_.load(); } + auto timestamp() const -> uint64_t { return timestamp_; } + + private: + friend struct System; + Transaction(State &state, std::unique_lock lock, std::uint64_t timestamp) + : state_{std::addressof(state)}, lock_(std::move(lock)), timestamp_{timestamp} {} + + State *state_; + std::unique_lock lock_; + std::uint64_t timestamp_; + std::list> actions_; +}; + +struct DoReplication { + explicit DoReplication(replication::RoleMainData &main_data) : main_data_{main_data} {} + auto ApplyAction(ISystemAction const &action, Transaction const &system_tx) -> AllSyncReplicaStatus { + auto sync_status = AllSyncReplicaStatus::AllCommitsConfirmed; + + for (auto &client : main_data_.registered_replicas_) { + bool completed = action.DoReplication(client, main_data_.epoch_, system_tx); + if (!completed && client.mode_ == replication_coordination_glue::ReplicationMode::SYNC) { + sync_status = AllSyncReplicaStatus::SomeCommitsUnconfirmed; + } + } + + action.PostReplication(main_data_); + return sync_status; + } + + private: + replication::RoleMainData &main_data_; +}; +static_assert(ReplicationPolicy); + +struct DoNothing { + auto ApplyAction(ISystemAction const & /*action*/, Transaction const & /*system_tx*/) -> AllSyncReplicaStatus { + return AllSyncReplicaStatus::AllCommitsConfirmed; + } +}; +static_assert(ReplicationPolicy); + +} // namespace memgraph::system diff --git a/src/system/state.cpp b/src/system/state.cpp new file mode 100644 index 000000000..fb256ad89 --- /dev/null +++ b/src/system/state.cpp @@ -0,0 +1,57 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "system/state.hpp" + +namespace memgraph::system { + +namespace { + +constexpr std::string_view kSystemDir = ".system"; +constexpr std::string_view kVersion = "version"; // Key for version durability +constexpr std::string_view kVersionV1 = "V1"; // Value for version 1 + +auto InitializeSystemDurability(std::optional storage, bool recovery_on_startup) + -> std::optional { + if (!storage) return std::nullopt; + + auto const &path = *storage; + memgraph::utils::EnsureDir(path); + auto system_dir = path / kSystemDir; + memgraph::utils::EnsureDir(system_dir); + auto durability = memgraph::kvstore::KVStore{std::move(system_dir)}; + + auto version = durability.Get(kVersion); + // TODO: migration schemes here in the future + if (!version || *version != kVersionV1) { + // ensure we start out with V1 + durability.Put(kVersion, kVersionV1); + } + + if (!recovery_on_startup) { + // reset last_committed_system_ts + durability.Delete(kLastCommitedSystemTsKey); + } + + return durability; +} + +auto LoadLastCommittedSystemTimestamp(std::optional const &store) -> uint64_t { + auto lcst = store ? store->Get(kLastCommitedSystemTsKey) : std::nullopt; + return lcst ? std::stoul(*lcst) : 0U; +} + +} // namespace + +State::State(std::optional storage, bool recovery_on_startup) + : durability_{InitializeSystemDurability(std::move(storage), recovery_on_startup)}, + last_committed_system_timestamp_{LoadLastCommittedSystemTimestamp(durability_)} {} +} // namespace memgraph::system diff --git a/src/system/system.cpp b/src/system/system.cpp new file mode 100644 index 000000000..a3e82fdc8 --- /dev/null +++ b/src/system/system.cpp @@ -0,0 +1,11 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. +#include "system/include/system/system.hpp" diff --git a/src/system/transaction.cpp b/src/system/transaction.cpp new file mode 100644 index 000000000..bfc00e86d --- /dev/null +++ b/src/system/transaction.cpp @@ -0,0 +1,11 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. +#include "system/include/system/transaction.hpp" diff --git a/src/telemetry/telemetry.cpp b/src/telemetry/telemetry.cpp index e5e779f31..0b554a3bf 100644 --- a/src/telemetry/telemetry.cpp +++ b/src/telemetry/telemetry.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -149,9 +149,9 @@ void Telemetry::AddClientCollector() { } #ifdef MG_ENTERPRISE -void Telemetry::AddDatabaseCollector(dbms::DbmsHandler &dbms_handler) { - AddCollector("database", [&dbms_handler]() -> nlohmann::json { - const auto &infos = dbms_handler.Info(); +void Telemetry::AddDatabaseCollector(dbms::DbmsHandler &dbms_handler, replication::ReplicationState &repl_state) { + AddCollector("database", [&dbms_handler, &repl_state]() -> nlohmann::json { + const auto &infos = dbms_handler.Info(repl_state.GetRole()); auto dbs = nlohmann::json::array(); for (const auto &db_info : infos) { dbs.push_back(memgraph::dbms::ToJson(db_info)); @@ -162,11 +162,10 @@ void Telemetry::AddDatabaseCollector(dbms::DbmsHandler &dbms_handler) { #else #endif -void Telemetry::AddStorageCollector( - dbms::DbmsHandler &dbms_handler, - memgraph::utils::Synchronized &auth) { - AddCollector("storage", [&dbms_handler, &auth]() -> nlohmann::json { - auto stats = dbms_handler.Stats(); +void Telemetry::AddStorageCollector(dbms::DbmsHandler &dbms_handler, memgraph::auth::SynchedAuth &auth, + memgraph::replication::ReplicationState &repl_state) { + AddCollector("storage", [&dbms_handler, &auth, &repl_state]() -> nlohmann::json { + auto stats = dbms_handler.Stats(repl_state.GetRole()); stats.users = auth->AllUsers().size(); return ToJson(stats); }); diff --git a/src/telemetry/telemetry.hpp b/src/telemetry/telemetry.hpp index c9b82f9ef..ad41c3097 100644 --- a/src/telemetry/telemetry.hpp +++ b/src/telemetry/telemetry.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -43,12 +43,11 @@ class Telemetry final { void AddCollector(const std::string &name, const std::function &func); // Specialized collectors - void AddStorageCollector( - dbms::DbmsHandler &dbms_handler, - memgraph::utils::Synchronized &auth); + void AddStorageCollector(dbms::DbmsHandler &dbms_handler, memgraph::auth::SynchedAuth &auth, + memgraph::replication::ReplicationState &repl_state); #ifdef MG_ENTERPRISE - void AddDatabaseCollector(dbms::DbmsHandler &dbms_handler); + void AddDatabaseCollector(dbms::DbmsHandler &dbms_handler, replication::ReplicationState &repl_state); #else void AddDatabaseCollector() { AddCollector("database", []() -> nlohmann::json { return nlohmann::json::array(); }); diff --git a/src/utils/gatekeeper.hpp b/src/utils/gatekeeper.hpp index 862cad982..fcc3b5842 100644 --- a/src/utils/gatekeeper.hpp +++ b/src/utils/gatekeeper.hpp @@ -161,10 +161,22 @@ struct Gatekeeper { ~Accessor() { reset(); } - auto get() -> T * { return std::addressof(*owner_->value_); } - auto get() const -> const T * { return std::addressof(*owner_->value_); } - T *operator->() { return std::addressof(*owner_->value_); } - const T *operator->() const { return std::addressof(*owner_->value_); } + auto get() -> T * { + if (owner_ == nullptr) return nullptr; + return std::addressof(*owner_->value_); + } + auto get() const -> const T * { + if (owner_ == nullptr) return nullptr; + return std::addressof(*owner_->value_); + } + T *operator->() { + if (owner_ == nullptr) return nullptr; + return std::addressof(*owner_->value_); + } + const T *operator->() const { + if (owner_ == nullptr) return nullptr; + return std::addressof(*owner_->value_); + } template [[nodiscard]] auto try_exclusively(Func &&func) -> EvalResult> { diff --git a/src/utils/typeinfo.hpp b/src/utils/typeinfo.hpp index fd0d1fdeb..6919e8e5c 100644 --- a/src/utils/typeinfo.hpp +++ b/src/utils/typeinfo.hpp @@ -93,6 +93,10 @@ enum class TypeId : uint64_t { REP_SYSTEM_HEARTBEAT_RES, REP_SYSTEM_RECOVERY_REQ, REP_SYSTEM_RECOVERY_RES, + REP_UPDATE_AUTH_DATA_REQ, + REP_UPDATE_AUTH_DATA_RES, + REP_DROP_AUTH_DATA_REQ, + REP_DROP_AUTH_DATA_RES, // Coordinator COORD_FAILOVER_REQ, diff --git a/tests/benchmark/expansion.cpp b/tests/benchmark/expansion.cpp index 51f77f310..0c4579476 100644 --- a/tests/benchmark/expansion.cpp +++ b/tests/benchmark/expansion.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -26,6 +26,7 @@ std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "e class ExpansionBenchFixture : public benchmark::Fixture { protected: + std::optional system; std::optional interpreter_context; std::optional interpreter; std::optional> db_gk; @@ -40,7 +41,14 @@ class ExpansionBenchFixture : public benchmark::Fixture { auto db_acc_opt = db_gk->access(); MG_ASSERT(db_acc_opt, "Failed to access db"); auto &db_acc = *db_acc_opt; - interpreter_context.emplace(memgraph::query::InterpreterConfig{}, nullptr, &repl_state.value()); + + system.emplace(); + interpreter_context.emplace(memgraph::query::InterpreterConfig{}, nullptr, &repl_state.value(), *system +#ifdef MG_ENTERPRISE + , + nullptr +#endif + ); auto label = db_acc->storage()->NameToLabel("Starting"); @@ -70,6 +78,7 @@ class ExpansionBenchFixture : public benchmark::Fixture { void TearDown(const benchmark::State &) override { interpreter = std::nullopt; interpreter_context = std::nullopt; + system.reset(); db_gk.reset(); std::filesystem::remove_all(data_directory); } diff --git a/tests/e2e/interactive_mg_runner.py b/tests/e2e/interactive_mg_runner.py index f0e4e6da1..93bfc5fe6 100755 --- a/tests/e2e/interactive_mg_runner.py +++ b/tests/e2e/interactive_mg_runner.py @@ -105,7 +105,9 @@ def is_port_in_use(port: int) -> bool: return s.connect_ex(("localhost", port)) == 0 -def _start_instance(name, args, log_file, setup_queries, use_ssl, procdir, data_directory): +def _start_instance( + name, args, log_file, setup_queries, use_ssl, procdir, data_directory, username=None, password=None +): assert ( name not in MEMGRAPH_INSTANCES.keys() ), "If this raises, you are trying to start an instance with the same name than one already running." @@ -115,7 +117,9 @@ def _start_instance(name, args, log_file, setup_queries, use_ssl, procdir, data_ log_file_path = os.path.join(BUILD_DIR, "logs", log_file) data_directory_path = os.path.join(BUILD_DIR, data_directory) - mg_instance = MemgraphInstanceRunner(MEMGRAPH_BINARY, use_ssl, {data_directory_path}) + mg_instance = MemgraphInstanceRunner( + MEMGRAPH_BINARY, use_ssl, {data_directory_path}, username=username, password=password + ) MEMGRAPH_INSTANCES[name] = mg_instance binary_args = args + ["--log-file", log_file_path] + ["--data-directory", data_directory_path] @@ -185,8 +189,14 @@ def start_instance(context, name, procdir): data_directory = value["data_directory"] else: data_directory = tempfile.TemporaryDirectory().name + username = None + if "username" in value: + username = value["username"] + password = None + if "password" in value: + password = value["password"] - instance = _start_instance(name, args, log_file, queries, use_ssl, procdir, data_directory) + instance = _start_instance(name, args, log_file, queries, use_ssl, procdir, data_directory, username, password) mg_instances[name] = instance assert len(mg_instances) == 1 diff --git a/tests/e2e/memgraph.py b/tests/e2e/memgraph.py index d5a62a388..92c0a8343 100755 --- a/tests/e2e/memgraph.py +++ b/tests/e2e/memgraph.py @@ -57,7 +57,7 @@ def replace_paths(path): class MemgraphInstanceRunner: - def __init__(self, binary_path=MEMGRAPH_BINARY, use_ssl=False, delete_on_stop=None): + def __init__(self, binary_path=MEMGRAPH_BINARY, use_ssl=False, delete_on_stop=None, username=None, password=None): self.host = "127.0.0.1" self.bolt_port = None self.binary_path = binary_path @@ -65,12 +65,19 @@ class MemgraphInstanceRunner: self.proc_mg = None self.ssl = use_ssl self.delete_on_stop = delete_on_stop + self.username = username + self.password = password def execute_setup_queries(self, setup_queries): if setup_queries is None: return - # An assumption being database instance is fresh, no need for the auth. - conn = mgclient.connect(host=self.host, port=self.bolt_port, sslmode=self.ssl) + conn = mgclient.connect( + host=self.host, + port=self.bolt_port, + sslmode=self.ssl, + username=(self.username or ""), + password=(self.password or ""), + ) conn.autocommit = True cursor = conn.cursor() for query_coll in setup_queries: diff --git a/tests/e2e/replication_experimental/CMakeLists.txt b/tests/e2e/replication_experimental/CMakeLists.txt index cd6e09f38..5037a3605 100644 --- a/tests/e2e/replication_experimental/CMakeLists.txt +++ b/tests/e2e/replication_experimental/CMakeLists.txt @@ -3,6 +3,7 @@ find_package(gflags REQUIRED) copy_e2e_python_files(replication_experiment common.py) copy_e2e_python_files(replication_experiment conftest.py) copy_e2e_python_files(replication_experiment multitenancy.py) +copy_e2e_python_files(replication_experiment auth.py) copy_e2e_python_files_from_parent_folder(replication_experiment ".." memgraph.py) copy_e2e_python_files_from_parent_folder(replication_experiment ".." interactive_mg_runner.py) copy_e2e_python_files_from_parent_folder(replication_experiment ".." mg_utils.py) diff --git a/tests/e2e/replication_experimental/auth.py b/tests/e2e/replication_experimental/auth.py new file mode 100644 index 000000000..950572f80 --- /dev/null +++ b/tests/e2e/replication_experimental/auth.py @@ -0,0 +1,831 @@ +# Copyright 2022 Memgraph Ltd. +# +# Use of this software is governed by the Business Source License +# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +# License, and you may not use this file except in compliance with the Business Source License. +# +# As of the Change Date specified in that file, in accordance with +# the Business Source License, use of this software will be governed +# by the Apache License, Version 2.0, included in the file +# licenses/APL.txt. + +import atexit +import os +import shutil +import sys +import tempfile +import time +from functools import partial + +import interactive_mg_runner +import mgclient +import pytest +from common import execute_and_fetch_all +from mg_utils import mg_sleep_and_assert + +interactive_mg_runner.SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) +interactive_mg_runner.PROJECT_DIR = os.path.normpath( + os.path.join(interactive_mg_runner.SCRIPT_DIR, "..", "..", "..", "..") +) +interactive_mg_runner.BUILD_DIR = os.path.normpath(os.path.join(interactive_mg_runner.PROJECT_DIR, "build")) +interactive_mg_runner.MEMGRAPH_BINARY = os.path.normpath(os.path.join(interactive_mg_runner.BUILD_DIR, "memgraph")) + +BOLT_PORTS = {"main": 7687, "replica_1": 7688, "replica_2": 7689} +REPLICATION_PORTS = {"replica_1": 10001, "replica_2": 10002} +TEMP_DIR = tempfile.TemporaryDirectory().name + + +def update_to_main(cursor): + execute_and_fetch_all(cursor, "SET REPLICATION ROLE TO MAIN;") + + +def add_user(cursor, username, password=None): + if password is not None: + return execute_and_fetch_all(cursor, f"CREATE USER {username} IDENTIFIED BY '{password}';") + return execute_and_fetch_all(cursor, f"CREATE USER {username};") + + +def show_users_func(cursor): + def func(): + return set(execute_and_fetch_all(cursor, "SHOW USERS;")) + + return func + + +def show_roles_func(cursor): + def func(): + return set(execute_and_fetch_all(cursor, "SHOW ROLES;")) + + return func + + +def show_users_for_role_func(cursor, rolename): + def func(): + return set(execute_and_fetch_all(cursor, f"SHOW USERS FOR {rolename};")) + + return func + + +def show_role_for_user_func(cursor, username): + def func(): + return set(execute_and_fetch_all(cursor, f"SHOW ROLE FOR {username};")) + + return func + + +def show_privileges_func(cursor, user_or_role): + def func(): + return set(execute_and_fetch_all(cursor, f"SHOW PRIVILEGES FOR {user_or_role};")) + + return func + + +def show_database_privileges_func(cursor, user): + def func(): + return execute_and_fetch_all(cursor, f"SHOW DATABASE PRIVILEGES FOR {user};") + + return func + + +def show_database_func(cursor): + def func(): + return execute_and_fetch_all(cursor, f"SHOW DATABASE;") + + return func + + +def try_and_count(cursor, query): + try: + execute_and_fetch_all(cursor, query) + except: + return 1 + return 0 + + +def only_main_queries(cursor): + n_exceptions = 0 + + n_exceptions += try_and_count(cursor, f"CREATE USER user_name") + n_exceptions += try_and_count(cursor, f"SET PASSWORD FOR user_name TO 'new_password'") + n_exceptions += try_and_count(cursor, f"DROP USER user_name") + n_exceptions += try_and_count(cursor, f"CREATE ROLE role_name") + n_exceptions += try_and_count(cursor, f"DROP ROLE role_name") + n_exceptions += try_and_count(cursor, f"CREATE USER user_name") + n_exceptions += try_and_count(cursor, f"CREATE ROLE role_name") + n_exceptions += try_and_count(cursor, f"SET ROLE FOR user_name TO role_name") + n_exceptions += try_and_count(cursor, f"CLEAR ROLE FOR user_name") + n_exceptions += try_and_count(cursor, f"GRANT AUTH TO role_name") + n_exceptions += try_and_count(cursor, f"DENY AUTH, INDEX TO user_name") + n_exceptions += try_and_count(cursor, f"REVOKE AUTH FROM role_name") + n_exceptions += try_and_count(cursor, f"GRANT READ ON LABELS :l TO role_name;") + n_exceptions += try_and_count(cursor, f"REVOKE EDGE_TYPES :e FROM user_name") + n_exceptions += try_and_count(cursor, f"GRANT DATABASE memgraph TO user_name;") + n_exceptions += try_and_count(cursor, f"SET MAIN DATABASE memgraph FOR user_name") + n_exceptions += try_and_count(cursor, f"REVOKE DATABASE memgraph FROM user_name;") + + return n_exceptions + + +def main_and_repl_queries(cursor): + n_exceptions = 0 + + try_and_count(cursor, f"SHOW USERS") + try_and_count(cursor, f"SHOW ROLES") + try_and_count(cursor, f"SHOW USERS FOR ROLE role_name") + try_and_count(cursor, f"SHOW ROLE FOR user_name") + try_and_count(cursor, f"SHOW PRIVILEGES FOR role_name") + try_and_count(cursor, f"SHOW DATABASE PRIVILEGES FOR user_name") + + return n_exceptions + + +def test_auth_queries_on_replica(connection): + # Goal: check that write auth queries are forbidden on REPLICAs + # 0/ Setup replication cluster + # 1/ Check queries + + MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL = { + "replica_1": { + "args": [ + "--bolt-port", + f"{BOLT_PORTS['replica_1']}", + "--log-level=TRACE", + "--data_directory", + TEMP_DIR + "/replica1", + ], + "log_file": "replica1.log", + "setup_queries": [ + f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_1']};", + ], + }, + "replica_2": { + "args": [ + "--bolt-port", + f"{BOLT_PORTS['replica_2']}", + "--log-level=TRACE", + "--data_directory", + TEMP_DIR + "/replica2", + ], + "log_file": "replica2.log", + "setup_queries": [ + f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_2']};", + ], + }, + "main": { + "args": [ + "--bolt-port", + f"{BOLT_PORTS['main']}", + "--log-level=TRACE", + "--data_directory", + TEMP_DIR + "/main", + ], + "log_file": "main.log", + "setup_queries": [ + f"REGISTER REPLICA replica_1 SYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_1']}';", + f"REGISTER REPLICA replica_2 ASYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_2']}';", + ], + }, + } + + # 0/ + interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL, keep_directories=False) + cursor_main = connection(BOLT_PORTS["main"], "main", "UsErA", "pass").cursor() + cursor_replica_1 = connection(BOLT_PORTS["replica_1"], "replica", "UsErA", "pass").cursor() + cursor_replica_2 = connection(BOLT_PORTS["replica_2"], "replica", "UsErA", "pass").cursor() + + # 1/ + assert only_main_queries(cursor_main) == 0 + assert only_main_queries(cursor_replica_1) == 17 + assert only_main_queries(cursor_replica_2) == 17 + assert main_and_repl_queries(cursor_main) == 0 + assert main_and_repl_queries(cursor_replica_1) == 0 + assert main_and_repl_queries(cursor_replica_2) == 0 + + +def test_manual_users_recovery(connection): + # Goal: show system recovery in action at registration time + # 0/ MAIN CREATE USER user1, user2 + # REPLICA CREATE USER user3, user4 + # Setup replication cluster + # 1/ Check that both MAIN and REPLICA have user1 and user2 + # 2/ Check connections on REPLICAS + + MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL = { + "replica_1": { + "args": [ + "--bolt-port", + f"{BOLT_PORTS['replica_1']}", + "--log-level=TRACE", + "--data_directory", + TEMP_DIR + "/replica1", + ], + "log_file": "replica1.log", + "setup_queries": [ + "CREATE USER user3;", + f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_1']};", + ], + }, + "replica_2": { + "args": [ + "--bolt-port", + f"{BOLT_PORTS['replica_2']}", + "--log-level=TRACE", + "--data_directory", + TEMP_DIR + "/replica2", + ], + "log_file": "replica2.log", + "setup_queries": [ + "CREATE USER user4 IDENTIFIED BY 'password';", + f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_2']};", + ], + }, + "main": { + "args": [ + "--bolt-port", + f"{BOLT_PORTS['main']}", + "--log-level=TRACE", + "--data_directory", + TEMP_DIR + "/main", + ], + "log_file": "main.log", + "setup_queries": [ + "CREATE USER user1;", + "CREATE USER user2 IDENTIFIED BY 'password';", + f"REGISTER REPLICA replica_1 SYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_1']}';", + f"REGISTER REPLICA replica_2 ASYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_2']}';", + ], + }, + } + + # 0/ + interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL, keep_directories=False) + cursor = connection(BOLT_PORTS["main"], "main", "user1").cursor() + + # 1/ + expected_data = {("user2",), ("user1",)} + mg_sleep_and_assert( + expected_data, show_users_func(connection(BOLT_PORTS["replica_1"], "replica", "user1").cursor()) + ) + mg_sleep_and_assert( + expected_data, show_users_func(connection(BOLT_PORTS["replica_2"], "replica", "user1").cursor()) + ) + + # 2/ + connection(BOLT_PORTS["replica_1"], "replica", "user1").cursor() + connection(BOLT_PORTS["replica_1"], "replica", "user2", "password").cursor() + connection(BOLT_PORTS["replica_2"], "replica", "user1").cursor() + connection(BOLT_PORTS["replica_2"], "replica", "user2", "password").cursor() + + +def test_env_users_recovery(connection): + # Goal: show system recovery in action at registration time + # 0/ Set users from the environment + # MAIN gets users from the environment + # Setup replication cluster + # 1/ Check that both MAIN and REPLICA have user1 + + MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL = { + "replica_1": { + "args": [ + "--bolt-port", + f"{BOLT_PORTS['replica_1']}", + "--log-level=TRACE", + "--data_directory", + TEMP_DIR + "/replica1", + ], + "log_file": "replica1.log", + "setup_queries": [ + f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_1']};", + ], + }, + "replica_2": { + "args": [ + "--bolt-port", + f"{BOLT_PORTS['replica_2']}", + "--log-level=TRACE", + "--data_directory", + TEMP_DIR + "/replica2", + ], + "log_file": "replica2.log", + "setup_queries": [ + f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_2']};", + ], + }, + "main": { + "username": "user1", + "password": "password", + "args": [ + "--bolt-port", + f"{BOLT_PORTS['main']}", + "--log-level=TRACE", + "--data_directory", + TEMP_DIR + "/main", + ], + "log_file": "main.log", + "setup_queries": [ + f"REGISTER REPLICA replica_1 SYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_1']}';", + f"REGISTER REPLICA replica_2 ASYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_2']}';", + ], + }, + } + + # 0/ + # Start only replicas without the env user + interactive_mg_runner.stop_all(keep_directories=False) + interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL, "replica_1") + interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL, "replica_2") + # Setup user + try: + os.environ["MEMGRAPH_USER"] = "user1" + os.environ["MEMGRAPH_PASSWORD"] = "password" + # Start main + interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL, "main") + finally: + # Cleanup + del os.environ["MEMGRAPH_USER"] + del os.environ["MEMGRAPH_PASSWORD"] + + # 1/ + expected_data = {("user1",)} + assert expected_data == show_users_func(connection(BOLT_PORTS["main"], "main", "user1", "password").cursor())() + mg_sleep_and_assert( + expected_data, show_users_func(connection(BOLT_PORTS["replica_1"], "replica", "user1", "password").cursor()) + ) + mg_sleep_and_assert( + expected_data, show_users_func(connection(BOLT_PORTS["replica_2"], "replica", "user1", "password").cursor()) + ) + + +def test_manual_roles_recovery(connection): + # Goal: show system recovery in action at registration time + # 0/ MAIN CREATE USER user1, user2 + # REPLICA CREATE USER user3, user4 + # Setup replication cluster + # 1/ Check that both MAIN and REPLICA have user1 and user2 + # 2/ Check that role1 and role2 are replicated + # 3/ Check that user1 has role1 + + MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL = { + "replica_1": { + "args": [ + "--bolt-port", + f"{BOLT_PORTS['replica_1']}", + "--log-level=TRACE", + "--data_directory", + TEMP_DIR + "/replica1", + ], + "log_file": "replica1.log", + "setup_queries": [ + "CREATE ROLE role3;", + f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_1']};", + ], + }, + "replica_2": { + "args": [ + "--bolt-port", + f"{BOLT_PORTS['replica_2']}", + "--log-level=TRACE", + "--data_directory", + TEMP_DIR + "/replica2", + ], + "log_file": "replica2.log", + "setup_queries": [ + "CREATE ROLE role4;", + "CREATE USER user4;", + "SET ROLE FOR user4 TO role4;", + f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_2']};", + ], + }, + "main": { + "args": [ + "--bolt-port", + f"{BOLT_PORTS['main']}", + "--log-level=TRACE", + "--data_directory", + TEMP_DIR + "/main", + ], + "log_file": "main.log", + "setup_queries": [ + "CREATE ROLE role1;", + "CREATE ROLE role2;", + "CREATE USER user2;", + "SET ROLE FOR user2 TO role2;", + f"REGISTER REPLICA replica_1 SYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_1']}';", + f"REGISTER REPLICA replica_2 ASYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_2']}';", + ], + }, + } + + # 0/ + interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL, keep_directories=False) + connection(BOLT_PORTS["main"], "main", "user2").cursor() # Just check if it connects + cursor_replica_1 = connection(BOLT_PORTS["replica_1"], "replica", "user2").cursor() + cursor_replica_2 = connection(BOLT_PORTS["replica_2"], "replica", "user2").cursor() + + # 1/ + expected_data = { + ("user2",), + } + mg_sleep_and_assert(expected_data, show_users_func(cursor_replica_1)) + mg_sleep_and_assert(expected_data, show_users_func(cursor_replica_2)) + + # 2/ + expected_data = {("role2",), ("role1",)} + mg_sleep_and_assert(expected_data, show_roles_func(cursor_replica_1)) + mg_sleep_and_assert(expected_data, show_roles_func(cursor_replica_2)) + + # 3/ + expected_data = {("role2",)} + mg_sleep_and_assert( + expected_data, + show_role_for_user_func(cursor_replica_1, "user2"), + ) + mg_sleep_and_assert( + expected_data, + show_role_for_user_func(cursor_replica_2, "user2"), + ) + + +def test_auth_config_recovery(connection): + # Goal: show we are replicating Auth::Config + # 0/ Setup auth configuration and compliant users + # 1/ Check that both MAIN and REPLICA have the same users + # 2/ Check that REPLICAS have the same config + + MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL = { + "replica_1": { + "args": [ + "--bolt-port", + f"{BOLT_PORTS['replica_1']}", + "--log-level=TRACE", + "--auth-password-strength-regex", + "^[A-Z]+$", + "--auth-password-permit-null=false", + "--auth-user-or-role-name-regex", + "^[O-o]+$", + "--data_directory", + TEMP_DIR + "/replica1", + ], + "log_file": "replica1.log", + "setup_queries": [ + "CREATE USER OPQabc IDENTIFIED BY 'PASSWORD';", + "CREATE ROLE defRST;", + f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_1']};", + ], + }, + "replica_2": { + "args": [ + "--bolt-port", + f"{BOLT_PORTS['replica_2']}", + "--log-level=TRACE", + "--auth-password-strength-regex", + "^[0-9]+$", + "--auth-password-permit-null=true", + "--auth-user-or-role-name-regex", + "^[A-Np-z]+$", + "--data_directory", + TEMP_DIR + "/replica2", + ], + "log_file": "replica2.log", + "setup_queries": [ + "CREATE ROLE ABCpqr;", + "CREATE USER stuDEF;", + "CREATE USER GvHwI IDENTIFIED BY '123456';", + "SET ROLE FOR GvHwI TO ABCpqr;", + f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_2']};", + ], + }, + "main": { + "args": [ + "--bolt-port", + f"{BOLT_PORTS['main']}", + "--log-level=TRACE", + "--auth-password-strength-regex", + "^[a-z]+$", + "--auth-password-permit-null=false", + "--auth-user-or-role-name-regex", + "^[A-z]+$", + "--data_directory", + TEMP_DIR + "/main", + ], + "log_file": "main.log", + "setup_queries": [ + "CREATE USER UsErA IDENTIFIED BY 'pass';", + "CREATE ROLE rOlE;", + "CREATE USER uSeRB IDENTIFIED BY 'word';", + "SET ROLE FOR uSeRB TO rOlE;", + f"REGISTER REPLICA replica_1 SYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_1']}';", + f"REGISTER REPLICA replica_2 ASYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_2']}';", + ], + }, + } + + # 0/ + interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL, keep_directories=False) + + # 1/ + cursor_main = connection(BOLT_PORTS["main"], "main", "UsErA", "pass").cursor() + cursor_replica_1 = connection(BOLT_PORTS["replica_1"], "replica", "UsErA", "pass").cursor() + cursor_replica_2 = connection(BOLT_PORTS["replica_2"], "replica", "UsErA", "pass").cursor() + + # 2/ Only MAIN can update users + def user_test(cursor): + with pytest.raises(mgclient.DatabaseError, match="Invalid user name."): + add_user(cursor, "UsEr1", "abcdef") + with pytest.raises(mgclient.DatabaseError, match="Null passwords aren't permitted!"): + add_user(cursor, "UsErC") + with pytest.raises(mgclient.DatabaseError, match="The user password doesn't conform to the required strength!"): + add_user(cursor, "UsErC", "123456") + + user_test(cursor_main) + update_to_main(cursor_replica_1) + user_test(cursor_replica_1) + update_to_main(cursor_replica_2) + user_test(cursor_replica_2) + + +def test_auth_replication(connection): + # Goal: show that individual auth queries get replicated + + MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL = { + "replica_1": { + "args": [ + "--bolt-port", + f"{BOLT_PORTS['replica_1']}", + "--log-level=TRACE", + "--data_directory", + TEMP_DIR + "/replica1", + ], + "log_file": "replica1.log", + "setup_queries": [ + f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_1']};", + ], + }, + "replica_2": { + "args": [ + "--bolt-port", + f"{BOLT_PORTS['replica_2']}", + "--log-level=TRACE", + "--data_directory", + TEMP_DIR + "/replica2", + ], + "log_file": "replica2.log", + "setup_queries": [ + f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_2']};", + ], + }, + "main": { + "args": [ + "--bolt-port", + f"{BOLT_PORTS['main']}", + "--log-level=TRACE", + "--data_directory", + TEMP_DIR + "/main", + ], + "log_file": "main.log", + "setup_queries": [ + f"REGISTER REPLICA replica_1 SYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_1']}';", + f"REGISTER REPLICA replica_2 ASYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_2']}';", + ], + }, + } + + # 0/ + interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL, keep_directories=False) + cursor_main = connection(BOLT_PORTS["main"], "main", "user1").cursor() + cursor_replica1 = connection(BOLT_PORTS["replica_1"], "replica").cursor() + cursor_replica2 = connection(BOLT_PORTS["replica_2"], "replica").cursor() + + # 1/ + def check(f, expected_data): + # mg_sleep_and_assert( + # REPLICA 1 is SYNC, should already be ready + assert expected_data == f(cursor_replica1)() + # ) + mg_sleep_and_assert(expected_data, f(cursor_replica2)) + + # CREATE USER + execute_and_fetch_all(cursor_main, "CREATE USER user1") + check( + show_users_func, + { + ("user1",), + }, + ) + execute_and_fetch_all(cursor_main, "CREATE USER user2 IDENTIFIED BY 'pass'") + check( + show_users_func, + { + ("user2",), + ("user1",), + }, + ) + connection(BOLT_PORTS["replica_1"], "replica", "user1").cursor() # Just check connection + connection(BOLT_PORTS["replica_2"], "replica", "user1").cursor() # Just check connection + connection(BOLT_PORTS["replica_1"], "replica", "user2", "pass").cursor() # Just check connection + connection(BOLT_PORTS["replica_2"], "replica", "user2", "pass").cursor() # Just check connection + + # SET PASSWORD + execute_and_fetch_all(cursor_main, "SET PASSWORD FOR user1 TO '1234'") + execute_and_fetch_all(cursor_main, "SET PASSWORD FOR user2 TO 'new_pass'") + connection(BOLT_PORTS["replica_1"], "replica", "user1", "1234").cursor() # Just check connection + connection(BOLT_PORTS["replica_2"], "replica", "user1", "1234").cursor() # Just check connection + connection(BOLT_PORTS["replica_1"], "replica", "user2", "new_pass").cursor() # Just check connection + connection(BOLT_PORTS["replica_2"], "replica", "user2", "new_pass").cursor() # Just check connection + + # DROP USER + execute_and_fetch_all(cursor_main, "DROP USER user2") + check( + show_users_func, + { + ("user1",), + }, + ) + execute_and_fetch_all(cursor_main, "DROP USER user1") + check(show_users_func, set()) + connection(BOLT_PORTS["replica_1"], "replica").cursor() # Just check connection + connection(BOLT_PORTS["replica_2"], "replica").cursor() # Just check connection + + # CREATE ROLE + execute_and_fetch_all(cursor_main, "CREATE ROLE role1") + check( + show_roles_func, + { + ("role1",), + }, + ) + execute_and_fetch_all(cursor_main, "CREATE ROLE role2") + check( + show_roles_func, + { + ("role2",), + ("role1",), + }, + ) + + # DROP ROLE + execute_and_fetch_all(cursor_main, "DROP ROLE role2") + check( + show_roles_func, + { + ("role1",), + }, + ) + execute_and_fetch_all(cursor_main, "DROP ROLE role1") + check(show_roles_func, set()) + + # SET ROLE + execute_and_fetch_all(cursor_main, "CREATE USER user3") + execute_and_fetch_all(cursor_main, "CREATE ROLE role3") + execute_and_fetch_all(cursor_main, "SET ROLE FOR user3 TO role3") + check(partial(show_role_for_user_func, username="user3"), {("role3",)}) + execute_and_fetch_all(cursor_main, "CREATE USER user3b") + execute_and_fetch_all(cursor_main, "SET ROLE FOR user3b TO role3") + check(partial(show_role_for_user_func, username="user3b"), {("role3",)}) + check( + partial(show_users_for_role_func, rolename="role3"), + { + ("user3",), + ("user3b",), + }, + ) + + # CLEAR ROLE + execute_and_fetch_all(cursor_main, "CLEAR ROLE FOR user3") + check(partial(show_role_for_user_func, username="user3"), {("null",)}) + check( + partial(show_users_for_role_func, rolename="role3"), + { + ("user3b",), + }, + ) + + # GRANT/REVOKE/DENY privileges TO user + execute_and_fetch_all(cursor_main, "CREATE USER user4") + execute_and_fetch_all(cursor_main, "REVOKE ALL PRIVILEGES FROM user4") + execute_and_fetch_all(cursor_main, "GRANT CREATE, DELETE, SET TO user4") + check( + partial(show_privileges_func, user_or_role="user4"), + { + ("CREATE", "GRANT", "GRANTED TO USER"), + ("DELETE", "GRANT", "GRANTED TO USER"), + ("SET", "GRANT", "GRANTED TO USER"), + }, + ) + execute_and_fetch_all(cursor_main, "REVOKE SET FROM user4") + check( + partial(show_privileges_func, user_or_role="user4"), + {("CREATE", "GRANT", "GRANTED TO USER"), ("DELETE", "GRANT", "GRANTED TO USER")}, + ) + execute_and_fetch_all(cursor_main, "DENY DELETE TO user4") + check( + partial(show_privileges_func, user_or_role="user4"), + {("CREATE", "GRANT", "GRANTED TO USER"), ("DELETE", "DENY", "DENIED TO USER")}, + ) + + # GRANT/REVOKE/DENY privileges TO role + execute_and_fetch_all(cursor_main, "REVOKE ALL PRIVILEGES FROM role3") + execute_and_fetch_all(cursor_main, "REVOKE ALL PRIVILEGES FROM user3b") + execute_and_fetch_all(cursor_main, "GRANT CREATE, DELETE, SET TO role3") + check( + partial(show_privileges_func, user_or_role="role3"), + { + ("CREATE", "GRANT", "GRANTED TO ROLE"), + ("DELETE", "GRANT", "GRANTED TO ROLE"), + ("SET", "GRANT", "GRANTED TO ROLE"), + }, + ) + check( + partial(show_privileges_func, user_or_role="user3b"), + { + ("CREATE", "GRANT", "GRANTED TO ROLE"), + ("DELETE", "GRANT", "GRANTED TO ROLE"), + ("SET", "GRANT", "GRANTED TO ROLE"), + }, + ) + execute_and_fetch_all(cursor_main, "REVOKE SET FROM role3") + check( + partial(show_privileges_func, user_or_role="role3"), + {("CREATE", "GRANT", "GRANTED TO ROLE"), ("DELETE", "GRANT", "GRANTED TO ROLE")}, + ) + check( + partial(show_privileges_func, user_or_role="user3b"), + {("CREATE", "GRANT", "GRANTED TO ROLE"), ("DELETE", "GRANT", "GRANTED TO ROLE")}, + ) + execute_and_fetch_all(cursor_main, "DENY DELETE TO role3") + check( + partial(show_privileges_func, user_or_role="role3"), + {("CREATE", "GRANT", "GRANTED TO ROLE"), ("DELETE", "DENY", "DENIED TO ROLE")}, + ) + check( + partial(show_privileges_func, user_or_role="user3b"), + {("CREATE", "GRANT", "GRANTED TO ROLE"), ("DELETE", "DENY", "DENIED TO ROLE")}, + ) + + # GRANT permission ON LABEL/EDGE to user/role + execute_and_fetch_all(cursor_main, "REVOKE ALL PRIVILEGES FROM role3") + execute_and_fetch_all(cursor_main, "REVOKE ALL PRIVILEGES FROM user4") + execute_and_fetch_all(cursor_main, "REVOKE ALL PRIVILEGES FROM user3b") + execute_and_fetch_all(cursor_main, "GRANT READ ON LABELS :l1 TO user4") + execute_and_fetch_all(cursor_main, "GRANT UPDATE ON LABELS :l2, :l3 TO role3") + check( + partial(show_privileges_func, user_or_role="user4"), + { + ("LABEL :l1", "READ", "LABEL PERMISSION GRANTED TO USER"), + }, + ) + check( + partial(show_privileges_func, user_or_role="role3"), + { + ("LABEL :l3", "UPDATE", "LABEL PERMISSION GRANTED TO ROLE"), + ("LABEL :l2", "UPDATE", "LABEL PERMISSION GRANTED TO ROLE"), + }, + ) + check( + partial(show_privileges_func, user_or_role="user3b"), + { + ("LABEL :l3", "UPDATE", "LABEL PERMISSION GRANTED TO ROLE"), + ("LABEL :l2", "UPDATE", "LABEL PERMISSION GRANTED TO ROLE"), + }, + ) + execute_and_fetch_all(cursor_main, "REVOKE LABELS :l1 FROM user4") + execute_and_fetch_all(cursor_main, "REVOKE LABELS :l2 FROM role3") + check(partial(show_privileges_func, user_or_role="user4"), set()) + check( + partial(show_privileges_func, user_or_role="role3"), + {("LABEL :l3", "UPDATE", "LABEL PERMISSION GRANTED TO ROLE")}, + ) + check( + partial(show_privileges_func, user_or_role="user3b"), + {("LABEL :l3", "UPDATE", "LABEL PERMISSION GRANTED TO ROLE")}, + ) + + # GRANT/REVOKE DATABASE + execute_and_fetch_all(cursor_main, "CREATE DATABASE auth_test") + execute_and_fetch_all(cursor_main, "CREATE DATABASE auth_test2") + execute_and_fetch_all(cursor_main, "GRANT DATABASE auth_test TO user4") + check(partial(show_database_privileges_func, user="user4"), [(["auth_test", "memgraph"], [])]) + execute_and_fetch_all(cursor_main, "REVOKE DATABASE auth_test2 FROM user4") + check(partial(show_database_privileges_func, user="user4"), [(["auth_test", "memgraph"], ["auth_test2"])]) + + # SET MAIN DATABASE + execute_and_fetch_all(cursor_main, "GRANT ALL PRIVILEGES TO user4") + execute_and_fetch_all(cursor_main, "SET MAIN DATABASE auth_test FOR user4") + # Reconnect and check current db + assert ( + execute_and_fetch_all(connection(BOLT_PORTS["main"], "main", "user4").cursor(), "SHOW DATABASE")[0][0] + == "auth_test" + ) + assert ( + execute_and_fetch_all(connection(BOLT_PORTS["replica_1"], "replica", "user4").cursor(), "SHOW DATABASE")[0][0] + == "auth_test" + ) + assert ( + execute_and_fetch_all(connection(BOLT_PORTS["replica_2"], "replica", "user4").cursor(), "SHOW DATABASE")[0][0] + == "auth_test" + ) + + +if __name__ == "__main__": + interactive_mg_runner.cleanup_directories_on_exit() + sys.exit(pytest.main([__file__, "-rA"])) diff --git a/tests/e2e/replication_experimental/conftest.py b/tests/e2e/replication_experimental/conftest.py index f91333cbf..90b2b0b86 100644 --- a/tests/e2e/replication_experimental/conftest.py +++ b/tests/e2e/replication_experimental/conftest.py @@ -18,9 +18,9 @@ def connection(): connection_holder = None role_holder = None - def inner_connection(port, role): + def inner_connection(port, role, username="", password=""): nonlocal connection_holder, role_holder - connection_holder = connect(host="localhost", port=port) + connection_holder = connect(host="localhost", port=port, username=username, password=password) role_holder = role return connection_holder diff --git a/tests/e2e/replication_experimental/workloads.yaml b/tests/e2e/replication_experimental/workloads.yaml index e48515f4f..fdb7e0674 100644 --- a/tests/e2e/replication_experimental/workloads.yaml +++ b/tests/e2e/replication_experimental/workloads.yaml @@ -2,3 +2,6 @@ workloads: - name: "Replicate multitenancy" binary: "tests/e2e/pytest_runner.sh" args: ["replication_experimental/multitenancy.py"] + - name: "Replicate auth data" + binary: "tests/e2e/pytest_runner.sh" + args: ["replication_experimental/auth.py"] diff --git a/tests/integration/telemetry/client.cpp b/tests/integration/telemetry/client.cpp index b93b1ada5..cff623d23 100644 --- a/tests/integration/telemetry/client.cpp +++ b/tests/integration/telemetry/client.cpp @@ -33,7 +33,7 @@ int main(int argc, char **argv) { // Memgraph backend std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_telemetry_integration_test"}; - memgraph::utils::Synchronized auth_{ + memgraph::auth::SynchedAuth auth_{ data_directory / "auth", memgraph::auth::Auth::Config{std::string{memgraph::glue::kDefaultUserRoleRegex}, "", true}}; memgraph::glue::AuthQueryHandler auth_handler(&auth_); @@ -43,14 +43,20 @@ int main(int argc, char **argv) { memgraph::storage::UpdatePaths(db_config, data_directory); memgraph::replication::ReplicationState repl_state(ReplicationStateRootPath(db_config)); - memgraph::dbms::DbmsHandler dbms_handler(db_config + memgraph::system::System system_state; + memgraph::dbms::DbmsHandler dbms_handler(db_config, system_state, repl_state #ifdef MG_ENTERPRISE , - &auth_, false + auth_, false #endif ); - memgraph::query::InterpreterContext interpreter_context_({}, &dbms_handler, &repl_state, &auth_handler, - &auth_checker); + memgraph::query::InterpreterContext interpreter_context_({}, &dbms_handler, &repl_state, system_state +#ifdef MG_ENTERPRISE + , + nullptr +#endif + , + &auth_handler, &auth_checker); memgraph::requests::Init(); memgraph::telemetry::Telemetry telemetry(FLAGS_endpoint, FLAGS_storage_directory, memgraph::utils::GenerateUUID(), @@ -65,9 +71,9 @@ int main(int argc, char **argv) { }); // Memgraph specific collectors - telemetry.AddStorageCollector(dbms_handler, auth_); + telemetry.AddStorageCollector(dbms_handler, auth_, repl_state); #ifdef MG_ENTERPRISE - telemetry.AddDatabaseCollector(dbms_handler); + telemetry.AddDatabaseCollector(dbms_handler, repl_state); #else telemetry.AddDatabaseCollector(); #endif diff --git a/tests/manual/single_query.cpp b/tests/manual/single_query.cpp index f58af9ae7..f2ce9c572 100644 --- a/tests/manual/single_query.cpp +++ b/tests/manual/single_query.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -39,7 +39,14 @@ int main(int argc, char *argv[]) { auto db_acc_opt = db_gk.access(); MG_ASSERT(db_acc_opt, "Failed to access db"); auto &db_acc = *db_acc_opt; - memgraph::query::InterpreterContext interpreter_context(memgraph::query::InterpreterConfig{}, nullptr, &repl_state); + memgraph::system::System system_state; + memgraph::query::InterpreterContext interpreter_context(memgraph::query::InterpreterConfig{}, nullptr, &repl_state, + system_state +#ifdef MG_ENTERPRISE + , + nullptr +#endif + ); memgraph::query::Interpreter interpreter{&interpreter_context, db_acc}; ResultStreamFaker stream(db_acc->storage()); diff --git a/tests/unit/auth_handler.cpp b/tests/unit/auth_handler.cpp index a162d1838..2b3c39734 100644 --- a/tests/unit/auth_handler.cpp +++ b/tests/unit/auth_handler.cpp @@ -25,7 +25,7 @@ class AuthQueryHandlerFixture : public testing::Test { protected: std::filesystem::path test_folder_{std::filesystem::temp_directory_path() / "MG_tests_unit_auth_handler"}; - memgraph::utils::Synchronized auth{ + memgraph::auth::SynchedAuth auth{ test_folder_ / ("unit_auth_handler_test_" + std::to_string(static_cast(getpid()))), memgraph::auth::Auth::Config{/* default */}}; memgraph::glue::AuthQueryHandler auth_handler{&auth}; diff --git a/tests/unit/dbms_handler.cpp b/tests/unit/dbms_handler.cpp index 2abe0b77d..a20b7dc89 100644 --- a/tests/unit/dbms_handler.cpp +++ b/tests/unit/dbms_handler.cpp @@ -10,6 +10,7 @@ // licenses/APL.txt. #include "query/auth_query_handler.hpp" +#include "replication/state.hpp" #include "storage/v2/config.hpp" #ifdef MG_ENTERPRISE #include @@ -43,7 +44,9 @@ std::set GetDirs(auto path) { std::filesystem::path storage_directory{std::filesystem::temp_directory_path() / "MG_test_unit_dbms_handler"}; std::filesystem::path db_dir{storage_directory / "databases"}; static memgraph::storage::Config storage_conf; -std::unique_ptr> auth; +std::unique_ptr auth; +std::unique_ptr system_state; +std::unique_ptr repl_state; // Let this be global so we can test it different states throughout @@ -64,14 +67,18 @@ class TestEnvironment : public ::testing::Environment { std::filesystem::remove_all(storage_directory); } } - auth = - std::make_unique>( - storage_directory / "auth", memgraph::auth::Auth::Config{/* default */}); - ptr_ = std::make_unique(storage_conf, auth.get(), false); + auth = std::make_unique(storage_directory / "auth", + memgraph::auth::Auth::Config{/* default */}); + system_state = std::make_unique(); + repl_state = std::make_unique(ReplicationStateRootPath(storage_conf)); + ptr_ = std::make_unique(storage_conf, *system_state.get(), *repl_state.get(), + *auth.get(), false); } void TearDown() override { ptr_.reset(); + repl_state.reset(); + system_state.reset(); auth.reset(); std::filesystem::remove_all(storage_directory); } diff --git a/tests/unit/dbms_handler_community.cpp b/tests/unit/dbms_handler_community.cpp index 4a47e018b..1af4445b3 100644 --- a/tests/unit/dbms_handler_community.cpp +++ b/tests/unit/dbms_handler_community.cpp @@ -28,7 +28,9 @@ // Global std::filesystem::path storage_directory{std::filesystem::temp_directory_path() / "MG_test_unit_dbms_handler_community"}; static memgraph::storage::Config storage_conf; -std::unique_ptr> auth; +std::unique_ptr auth; +std::unique_ptr system_state; +std::unique_ptr repl_state; // Let this be global so we can test it different states throughout @@ -49,14 +51,17 @@ class TestEnvironment : public ::testing::Environment { std::filesystem::remove_all(storage_directory); } } - auth = - std::make_unique>( - storage_directory / "auth", memgraph::auth::Auth::Config{/* default */}); - ptr_ = std::make_unique(storage_conf); + auth = std::make_unique(storage_directory / "auth", + memgraph::auth::Auth::Config{/* default */}); + system_state = std::make_unique(); + repl_state = std::make_unique(ReplicationStateRootPath(storage_conf)); + ptr_ = std::make_unique(storage_conf, *system_state.get(), *repl_state.get()); } void TearDown() override { ptr_.reset(); + repl_state.reset(); + system_state.reset(); auth.reset(); std::filesystem::remove_all(storage_directory); } diff --git a/tests/unit/interpreter.cpp b/tests/unit/interpreter.cpp index bd587e7df..dfed72fbd 100644 --- a/tests/unit/interpreter.cpp +++ b/tests/unit/interpreter.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -94,7 +94,16 @@ class InterpreterTest : public ::testing::Test { }() // iile }; - memgraph::query::InterpreterContext interpreter_context{{}, kNoHandler, &repl_state}; + memgraph::system::System system_state; + memgraph::query::InterpreterContext interpreter_context{{}, + kNoHandler, + &repl_state, + system_state +#ifdef MG_ENTERPRISE + , + nullptr +#endif + }; void TearDown() override { if (std::is_same::value) { @@ -1150,8 +1159,16 @@ TYPED_TEST(InterpreterTest, AllowLoadCsvConfig) { << "Wrong storage mode!"; memgraph::replication::ReplicationState repl_state{std::nullopt}; - memgraph::query::InterpreterContext csv_interpreter_context{ - {.query = {.allow_load_csv = allow_load_csv}}, nullptr, &repl_state}; + memgraph::system::System system_state; + memgraph::query::InterpreterContext csv_interpreter_context{{.query = {.allow_load_csv = allow_load_csv}}, + nullptr, + &repl_state, + system_state +#ifdef MG_ENTERPRISE + , + nullptr +#endif + }; InterpreterFaker interpreter_faker{&csv_interpreter_context, db_acc}; for (const auto &query : queries) { if (allow_load_csv) { diff --git a/tests/unit/multi_tenancy.cpp b/tests/unit/multi_tenancy.cpp index 59364776a..e5ea4dd05 100644 --- a/tests/unit/multi_tenancy.cpp +++ b/tests/unit/multi_tenancy.cpp @@ -14,6 +14,7 @@ #include #include +#include "auth/auth.hpp" #include "communication/bolt/v1/value.hpp" #include "communication/result_stream_faker.hpp" #include "csv/parsing.hpp" @@ -99,8 +100,17 @@ class MultiTenantTest : public ::testing::Test { struct MinMemgraph { explicit MinMemgraph(const memgraph::storage::Config &conf) : auth{conf.durability.storage_directory / "auth", memgraph::auth::Auth::Config{/* default */}}, - dbms{conf, &auth, true}, - interpreter_context{{}, &dbms, &dbms.ReplicationState()} { + repl_state{ReplicationStateRootPath(conf)}, + dbms{conf, system, repl_state, auth, true}, + interpreter_context{{}, + &dbms, + &repl_state, + system +#ifdef MG_ENTERPRISE + , + nullptr +#endif + } { memgraph::utils::global_settings.Initialize(conf.durability.storage_directory / "settings"); memgraph::license::RegisterLicenseSettings(memgraph::license::global_license_checker, memgraph::utils::global_settings); @@ -112,7 +122,9 @@ class MultiTenantTest : public ::testing::Test { auto NewInterpreter() { return InterpreterFaker{&interpreter_context, dbms.Get()}; } - memgraph::utils::Synchronized auth; + memgraph::auth::SynchedAuth auth; + memgraph::system::System system; + memgraph::replication::ReplicationState repl_state; memgraph::dbms::DbmsHandler dbms; memgraph::query::InterpreterContext interpreter_context; }; diff --git a/tests/unit/query_dump.cpp b/tests/unit/query_dump.cpp index 23eab17e0..5ecf598b2 100644 --- a/tests/unit/query_dump.cpp +++ b/tests/unit/query_dump.cpp @@ -267,7 +267,7 @@ memgraph::storage::EdgeAccessor CreateEdge(memgraph::storage::Storage::Accessor } template -void VerifyQueries(const std::vector> &results, TArgs &&... args) { +void VerifyQueries(const std::vector> &results, TArgs &&...args) { std::vector expected{std::forward(args)...}; std::vector got; got.reserve(results.size()); @@ -314,8 +314,13 @@ class DumpTest : public ::testing::Test { return db_acc; }() // iile }; - - memgraph::query::InterpreterContext context{memgraph::query::InterpreterConfig{}, nullptr, &repl_state}; + memgraph::system::System system_state; + memgraph::query::InterpreterContext context{memgraph::query::InterpreterConfig{}, nullptr, &repl_state, system_state +#ifdef MG_ENTERPRISE + , + nullptr +#endif + }; void TearDown() override { if (std::is_same::value) { @@ -722,7 +727,14 @@ TYPED_TEST(DumpTest, CheckStateVertexWithMultipleProperties) { : memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL)) << "Wrong storage mode!"; - memgraph::query::InterpreterContext interpreter_context(memgraph::query::InterpreterConfig{}, nullptr, &repl_state); + memgraph::system::System system_state; + memgraph::query::InterpreterContext interpreter_context(memgraph::query::InterpreterConfig{}, nullptr, &repl_state, + system_state +#ifdef MG_ENTERPRISE + , + nullptr +#endif + ); { ResultStreamFaker stream(this->db->storage()); @@ -842,7 +854,14 @@ TYPED_TEST(DumpTest, CheckStateSimpleGraph) { : memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL)) << "Wrong storage mode!"; - memgraph::query::InterpreterContext interpreter_context(memgraph::query::InterpreterConfig{}, nullptr, &repl_state); + memgraph::system::System system_state; + memgraph::query::InterpreterContext interpreter_context(memgraph::query::InterpreterConfig{}, nullptr, &repl_state, + system_state +#ifdef MG_ENTERPRISE + , + nullptr +#endif + ); { ResultStreamFaker stream(this->db->storage()); memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource()); diff --git a/tests/unit/query_plan_edge_cases.cpp b/tests/unit/query_plan_edge_cases.cpp index d0953651e..ac04cabdd 100644 --- a/tests/unit/query_plan_edge_cases.cpp +++ b/tests/unit/query_plan_edge_cases.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -42,6 +42,7 @@ class QueryExecution : public testing::Test { std::optional repl_state; std::optional> db_gk; + std::optional system_state; void SetUp() override { auto config = [&]() { @@ -65,14 +66,20 @@ class QueryExecution : public testing::Test { : memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL), "Wrong storage mode!"); db_acc_ = std::move(db_acc); - - interpreter_context_.emplace(memgraph::query::InterpreterConfig{}, nullptr, &repl_state.value()); + system_state.emplace(); + interpreter_context_.emplace(memgraph::query::InterpreterConfig{}, nullptr, &repl_state.value(), *system_state +#ifdef MG_ENTERPRISE + , + nullptr +#endif + ); interpreter_.emplace(&*interpreter_context_, *db_acc_); } void TearDown() override { interpreter_ = std::nullopt; interpreter_context_ = std::nullopt; + system_state.reset(); db_acc_.reset(); db_gk.reset(); repl_state.reset(); diff --git a/tests/unit/query_streams.cpp b/tests/unit/query_streams.cpp index 5dfd0a8f1..cde3d937a 100644 --- a/tests/unit/query_streams.cpp +++ b/tests/unit/query_streams.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -104,7 +104,14 @@ class StreamsTestFixture : public ::testing::Test { return db_acc; }() // iile }; - memgraph::query::InterpreterContext interpreter_context_{memgraph::query::InterpreterConfig{}, nullptr, &repl_state}; + memgraph::system::System system_state; + memgraph::query::InterpreterContext interpreter_context_{memgraph::query::InterpreterConfig{}, nullptr, &repl_state, + system_state +#ifdef MG_ENTERPRISE + , + nullptr +#endif + }; std::filesystem::path streams_data_directory_{data_directory_ / "separate-dir-for-test"}; std::optional proxyStreams_; diff --git a/tests/unit/storage_v2_replication.cpp b/tests/unit/storage_v2_replication.cpp index e572440ca..b2adf3588 100644 --- a/tests/unit/storage_v2_replication.cpp +++ b/tests/unit/storage_v2_replication.cpp @@ -25,22 +25,20 @@ #include "auth/auth.hpp" #include "dbms/database.hpp" #include "dbms/dbms_handler.hpp" -#include "dbms/replication_handler.hpp" #include "query/interpreter_context.hpp" #include "replication/config.hpp" #include "replication/state.hpp" +#include "replication_handler/replication_handler.hpp" #include "storage/v2/indices/label_index_stats.hpp" #include "storage/v2/storage.hpp" #include "storage/v2/view.hpp" -#include "utils/rw_lock.hpp" -#include "utils/synchronized.hpp" using testing::UnorderedElementsAre; -using memgraph::dbms::RegisterReplicaError; -using memgraph::dbms::ReplicationHandler; -using memgraph::dbms::UnregisterReplicaResult; +using memgraph::query::RegisterReplicaError; +using memgraph::query::UnregisterReplicaResult; using memgraph::replication::ReplicationClientConfig; +using memgraph::replication::ReplicationHandler; using memgraph::replication::ReplicationServerConfig; using memgraph::replication_coordination_glue::ReplicationMode; using memgraph::replication_coordination_glue::ReplicationRole; @@ -114,21 +112,26 @@ class ReplicationTest : public ::testing::Test { struct MinMemgraph { MinMemgraph(const memgraph::storage::Config &conf) : auth{conf.durability.storage_directory / "auth", memgraph::auth::Auth::Config{/* default */}}, - dbms{conf + repl_state{ReplicationStateRootPath(conf)}, + dbms{conf, system_, repl_state #ifdef MG_ENTERPRISE , - &auth, true + auth, true #endif }, - repl_state{dbms.ReplicationState()}, db_acc{dbms.Get()}, db{*db_acc.get()}, - repl_handler(dbms) { + repl_handler(repl_state, dbms +#ifdef MG_ENTERPRISE + , + &system_, auth +#endif + ) { } - - memgraph::utils::Synchronized auth; + memgraph::auth::SynchedAuth auth; + memgraph::system::System system_; + memgraph::replication::ReplicationState repl_state; memgraph::dbms::DbmsHandler dbms; - memgraph::replication::ReplicationState &repl_state; memgraph::dbms::DatabaseAccess db_acc; memgraph::dbms::Database &db; ReplicationHandler repl_handler; @@ -144,7 +147,7 @@ TEST_F(ReplicationTest, BasicSynchronousReplicationTest) { .port = ports[0], }); - const auto ® = main.repl_handler.RegisterReplica(ReplicationClientConfig{ + const auto ® = main.repl_handler.TryRegisterReplica(ReplicationClientConfig{ .name = "REPLICA", .mode = ReplicationMode::SYNC, .ip_address = local_host, @@ -442,7 +445,7 @@ TEST_F(ReplicationTest, MultipleSynchronousReplicationTest) { }); ASSERT_FALSE(main.repl_handler - .RegisterReplica(ReplicationClientConfig{ + .TryRegisterReplica(ReplicationClientConfig{ .name = replicas[0], .mode = ReplicationMode::SYNC, .ip_address = local_host, @@ -450,7 +453,7 @@ TEST_F(ReplicationTest, MultipleSynchronousReplicationTest) { }) .HasError()); ASSERT_FALSE(main.repl_handler - .RegisterReplica(ReplicationClientConfig{ + .TryRegisterReplica(ReplicationClientConfig{ .name = replicas[1], .mode = ReplicationMode::SYNC, .ip_address = local_host, @@ -587,7 +590,7 @@ TEST_F(ReplicationTest, RecoveryProcess) { .port = ports[0], }); ASSERT_FALSE(main.repl_handler - .RegisterReplica(ReplicationClientConfig{ + .TryRegisterReplica(ReplicationClientConfig{ .name = replicas[0], .mode = ReplicationMode::SYNC, .ip_address = local_host, @@ -663,7 +666,7 @@ TEST_F(ReplicationTest, BasicAsynchronousReplicationTest) { }); ASSERT_FALSE(main.repl_handler - .RegisterReplica(ReplicationClientConfig{ + .TryRegisterReplica(ReplicationClientConfig{ .name = "REPLICA_ASYNC", .mode = ReplicationMode::ASYNC, .ip_address = local_host, @@ -715,7 +718,7 @@ TEST_F(ReplicationTest, EpochTest) { }); ASSERT_FALSE(main.repl_handler - .RegisterReplica(ReplicationClientConfig{ + .TryRegisterReplica(ReplicationClientConfig{ .name = replicas[0], .mode = ReplicationMode::SYNC, .ip_address = local_host, @@ -724,7 +727,7 @@ TEST_F(ReplicationTest, EpochTest) { .HasError()); ASSERT_FALSE(main.repl_handler - .RegisterReplica(ReplicationClientConfig{ + .TryRegisterReplica(ReplicationClientConfig{ .name = replicas[1], .mode = ReplicationMode::SYNC, .ip_address = local_host, @@ -758,7 +761,7 @@ TEST_F(ReplicationTest, EpochTest) { ASSERT_TRUE(replica1.repl_handler.SetReplicationRoleMain()); ASSERT_FALSE(replica1.repl_handler - .RegisterReplica(ReplicationClientConfig{ + .TryRegisterReplica(ReplicationClientConfig{ .name = replicas[1], .mode = ReplicationMode::SYNC, .ip_address = local_host, @@ -791,7 +794,7 @@ TEST_F(ReplicationTest, EpochTest) { .port = ports[0], }); ASSERT_TRUE(main.repl_handler - .RegisterReplica(ReplicationClientConfig{ + .TryRegisterReplica(ReplicationClientConfig{ .name = replicas[0], .mode = ReplicationMode::SYNC, .ip_address = local_host, @@ -834,7 +837,7 @@ TEST_F(ReplicationTest, ReplicationInformation) { }); ASSERT_FALSE(main.repl_handler - .RegisterReplica(ReplicationClientConfig{ + .TryRegisterReplica(ReplicationClientConfig{ .name = replicas[0], .mode = ReplicationMode::SYNC, .ip_address = local_host, @@ -844,7 +847,7 @@ TEST_F(ReplicationTest, ReplicationInformation) { .HasError()); ASSERT_FALSE(main.repl_handler - .RegisterReplica(ReplicationClientConfig{ + .TryRegisterReplica(ReplicationClientConfig{ .name = replicas[1], .mode = ReplicationMode::ASYNC, .ip_address = local_host, @@ -890,7 +893,7 @@ TEST_F(ReplicationTest, ReplicationReplicaWithExistingName) { .port = replica2_port, }); ASSERT_FALSE(main.repl_handler - .RegisterReplica(ReplicationClientConfig{ + .TryRegisterReplica(ReplicationClientConfig{ .name = replicas[0], .mode = ReplicationMode::SYNC, .ip_address = local_host, @@ -899,7 +902,7 @@ TEST_F(ReplicationTest, ReplicationReplicaWithExistingName) { .HasError()); ASSERT_TRUE(main.repl_handler - .RegisterReplica(ReplicationClientConfig{ + .TryRegisterReplica(ReplicationClientConfig{ .name = replicas[0], .mode = ReplicationMode::ASYNC, .ip_address = local_host, @@ -925,7 +928,7 @@ TEST_F(ReplicationTest, ReplicationReplicaWithExistingEndPoint) { }); ASSERT_FALSE(main.repl_handler - .RegisterReplica(ReplicationClientConfig{ + .TryRegisterReplica(ReplicationClientConfig{ .name = replicas[0], .mode = ReplicationMode::SYNC, .ip_address = local_host, @@ -934,7 +937,7 @@ TEST_F(ReplicationTest, ReplicationReplicaWithExistingEndPoint) { .HasError()); ASSERT_TRUE(main.repl_handler - .RegisterReplica(ReplicationClientConfig{ + .TryRegisterReplica(ReplicationClientConfig{ .name = replicas[1], .mode = ReplicationMode::ASYNC, .ip_address = local_host, @@ -973,14 +976,14 @@ TEST_F(ReplicationTest, RestoringReplicationAtStartupAfterDroppingReplica) { .port = ports[1], }); - auto res = main->repl_handler.RegisterReplica(ReplicationClientConfig{ + auto res = main->repl_handler.TryRegisterReplica(ReplicationClientConfig{ .name = replicas[0], .mode = ReplicationMode::SYNC, .ip_address = local_host, .port = ports[0], }); ASSERT_FALSE(res.HasError()) << (int)res.GetError(); - res = main->repl_handler.RegisterReplica(ReplicationClientConfig{ + res = main->repl_handler.TryRegisterReplica(ReplicationClientConfig{ .name = replicas[1], .mode = ReplicationMode::SYNC, .ip_address = local_host, @@ -1030,14 +1033,14 @@ TEST_F(ReplicationTest, RestoringReplicationAtStartup) { .ip_address = local_host, .port = ports[1], }); - auto res = main->repl_handler.RegisterReplica(ReplicationClientConfig{ + auto res = main->repl_handler.TryRegisterReplica(ReplicationClientConfig{ .name = replicas[0], .mode = ReplicationMode::SYNC, .ip_address = local_host, .port = ports[0], }); ASSERT_FALSE(res.HasError()); - res = main->repl_handler.RegisterReplica(ReplicationClientConfig{ + res = main->repl_handler.TryRegisterReplica(ReplicationClientConfig{ .name = replicas[1], .mode = ReplicationMode::SYNC, .ip_address = local_host, @@ -1080,7 +1083,7 @@ TEST_F(ReplicationTest, AddingInvalidReplica) { MinMemgraph main(main_conf); ASSERT_TRUE(main.repl_handler - .RegisterReplica(ReplicationClientConfig{ + .TryRegisterReplica(ReplicationClientConfig{ .name = "REPLICA", .mode = ReplicationMode::SYNC, .ip_address = local_host, diff --git a/tests/unit/storage_v2_storage_mode.cpp b/tests/unit/storage_v2_storage_mode.cpp index 487319d3c..03ade41f8 100644 --- a/tests/unit/storage_v2_storage_mode.cpp +++ b/tests/unit/storage_v2_storage_mode.cpp @@ -90,7 +90,16 @@ class StorageModeMultiTxTest : public ::testing::Test { return db_acc; }() // iile }; - memgraph::query::InterpreterContext interpreter_context{{}, nullptr, &repl_state}; + memgraph::system::System system_state; + memgraph::query::InterpreterContext interpreter_context{{}, + nullptr, + &repl_state, + system_state +#ifdef MG_ENTERPRISE + , + nullptr +#endif + }; InterpreterFaker running_interpreter{&interpreter_context, db}, main_interpreter{&interpreter_context, db}; }; diff --git a/tests/unit/transaction_queue.cpp b/tests/unit/transaction_queue.cpp index d031b76b0..a90fe2c59 100644 --- a/tests/unit/transaction_queue.cpp +++ b/tests/unit/transaction_queue.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -59,7 +59,16 @@ class TransactionQueueSimpleTest : public ::testing::Test { return db_acc; }() // iile }; - memgraph::query::InterpreterContext interpreter_context{{}, nullptr, &repl_state}; + memgraph::system::System system_state; + memgraph::query::InterpreterContext interpreter_context{{}, + nullptr, + &repl_state, + system_state +#ifdef MG_ENTERPRISE + , + nullptr +#endif + }; InterpreterFaker running_interpreter{&interpreter_context, db}, main_interpreter{&interpreter_context, db}; void TearDown() override { diff --git a/tests/unit/transaction_queue_multiple.cpp b/tests/unit/transaction_queue_multiple.cpp index 0b6cdf635..da8aabd02 100644 --- a/tests/unit/transaction_queue_multiple.cpp +++ b/tests/unit/transaction_queue_multiple.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -68,7 +68,16 @@ class TransactionQueueMultipleTest : public ::testing::Test { }() // iile }; - memgraph::query::InterpreterContext interpreter_context{{}, nullptr, &repl_state}; + memgraph::system::System system_state; + memgraph::query::InterpreterContext interpreter_context{{}, + nullptr, + &repl_state, + system_state +#ifdef MG_ENTERPRISE + , + nullptr +#endif + }; InterpreterFaker main_interpreter{&interpreter_context, db}; std::vector running_interpreters;