Add privilege check in triggers and streams ()

This commit is contained in:
János Benjamin Antal 2021-07-22 16:22:08 +02:00 committed by GitHub
parent 09c58501f1
commit 09cfca35f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 1413 additions and 810 deletions

View File

@ -53,6 +53,7 @@ Checks: '*,
-performance-unnecessary-value-param,
-readability-braces-around-statements,
-readability-else-after-return,
-readability-function-cognitive-complexity,
-readability-implicit-bool-conversion,
-readability-magic-numbers,
-readability-named-parameter,

View File

@ -144,7 +144,7 @@ std::optional<User> Auth::Authenticate(const std::string &username, const std::s
}
}
std::optional<User> Auth::GetUser(const std::string &username_orig) {
std::optional<User> Auth::GetUser(const std::string &username_orig) const {
auto username = utils::ToLowerCase(username_orig);
auto existing_user = storage_.Get(kUserPrefix + username);
if (!existing_user) return std::nullopt;
@ -170,9 +170,9 @@ std::optional<User> Auth::GetUser(const std::string &username_orig) {
void Auth::SaveUser(const User &user) {
bool success = false;
if (user.role()) {
success = storage_.PutMultiple({{kUserPrefix + user.username(), user.Serialize().dump()},
{kLinkPrefix + user.username(), user.role()->rolename()}});
if (const auto *role = user.role(); role != nullptr) {
success = storage_.PutMultiple(
{{kUserPrefix + user.username(), user.Serialize().dump()}, {kLinkPrefix + user.username(), role->rolename()}});
} else {
success = storage_.PutAndDeleteMultiple({{kUserPrefix + user.username(), user.Serialize().dump()}},
{kLinkPrefix + user.username()});
@ -203,7 +203,7 @@ bool Auth::RemoveUser(const std::string &username_orig) {
return true;
}
std::vector<auth::User> Auth::AllUsers() {
std::vector<auth::User> Auth::AllUsers() const {
std::vector<auth::User> ret;
for (auto it = storage_.begin(kUserPrefix); it != storage_.end(kUserPrefix); ++it) {
auto username = it->first.substr(kUserPrefix.size());
@ -216,9 +216,9 @@ std::vector<auth::User> Auth::AllUsers() {
return ret;
}
bool Auth::HasUsers() { return storage_.begin(kUserPrefix) != storage_.end(kUserPrefix); }
bool Auth::HasUsers() const { return storage_.begin(kUserPrefix) != storage_.end(kUserPrefix); }
std::optional<Role> Auth::GetRole(const std::string &rolename_orig) {
std::optional<Role> Auth::GetRole(const std::string &rolename_orig) const {
auto rolename = utils::ToLowerCase(rolename_orig);
auto existing_role = storage_.Get(kRolePrefix + rolename);
if (!existing_role) return std::nullopt;
@ -265,7 +265,7 @@ bool Auth::RemoveRole(const std::string &rolename_orig) {
return true;
}
std::vector<auth::Role> Auth::AllRoles() {
std::vector<auth::Role> Auth::AllRoles() const {
std::vector<auth::Role> ret;
for (auto it = storage_.begin(kRolePrefix); it != storage_.end(kRolePrefix); ++it) {
auto rolename = it->first.substr(kRolePrefix.size());
@ -280,7 +280,7 @@ std::vector<auth::Role> Auth::AllRoles() {
return ret;
}
std::vector<auth::User> Auth::AllUsersForRole(const std::string &rolename_orig) {
std::vector<auth::User> Auth::AllUsersForRole(const std::string &rolename_orig) const {
auto rolename = utils::ToLowerCase(rolename_orig);
std::vector<auth::User> ret;
for (auto it = storage_.begin(kLinkPrefix); it != storage_.end(kLinkPrefix); ++it) {
@ -299,6 +299,4 @@ std::vector<auth::User> Auth::AllUsersForRole(const std::string &rolename_orig)
return ret;
}
std::mutex &Auth::WithLock() { return lock_; }
} // namespace auth

View File

@ -14,8 +14,7 @@ namespace auth {
/**
* This class serves as the main Authentication/Authorization storage.
* It provides functions for managing Users, Roles and Permissions.
* NOTE: The functions in this class aren't thread safe. Use the `WithLock` lock
* if you want to have safe modifications of the storage.
* NOTE: The non-const functions in this class aren't thread safe.
* TODO (mferencevic): Disable user/role modification functions when they are
* being managed by the auth module.
*/
@ -42,7 +41,7 @@ class Auth final {
* @return a user when the user exists, nullopt otherwise
* @throw AuthException if unable to load user data.
*/
std::optional<User> GetUser(const std::string &username);
std::optional<User> GetUser(const std::string &username) const;
/**
* Saves a user object to the storage.
@ -81,14 +80,14 @@ class Auth final {
* @return a list of users
* @throw AuthException if unable to load user data.
*/
std::vector<User> AllUsers();
std::vector<User> AllUsers() const;
/**
* Returns whether there are users in the storage.
*
* @return `true` if the storage contains any users, `false` otherwise
*/
bool HasUsers();
bool HasUsers() const;
/**
* Gets a role from the storage.
@ -98,7 +97,7 @@ class Auth final {
* @return a role when the role exists, nullopt otherwise
* @throw AuthException if unable to load role data.
*/
std::optional<Role> GetRole(const std::string &rolename);
std::optional<Role> GetRole(const std::string &rolename) const;
/**
* Saves a role object to the storage.
@ -136,7 +135,7 @@ class Auth final {
* @return a list of roles
* @throw AuthException if unable to load role data.
*/
std::vector<Role> AllRoles();
std::vector<Role> AllRoles() const;
/**
* Gets all users for a role from the storage.
@ -146,21 +145,13 @@ class Auth final {
* @return a list of roles
* @throw AuthException if unable to load user data.
*/
std::vector<User> AllUsersForRole(const std::string &rolename);
/**
* Returns a reference to the lock that should be used for all operations that
* require more than one interaction with this class.
*/
std::mutex &WithLock();
std::vector<User> AllUsersForRole(const std::string &rolename) const;
private:
// Even though the `kvstore::KVStore` class is guaranteed to be thread-safe,
// Auth is not thread-safe because modifying users and roles might require
// more than one operation on the storage.
kvstore::KVStore storage_;
auth::Module module_;
// Even though the `kvstore::KVStore` class is guaranteed to be thread-safe we
// use a mutex to lock all operations on the `User` and `Role` storage because
// some operations on the users and/or roles may require more than one
// operation on the storage.
std::mutex lock_;
};
} // namespace auth

View File

@ -216,7 +216,7 @@ void User::SetRole(const Role &role) { role_.emplace(role); }
void User::ClearRole() { role_ = std::nullopt; }
const Permissions User::GetPermissions() const {
Permissions User::GetPermissions() const {
if (role_) {
return Permissions(permissions_.grants() | role_->permissions().grants(),
permissions_.denies() | role_->permissions().denies());
@ -229,7 +229,12 @@ const std::string &User::username() const { return username_; }
const Permissions &User::permissions() const { return permissions_; }
Permissions &User::permissions() { return permissions_; }
std::optional<Role> User::role() const { return role_; }
const Role *User::role() const {
if (role_.has_value()) {
return &role_.value();
}
return nullptr;
}
nlohmann::json User::Serialize() const {
nlohmann::json data = nlohmann::json::object();

View File

@ -127,14 +127,14 @@ class User final {
void ClearRole();
const Permissions GetPermissions() const;
Permissions GetPermissions() const;
const std::string &username() const;
const Permissions &permissions() const;
Permissions &permissions();
std::optional<Role> role() const;
const Role *role() const;
nlohmann::json Serialize() const;

View File

@ -144,7 +144,7 @@ bool KVStore::iterator::IsValid() { return pimpl_->it != nullptr; }
// TODO(ipaljak) The complexity of the size function should be at most
// logarithmic.
size_t KVStore::Size(const std::string &prefix) {
size_t KVStore::Size(const std::string &prefix) const {
size_t size = 0;
for (auto it = this->begin(prefix); it != this->end(prefix); ++it) ++size;
return size;

View File

@ -126,7 +126,7 @@ class KVStore final {
*
* @return - number of stored pairs.
*/
size_t Size(const std::string &prefix = "");
size_t Size(const std::string &prefix = "") const;
/**
* Compact the underlying storage for the key range [begin_prefix,
@ -186,9 +186,9 @@ class KVStore final {
std::unique_ptr<impl> pimpl_;
};
iterator begin(const std::string &prefix = "") { return iterator(this, prefix); }
iterator begin(const std::string &prefix = "") const { return iterator(this, prefix); }
iterator end(const std::string &prefix = "") { return iterator(this, prefix, true); }
iterator end(const std::string &prefix = "") const { return iterator(this, prefix, true); }
private:
struct impl;

View File

@ -23,6 +23,7 @@
#include "communication/bolt/v1/constants.hpp"
#include "helpers.hpp"
#include "py/py.hpp"
#include "query/auth_checker.hpp"
#include "query/discard_value_stream.hpp"
#include "query/exceptions.hpp"
#include "query/interpreter.hpp"
@ -40,8 +41,10 @@
#include "utils/logging.hpp"
#include "utils/memory_tracker.hpp"
#include "utils/readable_size.hpp"
#include "utils/rw_lock.hpp"
#include "utils/signals.hpp"
#include "utils/string.hpp"
#include "utils/synchronized.hpp"
#include "utils/sysinfo/memory.hpp"
#include "utils/terminate_handler.hpp"
#include "version.hpp"
@ -354,15 +357,373 @@ void ConfigureLogging() {
struct SessionData {
// Explicit constructor here to ensure that pointers to all objects are
// supplied.
SessionData(storage::Storage *db, query::InterpreterContext *interpreter_context, auth::Auth *auth,
audit::Log *audit_log)
SessionData(storage::Storage *db, query::InterpreterContext *interpreter_context,
utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth, audit::Log *audit_log)
: db(db), interpreter_context(interpreter_context), auth(auth), audit_log(audit_log) {}
storage::Storage *db;
query::InterpreterContext *interpreter_context;
auth::Auth *auth;
utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth;
audit::Log *audit_log;
};
DEFINE_string(auth_user_or_role_name_regex, "[a-zA-Z0-9_.+-@]+",
"Set to the regular expression that each user or role name must fulfill.");
class AuthQueryHandler final : public query::AuthQueryHandler {
utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth_;
std::regex name_regex_;
public:
AuthQueryHandler(utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth, const std::regex &name_regex)
: auth_(auth), name_regex_(name_regex) {}
bool CreateUser(const std::string &username, const std::optional<std::string> &password) override {
if (!std::regex_match(username, name_regex_)) {
throw query::QueryRuntimeException("Invalid user name.");
}
try {
auto locked_auth = auth_->Lock();
return locked_auth->AddUser(username, password).has_value();
} catch (const auth::AuthException &e) {
throw query::QueryRuntimeException(e.what());
}
}
bool DropUser(const std::string &username) override {
if (!std::regex_match(username, name_regex_)) {
throw query::QueryRuntimeException("Invalid user name.");
}
try {
auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username);
if (!user) return false;
return locked_auth->RemoveUser(username);
} catch (const auth::AuthException &e) {
throw query::QueryRuntimeException(e.what());
}
}
void SetPassword(const std::string &username, const std::optional<std::string> &password) override {
if (!std::regex_match(username, name_regex_)) {
throw query::QueryRuntimeException("Invalid user name.");
}
try {
auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username);
if (!user) {
throw query::QueryRuntimeException("User '{}' doesn't exist.", username);
}
user->UpdatePassword(password);
locked_auth->SaveUser(*user);
} catch (const auth::AuthException &e) {
throw query::QueryRuntimeException(e.what());
}
}
bool CreateRole(const std::string &rolename) override {
if (!std::regex_match(rolename, name_regex_)) {
throw query::QueryRuntimeException("Invalid role name.");
}
try {
auto locked_auth = auth_->Lock();
return locked_auth->AddRole(rolename).has_value();
} catch (const auth::AuthException &e) {
throw query::QueryRuntimeException(e.what());
}
}
bool DropRole(const std::string &rolename) override {
if (!std::regex_match(rolename, name_regex_)) {
throw query::QueryRuntimeException("Invalid role name.");
}
try {
auto locked_auth = auth_->Lock();
auto role = locked_auth->GetRole(rolename);
if (!role) return false;
return locked_auth->RemoveRole(rolename);
} catch (const auth::AuthException &e) {
throw query::QueryRuntimeException(e.what());
}
}
std::vector<query::TypedValue> GetUsernames() override {
try {
auto locked_auth = auth_->ReadLock();
std::vector<query::TypedValue> usernames;
const auto &users = locked_auth->AllUsers();
usernames.reserve(users.size());
for (const auto &user : users) {
usernames.emplace_back(user.username());
}
return usernames;
} catch (const auth::AuthException &e) {
throw query::QueryRuntimeException(e.what());
}
}
std::vector<query::TypedValue> GetRolenames() override {
try {
auto locked_auth = auth_->ReadLock();
std::vector<query::TypedValue> rolenames;
const auto &roles = locked_auth->AllRoles();
rolenames.reserve(roles.size());
for (const auto &role : roles) {
rolenames.emplace_back(role.rolename());
}
return rolenames;
} catch (const auth::AuthException &e) {
throw query::QueryRuntimeException(e.what());
}
}
std::optional<std::string> GetRolenameForUser(const std::string &username) override {
if (!std::regex_match(username, name_regex_)) {
throw query::QueryRuntimeException("Invalid user name.");
}
try {
auto locked_auth = auth_->ReadLock();
auto user = locked_auth->GetUser(username);
if (!user) {
throw query::QueryRuntimeException("User '{}' doesn't exist .", username);
}
if (const auto *role = user->role(); role != nullptr) {
return role->rolename();
}
return std::nullopt;
} catch (const auth::AuthException &e) {
throw query::QueryRuntimeException(e.what());
}
}
std::vector<query::TypedValue> GetUsernamesForRole(const std::string &rolename) override {
if (!std::regex_match(rolename, name_regex_)) {
throw query::QueryRuntimeException("Invalid role name.");
}
try {
auto locked_auth = auth_->ReadLock();
auto role = locked_auth->GetRole(rolename);
if (!role) {
throw query::QueryRuntimeException("Role '{}' doesn't exist.", rolename);
}
std::vector<query::TypedValue> usernames;
const auto &users = locked_auth->AllUsersForRole(rolename);
usernames.reserve(users.size());
for (const auto &user : users) {
usernames.emplace_back(user.username());
}
return usernames;
} catch (const auth::AuthException &e) {
throw query::QueryRuntimeException(e.what());
}
}
void SetRole(const std::string &username, const std::string &rolename) override {
if (!std::regex_match(username, name_regex_)) {
throw query::QueryRuntimeException("Invalid user name.");
}
if (!std::regex_match(rolename, name_regex_)) {
throw query::QueryRuntimeException("Invalid role name.");
}
try {
auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username);
if (!user) {
throw query::QueryRuntimeException("User '{}' doesn't exist .", username);
}
auto role = locked_auth->GetRole(rolename);
if (!role) {
throw query::QueryRuntimeException("Role '{}' doesn't exist .", rolename);
}
if (const auto *current_role = user->role(); current_role != nullptr) {
throw query::QueryRuntimeException("User '{}' is already a member of role '{}'.", username,
current_role->rolename());
}
user->SetRole(*role);
locked_auth->SaveUser(*user);
} catch (const auth::AuthException &e) {
throw query::QueryRuntimeException(e.what());
}
}
void ClearRole(const std::string &username) override {
if (!std::regex_match(username, name_regex_)) {
throw query::QueryRuntimeException("Invalid user name.");
}
try {
auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username);
if (!user) {
throw query::QueryRuntimeException("User '{}' doesn't exist .", username);
}
user->ClearRole();
locked_auth->SaveUser(*user);
} catch (const auth::AuthException &e) {
throw query::QueryRuntimeException(e.what());
}
}
std::vector<std::vector<query::TypedValue>> GetPrivileges(const std::string &user_or_role) override {
if (!std::regex_match(user_or_role, name_regex_)) {
throw query::QueryRuntimeException("Invalid user or role name.");
}
try {
auto locked_auth = auth_->ReadLock();
std::vector<std::vector<query::TypedValue>> grants;
auto user = locked_auth->GetUser(user_or_role);
auto role = locked_auth->GetRole(user_or_role);
if (!user && !role) {
throw query::QueryRuntimeException("User or role '{}' doesn't exist.", user_or_role);
}
if (user) {
const auto &permissions = user->GetPermissions();
for (const auto &privilege : query::kPrivilegesAll) {
auto permission = glue::PrivilegeToPermission(privilege);
auto effective = permissions.Has(permission);
if (permissions.Has(permission) != auth::PermissionLevel::NEUTRAL) {
std::vector<std::string> description;
auto user_level = user->permissions().Has(permission);
if (user_level == auth::PermissionLevel::GRANT) {
description.emplace_back("GRANTED TO USER");
} else if (user_level == auth::PermissionLevel::DENY) {
description.emplace_back("DENIED TO USER");
}
if (const auto *role = user->role(); role != nullptr) {
auto role_level = role->permissions().Has(permission);
if (role_level == auth::PermissionLevel::GRANT) {
description.emplace_back("GRANTED TO ROLE");
} else if (role_level == auth::PermissionLevel::DENY) {
description.emplace_back("DENIED TO ROLE");
}
}
grants.push_back({query::TypedValue(auth::PermissionToString(permission)),
query::TypedValue(auth::PermissionLevelToString(effective)),
query::TypedValue(utils::Join(description, ", "))});
}
}
} else {
const auto &permissions = role->permissions();
for (const auto &privilege : query::kPrivilegesAll) {
auto permission = glue::PrivilegeToPermission(privilege);
auto effective = permissions.Has(permission);
if (effective != auth::PermissionLevel::NEUTRAL) {
std::string description;
if (effective == auth::PermissionLevel::GRANT) {
description = "GRANTED TO ROLE";
} else if (effective == auth::PermissionLevel::DENY) {
description = "DENIED TO ROLE";
}
grants.push_back({query::TypedValue(auth::PermissionToString(permission)),
query::TypedValue(auth::PermissionLevelToString(effective)),
query::TypedValue(description)});
}
}
}
return grants;
} catch (const auth::AuthException &e) {
throw query::QueryRuntimeException(e.what());
}
}
void GrantPrivilege(const std::string &user_or_role,
const std::vector<query::AuthQuery::Privilege> &privileges) override {
EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) {
// TODO (mferencevic): should we first check that the
// privilege is granted/denied/revoked before
// unconditionally granting/denying/revoking it?
permissions->Grant(permission);
});
}
void DenyPrivilege(const std::string &user_or_role,
const std::vector<query::AuthQuery::Privilege> &privileges) override {
EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) {
// TODO (mferencevic): should we first check that the
// privilege is granted/denied/revoked before
// unconditionally granting/denying/revoking it?
permissions->Deny(permission);
});
}
void RevokePrivilege(const std::string &user_or_role,
const std::vector<query::AuthQuery::Privilege> &privileges) override {
EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) {
// TODO (mferencevic): should we first check that the
// privilege is granted/denied/revoked before
// unconditionally granting/denying/revoking it?
permissions->Revoke(permission);
});
}
private:
template <class TEditFun>
void EditPermissions(const std::string &user_or_role, const std::vector<query::AuthQuery::Privilege> &privileges,
const TEditFun &edit_fun) {
if (!std::regex_match(user_or_role, name_regex_)) {
throw query::QueryRuntimeException("Invalid user or role name.");
}
try {
auto locked_auth = auth_->Lock();
std::vector<auth::Permission> permissions;
permissions.reserve(privileges.size());
for (const auto &privilege : privileges) {
permissions.push_back(glue::PrivilegeToPermission(privilege));
}
auto user = locked_auth->GetUser(user_or_role);
auto role = locked_auth->GetRole(user_or_role);
if (!user && !role) {
throw query::QueryRuntimeException("User or role '{}' doesn't exist.", user_or_role);
}
if (user) {
for (const auto &permission : permissions) {
edit_fun(&user->permissions(), permission);
}
locked_auth->SaveUser(*user);
} else {
for (const auto &permission : permissions) {
edit_fun(&role->permissions(), permission);
}
locked_auth->SaveRole(*role);
}
} catch (const auth::AuthException &e) {
throw query::QueryRuntimeException(e.what());
}
}
};
class AuthChecker final : public query::AuthChecker {
public:
explicit AuthChecker(utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth) : auth_{auth} {}
static bool IsUserAuthorized(const auth::User &user, const std::vector<query::AuthQuery::Privilege> &privileges) {
const auto user_permissions = user.GetPermissions();
return std::all_of(privileges.begin(), privileges.end(), [&user_permissions](const auto privilege) {
return user_permissions.Has(glue::PrivilegeToPermission(privilege)) == auth::PermissionLevel::GRANT;
});
}
bool IsUserAuthorized(const std::optional<std::string> &username,
const std::vector<query::AuthQuery::Privilege> &privileges) const final {
std::optional<auth::User> maybe_user;
{
auto locked_auth = auth_->ReadLock();
if (!locked_auth->HasUsers()) {
return true;
}
if (username.has_value()) {
maybe_user = locked_auth->GetUser(*username);
}
}
return maybe_user.has_value() && IsUserAuthorized(*maybe_user, privileges);
}
private:
utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth_;
};
#else
struct SessionData {
// Explicit constructor here to ensure that pointers to all objects are
// supplied.
@ -371,6 +732,52 @@ struct SessionData {
storage::Storage *db;
query::InterpreterContext *interpreter_context;
};
class NoAuthInCommunity : public query::QueryRuntimeException {
public:
NoAuthInCommunity()
: query::QueryRuntimeException::QueryRuntimeException("Auth is not supported in Memgraph Community!") {}
};
class AuthQueryHandler final : public query::AuthQueryHandler {
public:
bool CreateUser(const std::string &, const std::optional<std::string> &) override { throw NoAuthInCommunity(); }
bool DropUser(const std::string &) override { throw NoAuthInCommunity(); }
void SetPassword(const std::string &, const std::optional<std::string> &) override { throw NoAuthInCommunity(); }
bool CreateRole(const std::string &) override { throw NoAuthInCommunity(); }
bool DropRole(const std::string &) override { throw NoAuthInCommunity(); }
std::vector<query::TypedValue> GetUsernames() override { throw NoAuthInCommunity(); }
std::vector<query::TypedValue> GetRolenames() override { throw NoAuthInCommunity(); }
std::optional<std::string> GetRolenameForUser(const std::string &) override { throw NoAuthInCommunity(); }
std::vector<query::TypedValue> GetUsernamesForRole(const std::string &) override { throw NoAuthInCommunity(); }
void SetRole(const std::string &, const std::string &) override { throw NoAuthInCommunity(); }
void ClearRole(const std::string &) override { throw NoAuthInCommunity(); }
std::vector<std::vector<query::TypedValue>> GetPrivileges(const std::string &) override { throw NoAuthInCommunity(); }
void GrantPrivilege(const std::string &, const std::vector<query::AuthQuery::Privilege> &) override {
throw NoAuthInCommunity();
}
void DenyPrivilege(const std::string &, const std::vector<query::AuthQuery::Privilege> &) override {
throw NoAuthInCommunity();
}
void RevokePrivilege(const std::string &, const std::vector<query::AuthQuery::Privilege> &) override {
throw NoAuthInCommunity();
}
};
#endif
class BoltSession final : public communication::bolt::Session<communication::InputStream, communication::OutputStream> {
@ -400,22 +807,21 @@ class BoltSession final : public communication::bolt::Session<communication::Inp
const std::string &query, const std::map<std::string, communication::bolt::Value> &params) override {
std::map<std::string, storage::PropertyValue> params_pv;
for (const auto &kv : params) params_pv.emplace(kv.first, glue::ToPropertyValue(kv.second));
const std::string *username{nullptr};
#ifdef MG_ENTERPRISE
audit_log_->Record(endpoint_.address, user_ ? user_->username() : "", query, storage::PropertyValue(params_pv));
if (user_) {
username = &user_->username();
}
audit_log_->Record(endpoint_.address, user_ ? *username : "", query, storage::PropertyValue(params_pv));
#endif
try {
auto result = interpreter_.Prepare(query, params_pv);
auto result = interpreter_.Prepare(query, params_pv, username);
#ifdef MG_ENTERPRISE
if (user_) {
const auto &permissions = user_->GetPermissions();
for (const auto &privilege : result.privileges) {
if (permissions.Has(glue::PrivilegeToPermission(privilege)) != auth::PermissionLevel::GRANT) {
interpreter_.Abort();
throw communication::bolt::ClientError(
"You are not authorized to execute this query! Please contact "
"your database administrator.");
}
}
if (user_ && !AuthChecker::IsUserAuthorized(*user_, result.privileges)) {
interpreter_.Abort();
throw communication::bolt::ClientError(
"You are not authorized to execute this query! Please contact "
"your database administrator.");
}
#endif
return {result.headers, result.qid};
@ -442,9 +848,12 @@ class BoltSession final : public communication::bolt::Session<communication::Inp
bool Authenticate(const std::string &username, const std::string &password) override {
#ifdef MG_ENTERPRISE
if (!auth_->HasUsers()) return true;
user_ = auth_->Authenticate(username, password);
return !!user_;
auto locked_auth = auth_->Lock();
if (!locked_auth->HasUsers()) {
return true;
}
user_ = locked_auth->Authenticate(username, password);
return user_.has_value();
#else
return true;
#endif
@ -522,7 +931,7 @@ class BoltSession final : public communication::bolt::Session<communication::Inp
const storage::Storage *db_;
query::Interpreter interpreter_;
#ifdef MG_ENTERPRISE
auth::Auth *auth_;
utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth_;
std::optional<auth::User> user_;
audit::Log *audit_log_;
#endif
@ -532,374 +941,6 @@ class BoltSession final : public communication::bolt::Session<communication::Inp
using ServerT = communication::Server<BoltSession, SessionData>;
using communication::ServerContext;
#ifdef MG_ENTERPRISE
DEFINE_string(auth_user_or_role_name_regex, "[a-zA-Z0-9_.+-@]+",
"Set to the regular expression that each user or role name must fulfill.");
class AuthQueryHandler final : public query::AuthQueryHandler {
auth::Auth *auth_;
std::regex name_regex_;
public:
AuthQueryHandler(auth::Auth *auth, const std::regex &name_regex) : auth_(auth), name_regex_(name_regex) {}
bool CreateUser(const std::string &username, const std::optional<std::string> &password) override {
if (!std::regex_match(username, name_regex_)) {
throw query::QueryRuntimeException("Invalid user name.");
}
try {
std::lock_guard<std::mutex> lock(auth_->WithLock());
return !!auth_->AddUser(username, password);
} catch (const auth::AuthException &e) {
throw query::QueryRuntimeException(e.what());
}
}
bool DropUser(const std::string &username) override {
if (!std::regex_match(username, name_regex_)) {
throw query::QueryRuntimeException("Invalid user name.");
}
try {
std::lock_guard<std::mutex> lock(auth_->WithLock());
auto user = auth_->GetUser(username);
if (!user) return false;
return auth_->RemoveUser(username);
} catch (const auth::AuthException &e) {
throw query::QueryRuntimeException(e.what());
}
}
void SetPassword(const std::string &username, const std::optional<std::string> &password) override {
if (!std::regex_match(username, name_regex_)) {
throw query::QueryRuntimeException("Invalid user name.");
}
try {
std::lock_guard<std::mutex> lock(auth_->WithLock());
auto user = auth_->GetUser(username);
if (!user) {
throw query::QueryRuntimeException("User '{}' doesn't exist.", username);
}
user->UpdatePassword(password);
auth_->SaveUser(*user);
} catch (const auth::AuthException &e) {
throw query::QueryRuntimeException(e.what());
}
}
bool CreateRole(const std::string &rolename) override {
if (!std::regex_match(rolename, name_regex_)) {
throw query::QueryRuntimeException("Invalid role name.");
}
try {
std::lock_guard<std::mutex> lock(auth_->WithLock());
return !!auth_->AddRole(rolename);
} catch (const auth::AuthException &e) {
throw query::QueryRuntimeException(e.what());
}
}
bool DropRole(const std::string &rolename) override {
if (!std::regex_match(rolename, name_regex_)) {
throw query::QueryRuntimeException("Invalid role name.");
}
try {
std::lock_guard<std::mutex> lock(auth_->WithLock());
auto role = auth_->GetRole(rolename);
if (!role) return false;
return auth_->RemoveRole(rolename);
} catch (const auth::AuthException &e) {
throw query::QueryRuntimeException(e.what());
}
}
std::vector<query::TypedValue> GetUsernames() override {
try {
std::lock_guard<std::mutex> lock(auth_->WithLock());
std::vector<query::TypedValue> usernames;
const auto &users = auth_->AllUsers();
usernames.reserve(users.size());
for (const auto &user : users) {
usernames.emplace_back(user.username());
}
return usernames;
} catch (const auth::AuthException &e) {
throw query::QueryRuntimeException(e.what());
}
}
std::vector<query::TypedValue> GetRolenames() override {
try {
std::lock_guard<std::mutex> lock(auth_->WithLock());
std::vector<query::TypedValue> rolenames;
const auto &roles = auth_->AllRoles();
rolenames.reserve(roles.size());
for (const auto &role : roles) {
rolenames.emplace_back(role.rolename());
}
return rolenames;
} catch (const auth::AuthException &e) {
throw query::QueryRuntimeException(e.what());
}
}
std::optional<std::string> GetRolenameForUser(const std::string &username) override {
if (!std::regex_match(username, name_regex_)) {
throw query::QueryRuntimeException("Invalid user name.");
}
try {
std::lock_guard<std::mutex> lock(auth_->WithLock());
auto user = auth_->GetUser(username);
if (!user) {
throw query::QueryRuntimeException("User '{}' doesn't exist .", username);
}
if (user->role()) return user->role()->rolename();
return std::nullopt;
} catch (const auth::AuthException &e) {
throw query::QueryRuntimeException(e.what());
}
}
std::vector<query::TypedValue> GetUsernamesForRole(const std::string &rolename) override {
if (!std::regex_match(rolename, name_regex_)) {
throw query::QueryRuntimeException("Invalid role name.");
}
try {
std::lock_guard<std::mutex> lock(auth_->WithLock());
auto role = auth_->GetRole(rolename);
if (!role) {
throw query::QueryRuntimeException("Role '{}' doesn't exist.", rolename);
}
std::vector<query::TypedValue> usernames;
const auto &users = auth_->AllUsersForRole(rolename);
usernames.reserve(users.size());
for (const auto &user : users) {
usernames.emplace_back(user.username());
}
return usernames;
} catch (const auth::AuthException &e) {
throw query::QueryRuntimeException(e.what());
}
}
void SetRole(const std::string &username, const std::string &rolename) override {
if (!std::regex_match(username, name_regex_)) {
throw query::QueryRuntimeException("Invalid user name.");
}
if (!std::regex_match(rolename, name_regex_)) {
throw query::QueryRuntimeException("Invalid role name.");
}
try {
std::lock_guard<std::mutex> lock(auth_->WithLock());
auto user = auth_->GetUser(username);
if (!user) {
throw query::QueryRuntimeException("User '{}' doesn't exist .", username);
}
auto role = auth_->GetRole(rolename);
if (!role) {
throw query::QueryRuntimeException("Role '{}' doesn't exist .", rolename);
}
if (user->role()) {
throw query::QueryRuntimeException("User '{}' is already a member of role '{}'.", username,
user->role()->rolename());
}
user->SetRole(*role);
auth_->SaveUser(*user);
} catch (const auth::AuthException &e) {
throw query::QueryRuntimeException(e.what());
}
}
void ClearRole(const std::string &username) override {
if (!std::regex_match(username, name_regex_)) {
throw query::QueryRuntimeException("Invalid user name.");
}
try {
std::lock_guard<std::mutex> lock(auth_->WithLock());
auto user = auth_->GetUser(username);
if (!user) {
throw query::QueryRuntimeException("User '{}' doesn't exist .", username);
}
user->ClearRole();
auth_->SaveUser(*user);
} catch (const auth::AuthException &e) {
throw query::QueryRuntimeException(e.what());
}
}
std::vector<std::vector<query::TypedValue>> GetPrivileges(const std::string &user_or_role) override {
if (!std::regex_match(user_or_role, name_regex_)) {
throw query::QueryRuntimeException("Invalid user or role name.");
}
try {
std::lock_guard<std::mutex> lock(auth_->WithLock());
std::vector<std::vector<query::TypedValue>> grants;
auto user = auth_->GetUser(user_or_role);
auto role = auth_->GetRole(user_or_role);
if (!user && !role) {
throw query::QueryRuntimeException("User or role '{}' doesn't exist.", user_or_role);
}
if (user) {
const auto &permissions = user->GetPermissions();
for (const auto &privilege : query::kPrivilegesAll) {
auto permission = glue::PrivilegeToPermission(privilege);
auto effective = permissions.Has(permission);
if (permissions.Has(permission) != auth::PermissionLevel::NEUTRAL) {
std::vector<std::string> description;
auto user_level = user->permissions().Has(permission);
if (user_level == auth::PermissionLevel::GRANT) {
description.emplace_back("GRANTED TO USER");
} else if (user_level == auth::PermissionLevel::DENY) {
description.emplace_back("DENIED TO USER");
}
if (user->role()) {
auto role_level = user->role()->permissions().Has(permission);
if (role_level == auth::PermissionLevel::GRANT) {
description.emplace_back("GRANTED TO ROLE");
} else if (role_level == auth::PermissionLevel::DENY) {
description.emplace_back("DENIED TO ROLE");
}
}
grants.push_back({query::TypedValue(auth::PermissionToString(permission)),
query::TypedValue(auth::PermissionLevelToString(effective)),
query::TypedValue(utils::Join(description, ", "))});
}
}
} else {
const auto &permissions = role->permissions();
for (const auto &privilege : query::kPrivilegesAll) {
auto permission = glue::PrivilegeToPermission(privilege);
auto effective = permissions.Has(permission);
if (effective != auth::PermissionLevel::NEUTRAL) {
std::string description;
if (effective == auth::PermissionLevel::GRANT) {
description = "GRANTED TO ROLE";
} else if (effective == auth::PermissionLevel::DENY) {
description = "DENIED TO ROLE";
}
grants.push_back({query::TypedValue(auth::PermissionToString(permission)),
query::TypedValue(auth::PermissionLevelToString(effective)),
query::TypedValue(description)});
}
}
}
return grants;
} catch (const auth::AuthException &e) {
throw query::QueryRuntimeException(e.what());
}
}
void GrantPrivilege(const std::string &user_or_role,
const std::vector<query::AuthQuery::Privilege> &privileges) override {
EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) {
// TODO (mferencevic): should we first check that the
// privilege is granted/denied/revoked before
// unconditionally granting/denying/revoking it?
permissions->Grant(permission);
});
}
void DenyPrivilege(const std::string &user_or_role,
const std::vector<query::AuthQuery::Privilege> &privileges) override {
EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) {
// TODO (mferencevic): should we first check that the
// privilege is granted/denied/revoked before
// unconditionally granting/denying/revoking it?
permissions->Deny(permission);
});
}
void RevokePrivilege(const std::string &user_or_role,
const std::vector<query::AuthQuery::Privilege> &privileges) override {
EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) {
// TODO (mferencevic): should we first check that the
// privilege is granted/denied/revoked before
// unconditionally granting/denying/revoking it?
permissions->Revoke(permission);
});
}
private:
template <class TEditFun>
void EditPermissions(const std::string &user_or_role, const std::vector<query::AuthQuery::Privilege> &privileges,
const TEditFun &edit_fun) {
if (!std::regex_match(user_or_role, name_regex_)) {
throw query::QueryRuntimeException("Invalid user or role name.");
}
try {
std::lock_guard<std::mutex> lock(auth_->WithLock());
std::vector<auth::Permission> permissions;
permissions.reserve(privileges.size());
for (const auto &privilege : privileges) {
permissions.push_back(glue::PrivilegeToPermission(privilege));
}
auto user = auth_->GetUser(user_or_role);
auto role = auth_->GetRole(user_or_role);
if (!user && !role) {
throw query::QueryRuntimeException("User or role '{}' doesn't exist.", user_or_role);
}
if (user) {
for (const auto &permission : permissions) {
edit_fun(&user->permissions(), permission);
}
auth_->SaveUser(*user);
} else {
for (const auto &permission : permissions) {
edit_fun(&role->permissions(), permission);
}
auth_->SaveRole(*role);
}
} catch (const auth::AuthException &e) {
throw query::QueryRuntimeException(e.what());
}
}
};
#else
class NoAuthInCommunity : public query::QueryRuntimeException {
public:
NoAuthInCommunity()
: query::QueryRuntimeException::QueryRuntimeException("Auth is not supported in Memgraph Community!") {}
};
class AuthQueryHandler final : public query::AuthQueryHandler {
public:
bool CreateUser(const std::string &, const std::optional<std::string> &) override { throw NoAuthInCommunity(); }
bool DropUser(const std::string &) override { throw NoAuthInCommunity(); }
void SetPassword(const std::string &, const std::optional<std::string> &) override { throw NoAuthInCommunity(); }
bool CreateRole(const std::string &) override { throw NoAuthInCommunity(); }
bool DropRole(const std::string &) override { throw NoAuthInCommunity(); }
std::vector<query::TypedValue> GetUsernames() override { throw NoAuthInCommunity(); }
std::vector<query::TypedValue> GetRolenames() override { throw NoAuthInCommunity(); }
std::optional<std::string> GetRolenameForUser(const std::string &) override { throw NoAuthInCommunity(); }
std::vector<query::TypedValue> GetUsernamesForRole(const std::string &) override { throw NoAuthInCommunity(); }
void SetRole(const std::string &, const std::string &) override { throw NoAuthInCommunity(); }
void ClearRole(const std::string &) override { throw NoAuthInCommunity(); }
std::vector<std::vector<query::TypedValue>> GetPrivileges(const std::string &) override { throw NoAuthInCommunity(); }
void GrantPrivilege(const std::string &, const std::vector<query::AuthQuery::Privilege> &) override {
throw NoAuthInCommunity();
}
void DenyPrivilege(const std::string &, const std::vector<query::AuthQuery::Privilege> &) override {
throw NoAuthInCommunity();
}
void RevokePrivilege(const std::string &, const std::vector<query::AuthQuery::Privilege> &) override {
throw NoAuthInCommunity();
}
};
#endif
// Needed to correctly handle memgraph destruction from a signal handler.
// Without having some sort of a flag, it is possible that a signal is handled
// when we are exiting main, inside destructors of database::GraphDb and
@ -1013,7 +1054,7 @@ int main(int argc, char **argv) {
// Begin enterprise features initialization
// Auth
auth::Auth auth{data_directory / "auth"};
utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> auth{data_directory / "auth"};
// Audit log
audit::Log audit_log{data_directory / "audit", FLAGS_audit_buffer_size, FLAGS_audit_buffer_flush_interval_ms};
@ -1075,24 +1116,28 @@ int main(int argc, char **argv) {
query::procedure::gModuleRegistry.SetModulesDirectory(query_modules_directories);
query::procedure::gModuleRegistry.UnloadAndLoadModulesFromDirectories();
{
// Triggers can execute query procedures, so we need to reload the modules first and then
// the triggers
auto storage_accessor = interpreter_context.db->Access();
auto dba = query::DbAccessor{&storage_accessor};
interpreter_context.trigger_store.RestoreTriggers(
&interpreter_context.ast_cache, &dba, &interpreter_context.antlr_lock, interpreter_context.config.query);
}
// As the Stream transformations are using modules, they have to be restored after the query modules are loaded.
interpreter_context.streams.RestoreStreams();
#ifdef MG_ENTERPRISE
AuthQueryHandler auth_handler(&auth, std::regex(FLAGS_auth_user_or_role_name_regex));
AuthChecker auth_checker{&auth};
#else
AuthQueryHandler auth_handler;
query::AllowEverythingAuthChecker auth_checker{};
#endif
interpreter_context.auth = &auth_handler;
interpreter_context.auth_checker = &auth_checker;
{
// Triggers can execute query procedures, so we need to reload the modules first and then
// the triggers
auto storage_accessor = interpreter_context.db->Access();
auto dba = query::DbAccessor{&storage_accessor};
interpreter_context.trigger_store.RestoreTriggers(&interpreter_context.ast_cache, &dba,
&interpreter_context.antlr_lock, interpreter_context.config.query,
interpreter_context.auth_checker);
}
ServerContext context;
std::string service_name = "Bolt";

View File

@ -0,0 +1,18 @@
#pragma once
#include "query/frontend/ast/ast.hpp"
namespace query {
class AuthChecker {
public:
virtual bool IsUserAuthorized(const std::optional<std::string> &username,
const std::vector<query::AuthQuery::Privilege> &privileges) const = 0;
};
class AllowEverythingAuthChecker final : public query::AuthChecker {
bool IsUserAuthorized(const std::optional<std::string> &username,
const std::vector<query::AuthQuery::Privilege> &privileges) const override {
return true;
}
};
} // namespace query

View File

@ -3,6 +3,7 @@
#include <atomic>
#include <chrono>
#include <limits>
#include <optional>
#include "glue/communication.hpp"
#include "query/constants.hpp"
@ -463,8 +464,13 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &
}
}
std::optional<std::string> StringPointerToOptional(const std::string *str) {
return str == nullptr ? std::nullopt : std::make_optional(*str);
}
Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters &parameters,
InterpreterContext *interpreter_context, DbAccessor *db_accessor) {
InterpreterContext *interpreter_context, DbAccessor *db_accessor,
const std::string *username) {
Frame frame(0);
SymbolTable symbol_table;
EvaluationContext evaluation_context;
@ -484,19 +490,21 @@ Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters &paramete
std::string consumer_group{stream_query->consumer_group_.empty() ? kDefaultConsumerGroup
: stream_query->consumer_group_};
callback.fn = [interpreter_context, stream_name = stream_query->stream_name_,
topic_names = stream_query->topic_names_, consumer_group = std::move(consumer_group),
batch_interval =
GetOptionalValue<std::chrono::milliseconds>(stream_query->batch_interval_, evaluator),
batch_size = GetOptionalValue<int64_t>(stream_query->batch_size_, evaluator),
transformation_name = stream_query->transform_name_]() mutable {
interpreter_context->streams.Create(stream_name, query::StreamInfo{.topics = std::move(topic_names),
.consumer_group = std::move(consumer_group),
.batch_interval = batch_interval,
.batch_size = batch_size,
.transformation_name = transformation_name});
return std::vector<std::vector<TypedValue>>{};
};
callback.fn =
[interpreter_context, stream_name = stream_query->stream_name_, topic_names = stream_query->topic_names_,
consumer_group = std::move(consumer_group),
batch_interval = GetOptionalValue<std::chrono::milliseconds>(stream_query->batch_interval_, evaluator),
batch_size = GetOptionalValue<int64_t>(stream_query->batch_size_, evaluator),
transformation_name = stream_query->transform_name_, owner = StringPointerToOptional(username)]() mutable {
interpreter_context->streams.Create(stream_name,
query::StreamInfo{.topics = std::move(topic_names),
.consumer_group = std::move(consumer_group),
.batch_interval = batch_interval,
.batch_size = batch_size,
.transformation_name = std::move(transformation_name),
.owner = std::move(owner)});
return std::vector<std::vector<TypedValue>>{};
};
return callback;
}
case StreamQuery::Action::START_STREAM: {
@ -535,8 +543,8 @@ Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters &paramete
return callback;
}
case StreamQuery::Action::SHOW_STREAMS: {
callback.header = {"name", "topics", "consumer_group", "batch_interval", "batch_size", "transformation_name",
"is running"};
callback.header = {"name", "topics", "consumer_group", "batch_interval", "batch_size", "transformation_name",
"owner", "is running"};
callback.fn = [interpreter_context]() {
auto streams_status = interpreter_context->streams.GetStreamInfo();
std::vector<std::vector<TypedValue>> results;
@ -565,6 +573,11 @@ Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters &paramete
typed_status.emplace_back();
}
typed_status.emplace_back(stream_info.transformation_name);
if (stream_info.owner.has_value()) {
typed_status.emplace_back(*stream_info.owner);
} else {
typed_status.emplace_back();
}
};
for (const auto &status : streams_status) {
@ -1231,16 +1244,17 @@ TriggerEventType ToTriggerEventType(const TriggerQuery::EventType event_type) {
Callback CreateTrigger(TriggerQuery *trigger_query,
const std::map<std::string, storage::PropertyValue> &user_parameters,
InterpreterContext *interpreter_context, DbAccessor *dba) {
InterpreterContext *interpreter_context, DbAccessor *dba, std::optional<std::string> owner) {
return {
{},
[trigger_name = std::move(trigger_query->trigger_name_), trigger_statement = std::move(trigger_query->statement_),
event_type = trigger_query->event_type_, before_commit = trigger_query->before_commit_, interpreter_context, dba,
user_parameters]() -> std::vector<std::vector<TypedValue>> {
user_parameters, owner = std::move(owner)]() mutable -> std::vector<std::vector<TypedValue>> {
interpreter_context->trigger_store.AddTrigger(
trigger_name, trigger_statement, user_parameters, ToTriggerEventType(event_type),
std::move(trigger_name), trigger_statement, user_parameters, ToTriggerEventType(event_type),
before_commit ? TriggerPhase::BEFORE_COMMIT : TriggerPhase::AFTER_COMMIT, &interpreter_context->ast_cache,
dba, &interpreter_context->antlr_lock, interpreter_context->config.query);
dba, &interpreter_context->antlr_lock, interpreter_context->config.query, std::move(owner),
interpreter_context->auth_checker);
return {};
}};
}
@ -1255,7 +1269,7 @@ Callback DropTrigger(TriggerQuery *trigger_query, InterpreterContext *interprete
}
Callback ShowTriggers(InterpreterContext *interpreter_context) {
return {{"trigger name", "statement", "event type", "phase"}, [interpreter_context] {
return {{"trigger name", "statement", "event type", "phase", "owner"}, [interpreter_context] {
std::vector<std::vector<TypedValue>> results;
auto trigger_infos = interpreter_context->trigger_store.GetTriggerInfo();
results.reserve(trigger_infos.size());
@ -1267,6 +1281,9 @@ Callback ShowTriggers(InterpreterContext *interpreter_context) {
typed_trigger_info.emplace_back(TriggerEventTypeToString(trigger_info.event_type));
typed_trigger_info.emplace_back(trigger_info.phase == TriggerPhase::BEFORE_COMMIT ? "BEFORE COMMIT"
: "AFTER COMMIT");
typed_trigger_info.emplace_back(trigger_info.owner.has_value() ? TypedValue{*trigger_info.owner}
: TypedValue{});
results.push_back(std::move(typed_trigger_info));
}
@ -1276,7 +1293,8 @@ Callback ShowTriggers(InterpreterContext *interpreter_context) {
PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, const bool in_explicit_transaction,
InterpreterContext *interpreter_context, DbAccessor *dba,
const std::map<std::string, storage::PropertyValue> &user_parameters) {
const std::map<std::string, storage::PropertyValue> &user_parameters,
const std::string *username) {
if (in_explicit_transaction) {
throw TriggerModificationInMulticommandTxException();
}
@ -1284,11 +1302,12 @@ PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, const bool in_explic
auto *trigger_query = utils::Downcast<TriggerQuery>(parsed_query.query);
MG_ASSERT(trigger_query);
auto callback = [trigger_query, interpreter_context, dba, &user_parameters] {
auto callback = [trigger_query, interpreter_context, dba, &user_parameters,
owner = StringPointerToOptional(username)]() mutable {
switch (trigger_query->action_) {
case TriggerQuery::Action::CREATE_TRIGGER:
EventCounter::IncrementCounter(EventCounter::TriggersCreated);
return CreateTrigger(trigger_query, user_parameters, interpreter_context, dba);
return CreateTrigger(trigger_query, user_parameters, interpreter_context, dba, std::move(owner));
case TriggerQuery::Action::DROP_TRIGGER:
return DropTrigger(trigger_query, interpreter_context);
case TriggerQuery::Action::SHOW_TRIGGERS:
@ -1315,14 +1334,15 @@ PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, const bool in_explic
PreparedQuery PrepareStreamQuery(ParsedQuery parsed_query, const bool in_explicit_transaction,
InterpreterContext *interpreter_context, DbAccessor *dba,
const std::map<std::string, storage::PropertyValue> &user_parameters) {
const std::map<std::string, storage::PropertyValue> &user_parameters,
const std::string *username) {
if (in_explicit_transaction) {
throw StreamQueryInMulticommandTxException();
}
auto *stream_query = utils::Downcast<StreamQuery>(parsed_query.query);
MG_ASSERT(stream_query);
auto callback = HandleStreamQuery(stream_query, parsed_query.parameters, interpreter_context, dba);
auto callback = HandleStreamQuery(stream_query, parsed_query.parameters, interpreter_context, dba, username);
return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges),
[callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}](
@ -1651,7 +1671,8 @@ void Interpreter::RollbackTransaction() {
}
Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
const std::map<std::string, storage::PropertyValue> &params) {
const std::map<std::string, storage::PropertyValue> &params,
const std::string *username) {
if (!in_explicit_transaction_) {
query_executions_.clear();
}
@ -1748,10 +1769,10 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
prepared_query = PrepareFreeMemoryQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_);
} else if (utils::Downcast<TriggerQuery>(parsed_query.query)) {
prepared_query = PrepareTriggerQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_,
&*execution_db_accessor_, params);
&*execution_db_accessor_, params, username);
} else if (utils::Downcast<StreamQuery>(parsed_query.query)) {
prepared_query = PrepareStreamQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_,
&*execution_db_accessor_, params);
&*execution_db_accessor_, params, username);
} else if (utils::Downcast<IsolationLevelQuery>(parsed_query.query)) {
prepared_query =
PrepareIsolationLevelQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_, this);
@ -1809,7 +1830,7 @@ void RunTriggersIndividually(const utils::SkipList<Trigger> &triggers, Interpret
trigger_context.AdaptForAccessor(&db_accessor);
try {
trigger.Execute(&db_accessor, &execution_memory, interpreter_context->config.execution_timeout_sec,
&interpreter_context->is_shutting_down, trigger_context);
&interpreter_context->is_shutting_down, trigger_context, interpreter_context->auth_checker);
} catch (const utils::BasicException &exception) {
spdlog::warn("Trigger '{}' failed with exception:\n{}", trigger.Name(), exception.what());
db_accessor.Abort();
@ -1864,7 +1885,7 @@ void Interpreter::Commit() {
AdvanceCommand();
try {
trigger.Execute(&*execution_db_accessor_, &execution_memory, interpreter_context_->config.execution_timeout_sec,
&interpreter_context_->is_shutting_down, *trigger_context);
&interpreter_context_->is_shutting_down, *trigger_context, interpreter_context_->auth_checker);
} catch (const utils::BasicException &e) {
throw utils::BasicException(
fmt::format("Trigger '{}' caused the transaction to fail.\nException: {}", trigger.Name(), e.what()));

View File

@ -2,6 +2,7 @@
#include <gflags/gflags.h>
#include "query/auth_checker.hpp"
#include "query/config.hpp"
#include "query/context.hpp"
#include "query/cypher_query_interpreter.hpp"
@ -166,6 +167,7 @@ struct InterpreterContext {
std::atomic<bool> is_shutting_down{false};
AuthQueryHandler *auth{nullptr};
query::AuthChecker *auth_checker{nullptr};
utils::SkipList<QueryCacheEntry> ast_cache;
utils::SkipList<PlanCacheEntry> plan_cache;
@ -205,7 +207,8 @@ class Interpreter final {
*
* @throw query::QueryException
*/
PrepareResult Prepare(const std::string &query, const std::map<std::string, storage::PropertyValue> &params);
PrepareResult Prepare(const std::string &query, const std::map<std::string, storage::PropertyValue> &params,
const std::string *username);
/**
* Execute the last prepared query and stream *all* of the results into the

View File

@ -106,6 +106,7 @@ const std::string kBatchIntervalKey{"batch_interval"};
const std::string kBatchSizeKey{"batch_size"};
const std::string kIsRunningKey{"is_running"};
const std::string kTransformationName{"transformation_name"};
const std::string kOwner{"owner"};
void to_json(nlohmann::json &data, StreamStatus &&status) {
auto &info = status.info;
@ -127,6 +128,12 @@ void to_json(nlohmann::json &data, StreamStatus &&status) {
data[kIsRunningKey] = status.is_running;
data[kTransformationName] = status.info.transformation_name;
if (info.owner.has_value()) {
data[kOwner] = std::move(*info.owner);
} else {
data[kOwner] = nullptr;
}
}
void from_json(const nlohmann::json &data, StreamStatus &status) {
@ -135,16 +142,14 @@ void from_json(const nlohmann::json &data, StreamStatus &status) {
data.at(kTopicsKey).get_to(info.topics);
data.at(kConsumerGroupKey).get_to(info.consumer_group);
const auto batch_interval = data.at(kBatchIntervalKey);
if (!batch_interval.is_null()) {
if (const auto batch_interval = data.at(kBatchIntervalKey); !batch_interval.is_null()) {
using BatchInterval = decltype(info.batch_interval)::value_type;
info.batch_interval = BatchInterval{batch_interval.get<BatchInterval::rep>()};
} else {
info.batch_interval = {};
}
const auto batch_size = data.at(kBatchSizeKey);
if (!batch_size.is_null()) {
if (const auto batch_size = data.at(kBatchSizeKey); !batch_size.is_null()) {
info.batch_size = batch_size.get<decltype(info.batch_size)::value_type>();
} else {
info.batch_size = {};
@ -152,6 +157,12 @@ void from_json(const nlohmann::json &data, StreamStatus &status) {
data.at(kIsRunningKey).get_to(status.is_running);
data.at(kTransformationName).get_to(status.info.transformation_name);
if (const auto &owner = data.at(kOwner); !owner.is_null()) {
info.owner = owner.get<decltype(info.owner)::value_type>();
} else {
info.owner = {};
}
}
Streams::Streams(InterpreterContext *interpreter_context, std::string bootstrap_servers,
@ -200,7 +211,8 @@ void Streams::Create(const std::string &stream_name, StreamInfo info) {
auto it = CreateConsumer(*locked_streams, stream_name, std::move(info));
try {
Persist(CreateStatus(stream_name, it->second.transformation_name, *it->second.consumer->ReadLock()));
Persist(
CreateStatus(stream_name, it->second.transformation_name, it->second.owner, *it->second.consumer->ReadLock()));
} catch (...) {
locked_streams->erase(it);
throw;
@ -233,7 +245,7 @@ void Streams::Start(const std::string &stream_name) {
auto locked_consumer = it->second.consumer->Lock();
locked_consumer->Start();
Persist(CreateStatus(stream_name, it->second.transformation_name, *locked_consumer));
Persist(CreateStatus(stream_name, it->second.transformation_name, it->second.owner, *locked_consumer));
}
void Streams::Stop(const std::string &stream_name) {
@ -243,7 +255,7 @@ void Streams::Stop(const std::string &stream_name) {
auto locked_consumer = it->second.consumer->Lock();
locked_consumer->Stop();
Persist(CreateStatus(stream_name, it->second.transformation_name, *locked_consumer));
Persist(CreateStatus(stream_name, it->second.transformation_name, it->second.owner, *locked_consumer));
}
void Streams::StartAll() {
@ -251,7 +263,7 @@ void Streams::StartAll() {
auto locked_consumer = stream_data.consumer->Lock();
if (!locked_consumer->IsRunning()) {
locked_consumer->Start();
Persist(CreateStatus(stream_name, stream_data.transformation_name, *locked_consumer));
Persist(CreateStatus(stream_name, stream_data.transformation_name, stream_data.owner, *locked_consumer));
}
}
}
@ -261,7 +273,7 @@ void Streams::StopAll() {
auto locked_consumer = stream_data.consumer->Lock();
if (locked_consumer->IsRunning()) {
locked_consumer->Stop();
Persist(CreateStatus(stream_name, stream_data.transformation_name, *locked_consumer));
Persist(CreateStatus(stream_name, stream_data.transformation_name, stream_data.owner, *locked_consumer));
}
}
}
@ -270,8 +282,8 @@ std::vector<StreamStatus> Streams::GetStreamInfo() const {
std::vector<StreamStatus> result;
{
for (auto locked_streams = streams_.ReadLock(); const auto &[stream_name, stream_data] : *locked_streams) {
result.emplace_back(
CreateStatus(stream_name, stream_data.transformation_name, *stream_data.consumer->ReadLock()));
result.emplace_back(CreateStatus(stream_name, stream_data.transformation_name, stream_data.owner,
*stream_data.consumer->ReadLock()));
}
}
return result;
@ -314,6 +326,7 @@ TransformationResult Streams::Check(const std::string &stream_name, std::optiona
}
StreamStatus Streams::CreateStatus(const std::string &name, const std::string &transformation_name,
const std::optional<std::string> &owner,
const integrations::kafka::Consumer &consumer) {
const auto &info = consumer.Info();
return StreamStatus{name,
@ -323,6 +336,7 @@ StreamStatus Streams::CreateStatus(const std::string &name, const std::string &t
info.batch_interval,
info.batch_size,
transformation_name,
owner,
},
consumer.IsRunning()};
}
@ -336,7 +350,7 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std
auto *memory_resource = utils::NewDeleteResource();
auto consumer_function = [interpreter_context = interpreter_context_, memory_resource, stream_name,
transformation_name = stream_info.transformation_name,
transformation_name = stream_info.transformation_name, owner = stream_info.owner,
interpreter = std::make_shared<Interpreter>(interpreter_context_),
result = mgp_result{nullptr, memory_resource}](
const std::vector<integrations::kafka::Message> &messages) mutable {
@ -347,6 +361,10 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std
DiscardValueResultStream stream;
spdlog::trace("Start transaction in stream '{}'", stream_name);
utils::OnScopeExit cleanup{[&interpreter, &result]() {
result.rows.clear();
interpreter->Abort();
}};
interpreter->BeginTransaction();
for (auto &row : result.rows) {
@ -358,7 +376,13 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std
std::string query{query_value.ValueString()};
spdlog::trace("Executing query '{}' in stream '{}'", query, stream_name);
auto prepare_result =
interpreter->Prepare(query, params_prop.IsNull() ? empty_parameters : params_prop.ValueMap());
interpreter->Prepare(query, params_prop.IsNull() ? empty_parameters : params_prop.ValueMap(), nullptr);
if (!interpreter_context->auth_checker->IsUserAuthorized(owner, prepare_result.privileges)) {
throw StreamsException{
"Couldn't execute query '{}' for stream '{}' becuase the owner is not authorized to execute the "
"query!",
query, stream_name};
}
interpreter->PullAll(&stream);
}
@ -376,7 +400,7 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std
};
auto insert_result = map.insert_or_assign(
stream_name, StreamData{std::move(stream_info.transformation_name),
stream_name, StreamData{std::move(stream_info.transformation_name), std::move(stream_info.owner),
std::make_unique<SynchronizedConsumer>(bootstrap_servers_, std::move(consumer_info),
std::move(consumer_function))});
MG_ASSERT(insert_result.second, "Unexpected error during storing consumer '{}'", stream_name);

View File

@ -28,6 +28,7 @@ struct StreamInfo {
std::optional<std::chrono::milliseconds> batch_interval;
std::optional<int64_t> batch_size;
std::string transformation_name;
std::optional<std::string> owner;
};
struct StreamStatus {
@ -40,6 +41,7 @@ using SynchronizedConsumer = utils::Synchronized<integrations::kafka::Consumer,
struct StreamData {
std::string transformation_name;
std::optional<std::string> owner;
std::unique_ptr<SynchronizedConsumer> consumer;
};
@ -131,6 +133,7 @@ class Streams final {
using SynchronizedStreamsMap = utils::Synchronized<StreamsMap, utils::WritePrioritizedRWLock>;
static StreamStatus CreateStatus(const std::string &name, const std::string &transformation_name,
const std::optional<std::string> &owner,
const integrations::kafka::Consumer &consumer);
StreamsMap::iterator CreateConsumer(StreamsMap &map, const std::string &stream_name, StreamInfo stream_info);

View File

@ -142,51 +142,55 @@ std::vector<std::pair<Identifier, TriggerIdentifierTag>> GetPredefinedIdentifier
Trigger::Trigger(std::string name, const std::string &query,
const std::map<std::string, storage::PropertyValue> &user_parameters,
const TriggerEventType event_type, utils::SkipList<QueryCacheEntry> *query_cache,
DbAccessor *db_accessor, utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config)
DbAccessor *db_accessor, utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config,
std::optional<std::string> owner, const query::AuthChecker *auth_checker)
: name_{std::move(name)},
parsed_statements_{ParseQuery(query, user_parameters, query_cache, antlr_lock, query_config)},
event_type_{event_type} {
event_type_{event_type},
owner_{std::move(owner)} {
// We check immediately if the query is valid by trying to create a plan.
GetPlan(db_accessor);
GetPlan(db_accessor, auth_checker);
}
Trigger::TriggerPlan::TriggerPlan(std::unique_ptr<LogicalPlan> logical_plan, std::vector<IdentifierInfo> identifiers)
: cached_plan(std::move(logical_plan)), identifiers(std::move(identifiers)) {}
std::shared_ptr<Trigger::TriggerPlan> Trigger::GetPlan(DbAccessor *db_accessor) const {
std::shared_ptr<Trigger::TriggerPlan> Trigger::GetPlan(DbAccessor *db_accessor,
const query::AuthChecker *auth_checker) const {
std::lock_guard plan_guard{plan_lock_};
if (parsed_statements_.is_cacheable && trigger_plan_ && !trigger_plan_->cached_plan.IsExpired()) {
return trigger_plan_;
if (!parsed_statements_.is_cacheable || !trigger_plan_ || trigger_plan_->cached_plan.IsExpired()) {
auto identifiers = GetPredefinedIdentifiers(event_type_);
AstStorage ast_storage;
ast_storage.properties_ = parsed_statements_.ast_storage.properties_;
ast_storage.labels_ = parsed_statements_.ast_storage.labels_;
ast_storage.edge_types_ = parsed_statements_.ast_storage.edge_types_;
std::vector<Identifier *> predefined_identifiers;
predefined_identifiers.reserve(identifiers.size());
std::transform(identifiers.begin(), identifiers.end(), std::back_inserter(predefined_identifiers),
[](auto &identifier) { return &identifier.first; });
auto logical_plan = MakeLogicalPlan(std::move(ast_storage), utils::Downcast<CypherQuery>(parsed_statements_.query),
parsed_statements_.parameters, db_accessor, predefined_identifiers);
trigger_plan_ = std::make_shared<TriggerPlan>(std::move(logical_plan), std::move(identifiers));
}
if (!auth_checker->IsUserAuthorized(owner_, parsed_statements_.required_privileges)) {
throw utils::BasicException("The owner of trigger '{}' is not authorized to execute the query!", name_);
}
auto identifiers = GetPredefinedIdentifiers(event_type_);
AstStorage ast_storage;
ast_storage.properties_ = parsed_statements_.ast_storage.properties_;
ast_storage.labels_ = parsed_statements_.ast_storage.labels_;
ast_storage.edge_types_ = parsed_statements_.ast_storage.edge_types_;
std::vector<Identifier *> predefined_identifiers;
predefined_identifiers.reserve(identifiers.size());
std::transform(identifiers.begin(), identifiers.end(), std::back_inserter(predefined_identifiers),
[](auto &identifier) { return &identifier.first; });
auto logical_plan = MakeLogicalPlan(std::move(ast_storage), utils::Downcast<CypherQuery>(parsed_statements_.query),
parsed_statements_.parameters, db_accessor, predefined_identifiers);
trigger_plan_ = std::make_shared<TriggerPlan>(std::move(logical_plan), std::move(identifiers));
return trigger_plan_;
}
void Trigger::Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution_memory,
const double max_execution_time_sec, std::atomic<bool> *is_shutting_down,
const TriggerContext &context) const {
const TriggerContext &context, const AuthChecker *auth_checker) const {
if (!context.ShouldEventTrigger(event_type_)) {
return;
}
spdlog::debug("Executing trigger '{}'", name_);
auto trigger_plan = GetPlan(dba);
auto trigger_plan = GetPlan(dba, auth_checker);
MG_ASSERT(trigger_plan, "Invalid trigger plan received");
auto &[plan, identifiers] = *trigger_plan;
@ -238,13 +242,14 @@ void Trigger::Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution
namespace {
// When the format of the persisted trigger is changed, increase this version
constexpr uint64_t kVersion{1};
constexpr uint64_t kVersion{2};
} // namespace
TriggerStore::TriggerStore(std::filesystem::path directory) : storage_{std::move(directory)} {}
void TriggerStore::RestoreTriggers(utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor,
utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config) {
utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config,
const query::AuthChecker *auth_checker) {
MG_ASSERT(before_commit_triggers_.size() == 0 && after_commit_triggers_.size() == 0,
"Cannot restore trigger when some triggers already exist!");
spdlog::info("Loading triggers...");
@ -292,10 +297,19 @@ void TriggerStore::RestoreTriggers(utils::SkipList<QueryCacheEntry> *query_cache
}
const auto user_parameters = serialization::DeserializePropertyValueMap(json_trigger_data["user_parameters"]);
const auto owner_json = json_trigger_data["owner"];
std::optional<std::string> owner{};
if (owner_json.is_string()) {
owner.emplace(owner_json.get<std::string>());
} else if (!owner_json.is_null()) {
spdlog::warn(invalid_state_message);
continue;
}
std::optional<Trigger> trigger;
try {
trigger.emplace(trigger_name, statement, user_parameters, event_type, query_cache, db_accessor, antlr_lock,
query_config);
query_config, std::move(owner), auth_checker);
} catch (const utils::BasicException &e) {
spdlog::warn("Failed to create trigger '{}' because: {}", trigger_name, e.what());
continue;
@ -309,11 +323,12 @@ void TriggerStore::RestoreTriggers(utils::SkipList<QueryCacheEntry> *query_cache
}
}
void TriggerStore::AddTrigger(const std::string &name, const std::string &query,
void TriggerStore::AddTrigger(std::string name, const std::string &query,
const std::map<std::string, storage::PropertyValue> &user_parameters,
TriggerEventType event_type, TriggerPhase phase,
utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor,
utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config) {
utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config,
std::optional<std::string> owner, const query::AuthChecker *auth_checker) {
std::unique_lock store_guard{store_lock_};
if (storage_.Get(name)) {
throw utils::BasicException("Trigger with the same name already exists.");
@ -321,7 +336,8 @@ void TriggerStore::AddTrigger(const std::string &name, const std::string &query,
std::optional<Trigger> trigger;
try {
trigger.emplace(name, query, user_parameters, event_type, query_cache, db_accessor, antlr_lock, query_config);
trigger.emplace(std::move(name), query, user_parameters, event_type, query_cache, db_accessor, antlr_lock,
query_config, std::move(owner), auth_checker);
} catch (const utils::BasicException &e) {
const auto identifiers = GetPredefinedIdentifiers(event_type);
std::stringstream identifier_names_stream;
@ -342,7 +358,13 @@ void TriggerStore::AddTrigger(const std::string &name, const std::string &query,
data["event_type"] = event_type;
data["phase"] = phase;
data["version"] = kVersion;
storage_.Put(name, data.dump());
if (const auto &owner_from_trigger = trigger->Owner(); owner_from_trigger.has_value()) {
data["owner"] = *owner_from_trigger;
} else {
data["owner"] = nullptr;
}
storage_.Put(trigger->Name(), data.dump());
store_guard.unlock();
auto triggers_acc =
@ -384,7 +406,7 @@ std::vector<TriggerStore::TriggerInfo> TriggerStore::GetTriggerInfo() const {
const auto add_info = [&](const utils::SkipList<Trigger> &trigger_list, const TriggerPhase phase) {
for (const auto &trigger : trigger_list.access()) {
info.push_back({trigger.Name(), trigger.OriginalStatement(), trigger.EventType(), phase});
info.push_back({trigger.Name(), trigger.OriginalStatement(), trigger.EventType(), phase, trigger.Owner()});
}
};

View File

@ -9,6 +9,7 @@
#include <vector>
#include "kvstore/kvstore.hpp"
#include "query/auth_checker.hpp"
#include "query/config.hpp"
#include "query/cypher_query_interpreter.hpp"
#include "query/db_accessor.hpp"
@ -23,10 +24,12 @@ struct Trigger {
explicit Trigger(std::string name, const std::string &query,
const std::map<std::string, storage::PropertyValue> &user_parameters, TriggerEventType event_type,
utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor, utils::SpinLock *antlr_lock,
const InterpreterConfig::Query &query_config);
const InterpreterConfig::Query &query_config, std::optional<std::string> owner,
const query::AuthChecker *auth_checker);
void Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution_memory, double max_execution_time_sec,
std::atomic<bool> *is_shutting_down, const TriggerContext &context) const;
std::atomic<bool> *is_shutting_down, const TriggerContext &context,
const AuthChecker *auth_checker) const;
bool operator==(const Trigger &other) const { return name_ == other.name_; }
// NOLINTNEXTLINE (modernize-use-nullptr)
@ -37,6 +40,7 @@ struct Trigger {
const auto &Name() const noexcept { return name_; }
const auto &OriginalStatement() const noexcept { return parsed_statements_.query_string; }
const auto &Owner() const noexcept { return owner_; }
auto EventType() const noexcept { return event_type_; }
private:
@ -48,7 +52,7 @@ struct Trigger {
CachedPlan cached_plan;
std::vector<IdentifierInfo> identifiers;
};
std::shared_ptr<TriggerPlan> GetPlan(DbAccessor *db_accessor) const;
std::shared_ptr<TriggerPlan> GetPlan(DbAccessor *db_accessor, const query::AuthChecker *auth_checker) const;
std::string name_;
ParsedQuery parsed_statements_;
@ -57,6 +61,7 @@ struct Trigger {
mutable utils::SpinLock plan_lock_;
mutable std::shared_ptr<TriggerPlan> trigger_plan_;
std::optional<std::string> owner_;
};
enum class TriggerPhase : uint8_t { BEFORE_COMMIT, AFTER_COMMIT };
@ -65,12 +70,14 @@ struct TriggerStore {
explicit TriggerStore(std::filesystem::path directory);
void RestoreTriggers(utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor,
utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config);
utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config,
const query::AuthChecker *auth_checker);
void AddTrigger(const std::string &name, const std::string &query,
void AddTrigger(std::string name, const std::string &query,
const std::map<std::string, storage::PropertyValue> &user_parameters, TriggerEventType event_type,
TriggerPhase phase, utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor,
utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config);
utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config,
std::optional<std::string> owner, const query::AuthChecker *auth_checker);
void DropTrigger(const std::string &name);
@ -79,6 +86,7 @@ struct TriggerStore {
std::string statement;
TriggerEventType event_type;
TriggerPhase phase;
std::optional<std::string> owner;
};
std::vector<TriggerInfo> GetTriggerInfo() const;

View File

@ -54,7 +54,7 @@ BENCHMARK_DEFINE_F(ExpansionBenchFixture, Match)(benchmark::State &state) {
while (state.KeepRunning()) {
ResultStreamFaker results(&*db);
interpreter->Prepare(query, {});
interpreter->Prepare(query, {}, nullptr);
interpreter->PullAll(&results);
}
}
@ -69,7 +69,7 @@ BENCHMARK_DEFINE_F(ExpansionBenchFixture, Expand)(benchmark::State &state) {
while (state.KeepRunning()) {
ResultStreamFaker results(&*db);
interpreter->Prepare(query, {});
interpreter->Prepare(query, {}, nullptr);
interpreter->PullAll(&results);
}
}

View File

@ -6,6 +6,9 @@ add_custom_target(memgraph__e2e__streams__${FILE_NAME} ALL
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${FILE_NAME})
endfunction()
copy_streams_e2e_python_files(common.py)
copy_streams_e2e_python_files(conftest.py)
copy_streams_e2e_python_files(streams_tests.py)
copy_streams_e2e_python_files(streams_owner_tests.py)
copy_streams_e2e_python_files(streams_test_runner.sh)
add_subdirectory(transformations)

110
tests/e2e/streams/common.py Normal file
View File

@ -0,0 +1,110 @@
import mgclient
import time
# These are the indices of the different values in the result of SHOW STREAM
# query
NAME = 0
TOPICS = 1
CONSUMER_GROUP = 2
BATCH_INTERVAL = 3
BATCH_SIZE = 4
TRANSFORM = 5
OWNER = 6
IS_RUNNING = 7
def execute_and_fetch_all(cursor, query):
cursor.execute(query)
return cursor.fetchall()
def connect(**kwargs):
connection = mgclient.connect(host="localhost", port=7687, **kwargs)
connection.autocommit = True
return connection
def timed_wait(fun):
start_time = time.time()
seconds = 10
while True:
current_time = time.time()
elapsed_time = current_time - start_time
if elapsed_time > seconds:
return False
if fun():
return True
time.sleep(0.1)
def check_one_result_row(cursor, query):
start_time = time.time()
seconds = 10
while True:
current_time = time.time()
elapsed_time = current_time - start_time
if elapsed_time > seconds:
return False
cursor.execute(query)
results = cursor.fetchall()
if len(results) < 1:
time.sleep(0.1)
continue
return len(results) == 1
def check_vertex_exists_with_topic_and_payload(cursor, topic, payload_bytes):
assert check_one_result_row(cursor,
"MATCH (n: MESSAGE {"
f"payload: '{payload_bytes.decode('utf-8')}',"
f"topic: '{topic}'"
"}) RETURN n")
def get_stream_info(cursor, stream_name):
stream_infos = execute_and_fetch_all(cursor, "SHOW STREAMS")
for stream_info in stream_infos:
if (stream_info[NAME] == stream_name):
return stream_info
return None
def get_is_running(cursor, stream_name):
stream_info = get_stream_info(cursor, stream_name)
assert stream_info
return stream_info[IS_RUNNING]
def start_stream(cursor, stream_name):
execute_and_fetch_all(cursor, f"START STREAM {stream_name}")
assert get_is_running(cursor, stream_name)
def stop_stream(cursor, stream_name):
execute_and_fetch_all(cursor, f"STOP STREAM {stream_name}")
assert not get_is_running(cursor, stream_name)
def drop_stream(cursor, stream_name):
execute_and_fetch_all(cursor, f"DROP STREAM {stream_name}")
assert get_stream_info(cursor, stream_name) is None
def check_stream_info(cursor, stream_name, expected_stream_info):
stream_info = get_stream_info(cursor, stream_name)
assert len(stream_info) == len(expected_stream_info)
for info, expected_info in zip(stream_info, expected_stream_info):
assert info == expected_info

View File

@ -0,0 +1,45 @@
import pytest
from kafka import KafkaProducer
from kafka.admin import KafkaAdminClient, NewTopic
from common import execute_and_fetch_all, connect, NAME
# To run these test locally a running Kafka sever is necessery. The test tries
# to connect on localhost:9092.
@pytest.fixture(autouse=True)
def connection():
connection = connect()
yield connection
cursor = connection.cursor()
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n")
stream_infos = execute_and_fetch_all(cursor, "SHOW STREAMS")
for stream_info in stream_infos:
execute_and_fetch_all(cursor, f"DROP STREAM {stream_info[NAME]}")
users = execute_and_fetch_all(cursor, "SHOW USERS")
for username, in users:
execute_and_fetch_all(cursor, f"DROP USER {username}")
@pytest.fixture(scope="function")
def topics():
admin_client = KafkaAdminClient(
bootstrap_servers="localhost:9092", client_id='test')
topics = []
topics_to_create = []
for index in range(3):
topic = f"topic_{index}"
topics.append(topic)
topics_to_create.append(NewTopic(name=topic,
num_partitions=1, replication_factor=1))
admin_client.create_topics(new_topics=topics_to_create, timeout_ms=5000)
yield topics
admin_client.delete_topics(topics=topics, timeout_ms=5000)
@pytest.fixture(scope="function")
def producer():
yield KafkaProducer(bootstrap_servers="localhost:9092")

View File

@ -0,0 +1,151 @@
import sys
import pytest
import time
import mgclient
import common
def get_cursor_with_user(username):
connection = common.connect(username=username, password="")
return connection.cursor()
def create_admin_user(cursor, admin_user):
common.execute_and_fetch_all(cursor, f"CREATE USER {admin_user}")
common.execute_and_fetch_all(
cursor, f"GRANT ALL PRIVILEGES TO {admin_user}")
def create_stream_user(cursor, stream_user):
common.execute_and_fetch_all(cursor, f"CREATE USER {stream_user}")
common.execute_and_fetch_all(
cursor, f"GRANT STREAM TO {stream_user}")
def test_ownerless_stream(producer, topics, connection):
assert len(topics) > 0
userless_cursor = connection.cursor()
common.execute_and_fetch_all(userless_cursor,
"CREATE STREAM ownerless "
f"TOPICS {topics[0]} "
f"TRANSFORM transform.simple")
common.start_stream(userless_cursor, "ownerless")
time.sleep(1)
admin_user = "admin_user"
create_admin_user(userless_cursor, admin_user)
producer.send(topics[0], b"first message").get(timeout=60)
assert common.timed_wait(
lambda: not common.get_is_running(userless_cursor, "ownerless"))
assert len(common.execute_and_fetch_all(
userless_cursor, "MATCH (n) RETURN n")) == 0
common.execute_and_fetch_all(userless_cursor, f"DROP USER {admin_user}")
common.start_stream(userless_cursor, "ownerless")
time.sleep(1)
second_message = b"second message"
producer.send(topics[0], second_message).get(timeout=60)
common.check_vertex_exists_with_topic_and_payload(
userless_cursor, topics[0], second_message)
assert len(common.execute_and_fetch_all(
userless_cursor, "MATCH (n) RETURN n")) == 1
def test_owner_is_shown(topics, connection):
assert len(topics) > 0
userless_cursor = connection.cursor()
stream_user = "stream_user"
create_stream_user(userless_cursor, stream_user)
stream_cursor = get_cursor_with_user(stream_user)
common.execute_and_fetch_all(stream_cursor, "CREATE STREAM test "
f"TOPICS {topics[0]} "
f"TRANSFORM transform.simple")
common.check_stream_info(userless_cursor, "test", ("test", [
topics[0]], "mg_consumer", None, None,
"transform.simple", stream_user, False))
def test_insufficient_privileges(producer, topics, connection):
assert len(topics) > 0
userless_cursor = connection.cursor()
admin_user = "admin_user"
create_admin_user(userless_cursor, admin_user)
admin_cursor = get_cursor_with_user(admin_user)
stream_user = "stream_user"
create_stream_user(userless_cursor, stream_user)
stream_cursor = get_cursor_with_user(stream_user)
common.execute_and_fetch_all(stream_cursor,
"CREATE STREAM insufficient_test "
f"TOPICS {topics[0]} "
f"TRANSFORM transform.simple")
# the stream is started by admin, but should check against the owner
# privileges
common.start_stream(admin_cursor, "insufficient_test")
time.sleep(1)
producer.send(topics[0], b"first message").get(timeout=60)
assert common.timed_wait(
lambda: not common.get_is_running(userless_cursor, "insufficient_test"))
assert len(common.execute_and_fetch_all(
userless_cursor, "MATCH (n) RETURN n")) == 0
common.execute_and_fetch_all(
admin_cursor, f"GRANT CREATE TO {stream_user}")
common.start_stream(userless_cursor, "insufficient_test")
time.sleep(1)
second_message = b"second message"
producer.send(topics[0], second_message).get(timeout=60)
common.check_vertex_exists_with_topic_and_payload(
userless_cursor, topics[0], second_message)
assert len(common.execute_and_fetch_all(
userless_cursor, "MATCH (n) RETURN n")) == 1
def test_happy_case(producer, topics, connection):
assert len(topics) > 0
userless_cursor = connection.cursor()
admin_user = "admin_user"
create_admin_user(userless_cursor, admin_user)
admin_cursor = get_cursor_with_user(admin_user)
stream_user = "stream_user"
create_stream_user(userless_cursor, stream_user)
stream_cursor = get_cursor_with_user(stream_user)
common.execute_and_fetch_all(
admin_cursor, f"GRANT CREATE TO {stream_user}")
common.execute_and_fetch_all(stream_cursor,
"CREATE STREAM insufficient_test "
f"TOPICS {topics[0]} "
f"TRANSFORM transform.simple")
common.start_stream(stream_cursor, "insufficient_test")
time.sleep(1)
first_message = b"first message"
producer.send(topics[0], first_message).get(timeout=60)
common.check_vertex_exists_with_topic_and_payload(
userless_cursor, topics[0], first_message)
assert len(common.execute_and_fetch_all(
userless_cursor, "MATCH (n) RETURN n")) == 1
if __name__ == "__main__":
sys.exit(pytest.main([__file__, "-rA"]))

View File

@ -3,4 +3,4 @@
# This workaround is necessary to run in the same virtualenv as the e2e runner.py
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
python3 "$DIR/streams_tests.py"
python3 "$DIR/$1"

View File

@ -1,28 +1,11 @@
#!/usr/bin/python3
# To run these test locally a running Kafka sever is necessery. The test tries
# to connect on localhost:9092.
# All tests are implemented in this file, because using the same test fixtures
# in multiple files is not possible in a straightforward way
import sys
import pytest
import mgclient
import time
from multiprocessing import Process, Value
from kafka import KafkaProducer
from kafka.admin import KafkaAdminClient, NewTopic
# These are the indices of the different values in the result of SHOW STREAM
# query
NAME = 0
TOPICS = 1
CONSUMER_GROUP = 2
BATCH_INTERVAL = 3
BATCH_SIZE = 4
TRANSFORM = 5
IS_RUNNING = 6
import common
# These are the indices of the query and parameters in the result of CHECK
# STREAM query
@ -35,155 +18,22 @@ TRANSFORMATIONS_TO_CHECK = [
SIMPLE_MSG = b'message'
def execute_and_fetch_all(cursor, query):
cursor.execute(query)
return cursor.fetchall()
def connect():
connection = mgclient.connect(host="localhost", port=7687)
connection.autocommit = True
return connection
@pytest.fixture(autouse=True)
def connection():
connection = connect()
yield connection
cursor = connection.cursor()
execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n")
stream_infos = execute_and_fetch_all(cursor, "SHOW STREAMS")
for stream_info in stream_infos:
execute_and_fetch_all(cursor, f"DROP STREAM {stream_info[NAME]}")
@pytest.fixture(scope="function")
def topics():
admin_client = KafkaAdminClient(
bootstrap_servers="localhost:9092", client_id='test')
topics = []
topics_to_create = []
for index in range(3):
topic = f"topic_{index}"
topics.append(topic)
topics_to_create.append(NewTopic(name=topic,
num_partitions=1, replication_factor=1))
admin_client.create_topics(new_topics=topics_to_create, timeout_ms=5000)
yield topics
admin_client.delete_topics(topics=topics, timeout_ms=5000)
@pytest.fixture(scope="function")
def producer():
yield KafkaProducer(bootstrap_servers="localhost:9092")
def timed_wait(fun):
start_time = time.time()
seconds = 10
while True:
current_time = time.time()
elapsed_time = current_time - start_time
if elapsed_time > seconds:
return False
if fun():
return True
def check_one_result_row(cursor, query):
start_time = time.time()
seconds = 10
while True:
current_time = time.time()
elapsed_time = current_time - start_time
if elapsed_time > seconds:
return False
cursor.execute(query)
results = cursor.fetchall()
if len(results) < 1:
time.sleep(0.1)
continue
return len(results) == 1
def check_vertex_exists_with_topic_and_payload(cursor, topic, payload_bytes):
assert check_one_result_row(cursor,
"MATCH (n: MESSAGE {"
f"payload: '{payload_bytes.decode('utf-8')}',"
f"topic: '{topic}'"
"}) RETURN n")
def get_stream_info(cursor, stream_name):
stream_infos = execute_and_fetch_all(cursor, "SHOW STREAMS")
for stream_info in stream_infos:
if (stream_info[NAME] == stream_name):
return stream_info
return None
def get_is_running(cursor, stream_name):
stream_info = get_stream_info(cursor, stream_name)
assert stream_info
return stream_info[IS_RUNNING]
def start_stream(cursor, stream_name):
execute_and_fetch_all(cursor, f"START STREAM {stream_name}")
assert get_is_running(cursor, stream_name)
def stop_stream(cursor, stream_name):
execute_and_fetch_all(cursor, f"STOP STREAM {stream_name}")
assert not get_is_running(cursor, stream_name)
def drop_stream(cursor, stream_name):
execute_and_fetch_all(cursor, f"DROP STREAM {stream_name}")
assert get_stream_info(cursor, stream_name) is None
def check_stream_info(cursor, stream_name, expected_stream_info):
stream_info = get_stream_info(cursor, stream_name)
assert len(stream_info) == len(expected_stream_info)
for info, expected_info in zip(stream_info, expected_stream_info):
assert info == expected_info
##############################################
# Tests
##############################################
@pytest.mark.parametrize("transformation", TRANSFORMATIONS_TO_CHECK)
def test_simple(producer, topics, connection, transformation):
assert len(topics) > 0
cursor = connection.cursor()
execute_and_fetch_all(cursor,
"CREATE STREAM test "
f"TOPICS {','.join(topics)} "
f"TRANSFORM {transformation}")
start_stream(cursor, "test")
common.execute_and_fetch_all(cursor,
"CREATE STREAM test "
f"TOPICS {','.join(topics)} "
f"TRANSFORM {transformation}")
common.start_stream(cursor, "test")
time.sleep(5)
for topic in topics:
producer.send(topic, SIMPLE_MSG).get(timeout=60)
for topic in topics:
check_vertex_exists_with_topic_and_payload(
common.check_vertex_exists_with_topic_and_payload(
cursor, topic, SIMPLE_MSG)
@ -196,13 +46,13 @@ def test_separate_consumers(producer, topics, connection, transformation):
for topic in topics:
stream_name = "stream_" + topic
stream_names.append(stream_name)
execute_and_fetch_all(cursor,
f"CREATE STREAM {stream_name} "
f"TOPICS {topic} "
f"TRANSFORM {transformation}")
common.execute_and_fetch_all(cursor,
f"CREATE STREAM {stream_name} "
f"TOPICS {topic} "
f"TRANSFORM {transformation}")
for stream_name in stream_names:
start_stream(cursor, stream_name)
common.start_stream(cursor, stream_name)
time.sleep(5)
@ -210,7 +60,7 @@ def test_separate_consumers(producer, topics, connection, transformation):
producer.send(topic, SIMPLE_MSG).get(timeout=60)
for topic in topics:
check_vertex_exists_with_topic_and_payload(
common.check_vertex_exists_with_topic_and_payload(
cursor, topic, SIMPLE_MSG)
@ -223,41 +73,42 @@ def test_start_from_last_committed_offset(producer, topics, connection):
# restarting Memgraph during a single workload cannot be done currently.
assert len(topics) > 0
cursor = connection.cursor()
execute_and_fetch_all(cursor,
"CREATE STREAM test "
f"TOPICS {topics[0]} "
"TRANSFORM transform.simple")
start_stream(cursor, "test")
common.execute_and_fetch_all(cursor,
"CREATE STREAM test "
f"TOPICS {topics[0]} "
"TRANSFORM transform.simple")
common.start_stream(cursor, "test")
time.sleep(1)
producer.send(topics[0], SIMPLE_MSG).get(timeout=60)
check_vertex_exists_with_topic_and_payload(
common.check_vertex_exists_with_topic_and_payload(
cursor, topics[0], SIMPLE_MSG)
stop_stream(cursor, "test")
drop_stream(cursor, "test")
common.stop_stream(cursor, "test")
common.drop_stream(cursor, "test")
messages = [b"second message", b"third message"]
for message in messages:
producer.send(topics[0], message).get(timeout=60)
for message in messages:
vertices_with_msg = execute_and_fetch_all(cursor,
"MATCH (n: MESSAGE {"
f"payload: '{message.decode('utf-8')}'"
"}) RETURN n")
vertices_with_msg = common.execute_and_fetch_all(
cursor,
"MATCH (n: MESSAGE {"
f"payload: '{message.decode('utf-8')}'"
"}) RETURN n")
assert len(vertices_with_msg) == 0
execute_and_fetch_all(cursor,
"CREATE STREAM test "
f"TOPICS {topics[0]} "
"TRANSFORM transform.simple")
start_stream(cursor, "test")
common.execute_and_fetch_all(cursor,
"CREATE STREAM test "
f"TOPICS {topics[0]} "
"TRANSFORM transform.simple")
common.start_stream(cursor, "test")
for message in messages:
check_vertex_exists_with_topic_and_payload(
common.check_vertex_exists_with_topic_and_payload(
cursor, topics[0], message)
@ -265,16 +116,16 @@ def test_start_from_last_committed_offset(producer, topics, connection):
def test_check_stream(producer, topics, connection, transformation):
assert len(topics) > 0
cursor = connection.cursor()
execute_and_fetch_all(cursor,
"CREATE STREAM test "
f"TOPICS {topics[0]} "
f"TRANSFORM {transformation} "
"BATCH_SIZE 1")
start_stream(cursor, "test")
common.execute_and_fetch_all(cursor,
"CREATE STREAM test "
f"TOPICS {topics[0]} "
f"TRANSFORM {transformation} "
"BATCH_SIZE 1")
common.start_stream(cursor, "test")
time.sleep(1)
producer.send(topics[0], SIMPLE_MSG).get(timeout=60)
stop_stream(cursor, "test")
common.stop_stream(cursor, "test")
messages = [b"first message", b"second message", b"third message"]
for message in messages:
@ -283,7 +134,7 @@ def test_check_stream(producer, topics, connection, transformation):
def check_check_stream(batch_limit):
assert transformation == "transform.simple" \
or transformation == "transform.with_parameters"
test_results = execute_and_fetch_all(
test_results = common.execute_and_fetch_all(
cursor, f"CHECK STREAM test BATCH_LIMIT {batch_limit}")
assert len(test_results) == batch_limit
@ -308,42 +159,47 @@ def test_check_stream(producer, topics, connection, transformation):
check_check_stream(1)
check_check_stream(2)
check_check_stream(3)
start_stream(cursor, "test")
common.start_stream(cursor, "test")
for message in messages:
check_vertex_exists_with_topic_and_payload(
common.check_vertex_exists_with_topic_and_payload(
cursor, topics[0], message)
def test_show_streams(producer, topics, connection):
assert len(topics) > 1
cursor = connection.cursor()
execute_and_fetch_all(cursor,
"CREATE STREAM default_values "
f"TOPICS {topics[0]} "
f"TRANSFORM transform.simple")
common.execute_and_fetch_all(cursor,
"CREATE STREAM default_values "
f"TOPICS {topics[0]} "
f"TRANSFORM transform.simple")
consumer_group = "my_special_consumer_group"
batch_interval = 42
batch_size = 3
execute_and_fetch_all(cursor,
"CREATE STREAM complex_values "
f"TOPICS {','.join(topics)} "
f"TRANSFORM transform.with_parameters "
f"CONSUMER_GROUP {consumer_group} "
f"BATCH_INTERVAL {batch_interval} "
f"BATCH_SIZE {batch_size} ")
common.execute_and_fetch_all(cursor,
"CREATE STREAM complex_values "
f"TOPICS {','.join(topics)} "
f"TRANSFORM transform.with_parameters "
f"CONSUMER_GROUP {consumer_group} "
f"BATCH_INTERVAL {batch_interval} "
f"BATCH_SIZE {batch_size} ")
assert len(execute_and_fetch_all(cursor, "SHOW STREAMS")) == 2
assert len(common.execute_and_fetch_all(cursor, "SHOW STREAMS")) == 2
check_stream_info(cursor, "default_values", ("default_values", [
topics[0]], "mg_consumer", None, None,
"transform.simple", False))
common.check_stream_info(cursor, "default_values", ("default_values", [
topics[0]], "mg_consumer", None, None,
"transform.simple", None, False))
check_stream_info(cursor, "complex_values", ("complex_values", topics,
consumer_group, batch_interval, batch_size,
"transform.with_parameters",
False))
common.check_stream_info(cursor, "complex_values", (
"complex_values",
topics,
consumer_group,
batch_interval,
batch_size,
"transform.with_parameters",
None,
False))
@pytest.mark.parametrize("operation", ["START", "STOP"])
@ -360,10 +216,10 @@ def test_start_and_stop_during_check(producer, topics, connection, operation):
assert len(topics) > 1
assert operation == "START" or operation == "STOP"
cursor = connection.cursor()
execute_and_fetch_all(cursor,
"CREATE STREAM test_stream "
f"TOPICS {topics[0]} "
f"TRANSFORM transform.simple")
common.execute_and_fetch_all(cursor,
"CREATE STREAM test_stream "
f"TOPICS {topics[0]} "
f"TRANSFORM transform.simple")
check_counter = Value('i', 0)
check_result_len = Value('i', 0)
@ -377,10 +233,11 @@ def test_start_and_stop_during_check(producer, topics, connection, operation):
def call_check(counter, result_len):
# This process will call the CHECK query and increment the counter
# based on its progress and expected behavior
connection = connect()
connection = common.connect()
cursor = connection.cursor()
counter.value = CHECK_BEFORE_EXECUTE
result = execute_and_fetch_all(cursor, "CHECK STREAM test_stream")
result = common.execute_and_fetch_all(
cursor, "CHECK STREAM test_stream")
result_len.value = len(result)
counter.value = CHECK_AFTER_FETCHALL
if len(result) > 0 and "payload: 'message'" in result[0][QUERY]:
@ -397,11 +254,12 @@ def test_start_and_stop_during_check(producer, topics, connection, operation):
def call_operation(counter):
# This porcess will call the query with the specified operation and
# increment the counter based on its progress and expected behavior
connection = connect()
connection = common.connect()
cursor = connection.cursor()
counter.value = OP_BEFORE_EXECUTE
try:
execute_and_fetch_all(cursor, f"{operation} STREAM test_stream")
common.execute_and_fetch_all(
cursor, f"{operation} STREAM test_stream")
counter.value = OP_AFTER_FETCHALL
except mgclient.DatabaseError as e:
if "Kafka consumer test_stream is already stopped" in str(e):
@ -421,15 +279,19 @@ def test_start_and_stop_during_check(producer, topics, connection, operation):
time.sleep(0.5)
assert timed_wait(lambda: check_counter.value == CHECK_BEFORE_EXECUTE)
assert timed_wait(lambda: get_is_running(cursor, "test_stream"))
assert common.timed_wait(
lambda: check_counter.value == CHECK_BEFORE_EXECUTE)
assert common.timed_wait(
lambda: common.get_is_running(cursor, "test_stream"))
assert check_counter.value == CHECK_BEFORE_EXECUTE, "SHOW STREAMS " \
"was blocked until the end of CHECK STREAM"
operation_proc.start()
assert timed_wait(lambda: operation_counter.value == OP_BEFORE_EXECUTE)
assert common.timed_wait(
lambda: operation_counter.value == OP_BEFORE_EXECUTE)
producer.send(topics[0], SIMPLE_MSG).get(timeout=60)
assert timed_wait(lambda: check_counter.value > CHECK_AFTER_FETCHALL)
assert common.timed_wait(
lambda: check_counter.value > CHECK_AFTER_FETCHALL)
assert check_counter.value == CHECK_CORRECT_RESULT
assert check_result_len.value == 1
check_stream_proc.join()
@ -437,10 +299,10 @@ def test_start_and_stop_during_check(producer, topics, connection, operation):
operation_proc.join()
if operation == "START":
assert operation_counter.value == OP_AFTER_FETCHALL
assert get_is_running(cursor, "test_stream")
assert common.get_is_running(cursor, "test_stream")
else:
assert operation_counter.value == OP_ALREADY_STOPPED_EXCEPTION
assert not get_is_running(cursor, "test_stream")
assert not common.get_is_running(cursor, "test_stream")
finally:
# to make sure CHECK STREAM finishes
@ -455,42 +317,64 @@ def test_check_already_started_stream(topics, connection):
assert len(topics) > 0
cursor = connection.cursor()
execute_and_fetch_all(cursor,
"CREATE STREAM started_stream "
f"TOPICS {topics[0]} "
f"TRANSFORM transform.simple")
start_stream(cursor, "started_stream")
common.execute_and_fetch_all(cursor,
"CREATE STREAM started_stream "
f"TOPICS {topics[0]} "
f"TRANSFORM transform.simple")
common.start_stream(cursor, "started_stream")
with pytest.raises(mgclient.DatabaseError):
execute_and_fetch_all(cursor, "CHECK STREAM started_stream")
common.execute_and_fetch_all(cursor, "CHECK STREAM started_stream")
def test_start_checked_stream_after_timeout(topics, connection):
cursor = connection.cursor()
execute_and_fetch_all(cursor,
"CREATE STREAM test_stream "
f"TOPICS {topics[0]} "
f"TRANSFORM transform.simple")
common.execute_and_fetch_all(cursor,
"CREATE STREAM test_stream "
f"TOPICS {topics[0]} "
f"TRANSFORM transform.simple")
timeout_ms = 2000
def call_check():
execute_and_fetch_all(
connect().cursor(),
common.execute_and_fetch_all(
common.connect().cursor(),
f"CHECK STREAM test_stream TIMEOUT {timeout_ms}")
check_stream_proc = Process(target=call_check, daemon=True)
start = time.time()
check_stream_proc.start()
assert timed_wait(lambda: get_is_running(cursor, "test_stream"))
start_stream(cursor, "test_stream")
assert common.timed_wait(
lambda: common.get_is_running(cursor, "test_stream"))
common.start_stream(cursor, "test_stream")
end = time.time()
assert (end - start) < 1.3 * \
timeout_ms, "The START STREAM was blocked too long"
assert get_is_running(cursor, "test_stream")
stop_stream(cursor, "test_stream")
assert common.get_is_running(cursor, "test_stream")
common.stop_stream(cursor, "test_stream")
def test_restart_after_error(producer, topics, connection):
cursor = connection.cursor()
common.execute_and_fetch_all(cursor,
"CREATE STREAM test_stream "
f"TOPICS {topics[0]} "
f"TRANSFORM transform.query")
common.start_stream(cursor, "test_stream")
time.sleep(1)
producer.send(topics[0], SIMPLE_MSG).get(timeout=60)
assert common.timed_wait(
lambda: not common.get_is_running(cursor, "test_stream"))
common.start_stream(cursor, "test_stream")
time.sleep(1)
producer.send(topics[0], b'CREATE (n:VERTEX { id : 42 })')
assert common.check_one_result_row(
cursor, "MATCH (n:VERTEX { id : 42 }) RETURN n")
if __name__ == "__main__":

View File

@ -35,3 +35,17 @@ def with_parameters(context: mgp.TransCtx,
"topic": message.topic_name()}))
return result_queries
@mgp.transformation
def query(messages: mgp.Messages
) -> mgp.Record(query=str, parameters=mgp.Nullable[mgp.Map]):
result_queries = []
for i in range(0, messages.total_messages()):
message = messages.message_at(i)
payload_as_str = message.payload().decode("utf-8")
result_queries.append(mgp.Record(
query=payload_as_str, parameters=None))
return result_queries

View File

@ -10,5 +10,10 @@ workloads:
- name: "Streams start, stop and show"
binary: "tests/e2e/streams/streams_test_runner.sh"
proc: "tests/e2e/streams/transformations/"
args: []
args: ["streams_tests.py"]
<<: *template_cluster
- name: "Streams with users"
binary: "tests/e2e/streams/streams_test_runner.sh"
proc: "tests/e2e/streams/transformations/"
args: ["streams_owner_tests.py"]
<<: *template_cluster

View File

@ -9,3 +9,6 @@ target_link_libraries(memgraph__e2e__triggers__on_update memgraph__e2e__triggers
add_executable(memgraph__e2e__triggers__on_delete on_delete_triggers.cpp)
target_link_libraries(memgraph__e2e__triggers__on_delete memgraph__e2e__triggers_common)
add_executable(memgraph__e2e__triggers__privileges privilige_check.cpp)
target_link_libraries(memgraph__e2e__triggers__privileges memgraph__e2e__triggers_common)

View File

@ -1,7 +1,9 @@
#include "common.hpp"
#include <chrono>
#include <cstdint>
#include <optional>
#include <thread>
#include <fmt/format.h>
#include <gflags/gflags.h>
@ -10,13 +12,17 @@
DEFINE_uint64(bolt_port, 7687, "Bolt port");
std::unique_ptr<mg::Client> Connect() {
auto client =
mg::Client::Connect({.host = "127.0.0.1", .port = static_cast<uint16_t>(FLAGS_bolt_port), .use_ssl = false});
std::unique_ptr<mg::Client> ConnectWithUser(const std::string_view username) {
auto client = mg::Client::Connect({.host = "127.0.0.1",
.port = static_cast<uint16_t>(FLAGS_bolt_port),
.username = std::string{username},
.use_ssl = false});
MG_ASSERT(client, "Failed to connect!");
return client;
}
std::unique_ptr<mg::Client> Connect() { return ConnectWithUser(""); }
void CreateVertex(mg::Client &client, int vertex_id) {
mg::Map parameters{
{"id", mg::Value{vertex_id}},
@ -49,10 +55,12 @@ int GetNumberOfAllVertices(mg::Client &client) {
}
void WaitForNumberOfAllVertices(mg::Client &client, int number_of_vertices) {
using namespace std::chrono_literals;
utils::Timer timer{};
while ((timer.Elapsed().count() <= 0.5) && GetNumberOfAllVertices(client) != number_of_vertices) {
}
CheckNumberOfAllVertices(client, number_of_vertices);
std::this_thread::sleep_for(100ms);
}
void CheckNumberOfAllVertices(mg::Client &client, int expected_number_of_vertices) {

View File

@ -11,6 +11,7 @@ constexpr std::string_view kVertexLabel{"VERTEX"};
constexpr std::string_view kEdgeLabel{"EDGE"};
std::unique_ptr<mg::Client> Connect();
std::unique_ptr<mg::Client> ConnectWithUser(const std::string_view username);
void CreateVertex(mg::Client &client, int vertex_id);
void CreateEdge(mg::Client &client, int from_vertex, int to_vertex, int edge_id);

View File

@ -0,0 +1,161 @@
#include <string>
#include <string_view>
#include <gflags/gflags.h>
#include <spdlog/fmt/bundled/core.h>
#include <mgclient.hpp>
#include "common.hpp"
#include "utils/logging.hpp"
constexpr std::string_view kTriggerPrefix{"CreatedVerticesTrigger"};
int main(int argc, char **argv) {
gflags::SetUsageMessage("Memgraph E2E Triggers privilege check");
gflags::ParseCommandLineFlags(&argc, &argv, true);
logging::RedirectToStderr();
constexpr int kVertexId{42};
constexpr std::string_view kUserlessLabel{"USERLESS"};
constexpr std::string_view kAdminUser{"ADMIN"};
constexpr std::string_view kUserWithCreate{"USER_WITH_CREATE"};
constexpr std::string_view kUserWithoutCreate{"USER_WITHOUT_CREATE"};
mg::Client::Init();
auto userless_client = Connect();
const auto get_number_of_triggers = [&userless_client] {
userless_client->Execute("SHOW TRIGGERS");
auto result = userless_client->FetchAll();
MG_ASSERT(result.has_value());
return result->size();
};
auto create_trigger = [&get_number_of_triggers](mg::Client &client, const std::string_view vertexLabel,
bool should_succeed = true) {
const auto number_of_triggers_before = get_number_of_triggers();
client.Execute(
fmt::format("CREATE TRIGGER {}{} ON () CREATE "
"AFTER COMMIT "
"EXECUTE "
"UNWIND createdVertices as createdVertex "
"CREATE (n: {} {{ id: createdVertex.id }})",
kTriggerPrefix, vertexLabel, vertexLabel));
client.DiscardAll();
const auto number_of_triggers_after = get_number_of_triggers();
if (should_succeed) {
MG_ASSERT(number_of_triggers_after == number_of_triggers_before + 1);
} else {
MG_ASSERT(number_of_triggers_after == number_of_triggers_before);
}
};
auto delete_vertices = [&userless_client] {
userless_client->Execute("MATCH (n) DETACH DELETE n;");
userless_client->DiscardAll();
CheckNumberOfAllVertices(*userless_client, 0);
};
auto create_user = [&userless_client](const std::string_view username) {
userless_client->Execute(fmt::format("CREATE USER {};", username));
userless_client->DiscardAll();
userless_client->Execute(fmt::format("GRANT TRIGGER TO {};", username));
userless_client->DiscardAll();
};
auto drop_user = [&userless_client](const std::string_view username) {
userless_client->Execute(fmt::format("DROP USER {};", username));
userless_client->DiscardAll();
};
auto drop_trigger_of_user = [&userless_client](const std::string_view username) {
userless_client->Execute(fmt::format("DROP TRIGGER {}{};", kTriggerPrefix, username));
userless_client->DiscardAll();
};
// Single trigger created without user, there is no existing users
create_trigger(*userless_client, kUserlessLabel);
CreateVertex(*userless_client, kVertexId);
WaitForNumberOfAllVertices(*userless_client, 2);
CheckVertexExists(*userless_client, kVertexLabel, kVertexId);
CheckVertexExists(*userless_client, kUserlessLabel, kVertexId);
delete_vertices();
// Single trigger created without user, there is an existing user
// The trigger fails because there is no owner
create_user(kAdminUser);
CreateVertex(*userless_client, kVertexId);
CheckVertexExists(*userless_client, kVertexLabel, kVertexId);
CheckNumberOfAllVertices(*userless_client, 1);
delete_vertices();
// Three triggers: without an owner, an owner with CREATE privilege, an owner without CREATE privilege; there are
// existing users
// Only the trigger which owner has CREATE privilege will succeed
create_user(kUserWithCreate);
userless_client->Execute(fmt::format("GRANT CREATE TO {};", kUserWithCreate));
userless_client->DiscardAll();
create_user(kUserWithoutCreate);
auto client_with_create = ConnectWithUser(kUserWithCreate);
auto client_without_create = ConnectWithUser(kUserWithoutCreate);
create_trigger(*client_with_create, kUserWithCreate);
create_trigger(*client_without_create, kUserWithoutCreate, false);
// Grant CREATE to be able to create the trigger than revoke it
userless_client->Execute(fmt::format("GRANT CREATE TO {};", kUserWithoutCreate));
userless_client->DiscardAll();
create_trigger(*client_without_create, kUserWithoutCreate);
userless_client->Execute(fmt::format("REVOKE CREATE FROM {};", kUserWithoutCreate));
userless_client->DiscardAll();
CreateVertex(*userless_client, kVertexId);
WaitForNumberOfAllVertices(*userless_client, 2);
CheckVertexExists(*userless_client, kVertexLabel, kVertexId);
CheckVertexExists(*userless_client, kUserWithCreate, kVertexId);
delete_vertices();
// Three triggers: without an owner, an owner with CREATE privilege, an owner without CREATE privilege; there is no
// existing user
// All triggers will succeed, as there is no authorization is done when there are no users
drop_user(kAdminUser);
drop_user(kUserWithCreate);
drop_user(kUserWithoutCreate);
CreateVertex(*userless_client, kVertexId);
WaitForNumberOfAllVertices(*userless_client, 4);
CheckVertexExists(*userless_client, kVertexLabel, kVertexId);
CheckVertexExists(*userless_client, kUserlessLabel, kVertexId);
CheckVertexExists(*userless_client, kUserWithCreate, kVertexId);
CheckVertexExists(*userless_client, kUserWithoutCreate, kVertexId);
delete_vertices();
drop_trigger_of_user(kUserlessLabel);
drop_trigger_of_user(kUserWithCreate);
drop_trigger_of_user(kUserWithoutCreate);
// The BEFORE COMMIT trigger without proper privileges make the transaction fail
create_user(kUserWithoutCreate);
userless_client->Execute(fmt::format("GRANT CREATE TO {};", kUserWithoutCreate));
userless_client->DiscardAll();
client_without_create->Execute(
fmt::format("CREATE TRIGGER {}{} ON () CREATE "
"BEFORE COMMIT "
"EXECUTE "
"UNWIND createdVertices as createdVertex "
"CREATE (n: {} {{ id: createdVertex.id }})",
kTriggerPrefix, kUserWithoutCreate, kUserWithoutCreate));
client_without_create->DiscardAll();
userless_client->Execute(fmt::format("REVOKE CREATE FROM {};", kUserWithoutCreate));
userless_client->DiscardAll();
CreateVertex(*userless_client, kVertexId);
CheckNumberOfAllVertices(*userless_client, 0);
return 0;
}

View File

@ -20,5 +20,9 @@ workloads:
binary: "tests/e2e/triggers/memgraph__e2e__triggers__on_delete"
args: ["--bolt-port", *bolt_port]
<<: *template_cluster
- name: "Triggers privilege check"
binary: "tests/e2e/triggers/memgraph__e2e__triggers__privileges"
args: ["--bolt-port", *bolt_port]
<<: *template_cluster

View File

@ -22,7 +22,7 @@ int main(int argc, char *argv[]) {
query::Interpreter interpreter{&interpreter_context};
ResultStreamFaker stream(&db);
auto [header, _, qid] = interpreter.Prepare(argv[1], {});
auto [header, _, qid] = interpreter.Prepare(argv[1], {}, nullptr);
stream.Header(header);
auto summary = interpreter.PullAll(&stream);
stream.Summary(summary);

View File

@ -165,14 +165,14 @@ TEST_F(AuthWithStorage, RoleManipulations) {
{
auto user1 = auth.GetUser("user1");
ASSERT_TRUE(user1);
auto role1 = user1->role();
ASSERT_TRUE(role1);
const auto *role1 = user1->role();
ASSERT_NE(role1, nullptr);
ASSERT_EQ(role1->rolename(), "role1");
auto user2 = auth.GetUser("user2");
ASSERT_TRUE(user2);
auto role2 = user2->role();
ASSERT_TRUE(role2);
const auto *role2 = user2->role();
ASSERT_NE(role2, nullptr);
ASSERT_EQ(role2->rolename(), "role2");
}
@ -181,13 +181,13 @@ TEST_F(AuthWithStorage, RoleManipulations) {
{
auto user1 = auth.GetUser("user1");
ASSERT_TRUE(user1);
auto role = user1->role();
ASSERT_FALSE(role);
const auto *role = user1->role();
ASSERT_EQ(role, nullptr);
auto user2 = auth.GetUser("user2");
ASSERT_TRUE(user2);
auto role2 = user2->role();
ASSERT_TRUE(role2);
const auto *role2 = user2->role();
ASSERT_NE(role2, nullptr);
ASSERT_EQ(role2->rolename(), "role2");
}
@ -199,13 +199,13 @@ TEST_F(AuthWithStorage, RoleManipulations) {
{
auto user1 = auth.GetUser("user1");
ASSERT_TRUE(user1);
auto role1 = user1->role();
ASSERT_FALSE(role1);
const auto *role1 = user1->role();
ASSERT_EQ(role1, nullptr);
auto user2 = auth.GetUser("user2");
ASSERT_TRUE(user2);
auto role2 = user2->role();
ASSERT_TRUE(role2);
const auto *role2 = user2->role();
ASSERT_NE(role2, nullptr);
ASSERT_EQ(role2->rolename(), "role2");
}
@ -245,8 +245,8 @@ TEST_F(AuthWithStorage, UserRoleLinkUnlink) {
{
auto user = auth.GetUser("user");
ASSERT_TRUE(user);
auto role = user->role();
ASSERT_TRUE(role);
const auto *role = user->role();
ASSERT_NE(role, nullptr);
ASSERT_EQ(role->rolename(), "role");
}
@ -260,7 +260,7 @@ TEST_F(AuthWithStorage, UserRoleLinkUnlink) {
{
auto user = auth.GetUser("user");
ASSERT_TRUE(user);
ASSERT_FALSE(user->role());
ASSERT_EQ(user->role(), nullptr);
}
}
@ -620,8 +620,9 @@ TEST_F(AuthWithStorage, CaseInsensitivity) {
auto user = auth.GetUser("aLIce");
ASSERT_TRUE(user);
ASSERT_EQ(user->username(), "alice");
ASSERT_TRUE(user->role());
ASSERT_EQ(user->role()->rolename(), "moderator");
const auto *role = user->role();
ASSERT_NE(role, nullptr);
ASSERT_EQ(role->rolename(), "moderator");
}
// AllUsersForRole

View File

@ -6,6 +6,7 @@
#include "glue/communication.hpp"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "query/auth_checker.hpp"
#include "query/config.hpp"
#include "query/exceptions.hpp"
#include "query/interpreter.hpp"
@ -28,15 +29,17 @@ auto ToEdgeList(const communication::bolt::Value &v) {
};
struct InterpreterFaker {
explicit InterpreterFaker(storage::Storage *db, const query::InterpreterConfig config,
const std::filesystem::path &data_directory)
InterpreterFaker(storage::Storage *db, const query::InterpreterConfig config,
const std::filesystem::path &data_directory)
: interpreter_context(db, config, data_directory, "not used bootstrap servers"),
interpreter(&interpreter_context) {}
interpreter(&interpreter_context) {
interpreter_context.auth_checker = &auth_checker;
}
auto Prepare(const std::string &query, const std::map<std::string, storage::PropertyValue> &params = {}) {
ResultStreamFaker stream(interpreter_context.db);
const auto [header, _, qid] = interpreter.Prepare(query, params);
const auto [header, _, qid] = interpreter.Prepare(query, params, nullptr);
stream.Header(header);
return std::make_pair(std::move(stream), qid);
}
@ -61,6 +64,7 @@ struct InterpreterFaker {
return std::move(stream);
}
query::AllowEverythingAuthChecker auth_checker;
query::InterpreterContext interpreter_context;
query::Interpreter interpreter;
};

View File

@ -194,7 +194,7 @@ auto Execute(storage::Storage *db, const std::string &query) {
query::Interpreter interpreter(&context);
ResultStreamFaker stream(db);
auto [header, _, qid] = interpreter.Prepare(query, {});
auto [header, _, qid] = interpreter.Prepare(query, {}, nullptr);
stream.Header(header);
auto summary = interpreter.PullAll(&stream);
stream.Summary(summary);
@ -711,7 +711,7 @@ class StatefulInterpreter {
auto Execute(const std::string &query) {
ResultStreamFaker stream(db_);
auto [header, _, qid] = interpreter_.Prepare(query, {});
auto [header, _, qid] = interpreter_.Prepare(query, {}, nullptr);
stream.Header(header);
auto summary = interpreter_.PullAll(&stream);
stream.Summary(summary);

View File

@ -42,7 +42,7 @@ class QueryExecution : public testing::Test {
auto Execute(const std::string &query) {
ResultStreamFaker stream(&*db_);
auto [header, _, qid] = interpreter_->Prepare(query, {});
auto [header, _, qid] = interpreter_->Prepare(query, {}, nullptr);
stream.Header(header);
auto summary = interpreter_->PullAll(&stream);
stream.Summary(summary);

View File

@ -34,6 +34,7 @@ StreamInfo CreateDefaultStreamInfo() {
.batch_interval = std::nullopt,
.batch_size = std::nullopt,
.transformation_name = "not used in the tests",
.owner = std::nullopt,
};
}
@ -57,6 +58,8 @@ class StreamsTest : public ::testing::Test {
// Though there is a Streams object in interpreter context, it makes more sense to use a separate object to test,
// because that provides a way to recreate the streams object and also give better control over the arguments of the
// Streams constructor.
// InterpreterContext::auth_checker_ is used in the Streams object, but only in the message processing part. Because
// these tests don't send any messages, the auth_checker_ pointer can be left as nullptr.
query::InterpreterContext interpreter_context_{&db_, query::InterpreterConfig{}, data_directory_,
"dont care bootstrap servers"};
std::filesystem::path streams_data_directory_{data_directory_ / "separate-dir-for-test"};
@ -172,12 +175,14 @@ TEST_F(StreamsTest, RestoreStreams) {
if (i > 0) {
stream_info.batch_interval = std::chrono::milliseconds((i + 1) * 10);
stream_info.batch_size = 1000 + i;
stream_info.owner = std::string{"owner"} + iteration_postfix;
}
mock_cluster_.CreateTopic(stream_info.topics[0]);
}
stream_check_datas[1].info.batch_interval = {};
stream_check_datas[2].info.batch_size = {};
stream_check_datas[3].info.owner = {};
const auto check_restore_logic = [&stream_check_datas, this]() {
// Reset the Streams object to trigger reloading

View File

@ -1,12 +1,16 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <filesystem>
#include <fmt/format.h>
#include "query/auth_checker.hpp"
#include "query/config.hpp"
#include "query/db_accessor.hpp"
#include "query/frontend/ast/ast.hpp"
#include "query/interpreter.hpp"
#include "query/trigger.hpp"
#include "query/typed_value.hpp"
#include "utils/exceptions.hpp"
#include "utils/memory.hpp"
namespace {
@ -16,6 +20,12 @@ const std::unordered_set<query::TriggerEventType> kAllEventTypes{
query::TriggerEventType::DELETE, query::TriggerEventType::VERTEX_UPDATE, query::TriggerEventType::EDGE_UPDATE,
query::TriggerEventType::UPDATE,
};
class MockAuthChecker : public query::AuthChecker {
public:
MOCK_CONST_METHOD2(IsUserAuthorized, bool(const std::optional<std::string> &username,
const std::vector<query::AuthQuery::Privilege> &privileges));
};
} // namespace
class TriggerContextTest : public ::testing::Test {
@ -820,6 +830,7 @@ class TriggerStoreTest : public ::testing::Test {
utils::SkipList<query::QueryCacheEntry> ast_cache;
utils::SpinLock antlr_lock;
query::AllowEverythingAuthChecker auth_checker;
private:
void Clear() {
@ -836,7 +847,7 @@ TEST_F(TriggerStoreTest, Restore) {
const auto reset_store = [&] {
store.emplace(testing_directory);
store->RestoreTriggers(&ast_cache, &*dba, &antlr_lock, query::InterpreterConfig::Query{});
store->RestoreTriggers(&ast_cache, &*dba, &antlr_lock, query::InterpreterConfig::Query{}, &auth_checker);
};
reset_store();
@ -853,34 +864,40 @@ TEST_F(TriggerStoreTest, Restore) {
const auto *trigger_name_after = "trigger_after";
const auto *trigger_statement = "RETURN $parameter";
const auto event_type = query::TriggerEventType::VERTEX_CREATE;
const std::string owner{"owner"};
store->AddTrigger(trigger_name_before, trigger_statement,
std::map<std::string, storage::PropertyValue>{{"parameter", storage::PropertyValue{1}}}, event_type,
query::TriggerPhase::BEFORE_COMMIT, &ast_cache, &*dba, &antlr_lock,
query::InterpreterConfig::Query{});
query::InterpreterConfig::Query{}, std::nullopt, &auth_checker);
store->AddTrigger(trigger_name_after, trigger_statement,
std::map<std::string, storage::PropertyValue>{{"parameter", storage::PropertyValue{"value"}}},
event_type, query::TriggerPhase::AFTER_COMMIT, &ast_cache, &*dba, &antlr_lock,
query::InterpreterConfig::Query{});
query::InterpreterConfig::Query{}, {owner}, &auth_checker);
const auto check_triggers = [&] {
ASSERT_EQ(store->GetTriggerInfo().size(), 2);
const auto verify_trigger = [&](const auto &trigger, const auto &name) {
const auto verify_trigger = [&](const auto &trigger, const auto &name, const std::string *owner) {
ASSERT_EQ(trigger.Name(), name);
ASSERT_EQ(trigger.OriginalStatement(), trigger_statement);
ASSERT_EQ(trigger.EventType(), event_type);
if (owner != nullptr) {
ASSERT_EQ(*trigger.Owner(), *owner);
} else {
ASSERT_FALSE(trigger.Owner().has_value());
}
};
const auto before_commit_triggers = store->BeforeCommitTriggers().access();
ASSERT_EQ(before_commit_triggers.size(), 1);
for (const auto &trigger : before_commit_triggers) {
verify_trigger(trigger, trigger_name_before);
verify_trigger(trigger, trigger_name_before, nullptr);
}
const auto after_commit_triggers = store->AfterCommitTriggers().access();
ASSERT_EQ(after_commit_triggers.size(), 1);
for (const auto &trigger : after_commit_triggers) {
verify_trigger(trigger, trigger_name_after);
verify_trigger(trigger, trigger_name_after, &owner);
}
};
@ -906,32 +923,32 @@ TEST_F(TriggerStoreTest, AddTrigger) {
// Invalid query in statements
ASSERT_THROW(store.AddTrigger("trigger", "RETUR 1", {}, query::TriggerEventType::VERTEX_CREATE,
query::TriggerPhase::BEFORE_COMMIT, &ast_cache, &*dba, &antlr_lock,
query::InterpreterConfig::Query{}),
query::InterpreterConfig::Query{}, std::nullopt, &auth_checker),
utils::BasicException);
ASSERT_THROW(store.AddTrigger("trigger", "RETURN createdEdges", {}, query::TriggerEventType::VERTEX_CREATE,
query::TriggerPhase::BEFORE_COMMIT, &ast_cache, &*dba, &antlr_lock,
query::InterpreterConfig::Query{}),
query::InterpreterConfig::Query{}, std::nullopt, &auth_checker),
utils::BasicException);
ASSERT_THROW(store.AddTrigger("trigger", "RETURN $parameter", {}, query::TriggerEventType::VERTEX_CREATE,
query::TriggerPhase::BEFORE_COMMIT, &ast_cache, &*dba, &antlr_lock,
query::InterpreterConfig::Query{}),
query::InterpreterConfig::Query{}, std::nullopt, &auth_checker),
utils::BasicException);
ASSERT_NO_THROW(
store.AddTrigger("trigger", "RETURN $parameter",
std::map<std::string, storage::PropertyValue>{{"parameter", storage::PropertyValue{1}}},
query::TriggerEventType::VERTEX_CREATE, query::TriggerPhase::BEFORE_COMMIT, &ast_cache, &*dba,
&antlr_lock, query::InterpreterConfig::Query{}));
&antlr_lock, query::InterpreterConfig::Query{}, std::nullopt, &auth_checker));
// Inserting with the same name
ASSERT_THROW(store.AddTrigger("trigger", "RETURN 1", {}, query::TriggerEventType::VERTEX_CREATE,
query::TriggerPhase::BEFORE_COMMIT, &ast_cache, &*dba, &antlr_lock,
query::InterpreterConfig::Query{}),
query::InterpreterConfig::Query{}, std::nullopt, &auth_checker),
utils::BasicException);
ASSERT_THROW(store.AddTrigger("trigger", "RETURN 1", {}, query::TriggerEventType::VERTEX_CREATE,
query::TriggerPhase::AFTER_COMMIT, &ast_cache, &*dba, &antlr_lock,
query::InterpreterConfig::Query{}),
query::InterpreterConfig::Query{}, std::nullopt, &auth_checker),
utils::BasicException);
ASSERT_EQ(store.GetTriggerInfo().size(), 1);
@ -947,7 +964,7 @@ TEST_F(TriggerStoreTest, DropTrigger) {
const auto *trigger_name = "trigger";
store.AddTrigger(trigger_name, "RETURN 1", {}, query::TriggerEventType::VERTEX_CREATE,
query::TriggerPhase::BEFORE_COMMIT, &ast_cache, &*dba, &antlr_lock,
query::InterpreterConfig::Query{});
query::InterpreterConfig::Query{}, std::nullopt, &auth_checker);
ASSERT_THROW(store.DropTrigger("Unknown"), utils::BasicException);
ASSERT_NO_THROW(store.DropTrigger(trigger_name));
@ -960,7 +977,7 @@ TEST_F(TriggerStoreTest, TriggerInfo) {
std::vector<query::TriggerStore::TriggerInfo> expected_info;
store.AddTrigger("trigger", "RETURN 1", {}, query::TriggerEventType::VERTEX_CREATE,
query::TriggerPhase::BEFORE_COMMIT, &ast_cache, &*dba, &antlr_lock,
query::InterpreterConfig::Query{});
query::InterpreterConfig::Query{}, std::nullopt, &auth_checker);
expected_info.push_back(
{"trigger", "RETURN 1", query::TriggerEventType::VERTEX_CREATE, query::TriggerPhase::BEFORE_COMMIT});
@ -971,7 +988,7 @@ TEST_F(TriggerStoreTest, TriggerInfo) {
ASSERT_TRUE(std::all_of(expected_info.begin(), expected_info.end(), [&](const auto &info) {
return std::find_if(trigger_info.begin(), trigger_info.end(), [&](const auto &other) {
return info.name == other.name && info.statement == other.statement &&
info.event_type == other.event_type && info.phase == other.phase;
info.event_type == other.event_type && info.phase == other.phase && !info.owner.has_value();
}) != trigger_info.end();
}));
};
@ -979,8 +996,8 @@ TEST_F(TriggerStoreTest, TriggerInfo) {
check_trigger_info();
store.AddTrigger("edge_update_trigger", "RETURN 1", {}, query::TriggerEventType::EDGE_UPDATE,
query::TriggerPhase::AFTER_COMMIT, &ast_cache, &*dba, &antlr_lock,
query::InterpreterConfig::Query{});
query::TriggerPhase::AFTER_COMMIT, &ast_cache, &*dba, &antlr_lock, query::InterpreterConfig::Query{},
std::nullopt, &auth_checker);
expected_info.push_back(
{"edge_update_trigger", "RETURN 1", query::TriggerEventType::EDGE_UPDATE, query::TriggerPhase::AFTER_COMMIT});
@ -1093,8 +1110,56 @@ TEST_F(TriggerStoreTest, AnyTriggerAllKeywords) {
SCOPED_TRACE(keyword);
EXPECT_NO_THROW(store.AddTrigger(trigger_name, fmt::format("RETURN {}", keyword), {}, event_type,
query::TriggerPhase::BEFORE_COMMIT, &ast_cache, &*dba, &antlr_lock,
query::InterpreterConfig::Query{}));
query::InterpreterConfig::Query{}, std::nullopt, &auth_checker));
store.DropTrigger(trigger_name);
}
}
}
TEST_F(TriggerStoreTest, AuthCheckerUsage) {
using Privilege = query::AuthQuery::Privilege;
using ::testing::_;
using ::testing::ElementsAre;
using ::testing::Return;
std::optional<query::TriggerStore> store{testing_directory};
const std::optional<std::string> owner{"testing_owner"};
MockAuthChecker mock_checker;
::testing::InSequence s;
EXPECT_CALL(mock_checker, IsUserAuthorized(std::optional<std::string>{}, ElementsAre(Privilege::CREATE)))
.Times(1)
.WillOnce(Return(true));
EXPECT_CALL(mock_checker, IsUserAuthorized(owner, ElementsAre(Privilege::CREATE))).Times(1).WillOnce(Return(true));
ASSERT_NO_THROW(store->AddTrigger("successfull_trigger_1", "CREATE (n:VERTEX) RETURN n", {},
query::TriggerEventType::EDGE_UPDATE, query::TriggerPhase::AFTER_COMMIT, &ast_cache,
&*dba, &antlr_lock, query::InterpreterConfig::Query{}, std::nullopt,
&mock_checker));
ASSERT_NO_THROW(store->AddTrigger("successfull_trigger_2", "CREATE (n:VERTEX) RETURN n", {},
query::TriggerEventType::EDGE_UPDATE, query::TriggerPhase::AFTER_COMMIT, &ast_cache,
&*dba, &antlr_lock, query::InterpreterConfig::Query{}, owner, &mock_checker));
EXPECT_CALL(mock_checker, IsUserAuthorized(std::optional<std::string>{}, ElementsAre(Privilege::MATCH)))
.Times(1)
.WillOnce(Return(false));
ASSERT_THROW(store->AddTrigger("unprivileged_trigger", "MATCH (n:VERTEX) RETURN n", {},
query::TriggerEventType::EDGE_UPDATE, query::TriggerPhase::AFTER_COMMIT, &ast_cache,
&*dba, &antlr_lock, query::InterpreterConfig::Query{}, std::nullopt, &mock_checker);
, utils::BasicException);
store.emplace(testing_directory);
EXPECT_CALL(mock_checker, IsUserAuthorized(std::optional<std::string>{}, ElementsAre(Privilege::CREATE)))
.Times(1)
.WillOnce(Return(false));
EXPECT_CALL(mock_checker, IsUserAuthorized(owner, ElementsAre(Privilege::CREATE))).Times(1).WillOnce(Return(true));
ASSERT_NO_THROW(
store->RestoreTriggers(&ast_cache, &*dba, &antlr_lock, query::InterpreterConfig::Query{}, &mock_checker));
const auto triggers = store->GetTriggerInfo();
ASSERT_EQ(triggers.size(), 1);
ASSERT_EQ(triggers.front().owner, owner);
}