Adding authentication data replication (#1666)

* Add AUTH system tx deltas
* Add auth data RPC and handlers
* Support multiple system deltas in a single transaction
* Added e2e test
* Bugfix: KVStore segfault after move

---------

Co-authored-by: Gareth Lloyd <gareth.lloyd@memgraph.io>
This commit is contained in:
andrejtonev 2024-02-05 11:37:00 +01:00 committed by GitHub
parent c46dad18fe
commit 7ead00f23e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
95 changed files with 4430 additions and 1865 deletions

View File

@ -22,8 +22,10 @@ add_subdirectory(dbms)
add_subdirectory(flags)
add_subdirectory(distributed)
add_subdirectory(replication)
add_subdirectory(replication_handler)
add_subdirectory(coordination)
add_subdirectory(replication_coordination_glue)
add_subdirectory(system)
string(TOLOWER ${CMAKE_BUILD_TYPE} lower_build_type)
@ -43,10 +45,10 @@ set(mg_single_node_v2_sources
add_executable(memgraph ${mg_single_node_v2_sources})
target_include_directories(memgraph PUBLIC ${CMAKE_SOURCE_DIR}/include)
target_link_libraries(memgraph stdc++fs Threads::Threads
mg-telemetry mg-communication mg-communication-metrics mg-memory mg-utils mg-license mg-settings mg-glue mg-flags)
mg-telemetry mg-communication mg-communication-metrics mg-memory mg-utils mg-license mg-settings mg-glue mg-flags mg::system mg::replication_handler)
# NOTE: `include/mg_procedure.syms` describes a pattern match for symbols which
# should be dynamically exported, so that `dlopen` can correctly link the
# should be dynamically exported, so that `dlopen` can correctly link th
# symbols in custom procedure module libraries.
target_link_libraries(memgraph "-Wl,--dynamic-list=${CMAKE_SOURCE_DIR}/include/mg_procedure.syms")
set_target_properties(memgraph PROPERTIES

View File

@ -2,7 +2,9 @@ set(auth_src_files
auth.cpp
crypto.cpp
models.cpp
module.cpp)
module.cpp
rpc.cpp
replication_handlers.cpp)
find_package(Seccomp REQUIRED)
find_package(fmt REQUIRED)
@ -11,7 +13,7 @@ find_package(gflags REQUIRED)
add_library(mg-auth STATIC ${auth_src_files})
target_link_libraries(mg-auth json libbcrypt gflags fmt::fmt)
target_link_libraries(mg-auth mg-utils mg-kvstore mg-license )
target_link_libraries(mg-auth mg-utils mg-kvstore mg-license mg::system mg-replication)
target_link_libraries(mg-auth ${Seccomp_LIBRARIES})
target_include_directories(mg-auth SYSTEM PRIVATE ${Seccomp_INCLUDE_DIRS})

View File

@ -9,13 +9,16 @@
#include "auth/auth.hpp"
#include <iostream>
#include <optional>
#include <utility>
#include <fmt/format.h>
#include "auth/crypto.hpp"
#include "auth/exceptions.hpp"
#include "auth/rpc.hpp"
#include "license/license.hpp"
#include "system/transaction.hpp"
#include "utils/flag_validation.hpp"
#include "utils/message.hpp"
#include "utils/settings.hpp"
@ -41,12 +44,84 @@ DEFINE_VALIDATED_int32(auth_module_timeout_ms, 10000,
FLAG_IN_RANGE(100, 1800000));
namespace memgraph::auth {
namespace {
#ifdef MG_ENTERPRISE
/**
* REPLICATION SYSTEM ACTION IMPLEMENTATIONS
*/
struct UpdateAuthData : memgraph::system::ISystemAction {
explicit UpdateAuthData(User user) : user_{std::move(user)}, role_{std::nullopt} {}
explicit UpdateAuthData(Role role) : user_{std::nullopt}, role_{std::move(role)} {}
void DoDurability() override { /* Done during Auth execution */
}
bool DoReplication(replication::ReplicationClient &client, replication::ReplicationEpoch const &epoch,
memgraph::system::Transaction const &txn) const override {
auto check_response = [](const replication::UpdateAuthDataRes &response) { return response.success; };
if (user_) {
return client.SteamAndFinalizeDelta<replication::UpdateAuthDataRpc>(
check_response, std::string{epoch.id()}, txn.last_committed_system_timestamp(), txn.timestamp(), *user_);
}
if (role_) {
return client.SteamAndFinalizeDelta<replication::UpdateAuthDataRpc>(
check_response, std::string{epoch.id()}, txn.last_committed_system_timestamp(), txn.timestamp(), *role_);
}
// Should never get here
MG_ASSERT(false, "Trying to update auth data that is not a user nor a role");
return {};
}
void PostReplication(replication::RoleMainData &mainData) const override {}
private:
std::optional<User> user_;
std::optional<Role> role_;
};
struct DropAuthData : memgraph::system::ISystemAction {
enum class AuthDataType { USER, ROLE };
explicit DropAuthData(AuthDataType type, std::string_view name) : type_{type}, name_{name} {}
void DoDurability() override { /* Done during Auth execution */
}
bool DoReplication(replication::ReplicationClient &client, replication::ReplicationEpoch const &epoch,
memgraph::system::Transaction const &txn) const override {
auto check_response = [](const replication::DropAuthDataRes &response) { return response.success; };
memgraph::replication::DropAuthDataReq::DataType type{};
switch (type_) {
case AuthDataType::USER:
type = memgraph::replication::DropAuthDataReq::DataType::USER;
break;
case AuthDataType::ROLE:
type = memgraph::replication::DropAuthDataReq::DataType::ROLE;
break;
}
return client.SteamAndFinalizeDelta<replication::DropAuthDataRpc>(
check_response, std::string{epoch.id()}, txn.last_committed_system_timestamp(), txn.timestamp(), type, name_);
}
void PostReplication(replication::RoleMainData &mainData) const override {}
private:
AuthDataType type_;
std::string name_;
};
#endif
/**
* CONSTANTS
*/
const std::string kUserPrefix = "user:";
const std::string kRolePrefix = "role:";
const std::string kLinkPrefix = "link:";
const std::string kVersion = "version";
static constexpr auto kVersionV1 = "V1";
} // namespace
/**
* All data stored in the `Auth` storage is stored in an underlying
@ -148,6 +223,12 @@ std::optional<User> Auth::Authenticate(const std::string &username, const std::s
// Authenticate the user.
if (!is_authenticated) return std::nullopt;
/**
* TODO
* The auth module should not update auth data.
* There is now way to replicate it and we should not be storing sensitive data if we don't have to.
*/
// Find or create the user and return it.
auto user = GetUser(username);
if (!user) {
@ -240,7 +321,7 @@ std::optional<User> Auth::GetUser(const std::string &username_orig) const {
return user;
}
void Auth::SaveUser(const User &user) {
void Auth::SaveUser(const User &user, system::Transaction *system_tx) {
bool success = false;
if (const auto *role = user.role(); role != nullptr) {
success = storage_.PutMultiple(
@ -252,6 +333,12 @@ void Auth::SaveUser(const User &user) {
if (!success) {
throw AuthException("Couldn't save user '{}'!", user.username());
}
// All changes to the user end up calling this function, so no need to add a delta anywhere else
if (system_tx) {
#ifdef MG_ENTERPRISE
system_tx->AddAction<UpdateAuthData>(user);
#endif
}
}
void Auth::UpdatePassword(auth::User &user, const std::optional<std::string> &password) {
@ -284,7 +371,8 @@ void Auth::UpdatePassword(auth::User &user, const std::optional<std::string> &pa
user.UpdatePassword(password);
}
std::optional<User> Auth::AddUser(const std::string &username, const std::optional<std::string> &password) {
std::optional<User> Auth::AddUser(const std::string &username, const std::optional<std::string> &password,
system::Transaction *system_tx) {
if (!NameRegexMatch(username)) {
throw AuthException("Invalid user name.");
}
@ -294,17 +382,23 @@ std::optional<User> Auth::AddUser(const std::string &username, const std::option
if (existing_role) return std::nullopt;
auto new_user = User(username);
UpdatePassword(new_user, password);
SaveUser(new_user);
SaveUser(new_user, system_tx);
return new_user;
}
bool Auth::RemoveUser(const std::string &username_orig) {
bool Auth::RemoveUser(const std::string &username_orig, system::Transaction *system_tx) {
auto username = utils::ToLowerCase(username_orig);
if (!storage_.Get(kUserPrefix + username)) return false;
std::vector<std::string> keys({kLinkPrefix + username, kUserPrefix + username});
if (!storage_.DeleteMultiple(keys)) {
throw AuthException("Couldn't remove user '{}'!", username);
}
// Handling drop user delta
if (system_tx) {
#ifdef MG_ENTERPRISE
system_tx->AddAction<DropAuthData>(DropAuthData::AuthDataType::USER, username);
#endif
}
return true;
}
@ -321,6 +415,19 @@ std::vector<auth::User> Auth::AllUsers() const {
return ret;
}
std::vector<std::string> Auth::AllUsernames() const {
std::vector<std::string> ret;
for (auto it = storage_.begin(kUserPrefix); it != storage_.end(kUserPrefix); ++it) {
auto username = it->first.substr(kUserPrefix.size());
if (username != utils::ToLowerCase(username)) continue;
auto user = GetUser(username);
if (user) {
ret.push_back(username);
}
}
return ret;
}
bool Auth::HasUsers() const { return storage_.begin(kUserPrefix) != storage_.end(kUserPrefix); }
std::optional<Role> Auth::GetRole(const std::string &rolename_orig) const {
@ -338,24 +445,30 @@ std::optional<Role> Auth::GetRole(const std::string &rolename_orig) const {
return Role::Deserialize(data);
}
void Auth::SaveRole(const Role &role) {
void Auth::SaveRole(const Role &role, system::Transaction *system_tx) {
if (!storage_.Put(kRolePrefix + role.rolename(), role.Serialize().dump())) {
throw AuthException("Couldn't save role '{}'!", role.rolename());
}
// All changes to the role end up calling this function, so no need to add a delta anywhere else
if (system_tx) {
#ifdef MG_ENTERPRISE
system_tx->AddAction<UpdateAuthData>(role);
#endif
}
}
std::optional<Role> Auth::AddRole(const std::string &rolename) {
std::optional<Role> Auth::AddRole(const std::string &rolename, system::Transaction *system_tx) {
if (!NameRegexMatch(rolename)) {
throw AuthException("Invalid role name.");
}
if (auto existing_role = GetRole(rolename)) return std::nullopt;
if (auto existing_user = GetUser(rolename)) return std::nullopt;
auto new_role = Role(rolename);
SaveRole(new_role);
SaveRole(new_role, system_tx);
return new_role;
}
bool Auth::RemoveRole(const std::string &rolename_orig) {
bool Auth::RemoveRole(const std::string &rolename_orig, system::Transaction *system_tx) {
auto rolename = utils::ToLowerCase(rolename_orig);
if (!storage_.Get(kRolePrefix + rolename)) return false;
std::vector<std::string> keys;
@ -368,6 +481,12 @@ bool Auth::RemoveRole(const std::string &rolename_orig) {
if (!storage_.DeleteMultiple(keys)) {
throw AuthException("Couldn't remove role '{}'!", rolename);
}
// Handling drop role delta
if (system_tx) {
#ifdef MG_ENTERPRISE
system_tx->AddAction<DropAuthData>(DropAuthData::AuthDataType::ROLE, rolename);
#endif
}
return true;
}
@ -385,6 +504,18 @@ std::vector<auth::Role> Auth::AllRoles() const {
return ret;
}
std::vector<std::string> Auth::AllRolenames() const {
std::vector<std::string> ret;
for (auto it = storage_.begin(kRolePrefix); it != storage_.end(kRolePrefix); ++it) {
auto rolename = it->first.substr(kRolePrefix.size());
if (rolename != utils::ToLowerCase(rolename)) continue;
if (auto role = GetRole(rolename)) {
ret.push_back(rolename);
}
}
return ret;
}
std::vector<auth::User> Auth::AllUsersForRole(const std::string &rolename_orig) const {
const auto rolename = utils::ToLowerCase(rolename_orig);
std::vector<auth::User> ret;
@ -404,48 +535,48 @@ std::vector<auth::User> Auth::AllUsersForRole(const std::string &rolename_orig)
}
#ifdef MG_ENTERPRISE
bool Auth::GrantDatabaseToUser(const std::string &db, const std::string &name) {
bool Auth::GrantDatabaseToUser(const std::string &db, const std::string &name, system::Transaction *system_tx) {
if (auto user = GetUser(name)) {
if (db == kAllDatabases) {
user->db_access().GrantAll();
} else {
user->db_access().Add(db);
}
SaveUser(*user);
SaveUser(*user, system_tx);
return true;
}
return false;
}
bool Auth::RevokeDatabaseFromUser(const std::string &db, const std::string &name) {
bool Auth::RevokeDatabaseFromUser(const std::string &db, const std::string &name, system::Transaction *system_tx) {
if (auto user = GetUser(name)) {
if (db == kAllDatabases) {
user->db_access().DenyAll();
} else {
user->db_access().Remove(db);
}
SaveUser(*user);
SaveUser(*user, system_tx);
return true;
}
return false;
}
void Auth::DeleteDatabase(const std::string &db) {
void Auth::DeleteDatabase(const std::string &db, system::Transaction *system_tx) {
for (auto it = storage_.begin(kUserPrefix); it != storage_.end(kUserPrefix); ++it) {
auto username = it->first.substr(kUserPrefix.size());
if (auto user = GetUser(username)) {
user->db_access().Delete(db);
SaveUser(*user);
SaveUser(*user, system_tx);
}
}
}
bool Auth::SetMainDatabase(std::string_view db, const std::string &name) {
bool Auth::SetMainDatabase(std::string_view db, const std::string &name, system::Transaction *system_tx) {
if (auto user = GetUser(name)) {
if (!user->db_access().SetDefault(db)) {
throw AuthException("Couldn't set default database '{}' for user '{}'!", db, name);
}
SaveUser(*user);
SaveUser(*user, system_tx);
return true;
}
return false;

View File

@ -18,10 +18,15 @@
#include "auth/module.hpp"
#include "glue/auth_global.hpp"
#include "kvstore/kvstore.hpp"
#include "system/action.hpp"
#include "utils/settings.hpp"
#include "utils/synchronized.hpp"
namespace memgraph::auth {
class Auth;
using SynchedAuth = memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock>;
static const constexpr char *const kAllDatabases = "*";
/**
@ -68,6 +73,13 @@ class Auth final {
config_ = std::move(config);
}
/**
* @brief
*
* @return Config
*/
Config GetConfig() const { return config_; }
/**
* Authenticates a user using his username and password.
*
@ -96,7 +108,7 @@ class Auth final {
*
* @throw AuthException if unable to save the user.
*/
void SaveUser(const User &user);
void SaveUser(const User &user, system::Transaction *system_tx = nullptr);
/**
* Creates a user if the user doesn't exist.
@ -107,7 +119,8 @@ class Auth final {
* @return a user when the user is created, nullopt if the user exists
* @throw AuthException if unable to save the user.
*/
std::optional<User> AddUser(const std::string &username, const std::optional<std::string> &password = std::nullopt);
std::optional<User> AddUser(const std::string &username, const std::optional<std::string> &password = std::nullopt,
system::Transaction *system_tx = nullptr);
/**
* Removes a user from the storage.
@ -118,7 +131,7 @@ class Auth final {
* doesn't exist
* @throw AuthException if unable to remove the user.
*/
bool RemoveUser(const std::string &username);
bool RemoveUser(const std::string &username, system::Transaction *system_tx = nullptr);
/**
* @brief
@ -136,6 +149,13 @@ class Auth final {
*/
std::vector<User> AllUsers() const;
/**
* @brief
*
* @return std::vector<std::string>
*/
std::vector<std::string> AllUsernames() const;
/**
* Returns whether there are users in the storage.
*
@ -160,7 +180,7 @@ class Auth final {
*
* @throw AuthException if unable to save the role.
*/
void SaveRole(const Role &role);
void SaveRole(const Role &role, system::Transaction *system_tx = nullptr);
/**
* Creates a role if the role doesn't exist.
@ -170,7 +190,7 @@ class Auth final {
* @return a role when the role is created, nullopt if the role exists
* @throw AuthException if unable to save the role.
*/
std::optional<Role> AddRole(const std::string &rolename);
std::optional<Role> AddRole(const std::string &rolename, system::Transaction *system_tx = nullptr);
/**
* Removes a role from the storage.
@ -181,7 +201,7 @@ class Auth final {
* doesn't exist
* @throw AuthException if unable to remove the role.
*/
bool RemoveRole(const std::string &rolename);
bool RemoveRole(const std::string &rolename, system::Transaction *system_tx = nullptr);
/**
* Gets all roles from the storage.
@ -191,6 +211,13 @@ class Auth final {
*/
std::vector<Role> AllRoles() const;
/**
* @brief
*
* @return std::vector<std::string>
*/
std::vector<std::string> AllRolenames() const;
/**
* Gets all users for a role from the storage.
*
@ -210,7 +237,7 @@ class Auth final {
* @return true on success
* @throw AuthException if unable to find or update the user
*/
bool RevokeDatabaseFromUser(const std::string &db, const std::string &name);
bool RevokeDatabaseFromUser(const std::string &db, const std::string &name, system::Transaction *system_tx = nullptr);
/**
* @brief Grant access to individual database for a user.
@ -220,7 +247,7 @@ class Auth final {
* @return true on success
* @throw AuthException if unable to find or update the user
*/
bool GrantDatabaseToUser(const std::string &db, const std::string &name);
bool GrantDatabaseToUser(const std::string &db, const std::string &name, system::Transaction *system_tx = nullptr);
/**
* @brief Delete a database from all users.
@ -228,7 +255,7 @@ class Auth final {
* @param db name of the database to delete
* @throw AuthException if unable to read data
*/
void DeleteDatabase(const std::string &db);
void DeleteDatabase(const std::string &db, system::Transaction *system_tx = nullptr);
/**
* @brief Set main database for an individual user.
@ -238,7 +265,7 @@ class Auth final {
* @return true on success
* @throw AuthException if unable to find or update the user
*/
bool SetMainDatabase(std::string_view db, const std::string &name);
bool SetMainDatabase(std::string_view db, const std::string &name, system::Transaction *system_tx = nullptr);
#endif
private:

View File

@ -611,27 +611,49 @@ Permissions User::GetPermissions() const {
#ifdef MG_ENTERPRISE
FineGrainedAccessPermissions User::GetFineGrainedAccessLabelPermissions() const {
return Merge(GetUserFineGrainedAccessLabelPermissions(), GetRoleFineGrainedAccessLabelPermissions());
}
FineGrainedAccessPermissions User::GetFineGrainedAccessEdgeTypePermissions() const {
return Merge(GetUserFineGrainedAccessEdgeTypePermissions(), GetRoleFineGrainedAccessEdgeTypePermissions());
}
FineGrainedAccessPermissions User::GetUserFineGrainedAccessEdgeTypePermissions() const {
if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
return FineGrainedAccessPermissions{};
}
if (role_) {
return Merge(role()->fine_grained_access_handler().label_permissions(),
fine_grained_access_handler_.label_permissions());
return fine_grained_access_handler_.edge_type_permissions();
}
FineGrainedAccessPermissions User::GetUserFineGrainedAccessLabelPermissions() const {
if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
return FineGrainedAccessPermissions{};
}
return fine_grained_access_handler_.label_permissions();
}
FineGrainedAccessPermissions User::GetFineGrainedAccessEdgeTypePermissions() const {
FineGrainedAccessPermissions User::GetRoleFineGrainedAccessEdgeTypePermissions() const {
if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
return FineGrainedAccessPermissions{};
}
if (role_) {
return Merge(role()->fine_grained_access_handler().edge_type_permissions(),
fine_grained_access_handler_.edge_type_permissions());
return role()->fine_grained_access_handler().edge_type_permissions();
}
return fine_grained_access_handler_.edge_type_permissions();
return FineGrainedAccessPermissions{};
}
FineGrainedAccessPermissions User::GetRoleFineGrainedAccessLabelPermissions() const {
if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
return FineGrainedAccessPermissions{};
}
if (role_) {
return role()->fine_grained_access_handler().label_permissions();
}
return FineGrainedAccessPermissions{};
}
#endif

View File

@ -207,6 +207,8 @@ bool operator==(const FineGrainedAccessHandler &first, const FineGrainedAccessHa
class Role final {
public:
Role() = default;
explicit Role(const std::string &rolename);
Role(const std::string &rolename, const Permissions &permissions);
#ifdef MG_ENTERPRISE
@ -369,6 +371,10 @@ class User final {
#ifdef MG_ENTERPRISE
FineGrainedAccessPermissions GetFineGrainedAccessLabelPermissions() const;
FineGrainedAccessPermissions GetFineGrainedAccessEdgeTypePermissions() const;
FineGrainedAccessPermissions GetUserFineGrainedAccessLabelPermissions() const;
FineGrainedAccessPermissions GetUserFineGrainedAccessEdgeTypePermissions() const;
FineGrainedAccessPermissions GetRoleFineGrainedAccessLabelPermissions() const;
FineGrainedAccessPermissions GetRoleFineGrainedAccessEdgeTypePermissions() const;
const FineGrainedAccessHandler &fine_grained_access_handler() const;
FineGrainedAccessHandler &fine_grained_access_handler();
#endif

View File

@ -0,0 +1,170 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "auth/replication_handlers.hpp"
#include "auth/auth.hpp"
#include "auth/rpc.hpp"
#include "license/license.hpp"
namespace memgraph::auth {
#ifdef MG_ENTERPRISE
void UpdateAuthDataHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access, auth::SynchedAuth &auth,
slk::Reader *req_reader, slk::Builder *res_builder) {
replication::UpdateAuthDataReq req;
memgraph::slk::Load(&req, req_reader);
using memgraph::replication::UpdateAuthDataRes;
UpdateAuthDataRes res(false);
// Note: No need to check epoch, recovery mechanism is done by a full uptodate snapshot
// of the set of databases. Hence no history exists to maintain regarding epoch change.
// If MAIN has changed we need to check this new group_timestamp is consistent with
// what we have so far.
if (req.expected_group_timestamp != system_state_access.LastCommitedTS()) {
spdlog::debug("UpdateAuthDataHandler: bad expected timestamp {},{}", req.expected_group_timestamp,
system_state_access.LastCommitedTS());
memgraph::slk::Save(res, res_builder);
return;
}
try {
// Update
if (req.user) auth->SaveUser(*req.user);
if (req.role) auth->SaveRole(*req.role);
// Success
system_state_access.SetLastCommitedTS(req.new_group_timestamp);
res = UpdateAuthDataRes(true);
spdlog::debug("UpdateAuthDataHandler: SUCCESS updated LCTS to {}", req.new_group_timestamp);
} catch (const auth::AuthException & /* not used */) {
// Failure
}
memgraph::slk::Save(res, res_builder);
}
void DropAuthDataHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access, auth::SynchedAuth &auth,
slk::Reader *req_reader, slk::Builder *res_builder) {
replication::DropAuthDataReq req;
memgraph::slk::Load(&req, req_reader);
using memgraph::replication::DropAuthDataRes;
DropAuthDataRes res(false);
// Note: No need to check epoch, recovery mechanism is done by a full uptodate snapshot
// of the set of databases. Hence no history exists to maintain regarding epoch change.
// If MAIN has changed we need to check this new group_timestamp is consistent with
// what we have so far.
if (req.expected_group_timestamp != system_state_access.LastCommitedTS()) {
spdlog::debug("DropAuthDataHandler: bad expected timestamp {},{}", req.expected_group_timestamp,
system_state_access.LastCommitedTS());
memgraph::slk::Save(res, res_builder);
return;
}
try {
// Remove
switch (req.type) {
case replication::DropAuthDataReq::DataType::USER:
auth->RemoveUser(req.name);
break;
case replication::DropAuthDataReq::DataType::ROLE:
auth->RemoveRole(req.name);
break;
}
// Success
system_state_access.SetLastCommitedTS(req.new_group_timestamp);
res = DropAuthDataRes(true);
spdlog::debug("DropAuthDataHandler: SUCCESS updated LCTS to {}", req.new_group_timestamp);
} catch (const auth::AuthException & /* not used */) {
// Failure
}
memgraph::slk::Save(res, res_builder);
}
bool SystemRecoveryHandler(auth::SynchedAuth &auth, auth::Auth::Config auth_config,
const std::vector<auth::User> &users, const std::vector<auth::Role> &roles) {
return auth.WithLock([&](auto &locked_auth) {
// Update config
locked_auth.SetConfig(std::move(auth_config));
// Get all current users
auto old_users = locked_auth.AllUsernames();
// Save incoming users
for (const auto &user : users) {
// Missing users
try {
locked_auth.SaveUser(user);
} catch (const auth::AuthException &) {
spdlog::debug("SystemRecoveryHandler: Failed to save user");
return false;
}
const auto it = std::find(old_users.begin(), old_users.end(), user.username());
if (it != old_users.end()) old_users.erase(it);
}
// Delete all the leftover users
for (const auto &user : old_users) {
if (!locked_auth.RemoveUser(user)) {
spdlog::debug("SystemRecoveryHandler: Failed to remove user \"{}\".", user);
return false;
}
}
// Roles are only supported with a license
if (license::global_license_checker.IsEnterpriseValidFast()) {
// Get all current roles
auto old_roles = locked_auth.AllRolenames();
// Save incoming users
for (const auto &role : roles) {
// Missing users
try {
locked_auth.SaveRole(role);
} catch (const auth::AuthException &) {
spdlog::debug("SystemRecoveryHandler: Failed to save user");
return false;
}
const auto it = std::find(old_roles.begin(), old_roles.end(), role.rolename());
if (it != old_roles.end()) old_roles.erase(it);
}
// Delete all the leftover users
for (const auto &role : old_roles) {
if (!locked_auth.RemoveRole(role)) {
spdlog::debug("SystemRecoveryHandler: Failed to remove user \"{}\".", role);
return false;
}
}
}
// Success
return true;
});
}
void Register(replication::RoleReplicaData const &data, system::ReplicaHandlerAccessToState &system_state_access,
auth::SynchedAuth &auth) {
// NOTE: Register even without license as the user could add a license at run-time
data.server->rpc_server_.Register<replication::UpdateAuthDataRpc>(
[system_state_access, &auth](auto *req_reader, auto *res_builder) mutable {
spdlog::debug("Received UpdateAuthDataRpc");
UpdateAuthDataHandler(system_state_access, auth, req_reader, res_builder);
});
data.server->rpc_server_.Register<replication::DropAuthDataRpc>(
[system_state_access, &auth](auto *req_reader, auto *res_builder) mutable {
spdlog::debug("Received DropAuthDataRpc");
DropAuthDataHandler(system_state_access, auth, req_reader, res_builder);
});
}
#endif
} // namespace memgraph::auth

View File

@ -0,0 +1,31 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "auth/auth.hpp"
#include "replication/state.hpp"
#include "slk/streams.hpp"
#include "system/state.hpp"
namespace memgraph::auth {
#ifdef MG_ENTERPRISE
void UpdateAuthDataHandler(system::ReplicaHandlerAccessToState &system_state_access, auth::SynchedAuth &auth,
slk::Reader *req_reader, slk::Builder *res_builder);
void DropAuthDataHandler(system::ReplicaHandlerAccessToState &system_state_access, auth::SynchedAuth &auth,
slk::Reader *req_reader, slk::Builder *res_builder);
bool SystemRecoveryHandler(auth::SynchedAuth &auth, auth::Auth::Config auth_config,
const std::vector<auth::User> &users, const std::vector<auth::Role> &roles);
void Register(replication::RoleReplicaData const &data, system::ReplicaHandlerAccessToState &system_state_access,
auth::SynchedAuth &auth);
#endif
} // namespace memgraph::auth

178
src/auth/rpc.cpp Normal file
View File

@ -0,0 +1,178 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "auth/rpc.hpp"
#include <json/json.hpp>
#include "auth/auth.hpp"
#include "slk/serialization.hpp"
#include "slk/streams.hpp"
#include "utils/enum.hpp"
namespace memgraph::slk {
// Serialize code for auth::Role
void Save(const auth::Role &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.Serialize().dump(), builder);
}
namespace {
auth::Role LoadAuthRole(memgraph::slk::Reader *reader) {
std::string tmp;
memgraph::slk::Load(&tmp, reader);
const auto json = nlohmann::json::parse(tmp);
return memgraph::auth::Role::Deserialize(json);
}
} // namespace
// Deserialize code for auth::Role
void Load(auth::Role *self, memgraph::slk::Reader *reader) { *self = LoadAuthRole(reader); }
// Special case for optional<Role>
template <>
inline void Load<auth::Role>(std::optional<auth::Role> *obj, Reader *reader) {
bool exists = false;
Load(&exists, reader);
if (exists) {
obj->emplace(LoadAuthRole(reader));
} else {
*obj = std::nullopt;
}
}
// Serialize code for auth::User
void Save(const auth::User &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.Serialize().dump(), builder);
std::optional<auth::Role> role{};
if (const auto *role_ptr = self.role(); role_ptr) {
role.emplace(*role_ptr);
}
memgraph::slk::Save(role, builder);
}
// Deserialize code for auth::User
void Load(auth::User *self, memgraph::slk::Reader *reader) {
std::string tmp;
memgraph::slk::Load(&tmp, reader);
const auto json = nlohmann::json::parse(tmp);
*self = memgraph::auth::User::Deserialize(json);
std::optional<auth::Role> role{};
memgraph::slk::Load(&role, reader);
if (role)
self->SetRole(*role);
else
self->ClearRole();
}
// Serialize code for auth::Auth::Config
void Save(const auth::Auth::Config &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.name_regex_str, builder);
memgraph::slk::Save(self.password_regex_str, builder);
memgraph::slk::Save(self.password_permit_null, builder);
}
// Deserialize code for auth::Auth::Config
void Load(auth::Auth::Config *self, memgraph::slk::Reader *reader) {
std::string name_regex_str{};
std::string password_regex_str{};
bool password_permit_null{};
memgraph::slk::Load(&name_regex_str, reader);
memgraph::slk::Load(&password_regex_str, reader);
memgraph::slk::Load(&password_permit_null, reader);
*self = auth::Auth::Config{std::move(name_regex_str), std::move(password_regex_str), password_permit_null};
}
// Serialize code for UpdateAuthDataReq
void Save(const memgraph::replication::UpdateAuthDataReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.epoch_id, builder);
memgraph::slk::Save(self.expected_group_timestamp, builder);
memgraph::slk::Save(self.new_group_timestamp, builder);
memgraph::slk::Save(self.user, builder);
memgraph::slk::Save(self.role, builder);
}
void Load(memgraph::replication::UpdateAuthDataReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->epoch_id, reader);
memgraph::slk::Load(&self->expected_group_timestamp, reader);
memgraph::slk::Load(&self->new_group_timestamp, reader);
memgraph::slk::Load(&self->user, reader);
memgraph::slk::Load(&self->role, reader);
}
// Serialize code for UpdateAuthDataRes
void Save(const memgraph::replication::UpdateAuthDataRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.success, builder);
}
void Load(memgraph::replication::UpdateAuthDataRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->success, reader);
}
// Serialize code for DropAuthDataReq
void Save(const memgraph::replication::DropAuthDataReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.epoch_id, builder);
memgraph::slk::Save(self.expected_group_timestamp, builder);
memgraph::slk::Save(self.new_group_timestamp, builder);
memgraph::slk::Save(utils::EnumToNum<2, uint8_t>(self.type), builder);
memgraph::slk::Save(self.name, builder);
}
void Load(memgraph::replication::DropAuthDataReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->epoch_id, reader);
memgraph::slk::Load(&self->expected_group_timestamp, reader);
memgraph::slk::Load(&self->new_group_timestamp, reader);
uint8_t type_tmp = 0;
memgraph::slk::Load(&type_tmp, reader);
if (!utils::NumToEnum<2>(type_tmp, self->type)) {
throw SlkReaderException("Unexpected result line:{}!", __LINE__);
}
memgraph::slk::Load(&self->name, reader);
}
// Serialize code for DropAuthDataRes
void Save(const memgraph::replication::DropAuthDataRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.success, builder);
}
void Load(memgraph::replication::DropAuthDataRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->success, reader);
}
} // namespace memgraph::slk
namespace memgraph::replication {
constexpr utils::TypeInfo UpdateAuthDataReq::kType{utils::TypeId::REP_UPDATE_AUTH_DATA_REQ, "UpdateAuthDataReq",
nullptr};
constexpr utils::TypeInfo UpdateAuthDataRes::kType{utils::TypeId::REP_UPDATE_AUTH_DATA_RES, "UpdateAuthDataRes",
nullptr};
constexpr utils::TypeInfo DropAuthDataReq::kType{utils::TypeId::REP_DROP_AUTH_DATA_REQ, "DropAuthDataReq", nullptr};
constexpr utils::TypeInfo DropAuthDataRes::kType{utils::TypeId::REP_DROP_AUTH_DATA_RES, "DropAuthDataRes", nullptr};
void UpdateAuthDataReq::Save(const UpdateAuthDataReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void UpdateAuthDataReq::Load(UpdateAuthDataReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
void UpdateAuthDataRes::Save(const UpdateAuthDataRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void UpdateAuthDataRes::Load(UpdateAuthDataRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
void DropAuthDataReq::Save(const DropAuthDataReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void DropAuthDataReq::Load(DropAuthDataReq *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); }
void DropAuthDataRes::Save(const DropAuthDataRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void DropAuthDataRes::Load(DropAuthDataRes *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); }
} // namespace memgraph::replication

119
src/auth/rpc.hpp Normal file
View File

@ -0,0 +1,119 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <optional>
#include "auth/auth.hpp"
#include "auth/models.hpp"
#include "rpc/messages.hpp"
#include "slk/streams.hpp"
namespace memgraph::replication {
struct UpdateAuthDataReq {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(UpdateAuthDataReq *self, memgraph::slk::Reader *reader);
static void Save(const UpdateAuthDataReq &self, memgraph::slk::Builder *builder);
UpdateAuthDataReq() = default;
UpdateAuthDataReq(std::string epoch_id, uint64_t expected_ts, uint64_t new_ts, auth::User user)
: epoch_id{std::move(epoch_id)},
expected_group_timestamp{expected_ts},
new_group_timestamp{new_ts},
user{std::move(user)} {}
UpdateAuthDataReq(std::string epoch_id, uint64_t expected_ts, uint64_t new_ts, auth::Role role)
: epoch_id{std::move(epoch_id)},
expected_group_timestamp{expected_ts},
new_group_timestamp{new_ts},
role{std::move(role)} {}
std::string epoch_id;
uint64_t expected_group_timestamp;
uint64_t new_group_timestamp;
std::optional<auth::User> user;
std::optional<auth::Role> role;
};
struct UpdateAuthDataRes {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(UpdateAuthDataRes *self, memgraph::slk::Reader *reader);
static void Save(const UpdateAuthDataRes &self, memgraph::slk::Builder *builder);
UpdateAuthDataRes() = default;
explicit UpdateAuthDataRes(bool success) : success{success} {}
bool success;
};
using UpdateAuthDataRpc = rpc::RequestResponse<UpdateAuthDataReq, UpdateAuthDataRes>;
struct DropAuthDataReq {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(DropAuthDataReq *self, memgraph::slk::Reader *reader);
static void Save(const DropAuthDataReq &self, memgraph::slk::Builder *builder);
DropAuthDataReq() = default;
enum class DataType { USER, ROLE };
DropAuthDataReq(std::string epoch_id, uint64_t expected_ts, uint64_t new_ts, DataType type, std::string_view name)
: epoch_id{std::move(epoch_id)},
expected_group_timestamp{expected_ts},
new_group_timestamp{new_ts},
type{type},
name{name} {}
std::string epoch_id;
uint64_t expected_group_timestamp;
uint64_t new_group_timestamp;
DataType type;
std::string name;
};
struct DropAuthDataRes {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(DropAuthDataRes *self, memgraph::slk::Reader *reader);
static void Save(const DropAuthDataRes &self, memgraph::slk::Builder *builder);
DropAuthDataRes() = default;
explicit DropAuthDataRes(bool success) : success{success} {}
bool success;
};
using DropAuthDataRpc = rpc::RequestResponse<DropAuthDataReq, DropAuthDataRes>;
} // namespace memgraph::replication
namespace memgraph::slk {
void Save(const auth::Role &self, memgraph::slk::Builder *builder);
void Load(auth::Role *self, memgraph::slk::Reader *reader);
void Save(const auth::User &self, memgraph::slk::Builder *builder);
void Load(auth::User *self, memgraph::slk::Reader *reader);
void Save(const auth::Auth::Config &self, memgraph::slk::Builder *builder);
void Load(auth::Auth::Config *self, memgraph::slk::Reader *reader);
void Save(const memgraph::replication::UpdateAuthDataRes &self, memgraph::slk::Builder *builder);
void Load(memgraph::replication::UpdateAuthDataRes *self, memgraph::slk::Reader *reader);
void Save(const memgraph::replication::UpdateAuthDataReq & /*self*/, memgraph::slk::Builder * /*builder*/);
void Load(memgraph::replication::UpdateAuthDataReq * /*self*/, memgraph::slk::Reader * /*reader*/);
void Save(const memgraph::replication::DropAuthDataRes &self, memgraph::slk::Builder *builder);
void Load(memgraph::replication::DropAuthDataRes *self, memgraph::slk::Reader *reader);
void Save(const memgraph::replication::DropAuthDataReq & /*self*/, memgraph::slk::Builder * /*builder*/);
void Load(memgraph::replication::DropAuthDataReq * /*self*/, memgraph::slk::Reader * /*reader*/);
} // namespace memgraph::slk

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -14,8 +14,6 @@
#include <string>
#include "auth/auth.hpp"
#include "utils/spin_lock.hpp"
#include "utils/synchronized.hpp"
namespace memgraph::communication::websocket {
@ -30,7 +28,7 @@ class AuthenticationInterface {
class SafeAuth : public AuthenticationInterface {
public:
explicit SafeAuth(utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth) : auth_{auth} {}
explicit SafeAuth(auth::SynchedAuth *auth) : auth_{auth} {}
bool Authenticate(const std::string &username, const std::string &password) const override;
@ -39,6 +37,6 @@ class SafeAuth : public AuthenticationInterface {
bool HasAnyUsers() const override;
private:
utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth_;
auth::SynchedAuth *auth_;
};
} // namespace memgraph::communication::websocket

View File

@ -13,6 +13,7 @@ target_sources(mg-coordination
include/coordination/coordinator_data.hpp
include/coordination/constants.hpp
include/coordination/coordinator_cluster_config.hpp
include/coordination/coordinator_handlers.hpp
PRIVATE
coordinator_client.cpp
@ -21,9 +22,10 @@ target_sources(mg-coordination
coordinator_server.cpp
coordinator_data.cpp
coordinator_instance.cpp
coordinator_handlers.cpp
)
target_include_directories(mg-coordination PUBLIC include)
target_link_libraries(mg-coordination
PUBLIC mg::utils mg::rpc mg::slk mg::io mg::repl_coord_glue lib::rangev3 nuraft
PUBLIC mg::utils mg::rpc mg::slk mg::io mg::repl_coord_glue lib::rangev3 nuraft mg-replication_handler
)

View File

@ -183,7 +183,7 @@ auto CoordinatorData::RegisterInstance(CoordinatorClientConfig config) -> Regist
if (std::ranges::any_of(registered_instances_, [&config](CoordinatorInstance const &instance) {
return instance.SocketAddress() == config.SocketAddress();
})) {
return RegisterInstanceCoordinatorStatus::END_POINT_EXISTS;
return RegisterInstanceCoordinatorStatus::ENDPOINT_EXISTS;
}
try {

View File

@ -10,41 +10,35 @@
// licenses/APL.txt.
#ifdef MG_ENTERPRISE
#include "coordination/coordinator_handlers.hpp"
#include "dbms/coordinator_handlers.hpp"
#include <range/v3/view.hpp>
#include "coordination/coordinator_exceptions.hpp"
#include "coordination/coordinator_rpc.hpp"
#include "dbms/dbms_handler.hpp"
#include "dbms/replication_client.hpp"
#include "dbms/utils.hpp"
#include "range/v3/view.hpp"
#include "coordination/include/coordination/coordinator_server.hpp"
namespace memgraph::dbms {
void CoordinatorHandlers::Register(DbmsHandler &dbms_handler) {
auto &server = dbms_handler.CoordinatorState().GetCoordinatorServer();
void CoordinatorHandlers::Register(memgraph::coordination::CoordinatorServer &server,
replication::ReplicationHandler &replication_handler) {
server.Register<coordination::PromoteReplicaToMainRpc>(
[&dbms_handler](slk::Reader *req_reader, slk::Builder *res_builder) -> void {
[&](slk::Reader *req_reader, slk::Builder *res_builder) -> void {
spdlog::info("Received PromoteReplicaToMainRpc");
CoordinatorHandlers::PromoteReplicaToMainHandler(dbms_handler, req_reader, res_builder);
CoordinatorHandlers::PromoteReplicaToMainHandler(replication_handler, req_reader, res_builder);
});
server.Register<coordination::DemoteMainToReplicaRpc>(
[&dbms_handler](slk::Reader *req_reader, slk::Builder *res_builder) -> void {
[&replication_handler](slk::Reader *req_reader, slk::Builder *res_builder) -> void {
spdlog::info("Received DemoteMainToReplicaRpc from coordinator server");
CoordinatorHandlers::DemoteMainToReplicaHandler(dbms_handler, req_reader, res_builder);
CoordinatorHandlers::DemoteMainToReplicaHandler(replication_handler, req_reader, res_builder);
});
}
void CoordinatorHandlers::DemoteMainToReplicaHandler(DbmsHandler &dbms_handler, slk::Reader *req_reader,
slk::Builder *res_builder) {
auto &repl_state = dbms_handler.ReplicationState();
spdlog::info("Executing SetMainToReplicaHandler");
void CoordinatorHandlers::DemoteMainToReplicaHandler(replication::ReplicationHandler &replication_handler,
slk::Reader *req_reader, slk::Builder *res_builder) {
spdlog::info("Executing DemoteMainToReplicaHandler");
if (repl_state.IsReplica()) {
if (!replication_handler.IsMain()) {
spdlog::error("Setting to replica must be performed on main.");
slk::Save(coordination::DemoteMainToReplicaRes{false}, res_builder);
return;
@ -57,7 +51,7 @@ void CoordinatorHandlers::DemoteMainToReplicaHandler(DbmsHandler &dbms_handler,
.ip_address = req.replication_client_info.replication_ip_address,
.port = req.replication_client_info.replication_port};
if (bool const success = memgraph::dbms::SetReplicationRoleReplica(dbms_handler, clients_config); !success) {
if (!replication_handler.SetReplicationRoleReplica(clients_config)) {
spdlog::error("Demoting main to replica failed!");
slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder);
return;
@ -66,19 +60,17 @@ void CoordinatorHandlers::DemoteMainToReplicaHandler(DbmsHandler &dbms_handler,
slk::Save(coordination::PromoteReplicaToMainRes{true}, res_builder);
}
void CoordinatorHandlers::PromoteReplicaToMainHandler(DbmsHandler &dbms_handler, slk::Reader *req_reader,
slk::Builder *res_builder) {
auto &repl_state = dbms_handler.ReplicationState();
if (!repl_state.IsReplica()) {
spdlog::error("Only replica can be promoted to main!");
void CoordinatorHandlers::PromoteReplicaToMainHandler(replication::ReplicationHandler &replication_handler,
slk::Reader *req_reader, slk::Builder *res_builder) {
if (!replication_handler.IsReplica()) {
spdlog::error("Failover must be performed on replica!");
slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder);
return;
}
// This can fail because of disk. If it does, the cluster state could get inconsistent.
// We don't handle disk issues.
if (bool const success = memgraph::dbms::DoReplicaToMainPromotion(dbms_handler); !success) {
if (!replication_handler.DoReplicaToMainPromotion()) {
spdlog::error("Promoting replica to main failed!");
slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder);
return;
@ -96,53 +88,32 @@ void CoordinatorHandlers::PromoteReplicaToMainHandler(DbmsHandler &dbms_handler,
};
};
MG_ASSERT(
std::get<replication::RoleMainData>(repl_state.ReplicationData()).registered_replicas_.empty(),
"No replicas should be registered after promoting replica to main and before registering replication clients!");
// registering replicas
for (auto const &config : req.replication_clients_info | ranges::views::transform(converter)) {
auto instance_client = repl_state.RegisterReplica(config);
auto instance_client = replication_handler.RegisterReplica(config);
if (instance_client.HasError()) {
using enum memgraph::replication::RegisterReplicaError;
switch (instance_client.GetError()) {
// Can't happen, we are already replica
case NOT_MAIN:
spdlog::error("Failover must be performed on main!");
slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder);
return;
// Can't happen, checked on the coordinator side
case NAME_EXISTS:
case memgraph::query::RegisterReplicaError::NAME_EXISTS:
spdlog::error("Replica with the same name already exists!");
slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder);
return;
// Can't happen, checked on the coordinator side
case ENDPOINT_EXISTS:
case memgraph::query::RegisterReplicaError::ENDPOINT_EXISTS:
spdlog::error("Replica with the same endpoint already exists!");
slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder);
return;
// We don't handle disk issues
case COULD_NOT_BE_PERSISTED:
case memgraph::query::RegisterReplicaError::COULD_NOT_BE_PERSISTED:
spdlog::error("Registered replica could not be persisted!");
slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder);
return;
case SUCCESS:
case memgraph::query::RegisterReplicaError::CONNECTION_FAILED:
// Connection failure is not a fatal error
break;
}
}
if (!allow_mt_repl && dbms_handler.All().size() > 1) {
spdlog::warn("Multi-tenant replication is currently not supported!");
}
auto &instance_client_ref = *instance_client.GetValue();
// Update system before enabling individual storage <-> replica clients
dbms_handler.SystemRestore(instance_client_ref);
const bool all_clients_good = memgraph::dbms::RegisterAllDatabasesClients<true>(dbms_handler, instance_client_ref);
MG_ASSERT(all_clients_good, "Failed to register one or more databases to the REPLICA \"{}\".", config.name);
StartReplicaClient(dbms_handler, instance_client_ref);
}
slk::Save(coordination::PromoteReplicaToMainRes{true}, res_builder);

View File

@ -13,7 +13,9 @@
#ifdef MG_ENTERPRISE
#include "slk/serialization.hpp"
#include "coordination/coordinator_server.hpp"
#include "replication_handler/replication_handler.hpp"
#include "slk/streams.hpp"
namespace memgraph::dbms {
@ -21,12 +23,14 @@ class DbmsHandler;
class CoordinatorHandlers {
public:
static void Register(DbmsHandler &dbms_handler);
static void Register(memgraph::coordination::CoordinatorServer &server,
replication::ReplicationHandler &replication_handler);
private:
static void PromoteReplicaToMainHandler(DbmsHandler &dbms_handler, slk::Reader *req_reader,
static void PromoteReplicaToMainHandler(replication::ReplicationHandler &replication_handler, slk::Reader *req_reader,
slk::Builder *res_builder);
static void DemoteMainToReplicaHandler(DbmsHandler &dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder);
static void DemoteMainToReplicaHandler(replication::ReplicationHandler &replication_handler, slk::Reader *req_reader,
slk::Builder *res_builder);
};
} // namespace memgraph::dbms

View File

@ -19,7 +19,7 @@ namespace memgraph::coordination {
enum class RegisterInstanceCoordinatorStatus : uint8_t {
NAME_EXISTS,
END_POINT_EXISTS,
ENDPOINT_EXISTS,
NOT_COORDINATOR,
RPC_FAILED,
SUCCESS

View File

@ -1,2 +1,10 @@
add_library(mg-dbms STATIC dbms_handler.cpp database.cpp replication_handler.cpp coordinator_handler.cpp replication_client.cpp inmemory/replication_handlers.cpp coordinator_handlers.cpp)
target_link_libraries(mg-dbms mg-utils mg-storage-v2 mg-query mg-replication mg-coordination)
add_library(mg-dbms STATIC
dbms_handler.cpp
database.cpp
coordinator_handler.cpp
inmemory/replication_handlers.cpp
replication_handlers.cpp
rpc.cpp
)
target_link_libraries(mg-dbms mg-utils mg-storage-v2 mg-query mg-auth mg-replication mg-coordination)

View File

@ -18,20 +18,21 @@
namespace memgraph::dbms {
CoordinatorHandler::CoordinatorHandler(DbmsHandler &dbms_handler) : dbms_handler_(dbms_handler) {}
CoordinatorHandler::CoordinatorHandler(coordination::CoordinatorState &coordinator_state)
: coordinator_state_(coordinator_state) {}
auto CoordinatorHandler::RegisterInstance(memgraph::coordination::CoordinatorClientConfig config)
-> coordination::RegisterInstanceCoordinatorStatus {
return dbms_handler_.CoordinatorState().RegisterInstance(config);
return coordinator_state_.RegisterInstance(config);
}
auto CoordinatorHandler::SetInstanceToMain(std::string instance_name)
-> coordination::SetInstanceToMainCoordinatorStatus {
return dbms_handler_.CoordinatorState().SetInstanceToMain(std::move(instance_name));
return coordinator_state_.SetInstanceToMain(std::move(instance_name));
}
auto CoordinatorHandler::ShowInstances() const -> std::vector<coordination::CoordinatorInstanceStatus> {
return dbms_handler_.CoordinatorState().ShowInstances();
return coordinator_state_.ShowInstances();
}
} // namespace memgraph::dbms

View File

@ -15,11 +15,9 @@
#include "coordination/coordinator_config.hpp"
#include "coordination/coordinator_instance_status.hpp"
#include "coordination/coordinator_state.hpp"
#include "coordination/register_main_replica_coordinator_status.hpp"
#include "utils/result.hpp"
#include <cstdint>
#include <optional>
#include <vector>
namespace memgraph::dbms {
@ -28,7 +26,7 @@ class DbmsHandler;
class CoordinatorHandler {
public:
explicit CoordinatorHandler(DbmsHandler &dbms_handler);
explicit CoordinatorHandler(coordination::CoordinatorState &coordinator_state);
auto RegisterInstance(coordination::CoordinatorClientConfig config)
-> coordination::RegisterInstanceCoordinatorStatus;
@ -38,7 +36,7 @@ class CoordinatorHandler {
auto ShowInstances() const -> std::vector<coordination::CoordinatorInstanceStatus>;
private:
DbmsHandler &dbms_handler_;
coordination::CoordinatorState &coordinator_state_;
};
} // namespace memgraph::dbms

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -11,10 +11,7 @@
#include "dbms/database.hpp"
#include "dbms/inmemory/storage_helper.hpp"
#include "dbms/replication_handler.hpp"
#include "flags/storage_mode.hpp"
#include "storage/v2/disk/storage.hpp"
#include "storage/v2/inmemory/storage.hpp"
#include "storage/v2/storage_mode.hpp"
template struct memgraph::utils::Gatekeeper<memgraph::dbms::Database>;

View File

@ -11,29 +11,73 @@
#include "dbms/dbms_handler.hpp"
#include "dbms/coordinator_handlers.hpp"
#include "flags/replication.hpp"
#include <cstdint>
#include <filesystem>
#include "dbms/constants.hpp"
#include "dbms/global.hpp"
#include "dbms/replication_client.hpp"
#include "spdlog/spdlog.h"
#include "system/include/system/system.hpp"
#include "utils/exceptions.hpp"
#include "utils/logging.hpp"
#include "utils/uuid.hpp"
namespace memgraph::dbms {
#ifdef MG_ENTERPRISE
namespace {
constexpr std::string_view kDBPrefix = "database:"; // Key prefix for database durability
constexpr std::string_view kLastCommitedSystemTsKey = "last_commited_system_ts"; // Key for timestamp durability
constexpr std::string_view kDBPrefix = "database:"; // Key prefix for database durability
std::string RegisterReplicaErrorToString(query::RegisterReplicaError error) {
switch (error) {
using enum query::RegisterReplicaError;
case NAME_EXISTS:
return "NAME_EXISTS";
case ENDPOINT_EXISTS:
return "ENDPOINT_EXISTS";
case CONNECTION_FAILED:
return "CONNECTION_FAILED";
case COULD_NOT_BE_PERSISTED:
return "COULD_NOT_BE_PERSISTED";
}
}
// Per storage
// NOTE Storage will connect to all replicas. Future work might change this
void RestoreReplication(replication::RoleMainData &mainData, DatabaseAccess db_acc) {
spdlog::info("Restoring replication role.");
// Each individual client has already been restored and started. Here we just go through each database and start its
// client
for (auto &instance_client : mainData.registered_replicas_) {
spdlog::info("Replica {} restoration started for {}.", instance_client.name_, db_acc->name());
const auto &ret = db_acc->storage()->repl_storage_state_.replication_clients_.WithLock(
[&, db_acc](auto &storage_clients) mutable -> utils::BasicResult<query::RegisterReplicaError> {
auto client = std::make_unique<storage::ReplicationStorageClient>(instance_client);
auto *storage = db_acc->storage();
client->Start(storage, std::move(db_acc));
// After start the storage <-> replica state should be READY or RECOVERING (if correctly started)
// MAYBE_BEHIND isn't a statement of the current state, this is the default value
// Failed to start due to branching of MAIN and REPLICA
if (client->State() == storage::replication::ReplicaState::MAYBE_BEHIND) {
spdlog::warn("Connection failed when registering replica {}. Replica will still be registered.",
instance_client.name_);
}
storage_clients.push_back(std::move(client));
return {};
});
if (ret.HasError()) {
MG_ASSERT(query::RegisterReplicaError::CONNECTION_FAILED != ret.GetError());
LOG_FATAL("Failure when restoring replica {}: {}.", instance_client.name_,
RegisterReplicaErrorToString(ret.GetError()));
}
spdlog::info("Replica {} restored for {}.", instance_client.name_, db_acc->name());
}
spdlog::info("Replication role restored to MAIN.");
}
} // namespace
#ifdef MG_ENTERPRISE
struct Durability {
enum class DurabilityVersion : uint8_t {
V0 = 0,
@ -112,11 +156,9 @@ struct Durability {
}
};
DbmsHandler::DbmsHandler(
storage::Config config,
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth,
bool recovery_on_startup)
: default_config_{std::move(config)}, repl_state_{ReplicationStateRootPath(default_config_)} {
DbmsHandler::DbmsHandler(storage::Config config, memgraph::system::System &system,
replication::ReplicationState &repl_state, auth::SynchedAuth &auth, bool recovery_on_startup)
: default_config_{std::move(config)}, auth_{auth}, repl_state_{repl_state}, system_{&system} {
// TODO: Decouple storage config from dbms config
// TODO: Save individual db configs inside the kvstore and restore from there
@ -150,19 +192,13 @@ DbmsHandler::DbmsHandler(
const auto uuid = json.at("uuid").get<utils::UUID>();
const auto rel_dir = json.at("rel_dir").get<std::filesystem::path>();
spdlog::info("Restoring database {} at {}.", name, rel_dir);
auto new_db = New_(name, uuid, rel_dir);
auto new_db = New_(name, uuid, nullptr, rel_dir);
MG_ASSERT(!new_db.HasError(), "Failed while creating database {}.", name);
directories.emplace(rel_dir.filename());
spdlog::info("Database {} restored.", name);
}
// Read the last timestamp
auto lcst = durability_->Get(kLastCommitedSystemTsKey);
if (lcst) {
last_commited_system_timestamp_ = std::stoul(*lcst);
system_timestamp_ = last_commited_system_timestamp_;
}
} else { // Clear databases from the durability list and auth
auto locked_auth = auth->Lock();
auto locked_auth = auth_.Lock();
auto it = durability_->begin(std::string{kDBPrefix});
auto end = durability_->end(std::string{kDBPrefix});
for (; it != end; ++it) {
@ -172,8 +208,6 @@ DbmsHandler::DbmsHandler(
locked_auth->DeleteDatabase(name);
durability_->Delete(key);
}
// Delete the last timestamp
durability_->Delete(kLastCommitedSystemTsKey);
}
/*
@ -198,45 +232,29 @@ DbmsHandler::DbmsHandler(
*/
// Setup the default DB
SetupDefault_();
/*
* REPLICATION RECOVERY AND STARTUP
*/
// Startup replication state (if recovered at startup)
auto replica = [this](replication::RoleReplicaData const &data) { return StartRpcServer(*this, data); };
// Replication recovery and frequent check start
auto main = [this](replication::RoleMainData &data) {
for (auto &client : data.registered_replicas_) {
SystemRestore(client);
}
ForEach([this](DatabaseAccess db) { RecoverReplication(db); });
for (auto &client : data.registered_replicas_) {
StartReplicaClient(*this, client);
}
return true;
};
// Startup proccess for main/replica
MG_ASSERT(std::visit(memgraph::utils::Overloaded{replica, main}, repl_state_.ReplicationData()),
"Replica recovery failure!");
// Warning
if (default_config_.durability.snapshot_wal_mode == storage::Config::Durability::SnapshotWalMode::DISABLED &&
repl_state_.IsMain()) {
spdlog::warn(
"The instance has the MAIN replication role, but durability logs and snapshots are disabled. Please "
"consider "
"enabling durability by using --storage-snapshot-interval-sec and --storage-wal-enabled flags because "
"without write-ahead logs this instance is not replicating any data.");
}
// MAIN or REPLICA instance
if (FLAGS_coordinator_server_port) {
CoordinatorHandlers::Register(*this);
MG_ASSERT(coordinator_state_.GetCoordinatorServer().Start(), "Failed to start coordinator server!");
}
}
DbmsHandler::DeleteResult DbmsHandler::TryDelete(std::string_view db_name) {
struct DropDatabase : memgraph::system::ISystemAction {
explicit DropDatabase(utils::UUID uuid) : uuid_{uuid} {}
void DoDurability() override { /* Done during DBMS execution */
}
bool DoReplication(replication::ReplicationClient &client, replication::ReplicationEpoch const &epoch,
memgraph::system::Transaction const &txn) const override {
auto check_response = [](const storage::replication::DropDatabaseRes &response) {
return response.result != storage::replication::DropDatabaseRes::Result::FAILURE;
};
return client.SteamAndFinalizeDelta<storage::replication::DropDatabaseRpc>(
check_response, epoch.id(), txn.last_committed_system_timestamp(), txn.timestamp(), uuid_);
}
void PostReplication(replication::RoleMainData &mainData) const override {}
private:
utils::UUID uuid_;
};
DbmsHandler::DeleteResult DbmsHandler::TryDelete(std::string_view db_name, system::Transaction *transaction) {
std::lock_guard<LockT> wr(lock_);
if (db_name == kDefaultDB) {
// MSG cannot delete the default db
@ -273,9 +291,10 @@ DbmsHandler::DeleteResult DbmsHandler::TryDelete(std::string_view db_name) {
// Success
// Save delta
if (system_transaction_) {
system_transaction_->delta.emplace(SystemTransaction::Delta::drop_database, uuid);
if (transaction) {
transaction->AddAction<DropDatabase>(uuid);
}
return {};
}
@ -296,18 +315,48 @@ DbmsHandler::DeleteResult DbmsHandler::Delete(utils::UUID uuid) {
return Delete_(db_name);
}
DbmsHandler::NewResultT DbmsHandler::New_(storage::Config storage_config) {
struct CreateDatabase : memgraph::system::ISystemAction {
explicit CreateDatabase(storage::SalientConfig config, DatabaseAccess db_acc)
: config_{std::move(config)}, db_acc(db_acc) {}
void DoDurability() override {
// Done during dbms execution
}
bool DoReplication(replication::ReplicationClient &client, replication::ReplicationEpoch const &epoch,
memgraph::system::Transaction const &txn) const override {
auto check_response = [](const storage::replication::CreateDatabaseRes &response) {
return response.result != storage::replication::CreateDatabaseRes::Result::FAILURE;
};
return client.SteamAndFinalizeDelta<storage::replication::CreateDatabaseRpc>(
check_response, epoch.id(), txn.last_committed_system_timestamp(), txn.timestamp(), config_);
}
void PostReplication(replication::RoleMainData &mainData) const override {
// Sync database with REPLICAs
// NOTE: The function bellow is used to create ReplicationStorageClient, so it must be called on a new storage
// We don't need to have it here, since the function won't fail even if the replication client fails to
// connect We will just have everything ready, for recovery at some point.
dbms::DbmsHandler::RecoverStorageReplication(db_acc, mainData);
}
private:
storage::SalientConfig config_;
DatabaseAccess db_acc;
};
DbmsHandler::NewResultT DbmsHandler::New_(storage::Config storage_config, system::Transaction *txn) {
auto new_db = db_handler_.New(storage_config, repl_state_);
if (new_db.HasValue()) { // Success
// Save delta
if (system_transaction_) {
system_transaction_->delta.emplace(SystemTransaction::Delta::create_database, storage_config.salient);
}
UpdateDurability(storage_config);
return new_db.GetValue();
if (txn) {
txn->AddAction<CreateDatabase>(storage_config.salient, new_db.GetValue());
}
}
return new_db.GetError();
return new_db;
}
DbmsHandler::DeleteResult DbmsHandler::Delete_(std::string_view db_name) {
@ -361,89 +410,16 @@ void DbmsHandler::UpdateDurability(const storage::Config &config, std::optional<
durability_->Put(key, val);
}
AllSyncReplicaStatus DbmsHandler::Commit() {
if (system_transaction_ == std::nullopt || system_transaction_->delta == std::nullopt)
return AllSyncReplicaStatus::AllCommitsConfirmed; // Nothing to commit
const auto &delta = *system_transaction_->delta;
auto sync_status = AllSyncReplicaStatus::AllCommitsConfirmed;
// TODO Create a system client that can handle all of this automatically
switch (delta.action) {
using enum SystemTransaction::Delta::Action;
case CREATE_DATABASE: {
// Replication
auto main_handler = [&](memgraph::replication::RoleMainData &main_data) {
// TODO: data race issue? registered_replicas_ access not protected
// This is sync in any case, as this is the startup
for (auto &client : main_data.registered_replicas_) {
bool completed = SteamAndFinalizeDelta<storage::replication::CreateDatabaseRpc>(
client,
[](const storage::replication::CreateDatabaseRes &response) {
return response.result != storage::replication::CreateDatabaseRes::Result::FAILURE;
},
std::string(main_data.epoch_.id()), last_commited_system_timestamp_,
system_transaction_->system_timestamp, delta.config);
// TODO: reduce duplicate code
if (!completed && client.mode_ == replication_coordination_glue::ReplicationMode::SYNC) {
sync_status = AllSyncReplicaStatus::SomeCommitsUnconfirmed;
}
}
// Sync database with REPLICAs
RecoverReplication(Get_(delta.config.name));
};
auto replica_handler = [](memgraph::replication::RoleReplicaData &) { /* Nothing to do */ };
std::visit(utils::Overloaded{main_handler, replica_handler}, repl_state_.ReplicationData());
} break;
case DROP_DATABASE: {
// Replication
auto main_handler = [&](memgraph::replication::RoleMainData &main_data) {
// TODO: data race issue? registered_replicas_ access not protected
// This is sync in any case, as this is the startup
for (auto &client : main_data.registered_replicas_) {
bool completed = SteamAndFinalizeDelta<storage::replication::DropDatabaseRpc>(
client,
[](const storage::replication::DropDatabaseRes &response) {
return response.result != storage::replication::DropDatabaseRes::Result::FAILURE;
},
std::string(main_data.epoch_.id()), last_commited_system_timestamp_,
system_transaction_->system_timestamp, delta.uuid);
// TODO: reduce duplicate code
if (!completed && client.mode_ == replication_coordination_glue::ReplicationMode::SYNC) {
sync_status = AllSyncReplicaStatus::SomeCommitsUnconfirmed;
}
}
};
auto replica_handler = [](memgraph::replication::RoleReplicaData &) { /* Nothing to do */ };
std::visit(utils::Overloaded{main_handler, replica_handler}, repl_state_.ReplicationData());
} break;
}
durability_->Put(kLastCommitedSystemTsKey, std::to_string(system_transaction_->system_timestamp));
last_commited_system_timestamp_ = system_transaction_->system_timestamp;
ResetSystemTransaction();
return sync_status;
}
#else // not MG_ENTERPRISE
AllSyncReplicaStatus DbmsHandler::Commit() {
if (system_transaction_ == std::nullopt || system_transaction_->delta == std::nullopt) {
return AllSyncReplicaStatus::AllCommitsConfirmed; // Nothing to commit
}
const auto &delta = *system_transaction_->delta;
switch (delta.action) {
using enum SystemTransaction::Delta::Action;
case CREATE_DATABASE:
case DROP_DATABASE:
/* Community edition doesn't support multi-tenant replication */
break;
}
last_commited_system_timestamp_ = system_transaction_->system_timestamp;
ResetSystemTransaction();
return AllSyncReplicaStatus::AllCommitsConfirmed;
}
#endif
void DbmsHandler::RecoverStorageReplication(DatabaseAccess db_acc, replication::RoleMainData &role_main_data) {
if (allow_mt_repl || db_acc->name() == dbms::kDefaultDB) {
// Handle global replication state
spdlog::info("Replication configuration will be stored and will be automatically restored in case of a crash.");
// RECOVER REPLICA CONNECTIONS
memgraph::dbms::RestoreReplication(role_main_data, db_acc);
} else if (!role_main_data.registered_replicas_.empty()) {
spdlog::warn("Multi-tenant replication is currently not supported!");
}
}
} // namespace memgraph::dbms

View File

@ -25,24 +25,24 @@
#include "constants.hpp"
#include "dbms/database.hpp"
#include "dbms/inmemory/replication_handlers.hpp"
#include "dbms/replication_handler.hpp"
#include "dbms/rpc.hpp"
#include "kvstore/kvstore.hpp"
#include "license/license.hpp"
#include "replication/replication_client.hpp"
#include "storage/v2/config.hpp"
#include "storage/v2/replication/enums.hpp"
#include "storage/v2/replication/rpc.hpp"
#include "storage/v2/transaction.hpp"
#include "system/system.hpp"
#include "utils/thread_pool.hpp"
#ifdef MG_ENTERPRISE
#include "coordination/coordinator_state.hpp"
#include "dbms/database_handler.hpp"
#endif
#include "dbms/transaction.hpp"
#include "global.hpp"
#include "query/config.hpp"
#include "query/interpreter_context.hpp"
#include "spdlog/spdlog.h"
#include "storage/v2/isolation_level.hpp"
#include "system/system.hpp"
#include "utils/logging.hpp"
#include "utils/result.hpp"
#include "utils/rw_lock.hpp"
@ -51,11 +51,6 @@
namespace memgraph::dbms {
enum class AllSyncReplicaStatus {
AllCommitsConfirmed,
SomeCommitsUnconfirmed,
};
struct Statistics {
uint64_t num_vertex; //!< Sum of vertexes in every database
uint64_t num_edges; //!< Sum of edges in every database
@ -111,8 +106,8 @@ class DbmsHandler {
* @param auth pointer to the global authenticator
* @param recovery_on_startup restore databases (and its content) and authentication data
*/
DbmsHandler(storage::Config config,
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth,
DbmsHandler(storage::Config config, memgraph::system::System &system, replication::ReplicationState &repl_state,
auth::SynchedAuth &auth,
bool recovery_on_startup); // TODO If more arguments are added use a config struct
#else
/**
@ -120,15 +115,14 @@ class DbmsHandler {
*
* @param configs storage configuration
*/
DbmsHandler(storage::Config config)
: repl_state_{ReplicationStateRootPath(config)},
DbmsHandler(storage::Config config, memgraph::system::System &system, replication::ReplicationState &repl_state)
: repl_state_{repl_state},
system_{&system},
db_gatekeeper_{[&] {
config.salient.name = kDefaultDB;
return std::move(config);
}(),
repl_state_} {
RecoverReplication(Get());
}
repl_state_} {}
#endif
#ifdef MG_ENTERPRISE
@ -138,10 +132,10 @@ class DbmsHandler {
* @param name name of the database
* @return NewResultT context on success, error on failure
*/
NewResultT New(const std::string &name) {
NewResultT New(const std::string &name, system::Transaction *txn = nullptr) {
std::lock_guard<LockT> wr(lock_);
const auto uuid = utils::UUID{};
return New_(name, uuid);
return New_(name, uuid, txn);
}
/**
@ -234,7 +228,7 @@ class DbmsHandler {
* @param db_name database name
* @return DeleteResult error on failure
*/
DeleteResult TryDelete(std::string_view db_name);
DeleteResult TryDelete(std::string_view db_name, system::Transaction *transaction = nullptr);
/**
* @brief Delete or defer deletion of database.
@ -267,23 +261,12 @@ class DbmsHandler {
#endif
}
replication::ReplicationState &ReplicationState() { return repl_state_; }
replication::ReplicationState const &ReplicationState() const { return repl_state_; }
bool IsMain() const { return repl_state_.IsMain(); }
bool IsReplica() const { return repl_state_.IsReplica(); }
#ifdef MG_ENTERPRISE
coordination::CoordinatorState &CoordinatorState() { return coordinator_state_; }
#endif
/**
* @brief Return the statistics all databases.
*
* @return Statistics
*/
Statistics Stats() {
auto const replication_role = repl_state_.GetRole();
Statistics Stats(memgraph::replication_coordination_glue::ReplicationRole replication_role) {
Statistics stats{};
// TODO: Handle overflow?
#ifdef MG_ENTERPRISE
@ -319,8 +302,7 @@ class DbmsHandler {
*
* @return std::vector<DatabaseInfo>
*/
std::vector<DatabaseInfo> Info() {
auto const replication_role = repl_state_.GetRole();
std::vector<DatabaseInfo> Info(memgraph::replication_coordination_glue::ReplicationRole replication_role) {
std::vector<DatabaseInfo> res;
#ifdef MG_ENTERPRISE
std::shared_lock<LockT> rd(lock_);
@ -407,98 +389,17 @@ class DbmsHandler {
}
}
void NewSystemTransaction() {
DMG_ASSERT(!system_transaction_, "Already running a system transaction");
system_transaction_.emplace(++system_timestamp_);
}
void ResetSystemTransaction() { system_transaction_.reset(); }
//! \tparam RPC An rpc::RequestResponse
//! \tparam Args the args type
//! \param client the client to use for rpc communication
//! \param check predicate to check response is ok
//! \param args arguments to forward to the rpc request
//! \return If replica stream is completed or enqueued
template <typename RPC, typename... Args>
bool SteamAndFinalizeDelta(auto &client, auto &&check, Args &&...args) {
try {
auto stream = client.rpc_client_.template Stream<RPC>(std::forward<Args>(args)...);
auto task = [&client, check = std::forward<decltype(check)>(check), stream = std::move(stream)]() mutable {
if (stream.IsDefunct()) {
client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; });
return false;
}
try {
if (check(stream.AwaitResponse())) {
return true;
}
} catch (memgraph::rpc::GenericRpcFailedException const &e) {
// swallow error, fallthrough to error handling
}
// This replica needs SYSTEM recovery
client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; });
return false;
};
if (client.mode_ == memgraph::replication_coordination_glue::ReplicationMode::ASYNC) {
client.thread_pool_.AddTask([task = utils::CopyMovableFunctionWrapper{std::move(task)}]() mutable { task(); });
return true;
}
return task();
} catch (memgraph::rpc::GenericRpcFailedException const &e) {
// This replica needs SYSTEM recovery
client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; });
return false;
}
};
AllSyncReplicaStatus Commit();
auto LastCommitedTS() const -> uint64_t { return last_commited_system_timestamp_; }
void SetLastCommitedTS(uint64_t new_ts) { last_commited_system_timestamp_.store(new_ts); }
static void RecoverStorageReplication(DatabaseAccess db_acc, replication::RoleMainData &role_main_data);
auto default_config() const -> storage::Config const & {
#ifdef MG_ENTERPRISE
// When being called by intepreter no need to gain lock, it should already be under a system transaction
// But concurrently the FrequentCheck is running and will need to lock before reading last_commited_system_timestamp_
template <bool REQUIRE_LOCK = false>
void SystemRestore(replication::ReplicationClient &client) {
// Check if system is up to date
if (client.state_.WithLock(
[](auto &state) { return state == memgraph::replication::ReplicationClient::State::READY; }))
return;
// Try to recover...
{
auto [database_configs, last_commited_system_timestamp] = std::invoke([&] {
auto sys_guard =
std::unique_lock{system_lock_, std::defer_lock}; // ensure no other system transaction in progress
if constexpr (REQUIRE_LOCK) {
sys_guard.lock();
}
auto configs = std::vector<storage::SalientConfig>{};
ForEach([&configs](DatabaseAccess acc) { configs.emplace_back(acc->config().salient); });
return std::pair{configs, last_commited_system_timestamp_.load()};
});
try {
auto stream = client.rpc_client_.Stream<storage::replication::SystemRecoveryRpc>(last_commited_system_timestamp,
std::move(database_configs));
const auto response = stream.AwaitResponse();
if (response.result == storage::replication::SystemRecoveryRes::Result::FAILURE) {
client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; });
return;
}
} catch (memgraph::rpc::GenericRpcFailedException const &e) {
client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; });
return;
}
}
// Successfully recovered
client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::READY; });
}
return default_config_;
#else
const auto acc = db_gatekeeper_.access();
MG_ASSERT(acc, "Failed to get default database!");
return acc->get()->config();
#endif
}
private:
#ifdef MG_ENTERPRISE
@ -524,7 +425,8 @@ class DbmsHandler {
* @param uuid undelying RocksDB directory
* @return NewResultT context on success, error on failure
*/
NewResultT New_(std::string_view name, utils::UUID uuid, std::optional<std::filesystem::path> rel_dir = {}) {
NewResultT New_(std::string_view name, utils::UUID uuid, system::Transaction *txn = nullptr,
std::optional<std::filesystem::path> rel_dir = {}) {
auto config_copy = default_config_;
config_copy.salient.name = name;
config_copy.salient.uuid = uuid;
@ -535,7 +437,7 @@ class DbmsHandler {
storage::UpdatePaths(config_copy,
default_config_.durability.storage_directory / kMultiTenantDir / std::string{uuid});
}
return New_(std::move(config_copy));
return New_(std::move(config_copy), txn);
}
/**
@ -544,11 +446,11 @@ class DbmsHandler {
* @param config configuration to be used
* @return NewResultT context on success, error on failure
*/
NewResultT New_(const storage::SalientConfig &config) {
NewResultT New_(const storage::SalientConfig &config, system::Transaction *txn = nullptr) {
auto config_copy = default_config_;
config_copy.salient = config; // name, uuid, mode, etc
UpdatePaths(config_copy, config_copy.durability.storage_directory / kMultiTenantDir / std::string{config.uuid});
return New_(std::move(config_copy));
return New_(std::move(config_copy), txn);
}
/**
@ -557,7 +459,7 @@ class DbmsHandler {
* @param storage_config storage configuration
* @return NewResultT context on success, error on failure
*/
NewResultT New_(storage::Config storage_config);
DbmsHandler::NewResultT New_(storage::Config storage_config, system::Transaction *txn = nullptr);
// TODO: new overload of Delete_ with DatabaseAccess
DeleteResult Delete_(std::string_view db_name);
@ -572,7 +474,8 @@ class DbmsHandler {
Get(kDefaultDB);
} catch (const UnknownDatabaseException &) {
// No default DB restored, create it
MG_ASSERT(New_(kDefaultDB, {/* random UUID */}, ".").HasValue(), "Failed while creating the default database");
MG_ASSERT(New_(kDefaultDB, {/* random UUID */}, nullptr, ".").HasValue(),
"Failed while creating the default database");
}
// For back-compatibility...
@ -659,35 +562,24 @@ class DbmsHandler {
}
#endif
void RecoverReplication(DatabaseAccess db_acc) {
if (allow_mt_repl || db_acc->name() == dbms::kDefaultDB) {
// Handle global replication state
spdlog::info("Replication configuration will be stored and will be automatically restored in case of a crash.");
// RECOVER REPLICA CONNECTIONS
memgraph::dbms::RestoreReplication(repl_state_, std::move(db_acc));
} else if (const ::memgraph::replication::RoleMainData *data =
std::get_if<::memgraph::replication::RoleMainData>(&repl_state_.ReplicationData());
data && !data->registered_replicas_.empty()) {
spdlog::warn("Multi-tenant replication is currently not supported!");
}
}
#ifdef MG_ENTERPRISE
mutable LockT lock_{utils::RWLock::Priority::READ}; //!< protective lock
storage::Config default_config_; //!< Storage configuration used when creating new databases
DatabaseHandler db_handler_; //!< multi-tenancy storage handler
std::unique_ptr<kvstore::KVStore> durability_; //!< list of active dbs (pointer so we can postpone its creation)
coordination::CoordinatorState coordinator_state_; //!< Replication coordinator
// TODO: move to be common
std::unique_ptr<kvstore::KVStore> durability_; //!< list of active dbs (pointer so we can postpone its creation)
auth::SynchedAuth &auth_; //!< Synchronized auth::Auth
#endif
// TODO: Make an api
public:
utils::ResourceLock system_lock_{}; //!> Ensure exclusive access for system queries
private:
std::optional<SystemTransaction> system_transaction_; //!< Current system transaction (only one at a time)
uint64_t system_timestamp_{storage::kTimestampInitialId}; //!< System timestamp
std::atomic_uint64_t last_commited_system_timestamp_{
storage::kTimestampInitialId}; //!< Last commited system timestamp
replication::ReplicationState repl_state_; //!< Global replication state
// NOTE: atm the only reason this exists here, is because we pass it into the construction of New Database's
// Database only uses it as a convience to make the correct Access without out needing to be told the
// current replication role. TODO: make Database Access explicit about the role and remove this from
// dbms stuff
replication::ReplicationState &repl_state_; //!< Ref to global replication state
public:
// TODO fix to be non public/remove from dbms....maybe
system::System *system_;
#ifndef MG_ENTERPRISE
mutable utils::Gatekeeper<Database> db_gatekeeper_; //!< Single databases gatekeeper
#endif

View File

@ -11,10 +11,6 @@
#pragma once
#include <variant>
#include "dbms/constants.hpp"
#include "dbms/replication_handler.hpp"
#include "replication/state.hpp"
#include "storage/v2/config.hpp"
#include "storage/v2/inmemory/storage.hpp"

View File

@ -1,43 +0,0 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "dbms/replication_client.hpp"
#include "replication/replication_client.hpp"
namespace memgraph::dbms {
void StartReplicaClient(DbmsHandler &dbms_handler, replication::ReplicationClient &client) {
// No client error, start instance level client
auto const &endpoint = client.rpc_client_.Endpoint();
spdlog::trace("Replication client started at: {}:{}", endpoint.address, endpoint.port);
client.StartFrequentCheck([&dbms_handler](bool reconnect, replication::ReplicationClient &client) {
// Working connection
// Check if system needs restoration
if (reconnect) {
client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; });
}
#ifdef MG_ENTERPRISE
dbms_handler.SystemRestore<true>(client);
#endif
// Check if any database has been left behind
dbms_handler.ForEach([&name = client.name_, reconnect](dbms::DatabaseAccess db_acc) {
// Specific database <-> replica client
db_acc->storage()->repl_storage_state_.WithClient(name, [&](storage::ReplicationStorageClient *client) {
if (reconnect || client->State() == storage::replication::ReplicaState::MAYBE_BEHIND) {
// Database <-> replica might be behind, check and recover
client->TryCheckReplicaStateAsync(db_acc->storage(), db_acc);
}
});
});
});
} // namespace memgraph::dbms
} // namespace memgraph::dbms

View File

@ -1,400 +0,0 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "dbms/replication_handler.hpp"
#include <algorithm>
#include "dbms/constants.hpp"
#include "dbms/dbms_handler.hpp"
#include "dbms/global.hpp"
#include "dbms/inmemory/replication_handlers.hpp"
#include "dbms/replication_client.hpp"
#include "dbms/utils.hpp"
#include "replication/messages.hpp"
#include "replication/state.hpp"
#include "spdlog/spdlog.h"
#include "storage/v2/config.hpp"
#include "storage/v2/replication/rpc.hpp"
#include "utils/on_scope_exit.hpp"
using memgraph::replication::RoleMainData;
using memgraph::replication::RoleReplicaData;
namespace memgraph::dbms {
namespace {
std::string RegisterReplicaErrorToString(RegisterReplicaError error) {
switch (error) {
using enum RegisterReplicaError;
case NAME_EXISTS:
return "NAME_EXISTS";
case ENDPOINT_EXISTS:
return "ENDPOINT_EXISTS";
case CONNECTION_FAILED:
return "CONNECTION_FAILED";
case COULD_NOT_BE_PERSISTED:
return "COULD_NOT_BE_PERSISTED";
}
}
} // namespace
ReplicationHandler::ReplicationHandler(DbmsHandler &dbms_handler) : dbms_handler_(dbms_handler) {}
bool ReplicationHandler::SetReplicationRoleMain() {
auto const main_handler = [](RoleMainData &) {
// If we are already MAIN, we don't want to change anything
return false;
};
auto const replica_handler = [this](RoleReplicaData const &) {
return memgraph::dbms::DoReplicaToMainPromotion(dbms_handler_);
};
// TODO: under lock
return std::visit(utils::Overloaded{main_handler, replica_handler},
dbms_handler_.ReplicationState().ReplicationData());
}
bool ReplicationHandler::SetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config) {
// We don't want to restart the server if we're already a REPLICA
if (dbms_handler_.ReplicationState().IsReplica()) {
return false;
}
// TODO StorageState needs to be synched. Could have a dangling reference if someone adds a database as we are
// deleting the replica.
// Remove database specific clients
dbms_handler_.ForEach([&](DatabaseAccess db_acc) {
auto *storage = db_acc->storage();
storage->repl_storage_state_.replication_clients_.WithLock([](auto &clients) { clients.clear(); });
});
// Remove instance level clients
std::get<RoleMainData>(dbms_handler_.ReplicationState().ReplicationData()).registered_replicas_.clear();
// Creates the server
dbms_handler_.ReplicationState().SetReplicationRoleReplica(config);
// Start
const auto success =
std::visit(utils::Overloaded{[](RoleMainData const &) {
// ASSERT
return false;
},
[this](RoleReplicaData const &data) { return StartRpcServer(dbms_handler_, data); }},
dbms_handler_.ReplicationState().ReplicationData());
// TODO Handle error (restore to main?)
return success;
}
auto ReplicationHandler::RegisterReplica(const memgraph::replication::ReplicationClientConfig &config)
-> memgraph::utils::BasicResult<RegisterReplicaError> {
MG_ASSERT(dbms_handler_.ReplicationState().IsMain(), "Only main instance can register a replica!");
auto maybe_client = dbms_handler_.ReplicationState().RegisterReplica(config);
if (maybe_client.HasError()) {
switch (maybe_client.GetError()) {
case memgraph::replication::RegisterReplicaError::NOT_MAIN:
MG_ASSERT(false, "Only main instance can register a replica!");
return {};
case memgraph::replication::RegisterReplicaError::NAME_EXISTS:
return memgraph::dbms::RegisterReplicaError::NAME_EXISTS;
case memgraph::replication::RegisterReplicaError::ENDPOINT_EXISTS:
return memgraph::dbms::RegisterReplicaError::ENDPOINT_EXISTS;
case memgraph::replication::RegisterReplicaError::COULD_NOT_BE_PERSISTED:
return memgraph::dbms::RegisterReplicaError::COULD_NOT_BE_PERSISTED;
case memgraph::replication::RegisterReplicaError::SUCCESS:
break;
}
}
if (!allow_mt_repl && dbms_handler_.All().size() > 1) {
spdlog::warn("Multi-tenant replication is currently not supported!");
}
#ifdef MG_ENTERPRISE
// Update system before enabling individual storage <-> replica clients
dbms_handler_.SystemRestore(*maybe_client.GetValue());
#endif
const auto dbms_error = memgraph::dbms::HandleRegisterReplicaStatus(maybe_client);
if (dbms_error.has_value()) {
return *dbms_error;
}
auto &instance_client_ptr = maybe_client.GetValue();
const bool all_clients_good = memgraph::dbms::RegisterAllDatabasesClients(dbms_handler_, *instance_client_ptr);
// NOTE Currently if any databases fails, we revert back
if (!all_clients_good) {
spdlog::error("Failed to register all databases on the REPLICA \"{}\"", config.name);
UnregisterReplica(config.name);
return RegisterReplicaError::CONNECTION_FAILED;
}
// No client error, start instance level client
StartReplicaClient(dbms_handler_, *instance_client_ptr);
return {};
}
auto ReplicationHandler::UnregisterReplica(std::string_view name) -> UnregisterReplicaResult {
auto const replica_handler = [](RoleReplicaData const &) -> UnregisterReplicaResult {
return UnregisterReplicaResult::NOT_MAIN;
};
auto const main_handler = [this, name](RoleMainData &mainData) -> UnregisterReplicaResult {
if (!dbms_handler_.ReplicationState().TryPersistUnregisterReplica(name)) {
return UnregisterReplicaResult::COULD_NOT_BE_PERSISTED;
}
// Remove database specific clients
dbms_handler_.ForEach([name](DatabaseAccess db_acc) {
db_acc->storage()->repl_storage_state_.replication_clients_.WithLock([&name](auto &clients) {
std::erase_if(clients, [name](const auto &client) { return client->Name() == name; });
});
});
// Remove instance level clients
auto const n_unregistered =
std::erase_if(mainData.registered_replicas_, [name](auto const &client) { return client.name_ == name; });
return n_unregistered != 0 ? UnregisterReplicaResult::SUCCESS : UnregisterReplicaResult::CAN_NOT_UNREGISTER;
};
return std::visit(utils::Overloaded{main_handler, replica_handler},
dbms_handler_.ReplicationState().ReplicationData());
}
auto ReplicationHandler::GetRole() const -> memgraph::replication_coordination_glue::ReplicationRole {
return dbms_handler_.ReplicationState().GetRole();
}
bool ReplicationHandler::IsMain() const { return dbms_handler_.ReplicationState().IsMain(); }
bool ReplicationHandler::IsReplica() const { return dbms_handler_.ReplicationState().IsReplica(); }
// Per storage
// NOTE Storage will connect to all replicas. Future work might change this
void RestoreReplication(replication::ReplicationState &repl_state, DatabaseAccess db_acc) {
spdlog::info("Restoring replication role.");
/// MAIN
auto const recover_main = [db_acc = std::move(db_acc)](RoleMainData &mainData) mutable { // NOLINT
// Each individual client has already been restored and started. Here we just go through each database and start its
// client
for (auto &instance_client : mainData.registered_replicas_) {
spdlog::info("Replica {} restoration started for {}.", instance_client.name_, db_acc->name());
const auto &ret = db_acc->storage()->repl_storage_state_.replication_clients_.WithLock(
[&, db_acc](auto &storage_clients) mutable -> utils::BasicResult<RegisterReplicaError> {
auto client = std::make_unique<storage::ReplicationStorageClient>(instance_client);
auto *storage = db_acc->storage();
client->Start(storage, std::move(db_acc));
// After start the storage <-> replica state should be READY or RECOVERING (if correctly started)
// MAYBE_BEHIND isn't a statement of the current state, this is the default value
// Failed to start due to branching of MAIN and REPLICA
if (client->State() == storage::replication::ReplicaState::MAYBE_BEHIND) {
spdlog::warn("Connection failed when registering replica {}. Replica will still be registered.",
instance_client.name_);
}
storage_clients.push_back(std::move(client));
return {};
});
if (ret.HasError()) {
MG_ASSERT(RegisterReplicaError::CONNECTION_FAILED != ret.GetError());
LOG_FATAL("Failure when restoring replica {}: {}.", instance_client.name_,
RegisterReplicaErrorToString(ret.GetError()));
}
spdlog::info("Replica {} restored for {}.", instance_client.name_, db_acc->name());
}
spdlog::info("Replication role restored to MAIN.");
};
/// REPLICA
auto const recover_replica = [](RoleReplicaData const &data) { /*nothing to do*/ };
std::visit(
utils::Overloaded{
recover_main,
recover_replica,
},
repl_state.ReplicationData());
}
namespace system_replication {
#ifdef MG_ENTERPRISE
void SystemHeartbeatHandler(const uint64_t ts, slk::Reader *req_reader, slk::Builder *res_builder) {
replication::SystemHeartbeatReq req;
replication::SystemHeartbeatReq::Load(&req, req_reader);
replication::SystemHeartbeatRes res(ts);
memgraph::slk::Save(res, res_builder);
}
void CreateDatabaseHandler(DbmsHandler &dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder) {
memgraph::storage::replication::CreateDatabaseReq req;
memgraph::slk::Load(&req, req_reader);
using memgraph::storage::replication::CreateDatabaseRes;
CreateDatabaseRes res(CreateDatabaseRes::Result::FAILURE);
// Note: No need to check epoch, recovery mechanism is done by a full uptodate snapshot
// of the set of databases. Hence no history exists to maintain regarding epoch change.
// If MAIN has changed we need to check this new group_timestamp is consistent with
// what we have so far.
if (req.expected_group_timestamp != dbms_handler.LastCommitedTS()) {
spdlog::debug("CreateDatabaseHandler: bad expected timestamp {},{}", req.expected_group_timestamp,
dbms_handler.LastCommitedTS());
memgraph::slk::Save(res, res_builder);
return;
}
try {
// Create new
auto new_db = dbms_handler.Update(req.config);
if (new_db.HasValue()) {
// Successfully create db
dbms_handler.SetLastCommitedTS(req.new_group_timestamp);
res = CreateDatabaseRes(CreateDatabaseRes::Result::SUCCESS);
spdlog::debug("CreateDatabaseHandler: SUCCESS updated LCTS to {}", req.new_group_timestamp);
}
} catch (...) {
// Failure
}
memgraph::slk::Save(res, res_builder);
}
void DropDatabaseHandler(DbmsHandler &dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder) {
memgraph::storage::replication::DropDatabaseReq req;
memgraph::slk::Load(&req, req_reader);
using memgraph::storage::replication::DropDatabaseRes;
DropDatabaseRes res(DropDatabaseRes::Result::FAILURE);
// Note: No need to check epoch, recovery mechanism is done by a full uptodate snapshot
// of the set of databases. Hence no history exists to maintain regarding epoch change.
// If MAIN has changed we need to check this new group_timestamp is consistent with
// what we have so far.
if (req.expected_group_timestamp != dbms_handler.LastCommitedTS()) {
spdlog::debug("DropDatabaseHandler: bad expected timestamp {},{}", req.expected_group_timestamp,
dbms_handler.LastCommitedTS());
memgraph::slk::Save(res, res_builder);
return;
}
try {
// NOTE: Single communication channel can exist at a time, no other database can be deleted/created at the moment.
auto new_db = dbms_handler.Delete(req.uuid);
if (new_db.HasError()) {
if (new_db.GetError() == DeleteError::NON_EXISTENT) {
// Nothing to drop
dbms_handler.SetLastCommitedTS(req.new_group_timestamp);
res = DropDatabaseRes(DropDatabaseRes::Result::NO_NEED);
}
} else {
// Successfully drop db
dbms_handler.SetLastCommitedTS(req.new_group_timestamp);
res = DropDatabaseRes(DropDatabaseRes::Result::SUCCESS);
spdlog::debug("DropDatabaseHandler: SUCCESS updated LCTS to {}", req.new_group_timestamp);
}
} catch (...) {
// Failure
}
memgraph::slk::Save(res, res_builder);
}
void SystemRecoveryHandler(DbmsHandler &dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder) {
// TODO Speed up
memgraph::storage::replication::SystemRecoveryReq req;
memgraph::slk::Load(&req, req_reader);
using memgraph::storage::replication::SystemRecoveryRes;
SystemRecoveryRes res(SystemRecoveryRes::Result::FAILURE);
utils::OnScopeExit send_on_exit([&]() { memgraph::slk::Save(res, res_builder); });
// Get all current dbs
auto old = dbms_handler.All();
// Check/create the incoming dbs
for (const auto &config : req.database_configs) {
// Missing db
try {
if (dbms_handler.Update(config).HasError()) {
spdlog::debug("SystemRecoveryHandler: Failed to update database \"{}\".", config.name);
return; // Send failure on exit
}
} catch (const UnknownDatabaseException &) {
spdlog::debug("SystemRecoveryHandler: UnknownDatabaseException");
return; // Send failure on exit
}
const auto it = std::find(old.begin(), old.end(), config.name);
if (it != old.end()) old.erase(it);
}
// Delete all the leftover old dbs
for (const auto &remove_db : old) {
const auto del = dbms_handler.Delete(remove_db);
if (del.HasError()) {
// Some errors are not terminal
if (del.GetError() == DeleteError::DEFAULT_DB || del.GetError() == DeleteError::NON_EXISTENT) {
spdlog::debug("SystemRecoveryHandler: Dropped database \"{}\".", remove_db);
continue;
}
spdlog::debug("SystemRecoveryHandler: Failed to drop database \"{}\".", remove_db);
return; // Send failure on exit
}
}
// Successfully recovered
dbms_handler.SetLastCommitedTS(req.forced_group_timestamp);
spdlog::debug("SystemRecoveryHandler: SUCCESS updated LCTS to {}", req.forced_group_timestamp);
res = SystemRecoveryRes(SystemRecoveryRes::Result::SUCCESS);
}
#endif
void Register(replication::RoleReplicaData const &data, dbms::DbmsHandler &dbms_handler) {
#ifdef MG_ENTERPRISE
data.server->rpc_server_.Register<replication::SystemHeartbeatRpc>(
[&dbms_handler](auto *req_reader, auto *res_builder) {
spdlog::debug("Received SystemHeartbeatRpc");
SystemHeartbeatHandler(dbms_handler.LastCommitedTS(), req_reader, res_builder);
});
data.server->rpc_server_.Register<storage::replication::CreateDatabaseRpc>(
[&dbms_handler](auto *req_reader, auto *res_builder) {
spdlog::debug("Received CreateDatabaseRpc");
CreateDatabaseHandler(dbms_handler, req_reader, res_builder);
});
data.server->rpc_server_.Register<storage::replication::DropDatabaseRpc>(
[&dbms_handler](auto *req_reader, auto *res_builder) {
spdlog::debug("Received DropDatabaseRpc");
DropDatabaseHandler(dbms_handler, req_reader, res_builder);
});
data.server->rpc_server_.Register<storage::replication::SystemRecoveryRpc>(
[&dbms_handler](auto *req_reader, auto *res_builder) {
spdlog::debug("Received SystemRecoveryRpc");
SystemRecoveryHandler(dbms_handler, req_reader, res_builder);
});
#endif
}
} // namespace system_replication
bool StartRpcServer(DbmsHandler &dbms_handler, const replication::RoleReplicaData &data) {
// Register handlers
InMemoryReplicationHandlers::Register(&dbms_handler, *data.server);
system_replication::Register(data, dbms_handler);
// Start server
if (!data.server->Start()) {
spdlog::error("Unable to start the replication server.");
return false;
}
return true;
}
} // namespace memgraph::dbms

View File

@ -1,82 +0,0 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "replication_coordination_glue/role.hpp"
#include "dbms/database.hpp"
#include "utils/result.hpp"
namespace memgraph::replication {
struct ReplicationState;
struct ReplicationServerConfig;
struct ReplicationClientConfig;
} // namespace memgraph::replication
namespace memgraph::dbms {
class DbmsHandler;
enum class RegisterReplicaError : uint8_t { NAME_EXISTS, ENDPOINT_EXISTS, CONNECTION_FAILED, COULD_NOT_BE_PERSISTED };
enum class UnregisterReplicaResult : uint8_t {
NOT_MAIN,
COULD_NOT_BE_PERSISTED,
CAN_NOT_UNREGISTER,
SUCCESS,
};
/// A handler type that keep in sync current ReplicationState and the MAIN/REPLICA-ness of Storage
/// TODO: extend to do multiple storages
struct ReplicationHandler {
explicit ReplicationHandler(DbmsHandler &dbms_handler);
// as REPLICA, become MAIN
bool SetReplicationRoleMain();
// as MAIN, become REPLICA
bool SetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config);
// as MAIN, define and connect to REPLICAs
auto RegisterReplica(const memgraph::replication::ReplicationClientConfig &config)
-> utils::BasicResult<RegisterReplicaError>;
// as MAIN, remove a REPLICA connection
auto UnregisterReplica(std::string_view name) -> UnregisterReplicaResult;
// Helper pass-through (TODO: remove)
auto GetRole() const -> memgraph::replication_coordination_glue::ReplicationRole;
bool IsMain() const;
bool IsReplica() const;
private:
DbmsHandler &dbms_handler_;
};
/// A handler type that keep in sync current ReplicationState and the MAIN/REPLICA-ness of Storage
/// TODO: extend to do multiple storages
void RestoreReplication(replication::ReplicationState &repl_state, DatabaseAccess db_acc);
namespace system_replication {
// System handlers
#ifdef MG_ENTERPRISE
void CreateDatabaseHandler(DbmsHandler &dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder);
void SystemHeartbeatHandler(uint64_t ts, slk::Reader *req_reader, slk::Builder *res_builder);
void SystemRecoveryHandler(DbmsHandler &dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder);
#endif
/// Register all DBMS level RPC handlers
void Register(replication::RoleReplicaData const &data, DbmsHandler &dbms_handler);
} // namespace system_replication
bool StartRpcServer(DbmsHandler &dbms_handler, const replication::RoleReplicaData &data);
} // namespace memgraph::dbms

View File

@ -0,0 +1,191 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "dbms/replication_handlers.hpp"
#include "dbms/database.hpp"
#include "dbms/dbms_handler.hpp"
#include "storage/v2/storage.hpp"
#include "system/state.hpp"
namespace memgraph::dbms {
#ifdef MG_ENTERPRISE
void CreateDatabaseHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access,
DbmsHandler &dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder) {
using memgraph::storage::replication::CreateDatabaseRes;
CreateDatabaseRes res(CreateDatabaseRes::Result::FAILURE);
// Ignore if no license
if (!license::global_license_checker.IsEnterpriseValidFast()) {
spdlog::error("Handling CreateDatabase, an enterprise RPC message, without license.");
memgraph::slk::Save(res, res_builder);
return;
}
memgraph::storage::replication::CreateDatabaseReq req;
memgraph::slk::Load(&req, req_reader);
// Note: No need to check epoch, recovery mechanism is done by a full uptodate snapshot
// of the set of databases. Hence no history exists to maintain regarding epoch change.
// If MAIN has changed we need to check this new group_timestamp is consistent with
// what we have so far.
if (req.expected_group_timestamp != system_state_access.LastCommitedTS()) {
spdlog::debug("CreateDatabaseHandler: bad expected timestamp {},{}", req.expected_group_timestamp,
system_state_access.LastCommitedTS());
memgraph::slk::Save(res, res_builder);
return;
}
try {
// Create new
auto new_db = dbms_handler.Update(req.config);
if (new_db.HasValue()) {
// Successfully create db
system_state_access.SetLastCommitedTS(req.new_group_timestamp);
res = CreateDatabaseRes(CreateDatabaseRes::Result::SUCCESS);
spdlog::debug("CreateDatabaseHandler: SUCCESS updated LCTS to {}", req.new_group_timestamp);
}
} catch (...) {
// Failure
}
memgraph::slk::Save(res, res_builder);
}
void DropDatabaseHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access, DbmsHandler &dbms_handler,
slk::Reader *req_reader, slk::Builder *res_builder) {
using memgraph::storage::replication::DropDatabaseRes;
DropDatabaseRes res(DropDatabaseRes::Result::FAILURE);
// Ignore if no license
if (!license::global_license_checker.IsEnterpriseValidFast()) {
spdlog::error("Handling DropDatabase, an enterprise RPC message, without license.");
memgraph::slk::Save(res, res_builder);
return;
}
memgraph::storage::replication::DropDatabaseReq req;
memgraph::slk::Load(&req, req_reader);
// Note: No need to check epoch, recovery mechanism is done by a full uptodate snapshot
// of the set of databases. Hence no history exists to maintain regarding epoch change.
// If MAIN has changed we need to check this new group_timestamp is consistent with
// what we have so far.
if (req.expected_group_timestamp != system_state_access.LastCommitedTS()) {
spdlog::debug("DropDatabaseHandler: bad expected timestamp {},{}", req.expected_group_timestamp,
system_state_access.LastCommitedTS());
memgraph::slk::Save(res, res_builder);
return;
}
try {
// NOTE: Single communication channel can exist at a time, no other database can be deleted/created at the moment.
auto new_db = dbms_handler.Delete(req.uuid);
if (new_db.HasError()) {
if (new_db.GetError() == DeleteError::NON_EXISTENT) {
// Nothing to drop
system_state_access.SetLastCommitedTS(req.new_group_timestamp);
res = DropDatabaseRes(DropDatabaseRes::Result::NO_NEED);
}
} else {
// Successfully drop db
system_state_access.SetLastCommitedTS(req.new_group_timestamp);
res = DropDatabaseRes(DropDatabaseRes::Result::SUCCESS);
spdlog::debug("DropDatabaseHandler: SUCCESS updated LCTS to {}", req.new_group_timestamp);
}
} catch (...) {
// Failure
}
memgraph::slk::Save(res, res_builder);
}
bool SystemRecoveryHandler(DbmsHandler &dbms_handler, const std::vector<storage::SalientConfig> &database_configs) {
/*
* NO LICENSE
*/
if (!license::global_license_checker.IsEnterpriseValidFast()) {
spdlog::error("Handling SystemRecovery, an enterprise RPC message, without license.");
for (const auto &config : database_configs) {
// Only handle default DB
if (config.name != kDefaultDB) continue;
try {
if (dbms_handler.Update(config).HasError()) {
return false;
}
} catch (const UnknownDatabaseException &) {
return false;
}
}
return true;
}
/*
* MULTI-TENANCY
*/
// Get all current dbs
auto old = dbms_handler.All();
// Check/create the incoming dbs
for (const auto &config : database_configs) {
// Missing db
try {
if (dbms_handler.Update(config).HasError()) {
spdlog::debug("SystemRecoveryHandler: Failed to update database \"{}\".", config.name);
return false;
}
} catch (const UnknownDatabaseException &) {
spdlog::debug("SystemRecoveryHandler: UnknownDatabaseException");
return false;
}
const auto it = std::find(old.begin(), old.end(), config.name);
if (it != old.end()) old.erase(it);
}
// Delete all the leftover old dbs
for (const auto &remove_db : old) {
const auto del = dbms_handler.Delete(remove_db);
if (del.HasError()) {
// Some errors are not terminal
if (del.GetError() == DeleteError::DEFAULT_DB || del.GetError() == DeleteError::NON_EXISTENT) {
spdlog::debug("SystemRecoveryHandler: Dropped database \"{}\".", remove_db);
continue;
}
spdlog::debug("SystemRecoveryHandler: Failed to drop database \"{}\".", remove_db);
return false;
}
}
/*
* SUCCESS
*/
return true;
}
void Register(replication::RoleReplicaData const &data, system::ReplicaHandlerAccessToState &system_state_access,
dbms::DbmsHandler &dbms_handler) {
// NOTE: Register even without license as the user could add a license at run-time
data.server->rpc_server_.Register<storage::replication::CreateDatabaseRpc>(
[system_state_access, &dbms_handler](auto *req_reader, auto *res_builder) mutable {
spdlog::debug("Received CreateDatabaseRpc");
CreateDatabaseHandler(system_state_access, dbms_handler, req_reader, res_builder);
});
data.server->rpc_server_.Register<storage::replication::DropDatabaseRpc>(
[system_state_access, &dbms_handler](auto *req_reader, auto *res_builder) mutable {
spdlog::debug("Received DropDatabaseRpc");
DropDatabaseHandler(system_state_access, dbms_handler, req_reader, res_builder);
});
}
#endif
} // namespace memgraph::dbms

View File

@ -0,0 +1,32 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "dbms/dbms_handler.hpp"
#include "replication/state.hpp"
#include "slk/streams.hpp"
#include "system/state.hpp"
namespace memgraph::dbms {
#ifdef MG_ENTERPRISE
// RPC handlers
void CreateDatabaseHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access,
DbmsHandler &dbms_handler, slk::Reader *req_reader, slk::Builder *res_builder);
void DropDatabaseHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access, DbmsHandler &dbms_handler,
slk::Reader *req_reader, slk::Builder *res_builder);
bool SystemRecoveryHandler(DbmsHandler &dbms_handler, const std::vector<storage::SalientConfig> &database_configs);
// RPC registration
void Register(replication::RoleReplicaData const &data, system::ReplicaHandlerAccessToState &system_state_access,
dbms::DbmsHandler &dbms_handler);
#endif
} // namespace memgraph::dbms

118
src/dbms/rpc.cpp Normal file
View File

@ -0,0 +1,118 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "dbms/rpc.hpp"
#include "slk/streams.hpp"
#include "storage/v2/replication/rpc.hpp"
#include "utils/enum.hpp"
#include "utils/typeinfo.hpp"
namespace memgraph {
namespace storage::replication {
void CreateDatabaseReq::Save(const CreateDatabaseReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void CreateDatabaseReq::Load(CreateDatabaseReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
void CreateDatabaseRes::Save(const CreateDatabaseRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void CreateDatabaseRes::Load(CreateDatabaseRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
void DropDatabaseReq::Save(const DropDatabaseReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void DropDatabaseReq::Load(DropDatabaseReq *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); }
void DropDatabaseRes::Save(const DropDatabaseRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void DropDatabaseRes::Load(DropDatabaseRes *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); }
const utils::TypeInfo CreateDatabaseReq::kType{utils::TypeId::REP_CREATE_DATABASE_REQ, "CreateDatabaseReq", nullptr};
const utils::TypeInfo CreateDatabaseRes::kType{utils::TypeId::REP_CREATE_DATABASE_RES, "CreateDatabaseRes", nullptr};
const utils::TypeInfo DropDatabaseReq::kType{utils::TypeId::REP_DROP_DATABASE_REQ, "DropDatabaseReq", nullptr};
const utils::TypeInfo DropDatabaseRes::kType{utils::TypeId::REP_DROP_DATABASE_RES, "DropDatabaseRes", nullptr};
} // namespace storage::replication
// Autogenerated SLK serialization code
namespace slk {
// Serialize code for CreateDatabaseReq
void Save(const memgraph::storage::replication::CreateDatabaseReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.epoch_id, builder);
memgraph::slk::Save(self.expected_group_timestamp, builder);
memgraph::slk::Save(self.new_group_timestamp, builder);
memgraph::slk::Save(self.config, builder);
}
void Load(memgraph::storage::replication::CreateDatabaseReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->epoch_id, reader);
memgraph::slk::Load(&self->expected_group_timestamp, reader);
memgraph::slk::Load(&self->new_group_timestamp, reader);
memgraph::slk::Load(&self->config, reader);
}
// Serialize code for CreateDatabaseRes
void Save(const memgraph::storage::replication::CreateDatabaseRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(utils::EnumToNum<uint8_t>(self.result), builder);
}
void Load(memgraph::storage::replication::CreateDatabaseRes *self, memgraph::slk::Reader *reader) {
uint8_t res = 0;
memgraph::slk::Load(&res, reader);
if (!utils::NumToEnum(res, self->result)) {
throw SlkReaderException("Unexpected result line:{}!", __LINE__);
}
}
// Serialize code for DropDatabaseReq
void Save(const memgraph::storage::replication::DropDatabaseReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.epoch_id, builder);
memgraph::slk::Save(self.expected_group_timestamp, builder);
memgraph::slk::Save(self.new_group_timestamp, builder);
memgraph::slk::Save(self.uuid, builder);
}
void Load(memgraph::storage::replication::DropDatabaseReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->epoch_id, reader);
memgraph::slk::Load(&self->expected_group_timestamp, reader);
memgraph::slk::Load(&self->new_group_timestamp, reader);
memgraph::slk::Load(&self->uuid, reader);
}
// Serialize code for DropDatabaseRes
void Save(const memgraph::storage::replication::DropDatabaseRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(utils::EnumToNum<uint8_t>(self.result), builder);
}
void Load(memgraph::storage::replication::DropDatabaseRes *self, memgraph::slk::Reader *reader) {
uint8_t res = 0;
memgraph::slk::Load(&res, reader);
if (!utils::NumToEnum(res, self->result)) {
throw SlkReaderException("Unexpected result line:{}!", __LINE__);
}
}
} // namespace slk
} // namespace memgraph

118
src/dbms/rpc.hpp Normal file
View File

@ -0,0 +1,118 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <cstdint>
#include <string>
#include <utility>
#include "rpc/messages.hpp"
#include "slk/streams.hpp"
#include "storage/v2/config.hpp"
#include "utils/uuid.hpp"
namespace memgraph::storage::replication {
struct CreateDatabaseReq {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(CreateDatabaseReq *self, memgraph::slk::Reader *reader);
static void Save(const CreateDatabaseReq &self, memgraph::slk::Builder *builder);
CreateDatabaseReq() = default;
CreateDatabaseReq(std::string_view epoch_id, uint64_t expected_group_timestamp, uint64_t new_group_timestamp,
storage::SalientConfig config)
: epoch_id(std::string(epoch_id)),
expected_group_timestamp{expected_group_timestamp},
new_group_timestamp(new_group_timestamp),
config(std::move(config)) {}
std::string epoch_id;
uint64_t expected_group_timestamp;
uint64_t new_group_timestamp;
storage::SalientConfig config;
};
struct CreateDatabaseRes {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
enum class Result : uint8_t { SUCCESS, NO_NEED, FAILURE, /* Leave at end */ N };
static void Load(CreateDatabaseRes *self, memgraph::slk::Reader *reader);
static void Save(const CreateDatabaseRes &self, memgraph::slk::Builder *builder);
CreateDatabaseRes() = default;
explicit CreateDatabaseRes(Result res) : result(res) {}
Result result;
};
using CreateDatabaseRpc = rpc::RequestResponse<CreateDatabaseReq, CreateDatabaseRes>;
struct DropDatabaseReq {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(DropDatabaseReq *self, memgraph::slk::Reader *reader);
static void Save(const DropDatabaseReq &self, memgraph::slk::Builder *builder);
DropDatabaseReq() = default;
DropDatabaseReq(std::string_view epoch_id, uint64_t expected_group_timestamp, uint64_t new_group_timestamp,
const utils::UUID &uuid)
: epoch_id(std::string(epoch_id)),
expected_group_timestamp{expected_group_timestamp},
new_group_timestamp(new_group_timestamp),
uuid(uuid) {}
std::string epoch_id;
uint64_t expected_group_timestamp;
uint64_t new_group_timestamp;
utils::UUID uuid;
};
struct DropDatabaseRes {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
enum class Result : uint8_t { SUCCESS, NO_NEED, FAILURE, /* Leave at end */ N };
static void Load(DropDatabaseRes *self, memgraph::slk::Reader *reader);
static void Save(const DropDatabaseRes &self, memgraph::slk::Builder *builder);
DropDatabaseRes() = default;
explicit DropDatabaseRes(Result res) : result(res) {}
Result result;
};
using DropDatabaseRpc = rpc::RequestResponse<DropDatabaseReq, DropDatabaseRes>;
} // namespace memgraph::storage::replication
// SLK serialization declarations
namespace memgraph::slk {
void Save(const memgraph::storage::replication::CreateDatabaseReq &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::CreateDatabaseReq *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::CreateDatabaseRes &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::CreateDatabaseRes *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::DropDatabaseReq &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::DropDatabaseReq *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::DropDatabaseRes &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::DropDatabaseRes *self, memgraph::slk::Reader *reader);
} // namespace memgraph::slk

View File

@ -11,7 +11,10 @@
#pragma once
#include <list>
#include <memory>
#include <optional>
#include "auth/models.hpp"
#include "storage/v2/config.hpp"
namespace memgraph::dbms {
@ -20,17 +23,70 @@ struct SystemTransaction {
enum class Action {
CREATE_DATABASE,
DROP_DATABASE,
UPDATE_AUTH_DATA,
DROP_AUTH_DATA,
/**
*
* CREATE USER user_name [IDENTIFIED BY 'password'];
* SET PASSWORD FOR user_name TO 'new_password';
* ^ SaveUser
*
* DROP USER user_name;
* ^ Directly on KVStore
*
* CREATE ROLE role_name;
* ^ SaveRole
*
* DROP ROLE
* ^ RemoveRole
*
* SET ROLE FOR user_name TO role_name;
* CLEAR ROLE FOR user_name;
* ^ Do stuff then do SaveUser
*
* GRANT privilege_list TO user_or_role;
* DENY AUTH, INDEX TO moderator:
* REVOKE AUTH, INDEX TO moderator:
* GRANT permission_level ON (LABELS | EDGE_TYPES) label_list TO user_or_role;
* REVOKE (LABELS | EDGE_TYPES) label_or_edge_type_list FROM user_or_role
* DENY (LABELS | EDGE_TYPES) label_or_edge_type_list TO user_or_role
* ^ all of these are EditPermissions <-> SaveUser/Role
*
* Multi-tenant TODO Doc;
* ^ Should all call SaveUser
*
*/
};
static constexpr struct CreateDatabase {
} create_database;
static constexpr struct DropDatabase {
} drop_database;
static constexpr struct UpdateAuthData {
} update_auth_data;
static constexpr struct DropAuthData {
} drop_auth_data;
enum class AuthData { USER, ROLE };
// Multi-tenancy
Delta(CreateDatabase /*tag*/, storage::SalientConfig config)
: action(Action::CREATE_DATABASE), config(std::move(config)) {}
Delta(DropDatabase /*tag*/, const utils::UUID &uuid) : action(Action::DROP_DATABASE), uuid(uuid) {}
// Auth
Delta(UpdateAuthData /*tag*/, std::optional<auth::User> user)
: action(Action::UPDATE_AUTH_DATA), auth_data{std::move(user), std::nullopt} {}
Delta(UpdateAuthData /*tag*/, std::optional<auth::Role> role)
: action(Action::UPDATE_AUTH_DATA), auth_data{std::nullopt, std::move(role)} {}
Delta(DropAuthData /*tag*/, AuthData type, std::string_view name)
: action(Action::DROP_AUTH_DATA),
auth_data_key{
.type = type,
.name = std::string{name},
} {}
// Generic
Delta(const Delta &) = delete;
Delta(Delta &&) = delete;
Delta &operator=(const Delta &) = delete;
@ -42,8 +98,14 @@ struct SystemTransaction {
std::destroy_at(&config);
break;
case Action::DROP_DATABASE:
std::destroy_at(&uuid);
break;
case Action::UPDATE_AUTH_DATA:
std::destroy_at(&auth_data);
break;
case Action::DROP_AUTH_DATA:
std::destroy_at(&auth_data_key);
break;
// Some deltas might have special destructor handling
}
}
@ -51,13 +113,20 @@ struct SystemTransaction {
union {
storage::SalientConfig config;
utils::UUID uuid;
struct {
std::optional<auth::User> user;
std::optional<auth::Role> role;
} auth_data;
struct {
AuthData type;
std::string name;
} auth_data_key;
};
};
explicit SystemTransaction(uint64_t timestamp) : system_timestamp(timestamp) {}
// Currently system transitions support a single delta
std::optional<Delta> delta{};
std::list<Delta> deltas{};
uint64_t system_timestamp;
};

View File

@ -1,131 +0,0 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "dbms/dbms_handler.hpp"
#include "dbms/replication_handler.hpp"
#include "replication/include/replication/state.hpp"
#include "utils/result.hpp"
namespace memgraph::dbms {
inline bool DoReplicaToMainPromotion(dbms::DbmsHandler &dbms_handler) {
auto &repl_state = dbms_handler.ReplicationState();
// STEP 1) bring down all REPLICA servers
dbms_handler.ForEach([](DatabaseAccess db_acc) {
auto *storage = db_acc->storage();
// Remember old epoch + storage timestamp association
storage->PrepareForNewEpoch();
});
// STEP 2) Change to MAIN
// TODO: restore replication servers if false?
if (!repl_state.SetReplicationRoleMain()) {
// TODO: Handle recovery on failure???
return false;
}
// STEP 3) We are now MAIN, update storage local epoch
const auto &epoch =
std::get<replication::RoleMainData>(std::as_const(dbms_handler.ReplicationState()).ReplicationData()).epoch_;
dbms_handler.ForEach([&](DatabaseAccess db_acc) {
auto *storage = db_acc->storage();
storage->repl_storage_state_.epoch_ = epoch;
});
return true;
};
inline bool SetReplicationRoleReplica(dbms::DbmsHandler &dbms_handler,
const memgraph::replication::ReplicationServerConfig &config) {
if (dbms_handler.ReplicationState().IsReplica()) {
return false;
}
// TODO StorageState needs to be synched. Could have a dangling reference if someone adds a database as we are
// deleting the replica.
// Remove database specific clients
dbms_handler.ForEach([&](DatabaseAccess db_acc) {
auto *storage = db_acc->storage();
storage->repl_storage_state_.replication_clients_.WithLock([](auto &clients) { clients.clear(); });
});
// Remove instance level clients
std::get<replication::RoleMainData>(dbms_handler.ReplicationState().ReplicationData()).registered_replicas_.clear();
// Creates the server
dbms_handler.ReplicationState().SetReplicationRoleReplica(config);
// Start
const auto success = std::visit(utils::Overloaded{[](replication::RoleMainData const &) {
// ASSERT
return false;
},
[&dbms_handler](replication::RoleReplicaData const &data) {
return StartRpcServer(dbms_handler, data);
}},
dbms_handler.ReplicationState().ReplicationData());
// TODO Handle error (restore to main?)
return success;
}
template <bool AllowRPCFailure = false>
inline bool RegisterAllDatabasesClients(dbms::DbmsHandler &dbms_handler,
replication::ReplicationClient &instance_client) {
if (!allow_mt_repl && dbms_handler.All().size() > 1) {
spdlog::warn("Multi-tenant replication is currently not supported!");
}
bool all_clients_good = true;
dbms_handler.ForEach([&](DatabaseAccess db_acc) {
auto *storage = db_acc->storage();
if (!allow_mt_repl && storage->name() != kDefaultDB) {
return;
}
// TODO: ATM only IN_MEMORY_TRANSACTIONAL, fix other modes
if (storage->storage_mode_ != storage::StorageMode::IN_MEMORY_TRANSACTIONAL) return;
using enum storage::replication::ReplicaState;
all_clients_good &= storage->repl_storage_state_.replication_clients_.WithLock(
[storage, &instance_client, db_acc = std::move(db_acc)](auto &storage_clients) mutable { // NOLINT
auto client = std::make_unique<storage::ReplicationStorageClient>(instance_client);
client->Start(storage, std::move(db_acc));
if (client->State() == MAYBE_BEHIND && !AllowRPCFailure) {
return false;
}
storage_clients.push_back(std::move(client));
return true;
});
});
return all_clients_good;
}
inline std::optional<RegisterReplicaError> HandleRegisterReplicaStatus(
utils::BasicResult<replication::RegisterReplicaError, replication::ReplicationClient *> &instance_client) {
if (instance_client.HasError()) switch (instance_client.GetError()) {
case replication::RegisterReplicaError::NOT_MAIN:
MG_ASSERT(false, "Only main instance can register a replica!");
return {};
case replication::RegisterReplicaError::NAME_EXISTS:
return dbms::RegisterReplicaError::NAME_EXISTS;
case replication::RegisterReplicaError::ENDPOINT_EXISTS:
return dbms::RegisterReplicaError::ENDPOINT_EXISTS;
case replication::RegisterReplicaError::COULD_NOT_BE_PERSISTED:
return dbms::RegisterReplicaError::COULD_NOT_BE_PERSISTED;
case replication::RegisterReplicaError::SUCCESS:
break;
}
return {};
}
} // namespace memgraph::dbms

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -33,7 +33,7 @@ class WritePrioritizedRWLock;
struct Context {
memgraph::query::InterpreterContext *ic;
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth;
memgraph::auth::SynchedAuth *auth;
#if MG_ENTERPRISE
memgraph::audit::Log *audit_log;
#endif

View File

@ -319,8 +319,7 @@ void SessionHL::Configure(const std::map<std::string, memgraph::communication::b
SessionHL::SessionHL(memgraph::query::InterpreterContext *interpreter_context,
memgraph::communication::v2::ServerEndpoint endpoint,
memgraph::communication::v2::InputStream *input_stream,
memgraph::communication::v2::OutputStream *output_stream,
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth
memgraph::communication::v2::OutputStream *output_stream, memgraph::auth::SynchedAuth *auth
#ifdef MG_ENTERPRISE
,
memgraph::audit::Log *audit_log

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -25,8 +25,7 @@ class SessionHL final : public memgraph::communication::bolt::Session<memgraph::
SessionHL(memgraph::query::InterpreterContext *interpreter_context,
memgraph::communication::v2::ServerEndpoint endpoint,
memgraph::communication::v2::InputStream *input_stream,
memgraph::communication::v2::OutputStream *output_stream,
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth
memgraph::communication::v2::OutputStream *output_stream, memgraph::auth::SynchedAuth *auth
#ifdef MG_ENTERPRISE
,
memgraph::audit::Log *audit_log
@ -88,7 +87,7 @@ class SessionHL final : public memgraph::communication::bolt::Session<memgraph::
memgraph::audit::Log *audit_log_;
bool in_explicit_db_{false}; //!< If true, the user has defined the database to use via metadata
#endif
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth_;
memgraph::auth::SynchedAuth *auth_;
memgraph::communication::v2::ServerEndpoint endpoint_;
std::optional<std::string> implicit_db_;
};

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -66,9 +66,7 @@ bool IsUserAuthorizedEdgeType(const memgraph::auth::User &user, const memgraph::
#endif
namespace memgraph::glue {
AuthChecker::AuthChecker(
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth)
: auth_(auth) {}
AuthChecker::AuthChecker(memgraph::auth::SynchedAuth *auth) : auth_(auth) {}
bool AuthChecker::IsUserAuthorized(const std::optional<std::string> &username,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges,

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -22,8 +22,7 @@ namespace memgraph::glue {
class AuthChecker : public query::AuthChecker {
public:
explicit AuthChecker(
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth);
explicit AuthChecker(memgraph::auth::SynchedAuth *auth);
bool IsUserAuthorized(const std::optional<std::string> &username,
const std::vector<query::AuthQuery::Privilege> &privileges,
@ -41,7 +40,7 @@ class AuthChecker : public query::AuthChecker {
const std::string &db_name = "");
private:
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth_;
memgraph::auth::SynchedAuth *auth_;
mutable memgraph::utils::Synchronized<auth::User, memgraph::utils::SpinLock> user_; // cached user
};
#ifdef MG_ENTERPRISE

View File

@ -210,16 +210,25 @@ std::vector<std::vector<memgraph::query::TypedValue>> ShowFineGrainedUserPrivile
if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
return {};
}
const auto &label_permissions = user->GetFineGrainedAccessLabelPermissions();
const auto &edge_type_permissions = user->GetFineGrainedAccessEdgeTypePermissions();
auto all_fine_grained_permissions =
GetFineGrainedPermissionForPrivilegeForUserOrRole(label_permissions, "LABEL", "USER");
auto edge_type_fine_grained_permissions =
GetFineGrainedPermissionForPrivilegeForUserOrRole(edge_type_permissions, "EDGE_TYPE", "USER");
auto all_fine_grained_permissions = GetFineGrainedPermissionForPrivilegeForUserOrRole(
user->GetUserFineGrainedAccessLabelPermissions(), "LABEL", "USER");
auto all_role_fine_grained_permissions = GetFineGrainedPermissionForPrivilegeForUserOrRole(
user->GetRoleFineGrainedAccessLabelPermissions(), "LABEL", "ROLE");
all_fine_grained_permissions.insert(all_fine_grained_permissions.end(),
std::make_move_iterator(all_role_fine_grained_permissions.begin()),
std::make_move_iterator(all_role_fine_grained_permissions.end()));
all_fine_grained_permissions.insert(all_fine_grained_permissions.end(), edge_type_fine_grained_permissions.begin(),
edge_type_fine_grained_permissions.end());
auto edge_type_fine_grained_permissions = GetFineGrainedPermissionForPrivilegeForUserOrRole(
user->GetUserFineGrainedAccessEdgeTypePermissions(), "EDGE_TYPE", "USER");
auto role_edge_type_fine_grained_permissions = GetFineGrainedPermissionForPrivilegeForUserOrRole(
user->GetRoleFineGrainedAccessEdgeTypePermissions(), "EDGE_TYPE", "ROLE");
all_fine_grained_permissions.insert(all_fine_grained_permissions.end(),
std::make_move_iterator(edge_type_fine_grained_permissions.begin()),
std::make_move_iterator(edge_type_fine_grained_permissions.end()));
all_fine_grained_permissions.insert(all_fine_grained_permissions.end(),
std::make_move_iterator(role_edge_type_fine_grained_permissions.begin()),
std::make_move_iterator(role_edge_type_fine_grained_permissions.end()));
return ConstructFineGrainedPrivilegesResult(all_fine_grained_permissions);
}
@ -233,9 +242,9 @@ std::vector<std::vector<memgraph::query::TypedValue>> ShowFineGrainedRolePrivile
const auto &edge_type_permissions = role->GetFineGrainedAccessEdgeTypePermissions();
auto all_fine_grained_permissions =
GetFineGrainedPermissionForPrivilegeForUserOrRole(label_permissions, "LABEL", "USER");
GetFineGrainedPermissionForPrivilegeForUserOrRole(label_permissions, "LABEL", "ROLE");
auto edge_type_fine_grained_permissions =
GetFineGrainedPermissionForPrivilegeForUserOrRole(edge_type_permissions, "EDGE_TYPE", "USER");
GetFineGrainedPermissionForPrivilegeForUserOrRole(edge_type_permissions, "EDGE_TYPE", "ROLE");
all_fine_grained_permissions.insert(all_fine_grained_permissions.end(), edge_type_fine_grained_permissions.begin(),
edge_type_fine_grained_permissions.end());
@ -248,16 +257,15 @@ std::vector<std::vector<memgraph::query::TypedValue>> ShowFineGrainedRolePrivile
namespace memgraph::glue {
AuthQueryHandler::AuthQueryHandler(
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth)
: auth_(auth) {}
AuthQueryHandler::AuthQueryHandler(memgraph::auth::SynchedAuth *auth) : auth_(auth) {}
bool AuthQueryHandler::CreateUser(const std::string &username, const std::optional<std::string> &password) {
bool AuthQueryHandler::CreateUser(const std::string &username, const std::optional<std::string> &password,
system::Transaction *system_tx) {
try {
const auto [first_user, user_added] = std::invoke([&, this] {
auto locked_auth = auth_->Lock();
const auto first_user = !locked_auth->HasUsers();
const auto user_added = locked_auth->AddUser(username, password).has_value();
const auto user_added = locked_auth->AddUser(username, password, system_tx).has_value();
return std::make_pair(first_user, user_added);
});
@ -276,10 +284,11 @@ bool AuthQueryHandler::CreateUser(const std::string &username, const std::option
}
}
#endif
);
,
system_tx);
#ifdef MG_ENTERPRISE
GrantDatabaseToUser(auth::kAllDatabases, username);
SetMainDatabase(dbms::kDefaultDB, username);
GrantDatabaseToUser(auth::kAllDatabases, username, system_tx);
SetMainDatabase(dbms::kDefaultDB, username, system_tx);
#endif
}
@ -289,18 +298,19 @@ bool AuthQueryHandler::CreateUser(const std::string &username, const std::option
}
}
bool AuthQueryHandler::DropUser(const std::string &username) {
bool AuthQueryHandler::DropUser(const std::string &username, system::Transaction *system_tx) {
try {
auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username);
if (!user) return false;
return locked_auth->RemoveUser(username);
return locked_auth->RemoveUser(username, system_tx);
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
void AuthQueryHandler::SetPassword(const std::string &username, const std::optional<std::string> &password) {
void AuthQueryHandler::SetPassword(const std::string &username, const std::optional<std::string> &password,
system::Transaction *system_tx) {
try {
auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username);
@ -308,39 +318,41 @@ void AuthQueryHandler::SetPassword(const std::string &username, const std::optio
throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist.", username);
}
locked_auth->UpdatePassword(*user, password);
locked_auth->SaveUser(*user);
locked_auth->SaveUser(*user, system_tx);
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
bool AuthQueryHandler::CreateRole(const std::string &rolename) {
bool AuthQueryHandler::CreateRole(const std::string &rolename, system::Transaction *system_tx) {
try {
auto locked_auth = auth_->Lock();
return locked_auth->AddRole(rolename).has_value();
return locked_auth->AddRole(rolename, system_tx).has_value();
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
#ifdef MG_ENTERPRISE
bool AuthQueryHandler::RevokeDatabaseFromUser(const std::string &db, const std::string &username) {
bool AuthQueryHandler::RevokeDatabaseFromUser(const std::string &db_name, const std::string &username,
system::Transaction *system_tx) {
try {
auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username);
if (!user) return false;
return locked_auth->RevokeDatabaseFromUser(db, username);
return locked_auth->RevokeDatabaseFromUser(db_name, username, system_tx);
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
bool AuthQueryHandler::GrantDatabaseToUser(const std::string &db, const std::string &username) {
bool AuthQueryHandler::GrantDatabaseToUser(const std::string &db_name, const std::string &username,
system::Transaction *system_tx) {
try {
auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username);
if (!user) return false;
return locked_auth->GrantDatabaseToUser(db, username);
return locked_auth->GrantDatabaseToUser(db_name, username, system_tx);
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
@ -360,27 +372,28 @@ std::vector<std::vector<memgraph::query::TypedValue>> AuthQueryHandler::GetDatab
}
}
bool AuthQueryHandler::SetMainDatabase(std::string_view db, const std::string &username) {
bool AuthQueryHandler::SetMainDatabase(std::string_view db_name, const std::string &username,
system::Transaction *system_tx) {
try {
auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username);
if (!user) return false;
return locked_auth->SetMainDatabase(db, username);
return locked_auth->SetMainDatabase(db_name, username, system_tx);
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
void AuthQueryHandler::DeleteDatabase(std::string_view db) {
void AuthQueryHandler::DeleteDatabase(std::string_view db_name, system::Transaction *system_tx) {
try {
auth_->Lock()->DeleteDatabase(std::string(db));
auth_->Lock()->DeleteDatabase(std::string(db_name), system_tx);
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
#endif
bool AuthQueryHandler::DropRole(const std::string &rolename) {
bool AuthQueryHandler::DropRole(const std::string &rolename, system::Transaction *system_tx) {
try {
auto locked_auth = auth_->Lock();
auto role = locked_auth->GetRole(rolename);
@ -389,7 +402,7 @@ bool AuthQueryHandler::DropRole(const std::string &rolename) {
return false;
};
return locked_auth->RemoveRole(rolename);
return locked_auth->RemoveRole(rolename, system_tx);
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
@ -461,7 +474,8 @@ std::vector<memgraph::query::TypedValue> AuthQueryHandler::GetUsernamesForRole(c
}
}
void AuthQueryHandler::SetRole(const std::string &username, const std::string &rolename) {
void AuthQueryHandler::SetRole(const std::string &username, const std::string &rolename,
system::Transaction *system_tx) {
try {
auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username);
@ -477,13 +491,13 @@ void AuthQueryHandler::SetRole(const std::string &username, const std::string &r
current_role->rolename());
}
user->SetRole(*role);
locked_auth->SaveUser(*user);
locked_auth->SaveUser(*user, system_tx);
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
void AuthQueryHandler::ClearRole(const std::string &username) {
void AuthQueryHandler::ClearRole(const std::string &username, system::Transaction *system_tx) {
try {
auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username);
@ -491,7 +505,7 @@ void AuthQueryHandler::ClearRole(const std::string &username) {
throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist .", username);
}
user->ClearRole();
locked_auth->SaveUser(*user);
locked_auth->SaveUser(*user, system_tx);
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
@ -545,7 +559,8 @@ void AuthQueryHandler::GrantPrivilege(
const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>>
&edge_type_privileges
#endif
) {
,
system::Transaction *system_tx) {
EditPermissions(
user_or_role, privileges,
#ifdef MG_ENTERPRISE
@ -568,11 +583,13 @@ void AuthQueryHandler::GrantPrivilege(
}
}
#endif
);
,
system_tx);
} // namespace memgraph::glue
void AuthQueryHandler::DenyPrivilege(const std::string &user_or_role,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges) {
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges,
system::Transaction *system_tx) {
EditPermissions(
user_or_role, privileges,
#ifdef MG_ENTERPRISE
@ -588,7 +605,8 @@ void AuthQueryHandler::DenyPrivilege(const std::string &user_or_role,
,
[](auto &fine_grained_permissions, const auto &privilege_collection) {}
#endif
);
,
system_tx);
}
void AuthQueryHandler::RevokePrivilege(
@ -600,7 +618,8 @@ void AuthQueryHandler::RevokePrivilege(
const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>>
&edge_type_privileges
#endif
) {
,
system::Transaction *system_tx) {
EditPermissions(
user_or_role, privileges,
#ifdef MG_ENTERPRISE
@ -622,7 +641,8 @@ void AuthQueryHandler::RevokePrivilege(
}
}
#endif
);
,
system_tx);
} // namespace memgraph::glue
template <class TEditPermissionsFun
@ -646,7 +666,8 @@ void AuthQueryHandler::EditPermissions(
,
const TEditFineGrainedPermissionsFun &edit_fine_grained_permissions_fun
#endif
) {
,
system::Transaction *system_tx) {
try {
std::vector<memgraph::auth::Permission> permissions;
permissions.reserve(privileges.size());
@ -675,7 +696,7 @@ void AuthQueryHandler::EditPermissions(
}
}
#endif
locked_auth->SaveUser(*user);
locked_auth->SaveUser(*user, system_tx);
} else {
for (const auto &permission : permissions) {
edit_permissions_fun(role->permissions(), permission);
@ -691,7 +712,7 @@ void AuthQueryHandler::EditPermissions(
}
}
#endif
locked_auth->SaveRole(*role);
locked_auth->SaveRole(*role, system_tx);
}
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());

View File

@ -23,32 +23,36 @@
namespace memgraph::glue {
class AuthQueryHandler final : public memgraph::query::AuthQueryHandler {
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth_;
memgraph::auth::SynchedAuth *auth_;
public:
AuthQueryHandler(memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth);
explicit AuthQueryHandler(memgraph::auth::SynchedAuth *auth);
bool CreateUser(const std::string &username, const std::optional<std::string> &password) override;
bool CreateUser(const std::string &username, const std::optional<std::string> &password,
system::Transaction *system_tx) override;
bool DropUser(const std::string &username) override;
bool DropUser(const std::string &username, system::Transaction *system_tx) override;
void SetPassword(const std::string &username, const std::optional<std::string> &password) override;
void SetPassword(const std::string &username, const std::optional<std::string> &password,
system::Transaction *system_tx) override;
#ifdef MG_ENTERPRISE
bool RevokeDatabaseFromUser(const std::string &db, const std::string &username) override;
bool RevokeDatabaseFromUser(const std::string &db_name, const std::string &username,
system::Transaction *system_tx) override;
bool GrantDatabaseToUser(const std::string &db, const std::string &username) override;
bool GrantDatabaseToUser(const std::string &db_name, const std::string &username,
system::Transaction *system_tx) override;
std::vector<std::vector<memgraph::query::TypedValue>> GetDatabasePrivileges(const std::string &username) override;
bool SetMainDatabase(std::string_view db, const std::string &username) override;
bool SetMainDatabase(std::string_view db_name, const std::string &username, system::Transaction *system_tx) override;
void DeleteDatabase(std::string_view db) override;
void DeleteDatabase(std::string_view db_name, system::Transaction *system_tx) override;
#endif
bool CreateRole(const std::string &rolename) override;
bool CreateRole(const std::string &rolename, system::Transaction *system_tx) override;
bool DropRole(const std::string &rolename) override;
bool DropRole(const std::string &rolename, system::Transaction *system_tx) override;
std::vector<memgraph::query::TypedValue> GetUsernames() override;
@ -58,9 +62,9 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler {
std::vector<memgraph::query::TypedValue> GetUsernamesForRole(const std::string &rolename) override;
void SetRole(const std::string &username, const std::string &rolename) override;
void SetRole(const std::string &username, const std::string &rolename, system::Transaction *system_tx) override;
void ClearRole(const std::string &username) override;
void ClearRole(const std::string &username, system::Transaction *system_tx) override;
std::vector<std::vector<memgraph::query::TypedValue>> GetPrivileges(const std::string &user_or_role) override;
@ -74,10 +78,12 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler {
const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>>
&edge_type_privileges
#endif
) override;
,
system::Transaction *system_tx) override;
void DenyPrivilege(const std::string &user_or_role,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges) override;
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges,
system::Transaction *system_tx) override;
void RevokePrivilege(
const std::string &user_or_role, const std::vector<memgraph::query::AuthQuery::Privilege> &privileges
@ -88,7 +94,8 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler {
const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>>
&edge_type_privileges
#endif
) override;
,
system::Transaction *system_tx) override;
private:
template <class TEditPermissionsFun
@ -112,6 +119,7 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler {
,
const TEditFineGrainedPermissionsFun &edit_fine_grained_permissions_fun
#endif
);
,
system::Transaction *system_tx);
};
} // namespace memgraph::glue

View File

@ -37,6 +37,7 @@ KVStore::KVStore(std::filesystem::path storage) : pimpl_(std::make_unique<impl>(
}
KVStore::~KVStore() {
if (pimpl_ == nullptr) return;
spdlog::debug("Destroying KVStore at {}", pimpl_->storage.string());
const auto sync = pimpl_->db->SyncWAL();
if (!sync.ok()) spdlog::error("KVStore sync failed!");

View File

@ -11,9 +11,12 @@
#include <cstdint>
#include "audit/log.hpp"
#include "auth/auth.hpp"
#include "communication/websocket/auth.hpp"
#include "communication/websocket/server.hpp"
#include "coordination/coordinator_handlers.hpp"
#include "dbms/constants.hpp"
#include "dbms/dbms_handler.hpp"
#include "dbms/inmemory/replication_handlers.hpp"
#include "flags/all.hpp"
#include "glue/MonitoringServerT.hpp"
@ -24,14 +27,19 @@
#include "helpers.hpp"
#include "license/license_sender.hpp"
#include "memory/global_memory_control.hpp"
#include "query/auth_query_handler.hpp"
#include "query/config.hpp"
#include "query/discard_value_stream.hpp"
#include "query/interpreter.hpp"
#include "query/interpreter_context.hpp"
#include "query/procedure/callable_alias_mapper.hpp"
#include "query/procedure/module.hpp"
#include "query/procedure/py_module.hpp"
#include "replication_handler/replication_handler.hpp"
#include "replication_handler/system_replication.hpp"
#include "requests/requests.hpp"
#include "storage/v2/durability/durability.hpp"
#include "system/system.hpp"
#include "telemetry/telemetry.hpp"
#include "utils/signals.hpp"
#include "utils/sysinfo/memory.hpp"
@ -39,10 +47,6 @@
#include "utils/terminate_handler.hpp"
#include "version.hpp"
#include "dbms/dbms_handler.hpp"
#include "query/auth_query_handler.hpp"
#include "query/interpreter_context.hpp"
namespace {
constexpr const char *kMgUser = "MEMGRAPH_USER";
constexpr const char *kMgPassword = "MEMGRAPH_PASSWORD";
@ -356,44 +360,75 @@ int main(int argc, char **argv) {
.stream_transaction_conflict_retries = FLAGS_stream_transaction_conflict_retries,
.stream_transaction_retry_interval = std::chrono::milliseconds(FLAGS_stream_transaction_retry_interval)};
auto auth_glue =
[](memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth,
std::unique_ptr<memgraph::query::AuthQueryHandler> &ah, std::unique_ptr<memgraph::query::AuthChecker> &ac) {
// Glue high level auth implementations to the query side
ah = std::make_unique<memgraph::glue::AuthQueryHandler>(auth);
ac = std::make_unique<memgraph::glue::AuthChecker>(auth);
// Handle users passed via arguments
auto *maybe_username = std::getenv(kMgUser);
auto *maybe_password = std::getenv(kMgPassword);
auto *maybe_pass_file = std::getenv(kMgPassfile);
if (maybe_username && maybe_password) {
ah->CreateUser(maybe_username, maybe_password);
} else if (maybe_pass_file) {
const auto [username, password] = LoadUsernameAndPassword(maybe_pass_file);
if (!username.empty() && !password.empty()) {
ah->CreateUser(username, password);
}
}
};
auto auth_glue = [](memgraph::auth::SynchedAuth *auth, std::unique_ptr<memgraph::query::AuthQueryHandler> &ah,
std::unique_ptr<memgraph::query::AuthChecker> &ac) {
// Glue high level auth implementations to the query side
ah = std::make_unique<memgraph::glue::AuthQueryHandler>(auth);
ac = std::make_unique<memgraph::glue::AuthChecker>(auth);
// Handle users passed via arguments
auto *maybe_username = std::getenv(kMgUser);
auto *maybe_password = std::getenv(kMgPassword);
auto *maybe_pass_file = std::getenv(kMgPassfile);
if (maybe_username && maybe_password) {
ah->CreateUser(maybe_username, maybe_password, nullptr);
} else if (maybe_pass_file) {
const auto [username, password] = LoadUsernameAndPassword(maybe_pass_file);
if (!username.empty() && !password.empty()) {
ah->CreateUser(username, password, nullptr);
}
}
};
memgraph::auth::Auth::Config auth_config{FLAGS_auth_user_or_role_name_regex, FLAGS_auth_password_strength_regex,
FLAGS_auth_password_permit_null};
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> auth_{
data_directory / "auth", auth_config};
memgraph::auth::SynchedAuth auth_{data_directory / "auth", auth_config};
std::unique_ptr<memgraph::query::AuthQueryHandler> auth_handler;
std::unique_ptr<memgraph::query::AuthChecker> auth_checker;
auth_glue(&auth_, auth_handler, auth_checker);
memgraph::dbms::DbmsHandler dbms_handler(db_config
auto system = memgraph::system::System{db_config.durability.storage_directory, FLAGS_data_recovery_on_startup};
// singleton replication state
memgraph::replication::ReplicationState repl_state{ReplicationStateRootPath(db_config)};
// singleton coordinator state
#ifdef MG_ENTERPRISE
memgraph::coordination::CoordinatorState coordinator_state;
#endif
memgraph::dbms::DbmsHandler dbms_handler(db_config, system, repl_state
#ifdef MG_ENTERPRISE
,
&auth_, FLAGS_data_recovery_on_startup
auth_, FLAGS_data_recovery_on_startup
#endif
);
// Note: Now that all system's subsystems are initialised (dbms & auth)
// We can now initialise the recovery of replication (which will include those subsystems)
// ReplicationHandler will handle the recovery
auto replication_handler = memgraph::replication::ReplicationHandler{repl_state, dbms_handler
#ifdef MG_ENTERPRISE
,
&system, auth_
#endif
};
#ifdef MG_ENTERPRISE
// MAIN or REPLICA instance
if (FLAGS_coordinator_server_port) {
memgraph::dbms::CoordinatorHandlers::Register(coordinator_state.GetCoordinatorServer(), replication_handler);
MG_ASSERT(coordinator_state.GetCoordinatorServer().Start(), "Failed to start coordinator server!");
}
#endif
auto db_acc = dbms_handler.Get();
memgraph::query::InterpreterContext interpreter_context_(
interp_config, &dbms_handler, &dbms_handler.ReplicationState(), auth_handler.get(), auth_checker.get());
memgraph::query::InterpreterContext interpreter_context_(interp_config, &dbms_handler, &repl_state, system,
#ifdef MG_ENTERPRISE
&coordinator_state,
#endif
auth_handler.get(), auth_checker.get(),
&replication_handler);
MG_ASSERT(db_acc, "Failed to access the main database");
memgraph::query::procedure::gModuleRegistry.SetModulesDirectory(memgraph::flags::ParseQueryModulesDirectory(),
@ -460,9 +495,9 @@ int main(int argc, char **argv) {
if (FLAGS_telemetry_enabled) {
telemetry.emplace(telemetry_server, data_directory / "telemetry", memgraph::glue::run_id_, machine_id,
service_name == "BoltS", FLAGS_data_directory, std::chrono::minutes(10));
telemetry->AddStorageCollector(dbms_handler, auth_);
telemetry->AddStorageCollector(dbms_handler, auth_, repl_state);
#ifdef MG_ENTERPRISE
telemetry->AddDatabaseCollector(dbms_handler);
telemetry->AddDatabaseCollector(dbms_handler, repl_state);
#else
telemetry->AddDatabaseCollector();
#endif

View File

@ -56,6 +56,7 @@ target_link_libraries(mg-query PUBLIC dl
mg-kvstore
mg-memory
mg::csv
mg::system
mg-flags
mg-dbms
mg-events)

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -18,6 +18,7 @@
#include "query/frontend/ast/ast.hpp" // overkill
#include "query/typed_value.hpp"
#include "system/system.hpp"
namespace memgraph::query {
@ -33,23 +34,27 @@ class AuthQueryHandler {
/// Return false if the user already exists.
/// @throw QueryRuntimeException if an error ocurred.
virtual bool CreateUser(const std::string &username, const std::optional<std::string> &password) = 0;
virtual bool CreateUser(const std::string &username, const std::optional<std::string> &password,
system::Transaction *system_tx) = 0;
/// Return false if the user does not exist.
/// @throw QueryRuntimeException if an error ocurred.
virtual bool DropUser(const std::string &username) = 0;
virtual bool DropUser(const std::string &username, system::Transaction *system_tx) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void SetPassword(const std::string &username, const std::optional<std::string> &password) = 0;
virtual void SetPassword(const std::string &username, const std::optional<std::string> &password,
system::Transaction *system_tx) = 0;
#ifdef MG_ENTERPRISE
/// Return true if access revoked successfully
/// @throw QueryRuntimeException if an error ocurred.
virtual bool RevokeDatabaseFromUser(const std::string &db, const std::string &username) = 0;
virtual bool RevokeDatabaseFromUser(const std::string &db, const std::string &username,
system::Transaction *system_tx) = 0;
/// Return true if access granted successfully
/// @throw QueryRuntimeException if an error ocurred.
virtual bool GrantDatabaseToUser(const std::string &db, const std::string &username) = 0;
virtual bool GrantDatabaseToUser(const std::string &db, const std::string &username,
system::Transaction *system_tx) = 0;
/// Returns database access rights for the user
/// @throw QueryRuntimeException if an error ocurred.
@ -57,20 +62,20 @@ class AuthQueryHandler {
/// Return true if main database set successfully
/// @throw QueryRuntimeException if an error ocurred.
virtual bool SetMainDatabase(std::string_view db, const std::string &username) = 0;
virtual bool SetMainDatabase(std::string_view db, const std::string &username, system::Transaction *system_tx) = 0;
/// Delete database from all users
/// @throw QueryRuntimeException if an error ocurred.
virtual void DeleteDatabase(std::string_view db) = 0;
virtual void DeleteDatabase(std::string_view db, system::Transaction *system_tx) = 0;
#endif
/// Return false if the role already exists.
/// @throw QueryRuntimeException if an error ocurred.
virtual bool CreateRole(const std::string &rolename) = 0;
virtual bool CreateRole(const std::string &rolename, system::Transaction *system_tx) = 0;
/// Return false if the role does not exist.
/// @throw QueryRuntimeException if an error ocurred.
virtual bool DropRole(const std::string &rolename) = 0;
virtual bool DropRole(const std::string &rolename, system::Transaction *system_tx) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual std::vector<memgraph::query::TypedValue> GetUsernames() = 0;
@ -85,10 +90,10 @@ class AuthQueryHandler {
virtual std::vector<memgraph::query::TypedValue> GetUsernamesForRole(const std::string &rolename) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void SetRole(const std::string &username, const std::string &rolename) = 0;
virtual void SetRole(const std::string &username, const std::string &rolename, system::Transaction *system_tx) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void ClearRole(const std::string &username) = 0;
virtual void ClearRole(const std::string &username, system::Transaction *system_tx) = 0;
virtual std::vector<std::vector<memgraph::query::TypedValue>> GetPrivileges(const std::string &user_or_role) = 0;
@ -103,11 +108,13 @@ class AuthQueryHandler {
const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>>
&edge_type_privileges
#endif
) = 0;
,
system::Transaction *system_tx) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void DenyPrivilege(const std::string &user_or_role,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges) = 0;
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges,
system::Transaction *system_tx) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void RevokePrivilege(
@ -120,7 +127,8 @@ class AuthQueryHandler {
const std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>>
&edge_type_privileges
#endif
) = 0;
,
system::Transaction *system_tx) = 0;
};
} // namespace memgraph::query

View File

@ -34,7 +34,9 @@
#include "auth/auth.hpp"
#include "auth/models.hpp"
#include "csv/parsing.hpp"
#include "dbms/coordinator_handler.hpp"
#include "dbms/database.hpp"
#include "dbms/dbms_handler.hpp"
#include "dbms/global.hpp"
#include "dbms/inmemory/storage_helper.hpp"
#include "flags/replication.hpp"
@ -43,6 +45,7 @@
#include "license/license.hpp"
#include "memory/global_memory_control.hpp"
#include "memory/query_memory_control.hpp"
#include "query/auth_query_handler.hpp"
#include "query/config.hpp"
#include "query/constants.hpp"
#include "query/context.hpp"
@ -58,12 +61,14 @@
#include "query/frontend/semantic/symbol_generator.hpp"
#include "query/interpret/eval.hpp"
#include "query/interpret/frame.hpp"
#include "query/interpreter_context.hpp"
#include "query/metadata.hpp"
#include "query/plan/hint_provider.hpp"
#include "query/plan/planner.hpp"
#include "query/plan/profile.hpp"
#include "query/plan/vertex_count_cache.hpp"
#include "query/procedure/module.hpp"
#include "query/replication_query_handler.hpp"
#include "query/stream.hpp"
#include "query/stream/common.hpp"
#include "query/stream/sources.hpp"
@ -71,6 +76,7 @@
#include "query/trigger.hpp"
#include "query/typed_value.hpp"
#include "replication/config.hpp"
#include "replication/state.hpp"
#include "spdlog/spdlog.h"
#include "storage/v2/disk/storage.hpp"
#include "storage/v2/edge.hpp"
@ -101,13 +107,6 @@
#include "utils/typeinfo.hpp"
#include "utils/variant_helpers.hpp"
#include "dbms/coordinator_handler.hpp"
#include "dbms/dbms_handler.hpp"
#include "dbms/replication_handler.hpp"
#include "query/auth_query_handler.hpp"
#include "query/interpreter_context.hpp"
#include "replication/state.hpp"
#ifdef MG_ENTERPRISE
#include "coordination/constants.hpp"
#endif
@ -306,17 +305,18 @@ class ReplQueryHandler {
ReplicationQuery::ReplicaState state;
};
explicit ReplQueryHandler(dbms::DbmsHandler *dbms_handler) : handler_{*dbms_handler} {}
explicit ReplQueryHandler(query::ReplicationQueryHandler &replication_query_handler)
: handler_{&replication_query_handler} {}
/// @throw QueryRuntimeException if an error ocurred.
void SetReplicationRole(ReplicationQuery::ReplicationRole replication_role, std::optional<int64_t> port) {
auto ValidatePort = [](std::optional<int64_t> port) -> void {
if (*port < 0 || *port > std::numeric_limits<uint16_t>::max()) {
if (!port || *port < 0 || *port > std::numeric_limits<uint16_t>::max()) {
throw QueryRuntimeException("Port number invalid!");
}
};
if (replication_role == ReplicationQuery::ReplicationRole::MAIN) {
if (!handler_.SetReplicationRoleMain()) {
if (!handler_->SetReplicationRoleMain()) {
throw QueryRuntimeException("Couldn't set replication role to main!");
}
} else {
@ -327,7 +327,7 @@ class ReplQueryHandler {
.port = static_cast<uint16_t>(*port),
};
if (!handler_.SetReplicationRoleReplica(config)) {
if (!handler_->SetReplicationRoleReplica(config)) {
throw QueryRuntimeException("Couldn't set role to replica!");
}
}
@ -335,7 +335,7 @@ class ReplQueryHandler {
/// @throw QueryRuntimeException if an error ocurred.
ReplicationQuery::ReplicationRole ShowReplicationRole() const {
switch (handler_.GetRole()) {
switch (handler_->GetRole()) {
case memgraph::replication_coordination_glue::ReplicationRole::MAIN:
return ReplicationQuery::ReplicationRole::MAIN;
case memgraph::replication_coordination_glue::ReplicationRole::REPLICA:
@ -349,7 +349,7 @@ class ReplQueryHandler {
const ReplicationQuery::SyncMode sync_mode, const std::chrono::seconds replica_check_frequency) {
// Coordinator is main by default so this check is OK although it should actually be nothing (neither main nor
// replica)
if (handler_.IsReplica()) {
if (handler_->IsReplica()) {
// replica can't register another replica
throw QueryRuntimeException("Replica can't register another replica!");
}
@ -368,7 +368,7 @@ class ReplQueryHandler {
.replica_check_frequency = replica_check_frequency,
.ssl = std::nullopt};
const auto error = handler_.RegisterReplica(replication_config).HasError();
const auto error = handler_->TryRegisterReplica(replication_config).HasError();
if (error) {
throw QueryRuntimeException(fmt::format("Couldn't register replica '{}'!", name));
@ -381,9 +381,9 @@ class ReplQueryHandler {
/// @throw QueryRuntimeException if an error occurred.
void DropReplica(std::string_view replica_name) {
auto const result = handler_.UnregisterReplica(replica_name);
auto const result = handler_->UnregisterReplica(replica_name);
switch (result) {
using enum memgraph::dbms::UnregisterReplicaResult;
using enum memgraph::query::UnregisterReplicaResult;
case NOT_MAIN:
throw QueryRuntimeException("Replica can't unregister a replica!");
case COULD_NOT_BE_PERSISTED:
@ -396,7 +396,7 @@ class ReplQueryHandler {
}
std::vector<ReplicaInfo> ShowReplicas(const dbms::Database &db) const {
if (handler_.IsReplica()) {
if (handler_->IsReplica()) {
// replica can't show registered replicas (it shouldn't have any)
throw QueryRuntimeException("Replica can't show registered replicas (it shouldn't have any)!");
}
@ -447,19 +447,16 @@ class ReplQueryHandler {
}
private:
dbms::ReplicationHandler handler_;
query::ReplicationQueryHandler *handler_;
};
#ifdef MG_ENTERPRISE
class CoordQueryHandler final : public query::CoordinatorQueryHandler {
public:
explicit CoordQueryHandler(dbms::DbmsHandler *dbms_handler) : handler_ { *dbms_handler }
#ifdef MG_ENTERPRISE
, coordinator_handler_(*dbms_handler)
#endif
{
}
explicit CoordQueryHandler(coordination::CoordinatorState &coordinator_state)
: coordinator_handler_(coordinator_state) {}
#ifdef MG_ENTERPRISE
/// @throw QueryRuntimeException if an error ocurred.
void RegisterInstance(const std::string &coordinator_socket_address, const std::string &replication_socket_address,
const std::chrono::seconds instance_check_frequency, const std::string &instance_name,
@ -497,7 +494,7 @@ class CoordQueryHandler final : public query::CoordinatorQueryHandler {
using enum memgraph::coordination::RegisterInstanceCoordinatorStatus;
case NAME_EXISTS:
throw QueryRuntimeException("Couldn't register replica instance since instance with such name already exists!");
case END_POINT_EXISTS:
case ENDPOINT_EXISTS:
throw QueryRuntimeException(
"Couldn't register replica instance since instance with such endpoint already exists!");
case NOT_COORDINATOR:
@ -527,26 +524,20 @@ class CoordQueryHandler final : public query::CoordinatorQueryHandler {
}
}
#endif
#ifdef MG_ENTERPRISE
std::vector<coordination::CoordinatorInstanceStatus> ShowInstances() const override {
return coordinator_handler_.ShowInstances();
}
#endif
private:
dbms::ReplicationHandler handler_;
#ifdef MG_ENTERPRISE
dbms::CoordinatorHandler coordinator_handler_;
#endif
};
#endif
/// returns false if the replication role can't be set
/// @throw QueryRuntimeException if an error ocurred.
Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_context, const Parameters &parameters) {
Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_context, const Parameters &parameters,
Interpreter &interpreter) {
AuthQueryHandler *auth = interpreter_context->auth;
#ifdef MG_ENTERPRISE
auto *db_handler = interpreter_context->dbms_handler;
@ -595,63 +586,111 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_
license::LicenseCheckErrorToString(license_check_result.GetError(), "advanced authentication features"));
}
const auto forbid_on_replica = [has_license = license_check_result.HasError(),
is_replica = interpreter_context->repl_state->IsReplica()]() {
if (is_replica) {
#if MG_ENTERPRISE
if (has_license) {
throw QueryException(
"Query forbidden on the replica! Update on MAIN, as it is the only source of truth for authentication "
"data. MAIN will then replicate the update to connected REPLICAs");
}
throw QueryException(
"Query forbidden on the replica! Switch role to MAIN and update user data, then switch back to REPLICA.");
#else
throw QueryException(
"Query forbidden on the replica! Switch role to MAIN and update user data, then switch back to REPLICA.");
#endif
}
};
switch (auth_query->action_) {
case AuthQuery::Action::CREATE_USER:
callback.fn = [auth, username, password, valid_enterprise_license = !license_check_result.HasError()] {
forbid_on_replica();
callback.fn = [auth, username, password, valid_enterprise_license = !license_check_result.HasError(),
interpreter = &interpreter] {
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
MG_ASSERT(password.IsString() || password.IsNull());
if (!auth->CreateUser(username, password.IsString() ? std::make_optional(std::string(password.ValueString()))
: std::nullopt)) {
if (!auth->CreateUser(
username, password.IsString() ? std::make_optional(std::string(password.ValueString())) : std::nullopt,
&*interpreter->system_transaction_)) {
throw UserAlreadyExistsException("User '{}' already exists.", username);
}
// If the license is not valid we create users with admin access
if (!valid_enterprise_license) {
spdlog::warn("Granting all the privileges to {}.", username);
auth->GrantPrivilege(username, kPrivilegesAll
auth->GrantPrivilege(
username, kPrivilegesAll
#ifdef MG_ENTERPRISE
,
{{{AuthQuery::FineGrainedPrivilege::CREATE_DELETE, {query::kAsterisk}}}},
{
{
{
AuthQuery::FineGrainedPrivilege::CREATE_DELETE, { query::kAsterisk }
}
}
}
,
{{{AuthQuery::FineGrainedPrivilege::CREATE_DELETE, {query::kAsterisk}}}},
{
{
{
AuthQuery::FineGrainedPrivilege::CREATE_DELETE, { query::kAsterisk }
}
}
}
#endif
);
,
&*interpreter->system_transaction_);
}
return std::vector<std::vector<TypedValue>>();
};
return callback;
case AuthQuery::Action::DROP_USER:
callback.fn = [auth, username] {
if (!auth->DropUser(username)) {
forbid_on_replica();
callback.fn = [auth, username, interpreter = &interpreter] {
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
if (!auth->DropUser(username, &*interpreter->system_transaction_)) {
throw QueryRuntimeException("User '{}' doesn't exist.", username);
}
return std::vector<std::vector<TypedValue>>();
};
return callback;
case AuthQuery::Action::SET_PASSWORD:
callback.fn = [auth, username, password] {
forbid_on_replica();
callback.fn = [auth, username, password, interpreter = &interpreter] {
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
MG_ASSERT(password.IsString() || password.IsNull());
auth->SetPassword(username,
password.IsString() ? std::make_optional(std::string(password.ValueString())) : std::nullopt);
password.IsString() ? std::make_optional(std::string(password.ValueString())) : std::nullopt,
&*interpreter->system_transaction_);
return std::vector<std::vector<TypedValue>>();
};
return callback;
case AuthQuery::Action::CREATE_ROLE:
callback.fn = [auth, rolename] {
if (!auth->CreateRole(rolename)) {
forbid_on_replica();
callback.fn = [auth, rolename, interpreter = &interpreter] {
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
if (!auth->CreateRole(rolename, &*interpreter->system_transaction_)) {
throw QueryRuntimeException("Role '{}' already exists.", rolename);
}
return std::vector<std::vector<TypedValue>>();
};
return callback;
case AuthQuery::Action::DROP_ROLE:
callback.fn = [auth, rolename] {
if (!auth->DropRole(rolename)) {
forbid_on_replica();
callback.fn = [auth, rolename, interpreter = &interpreter] {
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
if (!auth->DropRole(rolename, &*interpreter->system_transaction_)) {
throw QueryRuntimeException("Role '{}' doesn't exist.", rolename);
}
return std::vector<std::vector<TypedValue>>();
@ -682,52 +721,79 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_
};
return callback;
case AuthQuery::Action::SET_ROLE:
callback.fn = [auth, username, rolename] {
auth->SetRole(username, rolename);
forbid_on_replica();
callback.fn = [auth, username, rolename, interpreter = &interpreter] {
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
auth->SetRole(username, rolename, &*interpreter->system_transaction_);
return std::vector<std::vector<TypedValue>>();
};
return callback;
case AuthQuery::Action::CLEAR_ROLE:
callback.fn = [auth, username] {
auth->ClearRole(username);
forbid_on_replica();
callback.fn = [auth, username, interpreter = &interpreter] {
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
auth->ClearRole(username, &*interpreter->system_transaction_);
return std::vector<std::vector<TypedValue>>();
};
return callback;
case AuthQuery::Action::GRANT_PRIVILEGE:
callback.fn = [auth, user_or_role, privileges
forbid_on_replica();
callback.fn = [auth, user_or_role, privileges, interpreter = &interpreter
#ifdef MG_ENTERPRISE
,
label_privileges, edge_type_privileges
#endif
] {
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
auth->GrantPrivilege(user_or_role, privileges
#ifdef MG_ENTERPRISE
,
label_privileges, edge_type_privileges
#endif
);
,
&*interpreter->system_transaction_);
return std::vector<std::vector<TypedValue>>();
};
return callback;
case AuthQuery::Action::DENY_PRIVILEGE:
callback.fn = [auth, user_or_role, privileges] {
auth->DenyPrivilege(user_or_role, privileges);
forbid_on_replica();
callback.fn = [auth, user_or_role, privileges, interpreter = &interpreter] {
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
auth->DenyPrivilege(user_or_role, privileges, &*interpreter->system_transaction_);
return std::vector<std::vector<TypedValue>>();
};
return callback;
case AuthQuery::Action::REVOKE_PRIVILEGE: {
callback.fn = [auth, user_or_role, privileges
forbid_on_replica();
callback.fn = [auth, user_or_role, privileges, interpreter = &interpreter
#ifdef MG_ENTERPRISE
,
label_privileges, edge_type_privileges
#endif
] {
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
auth->RevokePrivilege(user_or_role, privileges
#ifdef MG_ENTERPRISE
,
label_privileges, edge_type_privileges
#endif
);
,
&*interpreter->system_transaction_);
return std::vector<std::vector<TypedValue>>();
};
return callback;
@ -757,15 +823,20 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_
};
return callback;
case AuthQuery::Action::GRANT_DATABASE_TO_USER:
forbid_on_replica();
#ifdef MG_ENTERPRISE
callback.fn = [auth, database, username, db_handler] { // NOLINT
callback.fn = [auth, database, username, db_handler, interpreter = &interpreter] { // NOLINT
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
try {
std::optional<memgraph::dbms::DatabaseAccess> db =
std::nullopt; // Hold pointer to database to protect it until query is done
if (database != memgraph::auth::kAllDatabases) {
db = db_handler->Get(database); // Will throw if databases doesn't exist and protect it during pull
}
if (!auth->GrantDatabaseToUser(database, username)) {
if (!auth->GrantDatabaseToUser(database, username, &*interpreter->system_transaction_)) {
throw QueryRuntimeException("Failed to grant database {} to user {}.", database, username);
}
} catch (memgraph::dbms::UnknownDatabaseException &e) {
@ -778,15 +849,20 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_
};
return callback;
case AuthQuery::Action::REVOKE_DATABASE_FROM_USER:
forbid_on_replica();
#ifdef MG_ENTERPRISE
callback.fn = [auth, database, username, db_handler] { // NOLINT
callback.fn = [auth, database, username, db_handler, interpreter = &interpreter] { // NOLINT
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
try {
std::optional<memgraph::dbms::DatabaseAccess> db =
std::nullopt; // Hold pointer to database to protect it until query is done
if (database != memgraph::auth::kAllDatabases) {
db = db_handler->Get(database); // Will throw if databases doesn't exist and protect it during pull
}
if (!auth->RevokeDatabaseFromUser(database, username)) {
if (!auth->RevokeDatabaseFromUser(database, username, &*interpreter->system_transaction_)) {
throw QueryRuntimeException("Failed to revoke database {} from user {}.", database, username);
}
} catch (memgraph::dbms::UnknownDatabaseException &e) {
@ -811,12 +887,17 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_
#endif
return callback;
case AuthQuery::Action::SET_MAIN_DATABASE:
forbid_on_replica();
#ifdef MG_ENTERPRISE
callback.fn = [auth, database, username, db_handler] { // NOLINT
callback.fn = [auth, database, username, db_handler, interpreter = &interpreter] { // NOLINT
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
try {
const auto db =
db_handler->Get(database); // Will throw if databases doesn't exist and protect it during pull
if (!auth->SetMainDatabase(database, username)) {
if (!auth->SetMainDatabase(database, username, &*interpreter->system_transaction_)) {
throw QueryRuntimeException("Failed to set main database {} for user {}.", database, username);
}
} catch (memgraph::dbms::UnknownDatabaseException &e) {
@ -834,7 +915,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_
} // namespace
Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &parameters,
dbms::DbmsHandler *dbms_handler, CurrentDB &current_db,
ReplicationQueryHandler &replication_query_handler, CurrentDB &current_db,
const query::InterpreterConfig &config, std::vector<Notification> *notifications) {
// TODO: MemoryResource for EvaluationContext, it should probably be passed as
// the argument to Callback.
@ -864,7 +945,8 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &
notifications->emplace_back(SeverityLevel::WARNING, NotificationCode::REPLICA_PORT_WARNING,
"Be careful the replication port must be different from the memgraph port!");
}
callback.fn = [handler = ReplQueryHandler{dbms_handler}, role = repl_query->role_, maybe_port]() mutable {
callback.fn = [handler = ReplQueryHandler{replication_query_handler}, role = repl_query->role_,
maybe_port]() mutable {
handler.SetReplicationRole(role, maybe_port);
return std::vector<std::vector<TypedValue>>();
};
@ -882,7 +964,7 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &
#endif
callback.header = {"replication role"};
callback.fn = [handler = ReplQueryHandler{dbms_handler}] {
callback.fn = [handler = ReplQueryHandler{replication_query_handler}] {
auto mode = handler.ShowReplicationRole();
switch (mode) {
case ReplicationQuery::ReplicationRole::MAIN: {
@ -906,7 +988,7 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &
auto socket_address = repl_query->socket_address_->Accept(evaluator);
const auto replica_check_frequency = config.replication_replica_check_frequency;
callback.fn = [handler = ReplQueryHandler{dbms_handler}, name, socket_address, sync_mode,
callback.fn = [handler = ReplQueryHandler{replication_query_handler}, name, socket_address, sync_mode,
replica_check_frequency]() mutable {
handler.RegisterReplica(name, std::string(socket_address.ValueString()), sync_mode, replica_check_frequency);
return std::vector<std::vector<TypedValue>>();
@ -923,7 +1005,7 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &
}
#endif
const auto &name = repl_query->instance_name_;
callback.fn = [handler = ReplQueryHandler{dbms_handler}, name]() mutable {
callback.fn = [handler = ReplQueryHandler{replication_query_handler}, name]() mutable {
handler.DropReplica(name);
return std::vector<std::vector<TypedValue>>();
};
@ -941,7 +1023,7 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &
callback.header = {
"name", "socket_address", "sync_mode", "current_timestamp_of_replica", "number_of_timestamp_behind_master",
"state"};
callback.fn = [handler = ReplQueryHandler{dbms_handler}, replica_nfields = callback.header.size(),
callback.fn = [handler = ReplQueryHandler{replication_query_handler}, replica_nfields = callback.header.size(),
db_acc = current_db.db_acc_] {
const auto &replicas = handler.ShowReplicas(*db_acc->get());
auto typed_replicas = std::vector<std::vector<TypedValue>>{};
@ -989,16 +1071,17 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &
}
}
#ifdef MG_ENTERPRISE
Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Parameters &parameters,
dbms::DbmsHandler *dbms_handler, const query::InterpreterConfig &config,
std::vector<Notification> *notifications) {
coordination::CoordinatorState *coordinator_state,
const query::InterpreterConfig &config, std::vector<Notification> *notifications) {
Callback callback;
switch (coordinator_query->action_) {
case CoordinatorQuery::Action::REGISTER_INSTANCE: {
if (!license::global_license_checker.IsEnterpriseValidFast()) {
throw QueryException("Trying to use enterprise feature without a valid license.");
}
#ifdef MG_ENTERPRISE
if constexpr (!coordination::allow_ha) {
throw QueryRuntimeException(
"High availability is experimental feature. Please set MG_EXPERIMENTAL_HIGH_AVAILABILITY compile flag to "
@ -1014,7 +1097,7 @@ Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Param
auto coordinator_socket_address_tv = coordinator_query->coordinator_socket_address_->Accept(evaluator);
auto replication_socket_address_tv = coordinator_query->replication_socket_address_->Accept(evaluator);
callback.fn = [handler = CoordQueryHandler{dbms_handler}, coordinator_socket_address_tv,
callback.fn = [handler = CoordQueryHandler{*coordinator_state}, coordinator_socket_address_tv,
replication_socket_address_tv, main_check_frequency = config.replication_replica_check_frequency,
instance_name = coordinator_query->instance_name_,
sync_mode = coordinator_query->sync_mode_]() mutable {
@ -1029,13 +1112,11 @@ Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Param
fmt::format("Coordinator has registered coordinator server on {} for instance {}.",
coordinator_socket_address_tv.ValueString(), coordinator_query->instance_name_));
return callback;
#endif
}
case CoordinatorQuery::Action::SET_INSTANCE_TO_MAIN: {
if (!license::global_license_checker.IsEnterpriseValidFast()) {
throw QueryException("Trying to use enterprise feature without a valid license.");
}
#ifdef MG_ENTERPRISE
if constexpr (!coordination::allow_ha) {
throw QueryRuntimeException(
"High availability is experimental feature. Please set MG_EXPERIMENTAL_HIGH_AVAILABILITY compile flag to "
@ -1049,20 +1130,18 @@ Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Param
EvaluationContext evaluation_context{.timestamp = QueryTimestamp(), .parameters = parameters};
auto evaluator = PrimitiveLiteralExpressionEvaluator{evaluation_context};
callback.fn = [handler = CoordQueryHandler{dbms_handler},
callback.fn = [handler = CoordQueryHandler{*coordinator_state},
instance_name = coordinator_query->instance_name_]() mutable {
handler.SetInstanceToMain(instance_name);
return std::vector<std::vector<TypedValue>>();
};
return callback;
#endif
}
case CoordinatorQuery::Action::SHOW_REPLICATION_CLUSTER: {
if (!license::global_license_checker.IsEnterpriseValidFast()) {
throw QueryException("Trying to use enterprise feature without a valid license.");
}
#ifdef MG_ENTERPRISE
if constexpr (!coordination::allow_ha) {
throw QueryRuntimeException(
"High availability is experimental feature. Please set MG_EXPERIMENTAL_HIGH_AVAILABILITY compile flag to "
@ -1073,7 +1152,8 @@ Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Param
}
callback.header = {"name", "socket_address", "alive", "role"};
callback.fn = [handler = CoordQueryHandler{dbms_handler}, replica_nfields = callback.header.size()]() mutable {
callback.fn = [handler = CoordQueryHandler{*coordinator_state},
replica_nfields = callback.header.size()]() mutable {
auto const instances = handler.ShowInstances();
std::vector<std::vector<TypedValue>> result{};
result.reserve(result.size());
@ -1087,11 +1167,11 @@ Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Param
return result;
};
return callback;
#endif
}
return callback;
}
}
#endif
stream::CommonStreamInfo GetCommonStreamInfo(StreamQuery *stream_query, ExpressionVisitor<TypedValue> &evaluator) {
return {
@ -2493,14 +2573,14 @@ PreparedQuery PrepareIndexQuery(ParsedQuery parsed_query, bool in_explicit_trans
}
PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
InterpreterContext *interpreter_context) {
InterpreterContext *interpreter_context, Interpreter &interpreter) {
if (in_explicit_transaction) {
throw UserModificationInMulticommandTxException();
}
auto *auth_query = utils::Downcast<AuthQuery>(parsed_query.query);
auto callback = HandleAuthQuery(auth_query, interpreter_context, parsed_query.parameters);
auto callback = HandleAuthQuery(auth_query, interpreter_context, parsed_query.parameters, interpreter);
return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges),
[handler = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>(nullptr),
@ -2525,15 +2605,16 @@ PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transa
}
PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
std::vector<Notification> *notifications, dbms::DbmsHandler &dbms_handler,
CurrentDB &current_db, const InterpreterConfig &config) {
std::vector<Notification> *notifications,
ReplicationQueryHandler &replication_query_handler, CurrentDB &current_db,
const InterpreterConfig &config) {
if (in_explicit_transaction) {
throw ReplicationModificationInMulticommandTxException();
}
auto *replication_query = utils::Downcast<ReplicationQuery>(parsed_query.query);
auto callback = HandleReplicationQuery(replication_query, parsed_query.parameters, &dbms_handler, current_db, config,
notifications);
auto callback = HandleReplicationQuery(replication_query, parsed_query.parameters, replication_query_handler,
current_db, config, notifications);
return PreparedQuery{callback.header, std::move(parsed_query.required_privileges),
[callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}](
@ -2552,8 +2633,10 @@ PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, bool in_explicit
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
}
#ifdef MG_ENTERPRISE
PreparedQuery PrepareCoordinatorQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
std::vector<Notification> *notifications, dbms::DbmsHandler &dbms_handler,
std::vector<Notification> *notifications,
coordination::CoordinatorState &coordinator_state,
const InterpreterConfig &config) {
if (in_explicit_transaction) {
throw CoordinatorModificationInMulticommandTxException();
@ -2561,7 +2644,7 @@ PreparedQuery PrepareCoordinatorQuery(ParsedQuery parsed_query, bool in_explicit
auto *coordinator_query = utils::Downcast<CoordinatorQuery>(parsed_query.query);
auto callback =
HandleCoordinatorQuery(coordinator_query, parsed_query.parameters, &dbms_handler, config, notifications);
HandleCoordinatorQuery(coordinator_query, parsed_query.parameters, &coordinator_state, config, notifications);
return PreparedQuery{callback.header, std::move(parsed_query.required_privileges),
[callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}](
@ -2579,6 +2662,7 @@ PreparedQuery PrepareCoordinatorQuery(ParsedQuery parsed_query, bool in_explicit
// False positive report for the std::make_shared above
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
}
#endif
PreparedQuery PrepareLockPathQuery(ParsedQuery parsed_query, bool in_explicit_transaction, CurrentDB &current_db) {
if (in_explicit_transaction) {
@ -3681,7 +3765,8 @@ PreparedQuery PrepareConstraintQuery(ParsedQuery parsed_query, bool in_explicit_
PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, CurrentDB &current_db,
InterpreterContext *interpreter_context,
std::optional<std::function<void(std::string_view)>> on_change_cb) {
std::optional<std::function<void(std::string_view)>> on_change_cb,
Interpreter &interpreter) {
#ifdef MG_ENTERPRISE
if (!license::global_license_checker.IsEnterpriseValidFast()) {
throw QueryException("Trying to use enterprise feature without a valid license.");
@ -3700,12 +3785,16 @@ PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, CurrentDB &cur
return PreparedQuery{
{"STATUS"},
std::move(parsed_query.required_privileges),
[db_name = query->db_name_, db_handler](AnyStream *stream,
std::optional<int> n) -> std::optional<QueryHandlerResult> {
[db_name = query->db_name_, db_handler, interpreter = &interpreter](
AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> {
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
std::vector<std::vector<TypedValue>> status;
std::string res;
const auto success = db_handler->New(db_name);
const auto success = db_handler->New(db_name, &*interpreter->system_transaction_);
if (success.HasError()) {
switch (success.GetError()) {
case dbms::NewError::EXISTS:
@ -3780,16 +3869,20 @@ PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, CurrentDB &cur
return PreparedQuery{
{"STATUS"},
std::move(parsed_query.required_privileges),
[db_name = query->db_name_, db_handler, auth = interpreter_context->auth](
[db_name = query->db_name_, db_handler, auth = interpreter_context->auth, interpreter = &interpreter](
AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> {
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
std::vector<std::vector<TypedValue>> status;
try {
// Remove database
auto success = db_handler->TryDelete(db_name);
auto success = db_handler->TryDelete(db_name, &*interpreter->system_transaction_);
if (!success.HasError()) {
// Remove from auth
if (auth) auth->DeleteDatabase(db_name);
if (auth) auth->DeleteDatabase(db_name, &*interpreter->system_transaction_);
} else {
switch (success.GetError()) {
case dbms::DeleteError::DEFAULT_DB:
@ -4040,18 +4133,15 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
utils::Downcast<ReplicationQuery>(parsed_query.query);
// TODO Split SHOW REPLICAS (which needs the db) and other replication queries
auto system_transaction_guard = std::invoke([&]() -> std::optional<SystemTransactionGuard> {
if (system_queries) {
// TODO: Ordering between system and data queries
// Start a system transaction
auto system_unique = std::unique_lock{interpreter_context_->dbms_handler->system_lock_, std::defer_lock};
if (!system_unique.try_lock_for(std::chrono::milliseconds(kSystemTxTryMS))) {
throw ConcurrentSystemQueriesException("Multiple concurrent system queries are not supported.");
}
return std::optional<SystemTransactionGuard>{std::in_place, std::move(system_unique),
*interpreter_context_->dbms_handler};
auto system_transaction = std::invoke([&]() -> std::optional<memgraph::system::Transaction> {
if (!system_queries) return std::nullopt;
// TODO: Ordering between system and data queries
auto system_txn = interpreter_context_->system_->TryCreateTransaction(std::chrono::milliseconds(kSystemTxTryMS));
if (!system_txn) {
throw ConcurrentSystemQueriesException("Multiple concurrent system queries are not supported.");
}
return std::nullopt;
return system_txn;
});
// Some queries do not require a database to be executed (current_db_ won't be passed on to the Prepare*; special
@ -4117,7 +4207,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
prepared_query = PrepareAnalyzeGraphQuery(std::move(parsed_query), in_explicit_transaction_, current_db_);
} else if (utils::Downcast<AuthQuery>(parsed_query.query)) {
/// SYSTEM (Replication) PURE
prepared_query = PrepareAuthQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_);
prepared_query = PrepareAuthQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_, *this);
} else if (utils::Downcast<DatabaseInfoQuery>(parsed_query.query)) {
prepared_query = PrepareDatabaseInfoQuery(std::move(parsed_query), in_explicit_transaction_, current_db_);
} else if (utils::Downcast<SystemInfoQuery>(parsed_query.query)) {
@ -4128,13 +4218,18 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
&query_execution->notifications, current_db_);
} else if (utils::Downcast<ReplicationQuery>(parsed_query.query)) {
/// TODO: make replication DB agnostic
prepared_query =
PrepareReplicationQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications,
*interpreter_context_->dbms_handler, current_db_, interpreter_context_->config);
prepared_query = PrepareReplicationQuery(
std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications,
*interpreter_context_->replication_handler_, current_db_, interpreter_context_->config);
} else if (utils::Downcast<CoordinatorQuery>(parsed_query.query)) {
#ifdef MG_ENTERPRISE
prepared_query =
PrepareCoordinatorQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications,
*interpreter_context_->dbms_handler, interpreter_context_->config);
*interpreter_context_->coordinator_state_, interpreter_context_->config);
#else
throw QueryRuntimeException("Coordinator queries are not part of community edition");
#endif
} else if (utils::Downcast<LockPathQuery>(parsed_query.query)) {
prepared_query = PrepareLockPathQuery(std::move(parsed_query), in_explicit_transaction_, current_db_);
} else if (utils::Downcast<FreeMemoryQuery>(parsed_query.query)) {
@ -4177,8 +4272,8 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
}
/// SYSTEM (Replication) + INTERPRETER
// DMG_ASSERT(system_guard);
prepared_query = PrepareMultiDatabaseQuery(std::move(parsed_query), current_db_, interpreter_context_, on_change_
/*, *system_guard*/);
prepared_query =
PrepareMultiDatabaseQuery(std::move(parsed_query), current_db_, interpreter_context_, on_change_, *this);
} else if (utils::Downcast<ShowDatabasesQuery>(parsed_query.query)) {
prepared_query = PrepareShowDatabasesQuery(std::move(parsed_query), interpreter_context_, username_);
} else if (utils::Downcast<EdgeImportModeQuery>(parsed_query.query)) {
@ -4210,7 +4305,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
query_execution->summary["db"] = *query_execution->prepared_query->db;
// prepare is done, move system txn guard to be owned by interpreter
system_transaction_guard_ = std::move(system_transaction_guard);
system_transaction_ = std::move(system_transaction);
return {query_execution->prepared_query->header, query_execution->prepared_query->privileges, qid,
query_execution->prepared_query->db};
} catch (const utils::BasicException &) {
@ -4360,13 +4455,13 @@ void Interpreter::Commit() {
current_transaction_.reset();
if (!current_db_.db_transactional_accessor_ || !current_db_.db_acc_) {
// No database nor db transaction; check for system transaction
if (!system_transaction_guard_) return;
if (!system_transaction_) return;
// TODO Distinguish between data and system transaction state
// Think about updating the status to a struct with bitfield
// Clean transaction status on exit
utils::OnScopeExit clean_status([this]() {
system_transaction_guard_.reset();
system_transaction_.reset();
// System transactions are not terminable
// Durability has happened at time of PULL
// Commit is doing replication and timestamp update
@ -4384,7 +4479,23 @@ void Interpreter::Commit() {
}
});
system_transaction_guard_->Commit();
auto const main_commit = [&](replication::RoleMainData &mainData) {
// Only enterprise can do system replication
#ifdef MG_ENTERPRISE
if (license::global_license_checker.IsEnterpriseValidFast()) {
return system_transaction_->Commit(memgraph::system::DoReplication{mainData});
}
#endif
return system_transaction_->Commit(memgraph::system::DoNothing{});
};
auto const replica_commit = [&](replication::RoleReplicaData &) {
return system_transaction_->Commit(memgraph::system::DoNothing{});
};
auto const commit_method = utils::Overloaded{main_commit, replica_commit};
[[maybe_unused]] auto sync_result = std::visit(commit_method, interpreter_context_->repl_state->ReplicationData());
// TODO: something with sync_result
return;
}
auto *db = current_db_.db_acc_->get();

View File

@ -72,6 +72,7 @@ inline constexpr size_t kExecutionPoolMaxBlockSize = 1024UL; // 2 ^ 10
enum class QueryHandlerResult { COMMIT, ABORT, NOTHING };
#ifdef MG_ENTERPRISE
class CoordinatorQueryHandler {
public:
CoordinatorQueryHandler() = default;
@ -93,7 +94,6 @@ class CoordinatorQueryHandler {
ReplicationQuery::ReplicaState state;
};
#ifdef MG_ENTERPRISE
struct MainReplicaStatus {
std::string_view name;
std::string_view socket_address;
@ -103,9 +103,7 @@ class CoordinatorQueryHandler {
MainReplicaStatus(std::string_view name, std::string_view socket_address, bool alive, bool is_main)
: name{name}, socket_address{socket_address}, alive{alive}, is_main{is_main} {}
};
#endif
#ifdef MG_ENTERPRISE
/// @throw QueryRuntimeException if an error ocurred.
virtual void RegisterInstance(const std::string &coordinator_socket_address,
const std::string &replication_socket_address,
@ -117,9 +115,8 @@ class CoordinatorQueryHandler {
/// @throw QueryRuntimeException if an error ocurred.
virtual std::vector<coordination::CoordinatorInstanceStatus> ShowInstances() const = 0;
#endif
};
#endif
class AnalyzeGraphQueryHandler {
public:
@ -296,32 +293,12 @@ class Interpreter final {
void SetUser(std::string_view username);
struct SystemTransactionGuard {
explicit SystemTransactionGuard(std::unique_lock<utils::ResourceLock> guard, dbms::DbmsHandler &dbms_handler)
: system_guard_(std::move(guard)), dbms_handler_{&dbms_handler} {
dbms_handler_->NewSystemTransaction();
}
SystemTransactionGuard &operator=(SystemTransactionGuard &&) = default;
SystemTransactionGuard(SystemTransactionGuard &&) = default;
~SystemTransactionGuard() {
if (system_guard_.owns_lock()) dbms_handler_->ResetSystemTransaction();
}
dbms::AllSyncReplicaStatus Commit() { return dbms_handler_->Commit(); }
private:
std::unique_lock<utils::ResourceLock> system_guard_;
dbms::DbmsHandler *dbms_handler_;
};
std::optional<SystemTransactionGuard> system_transaction_guard_{};
std::optional<memgraph::system::Transaction> system_transaction_{};
private:
void ResetInterpreter() {
query_executions_.clear();
system_guard.reset();
system_transaction_guard_.reset();
system_transaction_.reset();
transaction_queries_->clear();
if (current_db_.db_acc_ && current_db_.db_acc_->is_deleting()) {
current_db_.db_acc_.reset();
@ -386,8 +363,6 @@ class Interpreter final {
// TODO Figure out how this would work for multi-database
// Exists only during a single transaction (for now should be okay as is)
std::vector<std::unique_ptr<QueryExecution>> query_executions_;
// TODO: our upgradable lock guard for system
std::optional<utils::ResourceLockGuard> system_guard;
// all queries that are run as part of the current transaction
utils::Synchronized<std::vector<std::string>, utils::SpinLock> transaction_queries_;

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -12,12 +12,27 @@
#include "query/interpreter_context.hpp"
#include "query/interpreter.hpp"
#include "system/include/system/system.hpp"
namespace memgraph::query {
InterpreterContext::InterpreterContext(InterpreterConfig interpreter_config, dbms::DbmsHandler *dbms_handler,
replication::ReplicationState *rs, query::AuthQueryHandler *ah,
query::AuthChecker *ac)
: dbms_handler(dbms_handler), config(interpreter_config), repl_state(rs), auth(ah), auth_checker(ac) {}
replication::ReplicationState *rs, memgraph::system::System &system,
#ifdef MG_ENTERPRISE
memgraph::coordination::CoordinatorState *coordinator_state,
#endif
AuthQueryHandler *ah, AuthChecker *ac,
ReplicationQueryHandler *replication_handler)
: dbms_handler(dbms_handler),
config(interpreter_config),
repl_state(rs),
#ifdef MG_ENTERPRISE
coordinator_state_{coordinator_state},
#endif
auth(ah),
auth_checker(ac),
replication_handler_{replication_handler},
system_{&system} {
}
std::vector<std::vector<TypedValue>> InterpreterContext::TerminateTransactions(
std::vector<std::string> maybe_kill_transaction_ids, const std::optional<std::string> &username,

View File

@ -20,14 +20,20 @@
#include "query/config.hpp"
#include "query/cypher_query_interpreter.hpp"
#include "query/replication_query_handler.hpp"
#include "query/typed_value.hpp"
#include "replication/state.hpp"
#include "storage/v2/config.hpp"
#include "storage/v2/transaction.hpp"
#include "system/state.hpp"
#include "system/system.hpp"
#include "utils/gatekeeper.hpp"
#include "utils/skip_list.hpp"
#include "utils/spin_lock.hpp"
#include "utils/synchronized.hpp"
#ifdef MG_ENTERPRISE
#include "coordination/coordinator_state.hpp"
#endif
namespace memgraph::dbms {
class DbmsHandler;
@ -48,7 +54,12 @@ class Interpreter;
*/
struct InterpreterContext {
InterpreterContext(InterpreterConfig interpreter_config, dbms::DbmsHandler *dbms_handler,
replication::ReplicationState *rs, AuthQueryHandler *ah = nullptr, AuthChecker *ac = nullptr);
replication::ReplicationState *rs, memgraph::system::System &system,
#ifdef MG_ENTERPRISE
memgraph::coordination::CoordinatorState *coordinator_state,
#endif
AuthQueryHandler *ah = nullptr, AuthChecker *ac = nullptr,
ReplicationQueryHandler *replication_handler = nullptr);
memgraph::dbms::DbmsHandler *dbms_handler;
@ -59,9 +70,14 @@ struct InterpreterContext {
// GLOBAL
memgraph::replication::ReplicationState *repl_state;
#ifdef MG_ENTERPRISE
memgraph::coordination::CoordinatorState *coordinator_state_;
#endif
AuthQueryHandler *auth;
AuthChecker *auth_checker;
ReplicationQueryHandler *replication_handler_;
system::System *system_;
// Used to check active transactions
// TODO: Have a way to read the current database

View File

@ -0,0 +1,60 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "replication_coordination_glue/role.hpp"
#include "utils/result.hpp"
// BEGIN fwd declares
namespace memgraph::replication {
struct ReplicationState;
struct ReplicationServerConfig;
struct ReplicationClientConfig;
} // namespace memgraph::replication
namespace memgraph::query {
enum class RegisterReplicaError : uint8_t { NAME_EXISTS, ENDPOINT_EXISTS, CONNECTION_FAILED, COULD_NOT_BE_PERSISTED };
enum class UnregisterReplicaResult : uint8_t {
NOT_MAIN,
COULD_NOT_BE_PERSISTED,
CAN_NOT_UNREGISTER,
SUCCESS,
};
/// A handler type that keep in sync current ReplicationState and the MAIN/REPLICA-ness of Storage
struct ReplicationQueryHandler {
virtual ~ReplicationQueryHandler() = default;
// as REPLICA, become MAIN
virtual bool SetReplicationRoleMain() = 0;
// as MAIN, become REPLICA
virtual bool SetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config) = 0;
// as MAIN, define and connect to REPLICAs
virtual auto TryRegisterReplica(const memgraph::replication::ReplicationClientConfig &config)
-> utils::BasicResult<RegisterReplicaError> = 0;
virtual auto RegisterReplica(const memgraph::replication::ReplicationClientConfig &config)
-> utils::BasicResult<RegisterReplicaError> = 0;
// as MAIN, remove a REPLICA connection
virtual auto UnregisterReplica(std::string_view name) -> UnregisterReplicaResult = 0;
// Helper pass-through (TODO: remove)
virtual auto GetRole() const -> memgraph::replication_coordination_glue::ReplicationRole = 0;
virtual bool IsMain() const = 0;
virtual bool IsReplica() const = 0;
};
} // namespace memgraph::query

View File

@ -6,7 +6,6 @@ target_sources(mg-replication
include/replication/epoch.hpp
include/replication/config.hpp
include/replication/status.hpp
include/replication/messages.hpp
include/replication/replication_client.hpp
include/replication/replication_server.hpp
@ -15,7 +14,6 @@ target_sources(mg-replication
epoch.cpp
config.cpp
status.cpp
messages.cpp
replication_client.cpp
replication_server.cpp
)

View File

@ -1,47 +0,0 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "rpc/messages.hpp"
#include "slk/serialization.hpp"
namespace memgraph::replication {
struct SystemHeartbeatReq {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(SystemHeartbeatReq *self, memgraph::slk::Reader *reader);
static void Save(const SystemHeartbeatReq &self, memgraph::slk::Builder *builder);
SystemHeartbeatReq() = default;
};
struct SystemHeartbeatRes {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(SystemHeartbeatRes *self, memgraph::slk::Reader *reader);
static void Save(const SystemHeartbeatRes &self, memgraph::slk::Builder *builder);
SystemHeartbeatRes() = default;
explicit SystemHeartbeatRes(uint64_t system_timestamp) : system_timestamp(system_timestamp) {}
uint64_t system_timestamp;
};
using SystemHeartbeatRpc = rpc::RequestResponse<SystemHeartbeatReq, SystemHeartbeatRes>;
} // namespace memgraph::replication
namespace memgraph::slk {
void Save(const memgraph::replication::SystemHeartbeatRes &self, memgraph::slk::Builder *builder);
void Load(memgraph::replication::SystemHeartbeatRes *self, memgraph::slk::Reader *reader);
void Save(const memgraph::replication::SystemHeartbeatReq & /*self*/, memgraph::slk::Builder * /*builder*/);
void Load(memgraph::replication::SystemHeartbeatReq * /*self*/, memgraph::slk::Reader * /*reader*/);
} // namespace memgraph::slk

View File

@ -41,26 +41,67 @@ struct ReplicationClient {
void StartFrequentCheck(F &&callback) {
// Help the user to get the most accurate replica state possible.
if (replica_check_frequency_ > std::chrono::seconds(0)) {
replica_checker_.Run("Replica Checker", replica_check_frequency_,
[this, cb = std::forward<F>(callback), reconnect = false]() mutable {
try {
{
auto stream{rpc_client_.Stream<memgraph::replication_coordination_glue::FrequentHeartbeatRpc>()};
stream.AwaitResponse();
}
cb(reconnect, *this);
reconnect = false;
} catch (const rpc::RpcFailedException &) {
// Nothing to do...wait for a reconnect
// NOTE: Here we are communicating with the instance connection.
// We don't have access to the undelying client; so the only thing we can do it
// tell the callback that this is a reconnection and to check the state
reconnect = true;
}
});
replica_checker_.Run(
"Replica Checker", replica_check_frequency_,
[this, cb = std::forward<F>(callback), reconnect = false]() mutable {
try {
{
auto stream{rpc_client_.Stream<memgraph::replication_coordination_glue::FrequentHeartbeatRpc>()};
stream.AwaitResponse();
}
cb(reconnect, *this);
reconnect = false;
} catch (const rpc::RpcFailedException &) {
// Nothing to do...wait for a reconnect
// NOTE: Here we are communicating with the instance connection.
// We don't have access to the undelying client; so the only thing we can do it
// tell the callback that this is a reconnection and to check the state
reconnect = true;
}
});
}
}
//! \tparam RPC An rpc::RequestResponse
//! \tparam Args the args type
//! \param client the client to use for rpc communication
//! \param check predicate to check response is ok
//! \param args arguments to forward to the rpc request
//! \return If replica stream is completed or enqueued
template <typename RPC, typename... Args>
bool SteamAndFinalizeDelta(auto &&check, Args &&...args) {
try {
auto stream = rpc_client_.template Stream<RPC>(std::forward<Args>(args)...);
auto task = [this, check = std::forward<decltype(check)>(check), stream = std::move(stream)]() mutable {
if (stream.IsDefunct()) {
state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; });
return false;
}
try {
if (check(stream.AwaitResponse())) {
return true;
}
} catch (memgraph::rpc::GenericRpcFailedException const &e) {
// swallow error, fallthrough to error handling
}
// This replica needs SYSTEM recovery
state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; });
return false;
};
if (mode_ == memgraph::replication_coordination_glue::ReplicationMode::ASYNC) {
thread_pool_.AddTask([task = utils::CopyMovableFunctionWrapper{std::move(task)}]() mutable { task(); });
return true;
}
return task();
} catch (memgraph::rpc::GenericRpcFailedException const &e) {
// This replica needs SYSTEM recovery
state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; });
return false;
}
};
std::string name_;
communication::ClientContext rpc_context_;
rpc::Client rpc_client_;

View File

@ -1,52 +0,0 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "replication/messages.hpp"
namespace memgraph::replication {
constexpr utils::TypeInfo SystemHeartbeatReq::kType{utils::TypeId::REP_SYSTEM_HEARTBEAT_REQ, "SystemHeartbeatReq",
nullptr};
constexpr utils::TypeInfo SystemHeartbeatRes::kType{utils::TypeId::REP_SYSTEM_HEARTBEAT_RES, "SystemHeartbeatRes",
nullptr};
void SystemHeartbeatReq::Save(const SystemHeartbeatReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void SystemHeartbeatReq::Load(SystemHeartbeatReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
void SystemHeartbeatRes::Save(const SystemHeartbeatRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void SystemHeartbeatRes::Load(SystemHeartbeatRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
} // namespace memgraph::replication
namespace memgraph::slk {
// Serialize code for SystemHeartbeatRes
void Save(const memgraph::replication::SystemHeartbeatRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.system_timestamp, builder);
}
void Load(memgraph::replication::SystemHeartbeatRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->system_timestamp, reader);
}
// Serialize code for SystemHeartbeatReq
void Save(const memgraph::replication::SystemHeartbeatReq & /*self*/, memgraph::slk::Builder * /*builder*/) {
/* Nothing to serialize */
}
void Load(memgraph::replication::SystemHeartbeatReq * /*self*/, memgraph::slk::Reader * /*reader*/) {
/* Nothing to serialize */
}
} // namespace memgraph::slk

View File

@ -0,0 +1,17 @@
add_library(mg-replication_handler STATIC)
add_library(mg::replication_handler ALIAS mg-replication_handler)
target_sources(mg-replication_handler
PUBLIC
include/replication_handler/replication_handler.hpp
include/replication_handler/system_replication.hpp
include/replication_handler/system_rpc.hpp
PRIVATE
replication_handler.cpp
system_replication.cpp
system_rpc.cpp
)
target_include_directories(mg-replication_handler PUBLIC include)
target_link_libraries(mg-replication_handler
PUBLIC mg-auth mg-dbms mg-replication)

View File

@ -0,0 +1,220 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "auth/auth.hpp"
#include "dbms/dbms_handler.hpp"
#include "replication/include/replication/state.hpp"
#include "replication_handler/system_rpc.hpp"
#include "utils/result.hpp"
namespace memgraph::replication {
inline std::optional<query::RegisterReplicaError> HandleRegisterReplicaStatus(
utils::BasicResult<replication::RegisterReplicaError, replication::ReplicationClient *> &instance_client);
#ifdef MG_ENTERPRISE
void StartReplicaClient(replication::ReplicationClient &client, system::System *system, dbms::DbmsHandler &dbms_handler,
auth::SynchedAuth &auth);
#else
void StartReplicaClient(replication::ReplicationClient &client, dbms::DbmsHandler &dbms_handler);
#endif
#ifdef MG_ENTERPRISE
// TODO: Split into 2 functions: dbms and auth
// When being called by interpreter no need to gain lock, it should already be under a system transaction
// But concurrently the FrequentCheck is running and will need to lock before reading last_committed_system_timestamp_
template <bool REQUIRE_LOCK = false>
void SystemRestore(replication::ReplicationClient &client, system::System *system, dbms::DbmsHandler &dbms_handler,
auth::SynchedAuth &auth) {
// Check if system is up to date
if (client.state_.WithLock(
[](auto &state) { return state == memgraph::replication::ReplicationClient::State::READY; }))
return;
// Try to recover...
{
struct DbInfo {
std::vector<storage::SalientConfig> configs;
uint64_t last_committed_timestamp;
};
DbInfo db_info = std::invoke([&] {
auto guard = std::invoke([&]() -> std::optional<memgraph::system::TransactionGuard> {
if constexpr (REQUIRE_LOCK) {
return system->GenTransactionGuard();
}
return std::nullopt;
});
if (license::global_license_checker.IsEnterpriseValidFast()) {
auto configs = std::vector<storage::SalientConfig>{};
dbms_handler.ForEach([&configs](dbms::DatabaseAccess acc) { configs.emplace_back(acc->config().salient); });
// TODO: This is `SystemRestore` maybe DbInfo is incorrect as it will need Auth also
return DbInfo{configs, system->LastCommittedSystemTimestamp()};
}
// No license -> send only default config
return DbInfo{{dbms_handler.Get()->config().salient}, system->LastCommittedSystemTimestamp()};
});
try {
auto stream = std::invoke([&]() {
// Handle only default database is no license
if (!license::global_license_checker.IsEnterpriseValidFast()) {
return client.rpc_client_.Stream<replication::SystemRecoveryRpc>(
db_info.last_committed_timestamp, std::move(db_info.configs), auth::Auth::Config{},
std::vector<auth::User>{}, std::vector<auth::Role>{});
}
return auth.WithLock([&](auto &locked_auth) {
return client.rpc_client_.Stream<replication::SystemRecoveryRpc>(
db_info.last_committed_timestamp, std::move(db_info.configs), locked_auth.GetConfig(),
locked_auth.AllUsers(), locked_auth.AllRoles());
});
});
const auto response = stream.AwaitResponse();
if (response.result == replication::SystemRecoveryRes::Result::FAILURE) {
client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; });
return;
}
} catch (memgraph::rpc::GenericRpcFailedException const &e) {
client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; });
return;
}
}
// Successfully recovered
client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::READY; });
}
#endif
/// A handler type that keep in sync current ReplicationState and the MAIN/REPLICA-ness of Storage
struct ReplicationHandler : public memgraph::query::ReplicationQueryHandler {
#ifdef MG_ENTERPRISE
explicit ReplicationHandler(memgraph::replication::ReplicationState &repl_state,
memgraph::dbms::DbmsHandler &dbms_handler, memgraph::system::System *system,
memgraph::auth::SynchedAuth &auth);
#else
explicit ReplicationHandler(memgraph::replication::ReplicationState &repl_state,
memgraph::dbms::DbmsHandler &dbms_handler);
#endif
// as REPLICA, become MAIN
bool SetReplicationRoleMain() override;
// as MAIN, become REPLICA
bool SetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config) override;
// as MAIN, define and connect to REPLICAs
auto TryRegisterReplica(const memgraph::replication::ReplicationClientConfig &config)
-> memgraph::utils::BasicResult<memgraph::query::RegisterReplicaError> override;
auto RegisterReplica(const memgraph::replication::ReplicationClientConfig &config)
-> memgraph::utils::BasicResult<memgraph::query::RegisterReplicaError> override;
// as MAIN, remove a REPLICA connection
auto UnregisterReplica(std::string_view name) -> memgraph::query::UnregisterReplicaResult override;
bool DoReplicaToMainPromotion();
// Helper pass-through (TODO: remove)
auto GetRole() const -> memgraph::replication_coordination_glue::ReplicationRole override;
bool IsMain() const override;
bool IsReplica() const override;
private:
template <bool HandleFailure>
auto RegisterReplica_(const memgraph::replication::ReplicationClientConfig &config)
-> memgraph::utils::BasicResult<memgraph::query::RegisterReplicaError> {
MG_ASSERT(repl_state_.IsMain(), "Only main instance can register a replica!");
auto maybe_client = repl_state_.RegisterReplica(config);
if (maybe_client.HasError()) {
switch (maybe_client.GetError()) {
case memgraph::replication::RegisterReplicaError::NOT_MAIN:
MG_ASSERT(false, "Only main instance can register a replica!");
return {};
case memgraph::replication::RegisterReplicaError::NAME_EXISTS:
return memgraph::query::RegisterReplicaError::NAME_EXISTS;
case memgraph::replication::RegisterReplicaError::ENDPOINT_EXISTS:
return memgraph::query::RegisterReplicaError::ENDPOINT_EXISTS;
case memgraph::replication::RegisterReplicaError::COULD_NOT_BE_PERSISTED:
return memgraph::query::RegisterReplicaError::COULD_NOT_BE_PERSISTED;
case memgraph::replication::RegisterReplicaError::SUCCESS:
break;
}
}
if (!memgraph::dbms::allow_mt_repl && dbms_handler_.All().size() > 1) {
spdlog::warn("Multi-tenant replication is currently not supported!");
}
#ifdef MG_ENTERPRISE
// Update system before enabling individual storage <-> replica clients
SystemRestore(*maybe_client.GetValue(), system_, dbms_handler_, auth_);
#endif
const auto dbms_error = HandleRegisterReplicaStatus(maybe_client);
if (dbms_error.has_value()) {
return *dbms_error;
}
auto &instance_client_ptr = maybe_client.GetValue();
bool all_clients_good = true;
// Add database specific clients (NOTE Currently all databases are connected to each replica)
dbms_handler_.ForEach([&](dbms::DatabaseAccess db_acc) {
auto *storage = db_acc->storage();
if (!dbms::allow_mt_repl && storage->name() != dbms::kDefaultDB) {
return;
}
// TODO: ATM only IN_MEMORY_TRANSACTIONAL, fix other modes
if (storage->storage_mode_ != storage::StorageMode::IN_MEMORY_TRANSACTIONAL) return;
all_clients_good &= storage->repl_storage_state_.replication_clients_.WithLock(
[storage, &instance_client_ptr, db_acc = std::move(db_acc)](auto &storage_clients) mutable { // NOLINT
auto client = std::make_unique<storage::ReplicationStorageClient>(*instance_client_ptr);
// All good, start replica client
client->Start(storage, std::move(db_acc));
// After start the storage <-> replica state should be READY or RECOVERING (if correctly started)
// MAYBE_BEHIND isn't a statement of the current state, this is the default value
// Failed to start due an error like branching of MAIN and REPLICA
const bool success = client->State() != storage::replication::ReplicaState::MAYBE_BEHIND;
if (HandleFailure || success) {
storage_clients.push_back(std::move(client));
}
return success;
});
});
// NOTE Currently if any databases fails, we revert back
if (!HandleFailure && !all_clients_good) {
spdlog::error("Failed to register all databases on the REPLICA \"{}\"", config.name);
UnregisterReplica(config.name);
return memgraph::query::RegisterReplicaError::CONNECTION_FAILED;
}
// No client error, start instance level client
#ifdef MG_ENTERPRISE
StartReplicaClient(*instance_client_ptr, system_, dbms_handler_, auth_);
#else
StartReplicaClient(*instance_client_ptr, dbms_handler_);
#endif
return {};
}
memgraph::replication::ReplicationState &repl_state_;
memgraph::dbms::DbmsHandler &dbms_handler_;
#ifdef MG_ENTERPRISE
memgraph::system::System *system_;
memgraph::auth::SynchedAuth &auth_;
#endif
};
} // namespace memgraph::replication

View File

@ -0,0 +1,31 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "auth/auth.hpp"
#include "dbms/dbms_handler.hpp"
#include "slk/streams.hpp"
#include "system/state.hpp"
namespace memgraph::replication {
#ifdef MG_ENTERPRISE
void SystemHeartbeatHandler(uint64_t ts, slk::Reader *req_reader, slk::Builder *res_builder);
void SystemRecoveryHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access,
dbms::DbmsHandler &dbms_handler, auth::SynchedAuth &auth, slk::Reader *req_reader,
slk::Builder *res_builder);
void Register(replication::RoleReplicaData const &data, dbms::DbmsHandler &dbms_handler, auth::SynchedAuth &auth);
bool StartRpcServer(dbms::DbmsHandler &dbms_handler, const replication::RoleReplicaData &data, auth::SynchedAuth &auth);
#else
bool StartRpcServer(dbms::DbmsHandler &dbms_handler, const replication::RoleReplicaData &data);
#endif
} // namespace memgraph::replication

View File

@ -0,0 +1,95 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <cstdint>
#include <vector>
#include "auth/auth.hpp"
#include "auth/models.hpp"
#include "rpc/messages.hpp"
#include "storage/v2/config.hpp"
namespace memgraph::replication {
struct SystemHeartbeatReq {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(SystemHeartbeatReq *self, memgraph::slk::Reader *reader);
static void Save(const SystemHeartbeatReq &self, memgraph::slk::Builder *builder);
SystemHeartbeatReq() = default;
};
struct SystemHeartbeatRes {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(SystemHeartbeatRes *self, memgraph::slk::Reader *reader);
static void Save(const SystemHeartbeatRes &self, memgraph::slk::Builder *builder);
SystemHeartbeatRes() = default;
explicit SystemHeartbeatRes(uint64_t system_timestamp) : system_timestamp(system_timestamp) {}
uint64_t system_timestamp;
};
using SystemHeartbeatRpc = rpc::RequestResponse<SystemHeartbeatReq, SystemHeartbeatRes>;
struct SystemRecoveryReq {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(SystemRecoveryReq *self, memgraph::slk::Reader *reader);
static void Save(const SystemRecoveryReq &self, memgraph::slk::Builder *builder);
SystemRecoveryReq() = default;
SystemRecoveryReq(uint64_t forced_group_timestamp, std::vector<storage::SalientConfig> database_configs,
auth::Auth::Config auth_config, std::vector<auth::User> users, std::vector<auth::Role> roles)
: forced_group_timestamp{forced_group_timestamp},
database_configs(std::move(database_configs)),
auth_config(std::move(auth_config)),
users{std::move(users)},
roles{std::move(roles)} {}
uint64_t forced_group_timestamp;
std::vector<storage::SalientConfig> database_configs;
auth::Auth::Config auth_config;
std::vector<auth::User> users;
std::vector<auth::Role> roles;
};
struct SystemRecoveryRes {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
enum class Result : uint8_t { SUCCESS, NO_NEED, FAILURE, /* Leave at end */ N };
static void Load(SystemRecoveryRes *self, memgraph::slk::Reader *reader);
static void Save(const SystemRecoveryRes &self, memgraph::slk::Builder *builder);
SystemRecoveryRes() = default;
explicit SystemRecoveryRes(Result res) : result(res) {}
Result result;
};
using SystemRecoveryRpc = rpc::RequestResponse<SystemRecoveryReq, SystemRecoveryRes>;
} // namespace memgraph::replication
namespace memgraph::slk {
void Save(const memgraph::replication::SystemHeartbeatRes &self, memgraph::slk::Builder *builder);
void Load(memgraph::replication::SystemHeartbeatRes *self, memgraph::slk::Reader *reader);
void Save(const memgraph::replication::SystemHeartbeatReq & /*self*/, memgraph::slk::Builder * /*builder*/);
void Load(memgraph::replication::SystemHeartbeatReq * /*self*/, memgraph::slk::Reader * /*reader*/);
void Save(const memgraph::replication::SystemRecoveryReq &self, memgraph::slk::Builder *builder);
void Load(memgraph::replication::SystemRecoveryReq *self, memgraph::slk::Reader *reader);
void Save(const memgraph::replication::SystemRecoveryRes &self, memgraph::slk::Builder *builder);
void Load(memgraph::replication::SystemRecoveryRes *self, memgraph::slk::Reader *reader);
} // namespace memgraph::slk

View File

@ -0,0 +1,291 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "replication_handler/replication_handler.hpp"
#include "dbms/dbms_handler.hpp"
#include "replication_handler/system_replication.hpp"
namespace memgraph::replication {
namespace {
#ifdef MG_ENTERPRISE
void RecoverReplication(memgraph::replication::ReplicationState &repl_state, memgraph::system::System *system,
memgraph::dbms::DbmsHandler &dbms_handler, memgraph::auth::SynchedAuth &auth) {
/*
* REPLICATION RECOVERY AND STARTUP
*/
// Startup replication state (if recovered at startup)
auto replica = [&dbms_handler, &auth](memgraph::replication::RoleReplicaData const &data) {
return memgraph::replication::StartRpcServer(dbms_handler, data, auth);
};
// Replication recovery and frequent check start
auto main = [system, &dbms_handler, &auth](memgraph::replication::RoleMainData &mainData) {
for (auto &client : mainData.registered_replicas_) {
memgraph::replication::SystemRestore(client, system, dbms_handler, auth);
}
// DBMS here
dbms_handler.ForEach([&mainData](memgraph::dbms::DatabaseAccess db_acc) {
dbms::DbmsHandler::RecoverStorageReplication(std::move(db_acc), mainData);
});
for (auto &client : mainData.registered_replicas_) {
memgraph::replication::StartReplicaClient(client, system, dbms_handler, auth);
}
// Warning
if (dbms_handler.default_config().durability.snapshot_wal_mode ==
memgraph::storage::Config::Durability::SnapshotWalMode::DISABLED) {
spdlog::warn(
"The instance has the MAIN replication role, but durability logs and snapshots are disabled. Please "
"consider "
"enabling durability by using --storage-snapshot-interval-sec and --storage-wal-enabled flags because "
"without write-ahead logs this instance is not replicating any data.");
}
return true;
};
auto result = std::visit(memgraph::utils::Overloaded{replica, main}, repl_state.ReplicationData());
MG_ASSERT(result, "Replica recovery failure!");
}
#else
void RecoverReplication(memgraph::replication::ReplicationState &repl_state,
memgraph::dbms::DbmsHandler &dbms_handler) {
// Startup replication state (if recovered at startup)
auto replica = [&dbms_handler](memgraph::replication::RoleReplicaData const &data) {
return memgraph::replication::StartRpcServer(dbms_handler, data);
};
// Replication recovery and frequent check start
auto main = [&dbms_handler](memgraph::replication::RoleMainData &mainData) {
dbms::DbmsHandler::RecoverStorageReplication(dbms_handler.Get(), mainData);
for (auto &client : mainData.registered_replicas_) {
memgraph::replication::StartReplicaClient(client, dbms_handler);
}
// Warning
if (dbms_handler.default_config().durability.snapshot_wal_mode ==
memgraph::storage::Config::Durability::SnapshotWalMode::DISABLED) {
spdlog::warn(
"The instance has the MAIN replication role, but durability logs and snapshots are disabled. Please "
"consider "
"enabling durability by using --storage-snapshot-interval-sec and --storage-wal-enabled flags because "
"without write-ahead logs this instance is not replicating any data.");
}
return true;
};
auto result = std::visit(memgraph::utils::Overloaded{replica, main}, repl_state.ReplicationData());
MG_ASSERT(result, "Replica recovery failure!");
}
#endif
} // namespace
inline std::optional<query::RegisterReplicaError> HandleRegisterReplicaStatus(
utils::BasicResult<replication::RegisterReplicaError, replication::ReplicationClient *> &instance_client) {
if (instance_client.HasError()) switch (instance_client.GetError()) {
case replication::RegisterReplicaError::NOT_MAIN:
MG_ASSERT(false, "Only main instance can register a replica!");
return {};
case replication::RegisterReplicaError::NAME_EXISTS:
return query::RegisterReplicaError::NAME_EXISTS;
case replication::RegisterReplicaError::ENDPOINT_EXISTS:
return query::RegisterReplicaError::ENDPOINT_EXISTS;
case replication::RegisterReplicaError::COULD_NOT_BE_PERSISTED:
return query::RegisterReplicaError::COULD_NOT_BE_PERSISTED;
case replication::RegisterReplicaError::SUCCESS:
break;
}
return {};
}
#ifdef MG_ENTERPRISE
void StartReplicaClient(replication::ReplicationClient &client, system::System *system, dbms::DbmsHandler &dbms_handler,
auth::SynchedAuth &auth) {
#else
void StartReplicaClient(replication::ReplicationClient &client, dbms::DbmsHandler &dbms_handler) {
#endif
// No client error, start instance level client
auto const &endpoint = client.rpc_client_.Endpoint();
spdlog::trace("Replication client started at: {}:{}", endpoint.address, endpoint.port);
client.StartFrequentCheck([&,
#ifdef MG_ENTERPRISE
system = system,
#endif
license = license::global_license_checker.IsEnterpriseValidFast()](
bool reconnect, replication::ReplicationClient &client) mutable {
// Working connection
// Check if system needs restoration
if (reconnect) {
client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; });
}
// Check if license has changed
const auto new_license = license::global_license_checker.IsEnterpriseValidFast();
if (new_license != license) {
license = new_license;
client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; });
}
#ifdef MG_ENTERPRISE
SystemRestore<true>(client, system, dbms_handler, auth);
#endif
// Check if any database has been left behind
dbms_handler.ForEach([&name = client.name_, reconnect](dbms::DatabaseAccess db_acc) {
// Specific database <-> replica client
db_acc->storage()->repl_storage_state_.WithClient(name, [&](storage::ReplicationStorageClient *client) {
if (reconnect || client->State() == storage::replication::ReplicaState::MAYBE_BEHIND) {
// Database <-> replica might be behind, check and recover
client->TryCheckReplicaStateAsync(db_acc->storage(), db_acc);
}
});
});
});
}
#ifdef MG_ENTERPRISE
ReplicationHandler::ReplicationHandler(memgraph::replication::ReplicationState &repl_state,
memgraph::dbms::DbmsHandler &dbms_handler, memgraph::system::System *system,
memgraph::auth::SynchedAuth &auth)
: repl_state_{repl_state}, dbms_handler_{dbms_handler}, system_{system}, auth_{auth} {
RecoverReplication(repl_state_, system_, dbms_handler_, auth_);
}
#else
ReplicationHandler::ReplicationHandler(replication::ReplicationState &repl_state, dbms::DbmsHandler &dbms_handler)
: repl_state_{repl_state}, dbms_handler_{dbms_handler} {
RecoverReplication(repl_state_, dbms_handler_);
}
#endif
bool ReplicationHandler::SetReplicationRoleMain() {
auto const main_handler = [](memgraph::replication::RoleMainData &) {
// If we are already MAIN, we don't want to change anything
return false;
};
auto const replica_handler = [this](memgraph::replication::RoleReplicaData const &) {
return DoReplicaToMainPromotion();
};
// TODO: under lock
return std::visit(memgraph::utils::Overloaded{main_handler, replica_handler}, repl_state_.ReplicationData());
}
bool ReplicationHandler::SetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config) {
// We don't want to restart the server if we're already a REPLICA
if (repl_state_.IsReplica()) {
return false;
}
// TODO StorageState needs to be synched. Could have a dangling reference if someone adds a database as we are
// deleting the replica.
// Remove database specific clients
dbms_handler_.ForEach([&](memgraph::dbms::DatabaseAccess db_acc) {
auto *storage = db_acc->storage();
storage->repl_storage_state_.replication_clients_.WithLock([](auto &clients) { clients.clear(); });
});
// Remove instance level clients
std::get<memgraph::replication::RoleMainData>(repl_state_.ReplicationData()).registered_replicas_.clear();
// Creates the server
repl_state_.SetReplicationRoleReplica(config);
// Start
const auto success =
std::visit(memgraph::utils::Overloaded{[](memgraph::replication::RoleMainData const &) {
// ASSERT
return false;
},
[this](memgraph::replication::RoleReplicaData const &data) {
#ifdef MG_ENTERPRISE
return StartRpcServer(dbms_handler_, data, auth_);
#else
return StartRpcServer(dbms_handler_, data);
#endif
}},
repl_state_.ReplicationData());
// TODO Handle error (restore to main?)
return success;
}
bool ReplicationHandler::DoReplicaToMainPromotion() {
// STEP 1) bring down all REPLICA servers
dbms_handler_.ForEach([](dbms::DatabaseAccess db_acc) {
auto *storage = db_acc->storage();
// Remember old epoch + storage timestamp association
storage->PrepareForNewEpoch();
});
// STEP 2) Change to MAIN
// TODO: restore replication servers if false?
if (!repl_state_.SetReplicationRoleMain()) {
// TODO: Handle recovery on failure???
return false;
}
// STEP 3) We are now MAIN, update storage local epoch
const auto &epoch = std::get<replication::RoleMainData>(std::as_const(repl_state_).ReplicationData()).epoch_;
dbms_handler_.ForEach([&](dbms::DatabaseAccess db_acc) {
auto *storage = db_acc->storage();
storage->repl_storage_state_.epoch_ = epoch;
});
return true;
};
// as MAIN, define and connect to REPLICAs
auto ReplicationHandler::TryRegisterReplica(const memgraph::replication::ReplicationClientConfig &config)
-> memgraph::utils::BasicResult<memgraph::query::RegisterReplicaError> {
return RegisterReplica_<false>(config);
}
auto ReplicationHandler::RegisterReplica(const memgraph::replication::ReplicationClientConfig &config)
-> memgraph::utils::BasicResult<memgraph::query::RegisterReplicaError> {
return RegisterReplica_<true>(config);
}
auto ReplicationHandler::UnregisterReplica(std::string_view name) -> memgraph::query::UnregisterReplicaResult {
auto const replica_handler =
[](memgraph::replication::RoleReplicaData const &) -> memgraph::query::UnregisterReplicaResult {
return memgraph::query::UnregisterReplicaResult::NOT_MAIN;
};
auto const main_handler =
[this, name](memgraph::replication::RoleMainData &mainData) -> memgraph::query::UnregisterReplicaResult {
if (!repl_state_.TryPersistUnregisterReplica(name)) {
return memgraph::query::UnregisterReplicaResult::COULD_NOT_BE_PERSISTED;
}
// Remove database specific clients
dbms_handler_.ForEach([name](memgraph::dbms::DatabaseAccess db_acc) {
db_acc->storage()->repl_storage_state_.replication_clients_.WithLock([&name](auto &clients) {
std::erase_if(clients, [name](const auto &client) { return client->Name() == name; });
});
});
// Remove instance level clients
auto const n_unregistered =
std::erase_if(mainData.registered_replicas_, [name](auto const &client) { return client.name_ == name; });
return n_unregistered != 0 ? memgraph::query::UnregisterReplicaResult::SUCCESS
: memgraph::query::UnregisterReplicaResult::CAN_NOT_UNREGISTER;
};
return std::visit(memgraph::utils::Overloaded{main_handler, replica_handler}, repl_state_.ReplicationData());
}
auto ReplicationHandler::GetRole() const -> memgraph::replication_coordination_glue::ReplicationRole {
return repl_state_.GetRole();
}
bool ReplicationHandler::IsMain() const { return repl_state_.IsMain(); }
bool ReplicationHandler::IsReplica() const { return repl_state_.IsReplica(); }
} // namespace memgraph::replication

View File

@ -0,0 +1,115 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "replication_handler/system_replication.hpp"
#include <spdlog/spdlog.h>
#include "auth/replication_handlers.hpp"
#include "dbms/replication_handlers.hpp"
#include "license/license.hpp"
#include "replication_handler/system_rpc.hpp"
namespace memgraph::replication {
#ifdef MG_ENTERPRISE
void SystemHeartbeatHandler(const uint64_t ts, slk::Reader *req_reader, slk::Builder *res_builder) {
replication::SystemHeartbeatRes res{0};
// Ignore if no license
if (!license::global_license_checker.IsEnterpriseValidFast()) {
spdlog::error("Handling SystemHeartbeat, an enterprise RPC message, without license.");
memgraph::slk::Save(res, res_builder);
return;
}
replication::SystemHeartbeatReq req;
replication::SystemHeartbeatReq::Load(&req, req_reader);
res = replication::SystemHeartbeatRes{ts};
memgraph::slk::Save(res, res_builder);
}
void SystemRecoveryHandler(memgraph::system::ReplicaHandlerAccessToState &system_state_access,
dbms::DbmsHandler &dbms_handler, auth::SynchedAuth &auth, slk::Reader *req_reader,
slk::Builder *res_builder) {
using memgraph::replication::SystemRecoveryRes;
SystemRecoveryRes res(SystemRecoveryRes::Result::FAILURE);
utils::OnScopeExit send_on_exit([&]() { memgraph::slk::Save(res, res_builder); });
memgraph::replication::SystemRecoveryReq req;
memgraph::slk::Load(&req, req_reader);
/*
* DBMS
*/
if (!dbms::SystemRecoveryHandler(dbms_handler, req.database_configs)) return; // Failure sent on exit
/*
* AUTH
*/
if (!auth::SystemRecoveryHandler(auth, req.auth_config, req.users, req.roles)) return; // Failure sent on exit
/*
* SUCCESSFUL RECOVERY
*/
system_state_access.SetLastCommitedTS(req.forced_group_timestamp);
spdlog::debug("SystemRecoveryHandler: SUCCESS updated LCTS to {}", req.forced_group_timestamp);
res = SystemRecoveryRes(SystemRecoveryRes::Result::SUCCESS);
}
void Register(replication::RoleReplicaData const &data, dbms::DbmsHandler &dbms_handler, auth::SynchedAuth &auth) {
// NOTE: Register even without license as the user could add a license at run-time
// TODO: fix Register when system is removed from DbmsHandler
auto system_state_access = dbms_handler.system_->CreateSystemStateAccess();
// System
data.server->rpc_server_.Register<replication::SystemHeartbeatRpc>(
[system_state_access](auto *req_reader, auto *res_builder) {
spdlog::debug("Received SystemHeartbeatRpc");
SystemHeartbeatHandler(system_state_access.LastCommitedTS(), req_reader, res_builder);
});
data.server->rpc_server_.Register<replication::SystemRecoveryRpc>(
[system_state_access, &dbms_handler, &auth](auto *req_reader, auto *res_builder) mutable {
spdlog::debug("Received SystemRecoveryRpc");
SystemRecoveryHandler(system_state_access, dbms_handler, auth, req_reader, res_builder);
});
// DBMS
dbms::Register(data, system_state_access, dbms_handler);
// Auth
auth::Register(data, system_state_access, auth);
}
#endif
#ifdef MG_ENTERPRISE
bool StartRpcServer(dbms::DbmsHandler &dbms_handler, const replication::RoleReplicaData &data,
auth::SynchedAuth &auth) {
#else
bool StartRpcServer(dbms::DbmsHandler &dbms_handler, const replication::RoleReplicaData &data) {
#endif
// Register storage handlers
dbms::InMemoryReplicationHandlers::Register(&dbms_handler, *data.server);
#ifdef MG_ENTERPRISE
// Register system handlers
Register(data, dbms_handler, auth);
#endif
// Start server
if (!data.server->Start()) {
spdlog::error("Unable to start the replication server.");
return false;
}
return true;
}
} // namespace memgraph::replication

View File

@ -0,0 +1,111 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "replication_handler/system_rpc.hpp"
#include <json/json.hpp>
#include "auth/rpc.hpp"
#include "slk/serialization.hpp"
#include "slk/streams.hpp"
#include "storage/v2/replication/rpc.hpp"
#include "utils/enum.hpp"
namespace memgraph::slk {
// Serialize code for SystemHeartbeatRes
void Save(const memgraph::replication::SystemHeartbeatRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.system_timestamp, builder);
}
void Load(memgraph::replication::SystemHeartbeatRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->system_timestamp, reader);
}
// Serialize code for SystemHeartbeatReq
void Save(const memgraph::replication::SystemHeartbeatReq & /*self*/, memgraph::slk::Builder * /*builder*/) {
/* Nothing to serialize */
}
void Load(memgraph::replication::SystemHeartbeatReq * /*self*/, memgraph::slk::Reader * /*reader*/) {
/* Nothing to serialize */
}
// Serialize code for SystemRecoveryReq
void Save(const memgraph::replication::SystemRecoveryReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.forced_group_timestamp, builder);
memgraph::slk::Save(self.database_configs, builder);
memgraph::slk::Save(self.auth_config, builder);
memgraph::slk::Save(self.users, builder);
memgraph::slk::Save(self.roles, builder);
}
void Load(memgraph::replication::SystemRecoveryReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->forced_group_timestamp, reader);
memgraph::slk::Load(&self->database_configs, reader);
memgraph::slk::Load(&self->auth_config, reader);
memgraph::slk::Load(&self->users, reader);
memgraph::slk::Load(&self->roles, reader);
}
// Serialize code for SystemRecoveryRes
void Save(const memgraph::replication::SystemRecoveryRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(utils::EnumToNum<uint8_t>(self.result), builder);
}
void Load(memgraph::replication::SystemRecoveryRes *self, memgraph::slk::Reader *reader) {
uint8_t res = 0;
memgraph::slk::Load(&res, reader);
if (!utils::NumToEnum(res, self->result)) {
throw SlkReaderException("Unexpected result line:{}!", __LINE__);
}
}
} // namespace memgraph::slk
namespace memgraph::replication {
constexpr utils::TypeInfo SystemHeartbeatReq::kType{utils::TypeId::REP_SYSTEM_HEARTBEAT_REQ, "SystemHeartbeatReq",
nullptr};
constexpr utils::TypeInfo SystemHeartbeatRes::kType{utils::TypeId::REP_SYSTEM_HEARTBEAT_RES, "SystemHeartbeatRes",
nullptr};
constexpr utils::TypeInfo SystemRecoveryReq::kType{utils::TypeId::REP_SYSTEM_RECOVERY_REQ, "SystemRecoveryReq",
nullptr};
constexpr utils::TypeInfo SystemRecoveryRes::kType{utils::TypeId::REP_SYSTEM_RECOVERY_RES, "SystemRecoveryRes",
nullptr};
void SystemHeartbeatReq::Save(const SystemHeartbeatReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void SystemHeartbeatReq::Load(SystemHeartbeatReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
void SystemHeartbeatRes::Save(const SystemHeartbeatRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void SystemHeartbeatRes::Load(SystemHeartbeatRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
void SystemRecoveryReq::Save(const SystemRecoveryReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void SystemRecoveryReq::Load(SystemRecoveryReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
void SystemRecoveryRes::Save(const SystemRecoveryRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void SystemRecoveryRes::Load(SystemRecoveryRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
} // namespace memgraph::replication

View File

@ -59,39 +59,6 @@ void TimestampRes::Save(const TimestampRes &self, memgraph::slk::Builder *builde
memgraph::slk::Save(self, builder);
}
void TimestampRes::Load(TimestampRes *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); }
void CreateDatabaseReq::Save(const CreateDatabaseReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void CreateDatabaseReq::Load(CreateDatabaseReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
void CreateDatabaseRes::Save(const CreateDatabaseRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void CreateDatabaseRes::Load(CreateDatabaseRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
void DropDatabaseReq::Save(const DropDatabaseReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void DropDatabaseReq::Load(DropDatabaseReq *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); }
void DropDatabaseRes::Save(const DropDatabaseRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void DropDatabaseRes::Load(DropDatabaseRes *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); }
void SystemRecoveryReq::Save(const SystemRecoveryReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void SystemRecoveryReq::Load(SystemRecoveryReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
void SystemRecoveryRes::Save(const SystemRecoveryRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void SystemRecoveryRes::Load(SystemRecoveryRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
} // namespace storage::replication
constexpr utils::TypeInfo storage::replication::AppendDeltasReq::kType{utils::TypeId::REP_APPEND_DELTAS_REQ,
@ -130,24 +97,6 @@ constexpr utils::TypeInfo storage::replication::TimestampReq::kType{utils::TypeI
constexpr utils::TypeInfo storage::replication::TimestampRes::kType{utils::TypeId::REP_TIMESTAMP_RES, "TimestampRes",
nullptr};
constexpr utils::TypeInfo storage::replication::CreateDatabaseReq::kType{utils::TypeId::REP_CREATE_DATABASE_REQ,
"CreateDatabaseReq", nullptr};
constexpr utils::TypeInfo storage::replication::CreateDatabaseRes::kType{utils::TypeId::REP_CREATE_DATABASE_RES,
"CreateDatabaseRes", nullptr};
constexpr utils::TypeInfo storage::replication::DropDatabaseReq::kType{utils::TypeId::REP_DROP_DATABASE_REQ,
"DropDatabaseReq", nullptr};
constexpr utils::TypeInfo storage::replication::DropDatabaseRes::kType{utils::TypeId::REP_DROP_DATABASE_RES,
"DropDatabaseRes", nullptr};
constexpr utils::TypeInfo storage::replication::SystemRecoveryReq::kType{utils::TypeId::REP_SYSTEM_RECOVERY_REQ,
"SystemRecoveryReq", nullptr};
constexpr utils::TypeInfo storage::replication::SystemRecoveryRes::kType{utils::TypeId::REP_SYSTEM_RECOVERY_RES,
"SystemRecoveryRes", nullptr};
// Autogenerated SLK serialization code
namespace slk {
// Serialize code for TimestampRes
@ -316,91 +265,5 @@ void Load(memgraph::storage::SalientConfig *self, memgraph::slk::Reader *reader)
memgraph::slk::Load(&self->items.enable_schema_metadata, reader);
}
// Serialize code for CreateDatabaseReq
void Save(const memgraph::storage::replication::CreateDatabaseReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.epoch_id, builder);
memgraph::slk::Save(self.expected_group_timestamp, builder);
memgraph::slk::Save(self.new_group_timestamp, builder);
memgraph::slk::Save(self.config, builder);
}
void Load(memgraph::storage::replication::CreateDatabaseReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->epoch_id, reader);
memgraph::slk::Load(&self->expected_group_timestamp, reader);
memgraph::slk::Load(&self->new_group_timestamp, reader);
memgraph::slk::Load(&self->config, reader);
}
// Serialize code for CreateDatabaseRes
void Save(const memgraph::storage::replication::CreateDatabaseRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(utils::EnumToNum<uint8_t>(self.result), builder);
}
void Load(memgraph::storage::replication::CreateDatabaseRes *self, memgraph::slk::Reader *reader) {
uint8_t res = 0;
memgraph::slk::Load(&res, reader);
if (!utils::NumToEnum(res, self->result)) {
throw SlkReaderException("Unexpected result line:{}!", __LINE__);
}
}
// Serialize code for DropDatabaseReq
void Save(const memgraph::storage::replication::DropDatabaseReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.epoch_id, builder);
memgraph::slk::Save(self.expected_group_timestamp, builder);
memgraph::slk::Save(self.new_group_timestamp, builder);
memgraph::slk::Save(self.uuid, builder);
}
void Load(memgraph::storage::replication::DropDatabaseReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->epoch_id, reader);
memgraph::slk::Load(&self->expected_group_timestamp, reader);
memgraph::slk::Load(&self->new_group_timestamp, reader);
memgraph::slk::Load(&self->uuid, reader);
}
// Serialize code for DropDatabaseRes
void Save(const memgraph::storage::replication::DropDatabaseRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(utils::EnumToNum<uint8_t>(self.result), builder);
}
void Load(memgraph::storage::replication::DropDatabaseRes *self, memgraph::slk::Reader *reader) {
uint8_t res = 0;
memgraph::slk::Load(&res, reader);
if (!utils::NumToEnum(res, self->result)) {
throw SlkReaderException("Unexpected result line:{}!", __LINE__);
}
}
// Serialize code for SystemRecoveryReq
void Save(const memgraph::storage::replication::SystemRecoveryReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.forced_group_timestamp, builder);
memgraph::slk::Save(self.database_configs, builder);
}
void Load(memgraph::storage::replication::SystemRecoveryReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->forced_group_timestamp, reader);
memgraph::slk::Load(&self->database_configs, reader);
}
// Serialize code for SystemRecoveryRes
void Save(const memgraph::storage::replication::SystemRecoveryRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(utils::EnumToNum<uint8_t>(self.result), builder);
}
void Load(memgraph::storage::replication::SystemRecoveryRes *self, memgraph::slk::Reader *reader) {
uint8_t res = 0;
memgraph::slk::Load(&res, reader);
if (!utils::NumToEnum(res, self->result)) {
throw SlkReaderException("Unexpected result line:{}!", __LINE__);
}
}
} // namespace slk
} // namespace memgraph

View File

@ -201,108 +201,6 @@ struct TimestampRes {
using TimestampRpc = rpc::RequestResponse<TimestampReq, TimestampRes>;
struct CreateDatabaseReq {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(CreateDatabaseReq *self, memgraph::slk::Reader *reader);
static void Save(const CreateDatabaseReq &self, memgraph::slk::Builder *builder);
CreateDatabaseReq() = default;
CreateDatabaseReq(std::string epoch_id, uint64_t expected_group_timestamp, uint64_t new_group_timestamp,
storage::SalientConfig config)
: epoch_id(std::move(epoch_id)),
expected_group_timestamp{expected_group_timestamp},
new_group_timestamp(new_group_timestamp),
config(std::move(config)) {}
std::string epoch_id;
uint64_t expected_group_timestamp;
uint64_t new_group_timestamp;
storage::SalientConfig config;
};
struct CreateDatabaseRes {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
enum class Result : uint8_t { SUCCESS, NO_NEED, FAILURE, /* Leave at end */ N };
static void Load(CreateDatabaseRes *self, memgraph::slk::Reader *reader);
static void Save(const CreateDatabaseRes &self, memgraph::slk::Builder *builder);
CreateDatabaseRes() = default;
explicit CreateDatabaseRes(Result res) : result(res) {}
Result result;
};
using CreateDatabaseRpc = rpc::RequestResponse<CreateDatabaseReq, CreateDatabaseRes>;
struct DropDatabaseReq {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(DropDatabaseReq *self, memgraph::slk::Reader *reader);
static void Save(const DropDatabaseReq &self, memgraph::slk::Builder *builder);
DropDatabaseReq() = default;
DropDatabaseReq(std::string epoch_id, uint64_t expected_group_timestamp, uint64_t new_group_timestamp,
const utils::UUID &uuid)
: epoch_id(std::move(epoch_id)),
expected_group_timestamp{expected_group_timestamp},
new_group_timestamp(new_group_timestamp),
uuid(uuid) {}
std::string epoch_id;
uint64_t expected_group_timestamp;
uint64_t new_group_timestamp;
utils::UUID uuid;
};
struct DropDatabaseRes {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
enum class Result : uint8_t { SUCCESS, NO_NEED, FAILURE, /* Leave at end */ N };
static void Load(DropDatabaseRes *self, memgraph::slk::Reader *reader);
static void Save(const DropDatabaseRes &self, memgraph::slk::Builder *builder);
DropDatabaseRes() = default;
explicit DropDatabaseRes(Result res) : result(res) {}
Result result;
};
using DropDatabaseRpc = rpc::RequestResponse<DropDatabaseReq, DropDatabaseRes>;
struct SystemRecoveryReq {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(SystemRecoveryReq *self, memgraph::slk::Reader *reader);
static void Save(const SystemRecoveryReq &self, memgraph::slk::Builder *builder);
SystemRecoveryReq() = default;
SystemRecoveryReq(uint64_t forced_group_timestamp, std::vector<storage::SalientConfig> database_configs)
: forced_group_timestamp{forced_group_timestamp}, database_configs(std::move(database_configs)) {}
uint64_t forced_group_timestamp;
std::vector<storage::SalientConfig> database_configs;
};
struct SystemRecoveryRes {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
enum class Result : uint8_t { SUCCESS, NO_NEED, FAILURE, /* Leave at end */ N };
static void Load(SystemRecoveryRes *self, memgraph::slk::Reader *reader);
static void Save(const SystemRecoveryRes &self, memgraph::slk::Builder *builder);
SystemRecoveryRes() = default;
explicit SystemRecoveryRes(Result res) : result(res) {}
Result result;
};
using SystemRecoveryRpc = rpc::RequestResponse<SystemRecoveryReq, SystemRecoveryRes>;
} // namespace memgraph::storage::replication
// SLK serialization declarations
@ -356,28 +254,8 @@ void Save(const memgraph::storage::replication::AppendDeltasReq &self, memgraph:
void Load(memgraph::storage::replication::AppendDeltasReq *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::CreateDatabaseReq &self, memgraph::slk::Builder *builder);
void Save(const memgraph::storage::SalientConfig &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::CreateDatabaseReq *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::CreateDatabaseRes &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::CreateDatabaseRes *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::DropDatabaseReq &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::DropDatabaseReq *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::DropDatabaseRes &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::DropDatabaseRes *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::SystemRecoveryReq &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::SystemRecoveryReq *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::SystemRecoveryRes &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::SystemRecoveryRes *self, memgraph::slk::Reader *reader);
void Load(memgraph::storage::SalientConfig *self, memgraph::slk::Reader *reader);
} // namespace memgraph::slk

23
src/system/CMakeLists.txt Normal file
View File

@ -0,0 +1,23 @@
add_library(mg-system STATIC)
add_library(mg::system ALIAS mg-system)
target_sources(mg-system
PUBLIC
include/system/action.hpp
include/system/system.hpp
include/system/transaction.hpp
include/system/state.hpp
PRIVATE
action.cpp
system.cpp
transaction.cpp
state.cpp
)
target_include_directories(mg-system PUBLIC include)
target_link_libraries(mg-system
PUBLIC
mg::replication
)

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -8,14 +8,4 @@
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "dbms/dbms_handler.hpp"
#include "replication/replication_client.hpp"
namespace memgraph::dbms {
void StartReplicaClient(DbmsHandler &dbms_handler, replication::ReplicationClient &client);
} // namespace memgraph::dbms
#include "system/include/system/action.hpp"

View File

@ -0,0 +1,38 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "replication/epoch.hpp"
#include "replication/replication_client.hpp"
#include "replication/state.hpp"
namespace memgraph::system {
struct Transaction;
/// The system action interface that subsystems will implement. This OO-style separation is needed so that one common
/// mechanism can be used for all subsystem replication within a system transaction, without the need for System to
/// know about all the subsystems.
struct ISystemAction {
/// Durability step which is defered until commit time
virtual void DoDurability() = 0;
/// Prepare the RPC payload that will be sent to all replicas clients
virtual bool DoReplication(memgraph::replication::ReplicationClient &client,
memgraph::replication::ReplicationEpoch const &epoch,
Transaction const &system_tx) const = 0;
virtual void PostReplication(memgraph::replication::RoleMainData &main_data) const = 0;
virtual ~ISystemAction() = default;
};
} // namespace memgraph::system

View File

@ -0,0 +1,57 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <atomic>
#include <cstdint>
#include "kvstore/kvstore.hpp"
#include "utils/file.hpp"
namespace memgraph::system {
namespace {
constexpr std::string_view kLastCommitedSystemTsKey = "last_committed_system_ts"; // Key for timestamp durability
}
struct State {
explicit State(std::optional<std::filesystem::path> storage, bool recovery_on_startup);
void FinalizeTransaction(std::uint64_t timestamp) {
if (durability_) {
durability_->Put(kLastCommitedSystemTsKey, std::to_string(timestamp));
}
last_committed_system_timestamp_.store(timestamp);
}
auto LastCommittedSystemTimestamp() -> uint64_t { return last_committed_system_timestamp_.load(); }
private:
friend struct ReplicaHandlerAccessToState;
friend struct Transaction;
std::optional<kvstore::KVStore> durability_;
std::atomic_uint64_t last_committed_system_timestamp_{};
};
struct ReplicaHandlerAccessToState {
explicit ReplicaHandlerAccessToState(memgraph::system::State &state) : state_{&state} {}
auto LastCommitedTS() const -> uint64_t { return state_->last_committed_system_timestamp_.load(); }
void SetLastCommitedTS(uint64_t new_timestamp) { state_->last_committed_system_timestamp_.store(new_timestamp); }
private:
State *state_;
};
} // namespace memgraph::system

View File

@ -0,0 +1,52 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "system/state.hpp"
#include "system/transaction.hpp"
namespace memgraph::system {
struct TransactionGuard {
explicit TransactionGuard(std::unique_lock<std::timed_mutex> guard) : guard_(std::move(guard)) {}
private:
std::unique_lock<std::timed_mutex> guard_;
};
struct System {
// NOTE: default arguments to make testing easier.
System(std::optional<std::filesystem::path> storage = std::nullopt, bool recovery_on_startup = false)
: state_(std::move(storage), recovery_on_startup), timestamp_{state_.LastCommittedSystemTimestamp()} {}
auto TryCreateTransaction(std::chrono::microseconds try_time = std::chrono::milliseconds{100})
-> std::optional<Transaction> {
auto system_unique = std::unique_lock{mtx_, std::defer_lock};
if (!system_unique.try_lock_for(try_time)) {
return std::nullopt;
}
return Transaction{state_, std::move(system_unique), timestamp_++};
}
// TODO: this and LastCommittedSystemTimestamp maybe not needed
auto GenTransactionGuard() -> TransactionGuard { return TransactionGuard{std::unique_lock{mtx_}}; }
auto LastCommittedSystemTimestamp() -> uint64_t { return state_.LastCommittedSystemTimestamp(); }
auto CreateSystemStateAccess() -> ReplicaHandlerAccessToState { return ReplicaHandlerAccessToState{state_}; }
private:
State state_;
std::timed_mutex mtx_{};
std::uint64_t timestamp_{};
};
} // namespace memgraph::system

View File

@ -0,0 +1,124 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <chrono>
#include <list>
#include <memory>
#include <mutex>
#include <optional>
#include "replication/state.hpp"
#include "system/action.hpp"
#include "system/state.hpp"
namespace memgraph::system {
enum class AllSyncReplicaStatus : std::uint8_t {
AllCommitsConfirmed,
SomeCommitsUnconfirmed,
};
struct Transaction;
template <typename T>
concept ReplicationPolicy = requires(T handler, ISystemAction const &action, Transaction const &txn) {
{ handler.ApplyAction(action, txn) } -> std::same_as<AllSyncReplicaStatus>;
};
struct System;
struct Transaction {
template <std::derived_from<ISystemAction> TAction, typename... Args>
requires std::constructible_from<TAction, Args...>
void AddAction(Args &&...args) { actions_.emplace_back(std::make_unique<TAction>(std::forward<Args>(args)...)); }
template <ReplicationPolicy Handler>
auto Commit(Handler handler) -> AllSyncReplicaStatus {
if (!lock_.owns_lock() || actions_.empty()) {
// If no actions, we do not increment the last commited ts, since there is no delta to send to the REPLICA
Abort();
return AllSyncReplicaStatus::AllCommitsConfirmed; // TODO: some kind of error
}
auto sync_status = AllSyncReplicaStatus::AllCommitsConfirmed;
while (!actions_.empty()) {
auto &action = actions_.front();
/// durability
action->DoDurability();
/// replication prep
auto action_sync_status = handler.ApplyAction(*action, *this);
if (action_sync_status != AllSyncReplicaStatus::AllCommitsConfirmed) {
sync_status = AllSyncReplicaStatus::SomeCommitsUnconfirmed;
}
actions_.pop_front();
}
state_->FinalizeTransaction(timestamp_);
lock_.unlock();
return sync_status;
}
void Abort() {
if (lock_.owns_lock()) {
lock_.unlock();
}
actions_.clear();
}
auto last_committed_system_timestamp() const -> uint64_t { return state_->last_committed_system_timestamp_.load(); }
auto timestamp() const -> uint64_t { return timestamp_; }
private:
friend struct System;
Transaction(State &state, std::unique_lock<std::timed_mutex> lock, std::uint64_t timestamp)
: state_{std::addressof(state)}, lock_(std::move(lock)), timestamp_{timestamp} {}
State *state_;
std::unique_lock<std::timed_mutex> lock_;
std::uint64_t timestamp_;
std::list<std::unique_ptr<ISystemAction>> actions_;
};
struct DoReplication {
explicit DoReplication(replication::RoleMainData &main_data) : main_data_{main_data} {}
auto ApplyAction(ISystemAction const &action, Transaction const &system_tx) -> AllSyncReplicaStatus {
auto sync_status = AllSyncReplicaStatus::AllCommitsConfirmed;
for (auto &client : main_data_.registered_replicas_) {
bool completed = action.DoReplication(client, main_data_.epoch_, system_tx);
if (!completed && client.mode_ == replication_coordination_glue::ReplicationMode::SYNC) {
sync_status = AllSyncReplicaStatus::SomeCommitsUnconfirmed;
}
}
action.PostReplication(main_data_);
return sync_status;
}
private:
replication::RoleMainData &main_data_;
};
static_assert(ReplicationPolicy<DoReplication>);
struct DoNothing {
auto ApplyAction(ISystemAction const & /*action*/, Transaction const & /*system_tx*/) -> AllSyncReplicaStatus {
return AllSyncReplicaStatus::AllCommitsConfirmed;
}
};
static_assert(ReplicationPolicy<DoNothing>);
} // namespace memgraph::system

57
src/system/state.cpp Normal file
View File

@ -0,0 +1,57 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "system/state.hpp"
namespace memgraph::system {
namespace {
constexpr std::string_view kSystemDir = ".system";
constexpr std::string_view kVersion = "version"; // Key for version durability
constexpr std::string_view kVersionV1 = "V1"; // Value for version 1
auto InitializeSystemDurability(std::optional<std::filesystem::path> storage, bool recovery_on_startup)
-> std::optional<memgraph::kvstore::KVStore> {
if (!storage) return std::nullopt;
auto const &path = *storage;
memgraph::utils::EnsureDir(path);
auto system_dir = path / kSystemDir;
memgraph::utils::EnsureDir(system_dir);
auto durability = memgraph::kvstore::KVStore{std::move(system_dir)};
auto version = durability.Get(kVersion);
// TODO: migration schemes here in the future
if (!version || *version != kVersionV1) {
// ensure we start out with V1
durability.Put(kVersion, kVersionV1);
}
if (!recovery_on_startup) {
// reset last_committed_system_ts
durability.Delete(kLastCommitedSystemTsKey);
}
return durability;
}
auto LoadLastCommittedSystemTimestamp(std::optional<kvstore::KVStore> const &store) -> uint64_t {
auto lcst = store ? store->Get(kLastCommitedSystemTsKey) : std::nullopt;
return lcst ? std::stoul(*lcst) : 0U;
}
} // namespace
State::State(std::optional<std::filesystem::path> storage, bool recovery_on_startup)
: durability_{InitializeSystemDurability(std::move(storage), recovery_on_startup)},
last_committed_system_timestamp_{LoadLastCommittedSystemTimestamp(durability_)} {}
} // namespace memgraph::system

11
src/system/system.cpp Normal file
View File

@ -0,0 +1,11 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "system/include/system/system.hpp"

View File

@ -0,0 +1,11 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "system/include/system/transaction.hpp"

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -149,9 +149,9 @@ void Telemetry::AddClientCollector() {
}
#ifdef MG_ENTERPRISE
void Telemetry::AddDatabaseCollector(dbms::DbmsHandler &dbms_handler) {
AddCollector("database", [&dbms_handler]() -> nlohmann::json {
const auto &infos = dbms_handler.Info();
void Telemetry::AddDatabaseCollector(dbms::DbmsHandler &dbms_handler, replication::ReplicationState &repl_state) {
AddCollector("database", [&dbms_handler, &repl_state]() -> nlohmann::json {
const auto &infos = dbms_handler.Info(repl_state.GetRole());
auto dbs = nlohmann::json::array();
for (const auto &db_info : infos) {
dbs.push_back(memgraph::dbms::ToJson(db_info));
@ -162,11 +162,10 @@ void Telemetry::AddDatabaseCollector(dbms::DbmsHandler &dbms_handler) {
#else
#endif
void Telemetry::AddStorageCollector(
dbms::DbmsHandler &dbms_handler,
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> &auth) {
AddCollector("storage", [&dbms_handler, &auth]() -> nlohmann::json {
auto stats = dbms_handler.Stats();
void Telemetry::AddStorageCollector(dbms::DbmsHandler &dbms_handler, memgraph::auth::SynchedAuth &auth,
memgraph::replication::ReplicationState &repl_state) {
AddCollector("storage", [&dbms_handler, &auth, &repl_state]() -> nlohmann::json {
auto stats = dbms_handler.Stats(repl_state.GetRole());
stats.users = auth->AllUsers().size();
return ToJson(stats);
});

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -43,12 +43,11 @@ class Telemetry final {
void AddCollector(const std::string &name, const std::function<const nlohmann::json(void)> &func);
// Specialized collectors
void AddStorageCollector(
dbms::DbmsHandler &dbms_handler,
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> &auth);
void AddStorageCollector(dbms::DbmsHandler &dbms_handler, memgraph::auth::SynchedAuth &auth,
memgraph::replication::ReplicationState &repl_state);
#ifdef MG_ENTERPRISE
void AddDatabaseCollector(dbms::DbmsHandler &dbms_handler);
void AddDatabaseCollector(dbms::DbmsHandler &dbms_handler, replication::ReplicationState &repl_state);
#else
void AddDatabaseCollector() {
AddCollector("database", []() -> nlohmann::json { return nlohmann::json::array(); });

View File

@ -161,10 +161,22 @@ struct Gatekeeper {
~Accessor() { reset(); }
auto get() -> T * { return std::addressof(*owner_->value_); }
auto get() const -> const T * { return std::addressof(*owner_->value_); }
T *operator->() { return std::addressof(*owner_->value_); }
const T *operator->() const { return std::addressof(*owner_->value_); }
auto get() -> T * {
if (owner_ == nullptr) return nullptr;
return std::addressof(*owner_->value_);
}
auto get() const -> const T * {
if (owner_ == nullptr) return nullptr;
return std::addressof(*owner_->value_);
}
T *operator->() {
if (owner_ == nullptr) return nullptr;
return std::addressof(*owner_->value_);
}
const T *operator->() const {
if (owner_ == nullptr) return nullptr;
return std::addressof(*owner_->value_);
}
template <typename Func>
[[nodiscard]] auto try_exclusively(Func &&func) -> EvalResult<std::invoke_result_t<Func, T &>> {

View File

@ -93,6 +93,10 @@ enum class TypeId : uint64_t {
REP_SYSTEM_HEARTBEAT_RES,
REP_SYSTEM_RECOVERY_REQ,
REP_SYSTEM_RECOVERY_RES,
REP_UPDATE_AUTH_DATA_REQ,
REP_UPDATE_AUTH_DATA_RES,
REP_DROP_AUTH_DATA_REQ,
REP_DROP_AUTH_DATA_RES,
// Coordinator
COORD_FAILOVER_REQ,

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -26,6 +26,7 @@ std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "e
class ExpansionBenchFixture : public benchmark::Fixture {
protected:
std::optional<memgraph::system::System> system;
std::optional<memgraph::query::InterpreterContext> interpreter_context;
std::optional<memgraph::query::Interpreter> interpreter;
std::optional<memgraph::utils::Gatekeeper<memgraph::dbms::Database>> db_gk;
@ -40,7 +41,14 @@ class ExpansionBenchFixture : public benchmark::Fixture {
auto db_acc_opt = db_gk->access();
MG_ASSERT(db_acc_opt, "Failed to access db");
auto &db_acc = *db_acc_opt;
interpreter_context.emplace(memgraph::query::InterpreterConfig{}, nullptr, &repl_state.value());
system.emplace();
interpreter_context.emplace(memgraph::query::InterpreterConfig{}, nullptr, &repl_state.value(), *system
#ifdef MG_ENTERPRISE
,
nullptr
#endif
);
auto label = db_acc->storage()->NameToLabel("Starting");
@ -70,6 +78,7 @@ class ExpansionBenchFixture : public benchmark::Fixture {
void TearDown(const benchmark::State &) override {
interpreter = std::nullopt;
interpreter_context = std::nullopt;
system.reset();
db_gk.reset();
std::filesystem::remove_all(data_directory);
}

View File

@ -105,7 +105,9 @@ def is_port_in_use(port: int) -> bool:
return s.connect_ex(("localhost", port)) == 0
def _start_instance(name, args, log_file, setup_queries, use_ssl, procdir, data_directory):
def _start_instance(
name, args, log_file, setup_queries, use_ssl, procdir, data_directory, username=None, password=None
):
assert (
name not in MEMGRAPH_INSTANCES.keys()
), "If this raises, you are trying to start an instance with the same name than one already running."
@ -115,7 +117,9 @@ def _start_instance(name, args, log_file, setup_queries, use_ssl, procdir, data_
log_file_path = os.path.join(BUILD_DIR, "logs", log_file)
data_directory_path = os.path.join(BUILD_DIR, data_directory)
mg_instance = MemgraphInstanceRunner(MEMGRAPH_BINARY, use_ssl, {data_directory_path})
mg_instance = MemgraphInstanceRunner(
MEMGRAPH_BINARY, use_ssl, {data_directory_path}, username=username, password=password
)
MEMGRAPH_INSTANCES[name] = mg_instance
binary_args = args + ["--log-file", log_file_path] + ["--data-directory", data_directory_path]
@ -185,8 +189,14 @@ def start_instance(context, name, procdir):
data_directory = value["data_directory"]
else:
data_directory = tempfile.TemporaryDirectory().name
username = None
if "username" in value:
username = value["username"]
password = None
if "password" in value:
password = value["password"]
instance = _start_instance(name, args, log_file, queries, use_ssl, procdir, data_directory)
instance = _start_instance(name, args, log_file, queries, use_ssl, procdir, data_directory, username, password)
mg_instances[name] = instance
assert len(mg_instances) == 1

View File

@ -57,7 +57,7 @@ def replace_paths(path):
class MemgraphInstanceRunner:
def __init__(self, binary_path=MEMGRAPH_BINARY, use_ssl=False, delete_on_stop=None):
def __init__(self, binary_path=MEMGRAPH_BINARY, use_ssl=False, delete_on_stop=None, username=None, password=None):
self.host = "127.0.0.1"
self.bolt_port = None
self.binary_path = binary_path
@ -65,12 +65,19 @@ class MemgraphInstanceRunner:
self.proc_mg = None
self.ssl = use_ssl
self.delete_on_stop = delete_on_stop
self.username = username
self.password = password
def execute_setup_queries(self, setup_queries):
if setup_queries is None:
return
# An assumption being database instance is fresh, no need for the auth.
conn = mgclient.connect(host=self.host, port=self.bolt_port, sslmode=self.ssl)
conn = mgclient.connect(
host=self.host,
port=self.bolt_port,
sslmode=self.ssl,
username=(self.username or ""),
password=(self.password or ""),
)
conn.autocommit = True
cursor = conn.cursor()
for query_coll in setup_queries:

View File

@ -3,6 +3,7 @@ find_package(gflags REQUIRED)
copy_e2e_python_files(replication_experiment common.py)
copy_e2e_python_files(replication_experiment conftest.py)
copy_e2e_python_files(replication_experiment multitenancy.py)
copy_e2e_python_files(replication_experiment auth.py)
copy_e2e_python_files_from_parent_folder(replication_experiment ".." memgraph.py)
copy_e2e_python_files_from_parent_folder(replication_experiment ".." interactive_mg_runner.py)
copy_e2e_python_files_from_parent_folder(replication_experiment ".." mg_utils.py)

View File

@ -0,0 +1,831 @@
# Copyright 2022 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import atexit
import os
import shutil
import sys
import tempfile
import time
from functools import partial
import interactive_mg_runner
import mgclient
import pytest
from common import execute_and_fetch_all
from mg_utils import mg_sleep_and_assert
interactive_mg_runner.SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
interactive_mg_runner.PROJECT_DIR = os.path.normpath(
os.path.join(interactive_mg_runner.SCRIPT_DIR, "..", "..", "..", "..")
)
interactive_mg_runner.BUILD_DIR = os.path.normpath(os.path.join(interactive_mg_runner.PROJECT_DIR, "build"))
interactive_mg_runner.MEMGRAPH_BINARY = os.path.normpath(os.path.join(interactive_mg_runner.BUILD_DIR, "memgraph"))
BOLT_PORTS = {"main": 7687, "replica_1": 7688, "replica_2": 7689}
REPLICATION_PORTS = {"replica_1": 10001, "replica_2": 10002}
TEMP_DIR = tempfile.TemporaryDirectory().name
def update_to_main(cursor):
execute_and_fetch_all(cursor, "SET REPLICATION ROLE TO MAIN;")
def add_user(cursor, username, password=None):
if password is not None:
return execute_and_fetch_all(cursor, f"CREATE USER {username} IDENTIFIED BY '{password}';")
return execute_and_fetch_all(cursor, f"CREATE USER {username};")
def show_users_func(cursor):
def func():
return set(execute_and_fetch_all(cursor, "SHOW USERS;"))
return func
def show_roles_func(cursor):
def func():
return set(execute_and_fetch_all(cursor, "SHOW ROLES;"))
return func
def show_users_for_role_func(cursor, rolename):
def func():
return set(execute_and_fetch_all(cursor, f"SHOW USERS FOR {rolename};"))
return func
def show_role_for_user_func(cursor, username):
def func():
return set(execute_and_fetch_all(cursor, f"SHOW ROLE FOR {username};"))
return func
def show_privileges_func(cursor, user_or_role):
def func():
return set(execute_and_fetch_all(cursor, f"SHOW PRIVILEGES FOR {user_or_role};"))
return func
def show_database_privileges_func(cursor, user):
def func():
return execute_and_fetch_all(cursor, f"SHOW DATABASE PRIVILEGES FOR {user};")
return func
def show_database_func(cursor):
def func():
return execute_and_fetch_all(cursor, f"SHOW DATABASE;")
return func
def try_and_count(cursor, query):
try:
execute_and_fetch_all(cursor, query)
except:
return 1
return 0
def only_main_queries(cursor):
n_exceptions = 0
n_exceptions += try_and_count(cursor, f"CREATE USER user_name")
n_exceptions += try_and_count(cursor, f"SET PASSWORD FOR user_name TO 'new_password'")
n_exceptions += try_and_count(cursor, f"DROP USER user_name")
n_exceptions += try_and_count(cursor, f"CREATE ROLE role_name")
n_exceptions += try_and_count(cursor, f"DROP ROLE role_name")
n_exceptions += try_and_count(cursor, f"CREATE USER user_name")
n_exceptions += try_and_count(cursor, f"CREATE ROLE role_name")
n_exceptions += try_and_count(cursor, f"SET ROLE FOR user_name TO role_name")
n_exceptions += try_and_count(cursor, f"CLEAR ROLE FOR user_name")
n_exceptions += try_and_count(cursor, f"GRANT AUTH TO role_name")
n_exceptions += try_and_count(cursor, f"DENY AUTH, INDEX TO user_name")
n_exceptions += try_and_count(cursor, f"REVOKE AUTH FROM role_name")
n_exceptions += try_and_count(cursor, f"GRANT READ ON LABELS :l TO role_name;")
n_exceptions += try_and_count(cursor, f"REVOKE EDGE_TYPES :e FROM user_name")
n_exceptions += try_and_count(cursor, f"GRANT DATABASE memgraph TO user_name;")
n_exceptions += try_and_count(cursor, f"SET MAIN DATABASE memgraph FOR user_name")
n_exceptions += try_and_count(cursor, f"REVOKE DATABASE memgraph FROM user_name;")
return n_exceptions
def main_and_repl_queries(cursor):
n_exceptions = 0
try_and_count(cursor, f"SHOW USERS")
try_and_count(cursor, f"SHOW ROLES")
try_and_count(cursor, f"SHOW USERS FOR ROLE role_name")
try_and_count(cursor, f"SHOW ROLE FOR user_name")
try_and_count(cursor, f"SHOW PRIVILEGES FOR role_name")
try_and_count(cursor, f"SHOW DATABASE PRIVILEGES FOR user_name")
return n_exceptions
def test_auth_queries_on_replica(connection):
# Goal: check that write auth queries are forbidden on REPLICAs
# 0/ Setup replication cluster
# 1/ Check queries
MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL = {
"replica_1": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['replica_1']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/replica1",
],
"log_file": "replica1.log",
"setup_queries": [
f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_1']};",
],
},
"replica_2": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['replica_2']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/replica2",
],
"log_file": "replica2.log",
"setup_queries": [
f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_2']};",
],
},
"main": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['main']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/main",
],
"log_file": "main.log",
"setup_queries": [
f"REGISTER REPLICA replica_1 SYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_1']}';",
f"REGISTER REPLICA replica_2 ASYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_2']}';",
],
},
}
# 0/
interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL, keep_directories=False)
cursor_main = connection(BOLT_PORTS["main"], "main", "UsErA", "pass").cursor()
cursor_replica_1 = connection(BOLT_PORTS["replica_1"], "replica", "UsErA", "pass").cursor()
cursor_replica_2 = connection(BOLT_PORTS["replica_2"], "replica", "UsErA", "pass").cursor()
# 1/
assert only_main_queries(cursor_main) == 0
assert only_main_queries(cursor_replica_1) == 17
assert only_main_queries(cursor_replica_2) == 17
assert main_and_repl_queries(cursor_main) == 0
assert main_and_repl_queries(cursor_replica_1) == 0
assert main_and_repl_queries(cursor_replica_2) == 0
def test_manual_users_recovery(connection):
# Goal: show system recovery in action at registration time
# 0/ MAIN CREATE USER user1, user2
# REPLICA CREATE USER user3, user4
# Setup replication cluster
# 1/ Check that both MAIN and REPLICA have user1 and user2
# 2/ Check connections on REPLICAS
MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL = {
"replica_1": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['replica_1']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/replica1",
],
"log_file": "replica1.log",
"setup_queries": [
"CREATE USER user3;",
f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_1']};",
],
},
"replica_2": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['replica_2']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/replica2",
],
"log_file": "replica2.log",
"setup_queries": [
"CREATE USER user4 IDENTIFIED BY 'password';",
f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_2']};",
],
},
"main": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['main']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/main",
],
"log_file": "main.log",
"setup_queries": [
"CREATE USER user1;",
"CREATE USER user2 IDENTIFIED BY 'password';",
f"REGISTER REPLICA replica_1 SYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_1']}';",
f"REGISTER REPLICA replica_2 ASYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_2']}';",
],
},
}
# 0/
interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL, keep_directories=False)
cursor = connection(BOLT_PORTS["main"], "main", "user1").cursor()
# 1/
expected_data = {("user2",), ("user1",)}
mg_sleep_and_assert(
expected_data, show_users_func(connection(BOLT_PORTS["replica_1"], "replica", "user1").cursor())
)
mg_sleep_and_assert(
expected_data, show_users_func(connection(BOLT_PORTS["replica_2"], "replica", "user1").cursor())
)
# 2/
connection(BOLT_PORTS["replica_1"], "replica", "user1").cursor()
connection(BOLT_PORTS["replica_1"], "replica", "user2", "password").cursor()
connection(BOLT_PORTS["replica_2"], "replica", "user1").cursor()
connection(BOLT_PORTS["replica_2"], "replica", "user2", "password").cursor()
def test_env_users_recovery(connection):
# Goal: show system recovery in action at registration time
# 0/ Set users from the environment
# MAIN gets users from the environment
# Setup replication cluster
# 1/ Check that both MAIN and REPLICA have user1
MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL = {
"replica_1": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['replica_1']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/replica1",
],
"log_file": "replica1.log",
"setup_queries": [
f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_1']};",
],
},
"replica_2": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['replica_2']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/replica2",
],
"log_file": "replica2.log",
"setup_queries": [
f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_2']};",
],
},
"main": {
"username": "user1",
"password": "password",
"args": [
"--bolt-port",
f"{BOLT_PORTS['main']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/main",
],
"log_file": "main.log",
"setup_queries": [
f"REGISTER REPLICA replica_1 SYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_1']}';",
f"REGISTER REPLICA replica_2 ASYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_2']}';",
],
},
}
# 0/
# Start only replicas without the env user
interactive_mg_runner.stop_all(keep_directories=False)
interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL, "replica_1")
interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL, "replica_2")
# Setup user
try:
os.environ["MEMGRAPH_USER"] = "user1"
os.environ["MEMGRAPH_PASSWORD"] = "password"
# Start main
interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL, "main")
finally:
# Cleanup
del os.environ["MEMGRAPH_USER"]
del os.environ["MEMGRAPH_PASSWORD"]
# 1/
expected_data = {("user1",)}
assert expected_data == show_users_func(connection(BOLT_PORTS["main"], "main", "user1", "password").cursor())()
mg_sleep_and_assert(
expected_data, show_users_func(connection(BOLT_PORTS["replica_1"], "replica", "user1", "password").cursor())
)
mg_sleep_and_assert(
expected_data, show_users_func(connection(BOLT_PORTS["replica_2"], "replica", "user1", "password").cursor())
)
def test_manual_roles_recovery(connection):
# Goal: show system recovery in action at registration time
# 0/ MAIN CREATE USER user1, user2
# REPLICA CREATE USER user3, user4
# Setup replication cluster
# 1/ Check that both MAIN and REPLICA have user1 and user2
# 2/ Check that role1 and role2 are replicated
# 3/ Check that user1 has role1
MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL = {
"replica_1": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['replica_1']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/replica1",
],
"log_file": "replica1.log",
"setup_queries": [
"CREATE ROLE role3;",
f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_1']};",
],
},
"replica_2": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['replica_2']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/replica2",
],
"log_file": "replica2.log",
"setup_queries": [
"CREATE ROLE role4;",
"CREATE USER user4;",
"SET ROLE FOR user4 TO role4;",
f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_2']};",
],
},
"main": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['main']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/main",
],
"log_file": "main.log",
"setup_queries": [
"CREATE ROLE role1;",
"CREATE ROLE role2;",
"CREATE USER user2;",
"SET ROLE FOR user2 TO role2;",
f"REGISTER REPLICA replica_1 SYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_1']}';",
f"REGISTER REPLICA replica_2 ASYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_2']}';",
],
},
}
# 0/
interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL, keep_directories=False)
connection(BOLT_PORTS["main"], "main", "user2").cursor() # Just check if it connects
cursor_replica_1 = connection(BOLT_PORTS["replica_1"], "replica", "user2").cursor()
cursor_replica_2 = connection(BOLT_PORTS["replica_2"], "replica", "user2").cursor()
# 1/
expected_data = {
("user2",),
}
mg_sleep_and_assert(expected_data, show_users_func(cursor_replica_1))
mg_sleep_and_assert(expected_data, show_users_func(cursor_replica_2))
# 2/
expected_data = {("role2",), ("role1",)}
mg_sleep_and_assert(expected_data, show_roles_func(cursor_replica_1))
mg_sleep_and_assert(expected_data, show_roles_func(cursor_replica_2))
# 3/
expected_data = {("role2",)}
mg_sleep_and_assert(
expected_data,
show_role_for_user_func(cursor_replica_1, "user2"),
)
mg_sleep_and_assert(
expected_data,
show_role_for_user_func(cursor_replica_2, "user2"),
)
def test_auth_config_recovery(connection):
# Goal: show we are replicating Auth::Config
# 0/ Setup auth configuration and compliant users
# 1/ Check that both MAIN and REPLICA have the same users
# 2/ Check that REPLICAS have the same config
MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL = {
"replica_1": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['replica_1']}",
"--log-level=TRACE",
"--auth-password-strength-regex",
"^[A-Z]+$",
"--auth-password-permit-null=false",
"--auth-user-or-role-name-regex",
"^[O-o]+$",
"--data_directory",
TEMP_DIR + "/replica1",
],
"log_file": "replica1.log",
"setup_queries": [
"CREATE USER OPQabc IDENTIFIED BY 'PASSWORD';",
"CREATE ROLE defRST;",
f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_1']};",
],
},
"replica_2": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['replica_2']}",
"--log-level=TRACE",
"--auth-password-strength-regex",
"^[0-9]+$",
"--auth-password-permit-null=true",
"--auth-user-or-role-name-regex",
"^[A-Np-z]+$",
"--data_directory",
TEMP_DIR + "/replica2",
],
"log_file": "replica2.log",
"setup_queries": [
"CREATE ROLE ABCpqr;",
"CREATE USER stuDEF;",
"CREATE USER GvHwI IDENTIFIED BY '123456';",
"SET ROLE FOR GvHwI TO ABCpqr;",
f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_2']};",
],
},
"main": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['main']}",
"--log-level=TRACE",
"--auth-password-strength-regex",
"^[a-z]+$",
"--auth-password-permit-null=false",
"--auth-user-or-role-name-regex",
"^[A-z]+$",
"--data_directory",
TEMP_DIR + "/main",
],
"log_file": "main.log",
"setup_queries": [
"CREATE USER UsErA IDENTIFIED BY 'pass';",
"CREATE ROLE rOlE;",
"CREATE USER uSeRB IDENTIFIED BY 'word';",
"SET ROLE FOR uSeRB TO rOlE;",
f"REGISTER REPLICA replica_1 SYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_1']}';",
f"REGISTER REPLICA replica_2 ASYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_2']}';",
],
},
}
# 0/
interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL, keep_directories=False)
# 1/
cursor_main = connection(BOLT_PORTS["main"], "main", "UsErA", "pass").cursor()
cursor_replica_1 = connection(BOLT_PORTS["replica_1"], "replica", "UsErA", "pass").cursor()
cursor_replica_2 = connection(BOLT_PORTS["replica_2"], "replica", "UsErA", "pass").cursor()
# 2/ Only MAIN can update users
def user_test(cursor):
with pytest.raises(mgclient.DatabaseError, match="Invalid user name."):
add_user(cursor, "UsEr1", "abcdef")
with pytest.raises(mgclient.DatabaseError, match="Null passwords aren't permitted!"):
add_user(cursor, "UsErC")
with pytest.raises(mgclient.DatabaseError, match="The user password doesn't conform to the required strength!"):
add_user(cursor, "UsErC", "123456")
user_test(cursor_main)
update_to_main(cursor_replica_1)
user_test(cursor_replica_1)
update_to_main(cursor_replica_2)
user_test(cursor_replica_2)
def test_auth_replication(connection):
# Goal: show that individual auth queries get replicated
MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL = {
"replica_1": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['replica_1']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/replica1",
],
"log_file": "replica1.log",
"setup_queries": [
f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_1']};",
],
},
"replica_2": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['replica_2']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/replica2",
],
"log_file": "replica2.log",
"setup_queries": [
f"SET REPLICATION ROLE TO REPLICA WITH PORT {REPLICATION_PORTS['replica_2']};",
],
},
"main": {
"args": [
"--bolt-port",
f"{BOLT_PORTS['main']}",
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/main",
],
"log_file": "main.log",
"setup_queries": [
f"REGISTER REPLICA replica_1 SYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_1']}';",
f"REGISTER REPLICA replica_2 ASYNC TO '127.0.0.1:{REPLICATION_PORTS['replica_2']}';",
],
},
}
# 0/
interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION_MANUAL, keep_directories=False)
cursor_main = connection(BOLT_PORTS["main"], "main", "user1").cursor()
cursor_replica1 = connection(BOLT_PORTS["replica_1"], "replica").cursor()
cursor_replica2 = connection(BOLT_PORTS["replica_2"], "replica").cursor()
# 1/
def check(f, expected_data):
# mg_sleep_and_assert(
# REPLICA 1 is SYNC, should already be ready
assert expected_data == f(cursor_replica1)()
# )
mg_sleep_and_assert(expected_data, f(cursor_replica2))
# CREATE USER
execute_and_fetch_all(cursor_main, "CREATE USER user1")
check(
show_users_func,
{
("user1",),
},
)
execute_and_fetch_all(cursor_main, "CREATE USER user2 IDENTIFIED BY 'pass'")
check(
show_users_func,
{
("user2",),
("user1",),
},
)
connection(BOLT_PORTS["replica_1"], "replica", "user1").cursor() # Just check connection
connection(BOLT_PORTS["replica_2"], "replica", "user1").cursor() # Just check connection
connection(BOLT_PORTS["replica_1"], "replica", "user2", "pass").cursor() # Just check connection
connection(BOLT_PORTS["replica_2"], "replica", "user2", "pass").cursor() # Just check connection
# SET PASSWORD
execute_and_fetch_all(cursor_main, "SET PASSWORD FOR user1 TO '1234'")
execute_and_fetch_all(cursor_main, "SET PASSWORD FOR user2 TO 'new_pass'")
connection(BOLT_PORTS["replica_1"], "replica", "user1", "1234").cursor() # Just check connection
connection(BOLT_PORTS["replica_2"], "replica", "user1", "1234").cursor() # Just check connection
connection(BOLT_PORTS["replica_1"], "replica", "user2", "new_pass").cursor() # Just check connection
connection(BOLT_PORTS["replica_2"], "replica", "user2", "new_pass").cursor() # Just check connection
# DROP USER
execute_and_fetch_all(cursor_main, "DROP USER user2")
check(
show_users_func,
{
("user1",),
},
)
execute_and_fetch_all(cursor_main, "DROP USER user1")
check(show_users_func, set())
connection(BOLT_PORTS["replica_1"], "replica").cursor() # Just check connection
connection(BOLT_PORTS["replica_2"], "replica").cursor() # Just check connection
# CREATE ROLE
execute_and_fetch_all(cursor_main, "CREATE ROLE role1")
check(
show_roles_func,
{
("role1",),
},
)
execute_and_fetch_all(cursor_main, "CREATE ROLE role2")
check(
show_roles_func,
{
("role2",),
("role1",),
},
)
# DROP ROLE
execute_and_fetch_all(cursor_main, "DROP ROLE role2")
check(
show_roles_func,
{
("role1",),
},
)
execute_and_fetch_all(cursor_main, "DROP ROLE role1")
check(show_roles_func, set())
# SET ROLE
execute_and_fetch_all(cursor_main, "CREATE USER user3")
execute_and_fetch_all(cursor_main, "CREATE ROLE role3")
execute_and_fetch_all(cursor_main, "SET ROLE FOR user3 TO role3")
check(partial(show_role_for_user_func, username="user3"), {("role3",)})
execute_and_fetch_all(cursor_main, "CREATE USER user3b")
execute_and_fetch_all(cursor_main, "SET ROLE FOR user3b TO role3")
check(partial(show_role_for_user_func, username="user3b"), {("role3",)})
check(
partial(show_users_for_role_func, rolename="role3"),
{
("user3",),
("user3b",),
},
)
# CLEAR ROLE
execute_and_fetch_all(cursor_main, "CLEAR ROLE FOR user3")
check(partial(show_role_for_user_func, username="user3"), {("null",)})
check(
partial(show_users_for_role_func, rolename="role3"),
{
("user3b",),
},
)
# GRANT/REVOKE/DENY privileges TO user
execute_and_fetch_all(cursor_main, "CREATE USER user4")
execute_and_fetch_all(cursor_main, "REVOKE ALL PRIVILEGES FROM user4")
execute_and_fetch_all(cursor_main, "GRANT CREATE, DELETE, SET TO user4")
check(
partial(show_privileges_func, user_or_role="user4"),
{
("CREATE", "GRANT", "GRANTED TO USER"),
("DELETE", "GRANT", "GRANTED TO USER"),
("SET", "GRANT", "GRANTED TO USER"),
},
)
execute_and_fetch_all(cursor_main, "REVOKE SET FROM user4")
check(
partial(show_privileges_func, user_or_role="user4"),
{("CREATE", "GRANT", "GRANTED TO USER"), ("DELETE", "GRANT", "GRANTED TO USER")},
)
execute_and_fetch_all(cursor_main, "DENY DELETE TO user4")
check(
partial(show_privileges_func, user_or_role="user4"),
{("CREATE", "GRANT", "GRANTED TO USER"), ("DELETE", "DENY", "DENIED TO USER")},
)
# GRANT/REVOKE/DENY privileges TO role
execute_and_fetch_all(cursor_main, "REVOKE ALL PRIVILEGES FROM role3")
execute_and_fetch_all(cursor_main, "REVOKE ALL PRIVILEGES FROM user3b")
execute_and_fetch_all(cursor_main, "GRANT CREATE, DELETE, SET TO role3")
check(
partial(show_privileges_func, user_or_role="role3"),
{
("CREATE", "GRANT", "GRANTED TO ROLE"),
("DELETE", "GRANT", "GRANTED TO ROLE"),
("SET", "GRANT", "GRANTED TO ROLE"),
},
)
check(
partial(show_privileges_func, user_or_role="user3b"),
{
("CREATE", "GRANT", "GRANTED TO ROLE"),
("DELETE", "GRANT", "GRANTED TO ROLE"),
("SET", "GRANT", "GRANTED TO ROLE"),
},
)
execute_and_fetch_all(cursor_main, "REVOKE SET FROM role3")
check(
partial(show_privileges_func, user_or_role="role3"),
{("CREATE", "GRANT", "GRANTED TO ROLE"), ("DELETE", "GRANT", "GRANTED TO ROLE")},
)
check(
partial(show_privileges_func, user_or_role="user3b"),
{("CREATE", "GRANT", "GRANTED TO ROLE"), ("DELETE", "GRANT", "GRANTED TO ROLE")},
)
execute_and_fetch_all(cursor_main, "DENY DELETE TO role3")
check(
partial(show_privileges_func, user_or_role="role3"),
{("CREATE", "GRANT", "GRANTED TO ROLE"), ("DELETE", "DENY", "DENIED TO ROLE")},
)
check(
partial(show_privileges_func, user_or_role="user3b"),
{("CREATE", "GRANT", "GRANTED TO ROLE"), ("DELETE", "DENY", "DENIED TO ROLE")},
)
# GRANT permission ON LABEL/EDGE to user/role
execute_and_fetch_all(cursor_main, "REVOKE ALL PRIVILEGES FROM role3")
execute_and_fetch_all(cursor_main, "REVOKE ALL PRIVILEGES FROM user4")
execute_and_fetch_all(cursor_main, "REVOKE ALL PRIVILEGES FROM user3b")
execute_and_fetch_all(cursor_main, "GRANT READ ON LABELS :l1 TO user4")
execute_and_fetch_all(cursor_main, "GRANT UPDATE ON LABELS :l2, :l3 TO role3")
check(
partial(show_privileges_func, user_or_role="user4"),
{
("LABEL :l1", "READ", "LABEL PERMISSION GRANTED TO USER"),
},
)
check(
partial(show_privileges_func, user_or_role="role3"),
{
("LABEL :l3", "UPDATE", "LABEL PERMISSION GRANTED TO ROLE"),
("LABEL :l2", "UPDATE", "LABEL PERMISSION GRANTED TO ROLE"),
},
)
check(
partial(show_privileges_func, user_or_role="user3b"),
{
("LABEL :l3", "UPDATE", "LABEL PERMISSION GRANTED TO ROLE"),
("LABEL :l2", "UPDATE", "LABEL PERMISSION GRANTED TO ROLE"),
},
)
execute_and_fetch_all(cursor_main, "REVOKE LABELS :l1 FROM user4")
execute_and_fetch_all(cursor_main, "REVOKE LABELS :l2 FROM role3")
check(partial(show_privileges_func, user_or_role="user4"), set())
check(
partial(show_privileges_func, user_or_role="role3"),
{("LABEL :l3", "UPDATE", "LABEL PERMISSION GRANTED TO ROLE")},
)
check(
partial(show_privileges_func, user_or_role="user3b"),
{("LABEL :l3", "UPDATE", "LABEL PERMISSION GRANTED TO ROLE")},
)
# GRANT/REVOKE DATABASE
execute_and_fetch_all(cursor_main, "CREATE DATABASE auth_test")
execute_and_fetch_all(cursor_main, "CREATE DATABASE auth_test2")
execute_and_fetch_all(cursor_main, "GRANT DATABASE auth_test TO user4")
check(partial(show_database_privileges_func, user="user4"), [(["auth_test", "memgraph"], [])])
execute_and_fetch_all(cursor_main, "REVOKE DATABASE auth_test2 FROM user4")
check(partial(show_database_privileges_func, user="user4"), [(["auth_test", "memgraph"], ["auth_test2"])])
# SET MAIN DATABASE
execute_and_fetch_all(cursor_main, "GRANT ALL PRIVILEGES TO user4")
execute_and_fetch_all(cursor_main, "SET MAIN DATABASE auth_test FOR user4")
# Reconnect and check current db
assert (
execute_and_fetch_all(connection(BOLT_PORTS["main"], "main", "user4").cursor(), "SHOW DATABASE")[0][0]
== "auth_test"
)
assert (
execute_and_fetch_all(connection(BOLT_PORTS["replica_1"], "replica", "user4").cursor(), "SHOW DATABASE")[0][0]
== "auth_test"
)
assert (
execute_and_fetch_all(connection(BOLT_PORTS["replica_2"], "replica", "user4").cursor(), "SHOW DATABASE")[0][0]
== "auth_test"
)
if __name__ == "__main__":
interactive_mg_runner.cleanup_directories_on_exit()
sys.exit(pytest.main([__file__, "-rA"]))

View File

@ -18,9 +18,9 @@ def connection():
connection_holder = None
role_holder = None
def inner_connection(port, role):
def inner_connection(port, role, username="", password=""):
nonlocal connection_holder, role_holder
connection_holder = connect(host="localhost", port=port)
connection_holder = connect(host="localhost", port=port, username=username, password=password)
role_holder = role
return connection_holder

View File

@ -2,3 +2,6 @@ workloads:
- name: "Replicate multitenancy"
binary: "tests/e2e/pytest_runner.sh"
args: ["replication_experimental/multitenancy.py"]
- name: "Replicate auth data"
binary: "tests/e2e/pytest_runner.sh"
args: ["replication_experimental/auth.py"]

View File

@ -33,7 +33,7 @@ int main(int argc, char **argv) {
// Memgraph backend
std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_telemetry_integration_test"};
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> auth_{
memgraph::auth::SynchedAuth auth_{
data_directory / "auth",
memgraph::auth::Auth::Config{std::string{memgraph::glue::kDefaultUserRoleRegex}, "", true}};
memgraph::glue::AuthQueryHandler auth_handler(&auth_);
@ -43,14 +43,20 @@ int main(int argc, char **argv) {
memgraph::storage::UpdatePaths(db_config, data_directory);
memgraph::replication::ReplicationState repl_state(ReplicationStateRootPath(db_config));
memgraph::dbms::DbmsHandler dbms_handler(db_config
memgraph::system::System system_state;
memgraph::dbms::DbmsHandler dbms_handler(db_config, system_state, repl_state
#ifdef MG_ENTERPRISE
,
&auth_, false
auth_, false
#endif
);
memgraph::query::InterpreterContext interpreter_context_({}, &dbms_handler, &repl_state, &auth_handler,
&auth_checker);
memgraph::query::InterpreterContext interpreter_context_({}, &dbms_handler, &repl_state, system_state
#ifdef MG_ENTERPRISE
,
nullptr
#endif
,
&auth_handler, &auth_checker);
memgraph::requests::Init();
memgraph::telemetry::Telemetry telemetry(FLAGS_endpoint, FLAGS_storage_directory, memgraph::utils::GenerateUUID(),
@ -65,9 +71,9 @@ int main(int argc, char **argv) {
});
// Memgraph specific collectors
telemetry.AddStorageCollector(dbms_handler, auth_);
telemetry.AddStorageCollector(dbms_handler, auth_, repl_state);
#ifdef MG_ENTERPRISE
telemetry.AddDatabaseCollector(dbms_handler);
telemetry.AddDatabaseCollector(dbms_handler, repl_state);
#else
telemetry.AddDatabaseCollector();
#endif

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -39,7 +39,14 @@ int main(int argc, char *argv[]) {
auto db_acc_opt = db_gk.access();
MG_ASSERT(db_acc_opt, "Failed to access db");
auto &db_acc = *db_acc_opt;
memgraph::query::InterpreterContext interpreter_context(memgraph::query::InterpreterConfig{}, nullptr, &repl_state);
memgraph::system::System system_state;
memgraph::query::InterpreterContext interpreter_context(memgraph::query::InterpreterConfig{}, nullptr, &repl_state,
system_state
#ifdef MG_ENTERPRISE
,
nullptr
#endif
);
memgraph::query::Interpreter interpreter{&interpreter_context, db_acc};
ResultStreamFaker stream(db_acc->storage());

View File

@ -25,7 +25,7 @@
class AuthQueryHandlerFixture : public testing::Test {
protected:
std::filesystem::path test_folder_{std::filesystem::temp_directory_path() / "MG_tests_unit_auth_handler"};
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> auth{
memgraph::auth::SynchedAuth auth{
test_folder_ / ("unit_auth_handler_test_" + std::to_string(static_cast<int>(getpid()))),
memgraph::auth::Auth::Config{/* default */}};
memgraph::glue::AuthQueryHandler auth_handler{&auth};

View File

@ -10,6 +10,7 @@
// licenses/APL.txt.
#include "query/auth_query_handler.hpp"
#include "replication/state.hpp"
#include "storage/v2/config.hpp"
#ifdef MG_ENTERPRISE
#include <gmock/gmock.h>
@ -43,7 +44,9 @@ std::set<std::string> GetDirs(auto path) {
std::filesystem::path storage_directory{std::filesystem::temp_directory_path() / "MG_test_unit_dbms_handler"};
std::filesystem::path db_dir{storage_directory / "databases"};
static memgraph::storage::Config storage_conf;
std::unique_ptr<memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock>> auth;
std::unique_ptr<memgraph::auth::SynchedAuth> auth;
std::unique_ptr<memgraph::system::System> system_state;
std::unique_ptr<memgraph::replication::ReplicationState> repl_state;
// Let this be global so we can test it different states throughout
@ -64,14 +67,18 @@ class TestEnvironment : public ::testing::Environment {
std::filesystem::remove_all(storage_directory);
}
}
auth =
std::make_unique<memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock>>(
storage_directory / "auth", memgraph::auth::Auth::Config{/* default */});
ptr_ = std::make_unique<memgraph::dbms::DbmsHandler>(storage_conf, auth.get(), false);
auth = std::make_unique<memgraph::auth::SynchedAuth>(storage_directory / "auth",
memgraph::auth::Auth::Config{/* default */});
system_state = std::make_unique<memgraph::system::System>();
repl_state = std::make_unique<memgraph::replication::ReplicationState>(ReplicationStateRootPath(storage_conf));
ptr_ = std::make_unique<memgraph::dbms::DbmsHandler>(storage_conf, *system_state.get(), *repl_state.get(),
*auth.get(), false);
}
void TearDown() override {
ptr_.reset();
repl_state.reset();
system_state.reset();
auth.reset();
std::filesystem::remove_all(storage_directory);
}

View File

@ -28,7 +28,9 @@
// Global
std::filesystem::path storage_directory{std::filesystem::temp_directory_path() / "MG_test_unit_dbms_handler_community"};
static memgraph::storage::Config storage_conf;
std::unique_ptr<memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock>> auth;
std::unique_ptr<memgraph::auth::SynchedAuth> auth;
std::unique_ptr<memgraph::system::System> system_state;
std::unique_ptr<memgraph::replication::ReplicationState> repl_state;
// Let this be global so we can test it different states throughout
@ -49,14 +51,17 @@ class TestEnvironment : public ::testing::Environment {
std::filesystem::remove_all(storage_directory);
}
}
auth =
std::make_unique<memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock>>(
storage_directory / "auth", memgraph::auth::Auth::Config{/* default */});
ptr_ = std::make_unique<memgraph::dbms::DbmsHandler>(storage_conf);
auth = std::make_unique<memgraph::auth::SynchedAuth>(storage_directory / "auth",
memgraph::auth::Auth::Config{/* default */});
system_state = std::make_unique<memgraph::system::System>();
repl_state = std::make_unique<memgraph::replication::ReplicationState>(ReplicationStateRootPath(storage_conf));
ptr_ = std::make_unique<memgraph::dbms::DbmsHandler>(storage_conf, *system_state.get(), *repl_state.get());
}
void TearDown() override {
ptr_.reset();
repl_state.reset();
system_state.reset();
auth.reset();
std::filesystem::remove_all(storage_directory);
}

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -94,7 +94,16 @@ class InterpreterTest : public ::testing::Test {
}() // iile
};
memgraph::query::InterpreterContext interpreter_context{{}, kNoHandler, &repl_state};
memgraph::system::System system_state;
memgraph::query::InterpreterContext interpreter_context{{},
kNoHandler,
&repl_state,
system_state
#ifdef MG_ENTERPRISE
,
nullptr
#endif
};
void TearDown() override {
if (std::is_same<StorageType, memgraph::storage::DiskStorage>::value) {
@ -1150,8 +1159,16 @@ TYPED_TEST(InterpreterTest, AllowLoadCsvConfig) {
<< "Wrong storage mode!";
memgraph::replication::ReplicationState repl_state{std::nullopt};
memgraph::query::InterpreterContext csv_interpreter_context{
{.query = {.allow_load_csv = allow_load_csv}}, nullptr, &repl_state};
memgraph::system::System system_state;
memgraph::query::InterpreterContext csv_interpreter_context{{.query = {.allow_load_csv = allow_load_csv}},
nullptr,
&repl_state,
system_state
#ifdef MG_ENTERPRISE
,
nullptr
#endif
};
InterpreterFaker interpreter_faker{&csv_interpreter_context, db_acc};
for (const auto &query : queries) {
if (allow_load_csv) {

View File

@ -14,6 +14,7 @@
#include <filesystem>
#include <thread>
#include "auth/auth.hpp"
#include "communication/bolt/v1/value.hpp"
#include "communication/result_stream_faker.hpp"
#include "csv/parsing.hpp"
@ -99,8 +100,17 @@ class MultiTenantTest : public ::testing::Test {
struct MinMemgraph {
explicit MinMemgraph(const memgraph::storage::Config &conf)
: auth{conf.durability.storage_directory / "auth", memgraph::auth::Auth::Config{/* default */}},
dbms{conf, &auth, true},
interpreter_context{{}, &dbms, &dbms.ReplicationState()} {
repl_state{ReplicationStateRootPath(conf)},
dbms{conf, system, repl_state, auth, true},
interpreter_context{{},
&dbms,
&repl_state,
system
#ifdef MG_ENTERPRISE
,
nullptr
#endif
} {
memgraph::utils::global_settings.Initialize(conf.durability.storage_directory / "settings");
memgraph::license::RegisterLicenseSettings(memgraph::license::global_license_checker,
memgraph::utils::global_settings);
@ -112,7 +122,9 @@ class MultiTenantTest : public ::testing::Test {
auto NewInterpreter() { return InterpreterFaker{&interpreter_context, dbms.Get()}; }
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> auth;
memgraph::auth::SynchedAuth auth;
memgraph::system::System system;
memgraph::replication::ReplicationState repl_state;
memgraph::dbms::DbmsHandler dbms;
memgraph::query::InterpreterContext interpreter_context;
};

View File

@ -267,7 +267,7 @@ memgraph::storage::EdgeAccessor CreateEdge(memgraph::storage::Storage::Accessor
}
template <class... TArgs>
void VerifyQueries(const std::vector<std::vector<memgraph::communication::bolt::Value>> &results, TArgs &&... args) {
void VerifyQueries(const std::vector<std::vector<memgraph::communication::bolt::Value>> &results, TArgs &&...args) {
std::vector<std::string> expected{std::forward<TArgs>(args)...};
std::vector<std::string> got;
got.reserve(results.size());
@ -314,8 +314,13 @@ class DumpTest : public ::testing::Test {
return db_acc;
}() // iile
};
memgraph::query::InterpreterContext context{memgraph::query::InterpreterConfig{}, nullptr, &repl_state};
memgraph::system::System system_state;
memgraph::query::InterpreterContext context{memgraph::query::InterpreterConfig{}, nullptr, &repl_state, system_state
#ifdef MG_ENTERPRISE
,
nullptr
#endif
};
void TearDown() override {
if (std::is_same<StorageType, memgraph::storage::DiskStorage>::value) {
@ -722,7 +727,14 @@ TYPED_TEST(DumpTest, CheckStateVertexWithMultipleProperties) {
: memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL))
<< "Wrong storage mode!";
memgraph::query::InterpreterContext interpreter_context(memgraph::query::InterpreterConfig{}, nullptr, &repl_state);
memgraph::system::System system_state;
memgraph::query::InterpreterContext interpreter_context(memgraph::query::InterpreterConfig{}, nullptr, &repl_state,
system_state
#ifdef MG_ENTERPRISE
,
nullptr
#endif
);
{
ResultStreamFaker stream(this->db->storage());
@ -842,7 +854,14 @@ TYPED_TEST(DumpTest, CheckStateSimpleGraph) {
: memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL))
<< "Wrong storage mode!";
memgraph::query::InterpreterContext interpreter_context(memgraph::query::InterpreterConfig{}, nullptr, &repl_state);
memgraph::system::System system_state;
memgraph::query::InterpreterContext interpreter_context(memgraph::query::InterpreterConfig{}, nullptr, &repl_state,
system_state
#ifdef MG_ENTERPRISE
,
nullptr
#endif
);
{
ResultStreamFaker stream(this->db->storage());
memgraph::query::AnyStream query_stream(&stream, memgraph::utils::NewDeleteResource());

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -42,6 +42,7 @@ class QueryExecution : public testing::Test {
std::optional<memgraph::replication::ReplicationState> repl_state;
std::optional<memgraph::utils::Gatekeeper<memgraph::dbms::Database>> db_gk;
std::optional<memgraph::system::System> system_state;
void SetUp() override {
auto config = [&]() {
@ -65,14 +66,20 @@ class QueryExecution : public testing::Test {
: memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL),
"Wrong storage mode!");
db_acc_ = std::move(db_acc);
interpreter_context_.emplace(memgraph::query::InterpreterConfig{}, nullptr, &repl_state.value());
system_state.emplace();
interpreter_context_.emplace(memgraph::query::InterpreterConfig{}, nullptr, &repl_state.value(), *system_state
#ifdef MG_ENTERPRISE
,
nullptr
#endif
);
interpreter_.emplace(&*interpreter_context_, *db_acc_);
}
void TearDown() override {
interpreter_ = std::nullopt;
interpreter_context_ = std::nullopt;
system_state.reset();
db_acc_.reset();
db_gk.reset();
repl_state.reset();

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -104,7 +104,14 @@ class StreamsTestFixture : public ::testing::Test {
return db_acc;
}() // iile
};
memgraph::query::InterpreterContext interpreter_context_{memgraph::query::InterpreterConfig{}, nullptr, &repl_state};
memgraph::system::System system_state;
memgraph::query::InterpreterContext interpreter_context_{memgraph::query::InterpreterConfig{}, nullptr, &repl_state,
system_state
#ifdef MG_ENTERPRISE
,
nullptr
#endif
};
std::filesystem::path streams_data_directory_{data_directory_ / "separate-dir-for-test"};
std::optional<StreamsTest> proxyStreams_;

View File

@ -25,22 +25,20 @@
#include "auth/auth.hpp"
#include "dbms/database.hpp"
#include "dbms/dbms_handler.hpp"
#include "dbms/replication_handler.hpp"
#include "query/interpreter_context.hpp"
#include "replication/config.hpp"
#include "replication/state.hpp"
#include "replication_handler/replication_handler.hpp"
#include "storage/v2/indices/label_index_stats.hpp"
#include "storage/v2/storage.hpp"
#include "storage/v2/view.hpp"
#include "utils/rw_lock.hpp"
#include "utils/synchronized.hpp"
using testing::UnorderedElementsAre;
using memgraph::dbms::RegisterReplicaError;
using memgraph::dbms::ReplicationHandler;
using memgraph::dbms::UnregisterReplicaResult;
using memgraph::query::RegisterReplicaError;
using memgraph::query::UnregisterReplicaResult;
using memgraph::replication::ReplicationClientConfig;
using memgraph::replication::ReplicationHandler;
using memgraph::replication::ReplicationServerConfig;
using memgraph::replication_coordination_glue::ReplicationMode;
using memgraph::replication_coordination_glue::ReplicationRole;
@ -114,21 +112,26 @@ class ReplicationTest : public ::testing::Test {
struct MinMemgraph {
MinMemgraph(const memgraph::storage::Config &conf)
: auth{conf.durability.storage_directory / "auth", memgraph::auth::Auth::Config{/* default */}},
dbms{conf
repl_state{ReplicationStateRootPath(conf)},
dbms{conf, system_, repl_state
#ifdef MG_ENTERPRISE
,
&auth, true
auth, true
#endif
},
repl_state{dbms.ReplicationState()},
db_acc{dbms.Get()},
db{*db_acc.get()},
repl_handler(dbms) {
repl_handler(repl_state, dbms
#ifdef MG_ENTERPRISE
,
&system_, auth
#endif
) {
}
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> auth;
memgraph::auth::SynchedAuth auth;
memgraph::system::System system_;
memgraph::replication::ReplicationState repl_state;
memgraph::dbms::DbmsHandler dbms;
memgraph::replication::ReplicationState &repl_state;
memgraph::dbms::DatabaseAccess db_acc;
memgraph::dbms::Database &db;
ReplicationHandler repl_handler;
@ -144,7 +147,7 @@ TEST_F(ReplicationTest, BasicSynchronousReplicationTest) {
.port = ports[0],
});
const auto &reg = main.repl_handler.RegisterReplica(ReplicationClientConfig{
const auto &reg = main.repl_handler.TryRegisterReplica(ReplicationClientConfig{
.name = "REPLICA",
.mode = ReplicationMode::SYNC,
.ip_address = local_host,
@ -442,7 +445,7 @@ TEST_F(ReplicationTest, MultipleSynchronousReplicationTest) {
});
ASSERT_FALSE(main.repl_handler
.RegisterReplica(ReplicationClientConfig{
.TryRegisterReplica(ReplicationClientConfig{
.name = replicas[0],
.mode = ReplicationMode::SYNC,
.ip_address = local_host,
@ -450,7 +453,7 @@ TEST_F(ReplicationTest, MultipleSynchronousReplicationTest) {
})
.HasError());
ASSERT_FALSE(main.repl_handler
.RegisterReplica(ReplicationClientConfig{
.TryRegisterReplica(ReplicationClientConfig{
.name = replicas[1],
.mode = ReplicationMode::SYNC,
.ip_address = local_host,
@ -587,7 +590,7 @@ TEST_F(ReplicationTest, RecoveryProcess) {
.port = ports[0],
});
ASSERT_FALSE(main.repl_handler
.RegisterReplica(ReplicationClientConfig{
.TryRegisterReplica(ReplicationClientConfig{
.name = replicas[0],
.mode = ReplicationMode::SYNC,
.ip_address = local_host,
@ -663,7 +666,7 @@ TEST_F(ReplicationTest, BasicAsynchronousReplicationTest) {
});
ASSERT_FALSE(main.repl_handler
.RegisterReplica(ReplicationClientConfig{
.TryRegisterReplica(ReplicationClientConfig{
.name = "REPLICA_ASYNC",
.mode = ReplicationMode::ASYNC,
.ip_address = local_host,
@ -715,7 +718,7 @@ TEST_F(ReplicationTest, EpochTest) {
});
ASSERT_FALSE(main.repl_handler
.RegisterReplica(ReplicationClientConfig{
.TryRegisterReplica(ReplicationClientConfig{
.name = replicas[0],
.mode = ReplicationMode::SYNC,
.ip_address = local_host,
@ -724,7 +727,7 @@ TEST_F(ReplicationTest, EpochTest) {
.HasError());
ASSERT_FALSE(main.repl_handler
.RegisterReplica(ReplicationClientConfig{
.TryRegisterReplica(ReplicationClientConfig{
.name = replicas[1],
.mode = ReplicationMode::SYNC,
.ip_address = local_host,
@ -758,7 +761,7 @@ TEST_F(ReplicationTest, EpochTest) {
ASSERT_TRUE(replica1.repl_handler.SetReplicationRoleMain());
ASSERT_FALSE(replica1.repl_handler
.RegisterReplica(ReplicationClientConfig{
.TryRegisterReplica(ReplicationClientConfig{
.name = replicas[1],
.mode = ReplicationMode::SYNC,
.ip_address = local_host,
@ -791,7 +794,7 @@ TEST_F(ReplicationTest, EpochTest) {
.port = ports[0],
});
ASSERT_TRUE(main.repl_handler
.RegisterReplica(ReplicationClientConfig{
.TryRegisterReplica(ReplicationClientConfig{
.name = replicas[0],
.mode = ReplicationMode::SYNC,
.ip_address = local_host,
@ -834,7 +837,7 @@ TEST_F(ReplicationTest, ReplicationInformation) {
});
ASSERT_FALSE(main.repl_handler
.RegisterReplica(ReplicationClientConfig{
.TryRegisterReplica(ReplicationClientConfig{
.name = replicas[0],
.mode = ReplicationMode::SYNC,
.ip_address = local_host,
@ -844,7 +847,7 @@ TEST_F(ReplicationTest, ReplicationInformation) {
.HasError());
ASSERT_FALSE(main.repl_handler
.RegisterReplica(ReplicationClientConfig{
.TryRegisterReplica(ReplicationClientConfig{
.name = replicas[1],
.mode = ReplicationMode::ASYNC,
.ip_address = local_host,
@ -890,7 +893,7 @@ TEST_F(ReplicationTest, ReplicationReplicaWithExistingName) {
.port = replica2_port,
});
ASSERT_FALSE(main.repl_handler
.RegisterReplica(ReplicationClientConfig{
.TryRegisterReplica(ReplicationClientConfig{
.name = replicas[0],
.mode = ReplicationMode::SYNC,
.ip_address = local_host,
@ -899,7 +902,7 @@ TEST_F(ReplicationTest, ReplicationReplicaWithExistingName) {
.HasError());
ASSERT_TRUE(main.repl_handler
.RegisterReplica(ReplicationClientConfig{
.TryRegisterReplica(ReplicationClientConfig{
.name = replicas[0],
.mode = ReplicationMode::ASYNC,
.ip_address = local_host,
@ -925,7 +928,7 @@ TEST_F(ReplicationTest, ReplicationReplicaWithExistingEndPoint) {
});
ASSERT_FALSE(main.repl_handler
.RegisterReplica(ReplicationClientConfig{
.TryRegisterReplica(ReplicationClientConfig{
.name = replicas[0],
.mode = ReplicationMode::SYNC,
.ip_address = local_host,
@ -934,7 +937,7 @@ TEST_F(ReplicationTest, ReplicationReplicaWithExistingEndPoint) {
.HasError());
ASSERT_TRUE(main.repl_handler
.RegisterReplica(ReplicationClientConfig{
.TryRegisterReplica(ReplicationClientConfig{
.name = replicas[1],
.mode = ReplicationMode::ASYNC,
.ip_address = local_host,
@ -973,14 +976,14 @@ TEST_F(ReplicationTest, RestoringReplicationAtStartupAfterDroppingReplica) {
.port = ports[1],
});
auto res = main->repl_handler.RegisterReplica(ReplicationClientConfig{
auto res = main->repl_handler.TryRegisterReplica(ReplicationClientConfig{
.name = replicas[0],
.mode = ReplicationMode::SYNC,
.ip_address = local_host,
.port = ports[0],
});
ASSERT_FALSE(res.HasError()) << (int)res.GetError();
res = main->repl_handler.RegisterReplica(ReplicationClientConfig{
res = main->repl_handler.TryRegisterReplica(ReplicationClientConfig{
.name = replicas[1],
.mode = ReplicationMode::SYNC,
.ip_address = local_host,
@ -1030,14 +1033,14 @@ TEST_F(ReplicationTest, RestoringReplicationAtStartup) {
.ip_address = local_host,
.port = ports[1],
});
auto res = main->repl_handler.RegisterReplica(ReplicationClientConfig{
auto res = main->repl_handler.TryRegisterReplica(ReplicationClientConfig{
.name = replicas[0],
.mode = ReplicationMode::SYNC,
.ip_address = local_host,
.port = ports[0],
});
ASSERT_FALSE(res.HasError());
res = main->repl_handler.RegisterReplica(ReplicationClientConfig{
res = main->repl_handler.TryRegisterReplica(ReplicationClientConfig{
.name = replicas[1],
.mode = ReplicationMode::SYNC,
.ip_address = local_host,
@ -1080,7 +1083,7 @@ TEST_F(ReplicationTest, AddingInvalidReplica) {
MinMemgraph main(main_conf);
ASSERT_TRUE(main.repl_handler
.RegisterReplica(ReplicationClientConfig{
.TryRegisterReplica(ReplicationClientConfig{
.name = "REPLICA",
.mode = ReplicationMode::SYNC,
.ip_address = local_host,

View File

@ -90,7 +90,16 @@ class StorageModeMultiTxTest : public ::testing::Test {
return db_acc;
}() // iile
};
memgraph::query::InterpreterContext interpreter_context{{}, nullptr, &repl_state};
memgraph::system::System system_state;
memgraph::query::InterpreterContext interpreter_context{{},
nullptr,
&repl_state,
system_state
#ifdef MG_ENTERPRISE
,
nullptr
#endif
};
InterpreterFaker running_interpreter{&interpreter_context, db}, main_interpreter{&interpreter_context, db};
};

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -59,7 +59,16 @@ class TransactionQueueSimpleTest : public ::testing::Test {
return db_acc;
}() // iile
};
memgraph::query::InterpreterContext interpreter_context{{}, nullptr, &repl_state};
memgraph::system::System system_state;
memgraph::query::InterpreterContext interpreter_context{{},
nullptr,
&repl_state,
system_state
#ifdef MG_ENTERPRISE
,
nullptr
#endif
};
InterpreterFaker running_interpreter{&interpreter_context, db}, main_interpreter{&interpreter_context, db};
void TearDown() override {

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -68,7 +68,16 @@ class TransactionQueueMultipleTest : public ::testing::Test {
}() // iile
};
memgraph::query::InterpreterContext interpreter_context{{}, nullptr, &repl_state};
memgraph::system::System system_state;
memgraph::query::InterpreterContext interpreter_context{{},
nullptr,
&repl_state,
system_state
#ifdef MG_ENTERPRISE
,
nullptr
#endif
};
InterpreterFaker main_interpreter{&interpreter_context, db};
std::vector<InterpreterFaker *> running_interpreters;