From 19dff1ae5c41d94195bd498293f86be48f7d34ae Mon Sep 17 00:00:00 2001 From: niko4299 Date: Tue, 12 Jul 2022 12:07:28 +0200 Subject: [PATCH] LabelChecker fixed --- src/memgraph.cpp | 10 +++++++--- src/query/interpreter.cpp | 8 ++++---- src/query/interpreter.hpp | 2 +- src/query/plan/operator.cpp | 7 ++++--- 4 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/memgraph.cpp b/src/memgraph.cpp index aab4b6cbb..19e7f1f41 100644 --- a/src/memgraph.cpp +++ b/src/memgraph.cpp @@ -747,7 +747,7 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler { } } - memgraph::auth::User GetUser(const std::string &username) override { + memgraph::auth::User *GetUser(const std::string &username) override { if (!std::regex_match(username, name_regex_)) { throw memgraph::query::QueryRuntimeException("Invalid user name."); } @@ -758,7 +758,8 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler { throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist .", username); } - return user.value(); + return new memgraph::auth::User(*user); + } catch (const memgraph::auth::AuthException &e) { throw memgraph::query::QueryRuntimeException(e.what()); } @@ -821,17 +822,21 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler { for (const auto &permission : permissions) { edit_fun(&user->permissions(), permission); } + for (const auto &label : labels) { edit_fun(&user->labelPermissions(), label); } + locked_auth->SaveUser(*user); } else { for (const auto &permission : permissions) { edit_fun(&role->permissions(), permission); } + for (const auto &label : labels) { edit_fun(&role->labelPermissions(), label); } + locked_auth->SaveRole(*role); } } catch (const memgraph::auth::AuthException &e) { @@ -1268,7 +1273,6 @@ int main(int argc, char **argv) { AuthChecker auth_checker{&auth}; interpreter_context.auth = &auth_handler; interpreter_context.auth_checker = &auth_checker; - // interpreter_context.label_checker = &label_checker; { // Triggers can execute query procedures, so we need to reload the modules first and then diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index 3a1438cf9..595fad46f 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -267,6 +267,8 @@ class LabelChecker final : public memgraph::query::LabelChecker { const auto user_label_permissions = user_->GetLabelPermissions(); auto *dba = dba_; + if (user_label_permissions.Has("*") == memgraph::auth::PermissionLevel::GRANT) return true; + return std::all_of(labels.begin(), labels.end(), [&user_label_permissions, dba](const auto label) { return user_label_permissions.Has(dba->LabelToName(label)) == memgraph::auth::PermissionLevel::GRANT; }); @@ -958,10 +960,8 @@ PullPlan::PullPlan(const std::shared_ptr plan, const Parameters &par ctx_.evaluation_context.properties = NamesToProperties(plan->ast_storage().properties_, dba); ctx_.evaluation_context.labels = NamesToLabels(plan->ast_storage().labels_, dba); if (username.has_value()) { - memgraph::auth::User user = interpreter_context->auth->GetUser(*username); - LabelChecker label_checker{&user, dba}; - - ctx_.label_checker = &label_checker; + memgraph::auth::User *user = interpreter_context->auth->GetUser(*username); + ctx_.label_checker = new LabelChecker{user, dba}; } if (interpreter_context->config.execution_timeout_sec > 0) { diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index 9c25ea383..046355948 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -92,7 +92,7 @@ class AuthQueryHandler { virtual std::vector GetUsernamesForRole(const std::string &rolename) = 0; /// @throw QueryRuntimeException if an error ocurred. - virtual memgraph::auth::User GetUser(const std::string &username) = 0; + virtual memgraph::auth::User *GetUser(const std::string &username) = 0; /// @throw QueryRuntimeException if an error ocurred. virtual void SetRole(const std::string &username, const std::string &rolename) = 0; diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index e1422c598..e1e5a5a39 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -404,10 +404,11 @@ class ScanAllCursor : public Cursor { vertices_.emplace(std::move(next_vertices.value())); vertices_it_.emplace(vertices_.value().begin()); } + while (vertices_it_.value() != vertices_.value().end()) { - VertexAccessor vector = *vertices_it_.value(); - auto labels = vector.Labels(memgraph::storage::View::NEW).GetValue(); - if (context.label_checker->IsUserAuthorized(labels) || !context.label_checker) { + VertexAccessor vertex = *vertices_it_.value(); + auto vertex_labels = vertex.Labels(memgraph::storage::View::NEW).GetValue(); + if (!context.label_checker || context.label_checker->IsUserAuthorized(vertex_labels)) { frame[output_symbol_] = *vertices_it_.value(); ++vertices_it_.value(); return true;