Add privilege check in triggers and streams (#200)
This commit is contained in:
parent
09c58501f1
commit
09cfca35f8
.clang-tidy
src
tests
benchmark
e2e
streams
CMakeLists.txtcommon.pyconftest.pystreams_owner_tests.pystreams_test_runner.shstreams_tests.py
transformations
workloads.yamltriggers
manual
unit
@ -53,6 +53,7 @@ Checks: '*,
|
||||
-performance-unnecessary-value-param,
|
||||
-readability-braces-around-statements,
|
||||
-readability-else-after-return,
|
||||
-readability-function-cognitive-complexity,
|
||||
-readability-implicit-bool-conversion,
|
||||
-readability-magic-numbers,
|
||||
-readability-named-parameter,
|
||||
|
@ -144,7 +144,7 @@ std::optional<User> Auth::Authenticate(const std::string &username, const std::s
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<User> Auth::GetUser(const std::string &username_orig) {
|
||||
std::optional<User> Auth::GetUser(const std::string &username_orig) const {
|
||||
auto username = utils::ToLowerCase(username_orig);
|
||||
auto existing_user = storage_.Get(kUserPrefix + username);
|
||||
if (!existing_user) return std::nullopt;
|
||||
@ -170,9 +170,9 @@ std::optional<User> Auth::GetUser(const std::string &username_orig) {
|
||||
|
||||
void Auth::SaveUser(const User &user) {
|
||||
bool success = false;
|
||||
if (user.role()) {
|
||||
success = storage_.PutMultiple({{kUserPrefix + user.username(), user.Serialize().dump()},
|
||||
{kLinkPrefix + user.username(), user.role()->rolename()}});
|
||||
if (const auto *role = user.role(); role != nullptr) {
|
||||
success = storage_.PutMultiple(
|
||||
{{kUserPrefix + user.username(), user.Serialize().dump()}, {kLinkPrefix + user.username(), role->rolename()}});
|
||||
} else {
|
||||
success = storage_.PutAndDeleteMultiple({{kUserPrefix + user.username(), user.Serialize().dump()}},
|
||||
{kLinkPrefix + user.username()});
|
||||
@ -203,7 +203,7 @@ bool Auth::RemoveUser(const std::string &username_orig) {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<auth::User> Auth::AllUsers() {
|
||||
std::vector<auth::User> Auth::AllUsers() const {
|
||||
std::vector<auth::User> ret;
|
||||
for (auto it = storage_.begin(kUserPrefix); it != storage_.end(kUserPrefix); ++it) {
|
||||
auto username = it->first.substr(kUserPrefix.size());
|
||||
@ -216,9 +216,9 @@ std::vector<auth::User> Auth::AllUsers() {
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool Auth::HasUsers() { return storage_.begin(kUserPrefix) != storage_.end(kUserPrefix); }
|
||||
bool Auth::HasUsers() const { return storage_.begin(kUserPrefix) != storage_.end(kUserPrefix); }
|
||||
|
||||
std::optional<Role> Auth::GetRole(const std::string &rolename_orig) {
|
||||
std::optional<Role> Auth::GetRole(const std::string &rolename_orig) const {
|
||||
auto rolename = utils::ToLowerCase(rolename_orig);
|
||||
auto existing_role = storage_.Get(kRolePrefix + rolename);
|
||||
if (!existing_role) return std::nullopt;
|
||||
@ -265,7 +265,7 @@ bool Auth::RemoveRole(const std::string &rolename_orig) {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<auth::Role> Auth::AllRoles() {
|
||||
std::vector<auth::Role> Auth::AllRoles() const {
|
||||
std::vector<auth::Role> ret;
|
||||
for (auto it = storage_.begin(kRolePrefix); it != storage_.end(kRolePrefix); ++it) {
|
||||
auto rolename = it->first.substr(kRolePrefix.size());
|
||||
@ -280,7 +280,7 @@ std::vector<auth::Role> Auth::AllRoles() {
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<auth::User> Auth::AllUsersForRole(const std::string &rolename_orig) {
|
||||
std::vector<auth::User> Auth::AllUsersForRole(const std::string &rolename_orig) const {
|
||||
auto rolename = utils::ToLowerCase(rolename_orig);
|
||||
std::vector<auth::User> ret;
|
||||
for (auto it = storage_.begin(kLinkPrefix); it != storage_.end(kLinkPrefix); ++it) {
|
||||
@ -299,6 +299,4 @@ std::vector<auth::User> Auth::AllUsersForRole(const std::string &rolename_orig)
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::mutex &Auth::WithLock() { return lock_; }
|
||||
|
||||
} // namespace auth
|
||||
|
@ -14,8 +14,7 @@ namespace auth {
|
||||
/**
|
||||
* This class serves as the main Authentication/Authorization storage.
|
||||
* It provides functions for managing Users, Roles and Permissions.
|
||||
* NOTE: The functions in this class aren't thread safe. Use the `WithLock` lock
|
||||
* if you want to have safe modifications of the storage.
|
||||
* NOTE: The non-const functions in this class aren't thread safe.
|
||||
* TODO (mferencevic): Disable user/role modification functions when they are
|
||||
* being managed by the auth module.
|
||||
*/
|
||||
@ -42,7 +41,7 @@ class Auth final {
|
||||
* @return a user when the user exists, nullopt otherwise
|
||||
* @throw AuthException if unable to load user data.
|
||||
*/
|
||||
std::optional<User> GetUser(const std::string &username);
|
||||
std::optional<User> GetUser(const std::string &username) const;
|
||||
|
||||
/**
|
||||
* Saves a user object to the storage.
|
||||
@ -81,14 +80,14 @@ class Auth final {
|
||||
* @return a list of users
|
||||
* @throw AuthException if unable to load user data.
|
||||
*/
|
||||
std::vector<User> AllUsers();
|
||||
std::vector<User> AllUsers() const;
|
||||
|
||||
/**
|
||||
* Returns whether there are users in the storage.
|
||||
*
|
||||
* @return `true` if the storage contains any users, `false` otherwise
|
||||
*/
|
||||
bool HasUsers();
|
||||
bool HasUsers() const;
|
||||
|
||||
/**
|
||||
* Gets a role from the storage.
|
||||
@ -98,7 +97,7 @@ class Auth final {
|
||||
* @return a role when the role exists, nullopt otherwise
|
||||
* @throw AuthException if unable to load role data.
|
||||
*/
|
||||
std::optional<Role> GetRole(const std::string &rolename);
|
||||
std::optional<Role> GetRole(const std::string &rolename) const;
|
||||
|
||||
/**
|
||||
* Saves a role object to the storage.
|
||||
@ -136,7 +135,7 @@ class Auth final {
|
||||
* @return a list of roles
|
||||
* @throw AuthException if unable to load role data.
|
||||
*/
|
||||
std::vector<Role> AllRoles();
|
||||
std::vector<Role> AllRoles() const;
|
||||
|
||||
/**
|
||||
* Gets all users for a role from the storage.
|
||||
@ -146,21 +145,13 @@ class Auth final {
|
||||
* @return a list of roles
|
||||
* @throw AuthException if unable to load user data.
|
||||
*/
|
||||
std::vector<User> AllUsersForRole(const std::string &rolename);
|
||||
|
||||
/**
|
||||
* Returns a reference to the lock that should be used for all operations that
|
||||
* require more than one interaction with this class.
|
||||
*/
|
||||
std::mutex &WithLock();
|
||||
std::vector<User> AllUsersForRole(const std::string &rolename) const;
|
||||
|
||||
private:
|
||||
// Even though the `kvstore::KVStore` class is guaranteed to be thread-safe,
|
||||
// Auth is not thread-safe because modifying users and roles might require
|
||||
// more than one operation on the storage.
|
||||
kvstore::KVStore storage_;
|
||||
auth::Module module_;
|
||||
// Even though the `kvstore::KVStore` class is guaranteed to be thread-safe we
|
||||
// use a mutex to lock all operations on the `User` and `Role` storage because
|
||||
// some operations on the users and/or roles may require more than one
|
||||
// operation on the storage.
|
||||
std::mutex lock_;
|
||||
};
|
||||
} // namespace auth
|
||||
|
@ -216,7 +216,7 @@ void User::SetRole(const Role &role) { role_.emplace(role); }
|
||||
|
||||
void User::ClearRole() { role_ = std::nullopt; }
|
||||
|
||||
const Permissions User::GetPermissions() const {
|
||||
Permissions User::GetPermissions() const {
|
||||
if (role_) {
|
||||
return Permissions(permissions_.grants() | role_->permissions().grants(),
|
||||
permissions_.denies() | role_->permissions().denies());
|
||||
@ -229,7 +229,12 @@ const std::string &User::username() const { return username_; }
|
||||
const Permissions &User::permissions() const { return permissions_; }
|
||||
Permissions &User::permissions() { return permissions_; }
|
||||
|
||||
std::optional<Role> User::role() const { return role_; }
|
||||
const Role *User::role() const {
|
||||
if (role_.has_value()) {
|
||||
return &role_.value();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
nlohmann::json User::Serialize() const {
|
||||
nlohmann::json data = nlohmann::json::object();
|
||||
|
@ -127,14 +127,14 @@ class User final {
|
||||
|
||||
void ClearRole();
|
||||
|
||||
const Permissions GetPermissions() const;
|
||||
Permissions GetPermissions() const;
|
||||
|
||||
const std::string &username() const;
|
||||
|
||||
const Permissions &permissions() const;
|
||||
Permissions &permissions();
|
||||
|
||||
std::optional<Role> role() const;
|
||||
const Role *role() const;
|
||||
|
||||
nlohmann::json Serialize() const;
|
||||
|
||||
|
@ -144,7 +144,7 @@ bool KVStore::iterator::IsValid() { return pimpl_->it != nullptr; }
|
||||
|
||||
// TODO(ipaljak) The complexity of the size function should be at most
|
||||
// logarithmic.
|
||||
size_t KVStore::Size(const std::string &prefix) {
|
||||
size_t KVStore::Size(const std::string &prefix) const {
|
||||
size_t size = 0;
|
||||
for (auto it = this->begin(prefix); it != this->end(prefix); ++it) ++size;
|
||||
return size;
|
||||
|
@ -126,7 +126,7 @@ class KVStore final {
|
||||
*
|
||||
* @return - number of stored pairs.
|
||||
*/
|
||||
size_t Size(const std::string &prefix = "");
|
||||
size_t Size(const std::string &prefix = "") const;
|
||||
|
||||
/**
|
||||
* Compact the underlying storage for the key range [begin_prefix,
|
||||
@ -186,9 +186,9 @@ class KVStore final {
|
||||
std::unique_ptr<impl> pimpl_;
|
||||
};
|
||||
|
||||
iterator begin(const std::string &prefix = "") { return iterator(this, prefix); }
|
||||
iterator begin(const std::string &prefix = "") const { return iterator(this, prefix); }
|
||||
|
||||
iterator end(const std::string &prefix = "") { return iterator(this, prefix, true); }
|
||||
iterator end(const std::string &prefix = "") const { return iterator(this, prefix, true); }
|
||||
|
||||
private:
|
||||
struct impl;
|
||||
|
839
src/memgraph.cpp
839
src/memgraph.cpp
@ -23,6 +23,7 @@
|
||||
#include "communication/bolt/v1/constants.hpp"
|
||||
#include "helpers.hpp"
|
||||
#include "py/py.hpp"
|
||||
#include "query/auth_checker.hpp"
|
||||
#include "query/discard_value_stream.hpp"
|
||||
#include "query/exceptions.hpp"
|
||||
#include "query/interpreter.hpp"
|
||||
@ -40,8 +41,10 @@
|
||||
#include "utils/logging.hpp"
|
||||
#include "utils/memory_tracker.hpp"
|
||||
#include "utils/readable_size.hpp"
|
||||
#include "utils/rw_lock.hpp"
|
||||
#include "utils/signals.hpp"
|
||||
#include "utils/string.hpp"
|
||||
#include "utils/synchronized.hpp"
|
||||
#include "utils/sysinfo/memory.hpp"
|
||||
#include "utils/terminate_handler.hpp"
|
||||
#include "version.hpp"
|
||||
@ -354,15 +357,373 @@ void ConfigureLogging() {
|
||||
struct SessionData {
|
||||
// Explicit constructor here to ensure that pointers to all objects are
|
||||
// supplied.
|
||||
SessionData(storage::Storage *db, query::InterpreterContext *interpreter_context, auth::Auth *auth,
|
||||
audit::Log *audit_log)
|
||||
SessionData(storage::Storage *db, query::InterpreterContext *interpreter_context,
|
||||
utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth, audit::Log *audit_log)
|
||||
: db(db), interpreter_context(interpreter_context), auth(auth), audit_log(audit_log) {}
|
||||
storage::Storage *db;
|
||||
query::InterpreterContext *interpreter_context;
|
||||
auth::Auth *auth;
|
||||
utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth;
|
||||
audit::Log *audit_log;
|
||||
};
|
||||
|
||||
DEFINE_string(auth_user_or_role_name_regex, "[a-zA-Z0-9_.+-@]+",
|
||||
"Set to the regular expression that each user or role name must fulfill.");
|
||||
|
||||
class AuthQueryHandler final : public query::AuthQueryHandler {
|
||||
utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth_;
|
||||
std::regex name_regex_;
|
||||
|
||||
public:
|
||||
AuthQueryHandler(utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth, const std::regex &name_regex)
|
||||
: auth_(auth), name_regex_(name_regex) {}
|
||||
|
||||
bool CreateUser(const std::string &username, const std::optional<std::string> &password) override {
|
||||
if (!std::regex_match(username, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid user name.");
|
||||
}
|
||||
try {
|
||||
auto locked_auth = auth_->Lock();
|
||||
return locked_auth->AddUser(username, password).has_value();
|
||||
} catch (const auth::AuthException &e) {
|
||||
throw query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
bool DropUser(const std::string &username) override {
|
||||
if (!std::regex_match(username, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid user name.");
|
||||
}
|
||||
try {
|
||||
auto locked_auth = auth_->Lock();
|
||||
auto user = locked_auth->GetUser(username);
|
||||
if (!user) return false;
|
||||
return locked_auth->RemoveUser(username);
|
||||
} catch (const auth::AuthException &e) {
|
||||
throw query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void SetPassword(const std::string &username, const std::optional<std::string> &password) override {
|
||||
if (!std::regex_match(username, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid user name.");
|
||||
}
|
||||
try {
|
||||
auto locked_auth = auth_->Lock();
|
||||
auto user = locked_auth->GetUser(username);
|
||||
if (!user) {
|
||||
throw query::QueryRuntimeException("User '{}' doesn't exist.", username);
|
||||
}
|
||||
user->UpdatePassword(password);
|
||||
locked_auth->SaveUser(*user);
|
||||
} catch (const auth::AuthException &e) {
|
||||
throw query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
bool CreateRole(const std::string &rolename) override {
|
||||
if (!std::regex_match(rolename, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid role name.");
|
||||
}
|
||||
try {
|
||||
auto locked_auth = auth_->Lock();
|
||||
return locked_auth->AddRole(rolename).has_value();
|
||||
} catch (const auth::AuthException &e) {
|
||||
throw query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
bool DropRole(const std::string &rolename) override {
|
||||
if (!std::regex_match(rolename, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid role name.");
|
||||
}
|
||||
try {
|
||||
auto locked_auth = auth_->Lock();
|
||||
auto role = locked_auth->GetRole(rolename);
|
||||
if (!role) return false;
|
||||
return locked_auth->RemoveRole(rolename);
|
||||
} catch (const auth::AuthException &e) {
|
||||
throw query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<query::TypedValue> GetUsernames() override {
|
||||
try {
|
||||
auto locked_auth = auth_->ReadLock();
|
||||
std::vector<query::TypedValue> usernames;
|
||||
const auto &users = locked_auth->AllUsers();
|
||||
usernames.reserve(users.size());
|
||||
for (const auto &user : users) {
|
||||
usernames.emplace_back(user.username());
|
||||
}
|
||||
return usernames;
|
||||
} catch (const auth::AuthException &e) {
|
||||
throw query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<query::TypedValue> GetRolenames() override {
|
||||
try {
|
||||
auto locked_auth = auth_->ReadLock();
|
||||
std::vector<query::TypedValue> rolenames;
|
||||
const auto &roles = locked_auth->AllRoles();
|
||||
rolenames.reserve(roles.size());
|
||||
for (const auto &role : roles) {
|
||||
rolenames.emplace_back(role.rolename());
|
||||
}
|
||||
return rolenames;
|
||||
} catch (const auth::AuthException &e) {
|
||||
throw query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<std::string> GetRolenameForUser(const std::string &username) override {
|
||||
if (!std::regex_match(username, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid user name.");
|
||||
}
|
||||
try {
|
||||
auto locked_auth = auth_->ReadLock();
|
||||
auto user = locked_auth->GetUser(username);
|
||||
if (!user) {
|
||||
throw query::QueryRuntimeException("User '{}' doesn't exist .", username);
|
||||
}
|
||||
|
||||
if (const auto *role = user->role(); role != nullptr) {
|
||||
return role->rolename();
|
||||
}
|
||||
return std::nullopt;
|
||||
} catch (const auth::AuthException &e) {
|
||||
throw query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<query::TypedValue> GetUsernamesForRole(const std::string &rolename) override {
|
||||
if (!std::regex_match(rolename, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid role name.");
|
||||
}
|
||||
try {
|
||||
auto locked_auth = auth_->ReadLock();
|
||||
auto role = locked_auth->GetRole(rolename);
|
||||
if (!role) {
|
||||
throw query::QueryRuntimeException("Role '{}' doesn't exist.", rolename);
|
||||
}
|
||||
std::vector<query::TypedValue> usernames;
|
||||
const auto &users = locked_auth->AllUsersForRole(rolename);
|
||||
usernames.reserve(users.size());
|
||||
for (const auto &user : users) {
|
||||
usernames.emplace_back(user.username());
|
||||
}
|
||||
return usernames;
|
||||
} catch (const auth::AuthException &e) {
|
||||
throw query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void SetRole(const std::string &username, const std::string &rolename) override {
|
||||
if (!std::regex_match(username, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid user name.");
|
||||
}
|
||||
if (!std::regex_match(rolename, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid role name.");
|
||||
}
|
||||
try {
|
||||
auto locked_auth = auth_->Lock();
|
||||
auto user = locked_auth->GetUser(username);
|
||||
if (!user) {
|
||||
throw query::QueryRuntimeException("User '{}' doesn't exist .", username);
|
||||
}
|
||||
auto role = locked_auth->GetRole(rolename);
|
||||
if (!role) {
|
||||
throw query::QueryRuntimeException("Role '{}' doesn't exist .", rolename);
|
||||
}
|
||||
if (const auto *current_role = user->role(); current_role != nullptr) {
|
||||
throw query::QueryRuntimeException("User '{}' is already a member of role '{}'.", username,
|
||||
current_role->rolename());
|
||||
}
|
||||
user->SetRole(*role);
|
||||
locked_auth->SaveUser(*user);
|
||||
} catch (const auth::AuthException &e) {
|
||||
throw query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void ClearRole(const std::string &username) override {
|
||||
if (!std::regex_match(username, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid user name.");
|
||||
}
|
||||
try {
|
||||
auto locked_auth = auth_->Lock();
|
||||
auto user = locked_auth->GetUser(username);
|
||||
if (!user) {
|
||||
throw query::QueryRuntimeException("User '{}' doesn't exist .", username);
|
||||
}
|
||||
user->ClearRole();
|
||||
locked_auth->SaveUser(*user);
|
||||
} catch (const auth::AuthException &e) {
|
||||
throw query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<query::TypedValue>> GetPrivileges(const std::string &user_or_role) override {
|
||||
if (!std::regex_match(user_or_role, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid user or role name.");
|
||||
}
|
||||
try {
|
||||
auto locked_auth = auth_->ReadLock();
|
||||
std::vector<std::vector<query::TypedValue>> grants;
|
||||
auto user = locked_auth->GetUser(user_or_role);
|
||||
auto role = locked_auth->GetRole(user_or_role);
|
||||
if (!user && !role) {
|
||||
throw query::QueryRuntimeException("User or role '{}' doesn't exist.", user_or_role);
|
||||
}
|
||||
if (user) {
|
||||
const auto &permissions = user->GetPermissions();
|
||||
for (const auto &privilege : query::kPrivilegesAll) {
|
||||
auto permission = glue::PrivilegeToPermission(privilege);
|
||||
auto effective = permissions.Has(permission);
|
||||
if (permissions.Has(permission) != auth::PermissionLevel::NEUTRAL) {
|
||||
std::vector<std::string> description;
|
||||
auto user_level = user->permissions().Has(permission);
|
||||
if (user_level == auth::PermissionLevel::GRANT) {
|
||||
description.emplace_back("GRANTED TO USER");
|
||||
} else if (user_level == auth::PermissionLevel::DENY) {
|
||||
description.emplace_back("DENIED TO USER");
|
||||
}
|
||||
if (const auto *role = user->role(); role != nullptr) {
|
||||
auto role_level = role->permissions().Has(permission);
|
||||
if (role_level == auth::PermissionLevel::GRANT) {
|
||||
description.emplace_back("GRANTED TO ROLE");
|
||||
} else if (role_level == auth::PermissionLevel::DENY) {
|
||||
description.emplace_back("DENIED TO ROLE");
|
||||
}
|
||||
}
|
||||
grants.push_back({query::TypedValue(auth::PermissionToString(permission)),
|
||||
query::TypedValue(auth::PermissionLevelToString(effective)),
|
||||
query::TypedValue(utils::Join(description, ", "))});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const auto &permissions = role->permissions();
|
||||
for (const auto &privilege : query::kPrivilegesAll) {
|
||||
auto permission = glue::PrivilegeToPermission(privilege);
|
||||
auto effective = permissions.Has(permission);
|
||||
if (effective != auth::PermissionLevel::NEUTRAL) {
|
||||
std::string description;
|
||||
if (effective == auth::PermissionLevel::GRANT) {
|
||||
description = "GRANTED TO ROLE";
|
||||
} else if (effective == auth::PermissionLevel::DENY) {
|
||||
description = "DENIED TO ROLE";
|
||||
}
|
||||
grants.push_back({query::TypedValue(auth::PermissionToString(permission)),
|
||||
query::TypedValue(auth::PermissionLevelToString(effective)),
|
||||
query::TypedValue(description)});
|
||||
}
|
||||
}
|
||||
}
|
||||
return grants;
|
||||
} catch (const auth::AuthException &e) {
|
||||
throw query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void GrantPrivilege(const std::string &user_or_role,
|
||||
const std::vector<query::AuthQuery::Privilege> &privileges) override {
|
||||
EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) {
|
||||
// TODO (mferencevic): should we first check that the
|
||||
// privilege is granted/denied/revoked before
|
||||
// unconditionally granting/denying/revoking it?
|
||||
permissions->Grant(permission);
|
||||
});
|
||||
}
|
||||
|
||||
void DenyPrivilege(const std::string &user_or_role,
|
||||
const std::vector<query::AuthQuery::Privilege> &privileges) override {
|
||||
EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) {
|
||||
// TODO (mferencevic): should we first check that the
|
||||
// privilege is granted/denied/revoked before
|
||||
// unconditionally granting/denying/revoking it?
|
||||
permissions->Deny(permission);
|
||||
});
|
||||
}
|
||||
|
||||
void RevokePrivilege(const std::string &user_or_role,
|
||||
const std::vector<query::AuthQuery::Privilege> &privileges) override {
|
||||
EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) {
|
||||
// TODO (mferencevic): should we first check that the
|
||||
// privilege is granted/denied/revoked before
|
||||
// unconditionally granting/denying/revoking it?
|
||||
permissions->Revoke(permission);
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
template <class TEditFun>
|
||||
void EditPermissions(const std::string &user_or_role, const std::vector<query::AuthQuery::Privilege> &privileges,
|
||||
const TEditFun &edit_fun) {
|
||||
if (!std::regex_match(user_or_role, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid user or role name.");
|
||||
}
|
||||
try {
|
||||
auto locked_auth = auth_->Lock();
|
||||
std::vector<auth::Permission> permissions;
|
||||
permissions.reserve(privileges.size());
|
||||
for (const auto &privilege : privileges) {
|
||||
permissions.push_back(glue::PrivilegeToPermission(privilege));
|
||||
}
|
||||
auto user = locked_auth->GetUser(user_or_role);
|
||||
auto role = locked_auth->GetRole(user_or_role);
|
||||
if (!user && !role) {
|
||||
throw query::QueryRuntimeException("User or role '{}' doesn't exist.", user_or_role);
|
||||
}
|
||||
if (user) {
|
||||
for (const auto &permission : permissions) {
|
||||
edit_fun(&user->permissions(), permission);
|
||||
}
|
||||
locked_auth->SaveUser(*user);
|
||||
} else {
|
||||
for (const auto &permission : permissions) {
|
||||
edit_fun(&role->permissions(), permission);
|
||||
}
|
||||
locked_auth->SaveRole(*role);
|
||||
}
|
||||
} catch (const auth::AuthException &e) {
|
||||
throw query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class AuthChecker final : public query::AuthChecker {
|
||||
public:
|
||||
explicit AuthChecker(utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth) : auth_{auth} {}
|
||||
|
||||
static bool IsUserAuthorized(const auth::User &user, const std::vector<query::AuthQuery::Privilege> &privileges) {
|
||||
const auto user_permissions = user.GetPermissions();
|
||||
return std::all_of(privileges.begin(), privileges.end(), [&user_permissions](const auto privilege) {
|
||||
return user_permissions.Has(glue::PrivilegeToPermission(privilege)) == auth::PermissionLevel::GRANT;
|
||||
});
|
||||
}
|
||||
|
||||
bool IsUserAuthorized(const std::optional<std::string> &username,
|
||||
const std::vector<query::AuthQuery::Privilege> &privileges) const final {
|
||||
std::optional<auth::User> maybe_user;
|
||||
{
|
||||
auto locked_auth = auth_->ReadLock();
|
||||
if (!locked_auth->HasUsers()) {
|
||||
return true;
|
||||
}
|
||||
if (username.has_value()) {
|
||||
maybe_user = locked_auth->GetUser(*username);
|
||||
}
|
||||
}
|
||||
|
||||
return maybe_user.has_value() && IsUserAuthorized(*maybe_user, privileges);
|
||||
}
|
||||
|
||||
private:
|
||||
utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth_;
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
struct SessionData {
|
||||
// Explicit constructor here to ensure that pointers to all objects are
|
||||
// supplied.
|
||||
@ -371,6 +732,52 @@ struct SessionData {
|
||||
storage::Storage *db;
|
||||
query::InterpreterContext *interpreter_context;
|
||||
};
|
||||
|
||||
class NoAuthInCommunity : public query::QueryRuntimeException {
|
||||
public:
|
||||
NoAuthInCommunity()
|
||||
: query::QueryRuntimeException::QueryRuntimeException("Auth is not supported in Memgraph Community!") {}
|
||||
};
|
||||
|
||||
class AuthQueryHandler final : public query::AuthQueryHandler {
|
||||
public:
|
||||
bool CreateUser(const std::string &, const std::optional<std::string> &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
bool DropUser(const std::string &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
void SetPassword(const std::string &, const std::optional<std::string> &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
bool CreateRole(const std::string &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
bool DropRole(const std::string &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
std::vector<query::TypedValue> GetUsernames() override { throw NoAuthInCommunity(); }
|
||||
|
||||
std::vector<query::TypedValue> GetRolenames() override { throw NoAuthInCommunity(); }
|
||||
|
||||
std::optional<std::string> GetRolenameForUser(const std::string &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
std::vector<query::TypedValue> GetUsernamesForRole(const std::string &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
void SetRole(const std::string &, const std::string &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
void ClearRole(const std::string &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
std::vector<std::vector<query::TypedValue>> GetPrivileges(const std::string &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
void GrantPrivilege(const std::string &, const std::vector<query::AuthQuery::Privilege> &) override {
|
||||
throw NoAuthInCommunity();
|
||||
}
|
||||
|
||||
void DenyPrivilege(const std::string &, const std::vector<query::AuthQuery::Privilege> &) override {
|
||||
throw NoAuthInCommunity();
|
||||
}
|
||||
|
||||
void RevokePrivilege(const std::string &, const std::vector<query::AuthQuery::Privilege> &) override {
|
||||
throw NoAuthInCommunity();
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
class BoltSession final : public communication::bolt::Session<communication::InputStream, communication::OutputStream> {
|
||||
@ -400,22 +807,21 @@ class BoltSession final : public communication::bolt::Session<communication::Inp
|
||||
const std::string &query, const std::map<std::string, communication::bolt::Value> ¶ms) override {
|
||||
std::map<std::string, storage::PropertyValue> params_pv;
|
||||
for (const auto &kv : params) params_pv.emplace(kv.first, glue::ToPropertyValue(kv.second));
|
||||
const std::string *username{nullptr};
|
||||
#ifdef MG_ENTERPRISE
|
||||
audit_log_->Record(endpoint_.address, user_ ? user_->username() : "", query, storage::PropertyValue(params_pv));
|
||||
if (user_) {
|
||||
username = &user_->username();
|
||||
}
|
||||
audit_log_->Record(endpoint_.address, user_ ? *username : "", query, storage::PropertyValue(params_pv));
|
||||
#endif
|
||||
try {
|
||||
auto result = interpreter_.Prepare(query, params_pv);
|
||||
auto result = interpreter_.Prepare(query, params_pv, username);
|
||||
#ifdef MG_ENTERPRISE
|
||||
if (user_) {
|
||||
const auto &permissions = user_->GetPermissions();
|
||||
for (const auto &privilege : result.privileges) {
|
||||
if (permissions.Has(glue::PrivilegeToPermission(privilege)) != auth::PermissionLevel::GRANT) {
|
||||
interpreter_.Abort();
|
||||
throw communication::bolt::ClientError(
|
||||
"You are not authorized to execute this query! Please contact "
|
||||
"your database administrator.");
|
||||
}
|
||||
}
|
||||
if (user_ && !AuthChecker::IsUserAuthorized(*user_, result.privileges)) {
|
||||
interpreter_.Abort();
|
||||
throw communication::bolt::ClientError(
|
||||
"You are not authorized to execute this query! Please contact "
|
||||
"your database administrator.");
|
||||
}
|
||||
#endif
|
||||
return {result.headers, result.qid};
|
||||
@ -442,9 +848,12 @@ class BoltSession final : public communication::bolt::Session<communication::Inp
|
||||
|
||||
bool Authenticate(const std::string &username, const std::string &password) override {
|
||||
#ifdef MG_ENTERPRISE
|
||||
if (!auth_->HasUsers()) return true;
|
||||
user_ = auth_->Authenticate(username, password);
|
||||
return !!user_;
|
||||
auto locked_auth = auth_->Lock();
|
||||
if (!locked_auth->HasUsers()) {
|
||||
return true;
|
||||
}
|
||||
user_ = locked_auth->Authenticate(username, password);
|
||||
return user_.has_value();
|
||||
#else
|
||||
return true;
|
||||
#endif
|
||||
@ -522,7 +931,7 @@ class BoltSession final : public communication::bolt::Session<communication::Inp
|
||||
const storage::Storage *db_;
|
||||
query::Interpreter interpreter_;
|
||||
#ifdef MG_ENTERPRISE
|
||||
auth::Auth *auth_;
|
||||
utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth_;
|
||||
std::optional<auth::User> user_;
|
||||
audit::Log *audit_log_;
|
||||
#endif
|
||||
@ -532,374 +941,6 @@ class BoltSession final : public communication::bolt::Session<communication::Inp
|
||||
using ServerT = communication::Server<BoltSession, SessionData>;
|
||||
using communication::ServerContext;
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
DEFINE_string(auth_user_or_role_name_regex, "[a-zA-Z0-9_.+-@]+",
|
||||
"Set to the regular expression that each user or role name must fulfill.");
|
||||
|
||||
class AuthQueryHandler final : public query::AuthQueryHandler {
|
||||
auth::Auth *auth_;
|
||||
std::regex name_regex_;
|
||||
|
||||
public:
|
||||
AuthQueryHandler(auth::Auth *auth, const std::regex &name_regex) : auth_(auth), name_regex_(name_regex) {}
|
||||
|
||||
bool CreateUser(const std::string &username, const std::optional<std::string> &password) override {
|
||||
if (!std::regex_match(username, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid user name.");
|
||||
}
|
||||
try {
|
||||
std::lock_guard<std::mutex> lock(auth_->WithLock());
|
||||
return !!auth_->AddUser(username, password);
|
||||
} catch (const auth::AuthException &e) {
|
||||
throw query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
bool DropUser(const std::string &username) override {
|
||||
if (!std::regex_match(username, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid user name.");
|
||||
}
|
||||
try {
|
||||
std::lock_guard<std::mutex> lock(auth_->WithLock());
|
||||
auto user = auth_->GetUser(username);
|
||||
if (!user) return false;
|
||||
return auth_->RemoveUser(username);
|
||||
} catch (const auth::AuthException &e) {
|
||||
throw query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void SetPassword(const std::string &username, const std::optional<std::string> &password) override {
|
||||
if (!std::regex_match(username, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid user name.");
|
||||
}
|
||||
try {
|
||||
std::lock_guard<std::mutex> lock(auth_->WithLock());
|
||||
auto user = auth_->GetUser(username);
|
||||
if (!user) {
|
||||
throw query::QueryRuntimeException("User '{}' doesn't exist.", username);
|
||||
}
|
||||
user->UpdatePassword(password);
|
||||
auth_->SaveUser(*user);
|
||||
} catch (const auth::AuthException &e) {
|
||||
throw query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
bool CreateRole(const std::string &rolename) override {
|
||||
if (!std::regex_match(rolename, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid role name.");
|
||||
}
|
||||
try {
|
||||
std::lock_guard<std::mutex> lock(auth_->WithLock());
|
||||
return !!auth_->AddRole(rolename);
|
||||
} catch (const auth::AuthException &e) {
|
||||
throw query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
bool DropRole(const std::string &rolename) override {
|
||||
if (!std::regex_match(rolename, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid role name.");
|
||||
}
|
||||
try {
|
||||
std::lock_guard<std::mutex> lock(auth_->WithLock());
|
||||
auto role = auth_->GetRole(rolename);
|
||||
if (!role) return false;
|
||||
return auth_->RemoveRole(rolename);
|
||||
} catch (const auth::AuthException &e) {
|
||||
throw query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<query::TypedValue> GetUsernames() override {
|
||||
try {
|
||||
std::lock_guard<std::mutex> lock(auth_->WithLock());
|
||||
std::vector<query::TypedValue> usernames;
|
||||
const auto &users = auth_->AllUsers();
|
||||
usernames.reserve(users.size());
|
||||
for (const auto &user : users) {
|
||||
usernames.emplace_back(user.username());
|
||||
}
|
||||
return usernames;
|
||||
} catch (const auth::AuthException &e) {
|
||||
throw query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<query::TypedValue> GetRolenames() override {
|
||||
try {
|
||||
std::lock_guard<std::mutex> lock(auth_->WithLock());
|
||||
std::vector<query::TypedValue> rolenames;
|
||||
const auto &roles = auth_->AllRoles();
|
||||
rolenames.reserve(roles.size());
|
||||
for (const auto &role : roles) {
|
||||
rolenames.emplace_back(role.rolename());
|
||||
}
|
||||
return rolenames;
|
||||
} catch (const auth::AuthException &e) {
|
||||
throw query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<std::string> GetRolenameForUser(const std::string &username) override {
|
||||
if (!std::regex_match(username, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid user name.");
|
||||
}
|
||||
try {
|
||||
std::lock_guard<std::mutex> lock(auth_->WithLock());
|
||||
auto user = auth_->GetUser(username);
|
||||
if (!user) {
|
||||
throw query::QueryRuntimeException("User '{}' doesn't exist .", username);
|
||||
}
|
||||
if (user->role()) return user->role()->rolename();
|
||||
return std::nullopt;
|
||||
} catch (const auth::AuthException &e) {
|
||||
throw query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<query::TypedValue> GetUsernamesForRole(const std::string &rolename) override {
|
||||
if (!std::regex_match(rolename, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid role name.");
|
||||
}
|
||||
try {
|
||||
std::lock_guard<std::mutex> lock(auth_->WithLock());
|
||||
auto role = auth_->GetRole(rolename);
|
||||
if (!role) {
|
||||
throw query::QueryRuntimeException("Role '{}' doesn't exist.", rolename);
|
||||
}
|
||||
std::vector<query::TypedValue> usernames;
|
||||
const auto &users = auth_->AllUsersForRole(rolename);
|
||||
usernames.reserve(users.size());
|
||||
for (const auto &user : users) {
|
||||
usernames.emplace_back(user.username());
|
||||
}
|
||||
return usernames;
|
||||
} catch (const auth::AuthException &e) {
|
||||
throw query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void SetRole(const std::string &username, const std::string &rolename) override {
|
||||
if (!std::regex_match(username, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid user name.");
|
||||
}
|
||||
if (!std::regex_match(rolename, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid role name.");
|
||||
}
|
||||
try {
|
||||
std::lock_guard<std::mutex> lock(auth_->WithLock());
|
||||
auto user = auth_->GetUser(username);
|
||||
if (!user) {
|
||||
throw query::QueryRuntimeException("User '{}' doesn't exist .", username);
|
||||
}
|
||||
auto role = auth_->GetRole(rolename);
|
||||
if (!role) {
|
||||
throw query::QueryRuntimeException("Role '{}' doesn't exist .", rolename);
|
||||
}
|
||||
if (user->role()) {
|
||||
throw query::QueryRuntimeException("User '{}' is already a member of role '{}'.", username,
|
||||
user->role()->rolename());
|
||||
}
|
||||
user->SetRole(*role);
|
||||
auth_->SaveUser(*user);
|
||||
} catch (const auth::AuthException &e) {
|
||||
throw query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void ClearRole(const std::string &username) override {
|
||||
if (!std::regex_match(username, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid user name.");
|
||||
}
|
||||
try {
|
||||
std::lock_guard<std::mutex> lock(auth_->WithLock());
|
||||
auto user = auth_->GetUser(username);
|
||||
if (!user) {
|
||||
throw query::QueryRuntimeException("User '{}' doesn't exist .", username);
|
||||
}
|
||||
user->ClearRole();
|
||||
auth_->SaveUser(*user);
|
||||
} catch (const auth::AuthException &e) {
|
||||
throw query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<query::TypedValue>> GetPrivileges(const std::string &user_or_role) override {
|
||||
if (!std::regex_match(user_or_role, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid user or role name.");
|
||||
}
|
||||
try {
|
||||
std::lock_guard<std::mutex> lock(auth_->WithLock());
|
||||
std::vector<std::vector<query::TypedValue>> grants;
|
||||
auto user = auth_->GetUser(user_or_role);
|
||||
auto role = auth_->GetRole(user_or_role);
|
||||
if (!user && !role) {
|
||||
throw query::QueryRuntimeException("User or role '{}' doesn't exist.", user_or_role);
|
||||
}
|
||||
if (user) {
|
||||
const auto &permissions = user->GetPermissions();
|
||||
for (const auto &privilege : query::kPrivilegesAll) {
|
||||
auto permission = glue::PrivilegeToPermission(privilege);
|
||||
auto effective = permissions.Has(permission);
|
||||
if (permissions.Has(permission) != auth::PermissionLevel::NEUTRAL) {
|
||||
std::vector<std::string> description;
|
||||
auto user_level = user->permissions().Has(permission);
|
||||
if (user_level == auth::PermissionLevel::GRANT) {
|
||||
description.emplace_back("GRANTED TO USER");
|
||||
} else if (user_level == auth::PermissionLevel::DENY) {
|
||||
description.emplace_back("DENIED TO USER");
|
||||
}
|
||||
if (user->role()) {
|
||||
auto role_level = user->role()->permissions().Has(permission);
|
||||
if (role_level == auth::PermissionLevel::GRANT) {
|
||||
description.emplace_back("GRANTED TO ROLE");
|
||||
} else if (role_level == auth::PermissionLevel::DENY) {
|
||||
description.emplace_back("DENIED TO ROLE");
|
||||
}
|
||||
}
|
||||
grants.push_back({query::TypedValue(auth::PermissionToString(permission)),
|
||||
query::TypedValue(auth::PermissionLevelToString(effective)),
|
||||
query::TypedValue(utils::Join(description, ", "))});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const auto &permissions = role->permissions();
|
||||
for (const auto &privilege : query::kPrivilegesAll) {
|
||||
auto permission = glue::PrivilegeToPermission(privilege);
|
||||
auto effective = permissions.Has(permission);
|
||||
if (effective != auth::PermissionLevel::NEUTRAL) {
|
||||
std::string description;
|
||||
if (effective == auth::PermissionLevel::GRANT) {
|
||||
description = "GRANTED TO ROLE";
|
||||
} else if (effective == auth::PermissionLevel::DENY) {
|
||||
description = "DENIED TO ROLE";
|
||||
}
|
||||
grants.push_back({query::TypedValue(auth::PermissionToString(permission)),
|
||||
query::TypedValue(auth::PermissionLevelToString(effective)),
|
||||
query::TypedValue(description)});
|
||||
}
|
||||
}
|
||||
}
|
||||
return grants;
|
||||
} catch (const auth::AuthException &e) {
|
||||
throw query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void GrantPrivilege(const std::string &user_or_role,
|
||||
const std::vector<query::AuthQuery::Privilege> &privileges) override {
|
||||
EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) {
|
||||
// TODO (mferencevic): should we first check that the
|
||||
// privilege is granted/denied/revoked before
|
||||
// unconditionally granting/denying/revoking it?
|
||||
permissions->Grant(permission);
|
||||
});
|
||||
}
|
||||
|
||||
void DenyPrivilege(const std::string &user_or_role,
|
||||
const std::vector<query::AuthQuery::Privilege> &privileges) override {
|
||||
EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) {
|
||||
// TODO (mferencevic): should we first check that the
|
||||
// privilege is granted/denied/revoked before
|
||||
// unconditionally granting/denying/revoking it?
|
||||
permissions->Deny(permission);
|
||||
});
|
||||
}
|
||||
|
||||
void RevokePrivilege(const std::string &user_or_role,
|
||||
const std::vector<query::AuthQuery::Privilege> &privileges) override {
|
||||
EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) {
|
||||
// TODO (mferencevic): should we first check that the
|
||||
// privilege is granted/denied/revoked before
|
||||
// unconditionally granting/denying/revoking it?
|
||||
permissions->Revoke(permission);
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
template <class TEditFun>
|
||||
void EditPermissions(const std::string &user_or_role, const std::vector<query::AuthQuery::Privilege> &privileges,
|
||||
const TEditFun &edit_fun) {
|
||||
if (!std::regex_match(user_or_role, name_regex_)) {
|
||||
throw query::QueryRuntimeException("Invalid user or role name.");
|
||||
}
|
||||
try {
|
||||
std::lock_guard<std::mutex> lock(auth_->WithLock());
|
||||
std::vector<auth::Permission> permissions;
|
||||
permissions.reserve(privileges.size());
|
||||
for (const auto &privilege : privileges) {
|
||||
permissions.push_back(glue::PrivilegeToPermission(privilege));
|
||||
}
|
||||
auto user = auth_->GetUser(user_or_role);
|
||||
auto role = auth_->GetRole(user_or_role);
|
||||
if (!user && !role) {
|
||||
throw query::QueryRuntimeException("User or role '{}' doesn't exist.", user_or_role);
|
||||
}
|
||||
if (user) {
|
||||
for (const auto &permission : permissions) {
|
||||
edit_fun(&user->permissions(), permission);
|
||||
}
|
||||
auth_->SaveUser(*user);
|
||||
} else {
|
||||
for (const auto &permission : permissions) {
|
||||
edit_fun(&role->permissions(), permission);
|
||||
}
|
||||
auth_->SaveRole(*role);
|
||||
}
|
||||
} catch (const auth::AuthException &e) {
|
||||
throw query::QueryRuntimeException(e.what());
|
||||
}
|
||||
}
|
||||
};
|
||||
#else
|
||||
class NoAuthInCommunity : public query::QueryRuntimeException {
|
||||
public:
|
||||
NoAuthInCommunity()
|
||||
: query::QueryRuntimeException::QueryRuntimeException("Auth is not supported in Memgraph Community!") {}
|
||||
};
|
||||
|
||||
class AuthQueryHandler final : public query::AuthQueryHandler {
|
||||
public:
|
||||
bool CreateUser(const std::string &, const std::optional<std::string> &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
bool DropUser(const std::string &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
void SetPassword(const std::string &, const std::optional<std::string> &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
bool CreateRole(const std::string &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
bool DropRole(const std::string &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
std::vector<query::TypedValue> GetUsernames() override { throw NoAuthInCommunity(); }
|
||||
|
||||
std::vector<query::TypedValue> GetRolenames() override { throw NoAuthInCommunity(); }
|
||||
|
||||
std::optional<std::string> GetRolenameForUser(const std::string &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
std::vector<query::TypedValue> GetUsernamesForRole(const std::string &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
void SetRole(const std::string &, const std::string &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
void ClearRole(const std::string &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
std::vector<std::vector<query::TypedValue>> GetPrivileges(const std::string &) override { throw NoAuthInCommunity(); }
|
||||
|
||||
void GrantPrivilege(const std::string &, const std::vector<query::AuthQuery::Privilege> &) override {
|
||||
throw NoAuthInCommunity();
|
||||
}
|
||||
|
||||
void DenyPrivilege(const std::string &, const std::vector<query::AuthQuery::Privilege> &) override {
|
||||
throw NoAuthInCommunity();
|
||||
}
|
||||
|
||||
void RevokePrivilege(const std::string &, const std::vector<query::AuthQuery::Privilege> &) override {
|
||||
throw NoAuthInCommunity();
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
// Needed to correctly handle memgraph destruction from a signal handler.
|
||||
// Without having some sort of a flag, it is possible that a signal is handled
|
||||
// when we are exiting main, inside destructors of database::GraphDb and
|
||||
@ -1013,7 +1054,7 @@ int main(int argc, char **argv) {
|
||||
// Begin enterprise features initialization
|
||||
|
||||
// Auth
|
||||
auth::Auth auth{data_directory / "auth"};
|
||||
utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> auth{data_directory / "auth"};
|
||||
|
||||
// Audit log
|
||||
audit::Log audit_log{data_directory / "audit", FLAGS_audit_buffer_size, FLAGS_audit_buffer_flush_interval_ms};
|
||||
@ -1075,24 +1116,28 @@ int main(int argc, char **argv) {
|
||||
query::procedure::gModuleRegistry.SetModulesDirectory(query_modules_directories);
|
||||
query::procedure::gModuleRegistry.UnloadAndLoadModulesFromDirectories();
|
||||
|
||||
{
|
||||
// Triggers can execute query procedures, so we need to reload the modules first and then
|
||||
// the triggers
|
||||
auto storage_accessor = interpreter_context.db->Access();
|
||||
auto dba = query::DbAccessor{&storage_accessor};
|
||||
interpreter_context.trigger_store.RestoreTriggers(
|
||||
&interpreter_context.ast_cache, &dba, &interpreter_context.antlr_lock, interpreter_context.config.query);
|
||||
}
|
||||
|
||||
// As the Stream transformations are using modules, they have to be restored after the query modules are loaded.
|
||||
interpreter_context.streams.RestoreStreams();
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
AuthQueryHandler auth_handler(&auth, std::regex(FLAGS_auth_user_or_role_name_regex));
|
||||
AuthChecker auth_checker{&auth};
|
||||
#else
|
||||
AuthQueryHandler auth_handler;
|
||||
query::AllowEverythingAuthChecker auth_checker{};
|
||||
#endif
|
||||
interpreter_context.auth = &auth_handler;
|
||||
interpreter_context.auth_checker = &auth_checker;
|
||||
|
||||
{
|
||||
// Triggers can execute query procedures, so we need to reload the modules first and then
|
||||
// the triggers
|
||||
auto storage_accessor = interpreter_context.db->Access();
|
||||
auto dba = query::DbAccessor{&storage_accessor};
|
||||
interpreter_context.trigger_store.RestoreTriggers(&interpreter_context.ast_cache, &dba,
|
||||
&interpreter_context.antlr_lock, interpreter_context.config.query,
|
||||
interpreter_context.auth_checker);
|
||||
}
|
||||
|
||||
ServerContext context;
|
||||
std::string service_name = "Bolt";
|
||||
|
18
src/query/auth_checker.hpp
Normal file
18
src/query/auth_checker.hpp
Normal file
@ -0,0 +1,18 @@
|
||||
#pragma once
|
||||
|
||||
#include "query/frontend/ast/ast.hpp"
|
||||
|
||||
namespace query {
|
||||
class AuthChecker {
|
||||
public:
|
||||
virtual bool IsUserAuthorized(const std::optional<std::string> &username,
|
||||
const std::vector<query::AuthQuery::Privilege> &privileges) const = 0;
|
||||
};
|
||||
|
||||
class AllowEverythingAuthChecker final : public query::AuthChecker {
|
||||
bool IsUserAuthorized(const std::optional<std::string> &username,
|
||||
const std::vector<query::AuthQuery::Privilege> &privileges) const override {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
} // namespace query
|
@ -3,6 +3,7 @@
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
|
||||
#include "glue/communication.hpp"
|
||||
#include "query/constants.hpp"
|
||||
@ -463,8 +464,13 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<std::string> StringPointerToOptional(const std::string *str) {
|
||||
return str == nullptr ? std::nullopt : std::make_optional(*str);
|
||||
}
|
||||
|
||||
Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters ¶meters,
|
||||
InterpreterContext *interpreter_context, DbAccessor *db_accessor) {
|
||||
InterpreterContext *interpreter_context, DbAccessor *db_accessor,
|
||||
const std::string *username) {
|
||||
Frame frame(0);
|
||||
SymbolTable symbol_table;
|
||||
EvaluationContext evaluation_context;
|
||||
@ -484,19 +490,21 @@ Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters ¶mete
|
||||
std::string consumer_group{stream_query->consumer_group_.empty() ? kDefaultConsumerGroup
|
||||
: stream_query->consumer_group_};
|
||||
|
||||
callback.fn = [interpreter_context, stream_name = stream_query->stream_name_,
|
||||
topic_names = stream_query->topic_names_, consumer_group = std::move(consumer_group),
|
||||
batch_interval =
|
||||
GetOptionalValue<std::chrono::milliseconds>(stream_query->batch_interval_, evaluator),
|
||||
batch_size = GetOptionalValue<int64_t>(stream_query->batch_size_, evaluator),
|
||||
transformation_name = stream_query->transform_name_]() mutable {
|
||||
interpreter_context->streams.Create(stream_name, query::StreamInfo{.topics = std::move(topic_names),
|
||||
.consumer_group = std::move(consumer_group),
|
||||
.batch_interval = batch_interval,
|
||||
.batch_size = batch_size,
|
||||
.transformation_name = transformation_name});
|
||||
return std::vector<std::vector<TypedValue>>{};
|
||||
};
|
||||
callback.fn =
|
||||
[interpreter_context, stream_name = stream_query->stream_name_, topic_names = stream_query->topic_names_,
|
||||
consumer_group = std::move(consumer_group),
|
||||
batch_interval = GetOptionalValue<std::chrono::milliseconds>(stream_query->batch_interval_, evaluator),
|
||||
batch_size = GetOptionalValue<int64_t>(stream_query->batch_size_, evaluator),
|
||||
transformation_name = stream_query->transform_name_, owner = StringPointerToOptional(username)]() mutable {
|
||||
interpreter_context->streams.Create(stream_name,
|
||||
query::StreamInfo{.topics = std::move(topic_names),
|
||||
.consumer_group = std::move(consumer_group),
|
||||
.batch_interval = batch_interval,
|
||||
.batch_size = batch_size,
|
||||
.transformation_name = std::move(transformation_name),
|
||||
.owner = std::move(owner)});
|
||||
return std::vector<std::vector<TypedValue>>{};
|
||||
};
|
||||
return callback;
|
||||
}
|
||||
case StreamQuery::Action::START_STREAM: {
|
||||
@ -535,8 +543,8 @@ Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters ¶mete
|
||||
return callback;
|
||||
}
|
||||
case StreamQuery::Action::SHOW_STREAMS: {
|
||||
callback.header = {"name", "topics", "consumer_group", "batch_interval", "batch_size", "transformation_name",
|
||||
"is running"};
|
||||
callback.header = {"name", "topics", "consumer_group", "batch_interval", "batch_size", "transformation_name",
|
||||
"owner", "is running"};
|
||||
callback.fn = [interpreter_context]() {
|
||||
auto streams_status = interpreter_context->streams.GetStreamInfo();
|
||||
std::vector<std::vector<TypedValue>> results;
|
||||
@ -565,6 +573,11 @@ Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters ¶mete
|
||||
typed_status.emplace_back();
|
||||
}
|
||||
typed_status.emplace_back(stream_info.transformation_name);
|
||||
if (stream_info.owner.has_value()) {
|
||||
typed_status.emplace_back(*stream_info.owner);
|
||||
} else {
|
||||
typed_status.emplace_back();
|
||||
}
|
||||
};
|
||||
|
||||
for (const auto &status : streams_status) {
|
||||
@ -1231,16 +1244,17 @@ TriggerEventType ToTriggerEventType(const TriggerQuery::EventType event_type) {
|
||||
|
||||
Callback CreateTrigger(TriggerQuery *trigger_query,
|
||||
const std::map<std::string, storage::PropertyValue> &user_parameters,
|
||||
InterpreterContext *interpreter_context, DbAccessor *dba) {
|
||||
InterpreterContext *interpreter_context, DbAccessor *dba, std::optional<std::string> owner) {
|
||||
return {
|
||||
{},
|
||||
[trigger_name = std::move(trigger_query->trigger_name_), trigger_statement = std::move(trigger_query->statement_),
|
||||
event_type = trigger_query->event_type_, before_commit = trigger_query->before_commit_, interpreter_context, dba,
|
||||
user_parameters]() -> std::vector<std::vector<TypedValue>> {
|
||||
user_parameters, owner = std::move(owner)]() mutable -> std::vector<std::vector<TypedValue>> {
|
||||
interpreter_context->trigger_store.AddTrigger(
|
||||
trigger_name, trigger_statement, user_parameters, ToTriggerEventType(event_type),
|
||||
std::move(trigger_name), trigger_statement, user_parameters, ToTriggerEventType(event_type),
|
||||
before_commit ? TriggerPhase::BEFORE_COMMIT : TriggerPhase::AFTER_COMMIT, &interpreter_context->ast_cache,
|
||||
dba, &interpreter_context->antlr_lock, interpreter_context->config.query);
|
||||
dba, &interpreter_context->antlr_lock, interpreter_context->config.query, std::move(owner),
|
||||
interpreter_context->auth_checker);
|
||||
return {};
|
||||
}};
|
||||
}
|
||||
@ -1255,7 +1269,7 @@ Callback DropTrigger(TriggerQuery *trigger_query, InterpreterContext *interprete
|
||||
}
|
||||
|
||||
Callback ShowTriggers(InterpreterContext *interpreter_context) {
|
||||
return {{"trigger name", "statement", "event type", "phase"}, [interpreter_context] {
|
||||
return {{"trigger name", "statement", "event type", "phase", "owner"}, [interpreter_context] {
|
||||
std::vector<std::vector<TypedValue>> results;
|
||||
auto trigger_infos = interpreter_context->trigger_store.GetTriggerInfo();
|
||||
results.reserve(trigger_infos.size());
|
||||
@ -1267,6 +1281,9 @@ Callback ShowTriggers(InterpreterContext *interpreter_context) {
|
||||
typed_trigger_info.emplace_back(TriggerEventTypeToString(trigger_info.event_type));
|
||||
typed_trigger_info.emplace_back(trigger_info.phase == TriggerPhase::BEFORE_COMMIT ? "BEFORE COMMIT"
|
||||
: "AFTER COMMIT");
|
||||
typed_trigger_info.emplace_back(trigger_info.owner.has_value() ? TypedValue{*trigger_info.owner}
|
||||
: TypedValue{});
|
||||
|
||||
results.push_back(std::move(typed_trigger_info));
|
||||
}
|
||||
|
||||
@ -1276,7 +1293,8 @@ Callback ShowTriggers(InterpreterContext *interpreter_context) {
|
||||
|
||||
PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, const bool in_explicit_transaction,
|
||||
InterpreterContext *interpreter_context, DbAccessor *dba,
|
||||
const std::map<std::string, storage::PropertyValue> &user_parameters) {
|
||||
const std::map<std::string, storage::PropertyValue> &user_parameters,
|
||||
const std::string *username) {
|
||||
if (in_explicit_transaction) {
|
||||
throw TriggerModificationInMulticommandTxException();
|
||||
}
|
||||
@ -1284,11 +1302,12 @@ PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, const bool in_explic
|
||||
auto *trigger_query = utils::Downcast<TriggerQuery>(parsed_query.query);
|
||||
MG_ASSERT(trigger_query);
|
||||
|
||||
auto callback = [trigger_query, interpreter_context, dba, &user_parameters] {
|
||||
auto callback = [trigger_query, interpreter_context, dba, &user_parameters,
|
||||
owner = StringPointerToOptional(username)]() mutable {
|
||||
switch (trigger_query->action_) {
|
||||
case TriggerQuery::Action::CREATE_TRIGGER:
|
||||
EventCounter::IncrementCounter(EventCounter::TriggersCreated);
|
||||
return CreateTrigger(trigger_query, user_parameters, interpreter_context, dba);
|
||||
return CreateTrigger(trigger_query, user_parameters, interpreter_context, dba, std::move(owner));
|
||||
case TriggerQuery::Action::DROP_TRIGGER:
|
||||
return DropTrigger(trigger_query, interpreter_context);
|
||||
case TriggerQuery::Action::SHOW_TRIGGERS:
|
||||
@ -1315,14 +1334,15 @@ PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, const bool in_explic
|
||||
|
||||
PreparedQuery PrepareStreamQuery(ParsedQuery parsed_query, const bool in_explicit_transaction,
|
||||
InterpreterContext *interpreter_context, DbAccessor *dba,
|
||||
const std::map<std::string, storage::PropertyValue> &user_parameters) {
|
||||
const std::map<std::string, storage::PropertyValue> &user_parameters,
|
||||
const std::string *username) {
|
||||
if (in_explicit_transaction) {
|
||||
throw StreamQueryInMulticommandTxException();
|
||||
}
|
||||
|
||||
auto *stream_query = utils::Downcast<StreamQuery>(parsed_query.query);
|
||||
MG_ASSERT(stream_query);
|
||||
auto callback = HandleStreamQuery(stream_query, parsed_query.parameters, interpreter_context, dba);
|
||||
auto callback = HandleStreamQuery(stream_query, parsed_query.parameters, interpreter_context, dba, username);
|
||||
|
||||
return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges),
|
||||
[callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}](
|
||||
@ -1651,7 +1671,8 @@ void Interpreter::RollbackTransaction() {
|
||||
}
|
||||
|
||||
Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
|
||||
const std::map<std::string, storage::PropertyValue> ¶ms) {
|
||||
const std::map<std::string, storage::PropertyValue> ¶ms,
|
||||
const std::string *username) {
|
||||
if (!in_explicit_transaction_) {
|
||||
query_executions_.clear();
|
||||
}
|
||||
@ -1748,10 +1769,10 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
|
||||
prepared_query = PrepareFreeMemoryQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_);
|
||||
} else if (utils::Downcast<TriggerQuery>(parsed_query.query)) {
|
||||
prepared_query = PrepareTriggerQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_,
|
||||
&*execution_db_accessor_, params);
|
||||
&*execution_db_accessor_, params, username);
|
||||
} else if (utils::Downcast<StreamQuery>(parsed_query.query)) {
|
||||
prepared_query = PrepareStreamQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_,
|
||||
&*execution_db_accessor_, params);
|
||||
&*execution_db_accessor_, params, username);
|
||||
} else if (utils::Downcast<IsolationLevelQuery>(parsed_query.query)) {
|
||||
prepared_query =
|
||||
PrepareIsolationLevelQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_, this);
|
||||
@ -1809,7 +1830,7 @@ void RunTriggersIndividually(const utils::SkipList<Trigger> &triggers, Interpret
|
||||
trigger_context.AdaptForAccessor(&db_accessor);
|
||||
try {
|
||||
trigger.Execute(&db_accessor, &execution_memory, interpreter_context->config.execution_timeout_sec,
|
||||
&interpreter_context->is_shutting_down, trigger_context);
|
||||
&interpreter_context->is_shutting_down, trigger_context, interpreter_context->auth_checker);
|
||||
} catch (const utils::BasicException &exception) {
|
||||
spdlog::warn("Trigger '{}' failed with exception:\n{}", trigger.Name(), exception.what());
|
||||
db_accessor.Abort();
|
||||
@ -1864,7 +1885,7 @@ void Interpreter::Commit() {
|
||||
AdvanceCommand();
|
||||
try {
|
||||
trigger.Execute(&*execution_db_accessor_, &execution_memory, interpreter_context_->config.execution_timeout_sec,
|
||||
&interpreter_context_->is_shutting_down, *trigger_context);
|
||||
&interpreter_context_->is_shutting_down, *trigger_context, interpreter_context_->auth_checker);
|
||||
} catch (const utils::BasicException &e) {
|
||||
throw utils::BasicException(
|
||||
fmt::format("Trigger '{}' caused the transaction to fail.\nException: {}", trigger.Name(), e.what()));
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <gflags/gflags.h>
|
||||
|
||||
#include "query/auth_checker.hpp"
|
||||
#include "query/config.hpp"
|
||||
#include "query/context.hpp"
|
||||
#include "query/cypher_query_interpreter.hpp"
|
||||
@ -166,6 +167,7 @@ struct InterpreterContext {
|
||||
std::atomic<bool> is_shutting_down{false};
|
||||
|
||||
AuthQueryHandler *auth{nullptr};
|
||||
query::AuthChecker *auth_checker{nullptr};
|
||||
|
||||
utils::SkipList<QueryCacheEntry> ast_cache;
|
||||
utils::SkipList<PlanCacheEntry> plan_cache;
|
||||
@ -205,7 +207,8 @@ class Interpreter final {
|
||||
*
|
||||
* @throw query::QueryException
|
||||
*/
|
||||
PrepareResult Prepare(const std::string &query, const std::map<std::string, storage::PropertyValue> ¶ms);
|
||||
PrepareResult Prepare(const std::string &query, const std::map<std::string, storage::PropertyValue> ¶ms,
|
||||
const std::string *username);
|
||||
|
||||
/**
|
||||
* Execute the last prepared query and stream *all* of the results into the
|
||||
|
@ -106,6 +106,7 @@ const std::string kBatchIntervalKey{"batch_interval"};
|
||||
const std::string kBatchSizeKey{"batch_size"};
|
||||
const std::string kIsRunningKey{"is_running"};
|
||||
const std::string kTransformationName{"transformation_name"};
|
||||
const std::string kOwner{"owner"};
|
||||
|
||||
void to_json(nlohmann::json &data, StreamStatus &&status) {
|
||||
auto &info = status.info;
|
||||
@ -127,6 +128,12 @@ void to_json(nlohmann::json &data, StreamStatus &&status) {
|
||||
|
||||
data[kIsRunningKey] = status.is_running;
|
||||
data[kTransformationName] = status.info.transformation_name;
|
||||
|
||||
if (info.owner.has_value()) {
|
||||
data[kOwner] = std::move(*info.owner);
|
||||
} else {
|
||||
data[kOwner] = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void from_json(const nlohmann::json &data, StreamStatus &status) {
|
||||
@ -135,16 +142,14 @@ void from_json(const nlohmann::json &data, StreamStatus &status) {
|
||||
data.at(kTopicsKey).get_to(info.topics);
|
||||
data.at(kConsumerGroupKey).get_to(info.consumer_group);
|
||||
|
||||
const auto batch_interval = data.at(kBatchIntervalKey);
|
||||
if (!batch_interval.is_null()) {
|
||||
if (const auto batch_interval = data.at(kBatchIntervalKey); !batch_interval.is_null()) {
|
||||
using BatchInterval = decltype(info.batch_interval)::value_type;
|
||||
info.batch_interval = BatchInterval{batch_interval.get<BatchInterval::rep>()};
|
||||
} else {
|
||||
info.batch_interval = {};
|
||||
}
|
||||
|
||||
const auto batch_size = data.at(kBatchSizeKey);
|
||||
if (!batch_size.is_null()) {
|
||||
if (const auto batch_size = data.at(kBatchSizeKey); !batch_size.is_null()) {
|
||||
info.batch_size = batch_size.get<decltype(info.batch_size)::value_type>();
|
||||
} else {
|
||||
info.batch_size = {};
|
||||
@ -152,6 +157,12 @@ void from_json(const nlohmann::json &data, StreamStatus &status) {
|
||||
|
||||
data.at(kIsRunningKey).get_to(status.is_running);
|
||||
data.at(kTransformationName).get_to(status.info.transformation_name);
|
||||
|
||||
if (const auto &owner = data.at(kOwner); !owner.is_null()) {
|
||||
info.owner = owner.get<decltype(info.owner)::value_type>();
|
||||
} else {
|
||||
info.owner = {};
|
||||
}
|
||||
}
|
||||
|
||||
Streams::Streams(InterpreterContext *interpreter_context, std::string bootstrap_servers,
|
||||
@ -200,7 +211,8 @@ void Streams::Create(const std::string &stream_name, StreamInfo info) {
|
||||
auto it = CreateConsumer(*locked_streams, stream_name, std::move(info));
|
||||
|
||||
try {
|
||||
Persist(CreateStatus(stream_name, it->second.transformation_name, *it->second.consumer->ReadLock()));
|
||||
Persist(
|
||||
CreateStatus(stream_name, it->second.transformation_name, it->second.owner, *it->second.consumer->ReadLock()));
|
||||
} catch (...) {
|
||||
locked_streams->erase(it);
|
||||
throw;
|
||||
@ -233,7 +245,7 @@ void Streams::Start(const std::string &stream_name) {
|
||||
auto locked_consumer = it->second.consumer->Lock();
|
||||
locked_consumer->Start();
|
||||
|
||||
Persist(CreateStatus(stream_name, it->second.transformation_name, *locked_consumer));
|
||||
Persist(CreateStatus(stream_name, it->second.transformation_name, it->second.owner, *locked_consumer));
|
||||
}
|
||||
|
||||
void Streams::Stop(const std::string &stream_name) {
|
||||
@ -243,7 +255,7 @@ void Streams::Stop(const std::string &stream_name) {
|
||||
auto locked_consumer = it->second.consumer->Lock();
|
||||
locked_consumer->Stop();
|
||||
|
||||
Persist(CreateStatus(stream_name, it->second.transformation_name, *locked_consumer));
|
||||
Persist(CreateStatus(stream_name, it->second.transformation_name, it->second.owner, *locked_consumer));
|
||||
}
|
||||
|
||||
void Streams::StartAll() {
|
||||
@ -251,7 +263,7 @@ void Streams::StartAll() {
|
||||
auto locked_consumer = stream_data.consumer->Lock();
|
||||
if (!locked_consumer->IsRunning()) {
|
||||
locked_consumer->Start();
|
||||
Persist(CreateStatus(stream_name, stream_data.transformation_name, *locked_consumer));
|
||||
Persist(CreateStatus(stream_name, stream_data.transformation_name, stream_data.owner, *locked_consumer));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -261,7 +273,7 @@ void Streams::StopAll() {
|
||||
auto locked_consumer = stream_data.consumer->Lock();
|
||||
if (locked_consumer->IsRunning()) {
|
||||
locked_consumer->Stop();
|
||||
Persist(CreateStatus(stream_name, stream_data.transformation_name, *locked_consumer));
|
||||
Persist(CreateStatus(stream_name, stream_data.transformation_name, stream_data.owner, *locked_consumer));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -270,8 +282,8 @@ std::vector<StreamStatus> Streams::GetStreamInfo() const {
|
||||
std::vector<StreamStatus> result;
|
||||
{
|
||||
for (auto locked_streams = streams_.ReadLock(); const auto &[stream_name, stream_data] : *locked_streams) {
|
||||
result.emplace_back(
|
||||
CreateStatus(stream_name, stream_data.transformation_name, *stream_data.consumer->ReadLock()));
|
||||
result.emplace_back(CreateStatus(stream_name, stream_data.transformation_name, stream_data.owner,
|
||||
*stream_data.consumer->ReadLock()));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
@ -314,6 +326,7 @@ TransformationResult Streams::Check(const std::string &stream_name, std::optiona
|
||||
}
|
||||
|
||||
StreamStatus Streams::CreateStatus(const std::string &name, const std::string &transformation_name,
|
||||
const std::optional<std::string> &owner,
|
||||
const integrations::kafka::Consumer &consumer) {
|
||||
const auto &info = consumer.Info();
|
||||
return StreamStatus{name,
|
||||
@ -323,6 +336,7 @@ StreamStatus Streams::CreateStatus(const std::string &name, const std::string &t
|
||||
info.batch_interval,
|
||||
info.batch_size,
|
||||
transformation_name,
|
||||
owner,
|
||||
},
|
||||
consumer.IsRunning()};
|
||||
}
|
||||
@ -336,7 +350,7 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std
|
||||
auto *memory_resource = utils::NewDeleteResource();
|
||||
|
||||
auto consumer_function = [interpreter_context = interpreter_context_, memory_resource, stream_name,
|
||||
transformation_name = stream_info.transformation_name,
|
||||
transformation_name = stream_info.transformation_name, owner = stream_info.owner,
|
||||
interpreter = std::make_shared<Interpreter>(interpreter_context_),
|
||||
result = mgp_result{nullptr, memory_resource}](
|
||||
const std::vector<integrations::kafka::Message> &messages) mutable {
|
||||
@ -347,6 +361,10 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std
|
||||
DiscardValueResultStream stream;
|
||||
|
||||
spdlog::trace("Start transaction in stream '{}'", stream_name);
|
||||
utils::OnScopeExit cleanup{[&interpreter, &result]() {
|
||||
result.rows.clear();
|
||||
interpreter->Abort();
|
||||
}};
|
||||
interpreter->BeginTransaction();
|
||||
|
||||
for (auto &row : result.rows) {
|
||||
@ -358,7 +376,13 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std
|
||||
std::string query{query_value.ValueString()};
|
||||
spdlog::trace("Executing query '{}' in stream '{}'", query, stream_name);
|
||||
auto prepare_result =
|
||||
interpreter->Prepare(query, params_prop.IsNull() ? empty_parameters : params_prop.ValueMap());
|
||||
interpreter->Prepare(query, params_prop.IsNull() ? empty_parameters : params_prop.ValueMap(), nullptr);
|
||||
if (!interpreter_context->auth_checker->IsUserAuthorized(owner, prepare_result.privileges)) {
|
||||
throw StreamsException{
|
||||
"Couldn't execute query '{}' for stream '{}' becuase the owner is not authorized to execute the "
|
||||
"query!",
|
||||
query, stream_name};
|
||||
}
|
||||
interpreter->PullAll(&stream);
|
||||
}
|
||||
|
||||
@ -376,7 +400,7 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std
|
||||
};
|
||||
|
||||
auto insert_result = map.insert_or_assign(
|
||||
stream_name, StreamData{std::move(stream_info.transformation_name),
|
||||
stream_name, StreamData{std::move(stream_info.transformation_name), std::move(stream_info.owner),
|
||||
std::make_unique<SynchronizedConsumer>(bootstrap_servers_, std::move(consumer_info),
|
||||
std::move(consumer_function))});
|
||||
MG_ASSERT(insert_result.second, "Unexpected error during storing consumer '{}'", stream_name);
|
||||
|
@ -28,6 +28,7 @@ struct StreamInfo {
|
||||
std::optional<std::chrono::milliseconds> batch_interval;
|
||||
std::optional<int64_t> batch_size;
|
||||
std::string transformation_name;
|
||||
std::optional<std::string> owner;
|
||||
};
|
||||
|
||||
struct StreamStatus {
|
||||
@ -40,6 +41,7 @@ using SynchronizedConsumer = utils::Synchronized<integrations::kafka::Consumer,
|
||||
|
||||
struct StreamData {
|
||||
std::string transformation_name;
|
||||
std::optional<std::string> owner;
|
||||
std::unique_ptr<SynchronizedConsumer> consumer;
|
||||
};
|
||||
|
||||
@ -131,6 +133,7 @@ class Streams final {
|
||||
using SynchronizedStreamsMap = utils::Synchronized<StreamsMap, utils::WritePrioritizedRWLock>;
|
||||
|
||||
static StreamStatus CreateStatus(const std::string &name, const std::string &transformation_name,
|
||||
const std::optional<std::string> &owner,
|
||||
const integrations::kafka::Consumer &consumer);
|
||||
|
||||
StreamsMap::iterator CreateConsumer(StreamsMap &map, const std::string &stream_name, StreamInfo stream_info);
|
||||
|
@ -142,51 +142,55 @@ std::vector<std::pair<Identifier, TriggerIdentifierTag>> GetPredefinedIdentifier
|
||||
Trigger::Trigger(std::string name, const std::string &query,
|
||||
const std::map<std::string, storage::PropertyValue> &user_parameters,
|
||||
const TriggerEventType event_type, utils::SkipList<QueryCacheEntry> *query_cache,
|
||||
DbAccessor *db_accessor, utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config)
|
||||
DbAccessor *db_accessor, utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config,
|
||||
std::optional<std::string> owner, const query::AuthChecker *auth_checker)
|
||||
: name_{std::move(name)},
|
||||
parsed_statements_{ParseQuery(query, user_parameters, query_cache, antlr_lock, query_config)},
|
||||
event_type_{event_type} {
|
||||
event_type_{event_type},
|
||||
owner_{std::move(owner)} {
|
||||
// We check immediately if the query is valid by trying to create a plan.
|
||||
GetPlan(db_accessor);
|
||||
GetPlan(db_accessor, auth_checker);
|
||||
}
|
||||
|
||||
Trigger::TriggerPlan::TriggerPlan(std::unique_ptr<LogicalPlan> logical_plan, std::vector<IdentifierInfo> identifiers)
|
||||
: cached_plan(std::move(logical_plan)), identifiers(std::move(identifiers)) {}
|
||||
|
||||
std::shared_ptr<Trigger::TriggerPlan> Trigger::GetPlan(DbAccessor *db_accessor) const {
|
||||
std::shared_ptr<Trigger::TriggerPlan> Trigger::GetPlan(DbAccessor *db_accessor,
|
||||
const query::AuthChecker *auth_checker) const {
|
||||
std::lock_guard plan_guard{plan_lock_};
|
||||
if (parsed_statements_.is_cacheable && trigger_plan_ && !trigger_plan_->cached_plan.IsExpired()) {
|
||||
return trigger_plan_;
|
||||
if (!parsed_statements_.is_cacheable || !trigger_plan_ || trigger_plan_->cached_plan.IsExpired()) {
|
||||
auto identifiers = GetPredefinedIdentifiers(event_type_);
|
||||
|
||||
AstStorage ast_storage;
|
||||
ast_storage.properties_ = parsed_statements_.ast_storage.properties_;
|
||||
ast_storage.labels_ = parsed_statements_.ast_storage.labels_;
|
||||
ast_storage.edge_types_ = parsed_statements_.ast_storage.edge_types_;
|
||||
|
||||
std::vector<Identifier *> predefined_identifiers;
|
||||
predefined_identifiers.reserve(identifiers.size());
|
||||
std::transform(identifiers.begin(), identifiers.end(), std::back_inserter(predefined_identifiers),
|
||||
[](auto &identifier) { return &identifier.first; });
|
||||
|
||||
auto logical_plan = MakeLogicalPlan(std::move(ast_storage), utils::Downcast<CypherQuery>(parsed_statements_.query),
|
||||
parsed_statements_.parameters, db_accessor, predefined_identifiers);
|
||||
|
||||
trigger_plan_ = std::make_shared<TriggerPlan>(std::move(logical_plan), std::move(identifiers));
|
||||
}
|
||||
if (!auth_checker->IsUserAuthorized(owner_, parsed_statements_.required_privileges)) {
|
||||
throw utils::BasicException("The owner of trigger '{}' is not authorized to execute the query!", name_);
|
||||
}
|
||||
|
||||
auto identifiers = GetPredefinedIdentifiers(event_type_);
|
||||
|
||||
AstStorage ast_storage;
|
||||
ast_storage.properties_ = parsed_statements_.ast_storage.properties_;
|
||||
ast_storage.labels_ = parsed_statements_.ast_storage.labels_;
|
||||
ast_storage.edge_types_ = parsed_statements_.ast_storage.edge_types_;
|
||||
|
||||
std::vector<Identifier *> predefined_identifiers;
|
||||
predefined_identifiers.reserve(identifiers.size());
|
||||
std::transform(identifiers.begin(), identifiers.end(), std::back_inserter(predefined_identifiers),
|
||||
[](auto &identifier) { return &identifier.first; });
|
||||
|
||||
auto logical_plan = MakeLogicalPlan(std::move(ast_storage), utils::Downcast<CypherQuery>(parsed_statements_.query),
|
||||
parsed_statements_.parameters, db_accessor, predefined_identifiers);
|
||||
|
||||
trigger_plan_ = std::make_shared<TriggerPlan>(std::move(logical_plan), std::move(identifiers));
|
||||
return trigger_plan_;
|
||||
}
|
||||
|
||||
void Trigger::Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution_memory,
|
||||
const double max_execution_time_sec, std::atomic<bool> *is_shutting_down,
|
||||
const TriggerContext &context) const {
|
||||
const TriggerContext &context, const AuthChecker *auth_checker) const {
|
||||
if (!context.ShouldEventTrigger(event_type_)) {
|
||||
return;
|
||||
}
|
||||
|
||||
spdlog::debug("Executing trigger '{}'", name_);
|
||||
auto trigger_plan = GetPlan(dba);
|
||||
auto trigger_plan = GetPlan(dba, auth_checker);
|
||||
MG_ASSERT(trigger_plan, "Invalid trigger plan received");
|
||||
auto &[plan, identifiers] = *trigger_plan;
|
||||
|
||||
@ -238,13 +242,14 @@ void Trigger::Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution
|
||||
|
||||
namespace {
|
||||
// When the format of the persisted trigger is changed, increase this version
|
||||
constexpr uint64_t kVersion{1};
|
||||
constexpr uint64_t kVersion{2};
|
||||
} // namespace
|
||||
|
||||
TriggerStore::TriggerStore(std::filesystem::path directory) : storage_{std::move(directory)} {}
|
||||
|
||||
void TriggerStore::RestoreTriggers(utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor,
|
||||
utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config) {
|
||||
utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config,
|
||||
const query::AuthChecker *auth_checker) {
|
||||
MG_ASSERT(before_commit_triggers_.size() == 0 && after_commit_triggers_.size() == 0,
|
||||
"Cannot restore trigger when some triggers already exist!");
|
||||
spdlog::info("Loading triggers...");
|
||||
@ -292,10 +297,19 @@ void TriggerStore::RestoreTriggers(utils::SkipList<QueryCacheEntry> *query_cache
|
||||
}
|
||||
const auto user_parameters = serialization::DeserializePropertyValueMap(json_trigger_data["user_parameters"]);
|
||||
|
||||
const auto owner_json = json_trigger_data["owner"];
|
||||
std::optional<std::string> owner{};
|
||||
if (owner_json.is_string()) {
|
||||
owner.emplace(owner_json.get<std::string>());
|
||||
} else if (!owner_json.is_null()) {
|
||||
spdlog::warn(invalid_state_message);
|
||||
continue;
|
||||
}
|
||||
|
||||
std::optional<Trigger> trigger;
|
||||
try {
|
||||
trigger.emplace(trigger_name, statement, user_parameters, event_type, query_cache, db_accessor, antlr_lock,
|
||||
query_config);
|
||||
query_config, std::move(owner), auth_checker);
|
||||
} catch (const utils::BasicException &e) {
|
||||
spdlog::warn("Failed to create trigger '{}' because: {}", trigger_name, e.what());
|
||||
continue;
|
||||
@ -309,11 +323,12 @@ void TriggerStore::RestoreTriggers(utils::SkipList<QueryCacheEntry> *query_cache
|
||||
}
|
||||
}
|
||||
|
||||
void TriggerStore::AddTrigger(const std::string &name, const std::string &query,
|
||||
void TriggerStore::AddTrigger(std::string name, const std::string &query,
|
||||
const std::map<std::string, storage::PropertyValue> &user_parameters,
|
||||
TriggerEventType event_type, TriggerPhase phase,
|
||||
utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor,
|
||||
utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config) {
|
||||
utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config,
|
||||
std::optional<std::string> owner, const query::AuthChecker *auth_checker) {
|
||||
std::unique_lock store_guard{store_lock_};
|
||||
if (storage_.Get(name)) {
|
||||
throw utils::BasicException("Trigger with the same name already exists.");
|
||||
@ -321,7 +336,8 @@ void TriggerStore::AddTrigger(const std::string &name, const std::string &query,
|
||||
|
||||
std::optional<Trigger> trigger;
|
||||
try {
|
||||
trigger.emplace(name, query, user_parameters, event_type, query_cache, db_accessor, antlr_lock, query_config);
|
||||
trigger.emplace(std::move(name), query, user_parameters, event_type, query_cache, db_accessor, antlr_lock,
|
||||
query_config, std::move(owner), auth_checker);
|
||||
} catch (const utils::BasicException &e) {
|
||||
const auto identifiers = GetPredefinedIdentifiers(event_type);
|
||||
std::stringstream identifier_names_stream;
|
||||
@ -342,7 +358,13 @@ void TriggerStore::AddTrigger(const std::string &name, const std::string &query,
|
||||
data["event_type"] = event_type;
|
||||
data["phase"] = phase;
|
||||
data["version"] = kVersion;
|
||||
storage_.Put(name, data.dump());
|
||||
|
||||
if (const auto &owner_from_trigger = trigger->Owner(); owner_from_trigger.has_value()) {
|
||||
data["owner"] = *owner_from_trigger;
|
||||
} else {
|
||||
data["owner"] = nullptr;
|
||||
}
|
||||
storage_.Put(trigger->Name(), data.dump());
|
||||
store_guard.unlock();
|
||||
|
||||
auto triggers_acc =
|
||||
@ -384,7 +406,7 @@ std::vector<TriggerStore::TriggerInfo> TriggerStore::GetTriggerInfo() const {
|
||||
|
||||
const auto add_info = [&](const utils::SkipList<Trigger> &trigger_list, const TriggerPhase phase) {
|
||||
for (const auto &trigger : trigger_list.access()) {
|
||||
info.push_back({trigger.Name(), trigger.OriginalStatement(), trigger.EventType(), phase});
|
||||
info.push_back({trigger.Name(), trigger.OriginalStatement(), trigger.EventType(), phase, trigger.Owner()});
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "kvstore/kvstore.hpp"
|
||||
#include "query/auth_checker.hpp"
|
||||
#include "query/config.hpp"
|
||||
#include "query/cypher_query_interpreter.hpp"
|
||||
#include "query/db_accessor.hpp"
|
||||
@ -23,10 +24,12 @@ struct Trigger {
|
||||
explicit Trigger(std::string name, const std::string &query,
|
||||
const std::map<std::string, storage::PropertyValue> &user_parameters, TriggerEventType event_type,
|
||||
utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor, utils::SpinLock *antlr_lock,
|
||||
const InterpreterConfig::Query &query_config);
|
||||
const InterpreterConfig::Query &query_config, std::optional<std::string> owner,
|
||||
const query::AuthChecker *auth_checker);
|
||||
|
||||
void Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution_memory, double max_execution_time_sec,
|
||||
std::atomic<bool> *is_shutting_down, const TriggerContext &context) const;
|
||||
std::atomic<bool> *is_shutting_down, const TriggerContext &context,
|
||||
const AuthChecker *auth_checker) const;
|
||||
|
||||
bool operator==(const Trigger &other) const { return name_ == other.name_; }
|
||||
// NOLINTNEXTLINE (modernize-use-nullptr)
|
||||
@ -37,6 +40,7 @@ struct Trigger {
|
||||
|
||||
const auto &Name() const noexcept { return name_; }
|
||||
const auto &OriginalStatement() const noexcept { return parsed_statements_.query_string; }
|
||||
const auto &Owner() const noexcept { return owner_; }
|
||||
auto EventType() const noexcept { return event_type_; }
|
||||
|
||||
private:
|
||||
@ -48,7 +52,7 @@ struct Trigger {
|
||||
CachedPlan cached_plan;
|
||||
std::vector<IdentifierInfo> identifiers;
|
||||
};
|
||||
std::shared_ptr<TriggerPlan> GetPlan(DbAccessor *db_accessor) const;
|
||||
std::shared_ptr<TriggerPlan> GetPlan(DbAccessor *db_accessor, const query::AuthChecker *auth_checker) const;
|
||||
|
||||
std::string name_;
|
||||
ParsedQuery parsed_statements_;
|
||||
@ -57,6 +61,7 @@ struct Trigger {
|
||||
|
||||
mutable utils::SpinLock plan_lock_;
|
||||
mutable std::shared_ptr<TriggerPlan> trigger_plan_;
|
||||
std::optional<std::string> owner_;
|
||||
};
|
||||
|
||||
enum class TriggerPhase : uint8_t { BEFORE_COMMIT, AFTER_COMMIT };
|
||||
@ -65,12 +70,14 @@ struct TriggerStore {
|
||||
explicit TriggerStore(std::filesystem::path directory);
|
||||
|
||||
void RestoreTriggers(utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor,
|
||||
utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config);
|
||||
utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config,
|
||||
const query::AuthChecker *auth_checker);
|
||||
|
||||
void AddTrigger(const std::string &name, const std::string &query,
|
||||
void AddTrigger(std::string name, const std::string &query,
|
||||
const std::map<std::string, storage::PropertyValue> &user_parameters, TriggerEventType event_type,
|
||||
TriggerPhase phase, utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor,
|
||||
utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config);
|
||||
utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config,
|
||||
std::optional<std::string> owner, const query::AuthChecker *auth_checker);
|
||||
|
||||
void DropTrigger(const std::string &name);
|
||||
|
||||
@ -79,6 +86,7 @@ struct TriggerStore {
|
||||
std::string statement;
|
||||
TriggerEventType event_type;
|
||||
TriggerPhase phase;
|
||||
std::optional<std::string> owner;
|
||||
};
|
||||
|
||||
std::vector<TriggerInfo> GetTriggerInfo() const;
|
||||
|
@ -54,7 +54,7 @@ BENCHMARK_DEFINE_F(ExpansionBenchFixture, Match)(benchmark::State &state) {
|
||||
|
||||
while (state.KeepRunning()) {
|
||||
ResultStreamFaker results(&*db);
|
||||
interpreter->Prepare(query, {});
|
||||
interpreter->Prepare(query, {}, nullptr);
|
||||
interpreter->PullAll(&results);
|
||||
}
|
||||
}
|
||||
@ -69,7 +69,7 @@ BENCHMARK_DEFINE_F(ExpansionBenchFixture, Expand)(benchmark::State &state) {
|
||||
|
||||
while (state.KeepRunning()) {
|
||||
ResultStreamFaker results(&*db);
|
||||
interpreter->Prepare(query, {});
|
||||
interpreter->Prepare(query, {}, nullptr);
|
||||
interpreter->PullAll(&results);
|
||||
}
|
||||
}
|
||||
|
@ -6,6 +6,9 @@ add_custom_target(memgraph__e2e__streams__${FILE_NAME} ALL
|
||||
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${FILE_NAME})
|
||||
endfunction()
|
||||
|
||||
copy_streams_e2e_python_files(common.py)
|
||||
copy_streams_e2e_python_files(conftest.py)
|
||||
copy_streams_e2e_python_files(streams_tests.py)
|
||||
copy_streams_e2e_python_files(streams_owner_tests.py)
|
||||
copy_streams_e2e_python_files(streams_test_runner.sh)
|
||||
add_subdirectory(transformations)
|
||||
|
110
tests/e2e/streams/common.py
Normal file
110
tests/e2e/streams/common.py
Normal file
@ -0,0 +1,110 @@
|
||||
import mgclient
|
||||
import time
|
||||
|
||||
# These are the indices of the different values in the result of SHOW STREAM
|
||||
# query
|
||||
NAME = 0
|
||||
TOPICS = 1
|
||||
CONSUMER_GROUP = 2
|
||||
BATCH_INTERVAL = 3
|
||||
BATCH_SIZE = 4
|
||||
TRANSFORM = 5
|
||||
OWNER = 6
|
||||
IS_RUNNING = 7
|
||||
|
||||
|
||||
def execute_and_fetch_all(cursor, query):
|
||||
cursor.execute(query)
|
||||
return cursor.fetchall()
|
||||
|
||||
|
||||
def connect(**kwargs):
|
||||
connection = mgclient.connect(host="localhost", port=7687, **kwargs)
|
||||
connection.autocommit = True
|
||||
return connection
|
||||
|
||||
|
||||
def timed_wait(fun):
|
||||
start_time = time.time()
|
||||
seconds = 10
|
||||
|
||||
while True:
|
||||
current_time = time.time()
|
||||
elapsed_time = current_time - start_time
|
||||
|
||||
if elapsed_time > seconds:
|
||||
return False
|
||||
|
||||
if fun():
|
||||
return True
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def check_one_result_row(cursor, query):
|
||||
start_time = time.time()
|
||||
seconds = 10
|
||||
|
||||
while True:
|
||||
current_time = time.time()
|
||||
elapsed_time = current_time - start_time
|
||||
|
||||
if elapsed_time > seconds:
|
||||
return False
|
||||
|
||||
cursor.execute(query)
|
||||
results = cursor.fetchall()
|
||||
if len(results) < 1:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
return len(results) == 1
|
||||
|
||||
|
||||
def check_vertex_exists_with_topic_and_payload(cursor, topic, payload_bytes):
|
||||
assert check_one_result_row(cursor,
|
||||
"MATCH (n: MESSAGE {"
|
||||
f"payload: '{payload_bytes.decode('utf-8')}',"
|
||||
f"topic: '{topic}'"
|
||||
"}) RETURN n")
|
||||
|
||||
|
||||
def get_stream_info(cursor, stream_name):
|
||||
stream_infos = execute_and_fetch_all(cursor, "SHOW STREAMS")
|
||||
for stream_info in stream_infos:
|
||||
if (stream_info[NAME] == stream_name):
|
||||
return stream_info
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_is_running(cursor, stream_name):
|
||||
stream_info = get_stream_info(cursor, stream_name)
|
||||
|
||||
assert stream_info
|
||||
return stream_info[IS_RUNNING]
|
||||
|
||||
|
||||
def start_stream(cursor, stream_name):
|
||||
execute_and_fetch_all(cursor, f"START STREAM {stream_name}")
|
||||
|
||||
assert get_is_running(cursor, stream_name)
|
||||
|
||||
|
||||
def stop_stream(cursor, stream_name):
|
||||
execute_and_fetch_all(cursor, f"STOP STREAM {stream_name}")
|
||||
|
||||
assert not get_is_running(cursor, stream_name)
|
||||
|
||||
|
||||
def drop_stream(cursor, stream_name):
|
||||
execute_and_fetch_all(cursor, f"DROP STREAM {stream_name}")
|
||||
|
||||
assert get_stream_info(cursor, stream_name) is None
|
||||
|
||||
|
||||
def check_stream_info(cursor, stream_name, expected_stream_info):
|
||||
stream_info = get_stream_info(cursor, stream_name)
|
||||
assert len(stream_info) == len(expected_stream_info)
|
||||
for info, expected_info in zip(stream_info, expected_stream_info):
|
||||
assert info == expected_info
|
45
tests/e2e/streams/conftest.py
Normal file
45
tests/e2e/streams/conftest.py
Normal file
@ -0,0 +1,45 @@
|
||||
import pytest
|
||||
from kafka import KafkaProducer
|
||||
from kafka.admin import KafkaAdminClient, NewTopic
|
||||
|
||||
from common import execute_and_fetch_all, connect, NAME
|
||||
|
||||
# To run these test locally a running Kafka sever is necessery. The test tries
|
||||
# to connect on localhost:9092.
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def connection():
|
||||
connection = connect()
|
||||
yield connection
|
||||
cursor = connection.cursor()
|
||||
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n")
|
||||
stream_infos = execute_and_fetch_all(cursor, "SHOW STREAMS")
|
||||
for stream_info in stream_infos:
|
||||
execute_and_fetch_all(cursor, f"DROP STREAM {stream_info[NAME]}")
|
||||
users = execute_and_fetch_all(cursor, "SHOW USERS")
|
||||
for username, in users:
|
||||
execute_and_fetch_all(cursor, f"DROP USER {username}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def topics():
|
||||
admin_client = KafkaAdminClient(
|
||||
bootstrap_servers="localhost:9092", client_id='test')
|
||||
|
||||
topics = []
|
||||
topics_to_create = []
|
||||
for index in range(3):
|
||||
topic = f"topic_{index}"
|
||||
topics.append(topic)
|
||||
topics_to_create.append(NewTopic(name=topic,
|
||||
num_partitions=1, replication_factor=1))
|
||||
|
||||
admin_client.create_topics(new_topics=topics_to_create, timeout_ms=5000)
|
||||
yield topics
|
||||
admin_client.delete_topics(topics=topics, timeout_ms=5000)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def producer():
|
||||
yield KafkaProducer(bootstrap_servers="localhost:9092")
|
151
tests/e2e/streams/streams_owner_tests.py
Normal file
151
tests/e2e/streams/streams_owner_tests.py
Normal file
@ -0,0 +1,151 @@
|
||||
import sys
|
||||
import pytest
|
||||
import time
|
||||
import mgclient
|
||||
import common
|
||||
|
||||
|
||||
def get_cursor_with_user(username):
|
||||
connection = common.connect(username=username, password="")
|
||||
return connection.cursor()
|
||||
|
||||
|
||||
def create_admin_user(cursor, admin_user):
|
||||
common.execute_and_fetch_all(cursor, f"CREATE USER {admin_user}")
|
||||
common.execute_and_fetch_all(
|
||||
cursor, f"GRANT ALL PRIVILEGES TO {admin_user}")
|
||||
|
||||
|
||||
def create_stream_user(cursor, stream_user):
|
||||
common.execute_and_fetch_all(cursor, f"CREATE USER {stream_user}")
|
||||
common.execute_and_fetch_all(
|
||||
cursor, f"GRANT STREAM TO {stream_user}")
|
||||
|
||||
|
||||
def test_ownerless_stream(producer, topics, connection):
|
||||
assert len(topics) > 0
|
||||
userless_cursor = connection.cursor()
|
||||
common.execute_and_fetch_all(userless_cursor,
|
||||
"CREATE STREAM ownerless "
|
||||
f"TOPICS {topics[0]} "
|
||||
f"TRANSFORM transform.simple")
|
||||
common.start_stream(userless_cursor, "ownerless")
|
||||
time.sleep(1)
|
||||
|
||||
admin_user = "admin_user"
|
||||
create_admin_user(userless_cursor, admin_user)
|
||||
|
||||
producer.send(topics[0], b"first message").get(timeout=60)
|
||||
assert common.timed_wait(
|
||||
lambda: not common.get_is_running(userless_cursor, "ownerless"))
|
||||
|
||||
assert len(common.execute_and_fetch_all(
|
||||
userless_cursor, "MATCH (n) RETURN n")) == 0
|
||||
|
||||
common.execute_and_fetch_all(userless_cursor, f"DROP USER {admin_user}")
|
||||
common.start_stream(userless_cursor, "ownerless")
|
||||
time.sleep(1)
|
||||
|
||||
second_message = b"second message"
|
||||
producer.send(topics[0], second_message).get(timeout=60)
|
||||
common.check_vertex_exists_with_topic_and_payload(
|
||||
userless_cursor, topics[0], second_message)
|
||||
|
||||
assert len(common.execute_and_fetch_all(
|
||||
userless_cursor, "MATCH (n) RETURN n")) == 1
|
||||
|
||||
|
||||
def test_owner_is_shown(topics, connection):
|
||||
assert len(topics) > 0
|
||||
userless_cursor = connection.cursor()
|
||||
|
||||
stream_user = "stream_user"
|
||||
create_stream_user(userless_cursor, stream_user)
|
||||
stream_cursor = get_cursor_with_user(stream_user)
|
||||
|
||||
common.execute_and_fetch_all(stream_cursor, "CREATE STREAM test "
|
||||
f"TOPICS {topics[0]} "
|
||||
f"TRANSFORM transform.simple")
|
||||
|
||||
common.check_stream_info(userless_cursor, "test", ("test", [
|
||||
topics[0]], "mg_consumer", None, None,
|
||||
"transform.simple", stream_user, False))
|
||||
|
||||
|
||||
def test_insufficient_privileges(producer, topics, connection):
|
||||
assert len(topics) > 0
|
||||
userless_cursor = connection.cursor()
|
||||
|
||||
admin_user = "admin_user"
|
||||
create_admin_user(userless_cursor, admin_user)
|
||||
admin_cursor = get_cursor_with_user(admin_user)
|
||||
|
||||
stream_user = "stream_user"
|
||||
create_stream_user(userless_cursor, stream_user)
|
||||
stream_cursor = get_cursor_with_user(stream_user)
|
||||
|
||||
common.execute_and_fetch_all(stream_cursor,
|
||||
"CREATE STREAM insufficient_test "
|
||||
f"TOPICS {topics[0]} "
|
||||
f"TRANSFORM transform.simple")
|
||||
|
||||
# the stream is started by admin, but should check against the owner
|
||||
# privileges
|
||||
common.start_stream(admin_cursor, "insufficient_test")
|
||||
time.sleep(1)
|
||||
|
||||
producer.send(topics[0], b"first message").get(timeout=60)
|
||||
assert common.timed_wait(
|
||||
lambda: not common.get_is_running(userless_cursor, "insufficient_test"))
|
||||
|
||||
assert len(common.execute_and_fetch_all(
|
||||
userless_cursor, "MATCH (n) RETURN n")) == 0
|
||||
|
||||
common.execute_and_fetch_all(
|
||||
admin_cursor, f"GRANT CREATE TO {stream_user}")
|
||||
common.start_stream(userless_cursor, "insufficient_test")
|
||||
time.sleep(1)
|
||||
|
||||
second_message = b"second message"
|
||||
producer.send(topics[0], second_message).get(timeout=60)
|
||||
common.check_vertex_exists_with_topic_and_payload(
|
||||
userless_cursor, topics[0], second_message)
|
||||
|
||||
assert len(common.execute_and_fetch_all(
|
||||
userless_cursor, "MATCH (n) RETURN n")) == 1
|
||||
|
||||
|
||||
def test_happy_case(producer, topics, connection):
|
||||
assert len(topics) > 0
|
||||
userless_cursor = connection.cursor()
|
||||
|
||||
admin_user = "admin_user"
|
||||
create_admin_user(userless_cursor, admin_user)
|
||||
admin_cursor = get_cursor_with_user(admin_user)
|
||||
|
||||
stream_user = "stream_user"
|
||||
create_stream_user(userless_cursor, stream_user)
|
||||
stream_cursor = get_cursor_with_user(stream_user)
|
||||
common.execute_and_fetch_all(
|
||||
admin_cursor, f"GRANT CREATE TO {stream_user}")
|
||||
|
||||
common.execute_and_fetch_all(stream_cursor,
|
||||
"CREATE STREAM insufficient_test "
|
||||
f"TOPICS {topics[0]} "
|
||||
f"TRANSFORM transform.simple")
|
||||
|
||||
common.start_stream(stream_cursor, "insufficient_test")
|
||||
time.sleep(1)
|
||||
|
||||
first_message = b"first message"
|
||||
producer.send(topics[0], first_message).get(timeout=60)
|
||||
|
||||
common.check_vertex_exists_with_topic_and_payload(
|
||||
userless_cursor, topics[0], first_message)
|
||||
|
||||
assert len(common.execute_and_fetch_all(
|
||||
userless_cursor, "MATCH (n) RETURN n")) == 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main([__file__, "-rA"]))
|
@ -3,4 +3,4 @@
|
||||
# This workaround is necessary to run in the same virtualenv as the e2e runner.py
|
||||
|
||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||
python3 "$DIR/streams_tests.py"
|
||||
python3 "$DIR/$1"
|
||||
|
@ -1,28 +1,11 @@
|
||||
#!/usr/bin/python3
|
||||
|
||||
# To run these test locally a running Kafka sever is necessery. The test tries
|
||||
# to connect on localhost:9092.
|
||||
|
||||
# All tests are implemented in this file, because using the same test fixtures
|
||||
# in multiple files is not possible in a straightforward way
|
||||
|
||||
import sys
|
||||
import pytest
|
||||
import mgclient
|
||||
import time
|
||||
from multiprocessing import Process, Value
|
||||
from kafka import KafkaProducer
|
||||
from kafka.admin import KafkaAdminClient, NewTopic
|
||||
|
||||
# These are the indices of the different values in the result of SHOW STREAM
|
||||
# query
|
||||
NAME = 0
|
||||
TOPICS = 1
|
||||
CONSUMER_GROUP = 2
|
||||
BATCH_INTERVAL = 3
|
||||
BATCH_SIZE = 4
|
||||
TRANSFORM = 5
|
||||
IS_RUNNING = 6
|
||||
import common
|
||||
|
||||
# These are the indices of the query and parameters in the result of CHECK
|
||||
# STREAM query
|
||||
@ -35,155 +18,22 @@ TRANSFORMATIONS_TO_CHECK = [
|
||||
SIMPLE_MSG = b'message'
|
||||
|
||||
|
||||
def execute_and_fetch_all(cursor, query):
|
||||
cursor.execute(query)
|
||||
return cursor.fetchall()
|
||||
|
||||
|
||||
def connect():
|
||||
connection = mgclient.connect(host="localhost", port=7687)
|
||||
connection.autocommit = True
|
||||
return connection
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def connection():
|
||||
connection = connect()
|
||||
yield connection
|
||||
cursor = connection.cursor()
|
||||
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n")
|
||||
stream_infos = execute_and_fetch_all(cursor, "SHOW STREAMS")
|
||||
for stream_info in stream_infos:
|
||||
execute_and_fetch_all(cursor, f"DROP STREAM {stream_info[NAME]}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def topics():
|
||||
admin_client = KafkaAdminClient(
|
||||
bootstrap_servers="localhost:9092", client_id='test')
|
||||
|
||||
topics = []
|
||||
topics_to_create = []
|
||||
for index in range(3):
|
||||
topic = f"topic_{index}"
|
||||
topics.append(topic)
|
||||
topics_to_create.append(NewTopic(name=topic,
|
||||
num_partitions=1, replication_factor=1))
|
||||
|
||||
admin_client.create_topics(new_topics=topics_to_create, timeout_ms=5000)
|
||||
yield topics
|
||||
admin_client.delete_topics(topics=topics, timeout_ms=5000)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def producer():
|
||||
yield KafkaProducer(bootstrap_servers="localhost:9092")
|
||||
|
||||
|
||||
def timed_wait(fun):
|
||||
start_time = time.time()
|
||||
seconds = 10
|
||||
|
||||
while True:
|
||||
current_time = time.time()
|
||||
elapsed_time = current_time - start_time
|
||||
|
||||
if elapsed_time > seconds:
|
||||
return False
|
||||
|
||||
if fun():
|
||||
return True
|
||||
|
||||
|
||||
def check_one_result_row(cursor, query):
|
||||
start_time = time.time()
|
||||
seconds = 10
|
||||
|
||||
while True:
|
||||
current_time = time.time()
|
||||
elapsed_time = current_time - start_time
|
||||
|
||||
if elapsed_time > seconds:
|
||||
return False
|
||||
|
||||
cursor.execute(query)
|
||||
results = cursor.fetchall()
|
||||
if len(results) < 1:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
return len(results) == 1
|
||||
|
||||
|
||||
def check_vertex_exists_with_topic_and_payload(cursor, topic, payload_bytes):
|
||||
assert check_one_result_row(cursor,
|
||||
"MATCH (n: MESSAGE {"
|
||||
f"payload: '{payload_bytes.decode('utf-8')}',"
|
||||
f"topic: '{topic}'"
|
||||
"}) RETURN n")
|
||||
|
||||
|
||||
def get_stream_info(cursor, stream_name):
|
||||
stream_infos = execute_and_fetch_all(cursor, "SHOW STREAMS")
|
||||
for stream_info in stream_infos:
|
||||
if (stream_info[NAME] == stream_name):
|
||||
return stream_info
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_is_running(cursor, stream_name):
|
||||
stream_info = get_stream_info(cursor, stream_name)
|
||||
|
||||
assert stream_info
|
||||
return stream_info[IS_RUNNING]
|
||||
|
||||
|
||||
def start_stream(cursor, stream_name):
|
||||
execute_and_fetch_all(cursor, f"START STREAM {stream_name}")
|
||||
|
||||
assert get_is_running(cursor, stream_name)
|
||||
|
||||
|
||||
def stop_stream(cursor, stream_name):
|
||||
execute_and_fetch_all(cursor, f"STOP STREAM {stream_name}")
|
||||
|
||||
assert not get_is_running(cursor, stream_name)
|
||||
|
||||
|
||||
def drop_stream(cursor, stream_name):
|
||||
execute_and_fetch_all(cursor, f"DROP STREAM {stream_name}")
|
||||
|
||||
assert get_stream_info(cursor, stream_name) is None
|
||||
|
||||
|
||||
def check_stream_info(cursor, stream_name, expected_stream_info):
|
||||
stream_info = get_stream_info(cursor, stream_name)
|
||||
assert len(stream_info) == len(expected_stream_info)
|
||||
for info, expected_info in zip(stream_info, expected_stream_info):
|
||||
assert info == expected_info
|
||||
|
||||
##############################################
|
||||
# Tests
|
||||
##############################################
|
||||
|
||||
|
||||
@pytest.mark.parametrize("transformation", TRANSFORMATIONS_TO_CHECK)
|
||||
def test_simple(producer, topics, connection, transformation):
|
||||
assert len(topics) > 0
|
||||
cursor = connection.cursor()
|
||||
execute_and_fetch_all(cursor,
|
||||
"CREATE STREAM test "
|
||||
f"TOPICS {','.join(topics)} "
|
||||
f"TRANSFORM {transformation}")
|
||||
start_stream(cursor, "test")
|
||||
common.execute_and_fetch_all(cursor,
|
||||
"CREATE STREAM test "
|
||||
f"TOPICS {','.join(topics)} "
|
||||
f"TRANSFORM {transformation}")
|
||||
common.start_stream(cursor, "test")
|
||||
time.sleep(5)
|
||||
|
||||
for topic in topics:
|
||||
producer.send(topic, SIMPLE_MSG).get(timeout=60)
|
||||
|
||||
for topic in topics:
|
||||
check_vertex_exists_with_topic_and_payload(
|
||||
common.check_vertex_exists_with_topic_and_payload(
|
||||
cursor, topic, SIMPLE_MSG)
|
||||
|
||||
|
||||
@ -196,13 +46,13 @@ def test_separate_consumers(producer, topics, connection, transformation):
|
||||
for topic in topics:
|
||||
stream_name = "stream_" + topic
|
||||
stream_names.append(stream_name)
|
||||
execute_and_fetch_all(cursor,
|
||||
f"CREATE STREAM {stream_name} "
|
||||
f"TOPICS {topic} "
|
||||
f"TRANSFORM {transformation}")
|
||||
common.execute_and_fetch_all(cursor,
|
||||
f"CREATE STREAM {stream_name} "
|
||||
f"TOPICS {topic} "
|
||||
f"TRANSFORM {transformation}")
|
||||
|
||||
for stream_name in stream_names:
|
||||
start_stream(cursor, stream_name)
|
||||
common.start_stream(cursor, stream_name)
|
||||
|
||||
time.sleep(5)
|
||||
|
||||
@ -210,7 +60,7 @@ def test_separate_consumers(producer, topics, connection, transformation):
|
||||
producer.send(topic, SIMPLE_MSG).get(timeout=60)
|
||||
|
||||
for topic in topics:
|
||||
check_vertex_exists_with_topic_and_payload(
|
||||
common.check_vertex_exists_with_topic_and_payload(
|
||||
cursor, topic, SIMPLE_MSG)
|
||||
|
||||
|
||||
@ -223,41 +73,42 @@ def test_start_from_last_committed_offset(producer, topics, connection):
|
||||
# restarting Memgraph during a single workload cannot be done currently.
|
||||
assert len(topics) > 0
|
||||
cursor = connection.cursor()
|
||||
execute_and_fetch_all(cursor,
|
||||
"CREATE STREAM test "
|
||||
f"TOPICS {topics[0]} "
|
||||
"TRANSFORM transform.simple")
|
||||
start_stream(cursor, "test")
|
||||
common.execute_and_fetch_all(cursor,
|
||||
"CREATE STREAM test "
|
||||
f"TOPICS {topics[0]} "
|
||||
"TRANSFORM transform.simple")
|
||||
common.start_stream(cursor, "test")
|
||||
time.sleep(1)
|
||||
|
||||
producer.send(topics[0], SIMPLE_MSG).get(timeout=60)
|
||||
|
||||
check_vertex_exists_with_topic_and_payload(
|
||||
common.check_vertex_exists_with_topic_and_payload(
|
||||
cursor, topics[0], SIMPLE_MSG)
|
||||
|
||||
stop_stream(cursor, "test")
|
||||
drop_stream(cursor, "test")
|
||||
common.stop_stream(cursor, "test")
|
||||
common.drop_stream(cursor, "test")
|
||||
|
||||
messages = [b"second message", b"third message"]
|
||||
for message in messages:
|
||||
producer.send(topics[0], message).get(timeout=60)
|
||||
|
||||
for message in messages:
|
||||
vertices_with_msg = execute_and_fetch_all(cursor,
|
||||
"MATCH (n: MESSAGE {"
|
||||
f"payload: '{message.decode('utf-8')}'"
|
||||
"}) RETURN n")
|
||||
vertices_with_msg = common.execute_and_fetch_all(
|
||||
cursor,
|
||||
"MATCH (n: MESSAGE {"
|
||||
f"payload: '{message.decode('utf-8')}'"
|
||||
"}) RETURN n")
|
||||
|
||||
assert len(vertices_with_msg) == 0
|
||||
|
||||
execute_and_fetch_all(cursor,
|
||||
"CREATE STREAM test "
|
||||
f"TOPICS {topics[0]} "
|
||||
"TRANSFORM transform.simple")
|
||||
start_stream(cursor, "test")
|
||||
common.execute_and_fetch_all(cursor,
|
||||
"CREATE STREAM test "
|
||||
f"TOPICS {topics[0]} "
|
||||
"TRANSFORM transform.simple")
|
||||
common.start_stream(cursor, "test")
|
||||
|
||||
for message in messages:
|
||||
check_vertex_exists_with_topic_and_payload(
|
||||
common.check_vertex_exists_with_topic_and_payload(
|
||||
cursor, topics[0], message)
|
||||
|
||||
|
||||
@ -265,16 +116,16 @@ def test_start_from_last_committed_offset(producer, topics, connection):
|
||||
def test_check_stream(producer, topics, connection, transformation):
|
||||
assert len(topics) > 0
|
||||
cursor = connection.cursor()
|
||||
execute_and_fetch_all(cursor,
|
||||
"CREATE STREAM test "
|
||||
f"TOPICS {topics[0]} "
|
||||
f"TRANSFORM {transformation} "
|
||||
"BATCH_SIZE 1")
|
||||
start_stream(cursor, "test")
|
||||
common.execute_and_fetch_all(cursor,
|
||||
"CREATE STREAM test "
|
||||
f"TOPICS {topics[0]} "
|
||||
f"TRANSFORM {transformation} "
|
||||
"BATCH_SIZE 1")
|
||||
common.start_stream(cursor, "test")
|
||||
time.sleep(1)
|
||||
|
||||
producer.send(topics[0], SIMPLE_MSG).get(timeout=60)
|
||||
stop_stream(cursor, "test")
|
||||
common.stop_stream(cursor, "test")
|
||||
|
||||
messages = [b"first message", b"second message", b"third message"]
|
||||
for message in messages:
|
||||
@ -283,7 +134,7 @@ def test_check_stream(producer, topics, connection, transformation):
|
||||
def check_check_stream(batch_limit):
|
||||
assert transformation == "transform.simple" \
|
||||
or transformation == "transform.with_parameters"
|
||||
test_results = execute_and_fetch_all(
|
||||
test_results = common.execute_and_fetch_all(
|
||||
cursor, f"CHECK STREAM test BATCH_LIMIT {batch_limit}")
|
||||
assert len(test_results) == batch_limit
|
||||
|
||||
@ -308,42 +159,47 @@ def test_check_stream(producer, topics, connection, transformation):
|
||||
check_check_stream(1)
|
||||
check_check_stream(2)
|
||||
check_check_stream(3)
|
||||
start_stream(cursor, "test")
|
||||
common.start_stream(cursor, "test")
|
||||
|
||||
for message in messages:
|
||||
check_vertex_exists_with_topic_and_payload(
|
||||
common.check_vertex_exists_with_topic_and_payload(
|
||||
cursor, topics[0], message)
|
||||
|
||||
|
||||
def test_show_streams(producer, topics, connection):
|
||||
assert len(topics) > 1
|
||||
cursor = connection.cursor()
|
||||
execute_and_fetch_all(cursor,
|
||||
"CREATE STREAM default_values "
|
||||
f"TOPICS {topics[0]} "
|
||||
f"TRANSFORM transform.simple")
|
||||
common.execute_and_fetch_all(cursor,
|
||||
"CREATE STREAM default_values "
|
||||
f"TOPICS {topics[0]} "
|
||||
f"TRANSFORM transform.simple")
|
||||
|
||||
consumer_group = "my_special_consumer_group"
|
||||
batch_interval = 42
|
||||
batch_size = 3
|
||||
execute_and_fetch_all(cursor,
|
||||
"CREATE STREAM complex_values "
|
||||
f"TOPICS {','.join(topics)} "
|
||||
f"TRANSFORM transform.with_parameters "
|
||||
f"CONSUMER_GROUP {consumer_group} "
|
||||
f"BATCH_INTERVAL {batch_interval} "
|
||||
f"BATCH_SIZE {batch_size} ")
|
||||
common.execute_and_fetch_all(cursor,
|
||||
"CREATE STREAM complex_values "
|
||||
f"TOPICS {','.join(topics)} "
|
||||
f"TRANSFORM transform.with_parameters "
|
||||
f"CONSUMER_GROUP {consumer_group} "
|
||||
f"BATCH_INTERVAL {batch_interval} "
|
||||
f"BATCH_SIZE {batch_size} ")
|
||||
|
||||
assert len(execute_and_fetch_all(cursor, "SHOW STREAMS")) == 2
|
||||
assert len(common.execute_and_fetch_all(cursor, "SHOW STREAMS")) == 2
|
||||
|
||||
check_stream_info(cursor, "default_values", ("default_values", [
|
||||
topics[0]], "mg_consumer", None, None,
|
||||
"transform.simple", False))
|
||||
common.check_stream_info(cursor, "default_values", ("default_values", [
|
||||
topics[0]], "mg_consumer", None, None,
|
||||
"transform.simple", None, False))
|
||||
|
||||
check_stream_info(cursor, "complex_values", ("complex_values", topics,
|
||||
consumer_group, batch_interval, batch_size,
|
||||
"transform.with_parameters",
|
||||
False))
|
||||
common.check_stream_info(cursor, "complex_values", (
|
||||
"complex_values",
|
||||
topics,
|
||||
consumer_group,
|
||||
batch_interval,
|
||||
batch_size,
|
||||
"transform.with_parameters",
|
||||
None,
|
||||
False))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("operation", ["START", "STOP"])
|
||||
@ -360,10 +216,10 @@ def test_start_and_stop_during_check(producer, topics, connection, operation):
|
||||
assert len(topics) > 1
|
||||
assert operation == "START" or operation == "STOP"
|
||||
cursor = connection.cursor()
|
||||
execute_and_fetch_all(cursor,
|
||||
"CREATE STREAM test_stream "
|
||||
f"TOPICS {topics[0]} "
|
||||
f"TRANSFORM transform.simple")
|
||||
common.execute_and_fetch_all(cursor,
|
||||
"CREATE STREAM test_stream "
|
||||
f"TOPICS {topics[0]} "
|
||||
f"TRANSFORM transform.simple")
|
||||
|
||||
check_counter = Value('i', 0)
|
||||
check_result_len = Value('i', 0)
|
||||
@ -377,10 +233,11 @@ def test_start_and_stop_during_check(producer, topics, connection, operation):
|
||||
def call_check(counter, result_len):
|
||||
# This process will call the CHECK query and increment the counter
|
||||
# based on its progress and expected behavior
|
||||
connection = connect()
|
||||
connection = common.connect()
|
||||
cursor = connection.cursor()
|
||||
counter.value = CHECK_BEFORE_EXECUTE
|
||||
result = execute_and_fetch_all(cursor, "CHECK STREAM test_stream")
|
||||
result = common.execute_and_fetch_all(
|
||||
cursor, "CHECK STREAM test_stream")
|
||||
result_len.value = len(result)
|
||||
counter.value = CHECK_AFTER_FETCHALL
|
||||
if len(result) > 0 and "payload: 'message'" in result[0][QUERY]:
|
||||
@ -397,11 +254,12 @@ def test_start_and_stop_during_check(producer, topics, connection, operation):
|
||||
def call_operation(counter):
|
||||
# This porcess will call the query with the specified operation and
|
||||
# increment the counter based on its progress and expected behavior
|
||||
connection = connect()
|
||||
connection = common.connect()
|
||||
cursor = connection.cursor()
|
||||
counter.value = OP_BEFORE_EXECUTE
|
||||
try:
|
||||
execute_and_fetch_all(cursor, f"{operation} STREAM test_stream")
|
||||
common.execute_and_fetch_all(
|
||||
cursor, f"{operation} STREAM test_stream")
|
||||
counter.value = OP_AFTER_FETCHALL
|
||||
except mgclient.DatabaseError as e:
|
||||
if "Kafka consumer test_stream is already stopped" in str(e):
|
||||
@ -421,15 +279,19 @@ def test_start_and_stop_during_check(producer, topics, connection, operation):
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
assert timed_wait(lambda: check_counter.value == CHECK_BEFORE_EXECUTE)
|
||||
assert timed_wait(lambda: get_is_running(cursor, "test_stream"))
|
||||
assert common.timed_wait(
|
||||
lambda: check_counter.value == CHECK_BEFORE_EXECUTE)
|
||||
assert common.timed_wait(
|
||||
lambda: common.get_is_running(cursor, "test_stream"))
|
||||
assert check_counter.value == CHECK_BEFORE_EXECUTE, "SHOW STREAMS " \
|
||||
"was blocked until the end of CHECK STREAM"
|
||||
operation_proc.start()
|
||||
assert timed_wait(lambda: operation_counter.value == OP_BEFORE_EXECUTE)
|
||||
assert common.timed_wait(
|
||||
lambda: operation_counter.value == OP_BEFORE_EXECUTE)
|
||||
|
||||
producer.send(topics[0], SIMPLE_MSG).get(timeout=60)
|
||||
assert timed_wait(lambda: check_counter.value > CHECK_AFTER_FETCHALL)
|
||||
assert common.timed_wait(
|
||||
lambda: check_counter.value > CHECK_AFTER_FETCHALL)
|
||||
assert check_counter.value == CHECK_CORRECT_RESULT
|
||||
assert check_result_len.value == 1
|
||||
check_stream_proc.join()
|
||||
@ -437,10 +299,10 @@ def test_start_and_stop_during_check(producer, topics, connection, operation):
|
||||
operation_proc.join()
|
||||
if operation == "START":
|
||||
assert operation_counter.value == OP_AFTER_FETCHALL
|
||||
assert get_is_running(cursor, "test_stream")
|
||||
assert common.get_is_running(cursor, "test_stream")
|
||||
else:
|
||||
assert operation_counter.value == OP_ALREADY_STOPPED_EXCEPTION
|
||||
assert not get_is_running(cursor, "test_stream")
|
||||
assert not common.get_is_running(cursor, "test_stream")
|
||||
|
||||
finally:
|
||||
# to make sure CHECK STREAM finishes
|
||||
@ -455,42 +317,64 @@ def test_check_already_started_stream(topics, connection):
|
||||
assert len(topics) > 0
|
||||
cursor = connection.cursor()
|
||||
|
||||
execute_and_fetch_all(cursor,
|
||||
"CREATE STREAM started_stream "
|
||||
f"TOPICS {topics[0]} "
|
||||
f"TRANSFORM transform.simple")
|
||||
start_stream(cursor, "started_stream")
|
||||
common.execute_and_fetch_all(cursor,
|
||||
"CREATE STREAM started_stream "
|
||||
f"TOPICS {topics[0]} "
|
||||
f"TRANSFORM transform.simple")
|
||||
common.start_stream(cursor, "started_stream")
|
||||
|
||||
with pytest.raises(mgclient.DatabaseError):
|
||||
execute_and_fetch_all(cursor, "CHECK STREAM started_stream")
|
||||
common.execute_and_fetch_all(cursor, "CHECK STREAM started_stream")
|
||||
|
||||
|
||||
def test_start_checked_stream_after_timeout(topics, connection):
|
||||
cursor = connection.cursor()
|
||||
execute_and_fetch_all(cursor,
|
||||
"CREATE STREAM test_stream "
|
||||
f"TOPICS {topics[0]} "
|
||||
f"TRANSFORM transform.simple")
|
||||
common.execute_and_fetch_all(cursor,
|
||||
"CREATE STREAM test_stream "
|
||||
f"TOPICS {topics[0]} "
|
||||
f"TRANSFORM transform.simple")
|
||||
|
||||
timeout_ms = 2000
|
||||
|
||||
def call_check():
|
||||
execute_and_fetch_all(
|
||||
connect().cursor(),
|
||||
common.execute_and_fetch_all(
|
||||
common.connect().cursor(),
|
||||
f"CHECK STREAM test_stream TIMEOUT {timeout_ms}")
|
||||
|
||||
check_stream_proc = Process(target=call_check, daemon=True)
|
||||
|
||||
start = time.time()
|
||||
check_stream_proc.start()
|
||||
assert timed_wait(lambda: get_is_running(cursor, "test_stream"))
|
||||
start_stream(cursor, "test_stream")
|
||||
assert common.timed_wait(
|
||||
lambda: common.get_is_running(cursor, "test_stream"))
|
||||
common.start_stream(cursor, "test_stream")
|
||||
end = time.time()
|
||||
|
||||
assert (end - start) < 1.3 * \
|
||||
timeout_ms, "The START STREAM was blocked too long"
|
||||
assert get_is_running(cursor, "test_stream")
|
||||
stop_stream(cursor, "test_stream")
|
||||
assert common.get_is_running(cursor, "test_stream")
|
||||
common.stop_stream(cursor, "test_stream")
|
||||
|
||||
|
||||
def test_restart_after_error(producer, topics, connection):
|
||||
cursor = connection.cursor()
|
||||
common.execute_and_fetch_all(cursor,
|
||||
"CREATE STREAM test_stream "
|
||||
f"TOPICS {topics[0]} "
|
||||
f"TRANSFORM transform.query")
|
||||
|
||||
common.start_stream(cursor, "test_stream")
|
||||
time.sleep(1)
|
||||
|
||||
producer.send(topics[0], SIMPLE_MSG).get(timeout=60)
|
||||
assert common.timed_wait(
|
||||
lambda: not common.get_is_running(cursor, "test_stream"))
|
||||
|
||||
common.start_stream(cursor, "test_stream")
|
||||
time.sleep(1)
|
||||
producer.send(topics[0], b'CREATE (n:VERTEX { id : 42 })')
|
||||
assert common.check_one_result_row(
|
||||
cursor, "MATCH (n:VERTEX { id : 42 }) RETURN n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -35,3 +35,17 @@ def with_parameters(context: mgp.TransCtx,
|
||||
"topic": message.topic_name()}))
|
||||
|
||||
return result_queries
|
||||
|
||||
|
||||
@mgp.transformation
|
||||
def query(messages: mgp.Messages
|
||||
) -> mgp.Record(query=str, parameters=mgp.Nullable[mgp.Map]):
|
||||
result_queries = []
|
||||
|
||||
for i in range(0, messages.total_messages()):
|
||||
message = messages.message_at(i)
|
||||
payload_as_str = message.payload().decode("utf-8")
|
||||
result_queries.append(mgp.Record(
|
||||
query=payload_as_str, parameters=None))
|
||||
|
||||
return result_queries
|
||||
|
@ -10,5 +10,10 @@ workloads:
|
||||
- name: "Streams start, stop and show"
|
||||
binary: "tests/e2e/streams/streams_test_runner.sh"
|
||||
proc: "tests/e2e/streams/transformations/"
|
||||
args: []
|
||||
args: ["streams_tests.py"]
|
||||
<<: *template_cluster
|
||||
- name: "Streams with users"
|
||||
binary: "tests/e2e/streams/streams_test_runner.sh"
|
||||
proc: "tests/e2e/streams/transformations/"
|
||||
args: ["streams_owner_tests.py"]
|
||||
<<: *template_cluster
|
||||
|
@ -9,3 +9,6 @@ target_link_libraries(memgraph__e2e__triggers__on_update memgraph__e2e__triggers
|
||||
|
||||
add_executable(memgraph__e2e__triggers__on_delete on_delete_triggers.cpp)
|
||||
target_link_libraries(memgraph__e2e__triggers__on_delete memgraph__e2e__triggers_common)
|
||||
|
||||
add_executable(memgraph__e2e__triggers__privileges privilige_check.cpp)
|
||||
target_link_libraries(memgraph__e2e__triggers__privileges memgraph__e2e__triggers_common)
|
||||
|
@ -1,7 +1,9 @@
|
||||
#include "common.hpp"
|
||||
|
||||
#include <chrono>
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <thread>
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <gflags/gflags.h>
|
||||
@ -10,13 +12,17 @@
|
||||
|
||||
DEFINE_uint64(bolt_port, 7687, "Bolt port");
|
||||
|
||||
std::unique_ptr<mg::Client> Connect() {
|
||||
auto client =
|
||||
mg::Client::Connect({.host = "127.0.0.1", .port = static_cast<uint16_t>(FLAGS_bolt_port), .use_ssl = false});
|
||||
std::unique_ptr<mg::Client> ConnectWithUser(const std::string_view username) {
|
||||
auto client = mg::Client::Connect({.host = "127.0.0.1",
|
||||
.port = static_cast<uint16_t>(FLAGS_bolt_port),
|
||||
.username = std::string{username},
|
||||
.use_ssl = false});
|
||||
MG_ASSERT(client, "Failed to connect!");
|
||||
return client;
|
||||
}
|
||||
|
||||
std::unique_ptr<mg::Client> Connect() { return ConnectWithUser(""); }
|
||||
|
||||
void CreateVertex(mg::Client &client, int vertex_id) {
|
||||
mg::Map parameters{
|
||||
{"id", mg::Value{vertex_id}},
|
||||
@ -49,10 +55,12 @@ int GetNumberOfAllVertices(mg::Client &client) {
|
||||
}
|
||||
|
||||
void WaitForNumberOfAllVertices(mg::Client &client, int number_of_vertices) {
|
||||
using namespace std::chrono_literals;
|
||||
utils::Timer timer{};
|
||||
while ((timer.Elapsed().count() <= 0.5) && GetNumberOfAllVertices(client) != number_of_vertices) {
|
||||
}
|
||||
CheckNumberOfAllVertices(client, number_of_vertices);
|
||||
std::this_thread::sleep_for(100ms);
|
||||
}
|
||||
|
||||
void CheckNumberOfAllVertices(mg::Client &client, int expected_number_of_vertices) {
|
||||
|
@ -11,6 +11,7 @@ constexpr std::string_view kVertexLabel{"VERTEX"};
|
||||
constexpr std::string_view kEdgeLabel{"EDGE"};
|
||||
|
||||
std::unique_ptr<mg::Client> Connect();
|
||||
std::unique_ptr<mg::Client> ConnectWithUser(const std::string_view username);
|
||||
void CreateVertex(mg::Client &client, int vertex_id);
|
||||
void CreateEdge(mg::Client &client, int from_vertex, int to_vertex, int edge_id);
|
||||
|
||||
|
161
tests/e2e/triggers/privilige_check.cpp
Normal file
161
tests/e2e/triggers/privilige_check.cpp
Normal file
@ -0,0 +1,161 @@
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
|
||||
#include <gflags/gflags.h>
|
||||
#include <spdlog/fmt/bundled/core.h>
|
||||
#include <mgclient.hpp>
|
||||
#include "common.hpp"
|
||||
#include "utils/logging.hpp"
|
||||
|
||||
constexpr std::string_view kTriggerPrefix{"CreatedVerticesTrigger"};
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
gflags::SetUsageMessage("Memgraph E2E Triggers privilege check");
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
logging::RedirectToStderr();
|
||||
|
||||
constexpr int kVertexId{42};
|
||||
constexpr std::string_view kUserlessLabel{"USERLESS"};
|
||||
constexpr std::string_view kAdminUser{"ADMIN"};
|
||||
constexpr std::string_view kUserWithCreate{"USER_WITH_CREATE"};
|
||||
constexpr std::string_view kUserWithoutCreate{"USER_WITHOUT_CREATE"};
|
||||
|
||||
mg::Client::Init();
|
||||
|
||||
auto userless_client = Connect();
|
||||
|
||||
const auto get_number_of_triggers = [&userless_client] {
|
||||
userless_client->Execute("SHOW TRIGGERS");
|
||||
auto result = userless_client->FetchAll();
|
||||
MG_ASSERT(result.has_value());
|
||||
return result->size();
|
||||
};
|
||||
|
||||
auto create_trigger = [&get_number_of_triggers](mg::Client &client, const std::string_view vertexLabel,
|
||||
bool should_succeed = true) {
|
||||
const auto number_of_triggers_before = get_number_of_triggers();
|
||||
client.Execute(
|
||||
fmt::format("CREATE TRIGGER {}{} ON () CREATE "
|
||||
"AFTER COMMIT "
|
||||
"EXECUTE "
|
||||
"UNWIND createdVertices as createdVertex "
|
||||
"CREATE (n: {} {{ id: createdVertex.id }})",
|
||||
kTriggerPrefix, vertexLabel, vertexLabel));
|
||||
client.DiscardAll();
|
||||
const auto number_of_triggers_after = get_number_of_triggers();
|
||||
if (should_succeed) {
|
||||
MG_ASSERT(number_of_triggers_after == number_of_triggers_before + 1);
|
||||
} else {
|
||||
MG_ASSERT(number_of_triggers_after == number_of_triggers_before);
|
||||
}
|
||||
};
|
||||
|
||||
auto delete_vertices = [&userless_client] {
|
||||
userless_client->Execute("MATCH (n) DETACH DELETE n;");
|
||||
userless_client->DiscardAll();
|
||||
CheckNumberOfAllVertices(*userless_client, 0);
|
||||
};
|
||||
|
||||
auto create_user = [&userless_client](const std::string_view username) {
|
||||
userless_client->Execute(fmt::format("CREATE USER {};", username));
|
||||
userless_client->DiscardAll();
|
||||
userless_client->Execute(fmt::format("GRANT TRIGGER TO {};", username));
|
||||
userless_client->DiscardAll();
|
||||
};
|
||||
|
||||
auto drop_user = [&userless_client](const std::string_view username) {
|
||||
userless_client->Execute(fmt::format("DROP USER {};", username));
|
||||
userless_client->DiscardAll();
|
||||
};
|
||||
|
||||
auto drop_trigger_of_user = [&userless_client](const std::string_view username) {
|
||||
userless_client->Execute(fmt::format("DROP TRIGGER {}{};", kTriggerPrefix, username));
|
||||
userless_client->DiscardAll();
|
||||
};
|
||||
|
||||
// Single trigger created without user, there is no existing users
|
||||
create_trigger(*userless_client, kUserlessLabel);
|
||||
CreateVertex(*userless_client, kVertexId);
|
||||
|
||||
WaitForNumberOfAllVertices(*userless_client, 2);
|
||||
CheckVertexExists(*userless_client, kVertexLabel, kVertexId);
|
||||
CheckVertexExists(*userless_client, kUserlessLabel, kVertexId);
|
||||
|
||||
delete_vertices();
|
||||
|
||||
// Single trigger created without user, there is an existing user
|
||||
// The trigger fails because there is no owner
|
||||
create_user(kAdminUser);
|
||||
CreateVertex(*userless_client, kVertexId);
|
||||
|
||||
CheckVertexExists(*userless_client, kVertexLabel, kVertexId);
|
||||
CheckNumberOfAllVertices(*userless_client, 1);
|
||||
|
||||
delete_vertices();
|
||||
|
||||
// Three triggers: without an owner, an owner with CREATE privilege, an owner without CREATE privilege; there are
|
||||
// existing users
|
||||
// Only the trigger which owner has CREATE privilege will succeed
|
||||
create_user(kUserWithCreate);
|
||||
userless_client->Execute(fmt::format("GRANT CREATE TO {};", kUserWithCreate));
|
||||
userless_client->DiscardAll();
|
||||
create_user(kUserWithoutCreate);
|
||||
auto client_with_create = ConnectWithUser(kUserWithCreate);
|
||||
auto client_without_create = ConnectWithUser(kUserWithoutCreate);
|
||||
create_trigger(*client_with_create, kUserWithCreate);
|
||||
create_trigger(*client_without_create, kUserWithoutCreate, false);
|
||||
|
||||
// Grant CREATE to be able to create the trigger than revoke it
|
||||
userless_client->Execute(fmt::format("GRANT CREATE TO {};", kUserWithoutCreate));
|
||||
userless_client->DiscardAll();
|
||||
create_trigger(*client_without_create, kUserWithoutCreate);
|
||||
userless_client->Execute(fmt::format("REVOKE CREATE FROM {};", kUserWithoutCreate));
|
||||
userless_client->DiscardAll();
|
||||
|
||||
CreateVertex(*userless_client, kVertexId);
|
||||
|
||||
WaitForNumberOfAllVertices(*userless_client, 2);
|
||||
CheckVertexExists(*userless_client, kVertexLabel, kVertexId);
|
||||
CheckVertexExists(*userless_client, kUserWithCreate, kVertexId);
|
||||
|
||||
delete_vertices();
|
||||
|
||||
// Three triggers: without an owner, an owner with CREATE privilege, an owner without CREATE privilege; there is no
|
||||
// existing user
|
||||
// All triggers will succeed, as there is no authorization is done when there are no users
|
||||
drop_user(kAdminUser);
|
||||
drop_user(kUserWithCreate);
|
||||
drop_user(kUserWithoutCreate);
|
||||
CreateVertex(*userless_client, kVertexId);
|
||||
|
||||
WaitForNumberOfAllVertices(*userless_client, 4);
|
||||
CheckVertexExists(*userless_client, kVertexLabel, kVertexId);
|
||||
CheckVertexExists(*userless_client, kUserlessLabel, kVertexId);
|
||||
CheckVertexExists(*userless_client, kUserWithCreate, kVertexId);
|
||||
CheckVertexExists(*userless_client, kUserWithoutCreate, kVertexId);
|
||||
|
||||
delete_vertices();
|
||||
drop_trigger_of_user(kUserlessLabel);
|
||||
drop_trigger_of_user(kUserWithCreate);
|
||||
drop_trigger_of_user(kUserWithoutCreate);
|
||||
|
||||
// The BEFORE COMMIT trigger without proper privileges make the transaction fail
|
||||
create_user(kUserWithoutCreate);
|
||||
userless_client->Execute(fmt::format("GRANT CREATE TO {};", kUserWithoutCreate));
|
||||
userless_client->DiscardAll();
|
||||
client_without_create->Execute(
|
||||
fmt::format("CREATE TRIGGER {}{} ON () CREATE "
|
||||
"BEFORE COMMIT "
|
||||
"EXECUTE "
|
||||
"UNWIND createdVertices as createdVertex "
|
||||
"CREATE (n: {} {{ id: createdVertex.id }})",
|
||||
kTriggerPrefix, kUserWithoutCreate, kUserWithoutCreate));
|
||||
client_without_create->DiscardAll();
|
||||
userless_client->Execute(fmt::format("REVOKE CREATE FROM {};", kUserWithoutCreate));
|
||||
userless_client->DiscardAll();
|
||||
|
||||
CreateVertex(*userless_client, kVertexId);
|
||||
CheckNumberOfAllVertices(*userless_client, 0);
|
||||
|
||||
return 0;
|
||||
}
|
@ -20,5 +20,9 @@ workloads:
|
||||
binary: "tests/e2e/triggers/memgraph__e2e__triggers__on_delete"
|
||||
args: ["--bolt-port", *bolt_port]
|
||||
<<: *template_cluster
|
||||
- name: "Triggers privilege check"
|
||||
binary: "tests/e2e/triggers/memgraph__e2e__triggers__privileges"
|
||||
args: ["--bolt-port", *bolt_port]
|
||||
<<: *template_cluster
|
||||
|
||||
|
||||
|
@ -22,7 +22,7 @@ int main(int argc, char *argv[]) {
|
||||
query::Interpreter interpreter{&interpreter_context};
|
||||
|
||||
ResultStreamFaker stream(&db);
|
||||
auto [header, _, qid] = interpreter.Prepare(argv[1], {});
|
||||
auto [header, _, qid] = interpreter.Prepare(argv[1], {}, nullptr);
|
||||
stream.Header(header);
|
||||
auto summary = interpreter.PullAll(&stream);
|
||||
stream.Summary(summary);
|
||||
|
@ -165,14 +165,14 @@ TEST_F(AuthWithStorage, RoleManipulations) {
|
||||
{
|
||||
auto user1 = auth.GetUser("user1");
|
||||
ASSERT_TRUE(user1);
|
||||
auto role1 = user1->role();
|
||||
ASSERT_TRUE(role1);
|
||||
const auto *role1 = user1->role();
|
||||
ASSERT_NE(role1, nullptr);
|
||||
ASSERT_EQ(role1->rolename(), "role1");
|
||||
|
||||
auto user2 = auth.GetUser("user2");
|
||||
ASSERT_TRUE(user2);
|
||||
auto role2 = user2->role();
|
||||
ASSERT_TRUE(role2);
|
||||
const auto *role2 = user2->role();
|
||||
ASSERT_NE(role2, nullptr);
|
||||
ASSERT_EQ(role2->rolename(), "role2");
|
||||
}
|
||||
|
||||
@ -181,13 +181,13 @@ TEST_F(AuthWithStorage, RoleManipulations) {
|
||||
{
|
||||
auto user1 = auth.GetUser("user1");
|
||||
ASSERT_TRUE(user1);
|
||||
auto role = user1->role();
|
||||
ASSERT_FALSE(role);
|
||||
const auto *role = user1->role();
|
||||
ASSERT_EQ(role, nullptr);
|
||||
|
||||
auto user2 = auth.GetUser("user2");
|
||||
ASSERT_TRUE(user2);
|
||||
auto role2 = user2->role();
|
||||
ASSERT_TRUE(role2);
|
||||
const auto *role2 = user2->role();
|
||||
ASSERT_NE(role2, nullptr);
|
||||
ASSERT_EQ(role2->rolename(), "role2");
|
||||
}
|
||||
|
||||
@ -199,13 +199,13 @@ TEST_F(AuthWithStorage, RoleManipulations) {
|
||||
{
|
||||
auto user1 = auth.GetUser("user1");
|
||||
ASSERT_TRUE(user1);
|
||||
auto role1 = user1->role();
|
||||
ASSERT_FALSE(role1);
|
||||
const auto *role1 = user1->role();
|
||||
ASSERT_EQ(role1, nullptr);
|
||||
|
||||
auto user2 = auth.GetUser("user2");
|
||||
ASSERT_TRUE(user2);
|
||||
auto role2 = user2->role();
|
||||
ASSERT_TRUE(role2);
|
||||
const auto *role2 = user2->role();
|
||||
ASSERT_NE(role2, nullptr);
|
||||
ASSERT_EQ(role2->rolename(), "role2");
|
||||
}
|
||||
|
||||
@ -245,8 +245,8 @@ TEST_F(AuthWithStorage, UserRoleLinkUnlink) {
|
||||
{
|
||||
auto user = auth.GetUser("user");
|
||||
ASSERT_TRUE(user);
|
||||
auto role = user->role();
|
||||
ASSERT_TRUE(role);
|
||||
const auto *role = user->role();
|
||||
ASSERT_NE(role, nullptr);
|
||||
ASSERT_EQ(role->rolename(), "role");
|
||||
}
|
||||
|
||||
@ -260,7 +260,7 @@ TEST_F(AuthWithStorage, UserRoleLinkUnlink) {
|
||||
{
|
||||
auto user = auth.GetUser("user");
|
||||
ASSERT_TRUE(user);
|
||||
ASSERT_FALSE(user->role());
|
||||
ASSERT_EQ(user->role(), nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
@ -620,8 +620,9 @@ TEST_F(AuthWithStorage, CaseInsensitivity) {
|
||||
auto user = auth.GetUser("aLIce");
|
||||
ASSERT_TRUE(user);
|
||||
ASSERT_EQ(user->username(), "alice");
|
||||
ASSERT_TRUE(user->role());
|
||||
ASSERT_EQ(user->role()->rolename(), "moderator");
|
||||
const auto *role = user->role();
|
||||
ASSERT_NE(role, nullptr);
|
||||
ASSERT_EQ(role->rolename(), "moderator");
|
||||
}
|
||||
|
||||
// AllUsersForRole
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include "glue/communication.hpp"
|
||||
#include "gmock/gmock.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "query/auth_checker.hpp"
|
||||
#include "query/config.hpp"
|
||||
#include "query/exceptions.hpp"
|
||||
#include "query/interpreter.hpp"
|
||||
@ -28,15 +29,17 @@ auto ToEdgeList(const communication::bolt::Value &v) {
|
||||
};
|
||||
|
||||
struct InterpreterFaker {
|
||||
explicit InterpreterFaker(storage::Storage *db, const query::InterpreterConfig config,
|
||||
const std::filesystem::path &data_directory)
|
||||
InterpreterFaker(storage::Storage *db, const query::InterpreterConfig config,
|
||||
const std::filesystem::path &data_directory)
|
||||
: interpreter_context(db, config, data_directory, "not used bootstrap servers"),
|
||||
interpreter(&interpreter_context) {}
|
||||
interpreter(&interpreter_context) {
|
||||
interpreter_context.auth_checker = &auth_checker;
|
||||
}
|
||||
|
||||
auto Prepare(const std::string &query, const std::map<std::string, storage::PropertyValue> ¶ms = {}) {
|
||||
ResultStreamFaker stream(interpreter_context.db);
|
||||
|
||||
const auto [header, _, qid] = interpreter.Prepare(query, params);
|
||||
const auto [header, _, qid] = interpreter.Prepare(query, params, nullptr);
|
||||
stream.Header(header);
|
||||
return std::make_pair(std::move(stream), qid);
|
||||
}
|
||||
@ -61,6 +64,7 @@ struct InterpreterFaker {
|
||||
return std::move(stream);
|
||||
}
|
||||
|
||||
query::AllowEverythingAuthChecker auth_checker;
|
||||
query::InterpreterContext interpreter_context;
|
||||
query::Interpreter interpreter;
|
||||
};
|
||||
|
@ -194,7 +194,7 @@ auto Execute(storage::Storage *db, const std::string &query) {
|
||||
query::Interpreter interpreter(&context);
|
||||
ResultStreamFaker stream(db);
|
||||
|
||||
auto [header, _, qid] = interpreter.Prepare(query, {});
|
||||
auto [header, _, qid] = interpreter.Prepare(query, {}, nullptr);
|
||||
stream.Header(header);
|
||||
auto summary = interpreter.PullAll(&stream);
|
||||
stream.Summary(summary);
|
||||
@ -711,7 +711,7 @@ class StatefulInterpreter {
|
||||
auto Execute(const std::string &query) {
|
||||
ResultStreamFaker stream(db_);
|
||||
|
||||
auto [header, _, qid] = interpreter_.Prepare(query, {});
|
||||
auto [header, _, qid] = interpreter_.Prepare(query, {}, nullptr);
|
||||
stream.Header(header);
|
||||
auto summary = interpreter_.PullAll(&stream);
|
||||
stream.Summary(summary);
|
||||
|
@ -42,7 +42,7 @@ class QueryExecution : public testing::Test {
|
||||
auto Execute(const std::string &query) {
|
||||
ResultStreamFaker stream(&*db_);
|
||||
|
||||
auto [header, _, qid] = interpreter_->Prepare(query, {});
|
||||
auto [header, _, qid] = interpreter_->Prepare(query, {}, nullptr);
|
||||
stream.Header(header);
|
||||
auto summary = interpreter_->PullAll(&stream);
|
||||
stream.Summary(summary);
|
||||
|
@ -34,6 +34,7 @@ StreamInfo CreateDefaultStreamInfo() {
|
||||
.batch_interval = std::nullopt,
|
||||
.batch_size = std::nullopt,
|
||||
.transformation_name = "not used in the tests",
|
||||
.owner = std::nullopt,
|
||||
};
|
||||
}
|
||||
|
||||
@ -57,6 +58,8 @@ class StreamsTest : public ::testing::Test {
|
||||
// Though there is a Streams object in interpreter context, it makes more sense to use a separate object to test,
|
||||
// because that provides a way to recreate the streams object and also give better control over the arguments of the
|
||||
// Streams constructor.
|
||||
// InterpreterContext::auth_checker_ is used in the Streams object, but only in the message processing part. Because
|
||||
// these tests don't send any messages, the auth_checker_ pointer can be left as nullptr.
|
||||
query::InterpreterContext interpreter_context_{&db_, query::InterpreterConfig{}, data_directory_,
|
||||
"dont care bootstrap servers"};
|
||||
std::filesystem::path streams_data_directory_{data_directory_ / "separate-dir-for-test"};
|
||||
@ -172,12 +175,14 @@ TEST_F(StreamsTest, RestoreStreams) {
|
||||
if (i > 0) {
|
||||
stream_info.batch_interval = std::chrono::milliseconds((i + 1) * 10);
|
||||
stream_info.batch_size = 1000 + i;
|
||||
stream_info.owner = std::string{"owner"} + iteration_postfix;
|
||||
}
|
||||
|
||||
mock_cluster_.CreateTopic(stream_info.topics[0]);
|
||||
}
|
||||
stream_check_datas[1].info.batch_interval = {};
|
||||
stream_check_datas[2].info.batch_size = {};
|
||||
stream_check_datas[3].info.owner = {};
|
||||
|
||||
const auto check_restore_logic = [&stream_check_datas, this]() {
|
||||
// Reset the Streams object to trigger reloading
|
||||
|
@ -1,12 +1,16 @@
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <filesystem>
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include "query/auth_checker.hpp"
|
||||
#include "query/config.hpp"
|
||||
#include "query/db_accessor.hpp"
|
||||
#include "query/frontend/ast/ast.hpp"
|
||||
#include "query/interpreter.hpp"
|
||||
#include "query/trigger.hpp"
|
||||
#include "query/typed_value.hpp"
|
||||
#include "utils/exceptions.hpp"
|
||||
#include "utils/memory.hpp"
|
||||
|
||||
namespace {
|
||||
@ -16,6 +20,12 @@ const std::unordered_set<query::TriggerEventType> kAllEventTypes{
|
||||
query::TriggerEventType::DELETE, query::TriggerEventType::VERTEX_UPDATE, query::TriggerEventType::EDGE_UPDATE,
|
||||
query::TriggerEventType::UPDATE,
|
||||
};
|
||||
|
||||
class MockAuthChecker : public query::AuthChecker {
|
||||
public:
|
||||
MOCK_CONST_METHOD2(IsUserAuthorized, bool(const std::optional<std::string> &username,
|
||||
const std::vector<query::AuthQuery::Privilege> &privileges));
|
||||
};
|
||||
} // namespace
|
||||
|
||||
class TriggerContextTest : public ::testing::Test {
|
||||
@ -820,6 +830,7 @@ class TriggerStoreTest : public ::testing::Test {
|
||||
|
||||
utils::SkipList<query::QueryCacheEntry> ast_cache;
|
||||
utils::SpinLock antlr_lock;
|
||||
query::AllowEverythingAuthChecker auth_checker;
|
||||
|
||||
private:
|
||||
void Clear() {
|
||||
@ -836,7 +847,7 @@ TEST_F(TriggerStoreTest, Restore) {
|
||||
|
||||
const auto reset_store = [&] {
|
||||
store.emplace(testing_directory);
|
||||
store->RestoreTriggers(&ast_cache, &*dba, &antlr_lock, query::InterpreterConfig::Query{});
|
||||
store->RestoreTriggers(&ast_cache, &*dba, &antlr_lock, query::InterpreterConfig::Query{}, &auth_checker);
|
||||
};
|
||||
|
||||
reset_store();
|
||||
@ -853,34 +864,40 @@ TEST_F(TriggerStoreTest, Restore) {
|
||||
const auto *trigger_name_after = "trigger_after";
|
||||
const auto *trigger_statement = "RETURN $parameter";
|
||||
const auto event_type = query::TriggerEventType::VERTEX_CREATE;
|
||||
const std::string owner{"owner"};
|
||||
store->AddTrigger(trigger_name_before, trigger_statement,
|
||||
std::map<std::string, storage::PropertyValue>{{"parameter", storage::PropertyValue{1}}}, event_type,
|
||||
query::TriggerPhase::BEFORE_COMMIT, &ast_cache, &*dba, &antlr_lock,
|
||||
query::InterpreterConfig::Query{});
|
||||
query::InterpreterConfig::Query{}, std::nullopt, &auth_checker);
|
||||
store->AddTrigger(trigger_name_after, trigger_statement,
|
||||
std::map<std::string, storage::PropertyValue>{{"parameter", storage::PropertyValue{"value"}}},
|
||||
event_type, query::TriggerPhase::AFTER_COMMIT, &ast_cache, &*dba, &antlr_lock,
|
||||
query::InterpreterConfig::Query{});
|
||||
query::InterpreterConfig::Query{}, {owner}, &auth_checker);
|
||||
|
||||
const auto check_triggers = [&] {
|
||||
ASSERT_EQ(store->GetTriggerInfo().size(), 2);
|
||||
|
||||
const auto verify_trigger = [&](const auto &trigger, const auto &name) {
|
||||
const auto verify_trigger = [&](const auto &trigger, const auto &name, const std::string *owner) {
|
||||
ASSERT_EQ(trigger.Name(), name);
|
||||
ASSERT_EQ(trigger.OriginalStatement(), trigger_statement);
|
||||
ASSERT_EQ(trigger.EventType(), event_type);
|
||||
if (owner != nullptr) {
|
||||
ASSERT_EQ(*trigger.Owner(), *owner);
|
||||
} else {
|
||||
ASSERT_FALSE(trigger.Owner().has_value());
|
||||
}
|
||||
};
|
||||
|
||||
const auto before_commit_triggers = store->BeforeCommitTriggers().access();
|
||||
ASSERT_EQ(before_commit_triggers.size(), 1);
|
||||
for (const auto &trigger : before_commit_triggers) {
|
||||
verify_trigger(trigger, trigger_name_before);
|
||||
verify_trigger(trigger, trigger_name_before, nullptr);
|
||||
}
|
||||
|
||||
const auto after_commit_triggers = store->AfterCommitTriggers().access();
|
||||
ASSERT_EQ(after_commit_triggers.size(), 1);
|
||||
for (const auto &trigger : after_commit_triggers) {
|
||||
verify_trigger(trigger, trigger_name_after);
|
||||
verify_trigger(trigger, trigger_name_after, &owner);
|
||||
}
|
||||
};
|
||||
|
||||
@ -906,32 +923,32 @@ TEST_F(TriggerStoreTest, AddTrigger) {
|
||||
// Invalid query in statements
|
||||
ASSERT_THROW(store.AddTrigger("trigger", "RETUR 1", {}, query::TriggerEventType::VERTEX_CREATE,
|
||||
query::TriggerPhase::BEFORE_COMMIT, &ast_cache, &*dba, &antlr_lock,
|
||||
query::InterpreterConfig::Query{}),
|
||||
query::InterpreterConfig::Query{}, std::nullopt, &auth_checker),
|
||||
utils::BasicException);
|
||||
ASSERT_THROW(store.AddTrigger("trigger", "RETURN createdEdges", {}, query::TriggerEventType::VERTEX_CREATE,
|
||||
query::TriggerPhase::BEFORE_COMMIT, &ast_cache, &*dba, &antlr_lock,
|
||||
query::InterpreterConfig::Query{}),
|
||||
query::InterpreterConfig::Query{}, std::nullopt, &auth_checker),
|
||||
utils::BasicException);
|
||||
|
||||
ASSERT_THROW(store.AddTrigger("trigger", "RETURN $parameter", {}, query::TriggerEventType::VERTEX_CREATE,
|
||||
query::TriggerPhase::BEFORE_COMMIT, &ast_cache, &*dba, &antlr_lock,
|
||||
query::InterpreterConfig::Query{}),
|
||||
query::InterpreterConfig::Query{}, std::nullopt, &auth_checker),
|
||||
utils::BasicException);
|
||||
|
||||
ASSERT_NO_THROW(
|
||||
store.AddTrigger("trigger", "RETURN $parameter",
|
||||
std::map<std::string, storage::PropertyValue>{{"parameter", storage::PropertyValue{1}}},
|
||||
query::TriggerEventType::VERTEX_CREATE, query::TriggerPhase::BEFORE_COMMIT, &ast_cache, &*dba,
|
||||
&antlr_lock, query::InterpreterConfig::Query{}));
|
||||
&antlr_lock, query::InterpreterConfig::Query{}, std::nullopt, &auth_checker));
|
||||
|
||||
// Inserting with the same name
|
||||
ASSERT_THROW(store.AddTrigger("trigger", "RETURN 1", {}, query::TriggerEventType::VERTEX_CREATE,
|
||||
query::TriggerPhase::BEFORE_COMMIT, &ast_cache, &*dba, &antlr_lock,
|
||||
query::InterpreterConfig::Query{}),
|
||||
query::InterpreterConfig::Query{}, std::nullopt, &auth_checker),
|
||||
utils::BasicException);
|
||||
ASSERT_THROW(store.AddTrigger("trigger", "RETURN 1", {}, query::TriggerEventType::VERTEX_CREATE,
|
||||
query::TriggerPhase::AFTER_COMMIT, &ast_cache, &*dba, &antlr_lock,
|
||||
query::InterpreterConfig::Query{}),
|
||||
query::InterpreterConfig::Query{}, std::nullopt, &auth_checker),
|
||||
utils::BasicException);
|
||||
|
||||
ASSERT_EQ(store.GetTriggerInfo().size(), 1);
|
||||
@ -947,7 +964,7 @@ TEST_F(TriggerStoreTest, DropTrigger) {
|
||||
const auto *trigger_name = "trigger";
|
||||
store.AddTrigger(trigger_name, "RETURN 1", {}, query::TriggerEventType::VERTEX_CREATE,
|
||||
query::TriggerPhase::BEFORE_COMMIT, &ast_cache, &*dba, &antlr_lock,
|
||||
query::InterpreterConfig::Query{});
|
||||
query::InterpreterConfig::Query{}, std::nullopt, &auth_checker);
|
||||
|
||||
ASSERT_THROW(store.DropTrigger("Unknown"), utils::BasicException);
|
||||
ASSERT_NO_THROW(store.DropTrigger(trigger_name));
|
||||
@ -960,7 +977,7 @@ TEST_F(TriggerStoreTest, TriggerInfo) {
|
||||
std::vector<query::TriggerStore::TriggerInfo> expected_info;
|
||||
store.AddTrigger("trigger", "RETURN 1", {}, query::TriggerEventType::VERTEX_CREATE,
|
||||
query::TriggerPhase::BEFORE_COMMIT, &ast_cache, &*dba, &antlr_lock,
|
||||
query::InterpreterConfig::Query{});
|
||||
query::InterpreterConfig::Query{}, std::nullopt, &auth_checker);
|
||||
expected_info.push_back(
|
||||
{"trigger", "RETURN 1", query::TriggerEventType::VERTEX_CREATE, query::TriggerPhase::BEFORE_COMMIT});
|
||||
|
||||
@ -971,7 +988,7 @@ TEST_F(TriggerStoreTest, TriggerInfo) {
|
||||
ASSERT_TRUE(std::all_of(expected_info.begin(), expected_info.end(), [&](const auto &info) {
|
||||
return std::find_if(trigger_info.begin(), trigger_info.end(), [&](const auto &other) {
|
||||
return info.name == other.name && info.statement == other.statement &&
|
||||
info.event_type == other.event_type && info.phase == other.phase;
|
||||
info.event_type == other.event_type && info.phase == other.phase && !info.owner.has_value();
|
||||
}) != trigger_info.end();
|
||||
}));
|
||||
};
|
||||
@ -979,8 +996,8 @@ TEST_F(TriggerStoreTest, TriggerInfo) {
|
||||
check_trigger_info();
|
||||
|
||||
store.AddTrigger("edge_update_trigger", "RETURN 1", {}, query::TriggerEventType::EDGE_UPDATE,
|
||||
query::TriggerPhase::AFTER_COMMIT, &ast_cache, &*dba, &antlr_lock,
|
||||
query::InterpreterConfig::Query{});
|
||||
query::TriggerPhase::AFTER_COMMIT, &ast_cache, &*dba, &antlr_lock, query::InterpreterConfig::Query{},
|
||||
std::nullopt, &auth_checker);
|
||||
expected_info.push_back(
|
||||
{"edge_update_trigger", "RETURN 1", query::TriggerEventType::EDGE_UPDATE, query::TriggerPhase::AFTER_COMMIT});
|
||||
|
||||
@ -1093,8 +1110,56 @@ TEST_F(TriggerStoreTest, AnyTriggerAllKeywords) {
|
||||
SCOPED_TRACE(keyword);
|
||||
EXPECT_NO_THROW(store.AddTrigger(trigger_name, fmt::format("RETURN {}", keyword), {}, event_type,
|
||||
query::TriggerPhase::BEFORE_COMMIT, &ast_cache, &*dba, &antlr_lock,
|
||||
query::InterpreterConfig::Query{}));
|
||||
query::InterpreterConfig::Query{}, std::nullopt, &auth_checker));
|
||||
store.DropTrigger(trigger_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TriggerStoreTest, AuthCheckerUsage) {
|
||||
using Privilege = query::AuthQuery::Privilege;
|
||||
using ::testing::_;
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::Return;
|
||||
std::optional<query::TriggerStore> store{testing_directory};
|
||||
const std::optional<std::string> owner{"testing_owner"};
|
||||
MockAuthChecker mock_checker;
|
||||
|
||||
::testing::InSequence s;
|
||||
|
||||
EXPECT_CALL(mock_checker, IsUserAuthorized(std::optional<std::string>{}, ElementsAre(Privilege::CREATE)))
|
||||
.Times(1)
|
||||
.WillOnce(Return(true));
|
||||
EXPECT_CALL(mock_checker, IsUserAuthorized(owner, ElementsAre(Privilege::CREATE))).Times(1).WillOnce(Return(true));
|
||||
|
||||
ASSERT_NO_THROW(store->AddTrigger("successfull_trigger_1", "CREATE (n:VERTEX) RETURN n", {},
|
||||
query::TriggerEventType::EDGE_UPDATE, query::TriggerPhase::AFTER_COMMIT, &ast_cache,
|
||||
&*dba, &antlr_lock, query::InterpreterConfig::Query{}, std::nullopt,
|
||||
&mock_checker));
|
||||
|
||||
ASSERT_NO_THROW(store->AddTrigger("successfull_trigger_2", "CREATE (n:VERTEX) RETURN n", {},
|
||||
query::TriggerEventType::EDGE_UPDATE, query::TriggerPhase::AFTER_COMMIT, &ast_cache,
|
||||
&*dba, &antlr_lock, query::InterpreterConfig::Query{}, owner, &mock_checker));
|
||||
|
||||
EXPECT_CALL(mock_checker, IsUserAuthorized(std::optional<std::string>{}, ElementsAre(Privilege::MATCH)))
|
||||
.Times(1)
|
||||
.WillOnce(Return(false));
|
||||
|
||||
ASSERT_THROW(store->AddTrigger("unprivileged_trigger", "MATCH (n:VERTEX) RETURN n", {},
|
||||
query::TriggerEventType::EDGE_UPDATE, query::TriggerPhase::AFTER_COMMIT, &ast_cache,
|
||||
&*dba, &antlr_lock, query::InterpreterConfig::Query{}, std::nullopt, &mock_checker);
|
||||
, utils::BasicException);
|
||||
|
||||
store.emplace(testing_directory);
|
||||
EXPECT_CALL(mock_checker, IsUserAuthorized(std::optional<std::string>{}, ElementsAre(Privilege::CREATE)))
|
||||
.Times(1)
|
||||
.WillOnce(Return(false));
|
||||
EXPECT_CALL(mock_checker, IsUserAuthorized(owner, ElementsAre(Privilege::CREATE))).Times(1).WillOnce(Return(true));
|
||||
|
||||
ASSERT_NO_THROW(
|
||||
store->RestoreTriggers(&ast_cache, &*dba, &antlr_lock, query::InterpreterConfig::Query{}, &mock_checker));
|
||||
|
||||
const auto triggers = store->GetTriggerInfo();
|
||||
ASSERT_EQ(triggers.size(), 1);
|
||||
ASSERT_EQ(triggers.front().owner, owner);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user