Better auth user/role handling (#1699)

* Stop auth module from creating users
* Explicit about auth policy (check if no users defined OR auth module used)
* Role supports database access definition
* Authenticate() returns user or role
* AuthChecker generates QueryUserOrRole (can be empty)
* QueryUserOrRole actually authorizes
* Add auth cache invalidation
* Better database access queries (GRANT, DENY, REVOKE DATABASE)
This commit is contained in:
andrejtonev 2024-02-22 15:00:39 +01:00 committed by GitHub
parent 98727e0fa0
commit 6a4ef55e90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
60 changed files with 1870 additions and 880 deletions

View File

@ -35,16 +35,42 @@ DEFINE_VALIDATED_string(auth_module_executable, "", "Absolute path to the auth m
}
return true;
});
DEFINE_bool(auth_module_create_missing_user, true, "Set to false to disable creation of missing users.");
DEFINE_bool(auth_module_create_missing_role, true, "Set to false to disable creation of missing roles.");
DEFINE_bool(auth_module_manage_roles, true, "Set to false to disable management of roles through the auth module.");
DEFINE_VALIDATED_int32(auth_module_timeout_ms, 10000,
"Timeout (in milliseconds) used when waiting for a "
"response from the auth module.",
FLAG_IN_RANGE(100, 1800000));
// DEPRECATED FLAGS
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables, misc-unused-parameters)
DEFINE_VALIDATED_HIDDEN_bool(auth_module_create_missing_user, true,
"Set to false to disable creation of missing users.", {
spdlog::warn(
"auth_module_create_missing_user flag is deprecated. It not possible to create "
"users through the module anymore.");
return true;
});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables, misc-unused-parameters)
DEFINE_VALIDATED_HIDDEN_bool(auth_module_create_missing_role, true,
"Set to false to disable creation of missing roles.", {
spdlog::warn(
"auth_module_create_missing_role flag is deprecated. It not possible to create "
"roles through the module anymore.");
return true;
});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables, misc-unused-parameters)
DEFINE_VALIDATED_HIDDEN_bool(
auth_module_manage_roles, true, "Set to false to disable management of roles through the auth module.", {
spdlog::warn(
"auth_module_manage_roles flag is deprecated. It not possible to create roles through the module anymore.");
return true;
});
namespace memgraph::auth {
const Auth::Epoch Auth::kStartEpoch = 1;
namespace {
#ifdef MG_ENTERPRISE
/**
@ -192,6 +218,17 @@ void MigrateVersions(kvstore::KVStore &store) {
version_str = kVersionV1;
}
}
auto ParseJson(std::string_view str) {
nlohmann::json data;
try {
data = nlohmann::json::parse(str);
} catch (const nlohmann::json::parse_error &e) {
throw AuthException("Couldn't load auth data!");
}
return data;
}
}; // namespace
Auth::Auth(std::string storage_directory, Config config)
@ -199,8 +236,11 @@ Auth::Auth(std::string storage_directory, Config config)
MigrateVersions(storage_);
}
std::optional<User> Auth::Authenticate(const std::string &username, const std::string &password) {
std::optional<UserOrRole> Auth::Authenticate(const std::string &username, const std::string &password) {
if (module_.IsUsed()) {
/*
* MODULE AUTH STORAGE
*/
const auto license_check_result = license::global_license_checker.IsEnterpriseValid(utils::global_settings);
if (license_check_result.HasError()) {
spdlog::warn(license::LicenseCheckErrorToString(license_check_result.GetError(), "authentication modules"));
@ -225,108 +265,64 @@ std::optional<User> Auth::Authenticate(const std::string &username, const std::s
auto is_authenticated = ret_authenticated.get<bool>();
const auto &rolename = ret_role.get<std::string>();
// Check if role is present
auto role = GetRole(rolename);
if (!role) {
spdlog::warn(utils::MessageWithLink("Couldn't authenticate user '{}' because the role '{}' doesn't exist.",
username, rolename, "https://memgr.ph/auth"));
return std::nullopt;
}
// Authenticate the user.
if (!is_authenticated) return std::nullopt;
/**
* TODO
* The auth module should not update auth data.
* There is now way to replicate it and we should not be storing sensitive data if we don't have to.
*/
// Find or create the user and return it.
auto user = GetUser(username);
if (!user) {
if (FLAGS_auth_module_create_missing_user) {
user = AddUser(username, password);
if (!user) {
spdlog::warn(utils::MessageWithLink(
"Couldn't create the missing user '{}' using the auth module because the user already exists as a role.",
username, "https://memgr.ph/auth"));
return std::nullopt;
}
} else {
spdlog::warn(utils::MessageWithLink(
"Couldn't authenticate user '{}' using the auth module because the user doesn't exist.", username,
"https://memgr.ph/auth"));
return std::nullopt;
}
} else {
UpdatePassword(*user, password);
}
if (FLAGS_auth_module_manage_roles) {
if (!rolename.empty()) {
auto role = GetRole(rolename);
if (!role) {
if (FLAGS_auth_module_create_missing_role) {
role = AddRole(rolename);
if (!role) {
spdlog::warn(
utils::MessageWithLink("Couldn't authenticate user '{}' using the auth module because the user's "
"role '{}' already exists as a user.",
username, rolename, "https://memgr.ph/auth"));
return std::nullopt;
}
SaveRole(*role);
} else {
spdlog::warn(utils::MessageWithLink(
"Couldn't authenticate user '{}' using the auth module because the user's role '{}' doesn't exist.",
username, rolename, "https://memgr.ph/auth"));
return std::nullopt;
}
}
user->SetRole(*role);
} else {
user->ClearRole();
}
}
SaveUser(*user);
return user;
} else {
auto user = GetUser(username);
if (!user) {
spdlog::warn(utils::MessageWithLink("Couldn't authenticate user '{}' because the user doesn't exist.", username,
"https://memgr.ph/auth"));
return std::nullopt;
}
if (!user->CheckPassword(password)) {
spdlog::warn(utils::MessageWithLink("Couldn't authenticate user '{}' because the password is not correct.",
username, "https://memgr.ph/auth"));
return std::nullopt;
}
if (user->UpgradeHash(password)) {
SaveUser(*user);
}
return user;
return RoleWUsername{username, std::move(*role)};
}
/*
* LOCAL AUTH STORAGE
*/
auto user = GetUser(username);
if (!user) {
spdlog::warn(utils::MessageWithLink("Couldn't authenticate user '{}' because the user doesn't exist.", username,
"https://memgr.ph/auth"));
return std::nullopt;
}
if (!user->CheckPassword(password)) {
spdlog::warn(utils::MessageWithLink("Couldn't authenticate user '{}' because the password is not correct.",
username, "https://memgr.ph/auth"));
return std::nullopt;
}
if (user->UpgradeHash(password)) {
SaveUser(*user);
}
return user;
}
std::optional<User> Auth::GetUser(const std::string &username_orig) const {
auto username = utils::ToLowerCase(username_orig);
auto existing_user = storage_.Get(kUserPrefix + username);
if (!existing_user) return std::nullopt;
nlohmann::json data;
try {
data = nlohmann::json::parse(*existing_user);
} catch (const nlohmann::json::parse_error &e) {
throw AuthException("Couldn't load user data!");
}
auto user = User::Deserialize(data);
auto link = storage_.Get(kLinkPrefix + username);
void Auth::LinkUser(User &user) const {
auto link = storage_.Get(kLinkPrefix + user.username());
if (link) {
auto role = GetRole(*link);
if (role) {
user.SetRole(*role);
}
}
}
std::optional<User> Auth::GetUser(const std::string &username_orig) const {
if (module_.IsUsed()) return std::nullopt; // User's are not supported when using module
auto username = utils::ToLowerCase(username_orig);
auto existing_user = storage_.Get(kUserPrefix + username);
if (!existing_user) return std::nullopt;
auto user = User::Deserialize(ParseJson(*existing_user));
LinkUser(user);
return user;
}
void Auth::SaveUser(const User &user, system::Transaction *system_tx) {
DisableIfModuleUsed();
bool success = false;
if (const auto *role = user.role(); role != nullptr) {
success = storage_.PutMultiple(
@ -338,6 +334,10 @@ void Auth::SaveUser(const User &user, system::Transaction *system_tx) {
if (!success) {
throw AuthException("Couldn't save user '{}'!", user.username());
}
// Durability updated -> new epoch
UpdateEpoch();
// All changes to the user end up calling this function, so no need to add a delta anywhere else
if (system_tx) {
#ifdef MG_ENTERPRISE
@ -347,6 +347,7 @@ void Auth::SaveUser(const User &user, system::Transaction *system_tx) {
}
void Auth::UpdatePassword(auth::User &user, const std::optional<std::string> &password) {
DisableIfModuleUsed();
// Check if null
if (!password) {
if (!config_.password_permit_null) {
@ -378,6 +379,7 @@ void Auth::UpdatePassword(auth::User &user, const std::optional<std::string> &pa
std::optional<User> Auth::AddUser(const std::string &username, const std::optional<std::string> &password,
system::Transaction *system_tx) {
DisableIfModuleUsed();
if (!NameRegexMatch(username)) {
throw AuthException("Invalid user name.");
}
@ -392,12 +394,17 @@ std::optional<User> Auth::AddUser(const std::string &username, const std::option
}
bool Auth::RemoveUser(const std::string &username_orig, system::Transaction *system_tx) {
DisableIfModuleUsed();
auto username = utils::ToLowerCase(username_orig);
if (!storage_.Get(kUserPrefix + username)) return false;
std::vector<std::string> keys({kLinkPrefix + username, kUserPrefix + username});
if (!storage_.DeleteMultiple(keys)) {
throw AuthException("Couldn't remove user '{}'!", username);
}
// Durability updated -> new epoch
UpdateEpoch();
// Handling drop user delta
if (system_tx) {
#ifdef MG_ENTERPRISE
@ -412,9 +419,12 @@ std::vector<auth::User> Auth::AllUsers() const {
for (auto it = storage_.begin(kUserPrefix); it != storage_.end(kUserPrefix); ++it) {
auto username = it->first.substr(kUserPrefix.size());
if (username != utils::ToLowerCase(username)) continue;
auto user = GetUser(username);
if (user) {
ret.push_back(std::move(*user));
try {
User user = auth::User::Deserialize(ParseJson(it->second)); // Will throw on failure
LinkUser(user);
ret.emplace_back(std::move(user));
} catch (AuthException &) {
continue;
}
}
return ret;
@ -425,9 +435,12 @@ std::vector<std::string> Auth::AllUsernames() const {
for (auto it = storage_.begin(kUserPrefix); it != storage_.end(kUserPrefix); ++it) {
auto username = it->first.substr(kUserPrefix.size());
if (username != utils::ToLowerCase(username)) continue;
auto user = GetUser(username);
if (user) {
ret.push_back(username);
try {
// Check if serialized correctly
memgraph::auth::User::Deserialize(ParseJson(it->second)); // Will throw on failure
ret.emplace_back(std::move(username));
} catch (AuthException &) {
continue;
}
}
return ret;
@ -435,25 +448,24 @@ std::vector<std::string> Auth::AllUsernames() const {
bool Auth::HasUsers() const { return storage_.begin(kUserPrefix) != storage_.end(kUserPrefix); }
bool Auth::AccessControlled() const { return HasUsers() || module_.IsUsed(); }
std::optional<Role> Auth::GetRole(const std::string &rolename_orig) const {
auto rolename = utils::ToLowerCase(rolename_orig);
auto existing_role = storage_.Get(kRolePrefix + rolename);
if (!existing_role) return std::nullopt;
nlohmann::json data;
try {
data = nlohmann::json::parse(*existing_role);
} catch (const nlohmann::json::parse_error &e) {
throw AuthException("Couldn't load role data!");
}
return Role::Deserialize(data);
return Role::Deserialize(ParseJson(*existing_role));
}
void Auth::SaveRole(const Role &role, system::Transaction *system_tx) {
if (!storage_.Put(kRolePrefix + role.rolename(), role.Serialize().dump())) {
throw AuthException("Couldn't save role '{}'!", role.rolename());
}
// Durability updated -> new epoch
UpdateEpoch();
// All changes to the role end up calling this function, so no need to add a delta anywhere else
if (system_tx) {
#ifdef MG_ENTERPRISE
@ -486,6 +498,10 @@ bool Auth::RemoveRole(const std::string &rolename_orig, system::Transaction *sys
if (!storage_.DeleteMultiple(keys)) {
throw AuthException("Couldn't remove role '{}'!", rolename);
}
// Durability updated -> new epoch
UpdateEpoch();
// Handling drop role delta
if (system_tx) {
#ifdef MG_ENTERPRISE
@ -500,11 +516,8 @@ std::vector<auth::Role> Auth::AllRoles() const {
for (auto it = storage_.begin(kRolePrefix); it != storage_.end(kRolePrefix); ++it) {
auto rolename = it->first.substr(kRolePrefix.size());
if (rolename != utils::ToLowerCase(rolename)) continue;
if (auto role = GetRole(rolename)) {
ret.push_back(*role);
} else {
throw AuthException("Couldn't load role '{}'!", rolename);
}
Role role = memgraph::auth::Role::Deserialize(ParseJson(it->second)); // Will throw on failure
ret.emplace_back(std::move(role));
}
return ret;
}
@ -514,14 +527,19 @@ std::vector<std::string> Auth::AllRolenames() const {
for (auto it = storage_.begin(kRolePrefix); it != storage_.end(kRolePrefix); ++it) {
auto rolename = it->first.substr(kRolePrefix.size());
if (rolename != utils::ToLowerCase(rolename)) continue;
if (auto role = GetRole(rolename)) {
ret.push_back(rolename);
try {
// Check that the data is serialized correctly
memgraph::auth::Role::Deserialize(ParseJson(it->second));
ret.emplace_back(std::move(rolename));
} catch (AuthException &) {
continue;
}
}
return ret;
}
std::vector<auth::User> Auth::AllUsersForRole(const std::string &rolename_orig) const {
DisableIfModuleUsed();
const auto rolename = utils::ToLowerCase(rolename_orig);
std::vector<auth::User> ret;
for (auto it = storage_.begin(kLinkPrefix); it != storage_.end(kLinkPrefix); ++it) {
@ -540,51 +558,176 @@ std::vector<auth::User> Auth::AllUsersForRole(const std::string &rolename_orig)
}
#ifdef MG_ENTERPRISE
bool Auth::GrantDatabaseToUser(const std::string &db, const std::string &name, system::Transaction *system_tx) {
if (auto user = GetUser(name)) {
if (db == kAllDatabases) {
user->db_access().GrantAll();
} else {
user->db_access().Add(db);
Auth::Result Auth::GrantDatabase(const std::string &db, const std::string &name, system::Transaction *system_tx) {
using enum Auth::Result;
if (module_.IsUsed()) {
if (auto role = GetRole(name)) {
GrantDatabase(db, *role, system_tx);
return SUCCESS;
}
SaveUser(*user, system_tx);
return true;
return NO_ROLE;
}
return false;
if (auto user = GetUser(name)) {
GrantDatabase(db, *user, system_tx);
return SUCCESS;
}
if (auto role = GetRole(name)) {
GrantDatabase(db, *role, system_tx);
return SUCCESS;
}
return NO_USER_ROLE;
}
bool Auth::RevokeDatabaseFromUser(const std::string &db, const std::string &name, system::Transaction *system_tx) {
if (auto user = GetUser(name)) {
if (db == kAllDatabases) {
user->db_access().DenyAll();
} else {
user->db_access().Remove(db);
}
SaveUser(*user, system_tx);
return true;
void Auth::GrantDatabase(const std::string &db, User &user, system::Transaction *system_tx) {
if (db == kAllDatabases) {
user.db_access().GrantAll();
} else {
user.db_access().Grant(db);
}
return false;
SaveUser(user, system_tx);
}
void Auth::GrantDatabase(const std::string &db, Role &role, system::Transaction *system_tx) {
if (db == kAllDatabases) {
role.db_access().GrantAll();
} else {
role.db_access().Grant(db);
}
SaveRole(role, system_tx);
}
Auth::Result Auth::DenyDatabase(const std::string &db, const std::string &name, system::Transaction *system_tx) {
using enum Auth::Result;
if (module_.IsUsed()) {
if (auto role = GetRole(name)) {
DenyDatabase(db, *role, system_tx);
return SUCCESS;
}
return NO_ROLE;
}
if (auto user = GetUser(name)) {
DenyDatabase(db, *user, system_tx);
return SUCCESS;
}
if (auto role = GetRole(name)) {
DenyDatabase(db, *role, system_tx);
return SUCCESS;
}
return NO_USER_ROLE;
}
void Auth::DenyDatabase(const std::string &db, User &user, system::Transaction *system_tx) {
if (db == kAllDatabases) {
user.db_access().DenyAll();
} else {
user.db_access().Deny(db);
}
SaveUser(user, system_tx);
}
void Auth::DenyDatabase(const std::string &db, Role &role, system::Transaction *system_tx) {
if (db == kAllDatabases) {
role.db_access().DenyAll();
} else {
role.db_access().Deny(db);
}
SaveRole(role, system_tx);
}
Auth::Result Auth::RevokeDatabase(const std::string &db, const std::string &name, system::Transaction *system_tx) {
using enum Auth::Result;
if (module_.IsUsed()) {
if (auto role = GetRole(name)) {
RevokeDatabase(db, *role, system_tx);
return SUCCESS;
}
return NO_ROLE;
}
if (auto user = GetUser(name)) {
RevokeDatabase(db, *user, system_tx);
return SUCCESS;
}
if (auto role = GetRole(name)) {
RevokeDatabase(db, *role, system_tx);
return SUCCESS;
}
return NO_USER_ROLE;
}
void Auth::RevokeDatabase(const std::string &db, User &user, system::Transaction *system_tx) {
if (db == kAllDatabases) {
user.db_access().RevokeAll();
} else {
user.db_access().Revoke(db);
}
SaveUser(user, system_tx);
}
void Auth::RevokeDatabase(const std::string &db, Role &role, system::Transaction *system_tx) {
if (db == kAllDatabases) {
role.db_access().RevokeAll();
} else {
role.db_access().Revoke(db);
}
SaveRole(role, system_tx);
}
void Auth::DeleteDatabase(const std::string &db, system::Transaction *system_tx) {
for (auto it = storage_.begin(kUserPrefix); it != storage_.end(kUserPrefix); ++it) {
auto username = it->first.substr(kUserPrefix.size());
if (auto user = GetUser(username)) {
user->db_access().Delete(db);
SaveUser(*user, system_tx);
try {
User user = auth::User::Deserialize(ParseJson(it->second));
LinkUser(user);
user.db_access().Revoke(db);
SaveUser(user, system_tx);
} catch (AuthException &) {
continue;
}
}
for (auto it = storage_.begin(kRolePrefix); it != storage_.end(kRolePrefix); ++it) {
auto rolename = it->first.substr(kRolePrefix.size());
try {
auto role = memgraph::auth::Role::Deserialize(ParseJson(it->second));
role.db_access().Revoke(db);
SaveRole(role, system_tx);
} catch (AuthException &) {
continue;
}
}
}
bool Auth::SetMainDatabase(std::string_view db, const std::string &name, system::Transaction *system_tx) {
if (auto user = GetUser(name)) {
if (!user->db_access().SetDefault(db)) {
throw AuthException("Couldn't set default database '{}' for user '{}'!", db, name);
Auth::Result Auth::SetMainDatabase(std::string_view db, const std::string &name, system::Transaction *system_tx) {
using enum Auth::Result;
if (module_.IsUsed()) {
if (auto role = GetRole(name)) {
SetMainDatabase(db, *role, system_tx);
return SUCCESS;
}
SaveUser(*user, system_tx);
return true;
return NO_ROLE;
}
return false;
if (auto user = GetUser(name)) {
SetMainDatabase(db, *user, system_tx);
return SUCCESS;
}
if (auto role = GetRole(name)) {
SetMainDatabase(db, *role, system_tx);
return SUCCESS;
}
return NO_USER_ROLE;
}
void Auth::SetMainDatabase(std::string_view db, User &user, system::Transaction *system_tx) {
if (!user.db_access().SetMain(db)) {
throw AuthException("Couldn't set default database '{}' for '{}'!", db, user.username());
}
SaveUser(user, system_tx);
}
void Auth::SetMainDatabase(std::string_view db, Role &role, system::Transaction *system_tx) {
if (!role.db_access().SetMain(db)) {
throw AuthException("Couldn't set default database '{}' for '{}'!", db, role.rolename());
}
SaveRole(role, system_tx);
}
#endif

View File

@ -29,6 +29,18 @@ using SynchedAuth = memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph
static const constexpr char *const kAllDatabases = "*";
struct RoleWUsername : Role {
template <typename... Args>
RoleWUsername(std::string_view username, Args &&...args) : Role{std::forward<Args>(args)...}, username_{username} {}
std::string username() { return username_; }
const std::string &username() const { return username_; }
private:
std::string username_;
};
using UserOrRole = std::variant<User, RoleWUsername>;
/**
* This class serves as the main Authentication/Authorization storage.
* It provides functions for managing Users, Roles, Permissions and FineGrainedAccessPermissions.
@ -61,6 +73,25 @@ class Auth final {
std::regex password_regex{password_regex_str};
};
struct Epoch {
Epoch() : epoch_{0} {}
Epoch(unsigned e) : epoch_{e} {}
Epoch operator++() { return ++epoch_; }
bool operator==(const Epoch &rhs) const = default;
private:
unsigned epoch_;
};
static const Epoch kStartEpoch;
enum class Result {
SUCCESS,
NO_USER_ROLE,
NO_ROLE,
};
explicit Auth(std::string storage_directory, Config config);
/**
@ -89,7 +120,7 @@ class Auth final {
* @return a user when the username and password match, nullopt otherwise
* @throw AuthException if unable to authenticate for whatever reason.
*/
std::optional<User> Authenticate(const std::string &username, const std::string &password);
std::optional<UserOrRole> Authenticate(const std::string &username, const std::string &password);
/**
* Gets a user from the storage.
@ -101,6 +132,8 @@ class Auth final {
*/
std::optional<User> GetUser(const std::string &username) const;
void LinkUser(User &user) const;
/**
* Saves a user object to the storage.
*
@ -163,6 +196,13 @@ class Auth final {
*/
bool HasUsers() const;
/**
* Returns whether the access is controlled by authentication/authorization.
*
* @return `true` if auth needs to run
*/
bool AccessControlled() const;
/**
* Gets a role from the storage.
*
@ -173,6 +213,37 @@ class Auth final {
*/
std::optional<Role> GetRole(const std::string &rolename) const;
std::optional<UserOrRole> GetUserOrRole(const std::optional<std::string> &username,
const std::optional<std::string> &rolename) const {
auto expect = [](bool condition, std::string &&msg) {
if (!condition) throw AuthException(std::move(msg));
};
// Special case if we are using a module; we must find the specified role
if (module_.IsUsed()) {
expect(username && rolename, "When using a module, a role needs to be connected to a username.");
const auto role = GetRole(*rolename);
expect(role != std::nullopt, "No role named " + *rolename);
return UserOrRole(auth::RoleWUsername{*username, *role});
}
// First check if we need to find a role
if (username && rolename) {
const auto role = GetRole(*rolename);
expect(role != std::nullopt, "No role named " + *rolename);
return UserOrRole(auth::RoleWUsername{*username, *role});
}
// We are only looking for a user
if (username) {
const auto user = GetUser(*username);
expect(user != std::nullopt, "No user named " + *username);
return *user;
}
// No user or role
return std::nullopt;
}
/**
* Saves a role object to the storage.
*
@ -229,16 +300,6 @@ class Auth final {
std::vector<User> AllUsersForRole(const std::string &rolename) const;
#ifdef MG_ENTERPRISE
/**
* @brief Revoke access to individual database for a user.
*
* @param db name of the database to revoke
* @param name user's username
* @return true on success
* @throw AuthException if unable to find or update the user
*/
bool RevokeDatabaseFromUser(const std::string &db, const std::string &name, system::Transaction *system_tx = nullptr);
/**
* @brief Grant access to individual database for a user.
*
@ -247,7 +308,33 @@ class Auth final {
* @return true on success
* @throw AuthException if unable to find or update the user
*/
bool GrantDatabaseToUser(const std::string &db, const std::string &name, system::Transaction *system_tx = nullptr);
Result GrantDatabase(const std::string &db, const std::string &name, system::Transaction *system_tx = nullptr);
void GrantDatabase(const std::string &db, User &user, system::Transaction *system_tx = nullptr);
void GrantDatabase(const std::string &db, Role &role, system::Transaction *system_tx = nullptr);
/**
* @brief Revoke access to individual database for a user.
*
* @param db name of the database to revoke
* @param name user's username
* @return true on success
* @throw AuthException if unable to find or update the user
*/
Result DenyDatabase(const std::string &db, const std::string &name, system::Transaction *system_tx = nullptr);
void DenyDatabase(const std::string &db, User &user, system::Transaction *system_tx = nullptr);
void DenyDatabase(const std::string &db, Role &role, system::Transaction *system_tx = nullptr);
/**
* @brief Revoke access to individual database for a user.
*
* @param db name of the database to revoke
* @param name user's username
* @return true on success
* @throw AuthException if unable to find or update the user
*/
Result RevokeDatabase(const std::string &db, const std::string &name, system::Transaction *system_tx = nullptr);
void RevokeDatabase(const std::string &db, User &user, system::Transaction *system_tx = nullptr);
void RevokeDatabase(const std::string &db, Role &role, system::Transaction *system_tx = nullptr);
/**
* @brief Delete a database from all users.
@ -265,9 +352,17 @@ class Auth final {
* @return true on success
* @throw AuthException if unable to find or update the user
*/
bool SetMainDatabase(std::string_view db, const std::string &name, system::Transaction *system_tx = nullptr);
Result SetMainDatabase(std::string_view db, const std::string &name, system::Transaction *system_tx = nullptr);
void SetMainDatabase(std::string_view db, User &user, system::Transaction *system_tx = nullptr);
void SetMainDatabase(std::string_view db, Role &role, system::Transaction *system_tx = nullptr);
#endif
bool UpToDate(Epoch &e) const {
bool res = e == epoch_;
e = epoch_;
return res;
}
private:
/**
* @brief
@ -278,11 +373,18 @@ class Auth final {
*/
bool NameRegexMatch(const std::string &user_or_role) const;
void UpdateEpoch() { ++epoch_; }
void DisableIfModuleUsed() const {
if (module_.IsUsed()) throw AuthException("Operation not permited when using an authentication module.");
}
// Even though the `kvstore::KVStore` class is guaranteed to be thread-safe,
// Auth is not thread-safe because modifying users and roles might require
// more than one operation on the storage.
kvstore::KVStore storage_;
auth::Module module_;
Config config_;
Epoch epoch_{kStartEpoch};
};
} // namespace memgraph::auth

View File

@ -425,10 +425,11 @@ Role::Role(const std::string &rolename, const Permissions &permissions)
: rolename_(utils::ToLowerCase(rolename)), permissions_(permissions) {}
#ifdef MG_ENTERPRISE
Role::Role(const std::string &rolename, const Permissions &permissions,
FineGrainedAccessHandler fine_grained_access_handler)
FineGrainedAccessHandler fine_grained_access_handler, Databases db_access)
: rolename_(utils::ToLowerCase(rolename)),
permissions_(permissions),
fine_grained_access_handler_(std::move(fine_grained_access_handler)) {}
fine_grained_access_handler_(std::move(fine_grained_access_handler)),
db_access_(std::move(db_access)) {}
#endif
const std::string &Role::rolename() const { return rolename_; }
@ -454,8 +455,10 @@ nlohmann::json Role::Serialize() const {
#ifdef MG_ENTERPRISE
if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
data[kFineGrainedAccessHandler] = fine_grained_access_handler_.Serialize();
data[kDatabases] = db_access_.Serialize();
} else {
data[kFineGrainedAccessHandler] = {};
data[kDatabases] = {};
}
#endif
return data;
@ -471,12 +474,21 @@ Role Role::Deserialize(const nlohmann::json &data) {
auto permissions = Permissions::Deserialize(data[kPermissions]);
#ifdef MG_ENTERPRISE
if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
Databases db_access;
if (data[kDatabases].is_structured()) {
db_access = Databases::Deserialize(data[kDatabases]);
} else {
// Back-compatibility
spdlog::warn("Role without specified database access. Given access to the default database.");
db_access.Grant(dbms::kDefaultDB);
db_access.SetMain(dbms::kDefaultDB);
}
FineGrainedAccessHandler fine_grained_access_handler;
// We can have an empty fine_grained if the user was created without a valid license
if (data[kFineGrainedAccessHandler].is_object()) {
fine_grained_access_handler = FineGrainedAccessHandler::Deserialize(data[kFineGrainedAccessHandler]);
}
return {data[kRoleName], permissions, std::move(fine_grained_access_handler)};
return {data[kRoleName], permissions, std::move(fine_grained_access_handler), std::move(db_access)};
}
#endif
return {data[kRoleName], permissions};
@ -493,7 +505,7 @@ bool operator==(const Role &first, const Role &second) {
}
#ifdef MG_ENTERPRISE
void Databases::Add(std::string_view db) {
void Databases::Grant(std::string_view db) {
if (allow_all_) {
grants_dbs_.clear();
allow_all_ = false;
@ -502,19 +514,19 @@ void Databases::Add(std::string_view db) {
denies_dbs_.erase(std::string{db}); // TODO: C++23 use transparent key compare
}
void Databases::Remove(const std::string &db) {
void Databases::Deny(const std::string &db) {
denies_dbs_.emplace(db);
grants_dbs_.erase(db);
}
void Databases::Delete(const std::string &db) {
void Databases::Revoke(const std::string &db) {
denies_dbs_.erase(db);
if (!allow_all_) {
grants_dbs_.erase(db);
}
// Reset if default deleted
if (default_db_ == db) {
default_db_ = "";
if (main_db_ == db) {
main_db_ = "";
}
}
@ -530,9 +542,16 @@ void Databases::DenyAll() {
denies_dbs_.clear();
}
bool Databases::SetDefault(std::string_view db) {
void Databases::RevokeAll() {
allow_all_ = false;
grants_dbs_.clear();
denies_dbs_.clear();
main_db_ = "";
}
bool Databases::SetMain(std::string_view db) {
if (!Contains(db)) return false;
default_db_ = db;
main_db_ = db;
return true;
}
@ -540,11 +559,11 @@ bool Databases::SetDefault(std::string_view db) {
return !denies_dbs_.contains(db) && (allow_all_ || grants_dbs_.contains(db));
}
const std::string &Databases::GetDefault() const {
if (!Contains(default_db_)) {
throw AuthException("No access to the set default database \"{}\".", default_db_);
const std::string &Databases::GetMain() const {
if (!Contains(main_db_)) {
throw AuthException("No access to the set default database \"{}\".", main_db_);
}
return default_db_;
return main_db_;
}
nlohmann::json Databases::Serialize() const {
@ -552,7 +571,7 @@ nlohmann::json Databases::Serialize() const {
data[kGrants] = grants_dbs_;
data[kDenies] = denies_dbs_;
data[kAllowAll] = allow_all_;
data[kDefault] = default_db_;
data[kDefault] = main_db_;
return data;
}
@ -719,15 +738,16 @@ User User::Deserialize(const nlohmann::json &data) {
} else {
// Back-compatibility
spdlog::warn("User without specified database access. Given access to the default database.");
db_access.Add(dbms::kDefaultDB);
db_access.SetDefault(dbms::kDefaultDB);
db_access.Grant(dbms::kDefaultDB);
db_access.SetMain(dbms::kDefaultDB);
}
FineGrainedAccessHandler fine_grained_access_handler;
// We can have an empty fine_grained if the user was created without a valid license
if (data[kFineGrainedAccessHandler].is_object()) {
fine_grained_access_handler = FineGrainedAccessHandler::Deserialize(data[kFineGrainedAccessHandler]);
}
return {data[kUsername], std::move(password_hash), permissions, std::move(fine_grained_access_handler), db_access};
return {data[kUsername], std::move(password_hash), permissions, std::move(fine_grained_access_handler),
std::move(db_access)};
}
#endif
return {data[kUsername], std::move(password_hash), permissions};

View File

@ -205,52 +205,10 @@ class FineGrainedAccessHandler final {
bool operator==(const FineGrainedAccessHandler &first, const FineGrainedAccessHandler &second);
#endif
class Role final {
public:
Role() = default;
explicit Role(const std::string &rolename);
Role(const std::string &rolename, const Permissions &permissions);
#ifdef MG_ENTERPRISE
Role(const std::string &rolename, const Permissions &permissions,
FineGrainedAccessHandler fine_grained_access_handler);
#endif
Role(const Role &) = default;
Role &operator=(const Role &) = default;
Role(Role &&) noexcept = default;
Role &operator=(Role &&) noexcept = default;
~Role() = default;
const std::string &rolename() const;
const Permissions &permissions() const;
Permissions &permissions();
#ifdef MG_ENTERPRISE
const FineGrainedAccessHandler &fine_grained_access_handler() const;
FineGrainedAccessHandler &fine_grained_access_handler();
const FineGrainedAccessPermissions &GetFineGrainedAccessLabelPermissions() const;
const FineGrainedAccessPermissions &GetFineGrainedAccessEdgeTypePermissions() const;
#endif
nlohmann::json Serialize() const;
/// @throw AuthException if unable to deserialize.
static Role Deserialize(const nlohmann::json &data);
friend bool operator==(const Role &first, const Role &second);
private:
std::string rolename_;
Permissions permissions_;
#ifdef MG_ENTERPRISE
FineGrainedAccessHandler fine_grained_access_handler_;
#endif
};
bool operator==(const Role &first, const Role &second);
#ifdef MG_ENTERPRISE
class Databases final {
public:
Databases() : grants_dbs_{std::string{dbms::kDefaultDB}}, allow_all_(false), default_db_(dbms::kDefaultDB) {}
Databases() : grants_dbs_{std::string{dbms::kDefaultDB}}, allow_all_(false), main_db_(dbms::kDefaultDB) {}
Databases(const Databases &) = default;
Databases &operator=(const Databases &) = default;
@ -263,7 +221,7 @@ class Databases final {
*
* @param db name of the database to grant access to
*/
void Add(std::string_view db);
void Grant(std::string_view db);
/**
* @brief Remove database to the list of granted access.
@ -272,7 +230,7 @@ class Databases final {
*
* @param db name of the database to grant access to
*/
void Remove(const std::string &db);
void Deny(const std::string &db);
/**
* @brief Called when database is dropped. Removes it from granted (if allow_all is false) and denied set.
@ -280,7 +238,7 @@ class Databases final {
*
* @param db name of the database to grant access to
*/
void Delete(const std::string &db);
void Revoke(const std::string &db);
/**
* @brief Set allow_all_ to true and clears grants and denied sets.
@ -292,10 +250,15 @@ class Databases final {
*/
void DenyAll();
/**
* @brief Set allow_all_ to false and clears grants and denied sets.
*/
void RevokeAll();
/**
* @brief Set the default database.
*/
bool SetDefault(std::string_view db);
bool SetMain(std::string_view db);
/**
* @brief Checks if access is grated to the database.
@ -304,11 +267,13 @@ class Databases final {
* @return true if allow_all and not denied or granted
*/
bool Contains(std::string_view db) const;
bool Denies(std::string_view db_name) const { return denies_dbs_.contains(db_name); }
bool Grants(std::string_view db_name) const { return allow_all_ || grants_dbs_.contains(db_name); }
bool GetAllowAll() const { return allow_all_; }
const std::set<std::string, std::less<>> &GetGrants() const { return grants_dbs_; }
const std::set<std::string, std::less<>> &GetDenies() const { return denies_dbs_; }
const std::string &GetDefault() const;
const std::string &GetMain() const;
nlohmann::json Serialize() const;
/// @throw AuthException if unable to deserialize.
@ -320,15 +285,69 @@ class Databases final {
: grants_dbs_(std::move(grant)),
denies_dbs_(std::move(deny)),
allow_all_(allow_all),
default_db_(std::move(default_db)) {}
main_db_(std::move(default_db)) {}
std::set<std::string, std::less<>> grants_dbs_; //!< set of databases with granted access
std::set<std::string, std::less<>> denies_dbs_; //!< set of databases with denied access
bool allow_all_; //!< flag to allow access to everything (denied overrides this)
std::string default_db_; //!< user's default database
std::string main_db_; //!< user's default database
};
#endif
class Role {
public:
Role() = default;
explicit Role(const std::string &rolename);
Role(const std::string &rolename, const Permissions &permissions);
#ifdef MG_ENTERPRISE
Role(const std::string &rolename, const Permissions &permissions,
FineGrainedAccessHandler fine_grained_access_handler, Databases db_access = {});
#endif
Role(const Role &) = default;
Role &operator=(const Role &) = default;
Role(Role &&) noexcept = default;
Role &operator=(Role &&) noexcept = default;
~Role() = default;
const std::string &rolename() const;
const Permissions &permissions() const;
Permissions &permissions();
Permissions GetPermissions() const { return permissions_; }
#ifdef MG_ENTERPRISE
const FineGrainedAccessHandler &fine_grained_access_handler() const;
FineGrainedAccessHandler &fine_grained_access_handler();
const FineGrainedAccessPermissions &GetFineGrainedAccessLabelPermissions() const;
const FineGrainedAccessPermissions &GetFineGrainedAccessEdgeTypePermissions() const;
#endif
#ifdef MG_ENTERPRISE
Databases &db_access() { return db_access_; }
const Databases &db_access() const { return db_access_; }
bool DeniesDB(std::string_view db_name) const { return db_access_.Denies(db_name); }
bool GrantsDB(std::string_view db_name) const { return db_access_.Grants(db_name); }
bool HasAccess(std::string_view db_name) const { return !DeniesDB(db_name) && GrantsDB(db_name); }
#endif
nlohmann::json Serialize() const;
/// @throw AuthException if unable to deserialize.
static Role Deserialize(const nlohmann::json &data);
friend bool operator==(const Role &first, const Role &second);
private:
std::string rolename_;
Permissions permissions_;
#ifdef MG_ENTERPRISE
FineGrainedAccessHandler fine_grained_access_handler_;
Databases db_access_;
#endif
};
bool operator==(const Role &first, const Role &second);
// TODO (mferencevic): Implement password expiry.
class User final {
public:
@ -388,6 +407,18 @@ class User final {
#ifdef MG_ENTERPRISE
Databases &db_access() { return database_access_; }
const Databases &db_access() const { return database_access_; }
bool DeniesDB(std::string_view db_name) const {
bool denies = database_access_.Denies(db_name);
if (role_) denies |= role_->DeniesDB(db_name);
return denies;
}
bool GrantsDB(std::string_view db_name) const {
bool grants = database_access_.Grants(db_name);
if (role_) grants |= role_->GrantsDB(db_name);
return grants;
}
bool HasAccess(std::string_view db_name) const { return !DeniesDB(db_name) && GrantsDB(db_name); }
#endif
nlohmann::json Serialize() const;
@ -403,7 +434,7 @@ class User final {
Permissions permissions_;
#ifdef MG_ENTERPRISE
FineGrainedAccessHandler fine_grained_access_handler_;
Databases database_access_;
Databases database_access_{};
#endif
std::optional<Role> role_;
};

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Licensed as a Memgraph Enterprise file under the Memgraph Enterprise
// License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use
@ -403,7 +403,7 @@ nlohmann::json Module::Call(const nlohmann::json &params, int timeout_millisec)
return ret;
}
bool Module::IsUsed() { return !module_executable_path_.empty(); }
bool Module::IsUsed() const { return !module_executable_path_.empty(); }
void Module::Shutdown() {
if (pid_ == -1) return;

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Licensed as a Memgraph Enterprise file under the Memgraph Enterprise
// License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use
@ -49,7 +49,7 @@ class Module final {
/// specified executable path and can thus be used.
///
/// @return boolean indicating whether the module can be used
bool IsUsed();
bool IsUsed() const;
~Module();

View File

@ -18,11 +18,9 @@
#include "utils/enum.hpp"
namespace memgraph::slk {
// Serialize code for auth::Role
void Save(const auth::Role &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.Serialize().dump(), builder);
}
void Save(const auth::Role &self, Builder *builder) { memgraph::slk::Save(self.Serialize().dump(), builder); }
namespace {
auth::Role LoadAuthRole(memgraph::slk::Reader *reader) {
std::string tmp;

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -12,19 +12,44 @@
#include "communication/websocket/auth.hpp"
#include <string>
#include "utils/variant_helpers.hpp"
namespace memgraph::communication::websocket {
bool SafeAuth::Authenticate(const std::string &username, const std::string &password) const {
return auth_->Lock()->Authenticate(username, password).has_value();
user_or_role_ = auth_->Lock()->Authenticate(username, password);
return user_or_role_.has_value();
}
bool SafeAuth::HasUserPermission(const std::string &username, const auth::Permission permission) const {
if (const auto user = auth_->ReadLock()->GetUser(username); user) {
return user->GetPermissions().Has(permission) == auth::PermissionLevel::GRANT;
bool SafeAuth::HasPermission(const auth::Permission permission) const {
auto locked_auth = auth_->ReadLock();
// Update if cache invalidated
if (!locked_auth->UpToDate(auth_epoch_) && user_or_role_) {
bool success = true;
std::visit(utils::Overloaded{[&](auth::User &user) {
auto tmp = locked_auth->GetUser(user.username());
if (!tmp) success = false;
user = std::move(*tmp);
},
[&](auth::Role &role) {
auto tmp = locked_auth->GetRole(role.rolename());
if (!tmp) success = false;
role = std::move(*tmp);
}},
*user_or_role_);
// Missing user/role; delete from cache
if (!success) user_or_role_.reset();
}
// Check permissions
if (user_or_role_) {
return std::visit(utils::Overloaded{[&](auto &user_or_role) {
return user_or_role.GetPermissions().Has(permission) == auth::PermissionLevel::GRANT;
}},
*user_or_role_);
}
// NOTE: websocket authenticates only if there is a user, so no need to check if access controlled
return false;
}
bool SafeAuth::HasAnyUsers() const { return auth_->ReadLock()->HasUsers(); }
bool SafeAuth::AccessControlled() const { return auth_->ReadLock()->AccessControlled(); }
} // namespace memgraph::communication::websocket

View File

@ -21,9 +21,9 @@ class AuthenticationInterface {
public:
virtual bool Authenticate(const std::string &username, const std::string &password) const = 0;
virtual bool HasUserPermission(const std::string &username, auth::Permission permission) const = 0;
virtual bool HasPermission(auth::Permission permission) const = 0;
virtual bool HasAnyUsers() const = 0;
virtual bool AccessControlled() const = 0;
};
class SafeAuth : public AuthenticationInterface {
@ -32,11 +32,13 @@ class SafeAuth : public AuthenticationInterface {
bool Authenticate(const std::string &username, const std::string &password) const override;
bool HasUserPermission(const std::string &username, auth::Permission permission) const override;
bool HasPermission(auth::Permission permission) const override;
bool HasAnyUsers() const override;
bool AccessControlled() const override;
private:
auth::SynchedAuth *auth_;
mutable std::optional<auth::UserOrRole> user_or_role_;
mutable auth::Auth::Epoch auth_epoch_{};
};
} // namespace memgraph::communication::websocket

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -80,7 +80,7 @@ bool Session::Run() {
return false;
}
authenticated_ = !auth_.HasAnyUsers();
authenticated_ = !auth_.AccessControlled();
connected_.store(true, std::memory_order_relaxed);
// run on the strand
@ -162,7 +162,7 @@ utils::BasicResult<std::string> Session::Authorize(const nlohmann::json &creds)
return {"Authentication failed!"};
}
#ifdef MG_ENTERPRISE
if (!auth_.HasUserPermission(creds.at("username").get<std::string>(), auth::Permission::WEBSOCKET)) {
if (!auth_.HasPermission(auth::Permission::WEBSOCKET)) {
return {"Authorization failed!"};
}
#endif

View File

@ -6,5 +6,6 @@ target_sources(mg-glue PRIVATE auth.cpp
SessionHL.cpp
ServerT.cpp
MonitoringServerT.cpp
run_id.cpp)
run_id.cpp
query_user.cpp)
target_link_libraries(mg-glue mg-query mg-auth mg-audit mg-flags)

View File

@ -11,6 +11,7 @@
#include <optional>
#include <utility>
#include "auth/auth.hpp"
#include "gflags/gflags.h"
#include "audit/log.hpp"
@ -19,17 +20,22 @@
#include "glue/SessionHL.hpp"
#include "glue/auth_checker.hpp"
#include "glue/communication.hpp"
#include "glue/query_user.hpp"
#include "glue/run_id.hpp"
#include "license/license.hpp"
#include "query/auth_checker.hpp"
#include "query/discard_value_stream.hpp"
#include "query/interpreter_context.hpp"
#include "query/query_user.hpp"
#include "utils/event_map.hpp"
#include "utils/spin_lock.hpp"
#include "utils/variant_helpers.hpp"
namespace memgraph::metrics {
extern const Event ActiveBoltSessions;
} // namespace memgraph::metrics
namespace {
auto ToQueryExtras(const memgraph::communication::bolt::Value &extra) -> memgraph::query::QueryExtras {
auto const &as_map = extra.ValueMap();
@ -97,20 +103,24 @@ std::vector<memgraph::communication::bolt::Value> TypedValueResultStreamBase::De
}
return decoded_values;
}
TypedValueResultStreamBase::TypedValueResultStreamBase(memgraph::storage::Storage *storage) : storage_(storage) {}
namespace memgraph::glue {
#ifdef MG_ENTERPRISE
inline static void MultiDatabaseAuth(const std::optional<auth::User> &user, std::string_view db) {
if (user && !AuthChecker::IsUserAuthorized(*user, {}, std::string(db))) {
void MultiDatabaseAuth(memgraph::query::QueryUserOrRole *user, std::string_view db) {
if (user && !user->IsAuthorized({}, std::string(db), &memgraph::query::session_long_policy)) {
throw memgraph::communication::bolt::ClientError(
"You are not authorized on the database \"{}\"! Please contact your database administrator.", db);
}
}
#endif
} // namespace
namespace memgraph::glue {
#ifdef MG_ENTERPRISE
std::string SessionHL::GetDefaultDB() {
if (user_.has_value()) {
return user_->db_access().GetDefault();
if (user_or_role_) {
return user_or_role_->GetDefaultDB();
}
return std::string{memgraph::dbms::kDefaultDB};
}
@ -132,13 +142,18 @@ bool SessionHL::Authenticate(const std::string &username, const std::string &pas
interpreter_.ResetUser();
{
auto locked_auth = auth_->Lock();
if (locked_auth->HasUsers()) {
user_ = locked_auth->Authenticate(username, password);
if (user_.has_value()) {
interpreter_.SetUser(user_->username());
if (locked_auth->AccessControlled()) {
const auto user_or_role = locked_auth->Authenticate(username, password);
if (user_or_role.has_value()) {
user_or_role_ = AuthChecker::GenQueryUser(auth_, *user_or_role);
interpreter_.SetUser(AuthChecker::GenQueryUser(auth_, *user_or_role));
} else {
res = false;
}
} else {
// No access control -> give empty user
user_or_role_ = AuthChecker::GenQueryUser(auth_, std::nullopt);
interpreter_.SetUser(AuthChecker::GenQueryUser(auth_, std::nullopt));
}
}
#ifdef MG_ENTERPRISE
@ -195,21 +210,17 @@ std::pair<std::vector<std::string>, std::optional<int>> SessionHL::Interpret(
}
#ifdef MG_ENTERPRISE
const std::string *username{nullptr};
if (user_) {
username = &user_->username();
}
if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
auto &db = interpreter_.current_db_.db_acc_;
audit_log_->Record(endpoint_.address().to_string(), user_ ? *username : "", query,
memgraph::storage::PropertyValue(params_pv), db ? db->get()->name() : "no known database");
const auto username = user_or_role_ ? (user_or_role_->username() ? *user_or_role_->username() : "") : "";
audit_log_->Record(endpoint_.address().to_string(), username, query, memgraph::storage::PropertyValue(params_pv),
db ? db->get()->name() : "no known database");
}
#endif
try {
auto result = interpreter_.Prepare(query, params_pv, ToQueryExtras(extra));
const std::string db_name = result.db ? *result.db : "";
if (user_ && !AuthChecker::IsUserAuthorized(*user_, result.privileges, db_name)) {
if (user_or_role_ && !user_or_role_->IsAuthorized(result.privileges, db_name, &query::session_long_policy)) {
interpreter_.Abort();
if (db_name.empty()) {
throw memgraph::communication::bolt::ClientError(
@ -311,7 +322,7 @@ void SessionHL::Configure(const std::map<std::string, memgraph::communication::b
// Check if the underlying database needs to be updated
if (update) {
MultiDatabaseAuth(user_, db);
MultiDatabaseAuth(user_or_role_.get(), db);
interpreter_.SetCurrentDB(db, in_explicit_db_);
}
#endif
@ -338,7 +349,7 @@ SessionHL::SessionHL(memgraph::query::InterpreterContext *interpreter_context,
// Metrics update
memgraph::metrics::IncrementCounter(memgraph::metrics::ActiveBoltSessions);
#ifdef MG_ENTERPRISE
interpreter_.OnChangeCB([&](std::string_view db_name) { MultiDatabaseAuth(user_, db_name); });
interpreter_.OnChangeCB([&](std::string_view db_name) { MultiDatabaseAuth(user_or_role_.get(), db_name); });
#endif
interpreter_context_->interpreters.WithLock([this](auto &interpreters) { interpreters.insert(&interpreter_); });
}

View File

@ -15,6 +15,7 @@
#include "communication/v2/server.hpp"
#include "communication/v2/session.hpp"
#include "dbms/database.hpp"
#include "glue/query_user.hpp"
#include "query/interpreter.hpp"
namespace memgraph::glue {
@ -82,7 +83,7 @@ class SessionHL final : public memgraph::communication::bolt::Session<memgraph::
memgraph::query::InterpreterContext *interpreter_context_;
memgraph::query::Interpreter interpreter_;
std::optional<memgraph::auth::User> user_;
std::unique_ptr<query::QueryUserOrRole> user_or_role_;
#ifdef MG_ENTERPRISE
memgraph::audit::Log *audit_log_;
bool in_explicit_db_{false}; //!< If true, the user has defined the database to use via metadata

View File

@ -14,53 +14,74 @@
#include "auth/auth.hpp"
#include "auth/models.hpp"
#include "glue/auth.hpp"
#include "glue/query_user.hpp"
#include "license/license.hpp"
#include "query/auth_checker.hpp"
#include "query/constants.hpp"
#include "query/frontend/ast/ast.hpp"
#include "query/query_user.hpp"
#include "utils/logging.hpp"
#include "utils/synchronized.hpp"
#include "utils/variant_helpers.hpp"
#ifdef MG_ENTERPRISE
namespace {
bool IsUserAuthorizedLabels(const memgraph::auth::User &user, const memgraph::query::DbAccessor *dba,
const std::vector<memgraph::storage::LabelId> &labels,
const memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) {
bool IsAuthorizedLabels(const memgraph::auth::UserOrRole &user_or_role, const memgraph::query::DbAccessor *dba,
const std::vector<memgraph::storage::LabelId> &labels,
const memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) {
if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
return true;
}
return std::all_of(labels.begin(), labels.end(), [dba, &user, fine_grained_privilege](const auto &label) {
return user.GetFineGrainedAccessLabelPermissions().Has(
dba->LabelToName(label), memgraph::glue::FineGrainedPrivilegeToFineGrainedPermission(
fine_grained_privilege)) == memgraph::auth::PermissionLevel::GRANT;
return std::all_of(labels.begin(), labels.end(), [dba, &user_or_role, fine_grained_privilege](const auto &label) {
return std::visit(memgraph::utils::Overloaded{[&](auto &user_or_role) {
return user_or_role.GetFineGrainedAccessLabelPermissions().Has(
dba->LabelToName(label), memgraph::glue::FineGrainedPrivilegeToFineGrainedPermission(
fine_grained_privilege)) ==
memgraph::auth::PermissionLevel::GRANT;
}},
user_or_role);
});
}
bool IsUserAuthorizedGloballyLabels(const memgraph::auth::User &user,
const memgraph::auth::FineGrainedPermission fine_grained_permission) {
bool IsAuthorizedGloballyLabels(const memgraph::auth::UserOrRole &user_or_role,
const memgraph::auth::FineGrainedPermission fine_grained_permission) {
if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
return true;
}
return user.GetFineGrainedAccessLabelPermissions().Has(memgraph::query::kAsterisk, fine_grained_permission) ==
memgraph::auth::PermissionLevel::GRANT;
return std::visit(memgraph::utils::Overloaded{[&](auto &user_or_role) {
return user_or_role.GetFineGrainedAccessLabelPermissions().Has(memgraph::query::kAsterisk,
fine_grained_permission) ==
memgraph::auth::PermissionLevel::GRANT;
}},
user_or_role);
}
bool IsUserAuthorizedGloballyEdges(const memgraph::auth::User &user,
const memgraph::auth::FineGrainedPermission fine_grained_permission) {
bool IsAuthorizedGloballyEdges(const memgraph::auth::UserOrRole &user_or_role,
const memgraph::auth::FineGrainedPermission fine_grained_permission) {
if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
return true;
}
return user.GetFineGrainedAccessEdgeTypePermissions().Has(memgraph::query::kAsterisk, fine_grained_permission) ==
memgraph::auth::PermissionLevel::GRANT;
return std::visit(memgraph::utils::Overloaded{[&](auto &user_or_role) {
return user_or_role.GetFineGrainedAccessEdgeTypePermissions().Has(memgraph::query::kAsterisk,
fine_grained_permission) ==
memgraph::auth::PermissionLevel::GRANT;
}},
user_or_role);
}
bool IsUserAuthorizedEdgeType(const memgraph::auth::User &user, const memgraph::query::DbAccessor *dba,
const memgraph::storage::EdgeTypeId &edgeType,
const memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) {
bool IsAuthorizedEdgeType(const memgraph::auth::UserOrRole &user_or_role, const memgraph::query::DbAccessor *dba,
const memgraph::storage::EdgeTypeId &edgeType,
const memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) {
if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
return true;
}
return user.GetFineGrainedAccessEdgeTypePermissions().Has(
dba->EdgeTypeToName(edgeType), memgraph::glue::FineGrainedPrivilegeToFineGrainedPermission(
fine_grained_privilege)) == memgraph::auth::PermissionLevel::GRANT;
return std::visit(memgraph::utils::Overloaded{[&](auto &user_or_role) {
return user_or_role.GetFineGrainedAccessEdgeTypePermissions().Has(
dba->EdgeTypeToName(edgeType),
memgraph::glue::FineGrainedPrivilegeToFineGrainedPermission(fine_grained_privilege)) ==
memgraph::auth::PermissionLevel::GRANT;
}},
user_or_role);
}
} // namespace
#endif
@ -68,47 +89,54 @@ namespace memgraph::glue {
AuthChecker::AuthChecker(memgraph::auth::SynchedAuth *auth) : auth_(auth) {}
bool AuthChecker::IsUserAuthorized(const std::optional<std::string> &username,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges,
const std::string &db_name) const {
std::optional<memgraph::auth::User> maybe_user;
{
auto locked_auth = auth_->ReadLock();
if (!locked_auth->HasUsers()) {
return true;
}
if (username.has_value()) {
maybe_user = locked_auth->GetUser(*username);
}
std::shared_ptr<query::QueryUserOrRole> AuthChecker::GenQueryUser(const std::optional<std::string> &username,
const std::optional<std::string> &rolename) const {
const auto user_or_role = auth_->ReadLock()->GetUserOrRole(username, rolename);
if (user_or_role) {
return std::make_shared<QueryUserOrRole>(auth_, *user_or_role);
}
// No user or role
return std::make_shared<QueryUserOrRole>(auth_);
}
return maybe_user.has_value() && IsUserAuthorized(*maybe_user, privileges, db_name);
std::unique_ptr<query::QueryUserOrRole> AuthChecker::GenQueryUser(auth::SynchedAuth *auth,
const std::optional<auth::UserOrRole> &user_or_role) {
if (user_or_role) {
return std::visit(
utils::Overloaded{[&](auto &user_or_role) { return std::make_unique<QueryUserOrRole>(auth, user_or_role); }},
*user_or_role);
}
// No user or role
return std::make_unique<QueryUserOrRole>(auth);
}
#ifdef MG_ENTERPRISE
std::unique_ptr<memgraph::query::FineGrainedAuthChecker> AuthChecker::GetFineGrainedAuthChecker(
const std::string &username, const memgraph::query::DbAccessor *dba) const {
std::shared_ptr<query::QueryUserOrRole> user_or_role, const memgraph::query::DbAccessor *dba) const {
if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
return {};
}
try {
auto user = user_.Lock();
if (username != user->username()) {
auto maybe_user = auth_->ReadLock()->GetUser(username);
if (!maybe_user) {
throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist .", username);
}
*user = std::move(*maybe_user);
}
return std::make_unique<memgraph::glue::FineGrainedAuthChecker>(*user, dba);
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
if (!user_or_role || !*user_or_role) {
throw query::QueryRuntimeException("No user specified for fine grained authorization!");
}
}
void AuthChecker::ClearCache() const {
user_.WithLock([](auto &user) mutable { user = {}; });
// Convert from query user to auth user or role
try {
auto glue_user = dynamic_cast<glue::QueryUserOrRole &>(*user_or_role);
if (glue_user.user_) {
return std::make_unique<glue::FineGrainedAuthChecker>(std::move(*glue_user.user_), dba);
}
if (glue_user.role_) {
return std::make_unique<glue::FineGrainedAuthChecker>(
auth::RoleWUsername{*glue_user.username(), std::move(*glue_user.role_)}, dba);
}
DMG_ASSERT(false, "Glue user has neither user not role");
} catch (std::bad_cast &e) {
DMG_ASSERT(false, "Using a non-glue user in glue...");
}
// Should never get here
return {};
}
#endif
@ -116,7 +144,7 @@ bool AuthChecker::IsUserAuthorized(const memgraph::auth::User &user,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges,
const std::string &db_name) { // NOLINT
#ifdef MG_ENTERPRISE
if (!db_name.empty() && !user.db_access().Contains(db_name)) {
if (!db_name.empty() && !user.HasAccess(db_name)) {
return false;
}
#endif
@ -127,9 +155,34 @@ bool AuthChecker::IsUserAuthorized(const memgraph::auth::User &user,
});
}
bool AuthChecker::IsRoleAuthorized(const memgraph::auth::Role &role,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges,
const std::string &db_name) { // NOLINT
#ifdef MG_ENTERPRISE
FineGrainedAuthChecker::FineGrainedAuthChecker(auth::User user, const memgraph::query::DbAccessor *dba)
: user_{std::move(user)}, dba_(dba){};
if (!db_name.empty() && !role.HasAccess(db_name)) {
return false;
}
#endif
const auto role_permissions = role.permissions();
return std::all_of(privileges.begin(), privileges.end(), [&role_permissions](const auto privilege) {
return role_permissions.Has(memgraph::glue::PrivilegeToPermission(privilege)) ==
memgraph::auth::PermissionLevel::GRANT;
});
}
bool AuthChecker::IsUserOrRoleAuthorized(const memgraph::auth::UserOrRole &user_or_role,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges,
const std::string &db_name) {
return std::visit(
utils::Overloaded{
[&](const auth::User &user) -> bool { return AuthChecker::IsUserAuthorized(user, privileges, db_name); },
[&](const auth::Role &role) -> bool { return AuthChecker::IsRoleAuthorized(role, privileges, db_name); }},
user_or_role);
}
#ifdef MG_ENTERPRISE
FineGrainedAuthChecker::FineGrainedAuthChecker(auth::UserOrRole user_or_role, const memgraph::query::DbAccessor *dba)
: user_or_role_{std::move(user_or_role)}, dba_(dba){};
bool FineGrainedAuthChecker::Has(const memgraph::query::VertexAccessor &vertex, const memgraph::storage::View view,
const memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const {
@ -147,22 +200,22 @@ bool FineGrainedAuthChecker::Has(const memgraph::query::VertexAccessor &vertex,
}
}
return IsUserAuthorizedLabels(user_, dba_, *maybe_labels, fine_grained_privilege);
return IsAuthorizedLabels(user_or_role_, dba_, *maybe_labels, fine_grained_privilege);
}
bool FineGrainedAuthChecker::Has(const memgraph::query::EdgeAccessor &edge,
const memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const {
return IsUserAuthorizedEdgeType(user_, dba_, edge.EdgeType(), fine_grained_privilege);
return IsAuthorizedEdgeType(user_or_role_, dba_, edge.EdgeType(), fine_grained_privilege);
}
bool FineGrainedAuthChecker::Has(const std::vector<memgraph::storage::LabelId> &labels,
const memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const {
return IsUserAuthorizedLabels(user_, dba_, labels, fine_grained_privilege);
return IsAuthorizedLabels(user_or_role_, dba_, labels, fine_grained_privilege);
}
bool FineGrainedAuthChecker::Has(const memgraph::storage::EdgeTypeId &edge_type,
const memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const {
return IsUserAuthorizedEdgeType(user_, dba_, edge_type, fine_grained_privilege);
return IsAuthorizedEdgeType(user_or_role_, dba_, edge_type, fine_grained_privilege);
}
bool FineGrainedAuthChecker::HasGlobalPrivilegeOnVertices(
@ -170,7 +223,7 @@ bool FineGrainedAuthChecker::HasGlobalPrivilegeOnVertices(
if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
return true;
}
return IsUserAuthorizedGloballyLabels(user_, FineGrainedPrivilegeToFineGrainedPermission(fine_grained_privilege));
return IsAuthorizedGloballyLabels(user_or_role_, FineGrainedPrivilegeToFineGrainedPermission(fine_grained_privilege));
}
bool FineGrainedAuthChecker::HasGlobalPrivilegeOnEdges(
@ -178,7 +231,7 @@ bool FineGrainedAuthChecker::HasGlobalPrivilegeOnEdges(
if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
return true;
}
return IsUserAuthorizedGloballyEdges(user_, FineGrainedPrivilegeToFineGrainedPermission(fine_grained_privilege));
return IsAuthorizedGloballyEdges(user_or_role_, FineGrainedPrivilegeToFineGrainedPermission(fine_grained_privilege));
};
#endif
} // namespace memgraph::glue

View File

@ -22,53 +22,59 @@ namespace memgraph::glue {
class AuthChecker : public query::AuthChecker {
public:
explicit AuthChecker(memgraph::auth::SynchedAuth *auth);
explicit AuthChecker(auth::SynchedAuth *auth);
bool IsUserAuthorized(const std::optional<std::string> &username,
const std::vector<query::AuthQuery::Privilege> &privileges,
const std::string &db_name) const override;
std::shared_ptr<query::QueryUserOrRole> GenQueryUser(const std::optional<std::string> &username,
const std::optional<std::string> &rolename) const override;
static std::unique_ptr<query::QueryUserOrRole> GenQueryUser(auth::SynchedAuth *auth,
const std::optional<auth::UserOrRole> &user_or_role);
#ifdef MG_ENTERPRISE
std::unique_ptr<memgraph::query::FineGrainedAuthChecker> GetFineGrainedAuthChecker(
const std::string &username, const memgraph::query::DbAccessor *dba) const override;
void ClearCache() const override;
std::unique_ptr<query::FineGrainedAuthChecker> GetFineGrainedAuthChecker(std::shared_ptr<query::QueryUserOrRole> user,
const query::DbAccessor *dba) const override;
#endif
[[nodiscard]] static bool IsUserAuthorized(const memgraph::auth::User &user,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges,
[[nodiscard]] static bool IsUserAuthorized(const auth::User &user,
const std::vector<query::AuthQuery::Privilege> &privileges,
const std::string &db_name = "");
[[nodiscard]] static bool IsRoleAuthorized(const auth::Role &role,
const std::vector<query::AuthQuery::Privilege> &privileges,
const std::string &db_name = "");
[[nodiscard]] static bool IsUserOrRoleAuthorized(const auth::UserOrRole &user_or_role,
const std::vector<query::AuthQuery::Privilege> &privileges,
const std::string &db_name = "");
private:
memgraph::auth::SynchedAuth *auth_;
mutable memgraph::utils::Synchronized<auth::User, memgraph::utils::SpinLock> user_; // cached user
auth::SynchedAuth *auth_;
mutable utils::Synchronized<auth::UserOrRole, utils::SpinLock> user_or_role_; // cached user
};
#ifdef MG_ENTERPRISE
class FineGrainedAuthChecker : public query::FineGrainedAuthChecker {
public:
explicit FineGrainedAuthChecker(auth::User user, const memgraph::query::DbAccessor *dba);
explicit FineGrainedAuthChecker(auth::UserOrRole user, const query::DbAccessor *dba);
bool Has(const query::VertexAccessor &vertex, memgraph::storage::View view,
bool Has(const query::VertexAccessor &vertex, storage::View view,
query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const override;
bool Has(const query::EdgeAccessor &edge,
query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const override;
bool Has(const std::vector<memgraph::storage::LabelId> &labels,
bool Has(const std::vector<storage::LabelId> &labels,
query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const override;
bool Has(const memgraph::storage::EdgeTypeId &edge_type,
bool Has(const storage::EdgeTypeId &edge_type,
query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const override;
bool HasGlobalPrivilegeOnVertices(
memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const override;
bool HasGlobalPrivilegeOnVertices(query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const override;
bool HasGlobalPrivilegeOnEdges(
memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const override;
bool HasGlobalPrivilegeOnEdges(query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const override;
private:
auth::User user_;
const memgraph::query::DbAccessor *dba_;
auth::UserOrRole user_or_role_;
const query::DbAccessor *dba_;
};
#endif
} // namespace memgraph::glue

View File

@ -15,6 +15,7 @@
#include <fmt/format.h>
#include "auth/auth.hpp"
#include "auth/models.hpp"
#include "dbms/constants.hpp"
#include "glue/auth.hpp"
@ -123,6 +124,29 @@ std::vector<std::vector<memgraph::query::TypedValue>> ShowRolePrivileges(
}
#ifdef MG_ENTERPRISE
std::vector<std::vector<memgraph::query::TypedValue>> ShowDatabasePrivileges(
const std::optional<memgraph::auth::Role> &role) {
if (!memgraph::license::global_license_checker.IsEnterpriseValidFast() || !role) {
return {};
}
const auto &db = role->db_access();
const auto &allows = db.GetAllowAll();
const auto &grants = db.GetGrants();
const auto &denies = db.GetDenies();
std::vector<memgraph::query::TypedValue> res; // First element is a list of granted databases, second of revoked ones
if (allows) {
res.emplace_back("*");
} else {
std::vector<memgraph::query::TypedValue> grants_vec(grants.cbegin(), grants.cend());
res.emplace_back(std::move(grants_vec));
}
std::vector<memgraph::query::TypedValue> denies_vec(denies.cbegin(), denies.cend());
res.emplace_back(std::move(denies_vec));
return {res};
}
std::vector<std::vector<memgraph::query::TypedValue>> ShowDatabasePrivileges(
const std::optional<memgraph::auth::User> &user) {
if (!memgraph::license::global_license_checker.IsEnterpriseValidFast() || !user) {
@ -130,9 +154,15 @@ std::vector<std::vector<memgraph::query::TypedValue>> ShowDatabasePrivileges(
}
const auto &db = user->db_access();
const auto &allows = db.GetAllowAll();
const auto &grants = db.GetGrants();
const auto &denies = db.GetDenies();
auto allows = db.GetAllowAll();
auto grants = db.GetGrants();
auto denies = db.GetDenies();
if (const auto *role = user->role()) {
const auto &role_db = role->db_access();
allows |= role_db.GetAllowAll();
grants.insert(role_db.GetGrants().begin(), role_db.GetGrants().end());
denies.insert(role_db.GetDenies().begin(), role_db.GetDenies().end());
}
std::vector<memgraph::query::TypedValue> res; // First element is a list of granted databases, second of revoked ones
if (allows) {
@ -287,7 +317,7 @@ bool AuthQueryHandler::CreateUser(const std::string &username, const std::option
,
system_tx);
#ifdef MG_ENTERPRISE
GrantDatabaseToUser(auth::kAllDatabases, username, system_tx);
GrantDatabase(auth::kAllDatabases, username, system_tx);
SetMainDatabase(dbms::kDefaultDB, username, system_tx);
#endif
}
@ -334,51 +364,97 @@ bool AuthQueryHandler::CreateRole(const std::string &rolename, system::Transacti
}
#ifdef MG_ENTERPRISE
bool AuthQueryHandler::RevokeDatabaseFromUser(const std::string &db_name, const std::string &username,
system::Transaction *system_tx) {
void AuthQueryHandler::GrantDatabase(const std::string &db_name, const std::string &user_or_role,
system::Transaction *system_tx) {
try {
auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username);
if (!user) return false;
return locked_auth->RevokeDatabaseFromUser(db_name, username, system_tx);
const auto res = locked_auth->GrantDatabase(db_name, user_or_role, system_tx);
switch (res) {
using enum auth::Auth::Result;
case SUCCESS:
return;
case NO_USER_ROLE:
throw query::QueryRuntimeException("No user nor role '{}' found.", user_or_role);
case NO_ROLE:
throw query::QueryRuntimeException("Using auth module, no role '{}' found.", user_or_role);
break;
}
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
bool AuthQueryHandler::GrantDatabaseToUser(const std::string &db_name, const std::string &username,
system::Transaction *system_tx) {
void AuthQueryHandler::DenyDatabase(const std::string &db_name, const std::string &user_or_role,
system::Transaction *system_tx) {
try {
auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username);
if (!user) return false;
return locked_auth->GrantDatabaseToUser(db_name, username, system_tx);
const auto res = locked_auth->DenyDatabase(db_name, user_or_role, system_tx);
switch (res) {
using enum auth::Auth::Result;
case SUCCESS:
return;
case NO_USER_ROLE:
throw query::QueryRuntimeException("No user nor role '{}' found.", user_or_role);
case NO_ROLE:
throw query::QueryRuntimeException("Using auth module, no role '{}' found.", user_or_role);
break;
}
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
void AuthQueryHandler::RevokeDatabase(const std::string &db_name, const std::string &user_or_role,
system::Transaction *system_tx) {
try {
auto locked_auth = auth_->Lock();
const auto res = locked_auth->RevokeDatabase(db_name, user_or_role, system_tx);
switch (res) {
using enum auth::Auth::Result;
case SUCCESS:
return;
case NO_USER_ROLE:
throw query::QueryRuntimeException("No user nor role '{}' found.", user_or_role);
case NO_ROLE:
throw query::QueryRuntimeException("Using auth module, no role '{}' found.", user_or_role);
break;
}
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
std::vector<std::vector<memgraph::query::TypedValue>> AuthQueryHandler::GetDatabasePrivileges(
const std::string &username) {
const std::string &user_or_role) {
try {
auto locked_auth = auth_->ReadLock();
auto user = locked_auth->GetUser(username);
if (!user) {
throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist.", username);
if (auto user = locked_auth->GetUser(user_or_role)) {
return ShowDatabasePrivileges(user);
}
return ShowDatabasePrivileges(user);
if (auto role = locked_auth->GetRole(user_or_role)) {
return ShowDatabasePrivileges(role);
}
throw memgraph::query::QueryRuntimeException("Neither user nor role '{}' exist.", user_or_role);
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
bool AuthQueryHandler::SetMainDatabase(std::string_view db_name, const std::string &username,
void AuthQueryHandler::SetMainDatabase(std::string_view db_name, const std::string &user_or_role,
system::Transaction *system_tx) {
try {
auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username);
if (!user) return false;
return locked_auth->SetMainDatabase(db_name, username, system_tx);
const auto res = locked_auth->SetMainDatabase(db_name, user_or_role, system_tx);
switch (res) {
using enum auth::Auth::Result;
case SUCCESS:
return;
case NO_USER_ROLE:
throw query::QueryRuntimeException("No user nor role '{}' found.", user_or_role);
case NO_ROLE:
throw query::QueryRuntimeException("Using auth module, no role '{}' found.", user_or_role);
break;
}
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}

View File

@ -37,15 +37,19 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler {
system::Transaction *system_tx) override;
#ifdef MG_ENTERPRISE
bool RevokeDatabaseFromUser(const std::string &db_name, const std::string &username,
system::Transaction *system_tx) override;
void GrantDatabase(const std::string &db_name, const std::string &user_or_role,
system::Transaction *system_tx) override;
bool GrantDatabaseToUser(const std::string &db_name, const std::string &username,
system::Transaction *system_tx) override;
void DenyDatabase(const std::string &db_name, const std::string &user_or_role,
system::Transaction *system_tx) override;
std::vector<std::vector<memgraph::query::TypedValue>> GetDatabasePrivileges(const std::string &username) override;
void RevokeDatabase(const std::string &db_name, const std::string &user_or_role,
system::Transaction *system_tx) override;
bool SetMainDatabase(std::string_view db_name, const std::string &username, system::Transaction *system_tx) override;
std::vector<std::vector<memgraph::query::TypedValue>> GetDatabasePrivileges(const std::string &user_or_role) override;
void SetMainDatabase(std::string_view db_name, const std::string &user_or_role,
system::Transaction *system_tx) override;
void DeleteDatabase(std::string_view db_name, system::Transaction *system_tx) override;
#endif

41
src/glue/query_user.cpp Normal file
View File

@ -0,0 +1,41 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "glue/query_user.hpp"
#include "glue/auth_checker.hpp"
namespace memgraph::glue {
bool QueryUserOrRole::IsAuthorized(const std::vector<query::AuthQuery::Privilege> &privileges,
const std::string &db_name, query::UserPolicy *policy) const {
auto locked_auth = auth_->Lock();
// Check policy and update if behind (and policy permits it)
if (policy->DoUpdate() && !locked_auth->UpToDate(auth_epoch_)) {
if (user_) user_ = locked_auth->GetUser(user_->username());
if (role_) role_ = locked_auth->GetRole(role_->rolename());
}
if (user_) return AuthChecker::IsUserAuthorized(*user_, privileges, db_name);
if (role_) return AuthChecker::IsRoleAuthorized(*role_, privileges, db_name);
return !policy->DoUpdate() || !locked_auth->AccessControlled();
}
#ifdef MG_ENTERPRISE
std::string QueryUserOrRole::GetDefaultDB() const {
if (user_) return user_->db_access().GetMain();
if (role_) return role_->db_access().GetMain();
return std::string{dbms::kDefaultDB};
}
#endif
} // namespace memgraph::glue

57
src/glue/query_user.hpp Normal file
View File

@ -0,0 +1,57 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <optional>
#include "auth/auth.hpp"
#include "query/query_user.hpp"
#include "utils/variant_helpers.hpp"
namespace memgraph::glue {
struct QueryUserOrRole : public query::QueryUserOrRole {
bool IsAuthorized(const std::vector<query::AuthQuery::Privilege> &privileges, const std::string &db_name,
query::UserPolicy *policy) const override;
#ifdef MG_ENTERPRISE
std::string GetDefaultDB() const override;
#endif
explicit QueryUserOrRole(auth::SynchedAuth *auth) : query::QueryUserOrRole{std::nullopt, std::nullopt}, auth_{auth} {}
QueryUserOrRole(auth::SynchedAuth *auth, auth::UserOrRole user_or_role)
: query::QueryUserOrRole{std::visit(
utils::Overloaded{[](const auto &user_or_role) { return user_or_role.username(); }},
user_or_role),
std::visit(utils::Overloaded{[&](const auth::User &) -> std::optional<std::string> {
return std::nullopt;
},
[&](const auth::Role &role) -> std::optional<std::string> {
return role.rolename();
}},
user_or_role)},
auth_{auth} {
std::visit(utils::Overloaded{[&](auth::User &&user) { user_.emplace(std::move(user)); },
[&](auth::Role &&role) { role_.emplace(std::move(role)); }},
std::move(user_or_role));
}
private:
friend class AuthChecker;
auth::SynchedAuth *auth_;
mutable std::optional<auth::User> user_{};
mutable std::optional<auth::Role> role_{};
mutable auth::Auth::Epoch auth_epoch_{auth::Auth::kStartEpoch};
};
} // namespace memgraph::glue

View File

@ -27,6 +27,7 @@
#include "helpers.hpp"
#include "license/license_sender.hpp"
#include "memory/global_memory_control.hpp"
#include "query/auth_checker.hpp"
#include "query/auth_query_handler.hpp"
#include "query/config.hpp"
#include "query/discard_value_stream.hpp"
@ -57,8 +58,13 @@ constexpr uint64_t kMgVmMaxMapCount = 262144;
void InitFromCypherlFile(memgraph::query::InterpreterContext &ctx, memgraph::dbms::DatabaseAccess &db_acc,
std::string cypherl_file_path, memgraph::audit::Log *audit_log = nullptr) {
memgraph::query::Interpreter interpreter(&ctx, db_acc);
std::ifstream file(cypherl_file_path);
// Temporary empty user
// TODO: Double check with buda
memgraph::query::AllowEverythingAuthChecker tmp_auth_checker;
auto tmp_user = tmp_auth_checker.GenQueryUser(std::nullopt, std::nullopt);
interpreter.SetUser(tmp_user);
std::ifstream file(cypherl_file_path);
if (!file.is_open()) {
spdlog::trace("Could not find init file {}", cypherl_file_path);
return;

View File

@ -40,6 +40,7 @@ set(mg_query_sources
db_accessor.cpp
auth_query_handler.cpp
interpreter_context.cpp
query_user.cpp
)
add_library(mg-query STATIC ${mg_query_sources})

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -16,7 +16,9 @@
#include <string>
#include <vector>
#include "dbms/constants.hpp"
#include "query/frontend/ast/ast.hpp"
#include "query/query_user.hpp"
#include "storage/v2/id_types.hpp"
namespace memgraph::query {
@ -29,15 +31,12 @@ class AuthChecker {
public:
virtual ~AuthChecker() = default;
[[nodiscard]] virtual bool IsUserAuthorized(const std::optional<std::string> &username,
const std::vector<AuthQuery::Privilege> &privileges,
const std::string &db_name) const = 0;
virtual std::shared_ptr<QueryUserOrRole> GenQueryUser(const std::optional<std::string> &username,
const std::optional<std::string> &rolename) const = 0;
#ifdef MG_ENTERPRISE
[[nodiscard]] virtual std::unique_ptr<FineGrainedAuthChecker> GetFineGrainedAuthChecker(
const std::string &username, const DbAccessor *db_accessor) const = 0;
virtual void ClearCache() const = 0;
std::shared_ptr<QueryUserOrRole> user, const DbAccessor *db_accessor) const = 0;
#endif
};
#ifdef MG_ENTERPRISE
@ -98,19 +97,29 @@ class AllowEverythingFineGrainedAuthChecker final : public FineGrainedAuthChecke
class AllowEverythingAuthChecker final : public AuthChecker {
public:
bool IsUserAuthorized(const std::optional<std::string> & /*username*/,
const std::vector<AuthQuery::Privilege> & /*privileges*/,
const std::string & /*db*/) const override {
return true;
struct User : query::QueryUserOrRole {
User() : query::QueryUserOrRole{std::nullopt, std::nullopt} {}
User(std::string name) : query::QueryUserOrRole{std::move(name), std::nullopt} {}
bool IsAuthorized(const std::vector<AuthQuery::Privilege> & /*privileges*/, const std::string & /*db_name*/,
UserPolicy * /*policy*/) const override {
return true;
}
#ifdef MG_ENTERPRISE
std::string GetDefaultDB() const override { return std::string{dbms::kDefaultDB}; }
#endif
};
std::shared_ptr<query::QueryUserOrRole> GenQueryUser(const std::optional<std::string> &name,
const std::optional<std::string> & /*role*/) const override {
if (name) return std::make_shared<User>(std::move(*name));
return std::make_shared<User>();
}
#ifdef MG_ENTERPRISE
std::unique_ptr<FineGrainedAuthChecker> GetFineGrainedAuthChecker(const std::string & /*username*/,
std::unique_ptr<FineGrainedAuthChecker> GetFineGrainedAuthChecker(std::shared_ptr<QueryUserOrRole> /*user*/,
const DbAccessor * /*dba*/) const override {
return std::make_unique<AllowEverythingFineGrainedAuthChecker>();
}
void ClearCache() const override {}
#endif
};

View File

@ -46,15 +46,17 @@ class AuthQueryHandler {
system::Transaction *system_tx) = 0;
#ifdef MG_ENTERPRISE
/// Return true if access revoked successfully
/// @throw QueryRuntimeException if an error ocurred.
virtual bool RevokeDatabaseFromUser(const std::string &db, const std::string &username,
system::Transaction *system_tx) = 0;
/// Return true if access granted successfully
/// @throw QueryRuntimeException if an error ocurred.
virtual bool GrantDatabaseToUser(const std::string &db, const std::string &username,
system::Transaction *system_tx) = 0;
virtual void GrantDatabase(const std::string &db, const std::string &username, system::Transaction *system_tx) = 0;
/// Return true if access revoked successfully
/// @throw QueryRuntimeException if an error ocurred.
virtual void DenyDatabase(const std::string &db, const std::string &username, system::Transaction *system_tx) = 0;
/// Return true if access revoked successfully
/// @throw QueryRuntimeException if an error ocurred.
virtual void RevokeDatabase(const std::string &db, const std::string &username, system::Transaction *system_tx) = 0;
/// Returns database access rights for the user
/// @throw QueryRuntimeException if an error ocurred.
@ -62,7 +64,7 @@ class AuthQueryHandler {
/// Return true if main database set successfully
/// @throw QueryRuntimeException if an error ocurred.
virtual bool SetMainDatabase(std::string_view db, const std::string &username, system::Transaction *system_tx) = 0;
virtual void SetMainDatabase(std::string_view db, const std::string &username, system::Transaction *system_tx) = 0;
/// Delete database from all users
/// @throw QueryRuntimeException if an error ocurred.

View File

@ -2819,6 +2819,7 @@ class AuthQuery : public memgraph::query::Query {
SHOW_ROLE_FOR_USER,
SHOW_USERS_FOR_ROLE,
GRANT_DATABASE_TO_USER,
DENY_DATABASE_FROM_USER,
REVOKE_DATABASE_FROM_USER,
SHOW_DATABASE_PRIVILEGES,
SET_MAIN_DATABASE,

View File

@ -1780,22 +1780,35 @@ antlrcpp::Any CypherMainVisitor::visitShowUsersForRole(MemgraphCypher::ShowUsers
/**
* @return AuthQuery*
*/
antlrcpp::Any CypherMainVisitor::visitGrantDatabaseToUser(MemgraphCypher::GrantDatabaseToUserContext *ctx) {
antlrcpp::Any CypherMainVisitor::visitGrantDatabaseToUserOrRole(MemgraphCypher::GrantDatabaseToUserOrRoleContext *ctx) {
auto *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::GRANT_DATABASE_TO_USER;
auth->database_ = std::any_cast<std::string>(ctx->wildcardName()->accept(this));
auth->user_ = std::any_cast<std::string>(ctx->user->accept(this));
auth->user_ = std::any_cast<std::string>(ctx->userOrRole->accept(this));
return auth;
}
/**
* @return AuthQuery*
*/
antlrcpp::Any CypherMainVisitor::visitRevokeDatabaseFromUser(MemgraphCypher::RevokeDatabaseFromUserContext *ctx) {
antlrcpp::Any CypherMainVisitor::visitDenyDatabaseFromUserOrRole(
MemgraphCypher::DenyDatabaseFromUserOrRoleContext *ctx) {
auto *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::DENY_DATABASE_FROM_USER;
auth->database_ = std::any_cast<std::string>(ctx->wildcardName()->accept(this));
auth->user_ = std::any_cast<std::string>(ctx->userOrRole->accept(this));
return auth;
}
/**
* @return AuthQuery*
*/
antlrcpp::Any CypherMainVisitor::visitRevokeDatabaseFromUserOrRole(
MemgraphCypher::RevokeDatabaseFromUserOrRoleContext *ctx) {
auto *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::REVOKE_DATABASE_FROM_USER;
auth->database_ = std::any_cast<std::string>(ctx->wildcardName()->accept(this));
auth->user_ = std::any_cast<std::string>(ctx->user->accept(this));
auth->user_ = std::any_cast<std::string>(ctx->userOrRole->accept(this));
return auth;
}
@ -1805,7 +1818,7 @@ antlrcpp::Any CypherMainVisitor::visitRevokeDatabaseFromUser(MemgraphCypher::Rev
antlrcpp::Any CypherMainVisitor::visitShowDatabasePrivileges(MemgraphCypher::ShowDatabasePrivilegesContext *ctx) {
auto *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::SHOW_DATABASE_PRIVILEGES;
auth->user_ = std::any_cast<std::string>(ctx->user->accept(this));
auth->user_ = std::any_cast<std::string>(ctx->userOrRole->accept(this));
return auth;
}
@ -1816,7 +1829,7 @@ antlrcpp::Any CypherMainVisitor::visitSetMainDatabase(MemgraphCypher::SetMainDat
auto *auth = storage_->Create<AuthQuery>();
auth->action_ = AuthQuery::Action::SET_MAIN_DATABASE;
auth->database_ = std::any_cast<std::string>(ctx->db->accept(this));
auth->user_ = std::any_cast<std::string>(ctx->user->accept(this));
auth->user_ = std::any_cast<std::string>(ctx->userOrRole->accept(this));
return auth;
}

View File

@ -605,12 +605,17 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
/**
* @return AuthQuery*
*/
antlrcpp::Any visitGrantDatabaseToUser(MemgraphCypher::GrantDatabaseToUserContext *ctx) override;
antlrcpp::Any visitGrantDatabaseToUserOrRole(MemgraphCypher::GrantDatabaseToUserOrRoleContext *ctx) override;
/**
* @return AuthQuery*
*/
antlrcpp::Any visitRevokeDatabaseFromUser(MemgraphCypher::RevokeDatabaseFromUserContext *ctx) override;
antlrcpp::Any visitDenyDatabaseFromUserOrRole(MemgraphCypher::DenyDatabaseFromUserOrRoleContext *ctx) override;
/**
* @return AuthQuery*
*/
antlrcpp::Any visitRevokeDatabaseFromUserOrRole(MemgraphCypher::RevokeDatabaseFromUserOrRoleContext *ctx) override;
/**
* @return AuthQuery*

View File

@ -176,8 +176,9 @@ authQuery : createRole
| showPrivileges
| showRoleForUser
| showUsersForRole
| grantDatabaseToUser
| revokeDatabaseFromUser
| grantDatabaseToUserOrRole
| denyDatabaseFromUserOrRole
| revokeDatabaseFromUserOrRole
| showDatabasePrivileges
| setMainDatabase
;
@ -303,13 +304,15 @@ denyPrivilege : DENY ( ALL PRIVILEGES | privileges=privilegesList ) TO userOrRol
revokePrivilege : REVOKE ( ALL PRIVILEGES | privileges=revokePrivilegesList ) FROM userOrRole=userOrRoleName ;
grantDatabaseToUser : GRANT DATABASE db=wildcardName TO user=symbolicName ;
grantDatabaseToUserOrRole : GRANT DATABASE db=wildcardName TO userOrRole=userOrRoleName ;
revokeDatabaseFromUser : REVOKE DATABASE db=wildcardName FROM user=symbolicName ;
denyDatabaseFromUserOrRole : DENY DATABASE db=wildcardName FROM userOrRole=userOrRoleName ;
showDatabasePrivileges : SHOW DATABASE PRIVILEGES FOR user=symbolicName ;
revokeDatabaseFromUserOrRole : REVOKE DATABASE db=wildcardName FROM userOrRole=userOrRoleName ;
setMainDatabase : SET MAIN DATABASE db=symbolicName FOR user=symbolicName ;
showDatabasePrivileges : SHOW DATABASE PRIVILEGES FOR userOrRole=userOrRoleName ;
setMainDatabase : SET MAIN DATABASE db=symbolicName FOR userOrRole=userOrRoleName ;
privilege : CREATE
| DELETE

View File

@ -68,6 +68,7 @@
#include "query/plan/profile.hpp"
#include "query/plan/vertex_count_cache.hpp"
#include "query/procedure/module.hpp"
#include "query/query_user.hpp"
#include "query/replication_query_handler.hpp"
#include "query/stream.hpp"
#include "query/stream/common.hpp"
@ -629,6 +630,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_
AuthQuery::Action::SHOW_USERS_FOR_ROLE,
AuthQuery::Action::SHOW_ROLE_FOR_USER,
AuthQuery::Action::GRANT_DATABASE_TO_USER,
AuthQuery::Action::DENY_DATABASE_FROM_USER,
AuthQuery::Action::REVOKE_DATABASE_FROM_USER,
AuthQuery::Action::SHOW_DATABASE_PRIVILEGES,
AuthQuery::Action::SET_MAIN_DATABASE};
@ -888,9 +890,31 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_
if (database != memgraph::auth::kAllDatabases) {
db = db_handler->Get(database); // Will throw if databases doesn't exist and protect it during pull
}
if (!auth->GrantDatabaseToUser(database, username, &*interpreter->system_transaction_)) {
throw QueryRuntimeException("Failed to grant database {} to user {}.", database, username);
auth->GrantDatabase(database, username, &*interpreter->system_transaction_); // Can throws query exception
} catch (memgraph::dbms::UnknownDatabaseException &e) {
throw QueryRuntimeException(e.what());
}
#else
callback.fn = [] {
#endif
return std::vector<std::vector<TypedValue>>();
};
return callback;
case AuthQuery::Action::DENY_DATABASE_FROM_USER:
forbid_on_replica();
#ifdef MG_ENTERPRISE
callback.fn = [auth, database, username, db_handler, interpreter = &interpreter] { // NOLINT
if (!interpreter->system_transaction_) {
throw QueryException("Expected to be in a system transaction");
}
try {
std::optional<memgraph::dbms::DatabaseAccess> db =
std::nullopt; // Hold pointer to database to protect it until query is done
if (database != memgraph::auth::kAllDatabases) {
db = db_handler->Get(database); // Will throw if databases doesn't exist and protect it during pull
}
auth->DenyDatabase(database, username, &*interpreter->system_transaction_); // Can throws query exception
} catch (memgraph::dbms::UnknownDatabaseException &e) {
throw QueryRuntimeException(e.what());
}
@ -914,9 +938,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_
if (database != memgraph::auth::kAllDatabases) {
db = db_handler->Get(database); // Will throw if databases doesn't exist and protect it during pull
}
if (!auth->RevokeDatabaseFromUser(database, username, &*interpreter->system_transaction_)) {
throw QueryRuntimeException("Failed to revoke database {} from user {}.", database, username);
}
auth->RevokeDatabase(database, username, &*interpreter->system_transaction_); // Can throws query exception
} catch (memgraph::dbms::UnknownDatabaseException &e) {
throw QueryRuntimeException(e.what());
}
@ -949,9 +971,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_
try {
const auto db =
db_handler->Get(database); // Will throw if databases doesn't exist and protect it during pull
if (!auth->SetMainDatabase(database, username, &*interpreter->system_transaction_)) {
throw QueryRuntimeException("Failed to set main database {} for user {}.", database, username);
}
auth->SetMainDatabase(database, username, &*interpreter->system_transaction_); // Can throws query exception
} catch (memgraph::dbms::UnknownDatabaseException &e) {
throw QueryRuntimeException(e.what());
}
@ -1275,7 +1295,7 @@ std::vector<std::string> EvaluateTopicNames(ExpressionVisitor<TypedValue> &evalu
Callback::CallbackFunction GetKafkaCreateCallback(StreamQuery *stream_query, ExpressionVisitor<TypedValue> &evaluator,
memgraph::dbms::DatabaseAccess db_acc,
InterpreterContext *interpreter_context,
const std::optional<std::string> &username) {
std::shared_ptr<QueryUserOrRole> user_or_role) {
static constexpr std::string_view kDefaultConsumerGroup = "mg_consumer";
std::string consumer_group{stream_query->consumer_group_.empty() ? kDefaultConsumerGroup
: stream_query->consumer_group_};
@ -1302,10 +1322,13 @@ Callback::CallbackFunction GetKafkaCreateCallback(StreamQuery *stream_query, Exp
memgraph::metrics::IncrementCounter(memgraph::metrics::StreamsCreated);
// Make a copy of the user and pass it to the subsystem
auto owner = interpreter_context->auth_checker->GenQueryUser(user_or_role->username(), user_or_role->rolename());
return [db_acc = std::move(db_acc), interpreter_context, stream_name = stream_query->stream_name_,
topic_names = EvaluateTopicNames(evaluator, stream_query->topic_names_),
consumer_group = std::move(consumer_group), common_stream_info = std::move(common_stream_info),
bootstrap_servers = std::move(bootstrap), owner = username,
bootstrap_servers = std::move(bootstrap), owner = std::move(owner),
configs = get_config_map(stream_query->configs_, "Configs"),
credentials = get_config_map(stream_query->credentials_, "Credentials"),
default_server = interpreter_context->config.default_kafka_bootstrap_servers]() mutable {
@ -1327,7 +1350,7 @@ Callback::CallbackFunction GetKafkaCreateCallback(StreamQuery *stream_query, Exp
Callback::CallbackFunction GetPulsarCreateCallback(StreamQuery *stream_query, ExpressionVisitor<TypedValue> &evaluator,
memgraph::dbms::DatabaseAccess db,
InterpreterContext *interpreter_context,
const std::optional<std::string> &username) {
std::shared_ptr<QueryUserOrRole> user_or_role) {
auto service_url = GetOptionalStringValue(stream_query->service_url_, evaluator);
if (service_url && service_url->empty()) {
throw SemanticException("Service URL must not be an empty string!");
@ -1335,9 +1358,13 @@ Callback::CallbackFunction GetPulsarCreateCallback(StreamQuery *stream_query, Ex
auto common_stream_info = GetCommonStreamInfo(stream_query, evaluator);
memgraph::metrics::IncrementCounter(memgraph::metrics::StreamsCreated);
// Make a copy of the user and pass it to the subsystem
auto owner = interpreter_context->auth_checker->GenQueryUser(user_or_role->username(), user_or_role->rolename());
return [db = std::move(db), interpreter_context, stream_name = stream_query->stream_name_,
topic_names = EvaluateTopicNames(evaluator, stream_query->topic_names_),
common_stream_info = std::move(common_stream_info), service_url = std::move(service_url), owner = username,
common_stream_info = std::move(common_stream_info), service_url = std::move(service_url),
owner = std::move(owner),
default_service = interpreter_context->config.default_pulsar_service_url]() mutable {
std::string url = service_url ? std::move(*service_url) : std::move(default_service);
db->streams()->Create<query::stream::PulsarStream>(
@ -1351,7 +1378,7 @@ Callback::CallbackFunction GetPulsarCreateCallback(StreamQuery *stream_query, Ex
Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters &parameters,
memgraph::dbms::DatabaseAccess &db_acc, InterpreterContext *interpreter_context,
const std::optional<std::string> &username, std::vector<Notification> *notifications) {
std::shared_ptr<QueryUserOrRole> user_or_role, std::vector<Notification> *notifications) {
// TODO: MemoryResource for EvaluationContext, it should probably be passed as
// the argument to Callback.
EvaluationContext evaluation_context;
@ -1364,10 +1391,12 @@ Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters &paramete
case StreamQuery::Action::CREATE_STREAM: {
switch (stream_query->type_) {
case StreamQuery::Type::KAFKA:
callback.fn = GetKafkaCreateCallback(stream_query, evaluator, db_acc, interpreter_context, username);
callback.fn =
GetKafkaCreateCallback(stream_query, evaluator, db_acc, interpreter_context, std::move(user_or_role));
break;
case StreamQuery::Type::PULSAR:
callback.fn = GetPulsarCreateCallback(stream_query, evaluator, db_acc, interpreter_context, username);
callback.fn =
GetPulsarCreateCallback(stream_query, evaluator, db_acc, interpreter_context, std::move(user_or_role));
break;
}
notifications->emplace_back(SeverityLevel::INFO, NotificationCode::CREATE_STREAM,
@ -1640,7 +1669,7 @@ struct TxTimeout {
struct PullPlan {
explicit PullPlan(std::shared_ptr<PlanWrapper> plan, const Parameters &parameters, bool is_profile_query,
DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory,
std::optional<std::string> username, std::atomic<TransactionStatus> *transaction_status,
std::shared_ptr<QueryUserOrRole> user_or_role, std::atomic<TransactionStatus> *transaction_status,
std::shared_ptr<utils::AsyncTimer> tx_timer,
TriggerContextCollector *trigger_context_collector = nullptr,
std::optional<size_t> memory_limit = {}, bool use_monotonic_memory = true,
@ -1680,7 +1709,7 @@ struct PullPlan {
PullPlan::PullPlan(const std::shared_ptr<PlanWrapper> plan, const Parameters &parameters, const bool is_profile_query,
DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory,
std::optional<std::string> username, std::atomic<TransactionStatus> *transaction_status,
std::shared_ptr<QueryUserOrRole> user_or_role, std::atomic<TransactionStatus> *transaction_status,
std::shared_ptr<utils::AsyncTimer> tx_timer, TriggerContextCollector *trigger_context_collector,
const std::optional<size_t> memory_limit, bool use_monotonic_memory,
FrameChangeCollector *frame_change_collector)
@ -1696,10 +1725,9 @@ PullPlan::PullPlan(const std::shared_ptr<PlanWrapper> plan, const Parameters &pa
ctx_.evaluation_context.properties = NamesToProperties(plan->ast_storage().properties_, dba);
ctx_.evaluation_context.labels = NamesToLabels(plan->ast_storage().labels_, dba);
#ifdef MG_ENTERPRISE
if (license::global_license_checker.IsEnterpriseValidFast() && username.has_value() && dba) {
// TODO How can we avoid creating this every time? If we must create it, it would be faster with an auth::User
// instead of the username
auto auth_checker = interpreter_context->auth_checker->GetFineGrainedAuthChecker(*username, dba);
if (license::global_license_checker.IsEnterpriseValidFast() && user_or_role && *user_or_role && dba) {
// Create only if an explicit user is defined
auto auth_checker = interpreter_context->auth_checker->GetFineGrainedAuthChecker(std::move(user_or_role), dba);
// if the user has global privileges to read, edit and write anything, we don't need to perform authorization
// otherwise, we do assign the auth checker to check for label access control
@ -1989,7 +2017,7 @@ bool IsCallBatchedProcedureQuery(const std::vector<memgraph::query::Clause *> &c
PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string, TypedValue> *summary,
InterpreterContext *interpreter_context, CurrentDB &current_db,
utils::MemoryResource *execution_memory, std::vector<Notification> *notifications,
std::optional<std::string> const &username,
std::shared_ptr<QueryUserOrRole> user_or_role,
std::atomic<TransactionStatus> *transaction_status,
std::shared_ptr<utils::AsyncTimer> tx_timer,
FrameChangeCollector *frame_change_collector = nullptr) {
@ -2057,8 +2085,8 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string,
auto *trigger_context_collector =
current_db.trigger_context_collector_ ? &*current_db.trigger_context_collector_ : nullptr;
auto pull_plan = std::make_shared<PullPlan>(
plan, parsed_query.parameters, false, dba, interpreter_context, execution_memory, username, transaction_status,
std::move(tx_timer), trigger_context_collector, memory_limit, use_monotonic_memory,
plan, parsed_query.parameters, false, dba, interpreter_context, execution_memory, std::move(user_or_role),
transaction_status, std::move(tx_timer), trigger_context_collector, memory_limit, use_monotonic_memory,
frame_change_collector->IsTrackingValues() ? frame_change_collector : nullptr);
return PreparedQuery{std::move(header), std::move(parsed_query.required_privileges),
[pull_plan = std::move(pull_plan), output_symbols = std::move(output_symbols), summary](
@ -2130,7 +2158,8 @@ PreparedQuery PrepareExplainQuery(ParsedQuery parsed_query, std::map<std::string
PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
std::map<std::string, TypedValue> *summary, std::vector<Notification> *notifications,
InterpreterContext *interpreter_context, CurrentDB &current_db,
utils::MemoryResource *execution_memory, std::optional<std::string> const &username,
utils::MemoryResource *execution_memory,
std::shared_ptr<QueryUserOrRole> user_or_role,
std::atomic<TransactionStatus> *transaction_status,
std::shared_ptr<utils::AsyncTimer> tx_timer,
FrameChangeCollector *frame_change_collector) {
@ -2208,37 +2237,37 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra
rw_type_checker.InferRWType(const_cast<plan::LogicalOperator &>(cypher_query_plan->plan()));
return PreparedQuery{{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"},
std::move(parsed_query.required_privileges),
[plan = std::move(cypher_query_plan), parameters = std::move(parsed_inner_query.parameters),
summary, dba, interpreter_context, execution_memory, memory_limit, username,
// We want to execute the query we are profiling lazily, so we delay
// the construction of the corresponding context.
stats_and_total_time = std::optional<plan::ProfilingStatsWithTotalTime>{},
pull_plan = std::shared_ptr<PullPlanVector>(nullptr), transaction_status, use_monotonic_memory,
frame_change_collector, tx_timer = std::move(tx_timer)](
AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> {
// No output symbols are given so that nothing is streamed.
if (!stats_and_total_time) {
stats_and_total_time =
PullPlan(plan, parameters, true, dba, interpreter_context, execution_memory, username,
transaction_status, std::move(tx_timer), nullptr, memory_limit,
use_monotonic_memory,
frame_change_collector->IsTrackingValues() ? frame_change_collector : nullptr)
.Pull(stream, {}, {}, summary);
pull_plan = std::make_shared<PullPlanVector>(ProfilingStatsToTable(*stats_and_total_time));
}
return PreparedQuery{
{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"},
std::move(parsed_query.required_privileges),
[plan = std::move(cypher_query_plan), parameters = std::move(parsed_inner_query.parameters), summary, dba,
interpreter_context, execution_memory, memory_limit, user_or_role = std::move(user_or_role),
// We want to execute the query we are profiling lazily, so we delay
// the construction of the corresponding context.
stats_and_total_time = std::optional<plan::ProfilingStatsWithTotalTime>{},
pull_plan = std::shared_ptr<PullPlanVector>(nullptr), transaction_status, use_monotonic_memory,
frame_change_collector, tx_timer = std::move(tx_timer)](
AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> {
// No output symbols are given so that nothing is streamed.
if (!stats_and_total_time) {
stats_and_total_time =
PullPlan(plan, parameters, true, dba, interpreter_context, execution_memory, std::move(user_or_role),
transaction_status, std::move(tx_timer), nullptr, memory_limit, use_monotonic_memory,
frame_change_collector->IsTrackingValues() ? frame_change_collector : nullptr)
.Pull(stream, {}, {}, summary);
pull_plan = std::make_shared<PullPlanVector>(ProfilingStatsToTable(*stats_and_total_time));
}
MG_ASSERT(stats_and_total_time, "Failed to execute the query!");
MG_ASSERT(stats_and_total_time, "Failed to execute the query!");
if (pull_plan->Pull(stream, n)) {
summary->insert_or_assign("profile", ProfilingStatsToJson(*stats_and_total_time).dump());
return QueryHandlerResult::ABORT;
}
if (pull_plan->Pull(stream, n)) {
summary->insert_or_assign("profile", ProfilingStatsToJson(*stats_and_total_time).dump());
return QueryHandlerResult::ABORT;
}
return std::nullopt;
},
rw_type_checker.type};
return std::nullopt;
},
rw_type_checker.type};
}
PreparedQuery PrepareDumpQuery(ParsedQuery parsed_query, CurrentDB &current_db) {
@ -2662,26 +2691,22 @@ PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transa
auto callback = HandleAuthQuery(auth_query, interpreter_context, parsed_query.parameters, interpreter);
return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges),
[handler = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>(nullptr),
interpreter_context]( // NOLINT
AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> {
if (!pull_plan) {
// Run the specific query
auto results = handler();
pull_plan = std::make_shared<PullPlanVector>(std::move(results));
#ifdef MG_ENTERPRISE
// Invalidate auth cache after every type of AuthQuery
interpreter_context->auth_checker->ClearCache();
#endif
}
return PreparedQuery{
std::move(callback.header), std::move(parsed_query.required_privileges),
[handler = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>(nullptr)]( // NOLINT
AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> {
if (!pull_plan) {
// Run the specific query
auto results = handler();
pull_plan = std::make_shared<PullPlanVector>(std::move(results));
}
if (pull_plan->Pull(stream, n)) {
return QueryHandlerResult::COMMIT;
}
return std::nullopt;
},
RWType::NONE};
if (pull_plan->Pull(stream, n)) {
return QueryHandlerResult::COMMIT;
}
return std::nullopt;
},
RWType::NONE};
}
PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
@ -2885,17 +2910,18 @@ TriggerEventType ToTriggerEventType(const TriggerQuery::EventType event_type) {
Callback CreateTrigger(TriggerQuery *trigger_query,
const std::map<std::string, storage::PropertyValue> &user_parameters,
TriggerStore *trigger_store, InterpreterContext *interpreter_context, DbAccessor *dba,
std::optional<std::string> owner) {
std::shared_ptr<QueryUserOrRole> user_or_role) {
// Make a copy of the user and pass it to the subsystem
auto owner = interpreter_context->auth_checker->GenQueryUser(user_or_role->username(), user_or_role->rolename());
return {{},
[trigger_name = std::move(trigger_query->trigger_name_),
trigger_statement = std::move(trigger_query->statement_), event_type = trigger_query->event_type_,
before_commit = trigger_query->before_commit_, trigger_store, interpreter_context, dba, user_parameters,
owner = std::move(owner)]() mutable -> std::vector<std::vector<TypedValue>> {
trigger_store->AddTrigger(std::move(trigger_name), trigger_statement, user_parameters,
ToTriggerEventType(event_type),
before_commit ? TriggerPhase::BEFORE_COMMIT : TriggerPhase::AFTER_COMMIT,
&interpreter_context->ast_cache, dba, interpreter_context->config.query,
std::move(owner), interpreter_context->auth_checker);
trigger_store->AddTrigger(
std::move(trigger_name), trigger_statement, user_parameters, ToTriggerEventType(event_type),
before_commit ? TriggerPhase::BEFORE_COMMIT : TriggerPhase::AFTER_COMMIT,
&interpreter_context->ast_cache, dba, interpreter_context->config.query, std::move(owner));
memgraph::metrics::IncrementCounter(memgraph::metrics::TriggersCreated);
return {};
}};
@ -2937,7 +2963,7 @@ PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, bool in_explicit_tra
std::vector<Notification> *notifications, CurrentDB &current_db,
InterpreterContext *interpreter_context,
const std::map<std::string, storage::PropertyValue> &user_parameters,
std::optional<std::string> const &username) {
std::shared_ptr<QueryUserOrRole> user_or_role) {
if (in_explicit_transaction) {
throw TriggerModificationInMulticommandTxException();
}
@ -2951,8 +2977,9 @@ PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, bool in_explicit_tra
MG_ASSERT(trigger_query);
std::optional<Notification> trigger_notification;
auto callback = std::invoke([trigger_query, trigger_store, interpreter_context, dba, &user_parameters,
owner = username, &trigger_notification]() mutable {
owner = std::move(user_or_role), &trigger_notification]() mutable {
switch (trigger_query->action_) {
case TriggerQuery::Action::CREATE_TRIGGER:
trigger_notification.emplace(SeverityLevel::INFO, NotificationCode::CREATE_TRIGGER,
@ -2990,7 +3017,8 @@ PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, bool in_explicit_tra
PreparedQuery PrepareStreamQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
std::vector<Notification> *notifications, CurrentDB &current_db,
InterpreterContext *interpreter_context, const std::optional<std::string> &username) {
InterpreterContext *interpreter_context,
std::shared_ptr<QueryUserOrRole> user_or_role) {
if (in_explicit_transaction) {
throw StreamQueryInMulticommandTxException();
}
@ -3000,8 +3028,8 @@ PreparedQuery PrepareStreamQuery(ParsedQuery parsed_query, bool in_explicit_tran
auto *stream_query = utils::Downcast<StreamQuery>(parsed_query.query);
MG_ASSERT(stream_query);
auto callback =
HandleStreamQuery(stream_query, parsed_query.parameters, db_acc, interpreter_context, username, notifications);
auto callback = HandleStreamQuery(stream_query, parsed_query.parameters, db_acc, interpreter_context,
std::move(user_or_role), notifications);
return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges),
[callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}](
@ -3305,7 +3333,7 @@ PreparedQuery PrepareSettingQuery(ParsedQuery parsed_query, bool in_explicit_tra
}
template <typename Func>
auto ShowTransactions(const std::unordered_set<Interpreter *> &interpreters, const std::optional<std::string> &username,
auto ShowTransactions(const std::unordered_set<Interpreter *> &interpreters, QueryUserOrRole *user_or_role,
Func &&privilege_checker) -> std::vector<std::vector<TypedValue>> {
std::vector<std::vector<TypedValue>> results;
results.reserve(interpreters.size());
@ -3325,11 +3353,21 @@ auto ShowTransactions(const std::unordered_set<Interpreter *> &interpreters, con
static std::string all;
return interpreter->current_db_.db_acc_ ? interpreter->current_db_.db_acc_->get()->name() : all;
};
if (transaction_id.has_value() &&
(interpreter->username_ == username || privilege_checker(get_interpreter_db_name()))) {
auto same_user = [](const auto &lv, const auto &rv) {
if (lv.get() == rv) return true;
if (lv && rv) return *lv == *rv;
return false;
};
if (transaction_id.has_value() && (same_user(interpreter->user_or_role_, user_or_role) ||
privilege_checker(user_or_role, get_interpreter_db_name()))) {
const auto &typed_queries = interpreter->GetQueries();
results.push_back({TypedValue(interpreter->username_.value_or("")),
TypedValue(std::to_string(transaction_id.value())), TypedValue(typed_queries)});
results.push_back(
{TypedValue(interpreter->user_or_role_
? (interpreter->user_or_role_->username() ? *interpreter->user_or_role_->username() : "")
: ""),
TypedValue(std::to_string(transaction_id.value())), TypedValue(typed_queries)});
// Handle user-defined metadata
std::map<std::string, TypedValue> metadata_tv;
if (interpreter->metadata_) {
@ -3344,17 +3382,19 @@ auto ShowTransactions(const std::unordered_set<Interpreter *> &interpreters, con
}
Callback HandleTransactionQueueQuery(TransactionQueueQuery *transaction_query,
const std::optional<std::string> &username, const Parameters &parameters,
std::shared_ptr<QueryUserOrRole> user_or_role, const Parameters &parameters,
InterpreterContext *interpreter_context) {
auto privilege_checker = [username, auth_checker = interpreter_context->auth_checker](std::string const &db_name) {
return auth_checker->IsUserAuthorized(username, {query::AuthQuery::Privilege::TRANSACTION_MANAGEMENT}, db_name);
auto privilege_checker = [](QueryUserOrRole *user_or_role, std::string const &db_name) {
return user_or_role && user_or_role->IsAuthorized({query::AuthQuery::Privilege::TRANSACTION_MANAGEMENT}, db_name,
&query::up_to_date_policy);
};
Callback callback;
switch (transaction_query->action_) {
case TransactionQueueQuery::Action::SHOW_TRANSACTIONS: {
auto show_transactions = [username, privilege_checker = std::move(privilege_checker)](const auto &interpreters) {
return ShowTransactions(interpreters, username, privilege_checker);
auto show_transactions = [user_or_role = std::move(user_or_role),
privilege_checker = std::move(privilege_checker)](const auto &interpreters) {
return ShowTransactions(interpreters, user_or_role.get(), privilege_checker);
};
callback.header = {"username", "transaction_id", "query", "metadata"};
callback.fn = [interpreter_context, show_transactions = std::move(show_transactions)] {
@ -3372,9 +3412,10 @@ Callback HandleTransactionQueueQuery(TransactionQueueQuery *transaction_query,
return std::string(expression->Accept(evaluator).ValueString());
});
callback.header = {"transaction_id", "killed"};
callback.fn = [interpreter_context, maybe_kill_transaction_ids = std::move(maybe_kill_transaction_ids), username,
callback.fn = [interpreter_context, maybe_kill_transaction_ids = std::move(maybe_kill_transaction_ids),
user_or_role = std::move(user_or_role),
privilege_checker = std::move(privilege_checker)]() mutable {
return interpreter_context->TerminateTransactions(std::move(maybe_kill_transaction_ids), username,
return interpreter_context->TerminateTransactions(std::move(maybe_kill_transaction_ids), user_or_role.get(),
std::move(privilege_checker));
};
break;
@ -3384,12 +3425,12 @@ Callback HandleTransactionQueueQuery(TransactionQueueQuery *transaction_query,
return callback;
}
PreparedQuery PrepareTransactionQueueQuery(ParsedQuery parsed_query, const std::optional<std::string> &username,
PreparedQuery PrepareTransactionQueueQuery(ParsedQuery parsed_query, std::shared_ptr<QueryUserOrRole> user_or_role,
InterpreterContext *interpreter_context) {
auto *transaction_queue_query = utils::Downcast<TransactionQueueQuery>(parsed_query.query);
MG_ASSERT(transaction_queue_query);
auto callback =
HandleTransactionQueueQuery(transaction_queue_query, username, parsed_query.parameters, interpreter_context);
auto callback = HandleTransactionQueueQuery(transaction_queue_query, std::move(user_or_role), parsed_query.parameters,
interpreter_context);
return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges),
[callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}](
@ -4022,7 +4063,7 @@ PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, CurrentDB &cur
}
PreparedQuery PrepareShowDatabasesQuery(ParsedQuery parsed_query, InterpreterContext *interpreter_context,
const std::optional<std::string> &username) {
std::shared_ptr<QueryUserOrRole> user_or_role) {
#ifdef MG_ENTERPRISE
if (!license::global_license_checker.IsEnterpriseValidFast()) {
throw QueryException("Trying to use enterprise feature without a valid license.");
@ -4033,7 +4074,8 @@ PreparedQuery PrepareShowDatabasesQuery(ParsedQuery parsed_query, InterpreterCon
Callback callback;
callback.header = {"Name"};
callback.fn = [auth, db_handler, username]() mutable -> std::vector<std::vector<TypedValue>> {
callback.fn = [auth, db_handler,
user_or_role = std::move(user_or_role)]() mutable -> std::vector<std::vector<TypedValue>> {
std::vector<std::vector<TypedValue>> status;
auto gen_status = [&]<typename T, typename K>(T all, K denied) {
Sort(all);
@ -4055,12 +4097,12 @@ PreparedQuery PrepareShowDatabasesQuery(ParsedQuery parsed_query, InterpreterCon
status.erase(iter, status.end());
};
if (!username) {
if (!user_or_role || !*user_or_role) {
// No user, return all
gen_status(db_handler->All(), std::vector<TypedValue>{});
} else {
// User has a subset of accessible dbs; this is synched with the SessionContextHandler
const auto &db_priv = auth->GetDatabasePrivileges(*username);
const auto &db_priv = auth->GetDatabasePrivileges(user_or_role->key());
const auto &allowed = db_priv[0][0];
const auto &denied = db_priv[0][1].ValueList();
if (allowed.IsString() && allowed.ValueString() == auth::kAllDatabases) {
@ -4128,6 +4170,7 @@ void Interpreter::SetCurrentDB(std::string_view db_name, bool in_explicit_db) {
Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
const std::map<std::string, storage::PropertyValue> &params,
QueryExtras const &extras) {
MG_ASSERT(user_or_role_, "Trying to prepare a query without a query user.");
// Handle transaction control queries.
const auto upper_case_query = utils::ToUpperCase(query_string);
const auto trimmed_query = utils::Trim(upper_case_query);
@ -4270,7 +4313,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
frame_change_collector_.emplace();
if (utils::Downcast<CypherQuery>(parsed_query.query)) {
prepared_query = PrepareCypherQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_,
current_db_, memory_resource, &query_execution->notifications, username_,
current_db_, memory_resource, &query_execution->notifications, user_or_role_,
&transaction_status_, current_timeout_timer_, &*frame_change_collector_);
} else if (utils::Downcast<ExplainQuery>(parsed_query.query)) {
prepared_query = PrepareExplainQuery(std::move(parsed_query), &query_execution->summary,
@ -4278,7 +4321,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
} else if (utils::Downcast<ProfileQuery>(parsed_query.query)) {
prepared_query = PrepareProfileQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary,
&query_execution->notifications, interpreter_context_, current_db_,
&query_execution->execution_memory_with_exception, username_,
&query_execution->execution_memory_with_exception, user_or_role_,
&transaction_status_, current_timeout_timer_, &*frame_change_collector_);
} else if (utils::Downcast<DumpQuery>(parsed_query.query)) {
prepared_query = PrepareDumpQuery(std::move(parsed_query), current_db_);
@ -4322,11 +4365,11 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
} else if (utils::Downcast<TriggerQuery>(parsed_query.query)) {
prepared_query =
PrepareTriggerQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications,
current_db_, interpreter_context_, params, username_);
current_db_, interpreter_context_, params, user_or_role_);
} else if (utils::Downcast<StreamQuery>(parsed_query.query)) {
prepared_query =
PrepareStreamQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications,
current_db_, interpreter_context_, username_);
current_db_, interpreter_context_, user_or_role_);
} else if (utils::Downcast<IsolationLevelQuery>(parsed_query.query)) {
prepared_query = PrepareIsolationLevelQuery(std::move(parsed_query), in_explicit_transaction_, current_db_, this);
} else if (utils::Downcast<CreateSnapshotQuery>(parsed_query.query)) {
@ -4347,7 +4390,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
if (in_explicit_transaction_) {
throw TransactionQueueInMulticommandTxException();
}
prepared_query = PrepareTransactionQueueQuery(std::move(parsed_query), username_, interpreter_context_);
prepared_query = PrepareTransactionQueueQuery(std::move(parsed_query), user_or_role_, interpreter_context_);
} else if (utils::Downcast<MultiDatabaseQuery>(parsed_query.query)) {
if (in_explicit_transaction_) {
throw MultiDatabaseQueryInMulticommandTxException();
@ -4357,7 +4400,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
prepared_query =
PrepareMultiDatabaseQuery(std::move(parsed_query), current_db_, interpreter_context_, on_change_, *this);
} else if (utils::Downcast<ShowDatabasesQuery>(parsed_query.query)) {
prepared_query = PrepareShowDatabasesQuery(std::move(parsed_query), interpreter_context_, username_);
prepared_query = PrepareShowDatabasesQuery(std::move(parsed_query), interpreter_context_, user_or_role_);
} else if (utils::Downcast<EdgeImportModeQuery>(parsed_query.query)) {
if (in_explicit_transaction_) {
throw EdgeImportModeModificationInMulticommandTxException();
@ -4433,6 +4476,12 @@ std::vector<TypedValue> Interpreter::GetQueries() {
void Interpreter::Abort() {
bool decrement = true;
// System tx
// TODO Implement system transaction scope and the ability to abort
system_transaction_.reset();
// Data tx
auto expected = TransactionStatus::ACTIVE;
while (!transaction_status_.compare_exchange_weak(expected, TransactionStatus::STARTED_ROLLBACK)) {
if (expected == TransactionStatus::TERMINATED || expected == TransactionStatus::IDLE) {
@ -4484,8 +4533,7 @@ void RunTriggersAfterCommit(dbms::DatabaseAccess db_acc, InterpreterContext *int
trigger_context.AdaptForAccessor(&db_accessor);
try {
trigger.Execute(&db_accessor, &execution_memory, flags::run_time::GetExecutionTimeout(),
&interpreter_context->is_shutting_down, transaction_status, trigger_context,
interpreter_context->auth_checker);
&interpreter_context->is_shutting_down, transaction_status, trigger_context);
} catch (const utils::BasicException &exception) {
spdlog::warn("Trigger '{}' failed with exception:\n{}", trigger.Name(), exception.what());
db_accessor.Abort();
@ -4642,8 +4690,7 @@ void Interpreter::Commit() {
AdvanceCommand();
try {
trigger.Execute(&*current_db_.execution_db_accessor_, &execution_memory, flags::run_time::GetExecutionTimeout(),
&interpreter_context_->is_shutting_down, &transaction_status_, *trigger_context,
interpreter_context_->auth_checker);
&interpreter_context_->is_shutting_down, &transaction_status_, *trigger_context);
} catch (const utils::BasicException &e) {
throw utils::BasicException(
fmt::format("Trigger '{}' caused the transaction to fail.\nException: {}", trigger.Name(), e.what()));
@ -4758,7 +4805,7 @@ void Interpreter::SetNextTransactionIsolationLevel(const storage::IsolationLevel
void Interpreter::SetSessionIsolationLevel(const storage::IsolationLevel isolation_level) {
interpreter_isolation_level.emplace(isolation_level);
}
void Interpreter::ResetUser() { username_.reset(); }
void Interpreter::SetUser(std::string_view username) { username_ = username; }
void Interpreter::ResetUser() { user_or_role_.reset(); }
void Interpreter::SetUser(std::shared_ptr<QueryUserOrRole> user_or_role) { user_or_role_ = std::move(user_or_role); }
} // namespace memgraph::query

View File

@ -210,7 +210,7 @@ class Interpreter final {
std::optional<std::string> db;
};
std::optional<std::string> username_;
std::shared_ptr<QueryUserOrRole> user_or_role_{};
bool in_explicit_transaction_{false};
CurrentDB current_db_;
@ -300,7 +300,7 @@ class Interpreter final {
void ResetUser();
void SetUser(std::string_view username);
void SetUser(std::shared_ptr<QueryUserOrRole> user);
std::optional<memgraph::system::Transaction> system_transaction_{};

View File

@ -35,13 +35,13 @@ InterpreterContext::InterpreterContext(InterpreterConfig interpreter_config, dbm
}
std::vector<std::vector<TypedValue>> InterpreterContext::TerminateTransactions(
std::vector<std::string> maybe_kill_transaction_ids, const std::optional<std::string> &username,
std::function<bool(std::string const &)> privilege_checker) {
std::vector<std::string> maybe_kill_transaction_ids, QueryUserOrRole *user_or_role,
std::function<bool(QueryUserOrRole *, std::string const &)> privilege_checker) {
auto not_found_midpoint = maybe_kill_transaction_ids.end();
// Multiple simultaneous TERMINATE TRANSACTIONS aren't allowed
// TERMINATE and SHOW TRANSACTIONS are mutually exclusive
interpreters.WithLock([&not_found_midpoint, &maybe_kill_transaction_ids, username,
interpreters.WithLock([&not_found_midpoint, &maybe_kill_transaction_ids, user_or_role,
privilege_checker = std::move(privilege_checker)](const auto &interpreters) {
for (Interpreter *interpreter : interpreters) {
TransactionStatus alive_status = TransactionStatus::ACTIVE;
@ -73,7 +73,15 @@ std::vector<std::vector<TypedValue>> InterpreterContext::TerminateTransactions(
static std::string all;
return interpreter->current_db_.db_acc_ ? interpreter->current_db_.db_acc_->get()->name() : all;
};
if (interpreter->username_ == username || privilege_checker(get_interpreter_db_name())) {
auto same_user = [](const auto &lv, const auto &rv) {
if (lv.get() == rv) return true;
if (lv && rv) return *lv == *rv;
return false;
};
if (same_user(interpreter->user_or_role_, user_or_role) ||
privilege_checker(user_or_role, get_interpreter_db_name())) {
killed = true; // Note: this is used by the above `clean_status` (OnScopeExit)
spdlog::warn("Transaction {} successfully killed", transaction_id);
} else {

View File

@ -46,6 +46,7 @@ constexpr uint64_t kInterpreterTransactionInitialId = 1ULL << 63U;
class AuthQueryHandler;
class AuthChecker;
class Interpreter;
struct QueryUserOrRole;
/**
* Holds data shared between multiple `Interpreter` instances (which might be
@ -95,8 +96,8 @@ struct InterpreterContext {
void Shutdown() { is_shutting_down.store(true, std::memory_order_release); }
std::vector<std::vector<TypedValue>> TerminateTransactions(
std::vector<std::string> maybe_kill_transaction_ids, const std::optional<std::string> &username,
std::function<bool(std::string const &)> privilege_checker);
std::vector<std::string> maybe_kill_transaction_ids, QueryUserOrRole *user_or_role,
std::function<bool(QueryUserOrRole *, std::string const &)> privilege_checker);
};
} // namespace memgraph::query

21
src/query/query_user.cpp Normal file
View File

@ -0,0 +1,21 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/query_user.hpp"
namespace memgraph::query {
// The variables below are used to define a user auth policy.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
SessionLongPolicy session_long_policy;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
UpToDatePolicy up_to_date_policy;
} // namespace memgraph::query

61
src/query/query_user.hpp Normal file
View File

@ -0,0 +1,61 @@
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <string>
#include <vector>
#include "query/frontend/ast/ast.hpp"
namespace memgraph::query {
class UserPolicy {
public:
virtual bool DoUpdate() const = 0;
};
extern struct SessionLongPolicy : UserPolicy {
public:
bool DoUpdate() const override { return false; }
} session_long_policy;
extern struct UpToDatePolicy : UserPolicy {
public:
bool DoUpdate() const override { return true; }
} up_to_date_policy;
struct QueryUserOrRole {
QueryUserOrRole(std::optional<std::string> username, std::optional<std::string> rolename)
: username_{std::move(username)}, rolename_{std::move(rolename)} {}
virtual ~QueryUserOrRole() = default;
virtual bool IsAuthorized(const std::vector<AuthQuery::Privilege> &privileges, const std::string &db_name,
UserPolicy *policy) const = 0;
#ifdef MG_ENTERPRISE
virtual std::string GetDefaultDB() const = 0;
#endif
std::string key() const {
// NOTE: Each role has an associated username, that's why we check it with higher priority
return rolename_ ? *rolename_ : (username_ ? *username_ : "");
}
const std::optional<std::string> &username() const { return username_; }
const std::optional<std::string> &rolename() const { return rolename_; }
bool operator==(const QueryUserOrRole &other) const = default;
operator bool() const { return username_.has_value(); }
private:
std::optional<std::string> username_;
std::optional<std::string> rolename_;
};
} // namespace memgraph::query

View File

@ -29,6 +29,7 @@
#include "query/procedure/mg_procedure_helpers.hpp"
#include "query/procedure/mg_procedure_impl.hpp"
#include "query/procedure/module.hpp"
#include "query/query_user.hpp"
#include "query/stream/sources.hpp"
#include "query/typed_value.hpp"
#include "utils/event_counter.hpp"
@ -131,6 +132,7 @@ StreamStatus<TStream> CreateStatus(std::string stream_name, std::string transfor
const std::string kStreamName{"name"};
const std::string kIsRunningKey{"is_running"};
const std::string kOwner{"owner"};
const std::string kOwnerRole{"owner_role"};
const std::string kType{"type"};
} // namespace
@ -142,6 +144,11 @@ void to_json(nlohmann::json &data, StreamStatus<TStream> &&status) {
if (status.owner.has_value()) {
data[kOwner] = std::move(*status.owner);
if (status.owner_role.has_value()) {
data[kOwnerRole] = std::move(*status.owner_role);
} else {
data[kOwnerRole] = nullptr;
}
} else {
data[kOwner] = nullptr;
}
@ -156,6 +163,11 @@ void from_json(const nlohmann::json &data, StreamStatus<TStream> &status) {
if (const auto &owner = data.at(kOwner); !owner.is_null()) {
status.owner = owner.get<typename decltype(status.owner)::value_type>();
if (const auto &owner_role = data.at(kOwnerRole); !owner_role.is_null()) {
owner_role.get_to(status.owner_role);
} else {
status.owner_role = {};
}
} else {
status.owner = {};
}
@ -449,7 +461,7 @@ void Streams::RegisterPulsarProcedures() {
template <Stream TStream, typename TDbAccess>
void Streams::Create(const std::string &stream_name, typename TStream::StreamInfo info,
std::optional<std::string> owner, TDbAccess db_acc, InterpreterContext *ic) {
std::shared_ptr<QueryUserOrRole> owner, TDbAccess db_acc, InterpreterContext *ic) {
auto locked_streams = streams_.Lock();
auto it = CreateConsumer<TStream, TDbAccess>(*locked_streams, stream_name, std::move(info), std::move(owner),
std::move(db_acc), ic);
@ -469,31 +481,39 @@ void Streams::Create(const std::string &stream_name, typename TStream::StreamInf
template void Streams::Create<KafkaStream, dbms::DatabaseAccess>(const std::string &stream_name,
KafkaStream::StreamInfo info,
std::optional<std::string> owner,
std::shared_ptr<QueryUserOrRole> owner,
dbms::DatabaseAccess db, InterpreterContext *ic);
template void Streams::Create<PulsarStream, dbms::DatabaseAccess>(const std::string &stream_name,
PulsarStream::StreamInfo info,
std::optional<std::string> owner,
std::shared_ptr<QueryUserOrRole> owner,
dbms::DatabaseAccess db, InterpreterContext *ic);
template <Stream TStream, typename TDbAccess>
Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std::string &stream_name,
typename TStream::StreamInfo stream_info,
std::optional<std::string> owner, TDbAccess db_acc,
std::shared_ptr<QueryUserOrRole> owner, TDbAccess db_acc,
InterpreterContext *interpreter_context) {
if (map.contains(stream_name)) {
throw StreamsException{"Stream already exists with name '{}'", stream_name};
}
auto ownername = owner->username();
auto rolename = owner->rolename();
auto *memory_resource = utils::NewDeleteResource();
auto consumer_function = [interpreter_context, memory_resource, stream_name,
transformation_name = stream_info.common_info.transformation_name, owner = owner,
transformation_name = stream_info.common_info.transformation_name, owner = std::move(owner),
interpreter = std::make_shared<Interpreter>(interpreter_context, std::move(db_acc)),
result = mgp_result{nullptr, memory_resource},
total_retries = interpreter_context->config.stream_transaction_conflict_retries,
retry_interval = interpreter_context->config.stream_transaction_retry_interval](
const std::vector<typename TStream::Message> &messages) mutable {
// Set interpreter's user to the stream owner
// NOTE: We generate an empty user to avoid generating interpreter's fine grained access control and rely only on
// the global auth_checker used in the stream itself
// TODO: Fix auth inconsistency
interpreter->SetUser(interpreter_context->auth_checker->GenQueryUser(std::nullopt, std::nullopt));
#ifdef MG_ENTERPRISE
interpreter->OnChangeCB([](auto) { return false; }); // Disable database change
#endif
@ -523,12 +543,11 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std
spdlog::trace("Processing row in stream '{}'", stream_name);
auto [query_value, params_value] = ExtractTransformationResult(row.values, transformation_name, stream_name);
storage::PropertyValue params_prop{params_value};
std::string query{query_value.ValueString()};
spdlog::trace("Executing query '{}' in stream '{}'", query, stream_name);
auto prepare_result =
interpreter->Prepare(query, params_prop.IsNull() ? empty_parameters : params_prop.ValueMap(), {});
if (!interpreter_context->auth_checker->IsUserAuthorized(owner, prepare_result.privileges, "")) {
if (!owner->IsAuthorized(prepare_result.privileges, "", &up_to_date_policy)) {
throw StreamsException{
"Couldn't execute query '{}' for stream '{}' because the owner is not authorized to execute the "
"query!",
@ -553,7 +572,8 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std
};
auto insert_result = map.try_emplace(
stream_name, StreamData<TStream>{std::move(stream_info.common_info.transformation_name), std::move(owner),
stream_name, StreamData<TStream>{std::move(stream_info.common_info.transformation_name), std::move(ownername),
std::move(rolename),
std::make_unique<SynchronizedStreamSource<TStream>>(
stream_name, std::move(stream_info), std::move(consumer_function))});
MG_ASSERT(insert_result.second, "Unexpected error during storing consumer '{}'", stream_name);
@ -575,6 +595,7 @@ void Streams::RestoreStreams(TDbAccess db, InterpreterContext *ic) {
const auto create_consumer = [&, &stream_name = stream_name]<typename T>(StreamStatus<T> status,
auto &&stream_json_data) {
try {
// TODO: Migration
stream_json_data.get_to(status);
} catch (const nlohmann::json::type_error &exception) {
spdlog::warn(get_failed_message("invalid type conversion", exception.what()));
@ -586,8 +607,8 @@ void Streams::RestoreStreams(TDbAccess db, InterpreterContext *ic) {
MG_ASSERT(status.name == stream_name, "Expected stream name is '{}', but got '{}'", status.name, stream_name);
try {
auto it = CreateConsumer<T>(*locked_streams_map, stream_name, std::move(status.info), std::move(status.owner),
db, ic);
auto owner = ic->auth_checker->GenQueryUser(status.owner, status.owner_role);
auto it = CreateConsumer<T>(*locked_streams_map, stream_name, std::move(status.info), std::move(owner), db, ic);
if (status.is_running) {
std::visit(
[&](const auto &stream_data) {
@ -745,7 +766,7 @@ std::vector<StreamStatus<>> Streams::GetStreamInfo() const {
auto info = locked_stream_source->Info(stream_data.transformation_name);
result.emplace_back(StreamStatus<>{stream_name, StreamType(*locked_stream_source),
locked_stream_source->IsRunning(), std::move(info.common_info),
stream_data.owner});
stream_data.owner, stream_data.owner_role});
},
stream_data);
}

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -67,6 +67,7 @@ struct StreamStatus {
bool is_running;
StreamInfoType<T> info;
std::optional<std::string> owner;
std::optional<std::string> owner_role;
};
using TransformationResult = std::vector<std::vector<TypedValue>>;
@ -100,7 +101,7 @@ class Streams final {
///
/// @throws StreamsException if the stream with the same name exists or if the creation of Kafka consumer fails
template <Stream TStream, typename TDbAccess>
void Create(const std::string &stream_name, typename TStream::StreamInfo info, std::optional<std::string> owner,
void Create(const std::string &stream_name, typename TStream::StreamInfo info, std::shared_ptr<QueryUserOrRole> owner,
TDbAccess db, InterpreterContext *interpreter_context);
/// Deletes an existing stream and all the data that was persisted.
@ -182,6 +183,7 @@ class Streams final {
struct StreamData {
std::string transformation_name;
std::optional<std::string> owner;
std::optional<std::string> owner_role;
std::unique_ptr<SynchronizedStreamSource<TStream>> stream_source;
};
@ -191,7 +193,7 @@ class Streams final {
template <Stream TStream, typename TDbAccess>
StreamsMap::iterator CreateConsumer(StreamsMap &map, const std::string &stream_name,
typename TStream::StreamInfo stream_info, std::optional<std::string> owner,
typename TStream::StreamInfo stream_info, std::shared_ptr<QueryUserOrRole> owner,
TDbAccess db, InterpreterContext *interpreter_context);
template <Stream TStream>

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -11,14 +11,13 @@
#include "query/trigger.hpp"
#include <concepts>
#include "query/config.hpp"
#include "query/context.hpp"
#include "query/cypher_query_interpreter.hpp"
#include "query/db_accessor.hpp"
#include "query/frontend/ast/ast.hpp"
#include "query/interpret/frame.hpp"
#include "query/query_user.hpp"
#include "query/serialization/property_value.hpp"
#include "query/typed_value.hpp"
#include "storage/v2/property_value.hpp"
@ -154,20 +153,19 @@ Trigger::Trigger(std::string name, const std::string &query,
const std::map<std::string, storage::PropertyValue> &user_parameters,
const TriggerEventType event_type, utils::SkipList<QueryCacheEntry> *query_cache,
DbAccessor *db_accessor, const InterpreterConfig::Query &query_config,
std::optional<std::string> owner, const query::AuthChecker *auth_checker)
std::shared_ptr<QueryUserOrRole> owner)
: name_{std::move(name)},
parsed_statements_{ParseQuery(query, user_parameters, query_cache, query_config)},
event_type_{event_type},
owner_{std::move(owner)} {
// We check immediately if the query is valid by trying to create a plan.
GetPlan(db_accessor, auth_checker);
GetPlan(db_accessor);
}
Trigger::TriggerPlan::TriggerPlan(std::unique_ptr<LogicalPlan> logical_plan, std::vector<IdentifierInfo> identifiers)
: cached_plan(std::move(logical_plan)), identifiers(std::move(identifiers)) {}
std::shared_ptr<Trigger::TriggerPlan> Trigger::GetPlan(DbAccessor *db_accessor,
const query::AuthChecker *auth_checker) const {
std::shared_ptr<Trigger::TriggerPlan> Trigger::GetPlan(DbAccessor *db_accessor) const {
std::lock_guard plan_guard{plan_lock_};
if (!parsed_statements_.is_cacheable || !trigger_plan_) {
auto identifiers = GetPredefinedIdentifiers(event_type_);
@ -187,7 +185,7 @@ std::shared_ptr<Trigger::TriggerPlan> Trigger::GetPlan(DbAccessor *db_accessor,
trigger_plan_ = std::make_shared<TriggerPlan>(std::move(logical_plan), std::move(identifiers));
}
if (!auth_checker->IsUserAuthorized(owner_, parsed_statements_.required_privileges, "")) {
if (!owner_->IsAuthorized(parsed_statements_.required_privileges, "", &up_to_date_policy)) {
throw utils::BasicException("The owner of trigger '{}' is not authorized to execute the query!", name_);
}
return trigger_plan_;
@ -195,14 +193,13 @@ std::shared_ptr<Trigger::TriggerPlan> Trigger::GetPlan(DbAccessor *db_accessor,
void Trigger::Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution_memory,
const double max_execution_time_sec, std::atomic<bool> *is_shutting_down,
std::atomic<TransactionStatus> *transaction_status, const TriggerContext &context,
const AuthChecker *auth_checker) const {
std::atomic<TransactionStatus> *transaction_status, const TriggerContext &context) const {
if (!context.ShouldEventTrigger(event_type_)) {
return;
}
spdlog::debug("Executing trigger '{}'", name_);
auto trigger_plan = GetPlan(dba, auth_checker);
auto trigger_plan = GetPlan(dba);
MG_ASSERT(trigger_plan, "Invalid trigger plan received");
auto &[plan, identifiers] = *trigger_plan;
@ -308,6 +305,7 @@ void TriggerStore::RestoreTriggers(utils::SkipList<QueryCacheEntry> *query_cache
}
const auto user_parameters = serialization::DeserializePropertyValueMap(json_trigger_data["user_parameters"]);
// TODO: Migration
const auto owner_json = json_trigger_data["owner"];
std::optional<std::string> owner{};
if (owner_json.is_string()) {
@ -317,10 +315,21 @@ void TriggerStore::RestoreTriggers(utils::SkipList<QueryCacheEntry> *query_cache
continue;
}
const auto owner_role_json = json_trigger_data["owner_role"];
std::optional<std::string> role{};
if (owner_role_json.is_string()) {
owner.emplace(owner_role_json.get<std::string>());
} else if (!owner_role_json.is_null()) {
spdlog::warn(invalid_state_message);
continue;
}
auto user = auth_checker->GenQueryUser(owner, role);
std::optional<Trigger> trigger;
try {
trigger.emplace(trigger_name, statement, user_parameters, event_type, query_cache, db_accessor, query_config,
std::move(owner), auth_checker);
std::move(user));
} catch (const utils::BasicException &e) {
spdlog::warn("Failed to create trigger '{}' because: {}", trigger_name, e.what());
continue;
@ -338,8 +347,7 @@ void TriggerStore::AddTrigger(std::string name, const std::string &query,
const std::map<std::string, storage::PropertyValue> &user_parameters,
TriggerEventType event_type, TriggerPhase phase,
utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor,
const InterpreterConfig::Query &query_config, std::optional<std::string> owner,
const query::AuthChecker *auth_checker) {
const InterpreterConfig::Query &query_config, std::shared_ptr<QueryUserOrRole> owner) {
std::unique_lock store_guard{store_lock_};
if (storage_.Get(name)) {
throw utils::BasicException("Trigger with the same name already exists.");
@ -348,7 +356,7 @@ void TriggerStore::AddTrigger(std::string name, const std::string &query,
std::optional<Trigger> trigger;
try {
trigger.emplace(std::move(name), query, user_parameters, event_type, query_cache, db_accessor, query_config,
std::move(owner), auth_checker);
std::move(owner));
} catch (const utils::BasicException &e) {
const auto identifiers = GetPredefinedIdentifiers(event_type);
std::stringstream identifier_names_stream;
@ -370,10 +378,23 @@ void TriggerStore::AddTrigger(std::string name, const std::string &query,
data["phase"] = phase;
data["version"] = kVersion;
if (const auto &owner_from_trigger = trigger->Owner(); owner_from_trigger.has_value()) {
data["owner"] = *owner_from_trigger;
if (const auto &owner_from_trigger = trigger->Owner(); owner_from_trigger && *owner_from_trigger) {
const auto &maybe_username = owner_from_trigger->username();
if (maybe_username) {
data["owner"] = *maybe_username;
// Roles need to be associated with a username
const auto &maybe_rolename = owner_from_trigger->rolename();
if (maybe_rolename) {
data["owner_role"] = *maybe_rolename;
} else {
data["owner_role"] = nullptr;
}
} else {
data["owner"] = nullptr;
}
} else {
data["owner"] = nullptr;
data["owner_role"] = nullptr;
}
storage_.Put(trigger->Name(), data.dump());
store_guard.unlock();
@ -417,7 +438,9 @@ std::vector<TriggerStore::TriggerInfo> TriggerStore::GetTriggerInfo() const {
const auto add_info = [&](const utils::SkipList<Trigger> &trigger_list, const TriggerPhase phase) {
for (const auto &trigger : trigger_list.access()) {
info.push_back({trigger.Name(), trigger.OriginalStatement(), trigger.EventType(), phase, trigger.Owner()});
std::optional<std::string> owner_str{};
if (const auto &owner = trigger.Owner(); owner && *owner) owner_str = owner->username();
info.push_back({trigger.Name(), trigger.OriginalStatement(), trigger.EventType(), phase, std::move(owner_str)});
}
};

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -37,12 +37,11 @@ struct Trigger {
explicit Trigger(std::string name, const std::string &query,
const std::map<std::string, storage::PropertyValue> &user_parameters, TriggerEventType event_type,
utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor,
const InterpreterConfig::Query &query_config, std::optional<std::string> owner,
const query::AuthChecker *auth_checker);
const InterpreterConfig::Query &query_config, std::shared_ptr<QueryUserOrRole> owner);
void Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution_memory, double max_execution_time_sec,
std::atomic<bool> *is_shutting_down, std::atomic<TransactionStatus> *transaction_status,
const TriggerContext &context, const AuthChecker *auth_checker) const;
const TriggerContext &context) const;
bool operator==(const Trigger &other) const { return name_ == other.name_; }
// NOLINTNEXTLINE (modernize-use-nullptr)
@ -65,7 +64,7 @@ struct Trigger {
PlanWrapper cached_plan;
std::vector<IdentifierInfo> identifiers;
};
std::shared_ptr<TriggerPlan> GetPlan(DbAccessor *db_accessor, const query::AuthChecker *auth_checker) const;
std::shared_ptr<TriggerPlan> GetPlan(DbAccessor *db_accessor) const;
std::string name_;
ParsedQuery parsed_statements_;
@ -74,7 +73,7 @@ struct Trigger {
mutable utils::SpinLock plan_lock_;
mutable std::shared_ptr<TriggerPlan> trigger_plan_;
std::optional<std::string> owner_;
std::shared_ptr<QueryUserOrRole> owner_;
};
enum class TriggerPhase : uint8_t { BEFORE_COMMIT, AFTER_COMMIT };
@ -88,8 +87,7 @@ struct TriggerStore {
void AddTrigger(std::string name, const std::string &query,
const std::map<std::string, storage::PropertyValue> &user_parameters, TriggerEventType event_type,
TriggerPhase phase, utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor,
const InterpreterConfig::Query &query_config, std::optional<std::string> owner,
const query::AuthChecker *auth_checker);
const InterpreterConfig::Query &query_config, std::shared_ptr<QueryUserOrRole> owner);
void DropTrigger(const std::string &name);

View File

@ -103,7 +103,8 @@ void RecoverReplication(memgraph::replication::ReplicationState &repl_state,
inline std::optional<query::RegisterReplicaError> HandleRegisterReplicaStatus(
utils::BasicResult<replication::RegisterReplicaError, replication::ReplicationClient *> &instance_client) {
if (instance_client.HasError()) switch (instance_client.GetError()) {
if (instance_client.HasError()) {
switch (instance_client.GetError()) {
case replication::RegisterReplicaError::NOT_MAIN:
MG_ASSERT(false, "Only main instance can register a replica!");
return {};
@ -116,6 +117,7 @@ inline std::optional<query::RegisterReplicaError> HandleRegisterReplicaStatus(
case replication::RegisterReplicaError::SUCCESS:
break;
}
}
return {};
}

View File

@ -27,6 +27,7 @@ std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "e
class ExpansionBenchFixture : public benchmark::Fixture {
protected:
std::optional<memgraph::system::System> system;
std::optional<memgraph::query::AllowEverythingAuthChecker> auth_checker;
std::optional<memgraph::query::InterpreterContext> interpreter_context;
std::optional<memgraph::query::Interpreter> interpreter;
std::optional<memgraph::utils::Gatekeeper<memgraph::dbms::Database>> db_gk;
@ -43,6 +44,7 @@ class ExpansionBenchFixture : public benchmark::Fixture {
auto &db_acc = *db_acc_opt;
system.emplace();
auth_checker.emplace();
interpreter_context.emplace(memgraph::query::InterpreterConfig{}, nullptr, &repl_state.value(), *system
#ifdef MG_ENTERPRISE
,
@ -73,13 +75,15 @@ class ExpansionBenchFixture : public benchmark::Fixture {
}
interpreter.emplace(&*interpreter_context, std::move(db_acc));
interpreter->SetUser(auth_checker->GenQueryUser(std::nullopt, std::nullopt));
}
void TearDown(const benchmark::State &) override {
interpreter = std::nullopt;
interpreter_context = std::nullopt;
system.reset();
db_gk.reset();
auth_checker.reset();
system.reset();
std::filesystem::remove_all(data_directory);
}
};

View File

@ -14,14 +14,7 @@
# If you wish to modify these, update the startup_config_dict and workloads.yaml !
startup_config_dict = {
"auth_module_create_missing_role": ("true", "true", "Set to false to disable creation of missing roles."),
"auth_module_create_missing_user": ("true", "true", "Set to false to disable creation of missing users."),
"auth_module_executable": ("", "", "Absolute path to the auth module executable that should be used."),
"auth_module_manage_roles": (
"true",
"true",
"Set to false to disable management of roles through the auth module.",
),
"auth_module_timeout_ms": (
"10000",
"10000",

View File

@ -19,10 +19,10 @@ from mgclient import DatabaseError
@pytest.mark.parametrize("switch", [False, True])
def test_create_node_all_labels_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
results = common.execute_and_fetch_all(user_connection.cursor(), "CREATE (n:label1) RETURN n;")
@ -33,10 +33,10 @@ def test_create_node_all_labels_granted(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_create_node_all_labels_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -47,10 +47,10 @@ def test_create_node_all_labels_denied(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_create_node_specific_label_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label1 TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
results = common.execute_and_fetch_all(user_connection.cursor(), "CREATE (n:label1) RETURN n;")
@ -61,10 +61,10 @@ def test_create_node_specific_label_granted(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_create_node_specific_label_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :label1 TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -75,10 +75,10 @@ def test_create_node_specific_label_denied(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_delete_node_all_labels_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n:test_delete) DELETE n;")
@ -91,10 +91,10 @@ def test_delete_node_all_labels_granted(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_delete_node_all_labels_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -105,10 +105,10 @@ def test_delete_node_all_labels_denied(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_delete_node_specific_label_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :test_delete TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n:test_delete) DELETE n;")
@ -123,10 +123,10 @@ def test_delete_node_specific_label_granted(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_delete_node_specific_label_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :test_delete TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -137,11 +137,11 @@ def test_delete_node_specific_label_denied(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_create_edge_all_labels_all_edge_types_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -156,11 +156,11 @@ def test_create_edge_all_labels_all_edge_types_granted(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_create_edge_all_labels_all_edge_types_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -174,11 +174,11 @@ def test_create_edge_all_labels_all_edge_types_denied(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_create_edge_all_labels_denied_all_edge_types_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -192,11 +192,11 @@ def test_create_edge_all_labels_denied_all_edge_types_granted(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_create_edge_all_labels_granted_all_edge_types_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -210,7 +210,6 @@ def test_create_edge_all_labels_granted_all_edge_types_denied(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_create_edge_all_labels_granted_specific_edge_types_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
@ -218,6 +217,7 @@ def test_create_edge_all_labels_granted_specific_edge_types_denied(switch):
admin_connection.cursor(),
"GRANT UPDATE ON EDGE_TYPES :edge_type TO user;",
)
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -231,7 +231,6 @@ def test_create_edge_all_labels_granted_specific_edge_types_denied(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_create_edge_first_node_label_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label1 TO user;")
@ -240,6 +239,7 @@ def test_create_edge_first_node_label_granted(switch):
admin_connection.cursor(),
"GRANT CREATE_DELETE ON EDGE_TYPES :edge_type TO user;",
)
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -253,7 +253,6 @@ def test_create_edge_first_node_label_granted(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_create_edge_second_node_label_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label2 TO user;")
@ -262,6 +261,7 @@ def test_create_edge_second_node_label_granted(switch):
admin_connection.cursor(),
"GRANT CREATE_DELETE ON EDGE_TYPES :edge_type TO user;",
)
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -275,11 +275,11 @@ def test_create_edge_second_node_label_granted(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_delete_edge_all_labels_denied_all_edge_types_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -293,11 +293,11 @@ def test_delete_edge_all_labels_denied_all_edge_types_granted(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_delete_edge_all_labels_granted_all_edge_types_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -311,7 +311,6 @@ def test_delete_edge_all_labels_granted_all_edge_types_denied(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_delete_edge_all_labels_granted_specific_edge_types_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
@ -319,6 +318,7 @@ def test_delete_edge_all_labels_granted_specific_edge_types_denied(switch):
admin_connection.cursor(),
"GRANT UPDATE ON EDGE_TYPES :edge_type_delete TO user;",
)
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -332,7 +332,6 @@ def test_delete_edge_all_labels_granted_specific_edge_types_denied(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_delete_edge_first_node_label_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :test_delete_1 TO user;")
@ -341,6 +340,7 @@ def test_delete_edge_first_node_label_granted(switch):
admin_connection.cursor(),
"GRANT CREATE_DELETE ON EDGE_TYPES :edge_type_delete TO user;",
)
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -354,7 +354,6 @@ def test_delete_edge_first_node_label_granted(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_delete_edge_second_node_label_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :test_delete_2 TO user;")
@ -363,6 +362,7 @@ def test_delete_edge_second_node_label_granted(switch):
admin_connection.cursor(),
"GRANT CREATE_DELETE ON EDGE_TYPES :edge_type_delete TO user;",
)
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -376,13 +376,13 @@ def test_delete_edge_second_node_label_granted(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_delete_node_with_edge_label_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(
admin_connection.cursor(),
"GRANT UPDATE ON LABELS :test_delete_1 TO user;",
)
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -393,13 +393,13 @@ def test_delete_node_with_edge_label_denied(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_delete_node_with_edge_label_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(
admin_connection.cursor(),
"GRANT CREATE_DELETE ON LABELS :test_delete_1 TO user;",
)
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -415,10 +415,10 @@ def test_delete_node_with_edge_label_granted(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_merge_node_all_labels_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
results = common.execute_and_fetch_all(user_connection.cursor(), "MERGE (n:label1) RETURN n;")
@ -429,10 +429,10 @@ def test_merge_node_all_labels_granted(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_merge_node_all_labels_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -443,10 +443,10 @@ def test_merge_node_all_labels_denied(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_merge_node_specific_label_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label1 TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
results = common.execute_and_fetch_all(user_connection.cursor(), "MERGE (n:label1) RETURN n;")
@ -457,10 +457,10 @@ def test_merge_node_specific_label_granted(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_merge_node_specific_label_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :label1 TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -471,11 +471,11 @@ def test_merge_node_specific_label_denied(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_merge_edge_all_labels_all_edge_types_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
results = common.execute_and_fetch_all(
@ -489,11 +489,11 @@ def test_merge_edge_all_labels_all_edge_types_granted(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_merge_edge_all_labels_all_edge_types_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -507,11 +507,11 @@ def test_merge_edge_all_labels_all_edge_types_denied(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_merge_edge_all_labels_denied_all_edge_types_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -525,11 +525,11 @@ def test_merge_edge_all_labels_denied_all_edge_types_granted(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_merge_edge_all_labels_granted_all_edge_types_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -543,7 +543,6 @@ def test_merge_edge_all_labels_granted_all_edge_types_denied(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_merge_edge_all_labels_granted_specific_edge_types_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
@ -551,6 +550,7 @@ def test_merge_edge_all_labels_granted_specific_edge_types_denied(switch):
admin_connection.cursor(),
"GRANT UPDATE ON EDGE_TYPES :edge_type TO user;",
)
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -564,7 +564,6 @@ def test_merge_edge_all_labels_granted_specific_edge_types_denied(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_merge_edge_first_node_label_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label1 TO user;")
@ -573,6 +572,7 @@ def test_merge_edge_first_node_label_granted(switch):
admin_connection.cursor(),
"GRANT CREATE_DELETE ON EDGE_TYPES :edge_type TO user;",
)
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -586,7 +586,6 @@ def test_merge_edge_first_node_label_granted(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_merge_edge_second_node_label_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label2 TO user;")
@ -595,6 +594,7 @@ def test_merge_edge_second_node_label_granted(switch):
admin_connection.cursor(),
"GRANT CREATE_DELETE ON EDGE_TYPES :edge_type TO user;",
)
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -608,10 +608,10 @@ def test_merge_edge_second_node_label_granted(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_set_label_when_label_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :update_label_2 TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -621,12 +621,12 @@ def test_set_label_when_label_granted(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_set_label_when_label_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :update_label_2 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :test_delete TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -637,11 +637,11 @@ def test_set_label_when_label_denied(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_remove_label_when_label_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :test_delete TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -651,12 +651,12 @@ def test_remove_label_when_label_granted(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_remove_label_when_label_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :update_label_2 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :test_delete TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -667,12 +667,12 @@ def test_remove_label_when_label_denied(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_merge_nodes_pass_when_having_create_delete(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.reset_and_prepare(admin_connection.cursor())
common.create_multi_db(admin_connection.cursor(), switch)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())

View File

@ -7,9 +7,9 @@ import pytest
@pytest.mark.parametrize("switch", [False, True])
def test_all_edge_types_all_labels_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -21,9 +21,9 @@ def test_all_edge_types_all_labels_granted(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_deny_all_edge_types_and_all_labels(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -35,9 +35,9 @@ def test_deny_all_edge_types_and_all_labels(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_revoke_all_edge_types_and_all_labels(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -49,10 +49,10 @@ def test_revoke_all_edge_types_and_all_labels(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_deny_edge_type(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :label1, :label2, :label3 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edgeType2 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edgeType1 TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -64,10 +64,10 @@ def test_deny_edge_type(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_denied_node_label(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :label1,:label3 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edgeType1, :edgeType2 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label2 TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -79,10 +79,10 @@ def test_denied_node_label(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_denied_one_of_node_label(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :label1,:label2 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edgeType1, :edgeType2 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label3 TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -94,8 +94,8 @@ def test_denied_one_of_node_label(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_revoke_all_labels(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")
@ -106,8 +106,8 @@ def test_revoke_all_labels(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_revoke_all_edge_types(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;")

View File

@ -7,11 +7,11 @@ import pytest
@pytest.mark.parametrize("switch", [False, True])
def test_weighted_shortest_path_all_edge_types_all_labels_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -54,11 +54,11 @@ def test_weighted_shortest_path_all_edge_types_all_labels_granted(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_weighted_shortest_path_all_edge_types_all_labels_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -72,7 +72,6 @@ def test_weighted_shortest_path_all_edge_types_all_labels_denied(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_weighted_shortest_path_denied_start(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -80,6 +79,7 @@ def test_weighted_shortest_path_denied_start(switch):
)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -94,7 +94,6 @@ def test_weighted_shortest_path_denied_start(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_weighted_shortest_path_denied_destination(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -102,6 +101,7 @@ def test_weighted_shortest_path_denied_destination(switch):
)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -116,7 +116,6 @@ def test_weighted_shortest_path_denied_destination(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_weighted_shortest_path_denied_label_1(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -124,6 +123,7 @@ def test_weighted_shortest_path_denied_label_1(switch):
)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -162,7 +162,6 @@ def test_weighted_shortest_path_denied_label_1(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_weighted_shortest_path_denied_edge_type_3(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
@ -170,6 +169,7 @@ def test_weighted_shortest_path_denied_edge_type_3(switch):
admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edge_type_1, :edge_type_2, :edge_type_4 TO user;"
)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -213,11 +213,11 @@ def test_weighted_shortest_path_denied_edge_type_3(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_dfs_all_edge_types_all_labels_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -235,11 +235,11 @@ def test_dfs_all_edge_types_all_labels_granted(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_dfs_all_edge_types_all_labels_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -251,7 +251,6 @@ def test_dfs_all_edge_types_all_labels_denied(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_dfs_denied_start(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -259,6 +258,7 @@ def test_dfs_denied_start(switch):
)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -272,7 +272,6 @@ def test_dfs_denied_start(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_dfs_denied_destination(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -280,6 +279,7 @@ def test_dfs_denied_destination(switch):
)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -293,7 +293,6 @@ def test_dfs_denied_destination(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_dfs_denied_label_1(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -301,6 +300,7 @@ def test_dfs_denied_label_1(switch):
)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -318,7 +318,6 @@ def test_dfs_denied_label_1(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_dfs_denied_edge_type_3(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
@ -327,6 +326,7 @@ def test_dfs_denied_edge_type_3(switch):
admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edge_type_1, :edge_type_2, :edge_type_4 TO user;"
)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -344,11 +344,11 @@ def test_dfs_denied_edge_type_3(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_bfs_sts_all_edge_types_all_labels_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -366,11 +366,11 @@ def test_bfs_sts_all_edge_types_all_labels_granted(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_bfs_sts_all_edge_types_all_labels_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -384,7 +384,6 @@ def test_bfs_sts_all_edge_types_all_labels_denied(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_bfs_sts_denied_start(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -392,6 +391,7 @@ def test_bfs_sts_denied_start(switch):
)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -405,7 +405,6 @@ def test_bfs_sts_denied_start(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_bfs_sts_denied_destination(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -413,6 +412,7 @@ def test_bfs_sts_denied_destination(switch):
)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -426,7 +426,6 @@ def test_bfs_sts_denied_destination(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_bfs_sts_denied_label_1(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -434,6 +433,7 @@ def test_bfs_sts_denied_label_1(switch):
)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -450,7 +450,6 @@ def test_bfs_sts_denied_label_1(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_bfs_sts_denied_edge_type_3(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
@ -458,6 +457,7 @@ def test_bfs_sts_denied_edge_type_3(switch):
admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edge_type_1, :edge_type_2, :edge_type_4 TO user;"
)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -474,11 +474,11 @@ def test_bfs_sts_denied_edge_type_3(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_bfs_single_source_all_edge_types_all_labels_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -496,11 +496,11 @@ def test_bfs_single_source_all_edge_types_all_labels_granted(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_bfs_single_source_all_edge_types_all_labels_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -512,7 +512,6 @@ def test_bfs_single_source_all_edge_types_all_labels_denied(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_bfs_single_source_denied_start(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -520,6 +519,7 @@ def test_bfs_single_source_denied_start(switch):
)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -533,7 +533,6 @@ def test_bfs_single_source_denied_start(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_bfs_single_source_denied_destination(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -541,6 +540,7 @@ def test_bfs_single_source_denied_destination(switch):
)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -554,7 +554,6 @@ def test_bfs_single_source_denied_destination(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_bfs_single_source_denied_label_1(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -562,6 +561,7 @@ def test_bfs_single_source_denied_label_1(switch):
)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -579,7 +579,6 @@ def test_bfs_single_source_denied_label_1(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_bfs_single_source_denied_edge_type_3(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
@ -587,6 +586,7 @@ def test_bfs_single_source_denied_edge_type_3(switch):
admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edge_type_1, :edge_type_2, :edge_type_4 TO user;"
)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -604,11 +604,11 @@ def test_bfs_single_source_denied_edge_type_3(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_all_shortest_paths_when_all_edge_types_all_labels_granted(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -651,11 +651,11 @@ def test_all_shortest_paths_when_all_edge_types_all_labels_granted(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_all_shortest_paths_when_all_edge_types_all_labels_denied(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -669,7 +669,6 @@ def test_all_shortest_paths_when_all_edge_types_all_labels_denied(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_all_shortest_paths_when_denied_start(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -677,6 +676,7 @@ def test_all_shortest_paths_when_denied_start(switch):
)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -691,7 +691,6 @@ def test_all_shortest_paths_when_denied_start(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_all_shortest_paths_when_denied_destination(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -699,6 +698,7 @@ def test_all_shortest_paths_when_denied_destination(switch):
)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -713,7 +713,6 @@ def test_all_shortest_paths_when_denied_destination(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_all_shortest_paths_when_denied_label_1(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(
@ -721,6 +720,7 @@ def test_all_shortest_paths_when_denied_label_1(switch):
)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())
@ -759,7 +759,6 @@ def test_all_shortest_paths_when_denied_label_1(switch):
@pytest.mark.parametrize("switch", [False, True])
def test_all_shortest_paths_when_denied_edge_type_3(switch):
admin_connection = common.connect(username="admin", password="test")
user_connection = common.connect(username="user", password="test")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;")
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;")
@ -767,6 +766,7 @@ def test_all_shortest_paths_when_denied_edge_type_3(switch):
admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edge_type_1, :edge_type_2, :edge_type_4 TO user;"
)
common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;")
user_connection = common.connect(username="user", password="test")
if switch:
common.switch_db(user_connection.cursor())

View File

@ -84,11 +84,12 @@ show_databases_w_user_setup_queries: &show_databases_w_user_setup_queries
- "GRANT DATABASE db1 TO user;"
- "GRANT ALL PRIVILEGES TO user2;"
- "GRANT DATABASE db2 TO user2;"
- "GRANT DATABASE memgraph TO user2;"
- "REVOKE DATABASE memgraph FROM user2;"
- "SET MAIN DATABASE db2 FOR user2"
- "GRANT ALL PRIVILEGES TO user3;"
- "GRANT DATABASE * TO user3;"
- "REVOKE DATABASE memgraph FROM user3;"
- "DENY DATABASE memgraph FROM user3;"
- "SET MAIN DATABASE db1 FOR user3"
create_delete_filtering_in_memory_cluster: &create_delete_filtering_in_memory_cluster

View File

@ -107,18 +107,21 @@ def execute_read_node_assertion(
operation_case: List[str], queries: List[str], create_index: bool, expected_size: int, switch: bool
) -> None:
admin_cursor = get_admin_cursor()
user_cursor = get_user_cursor()
if switch:
create_multi_db(admin_cursor)
switch_db(admin_cursor)
switch_db(user_cursor)
reset_permissions(admin_cursor, create_index)
for operation in operation_case:
execute_and_fetch_all(admin_cursor, operation)
# Connect after possible auth changes
user_cursor = get_user_cursor()
if switch:
switch_db(user_cursor)
for mq in queries:
results = execute_and_fetch_all(user_cursor, mq)
assert len(results) == expected_size

View File

@ -121,6 +121,7 @@ def only_main_queries(cursor):
n_exceptions += try_and_count(cursor, f"REVOKE EDGE_TYPES :e FROM user_name")
n_exceptions += try_and_count(cursor, f"GRANT DATABASE memgraph TO user_name;")
n_exceptions += try_and_count(cursor, f"SET MAIN DATABASE memgraph FOR user_name")
n_exceptions += try_and_count(cursor, f"DENY DATABASE memgraph FROM user_name;")
n_exceptions += try_and_count(cursor, f"REVOKE DATABASE memgraph FROM user_name;")
return n_exceptions
@ -198,8 +199,8 @@ def test_auth_queries_on_replica(connection):
# 1/
assert only_main_queries(cursor_main) == 0
assert only_main_queries(cursor_replica_1) == 17
assert only_main_queries(cursor_replica_2) == 17
assert only_main_queries(cursor_replica_1) == 18
assert only_main_queries(cursor_replica_2) == 18
assert main_and_repl_queries(cursor_main) == 0
assert main_and_repl_queries(cursor_replica_1) == 0
assert main_and_repl_queries(cursor_replica_2) == 0
@ -383,6 +384,7 @@ def test_manual_roles_recovery(connection):
"--log-level=TRACE",
"--data_directory",
TEMP_DIR + "/replica1",
"--also-log-to-stderr",
],
"log_file": "replica1.log",
"setup_queries": [
@ -818,13 +820,15 @@ def test_auth_replication(connection):
{("LABEL :l3", "UPDATE", "LABEL PERMISSION GRANTED TO ROLE")},
)
# GRANT/REVOKE DATABASE
# GRANT/DENY DATABASE
execute_and_fetch_all(cursor_main, "CREATE DATABASE auth_test")
execute_and_fetch_all(cursor_main, "CREATE DATABASE auth_test2")
execute_and_fetch_all(cursor_main, "GRANT DATABASE auth_test TO user4")
check(partial(show_database_privileges_func, user="user4"), [(["auth_test", "memgraph"], [])])
execute_and_fetch_all(cursor_main, "REVOKE DATABASE auth_test2 FROM user4")
execute_and_fetch_all(cursor_main, "DENY DATABASE auth_test2 FROM user4")
check(partial(show_database_privileges_func, user="user4"), [(["auth_test", "memgraph"], ["auth_test2"])])
execute_and_fetch_all(cursor_main, "REVOKE DATABASE memgraph FROM user4")
check(partial(show_database_privileges_func, user="user4"), [(["auth_test"], ["auth_test2"])])
# SET MAIN DATABASE
execute_and_fetch_all(cursor_main, "GRANT ALL PRIVILEGES TO user4")

View File

@ -70,21 +70,26 @@ def test_multitenant_transactions():
# TODO Add SHOW TRANSACTIONS ON * that should return all transactions
def test_admin_has_one_transaction():
def test_admin_has_one_transaction(request):
"""Creates admin and tests that he sees only one transaction."""
# a_cursor is used for creating admin user, simulates main thread
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin")
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin")
def on_exit():
execute_and_fetch_all(superadmin_cursor, "DROP USER admin")
request.addfinalizer(on_exit)
admin_cursor = connect(username="admin", password="").cursor()
process = multiprocessing.Process(target=show_transactions_test, args=(admin_cursor, 1))
process.start()
process.join()
execute_and_fetch_all(superadmin_cursor, "DROP USER admin")
def test_user_can_see_its_transaction():
def test_user_can_see_its_transaction(request):
"""Tests that user without privileges can see its own transaction"""
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
@ -92,20 +97,31 @@ def test_user_can_see_its_transaction():
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin")
execute_and_fetch_all(superadmin_cursor, "CREATE USER user")
execute_and_fetch_all(superadmin_cursor, "REVOKE ALL PRIVILEGES FROM user")
def on_exit():
execute_and_fetch_all(superadmin_cursor, "DROP USER admin")
execute_and_fetch_all(superadmin_cursor, "DROP USER user")
request.addfinalizer(on_exit)
user_cursor = connect(username="user", password="").cursor()
process = multiprocessing.Process(target=show_transactions_test, args=(user_cursor, 1))
process.start()
process.join()
admin_cursor = connect(username="admin", password="").cursor()
execute_and_fetch_all(admin_cursor, "DROP USER user")
execute_and_fetch_all(admin_cursor, "DROP USER admin")
def test_explicit_transaction_output():
def test_explicit_transaction_output(request):
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin")
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin")
def on_exit():
execute_and_fetch_all(superadmin_cursor, "DROP USER admin")
request.addfinalizer(on_exit)
admin_connection = connect(username="admin", password="")
admin_cursor = admin_connection.cursor()
# Admin starts running explicit transaction
@ -123,10 +139,9 @@ def test_explicit_transaction_output():
assert show_results[1 - executing_index][2] == ["CREATE (n:Person {id_: 1})", "CREATE (n:Person {id_: 2})"]
execute_and_fetch_all(superadmin_cursor, "ROLLBACK")
execute_and_fetch_all(superadmin_cursor, "DROP USER admin")
def test_superadmin_cannot_see_admin_can_see_admin():
def test_superadmin_cannot_see_admin_can_see_admin(request):
"""Tests that superadmin cannot see the transaction created by admin but two admins can see and kill each other's transactions."""
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin1")
@ -135,6 +150,13 @@ def test_superadmin_cannot_see_admin_can_see_admin():
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin2")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin2")
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin2")
def on_exit():
execute_and_fetch_all(superadmin_cursor, "DROP USER admin1")
execute_and_fetch_all(superadmin_cursor, "DROP USER admin2")
request.addfinalizer(on_exit)
# Admin starts running infinite query
admin_connection_1 = connect(username="admin1", password="")
admin_cursor_1 = admin_connection_1.cursor()
@ -160,19 +182,23 @@ def test_superadmin_cannot_see_admin_can_see_admin():
# Kill transaction
long_transaction_id = show_results[1 - executing_index][1]
execute_and_fetch_all(admin_cursor_2, f"TERMINATE TRANSACTIONS '{long_transaction_id}'")
execute_and_fetch_all(superadmin_cursor, "DROP USER admin1")
execute_and_fetch_all(superadmin_cursor, "DROP USER admin2")
admin_connection_1.close()
admin_connection_2.close()
def test_admin_sees_superadmin():
def test_admin_sees_superadmin(request):
"""Tests that admin created by superadmin can see the superadmin's transaction."""
superadmin_connection = connect()
superadmin_cursor = superadmin_connection.cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin")
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin")
def on_exit():
execute_and_fetch_all(admin_cursor, "DROP USER admin")
request.addfinalizer(on_exit)
# Admin starts running infinite query
process = multiprocessing.Process(
target=process_function, args=(superadmin_cursor, ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"])
@ -194,17 +220,23 @@ def test_admin_sees_superadmin():
# Kill transaction
long_transaction_id = show_results[1 - executing_index][1]
execute_and_fetch_all(admin_cursor, f"TERMINATE TRANSACTIONS '{long_transaction_id}'")
execute_and_fetch_all(admin_cursor, "DROP USER admin")
superadmin_connection.close()
def test_admin_can_see_user_transaction():
def test_admin_can_see_user_transaction(request):
"""Tests that admin can see user's transaction and kill it."""
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin")
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin")
execute_and_fetch_all(superadmin_cursor, "CREATE USER user")
def on_exit():
execute_and_fetch_all(superadmin_cursor, "DROP USER admin")
execute_and_fetch_all(superadmin_cursor, "DROP USER user")
request.addfinalizer(on_exit)
# Admin starts running infinite query
admin_connection = connect(username="admin", password="")
admin_cursor = admin_connection.cursor()
@ -229,13 +261,11 @@ def test_admin_can_see_user_transaction():
# Kill transaction
long_transaction_id = show_results[1 - executing_index][1]
execute_and_fetch_all(admin_cursor, f"TERMINATE TRANSACTIONS '{long_transaction_id}'")
execute_and_fetch_all(superadmin_cursor, "DROP USER admin")
execute_and_fetch_all(superadmin_cursor, "DROP USER user")
admin_connection.close()
user_connection.close()
def test_user_cannot_see_admin_transaction():
def test_user_cannot_see_admin_transaction(request):
"""User cannot see admin's transaction but other admin can and he can kill it."""
# Superadmin creates two admins and one user
superadmin_cursor = connect().cursor()
@ -246,6 +276,14 @@ def test_user_cannot_see_admin_transaction():
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin2")
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin2")
execute_and_fetch_all(superadmin_cursor, "CREATE USER user")
def on_exit():
execute_and_fetch_all(superadmin_cursor, "DROP USER admin1")
execute_and_fetch_all(superadmin_cursor, "DROP USER admin2")
execute_and_fetch_all(superadmin_cursor, "DROP USER user")
request.addfinalizer(on_exit)
admin_connection_1 = connect(username="admin1", password="")
admin_cursor_1 = admin_connection_1.cursor()
admin_connection_2 = connect(username="admin2", password="")
@ -274,9 +312,6 @@ def test_user_cannot_see_admin_transaction():
# Kill transaction
long_transaction_id = show_results[1 - executing_index][1]
execute_and_fetch_all(admin_cursor_2, f"TERMINATE TRANSACTIONS '{long_transaction_id}'")
execute_and_fetch_all(superadmin_cursor, "DROP USER admin1")
execute_and_fetch_all(superadmin_cursor, "DROP USER admin2")
execute_and_fetch_all(superadmin_cursor, "DROP USER user")
admin_connection_1.close()
admin_connection_2.close()
user_connection.close()
@ -300,12 +335,18 @@ def test_killing_multiple_non_existing_transactions():
assert results[i][1] == False # not killed
def test_admin_killing_multiple_non_existing_transactions():
def test_admin_killing_multiple_non_existing_transactions(request):
# Starting, superadmin admin
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin")
execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin")
def on_exit():
execute_and_fetch_all(admin_cursor, "DROP USER admin")
request.addfinalizer(on_exit)
# Connect with admin
admin_cursor = connect(username="admin", password="").cursor()
transactions_id = ["'1'", "'2'", "'3'"]
@ -314,7 +355,6 @@ def test_admin_killing_multiple_non_existing_transactions():
for i in range(len(results)):
assert results[i][0] == eval(transactions_id[i]) # transaction id
assert results[i][1] == False # not killed
execute_and_fetch_all(admin_cursor, "DROP USER admin")
def test_user_killing_some_transactions():

View File

@ -193,12 +193,12 @@ def execute_test(memgraph_binary, tester_binary, checker_binary):
"GRANT DATABASE db2 TO user",
"CREATE USER useR2 IDENTIFIED BY 'user'",
"GRANT DATABASE db2 TO user2",
"REVOKE DATABASE memgraph FROM user2",
"DENY DATABASE memgraph FROM user2",
"SET MAIN DATABASE db2 FOR user2",
"CREATE USER user3 IDENTIFIED BY 'user'",
"GRANT ALL PRIVILEGES TO user3",
"GRANT DATABASE * TO user3",
"REVOKE DATABASE memgraph FROM user3",
"DENY DATABASE memgraph FROM user3",
]
)

View File

@ -139,7 +139,7 @@ class Memgraph:
def initialize_test(memgraph, tester_binary, **kwargs):
memgraph.start(module_executable="")
execute_tester(tester_binary, ["CREATE USER root", "GRANT ALL PRIVILEGES TO root"])
execute_tester(tester_binary, ["CREATE ROLE root_role", "GRANT ALL PRIVILEGES TO root_role"])
check_login = kwargs.pop("check_login", True)
memgraph.restart(**kwargs)
if check_login:
@ -149,20 +149,24 @@ def initialize_test(memgraph, tester_binary, **kwargs):
# Tests
def test_basic(memgraph, tester_binary):
def test_module_ux(memgraph, tester_binary):
initialize_test(memgraph, tester_binary)
execute_tester(tester_binary, [], "alice")
execute_tester(tester_binary, ["GRANT MATCH TO alice"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
execute_tester(tester_binary, ["CREATE USER user1"], "root", query_should_fail=True)
execute_tester(tester_binary, ["CREATE ROLE role1"], "root", query_should_fail=False)
execute_tester(tester_binary, ["DROP USER user1"], "root", query_should_fail=True)
execute_tester(tester_binary, ["DROP ROLE role1"], "root", query_should_fail=False)
execute_tester(tester_binary, ["SET ROLE FOR user1 TO role1"], "root", query_should_fail=True)
execute_tester(tester_binary, ["CLEAR ROLE FOR user1"], "root", query_should_fail=True)
memgraph.stop()
def test_only_existing_users(memgraph, tester_binary):
initialize_test(memgraph, tester_binary, create_missing_user=False)
def test_user_auth(memgraph, tester_binary):
initialize_test(memgraph, tester_binary)
execute_tester(tester_binary, [], "alice", auth_should_fail=True)
execute_tester(tester_binary, ["CREATE USER alice"], "root")
execute_tester(tester_binary, ["CREATE ROLE moderator"], "root")
execute_tester(tester_binary, [], "alice")
execute_tester(tester_binary, ["GRANT MATCH TO alice"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
memgraph.stop()
@ -170,77 +174,50 @@ def test_only_existing_users(memgraph, tester_binary):
def test_role_mapping(memgraph, tester_binary):
initialize_test(memgraph, tester_binary)
execute_tester(tester_binary, [], "alice")
execute_tester(tester_binary, [], "alice", auth_should_fail=True)
execute_tester(tester_binary, [], "bob", auth_should_fail=True)
execute_tester(tester_binary, [], "carol", auth_should_fail=True)
execute_tester(tester_binary, ["CREATE ROLE moderator"], "root")
execute_tester(tester_binary, ["CREATE ROLE admin"], "root")
execute_tester(tester_binary, [], "alice", auth_should_fail=False)
execute_tester(tester_binary, [], "bob", auth_should_fail=True)
execute_tester(tester_binary, [], "carol", auth_should_fail=False)
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "carol", query_should_fail=True)
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=False)
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "carol", query_should_fail=True)
execute_tester(tester_binary, ["GRANT MATCH TO admin"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=False)
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "carol", query_should_fail=False)
execute_tester(tester_binary, [], "bob")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "bob", query_should_fail=True)
execute_tester(tester_binary, [], "carol")
execute_tester(tester_binary, ["CREATE (n) RETURN n"], "alice", query_should_fail=True)
execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol", query_should_fail=True)
execute_tester(tester_binary, ["GRANT CREATE TO admin"], "root")
execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol")
execute_tester(tester_binary, ["CREATE (n) RETURN n"], "dave")
execute_tester(tester_binary, ["CREATE (n) RETURN n"], "alice", query_should_fail=True)
execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol", query_should_fail=False)
memgraph.stop()
def test_instance_restart(memgraph, tester_binary):
initialize_test(memgraph, tester_binary)
execute_tester(tester_binary, ["CREATE ROLE moderator"], "root")
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
memgraph.restart()
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
memgraph.stop()
def test_role_removal(memgraph, tester_binary):
initialize_test(memgraph, tester_binary)
execute_tester(tester_binary, [], "alice")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
execute_tester(tester_binary, ["CREATE ROLE moderator"], "root")
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
memgraph.restart(manage_roles=False)
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
execute_tester(tester_binary, ["CLEAR ROLE FOR alice"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
memgraph.stop()
def test_only_existing_roles(memgraph, tester_binary):
initialize_test(memgraph, tester_binary, create_missing_role=False)
execute_tester(tester_binary, [], "bob")
execute_tester(tester_binary, ["DROP ROLE moderator"], "root")
execute_tester(tester_binary, [], "alice", auth_should_fail=True)
execute_tester(tester_binary, ["CREATE ROLE moderator"], "root")
execute_tester(tester_binary, [], "alice")
memgraph.stop()
def test_role_is_user(memgraph, tester_binary):
initialize_test(memgraph, tester_binary)
execute_tester(tester_binary, [], "admin")
execute_tester(tester_binary, [], "carol", auth_should_fail=True)
memgraph.stop()
def test_user_is_role(memgraph, tester_binary):
initialize_test(memgraph, tester_binary)
execute_tester(tester_binary, [], "carol")
execute_tester(tester_binary, [], "admin", auth_should_fail=True)
memgraph.stop()
def test_user_permissions_persistancy(memgraph, tester_binary):
initialize_test(memgraph, tester_binary)
execute_tester(tester_binary, ["CREATE USER alice", "GRANT MATCH TO alice"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
memgraph.stop()
def test_role_permissions_persistancy(memgraph, tester_binary):
initialize_test(memgraph, tester_binary)
execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
memgraph.stop()
def test_only_authentication(memgraph, tester_binary):
initialize_test(memgraph, tester_binary, manage_roles=False)
execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
memgraph.stop()
@ -258,36 +235,36 @@ def test_wrong_suffix(memgraph, tester_binary):
def test_suffix_with_spaces(memgraph, tester_binary):
initialize_test(memgraph, tester_binary, suffix=", ou= people, dc = memgraph, dc = com")
execute_tester(tester_binary, ["CREATE USER alice", "GRANT MATCH TO alice"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
memgraph.stop()
def test_role_mapping_wrong_root_dn(memgraph, tester_binary):
initialize_test(memgraph, tester_binary, root_dn="ou=invalid,dc=memgraph,dc=com")
execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
memgraph.restart()
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
memgraph.stop()
def test_role_mapping_wrong_root_objectclass(memgraph, tester_binary):
initialize_test(memgraph, tester_binary, root_objectclass="person")
execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
memgraph.restart()
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
memgraph.stop()
# def test_role_mapping_wrong_root_dn(memgraph, tester_binary):
# initialize_test(memgraph, tester_binary, root_dn="ou=invalid,dc=memgraph,dc=com")
# execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root")
# execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
# memgraph.restart()
# execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
# memgraph.stop()
def test_role_mapping_wrong_user_attribute(memgraph, tester_binary):
initialize_test(memgraph, tester_binary, user_attribute="cn")
execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
memgraph.restart()
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
memgraph.stop()
# def test_role_mapping_wrong_root_objectclass(memgraph, tester_binary):
# initialize_test(memgraph, tester_binary, root_objectclass="person")
# execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root")
# execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
# memgraph.restart()
# execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
# memgraph.stop()
# def test_role_mapping_wrong_user_attribute(memgraph, tester_binary):
# initialize_test(memgraph, tester_binary, user_attribute="cn")
# execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root")
# execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True)
# memgraph.restart()
# execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice")
# memgraph.stop()
def test_wrong_password(memgraph, tester_binary):
@ -297,31 +274,9 @@ def test_wrong_password(memgraph, tester_binary):
memgraph.stop()
def test_password_persistancy(memgraph, tester_binary):
initialize_test(memgraph, tester_binary, check_login=False)
memgraph.restart(module_executable="")
execute_tester(tester_binary, ["SHOW USERS"], "root", password="sudo")
execute_tester(tester_binary, ["SHOW USERS"], "root", password="root")
memgraph.restart()
execute_tester(tester_binary, [], "root", password="sudo", auth_should_fail=True)
execute_tester(tester_binary, ["SHOW USERS"], "root", password="root")
memgraph.restart(module_executable="")
execute_tester(tester_binary, [], "root", password="sudo", auth_should_fail=True)
execute_tester(tester_binary, ["SHOW USERS"], "root", password="root")
memgraph.stop()
def test_user_multiple_roles(memgraph, tester_binary):
initialize_test(memgraph, tester_binary, check_login=False)
memgraph.restart()
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", query_should_fail=True)
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", query_should_fail=True)
memgraph.restart(manage_roles=False)
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", query_should_fail=True)
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", query_should_fail=True)
memgraph.restart(manage_roles=False, root_dn="")
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", query_should_fail=True)
execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", query_should_fail=True)
initialize_test(memgraph, tester_binary)
execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", auth_should_fail=True)
memgraph.stop()

View File

@ -84,6 +84,13 @@ objectclass: organizationalUnit
objectclass: top
ou: roles
# Role root
dn: cn=root_role,ou=roles,dc=memgraph,dc=com
cn: root_role
member: cn=root,ou=people,dc=memgraph,dc=com
objectclass: groupOfNames
objectclass: top
# Role moderator
dn: cn=moderator,ou=roles,dc=memgraph,dc=com
cn: moderator

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -48,6 +48,7 @@ int main(int argc, char **argv) {
}
if (FLAGS_auth_should_fail) {
MG_ASSERT(!what.empty(), "The authentication should have failed!");
return 0; // Auth failed, nothing left to do
} else {
MG_ASSERT(what.empty(),
"The authentication should have succeeded, but "

View File

@ -50,6 +50,8 @@ int main(int argc, char *argv[]) {
memgraph::query::Interpreter interpreter{&interpreter_context, db_acc};
ResultStreamFaker stream(db_acc->storage());
memgraph::query::AllowEverythingAuthChecker auth_checker;
interpreter.SetUser(auth_checker.GenQueryUser(std::nullopt, std::nullopt));
auto [header, _1, qid, _2] = interpreter.Prepare(argv[1], {}, {});
stream.Header(header);
auto summary = interpreter.PullAll(&stream);

View File

@ -280,6 +280,8 @@ TEST_F(AuthWithStorage, RoleManipulations) {
}
{
const auto all = auth->AllUsernames();
for (const auto &user : all) std::cout << user << std::endl;
auto users = auth->AllUsers();
std::sort(users.begin(), users.end(), [](const User &a, const User &b) { return a.username() < b.username(); });
ASSERT_EQ(users.size(), 2);
@ -774,14 +776,16 @@ TEST_F(AuthWithStorage, CaseInsensitivity) {
// Authenticate
{
auto user = auth->Authenticate("alice", "alice");
ASSERT_TRUE(user);
ASSERT_EQ(user->username(), "alice");
auto user_or_role = auth->Authenticate("alice", "alice");
ASSERT_TRUE(user_or_role);
const auto &user = std::get<memgraph::auth::User>(*user_or_role);
ASSERT_EQ(user.username(), "alice");
}
{
auto user = auth->Authenticate("alICe", "alice");
ASSERT_TRUE(user);
ASSERT_EQ(user->username(), "alice");
auto user_or_role = auth->Authenticate("alICe", "alice");
ASSERT_TRUE(user_or_role);
const auto &user = std::get<memgraph::auth::User>(*user_or_role);
ASSERT_EQ(user.username(), "alice");
}
// GetUser
@ -809,6 +813,8 @@ TEST_F(AuthWithStorage, CaseInsensitivity) {
// AllUsers
{
const auto all = auth->AllUsernames();
for (const auto &user : all) std::cout << user << std::endl;
auto users = auth->AllUsers();
ASSERT_EQ(users.size(), 2);
std::sort(users.begin(), users.end(), [](const auto &a, const auto &b) { return a.username() < b.username(); });

View File

@ -12,11 +12,14 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "auth/exceptions.hpp"
#include "auth/models.hpp"
#include "disk_test_utils.hpp"
#include "glue/auth_checker.hpp"
#include "license/license.hpp"
#include "query/frontend/ast/ast.hpp"
#include "query/query_user.hpp"
#include "query_plan_common.hpp"
#include "storage/v2/config.hpp"
#include "storage/v2/disk/storage.hpp"
@ -225,4 +228,123 @@ TYPED_TEST(FineGrainedAuthCheckerFixture, GrantAndDenySpecificEdgeTypes) {
ASSERT_FALSE(auth_checker.Has(this->r3, memgraph::query::AuthQuery::FineGrainedPrivilege::READ));
ASSERT_FALSE(auth_checker.Has(this->r4, memgraph::query::AuthQuery::FineGrainedPrivilege::READ));
}
TEST(AuthChecker, Generate) {
std::filesystem::path auth_dir{std::filesystem::temp_directory_path() / "MG_auth_checker"};
memgraph::utils::OnScopeExit clean([&]() {
if (std::filesystem::exists(auth_dir)) {
std::filesystem::remove_all(auth_dir);
}
});
memgraph::auth::SynchedAuth auth(auth_dir, memgraph::auth::Auth::Config{/* default config */});
memgraph::glue::AuthChecker auth_checker(&auth);
auto empty_user = auth_checker.GenQueryUser(std::nullopt, std::nullopt);
ASSERT_THROW(auth_checker.GenQueryUser("does_not_exist", std::nullopt), memgraph::auth::AuthException);
EXPECT_FALSE(empty_user && *empty_user);
// Still empty auth, so the above should have su permissions
using enum memgraph::query::AuthQuery::Privilege;
EXPECT_TRUE(empty_user->IsAuthorized({AUTH, REMOVE, REPLICATION}, "", &memgraph::query::session_long_policy));
EXPECT_TRUE(empty_user->IsAuthorized({FREE_MEMORY, WEBSOCKET, MULTI_DATABASE_EDIT}, "memgraph",
&memgraph::query::session_long_policy));
EXPECT_TRUE(
empty_user->IsAuthorized({TRIGGER, DURABILITY, STORAGE_MODE}, "some_db", &memgraph::query::session_long_policy));
// Add user
auth->AddUser("new_user");
// ~Empty user should now fail~
// NOTE: Cache invalidation has been disabled, so this will pass; change if it is ever turned on
EXPECT_TRUE(empty_user->IsAuthorized({AUTH, REMOVE, REPLICATION}, "", &memgraph::query::session_long_policy));
EXPECT_TRUE(empty_user->IsAuthorized({FREE_MEMORY, WEBSOCKET, MULTI_DATABASE_EDIT}, "memgraph",
&memgraph::query::session_long_policy));
EXPECT_TRUE(
empty_user->IsAuthorized({TRIGGER, DURABILITY, STORAGE_MODE}, "some_db", &memgraph::query::session_long_policy));
// Add role and new user
auto new_role = *auth->AddRole("new_role");
auto new_user2 = *auth->AddUser("new_user2");
auto role = auth_checker.GenQueryUser("anyuser", "new_role");
auto user2 = auth_checker.GenQueryUser("new_user2", std::nullopt);
// Should be permission-less by default
EXPECT_FALSE(role->IsAuthorized({AUTH}, "memgraph", &memgraph::query::session_long_policy));
EXPECT_FALSE(role->IsAuthorized({FREE_MEMORY}, "memgraph", &memgraph::query::session_long_policy));
EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::session_long_policy));
EXPECT_FALSE(user2->IsAuthorized({AUTH}, "memgraph", &memgraph::query::session_long_policy));
EXPECT_FALSE(user2->IsAuthorized({FREE_MEMORY}, "memgraph", &memgraph::query::session_long_policy));
EXPECT_FALSE(user2->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::session_long_policy));
// Update permissions and recheck
new_user2.permissions().Grant(memgraph::auth::Permission::AUTH);
new_role.permissions().Grant(memgraph::auth::Permission::TRIGGER);
auth->SaveUser(new_user2);
auth->SaveRole(new_role);
role = auth_checker.GenQueryUser("no check", "new_role");
user2 = auth_checker.GenQueryUser("new_user2", std::nullopt);
EXPECT_FALSE(role->IsAuthorized({AUTH}, "memgraph", &memgraph::query::session_long_policy));
EXPECT_FALSE(role->IsAuthorized({FREE_MEMORY}, "memgraph", &memgraph::query::session_long_policy));
EXPECT_TRUE(role->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::session_long_policy));
EXPECT_TRUE(user2->IsAuthorized({AUTH}, "memgraph", &memgraph::query::session_long_policy));
EXPECT_FALSE(user2->IsAuthorized({FREE_MEMORY}, "memgraph", &memgraph::query::session_long_policy));
EXPECT_FALSE(user2->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::session_long_policy));
// Connect role and recheck
new_user2.SetRole(new_role);
auth->SaveUser(new_user2);
user2 = auth_checker.GenQueryUser("new_user2", std::nullopt);
EXPECT_TRUE(user2->IsAuthorized({AUTH}, "memgraph", &memgraph::query::session_long_policy));
EXPECT_FALSE(user2->IsAuthorized({FREE_MEMORY}, "memgraph", &memgraph::query::session_long_policy));
EXPECT_TRUE(user2->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::session_long_policy));
// Add database and recheck
EXPECT_FALSE(user2->IsAuthorized({AUTH}, "non_default", &memgraph::query::session_long_policy));
EXPECT_FALSE(user2->IsAuthorized({AUTH}, "another", &memgraph::query::session_long_policy));
EXPECT_TRUE(user2->IsAuthorized({AUTH}, "memgraph", &memgraph::query::session_long_policy));
EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "non_default", &memgraph::query::session_long_policy));
EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "another", &memgraph::query::session_long_policy));
EXPECT_TRUE(role->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::session_long_policy));
new_user2.db_access().Grant("another");
new_role.db_access().Grant("non_default");
auth->SaveUser(new_user2);
auth->SaveRole(new_role);
// Session policy test
// Session long policy
EXPECT_FALSE(user2->IsAuthorized({AUTH}, "non_default", &memgraph::query::session_long_policy));
EXPECT_FALSE(user2->IsAuthorized({AUTH}, "another", &memgraph::query::session_long_policy));
EXPECT_TRUE(user2->IsAuthorized({AUTH}, "memgraph", &memgraph::query::session_long_policy));
EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "non_default", &memgraph::query::session_long_policy));
EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "another", &memgraph::query::session_long_policy));
EXPECT_TRUE(role->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::session_long_policy));
// Up to date policy
EXPECT_TRUE(user2->IsAuthorized({AUTH}, "non_default", &memgraph::query::up_to_date_policy));
EXPECT_TRUE(user2->IsAuthorized({AUTH}, "another", &memgraph::query::up_to_date_policy));
EXPECT_TRUE(user2->IsAuthorized({AUTH}, "memgraph", &memgraph::query::up_to_date_policy));
EXPECT_TRUE(role->IsAuthorized({TRIGGER}, "non_default", &memgraph::query::up_to_date_policy));
EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "another", &memgraph::query::up_to_date_policy));
EXPECT_TRUE(role->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::up_to_date_policy));
new_user2.db_access().Deny("memgraph");
new_role.db_access().Deny("non_default");
auth->SaveUser(new_user2);
auth->SaveRole(new_role);
EXPECT_FALSE(user2->IsAuthorized({AUTH}, "non_default", &memgraph::query::up_to_date_policy));
EXPECT_TRUE(user2->IsAuthorized({AUTH}, "another", &memgraph::query::up_to_date_policy));
EXPECT_FALSE(user2->IsAuthorized({AUTH}, "memgraph", &memgraph::query::up_to_date_policy));
EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "non_default", &memgraph::query::up_to_date_policy));
EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "another", &memgraph::query::up_to_date_policy));
EXPECT_TRUE(role->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::up_to_date_policy));
new_user2.db_access().Revoke("memgraph");
new_role.db_access().Revoke("non_default");
auth->SaveUser(new_user2);
auth->SaveRole(new_role);
EXPECT_FALSE(user2->IsAuthorized({AUTH}, "non_default", &memgraph::query::up_to_date_policy));
EXPECT_TRUE(user2->IsAuthorized({AUTH}, "another", &memgraph::query::up_to_date_policy));
EXPECT_TRUE(user2->IsAuthorized({AUTH}, "memgraph", &memgraph::query::up_to_date_policy));
EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "non_default", &memgraph::query::up_to_date_policy));
EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "another", &memgraph::query::up_to_date_policy));
EXPECT_TRUE(role->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::up_to_date_policy));
}
#endif

View File

@ -1,4 +1,4 @@
// Copyright 2023 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -18,6 +18,7 @@ struct InterpreterFaker {
: interpreter_context(interpreter_context), interpreter(interpreter_context, db) {
interpreter_context->auth_checker = &auth_checker;
interpreter_context->interpreters.WithLock([this](auto &interpreters) { interpreters.insert(&interpreter); });
interpreter.SetUser(auth_checker.GenQueryUser(std::nullopt, std::nullopt));
}
auto Prepare(const std::string &query, const std::map<std::string, memgraph::storage::PropertyValue> &params = {}) {

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2024 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -44,11 +44,9 @@ struct MockAuth : public memgraph::communication::websocket::AuthenticationInter
return authentication;
}
bool HasUserPermission(const std::string & /*username*/, memgraph::auth::Permission /*permission*/) const override {
return authorization;
}
bool HasPermission(memgraph::auth::Permission /*permission*/) const override { return authorization; }
bool HasAnyUsers() const override { return has_any_users; }
bool AccessControlled() const override { return has_any_users; }
bool authentication{true};
bool authorization{true};

View File

@ -21,6 +21,8 @@
#include "communication/result_stream_faker.hpp"
#include "dbms/database.hpp"
#include "disk_test_utils.hpp"
#include "glue/auth_checker.hpp"
#include "query/auth_checker.hpp"
#include "query/config.hpp"
#include "query/dump.hpp"
#include "query/interpreter.hpp"
@ -216,6 +218,8 @@ DatabaseState GetState(memgraph::storage::Storage *db) {
auto Execute(memgraph::query::InterpreterContext *context, memgraph::dbms::DatabaseAccess db,
const std::string &query) {
memgraph::query::Interpreter interpreter(context, db);
memgraph::query::AllowEverythingAuthChecker auth_checker;
interpreter.SetUser(auth_checker.GenQueryUser(std::nullopt, std::nullopt));
ResultStreamFaker stream(db->storage());
auto [header, _1, qid, _2] = interpreter.Prepare(query, {}, {});
@ -915,7 +919,10 @@ TYPED_TEST(DumpTest, ExecuteDumpDatabase) {
class StatefulInterpreter {
public:
explicit StatefulInterpreter(memgraph::query::InterpreterContext *context, memgraph::dbms::DatabaseAccess db)
: context_(context), interpreter_(context_, db) {}
: context_(context), interpreter_(context_, db) {
memgraph::query::AllowEverythingAuthChecker auth_checker;
interpreter_.SetUser(auth_checker.GenQueryUser(std::nullopt, std::nullopt));
}
auto Execute(const std::string &query) {
ResultStreamFaker stream(interpreter_.current_db_.db_acc_->get()->storage());
@ -1138,7 +1145,7 @@ TYPED_TEST(DumpTest, DumpDatabaseWithTriggers) {
memgraph::query::DbAccessor dba(acc.get());
const std::map<std::string, memgraph::storage::PropertyValue> props;
trigger_store->AddTrigger(trigger_name, trigger_statement, props, trigger_event_type, trigger_phase, &ast_cache,
&dba, query_config, std::nullopt, &auth_checker);
&dba, query_config, auth_checker.GenQueryUser(std::nullopt, std::nullopt));
}
{
ResultStreamFaker stream(this->db->storage());

View File

@ -22,6 +22,7 @@
#include "gtest/gtest.h"
#include "communication/result_stream_faker.hpp"
#include "query/auth_checker.hpp"
#include "query/interpreter.hpp"
#include "query/interpreter_context.hpp"
#include "query/stream/streams.hpp"
@ -36,6 +37,7 @@ class QueryExecution : public testing::Test {
const std::string testSuite = "query_plan_edge_cases";
std::optional<memgraph::dbms::DatabaseAccess> db_acc_;
std::optional<memgraph::query::InterpreterContext> interpreter_context_;
std::optional<memgraph::query::AllowEverythingAuthChecker> auth_checker_;
std::optional<memgraph::query::Interpreter> interpreter_;
std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_tests_unit_query_plan_edge_cases"};
@ -73,11 +75,14 @@ class QueryExecution : public testing::Test {
nullptr
#endif
);
auth_checker_.emplace();
interpreter_.emplace(&*interpreter_context_, *db_acc_);
interpreter_->SetUser(auth_checker_->GenQueryUser(std::nullopt, std::nullopt));
}
void TearDown() override {
interpreter_ = std::nullopt;
auth_checker_.reset();
interpreter_context_ = std::nullopt;
system_state.reset();
db_acc_.reset();

View File

@ -20,9 +20,11 @@
#include "integrations/constants.hpp"
#include "integrations/kafka/exceptions.hpp"
#include "kafka_mock.hpp"
#include "query/auth_checker.hpp"
#include "query/config.hpp"
#include "query/interpreter.hpp"
#include "query/interpreter_context.hpp"
#include "query/query_user.hpp"
#include "query/stream/streams.hpp"
#include "storage/v2/config.hpp"
#include "storage/v2/disk/storage.hpp"
@ -35,11 +37,23 @@ using StreamStatus = memgraph::query::stream::StreamStatus<memgraph::query::stre
namespace {
const static std::string kTopicName{"TrialTopic"};
struct FakeUser : memgraph::query::QueryUserOrRole {
FakeUser() : memgraph::query::QueryUserOrRole{std::nullopt, std::nullopt} {}
bool IsAuthorized(const std::vector<memgraph::query::AuthQuery::Privilege> &privileges, const std::string &db_name,
memgraph::query::UserPolicy *policy) const {
return true;
}
#ifdef MG_ENTERPRISE
std::string GetDefaultDB() const { return "memgraph"; }
#endif
};
struct StreamCheckData {
std::string name;
StreamInfo info;
bool is_running;
std::optional<std::string> owner;
std::shared_ptr<memgraph::query::QueryUserOrRole> owner;
};
std::string GetDefaultStreamName() {
@ -105,13 +119,16 @@ class StreamsTestFixture : public ::testing::Test {
}() // iile
};
memgraph::system::System system_state;
memgraph::query::InterpreterContext interpreter_context_{memgraph::query::InterpreterConfig{}, nullptr, &repl_state,
system_state
memgraph::query::AllowEverythingAuthChecker auth_checker;
memgraph::query::InterpreterContext interpreter_context_{memgraph::query::InterpreterConfig{},
nullptr,
&repl_state,
system_state,
#ifdef MG_ENTERPRISE
,
nullptr
nullptr,
#endif
};
nullptr,
&auth_checker};
std::filesystem::path streams_data_directory_{data_directory_ / "separate-dir-for-test"};
std::optional<StreamsTest> proxyStreams_;
@ -173,7 +190,7 @@ class StreamsTestFixture : public ::testing::Test {
}
StreamCheckData CreateDefaultStreamCheckData() {
return {GetDefaultStreamName(), CreateDefaultStreamInfo(), false, std::nullopt};
return {GetDefaultStreamName(), CreateDefaultStreamInfo(), false, std::make_unique<FakeUser>()};
}
void Clear() {
@ -215,11 +232,11 @@ TYPED_TEST(StreamsTestFixture, CreateAlreadyExisting) {
auto stream_info = this->CreateDefaultStreamInfo();
auto stream_name = GetDefaultStreamName();
this->proxyStreams_->streams_->template Create<memgraph::query::stream::KafkaStream>(
stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_);
stream_name, stream_info, std::make_unique<FakeUser>(), this->db_, &this->interpreter_context_);
try {
this->proxyStreams_->streams_->template Create<memgraph::query::stream::KafkaStream>(
stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_);
stream_name, stream_info, std::make_unique<FakeUser>(), this->db_, &this->interpreter_context_);
FAIL() << "Creating already existing stream should throw\n";
} catch (memgraph::query::stream::StreamsException &exception) {
EXPECT_EQ(exception.what(), fmt::format("Stream already exists with name '{}'", stream_name));
@ -231,7 +248,7 @@ TYPED_TEST(StreamsTestFixture, DropNotExistingStream) {
const auto stream_name = GetDefaultStreamName();
const std::string not_existing_stream_name{"ThisDoesn'tExists"};
this->proxyStreams_->streams_->template Create<memgraph::query::stream::KafkaStream>(
stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_);
stream_name, stream_info, std::make_unique<FakeUser>(), this->db_, &this->interpreter_context_);
try {
this->proxyStreams_->streams_->Drop(not_existing_stream_name);
@ -262,7 +279,7 @@ TYPED_TEST(StreamsTestFixture, RestoreStreams) {
if (i > 0) {
stream_info.common_info.batch_interval = std::chrono::milliseconds((i + 1) * 10);
stream_info.common_info.batch_size = 1000 + i;
stream_check_data.owner = std::string{"owner"} + iteration_postfix;
stream_check_data.owner = std::make_unique<FakeUser>();
// These are just random numbers to make the CONFIGS and CREDENTIALS map vary between consumers:
// - 0 means no config, no credential
@ -280,7 +297,7 @@ TYPED_TEST(StreamsTestFixture, RestoreStreams) {
this->mock_cluster_.CreateTopic(stream_info.topics[0]);
}
stream_check_datas[3].owner = {};
stream_check_datas[3].owner = std::make_unique<FakeUser>();
const auto check_restore_logic = [&stream_check_datas, this]() {
// Reset the Streams object to trigger reloading
@ -336,7 +353,7 @@ TYPED_TEST(StreamsTestFixture, CheckWithTimeout) {
const auto stream_info = this->CreateDefaultStreamInfo();
const auto stream_name = GetDefaultStreamName();
this->proxyStreams_->streams_->template Create<memgraph::query::stream::KafkaStream>(
stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_);
stream_name, stream_info, std::make_unique<FakeUser>(), this->db_, &this->interpreter_context_);
std::chrono::milliseconds timeout{3000};
@ -360,9 +377,10 @@ TYPED_TEST(StreamsTestFixture, CheckInvalidConfig) {
EXPECT_TRUE(message.find(kInvalidConfigName) != std::string::npos) << message;
EXPECT_TRUE(message.find(kConfigValue) != std::string::npos) << message;
};
EXPECT_THROW_WITH_MSG(this->proxyStreams_->streams_->template Create<memgraph::query::stream::KafkaStream>(
stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_),
memgraph::integrations::kafka::SettingCustomConfigFailed, checker);
EXPECT_THROW_WITH_MSG(
this->proxyStreams_->streams_->template Create<memgraph::query::stream::KafkaStream>(
stream_name, stream_info, std::make_unique<FakeUser>(), this->db_, &this->interpreter_context_),
memgraph::integrations::kafka::SettingCustomConfigFailed, checker);
}
TYPED_TEST(StreamsTestFixture, CheckInvalidCredentials) {
@ -376,7 +394,8 @@ TYPED_TEST(StreamsTestFixture, CheckInvalidCredentials) {
EXPECT_TRUE(message.find(memgraph::integrations::kReducted) != std::string::npos) << message;
EXPECT_TRUE(message.find(kCredentialValue) == std::string::npos) << message;
};
EXPECT_THROW_WITH_MSG(this->proxyStreams_->streams_->template Create<memgraph::query::stream::KafkaStream>(
stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_),
memgraph::integrations::kafka::SettingCustomConfigFailed, checker);
EXPECT_THROW_WITH_MSG(
this->proxyStreams_->streams_->template Create<memgraph::query::stream::KafkaStream>(
stream_name, stream_info, std::make_unique<FakeUser>(), this->db_, &this->interpreter_context_),
memgraph::integrations::kafka::SettingCustomConfigFailed, checker);
}

View File

@ -21,6 +21,7 @@
#include "query/db_accessor.hpp"
#include "query/frontend/ast/ast.hpp"
#include "query/interpreter.hpp"
#include "query/query_user.hpp"
#include "query/trigger.hpp"
#include "query/typed_value.hpp"
#include "storage/v2/config.hpp"
@ -42,16 +43,27 @@ const std::unordered_set<memgraph::query::TriggerEventType> kAllEventTypes{
class MockAuthChecker : public memgraph::query::AuthChecker {
public:
MOCK_CONST_METHOD3(IsUserAuthorized,
bool(const std::optional<std::string> &username,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges, const std::string &db));
MOCK_CONST_METHOD2(GenQueryUser,
std::shared_ptr<memgraph::query::QueryUserOrRole>(const std::optional<std::string> &username,
const std::optional<std::string> &rolename));
#ifdef MG_ENTERPRISE
MOCK_CONST_METHOD2(GetFineGrainedAuthChecker,
std::unique_ptr<memgraph::query::FineGrainedAuthChecker>(
const std::string &username, const memgraph::query::DbAccessor *db_accessor));
MOCK_CONST_METHOD2(GetFineGrainedAuthChecker, std::unique_ptr<memgraph::query::FineGrainedAuthChecker>(
std::shared_ptr<memgraph::query::QueryUserOrRole> user,
const memgraph::query::DbAccessor *db_accessor));
MOCK_CONST_METHOD0(ClearCache, void());
#endif
};
class MockQueryUser : public memgraph::query::QueryUserOrRole {
public:
MockQueryUser(std::optional<std::string> name) : memgraph::query::QueryUserOrRole(std::move(name), std::nullopt) {}
MOCK_CONST_METHOD3(IsAuthorized, bool(const std::vector<memgraph::query::AuthQuery::Privilege> &privileges,
const std::string &db_name, memgraph::query::UserPolicy *policy));
#ifdef MG_ENTERPRISE
MOCK_CONST_METHOD0(GetDefaultDB, std::string());
#endif
};
} // namespace
const std::string testSuite = "query_trigger";
@ -966,12 +978,12 @@ TYPED_TEST(TriggerStoreTest, Restore) {
trigger_name_before, trigger_statement,
std::map<std::string, memgraph::storage::PropertyValue>{{"parameter", memgraph::storage::PropertyValue{1}}},
event_type, memgraph::query::TriggerPhase::BEFORE_COMMIT, &this->ast_cache, &*this->dba,
memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker);
memgraph::query::InterpreterConfig::Query{}, this->auth_checker.GenQueryUser(std::nullopt, std::nullopt));
store->AddTrigger(
trigger_name_after, trigger_statement,
std::map<std::string, memgraph::storage::PropertyValue>{{"parameter", memgraph::storage::PropertyValue{"value"}}},
event_type, memgraph::query::TriggerPhase::AFTER_COMMIT, &this->ast_cache, &*this->dba,
memgraph::query::InterpreterConfig::Query{}, {owner}, &this->auth_checker);
memgraph::query::InterpreterConfig::Query{}, this->auth_checker.GenQueryUser(owner, std::nullopt));
const auto check_triggers = [&] {
ASSERT_EQ(store->GetTriggerInfo().size(), 2);
@ -981,9 +993,9 @@ TYPED_TEST(TriggerStoreTest, Restore) {
ASSERT_EQ(trigger.OriginalStatement(), trigger_statement);
ASSERT_EQ(trigger.EventType(), event_type);
if (owner != nullptr) {
ASSERT_EQ(*trigger.Owner(), *owner);
ASSERT_EQ(trigger.Owner()->username(), *owner);
} else {
ASSERT_FALSE(trigger.Owner().has_value());
ASSERT_FALSE(trigger.Owner()->username());
}
};
@ -1022,32 +1034,38 @@ TYPED_TEST(TriggerStoreTest, AddTrigger) {
// Invalid query in statements
ASSERT_THROW(store.AddTrigger("trigger", "RETUR 1", {}, memgraph::query::TriggerEventType::VERTEX_CREATE,
memgraph::query::TriggerPhase::BEFORE_COMMIT, &this->ast_cache, &*this->dba,
memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker),
memgraph::query::InterpreterConfig::Query{},
this->auth_checker.GenQueryUser(std::nullopt, std::nullopt)),
memgraph::utils::BasicException);
ASSERT_THROW(store.AddTrigger("trigger", "RETURN createdEdges", {}, memgraph::query::TriggerEventType::VERTEX_CREATE,
memgraph::query::TriggerPhase::BEFORE_COMMIT, &this->ast_cache, &*this->dba,
memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker),
memgraph::query::InterpreterConfig::Query{},
this->auth_checker.GenQueryUser(std::nullopt, std::nullopt)),
memgraph::utils::BasicException);
ASSERT_THROW(store.AddTrigger("trigger", "RETURN $parameter", {}, memgraph::query::TriggerEventType::VERTEX_CREATE,
memgraph::query::TriggerPhase::BEFORE_COMMIT, &this->ast_cache, &*this->dba,
memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker),
memgraph::query::InterpreterConfig::Query{},
this->auth_checker.GenQueryUser(std::nullopt, std::nullopt)),
memgraph::utils::BasicException);
ASSERT_NO_THROW(store.AddTrigger(
"trigger", "RETURN $parameter",
std::map<std::string, memgraph::storage::PropertyValue>{{"parameter", memgraph::storage::PropertyValue{1}}},
memgraph::query::TriggerEventType::VERTEX_CREATE, memgraph::query::TriggerPhase::BEFORE_COMMIT, &this->ast_cache,
&*this->dba, memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker));
&*this->dba, memgraph::query::InterpreterConfig::Query{},
this->auth_checker.GenQueryUser(std::nullopt, std::nullopt)));
// Inserting with the same name
ASSERT_THROW(store.AddTrigger("trigger", "RETURN 1", {}, memgraph::query::TriggerEventType::VERTEX_CREATE,
memgraph::query::TriggerPhase::BEFORE_COMMIT, &this->ast_cache, &*this->dba,
memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker),
memgraph::query::InterpreterConfig::Query{},
this->auth_checker.GenQueryUser(std::nullopt, std::nullopt)),
memgraph::utils::BasicException);
ASSERT_THROW(store.AddTrigger("trigger", "RETURN 1", {}, memgraph::query::TriggerEventType::VERTEX_CREATE,
memgraph::query::TriggerPhase::AFTER_COMMIT, &this->ast_cache, &*this->dba,
memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker),
memgraph::query::InterpreterConfig::Query{},
this->auth_checker.GenQueryUser(std::nullopt, std::nullopt)),
memgraph::utils::BasicException);
ASSERT_EQ(store.GetTriggerInfo().size(), 1);
@ -1063,7 +1081,8 @@ TYPED_TEST(TriggerStoreTest, DropTrigger) {
const auto *trigger_name = "trigger";
store.AddTrigger(trigger_name, "RETURN 1", {}, memgraph::query::TriggerEventType::VERTEX_CREATE,
memgraph::query::TriggerPhase::BEFORE_COMMIT, &this->ast_cache, &*this->dba,
memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker);
memgraph::query::InterpreterConfig::Query{},
this->auth_checker.GenQueryUser(std::nullopt, std::nullopt));
ASSERT_THROW(store.DropTrigger("Unknown"), memgraph::utils::BasicException);
ASSERT_NO_THROW(store.DropTrigger(trigger_name));
@ -1076,7 +1095,8 @@ TYPED_TEST(TriggerStoreTest, TriggerInfo) {
std::vector<memgraph::query::TriggerStore::TriggerInfo> expected_info;
store.AddTrigger("trigger", "RETURN 1", {}, memgraph::query::TriggerEventType::VERTEX_CREATE,
memgraph::query::TriggerPhase::BEFORE_COMMIT, &this->ast_cache, &*this->dba,
memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker);
memgraph::query::InterpreterConfig::Query{},
this->auth_checker.GenQueryUser(std::nullopt, std::nullopt));
expected_info.push_back({"trigger",
"RETURN 1",
memgraph::query::TriggerEventType::VERTEX_CREATE,
@ -1099,7 +1119,8 @@ TYPED_TEST(TriggerStoreTest, TriggerInfo) {
store.AddTrigger("edge_update_trigger", "RETURN 1", {}, memgraph::query::TriggerEventType::EDGE_UPDATE,
memgraph::query::TriggerPhase::AFTER_COMMIT, &this->ast_cache, &*this->dba,
memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker);
memgraph::query::InterpreterConfig::Query{},
this->auth_checker.GenQueryUser(std::nullopt, std::nullopt));
expected_info.push_back({"edge_update_trigger",
"RETURN 1",
memgraph::query::TriggerEventType::EDGE_UPDATE,
@ -1216,7 +1237,8 @@ TYPED_TEST(TriggerStoreTest, AnyTriggerAllKeywords) {
SCOPED_TRACE(keyword);
EXPECT_NO_THROW(store.AddTrigger(trigger_name, fmt::format("RETURN {}", keyword), {}, event_type,
memgraph::query::TriggerPhase::BEFORE_COMMIT, &this->ast_cache, &*this->dba,
memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker));
memgraph::query::InterpreterConfig::Query{},
this->auth_checker.GenQueryUser(std::nullopt, std::nullopt)));
store.DropTrigger(trigger_name);
}
}
@ -1228,45 +1250,50 @@ TYPED_TEST(TriggerStoreTest, AuthCheckerUsage) {
using ::testing::ElementsAre;
using ::testing::Return;
std::optional<memgraph::query::TriggerStore> store{this->testing_directory};
const std::optional<std::string> owner{"testing_owner"};
MockAuthChecker mock_checker;
const std::optional<std::string> owner{"mock_user"};
MockQueryUser mock_user(owner);
std::shared_ptr<memgraph::query::QueryUserOrRole> mock_user_ptr(
&mock_user, [](memgraph::query::QueryUserOrRole *) { /* do nothing */ });
MockQueryUser mock_userless(std::nullopt);
std::shared_ptr<memgraph::query::QueryUserOrRole> mock_userless_ptr(
&mock_userless, [](memgraph::query::QueryUserOrRole *) { /* do nothing */ });
::testing::InSequence s;
EXPECT_CALL(mock_checker, IsUserAuthorized(std::optional<std::string>{}, ElementsAre(Privilege::CREATE), ""))
.Times(1)
// TODO Userless
EXPECT_CALL(mock_user, IsAuthorized(ElementsAre(Privilege::CREATE), "", &memgraph::query::up_to_date_policy))
.WillOnce(Return(true));
EXPECT_CALL(mock_checker, IsUserAuthorized(owner, ElementsAre(Privilege::CREATE), ""))
.Times(1)
.WillOnce(Return(true));
ASSERT_NO_THROW(store->AddTrigger("successfull_trigger_1", "CREATE (n:VERTEX) RETURN n", {},
memgraph::query::TriggerEventType::EDGE_UPDATE,
memgraph::query::TriggerPhase::AFTER_COMMIT, &this->ast_cache, &*this->dba,
memgraph::query::InterpreterConfig::Query{}, std::nullopt, &mock_checker));
memgraph::query::InterpreterConfig::Query{}, mock_user_ptr));
EXPECT_CALL(mock_userless, IsAuthorized(ElementsAre(Privilege::CREATE), "", &memgraph::query::up_to_date_policy))
.WillOnce(Return(true));
ASSERT_NO_THROW(store->AddTrigger("successfull_trigger_2", "CREATE (n:VERTEX) RETURN n", {},
memgraph::query::TriggerEventType::EDGE_UPDATE,
memgraph::query::TriggerPhase::AFTER_COMMIT, &this->ast_cache, &*this->dba,
memgraph::query::InterpreterConfig::Query{}, owner, &mock_checker));
memgraph::query::InterpreterConfig::Query{}, mock_userless_ptr));
EXPECT_CALL(mock_checker, IsUserAuthorized(std::optional<std::string>{}, ElementsAre(Privilege::MATCH), ""))
.Times(1)
EXPECT_CALL(mock_user, IsAuthorized(ElementsAre(Privilege::MATCH), "", &memgraph::query::up_to_date_policy))
.WillOnce(Return(false));
ASSERT_THROW(
store->AddTrigger("unprivileged_trigger", "MATCH (n:VERTEX) RETURN n", {},
memgraph::query::TriggerEventType::EDGE_UPDATE, memgraph::query::TriggerPhase::AFTER_COMMIT,
&this->ast_cache, &*this->dba, memgraph::query::InterpreterConfig::Query{}, mock_user_ptr);
, memgraph::utils::BasicException);
ASSERT_THROW(store->AddTrigger("unprivileged_trigger", "MATCH (n:VERTEX) RETURN n", {},
memgraph::query::TriggerEventType::EDGE_UPDATE,
memgraph::query::TriggerPhase::AFTER_COMMIT, &this->ast_cache, &*this->dba,
memgraph::query::InterpreterConfig::Query{}, std::nullopt, &mock_checker);
, memgraph::utils::BasicException);
// Restore
store.emplace(this->testing_directory);
EXPECT_CALL(mock_checker, IsUserAuthorized(std::optional<std::string>{}, ElementsAre(Privilege::CREATE), ""))
.Times(1)
.WillOnce(Return(false));
EXPECT_CALL(mock_checker, IsUserAuthorized(owner, ElementsAre(Privilege::CREATE), ""))
.Times(1)
std::optional<std::string> nopt{};
EXPECT_CALL(mock_checker, GenQueryUser(owner, nopt)).WillOnce(Return(mock_user_ptr));
EXPECT_CALL(mock_user, IsAuthorized(ElementsAre(Privilege::CREATE), "", &memgraph::query::up_to_date_policy))
.WillOnce(Return(true));
EXPECT_CALL(mock_checker, GenQueryUser(nopt, nopt)).WillOnce(Return(mock_userless_ptr));
EXPECT_CALL(mock_userless, IsAuthorized(ElementsAre(Privilege::CREATE), "", &memgraph::query::up_to_date_policy))
.WillOnce(Return(false));
ASSERT_NO_THROW(store->RestoreTriggers(&this->ast_cache, &*this->dba, memgraph::query::InterpreterConfig::Query{},
&mock_checker));