From d0babcddc5f110e0273ea6ce59ca04f4e5499403 Mon Sep 17 00:00:00 2001 From: Josip Mrden Date: Wed, 28 Feb 2024 13:39:41 +0100 Subject: [PATCH] Pass user to query execution --- src/query/context.hpp | 2 +- src/query/interpreter.cpp | 11 +++-------- src/query/procedure/mg_procedure_impl.cpp | 7 ++----- 3 files changed, 6 insertions(+), 14 deletions(-) diff --git a/src/query/context.hpp b/src/query/context.hpp index 288505c7d..b2977a6f8 100644 --- a/src/query/context.hpp +++ b/src/query/context.hpp @@ -92,7 +92,7 @@ struct ExecutionContext { TriggerContextCollector *trigger_context_collector{nullptr}; FrameChangeCollector *frame_change_collector{nullptr}; std::shared_ptr timer; - UserExecutionContextInfo user_info; + std::shared_ptr user_or_role; #ifdef MG_ENTERPRISE std::unique_ptr auth_checker{nullptr}; #endif diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index 4395da387..5c987d0bd 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -1720,17 +1720,12 @@ PullPlan::PullPlan(const std::shared_ptr plan, const Parameters &pa ctx_.evaluation_context.parameters = parameters; ctx_.evaluation_context.properties = NamesToProperties(plan->ast_storage().properties_, dba); ctx_.evaluation_context.labels = NamesToLabels(plan->ast_storage().labels_, dba); - if (!user_or_role) { - ctx_.user_info = {.mode = UserExecutionContextInfo::UserMode::NONE, .name = ""}; - } else { - ctx_.user_info = {.mode = user_or_role->username() ? UserExecutionContextInfo::UserMode::USER - : UserExecutionContextInfo::UserMode::ROLE, - .name = user_or_role->key()}; - } + ctx_.user_or_role = user_or_role; + #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast() && user_or_role && *user_or_role && dba) { // Create only if an explicit user is defined - auto auth_checker = interpreter_context->auth_checker->GetFineGrainedAuthChecker(std::move(user_or_role), dba); + auto auth_checker = interpreter_context->auth_checker->GetFineGrainedAuthChecker(user_or_role, dba); // if the user has global privileges to read, edit and write anything, we don't need to perform authorization // otherwise, we do assign the auth checker to check for label access control diff --git a/src/query/procedure/mg_procedure_impl.cpp b/src/query/procedure/mg_procedure_impl.cpp index 4b2d8d1c6..373099a40 100644 --- a/src/query/procedure/mg_procedure_impl.cpp +++ b/src/query/procedure/mg_procedure_impl.cpp @@ -23,6 +23,7 @@ #include #include +#include "glue/auth.hpp" #include "license/license.hpp" #include "mg_procedure.h" #include "module.hpp" @@ -4028,10 +4029,10 @@ mgp_error mgp_untrack_current_thread_allocations(mgp_graph *graph) { mgp_error mgp_execute_query(mgp_graph *graph, const char *query) { return WrapExceptions([&]() { auto query_string = std::string(query); - auto user_info = graph->ctx->user_info; auto *instance = memgraph::query::InterpreterContext::getInstance(); memgraph::query::Interpreter interpreter(instance); + interpreter.SetUser(graph->ctx->user_or_role); instance->interpreters.WithLock([&interpreter](auto &interpreters) { interpreters.insert(&interpreter); }); @@ -4039,10 +4040,6 @@ mgp_error mgp_execute_query(mgp_graph *graph, const char *query) { instance->interpreters.WithLock([&interpreter](auto &interpreters) { interpreters.erase(&interpreter); }); }); - memgraph::query::AllowEverythingAuthChecker tmp_auth_checker; - auto tmp_user = tmp_auth_checker.GenQueryUser(std::nullopt, std::nullopt); - interpreter.SetUser(tmp_user); - auto results = interpreter.Prepare(query_string, {}, {}); memgraph::query::DiscardValueResultStream stream; interpreter.Pull(&stream, {}, results.qid);