Adding authentication data replication (#1666)

* Add AUTH system tx deltas
* Add auth data RPC and handlers
* Support multiple system deltas in a single transaction
* Added e2e test
* Bugfix: KVStore segfault after move

---------

Co-authored-by: Gareth Lloyd <gareth.lloyd@memgraph.io>
This commit is contained in:
andrejtonev 2024-02-05 11:37:00 +01:00 committed by GitHub
parent c46dad18fe
commit 7ead00f23e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
95 changed files with 4430 additions and 1865 deletions

View File

@ -22,8 +22,10 @@ add_subdirectory(dbms)
add_subdirectory(flags) add_subdirectory(flags)
add_subdirectory(distributed) add_subdirectory(distributed)
add_subdirectory(replication) add_subdirectory(replication)
add_subdirectory(replication_handler)
add_subdirectory(coordination) add_subdirectory(coordination)
add_subdirectory(replication_coordination_glue) add_subdirectory(replication_coordination_glue)
add_subdirectory(system)
string(TOLOWER ${CMAKE_BUILD_TYPE} lower_build_type) string(TOLOWER ${CMAKE_BUILD_TYPE} lower_build_type)
@ -43,10 +45,10 @@ set(mg_single_node_v2_sources
add_executable(memgraph ${mg_single_node_v2_sources}) add_executable(memgraph ${mg_single_node_v2_sources})
target_include_directories(memgraph PUBLIC ${CMAKE_SOURCE_DIR}/include) target_include_directories(memgraph PUBLIC ${CMAKE_SOURCE_DIR}/include)
target_link_libraries(memgraph stdc++fs Threads::Threads target_link_libraries(memgraph stdc++fs Threads::Threads
mg-telemetry mg-communication mg-communication-metrics mg-memory mg-utils mg-license mg-settings mg-glue mg-flags) mg-telemetry mg-communication mg-communication-metrics mg-memory mg-utils mg-license mg-settings mg-glue mg-flags mg::system mg::replication_handler)
# NOTE: `include/mg_procedure.syms` describes a pattern match for symbols which # NOTE: `include/mg_procedure.syms` describes a pattern match for symbols which
# should be dynamically exported, so that `dlopen` can correctly link the # should be dynamically exported, so that `dlopen` can correctly link th
# symbols in custom procedure module libraries. # symbols in custom procedure module libraries.
target_link_libraries(memgraph "-Wl,--dynamic-list=${CMAKE_SOURCE_DIR}/include/mg_procedure.syms") target_link_libraries(memgraph "-Wl,--dynamic-list=${CMAKE_SOURCE_DIR}/include/mg_procedure.syms")
set_target_properties(memgraph PROPERTIES set_target_properties(memgraph PROPERTIES

View File

@ -2,7 +2,9 @@ set(auth_src_files
auth.cpp auth.cpp
crypto.cpp crypto.cpp
models.cpp models.cpp
module.cpp) module.cpp
rpc.cpp
replication_handlers.cpp)
find_package(Seccomp REQUIRED) find_package(Seccomp REQUIRED)
find_package(fmt REQUIRED) find_package(fmt REQUIRED)
@ -11,7 +13,7 @@ find_package(gflags REQUIRED)
add_library(mg-auth STATIC ${auth_src_files}) add_library(mg-auth STATIC ${auth_src_files})
target_link_libraries(mg-auth json libbcrypt gflags fmt::fmt) target_link_libraries(mg-auth json libbcrypt gflags fmt::fmt)
target_link_libraries(mg-auth mg-utils mg-kvstore mg-license ) target_link_libraries(mg-auth mg-utils mg-kvstore mg-license mg::system mg-replication)
target_link_libraries(mg-auth ${Seccomp_LIBRARIES}) target_link_libraries(mg-auth ${Seccomp_LIBRARIES})
target_include_directories(mg-auth SYSTEM PRIVATE ${Seccomp_INCLUDE_DIRS}) target_include_directories(mg-auth SYSTEM PRIVATE ${Seccomp_INCLUDE_DIRS})

View File

@ -9,13 +9,16 @@
#include "auth/auth.hpp" #include "auth/auth.hpp"
#include <iostream> #include <iostream>
#include <optional>
#include <utility> #include <utility>
#include <fmt/format.h> #include <fmt/format.h>
#include "auth/crypto.hpp" #include "auth/crypto.hpp"
#include "auth/exceptions.hpp" #include "auth/exceptions.hpp"
#include "auth/rpc.hpp"
#include "license/license.hpp" #include "license/license.hpp"
#include "system/transaction.hpp"
#include "utils/flag_validation.hpp" #include "utils/flag_validation.hpp"
#include "utils/message.hpp" #include "utils/message.hpp"
#include "utils/settings.hpp" #include "utils/settings.hpp"
@ -41,12 +44,84 @@ DEFINE_VALIDATED_int32(auth_module_timeout_ms, 10000,
FLAG_IN_RANGE(100, 1800000)); FLAG_IN_RANGE(100, 1800000));
namespace memgraph::auth { namespace memgraph::auth {
namespace {
#ifdef MG_ENTERPRISE
/**
* REPLICATION SYSTEM ACTION IMPLEMENTATIONS
*/
struct UpdateAuthData : memgraph::system::ISystemAction {
explicit UpdateAuthData(User user) : user_{std::move(user)}, role_{std::nullopt} {}
explicit UpdateAuthData(Role role) : user_{std::nullopt}, role_{std::move(role)} {}
void DoDurability() override { /* Done during Auth execution */
}
bool DoReplication(replication::ReplicationClient &client, replication::ReplicationEpoch const &epoch,
memgraph::system::Transaction const &txn) const override {
auto check_response = [](const replication::UpdateAuthDataRes &response) { return response.success; };
if (user_) {
return client.SteamAndFinalizeDelta<replication::UpdateAuthDataRpc>(
check_response, std::string{epoch.id()}, txn.last_committed_system_timestamp(), txn.timestamp(), *user_);
}
if (role_) {
return client.SteamAndFinalizeDelta<replication::UpdateAuthDataRpc>(
check_response, std::string{epoch.id()}, txn.last_committed_system_timestamp(), txn.timestamp(), *role_);
}
// Should never get here
MG_ASSERT(false, "Trying to update auth data that is not a user nor a role");
return {};
}
void PostReplication(replication::RoleMainData &mainData) const override {}
private:
std::optional<User> user_;
std::optional<Role> role_;
};
struct DropAuthData : memgraph::system::ISystemAction {
enum class AuthDataType { USER, ROLE };
explicit DropAuthData(AuthDataType type, std::string_view name) : type_{type}, name_{name} {}
void DoDurability() override { /* Done during Auth execution */
}
bool DoReplication(replication::ReplicationClient &client, replication::ReplicationEpoch const &epoch,
memgraph::system::Transaction const &txn) const override {
auto check_response = [](const replication::DropAuthDataRes &response) { return response.success; };
memgraph::replication::DropAuthDataReq::DataType type{};
switch (type_) {
case AuthDataType::USER:
type = memgraph::replication::DropAuthDataReq::DataType::USER;
break;
case AuthDataType::ROLE:
type = memgraph::replication::DropAuthDataReq::DataType::ROLE;
break;
}
return client.SteamAndFinalizeDelta<replication::DropAuthDataRpc>(
check_response, std::string{epoch.id()}, txn.last_committed_system_timestamp(), txn.timestamp(), type, name_);
}
void PostReplication(replication::RoleMainData &mainData) const override {}
private:
AuthDataType type_;
std::string name_;
};
#endif
/**
* CONSTANTS
*/
const std::string kUserPrefix = "user:"; const std::string kUserPrefix = "user:";
const std::string kRolePrefix = "role:"; const std::string kRolePrefix = "role:";
const std::string kLinkPrefix = "link:"; const std::string kLinkPrefix = "link:";
const std::string kVersion = "version"; const std::string kVersion = "version";
static constexpr auto kVersionV1 = "V1"; static constexpr auto kVersionV1 = "V1";
} // namespace
/** /**
* All data stored in the `Auth` storage is stored in an underlying * All data stored in the `Auth` storage is stored in an underlying
@ -148,6 +223,12 @@ std::optional<User> Auth::Authenticate(const std::string &username, const std::s
// Authenticate the user. // Authenticate the user.
if (!is_authenticated) return std::nullopt; if (!is_authenticated) return std::nullopt;
/**
* TODO
* The auth module should not update auth data.
* There is now way to replicate it and we should not be storing sensitive data if we don't have to.
*/
// Find or create the user and return it. // Find or create the user and return it.
auto user = GetUser(username); auto user = GetUser(username);
if (!user) { if (!user) {
@ -240,7 +321,7 @@ std::optional<User> Auth::GetUser(const std::string &username_orig) const {
return user; return user;
} }
void Auth::SaveUser(const User &user) { void Auth::SaveUser(const User &user, system::Transaction *system_tx) {
bool success = false; bool success = false;
if (const auto *role = user.role(); role != nullptr) { if (const auto *role = user.role(); role != nullptr) {
success = storage_.PutMultiple( success = storage_.PutMultiple(
@ -252,6 +333,12 @@ void Auth::SaveUser(const User &user) {
if (!success) { if (!success) {
throw AuthException("Couldn't save user '{}'!", user.username()); throw AuthException("Couldn't save user '{}'!", user.username());
} }
// All changes to the user end up calling this function, so no need to add a delta anywhere else
if (system_tx) {
#ifdef MG_ENTERPRISE
system_tx->AddAction<UpdateAuthData>(user);
#endif
}
} }
void Auth::UpdatePassword(auth::User &user, const std::optional<std::string> &password) { void Auth::UpdatePassword(auth::User &user, const std::optional<std::string> &password) {
@ -284,7 +371,8 @@ void Auth::UpdatePassword(auth::User &user, const std::optional<std::string> &pa
user.UpdatePassword(password); user.UpdatePassword(password);
} }
std::optional<User> Auth::AddUser(const std::string &username, const std::optional<std::string> &password) { std::optional<User> Auth::AddUser(const std::string &username, const std::optional<std::string> &password,
system::Transaction *system_tx) {
if (!NameRegexMatch(username)) { if (!NameRegexMatch(username)) {
throw AuthException("Invalid user name."); throw AuthException("Invalid user name.");
} }
@ -294,17 +382,23 @@ std::optional<User> Auth::AddUser(const std::string &username, const std::option
if (existing_role) return std::nullopt; if (existing_role) return std::nullopt;
auto new_user = User(username); auto new_user = User(username);
UpdatePassword(new_user, password); UpdatePassword(new_user, password);
SaveUser(new_user); SaveUser(new_user, system_tx);
return new_user; return new_user;
} }
bool Auth::RemoveUser(const std::string &username_orig) { bool Auth::RemoveUser(const std::string &username_orig, system::Transaction *system_tx) {
auto username = utils::ToLowerCase(username_orig); auto username = utils::ToLowerCase(username_orig);
if (!storage_.Get(kUserPrefix + username)) return false; if (!storage_.Get(kUserPrefix + username)) return false;
std::vector<std::string> keys({kLinkPrefix + username, kUserPrefix + username}); std::vector<std::string> keys({kLinkPrefix + username, kUserPrefix + username});
if (!storage_.DeleteMultiple(keys)) { if (!storage_.DeleteMultiple(keys)) {
throw AuthException("Couldn't remove user '{}'!", username); throw AuthException("Couldn't remove user '{}'!", username);
} }
// Handling drop user delta
if (system_tx) {
#ifdef MG_ENTERPRISE
system_tx->AddAction<DropAuthData>(DropAuthData::AuthDataType::USER, username);
#endif
}
return true; return true;
} }
@ -321,6 +415,19 @@ std::vector<auth::User> Auth::AllUsers() const {
return ret; return ret;
} }
std::vector<std::string> Auth::AllUsernames() const {
std::vector<std::string> ret;
for (auto it = storage_.begin(kUserPrefix); it != storage_.end(kUserPrefix); ++it) {
auto username = it->first.substr(kUserPrefix.size());
if (username != utils::ToLowerCase(username)) continue;
auto user = GetUser(username);
if (user) {
ret.push_back(username);
}
}
return ret;
}
bool Auth::HasUsers() const { return storage_.begin(kUserPrefix) != storage_.end(kUserPrefix); } bool Auth::HasUsers() const { return storage_.begin(kUserPrefix) != storage_.end(kUserPrefix); }
std::optional<Role> Auth::GetRole(const std::string &rolename_orig) const { std::optional<Role> Auth::GetRole(const std::string &rolename_orig) const {
@ -338,24 +445,30 @@ std::optional<Role> Auth::GetRole(const std::string &rolename_orig) const {
return Role::Deserialize(data); return Role::Deserialize(data);
} }
void Auth::SaveRole(const Role &role) { void Auth::SaveRole(const Role &role, system::Transaction *system_tx) {
if (!storage_.Put(kRolePrefix + role.rolename(), role.Serialize().dump())) { if (!storage_.Put(kRolePrefix + role.rolename(), role.Serialize().dump())) {
throw AuthException("Couldn't save role '{}'!", role.rolename()); throw AuthException("Couldn't save role '{}'!", role.rolename());
} }
// All changes to the role end up calling this function, so no need to add a delta anywhere else
if (system_tx) {
#ifdef MG_ENTERPRISE
system_tx->AddAction<UpdateAuthData>(role);
#endif
}
} }
std::optional<Role> Auth::AddRole(const std::string &rolename) { std::optional<Role> Auth::AddRole(const std::string &rolename, system::Transaction *system_tx) {
if (!NameRegexMatch(rolename)) { if (!NameRegexMatch(rolename)) {
throw AuthException("Invalid role name."); throw AuthException("Invalid role name.");
} }
if (auto existing_role = GetRole(rolename)) return std::nullopt; if (auto existing_role = GetRole(rolename)) return std::nullopt;
if (auto existing_user = GetUser(rolename)) return std::nullopt; if (auto existing_user = GetUser(rolename)) return std::nullopt;
auto new_role = Role(rolename); auto new_role = Role(rolename);
SaveRole(new_role); SaveRole(new_role, system_tx);
return new_role; return new_role;
} }
bool Auth::RemoveRole(const std::string &rolename_orig) { bool Auth::RemoveRole(const std::string &rolename_orig, system::Transaction *system_tx) {
auto rolename = utils::ToLowerCase(rolename_orig); auto rolename = utils::ToLowerCase(rolename_orig);
if (!storage_.Get(kRolePrefix + rolename)) return false; if (!storage_.Get(kRolePrefix + rolename)) return false;
std::vector<std::string> keys; std::vector<std::string> keys;
@ -368,6 +481,12 @@ bool Auth::RemoveRole(const std::string &rolename_orig) {
if (!storage_.DeleteMultiple(keys)) { if (!storage_.DeleteMultiple(keys)) {
throw AuthException("Couldn't remove role '{}'!", rolename); throw AuthException("Couldn't remove role '{}'!", rolename);
} }
// Handling drop role delta
if (system_tx) {
#ifdef MG_ENTERPRISE
system_tx->AddAction<DropAuthData>(DropAuthData::AuthDataType::ROLE, rolename);
#endif
}
return true; return true;
} }
@ -385,6 +504,18 @@ std::vector<auth::Role> Auth::AllRoles() const {
return ret; return ret;
} }
std::vector<std::string> Auth::AllRolenames() const {
std::vector<std::string> ret;
for (auto it = storage_.begin(kRolePrefix); it != storage_.end(kRolePrefix); ++it) {
auto rolename = it->first.substr(kRolePrefix.size());
if (rolename != utils::ToLowerCase(rolename)) continue;
if (auto role = GetRole(rolename)) {
ret.push_back(rolename);
}
}
return ret;
}
std::vector<auth::User> Auth::AllUsersForRole(const std::string &rolename_orig) const { std::vector<auth::User> Auth::AllUsersForRole(const std::string &rolename_orig) const {
const auto rolename = utils::ToLowerCase(rolename_orig); const auto rolename = utils::ToLowerCase(rolename_orig);
std::vector<auth::User> ret; std::vector<auth::User> ret;
@ -404,48 +535,48 @@ std::vector<auth::User> Auth::AllUsersForRole(const std::string &rolename_orig)
} }
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
bool Auth::GrantDatabaseToUser(const std::string &db, const std::string &name) { bool Auth::GrantDatabaseToUser(const std::string &db, const std::string &name, system::Transaction *system_tx) {
if (auto user = GetUser(name)) { if (auto user = GetUser(name)) {
if (db == kAllDatabases) { if (db == kAllDatabases) {
user->db_access().GrantAll(); user->db_access().GrantAll();
} else { } else {
user->db_access().Add(db); user->db_access().Add(db);
} }
SaveUser(*user); SaveUser(*user, system_tx);
return true; return true;
} }
return false; return false;
} }
bool Auth::RevokeDatabaseFromUser(const std::string &db, const std::string &name) { bool Auth::RevokeDatabaseFromUser(const std::string &db, const std::string &name, system::Transaction *system_tx) {
if (auto user = GetUser(name)) { if (auto user = GetUser(name)) {
if (db == kAllDatabases) { if (db == kAllDatabases) {
user->db_access().DenyAll(); user->db_access().DenyAll();
} else { } else {
user->db_access().Remove(db); user->db_access().Remove(db);
} }
SaveUser(*user); SaveUser(*user, system_tx);
return true; return true;
} }
return false; return false;
} }
void Auth::DeleteDatabase(const std::string &db) { void Auth::DeleteDatabase(const std::string &db, system::Transaction *system_tx) {
for (auto it = storage_.begin(kUserPrefix); it != storage_.end(kUserPrefix); ++it) { for (auto it = storage_.begin(kUserPrefix); it != storage_.end(kUserPrefix); ++it) {
auto username = it->first.substr(kUserPrefix.size()); auto username = it->first.substr(kUserPrefix.size());
if (auto user = GetUser(username)) { if (auto user = GetUser(username)) {
user->db_access().Delete(db); user->db_access().Delete(db);
SaveUser(*user); SaveUser(*user, system_tx);
} }
} }
} }
bool Auth::SetMainDatabase(std::string_view db, const std::string &name) { bool Auth::SetMainDatabase(std::string_view db, const std::string &name, system::Transaction *system_tx) {
if (auto user = GetUser(name)) { if (auto user = GetUser(name)) {
if (!user->db_access().SetDefault(db)) { if (!user->db_access().SetDefault(db)) {
throw AuthException("Couldn't set default database '{}' for user '{}'!", db, name); throw AuthException("Couldn't set default database '{}' for user '{}'!", db, name);
} }
SaveUser(*user); SaveUser(*user, system_tx);
return true; return true;
} }
return false; return false;

View File

@ -18,10 +18,15 @@
#include "auth/module.hpp" #include "auth/module.hpp"
#include "glue/auth_global.hpp" #include "glue/auth_global.hpp"
#include "kvstore/kvstore.hpp" #include "kvstore/kvstore.hpp"
#include "system/action.hpp"
#include "utils/settings.hpp" #include "utils/settings.hpp"
#include "utils/synchronized.hpp"
namespace memgraph::auth { namespace memgraph::auth {
class Auth;
using SynchedAuth = memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock>;
static const constexpr char *const kAllDatabases = "*"; static const constexpr char *const kAllDatabases = "*";
/** /**
@ -68,6 +73,13 @@ class Auth final {
config_ = std::move(config); config_ = std::move(config);
} }
/**
* @brief
*
* @return Config
*/
Config GetConfig() const { return config_; }
/** /**
* Authenticates a user using his username and password. * Authenticates a user using his username and password.
* *
@ -96,7 +108,7 @@ class Auth final {
* *
* @throw AuthException if unable to save the user. * @throw AuthException if unable to save the user.
*/ */
void SaveUser(const User &user); void SaveUser(const User &user, system::Transaction *system_tx = nullptr);
/** /**
* Creates a user if the user doesn't exist. * Creates a user if the user doesn't exist.
@ -107,7 +119,8 @@ class Auth final {
* @return a user when the user is created, nullopt if the user exists * @return a user when the user is created, nullopt if the user exists
* @throw AuthException if unable to save the user. * @throw AuthException if unable to save the user.
*/ */
std::optional<User> AddUser(const std::string &username, const std::optional<std::string> &password = std::nullopt); std::optional<User> AddUser(const std::string &username, const std::optional<std::string> &password = std::nullopt,
system::Transaction *system_tx = nullptr);
/** /**
* Removes a user from the storage. * Removes a user from the storage.
@ -118,7 +131,7 @@ class Auth final {
* doesn't exist * doesn't exist
* @throw AuthException if unable to remove the user. * @throw AuthException if unable to remove the user.
*/ */
bool RemoveUser(const std::string &username); bool RemoveUser(const std::string &username, system::Transaction *system_tx = nullptr);
/** /**
* @brief * @brief
@ -136,6 +149,13 @@ class Auth final {
*/ */
std::vector<User> AllUsers() const; std::vector<User> AllUsers() const;
/**
* @brief
*
* @return std::vector<std::string>
*/
std::vector<std::string> AllUsernames() const;
/** /**
* Returns whether there are users in the storage. * Returns whether there are users in the storage.
* *
@ -160,7 +180,7 @@ class Auth final {
* *
* @throw AuthException if unable to save the role. * @throw AuthException if unable to save the role.
*/ */
void SaveRole(const Role &role); void SaveRole(const Role &role, system::Transaction *system_tx = nullptr);
/** /**
* Creates a role if the role doesn't exist. * Creates a role if the role doesn't exist.
@ -170,7 +190,7 @@ class Auth final {
* @return a role when the role is created, nullopt if the role exists * @return a role when the role is created, nullopt if the role exists
* @throw AuthException if unable to save the role. * @throw AuthException if unable to save the role.
*/ */
std::optional<Role> AddRole(const std::string &rolename); std::optional<Role> AddRole(const std::string &rolename, system::Transaction *system_tx = nullptr);
/** /**
* Removes a role from the storage. * Removes a role from the storage.
@ -181,7 +201,7 @@ class Auth final {
* doesn't exist * doesn't exist
* @throw AuthException if unable to remove the role. * @throw AuthException if unable to remove the role.
*/ */
bool RemoveRole(const std::string &rolename); bool RemoveRole(const std::string &rolename, system::Transaction *system_tx = nullptr);
/** /**
* Gets all roles from the storage. * Gets all roles from the storage.
@ -191,6 +211,13 @@ class Auth final {
*/ */
std::vector<Role> AllRoles() const; std::vector<Role> AllRoles() const;
/**
* @brief
*
* @return std::vector<std::string>
*/
std::vector<std::string> AllRolenames() const;
/** /**
* Gets all users for a role from the storage. * Gets all users for a role from the storage.
* *
@ -210,7 +237,7 @@ class Auth final {
* @return true on success * @return true on success
* @throw AuthException if unable to find or update the user * @throw AuthException if unable to find or update the user
*/ */
bool RevokeDatabaseFromUser(const std::string &db, const std::string &name); bool RevokeDatabaseFromUser(const std::string &db, const std::string &name, system::Transaction *system_tx = nullptr);
/** /**
* @brief Grant access to individual database for a user. * @brief Grant access to individual database for a user.
@ -220,7 +247,7 @@ class Auth final {
* @return true on success * @return true on success
* @throw AuthException if unable to find or update the user * @throw AuthException if unable to find or update the user
*/ */
bool GrantDatabaseToUser(const std::string &db, const std::string &name); bool GrantDatabaseToUser(const std::string &db, const std::string &name, system::Transaction *system_tx = nullptr);
/** /**
* @brief Delete a database from all users. * @brief Delete a database from all users.
@ -228,7 +255,7 @@ class Auth final {
* @param db name of the database to delete * @param db name of the database to delete
* @throw AuthException if unable to read data * @throw AuthException if unable to read data
*/ */
void DeleteDatabase(const std::string &db); void DeleteDatabase(const std::string &db, system::Transaction *system_tx = nullptr);
/** /**
* @brief Set main database for an individual user. * @brief Set main database for an individual user.
@ -238,7 +265,7 @@ class Auth final {
* @return true on success * @return true on success
* @throw AuthException if unable to find or update the user * @throw AuthException if unable to find or update the user
*/ */
bool SetMainDatabase(std::string_view db, const std::string &name); bool SetMainDatabase(std::string_view db, const std::string &name, system::Transaction *system_tx = nullptr);
#endif #endif
private: private:

View File

@ -611,27 +611,49 @@ Permissions User::GetPermissions() const {
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
FineGrainedAccessPermissions User::GetFineGrainedAccessLabelPermissions() const { FineGrainedAccessPermissions User::GetFineGrainedAccessLabelPermissions() const {
return Merge(GetUserFineGrainedAccessLabelPermissions(), GetRoleFineGrainedAccessLabelPermissions());
}
FineGrainedAccessPermissions User::GetFineGrainedAccessEdgeTypePermissions() const {
return Merge(GetUserFineGrainedAccessEdgeTypePermissions(), GetRoleFineGrainedAccessEdgeTypePermissions());
}
FineGrainedAccessPermissions User::GetUserFineGrainedAccessEdgeTypePermissions() const {
if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) { if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
return FineGrainedAccessPermissions{}; return FineGrainedAccessPermissions{};
} }
if (role_) { return fine_grained_access_handler_.edge_type_permissions();
return Merge(role()->fine_grained_access_handler().label_permissions(), }
fine_grained_access_handler_.label_permissions());
FineGrainedAccessPermissions User::GetUserFineGrainedAccessLabelPermissions() const {
if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
return FineGrainedAccessPermissions{};
} }
return fine_grained_access_handler_.label_permissions(); return fine_grained_access_handler_.label_permissions();
} }
FineGrainedAccessPermissions User::GetFineGrainedAccessEdgeTypePermissions() const { FineGrainedAccessPermissions User::GetRoleFineGrainedAccessEdgeTypePermissions() const {
if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) { if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
return FineGrainedAccessPermissions{}; return FineGrainedAccessPermissions{};
} }
if (role_) { if (role_) {
return Merge(role()->fine_grained_access_handler().edge_type_permissions(), return role()->fine_grained_access_handler().edge_type_permissions();
fine_grained_access_handler_.edge_type_permissions());
} }
return fine_grained_access_handler_.edge_type_permissions(); return FineGrainedAccessPermissions{};
}
FineGrainedAccessPermissions User::GetRoleFineGrainedAccessLabelPermissions() const {
if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
return FineGrainedAccessPermissions{};
}
if (role_) {
return role()->fine_grained_access_handler().label_permissions();
}
return FineGrainedAccessPermissions{};
} }
#endif #endif

View File

@ -207,6 +207,8 @@ bool operator==(const FineGrainedAccessHandler &first, const FineGrainedAccessHa
class Role final { class Role final {
public: public:
Role() = default;
explicit Role(const std::string &rolename); explicit Role(const std::string &rolename);
Role(const std::string &rolename, const Permissions &permissions); Role(const std::string &rolename, const Permissions &permissions);
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
@ -369,6 +371,10 @@ class User final {
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
FineGrainedAccessPermissions GetFineGrainedAccessLabelPermissions() const; FineGrainedAccessPermissions GetFineGrainedAccessLabelPermissions() const;
FineGrainedAccessPermissions GetFineGrainedAccessEdgeTypePermissions() const; FineGrainedAccessPermissions GetFineGrainedAccessEdgeTypePermissions() const;
FineGrainedAccessPermissions GetUserFineGrainedAccessLabelPermissions() const;
FineGrainedAccessPermissions GetUserFineGrainedAccessEdgeTypePermissions() const;
FineGrainedAccessPermissions GetRoleFineGrainedAccessLabelPermissions() const;
FineGrainedAccessPermissions GetRoleFineGrainedAccessEdgeTypePermissions() const;
const FineGrainedAccessHandler &fine_grained_access_handler() const; const FineGrainedAccessHandler &fine_grained_access_handler() const;
FineGrainedAccessHandler &fine_grained_access_handler(); FineGrainedAccessHandler &fine_grained_access_handler();
#endif #endif

View File

@ -0,0 +1,170 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "auth/replication_handlers.hpp"
#include "auth/auth.hpp"
#include "auth/rpc.hpp"
#include "license/license.hpp"
namespace memgraph::auth {
#ifdef MG_ENTERPRISE
void UpdateAuthDataHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access, auth::SynchedAuth &auth,
slk::Reader *req_reader, slk::Builder *res_builder) {
replication::UpdateAuthDataReq req;
memgraph::slk::Load(&req, req_reader);
using memgraph::replication::UpdateAuthDataRes;
UpdateAuthDataRes res(false);
// Note: No need to check epoch, recovery mechanism is done by a full uptodate snapshot
// of the set of databases. Hence no history exists to maintain regarding epoch change.
// If MAIN has changed we need to check this new group_timestamp is consistent with
// what we have so far.
if (req.expected_group_timestamp != system_state_access.LastCommitedTS()) {
spdlog::debug("UpdateAuthDataHandler: bad expected timestamp {},{}", req.expected_group_timestamp,
system_state_access.LastCommitedTS());
memgraph::slk::Save(res, res_builder);
return;
}
try {
// Update
if (req.user) auth->SaveUser(*req.user);
if (req.role) auth->SaveRole(*req.role);
// Success
system_state_access.SetLastCommitedTS(req.new_group_timestamp);
res = UpdateAuthDataRes(true);
spdlog::debug("UpdateAuthDataHandler: SUCCESS updated LCTS to {}", req.new_group_timestamp);
} catch (const auth::AuthException & /* not used */) {
// Failure
}
memgraph::slk::Save(res, res_builder);
}
void DropAuthDataHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access, auth::SynchedAuth &auth,
slk::Reader *req_reader, slk::Builder *res_builder) {
replication::DropAuthDataReq req;
memgraph::slk::Load(&req, req_reader);
using memgraph::replication::DropAuthDataRes;
DropAuthDataRes res(false);
// Note: No need to check epoch, recovery mechanism is done by a full uptodate snapshot
// of the set of databases. Hence no history exists to maintain regarding epoch change.
// If MAIN has changed we need to check this new group_timestamp is consistent with
// what we have so far.
if (req.expected_group_timestamp != system_state_access.LastCommitedTS()) {
spdlog::debug("DropAuthDataHandler: bad expected timestamp {},{}", req.expected_group_timestamp,
system_state_access.LastCommitedTS());
memgraph::slk::Save(res, res_builder);
return;
}
try {
// Remove
switch (req.type) {
case replication::DropAuthDataReq::DataType::USER:
auth->RemoveUser(req.name);
break;
case replication::DropAuthDataReq::DataType::ROLE:
auth->RemoveRole(req.name);
break;
}
// Success
system_state_access.SetLastCommitedTS(req.new_group_timestamp);
res = DropAuthDataRes(true);
spdlog::debug("DropAuthDataHandler: SUCCESS updated LCTS to {}", req.new_group_timestamp);
} catch (const auth::AuthException & /* not used */) {
// Failure
}
memgraph::slk::Save(res, res_builder);
}
bool SystemRecoveryHandler(auth::SynchedAuth &auth, auth::Auth::Config auth_config,
const std::vector<auth::User> &users, const std::vector<auth::Role> &roles) {
return auth.WithLock([&](auto &locked_auth) {
// Update config
locked_auth.SetConfig(std::move(auth_config));
// Get all current users
auto old_users = locked_auth.AllUsernames();
// Save incoming users
for (const auto &user : users) {
// Missing users
try {
locked_auth.SaveUser(user);
} catch (const auth::AuthException &) {
spdlog::debug("SystemRecoveryHandler: Failed to save user");
return false;
}
const auto it = std::find(old_users.begin(), old_users.end(), user.username());
if (it != old_users.end()) old_users.erase(it);
}
// Delete all the leftover users
for (const auto &user : old_users) {
if (!locked_auth.RemoveUser(user)) {
spdlog::debug("SystemRecoveryHandler: Failed to remove user \"{}\".", user);
return false;
}
}
// Roles are only supported with a license
if (license::global_license_checker.IsEnterpriseValidFast()) {
// Get all current roles
auto old_roles = locked_auth.AllRolenames();
// Save incoming users
for (const auto &role : roles) {
// Missing users
try {
locked_auth.SaveRole(role);
} catch (const auth::AuthException &) {
spdlog::debug("SystemRecoveryHandler: Failed to save user");
return false;
}
const auto it = std::find(old_roles.begin(), old_roles.end(), role.rolename());
if (it != old_roles.end()) old_roles.erase(it);
}
// Delete all the leftover users
for (const auto &role : old_roles) {
if (!locked_auth.RemoveRole(role)) {
spdlog::debug("SystemRecoveryHandler: Failed to remove user \"{}\".", role);
return false;
}
}
}
// Success
return true;
});
}
void Register(replication::RoleReplicaData const &data, system::ReplicaHandlerAccessToState &system_state_access,
auth::SynchedAuth &auth) {
// NOTE: Register even without license as the user could add a license at run-time
data.server->rpc_server_.Register<replication::UpdateAuthDataRpc>(
[system_state_access, &auth](auto *req_reader, auto *res_builder) mutable {
spdlog::debug("Received UpdateAuthDataRpc");
UpdateAuthDataHandler(system_state_access, auth, req_reader, res_builder);
});
data.server->rpc_server_.Register<replication::DropAuthDataRpc>(
[system_state_access, &auth](auto *req_reader, auto *res_builder) mutable {
spdlog::debug("Received DropAuthDataRpc");
DropAuthDataHandler(system_state_access, auth, req_reader, res_builder);
});
}
#endif
} // namespace memgraph::auth

View File

@ -0,0 +1,31 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "auth/auth.hpp"
#include "replication/state.hpp"
#include "slk/streams.hpp"
#include "system/state.hpp"
namespace memgraph::auth {
#ifdef MG_ENTERPRISE
void UpdateAuthDataHandler(system::ReplicaHandlerAccessToState &system_state_access, auth::SynchedAuth &auth,
slk::Reader *req_reader, slk::Builder *res_builder);
void DropAuthDataHandler(system::ReplicaHandlerAccessToState &system_state_access, auth::SynchedAuth &auth,
slk::Reader *req_reader, slk::Builder *res_builder);
bool SystemRecoveryHandler(auth::SynchedAuth &auth, auth::Auth::Config auth_config,
const std::vector<auth::User> &users, const std::vector<auth::Role> &roles);
void Register(replication::RoleReplicaData const &data, system::ReplicaHandlerAccessToState &system_state_access,
auth::SynchedAuth &auth);
#endif
} // namespace memgraph::auth

178
src/auth/rpc.cpp Normal file
View File

@ -0,0 +1,178 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "auth/rpc.hpp"
#include <json/json.hpp>
#include "auth/auth.hpp"
#include "slk/serialization.hpp"
#include "slk/streams.hpp"
#include "utils/enum.hpp"
namespace memgraph::slk {
// Serialize code for auth::Role
void Save(const auth::Role &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.Serialize().dump(), builder);
}
namespace {
auth::Role LoadAuthRole(memgraph::slk::Reader *reader) {
std::string tmp;
memgraph::slk::Load(&tmp, reader);
const auto json = nlohmann::json::parse(tmp);
return memgraph::auth::Role::Deserialize(json);
}
} // namespace
// Deserialize code for auth::Role
void Load(auth::Role *self, memgraph::slk::Reader *reader) { *self = LoadAuthRole(reader); }
// Special case for optional<Role>
template <>
inline void Load<auth::Role>(std::optional<auth::Role> *obj, Reader *reader) {
bool exists = false;
Load(&exists, reader);
if (exists) {
obj->emplace(LoadAuthRole(reader));
} else {
*obj = std::nullopt;
}
}
// Serialize code for auth::User
void Save(const auth::User &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.Serialize().dump(), builder);
std::optional<auth::Role> role{};
if (const auto *role_ptr = self.role(); role_ptr) {
role.emplace(*role_ptr);
}
memgraph::slk::Save(role, builder);
}
// Deserialize code for auth::User
void Load(auth::User *self, memgraph::slk::Reader *reader) {
std::string tmp;
memgraph::slk::Load(&tmp, reader);
const auto json = nlohmann::json::parse(tmp);
*self = memgraph::auth::User::Deserialize(json);
std::optional<auth::Role> role{};
memgraph::slk::Load(&role, reader);
if (role)
self->SetRole(*role);
else
self->ClearRole();
}
// Serialize code for auth::Auth::Config
void Save(const auth::Auth::Config &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.name_regex_str, builder);
memgraph::slk::Save(self.password_regex_str, builder);
memgraph::slk::Save(self.password_permit_null, builder);
}
// Deserialize code for auth::Auth::Config
void Load(auth::Auth::Config *self, memgraph::slk::Reader *reader) {
std::string name_regex_str{};
std::string password_regex_str{};
bool password_permit_null{};
memgraph::slk::Load(&name_regex_str, reader);
memgraph::slk::Load(&password_regex_str, reader);
memgraph::slk::Load(&password_permit_null, reader);
*self = auth::Auth::Config{std::move(name_regex_str), std::move(password_regex_str), password_permit_null};
}
// Serialize code for UpdateAuthDataReq
void Save(const memgraph::replication::UpdateAuthDataReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.epoch_id, builder);
memgraph::slk::Save(self.expected_group_timestamp, builder);
memgraph::slk::Save(self.new_group_timestamp, builder);
memgraph::slk::Save(self.user, builder);
memgraph::slk::Save(self.role, builder);
}
void Load(memgraph::replication::UpdateAuthDataReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->epoch_id, reader);
memgraph::slk::Load(&self->expected_group_timestamp, reader);
memgraph::slk::Load(&self->new_group_timestamp, reader);
memgraph::slk::Load(&self->user, reader);
memgraph::slk::Load(&self->role, reader);
}
// Serialize code for UpdateAuthDataRes
void Save(const memgraph::replication::UpdateAuthDataRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.success, builder);
}
void Load(memgraph::replication::UpdateAuthDataRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->success, reader);
}
// Serialize code for DropAuthDataReq
void Save(const memgraph::replication::DropAuthDataReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.epoch_id, builder);
memgraph::slk::Save(self.expected_group_timestamp, builder);
memgraph::slk::Save(self.new_group_timestamp, builder);
memgraph::slk::Save(utils::EnumToNum<2, uint8_t>(self.type), builder);
memgraph::slk::Save(self.name, builder);
}
void Load(memgraph::replication::DropAuthDataReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->epoch_id, reader);
memgraph::slk::Load(&self->expected_group_timestamp, reader);
memgraph::slk::Load(&self->new_group_timestamp, reader);
uint8_t type_tmp = 0;
memgraph::slk::Load(&type_tmp, reader);
if (!utils::NumToEnum<2>(type_tmp, self->type)) {
throw SlkReaderException("Unexpected result line:{}!", __LINE__);
}
memgraph::slk::Load(&self->name, reader);
}
// Serialize code for DropAuthDataRes
void Save(const memgraph::replication::DropAuthDataRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.success, builder);
}
void Load(memgraph::replication::DropAuthDataRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->success, reader);
}
} // namespace memgraph::slk
namespace memgraph::replication {
constexpr utils::TypeInfo UpdateAuthDataReq::kType{utils::TypeId::REP_UPDATE_AUTH_DATA_REQ, "UpdateAuthDataReq",
nullptr};
constexpr utils::TypeInfo UpdateAuthDataRes::kType{utils::TypeId::REP_UPDATE_AUTH_DATA_RES, "UpdateAuthDataRes",
nullptr};
constexpr utils::TypeInfo DropAuthDataReq::kType{utils::TypeId::REP_DROP_AUTH_DATA_REQ, "DropAuthDataReq", nullptr};
constexpr utils::TypeInfo DropAuthDataRes::kType{utils::TypeId::REP_DROP_AUTH_DATA_RES, "DropAuthDataRes", nullptr};
void UpdateAuthDataReq::Save(const UpdateAuthDataReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void UpdateAuthDataReq::Load(UpdateAuthDataReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
void UpdateAuthDataRes::Save(const UpdateAuthDataRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void UpdateAuthDataRes::Load(UpdateAuthDataRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
void DropAuthDataReq::Save(const DropAuthDataReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void DropAuthDataReq::Load(DropAuthDataReq *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); }
void DropAuthDataRes::Save(const DropAuthDataRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void DropAuthDataRes::Load(DropAuthDataRes *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); }
} // namespace memgraph::replication

119
src/auth/rpc.hpp Normal file
View File

@ -0,0 +1,119 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <optional>
#include "auth/auth.hpp"
#include "auth/models.hpp"
#include "rpc/messages.hpp"
#include "slk/streams.hpp"
namespace memgraph::replication {
struct UpdateAuthDataReq {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(UpdateAuthDataReq *self, memgraph::slk::Reader *reader);
static void Save(const UpdateAuthDataReq &self, memgraph::slk::Builder *builder);
UpdateAuthDataReq() = default;
UpdateAuthDataReq(std::string epoch_id, uint64_t expected_ts, uint64_t new_ts, auth::User user)
: epoch_id{std::move(epoch_id)},
expected_group_timestamp{expected_ts},
new_group_timestamp{new_ts},
user{std::move(user)} {}
UpdateAuthDataReq(std::string epoch_id, uint64_t expected_ts, uint64_t new_ts, auth::Role role)
: epoch_id{std::move(epoch_id)},
expected_group_timestamp{expected_ts},
new_group_timestamp{new_ts},
role{std::move(role)} {}
std::string epoch_id;
uint64_t expected_group_timestamp;
uint64_t new_group_timestamp;
std::optional<auth::User> user;
std::optional<auth::Role> role;
};
struct UpdateAuthDataRes {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(UpdateAuthDataRes *self, memgraph::slk::Reader *reader);
static void Save(const UpdateAuthDataRes &self, memgraph::slk::Builder *builder);
UpdateAuthDataRes() = default;
explicit UpdateAuthDataRes(bool success) : success{success} {}
bool success;
};
using UpdateAuthDataRpc = rpc::RequestResponse<UpdateAuthDataReq, UpdateAuthDataRes>;
struct DropAuthDataReq {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(DropAuthDataReq *self, memgraph::slk::Reader *reader);
static void Save(const DropAuthDataReq &self, memgraph::slk::Builder *builder);
DropAuthDataReq() = default;
enum class DataType { USER, ROLE };
DropAuthDataReq(std::string epoch_id, uint64_t expected_ts, uint64_t new_ts, DataType type, std::string_view name)
: epoch_id{std::move(epoch_id)},
expected_group_timestamp{expected_ts},
new_group_timestamp{new_ts},
type{type},
name{name} {}
std::string epoch_id;
uint64_t expected_group_timestamp;
uint64_t new_group_timestamp;
DataType type;
std::string name;
};
struct DropAuthDataRes {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(DropAuthDataRes *self, memgraph::slk::Reader *reader);
static void Save(const DropAuthDataRes &self, memgraph::slk::Builder *builder);
DropAuthDataRes() = default;
explicit DropAuthDataRes(bool success) : success{success} {}
bool success;
};
using DropAuthDataRpc = rpc::RequestResponse<DropAuthDataReq, DropAuthDataRes>;
} // namespace memgraph::replication
namespace memgraph::slk {
void Save(const auth::Role &self, memgraph::slk::Builder *builder);
void Load(auth::Role *self, memgraph::slk::Reader *reader);
void Save(const auth::User &self, memgraph::slk::Builder *builder);
void Load(auth::User *self, memgraph::slk::Reader *reader);
void Save(const auth::Auth::Config &self, memgraph::slk::Builder *builder);
void Load(auth::Auth::Config *self, memgraph::slk::Reader *reader);
void Save(const memgraph::replication::UpdateAuthDataRes &self, memgraph::slk::Builder *builder);
void Load(memgraph::replication::UpdateAuthDataRes *self, memgraph::slk::Reader *reader);
void Save(const memgraph::replication::UpdateAuthDataReq & /*self*/, memgraph::slk::Builder * /*builder*/);
void Load(memgraph::replication::UpdateAuthDataReq * /*self*/, memgraph::slk::Reader * /*reader*/);
void Save(const memgraph::replication::DropAuthDataRes &self, memgraph::slk::Builder *builder);
void Load(memgraph::replication::DropAuthDataRes *self, memgraph::slk::Reader *reader);
void Save(const memgraph::replication::DropAuthDataReq & /*self*/, memgraph::slk::Builder * /*builder*/);
void Load(memgraph::replication::DropAuthDataReq * /*self*/, memgraph::slk::Reader * /*reader*/);
} // namespace memgraph::slk

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2024 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -14,8 +14,6 @@
#include <string> #include <string>
#include "auth/auth.hpp" #include "auth/auth.hpp"
#include "utils/spin_lock.hpp"
#include "utils/synchronized.hpp"
namespace memgraph::communication::websocket { namespace memgraph::communication::websocket {
@ -30,7 +28,7 @@ class AuthenticationInterface {
class SafeAuth : public AuthenticationInterface { class SafeAuth : public AuthenticationInterface {
public: public:
explicit SafeAuth(utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth) : auth_{auth} {} explicit SafeAuth(auth::SynchedAuth *auth) : auth_{auth} {}
bool Authenticate(const std::string &username, const std::string &password) const override; bool Authenticate(const std::string &username, const std::string &password) const override;
@ -39,6 +37,6 @@ class SafeAuth : public AuthenticationInterface {
bool HasAnyUsers() const override; bool HasAnyUsers() const override;
private: private:
utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth_; auth::SynchedAuth *auth_;
}; };
} // namespace memgraph::communication::websocket } // namespace memgraph::communication::websocket

View File

@ -13,6 +13,7 @@ target_sources(mg-coordination
include/coordination/coordinator_data.hpp include/coordination/coordinator_data.hpp
include/coordination/constants.hpp include/coordination/constants.hpp
include/coordination/coordinator_cluster_config.hpp include/coordination/coordinator_cluster_config.hpp
include/coordination/coordinator_handlers.hpp
PRIVATE PRIVATE
coordinator_client.cpp coordinator_client.cpp
@ -21,9 +22,10 @@ target_sources(mg-coordination
coordinator_server.cpp coordinator_server.cpp
coordinator_data.cpp coordinator_data.cpp
coordinator_instance.cpp coordinator_instance.cpp
coordinator_handlers.cpp
) )
target_include_directories(mg-coordination PUBLIC include) target_include_directories(mg-coordination PUBLIC include)
target_link_libraries(mg-coordination target_link_libraries(mg-coordination
PUBLIC mg::utils mg::rpc mg::slk mg::io mg::repl_coord_glue lib::rangev3 nuraft PUBLIC mg::utils mg::rpc mg::slk mg::io mg::repl_coord_glue lib::rangev3 nuraft mg-replication_handler
) )

View File

@ -183,7 +183,7 @@ auto CoordinatorData::RegisterInstance(CoordinatorClientConfig config) -> Regist
if (std::ranges::any_of(registered_instances_, [&config](CoordinatorInstance const &instance) { if (std::ranges::any_of(registered_instances_, [&config](CoordinatorInstance const &instance) {
return instance.SocketAddress() == config.SocketAddress(); return instance.SocketAddress() == config.SocketAddress();
})) { })) {
return RegisterInstanceCoordinatorStatus::END_POINT_EXISTS; return RegisterInstanceCoordinatorStatus::ENDPOINT_EXISTS;
} }
try { try {

View File

@ -10,41 +10,35 @@
// licenses/APL.txt. // licenses/APL.txt.
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
#include "coordination/coordinator_handlers.hpp"
#include "dbms/coordinator_handlers.hpp" #include <range/v3/view.hpp>
#include "coordination/coordinator_exceptions.hpp"
#include "coordination/coordinator_rpc.hpp" #include "coordination/coordinator_rpc.hpp"
#include "dbms/dbms_handler.hpp" #include "coordination/include/coordination/coordinator_server.hpp"
#include "dbms/replication_client.hpp"
#include "dbms/utils.hpp"
#include "range/v3/view.hpp"
namespace memgraph::dbms { namespace memgraph::dbms {
void CoordinatorHandlers::Register(DbmsHandler &dbms_handler) { void CoordinatorHandlers::Register(memgraph::coordination::CoordinatorServer &server,
auto &server = dbms_handler.CoordinatorState().GetCoordinatorServer(); replication::ReplicationHandler &replication_handler) {
server.Register<coordination::PromoteReplicaToMainRpc>( server.Register<coordination::PromoteReplicaToMainRpc>(
[&dbms_handler](slk::Reader *req_reader, slk::Builder *res_builder) -> void { [&](slk::Reader *req_reader, slk::Builder *res_builder) -> void {
spdlog::info("Received PromoteReplicaToMainRpc"); spdlog::info("Received PromoteReplicaToMainRpc");
CoordinatorHandlers::PromoteReplicaToMainHandler(dbms_handler, req_reader, res_builder); CoordinatorHandlers::PromoteReplicaToMainHandler(replication_handler, req_reader, res_builder);
}); });
server.Register<coordination::DemoteMainToReplicaRpc>( server.Register<coordination::DemoteMainToReplicaRpc>(
[&dbms_handler](slk::Reader *req_reader, slk::Builder *res_builder) -> void { [&replication_handler](slk::Reader *req_reader, slk::Builder *res_builder) -> void {
spdlog::info("Received DemoteMainToReplicaRpc from coordinator server"); spdlog::info("Received DemoteMainToReplicaRpc from coordinator server");
CoordinatorHandlers::DemoteMainToReplicaHandler(dbms_handler, req_reader, res_builder); CoordinatorHandlers::DemoteMainToReplicaHandler(replication_handler, req_reader, res_builder);
}); });
} }
void CoordinatorHandlers::DemoteMainToReplicaHandler(DbmsHandler &dbms_handler, slk::Reader *req_reader, void CoordinatorHandlers::DemoteMainToReplicaHandler(replication::ReplicationHandler &replication_handler,
slk::Builder *res_builder) { slk::Reader *req_reader, slk::Builder *res_builder) {
auto &repl_state = dbms_handler.ReplicationState(); spdlog::info("Executing DemoteMainToReplicaHandler");
spdlog::info("Executing SetMainToReplicaHandler");
if (repl_state.IsReplica()) { if (!replication_handler.IsMain()) {
spdlog::error("Setting to replica must be performed on main."); spdlog::error("Setting to replica must be performed on main.");
slk::Save(coordination::DemoteMainToReplicaRes{false}, res_builder); slk::Save(coordination::DemoteMainToReplicaRes{false}, res_builder);
return; return;
@ -57,7 +51,7 @@ void CoordinatorHandlers::DemoteMainToReplicaHandler(DbmsHandler &dbms_handler,
.ip_address = req.replication_client_info.replication_ip_address, .ip_address = req.replication_client_info.replication_ip_address,
.port = req.replication_client_info.replication_port}; .port = req.replication_client_info.replication_port};
if (bool const success = memgraph::dbms::SetReplicationRoleReplica(dbms_handler, clients_config); !success) { if (!replication_handler.SetReplicationRoleReplica(clients_config)) {
spdlog::error("Demoting main to replica failed!"); spdlog::error("Demoting main to replica failed!");
slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder); slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder);
return; return;
@ -66,19 +60,17 @@ void CoordinatorHandlers::DemoteMainToReplicaHandler(DbmsHandler &dbms_handler,
slk::Save(coordination::PromoteReplicaToMainRes{true}, res_builder); slk::Save(coordination::PromoteReplicaToMainRes{true}, res_builder);
} }
void CoordinatorHandlers::PromoteReplicaToMainHandler(DbmsHandler &dbms_handler, slk::Reader *req_reader, void CoordinatorHandlers::PromoteReplicaToMainHandler(replication::ReplicationHandler &replication_handler,
slk::Builder *res_builder) { slk::Reader *req_reader, slk::Builder *res_builder) {
auto &repl_state = dbms_handler.ReplicationState(); if (!replication_handler.IsReplica()) {
spdlog::error("Failover must be performed on replica!");
if (!repl_state.IsReplica()) {
spdlog::error("Only replica can be promoted to main!");
slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder); slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder);
return; return;
} }
// This can fail because of disk. If it does, the cluster state could get inconsistent. // This can fail because of disk. If it does, the cluster state could get inconsistent.
// We don't handle disk issues. // We don't handle disk issues.
if (bool const success = memgraph::dbms::DoReplicaToMainPromotion(dbms_handler); !success) { if (!replication_handler.DoReplicaToMainPromotion()) {
spdlog::error("Promoting replica to main failed!"); spdlog::error("Promoting replica to main failed!");
slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder); slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder);
return; return;
@ -96,53 +88,32 @@ void CoordinatorHandlers::PromoteReplicaToMainHandler(DbmsHandler &dbms_handler,
}; };
}; };
MG_ASSERT(
std::get<replication::RoleMainData>(repl_state.ReplicationData()).registered_replicas_.empty(),
"No replicas should be registered after promoting replica to main and before registering replication clients!");
// registering replicas // registering replicas
for (auto const &config : req.replication_clients_info | ranges::views::transform(converter)) { for (auto const &config : req.replication_clients_info | ranges::views::transform(converter)) {
auto instance_client = repl_state.RegisterReplica(config); auto instance_client = replication_handler.RegisterReplica(config);
if (instance_client.HasError()) { if (instance_client.HasError()) {
using enum memgraph::replication::RegisterReplicaError; using enum memgraph::replication::RegisterReplicaError;
switch (instance_client.GetError()) { switch (instance_client.GetError()) {
// Can't happen, we are already replica
case NOT_MAIN:
spdlog::error("Failover must be performed on main!");
slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder);
return;
// Can't happen, checked on the coordinator side // Can't happen, checked on the coordinator side
case NAME_EXISTS: case memgraph::query::RegisterReplicaError::NAME_EXISTS:
spdlog::error("Replica with the same name already exists!"); spdlog::error("Replica with the same name already exists!");
slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder); slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder);
return; return;
// Can't happen, checked on the coordinator side // Can't happen, checked on the coordinator side
case ENDPOINT_EXISTS: case memgraph::query::RegisterReplicaError::ENDPOINT_EXISTS:
spdlog::error("Replica with the same endpoint already exists!"); spdlog::error("Replica with the same endpoint already exists!");
slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder); slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder);
return; return;
// We don't handle disk issues // We don't handle disk issues
case COULD_NOT_BE_PERSISTED: case memgraph::query::RegisterReplicaError::COULD_NOT_BE_PERSISTED:
spdlog::error("Registered replica could not be persisted!"); spdlog::error("Registered replica could not be persisted!");
slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder); slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder);
return; return;
case SUCCESS: case memgraph::query::RegisterReplicaError::CONNECTION_FAILED:
// Connection failure is not a fatal error
break; break;
} }
} }
if (!allow_mt_repl && dbms_handler.All().size() > 1) {
spdlog::warn("Multi-tenant replication is currently not supported!");
}
auto &instance_client_ref = *instance_client.GetValue();
// Update system before enabling individual storage <-> replica clients
dbms_handler.SystemRestore(instance_client_ref);
const bool all_clients_good = memgraph::dbms::RegisterAllDatabasesClients<true>(dbms_handler, instance_client_ref);
MG_ASSERT(all_clients_good, "Failed to register one or more databases to the REPLICA \"{}\".", config.name);
StartReplicaClient(dbms_handler, instance_client_ref);
} }
slk::Save(coordination::PromoteReplicaToMainRes{true}, res_builder); slk::Save(coordination::PromoteReplicaToMainRes{true}, res_builder);

View File

@ -13,7 +13,9 @@
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
#include "slk/serialization.hpp" #include "coordination/coordinator_server.hpp"
#include "replication_handler/replication_handler.hpp"
#include "slk/streams.hpp"
namespace memgraph::dbms { namespace memgraph::dbms {
@ -21,12 +23,14 @@ class DbmsHandler;
class CoordinatorHandlers { class CoordinatorHandlers {
public: public:
static void Register(DbmsHandler &dbms_handler); static void Register(memgraph::coordination::CoordinatorServer &server,
replication::ReplicationHandler &replication_handler);
private: private:
static void PromoteReplicaToMainHandler(DbmsHandler &dbms_handler, slk::Reader *req_reader, static void PromoteReplicaToMainHandler(replication::ReplicationHandler &replication_handler, slk::Reader *req_reader,
slk::Builder *res_builder);
static void DemoteMainToReplicaHandler(replication::ReplicationHandler &replication_handler, slk::Reader *req_reader,
slk::Builder *res_builder); slk::Builder *res_builder);
static void DemoteMainToReplicaHandler(DbmsHandler &dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder);
}; };
} // namespace memgraph::dbms } // namespace memgraph::dbms

View File

@ -19,7 +19,7 @@ namespace memgraph::coordination {
enum class RegisterInstanceCoordinatorStatus : uint8_t { enum class RegisterInstanceCoordinatorStatus : uint8_t {
NAME_EXISTS, NAME_EXISTS,
END_POINT_EXISTS, ENDPOINT_EXISTS,
NOT_COORDINATOR, NOT_COORDINATOR,
RPC_FAILED, RPC_FAILED,
SUCCESS SUCCESS

View File

@ -1,2 +1,10 @@
add_library(mg-dbms STATIC dbms_handler.cpp database.cpp replication_handler.cpp coordinator_handler.cpp replication_client.cpp inmemory/replication_handlers.cpp coordinator_handlers.cpp) add_library(mg-dbms STATIC
target_link_libraries(mg-dbms mg-utils mg-storage-v2 mg-query mg-replication mg-coordination) dbms_handler.cpp
database.cpp
coordinator_handler.cpp
inmemory/replication_handlers.cpp
replication_handlers.cpp
rpc.cpp
)
target_link_libraries(mg-dbms mg-utils mg-storage-v2 mg-query mg-auth mg-replication mg-coordination)

View File

@ -18,20 +18,21 @@
namespace memgraph::dbms { namespace memgraph::dbms {
CoordinatorHandler::CoordinatorHandler(DbmsHandler &dbms_handler) : dbms_handler_(dbms_handler) {} CoordinatorHandler::CoordinatorHandler(coordination::CoordinatorState &coordinator_state)
: coordinator_state_(coordinator_state) {}
auto CoordinatorHandler::RegisterInstance(memgraph::coordination::CoordinatorClientConfig config) auto CoordinatorHandler::RegisterInstance(memgraph::coordination::CoordinatorClientConfig config)
-> coordination::RegisterInstanceCoordinatorStatus { -> coordination::RegisterInstanceCoordinatorStatus {
return dbms_handler_.CoordinatorState().RegisterInstance(config); return coordinator_state_.RegisterInstance(config);
} }
auto CoordinatorHandler::SetInstanceToMain(std::string instance_name) auto CoordinatorHandler::SetInstanceToMain(std::string instance_name)
-> coordination::SetInstanceToMainCoordinatorStatus { -> coordination::SetInstanceToMainCoordinatorStatus {
return dbms_handler_.CoordinatorState().SetInstanceToMain(std::move(instance_name)); return coordinator_state_.SetInstanceToMain(std::move(instance_name));
} }
auto CoordinatorHandler::ShowInstances() const -> std::vector<coordination::CoordinatorInstanceStatus> { auto CoordinatorHandler::ShowInstances() const -> std::vector<coordination::CoordinatorInstanceStatus> {
return dbms_handler_.CoordinatorState().ShowInstances(); return coordinator_state_.ShowInstances();
} }
} // namespace memgraph::dbms } // namespace memgraph::dbms

View File

@ -15,11 +15,9 @@
#include "coordination/coordinator_config.hpp" #include "coordination/coordinator_config.hpp"
#include "coordination/coordinator_instance_status.hpp" #include "coordination/coordinator_instance_status.hpp"
#include "coordination/coordinator_state.hpp"
#include "coordination/register_main_replica_coordinator_status.hpp" #include "coordination/register_main_replica_coordinator_status.hpp"
#include "utils/result.hpp"
#include <cstdint>
#include <optional>
#include <vector> #include <vector>
namespace memgraph::dbms { namespace memgraph::dbms {
@ -28,7 +26,7 @@ class DbmsHandler;
class CoordinatorHandler { class CoordinatorHandler {
public: public:
explicit CoordinatorHandler(DbmsHandler &dbms_handler); explicit CoordinatorHandler(coordination::CoordinatorState &coordinator_state);
auto RegisterInstance(coordination::CoordinatorClientConfig config) auto RegisterInstance(coordination::CoordinatorClientConfig config)
-> coordination::RegisterInstanceCoordinatorStatus; -> coordination::RegisterInstanceCoordinatorStatus;
@ -38,7 +36,7 @@ class CoordinatorHandler {
auto ShowInstances() const -> std::vector<coordination::CoordinatorInstanceStatus>; auto ShowInstances() const -> std::vector<coordination::CoordinatorInstanceStatus>;
private: private:
DbmsHandler &dbms_handler_; coordination::CoordinatorState &coordinator_state_;
}; };
} // namespace memgraph::dbms } // namespace memgraph::dbms

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd. // Copyright 2024 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -11,10 +11,7 @@
#include "dbms/database.hpp" #include "dbms/database.hpp"
#include "dbms/inmemory/storage_helper.hpp" #include "dbms/inmemory/storage_helper.hpp"
#include "dbms/replication_handler.hpp"
#include "flags/storage_mode.hpp"
#include "storage/v2/disk/storage.hpp" #include "storage/v2/disk/storage.hpp"
#include "storage/v2/inmemory/storage.hpp"
#include "storage/v2/storage_mode.hpp" #include "storage/v2/storage_mode.hpp"
template struct memgraph::utils::Gatekeeper<memgraph::dbms::Database>; template struct memgraph::utils::Gatekeeper<memgraph::dbms::Database>;

View File

@ -11,29 +11,73 @@
#include "dbms/dbms_handler.hpp" #include "dbms/dbms_handler.hpp"
#include "dbms/coordinator_handlers.hpp"
#include "flags/replication.hpp"
#include <cstdint> #include <cstdint>
#include <filesystem> #include <filesystem>
#include "dbms/constants.hpp" #include "dbms/constants.hpp"
#include "dbms/global.hpp" #include "dbms/global.hpp"
#include "dbms/replication_client.hpp"
#include "spdlog/spdlog.h" #include "spdlog/spdlog.h"
#include "system/include/system/system.hpp"
#include "utils/exceptions.hpp" #include "utils/exceptions.hpp"
#include "utils/logging.hpp" #include "utils/logging.hpp"
#include "utils/uuid.hpp" #include "utils/uuid.hpp"
namespace memgraph::dbms { namespace memgraph::dbms {
#ifdef MG_ENTERPRISE
namespace { namespace {
constexpr std::string_view kDBPrefix = "database:"; // Key prefix for database durability constexpr std::string_view kDBPrefix = "database:"; // Key prefix for database durability
constexpr std::string_view kLastCommitedSystemTsKey = "last_commited_system_ts"; // Key for timestamp durability
std::string RegisterReplicaErrorToString(query::RegisterReplicaError error) {
switch (error) {
using enum query::RegisterReplicaError;
case NAME_EXISTS:
return "NAME_EXISTS";
case ENDPOINT_EXISTS:
return "ENDPOINT_EXISTS";
case CONNECTION_FAILED:
return "CONNECTION_FAILED";
case COULD_NOT_BE_PERSISTED:
return "COULD_NOT_BE_PERSISTED";
}
}
// Per storage
// NOTE Storage will connect to all replicas. Future work might change this
void RestoreReplication(replication::RoleMainData &mainData, DatabaseAccess db_acc) {
spdlog::info("Restoring replication role.");
// Each individual client has already been restored and started. Here we just go through each database and start its
// client
for (auto &instance_client : mainData.registered_replicas_) {
spdlog::info("Replica {} restoration started for {}.", instance_client.name_, db_acc->name());
const auto &ret = db_acc->storage()->repl_storage_state_.replication_clients_.WithLock(
[&, db_acc](auto &storage_clients) mutable -> utils::BasicResult<query::RegisterReplicaError> {
auto client = std::make_unique<storage::ReplicationStorageClient>(instance_client);
auto *storage = db_acc->storage();
client->Start(storage, std::move(db_acc));
// After start the storage <-> replica state should be READY or RECOVERING (if correctly started)
// MAYBE_BEHIND isn't a statement of the current state, this is the default value
// Failed to start due to branching of MAIN and REPLICA
if (client->State() == storage::replication::ReplicaState::MAYBE_BEHIND) {
spdlog::warn("Connection failed when registering replica {}. Replica will still be registered.",
instance_client.name_);
}
storage_clients.push_back(std::move(client));
return {};
});
if (ret.HasError()) {
MG_ASSERT(query::RegisterReplicaError::CONNECTION_FAILED != ret.GetError());
LOG_FATAL("Failure when restoring replica {}: {}.", instance_client.name_,
RegisterReplicaErrorToString(ret.GetError()));
}
spdlog::info("Replica {} restored for {}.", instance_client.name_, db_acc->name());
}
spdlog::info("Replication role restored to MAIN.");
}
} // namespace } // namespace
#ifdef MG_ENTERPRISE
struct Durability { struct Durability {
enum class DurabilityVersion : uint8_t { enum class DurabilityVersion : uint8_t {
V0 = 0, V0 = 0,
@ -112,11 +156,9 @@ struct Durability {
} }
}; };
DbmsHandler::DbmsHandler( DbmsHandler::DbmsHandler(storage::Config config, memgraph::system::System &system,
storage::Config config, replication::ReplicationState &repl_state, auth::SynchedAuth &auth, bool recovery_on_startup)
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth, : default_config_{std::move(config)}, auth_{auth}, repl_state_{repl_state}, system_{&system} {
bool recovery_on_startup)
: default_config_{std::move(config)}, repl_state_{ReplicationStateRootPath(default_config_)} {
// TODO: Decouple storage config from dbms config // TODO: Decouple storage config from dbms config
// TODO: Save individual db configs inside the kvstore and restore from there // TODO: Save individual db configs inside the kvstore and restore from there
@ -150,19 +192,13 @@ DbmsHandler::DbmsHandler(
const auto uuid = json.at("uuid").get<utils::UUID>(); const auto uuid = json.at("uuid").get<utils::UUID>();
const auto rel_dir = json.at("rel_dir").get<std::filesystem::path>(); const auto rel_dir = json.at("rel_dir").get<std::filesystem::path>();
spdlog::info("Restoring database {} at {}.", name, rel_dir); spdlog::info("Restoring database {} at {}.", name, rel_dir);
auto new_db = New_(name, uuid, rel_dir); auto new_db = New_(name, uuid, nullptr, rel_dir);
MG_ASSERT(!new_db.HasError(), "Failed while creating database {}.", name); MG_ASSERT(!new_db.HasError(), "Failed while creating database {}.", name);
directories.emplace(rel_dir.filename()); directories.emplace(rel_dir.filename());
spdlog::info("Database {} restored.", name); spdlog::info("Database {} restored.", name);
} }
// Read the last timestamp
auto lcst = durability_->Get(kLastCommitedSystemTsKey);
if (lcst) {
last_commited_system_timestamp_ = std::stoul(*lcst);
system_timestamp_ = last_commited_system_timestamp_;
}
} else { // Clear databases from the durability list and auth } else { // Clear databases from the durability list and auth
auto locked_auth = auth->Lock(); auto locked_auth = auth_.Lock();
auto it = durability_->begin(std::string{kDBPrefix}); auto it = durability_->begin(std::string{kDBPrefix});
auto end = durability_->end(std::string{kDBPrefix}); auto end = durability_->end(std::string{kDBPrefix});
for (; it != end; ++it) { for (; it != end; ++it) {
@ -172,8 +208,6 @@ DbmsHandler::DbmsHandler(
locked_auth->DeleteDatabase(name); locked_auth->DeleteDatabase(name);
durability_->Delete(key); durability_->Delete(key);
} }
// Delete the last timestamp
durability_->Delete(kLastCommitedSystemTsKey);
} }
/* /*
@ -198,45 +232,29 @@ DbmsHandler::DbmsHandler(
*/ */
// Setup the default DB // Setup the default DB
SetupDefault_(); SetupDefault_();
}
/* struct DropDatabase : memgraph::system::ISystemAction {
* REPLICATION RECOVERY AND STARTUP explicit DropDatabase(utils::UUID uuid) : uuid_{uuid} {}
*/ void DoDurability() override { /* Done during DBMS execution */
// Startup replication state (if recovered at startup)
auto replica = [this](replication::RoleReplicaData const &data) { return StartRpcServer(*this, data); };
// Replication recovery and frequent check start
auto main = [this](replication::RoleMainData &data) {
for (auto &client : data.registered_replicas_) {
SystemRestore(client);
} }
ForEach([this](DatabaseAccess db) { RecoverReplication(db); });
for (auto &client : data.registered_replicas_) { bool DoReplication(replication::ReplicationClient &client, replication::ReplicationEpoch const &epoch,
StartReplicaClient(*this, client); memgraph::system::Transaction const &txn) const override {
} auto check_response = [](const storage::replication::DropDatabaseRes &response) {
return true; return response.result != storage::replication::DropDatabaseRes::Result::FAILURE;
}; };
// Startup proccess for main/replica
MG_ASSERT(std::visit(memgraph::utils::Overloaded{replica, main}, repl_state_.ReplicationData()),
"Replica recovery failure!");
// Warning return client.SteamAndFinalizeDelta<storage::replication::DropDatabaseRpc>(
if (default_config_.durability.snapshot_wal_mode == storage::Config::Durability::SnapshotWalMode::DISABLED && check_response, epoch.id(), txn.last_committed_system_timestamp(), txn.timestamp(), uuid_);
repl_state_.IsMain()) {
spdlog::warn(
"The instance has the MAIN replication role, but durability logs and snapshots are disabled. Please "
"consider "
"enabling durability by using --storage-snapshot-interval-sec and --storage-wal-enabled flags because "
"without write-ahead logs this instance is not replicating any data.");
} }
void PostReplication(replication::RoleMainData &mainData) const override {}
// MAIN or REPLICA instance private:
if (FLAGS_coordinator_server_port) { utils::UUID uuid_;
CoordinatorHandlers::Register(*this); };
MG_ASSERT(coordinator_state_.GetCoordinatorServer().Start(), "Failed to start coordinator server!");
}
}
DbmsHandler::DeleteResult DbmsHandler::TryDelete(std::string_view db_name) { DbmsHandler::DeleteResult DbmsHandler::TryDelete(std::string_view db_name, system::Transaction *transaction) {
std::lock_guard<LockT> wr(lock_); std::lock_guard<LockT> wr(lock_);
if (db_name == kDefaultDB) { if (db_name == kDefaultDB) {
// MSG cannot delete the default db // MSG cannot delete the default db
@ -273,9 +291,10 @@ DbmsHandler::DeleteResult DbmsHandler::TryDelete(std::string_view db_name) {
// Success // Success
// Save delta // Save delta
if (system_transaction_) { if (transaction) {
system_transaction_->delta.emplace(SystemTransaction::Delta::drop_database, uuid); transaction->AddAction<DropDatabase>(uuid);
} }
return {}; return {};
} }
@ -296,18 +315,48 @@ DbmsHandler::DeleteResult DbmsHandler::Delete(utils::UUID uuid) {
return Delete_(db_name); return Delete_(db_name);
} }
DbmsHandler::NewResultT DbmsHandler::New_(storage::Config storage_config) { struct CreateDatabase : memgraph::system::ISystemAction {
explicit CreateDatabase(storage::SalientConfig config, DatabaseAccess db_acc)
: config_{std::move(config)}, db_acc(db_acc) {}
void DoDurability() override {
// Done during dbms execution
}
bool DoReplication(replication::ReplicationClient &client, replication::ReplicationEpoch const &epoch,
memgraph::system::Transaction const &txn) const override {
auto check_response = [](const storage::replication::CreateDatabaseRes &response) {
return response.result != storage::replication::CreateDatabaseRes::Result::FAILURE;
};
return client.SteamAndFinalizeDelta<storage::replication::CreateDatabaseRpc>(
check_response, epoch.id(), txn.last_committed_system_timestamp(), txn.timestamp(), config_);
}
void PostReplication(replication::RoleMainData &mainData) const override {
// Sync database with REPLICAs
// NOTE: The function bellow is used to create ReplicationStorageClient, so it must be called on a new storage
// We don't need to have it here, since the function won't fail even if the replication client fails to
// connect We will just have everything ready, for recovery at some point.
dbms::DbmsHandler::RecoverStorageReplication(db_acc, mainData);
}
private:
storage::SalientConfig config_;
DatabaseAccess db_acc;
};
DbmsHandler::NewResultT DbmsHandler::New_(storage::Config storage_config, system::Transaction *txn) {
auto new_db = db_handler_.New(storage_config, repl_state_); auto new_db = db_handler_.New(storage_config, repl_state_);
if (new_db.HasValue()) { // Success if (new_db.HasValue()) { // Success
// Save delta // Save delta
if (system_transaction_) {
system_transaction_->delta.emplace(SystemTransaction::Delta::create_database, storage_config.salient);
}
UpdateDurability(storage_config); UpdateDurability(storage_config);
return new_db.GetValue(); if (txn) {
txn->AddAction<CreateDatabase>(storage_config.salient, new_db.GetValue());
} }
return new_db.GetError(); }
return new_db;
} }
DbmsHandler::DeleteResult DbmsHandler::Delete_(std::string_view db_name) { DbmsHandler::DeleteResult DbmsHandler::Delete_(std::string_view db_name) {
@ -361,89 +410,16 @@ void DbmsHandler::UpdateDurability(const storage::Config &config, std::optional<
durability_->Put(key, val); durability_->Put(key, val);
} }
AllSyncReplicaStatus DbmsHandler::Commit() {
if (system_transaction_ == std::nullopt || system_transaction_->delta == std::nullopt)
return AllSyncReplicaStatus::AllCommitsConfirmed; // Nothing to commit
const auto &delta = *system_transaction_->delta;
auto sync_status = AllSyncReplicaStatus::AllCommitsConfirmed;
// TODO Create a system client that can handle all of this automatically
switch (delta.action) {
using enum SystemTransaction::Delta::Action;
case CREATE_DATABASE: {
// Replication
auto main_handler = [&](memgraph::replication::RoleMainData &main_data) {
// TODO: data race issue? registered_replicas_ access not protected
// This is sync in any case, as this is the startup
for (auto &client : main_data.registered_replicas_) {
bool completed = SteamAndFinalizeDelta<storage::replication::CreateDatabaseRpc>(
client,
[](const storage::replication::CreateDatabaseRes &response) {
return response.result != storage::replication::CreateDatabaseRes::Result::FAILURE;
},
std::string(main_data.epoch_.id()), last_commited_system_timestamp_,
system_transaction_->system_timestamp, delta.config);
// TODO: reduce duplicate code
if (!completed && client.mode_ == replication_coordination_glue::ReplicationMode::SYNC) {
sync_status = AllSyncReplicaStatus::SomeCommitsUnconfirmed;
}
}
// Sync database with REPLICAs
RecoverReplication(Get_(delta.config.name));
};
auto replica_handler = [](memgraph::replication::RoleReplicaData &) { /* Nothing to do */ };
std::visit(utils::Overloaded{main_handler, replica_handler}, repl_state_.ReplicationData());
} break;
case DROP_DATABASE: {
// Replication
auto main_handler = [&](memgraph::replication::RoleMainData &main_data) {
// TODO: data race issue? registered_replicas_ access not protected
// This is sync in any case, as this is the startup
for (auto &client : main_data.registered_replicas_) {
bool completed = SteamAndFinalizeDelta<storage::replication::DropDatabaseRpc>(
client,
[](const storage::replication::DropDatabaseRes &response) {
return response.result != storage::replication::DropDatabaseRes::Result::FAILURE;
},
std::string(main_data.epoch_.id()), last_commited_system_timestamp_,
system_transaction_->system_timestamp, delta.uuid);
// TODO: reduce duplicate code
if (!completed && client.mode_ == replication_coordination_glue::ReplicationMode::SYNC) {
sync_status = AllSyncReplicaStatus::SomeCommitsUnconfirmed;
}
}
};
auto replica_handler = [](memgraph::replication::RoleReplicaData &) { /* Nothing to do */ };
std::visit(utils::Overloaded{main_handler, replica_handler}, repl_state_.ReplicationData());
} break;
}
durability_->Put(kLastCommitedSystemTsKey, std::to_string(system_transaction_->system_timestamp));
last_commited_system_timestamp_ = system_transaction_->system_timestamp;
ResetSystemTransaction();
return sync_status;
}
#else // not MG_ENTERPRISE
AllSyncReplicaStatus DbmsHandler::Commit() {
if (system_transaction_ == std::nullopt || system_transaction_->delta == std::nullopt) {
return AllSyncReplicaStatus::AllCommitsConfirmed; // Nothing to commit
}
const auto &delta = *system_transaction_->delta;
switch (delta.action) {
using enum SystemTransaction::Delta::Action;
case CREATE_DATABASE:
case DROP_DATABASE:
/* Community edition doesn't support multi-tenant replication */
break;
}
last_commited_system_timestamp_ = system_transaction_->system_timestamp;
ResetSystemTransaction();
return AllSyncReplicaStatus::AllCommitsConfirmed;
}
#endif #endif
void DbmsHandler::RecoverStorageReplication(DatabaseAccess db_acc, replication::RoleMainData &role_main_data) {
if (allow_mt_repl || db_acc->name() == dbms::kDefaultDB) {
// Handle global replication state
spdlog::info("Replication configuration will be stored and will be automatically restored in case of a crash.");
// RECOVER REPLICA CONNECTIONS
memgraph::dbms::RestoreReplication(role_main_data, db_acc);
} else if (!role_main_data.registered_replicas_.empty()) {
spdlog::warn("Multi-tenant replication is currently not supported!");
}
}
} // namespace memgraph::dbms } // namespace memgraph::dbms

View File

@ -25,24 +25,24 @@
#include "constants.hpp" #include "constants.hpp"
#include "dbms/database.hpp" #include "dbms/database.hpp"
#include "dbms/inmemory/replication_handlers.hpp" #include "dbms/inmemory/replication_handlers.hpp"
#include "dbms/replication_handler.hpp" #include "dbms/rpc.hpp"
#include "kvstore/kvstore.hpp" #include "kvstore/kvstore.hpp"
#include "license/license.hpp"
#include "replication/replication_client.hpp" #include "replication/replication_client.hpp"
#include "storage/v2/config.hpp" #include "storage/v2/config.hpp"
#include "storage/v2/replication/enums.hpp"
#include "storage/v2/replication/rpc.hpp"
#include "storage/v2/transaction.hpp" #include "storage/v2/transaction.hpp"
#include "system/system.hpp"
#include "utils/thread_pool.hpp" #include "utils/thread_pool.hpp"
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
#include "coordination/coordinator_state.hpp" #include "coordination/coordinator_state.hpp"
#include "dbms/database_handler.hpp" #include "dbms/database_handler.hpp"
#endif #endif
#include "dbms/transaction.hpp"
#include "global.hpp" #include "global.hpp"
#include "query/config.hpp" #include "query/config.hpp"
#include "query/interpreter_context.hpp" #include "query/interpreter_context.hpp"
#include "spdlog/spdlog.h" #include "spdlog/spdlog.h"
#include "storage/v2/isolation_level.hpp" #include "storage/v2/isolation_level.hpp"
#include "system/system.hpp"
#include "utils/logging.hpp" #include "utils/logging.hpp"
#include "utils/result.hpp" #include "utils/result.hpp"
#include "utils/rw_lock.hpp" #include "utils/rw_lock.hpp"
@ -51,11 +51,6 @@
namespace memgraph::dbms { namespace memgraph::dbms {
enum class AllSyncReplicaStatus {
AllCommitsConfirmed,
SomeCommitsUnconfirmed,
};
struct Statistics { struct Statistics {
uint64_t num_vertex; //!< Sum of vertexes in every database uint64_t num_vertex; //!< Sum of vertexes in every database
uint64_t num_edges; //!< Sum of edges in every database uint64_t num_edges; //!< Sum of edges in every database
@ -111,8 +106,8 @@ class DbmsHandler {
* @param auth pointer to the global authenticator * @param auth pointer to the global authenticator
* @param recovery_on_startup restore databases (and its content) and authentication data * @param recovery_on_startup restore databases (and its content) and authentication data
*/ */
DbmsHandler(storage::Config config, DbmsHandler(storage::Config config, memgraph::system::System &system, replication::ReplicationState &repl_state,
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth, auth::SynchedAuth &auth,
bool recovery_on_startup); // TODO If more arguments are added use a config struct bool recovery_on_startup); // TODO If more arguments are added use a config struct
#else #else
/** /**
@ -120,15 +115,14 @@ class DbmsHandler {
* *
* @param configs storage configuration * @param configs storage configuration
*/ */
DbmsHandler(storage::Config config) DbmsHandler(storage::Config config, memgraph::system::System &system, replication::ReplicationState &repl_state)
: repl_state_{ReplicationStateRootPath(config)}, : repl_state_{repl_state},
system_{&system},
db_gatekeeper_{[&] { db_gatekeeper_{[&] {
config.salient.name = kDefaultDB; config.salient.name = kDefaultDB;
return std::move(config); return std::move(config);
}(), }(),
repl_state_} { repl_state_} {}
RecoverReplication(Get());
}
#endif #endif
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
@ -138,10 +132,10 @@ class DbmsHandler {
* @param name name of the database * @param name name of the database
* @return NewResultT context on success, error on failure * @return NewResultT context on success, error on failure
*/ */
NewResultT New(const std::string &name) { NewResultT New(const std::string &name, system::Transaction *txn = nullptr) {
std::lock_guard<LockT> wr(lock_); std::lock_guard<LockT> wr(lock_);
const auto uuid = utils::UUID{}; const auto uuid = utils::UUID{};
return New_(name, uuid); return New_(name, uuid, txn);
} }
/** /**
@ -234,7 +228,7 @@ class DbmsHandler {
* @param db_name database name * @param db_name database name
* @return DeleteResult error on failure * @return DeleteResult error on failure
*/ */
DeleteResult TryDelete(std::string_view db_name); DeleteResult TryDelete(std::string_view db_name, system::Transaction *transaction = nullptr);
/** /**
* @brief Delete or defer deletion of database. * @brief Delete or defer deletion of database.
@ -267,23 +261,12 @@ class DbmsHandler {
#endif #endif
} }
replication::ReplicationState &ReplicationState() { return repl_state_; }
replication::ReplicationState const &ReplicationState() const { return repl_state_; }
bool IsMain() const { return repl_state_.IsMain(); }
bool IsReplica() const { return repl_state_.IsReplica(); }
#ifdef MG_ENTERPRISE
coordination::CoordinatorState &CoordinatorState() { return coordinator_state_; }
#endif
/** /**
* @brief Return the statistics all databases. * @brief Return the statistics all databases.
* *
* @return Statistics * @return Statistics
*/ */
Statistics Stats() { Statistics Stats(memgraph::replication_coordination_glue::ReplicationRole replication_role) {
auto const replication_role = repl_state_.GetRole();
Statistics stats{}; Statistics stats{};
// TODO: Handle overflow? // TODO: Handle overflow?
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
@ -319,8 +302,7 @@ class DbmsHandler {
* *
* @return std::vector<DatabaseInfo> * @return std::vector<DatabaseInfo>
*/ */
std::vector<DatabaseInfo> Info() { std::vector<DatabaseInfo> Info(memgraph::replication_coordination_glue::ReplicationRole replication_role) {
auto const replication_role = repl_state_.GetRole();
std::vector<DatabaseInfo> res; std::vector<DatabaseInfo> res;
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
std::shared_lock<LockT> rd(lock_); std::shared_lock<LockT> rd(lock_);
@ -407,98 +389,17 @@ class DbmsHandler {
} }
} }
void NewSystemTransaction() { static void RecoverStorageReplication(DatabaseAccess db_acc, replication::RoleMainData &role_main_data);
DMG_ASSERT(!system_transaction_, "Already running a system transaction");
system_transaction_.emplace(++system_timestamp_);
}
void ResetSystemTransaction() { system_transaction_.reset(); }
//! \tparam RPC An rpc::RequestResponse
//! \tparam Args the args type
//! \param client the client to use for rpc communication
//! \param check predicate to check response is ok
//! \param args arguments to forward to the rpc request
//! \return If replica stream is completed or enqueued
template <typename RPC, typename... Args>
bool SteamAndFinalizeDelta(auto &client, auto &&check, Args &&...args) {
try {
auto stream = client.rpc_client_.template Stream<RPC>(std::forward<Args>(args)...);
auto task = [&client, check = std::forward<decltype(check)>(check), stream = std::move(stream)]() mutable {
if (stream.IsDefunct()) {
client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; });
return false;
}
try {
if (check(stream.AwaitResponse())) {
return true;
}
} catch (memgraph::rpc::GenericRpcFailedException const &e) {
// swallow error, fallthrough to error handling
}
// This replica needs SYSTEM recovery
client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; });
return false;
};
if (client.mode_ == memgraph::replication_coordination_glue::ReplicationMode::ASYNC) {
client.thread_pool_.AddTask([task = utils::CopyMovableFunctionWrapper{std::move(task)}]() mutable { task(); });
return true;
}
return task();
} catch (memgraph::rpc::GenericRpcFailedException const &e) {
// This replica needs SYSTEM recovery
client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; });
return false;
}
};
AllSyncReplicaStatus Commit();
auto LastCommitedTS() const -> uint64_t { return last_commited_system_timestamp_; }
void SetLastCommitedTS(uint64_t new_ts) { last_commited_system_timestamp_.store(new_ts); }
auto default_config() const -> storage::Config const & {
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
// When being called by intepreter no need to gain lock, it should already be under a system transaction return default_config_;
// But concurrently the FrequentCheck is running and will need to lock before reading last_commited_system_timestamp_ #else
template <bool REQUIRE_LOCK = false> const auto acc = db_gatekeeper_.access();
void SystemRestore(replication::ReplicationClient &client) { MG_ASSERT(acc, "Failed to get default database!");
// Check if system is up to date return acc->get()->config();
if (client.state_.WithLock(
[](auto &state) { return state == memgraph::replication::ReplicationClient::State::READY; }))
return;
// Try to recover...
{
auto [database_configs, last_commited_system_timestamp] = std::invoke([&] {
auto sys_guard =
std::unique_lock{system_lock_, std::defer_lock}; // ensure no other system transaction in progress
if constexpr (REQUIRE_LOCK) {
sys_guard.lock();
}
auto configs = std::vector<storage::SalientConfig>{};
ForEach([&configs](DatabaseAccess acc) { configs.emplace_back(acc->config().salient); });
return std::pair{configs, last_commited_system_timestamp_.load()};
});
try {
auto stream = client.rpc_client_.Stream<storage::replication::SystemRecoveryRpc>(last_commited_system_timestamp,
std::move(database_configs));
const auto response = stream.AwaitResponse();
if (response.result == storage::replication::SystemRecoveryRes::Result::FAILURE) {
client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; });
return;
}
} catch (memgraph::rpc::GenericRpcFailedException const &e) {
client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; });
return;
}
}
// Successfully recovered
client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::READY; });
}
#endif #endif
}
private: private:
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
@ -524,7 +425,8 @@ class DbmsHandler {
* @param uuid undelying RocksDB directory * @param uuid undelying RocksDB directory
* @return NewResultT context on success, error on failure * @return NewResultT context on success, error on failure
*/ */
NewResultT New_(std::string_view name, utils::UUID uuid, std::optional<std::filesystem::path> rel_dir = {}) { NewResultT New_(std::string_view name, utils::UUID uuid, system::Transaction *txn = nullptr,
std::optional<std::filesystem::path> rel_dir = {}) {
auto config_copy = default_config_; auto config_copy = default_config_;
config_copy.salient.name = name; config_copy.salient.name = name;
config_copy.salient.uuid = uuid; config_copy.salient.uuid = uuid;
@ -535,7 +437,7 @@ class DbmsHandler {
storage::UpdatePaths(config_copy, storage::UpdatePaths(config_copy,
default_config_.durability.storage_directory / kMultiTenantDir / std::string{uuid}); default_config_.durability.storage_directory / kMultiTenantDir / std::string{uuid});
} }
return New_(std::move(config_copy)); return New_(std::move(config_copy), txn);
} }
/** /**
@ -544,11 +446,11 @@ class DbmsHandler {
* @param config configuration to be used * @param config configuration to be used
* @return NewResultT context on success, error on failure * @return NewResultT context on success, error on failure
*/ */
NewResultT New_(const storage::SalientConfig &config) { NewResultT New_(const storage::SalientConfig &config, system::Transaction *txn = nullptr) {
auto config_copy = default_config_; auto config_copy = default_config_;
config_copy.salient = config; // name, uuid, mode, etc config_copy.salient = config; // name, uuid, mode, etc
UpdatePaths(config_copy, config_copy.durability.storage_directory / kMultiTenantDir / std::string{config.uuid}); UpdatePaths(config_copy, config_copy.durability.storage_directory / kMultiTenantDir / std::string{config.uuid});
return New_(std::move(config_copy)); return New_(std::move(config_copy), txn);
} }
/** /**
@ -557,7 +459,7 @@ class DbmsHandler {
* @param storage_config storage configuration * @param storage_config storage configuration
* @return NewResultT context on success, error on failure * @return NewResultT context on success, error on failure
*/ */
NewResultT New_(storage::Config storage_config); DbmsHandler::NewResultT New_(storage::Config storage_config, system::Transaction *txn = nullptr);
// TODO: new overload of Delete_ with DatabaseAccess // TODO: new overload of Delete_ with DatabaseAccess
DeleteResult Delete_(std::string_view db_name); DeleteResult Delete_(std::string_view db_name);
@ -572,7 +474,8 @@ class DbmsHandler {
Get(kDefaultDB); Get(kDefaultDB);
} catch (const UnknownDatabaseException &) { } catch (const UnknownDatabaseException &) {
// No default DB restored, create it // No default DB restored, create it
MG_ASSERT(New_(kDefaultDB, {/* random UUID */}, ".").HasValue(), "Failed while creating the default database"); MG_ASSERT(New_(kDefaultDB, {/* random UUID */}, nullptr, ".").HasValue(),
"Failed while creating the default database");
} }
// For back-compatibility... // For back-compatibility...
@ -659,35 +562,24 @@ class DbmsHandler {
} }
#endif #endif
void RecoverReplication(DatabaseAccess db_acc) {
if (allow_mt_repl || db_acc->name() == dbms::kDefaultDB) {
// Handle global replication state
spdlog::info("Replication configuration will be stored and will be automatically restored in case of a crash.");
// RECOVER REPLICA CONNECTIONS
memgraph::dbms::RestoreReplication(repl_state_, std::move(db_acc));
} else if (const ::memgraph::replication::RoleMainData *data =
std::get_if<::memgraph::replication::RoleMainData>(&repl_state_.ReplicationData());
data && !data->registered_replicas_.empty()) {
spdlog::warn("Multi-tenant replication is currently not supported!");
}
}
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
mutable LockT lock_{utils::RWLock::Priority::READ}; //!< protective lock mutable LockT lock_{utils::RWLock::Priority::READ}; //!< protective lock
storage::Config default_config_; //!< Storage configuration used when creating new databases storage::Config default_config_; //!< Storage configuration used when creating new databases
DatabaseHandler db_handler_; //!< multi-tenancy storage handler DatabaseHandler db_handler_; //!< multi-tenancy storage handler
// TODO: move to be common
std::unique_ptr<kvstore::KVStore> durability_; //!< list of active dbs (pointer so we can postpone its creation) std::unique_ptr<kvstore::KVStore> durability_; //!< list of active dbs (pointer so we can postpone its creation)
coordination::CoordinatorState coordinator_state_; //!< Replication coordinator auth::SynchedAuth &auth_; //!< Synchronized auth::Auth
#endif #endif
// TODO: Make an api
public:
utils::ResourceLock system_lock_{}; //!> Ensure exclusive access for system queries
private: private:
std::optional<SystemTransaction> system_transaction_; //!< Current system transaction (only one at a time) // NOTE: atm the only reason this exists here, is because we pass it into the construction of New Database's
uint64_t system_timestamp_{storage::kTimestampInitialId}; //!< System timestamp // Database only uses it as a convience to make the correct Access without out needing to be told the
std::atomic_uint64_t last_commited_system_timestamp_{ // current replication role. TODO: make Database Access explicit about the role and remove this from
storage::kTimestampInitialId}; //!< Last commited system timestamp // dbms stuff
replication::ReplicationState repl_state_; //!< Global replication state replication::ReplicationState &repl_state_; //!< Ref to global replication state
public:
// TODO fix to be non public/remove from dbms....maybe
system::System *system_;
#ifndef MG_ENTERPRISE #ifndef MG_ENTERPRISE
mutable utils::Gatekeeper<Database> db_gatekeeper_; //!< Single databases gatekeeper mutable utils::Gatekeeper<Database> db_gatekeeper_; //!< Single databases gatekeeper
#endif #endif

View File

@ -11,10 +11,6 @@
#pragma once #pragma once
#include <variant>
#include "dbms/constants.hpp"
#include "dbms/replication_handler.hpp"
#include "replication/state.hpp" #include "replication/state.hpp"
#include "storage/v2/config.hpp" #include "storage/v2/config.hpp"
#include "storage/v2/inmemory/storage.hpp" #include "storage/v2/inmemory/storage.hpp"

View File

@ -1,43 +0,0 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "dbms/replication_client.hpp"
#include "replication/replication_client.hpp"
namespace memgraph::dbms {
void StartReplicaClient(DbmsHandler &dbms_handler, replication::ReplicationClient &client) {
// No client error, start instance level client
auto const &endpoint = client.rpc_client_.Endpoint();
spdlog::trace("Replication client started at: {}:{}", endpoint.address, endpoint.port);
client.StartFrequentCheck([&dbms_handler](bool reconnect, replication::ReplicationClient &client) {
// Working connection
// Check if system needs restoration
if (reconnect) {
client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; });
}
#ifdef MG_ENTERPRISE
dbms_handler.SystemRestore<true>(client);
#endif
// Check if any database has been left behind
dbms_handler.ForEach([&name = client.name_, reconnect](dbms::DatabaseAccess db_acc) {
// Specific database <-> replica client
db_acc->storage()->repl_storage_state_.WithClient(name, [&](storage::ReplicationStorageClient *client) {
if (reconnect || client->State() == storage::replication::ReplicaState::MAYBE_BEHIND) {
// Database <-> replica might be behind, check and recover
client->TryCheckReplicaStateAsync(db_acc->storage(), db_acc);
}
});
});
});
} // namespace memgraph::dbms
} // namespace memgraph::dbms

View File

@ -1,400 +0,0 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "dbms/replication_handler.hpp"
#include <algorithm>
#include "dbms/constants.hpp"
#include "dbms/dbms_handler.hpp"
#include "dbms/global.hpp"
#include "dbms/inmemory/replication_handlers.hpp"
#include "dbms/replication_client.hpp"
#include "dbms/utils.hpp"
#include "replication/messages.hpp"
#include "replication/state.hpp"
#include "spdlog/spdlog.h"
#include "storage/v2/config.hpp"
#include "storage/v2/replication/rpc.hpp"
#include "utils/on_scope_exit.hpp"
using memgraph::replication::RoleMainData;
using memgraph::replication::RoleReplicaData;
namespace memgraph::dbms {
namespace {
std::string RegisterReplicaErrorToString(RegisterReplicaError error) {
switch (error) {
using enum RegisterReplicaError;
case NAME_EXISTS:
return "NAME_EXISTS";
case ENDPOINT_EXISTS:
return "ENDPOINT_EXISTS";
case CONNECTION_FAILED:
return "CONNECTION_FAILED";
case COULD_NOT_BE_PERSISTED:
return "COULD_NOT_BE_PERSISTED";
}
}
} // namespace
ReplicationHandler::ReplicationHandler(DbmsHandler &dbms_handler) : dbms_handler_(dbms_handler) {}
bool ReplicationHandler::SetReplicationRoleMain() {
auto const main_handler = [](RoleMainData &) {
// If we are already MAIN, we don't want to change anything
return false;
};
auto const replica_handler = [this](RoleReplicaData const &) {
return memgraph::dbms::DoReplicaToMainPromotion(dbms_handler_);
};
// TODO: under lock
return std::visit(utils::Overloaded{main_handler, replica_handler},
dbms_handler_.ReplicationState().ReplicationData());
}
bool ReplicationHandler::SetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config) {
// We don't want to restart the server if we're already a REPLICA
if (dbms_handler_.ReplicationState().IsReplica()) {
return false;
}
// TODO StorageState needs to be synched. Could have a dangling reference if someone adds a database as we are
// deleting the replica.
// Remove database specific clients
dbms_handler_.ForEach([&](DatabaseAccess db_acc) {
auto *storage = db_acc->storage();
storage->repl_storage_state_.replication_clients_.WithLock([](auto &clients) { clients.clear(); });
});
// Remove instance level clients
std::get<RoleMainData>(dbms_handler_.ReplicationState().ReplicationData()).registered_replicas_.clear();
// Creates the server
dbms_handler_.ReplicationState().SetReplicationRoleReplica(config);
// Start
const auto success =
std::visit(utils::Overloaded{[](RoleMainData const &) {
// ASSERT
return false;
},
[this](RoleReplicaData const &data) { return StartRpcServer(dbms_handler_, data); }},
dbms_handler_.ReplicationState().ReplicationData());
// TODO Handle error (restore to main?)
return success;
}
auto ReplicationHandler::RegisterReplica(const memgraph::replication::ReplicationClientConfig &config)
-> memgraph::utils::BasicResult<RegisterReplicaError> {
MG_ASSERT(dbms_handler_.ReplicationState().IsMain(), "Only main instance can register a replica!");
auto maybe_client = dbms_handler_.ReplicationState().RegisterReplica(config);
if (maybe_client.HasError()) {
switch (maybe_client.GetError()) {
case memgraph::replication::RegisterReplicaError::NOT_MAIN:
MG_ASSERT(false, "Only main instance can register a replica!");
return {};
case memgraph::replication::RegisterReplicaError::NAME_EXISTS:
return memgraph::dbms::RegisterReplicaError::NAME_EXISTS;
case memgraph::replication::RegisterReplicaError::ENDPOINT_EXISTS:
return memgraph::dbms::RegisterReplicaError::ENDPOINT_EXISTS;
case memgraph::replication::RegisterReplicaError::COULD_NOT_BE_PERSISTED:
return memgraph::dbms::RegisterReplicaError::COULD_NOT_BE_PERSISTED;
case memgraph::replication::RegisterReplicaError::SUCCESS:
break;
}
}
if (!allow_mt_repl && dbms_handler_.All().size() > 1) {
spdlog::warn("Multi-tenant replication is currently not supported!");
}
#ifdef MG_ENTERPRISE
// Update system before enabling individual storage <-> replica clients
dbms_handler_.SystemRestore(*maybe_client.GetValue());
#endif
const auto dbms_error = memgraph::dbms::HandleRegisterReplicaStatus(maybe_client);
if (dbms_error.has_value()) {
return *dbms_error;
}
auto &instance_client_ptr = maybe_client.GetValue();
const bool all_clients_good = memgraph::dbms::RegisterAllDatabasesClients(dbms_handler_, *instance_client_ptr);
// NOTE Currently if any databases fails, we revert back
if (!all_clients_good) {
spdlog::error("Failed to register all databases on the REPLICA \"{}\"", config.name);
UnregisterReplica(config.name);
return RegisterReplicaError::CONNECTION_FAILED;
}
// No client error, start instance level client
StartReplicaClient(dbms_handler_, *instance_client_ptr);
return {};
}
auto ReplicationHandler::UnregisterReplica(std::string_view name) -> UnregisterReplicaResult {
auto const replica_handler = [](RoleReplicaData const &) -> UnregisterReplicaResult {
return UnregisterReplicaResult::NOT_MAIN;
};
auto const main_handler = [this, name](RoleMainData &mainData) -> UnregisterReplicaResult {
if (!dbms_handler_.ReplicationState().TryPersistUnregisterReplica(name)) {
return UnregisterReplicaResult::COULD_NOT_BE_PERSISTED;
}
// Remove database specific clients
dbms_handler_.ForEach([name](DatabaseAccess db_acc) {
db_acc->storage()->repl_storage_state_.replication_clients_.WithLock([&name](auto &clients) {
std::erase_if(clients, [name](const auto &client) { return client->Name() == name; });
});
});
// Remove instance level clients
auto const n_unregistered =
std::erase_if(mainData.registered_replicas_, [name](auto const &client) { return client.name_ == name; });
return n_unregistered != 0 ? UnregisterReplicaResult::SUCCESS : UnregisterReplicaResult::CAN_NOT_UNREGISTER;
};
return std::visit(utils::Overloaded{main_handler, replica_handler},
dbms_handler_.ReplicationState().ReplicationData());
}
auto ReplicationHandler::GetRole() const -> memgraph::replication_coordination_glue::ReplicationRole {
return dbms_handler_.ReplicationState().GetRole();
}
bool ReplicationHandler::IsMain() const { return dbms_handler_.ReplicationState().IsMain(); }
bool ReplicationHandler::IsReplica() const { return dbms_handler_.ReplicationState().IsReplica(); }
// Per storage
// NOTE Storage will connect to all replicas. Future work might change this
void RestoreReplication(replication::ReplicationState &repl_state, DatabaseAccess db_acc) {
spdlog::info("Restoring replication role.");
/// MAIN
auto const recover_main = [db_acc = std::move(db_acc)](RoleMainData &mainData) mutable { // NOLINT
// Each individual client has already been restored and started. Here we just go through each database and start its
// client
for (auto &instance_client : mainData.registered_replicas_) {
spdlog::info("Replica {} restoration started for {}.", instance_client.name_, db_acc->name());
const auto &ret = db_acc->storage()->repl_storage_state_.replication_clients_.WithLock(
[&, db_acc](auto &storage_clients) mutable -> utils::BasicResult<RegisterReplicaError> {
auto client = std::make_unique<storage::ReplicationStorageClient>(instance_client);
auto *storage = db_acc->storage();
client->Start(storage, std::move(db_acc));
// After start the storage <-> replica state should be READY or RECOVERING (if correctly started)
// MAYBE_BEHIND isn't a statement of the current state, this is the default value
// Failed to start due to branching of MAIN and REPLICA
if (client->State() == storage::replication::ReplicaState::MAYBE_BEHIND) {
spdlog::warn("Connection failed when registering replica {}. Replica will still be registered.",
instance_client.name_);
}
storage_clients.push_back(std::move(client));
return {};
});
if (ret.HasError()) {
MG_ASSERT(RegisterReplicaError::CONNECTION_FAILED != ret.GetError());
LOG_FATAL("Failure when restoring replica {}: {}.", instance_client.name_,
RegisterReplicaErrorToString(ret.GetError()));
}
spdlog::info("Replica {} restored for {}.", instance_client.name_, db_acc->name());
}
spdlog::info("Replication role restored to MAIN.");
};
/// REPLICA
auto const recover_replica = [](RoleReplicaData const &data) { /*nothing to do*/ };
std::visit(
utils::Overloaded{
recover_main,
recover_replica,
},
repl_state.ReplicationData());
}
namespace system_replication {
#ifdef MG_ENTERPRISE
void SystemHeartbeatHandler(const uint64_t ts, slk::Reader *req_reader, slk::Builder *res_builder) {
replication::SystemHeartbeatReq req;
replication::SystemHeartbeatReq::Load(&req, req_reader);
replication::SystemHeartbeatRes res(ts);
memgraph::slk::Save(res, res_builder);
}
void CreateDatabaseHandler(DbmsHandler &dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder) {
memgraph::storage::replication::CreateDatabaseReq req;
memgraph::slk::Load(&req, req_reader);
using memgraph::storage::replication::CreateDatabaseRes;
CreateDatabaseRes res(CreateDatabaseRes::Result::FAILURE);
// Note: No need to check epoch, recovery mechanism is done by a full uptodate snapshot
// of the set of databases. Hence no history exists to maintain regarding epoch change.
// If MAIN has changed we need to check this new group_timestamp is consistent with
// what we have so far.
if (req.expected_group_timestamp != dbms_handler.LastCommitedTS()) {
spdlog::debug("CreateDatabaseHandler: bad expected timestamp {},{}", req.expected_group_timestamp,
dbms_handler.LastCommitedTS());
memgraph::slk::Save(res, res_builder);
return;
}
try {
// Create new
auto new_db = dbms_handler.Update(req.config);
if (new_db.HasValue()) {
// Successfully create db
dbms_handler.SetLastCommitedTS(req.new_group_timestamp);
res = CreateDatabaseRes(CreateDatabaseRes::Result::SUCCESS);
spdlog::debug("CreateDatabaseHandler: SUCCESS updated LCTS to {}", req.new_group_timestamp);
}
} catch (...) {
// Failure
}
memgraph::slk::Save(res, res_builder);
}
void DropDatabaseHandler(DbmsHandler &dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder) {
memgraph::storage::replication::DropDatabaseReq req;
memgraph::slk::Load(&req, req_reader);
using memgraph::storage::replication::DropDatabaseRes;
DropDatabaseRes res(DropDatabaseRes::Result::FAILURE);
// Note: No need to check epoch, recovery mechanism is done by a full uptodate snapshot
// of the set of databases. Hence no history exists to maintain regarding epoch change.
// If MAIN has changed we need to check this new group_timestamp is consistent with
// what we have so far.
if (req.expected_group_timestamp != dbms_handler.LastCommitedTS()) {
spdlog::debug("DropDatabaseHandler: bad expected timestamp {},{}", req.expected_group_timestamp,
dbms_handler.LastCommitedTS());
memgraph::slk::Save(res, res_builder);
return;
}
try {
// NOTE: Single communication channel can exist at a time, no other database can be deleted/created at the moment.
auto new_db = dbms_handler.Delete(req.uuid);
if (new_db.HasError()) {
if (new_db.GetError() == DeleteError::NON_EXISTENT) {
// Nothing to drop
dbms_handler.SetLastCommitedTS(req.new_group_timestamp);
res = DropDatabaseRes(DropDatabaseRes::Result::NO_NEED);
}
} else {
// Successfully drop db
dbms_handler.SetLastCommitedTS(req.new_group_timestamp);
res = DropDatabaseRes(DropDatabaseRes::Result::SUCCESS);
spdlog::debug("DropDatabaseHandler: SUCCESS updated LCTS to {}", req.new_group_timestamp);
}
} catch (...) {
// Failure
}
memgraph::slk::Save(res, res_builder);
}
void SystemRecoveryHandler(DbmsHandler &dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder) {
// TODO Speed up
memgraph::storage::replication::SystemRecoveryReq req;
memgraph::slk::Load(&req, req_reader);
using memgraph::storage::replication::SystemRecoveryRes;
SystemRecoveryRes res(SystemRecoveryRes::Result::FAILURE);
utils::OnScopeExit send_on_exit([&]() { memgraph::slk::Save(res, res_builder); });
// Get all current dbs
auto old = dbms_handler.All();
// Check/create the incoming dbs
for (const auto &config : req.database_configs) {
// Missing db
try {
if (dbms_handler.Update(config).HasError()) {
spdlog::debug("SystemRecoveryHandler: Failed to update database \"{}\".", config.name);
return; // Send failure on exit
}
} catch (const UnknownDatabaseException &) {
spdlog::debug("SystemRecoveryHandler: UnknownDatabaseException");
return; // Send failure on exit
}
const auto it = std::find(old.begin(), old.end(), config.name);
if (it != old.end()) old.erase(it);
}
// Delete all the leftover old dbs
for (const auto &remove_db : old) {
const auto del = dbms_handler.Delete(remove_db);
if (del.HasError()) {
// Some errors are not terminal
if (del.GetError() == DeleteError::DEFAULT_DB || del.GetError() == DeleteError::NON_EXISTENT) {
spdlog::debug("SystemRecoveryHandler: Dropped database \"{}\".", remove_db);
continue;
}
spdlog::debug("SystemRecoveryHandler: Failed to drop database \"{}\".", remove_db);
return; // Send failure on exit
}
}
// Successfully recovered
dbms_handler.SetLastCommitedTS(req.forced_group_timestamp);
spdlog::debug("SystemRecoveryHandler: SUCCESS updated LCTS to {}", req.forced_group_timestamp);
res = SystemRecoveryRes(SystemRecoveryRes::Result::SUCCESS);
}
#endif
void Register(replication::RoleReplicaData const &data, dbms::DbmsHandler &dbms_handler) {
#ifdef MG_ENTERPRISE
data.server->rpc_server_.Register<replication::SystemHeartbeatRpc>(
[&dbms_handler](auto *req_reader, auto *res_builder) {
spdlog::debug("Received SystemHeartbeatRpc");
SystemHeartbeatHandler(dbms_handler.LastCommitedTS(), req_reader, res_builder);
});
data.server->rpc_server_.Register<storage::replication::CreateDatabaseRpc>(
[&dbms_handler](auto *req_reader, auto *res_builder) {
spdlog::debug("Received CreateDatabaseRpc");
CreateDatabaseHandler(dbms_handler, req_reader, res_builder);
});
data.server->rpc_server_.Register<storage::replication::DropDatabaseRpc>(
[&dbms_handler](auto *req_reader, auto *res_builder) {
spdlog::debug("Received DropDatabaseRpc");
DropDatabaseHandler(dbms_handler, req_reader, res_builder);
});
data.server->rpc_server_.Register<storage::replication::SystemRecoveryRpc>(
[&dbms_handler](auto *req_reader, auto *res_builder) {
spdlog::debug("Received SystemRecoveryRpc");
SystemRecoveryHandler(dbms_handler, req_reader, res_builder);
});
#endif
}
} // namespace system_replication
bool StartRpcServer(DbmsHandler &dbms_handler, const replication::RoleReplicaData &data) {
// Register handlers
InMemoryReplicationHandlers::Register(&dbms_handler, *data.server);
system_replication::Register(data, dbms_handler);
// Start server
if (!data.server->Start()) {
spdlog::error("Unable to start the replication server.");
return false;
}
return true;
}
} // namespace memgraph::dbms

View File

@ -1,82 +0,0 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "replication_coordination_glue/role.hpp"
#include "dbms/database.hpp"
#include "utils/result.hpp"
namespace memgraph::replication {
struct ReplicationState;
struct ReplicationServerConfig;
struct ReplicationClientConfig;
} // namespace memgraph::replication
namespace memgraph::dbms {
class DbmsHandler;
enum class RegisterReplicaError : uint8_t { NAME_EXISTS, ENDPOINT_EXISTS, CONNECTION_FAILED, COULD_NOT_BE_PERSISTED };
enum class UnregisterReplicaResult : uint8_t {
NOT_MAIN,
COULD_NOT_BE_PERSISTED,
CAN_NOT_UNREGISTER,
SUCCESS,
};
/// A handler type that keep in sync current ReplicationState and the MAIN/REPLICA-ness of Storage
/// TODO: extend to do multiple storages
struct ReplicationHandler {
explicit ReplicationHandler(DbmsHandler &dbms_handler);
// as REPLICA, become MAIN
bool SetReplicationRoleMain();
// as MAIN, become REPLICA
bool SetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config);
// as MAIN, define and connect to REPLICAs
auto RegisterReplica(const memgraph::replication::ReplicationClientConfig &config)
-> utils::BasicResult<RegisterReplicaError>;
// as MAIN, remove a REPLICA connection
auto UnregisterReplica(std::string_view name) -> UnregisterReplicaResult;
// Helper pass-through (TODO: remove)
auto GetRole() const -> memgraph::replication_coordination_glue::ReplicationRole;
bool IsMain() const;
bool IsReplica() const;
private:
DbmsHandler &dbms_handler_;
};
/// A handler type that keep in sync current ReplicationState and the MAIN/REPLICA-ness of Storage
/// TODO: extend to do multiple storages
void RestoreReplication(replication::ReplicationState &repl_state, DatabaseAccess db_acc);
namespace system_replication {
// System handlers
#ifdef MG_ENTERPRISE
void CreateDatabaseHandler(DbmsHandler &dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder);
void SystemHeartbeatHandler(uint64_t ts, slk::Reader *req_reader, slk::Builder *res_builder);
void SystemRecoveryHandler(DbmsHandler &dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder);
#endif
/// Register all DBMS level RPC handlers
void Register(replication::RoleReplicaData const &data, DbmsHandler &dbms_handler);
} // namespace system_replication
bool StartRpcServer(DbmsHandler &dbms_handler, const replication::RoleReplicaData &data);
} // namespace memgraph::dbms

View File

@ -0,0 +1,191 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "dbms/replication_handlers.hpp"
#include "dbms/database.hpp"
#include "dbms/dbms_handler.hpp"
#include "storage/v2/storage.hpp"
#include "system/state.hpp"
namespace memgraph::dbms {
#ifdef MG_ENTERPRISE
void CreateDatabaseHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access,
DbmsHandler &dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder) {
using memgraph::storage::replication::CreateDatabaseRes;
CreateDatabaseRes res(CreateDatabaseRes::Result::FAILURE);
// Ignore if no license
if (!license::global_license_checker.IsEnterpriseValidFast()) {
spdlog::error("Handling CreateDatabase, an enterprise RPC message, without license.");
memgraph::slk::Save(res, res_builder);
return;
}
memgraph::storage::replication::CreateDatabaseReq req;
memgraph::slk::Load(&req, req_reader);
// Note: No need to check epoch, recovery mechanism is done by a full uptodate snapshot
// of the set of databases. Hence no history exists to maintain regarding epoch change.
// If MAIN has changed we need to check this new group_timestamp is consistent with
// what we have so far.
if (req.expected_group_timestamp != system_state_access.LastCommitedTS()) {
spdlog::debug("CreateDatabaseHandler: bad expected timestamp {},{}", req.expected_group_timestamp,
system_state_access.LastCommitedTS());
memgraph::slk::Save(res, res_builder);
return;
}
try {
// Create new
auto new_db = dbms_handler.Update(req.config);
if (new_db.HasValue()) {
// Successfully create db
system_state_access.SetLastCommitedTS(req.new_group_timestamp);
res = CreateDatabaseRes(CreateDatabaseRes::Result::SUCCESS);
spdlog::debug("CreateDatabaseHandler: SUCCESS updated LCTS to {}", req.new_group_timestamp);
}
} catch (...) {
// Failure
}
memgraph::slk::Save(res, res_builder);
}
void DropDatabaseHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access, DbmsHandler &dbms_handler,
slk::Reader *req_reader, slk::Builder *res_builder) {
using memgraph::storage::replication::DropDatabaseRes;
DropDatabaseRes res(DropDatabaseRes::Result::FAILURE);
// Ignore if no license
if (!license::global_license_checker.IsEnterpriseValidFast()) {
spdlog::error("Handling DropDatabase, an enterprise RPC message, without license.");
memgraph::slk::Save(res, res_builder);
return;
}
memgraph::storage::replication::DropDatabaseReq req;
memgraph::slk::Load(&req, req_reader);
// Note: No need to check epoch, recovery mechanism is done by a full uptodate snapshot
// of the set of databases. Hence no history exists to maintain regarding epoch change.
// If MAIN has changed we need to check this new group_timestamp is consistent with
// what we have so far.
if (req.expected_group_timestamp != system_state_access.LastCommitedTS()) {
spdlog::debug("DropDatabaseHandler: bad expected timestamp {},{}", req.expected_group_timestamp,
system_state_access.LastCommitedTS());
memgraph::slk::Save(res, res_builder);
return;
}
try {
// NOTE: Single communication channel can exist at a time, no other database can be deleted/created at the moment.
auto new_db = dbms_handler.Delete(req.uuid);
if (new_db.HasError()) {
if (new_db.GetError() == DeleteError::NON_EXISTENT) {
// Nothing to drop
system_state_access.SetLastCommitedTS(req.new_group_timestamp);
res = DropDatabaseRes(DropDatabaseRes::Result::NO_NEED);
}
} else {
// Successfully drop db
system_state_access.SetLastCommitedTS(req.new_group_timestamp);
res = DropDatabaseRes(DropDatabaseRes::Result::SUCCESS);
spdlog::debug("DropDatabaseHandler: SUCCESS updated LCTS to {}", req.new_group_timestamp);
}
} catch (...) {
// Failure
}
memgraph::slk::Save(res, res_builder);
}
bool SystemRecoveryHandler(DbmsHandler &dbms_handler, const std::vector<storage::SalientConfig> &database_configs) {
/*
* NO LICENSE
*/
if (!license::global_license_checker.IsEnterpriseValidFast()) {
spdlog::error("Handling SystemRecovery, an enterprise RPC message, without license.");
for (const auto &config : database_configs) {
// Only handle default DB
if (config.name != kDefaultDB) continue;
try {
if (dbms_handler.Update(config).HasError()) {
return false;
}
} catch (const UnknownDatabaseException &) {
return false;
}
}
return true;
}
/*
* MULTI-TENANCY
*/
// Get all current dbs
auto old = dbms_handler.All();
// Check/create the incoming dbs
for (const auto &config : database_configs) {
// Missing db
try {
if (dbms_handler.Update(config).HasError()) {
spdlog::debug("SystemRecoveryHandler: Failed to update database \"{}\".", config.name);
return false;
}
} catch (const UnknownDatabaseException &) {
spdlog::debug("SystemRecoveryHandler: UnknownDatabaseException");
return false;
}
const auto it = std::find(old.begin(), old.end(), config.name);
if (it != old.end()) old.erase(it);
}
// Delete all the leftover old dbs
for (const auto &remove_db : old) {
const auto del = dbms_handler.Delete(remove_db);
if (del.HasError()) {
// Some errors are not terminal
if (del.GetError() == DeleteError::DEFAULT_DB || del.GetError() == DeleteError::NON_EXISTENT) {
spdlog::debug("SystemRecoveryHandler: Dropped database \"{}\".", remove_db);
continue;
}
spdlog::debug("SystemRecoveryHandler: Failed to drop database \"{}\".", remove_db);
return false;
}
}
/*
* SUCCESS
*/
return true;
}
void Register(replication::RoleReplicaData const &data, system::ReplicaHandlerAccessToState &system_state_access,
dbms::DbmsHandler &dbms_handler) {
// NOTE: Register even without license as the user could add a license at run-time
data.server->rpc_server_.Register<storage::replication::CreateDatabaseRpc>(
[system_state_access, &dbms_handler](auto *req_reader, auto *res_builder) mutable {
spdlog::debug("Received CreateDatabaseRpc");
CreateDatabaseHandler(system_state_access, dbms_handler, req_reader, res_builder);
});
data.server->rpc_server_.Register<storage::replication::DropDatabaseRpc>(
[system_state_access, &dbms_handler](auto *req_reader, auto *res_builder) mutable {
spdlog::debug("Received DropDatabaseRpc");
DropDatabaseHandler(system_state_access, dbms_handler, req_reader, res_builder);
});
}
#endif
} // namespace memgraph::dbms

View File

@ -0,0 +1,32 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "dbms/dbms_handler.hpp"
#include "replication/state.hpp"
#include "slk/streams.hpp"
#include "system/state.hpp"
namespace memgraph::dbms {
#ifdef MG_ENTERPRISE
// RPC handlers
void CreateDatabaseHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access,
DbmsHandler &dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder);
void DropDatabaseHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access, DbmsHandler &dbms_handler,
slk::Reader *req_reader, slk::Builder *res_builder);
bool SystemRecoveryHandler(DbmsHandler &dbms_handler, const std::vector<storage::SalientConfig> &database_configs);
// RPC registration
void Register(replication::RoleReplicaData const &data, system::ReplicaHandlerAccessToState &system_state_access,
dbms::DbmsHandler &dbms_handler);
#endif
} // namespace memgraph::dbms

118
src/dbms/rpc.cpp Normal file
View File

@ -0,0 +1,118 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "dbms/rpc.hpp"
#include "slk/streams.hpp"
#include "storage/v2/replication/rpc.hpp"
#include "utils/enum.hpp"
#include "utils/typeinfo.hpp"
namespace memgraph {
namespace storage::replication {
void CreateDatabaseReq::Save(const CreateDatabaseReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void CreateDatabaseReq::Load(CreateDatabaseReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
void CreateDatabaseRes::Save(const CreateDatabaseRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void CreateDatabaseRes::Load(CreateDatabaseRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
void DropDatabaseReq::Save(const DropDatabaseReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void DropDatabaseReq::Load(DropDatabaseReq *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); }
void DropDatabaseRes::Save(const DropDatabaseRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void DropDatabaseRes::Load(DropDatabaseRes *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); }
const utils::TypeInfo CreateDatabaseReq::kType{utils::TypeId::REP_CREATE_DATABASE_REQ, "CreateDatabaseReq", nullptr};
const utils::TypeInfo CreateDatabaseRes::kType{utils::TypeId::REP_CREATE_DATABASE_RES, "CreateDatabaseRes", nullptr};
const utils::TypeInfo DropDatabaseReq::kType{utils::TypeId::REP_DROP_DATABASE_REQ, "DropDatabaseReq", nullptr};
const utils::TypeInfo DropDatabaseRes::kType{utils::TypeId::REP_DROP_DATABASE_RES, "DropDatabaseRes", nullptr};
} // namespace storage::replication
// Autogenerated SLK serialization code
namespace slk {
// Serialize code for CreateDatabaseReq
void Save(const memgraph::storage::replication::CreateDatabaseReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.epoch_id, builder);
memgraph::slk::Save(self.expected_group_timestamp, builder);
memgraph::slk::Save(self.new_group_timestamp, builder);
memgraph::slk::Save(self.config, builder);
}
void Load(memgraph::storage::replication::CreateDatabaseReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->epoch_id, reader);
memgraph::slk::Load(&self->expected_group_timestamp, reader);
memgraph::slk::Load(&self->new_group_timestamp, reader);
memgraph::slk::Load(&self->config, reader);
}
// Serialize code for CreateDatabaseRes
void Save(const memgraph::storage::replication::CreateDatabaseRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(utils::EnumToNum<uint8_t>(self.result), builder);
}
void Load(memgraph::storage::replication::CreateDatabaseRes *self, memgraph::slk::Reader *reader) {
uint8_t res = 0;
memgraph::slk::Load(&res, reader);
if (!utils::NumToEnum(res, self->result)) {
throw SlkReaderException("Unexpected result line:{}!", __LINE__);
}
}
// Serialize code for DropDatabaseReq
void Save(const memgraph::storage::replication::DropDatabaseReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.epoch_id, builder);
memgraph::slk::Save(self.expected_group_timestamp, builder);
memgraph::slk::Save(self.new_group_timestamp, builder);
memgraph::slk::Save(self.uuid, builder);
}
void Load(memgraph::storage::replication::DropDatabaseReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->epoch_id, reader);
memgraph::slk::Load(&self->expected_group_timestamp, reader);
memgraph::slk::Load(&self->new_group_timestamp, reader);
memgraph::slk::Load(&self->uuid, reader);
}
// Serialize code for DropDatabaseRes
void Save(const memgraph::storage::replication::DropDatabaseRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(utils::EnumToNum<uint8_t>(self.result), builder);
}
void Load(memgraph::storage::replication::DropDatabaseRes *self, memgraph::slk::Reader *reader) {
uint8_t res = 0;
memgraph::slk::Load(&res, reader);
if (!utils::NumToEnum(res, self->result)) {
throw SlkReaderException("Unexpected result line:{}!", __LINE__);
}
}
} // namespace slk
} // namespace memgraph

118
src/dbms/rpc.hpp Normal file
View File

@ -0,0 +1,118 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <cstdint>
#include <string>
#include <utility>
#include "rpc/messages.hpp"
#include "slk/streams.hpp"
#include "storage/v2/config.hpp"
#include "utils/uuid.hpp"
namespace memgraph::storage::replication {
struct CreateDatabaseReq {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(CreateDatabaseReq *self, memgraph::slk::Reader *reader);
static void Save(const CreateDatabaseReq &self, memgraph::slk::Builder *builder);
CreateDatabaseReq() = default;
CreateDatabaseReq(std::string_view epoch_id, uint64_t expected_group_timestamp, uint64_t new_group_timestamp,
storage::SalientConfig config)
: epoch_id(std::string(epoch_id)),
expected_group_timestamp{expected_group_timestamp},
new_group_timestamp(new_group_timestamp),
config(std::move(config)) {}
std::string epoch_id;
uint64_t expected_group_timestamp;
uint64_t new_group_timestamp;
storage::SalientConfig config;
};
struct CreateDatabaseRes {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
enum class Result : uint8_t { SUCCESS, NO_NEED, FAILURE, /* Leave at end */ N };
static void Load(CreateDatabaseRes *self, memgraph::slk::Reader *reader);
static void Save(const CreateDatabaseRes &self, memgraph::slk::Builder *builder);
CreateDatabaseRes() = default;
explicit CreateDatabaseRes(Result res) : result(res) {}
Result result;
};
using CreateDatabaseRpc = rpc::RequestResponse<CreateDatabaseReq, CreateDatabaseRes>;
struct DropDatabaseReq {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(DropDatabaseReq *self, memgraph::slk::Reader *reader);
static void Save(const DropDatabaseReq &self, memgraph::slk::Builder *builder);
DropDatabaseReq() = default;
DropDatabaseReq(std::string_view epoch_id, uint64_t expected_group_timestamp, uint64_t new_group_timestamp,
const utils::UUID &uuid)
: epoch_id(std::string(epoch_id)),
expected_group_timestamp{expected_group_timestamp},
new_group_timestamp(new_group_timestamp),
uuid(uuid) {}
std::string epoch_id;
uint64_t expected_group_timestamp;
uint64_t new_group_timestamp;
utils::UUID uuid;
};
struct DropDatabaseRes {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
enum class Result : uint8_t { SUCCESS, NO_NEED, FAILURE, /* Leave at end */ N };
static void Load(DropDatabaseRes *self, memgraph::slk::Reader *reader);
static void Save(const DropDatabaseRes &self, memgraph::slk::Builder *builder);
DropDatabaseRes() = default;
explicit DropDatabaseRes(Result res) : result(res) {}
Result result;
};
using DropDatabaseRpc = rpc::RequestResponse<DropDatabaseReq, DropDatabaseRes>;
} // namespace memgraph::storage::replication
// SLK serialization declarations
namespace memgraph::slk {
void Save(const memgraph::storage::replication::CreateDatabaseReq &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::CreateDatabaseReq *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::CreateDatabaseRes &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::CreateDatabaseRes *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::DropDatabaseReq &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::DropDatabaseReq *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::DropDatabaseRes &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::DropDatabaseRes *self, memgraph::slk::Reader *reader);
} // namespace memgraph::slk

View File

@ -11,7 +11,10 @@
#pragma once #pragma once
#include <list>
#include <memory> #include <memory>
#include <optional>
#include "auth/models.hpp"
#include "storage/v2/config.hpp" #include "storage/v2/config.hpp"
namespace memgraph::dbms { namespace memgraph::dbms {
@ -20,17 +23,70 @@ struct SystemTransaction {
enum class Action { enum class Action {
CREATE_DATABASE, CREATE_DATABASE,
DROP_DATABASE, DROP_DATABASE,
UPDATE_AUTH_DATA,
DROP_AUTH_DATA,
/**
*
* CREATE USER user_name [IDENTIFIED BY 'password'];
* SET PASSWORD FOR user_name TO 'new_password';
* ^ SaveUser
*
* DROP USER user_name;
* ^ Directly on KVStore
*
* CREATE ROLE role_name;
* ^ SaveRole
*
* DROP ROLE
* ^ RemoveRole
*
* SET ROLE FOR user_name TO role_name;
* CLEAR ROLE FOR user_name;
* ^ Do stuff then do SaveUser
*
* GRANT privilege_list TO user_or_role;
* DENY AUTH, INDEX TO moderator:
* REVOKE AUTH, INDEX TO moderator:
* GRANT permission_level ON (LABELS | EDGE_TYPES) label_list TO user_or_role;
* REVOKE (LABELS | EDGE_TYPES) label_or_edge_type_list FROM user_or_role
* DENY (LABELS | EDGE_TYPES) label_or_edge_type_list TO user_or_role
* ^ all of these are EditPermissions <-> SaveUser/Role
*
* Multi-tenant TODO Doc;
* ^ Should all call SaveUser
*
*/
}; };
static constexpr struct CreateDatabase { static constexpr struct CreateDatabase {
} create_database; } create_database;
static constexpr struct DropDatabase { static constexpr struct DropDatabase {
} drop_database; } drop_database;
static constexpr struct UpdateAuthData {
} update_auth_data;
static constexpr struct DropAuthData {
} drop_auth_data;
enum class AuthData { USER, ROLE };
// Multi-tenancy
Delta(CreateDatabase /*tag*/, storage::SalientConfig config) Delta(CreateDatabase /*tag*/, storage::SalientConfig config)
: action(Action::CREATE_DATABASE), config(std::move(config)) {} : action(Action::CREATE_DATABASE), config(std::move(config)) {}
Delta(DropDatabase /*tag*/, const utils::UUID &uuid) : action(Action::DROP_DATABASE), uuid(uuid) {} Delta(DropDatabase /*tag*/, const utils::UUID &uuid) : action(Action::DROP_DATABASE), uuid(uuid) {}
// Auth
Delta(UpdateAuthData /*tag*/, std::optional<auth::User> user)
: action(Action::UPDATE_AUTH_DATA), auth_data{std::move(user), std::nullopt} {}
Delta(UpdateAuthData /*tag*/, std::optional<auth::Role> role)
: action(Action::UPDATE_AUTH_DATA), auth_data{std::nullopt, std::move(role)} {}
Delta(DropAuthData /*tag*/, AuthData type, std::string_view name)
: action(Action::DROP_AUTH_DATA),
auth_data_key{
.type = type,
.name = std::string{name},
} {}
// Generic
Delta(const Delta &) = delete; Delta(const Delta &) = delete;
Delta(Delta &&) = delete; Delta(Delta &&) = delete;
Delta &operator=(const Delta &) = delete; Delta &operator=(const Delta &) = delete;
@ -42,8 +98,14 @@ struct SystemTransaction {
std::destroy_at(&config); std::destroy_at(&config);
break; break;
case Action::DROP_DATABASE: case Action::DROP_DATABASE:
std::destroy_at(&uuid);
break;
case Action::UPDATE_AUTH_DATA:
std::destroy_at(&auth_data);
break;
case Action::DROP_AUTH_DATA:
std::destroy_at(&auth_data_key);
break; break;
// Some deltas might have special destructor handling
} }
} }
@ -51,13 +113,20 @@ struct SystemTransaction {
union { union {
storage::SalientConfig config; storage::SalientConfig config;
utils::UUID uuid; utils::UUID uuid;
struct {
std::optional<auth::User> user;
std::optional<auth::Role> role;
} auth_data;
struct {
AuthData type;
std::string name;
} auth_data_key;
}; };
}; };
explicit SystemTransaction(uint64_t timestamp) : system_timestamp(timestamp) {} explicit SystemTransaction(uint64_t timestamp) : system_timestamp(timestamp) {}
// Currently system transitions support a single delta std::list<Delta> deltas{};
std::optional<Delta> delta{};
uint64_t system_timestamp; uint64_t system_timestamp;
}; };

View File

@ -1,131 +0,0 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "dbms/dbms_handler.hpp"
#include "dbms/replication_handler.hpp"
#include "replication/include/replication/state.hpp"
#include "utils/result.hpp"
namespace memgraph::dbms {
inline bool DoReplicaToMainPromotion(dbms::DbmsHandler &dbms_handler) {
auto &repl_state = dbms_handler.ReplicationState();
// STEP 1) bring down all REPLICA servers
dbms_handler.ForEach([](DatabaseAccess db_acc) {
auto *storage = db_acc->storage();
// Remember old epoch + storage timestamp association
storage->PrepareForNewEpoch();
});
// STEP 2) Change to MAIN
// TODO: restore replication servers if false?
if (!repl_state.SetReplicationRoleMain()) {
// TODO: Handle recovery on failure???
return false;
}
// STEP 3) We are now MAIN, update storage local epoch
const auto &epoch =
std::get<replication::RoleMainData>(std::as_const(dbms_handler.ReplicationState()).ReplicationData()).epoch_;
dbms_handler.ForEach([&](DatabaseAccess db_acc) {
auto *storage = db_acc->storage();
storage->repl_storage_state_.epoch_ = epoch;
});
return true;
};
inline bool SetReplicationRoleReplica(dbms::DbmsHandler &dbms_handler,
const memgraph::replication::ReplicationServerConfig &config) {
if (dbms_handler.ReplicationState().IsReplica()) {
return false;
}
// TODO StorageState needs to be synched. Could have a dangling reference if someone adds a database as we are
// deleting the replica.
// Remove database specific clients
dbms_handler.ForEach([&](DatabaseAccess db_acc) {
auto *storage = db_acc->storage();
storage->repl_storage_state_.replication_clients_.WithLock([](auto &clients) { clients.clear(); });
});
// Remove instance level clients
std::get<replication::RoleMainData>(dbms_handler.ReplicationState().ReplicationData()).registered_replicas_.clear();
// Creates the server
dbms_handler.ReplicationState().SetReplicationRoleReplica(config);
// Start
const auto success = std::visit(utils::Overloaded{[](replication::RoleMainData const &) {
// ASSERT
return false;
},
[&dbms_handler](replication::RoleReplicaData const &data) {
return StartRpcServer(dbms_handler, data);
}},
dbms_handler.ReplicationState().ReplicationData());
// TODO Handle error (restore to main?)
return success;
}
template <bool AllowRPCFailure = false>
inline bool RegisterAllDatabasesClients(dbms::DbmsHandler &dbms_handler,
replication::ReplicationClient &instance_client) {
if (!allow_mt_repl && dbms_handler.All().size() > 1) {
spdlog::warn("Multi-tenant replication is currently not supported!");
}
bool all_clients_good = true;
dbms_handler.ForEach([&](DatabaseAccess db_acc) {
auto *storage = db_acc->storage();
if (!allow_mt_repl && storage->name() != kDefaultDB) {
return;
}
// TODO: ATM only IN_MEMORY_TRANSACTIONAL, fix other modes
if (storage->storage_mode_ != storage::StorageMode::IN_MEMORY_TRANSACTIONAL) return;
using enum storage::replication::ReplicaState;
all_clients_good &= storage->repl_storage_state_.replication_clients_.WithLock(
[storage, &instance_client, db_acc = std::move(db_acc)](auto &storage_clients) mutable { // NOLINT
auto client = std::make_unique<storage::ReplicationStorageClient>(instance_client);
client->Start(storage, std::move(db_acc));
if (client->State() == MAYBE_BEHIND && !AllowRPCFailure) {
return false;
}
storage_clients.push_back(std::move(client));
return true;
});
});
return all_clients_good;
}
inline std::optional<RegisterReplicaError> HandleRegisterReplicaStatus(
utils::BasicResult<replication::RegisterReplicaError, replication::ReplicationClient *> &instance_client) {
if (instance_client.HasError()) switch (instance_client.GetError()) {
case replication::RegisterReplicaError::NOT_MAIN:
MG_ASSERT(false, "Only main instance can register a replica!");
return {};
case replication::RegisterReplicaError::NAME_EXISTS:
return dbms::RegisterReplicaError::NAME_EXISTS;
case replication::RegisterReplicaError::ENDPOINT_EXISTS:
return dbms::RegisterReplicaError::ENDPOINT_EXISTS;
case replication::RegisterReplicaError::COULD_NOT_BE_PERSISTED:
return dbms::RegisterReplicaError::COULD_NOT_BE_PERSISTED;
case replication::RegisterReplicaError::SUCCESS:
break;
}
return {};
}
} // namespace memgraph::dbms

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd. // Copyright 2024 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -33,7 +33,7 @@ class WritePrioritizedRWLock;
struct Context { struct Context {
memgraph::query::InterpreterContext *ic; memgraph::query::InterpreterContext *ic;
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth; memgraph::auth::SynchedAuth *auth;
#if MG_ENTERPRISE #if MG_ENTERPRISE
memgraph::audit::Log *audit_log; memgraph::audit::Log *audit_log;
#endif #endif

View File

@ -319,8 +319,7 @@ void SessionHL::Configure(const std::map<std::string, memgraph::communication::b
SessionHL::SessionHL(memgraph::query::InterpreterContext *interpreter_context, SessionHL::SessionHL(memgraph::query::InterpreterContext *interpreter_context,
memgraph::communication::v2::ServerEndpoint endpoint, memgraph::communication::v2::ServerEndpoint endpoint,
memgraph::communication::v2::InputStream *input_stream, memgraph::communication::v2::InputStream *input_stream,
memgraph::communication::v2::OutputStream *output_stream, memgraph::communication::v2::OutputStream *output_stream, memgraph::auth::SynchedAuth *auth
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
, ,
memgraph::audit::Log *audit_log memgraph::audit::Log *audit_log

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd. // Copyright 2024 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -25,8 +25,7 @@ class SessionHL final : public memgraph::communication::bolt::Session<memgraph::
SessionHL(memgraph::query::InterpreterContext *interpreter_context, SessionHL(memgraph::query::InterpreterContext *interpreter_context,
memgraph::communication::v2::ServerEndpoint endpoint, memgraph::communication::v2::ServerEndpoint endpoint,
memgraph::communication::v2::InputStream *input_stream, memgraph::communication::v2::InputStream *input_stream,
memgraph::communication::v2::OutputStream *output_stream, memgraph::communication::v2::OutputStream *output_stream, memgraph::auth::SynchedAuth *auth
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
, ,
memgraph::audit::Log *audit_log memgraph::audit::Log *audit_log
@ -88,7 +87,7 @@ class SessionHL final : public memgraph::communication::bolt::Session<memgraph::
memgraph::audit::Log *audit_log_; memgraph::audit::Log *audit_log_;
bool in_explicit_db_{false}; //!< If true, the user has defined the database to use via metadata bool in_explicit_db_{false}; //!< If true, the user has defined the database to use via metadata
#endif #endif
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth_; memgraph::auth::SynchedAuth *auth_;
memgraph::communication::v2::ServerEndpoint endpoint_; memgraph::communication::v2::ServerEndpoint endpoint_;
std::optional<std::string> implicit_db_; std::optional<std::string> implicit_db_;
}; };

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd. // Copyright 2024 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -66,9 +66,7 @@ bool IsUserAuthorizedEdgeType(const memgraph::auth::User &user, const memgraph::
#endif #endif
namespace memgraph::glue { namespace memgraph::glue {
AuthChecker::AuthChecker( AuthChecker::AuthChecker(memgraph::auth::SynchedAuth *auth) : auth_(auth) {}
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth)
: auth_(auth) {}
bool AuthChecker::IsUserAuthorized(const std::optional<std::string> &username, bool AuthChecker::IsUserAuthorized(const std::optional<std::string> &username,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges, const std::vector<memgraph::query::AuthQuery::Privilege> &privileges,

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd. // Copyright 2024 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -22,8 +22,7 @@ namespace memgraph::glue {
class AuthChecker : public query::AuthChecker { class AuthChecker : public query::AuthChecker {
public: public:
explicit AuthChecker( explicit AuthChecker(memgraph::auth::SynchedAuth *auth);
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth);
bool IsUserAuthorized(const std::optional<std::string> &username, bool IsUserAuthorized(const std::optional<std::string> &username,
const std::vector<query::AuthQuery::Privilege> &privileges, const std::vector<query::AuthQuery::Privilege> &privileges,
@ -41,7 +40,7 @@ class AuthChecker : public query::AuthChecker {
const std::string &db_name = ""); const std::string &db_name = "");
private: private:
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth_; memgraph::auth::SynchedAuth *auth_;
mutable memgraph::utils::Synchronized<auth::User, memgraph::utils::SpinLock> user_; // cached user mutable memgraph::utils::Synchronized<auth::User, memgraph::utils::SpinLock> user_; // cached user
}; };
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE

View File

@ -210,16 +210,25 @@ std::vector<std::vector<memgraph::query::TypedValue>> ShowFineGrainedUserPrivile
if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) { if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
return {}; return {};
} }
const auto &label_permissions = user->GetFineGrainedAccessLabelPermissions();
const auto &edge_type_permissions = user->GetFineGrainedAccessEdgeTypePermissions();
auto all_fine_grained_permissions = auto all_fine_grained_permissions = GetFineGrainedPermissionForPrivilegeForUserOrRole(
GetFineGrainedPermissionForPrivilegeForUserOrRole(label_permissions, "LABEL", "USER"); user->GetUserFineGrainedAccessLabelPermissions(), "LABEL", "USER");
auto edge_type_fine_grained_permissions = auto all_role_fine_grained_permissions = GetFineGrainedPermissionForPrivilegeForUserOrRole(
GetFineGrainedPermissionForPrivilegeForUserOrRole(edge_type_permissions, "EDGE_TYPE", "USER"); user->GetRoleFineGrainedAccessLabelPermissions(), "LABEL", "ROLE");
all_fine_grained_permissions.insert(all_fine_grained_permissions.end(),
std::make_move_iterator(all_role_fine_grained_permissions.begin()),
std::make_move_iterator(all_role_fine_grained_permissions.end()));
all_fine_grained_permissions.insert(all_fine_grained_permissions.end(), edge_type_fine_grained_permissions.begin(), auto edge_type_fine_grained_permissions = GetFineGrainedPermissionForPrivilegeForUserOrRole(
edge_type_fine_grained_permissions.end()); user->GetUserFineGrainedAccessEdgeTypePermissions(), "EDGE_TYPE", "USER");
auto role_edge_type_fine_grained_permissions = GetFineGrainedPermissionForPrivilegeForUserOrRole(
user->GetRoleFineGrainedAccessEdgeTypePermissions(), "EDGE_TYPE", "ROLE");
all_fine_grained_permissions.insert(all_fine_grained_permissions.end(),
std::make_move_iterator(edge_type_fine_grained_permissions.begin()),
std::make_move_iterator(edge_type_fine_grained_permissions.end()));
all_fine_grained_permissions.insert(all_fine_grained_permissions.end(),
std::make_move_iterator(role_edge_type_fine_grained_permissions.begin()),
std::make_move_iterator(role_edge_type_fine_grained_permissions.end()));
return ConstructFineGrainedPrivilegesResult(all_fine_grained_permissions); return ConstructFineGrainedPrivilegesResult(all_fine_grained_permissions);
} }
@ -233,9 +242,9 @@ std::vector<std::vector<memgraph::query::TypedValue>> ShowFineGrainedRolePrivile
const auto &edge_type_permissions = role->GetFineGrainedAccessEdgeTypePermissions(); const auto &edge_type_permissions = role->GetFineGrainedAccessEdgeTypePermissions();
auto all_fine_grained_permissions = auto all_fine_grained_permissions =
GetFineGrainedPermissionForPrivilegeForUserOrRole(label_permissions, "LABEL", "USER"); GetFineGrainedPermissionForPrivilegeForUserOrRole(label_permissions, "LABEL", "ROLE");
auto edge_type_fine_grained_permissions = auto edge_type_fine_grained_permissions =
GetFineGrainedPermissionForPrivilegeForUserOrRole(edge_type_permissions, "EDGE_TYPE", "USER"); GetFineGrainedPermissionForPrivilegeForUserOrRole(edge_type_permissions, "EDGE_TYPE", "ROLE");
all_fine_grained_permissions.insert(all_fine_grained_permissions.end(), edge_type_fine_grained_permissions.begin(), all_fine_grained_permissions.insert(all_fine_grained_permissions.end(), edge_type_fine_grained_permissions.begin(),
edge_type_fine_grained_permissions.end()); edge_type_fine_grained_permissions.end());
@ -248,16 +257,15 @@ std::vector<std::vector<memgraph::query::TypedValue>> ShowFineGrainedRolePrivile
namespace memgraph::glue { namespace memgraph::glue {
AuthQueryHandler::AuthQueryHandler( AuthQueryHandler::AuthQueryHandler(memgraph::auth::SynchedAuth *auth) : auth_(auth) {}
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth)
: auth_(auth) {}
bool AuthQueryHandler::CreateUser(const std::string &username, const std::optional<std::string> &password) { bool AuthQueryHandler::CreateUser(const std::string &username, const std::optional<std::string> &password,
system::Transaction *system_tx) {
try { try {
const auto [first_user, user_added] = std::invoke([&, this] { const auto [first_user, user_added] = std::invoke([&, this] {
auto locked_auth = auth_->Lock(); auto locked_auth = auth_->Lock();
const auto first_user = !locked_auth->HasUsers(); const auto first_user = !locked_auth->HasUsers();
const auto user_added = locked_auth->AddUser(username, password).has_value(); const auto user_added = locked_auth->AddUser(username, password, system_tx).has_value();
return std::make_pair(first_user, user_added); return std::make_pair(first_user, user_added);
}); });
@ -276,10 +284,11 @@ bool AuthQueryHandler::CreateUser(const std::string &username, const std::option
} }
} }
#endif #endif
); ,
system_tx);
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
GrantDatabaseToUser(auth::kAllDatabases, username); GrantDatabaseToUser(auth::kAllDatabases, username, system_tx);
SetMainDatabase(dbms::kDefaultDB, username); SetMainDatabase(dbms::kDefaultDB, username, system_tx);
#endif #endif
} }
@ -289,18 +298,19 @@ bool AuthQueryHandler::CreateUser(const std::string &username, const std::option
} }
} }
bool AuthQueryHandler::DropUser(const std::string &username) { bool AuthQueryHandler::DropUser(const std::string &username, system::Transaction *system_tx) {
try { try {
auto locked_auth = auth_->Lock(); auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username); auto user = locked_auth->GetUser(username);
if (!user) return false; if (!user) return false;
return locked_auth->RemoveUser(username); return locked_auth->RemoveUser(username, system_tx);
} catch (const memgraph::auth::AuthException &e) { } catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what()); throw memgraph::query::QueryRuntimeException(e.what());
} }
} }
void AuthQueryHandler::SetPassword(const std::string &username, const std::optional<std::string> &password) { void AuthQueryHandler::SetPassword(const std::string &username, const std::optional<std::string> &password,
system::Transaction *system_tx) {
try { try {
auto locked_auth = auth_->Lock(); auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username); auto user = locked_auth->GetUser(username);
@ -308,39 +318,41 @@ void AuthQueryHandler::SetPassword(const std::string &username, const std::optio
throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist.", username); throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist.", username);
} }
locked_auth->UpdatePassword(*user, password); locked_auth->UpdatePassword(*user, password);
locked_auth->SaveUser(*user); locked_auth->SaveUser(*user, system_tx);
} catch (const memgraph::auth::AuthException &e) { } catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what()); throw memgraph::query::QueryRuntimeException(e.what());
} }
} }
bool AuthQueryHandler::CreateRole(const std::string &rolename) { bool AuthQueryHandler::CreateRole(const std::string &rolename, system::Transaction *system_tx) {
try { try {
auto locked_auth = auth_->Lock(); auto locked_auth = auth_->Lock();
return locked_auth->AddRole(rolename).has_value(); return locked_auth->AddRole(rolename, system_tx).has_value();
} catch (const memgraph::auth::AuthException &e) { } catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what()); throw memgraph::query::QueryRuntimeException(e.what());
} }
} }
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
bool AuthQueryHandler::RevokeDatabaseFromUser(const std::string &db, const std::string &username) { bool AuthQueryHandler::RevokeDatabaseFromUser(const std::string &db_name, const std::string &username,
system::Transaction *system_tx) {
try { try {
auto locked_auth = auth_->Lock(); auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username); auto user = locked_auth->GetUser(username);
if (!user) return false; if (!user) return false;
return locked_auth->RevokeDatabaseFromUser(db, username); return locked_auth->RevokeDatabaseFromUser(db_name, username, system_tx);
} catch (const memgraph::auth::AuthException &e) { } catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what()); throw memgraph::query::QueryRuntimeException(e.what());
} }
} }
bool AuthQueryHandler::GrantDatabaseToUser(const std::string &db, const std::string &username) { bool AuthQueryHandler::GrantDatabaseToUser(const std::string &db_name, const std::string &username,
system::Transaction *system_tx) {
try { try {
auto locked_auth = auth_->Lock(); auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username); auto user = locked_auth->GetUser(username);
if (!user) return false; if (!user) return false;
return locked_auth->GrantDatabaseToUser(db, username); return locked_auth->GrantDatabaseToUser(db_name, username, system_tx);
} catch (const memgraph::auth::AuthException &e) { } catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what()); throw memgraph::query::QueryRuntimeException(e.what());
} }
@ -360,27 +372,28 @@ std::vector<std::vector<memgraph::query::TypedValue>> AuthQueryHandler::GetDatab
} }
} }
bool AuthQueryHandler::SetMainDatabase(std::string_view db, const std::string &username) { bool AuthQueryHandler::SetMainDatabase(std::string_view db_name, const std::string &username,
system::Transaction *system_tx) {
try { try {
auto locked_auth = auth_->Lock(); auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username); auto user = locked_auth->GetUser(username);
if (!user) return false; if (!user) return false;
return locked_auth->SetMainDatabase(db, username); return locked_auth->SetMainDatabase(db_name, username, system_tx);
} catch (const memgraph::auth::AuthException &e) { } catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what()); throw memgraph::query::QueryRuntimeException(e.what());
} }
} }
void AuthQueryHandler::DeleteDatabase(std::string_view db) { void AuthQueryHandler::DeleteDatabase(std::string_view db_name, system::Transaction *system_tx) {
try { try {
auth_->Lock()->DeleteDatabase(std::string(db)); auth_->Lock()->DeleteDatabase(std::string(db_name), system_tx);
} catch (const memgraph::auth::AuthException &e) { } catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what()); throw memgraph::query::QueryRuntimeException(e.what());
} }
} }
#endif #endif
bool AuthQueryHandler::DropRole(const std::string &rolename) { bool AuthQueryHandler::DropRole(const std::string &rolename, system::Transaction *system_tx) {
try { try {
auto locked_auth = auth_->Lock(); auto locked_auth = auth_->Lock();
auto role = locked_auth->GetRole(rolename); auto role = locked_auth->GetRole(rolename);
@ -389,7 +402,7 @@ bool AuthQueryHandler::DropRole(const std::string &rolename) {
return false; return false;
}; };
return locked_auth->RemoveRole(rolename); return locked_auth->RemoveRole(rolename, system_tx);
} catch (const memgraph::auth::AuthException &e) { } catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what()); throw memgraph::query::QueryRuntimeException(e.what());
} }
@ -461,7 +474,8 @@ std::vector<memgraph::query::TypedValue> AuthQueryHandler::GetUsernamesForRole(c
} }
} }
void AuthQueryHandler::SetRole(const std::string &username, const std::string &rolename) { void AuthQueryHandler::SetRole(const std::string &username, const std::string &rolename,
system::Transaction *system_tx) {
try { try {
auto locked_auth = auth_->Lock(); auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username); auto user = locked_auth->GetUser(username);
@ -477,13 +491,13 @@ void AuthQueryHandler::SetRole(const std::string &username, const std::string &r
current_role->rolename()); current_role->rolename());
} }
user->SetRole(*role); user->SetRole(*role);
locked_auth->SaveUser(*user); locked_auth->SaveUser(*user, system_tx);
} catch (const memgraph::auth::AuthException &e) { } catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what()); throw memgraph::query::QueryRuntimeException(e.what());
} }
} }
void AuthQueryHandler::ClearRole(const std::string &username) { void AuthQueryHandler::ClearRole(const std::string &username, system::Transaction *system_tx) {
try { try {
auto locked_auth = auth_->Lock(); auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username); auto user = locked_auth->GetUser(username);
@ -491,7 +505,7 @@ void AuthQueryHandler::ClearRole(const std::string &username) {
throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist .", username); throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist .", username);
} }
user->ClearRole(); user->ClearRole();
locked_auth->SaveUser(*user); locked_auth->SaveUser(*user, system_tx);
} catch (const memgraph::auth::AuthException &e) { } catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what()); throw memgraph::query::QueryRuntimeException(e.what());
} }
@ -545,7 +559,8 @@ void AuthQueryHandler::GrantPrivilege(
const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>> const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>>
&edge_type_privileges &edge_type_privileges
#endif #endif
) { ,
system::Transaction *system_tx) {
EditPermissions( EditPermissions(
user_or_role, privileges, user_or_role, privileges,
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
@ -568,11 +583,13 @@ void AuthQueryHandler::GrantPrivilege(
} }
} }
#endif #endif
); ,
system_tx);
} // namespace memgraph::glue } // namespace memgraph::glue
void AuthQueryHandler::DenyPrivilege(const std::string &user_or_role, void AuthQueryHandler::DenyPrivilege(const std::string &user_or_role,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges) { const std::vector<memgraph::query::AuthQuery::Privilege> &privileges,
system::Transaction *system_tx) {
EditPermissions( EditPermissions(
user_or_role, privileges, user_or_role, privileges,
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
@ -588,7 +605,8 @@ void AuthQueryHandler::DenyPrivilege(const std::string &user_or_role,
, ,
[](auto &fine_grained_permissions, const auto &privilege_collection) {} [](auto &fine_grained_permissions, const auto &privilege_collection) {}
#endif #endif
); ,
system_tx);
} }
void AuthQueryHandler::RevokePrivilege( void AuthQueryHandler::RevokePrivilege(
@ -600,7 +618,8 @@ void AuthQueryHandler::RevokePrivilege(
const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>> const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>>
&edge_type_privileges &edge_type_privileges
#endif #endif
) { ,
system::Transaction *system_tx) {
EditPermissions( EditPermissions(
user_or_role, privileges, user_or_role, privileges,
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
@ -622,7 +641,8 @@ void AuthQueryHandler::RevokePrivilege(
} }
} }
#endif #endif
); ,
system_tx);
} // namespace memgraph::glue } // namespace memgraph::glue
template <class TEditPermissionsFun template <class TEditPermissionsFun
@ -646,7 +666,8 @@ void AuthQueryHandler::EditPermissions(
, ,
const TEditFineGrainedPermissionsFun &edit_fine_grained_permissions_fun const TEditFineGrainedPermissionsFun &edit_fine_grained_permissions_fun
#endif #endif
) { ,
system::Transaction *system_tx) {
try { try {
std::vector<memgraph::auth::Permission> permissions; std::vector<memgraph::auth::Permission> permissions;
permissions.reserve(privileges.size()); permissions.reserve(privileges.size());
@ -675,7 +696,7 @@ void AuthQueryHandler::EditPermissions(
} }
} }
#endif #endif
locked_auth->SaveUser(*user); locked_auth->SaveUser(*user, system_tx);
} else { } else {
for (const auto &permission : permissions) { for (const auto &permission : permissions) {
edit_permissions_fun(role->permissions(), permission); edit_permissions_fun(role->permissions(), permission);
@ -691,7 +712,7 @@ void AuthQueryHandler::EditPermissions(
} }
} }
#endif #endif
locked_auth->SaveRole(*role); locked_auth->SaveRole(*role, system_tx);
} }
} catch (const memgraph::auth::AuthException &e) { } catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what()); throw memgraph::query::QueryRuntimeException(e.what());

View File

@ -23,32 +23,36 @@
namespace memgraph::glue { namespace memgraph::glue {
class AuthQueryHandler final : public memgraph::query::AuthQueryHandler { class AuthQueryHandler final : public memgraph::query::AuthQueryHandler {
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth_; memgraph::auth::SynchedAuth *auth_;
public: public:
AuthQueryHandler(memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth); explicit AuthQueryHandler(memgraph::auth::SynchedAuth *auth);
bool CreateUser(const std::string &username, const std::optional<std::string> &password) override; bool CreateUser(const std::string &username, const std::optional<std::string> &password,
system::Transaction *system_tx) override;
bool DropUser(const std::string &username) override; bool DropUser(const std::string &username, system::Transaction *system_tx) override;
void SetPassword(const std::string &username, const std::optional<std::string> &password) override; void SetPassword(const std::string &username, const std::optional<std::string> &password,
system::Transaction *system_tx) override;
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
bool RevokeDatabaseFromUser(const std::string &db, const std::string &username) override; bool RevokeDatabaseFromUser(const std::string &db_name, const std::string &username,
system::Transaction *system_tx) override;
bool GrantDatabaseToUser(const std::string &db, const std::string &username) override; bool GrantDatabaseToUser(const std::string &db_name, const std::string &username,
system::Transaction *system_tx) override;
std::vector<std::vector<memgraph::query::TypedValue>> GetDatabasePrivileges(const std::string &username) override; std::vector<std::vector<memgraph::query::TypedValue>> GetDatabasePrivileges(const std::string &username) override;
bool SetMainDatabase(std::string_view db, const std::string &username) override; bool SetMainDatabase(std::string_view db_name, const std::string &username, system::Transaction *system_tx) override;
void DeleteDatabase(std::string_view db) override; void DeleteDatabase(std::string_view db_name, system::Transaction *system_tx) override;
#endif #endif
bool CreateRole(const std::string &rolename) override; bool CreateRole(const std::string &rolename, system::Transaction *system_tx) override;
bool DropRole(const std::string &rolename) override; bool DropRole(const std::string &rolename, system::Transaction *system_tx) override;
std::vector<memgraph::query::TypedValue> GetUsernames() override; std::vector<memgraph::query::TypedValue> GetUsernames() override;
@ -58,9 +62,9 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler {
std::vector<memgraph::query::TypedValue> GetUsernamesForRole(const std::string &rolename) override; std::vector<memgraph::query::TypedValue> GetUsernamesForRole(const std::string &rolename) override;
void SetRole(const std::string &username, const std::string &rolename) override; void SetRole(const std::string &username, const std::string &rolename, system::Transaction *system_tx) override;
void ClearRole(const std::string &username) override; void ClearRole(const std::string &username, system::Transaction *system_tx) override;
std::vector<std::vector<memgraph::query::TypedValue>> GetPrivileges(const std::string &user_or_role) override; std::vector<std::vector<memgraph::query::TypedValue>> GetPrivileges(const std::string &user_or_role) override;
@ -74,10 +78,12 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler {
const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>> const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>>
&edge_type_privileges &edge_type_privileges
#endif #endif
) override; ,
system::Transaction *system_tx) override;
void DenyPrivilege(const std::string &user_or_role, void DenyPrivilege(const std::string &user_or_role,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges) override; const std::vector<memgraph::query::AuthQuery::Privilege> &privileges,
system::Transaction *system_tx) override;
void RevokePrivilege( void RevokePrivilege(
const std::string &user_or_role, const std::vector<memgraph::query::AuthQuery::Privilege> &privileges const std::string &user_or_role, const std::vector<memgraph::query::AuthQuery::Privilege> &privileges
@ -88,7 +94,8 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler {
const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>> const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>>
&edge_type_privileges &edge_type_privileges
#endif #endif
) override; ,
system::Transaction *system_tx) override;
private: private:
template <class TEditPermissionsFun template <class TEditPermissionsFun
@ -112,6 +119,7 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler {
, ,
const TEditFineGrainedPermissionsFun &edit_fine_grained_permissions_fun const TEditFineGrainedPermissionsFun &edit_fine_grained_permissions_fun
#endif #endif
); ,
system::Transaction *system_tx);
}; };
} // namespace memgraph::glue } // namespace memgraph::glue

View File

@ -37,6 +37,7 @@ KVStore::KVStore(std::filesystem::path storage) : pimpl_(std::make_unique<impl>(
} }
KVStore::~KVStore() { KVStore::~KVStore() {
if (pimpl_ == nullptr) return;
spdlog::debug("Destroying KVStore at {}", pimpl_->storage.string()); spdlog::debug("Destroying KVStore at {}", pimpl_->storage.string());
const auto sync = pimpl_->db->SyncWAL(); const auto sync = pimpl_->db->SyncWAL();
if (!sync.ok()) spdlog::error("KVStore sync failed!"); if (!sync.ok()) spdlog::error("KVStore sync failed!");

View File

@ -11,9 +11,12 @@
#include <cstdint> #include <cstdint>
#include "audit/log.hpp" #include "audit/log.hpp"
#include "auth/auth.hpp"
#include "communication/websocket/auth.hpp" #include "communication/websocket/auth.hpp"
#include "communication/websocket/server.hpp" #include "communication/websocket/server.hpp"
#include "coordination/coordinator_handlers.hpp"
#include "dbms/constants.hpp" #include "dbms/constants.hpp"
#include "dbms/dbms_handler.hpp"
#include "dbms/inmemory/replication_handlers.hpp" #include "dbms/inmemory/replication_handlers.hpp"
#include "flags/all.hpp" #include "flags/all.hpp"
#include "glue/MonitoringServerT.hpp" #include "glue/MonitoringServerT.hpp"
@ -24,14 +27,19 @@
#include "helpers.hpp" #include "helpers.hpp"
#include "license/license_sender.hpp" #include "license/license_sender.hpp"
#include "memory/global_memory_control.hpp" #include "memory/global_memory_control.hpp"
#include "query/auth_query_handler.hpp"
#include "query/config.hpp" #include "query/config.hpp"
#include "query/discard_value_stream.hpp" #include "query/discard_value_stream.hpp"
#include "query/interpreter.hpp" #include "query/interpreter.hpp"
#include "query/interpreter_context.hpp"
#include "query/procedure/callable_alias_mapper.hpp" #include "query/procedure/callable_alias_mapper.hpp"
#include "query/procedure/module.hpp" #include "query/procedure/module.hpp"
#include "query/procedure/py_module.hpp" #include "query/procedure/py_module.hpp"
#include "replication_handler/replication_handler.hpp"
#include "replication_handler/system_replication.hpp"
#include "requests/requests.hpp" #include "requests/requests.hpp"
#include "storage/v2/durability/durability.hpp" #include "storage/v2/durability/durability.hpp"
#include "system/system.hpp"
#include "telemetry/telemetry.hpp" #include "telemetry/telemetry.hpp"
#include "utils/signals.hpp" #include "utils/signals.hpp"
#include "utils/sysinfo/memory.hpp" #include "utils/sysinfo/memory.hpp"
@ -39,10 +47,6 @@
#include "utils/terminate_handler.hpp" #include "utils/terminate_handler.hpp"
#include "version.hpp" #include "version.hpp"
#include "dbms/dbms_handler.hpp"
#include "query/auth_query_handler.hpp"
#include "query/interpreter_context.hpp"
namespace { namespace {
constexpr const char *kMgUser = "MEMGRAPH_USER"; constexpr const char *kMgUser = "MEMGRAPH_USER";
constexpr const char *kMgPassword = "MEMGRAPH_PASSWORD"; constexpr const char *kMgPassword = "MEMGRAPH_PASSWORD";
@ -356,9 +360,8 @@ int main(int argc, char **argv) {
.stream_transaction_conflict_retries = FLAGS_stream_transaction_conflict_retries, .stream_transaction_conflict_retries = FLAGS_stream_transaction_conflict_retries,
.stream_transaction_retry_interval = std::chrono::milliseconds(FLAGS_stream_transaction_retry_interval)}; .stream_transaction_retry_interval = std::chrono::milliseconds(FLAGS_stream_transaction_retry_interval)};
auto auth_glue = auto auth_glue = [](memgraph::auth::SynchedAuth *auth, std::unique_ptr<memgraph::query::AuthQueryHandler> &ah,
[](memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth, std::unique_ptr<memgraph::query::AuthChecker> &ac) {
std::unique_ptr<memgraph::query::AuthQueryHandler> &ah, std::unique_ptr<memgraph::query::AuthChecker> &ac) {
// Glue high level auth implementations to the query side // Glue high level auth implementations to the query side
ah = std::make_unique<memgraph::glue::AuthQueryHandler>(auth); ah = std::make_unique<memgraph::glue::AuthQueryHandler>(auth);
ac = std::make_unique<memgraph::glue::AuthChecker>(auth); ac = std::make_unique<memgraph::glue::AuthChecker>(auth);
@ -367,33 +370,65 @@ int main(int argc, char **argv) {
auto *maybe_password = std::getenv(kMgPassword); auto *maybe_password = std::getenv(kMgPassword);
auto *maybe_pass_file = std::getenv(kMgPassfile); auto *maybe_pass_file = std::getenv(kMgPassfile);
if (maybe_username && maybe_password) { if (maybe_username && maybe_password) {
ah->CreateUser(maybe_username, maybe_password); ah->CreateUser(maybe_username, maybe_password, nullptr);
} else if (maybe_pass_file) { } else if (maybe_pass_file) {
const auto [username, password] = LoadUsernameAndPassword(maybe_pass_file); const auto [username, password] = LoadUsernameAndPassword(maybe_pass_file);
if (!username.empty() && !password.empty()) { if (!username.empty() && !password.empty()) {
ah->CreateUser(username, password); ah->CreateUser(username, password, nullptr);
} }
} }
}; };
memgraph::auth::Auth::Config auth_config{FLAGS_auth_user_or_role_name_regex, FLAGS_auth_password_strength_regex, memgraph::auth::Auth::Config auth_config{FLAGS_auth_user_or_role_name_regex, FLAGS_auth_password_strength_regex,
FLAGS_auth_password_permit_null}; FLAGS_auth_password_permit_null};
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> auth_{ memgraph::auth::SynchedAuth auth_{data_directory / "auth", auth_config};
data_directory / "auth", auth_config};
std::unique_ptr<memgraph::query::AuthQueryHandler> auth_handler; std::unique_ptr<memgraph::query::AuthQueryHandler> auth_handler;
std::unique_ptr<memgraph::query::AuthChecker> auth_checker; std::unique_ptr<memgraph::query::AuthChecker> auth_checker;
auth_glue(&auth_, auth_handler, auth_checker); auth_glue(&auth_, auth_handler, auth_checker);
memgraph::dbms::DbmsHandler dbms_handler(db_config auto system = memgraph::system::System{db_config.durability.storage_directory, FLAGS_data_recovery_on_startup};
// singleton replication state
memgraph::replication::ReplicationState repl_state{ReplicationStateRootPath(db_config)};
// singleton coordinator state
#ifdef MG_ENTERPRISE
memgraph::coordination::CoordinatorState coordinator_state;
#endif
memgraph::dbms::DbmsHandler dbms_handler(db_config, system, repl_state
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
, ,
&auth_, FLAGS_data_recovery_on_startup auth_, FLAGS_data_recovery_on_startup
#endif #endif
); );
// Note: Now that all system's subsystems are initialised (dbms & auth)
// We can now initialise the recovery of replication (which will include those subsystems)
// ReplicationHandler will handle the recovery
auto replication_handler = memgraph::replication::ReplicationHandler{repl_state, dbms_handler
#ifdef MG_ENTERPRISE
,
&system, auth_
#endif
};
#ifdef MG_ENTERPRISE
// MAIN or REPLICA instance
if (FLAGS_coordinator_server_port) {
memgraph::dbms::CoordinatorHandlers::Register(coordinator_state.GetCoordinatorServer(), replication_handler);
MG_ASSERT(coordinator_state.GetCoordinatorServer().Start(), "Failed to start coordinator server!");
}
#endif
auto db_acc = dbms_handler.Get(); auto db_acc = dbms_handler.Get();
memgraph::query::InterpreterContext interpreter_context_( memgraph::query::InterpreterContext interpreter_context_(interp_config, &dbms_handler, &repl_state, system,
interp_config, &dbms_handler, &dbms_handler.ReplicationState(), auth_handler.get(), auth_checker.get()); #ifdef MG_ENTERPRISE
&coordinator_state,
#endif
auth_handler.get(), auth_checker.get(),
&replication_handler);
MG_ASSERT(db_acc, "Failed to access the main database"); MG_ASSERT(db_acc, "Failed to access the main database");
memgraph::query::procedure::gModuleRegistry.SetModulesDirectory(memgraph::flags::ParseQueryModulesDirectory(), memgraph::query::procedure::gModuleRegistry.SetModulesDirectory(memgraph::flags::ParseQueryModulesDirectory(),
@ -460,9 +495,9 @@ int main(int argc, char **argv) {
if (FLAGS_telemetry_enabled) { if (FLAGS_telemetry_enabled) {
telemetry.emplace(telemetry_server, data_directory / "telemetry", memgraph::glue::run_id_, machine_id, telemetry.emplace(telemetry_server, data_directory / "telemetry", memgraph::glue::run_id_, machine_id,
service_name == "BoltS", FLAGS_data_directory, std::chrono::minutes(10)); service_name == "BoltS", FLAGS_data_directory, std::chrono::minutes(10));
telemetry->AddStorageCollector(dbms_handler, auth_); telemetry->AddStorageCollector(dbms_handler, auth_, repl_state);
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
telemetry->AddDatabaseCollector(dbms_handler); telemetry->AddDatabaseCollector(dbms_handler, repl_state);
#else #else
telemetry->AddDatabaseCollector(); telemetry->AddDatabaseCollector();
#endif #endif

View File

@ -56,6 +56,7 @@ target_link_libraries(mg-query PUBLIC dl
mg-kvstore mg-kvstore
mg-memory mg-memory
mg::csv mg::csv
mg::system
mg-flags mg-flags
mg-dbms mg-dbms
mg-events) mg-events)

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd. // Copyright 2024 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -18,6 +18,7 @@
#include "query/frontend/ast/ast.hpp" // overkill #include "query/frontend/ast/ast.hpp" // overkill
#include "query/typed_value.hpp" #include "query/typed_value.hpp"
#include "system/system.hpp"
namespace memgraph::query { namespace memgraph::query {
@ -33,23 +34,27 @@ class AuthQueryHandler {
/// Return false if the user already exists. /// Return false if the user already exists.
/// @throw QueryRuntimeException if an error ocurred. /// @throw QueryRuntimeException if an error ocurred.
virtual bool CreateUser(const std::string &username, const std::optional<std::string> &password) = 0; virtual bool CreateUser(const std::string &username, const std::optional<std::string> &password,
system::Transaction *system_tx) = 0;
/// Return false if the user does not exist. /// Return false if the user does not exist.
/// @throw QueryRuntimeException if an error ocurred. /// @throw QueryRuntimeException if an error ocurred.
virtual bool DropUser(const std::string &username) = 0; virtual bool DropUser(const std::string &username, system::Transaction *system_tx) = 0;
/// @throw QueryRuntimeException if an error ocurred. /// @throw QueryRuntimeException if an error ocurred.
virtual void SetPassword(const std::string &username, const std::optional<std::string> &password) = 0; virtual void SetPassword(const std::string &username, const std::optional<std::string> &password,
system::Transaction *system_tx) = 0;
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
/// Return true if access revoked successfully /// Return true if access revoked successfully
/// @throw QueryRuntimeException if an error ocurred. /// @throw QueryRuntimeException if an error ocurred.
virtual bool RevokeDatabaseFromUser(const std::string &db, const std::string &username) = 0; virtual bool RevokeDatabaseFromUser(const std::string &db, const std::string &username,
system::Transaction *system_tx) = 0;
/// Return true if access granted successfully /// Return true if access granted successfully
/// @throw QueryRuntimeException if an error ocurred. /// @throw QueryRuntimeException if an error ocurred.
virtual bool GrantDatabaseToUser(const std::string &db, const std::string &username) = 0; virtual bool GrantDatabaseToUser(const std::string &db, const std::string &username,
system::Transaction *system_tx) = 0;
/// Returns database access rights for the user /// Returns database access rights for the user
/// @throw QueryRuntimeException if an error ocurred. /// @throw QueryRuntimeException if an error ocurred.
@ -57,20 +62,20 @@ class AuthQueryHandler {
/// Return true if main database set successfully /// Return true if main database set successfully
/// @throw QueryRuntimeException if an error ocurred. /// @throw QueryRuntimeException if an error ocurred.
virtual bool SetMainDatabase(std::string_view db, const std::string &username) = 0; virtual bool SetMainDatabase(std::string_view db, const std::string &username, system::Transaction *system_tx) = 0;
/// Delete database from all users /// Delete database from all users
/// @throw QueryRuntimeException if an error ocurred. /// @throw QueryRuntimeException if an error ocurred.
virtual void DeleteDatabase(std::string_view db) = 0; virtual void DeleteDatabase(std::string_view db, system::Transaction *system_tx) = 0;
#endif #endif
/// Return false if the role already exists. /// Return false if the role already exists.
/// @throw QueryRuntimeException if an error ocurred. /// @throw QueryRuntimeException if an error ocurred.
virtual bool CreateRole(const std::string &rolename) = 0; virtual bool CreateRole(const std::string &rolename, system::Transaction *system_tx) = 0;
/// Return false if the role does not exist. /// Return false if the role does not exist.
/// @throw QueryRuntimeException if an error ocurred. /// @throw QueryRuntimeException if an error ocurred.
virtual bool DropRole(const std::string &rolename) = 0; virtual bool DropRole(const std::string &rolename, system::Transaction *system_tx) = 0;
/// @throw QueryRuntimeException if an error ocurred. /// @throw QueryRuntimeException if an error ocurred.
virtual std::vector<memgraph::query::TypedValue> GetUsernames() = 0; virtual std::vector<memgraph::query::TypedValue> GetUsernames() = 0;
@ -85,10 +90,10 @@ class AuthQueryHandler {
virtual std::vector<memgraph::query::TypedValue> GetUsernamesForRole(const std::string &rolename) = 0; virtual std::vector<memgraph::query::TypedValue> GetUsernamesForRole(const std::string &rolename) = 0;
/// @throw QueryRuntimeException if an error ocurred. /// @throw QueryRuntimeException if an error ocurred.
virtual void SetRole(const std::string &username, const std::string &rolename) = 0; virtual void SetRole(const std::string &username, const std::string &rolename, system::Transaction *system_tx) = 0;
/// @throw QueryRuntimeException if an error ocurred. /// @throw QueryRuntimeException if an error ocurred.
virtual void ClearRole(const std::string &username) = 0; virtual void ClearRole(const std::string &username, system::Transaction *system_tx) = 0;
virtual std::vector<std::vector<memgraph::query::TypedValue>> GetPrivileges(const std::string &user_or_role) = 0; virtual std::vector<std::vector<memgraph::query::TypedValue>> GetPrivileges(const std::string &user_or_role) = 0;
@ -103,11 +108,13 @@ class AuthQueryHandler {
const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>> const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>>
&edge_type_privileges &edge_type_privileges
#endif #endif
) = 0; ,
system::Transaction *system_tx) = 0;
/// @throw QueryRuntimeException if an error ocurred. /// @throw QueryRuntimeException if an error ocurred.
virtual void DenyPrivilege(const std::string &user_or_role, virtual void DenyPrivilege(const std::string &user_or_role,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges) = 0; const std::vector<memgraph::query::AuthQuery::Privilege> &privileges,
system::Transaction *system_tx) = 0;
/// @throw QueryRuntimeException if an error ocurred. /// @throw QueryRuntimeException if an error ocurred.
virtual void RevokePrivilege( virtual void RevokePrivilege(
@ -120,7 +127,8 @@ class AuthQueryHandler {
const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>> const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>>
&edge_type_privileges &edge_type_privileges
#endif #endif
) = 0; ,
system::Transaction *system_tx) = 0;
}; };
} // namespace memgraph::query } // namespace memgraph::query

View File

@ -34,7 +34,9 @@
#include "auth/auth.hpp" #include "auth/auth.hpp"
#include "auth/models.hpp" #include "auth/models.hpp"
#include "csv/parsing.hpp" #include "csv/parsing.hpp"
#include "dbms/coordinator_handler.hpp"
#include "dbms/database.hpp" #include "dbms/database.hpp"
#include "dbms/dbms_handler.hpp"
#include "dbms/global.hpp" #include "dbms/global.hpp"
#include "dbms/inmemory/storage_helper.hpp" #include "dbms/inmemory/storage_helper.hpp"
#include "flags/replication.hpp" #include "flags/replication.hpp"
@ -43,6 +45,7 @@
#include "license/license.hpp" #include "license/license.hpp"
#include "memory/global_memory_control.hpp" #include "memory/global_memory_control.hpp"
#include "memory/query_memory_control.hpp" #include "memory/query_memory_control.hpp"
#include "query/auth_query_handler.hpp"
#include "query/config.hpp" #include "query/config.hpp"
#include "query/constants.hpp" #include "query/constants.hpp"
#include "query/context.hpp" #include "query/context.hpp"
@ -58,12 +61,14 @@
#include "query/frontend/semantic/symbol_generator.hpp" #include "query/frontend/semantic/symbol_generator.hpp"
#include "query/interpret/eval.hpp" #include "query/interpret/eval.hpp"
#include "query/interpret/frame.hpp" #include "query/interpret/frame.hpp"
#include "query/interpreter_context.hpp"
#include "query/metadata.hpp" #include "query/metadata.hpp"
#include "query/plan/hint_provider.hpp" #include "query/plan/hint_provider.hpp"
#include "query/plan/planner.hpp" #include "query/plan/planner.hpp"
#include "query/plan/profile.hpp" #include "query/plan/profile.hpp"
#include "query/plan/vertex_count_cache.hpp" #include "query/plan/vertex_count_cache.hpp"
#include "query/procedure/module.hpp" #include "query/procedure/module.hpp"
#include "query/replication_query_handler.hpp"
#include "query/stream.hpp" #include "query/stream.hpp"
#include "query/stream/common.hpp" #include "query/stream/common.hpp"
#include "query/stream/sources.hpp" #include "query/stream/sources.hpp"
@ -71,6 +76,7 @@
#include "query/trigger.hpp" #include "query/trigger.hpp"
#include "query/typed_value.hpp" #include "query/typed_value.hpp"
#include "replication/config.hpp" #include "replication/config.hpp"
#include "replication/state.hpp"
#include "spdlog/spdlog.h" #include "spdlog/spdlog.h"
#include "storage/v2/disk/storage.hpp" #include "storage/v2/disk/storage.hpp"
#include "storage/v2/edge.hpp" #include "storage/v2/edge.hpp"
@ -101,13 +107,6 @@
#include "utils/typeinfo.hpp" #include "utils/typeinfo.hpp"
#include "utils/variant_helpers.hpp" #include "utils/variant_helpers.hpp"
#include "dbms/coordinator_handler.hpp"
#include "dbms/dbms_handler.hpp"
#include "dbms/replication_handler.hpp"
#include "query/auth_query_handler.hpp"
#include "query/interpreter_context.hpp"
#include "replication/state.hpp"
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
#include "coordination/constants.hpp" #include "coordination/constants.hpp"
#endif #endif
@ -306,17 +305,18 @@ class ReplQueryHandler {
ReplicationQuery::ReplicaState state; ReplicationQuery::ReplicaState state;
}; };
explicit ReplQueryHandler(dbms::DbmsHandler *dbms_handler) : handler_{*dbms_handler} {} explicit ReplQueryHandler(query::ReplicationQueryHandler &replication_query_handler)
: handler_{&replication_query_handler} {}
/// @throw QueryRuntimeException if an error ocurred. /// @throw QueryRuntimeException if an error ocurred.
void SetReplicationRole(ReplicationQuery::ReplicationRole replication_role, std::optional<int64_t> port) { void SetReplicationRole(ReplicationQuery::ReplicationRole replication_role, std::optional<int64_t> port) {
auto ValidatePort = [](std::optional<int64_t> port) -> void { auto ValidatePort = [](std::optional<int64_t> port) -> void {
if (*port < 0 || *port > std::numeric_limits<uint16_t>::max()) { if (!port || *port < 0 || *port > std::numeric_limits<uint16_t>::max()) {
throw QueryRuntimeException("Port number invalid!"); throw QueryRuntimeException("Port number invalid!");
} }
}; };
if (replication_role == ReplicationQuery::ReplicationRole::MAIN) { if (replication_role == ReplicationQuery::ReplicationRole::MAIN) {
if (!handler_.SetReplicationRoleMain()) { if (!handler_->SetReplicationRoleMain()) {
throw QueryRuntimeException("Couldn't set replication role to main!"); throw QueryRuntimeException("Couldn't set replication role to main!");
} }
} else { } else {
@ -327,7 +327,7 @@ class ReplQueryHandler {
.port = static_cast<uint16_t>(*port), .port = static_cast<uint16_t>(*port),
}; };
if (!handler_.SetReplicationRoleReplica(config)) { if (!handler_->SetReplicationRoleReplica(config)) {
throw QueryRuntimeException("Couldn't set role to replica!"); throw QueryRuntimeException("Couldn't set role to replica!");
} }
} }
@ -335,7 +335,7 @@ class ReplQueryHandler {
/// @throw QueryRuntimeException if an error ocurred. /// @throw QueryRuntimeException if an error ocurred.
ReplicationQuery::ReplicationRole ShowReplicationRole() const { ReplicationQuery::ReplicationRole ShowReplicationRole() const {
switch (handler_.GetRole()) { switch (handler_->GetRole()) {
case memgraph::replication_coordination_glue::ReplicationRole::MAIN: case memgraph::replication_coordination_glue::ReplicationRole::MAIN:
return ReplicationQuery::ReplicationRole::MAIN; return ReplicationQuery::ReplicationRole::MAIN;
case memgraph::replication_coordination_glue::ReplicationRole::REPLICA: case memgraph::replication_coordination_glue::ReplicationRole::REPLICA:
@ -349,7 +349,7 @@ class ReplQueryHandler {
const ReplicationQuery::SyncMode sync_mode, const std::chrono::seconds replica_check_frequency) { const ReplicationQuery::SyncMode sync_mode, const std::chrono::seconds replica_check_frequency) {
// Coordinator is main by default so this check is OK although it should actually be nothing (neither main nor // Coordinator is main by default so this check is OK although it should actually be nothing (neither main nor
// replica) // replica)
if (handler_.IsReplica()) { if (handler_->IsReplica()) {
// replica can't register another replica // replica can't register another replica
throw QueryRuntimeException("Replica can't register another replica!"); throw QueryRuntimeException("Replica can't register another replica!");
} }
@ -368,7 +368,7 @@ class ReplQueryHandler {
.replica_check_frequency = replica_check_frequency, .replica_check_frequency = replica_check_frequency,
.ssl = std::nullopt}; .ssl = std::nullopt};
const auto error = handler_.RegisterReplica(replication_config).HasError(); const auto error = handler_->TryRegisterReplica(replication_config).HasError();
if (error) { if (error) {
throw QueryRuntimeException(fmt::format("Couldn't register replica '{}'!", name)); throw QueryRuntimeException(fmt::format("Couldn't register replica '{}'!", name));
@ -381,9 +381,9 @@ class ReplQueryHandler {
/// @throw QueryRuntimeException if an error occurred. /// @throw QueryRuntimeException if an error occurred.
void DropReplica(std::string_view replica_name) { void DropReplica(std::string_view replica_name) {
auto const result = handler_.UnregisterReplica(replica_name); auto const result = handler_->UnregisterReplica(replica_name);
switch (result) { switch (result) {
using enum memgraph::dbms::UnregisterReplicaResult; using enum memgraph::query::UnregisterReplicaResult;
case NOT_MAIN: case NOT_MAIN:
throw QueryRuntimeException("Replica can't unregister a replica!"); throw QueryRuntimeException("Replica can't unregister a replica!");
case COULD_NOT_BE_PERSISTED: case COULD_NOT_BE_PERSISTED:
@ -396,7 +396,7 @@ class ReplQueryHandler {
} }
std::vector<ReplicaInfo> ShowReplicas(const dbms::Database &db) const { std::vector<ReplicaInfo> ShowReplicas(const dbms::Database &db) const {
if (handler_.IsReplica()) { if (handler_->IsReplica()) {
// replica can't show registered replicas (it shouldn't have any) // replica can't show registered replicas (it shouldn't have any)
throw QueryRuntimeException("Replica can't show registered replicas (it shouldn't have any)!"); throw QueryRuntimeException("Replica can't show registered replicas (it shouldn't have any)!");
} }
@ -447,19 +447,16 @@ class ReplQueryHandler {
} }
private: private:
dbms::ReplicationHandler handler_; query::ReplicationQueryHandler *handler_;
}; };
#ifdef MG_ENTERPRISE
class CoordQueryHandler final : public query::CoordinatorQueryHandler { class CoordQueryHandler final : public query::CoordinatorQueryHandler {
public: public:
explicit CoordQueryHandler(dbms::DbmsHandler *dbms_handler) : handler_ { *dbms_handler } explicit CoordQueryHandler(coordination::CoordinatorState &coordinator_state)
#ifdef MG_ENTERPRISE
, coordinator_handler_(*dbms_handler) : coordinator_handler_(coordinator_state) {}
#endif
{
}
#ifdef MG_ENTERPRISE
/// @throw QueryRuntimeException if an error ocurred. /// @throw QueryRuntimeException if an error ocurred.
void RegisterInstance(const std::string &coordinator_socket_address, const std::string &replication_socket_address, void RegisterInstance(const std::string &coordinator_socket_address, const std::string &replication_socket_address,
const std::chrono::seconds instance_check_frequency, const std::string &instance_name, const std::chrono::seconds instance_check_frequency, const std::string &instance_name,
@ -497,7 +494,7 @@ class CoordQueryHandler final : public query::CoordinatorQueryHandler {
using enum memgraph::coordination::RegisterInstanceCoordinatorStatus; using enum memgraph::coordination::RegisterInstanceCoordinatorStatus;
case NAME_EXISTS: case NAME_EXISTS:
throw QueryRuntimeException("Couldn't register replica instance since instance with such name already exists!"); throw QueryRuntimeException("Couldn't register replica instance since instance with such name already exists!");
case END_POINT_EXISTS: case ENDPOINT_EXISTS:
throw QueryRuntimeException( throw QueryRuntimeException(
"Couldn't register replica instance since instance with such endpoint already exists!"); "Couldn't register replica instance since instance with such endpoint already exists!");
case NOT_COORDINATOR: case NOT_COORDINATOR:
@ -527,26 +524,20 @@ class CoordQueryHandler final : public query::CoordinatorQueryHandler {
} }
} }
#endif
#ifdef MG_ENTERPRISE
std::vector<coordination::CoordinatorInstanceStatus> ShowInstances() const override { std::vector<coordination::CoordinatorInstanceStatus> ShowInstances() const override {
return coordinator_handler_.ShowInstances(); return coordinator_handler_.ShowInstances();
} }
#endif
private: private:
dbms::ReplicationHandler handler_;
#ifdef MG_ENTERPRISE
dbms::CoordinatorHandler coordinator_handler_; dbms::CoordinatorHandler coordinator_handler_;
#endif
}; };
#endif
/// returns false if the replication role can't be set /// returns false if the replication role can't be set
/// @throw QueryRuntimeException if an error ocurred. /// @throw QueryRuntimeException if an error ocurred.
Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_context, const Parameters &parameters) { Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_context, const Parameters &parameters,
Interpreter &interpreter) {
AuthQueryHandler *auth = interpreter_context->auth; AuthQueryHandler *auth = interpreter_context->auth;
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
auto *db_handler = interpreter_context->dbms_handler; auto *db_handler = interpreter_context->dbms_handler;
@ -595,19 +586,45 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_
license::LicenseCheckErrorToString(license_check_result.GetError(), "advanced authentication features")); license::LicenseCheckErrorToString(license_check_result.GetError(), "advanced authentication features"));
} }
const auto forbid_on_replica = [has_license = license_check_result.HasError(),
is_replica = interpreter_context->repl_state->IsReplica()]() {
if (is_replica) {
#if MG_ENTERPRISE
if (has_license) {
throw QueryException(
"Query forbidden on the replica! Update on MAIN, as it is the only source of truth for authentication "
"data. MAIN will then replicate the update to connected REPLICAs");
}
throw QueryException(
"Query forbidden on the replica! Switch role to MAIN and update user data, then switch back to REPLICA.");
#else
throw QueryException(
"Query forbidden on the replica! Switch role to MAIN and update user data, then switch back to REPLICA.");
#endif
}
};
switch (auth_query->action_) { switch (auth_query->action_) {
case AuthQuery::Action::CREATE_USER: case AuthQuery::Action::CREATE_USER:
callback.fn = [auth, username, password, valid_enterprise_license = !license_check_result.HasError()] { forbid_on_replica();
callback.fn = [auth, username, password, valid_enterprise_license = !license_check_result.HasError(),
interpreter = &interpreter] {
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
MG_ASSERT(password.IsString() || password.IsNull()); MG_ASSERT(password.IsString() || password.IsNull());
if (!auth->CreateUser(username, password.IsString() ? std::make_optional(std::string(password.ValueString())) if (!auth->CreateUser(
: std::nullopt)) { username, password.IsString() ? std::make_optional(std::string(password.ValueString())) : std::nullopt,
&*interpreter->system_transaction_)) {
throw UserAlreadyExistsException("User '{}' already exists.", username); throw UserAlreadyExistsException("User '{}' already exists.", username);
} }
// If the license is not valid we create users with admin access // If the license is not valid we create users with admin access
if (!valid_enterprise_license) { if (!valid_enterprise_license) {
spdlog::warn("Granting all the privileges to {}.", username); spdlog::warn("Granting all the privileges to {}.", username);
auth->GrantPrivilege(username, kPrivilegesAll auth->GrantPrivilege(
username, kPrivilegesAll
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
, ,
{{{AuthQuery::FineGrainedPrivilege::CREATE_DELETE, {query::kAsterisk}}}}, {{{AuthQuery::FineGrainedPrivilege::CREATE_DELETE, {query::kAsterisk}}}},
@ -619,39 +636,61 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_
} }
} }
#endif #endif
); ,
&*interpreter->system_transaction_);
} }
return std::vector<std::vector<TypedValue>>(); return std::vector<std::vector<TypedValue>>();
}; };
return callback; return callback;
case AuthQuery::Action::DROP_USER: case AuthQuery::Action::DROP_USER:
callback.fn = [auth, username] { forbid_on_replica();
if (!auth->DropUser(username)) { callback.fn = [auth, username, interpreter = &interpreter] {
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
if (!auth->DropUser(username, &*interpreter->system_transaction_)) {
throw QueryRuntimeException("User '{}' doesn't exist.", username); throw QueryRuntimeException("User '{}' doesn't exist.", username);
} }
return std::vector<std::vector<TypedValue>>(); return std::vector<std::vector<TypedValue>>();
}; };
return callback; return callback;
case AuthQuery::Action::SET_PASSWORD: case AuthQuery::Action::SET_PASSWORD:
callback.fn = [auth, username, password] { forbid_on_replica();
callback.fn = [auth, username, password, interpreter = &interpreter] {
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
MG_ASSERT(password.IsString() || password.IsNull()); MG_ASSERT(password.IsString() || password.IsNull());
auth->SetPassword(username, auth->SetPassword(username,
password.IsString() ? std::make_optional(std::string(password.ValueString())) : std::nullopt); password.IsString() ? std::make_optional(std::string(password.ValueString())) : std::nullopt,
&*interpreter->system_transaction_);
return std::vector<std::vector<TypedValue>>(); return std::vector<std::vector<TypedValue>>();
}; };
return callback; return callback;
case AuthQuery::Action::CREATE_ROLE: case AuthQuery::Action::CREATE_ROLE:
callback.fn = [auth, rolename] { forbid_on_replica();
if (!auth->CreateRole(rolename)) { callback.fn = [auth, rolename, interpreter = &interpreter] {
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
if (!auth->CreateRole(rolename, &*interpreter->system_transaction_)) {
throw QueryRuntimeException("Role '{}' already exists.", rolename); throw QueryRuntimeException("Role '{}' already exists.", rolename);
} }
return std::vector<std::vector<TypedValue>>(); return std::vector<std::vector<TypedValue>>();
}; };
return callback; return callback;
case AuthQuery::Action::DROP_ROLE: case AuthQuery::Action::DROP_ROLE:
callback.fn = [auth, rolename] { forbid_on_replica();
if (!auth->DropRole(rolename)) { callback.fn = [auth, rolename, interpreter = &interpreter] {
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
if (!auth->DropRole(rolename, &*interpreter->system_transaction_)) {
throw QueryRuntimeException("Role '{}' doesn't exist.", rolename); throw QueryRuntimeException("Role '{}' doesn't exist.", rolename);
} }
return std::vector<std::vector<TypedValue>>(); return std::vector<std::vector<TypedValue>>();
@ -682,52 +721,79 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_
}; };
return callback; return callback;
case AuthQuery::Action::SET_ROLE: case AuthQuery::Action::SET_ROLE:
callback.fn = [auth, username, rolename] { forbid_on_replica();
auth->SetRole(username, rolename); callback.fn = [auth, username, rolename, interpreter = &interpreter] {
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
auth->SetRole(username, rolename, &*interpreter->system_transaction_);
return std::vector<std::vector<TypedValue>>(); return std::vector<std::vector<TypedValue>>();
}; };
return callback; return callback;
case AuthQuery::Action::CLEAR_ROLE: case AuthQuery::Action::CLEAR_ROLE:
callback.fn = [auth, username] { forbid_on_replica();
auth->ClearRole(username); callback.fn = [auth, username, interpreter = &interpreter] {
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
auth->ClearRole(username, &*interpreter->system_transaction_);
return std::vector<std::vector<TypedValue>>(); return std::vector<std::vector<TypedValue>>();
}; };
return callback; return callback;
case AuthQuery::Action::GRANT_PRIVILEGE: case AuthQuery::Action::GRANT_PRIVILEGE:
callback.fn = [auth, user_or_role, privileges forbid_on_replica();
callback.fn = [auth, user_or_role, privileges, interpreter = &interpreter
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
, ,
label_privileges, edge_type_privileges label_privileges, edge_type_privileges
#endif #endif
] { ] {
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
auth->GrantPrivilege(user_or_role, privileges auth->GrantPrivilege(user_or_role, privileges
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
, ,
label_privileges, edge_type_privileges label_privileges, edge_type_privileges
#endif #endif
); ,
&*interpreter->system_transaction_);
return std::vector<std::vector<TypedValue>>(); return std::vector<std::vector<TypedValue>>();
}; };
return callback; return callback;
case AuthQuery::Action::DENY_PRIVILEGE: case AuthQuery::Action::DENY_PRIVILEGE:
callback.fn = [auth, user_or_role, privileges] { forbid_on_replica();
auth->DenyPrivilege(user_or_role, privileges); callback.fn = [auth, user_or_role, privileges, interpreter = &interpreter] {
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
auth->DenyPrivilege(user_or_role, privileges, &*interpreter->system_transaction_);
return std::vector<std::vector<TypedValue>>(); return std::vector<std::vector<TypedValue>>();
}; };
return callback; return callback;
case AuthQuery::Action::REVOKE_PRIVILEGE: { case AuthQuery::Action::REVOKE_PRIVILEGE: {
callback.fn = [auth, user_or_role, privileges forbid_on_replica();
callback.fn = [auth, user_or_role, privileges, interpreter = &interpreter
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
, ,
label_privileges, edge_type_privileges label_privileges, edge_type_privileges
#endif #endif
] { ] {
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
auth->RevokePrivilege(user_or_role, privileges auth->RevokePrivilege(user_or_role, privileges
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
, ,
label_privileges, edge_type_privileges label_privileges, edge_type_privileges
#endif #endif
); ,
&*interpreter->system_transaction_);
return std::vector<std::vector<TypedValue>>(); return std::vector<std::vector<TypedValue>>();
}; };
return callback; return callback;
@ -757,15 +823,20 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_
}; };
return callback; return callback;
case AuthQuery::Action::GRANT_DATABASE_TO_USER: case AuthQuery::Action::GRANT_DATABASE_TO_USER:
forbid_on_replica();
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
callback.fn = [auth, database, username, db_handler] { // NOLINT callback.fn = [auth, database, username, db_handler, interpreter = &interpreter] { // NOLINT
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
try { try {
std::optional<memgraph::dbms::DatabaseAccess> db = std::optional<memgraph::dbms::DatabaseAccess> db =
std::nullopt; // Hold pointer to database to protect it until query is done std::nullopt; // Hold pointer to database to protect it until query is done
if (database != memgraph::auth::kAllDatabases) { if (database != memgraph::auth::kAllDatabases) {
db = db_handler->Get(database); // Will throw if databases doesn't exist and protect it during pull db = db_handler->Get(database); // Will throw if databases doesn't exist and protect it during pull
} }
if (!auth->GrantDatabaseToUser(database, username)) { if (!auth->GrantDatabaseToUser(database, username, &*interpreter->system_transaction_)) {
throw QueryRuntimeException("Failed to grant database {} to user {}.", database, username); throw QueryRuntimeException("Failed to grant database {} to user {}.", database, username);
} }
} catch (memgraph::dbms::UnknownDatabaseException &e) { } catch (memgraph::dbms::UnknownDatabaseException &e) {
@ -778,15 +849,20 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_
}; };
return callback; return callback;
case AuthQuery::Action::REVOKE_DATABASE_FROM_USER: case AuthQuery::Action::REVOKE_DATABASE_FROM_USER:
forbid_on_replica();
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
callback.fn = [auth, database, username, db_handler] { // NOLINT callback.fn = [auth, database, username, db_handler, interpreter = &interpreter] { // NOLINT
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
try { try {
std::optional<memgraph::dbms::DatabaseAccess> db = std::optional<memgraph::dbms::DatabaseAccess> db =
std::nullopt; // Hold pointer to database to protect it until query is done std::nullopt; // Hold pointer to database to protect it until query is done
if (database != memgraph::auth::kAllDatabases) { if (database != memgraph::auth::kAllDatabases) {
db = db_handler->Get(database); // Will throw if databases doesn't exist and protect it during pull db = db_handler->Get(database); // Will throw if databases doesn't exist and protect it during pull
} }
if (!auth->RevokeDatabaseFromUser(database, username)) { if (!auth->RevokeDatabaseFromUser(database, username, &*interpreter->system_transaction_)) {
throw QueryRuntimeException("Failed to revoke database {} from user {}.", database, username); throw QueryRuntimeException("Failed to revoke database {} from user {}.", database, username);
} }
} catch (memgraph::dbms::UnknownDatabaseException &e) { } catch (memgraph::dbms::UnknownDatabaseException &e) {
@ -811,12 +887,17 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_
#endif #endif
return callback; return callback;
case AuthQuery::Action::SET_MAIN_DATABASE: case AuthQuery::Action::SET_MAIN_DATABASE:
forbid_on_replica();
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
callback.fn = [auth, database, username, db_handler] { // NOLINT callback.fn = [auth, database, username, db_handler, interpreter = &interpreter] { // NOLINT
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
try { try {
const auto db = const auto db =
db_handler->Get(database); // Will throw if databases doesn't exist and protect it during pull db_handler->Get(database); // Will throw if databases doesn't exist and protect it during pull
if (!auth->SetMainDatabase(database, username)) { if (!auth->SetMainDatabase(database, username, &*interpreter->system_transaction_)) {
throw QueryRuntimeException("Failed to set main database {} for user {}.", database, username); throw QueryRuntimeException("Failed to set main database {} for user {}.", database, username);
} }
} catch (memgraph::dbms::UnknownDatabaseException &e) { } catch (memgraph::dbms::UnknownDatabaseException &e) {
@ -834,7 +915,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_
} // namespace } // namespace
Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &parameters, Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &parameters,
dbms::DbmsHandler *dbms_handler, CurrentDB &current_db, ReplicationQueryHandler &replication_query_handler, CurrentDB &current_db,
const query::InterpreterConfig &config, std::vector<Notification> *notifications) { const query::InterpreterConfig &config, std::vector<Notification> *notifications) {
// TODO: MemoryResource for EvaluationContext, it should probably be passed as // TODO: MemoryResource for EvaluationContext, it should probably be passed as
// the argument to Callback. // the argument to Callback.
@ -864,7 +945,8 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &
notifications->emplace_back(SeverityLevel::WARNING, NotificationCode::REPLICA_PORT_WARNING, notifications->emplace_back(SeverityLevel::WARNING, NotificationCode::REPLICA_PORT_WARNING,
"Be careful the replication port must be different from the memgraph port!"); "Be careful the replication port must be different from the memgraph port!");
} }
callback.fn = [handler = ReplQueryHandler{dbms_handler}, role = repl_query->role_, maybe_port]() mutable { callback.fn = [handler = ReplQueryHandler{replication_query_handler}, role = repl_query->role_,
maybe_port]() mutable {
handler.SetReplicationRole(role, maybe_port); handler.SetReplicationRole(role, maybe_port);
return std::vector<std::vector<TypedValue>>(); return std::vector<std::vector<TypedValue>>();
}; };
@ -882,7 +964,7 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &
#endif #endif
callback.header = {"replication role"}; callback.header = {"replication role"};
callback.fn = [handler = ReplQueryHandler{dbms_handler}] { callback.fn = [handler = ReplQueryHandler{replication_query_handler}] {
auto mode = handler.ShowReplicationRole(); auto mode = handler.ShowReplicationRole();
switch (mode) { switch (mode) {
case ReplicationQuery::ReplicationRole::MAIN: { case ReplicationQuery::ReplicationRole::MAIN: {
@ -906,7 +988,7 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &
auto socket_address = repl_query->socket_address_->Accept(evaluator); auto socket_address = repl_query->socket_address_->Accept(evaluator);
const auto replica_check_frequency = config.replication_replica_check_frequency; const auto replica_check_frequency = config.replication_replica_check_frequency;
callback.fn = [handler = ReplQueryHandler{dbms_handler}, name, socket_address, sync_mode, callback.fn = [handler = ReplQueryHandler{replication_query_handler}, name, socket_address, sync_mode,
replica_check_frequency]() mutable { replica_check_frequency]() mutable {
handler.RegisterReplica(name, std::string(socket_address.ValueString()), sync_mode, replica_check_frequency); handler.RegisterReplica(name, std::string(socket_address.ValueString()), sync_mode, replica_check_frequency);
return std::vector<std::vector<TypedValue>>(); return std::vector<std::vector<TypedValue>>();
@ -923,7 +1005,7 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &
} }
#endif #endif
const auto &name = repl_query->instance_name_; const auto &name = repl_query->instance_name_;
callback.fn = [handler = ReplQueryHandler{dbms_handler}, name]() mutable { callback.fn = [handler = ReplQueryHandler{replication_query_handler}, name]() mutable {
handler.DropReplica(name); handler.DropReplica(name);
return std::vector<std::vector<TypedValue>>(); return std::vector<std::vector<TypedValue>>();
}; };
@ -941,7 +1023,7 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &
callback.header = { callback.header = {
"name", "socket_address", "sync_mode", "current_timestamp_of_replica", "number_of_timestamp_behind_master", "name", "socket_address", "sync_mode", "current_timestamp_of_replica", "number_of_timestamp_behind_master",
"state"}; "state"};
callback.fn = [handler = ReplQueryHandler{dbms_handler}, replica_nfields = callback.header.size(), callback.fn = [handler = ReplQueryHandler{replication_query_handler}, replica_nfields = callback.header.size(),
db_acc = current_db.db_acc_] { db_acc = current_db.db_acc_] {
const auto &replicas = handler.ShowReplicas(*db_acc->get()); const auto &replicas = handler.ShowReplicas(*db_acc->get());
auto typed_replicas = std::vector<std::vector<TypedValue>>{}; auto typed_replicas = std::vector<std::vector<TypedValue>>{};
@ -989,16 +1071,17 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &
} }
} }
#ifdef MG_ENTERPRISE
Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Parameters &parameters, Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Parameters &parameters,
dbms::DbmsHandler *dbms_handler, const query::InterpreterConfig &config, coordination::CoordinatorState *coordinator_state,
std::vector<Notification> *notifications) { const query::InterpreterConfig &config, std::vector<Notification> *notifications) {
Callback callback; Callback callback;
switch (coordinator_query->action_) { switch (coordinator_query->action_) {
case CoordinatorQuery::Action::REGISTER_INSTANCE: { case CoordinatorQuery::Action::REGISTER_INSTANCE: {
if (!license::global_license_checker.IsEnterpriseValidFast()) { if (!license::global_license_checker.IsEnterpriseValidFast()) {
throw QueryException("Trying to use enterprise feature without a valid license."); throw QueryException("Trying to use enterprise feature without a valid license.");
} }
#ifdef MG_ENTERPRISE
if constexpr (!coordination::allow_ha) { if constexpr (!coordination::allow_ha) {
throw QueryRuntimeException( throw QueryRuntimeException(
"High availability is experimental feature. Please set MG_EXPERIMENTAL_HIGH_AVAILABILITY compile flag to " "High availability is experimental feature. Please set MG_EXPERIMENTAL_HIGH_AVAILABILITY compile flag to "
@ -1014,7 +1097,7 @@ Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Param
auto coordinator_socket_address_tv = coordinator_query->coordinator_socket_address_->Accept(evaluator); auto coordinator_socket_address_tv = coordinator_query->coordinator_socket_address_->Accept(evaluator);
auto replication_socket_address_tv = coordinator_query->replication_socket_address_->Accept(evaluator); auto replication_socket_address_tv = coordinator_query->replication_socket_address_->Accept(evaluator);
callback.fn = [handler = CoordQueryHandler{dbms_handler}, coordinator_socket_address_tv, callback.fn = [handler = CoordQueryHandler{*coordinator_state}, coordinator_socket_address_tv,
replication_socket_address_tv, main_check_frequency = config.replication_replica_check_frequency, replication_socket_address_tv, main_check_frequency = config.replication_replica_check_frequency,
instance_name = coordinator_query->instance_name_, instance_name = coordinator_query->instance_name_,
sync_mode = coordinator_query->sync_mode_]() mutable { sync_mode = coordinator_query->sync_mode_]() mutable {
@ -1029,13 +1112,11 @@ Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Param
fmt::format("Coordinator has registered coordinator server on {} for instance {}.", fmt::format("Coordinator has registered coordinator server on {} for instance {}.",
coordinator_socket_address_tv.ValueString(), coordinator_query->instance_name_)); coordinator_socket_address_tv.ValueString(), coordinator_query->instance_name_));
return callback; return callback;
#endif
} }
case CoordinatorQuery::Action::SET_INSTANCE_TO_MAIN: { case CoordinatorQuery::Action::SET_INSTANCE_TO_MAIN: {
if (!license::global_license_checker.IsEnterpriseValidFast()) { if (!license::global_license_checker.IsEnterpriseValidFast()) {
throw QueryException("Trying to use enterprise feature without a valid license."); throw QueryException("Trying to use enterprise feature without a valid license.");
} }
#ifdef MG_ENTERPRISE
if constexpr (!coordination::allow_ha) { if constexpr (!coordination::allow_ha) {
throw QueryRuntimeException( throw QueryRuntimeException(
"High availability is experimental feature. Please set MG_EXPERIMENTAL_HIGH_AVAILABILITY compile flag to " "High availability is experimental feature. Please set MG_EXPERIMENTAL_HIGH_AVAILABILITY compile flag to "
@ -1049,20 +1130,18 @@ Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Param
EvaluationContext evaluation_context{.timestamp = QueryTimestamp(), .parameters = parameters}; EvaluationContext evaluation_context{.timestamp = QueryTimestamp(), .parameters = parameters};
auto evaluator = PrimitiveLiteralExpressionEvaluator{evaluation_context}; auto evaluator = PrimitiveLiteralExpressionEvaluator{evaluation_context};
callback.fn = [handler = CoordQueryHandler{dbms_handler}, callback.fn = [handler = CoordQueryHandler{*coordinator_state},
instance_name = coordinator_query->instance_name_]() mutable { instance_name = coordinator_query->instance_name_]() mutable {
handler.SetInstanceToMain(instance_name); handler.SetInstanceToMain(instance_name);
return std::vector<std::vector<TypedValue>>(); return std::vector<std::vector<TypedValue>>();
}; };
return callback; return callback;
#endif
} }
case CoordinatorQuery::Action::SHOW_REPLICATION_CLUSTER: { case CoordinatorQuery::Action::SHOW_REPLICATION_CLUSTER: {
if (!license::global_license_checker.IsEnterpriseValidFast()) { if (!license::global_license_checker.IsEnterpriseValidFast()) {
throw QueryException("Trying to use enterprise feature without a valid license."); throw QueryException("Trying to use enterprise feature without a valid license.");
} }
#ifdef MG_ENTERPRISE
if constexpr (!coordination::allow_ha) { if constexpr (!coordination::allow_ha) {
throw QueryRuntimeException( throw QueryRuntimeException(
"High availability is experimental feature. Please set MG_EXPERIMENTAL_HIGH_AVAILABILITY compile flag to " "High availability is experimental feature. Please set MG_EXPERIMENTAL_HIGH_AVAILABILITY compile flag to "
@ -1073,7 +1152,8 @@ Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Param
} }
callback.header = {"name", "socket_address", "alive", "role"}; callback.header = {"name", "socket_address", "alive", "role"};
callback.fn = [handler = CoordQueryHandler{dbms_handler}, replica_nfields = callback.header.size()]() mutable { callback.fn = [handler = CoordQueryHandler{*coordinator_state},
replica_nfields = callback.header.size()]() mutable {
auto const instances = handler.ShowInstances(); auto const instances = handler.ShowInstances();
std::vector<std::vector<TypedValue>> result{}; std::vector<std::vector<TypedValue>> result{};
result.reserve(result.size()); result.reserve(result.size());
@ -1087,11 +1167,11 @@ Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Param
return result; return result;
}; };
return callback; return callback;
#endif
} }
return callback; return callback;
} }
} }
#endif
stream::CommonStreamInfo GetCommonStreamInfo(StreamQuery *stream_query, ExpressionVisitor<TypedValue> &evaluator) { stream::CommonStreamInfo GetCommonStreamInfo(StreamQuery *stream_query, ExpressionVisitor<TypedValue> &evaluator) {
return { return {
@ -2493,14 +2573,14 @@ PreparedQuery PrepareIndexQuery(ParsedQuery parsed_query, bool in_explicit_trans
} }
PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transaction, PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
InterpreterContext *interpreter_context) { InterpreterContext *interpreter_context, Interpreter &interpreter) {
if (in_explicit_transaction) { if (in_explicit_transaction) {
throw UserModificationInMulticommandTxException(); throw UserModificationInMulticommandTxException();
} }
auto *auth_query = utils::Downcast<AuthQuery>(parsed_query.query); auto *auth_query = utils::Downcast<AuthQuery>(parsed_query.query);
auto callback = HandleAuthQuery(auth_query, interpreter_context, parsed_query.parameters); auto callback = HandleAuthQuery(auth_query, interpreter_context, parsed_query.parameters, interpreter);
return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges), return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges),
[handler = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>(nullptr), [handler = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>(nullptr),
@ -2525,15 +2605,16 @@ PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transa
} }
PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, bool in_explicit_transaction, PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
std::vector<Notification> *notifications, dbms::DbmsHandler &dbms_handler, std::vector<Notification> *notifications,
CurrentDB &current_db, const InterpreterConfig &config) { ReplicationQueryHandler &replication_query_handler, CurrentDB &current_db,
const InterpreterConfig &config) {
if (in_explicit_transaction) { if (in_explicit_transaction) {
throw ReplicationModificationInMulticommandTxException(); throw ReplicationModificationInMulticommandTxException();
} }
auto *replication_query = utils::Downcast<ReplicationQuery>(parsed_query.query); auto *replication_query = utils::Downcast<ReplicationQuery>(parsed_query.query);
auto callback = HandleReplicationQuery(replication_query, parsed_query.parameters, &dbms_handler, current_db, config, auto callback = HandleReplicationQuery(replication_query, parsed_query.parameters, replication_query_handler,
notifications); current_db, config, notifications);
return PreparedQuery{callback.header, std::move(parsed_query.required_privileges), return PreparedQuery{callback.header, std::move(parsed_query.required_privileges),
[callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}]( [callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}](
@ -2552,8 +2633,10 @@ PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, bool in_explicit
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
} }
#ifdef MG_ENTERPRISE
PreparedQuery PrepareCoordinatorQuery(ParsedQuery parsed_query, bool in_explicit_transaction, PreparedQuery PrepareCoordinatorQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
std::vector<Notification> *notifications, dbms::DbmsHandler &dbms_handler, std::vector<Notification> *notifications,
coordination::CoordinatorState &coordinator_state,
const InterpreterConfig &config) { const InterpreterConfig &config) {
if (in_explicit_transaction) { if (in_explicit_transaction) {
throw CoordinatorModificationInMulticommandTxException(); throw CoordinatorModificationInMulticommandTxException();
@ -2561,7 +2644,7 @@ PreparedQuery PrepareCoordinatorQuery(ParsedQuery parsed_query, bool in_explicit
auto *coordinator_query = utils::Downcast<CoordinatorQuery>(parsed_query.query); auto *coordinator_query = utils::Downcast<CoordinatorQuery>(parsed_query.query);
auto callback = auto callback =
HandleCoordinatorQuery(coordinator_query, parsed_query.parameters, &dbms_handler, config, notifications); HandleCoordinatorQuery(coordinator_query, parsed_query.parameters, &coordinator_state, config, notifications);
return PreparedQuery{callback.header, std::move(parsed_query.required_privileges), return PreparedQuery{callback.header, std::move(parsed_query.required_privileges),
[callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}]( [callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}](
@ -2579,6 +2662,7 @@ PreparedQuery PrepareCoordinatorQuery(ParsedQuery parsed_query, bool in_explicit
// False positive report for the std::make_shared above // False positive report for the std::make_shared above
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
} }
#endif
PreparedQuery PrepareLockPathQuery(ParsedQuery parsed_query, bool in_explicit_transaction, CurrentDB &current_db) { PreparedQuery PrepareLockPathQuery(ParsedQuery parsed_query, bool in_explicit_transaction, CurrentDB &current_db) {
if (in_explicit_transaction) { if (in_explicit_transaction) {
@ -3681,7 +3765,8 @@ PreparedQuery PrepareConstraintQuery(ParsedQuery parsed_query, bool in_explicit_
PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, CurrentDB &current_db, PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, CurrentDB &current_db,
InterpreterContext *interpreter_context, InterpreterContext *interpreter_context,
std::optional<std::function<void(std::string_view)>> on_change_cb) { std::optional<std::function<void(std::string_view)>> on_change_cb,
Interpreter &interpreter) {
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
if (!license::global_license_checker.IsEnterpriseValidFast()) { if (!license::global_license_checker.IsEnterpriseValidFast()) {
throw QueryException("Trying to use enterprise feature without a valid license."); throw QueryException("Trying to use enterprise feature without a valid license.");
@ -3700,12 +3785,16 @@ PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, CurrentDB &cur
return PreparedQuery{ return PreparedQuery{
{"STATUS"}, {"STATUS"},
std::move(parsed_query.required_privileges), std::move(parsed_query.required_privileges),
[db_name = query->db_name_, db_handler](AnyStream *stream, [db_name = query->db_name_, db_handler, interpreter = &interpreter](
std::optional<int> n) -> std::optional<QueryHandlerResult> { AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> {
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
std::vector<std::vector<TypedValue>> status; std::vector<std::vector<TypedValue>> status;
std::string res; std::string res;
const auto success = db_handler->New(db_name); const auto success = db_handler->New(db_name, &*interpreter->system_transaction_);
if (success.HasError()) { if (success.HasError()) {
switch (success.GetError()) { switch (success.GetError()) {
case dbms::NewError::EXISTS: case dbms::NewError::EXISTS:
@ -3780,16 +3869,20 @@ PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, CurrentDB &cur
return PreparedQuery{ return PreparedQuery{
{"STATUS"}, {"STATUS"},
std::move(parsed_query.required_privileges), std::move(parsed_query.required_privileges),
[db_name = query->db_name_, db_handler, auth = interpreter_context->auth]( [db_name = query->db_name_, db_handler, auth = interpreter_context->auth, interpreter = &interpreter](
AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> { AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> {
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
std::vector<std::vector<TypedValue>> status; std::vector<std::vector<TypedValue>> status;
try { try {
// Remove database // Remove database
auto success = db_handler->TryDelete(db_name); auto success = db_handler->TryDelete(db_name, &*interpreter->system_transaction_);
if (!success.HasError()) { if (!success.HasError()) {
// Remove from auth // Remove from auth
if (auth) auth->DeleteDatabase(db_name); if (auth) auth->DeleteDatabase(db_name, &*interpreter->system_transaction_);
} else { } else {
switch (success.GetError()) { switch (success.GetError()) {
case dbms::DeleteError::DEFAULT_DB: case dbms::DeleteError::DEFAULT_DB:
@ -4040,18 +4133,15 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
utils::Downcast<ReplicationQuery>(parsed_query.query); utils::Downcast<ReplicationQuery>(parsed_query.query);
// TODO Split SHOW REPLICAS (which needs the db) and other replication queries // TODO Split SHOW REPLICAS (which needs the db) and other replication queries
auto system_transaction_guard = std::invoke([&]() -> std::optional<SystemTransactionGuard> { auto system_transaction = std::invoke([&]() -> std::optional<memgraph::system::Transaction> {
if (system_queries) { if (!system_queries) return std::nullopt;
// TODO: Ordering between system and data queries // TODO: Ordering between system and data queries
// Start a system transaction auto system_txn = interpreter_context_->system_->TryCreateTransaction(std::chrono::milliseconds(kSystemTxTryMS));
auto system_unique = std::unique_lock{interpreter_context_->dbms_handler->system_lock_, std::defer_lock}; if (!system_txn) {
if (!system_unique.try_lock_for(std::chrono::milliseconds(kSystemTxTryMS))) {
throw ConcurrentSystemQueriesException("Multiple concurrent system queries are not supported."); throw ConcurrentSystemQueriesException("Multiple concurrent system queries are not supported.");
} }
return std::optional<SystemTransactionGuard>{std::in_place, std::move(system_unique), return system_txn;
*interpreter_context_->dbms_handler};
}
return std::nullopt;
}); });
// Some queries do not require a database to be executed (current_db_ won't be passed on to the Prepare*; special // Some queries do not require a database to be executed (current_db_ won't be passed on to the Prepare*; special
@ -4117,7 +4207,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
prepared_query = PrepareAnalyzeGraphQuery(std::move(parsed_query), in_explicit_transaction_, current_db_); prepared_query = PrepareAnalyzeGraphQuery(std::move(parsed_query), in_explicit_transaction_, current_db_);
} else if (utils::Downcast<AuthQuery>(parsed_query.query)) { } else if (utils::Downcast<AuthQuery>(parsed_query.query)) {
/// SYSTEM (Replication) PURE /// SYSTEM (Replication) PURE
prepared_query = PrepareAuthQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_); prepared_query = PrepareAuthQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_, *this);
} else if (utils::Downcast<DatabaseInfoQuery>(parsed_query.query)) { } else if (utils::Downcast<DatabaseInfoQuery>(parsed_query.query)) {
prepared_query = PrepareDatabaseInfoQuery(std::move(parsed_query), in_explicit_transaction_, current_db_); prepared_query = PrepareDatabaseInfoQuery(std::move(parsed_query), in_explicit_transaction_, current_db_);
} else if (utils::Downcast<SystemInfoQuery>(parsed_query.query)) { } else if (utils::Downcast<SystemInfoQuery>(parsed_query.query)) {
@ -4128,13 +4218,18 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
&query_execution->notifications, current_db_); &query_execution->notifications, current_db_);
} else if (utils::Downcast<ReplicationQuery>(parsed_query.query)) { } else if (utils::Downcast<ReplicationQuery>(parsed_query.query)) {
/// TODO: make replication DB agnostic /// TODO: make replication DB agnostic
prepared_query = prepared_query = PrepareReplicationQuery(
PrepareReplicationQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications, std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications,
*interpreter_context_->dbms_handler, current_db_, interpreter_context_->config); *interpreter_context_->replication_handler_, current_db_, interpreter_context_->config);
} else if (utils::Downcast<CoordinatorQuery>(parsed_query.query)) { } else if (utils::Downcast<CoordinatorQuery>(parsed_query.query)) {
#ifdef MG_ENTERPRISE
prepared_query = prepared_query =
PrepareCoordinatorQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications, PrepareCoordinatorQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications,
*interpreter_context_->dbms_handler, interpreter_context_->config); *interpreter_context_->coordinator_state_, interpreter_context_->config);
#else
throw QueryRuntimeException("Coordinator queries are not part of community edition");
#endif
} else if (utils::Downcast<LockPathQuery>(parsed_query.query)) { } else if (utils::Downcast<LockPathQuery>(parsed_query.query)) {
prepared_query = PrepareLockPathQuery(std::move(parsed_query), in_explicit_transaction_, current_db_); prepared_query = PrepareLockPathQuery(std::move(parsed_query), in_explicit_transaction_, current_db_);
} else if (utils::Downcast<FreeMemoryQuery>(parsed_query.query)) { } else if (utils::Downcast<FreeMemoryQuery>(parsed_query.query)) {
@ -4177,8 +4272,8 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
} }
/// SYSTEM (Replication) + INTERPRETER /// SYSTEM (Replication) + INTERPRETER
// DMG_ASSERT(system_guard); // DMG_ASSERT(system_guard);
prepared_query = PrepareMultiDatabaseQuery(std::move(parsed_query), current_db_, interpreter_context_, on_change_ prepared_query =
/*, *system_guard*/); PrepareMultiDatabaseQuery(std::move(parsed_query), current_db_, interpreter_context_, on_change_, *this);
} else if (utils::Downcast<ShowDatabasesQuery>(parsed_query.query)) { } else if (utils::Downcast<ShowDatabasesQuery>(parsed_query.query)) {
prepared_query = PrepareShowDatabasesQuery(std::move(parsed_query), interpreter_context_, username_); prepared_query = PrepareShowDatabasesQuery(std::move(parsed_query), interpreter_context_, username_);
} else if (utils::Downcast<EdgeImportModeQuery>(parsed_query.query)) { } else if (utils::Downcast<EdgeImportModeQuery>(parsed_query.query)) {
@ -4210,7 +4305,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
query_execution->summary["db"] = *query_execution->prepared_query->db; query_execution->summary["db"] = *query_execution->prepared_query->db;
// prepare is done, move system txn guard to be owned by interpreter // prepare is done, move system txn guard to be owned by interpreter
system_transaction_guard_ = std::move(system_transaction_guard); system_transaction_ = std::move(system_transaction);
return {query_execution->prepared_query->header, query_execution->prepared_query->privileges, qid, return {query_execution->prepared_query->header, query_execution->prepared_query->privileges, qid,
query_execution->prepared_query->db}; query_execution->prepared_query->db};
} catch (const utils::BasicException &) { } catch (const utils::BasicException &) {
@ -4360,13 +4455,13 @@ void Interpreter::Commit() {
current_transaction_.reset(); current_transaction_.reset();
if (!current_db_.db_transactional_accessor_ || !current_db_.db_acc_) { if (!current_db_.db_transactional_accessor_ || !current_db_.db_acc_) {
// No database nor db transaction; check for system transaction // No database nor db transaction; check for system transaction
if (!system_transaction_guard_) return; if (!system_transaction_) return;
// TODO Distinguish between data and system transaction state // TODO Distinguish between data and system transaction state
// Think about updating the status to a struct with bitfield // Think about updating the status to a struct with bitfield
// Clean transaction status on exit // Clean transaction status on exit
utils::OnScopeExit clean_status([this]() { utils::OnScopeExit clean_status([this]() {
system_transaction_guard_.reset(); system_transaction_.reset();
// System transactions are not terminable // System transactions are not terminable
// Durability has happened at time of PULL // Durability has happened at time of PULL
// Commit is doing replication and timestamp update // Commit is doing replication and timestamp update
@ -4384,7 +4479,23 @@ void Interpreter::Commit() {
} }
}); });
system_transaction_guard_->Commit(); auto const main_commit = [&](replication::RoleMainData &mainData) {
// Only enterprise can do system replication
#ifdef MG_ENTERPRISE
if (license::global_license_checker.IsEnterpriseValidFast()) {
return system_transaction_->Commit(memgraph::system::DoReplication{mainData});
}
#endif
return system_transaction_->Commit(memgraph::system::DoNothing{});
};
auto const replica_commit = [&](replication::RoleReplicaData &) {
return system_transaction_->Commit(memgraph::system::DoNothing{});
};
auto const commit_method = utils::Overloaded{main_commit, replica_commit};
[[maybe_unused]] auto sync_result = std::visit(commit_method, interpreter_context_->repl_state->ReplicationData());
// TODO: something with sync_result
return; return;
} }
auto *db = current_db_.db_acc_->get(); auto *db = current_db_.db_acc_->get();

View File

@ -72,6 +72,7 @@ inline constexpr size_t kExecutionPoolMaxBlockSize = 1024UL; // 2 ^ 10
enum class QueryHandlerResult { COMMIT, ABORT, NOTHING }; enum class QueryHandlerResult { COMMIT, ABORT, NOTHING };
#ifdef MG_ENTERPRISE
class CoordinatorQueryHandler { class CoordinatorQueryHandler {
public: public:
CoordinatorQueryHandler() = default; CoordinatorQueryHandler() = default;
@ -93,7 +94,6 @@ class CoordinatorQueryHandler {
ReplicationQuery::ReplicaState state; ReplicationQuery::ReplicaState state;
}; };
#ifdef MG_ENTERPRISE
struct MainReplicaStatus { struct MainReplicaStatus {
std::string_view name; std::string_view name;
std::string_view socket_address; std::string_view socket_address;
@ -103,9 +103,7 @@ class CoordinatorQueryHandler {
MainReplicaStatus(std::string_view name, std::string_view socket_address, bool alive, bool is_main) MainReplicaStatus(std::string_view name, std::string_view socket_address, bool alive, bool is_main)
: name{name}, socket_address{socket_address}, alive{alive}, is_main{is_main} {} : name{name}, socket_address{socket_address}, alive{alive}, is_main{is_main} {}
}; };
#endif
#ifdef MG_ENTERPRISE
/// @throw QueryRuntimeException if an error ocurred. /// @throw QueryRuntimeException if an error ocurred.
virtual void RegisterInstance(const std::string &coordinator_socket_address, virtual void RegisterInstance(const std::string &coordinator_socket_address,
const std::string &replication_socket_address, const std::string &replication_socket_address,
@ -117,9 +115,8 @@ class CoordinatorQueryHandler {
/// @throw QueryRuntimeException if an error ocurred. /// @throw QueryRuntimeException if an error ocurred.
virtual std::vector<coordination::CoordinatorInstanceStatus> ShowInstances() const = 0; virtual std::vector<coordination::CoordinatorInstanceStatus> ShowInstances() const = 0;
#endif
}; };
#endif
class AnalyzeGraphQueryHandler { class AnalyzeGraphQueryHandler {
public: public:
@ -296,32 +293,12 @@ class Interpreter final {
void SetUser(std::string_view username); void SetUser(std::string_view username);
struct SystemTransactionGuard { std::optional<memgraph::system::Transaction> system_transaction_{};
explicit SystemTransactionGuard(std::unique_lock<utils::ResourceLock> guard, dbms::DbmsHandler &dbms_handler)
: system_guard_(std::move(guard)), dbms_handler_{&dbms_handler} {
dbms_handler_->NewSystemTransaction();
}
SystemTransactionGuard &operator=(SystemTransactionGuard &&) = default;
SystemTransactionGuard(SystemTransactionGuard &&) = default;
~SystemTransactionGuard() {
if (system_guard_.owns_lock()) dbms_handler_->ResetSystemTransaction();
}
dbms::AllSyncReplicaStatus Commit() { return dbms_handler_->Commit(); }
private:
std::unique_lock<utils::ResourceLock> system_guard_;
dbms::DbmsHandler *dbms_handler_;
};
std::optional<SystemTransactionGuard> system_transaction_guard_{};
private: private:
void ResetInterpreter() { void ResetInterpreter() {
query_executions_.clear(); query_executions_.clear();
system_guard.reset(); system_transaction_.reset();
system_transaction_guard_.reset();
transaction_queries_->clear(); transaction_queries_->clear();
if (current_db_.db_acc_ && current_db_.db_acc_->is_deleting()) { if (current_db_.db_acc_ && current_db_.db_acc_->is_deleting()) {
current_db_.db_acc_.reset(); current_db_.db_acc_.reset();
@ -386,8 +363,6 @@ class Interpreter final {
// TODO Figure out how this would work for multi-database // TODO Figure out how this would work for multi-database
// Exists only during a single transaction (for now should be okay as is) // Exists only during a single transaction (for now should be okay as is)
std::vector<std::unique_ptr<QueryExecution>> query_executions_; std::vector<std::unique_ptr<QueryExecution>> query_executions_;
// TODO: our upgradable lock guard for system
std::optional<utils::ResourceLockGuard> system_guard;
// all queries that are run as part of the current transaction // all queries that are run as part of the current transaction
utils::Synchronized<std::vector<std::string>, utils::SpinLock> transaction_queries_; utils::Synchronized<std::vector<std::string>, utils::SpinLock> transaction_queries_;

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd. // Copyright 2024 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -12,12 +12,27 @@
#include "query/interpreter_context.hpp" #include "query/interpreter_context.hpp"
#include "query/interpreter.hpp" #include "query/interpreter.hpp"
#include "system/include/system/system.hpp"
namespace memgraph::query { namespace memgraph::query {
InterpreterContext::InterpreterContext(InterpreterConfig interpreter_config, dbms::DbmsHandler *dbms_handler, InterpreterContext::InterpreterContext(InterpreterConfig interpreter_config, dbms::DbmsHandler *dbms_handler,
replication::ReplicationState *rs, query::AuthQueryHandler *ah, replication::ReplicationState *rs, memgraph::system::System &system,
query::AuthChecker *ac) #ifdef MG_ENTERPRISE
: dbms_handler(dbms_handler), config(interpreter_config), repl_state(rs), auth(ah), auth_checker(ac) {} memgraph::coordination::CoordinatorState *coordinator_state,
#endif
AuthQueryHandler *ah, AuthChecker *ac,
ReplicationQueryHandler *replication_handler)
: dbms_handler(dbms_handler),
config(interpreter_config),
repl_state(rs),
#ifdef MG_ENTERPRISE
coordinator_state_{coordinator_state},
#endif
auth(ah),
auth_checker(ac),
replication_handler_{replication_handler},
system_{&system} {
}
std::vector<std::vector<TypedValue>> InterpreterContext::TerminateTransactions( std::vector<std::vector<TypedValue>> InterpreterContext::TerminateTransactions(
std::vector<std::string> maybe_kill_transaction_ids, const std::optional<std::string> &username, std::vector<std::string> maybe_kill_transaction_ids, const std::optional<std::string> &username,

View File

@ -20,14 +20,20 @@
#include "query/config.hpp" #include "query/config.hpp"
#include "query/cypher_query_interpreter.hpp" #include "query/cypher_query_interpreter.hpp"
#include "query/replication_query_handler.hpp"
#include "query/typed_value.hpp" #include "query/typed_value.hpp"
#include "replication/state.hpp" #include "replication/state.hpp"
#include "storage/v2/config.hpp" #include "storage/v2/config.hpp"
#include "storage/v2/transaction.hpp" #include "storage/v2/transaction.hpp"
#include "system/state.hpp"
#include "system/system.hpp"
#include "utils/gatekeeper.hpp" #include "utils/gatekeeper.hpp"
#include "utils/skip_list.hpp" #include "utils/skip_list.hpp"
#include "utils/spin_lock.hpp" #include "utils/spin_lock.hpp"
#include "utils/synchronized.hpp" #include "utils/synchronized.hpp"
#ifdef MG_ENTERPRISE
#include "coordination/coordinator_state.hpp"
#endif
namespace memgraph::dbms { namespace memgraph::dbms {
class DbmsHandler; class DbmsHandler;
@ -48,7 +54,12 @@ class Interpreter;
*/ */
struct InterpreterContext { struct InterpreterContext {
InterpreterContext(InterpreterConfig interpreter_config, dbms::DbmsHandler *dbms_handler, InterpreterContext(InterpreterConfig interpreter_config, dbms::DbmsHandler *dbms_handler,
replication::ReplicationState *rs, AuthQueryHandler *ah = nullptr, AuthChecker *ac = nullptr); replication::ReplicationState *rs, memgraph::system::System &system,
#ifdef MG_ENTERPRISE
memgraph::coordination::CoordinatorState *coordinator_state,
#endif
AuthQueryHandler *ah = nullptr, AuthChecker *ac = nullptr,
ReplicationQueryHandler *replication_handler = nullptr);
memgraph::dbms::DbmsHandler *dbms_handler; memgraph::dbms::DbmsHandler *dbms_handler;
@ -59,9 +70,14 @@ struct InterpreterContext {
// GLOBAL // GLOBAL
memgraph::replication::ReplicationState *repl_state; memgraph::replication::ReplicationState *repl_state;
#ifdef MG_ENTERPRISE
memgraph::coordination::CoordinatorState *coordinator_state_;
#endif
AuthQueryHandler *auth; AuthQueryHandler *auth;
AuthChecker *auth_checker; AuthChecker *auth_checker;
ReplicationQueryHandler *replication_handler_;
system::System *system_;
// Used to check active transactions // Used to check active transactions
// TODO: Have a way to read the current database // TODO: Have a way to read the current database

View File

@ -0,0 +1,60 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "replication_coordination_glue/role.hpp"
#include "utils/result.hpp"
// BEGIN fwd declares
namespace memgraph::replication {
struct ReplicationState;
struct ReplicationServerConfig;
struct ReplicationClientConfig;
} // namespace memgraph::replication
namespace memgraph::query {
enum class RegisterReplicaError : uint8_t { NAME_EXISTS, ENDPOINT_EXISTS, CONNECTION_FAILED, COULD_NOT_BE_PERSISTED };
enum class UnregisterReplicaResult : uint8_t {
NOT_MAIN,
COULD_NOT_BE_PERSISTED,
CAN_NOT_UNREGISTER,
SUCCESS,
};
/// A handler type that keep in sync current ReplicationState and the MAIN/REPLICA-ness of Storage
struct ReplicationQueryHandler {
virtual ~ReplicationQueryHandler() = default;
// as REPLICA, become MAIN
virtual bool SetReplicationRoleMain() = 0;
// as MAIN, become REPLICA
virtual bool SetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config) = 0;
// as MAIN, define and connect to REPLICAs
virtual auto TryRegisterReplica(const memgraph::replication::ReplicationClientConfig &config)
-> utils::BasicResult<RegisterReplicaError> = 0;
virtual auto RegisterReplica(const memgraph::replication::ReplicationClientConfig &config)
-> utils::BasicResult<RegisterReplicaError> = 0;
// as MAIN, remove a REPLICA connection
virtual auto UnregisterReplica(std::string_view name) -> UnregisterReplicaResult = 0;
// Helper pass-through (TODO: remove)
virtual auto GetRole() const -> memgraph::replication_coordination_glue::ReplicationRole = 0;
virtual bool IsMain() const = 0;
virtual bool IsReplica() const = 0;
};
} // namespace memgraph::query

View File

@ -6,7 +6,6 @@ target_sources(mg-replication
include/replication/epoch.hpp include/replication/epoch.hpp
include/replication/config.hpp include/replication/config.hpp
include/replication/status.hpp include/replication/status.hpp
include/replication/messages.hpp
include/replication/replication_client.hpp include/replication/replication_client.hpp
include/replication/replication_server.hpp include/replication/replication_server.hpp
@ -15,7 +14,6 @@ target_sources(mg-replication
epoch.cpp epoch.cpp
config.cpp config.cpp
status.cpp status.cpp
messages.cpp
replication_client.cpp replication_client.cpp
replication_server.cpp replication_server.cpp
) )

View File

@ -1,47 +0,0 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "rpc/messages.hpp"
#include "slk/serialization.hpp"
namespace memgraph::replication {
struct SystemHeartbeatReq {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(SystemHeartbeatReq *self, memgraph::slk::Reader *reader);
static void Save(const SystemHeartbeatReq &self, memgraph::slk::Builder *builder);
SystemHeartbeatReq() = default;
};
struct SystemHeartbeatRes {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(SystemHeartbeatRes *self, memgraph::slk::Reader *reader);
static void Save(const SystemHeartbeatRes &self, memgraph::slk::Builder *builder);
SystemHeartbeatRes() = default;
explicit SystemHeartbeatRes(uint64_t system_timestamp) : system_timestamp(system_timestamp) {}
uint64_t system_timestamp;
};
using SystemHeartbeatRpc = rpc::RequestResponse<SystemHeartbeatReq, SystemHeartbeatRes>;
} // namespace memgraph::replication
namespace memgraph::slk {
void Save(const memgraph::replication::SystemHeartbeatRes &self, memgraph::slk::Builder *builder);
void Load(memgraph::replication::SystemHeartbeatRes *self, memgraph::slk::Reader *reader);
void Save(const memgraph::replication::SystemHeartbeatReq & /*self*/, memgraph::slk::Builder * /*builder*/);
void Load(memgraph::replication::SystemHeartbeatReq * /*self*/, memgraph::slk::Reader * /*reader*/);
} // namespace memgraph::slk

View File

@ -41,7 +41,8 @@ struct ReplicationClient {
void StartFrequentCheck(F &&callback) { void StartFrequentCheck(F &&callback) {
// Help the user to get the most accurate replica state possible. // Help the user to get the most accurate replica state possible.
if (replica_check_frequency_ > std::chrono::seconds(0)) { if (replica_check_frequency_ > std::chrono::seconds(0)) {
replica_checker_.Run("Replica Checker", replica_check_frequency_, replica_checker_.Run(
"Replica Checker", replica_check_frequency_,
[this, cb = std::forward<F>(callback), reconnect = false]() mutable { [this, cb = std::forward<F>(callback), reconnect = false]() mutable {
try { try {
{ {
@ -61,6 +62,46 @@ struct ReplicationClient {
} }
} }
//! \tparam RPC An rpc::RequestResponse
//! \tparam Args the args type
//! \param client the client to use for rpc communication
//! \param check predicate to check response is ok
//! \param args arguments to forward to the rpc request
//! \return If replica stream is completed or enqueued
template <typename RPC, typename... Args>
bool SteamAndFinalizeDelta(auto &&check, Args &&...args) {
try {
auto stream = rpc_client_.template Stream<RPC>(std::forward<Args>(args)...);
auto task = [this, check = std::forward<decltype(check)>(check), stream = std::move(stream)]() mutable {
if (stream.IsDefunct()) {
state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; });
return false;
}
try {
if (check(stream.AwaitResponse())) {
return true;
}
} catch (memgraph::rpc::GenericRpcFailedException const &e) {
// swallow error, fallthrough to error handling
}
// This replica needs SYSTEM recovery
state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; });
return false;
};
if (mode_ == memgraph::replication_coordination_glue::ReplicationMode::ASYNC) {
thread_pool_.AddTask([task = utils::CopyMovableFunctionWrapper{std::move(task)}]() mutable { task(); });
return true;
}
return task();
} catch (memgraph::rpc::GenericRpcFailedException const &e) {
// This replica needs SYSTEM recovery
state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; });
return false;
}
};
std::string name_; std::string name_;
communication::ClientContext rpc_context_; communication::ClientContext rpc_context_;
rpc::Client rpc_client_; rpc::Client rpc_client_;

View File

@ -1,52 +0,0 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "replication/messages.hpp"
namespace memgraph::replication {
constexpr utils::TypeInfo SystemHeartbeatReq::kType{utils::TypeId::REP_SYSTEM_HEARTBEAT_REQ, "SystemHeartbeatReq",
nullptr};
constexpr utils::TypeInfo SystemHeartbeatRes::kType{utils::TypeId::REP_SYSTEM_HEARTBEAT_RES, "SystemHeartbeatRes",
nullptr};
void SystemHeartbeatReq::Save(const SystemHeartbeatReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void SystemHeartbeatReq::Load(SystemHeartbeatReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
void SystemHeartbeatRes::Save(const SystemHeartbeatRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void SystemHeartbeatRes::Load(SystemHeartbeatRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
} // namespace memgraph::replication
namespace memgraph::slk {
// Serialize code for SystemHeartbeatRes
void Save(const memgraph::replication::SystemHeartbeatRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.system_timestamp, builder);
}
void Load(memgraph::replication::SystemHeartbeatRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->system_timestamp, reader);
}
// Serialize code for SystemHeartbeatReq
void Save(const memgraph::replication::SystemHeartbeatReq & /*self*/, memgraph::slk::Builder * /*builder*/) {
/* Nothing to serialize */
}
void Load(memgraph::replication::SystemHeartbeatReq * /*self*/, memgraph::slk::Reader * /*reader*/) {
/* Nothing to serialize */
}
} // namespace memgraph::slk

View File

@ -0,0 +1,17 @@
add_library(mg-replication_handler STATIC)
add_library(mg::replication_handler ALIAS mg-replication_handler)
target_sources(mg-replication_handler
PUBLIC
include/replication_handler/replication_handler.hpp
include/replication_handler/system_replication.hpp
include/replication_handler/system_rpc.hpp
PRIVATE
replication_handler.cpp
system_replication.cpp
system_rpc.cpp
)
target_include_directories(mg-replication_handler PUBLIC include)
target_link_libraries(mg-replication_handler
PUBLIC mg-auth mg-dbms mg-replication)

View File

@ -0,0 +1,220 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "auth/auth.hpp"
#include "dbms/dbms_handler.hpp"
#include "replication/include/replication/state.hpp"
#include "replication_handler/system_rpc.hpp"
#include "utils/result.hpp"
namespace memgraph::replication {
inline std::optional<query::RegisterReplicaError> HandleRegisterReplicaStatus(
utils::BasicResult<replication::RegisterReplicaError, replication::ReplicationClient *> &instance_client);
#ifdef MG_ENTERPRISE
void StartReplicaClient(replication::ReplicationClient &client, system::System *system, dbms::DbmsHandler &dbms_handler,
auth::SynchedAuth &auth);
#else
void StartReplicaClient(replication::ReplicationClient &client, dbms::DbmsHandler &dbms_handler);
#endif
#ifdef MG_ENTERPRISE
// TODO: Split into 2 functions: dbms and auth
// When being called by interpreter no need to gain lock, it should already be under a system transaction
// But concurrently the FrequentCheck is running and will need to lock before reading last_committed_system_timestamp_
template <bool REQUIRE_LOCK = false>
void SystemRestore(replication::ReplicationClient &client, system::System *system, dbms::DbmsHandler &dbms_handler,
auth::SynchedAuth &auth) {
// Check if system is up to date
if (client.state_.WithLock(
[](auto &state) { return state == memgraph::replication::ReplicationClient::State::READY; }))
return;
// Try to recover...
{
struct DbInfo {
std::vector<storage::SalientConfig> configs;
uint64_t last_committed_timestamp;
};
DbInfo db_info = std::invoke([&] {
auto guard = std::invoke([&]() -> std::optional<memgraph::system::TransactionGuard> {
if constexpr (REQUIRE_LOCK) {
return system->GenTransactionGuard();
}
return std::nullopt;
});
if (license::global_license_checker.IsEnterpriseValidFast()) {
auto configs = std::vector<storage::SalientConfig>{};
dbms_handler.ForEach([&configs](dbms::DatabaseAccess acc) { configs.emplace_back(acc->config().salient); });
// TODO: This is `SystemRestore` maybe DbInfo is incorrect as it will need Auth also
return DbInfo{configs, system->LastCommittedSystemTimestamp()};
}
// No license -> send only default config
return DbInfo{{dbms_handler.Get()->config().salient}, system->LastCommittedSystemTimestamp()};
});
try {
auto stream = std::invoke([&]() {
// Handle only default database is no license
if (!license::global_license_checker.IsEnterpriseValidFast()) {
return client.rpc_client_.Stream<replication::SystemRecoveryRpc>(
db_info.last_committed_timestamp, std::move(db_info.configs), auth::Auth::Config{},
std::vector<auth::User>{}, std::vector<auth::Role>{});
}
return auth.WithLock([&](auto &locked_auth) {
return client.rpc_client_.Stream<replication::SystemRecoveryRpc>(
db_info.last_committed_timestamp, std::move(db_info.configs), locked_auth.GetConfig(),
locked_auth.AllUsers(), locked_auth.AllRoles());
});
});
const auto response = stream.AwaitResponse();
if (response.result == replication::SystemRecoveryRes::Result::FAILURE) {
client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; });
return;
}
} catch (memgraph::rpc::GenericRpcFailedException const &e) {
client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; });
return;
}
}
// Successfully recovered
client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::READY; });
}
#endif
/// A handler type that keep in sync current ReplicationState and the MAIN/REPLICA-ness of Storage
struct ReplicationHandler : public memgraph::query::ReplicationQueryHandler {
#ifdef MG_ENTERPRISE
explicit ReplicationHandler(memgraph::replication::ReplicationState &repl_state,
memgraph::dbms::DbmsHandler &dbms_handler, memgraph::system::System *system,
memgraph::auth::SynchedAuth &auth);
#else
explicit ReplicationHandler(memgraph::replication::ReplicationState &repl_state,
memgraph::dbms::DbmsHandler &dbms_handler);
#endif
// as REPLICA, become MAIN
bool SetReplicationRoleMain() override;
// as MAIN, become REPLICA
bool SetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config) override;
// as MAIN, define and connect to REPLICAs
auto TryRegisterReplica(const memgraph::replication::ReplicationClientConfig &config)
-> memgraph::utils::BasicResult<memgraph::query::RegisterReplicaError> override;
auto RegisterReplica(const memgraph::replication::ReplicationClientConfig &config)
-> memgraph::utils::BasicResult<memgraph::query::RegisterReplicaError> override;
// as MAIN, remove a REPLICA connection
auto UnregisterReplica(std::string_view name) -> memgraph::query::UnregisterReplicaResult override;
bool DoReplicaToMainPromotion();
// Helper pass-through (TODO: remove)
auto GetRole() const -> memgraph::replication_coordination_glue::ReplicationRole override;
bool IsMain() const override;
bool IsReplica() const override;
private:
template <bool HandleFailure>
auto RegisterReplica_(const memgraph::replication::ReplicationClientConfig &config)
-> memgraph::utils::BasicResult<memgraph::query::RegisterReplicaError> {
MG_ASSERT(repl_state_.IsMain(), "Only main instance can register a replica!");
auto maybe_client = repl_state_.RegisterReplica(config);
if (maybe_client.HasError()) {
switch (maybe_client.GetError()) {
case memgraph::replication::RegisterReplicaError::NOT_MAIN:
MG_ASSERT(false, "Only main instance can register a replica!");
return {};
case memgraph::replication::RegisterReplicaError::NAME_EXISTS:
return memgraph::query::RegisterReplicaError::NAME_EXISTS;
case memgraph::replication::RegisterReplicaError::ENDPOINT_EXISTS:
return memgraph::query::RegisterReplicaError::ENDPOINT_EXISTS;
case memgraph::replication::RegisterReplicaError::COULD_NOT_BE_PERSISTED:
return memgraph::query::RegisterReplicaError::COULD_NOT_BE_PERSISTED;
case memgraph::replication::RegisterReplicaError::SUCCESS:
break;
}
}
if (!memgraph::dbms::allow_mt_repl && dbms_handler_.All().size() > 1) {
spdlog::warn("Multi-tenant replication is currently not supported!");
}
#ifdef MG_ENTERPRISE
// Update system before enabling individual storage <-> replica clients
SystemRestore(*maybe_client.GetValue(), system_, dbms_handler_, auth_);
#endif
const auto dbms_error = HandleRegisterReplicaStatus(maybe_client);
if (dbms_error.has_value()) {
return *dbms_error;
}
auto &instance_client_ptr = maybe_client.GetValue();
bool all_clients_good = true;
// Add database specific clients (NOTE Currently all databases are connected to each replica)
dbms_handler_.ForEach([&](dbms::DatabaseAccess db_acc) {
auto *storage = db_acc->storage();
if (!dbms::allow_mt_repl && storage->name() != dbms::kDefaultDB) {
return;
}
// TODO: ATM only IN_MEMORY_TRANSACTIONAL, fix other modes
if (storage->storage_mode_ != storage::StorageMode::IN_MEMORY_TRANSACTIONAL) return;
all_clients_good &= storage->repl_storage_state_.replication_clients_.WithLock(
[storage, &instance_client_ptr, db_acc = std::move(db_acc)](auto &storage_clients) mutable { // NOLINT
auto client = std::make_unique<storage::ReplicationStorageClient>(*instance_client_ptr);
// All good, start replica client
client->Start(storage, std::move(db_acc));
// After start the storage <-> replica state should be READY or RECOVERING (if correctly started)
// MAYBE_BEHIND isn't a statement of the current state, this is the default value
// Failed to start due an error like branching of MAIN and REPLICA
const bool success = client->State() != storage::replication::ReplicaState::MAYBE_BEHIND;
if (HandleFailure || success) {
storage_clients.push_back(std::move(client));
}
return success;
});
});
// NOTE Currently if any databases fails, we revert back
if (!HandleFailure && !all_clients_good) {
spdlog::error("Failed to register all databases on the REPLICA \"{}\"", config.name);
UnregisterReplica(config.name);
return memgraph::query::RegisterReplicaError::CONNECTION_FAILED;
}
// No client error, start instance level client
#ifdef MG_ENTERPRISE
StartReplicaClient(*instance_client_ptr, system_, dbms_handler_, auth_);
#else
StartReplicaClient(*instance_client_ptr, dbms_handler_);
#endif
return {};
}
memgraph::replication::ReplicationState &repl_state_;
memgraph::dbms::DbmsHandler &dbms_handler_;
#ifdef MG_ENTERPRISE
memgraph::system::System *system_;
memgraph::auth::SynchedAuth &auth_;
#endif
};
} // namespace memgraph::replication

View File

@ -0,0 +1,31 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "auth/auth.hpp"
#include "dbms/dbms_handler.hpp"
#include "slk/streams.hpp"
#include "system/state.hpp"
namespace memgraph::replication {
#ifdef MG_ENTERPRISE
void SystemHeartbeatHandler(uint64_t ts, slk::Reader *req_reader, slk::Builder *res_builder);
void SystemRecoveryHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access,
dbms::DbmsHandler &dbms_handler, auth::SynchedAuth &auth, slk::Reader *req_reader,
slk::Builder *res_builder);
void Register(replication::RoleReplicaData const &data, dbms::DbmsHandler &dbms_handler, auth::SynchedAuth &auth);
bool StartRpcServer(dbms::DbmsHandler &dbms_handler, const replication::RoleReplicaData &data, auth::SynchedAuth &auth);
#else
bool StartRpcServer(dbms::DbmsHandler &dbms_handler, const replication::RoleReplicaData &data);
#endif
} // namespace memgraph::replication

View File

@ -0,0 +1,95 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <cstdint>
#include <vector>
#include "auth/auth.hpp"
#include "auth/models.hpp"
#include "rpc/messages.hpp"
#include "storage/v2/config.hpp"
namespace memgraph::replication {
struct SystemHeartbeatReq {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(SystemHeartbeatReq *self, memgraph::slk::Reader *reader);
static void Save(const SystemHeartbeatReq &self, memgraph::slk::Builder *builder);
SystemHeartbeatReq() = default;
};
struct SystemHeartbeatRes {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(SystemHeartbeatRes *self, memgraph::slk::Reader *reader);
static void Save(const SystemHeartbeatRes &self, memgraph::slk::Builder *builder);
SystemHeartbeatRes() = default;
explicit SystemHeartbeatRes(uint64_t system_timestamp) : system_timestamp(system_timestamp) {}
uint64_t system_timestamp;
};
using SystemHeartbeatRpc = rpc::RequestResponse<SystemHeartbeatReq, SystemHeartbeatRes>;
struct SystemRecoveryReq {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(SystemRecoveryReq *self, memgraph::slk::Reader *reader);
static void Save(const SystemRecoveryReq &self, memgraph::slk::Builder *builder);
SystemRecoveryReq() = default;
SystemRecoveryReq(uint64_t forced_group_timestamp, std::vector<storage::SalientConfig> database_configs,
auth::Auth::Config auth_config, std::vector<auth::User> users, std::vector<auth::Role> roles)
: forced_group_timestamp{forced_group_timestamp},
database_configs(std::move(database_configs)),
auth_config(std::move(auth_config)),
users{std::move(users)},
roles{std::move(roles)} {}
uint64_t forced_group_timestamp;
std::vector<storage::SalientConfig> database_configs;
auth::Auth::Config auth_config;
std::vector<auth::User> users;
std::vector<auth::Role> roles;
};
struct SystemRecoveryRes {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
enum class Result : uint8_t { SUCCESS, NO_NEED, FAILURE, /* Leave at end */ N };
static void Load(SystemRecoveryRes *self, memgraph::slk::Reader *reader);
static void Save(const SystemRecoveryRes &self, memgraph::slk::Builder *builder);
SystemRecoveryRes() = default;
explicit SystemRecoveryRes(Result res) : result(res) {}
Result result;
};
using SystemRecoveryRpc = rpc::RequestResponse<SystemRecoveryReq, SystemRecoveryRes>;
} // namespace memgraph::replication
namespace memgraph::slk {
void Save(const memgraph::replication::SystemHeartbeatRes &self, memgraph::slk::Builder *builder);
void Load(memgraph::replication::SystemHeartbeatRes *self, memgraph::slk::Reader *reader);
void Save(const memgraph::replication::SystemHeartbeatReq & /*self*/, memgraph::slk::Builder * /*builder*/);
void Load(memgraph::replication::SystemHeartbeatReq * /*self*/, memgraph::slk::Reader * /*reader*/);
void Save(const memgraph::replication::SystemRecoveryReq &self, memgraph::slk::Builder *builder);
void Load(memgraph::replication::SystemRecoveryReq *self, memgraph::slk::Reader *reader);
void Save(const memgraph::replication::SystemRecoveryRes &self, memgraph::slk::Builder *builder);
void Load(memgraph::replication::SystemRecoveryRes *self, memgraph::slk::Reader *reader);
} // namespace memgraph::slk

View File

@ -0,0 +1,291 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "replication_handler/replication_handler.hpp"
#include "dbms/dbms_handler.hpp"
#include "replication_handler/system_replication.hpp"
namespace memgraph::replication {
namespace {
#ifdef MG_ENTERPRISE
void RecoverReplication(memgraph::replication::ReplicationState &repl_state, memgraph::system::System *system,
memgraph::dbms::DbmsHandler &dbms_handler, memgraph::auth::SynchedAuth &auth) {
/*
* REPLICATION RECOVERY AND STARTUP
*/
// Startup replication state (if recovered at startup)
auto replica = [&dbms_handler, &auth](memgraph::replication::RoleReplicaData const &data) {
return memgraph::replication::StartRpcServer(dbms_handler, data, auth);
};
// Replication recovery and frequent check start
auto main = [system, &dbms_handler, &auth](memgraph::replication::RoleMainData &mainData) {
for (auto &client : mainData.registered_replicas_) {
memgraph::replication::SystemRestore(client, system, dbms_handler, auth);
}
// DBMS here
dbms_handler.ForEach([&mainData](memgraph::dbms::DatabaseAccess db_acc) {
dbms::DbmsHandler::RecoverStorageReplication(std::move(db_acc), mainData);
});
for (auto &client : mainData.registered_replicas_) {
memgraph::replication::StartReplicaClient(client, system, dbms_handler, auth);
}
// Warning
if (dbms_handler.default_config().durability.snapshot_wal_mode ==
memgraph::storage::Config::Durability::SnapshotWalMode::DISABLED) {
spdlog::warn(
"The instance has the MAIN replication role, but durability logs and snapshots are disabled. Please "
"consider "
"enabling durability by using --storage-snapshot-interval-sec and --storage-wal-enabled flags because "
"without write-ahead logs this instance is not replicating any data.");
}
return true;
};
auto result = std::visit(memgraph::utils::Overloaded{replica, main}, repl_state.ReplicationData());
MG_ASSERT(result, "Replica recovery failure!");
}
#else
void RecoverReplication(memgraph::replication::ReplicationState &repl_state,
memgraph::dbms::DbmsHandler &dbms_handler) {
// Startup replication state (if recovered at startup)
auto replica = [&dbms_handler](memgraph::replication::RoleReplicaData const &data) {
return memgraph::replication::StartRpcServer(dbms_handler, data);
};
// Replication recovery and frequent check start
auto main = [&dbms_handler](memgraph::replication::RoleMainData &mainData) {
dbms::DbmsHandler::RecoverStorageReplication(dbms_handler.Get(), mainData);
for (auto &client : mainData.registered_replicas_) {
memgraph::replication::StartReplicaClient(client, dbms_handler);
}
// Warning
if (dbms_handler.default_config().durability.snapshot_wal_mode ==
memgraph::storage::Config::Durability::SnapshotWalMode::DISABLED) {
spdlog::warn(
"The instance has the MAIN replication role, but durability logs and snapshots are disabled. Please "
"consider "
"enabling durability by using --storage-snapshot-interval-sec and --storage-wal-enabled flags because "
"without write-ahead logs this instance is not replicating any data.");
}
return true;
};
auto result = std::visit(memgraph::utils::Overloaded{replica, main}, repl_state.ReplicationData());
MG_ASSERT(result, "Replica recovery failure!");
}
#endif
} // namespace
inline std::optional<query::RegisterReplicaError> HandleRegisterReplicaStatus(
utils::BasicResult<replication::RegisterReplicaError, replication::ReplicationClient *> &instance_client) {
if (instance_client.HasError()) switch (instance_client.GetError()) {
case replication::RegisterReplicaError::NOT_MAIN:
MG_ASSERT(false, "Only main instance can register a replica!");
return {};
case replication::RegisterReplicaError::NAME_EXISTS:
return query::RegisterReplicaError::NAME_EXISTS;
case replication::RegisterReplicaError::ENDPOINT_EXISTS:
return query::RegisterReplicaError::ENDPOINT_EXISTS;
case replication::RegisterReplicaError::COULD_NOT_BE_PERSISTED:
return query::RegisterReplicaError::COULD_NOT_BE_PERSISTED;
case replication::RegisterReplicaError::SUCCESS:
break;
}
return {};
}
#ifdef MG_ENTERPRISE
void StartReplicaClient(replication::ReplicationClient &client, system::System *system, dbms::DbmsHandler &dbms_handler,
auth::SynchedAuth &auth) {
#else
void StartReplicaClient(replication::ReplicationClient &client, dbms::DbmsHandler &dbms_handler) {
#endif
// No client error, start instance level client
auto const &endpoint = client.rpc_client_.Endpoint();
spdlog::trace("Replication client started at: {}:{}", endpoint.address, endpoint.port);
client.StartFrequentCheck([&,
#ifdef MG_ENTERPRISE
system = system,
#endif
license = license::global_license_checker.IsEnterpriseValidFast()](
bool reconnect, replication::ReplicationClient &client) mutable {
// Working connection
// Check if system needs restoration
if (reconnect) {
client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; });
}
// Check if license has changed
const auto new_license = license::global_license_checker.IsEnterpriseValidFast();
if (new_license != license) {
license = new_license;
client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; });
}
#ifdef MG_ENTERPRISE
SystemRestore<true>(client, system, dbms_handler, auth);
#endif
// Check if any database has been left behind
dbms_handler.ForEach([&name = client.name_, reconnect](dbms::DatabaseAccess db_acc) {
// Specific database <-> replica client
db_acc->storage()->repl_storage_state_.WithClient(name, [&](storage::ReplicationStorageClient *client) {
if (reconnect || client->State() == storage::replication::ReplicaState::MAYBE_BEHIND) {
// Database <-> replica might be behind, check and recover
client->TryCheckReplicaStateAsync(db_acc->storage(), db_acc);
}
});
});
});
}
#ifdef MG_ENTERPRISE
ReplicationHandler::ReplicationHandler(memgraph::replication::ReplicationState &repl_state,
memgraph::dbms::DbmsHandler &dbms_handler, memgraph::system::System *system,
memgraph::auth::SynchedAuth &auth)
: repl_state_{repl_state}, dbms_handler_{dbms_handler}, system_{system}, auth_{auth} {
RecoverReplication(repl_state_, system_, dbms_handler_, auth_);
}
#else
ReplicationHandler::ReplicationHandler(replication::ReplicationState &repl_state, dbms::DbmsHandler &dbms_handler)
: repl_state_{repl_state}, dbms_handler_{dbms_handler} {
RecoverReplication(repl_state_, dbms_handler_);
}
#endif
bool ReplicationHandler::SetReplicationRoleMain() {
auto const main_handler = [](memgraph::replication::RoleMainData &) {
// If we are already MAIN, we don't want to change anything
return false;
};
auto const replica_handler = [this](memgraph::replication::RoleReplicaData const &) {
return DoReplicaToMainPromotion();
};
// TODO: under lock
return std::visit(memgraph::utils::Overloaded{main_handler, replica_handler}, repl_state_.ReplicationData());
}
bool ReplicationHandler::SetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config) {
// We don't want to restart the server if we're already a REPLICA
if (repl_state_.IsReplica()) {
return false;
}
// TODO StorageState needs to be synched. Could have a dangling reference if someone adds a database as we are
// deleting the replica.
// Remove database specific clients
dbms_handler_.ForEach([&](memgraph::dbms::DatabaseAccess db_acc) {
auto *storage = db_acc->storage();
storage->repl_storage_state_.replication_clients_.WithLock([](auto &clients) { clients.clear(); });
});
// Remove instance level clients
std::get<memgraph::replication::RoleMainData>(repl_state_.ReplicationData()).registered_replicas_.clear();
// Creates the server
repl_state_.SetReplicationRoleReplica(config);
// Start
const auto success =
std::visit(memgraph::utils::Overloaded{[](memgraph::replication::RoleMainData const &) {
// ASSERT
return false;
},
[this](memgraph::replication::RoleReplicaData const &data) {
#ifdef MG_ENTERPRISE
return StartRpcServer(dbms_handler_, data, auth_);
#else
return StartRpcServer(dbms_handler_, data);
#endif
}},
repl_state_.ReplicationData());
// TODO Handle error (restore to main?)
return success;
}
bool ReplicationHandler::DoReplicaToMainPromotion() {
// STEP 1) bring down all REPLICA servers
dbms_handler_.ForEach([](dbms::DatabaseAccess db_acc) {
auto *storage = db_acc->storage();
// Remember old epoch + storage timestamp association
storage->PrepareForNewEpoch();
});
// STEP 2) Change to MAIN
// TODO: restore replication servers if false?
if (!repl_state_.SetReplicationRoleMain()) {
// TODO: Handle recovery on failure???
return false;
}
// STEP 3) We are now MAIN, update storage local epoch
const auto &epoch = std::get<replication::RoleMainData>(std::as_const(repl_state_).ReplicationData()).epoch_;
dbms_handler_.ForEach([&](dbms::DatabaseAccess db_acc) {
auto *storage = db_acc->storage();
storage->repl_storage_state_.epoch_ = epoch;
});
return true;
};
// as MAIN, define and connect to REPLICAs
auto ReplicationHandler::TryRegisterReplica(const memgraph::replication::ReplicationClientConfig &config)
-> memgraph::utils::BasicResult<memgraph::query::RegisterReplicaError> {
return RegisterReplica_<false>(config);
}
auto ReplicationHandler::RegisterReplica(const memgraph::replication::ReplicationClientConfig &config)
-> memgraph::utils::BasicResult<memgraph::query::RegisterReplicaError> {
return RegisterReplica_<true>(config);
}
auto ReplicationHandler::UnregisterReplica(std::string_view name) -> memgraph::query::UnregisterReplicaResult {
auto const replica_handler =
[](memgraph::replication::RoleReplicaData const &) -> memgraph::query::UnregisterReplicaResult {
return memgraph::query::UnregisterReplicaResult::NOT_MAIN;
};
auto const main_handler =
[this, name](memgraph::replication::RoleMainData &mainData) -> memgraph::query::UnregisterReplicaResult {
if (!repl_state_.TryPersistUnregisterReplica(name)) {
return memgraph::query::UnregisterReplicaResult::COULD_NOT_BE_PERSISTED;
}
// Remove database specific clients
dbms_handler_.ForEach([name](memgraph::dbms::DatabaseAccess db_acc) {
db_acc->storage()->repl_storage_state_.replication_clients_.WithLock([&name](auto &clients) {
std::erase_if(clients, [name](const auto &client) { return client->Name() == name; });
});
});
// Remove instance level clients
auto const n_unregistered =
std::erase_if(mainData.registered_replicas_, [name](auto const &client) { return client.name_ == name; });
return n_unregistered != 0 ? memgraph::query::UnregisterReplicaResult::SUCCESS
: memgraph::query::UnregisterReplicaResult::CAN_NOT_UNREGISTER;
};
return std::visit(memgraph::utils::Overloaded{main_handler, replica_handler}, repl_state_.ReplicationData());
}
auto ReplicationHandler::GetRole() const -> memgraph::replication_coordination_glue::ReplicationRole {
return repl_state_.GetRole();
}
bool ReplicationHandler::IsMain() const { return repl_state_.IsMain(); }
bool ReplicationHandler::IsReplica() const { return repl_state_.IsReplica(); }
} // namespace memgraph::replication

View File

@ -0,0 +1,115 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "replication_handler/system_replication.hpp"
#include <spdlog/spdlog.h>
#include "auth/replication_handlers.hpp"
#include "dbms/replication_handlers.hpp"
#include "license/license.hpp"
#include "replication_handler/system_rpc.hpp"
namespace memgraph::replication {
#ifdef MG_ENTERPRISE
void SystemHeartbeatHandler(const uint64_t ts, slk::Reader *req_reader, slk::Builder *res_builder) {
replication::SystemHeartbeatRes res{0};
// Ignore if no license
if (!license::global_license_checker.IsEnterpriseValidFast()) {
spdlog::error("Handling SystemHeartbeat, an enterprise RPC message, without license.");
memgraph::slk::Save(res, res_builder);
return;
}
replication::SystemHeartbeatReq req;
replication::SystemHeartbeatReq::Load(&req, req_reader);
res = replication::SystemHeartbeatRes{ts};
memgraph::slk::Save(res, res_builder);
}
void SystemRecoveryHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access,
dbms::DbmsHandler &dbms_handler, auth::SynchedAuth &auth, slk::Reader *req_reader,
slk::Builder *res_builder) {
using memgraph::replication::SystemRecoveryRes;
SystemRecoveryRes res(SystemRecoveryRes::Result::FAILURE);
utils::OnScopeExit send_on_exit([&]() { memgraph::slk::Save(res, res_builder); });
memgraph::replication::SystemRecoveryReq req;
memgraph::slk::Load(&req, req_reader);
/*
* DBMS
*/
if (!dbms::SystemRecoveryHandler(dbms_handler, req.database_configs)) return; // Failure sent on exit
/*
* AUTH
*/
if (!auth::SystemRecoveryHandler(auth, req.auth_config, req.users, req.roles)) return; // Failure sent on exit
/*
* SUCCESSFUL RECOVERY
*/
system_state_access.SetLastCommitedTS(req.forced_group_timestamp);
spdlog::debug("SystemRecoveryHandler: SUCCESS updated LCTS to {}", req.forced_group_timestamp);
res = SystemRecoveryRes(SystemRecoveryRes::Result::SUCCESS);
}
void Register(replication::RoleReplicaData const &data, dbms::DbmsHandler &dbms_handler, auth::SynchedAuth &auth) {
// NOTE: Register even without license as the user could add a license at run-time
// TODO: fix Register when system is removed from DbmsHandler
auto system_state_access = dbms_handler.system_->CreateSystemStateAccess();
// System
data.server->rpc_server_.Register<replication::SystemHeartbeatRpc>(
[system_state_access](auto *req_reader, auto *res_builder) {
spdlog::debug("Received SystemHeartbeatRpc");
SystemHeartbeatHandler(system_state_access.LastCommitedTS(), req_reader, res_builder);
});
data.server->rpc_server_.Register<replication::SystemRecoveryRpc>(
[system_state_access, &dbms_handler, &auth](auto *req_reader, auto *res_builder) mutable {
spdlog::debug("Received SystemRecoveryRpc");
SystemRecoveryHandler(system_state_access, dbms_handler, auth, req_reader, res_builder);
});
// DBMS
dbms::Register(data, system_state_access, dbms_handler);
// Auth
auth::Register(data, system_state_access, auth);
}
#endif
#ifdef MG_ENTERPRISE
bool StartRpcServer(dbms::DbmsHandler &dbms_handler, const replication::RoleReplicaData &data,
auth::SynchedAuth &auth) {
#else
bool StartRpcServer(dbms::DbmsHandler &dbms_handler, const replication::RoleReplicaData &data) {
#endif
// Register storage handlers
dbms::InMemoryReplicationHandlers::Register(&dbms_handler, *data.server);
#ifdef MG_ENTERPRISE
// Register system handlers
Register(data, dbms_handler, auth);
#endif
// Start server
if (!data.server->Start()) {
spdlog::error("Unable to start the replication server.");
return false;
}
return true;
}
} // namespace memgraph::replication

View File

@ -0,0 +1,111 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "replication_handler/system_rpc.hpp"
#include <json/json.hpp>
#include "auth/rpc.hpp"
#include "slk/serialization.hpp"
#include "slk/streams.hpp"
#include "storage/v2/replication/rpc.hpp"
#include "utils/enum.hpp"
namespace memgraph::slk {
// Serialize code for SystemHeartbeatRes
void Save(const memgraph::replication::SystemHeartbeatRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.system_timestamp, builder);
}
void Load(memgraph::replication::SystemHeartbeatRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->system_timestamp, reader);
}
// Serialize code for SystemHeartbeatReq
void Save(const memgraph::replication::SystemHeartbeatReq & /*self*/, memgraph::slk::Builder * /*builder*/) {
/* Nothing to serialize */
}
void Load(memgraph::replication::SystemHeartbeatReq * /*self*/, memgraph::slk::Reader * /*reader*/) {
/* Nothing to serialize */
}
// Serialize code for SystemRecoveryReq
void Save(const memgraph::replication::SystemRecoveryReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.forced_group_timestamp, builder);
memgraph::slk::Save(self.database_configs, builder);
memgraph::slk::Save(self.auth_config, builder);
memgraph::slk::Save(self.users, builder);
memgraph::slk::Save(self.roles, builder);
}
void Load(memgraph::replication::SystemRecoveryReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->forced_group_timestamp, reader);
memgraph::slk::Load(&self->database_configs, reader);
memgraph::slk::Load(&self->auth_config, reader);
memgraph::slk::Load(&self->users, reader);
memgraph::slk::Load(&self->roles, reader);
}
// Serialize code for SystemRecoveryRes
void Save(const memgraph::replication::SystemRecoveryRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(utils::EnumToNum<uint8_t>(self.result), builder);
}
void Load(memgraph::replication::SystemRecoveryRes *self, memgraph::slk::Reader *reader) {
uint8_t res = 0;
memgraph::slk::Load(&res, reader);
if (!utils::NumToEnum(res, self->result)) {
throw SlkReaderException("Unexpected result line:{}!", __LINE__);
}
}
} // namespace memgraph::slk
namespace memgraph::replication {
constexpr utils::TypeInfo SystemHeartbeatReq::kType{utils::TypeId::REP_SYSTEM_HEARTBEAT_REQ, "SystemHeartbeatReq",
nullptr};
constexpr utils::TypeInfo SystemHeartbeatRes::kType{utils::TypeId::REP_SYSTEM_HEARTBEAT_RES, "SystemHeartbeatRes",
nullptr};
constexpr utils::TypeInfo SystemRecoveryReq::kType{utils::TypeId::REP_SYSTEM_RECOVERY_REQ, "SystemRecoveryReq",
nullptr};
constexpr utils::TypeInfo SystemRecoveryRes::kType{utils::TypeId::REP_SYSTEM_RECOVERY_RES, "SystemRecoveryRes",
nullptr};
void SystemHeartbeatReq::Save(const SystemHeartbeatReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void SystemHeartbeatReq::Load(SystemHeartbeatReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
void SystemHeartbeatRes::Save(const SystemHeartbeatRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void SystemHeartbeatRes::Load(SystemHeartbeatRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
void SystemRecoveryReq::Save(const SystemRecoveryReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void SystemRecoveryReq::Load(SystemRecoveryReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
void SystemRecoveryRes::Save(const SystemRecoveryRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void SystemRecoveryRes::Load(SystemRecoveryRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
} // namespace memgraph::replication

View File

@ -59,39 +59,6 @@ void TimestampRes::Save(const TimestampRes &self, memgraph::slk::Builder *builde
memgraph::slk::Save(self, builder); memgraph::slk::Save(self, builder);
} }
void TimestampRes::Load(TimestampRes *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); } void TimestampRes::Load(TimestampRes *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); }
void CreateDatabaseReq::Save(const CreateDatabaseReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void CreateDatabaseReq::Load(CreateDatabaseReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
void CreateDatabaseRes::Save(const CreateDatabaseRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void CreateDatabaseRes::Load(CreateDatabaseRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
void DropDatabaseReq::Save(const DropDatabaseReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void DropDatabaseReq::Load(DropDatabaseReq *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); }
void DropDatabaseRes::Save(const DropDatabaseRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void DropDatabaseRes::Load(DropDatabaseRes *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); }
void SystemRecoveryReq::Save(const SystemRecoveryReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void SystemRecoveryReq::Load(SystemRecoveryReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
void SystemRecoveryRes::Save(const SystemRecoveryRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void SystemRecoveryRes::Load(SystemRecoveryRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
} // namespace storage::replication } // namespace storage::replication
constexpr utils::TypeInfo storage::replication::AppendDeltasReq::kType{utils::TypeId::REP_APPEND_DELTAS_REQ, constexpr utils::TypeInfo storage::replication::AppendDeltasReq::kType{utils::TypeId::REP_APPEND_DELTAS_REQ,
@ -130,24 +97,6 @@ constexpr utils::TypeInfo storage::replication::TimestampReq::kType{utils::TypeI
constexpr utils::TypeInfo storage::replication::TimestampRes::kType{utils::TypeId::REP_TIMESTAMP_RES, "TimestampRes", constexpr utils::TypeInfo storage::replication::TimestampRes::kType{utils::TypeId::REP_TIMESTAMP_RES, "TimestampRes",
nullptr}; nullptr};
constexpr utils::TypeInfo storage::replication::CreateDatabaseReq::kType{utils::TypeId::REP_CREATE_DATABASE_REQ,
"CreateDatabaseReq", nullptr};
constexpr utils::TypeInfo storage::replication::CreateDatabaseRes::kType{utils::TypeId::REP_CREATE_DATABASE_RES,
"CreateDatabaseRes", nullptr};
constexpr utils::TypeInfo storage::replication::DropDatabaseReq::kType{utils::TypeId::REP_DROP_DATABASE_REQ,
"DropDatabaseReq", nullptr};
constexpr utils::TypeInfo storage::replication::DropDatabaseRes::kType{utils::TypeId::REP_DROP_DATABASE_RES,
"DropDatabaseRes", nullptr};
constexpr utils::TypeInfo storage::replication::SystemRecoveryReq::kType{utils::TypeId::REP_SYSTEM_RECOVERY_REQ,
"SystemRecoveryReq", nullptr};
constexpr utils::TypeInfo storage::replication::SystemRecoveryRes::kType{utils::TypeId::REP_SYSTEM_RECOVERY_RES,
"SystemRecoveryRes", nullptr};
// Autogenerated SLK serialization code // Autogenerated SLK serialization code
namespace slk { namespace slk {
// Serialize code for TimestampRes // Serialize code for TimestampRes
@ -316,91 +265,5 @@ void Load(memgraph::storage::SalientConfig *self, memgraph::slk::Reader *reader)
memgraph::slk::Load(&self->items.enable_schema_metadata, reader); memgraph::slk::Load(&self->items.enable_schema_metadata, reader);
} }
// Serialize code for CreateDatabaseReq
void Save(const memgraph::storage::replication::CreateDatabaseReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.epoch_id, builder);
memgraph::slk::Save(self.expected_group_timestamp, builder);
memgraph::slk::Save(self.new_group_timestamp, builder);
memgraph::slk::Save(self.config, builder);
}
void Load(memgraph::storage::replication::CreateDatabaseReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->epoch_id, reader);
memgraph::slk::Load(&self->expected_group_timestamp, reader);
memgraph::slk::Load(&self->new_group_timestamp, reader);
memgraph::slk::Load(&self->config, reader);
}
// Serialize code for CreateDatabaseRes
void Save(const memgraph::storage::replication::CreateDatabaseRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(utils::EnumToNum<uint8_t>(self.result), builder);
}
void Load(memgraph::storage::replication::CreateDatabaseRes *self, memgraph::slk::Reader *reader) {
uint8_t res = 0;
memgraph::slk::Load(&res, reader);
if (!utils::NumToEnum(res, self->result)) {
throw SlkReaderException("Unexpected result line:{}!", __LINE__);
}
}
// Serialize code for DropDatabaseReq
void Save(const memgraph::storage::replication::DropDatabaseReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.epoch_id, builder);
memgraph::slk::Save(self.expected_group_timestamp, builder);
memgraph::slk::Save(self.new_group_timestamp, builder);
memgraph::slk::Save(self.uuid, builder);
}
void Load(memgraph::storage::replication::DropDatabaseReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->epoch_id, reader);
memgraph::slk::Load(&self->expected_group_timestamp, reader);
memgraph::slk::Load(&self->new_group_timestamp, reader);
memgraph::slk::Load(&self->uuid, reader);
}
// Serialize code for DropDatabaseRes
void Save(const memgraph::storage::replication::DropDatabaseRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(utils::EnumToNum<uint8_t>(self.result), builder);
}
void Load(memgraph::storage::replication::DropDatabaseRes *self, memgraph::slk::Reader *reader) {
uint8_t res = 0;
memgraph::slk::Load(&res, reader);
if (!utils::NumToEnum(res, self->result)) {
throw SlkReaderException("Unexpected result line:{}!", __LINE__);
}
}
// Serialize code for SystemRecoveryReq
void Save(const memgraph::storage::replication::SystemRecoveryReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.forced_group_timestamp, builder);
memgraph::slk::Save(self.database_configs, builder);
}
void Load(memgraph::storage::replication::SystemRecoveryReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->forced_group_timestamp, reader);
memgraph::slk::Load(&self->database_configs, reader);
}
// Serialize code for SystemRecoveryRes
void Save(const memgraph::storage::replication::SystemRecoveryRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(utils::EnumToNum<uint8_t>(self.result), builder);
}
void Load(memgraph::storage::replication::SystemRecoveryRes *self, memgraph::slk::Reader *reader) {
uint8_t res = 0;
memgraph::slk::Load(&res, reader);
if (!utils::NumToEnum(res, self->result)) {
throw SlkReaderException("Unexpected result line:{}!", __LINE__);
}
}
} // namespace slk } // namespace slk
} // namespace memgraph } // namespace memgraph

View File

@ -201,108 +201,6 @@ struct TimestampRes {
using TimestampRpc = rpc::RequestResponse<TimestampReq, TimestampRes>; using TimestampRpc = rpc::RequestResponse<TimestampReq, TimestampRes>;
struct CreateDatabaseReq {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(CreateDatabaseReq *self, memgraph::slk::Reader *reader);
static void Save(const CreateDatabaseReq &self, memgraph::slk::Builder *builder);
CreateDatabaseReq() = default;
CreateDatabaseReq(std::string epoch_id, uint64_t expected_group_timestamp, uint64_t new_group_timestamp,
storage::SalientConfig config)
: epoch_id(std::move(epoch_id)),
expected_group_timestamp{expected_group_timestamp},
new_group_timestamp(new_group_timestamp),
config(std::move(config)) {}
std::string epoch_id;
uint64_t expected_group_timestamp;
uint64_t new_group_timestamp;
storage::SalientConfig config;
};
struct CreateDatabaseRes {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
enum class Result : uint8_t { SUCCESS, NO_NEED, FAILURE, /* Leave at end */ N };
static void Load(CreateDatabaseRes *self, memgraph::slk::Reader *reader);
static void Save(const CreateDatabaseRes &self, memgraph::slk::Builder *builder);
CreateDatabaseRes() = default;
explicit CreateDatabaseRes(Result res) : result(res) {}
Result result;
};
using CreateDatabaseRpc = rpc::RequestResponse<CreateDatabaseReq, CreateDatabaseRes>;
struct DropDatabaseReq {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(DropDatabaseReq *self, memgraph::slk::Reader *reader);
static void Save(const DropDatabaseReq &self, memgraph::slk::Builder *builder);
DropDatabaseReq() = default;
DropDatabaseReq(std::string epoch_id, uint64_t expected_group_timestamp, uint64_t new_group_timestamp,
const utils::UUID &uuid)
: epoch_id(std::move(epoch_id)),
expected_group_timestamp{expected_group_timestamp},
new_group_timestamp(new_group_timestamp),
uuid(uuid) {}
std::string epoch_id;
uint64_t expected_group_timestamp;
uint64_t new_group_timestamp;
utils::UUID uuid;
};
struct DropDatabaseRes {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
enum class Result : uint8_t { SUCCESS, NO_NEED, FAILURE, /* Leave at end */ N };
static void Load(DropDatabaseRes *self, memgraph::slk::Reader *reader);
static void Save(const DropDatabaseRes &self, memgraph::slk::Builder *builder);
DropDatabaseRes() = default;
explicit DropDatabaseRes(Result res) : result(res) {}
Result result;
};
using DropDatabaseRpc = rpc::RequestResponse<DropDatabaseReq, DropDatabaseRes>;
struct SystemRecoveryReq {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(SystemRecoveryReq *self, memgraph::slk::Reader *reader);
static void Save(const SystemRecoveryReq &self, memgraph::slk::Builder *builder);
SystemRecoveryReq() = default;
SystemRecoveryReq(uint64_t forced_group_timestamp, std::vector<storage::SalientConfig> database_configs)
: forced_group_timestamp{forced_group_timestamp}, database_configs(std::move(database_configs)) {}
uint64_t forced_group_timestamp;
std::vector<storage::SalientConfig> database_configs;
};
struct SystemRecoveryRes {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
enum class Result : uint8_t { SUCCESS, NO_NEED, FAILURE, /* Leave at end */ N };
static void Load(SystemRecoveryRes *self, memgraph::slk::Reader *reader);
static void Save(const SystemRecoveryRes &self, memgraph::slk::Builder *builder);
SystemRecoveryRes() = default;
explicit SystemRecoveryRes(Result res) : result(res) {}
Result result;
};
using SystemRecoveryRpc = rpc::RequestResponse<SystemRecoveryReq, SystemRecoveryRes>;
} // namespace memgraph::storage::replication } // namespace memgraph::storage::replication
// SLK serialization declarations // SLK serialization declarations
@ -356,28 +254,8 @@ void Save(const memgraph::storage::replication::AppendDeltasReq &self, memgraph:
void Load(memgraph::storage::replication::AppendDeltasReq *self, memgraph::slk::Reader *reader); void Load(memgraph::storage::replication::AppendDeltasReq *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::CreateDatabaseReq &self, memgraph::slk::Builder *builder); void Save(const memgraph::storage::SalientConfig &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::CreateDatabaseReq *self, memgraph::slk::Reader *reader); void Load(memgraph::storage::SalientConfig *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::CreateDatabaseRes &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::CreateDatabaseRes *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::DropDatabaseReq &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::DropDatabaseReq *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::DropDatabaseRes &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::DropDatabaseRes *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::SystemRecoveryReq &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::SystemRecoveryReq *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::SystemRecoveryRes &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::SystemRecoveryRes *self, memgraph::slk::Reader *reader);
} // namespace memgraph::slk } // namespace memgraph::slk

23
src/system/CMakeLists.txt Normal file
View File

@ -0,0 +1,23 @@
add_library(mg-system STATIC)
add_library(mg::system ALIAS mg-system)
target_sources(mg-system
PUBLIC
include/system/action.hpp
include/system/system.hpp
include/system/transaction.hpp
include/system/state.hpp
PRIVATE
action.cpp
system.cpp
transaction.cpp
state.cpp
)
target_include_directories(mg-system PUBLIC include)
target_link_libraries(mg-system
PUBLIC
mg::replication
)

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd. // Copyright 2024 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -8,14 +8,4 @@
// the Business Source License, use of this software will be governed // the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file // by the Apache License, Version 2.0, included in the file
// licenses/APL.txt. // licenses/APL.txt.
#include "system/include/system/action.hpp"
#pragma once
#include "dbms/dbms_handler.hpp"
#include "replication/replication_client.hpp"
namespace memgraph::dbms {
void StartReplicaClient(DbmsHandler &dbms_handler, replication::ReplicationClient &client);
} // namespace memgraph::dbms

View File

@ -0,0 +1,38 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "replication/epoch.hpp"
#include "replication/replication_client.hpp"
#include "replication/state.hpp"
namespace memgraph::system {
struct Transaction;
/// The system action interface that subsystems will implement. This OO-style separation is needed so that one common
/// mechanism can be used for all subsystem replication within a system transaction, without the need for System to
/// know about all the subsystems.
struct ISystemAction {
/// Durability step which is defered until commit time
virtual void DoDurability() = 0;
/// Prepare the RPC payload that will be sent to all replicas clients
virtual bool DoReplication(memgraph::replication::ReplicationClient &client,
memgraph::replication::ReplicationEpoch const &epoch,
Transaction const &system_tx) const = 0;
virtual void PostReplication(memgraph::replication::RoleMainData &main_data) const = 0;
virtual ~ISystemAction() = default;
};
} // namespace memgraph::system

View File

@ -0,0 +1,57 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <atomic>
#include <cstdint>
#include "kvstore/kvstore.hpp"
#include "utils/file.hpp"
namespace memgraph::system {
namespace {
constexpr std::string_view kLastCommitedSystemTsKey = "last_committed_system_ts"; // Key for timestamp durability
}
struct State {
explicit State(std::optional<std::filesystem::path> storage, bool recovery_on_startup);
void FinalizeTransaction(std::uint64_t timestamp) {
if (durability_) {
durability_->Put(kLastCommitedSystemTsKey, std::to_string(timestamp));
}
last_committed_system_timestamp_.store(timestamp);
}
auto LastCommittedSystemTimestamp() -> uint64_t { return last_committed_system_timestamp_.load(); }
private:
friend struct ReplicaHandlerAccessToState;
friend struct Transaction;
std::optional<kvstore::KVStore> durability_;
std::atomic_uint64_t last_committed_system_timestamp_{};
};
struct ReplicaHandlerAccessToState {
explicit ReplicaHandlerAccessToState(memgraph::system::State &state) : state_{&state} {}
auto LastCommitedTS() const -> uint64_t { return state_->last_committed_system_timestamp_.load(); }
void SetLastCommitedTS(uint64_t new_timestamp) { state_->last_committed_system_timestamp_.store(new_timestamp); }
private:
State *state_;
};
} // namespace memgraph::system

View File

@ -0,0 +1,52 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "system/state.hpp"
#include "system/transaction.hpp"
namespace memgraph::system {
struct TransactionGuard {
explicit TransactionGuard(std::unique_lock<std::timed_mutex> guard) : guard_(std::move(guard)) {}
private:
std::unique_lock<std::timed_mutex> guard_;
};
struct System {
// NOTE: default arguments to make testing easier.
System(std::optional<std::filesystem::path> storage = std::nullopt, bool recovery_on_startup = false)
: state_(std::move(storage), recovery_on_startup), timestamp_{state_.LastCommittedSystemTimestamp()} {}
auto TryCreateTransaction(std::chrono::microseconds try_time = std::chrono::milliseconds{100})
-> std::optional<Transaction> {
auto system_unique = std::unique_lock{mtx_, std::defer_lock};
if (!system_unique.try_lock_for(try_time)) {
return std::nullopt;
}
return Transaction{state_, std::move(system_unique), timestamp_++};
}
// TODO: this and LastCommittedSystemTimestamp maybe not needed
auto GenTransactionGuard() -> TransactionGuard { return TransactionGuard{std::unique_lock{mtx_}}; }
auto LastCommittedSystemTimestamp() -> uint64_t { return state_.LastCommittedSystemTimestamp(); }
auto CreateSystemStateAccess() -> ReplicaHandlerAccessToState { return ReplicaHandlerAccessToState{state_}; }
private:
State state_;
std::timed_mutex mtx_{};
std::uint64_t timestamp_{};
};
} // namespace memgraph::system

View File

@ -0,0 +1,124 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <chrono>
#include <list>
#include <memory>
#include <mutex>
#include <optional>
#include "replication/state.hpp"
#include "system/action.hpp"
#include "system/state.hpp"
namespace memgraph::system {
enum class AllSyncReplicaStatus : std::uint8_t {
AllCommitsConfirmed,
SomeCommitsUnconfirmed,
};
struct Transaction;
template <typename T>
concept ReplicationPolicy = requires(T handler, ISystemAction const &action, Transaction const &txn) {
{ handler.ApplyAction(action, txn) } -> std::same_as<AllSyncReplicaStatus>;
};
struct System;
struct Transaction {
template <std::derived_from<ISystemAction> TAction, typename... Args>
requires std::constructible_from<TAction, Args...>
void AddAction(Args &&...args) { actions_.emplace_back(std::make_unique<TAction>(std::forward<Args>(args)...)); }
template <ReplicationPolicy Handler>
auto Commit(Handler handler) -> AllSyncReplicaStatus {
if (!lock_.owns_lock() || actions_.empty()) {
// If no actions, we do not increment the last commited ts, since there is no delta to send to the REPLICA
Abort();
return AllSyncReplicaStatus::AllCommitsConfirmed; // TODO: some kind of error
}
auto sync_status = AllSyncReplicaStatus::AllCommitsConfirmed;
while (!actions_.empty()) {
auto &action = actions_.front();
/// durability
action->DoDurability();
/// replication prep
auto action_sync_status = handler.ApplyAction(*action, *this);
if (action_sync_status != AllSyncReplicaStatus::AllCommitsConfirmed) {
sync_status = AllSyncReplicaStatus::SomeCommitsUnconfirmed;
}
actions_.pop_front();
}
state_->FinalizeTransaction(timestamp_);
lock_.unlock();
return sync_status;
}
void Abort() {
if (lock_.owns_lock()) {
lock_.unlock();
}
actions_.clear();
}
auto last_committed_system_timestamp() const -> uint64_t { return state_->last_committed_system_timestamp_.load(); }
auto timestamp() const -> uint64_t { return timestamp_; }
private:
friend struct System;
Transaction(State &state, std::unique_lock<std::timed_mutex> lock, std::uint64_t timestamp)
: state_{std::addressof(state)}, lock_(std::move(lock)), timestamp_{timestamp} {}
State *state_;
std::unique_lock<std::timed_mutex> lock_;
std::uint64_t timestamp_;
std::list<std::unique_ptr<ISystemAction>> actions_;
};
struct DoReplication {
explicit DoReplication(replication::RoleMainData &main_data) : main_data_{main_data} {}
auto ApplyAction(ISystemAction const &action, Transaction const &system_tx) -> AllSyncReplicaStatus {
auto sync_status = AllSyncReplicaStatus::AllCommitsConfirmed;
for (auto &client : main_data_.registered_replicas_) {
bool completed = action.DoReplication(client, main_data_.epoch_, system_tx);
if (!completed && client.mode_ == replication_coordination_glue::ReplicationMode::SYNC) {
sync_status = AllSyncReplicaStatus::SomeCommitsUnconfirmed;
}
}
action.PostReplication(main_data_);
return sync_status;
}
private:
replication::RoleMainData &main_data_;
};
static_assert(ReplicationPolicy<DoReplication>);
struct DoNothing {
auto ApplyAction(ISystemAction const & /*action*/, Transaction const & /*system_tx*/) -> AllSyncReplicaStatus {
return AllSyncReplicaStatus::AllCommitsConfirmed;
}
};
static_assert(ReplicationPolicy<DoNothing>);
} // namespace memgraph::system

57
src/system/state.cpp Normal file
View File

@ -0,0 +1,57 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "system/state.hpp"
namespace memgraph::system {
namespace {
constexpr std::string_view kSystemDir = ".system";
constexpr std::string_view kVersion = "version"; // Key for version durability
constexpr std::string_view kVersionV1 = "V1"; // Value for version 1
auto InitializeSystemDurability(std::optional<std::filesystem::path> storage, bool recovery_on_startup)
-> std::optional<memgraph::kvstore::KVStore> {
if (!storage) return std::nullopt;
auto const &path = *storage;
memgraph::utils::EnsureDir(path);
auto system_dir = path / kSystemDir;
memgraph::utils::EnsureDir(system_dir);
auto durability = memgraph::kvstore::KVStore{std::move(system_dir)};
auto version = durability.Get(kVersion);
// TODO: migration schemes here in the future
if (!version || *version != kVersionV1) {
// ensure we start out with V1
durability.Put(kVersion, kVersionV1);
}
if (!recovery_on_startup) {
// reset last_committed_system_ts
durability.Delete(kLastCommitedSystemTsKey);
}
return durability;
}
auto LoadLastCommittedSystemTimestamp(std::optional<kvstore::KVStore> const &store) -> uint64_t {
auto lcst = store ? store->Get(kLastCommitedSystemTsKey) : std::nullopt;
return lcst ? std::stoul(*lcst) : 0U;
}
} // namespace
State::State(std::optional<std::filesystem::path> storage, bool recovery_on_startup)
: durability_{InitializeSystemDurability(std::move(storage), recovery_on_startup)},
last_committed_system_timestamp_{LoadLastCommittedSystemTimestamp(durability_)} {}
} // namespace memgraph::system

11
src/system/system.cpp Normal file
View File

@ -0,0 +1,11 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "system/include/system/system.hpp"

View File

@ -0,0 +1,11 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "system/include/system/transaction.hpp"

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd. // Copyright 2024 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -149,9 +149,9 @@ void Telemetry::AddClientCollector() {
} }
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
void Telemetry::AddDatabaseCollector(dbms::DbmsHandler &dbms_handler) { void Telemetry::AddDatabaseCollector(dbms::DbmsHandler &dbms_handler, replication::ReplicationState &repl_state) {
AddCollector("database", [&dbms_handler]() -> nlohmann::json { AddCollector("database", [&dbms_handler, &repl_state]() -> nlohmann::json {
const auto &infos = dbms_handler.Info(); const auto &infos = dbms_handler.Info(repl_state.GetRole());
auto dbs = nlohmann::json::array(); auto dbs = nlohmann::json::array();
for (const auto &db_info : infos) { for (const auto &db_info : infos) {
dbs.push_back(memgraph::dbms::ToJson(db_info)); dbs.push_back(memgraph::dbms::ToJson(db_info));
@ -162,11 +162,10 @@ void Telemetry::AddDatabaseCollector(dbms::DbmsHandler &dbms_handler) {
#else #else
#endif #endif
void Telemetry::AddStorageCollector( void Telemetry::AddStorageCollector(dbms::DbmsHandler &dbms_handler, memgraph::auth::SynchedAuth &auth,
dbms::DbmsHandler &dbms_handler, memgraph::replication::ReplicationState &repl_state) {
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> &auth) { AddCollector("storage", [&dbms_handler, &auth, &repl_state]() -> nlohmann::json {
AddCollector("storage", [&dbms_handler, &auth]() -> nlohmann::json { auto stats = dbms_handler.Stats(repl_state.GetRole());
auto stats = dbms_handler.Stats();
stats.users = auth->AllUsers().size(); stats.users = auth->AllUsers().size();
return ToJson(stats); return ToJson(stats);
}); });

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd. // Copyright 2024 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -43,12 +43,11 @@ class Telemetry final {
void AddCollector(const std::string &name, const std::function<const nlohmann::json(void)> &func); void AddCollector(const std::string &name, const std::function<const nlohmann::json(void)> &func);
// Specialized collectors // Specialized collectors
void AddStorageCollector( void AddStorageCollector(dbms::DbmsHandler &dbms_handler, memgraph::auth::SynchedAuth &auth,
dbms::DbmsHandler &dbms_handler, memgraph::replication::ReplicationState &repl_state);
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> &auth);
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
void AddDatabaseCollector(dbms::DbmsHandler &dbms_handler); void AddDatabaseCollector(dbms::DbmsHandler &dbms_handler, replication::ReplicationState &repl_state);
#else #else
void AddDatabaseCollector() { void AddDatabaseCollector() {
AddCollector("database", []() -> nlohmann::json { return nlohmann::json::array(); }); AddCollector("database", []() -> nlohmann::json { return nlohmann::json::array(); });

View File

@ -161,10 +161,22 @@ struct Gatekeeper {
~Accessor() { reset(); } ~Accessor() { reset(); }
auto get() -> T * { return std::addressof(*owner_->value_); } auto get() -> T * {
auto get() const -> const T * { return std::addressof(*owner_->value_); } if (owner_ == nullptr) return nullptr;
T *operator->() { return std::addressof(*owner_->value_); } return std::addressof(*owner_->value_);
const T *operator->() const { return std::addressof(*owner_->value_); } }
auto get() const -> const T * {
if (owner_ == nullptr) return nullptr;
return std::addressof(*owner_->value_);
}
T *operator->() {
if (owner_ == nullptr) return nullptr;
return std::addressof(*owner_->value_);
}
const T *operator->() const {
if (owner_ == nullptr) return nullptr;
return std::addressof(*owner_->value_);
}
template <typename Func> template <typename Func>
[[nodiscard]] auto try_exclusively(Func &&func) -> EvalResult<std::invoke_result_t<Func, T &>> { [[nodiscard]] auto try_exclusively(Func &&func) -> EvalResult<std::invoke_result_t<Func, T &>> {

View File

@ -93,6 +93,10 @@ enum class TypeId : uint64_t {
REP_SYSTEM_HEARTBEAT_RES, REP_SYSTEM_HEARTBEAT_RES,
REP_SYSTEM_RECOVERY_REQ, REP_SYSTEM_RECOVERY_REQ,
REP_SYSTEM_RECOVERY_RES, REP_SYSTEM_RECOVERY_RES,
REP_UPDATE_AUTH_DATA_REQ,
REP_UPDATE_AUTH_DATA_RES,
REP_DROP_AUTH_DATA_REQ,
REP_DROP_AUTH_DATA_RES,
// Coordinator // Coordinator
COORD_FAILOVER_REQ, COORD_FAILOVER_REQ,

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd. // Copyright 2024 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -26,6 +26,7 @@ std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "e
class ExpansionBenchFixture : public benchmark::Fixture { class ExpansionBenchFixture : public benchmark::Fixture {
protected: protected:
std::optional<memgraph::system::System> system;
std::optional<memgraph::query::InterpreterContext> interpreter_context; std::optional<memgraph::query::InterpreterContext> interpreter_context;
std::optional<memgraph::query::Interpreter> interpreter; std::optional<memgraph::query::Interpreter> interpreter;
std::optional<memgraph::utils::Gatekeeper<memgraph::dbms::Database>> db_gk; std::optional<memgraph::utils::Gatekeeper<memgraph::dbms::Database>> db_gk;
@ -40,7 +41,14 @@ class ExpansionBenchFixture : public benchmark::Fixture {
auto db_acc_opt = db_gk->access(); auto db_acc_opt = db_gk->access();
MG_ASSERT(db_acc_opt, "Failed to access db"); MG_ASSERT(db_acc_opt, "Failed to access db");
auto &db_acc = *db_acc_opt; auto &db_acc = *db_acc_opt;
interpreter_context.emplace(memgraph::query::InterpreterConfig{}, nullptr, &repl_state.value());
system.emplace();
interpreter_context.emplace(memgraph::query::InterpreterConfig{}, nullptr, &repl_state.value(), *system
#ifdef MG_ENTERPRISE
,
nullptr
#endif
);
auto label = db_acc->storage()->NameToLabel("Starting"); auto label = db_acc->storage()->NameToLabel("Starting");
@ -70,6 +78,7 @@ class ExpansionBenchFixture : public benchmark::Fixture {
void TearDown(const benchmark::State &) override { void TearDown(const benchmark::State &) override {
interpreter = std::nullopt; interpreter = std::nullopt;
interpreter_context = std::nullopt; interpreter_context = std::nullopt;
system.reset();
db_gk.reset(); db_gk.reset();
std::filesystem::remove_all(data_directory); std::filesystem::remove_all(data_directory);
} }

View File

@ -105,7 +105,9 @@ def is_port_in_use(port: int) -> bool:
return s.connect_ex(("localhost", port)) == 0 return s.connect_ex(("localhost", port)) == 0
def _start_instance(name, args, log_file, setup_queries, use_ssl, procdir, data_directory): def _start_instance(
name, args, log_file, setup_queries, use_ssl, procdir, data_directory, username=None, password=None
):
assert ( assert (
name not in MEMGRAPH_INSTANCES.keys() name not in MEMGRAPH_INSTANCES.keys()
), "If this raises, you are trying to start an instance with the same name than one already running." ), "If this raises, you are trying to start an instance with the same name than one already running."
@ -115,7 +117,9 @@ def _start_instance(name, args, log_file, setup_queries, use_ssl, procdir, data_
log_file_path = os.path.join(BUILD_DIR, "logs", log_file) log_file_path = os.path.join(BUILD_DIR, "logs", log_file)
data_directory_path = os.path.join(BUILD_DIR, data_directory) data_directory_path = os.path.join(BUILD_DIR, data_directory)
mg_instance = MemgraphInstanceRunner(MEMGRAPH_BINARY, use_ssl, {data_directory_path}) mg_instance = MemgraphInstanceRunner(
MEMGRAPH_BINARY, use_ssl, {data_directory_path}, username=username, password=password
)
MEMGRAPH_INSTANCES[name] = mg_instance MEMGRAPH_INSTANCES[name] = mg_instance
binary_args = args + ["--log-file", log_file_path] + ["--data-directory", data_directory_path] binary_args = args + ["--log-file", log_file_path] + ["--data-directory", data_directory_path]
@ -185,8 +189,14 @@ def start_instance(context, name, procdir):
data_directory = value["data_directory"] data_directory = value["data_directory"]
else: else:
data_directory = tempfile.TemporaryDirectory().name data_directory = tempfile.TemporaryDirectory().name
username = None
if "username" in value:
username = value["username"]
password = None
if "password" in value:
password = value["password"]
instance = _start_instance(name, args, log_file, queries, use_ssl, procdir, data_directory) instance = _start_instance(name, args, log_file, queries, use_ssl, procdir, data_directory, username, password)
mg_instances[name] = instance mg_instances[name] = instance
assert len(mg_instances) == 1 assert len(mg_instances) == 1

View File

@ -57,7 +57,7 @@ def replace_paths(path):
class MemgraphInstanceRunner: class MemgraphInstanceRunner:
def __init__(self, binary_path=MEMGRAPH_BINARY, use_ssl=False, delete_on_stop=None): def __init__(self, binary_path=MEMGRAPH_BINARY, use_ssl=False, delete_on_stop=None, username=None, password=None):
self.host = "127.0.0.1" self.host = "127.0.0.1"
self.bolt_port = None self.bolt_port = None
self.binary_path = binary_path self.binary_path = binary_path
@ -65,12 +65,19 @@ class MemgraphInstanceRunner:
self.proc_mg = None self.proc_mg = None
self.ssl = use_ssl self.ssl = use_ssl
self.delete_on_stop = delete_on_stop self.delete_on_stop = delete_on_stop
self.username = username
self.password = password
def execute_setup_queries(self, setup_queries): def execute_setup_queries(self, setup_queries):
if setup_queries is None: if setup_queries is None:
return return
# An assumption being database instance is fresh, no need for the auth. conn = mgclient.connect(
conn = mgclient.connect(host=self.host, port=self.bolt_port, sslmode=self.ssl) host=self.host,
port=self.bolt_port,
sslmode=self.ssl,
username=(self.username or ""),
password=(self.password or ""),
)
conn.autocommit = True conn.autocommit = True
cursor = conn.cursor() cursor = conn.cursor()
for query_coll in setup_queries: for query_coll in setup_queries:

View File

@ -3,6 +3,7 @@ find_package(gflags REQUIRED)
copy_e2e_python_files(replication_experiment common.py) copy_e2e_python_files(replication_experiment common.py)
copy_e2e_python_files(replication_experiment conftest.py) copy_e2e_python_files(replication_experiment conftest.py)
copy_e2e_python_files(replication_experiment multitenancy.py) copy_e2e_python_files(replication_experiment multitenancy.py)
copy_e2e_python_files(replication_experiment auth.py)
copy_e2e_python_files_from_parent_folder(replication_experiment ".." memgraph.py) copy_e2e_python_files_from_parent_folder(replication_experiment ".." memgraph.py)
copy_e2e_python_files_from_parent_folder(replication_experiment ".." interactive_mg_runner.py) copy_e2e_python_files_from_parent_folder(replication_experiment ".." interactive_mg_runner.py)
copy_e2e_python_files_from_parent_folder(replication_experiment ".." mg_utils.py) copy_e2e_python_files_from_parent_folder(replication_experiment ".." mg_utils.py)

View File

@ -0,0 +1,831 @@
# Copyright 2022 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import atexit
import os
import shutil
import sys
import tempfile
import time
from functools import partial
import interactive_mg_runner
import mgclient
import pytest
from common import execute_and_fetch_all
from mg_utils import mg_sleep_and_assert
interactive_mg_runner.SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
interactive_mg_runner.PROJECT_DIR = os.path.normpath(
os.path.join(interactive_mg_runner.SCRIPT_DIR, "..", "..", "..", "..")
)
interactive_mg_runner.BUILD_DIR = os.path.normpath(os.path.join(interactive_mg_runner.PROJECT_DIR, "build"))
interactive_mg_runner.MEMGRAPH_BINARY = os.path.normpath(os.path.join(interactive_mg_runner.BUILD_DIR, "memgraph"))
BOLT_PORTS = {"main": 7687, "replica_1": 7688, "replica_2": 7689}
REPLICATION_PORTS = {"replica_1": 10001, "replica_2": 10002}
TEMP_DIR = tempfile.TemporaryDirectory().name
def update_to_main(cursor):
execute_and_fetch_all(cursor, "SET REPLICATION ROLE TO MAIN;")
def add_user(cursor, username, password=None):
if password is not None:
return execute_and_fetch_all(cursor, f"CREATE USER {username} IDENTIFIED BY '{password}';")
return execute_and_fetch_all(cursor, f"CREATE USER {username};")
def show_users_func(cursor):
def func():
return set(execute_and_fetch_all(cursor, "SHOW USERS;"))
return func
def show_roles_func(cursor):
def func():
return set(execute_and_fetch_all(cursor, "SHOW ROLES;"))
return func
def show_users_for_role_func(cursor, rolename):
def func():
return set(execute_and_fetch_all(cursor, f"SHOW USERS FOR {rolename};"))
return func
def show_role_for_user_func(cursor, username):
def func():
return set(execute_and_fetch_all(cursor, f"SHOW ROLE FOR {username};"))
return func
def show_privileges_func(cursor, user_or_role):
def func():
return set(execute_and_fetch_all(cursor, f"SHOW PRIVILEGES FOR {user_or_role};"))
return func
def show_database_privileges_func(cursor, user):
def func():
return execute_and_fetch_all(cursor, f"SHOW DATABASE PRIVILEGES FOR {user};")
return func
def show_database_func(cursor):
def func():
return execute_and_fetch_all(cursor, f"SHOW DATABASE;")
return func
def try_and_count(cursor, query):
try:
execute_and_fetch_all(cursor, query)
except:
return 1
return 0
def only_main_queries(cursor):
n_exceptions = 0
n_exceptions += try_and_count(cursor, f"CREATE USER user_name")
n_exceptions += try_and_count(cursor, f"SET PASSWORD FOR user_name TO 'new_password'")
n_exceptions += try_and_count(cursor, f"DROP USER user_name")
n_exceptions += try_and_count(cursor, f"CREATE ROLE role_name")
n_exceptions += try_and_count(cursor, f"DROP ROLE role_name")
n_exceptions += try_and_count(cursor, f"CREATE USER user_name")
n_exceptions += try_and_count(cursor, f"CREATE ROLE role_name")
n_exceptions += try_and_count(cursor, f"SET ROLE FOR user_name TO role_name")
n_exceptions += try_and_count(cursor, f"CLEAR ROLE FOR user_name")
n_exceptions += try_and_count(cursor, f"GRANT AUTH TO role_name")
n_exceptions += try_and_count(cursor, f"DENY AUTH, INDEX TO user_name")
n_exceptions += try_and_count(cursor, f"REVOKE AUTH FROM role_name")
n_exceptions += try_and_count(cursor, f"GRANT READ ON LABELS :l TO role_name;")
n_exceptions += try_and_count(cursor, f"REVOKE EDGE_TYPES :e FROM user_name")
n_exceptions += try_and_count(cursor, f"GRANT DATABASE memgraph TO user_name;")
n_exceptions += try_and_count(cursor, f"SET MAIN DATABASE memgraph FOR user_name")
n_exceptions += try_and_count(cursor, f"REVOKE DATABASE memgraph FROM user_name;")
return n_exceptions
def main_and_repl_queries(cursor):
n_exceptions = 0
try_and_count(cursor, f"SHOW USERS")
try_and_count(cursor, f"SHOW ROLES")
try_and_count(cursor, f"SHOW USERS FOR ROLE role_name")
try_and_count(cursor, f"SHOW ROLE FOR user_name")
try_and_count(cursor, f"SHOW PRIVILEGES FOR role_name")
try_and_count(cursor, f"SHOW DATABASE PRIVILEGES FOR user_name")
return n_exceptions
def test_auth_queries_on_replica(connection):
# Goal: check that write auth queries are forbidden on REPLICAs
# 0/ Setup replication cluster
# 1/ Check queries
MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL = {
"replica_1": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['replica_1']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/replica1",
],
"log_file": "replica1.log",
"setup_queries": [
f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_1']};",
],
},
"replica_2": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['replica_2']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/replica2",
],
"log_file": "replica2.log",
"setup_queries": [
f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_2']};",
],
},
"main": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['main']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/main",
],
"log_file": "main.log",
"setup_queries": [
f"REGISTER REPLICA replica_1 SYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_1']}';",
f"REGISTER REPLICA replica_2 ASYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_2']}';",
],
},
}
# 0/
interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL, keep_directories=False)
cursor_main = connection(BOLT_PORTS["main"], "main", "UsErA", "pass").cursor()
cursor_replica_1 = connection(BOLT_PORTS["replica_1"], "replica", "UsErA", "pass").cursor()
cursor_replica_2 = connection(BOLT_PORTS["replica_2"], "replica", "UsErA", "pass").cursor()
# 1/
assert only_main_queries(cursor_main) == 0
assert only_main_queries(cursor_replica_1) == 17
assert only_main_queries(cursor_replica_2) == 17
assert main_and_repl_queries(cursor_main) == 0
assert main_and_repl_queries(cursor_replica_1) == 0
assert main_and_repl_queries(cursor_replica_2) == 0
def test_manual_users_recovery(connection):
# Goal: show system recovery in action at registration time
# 0/ MAIN CREATE USER user1, user2
# REPLICA CREATE USER user3, user4
# Setup replication cluster
# 1/ Check that both MAIN and REPLICA have user1 and user2
# 2/ Check connections on REPLICAS
MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL = {
"replica_1": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['replica_1']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/replica1",
],
"log_file": "replica1.log",
"setup_queries": [
"CREATE USER user3;",
f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_1']};",
],
},
"replica_2": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['replica_2']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/replica2",
],
"log_file": "replica2.log",
"setup_queries": [
"CREATE USER user4 IDENTIFIED BY 'password';",
f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_2']};",
],
},
"main": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['main']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/main",
],
"log_file": "main.log",
"setup_queries": [
"CREATE USER user1;",
"CREATE USER user2 IDENTIFIED BY 'password';",
f"REGISTER REPLICA replica_1 SYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_1']}';",
f"REGISTER REPLICA replica_2 ASYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_2']}';",
],
},
}
# 0/
interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL, keep_directories=False)
cursor = connection(BOLT_PORTS["main"], "main", "user1").cursor()
# 1/
expected_data = {("user2",), ("user1",)}
mg_sleep_and_assert(
expected_data, show_users_func(connection(BOLT_PORTS["replica_1"], "replica", "user1").cursor())
)
mg_sleep_and_assert(
expected_data, show_users_func(connection(BOLT_PORTS["replica_2"], "replica", "user1").cursor())
)
# 2/
connection(BOLT_PORTS["replica_1"], "replica", "user1").cursor()
connection(BOLT_PORTS["replica_1"], "replica", "user2", "password").cursor()
connection(BOLT_PORTS["replica_2"], "replica", "user1").cursor()
connection(BOLT_PORTS["replica_2"], "replica", "user2", "password").cursor()
def test_env_users_recovery(connection):
# Goal: show system recovery in action at registration time
# 0/ Set users from the environment
# MAIN gets users from the environment
# Setup replication cluster
# 1/ Check that both MAIN and REPLICA have user1
MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL = {
"replica_1": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['replica_1']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/replica1",
],
"log_file": "replica1.log",
"setup_queries": [
f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_1']};",
],
},
"replica_2": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['replica_2']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/replica2",
],
"log_file": "replica2.log",
"setup_queries": [
f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_2']};",
],
},
"main": {
"username": "user1",
"password": "password",
"args": [
"--bolt-port",
f"{BOLT_PORTS['main']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/main",
],
"log_file": "main.log",
"setup_queries": [
f"REGISTER REPLICA replica_1 SYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_1']}';",
f"REGISTER REPLICA replica_2 ASYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_2']}';",
],
},
}
# 0/
# Start only replicas without the env user
interactive_mg_runner.stop_all(keep_directories=False)
interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL, "replica_1")
interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL, "replica_2")
# Setup user
try:
os.environ["MEMGRAPH_USER"] = "user1"
os.environ["MEMGRAPH_PASSWORD"] = "password"
# Start main
interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL, "main")
finally:
# Cleanup
del os.environ["MEMGRAPH_USER"]
del os.environ["MEMGRAPH_PASSWORD"]
# 1/
expected_data = {("user1",)}
assert expected_data == show_users_func(connection(BOLT_PORTS["main"], "main", "user1", "password").cursor())()
mg_sleep_and_assert(
expected_data, show_users_func(connection(BOLT_PORTS["replica_1"], "replica", "user1", "password").cursor())
)
mg_sleep_and_assert(
expected_data, show_users_func(connection(BOLT_PORTS["replica_2"], "replica", "user1", "password").cursor())
)
def test_manual_roles_recovery(connection):
# Goal: show system recovery in action at registration time
# 0/ MAIN CREATE USER user1, user2
# REPLICA CREATE USER user3, user4
# Setup replication cluster
# 1/ Check that both MAIN and REPLICA have user1 and user2
# 2/ Check that role1 and role2 are replicated
# 3/ Check that user1 has role1
MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL = {
"replica_1": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['replica_1']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/replica1",
],
"log_file": "replica1.log",
"setup_queries": [
"CREATE ROLE role3;",
f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_1']};",
],
},
"replica_2": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['replica_2']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/replica2",
],
"log_file": "replica2.log",
"setup_queries": [
"CREATE ROLE role4;",
"CREATE USER user4;",
"SET ROLE FOR user4 TO role4;",
f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_2']};",
],
},
"main": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['main']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/main",
],
"log_file": "main.log",
"setup_queries": [
"CREATE ROLE role1;",
"CREATE ROLE role2;",
"CREATE USER user2;",
"SET ROLE FOR user2 TO role2;",
f"REGISTER REPLICA replica_1 SYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_1']}';",
f"REGISTER REPLICA replica_2 ASYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_2']}';",
],
},
}
# 0/
interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL, keep_directories=False)
connection(BOLT_PORTS["main"], "main", "user2").cursor() # Just check if it connects
cursor_replica_1 = connection(BOLT_PORTS["replica_1"], "replica", "user2").cursor()
cursor_replica_2 = connection(BOLT_PORTS["replica_2"], "replica", "user2").cursor()
# 1/
expected_data = {
("user2",),
}
mg_sleep_and_assert(expected_data, show_users_func(cursor_replica_1))
mg_sleep_and_assert(expected_data, show_users_func(cursor_replica_2))
# 2/
expected_data = {("role2",), ("role1",)}
mg_sleep_and_assert(expected_data, show_roles_func(cursor_replica_1))
mg_sleep_and_assert(expected_data, show_roles_func(cursor_replica_2))
# 3/
expected_data = {("role2",)}
mg_sleep_and_assert(
expected_data,
show_role_for_user_func(cursor_replica_1, "user2"),
)
mg_sleep_and_assert(
expected_data,
show_role_for_user_func(cursor_replica_2, "user2"),
)
def test_auth_config_recovery(connection):
# Goal: show we are replicating Auth::Config
# 0/ Setup auth configuration and compliant users
# 1/ Check that both MAIN and REPLICA have the same users
# 2/ Check that REPLICAS have the same config
MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL = {
"replica_1": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['replica_1']}",
"--log-level=TRACE",
"--auth-password-strength-regex",
"^[A-Z]+$",
"--auth-password-permit-null=false",
"--auth-user-or-role-name-regex",
"^[O-o]+$",
"--data_directory",
TEMP_DIR + "/replica1",
],
"log_file": "replica1.log",
"setup_queries": [
"CREATE USER OPQabc IDENTIFIED BY 'PASSWORD';",
"CREATE ROLE defRST;",
f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_1']};",
],
},
"replica_2": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['replica_2']}",
"--log-level=TRACE",
"--auth-password-strength-regex",
"^[0-9]+$",
"--auth-password-permit-null=true",
"--auth-user-or-role-name-regex",
"^[A-Np-z]+$",
"--data_directory",
TEMP_DIR + "/replica2",
],
"log_file": "replica2.log",
"setup_queries": [
"CREATE ROLE ABCpqr;",
"CREATE USER stuDEF;",
"CREATE USER GvHwI IDENTIFIED BY '123456';",
"SET ROLE FOR GvHwI TO ABCpqr;",
f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_2']};",
],
},
"main": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['main']}",
"--log-level=TRACE",
"--auth-password-strength-regex",
"^[a-z]+$",
"--auth-password-permit-null=false",
"--auth-user-or-role-name-regex",
"^[A-z]+$",
"--data_directory",
TEMP_DIR + "/main",
],
"log_file": "main.log",
"setup_queries": [
"CREATE USER UsErA IDENTIFIED BY 'pass';",
"CREATE ROLE rOlE;",
"CREATE USER uSeRB IDENTIFIED BY 'word';",
"SET ROLE FOR uSeRB TO rOlE;",
f"REGISTER REPLICA replica_1 SYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_1']}';",
f"REGISTER REPLICA replica_2 ASYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_2']}';",
],
},
}
# 0/
interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL, keep_directories=False)
# 1/
cursor_main = connection(BOLT_PORTS["main"], "main", "UsErA", "pass").cursor()
cursor_replica_1 = connection(BOLT_PORTS["replica_1"], "replica", "UsErA", "pass").cursor()
cursor_replica_2 = connection(BOLT_PORTS["replica_2"], "replica", "UsErA", "pass").cursor()
# 2/ Only MAIN can update users
def user_test(cursor):
with pytest.raises(mgclient.DatabaseError, match="Invalid user name."):
add_user(cursor, "UsEr1", "abcdef")
with pytest.raises(mgclient.DatabaseError, match="Null passwords aren't permitted!"):
add_user(cursor, "UsErC")
with pytest.raises(mgclient.DatabaseError, match="The user password doesn't conform to the required strength!"):
add_user(cursor, "UsErC", "123456")
user_test(cursor_main)
update_to_main(cursor_replica_1)
user_test(cursor_replica_1)
update_to_main(cursor_replica_2)
user_test(cursor_replica_2)
def test_auth_replication(connection):
# Goal: show that individual auth queries get replicated
MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL = {
"replica_1": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['replica_1']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/replica1",
],
"log_file": "replica1.log",
"setup_queries": [
f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_1']};",
],
},
"replica_2": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['replica_2']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/replica2",
],
"log_file": "replica2.log",
"setup_queries": [
f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_2']};",
],
},
"main": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['main']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/main",
],
"log_file": "main.log",
"setup_queries": [
f"REGISTER REPLICA replica_1 SYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_1']}';",
f"REGISTER REPLICA replica_2 ASYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_2']}';",
],
},
}
# 0/
interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL, keep_directories=False)
cursor_main = connection(BOLT_PORTS["main"], "main", "user1").cursor()
cursor_replica1 = connection(BOLT_PORTS["replica_1"], "replica").cursor()
cursor_replica2 = connection(BOLT_PORTS["replica_2"], "replica").cursor()
# 1/
def check(f, expected_data):
# mg_sleep_and_assert(
# REPLICA 1 is SYNC, should already be ready
assert expected_data == f(cursor_replica1)()
# )
mg_sleep_and_assert(expected_data, f(cursor_replica2))
# CREATE USER
execute_and_fetch_all(cursor_main, "CREATE USER user1")
check(
show_users_func,
{
("user1",),
},
)
execute_and_fetch_all(cursor_main, "CREATE USER user2 IDENTIFIED BY 'pass'")
check(
show_users_func,
{
("user2",),
("user1",),
},
)
connection(BOLT_PORTS["replica_1"], "replica", "user1").cursor() # Just check connection
connection(BOLT_PORTS["replica_2"], "replica", "user1").cursor() # Just check connection
connection(BOLT_PORTS["replica_1"], "replica", "user2", "pass").cursor() # Just check connection
connection(BOLT_PORTS["replica_2"], "replica", "user2", "pass").cursor() # Just check connection
# SET PASSWORD
execute_and_fetch_all(cursor_main, "SET PASSWORD FOR user1 TO '1234'")
execute_and_fetch_all(cursor_main, "SET PASSWORD FOR user2 TO 'new_pass'")
connection(BOLT_PORTS["replica_1"], "replica", "user1", "1234").cursor() # Just check connection
connection(BOLT_PORTS["replica_2"], "replica", "user1", "1234").cursor() # Just check connection
connection(BOLT_PORTS["replica_1"], "replica", "user2", "new_pass").cursor() # Just check connection
connection(BOLT_PORTS["replica_2"], "replica", "user2", "new_pass").cursor() # Just check connection
# DROP USER
execute_and_fetch_all(cursor_main, "DROP USER user2")
check(
show_users_func,
{
("user1",),
},
)
execute_and_fetch_all(cursor_main, "DROP USER user1")
check(show_users_func, set())
connection(BOLT_PORTS["replica_1"], "replica").cursor() # Just check connection
connection(BOLT_PORTS["replica_2"], "replica").cursor() # Just check connection
# CREATE ROLE
execute_and_fetch_all(cursor_main, "CREATE ROLE role1")
check(
show_roles_func,
{
("role1",),
},
)
execute_and_fetch_all(cursor_main, "CREATE ROLE role2")
check(
show_roles_func,
{
("role2",),
("role1",),
},
)
# DROP ROLE
execute_and_fetch_all(cursor_main, "DROP ROLE role2")
check(
show_roles_func,
{
("role1",),
},
)
execute_and_fetch_all(cursor_main, "DROP ROLE role1")
check(show_roles_func, set())
# SET ROLE
execute_and_fetch_all(cursor_main, "CREATE USER user3")
execute_and_fetch_all(cursor_main, "CREATE ROLE role3")
execute_and_fetch_all(cursor_main, "SET ROLE FOR user3 TO role3")
check(partial(show_role_for_user_func, username="user3"), {("role3",)})
execute_and_fetch_all(cursor_main, "CREATE USER user3b")
execute_and_fetch_all(cursor_main, "SET ROLE FOR user3b TO role3")
check(partial(show_role_for_user_func, username="user3b"), {("role3",)})
check(
partial(show_users_for_role_func, rolename="role3"),
{
("user3",),
("user3b",),
},
)
# CLEAR ROLE
execute_and_fetch_all(cursor_main, "CLEAR ROLE FOR user3")
check(partial(show_role_for_user_func, username="user3"), {("null",)})
check(
partial(show_users_for_role_func, rolename="role3"),
{
("user3b",),
},
)
# GRANT/REVOKE/DENY privileges TO user
execute_and_fetch_all(cursor_main, "CREATE USER user4")
execute_and_fetch_all(cursor_main, "REVOKE ALL PRIVILEGES FROM user4")
execute_and_fetch_all(cursor_main, "GRANT CREATE, DELETE, SET TO user4")
check(
partial(show_privileges_func, user_or_role="user4"),
{
("CREATE", "GRANT", "GRANTED TO USER"),
("DELETE", "GRANT", "GRANTED TO USER"),
("SET", "GRANT", "GRANTED TO USER"),
},
)
execute_and_fetch_all(cursor_main, "REVOKE SET FROM user4")
check(
partial(show_privileges_func, user_or_role="user4"),
{("CREATE", "GRANT", "GRANTED TO USER"), ("DELETE", "GRANT", "GRANTED TO USER")},
)
execute_and_fetch_all(cursor_main, "DENY DELETE TO user4")
check(
partial(show_privileges_func, user_or_role="user4"),
{("CREATE", "GRANT", "GRANTED TO USER"), ("DELETE", "DENY", "DENIED TO USER")},
)
# GRANT/REVOKE/DENY privileges TO role
execute_and_fetch_all(cursor_main, "REVOKE ALL PRIVILEGES FROM role3")
execute_and_fetch_all(cursor_main, "REVOKE ALL PRIVILEGES FROM user3b")
execute_and_fetch_all(cursor_main, "GRANT CREATE, DELETE, SET TO role3")
check(
partial(show_privileges_func, user_or_role="role3"),
{
("CREATE", "GRANT", "GRANTED TO ROLE"),
("DELETE", "GRANT", "GRANTED TO ROLE"),
("SET", "GRANT", "GRANTED TO ROLE"),
},
)
check(
partial(show_privileges_func, user_or_role="user3b"),
{
("CREATE", "GRANT", "GRANTED TO ROLE"),
("DELETE", "GRANT", "GRANTED TO ROLE"),
("SET", "GRANT", "GRANTED TO ROLE"),
},
)
execute_and_fetch_all(cursor_main, "REVOKE SET FROM role3")
check(
partial(show_privileges_func, user_or_role="role3"),
{("CREATE", "GRANT", "GRANTED TO ROLE"), ("DELETE", "GRANT", "GRANTED TO ROLE")},
)
check(
partial(show_privileges_func, user_or_role="user3b"),
{("CREATE", "GRANT", "GRANTED TO ROLE"), ("DELETE", "GRANT", "GRANTED TO ROLE")},
)
execute_and_fetch_all(cursor_main, "DENY DELETE TO role3")
check(
partial(show_privileges_func, user_or_role="role3"),
{("CREATE", "GRANT", "GRANTED TO ROLE"), ("DELETE", "DENY", "DENIED TO ROLE")},
)
check(
partial(show_privileges_func, user_or_role="user3b"),
{("CREATE", "GRANT", "GRANTED TO ROLE"), ("DELETE", "DENY", "DENIED TO ROLE")},
)
# GRANT permission ON LABEL/EDGE to user/role
execute_and_fetch_all(cursor_main, "REVOKE ALL PRIVILEGES FROM role3")
execute_and_fetch_all(cursor_main, "REVOKE ALL PRIVILEGES FROM user4")
execute_and_fetch_all(cursor_main, "REVOKE ALL PRIVILEGES FROM user3b")
execute_and_fetch_all(cursor_main, "GRANT READ ON LABELS :l1 TO user4")
execute_and_fetch_all(cursor_main, "GRANT UPDATE ON LABELS :l2, :l3 TO role3")
check(
partial(show_privileges_func, user_or_role="user4"),
{
("LABEL :l1", "READ", "LABEL PERMISSION GRANTED TO USER"),
},
)
check(
partial(show_privileges_func, user_or_role="role3"),
{
("LABEL :l3", "UPDATE", "LABEL PERMISSION GRANTED TO ROLE"),
("LABEL :l2", "UPDATE", "LABEL PERMISSION GRANTED TO ROLE"),
},
)
check(
partial(show_privileges_func, user_or_role="user3b"),
{
("LABEL :l3", "UPDATE", "LABEL PERMISSION GRANTED TO ROLE"),
("LABEL :l2", "UPDATE", "LABEL PERMISSION GRANTED TO ROLE"),
},
)
execute_and_fetch_all(cursor_main, "REVOKE LABELS :l1 FROM user4")
execute_and_fetch_all(cursor_main, "REVOKE LABELS :l2 FROM role3")
check(partial(show_privileges_func, user_or_role="user4"), set())
check(
partial(show_privileges_func, user_or_role="role3"),
{("LABEL :l3", "UPDATE", "LABEL PERMISSION GRANTED TO ROLE")},
)
check(
partial(show_privileges_func, user_or_role="user3b"),
{("LABEL :l3", "UPDATE", "LABEL PERMISSION GRANTED TO ROLE")},
)
# GRANT/REVOKE DATABASE
execute_and_fetch_all(cursor_main, "CREATE DATABASE auth_test")
execute_and_fetch_all(cursor_main, "CREATE DATABASE auth_test2")
execute_and_fetch_all(cursor_main, "GRANT DATABASE auth_test TO user4")
check(partial(show_database_privileges_func, user="user4"), [(["auth_test", "memgraph"], [])])
execute_and_fetch_all(cursor_main, "REVOKE DATABASE auth_test2 FROM user4")
check(partial(show_database_privileges_func, user="user4"), [(["auth_test", "memgraph"], ["auth_test2"])])
# SET MAIN DATABASE
execute_and_fetch_all(cursor_main, "GRANT ALL PRIVILEGES TO user4")
execute_and_fetch_all(cursor_main, "SET MAIN DATABASE auth_test FOR user4")
# Reconnect and check current db
assert (
execute_and_fetch_all(connection(BOLT_PORTS["main"], "main", "user4").cursor(), "SHOW DATABASE")[0][0]
== "auth_test"
)
assert (
execute_and_fetch_all(connection(BOLT_PORTS["replica_1"], "replica", "user4").cursor(), "SHOW DATABASE")[0][0]
== "auth_test"
)
assert (
execute_and_fetch_all(connection(BOLT_PORTS["replica_2"], "replica", "user4").cursor(), "SHOW DATABASE")[0][0]
== "auth_test"
)
if __name__ == "__main__":
interactive_mg_runner.cleanup_directories_on_exit()
sys.exit(pytest.main([__file__, "-rA"]))

View File

@ -18,9 +18,9 @@ def connection():
connection_holder = None connection_holder = None
role_holder = None role_holder = None
def inner_connection(port, role): def inner_connection(port, role, username="", password=""):
nonlocal connection_holder, role_holder nonlocal connection_holder, role_holder
connection_holder = connect(host="localhost", port=port) connection_holder = connect(host="localhost", port=port, username=username, password=password)
role_holder = role role_holder = role
return connection_holder return connection_holder

View File

@ -2,3 +2,6 @@ workloads:
- name: "Replicate multitenancy" - name: "Replicate multitenancy"
binary: "tests/e2e/pytest_runner.sh" binary: "tests/e2e/pytest_runner.sh"
args: ["replication_experimental/multitenancy.py"] args: ["replication_experimental/multitenancy.py"]
- name: "Replicate auth data"
binary: "tests/e2e/pytest_runner.sh"
args: ["replication_experimental/auth.py"]

View File

@ -33,7 +33,7 @@ int main(int argc, char **argv) {
// Memgraph backend // Memgraph backend
std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_telemetry_integration_test"}; std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_telemetry_integration_test"};
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> auth_{ memgraph::auth::SynchedAuth auth_{
data_directory / "auth", data_directory / "auth",
memgraph::auth::Auth::Config{std::string{memgraph::glue::kDefaultUserRoleRegex}, "", true}}; memgraph::auth::Auth::Config{std::string{memgraph::glue::kDefaultUserRoleRegex}, "", true}};
memgraph::glue::AuthQueryHandler auth_handler(&auth_); memgraph::glue::AuthQueryHandler auth_handler(&auth_);
@ -43,14 +43,20 @@ int main(int argc, char **argv) {
memgraph::storage::UpdatePaths(db_config, data_directory); memgraph::storage::UpdatePaths(db_config, data_directory);
memgraph::replication::ReplicationState repl_state(ReplicationStateRootPath(db_config)); memgraph::replication::ReplicationState repl_state(ReplicationStateRootPath(db_config));
memgraph::dbms::DbmsHandler dbms_handler(db_config memgraph::system::System system_state;
memgraph::dbms::DbmsHandler dbms_handler(db_config, system_state, repl_state
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
, ,
&auth_, false auth_, false
#endif #endif
); );
memgraph::query::InterpreterContext interpreter_context_({}, &dbms_handler, &repl_state, &auth_handler, memgraph::query::InterpreterContext interpreter_context_({}, &dbms_handler, &repl_state, system_state
&auth_checker); #ifdef MG_ENTERPRISE
,
nullptr
#endif
,
&auth_handler, &auth_checker);
memgraph::requests::Init(); memgraph::requests::Init();
memgraph::telemetry::Telemetry telemetry(FLAGS_endpoint, FLAGS_storage_directory, memgraph::utils::GenerateUUID(), memgraph::telemetry::Telemetry telemetry(FLAGS_endpoint, FLAGS_storage_directory, memgraph::utils::GenerateUUID(),
@ -65,9 +71,9 @@ int main(int argc, char **argv) {
}); });
// Memgraph specific collectors // Memgraph specific collectors
telemetry.AddStorageCollector(dbms_handler, auth_); telemetry.AddStorageCollector(dbms_handler, auth_, repl_state);
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
telemetry.AddDatabaseCollector(dbms_handler); telemetry.AddDatabaseCollector(dbms_handler, repl_state);
#else #else
telemetry.AddDatabaseCollector(); telemetry.AddDatabaseCollector();
#endif #endif

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd. // Copyright 2024 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -39,7 +39,14 @@ int main(int argc, char *argv[]) {
auto db_acc_opt = db_gk.access(); auto db_acc_opt = db_gk.access();
MG_ASSERT(db_acc_opt, "Failed to access db"); MG_ASSERT(db_acc_opt, "Failed to access db");
auto &db_acc = *db_acc_opt; auto &db_acc = *db_acc_opt;
memgraph::query::InterpreterContext interpreter_context(memgraph::query::InterpreterConfig{}, nullptr, &repl_state); memgraph::system::System system_state;
memgraph::query::InterpreterContext interpreter_context(memgraph::query::InterpreterConfig{}, nullptr, &repl_state,
system_state
#ifdef MG_ENTERPRISE
,
nullptr
#endif
);
memgraph::query::Interpreter interpreter{&interpreter_context, db_acc}; memgraph::query::Interpreter interpreter{&interpreter_context, db_acc};
ResultStreamFaker stream(db_acc->storage()); ResultStreamFaker stream(db_acc->storage());

View File

@ -25,7 +25,7 @@
class AuthQueryHandlerFixture : public testing::Test { class AuthQueryHandlerFixture : public testing::Test {
protected: protected:
std::filesystem::path test_folder_{std::filesystem::temp_directory_path() / "MG_tests_unit_auth_handler"}; std::filesystem::path test_folder_{std::filesystem::temp_directory_path() / "MG_tests_unit_auth_handler"};
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> auth{ memgraph::auth::SynchedAuth auth{
test_folder_ / ("unit_auth_handler_test_" + std::to_string(static_cast<int>(getpid()))), test_folder_ / ("unit_auth_handler_test_" + std::to_string(static_cast<int>(getpid()))),
memgraph::auth::Auth::Config{/* default */}}; memgraph::auth::Auth::Config{/* default */}};
memgraph::glue::AuthQueryHandler auth_handler{&auth}; memgraph::glue::AuthQueryHandler auth_handler{&auth};

View File

@ -10,6 +10,7 @@
// licenses/APL.txt. // licenses/APL.txt.
#include "query/auth_query_handler.hpp" #include "query/auth_query_handler.hpp"
#include "replication/state.hpp"
#include "storage/v2/config.hpp" #include "storage/v2/config.hpp"
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
#include <gmock/gmock.h> #include <gmock/gmock.h>
@ -43,7 +44,9 @@ std::set<std::string> GetDirs(auto path) {
std::filesystem::path storage_directory{std::filesystem::temp_directory_path() / "MG_test_unit_dbms_handler"}; std::filesystem::path storage_directory{std::filesystem::temp_directory_path() / "MG_test_unit_dbms_handler"};
std::filesystem::path db_dir{storage_directory / "databases"}; std::filesystem::path db_dir{storage_directory / "databases"};
static memgraph::storage::Config storage_conf; static memgraph::storage::Config storage_conf;
std::unique_ptr<memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock>> auth; std::unique_ptr<memgraph::auth::SynchedAuth> auth;
std::unique_ptr<memgraph::system::System> system_state;
std::unique_ptr<memgraph::replication::ReplicationState> repl_state;
// Let this be global so we can test it different states throughout // Let this be global so we can test it different states throughout
@ -64,14 +67,18 @@ class TestEnvironment : public ::testing::Environment {
std::filesystem::remove_all(storage_directory); std::filesystem::remove_all(storage_directory);
} }
} }
auth = auth = std::make_unique<memgraph::auth::SynchedAuth>(storage_directory / "auth",
std::make_unique<memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock>>( memgraph::auth::Auth::Config{/* default */});
storage_directory / "auth", memgraph::auth::Auth::Config{/* default */}); system_state = std::make_unique<memgraph::system::System>();
ptr_ = std::make_unique<memgraph::dbms::DbmsHandler>(storage_conf, auth.get(), false); repl_state = std::make_unique<memgraph::replication::ReplicationState>(ReplicationStateRootPath(storage_conf));
ptr_ = std::make_unique<memgraph::dbms::DbmsHandler>(storage_conf, *system_state.get(), *repl_state.get(),
*auth.get(), false);
} }
void TearDown() override { void TearDown() override {
ptr_.reset(); ptr_.reset();
repl_state.reset();
system_state.reset();
auth.reset(); auth.reset();
std::filesystem::remove_all(storage_directory); std::filesystem::remove_all(storage_directory);
} }

View File

@ -28,7 +28,9 @@
// Global // Global
std::filesystem::path storage_directory{std::filesystem::temp_directory_path() / "MG_test_unit_dbms_handler_community"}; std::filesystem::path storage_directory{std::filesystem::temp_directory_path() / "MG_test_unit_dbms_handler_community"};
static memgraph::storage::Config storage_conf; static memgraph::storage::Config storage_conf;
std::unique_ptr<memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock>> auth; std::unique_ptr<memgraph::auth::SynchedAuth> auth;
std::unique_ptr<memgraph::system::System> system_state;
std::unique_ptr<memgraph::replication::ReplicationState> repl_state;
// Let this be global so we can test it different states throughout // Let this be global so we can test it different states throughout
@ -49,14 +51,17 @@ class TestEnvironment : public ::testing::Environment {
std::filesystem::remove_all(storage_directory); std::filesystem::remove_all(storage_directory);
} }
} }
auth = auth = std::make_unique<memgraph::auth::SynchedAuth>(storage_directory / "auth",
std::make_unique<memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock>>( memgraph::auth::Auth::Config{/* default */});
storage_directory / "auth", memgraph::auth::Auth::Config{/* default */}); system_state = std::make_unique<memgraph::system::System>();
ptr_ = std::make_unique<memgraph::dbms::DbmsHandler>(storage_conf); repl_state = std::make_unique<memgraph::replication::ReplicationState>(ReplicationStateRootPath(storage_conf));
ptr_ = std::make_unique<memgraph::dbms::DbmsHandler>(storage_conf, *system_state.get(), *repl_state.get());
} }
void TearDown() override { void TearDown() override {
ptr_.reset(); ptr_.reset();
repl_state.reset();
system_state.reset();
auth.reset(); auth.reset();
std::filesystem::remove_all(storage_directory); std::filesystem::remove_all(storage_directory);
} }

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd. // Copyright 2024 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -94,7 +94,16 @@ class InterpreterTest : public ::testing::Test {
}() // iile }() // iile
}; };
memgraph::query::InterpreterContext interpreter_context{{}, kNoHandler, &repl_state}; memgraph::system::System system_state;
memgraph::query::InterpreterContext interpreter_context{{},
kNoHandler,
&repl_state,
system_state
#ifdef MG_ENTERPRISE
,
nullptr
#endif
};
void TearDown() override { void TearDown() override {
if (std::is_same<StorageType, memgraph::storage::DiskStorage>::value) { if (std::is_same<StorageType, memgraph::storage::DiskStorage>::value) {
@ -1150,8 +1159,16 @@ TYPED_TEST(InterpreterTest, AllowLoadCsvConfig) {
<< "Wrong storage mode!"; << "Wrong storage mode!";
memgraph::replication::ReplicationState repl_state{std::nullopt}; memgraph::replication::ReplicationState repl_state{std::nullopt};
memgraph::query::InterpreterContext csv_interpreter_context{ memgraph::system::System system_state;
{.query = {.allow_load_csv = allow_load_csv}}, nullptr, &repl_state}; memgraph::query::InterpreterContext csv_interpreter_context{{.query = {.allow_load_csv = allow_load_csv}},
nullptr,
&repl_state,
system_state
#ifdef MG_ENTERPRISE
,
nullptr
#endif
};
InterpreterFaker interpreter_faker{&csv_interpreter_context, db_acc}; InterpreterFaker interpreter_faker{&csv_interpreter_context, db_acc};
for (const auto &query : queries) { for (const auto &query : queries) {
if (allow_load_csv) { if (allow_load_csv) {

View File

@ -14,6 +14,7 @@
#include <filesystem> #include <filesystem>
#include <thread> #include <thread>
#include "auth/auth.hpp"
#include "communication/bolt/v1/value.hpp" #include "communication/bolt/v1/value.hpp"
#include "communication/result_stream_faker.hpp" #include "communication/result_stream_faker.hpp"
#include "csv/parsing.hpp" #include "csv/parsing.hpp"
@ -99,8 +100,17 @@ class MultiTenantTest : public ::testing::Test {
struct MinMemgraph { struct MinMemgraph {
explicit MinMemgraph(const memgraph::storage::Config &conf) explicit MinMemgraph(const memgraph::storage::Config &conf)
: auth{conf.durability.storage_directory / "auth", memgraph::auth::Auth::Config{/* default */}}, : auth{conf.durability.storage_directory / "auth", memgraph::auth::Auth::Config{/* default */}},
dbms{conf, &auth, true}, repl_state{ReplicationStateRootPath(conf)},
interpreter_context{{}, &dbms, &dbms.ReplicationState()} { dbms{conf, system, repl_state, auth, true},
interpreter_context{{},
&dbms,
&repl_state,
system
#ifdef MG_ENTERPRISE
,
nullptr
#endif
} {
memgraph::utils::global_settings.Initialize(conf.durability.storage_directory / "settings"); memgraph::utils::global_settings.Initialize(conf.durability.storage_directory / "settings");
memgraph::license::RegisterLicenseSettings(memgraph::license::global_license_checker, memgraph::license::RegisterLicenseSettings(memgraph::license::global_license_checker,
memgraph::utils::global_settings); memgraph::utils::global_settings);
@ -112,7 +122,9 @@ class MultiTenantTest : public ::testing::Test {
auto NewInterpreter() { return InterpreterFaker{&interpreter_context, dbms.Get()}; } auto NewInterpreter() { return InterpreterFaker{&interpreter_context, dbms.Get()}; }
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> auth; memgraph::auth::SynchedAuth auth;
memgraph::system::System system;
memgraph::replication::ReplicationState repl_state;
memgraph::dbms::DbmsHandler dbms; memgraph::dbms::DbmsHandler dbms;
memgraph::query::InterpreterContext interpreter_context; memgraph::query::InterpreterContext interpreter_context;
}; };

View File

@ -314,8 +314,13 @@ class DumpTest : public ::testing::Test {
return db_acc; return db_acc;
}() // iile }() // iile
}; };
memgraph::system::System system_state;
memgraph::query::InterpreterContext context{memgraph::query::InterpreterConfig{}, nullptr, &repl_state}; memgraph::query::InterpreterContext context{memgraph::query::InterpreterConfig{}, nullptr, &repl_state, system_state
#ifdef MG_ENTERPRISE
,
nullptr
#endif
};
void TearDown() override { void TearDown() override {
if (std::is_same<StorageType, memgraph::storage::DiskStorage>::value) { if (std::is_same<StorageType, memgraph::storage::DiskStorage>::value) {
@ -722,7 +727,14 @@ TYPED_TEST(DumpTest, CheckStateVertexWithMultipleProperties) {
: memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL)) : memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL))
<< "Wrong storage mode!"; << "Wrong storage mode!";
memgraph::query::InterpreterContext interpreter_context(memgraph::query::InterpreterConfig{}, nullptr, &repl_state); memgraph::system::System system_state;
memgraph::query::InterpreterContext interpreter_context(memgraph::query::InterpreterConfig{}, nullptr, &repl_state,
system_state
#ifdef MG_ENTERPRISE
,
nullptr
#endif
);
{ {
ResultStreamFaker stream(this->db->storage()); ResultStreamFaker stream(this->db->storage());
@ -842,7 +854,14 @@ TYPED_TEST(DumpTest, CheckStateSimpleGraph) {
: memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL)) : memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL))
<< "Wrong storage mode!"; << "Wrong storage mode!";
memgraph::query::InterpreterContext interpreter_context(memgraph::query::InterpreterConfig{}, nullptr, &repl_state); memgraph::system::System system_state;
memgraph::query::InterpreterContext interpreter_context(memgraph::query::InterpreterConfig{}, nullptr, &repl_state,
system_state
#ifdef MG_ENTERPRISE
,
nullptr
#endif
);
{ {
ResultStreamFaker stream(this->db->storage()); ResultStreamFaker stream(this->db->storage());
memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource()); memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource());

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd. // Copyright 2024 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -42,6 +42,7 @@ class QueryExecution : public testing::Test {
std::optional<memgraph::replication::ReplicationState> repl_state; std::optional<memgraph::replication::ReplicationState> repl_state;
std::optional<memgraph::utils::Gatekeeper<memgraph::dbms::Database>> db_gk; std::optional<memgraph::utils::Gatekeeper<memgraph::dbms::Database>> db_gk;
std::optional<memgraph::system::System> system_state;
void SetUp() override { void SetUp() override {
auto config = [&]() { auto config = [&]() {
@ -65,14 +66,20 @@ class QueryExecution : public testing::Test {
: memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL), : memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL),
"Wrong storage mode!"); "Wrong storage mode!");
db_acc_ = std::move(db_acc); db_acc_ = std::move(db_acc);
system_state.emplace();
interpreter_context_.emplace(memgraph::query::InterpreterConfig{}, nullptr, &repl_state.value()); interpreter_context_.emplace(memgraph::query::InterpreterConfig{}, nullptr, &repl_state.value(), *system_state
#ifdef MG_ENTERPRISE
,
nullptr
#endif
);
interpreter_.emplace(&*interpreter_context_, *db_acc_); interpreter_.emplace(&*interpreter_context_, *db_acc_);
} }
void TearDown() override { void TearDown() override {
interpreter_ = std::nullopt; interpreter_ = std::nullopt;
interpreter_context_ = std::nullopt; interpreter_context_ = std::nullopt;
system_state.reset();
db_acc_.reset(); db_acc_.reset();
db_gk.reset(); db_gk.reset();
repl_state.reset(); repl_state.reset();

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd. // Copyright 2024 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -104,7 +104,14 @@ class StreamsTestFixture : public ::testing::Test {
return db_acc; return db_acc;
}() // iile }() // iile
}; };
memgraph::query::InterpreterContext interpreter_context_{memgraph::query::InterpreterConfig{}, nullptr, &repl_state}; memgraph::system::System system_state;
memgraph::query::InterpreterContext interpreter_context_{memgraph::query::InterpreterConfig{}, nullptr, &repl_state,
system_state
#ifdef MG_ENTERPRISE
,
nullptr
#endif
};
std::filesystem::path streams_data_directory_{data_directory_ / "separate-dir-for-test"}; std::filesystem::path streams_data_directory_{data_directory_ / "separate-dir-for-test"};
std::optional<StreamsTest> proxyStreams_; std::optional<StreamsTest> proxyStreams_;

View File

@ -25,22 +25,20 @@
#include "auth/auth.hpp" #include "auth/auth.hpp"
#include "dbms/database.hpp" #include "dbms/database.hpp"
#include "dbms/dbms_handler.hpp" #include "dbms/dbms_handler.hpp"
#include "dbms/replication_handler.hpp"
#include "query/interpreter_context.hpp" #include "query/interpreter_context.hpp"
#include "replication/config.hpp" #include "replication/config.hpp"
#include "replication/state.hpp" #include "replication/state.hpp"
#include "replication_handler/replication_handler.hpp"
#include "storage/v2/indices/label_index_stats.hpp" #include "storage/v2/indices/label_index_stats.hpp"
#include "storage/v2/storage.hpp" #include "storage/v2/storage.hpp"
#include "storage/v2/view.hpp" #include "storage/v2/view.hpp"
#include "utils/rw_lock.hpp"
#include "utils/synchronized.hpp"
using testing::UnorderedElementsAre; using testing::UnorderedElementsAre;
using memgraph::dbms::RegisterReplicaError; using memgraph::query::RegisterReplicaError;
using memgraph::dbms::ReplicationHandler; using memgraph::query::UnregisterReplicaResult;
using memgraph::dbms::UnregisterReplicaResult;
using memgraph::replication::ReplicationClientConfig; using memgraph::replication::ReplicationClientConfig;
using memgraph::replication::ReplicationHandler;
using memgraph::replication::ReplicationServerConfig; using memgraph::replication::ReplicationServerConfig;
using memgraph::replication_coordination_glue::ReplicationMode; using memgraph::replication_coordination_glue::ReplicationMode;
using memgraph::replication_coordination_glue::ReplicationRole; using memgraph::replication_coordination_glue::ReplicationRole;
@ -114,21 +112,26 @@ class ReplicationTest : public ::testing::Test {
struct MinMemgraph { struct MinMemgraph {
MinMemgraph(const memgraph::storage::Config &conf) MinMemgraph(const memgraph::storage::Config &conf)
: auth{conf.durability.storage_directory / "auth", memgraph::auth::Auth::Config{/* default */}}, : auth{conf.durability.storage_directory / "auth", memgraph::auth::Auth::Config{/* default */}},
dbms{conf repl_state{ReplicationStateRootPath(conf)},
dbms{conf, system_, repl_state
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
, ,
&auth, true auth, true
#endif #endif
}, },
repl_state{dbms.ReplicationState()},
db_acc{dbms.Get()}, db_acc{dbms.Get()},
db{*db_acc.get()}, db{*db_acc.get()},
repl_handler(dbms) { repl_handler(repl_state, dbms
#ifdef MG_ENTERPRISE
,
&system_, auth
#endif
) {
} }
memgraph::auth::SynchedAuth auth;
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> auth; memgraph::system::System system_;
memgraph::replication::ReplicationState repl_state;
memgraph::dbms::DbmsHandler dbms; memgraph::dbms::DbmsHandler dbms;
memgraph::replication::ReplicationState &repl_state;
memgraph::dbms::DatabaseAccess db_acc; memgraph::dbms::DatabaseAccess db_acc;
memgraph::dbms::Database &db; memgraph::dbms::Database &db;
ReplicationHandler repl_handler; ReplicationHandler repl_handler;
@ -144,7 +147,7 @@ TEST_F(ReplicationTest, BasicSynchronousReplicationTest) {
.port = ports[0], .port = ports[0],
}); });
const auto &reg = main.repl_handler.RegisterReplica(ReplicationClientConfig{ const auto &reg = main.repl_handler.TryRegisterReplica(ReplicationClientConfig{
.name = "REPLICA", .name = "REPLICA",
.mode = ReplicationMode::SYNC, .mode = ReplicationMode::SYNC,
.ip_address = local_host, .ip_address = local_host,
@ -442,7 +445,7 @@ TEST_F(ReplicationTest, MultipleSynchronousReplicationTest) {
}); });
ASSERT_FALSE(main.repl_handler ASSERT_FALSE(main.repl_handler
.RegisterReplica(ReplicationClientConfig{ .TryRegisterReplica(ReplicationClientConfig{
.name = replicas[0], .name = replicas[0],
.mode = ReplicationMode::SYNC, .mode = ReplicationMode::SYNC,
.ip_address = local_host, .ip_address = local_host,
@ -450,7 +453,7 @@ TEST_F(ReplicationTest, MultipleSynchronousReplicationTest) {
}) })
.HasError()); .HasError());
ASSERT_FALSE(main.repl_handler ASSERT_FALSE(main.repl_handler
.RegisterReplica(ReplicationClientConfig{ .TryRegisterReplica(ReplicationClientConfig{
.name = replicas[1], .name = replicas[1],
.mode = ReplicationMode::SYNC, .mode = ReplicationMode::SYNC,
.ip_address = local_host, .ip_address = local_host,
@ -587,7 +590,7 @@ TEST_F(ReplicationTest, RecoveryProcess) {
.port = ports[0], .port = ports[0],
}); });
ASSERT_FALSE(main.repl_handler ASSERT_FALSE(main.repl_handler
.RegisterReplica(ReplicationClientConfig{ .TryRegisterReplica(ReplicationClientConfig{
.name = replicas[0], .name = replicas[0],
.mode = ReplicationMode::SYNC, .mode = ReplicationMode::SYNC,
.ip_address = local_host, .ip_address = local_host,
@ -663,7 +666,7 @@ TEST_F(ReplicationTest, BasicAsynchronousReplicationTest) {
}); });
ASSERT_FALSE(main.repl_handler ASSERT_FALSE(main.repl_handler
.RegisterReplica(ReplicationClientConfig{ .TryRegisterReplica(ReplicationClientConfig{
.name = "REPLICA_ASYNC", .name = "REPLICA_ASYNC",
.mode = ReplicationMode::ASYNC, .mode = ReplicationMode::ASYNC,
.ip_address = local_host, .ip_address = local_host,
@ -715,7 +718,7 @@ TEST_F(ReplicationTest, EpochTest) {
}); });
ASSERT_FALSE(main.repl_handler ASSERT_FALSE(main.repl_handler
.RegisterReplica(ReplicationClientConfig{ .TryRegisterReplica(ReplicationClientConfig{
.name = replicas[0], .name = replicas[0],
.mode = ReplicationMode::SYNC, .mode = ReplicationMode::SYNC,
.ip_address = local_host, .ip_address = local_host,
@ -724,7 +727,7 @@ TEST_F(ReplicationTest, EpochTest) {
.HasError()); .HasError());
ASSERT_FALSE(main.repl_handler ASSERT_FALSE(main.repl_handler
.RegisterReplica(ReplicationClientConfig{ .TryRegisterReplica(ReplicationClientConfig{
.name = replicas[1], .name = replicas[1],
.mode = ReplicationMode::SYNC, .mode = ReplicationMode::SYNC,
.ip_address = local_host, .ip_address = local_host,
@ -758,7 +761,7 @@ TEST_F(ReplicationTest, EpochTest) {
ASSERT_TRUE(replica1.repl_handler.SetReplicationRoleMain()); ASSERT_TRUE(replica1.repl_handler.SetReplicationRoleMain());
ASSERT_FALSE(replica1.repl_handler ASSERT_FALSE(replica1.repl_handler
.RegisterReplica(ReplicationClientConfig{ .TryRegisterReplica(ReplicationClientConfig{
.name = replicas[1], .name = replicas[1],
.mode = ReplicationMode::SYNC, .mode = ReplicationMode::SYNC,
.ip_address = local_host, .ip_address = local_host,
@ -791,7 +794,7 @@ TEST_F(ReplicationTest, EpochTest) {
.port = ports[0], .port = ports[0],
}); });
ASSERT_TRUE(main.repl_handler ASSERT_TRUE(main.repl_handler
.RegisterReplica(ReplicationClientConfig{ .TryRegisterReplica(ReplicationClientConfig{
.name = replicas[0], .name = replicas[0],
.mode = ReplicationMode::SYNC, .mode = ReplicationMode::SYNC,
.ip_address = local_host, .ip_address = local_host,
@ -834,7 +837,7 @@ TEST_F(ReplicationTest, ReplicationInformation) {
}); });
ASSERT_FALSE(main.repl_handler ASSERT_FALSE(main.repl_handler
.RegisterReplica(ReplicationClientConfig{ .TryRegisterReplica(ReplicationClientConfig{
.name = replicas[0], .name = replicas[0],
.mode = ReplicationMode::SYNC, .mode = ReplicationMode::SYNC,
.ip_address = local_host, .ip_address = local_host,
@ -844,7 +847,7 @@ TEST_F(ReplicationTest, ReplicationInformation) {
.HasError()); .HasError());
ASSERT_FALSE(main.repl_handler ASSERT_FALSE(main.repl_handler
.RegisterReplica(ReplicationClientConfig{ .TryRegisterReplica(ReplicationClientConfig{
.name = replicas[1], .name = replicas[1],
.mode = ReplicationMode::ASYNC, .mode = ReplicationMode::ASYNC,
.ip_address = local_host, .ip_address = local_host,
@ -890,7 +893,7 @@ TEST_F(ReplicationTest, ReplicationReplicaWithExistingName) {
.port = replica2_port, .port = replica2_port,
}); });
ASSERT_FALSE(main.repl_handler ASSERT_FALSE(main.repl_handler
.RegisterReplica(ReplicationClientConfig{ .TryRegisterReplica(ReplicationClientConfig{
.name = replicas[0], .name = replicas[0],
.mode = ReplicationMode::SYNC, .mode = ReplicationMode::SYNC,
.ip_address = local_host, .ip_address = local_host,
@ -899,7 +902,7 @@ TEST_F(ReplicationTest, ReplicationReplicaWithExistingName) {
.HasError()); .HasError());
ASSERT_TRUE(main.repl_handler ASSERT_TRUE(main.repl_handler
.RegisterReplica(ReplicationClientConfig{ .TryRegisterReplica(ReplicationClientConfig{
.name = replicas[0], .name = replicas[0],
.mode = ReplicationMode::ASYNC, .mode = ReplicationMode::ASYNC,
.ip_address = local_host, .ip_address = local_host,
@ -925,7 +928,7 @@ TEST_F(ReplicationTest, ReplicationReplicaWithExistingEndPoint) {
}); });
ASSERT_FALSE(main.repl_handler ASSERT_FALSE(main.repl_handler
.RegisterReplica(ReplicationClientConfig{ .TryRegisterReplica(ReplicationClientConfig{
.name = replicas[0], .name = replicas[0],
.mode = ReplicationMode::SYNC, .mode = ReplicationMode::SYNC,
.ip_address = local_host, .ip_address = local_host,
@ -934,7 +937,7 @@ TEST_F(ReplicationTest, ReplicationReplicaWithExistingEndPoint) {
.HasError()); .HasError());
ASSERT_TRUE(main.repl_handler ASSERT_TRUE(main.repl_handler
.RegisterReplica(ReplicationClientConfig{ .TryRegisterReplica(ReplicationClientConfig{
.name = replicas[1], .name = replicas[1],
.mode = ReplicationMode::ASYNC, .mode = ReplicationMode::ASYNC,
.ip_address = local_host, .ip_address = local_host,
@ -973,14 +976,14 @@ TEST_F(ReplicationTest, RestoringReplicationAtStartupAfterDroppingReplica) {
.port = ports[1], .port = ports[1],
}); });
auto res = main->repl_handler.RegisterReplica(ReplicationClientConfig{ auto res = main->repl_handler.TryRegisterReplica(ReplicationClientConfig{
.name = replicas[0], .name = replicas[0],
.mode = ReplicationMode::SYNC, .mode = ReplicationMode::SYNC,
.ip_address = local_host, .ip_address = local_host,
.port = ports[0], .port = ports[0],
}); });
ASSERT_FALSE(res.HasError()) << (int)res.GetError(); ASSERT_FALSE(res.HasError()) << (int)res.GetError();
res = main->repl_handler.RegisterReplica(ReplicationClientConfig{ res = main->repl_handler.TryRegisterReplica(ReplicationClientConfig{
.name = replicas[1], .name = replicas[1],
.mode = ReplicationMode::SYNC, .mode = ReplicationMode::SYNC,
.ip_address = local_host, .ip_address = local_host,
@ -1030,14 +1033,14 @@ TEST_F(ReplicationTest, RestoringReplicationAtStartup) {
.ip_address = local_host, .ip_address = local_host,
.port = ports[1], .port = ports[1],
}); });
auto res = main->repl_handler.RegisterReplica(ReplicationClientConfig{ auto res = main->repl_handler.TryRegisterReplica(ReplicationClientConfig{
.name = replicas[0], .name = replicas[0],
.mode = ReplicationMode::SYNC, .mode = ReplicationMode::SYNC,
.ip_address = local_host, .ip_address = local_host,
.port = ports[0], .port = ports[0],
}); });
ASSERT_FALSE(res.HasError()); ASSERT_FALSE(res.HasError());
res = main->repl_handler.RegisterReplica(ReplicationClientConfig{ res = main->repl_handler.TryRegisterReplica(ReplicationClientConfig{
.name = replicas[1], .name = replicas[1],
.mode = ReplicationMode::SYNC, .mode = ReplicationMode::SYNC,
.ip_address = local_host, .ip_address = local_host,
@ -1080,7 +1083,7 @@ TEST_F(ReplicationTest, AddingInvalidReplica) {
MinMemgraph main(main_conf); MinMemgraph main(main_conf);
ASSERT_TRUE(main.repl_handler ASSERT_TRUE(main.repl_handler
.RegisterReplica(ReplicationClientConfig{ .TryRegisterReplica(ReplicationClientConfig{
.name = "REPLICA", .name = "REPLICA",
.mode = ReplicationMode::SYNC, .mode = ReplicationMode::SYNC,
.ip_address = local_host, .ip_address = local_host,

View File

@ -90,7 +90,16 @@ class StorageModeMultiTxTest : public ::testing::Test {
return db_acc; return db_acc;
}() // iile }() // iile
}; };
memgraph::query::InterpreterContext interpreter_context{{}, nullptr, &repl_state}; memgraph::system::System system_state;
memgraph::query::InterpreterContext interpreter_context{{},
nullptr,
&repl_state,
system_state
#ifdef MG_ENTERPRISE
,
nullptr
#endif
};
InterpreterFaker running_interpreter{&interpreter_context, db}, main_interpreter{&interpreter_context, db}; InterpreterFaker running_interpreter{&interpreter_context, db}, main_interpreter{&interpreter_context, db};
}; };

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd. // Copyright 2024 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -59,7 +59,16 @@ class TransactionQueueSimpleTest : public ::testing::Test {
return db_acc; return db_acc;
}() // iile }() // iile
}; };
memgraph::query::InterpreterContext interpreter_context{{}, nullptr, &repl_state}; memgraph::system::System system_state;
memgraph::query::InterpreterContext interpreter_context{{},
nullptr,
&repl_state,
system_state
#ifdef MG_ENTERPRISE
,
nullptr
#endif
};
InterpreterFaker running_interpreter{&interpreter_context, db}, main_interpreter{&interpreter_context, db}; InterpreterFaker running_interpreter{&interpreter_context, db}, main_interpreter{&interpreter_context, db};
void TearDown() override { void TearDown() override {

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd. // Copyright 2024 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -68,7 +68,16 @@ class TransactionQueueMultipleTest : public ::testing::Test {
}() // iile }() // iile
}; };
memgraph::query::InterpreterContext interpreter_context{{}, nullptr, &repl_state}; memgraph::system::System system_state;
memgraph::query::InterpreterContext interpreter_context{{},
nullptr,
&repl_state,
system_state
#ifdef MG_ENTERPRISE
,
nullptr
#endif
};
InterpreterFaker main_interpreter{&interpreter_context, db}; InterpreterFaker main_interpreter{&interpreter_context, db};
std::vector<InterpreterFaker *> running_interpreters; std::vector<InterpreterFaker *> running_interpreters;