Merge old Label Based Auth Epic branch into new one because of commits with bad checks on the old epic branch (#478)

* grammar expanded; (#462)

* T0954 mg expand user and role to hold permissions on labels (#465)

* added FineGrainedAccessPermissions class to model

* expanded user and role with fine grained access permissions

* fixed grammar

* [E129 < T0953-MG] GRANT, DENY, REVOKE added in interpreter and mainVisitor (#464)

* GRANT, DENY, REVOKE added in interpreter and mainVisitor

* Commented labelPermissons

* remove labelsPermission adding

* Removed extra lambda

* [E129<-T0955-MG] Expand ExecutionContext with label related information (#467)

* Added FineGrainedAccessChecker to Context

* fixed failing tests for label based authorization (#480)

* Marked FineGrainedAccessChecker ctor explicit; Introduced change to clang-tidy; (#483)

Co-authored-by: niko4299 <51059248+niko4299@users.noreply.github.com>
This commit is contained in:
Boris Taševski 2022-08-02 12:51:22 +02:00 committed by GitHub
parent 80e0e439b7
commit 480df4ed69
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 391 additions and 60 deletions

View File

@ -130,7 +130,7 @@ jobs:
source /opt/toolchain-v4/activate
# Restrict clang-tidy results only to the modified parts
git diff -U0 ${{ env.BASE_BRANCH }}... -- src | ./tools/github/clang-tidy/clang-tidy-diff.py -p 1 -j $THREADS -path build | tee ./build/clang_tidy_output.txt
git diff -U0 ${{ env.BASE_BRANCH }}... -- src | ./tools/github/clang-tidy/clang-tidy-diff.py -p 1 -j $THREADS -path build -regex ".+\.cpp" | tee ./build/clang_tidy_output.txt
# Fail if any warning is reported
! cat ./build/clang_tidy_output.txt | ./tools/github/clang-tidy/grep_error_lines.sh > /dev/null

View File

@ -183,19 +183,137 @@ bool operator==(const Permissions &first, const Permissions &second) {
bool operator!=(const Permissions &first, const Permissions &second) { return !(first == second); }
const std::string ASTERISK = "*";
FineGrainedAccessPermissions::FineGrainedAccessPermissions(const std::unordered_set<std::string> &grants,
const std::unordered_set<std::string> &denies)
: grants_(grants), denies_(denies) {}
PermissionLevel FineGrainedAccessPermissions::Has(const std::string &permission) const {
if ((denies_.size() == 1 && denies_.find(ASTERISK) != denies_.end()) || denies_.find(permission) != denies_.end()) {
return PermissionLevel::DENY;
}
if ((grants_.size() == 1 && grants_.find(ASTERISK) != grants_.end()) || grants_.find(permission) != denies_.end()) {
return PermissionLevel::GRANT;
}
return PermissionLevel::NEUTRAL;
}
void FineGrainedAccessPermissions::Grant(const std::string &permission) {
if (permission == ASTERISK) {
grants_.clear();
grants_.insert(permission);
return;
}
auto deniedPermissionIter = denies_.find(permission);
if (deniedPermissionIter != denies_.end()) {
denies_.erase(deniedPermissionIter);
}
if (grants_.size() == 1 && grants_.find(ASTERISK) != grants_.end()) {
grants_.erase(ASTERISK);
}
if (grants_.find(permission) == grants_.end()) {
grants_.insert(permission);
}
}
void FineGrainedAccessPermissions::Revoke(const std::string &permission) {
if (permission == ASTERISK) {
grants_.clear();
denies_.clear();
return;
}
auto deniedPermissionIter = denies_.find(permission);
auto grantedPermissionIter = grants_.find(permission);
if (deniedPermissionIter != denies_.end()) {
denies_.erase(deniedPermissionIter);
}
if (grantedPermissionIter != grants_.end()) {
grants_.erase(grantedPermissionIter);
}
}
void FineGrainedAccessPermissions::Deny(const std::string &permission) {
if (permission == ASTERISK) {
denies_.clear();
denies_.insert(permission);
return;
}
auto grantedPermissionIter = grants_.find(permission);
if (grantedPermissionIter != grants_.end()) {
grants_.erase(grantedPermissionIter);
}
if (denies_.size() == 1 && denies_.find(ASTERISK) != denies_.end()) {
denies_.erase(ASTERISK);
}
if (denies_.find(permission) == denies_.end()) {
denies_.insert(permission);
}
}
nlohmann::json FineGrainedAccessPermissions::Serialize() const {
nlohmann::json data = nlohmann::json::object();
data["grants"] = grants_;
data["denies"] = denies_;
return data;
}
FineGrainedAccessPermissions FineGrainedAccessPermissions::Deserialize(const nlohmann::json &data) {
if (!data.is_object()) {
throw AuthException("Couldn't load permissions data!");
}
return FineGrainedAccessPermissions(data["grants"], data["denies"]);
}
const std::unordered_set<std::string> &FineGrainedAccessPermissions::grants() const { return grants_; }
const std::unordered_set<std::string> &FineGrainedAccessPermissions::denies() const { return denies_; }
bool operator==(const FineGrainedAccessPermissions &first, const FineGrainedAccessPermissions &second) {
return first.grants() == second.grants() && first.denies() == second.denies();
}
bool operator!=(const FineGrainedAccessPermissions &first, const FineGrainedAccessPermissions &second) {
return !(first == second);
}
Role::Role(const std::string &rolename) : rolename_(utils::ToLowerCase(rolename)) {}
Role::Role(const std::string &rolename, const Permissions &permissions)
: rolename_(utils::ToLowerCase(rolename)), permissions_(permissions) {}
Role::Role(const std::string &rolename, const Permissions &permissions,
const FineGrainedAccessPermissions &fine_grained_access_permissions)
: rolename_(utils::ToLowerCase(rolename)),
permissions_(permissions),
fine_grained_access_permissions_(fine_grained_access_permissions) {}
const std::string &Role::rolename() const { return rolename_; }
const Permissions &Role::permissions() const { return permissions_; }
Permissions &Role::permissions() { return permissions_; }
const FineGrainedAccessPermissions &Role::fine_grained_access_permissions() const {
return fine_grained_access_permissions_;
}
FineGrainedAccessPermissions &Role::fine_grained_access_permissions() { return fine_grained_access_permissions_; }
nlohmann::json Role::Serialize() const {
nlohmann::json data = nlohmann::json::object();
data["rolename"] = rolename_;
data["permissions"] = permissions_.Serialize();
data["fine_grained_access_permissions"] = fine_grained_access_permissions_.Serialize();
return data;
}
@ -203,11 +321,14 @@ Role Role::Deserialize(const nlohmann::json &data) {
if (!data.is_object()) {
throw AuthException("Couldn't load role data!");
}
if (!data["rolename"].is_string() || !data["permissions"].is_object()) {
if (!data["rolename"].is_string() || !data["permissions"].is_object() ||
!data["fine_grained_access_permissions"].is_object()) {
throw AuthException("Couldn't load role data!");
}
auto permissions = Permissions::Deserialize(data["permissions"]);
return {data["rolename"], permissions};
auto fine_grained_access_permissions =
FineGrainedAccessPermissions::Deserialize(data["fine_grained_access_permissions"]);
return {data["rolename"], permissions, fine_grained_access_permissions};
}
bool operator==(const Role &first, const Role &second) {
@ -216,8 +337,12 @@ bool operator==(const Role &first, const Role &second) {
User::User(const std::string &username) : username_(utils::ToLowerCase(username)) {}
User::User(const std::string &username, const std::string &password_hash, const Permissions &permissions)
: username_(utils::ToLowerCase(username)), password_hash_(password_hash), permissions_(permissions) {}
User::User(const std::string &username, const std::string &password_hash, const Permissions &permissions,
const FineGrainedAccessPermissions &fine_grained_access_permissions)
: username_(utils::ToLowerCase(username)),
password_hash_(password_hash),
permissions_(permissions),
fine_grained_access_permissions_(fine_grained_access_permissions) {}
bool User::CheckPassword(const std::string &password) {
if (password_hash_.empty()) return true;
@ -266,10 +391,35 @@ Permissions User::GetPermissions() const {
return permissions_;
}
FineGrainedAccessPermissions User::GetFineGrainedAccessPermissions() const {
if (role_) {
std::unordered_set<std::string> resultGrants;
std::set_union(fine_grained_access_permissions_.grants().begin(), fine_grained_access_permissions_.grants().end(),
role_->fine_grained_access_permissions().grants().begin(),
role_->fine_grained_access_permissions().grants().end(),
std::inserter(resultGrants, resultGrants.begin()));
std::unordered_set<std::string> resultDenies;
std::set_union(fine_grained_access_permissions_.denies().begin(), fine_grained_access_permissions_.denies().end(),
role_->fine_grained_access_permissions().denies().begin(),
role_->fine_grained_access_permissions().denies().end(),
std::inserter(resultDenies, resultDenies.begin()));
return FineGrainedAccessPermissions(resultGrants, resultDenies);
}
return fine_grained_access_permissions_;
}
const std::string &User::username() const { return username_; }
const Permissions &User::permissions() const { return permissions_; }
Permissions &User::permissions() { return permissions_; }
const FineGrainedAccessPermissions &User::fine_grained_access_permissions() const {
return fine_grained_access_permissions_;
}
FineGrainedAccessPermissions &User::fine_grained_access_permissions() { return fine_grained_access_permissions_; }
const Role *User::role() const {
if (role_.has_value()) {
@ -283,6 +433,7 @@ nlohmann::json User::Serialize() const {
data["username"] = username_;
data["password_hash"] = password_hash_;
data["permissions"] = permissions_.Serialize();
data["fine_grained_access_permissions"] = fine_grained_access_permissions_.Serialize();
// The role shouldn't be serialized here, it is stored as a foreign key.
return data;
}
@ -295,7 +446,9 @@ User User::Deserialize(const nlohmann::json &data) {
throw AuthException("Couldn't load user data!");
}
auto permissions = Permissions::Deserialize(data["permissions"]);
return {data["username"], data["password_hash"], permissions};
auto fine_grained_access_permissions =
FineGrainedAccessPermissions::Deserialize(data["fine_grained_access_permissions"]);
return {data["username"], data["password_hash"], permissions, fine_grained_access_permissions};
}
bool operator==(const User &first, const User &second) {

View File

@ -10,6 +10,7 @@
#include <optional>
#include <string>
#include <unordered_set>
#include <json/json.hpp>
@ -88,15 +89,48 @@ bool operator==(const Permissions &first, const Permissions &second);
bool operator!=(const Permissions &first, const Permissions &second);
class FineGrainedAccessPermissions final {
public:
explicit FineGrainedAccessPermissions(const std::unordered_set<std::string> &grants = {},
const std::unordered_set<std::string> &denies = {});
PermissionLevel Has(const std::string &permission) const;
void Grant(const std::string &permission);
void Revoke(const std::string &permission);
void Deny(const std::string &permission);
nlohmann::json Serialize() const;
/// @throw AuthException if unable to deserialize.
static FineGrainedAccessPermissions Deserialize(const nlohmann::json &data);
const std::unordered_set<std::string> &grants() const;
const std::unordered_set<std::string> &denies() const;
private:
std::unordered_set<std::string> grants_{};
std::unordered_set<std::string> denies_{};
};
bool operator==(const FineGrainedAccessPermissions &first, const FineGrainedAccessPermissions &second);
bool operator!=(const FineGrainedAccessPermissions &first, const FineGrainedAccessPermissions &second);
class Role final {
public:
Role(const std::string &rolename);
Role(const std::string &rolename, const Permissions &permissions);
Role(const std::string &rolename, const Permissions &permissions,
const FineGrainedAccessPermissions &fine_grained_access_permissions);
const std::string &rolename() const;
const Permissions &permissions() const;
Permissions &permissions();
const FineGrainedAccessPermissions &fine_grained_access_permissions() const;
FineGrainedAccessPermissions &fine_grained_access_permissions();
nlohmann::json Serialize() const;
@ -108,6 +142,7 @@ class Role final {
private:
std::string rolename_;
Permissions permissions_;
FineGrainedAccessPermissions fine_grained_access_permissions_;
};
bool operator==(const Role &first, const Role &second);
@ -117,7 +152,8 @@ class User final {
public:
User(const std::string &username);
User(const std::string &username, const std::string &password_hash, const Permissions &permissions);
User(const std::string &username, const std::string &password_hash, const Permissions &permissions,
const FineGrainedAccessPermissions &fine_grained_access_permissions);
/// @throw AuthException if unable to verify the password.
bool CheckPassword(const std::string &password);
@ -130,11 +166,14 @@ class User final {
void ClearRole();
Permissions GetPermissions() const;
FineGrainedAccessPermissions GetFineGrainedAccessPermissions() const;
const std::string &username() const;
const Permissions &permissions() const;
Permissions &permissions();
const FineGrainedAccessPermissions &fine_grained_access_permissions() const;
FineGrainedAccessPermissions &fine_grained_access_permissions();
const Role *role() const;
@ -149,6 +188,7 @@ class User final {
std::string username_;
std::string password_hash_;
Permissions permissions_;
FineGrainedAccessPermissions fine_grained_access_permissions_;
std::optional<Role> role_;
};

View File

@ -506,7 +506,7 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler {
if (first_user) {
spdlog::info("{} is first created user. Granting all privileges.", username);
GrantPrivilege(username, memgraph::query::kPrivilegesAll);
GrantPrivilege(username, memgraph::query::kPrivilegesAll, {"*"});
}
return user_added;
@ -751,9 +751,28 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler {
}
}
memgraph::auth::User *GetUser(const std::string &username) override {
if (!std::regex_match(username, name_regex_)) {
throw memgraph::query::QueryRuntimeException("Invalid user name.");
}
try {
auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username);
if (!user) {
throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist .", username);
}
return new memgraph::auth::User(*user);
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
void GrantPrivilege(const std::string &user_or_role,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges) override {
EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) {
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges,
const std::vector<std::string> &labels) override {
EditPermissions(user_or_role, privileges, labels, [](auto *permissions, const auto &permission) {
// TODO (mferencevic): should we first check that the
// privilege is granted/denied/revoked before
// unconditionally granting/denying/revoking it?
@ -762,8 +781,9 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler {
}
void DenyPrivilege(const std::string &user_or_role,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges) override {
EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) {
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges,
const std::vector<std::string> &labels) override {
EditPermissions(user_or_role, privileges, labels, [](auto *permissions, const auto &permission) {
// TODO (mferencevic): should we first check that the
// privilege is granted/denied/revoked before
// unconditionally granting/denying/revoking it?
@ -772,8 +792,9 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler {
}
void RevokePrivilege(const std::string &user_or_role,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges) override {
EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) {
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges,
const std::vector<std::string> &labels) override {
EditPermissions(user_or_role, privileges, labels, [](auto *permissions, const auto &permission) {
// TODO (mferencevic): should we first check that the
// privilege is granted/denied/revoked before
// unconditionally granting/denying/revoking it?
@ -784,7 +805,8 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler {
private:
template <class TEditFun>
void EditPermissions(const std::string &user_or_role,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges, const TEditFun &edit_fun) {
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges,
const std::vector<std::string> &labels, const TEditFun &edit_fun) {
if (!std::regex_match(user_or_role, name_regex_)) {
throw memgraph::query::QueryRuntimeException("Invalid user or role name.");
}
@ -804,11 +826,18 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler {
for (const auto &permission : permissions) {
edit_fun(&user->permissions(), permission);
}
for (const auto &label : labels) {
edit_fun(&user->fine_grained_access_permissions(), label);
}
locked_auth->SaveUser(*user);
} else {
for (const auto &permission : permissions) {
edit_fun(&role->permissions(), permission);
}
for (const auto &label : labels) {
edit_fun(&user->fine_grained_access_permissions(), label);
}
locked_auth->SaveRole(*role);
}
} catch (const memgraph::auth::AuthException &e) {

View File

@ -14,6 +14,7 @@
#include <type_traits>
#include "query/common.hpp"
#include "query/fine_grained_access_checker.hpp"
#include "query/frontend/semantic/symbol_table.hpp"
#include "query/metadata.hpp"
#include "query/parameters.hpp"
@ -72,6 +73,7 @@ struct ExecutionContext {
ExecutionStats execution_stats;
TriggerContextCollector *trigger_context_collector{nullptr};
utils::AsyncTimer timer;
FineGrainedAccessChecker *fine_grained_access_checker{nullptr};
};
static_assert(std::is_move_assignable_v<ExecutionContext>, "ExecutionContext must be move assignable!");

View File

@ -0,0 +1,24 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "auth/models.hpp"
#include "query/frontend/ast/ast.hpp"
#include "storage/v2/id_types.hpp"
namespace memgraph::query {
class FineGrainedAccessChecker {
public:
virtual bool IsUserAuthorizedLabels(const std::vector<memgraph::storage::LabelId> &label,
memgraph::query::DbAccessor *dba) const = 0;
};
} // namespace memgraph::query

View File

@ -2234,6 +2234,7 @@ cpp<#
(:serialize (:slk))
(:clone))
(lcp:define-class auth-query (query)
((action "Action" :scope :public)
(user "std::string" :scope :public)
@ -2242,7 +2243,8 @@ cpp<#
(password "Expression *" :initval "nullptr" :scope :public
:slk-save #'slk-save-ast-pointer
:slk-load (slk-load-ast-pointer "Expression"))
(privileges "std::vector<Privilege>" :scope :public))
(privileges "std::vector<Privilege>" :scope :public)
(labels "std::vector<std::string>" :scope :public))
(:public
(lcp:define-enum action
(create-role drop-role show-roles create-user set-password drop-user
@ -2264,13 +2266,14 @@ cpp<#
#>cpp
AuthQuery(Action action, std::string user, std::string role,
std::string user_or_role, Expression *password,
std::vector<Privilege> privileges)
std::vector<Privilege> privileges, std::vector<std::string> labels)
: action_(action),
user_(user),
role_(role),
user_or_role_(user_or_role),
password_(password),
privileges_(privileges) {}
privileges_(privileges),
labels_(labels) {}
cpp<#)
(:private
#>cpp

View File

@ -1277,7 +1277,11 @@ antlrcpp::Any CypherMainVisitor::visitGrantPrivilege(MemgraphCypher::GrantPrivil
auth->user_or_role_ = std::any_cast<std::string>(ctx->userOrRole->accept(this));
if (ctx->privilegeList()) {
for (auto *privilege : ctx->privilegeList()->privilege()) {
auth->privileges_.push_back(std::any_cast<AuthQuery::Privilege>(privilege->accept(this)));
if (privilege->LABELS()) {
auth->labels_ = std::any_cast<std::vector<std::string>>(privilege->labelList()->accept(this));
} else {
auth->privileges_.push_back(std::any_cast<AuthQuery::Privilege>(privilege->accept(this)));
}
}
} else {
/* grant all privileges */
@ -1295,7 +1299,11 @@ antlrcpp::Any CypherMainVisitor::visitDenyPrivilege(MemgraphCypher::DenyPrivileg
auth->user_or_role_ = std::any_cast<std::string>(ctx->userOrRole->accept(this));
if (ctx->privilegeList()) {
for (auto *privilege : ctx->privilegeList()->privilege()) {
auth->privileges_.push_back(std::any_cast<AuthQuery::Privilege>(privilege->accept(this)));
if (privilege->LABELS()) {
auth->labels_ = std::any_cast<std::vector<std::string>>(privilege->labelList()->accept(this));
} else {
auth->privileges_.push_back(std::any_cast<AuthQuery::Privilege>(privilege->accept(this)));
}
}
} else {
/* deny all privileges */
@ -1313,7 +1321,11 @@ antlrcpp::Any CypherMainVisitor::visitRevokePrivilege(MemgraphCypher::RevokePriv
auth->user_or_role_ = std::any_cast<std::string>(ctx->userOrRole->accept(this));
if (ctx->privilegeList()) {
for (auto *privilege : ctx->privilegeList()->privilege()) {
auth->privileges_.push_back(std::any_cast<AuthQuery::Privilege>(privilege->accept(this)));
if (privilege->LABELS()) {
auth->labels_ = std::any_cast<std::vector<std::string>>(privilege->labelList()->accept(this));
} else {
auth->privileges_.push_back(std::any_cast<AuthQuery::Privilege>(privilege->accept(this)));
}
}
} else {
/* revoke all privileges */
@ -1322,6 +1334,19 @@ antlrcpp::Any CypherMainVisitor::visitRevokePrivilege(MemgraphCypher::RevokePriv
return auth;
}
antlrcpp::Any CypherMainVisitor::visitLabelList(MemgraphCypher::LabelListContext *ctx) {
std::vector<std::string> labels;
if (ctx->listOfLabels()) {
for (auto *label : ctx->listOfLabels()->label()) {
labels.push_back(std::any_cast<std::string>(label->symbolicName()->accept(this)));
}
} else {
labels.emplace_back("*");
}
return labels;
}
/**
* @return AuthQuery::Privilege
*/

View File

@ -478,6 +478,11 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
*/
antlrcpp::Any visitShowPrivileges(MemgraphCypher::ShowPrivilegesContext *ctx) override;
/**
* @return AuthQuery::LabelList
*/
antlrcpp::Any visitLabelList(MemgraphCypher::LabelListContext *ctx) override;
/**
* @return AuthQuery*
*/

View File

@ -56,6 +56,7 @@ memgraphCypherKeyword : cypherKeyword
| IDENTIFIED
| ISOLATION
| KAFKA
| LABELS
| LEVEL
| LOAD
| LOCK
@ -254,10 +255,17 @@ privilege : CREATE
| MODULE_READ
| MODULE_WRITE
| WEBSOCKET
| LABELS labels=labelList
;
privilegeList : privilege ( ',' privilege )* ;
labelList : '*' | listOfLabels ;
listOfLabels : label ( ',' label )* ;
label : COLON symbolicName ;
showPrivileges : SHOW PRIVILEGES FOR userOrRole=userOrRoleName ;
showRoleForUser : SHOW ROLE FOR user=userOrRoleName ;

View File

@ -66,6 +66,7 @@ IDENTIFIED : I D E N T I F I E D ;
IGNORE : I G N O R E ;
ISOLATION : I S O L A T I O N ;
KAFKA : K A F K A ;
LABELS : L A B E L S ;
LEVEL : L E V E L ;
LOAD : L O A D ;
LOCK : L O C K ;

View File

@ -204,8 +204,9 @@ const trie::Trie kKeywords = {"union",
"pulsar",
"service_url",
"version",
"websocket"
"foreach"};
"websocket",
"foreach",
"labels"};
// Unicode codepoints that are allowed at the start of the unescaped name.
const std::bitset<kBitsetSize> kUnescapedNameAllowedStarts(

View File

@ -29,6 +29,7 @@
#include "query/db_accessor.hpp"
#include "query/dump.hpp"
#include "query/exceptions.hpp"
#include "query/fine_grained_access_checker.hpp"
#include "query/frontend/ast/ast.hpp"
#include "query/frontend/ast/ast_visitor.hpp"
#include "query/frontend/ast/cypher_main_visitor.hpp"
@ -259,6 +260,24 @@ class ReplQueryHandler final : public query::ReplicationQueryHandler {
private:
storage::Storage *db_;
};
class FineGrainedAccessChecker final : public memgraph::query::FineGrainedAccessChecker {
public:
explicit FineGrainedAccessChecker(memgraph::auth::User *user) : user_{user} {}
bool IsUserAuthorizedLabels(const std::vector<memgraph::storage::LabelId> &labels,
memgraph::query::DbAccessor *dba) const final {
auto labelPermissions = user_->GetFineGrainedAccessPermissions();
return std::any_of(labels.begin(), labels.end(), [&labelPermissions, &dba](const auto label) {
return labelPermissions.Has(dba->LabelToName(label)) == memgraph::auth::PermissionLevel::GRANT;
});
}
private:
memgraph::auth::User *user_;
};
/// returns false if the replication role can't be set
/// @throw QueryRuntimeException if an error ocurred.
@ -280,6 +299,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Pa
std::string rolename = auth_query->role_;
std::string user_or_role = auth_query->user_or_role_;
std::vector<AuthQuery::Privilege> privileges = auth_query->privileges_;
std::vector<std::string> labels = auth_query->labels_;
auto password = EvaluateOptionalExpression(auth_query->password_, &evaluator);
Callback callback;
@ -309,7 +329,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Pa
// If the license is not valid we create users with admin access
if (!valid_enterprise_license) {
spdlog::warn("Granting all the privileges to {}.", username);
auth->GrantPrivilege(username, kPrivilegesAll);
auth->GrantPrivilege(username, kPrivilegesAll, {"*"});
}
return std::vector<std::vector<TypedValue>>();
@ -384,20 +404,20 @@ Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Pa
};
return callback;
case AuthQuery::Action::GRANT_PRIVILEGE:
callback.fn = [auth, user_or_role, privileges] {
auth->GrantPrivilege(user_or_role, privileges);
callback.fn = [auth, user_or_role, privileges, labels] {
auth->GrantPrivilege(user_or_role, privileges, labels);
return std::vector<std::vector<TypedValue>>();
};
return callback;
case AuthQuery::Action::DENY_PRIVILEGE:
callback.fn = [auth, user_or_role, privileges] {
auth->DenyPrivilege(user_or_role, privileges);
callback.fn = [auth, user_or_role, privileges, labels] {
auth->DenyPrivilege(user_or_role, privileges, labels);
return std::vector<std::vector<TypedValue>>();
};
return callback;
case AuthQuery::Action::REVOKE_PRIVILEGE: {
callback.fn = [auth, user_or_role, privileges] {
auth->RevokePrivilege(user_or_role, privileges);
callback.fn = [auth, user_or_role, privileges, labels] {
auth->RevokePrivilege(user_or_role, privileges, labels);
return std::vector<std::vector<TypedValue>>();
};
return callback;
@ -897,7 +917,7 @@ struct PullPlanVector {
struct PullPlan {
explicit PullPlan(std::shared_ptr<CachedPlan> plan, const Parameters &parameters, bool is_profile_query,
DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory,
TriggerContextCollector *trigger_context_collector = nullptr,
std::optional<std::string> username, TriggerContextCollector *trigger_context_collector = nullptr,
std::optional<size_t> memory_limit = {});
std::optional<plan::ProfilingStatsWithTotalTime> Pull(AnyStream *stream, std::optional<int> n,
const std::vector<Symbol> &output_symbols,
@ -926,7 +946,8 @@ struct PullPlan {
PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters &parameters, const bool is_profile_query,
DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory,
TriggerContextCollector *trigger_context_collector, const std::optional<size_t> memory_limit)
std::optional<std::string> username, TriggerContextCollector *trigger_context_collector,
const std::optional<size_t> memory_limit)
: plan_(plan),
cursor_(plan->plan().MakeCursor(execution_memory)),
frame_(plan->symbol_table().max_position(), execution_memory),
@ -937,6 +958,12 @@ PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters &par
ctx_.evaluation_context.parameters = parameters;
ctx_.evaluation_context.properties = NamesToProperties(plan->ast_storage().properties_, dba);
ctx_.evaluation_context.labels = NamesToLabels(plan->ast_storage().labels_, dba);
#ifdef MG_ENTERPRISE
if (username.has_value()) {
memgraph::auth::User *user = interpreter_context->auth->GetUser(*username);
ctx_.fine_grained_access_checker = new FineGrainedAccessChecker{user};
}
#endif
if (interpreter_context->config.execution_timeout_sec > 0) {
ctx_.timer = utils::AsyncTimer{interpreter_context->config.execution_timeout_sec};
}
@ -1110,6 +1137,7 @@ PreparedQuery Interpreter::PrepareTransactionQuery(std::string_view query_upper)
PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string, TypedValue> *summary,
InterpreterContext *interpreter_context, DbAccessor *dba,
utils::MemoryResource *execution_memory, std::vector<Notification> *notifications,
const std::string *username,
TriggerContextCollector *trigger_context_collector = nullptr) {
auto *cypher_query = utils::Downcast<CypherQuery>(parsed_query.query);
@ -1153,8 +1181,9 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string,
header.push_back(
utils::FindOr(parsed_query.stripped_query.named_expressions(), symbol.token_position(), symbol.name()).first);
}
auto pull_plan = std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context,
execution_memory, trigger_context_collector, memory_limit);
auto pull_plan =
std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context, execution_memory,
StringPointerToOptional(username), trigger_context_collector, memory_limit);
return PreparedQuery{std::move(header), std::move(parsed_query.required_privileges),
[pull_plan = std::move(pull_plan), output_symbols = std::move(output_symbols), summary](
AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> {
@ -1214,7 +1243,8 @@ PreparedQuery PrepareExplainQuery(ParsedQuery parsed_query, std::map<std::string
PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
std::map<std::string, TypedValue> *summary, InterpreterContext *interpreter_context,
DbAccessor *dba, utils::MemoryResource *execution_memory) {
DbAccessor *dba, utils::MemoryResource *execution_memory,
const std::string *username) {
const std::string kProfileQueryStart = "profile ";
MG_ASSERT(utils::StartsWith(utils::ToLowerCase(parsed_query.stripped_query.query()), kProfileQueryStart),
@ -1264,12 +1294,14 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra
parsed_inner_query.stripped_query.hash(), std::move(parsed_inner_query.ast_storage), cypher_query,
parsed_inner_query.parameters, parsed_inner_query.is_cacheable ? &interpreter_context->plan_cache : nullptr, dba);
auto rw_type_checker = plan::ReadWriteTypeChecker();
auto optional_username = StringPointerToOptional(username);
rw_type_checker.InferRWType(const_cast<plan::LogicalOperator &>(cypher_query_plan->plan()));
return PreparedQuery{{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"},
std::move(parsed_query.required_privileges),
[plan = std::move(cypher_query_plan), parameters = std::move(parsed_inner_query.parameters),
summary, dba, interpreter_context, execution_memory, memory_limit,
summary, dba, interpreter_context, execution_memory, memory_limit, optional_username,
// We want to execute the query we are profiling lazily, so we delay
// the construction of the corresponding context.
stats_and_total_time = std::optional<plan::ProfilingStatsWithTotalTime>{},
@ -1278,7 +1310,7 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra
// No output symbols are given so that nothing is streamed.
if (!stats_and_total_time) {
stats_and_total_time = PullPlan(plan, parameters, true, dba, interpreter_context,
execution_memory, nullptr, memory_limit)
execution_memory, optional_username, nullptr, memory_limit)
.Pull(stream, {}, {}, summary);
pull_plan = std::make_shared<PullPlanVector>(ProfilingStatsToTable(*stats_and_total_time));
}
@ -1413,7 +1445,7 @@ PreparedQuery PrepareIndexQuery(ParsedQuery parsed_query, bool in_explicit_trans
PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
std::map<std::string, TypedValue> *summary, InterpreterContext *interpreter_context,
DbAccessor *dba, utils::MemoryResource *execution_memory) {
DbAccessor *dba, utils::MemoryResource *execution_memory, const std::string *username) {
if (in_explicit_transaction) {
throw UserModificationInMulticommandTxException();
}
@ -1433,8 +1465,8 @@ PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transa
[fn = callback.fn](Frame *, ExecutionContext *) { return fn(); }),
0.0, AstStorage{}, symbol_table));
auto pull_plan =
std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context, execution_memory);
auto pull_plan = std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context,
execution_memory, StringPointerToOptional(username));
return PreparedQuery{
callback.header, std::move(parsed_query.required_privileges),
[pull_plan = std::move(pull_plan), callback = std::move(callback), output_symbols = std::move(output_symbols),
@ -2146,7 +2178,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
if (utils::Downcast<CypherQuery>(parsed_query.query)) {
prepared_query = PrepareCypherQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_,
&*execution_db_accessor_, &query_execution->execution_memory,
&query_execution->notifications,
&query_execution->notifications, username,
trigger_context_collector_ ? &*trigger_context_collector_ : nullptr);
} else if (utils::Downcast<ExplainQuery>(parsed_query.query)) {
prepared_query = PrepareExplainQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_,
@ -2154,7 +2186,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
} else if (utils::Downcast<ProfileQuery>(parsed_query.query)) {
prepared_query = PrepareProfileQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary,
interpreter_context_, &*execution_db_accessor_,
&query_execution->execution_memory_with_exception);
&query_execution->execution_memory_with_exception, username);
} else if (utils::Downcast<DumpQuery>(parsed_query.query)) {
prepared_query = PrepareDumpQuery(std::move(parsed_query), &query_execution->summary, &*execution_db_accessor_,
&query_execution->execution_memory);
@ -2164,7 +2196,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
} else if (utils::Downcast<AuthQuery>(parsed_query.query)) {
prepared_query = PrepareAuthQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary,
interpreter_context_, &*execution_db_accessor_,
&query_execution->execution_memory_with_exception);
&query_execution->execution_memory_with_exception, username);
} else if (utils::Downcast<InfoQuery>(parsed_query.query)) {
prepared_query = PrepareInfoQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary,
interpreter_context_, interpreter_context_->db,

View File

@ -13,6 +13,7 @@
#include <gflags/gflags.h>
#include "auth/models.hpp"
#include "query/auth_checker.hpp"
#include "query/config.hpp"
#include "query/context.hpp"
@ -99,14 +100,19 @@ class AuthQueryHandler {
virtual std::vector<std::vector<TypedValue>> GetPrivileges(const std::string &user_or_role) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void GrantPrivilege(const std::string &user_or_role, const std::vector<AuthQuery::Privilege> &privileges) = 0;
virtual memgraph::auth::User *GetUser(const std::string &username) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void DenyPrivilege(const std::string &user_or_role, const std::vector<AuthQuery::Privilege> &privileges) = 0;
virtual void GrantPrivilege(const std::string &user_or_role, const std::vector<AuthQuery::Privilege> &privileges,
const std::vector<std::string> &labels) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void RevokePrivilege(const std::string &user_or_role,
const std::vector<AuthQuery::Privilege> &privileges) = 0;
virtual void DenyPrivilege(const std::string &user_or_role, const std::vector<AuthQuery::Privilege> &privileges,
const std::vector<std::string> &labels) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void RevokePrivilege(const std::string &user_or_role, const std::vector<AuthQuery::Privilege> &privileges,
const std::vector<std::string> &labels) = 0;
};
enum class QueryHandlerResult { COMMIT, ABORT, NOTHING };

View File

@ -468,8 +468,9 @@ TEST(AuthWithoutStorage, CaseInsensitivity) {
}
{
auto perms = Permissions();
auto user1 = User("test", "pw", perms);
auto user2 = User("Test", "pw", perms);
auto fine_grained_perms = FineGrainedAccessPermissions();
auto user1 = User("test", "pw", perms, fine_grained_perms);
auto user2 = User("Test", "pw", perms, fine_grained_perms);
ASSERT_EQ(user1, user2);
ASSERT_EQ(user1.username(), user2.username());
ASSERT_EQ(user1.username(), "test");
@ -485,8 +486,9 @@ TEST(AuthWithoutStorage, CaseInsensitivity) {
}
{
auto perms = Permissions();
auto role1 = Role("role", perms);
auto role2 = Role("Role", perms);
auto fine_grained_perms = FineGrainedAccessPermissions();
auto role1 = Role("role", perms, fine_grained_perms);
auto role2 = Role("Role", perms, fine_grained_perms);
ASSERT_EQ(role1, role2);
ASSERT_EQ(role1.rolename(), role2.rolename());
ASSERT_EQ(role1.rolename(), "role");

View File

@ -531,9 +531,9 @@ auto GetForeach(AstStorage &storage, NamedExpression *named_expr, const std::vec
memgraph::query::test_common::OnCreate { \
std::vector<memgraph::query::Clause *> { __VA_ARGS__ } \
}
#define CREATE_INDEX_ON(label, property) \
#define CREATE_INDEX_ON(label, property) \
storage.Create<memgraph::query::IndexQuery>(memgraph::query::IndexQuery::Action::CREATE, (label), \
std::vector<memgraph::query::PropertyIx>{(property)})
std::vector<memgraph::query::PropertyIx>{(property)})
#define QUERY(...) memgraph::query::test_common::GetQuery(storage, __VA_ARGS__)
#define SINGLE_QUERY(...) memgraph::query::test_common::GetSingleQuery(storage.Create<SingleQuery>(), __VA_ARGS__)
#define UNION(...) memgraph::query::test_common::GetCypherUnion(storage.Create<CypherUnion>(true), __VA_ARGS__)
@ -583,7 +583,7 @@ auto GetForeach(AstStorage &storage, NamedExpression *named_expr, const std::vec
#define COALESCE(...) storage.Create<memgraph::query::Coalesce>(std::vector<memgraph::query::Expression *>{__VA_ARGS__})
#define EXTRACT(variable, list, expr) \
storage.Create<memgraph::query::Extract>(storage.Create<memgraph::query::Identifier>(variable), list, expr)
#define AUTH_QUERY(action, user, role, user_or_role, password, privileges) \
storage.Create<memgraph::query::AuthQuery>((action), (user), (role), (user_or_role), password, (privileges))
#define AUTH_QUERY(action, user, role, user_or_role, password, privileges, labels) \
storage.Create<memgraph::query::AuthQuery>((action), (user), (role), (user_or_role), password, (privileges), (labels))
#define DROP_USER(usernames) storage.Create<memgraph::query::DropUser>((usernames))
#define CALL_PROCEDURE(...) memgraph::query::test_common::GetCallProcedure(storage, __VA_ARGS__)

View File

@ -98,8 +98,8 @@ TEST_F(TestPrivilegeExtractor, CreateIndex) {
}
TEST_F(TestPrivilegeExtractor, AuthQuery) {
auto *query =
AUTH_QUERY(AuthQuery::Action::CREATE_ROLE, "", "role", "", nullptr, std::vector<AuthQuery::Privilege>{});
auto *query = AUTH_QUERY(AuthQuery::Action::CREATE_ROLE, "", "role", "", nullptr, std::vector<AuthQuery::Privilege>{},
std::vector<std::string>{});
EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::AUTH));
}