LabelChecker fixed

This commit is contained in:
niko4299 2022-07-12 12:07:28 +02:00
parent 41ae33c671
commit 19dff1ae5c
4 changed files with 16 additions and 11 deletions

View File

@ -747,7 +747,7 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler {
}
}
memgraph::auth::User GetUser(const std::string &username) override {
memgraph::auth::User *GetUser(const std::string &username) override {
if (!std::regex_match(username, name_regex_)) {
throw memgraph::query::QueryRuntimeException("Invalid user name.");
}
@ -758,7 +758,8 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler {
throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist .", username);
}
return user.value();
return new memgraph::auth::User(*user);
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
@ -821,17 +822,21 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler {
for (const auto &permission : permissions) {
edit_fun(&user->permissions(), permission);
}
for (const auto &label : labels) {
edit_fun(&user->labelPermissions(), label);
}
locked_auth->SaveUser(*user);
} else {
for (const auto &permission : permissions) {
edit_fun(&role->permissions(), permission);
}
for (const auto &label : labels) {
edit_fun(&role->labelPermissions(), label);
}
locked_auth->SaveRole(*role);
}
} catch (const memgraph::auth::AuthException &e) {
@ -1268,7 +1273,6 @@ int main(int argc, char **argv) {
AuthChecker auth_checker{&auth};
interpreter_context.auth = &auth_handler;
interpreter_context.auth_checker = &auth_checker;
// interpreter_context.label_checker = &label_checker;
{
// Triggers can execute query procedures, so we need to reload the modules first and then

View File

@ -267,6 +267,8 @@ class LabelChecker final : public memgraph::query::LabelChecker {
const auto user_label_permissions = user_->GetLabelPermissions();
auto *dba = dba_;
if (user_label_permissions.Has("*") == memgraph::auth::PermissionLevel::GRANT) return true;
return std::all_of(labels.begin(), labels.end(), [&user_label_permissions, dba](const auto label) {
return user_label_permissions.Has(dba->LabelToName(label)) == memgraph::auth::PermissionLevel::GRANT;
});
@ -958,10 +960,8 @@ PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters &par
ctx_.evaluation_context.properties = NamesToProperties(plan->ast_storage().properties_, dba);
ctx_.evaluation_context.labels = NamesToLabels(plan->ast_storage().labels_, dba);
if (username.has_value()) {
memgraph::auth::User user = interpreter_context->auth->GetUser(*username);
LabelChecker label_checker{&user, dba};
ctx_.label_checker = &label_checker;
memgraph::auth::User *user = interpreter_context->auth->GetUser(*username);
ctx_.label_checker = new LabelChecker{user, dba};
}
if (interpreter_context->config.execution_timeout_sec > 0) {

View File

@ -92,7 +92,7 @@ class AuthQueryHandler {
virtual std::vector<TypedValue> GetUsernamesForRole(const std::string &rolename) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual memgraph::auth::User GetUser(const std::string &username) = 0;
virtual memgraph::auth::User *GetUser(const std::string &username) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void SetRole(const std::string &username, const std::string &rolename) = 0;

View File

@ -404,10 +404,11 @@ class ScanAllCursor : public Cursor {
vertices_.emplace(std::move(next_vertices.value()));
vertices_it_.emplace(vertices_.value().begin());
}
while (vertices_it_.value() != vertices_.value().end()) {
VertexAccessor vector = *vertices_it_.value();
auto labels = vector.Labels(memgraph::storage::View::NEW).GetValue();
if (context.label_checker->IsUserAuthorized(labels) || !context.label_checker) {
VertexAccessor vertex = *vertices_it_.value();
auto vertex_labels = vertex.Labels(memgraph::storage::View::NEW).GetValue();
if (!context.label_checker || context.label_checker->IsUserAuthorized(vertex_labels)) {
frame[output_symbol_] = *vertices_it_.value();
++vertices_it_.value();
return true;