Add SHA-256 password encryption (#839)

This commit is contained in:
Josipmrden 2023-04-03 16:29:21 +02:00 committed by GitHub
parent f5a49ed29f
commit 128771a6ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 332 additions and 73 deletions

View File

@ -16,6 +16,10 @@ target_link_libraries(mg-auth mg-utils mg-kvstore mg-license )
target_link_libraries(mg-auth ${Seccomp_LIBRARIES})
target_include_directories(mg-auth SYSTEM PRIVATE ${Seccomp_INCLUDE_DIRS})
find_package(OpenSSL REQUIRED)
target_link_libraries(mg-auth ${OPENSSL_LIBRARIES})
target_include_directories(mg-auth SYSTEM PUBLIC ${OPENSSL_INCLUDE_DIR})
# Install reference auth modules and their configuration files.
install(PROGRAMS ${CMAKE_CURRENT_SOURCE_DIR}/reference_modules/example.py
DESTINATION lib/memgraph/auth_module)

View File

@ -1,19 +1,64 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Licensed as a Memgraph Enterprise file under the Memgraph Enterprise
// License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use
// this file except in compliance with the License. You may obtain a copy of the License at https://memgraph.com/legal.
//
//
#include "auth/crypto.hpp"
#include <iomanip>
#include <sstream>
#include <gflags/gflags.h>
#include <libbcrypt/bcrypt.h>
#include <openssl/evp.h>
#include <openssl/opensslv.h>
#include <openssl/sha.h>
#include "auth/exceptions.hpp"
#include "utils/enum.hpp"
#include "utils/flag_validation.hpp"
namespace {
using namespace std::literals;
inline constexpr std::array password_encryption_mappings{
std::pair{"bcrypt"sv, memgraph::auth::PasswordEncryptionAlgorithm::BCRYPT},
std::pair{"sha256"sv, memgraph::auth::PasswordEncryptionAlgorithm::SHA256},
std::pair{"sha256-multiple"sv, memgraph::auth::PasswordEncryptionAlgorithm::SHA256_MULTIPLE}};
inline constexpr uint64_t ONE_SHA_ITERATION = 1;
inline constexpr uint64_t MULTIPLE_SHA_ITERATIONS = 1024;
} // namespace
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,misc-unused-parameters)
DEFINE_VALIDATED_string(password_encryption_algorithm, "bcrypt",
"The password encryption algorithm used for authentication.", {
if (const auto result =
memgraph::utils::IsValidEnumValueString(value, password_encryption_mappings);
result.HasError()) {
const auto error = result.GetError();
switch (error) {
case memgraph::utils::ValidationError::EmptyValue: {
std::cout << "Password encryption algorithm cannot be empty." << std::endl;
break;
}
case memgraph::utils::ValidationError::InvalidValue: {
std::cout << "Invalid value for password encryption algorithm. Allowed values: "
<< memgraph::utils::GetAllowedEnumValuesString(password_encryption_mappings)
<< std::endl;
break;
}
}
return false;
}
return true;
});
namespace memgraph::auth {
const std::string EncryptPassword(const std::string &password) {
namespace BCrypt {
std::string EncryptPassword(const std::string &password) {
char salt[BCRYPT_HASHSIZE];
char hash[BCRYPT_HASHSIZE];
@ -28,7 +73,7 @@ const std::string EncryptPassword(const std::string &password) {
throw AuthException("Couldn't hash password!");
}
return std::string(hash);
return {hash};
}
bool VerifyPassword(const std::string &password, const std::string &hash) {
@ -38,5 +83,102 @@ bool VerifyPassword(const std::string &password, const std::string &hash) {
}
return ret == 0;
}
} // namespace BCrypt
namespace SHA {
#if OPENSSL_VERSION_MAJOR >= 3
std::string EncryptPasswordOpenSSL3(const std::string &password, const uint64_t number_of_iterations) {
unsigned char hash[SHA256_DIGEST_LENGTH];
EVP_MD_CTX *ctx = EVP_MD_CTX_new();
EVP_MD *md = EVP_MD_fetch(nullptr, "SHA2-256", nullptr);
EVP_DigestInit_ex(ctx, md, nullptr);
for (auto i = 0; i < number_of_iterations; i++) {
EVP_DigestUpdate(ctx, password.c_str(), password.size());
}
EVP_DigestFinal_ex(ctx, hash, nullptr);
EVP_MD_free(md);
EVP_MD_CTX_free(ctx);
std::stringstream result_stream;
for (auto hash_char : hash) {
result_stream << std::hex << std::setw(2) << std::setfill('0') << (int)hash_char;
}
return result_stream.str();
}
#else
std::string EncryptPasswordOpenSSL1_1(const std::string &password, const uint64_t number_of_iterations) {
unsigned char hash[SHA256_DIGEST_LENGTH];
SHA256_CTX sha256;
SHA256_Init(&sha256);
for (auto i = 0; i < number_of_iterations; i++) {
SHA256_Update(&sha256, password.c_str(), password.size());
}
SHA256_Final(hash, &sha256);
std::stringstream ss;
for (auto hash_char : hash) {
ss << std::hex << std::setw(2) << std::setfill('0') << (int)hash_char;
}
return ss.str();
}
#endif
std::string EncryptPassword(const std::string &password, const uint64_t number_of_iterations) {
#if OPENSSL_VERSION_MAJOR >= 3
return EncryptPasswordOpenSSL3(password, number_of_iterations);
#else
return EncryptPasswordOpenSSL1_1(password, number_of_iterations);
#endif
}
bool VerifyPassword(const std::string &password, const std::string &hash, const uint64_t number_of_iterations) {
auto password_hash = EncryptPassword(password, number_of_iterations);
return password_hash == hash;
}
} // namespace SHA
bool VerifyPassword(const std::string &password, const std::string &hash) {
const auto password_encryption_algorithm = utils::StringToEnum<PasswordEncryptionAlgorithm>(
FLAGS_password_encryption_algorithm, password_encryption_mappings);
if (!password_encryption_algorithm.has_value()) {
throw AuthException("Invalid password encryption flag '{}'!", FLAGS_password_encryption_algorithm);
}
switch (password_encryption_algorithm.value()) {
case PasswordEncryptionAlgorithm::BCRYPT:
return BCrypt::VerifyPassword(password, hash);
case PasswordEncryptionAlgorithm::SHA256:
return SHA::VerifyPassword(password, hash, ONE_SHA_ITERATION);
case PasswordEncryptionAlgorithm::SHA256_MULTIPLE:
return SHA::VerifyPassword(password, hash, MULTIPLE_SHA_ITERATIONS);
}
throw AuthException("Invalid password encryption flag '{}'!", FLAGS_password_encryption_algorithm);
}
std::string EncryptPassword(const std::string &password) {
const auto password_encryption_algorithm = utils::StringToEnum<PasswordEncryptionAlgorithm>(
FLAGS_password_encryption_algorithm, password_encryption_mappings);
if (!password_encryption_algorithm.has_value()) {
throw AuthException("Invalid password encryption flag '{}'!", FLAGS_password_encryption_algorithm);
}
switch (password_encryption_algorithm.value()) {
case PasswordEncryptionAlgorithm::BCRYPT:
return BCrypt::EncryptPassword(password);
case PasswordEncryptionAlgorithm::SHA256:
return SHA::EncryptPassword(password, ONE_SHA_ITERATION);
case PasswordEncryptionAlgorithm::SHA256_MULTIPLE:
return SHA::EncryptPassword(password, MULTIPLE_SHA_ITERATIONS);
}
}
} // namespace memgraph::auth

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Licensed as a Memgraph Enterprise file under the Memgraph Enterprise
// License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use
@ -11,10 +11,11 @@
#include <string>
namespace memgraph::auth {
enum class PasswordEncryptionAlgorithm : uint8_t { BCRYPT, SHA256, SHA256_MULTIPLE };
/// @throw AuthException if unable to encrypt the password.
const std::string EncryptPassword(const std::string &password);
std::string EncryptPassword(const std::string &password);
/// @throw AuthException if unable to verify the password.
bool VerifyPassword(const std::string &password, const std::string &hash);
} // namespace memgraph::auth

View File

@ -57,6 +57,7 @@
#include "storage/v2/storage.hpp"
#include "storage/v2/view.hpp"
#include "telemetry/telemetry.hpp"
#include "utils/enum.hpp"
#include "utils/event_counter.hpp"
#include "utils/file.hpp"
#include "utils/flag_validation.hpp"
@ -103,42 +104,6 @@ constexpr const char *kMgUser = "MEMGRAPH_USER";
constexpr const char *kMgPassword = "MEMGRAPH_PASSWORD";
constexpr const char *kMgPassfile = "MEMGRAPH_PASSFILE";
namespace {
std::string GetAllowedEnumValuesString(const auto &mappings) {
std::vector<std::string> allowed_values;
allowed_values.reserve(mappings.size());
std::transform(mappings.begin(), mappings.end(), std::back_inserter(allowed_values),
[](const auto &mapping) { return std::string(mapping.first); });
return memgraph::utils::Join(allowed_values, ", ");
}
enum class ValidationError : uint8_t { EmptyValue, InvalidValue };
memgraph::utils::BasicResult<ValidationError> IsValidEnumValueString(const auto &value, const auto &mappings) {
if (value.empty()) {
return ValidationError::EmptyValue;
}
if (std::find_if(mappings.begin(), mappings.end(), [&](const auto &mapping) { return mapping.first == value; }) ==
mappings.cend()) {
return ValidationError::InvalidValue;
}
return {};
}
template <typename Enum>
std::optional<Enum> StringToEnum(const auto &value, const auto &mappings) {
const auto mapping_iter =
std::find_if(mappings.begin(), mappings.end(), [&](const auto &mapping) { return mapping.first == value; });
if (mapping_iter == mappings.cend()) {
return std::nullopt;
}
return mapping_iter->second;
}
} // namespace
// Short help flag.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_HIDDEN_bool(h, false, "Print usage and exit.");
@ -290,21 +255,21 @@ inline constexpr std::array isolation_level_mappings{
const std::string isolation_level_help_string =
fmt::format("Default isolation level used for the transactions. Allowed values: {}",
GetAllowedEnumValuesString(isolation_level_mappings));
memgraph::utils::GetAllowedEnumValuesString(isolation_level_mappings));
} // namespace
// NOLINTNEXTLINE (cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_VALIDATED_string(isolation_level, "SNAPSHOT_ISOLATION", isolation_level_help_string.c_str(), {
if (const auto result = IsValidEnumValueString(value, isolation_level_mappings); result.HasError()) {
if (const auto result = memgraph::utils::IsValidEnumValueString(value, isolation_level_mappings); result.HasError()) {
const auto error = result.GetError();
switch (error) {
case ValidationError::EmptyValue: {
case memgraph::utils::ValidationError::EmptyValue: {
std::cout << "Isolation level cannot be empty." << std::endl;
break;
}
case ValidationError::InvalidValue: {
case memgraph::utils::ValidationError::InvalidValue: {
std::cout << "Invalid value for isolation level. Allowed values: "
<< GetAllowedEnumValuesString(isolation_level_mappings) << std::endl;
<< memgraph::utils::GetAllowedEnumValuesString(isolation_level_mappings) << std::endl;
break;
}
}
@ -317,7 +282,7 @@ DEFINE_VALIDATED_string(isolation_level, "SNAPSHOT_ISOLATION", isolation_level_h
namespace {
memgraph::storage::IsolationLevel ParseIsolationLevel() {
const auto isolation_level =
StringToEnum<memgraph::storage::IsolationLevel>(FLAGS_isolation_level, isolation_level_mappings);
memgraph::utils::StringToEnum<memgraph::storage::IsolationLevel>(FLAGS_isolation_level, isolation_level_mappings);
MG_ASSERT(isolation_level, "Invalid isolation level");
return *isolation_level;
}
@ -377,21 +342,21 @@ inline constexpr std::array log_level_mappings{
std::pair{"INFO"sv, spdlog::level::info}, std::pair{"WARNING"sv, spdlog::level::warn},
std::pair{"ERROR"sv, spdlog::level::err}, std::pair{"CRITICAL"sv, spdlog::level::critical}};
const std::string log_level_help_string =
fmt::format("Minimum log level. Allowed values: {}", GetAllowedEnumValuesString(log_level_mappings));
const std::string log_level_help_string = fmt::format("Minimum log level. Allowed values: {}",
memgraph::utils::GetAllowedEnumValuesString(log_level_mappings));
} // namespace
DEFINE_VALIDATED_string(log_level, "WARNING", log_level_help_string.c_str(), {
if (const auto result = IsValidEnumValueString(value, log_level_mappings); result.HasError()) {
if (const auto result = memgraph::utils::IsValidEnumValueString(value, log_level_mappings); result.HasError()) {
const auto error = result.GetError();
switch (error) {
case ValidationError::EmptyValue: {
case memgraph::utils::ValidationError::EmptyValue: {
std::cout << "Log level cannot be empty." << std::endl;
break;
}
case ValidationError::InvalidValue: {
std::cout << "Invalid value for log level. Allowed values: " << GetAllowedEnumValuesString(log_level_mappings)
<< std::endl;
case memgraph::utils::ValidationError::InvalidValue: {
std::cout << "Invalid value for log level. Allowed values: "
<< memgraph::utils::GetAllowedEnumValuesString(log_level_mappings) << std::endl;
break;
}
}
@ -403,7 +368,7 @@ DEFINE_VALIDATED_string(log_level, "WARNING", log_level_help_string.c_str(), {
namespace {
spdlog::level::level_enum ParseLogLevel() {
const auto log_level = StringToEnum<spdlog::level::level_enum>(FLAGS_log_level, log_level_mappings);
const auto log_level = memgraph::utils::StringToEnum<spdlog::level::level_enum>(FLAGS_log_level, log_level_mappings);
MG_ASSERT(log_level, "Invalid log level");
return *log_level;
}

57
src/utils/enum.hpp Normal file
View File

@ -0,0 +1,57 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <optional>
#include <string>
#include "utils/result.hpp"
#include "utils/string.hpp"
namespace memgraph::utils {
enum class ValidationError : uint8_t { EmptyValue, InvalidValue };
// Returns joined string representations for every enum in the mapping.
std::string GetAllowedEnumValuesString(const auto &mappings) {
std::vector<std::string> allowed_values;
allowed_values.reserve(mappings.size());
std::transform(mappings.begin(), mappings.end(), std::back_inserter(allowed_values),
[](const auto &mapping) { return std::string(mapping.first); });
return memgraph::utils::Join(allowed_values, ", ");
}
// Checks if the string value can be represented as an enum.
// If not, the BasicResult will contain an error.
memgraph::utils::BasicResult<ValidationError> IsValidEnumValueString(const auto &value, const auto &mappings) {
if (value.empty()) {
return ValidationError::EmptyValue;
}
if (std::find_if(mappings.begin(), mappings.end(), [&](const auto &mapping) { return mapping.first == value; }) ==
mappings.cend()) {
return ValidationError::InvalidValue;
}
return {};
}
// Tries to convert a string into enum, which would then contain a value if the conversion
// has been successful.
template <typename Enum>
std::optional<Enum> StringToEnum(const auto &value, const auto &mappings) {
const auto mapping_iter =
std::find_if(mappings.begin(), mappings.end(), [&](const auto &mapping) { return mapping.first == value; });
if (mapping_iter == mappings.cend()) {
return std::nullopt;
}
return mapping_iter->second;
}
} // namespace memgraph::utils

View File

@ -98,6 +98,7 @@ startup_config_dict = {
"IP address on which the websocket server for Memgraph monitoring should listen.",
),
"monitoring_port": ("7444", "7444", "Port on which the websocket server for Memgraph monitoring should listen."),
"password_encryption_algorithm": ("bcrypt", "bcrypt", "The password encryption algorithm used for authentication."),
"pulsar_service_url": ("", "", "Default URL used while connecting to Pulsar brokers."),
"query_execution_timeout_sec": (
"600",

View File

@ -337,6 +337,8 @@ void ExecuteWorkload(
std::vector<std::thread> threads;
threads.reserve(FLAGS_num_workers);
auto total_time_start = std::chrono::steady_clock::now();
std::vector<uint64_t> worker_retries(FLAGS_num_workers, 0);
std::vector<Metadata> worker_metadata(FLAGS_num_workers, Metadata());
std::vector<double> worker_duration(FLAGS_num_workers, 0.0);
@ -398,8 +400,12 @@ void ExecuteWorkload(
final_duration += worker_duration[i];
}
auto total_time_end = std::chrono::steady_clock::now();
auto total_time = std::chrono::duration_cast<std::chrono::duration<double>>(total_time_end - total_time_start);
final_duration /= FLAGS_num_workers;
nlohmann::json summary = nlohmann::json::object();
summary["total_time"] = total_time.count();
summary["count"] = queries.size();
summary["duration"] = final_duration;
summary["throughput"] = static_cast<double>(queries.size()) / final_duration;

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -28,6 +28,7 @@ namespace fs = std::filesystem;
DECLARE_bool(auth_password_permit_null);
DECLARE_string(auth_password_strength_regex);
DECLARE_string(password_encryption_algorithm);
class AuthWithStorage : public ::testing::Test {
protected:
@ -55,23 +56,22 @@ TEST_F(AuthWithStorage, AddRole) {
TEST_F(AuthWithStorage, RemoveRole) {
ASSERT_TRUE(auth.AddRole("admin"));
ASSERT_TRUE(auth.RemoveRole("admin"));
ASSERT_FALSE(auth.RemoveRole("user"));
ASSERT_FALSE(auth.RemoveRole("user"));
}
class AuthWithStorage : public ::testing::Test {
protected:
virtual void SetUp() {
memgraph::utils::EnsureDir(test_folder_);
FLAGS_auth_password_permit_null = true;
FLAGS_auth_password_strength_regex = ".+";
TEST_F(AuthWithStorage, AddUser) {
ASSERT_FALSE(auth.HasUsers());
ASSERT_TRUE(auth.AddUser("test"));
ASSERT_TRUE(auth.HasUsers());
ASSERT_TRUE(auth.AddUser("test2"));
ASSERT_FALSE(auth.AddUser("test"));
}
memgraph::license::global_license_checker.EnableTesting();
}
TEST_F(AuthWithStorage, RemoveUser) {
ASSERT_FALSE(auth.HasUsers());
ASSERT_TRUE(auth.AddUser("test"));
ASSERT_TRUE(auth.HasUsers());
ASSERT_TRUE(auth.RemoveUser("test"));
virtual void TearDown() { fs::remove_all(test_folder_); }
fs::path test_folder_{fs::temp_directory_path() / "MG_tests_unit_auth"};
Auth auth{test_folder_ / ("unit_auth_test_" + std::to_string(static_cast<int>(getpid())))};
};
ASSERT_FALSE(auth.HasUsers());
ASSERT_FALSE(auth.RemoveUser("test2"));
ASSERT_FALSE(auth.RemoveUser("test"));
@ -926,3 +926,86 @@ TEST(AuthWithoutStorage, Crypto) {
ASSERT_TRUE(VerifyPassword("hello", hash));
ASSERT_FALSE(VerifyPassword("hello1", hash));
}
class AuthWithVariousEncryptionAlgorithms : public ::testing::Test {
protected:
virtual void SetUp() { FLAGS_password_encryption_algorithm = "bcrypt"; }
};
TEST_F(AuthWithVariousEncryptionAlgorithms, VerifyPasswordDefault) {
auto hash = EncryptPassword("hello");
ASSERT_TRUE(VerifyPassword("hello", hash));
ASSERT_FALSE(VerifyPassword("hello1", hash));
}
TEST_F(AuthWithVariousEncryptionAlgorithms, VerifyPasswordSHA256) {
FLAGS_password_encryption_algorithm = "sha256";
auto hash = EncryptPassword("hello");
ASSERT_TRUE(VerifyPassword("hello", hash));
ASSERT_FALSE(VerifyPassword("hello1", hash));
}
TEST_F(AuthWithVariousEncryptionAlgorithms, VerifyPasswordSHA256_1024) {
FLAGS_password_encryption_algorithm = "sha256-multiple";
auto hash = EncryptPassword("hello");
ASSERT_TRUE(VerifyPassword("hello", hash));
ASSERT_FALSE(VerifyPassword("hello1", hash));
}
TEST_F(AuthWithVariousEncryptionAlgorithms, VerifyPasswordThrow) {
FLAGS_password_encryption_algorithm = "abcd";
ASSERT_THROW(EncryptPassword("hello"), AuthException);
}
TEST_F(AuthWithVariousEncryptionAlgorithms, VerifyPasswordEmptyEncryptionThrow) {
FLAGS_password_encryption_algorithm = "";
ASSERT_THROW(EncryptPassword("hello"), AuthException);
}
class AuthWithStorageWithVariousEncryptionAlgorithms : public ::testing::Test {
protected:
virtual void SetUp() {
memgraph::utils::EnsureDir(test_folder_);
FLAGS_auth_password_permit_null = true;
FLAGS_auth_password_strength_regex = ".+";
FLAGS_password_encryption_algorithm = "bcrypt";
memgraph::license::global_license_checker.EnableTesting();
}
virtual void TearDown() { fs::remove_all(test_folder_); }
fs::path test_folder_{fs::temp_directory_path() / "MG_tests_unit_auth"};
Auth auth{test_folder_ / ("unit_auth_test_" + std::to_string(static_cast<int>(getpid())))};
};
TEST_F(AuthWithStorageWithVariousEncryptionAlgorithms, AddUserDefault) {
auto user = auth.AddUser("Alice", "alice");
ASSERT_TRUE(user);
ASSERT_EQ(user->username(), "alice");
}
TEST_F(AuthWithStorageWithVariousEncryptionAlgorithms, AddUserSha256) {
FLAGS_password_encryption_algorithm = "sha256";
auto user = auth.AddUser("Alice", "alice");
ASSERT_TRUE(user);
ASSERT_EQ(user->username(), "alice");
}
TEST_F(AuthWithStorageWithVariousEncryptionAlgorithms, AddUserSha256_1024) {
FLAGS_password_encryption_algorithm = "sha256-multiple";
auto user = auth.AddUser("Alice", "alice");
ASSERT_TRUE(user);
ASSERT_EQ(user->username(), "alice");
}
TEST_F(AuthWithStorageWithVariousEncryptionAlgorithms, AddUserThrow) {
FLAGS_password_encryption_algorithm = "abcd";
ASSERT_THROW(auth.AddUser("Alice", "alice"), AuthException);
}
TEST_F(AuthWithStorageWithVariousEncryptionAlgorithms, AddUserEmptyPasswordEncryptionThrow) {
FLAGS_password_encryption_algorithm = "";
ASSERT_THROW(auth.AddUser("Alice", "alice"), AuthException);
}