From b0af87ed9bd895e0c0de2f622393673da03411f3 Mon Sep 17 00:00:00 2001 From: niko4299 Date: Mon, 11 Jul 2022 17:09:47 +0200 Subject: [PATCH 1/2] Added LabelChecker and dummy filtration --- src/auth/reference_modules/ldap.py | 64 +++++++++++++++------------ src/memgraph.cpp | 19 ++++++++ src/query/context.hpp | 3 ++ src/query/interpreter.cpp | 70 ++++++++++++++++++++++-------- src/query/interpreter.hpp | 4 ++ src/query/label_checker.hpp | 23 ++++++++++ src/query/plan/operator.cpp | 15 +++++-- 7 files changed, 148 insertions(+), 50 deletions(-) create mode 100644 src/query/label_checker.hpp diff --git a/src/auth/reference_modules/ldap.py b/src/auth/reference_modules/ldap.py index 761db8fd6..e181ce84f 100755 --- a/src/auth/reference_modules/ldap.py +++ b/src/auth/reference_modules/ldap.py @@ -18,19 +18,24 @@ roles_config = config["roles"] # Initialize LDAP server. tls = None if server_config["encryption"] != "disabled": - cert_file = server_config["cert_file"] if server_config["cert_file"] \ - else None + cert_file = server_config["cert_file"] if server_config["cert_file"] else None key_file = server_config["key_file"] if server_config["key_file"] else None ca_file = server_config["ca_file"] if server_config["ca_file"] else None - validate = ssl.CERT_REQUIRED if server_config["validate_cert"] \ - else ssl.CERT_NONE - tls = ldap3.Tls(local_private_key_file=key_file, - local_certificate_file=cert_file, - ca_certs_file=ca_file, - validate=validate) + validate = ssl.CERT_REQUIRED if server_config["validate_cert"] else ssl.CERT_NONE + tls = ldap3.Tls( + local_private_key_file=key_file, + local_certificate_file=cert_file, + ca_certs_file=ca_file, + validate=validate, + ) use_ssl = server_config["encryption"] == "ssl" -server = ldap3.Server(server_config["host"], port=server_config["port"], - tls=tls, use_ssl=use_ssl, get_info=ldap3.ALL) +server = ldap3.Server( + server_config["host"], + port=server_config["port"], + tls=tls, + use_ssl=use_ssl, + get_info=ldap3.ALL, +) # Main authentication/authorization function. @@ -40,14 +45,12 @@ def authenticate(username, password): return {"authenticated": False, "role": ""} # Create the DN of the user - dn = users_config["prefix"] + ldap3.utils.dn.escape_rdn(username) + \ - users_config["suffix"] + dn = users_config["prefix"] + ldap3.utils.dn.escape_rdn(username) + users_config["suffix"] # Bind to the server conn = ldap3.Connection(server, dn, password) if server_config["encryption"] == "starttls" and not conn.start_tls(): - print("ERROR: Couldn't issue STARTTLS to the LDAP server!", - file=sys.stderr) + print("ERROR: Couldn't issue STARTTLS to the LDAP server!", file=sys.stderr) return {"authenticated": False, "role": ""} if not conn.bind(): return {"authenticated": False, "role": ""} @@ -56,25 +59,32 @@ def authenticate(username, password): if roles_config["root_dn"] != "": # search for role search_filter = "(&(objectclass={objclass})({attr}={value}))".format( - objclass=roles_config["root_objectclass"], - attr=roles_config["user_attribute"], - value=ldap3.utils.conv.escape_filter_chars(dn)) - succ = conn.search(roles_config["root_dn"], search_filter, - search_scope=ldap3.LEVEL, - attributes=[roles_config["role_attribute"]]) + objclass=roles_config["root_objectclass"], + attr=roles_config["user_attribute"], + value=ldap3.utils.conv.escape_filter_chars(dn), + ) + succ = conn.search( + roles_config["root_dn"], + search_filter, + search_scope=ldap3.LEVEL, + attributes=[roles_config["role_attribute"]], + ) if not succ or len(conn.entries) == 0: return {"authenticated": True, "role": ""} if len(conn.entries) > 1: - roles = list(map(lambda x: x[roles_config["role_attribute"]].value, - conn.entries)) + roles = list(map(lambda x: x[roles_config["role_attribute"]].value, conn.entries)) # Because we don't know exactly which role the user should have # we authorize the user with an empty role. - print("WARNING: Found more than one role for " - "user '" + username + "':", ", ".join(roles) + "!", - file=sys.stderr) + print( + "WARNING: Found more than one role for " "user '" + username + "':", + ", ".join(roles) + "!", + file=sys.stderr, + ) return {"authenticated": True, "role": ""} - return {"authenticated": True, - "role": conn.entries[0][roles_config["role_attribute"]].value} + return { + "authenticated": True, + "role": conn.entries[0][roles_config["role_attribute"]].value, + } else: return {"authenticated": True, "role": ""} diff --git a/src/memgraph.cpp b/src/memgraph.cpp index e935c03ec..aab4b6cbb 100644 --- a/src/memgraph.cpp +++ b/src/memgraph.cpp @@ -38,6 +38,7 @@ #include "helpers.hpp" #include "py/py.hpp" #include "query/auth_checker.hpp" +#include "query/db_accessor.hpp" #include "query/discard_value_stream.hpp" #include "query/exceptions.hpp" #include "query/frontend/ast/ast.hpp" @@ -746,6 +747,23 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler { } } + memgraph::auth::User GetUser(const std::string &username) override { + if (!std::regex_match(username, name_regex_)) { + throw memgraph::query::QueryRuntimeException("Invalid user name."); + } + try { + auto locked_auth = auth_->Lock(); + auto user = locked_auth->GetUser(username); + if (!user) { + throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist .", username); + } + + return user.value(); + } catch (const memgraph::auth::AuthException &e) { + throw memgraph::query::QueryRuntimeException(e.what()); + } + } + void GrantPrivilege(const std::string &user_or_role, const std::vector &privileges, const std::vector &labels) override { @@ -1250,6 +1268,7 @@ int main(int argc, char **argv) { AuthChecker auth_checker{&auth}; interpreter_context.auth = &auth_handler; interpreter_context.auth_checker = &auth_checker; + // interpreter_context.label_checker = &label_checker; { // Triggers can execute query procedures, so we need to reload the modules first and then diff --git a/src/query/context.hpp b/src/query/context.hpp index 12b1f0cac..011cc57ad 100644 --- a/src/query/context.hpp +++ b/src/query/context.hpp @@ -13,8 +13,10 @@ #include +#include "auth/models.hpp" #include "query/common.hpp" #include "query/frontend/semantic/symbol_table.hpp" +#include "query/label_checker.hpp" #include "query/metadata.hpp" #include "query/parameters.hpp" #include "query/plan/profile.hpp" @@ -72,6 +74,7 @@ struct ExecutionContext { ExecutionStats execution_stats; TriggerContextCollector *trigger_context_collector{nullptr}; utils::AsyncTimer timer; + LabelChecker *label_checker{nullptr}; }; static_assert(std::is_move_assignable_v, "ExecutionContext must be move assignable!"); diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index 209741f2b..3a1438cf9 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -36,6 +36,7 @@ #include "query/frontend/semantic/required_privileges.hpp" #include "query/frontend/semantic/symbol_generator.hpp" #include "query/interpret/eval.hpp" +#include "query/label_checker.hpp" #include "query/metadata.hpp" #include "query/plan/planner.hpp" #include "query/plan/profile.hpp" @@ -257,8 +258,24 @@ class ReplQueryHandler final : public query::ReplicationQueryHandler { private: storage::Storage *db_; }; -/// returns false if the replication role can't be set -/// @throw QueryRuntimeException if an error ocurred. + +class LabelChecker final : public memgraph::query::LabelChecker { + public: + explicit LabelChecker(memgraph::auth::User *user, memgraph::query::DbAccessor *dba) : user_{user}, dba_(dba) {} + + bool IsUserAuthorized(const std::vector &labels) const final { + const auto user_label_permissions = user_->GetLabelPermissions(); + auto *dba = dba_; + + return std::all_of(labels.begin(), labels.end(), [&user_label_permissions, dba](const auto label) { + return user_label_permissions.Has(dba->LabelToName(label)) == memgraph::auth::PermissionLevel::GRANT; + }); + } + + private: + memgraph::auth::User *user_; + memgraph::query::DbAccessor *dba_; +}; Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Parameters ¶meters, DbAccessor *db_accessor) { @@ -271,6 +288,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Pa // TODO: MemoryResource for EvaluationContext, it should probably be passed as // the argument to Callback. evaluation_context.timestamp = QueryTimestamp(); + evaluation_context.parameters = parameters; ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, db_accessor, storage::View::OLD); @@ -292,10 +310,11 @@ Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Pa AuthQuery::Action::REVOKE_PRIVILEGE, AuthQuery::Action::SHOW_PRIVILEGES, AuthQuery::Action::SHOW_USERS_FOR_ROLE, AuthQuery::Action::SHOW_ROLE_FOR_USER}; - if (license_check_result.HasError() && enterprise_only_methods.contains(auth_query->action_)) { - throw utils::BasicException( - utils::license::LicenseCheckErrorToString(license_check_result.GetError(), "advanced authentication features")); - } + // if (license_check_result.HasError() && enterprise_only_methods.contains(auth_query->action_)) { + // throw utils::BasicException( + // utils::license::LicenseCheckErrorToString(license_check_result.GetError(), "advanced authentication + // features")); + // } switch (auth_query->action_) { case AuthQuery::Action::CREATE_USER: @@ -897,7 +916,7 @@ struct PullPlanVector { struct PullPlan { explicit PullPlan(std::shared_ptr plan, const Parameters ¶meters, bool is_profile_query, DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory, - TriggerContextCollector *trigger_context_collector = nullptr, + std::optional username, TriggerContextCollector *trigger_context_collector = nullptr, std::optional memory_limit = {}); std::optional Pull(AnyStream *stream, std::optional n, const std::vector &output_symbols, @@ -926,7 +945,8 @@ struct PullPlan { PullPlan::PullPlan(const std::shared_ptr plan, const Parameters ¶meters, const bool is_profile_query, DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory, - TriggerContextCollector *trigger_context_collector, const std::optional memory_limit) + std::optional username, TriggerContextCollector *trigger_context_collector, + const std::optional memory_limit) : plan_(plan), cursor_(plan->plan().MakeCursor(execution_memory)), frame_(plan->symbol_table().max_position(), execution_memory), @@ -937,6 +957,13 @@ PullPlan::PullPlan(const std::shared_ptr plan, const Parameters &par ctx_.evaluation_context.parameters = parameters; ctx_.evaluation_context.properties = NamesToProperties(plan->ast_storage().properties_, dba); ctx_.evaluation_context.labels = NamesToLabels(plan->ast_storage().labels_, dba); + if (username.has_value()) { + memgraph::auth::User user = interpreter_context->auth->GetUser(*username); + LabelChecker label_checker{&user, dba}; + + ctx_.label_checker = &label_checker; + } + if (interpreter_context->config.execution_timeout_sec > 0) { ctx_.timer = utils::AsyncTimer{interpreter_context->config.execution_timeout_sec}; } @@ -1110,6 +1137,7 @@ PreparedQuery Interpreter::PrepareTransactionQuery(std::string_view query_upper) PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map *summary, InterpreterContext *interpreter_context, DbAccessor *dba, utils::MemoryResource *execution_memory, std::vector *notifications, + const std::string *username, TriggerContextCollector *trigger_context_collector = nullptr) { auto *cypher_query = utils::Downcast(parsed_query.query); @@ -1118,6 +1146,7 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::mapmemory_limit_, cypher_query->memory_scale_); if (memory_limit) { @@ -1153,8 +1182,9 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map(plan, parsed_query.parameters, false, dba, interpreter_context, - execution_memory, trigger_context_collector, memory_limit); + auto pull_plan = + std::make_shared(plan, parsed_query.parameters, false, dba, interpreter_context, execution_memory, + StringPointerToOptional(username), trigger_context_collector, memory_limit); return PreparedQuery{std::move(header), std::move(parsed_query.required_privileges), [pull_plan = std::move(pull_plan), output_symbols = std::move(output_symbols), summary]( AnyStream *stream, std::optional n) -> std::optional { @@ -1214,7 +1244,8 @@ PreparedQuery PrepareExplainQuery(ParsedQuery parsed_query, std::map *summary, InterpreterContext *interpreter_context, - DbAccessor *dba, utils::MemoryResource *execution_memory) { + DbAccessor *dba, utils::MemoryResource *execution_memory, + const std::string *username) { const std::string kProfileQueryStart = "profile "; MG_ASSERT(utils::StartsWith(utils::ToLowerCase(parsed_query.stripped_query.query()), kProfileQueryStart), @@ -1265,11 +1296,12 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra parsed_inner_query.parameters, parsed_inner_query.is_cacheable ? &interpreter_context->plan_cache : nullptr, dba); auto rw_type_checker = plan::ReadWriteTypeChecker(); rw_type_checker.InferRWType(const_cast(cypher_query_plan->plan())); + auto optional_username = StringPointerToOptional(username); return PreparedQuery{{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"}, std::move(parsed_query.required_privileges), [plan = std::move(cypher_query_plan), parameters = std::move(parsed_inner_query.parameters), - summary, dba, interpreter_context, execution_memory, memory_limit, + summary, dba, interpreter_context, execution_memory, memory_limit, optional_username, // We want to execute the query we are profiling lazily, so we delay // the construction of the corresponding context. stats_and_total_time = std::optional{}, @@ -1278,7 +1310,7 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra // No output symbols are given so that nothing is streamed. if (!stats_and_total_time) { stats_and_total_time = PullPlan(plan, parameters, true, dba, interpreter_context, - execution_memory, nullptr, memory_limit) + execution_memory, optional_username, nullptr, memory_limit) .Pull(stream, {}, {}, summary); pull_plan = std::make_shared(ProfilingStatsToTable(*stats_and_total_time)); } @@ -1413,7 +1445,7 @@ PreparedQuery PrepareIndexQuery(ParsedQuery parsed_query, bool in_explicit_trans PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transaction, std::map *summary, InterpreterContext *interpreter_context, - DbAccessor *dba, utils::MemoryResource *execution_memory) { + DbAccessor *dba, utils::MemoryResource *execution_memory, const std::string *username) { if (in_explicit_transaction) { throw UserModificationInMulticommandTxException(); } @@ -1433,8 +1465,8 @@ PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transa [fn = callback.fn](Frame *, ExecutionContext *) { return fn(); }), 0.0, AstStorage{}, symbol_table)); - auto pull_plan = - std::make_shared(plan, parsed_query.parameters, false, dba, interpreter_context, execution_memory); + auto pull_plan = std::make_shared(plan, parsed_query.parameters, false, dba, interpreter_context, + execution_memory, StringPointerToOptional(username)); return PreparedQuery{ callback.header, std::move(parsed_query.required_privileges), [pull_plan = std::move(pull_plan), callback = std::move(callback), output_symbols = std::move(output_symbols), @@ -2147,7 +2179,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareCypherQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_, &*execution_db_accessor_, &query_execution->execution_memory, - &query_execution->notifications, + &query_execution->notifications, username, trigger_context_collector_ ? &*trigger_context_collector_ : nullptr); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareExplainQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_, @@ -2155,7 +2187,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareProfileQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, interpreter_context_, &*execution_db_accessor_, - &query_execution->execution_memory_with_exception); + &query_execution->execution_memory_with_exception, username); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareDumpQuery(std::move(parsed_query), &query_execution->summary, &*execution_db_accessor_, &query_execution->execution_memory); @@ -2165,7 +2197,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareAuthQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, interpreter_context_, &*execution_db_accessor_, - &query_execution->execution_memory_with_exception); + &query_execution->execution_memory_with_exception, username); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareInfoQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, interpreter_context_, interpreter_context_->db, diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index 77814da4a..9c25ea383 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -13,6 +13,7 @@ #include +#include "auth/models.hpp" #include "query/auth_checker.hpp" #include "query/config.hpp" #include "query/context.hpp" @@ -90,6 +91,9 @@ class AuthQueryHandler { /// @throw QueryRuntimeException if an error ocurred. virtual std::vector GetUsernamesForRole(const std::string &rolename) = 0; + /// @throw QueryRuntimeException if an error ocurred. + virtual memgraph::auth::User GetUser(const std::string &username) = 0; + /// @throw QueryRuntimeException if an error ocurred. virtual void SetRole(const std::string &username, const std::string &rolename) = 0; diff --git a/src/query/label_checker.hpp b/src/query/label_checker.hpp new file mode 100644 index 000000000..a874f0694 --- /dev/null +++ b/src/query/label_checker.hpp @@ -0,0 +1,23 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include "auth/models.hpp" +#include "query/frontend/ast/ast.hpp" +#include "storage/v2/id_types.hpp" + +namespace memgraph::query { +class LabelChecker { + public: + virtual bool IsUserAuthorized(const std::vector &label) const = 0; +}; +} // namespace memgraph::query diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 74463ead0..29e749242 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -404,10 +404,17 @@ class ScanAllCursor : public Cursor { vertices_.emplace(std::move(next_vertices.value())); vertices_it_.emplace(vertices_.value().begin()); } - - frame[output_symbol_] = *vertices_it_.value(); - ++vertices_it_.value(); - return true; + while (vertices_it_.value() != vertices_.value().end()) { + VertexAccessor vector = *vertices_it_.value(); + auto labels = vector.Labels(memgraph::storage::View::NEW).GetValue(); + if (context.label_checker->IsUserAuthorized(labels)) { + frame[output_symbol_] = *vertices_it_.value(); + ++vertices_it_.value(); + return true; + } + ++vertices_it_.value(); + } + return false; } void Shutdown() override { input_cursor_->Shutdown(); } From 216abaa0ef4ab9091034c8b2bddc7000698e02fc Mon Sep 17 00:00:00 2001 From: niko4299 Date: Mon, 11 Jul 2022 17:13:43 +0200 Subject: [PATCH 2/2] without username can see all --- src/query/plan/operator.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 29e749242..e1422c598 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -407,7 +407,7 @@ class ScanAllCursor : public Cursor { while (vertices_it_.value() != vertices_.value().end()) { VertexAccessor vector = *vertices_it_.value(); auto labels = vector.Labels(memgraph::storage::View::NEW).GetValue(); - if (context.label_checker->IsUserAuthorized(labels)) { + if (context.label_checker->IsUserAuthorized(labels) || !context.label_checker) { frame[output_symbol_] = *vertices_it_.value(); ++vertices_it_.value(); return true;