remote changes merged

This commit is contained in:
Boris Tasevski 2022-07-11 17:35:54 +02:00
commit 41ae33c671
7 changed files with 148 additions and 50 deletions

View File

@ -18,19 +18,24 @@ roles_config = config["roles"]
# Initialize LDAP server.
tls = None
if server_config["encryption"] != "disabled":
cert_file = server_config["cert_file"] if server_config["cert_file"] \
else None
cert_file = server_config["cert_file"] if server_config["cert_file"] else None
key_file = server_config["key_file"] if server_config["key_file"] else None
ca_file = server_config["ca_file"] if server_config["ca_file"] else None
validate = ssl.CERT_REQUIRED if server_config["validate_cert"] \
else ssl.CERT_NONE
tls = ldap3.Tls(local_private_key_file=key_file,
local_certificate_file=cert_file,
ca_certs_file=ca_file,
validate=validate)
validate = ssl.CERT_REQUIRED if server_config["validate_cert"] else ssl.CERT_NONE
tls = ldap3.Tls(
local_private_key_file=key_file,
local_certificate_file=cert_file,
ca_certs_file=ca_file,
validate=validate,
)
use_ssl = server_config["encryption"] == "ssl"
server = ldap3.Server(server_config["host"], port=server_config["port"],
tls=tls, use_ssl=use_ssl, get_info=ldap3.ALL)
server = ldap3.Server(
server_config["host"],
port=server_config["port"],
tls=tls,
use_ssl=use_ssl,
get_info=ldap3.ALL,
)
# Main authentication/authorization function.
@ -40,14 +45,12 @@ def authenticate(username, password):
return {"authenticated": False, "role": ""}
# Create the DN of the user
dn = users_config["prefix"] + ldap3.utils.dn.escape_rdn(username) + \
users_config["suffix"]
dn = users_config["prefix"] + ldap3.utils.dn.escape_rdn(username) + users_config["suffix"]
# Bind to the server
conn = ldap3.Connection(server, dn, password)
if server_config["encryption"] == "starttls" and not conn.start_tls():
print("ERROR: Couldn't issue STARTTLS to the LDAP server!",
file=sys.stderr)
print("ERROR: Couldn't issue STARTTLS to the LDAP server!", file=sys.stderr)
return {"authenticated": False, "role": ""}
if not conn.bind():
return {"authenticated": False, "role": ""}
@ -56,25 +59,32 @@ def authenticate(username, password):
if roles_config["root_dn"] != "":
# search for role
search_filter = "(&(objectclass={objclass})({attr}={value}))".format(
objclass=roles_config["root_objectclass"],
attr=roles_config["user_attribute"],
value=ldap3.utils.conv.escape_filter_chars(dn))
succ = conn.search(roles_config["root_dn"], search_filter,
search_scope=ldap3.LEVEL,
attributes=[roles_config["role_attribute"]])
objclass=roles_config["root_objectclass"],
attr=roles_config["user_attribute"],
value=ldap3.utils.conv.escape_filter_chars(dn),
)
succ = conn.search(
roles_config["root_dn"],
search_filter,
search_scope=ldap3.LEVEL,
attributes=[roles_config["role_attribute"]],
)
if not succ or len(conn.entries) == 0:
return {"authenticated": True, "role": ""}
if len(conn.entries) > 1:
roles = list(map(lambda x: x[roles_config["role_attribute"]].value,
conn.entries))
roles = list(map(lambda x: x[roles_config["role_attribute"]].value, conn.entries))
# Because we don't know exactly which role the user should have
# we authorize the user with an empty role.
print("WARNING: Found more than one role for "
"user '" + username + "':", ", ".join(roles) + "!",
file=sys.stderr)
print(
"WARNING: Found more than one role for " "user '" + username + "':",
", ".join(roles) + "!",
file=sys.stderr,
)
return {"authenticated": True, "role": ""}
return {"authenticated": True,
"role": conn.entries[0][roles_config["role_attribute"]].value}
return {
"authenticated": True,
"role": conn.entries[0][roles_config["role_attribute"]].value,
}
else:
return {"authenticated": True, "role": ""}

View File

