diff --git a/.clang-tidy b/.clang-tidy index 3abb221ad..46103a7a9 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -53,6 +53,7 @@ Checks: '*, -performance-unnecessary-value-param, -readability-braces-around-statements, -readability-else-after-return, + -readability-function-cognitive-complexity, -readability-implicit-bool-conversion, -readability-magic-numbers, -readability-named-parameter, diff --git a/src/auth/auth.cpp b/src/auth/auth.cpp index 759d5abba..d9bda1c60 100644 --- a/src/auth/auth.cpp +++ b/src/auth/auth.cpp @@ -144,7 +144,7 @@ std::optional<User> Auth::Authenticate(const std::string &username, const std::s } } -std::optional<User> Auth::GetUser(const std::string &username_orig) { +std::optional<User> Auth::GetUser(const std::string &username_orig) const { auto username = utils::ToLowerCase(username_orig); auto existing_user = storage_.Get(kUserPrefix + username); if (!existing_user) return std::nullopt; @@ -170,9 +170,9 @@ std::optional<User> Auth::GetUser(const std::string &username_orig) { void Auth::SaveUser(const User &user) { bool success = false; - if (user.role()) { - success = storage_.PutMultiple({{kUserPrefix + user.username(), user.Serialize().dump()}, - {kLinkPrefix + user.username(), user.role()->rolename()}}); + if (const auto *role = user.role(); role != nullptr) { + success = storage_.PutMultiple( + {{kUserPrefix + user.username(), user.Serialize().dump()}, {kLinkPrefix + user.username(), role->rolename()}}); } else { success = storage_.PutAndDeleteMultiple({{kUserPrefix + user.username(), user.Serialize().dump()}}, {kLinkPrefix + user.username()}); @@ -203,7 +203,7 @@ bool Auth::RemoveUser(const std::string &username_orig) { return true; } -std::vector<auth::User> Auth::AllUsers() { +std::vector<auth::User> Auth::AllUsers() const { std::vector<auth::User> ret; for (auto it = storage_.begin(kUserPrefix); it != storage_.end(kUserPrefix); ++it) { auto username = it->first.substr(kUserPrefix.size()); @@ -216,9 +216,9 @@ std::vector<auth::User> Auth::AllUsers() { return ret; } -bool Auth::HasUsers() { return storage_.begin(kUserPrefix) != storage_.end(kUserPrefix); } +bool Auth::HasUsers() const { return storage_.begin(kUserPrefix) != storage_.end(kUserPrefix); } -std::optional<Role> Auth::GetRole(const std::string &rolename_orig) { +std::optional<Role> Auth::GetRole(const std::string &rolename_orig) const { auto rolename = utils::ToLowerCase(rolename_orig); auto existing_role = storage_.Get(kRolePrefix + rolename); if (!existing_role) return std::nullopt; @@ -265,7 +265,7 @@ bool Auth::RemoveRole(const std::string &rolename_orig) { return true; } -std::vector<auth::Role> Auth::AllRoles() { +std::vector<auth::Role> Auth::AllRoles() const { std::vector<auth::Role> ret; for (auto it = storage_.begin(kRolePrefix); it != storage_.end(kRolePrefix); ++it) { auto rolename = it->first.substr(kRolePrefix.size()); @@ -280,7 +280,7 @@ std::vector<auth::Role> Auth::AllRoles() { return ret; } -std::vector<auth::User> Auth::AllUsersForRole(const std::string &rolename_orig) { +std::vector<auth::User> Auth::AllUsersForRole(const std::string &rolename_orig) const { auto rolename = utils::ToLowerCase(rolename_orig); std::vector<auth::User> ret; for (auto it = storage_.begin(kLinkPrefix); it != storage_.end(kLinkPrefix); ++it) { @@ -299,6 +299,4 @@ std::vector<auth::User> Auth::AllUsersForRole(const std::string &rolename_orig) return ret; } -std::mutex &Auth::WithLock() { return lock_; } - } // namespace auth diff --git a/src/auth/auth.hpp b/src/auth/auth.hpp index 9d87dfab7..0174d39b2 100644 --- a/src/auth/auth.hpp +++ b/src/auth/auth.hpp @@ -14,8 +14,7 @@ namespace auth { /** * This class serves as the main Authentication/Authorization storage. * It provides functions for managing Users, Roles and Permissions. - * NOTE: The functions in this class aren't thread safe. Use the `WithLock` lock - * if you want to have safe modifications of the storage. + * NOTE: The non-const functions in this class aren't thread safe. * TODO (mferencevic): Disable user/role modification functions when they are * being managed by the auth module. */ @@ -42,7 +41,7 @@ class Auth final { * @return a user when the user exists, nullopt otherwise * @throw AuthException if unable to load user data. */ - std::optional<User> GetUser(const std::string &username); + std::optional<User> GetUser(const std::string &username) const; /** * Saves a user object to the storage. @@ -81,14 +80,14 @@ class Auth final { * @return a list of users * @throw AuthException if unable to load user data. */ - std::vector<User> AllUsers(); + std::vector<User> AllUsers() const; /** * Returns whether there are users in the storage. * * @return `true` if the storage contains any users, `false` otherwise */ - bool HasUsers(); + bool HasUsers() const; /** * Gets a role from the storage. @@ -98,7 +97,7 @@ class Auth final { * @return a role when the role exists, nullopt otherwise * @throw AuthException if unable to load role data. */ - std::optional<Role> GetRole(const std::string &rolename); + std::optional<Role> GetRole(const std::string &rolename) const; /** * Saves a role object to the storage. @@ -136,7 +135,7 @@ class Auth final { * @return a list of roles * @throw AuthException if unable to load role data. */ - std::vector<Role> AllRoles(); + std::vector<Role> AllRoles() const; /** * Gets all users for a role from the storage. @@ -146,21 +145,13 @@ class Auth final { * @return a list of roles * @throw AuthException if unable to load user data. */ - std::vector<User> AllUsersForRole(const std::string &rolename); - - /** - * Returns a reference to the lock that should be used for all operations that - * require more than one interaction with this class. - */ - std::mutex &WithLock(); + std::vector<User> AllUsersForRole(const std::string &rolename) const; private: + // Even though the `kvstore::KVStore` class is guaranteed to be thread-safe, + // Auth is not thread-safe because modifying users and roles might require + // more than one operation on the storage. kvstore::KVStore storage_; auth::Module module_; - // Even though the `kvstore::KVStore` class is guaranteed to be thread-safe we - // use a mutex to lock all operations on the `User` and `Role` storage because - // some operations on the users and/or roles may require more than one - // operation on the storage. - std::mutex lock_; }; } // namespace auth diff --git a/src/auth/models.cpp b/src/auth/models.cpp index e3b6443e9..6a9ed1040 100644 --- a/src/auth/models.cpp +++ b/src/auth/models.cpp @@ -216,7 +216,7 @@ void User::SetRole(const Role &role) { role_.emplace(role); } void User::ClearRole() { role_ = std::nullopt; } -const Permissions User::GetPermissions() const { +Permissions User::GetPermissions() const { if (role_) { return Permissions(permissions_.grants() | role_->permissions().grants(), permissions_.denies() | role_->permissions().denies()); @@ -229,7 +229,12 @@ const std::string &User::username() const { return username_; } const Permissions &User::permissions() const { return permissions_; } Permissions &User::permissions() { return permissions_; } -std::optional<Role> User::role() const { return role_; } +const Role *User::role() const { + if (role_.has_value()) { + return &role_.value(); + } + return nullptr; +} nlohmann::json User::Serialize() const { nlohmann::json data = nlohmann::json::object(); diff --git a/src/auth/models.hpp b/src/auth/models.hpp index 5bea738de..03eeb5241 100644 --- a/src/auth/models.hpp +++ b/src/auth/models.hpp @@ -127,14 +127,14 @@ class User final { void ClearRole(); - const Permissions GetPermissions() const; + Permissions GetPermissions() const; const std::string &username() const; const Permissions &permissions() const; Permissions &permissions(); - std::optional<Role> role() const; + const Role *role() const; nlohmann::json Serialize() const; diff --git a/src/kvstore/kvstore.cpp b/src/kvstore/kvstore.cpp index e5b63310b..44a7295c1 100644 --- a/src/kvstore/kvstore.cpp +++ b/src/kvstore/kvstore.cpp @@ -144,7 +144,7 @@ bool KVStore::iterator::IsValid() { return pimpl_->it != nullptr; } // TODO(ipaljak) The complexity of the size function should be at most // logarithmic. -size_t KVStore::Size(const std::string &prefix) { +size_t KVStore::Size(const std::string &prefix) const { size_t size = 0; for (auto it = this->begin(prefix); it != this->end(prefix); ++it) ++size; return size; diff --git a/src/kvstore/kvstore.hpp b/src/kvstore/kvstore.hpp index f51460a96..7858b472c 100644 --- a/src/kvstore/kvstore.hpp +++ b/src/kvstore/kvstore.hpp @@ -126,7 +126,7 @@ class KVStore final { * * @return - number of stored pairs. */ - size_t Size(const std::string &prefix = ""); + size_t Size(const std::string &prefix = "") const; /** * Compact the underlying storage for the key range [begin_prefix, @@ -186,9 +186,9 @@ class KVStore final { std::unique_ptr<impl> pimpl_; }; - iterator begin(const std::string &prefix = "") { return iterator(this, prefix); } + iterator begin(const std::string &prefix = "") const { return iterator(this, prefix); } - iterator end(const std::string &prefix = "") { return iterator(this, prefix, true); } + iterator end(const std::string &prefix = "") const { return iterator(this, prefix, true); } private: struct impl; diff --git a/src/memgraph.cpp b/src/memgraph.cpp index 8add7d27b..352b9d28b 100644 --- a/src/memgraph.cpp +++ b/src/memgraph.cpp @@ -23,6 +23,7 @@ #include "communication/bolt/v1/constants.hpp" #include "helpers.hpp" #include "py/py.hpp" +#include "query/auth_checker.hpp" #include "query/discard_value_stream.hpp" #include "query/exceptions.hpp" #include "query/interpreter.hpp" @@ -40,8 +41,10 @@ #include "utils/logging.hpp" #include "utils/memory_tracker.hpp" #include "utils/readable_size.hpp" +#include "utils/rw_lock.hpp" #include "utils/signals.hpp" #include "utils/string.hpp" +#include "utils/synchronized.hpp" #include "utils/sysinfo/memory.hpp" #include "utils/terminate_handler.hpp" #include "version.hpp" @@ -354,15 +357,373 @@ void ConfigureLogging() { struct SessionData { // Explicit constructor here to ensure that pointers to all objects are // supplied. - SessionData(storage::Storage *db, query::InterpreterContext *interpreter_context, auth::Auth *auth, - audit::Log *audit_log) + SessionData(storage::Storage *db, query::InterpreterContext *interpreter_context, + utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth, audit::Log *audit_log) : db(db), interpreter_context(interpreter_context), auth(auth), audit_log(audit_log) {} storage::Storage *db; query::InterpreterContext *interpreter_context; - auth::Auth *auth; + utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth; audit::Log *audit_log; }; + +DEFINE_string(auth_user_or_role_name_regex, "[a-zA-Z0-9_.+-@]+", + "Set to the regular expression that each user or role name must fulfill."); + +class AuthQueryHandler final : public query::AuthQueryHandler { + utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth_; + std::regex name_regex_; + + public: + AuthQueryHandler(utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth, const std::regex &name_regex) + : auth_(auth), name_regex_(name_regex) {} + + bool CreateUser(const std::string &username, const std::optional<std::string> &password) override { + if (!std::regex_match(username, name_regex_)) { + throw query::QueryRuntimeException("Invalid user name."); + } + try { + auto locked_auth = auth_->Lock(); + return locked_auth->AddUser(username, password).has_value(); + } catch (const auth::AuthException &e) { + throw query::QueryRuntimeException(e.what()); + } + } + + bool DropUser(const std::string &username) override { + if (!std::regex_match(username, name_regex_)) { + throw query::QueryRuntimeException("Invalid user name."); + } + try { + auto locked_auth = auth_->Lock(); + auto user = locked_auth->GetUser(username); + if (!user) return false; + return locked_auth->RemoveUser(username); + } catch (const auth::AuthException &e) { + throw query::QueryRuntimeException(e.what()); + } + } + + void SetPassword(const std::string &username, const std::optional<std::string> &password) override { + if (!std::regex_match(username, name_regex_)) { + throw query::QueryRuntimeException("Invalid user name."); + } + try { + auto locked_auth = auth_->Lock(); + auto user = locked_auth->GetUser(username); + if (!user) { + throw query::QueryRuntimeException("User '{}' doesn't exist.", username); + } + user->UpdatePassword(password); + locked_auth->SaveUser(*user); + } catch (const auth::AuthException &e) { + throw query::QueryRuntimeException(e.what()); + } + } + + bool CreateRole(const std::string &rolename) override { + if (!std::regex_match(rolename, name_regex_)) { + throw query::QueryRuntimeException("Invalid role name."); + } + try { + auto locked_auth = auth_->Lock(); + return locked_auth->AddRole(rolename).has_value(); + } catch (const auth::AuthException &e) { + throw query::QueryRuntimeException(e.what()); + } + } + + bool DropRole(const std::string &rolename) override { + if (!std::regex_match(rolename, name_regex_)) { + throw query::QueryRuntimeException("Invalid role name."); + } + try { + auto locked_auth = auth_->Lock(); + auto role = locked_auth->GetRole(rolename); + if (!role) return false; + return locked_auth->RemoveRole(rolename); + } catch (const auth::AuthException &e) { + throw query::QueryRuntimeException(e.what()); + } + } + + std::vector<query::TypedValue> GetUsernames() override { + try { + auto locked_auth = auth_->ReadLock(); + std::vector<query::TypedValue> usernames; + const auto &users = locked_auth->AllUsers(); + usernames.reserve(users.size()); + for (const auto &user : users) { + usernames.emplace_back(user.username()); + } + return usernames; + } catch (const auth::AuthException &e) { + throw query::QueryRuntimeException(e.what()); + } + } + + std::vector<query::TypedValue> GetRolenames() override { + try { + auto locked_auth = auth_->ReadLock(); + std::vector<query::TypedValue> rolenames; + const auto &roles = locked_auth->AllRoles(); + rolenames.reserve(roles.size()); + for (const auto &role : roles) { + rolenames.emplace_back(role.rolename()); + } + return rolenames; + } catch (const auth::AuthException &e) { + throw query::QueryRuntimeException(e.what()); + } + } + + std::optional<std::string> GetRolenameForUser(const std::string &username) override { + if (!std::regex_match(username, name_regex_)) { + throw query::QueryRuntimeException("Invalid user name."); + } + try { + auto locked_auth = auth_->ReadLock(); + auto user = locked_auth->GetUser(username); + if (!user) { + throw query::QueryRuntimeException("User '{}' doesn't exist .", username); + } + + if (const auto *role = user->role(); role != nullptr) { + return role->rolename(); + } + return std::nullopt; + } catch (const auth::AuthException &e) { + throw query::QueryRuntimeException(e.what()); + } + } + + std::vector<query::TypedValue> GetUsernamesForRole(const std::string &rolename) override { + if (!std::regex_match(rolename, name_regex_)) { + throw query::QueryRuntimeException("Invalid role name."); + } + try { + auto locked_auth = auth_->ReadLock(); + auto role = locked_auth->GetRole(rolename); + if (!role) { + throw query::QueryRuntimeException("Role '{}' doesn't exist.", rolename); + } + std::vector<query::TypedValue> usernames; + const auto &users = locked_auth->AllUsersForRole(rolename); + usernames.reserve(users.size()); + for (const auto &user : users) { + usernames.emplace_back(user.username()); + } + return usernames; + } catch (const auth::AuthException &e) { + throw query::QueryRuntimeException(e.what()); + } + } + + void SetRole(const std::string &username, const std::string &rolename) override { + if (!std::regex_match(username, name_regex_)) { + throw query::QueryRuntimeException("Invalid user name."); + } + if (!std::regex_match(rolename, name_regex_)) { + throw query::QueryRuntimeException("Invalid role name."); + } + try { + auto locked_auth = auth_->Lock(); + auto user = locked_auth->GetUser(username); + if (!user) { + throw query::QueryRuntimeException("User '{}' doesn't exist .", username); + } + auto role = locked_auth->GetRole(rolename); + if (!role) { + throw query::QueryRuntimeException("Role '{}' doesn't exist .", rolename); + } + if (const auto *current_role = user->role(); current_role != nullptr) { + throw query::QueryRuntimeException("User '{}' is already a member of role '{}'.", username, + current_role->rolename()); + } + user->SetRole(*role); + locked_auth->SaveUser(*user); + } catch (const auth::AuthException &e) { + throw query::QueryRuntimeException(e.what()); + } + } + + void ClearRole(const std::string &username) override { + if (!std::regex_match(username, name_regex_)) { + throw query::QueryRuntimeException("Invalid user name."); + } + try { + auto locked_auth = auth_->Lock(); + auto user = locked_auth->GetUser(username); + if (!user) { + throw query::QueryRuntimeException("User '{}' doesn't exist .", username); + } + user->ClearRole(); + locked_auth->SaveUser(*user); + } catch (const auth::AuthException &e) { + throw query::QueryRuntimeException(e.what()); + } + } + + std::vector<std::vector<query::TypedValue>> GetPrivileges(const std::string &user_or_role) override { + if (!std::regex_match(user_or_role, name_regex_)) { + throw query::QueryRuntimeException("Invalid user or role name."); + } + try { + auto locked_auth = auth_->ReadLock(); + std::vector<std::vector<query::TypedValue>> grants; + auto user = locked_auth->GetUser(user_or_role); + auto role = locked_auth->GetRole(user_or_role); + if (!user && !role) { + throw query::QueryRuntimeException("User or role '{}' doesn't exist.", user_or_role); + } + if (user) { + const auto &permissions = user->GetPermissions(); + for (const auto &privilege : query::kPrivilegesAll) { + auto permission = glue::PrivilegeToPermission(privilege); + auto effective = permissions.Has(permission); + if (permissions.Has(permission) != auth::PermissionLevel::NEUTRAL) { + std::vector<std::string> description; + auto user_level = user->permissions().Has(permission); + if (user_level == auth::PermissionLevel::GRANT) { + description.emplace_back("GRANTED TO USER"); + } else if (user_level == auth::PermissionLevel::DENY) { + description.emplace_back("DENIED TO USER"); + } + if (const auto *role = user->role(); role != nullptr) { + auto role_level = role->permissions().Has(permission); + if (role_level == auth::PermissionLevel::GRANT) { + description.emplace_back("GRANTED TO ROLE"); + } else if (role_level == auth::PermissionLevel::DENY) { + description.emplace_back("DENIED TO ROLE"); + } + } + grants.push_back({query::TypedValue(auth::PermissionToString(permission)), + query::TypedValue(auth::PermissionLevelToString(effective)), + query::TypedValue(utils::Join(description, ", "))}); + } + } + } else { + const auto &permissions = role->permissions(); + for (const auto &privilege : query::kPrivilegesAll) { + auto permission = glue::PrivilegeToPermission(privilege); + auto effective = permissions.Has(permission); + if (effective != auth::PermissionLevel::NEUTRAL) { + std::string description; + if (effective == auth::PermissionLevel::GRANT) { + description = "GRANTED TO ROLE"; + } else if (effective == auth::PermissionLevel::DENY) { + description = "DENIED TO ROLE"; + } + grants.push_back({query::TypedValue(auth::PermissionToString(permission)), + query::TypedValue(auth::PermissionLevelToString(effective)), + query::TypedValue(description)}); + } + } + } + return grants; + } catch (const auth::AuthException &e) { + throw query::QueryRuntimeException(e.what()); + } + } + + void GrantPrivilege(const std::string &user_or_role, + const std::vector<query::AuthQuery::Privilege> &privileges) override { + EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) { + // TODO (mferencevic): should we first check that the + // privilege is granted/denied/revoked before + // unconditionally granting/denying/revoking it? + permissions->Grant(permission); + }); + } + + void DenyPrivilege(const std::string &user_or_role, + const std::vector<query::AuthQuery::Privilege> &privileges) override { + EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) { + // TODO (mferencevic): should we first check that the + // privilege is granted/denied/revoked before + // unconditionally granting/denying/revoking it? + permissions->Deny(permission); + }); + } + + void RevokePrivilege(const std::string &user_or_role, + const std::vector<query::AuthQuery::Privilege> &privileges) override { + EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) { + // TODO (mferencevic): should we first check that the + // privilege is granted/denied/revoked before + // unconditionally granting/denying/revoking it? + permissions->Revoke(permission); + }); + } + + private: + template <class TEditFun> + void EditPermissions(const std::string &user_or_role, const std::vector<query::AuthQuery::Privilege> &privileges, + const TEditFun &edit_fun) { + if (!std::regex_match(user_or_role, name_regex_)) { + throw query::QueryRuntimeException("Invalid user or role name."); + } + try { + auto locked_auth = auth_->Lock(); + std::vector<auth::Permission> permissions; + permissions.reserve(privileges.size()); + for (const auto &privilege : privileges) { + permissions.push_back(glue::PrivilegeToPermission(privilege)); + } + auto user = locked_auth->GetUser(user_or_role); + auto role = locked_auth->GetRole(user_or_role); + if (!user && !role) { + throw query::QueryRuntimeException("User or role '{}' doesn't exist.", user_or_role); + } + if (user) { + for (const auto &permission : permissions) { + edit_fun(&user->permissions(), permission); + } + locked_auth->SaveUser(*user); + } else { + for (const auto &permission : permissions) { + edit_fun(&role->permissions(), permission); + } + locked_auth->SaveRole(*role); + } + } catch (const auth::AuthException &e) { + throw query::QueryRuntimeException(e.what()); + } + } +}; + +class AuthChecker final : public query::AuthChecker { + public: + explicit AuthChecker(utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth) : auth_{auth} {} + + static bool IsUserAuthorized(const auth::User &user, const std::vector<query::AuthQuery::Privilege> &privileges) { + const auto user_permissions = user.GetPermissions(); + return std::all_of(privileges.begin(), privileges.end(), [&user_permissions](const auto privilege) { + return user_permissions.Has(glue::PrivilegeToPermission(privilege)) == auth::PermissionLevel::GRANT; + }); + } + + bool IsUserAuthorized(const std::optional<std::string> &username, + const std::vector<query::AuthQuery::Privilege> &privileges) const final { + std::optional<auth::User> maybe_user; + { + auto locked_auth = auth_->ReadLock(); + if (!locked_auth->HasUsers()) { + return true; + } + if (username.has_value()) { + maybe_user = locked_auth->GetUser(*username); + } + } + + return maybe_user.has_value() && IsUserAuthorized(*maybe_user, privileges); + } + + private: + utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth_; +}; + #else + struct SessionData { // Explicit constructor here to ensure that pointers to all objects are // supplied. @@ -371,6 +732,52 @@ struct SessionData { storage::Storage *db; query::InterpreterContext *interpreter_context; }; + +class NoAuthInCommunity : public query::QueryRuntimeException { + public: + NoAuthInCommunity() + : query::QueryRuntimeException::QueryRuntimeException("Auth is not supported in Memgraph Community!") {} +}; + +class AuthQueryHandler final : public query::AuthQueryHandler { + public: + bool CreateUser(const std::string &, const std::optional<std::string> &) override { throw NoAuthInCommunity(); } + + bool DropUser(const std::string &) override { throw NoAuthInCommunity(); } + + void SetPassword(const std::string &, const std::optional<std::string> &) override { throw NoAuthInCommunity(); } + + bool CreateRole(const std::string &) override { throw NoAuthInCommunity(); } + + bool DropRole(const std::string &) override { throw NoAuthInCommunity(); } + + std::vector<query::TypedValue> GetUsernames() override { throw NoAuthInCommunity(); } + + std::vector<query::TypedValue> GetRolenames() override { throw NoAuthInCommunity(); } + + std::optional<std::string> GetRolenameForUser(const std::string &) override { throw NoAuthInCommunity(); } + + std::vector<query::TypedValue> GetUsernamesForRole(const std::string &) override { throw NoAuthInCommunity(); } + + void SetRole(const std::string &, const std::string &) override { throw NoAuthInCommunity(); } + + void ClearRole(const std::string &) override { throw NoAuthInCommunity(); } + + std::vector<std::vector<query::TypedValue>> GetPrivileges(const std::string &) override { throw NoAuthInCommunity(); } + + void GrantPrivilege(const std::string &, const std::vector<query::AuthQuery::Privilege> &) override { + throw NoAuthInCommunity(); + } + + void DenyPrivilege(const std::string &, const std::vector<query::AuthQuery::Privilege> &) override { + throw NoAuthInCommunity(); + } + + void RevokePrivilege(const std::string &, const std::vector<query::AuthQuery::Privilege> &) override { + throw NoAuthInCommunity(); + } +}; + #endif class BoltSession final : public communication::bolt::Session<communication::InputStream, communication::OutputStream> { @@ -400,22 +807,21 @@ class BoltSession final : public communication::bolt::Session<communication::Inp const std::string &query, const std::map<std::string, communication::bolt::Value> ¶ms) override { std::map<std::string, storage::PropertyValue> params_pv; for (const auto &kv : params) params_pv.emplace(kv.first, glue::ToPropertyValue(kv.second)); + const std::string *username{nullptr}; #ifdef MG_ENTERPRISE - audit_log_->Record(endpoint_.address, user_ ? user_->username() : "", query, storage::PropertyValue(params_pv)); + if (user_) { + username = &user_->username(); + } + audit_log_->Record(endpoint_.address, user_ ? *username : "", query, storage::PropertyValue(params_pv)); #endif try { - auto result = interpreter_.Prepare(query, params_pv); + auto result = interpreter_.Prepare(query, params_pv, username); #ifdef MG_ENTERPRISE - if (user_) { - const auto &permissions = user_->GetPermissions(); - for (const auto &privilege : result.privileges) { - if (permissions.Has(glue::PrivilegeToPermission(privilege)) != auth::PermissionLevel::GRANT) { - interpreter_.Abort(); - throw communication::bolt::ClientError( - "You are not authorized to execute this query! Please contact " - "your database administrator."); - } - } + if (user_ && !AuthChecker::IsUserAuthorized(*user_, result.privileges)) { + interpreter_.Abort(); + throw communication::bolt::ClientError( + "You are not authorized to execute this query! Please contact " + "your database administrator."); } #endif return {result.headers, result.qid}; @@ -442,9 +848,12 @@ class BoltSession final : public communication::bolt::Session<communication::Inp bool Authenticate(const std::string &username, const std::string &password) override { #ifdef MG_ENTERPRISE - if (!auth_->HasUsers()) return true; - user_ = auth_->Authenticate(username, password); - return !!user_; + auto locked_auth = auth_->Lock(); + if (!locked_auth->HasUsers()) { + return true; + } + user_ = locked_auth->Authenticate(username, password); + return user_.has_value(); #else return true; #endif @@ -522,7 +931,7 @@ class BoltSession final : public communication::bolt::Session<communication::Inp const storage::Storage *db_; query::Interpreter interpreter_; #ifdef MG_ENTERPRISE - auth::Auth *auth_; + utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth_; std::optional<auth::User> user_; audit::Log *audit_log_; #endif @@ -532,374 +941,6 @@ class BoltSession final : public communication::bolt::Session<communication::Inp using ServerT = communication::Server<BoltSession, SessionData>; using communication::ServerContext; -#ifdef MG_ENTERPRISE -DEFINE_string(auth_user_or_role_name_regex, "[a-zA-Z0-9_.+-@]+", - "Set to the regular expression that each user or role name must fulfill."); - -class AuthQueryHandler final : public query::AuthQueryHandler { - auth::Auth *auth_; - std::regex name_regex_; - - public: - AuthQueryHandler(auth::Auth *auth, const std::regex &name_regex) : auth_(auth), name_regex_(name_regex) {} - - bool CreateUser(const std::string &username, const std::optional<std::string> &password) override { - if (!std::regex_match(username, name_regex_)) { - throw query::QueryRuntimeException("Invalid user name."); - } - try { - std::lock_guard<std::mutex> lock(auth_->WithLock()); - return !!auth_->AddUser(username, password); - } catch (const auth::AuthException &e) { - throw query::QueryRuntimeException(e.what()); - } - } - - bool DropUser(const std::string &username) override { - if (!std::regex_match(username, name_regex_)) { - throw query::QueryRuntimeException("Invalid user name."); - } - try { - std::lock_guard<std::mutex> lock(auth_->WithLock()); - auto user = auth_->GetUser(username); - if (!user) return false; - return auth_->RemoveUser(username); - } catch (const auth::AuthException &e) { - throw query::QueryRuntimeException(e.what()); - } - } - - void SetPassword(const std::string &username, const std::optional<std::string> &password) override { - if (!std::regex_match(username, name_regex_)) { - throw query::QueryRuntimeException("Invalid user name."); - } - try { - std::lock_guard<std::mutex> lock(auth_->WithLock()); - auto user = auth_->GetUser(username); - if (!user) { - throw query::QueryRuntimeException("User '{}' doesn't exist.", username); - } - user->UpdatePassword(password); - auth_->SaveUser(*user); - } catch (const auth::AuthException &e) { - throw query::QueryRuntimeException(e.what()); - } - } - - bool CreateRole(const std::string &rolename) override { - if (!std::regex_match(rolename, name_regex_)) { - throw query::QueryRuntimeException("Invalid role name."); - } - try { - std::lock_guard<std::mutex> lock(auth_->WithLock()); - return !!auth_->AddRole(rolename); - } catch (const auth::AuthException &e) { - throw query::QueryRuntimeException(e.what()); - } - } - - bool DropRole(const std::string &rolename) override { - if (!std::regex_match(rolename, name_regex_)) { - throw query::QueryRuntimeException("Invalid role name."); - } - try { - std::lock_guard<std::mutex> lock(auth_->WithLock()); - auto role = auth_->GetRole(rolename); - if (!role) return false; - return auth_->RemoveRole(rolename); - } catch (const auth::AuthException &e) { - throw query::QueryRuntimeException(e.what()); - } - } - - std::vector<query::TypedValue> GetUsernames() override { - try { - std::lock_guard<std::mutex> lock(auth_->WithLock()); - std::vector<query::TypedValue> usernames; - const auto &users = auth_->AllUsers(); - usernames.reserve(users.size()); - for (const auto &user : users) { - usernames.emplace_back(user.username()); - } - return usernames; - } catch (const auth::AuthException &e) { - throw query::QueryRuntimeException(e.what()); - } - } - - std::vector<query::TypedValue> GetRolenames() override { - try { - std::lock_guard<std::mutex> lock(auth_->WithLock()); - std::vector<query::TypedValue> rolenames; - const auto &roles = auth_->AllRoles(); - rolenames.reserve(roles.size()); - for (const auto &role : roles) { - rolenames.emplace_back(role.rolename()); - } - return rolenames; - } catch (const auth::AuthException &e) { - throw query::QueryRuntimeException(e.what()); - } - } - - std::optional<std::string> GetRolenameForUser(const std::string &username) override { - if (!std::regex_match(username, name_regex_)) { - throw query::QueryRuntimeException("Invalid user name."); - } - try { - std::lock_guard<std::mutex> lock(auth_->WithLock()); - auto user = auth_->GetUser(username); - if (!user) { - throw query::QueryRuntimeException("User '{}' doesn't exist .", username); - } - if (user->role()) return user->role()->rolename(); - return std::nullopt; - } catch (const auth::AuthException &e) { - throw query::QueryRuntimeException(e.what()); - } - } - - std::vector<query::TypedValue> GetUsernamesForRole(const std::string &rolename) override { - if (!std::regex_match(rolename, name_regex_)) { - throw query::QueryRuntimeException("Invalid role name."); - } - try { - std::lock_guard<std::mutex> lock(auth_->WithLock()); - auto role = auth_->GetRole(rolename); - if (!role) { - throw query::QueryRuntimeException("Role '{}' doesn't exist.", rolename); - } - std::vector<query::TypedValue> usernames; - const auto &users = auth_->AllUsersForRole(rolename); - usernames.reserve(users.size()); - for (const auto &user : users) { - usernames.emplace_back(user.username()); - } - return usernames; - } catch (const auth::AuthException &e) { - throw query::QueryRuntimeException(e.what()); - } - } - - void SetRole(const std::string &username, const std::string &rolename) override { - if (!std::regex_match(username, name_regex_)) { - throw query::QueryRuntimeException("Invalid user name."); - } - if (!std::regex_match(rolename, name_regex_)) { - throw query::QueryRuntimeException("Invalid role name."); - } - try { - std::lock_guard<std::mutex> lock(auth_->WithLock()); - auto user = auth_->GetUser(username); - if (!user) { - throw query::QueryRuntimeException("User '{}' doesn't exist .", username); - } - auto role = auth_->GetRole(rolename); - if (!role) { - throw query::QueryRuntimeException("Role '{}' doesn't exist .", rolename); - } - if (user->role()) { - throw query::QueryRuntimeException("User '{}' is already a member of role '{}'.", username, - user->role()->rolename()); - } - user->SetRole(*role); - auth_->SaveUser(*user); - } catch (const auth::AuthException &e) { - throw query::QueryRuntimeException(e.what()); - } - } - - void ClearRole(const std::string &username) override { - if (!std::regex_match(username, name_regex_)) { - throw query::QueryRuntimeException("Invalid user name."); - } - try { - std::lock_guard<std::mutex> lock(auth_->WithLock()); - auto user = auth_->GetUser(username); - if (!user) { - throw query::QueryRuntimeException("User '{}' doesn't exist .", username); - } - user->ClearRole(); - auth_->SaveUser(*user); - } catch (const auth::AuthException &e) { - throw query::QueryRuntimeException(e.what()); - } - } - - std::vector<std::vector<query::TypedValue>> GetPrivileges(const std::string &user_or_role) override { - if (!std::regex_match(user_or_role, name_regex_)) { - throw query::QueryRuntimeException("Invalid user or role name."); - } - try { - std::lock_guard<std::mutex> lock(auth_->WithLock()); - std::vector<std::vector<query::TypedValue>> grants; - auto user = auth_->GetUser(user_or_role); - auto role = auth_->GetRole(user_or_role); - if (!user && !role) { - throw query::QueryRuntimeException("User or role '{}' doesn't exist.", user_or_role); - } - if (user) { - const auto &permissions = user->GetPermissions(); - for (const auto &privilege : query::kPrivilegesAll) { - auto permission = glue::PrivilegeToPermission(privilege); - auto effective = permissions.Has(permission); - if (permissions.Has(permission) != auth::PermissionLevel::NEUTRAL) { - std::vector<std::string> description; - auto user_level = user->permissions().Has(permission); - if (user_level == auth::PermissionLevel::GRANT) { - description.emplace_back("GRANTED TO USER"); - } else if (user_level == auth::PermissionLevel::DENY) { - description.emplace_back("DENIED TO USER"); - } - if (user->role()) { - auto role_level = user->role()->permissions().Has(permission); - if (role_level == auth::PermissionLevel::GRANT) { - description.emplace_back("GRANTED TO ROLE"); - } else if (role_level == auth::PermissionLevel::DENY) { - description.emplace_back("DENIED TO ROLE"); - } - } - grants.push_back({query::TypedValue(auth::PermissionToString(permission)), - query::TypedValue(auth::PermissionLevelToString(effective)), - query::TypedValue(utils::Join(description, ", "))}); - } - } - } else { - const auto &permissions = role->permissions(); - for (const auto &privilege : query::kPrivilegesAll) { - auto permission = glue::PrivilegeToPermission(privilege); - auto effective = permissions.Has(permission); - if (effective != auth::PermissionLevel::NEUTRAL) { - std::string description; - if (effective == auth::PermissionLevel::GRANT) { - description = "GRANTED TO ROLE"; - } else if (effective == auth::PermissionLevel::DENY) { - description = "DENIED TO ROLE"; - } - grants.push_back({query::TypedValue(auth::PermissionToString(permission)), - query::TypedValue(auth::PermissionLevelToString(effective)), - query::TypedValue(description)}); - } - } - } - return grants; - } catch (const auth::AuthException &e) { - throw query::QueryRuntimeException(e.what()); - } - } - - void GrantPrivilege(const std::string &user_or_role, - const std::vector<query::AuthQuery::Privilege> &privileges) override { - EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) { - // TODO (mferencevic): should we first check that the - // privilege is granted/denied/revoked before - // unconditionally granting/denying/revoking it? - permissions->Grant(permission); - }); - } - - void DenyPrivilege(const std::string &user_or_role, - const std::vector<query::AuthQuery::Privilege> &privileges) override { - EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) { - // TODO (mferencevic): should we first check that the - // privilege is granted/denied/revoked before - // unconditionally granting/denying/revoking it? - permissions->Deny(permission); - }); - } - - void RevokePrivilege(const std::string &user_or_role, - const std::vector<query::AuthQuery::Privilege> &privileges) override { - EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) { - // TODO (mferencevic): should we first check that the - // privilege is granted/denied/revoked before - // unconditionally granting/denying/revoking it? - permissions->Revoke(permission); - }); - } - - private: - template <class TEditFun> - void EditPermissions(const std::string &user_or_role, const std::vector<query::AuthQuery::Privilege> &privileges, - const TEditFun &edit_fun) { - if (!std::regex_match(user_or_role, name_regex_)) { - throw query::QueryRuntimeException("Invalid user or role name."); - } - try { - std::lock_guard<std::mutex> lock(auth_->WithLock()); - std::vector<auth::Permission> permissions; - permissions.reserve(privileges.size()); - for (const auto &privilege : privileges) { - permissions.push_back(glue::PrivilegeToPermission(privilege)); - } - auto user = auth_->GetUser(user_or_role); - auto role = auth_->GetRole(user_or_role); - if (!user && !role) { - throw query::QueryRuntimeException("User or role '{}' doesn't exist.", user_or_role); - } - if (user) { - for (const auto &permission : permissions) { - edit_fun(&user->permissions(), permission); - } - auth_->SaveUser(*user); - } else { - for (const auto &permission : permissions) { - edit_fun(&role->permissions(), permission); - } - auth_->SaveRole(*role); - } - } catch (const auth::AuthException &e) { - throw query::QueryRuntimeException(e.what()); - } - } -}; -#else -class NoAuthInCommunity : public query::QueryRuntimeException { - public: - NoAuthInCommunity() - : query::QueryRuntimeException::QueryRuntimeException("Auth is not supported in Memgraph Community!") {} -}; - -class AuthQueryHandler final : public query::AuthQueryHandler { - public: - bool CreateUser(const std::string &, const std::optional<std::string> &) override { throw NoAuthInCommunity(); } - - bool DropUser(const std::string &) override { throw NoAuthInCommunity(); } - - void SetPassword(const std::string &, const std::optional<std::string> &) override { throw NoAuthInCommunity(); } - - bool CreateRole(const std::string &) override { throw NoAuthInCommunity(); } - - bool DropRole(const std::string &) override { throw NoAuthInCommunity(); } - - std::vector<query::TypedValue> GetUsernames() override { throw NoAuthInCommunity(); } - - std::vector<query::TypedValue> GetRolenames() override { throw NoAuthInCommunity(); } - - std::optional<std::string> GetRolenameForUser(const std::string &) override { throw NoAuthInCommunity(); } - - std::vector<query::TypedValue> GetUsernamesForRole(const std::string &) override { throw NoAuthInCommunity(); } - - void SetRole(const std::string &, const std::string &) override { throw NoAuthInCommunity(); } - - void ClearRole(const std::string &) override { throw NoAuthInCommunity(); } - - std::vector<std::vector<query::TypedValue>> GetPrivileges(const std::string &) override { throw NoAuthInCommunity(); } - - void GrantPrivilege(const std::string &, const std::vector<query::AuthQuery::Privilege> &) override { - throw NoAuthInCommunity(); - } - - void DenyPrivilege(const std::string &, const std::vector<query::AuthQuery::Privilege> &) override { - throw NoAuthInCommunity(); - } - - void RevokePrivilege(const std::string &, const std::vector<query::AuthQuery::Privilege> &) override { - throw NoAuthInCommunity(); - } -}; -#endif - // Needed to correctly handle memgraph destruction from a signal handler. // Without having some sort of a flag, it is possible that a signal is handled // when we are exiting main, inside destructors of database::GraphDb and @@ -1013,7 +1054,7 @@ int main(int argc, char **argv) { // Begin enterprise features initialization // Auth - auth::Auth auth{data_directory / "auth"}; + utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> auth{data_directory / "auth"}; // Audit log audit::Log audit_log{data_directory / "audit", FLAGS_audit_buffer_size, FLAGS_audit_buffer_flush_interval_ms}; @@ -1075,24 +1116,28 @@ int main(int argc, char **argv) { query::procedure::gModuleRegistry.SetModulesDirectory(query_modules_directories); query::procedure::gModuleRegistry.UnloadAndLoadModulesFromDirectories(); - { - // Triggers can execute query procedures, so we need to reload the modules first and then - // the triggers - auto storage_accessor = interpreter_context.db->Access(); - auto dba = query::DbAccessor{&storage_accessor}; - interpreter_context.trigger_store.RestoreTriggers( - &interpreter_context.ast_cache, &dba, &interpreter_context.antlr_lock, interpreter_context.config.query); - } - // As the Stream transformations are using modules, they have to be restored after the query modules are loaded. interpreter_context.streams.RestoreStreams(); #ifdef MG_ENTERPRISE AuthQueryHandler auth_handler(&auth, std::regex(FLAGS_auth_user_or_role_name_regex)); + AuthChecker auth_checker{&auth}; #else AuthQueryHandler auth_handler; + query::AllowEverythingAuthChecker auth_checker{}; #endif interpreter_context.auth = &auth_handler; + interpreter_context.auth_checker = &auth_checker; + + { + // Triggers can execute query procedures, so we need to reload the modules first and then + // the triggers + auto storage_accessor = interpreter_context.db->Access(); + auto dba = query::DbAccessor{&storage_accessor}; + interpreter_context.trigger_store.RestoreTriggers(&interpreter_context.ast_cache, &dba, + &interpreter_context.antlr_lock, interpreter_context.config.query, + interpreter_context.auth_checker); + } ServerContext context; std::string service_name = "Bolt"; diff --git a/src/query/auth_checker.hpp b/src/query/auth_checker.hpp new file mode 100644 index 000000000..8734a7c80 --- /dev/null +++ b/src/query/auth_checker.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include "query/frontend/ast/ast.hpp" + +namespace query { +class AuthChecker { + public: + virtual bool IsUserAuthorized(const std::optional<std::string> &username, + const std::vector<query::AuthQuery::Privilege> &privileges) const = 0; +}; + +class AllowEverythingAuthChecker final : public query::AuthChecker { + bool IsUserAuthorized(const std::optional<std::string> &username, + const std::vector<query::AuthQuery::Privilege> &privileges) const override { + return true; + } +}; +} // namespace query \ No newline at end of file diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index 236d4baf5..2885a7046 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -3,6 +3,7 @@ #include <atomic> #include <chrono> #include <limits> +#include <optional> #include "glue/communication.hpp" #include "query/constants.hpp" @@ -463,8 +464,13 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters & } } +std::optional<std::string> StringPointerToOptional(const std::string *str) { + return str == nullptr ? std::nullopt : std::make_optional(*str); +} + Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters ¶meters, - InterpreterContext *interpreter_context, DbAccessor *db_accessor) { + InterpreterContext *interpreter_context, DbAccessor *db_accessor, + const std::string *username) { Frame frame(0); SymbolTable symbol_table; EvaluationContext evaluation_context; @@ -484,19 +490,21 @@ Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters ¶mete std::string consumer_group{stream_query->consumer_group_.empty() ? kDefaultConsumerGroup : stream_query->consumer_group_}; - callback.fn = [interpreter_context, stream_name = stream_query->stream_name_, - topic_names = stream_query->topic_names_, consumer_group = std::move(consumer_group), - batch_interval = - GetOptionalValue<std::chrono::milliseconds>(stream_query->batch_interval_, evaluator), - batch_size = GetOptionalValue<int64_t>(stream_query->batch_size_, evaluator), - transformation_name = stream_query->transform_name_]() mutable { - interpreter_context->streams.Create(stream_name, query::StreamInfo{.topics = std::move(topic_names), - .consumer_group = std::move(consumer_group), - .batch_interval = batch_interval, - .batch_size = batch_size, - .transformation_name = transformation_name}); - return std::vector<std::vector<TypedValue>>{}; - }; + callback.fn = + [interpreter_context, stream_name = stream_query->stream_name_, topic_names = stream_query->topic_names_, + consumer_group = std::move(consumer_group), + batch_interval = GetOptionalValue<std::chrono::milliseconds>(stream_query->batch_interval_, evaluator), + batch_size = GetOptionalValue<int64_t>(stream_query->batch_size_, evaluator), + transformation_name = stream_query->transform_name_, owner = StringPointerToOptional(username)]() mutable { + interpreter_context->streams.Create(stream_name, + query::StreamInfo{.topics = std::move(topic_names), + .consumer_group = std::move(consumer_group), + .batch_interval = batch_interval, + .batch_size = batch_size, + .transformation_name = std::move(transformation_name), + .owner = std::move(owner)}); + return std::vector<std::vector<TypedValue>>{}; + }; return callback; } case StreamQuery::Action::START_STREAM: { @@ -535,8 +543,8 @@ Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters ¶mete return callback; } case StreamQuery::Action::SHOW_STREAMS: { - callback.header = {"name", "topics", "consumer_group", "batch_interval", "batch_size", "transformation_name", - "is running"}; + callback.header = {"name", "topics", "consumer_group", "batch_interval", "batch_size", "transformation_name", + "owner", "is running"}; callback.fn = [interpreter_context]() { auto streams_status = interpreter_context->streams.GetStreamInfo(); std::vector<std::vector<TypedValue>> results; @@ -565,6 +573,11 @@ Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters ¶mete typed_status.emplace_back(); } typed_status.emplace_back(stream_info.transformation_name); + if (stream_info.owner.has_value()) { + typed_status.emplace_back(*stream_info.owner); + } else { + typed_status.emplace_back(); + } }; for (const auto &status : streams_status) { @@ -1231,16 +1244,17 @@ TriggerEventType ToTriggerEventType(const TriggerQuery::EventType event_type) { Callback CreateTrigger(TriggerQuery *trigger_query, const std::map<std::string, storage::PropertyValue> &user_parameters, - InterpreterContext *interpreter_context, DbAccessor *dba) { + InterpreterContext *interpreter_context, DbAccessor *dba, std::optional<std::string> owner) { return { {}, [trigger_name = std::move(trigger_query->trigger_name_), trigger_statement = std::move(trigger_query->statement_), event_type = trigger_query->event_type_, before_commit = trigger_query->before_commit_, interpreter_context, dba, - user_parameters]() -> std::vector<std::vector<TypedValue>> { + user_parameters, owner = std::move(owner)]() mutable -> std::vector<std::vector<TypedValue>> { interpreter_context->trigger_store.AddTrigger( - trigger_name, trigger_statement, user_parameters, ToTriggerEventType(event_type), + std::move(trigger_name), trigger_statement, user_parameters, ToTriggerEventType(event_type), before_commit ? TriggerPhase::BEFORE_COMMIT : TriggerPhase::AFTER_COMMIT, &interpreter_context->ast_cache, - dba, &interpreter_context->antlr_lock, interpreter_context->config.query); + dba, &interpreter_context->antlr_lock, interpreter_context->config.query, std::move(owner), + interpreter_context->auth_checker); return {}; }}; } @@ -1255,7 +1269,7 @@ Callback DropTrigger(TriggerQuery *trigger_query, InterpreterContext *interprete } Callback ShowTriggers(InterpreterContext *interpreter_context) { - return {{"trigger name", "statement", "event type", "phase"}, [interpreter_context] { + return {{"trigger name", "statement", "event type", "phase", "owner"}, [interpreter_context] { std::vector<std::vector<TypedValue>> results; auto trigger_infos = interpreter_context->trigger_store.GetTriggerInfo(); results.reserve(trigger_infos.size()); @@ -1267,6 +1281,9 @@ Callback ShowTriggers(InterpreterContext *interpreter_context) { typed_trigger_info.emplace_back(TriggerEventTypeToString(trigger_info.event_type)); typed_trigger_info.emplace_back(trigger_info.phase == TriggerPhase::BEFORE_COMMIT ? "BEFORE COMMIT" : "AFTER COMMIT"); + typed_trigger_info.emplace_back(trigger_info.owner.has_value() ? TypedValue{*trigger_info.owner} + : TypedValue{}); + results.push_back(std::move(typed_trigger_info)); } @@ -1276,7 +1293,8 @@ Callback ShowTriggers(InterpreterContext *interpreter_context) { PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, const bool in_explicit_transaction, InterpreterContext *interpreter_context, DbAccessor *dba, - const std::map<std::string, storage::PropertyValue> &user_parameters) { + const std::map<std::string, storage::PropertyValue> &user_parameters, + const std::string *username) { if (in_explicit_transaction) { throw TriggerModificationInMulticommandTxException(); } @@ -1284,11 +1302,12 @@ PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, const bool in_explic auto *trigger_query = utils::Downcast<TriggerQuery>(parsed_query.query); MG_ASSERT(trigger_query); - auto callback = [trigger_query, interpreter_context, dba, &user_parameters] { + auto callback = [trigger_query, interpreter_context, dba, &user_parameters, + owner = StringPointerToOptional(username)]() mutable { switch (trigger_query->action_) { case TriggerQuery::Action::CREATE_TRIGGER: EventCounter::IncrementCounter(EventCounter::TriggersCreated); - return CreateTrigger(trigger_query, user_parameters, interpreter_context, dba); + return CreateTrigger(trigger_query, user_parameters, interpreter_context, dba, std::move(owner)); case TriggerQuery::Action::DROP_TRIGGER: return DropTrigger(trigger_query, interpreter_context); case TriggerQuery::Action::SHOW_TRIGGERS: @@ -1315,14 +1334,15 @@ PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, const bool in_explic PreparedQuery PrepareStreamQuery(ParsedQuery parsed_query, const bool in_explicit_transaction, InterpreterContext *interpreter_context, DbAccessor *dba, - const std::map<std::string, storage::PropertyValue> &user_parameters) { + const std::map<std::string, storage::PropertyValue> &user_parameters, + const std::string *username) { if (in_explicit_transaction) { throw StreamQueryInMulticommandTxException(); } auto *stream_query = utils::Downcast<StreamQuery>(parsed_query.query); MG_ASSERT(stream_query); - auto callback = HandleStreamQuery(stream_query, parsed_query.parameters, interpreter_context, dba); + auto callback = HandleStreamQuery(stream_query, parsed_query.parameters, interpreter_context, dba, username); return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges), [callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}]( @@ -1651,7 +1671,8 @@ void Interpreter::RollbackTransaction() { } Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, - const std::map<std::string, storage::PropertyValue> ¶ms) { + const std::map<std::string, storage::PropertyValue> ¶ms, + const std::string *username) { if (!in_explicit_transaction_) { query_executions_.clear(); } @@ -1748,10 +1769,10 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, prepared_query = PrepareFreeMemoryQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_); } else if (utils::Downcast<TriggerQuery>(parsed_query.query)) { prepared_query = PrepareTriggerQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_, - &*execution_db_accessor_, params); + &*execution_db_accessor_, params, username); } else if (utils::Downcast<StreamQuery>(parsed_query.query)) { prepared_query = PrepareStreamQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_, - &*execution_db_accessor_, params); + &*execution_db_accessor_, params, username); } else if (utils::Downcast<IsolationLevelQuery>(parsed_query.query)) { prepared_query = PrepareIsolationLevelQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_, this); @@ -1809,7 +1830,7 @@ void RunTriggersIndividually(const utils::SkipList<Trigger> &triggers, Interpret trigger_context.AdaptForAccessor(&db_accessor); try { trigger.Execute(&db_accessor, &execution_memory, interpreter_context->config.execution_timeout_sec, - &interpreter_context->is_shutting_down, trigger_context); + &interpreter_context->is_shutting_down, trigger_context, interpreter_context->auth_checker); } catch (const utils::BasicException &exception) { spdlog::warn("Trigger '{}' failed with exception:\n{}", trigger.Name(), exception.what()); db_accessor.Abort(); @@ -1864,7 +1885,7 @@ void Interpreter::Commit() { AdvanceCommand(); try { trigger.Execute(&*execution_db_accessor_, &execution_memory, interpreter_context_->config.execution_timeout_sec, - &interpreter_context_->is_shutting_down, *trigger_context); + &interpreter_context_->is_shutting_down, *trigger_context, interpreter_context_->auth_checker); } catch (const utils::BasicException &e) { throw utils::BasicException( fmt::format("Trigger '{}' caused the transaction to fail.\nException: {}", trigger.Name(), e.what())); diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index 700809fa3..ec07b74f9 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -2,6 +2,7 @@ #include <gflags/gflags.h> +#include "query/auth_checker.hpp" #include "query/config.hpp" #include "query/context.hpp" #include "query/cypher_query_interpreter.hpp" @@ -166,6 +167,7 @@ struct InterpreterContext { std::atomic<bool> is_shutting_down{false}; AuthQueryHandler *auth{nullptr}; + query::AuthChecker *auth_checker{nullptr}; utils::SkipList<QueryCacheEntry> ast_cache; utils::SkipList<PlanCacheEntry> plan_cache; @@ -205,7 +207,8 @@ class Interpreter final { * * @throw query::QueryException */ - PrepareResult Prepare(const std::string &query, const std::map<std::string, storage::PropertyValue> ¶ms); + PrepareResult Prepare(const std::string &query, const std::map<std::string, storage::PropertyValue> ¶ms, + const std::string *username); /** * Execute the last prepared query and stream *all* of the results into the diff --git a/src/query/streams.cpp b/src/query/streams.cpp index ae0c0443b..8f019b302 100644 --- a/src/query/streams.cpp +++ b/src/query/streams.cpp @@ -106,6 +106,7 @@ const std::string kBatchIntervalKey{"batch_interval"}; const std::string kBatchSizeKey{"batch_size"}; const std::string kIsRunningKey{"is_running"}; const std::string kTransformationName{"transformation_name"}; +const std::string kOwner{"owner"}; void to_json(nlohmann::json &data, StreamStatus &&status) { auto &info = status.info; @@ -127,6 +128,12 @@ void to_json(nlohmann::json &data, StreamStatus &&status) { data[kIsRunningKey] = status.is_running; data[kTransformationName] = status.info.transformation_name; + + if (info.owner.has_value()) { + data[kOwner] = std::move(*info.owner); + } else { + data[kOwner] = nullptr; + } } void from_json(const nlohmann::json &data, StreamStatus &status) { @@ -135,16 +142,14 @@ void from_json(const nlohmann::json &data, StreamStatus &status) { data.at(kTopicsKey).get_to(info.topics); data.at(kConsumerGroupKey).get_to(info.consumer_group); - const auto batch_interval = data.at(kBatchIntervalKey); - if (!batch_interval.is_null()) { + if (const auto batch_interval = data.at(kBatchIntervalKey); !batch_interval.is_null()) { using BatchInterval = decltype(info.batch_interval)::value_type; info.batch_interval = BatchInterval{batch_interval.get<BatchInterval::rep>()}; } else { info.batch_interval = {}; } - const auto batch_size = data.at(kBatchSizeKey); - if (!batch_size.is_null()) { + if (const auto batch_size = data.at(kBatchSizeKey); !batch_size.is_null()) { info.batch_size = batch_size.get<decltype(info.batch_size)::value_type>(); } else { info.batch_size = {}; @@ -152,6 +157,12 @@ void from_json(const nlohmann::json &data, StreamStatus &status) { data.at(kIsRunningKey).get_to(status.is_running); data.at(kTransformationName).get_to(status.info.transformation_name); + + if (const auto &owner = data.at(kOwner); !owner.is_null()) { + info.owner = owner.get<decltype(info.owner)::value_type>(); + } else { + info.owner = {}; + } } Streams::Streams(InterpreterContext *interpreter_context, std::string bootstrap_servers, @@ -200,7 +211,8 @@ void Streams::Create(const std::string &stream_name, StreamInfo info) { auto it = CreateConsumer(*locked_streams, stream_name, std::move(info)); try { - Persist(CreateStatus(stream_name, it->second.transformation_name, *it->second.consumer->ReadLock())); + Persist( + CreateStatus(stream_name, it->second.transformation_name, it->second.owner, *it->second.consumer->ReadLock())); } catch (...) { locked_streams->erase(it); throw; @@ -233,7 +245,7 @@ void Streams::Start(const std::string &stream_name) { auto locked_consumer = it->second.consumer->Lock(); locked_consumer->Start(); - Persist(CreateStatus(stream_name, it->second.transformation_name, *locked_consumer)); + Persist(CreateStatus(stream_name, it->second.transformation_name, it->second.owner, *locked_consumer)); } void Streams::Stop(const std::string &stream_name) { @@ -243,7 +255,7 @@ void Streams::Stop(const std::string &stream_name) { auto locked_consumer = it->second.consumer->Lock(); locked_consumer->Stop(); - Persist(CreateStatus(stream_name, it->second.transformation_name, *locked_consumer)); + Persist(CreateStatus(stream_name, it->second.transformation_name, it->second.owner, *locked_consumer)); } void Streams::StartAll() { @@ -251,7 +263,7 @@ void Streams::StartAll() { auto locked_consumer = stream_data.consumer->Lock(); if (!locked_consumer->IsRunning()) { locked_consumer->Start(); - Persist(CreateStatus(stream_name, stream_data.transformation_name, *locked_consumer)); + Persist(CreateStatus(stream_name, stream_data.transformation_name, stream_data.owner, *locked_consumer)); } } } @@ -261,7 +273,7 @@ void Streams::StopAll() { auto locked_consumer = stream_data.consumer->Lock(); if (locked_consumer->IsRunning()) { locked_consumer->Stop(); - Persist(CreateStatus(stream_name, stream_data.transformation_name, *locked_consumer)); + Persist(CreateStatus(stream_name, stream_data.transformation_name, stream_data.owner, *locked_consumer)); } } } @@ -270,8 +282,8 @@ std::vector<StreamStatus> Streams::GetStreamInfo() const { std::vector<StreamStatus> result; { for (auto locked_streams = streams_.ReadLock(); const auto &[stream_name, stream_data] : *locked_streams) { - result.emplace_back( - CreateStatus(stream_name, stream_data.transformation_name, *stream_data.consumer->ReadLock())); + result.emplace_back(CreateStatus(stream_name, stream_data.transformation_name, stream_data.owner, + *stream_data.consumer->ReadLock())); } } return result; @@ -314,6 +326,7 @@ TransformationResult Streams::Check(const std::string &stream_name, std::optiona } StreamStatus Streams::CreateStatus(const std::string &name, const std::string &transformation_name, + const std::optional<std::string> &owner, const integrations::kafka::Consumer &consumer) { const auto &info = consumer.Info(); return StreamStatus{name, @@ -323,6 +336,7 @@ StreamStatus Streams::CreateStatus(const std::string &name, const std::string &t info.batch_interval, info.batch_size, transformation_name, + owner, }, consumer.IsRunning()}; } @@ -336,7 +350,7 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std auto *memory_resource = utils::NewDeleteResource(); auto consumer_function = [interpreter_context = interpreter_context_, memory_resource, stream_name, - transformation_name = stream_info.transformation_name, + transformation_name = stream_info.transformation_name, owner = stream_info.owner, interpreter = std::make_shared<Interpreter>(interpreter_context_), result = mgp_result{nullptr, memory_resource}]( const std::vector<integrations::kafka::Message> &messages) mutable { @@ -347,6 +361,10 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std DiscardValueResultStream stream; spdlog::trace("Start transaction in stream '{}'", stream_name); + utils::OnScopeExit cleanup{[&interpreter, &result]() { + result.rows.clear(); + interpreter->Abort(); + }}; interpreter->BeginTransaction(); for (auto &row : result.rows) { @@ -358,7 +376,13 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std std::string query{query_value.ValueString()}; spdlog::trace("Executing query '{}' in stream '{}'", query, stream_name); auto prepare_result = - interpreter->Prepare(query, params_prop.IsNull() ? empty_parameters : params_prop.ValueMap()); + interpreter->Prepare(query, params_prop.IsNull() ? empty_parameters : params_prop.ValueMap(), nullptr); + if (!interpreter_context->auth_checker->IsUserAuthorized(owner, prepare_result.privileges)) { + throw StreamsException{ + "Couldn't execute query '{}' for stream '{}' becuase the owner is not authorized to execute the " + "query!", + query, stream_name}; + } interpreter->PullAll(&stream); } @@ -376,7 +400,7 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std }; auto insert_result = map.insert_or_assign( - stream_name, StreamData{std::move(stream_info.transformation_name), + stream_name, StreamData{std::move(stream_info.transformation_name), std::move(stream_info.owner), std::make_unique<SynchronizedConsumer>(bootstrap_servers_, std::move(consumer_info), std::move(consumer_function))}); MG_ASSERT(insert_result.second, "Unexpected error during storing consumer '{}'", stream_name); diff --git a/src/query/streams.hpp b/src/query/streams.hpp index de155dba1..bec2172f6 100644 --- a/src/query/streams.hpp +++ b/src/query/streams.hpp @@ -28,6 +28,7 @@ struct StreamInfo { std::optional<std::chrono::milliseconds> batch_interval; std::optional<int64_t> batch_size; std::string transformation_name; + std::optional<std::string> owner; }; struct StreamStatus { @@ -40,6 +41,7 @@ using SynchronizedConsumer = utils::Synchronized<integrations::kafka::Consumer, struct StreamData { std::string transformation_name; + std::optional<std::string> owner; std::unique_ptr<SynchronizedConsumer> consumer; }; @@ -131,6 +133,7 @@ class Streams final { using SynchronizedStreamsMap = utils::Synchronized<StreamsMap, utils::WritePrioritizedRWLock>; static StreamStatus CreateStatus(const std::string &name, const std::string &transformation_name, + const std::optional<std::string> &owner, const integrations::kafka::Consumer &consumer); StreamsMap::iterator CreateConsumer(StreamsMap &map, const std::string &stream_name, StreamInfo stream_info); diff --git a/src/query/trigger.cpp b/src/query/trigger.cpp index edd4a08e9..14b92350e 100644 --- a/src/query/trigger.cpp +++ b/src/query/trigger.cpp @@ -142,51 +142,55 @@ std::vector<std::pair<Identifier, TriggerIdentifierTag>> GetPredefinedIdentifier Trigger::Trigger(std::string name, const std::string &query, const std::map<std::string, storage::PropertyValue> &user_parameters, const TriggerEventType event_type, utils::SkipList<QueryCacheEntry> *query_cache, - DbAccessor *db_accessor, utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config) + DbAccessor *db_accessor, utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config, + std::optional<std::string> owner, const query::AuthChecker *auth_checker) : name_{std::move(name)}, parsed_statements_{ParseQuery(query, user_parameters, query_cache, antlr_lock, query_config)}, - event_type_{event_type} { + event_type_{event_type}, + owner_{std::move(owner)} { // We check immediately if the query is valid by trying to create a plan. - GetPlan(db_accessor); + GetPlan(db_accessor, auth_checker); } Trigger::TriggerPlan::TriggerPlan(std::unique_ptr<LogicalPlan> logical_plan, std::vector<IdentifierInfo> identifiers) : cached_plan(std::move(logical_plan)), identifiers(std::move(identifiers)) {} -std::shared_ptr<Trigger::TriggerPlan> Trigger::GetPlan(DbAccessor *db_accessor) const { +std::shared_ptr<Trigger::TriggerPlan> Trigger::GetPlan(DbAccessor *db_accessor, + const query::AuthChecker *auth_checker) const { std::lock_guard plan_guard{plan_lock_}; - if (parsed_statements_.is_cacheable && trigger_plan_ && !trigger_plan_->cached_plan.IsExpired()) { - return trigger_plan_; + if (!parsed_statements_.is_cacheable || !trigger_plan_ || trigger_plan_->cached_plan.IsExpired()) { + auto identifiers = GetPredefinedIdentifiers(event_type_); + + AstStorage ast_storage; + ast_storage.properties_ = parsed_statements_.ast_storage.properties_; + ast_storage.labels_ = parsed_statements_.ast_storage.labels_; + ast_storage.edge_types_ = parsed_statements_.ast_storage.edge_types_; + + std::vector<Identifier *> predefined_identifiers; + predefined_identifiers.reserve(identifiers.size()); + std::transform(identifiers.begin(), identifiers.end(), std::back_inserter(predefined_identifiers), + [](auto &identifier) { return &identifier.first; }); + + auto logical_plan = MakeLogicalPlan(std::move(ast_storage), utils::Downcast<CypherQuery>(parsed_statements_.query), + parsed_statements_.parameters, db_accessor, predefined_identifiers); + + trigger_plan_ = std::make_shared<TriggerPlan>(std::move(logical_plan), std::move(identifiers)); + } + if (!auth_checker->IsUserAuthorized(owner_, parsed_statements_.required_privileges)) { + throw utils::BasicException("The owner of trigger '{}' is not authorized to execute the query!", name_); } - - auto identifiers = GetPredefinedIdentifiers(event_type_); - - AstStorage ast_storage; - ast_storage.properties_ = parsed_statements_.ast_storage.properties_; - ast_storage.labels_ = parsed_statements_.ast_storage.labels_; - ast_storage.edge_types_ = parsed_statements_.ast_storage.edge_types_; - - std::vector<Identifier *> predefined_identifiers; - predefined_identifiers.reserve(identifiers.size()); - std::transform(identifiers.begin(), identifiers.end(), std::back_inserter(predefined_identifiers), - [](auto &identifier) { return &identifier.first; }); - - auto logical_plan = MakeLogicalPlan(std::move(ast_storage), utils::Downcast<CypherQuery>(parsed_statements_.query), - parsed_statements_.parameters, db_accessor, predefined_identifiers); - - trigger_plan_ = std::make_shared<TriggerPlan>(std::move(logical_plan), std::move(identifiers)); return trigger_plan_; } void Trigger::Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution_memory, const double max_execution_time_sec, std::atomic<bool> *is_shutting_down, - const TriggerContext &context) const { + const TriggerContext &context, const AuthChecker *auth_checker) const { if (!context.ShouldEventTrigger(event_type_)) { return; } spdlog::debug("Executing trigger '{}'", name_); - auto trigger_plan = GetPlan(dba); + auto trigger_plan = GetPlan(dba, auth_checker); MG_ASSERT(trigger_plan, "Invalid trigger plan received"); auto &[plan, identifiers] = *trigger_plan; @@ -238,13 +242,14 @@ void Trigger::Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution namespace { // When the format of the persisted trigger is changed, increase this version -constexpr uint64_t kVersion{1}; +constexpr uint64_t kVersion{2}; } // namespace TriggerStore::TriggerStore(std::filesystem::path directory) : storage_{std::move(directory)} {} void TriggerStore::RestoreTriggers(utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor, - utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config) { + utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config, + const query::AuthChecker *auth_checker) { MG_ASSERT(before_commit_triggers_.size() == 0 && after_commit_triggers_.size() == 0, "Cannot restore trigger when some triggers already exist!"); spdlog::info("Loading triggers..."); @@ -292,10 +297,19 @@ void TriggerStore::RestoreTriggers(utils::SkipList<QueryCacheEntry> *query_cache } const auto user_parameters = serialization::DeserializePropertyValueMap(json_trigger_data["user_parameters"]); + const auto owner_json = json_trigger_data["owner"]; + std::optional<std::string> owner{}; + if (owner_json.is_string()) { + owner.emplace(owner_json.get<std::string>()); + } else if (!owner_json.is_null()) { + spdlog::warn(invalid_state_message); + continue; + } + std::optional<Trigger> trigger; try { trigger.emplace(trigger_name, statement, user_parameters, event_type, query_cache, db_accessor, antlr_lock, - query_config); + query_config, std::move(owner), auth_checker); } catch (const utils::BasicException &e) { spdlog::warn("Failed to create trigger '{}' because: {}", trigger_name, e.what()); continue; @@ -309,11 +323,12 @@ void TriggerStore::RestoreTriggers(utils::SkipList<QueryCacheEntry> *query_cache } } -void TriggerStore::AddTrigger(const std::string &name, const std::string &query, +void TriggerStore::AddTrigger(std::string name, const std::string &query, const std::map<std::string, storage::PropertyValue> &user_parameters, TriggerEventType event_type, TriggerPhase phase, utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor, - utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config) { + utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config, + std::optional<std::string> owner, const query::AuthChecker *auth_checker) { std::unique_lock store_guard{store_lock_}; if (storage_.Get(name)) { throw utils::BasicException("Trigger with the same name already exists."); @@ -321,7 +336,8 @@ void TriggerStore::AddTrigger(const std::string &name, const std::string &query, std::optional<Trigger> trigger; try { - trigger.emplace(name, query, user_parameters, event_type, query_cache, db_accessor, antlr_lock, query_config); + trigger.emplace(std::move(name), query, user_parameters, event_type, query_cache, db_accessor, antlr_lock, + query_config, std::move(owner), auth_checker); } catch (const utils::BasicException &e) { const auto identifiers = GetPredefinedIdentifiers(event_type); std::stringstream identifier_names_stream; @@ -342,7 +358,13 @@ void TriggerStore::AddTrigger(const std::string &name, const std::string &query, data["event_type"] = event_type; data["phase"] = phase; data["version"] = kVersion; - storage_.Put(name, data.dump()); + + if (const auto &owner_from_trigger = trigger->Owner(); owner_from_trigger.has_value()) { + data["owner"] = *owner_from_trigger; + } else { + data["owner"] = nullptr; + } + storage_.Put(trigger->Name(), data.dump()); store_guard.unlock(); auto triggers_acc = @@ -384,7 +406,7 @@ std::vector<TriggerStore::TriggerInfo> TriggerStore::GetTriggerInfo() const { const auto add_info = [&](const utils::SkipList<Trigger> &trigger_list, const TriggerPhase phase) { for (const auto &trigger : trigger_list.access()) { - info.push_back({trigger.Name(), trigger.OriginalStatement(), trigger.EventType(), phase}); + info.push_back({trigger.Name(), trigger.OriginalStatement(), trigger.EventType(), phase, trigger.Owner()}); } }; diff --git a/src/query/trigger.hpp b/src/query/trigger.hpp index 49f505d67..0261454d2 100644 --- a/src/query/trigger.hpp +++ b/src/query/trigger.hpp @@ -9,6 +9,7 @@ #include <vector> #include "kvstore/kvstore.hpp" +#include "query/auth_checker.hpp" #include "query/config.hpp" #include "query/cypher_query_interpreter.hpp" #include "query/db_accessor.hpp" @@ -23,10 +24,12 @@ struct Trigger { explicit Trigger(std::string name, const std::string &query, const std::map<std::string, storage::PropertyValue> &user_parameters, TriggerEventType event_type, utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor, utils::SpinLock *antlr_lock, - const InterpreterConfig::Query &query_config); + const InterpreterConfig::Query &query_config, std::optional<std::string> owner, + const query::AuthChecker *auth_checker); void Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution_memory, double max_execution_time_sec, - std::atomic<bool> *is_shutting_down, const TriggerContext &context) const; + std::atomic<bool> *is_shutting_down, const TriggerContext &context, + const AuthChecker *auth_checker) const; bool operator==(const Trigger &other) const { return name_ == other.name_; } // NOLINTNEXTLINE (modernize-use-nullptr) @@ -37,6 +40,7 @@ struct Trigger { const auto &Name() const noexcept { return name_; } const auto &OriginalStatement() const noexcept { return parsed_statements_.query_string; } + const auto &Owner() const noexcept { return owner_; } auto EventType() const noexcept { return event_type_; } private: @@ -48,7 +52,7 @@ struct Trigger { CachedPlan cached_plan; std::vector<IdentifierInfo> identifiers; }; - std::shared_ptr<TriggerPlan> GetPlan(DbAccessor *db_accessor) const; + std::shared_ptr<TriggerPlan> GetPlan(DbAccessor *db_accessor, const query::AuthChecker *auth_checker) const; std::string name_; ParsedQuery parsed_statements_; @@ -57,6 +61,7 @@ struct Trigger { mutable utils::SpinLock plan_lock_; mutable std::shared_ptr<TriggerPlan> trigger_plan_; + std::optional<std::string> owner_; }; enum class TriggerPhase : uint8_t { BEFORE_COMMIT, AFTER_COMMIT }; @@ -65,12 +70,14 @@ struct TriggerStore { explicit TriggerStore(std::filesystem::path directory); void RestoreTriggers(utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor, - utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config); + utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config, + const query::AuthChecker *auth_checker); - void AddTrigger(const std::string &name, const std::string &query, + void AddTrigger(std::string name, const std::string &query, const std::map<std::string, storage::PropertyValue> &user_parameters, TriggerEventType event_type, TriggerPhase phase, utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor, - utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config); + utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config, + std::optional<std::string> owner, const query::AuthChecker *auth_checker); void DropTrigger(const std::string &name); @@ -79,6 +86,7 @@ struct TriggerStore { std::string statement; TriggerEventType event_type; TriggerPhase phase; + std::optional<std::string> owner; }; std::vector<TriggerInfo> GetTriggerInfo() const; diff --git a/tests/benchmark/expansion.cpp b/tests/benchmark/expansion.cpp index 7cf7d3246..5b9565b8a 100644 --- a/tests/benchmark/expansion.cpp +++ b/tests/benchmark/expansion.cpp @@ -54,7 +54,7 @@ BENCHMARK_DEFINE_F(ExpansionBenchFixture, Match)(benchmark::State &state) { while (state.KeepRunning()) { ResultStreamFaker results(&*db); - interpreter->Prepare(query, {}); + interpreter->Prepare(query, {}, nullptr); interpreter->PullAll(&results); } } @@ -69,7 +69,7 @@ BENCHMARK_DEFINE_F(ExpansionBenchFixture, Expand)(benchmark::State &state) { while (state.KeepRunning()) { ResultStreamFaker results(&*db); - interpreter->Prepare(query, {}); + interpreter->Prepare(query, {}, nullptr); interpreter->PullAll(&results); } } diff --git a/tests/e2e/streams/CMakeLists.txt b/tests/e2e/streams/CMakeLists.txt index d142ae61a..44d5506f4 100644 --- a/tests/e2e/streams/CMakeLists.txt +++ b/tests/e2e/streams/CMakeLists.txt @@ -6,6 +6,9 @@ add_custom_target(memgraph__e2e__streams__${FILE_NAME} ALL DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${FILE_NAME}) endfunction() +copy_streams_e2e_python_files(common.py) +copy_streams_e2e_python_files(conftest.py) copy_streams_e2e_python_files(streams_tests.py) +copy_streams_e2e_python_files(streams_owner_tests.py) copy_streams_e2e_python_files(streams_test_runner.sh) add_subdirectory(transformations) diff --git a/tests/e2e/streams/common.py b/tests/e2e/streams/common.py new file mode 100644 index 000000000..4245781b4 --- /dev/null +++ b/tests/e2e/streams/common.py @@ -0,0 +1,110 @@ +import mgclient +import time + +# These are the indices of the different values in the result of SHOW STREAM +# query +NAME = 0 +TOPICS = 1 +CONSUMER_GROUP = 2 +BATCH_INTERVAL = 3 +BATCH_SIZE = 4 +TRANSFORM = 5 +OWNER = 6 +IS_RUNNING = 7 + + +def execute_and_fetch_all(cursor, query): + cursor.execute(query) + return cursor.fetchall() + + +def connect(**kwargs): + connection = mgclient.connect(host="localhost", port=7687, **kwargs) + connection.autocommit = True + return connection + + +def timed_wait(fun): + start_time = time.time() + seconds = 10 + + while True: + current_time = time.time() + elapsed_time = current_time - start_time + + if elapsed_time > seconds: + return False + + if fun(): + return True + + time.sleep(0.1) + + +def check_one_result_row(cursor, query): + start_time = time.time() + seconds = 10 + + while True: + current_time = time.time() + elapsed_time = current_time - start_time + + if elapsed_time > seconds: + return False + + cursor.execute(query) + results = cursor.fetchall() + if len(results) < 1: + time.sleep(0.1) + continue + + return len(results) == 1 + + +def check_vertex_exists_with_topic_and_payload(cursor, topic, payload_bytes): + assert check_one_result_row(cursor, + "MATCH (n: MESSAGE {" + f"payload: '{payload_bytes.decode('utf-8')}'," + f"topic: '{topic}'" + "}) RETURN n") + + +def get_stream_info(cursor, stream_name): + stream_infos = execute_and_fetch_all(cursor, "SHOW STREAMS") + for stream_info in stream_infos: + if (stream_info[NAME] == stream_name): + return stream_info + + return None + + +def get_is_running(cursor, stream_name): + stream_info = get_stream_info(cursor, stream_name) + + assert stream_info + return stream_info[IS_RUNNING] + + +def start_stream(cursor, stream_name): + execute_and_fetch_all(cursor, f"START STREAM {stream_name}") + + assert get_is_running(cursor, stream_name) + + +def stop_stream(cursor, stream_name): + execute_and_fetch_all(cursor, f"STOP STREAM {stream_name}") + + assert not get_is_running(cursor, stream_name) + + +def drop_stream(cursor, stream_name): + execute_and_fetch_all(cursor, f"DROP STREAM {stream_name}") + + assert get_stream_info(cursor, stream_name) is None + + +def check_stream_info(cursor, stream_name, expected_stream_info): + stream_info = get_stream_info(cursor, stream_name) + assert len(stream_info) == len(expected_stream_info) + for info, expected_info in zip(stream_info, expected_stream_info): + assert info == expected_info diff --git a/tests/e2e/streams/conftest.py b/tests/e2e/streams/conftest.py new file mode 100644 index 000000000..b724d6801 --- /dev/null +++ b/tests/e2e/streams/conftest.py @@ -0,0 +1,45 @@ +import pytest +from kafka import KafkaProducer +from kafka.admin import KafkaAdminClient, NewTopic + +from common import execute_and_fetch_all, connect, NAME + +# To run these test locally a running Kafka sever is necessery. The test tries +# to connect on localhost:9092. + + +@pytest.fixture(autouse=True) +def connection(): + connection = connect() + yield connection + cursor = connection.cursor() + execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n") + stream_infos = execute_and_fetch_all(cursor, "SHOW STREAMS") + for stream_info in stream_infos: + execute_and_fetch_all(cursor, f"DROP STREAM {stream_info[NAME]}") + users = execute_and_fetch_all(cursor, "SHOW USERS") + for username, in users: + execute_and_fetch_all(cursor, f"DROP USER {username}") + + +@pytest.fixture(scope="function") +def topics(): + admin_client = KafkaAdminClient( + bootstrap_servers="localhost:9092", client_id='test') + + topics = [] + topics_to_create = [] + for index in range(3): + topic = f"topic_{index}" + topics.append(topic) + topics_to_create.append(NewTopic(name=topic, + num_partitions=1, replication_factor=1)) + + admin_client.create_topics(new_topics=topics_to_create, timeout_ms=5000) + yield topics + admin_client.delete_topics(topics=topics, timeout_ms=5000) + + +@pytest.fixture(scope="function") +def producer(): + yield KafkaProducer(bootstrap_servers="localhost:9092") diff --git a/tests/e2e/streams/streams_owner_tests.py b/tests/e2e/streams/streams_owner_tests.py new file mode 100644 index 000000000..f48f7439c --- /dev/null +++ b/tests/e2e/streams/streams_owner_tests.py @@ -0,0 +1,151 @@ +import sys +import pytest +import time +import mgclient +import common + + +def get_cursor_with_user(username): + connection = common.connect(username=username, password="") + return connection.cursor() + + +def create_admin_user(cursor, admin_user): + common.execute_and_fetch_all(cursor, f"CREATE USER {admin_user}") + common.execute_and_fetch_all( + cursor, f"GRANT ALL PRIVILEGES TO {admin_user}") + + +def create_stream_user(cursor, stream_user): + common.execute_and_fetch_all(cursor, f"CREATE USER {stream_user}") + common.execute_and_fetch_all( + cursor, f"GRANT STREAM TO {stream_user}") + + +def test_ownerless_stream(producer, topics, connection): + assert len(topics) > 0 + userless_cursor = connection.cursor() + common.execute_and_fetch_all(userless_cursor, + "CREATE STREAM ownerless " + f"TOPICS {topics[0]} " + f"TRANSFORM transform.simple") + common.start_stream(userless_cursor, "ownerless") + time.sleep(1) + + admin_user = "admin_user" + create_admin_user(userless_cursor, admin_user) + + producer.send(topics[0], b"first message").get(timeout=60) + assert common.timed_wait( + lambda: not common.get_is_running(userless_cursor, "ownerless")) + + assert len(common.execute_and_fetch_all( + userless_cursor, "MATCH (n) RETURN n")) == 0 + + common.execute_and_fetch_all(userless_cursor, f"DROP USER {admin_user}") + common.start_stream(userless_cursor, "ownerless") + time.sleep(1) + + second_message = b"second message" + producer.send(topics[0], second_message).get(timeout=60) + common.check_vertex_exists_with_topic_and_payload( + userless_cursor, topics[0], second_message) + + assert len(common.execute_and_fetch_all( + userless_cursor, "MATCH (n) RETURN n")) == 1 + + +def test_owner_is_shown(topics, connection): + assert len(topics) > 0 + userless_cursor = connection.cursor() + + stream_user = "stream_user" + create_stream_user(userless_cursor, stream_user) + stream_cursor = get_cursor_with_user(stream_user) + + common.execute_and_fetch_all(stream_cursor, "CREATE STREAM test " + f"TOPICS {topics[0]} " + f"TRANSFORM transform.simple") + + common.check_stream_info(userless_cursor, "test", ("test", [ + topics[0]], "mg_consumer", None, None, + "transform.simple", stream_user, False)) + + +def test_insufficient_privileges(producer, topics, connection): + assert len(topics) > 0 + userless_cursor = connection.cursor() + + admin_user = "admin_user" + create_admin_user(userless_cursor, admin_user) + admin_cursor = get_cursor_with_user(admin_user) + + stream_user = "stream_user" + create_stream_user(userless_cursor, stream_user) + stream_cursor = get_cursor_with_user(stream_user) + + common.execute_and_fetch_all(stream_cursor, + "CREATE STREAM insufficient_test " + f"TOPICS {topics[0]} " + f"TRANSFORM transform.simple") + + # the stream is started by admin, but should check against the owner + # privileges + common.start_stream(admin_cursor, "insufficient_test") + time.sleep(1) + + producer.send(topics[0], b"first message").get(timeout=60) + assert common.timed_wait( + lambda: not common.get_is_running(userless_cursor, "insufficient_test")) + + assert len(common.execute_and_fetch_all( + userless_cursor, "MATCH (n) RETURN n")) == 0 + + common.execute_and_fetch_all( + admin_cursor, f"GRANT CREATE TO {stream_user}") + common.start_stream(userless_cursor, "insufficient_test") + time.sleep(1) + + second_message = b"second message" + producer.send(topics[0], second_message).get(timeout=60) + common.check_vertex_exists_with_topic_and_payload( + userless_cursor, topics[0], second_message) + + assert len(common.execute_and_fetch_all( + userless_cursor, "MATCH (n) RETURN n")) == 1 + + +def test_happy_case(producer, topics, connection): + assert len(topics) > 0 + userless_cursor = connection.cursor() + + admin_user = "admin_user" + create_admin_user(userless_cursor, admin_user) + admin_cursor = get_cursor_with_user(admin_user) + + stream_user = "stream_user" + create_stream_user(userless_cursor, stream_user) + stream_cursor = get_cursor_with_user(stream_user) + common.execute_and_fetch_all( + admin_cursor, f"GRANT CREATE TO {stream_user}") + + common.execute_and_fetch_all(stream_cursor, + "CREATE STREAM insufficient_test " + f"TOPICS {topics[0]} " + f"TRANSFORM transform.simple") + + common.start_stream(stream_cursor, "insufficient_test") + time.sleep(1) + + first_message = b"first message" + producer.send(topics[0], first_message).get(timeout=60) + + common.check_vertex_exists_with_topic_and_payload( + userless_cursor, topics[0], first_message) + + assert len(common.execute_and_fetch_all( + userless_cursor, "MATCH (n) RETURN n")) == 1 + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__, "-rA"])) diff --git a/tests/e2e/streams/streams_test_runner.sh b/tests/e2e/streams/streams_test_runner.sh index 9cc749250..9d644a2ab 100755 --- a/tests/e2e/streams/streams_test_runner.sh +++ b/tests/e2e/streams/streams_test_runner.sh @@ -3,4 +3,4 @@ # This workaround is necessary to run in the same virtualenv as the e2e runner.py DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" -python3 "$DIR/streams_tests.py" +python3 "$DIR/$1" diff --git a/tests/e2e/streams/streams_tests.py b/tests/e2e/streams/streams_tests.py index a0d66b236..ef26cafdd 100755 --- a/tests/e2e/streams/streams_tests.py +++ b/tests/e2e/streams/streams_tests.py @@ -1,28 +1,11 @@ #!/usr/bin/python3 -# To run these test locally a running Kafka sever is necessery. The test tries -# to connect on localhost:9092. - -# All tests are implemented in this file, because using the same test fixtures -# in multiple files is not possible in a straightforward way - import sys import pytest import mgclient import time from multiprocessing import Process, Value -from kafka import KafkaProducer -from kafka.admin import KafkaAdminClient, NewTopic - -# These are the indices of the different values in the result of SHOW STREAM -# query -NAME = 0 -TOPICS = 1 -CONSUMER_GROUP = 2 -BATCH_INTERVAL = 3 -BATCH_SIZE = 4 -TRANSFORM = 5 -IS_RUNNING = 6 +import common # These are the indices of the query and parameters in the result of CHECK # STREAM query @@ -35,155 +18,22 @@ TRANSFORMATIONS_TO_CHECK = [ SIMPLE_MSG = b'message' -def execute_and_fetch_all(cursor, query): - cursor.execute(query) - return cursor.fetchall() - - -def connect(): - connection = mgclient.connect(host="localhost", port=7687) - connection.autocommit = True - return connection - - -@pytest.fixture(autouse=True) -def connection(): - connection = connect() - yield connection - cursor = connection.cursor() - execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n") - stream_infos = execute_and_fetch_all(cursor, "SHOW STREAMS") - for stream_info in stream_infos: - execute_and_fetch_all(cursor, f"DROP STREAM {stream_info[NAME]}") - - -@pytest.fixture(scope="function") -def topics(): - admin_client = KafkaAdminClient( - bootstrap_servers="localhost:9092", client_id='test') - - topics = [] - topics_to_create = [] - for index in range(3): - topic = f"topic_{index}" - topics.append(topic) - topics_to_create.append(NewTopic(name=topic, - num_partitions=1, replication_factor=1)) - - admin_client.create_topics(new_topics=topics_to_create, timeout_ms=5000) - yield topics - admin_client.delete_topics(topics=topics, timeout_ms=5000) - - -@pytest.fixture(scope="function") -def producer(): - yield KafkaProducer(bootstrap_servers="localhost:9092") - - -def timed_wait(fun): - start_time = time.time() - seconds = 10 - - while True: - current_time = time.time() - elapsed_time = current_time - start_time - - if elapsed_time > seconds: - return False - - if fun(): - return True - - -def check_one_result_row(cursor, query): - start_time = time.time() - seconds = 10 - - while True: - current_time = time.time() - elapsed_time = current_time - start_time - - if elapsed_time > seconds: - return False - - cursor.execute(query) - results = cursor.fetchall() - if len(results) < 1: - time.sleep(0.1) - continue - - return len(results) == 1 - - -def check_vertex_exists_with_topic_and_payload(cursor, topic, payload_bytes): - assert check_one_result_row(cursor, - "MATCH (n: MESSAGE {" - f"payload: '{payload_bytes.decode('utf-8')}'," - f"topic: '{topic}'" - "}) RETURN n") - - -def get_stream_info(cursor, stream_name): - stream_infos = execute_and_fetch_all(cursor, "SHOW STREAMS") - for stream_info in stream_infos: - if (stream_info[NAME] == stream_name): - return stream_info - - return None - - -def get_is_running(cursor, stream_name): - stream_info = get_stream_info(cursor, stream_name) - - assert stream_info - return stream_info[IS_RUNNING] - - -def start_stream(cursor, stream_name): - execute_and_fetch_all(cursor, f"START STREAM {stream_name}") - - assert get_is_running(cursor, stream_name) - - -def stop_stream(cursor, stream_name): - execute_and_fetch_all(cursor, f"STOP STREAM {stream_name}") - - assert not get_is_running(cursor, stream_name) - - -def drop_stream(cursor, stream_name): - execute_and_fetch_all(cursor, f"DROP STREAM {stream_name}") - - assert get_stream_info(cursor, stream_name) is None - - -def check_stream_info(cursor, stream_name, expected_stream_info): - stream_info = get_stream_info(cursor, stream_name) - assert len(stream_info) == len(expected_stream_info) - for info, expected_info in zip(stream_info, expected_stream_info): - assert info == expected_info - -############################################## -# Tests -############################################## - - @pytest.mark.parametrize("transformation", TRANSFORMATIONS_TO_CHECK) def test_simple(producer, topics, connection, transformation): assert len(topics) > 0 cursor = connection.cursor() - execute_and_fetch_all(cursor, - "CREATE STREAM test " - f"TOPICS {','.join(topics)} " - f"TRANSFORM {transformation}") - start_stream(cursor, "test") + common.execute_and_fetch_all(cursor, + "CREATE STREAM test " + f"TOPICS {','.join(topics)} " + f"TRANSFORM {transformation}") + common.start_stream(cursor, "test") time.sleep(5) for topic in topics: producer.send(topic, SIMPLE_MSG).get(timeout=60) for topic in topics: - check_vertex_exists_with_topic_and_payload( + common.check_vertex_exists_with_topic_and_payload( cursor, topic, SIMPLE_MSG) @@ -196,13 +46,13 @@ def test_separate_consumers(producer, topics, connection, transformation): for topic in topics: stream_name = "stream_" + topic stream_names.append(stream_name) - execute_and_fetch_all(cursor, - f"CREATE STREAM {stream_name} " - f"TOPICS {topic} " - f"TRANSFORM {transformation}") + common.execute_and_fetch_all(cursor, + f"CREATE STREAM {stream_name} " + f"TOPICS {topic} " + f"TRANSFORM {transformation}") for stream_name in stream_names: - start_stream(cursor, stream_name) + common.start_stream(cursor, stream_name) time.sleep(5) @@ -210,7 +60,7 @@ def test_separate_consumers(producer, topics, connection, transformation): producer.send(topic, SIMPLE_MSG).get(timeout=60) for topic in topics: - check_vertex_exists_with_topic_and_payload( + common.check_vertex_exists_with_topic_and_payload( cursor, topic, SIMPLE_MSG) @@ -223,41 +73,42 @@ def test_start_from_last_committed_offset(producer, topics, connection): # restarting Memgraph during a single workload cannot be done currently. assert len(topics) > 0 cursor = connection.cursor() - execute_and_fetch_all(cursor, - "CREATE STREAM test " - f"TOPICS {topics[0]} " - "TRANSFORM transform.simple") - start_stream(cursor, "test") + common.execute_and_fetch_all(cursor, + "CREATE STREAM test " + f"TOPICS {topics[0]} " + "TRANSFORM transform.simple") + common.start_stream(cursor, "test") time.sleep(1) producer.send(topics[0], SIMPLE_MSG).get(timeout=60) - check_vertex_exists_with_topic_and_payload( + common.check_vertex_exists_with_topic_and_payload( cursor, topics[0], SIMPLE_MSG) - stop_stream(cursor, "test") - drop_stream(cursor, "test") + common.stop_stream(cursor, "test") + common.drop_stream(cursor, "test") messages = [b"second message", b"third message"] for message in messages: producer.send(topics[0], message).get(timeout=60) for message in messages: - vertices_with_msg = execute_and_fetch_all(cursor, - "MATCH (n: MESSAGE {" - f"payload: '{message.decode('utf-8')}'" - "}) RETURN n") + vertices_with_msg = common.execute_and_fetch_all( + cursor, + "MATCH (n: MESSAGE {" + f"payload: '{message.decode('utf-8')}'" + "}) RETURN n") assert len(vertices_with_msg) == 0 - execute_and_fetch_all(cursor, - "CREATE STREAM test " - f"TOPICS {topics[0]} " - "TRANSFORM transform.simple") - start_stream(cursor, "test") + common.execute_and_fetch_all(cursor, + "CREATE STREAM test " + f"TOPICS {topics[0]} " + "TRANSFORM transform.simple") + common.start_stream(cursor, "test") for message in messages: - check_vertex_exists_with_topic_and_payload( + common.check_vertex_exists_with_topic_and_payload( cursor, topics[0], message) @@ -265,16 +116,16 @@ def test_start_from_last_committed_offset(producer, topics, connection): def test_check_stream(producer, topics, connection, transformation): assert len(topics) > 0 cursor = connection.cursor() - execute_and_fetch_all(cursor, - "CREATE STREAM test " - f"TOPICS {topics[0]} " - f"TRANSFORM {transformation} " - "BATCH_SIZE 1") - start_stream(cursor, "test") + common.execute_and_fetch_all(cursor, + "CREATE STREAM test " + f"TOPICS {topics[0]} " + f"TRANSFORM {transformation} " + "BATCH_SIZE 1") + common.start_stream(cursor, "test") time.sleep(1) producer.send(topics[0], SIMPLE_MSG).get(timeout=60) - stop_stream(cursor, "test") + common.stop_stream(cursor, "test") messages = [b"first message", b"second message", b"third message"] for message in messages: @@ -283,7 +134,7 @@ def test_check_stream(producer, topics, connection, transformation): def check_check_stream(batch_limit): assert transformation == "transform.simple" \ or transformation == "transform.with_parameters" - test_results = execute_and_fetch_all( + test_results = common.execute_and_fetch_all( cursor, f"CHECK STREAM test BATCH_LIMIT {batch_limit}") assert len(test_results) == batch_limit @@ -308,42 +159,47 @@ def test_check_stream(producer, topics, connection, transformation): check_check_stream(1) check_check_stream(2) check_check_stream(3) - start_stream(cursor, "test") + common.start_stream(cursor, "test") for message in messages: - check_vertex_exists_with_topic_and_payload( + common.check_vertex_exists_with_topic_and_payload( cursor, topics[0], message) def test_show_streams(producer, topics, connection): assert len(topics) > 1 cursor = connection.cursor() - execute_and_fetch_all(cursor, - "CREATE STREAM default_values " - f"TOPICS {topics[0]} " - f"TRANSFORM transform.simple") + common.execute_and_fetch_all(cursor, + "CREATE STREAM default_values " + f"TOPICS {topics[0]} " + f"TRANSFORM transform.simple") consumer_group = "my_special_consumer_group" batch_interval = 42 batch_size = 3 - execute_and_fetch_all(cursor, - "CREATE STREAM complex_values " - f"TOPICS {','.join(topics)} " - f"TRANSFORM transform.with_parameters " - f"CONSUMER_GROUP {consumer_group} " - f"BATCH_INTERVAL {batch_interval} " - f"BATCH_SIZE {batch_size} ") + common.execute_and_fetch_all(cursor, + "CREATE STREAM complex_values " + f"TOPICS {','.join(topics)} " + f"TRANSFORM transform.with_parameters " + f"CONSUMER_GROUP {consumer_group} " + f"BATCH_INTERVAL {batch_interval} " + f"BATCH_SIZE {batch_size} ") - assert len(execute_and_fetch_all(cursor, "SHOW STREAMS")) == 2 + assert len(common.execute_and_fetch_all(cursor, "SHOW STREAMS")) == 2 - check_stream_info(cursor, "default_values", ("default_values", [ - topics[0]], "mg_consumer", None, None, - "transform.simple", False)) + common.check_stream_info(cursor, "default_values", ("default_values", [ + topics[0]], "mg_consumer", None, None, + "transform.simple", None, False)) - check_stream_info(cursor, "complex_values", ("complex_values", topics, - consumer_group, batch_interval, batch_size, - "transform.with_parameters", - False)) + common.check_stream_info(cursor, "complex_values", ( + "complex_values", + topics, + consumer_group, + batch_interval, + batch_size, + "transform.with_parameters", + None, + False)) @pytest.mark.parametrize("operation", ["START", "STOP"]) @@ -360,10 +216,10 @@ def test_start_and_stop_during_check(producer, topics, connection, operation): assert len(topics) > 1 assert operation == "START" or operation == "STOP" cursor = connection.cursor() - execute_and_fetch_all(cursor, - "CREATE STREAM test_stream " - f"TOPICS {topics[0]} " - f"TRANSFORM transform.simple") + common.execute_and_fetch_all(cursor, + "CREATE STREAM test_stream " + f"TOPICS {topics[0]} " + f"TRANSFORM transform.simple") check_counter = Value('i', 0) check_result_len = Value('i', 0) @@ -377,10 +233,11 @@ def test_start_and_stop_during_check(producer, topics, connection, operation): def call_check(counter, result_len): # This process will call the CHECK query and increment the counter # based on its progress and expected behavior - connection = connect() + connection = common.connect() cursor = connection.cursor() counter.value = CHECK_BEFORE_EXECUTE - result = execute_and_fetch_all(cursor, "CHECK STREAM test_stream") + result = common.execute_and_fetch_all( + cursor, "CHECK STREAM test_stream") result_len.value = len(result) counter.value = CHECK_AFTER_FETCHALL if len(result) > 0 and "payload: 'message'" in result[0][QUERY]: @@ -397,11 +254,12 @@ def test_start_and_stop_during_check(producer, topics, connection, operation): def call_operation(counter): # This porcess will call the query with the specified operation and # increment the counter based on its progress and expected behavior - connection = connect() + connection = common.connect() cursor = connection.cursor() counter.value = OP_BEFORE_EXECUTE try: - execute_and_fetch_all(cursor, f"{operation} STREAM test_stream") + common.execute_and_fetch_all( + cursor, f"{operation} STREAM test_stream") counter.value = OP_AFTER_FETCHALL except mgclient.DatabaseError as e: if "Kafka consumer test_stream is already stopped" in str(e): @@ -421,15 +279,19 @@ def test_start_and_stop_during_check(producer, topics, connection, operation): time.sleep(0.5) - assert timed_wait(lambda: check_counter.value == CHECK_BEFORE_EXECUTE) - assert timed_wait(lambda: get_is_running(cursor, "test_stream")) + assert common.timed_wait( + lambda: check_counter.value == CHECK_BEFORE_EXECUTE) + assert common.timed_wait( + lambda: common.get_is_running(cursor, "test_stream")) assert check_counter.value == CHECK_BEFORE_EXECUTE, "SHOW STREAMS " \ "was blocked until the end of CHECK STREAM" operation_proc.start() - assert timed_wait(lambda: operation_counter.value == OP_BEFORE_EXECUTE) + assert common.timed_wait( + lambda: operation_counter.value == OP_BEFORE_EXECUTE) producer.send(topics[0], SIMPLE_MSG).get(timeout=60) - assert timed_wait(lambda: check_counter.value > CHECK_AFTER_FETCHALL) + assert common.timed_wait( + lambda: check_counter.value > CHECK_AFTER_FETCHALL) assert check_counter.value == CHECK_CORRECT_RESULT assert check_result_len.value == 1 check_stream_proc.join() @@ -437,10 +299,10 @@ def test_start_and_stop_during_check(producer, topics, connection, operation): operation_proc.join() if operation == "START": assert operation_counter.value == OP_AFTER_FETCHALL - assert get_is_running(cursor, "test_stream") + assert common.get_is_running(cursor, "test_stream") else: assert operation_counter.value == OP_ALREADY_STOPPED_EXCEPTION - assert not get_is_running(cursor, "test_stream") + assert not common.get_is_running(cursor, "test_stream") finally: # to make sure CHECK STREAM finishes @@ -455,42 +317,64 @@ def test_check_already_started_stream(topics, connection): assert len(topics) > 0 cursor = connection.cursor() - execute_and_fetch_all(cursor, - "CREATE STREAM started_stream " - f"TOPICS {topics[0]} " - f"TRANSFORM transform.simple") - start_stream(cursor, "started_stream") + common.execute_and_fetch_all(cursor, + "CREATE STREAM started_stream " + f"TOPICS {topics[0]} " + f"TRANSFORM transform.simple") + common.start_stream(cursor, "started_stream") with pytest.raises(mgclient.DatabaseError): - execute_and_fetch_all(cursor, "CHECK STREAM started_stream") + common.execute_and_fetch_all(cursor, "CHECK STREAM started_stream") def test_start_checked_stream_after_timeout(topics, connection): cursor = connection.cursor() - execute_and_fetch_all(cursor, - "CREATE STREAM test_stream " - f"TOPICS {topics[0]} " - f"TRANSFORM transform.simple") + common.execute_and_fetch_all(cursor, + "CREATE STREAM test_stream " + f"TOPICS {topics[0]} " + f"TRANSFORM transform.simple") timeout_ms = 2000 def call_check(): - execute_and_fetch_all( - connect().cursor(), + common.execute_and_fetch_all( + common.connect().cursor(), f"CHECK STREAM test_stream TIMEOUT {timeout_ms}") check_stream_proc = Process(target=call_check, daemon=True) start = time.time() check_stream_proc.start() - assert timed_wait(lambda: get_is_running(cursor, "test_stream")) - start_stream(cursor, "test_stream") + assert common.timed_wait( + lambda: common.get_is_running(cursor, "test_stream")) + common.start_stream(cursor, "test_stream") end = time.time() assert (end - start) < 1.3 * \ timeout_ms, "The START STREAM was blocked too long" - assert get_is_running(cursor, "test_stream") - stop_stream(cursor, "test_stream") + assert common.get_is_running(cursor, "test_stream") + common.stop_stream(cursor, "test_stream") + + +def test_restart_after_error(producer, topics, connection): + cursor = connection.cursor() + common.execute_and_fetch_all(cursor, + "CREATE STREAM test_stream " + f"TOPICS {topics[0]} " + f"TRANSFORM transform.query") + + common.start_stream(cursor, "test_stream") + time.sleep(1) + + producer.send(topics[0], SIMPLE_MSG).get(timeout=60) + assert common.timed_wait( + lambda: not common.get_is_running(cursor, "test_stream")) + + common.start_stream(cursor, "test_stream") + time.sleep(1) + producer.send(topics[0], b'CREATE (n:VERTEX { id : 42 })') + assert common.check_one_result_row( + cursor, "MATCH (n:VERTEX { id : 42 }) RETURN n") if __name__ == "__main__": diff --git a/tests/e2e/streams/transformations/transform.py b/tests/e2e/streams/transformations/transform.py index b15ff13a1..249b72e61 100644 --- a/tests/e2e/streams/transformations/transform.py +++ b/tests/e2e/streams/transformations/transform.py @@ -35,3 +35,17 @@ def with_parameters(context: mgp.TransCtx, "topic": message.topic_name()})) return result_queries + + +@mgp.transformation +def query(messages: mgp.Messages + ) -> mgp.Record(query=str, parameters=mgp.Nullable[mgp.Map]): + result_queries = [] + + for i in range(0, messages.total_messages()): + message = messages.message_at(i) + payload_as_str = message.payload().decode("utf-8") + result_queries.append(mgp.Record( + query=payload_as_str, parameters=None)) + + return result_queries diff --git a/tests/e2e/streams/workloads.yaml b/tests/e2e/streams/workloads.yaml index 624707f13..f9f52f9f3 100644 --- a/tests/e2e/streams/workloads.yaml +++ b/tests/e2e/streams/workloads.yaml @@ -10,5 +10,10 @@ workloads: - name: "Streams start, stop and show" binary: "tests/e2e/streams/streams_test_runner.sh" proc: "tests/e2e/streams/transformations/" - args: [] + args: ["streams_tests.py"] + <<: *template_cluster + - name: "Streams with users" + binary: "tests/e2e/streams/streams_test_runner.sh" + proc: "tests/e2e/streams/transformations/" + args: ["streams_owner_tests.py"] <<: *template_cluster diff --git a/tests/e2e/triggers/CMakeLists.txt b/tests/e2e/triggers/CMakeLists.txt index d1587c56b..090f006f5 100644 --- a/tests/e2e/triggers/CMakeLists.txt +++ b/tests/e2e/triggers/CMakeLists.txt @@ -9,3 +9,6 @@ target_link_libraries(memgraph__e2e__triggers__on_update memgraph__e2e__triggers add_executable(memgraph__e2e__triggers__on_delete on_delete_triggers.cpp) target_link_libraries(memgraph__e2e__triggers__on_delete memgraph__e2e__triggers_common) + +add_executable(memgraph__e2e__triggers__privileges privilige_check.cpp) +target_link_libraries(memgraph__e2e__triggers__privileges memgraph__e2e__triggers_common) diff --git a/tests/e2e/triggers/common.cpp b/tests/e2e/triggers/common.cpp index 17f916217..955e848f4 100644 --- a/tests/e2e/triggers/common.cpp +++ b/tests/e2e/triggers/common.cpp @@ -1,7 +1,9 @@ #include "common.hpp" +#include <chrono> #include <cstdint> #include <optional> +#include <thread> #include <fmt/format.h> #include <gflags/gflags.h> @@ -10,13 +12,17 @@ DEFINE_uint64(bolt_port, 7687, "Bolt port"); -std::unique_ptr<mg::Client> Connect() { - auto client = - mg::Client::Connect({.host = "127.0.0.1", .port = static_cast<uint16_t>(FLAGS_bolt_port), .use_ssl = false}); +std::unique_ptr<mg::Client> ConnectWithUser(const std::string_view username) { + auto client = mg::Client::Connect({.host = "127.0.0.1", + .port = static_cast<uint16_t>(FLAGS_bolt_port), + .username = std::string{username}, + .use_ssl = false}); MG_ASSERT(client, "Failed to connect!"); return client; } +std::unique_ptr<mg::Client> Connect() { return ConnectWithUser(""); } + void CreateVertex(mg::Client &client, int vertex_id) { mg::Map parameters{ {"id", mg::Value{vertex_id}}, @@ -49,10 +55,12 @@ int GetNumberOfAllVertices(mg::Client &client) { } void WaitForNumberOfAllVertices(mg::Client &client, int number_of_vertices) { + using namespace std::chrono_literals; utils::Timer timer{}; while ((timer.Elapsed().count() <= 0.5) && GetNumberOfAllVertices(client) != number_of_vertices) { } CheckNumberOfAllVertices(client, number_of_vertices); + std::this_thread::sleep_for(100ms); } void CheckNumberOfAllVertices(mg::Client &client, int expected_number_of_vertices) { diff --git a/tests/e2e/triggers/common.hpp b/tests/e2e/triggers/common.hpp index 903dd8125..5196fa80a 100644 --- a/tests/e2e/triggers/common.hpp +++ b/tests/e2e/triggers/common.hpp @@ -11,6 +11,7 @@ constexpr std::string_view kVertexLabel{"VERTEX"}; constexpr std::string_view kEdgeLabel{"EDGE"}; std::unique_ptr<mg::Client> Connect(); +std::unique_ptr<mg::Client> ConnectWithUser(const std::string_view username); void CreateVertex(mg::Client &client, int vertex_id); void CreateEdge(mg::Client &client, int from_vertex, int to_vertex, int edge_id); diff --git a/tests/e2e/triggers/privilige_check.cpp b/tests/e2e/triggers/privilige_check.cpp new file mode 100644 index 000000000..ae83a5eb8 --- /dev/null +++ b/tests/e2e/triggers/privilige_check.cpp @@ -0,0 +1,161 @@ +#include <string> +#include <string_view> + +#include <gflags/gflags.h> +#include <spdlog/fmt/bundled/core.h> +#include <mgclient.hpp> +#include "common.hpp" +#include "utils/logging.hpp" + +constexpr std::string_view kTriggerPrefix{"CreatedVerticesTrigger"}; + +int main(int argc, char **argv) { + gflags::SetUsageMessage("Memgraph E2E Triggers privilege check"); + gflags::ParseCommandLineFlags(&argc, &argv, true); + logging::RedirectToStderr(); + + constexpr int kVertexId{42}; + constexpr std::string_view kUserlessLabel{"USERLESS"}; + constexpr std::string_view kAdminUser{"ADMIN"}; + constexpr std::string_view kUserWithCreate{"USER_WITH_CREATE"}; + constexpr std::string_view kUserWithoutCreate{"USER_WITHOUT_CREATE"}; + + mg::Client::Init(); + + auto userless_client = Connect(); + + const auto get_number_of_triggers = [&userless_client] { + userless_client->Execute("SHOW TRIGGERS"); + auto result = userless_client->FetchAll(); + MG_ASSERT(result.has_value()); + return result->size(); + }; + + auto create_trigger = [&get_number_of_triggers](mg::Client &client, const std::string_view vertexLabel, + bool should_succeed = true) { + const auto number_of_triggers_before = get_number_of_triggers(); + client.Execute( + fmt::format("CREATE TRIGGER {}{} ON () CREATE " + "AFTER COMMIT " + "EXECUTE " + "UNWIND createdVertices as createdVertex " + "CREATE (n: {} {{ id: createdVertex.id }})", + kTriggerPrefix, vertexLabel, vertexLabel)); + client.DiscardAll(); + const auto number_of_triggers_after = get_number_of_triggers(); + if (should_succeed) { + MG_ASSERT(number_of_triggers_after == number_of_triggers_before + 1); + } else { + MG_ASSERT(number_of_triggers_after == number_of_triggers_before); + } + }; + + auto delete_vertices = [&userless_client] { + userless_client->Execute("MATCH (n) DETACH DELETE n;"); + userless_client->DiscardAll(); + CheckNumberOfAllVertices(*userless_client, 0); + }; + + auto create_user = [&userless_client](const std::string_view username) { + userless_client->Execute(fmt::format("CREATE USER {};", username)); + userless_client->DiscardAll(); + userless_client->Execute(fmt::format("GRANT TRIGGER TO {};", username)); + userless_client->DiscardAll(); + }; + + auto drop_user = [&userless_client](const std::string_view username) { + userless_client->Execute(fmt::format("DROP USER {};", username)); + userless_client->DiscardAll(); + }; + + auto drop_trigger_of_user = [&userless_client](const std::string_view username) { + userless_client->Execute(fmt::format("DROP TRIGGER {}{};", kTriggerPrefix, username)); + userless_client->DiscardAll(); + }; + + // Single trigger created without user, there is no existing users + create_trigger(*userless_client, kUserlessLabel); + CreateVertex(*userless_client, kVertexId); + + WaitForNumberOfAllVertices(*userless_client, 2); + CheckVertexExists(*userless_client, kVertexLabel, kVertexId); + CheckVertexExists(*userless_client, kUserlessLabel, kVertexId); + + delete_vertices(); + + // Single trigger created without user, there is an existing user + // The trigger fails because there is no owner + create_user(kAdminUser); + CreateVertex(*userless_client, kVertexId); + + CheckVertexExists(*userless_client, kVertexLabel, kVertexId); + CheckNumberOfAllVertices(*userless_client, 1); + + delete_vertices(); + + // Three triggers: without an owner, an owner with CREATE privilege, an owner without CREATE privilege; there are + // existing users + // Only the trigger which owner has CREATE privilege will succeed + create_user(kUserWithCreate); + userless_client->Execute(fmt::format("GRANT CREATE TO {};", kUserWithCreate)); + userless_client->DiscardAll(); + create_user(kUserWithoutCreate); + auto client_with_create = ConnectWithUser(kUserWithCreate); + auto client_without_create = ConnectWithUser(kUserWithoutCreate); + create_trigger(*client_with_create, kUserWithCreate); + create_trigger(*client_without_create, kUserWithoutCreate, false); + + // Grant CREATE to be able to create the trigger than revoke it + userless_client->Execute(fmt::format("GRANT CREATE TO {};", kUserWithoutCreate)); + userless_client->DiscardAll(); + create_trigger(*client_without_create, kUserWithoutCreate); + userless_client->Execute(fmt::format("REVOKE CREATE FROM {};", kUserWithoutCreate)); + userless_client->DiscardAll(); + + CreateVertex(*userless_client, kVertexId); + + WaitForNumberOfAllVertices(*userless_client, 2); + CheckVertexExists(*userless_client, kVertexLabel, kVertexId); + CheckVertexExists(*userless_client, kUserWithCreate, kVertexId); + + delete_vertices(); + + // Three triggers: without an owner, an owner with CREATE privilege, an owner without CREATE privilege; there is no + // existing user + // All triggers will succeed, as there is no authorization is done when there are no users + drop_user(kAdminUser); + drop_user(kUserWithCreate); + drop_user(kUserWithoutCreate); + CreateVertex(*userless_client, kVertexId); + + WaitForNumberOfAllVertices(*userless_client, 4); + CheckVertexExists(*userless_client, kVertexLabel, kVertexId); + CheckVertexExists(*userless_client, kUserlessLabel, kVertexId); + CheckVertexExists(*userless_client, kUserWithCreate, kVertexId); + CheckVertexExists(*userless_client, kUserWithoutCreate, kVertexId); + + delete_vertices(); + drop_trigger_of_user(kUserlessLabel); + drop_trigger_of_user(kUserWithCreate); + drop_trigger_of_user(kUserWithoutCreate); + + // The BEFORE COMMIT trigger without proper privileges make the transaction fail + create_user(kUserWithoutCreate); + userless_client->Execute(fmt::format("GRANT CREATE TO {};", kUserWithoutCreate)); + userless_client->DiscardAll(); + client_without_create->Execute( + fmt::format("CREATE TRIGGER {}{} ON () CREATE " + "BEFORE COMMIT " + "EXECUTE " + "UNWIND createdVertices as createdVertex " + "CREATE (n: {} {{ id: createdVertex.id }})", + kTriggerPrefix, kUserWithoutCreate, kUserWithoutCreate)); + client_without_create->DiscardAll(); + userless_client->Execute(fmt::format("REVOKE CREATE FROM {};", kUserWithoutCreate)); + userless_client->DiscardAll(); + + CreateVertex(*userless_client, kVertexId); + CheckNumberOfAllVertices(*userless_client, 0); + + return 0; +} diff --git a/tests/e2e/triggers/workloads.yaml b/tests/e2e/triggers/workloads.yaml index 55d82d485..de3188a20 100644 --- a/tests/e2e/triggers/workloads.yaml +++ b/tests/e2e/triggers/workloads.yaml @@ -20,5 +20,9 @@ workloads: binary: "tests/e2e/triggers/memgraph__e2e__triggers__on_delete" args: ["--bolt-port", *bolt_port] <<: *template_cluster + - name: "Triggers privilege check" + binary: "tests/e2e/triggers/memgraph__e2e__triggers__privileges" + args: ["--bolt-port", *bolt_port] + <<: *template_cluster diff --git a/tests/manual/single_query.cpp b/tests/manual/single_query.cpp index 3aea249af..a181db737 100644 --- a/tests/manual/single_query.cpp +++ b/tests/manual/single_query.cpp @@ -22,7 +22,7 @@ int main(int argc, char *argv[]) { query::Interpreter interpreter{&interpreter_context}; ResultStreamFaker stream(&db); - auto [header, _, qid] = interpreter.Prepare(argv[1], {}); + auto [header, _, qid] = interpreter.Prepare(argv[1], {}, nullptr); stream.Header(header); auto summary = interpreter.PullAll(&stream); stream.Summary(summary); diff --git a/tests/unit/auth.cpp b/tests/unit/auth.cpp index 7fa2df4aa..201248bf4 100644 --- a/tests/unit/auth.cpp +++ b/tests/unit/auth.cpp @@ -165,14 +165,14 @@ TEST_F(AuthWithStorage, RoleManipulations) { { auto user1 = auth.GetUser("user1"); ASSERT_TRUE(user1); - auto role1 = user1->role(); - ASSERT_TRUE(role1); + const auto *role1 = user1->role(); + ASSERT_NE(role1, nullptr); ASSERT_EQ(role1->rolename(), "role1"); auto user2 = auth.GetUser("user2"); ASSERT_TRUE(user2); - auto role2 = user2->role(); - ASSERT_TRUE(role2); + const auto *role2 = user2->role(); + ASSERT_NE(role2, nullptr); ASSERT_EQ(role2->rolename(), "role2"); } @@ -181,13 +181,13 @@ TEST_F(AuthWithStorage, RoleManipulations) { { auto user1 = auth.GetUser("user1"); ASSERT_TRUE(user1); - auto role = user1->role(); - ASSERT_FALSE(role); + const auto *role = user1->role(); + ASSERT_EQ(role, nullptr); auto user2 = auth.GetUser("user2"); ASSERT_TRUE(user2); - auto role2 = user2->role(); - ASSERT_TRUE(role2); + const auto *role2 = user2->role(); + ASSERT_NE(role2, nullptr); ASSERT_EQ(role2->rolename(), "role2"); } @@ -199,13 +199,13 @@ TEST_F(AuthWithStorage, RoleManipulations) { { auto user1 = auth.GetUser("user1"); ASSERT_TRUE(user1); - auto role1 = user1->role(); - ASSERT_FALSE(role1); + const auto *role1 = user1->role(); + ASSERT_EQ(role1, nullptr); auto user2 = auth.GetUser("user2"); ASSERT_TRUE(user2); - auto role2 = user2->role(); - ASSERT_TRUE(role2); + const auto *role2 = user2->role(); + ASSERT_NE(role2, nullptr); ASSERT_EQ(role2->rolename(), "role2"); } @@ -245,8 +245,8 @@ TEST_F(AuthWithStorage, UserRoleLinkUnlink) { { auto user = auth.GetUser("user"); ASSERT_TRUE(user); - auto role = user->role(); - ASSERT_TRUE(role); + const auto *role = user->role(); + ASSERT_NE(role, nullptr); ASSERT_EQ(role->rolename(), "role"); } @@ -260,7 +260,7 @@ TEST_F(AuthWithStorage, UserRoleLinkUnlink) { { auto user = auth.GetUser("user"); ASSERT_TRUE(user); - ASSERT_FALSE(user->role()); + ASSERT_EQ(user->role(), nullptr); } } @@ -620,8 +620,9 @@ TEST_F(AuthWithStorage, CaseInsensitivity) { auto user = auth.GetUser("aLIce"); ASSERT_TRUE(user); ASSERT_EQ(user->username(), "alice"); - ASSERT_TRUE(user->role()); - ASSERT_EQ(user->role()->rolename(), "moderator"); + const auto *role = user->role(); + ASSERT_NE(role, nullptr); + ASSERT_EQ(role->rolename(), "moderator"); } // AllUsersForRole diff --git a/tests/unit/interpreter.cpp b/tests/unit/interpreter.cpp index 6080ecccb..5cc9df849 100644 --- a/tests/unit/interpreter.cpp +++ b/tests/unit/interpreter.cpp @@ -6,6 +6,7 @@ #include "glue/communication.hpp" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "query/auth_checker.hpp" #include "query/config.hpp" #include "query/exceptions.hpp" #include "query/interpreter.hpp" @@ -28,15 +29,17 @@ auto ToEdgeList(const communication::bolt::Value &v) { }; struct InterpreterFaker { - explicit InterpreterFaker(storage::Storage *db, const query::InterpreterConfig config, - const std::filesystem::path &data_directory) + InterpreterFaker(storage::Storage *db, const query::InterpreterConfig config, + const std::filesystem::path &data_directory) : interpreter_context(db, config, data_directory, "not used bootstrap servers"), - interpreter(&interpreter_context) {} + interpreter(&interpreter_context) { + interpreter_context.auth_checker = &auth_checker; + } auto Prepare(const std::string &query, const std::map<std::string, storage::PropertyValue> ¶ms = {}) { ResultStreamFaker stream(interpreter_context.db); - const auto [header, _, qid] = interpreter.Prepare(query, params); + const auto [header, _, qid] = interpreter.Prepare(query, params, nullptr); stream.Header(header); return std::make_pair(std::move(stream), qid); } @@ -61,6 +64,7 @@ struct InterpreterFaker { return std::move(stream); } + query::AllowEverythingAuthChecker auth_checker; query::InterpreterContext interpreter_context; query::Interpreter interpreter; }; diff --git a/tests/unit/query_dump.cpp b/tests/unit/query_dump.cpp index 58f86491b..6f6176380 100644 --- a/tests/unit/query_dump.cpp +++ b/tests/unit/query_dump.cpp @@ -194,7 +194,7 @@ auto Execute(storage::Storage *db, const std::string &query) { query::Interpreter interpreter(&context); ResultStreamFaker stream(db); - auto [header, _, qid] = interpreter.Prepare(query, {}); + auto [header, _, qid] = interpreter.Prepare(query, {}, nullptr); stream.Header(header); auto summary = interpreter.PullAll(&stream); stream.Summary(summary); @@ -711,7 +711,7 @@ class StatefulInterpreter { auto Execute(const std::string &query) { ResultStreamFaker stream(db_); - auto [header, _, qid] = interpreter_.Prepare(query, {}); + auto [header, _, qid] = interpreter_.Prepare(query, {}, nullptr); stream.Header(header); auto summary = interpreter_.PullAll(&stream); stream.Summary(summary); diff --git a/tests/unit/query_plan_edge_cases.cpp b/tests/unit/query_plan_edge_cases.cpp index 4fb478d5f..b214696fd 100644 --- a/tests/unit/query_plan_edge_cases.cpp +++ b/tests/unit/query_plan_edge_cases.cpp @@ -42,7 +42,7 @@ class QueryExecution : public testing::Test { auto Execute(const std::string &query) { ResultStreamFaker stream(&*db_); - auto [header, _, qid] = interpreter_->Prepare(query, {}); + auto [header, _, qid] = interpreter_->Prepare(query, {}, nullptr); stream.Header(header); auto summary = interpreter_->PullAll(&stream); stream.Summary(summary); diff --git a/tests/unit/query_streams.cpp b/tests/unit/query_streams.cpp index d73f04c26..4ce2a84ca 100644 --- a/tests/unit/query_streams.cpp +++ b/tests/unit/query_streams.cpp @@ -34,6 +34,7 @@ StreamInfo CreateDefaultStreamInfo() { .batch_interval = std::nullopt, .batch_size = std::nullopt, .transformation_name = "not used in the tests", + .owner = std::nullopt, }; } @@ -57,6 +58,8 @@ class StreamsTest : public ::testing::Test { // Though there is a Streams object in interpreter context, it makes more sense to use a separate object to test, // because that provides a way to recreate the streams object and also give better control over the arguments of the // Streams constructor. + // InterpreterContext::auth_checker_ is used in the Streams object, but only in the message processing part. Because + // these tests don't send any messages, the auth_checker_ pointer can be left as nullptr. query::InterpreterContext interpreter_context_{&db_, query::InterpreterConfig{}, data_directory_, "dont care bootstrap servers"}; std::filesystem::path streams_data_directory_{data_directory_ / "separate-dir-for-test"}; @@ -172,12 +175,14 @@ TEST_F(StreamsTest, RestoreStreams) { if (i > 0) { stream_info.batch_interval = std::chrono::milliseconds((i + 1) * 10); stream_info.batch_size = 1000 + i; + stream_info.owner = std::string{"owner"} + iteration_postfix; } mock_cluster_.CreateTopic(stream_info.topics[0]); } stream_check_datas[1].info.batch_interval = {}; stream_check_datas[2].info.batch_size = {}; + stream_check_datas[3].info.owner = {}; const auto check_restore_logic = [&stream_check_datas, this]() { // Reset the Streams object to trigger reloading diff --git a/tests/unit/query_trigger.cpp b/tests/unit/query_trigger.cpp index a9f9838be..3904cbda9 100644 --- a/tests/unit/query_trigger.cpp +++ b/tests/unit/query_trigger.cpp @@ -1,12 +1,16 @@ +#include <gmock/gmock.h> #include <gtest/gtest.h> #include <filesystem> #include <fmt/format.h> +#include "query/auth_checker.hpp" #include "query/config.hpp" #include "query/db_accessor.hpp" +#include "query/frontend/ast/ast.hpp" #include "query/interpreter.hpp" #include "query/trigger.hpp" #include "query/typed_value.hpp" +#include "utils/exceptions.hpp" #include "utils/memory.hpp" namespace { @@ -16,6 +20,12 @@ const std::unordered_set<query::TriggerEventType> kAllEventTypes{ query::TriggerEventType::DELETE, query::TriggerEventType::VERTEX_UPDATE, query::TriggerEventType::EDGE_UPDATE, query::TriggerEventType::UPDATE, }; + +class MockAuthChecker : public query::AuthChecker { + public: + MOCK_CONST_METHOD2(IsUserAuthorized, bool(const std::optional<std::string> &username, + const std::vector<query::AuthQuery::Privilege> &privileges)); +}; } // namespace class TriggerContextTest : public ::testing::Test { @@ -820,6 +830,7 @@ class TriggerStoreTest : public ::testing::Test { utils::SkipList<query::QueryCacheEntry> ast_cache; utils::SpinLock antlr_lock; + query::AllowEverythingAuthChecker auth_checker; private: void Clear() { @@ -836,7 +847,7 @@ TEST_F(TriggerStoreTest, Restore) { const auto reset_store = [&] { store.emplace(testing_directory); - store->RestoreTriggers(&ast_cache, &*dba, &antlr_lock, query::InterpreterConfig::Query{}); + store->RestoreTriggers(&ast_cache, &*dba, &antlr_lock, query::InterpreterConfig::Query{}, &auth_checker); }; reset_store(); @@ -853,34 +864,40 @@ TEST_F(TriggerStoreTest, Restore) { const auto *trigger_name_after = "trigger_after"; const auto *trigger_statement = "RETURN $parameter"; const auto event_type = query::TriggerEventType::VERTEX_CREATE; + const std::string owner{"owner"}; store->AddTrigger(trigger_name_before, trigger_statement, std::map<std::string, storage::PropertyValue>{{"parameter", storage::PropertyValue{1}}}, event_type, query::TriggerPhase::BEFORE_COMMIT, &ast_cache, &*dba, &antlr_lock, - query::InterpreterConfig::Query{}); + query::InterpreterConfig::Query{}, std::nullopt, &auth_checker); store->AddTrigger(trigger_name_after, trigger_statement, std::map<std::string, storage::PropertyValue>{{"parameter", storage::PropertyValue{"value"}}}, event_type, query::TriggerPhase::AFTER_COMMIT, &ast_cache, &*dba, &antlr_lock, - query::InterpreterConfig::Query{}); + query::InterpreterConfig::Query{}, {owner}, &auth_checker); const auto check_triggers = [&] { ASSERT_EQ(store->GetTriggerInfo().size(), 2); - const auto verify_trigger = [&](const auto &trigger, const auto &name) { + const auto verify_trigger = [&](const auto &trigger, const auto &name, const std::string *owner) { ASSERT_EQ(trigger.Name(), name); ASSERT_EQ(trigger.OriginalStatement(), trigger_statement); ASSERT_EQ(trigger.EventType(), event_type); + if (owner != nullptr) { + ASSERT_EQ(*trigger.Owner(), *owner); + } else { + ASSERT_FALSE(trigger.Owner().has_value()); + } }; const auto before_commit_triggers = store->BeforeCommitTriggers().access(); ASSERT_EQ(before_commit_triggers.size(), 1); for (const auto &trigger : before_commit_triggers) { - verify_trigger(trigger, trigger_name_before); + verify_trigger(trigger, trigger_name_before, nullptr); } const auto after_commit_triggers = store->AfterCommitTriggers().access(); ASSERT_EQ(after_commit_triggers.size(), 1); for (const auto &trigger : after_commit_triggers) { - verify_trigger(trigger, trigger_name_after); + verify_trigger(trigger, trigger_name_after, &owner); } }; @@ -906,32 +923,32 @@ TEST_F(TriggerStoreTest, AddTrigger) { // Invalid query in statements ASSERT_THROW(store.AddTrigger("trigger", "RETUR 1", {}, query::TriggerEventType::VERTEX_CREATE, query::TriggerPhase::BEFORE_COMMIT, &ast_cache, &*dba, &antlr_lock, - query::InterpreterConfig::Query{}), + query::InterpreterConfig::Query{}, std::nullopt, &auth_checker), utils::BasicException); ASSERT_THROW(store.AddTrigger("trigger", "RETURN createdEdges", {}, query::TriggerEventType::VERTEX_CREATE, query::TriggerPhase::BEFORE_COMMIT, &ast_cache, &*dba, &antlr_lock, - query::InterpreterConfig::Query{}), + query::InterpreterConfig::Query{}, std::nullopt, &auth_checker), utils::BasicException); ASSERT_THROW(store.AddTrigger("trigger", "RETURN $parameter", {}, query::TriggerEventType::VERTEX_CREATE, query::TriggerPhase::BEFORE_COMMIT, &ast_cache, &*dba, &antlr_lock, - query::InterpreterConfig::Query{}), + query::InterpreterConfig::Query{}, std::nullopt, &auth_checker), utils::BasicException); ASSERT_NO_THROW( store.AddTrigger("trigger", "RETURN $parameter", std::map<std::string, storage::PropertyValue>{{"parameter", storage::PropertyValue{1}}}, query::TriggerEventType::VERTEX_CREATE, query::TriggerPhase::BEFORE_COMMIT, &ast_cache, &*dba, - &antlr_lock, query::InterpreterConfig::Query{})); + &antlr_lock, query::InterpreterConfig::Query{}, std::nullopt, &auth_checker)); // Inserting with the same name ASSERT_THROW(store.AddTrigger("trigger", "RETURN 1", {}, query::TriggerEventType::VERTEX_CREATE, query::TriggerPhase::BEFORE_COMMIT, &ast_cache, &*dba, &antlr_lock, - query::InterpreterConfig::Query{}), + query::InterpreterConfig::Query{}, std::nullopt, &auth_checker), utils::BasicException); ASSERT_THROW(store.AddTrigger("trigger", "RETURN 1", {}, query::TriggerEventType::VERTEX_CREATE, query::TriggerPhase::AFTER_COMMIT, &ast_cache, &*dba, &antlr_lock, - query::InterpreterConfig::Query{}), + query::InterpreterConfig::Query{}, std::nullopt, &auth_checker), utils::BasicException); ASSERT_EQ(store.GetTriggerInfo().size(), 1); @@ -947,7 +964,7 @@ TEST_F(TriggerStoreTest, DropTrigger) { const auto *trigger_name = "trigger"; store.AddTrigger(trigger_name, "RETURN 1", {}, query::TriggerEventType::VERTEX_CREATE, query::TriggerPhase::BEFORE_COMMIT, &ast_cache, &*dba, &antlr_lock, - query::InterpreterConfig::Query{}); + query::InterpreterConfig::Query{}, std::nullopt, &auth_checker); ASSERT_THROW(store.DropTrigger("Unknown"), utils::BasicException); ASSERT_NO_THROW(store.DropTrigger(trigger_name)); @@ -960,7 +977,7 @@ TEST_F(TriggerStoreTest, TriggerInfo) { std::vector<query::TriggerStore::TriggerInfo> expected_info; store.AddTrigger("trigger", "RETURN 1", {}, query::TriggerEventType::VERTEX_CREATE, query::TriggerPhase::BEFORE_COMMIT, &ast_cache, &*dba, &antlr_lock, - query::InterpreterConfig::Query{}); + query::InterpreterConfig::Query{}, std::nullopt, &auth_checker); expected_info.push_back( {"trigger", "RETURN 1", query::TriggerEventType::VERTEX_CREATE, query::TriggerPhase::BEFORE_COMMIT}); @@ -971,7 +988,7 @@ TEST_F(TriggerStoreTest, TriggerInfo) { ASSERT_TRUE(std::all_of(expected_info.begin(), expected_info.end(), [&](const auto &info) { return std::find_if(trigger_info.begin(), trigger_info.end(), [&](const auto &other) { return info.name == other.name && info.statement == other.statement && - info.event_type == other.event_type && info.phase == other.phase; + info.event_type == other.event_type && info.phase == other.phase && !info.owner.has_value(); }) != trigger_info.end(); })); }; @@ -979,8 +996,8 @@ TEST_F(TriggerStoreTest, TriggerInfo) { check_trigger_info(); store.AddTrigger("edge_update_trigger", "RETURN 1", {}, query::TriggerEventType::EDGE_UPDATE, - query::TriggerPhase::AFTER_COMMIT, &ast_cache, &*dba, &antlr_lock, - query::InterpreterConfig::Query{}); + query::TriggerPhase::AFTER_COMMIT, &ast_cache, &*dba, &antlr_lock, query::InterpreterConfig::Query{}, + std::nullopt, &auth_checker); expected_info.push_back( {"edge_update_trigger", "RETURN 1", query::TriggerEventType::EDGE_UPDATE, query::TriggerPhase::AFTER_COMMIT}); @@ -1093,8 +1110,56 @@ TEST_F(TriggerStoreTest, AnyTriggerAllKeywords) { SCOPED_TRACE(keyword); EXPECT_NO_THROW(store.AddTrigger(trigger_name, fmt::format("RETURN {}", keyword), {}, event_type, query::TriggerPhase::BEFORE_COMMIT, &ast_cache, &*dba, &antlr_lock, - query::InterpreterConfig::Query{})); + query::InterpreterConfig::Query{}, std::nullopt, &auth_checker)); store.DropTrigger(trigger_name); } } } + +TEST_F(TriggerStoreTest, AuthCheckerUsage) { + using Privilege = query::AuthQuery::Privilege; + using ::testing::_; + using ::testing::ElementsAre; + using ::testing::Return; + std::optional<query::TriggerStore> store{testing_directory}; + const std::optional<std::string> owner{"testing_owner"}; + MockAuthChecker mock_checker; + + ::testing::InSequence s; + + EXPECT_CALL(mock_checker, IsUserAuthorized(std::optional<std::string>{}, ElementsAre(Privilege::CREATE))) + .Times(1) + .WillOnce(Return(true)); + EXPECT_CALL(mock_checker, IsUserAuthorized(owner, ElementsAre(Privilege::CREATE))).Times(1).WillOnce(Return(true)); + + ASSERT_NO_THROW(store->AddTrigger("successfull_trigger_1", "CREATE (n:VERTEX) RETURN n", {}, + query::TriggerEventType::EDGE_UPDATE, query::TriggerPhase::AFTER_COMMIT, &ast_cache, + &*dba, &antlr_lock, query::InterpreterConfig::Query{}, std::nullopt, + &mock_checker)); + + ASSERT_NO_THROW(store->AddTrigger("successfull_trigger_2", "CREATE (n:VERTEX) RETURN n", {}, + query::TriggerEventType::EDGE_UPDATE, query::TriggerPhase::AFTER_COMMIT, &ast_cache, + &*dba, &antlr_lock, query::InterpreterConfig::Query{}, owner, &mock_checker)); + + EXPECT_CALL(mock_checker, IsUserAuthorized(std::optional<std::string>{}, ElementsAre(Privilege::MATCH))) + .Times(1) + .WillOnce(Return(false)); + + ASSERT_THROW(store->AddTrigger("unprivileged_trigger", "MATCH (n:VERTEX) RETURN n", {}, + query::TriggerEventType::EDGE_UPDATE, query::TriggerPhase::AFTER_COMMIT, &ast_cache, + &*dba, &antlr_lock, query::InterpreterConfig::Query{}, std::nullopt, &mock_checker); + , utils::BasicException); + + store.emplace(testing_directory); + EXPECT_CALL(mock_checker, IsUserAuthorized(std::optional<std::string>{}, ElementsAre(Privilege::CREATE))) + .Times(1) + .WillOnce(Return(false)); + EXPECT_CALL(mock_checker, IsUserAuthorized(owner, ElementsAre(Privilege::CREATE))).Times(1).WillOnce(Return(true)); + + ASSERT_NO_THROW( + store->RestoreTriggers(&ast_cache, &*dba, &antlr_lock, query::InterpreterConfig::Query{}, &mock_checker)); + + const auto triggers = store->GetTriggerInfo(); + ASSERT_EQ(triggers.size(), 1); + ASSERT_EQ(triggers.front().owner, owner); +}