From efa64fc864e44d7fff9587d0975a21c65d44cea1 Mon Sep 17 00:00:00 2001 From: niko4299 Date: Mon, 18 Jul 2022 15:28:52 +0200 Subject: [PATCH 1/2] Filtering --- src/query/interpreter.cpp | 3 ++- src/query/plan/operator.cpp | 17 +++++++++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index 595fad46f..ee6da8e2f 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -959,11 +959,12 @@ PullPlan::PullPlan(const std::shared_ptr plan, const Parameters &par ctx_.evaluation_context.parameters = parameters; ctx_.evaluation_context.properties = NamesToProperties(plan->ast_storage().properties_, dba); ctx_.evaluation_context.labels = NamesToLabels(plan->ast_storage().labels_, dba); +#ifdef MG_ENTERPRISE if (username.has_value()) { memgraph::auth::User *user = interpreter_context->auth->GetUser(*username); ctx_.label_checker = new LabelChecker{user, dba}; } - +#endif if (interpreter_context->config.execution_timeout_sec > 0) { ctx_.timer = utils::AsyncTimer{interpreter_context->config.execution_timeout_sec}; } diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index e1e5a5a39..51a6ec542 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -394,8 +394,8 @@ class ScanAllCursor : public Cursor { while (!vertices_ || vertices_it_.value() == vertices_.value().end()) { if (!input_cursor_->Pull(frame, context)) return false; - // We need a getter function, because in case of exhausting a lazy - // iterable, we cannot simply reset it by calling begin(). + // We need a getter function, because in case of exhausting a lazy iterable, + // we cannot simply reset it by calling begin(). auto next_vertices = get_vertices_(frame, context); if (!next_vertices) continue; // Since vertices iterator isn't nothrow_move_assignable, we have to use @@ -405,17 +405,22 @@ class ScanAllCursor : public Cursor { vertices_it_.emplace(vertices_.value().begin()); } +#ifdef MG_ENTERPRISE while (vertices_it_.value() != vertices_.value().end()) { VertexAccessor vertex = *vertices_it_.value(); auto vertex_labels = vertex.Labels(memgraph::storage::View::NEW).GetValue(); if (!context.label_checker || context.label_checker->IsUserAuthorized(vertex_labels)) { - frame[output_symbol_] = *vertices_it_.value(); - ++vertices_it_.value(); - return true; + break; } ++vertices_it_.value(); } - return false; + if (vertices_it_.value() == vertices_.value().end()) return false; +#endif + + frame[output_symbol_] = *vertices_it_.value(); + ++vertices_it_.value(); + + return true; } void Shutdown() override { input_cursor_->Shutdown(); } From db655dab5e76aa466e39265960cf698c56c47a88 Mon Sep 17 00:00:00 2001 From: niko4299 Date: Tue, 19 Jul 2022 10:26:04 +0200 Subject: [PATCH 2/2] refactor --- src/query/interpreter.cpp | 9 ++++----- src/query/label_checker.hpp | 3 ++- src/query/plan/operator.cpp | 22 ++++++++++++++-------- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index ee6da8e2f..4cf00de36 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -261,11 +261,11 @@ class ReplQueryHandler final : public query::ReplicationQueryHandler { class LabelChecker final : public memgraph::query::LabelChecker { public: - explicit LabelChecker(memgraph::auth::User *user, memgraph::query::DbAccessor *dba) : user_{user}, dba_(dba) {} + explicit LabelChecker(memgraph::auth::User *user) : user_{user} {} - bool IsUserAuthorized(const std::vector &labels) const final { + bool IsUserAuthorized(const std::vector &labels, + memgraph::query::DbAccessor *dba) const final { const auto user_label_permissions = user_->GetLabelPermissions(); - auto *dba = dba_; if (user_label_permissions.Has("*") == memgraph::auth::PermissionLevel::GRANT) return true; @@ -276,7 +276,6 @@ class LabelChecker final : public memgraph::query::LabelChecker { private: memgraph::auth::User *user_; - memgraph::query::DbAccessor *dba_; }; Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Parameters ¶meters, @@ -962,7 +961,7 @@ PullPlan::PullPlan(const std::shared_ptr plan, const Parameters &par #ifdef MG_ENTERPRISE if (username.has_value()) { memgraph::auth::User *user = interpreter_context->auth->GetUser(*username); - ctx_.label_checker = new LabelChecker{user, dba}; + ctx_.label_checker = new LabelChecker{user}; } #endif if (interpreter_context->config.execution_timeout_sec > 0) { diff --git a/src/query/label_checker.hpp b/src/query/label_checker.hpp index a874f0694..c6d8f5de0 100644 --- a/src/query/label_checker.hpp +++ b/src/query/label_checker.hpp @@ -18,6 +18,7 @@ namespace memgraph::query { class LabelChecker { public: - virtual bool IsUserAuthorized(const std::vector &label) const = 0; + virtual bool IsUserAuthorized(const std::vector &label, + memgraph::query::DbAccessor *dba) const = 0; }; } // namespace memgraph::query diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 51a6ec542..8ec2610ab 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -32,6 +32,7 @@ #include "query/frontend/ast/ast.hpp" #include "query/frontend/semantic/symbol_table.hpp" #include "query/interpret/eval.hpp" +#include "query/label_checker.hpp" #include "query/path.hpp" #include "query/plan/scoped_profile.hpp" #include "query/procedure/cypher_types.hpp" @@ -406,14 +407,7 @@ class ScanAllCursor : public Cursor { } #ifdef MG_ENTERPRISE - while (vertices_it_.value() != vertices_.value().end()) { - VertexAccessor vertex = *vertices_it_.value(); - auto vertex_labels = vertex.Labels(memgraph::storage::View::NEW).GetValue(); - if (!context.label_checker || context.label_checker->IsUserAuthorized(vertex_labels)) { - break; - } - ++vertices_it_.value(); - } + FilterNodes(context.label_checker, context.db_accessor); if (vertices_it_.value() == vertices_.value().end()) return false; #endif @@ -423,6 +417,18 @@ class ScanAllCursor : public Cursor { return true; } + void FilterNodes(const LabelChecker *label_checker, DbAccessor *dba) { + if (!label_checker) return; + while (vertices_it_.value() != vertices_.value().end()) { + VertexAccessor vertex = *vertices_it_.value(); + auto vertex_labels = vertex.Labels(memgraph::storage::View::NEW).GetValue(); + if (label_checker->IsUserAuthorized(vertex_labels, dba)) { + break; + } + ++vertices_it_.value(); + } + } + void Shutdown() override { input_cursor_->Shutdown(); } void Reset() override {