@ -38,6 +38,7 @@
#include "helpers.hpp"
#include "py/py.hpp"
#include "query/auth_checker.hpp"
#include "query/db_accessor.hpp"
#include "query/discard_value_stream.hpp"
#include "query/exceptions.hpp"
#include "query/frontend/ast/ast.hpp"
@ -746,6 +747,23 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler {
}
}
memgraph::auth::User GetUser(const std::string &username) override {
if (!std::regex_match(username, name_regex_)) {
throw memgraph::query::QueryRuntimeException("Invalid user name.");
}
try {
auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username);
if (!user) {
throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist .", username);
}
return user.value();
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
void GrantPrivilege(const std::string &user_or_role,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges,
const std::vector<std::string> &labels) override {
@ -1250,6 +1268,7 @@ int main(int argc, char **argv) {
AuthChecker auth_checker{&auth};
interpreter_context.auth = &auth_handler;
interpreter_context.auth_checker = &auth_checker;
// interpreter_context.label_checker = &label_checker;
{
// Triggers can execute query procedures, so we need to reload the modules first and then

View File

@ -13,8 +13,10 @@
#include <type_traits>
#include "auth/models.hpp"
#include "query/common.hpp"
#include "query/frontend/semantic/symbol_table.hpp"
#include "query/label_checker.hpp"
#include "query/metadata.hpp"
#include "query/parameters.hpp"
#include "query/plan/profile.hpp"
@ -72,6 +74,7 @@ struct ExecutionContext {
ExecutionStats execution_stats;
TriggerContextCollector *trigger_context_collector{nullptr};
utils::AsyncTimer timer;
LabelChecker *label_checker{nullptr};
};
static_assert(std::is_move_assignable_v<ExecutionContext>, "ExecutionContext must be move assignable!");

View File

@ -36,6 +36,7 @@
#include "query/frontend/semantic/required_privileges.hpp"
#include "query/frontend/semantic/symbol_generator.hpp"
#include "query/interpret/eval.hpp"
#include "query/label_checker.hpp"
#include "query/metadata.hpp"
#include "query/plan/planner.hpp"
#include "query/plan/profile.hpp"
@ -257,8 +258,24 @@ class ReplQueryHandler final : public query::ReplicationQueryHandler {
private:
storage::Storage *db_;
};
/// returns false if the replication role can't be set
/// @throw QueryRuntimeException if an error ocurred.
class LabelChecker final : public memgraph::query::LabelChecker {
public:
explicit LabelChecker(memgraph::auth::User *user, memgraph::query::DbAccessor *dba) : user_{user}, dba_(dba) {}
bool IsUserAuthorized(const std::vector<memgraph::storage::LabelId> &labels) const final {
const auto user_label_permissions = user_->GetLabelPermissions();
auto *dba = dba_;
return std::all_of(labels.begin(), labels.end(), [&user_label_permissions, dba](const auto label) {
return user_label_permissions.Has(dba->LabelToName(label)) == memgraph::auth::PermissionLevel::GRANT;
});
}
private:
memgraph::auth::User *user_;
memgraph::query::DbAccessor *dba_;
};
Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Parameters &parameters,
DbAccessor *db_accessor) {
@ -271,6 +288,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Pa
// TODO: MemoryResource for EvaluationContext, it should probably be passed as
// the argument to Callback.
evaluation_context.timestamp = QueryTimestamp();
evaluation_context.parameters = parameters;
ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, db_accessor, storage::View::OLD);
@ -292,10 +310,11 @@ Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Pa
AuthQuery::Action::REVOKE_PRIVILEGE, AuthQuery::Action::SHOW_PRIVILEGES, AuthQuery::Action::SHOW_USERS_FOR_ROLE,
AuthQuery::Action::SHOW_ROLE_FOR_USER};
if (license_check_result.HasError() && enterprise_only_methods.contains(auth_query->action_)) {
throw utils::BasicException(
utils::license::LicenseCheckErrorToString(license_check_result.GetError(), "advanced authentication features"));
}
// if (license_check_result.HasError() && enterprise_only_methods.contains(auth_query->action_)) {
// throw utils::BasicException(
// utils::license::LicenseCheckErrorToString(license_check_result.GetError(), "advanced authentication
// features"));
// }
switch (auth_query->action_) {
case AuthQuery::Action::CREATE_USER:
@ -897,7 +916,7 @@ struct PullPlanVector {
struct PullPlan {
explicit PullPlan(std::shared_ptr<CachedPlan> plan, const Parameters &parameters, bool is_profile_query,
DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory,
TriggerContextCollector *trigger_context_collector = nullptr,
std::optional<std::string> username, TriggerContextCollector *trigger_context_collector = nullptr,
std::optional<size_t> memory_limit = {});
std::optional<plan::ProfilingStatsWithTotalTime> Pull(AnyStream *stream, std::optional<int> n,
const std::vector<Symbol> &output_symbols,
@ -926,7 +945,8 @@ struct PullPlan {
PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters &parameters, const bool is_profile_query,
DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory,
TriggerContextCollector *trigger_context_collector, const std::optional<size_t> memory_limit)
std::optional<std::string> username, TriggerContextCollector *trigger_context_collector,
const std::optional<size_t> memory_limit)
: plan_(plan),
cursor_(plan->plan().MakeCursor(execution_memory)),
frame_(plan->symbol_table().max_position(), execution_memory),
@ -937,6 +957,13 @@ PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters &par
ctx_.evaluation_context.parameters = parameters;
ctx_.evaluation_context.properties = NamesToProperties(plan->ast_storage().properties_, dba);
ctx_.evaluation_context.labels = NamesToLabels(plan->ast_storage().labels_, dba);
if (username.has_value()) {
memgraph::auth::User user = interpreter_context->auth->GetUser(*username);
LabelChecker label_checker{&user, dba};
ctx_.label_checker = &label_checker;
}
if (interpreter_context->config.execution_timeout_sec > 0) {
ctx_.timer = utils::AsyncTimer{interpreter_context->config.execution_timeout_sec};
}
@ -1110,6 +1137,7 @@ PreparedQuery Interpreter::PrepareTransactionQuery(std::string_view query_upper)
PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string, TypedValue> *summary,
InterpreterContext *interpreter_context, DbAccessor *dba,
utils::MemoryResource *execution_memory, std::vector<Notification> *notifications,
const std::string *username,
TriggerContextCollector *trigger_context_collector = nullptr) {
auto *cypher_query = utils::Downcast<CypherQuery>(parsed_query.query);
@ -1118,6 +1146,7 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string,
EvaluationContext evaluation_context;
evaluation_context.timestamp = QueryTimestamp();
evaluation_context.parameters = parsed_query.parameters;
ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, dba, storage::View::OLD);
const auto memory_limit = EvaluateMemoryLimit(&evaluator, cypher_query->memory_limit_, cypher_query->memory_scale_);
if (memory_limit) {
@ -1153,8 +1182,9 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string,
header.push_back(
utils::FindOr(parsed_query.stripped_query.named_expressions(), symbol.token_position(), symbol.name()).first);
}
auto pull_plan = std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context,
execution_memory, trigger_context_collector, memory_limit);
auto pull_plan =
std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context, execution_memory,
StringPointerToOptional(username), trigger_context_collector, memory_limit);
return PreparedQuery{std::move(header), std::move(parsed_query.required_privileges),
[pull_plan = std::move(pull_plan), output_symbols = std::move(output_symbols), summary](
AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> {
@ -1214,7 +1244,8 @@ PreparedQuery PrepareExplainQuery(ParsedQuery parsed_query, std::map<std::string
PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
std::map<std::string, TypedValue> *summary, InterpreterContext *interpreter_context,
DbAccessor *dba, utils::MemoryResource *execution_memory) {
DbAccessor *dba, utils::MemoryResource *execution_memory,
const std::string *username) {
const std::string kProfileQueryStart = "profile ";
MG_ASSERT(utils::StartsWith(utils::ToLowerCase(parsed_query.stripped_query.query()), kProfileQueryStart),
@ -1265,11 +1296,12 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra
parsed_inner_query.parameters, parsed_inner_query.is_cacheable ? &interpreter_context->plan_cache : nullptr, dba);
auto rw_type_checker = plan::ReadWriteTypeChecker();
rw_type_checker.InferRWType(const_cast<plan::LogicalOperator &>(cypher_query_plan->plan()));
auto optional_username = StringPointerToOptional(username);
return PreparedQuery{{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"},
std::move(parsed_query.required_privileges),
[plan = std::move(cypher_query_plan), parameters = std::move(parsed_inner_query.parameters),
summary, dba, interpreter_context, execution_memory, memory_limit,
summary, dba, interpreter_context, execution_memory, memory_limit, optional_username,
// We want to execute the query we are profiling lazily, so we delay
// the construction of the corresponding context.
stats_and_total_time = std::optional<plan::ProfilingStatsWithTotalTime>{},
@ -1278,7 +1310,7 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra
// No output symbols are given so that nothing is streamed.
if (!stats_and_total_time) {
stats_and_total_time = PullPlan(plan, parameters, true, dba, interpreter_context,
execution_memory, nullptr, memory_limit)
execution_memory, optional_username, nullptr, memory_limit)
.Pull(stream, {}, {}, summary);
pull_plan = std::make_shared<PullPlanVector>(ProfilingStatsToTable(*stats_and_total_time));
}
@ -1413,7 +1445,7 @@ PreparedQuery PrepareIndexQuery(ParsedQuery parsed_query, bool in_explicit_trans
PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
std::map<std::string, TypedValue> *summary, InterpreterContext *interpreter_context,
DbAccessor *dba, utils::MemoryResource *execution_memory) {
DbAccessor *dba, utils::MemoryResource *execution_memory, const std::string *username) {
if (in_explicit_transaction) {
throw UserModificationInMulticommandTxException();
}
@ -1433,8 +1465,8 @@ PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transa
[fn = callback.fn](Frame *, ExecutionContext *) { return fn(); }),
0.0, AstStorage{}, symbol_table));
auto pull_plan =
std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context, execution_memory);
auto pull_plan = std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context,
execution_memory, StringPointerToOptional(username));
return PreparedQuery{
callback.header, std::move(parsed_query.required_privileges),
[pull_plan = std::move(pull_plan), callback = std::move(callback), output_symbols = std::move(output_symbols),
@ -2147,7 +2179,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
if (utils::Downcast<CypherQuery>(parsed_query.query)) {
prepared_query = PrepareCypherQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_,
&*execution_db_accessor_, &query_execution->execution_memory,
&query_execution->notifications,
&query_execution->notifications, username,
trigger_context_collector_ ? &*trigger_context_collector_ : nullptr);
} else if (utils::Downcast<ExplainQuery>(parsed_query.query)) {
prepared_query = PrepareExplainQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_,
@ -2155,7 +2187,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
} else if (utils::Downcast<ProfileQuery>(parsed_query.query)) {
prepared_query = PrepareProfileQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary,
interpreter_context_, &*execution_db_accessor_,
&query_execution->execution_memory_with_exception);
&query_execution->execution_memory_with_exception, username);
} else if (utils::Downcast<DumpQuery>(parsed_query.query)) {
prepared_query = PrepareDumpQuery(std::move(parsed_query), &query_execution->summary, &*execution_db_accessor_,
&query_execution->execution_memory);
@ -2165,7 +2197,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
} else if (utils::Downcast<AuthQuery>(parsed_query.query)) {
prepared_query = PrepareAuthQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary,
interpreter_context_, &*execution_db_accessor_,
&query_execution->execution_memory_with_exception);
&query_execution->execution_memory_with_exception, username);
} else if (utils::Downcast<InfoQuery>(parsed_query.query)) {
prepared_query = PrepareInfoQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary,
interpreter_context_, interpreter_context_->db,

View File

@ -13,6 +13,7 @@
#include <gflags/gflags.h>
#include "auth/models.hpp"
#include "query/auth_checker.hpp"
#include "query/config.hpp"
#include "query/context.hpp"
@ -90,6 +91,9 @@ class AuthQueryHandler {
/// @throw QueryRuntimeException if an error ocurred.
virtual std::vector<TypedValue> GetUsernamesForRole(const std::string &rolename) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual memgraph::auth::User GetUser(const std::string &username) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void SetRole(const std::string &username, const std::string &rolename) = 0;

View File

@ -0,0 +1,23 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "auth/models.hpp"
#include "query/frontend/ast/ast.hpp"
#include "storage/v2/id_types.hpp"
namespace memgraph::query {
class LabelChecker {
public:
virtual bool IsUserAuthorized(const std::vector<memgraph::storage::LabelId> &label) const = 0;
};
} // namespace memgraph::query

View File

@ -404,10 +404,17 @@ class ScanAllCursor : public Cursor {
vertices_.emplace(std::move(next_vertices.value()));
vertices_it_.emplace(vertices_.value().begin());
}
frame[output_symbol_] = *vertices_it_.value();
++vertices_it_.value();
return true;
while (vertices_it_.value() != vertices_.value().end()) {
VertexAccessor vector = *vertices_it_.value();
auto labels = vector.Labels(memgraph::storage::View::NEW).GetValue();
if (context.label_checker->IsUserAuthorized(labels) || !context.label_checker) {
frame[output_symbol_] = *vertices_it_.value();
++vertices_it_.value();
return true;
}
++vertices_it_.value();
}
return false;
}
void Shutdown() override { input_cursor_->Shutdown(); }