This commit is contained in:
niko4299 2022-07-19 10:26:04 +02:00
parent efa64fc864
commit db655dab5e
3 changed files with 20 additions and 14 deletions

View File

@ -261,11 +261,11 @@ class ReplQueryHandler final : public query::ReplicationQueryHandler {
class LabelChecker final : public memgraph::query::LabelChecker {
public:
explicit LabelChecker(memgraph::auth::User *user, memgraph::query::DbAccessor *dba) : user_{user}, dba_(dba) {}
explicit LabelChecker(memgraph::auth::User *user) : user_{user} {}
bool IsUserAuthorized(const std::vector<memgraph::storage::LabelId> &labels) const final {
bool IsUserAuthorized(const std::vector<memgraph::storage::LabelId> &labels,
memgraph::query::DbAccessor *dba) const final {
const auto user_label_permissions = user_->GetLabelPermissions();
auto *dba = dba_;
if (user_label_permissions.Has("*") == memgraph::auth::PermissionLevel::GRANT) return true;
@ -276,7 +276,6 @@ class LabelChecker final : public memgraph::query::LabelChecker {
private:
memgraph::auth::User *user_;
memgraph::query::DbAccessor *dba_;
};
Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Parameters &parameters,
@ -962,7 +961,7 @@ PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters &par
#ifdef MG_ENTERPRISE
if (username.has_value()) {
memgraph::auth::User *user = interpreter_context->auth->GetUser(*username);
ctx_.label_checker = new LabelChecker{user, dba};
ctx_.label_checker = new LabelChecker{user};
}
#endif
if (interpreter_context->config.execution_timeout_sec > 0) {

View File

@ -18,6 +18,7 @@
namespace memgraph::query {
class LabelChecker {
public:
virtual bool IsUserAuthorized(const std::vector<memgraph::storage::LabelId> &label) const = 0;
virtual bool IsUserAuthorized(const std::vector<memgraph::storage::LabelId> &label,
memgraph::query::DbAccessor *dba) const = 0;
};
} // namespace memgraph::query

View File

@ -32,6 +32,7 @@
#include "query/frontend/ast/ast.hpp"
#include "query/frontend/semantic/symbol_table.hpp"
#include "query/interpret/eval.hpp"
#include "query/label_checker.hpp"
#include "query/path.hpp"
#include "query/plan/scoped_profile.hpp"
#include "query/procedure/cypher_types.hpp"
@ -406,14 +407,7 @@ class ScanAllCursor : public Cursor {
}
#ifdef MG_ENTERPRISE
while (vertices_it_.value() != vertices_.value().end()) {
VertexAccessor vertex = *vertices_it_.value();
auto vertex_labels = vertex.Labels(memgraph::storage::View::NEW).GetValue();
if (!context.label_checker || context.label_checker->IsUserAuthorized(vertex_labels)) {
break;
}
++vertices_it_.value();
}
FilterNodes(context.label_checker, context.db_accessor);
if (vertices_it_.value() == vertices_.value().end()) return false;
#endif
@ -423,6 +417,18 @@ class ScanAllCursor : public Cursor {
return true;
}
void FilterNodes(const LabelChecker *label_checker, DbAccessor *dba) {
if (!label_checker) return;
while (vertices_it_.value() != vertices_.value().end()) {
VertexAccessor vertex = *vertices_it_.value();
auto vertex_labels = vertex.Labels(memgraph::storage::View::NEW).GetValue();
if (label_checker->IsUserAuthorized(vertex_labels, dba)) {
break;
}
++vertices_it_.value();
}
}
void Shutdown() override { input_cursor_->Shutdown(); }
void Reset() override {