diff --git a/src/auth/CMakeLists.txt b/src/auth/CMakeLists.txt index 4bd867a02..4861cd049 100644 --- a/src/auth/CMakeLists.txt +++ b/src/auth/CMakeLists.txt @@ -1,10 +1,15 @@ set(auth_src_files auth.cpp crypto.cpp - models.cpp) + models.cpp + module.cpp) find_package(Ldap REQUIRED) +find_package(Seccomp REQUIRED) add_library(mg-auth STATIC ${auth_src_files}) target_link_libraries(mg-auth json libbcrypt glog gflags fmt ldap) target_link_libraries(mg-auth mg-utils) + +target_link_libraries(mg-auth ${Seccomp_LIBRARIES}) +target_include_directories(mg-auth SYSTEM PRIVATE ${Seccomp_INCLUDE_DIRS}) diff --git a/src/auth/auth.cpp b/src/auth/auth.cpp index 62b7c234e..90d371882 100644 --- a/src/auth/auth.cpp +++ b/src/auth/auth.cpp @@ -1,6 +1,7 @@ #include "auth/auth.hpp" #include <cstring> +#include <iostream> #include <limits> #include <utility> @@ -34,6 +35,31 @@ DEFINE_bool(auth_ldap_create_role, true, DEFINE_string(auth_ldap_role_mapping_root_dn, "", "Set this value to the DN that contains all role mappings."); +DEFINE_VALIDATED_string( + auth_module_executable, "", + "Absolute path to the auth module executable that should be used.", { + if (value.empty()) return true; + // Check the file status, following symlinks. + auto status = std::filesystem::status(value); + if (!std::filesystem::is_regular_file(status)) { + std::cerr << "The auth module path doesn't exist or isn't a file!" + << std::endl; + return false; + } + return true; + }); +DEFINE_bool(auth_module_create_missing_user, true, + "Set to false to disable creation of missing users."); +DEFINE_bool(auth_module_create_missing_role, true, + "Set to false to disable creation of missing roles."); +DEFINE_bool( + auth_module_manage_roles, true, + "Set to false to disable management of roles through the auth module."); +DEFINE_VALIDATED_int32(auth_module_timeout_ms, 10000, + "Timeout (in milliseconds) used when waiting for a " + "response from the auth module.", + FLAG_IN_RANGE(100, 1800000)); + namespace auth { const std::string kUserPrefix = "user:"; @@ -83,7 +109,7 @@ void Init() { } Auth::Auth(const std::string &storage_directory) - : storage_(storage_directory) {} + : storage_(storage_directory), module_(FLAGS_auth_module_executable) {} /// Converts a `std::string` to a `struct berval`. std::pair<std::unique_ptr<char[]>, struct berval> LdapConvertString( @@ -197,7 +223,78 @@ std::optional<std::string> LdapFindRole(LDAP *ld, std::optional<User> Auth::Authenticate(const std::string &username, const std::string &password) { - if (FLAGS_auth_ldap_enabled) { + if (module_.IsUsed()) { + nlohmann::json params = nlohmann::json::object(); + params["username"] = username; + params["password"] = password; + + auto ret = module_.Call(params, FLAGS_auth_module_timeout_ms); + + // Verify response integrity. + if (!ret.is_object() || ret.find("authenticated") == ret.end() || + ret.find("role") == ret.end()) { + return std::nullopt; + } + const auto &ret_authenticated = ret.at("authenticated"); + const auto &ret_role = ret.at("role"); + if (!ret_authenticated.is_boolean() || !ret_role.is_string()) { + return std::nullopt; + } + auto is_authenticated = ret_authenticated.get<bool>(); + const auto &rolename = ret_role.get<std::string>(); + + // Authenticate the user. + if (!is_authenticated) return std::nullopt; + + // Find or create the user and return it. + auto user = GetUser(username); + if (!user) { + if (FLAGS_auth_module_create_missing_user) { + user = AddUser(username, password); + if (!user) { + LOG(WARNING) << "Couldn't authenticate user '" << username + << "' using the auth module because the user already " + "exists as a role!"; + return std::nullopt; + } + } else { + LOG(WARNING) + << "Couldn't authenticate user '" << username + << "' using the auth module because the user doesn't exist!"; + return std::nullopt; + } + } else { + user->UpdatePassword(password); + } + if (FLAGS_auth_module_manage_roles) { + if (!rolename.empty()) { + auto role = GetRole(rolename); + if (!role) { + if (FLAGS_auth_module_create_missing_role) { + role = AddRole(rolename); + if (!role) { + LOG(WARNING) + << "Couldn't authenticate user '" << username + << "' using the auth module because the user's role '" + << rolename << "' already exists as a user!"; + return std::nullopt; + } + SaveRole(*role); + } else { + LOG(WARNING) << "Couldn't authenticate user '" << username + << "' using the auth module because the user's role '" + << rolename << "' doesn't exist!"; + return std::nullopt; + } + } + user->SetRole(*role); + } else { + user->ClearRole(); + } + } + SaveUser(*user); + return user; + } else if (FLAGS_auth_ldap_enabled) { LDAP *ld = nullptr; // Initialize the LDAP struct. diff --git a/src/auth/auth.hpp b/src/auth/auth.hpp index a04d410a8..ee603f8cc 100644 --- a/src/auth/auth.hpp +++ b/src/auth/auth.hpp @@ -6,6 +6,7 @@ #include "auth/exceptions.hpp" #include "auth/models.hpp" +#include "auth/module.hpp" #include "storage/common/kvstore/kvstore.hpp" namespace auth { @@ -23,6 +24,8 @@ void Init(); * It provides functions for managing Users, Roles and Permissions. * NOTE: The functions in this class aren't thread safe. Use the `WithLock` lock * if you want to have safe modifications of the storage. + * TODO (mferencevic): Disable user/role modification functions when they are + * being managed by the auth module. */ class Auth final { public: @@ -150,6 +153,7 @@ class Auth final { private: storage::KVStore storage_; + auth::Module module_; // Even though the `storage::KVStore` class is guaranteed to be thread-safe we // use a mutex to lock all operations on the `User` and `Role` storage because // some operations on the users and/or roles may require more than one diff --git a/src/auth/module.cpp b/src/auth/module.cpp new file mode 100644 index 000000000..21777766a --- /dev/null +++ b/src/auth/module.cpp @@ -0,0 +1,473 @@ +#include "auth/module.hpp" + +#include <cerrno> +#include <chrono> +#include <csignal> +#include <cstdlib> +#include <cstring> +#include <thread> + +#include <fcntl.h> +#include <libgen.h> +#include <linux/limits.h> +#include <poll.h> +#include <pwd.h> +#include <sched.h> +#include <seccomp.h> +#include <sys/resource.h> +#include <sys/time.h> +#include <sys/types.h> +#include <sys/wait.h> +#include <unistd.h> + +#include <fmt/format.h> +#include <gflags/gflags.h> +#include <glog/logging.h> + +#include "utils/file.hpp" + +namespace { + +///////////////////////////////////////////////////////////////////////// +// Constants used for starting and communicating with the target process. +///////////////////////////////////////////////////////////////////////// + +const int kPipeReadEnd = 0; +const int kPipeWriteEnd = 1; + +const int kCommunicationToModuleFd = 1000; +const int kCommunicationFromModuleFd = 1001; + +const int kTerminateTimeoutSec = 5; + +//////////////////// +// Helper functions. +//////////////////// + +std::filesystem::path GetTemporaryPath(pid_t pid) { + return std::filesystem::temp_directory_path() / "memgraph" / + fmt::format("auth_module_{}", pid); +} + +std::string GetEnvironmentVariable(const std::string &name) { + char *value = secure_getenv(name.c_str()); + if (value == nullptr) { + return fmt::format("{}=", name); + } + return fmt::format("{}={}", name, value); +} + +/////////////////////////////////////////// +// char** wrapper used for C library calls. +/////////////////////////////////////////// + +const int kCharppMaxElements = 20; + +class CharPP final { + public: + CharPP() { memset(data_, 0, sizeof(char *) * kCharppMaxElements); } + + ~CharPP() { + for (size_t i = 0; i < size_; ++i) { + free(data_[i]); + } + } + + CharPP(const CharPP &) = delete; + CharPP(CharPP &&) = delete; + CharPP &operator=(const CharPP &) = delete; + CharPP &operator=(CharPP &&) = delete; + + void Add(const char *value) { + if (size_ == kCharppMaxElements) return; + int len = strlen(value); + char *item = static_cast<char *>(malloc(sizeof(char) * (len + 1))); + if (item == nullptr) return; + memcpy(item, value, len); + item[len] = 0; + data_[size_++] = item; + } + + void Add(const std::string &value) { Add(value.c_str()); } + + char **Get() { return data_; } + + private: + char *data_[kCharppMaxElements]; + size_t size_{0}; +}; + +//////////////////////////////////// +// Security functions and constants. +//////////////////////////////////// + +const std::vector<int> kSeccompSyscallsBlacklist = { + SCMP_SYS(mknod), + SCMP_SYS(mount), + SCMP_SYS(setuid), + SCMP_SYS(stime), + SCMP_SYS(ptrace), + SCMP_SYS(setgid), + SCMP_SYS(acct), + SCMP_SYS(umount), + SCMP_SYS(setpgid), + SCMP_SYS(chroot), + SCMP_SYS(setreuid), + SCMP_SYS(setregid), + SCMP_SYS(sethostname), + SCMP_SYS(settimeofday), + SCMP_SYS(setgroups), + SCMP_SYS(swapon), + SCMP_SYS(reboot), + SCMP_SYS(setpriority), + SCMP_SYS(ioperm), + SCMP_SYS(syslog), + SCMP_SYS(iopl), + SCMP_SYS(vhangup), + SCMP_SYS(vm86old), + SCMP_SYS(swapoff), + SCMP_SYS(setdomainname), + SCMP_SYS(adjtimex), + SCMP_SYS(init_module), + SCMP_SYS(delete_module), + SCMP_SYS(setfsuid), + SCMP_SYS(setfsgid), + SCMP_SYS(setresuid), + SCMP_SYS(vm86), + SCMP_SYS(setresgid), + SCMP_SYS(capset), + SCMP_SYS(setreuid), + SCMP_SYS(setregid), + SCMP_SYS(setgroups), + SCMP_SYS(setresuid), + SCMP_SYS(setresgid), + SCMP_SYS(setuid), + SCMP_SYS(setgid), + SCMP_SYS(setfsuid), + SCMP_SYS(setfsgid), + SCMP_SYS(pivot_root), + SCMP_SYS(sched_setaffinity), + SCMP_SYS(clock_settime), + SCMP_SYS(kexec_load), + SCMP_SYS(mknodat), + SCMP_SYS(unshare), + SCMP_SYS(seccomp), +}; + +bool SetupSeccomp() { + // Initialize the seccomp context. + scmp_filter_ctx ctx; + ctx = seccomp_init(SCMP_ACT_ALLOW); + if (ctx == nullptr) return false; + + // Add all general blacklist rules. + for (auto syscall_num : kSeccompSyscallsBlacklist) { + if (seccomp_rule_add(ctx, SCMP_ACT_KILL, syscall_num, 0) != 0) { + seccomp_release(ctx); + return false; + } + } + + // Load the context for the current process. + auto ret = seccomp_load(ctx); + + // Free the context and return success/failure. + seccomp_release(ctx); + return ret == 0; +} + +bool SetLimit(int resource, rlim_t n) { + struct rlimit limit; + limit.rlim_cur = limit.rlim_max = n; + return setrlimit(resource, &limit) == 0; +} + +//////////////////////////////////////////////////// +// Target function used to start the module process. +//////////////////////////////////////////////////// + +int Target(void *arg) { + // NOTE: (D)LOG shouldn't be used here because it wasn't initialized in this + // process and something really bad could happen. + + // Get a pointer to the passed arguments. + auto *ta = reinterpret_cast<auth::TargetArguments *>(arg); + + // Redirect `stdin` to `/dev/null`. + int fd = open("/dev/null", O_RDONLY | O_CLOEXEC); + if (fd == -1) { + return EXIT_FAILURE; + } + if (dup2(fd, STDIN_FILENO) != STDIN_FILENO) { + return EXIT_FAILURE; + } + + // Create the working directory. + std::filesystem::path working_path = GetTemporaryPath(getpid()); + utils::DeleteDir(working_path); + if (!utils::EnsureDir(working_path)) { + return EXIT_FAILURE; + } + + // Change the current directory to the working directory. + if (chdir(working_path.c_str()) != 0) { + return EXIT_FAILURE; + } + + // Create the executable CharPP object. + CharPP exe; + exe.Add(ta->module_executable_path); + + // Create the environment CharPP object. + CharPP env; + env.Add(GetEnvironmentVariable("PATH")); + env.Add(GetEnvironmentVariable("USER")); + env.Add(GetEnvironmentVariable("LANG")); + env.Add(GetEnvironmentVariable("LANGUAGE")); + env.Add(GetEnvironmentVariable("HOME")); + + // Connect the communication input pipe. + if (dup2(ta->pipe_to_module, kCommunicationToModuleFd) != + kCommunicationToModuleFd) { + return EXIT_FAILURE; + } + + // Connect the communication output pipe. + if (dup2(ta->pipe_from_module, kCommunicationFromModuleFd) != + kCommunicationFromModuleFd) { + return EXIT_FAILURE; + } + + // Set process limits. + // Disable core dumps. + if (!SetLimit(RLIMIT_CORE, 0)) { + return EXIT_FAILURE; + } + + // Ignore SIGINT. + struct sigaction action; + // `sa_sigaction` must be cleared before `sa_handler` is set because on some + // platforms the two are a union. + action.sa_sigaction = nullptr; + action.sa_handler = SIG_IGN; + sigemptyset(&action.sa_mask); + action.sa_flags = 0; + if (sigaction(SIGINT, &action, nullptr) != 0) { + return EXIT_FAILURE; + } + + // Setup seccomp. + if (!SetupSeccomp()) { + return EXIT_FAILURE; + } + + execve(*exe.Get(), exe.Get(), env.Get()); + + return EXIT_FAILURE; +} + +///////////////////////////////////////////////////// +// Function used to send data to the started process. +///////////////////////////////////////////////////// + +/// The data that is being sent to the module process is always a newline +/// terminated JSON encoded string. + +bool PutData(int fd, const nlohmann::json &data, int timeout_millisec) { + std::string encoded; + try { + encoded = data.dump(); + } catch (const nlohmann::json::type_error &) { + return false; + } + + if (encoded.empty()) return false; + if (*encoded.rbegin() != '\n') { + encoded.push_back('\n'); + } + + size_t put = 0; + while (put < encoded.size()) { + struct pollfd desc; + desc.fd = fd; + desc.events = POLLOUT; + desc.revents = 0; + if (poll(&desc, 1, timeout_millisec) <= 0) { + return false; + } + int ret = write(fd, encoded.data() + put, encoded.size() - put); + if (ret > 0) { + put += ret; + } else if (ret == 0 || errno != EINTR) { + return false; + } + } + return true; +} + +////////////////////////////////////////////////////// +// Function used to get data from the started process. +////////////////////////////////////////////////////// + +/// The data that is being received from the module process is always a newline +/// terminated JSON encoded string. The JSON encoded string must be in a single +/// line and all newline characters may only appear encoded as a part of a +/// character string. + +nlohmann::json GetData(int fd, int timeout_millisec) { + std::string data; + while (true) { + struct pollfd desc; + desc.fd = fd; + desc.events = POLLIN; + desc.revents = 0; + if (poll(&desc, 1, timeout_millisec) <= 0) { + return {}; + } + char ch; + int ret = read(fd, &ch, 1); + if (ret > 0) { + data += ch; + if (ch == '\n') break; + } else if (ret == 0 || errno != EINTR) { + return {}; + } + } + try { + return nlohmann::json::parse(data); + } catch (const nlohmann::json::parse_error &) { + return {}; + } +} + +} // namespace + +namespace auth { + +Module::Module(const std::string &module_executable_path) + : module_executable_path_(module_executable_path) {} + +bool Module::Startup() { + // Check whether the process is alive. + if (pid_ != -1 && waitpid(pid_, &status_, WNOHANG | WUNTRACED) == 0) { + return true; + } + + // Cleanup leftover state. + Shutdown(); + + // Setup communication pipes. + if (pipe2(pipe_to_module_, O_CLOEXEC) != 0) { + LOG(ERROR) << "Couldn't create communication pipe from the database to " + "the auth module!"; + return false; + } + if (pipe2(pipe_from_module_, O_CLOEXEC) != 0) { + LOG(ERROR) << "Couldn't create communication pipe from the auth module to " + "the database!"; + close(pipe_to_module_[kPipeReadEnd]); + close(pipe_to_module_[kPipeWriteEnd]); + return false; + } + + // Find the top of the stack. + uint8_t *stack_top = stack_.get() + kStackSizeBytes; + + // Set the target arguments. + target_arguments_->module_executable_path = module_executable_path_; + target_arguments_->pipe_to_module = pipe_to_module_[kPipeReadEnd]; + target_arguments_->pipe_from_module = pipe_from_module_[kPipeWriteEnd]; + + // Create the process. + pid_ = clone(Target, stack_top, CLONE_VFORK, target_arguments_.get()); + if (pid_ == -1) { + LOG(ERROR) << "Couldn't start the auth module process!"; + close(pipe_to_module_[kPipeReadEnd]); + close(pipe_to_module_[kPipeWriteEnd]); + close(pipe_from_module_[kPipeReadEnd]); + close(pipe_from_module_[kPipeWriteEnd]); + return false; + } + + // Close pipes that won't be used from the master process. + close(pipe_to_module_[kPipeReadEnd]); + close(pipe_from_module_[kPipeWriteEnd]); + + return true; +} + +nlohmann::json Module::Call(const nlohmann::json ¶ms, + int timeout_millisec) { + std::lock_guard<std::mutex> guard(lock_); + + if (!params.is_object()) return {}; + + // Ensure that the module is up and running. + if (!Startup()) return {}; + + // Put the request to the module process. + if (!PutData(pipe_to_module_[kPipeWriteEnd], params, timeout_millisec)) { + LOG(ERROR) << "Couldn't send data to the auth module process!"; + return {}; + } + + // Get the response from the module process. + auto ret = GetData(pipe_from_module_[kPipeReadEnd], timeout_millisec); + if (ret.is_null()) { + LOG(ERROR) << "Couldn't receive data from the auth module process!"; + return {}; + } + if (!ret.is_object()) { + LOG(ERROR) << "Data received from the auth module is of wrong type!"; + return {}; + } + return ret; +} + +bool Module::IsUsed() { return !module_executable_path_.empty(); } + +void Module::Shutdown() { + if (pid_ == -1) return; + + // Try to terminate the process gracefully in `kTerminateTimeoutSec`. + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + for (int i = 0; i < kTerminateTimeoutSec * 10; ++i) { + LOG(INFO) << "Terminating the auth module process with pid " << pid_; + kill(pid_, SIGTERM); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + int ret = waitpid(pid_, &status_, WNOHANG | WUNTRACED); + if (ret == pid_ || ret == -1) { + break; + } + } + + // If the process is still alive, kill it and wait for it to die. + if (waitpid(pid_, &status_, WNOHANG | WUNTRACED) == 0) { + LOG(WARNING) << "Killing the auth module process with pid " << pid_; + kill(pid_, SIGKILL); + waitpid(pid_, &status_, 0); + } + + // Delete the working directory. + utils::DeleteDir(GetTemporaryPath(pid_)); + + // Close leftover open pipes. + // We have to be careful to close only the leftover open pipes (the + // pipe_to_module WriteEnd and pipe_from_module ReadEnd), the other two ends + // were closed in the function that created them because they aren't used from + // the master process (they are only used from the module process). + close(pipe_to_module_[kPipeWriteEnd]); + close(pipe_from_module_[kPipeReadEnd]); + + // Reset variables. + pid_ = -1; + status_ = 0; + pipe_to_module_[kPipeReadEnd] = -1; + pipe_to_module_[kPipeWriteEnd] = -1; + pipe_from_module_[kPipeReadEnd] = -1; + pipe_from_module_[kPipeWriteEnd] = -1; +} + +Module::~Module() { Shutdown(); } + +} // namespace auth diff --git a/src/auth/module.hpp b/src/auth/module.hpp new file mode 100644 index 000000000..6932d7a62 --- /dev/null +++ b/src/auth/module.hpp @@ -0,0 +1,65 @@ +/// @file +#pragma once + +#include <filesystem> +#include <map> +#include <mutex> +#include <string> + +#include <json/json.hpp> + +namespace auth { + +struct TargetArguments { + std::filesystem::path module_executable_path; + int pipe_to_module{-1}; + int pipe_from_module{-1}; +}; + +/// Wrapper around the module executable. +class Module final { + private: + const int kStackSizeBytes = 262144; + + public: + explicit Module(const std::string &module_executable_path); + + Module(const Module &) = delete; + Module(Module &&) = delete; + Module &operator=(const Module &) = delete; + Module &operator=(Module &&) = delete; + + /// Call the function in the module with the specified parameters and return + /// the response. + /// + /// @param parameters dict used to call the module function + /// @param timeout_millisec timeout in ms used for communication with the + /// module + /// @return dict retuned by module function + nlohmann::json Call(const nlohmann::json ¶ms, int timeout_millisec); + + /// This function returns a boolean value indicating whether the module has a + /// specified executable path and can thus be used. + /// + /// @return boolean indicating whether the module can be used + bool IsUsed(); + + ~Module(); + + private: + bool Startup(); + void Shutdown(); + + std::string module_executable_path_; + std::mutex lock_; + pid_t pid_{-1}; + int status_{0}; + // The stack used for the `clone` system call must be heap allocated. + std::unique_ptr<uint8_t[]> stack_{new uint8_t[kStackSizeBytes]}; + // The target arguments passed to the new process must be heap allocated. + std::unique_ptr<TargetArguments> target_arguments_{new TargetArguments()}; + int pipe_to_module_[2] = {-1, -1}; + int pipe_from_module_[2] = {-1, -1}; +}; + +} // namespace auth diff --git a/src/auth/reference_modules/example.py b/src/auth/reference_modules/example.py new file mode 100755 index 000000000..d40f34892 --- /dev/null +++ b/src/auth/reference_modules/example.py @@ -0,0 +1,16 @@ +#!/usr/bin/python3 +import json +import io + + +def authenticate(username, password): + return {"authenticated": True, "role": ""} + + +if __name__ == "__main__": + input_stream = io.FileIO(1000, mode="r") + output_stream = io.FileIO(1001, mode="w") + while True: + params = json.loads(input_stream.readline().decode("ascii")) + ret = authenticate(**params) + output_stream.write((json.dumps(ret) + "\n").encode("ascii"))