diff --git a/environment/os/amzn-2.sh b/environment/os/amzn-2.sh index 15ff29106..a9cc3e4b2 100755 --- a/environment/os/amzn-2.sh +++ b/environment/os/amzn-2.sh @@ -1,7 +1,5 @@ #!/bin/bash - set -Eeuo pipefail - DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source "$DIR/../util.sh" @@ -9,7 +7,7 @@ check_operating_system "amzn-2" check_architecture "x86_64" TOOLCHAIN_BUILD_DEPS=( - gcc gcc-c++ make # generic build tools + git gcc gcc-c++ make # generic build tools wget # used for archive download gnupg2 # used for archive signature verification tar gzip bzip2 xz unzip # used for archive unpacking @@ -63,6 +61,8 @@ MEMGRAPH_BUILD_DEPS=( cyrus-sasl-devel ) +MEMGRAPH_TEST_DEPS="${MEMGRAPH_BUILD_DEPS[*]}" + MEMGRAPH_RUN_DEPS=( logrotate openssl python3 libseccomp ) diff --git a/environment/os/centos-7.sh b/environment/os/centos-7.sh index df16fbc73..d9fc93912 100755 --- a/environment/os/centos-7.sh +++ b/environment/os/centos-7.sh @@ -1,7 +1,5 @@ #!/bin/bash - set -Eeuo pipefail - DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source "$DIR/../util.sh" @@ -63,6 +61,8 @@ MEMGRAPH_BUILD_DEPS=( cyrus-sasl-devel ) +MEMGRAPH_TEST_DEPS="${MEMGRAPH_BUILD_DEPS[*]}" + MEMGRAPH_RUN_DEPS=( logrotate openssl python3 libseccomp ) diff --git a/environment/os/centos-9.sh b/environment/os/centos-9.sh index 8a431807e..8177c9223 100755 --- a/environment/os/centos-9.sh +++ b/environment/os/centos-9.sh @@ -1,7 +1,5 @@ #!/bin/bash - set -Eeuo pipefail - DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source "$DIR/../util.sh" @@ -9,8 +7,10 @@ check_operating_system "centos-9" check_architecture "x86_64" TOOLCHAIN_BUILD_DEPS=( - coreutils-common gcc gcc-c++ make # generic build tools wget # used for archive download + coreutils-common gcc gcc-c++ make # generic build tools + # NOTE: Pure libcurl conflicts with libcurl-minimal + libcurl-devel # cmake build requires it gnupg2 # used for archive signature verification tar gzip bzip2 xz unzip # used for archive unpacking zlib-devel # zlib library used for all builds @@ -64,6 +64,8 @@ MEMGRAPH_BUILD_DEPS=( cyrus-sasl-devel ) +MEMGRAPH_TEST_DEPS="${MEMGRAPH_BUILD_DEPS[*]}" + MEMGRAPH_RUN_DEPS=( logrotate openssl python3 libseccomp ) @@ -123,7 +125,9 @@ install() { else echo "NOTE: export LANG=en_US.utf8" fi - yum update -y + # --nobest is used because of libipt because we install custom versions + # because libipt-devel is not available on CentOS 9 Stream + yum update -y --nobest yum install -y wget git python3 python3-pip for pkg in $1; do diff --git a/environment/os/debian-10.sh b/environment/os/debian-10.sh index 4c1deda42..9a64854de 100755 --- a/environment/os/debian-10.sh +++ b/environment/os/debian-10.sh @@ -1,10 +1,10 @@ #!/bin/bash - set -Eeuo pipefail - DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source "$DIR/../util.sh" +# IMPORTANT: Deprecated since memgraph v2.12.0. + check_operating_system "debian-10" check_architecture "x86_64" diff --git a/environment/os/debian-11-arm.sh b/environment/os/debian-11-arm.sh index c8a3cca1c..8e17a8fdd 100755 --- a/environment/os/debian-11-arm.sh +++ b/environment/os/debian-11-arm.sh @@ -1,10 +1,10 @@ #!/bin/bash - set -Eeuo pipefail - DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source "$DIR/../util.sh" +# IMPORTANT: Deprecated since memgraph v2.12.0. + check_operating_system "debian-11" check_architecture "arm64" "aarch64" diff --git a/environment/os/debian-11.sh b/environment/os/debian-11.sh index c7e82b52c..ac05f6ba6 100755 --- a/environment/os/debian-11.sh +++ b/environment/os/debian-11.sh @@ -1,7 +1,5 @@ #!/bin/bash - set -Eeuo pipefail - DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source "$DIR/../util.sh" @@ -61,6 +59,8 @@ MEMGRAPH_BUILD_DEPS=( libsasl2-dev ) +MEMGRAPH_TEST_DEPS="${MEMGRAPH_BUILD_DEPS[*]}" + MEMGRAPH_RUN_DEPS=( logrotate openssl python3 libseccomp ) diff --git a/environment/os/debian-12-arm.sh b/environment/os/debian-12-arm.sh new file mode 100755 index 000000000..15d3f7473 --- /dev/null +++ b/environment/os/debian-12-arm.sh @@ -0,0 +1,134 @@ +#!/bin/bash +set -Eeuo pipefail +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +source "$DIR/../util.sh" + +check_operating_system "debian-12" +check_architecture "arm64" "aarch64" + +TOOLCHAIN_BUILD_DEPS=( + coreutils gcc g++ build-essential make # generic build tools + wget # used for archive download + gnupg # used for archive signature verification + tar gzip bzip2 xz-utils unzip # used for archive unpacking + zlib1g-dev # zlib library used for all builds + libexpat1-dev liblzma-dev python3-dev texinfo # for gdb + libcurl4-openssl-dev # for cmake + libreadline-dev # for cmake and llvm + libffi-dev libxml2-dev # for llvm + libedit-dev libpcre2-dev libpcre3-dev automake bison # for swig + curl # snappy + file # for libunwind + libssl-dev # for libevent + libgmp-dev + gperf # for proxygen + git # for fbthrift +) + +TOOLCHAIN_RUN_DEPS=( + make # generic build tools + tar gzip bzip2 xz-utils # used for archive unpacking + zlib1g # zlib library used for all builds + libexpat1 liblzma5 python3 # for gdb + libcurl4 # for cmake + file # for CPack + libreadline8 # for cmake and llvm + libffi8 libxml2 # for llvm + libssl-dev # for libevent +) + +MEMGRAPH_BUILD_DEPS=( + git # source code control + make pkg-config # build system + curl wget # for downloading libs + uuid-dev default-jre-headless # required by antlr + libreadline-dev # for memgraph console + libpython3-dev python3-dev # for query modules + libssl-dev + libseccomp-dev + netcat # tests are using nc to wait for memgraph + python3 virtualenv python3-virtualenv python3-pip # for qa, macro_benchmark and stress tests + python3-yaml # for the configuration generator + libcurl4-openssl-dev # mg-requests + sbcl # for custom Lisp C++ preprocessing + doxygen graphviz # source documentation generators + mono-runtime mono-mcs zip unzip default-jdk-headless custom-maven3.9.3 # for driver tests + dotnet-sdk-7.0 golang custom-golang1.18.9 nodejs npm + autoconf # for jemalloc code generation + libtool # for protobuf code generation + libsasl2-dev +) + +MEMGRAPH_RUN_DEPS=( + logrotate openssl python3 libseccomp +) + +NEW_DEPS=( + wget curl tar gzip +) + +list() { + echo "$1" +} + +check() { + local missing="" + for pkg in $1; do + if [ "$pkg" == custom-maven3.9.3 ]; then + if [ ! -f "/opt/apache-maven-3.9.3/bin/mvn" ]; then + missing="$pkg $missing" + fi + continue + fi + if [ "$pkg" == custom-golang1.18.9 ]; then + if [ ! -f "/opt/go1.18.9/go/bin/go" ]; then + missing="$pkg $missing" + fi + continue + fi + if ! dpkg -s "$pkg" >/dev/null 2>/dev/null; then + missing="$pkg $missing" + fi + done + if [ "$missing" != "" ]; then + echo "MISSING PACKAGES: $missing" + exit 1 + fi +} + +install() { + cd "$DIR" + apt update + # If GitHub Actions runner is installed, append LANG to the environment. + # Python related tests doesn't work the LANG export. + if [ -d "/home/gh/actions-runner" ]; then + echo "LANG=en_US.utf8" >> /home/gh/actions-runner/.env + else + echo "NOTE: export LANG=en_US.utf8" + fi + apt install -y wget + + for pkg in $1; do + if [ "$pkg" == custom-maven3.9.3 ]; then + install_custom_maven "3.9.3" + continue + fi + if [ "$pkg" == custom-golang1.18.9 ]; then + install_custom_golang "1.18.9" + continue + fi + if [ "$pkg" == dotnet-sdk-7.0 ]; then + if ! dpkg -s "$pkg" 2>/dev/null >/dev/null; then + wget -nv https://packages.microsoft.com/config/debian/12/packages-microsoft-prod.deb -O packages-microsoft-prod.deb + dpkg -i packages-microsoft-prod.deb + apt-get update + apt-get install -y apt-transport-https dotnet-sdk-7.0 + fi + continue + fi + apt install -y "$pkg" + done +} + +deps=$2"[*]" +"$1" "${!deps}" diff --git a/environment/os/debian-12.sh b/environment/os/debian-12.sh new file mode 100755 index 000000000..1709230ad --- /dev/null +++ b/environment/os/debian-12.sh @@ -0,0 +1,136 @@ +#!/bin/bash +set -Eeuo pipefail +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +source "$DIR/../util.sh" + +check_operating_system "debian-12" +check_architecture "x86_64" + +TOOLCHAIN_BUILD_DEPS=( + coreutils gcc g++ build-essential make # generic build tools + wget # used for archive download + gnupg # used for archive signature verification + tar gzip bzip2 xz-utils unzip # used for archive unpacking + zlib1g-dev # zlib library used for all builds + libexpat1-dev libipt-dev libbabeltrace-dev liblzma-dev python3-dev texinfo # for gdb + libcurl4-openssl-dev # for cmake + libreadline-dev # for cmake and llvm + libffi-dev libxml2-dev # for llvm + libedit-dev libpcre2-dev libpcre3-dev automake bison # for swig + curl # snappy + file # for libunwind + libssl-dev # for libevent + libgmp-dev + gperf # for proxygen + git # for fbthrift +) + +TOOLCHAIN_RUN_DEPS=( + make # generic build tools + tar gzip bzip2 xz-utils # used for archive unpacking + zlib1g # zlib library used for all builds + libexpat1 libipt2 libbabeltrace1 liblzma5 python3 # for gdb + libcurl4 # for cmake + file # for CPack + libreadline8 # for cmake and llvm + libffi8 libxml2 # for llvm + libssl-dev # for libevent +) + +MEMGRAPH_BUILD_DEPS=( + git # source code control + make cmake pkg-config # build system + curl wget # for downloading libs + uuid-dev default-jre-headless # required by antlr + libreadline-dev # for memgraph console + libpython3-dev python3-dev # for query modules + libssl-dev + libseccomp-dev + netcat-traditional # tests are using nc to wait for memgraph + python3 virtualenv python3-virtualenv python3-pip # for qa, macro_benchmark and stress tests + python3-yaml # for the configuration generator + libcurl4-openssl-dev # mg-requests + sbcl # for custom Lisp C++ preprocessing + doxygen graphviz # source documentation generators + mono-runtime mono-mcs zip unzip default-jdk-headless custom-maven3.9.3 # for driver tests + dotnet-sdk-7.0 golang custom-golang1.18.9 nodejs npm + autoconf # for jemalloc code generation + libtool # for protobuf code generation + libsasl2-dev +) + +MEMGRAPH_TEST_DEPS="${MEMGRAPH_BUILD_DEPS[*]}" + +MEMGRAPH_RUN_DEPS=( + logrotate openssl python3 libseccomp +) + +NEW_DEPS=( + wget curl tar gzip +) + +list() { + echo "$1" +} + +check() { + local missing="" + for pkg in $1; do + if [ "$pkg" == custom-maven3.9.3 ]; then + if [ ! -f "/opt/apache-maven-3.9.3/bin/mvn" ]; then + missing="$pkg $missing" + fi + continue + fi + if [ "$pkg" == custom-golang1.18.9 ]; then + if [ ! -f "/opt/go1.18.9/go/bin/go" ]; then + missing="$pkg $missing" + fi + continue + fi + if ! dpkg -s "$pkg" >/dev/null 2>/dev/null; then + missing="$pkg $missing" + fi + done + if [ "$missing" != "" ]; then + echo "MISSING PACKAGES: $missing" + exit 1 + fi +} + +install() { + cd "$DIR" + apt update + # If GitHub Actions runner is installed, append LANG to the environment. + # Python related tests doesn't work the LANG export. + if [ -d "/home/gh/actions-runner" ]; then + echo "LANG=en_US.utf8" >> /home/gh/actions-runner/.env + else + echo "NOTE: export LANG=en_US.utf8" + fi + apt install -y wget + + for pkg in $1; do + if [ "$pkg" == custom-maven3.9.3 ]; then + install_custom_maven "3.9.3" + continue + fi + if [ "$pkg" == custom-golang1.18.9 ]; then + install_custom_golang "1.18.9" + continue + fi + if [ "$pkg" == dotnet-sdk-7.0 ]; then + if ! dpkg -s "$pkg" 2>/dev/null >/dev/null; then + wget -nv https://packages.microsoft.com/config/debian/12/packages-microsoft-prod.deb -O packages-microsoft-prod.deb + dpkg -i packages-microsoft-prod.deb + apt-get update + apt-get install -y apt-transport-https dotnet-sdk-7.0 + fi + continue + fi + apt install -y "$pkg" + done +} + +deps=$2"[*]" +"$1" "${!deps}" diff --git a/environment/os/fedora-36.sh b/environment/os/fedora-36.sh index f7bd0c53a..f8b8995d9 100755 --- a/environment/os/fedora-36.sh +++ b/environment/os/fedora-36.sh @@ -1,10 +1,10 @@ #!/bin/bash - set -Eeuo pipefail - DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source "$DIR/../util.sh" +# IMPORTANT: Deprecated since memgraph v2.12.0. + check_operating_system "fedora-36" check_architecture "x86_64" @@ -27,6 +27,7 @@ TOOLCHAIN_BUILD_DEPS=( libipt libipt-devel # intel patch perl # for openssl + git ) TOOLCHAIN_RUN_DEPS=( diff --git a/environment/os/fedora-38.sh b/environment/os/fedora-38.sh index 7837f018b..951bec46f 100755 --- a/environment/os/fedora-38.sh +++ b/environment/os/fedora-38.sh @@ -1,7 +1,5 @@ #!/bin/bash - set -Eeuo pipefail - DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source "$DIR/../util.sh" @@ -27,6 +25,7 @@ TOOLCHAIN_BUILD_DEPS=( libipt libipt-devel # intel patch perl # for openssl + git ) TOOLCHAIN_RUN_DEPS=( @@ -58,6 +57,16 @@ MEMGRAPH_BUILD_DEPS=( libtool # for protobuf code generation ) +MEMGRAPH_TEST_DEPS="${MEMGRAPH_BUILD_DEPS[*]}" + +MEMGRAPH_RUN_DEPS=( + logrotate openssl python3 libseccomp +) + +NEW_DEPS=( + wget curl tar gzip +) + list() { echo "$1" } diff --git a/environment/os/fedora-39.sh b/environment/os/fedora-39.sh new file mode 100755 index 000000000..4b0e82992 --- /dev/null +++ b/environment/os/fedora-39.sh @@ -0,0 +1,117 @@ +#!/bin/bash +set -Eeuo pipefail +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +source "$DIR/../util.sh" + +check_operating_system "fedora-39" +check_architecture "x86_64" + +TOOLCHAIN_BUILD_DEPS=( + coreutils-common gcc gcc-c++ make # generic build tools + wget # used for archive download + gnupg2 # used for archive signature verification + tar gzip bzip2 xz unzip # used for archive unpacking + zlib-devel # zlib library used for all builds + expat-devel xz-devel python3-devel texinfo libbabeltrace-devel # for gdb + curl libcurl-devel # for cmake + readline-devel # for cmake and llvm + libffi-devel libxml2-devel # for llvm + libedit-devel pcre-devel pcre2-devel automake bison # for swig + file + openssl-devel + gmp-devel + gperf + diffutils + libipt libipt-devel # intel + patch + perl # for openssl + git +) + +TOOLCHAIN_RUN_DEPS=( + make # generic build tools + tar gzip bzip2 xz # used for archive unpacking + zlib # zlib library used for all builds + expat xz-libs python3 # for gdb + readline # for cmake and llvm + libffi libxml2 # for llvm + openssl-devel +) + +MEMGRAPH_BUILD_DEPS=( + git # source code control + make pkgconf-pkg-config # build system + wget # for downloading libs + libuuid-devel java-11-openjdk # required by antlr + readline-devel # for memgraph console + python3-devel # for query modules + openssl-devel + libseccomp-devel + python3 python3-pip python3-virtualenv python3-virtualenvwrapper python3-pyyaml nmap-ncat # for tests + libcurl-devel # mg-requests + rpm-build rpmlint # for RPM package building + doxygen graphviz # source documentation generators + which nodejs golang zip unzip java-11-openjdk-devel # for driver tests + sbcl # for custom Lisp C++ preprocessing + autoconf # for jemalloc code generation + libtool # for protobuf code generation +) + +MEMGRAPH_TEST_DEPS="${MEMGRAPH_BUILD_DEPS[*]}" + +MEMGRAPH_RUN_DEPS=( + logrotate openssl python3 libseccomp +) + +NEW_DEPS=( + wget curl tar gzip +) + +list() { + echo "$1" +} + +check() { + if [ -v LD_LIBRARY_PATH ]; then + # On Fedora 38 yum/dnf and python11 use newer glibc which is not compatible + # with ours, so we need to momentarely disable env + local OLD_LD_LIBRARY_PATH=${LD_LIBRARY_PATH} + LD_LIBRARY_PATH="" + fi + local missing="" + for pkg in $1; do + if ! dnf list installed "$pkg" >/dev/null 2>/dev/null; then + missing="$pkg $missing" + fi + done + if [ "$missing" != "" ]; then + echo "MISSING PACKAGES: $missing" + exit 1 + fi + if [ -v OLD_LD_LIBRARY_PATH ]; then + echo "Restoring LD_LIBRARY_PATH..." + LD_LIBRARY_PATH=${OLD_LD_LIBRARY_PATH} + fi +} + +install() { + cd "$DIR" + if [ "$EUID" -ne 0 ]; then + echo "Please run as root." + exit 1 + fi + # If GitHub Actions runner is installed, append LANG to the environment. + # Python related tests don't work without the LANG export. + if [ -d "/home/gh/actions-runner" ]; then + echo "LANG=en_US.utf8" >> /home/gh/actions-runner/.env + else + echo "NOTE: export LANG=en_US.utf8" + fi + dnf update -y + for pkg in $1; do + dnf install -y "$pkg" + done +} + +deps=$2"[*]" +"$1" "${!deps}" diff --git a/environment/os/rocky-9.3.sh b/environment/os/rocky-9.3.sh new file mode 100755 index 000000000..571278654 --- /dev/null +++ b/environment/os/rocky-9.3.sh @@ -0,0 +1,188 @@ +#!/bin/bash +set -Eeuo pipefail +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +source "$DIR/../util.sh" + +# TODO(gitbuda): Rocky gets automatically updates -> figure out how to handle it. +check_operating_system "rocky-9.3" +check_architecture "x86_64" + +TOOLCHAIN_BUILD_DEPS=( + wget # used for archive download + coreutils-common gcc gcc-c++ make # generic build tools + # NOTE: Pure libcurl conflicts with libcurl-minimal + libcurl-devel # cmake build requires it + gnupg2 # used for archive signature verification + tar gzip bzip2 xz unzip # used for archive unpacking + zlib-devel # zlib library used for all builds + expat-devel xz-devel python3-devel perl-Unicode-EastAsianWidth texinfo libbabeltrace-devel # for gdb + readline-devel # for cmake and llvm + libffi-devel libxml2-devel # for llvm + libedit-devel pcre-devel pcre2-devel automake bison # for swig + file + openssl-devel + gmp-devel + gperf + diffutils + libipt libipt-devel # intel + patch +) + +TOOLCHAIN_RUN_DEPS=( + make # generic build tools + tar gzip bzip2 xz # used for archive unpacking + zlib # zlib library used for all builds + expat xz-libs python3 # for gdb + readline # for cmake and llvm + libffi libxml2 # for llvm + openssl-devel + perl # for openssl +) + +MEMGRAPH_BUILD_DEPS=( + git # source code control + make cmake pkgconf-pkg-config # build system + wget # for downloading libs + libuuid-devel java-11-openjdk # required by antlr + readline-devel # for memgraph console + python3-devel # for query modules + openssl-devel + libseccomp-devel + python3 python3-pip python3-virtualenv nmap-ncat # for qa, macro_benchmark and stress tests + # + # IMPORTANT: python3-yaml does NOT exist on CentOS + # Install it manually using `pip3 install PyYAML` + # + PyYAML # Package name here does not correspond to the yum package! + libcurl-devel # mg-requests + rpm-build rpmlint # for RPM package building + doxygen graphviz # source documentation generators + which nodejs golang custom-golang1.18.9 # for driver tests + zip unzip java-11-openjdk-devel java-17-openjdk java-17-openjdk-devel custom-maven3.9.3 # for driver tests + sbcl # for custom Lisp C++ preprocessing + autoconf # for jemalloc code generation + libtool # for protobuf code generation + cyrus-sasl-devel +) + +MEMGRAPH_TEST_DEPS="${MEMGRAPH_BUILD_DEPS[*]}" + +MEMGRAPH_RUN_DEPS=( + logrotate openssl python3 libseccomp +) + +NEW_DEPS=( + wget curl tar gzip +) + +list() { + echo "$1" +} + +check() { + local missing="" + for pkg in $1; do + if [ "$pkg" == custom-maven3.9.3 ]; then + if [ ! -f "/opt/apache-maven-3.9.3/bin/mvn" ]; then + missing="$pkg $missing" + fi + continue + fi + if [ "$pkg" == custom-golang1.18.9 ]; then + if [ ! -f "/opt/go1.18.9/go/bin/go" ]; then + missing="$pkg $missing" + fi + continue + fi + if [ "$pkg" == "PyYAML" ]; then + if ! python3 -c "import yaml" >/dev/null 2>/dev/null; then + missing="$pkg $missing" + fi + continue + fi + if [ "$pkg" == "python3-virtualenv" ]; then + continue + fi + if ! yum list installed "$pkg" >/dev/null 2>/dev/null; then + missing="$pkg $missing" + fi + done + if [ "$missing" != "" ]; then + echo "MISSING PACKAGES: $missing" + exit 1 + fi +} + +install() { + cd "$DIR" + if [ "$EUID" -ne 0 ]; then + echo "Please run as root." + exit 1 + fi + # If GitHub Actions runner is installed, append LANG to the environment. + # Python related tests doesn't work the LANG export. + if [ -d "/home/gh/actions-runner" ]; then + echo "LANG=en_US.utf8" >> /home/gh/actions-runner/.env + else + echo "NOTE: export LANG=en_US.utf8" + fi + yum update -y + yum install -y wget git python3 python3-pip + + for pkg in $1; do + if [ "$pkg" == custom-maven3.9.3 ]; then + install_custom_maven "3.9.3" + continue + fi + if [ "$pkg" == custom-golang1.18.9 ]; then + install_custom_golang "1.18.9" + continue + fi + if [ "$pkg" == perl-Unicode-EastAsianWidth ]; then + if ! dnf list installed perl-Unicode-EastAsianWidth >/dev/null 2>/dev/null; then + dnf install -y https://dl.rockylinux.org/pub/rocky/9/CRB/x86_64/os/Packages/p/perl-Unicode-EastAsianWidth-12.0-7.el9.noarch.rpm + fi + continue + fi + if [ "$pkg" == texinfo ]; then + if ! dnf list installed texinfo >/dev/null 2>/dev/null; then + dnf install -y https://dl.rockylinux.org/pub/rocky/9/CRB/x86_64/os/Packages/t/texinfo-6.7-15.el9.x86_64.rpm + fi + continue + fi + if [ "$pkg" == libbabeltrace-devel ]; then + if ! dnf list installed libbabeltrace-devel >/dev/null 2>/dev/null; then + dnf install -y https://dl.rockylinux.org/pub/rocky/9/devel/x86_64/os/Packages/l/libbabeltrace-devel-1.5.8-10.el9.x86_64.rpm + fi + continue + fi + if [ "$pkg" == libipt-devel ]; then + if ! dnf list installed libipt-devel >/dev/null 2>/dev/null; then + dnf install -y https://dl.rockylinux.org/pub/rocky/9/devel/x86_64/os/Packages/l/libipt-devel-2.0.4-5.el9.x86_64.rpm + fi + continue + fi + if [ "$pkg" == PyYAML ]; then + if [ -z ${SUDO_USER+x} ]; then # Running as root (e.g. Docker). + pip3 install --user PyYAML + else # Running using sudo. + sudo -H -u "$SUDO_USER" bash -c "pip3 install --user PyYAML" + fi + continue + fi + if [ "$pkg" == python3-virtualenv ]; then + if [ -z ${SUDO_USER+x} ]; then # Running as root (e.g. Docker). + pip3 install virtualenv + pip3 install virtualenvwrapper + else # Running using sudo. + sudo -H -u "$SUDO_USER" bash -c "pip3 install virtualenv" + sudo -H -u "$SUDO_USER" bash -c "pip3 install virtualenvwrapper" + fi + continue + fi + yum install -y "$pkg" + done +} + +deps=$2"[*]" +"$1" "${!deps}" diff --git a/environment/os/template.sh b/environment/os/template.sh index b1f2f8fe4..692926efb 100755 --- a/environment/os/template.sh +++ b/environment/os/template.sh @@ -1,7 +1,5 @@ #!/bin/bash - set -Eeuo pipefail - DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source "$DIR/../util.sh" @@ -20,6 +18,10 @@ MEMGRAPH_BUILD_DEPS=( pkg ) +MEMGRAPH_TEST_DEPS=( + pkg +) + MEMGRAPH_RUN_DEPS=( pkg ) diff --git a/environment/os/ubuntu-18.04.sh b/environment/os/ubuntu-18.04.sh index 27d876e4f..451d5e69c 100755 --- a/environment/os/ubuntu-18.04.sh +++ b/environment/os/ubuntu-18.04.sh @@ -1,10 +1,10 @@ #!/bin/bash - set -Eeuo pipefail - DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source "$DIR/../util.sh" +# IMPORTANT: Deprecated since memgraph v2.12.0. + check_operating_system "ubuntu-18.04" check_architecture "x86_64" diff --git a/environment/os/ubuntu-20.04.sh b/environment/os/ubuntu-20.04.sh index 8a308406e..7739b49d1 100755 --- a/environment/os/ubuntu-20.04.sh +++ b/environment/os/ubuntu-20.04.sh @@ -1,7 +1,5 @@ #!/bin/bash - set -Eeuo pipefail - DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source "$DIR/../util.sh" @@ -60,6 +58,8 @@ MEMGRAPH_BUILD_DEPS=( libsasl2-dev ) +MEMGRAPH_TEST_DEPS="${MEMGRAPH_BUILD_DEPS[*]}" + MEMGRAPH_RUN_DEPS=( logrotate openssl python3 libseccomp2 ) diff --git a/environment/os/ubuntu-22.04-arm.sh b/environment/os/ubuntu-22.04-arm.sh index 45a4f3d4c..9326e52e9 100755 --- a/environment/os/ubuntu-22.04-arm.sh +++ b/environment/os/ubuntu-22.04-arm.sh @@ -1,7 +1,5 @@ #!/bin/bash - set -Eeuo pipefail - DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source "$DIR/../util.sh" @@ -60,6 +58,8 @@ MEMGRAPH_BUILD_DEPS=( libsasl2-dev ) +MEMGRAPH_TEST_DEPS="${MEMGRAPH_BUILD_DEPS[*]}" + MEMGRAPH_RUN_DEPS=( logrotate openssl python3 libseccomp2 ) diff --git a/environment/os/ubuntu-22.04.sh b/environment/os/ubuntu-22.04.sh index 59361dd81..649338e53 100755 --- a/environment/os/ubuntu-22.04.sh +++ b/environment/os/ubuntu-22.04.sh @@ -1,7 +1,5 @@ #!/bin/bash - set -Eeuo pipefail - DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source "$DIR/../util.sh" @@ -60,6 +58,8 @@ MEMGRAPH_BUILD_DEPS=( libsasl2-dev ) +MEMGRAPH_TEST_DEPS="${MEMGRAPH_BUILD_DEPS[*]}" + MEMGRAPH_RUN_DEPS=( logrotate openssl python3 libseccomp2 ) diff --git a/environment/toolchain/.gitignore b/environment/toolchain/.gitignore index e75f93b12..6ba5f327d 100644 --- a/environment/toolchain/.gitignore +++ b/environment/toolchain/.gitignore @@ -2,3 +2,4 @@ archives build output *.tar.gz +tmp_build.sh diff --git a/environment/toolchain/template_build.sh b/environment/toolchain/template_build.sh new file mode 100644 index 000000000..b01902ab9 --- /dev/null +++ b/environment/toolchain/template_build.sh @@ -0,0 +1,48 @@ +#!/bin/bash -e + +# NOTE: Copy this under memgraph/environment/toolchain/vN/tmp_build.sh, edit and test. + +pushd () { command pushd "$@" > /dev/null; } +popd () { command popd "$@" > /dev/null; } +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +CPUS=$( grep -c processor < /proc/cpuinfo ) +cd "$DIR" +source "$DIR/../../util.sh" +DISTRO="$(operating_system)" +TOOLCHAIN_VERSION=5 +NAME=toolchain-v$TOOLCHAIN_VERSION +PREFIX=/opt/$NAME +function log_tool_name () { + echo "" + echo "" + echo "#### $1 ####" + echo "" + echo "" +} + +# HERE: Remove/clear dependencies from a given toolchain. + +mkdir -p archives && pushd archives +# HERE: Download dependencies here. +popd + +mkdir -p build +pushd build +source $PREFIX/activate +export CC=$PREFIX/bin/clang +export CXX=$PREFIX/bin/clang++ +export CFLAGS="$CFLAGS -fPIC" +export PATH=$PREFIX/bin:$PATH +export LD_LIBRARY_PATH=$PREFIX/lib64 +COMMON_CMAKE_FLAGS="-DCMAKE_INSTALL_PREFIX=$PREFIX + -DCMAKE_PREFIX_PATH=$PREFIX + -DCMAKE_BUILD_TYPE=Release + -DCMAKE_C_COMPILER=$CC + -DCMAKE_CXX_COMPILER=$CXX + -DBUILD_SHARED_LIBS=OFF + -DCMAKE_CXX_STANDARD=20 + -DBUILD_TESTING=OFF + -DCMAKE_REQUIRED_INCLUDES=$PREFIX/include + -DCMAKE_POSITION_INDEPENDENT_CODE=ON" + +# HERE: Add dependencies to test below. diff --git a/environment/toolchain/v5/build.sh b/environment/toolchain/v5/build.sh index b6c1ff6d8..aade6f9c5 100755 --- a/environment/toolchain/v5/build.sh +++ b/environment/toolchain/v5/build.sh @@ -307,7 +307,7 @@ if [ ! -f $PREFIX/bin/ld.gold ]; then fi log_tool_name "GDB $GDB_VERSION" -if [ ! -f $PREFIX/bin/gdb ]; then +if [[ ! -f "$PREFIX/bin/gdb" && "$DISTRO" -ne "amzn-2" ]]; then if [ -d gdb-$GDB_VERSION ]; then rm -rf gdb-$GDB_VERSION fi @@ -671,7 +671,6 @@ PROXYGEN_SHA256=5360a8ccdfb2f5a6c7b3eed331ec7ab0e2c792d579c6fff499c85c516c11fe14 WANGLE_SHA256=1002e9c32b6f4837f6a760016e3b3e22f3509880ef3eaad191c80dc92655f23f # WANGLE_SHA256=0e493c03572bb27fe9ca03a9da5023e52fde99c95abdcaa919bb6190e7e69532 -FLEX_VERSION=2.6.4 FMT_SHA256=78b8c0a72b1c35e4443a7e308df52498252d1cefc2b08c9a97bc9ee6cfe61f8b FMT_VERSION=10.1.1 # NOTE: spdlog depends on exact fmt versions -> UPGRADE fmt and spdlog TOGETHER. @@ -690,8 +689,8 @@ LZ4_VERSION=1.9.4 SNAPPY_SHA256=75c1fbb3d618dd3a0483bff0e26d0a92b495bbe5059c8b4f1c962b478b6e06e7 SNAPPY_VERSION=1.1.9 XZ_VERSION=5.2.5 # for LZMA -ZLIB_VERSION=1.3 -ZSTD_VERSION=1.5.0 +ZLIB_VERSION=1.3.1 +ZSTD_VERSION=1.5.5 pushd archives if [ ! -f boost_$BOOST_VERSION_UNDERSCORES.tar.gz ]; then @@ -700,7 +699,7 @@ if [ ! -f boost_$BOOST_VERSION_UNDERSCORES.tar.gz ]; then wget https://boostorg.jfrog.io/artifactory/main/release/$BOOST_VERSION/source/boost_$BOOST_VERSION_UNDERSCORES.tar.gz -O boost_$BOOST_VERSION_UNDERSCORES.tar.gz fi if [ ! -f bzip2-$BZIP2_VERSION.tar.gz ]; then - wget https://sourceforge.net/projects/bzip2/files/bzip2-$BZIP2_VERSION.tar.gz -O bzip2-$BZIP2_VERSION.tar.gz + wget https://sourceware.org/pub/bzip2/bzip2-$BZIP2_VERSION.tar.gz -O bzip2-$BZIP2_VERSION.tar.gz fi if [ ! -f double-conversion-$DOUBLE_CONVERSION_VERSION.tar.gz ]; then wget https://github.com/google/double-conversion/archive/refs/tags/v$DOUBLE_CONVERSION_VERSION.tar.gz -O double-conversion-$DOUBLE_CONVERSION_VERSION.tar.gz @@ -708,9 +707,7 @@ fi if [ ! -f fizz-$FBLIBS_VERSION.tar.gz ]; then wget https://github.com/facebookincubator/fizz/releases/download/v$FBLIBS_VERSION/fizz-v$FBLIBS_VERSION.tar.gz -O fizz-$FBLIBS_VERSION.tar.gz fi -if [ ! -f flex-$FLEX_VERSION.tar.gz ]; then - wget https://github.com/westes/flex/releases/download/v$FLEX_VERSION/flex-$FLEX_VERSION.tar.gz -O flex-$FLEX_VERSION.tar.gz -fi + if [ ! -f fmt-$FMT_VERSION.tar.gz ]; then wget https://github.com/fmtlib/fmt/archive/refs/tags/$FMT_VERSION.tar.gz -O fmt-$FMT_VERSION.tar.gz fi @@ -765,14 +762,6 @@ echo "$BZIP2_SHA256 bzip2-$BZIP2_VERSION.tar.gz" | sha256sum -c echo "$DOUBLE_CONVERSION_SHA256 double-conversion-$DOUBLE_CONVERSION_VERSION.tar.gz" | sha256sum -c # verify fizz echo "$FIZZ_SHA256 fizz-$FBLIBS_VERSION.tar.gz" | sha256sum -c -# verify flex -if [ ! -f flex-$FLEX_VERSION.tar.gz.sig ]; then - wget https://github.com/westes/flex/releases/download/v$FLEX_VERSION/flex-$FLEX_VERSION.tar.gz.sig -fi -if false; then - $GPG --keyserver $KEYSERVER --recv-keys 0xE4B29C8D64885307 - $GPG --verify flex-$FLEX_VERSION.tar.gz.sig flex-$FLEX_VERSION.tar.gz -fi # verify fmt echo "$FMT_SHA256 fmt-$FMT_VERSION.tar.gz" | sha256sum -c # verify spdlog @@ -1025,7 +1014,6 @@ if [ ! -d $PREFIX/include/gflags ]; then if [ -d gflags ]; then rm -rf gflags fi - git clone https://github.com/memgraph/gflags.git gflags pushd gflags git checkout $GFLAGS_COMMIT_HASH @@ -1034,7 +1022,7 @@ if [ ! -d $PREFIX/include/gflags ]; then cmake .. $COMMON_CMAKE_FLAGS \ -DREGISTER_INSTALL_PREFIX=OFF \ -DBUILD_gflags_nothreads_LIB=OFF \ - -DGFLAGS_NO_FILENAMES=0 + -DGFLAGS_NO_FILENAMES=1 make -j$CPUS install popd && popd fi @@ -1232,18 +1220,6 @@ if false; then fi fi -log_tool_name "flex $FLEX_VERSION" -if [ ! -f $PREFIX/include/FlexLexer.h ]; then - if [ -d flex-$FLEX_VERSION ]; then - rm -rf flex-$FLEX_VERSION - fi - tar -xzf ../archives/flex-$FLEX_VERSION.tar.gz - pushd flex-$FLEX_VERSION - ./configure $COMMON_CONFIGURE_FLAGS - make -j$CPUS install - popd -fi - popd # NOTE: It's important/clean (e.g., easier upload to S3) to have a separated # folder to the output archive. diff --git a/src/auth/auth.cpp b/src/auth/auth.cpp index 3bb7648db..54825c70c 100644 --- a/src/auth/auth.cpp +++ b/src/auth/auth.cpp @@ -35,16 +35,42 @@ DEFINE_VALIDATED_string(auth_module_executable, "", "Absolute path to the auth m } return true; }); -DEFINE_bool(auth_module_create_missing_user, true, "Set to false to disable creation of missing users."); -DEFINE_bool(auth_module_create_missing_role, true, "Set to false to disable creation of missing roles."); -DEFINE_bool(auth_module_manage_roles, true, "Set to false to disable management of roles through the auth module."); DEFINE_VALIDATED_int32(auth_module_timeout_ms, 10000, "Timeout (in milliseconds) used when waiting for a " "response from the auth module.", FLAG_IN_RANGE(100, 1800000)); +// DEPRECATED FLAGS +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables, misc-unused-parameters) +DEFINE_VALIDATED_HIDDEN_bool(auth_module_create_missing_user, true, + "Set to false to disable creation of missing users.", { + spdlog::warn( + "auth_module_create_missing_user flag is deprecated. It not possible to create " + "users through the module anymore."); + return true; + }); + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables, misc-unused-parameters) +DEFINE_VALIDATED_HIDDEN_bool(auth_module_create_missing_role, true, + "Set to false to disable creation of missing roles.", { + spdlog::warn( + "auth_module_create_missing_role flag is deprecated. It not possible to create " + "roles through the module anymore."); + return true; + }); + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables, misc-unused-parameters) +DEFINE_VALIDATED_HIDDEN_bool( + auth_module_manage_roles, true, "Set to false to disable management of roles through the auth module.", { + spdlog::warn( + "auth_module_manage_roles flag is deprecated. It not possible to create roles through the module anymore."); + return true; + }); + namespace memgraph::auth { +const Auth::Epoch Auth::kStartEpoch = 1; + namespace { #ifdef MG_ENTERPRISE /** @@ -192,6 +218,17 @@ void MigrateVersions(kvstore::KVStore &store) { version_str = kVersionV1; } } + +auto ParseJson(std::string_view str) { + nlohmann::json data; + try { + data = nlohmann::json::parse(str); + } catch (const nlohmann::json::parse_error &e) { + throw AuthException("Couldn't load auth data!"); + } + return data; +} + }; // namespace Auth::Auth(std::string storage_directory, Config config) @@ -199,8 +236,11 @@ Auth::Auth(std::string storage_directory, Config config) MigrateVersions(storage_); } -std::optional Auth::Authenticate(const std::string &username, const std::string &password) { +std::optional Auth::Authenticate(const std::string &username, const std::string &password) { if (module_.IsUsed()) { + /* + * MODULE AUTH STORAGE + */ const auto license_check_result = license::global_license_checker.IsEnterpriseValid(utils::global_settings); if (license_check_result.HasError()) { spdlog::warn(license::LicenseCheckErrorToString(license_check_result.GetError(), "authentication modules")); @@ -225,108 +265,64 @@ std::optional Auth::Authenticate(const std::string &username, const std::s auto is_authenticated = ret_authenticated.get(); const auto &rolename = ret_role.get(); + // Check if role is present + auto role = GetRole(rolename); + if (!role) { + spdlog::warn(utils::MessageWithLink("Couldn't authenticate user '{}' because the role '{}' doesn't exist.", + username, rolename, "https://memgr.ph/auth")); + return std::nullopt; + } + // Authenticate the user. if (!is_authenticated) return std::nullopt; - /** - * TODO - * The auth module should not update auth data. - * There is now way to replicate it and we should not be storing sensitive data if we don't have to. - */ - - // Find or create the user and return it. - auto user = GetUser(username); - if (!user) { - if (FLAGS_auth_module_create_missing_user) { - user = AddUser(username, password); - if (!user) { - spdlog::warn(utils::MessageWithLink( - "Couldn't create the missing user '{}' using the auth module because the user already exists as a role.", - username, "https://memgr.ph/auth")); - return std::nullopt; - } - } else { - spdlog::warn(utils::MessageWithLink( - "Couldn't authenticate user '{}' using the auth module because the user doesn't exist.", username, - "https://memgr.ph/auth")); - return std::nullopt; - } - } else { - UpdatePassword(*user, password); - } - if (FLAGS_auth_module_manage_roles) { - if (!rolename.empty()) { - auto role = GetRole(rolename); - if (!role) { - if (FLAGS_auth_module_create_missing_role) { - role = AddRole(rolename); - if (!role) { - spdlog::warn( - utils::MessageWithLink("Couldn't authenticate user '{}' using the auth module because the user's " - "role '{}' already exists as a user.", - username, rolename, "https://memgr.ph/auth")); - return std::nullopt; - } - SaveRole(*role); - } else { - spdlog::warn(utils::MessageWithLink( - "Couldn't authenticate user '{}' using the auth module because the user's role '{}' doesn't exist.", - username, rolename, "https://memgr.ph/auth")); - return std::nullopt; - } - } - user->SetRole(*role); - } else { - user->ClearRole(); - } - } - SaveUser(*user); - return user; - } else { - auto user = GetUser(username); - if (!user) { - spdlog::warn(utils::MessageWithLink("Couldn't authenticate user '{}' because the user doesn't exist.", username, - "https://memgr.ph/auth")); - return std::nullopt; - } - if (!user->CheckPassword(password)) { - spdlog::warn(utils::MessageWithLink("Couldn't authenticate user '{}' because the password is not correct.", - username, "https://memgr.ph/auth")); - return std::nullopt; - } - if (user->UpgradeHash(password)) { - SaveUser(*user); - } - - return user; + return RoleWUsername{username, std::move(*role)}; } + + /* + * LOCAL AUTH STORAGE + */ + auto user = GetUser(username); + if (!user) { + spdlog::warn(utils::MessageWithLink("Couldn't authenticate user '{}' because the user doesn't exist.", username, + "https://memgr.ph/auth")); + return std::nullopt; + } + if (!user->CheckPassword(password)) { + spdlog::warn(utils::MessageWithLink("Couldn't authenticate user '{}' because the password is not correct.", + username, "https://memgr.ph/auth")); + return std::nullopt; + } + if (user->UpgradeHash(password)) { + SaveUser(*user); + } + + return user; } -std::optional Auth::GetUser(const std::string &username_orig) const { - auto username = utils::ToLowerCase(username_orig); - auto existing_user = storage_.Get(kUserPrefix + username); - if (!existing_user) return std::nullopt; - - nlohmann::json data; - try { - data = nlohmann::json::parse(*existing_user); - } catch (const nlohmann::json::parse_error &e) { - throw AuthException("Couldn't load user data!"); - } - - auto user = User::Deserialize(data); - auto link = storage_.Get(kLinkPrefix + username); - +void Auth::LinkUser(User &user) const { + auto link = storage_.Get(kLinkPrefix + user.username()); if (link) { auto role = GetRole(*link); if (role) { user.SetRole(*role); } } +} + +std::optional Auth::GetUser(const std::string &username_orig) const { + if (module_.IsUsed()) return std::nullopt; // User's are not supported when using module + auto username = utils::ToLowerCase(username_orig); + auto existing_user = storage_.Get(kUserPrefix + username); + if (!existing_user) return std::nullopt; + + auto user = User::Deserialize(ParseJson(*existing_user)); + LinkUser(user); return user; } void Auth::SaveUser(const User &user, system::Transaction *system_tx) { + DisableIfModuleUsed(); bool success = false; if (const auto *role = user.role(); role != nullptr) { success = storage_.PutMultiple( @@ -338,6 +334,10 @@ void Auth::SaveUser(const User &user, system::Transaction *system_tx) { if (!success) { throw AuthException("Couldn't save user '{}'!", user.username()); } + + // Durability updated -> new epoch + UpdateEpoch(); + // All changes to the user end up calling this function, so no need to add a delta anywhere else if (system_tx) { #ifdef MG_ENTERPRISE @@ -347,6 +347,7 @@ void Auth::SaveUser(const User &user, system::Transaction *system_tx) { } void Auth::UpdatePassword(auth::User &user, const std::optional &password) { + DisableIfModuleUsed(); // Check if null if (!password) { if (!config_.password_permit_null) { @@ -378,6 +379,7 @@ void Auth::UpdatePassword(auth::User &user, const std::optional &pa std::optional Auth::AddUser(const std::string &username, const std::optional &password, system::Transaction *system_tx) { + DisableIfModuleUsed(); if (!NameRegexMatch(username)) { throw AuthException("Invalid user name."); } @@ -392,12 +394,17 @@ std::optional Auth::AddUser(const std::string &username, const std::option } bool Auth::RemoveUser(const std::string &username_orig, system::Transaction *system_tx) { + DisableIfModuleUsed(); auto username = utils::ToLowerCase(username_orig); if (!storage_.Get(kUserPrefix + username)) return false; std::vector keys({kLinkPrefix + username, kUserPrefix + username}); if (!storage_.DeleteMultiple(keys)) { throw AuthException("Couldn't remove user '{}'!", username); } + + // Durability updated -> new epoch + UpdateEpoch(); + // Handling drop user delta if (system_tx) { #ifdef MG_ENTERPRISE @@ -412,9 +419,12 @@ std::vector Auth::AllUsers() const { for (auto it = storage_.begin(kUserPrefix); it != storage_.end(kUserPrefix); ++it) { auto username = it->first.substr(kUserPrefix.size()); if (username != utils::ToLowerCase(username)) continue; - auto user = GetUser(username); - if (user) { - ret.push_back(std::move(*user)); + try { + User user = auth::User::Deserialize(ParseJson(it->second)); // Will throw on failure + LinkUser(user); + ret.emplace_back(std::move(user)); + } catch (AuthException &) { + continue; } } return ret; @@ -425,9 +435,12 @@ std::vector Auth::AllUsernames() const { for (auto it = storage_.begin(kUserPrefix); it != storage_.end(kUserPrefix); ++it) { auto username = it->first.substr(kUserPrefix.size()); if (username != utils::ToLowerCase(username)) continue; - auto user = GetUser(username); - if (user) { - ret.push_back(username); + try { + // Check if serialized correctly + memgraph::auth::User::Deserialize(ParseJson(it->second)); // Will throw on failure + ret.emplace_back(std::move(username)); + } catch (AuthException &) { + continue; } } return ret; @@ -435,25 +448,24 @@ std::vector Auth::AllUsernames() const { bool Auth::HasUsers() const { return storage_.begin(kUserPrefix) != storage_.end(kUserPrefix); } +bool Auth::AccessControlled() const { return HasUsers() || module_.IsUsed(); } + std::optional Auth::GetRole(const std::string &rolename_orig) const { auto rolename = utils::ToLowerCase(rolename_orig); auto existing_role = storage_.Get(kRolePrefix + rolename); if (!existing_role) return std::nullopt; - nlohmann::json data; - try { - data = nlohmann::json::parse(*existing_role); - } catch (const nlohmann::json::parse_error &e) { - throw AuthException("Couldn't load role data!"); - } - - return Role::Deserialize(data); + return Role::Deserialize(ParseJson(*existing_role)); } void Auth::SaveRole(const Role &role, system::Transaction *system_tx) { if (!storage_.Put(kRolePrefix + role.rolename(), role.Serialize().dump())) { throw AuthException("Couldn't save role '{}'!", role.rolename()); } + + // Durability updated -> new epoch + UpdateEpoch(); + // All changes to the role end up calling this function, so no need to add a delta anywhere else if (system_tx) { #ifdef MG_ENTERPRISE @@ -486,6 +498,10 @@ bool Auth::RemoveRole(const std::string &rolename_orig, system::Transaction *sys if (!storage_.DeleteMultiple(keys)) { throw AuthException("Couldn't remove role '{}'!", rolename); } + + // Durability updated -> new epoch + UpdateEpoch(); + // Handling drop role delta if (system_tx) { #ifdef MG_ENTERPRISE @@ -500,11 +516,8 @@ std::vector Auth::AllRoles() const { for (auto it = storage_.begin(kRolePrefix); it != storage_.end(kRolePrefix); ++it) { auto rolename = it->first.substr(kRolePrefix.size()); if (rolename != utils::ToLowerCase(rolename)) continue; - if (auto role = GetRole(rolename)) { - ret.push_back(*role); - } else { - throw AuthException("Couldn't load role '{}'!", rolename); - } + Role role = memgraph::auth::Role::Deserialize(ParseJson(it->second)); // Will throw on failure + ret.emplace_back(std::move(role)); } return ret; } @@ -514,14 +527,19 @@ std::vector Auth::AllRolenames() const { for (auto it = storage_.begin(kRolePrefix); it != storage_.end(kRolePrefix); ++it) { auto rolename = it->first.substr(kRolePrefix.size()); if (rolename != utils::ToLowerCase(rolename)) continue; - if (auto role = GetRole(rolename)) { - ret.push_back(rolename); + try { + // Check that the data is serialized correctly + memgraph::auth::Role::Deserialize(ParseJson(it->second)); + ret.emplace_back(std::move(rolename)); + } catch (AuthException &) { + continue; } } return ret; } std::vector Auth::AllUsersForRole(const std::string &rolename_orig) const { + DisableIfModuleUsed(); const auto rolename = utils::ToLowerCase(rolename_orig); std::vector ret; for (auto it = storage_.begin(kLinkPrefix); it != storage_.end(kLinkPrefix); ++it) { @@ -540,51 +558,176 @@ std::vector Auth::AllUsersForRole(const std::string &rolename_orig) } #ifdef MG_ENTERPRISE -bool Auth::GrantDatabaseToUser(const std::string &db, const std::string &name, system::Transaction *system_tx) { - if (auto user = GetUser(name)) { - if (db == kAllDatabases) { - user->db_access().GrantAll(); - } else { - user->db_access().Add(db); +Auth::Result Auth::GrantDatabase(const std::string &db, const std::string &name, system::Transaction *system_tx) { + using enum Auth::Result; + if (module_.IsUsed()) { + if (auto role = GetRole(name)) { + GrantDatabase(db, *role, system_tx); + return SUCCESS; } - SaveUser(*user, system_tx); - return true; + return NO_ROLE; } - return false; + if (auto user = GetUser(name)) { + GrantDatabase(db, *user, system_tx); + return SUCCESS; + } + if (auto role = GetRole(name)) { + GrantDatabase(db, *role, system_tx); + return SUCCESS; + } + return NO_USER_ROLE; } -bool Auth::RevokeDatabaseFromUser(const std::string &db, const std::string &name, system::Transaction *system_tx) { - if (auto user = GetUser(name)) { - if (db == kAllDatabases) { - user->db_access().DenyAll(); - } else { - user->db_access().Remove(db); - } - SaveUser(*user, system_tx); - return true; +void Auth::GrantDatabase(const std::string &db, User &user, system::Transaction *system_tx) { + if (db == kAllDatabases) { + user.db_access().GrantAll(); + } else { + user.db_access().Grant(db); } - return false; + SaveUser(user, system_tx); +} + +void Auth::GrantDatabase(const std::string &db, Role &role, system::Transaction *system_tx) { + if (db == kAllDatabases) { + role.db_access().GrantAll(); + } else { + role.db_access().Grant(db); + } + SaveRole(role, system_tx); +} + +Auth::Result Auth::DenyDatabase(const std::string &db, const std::string &name, system::Transaction *system_tx) { + using enum Auth::Result; + if (module_.IsUsed()) { + if (auto role = GetRole(name)) { + DenyDatabase(db, *role, system_tx); + return SUCCESS; + } + return NO_ROLE; + } + if (auto user = GetUser(name)) { + DenyDatabase(db, *user, system_tx); + return SUCCESS; + } + if (auto role = GetRole(name)) { + DenyDatabase(db, *role, system_tx); + return SUCCESS; + } + return NO_USER_ROLE; +} + +void Auth::DenyDatabase(const std::string &db, User &user, system::Transaction *system_tx) { + if (db == kAllDatabases) { + user.db_access().DenyAll(); + } else { + user.db_access().Deny(db); + } + SaveUser(user, system_tx); +} + +void Auth::DenyDatabase(const std::string &db, Role &role, system::Transaction *system_tx) { + if (db == kAllDatabases) { + role.db_access().DenyAll(); + } else { + role.db_access().Deny(db); + } + SaveRole(role, system_tx); +} + +Auth::Result Auth::RevokeDatabase(const std::string &db, const std::string &name, system::Transaction *system_tx) { + using enum Auth::Result; + if (module_.IsUsed()) { + if (auto role = GetRole(name)) { + RevokeDatabase(db, *role, system_tx); + return SUCCESS; + } + return NO_ROLE; + } + if (auto user = GetUser(name)) { + RevokeDatabase(db, *user, system_tx); + return SUCCESS; + } + if (auto role = GetRole(name)) { + RevokeDatabase(db, *role, system_tx); + return SUCCESS; + } + return NO_USER_ROLE; +} + +void Auth::RevokeDatabase(const std::string &db, User &user, system::Transaction *system_tx) { + if (db == kAllDatabases) { + user.db_access().RevokeAll(); + } else { + user.db_access().Revoke(db); + } + SaveUser(user, system_tx); +} + +void Auth::RevokeDatabase(const std::string &db, Role &role, system::Transaction *system_tx) { + if (db == kAllDatabases) { + role.db_access().RevokeAll(); + } else { + role.db_access().Revoke(db); + } + SaveRole(role, system_tx); } void Auth::DeleteDatabase(const std::string &db, system::Transaction *system_tx) { for (auto it = storage_.begin(kUserPrefix); it != storage_.end(kUserPrefix); ++it) { auto username = it->first.substr(kUserPrefix.size()); - if (auto user = GetUser(username)) { - user->db_access().Delete(db); - SaveUser(*user, system_tx); + try { + User user = auth::User::Deserialize(ParseJson(it->second)); + LinkUser(user); + user.db_access().Revoke(db); + SaveUser(user, system_tx); + } catch (AuthException &) { + continue; + } + } + for (auto it = storage_.begin(kRolePrefix); it != storage_.end(kRolePrefix); ++it) { + auto rolename = it->first.substr(kRolePrefix.size()); + try { + auto role = memgraph::auth::Role::Deserialize(ParseJson(it->second)); + role.db_access().Revoke(db); + SaveRole(role, system_tx); + } catch (AuthException &) { + continue; } } } -bool Auth::SetMainDatabase(std::string_view db, const std::string &name, system::Transaction *system_tx) { - if (auto user = GetUser(name)) { - if (!user->db_access().SetDefault(db)) { - throw AuthException("Couldn't set default database '{}' for user '{}'!", db, name); +Auth::Result Auth::SetMainDatabase(std::string_view db, const std::string &name, system::Transaction *system_tx) { + using enum Auth::Result; + if (module_.IsUsed()) { + if (auto role = GetRole(name)) { + SetMainDatabase(db, *role, system_tx); + return SUCCESS; } - SaveUser(*user, system_tx); - return true; + return NO_ROLE; } - return false; + if (auto user = GetUser(name)) { + SetMainDatabase(db, *user, system_tx); + return SUCCESS; + } + if (auto role = GetRole(name)) { + SetMainDatabase(db, *role, system_tx); + return SUCCESS; + } + return NO_USER_ROLE; +} + +void Auth::SetMainDatabase(std::string_view db, User &user, system::Transaction *system_tx) { + if (!user.db_access().SetMain(db)) { + throw AuthException("Couldn't set default database '{}' for '{}'!", db, user.username()); + } + SaveUser(user, system_tx); +} + +void Auth::SetMainDatabase(std::string_view db, Role &role, system::Transaction *system_tx) { + if (!role.db_access().SetMain(db)) { + throw AuthException("Couldn't set default database '{}' for '{}'!", db, role.rolename()); + } + SaveRole(role, system_tx); } #endif diff --git a/src/auth/auth.hpp b/src/auth/auth.hpp index 4b1bcd479..f8d3d58be 100644 --- a/src/auth/auth.hpp +++ b/src/auth/auth.hpp @@ -29,6 +29,18 @@ using SynchedAuth = memgraph::utils::Synchronized + RoleWUsername(std::string_view username, Args &&...args) : Role{std::forward(args)...}, username_{username} {} + + std::string username() { return username_; } + const std::string &username() const { return username_; } + + private: + std::string username_; +}; +using UserOrRole = std::variant; + /** * This class serves as the main Authentication/Authorization storage. * It provides functions for managing Users, Roles, Permissions and FineGrainedAccessPermissions. @@ -61,6 +73,25 @@ class Auth final { std::regex password_regex{password_regex_str}; }; + struct Epoch { + Epoch() : epoch_{0} {} + Epoch(unsigned e) : epoch_{e} {} + + Epoch operator++() { return ++epoch_; } + bool operator==(const Epoch &rhs) const = default; + + private: + unsigned epoch_; + }; + + static const Epoch kStartEpoch; + + enum class Result { + SUCCESS, + NO_USER_ROLE, + NO_ROLE, + }; + explicit Auth(std::string storage_directory, Config config); /** @@ -89,7 +120,7 @@ class Auth final { * @return a user when the username and password match, nullopt otherwise * @throw AuthException if unable to authenticate for whatever reason. */ - std::optional Authenticate(const std::string &username, const std::string &password); + std::optional Authenticate(const std::string &username, const std::string &password); /** * Gets a user from the storage. @@ -101,6 +132,8 @@ class Auth final { */ std::optional GetUser(const std::string &username) const; + void LinkUser(User &user) const; + /** * Saves a user object to the storage. * @@ -163,6 +196,13 @@ class Auth final { */ bool HasUsers() const; + /** + * Returns whether the access is controlled by authentication/authorization. + * + * @return `true` if auth needs to run + */ + bool AccessControlled() const; + /** * Gets a role from the storage. * @@ -173,6 +213,37 @@ class Auth final { */ std::optional GetRole(const std::string &rolename) const; + std::optional GetUserOrRole(const std::optional &username, + const std::optional &rolename) const { + auto expect = [](bool condition, std::string &&msg) { + if (!condition) throw AuthException(std::move(msg)); + }; + // Special case if we are using a module; we must find the specified role + if (module_.IsUsed()) { + expect(username && rolename, "When using a module, a role needs to be connected to a username."); + const auto role = GetRole(*rolename); + expect(role != std::nullopt, "No role named " + *rolename); + return UserOrRole(auth::RoleWUsername{*username, *role}); + } + + // First check if we need to find a role + if (username && rolename) { + const auto role = GetRole(*rolename); + expect(role != std::nullopt, "No role named " + *rolename); + return UserOrRole(auth::RoleWUsername{*username, *role}); + } + + // We are only looking for a user + if (username) { + const auto user = GetUser(*username); + expect(user != std::nullopt, "No user named " + *username); + return *user; + } + + // No user or role + return std::nullopt; + } + /** * Saves a role object to the storage. * @@ -229,16 +300,6 @@ class Auth final { std::vector AllUsersForRole(const std::string &rolename) const; #ifdef MG_ENTERPRISE - /** - * @brief Revoke access to individual database for a user. - * - * @param db name of the database to revoke - * @param name user's username - * @return true on success - * @throw AuthException if unable to find or update the user - */ - bool RevokeDatabaseFromUser(const std::string &db, const std::string &name, system::Transaction *system_tx = nullptr); - /** * @brief Grant access to individual database for a user. * @@ -247,7 +308,33 @@ class Auth final { * @return true on success * @throw AuthException if unable to find or update the user */ - bool GrantDatabaseToUser(const std::string &db, const std::string &name, system::Transaction *system_tx = nullptr); + Result GrantDatabase(const std::string &db, const std::string &name, system::Transaction *system_tx = nullptr); + void GrantDatabase(const std::string &db, User &user, system::Transaction *system_tx = nullptr); + void GrantDatabase(const std::string &db, Role &role, system::Transaction *system_tx = nullptr); + + /** + * @brief Revoke access to individual database for a user. + * + * @param db name of the database to revoke + * @param name user's username + * @return true on success + * @throw AuthException if unable to find or update the user + */ + Result DenyDatabase(const std::string &db, const std::string &name, system::Transaction *system_tx = nullptr); + void DenyDatabase(const std::string &db, User &user, system::Transaction *system_tx = nullptr); + void DenyDatabase(const std::string &db, Role &role, system::Transaction *system_tx = nullptr); + + /** + * @brief Revoke access to individual database for a user. + * + * @param db name of the database to revoke + * @param name user's username + * @return true on success + * @throw AuthException if unable to find or update the user + */ + Result RevokeDatabase(const std::string &db, const std::string &name, system::Transaction *system_tx = nullptr); + void RevokeDatabase(const std::string &db, User &user, system::Transaction *system_tx = nullptr); + void RevokeDatabase(const std::string &db, Role &role, system::Transaction *system_tx = nullptr); /** * @brief Delete a database from all users. @@ -265,9 +352,17 @@ class Auth final { * @return true on success * @throw AuthException if unable to find or update the user */ - bool SetMainDatabase(std::string_view db, const std::string &name, system::Transaction *system_tx = nullptr); + Result SetMainDatabase(std::string_view db, const std::string &name, system::Transaction *system_tx = nullptr); + void SetMainDatabase(std::string_view db, User &user, system::Transaction *system_tx = nullptr); + void SetMainDatabase(std::string_view db, Role &role, system::Transaction *system_tx = nullptr); #endif + bool UpToDate(Epoch &e) const { + bool res = e == epoch_; + e = epoch_; + return res; + } + private: /** * @brief @@ -278,11 +373,18 @@ class Auth final { */ bool NameRegexMatch(const std::string &user_or_role) const; + void UpdateEpoch() { ++epoch_; } + + void DisableIfModuleUsed() const { + if (module_.IsUsed()) throw AuthException("Operation not permited when using an authentication module."); + } + // Even though the `kvstore::KVStore` class is guaranteed to be thread-safe, // Auth is not thread-safe because modifying users and roles might require // more than one operation on the storage. kvstore::KVStore storage_; auth::Module module_; Config config_; + Epoch epoch_{kStartEpoch}; }; } // namespace memgraph::auth diff --git a/src/auth/models.cpp b/src/auth/models.cpp index f75e6fe32..51b13329a 100644 --- a/src/auth/models.cpp +++ b/src/auth/models.cpp @@ -425,10 +425,11 @@ Role::Role(const std::string &rolename, const Permissions &permissions) : rolename_(utils::ToLowerCase(rolename)), permissions_(permissions) {} #ifdef MG_ENTERPRISE Role::Role(const std::string &rolename, const Permissions &permissions, - FineGrainedAccessHandler fine_grained_access_handler) + FineGrainedAccessHandler fine_grained_access_handler, Databases db_access) : rolename_(utils::ToLowerCase(rolename)), permissions_(permissions), - fine_grained_access_handler_(std::move(fine_grained_access_handler)) {} + fine_grained_access_handler_(std::move(fine_grained_access_handler)), + db_access_(std::move(db_access)) {} #endif const std::string &Role::rolename() const { return rolename_; } @@ -454,8 +455,10 @@ nlohmann::json Role::Serialize() const { #ifdef MG_ENTERPRISE if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) { data[kFineGrainedAccessHandler] = fine_grained_access_handler_.Serialize(); + data[kDatabases] = db_access_.Serialize(); } else { data[kFineGrainedAccessHandler] = {}; + data[kDatabases] = {}; } #endif return data; @@ -471,12 +474,21 @@ Role Role::Deserialize(const nlohmann::json &data) { auto permissions = Permissions::Deserialize(data[kPermissions]); #ifdef MG_ENTERPRISE if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) { + Databases db_access; + if (data[kDatabases].is_structured()) { + db_access = Databases::Deserialize(data[kDatabases]); + } else { + // Back-compatibility + spdlog::warn("Role without specified database access. Given access to the default database."); + db_access.Grant(dbms::kDefaultDB); + db_access.SetMain(dbms::kDefaultDB); + } FineGrainedAccessHandler fine_grained_access_handler; // We can have an empty fine_grained if the user was created without a valid license if (data[kFineGrainedAccessHandler].is_object()) { fine_grained_access_handler = FineGrainedAccessHandler::Deserialize(data[kFineGrainedAccessHandler]); } - return {data[kRoleName], permissions, std::move(fine_grained_access_handler)}; + return {data[kRoleName], permissions, std::move(fine_grained_access_handler), std::move(db_access)}; } #endif return {data[kRoleName], permissions}; @@ -493,7 +505,7 @@ bool operator==(const Role &first, const Role &second) { } #ifdef MG_ENTERPRISE -void Databases::Add(std::string_view db) { +void Databases::Grant(std::string_view db) { if (allow_all_) { grants_dbs_.clear(); allow_all_ = false; @@ -502,19 +514,19 @@ void Databases::Add(std::string_view db) { denies_dbs_.erase(std::string{db}); // TODO: C++23 use transparent key compare } -void Databases::Remove(const std::string &db) { +void Databases::Deny(const std::string &db) { denies_dbs_.emplace(db); grants_dbs_.erase(db); } -void Databases::Delete(const std::string &db) { +void Databases::Revoke(const std::string &db) { denies_dbs_.erase(db); if (!allow_all_) { grants_dbs_.erase(db); } // Reset if default deleted - if (default_db_ == db) { - default_db_ = ""; + if (main_db_ == db) { + main_db_ = ""; } } @@ -530,9 +542,16 @@ void Databases::DenyAll() { denies_dbs_.clear(); } -bool Databases::SetDefault(std::string_view db) { +void Databases::RevokeAll() { + allow_all_ = false; + grants_dbs_.clear(); + denies_dbs_.clear(); + main_db_ = ""; +} + +bool Databases::SetMain(std::string_view db) { if (!Contains(db)) return false; - default_db_ = db; + main_db_ = db; return true; } @@ -540,11 +559,11 @@ bool Databases::SetDefault(std::string_view db) { return !denies_dbs_.contains(db) && (allow_all_ || grants_dbs_.contains(db)); } -const std::string &Databases::GetDefault() const { - if (!Contains(default_db_)) { - throw AuthException("No access to the set default database \"{}\".", default_db_); +const std::string &Databases::GetMain() const { + if (!Contains(main_db_)) { + throw AuthException("No access to the set default database \"{}\".", main_db_); } - return default_db_; + return main_db_; } nlohmann::json Databases::Serialize() const { @@ -552,7 +571,7 @@ nlohmann::json Databases::Serialize() const { data[kGrants] = grants_dbs_; data[kDenies] = denies_dbs_; data[kAllowAll] = allow_all_; - data[kDefault] = default_db_; + data[kDefault] = main_db_; return data; } @@ -719,15 +738,16 @@ User User::Deserialize(const nlohmann::json &data) { } else { // Back-compatibility spdlog::warn("User without specified database access. Given access to the default database."); - db_access.Add(dbms::kDefaultDB); - db_access.SetDefault(dbms::kDefaultDB); + db_access.Grant(dbms::kDefaultDB); + db_access.SetMain(dbms::kDefaultDB); } FineGrainedAccessHandler fine_grained_access_handler; // We can have an empty fine_grained if the user was created without a valid license if (data[kFineGrainedAccessHandler].is_object()) { fine_grained_access_handler = FineGrainedAccessHandler::Deserialize(data[kFineGrainedAccessHandler]); } - return {data[kUsername], std::move(password_hash), permissions, std::move(fine_grained_access_handler), db_access}; + return {data[kUsername], std::move(password_hash), permissions, std::move(fine_grained_access_handler), + std::move(db_access)}; } #endif return {data[kUsername], std::move(password_hash), permissions}; diff --git a/src/auth/models.hpp b/src/auth/models.hpp index b65d172ff..9b12abee4 100644 --- a/src/auth/models.hpp +++ b/src/auth/models.hpp @@ -205,52 +205,10 @@ class FineGrainedAccessHandler final { bool operator==(const FineGrainedAccessHandler &first, const FineGrainedAccessHandler &second); #endif -class Role final { - public: - Role() = default; - - explicit Role(const std::string &rolename); - Role(const std::string &rolename, const Permissions &permissions); -#ifdef MG_ENTERPRISE - Role(const std::string &rolename, const Permissions &permissions, - FineGrainedAccessHandler fine_grained_access_handler); -#endif - Role(const Role &) = default; - Role &operator=(const Role &) = default; - Role(Role &&) noexcept = default; - Role &operator=(Role &&) noexcept = default; - ~Role() = default; - - const std::string &rolename() const; - const Permissions &permissions() const; - Permissions &permissions(); -#ifdef MG_ENTERPRISE - const FineGrainedAccessHandler &fine_grained_access_handler() const; - FineGrainedAccessHandler &fine_grained_access_handler(); - const FineGrainedAccessPermissions &GetFineGrainedAccessLabelPermissions() const; - const FineGrainedAccessPermissions &GetFineGrainedAccessEdgeTypePermissions() const; -#endif - nlohmann::json Serialize() const; - - /// @throw AuthException if unable to deserialize. - static Role Deserialize(const nlohmann::json &data); - - friend bool operator==(const Role &first, const Role &second); - - private: - std::string rolename_; - Permissions permissions_; -#ifdef MG_ENTERPRISE - FineGrainedAccessHandler fine_grained_access_handler_; -#endif -}; - -bool operator==(const Role &first, const Role &second); - #ifdef MG_ENTERPRISE class Databases final { public: - Databases() : grants_dbs_{std::string{dbms::kDefaultDB}}, allow_all_(false), default_db_(dbms::kDefaultDB) {} + Databases() : grants_dbs_{std::string{dbms::kDefaultDB}}, allow_all_(false), main_db_(dbms::kDefaultDB) {} Databases(const Databases &) = default; Databases &operator=(const Databases &) = default; @@ -263,7 +221,7 @@ class Databases final { * * @param db name of the database to grant access to */ - void Add(std::string_view db); + void Grant(std::string_view db); /** * @brief Remove database to the list of granted access. @@ -272,7 +230,7 @@ class Databases final { * * @param db name of the database to grant access to */ - void Remove(const std::string &db); + void Deny(const std::string &db); /** * @brief Called when database is dropped. Removes it from granted (if allow_all is false) and denied set. @@ -280,7 +238,7 @@ class Databases final { * * @param db name of the database to grant access to */ - void Delete(const std::string &db); + void Revoke(const std::string &db); /** * @brief Set allow_all_ to true and clears grants and denied sets. @@ -292,10 +250,15 @@ class Databases final { */ void DenyAll(); + /** + * @brief Set allow_all_ to false and clears grants and denied sets. + */ + void RevokeAll(); + /** * @brief Set the default database. */ - bool SetDefault(std::string_view db); + bool SetMain(std::string_view db); /** * @brief Checks if access is grated to the database. @@ -304,11 +267,13 @@ class Databases final { * @return true if allow_all and not denied or granted */ bool Contains(std::string_view db) const; + bool Denies(std::string_view db_name) const { return denies_dbs_.contains(db_name); } + bool Grants(std::string_view db_name) const { return allow_all_ || grants_dbs_.contains(db_name); } bool GetAllowAll() const { return allow_all_; } const std::set> &GetGrants() const { return grants_dbs_; } const std::set> &GetDenies() const { return denies_dbs_; } - const std::string &GetDefault() const; + const std::string &GetMain() const; nlohmann::json Serialize() const; /// @throw AuthException if unable to deserialize. @@ -320,15 +285,69 @@ class Databases final { : grants_dbs_(std::move(grant)), denies_dbs_(std::move(deny)), allow_all_(allow_all), - default_db_(std::move(default_db)) {} + main_db_(std::move(default_db)) {} std::set> grants_dbs_; //!< set of databases with granted access std::set> denies_dbs_; //!< set of databases with denied access bool allow_all_; //!< flag to allow access to everything (denied overrides this) - std::string default_db_; //!< user's default database + std::string main_db_; //!< user's default database }; #endif +class Role { + public: + Role() = default; + + explicit Role(const std::string &rolename); + Role(const std::string &rolename, const Permissions &permissions); +#ifdef MG_ENTERPRISE + Role(const std::string &rolename, const Permissions &permissions, + FineGrainedAccessHandler fine_grained_access_handler, Databases db_access = {}); +#endif + Role(const Role &) = default; + Role &operator=(const Role &) = default; + Role(Role &&) noexcept = default; + Role &operator=(Role &&) noexcept = default; + ~Role() = default; + + const std::string &rolename() const; + const Permissions &permissions() const; + Permissions &permissions(); + Permissions GetPermissions() const { return permissions_; } +#ifdef MG_ENTERPRISE + const FineGrainedAccessHandler &fine_grained_access_handler() const; + FineGrainedAccessHandler &fine_grained_access_handler(); + const FineGrainedAccessPermissions &GetFineGrainedAccessLabelPermissions() const; + const FineGrainedAccessPermissions &GetFineGrainedAccessEdgeTypePermissions() const; +#endif + +#ifdef MG_ENTERPRISE + Databases &db_access() { return db_access_; } + const Databases &db_access() const { return db_access_; } + + bool DeniesDB(std::string_view db_name) const { return db_access_.Denies(db_name); } + bool GrantsDB(std::string_view db_name) const { return db_access_.Grants(db_name); } + bool HasAccess(std::string_view db_name) const { return !DeniesDB(db_name) && GrantsDB(db_name); } +#endif + + nlohmann::json Serialize() const; + + /// @throw AuthException if unable to deserialize. + static Role Deserialize(const nlohmann::json &data); + + friend bool operator==(const Role &first, const Role &second); + + private: + std::string rolename_; + Permissions permissions_; +#ifdef MG_ENTERPRISE + FineGrainedAccessHandler fine_grained_access_handler_; + Databases db_access_; +#endif +}; + +bool operator==(const Role &first, const Role &second); + // TODO (mferencevic): Implement password expiry. class User final { public: @@ -388,6 +407,18 @@ class User final { #ifdef MG_ENTERPRISE Databases &db_access() { return database_access_; } const Databases &db_access() const { return database_access_; } + + bool DeniesDB(std::string_view db_name) const { + bool denies = database_access_.Denies(db_name); + if (role_) denies |= role_->DeniesDB(db_name); + return denies; + } + bool GrantsDB(std::string_view db_name) const { + bool grants = database_access_.Grants(db_name); + if (role_) grants |= role_->GrantsDB(db_name); + return grants; + } + bool HasAccess(std::string_view db_name) const { return !DeniesDB(db_name) && GrantsDB(db_name); } #endif nlohmann::json Serialize() const; @@ -403,7 +434,7 @@ class User final { Permissions permissions_; #ifdef MG_ENTERPRISE FineGrainedAccessHandler fine_grained_access_handler_; - Databases database_access_; + Databases database_access_{}; #endif std::optional role_; }; diff --git a/src/auth/module.cpp b/src/auth/module.cpp index 45b93182a..04fa7fa73 100644 --- a/src/auth/module.cpp +++ b/src/auth/module.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Licensed as a Memgraph Enterprise file under the Memgraph Enterprise // License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use @@ -403,7 +403,7 @@ nlohmann::json Module::Call(const nlohmann::json ¶ms, int timeout_millisec) return ret; } -bool Module::IsUsed() { return !module_executable_path_.empty(); } +bool Module::IsUsed() const { return !module_executable_path_.empty(); } void Module::Shutdown() { if (pid_ == -1) return; diff --git a/src/auth/module.hpp b/src/auth/module.hpp index e711708f7..712466950 100644 --- a/src/auth/module.hpp +++ b/src/auth/module.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Licensed as a Memgraph Enterprise file under the Memgraph Enterprise // License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use @@ -49,7 +49,7 @@ class Module final { /// specified executable path and can thus be used. /// /// @return boolean indicating whether the module can be used - bool IsUsed(); + bool IsUsed() const; ~Module(); diff --git a/src/auth/rpc.cpp b/src/auth/rpc.cpp index b658c9491..6f264ccdf 100644 --- a/src/auth/rpc.cpp +++ b/src/auth/rpc.cpp @@ -18,11 +18,9 @@ #include "utils/enum.hpp" namespace memgraph::slk { - // Serialize code for auth::Role -void Save(const auth::Role &self, memgraph::slk::Builder *builder) { - memgraph::slk::Save(self.Serialize().dump(), builder); -} +void Save(const auth::Role &self, Builder *builder) { memgraph::slk::Save(self.Serialize().dump(), builder); } + namespace { auth::Role LoadAuthRole(memgraph::slk::Reader *reader) { std::string tmp; diff --git a/src/communication/websocket/auth.cpp b/src/communication/websocket/auth.cpp index 6efa97a08..68f873ee4 100644 --- a/src/communication/websocket/auth.cpp +++ b/src/communication/websocket/auth.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -12,19 +12,44 @@ #include "communication/websocket/auth.hpp" #include +#include "utils/variant_helpers.hpp" namespace memgraph::communication::websocket { bool SafeAuth::Authenticate(const std::string &username, const std::string &password) const { - return auth_->Lock()->Authenticate(username, password).has_value(); + user_or_role_ = auth_->Lock()->Authenticate(username, password); + return user_or_role_.has_value(); } -bool SafeAuth::HasUserPermission(const std::string &username, const auth::Permission permission) const { - if (const auto user = auth_->ReadLock()->GetUser(username); user) { - return user->GetPermissions().Has(permission) == auth::PermissionLevel::GRANT; +bool SafeAuth::HasPermission(const auth::Permission permission) const { + auto locked_auth = auth_->ReadLock(); + // Update if cache invalidated + if (!locked_auth->UpToDate(auth_epoch_) && user_or_role_) { + bool success = true; + std::visit(utils::Overloaded{[&](auth::User &user) { + auto tmp = locked_auth->GetUser(user.username()); + if (!tmp) success = false; + user = std::move(*tmp); + }, + [&](auth::Role &role) { + auto tmp = locked_auth->GetRole(role.rolename()); + if (!tmp) success = false; + role = std::move(*tmp); + }}, + *user_or_role_); + // Missing user/role; delete from cache + if (!success) user_or_role_.reset(); } + // Check permissions + if (user_or_role_) { + return std::visit(utils::Overloaded{[&](auto &user_or_role) { + return user_or_role.GetPermissions().Has(permission) == auth::PermissionLevel::GRANT; + }}, + *user_or_role_); + } + // NOTE: websocket authenticates only if there is a user, so no need to check if access controlled return false; } -bool SafeAuth::HasAnyUsers() const { return auth_->ReadLock()->HasUsers(); } +bool SafeAuth::AccessControlled() const { return auth_->ReadLock()->AccessControlled(); } } // namespace memgraph::communication::websocket diff --git a/src/communication/websocket/auth.hpp b/src/communication/websocket/auth.hpp index 1ab865a2a..cb838382c 100644 --- a/src/communication/websocket/auth.hpp +++ b/src/communication/websocket/auth.hpp @@ -21,9 +21,9 @@ class AuthenticationInterface { public: virtual bool Authenticate(const std::string &username, const std::string &password) const = 0; - virtual bool HasUserPermission(const std::string &username, auth::Permission permission) const = 0; + virtual bool HasPermission(auth::Permission permission) const = 0; - virtual bool HasAnyUsers() const = 0; + virtual bool AccessControlled() const = 0; }; class SafeAuth : public AuthenticationInterface { @@ -32,11 +32,13 @@ class SafeAuth : public AuthenticationInterface { bool Authenticate(const std::string &username, const std::string &password) const override; - bool HasUserPermission(const std::string &username, auth::Permission permission) const override; + bool HasPermission(auth::Permission permission) const override; - bool HasAnyUsers() const override; + bool AccessControlled() const override; private: auth::SynchedAuth *auth_; + mutable std::optional user_or_role_; + mutable auth::Auth::Epoch auth_epoch_{}; }; } // namespace memgraph::communication::websocket diff --git a/src/communication/websocket/session.cpp b/src/communication/websocket/session.cpp index 13c788ddd..094ed8f83 100644 --- a/src/communication/websocket/session.cpp +++ b/src/communication/websocket/session.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -80,7 +80,7 @@ bool Session::Run() { return false; } - authenticated_ = !auth_.HasAnyUsers(); + authenticated_ = !auth_.AccessControlled(); connected_.store(true, std::memory_order_relaxed); // run on the strand @@ -162,7 +162,7 @@ utils::BasicResult Session::Authorize(const nlohmann::json &creds) return {"Authentication failed!"}; } #ifdef MG_ENTERPRISE - if (!auth_.HasUserPermission(creds.at("username").get(), auth::Permission::WEBSOCKET)) { + if (!auth_.HasPermission(auth::Permission::WEBSOCKET)) { return {"Authorization failed!"}; } #endif diff --git a/src/glue/CMakeLists.txt b/src/glue/CMakeLists.txt index da287179f..8f3aec412 100644 --- a/src/glue/CMakeLists.txt +++ b/src/glue/CMakeLists.txt @@ -6,5 +6,6 @@ target_sources(mg-glue PRIVATE auth.cpp SessionHL.cpp ServerT.cpp MonitoringServerT.cpp - run_id.cpp) + run_id.cpp + query_user.cpp) target_link_libraries(mg-glue mg-query mg-auth mg-audit mg-flags) diff --git a/src/glue/SessionHL.cpp b/src/glue/SessionHL.cpp index 07e1bf6e8..6c901516c 100644 --- a/src/glue/SessionHL.cpp +++ b/src/glue/SessionHL.cpp @@ -11,6 +11,7 @@ #include #include +#include "auth/auth.hpp" #include "gflags/gflags.h" #include "audit/log.hpp" @@ -19,17 +20,22 @@ #include "glue/SessionHL.hpp" #include "glue/auth_checker.hpp" #include "glue/communication.hpp" +#include "glue/query_user.hpp" #include "glue/run_id.hpp" #include "license/license.hpp" +#include "query/auth_checker.hpp" #include "query/discard_value_stream.hpp" #include "query/interpreter_context.hpp" +#include "query/query_user.hpp" #include "utils/event_map.hpp" #include "utils/spin_lock.hpp" +#include "utils/variant_helpers.hpp" namespace memgraph::metrics { extern const Event ActiveBoltSessions; } // namespace memgraph::metrics +namespace { auto ToQueryExtras(const memgraph::communication::bolt::Value &extra) -> memgraph::query::QueryExtras { auto const &as_map = extra.ValueMap(); @@ -97,20 +103,24 @@ std::vector TypedValueResultStreamBase::De } return decoded_values; } + TypedValueResultStreamBase::TypedValueResultStreamBase(memgraph::storage::Storage *storage) : storage_(storage) {} -namespace memgraph::glue { - #ifdef MG_ENTERPRISE -inline static void MultiDatabaseAuth(const std::optional &user, std::string_view db) { - if (user && !AuthChecker::IsUserAuthorized(*user, {}, std::string(db))) { +void MultiDatabaseAuth(memgraph::query::QueryUserOrRole *user, std::string_view db) { + if (user && !user->IsAuthorized({}, std::string(db), &memgraph::query::session_long_policy)) { throw memgraph::communication::bolt::ClientError( "You are not authorized on the database \"{}\"! Please contact your database administrator.", db); } } +#endif +} // namespace +namespace memgraph::glue { + +#ifdef MG_ENTERPRISE std::string SessionHL::GetDefaultDB() { - if (user_.has_value()) { - return user_->db_access().GetDefault(); + if (user_or_role_) { + return user_or_role_->GetDefaultDB(); } return std::string{memgraph::dbms::kDefaultDB}; } @@ -132,13 +142,18 @@ bool SessionHL::Authenticate(const std::string &username, const std::string &pas interpreter_.ResetUser(); { auto locked_auth = auth_->Lock(); - if (locked_auth->HasUsers()) { - user_ = locked_auth->Authenticate(username, password); - if (user_.has_value()) { - interpreter_.SetUser(user_->username()); + if (locked_auth->AccessControlled()) { + const auto user_or_role = locked_auth->Authenticate(username, password); + if (user_or_role.has_value()) { + user_or_role_ = AuthChecker::GenQueryUser(auth_, *user_or_role); + interpreter_.SetUser(AuthChecker::GenQueryUser(auth_, *user_or_role)); } else { res = false; } + } else { + // No access control -> give empty user + user_or_role_ = AuthChecker::GenQueryUser(auth_, std::nullopt); + interpreter_.SetUser(AuthChecker::GenQueryUser(auth_, std::nullopt)); } } #ifdef MG_ENTERPRISE @@ -195,21 +210,17 @@ std::pair, std::optional> SessionHL::Interpret( } #ifdef MG_ENTERPRISE - const std::string *username{nullptr}; - if (user_) { - username = &user_->username(); - } - if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) { auto &db = interpreter_.current_db_.db_acc_; - audit_log_->Record(endpoint_.address().to_string(), user_ ? *username : "", query, - memgraph::storage::PropertyValue(params_pv), db ? db->get()->name() : "no known database"); + const auto username = user_or_role_ ? (user_or_role_->username() ? *user_or_role_->username() : "") : ""; + audit_log_->Record(endpoint_.address().to_string(), username, query, memgraph::storage::PropertyValue(params_pv), + db ? db->get()->name() : "no known database"); } #endif try { auto result = interpreter_.Prepare(query, params_pv, ToQueryExtras(extra)); const std::string db_name = result.db ? *result.db : ""; - if (user_ && !AuthChecker::IsUserAuthorized(*user_, result.privileges, db_name)) { + if (user_or_role_ && !user_or_role_->IsAuthorized(result.privileges, db_name, &query::session_long_policy)) { interpreter_.Abort(); if (db_name.empty()) { throw memgraph::communication::bolt::ClientError( @@ -311,7 +322,7 @@ void SessionHL::Configure(const std::mapinterpreters.WithLock([this](auto &interpreters) { interpreters.insert(&interpreter_); }); } diff --git a/src/glue/SessionHL.hpp b/src/glue/SessionHL.hpp index 64dcddda5..cf0280fcc 100644 --- a/src/glue/SessionHL.hpp +++ b/src/glue/SessionHL.hpp @@ -15,6 +15,7 @@ #include "communication/v2/server.hpp" #include "communication/v2/session.hpp" #include "dbms/database.hpp" +#include "glue/query_user.hpp" #include "query/interpreter.hpp" namespace memgraph::glue { @@ -82,7 +83,7 @@ class SessionHL final : public memgraph::communication::bolt::Session user_; + std::unique_ptr user_or_role_; #ifdef MG_ENTERPRISE memgraph::audit::Log *audit_log_; bool in_explicit_db_{false}; //!< If true, the user has defined the database to use via metadata diff --git a/src/glue/auth_checker.cpp b/src/glue/auth_checker.cpp index 4db6c827e..99463d323 100644 --- a/src/glue/auth_checker.cpp +++ b/src/glue/auth_checker.cpp @@ -14,53 +14,74 @@ #include "auth/auth.hpp" #include "auth/models.hpp" #include "glue/auth.hpp" +#include "glue/query_user.hpp" #include "license/license.hpp" +#include "query/auth_checker.hpp" #include "query/constants.hpp" #include "query/frontend/ast/ast.hpp" +#include "query/query_user.hpp" +#include "utils/logging.hpp" #include "utils/synchronized.hpp" +#include "utils/variant_helpers.hpp" #ifdef MG_ENTERPRISE namespace { -bool IsUserAuthorizedLabels(const memgraph::auth::User &user, const memgraph::query::DbAccessor *dba, - const std::vector &labels, - const memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) { +bool IsAuthorizedLabels(const memgraph::auth::UserOrRole &user_or_role, const memgraph::query::DbAccessor *dba, + const std::vector &labels, + const memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) { if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) { return true; } - return std::all_of(labels.begin(), labels.end(), [dba, &user, fine_grained_privilege](const auto &label) { - return user.GetFineGrainedAccessLabelPermissions().Has( - dba->LabelToName(label), memgraph::glue::FineGrainedPrivilegeToFineGrainedPermission( - fine_grained_privilege)) == memgraph::auth::PermissionLevel::GRANT; + return std::all_of(labels.begin(), labels.end(), [dba, &user_or_role, fine_grained_privilege](const auto &label) { + return std::visit(memgraph::utils::Overloaded{[&](auto &user_or_role) { + return user_or_role.GetFineGrainedAccessLabelPermissions().Has( + dba->LabelToName(label), memgraph::glue::FineGrainedPrivilegeToFineGrainedPermission( + fine_grained_privilege)) == + memgraph::auth::PermissionLevel::GRANT; + }}, + user_or_role); }); } -bool IsUserAuthorizedGloballyLabels(const memgraph::auth::User &user, - const memgraph::auth::FineGrainedPermission fine_grained_permission) { +bool IsAuthorizedGloballyLabels(const memgraph::auth::UserOrRole &user_or_role, + const memgraph::auth::FineGrainedPermission fine_grained_permission) { if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) { return true; } - return user.GetFineGrainedAccessLabelPermissions().Has(memgraph::query::kAsterisk, fine_grained_permission) == - memgraph::auth::PermissionLevel::GRANT; + return std::visit(memgraph::utils::Overloaded{[&](auto &user_or_role) { + return user_or_role.GetFineGrainedAccessLabelPermissions().Has(memgraph::query::kAsterisk, + fine_grained_permission) == + memgraph::auth::PermissionLevel::GRANT; + }}, + user_or_role); } -bool IsUserAuthorizedGloballyEdges(const memgraph::auth::User &user, - const memgraph::auth::FineGrainedPermission fine_grained_permission) { +bool IsAuthorizedGloballyEdges(const memgraph::auth::UserOrRole &user_or_role, + const memgraph::auth::FineGrainedPermission fine_grained_permission) { if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) { return true; } - return user.GetFineGrainedAccessEdgeTypePermissions().Has(memgraph::query::kAsterisk, fine_grained_permission) == - memgraph::auth::PermissionLevel::GRANT; + return std::visit(memgraph::utils::Overloaded{[&](auto &user_or_role) { + return user_or_role.GetFineGrainedAccessEdgeTypePermissions().Has(memgraph::query::kAsterisk, + fine_grained_permission) == + memgraph::auth::PermissionLevel::GRANT; + }}, + user_or_role); } -bool IsUserAuthorizedEdgeType(const memgraph::auth::User &user, const memgraph::query::DbAccessor *dba, - const memgraph::storage::EdgeTypeId &edgeType, - const memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) { +bool IsAuthorizedEdgeType(const memgraph::auth::UserOrRole &user_or_role, const memgraph::query::DbAccessor *dba, + const memgraph::storage::EdgeTypeId &edgeType, + const memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) { if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) { return true; } - return user.GetFineGrainedAccessEdgeTypePermissions().Has( - dba->EdgeTypeToName(edgeType), memgraph::glue::FineGrainedPrivilegeToFineGrainedPermission( - fine_grained_privilege)) == memgraph::auth::PermissionLevel::GRANT; + return std::visit(memgraph::utils::Overloaded{[&](auto &user_or_role) { + return user_or_role.GetFineGrainedAccessEdgeTypePermissions().Has( + dba->EdgeTypeToName(edgeType), + memgraph::glue::FineGrainedPrivilegeToFineGrainedPermission(fine_grained_privilege)) == + memgraph::auth::PermissionLevel::GRANT; + }}, + user_or_role); } } // namespace #endif @@ -68,47 +89,54 @@ namespace memgraph::glue { AuthChecker::AuthChecker(memgraph::auth::SynchedAuth *auth) : auth_(auth) {} -bool AuthChecker::IsUserAuthorized(const std::optional &username, - const std::vector &privileges, - const std::string &db_name) const { - std::optional maybe_user; - { - auto locked_auth = auth_->ReadLock(); - if (!locked_auth->HasUsers()) { - return true; - } - if (username.has_value()) { - maybe_user = locked_auth->GetUser(*username); - } +std::shared_ptr AuthChecker::GenQueryUser(const std::optional &username, + const std::optional &rolename) const { + const auto user_or_role = auth_->ReadLock()->GetUserOrRole(username, rolename); + if (user_or_role) { + return std::make_shared(auth_, *user_or_role); } + // No user or role + return std::make_shared(auth_); +} - return maybe_user.has_value() && IsUserAuthorized(*maybe_user, privileges, db_name); +std::unique_ptr AuthChecker::GenQueryUser(auth::SynchedAuth *auth, + const std::optional &user_or_role) { + if (user_or_role) { + return std::visit( + utils::Overloaded{[&](auto &user_or_role) { return std::make_unique(auth, user_or_role); }}, + *user_or_role); + } + // No user or role + return std::make_unique(auth); } #ifdef MG_ENTERPRISE std::unique_ptr AuthChecker::GetFineGrainedAuthChecker( - const std::string &username, const memgraph::query::DbAccessor *dba) const { + std::shared_ptr user_or_role, const memgraph::query::DbAccessor *dba) const { if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) { return {}; } - try { - auto user = user_.Lock(); - if (username != user->username()) { - auto maybe_user = auth_->ReadLock()->GetUser(username); - if (!maybe_user) { - throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist .", username); - } - *user = std::move(*maybe_user); - } - return std::make_unique(*user, dba); - - } catch (const memgraph::auth::AuthException &e) { - throw memgraph::query::QueryRuntimeException(e.what()); + if (!user_or_role || !*user_or_role) { + throw query::QueryRuntimeException("No user specified for fine grained authorization!"); } -} -void AuthChecker::ClearCache() const { - user_.WithLock([](auto &user) mutable { user = {}; }); + // Convert from query user to auth user or role + try { + auto glue_user = dynamic_cast(*user_or_role); + if (glue_user.user_) { + return std::make_unique(std::move(*glue_user.user_), dba); + } + if (glue_user.role_) { + return std::make_unique( + auth::RoleWUsername{*glue_user.username(), std::move(*glue_user.role_)}, dba); + } + DMG_ASSERT(false, "Glue user has neither user not role"); + } catch (std::bad_cast &e) { + DMG_ASSERT(false, "Using a non-glue user in glue..."); + } + + // Should never get here + return {}; } #endif @@ -116,7 +144,7 @@ bool AuthChecker::IsUserAuthorized(const memgraph::auth::User &user, const std::vector &privileges, const std::string &db_name) { // NOLINT #ifdef MG_ENTERPRISE - if (!db_name.empty() && !user.db_access().Contains(db_name)) { + if (!db_name.empty() && !user.HasAccess(db_name)) { return false; } #endif @@ -127,9 +155,34 @@ bool AuthChecker::IsUserAuthorized(const memgraph::auth::User &user, }); } +bool AuthChecker::IsRoleAuthorized(const memgraph::auth::Role &role, + const std::vector &privileges, + const std::string &db_name) { // NOLINT #ifdef MG_ENTERPRISE -FineGrainedAuthChecker::FineGrainedAuthChecker(auth::User user, const memgraph::query::DbAccessor *dba) - : user_{std::move(user)}, dba_(dba){}; + if (!db_name.empty() && !role.HasAccess(db_name)) { + return false; + } +#endif + const auto role_permissions = role.permissions(); + return std::all_of(privileges.begin(), privileges.end(), [&role_permissions](const auto privilege) { + return role_permissions.Has(memgraph::glue::PrivilegeToPermission(privilege)) == + memgraph::auth::PermissionLevel::GRANT; + }); +} + +bool AuthChecker::IsUserOrRoleAuthorized(const memgraph::auth::UserOrRole &user_or_role, + const std::vector &privileges, + const std::string &db_name) { + return std::visit( + utils::Overloaded{ + [&](const auth::User &user) -> bool { return AuthChecker::IsUserAuthorized(user, privileges, db_name); }, + [&](const auth::Role &role) -> bool { return AuthChecker::IsRoleAuthorized(role, privileges, db_name); }}, + user_or_role); +} + +#ifdef MG_ENTERPRISE +FineGrainedAuthChecker::FineGrainedAuthChecker(auth::UserOrRole user_or_role, const memgraph::query::DbAccessor *dba) + : user_or_role_{std::move(user_or_role)}, dba_(dba){}; bool FineGrainedAuthChecker::Has(const memgraph::query::VertexAccessor &vertex, const memgraph::storage::View view, const memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const { @@ -147,22 +200,22 @@ bool FineGrainedAuthChecker::Has(const memgraph::query::VertexAccessor &vertex, } } - return IsUserAuthorizedLabels(user_, dba_, *maybe_labels, fine_grained_privilege); + return IsAuthorizedLabels(user_or_role_, dba_, *maybe_labels, fine_grained_privilege); } bool FineGrainedAuthChecker::Has(const memgraph::query::EdgeAccessor &edge, const memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const { - return IsUserAuthorizedEdgeType(user_, dba_, edge.EdgeType(), fine_grained_privilege); + return IsAuthorizedEdgeType(user_or_role_, dba_, edge.EdgeType(), fine_grained_privilege); } bool FineGrainedAuthChecker::Has(const std::vector &labels, const memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const { - return IsUserAuthorizedLabels(user_, dba_, labels, fine_grained_privilege); + return IsAuthorizedLabels(user_or_role_, dba_, labels, fine_grained_privilege); } bool FineGrainedAuthChecker::Has(const memgraph::storage::EdgeTypeId &edge_type, const memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const { - return IsUserAuthorizedEdgeType(user_, dba_, edge_type, fine_grained_privilege); + return IsAuthorizedEdgeType(user_or_role_, dba_, edge_type, fine_grained_privilege); } bool FineGrainedAuthChecker::HasGlobalPrivilegeOnVertices( @@ -170,7 +223,7 @@ bool FineGrainedAuthChecker::HasGlobalPrivilegeOnVertices( if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) { return true; } - return IsUserAuthorizedGloballyLabels(user_, FineGrainedPrivilegeToFineGrainedPermission(fine_grained_privilege)); + return IsAuthorizedGloballyLabels(user_or_role_, FineGrainedPrivilegeToFineGrainedPermission(fine_grained_privilege)); } bool FineGrainedAuthChecker::HasGlobalPrivilegeOnEdges( @@ -178,7 +231,7 @@ bool FineGrainedAuthChecker::HasGlobalPrivilegeOnEdges( if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) { return true; } - return IsUserAuthorizedGloballyEdges(user_, FineGrainedPrivilegeToFineGrainedPermission(fine_grained_privilege)); + return IsAuthorizedGloballyEdges(user_or_role_, FineGrainedPrivilegeToFineGrainedPermission(fine_grained_privilege)); }; #endif } // namespace memgraph::glue diff --git a/src/glue/auth_checker.hpp b/src/glue/auth_checker.hpp index 217ac0c74..ef8e993df 100644 --- a/src/glue/auth_checker.hpp +++ b/src/glue/auth_checker.hpp @@ -22,53 +22,59 @@ namespace memgraph::glue { class AuthChecker : public query::AuthChecker { public: - explicit AuthChecker(memgraph::auth::SynchedAuth *auth); + explicit AuthChecker(auth::SynchedAuth *auth); - bool IsUserAuthorized(const std::optional &username, - const std::vector &privileges, - const std::string &db_name) const override; + std::shared_ptr GenQueryUser(const std::optional &username, + const std::optional &rolename) const override; + + static std::unique_ptr GenQueryUser(auth::SynchedAuth *auth, + const std::optional &user_or_role); #ifdef MG_ENTERPRISE - std::unique_ptr GetFineGrainedAuthChecker( - const std::string &username, const memgraph::query::DbAccessor *dba) const override; - - void ClearCache() const override; - + std::unique_ptr GetFineGrainedAuthChecker(std::shared_ptr user, + const query::DbAccessor *dba) const override; #endif - [[nodiscard]] static bool IsUserAuthorized(const memgraph::auth::User &user, - const std::vector &privileges, + + [[nodiscard]] static bool IsUserAuthorized(const auth::User &user, + const std::vector &privileges, const std::string &db_name = ""); + [[nodiscard]] static bool IsRoleAuthorized(const auth::Role &role, + const std::vector &privileges, + const std::string &db_name = ""); + + [[nodiscard]] static bool IsUserOrRoleAuthorized(const auth::UserOrRole &user_or_role, + const std::vector &privileges, + const std::string &db_name = ""); + private: - memgraph::auth::SynchedAuth *auth_; - mutable memgraph::utils::Synchronized user_; // cached user + auth::SynchedAuth *auth_; + mutable utils::Synchronized user_or_role_; // cached user }; #ifdef MG_ENTERPRISE class FineGrainedAuthChecker : public query::FineGrainedAuthChecker { public: - explicit FineGrainedAuthChecker(auth::User user, const memgraph::query::DbAccessor *dba); + explicit FineGrainedAuthChecker(auth::UserOrRole user, const query::DbAccessor *dba); - bool Has(const query::VertexAccessor &vertex, memgraph::storage::View view, + bool Has(const query::VertexAccessor &vertex, storage::View view, query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const override; bool Has(const query::EdgeAccessor &edge, query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const override; - bool Has(const std::vector &labels, + bool Has(const std::vector &labels, query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const override; - bool Has(const memgraph::storage::EdgeTypeId &edge_type, + bool Has(const storage::EdgeTypeId &edge_type, query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const override; - bool HasGlobalPrivilegeOnVertices( - memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const override; + bool HasGlobalPrivilegeOnVertices(query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const override; - bool HasGlobalPrivilegeOnEdges( - memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const override; + bool HasGlobalPrivilegeOnEdges(query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const override; private: - auth::User user_; - const memgraph::query::DbAccessor *dba_; + auth::UserOrRole user_or_role_; + const query::DbAccessor *dba_; }; #endif } // namespace memgraph::glue diff --git a/src/glue/auth_handler.cpp b/src/glue/auth_handler.cpp index 2d7260b3c..6178b152e 100644 --- a/src/glue/auth_handler.cpp +++ b/src/glue/auth_handler.cpp @@ -15,6 +15,7 @@ #include +#include "auth/auth.hpp" #include "auth/models.hpp" #include "dbms/constants.hpp" #include "glue/auth.hpp" @@ -123,6 +124,29 @@ std::vector> ShowRolePrivileges( } #ifdef MG_ENTERPRISE +std::vector> ShowDatabasePrivileges( + const std::optional &role) { + if (!memgraph::license::global_license_checker.IsEnterpriseValidFast() || !role) { + return {}; + } + + const auto &db = role->db_access(); + const auto &allows = db.GetAllowAll(); + const auto &grants = db.GetGrants(); + const auto &denies = db.GetDenies(); + + std::vector res; // First element is a list of granted databases, second of revoked ones + if (allows) { + res.emplace_back("*"); + } else { + std::vector grants_vec(grants.cbegin(), grants.cend()); + res.emplace_back(std::move(grants_vec)); + } + std::vector denies_vec(denies.cbegin(), denies.cend()); + res.emplace_back(std::move(denies_vec)); + return {res}; +} + std::vector> ShowDatabasePrivileges( const std::optional &user) { if (!memgraph::license::global_license_checker.IsEnterpriseValidFast() || !user) { @@ -130,9 +154,15 @@ std::vector> ShowDatabasePrivileges( } const auto &db = user->db_access(); - const auto &allows = db.GetAllowAll(); - const auto &grants = db.GetGrants(); - const auto &denies = db.GetDenies(); + auto allows = db.GetAllowAll(); + auto grants = db.GetGrants(); + auto denies = db.GetDenies(); + if (const auto *role = user->role()) { + const auto &role_db = role->db_access(); + allows |= role_db.GetAllowAll(); + grants.insert(role_db.GetGrants().begin(), role_db.GetGrants().end()); + denies.insert(role_db.GetDenies().begin(), role_db.GetDenies().end()); + } std::vector res; // First element is a list of granted databases, second of revoked ones if (allows) { @@ -287,7 +317,7 @@ bool AuthQueryHandler::CreateUser(const std::string &username, const std::option , system_tx); #ifdef MG_ENTERPRISE - GrantDatabaseToUser(auth::kAllDatabases, username, system_tx); + GrantDatabase(auth::kAllDatabases, username, system_tx); SetMainDatabase(dbms::kDefaultDB, username, system_tx); #endif } @@ -334,51 +364,97 @@ bool AuthQueryHandler::CreateRole(const std::string &rolename, system::Transacti } #ifdef MG_ENTERPRISE -bool AuthQueryHandler::RevokeDatabaseFromUser(const std::string &db_name, const std::string &username, - system::Transaction *system_tx) { +void AuthQueryHandler::GrantDatabase(const std::string &db_name, const std::string &user_or_role, + system::Transaction *system_tx) { try { auto locked_auth = auth_->Lock(); - auto user = locked_auth->GetUser(username); - if (!user) return false; - return locked_auth->RevokeDatabaseFromUser(db_name, username, system_tx); + const auto res = locked_auth->GrantDatabase(db_name, user_or_role, system_tx); + switch (res) { + using enum auth::Auth::Result; + case SUCCESS: + return; + case NO_USER_ROLE: + throw query::QueryRuntimeException("No user nor role '{}' found.", user_or_role); + case NO_ROLE: + throw query::QueryRuntimeException("Using auth module, no role '{}' found.", user_or_role); + break; + } } catch (const memgraph::auth::AuthException &e) { throw memgraph::query::QueryRuntimeException(e.what()); } } -bool AuthQueryHandler::GrantDatabaseToUser(const std::string &db_name, const std::string &username, - system::Transaction *system_tx) { +void AuthQueryHandler::DenyDatabase(const std::string &db_name, const std::string &user_or_role, + system::Transaction *system_tx) { try { auto locked_auth = auth_->Lock(); - auto user = locked_auth->GetUser(username); - if (!user) return false; - return locked_auth->GrantDatabaseToUser(db_name, username, system_tx); + const auto res = locked_auth->DenyDatabase(db_name, user_or_role, system_tx); + switch (res) { + using enum auth::Auth::Result; + case SUCCESS: + return; + case NO_USER_ROLE: + throw query::QueryRuntimeException("No user nor role '{}' found.", user_or_role); + case NO_ROLE: + throw query::QueryRuntimeException("Using auth module, no role '{}' found.", user_or_role); + break; + } + } catch (const memgraph::auth::AuthException &e) { + throw memgraph::query::QueryRuntimeException(e.what()); + } +} + +void AuthQueryHandler::RevokeDatabase(const std::string &db_name, const std::string &user_or_role, + system::Transaction *system_tx) { + try { + auto locked_auth = auth_->Lock(); + const auto res = locked_auth->RevokeDatabase(db_name, user_or_role, system_tx); + switch (res) { + using enum auth::Auth::Result; + case SUCCESS: + return; + case NO_USER_ROLE: + throw query::QueryRuntimeException("No user nor role '{}' found.", user_or_role); + case NO_ROLE: + throw query::QueryRuntimeException("Using auth module, no role '{}' found.", user_or_role); + break; + } } catch (const memgraph::auth::AuthException &e) { throw memgraph::query::QueryRuntimeException(e.what()); } } std::vector> AuthQueryHandler::GetDatabasePrivileges( - const std::string &username) { + const std::string &user_or_role) { try { auto locked_auth = auth_->ReadLock(); - auto user = locked_auth->GetUser(username); - if (!user) { - throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist.", username); + if (auto user = locked_auth->GetUser(user_or_role)) { + return ShowDatabasePrivileges(user); } - return ShowDatabasePrivileges(user); + if (auto role = locked_auth->GetRole(user_or_role)) { + return ShowDatabasePrivileges(role); + } + throw memgraph::query::QueryRuntimeException("Neither user nor role '{}' exist.", user_or_role); } catch (const memgraph::auth::AuthException &e) { throw memgraph::query::QueryRuntimeException(e.what()); } } -bool AuthQueryHandler::SetMainDatabase(std::string_view db_name, const std::string &username, +void AuthQueryHandler::SetMainDatabase(std::string_view db_name, const std::string &user_or_role, system::Transaction *system_tx) { try { auto locked_auth = auth_->Lock(); - auto user = locked_auth->GetUser(username); - if (!user) return false; - return locked_auth->SetMainDatabase(db_name, username, system_tx); + const auto res = locked_auth->SetMainDatabase(db_name, user_or_role, system_tx); + switch (res) { + using enum auth::Auth::Result; + case SUCCESS: + return; + case NO_USER_ROLE: + throw query::QueryRuntimeException("No user nor role '{}' found.", user_or_role); + case NO_ROLE: + throw query::QueryRuntimeException("Using auth module, no role '{}' found.", user_or_role); + break; + } } catch (const memgraph::auth::AuthException &e) { throw memgraph::query::QueryRuntimeException(e.what()); } diff --git a/src/glue/auth_handler.hpp b/src/glue/auth_handler.hpp index 52db6075f..d78daaea4 100644 --- a/src/glue/auth_handler.hpp +++ b/src/glue/auth_handler.hpp @@ -37,15 +37,19 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler { system::Transaction *system_tx) override; #ifdef MG_ENTERPRISE - bool RevokeDatabaseFromUser(const std::string &db_name, const std::string &username, - system::Transaction *system_tx) override; + void GrantDatabase(const std::string &db_name, const std::string &user_or_role, + system::Transaction *system_tx) override; - bool GrantDatabaseToUser(const std::string &db_name, const std::string &username, - system::Transaction *system_tx) override; + void DenyDatabase(const std::string &db_name, const std::string &user_or_role, + system::Transaction *system_tx) override; - std::vector> GetDatabasePrivileges(const std::string &username) override; + void RevokeDatabase(const std::string &db_name, const std::string &user_or_role, + system::Transaction *system_tx) override; - bool SetMainDatabase(std::string_view db_name, const std::string &username, system::Transaction *system_tx) override; + std::vector> GetDatabasePrivileges(const std::string &user_or_role) override; + + void SetMainDatabase(std::string_view db_name, const std::string &user_or_role, + system::Transaction *system_tx) override; void DeleteDatabase(std::string_view db_name, system::Transaction *system_tx) override; #endif diff --git a/src/glue/query_user.cpp b/src/glue/query_user.cpp new file mode 100644 index 000000000..5cd6e6750 --- /dev/null +++ b/src/glue/query_user.cpp @@ -0,0 +1,41 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "glue/query_user.hpp" + +#include "glue/auth_checker.hpp" + +namespace memgraph::glue { + +bool QueryUserOrRole::IsAuthorized(const std::vector &privileges, + const std::string &db_name, query::UserPolicy *policy) const { + auto locked_auth = auth_->Lock(); + // Check policy and update if behind (and policy permits it) + if (policy->DoUpdate() && !locked_auth->UpToDate(auth_epoch_)) { + if (user_) user_ = locked_auth->GetUser(user_->username()); + if (role_) role_ = locked_auth->GetRole(role_->rolename()); + } + + if (user_) return AuthChecker::IsUserAuthorized(*user_, privileges, db_name); + if (role_) return AuthChecker::IsRoleAuthorized(*role_, privileges, db_name); + + return !policy->DoUpdate() || !locked_auth->AccessControlled(); +} + +#ifdef MG_ENTERPRISE +std::string QueryUserOrRole::GetDefaultDB() const { + if (user_) return user_->db_access().GetMain(); + if (role_) return role_->db_access().GetMain(); + return std::string{dbms::kDefaultDB}; +} +#endif + +} // namespace memgraph::glue diff --git a/src/glue/query_user.hpp b/src/glue/query_user.hpp new file mode 100644 index 000000000..22f3598db --- /dev/null +++ b/src/glue/query_user.hpp @@ -0,0 +1,57 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include + +#include "auth/auth.hpp" +#include "query/query_user.hpp" +#include "utils/variant_helpers.hpp" + +namespace memgraph::glue { + +struct QueryUserOrRole : public query::QueryUserOrRole { + bool IsAuthorized(const std::vector &privileges, const std::string &db_name, + query::UserPolicy *policy) const override; + +#ifdef MG_ENTERPRISE + std::string GetDefaultDB() const override; +#endif + + explicit QueryUserOrRole(auth::SynchedAuth *auth) : query::QueryUserOrRole{std::nullopt, std::nullopt}, auth_{auth} {} + + QueryUserOrRole(auth::SynchedAuth *auth, auth::UserOrRole user_or_role) + : query::QueryUserOrRole{std::visit( + utils::Overloaded{[](const auto &user_or_role) { return user_or_role.username(); }}, + user_or_role), + std::visit(utils::Overloaded{[&](const auth::User &) -> std::optional { + return std::nullopt; + }, + [&](const auth::Role &role) -> std::optional { + return role.rolename(); + }}, + user_or_role)}, + auth_{auth} { + std::visit(utils::Overloaded{[&](auth::User &&user) { user_.emplace(std::move(user)); }, + [&](auth::Role &&role) { role_.emplace(std::move(role)); }}, + std::move(user_or_role)); + } + + private: + friend class AuthChecker; + auth::SynchedAuth *auth_; + mutable std::optional user_{}; + mutable std::optional role_{}; + mutable auth::Auth::Epoch auth_epoch_{auth::Auth::kStartEpoch}; +}; + +} // namespace memgraph::glue diff --git a/src/memgraph.cpp b/src/memgraph.cpp index 378d55e77..34d64f434 100644 --- a/src/memgraph.cpp +++ b/src/memgraph.cpp @@ -27,6 +27,7 @@ #include "helpers.hpp" #include "license/license_sender.hpp" #include "memory/global_memory_control.hpp" +#include "query/auth_checker.hpp" #include "query/auth_query_handler.hpp" #include "query/config.hpp" #include "query/discard_value_stream.hpp" @@ -57,8 +58,13 @@ constexpr uint64_t kMgVmMaxMapCount = 262144; void InitFromCypherlFile(memgraph::query::InterpreterContext &ctx, memgraph::dbms::DatabaseAccess &db_acc, std::string cypherl_file_path, memgraph::audit::Log *audit_log = nullptr) { memgraph::query::Interpreter interpreter(&ctx, db_acc); - std::ifstream file(cypherl_file_path); + // Temporary empty user + // TODO: Double check with buda + memgraph::query::AllowEverythingAuthChecker tmp_auth_checker; + auto tmp_user = tmp_auth_checker.GenQueryUser(std::nullopt, std::nullopt); + interpreter.SetUser(tmp_user); + std::ifstream file(cypherl_file_path); if (!file.is_open()) { spdlog::trace("Could not find init file {}", cypherl_file_path); return; diff --git a/src/query/CMakeLists.txt b/src/query/CMakeLists.txt index 3bc7c9499..d70ede482 100644 --- a/src/query/CMakeLists.txt +++ b/src/query/CMakeLists.txt @@ -40,6 +40,7 @@ set(mg_query_sources db_accessor.cpp auth_query_handler.cpp interpreter_context.cpp + query_user.cpp ) add_library(mg-query STATIC ${mg_query_sources}) diff --git a/src/query/auth_checker.hpp b/src/query/auth_checker.hpp index 1eb9d02e9..183cbd900 100644 --- a/src/query/auth_checker.hpp +++ b/src/query/auth_checker.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -16,7 +16,9 @@ #include #include +#include "dbms/constants.hpp" #include "query/frontend/ast/ast.hpp" +#include "query/query_user.hpp" #include "storage/v2/id_types.hpp" namespace memgraph::query { @@ -29,15 +31,12 @@ class AuthChecker { public: virtual ~AuthChecker() = default; - [[nodiscard]] virtual bool IsUserAuthorized(const std::optional &username, - const std::vector &privileges, - const std::string &db_name) const = 0; + virtual std::shared_ptr GenQueryUser(const std::optional &username, + const std::optional &rolename) const = 0; #ifdef MG_ENTERPRISE [[nodiscard]] virtual std::unique_ptr GetFineGrainedAuthChecker( - const std::string &username, const DbAccessor *db_accessor) const = 0; - - virtual void ClearCache() const = 0; + std::shared_ptr user, const DbAccessor *db_accessor) const = 0; #endif }; #ifdef MG_ENTERPRISE @@ -98,19 +97,29 @@ class AllowEverythingFineGrainedAuthChecker final : public FineGrainedAuthChecke class AllowEverythingAuthChecker final : public AuthChecker { public: - bool IsUserAuthorized(const std::optional & /*username*/, - const std::vector & /*privileges*/, - const std::string & /*db*/) const override { - return true; + struct User : query::QueryUserOrRole { + User() : query::QueryUserOrRole{std::nullopt, std::nullopt} {} + User(std::string name) : query::QueryUserOrRole{std::move(name), std::nullopt} {} + bool IsAuthorized(const std::vector & /*privileges*/, const std::string & /*db_name*/, + UserPolicy * /*policy*/) const override { + return true; + } +#ifdef MG_ENTERPRISE + std::string GetDefaultDB() const override { return std::string{dbms::kDefaultDB}; } +#endif + }; + + std::shared_ptr GenQueryUser(const std::optional &name, + const std::optional & /*role*/) const override { + if (name) return std::make_shared(std::move(*name)); + return std::make_shared(); } #ifdef MG_ENTERPRISE - std::unique_ptr GetFineGrainedAuthChecker(const std::string & /*username*/, + std::unique_ptr GetFineGrainedAuthChecker(std::shared_ptr /*user*/, const DbAccessor * /*dba*/) const override { return std::make_unique(); } - - void ClearCache() const override {} #endif }; diff --git a/src/query/auth_query_handler.hpp b/src/query/auth_query_handler.hpp index 0258005c3..acc90c2c5 100644 --- a/src/query/auth_query_handler.hpp +++ b/src/query/auth_query_handler.hpp @@ -46,15 +46,17 @@ class AuthQueryHandler { system::Transaction *system_tx) = 0; #ifdef MG_ENTERPRISE - /// Return true if access revoked successfully - /// @throw QueryRuntimeException if an error ocurred. - virtual bool RevokeDatabaseFromUser(const std::string &db, const std::string &username, - system::Transaction *system_tx) = 0; - /// Return true if access granted successfully /// @throw QueryRuntimeException if an error ocurred. - virtual bool GrantDatabaseToUser(const std::string &db, const std::string &username, - system::Transaction *system_tx) = 0; + virtual void GrantDatabase(const std::string &db, const std::string &username, system::Transaction *system_tx) = 0; + + /// Return true if access revoked successfully + /// @throw QueryRuntimeException if an error ocurred. + virtual void DenyDatabase(const std::string &db, const std::string &username, system::Transaction *system_tx) = 0; + + /// Return true if access revoked successfully + /// @throw QueryRuntimeException if an error ocurred. + virtual void RevokeDatabase(const std::string &db, const std::string &username, system::Transaction *system_tx) = 0; /// Returns database access rights for the user /// @throw QueryRuntimeException if an error ocurred. @@ -62,7 +64,7 @@ class AuthQueryHandler { /// Return true if main database set successfully /// @throw QueryRuntimeException if an error ocurred. - virtual bool SetMainDatabase(std::string_view db, const std::string &username, system::Transaction *system_tx) = 0; + virtual void SetMainDatabase(std::string_view db, const std::string &username, system::Transaction *system_tx) = 0; /// Delete database from all users /// @throw QueryRuntimeException if an error ocurred. diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index d106f226e..f0abd1c86 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -2850,6 +2850,7 @@ class AuthQuery : public memgraph::query::Query { SHOW_ROLE_FOR_USER, SHOW_USERS_FOR_ROLE, GRANT_DATABASE_TO_USER, + DENY_DATABASE_FROM_USER, REVOKE_DATABASE_FROM_USER, SHOW_DATABASE_PRIVILEGES, SET_MAIN_DATABASE, diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index a84768e96..196c02cde 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -1802,22 +1802,35 @@ antlrcpp::Any CypherMainVisitor::visitShowUsersForRole(MemgraphCypher::ShowUsers /** * @return AuthQuery* */ -antlrcpp::Any CypherMainVisitor::visitGrantDatabaseToUser(MemgraphCypher::GrantDatabaseToUserContext *ctx) { +antlrcpp::Any CypherMainVisitor::visitGrantDatabaseToUserOrRole(MemgraphCypher::GrantDatabaseToUserOrRoleContext *ctx) { auto *auth = storage_->Create(); auth->action_ = AuthQuery::Action::GRANT_DATABASE_TO_USER; auth->database_ = std::any_cast(ctx->wildcardName()->accept(this)); - auth->user_ = std::any_cast(ctx->user->accept(this)); + auth->user_ = std::any_cast(ctx->userOrRole->accept(this)); return auth; } /** * @return AuthQuery* */ -antlrcpp::Any CypherMainVisitor::visitRevokeDatabaseFromUser(MemgraphCypher::RevokeDatabaseFromUserContext *ctx) { +antlrcpp::Any CypherMainVisitor::visitDenyDatabaseFromUserOrRole( + MemgraphCypher::DenyDatabaseFromUserOrRoleContext *ctx) { + auto *auth = storage_->Create(); + auth->action_ = AuthQuery::Action::DENY_DATABASE_FROM_USER; + auth->database_ = std::any_cast(ctx->wildcardName()->accept(this)); + auth->user_ = std::any_cast(ctx->userOrRole->accept(this)); + return auth; +} + +/** + * @return AuthQuery* + */ +antlrcpp::Any CypherMainVisitor::visitRevokeDatabaseFromUserOrRole( + MemgraphCypher::RevokeDatabaseFromUserOrRoleContext *ctx) { auto *auth = storage_->Create(); auth->action_ = AuthQuery::Action::REVOKE_DATABASE_FROM_USER; auth->database_ = std::any_cast(ctx->wildcardName()->accept(this)); - auth->user_ = std::any_cast(ctx->user->accept(this)); + auth->user_ = std::any_cast(ctx->userOrRole->accept(this)); return auth; } @@ -1827,7 +1840,7 @@ antlrcpp::Any CypherMainVisitor::visitRevokeDatabaseFromUser(MemgraphCypher::Rev antlrcpp::Any CypherMainVisitor::visitShowDatabasePrivileges(MemgraphCypher::ShowDatabasePrivilegesContext *ctx) { auto *auth = storage_->Create(); auth->action_ = AuthQuery::Action::SHOW_DATABASE_PRIVILEGES; - auth->user_ = std::any_cast(ctx->user->accept(this)); + auth->user_ = std::any_cast(ctx->userOrRole->accept(this)); return auth; } @@ -1838,7 +1851,7 @@ antlrcpp::Any CypherMainVisitor::visitSetMainDatabase(MemgraphCypher::SetMainDat auto *auth = storage_->Create(); auth->action_ = AuthQuery::Action::SET_MAIN_DATABASE; auth->database_ = std::any_cast(ctx->db->accept(this)); - auth->user_ = std::any_cast(ctx->user->accept(this)); + auth->user_ = std::any_cast(ctx->userOrRole->accept(this)); return auth; } diff --git a/src/query/frontend/ast/cypher_main_visitor.hpp b/src/query/frontend/ast/cypher_main_visitor.hpp index 5fd41b83e..599487cf7 100644 --- a/src/query/frontend/ast/cypher_main_visitor.hpp +++ b/src/query/frontend/ast/cypher_main_visitor.hpp @@ -620,12 +620,17 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor { /** * @return AuthQuery* */ - antlrcpp::Any visitGrantDatabaseToUser(MemgraphCypher::GrantDatabaseToUserContext *ctx) override; + antlrcpp::Any visitGrantDatabaseToUserOrRole(MemgraphCypher::GrantDatabaseToUserOrRoleContext *ctx) override; /** * @return AuthQuery* */ - antlrcpp::Any visitRevokeDatabaseFromUser(MemgraphCypher::RevokeDatabaseFromUserContext *ctx) override; + antlrcpp::Any visitDenyDatabaseFromUserOrRole(MemgraphCypher::DenyDatabaseFromUserOrRoleContext *ctx) override; + + /** + * @return AuthQuery* + */ + antlrcpp::Any visitRevokeDatabaseFromUserOrRole(MemgraphCypher::RevokeDatabaseFromUserOrRoleContext *ctx) override; /** * @return AuthQuery* diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 index 507511d27..0c75d1d82 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 @@ -177,8 +177,9 @@ authQuery : createRole | showPrivileges | showRoleForUser | showUsersForRole - | grantDatabaseToUser - | revokeDatabaseFromUser + | grantDatabaseToUserOrRole + | denyDatabaseFromUserOrRole + | revokeDatabaseFromUserOrRole | showDatabasePrivileges | setMainDatabase ; @@ -304,13 +305,15 @@ denyPrivilege : DENY ( ALL PRIVILEGES | privileges=privilegesList ) TO userOrRol revokePrivilege : REVOKE ( ALL PRIVILEGES | privileges=revokePrivilegesList ) FROM userOrRole=userOrRoleName ; -grantDatabaseToUser : GRANT DATABASE db=wildcardName TO user=symbolicName ; +grantDatabaseToUserOrRole : GRANT DATABASE db=wildcardName TO userOrRole=userOrRoleName ; -revokeDatabaseFromUser : REVOKE DATABASE db=wildcardName FROM user=symbolicName ; +denyDatabaseFromUserOrRole : DENY DATABASE db=wildcardName FROM userOrRole=userOrRoleName ; -showDatabasePrivileges : SHOW DATABASE PRIVILEGES FOR user=symbolicName ; +revokeDatabaseFromUserOrRole : REVOKE DATABASE db=wildcardName FROM userOrRole=userOrRoleName ; -setMainDatabase : SET MAIN DATABASE db=symbolicName FOR user=symbolicName ; +showDatabasePrivileges : SHOW DATABASE PRIVILEGES FOR userOrRole=userOrRoleName ; + +setMainDatabase : SET MAIN DATABASE db=symbolicName FOR userOrRole=userOrRoleName ; privilege : CREATE | DELETE diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index 31efd9ee5..565777581 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -69,6 +69,7 @@ #include "query/plan/profile.hpp" #include "query/plan/vertex_count_cache.hpp" #include "query/procedure/module.hpp" +#include "query/query_user.hpp" #include "query/replication_query_handler.hpp" #include "query/stream.hpp" #include "query/stream/common.hpp" @@ -630,6 +631,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_ AuthQuery::Action::SHOW_USERS_FOR_ROLE, AuthQuery::Action::SHOW_ROLE_FOR_USER, AuthQuery::Action::GRANT_DATABASE_TO_USER, + AuthQuery::Action::DENY_DATABASE_FROM_USER, AuthQuery::Action::REVOKE_DATABASE_FROM_USER, AuthQuery::Action::SHOW_DATABASE_PRIVILEGES, AuthQuery::Action::SET_MAIN_DATABASE}; @@ -889,9 +891,31 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_ if (database != memgraph::auth::kAllDatabases) { db = db_handler->Get(database); // Will throw if databases doesn't exist and protect it during pull } - if (!auth->GrantDatabaseToUser(database, username, &*interpreter->system_transaction_)) { - throw QueryRuntimeException("Failed to grant database {} to user {}.", database, username); + auth->GrantDatabase(database, username, &*interpreter->system_transaction_); // Can throws query exception + } catch (memgraph::dbms::UnknownDatabaseException &e) { + throw QueryRuntimeException(e.what()); + } +#else + callback.fn = [] { +#endif + return std::vector>(); + }; + return callback; + case AuthQuery::Action::DENY_DATABASE_FROM_USER: + forbid_on_replica(); +#ifdef MG_ENTERPRISE + callback.fn = [auth, database, username, db_handler, interpreter = &interpreter] { // NOLINT + if (!interpreter->system_transaction_) { + throw QueryException("Expected to be in a system transaction"); + } + + try { + std::optional db = + std::nullopt; // Hold pointer to database to protect it until query is done + if (database != memgraph::auth::kAllDatabases) { + db = db_handler->Get(database); // Will throw if databases doesn't exist and protect it during pull } + auth->DenyDatabase(database, username, &*interpreter->system_transaction_); // Can throws query exception } catch (memgraph::dbms::UnknownDatabaseException &e) { throw QueryRuntimeException(e.what()); } @@ -915,9 +939,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_ if (database != memgraph::auth::kAllDatabases) { db = db_handler->Get(database); // Will throw if databases doesn't exist and protect it during pull } - if (!auth->RevokeDatabaseFromUser(database, username, &*interpreter->system_transaction_)) { - throw QueryRuntimeException("Failed to revoke database {} from user {}.", database, username); - } + auth->RevokeDatabase(database, username, &*interpreter->system_transaction_); // Can throws query exception } catch (memgraph::dbms::UnknownDatabaseException &e) { throw QueryRuntimeException(e.what()); } @@ -950,9 +972,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_ try { const auto db = db_handler->Get(database); // Will throw if databases doesn't exist and protect it during pull - if (!auth->SetMainDatabase(database, username, &*interpreter->system_transaction_)) { - throw QueryRuntimeException("Failed to set main database {} for user {}.", database, username); - } + auth->SetMainDatabase(database, username, &*interpreter->system_transaction_); // Can throws query exception } catch (memgraph::dbms::UnknownDatabaseException &e) { throw QueryRuntimeException(e.what()); } @@ -1276,7 +1296,7 @@ std::vector EvaluateTopicNames(ExpressionVisitor &evalu Callback::CallbackFunction GetKafkaCreateCallback(StreamQuery *stream_query, ExpressionVisitor &evaluator, memgraph::dbms::DatabaseAccess db_acc, InterpreterContext *interpreter_context, - const std::optional &username) { + std::shared_ptr user_or_role) { static constexpr std::string_view kDefaultConsumerGroup = "mg_consumer"; std::string consumer_group{stream_query->consumer_group_.empty() ? kDefaultConsumerGroup : stream_query->consumer_group_}; @@ -1303,10 +1323,13 @@ Callback::CallbackFunction GetKafkaCreateCallback(StreamQuery *stream_query, Exp memgraph::metrics::IncrementCounter(memgraph::metrics::StreamsCreated); + // Make a copy of the user and pass it to the subsystem + auto owner = interpreter_context->auth_checker->GenQueryUser(user_or_role->username(), user_or_role->rolename()); + return [db_acc = std::move(db_acc), interpreter_context, stream_name = stream_query->stream_name_, topic_names = EvaluateTopicNames(evaluator, stream_query->topic_names_), consumer_group = std::move(consumer_group), common_stream_info = std::move(common_stream_info), - bootstrap_servers = std::move(bootstrap), owner = username, + bootstrap_servers = std::move(bootstrap), owner = std::move(owner), configs = get_config_map(stream_query->configs_, "Configs"), credentials = get_config_map(stream_query->credentials_, "Credentials"), default_server = interpreter_context->config.default_kafka_bootstrap_servers]() mutable { @@ -1328,7 +1351,7 @@ Callback::CallbackFunction GetKafkaCreateCallback(StreamQuery *stream_query, Exp Callback::CallbackFunction GetPulsarCreateCallback(StreamQuery *stream_query, ExpressionVisitor &evaluator, memgraph::dbms::DatabaseAccess db, InterpreterContext *interpreter_context, - const std::optional &username) { + std::shared_ptr user_or_role) { auto service_url = GetOptionalStringValue(stream_query->service_url_, evaluator); if (service_url && service_url->empty()) { throw SemanticException("Service URL must not be an empty string!"); @@ -1336,9 +1359,13 @@ Callback::CallbackFunction GetPulsarCreateCallback(StreamQuery *stream_query, Ex auto common_stream_info = GetCommonStreamInfo(stream_query, evaluator); memgraph::metrics::IncrementCounter(memgraph::metrics::StreamsCreated); + // Make a copy of the user and pass it to the subsystem + auto owner = interpreter_context->auth_checker->GenQueryUser(user_or_role->username(), user_or_role->rolename()); + return [db = std::move(db), interpreter_context, stream_name = stream_query->stream_name_, topic_names = EvaluateTopicNames(evaluator, stream_query->topic_names_), - common_stream_info = std::move(common_stream_info), service_url = std::move(service_url), owner = username, + common_stream_info = std::move(common_stream_info), service_url = std::move(service_url), + owner = std::move(owner), default_service = interpreter_context->config.default_pulsar_service_url]() mutable { std::string url = service_url ? std::move(*service_url) : std::move(default_service); db->streams()->Create( @@ -1352,7 +1379,7 @@ Callback::CallbackFunction GetPulsarCreateCallback(StreamQuery *stream_query, Ex Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters ¶meters, memgraph::dbms::DatabaseAccess &db_acc, InterpreterContext *interpreter_context, - const std::optional &username, std::vector *notifications) { + std::shared_ptr user_or_role, std::vector *notifications) { // TODO: MemoryResource for EvaluationContext, it should probably be passed as // the argument to Callback. EvaluationContext evaluation_context; @@ -1365,10 +1392,12 @@ Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters ¶mete case StreamQuery::Action::CREATE_STREAM: { switch (stream_query->type_) { case StreamQuery::Type::KAFKA: - callback.fn = GetKafkaCreateCallback(stream_query, evaluator, db_acc, interpreter_context, username); + callback.fn = + GetKafkaCreateCallback(stream_query, evaluator, db_acc, interpreter_context, std::move(user_or_role)); break; case StreamQuery::Type::PULSAR: - callback.fn = GetPulsarCreateCallback(stream_query, evaluator, db_acc, interpreter_context, username); + callback.fn = + GetPulsarCreateCallback(stream_query, evaluator, db_acc, interpreter_context, std::move(user_or_role)); break; } notifications->emplace_back(SeverityLevel::INFO, NotificationCode::CREATE_STREAM, @@ -1641,7 +1670,7 @@ struct TxTimeout { struct PullPlan { explicit PullPlan(std::shared_ptr plan, const Parameters ¶meters, bool is_profile_query, DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory, - std::optional username, std::atomic *transaction_status, + std::shared_ptr user_or_role, std::atomic *transaction_status, std::shared_ptr tx_timer, TriggerContextCollector *trigger_context_collector = nullptr, std::optional memory_limit = {}, bool use_monotonic_memory = true, @@ -1681,7 +1710,7 @@ struct PullPlan { PullPlan::PullPlan(const std::shared_ptr plan, const Parameters ¶meters, const bool is_profile_query, DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory, - std::optional username, std::atomic *transaction_status, + std::shared_ptr user_or_role, std::atomic *transaction_status, std::shared_ptr tx_timer, TriggerContextCollector *trigger_context_collector, const std::optional memory_limit, bool use_monotonic_memory, FrameChangeCollector *frame_change_collector) @@ -1697,10 +1726,9 @@ PullPlan::PullPlan(const std::shared_ptr plan, const Parameters &pa ctx_.evaluation_context.properties = NamesToProperties(plan->ast_storage().properties_, dba); ctx_.evaluation_context.labels = NamesToLabels(plan->ast_storage().labels_, dba); #ifdef MG_ENTERPRISE - if (license::global_license_checker.IsEnterpriseValidFast() && username.has_value() && dba) { - // TODO How can we avoid creating this every time? If we must create it, it would be faster with an auth::User - // instead of the username - auto auth_checker = interpreter_context->auth_checker->GetFineGrainedAuthChecker(*username, dba); + if (license::global_license_checker.IsEnterpriseValidFast() && user_or_role && *user_or_role && dba) { + // Create only if an explicit user is defined + auto auth_checker = interpreter_context->auth_checker->GetFineGrainedAuthChecker(std::move(user_or_role), dba); // if the user has global privileges to read, edit and write anything, we don't need to perform authorization // otherwise, we do assign the auth checker to check for label access control @@ -1990,7 +2018,7 @@ bool IsCallBatchedProcedureQuery(const std::vector &c PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map *summary, InterpreterContext *interpreter_context, CurrentDB ¤t_db, utils::MemoryResource *execution_memory, std::vector *notifications, - std::optional const &username, + std::shared_ptr user_or_role, std::atomic *transaction_status, std::shared_ptr tx_timer, FrameChangeCollector *frame_change_collector = nullptr) { @@ -2058,8 +2086,8 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map( - plan, parsed_query.parameters, false, dba, interpreter_context, execution_memory, username, transaction_status, - std::move(tx_timer), trigger_context_collector, memory_limit, use_monotonic_memory, + plan, parsed_query.parameters, false, dba, interpreter_context, execution_memory, std::move(user_or_role), + transaction_status, std::move(tx_timer), trigger_context_collector, memory_limit, use_monotonic_memory, frame_change_collector->IsTrackingValues() ? frame_change_collector : nullptr); return PreparedQuery{std::move(header), std::move(parsed_query.required_privileges), [pull_plan = std::move(pull_plan), output_symbols = std::move(output_symbols), summary]( @@ -2131,7 +2159,8 @@ PreparedQuery PrepareExplainQuery(ParsedQuery parsed_query, std::map *summary, std::vector *notifications, InterpreterContext *interpreter_context, CurrentDB ¤t_db, - utils::MemoryResource *execution_memory, std::optional const &username, + utils::MemoryResource *execution_memory, + std::shared_ptr user_or_role, std::atomic *transaction_status, std::shared_ptr tx_timer, FrameChangeCollector *frame_change_collector) { @@ -2209,37 +2238,37 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra rw_type_checker.InferRWType(const_cast(cypher_query_plan->plan())); - return PreparedQuery{{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"}, - std::move(parsed_query.required_privileges), - [plan = std::move(cypher_query_plan), parameters = std::move(parsed_inner_query.parameters), - summary, dba, interpreter_context, execution_memory, memory_limit, username, - // We want to execute the query we are profiling lazily, so we delay - // the construction of the corresponding context. - stats_and_total_time = std::optional{}, - pull_plan = std::shared_ptr(nullptr), transaction_status, use_monotonic_memory, - frame_change_collector, tx_timer = std::move(tx_timer)]( - AnyStream *stream, std::optional n) mutable -> std::optional { - // No output symbols are given so that nothing is streamed. - if (!stats_and_total_time) { - stats_and_total_time = - PullPlan(plan, parameters, true, dba, interpreter_context, execution_memory, username, - transaction_status, std::move(tx_timer), nullptr, memory_limit, - use_monotonic_memory, - frame_change_collector->IsTrackingValues() ? frame_change_collector : nullptr) - .Pull(stream, {}, {}, summary); - pull_plan = std::make_shared(ProfilingStatsToTable(*stats_and_total_time)); - } + return PreparedQuery{ + {"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"}, + std::move(parsed_query.required_privileges), + [plan = std::move(cypher_query_plan), parameters = std::move(parsed_inner_query.parameters), summary, dba, + interpreter_context, execution_memory, memory_limit, user_or_role = std::move(user_or_role), + // We want to execute the query we are profiling lazily, so we delay + // the construction of the corresponding context. + stats_and_total_time = std::optional{}, + pull_plan = std::shared_ptr(nullptr), transaction_status, use_monotonic_memory, + frame_change_collector, tx_timer = std::move(tx_timer)]( + AnyStream *stream, std::optional n) mutable -> std::optional { + // No output symbols are given so that nothing is streamed. + if (!stats_and_total_time) { + stats_and_total_time = + PullPlan(plan, parameters, true, dba, interpreter_context, execution_memory, std::move(user_or_role), + transaction_status, std::move(tx_timer), nullptr, memory_limit, use_monotonic_memory, + frame_change_collector->IsTrackingValues() ? frame_change_collector : nullptr) + .Pull(stream, {}, {}, summary); + pull_plan = std::make_shared(ProfilingStatsToTable(*stats_and_total_time)); + } - MG_ASSERT(stats_and_total_time, "Failed to execute the query!"); + MG_ASSERT(stats_and_total_time, "Failed to execute the query!"); - if (pull_plan->Pull(stream, n)) { - summary->insert_or_assign("profile", ProfilingStatsToJson(*stats_and_total_time).dump()); - return QueryHandlerResult::ABORT; - } + if (pull_plan->Pull(stream, n)) { + summary->insert_or_assign("profile", ProfilingStatsToJson(*stats_and_total_time).dump()); + return QueryHandlerResult::ABORT; + } - return std::nullopt; - }, - rw_type_checker.type}; + return std::nullopt; + }, + rw_type_checker.type}; } PreparedQuery PrepareDumpQuery(ParsedQuery parsed_query, CurrentDB ¤t_db) { @@ -2732,26 +2761,22 @@ PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transa auto callback = HandleAuthQuery(auth_query, interpreter_context, parsed_query.parameters, interpreter); - return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges), - [handler = std::move(callback.fn), pull_plan = std::shared_ptr(nullptr), - interpreter_context]( // NOLINT - AnyStream *stream, std::optional n) mutable -> std::optional { - if (!pull_plan) { - // Run the specific query - auto results = handler(); - pull_plan = std::make_shared(std::move(results)); -#ifdef MG_ENTERPRISE - // Invalidate auth cache after every type of AuthQuery - interpreter_context->auth_checker->ClearCache(); -#endif - } + return PreparedQuery{ + std::move(callback.header), std::move(parsed_query.required_privileges), + [handler = std::move(callback.fn), pull_plan = std::shared_ptr(nullptr)]( // NOLINT + AnyStream *stream, std::optional n) mutable -> std::optional { + if (!pull_plan) { + // Run the specific query + auto results = handler(); + pull_plan = std::make_shared(std::move(results)); + } - if (pull_plan->Pull(stream, n)) { - return QueryHandlerResult::COMMIT; - } - return std::nullopt; - }, - RWType::NONE}; + if (pull_plan->Pull(stream, n)) { + return QueryHandlerResult::COMMIT; + } + return std::nullopt; + }, + RWType::NONE}; } PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, bool in_explicit_transaction, @@ -2955,17 +2980,18 @@ TriggerEventType ToTriggerEventType(const TriggerQuery::EventType event_type) { Callback CreateTrigger(TriggerQuery *trigger_query, const std::map &user_parameters, TriggerStore *trigger_store, InterpreterContext *interpreter_context, DbAccessor *dba, - std::optional owner) { + std::shared_ptr user_or_role) { + // Make a copy of the user and pass it to the subsystem + auto owner = interpreter_context->auth_checker->GenQueryUser(user_or_role->username(), user_or_role->rolename()); return {{}, [trigger_name = std::move(trigger_query->trigger_name_), trigger_statement = std::move(trigger_query->statement_), event_type = trigger_query->event_type_, before_commit = trigger_query->before_commit_, trigger_store, interpreter_context, dba, user_parameters, owner = std::move(owner)]() mutable -> std::vector> { - trigger_store->AddTrigger(std::move(trigger_name), trigger_statement, user_parameters, - ToTriggerEventType(event_type), - before_commit ? TriggerPhase::BEFORE_COMMIT : TriggerPhase::AFTER_COMMIT, - &interpreter_context->ast_cache, dba, interpreter_context->config.query, - std::move(owner), interpreter_context->auth_checker); + trigger_store->AddTrigger( + std::move(trigger_name), trigger_statement, user_parameters, ToTriggerEventType(event_type), + before_commit ? TriggerPhase::BEFORE_COMMIT : TriggerPhase::AFTER_COMMIT, + &interpreter_context->ast_cache, dba, interpreter_context->config.query, std::move(owner)); memgraph::metrics::IncrementCounter(memgraph::metrics::TriggersCreated); return {}; }}; @@ -3007,7 +3033,7 @@ PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, bool in_explicit_tra std::vector *notifications, CurrentDB ¤t_db, InterpreterContext *interpreter_context, const std::map &user_parameters, - std::optional const &username) { + std::shared_ptr user_or_role) { if (in_explicit_transaction) { throw TriggerModificationInMulticommandTxException(); } @@ -3021,8 +3047,9 @@ PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, bool in_explicit_tra MG_ASSERT(trigger_query); std::optional trigger_notification; + auto callback = std::invoke([trigger_query, trigger_store, interpreter_context, dba, &user_parameters, - owner = username, &trigger_notification]() mutable { + owner = std::move(user_or_role), &trigger_notification]() mutable { switch (trigger_query->action_) { case TriggerQuery::Action::CREATE_TRIGGER: trigger_notification.emplace(SeverityLevel::INFO, NotificationCode::CREATE_TRIGGER, @@ -3060,7 +3087,8 @@ PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, bool in_explicit_tra PreparedQuery PrepareStreamQuery(ParsedQuery parsed_query, bool in_explicit_transaction, std::vector *notifications, CurrentDB ¤t_db, - InterpreterContext *interpreter_context, const std::optional &username) { + InterpreterContext *interpreter_context, + std::shared_ptr user_or_role) { if (in_explicit_transaction) { throw StreamQueryInMulticommandTxException(); } @@ -3070,8 +3098,8 @@ PreparedQuery PrepareStreamQuery(ParsedQuery parsed_query, bool in_explicit_tran auto *stream_query = utils::Downcast(parsed_query.query); MG_ASSERT(stream_query); - auto callback = - HandleStreamQuery(stream_query, parsed_query.parameters, db_acc, interpreter_context, username, notifications); + auto callback = HandleStreamQuery(stream_query, parsed_query.parameters, db_acc, interpreter_context, + std::move(user_or_role), notifications); return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges), [callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr{nullptr}]( @@ -3375,7 +3403,7 @@ PreparedQuery PrepareSettingQuery(ParsedQuery parsed_query, bool in_explicit_tra } template -auto ShowTransactions(const std::unordered_set &interpreters, const std::optional &username, +auto ShowTransactions(const std::unordered_set &interpreters, QueryUserOrRole *user_or_role, Func &&privilege_checker) -> std::vector> { std::vector> results; results.reserve(interpreters.size()); @@ -3395,11 +3423,21 @@ auto ShowTransactions(const std::unordered_set &interpreters, con static std::string all; return interpreter->current_db_.db_acc_ ? interpreter->current_db_.db_acc_->get()->name() : all; }; - if (transaction_id.has_value() && - (interpreter->username_ == username || privilege_checker(get_interpreter_db_name()))) { + + auto same_user = [](const auto &lv, const auto &rv) { + if (lv.get() == rv) return true; + if (lv && rv) return *lv == *rv; + return false; + }; + + if (transaction_id.has_value() && (same_user(interpreter->user_or_role_, user_or_role) || + privilege_checker(user_or_role, get_interpreter_db_name()))) { const auto &typed_queries = interpreter->GetQueries(); - results.push_back({TypedValue(interpreter->username_.value_or("")), - TypedValue(std::to_string(transaction_id.value())), TypedValue(typed_queries)}); + results.push_back( + {TypedValue(interpreter->user_or_role_ + ? (interpreter->user_or_role_->username() ? *interpreter->user_or_role_->username() : "") + : ""), + TypedValue(std::to_string(transaction_id.value())), TypedValue(typed_queries)}); // Handle user-defined metadata std::map metadata_tv; if (interpreter->metadata_) { @@ -3414,17 +3452,19 @@ auto ShowTransactions(const std::unordered_set &interpreters, con } Callback HandleTransactionQueueQuery(TransactionQueueQuery *transaction_query, - const std::optional &username, const Parameters ¶meters, + std::shared_ptr user_or_role, const Parameters ¶meters, InterpreterContext *interpreter_context) { - auto privilege_checker = [username, auth_checker = interpreter_context->auth_checker](std::string const &db_name) { - return auth_checker->IsUserAuthorized(username, {query::AuthQuery::Privilege::TRANSACTION_MANAGEMENT}, db_name); + auto privilege_checker = [](QueryUserOrRole *user_or_role, std::string const &db_name) { + return user_or_role && user_or_role->IsAuthorized({query::AuthQuery::Privilege::TRANSACTION_MANAGEMENT}, db_name, + &query::up_to_date_policy); }; Callback callback; switch (transaction_query->action_) { case TransactionQueueQuery::Action::SHOW_TRANSACTIONS: { - auto show_transactions = [username, privilege_checker = std::move(privilege_checker)](const auto &interpreters) { - return ShowTransactions(interpreters, username, privilege_checker); + auto show_transactions = [user_or_role = std::move(user_or_role), + privilege_checker = std::move(privilege_checker)](const auto &interpreters) { + return ShowTransactions(interpreters, user_or_role.get(), privilege_checker); }; callback.header = {"username", "transaction_id", "query", "metadata"}; callback.fn = [interpreter_context, show_transactions = std::move(show_transactions)] { @@ -3442,9 +3482,10 @@ Callback HandleTransactionQueueQuery(TransactionQueueQuery *transaction_query, return std::string(expression->Accept(evaluator).ValueString()); }); callback.header = {"transaction_id", "killed"}; - callback.fn = [interpreter_context, maybe_kill_transaction_ids = std::move(maybe_kill_transaction_ids), username, + callback.fn = [interpreter_context, maybe_kill_transaction_ids = std::move(maybe_kill_transaction_ids), + user_or_role = std::move(user_or_role), privilege_checker = std::move(privilege_checker)]() mutable { - return interpreter_context->TerminateTransactions(std::move(maybe_kill_transaction_ids), username, + return interpreter_context->TerminateTransactions(std::move(maybe_kill_transaction_ids), user_or_role.get(), std::move(privilege_checker)); }; break; @@ -3454,12 +3495,12 @@ Callback HandleTransactionQueueQuery(TransactionQueueQuery *transaction_query, return callback; } -PreparedQuery PrepareTransactionQueueQuery(ParsedQuery parsed_query, const std::optional &username, +PreparedQuery PrepareTransactionQueueQuery(ParsedQuery parsed_query, std::shared_ptr user_or_role, InterpreterContext *interpreter_context) { auto *transaction_queue_query = utils::Downcast(parsed_query.query); MG_ASSERT(transaction_queue_query); - auto callback = - HandleTransactionQueueQuery(transaction_queue_query, username, parsed_query.parameters, interpreter_context); + auto callback = HandleTransactionQueueQuery(transaction_queue_query, std::move(user_or_role), parsed_query.parameters, + interpreter_context); return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges), [callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr{nullptr}]( @@ -4096,7 +4137,7 @@ PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, CurrentDB &cur } PreparedQuery PrepareShowDatabasesQuery(ParsedQuery parsed_query, InterpreterContext *interpreter_context, - const std::optional &username) { + std::shared_ptr user_or_role) { #ifdef MG_ENTERPRISE if (!license::global_license_checker.IsEnterpriseValidFast()) { throw QueryException("Trying to use enterprise feature without a valid license."); @@ -4107,7 +4148,8 @@ PreparedQuery PrepareShowDatabasesQuery(ParsedQuery parsed_query, InterpreterCon Callback callback; callback.header = {"Name"}; - callback.fn = [auth, db_handler, username]() mutable -> std::vector> { + callback.fn = [auth, db_handler, + user_or_role = std::move(user_or_role)]() mutable -> std::vector> { std::vector> status; auto gen_status = [&](T all, K denied) { Sort(all); @@ -4129,12 +4171,12 @@ PreparedQuery PrepareShowDatabasesQuery(ParsedQuery parsed_query, InterpreterCon status.erase(iter, status.end()); }; - if (!username) { + if (!user_or_role || !*user_or_role) { // No user, return all gen_status(db_handler->All(), std::vector{}); } else { // User has a subset of accessible dbs; this is synched with the SessionContextHandler - const auto &db_priv = auth->GetDatabasePrivileges(*username); + const auto &db_priv = auth->GetDatabasePrivileges(user_or_role->key()); const auto &allowed = db_priv[0][0]; const auto &denied = db_priv[0][1].ValueList(); if (allowed.IsString() && allowed.ValueString() == auth::kAllDatabases) { @@ -4202,6 +4244,7 @@ void Interpreter::SetCurrentDB(std::string_view db_name, bool in_explicit_db) { Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, const std::map ¶ms, QueryExtras const &extras) { + MG_ASSERT(user_or_role_, "Trying to prepare a query without a query user."); // Handle transaction control queries. const auto upper_case_query = utils::ToUpperCase(query_string); const auto trimmed_query = utils::Trim(upper_case_query); @@ -4345,7 +4388,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, frame_change_collector_.emplace(); if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareCypherQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_, - current_db_, memory_resource, &query_execution->notifications, username_, + current_db_, memory_resource, &query_execution->notifications, user_or_role_, &transaction_status_, current_timeout_timer_, &*frame_change_collector_); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareExplainQuery(std::move(parsed_query), &query_execution->summary, @@ -4353,7 +4396,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareProfileQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, &query_execution->notifications, interpreter_context_, current_db_, - &query_execution->execution_memory_with_exception, username_, + &query_execution->execution_memory_with_exception, user_or_role_, &transaction_status_, current_timeout_timer_, &*frame_change_collector_); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareDumpQuery(std::move(parsed_query), current_db_); @@ -4400,11 +4443,11 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareTriggerQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications, - current_db_, interpreter_context_, params, username_); + current_db_, interpreter_context_, params, user_or_role_); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareStreamQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications, - current_db_, interpreter_context_, username_); + current_db_, interpreter_context_, user_or_role_); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareIsolationLevelQuery(std::move(parsed_query), in_explicit_transaction_, current_db_, this); } else if (utils::Downcast(parsed_query.query)) { @@ -4425,7 +4468,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, if (in_explicit_transaction_) { throw TransactionQueueInMulticommandTxException(); } - prepared_query = PrepareTransactionQueueQuery(std::move(parsed_query), username_, interpreter_context_); + prepared_query = PrepareTransactionQueueQuery(std::move(parsed_query), user_or_role_, interpreter_context_); } else if (utils::Downcast(parsed_query.query)) { if (in_explicit_transaction_) { throw MultiDatabaseQueryInMulticommandTxException(); @@ -4435,7 +4478,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, prepared_query = PrepareMultiDatabaseQuery(std::move(parsed_query), current_db_, interpreter_context_, on_change_, *this); } else if (utils::Downcast(parsed_query.query)) { - prepared_query = PrepareShowDatabasesQuery(std::move(parsed_query), interpreter_context_, username_); + prepared_query = PrepareShowDatabasesQuery(std::move(parsed_query), interpreter_context_, user_or_role_); } else if (utils::Downcast(parsed_query.query)) { if (in_explicit_transaction_) { throw EdgeImportModeModificationInMulticommandTxException(); @@ -4511,6 +4554,12 @@ std::vector Interpreter::GetQueries() { void Interpreter::Abort() { bool decrement = true; + + // System tx + // TODO Implement system transaction scope and the ability to abort + system_transaction_.reset(); + + // Data tx auto expected = TransactionStatus::ACTIVE; while (!transaction_status_.compare_exchange_weak(expected, TransactionStatus::STARTED_ROLLBACK)) { if (expected == TransactionStatus::TERMINATED || expected == TransactionStatus::IDLE) { @@ -4562,8 +4611,7 @@ void RunTriggersAfterCommit(dbms::DatabaseAccess db_acc, InterpreterContext *int trigger_context.AdaptForAccessor(&db_accessor); try { trigger.Execute(&db_accessor, &execution_memory, flags::run_time::GetExecutionTimeout(), - &interpreter_context->is_shutting_down, transaction_status, trigger_context, - interpreter_context->auth_checker); + &interpreter_context->is_shutting_down, transaction_status, trigger_context); } catch (const utils::BasicException &exception) { spdlog::warn("Trigger '{}' failed with exception:\n{}", trigger.Name(), exception.what()); db_accessor.Abort(); @@ -4720,8 +4768,7 @@ void Interpreter::Commit() { AdvanceCommand(); try { trigger.Execute(&*current_db_.execution_db_accessor_, &execution_memory, flags::run_time::GetExecutionTimeout(), - &interpreter_context_->is_shutting_down, &transaction_status_, *trigger_context, - interpreter_context_->auth_checker); + &interpreter_context_->is_shutting_down, &transaction_status_, *trigger_context); } catch (const utils::BasicException &e) { throw utils::BasicException( fmt::format("Trigger '{}' caused the transaction to fail.\nException: {}", trigger.Name(), e.what())); @@ -4836,7 +4883,7 @@ void Interpreter::SetNextTransactionIsolationLevel(const storage::IsolationLevel void Interpreter::SetSessionIsolationLevel(const storage::IsolationLevel isolation_level) { interpreter_isolation_level.emplace(isolation_level); } -void Interpreter::ResetUser() { username_.reset(); } -void Interpreter::SetUser(std::string_view username) { username_ = username; } +void Interpreter::ResetUser() { user_or_role_.reset(); } +void Interpreter::SetUser(std::shared_ptr user_or_role) { user_or_role_ = std::move(user_or_role); } } // namespace memgraph::query diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index b4b130f72..1ffbaa597 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -210,7 +210,7 @@ class Interpreter final { std::optional db; }; - std::optional username_; + std::shared_ptr user_or_role_{}; bool in_explicit_transaction_{false}; CurrentDB current_db_; @@ -300,7 +300,7 @@ class Interpreter final { void ResetUser(); - void SetUser(std::string_view username); + void SetUser(std::shared_ptr user); std::optional system_transaction_{}; diff --git a/src/query/interpreter_context.cpp b/src/query/interpreter_context.cpp index f7b4584ba..eb35dbf03 100644 --- a/src/query/interpreter_context.cpp +++ b/src/query/interpreter_context.cpp @@ -35,13 +35,13 @@ InterpreterContext::InterpreterContext(InterpreterConfig interpreter_config, dbm } std::vector> InterpreterContext::TerminateTransactions( - std::vector maybe_kill_transaction_ids, const std::optional &username, - std::function privilege_checker) { + std::vector maybe_kill_transaction_ids, QueryUserOrRole *user_or_role, + std::function privilege_checker) { auto not_found_midpoint = maybe_kill_transaction_ids.end(); // Multiple simultaneous TERMINATE TRANSACTIONS aren't allowed // TERMINATE and SHOW TRANSACTIONS are mutually exclusive - interpreters.WithLock([¬_found_midpoint, &maybe_kill_transaction_ids, username, + interpreters.WithLock([¬_found_midpoint, &maybe_kill_transaction_ids, user_or_role, privilege_checker = std::move(privilege_checker)](const auto &interpreters) { for (Interpreter *interpreter : interpreters) { TransactionStatus alive_status = TransactionStatus::ACTIVE; @@ -73,7 +73,15 @@ std::vector> InterpreterContext::TerminateTransactions( static std::string all; return interpreter->current_db_.db_acc_ ? interpreter->current_db_.db_acc_->get()->name() : all; }; - if (interpreter->username_ == username || privilege_checker(get_interpreter_db_name())) { + + auto same_user = [](const auto &lv, const auto &rv) { + if (lv.get() == rv) return true; + if (lv && rv) return *lv == *rv; + return false; + }; + + if (same_user(interpreter->user_or_role_, user_or_role) || + privilege_checker(user_or_role, get_interpreter_db_name())) { killed = true; // Note: this is used by the above `clean_status` (OnScopeExit) spdlog::warn("Transaction {} successfully killed", transaction_id); } else { diff --git a/src/query/interpreter_context.hpp b/src/query/interpreter_context.hpp index c5fe00d2d..559ea3342 100644 --- a/src/query/interpreter_context.hpp +++ b/src/query/interpreter_context.hpp @@ -46,6 +46,7 @@ constexpr uint64_t kInterpreterTransactionInitialId = 1ULL << 63U; class AuthQueryHandler; class AuthChecker; class Interpreter; +struct QueryUserOrRole; /** * Holds data shared between multiple `Interpreter` instances (which might be @@ -95,8 +96,8 @@ struct InterpreterContext { void Shutdown() { is_shutting_down.store(true, std::memory_order_release); } std::vector> TerminateTransactions( - std::vector maybe_kill_transaction_ids, const std::optional &username, - std::function privilege_checker); + std::vector maybe_kill_transaction_ids, QueryUserOrRole *user_or_role, + std::function privilege_checker); }; } // namespace memgraph::query diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index ca32d1964..b648a8c8b 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -1554,15 +1554,15 @@ class SingleSourceShortestPathCursor : public query::plan::Cursor { // for the given (edge, vertex) pair checks if they satisfy the // "where" condition. if so, places them in the to_visit_ structure. - auto expand_pair = [this, &evaluator, &frame, &context](EdgeAccessor edge, VertexAccessor vertex) { + auto expand_pair = [this, &evaluator, &frame, &context](EdgeAccessor edge, VertexAccessor vertex) -> bool { // if we already processed the given vertex it doesn't get expanded - if (processed_.find(vertex) != processed_.end()) return; + if (processed_.find(vertex) != processed_.end()) return false; #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !(context.auth_checker->Has(vertex, storage::View::OLD, memgraph::query::AuthQuery::FineGrainedPrivilege::READ) && context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) { - return; + return false; } #endif frame[self_.filter_lambda_.inner_edge_symbol] = edge; @@ -1581,9 +1581,9 @@ class SingleSourceShortestPathCursor : public query::plan::Cursor { TypedValue result = self_.filter_lambda_.expression->Accept(evaluator); switch (result.type()) { case TypedValue::Type::Null: - return; + return true; case TypedValue::Type::Bool: - if (!result.ValueBool()) return; + if (!result.ValueBool()) return true; break; default: throw QueryRuntimeException("Expansion condition must evaluate to boolean or null."); @@ -1591,10 +1591,11 @@ class SingleSourceShortestPathCursor : public query::plan::Cursor { } to_visit_next_.emplace_back(edge, vertex, std::move(curr_acc_path)); processed_.emplace(vertex, edge); + return true; }; - auto restore_frame_state_after_expansion = [this, &frame]() { - if (self_.filter_lambda_.accumulated_path_symbol) { + auto restore_frame_state_after_expansion = [this, &frame](bool was_expanded) { + if (was_expanded && self_.filter_lambda_.accumulated_path_symbol) { frame[self_.filter_lambda_.accumulated_path_symbol.value()].ValuePath().Shrink(); } }; @@ -1606,15 +1607,15 @@ class SingleSourceShortestPathCursor : public query::plan::Cursor { if (self_.common_.direction != EdgeAtom::Direction::IN) { auto out_edges = UnwrapEdgesResult(vertex.OutEdges(storage::View::OLD, self_.common_.edge_types)).edges; for (const auto &edge : out_edges) { - expand_pair(edge, edge.To()); - restore_frame_state_after_expansion(); + bool was_expanded = expand_pair(edge, edge.To()); + restore_frame_state_after_expansion(was_expanded); } } if (self_.common_.direction != EdgeAtom::Direction::OUT) { auto in_edges = UnwrapEdgesResult(vertex.InEdges(storage::View::OLD, self_.common_.edge_types)).edges; for (const auto &edge : in_edges) { - expand_pair(edge, edge.From()); - restore_frame_state_after_expansion(); + bool was_expanded = expand_pair(edge, edge.From()); + restore_frame_state_after_expansion(was_expanded); } } }; @@ -1805,18 +1806,8 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor { // For the given (edge, vertex, weight, depth) tuple checks if they // satisfy the "where" condition. if so, places them in the priority // queue. - auto expand_pair = [this, &evaluator, &frame, &create_state, &context]( - const EdgeAccessor &edge, const VertexAccessor &vertex, const TypedValue &total_weight, - int64_t depth) { -#ifdef MG_ENTERPRISE - if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && - !(context.auth_checker->Has(vertex, storage::View::OLD, - memgraph::query::AuthQuery::FineGrainedPrivilege::READ) && - context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) { - return; - } -#endif - + auto expand_pair = [this, &evaluator, &frame, &create_state](const EdgeAccessor &edge, const VertexAccessor &vertex, + const TypedValue &total_weight, int64_t depth) { frame[self_.weight_lambda_->inner_edge_symbol] = edge; frame[self_.weight_lambda_->inner_node_symbol] = vertex; TypedValue next_weight = CalculateNextWeight(self_.weight_lambda_, total_weight, evaluator); @@ -1859,11 +1850,19 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor { // Populates the priority queue structure with expansions // from the given vertex. skips expansions that don't satisfy // the "where" condition. - auto expand_from_vertex = [this, &expand_pair, &restore_frame_state_after_expansion]( + auto expand_from_vertex = [this, &context, &expand_pair, &restore_frame_state_after_expansion]( const VertexAccessor &vertex, const TypedValue &weight, int64_t depth) { if (self_.common_.direction != EdgeAtom::Direction::IN) { auto out_edges = UnwrapEdgesResult(vertex.OutEdges(storage::View::OLD, self_.common_.edge_types)).edges; for (const auto &edge : out_edges) { +#ifdef MG_ENTERPRISE + if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && + !(context.auth_checker->Has(edge.To(), storage::View::OLD, + memgraph::query::AuthQuery::FineGrainedPrivilege::READ) && + context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) { + continue; + } +#endif expand_pair(edge, edge.To(), weight, depth); restore_frame_state_after_expansion(); } @@ -1871,6 +1870,14 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor { if (self_.common_.direction != EdgeAtom::Direction::OUT) { auto in_edges = UnwrapEdgesResult(vertex.InEdges(storage::View::OLD, self_.common_.edge_types)).edges; for (const auto &edge : in_edges) { +#ifdef MG_ENTERPRISE + if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && + !(context.auth_checker->Has(edge.From(), storage::View::OLD, + memgraph::query::AuthQuery::FineGrainedPrivilege::READ) && + context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) { + continue; + } +#endif expand_pair(edge, edge.From(), weight, depth); restore_frame_state_after_expansion(); } diff --git a/src/query/query_user.cpp b/src/query/query_user.cpp new file mode 100644 index 000000000..005601f81 --- /dev/null +++ b/src/query/query_user.cpp @@ -0,0 +1,21 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/query_user.hpp" + +namespace memgraph::query { +// The variables below are used to define a user auth policy. +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +SessionLongPolicy session_long_policy; +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +UpToDatePolicy up_to_date_policy; + +} // namespace memgraph::query diff --git a/src/query/query_user.hpp b/src/query/query_user.hpp new file mode 100644 index 000000000..62d2e32b1 --- /dev/null +++ b/src/query/query_user.hpp @@ -0,0 +1,61 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include +#include + +#include "query/frontend/ast/ast.hpp" + +namespace memgraph::query { + +class UserPolicy { + public: + virtual bool DoUpdate() const = 0; +}; +extern struct SessionLongPolicy : UserPolicy { + public: + bool DoUpdate() const override { return false; } +} session_long_policy; +extern struct UpToDatePolicy : UserPolicy { + public: + bool DoUpdate() const override { return true; } +} up_to_date_policy; + +struct QueryUserOrRole { + QueryUserOrRole(std::optional username, std::optional rolename) + : username_{std::move(username)}, rolename_{std::move(rolename)} {} + virtual ~QueryUserOrRole() = default; + + virtual bool IsAuthorized(const std::vector &privileges, const std::string &db_name, + UserPolicy *policy) const = 0; + +#ifdef MG_ENTERPRISE + virtual std::string GetDefaultDB() const = 0; +#endif + + std::string key() const { + // NOTE: Each role has an associated username, that's why we check it with higher priority + return rolename_ ? *rolename_ : (username_ ? *username_ : ""); + } + const std::optional &username() const { return username_; } + const std::optional &rolename() const { return rolename_; } + + bool operator==(const QueryUserOrRole &other) const = default; + operator bool() const { return username_.has_value(); } + + private: + std::optional username_; + std::optional rolename_; +}; + +} // namespace memgraph::query diff --git a/src/query/stream/streams.cpp b/src/query/stream/streams.cpp index 101ca592c..b8984b94b 100644 --- a/src/query/stream/streams.cpp +++ b/src/query/stream/streams.cpp @@ -29,6 +29,7 @@ #include "query/procedure/mg_procedure_helpers.hpp" #include "query/procedure/mg_procedure_impl.hpp" #include "query/procedure/module.hpp" +#include "query/query_user.hpp" #include "query/stream/sources.hpp" #include "query/typed_value.hpp" #include "utils/event_counter.hpp" @@ -131,6 +132,7 @@ StreamStatus CreateStatus(std::string stream_name, std::string transfor const std::string kStreamName{"name"}; const std::string kIsRunningKey{"is_running"}; const std::string kOwner{"owner"}; +const std::string kOwnerRole{"owner_role"}; const std::string kType{"type"}; } // namespace @@ -142,6 +144,11 @@ void to_json(nlohmann::json &data, StreamStatus &&status) { if (status.owner.has_value()) { data[kOwner] = std::move(*status.owner); + if (status.owner_role.has_value()) { + data[kOwnerRole] = std::move(*status.owner_role); + } else { + data[kOwnerRole] = nullptr; + } } else { data[kOwner] = nullptr; } @@ -156,6 +163,11 @@ void from_json(const nlohmann::json &data, StreamStatus &status) { if (const auto &owner = data.at(kOwner); !owner.is_null()) { status.owner = owner.get(); + if (const auto &owner_role = data.at(kOwnerRole); !owner_role.is_null()) { + owner_role.get_to(status.owner_role); + } else { + status.owner_role = {}; + } } else { status.owner = {}; } @@ -449,7 +461,7 @@ void Streams::RegisterPulsarProcedures() { template void Streams::Create(const std::string &stream_name, typename TStream::StreamInfo info, - std::optional owner, TDbAccess db_acc, InterpreterContext *ic) { + std::shared_ptr owner, TDbAccess db_acc, InterpreterContext *ic) { auto locked_streams = streams_.Lock(); auto it = CreateConsumer(*locked_streams, stream_name, std::move(info), std::move(owner), std::move(db_acc), ic); @@ -469,31 +481,39 @@ void Streams::Create(const std::string &stream_name, typename TStream::StreamInf template void Streams::Create(const std::string &stream_name, KafkaStream::StreamInfo info, - std::optional owner, + std::shared_ptr owner, dbms::DatabaseAccess db, InterpreterContext *ic); template void Streams::Create(const std::string &stream_name, PulsarStream::StreamInfo info, - std::optional owner, + std::shared_ptr owner, dbms::DatabaseAccess db, InterpreterContext *ic); template Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std::string &stream_name, typename TStream::StreamInfo stream_info, - std::optional owner, TDbAccess db_acc, + std::shared_ptr owner, TDbAccess db_acc, InterpreterContext *interpreter_context) { if (map.contains(stream_name)) { throw StreamsException{"Stream already exists with name '{}'", stream_name}; } + auto ownername = owner->username(); + auto rolename = owner->rolename(); + auto *memory_resource = utils::NewDeleteResource(); auto consumer_function = [interpreter_context, memory_resource, stream_name, - transformation_name = stream_info.common_info.transformation_name, owner = owner, + transformation_name = stream_info.common_info.transformation_name, owner = std::move(owner), interpreter = std::make_shared(interpreter_context, std::move(db_acc)), result = mgp_result{nullptr, memory_resource}, total_retries = interpreter_context->config.stream_transaction_conflict_retries, retry_interval = interpreter_context->config.stream_transaction_retry_interval]( const std::vector &messages) mutable { + // Set interpreter's user to the stream owner + // NOTE: We generate an empty user to avoid generating interpreter's fine grained access control and rely only on + // the global auth_checker used in the stream itself + // TODO: Fix auth inconsistency + interpreter->SetUser(interpreter_context->auth_checker->GenQueryUser(std::nullopt, std::nullopt)); #ifdef MG_ENTERPRISE interpreter->OnChangeCB([](auto) { return false; }); // Disable database change #endif @@ -523,12 +543,11 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std spdlog::trace("Processing row in stream '{}'", stream_name); auto [query_value, params_value] = ExtractTransformationResult(row.values, transformation_name, stream_name); storage::PropertyValue params_prop{params_value}; - std::string query{query_value.ValueString()}; spdlog::trace("Executing query '{}' in stream '{}'", query, stream_name); auto prepare_result = interpreter->Prepare(query, params_prop.IsNull() ? empty_parameters : params_prop.ValueMap(), {}); - if (!interpreter_context->auth_checker->IsUserAuthorized(owner, prepare_result.privileges, "")) { + if (!owner->IsAuthorized(prepare_result.privileges, "", &up_to_date_policy)) { throw StreamsException{ "Couldn't execute query '{}' for stream '{}' because the owner is not authorized to execute the " "query!", @@ -553,7 +572,8 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std }; auto insert_result = map.try_emplace( - stream_name, StreamData{std::move(stream_info.common_info.transformation_name), std::move(owner), + stream_name, StreamData{std::move(stream_info.common_info.transformation_name), std::move(ownername), + std::move(rolename), std::make_unique>( stream_name, std::move(stream_info), std::move(consumer_function))}); MG_ASSERT(insert_result.second, "Unexpected error during storing consumer '{}'", stream_name); @@ -575,6 +595,7 @@ void Streams::RestoreStreams(TDbAccess db, InterpreterContext *ic) { const auto create_consumer = [&, &stream_name = stream_name](StreamStatus status, auto &&stream_json_data) { try { + // TODO: Migration stream_json_data.get_to(status); } catch (const nlohmann::json::type_error &exception) { spdlog::warn(get_failed_message("invalid type conversion", exception.what())); @@ -586,8 +607,8 @@ void Streams::RestoreStreams(TDbAccess db, InterpreterContext *ic) { MG_ASSERT(status.name == stream_name, "Expected stream name is '{}', but got '{}'", status.name, stream_name); try { - auto it = CreateConsumer(*locked_streams_map, stream_name, std::move(status.info), std::move(status.owner), - db, ic); + auto owner = ic->auth_checker->GenQueryUser(status.owner, status.owner_role); + auto it = CreateConsumer(*locked_streams_map, stream_name, std::move(status.info), std::move(owner), db, ic); if (status.is_running) { std::visit( [&](const auto &stream_data) { @@ -745,7 +766,7 @@ std::vector> Streams::GetStreamInfo() const { auto info = locked_stream_source->Info(stream_data.transformation_name); result.emplace_back(StreamStatus<>{stream_name, StreamType(*locked_stream_source), locked_stream_source->IsRunning(), std::move(info.common_info), - stream_data.owner}); + stream_data.owner, stream_data.owner_role}); }, stream_data); } diff --git a/src/query/stream/streams.hpp b/src/query/stream/streams.hpp index bad1f8c98..e1660bdb4 100644 --- a/src/query/stream/streams.hpp +++ b/src/query/stream/streams.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -67,6 +67,7 @@ struct StreamStatus { bool is_running; StreamInfoType info; std::optional owner; + std::optional owner_role; }; using TransformationResult = std::vector>; @@ -100,7 +101,7 @@ class Streams final { /// /// @throws StreamsException if the stream with the same name exists or if the creation of Kafka consumer fails template - void Create(const std::string &stream_name, typename TStream::StreamInfo info, std::optional owner, + void Create(const std::string &stream_name, typename TStream::StreamInfo info, std::shared_ptr owner, TDbAccess db, InterpreterContext *interpreter_context); /// Deletes an existing stream and all the data that was persisted. @@ -182,6 +183,7 @@ class Streams final { struct StreamData { std::string transformation_name; std::optional owner; + std::optional owner_role; std::unique_ptr> stream_source; }; @@ -191,7 +193,7 @@ class Streams final { template StreamsMap::iterator CreateConsumer(StreamsMap &map, const std::string &stream_name, - typename TStream::StreamInfo stream_info, std::optional owner, + typename TStream::StreamInfo stream_info, std::shared_ptr owner, TDbAccess db, InterpreterContext *interpreter_context); template diff --git a/src/query/trigger.cpp b/src/query/trigger.cpp index 7998714c1..437389128 100644 --- a/src/query/trigger.cpp +++ b/src/query/trigger.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -11,14 +11,13 @@ #include "query/trigger.hpp" -#include - #include "query/config.hpp" #include "query/context.hpp" #include "query/cypher_query_interpreter.hpp" #include "query/db_accessor.hpp" #include "query/frontend/ast/ast.hpp" #include "query/interpret/frame.hpp" +#include "query/query_user.hpp" #include "query/serialization/property_value.hpp" #include "query/typed_value.hpp" #include "storage/v2/property_value.hpp" @@ -154,20 +153,19 @@ Trigger::Trigger(std::string name, const std::string &query, const std::map &user_parameters, const TriggerEventType event_type, utils::SkipList *query_cache, DbAccessor *db_accessor, const InterpreterConfig::Query &query_config, - std::optional owner, const query::AuthChecker *auth_checker) + std::shared_ptr owner) : name_{std::move(name)}, parsed_statements_{ParseQuery(query, user_parameters, query_cache, query_config)}, event_type_{event_type}, owner_{std::move(owner)} { // We check immediately if the query is valid by trying to create a plan. - GetPlan(db_accessor, auth_checker); + GetPlan(db_accessor); } Trigger::TriggerPlan::TriggerPlan(std::unique_ptr logical_plan, std::vector identifiers) : cached_plan(std::move(logical_plan)), identifiers(std::move(identifiers)) {} -std::shared_ptr Trigger::GetPlan(DbAccessor *db_accessor, - const query::AuthChecker *auth_checker) const { +std::shared_ptr Trigger::GetPlan(DbAccessor *db_accessor) const { std::lock_guard plan_guard{plan_lock_}; if (!parsed_statements_.is_cacheable || !trigger_plan_) { auto identifiers = GetPredefinedIdentifiers(event_type_); @@ -187,7 +185,7 @@ std::shared_ptr Trigger::GetPlan(DbAccessor *db_accessor, trigger_plan_ = std::make_shared(std::move(logical_plan), std::move(identifiers)); } - if (!auth_checker->IsUserAuthorized(owner_, parsed_statements_.required_privileges, "")) { + if (!owner_->IsAuthorized(parsed_statements_.required_privileges, "", &up_to_date_policy)) { throw utils::BasicException("The owner of trigger '{}' is not authorized to execute the query!", name_); } return trigger_plan_; @@ -195,14 +193,13 @@ std::shared_ptr Trigger::GetPlan(DbAccessor *db_accessor, void Trigger::Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution_memory, const double max_execution_time_sec, std::atomic *is_shutting_down, - std::atomic *transaction_status, const TriggerContext &context, - const AuthChecker *auth_checker) const { + std::atomic *transaction_status, const TriggerContext &context) const { if (!context.ShouldEventTrigger(event_type_)) { return; } spdlog::debug("Executing trigger '{}'", name_); - auto trigger_plan = GetPlan(dba, auth_checker); + auto trigger_plan = GetPlan(dba); MG_ASSERT(trigger_plan, "Invalid trigger plan received"); auto &[plan, identifiers] = *trigger_plan; @@ -308,6 +305,7 @@ void TriggerStore::RestoreTriggers(utils::SkipList *query_cache } const auto user_parameters = serialization::DeserializePropertyValueMap(json_trigger_data["user_parameters"]); + // TODO: Migration const auto owner_json = json_trigger_data["owner"]; std::optional owner{}; if (owner_json.is_string()) { @@ -317,10 +315,21 @@ void TriggerStore::RestoreTriggers(utils::SkipList *query_cache continue; } + const auto owner_role_json = json_trigger_data["owner_role"]; + std::optional role{}; + if (owner_role_json.is_string()) { + owner.emplace(owner_role_json.get()); + } else if (!owner_role_json.is_null()) { + spdlog::warn(invalid_state_message); + continue; + } + + auto user = auth_checker->GenQueryUser(owner, role); + std::optional trigger; try { trigger.emplace(trigger_name, statement, user_parameters, event_type, query_cache, db_accessor, query_config, - std::move(owner), auth_checker); + std::move(user)); } catch (const utils::BasicException &e) { spdlog::warn("Failed to create trigger '{}' because: {}", trigger_name, e.what()); continue; @@ -338,8 +347,7 @@ void TriggerStore::AddTrigger(std::string name, const std::string &query, const std::map &user_parameters, TriggerEventType event_type, TriggerPhase phase, utils::SkipList *query_cache, DbAccessor *db_accessor, - const InterpreterConfig::Query &query_config, std::optional owner, - const query::AuthChecker *auth_checker) { + const InterpreterConfig::Query &query_config, std::shared_ptr owner) { std::unique_lock store_guard{store_lock_}; if (storage_.Get(name)) { throw utils::BasicException("Trigger with the same name already exists."); @@ -348,7 +356,7 @@ void TriggerStore::AddTrigger(std::string name, const std::string &query, std::optional trigger; try { trigger.emplace(std::move(name), query, user_parameters, event_type, query_cache, db_accessor, query_config, - std::move(owner), auth_checker); + std::move(owner)); } catch (const utils::BasicException &e) { const auto identifiers = GetPredefinedIdentifiers(event_type); std::stringstream identifier_names_stream; @@ -370,10 +378,23 @@ void TriggerStore::AddTrigger(std::string name, const std::string &query, data["phase"] = phase; data["version"] = kVersion; - if (const auto &owner_from_trigger = trigger->Owner(); owner_from_trigger.has_value()) { - data["owner"] = *owner_from_trigger; + if (const auto &owner_from_trigger = trigger->Owner(); owner_from_trigger && *owner_from_trigger) { + const auto &maybe_username = owner_from_trigger->username(); + if (maybe_username) { + data["owner"] = *maybe_username; + // Roles need to be associated with a username + const auto &maybe_rolename = owner_from_trigger->rolename(); + if (maybe_rolename) { + data["owner_role"] = *maybe_rolename; + } else { + data["owner_role"] = nullptr; + } + } else { + data["owner"] = nullptr; + } } else { data["owner"] = nullptr; + data["owner_role"] = nullptr; } storage_.Put(trigger->Name(), data.dump()); store_guard.unlock(); @@ -417,7 +438,9 @@ std::vector TriggerStore::GetTriggerInfo() const { const auto add_info = [&](const utils::SkipList &trigger_list, const TriggerPhase phase) { for (const auto &trigger : trigger_list.access()) { - info.push_back({trigger.Name(), trigger.OriginalStatement(), trigger.EventType(), phase, trigger.Owner()}); + std::optional owner_str{}; + if (const auto &owner = trigger.Owner(); owner && *owner) owner_str = owner->username(); + info.push_back({trigger.Name(), trigger.OriginalStatement(), trigger.EventType(), phase, std::move(owner_str)}); } }; diff --git a/src/query/trigger.hpp b/src/query/trigger.hpp index a6e19032e..91c74579e 100644 --- a/src/query/trigger.hpp +++ b/src/query/trigger.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -37,12 +37,11 @@ struct Trigger { explicit Trigger(std::string name, const std::string &query, const std::map &user_parameters, TriggerEventType event_type, utils::SkipList *query_cache, DbAccessor *db_accessor, - const InterpreterConfig::Query &query_config, std::optional owner, - const query::AuthChecker *auth_checker); + const InterpreterConfig::Query &query_config, std::shared_ptr owner); void Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution_memory, double max_execution_time_sec, std::atomic *is_shutting_down, std::atomic *transaction_status, - const TriggerContext &context, const AuthChecker *auth_checker) const; + const TriggerContext &context) const; bool operator==(const Trigger &other) const { return name_ == other.name_; } // NOLINTNEXTLINE (modernize-use-nullptr) @@ -65,7 +64,7 @@ struct Trigger { PlanWrapper cached_plan; std::vector identifiers; }; - std::shared_ptr GetPlan(DbAccessor *db_accessor, const query::AuthChecker *auth_checker) const; + std::shared_ptr GetPlan(DbAccessor *db_accessor) const; std::string name_; ParsedQuery parsed_statements_; @@ -74,7 +73,7 @@ struct Trigger { mutable utils::SpinLock plan_lock_; mutable std::shared_ptr trigger_plan_; - std::optional owner_; + std::shared_ptr owner_; }; enum class TriggerPhase : uint8_t { BEFORE_COMMIT, AFTER_COMMIT }; @@ -88,8 +87,7 @@ struct TriggerStore { void AddTrigger(std::string name, const std::string &query, const std::map &user_parameters, TriggerEventType event_type, TriggerPhase phase, utils::SkipList *query_cache, DbAccessor *db_accessor, - const InterpreterConfig::Query &query_config, std::optional owner, - const query::AuthChecker *auth_checker); + const InterpreterConfig::Query &query_config, std::shared_ptr owner); void DropTrigger(const std::string &name); diff --git a/src/replication_handler/replication_handler.cpp b/src/replication_handler/replication_handler.cpp index 0d95cbd51..7a9a7cd58 100644 --- a/src/replication_handler/replication_handler.cpp +++ b/src/replication_handler/replication_handler.cpp @@ -103,7 +103,8 @@ void RecoverReplication(memgraph::replication::ReplicationState &repl_state, inline std::optional HandleRegisterReplicaStatus( utils::BasicResult &instance_client) { - if (instance_client.HasError()) switch (instance_client.GetError()) { + if (instance_client.HasError()) { + switch (instance_client.GetError()) { case replication::RegisterReplicaError::NOT_MAIN: MG_ASSERT(false, "Only main instance can register a replica!"); return {}; @@ -116,6 +117,7 @@ inline std::optional HandleRegisterReplicaStatus( case replication::RegisterReplicaError::SUCCESS: break; } + } return {}; } diff --git a/tests/benchmark/expansion.cpp b/tests/benchmark/expansion.cpp index 0c4579476..d47ca1aca 100644 --- a/tests/benchmark/expansion.cpp +++ b/tests/benchmark/expansion.cpp @@ -27,6 +27,7 @@ std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "e class ExpansionBenchFixture : public benchmark::Fixture { protected: std::optional system; + std::optional auth_checker; std::optional interpreter_context; std::optional interpreter; std::optional> db_gk; @@ -43,6 +44,7 @@ class ExpansionBenchFixture : public benchmark::Fixture { auto &db_acc = *db_acc_opt; system.emplace(); + auth_checker.emplace(); interpreter_context.emplace(memgraph::query::InterpreterConfig{}, nullptr, &repl_state.value(), *system #ifdef MG_ENTERPRISE , @@ -73,13 +75,15 @@ class ExpansionBenchFixture : public benchmark::Fixture { } interpreter.emplace(&*interpreter_context, std::move(db_acc)); + interpreter->SetUser(auth_checker->GenQueryUser(std::nullopt, std::nullopt)); } void TearDown(const benchmark::State &) override { interpreter = std::nullopt; interpreter_context = std::nullopt; - system.reset(); db_gk.reset(); + auth_checker.reset(); + system.reset(); std::filesystem::remove_all(data_directory); } }; diff --git a/tests/e2e/configuration/default_config.py b/tests/e2e/configuration/default_config.py index 4a5b6b858..2b53ea3c6 100644 --- a/tests/e2e/configuration/default_config.py +++ b/tests/e2e/configuration/default_config.py @@ -14,14 +14,7 @@ # If you wish to modify these, update the startup_config_dict and workloads.yaml ! startup_config_dict = { - "auth_module_create_missing_role": ("true", "true", "Set to false to disable creation of missing roles."), - "auth_module_create_missing_user": ("true", "true", "Set to false to disable creation of missing users."), "auth_module_executable": ("", "", "Absolute path to the auth module executable that should be used."), - "auth_module_manage_roles": ( - "true", - "true", - "Set to false to disable management of roles through the auth module.", - ), "auth_module_timeout_ms": ( "10000", "10000", diff --git a/tests/e2e/fine_grained_access/create_delete_filtering_tests.py b/tests/e2e/fine_grained_access/create_delete_filtering_tests.py index b087fa6c5..bc8f107da 100644 --- a/tests/e2e/fine_grained_access/create_delete_filtering_tests.py +++ b/tests/e2e/fine_grained_access/create_delete_filtering_tests.py @@ -19,10 +19,10 @@ from mgclient import DatabaseError @pytest.mark.parametrize("switch", [False, True]) def test_create_node_all_labels_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) results = common.execute_and_fetch_all(user_connection.cursor(), "CREATE (n:label1) RETURN n;") @@ -33,10 +33,10 @@ def test_create_node_all_labels_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_create_node_all_labels_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -47,10 +47,10 @@ def test_create_node_all_labels_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_create_node_specific_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label1 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) results = common.execute_and_fetch_all(user_connection.cursor(), "CREATE (n:label1) RETURN n;") @@ -61,10 +61,10 @@ def test_create_node_specific_label_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_create_node_specific_label_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :label1 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -75,10 +75,10 @@ def test_create_node_specific_label_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_delete_node_all_labels_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n:test_delete) DELETE n;") @@ -91,10 +91,10 @@ def test_delete_node_all_labels_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_delete_node_all_labels_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -105,10 +105,10 @@ def test_delete_node_all_labels_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_delete_node_specific_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :test_delete TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n:test_delete) DELETE n;") @@ -123,10 +123,10 @@ def test_delete_node_specific_label_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_delete_node_specific_label_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :test_delete TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -137,11 +137,11 @@ def test_delete_node_specific_label_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_create_edge_all_labels_all_edge_types_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -156,11 +156,11 @@ def test_create_edge_all_labels_all_edge_types_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_create_edge_all_labels_all_edge_types_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -174,11 +174,11 @@ def test_create_edge_all_labels_all_edge_types_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_create_edge_all_labels_denied_all_edge_types_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -192,11 +192,11 @@ def test_create_edge_all_labels_denied_all_edge_types_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_create_edge_all_labels_granted_all_edge_types_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -210,7 +210,6 @@ def test_create_edge_all_labels_granted_all_edge_types_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_create_edge_all_labels_granted_specific_edge_types_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") @@ -218,6 +217,7 @@ def test_create_edge_all_labels_granted_specific_edge_types_denied(switch): admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES :edge_type TO user;", ) + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -231,7 +231,6 @@ def test_create_edge_all_labels_granted_specific_edge_types_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_create_edge_first_node_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label1 TO user;") @@ -240,6 +239,7 @@ def test_create_edge_first_node_label_granted(switch): admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES :edge_type TO user;", ) + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -253,7 +253,6 @@ def test_create_edge_first_node_label_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_create_edge_second_node_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label2 TO user;") @@ -262,6 +261,7 @@ def test_create_edge_second_node_label_granted(switch): admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES :edge_type TO user;", ) + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -275,11 +275,11 @@ def test_create_edge_second_node_label_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_delete_edge_all_labels_denied_all_edge_types_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -293,11 +293,11 @@ def test_delete_edge_all_labels_denied_all_edge_types_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_delete_edge_all_labels_granted_all_edge_types_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -311,7 +311,6 @@ def test_delete_edge_all_labels_granted_all_edge_types_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_delete_edge_all_labels_granted_specific_edge_types_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") @@ -319,6 +318,7 @@ def test_delete_edge_all_labels_granted_specific_edge_types_denied(switch): admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES :edge_type_delete TO user;", ) + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -332,7 +332,6 @@ def test_delete_edge_all_labels_granted_specific_edge_types_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_delete_edge_first_node_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :test_delete_1 TO user;") @@ -341,6 +340,7 @@ def test_delete_edge_first_node_label_granted(switch): admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES :edge_type_delete TO user;", ) + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -354,7 +354,6 @@ def test_delete_edge_first_node_label_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_delete_edge_second_node_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :test_delete_2 TO user;") @@ -363,6 +362,7 @@ def test_delete_edge_second_node_label_granted(switch): admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES :edge_type_delete TO user;", ) + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -376,13 +376,13 @@ def test_delete_edge_second_node_label_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_delete_node_with_edge_label_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all( admin_connection.cursor(), "GRANT UPDATE ON LABELS :test_delete_1 TO user;", ) + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -393,13 +393,13 @@ def test_delete_node_with_edge_label_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_delete_node_with_edge_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all( admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :test_delete_1 TO user;", ) + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -415,10 +415,10 @@ def test_delete_node_with_edge_label_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_merge_node_all_labels_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) results = common.execute_and_fetch_all(user_connection.cursor(), "MERGE (n:label1) RETURN n;") @@ -429,10 +429,10 @@ def test_merge_node_all_labels_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_merge_node_all_labels_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -443,10 +443,10 @@ def test_merge_node_all_labels_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_merge_node_specific_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label1 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) results = common.execute_and_fetch_all(user_connection.cursor(), "MERGE (n:label1) RETURN n;") @@ -457,10 +457,10 @@ def test_merge_node_specific_label_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_merge_node_specific_label_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :label1 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -471,11 +471,11 @@ def test_merge_node_specific_label_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_merge_edge_all_labels_all_edge_types_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) results = common.execute_and_fetch_all( @@ -489,11 +489,11 @@ def test_merge_edge_all_labels_all_edge_types_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_merge_edge_all_labels_all_edge_types_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -507,11 +507,11 @@ def test_merge_edge_all_labels_all_edge_types_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_merge_edge_all_labels_denied_all_edge_types_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -525,11 +525,11 @@ def test_merge_edge_all_labels_denied_all_edge_types_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_merge_edge_all_labels_granted_all_edge_types_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -543,7 +543,6 @@ def test_merge_edge_all_labels_granted_all_edge_types_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_merge_edge_all_labels_granted_specific_edge_types_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") @@ -551,6 +550,7 @@ def test_merge_edge_all_labels_granted_specific_edge_types_denied(switch): admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES :edge_type TO user;", ) + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -564,7 +564,6 @@ def test_merge_edge_all_labels_granted_specific_edge_types_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_merge_edge_first_node_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label1 TO user;") @@ -573,6 +572,7 @@ def test_merge_edge_first_node_label_granted(switch): admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES :edge_type TO user;", ) + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -586,7 +586,6 @@ def test_merge_edge_first_node_label_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_merge_edge_second_node_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label2 TO user;") @@ -595,6 +594,7 @@ def test_merge_edge_second_node_label_granted(switch): admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES :edge_type TO user;", ) + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -608,10 +608,10 @@ def test_merge_edge_second_node_label_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_set_label_when_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :update_label_2 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -621,12 +621,12 @@ def test_set_label_when_label_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_set_label_when_label_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :update_label_2 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :test_delete TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -637,11 +637,11 @@ def test_set_label_when_label_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_remove_label_when_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :test_delete TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -651,12 +651,12 @@ def test_remove_label_when_label_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_remove_label_when_label_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :update_label_2 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :test_delete TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -667,12 +667,12 @@ def test_remove_label_when_label_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_merge_nodes_pass_when_having_create_delete(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) diff --git a/tests/e2e/fine_grained_access/edge_type_filtering_tests.py b/tests/e2e/fine_grained_access/edge_type_filtering_tests.py index f2071c54b..4bae2b2f4 100644 --- a/tests/e2e/fine_grained_access/edge_type_filtering_tests.py +++ b/tests/e2e/fine_grained_access/edge_type_filtering_tests.py @@ -7,9 +7,9 @@ import pytest @pytest.mark.parametrize("switch", [False, True]) def test_all_edge_types_all_labels_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -21,9 +21,9 @@ def test_all_edge_types_all_labels_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_deny_all_edge_types_and_all_labels(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -35,9 +35,9 @@ def test_deny_all_edge_types_and_all_labels(switch): @pytest.mark.parametrize("switch", [False, True]) def test_revoke_all_edge_types_and_all_labels(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -49,10 +49,10 @@ def test_revoke_all_edge_types_and_all_labels(switch): @pytest.mark.parametrize("switch", [False, True]) def test_deny_edge_type(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :label1, :label2, :label3 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edgeType2 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edgeType1 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -64,10 +64,10 @@ def test_deny_edge_type(switch): @pytest.mark.parametrize("switch", [False, True]) def test_denied_node_label(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :label1,:label3 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edgeType1, :edgeType2 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label2 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -79,10 +79,10 @@ def test_denied_node_label(switch): @pytest.mark.parametrize("switch", [False, True]) def test_denied_one_of_node_label(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :label1,:label2 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edgeType1, :edgeType2 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label3 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -94,8 +94,8 @@ def test_denied_one_of_node_label(switch): @pytest.mark.parametrize("switch", [False, True]) def test_revoke_all_labels(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;") @@ -106,8 +106,8 @@ def test_revoke_all_labels(switch): @pytest.mark.parametrize("switch", [False, True]) def test_revoke_all_edge_types(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;") diff --git a/tests/e2e/fine_grained_access/path_filtering_tests.py b/tests/e2e/fine_grained_access/path_filtering_tests.py index c5b873972..e8c395b2e 100644 --- a/tests/e2e/fine_grained_access/path_filtering_tests.py +++ b/tests/e2e/fine_grained_access/path_filtering_tests.py @@ -7,11 +7,11 @@ import pytest @pytest.mark.parametrize("switch", [False, True]) def test_weighted_shortest_path_all_edge_types_all_labels_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -54,11 +54,11 @@ def test_weighted_shortest_path_all_edge_types_all_labels_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_weighted_shortest_path_all_edge_types_all_labels_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -72,7 +72,6 @@ def test_weighted_shortest_path_all_edge_types_all_labels_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_weighted_shortest_path_denied_start(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -80,6 +79,7 @@ def test_weighted_shortest_path_denied_start(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -94,7 +94,6 @@ def test_weighted_shortest_path_denied_start(switch): @pytest.mark.parametrize("switch", [False, True]) def test_weighted_shortest_path_denied_destination(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -102,6 +101,7 @@ def test_weighted_shortest_path_denied_destination(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -116,7 +116,6 @@ def test_weighted_shortest_path_denied_destination(switch): @pytest.mark.parametrize("switch", [False, True]) def test_weighted_shortest_path_denied_label_1(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -124,6 +123,7 @@ def test_weighted_shortest_path_denied_label_1(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -162,7 +162,6 @@ def test_weighted_shortest_path_denied_label_1(switch): @pytest.mark.parametrize("switch", [False, True]) def test_weighted_shortest_path_denied_edge_type_3(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") @@ -170,6 +169,7 @@ def test_weighted_shortest_path_denied_edge_type_3(switch): admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edge_type_1, :edge_type_2, :edge_type_4 TO user;" ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -213,11 +213,11 @@ def test_weighted_shortest_path_denied_edge_type_3(switch): @pytest.mark.parametrize("switch", [False, True]) def test_dfs_all_edge_types_all_labels_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -235,11 +235,11 @@ def test_dfs_all_edge_types_all_labels_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_dfs_all_edge_types_all_labels_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -251,7 +251,6 @@ def test_dfs_all_edge_types_all_labels_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_dfs_denied_start(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -259,6 +258,7 @@ def test_dfs_denied_start(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -272,7 +272,6 @@ def test_dfs_denied_start(switch): @pytest.mark.parametrize("switch", [False, True]) def test_dfs_denied_destination(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -280,6 +279,7 @@ def test_dfs_denied_destination(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -293,7 +293,6 @@ def test_dfs_denied_destination(switch): @pytest.mark.parametrize("switch", [False, True]) def test_dfs_denied_label_1(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -301,6 +300,7 @@ def test_dfs_denied_label_1(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -318,7 +318,6 @@ def test_dfs_denied_label_1(switch): @pytest.mark.parametrize("switch", [False, True]) def test_dfs_denied_edge_type_3(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") @@ -327,6 +326,7 @@ def test_dfs_denied_edge_type_3(switch): admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edge_type_1, :edge_type_2, :edge_type_4 TO user;" ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -344,11 +344,11 @@ def test_dfs_denied_edge_type_3(switch): @pytest.mark.parametrize("switch", [False, True]) def test_bfs_sts_all_edge_types_all_labels_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -366,11 +366,11 @@ def test_bfs_sts_all_edge_types_all_labels_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_bfs_sts_all_edge_types_all_labels_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -384,7 +384,6 @@ def test_bfs_sts_all_edge_types_all_labels_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_bfs_sts_denied_start(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -392,6 +391,7 @@ def test_bfs_sts_denied_start(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -405,7 +405,6 @@ def test_bfs_sts_denied_start(switch): @pytest.mark.parametrize("switch", [False, True]) def test_bfs_sts_denied_destination(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -413,6 +412,7 @@ def test_bfs_sts_denied_destination(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -426,7 +426,6 @@ def test_bfs_sts_denied_destination(switch): @pytest.mark.parametrize("switch", [False, True]) def test_bfs_sts_denied_label_1(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -434,6 +433,7 @@ def test_bfs_sts_denied_label_1(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -450,7 +450,6 @@ def test_bfs_sts_denied_label_1(switch): @pytest.mark.parametrize("switch", [False, True]) def test_bfs_sts_denied_edge_type_3(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") @@ -458,6 +457,7 @@ def test_bfs_sts_denied_edge_type_3(switch): admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edge_type_1, :edge_type_2, :edge_type_4 TO user;" ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -474,11 +474,11 @@ def test_bfs_sts_denied_edge_type_3(switch): @pytest.mark.parametrize("switch", [False, True]) def test_bfs_single_source_all_edge_types_all_labels_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -496,11 +496,11 @@ def test_bfs_single_source_all_edge_types_all_labels_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_bfs_single_source_all_edge_types_all_labels_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -512,7 +512,6 @@ def test_bfs_single_source_all_edge_types_all_labels_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_bfs_single_source_denied_start(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -520,6 +519,7 @@ def test_bfs_single_source_denied_start(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -533,7 +533,6 @@ def test_bfs_single_source_denied_start(switch): @pytest.mark.parametrize("switch", [False, True]) def test_bfs_single_source_denied_destination(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -541,6 +540,7 @@ def test_bfs_single_source_denied_destination(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -554,7 +554,6 @@ def test_bfs_single_source_denied_destination(switch): @pytest.mark.parametrize("switch", [False, True]) def test_bfs_single_source_denied_label_1(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -562,6 +561,7 @@ def test_bfs_single_source_denied_label_1(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -579,7 +579,6 @@ def test_bfs_single_source_denied_label_1(switch): @pytest.mark.parametrize("switch", [False, True]) def test_bfs_single_source_denied_edge_type_3(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") @@ -587,6 +586,7 @@ def test_bfs_single_source_denied_edge_type_3(switch): admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edge_type_1, :edge_type_2, :edge_type_4 TO user;" ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -604,11 +604,11 @@ def test_bfs_single_source_denied_edge_type_3(switch): @pytest.mark.parametrize("switch", [False, True]) def test_all_shortest_paths_when_all_edge_types_all_labels_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -651,11 +651,11 @@ def test_all_shortest_paths_when_all_edge_types_all_labels_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_all_shortest_paths_when_all_edge_types_all_labels_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -669,7 +669,6 @@ def test_all_shortest_paths_when_all_edge_types_all_labels_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_all_shortest_paths_when_denied_start(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -677,6 +676,7 @@ def test_all_shortest_paths_when_denied_start(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -691,7 +691,6 @@ def test_all_shortest_paths_when_denied_start(switch): @pytest.mark.parametrize("switch", [False, True]) def test_all_shortest_paths_when_denied_destination(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -699,6 +698,7 @@ def test_all_shortest_paths_when_denied_destination(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -713,7 +713,6 @@ def test_all_shortest_paths_when_denied_destination(switch): @pytest.mark.parametrize("switch", [False, True]) def test_all_shortest_paths_when_denied_label_1(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -721,6 +720,7 @@ def test_all_shortest_paths_when_denied_label_1(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -759,7 +759,6 @@ def test_all_shortest_paths_when_denied_label_1(switch): @pytest.mark.parametrize("switch", [False, True]) def test_all_shortest_paths_when_denied_edge_type_3(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") @@ -767,6 +766,7 @@ def test_all_shortest_paths_when_denied_edge_type_3(switch): admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edge_type_1, :edge_type_2, :edge_type_4 TO user;" ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) diff --git a/tests/e2e/fine_grained_access/workloads.yaml b/tests/e2e/fine_grained_access/workloads.yaml index ad1dd43b2..6128b4b7d 100644 --- a/tests/e2e/fine_grained_access/workloads.yaml +++ b/tests/e2e/fine_grained_access/workloads.yaml @@ -84,11 +84,12 @@ show_databases_w_user_setup_queries: &show_databases_w_user_setup_queries - "GRANT DATABASE db1 TO user;" - "GRANT ALL PRIVILEGES TO user2;" - "GRANT DATABASE db2 TO user2;" + - "GRANT DATABASE memgraph TO user2;" - "REVOKE DATABASE memgraph FROM user2;" - "SET MAIN DATABASE db2 FOR user2" - "GRANT ALL PRIVILEGES TO user3;" - "GRANT DATABASE * TO user3;" - - "REVOKE DATABASE memgraph FROM user3;" + - "DENY DATABASE memgraph FROM user3;" - "SET MAIN DATABASE db1 FOR user3" create_delete_filtering_in_memory_cluster: &create_delete_filtering_in_memory_cluster diff --git a/tests/e2e/lba_procedures/read_permission_queries.py b/tests/e2e/lba_procedures/read_permission_queries.py index 4348e6bda..4c02910da 100644 --- a/tests/e2e/lba_procedures/read_permission_queries.py +++ b/tests/e2e/lba_procedures/read_permission_queries.py @@ -107,18 +107,21 @@ def execute_read_node_assertion( operation_case: List[str], queries: List[str], create_index: bool, expected_size: int, switch: bool ) -> None: admin_cursor = get_admin_cursor() - user_cursor = get_user_cursor() if switch: create_multi_db(admin_cursor) switch_db(admin_cursor) - switch_db(user_cursor) reset_permissions(admin_cursor, create_index) for operation in operation_case: execute_and_fetch_all(admin_cursor, operation) + # Connect after possible auth changes + user_cursor = get_user_cursor() + if switch: + switch_db(user_cursor) + for mq in queries: results = execute_and_fetch_all(user_cursor, mq) assert len(results) == expected_size diff --git a/tests/e2e/replication_experimental/auth.py b/tests/e2e/replication_experimental/auth.py index 44738ccbd..60da2f513 100644 --- a/tests/e2e/replication_experimental/auth.py +++ b/tests/e2e/replication_experimental/auth.py @@ -121,6 +121,7 @@ def only_main_queries(cursor): n_exceptions += try_and_count(cursor, f"REVOKE EDGE_TYPES :e FROM user_name") n_exceptions += try_and_count(cursor, f"GRANT DATABASE memgraph TO user_name;") n_exceptions += try_and_count(cursor, f"SET MAIN DATABASE memgraph FOR user_name") + n_exceptions += try_and_count(cursor, f"DENY DATABASE memgraph FROM user_name;") n_exceptions += try_and_count(cursor, f"REVOKE DATABASE memgraph FROM user_name;") return n_exceptions @@ -198,8 +199,8 @@ def test_auth_queries_on_replica(connection): # 1/ assert only_main_queries(cursor_main) == 0 - assert only_main_queries(cursor_replica_1) == 17 - assert only_main_queries(cursor_replica_2) == 17 + assert only_main_queries(cursor_replica_1) == 18 + assert only_main_queries(cursor_replica_2) == 18 assert main_and_repl_queries(cursor_main) == 0 assert main_and_repl_queries(cursor_replica_1) == 0 assert main_and_repl_queries(cursor_replica_2) == 0 @@ -383,6 +384,7 @@ def test_manual_roles_recovery(connection): "--log-level=TRACE", "--data_directory", TEMP_DIR + "/replica1", + "--also-log-to-stderr", ], "log_file": "replica1.log", "setup_queries": [ @@ -818,13 +820,15 @@ def test_auth_replication(connection): {("LABEL :l3", "UPDATE", "LABEL PERMISSION GRANTED TO ROLE")}, ) - # GRANT/REVOKE DATABASE + # GRANT/DENY DATABASE execute_and_fetch_all(cursor_main, "CREATE DATABASE auth_test") execute_and_fetch_all(cursor_main, "CREATE DATABASE auth_test2") execute_and_fetch_all(cursor_main, "GRANT DATABASE auth_test TO user4") check(partial(show_database_privileges_func, user="user4"), [(["auth_test", "memgraph"], [])]) - execute_and_fetch_all(cursor_main, "REVOKE DATABASE auth_test2 FROM user4") + execute_and_fetch_all(cursor_main, "DENY DATABASE auth_test2 FROM user4") check(partial(show_database_privileges_func, user="user4"), [(["auth_test", "memgraph"], ["auth_test2"])]) + execute_and_fetch_all(cursor_main, "REVOKE DATABASE memgraph FROM user4") + check(partial(show_database_privileges_func, user="user4"), [(["auth_test"], ["auth_test2"])]) # SET MAIN DATABASE execute_and_fetch_all(cursor_main, "GRANT ALL PRIVILEGES TO user4") diff --git a/tests/e2e/transaction_queue/test_transaction_queue.py b/tests/e2e/transaction_queue/test_transaction_queue.py index 221243c50..a563b0aff 100644 --- a/tests/e2e/transaction_queue/test_transaction_queue.py +++ b/tests/e2e/transaction_queue/test_transaction_queue.py @@ -70,21 +70,26 @@ def test_multitenant_transactions(): # TODO Add SHOW TRANSACTIONS ON * that should return all transactions -def test_admin_has_one_transaction(): +def test_admin_has_one_transaction(request): """Creates admin and tests that he sees only one transaction.""" # a_cursor is used for creating admin user, simulates main thread superadmin_cursor = connect().cursor() execute_and_fetch_all(superadmin_cursor, "CREATE USER admin") execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin") execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin") + + def on_exit(): + execute_and_fetch_all(superadmin_cursor, "DROP USER admin") + + request.addfinalizer(on_exit) + admin_cursor = connect(username="admin", password="").cursor() process = multiprocessing.Process(target=show_transactions_test, args=(admin_cursor, 1)) process.start() process.join() - execute_and_fetch_all(superadmin_cursor, "DROP USER admin") -def test_user_can_see_its_transaction(): +def test_user_can_see_its_transaction(request): """Tests that user without privileges can see its own transaction""" superadmin_cursor = connect().cursor() execute_and_fetch_all(superadmin_cursor, "CREATE USER admin") @@ -92,20 +97,31 @@ def test_user_can_see_its_transaction(): execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin") execute_and_fetch_all(superadmin_cursor, "CREATE USER user") execute_and_fetch_all(superadmin_cursor, "REVOKE ALL PRIVILEGES FROM user") + + def on_exit(): + execute_and_fetch_all(superadmin_cursor, "DROP USER admin") + execute_and_fetch_all(superadmin_cursor, "DROP USER user") + + request.addfinalizer(on_exit) + user_cursor = connect(username="user", password="").cursor() process = multiprocessing.Process(target=show_transactions_test, args=(user_cursor, 1)) process.start() process.join() admin_cursor = connect(username="admin", password="").cursor() - execute_and_fetch_all(admin_cursor, "DROP USER user") - execute_and_fetch_all(admin_cursor, "DROP USER admin") -def test_explicit_transaction_output(): +def test_explicit_transaction_output(request): superadmin_cursor = connect().cursor() execute_and_fetch_all(superadmin_cursor, "CREATE USER admin") execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin") execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin") + + def on_exit(): + execute_and_fetch_all(superadmin_cursor, "DROP USER admin") + + request.addfinalizer(on_exit) + admin_connection = connect(username="admin", password="") admin_cursor = admin_connection.cursor() # Admin starts running explicit transaction @@ -123,10 +139,9 @@ def test_explicit_transaction_output(): assert show_results[1 - executing_index][2] == ["CREATE (n:Person {id_: 1})", "CREATE (n:Person {id_: 2})"] execute_and_fetch_all(superadmin_cursor, "ROLLBACK") - execute_and_fetch_all(superadmin_cursor, "DROP USER admin") -def test_superadmin_cannot_see_admin_can_see_admin(): +def test_superadmin_cannot_see_admin_can_see_admin(request): """Tests that superadmin cannot see the transaction created by admin but two admins can see and kill each other's transactions.""" superadmin_cursor = connect().cursor() execute_and_fetch_all(superadmin_cursor, "CREATE USER admin1") @@ -135,6 +150,13 @@ def test_superadmin_cannot_see_admin_can_see_admin(): execute_and_fetch_all(superadmin_cursor, "CREATE USER admin2") execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin2") execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin2") + + def on_exit(): + execute_and_fetch_all(superadmin_cursor, "DROP USER admin1") + execute_and_fetch_all(superadmin_cursor, "DROP USER admin2") + + request.addfinalizer(on_exit) + # Admin starts running infinite query admin_connection_1 = connect(username="admin1", password="") admin_cursor_1 = admin_connection_1.cursor() @@ -160,19 +182,23 @@ def test_superadmin_cannot_see_admin_can_see_admin(): # Kill transaction long_transaction_id = show_results[1 - executing_index][1] execute_and_fetch_all(admin_cursor_2, f"TERMINATE TRANSACTIONS '{long_transaction_id}'") - execute_and_fetch_all(superadmin_cursor, "DROP USER admin1") - execute_and_fetch_all(superadmin_cursor, "DROP USER admin2") admin_connection_1.close() admin_connection_2.close() -def test_admin_sees_superadmin(): +def test_admin_sees_superadmin(request): """Tests that admin created by superadmin can see the superadmin's transaction.""" superadmin_connection = connect() superadmin_cursor = superadmin_connection.cursor() execute_and_fetch_all(superadmin_cursor, "CREATE USER admin") execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin") execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin") + + def on_exit(): + execute_and_fetch_all(admin_cursor, "DROP USER admin") + + request.addfinalizer(on_exit) + # Admin starts running infinite query process = multiprocessing.Process( target=process_function, args=(superadmin_cursor, ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"]) @@ -194,17 +220,23 @@ def test_admin_sees_superadmin(): # Kill transaction long_transaction_id = show_results[1 - executing_index][1] execute_and_fetch_all(admin_cursor, f"TERMINATE TRANSACTIONS '{long_transaction_id}'") - execute_and_fetch_all(admin_cursor, "DROP USER admin") superadmin_connection.close() -def test_admin_can_see_user_transaction(): +def test_admin_can_see_user_transaction(request): """Tests that admin can see user's transaction and kill it.""" superadmin_cursor = connect().cursor() execute_and_fetch_all(superadmin_cursor, "CREATE USER admin") execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin") execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin") execute_and_fetch_all(superadmin_cursor, "CREATE USER user") + + def on_exit(): + execute_and_fetch_all(superadmin_cursor, "DROP USER admin") + execute_and_fetch_all(superadmin_cursor, "DROP USER user") + + request.addfinalizer(on_exit) + # Admin starts running infinite query admin_connection = connect(username="admin", password="") admin_cursor = admin_connection.cursor() @@ -229,13 +261,11 @@ def test_admin_can_see_user_transaction(): # Kill transaction long_transaction_id = show_results[1 - executing_index][1] execute_and_fetch_all(admin_cursor, f"TERMINATE TRANSACTIONS '{long_transaction_id}'") - execute_and_fetch_all(superadmin_cursor, "DROP USER admin") - execute_and_fetch_all(superadmin_cursor, "DROP USER user") admin_connection.close() user_connection.close() -def test_user_cannot_see_admin_transaction(): +def test_user_cannot_see_admin_transaction(request): """User cannot see admin's transaction but other admin can and he can kill it.""" # Superadmin creates two admins and one user superadmin_cursor = connect().cursor() @@ -246,6 +276,14 @@ def test_user_cannot_see_admin_transaction(): execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin2") execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin2") execute_and_fetch_all(superadmin_cursor, "CREATE USER user") + + def on_exit(): + execute_and_fetch_all(superadmin_cursor, "DROP USER admin1") + execute_and_fetch_all(superadmin_cursor, "DROP USER admin2") + execute_and_fetch_all(superadmin_cursor, "DROP USER user") + + request.addfinalizer(on_exit) + admin_connection_1 = connect(username="admin1", password="") admin_cursor_1 = admin_connection_1.cursor() admin_connection_2 = connect(username="admin2", password="") @@ -274,9 +312,6 @@ def test_user_cannot_see_admin_transaction(): # Kill transaction long_transaction_id = show_results[1 - executing_index][1] execute_and_fetch_all(admin_cursor_2, f"TERMINATE TRANSACTIONS '{long_transaction_id}'") - execute_and_fetch_all(superadmin_cursor, "DROP USER admin1") - execute_and_fetch_all(superadmin_cursor, "DROP USER admin2") - execute_and_fetch_all(superadmin_cursor, "DROP USER user") admin_connection_1.close() admin_connection_2.close() user_connection.close() @@ -300,12 +335,18 @@ def test_killing_multiple_non_existing_transactions(): assert results[i][1] == False # not killed -def test_admin_killing_multiple_non_existing_transactions(): +def test_admin_killing_multiple_non_existing_transactions(request): # Starting, superadmin admin superadmin_cursor = connect().cursor() execute_and_fetch_all(superadmin_cursor, "CREATE USER admin") execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin") execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin") + + def on_exit(): + execute_and_fetch_all(admin_cursor, "DROP USER admin") + + request.addfinalizer(on_exit) + # Connect with admin admin_cursor = connect(username="admin", password="").cursor() transactions_id = ["'1'", "'2'", "'3'"] @@ -314,7 +355,6 @@ def test_admin_killing_multiple_non_existing_transactions(): for i in range(len(results)): assert results[i][0] == eval(transactions_id[i]) # transaction id assert results[i][1] == False # not killed - execute_and_fetch_all(admin_cursor, "DROP USER admin") def test_user_killing_some_transactions(): diff --git a/tests/gql_behave/tests/memgraph_V1/features/match.feature b/tests/gql_behave/tests/memgraph_V1/features/match.feature index eaf8d3f44..47da2fadf 100644 --- a/tests/gql_behave/tests/memgraph_V1/features/match.feature +++ b/tests/gql_behave/tests/memgraph_V1/features/match.feature @@ -761,6 +761,7 @@ Feature: Match Then the result should be: | path | | <(:label1 {id: 1})-[:type2 {id: 10}]->(:label3 {id: 3})> | + | <(:label1 {id: 1})-[:same {id: 30}]->(:label1 {id: 1})-[:type2 {id: 10}]->(:label3 {id: 3})> | Scenario: Test DFS variable expand using IN edges with filter by edge type1 Given graph "graph_edges" @@ -771,6 +772,7 @@ Feature: Match Then the result should be: | path | | <(:label3 {id: 3})<-[:type2 {id: 10}]-(:label1 {id: 1})> | + | <(:label3 {id: 3})<-[:type2 {id: 10}]-(:label1 {id: 1})-[:same {id: 30}]->(:label1 {id: 1})> | Scenario: Test DFS variable expand with filter by edge type2 Given graph "graph_edges" @@ -781,6 +783,7 @@ Feature: Match Then the result should be: | path | | <(:label1 {id: 1})-[:type1 {id: 1}]->(:label2 {id: 2})-[:type1 {id: 2}]->(:label3 {id: 3})> | + | <(:label1 {id: 1})-[:same {id: 30}]->(:label1 {id: 1})-[:type1 {id: 1}]->(:label2 {id: 2})-[:type1 {id: 2}]->(:label3 {id: 3})> | Scenario: Test DFS variable expand using IN edges with filter by edge type2 Given graph "graph_edges" @@ -791,6 +794,7 @@ Feature: Match Then the result should be: | path | | <(:label3 {id: 3})<-[:type1 {id: 2}]-(:label2 {id: 2})<-[:type1 {id: 1}]-(:label1 {id: 1})> | + | <(:label3 {id: 3})<-[:type1 {id: 2}]-(:label2 {id: 2})<-[:type1 {id: 1}]-(:label1 {id: 1})-[:same {id: 30}]->(:label1 {id: 1})> | Scenario: Using path indentifier from CREATE in MERGE Given an empty graph diff --git a/tests/gql_behave/tests/memgraph_V1/features/memgraph_bfs.feature b/tests/gql_behave/tests/memgraph_V1/features/memgraph_bfs.feature index 23edc69cd..01855e548 100644 --- a/tests/gql_behave/tests/memgraph_V1/features/memgraph_bfs.feature +++ b/tests/gql_behave/tests/memgraph_V1/features/memgraph_bfs.feature @@ -249,3 +249,15 @@ Feature: Bfs Then the result should be: | path | | <(:label1 {id: 1})-[:type1 {id: 1}]->(:label2 {id: 2})-[:type1 {id: 2}]->(:label3 {id: 3})> | + + Scenario: Test BFS variable expand with already processed vertex and loop with filter by path + Given graph "graph_edges" + When executing query: + """ + MATCH path=(:label1)-[*BFS 1..1 (e, n, p | True)]-() RETURN path; + """ + Then the result should be: + | path | + | < (:label1 {id: 1})-[:type3 {id: 20}]->(:label5 {id: 5}) > | + | < (:label1 {id: 1})-[:type2 {id: 10}]->(:label3 {id: 3}) > | + | < (:label1 {id: 1})-[:type1 {id: 1}]->(:label2 {id: 2}) > | diff --git a/tests/gql_behave/tests/memgraph_V1/features/memgraph_wshortest.feature b/tests/gql_behave/tests/memgraph_V1/features/memgraph_wshortest.feature index afd484696..a160e471a 100644 --- a/tests/gql_behave/tests/memgraph_V1/features/memgraph_wshortest.feature +++ b/tests/gql_behave/tests/memgraph_V1/features/memgraph_wshortest.feature @@ -269,3 +269,15 @@ Feature: Weighted Shortest Path Then the result should be: | path | total_weight | | <(:station {arrival: 08:00:00.000000000, departure: 08:15:00.000000000, name: 'A'})-[:ride {duration: PT1H5M, id: 1}]->(:station {arrival: 09:20:00.000000000, departure: 09:30:00.000000000, name: 'B'})-[:ride {duration: PT30M, id: 2}]->(:station {arrival: 10:00:00.000000000, departure: 10:20:00.000000000, name: 'C'})> | PT2H20M | + + Scenario: Test wShortest variable expand with already processed vertex and loop with filter by path + Given graph "graph_edges" + When executing query: + """ + MATCH path=(:label1)-[*WSHORTEST ..1 (r, n | r.id) total_weight (e, n, p | True)]-() RETURN path; + """ + Then the result should be: + | path | + | < (:label1 {id: 1})-[:type3 {id: 20}]->(:label5 {id: 5}) > | + | < (:label1 {id: 1})-[:type2 {id: 10}]->(:label3 {id: 3}) > | + | < (:label1 {id: 1})-[:type1 {id: 1}]->(:label2 {id: 2}) > | diff --git a/tests/gql_behave/tests/memgraph_V1/graphs/graph_edges.cypher b/tests/gql_behave/tests/memgraph_V1/graphs/graph_edges.cypher index 3657b855b..1aa081cc1 100644 --- a/tests/gql_behave/tests/memgraph_V1/graphs/graph_edges.cypher +++ b/tests/gql_behave/tests/memgraph_V1/graphs/graph_edges.cypher @@ -1,3 +1,4 @@ CREATE (:label1 {id: 1})-[:type1 {id:1}]->(:label2 {id: 2})-[:type1 {id: 2}]->(:label3 {id: 3})-[:type1 {id: 3}]->(:label4 {id: 4}); MATCH (n :label1), (m :label3) CREATE (n)-[:type2 {id: 10}]->(m); MATCH (n :label1) CREATE (n)-[:type3 {id: 20}]->(:label5 { id: 5 }); +MATCH (n :label1) CREATE (n)-[:same {id: 30}]->(n); diff --git a/tests/integration/auth/runner.py b/tests/integration/auth/runner.py index a74b19a4e..6dcd42f38 100755 --- a/tests/integration/auth/runner.py +++ b/tests/integration/auth/runner.py @@ -193,12 +193,12 @@ def execute_test(memgraph_binary, tester_binary, checker_binary): "GRANT DATABASE db2 TO user", "CREATE USER useR2 IDENTIFIED BY 'user'", "GRANT DATABASE db2 TO user2", - "REVOKE DATABASE memgraph FROM user2", + "DENY DATABASE memgraph FROM user2", "SET MAIN DATABASE db2 FOR user2", "CREATE USER user3 IDENTIFIED BY 'user'", "GRANT ALL PRIVILEGES TO user3", "GRANT DATABASE * TO user3", - "REVOKE DATABASE memgraph FROM user3", + "DENY DATABASE memgraph FROM user3", ] ) diff --git a/tests/integration/ldap/runner.py b/tests/integration/ldap/runner.py index 8fc3af913..9e1a20f71 100755 --- a/tests/integration/ldap/runner.py +++ b/tests/integration/ldap/runner.py @@ -139,7 +139,7 @@ class Memgraph: def initialize_test(memgraph, tester_binary, **kwargs): memgraph.start(module_executable="") - execute_tester(tester_binary, ["CREATE USER root", "GRANT ALL PRIVILEGES TO root"]) + execute_tester(tester_binary, ["CREATE ROLE root_role", "GRANT ALL PRIVILEGES TO root_role"]) check_login = kwargs.pop("check_login", True) memgraph.restart(**kwargs) if check_login: @@ -149,20 +149,24 @@ def initialize_test(memgraph, tester_binary, **kwargs): # Tests -def test_basic(memgraph, tester_binary): +def test_module_ux(memgraph, tester_binary): initialize_test(memgraph, tester_binary) - execute_tester(tester_binary, [], "alice") - execute_tester(tester_binary, ["GRANT MATCH TO alice"], "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") + execute_tester(tester_binary, ["CREATE USER user1"], "root", query_should_fail=True) + execute_tester(tester_binary, ["CREATE ROLE role1"], "root", query_should_fail=False) + execute_tester(tester_binary, ["DROP USER user1"], "root", query_should_fail=True) + execute_tester(tester_binary, ["DROP ROLE role1"], "root", query_should_fail=False) + execute_tester(tester_binary, ["SET ROLE FOR user1 TO role1"], "root", query_should_fail=True) + execute_tester(tester_binary, ["CLEAR ROLE FOR user1"], "root", query_should_fail=True) memgraph.stop() -def test_only_existing_users(memgraph, tester_binary): - initialize_test(memgraph, tester_binary, create_missing_user=False) +def test_user_auth(memgraph, tester_binary): + initialize_test(memgraph, tester_binary) execute_tester(tester_binary, [], "alice", auth_should_fail=True) - execute_tester(tester_binary, ["CREATE USER alice"], "root") + execute_tester(tester_binary, ["CREATE ROLE moderator"], "root") execute_tester(tester_binary, [], "alice") - execute_tester(tester_binary, ["GRANT MATCH TO alice"], "root") + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) + execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root") execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") memgraph.stop() @@ -170,77 +174,50 @@ def test_only_existing_users(memgraph, tester_binary): def test_role_mapping(memgraph, tester_binary): initialize_test(memgraph, tester_binary) - execute_tester(tester_binary, [], "alice") + execute_tester(tester_binary, [], "alice", auth_should_fail=True) + execute_tester(tester_binary, [], "bob", auth_should_fail=True) + execute_tester(tester_binary, [], "carol", auth_should_fail=True) + execute_tester(tester_binary, ["CREATE ROLE moderator"], "root") + execute_tester(tester_binary, ["CREATE ROLE admin"], "root") + execute_tester(tester_binary, [], "alice", auth_should_fail=False) + execute_tester(tester_binary, [], "bob", auth_should_fail=True) + execute_tester(tester_binary, [], "carol", auth_should_fail=False) + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "carol", query_should_fail=True) execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=False) + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "carol", query_should_fail=True) + execute_tester(tester_binary, ["GRANT MATCH TO admin"], "root") + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=False) + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "carol", query_should_fail=False) - execute_tester(tester_binary, [], "bob") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "bob", query_should_fail=True) - - execute_tester(tester_binary, [], "carol") + execute_tester(tester_binary, ["CREATE (n) RETURN n"], "alice", query_should_fail=True) execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol", query_should_fail=True) execute_tester(tester_binary, ["GRANT CREATE TO admin"], "root") - execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol") - execute_tester(tester_binary, ["CREATE (n) RETURN n"], "dave") + execute_tester(tester_binary, ["CREATE (n) RETURN n"], "alice", query_should_fail=True) + execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol", query_should_fail=False) memgraph.stop() +def test_instance_restart(memgraph, tester_binary): + initialize_test(memgraph, tester_binary) + execute_tester(tester_binary, ["CREATE ROLE moderator"], "root") + execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root") + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") + memgraph.restart() + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") + memgraph.stop() + + def test_role_removal(memgraph, tester_binary): initialize_test(memgraph, tester_binary) - execute_tester(tester_binary, [], "alice") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) + execute_tester(tester_binary, ["CREATE ROLE moderator"], "root") execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root") execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") - memgraph.restart(manage_roles=False) - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") - execute_tester(tester_binary, ["CLEAR ROLE FOR alice"], "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) - memgraph.stop() - - -def test_only_existing_roles(memgraph, tester_binary): - initialize_test(memgraph, tester_binary, create_missing_role=False) - execute_tester(tester_binary, [], "bob") + execute_tester(tester_binary, ["DROP ROLE moderator"], "root") execute_tester(tester_binary, [], "alice", auth_should_fail=True) - execute_tester(tester_binary, ["CREATE ROLE moderator"], "root") - execute_tester(tester_binary, [], "alice") - memgraph.stop() - - -def test_role_is_user(memgraph, tester_binary): - initialize_test(memgraph, tester_binary) - execute_tester(tester_binary, [], "admin") - execute_tester(tester_binary, [], "carol", auth_should_fail=True) - memgraph.stop() - - -def test_user_is_role(memgraph, tester_binary): - initialize_test(memgraph, tester_binary) - execute_tester(tester_binary, [], "carol") - execute_tester(tester_binary, [], "admin", auth_should_fail=True) - memgraph.stop() - - -def test_user_permissions_persistancy(memgraph, tester_binary): - initialize_test(memgraph, tester_binary) - execute_tester(tester_binary, ["CREATE USER alice", "GRANT MATCH TO alice"], "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") - memgraph.stop() - - -def test_role_permissions_persistancy(memgraph, tester_binary): - initialize_test(memgraph, tester_binary) - execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") - memgraph.stop() - - -def test_only_authentication(memgraph, tester_binary): - initialize_test(memgraph, tester_binary, manage_roles=False) - execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) memgraph.stop() @@ -258,36 +235,36 @@ def test_wrong_suffix(memgraph, tester_binary): def test_suffix_with_spaces(memgraph, tester_binary): initialize_test(memgraph, tester_binary, suffix=", ou= people, dc = memgraph, dc = com") - execute_tester(tester_binary, ["CREATE USER alice", "GRANT MATCH TO alice"], "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") - memgraph.stop() - - -def test_role_mapping_wrong_root_dn(memgraph, tester_binary): - initialize_test(memgraph, tester_binary, root_dn="ou=invalid,dc=memgraph,dc=com") execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) - memgraph.restart() execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") memgraph.stop() -def test_role_mapping_wrong_root_objectclass(memgraph, tester_binary): - initialize_test(memgraph, tester_binary, root_objectclass="person") - execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) - memgraph.restart() - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") - memgraph.stop() +# def test_role_mapping_wrong_root_dn(memgraph, tester_binary): +# initialize_test(memgraph, tester_binary, root_dn="ou=invalid,dc=memgraph,dc=com") +# execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root") +# execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) +# memgraph.restart() +# execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") +# memgraph.stop() -def test_role_mapping_wrong_user_attribute(memgraph, tester_binary): - initialize_test(memgraph, tester_binary, user_attribute="cn") - execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) - memgraph.restart() - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") - memgraph.stop() +# def test_role_mapping_wrong_root_objectclass(memgraph, tester_binary): +# initialize_test(memgraph, tester_binary, root_objectclass="person") +# execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root") +# execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) +# memgraph.restart() +# execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") +# memgraph.stop() + + +# def test_role_mapping_wrong_user_attribute(memgraph, tester_binary): +# initialize_test(memgraph, tester_binary, user_attribute="cn") +# execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root") +# execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) +# memgraph.restart() +# execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") +# memgraph.stop() def test_wrong_password(memgraph, tester_binary): @@ -297,31 +274,9 @@ def test_wrong_password(memgraph, tester_binary): memgraph.stop() -def test_password_persistancy(memgraph, tester_binary): - initialize_test(memgraph, tester_binary, check_login=False) - memgraph.restart(module_executable="") - execute_tester(tester_binary, ["SHOW USERS"], "root", password="sudo") - execute_tester(tester_binary, ["SHOW USERS"], "root", password="root") - memgraph.restart() - execute_tester(tester_binary, [], "root", password="sudo", auth_should_fail=True) - execute_tester(tester_binary, ["SHOW USERS"], "root", password="root") - memgraph.restart(module_executable="") - execute_tester(tester_binary, [], "root", password="sudo", auth_should_fail=True) - execute_tester(tester_binary, ["SHOW USERS"], "root", password="root") - memgraph.stop() - - def test_user_multiple_roles(memgraph, tester_binary): - initialize_test(memgraph, tester_binary, check_login=False) - memgraph.restart() - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", query_should_fail=True) - execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", query_should_fail=True) - memgraph.restart(manage_roles=False) - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", query_should_fail=True) - execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", query_should_fail=True) - memgraph.restart(manage_roles=False, root_dn="") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", query_should_fail=True) - execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", query_should_fail=True) + initialize_test(memgraph, tester_binary) + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", auth_should_fail=True) memgraph.stop() diff --git a/tests/integration/ldap/schema.ldif b/tests/integration/ldap/schema.ldif index f47ca5e8f..730c04415 100644 --- a/tests/integration/ldap/schema.ldif +++ b/tests/integration/ldap/schema.ldif @@ -84,6 +84,13 @@ objectclass: organizationalUnit objectclass: top ou: roles +# Role root +dn: cn=root_role,ou=roles,dc=memgraph,dc=com +cn: root_role +member: cn=root,ou=people,dc=memgraph,dc=com +objectclass: groupOfNames +objectclass: top + # Role moderator dn: cn=moderator,ou=roles,dc=memgraph,dc=com cn: moderator diff --git a/tests/integration/ldap/tester.cpp b/tests/integration/ldap/tester.cpp index 8f79938c7..fc8acfd82 100644 --- a/tests/integration/ldap/tester.cpp +++ b/tests/integration/ldap/tester.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -48,6 +48,7 @@ int main(int argc, char **argv) { } if (FLAGS_auth_should_fail) { MG_ASSERT(!what.empty(), "The authentication should have failed!"); + return 0; // Auth failed, nothing left to do } else { MG_ASSERT(what.empty(), "The authentication should have succeeded, but " diff --git a/tests/manual/single_query.cpp b/tests/manual/single_query.cpp index f2ce9c572..32e093a7a 100644 --- a/tests/manual/single_query.cpp +++ b/tests/manual/single_query.cpp @@ -50,6 +50,8 @@ int main(int argc, char *argv[]) { memgraph::query::Interpreter interpreter{&interpreter_context, db_acc}; ResultStreamFaker stream(db_acc->storage()); + memgraph::query::AllowEverythingAuthChecker auth_checker; + interpreter.SetUser(auth_checker.GenQueryUser(std::nullopt, std::nullopt)); auto [header, _1, qid, _2] = interpreter.Prepare(argv[1], {}, {}); stream.Header(header); auto summary = interpreter.PullAll(&stream); diff --git a/tests/unit/auth.cpp b/tests/unit/auth.cpp index bc2947a12..8bac5a05b 100644 --- a/tests/unit/auth.cpp +++ b/tests/unit/auth.cpp @@ -280,6 +280,8 @@ TEST_F(AuthWithStorage, RoleManipulations) { } { + const auto all = auth->AllUsernames(); + for (const auto &user : all) std::cout << user << std::endl; auto users = auth->AllUsers(); std::sort(users.begin(), users.end(), [](const User &a, const User &b) { return a.username() < b.username(); }); ASSERT_EQ(users.size(), 2); @@ -774,14 +776,16 @@ TEST_F(AuthWithStorage, CaseInsensitivity) { // Authenticate { - auto user = auth->Authenticate("alice", "alice"); - ASSERT_TRUE(user); - ASSERT_EQ(user->username(), "alice"); + auto user_or_role = auth->Authenticate("alice", "alice"); + ASSERT_TRUE(user_or_role); + const auto &user = std::get(*user_or_role); + ASSERT_EQ(user.username(), "alice"); } { - auto user = auth->Authenticate("alICe", "alice"); - ASSERT_TRUE(user); - ASSERT_EQ(user->username(), "alice"); + auto user_or_role = auth->Authenticate("alICe", "alice"); + ASSERT_TRUE(user_or_role); + const auto &user = std::get(*user_or_role); + ASSERT_EQ(user.username(), "alice"); } // GetUser @@ -809,6 +813,8 @@ TEST_F(AuthWithStorage, CaseInsensitivity) { // AllUsers { + const auto all = auth->AllUsernames(); + for (const auto &user : all) std::cout << user << std::endl; auto users = auth->AllUsers(); ASSERT_EQ(users.size(), 2); std::sort(users.begin(), users.end(), [](const auto &a, const auto &b) { return a.username() < b.username(); }); diff --git a/tests/unit/auth_checker.cpp b/tests/unit/auth_checker.cpp index f4c499cd7..50bec4cbc 100644 --- a/tests/unit/auth_checker.cpp +++ b/tests/unit/auth_checker.cpp @@ -12,11 +12,14 @@ #include #include +#include "auth/exceptions.hpp" #include "auth/models.hpp" #include "disk_test_utils.hpp" #include "glue/auth_checker.hpp" #include "license/license.hpp" +#include "query/frontend/ast/ast.hpp" +#include "query/query_user.hpp" #include "query_plan_common.hpp" #include "storage/v2/config.hpp" #include "storage/v2/disk/storage.hpp" @@ -225,4 +228,123 @@ TYPED_TEST(FineGrainedAuthCheckerFixture, GrantAndDenySpecificEdgeTypes) { ASSERT_FALSE(auth_checker.Has(this->r3, memgraph::query::AuthQuery::FineGrainedPrivilege::READ)); ASSERT_FALSE(auth_checker.Has(this->r4, memgraph::query::AuthQuery::FineGrainedPrivilege::READ)); } + +TEST(AuthChecker, Generate) { + std::filesystem::path auth_dir{std::filesystem::temp_directory_path() / "MG_auth_checker"}; + memgraph::utils::OnScopeExit clean([&]() { + if (std::filesystem::exists(auth_dir)) { + std::filesystem::remove_all(auth_dir); + } + }); + memgraph::auth::SynchedAuth auth(auth_dir, memgraph::auth::Auth::Config{/* default config */}); + memgraph::glue::AuthChecker auth_checker(&auth); + + auto empty_user = auth_checker.GenQueryUser(std::nullopt, std::nullopt); + ASSERT_THROW(auth_checker.GenQueryUser("does_not_exist", std::nullopt), memgraph::auth::AuthException); + + EXPECT_FALSE(empty_user && *empty_user); + // Still empty auth, so the above should have su permissions + using enum memgraph::query::AuthQuery::Privilege; + EXPECT_TRUE(empty_user->IsAuthorized({AUTH, REMOVE, REPLICATION}, "", &memgraph::query::session_long_policy)); + EXPECT_TRUE(empty_user->IsAuthorized({FREE_MEMORY, WEBSOCKET, MULTI_DATABASE_EDIT}, "memgraph", + &memgraph::query::session_long_policy)); + EXPECT_TRUE( + empty_user->IsAuthorized({TRIGGER, DURABILITY, STORAGE_MODE}, "some_db", &memgraph::query::session_long_policy)); + + // Add user + auth->AddUser("new_user"); + + // ~Empty user should now fail~ + // NOTE: Cache invalidation has been disabled, so this will pass; change if it is ever turned on + EXPECT_TRUE(empty_user->IsAuthorized({AUTH, REMOVE, REPLICATION}, "", &memgraph::query::session_long_policy)); + EXPECT_TRUE(empty_user->IsAuthorized({FREE_MEMORY, WEBSOCKET, MULTI_DATABASE_EDIT}, "memgraph", + &memgraph::query::session_long_policy)); + EXPECT_TRUE( + empty_user->IsAuthorized({TRIGGER, DURABILITY, STORAGE_MODE}, "some_db", &memgraph::query::session_long_policy)); + + // Add role and new user + auto new_role = *auth->AddRole("new_role"); + auto new_user2 = *auth->AddUser("new_user2"); + auto role = auth_checker.GenQueryUser("anyuser", "new_role"); + auto user2 = auth_checker.GenQueryUser("new_user2", std::nullopt); + + // Should be permission-less by default + EXPECT_FALSE(role->IsAuthorized({AUTH}, "memgraph", &memgraph::query::session_long_policy)); + EXPECT_FALSE(role->IsAuthorized({FREE_MEMORY}, "memgraph", &memgraph::query::session_long_policy)); + EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::session_long_policy)); + EXPECT_FALSE(user2->IsAuthorized({AUTH}, "memgraph", &memgraph::query::session_long_policy)); + EXPECT_FALSE(user2->IsAuthorized({FREE_MEMORY}, "memgraph", &memgraph::query::session_long_policy)); + EXPECT_FALSE(user2->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::session_long_policy)); + + // Update permissions and recheck + new_user2.permissions().Grant(memgraph::auth::Permission::AUTH); + new_role.permissions().Grant(memgraph::auth::Permission::TRIGGER); + auth->SaveUser(new_user2); + auth->SaveRole(new_role); + role = auth_checker.GenQueryUser("no check", "new_role"); + user2 = auth_checker.GenQueryUser("new_user2", std::nullopt); + EXPECT_FALSE(role->IsAuthorized({AUTH}, "memgraph", &memgraph::query::session_long_policy)); + EXPECT_FALSE(role->IsAuthorized({FREE_MEMORY}, "memgraph", &memgraph::query::session_long_policy)); + EXPECT_TRUE(role->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::session_long_policy)); + EXPECT_TRUE(user2->IsAuthorized({AUTH}, "memgraph", &memgraph::query::session_long_policy)); + EXPECT_FALSE(user2->IsAuthorized({FREE_MEMORY}, "memgraph", &memgraph::query::session_long_policy)); + EXPECT_FALSE(user2->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::session_long_policy)); + + // Connect role and recheck + new_user2.SetRole(new_role); + auth->SaveUser(new_user2); + user2 = auth_checker.GenQueryUser("new_user2", std::nullopt); + EXPECT_TRUE(user2->IsAuthorized({AUTH}, "memgraph", &memgraph::query::session_long_policy)); + EXPECT_FALSE(user2->IsAuthorized({FREE_MEMORY}, "memgraph", &memgraph::query::session_long_policy)); + EXPECT_TRUE(user2->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::session_long_policy)); + + // Add database and recheck + EXPECT_FALSE(user2->IsAuthorized({AUTH}, "non_default", &memgraph::query::session_long_policy)); + EXPECT_FALSE(user2->IsAuthorized({AUTH}, "another", &memgraph::query::session_long_policy)); + EXPECT_TRUE(user2->IsAuthorized({AUTH}, "memgraph", &memgraph::query::session_long_policy)); + EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "non_default", &memgraph::query::session_long_policy)); + EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "another", &memgraph::query::session_long_policy)); + EXPECT_TRUE(role->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::session_long_policy)); + new_user2.db_access().Grant("another"); + new_role.db_access().Grant("non_default"); + auth->SaveUser(new_user2); + auth->SaveRole(new_role); + // Session policy test + // Session long policy + EXPECT_FALSE(user2->IsAuthorized({AUTH}, "non_default", &memgraph::query::session_long_policy)); + EXPECT_FALSE(user2->IsAuthorized({AUTH}, "another", &memgraph::query::session_long_policy)); + EXPECT_TRUE(user2->IsAuthorized({AUTH}, "memgraph", &memgraph::query::session_long_policy)); + EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "non_default", &memgraph::query::session_long_policy)); + EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "another", &memgraph::query::session_long_policy)); + EXPECT_TRUE(role->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::session_long_policy)); + // Up to date policy + EXPECT_TRUE(user2->IsAuthorized({AUTH}, "non_default", &memgraph::query::up_to_date_policy)); + EXPECT_TRUE(user2->IsAuthorized({AUTH}, "another", &memgraph::query::up_to_date_policy)); + EXPECT_TRUE(user2->IsAuthorized({AUTH}, "memgraph", &memgraph::query::up_to_date_policy)); + EXPECT_TRUE(role->IsAuthorized({TRIGGER}, "non_default", &memgraph::query::up_to_date_policy)); + EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "another", &memgraph::query::up_to_date_policy)); + EXPECT_TRUE(role->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::up_to_date_policy)); + + new_user2.db_access().Deny("memgraph"); + new_role.db_access().Deny("non_default"); + auth->SaveUser(new_user2); + auth->SaveRole(new_role); + EXPECT_FALSE(user2->IsAuthorized({AUTH}, "non_default", &memgraph::query::up_to_date_policy)); + EXPECT_TRUE(user2->IsAuthorized({AUTH}, "another", &memgraph::query::up_to_date_policy)); + EXPECT_FALSE(user2->IsAuthorized({AUTH}, "memgraph", &memgraph::query::up_to_date_policy)); + EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "non_default", &memgraph::query::up_to_date_policy)); + EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "another", &memgraph::query::up_to_date_policy)); + EXPECT_TRUE(role->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::up_to_date_policy)); + + new_user2.db_access().Revoke("memgraph"); + new_role.db_access().Revoke("non_default"); + auth->SaveUser(new_user2); + auth->SaveRole(new_role); + EXPECT_FALSE(user2->IsAuthorized({AUTH}, "non_default", &memgraph::query::up_to_date_policy)); + EXPECT_TRUE(user2->IsAuthorized({AUTH}, "another", &memgraph::query::up_to_date_policy)); + EXPECT_TRUE(user2->IsAuthorized({AUTH}, "memgraph", &memgraph::query::up_to_date_policy)); + EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "non_default", &memgraph::query::up_to_date_policy)); + EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "another", &memgraph::query::up_to_date_policy)); + EXPECT_TRUE(role->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::up_to_date_policy)); +} #endif diff --git a/tests/unit/interpreter_faker.hpp b/tests/unit/interpreter_faker.hpp index 3b6075911..c1e3b4b06 100644 --- a/tests/unit/interpreter_faker.hpp +++ b/tests/unit/interpreter_faker.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -18,6 +18,7 @@ struct InterpreterFaker { : interpreter_context(interpreter_context), interpreter(interpreter_context, db) { interpreter_context->auth_checker = &auth_checker; interpreter_context->interpreters.WithLock([this](auto &interpreters) { interpreters.insert(&interpreter); }); + interpreter.SetUser(auth_checker.GenQueryUser(std::nullopt, std::nullopt)); } auto Prepare(const std::string &query, const std::map ¶ms = {}) { diff --git a/tests/unit/monitoring.cpp b/tests/unit/monitoring.cpp index e04e091e5..26dc6ad47 100644 --- a/tests/unit/monitoring.cpp +++ b/tests/unit/monitoring.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -44,11 +44,9 @@ struct MockAuth : public memgraph::communication::websocket::AuthenticationInter return authentication; } - bool HasUserPermission(const std::string & /*username*/, memgraph::auth::Permission /*permission*/) const override { - return authorization; - } + bool HasPermission(memgraph::auth::Permission /*permission*/) const override { return authorization; } - bool HasAnyUsers() const override { return has_any_users; } + bool AccessControlled() const override { return has_any_users; } bool authentication{true}; bool authorization{true}; diff --git a/tests/unit/query_dump.cpp b/tests/unit/query_dump.cpp index e94659ce7..2dd1e7ac7 100644 --- a/tests/unit/query_dump.cpp +++ b/tests/unit/query_dump.cpp @@ -21,6 +21,8 @@ #include "communication/result_stream_faker.hpp" #include "dbms/database.hpp" #include "disk_test_utils.hpp" +#include "glue/auth_checker.hpp" +#include "query/auth_checker.hpp" #include "query/config.hpp" #include "query/dump.hpp" #include "query/interpreter.hpp" @@ -235,6 +237,8 @@ DatabaseState GetState(memgraph::storage::Storage *db) { auto Execute(memgraph::query::InterpreterContext *context, memgraph::dbms::DatabaseAccess db, const std::string &query) { memgraph::query::Interpreter interpreter(context, db); + memgraph::query::AllowEverythingAuthChecker auth_checker; + interpreter.SetUser(auth_checker.GenQueryUser(std::nullopt, std::nullopt)); ResultStreamFaker stream(db->storage()); auto [header, _1, qid, _2] = interpreter.Prepare(query, {}, {}); @@ -934,7 +938,10 @@ TYPED_TEST(DumpTest, ExecuteDumpDatabase) { class StatefulInterpreter { public: explicit StatefulInterpreter(memgraph::query::InterpreterContext *context, memgraph::dbms::DatabaseAccess db) - : context_(context), interpreter_(context_, db) {} + : context_(context), interpreter_(context_, db) { + memgraph::query::AllowEverythingAuthChecker auth_checker; + interpreter_.SetUser(auth_checker.GenQueryUser(std::nullopt, std::nullopt)); + } auto Execute(const std::string &query) { ResultStreamFaker stream(interpreter_.current_db_.db_acc_->get()->storage()); @@ -1157,7 +1164,7 @@ TYPED_TEST(DumpTest, DumpDatabaseWithTriggers) { memgraph::query::DbAccessor dba(acc.get()); const std::map props; trigger_store->AddTrigger(trigger_name, trigger_statement, props, trigger_event_type, trigger_phase, &ast_cache, - &dba, query_config, std::nullopt, &auth_checker); + &dba, query_config, auth_checker.GenQueryUser(std::nullopt, std::nullopt)); } { ResultStreamFaker stream(this->db->storage()); diff --git a/tests/unit/query_plan_edge_cases.cpp b/tests/unit/query_plan_edge_cases.cpp index ac04cabdd..262ebd4e1 100644 --- a/tests/unit/query_plan_edge_cases.cpp +++ b/tests/unit/query_plan_edge_cases.cpp @@ -22,6 +22,7 @@ #include "gtest/gtest.h" #include "communication/result_stream_faker.hpp" +#include "query/auth_checker.hpp" #include "query/interpreter.hpp" #include "query/interpreter_context.hpp" #include "query/stream/streams.hpp" @@ -36,6 +37,7 @@ class QueryExecution : public testing::Test { const std::string testSuite = "query_plan_edge_cases"; std::optional db_acc_; std::optional interpreter_context_; + std::optional auth_checker_; std::optional interpreter_; std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_tests_unit_query_plan_edge_cases"}; @@ -73,11 +75,14 @@ class QueryExecution : public testing::Test { nullptr #endif ); + auth_checker_.emplace(); interpreter_.emplace(&*interpreter_context_, *db_acc_); + interpreter_->SetUser(auth_checker_->GenQueryUser(std::nullopt, std::nullopt)); } void TearDown() override { interpreter_ = std::nullopt; + auth_checker_.reset(); interpreter_context_ = std::nullopt; system_state.reset(); db_acc_.reset(); diff --git a/tests/unit/query_streams.cpp b/tests/unit/query_streams.cpp index cde3d937a..5b246468f 100644 --- a/tests/unit/query_streams.cpp +++ b/tests/unit/query_streams.cpp @@ -20,9 +20,11 @@ #include "integrations/constants.hpp" #include "integrations/kafka/exceptions.hpp" #include "kafka_mock.hpp" +#include "query/auth_checker.hpp" #include "query/config.hpp" #include "query/interpreter.hpp" #include "query/interpreter_context.hpp" +#include "query/query_user.hpp" #include "query/stream/streams.hpp" #include "storage/v2/config.hpp" #include "storage/v2/disk/storage.hpp" @@ -35,11 +37,23 @@ using StreamStatus = memgraph::query::stream::StreamStatus &privileges, const std::string &db_name, + memgraph::query::UserPolicy *policy) const { + return true; + } +#ifdef MG_ENTERPRISE + std::string GetDefaultDB() const { return "memgraph"; } +#endif +}; + struct StreamCheckData { std::string name; StreamInfo info; bool is_running; - std::optional owner; + std::shared_ptr owner; }; std::string GetDefaultStreamName() { @@ -105,13 +119,16 @@ class StreamsTestFixture : public ::testing::Test { }() // iile }; memgraph::system::System system_state; - memgraph::query::InterpreterContext interpreter_context_{memgraph::query::InterpreterConfig{}, nullptr, &repl_state, - system_state + memgraph::query::AllowEverythingAuthChecker auth_checker; + memgraph::query::InterpreterContext interpreter_context_{memgraph::query::InterpreterConfig{}, + nullptr, + &repl_state, + system_state, #ifdef MG_ENTERPRISE - , - nullptr + nullptr, #endif - }; + nullptr, + &auth_checker}; std::filesystem::path streams_data_directory_{data_directory_ / "separate-dir-for-test"}; std::optional proxyStreams_; @@ -173,7 +190,7 @@ class StreamsTestFixture : public ::testing::Test { } StreamCheckData CreateDefaultStreamCheckData() { - return {GetDefaultStreamName(), CreateDefaultStreamInfo(), false, std::nullopt}; + return {GetDefaultStreamName(), CreateDefaultStreamInfo(), false, std::make_unique()}; } void Clear() { @@ -215,11 +232,11 @@ TYPED_TEST(StreamsTestFixture, CreateAlreadyExisting) { auto stream_info = this->CreateDefaultStreamInfo(); auto stream_name = GetDefaultStreamName(); this->proxyStreams_->streams_->template Create( - stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_); + stream_name, stream_info, std::make_unique(), this->db_, &this->interpreter_context_); try { this->proxyStreams_->streams_->template Create( - stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_); + stream_name, stream_info, std::make_unique(), this->db_, &this->interpreter_context_); FAIL() << "Creating already existing stream should throw\n"; } catch (memgraph::query::stream::StreamsException &exception) { EXPECT_EQ(exception.what(), fmt::format("Stream already exists with name '{}'", stream_name)); @@ -231,7 +248,7 @@ TYPED_TEST(StreamsTestFixture, DropNotExistingStream) { const auto stream_name = GetDefaultStreamName(); const std::string not_existing_stream_name{"ThisDoesn'tExists"}; this->proxyStreams_->streams_->template Create( - stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_); + stream_name, stream_info, std::make_unique(), this->db_, &this->interpreter_context_); try { this->proxyStreams_->streams_->Drop(not_existing_stream_name); @@ -262,7 +279,7 @@ TYPED_TEST(StreamsTestFixture, RestoreStreams) { if (i > 0) { stream_info.common_info.batch_interval = std::chrono::milliseconds((i + 1) * 10); stream_info.common_info.batch_size = 1000 + i; - stream_check_data.owner = std::string{"owner"} + iteration_postfix; + stream_check_data.owner = std::make_unique(); // These are just random numbers to make the CONFIGS and CREDENTIALS map vary between consumers: // - 0 means no config, no credential @@ -280,7 +297,7 @@ TYPED_TEST(StreamsTestFixture, RestoreStreams) { this->mock_cluster_.CreateTopic(stream_info.topics[0]); } - stream_check_datas[3].owner = {}; + stream_check_datas[3].owner = std::make_unique(); const auto check_restore_logic = [&stream_check_datas, this]() { // Reset the Streams object to trigger reloading @@ -336,7 +353,7 @@ TYPED_TEST(StreamsTestFixture, CheckWithTimeout) { const auto stream_info = this->CreateDefaultStreamInfo(); const auto stream_name = GetDefaultStreamName(); this->proxyStreams_->streams_->template Create( - stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_); + stream_name, stream_info, std::make_unique(), this->db_, &this->interpreter_context_); std::chrono::milliseconds timeout{3000}; @@ -360,9 +377,10 @@ TYPED_TEST(StreamsTestFixture, CheckInvalidConfig) { EXPECT_TRUE(message.find(kInvalidConfigName) != std::string::npos) << message; EXPECT_TRUE(message.find(kConfigValue) != std::string::npos) << message; }; - EXPECT_THROW_WITH_MSG(this->proxyStreams_->streams_->template Create( - stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_), - memgraph::integrations::kafka::SettingCustomConfigFailed, checker); + EXPECT_THROW_WITH_MSG( + this->proxyStreams_->streams_->template Create( + stream_name, stream_info, std::make_unique(), this->db_, &this->interpreter_context_), + memgraph::integrations::kafka::SettingCustomConfigFailed, checker); } TYPED_TEST(StreamsTestFixture, CheckInvalidCredentials) { @@ -376,7 +394,8 @@ TYPED_TEST(StreamsTestFixture, CheckInvalidCredentials) { EXPECT_TRUE(message.find(memgraph::integrations::kReducted) != std::string::npos) << message; EXPECT_TRUE(message.find(kCredentialValue) == std::string::npos) << message; }; - EXPECT_THROW_WITH_MSG(this->proxyStreams_->streams_->template Create( - stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_), - memgraph::integrations::kafka::SettingCustomConfigFailed, checker); + EXPECT_THROW_WITH_MSG( + this->proxyStreams_->streams_->template Create( + stream_name, stream_info, std::make_unique(), this->db_, &this->interpreter_context_), + memgraph::integrations::kafka::SettingCustomConfigFailed, checker); } diff --git a/tests/unit/query_trigger.cpp b/tests/unit/query_trigger.cpp index 1b2ca5e9c..06aa1dbd9 100644 --- a/tests/unit/query_trigger.cpp +++ b/tests/unit/query_trigger.cpp @@ -21,6 +21,7 @@ #include "query/db_accessor.hpp" #include "query/frontend/ast/ast.hpp" #include "query/interpreter.hpp" +#include "query/query_user.hpp" #include "query/trigger.hpp" #include "query/typed_value.hpp" #include "storage/v2/config.hpp" @@ -42,16 +43,27 @@ const std::unordered_set kAllEventTypes{ class MockAuthChecker : public memgraph::query::AuthChecker { public: - MOCK_CONST_METHOD3(IsUserAuthorized, - bool(const std::optional &username, - const std::vector &privileges, const std::string &db)); + MOCK_CONST_METHOD2(GenQueryUser, + std::shared_ptr(const std::optional &username, + const std::optional &rolename)); #ifdef MG_ENTERPRISE - MOCK_CONST_METHOD2(GetFineGrainedAuthChecker, - std::unique_ptr( - const std::string &username, const memgraph::query::DbAccessor *db_accessor)); + MOCK_CONST_METHOD2(GetFineGrainedAuthChecker, std::unique_ptr( + std::shared_ptr user, + const memgraph::query::DbAccessor *db_accessor)); MOCK_CONST_METHOD0(ClearCache, void()); #endif }; + +class MockQueryUser : public memgraph::query::QueryUserOrRole { + public: + MockQueryUser(std::optional name) : memgraph::query::QueryUserOrRole(std::move(name), std::nullopt) {} + MOCK_CONST_METHOD3(IsAuthorized, bool(const std::vector &privileges, + const std::string &db_name, memgraph::query::UserPolicy *policy)); + +#ifdef MG_ENTERPRISE + MOCK_CONST_METHOD0(GetDefaultDB, std::string()); +#endif +}; } // namespace const std::string testSuite = "query_trigger"; @@ -966,12 +978,12 @@ TYPED_TEST(TriggerStoreTest, Restore) { trigger_name_before, trigger_statement, std::map{{"parameter", memgraph::storage::PropertyValue{1}}}, event_type, memgraph::query::TriggerPhase::BEFORE_COMMIT, &this->ast_cache, &*this->dba, - memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker); + memgraph::query::InterpreterConfig::Query{}, this->auth_checker.GenQueryUser(std::nullopt, std::nullopt)); store->AddTrigger( trigger_name_after, trigger_statement, std::map{{"parameter", memgraph::storage::PropertyValue{"value"}}}, event_type, memgraph::query::TriggerPhase::AFTER_COMMIT, &this->ast_cache, &*this->dba, - memgraph::query::InterpreterConfig::Query{}, {owner}, &this->auth_checker); + memgraph::query::InterpreterConfig::Query{}, this->auth_checker.GenQueryUser(owner, std::nullopt)); const auto check_triggers = [&] { ASSERT_EQ(store->GetTriggerInfo().size(), 2); @@ -981,9 +993,9 @@ TYPED_TEST(TriggerStoreTest, Restore) { ASSERT_EQ(trigger.OriginalStatement(), trigger_statement); ASSERT_EQ(trigger.EventType(), event_type); if (owner != nullptr) { - ASSERT_EQ(*trigger.Owner(), *owner); + ASSERT_EQ(trigger.Owner()->username(), *owner); } else { - ASSERT_FALSE(trigger.Owner().has_value()); + ASSERT_FALSE(trigger.Owner()->username()); } }; @@ -1022,32 +1034,38 @@ TYPED_TEST(TriggerStoreTest, AddTrigger) { // Invalid query in statements ASSERT_THROW(store.AddTrigger("trigger", "RETUR 1", {}, memgraph::query::TriggerEventType::VERTEX_CREATE, memgraph::query::TriggerPhase::BEFORE_COMMIT, &this->ast_cache, &*this->dba, - memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker), + memgraph::query::InterpreterConfig::Query{}, + this->auth_checker.GenQueryUser(std::nullopt, std::nullopt)), memgraph::utils::BasicException); ASSERT_THROW(store.AddTrigger("trigger", "RETURN createdEdges", {}, memgraph::query::TriggerEventType::VERTEX_CREATE, memgraph::query::TriggerPhase::BEFORE_COMMIT, &this->ast_cache, &*this->dba, - memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker), + memgraph::query::InterpreterConfig::Query{}, + this->auth_checker.GenQueryUser(std::nullopt, std::nullopt)), memgraph::utils::BasicException); ASSERT_THROW(store.AddTrigger("trigger", "RETURN $parameter", {}, memgraph::query::TriggerEventType::VERTEX_CREATE, memgraph::query::TriggerPhase::BEFORE_COMMIT, &this->ast_cache, &*this->dba, - memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker), + memgraph::query::InterpreterConfig::Query{}, + this->auth_checker.GenQueryUser(std::nullopt, std::nullopt)), memgraph::utils::BasicException); ASSERT_NO_THROW(store.AddTrigger( "trigger", "RETURN $parameter", std::map{{"parameter", memgraph::storage::PropertyValue{1}}}, memgraph::query::TriggerEventType::VERTEX_CREATE, memgraph::query::TriggerPhase::BEFORE_COMMIT, &this->ast_cache, - &*this->dba, memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker)); + &*this->dba, memgraph::query::InterpreterConfig::Query{}, + this->auth_checker.GenQueryUser(std::nullopt, std::nullopt))); // Inserting with the same name ASSERT_THROW(store.AddTrigger("trigger", "RETURN 1", {}, memgraph::query::TriggerEventType::VERTEX_CREATE, memgraph::query::TriggerPhase::BEFORE_COMMIT, &this->ast_cache, &*this->dba, - memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker), + memgraph::query::InterpreterConfig::Query{}, + this->auth_checker.GenQueryUser(std::nullopt, std::nullopt)), memgraph::utils::BasicException); ASSERT_THROW(store.AddTrigger("trigger", "RETURN 1", {}, memgraph::query::TriggerEventType::VERTEX_CREATE, memgraph::query::TriggerPhase::AFTER_COMMIT, &this->ast_cache, &*this->dba, - memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker), + memgraph::query::InterpreterConfig::Query{}, + this->auth_checker.GenQueryUser(std::nullopt, std::nullopt)), memgraph::utils::BasicException); ASSERT_EQ(store.GetTriggerInfo().size(), 1); @@ -1063,7 +1081,8 @@ TYPED_TEST(TriggerStoreTest, DropTrigger) { const auto *trigger_name = "trigger"; store.AddTrigger(trigger_name, "RETURN 1", {}, memgraph::query::TriggerEventType::VERTEX_CREATE, memgraph::query::TriggerPhase::BEFORE_COMMIT, &this->ast_cache, &*this->dba, - memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker); + memgraph::query::InterpreterConfig::Query{}, + this->auth_checker.GenQueryUser(std::nullopt, std::nullopt)); ASSERT_THROW(store.DropTrigger("Unknown"), memgraph::utils::BasicException); ASSERT_NO_THROW(store.DropTrigger(trigger_name)); @@ -1076,7 +1095,8 @@ TYPED_TEST(TriggerStoreTest, TriggerInfo) { std::vector expected_info; store.AddTrigger("trigger", "RETURN 1", {}, memgraph::query::TriggerEventType::VERTEX_CREATE, memgraph::query::TriggerPhase::BEFORE_COMMIT, &this->ast_cache, &*this->dba, - memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker); + memgraph::query::InterpreterConfig::Query{}, + this->auth_checker.GenQueryUser(std::nullopt, std::nullopt)); expected_info.push_back({"trigger", "RETURN 1", memgraph::query::TriggerEventType::VERTEX_CREATE, @@ -1099,7 +1119,8 @@ TYPED_TEST(TriggerStoreTest, TriggerInfo) { store.AddTrigger("edge_update_trigger", "RETURN 1", {}, memgraph::query::TriggerEventType::EDGE_UPDATE, memgraph::query::TriggerPhase::AFTER_COMMIT, &this->ast_cache, &*this->dba, - memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker); + memgraph::query::InterpreterConfig::Query{}, + this->auth_checker.GenQueryUser(std::nullopt, std::nullopt)); expected_info.push_back({"edge_update_trigger", "RETURN 1", memgraph::query::TriggerEventType::EDGE_UPDATE, @@ -1216,7 +1237,8 @@ TYPED_TEST(TriggerStoreTest, AnyTriggerAllKeywords) { SCOPED_TRACE(keyword); EXPECT_NO_THROW(store.AddTrigger(trigger_name, fmt::format("RETURN {}", keyword), {}, event_type, memgraph::query::TriggerPhase::BEFORE_COMMIT, &this->ast_cache, &*this->dba, - memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker)); + memgraph::query::InterpreterConfig::Query{}, + this->auth_checker.GenQueryUser(std::nullopt, std::nullopt))); store.DropTrigger(trigger_name); } } @@ -1228,45 +1250,50 @@ TYPED_TEST(TriggerStoreTest, AuthCheckerUsage) { using ::testing::ElementsAre; using ::testing::Return; std::optional store{this->testing_directory}; - const std::optional owner{"testing_owner"}; MockAuthChecker mock_checker; + const std::optional owner{"mock_user"}; + MockQueryUser mock_user(owner); + std::shared_ptr mock_user_ptr( + &mock_user, [](memgraph::query::QueryUserOrRole *) { /* do nothing */ }); + MockQueryUser mock_userless(std::nullopt); + std::shared_ptr mock_userless_ptr( + &mock_userless, [](memgraph::query::QueryUserOrRole *) { /* do nothing */ }); ::testing::InSequence s; - EXPECT_CALL(mock_checker, IsUserAuthorized(std::optional{}, ElementsAre(Privilege::CREATE), "")) - .Times(1) + // TODO Userless + EXPECT_CALL(mock_user, IsAuthorized(ElementsAre(Privilege::CREATE), "", &memgraph::query::up_to_date_policy)) .WillOnce(Return(true)); - EXPECT_CALL(mock_checker, IsUserAuthorized(owner, ElementsAre(Privilege::CREATE), "")) - .Times(1) - .WillOnce(Return(true)); - ASSERT_NO_THROW(store->AddTrigger("successfull_trigger_1", "CREATE (n:VERTEX) RETURN n", {}, memgraph::query::TriggerEventType::EDGE_UPDATE, memgraph::query::TriggerPhase::AFTER_COMMIT, &this->ast_cache, &*this->dba, - memgraph::query::InterpreterConfig::Query{}, std::nullopt, &mock_checker)); + memgraph::query::InterpreterConfig::Query{}, mock_user_ptr)); + EXPECT_CALL(mock_userless, IsAuthorized(ElementsAre(Privilege::CREATE), "", &memgraph::query::up_to_date_policy)) + .WillOnce(Return(true)); ASSERT_NO_THROW(store->AddTrigger("successfull_trigger_2", "CREATE (n:VERTEX) RETURN n", {}, memgraph::query::TriggerEventType::EDGE_UPDATE, memgraph::query::TriggerPhase::AFTER_COMMIT, &this->ast_cache, &*this->dba, - memgraph::query::InterpreterConfig::Query{}, owner, &mock_checker)); + memgraph::query::InterpreterConfig::Query{}, mock_userless_ptr)); - EXPECT_CALL(mock_checker, IsUserAuthorized(std::optional{}, ElementsAre(Privilege::MATCH), "")) - .Times(1) + EXPECT_CALL(mock_user, IsAuthorized(ElementsAre(Privilege::MATCH), "", &memgraph::query::up_to_date_policy)) .WillOnce(Return(false)); + ASSERT_THROW( + store->AddTrigger("unprivileged_trigger", "MATCH (n:VERTEX) RETURN n", {}, + memgraph::query::TriggerEventType::EDGE_UPDATE, memgraph::query::TriggerPhase::AFTER_COMMIT, + &this->ast_cache, &*this->dba, memgraph::query::InterpreterConfig::Query{}, mock_user_ptr); + , memgraph::utils::BasicException); - ASSERT_THROW(store->AddTrigger("unprivileged_trigger", "MATCH (n:VERTEX) RETURN n", {}, - memgraph::query::TriggerEventType::EDGE_UPDATE, - memgraph::query::TriggerPhase::AFTER_COMMIT, &this->ast_cache, &*this->dba, - memgraph::query::InterpreterConfig::Query{}, std::nullopt, &mock_checker); - , memgraph::utils::BasicException); - + // Restore store.emplace(this->testing_directory); - EXPECT_CALL(mock_checker, IsUserAuthorized(std::optional{}, ElementsAre(Privilege::CREATE), "")) - .Times(1) - .WillOnce(Return(false)); - EXPECT_CALL(mock_checker, IsUserAuthorized(owner, ElementsAre(Privilege::CREATE), "")) - .Times(1) + + std::optional nopt{}; + EXPECT_CALL(mock_checker, GenQueryUser(owner, nopt)).WillOnce(Return(mock_user_ptr)); + EXPECT_CALL(mock_user, IsAuthorized(ElementsAre(Privilege::CREATE), "", &memgraph::query::up_to_date_policy)) .WillOnce(Return(true)); + EXPECT_CALL(mock_checker, GenQueryUser(nopt, nopt)).WillOnce(Return(mock_userless_ptr)); + EXPECT_CALL(mock_userless, IsAuthorized(ElementsAre(Privilege::CREATE), "", &memgraph::query::up_to_date_policy)) + .WillOnce(Return(false)); ASSERT_NO_THROW(store->RestoreTriggers(&this->ast_cache, &*this->dba, memgraph::query::InterpreterConfig::Query{}, &mock_checker));