diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index e5c609800..77a711eb2 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -3,7 +3,6 @@ name: Bug report about: Create a report to help us improve title: "" labels: bug -assignees: gitbuda --- **Memgraph version** diff --git a/.github/workflows/diff.yaml b/.github/workflows/diff.yaml index 143ac102f..a2dc0aef2 100644 --- a/.github/workflows/diff.yaml +++ b/.github/workflows/diff.yaml @@ -336,53 +336,6 @@ jobs: # multiple paths could be defined build/logs - experimental_build_ha: - name: "High availability build" - runs-on: [self-hosted, Linux, X64, Diff] - env: - THREADS: 24 - MEMGRAPH_ENTERPRISE_LICENSE: ${{ secrets.MEMGRAPH_ENTERPRISE_LICENSE }} - MEMGRAPH_ORGANIZATION_NAME: ${{ secrets.MEMGRAPH_ORGANIZATION_NAME }} - - steps: - - name: Set up repository - uses: actions/checkout@v4 - with: - # Number of commits to fetch. `0` indicates all history for all - # branches and tags. (default: 1) - fetch-depth: 0 - - - name: Build release binaries - run: | - source /opt/toolchain-v4/activate - ./init - cd build - cmake -DCMAKE_BUILD_TYPE=Release -DMG_EXPERIMENTAL_HIGH_AVAILABILITY=ON .. - make -j$THREADS - - name: Run unit tests - run: | - source /opt/toolchain-v4/activate - cd build - ctest -R memgraph__unit --output-on-failure -j$THREADS - - name: Run e2e tests - if: false - run: | - cd tests - ./setup.sh /opt/toolchain-v4/activate - source ve3/bin/activate_e2e - cd e2e - ./run.sh "Coordinator" - ./run.sh "Client initiated failover" - ./run.sh "Uninitialized cluster" - - name: Save test data - uses: actions/upload-artifact@v4 - if: always() - with: - name: "Test data(High availability build)" - path: | - # multiple paths could be defined - build/logs - release_jepsen_test: name: "Release Jepsen Test" runs-on: [self-hosted, Linux, X64, Debian10, JepsenControl] diff --git a/.github/workflows/release_build_test.yaml b/.github/workflows/release_build_test.yaml new file mode 100644 index 000000000..cc4884758 --- /dev/null +++ b/.github/workflows/release_build_test.yaml @@ -0,0 +1,208 @@ +name: Release build test +concurrency: + group: ${{ github.workflow }}-${{ github.ref_name }} + cancel-in-progress: true + +on: + workflow_dispatch: + inputs: + build_type: + type: choice + description: "Memgraph Build type. Default value is Release." + default: 'Release' + options: + - Release + - RelWithDebInfo + + push: + branches: + - "release/**" + tags: + - "v*.*.*-rc*" + - "v*.*-rc*" + schedule: + # UTC + - cron: "0 22 * * *" + +env: + THREADS: 24 + MEMGRAPH_ENTERPRISE_LICENSE: ${{ secrets.MEMGRAPH_ENTERPRISE_LICENSE }} + MEMGRAPH_ORGANIZATION_NAME: ${{ secrets.MEMGRAPH_ORGANIZATION_NAME }} + BUILD_TYPE: ${{ github.event.inputs.build_type || 'Release' }} + +jobs: + Debian10: + uses: ./.github/workflows/release_debian10.yaml + with: + build_type: ${{ github.event.inputs.build_type || 'Release' }} + secrets: inherit + + Ubuntu20_04: + uses: ./.github/workflows/release_ubuntu2004.yaml + with: + build_type: ${{ github.event.inputs.build_type || 'Release' }} + secrets: inherit + + PackageDebian10: + if: github.ref_type == 'tag' + needs: [Debian10] + runs-on: [self-hosted, DockerMgBuild, X64] + timeout-minutes: 60 + steps: + - name: "Set up repository" + uses: actions/checkout@v4 + with: + fetch-depth: 0 # Required because of release/get_version.py + - name: "Build package" + run: | + ./release/package/run.sh package debian-10 $BUILD_TYPE + - name: Upload to S3 + uses: jakejarvis/s3-sync-action@v0.5.1 + env: + AWS_S3_BUCKET: "deps.memgraph.io" + AWS_ACCESS_KEY_ID: ${{ secrets.S3_AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_AWS_SECRET_ACCESS_KEY }} + AWS_REGION: "eu-west-1" + SOURCE_DIR: "build/output" + DEST_DIR: "memgraph-unofficial/${{ github.ref_name }}/" + - name: "Upload package" + uses: actions/upload-artifact@v4 + with: + name: debian-10 + path: build/output/debian-10/memgraph*.deb + + PackageUbuntu20_04: + if: github.ref_type == 'tag' + needs: [Ubuntu20_04] + runs-on: [self-hosted, DockerMgBuild, X64] + timeout-minutes: 60 + steps: + - name: "Set up repository" + uses: actions/checkout@v4 + with: + fetch-depth: 0 # Required because of release/get_version.py + - name: "Build package" + run: | + ./release/package/run.sh package ubuntu-22.04 $BUILD_TYPE + - name: Upload to S3 + uses: jakejarvis/s3-sync-action@v0.5.1 + env: + AWS_S3_BUCKET: "deps.memgraph.io" + AWS_ACCESS_KEY_ID: ${{ secrets.S3_AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_AWS_SECRET_ACCESS_KEY }} + AWS_REGION: "eu-west-1" + SOURCE_DIR: "build/output" + DEST_DIR: "memgraph-unofficial/${{ github.ref_name }}/" + - name: "Upload package" + uses: actions/upload-artifact@v4 + with: + name: ubuntu-22.04 + path: build/output/ubuntu-22.04/memgraph*.deb + + PackageUbuntu20_04_ARM: + if: github.ref_type == 'tag' + needs: [Ubuntu20_04] + runs-on: [self-hosted, DockerMgBuild, ARM64] + # M1 Mac mini is sometimes slower + timeout-minutes: 150 + steps: + - name: "Set up repository" + uses: actions/checkout@v4 + with: + fetch-depth: 0 # Required because of release/get_version.py + - name: "Build package" + run: | + ./release/package/run.sh package ubuntu-22.04-arm $BUILD_TYPE + - name: "Upload package" + uses: actions/upload-artifact@v4 + with: + name: ubuntu-22.04-aarch64 + path: build/output/ubuntu-22.04-arm/memgraph*.deb + + PushToS3Ubuntu20_04_ARM: + if: github.ref_type == 'tag' + needs: [PackageUbuntu20_04_ARM] + runs-on: ubuntu-latest + steps: + - name: Download package + uses: actions/download-artifact@v4 + with: + name: ubuntu-22.04-aarch64 + path: build/output/release + - name: Upload to S3 + uses: jakejarvis/s3-sync-action@v0.5.1 + env: + AWS_S3_BUCKET: "deps.memgraph.io" + AWS_ACCESS_KEY_ID: ${{ secrets.S3_AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_AWS_SECRET_ACCESS_KEY }} + AWS_REGION: "eu-west-1" + SOURCE_DIR: "build/output/release" + DEST_DIR: "memgraph-unofficial/${{ github.ref_name }}/" + + PackageDebian11: + if: github.ref_type == 'tag' + needs: [Debian10, Ubuntu20_04] + runs-on: [self-hosted, DockerMgBuild, X64] + timeout-minutes: 60 + steps: + - name: "Set up repository" + uses: actions/checkout@v4 + with: + fetch-depth: 0 # Required because of release/get_version.py + - name: "Build package" + run: | + ./release/package/run.sh package debian-11 $BUILD_TYPE + - name: Upload to S3 + uses: jakejarvis/s3-sync-action@v0.5.1 + env: + AWS_S3_BUCKET: "deps.memgraph.io" + AWS_ACCESS_KEY_ID: ${{ secrets.S3_AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_AWS_SECRET_ACCESS_KEY }} + AWS_REGION: "eu-west-1" + SOURCE_DIR: "build/output" + DEST_DIR: "memgraph-unofficial/${{ github.ref_name }}/" + - name: "Upload package" + uses: actions/upload-artifact@v4 + with: + name: debian-11 + path: build/output/debian-11/memgraph*.deb + + PackageDebian11_ARM: + if: github.ref_type == 'tag' + needs: [Debian10, Ubuntu20_04] + runs-on: [self-hosted, DockerMgBuild, ARM64] + # M1 Mac mini is sometimes slower + timeout-minutes: 150 + steps: + - name: "Set up repository" + uses: actions/checkout@v4 + with: + fetch-depth: 0 # Required because of release/get_version.py + - name: "Build package" + run: | + ./release/package/run.sh package debian-11-arm $BUILD_TYPE + - name: "Upload package" + uses: actions/upload-artifact@v4 + with: + name: debian-11-aarch64 + path: build/output/debian-11-arm/memgraph*.deb + + PushToS3Debian11_ARM: + if: github.ref_type == 'tag' + needs: [PackageDebian11_ARM] + runs-on: ubuntu-latest + steps: + - name: Download package + uses: actions/download-artifact@v4 + with: + name: debian-11-aarch64 + path: build/output/release + - name: Upload to S3 + uses: jakejarvis/s3-sync-action@v0.5.1 + env: + AWS_S3_BUCKET: "deps.memgraph.io" + AWS_ACCESS_KEY_ID: ${{ secrets.S3_AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_AWS_SECRET_ACCESS_KEY }} + AWS_REGION: "eu-west-1" + SOURCE_DIR: "build/output/release" + DEST_DIR: "memgraph-unofficial/${{ github.ref_name }}/" diff --git a/.github/workflows/release_debian10.yaml b/.github/workflows/release_debian10.yaml index 9feb5b1f0..188617a6e 100644 --- a/.github/workflows/release_debian10.yaml +++ b/.github/workflows/release_debian10.yaml @@ -1,6 +1,12 @@ name: Release Debian 10 on: + workflow_call: + inputs: + build_type: + type: string + description: "Memgraph Build type. Default value is Release." + default: 'Release' workflow_dispatch: inputs: build_type: @@ -11,10 +17,8 @@ on: - Release - RelWithDebInfo - schedule: - - cron: "0 22 * * *" - env: + OS: "Debian10" THREADS: 24 MEMGRAPH_ENTERPRISE_LICENSE: ${{ secrets.MEMGRAPH_ENTERPRISE_LICENSE }} MEMGRAPH_ORGANIZATION_NAME: ${{ secrets.MEMGRAPH_ORGANIZATION_NAME }} @@ -111,7 +115,7 @@ jobs: - name: Save code coverage uses: actions/upload-artifact@v4 with: - name: "Code coverage(Coverage build)" + name: "Code coverage(Coverage build)-${{ env.OS }}" path: tools/github/generated/code_coverage.tar.gz debug_build: @@ -165,7 +169,7 @@ jobs: - name: Save cppcheck and clang-format errors uses: actions/upload-artifact@v4 with: - name: "Code coverage(Debug build)" + name: "Code coverage(Debug build)-${{ env.OS }}" path: tools/github/cppcheck_and_clang_format.txt debug_integration_test: @@ -242,7 +246,7 @@ jobs: - name: Save enterprise DEB package uses: actions/upload-artifact@v4 with: - name: "Enterprise DEB package" + name: "Enterprise DEB package-${{ env.OS}}" path: build/output/memgraph*.deb - name: Run GQL Behave tests @@ -255,7 +259,7 @@ jobs: - name: Save quality assurance status uses: actions/upload-artifact@v4 with: - name: "GQL Behave Status" + name: "GQL Behave Status-${{ env.OS }}" path: | tests/gql_behave/gql_behave_status.csv tests/gql_behave/gql_behave_status.html @@ -321,7 +325,6 @@ jobs: --no-strict release_e2e_test: - if: false name: "Release End-to-end Test" runs-on: [self-hosted, Linux, X64, Debian10] timeout-minutes: 60 @@ -456,5 +459,5 @@ jobs: uses: actions/upload-artifact@v4 if: ${{ always() }} with: - name: "Jepsen Report" + name: "Jepsen Report-${{ env.OS }}" path: tests/jepsen/Jepsen.tar.gz diff --git a/.github/workflows/release_ubuntu2004.yaml b/.github/workflows/release_ubuntu2004.yaml index 77feea2fe..be6099128 100644 --- a/.github/workflows/release_ubuntu2004.yaml +++ b/.github/workflows/release_ubuntu2004.yaml @@ -1,6 +1,12 @@ name: Release Ubuntu 20.04 on: + workflow_call: + inputs: + build_type: + type: string + description: "Memgraph Build type. Default value is Release." + default: 'Release' workflow_dispatch: inputs: build_type: @@ -11,10 +17,8 @@ on: - Release - RelWithDebInfo - schedule: - - cron: "0 22 * * *" - env: + OS: "Ubuntu 20.04" THREADS: 24 MEMGRAPH_ENTERPRISE_LICENSE: ${{ secrets.MEMGRAPH_ENTERPRISE_LICENSE }} MEMGRAPH_ORGANIZATION_NAME: ${{ secrets.MEMGRAPH_ORGANIZATION_NAME }} @@ -107,7 +111,7 @@ jobs: - name: Save code coverage uses: actions/upload-artifact@v4 with: - name: "Code coverage(Coverage build)" + name: "Code coverage(Coverage build)-${{ env.OS }}" path: tools/github/generated/code_coverage.tar.gz debug_build: @@ -161,7 +165,7 @@ jobs: - name: Save cppcheck and clang-format errors uses: actions/upload-artifact@v4 with: - name: "Code coverage(Debug build)" + name: "Code coverage(Debug build)-${{ env.OS }}" path: tools/github/cppcheck_and_clang_format.txt debug_integration_test: @@ -238,7 +242,7 @@ jobs: - name: Save enterprise DEB package uses: actions/upload-artifact@v4 with: - name: "Enterprise DEB package" + name: "Enterprise DEB package-${{ env.OS }}" path: build/output/memgraph*.deb - name: Run GQL Behave tests @@ -251,7 +255,7 @@ jobs: - name: Save quality assurance status uses: actions/upload-artifact@v4 with: - name: "GQL Behave Status" + name: "GQL Behave Status-${{ env.OS }}" path: | tests/gql_behave/gql_behave_status.csv tests/gql_behave/gql_behave_status.html @@ -317,7 +321,6 @@ jobs: --no-strict release_e2e_test: - if: false name: "Release End-to-end Test" runs-on: [self-hosted, Linux, X64, Ubuntu20.04] timeout-minutes: 60 diff --git a/.github/workflows/stress_test_large.yaml b/.github/workflows/stress_test_large.yaml index bdb805f5d..712d245ae 100644 --- a/.github/workflows/stress_test_large.yaml +++ b/.github/workflows/stress_test_large.yaml @@ -1,4 +1,7 @@ name: Stress test large +concurrency: + group: ${{ github.workflow }}-${{ github.ref_name }} + cancel-in-progress: true on: workflow_dispatch: @@ -10,7 +13,10 @@ on: options: - Release - RelWithDebInfo - + push: + tags: + - "v*.*.*-rc*" + - "v*.*-rc*" schedule: - cron: "0 22 * * *" diff --git a/CMakeLists.txt b/CMakeLists.txt index 62c5a6fcf..028406447 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -211,8 +211,13 @@ set(CMAKE_CXX_FLAGS_RELWITHDEBINFO # ** Static linking is allowed only for executables! ** set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -static-libgcc -static-libstdc++") -# Use lld linker to speedup build -add_link_options(-fuse-ld=lld) # TODO: use mold linker +# Use lld linker to speedup build and use less memory. +add_link_options(-fuse-ld=lld) +# NOTE: Moving to latest Clang (probably starting from 15), lld stopped to work +# without explicit link_directories call. +string(REPLACE ":" " " LD_LIBS $ENV{LD_LIBRARY_PATH}) +separate_arguments(LD_LIBS) +link_directories(${LD_LIBS}) # release flags set(CMAKE_CXX_FLAGS_RELEASE "-O2 -DNDEBUG") @@ -271,18 +276,6 @@ endif() set(libs_dir ${CMAKE_SOURCE_DIR}/libs) add_subdirectory(libs EXCLUDE_FROM_ALL) -option(MG_EXPERIMENTAL_HIGH_AVAILABILITY "Feature flag for experimental high availability" OFF) - -if (NOT MG_ENTERPRISE AND MG_EXPERIMENTAL_HIGH_AVAILABILITY) - set(MG_EXPERIMENTAL_HIGH_AVAILABILITY OFF) - message(FATAL_ERROR "MG_EXPERIMENTAL_HIGH_AVAILABILITY can only be used with enterpise version of the code.") -endif () - -if (MG_EXPERIMENTAL_HIGH_AVAILABILITY) - add_compile_definitions(MG_EXPERIMENTAL_HIGH_AVAILABILITY) -endif () - -# Optional subproject configuration ------------------------------------------- option(TEST_COVERAGE "Generate coverage reports from running memgraph" OFF) option(TOOLS "Build tools binaries" ON) option(QUERY_MODULES "Build query modules containing custom procedures" ON) diff --git a/environment/os/amzn-2.sh b/environment/os/amzn-2.sh index 15ff29106..a9cc3e4b2 100755 --- a/environment/os/amzn-2.sh +++ b/environment/os/amzn-2.sh @@ -1,7 +1,5 @@ #!/bin/bash - set -Eeuo pipefail - DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source "$DIR/../util.sh" @@ -9,7 +7,7 @@ check_operating_system "amzn-2" check_architecture "x86_64" TOOLCHAIN_BUILD_DEPS=( - gcc gcc-c++ make # generic build tools + git gcc gcc-c++ make # generic build tools wget # used for archive download gnupg2 # used for archive signature verification tar gzip bzip2 xz unzip # used for archive unpacking @@ -63,6 +61,8 @@ MEMGRAPH_BUILD_DEPS=( cyrus-sasl-devel ) +MEMGRAPH_TEST_DEPS="${MEMGRAPH_BUILD_DEPS[*]}" + MEMGRAPH_RUN_DEPS=( logrotate openssl python3 libseccomp ) diff --git a/environment/os/centos-7.sh b/environment/os/centos-7.sh index df16fbc73..d9fc93912 100755 --- a/environment/os/centos-7.sh +++ b/environment/os/centos-7.sh @@ -1,7 +1,5 @@ #!/bin/bash - set -Eeuo pipefail - DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source "$DIR/../util.sh" @@ -63,6 +61,8 @@ MEMGRAPH_BUILD_DEPS=( cyrus-sasl-devel ) +MEMGRAPH_TEST_DEPS="${MEMGRAPH_BUILD_DEPS[*]}" + MEMGRAPH_RUN_DEPS=( logrotate openssl python3 libseccomp ) diff --git a/environment/os/centos-9.sh b/environment/os/centos-9.sh index 8a431807e..8177c9223 100755 --- a/environment/os/centos-9.sh +++ b/environment/os/centos-9.sh @@ -1,7 +1,5 @@ #!/bin/bash - set -Eeuo pipefail - DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source "$DIR/../util.sh" @@ -9,8 +7,10 @@ check_operating_system "centos-9" check_architecture "x86_64" TOOLCHAIN_BUILD_DEPS=( - coreutils-common gcc gcc-c++ make # generic build tools wget # used for archive download + coreutils-common gcc gcc-c++ make # generic build tools + # NOTE: Pure libcurl conflicts with libcurl-minimal + libcurl-devel # cmake build requires it gnupg2 # used for archive signature verification tar gzip bzip2 xz unzip # used for archive unpacking zlib-devel # zlib library used for all builds @@ -64,6 +64,8 @@ MEMGRAPH_BUILD_DEPS=( cyrus-sasl-devel ) +MEMGRAPH_TEST_DEPS="${MEMGRAPH_BUILD_DEPS[*]}" + MEMGRAPH_RUN_DEPS=( logrotate openssl python3 libseccomp ) @@ -123,7 +125,9 @@ install() { else echo "NOTE: export LANG=en_US.utf8" fi - yum update -y + # --nobest is used because of libipt because we install custom versions + # because libipt-devel is not available on CentOS 9 Stream + yum update -y --nobest yum install -y wget git python3 python3-pip for pkg in $1; do diff --git a/environment/os/debian-10.sh b/environment/os/debian-10.sh index 4c1deda42..9a64854de 100755 --- a/environment/os/debian-10.sh +++ b/environment/os/debian-10.sh @@ -1,10 +1,10 @@ #!/bin/bash - set -Eeuo pipefail - DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source "$DIR/../util.sh" +# IMPORTANT: Deprecated since memgraph v2.12.0. + check_operating_system "debian-10" check_architecture "x86_64" diff --git a/environment/os/debian-11-arm.sh b/environment/os/debian-11-arm.sh index c8a3cca1c..8e17a8fdd 100755 --- a/environment/os/debian-11-arm.sh +++ b/environment/os/debian-11-arm.sh @@ -1,10 +1,10 @@ #!/bin/bash - set -Eeuo pipefail - DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source "$DIR/../util.sh" +# IMPORTANT: Deprecated since memgraph v2.12.0. + check_operating_system "debian-11" check_architecture "arm64" "aarch64" diff --git a/environment/os/debian-11.sh b/environment/os/debian-11.sh index c7e82b52c..ac05f6ba6 100755 --- a/environment/os/debian-11.sh +++ b/environment/os/debian-11.sh @@ -1,7 +1,5 @@ #!/bin/bash - set -Eeuo pipefail - DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source "$DIR/../util.sh" @@ -61,6 +59,8 @@ MEMGRAPH_BUILD_DEPS=( libsasl2-dev ) +MEMGRAPH_TEST_DEPS="${MEMGRAPH_BUILD_DEPS[*]}" + MEMGRAPH_RUN_DEPS=( logrotate openssl python3 libseccomp ) diff --git a/environment/os/debian-12-arm.sh b/environment/os/debian-12-arm.sh new file mode 100755 index 000000000..15d3f7473 --- /dev/null +++ b/environment/os/debian-12-arm.sh @@ -0,0 +1,134 @@ +#!/bin/bash +set -Eeuo pipefail +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +source "$DIR/../util.sh" + +check_operating_system "debian-12" +check_architecture "arm64" "aarch64" + +TOOLCHAIN_BUILD_DEPS=( + coreutils gcc g++ build-essential make # generic build tools + wget # used for archive download + gnupg # used for archive signature verification + tar gzip bzip2 xz-utils unzip # used for archive unpacking + zlib1g-dev # zlib library used for all builds + libexpat1-dev liblzma-dev python3-dev texinfo # for gdb + libcurl4-openssl-dev # for cmake + libreadline-dev # for cmake and llvm + libffi-dev libxml2-dev # for llvm + libedit-dev libpcre2-dev libpcre3-dev automake bison # for swig + curl # snappy + file # for libunwind + libssl-dev # for libevent + libgmp-dev + gperf # for proxygen + git # for fbthrift +) + +TOOLCHAIN_RUN_DEPS=( + make # generic build tools + tar gzip bzip2 xz-utils # used for archive unpacking + zlib1g # zlib library used for all builds + libexpat1 liblzma5 python3 # for gdb + libcurl4 # for cmake + file # for CPack + libreadline8 # for cmake and llvm + libffi8 libxml2 # for llvm + libssl-dev # for libevent +) + +MEMGRAPH_BUILD_DEPS=( + git # source code control + make pkg-config # build system + curl wget # for downloading libs + uuid-dev default-jre-headless # required by antlr + libreadline-dev # for memgraph console + libpython3-dev python3-dev # for query modules + libssl-dev + libseccomp-dev + netcat # tests are using nc to wait for memgraph + python3 virtualenv python3-virtualenv python3-pip # for qa, macro_benchmark and stress tests + python3-yaml # for the configuration generator + libcurl4-openssl-dev # mg-requests + sbcl # for custom Lisp C++ preprocessing + doxygen graphviz # source documentation generators + mono-runtime mono-mcs zip unzip default-jdk-headless custom-maven3.9.3 # for driver tests + dotnet-sdk-7.0 golang custom-golang1.18.9 nodejs npm + autoconf # for jemalloc code generation + libtool # for protobuf code generation + libsasl2-dev +) + +MEMGRAPH_RUN_DEPS=( + logrotate openssl python3 libseccomp +) + +NEW_DEPS=( + wget curl tar gzip +) + +list() { + echo "$1" +} + +check() { + local missing="" + for pkg in $1; do + if [ "$pkg" == custom-maven3.9.3 ]; then + if [ ! -f "/opt/apache-maven-3.9.3/bin/mvn" ]; then + missing="$pkg $missing" + fi + continue + fi + if [ "$pkg" == custom-golang1.18.9 ]; then + if [ ! -f "/opt/go1.18.9/go/bin/go" ]; then + missing="$pkg $missing" + fi + continue + fi + if ! dpkg -s "$pkg" >/dev/null 2>/dev/null; then + missing="$pkg $missing" + fi + done + if [ "$missing" != "" ]; then + echo "MISSING PACKAGES: $missing" + exit 1 + fi +} + +install() { + cd "$DIR" + apt update + # If GitHub Actions runner is installed, append LANG to the environment. + # Python related tests doesn't work the LANG export. + if [ -d "/home/gh/actions-runner" ]; then + echo "LANG=en_US.utf8" >> /home/gh/actions-runner/.env + else + echo "NOTE: export LANG=en_US.utf8" + fi + apt install -y wget + + for pkg in $1; do + if [ "$pkg" == custom-maven3.9.3 ]; then + install_custom_maven "3.9.3" + continue + fi + if [ "$pkg" == custom-golang1.18.9 ]; then + install_custom_golang "1.18.9" + continue + fi + if [ "$pkg" == dotnet-sdk-7.0 ]; then + if ! dpkg -s "$pkg" 2>/dev/null >/dev/null; then + wget -nv https://packages.microsoft.com/config/debian/12/packages-microsoft-prod.deb -O packages-microsoft-prod.deb + dpkg -i packages-microsoft-prod.deb + apt-get update + apt-get install -y apt-transport-https dotnet-sdk-7.0 + fi + continue + fi + apt install -y "$pkg" + done +} + +deps=$2"[*]" +"$1" "${!deps}" diff --git a/environment/os/debian-12.sh b/environment/os/debian-12.sh new file mode 100755 index 000000000..1709230ad --- /dev/null +++ b/environment/os/debian-12.sh @@ -0,0 +1,136 @@ +#!/bin/bash +set -Eeuo pipefail +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +source "$DIR/../util.sh" + +check_operating_system "debian-12" +check_architecture "x86_64" + +TOOLCHAIN_BUILD_DEPS=( + coreutils gcc g++ build-essential make # generic build tools + wget # used for archive download + gnupg # used for archive signature verification + tar gzip bzip2 xz-utils unzip # used for archive unpacking + zlib1g-dev # zlib library used for all builds + libexpat1-dev libipt-dev libbabeltrace-dev liblzma-dev python3-dev texinfo # for gdb + libcurl4-openssl-dev # for cmake + libreadline-dev # for cmake and llvm + libffi-dev libxml2-dev # for llvm + libedit-dev libpcre2-dev libpcre3-dev automake bison # for swig + curl # snappy + file # for libunwind + libssl-dev # for libevent + libgmp-dev + gperf # for proxygen + git # for fbthrift +) + +TOOLCHAIN_RUN_DEPS=( + make # generic build tools + tar gzip bzip2 xz-utils # used for archive unpacking + zlib1g # zlib library used for all builds + libexpat1 libipt2 libbabeltrace1 liblzma5 python3 # for gdb + libcurl4 # for cmake + file # for CPack + libreadline8 # for cmake and llvm + libffi8 libxml2 # for llvm + libssl-dev # for libevent +) + +MEMGRAPH_BUILD_DEPS=( + git # source code control + make cmake pkg-config # build system + curl wget # for downloading libs + uuid-dev default-jre-headless # required by antlr + libreadline-dev # for memgraph console + libpython3-dev python3-dev # for query modules + libssl-dev + libseccomp-dev + netcat-traditional # tests are using nc to wait for memgraph + python3 virtualenv python3-virtualenv python3-pip # for qa, macro_benchmark and stress tests + python3-yaml # for the configuration generator + libcurl4-openssl-dev # mg-requests + sbcl # for custom Lisp C++ preprocessing + doxygen graphviz # source documentation generators + mono-runtime mono-mcs zip unzip default-jdk-headless custom-maven3.9.3 # for driver tests + dotnet-sdk-7.0 golang custom-golang1.18.9 nodejs npm + autoconf # for jemalloc code generation + libtool # for protobuf code generation + libsasl2-dev +) + +MEMGRAPH_TEST_DEPS="${MEMGRAPH_BUILD_DEPS[*]}" + +MEMGRAPH_RUN_DEPS=( + logrotate openssl python3 libseccomp +) + +NEW_DEPS=( + wget curl tar gzip +) + +list() { + echo "$1" +} + +check() { + local missing="" + for pkg in $1; do + if [ "$pkg" == custom-maven3.9.3 ]; then + if [ ! -f "/opt/apache-maven-3.9.3/bin/mvn" ]; then + missing="$pkg $missing" + fi + continue + fi + if [ "$pkg" == custom-golang1.18.9 ]; then + if [ ! -f "/opt/go1.18.9/go/bin/go" ]; then + missing="$pkg $missing" + fi + continue + fi + if ! dpkg -s "$pkg" >/dev/null 2>/dev/null; then + missing="$pkg $missing" + fi + done + if [ "$missing" != "" ]; then + echo "MISSING PACKAGES: $missing" + exit 1 + fi +} + +install() { + cd "$DIR" + apt update + # If GitHub Actions runner is installed, append LANG to the environment. + # Python related tests doesn't work the LANG export. + if [ -d "/home/gh/actions-runner" ]; then + echo "LANG=en_US.utf8" >> /home/gh/actions-runner/.env + else + echo "NOTE: export LANG=en_US.utf8" + fi + apt install -y wget + + for pkg in $1; do + if [ "$pkg" == custom-maven3.9.3 ]; then + install_custom_maven "3.9.3" + continue + fi + if [ "$pkg" == custom-golang1.18.9 ]; then + install_custom_golang "1.18.9" + continue + fi + if [ "$pkg" == dotnet-sdk-7.0 ]; then + if ! dpkg -s "$pkg" 2>/dev/null >/dev/null; then + wget -nv https://packages.microsoft.com/config/debian/12/packages-microsoft-prod.deb -O packages-microsoft-prod.deb + dpkg -i packages-microsoft-prod.deb + apt-get update + apt-get install -y apt-transport-https dotnet-sdk-7.0 + fi + continue + fi + apt install -y "$pkg" + done +} + +deps=$2"[*]" +"$1" "${!deps}" diff --git a/environment/os/fedora-36.sh b/environment/os/fedora-36.sh index f7bd0c53a..f8b8995d9 100755 --- a/environment/os/fedora-36.sh +++ b/environment/os/fedora-36.sh @@ -1,10 +1,10 @@ #!/bin/bash - set -Eeuo pipefail - DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source "$DIR/../util.sh" +# IMPORTANT: Deprecated since memgraph v2.12.0. + check_operating_system "fedora-36" check_architecture "x86_64" @@ -27,6 +27,7 @@ TOOLCHAIN_BUILD_DEPS=( libipt libipt-devel # intel patch perl # for openssl + git ) TOOLCHAIN_RUN_DEPS=( diff --git a/environment/os/fedora-38.sh b/environment/os/fedora-38.sh index 7837f018b..951bec46f 100755 --- a/environment/os/fedora-38.sh +++ b/environment/os/fedora-38.sh @@ -1,7 +1,5 @@ #!/bin/bash - set -Eeuo pipefail - DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source "$DIR/../util.sh" @@ -27,6 +25,7 @@ TOOLCHAIN_BUILD_DEPS=( libipt libipt-devel # intel patch perl # for openssl + git ) TOOLCHAIN_RUN_DEPS=( @@ -58,6 +57,16 @@ MEMGRAPH_BUILD_DEPS=( libtool # for protobuf code generation ) +MEMGRAPH_TEST_DEPS="${MEMGRAPH_BUILD_DEPS[*]}" + +MEMGRAPH_RUN_DEPS=( + logrotate openssl python3 libseccomp +) + +NEW_DEPS=( + wget curl tar gzip +) + list() { echo "$1" } diff --git a/environment/os/fedora-39.sh b/environment/os/fedora-39.sh new file mode 100755 index 000000000..4b0e82992 --- /dev/null +++ b/environment/os/fedora-39.sh @@ -0,0 +1,117 @@ +#!/bin/bash +set -Eeuo pipefail +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +source "$DIR/../util.sh" + +check_operating_system "fedora-39" +check_architecture "x86_64" + +TOOLCHAIN_BUILD_DEPS=( + coreutils-common gcc gcc-c++ make # generic build tools + wget # used for archive download + gnupg2 # used for archive signature verification + tar gzip bzip2 xz unzip # used for archive unpacking + zlib-devel # zlib library used for all builds + expat-devel xz-devel python3-devel texinfo libbabeltrace-devel # for gdb + curl libcurl-devel # for cmake + readline-devel # for cmake and llvm + libffi-devel libxml2-devel # for llvm + libedit-devel pcre-devel pcre2-devel automake bison # for swig + file + openssl-devel + gmp-devel + gperf + diffutils + libipt libipt-devel # intel + patch + perl # for openssl + git +) + +TOOLCHAIN_RUN_DEPS=( + make # generic build tools + tar gzip bzip2 xz # used for archive unpacking + zlib # zlib library used for all builds + expat xz-libs python3 # for gdb + readline # for cmake and llvm + libffi libxml2 # for llvm + openssl-devel +) + +MEMGRAPH_BUILD_DEPS=( + git # source code control + make pkgconf-pkg-config # build system + wget # for downloading libs + libuuid-devel java-11-openjdk # required by antlr + readline-devel # for memgraph console + python3-devel # for query modules + openssl-devel + libseccomp-devel + python3 python3-pip python3-virtualenv python3-virtualenvwrapper python3-pyyaml nmap-ncat # for tests + libcurl-devel # mg-requests + rpm-build rpmlint # for RPM package building + doxygen graphviz # source documentation generators + which nodejs golang zip unzip java-11-openjdk-devel # for driver tests + sbcl # for custom Lisp C++ preprocessing + autoconf # for jemalloc code generation + libtool # for protobuf code generation +) + +MEMGRAPH_TEST_DEPS="${MEMGRAPH_BUILD_DEPS[*]}" + +MEMGRAPH_RUN_DEPS=( + logrotate openssl python3 libseccomp +) + +NEW_DEPS=( + wget curl tar gzip +) + +list() { + echo "$1" +} + +check() { + if [ -v LD_LIBRARY_PATH ]; then + # On Fedora 38 yum/dnf and python11 use newer glibc which is not compatible + # with ours, so we need to momentarely disable env + local OLD_LD_LIBRARY_PATH=${LD_LIBRARY_PATH} + LD_LIBRARY_PATH="" + fi + local missing="" + for pkg in $1; do + if ! dnf list installed "$pkg" >/dev/null 2>/dev/null; then + missing="$pkg $missing" + fi + done + if [ "$missing" != "" ]; then + echo "MISSING PACKAGES: $missing" + exit 1 + fi + if [ -v OLD_LD_LIBRARY_PATH ]; then + echo "Restoring LD_LIBRARY_PATH..." + LD_LIBRARY_PATH=${OLD_LD_LIBRARY_PATH} + fi +} + +install() { + cd "$DIR" + if [ "$EUID" -ne 0 ]; then + echo "Please run as root." + exit 1 + fi + # If GitHub Actions runner is installed, append LANG to the environment. + # Python related tests don't work without the LANG export. + if [ -d "/home/gh/actions-runner" ]; then + echo "LANG=en_US.utf8" >> /home/gh/actions-runner/.env + else + echo "NOTE: export LANG=en_US.utf8" + fi + dnf update -y + for pkg in $1; do + dnf install -y "$pkg" + done +} + +deps=$2"[*]" +"$1" "${!deps}" diff --git a/environment/os/rocky-9.3.sh b/environment/os/rocky-9.3.sh new file mode 100755 index 000000000..571278654 --- /dev/null +++ b/environment/os/rocky-9.3.sh @@ -0,0 +1,188 @@ +#!/bin/bash +set -Eeuo pipefail +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +source "$DIR/../util.sh" + +# TODO(gitbuda): Rocky gets automatically updates -> figure out how to handle it. +check_operating_system "rocky-9.3" +check_architecture "x86_64" + +TOOLCHAIN_BUILD_DEPS=( + wget # used for archive download + coreutils-common gcc gcc-c++ make # generic build tools + # NOTE: Pure libcurl conflicts with libcurl-minimal + libcurl-devel # cmake build requires it + gnupg2 # used for archive signature verification + tar gzip bzip2 xz unzip # used for archive unpacking + zlib-devel # zlib library used for all builds + expat-devel xz-devel python3-devel perl-Unicode-EastAsianWidth texinfo libbabeltrace-devel # for gdb + readline-devel # for cmake and llvm + libffi-devel libxml2-devel # for llvm + libedit-devel pcre-devel pcre2-devel automake bison # for swig + file + openssl-devel + gmp-devel + gperf + diffutils + libipt libipt-devel # intel + patch +) + +TOOLCHAIN_RUN_DEPS=( + make # generic build tools + tar gzip bzip2 xz # used for archive unpacking + zlib # zlib library used for all builds + expat xz-libs python3 # for gdb + readline # for cmake and llvm + libffi libxml2 # for llvm + openssl-devel + perl # for openssl +) + +MEMGRAPH_BUILD_DEPS=( + git # source code control + make cmake pkgconf-pkg-config # build system + wget # for downloading libs + libuuid-devel java-11-openjdk # required by antlr + readline-devel # for memgraph console + python3-devel # for query modules + openssl-devel + libseccomp-devel + python3 python3-pip python3-virtualenv nmap-ncat # for qa, macro_benchmark and stress tests + # + # IMPORTANT: python3-yaml does NOT exist on CentOS + # Install it manually using `pip3 install PyYAML` + # + PyYAML # Package name here does not correspond to the yum package! + libcurl-devel # mg-requests + rpm-build rpmlint # for RPM package building + doxygen graphviz # source documentation generators + which nodejs golang custom-golang1.18.9 # for driver tests + zip unzip java-11-openjdk-devel java-17-openjdk java-17-openjdk-devel custom-maven3.9.3 # for driver tests + sbcl # for custom Lisp C++ preprocessing + autoconf # for jemalloc code generation + libtool # for protobuf code generation + cyrus-sasl-devel +) + +MEMGRAPH_TEST_DEPS="${MEMGRAPH_BUILD_DEPS[*]}" + +MEMGRAPH_RUN_DEPS=( + logrotate openssl python3 libseccomp +) + +NEW_DEPS=( + wget curl tar gzip +) + +list() { + echo "$1" +} + +check() { + local missing="" + for pkg in $1; do + if [ "$pkg" == custom-maven3.9.3 ]; then + if [ ! -f "/opt/apache-maven-3.9.3/bin/mvn" ]; then + missing="$pkg $missing" + fi + continue + fi + if [ "$pkg" == custom-golang1.18.9 ]; then + if [ ! -f "/opt/go1.18.9/go/bin/go" ]; then + missing="$pkg $missing" + fi + continue + fi + if [ "$pkg" == "PyYAML" ]; then + if ! python3 -c "import yaml" >/dev/null 2>/dev/null; then + missing="$pkg $missing" + fi + continue + fi + if [ "$pkg" == "python3-virtualenv" ]; then + continue + fi + if ! yum list installed "$pkg" >/dev/null 2>/dev/null; then + missing="$pkg $missing" + fi + done + if [ "$missing" != "" ]; then + echo "MISSING PACKAGES: $missing" + exit 1 + fi +} + +install() { + cd "$DIR" + if [ "$EUID" -ne 0 ]; then + echo "Please run as root." + exit 1 + fi + # If GitHub Actions runner is installed, append LANG to the environment. + # Python related tests doesn't work the LANG export. + if [ -d "/home/gh/actions-runner" ]; then + echo "LANG=en_US.utf8" >> /home/gh/actions-runner/.env + else + echo "NOTE: export LANG=en_US.utf8" + fi + yum update -y + yum install -y wget git python3 python3-pip + + for pkg in $1; do + if [ "$pkg" == custom-maven3.9.3 ]; then + install_custom_maven "3.9.3" + continue + fi + if [ "$pkg" == custom-golang1.18.9 ]; then + install_custom_golang "1.18.9" + continue + fi + if [ "$pkg" == perl-Unicode-EastAsianWidth ]; then + if ! dnf list installed perl-Unicode-EastAsianWidth >/dev/null 2>/dev/null; then + dnf install -y https://dl.rockylinux.org/pub/rocky/9/CRB/x86_64/os/Packages/p/perl-Unicode-EastAsianWidth-12.0-7.el9.noarch.rpm + fi + continue + fi + if [ "$pkg" == texinfo ]; then + if ! dnf list installed texinfo >/dev/null 2>/dev/null; then + dnf install -y https://dl.rockylinux.org/pub/rocky/9/CRB/x86_64/os/Packages/t/texinfo-6.7-15.el9.x86_64.rpm + fi + continue + fi + if [ "$pkg" == libbabeltrace-devel ]; then + if ! dnf list installed libbabeltrace-devel >/dev/null 2>/dev/null; then + dnf install -y https://dl.rockylinux.org/pub/rocky/9/devel/x86_64/os/Packages/l/libbabeltrace-devel-1.5.8-10.el9.x86_64.rpm + fi + continue + fi + if [ "$pkg" == libipt-devel ]; then + if ! dnf list installed libipt-devel >/dev/null 2>/dev/null; then + dnf install -y https://dl.rockylinux.org/pub/rocky/9/devel/x86_64/os/Packages/l/libipt-devel-2.0.4-5.el9.x86_64.rpm + fi + continue + fi + if [ "$pkg" == PyYAML ]; then + if [ -z ${SUDO_USER+x} ]; then # Running as root (e.g. Docker). + pip3 install --user PyYAML + else # Running using sudo. + sudo -H -u "$SUDO_USER" bash -c "pip3 install --user PyYAML" + fi + continue + fi + if [ "$pkg" == python3-virtualenv ]; then + if [ -z ${SUDO_USER+x} ]; then # Running as root (e.g. Docker). + pip3 install virtualenv + pip3 install virtualenvwrapper + else # Running using sudo. + sudo -H -u "$SUDO_USER" bash -c "pip3 install virtualenv" + sudo -H -u "$SUDO_USER" bash -c "pip3 install virtualenvwrapper" + fi + continue + fi + yum install -y "$pkg" + done +} + +deps=$2"[*]" +"$1" "${!deps}" diff --git a/environment/os/template.sh b/environment/os/template.sh index b1f2f8fe4..692926efb 100755 --- a/environment/os/template.sh +++ b/environment/os/template.sh @@ -1,7 +1,5 @@ #!/bin/bash - set -Eeuo pipefail - DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source "$DIR/../util.sh" @@ -20,6 +18,10 @@ MEMGRAPH_BUILD_DEPS=( pkg ) +MEMGRAPH_TEST_DEPS=( + pkg +) + MEMGRAPH_RUN_DEPS=( pkg ) diff --git a/environment/os/ubuntu-18.04.sh b/environment/os/ubuntu-18.04.sh index 27d876e4f..451d5e69c 100755 --- a/environment/os/ubuntu-18.04.sh +++ b/environment/os/ubuntu-18.04.sh @@ -1,10 +1,10 @@ #!/bin/bash - set -Eeuo pipefail - DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source "$DIR/../util.sh" +# IMPORTANT: Deprecated since memgraph v2.12.0. + check_operating_system "ubuntu-18.04" check_architecture "x86_64" diff --git a/environment/os/ubuntu-20.04.sh b/environment/os/ubuntu-20.04.sh index 8a308406e..7739b49d1 100755 --- a/environment/os/ubuntu-20.04.sh +++ b/environment/os/ubuntu-20.04.sh @@ -1,7 +1,5 @@ #!/bin/bash - set -Eeuo pipefail - DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source "$DIR/../util.sh" @@ -60,6 +58,8 @@ MEMGRAPH_BUILD_DEPS=( libsasl2-dev ) +MEMGRAPH_TEST_DEPS="${MEMGRAPH_BUILD_DEPS[*]}" + MEMGRAPH_RUN_DEPS=( logrotate openssl python3 libseccomp2 ) diff --git a/environment/os/ubuntu-22.04-arm.sh b/environment/os/ubuntu-22.04-arm.sh index 45a4f3d4c..9326e52e9 100755 --- a/environment/os/ubuntu-22.04-arm.sh +++ b/environment/os/ubuntu-22.04-arm.sh @@ -1,7 +1,5 @@ #!/bin/bash - set -Eeuo pipefail - DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source "$DIR/../util.sh" @@ -60,6 +58,8 @@ MEMGRAPH_BUILD_DEPS=( libsasl2-dev ) +MEMGRAPH_TEST_DEPS="${MEMGRAPH_BUILD_DEPS[*]}" + MEMGRAPH_RUN_DEPS=( logrotate openssl python3 libseccomp2 ) diff --git a/environment/os/ubuntu-22.04.sh b/environment/os/ubuntu-22.04.sh index 59361dd81..649338e53 100755 --- a/environment/os/ubuntu-22.04.sh +++ b/environment/os/ubuntu-22.04.sh @@ -1,7 +1,5 @@ #!/bin/bash - set -Eeuo pipefail - DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" source "$DIR/../util.sh" @@ -60,6 +58,8 @@ MEMGRAPH_BUILD_DEPS=( libsasl2-dev ) +MEMGRAPH_TEST_DEPS="${MEMGRAPH_BUILD_DEPS[*]}" + MEMGRAPH_RUN_DEPS=( logrotate openssl python3 libseccomp2 ) diff --git a/environment/toolchain/.gitignore b/environment/toolchain/.gitignore index e75f93b12..6ba5f327d 100644 --- a/environment/toolchain/.gitignore +++ b/environment/toolchain/.gitignore @@ -2,3 +2,4 @@ archives build output *.tar.gz +tmp_build.sh diff --git a/environment/toolchain/template_build.sh b/environment/toolchain/template_build.sh new file mode 100644 index 000000000..b01902ab9 --- /dev/null +++ b/environment/toolchain/template_build.sh @@ -0,0 +1,48 @@ +#!/bin/bash -e + +# NOTE: Copy this under memgraph/environment/toolchain/vN/tmp_build.sh, edit and test. + +pushd () { command pushd "$@" > /dev/null; } +popd () { command popd "$@" > /dev/null; } +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +CPUS=$( grep -c processor < /proc/cpuinfo ) +cd "$DIR" +source "$DIR/../../util.sh" +DISTRO="$(operating_system)" +TOOLCHAIN_VERSION=5 +NAME=toolchain-v$TOOLCHAIN_VERSION +PREFIX=/opt/$NAME +function log_tool_name () { + echo "" + echo "" + echo "#### $1 ####" + echo "" + echo "" +} + +# HERE: Remove/clear dependencies from a given toolchain. + +mkdir -p archives && pushd archives +# HERE: Download dependencies here. +popd + +mkdir -p build +pushd build +source $PREFIX/activate +export CC=$PREFIX/bin/clang +export CXX=$PREFIX/bin/clang++ +export CFLAGS="$CFLAGS -fPIC" +export PATH=$PREFIX/bin:$PATH +export LD_LIBRARY_PATH=$PREFIX/lib64 +COMMON_CMAKE_FLAGS="-DCMAKE_INSTALL_PREFIX=$PREFIX + -DCMAKE_PREFIX_PATH=$PREFIX + -DCMAKE_BUILD_TYPE=Release + -DCMAKE_C_COMPILER=$CC + -DCMAKE_CXX_COMPILER=$CXX + -DBUILD_SHARED_LIBS=OFF + -DCMAKE_CXX_STANDARD=20 + -DBUILD_TESTING=OFF + -DCMAKE_REQUIRED_INCLUDES=$PREFIX/include + -DCMAKE_POSITION_INDEPENDENT_CODE=ON" + +# HERE: Add dependencies to test below. diff --git a/environment/toolchain/v5/build.sh b/environment/toolchain/v5/build.sh index b6c1ff6d8..aade6f9c5 100755 --- a/environment/toolchain/v5/build.sh +++ b/environment/toolchain/v5/build.sh @@ -307,7 +307,7 @@ if [ ! -f $PREFIX/bin/ld.gold ]; then fi log_tool_name "GDB $GDB_VERSION" -if [ ! -f $PREFIX/bin/gdb ]; then +if [[ ! -f "$PREFIX/bin/gdb" && "$DISTRO" -ne "amzn-2" ]]; then if [ -d gdb-$GDB_VERSION ]; then rm -rf gdb-$GDB_VERSION fi @@ -671,7 +671,6 @@ PROXYGEN_SHA256=5360a8ccdfb2f5a6c7b3eed331ec7ab0e2c792d579c6fff499c85c516c11fe14 WANGLE_SHA256=1002e9c32b6f4837f6a760016e3b3e22f3509880ef3eaad191c80dc92655f23f # WANGLE_SHA256=0e493c03572bb27fe9ca03a9da5023e52fde99c95abdcaa919bb6190e7e69532 -FLEX_VERSION=2.6.4 FMT_SHA256=78b8c0a72b1c35e4443a7e308df52498252d1cefc2b08c9a97bc9ee6cfe61f8b FMT_VERSION=10.1.1 # NOTE: spdlog depends on exact fmt versions -> UPGRADE fmt and spdlog TOGETHER. @@ -690,8 +689,8 @@ LZ4_VERSION=1.9.4 SNAPPY_SHA256=75c1fbb3d618dd3a0483bff0e26d0a92b495bbe5059c8b4f1c962b478b6e06e7 SNAPPY_VERSION=1.1.9 XZ_VERSION=5.2.5 # for LZMA -ZLIB_VERSION=1.3 -ZSTD_VERSION=1.5.0 +ZLIB_VERSION=1.3.1 +ZSTD_VERSION=1.5.5 pushd archives if [ ! -f boost_$BOOST_VERSION_UNDERSCORES.tar.gz ]; then @@ -700,7 +699,7 @@ if [ ! -f boost_$BOOST_VERSION_UNDERSCORES.tar.gz ]; then wget https://boostorg.jfrog.io/artifactory/main/release/$BOOST_VERSION/source/boost_$BOOST_VERSION_UNDERSCORES.tar.gz -O boost_$BOOST_VERSION_UNDERSCORES.tar.gz fi if [ ! -f bzip2-$BZIP2_VERSION.tar.gz ]; then - wget https://sourceforge.net/projects/bzip2/files/bzip2-$BZIP2_VERSION.tar.gz -O bzip2-$BZIP2_VERSION.tar.gz + wget https://sourceware.org/pub/bzip2/bzip2-$BZIP2_VERSION.tar.gz -O bzip2-$BZIP2_VERSION.tar.gz fi if [ ! -f double-conversion-$DOUBLE_CONVERSION_VERSION.tar.gz ]; then wget https://github.com/google/double-conversion/archive/refs/tags/v$DOUBLE_CONVERSION_VERSION.tar.gz -O double-conversion-$DOUBLE_CONVERSION_VERSION.tar.gz @@ -708,9 +707,7 @@ fi if [ ! -f fizz-$FBLIBS_VERSION.tar.gz ]; then wget https://github.com/facebookincubator/fizz/releases/download/v$FBLIBS_VERSION/fizz-v$FBLIBS_VERSION.tar.gz -O fizz-$FBLIBS_VERSION.tar.gz fi -if [ ! -f flex-$FLEX_VERSION.tar.gz ]; then - wget https://github.com/westes/flex/releases/download/v$FLEX_VERSION/flex-$FLEX_VERSION.tar.gz -O flex-$FLEX_VERSION.tar.gz -fi + if [ ! -f fmt-$FMT_VERSION.tar.gz ]; then wget https://github.com/fmtlib/fmt/archive/refs/tags/$FMT_VERSION.tar.gz -O fmt-$FMT_VERSION.tar.gz fi @@ -765,14 +762,6 @@ echo "$BZIP2_SHA256 bzip2-$BZIP2_VERSION.tar.gz" | sha256sum -c echo "$DOUBLE_CONVERSION_SHA256 double-conversion-$DOUBLE_CONVERSION_VERSION.tar.gz" | sha256sum -c # verify fizz echo "$FIZZ_SHA256 fizz-$FBLIBS_VERSION.tar.gz" | sha256sum -c -# verify flex -if [ ! -f flex-$FLEX_VERSION.tar.gz.sig ]; then - wget https://github.com/westes/flex/releases/download/v$FLEX_VERSION/flex-$FLEX_VERSION.tar.gz.sig -fi -if false; then - $GPG --keyserver $KEYSERVER --recv-keys 0xE4B29C8D64885307 - $GPG --verify flex-$FLEX_VERSION.tar.gz.sig flex-$FLEX_VERSION.tar.gz -fi # verify fmt echo "$FMT_SHA256 fmt-$FMT_VERSION.tar.gz" | sha256sum -c # verify spdlog @@ -1025,7 +1014,6 @@ if [ ! -d $PREFIX/include/gflags ]; then if [ -d gflags ]; then rm -rf gflags fi - git clone https://github.com/memgraph/gflags.git gflags pushd gflags git checkout $GFLAGS_COMMIT_HASH @@ -1034,7 +1022,7 @@ if [ ! -d $PREFIX/include/gflags ]; then cmake .. $COMMON_CMAKE_FLAGS \ -DREGISTER_INSTALL_PREFIX=OFF \ -DBUILD_gflags_nothreads_LIB=OFF \ - -DGFLAGS_NO_FILENAMES=0 + -DGFLAGS_NO_FILENAMES=1 make -j$CPUS install popd && popd fi @@ -1232,18 +1220,6 @@ if false; then fi fi -log_tool_name "flex $FLEX_VERSION" -if [ ! -f $PREFIX/include/FlexLexer.h ]; then - if [ -d flex-$FLEX_VERSION ]; then - rm -rf flex-$FLEX_VERSION - fi - tar -xzf ../archives/flex-$FLEX_VERSION.tar.gz - pushd flex-$FLEX_VERSION - ./configure $COMMON_CONFIGURE_FLAGS - make -j$CPUS install - popd -fi - popd # NOTE: It's important/clean (e.g., easier upload to S3) to have a separated # folder to the output archive. diff --git a/include/_mgp.hpp b/include/_mgp.hpp index 4f6797739..8b67bc36a 100644 --- a/include/_mgp.hpp +++ b/include/_mgp.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -283,7 +283,7 @@ inline mgp_list *list_all_unique_constraints(mgp_graph *graph, mgp_memory *memor } // mgp_graph - + inline bool graph_is_transactional(mgp_graph *graph) { return MgInvoke(mgp_graph_is_transactional, graph); } inline bool graph_is_mutable(mgp_graph *graph) { return MgInvoke(mgp_graph_is_mutable, graph); } diff --git a/libs/.gitignore b/libs/.gitignore index 1d149f2f0..6eb8fabc0 100644 --- a/libs/.gitignore +++ b/libs/.gitignore @@ -7,3 +7,4 @@ !pulsar.patch !antlr4.10.1.patch !rocksdb8.1.1.patch +!nuraft2.1.0.patch diff --git a/libs/CMakeLists.txt b/libs/CMakeLists.txt index fd6823ee5..7d568d548 100644 --- a/libs/CMakeLists.txt +++ b/libs/CMakeLists.txt @@ -16,7 +16,7 @@ set(GFLAGS_NOTHREADS OFF) # NOTE: config/generate.py depends on the gflags help XML format. find_package(gflags REQUIRED) -find_package(fmt 8.0.1) +find_package(fmt 8.0.1 REQUIRED) find_package(ZLIB 1.2.11 REQUIRED) set(LIB_DIR ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/libs/librdtsc.patch b/libs/librdtsc.patch index 70a98c94a..c6022adac 100644 --- a/libs/librdtsc.patch +++ b/libs/librdtsc.patch @@ -5,7 +5,7 @@ index ee9b58c..31359a9 100644 @@ -48,7 +48,7 @@ option(LIBRDTSC_USE_PMU "Enables PMU usage on ARM platforms" OFF) # | Library Build and Install Properties | # +--------------------------------------------------------+ - + -add_library(rdtsc SHARED +add_library(rdtsc src/cycles.c @@ -14,7 +14,7 @@ index ee9b58c..31359a9 100644 @@ -72,15 +72,6 @@ target_include_directories(rdtsc PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include ) - + -# Install directory changes depending on build mode -if (CMAKE_BUILD_TYPE MATCHES "^[Dd]ebug") - # During debug, the library will be installed into a local directory @@ -27,3 +27,15 @@ index ee9b58c..31359a9 100644 # Specifying what to export when installing (GNUInstallDirs required) install(TARGETS rdtsc EXPORT librstsc-config +diff --git a/include/librdtsc/common_timer.h b/include/librdtsc/common_timer.h +index a6922d8..080dc77 100644 +--- a/include/librdtsc/common_timer.h ++++ b/include/librdtsc/common_timer.h +@@ -2,6 +2,7 @@ + #define LIBRDTSC_COMMON_TIMER_H + + #include ++#include + + extern uint64_t rdtsc_get_tsc_freq_arch(); + extern uint64_t rdtsc_get_tsc_freq(); diff --git a/libs/nuraft2.1.0.patch b/libs/nuraft2.1.0.patch new file mode 100644 index 000000000..574978872 --- /dev/null +++ b/libs/nuraft2.1.0.patch @@ -0,0 +1,24 @@ +diff --git a/include/libnuraft/asio_service_options.hxx b/include/libnuraft/asio_service_options.hxx +index 8fe1ec9..9497355 100644 +--- a/include/libnuraft/asio_service_options.hxx ++++ b/include/libnuraft/asio_service_options.hxx +@@ -17,6 +17,7 @@ limitations under the License. + + #pragma once + ++#include + #include + #include + #include +diff --git a/include/libnuraft/callback.hxx b/include/libnuraft/callback.hxx +index 7b71624..d48c1e2 100644 +--- a/include/libnuraft/callback.hxx ++++ b/include/libnuraft/callback.hxx +@@ -18,6 +18,7 @@ limitations under the License. + #ifndef _CALLBACK_H_ + #define _CALLBACK_H_ + ++#include + #include + #include + diff --git a/libs/rocksdb.patch b/libs/rocksdb.patch deleted file mode 100644 index 297e509fb..000000000 --- a/libs/rocksdb.patch +++ /dev/null @@ -1,21 +0,0 @@ -diff --git a/CMakeLists.txt b/CMakeLists.txt -index 6761929..6a369af 100644 ---- a/CMakeLists.txt -+++ b/CMakeLists.txt -@@ -220,6 +220,7 @@ else() - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -momit-leaf-frame-pointer") - endif() - endif() -+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-deprecated-copy -Wno-unused-but-set-variable") - endif() - - include(CheckCCompilerFlag) -@@ -997,7 +998,7 @@ if(NOT WIN32 OR ROCKSDB_INSTALL_ON_WINDOWS) - - if(ROCKSDB_BUILD_SHARED) - install( -- TARGETS ${ROCKSDB_SHARED_LIB} -+ TARGETS ${ROCKSDB_SHARED_LIB} OPTIONAL - EXPORT RocksDBTargets - COMPONENT runtime - ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}" diff --git a/libs/setup.sh b/libs/setup.sh index 74291911e..9c2a38c47 100755 --- a/libs/setup.sh +++ b/libs/setup.sh @@ -168,12 +168,11 @@ pushd antlr4 git apply ../antlr4.10.1.patch popd -# cppitertools v2.0 2019-12-23 -cppitertools_ref="cb3635456bdb531121b82b4d2e3afc7ae1f56d47" +cppitertools_ref="v2.1" # 2021-01-15 repo_clone_try_double "${primary_urls[cppitertools]}" "${secondary_urls[cppitertools]}" "cppitertools" "$cppitertools_ref" # rapidcheck -rapidcheck_tag="7bc7d302191a4f3d0bf005692677126136e02f60" # (2020-05-04) +rapidcheck_tag="1c91f40e64d87869250cfb610376c629307bf77d" # (2023-08-15) repo_clone_try_double "${primary_urls[rapidcheck]}" "${secondary_urls[rapidcheck]}" "rapidcheck" "$rapidcheck_tag" # google benchmark @@ -221,7 +220,7 @@ repo_clone_try_double "${primary_urls[pymgclient]}" "${secondary_urls[pymgclient mgconsole_tag="v1.4.0" # (2023-05-21) repo_clone_try_double "${primary_urls[mgconsole]}" "${secondary_urls[mgconsole]}" "mgconsole" "$mgconsole_tag" true -spdlog_tag="v1.9.2" # (2021-08-12) +spdlog_tag="v1.12.0" # (2022-11-02) repo_clone_try_double "${primary_urls[spdlog]}" "${secondary_urls[spdlog]}" "spdlog" "$spdlog_tag" true # librdkafka @@ -286,5 +285,6 @@ repo_clone_try_double "${primary_urls[range-v3]}" "${secondary_urls[range-v3]}" nuraft_tag="v2.1.0" repo_clone_try_double "${primary_urls[nuraft]}" "${secondary_urls[nuraft]}" "nuraft" "$nuraft_tag" true pushd nuraft +git apply ../nuraft2.1.0.patch ./prepare.sh popd diff --git a/src/auth/auth.cpp b/src/auth/auth.cpp index 3bb7648db..54825c70c 100644 --- a/src/auth/auth.cpp +++ b/src/auth/auth.cpp @@ -35,16 +35,42 @@ DEFINE_VALIDATED_string(auth_module_executable, "", "Absolute path to the auth m } return true; }); -DEFINE_bool(auth_module_create_missing_user, true, "Set to false to disable creation of missing users."); -DEFINE_bool(auth_module_create_missing_role, true, "Set to false to disable creation of missing roles."); -DEFINE_bool(auth_module_manage_roles, true, "Set to false to disable management of roles through the auth module."); DEFINE_VALIDATED_int32(auth_module_timeout_ms, 10000, "Timeout (in milliseconds) used when waiting for a " "response from the auth module.", FLAG_IN_RANGE(100, 1800000)); +// DEPRECATED FLAGS +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables, misc-unused-parameters) +DEFINE_VALIDATED_HIDDEN_bool(auth_module_create_missing_user, true, + "Set to false to disable creation of missing users.", { + spdlog::warn( + "auth_module_create_missing_user flag is deprecated. It not possible to create " + "users through the module anymore."); + return true; + }); + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables, misc-unused-parameters) +DEFINE_VALIDATED_HIDDEN_bool(auth_module_create_missing_role, true, + "Set to false to disable creation of missing roles.", { + spdlog::warn( + "auth_module_create_missing_role flag is deprecated. It not possible to create " + "roles through the module anymore."); + return true; + }); + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables, misc-unused-parameters) +DEFINE_VALIDATED_HIDDEN_bool( + auth_module_manage_roles, true, "Set to false to disable management of roles through the auth module.", { + spdlog::warn( + "auth_module_manage_roles flag is deprecated. It not possible to create roles through the module anymore."); + return true; + }); + namespace memgraph::auth { +const Auth::Epoch Auth::kStartEpoch = 1; + namespace { #ifdef MG_ENTERPRISE /** @@ -192,6 +218,17 @@ void MigrateVersions(kvstore::KVStore &store) { version_str = kVersionV1; } } + +auto ParseJson(std::string_view str) { + nlohmann::json data; + try { + data = nlohmann::json::parse(str); + } catch (const nlohmann::json::parse_error &e) { + throw AuthException("Couldn't load auth data!"); + } + return data; +} + }; // namespace Auth::Auth(std::string storage_directory, Config config) @@ -199,8 +236,11 @@ Auth::Auth(std::string storage_directory, Config config) MigrateVersions(storage_); } -std::optional Auth::Authenticate(const std::string &username, const std::string &password) { +std::optional Auth::Authenticate(const std::string &username, const std::string &password) { if (module_.IsUsed()) { + /* + * MODULE AUTH STORAGE + */ const auto license_check_result = license::global_license_checker.IsEnterpriseValid(utils::global_settings); if (license_check_result.HasError()) { spdlog::warn(license::LicenseCheckErrorToString(license_check_result.GetError(), "authentication modules")); @@ -225,108 +265,64 @@ std::optional Auth::Authenticate(const std::string &username, const std::s auto is_authenticated = ret_authenticated.get(); const auto &rolename = ret_role.get(); + // Check if role is present + auto role = GetRole(rolename); + if (!role) { + spdlog::warn(utils::MessageWithLink("Couldn't authenticate user '{}' because the role '{}' doesn't exist.", + username, rolename, "https://memgr.ph/auth")); + return std::nullopt; + } + // Authenticate the user. if (!is_authenticated) return std::nullopt; - /** - * TODO - * The auth module should not update auth data. - * There is now way to replicate it and we should not be storing sensitive data if we don't have to. - */ - - // Find or create the user and return it. - auto user = GetUser(username); - if (!user) { - if (FLAGS_auth_module_create_missing_user) { - user = AddUser(username, password); - if (!user) { - spdlog::warn(utils::MessageWithLink( - "Couldn't create the missing user '{}' using the auth module because the user already exists as a role.", - username, "https://memgr.ph/auth")); - return std::nullopt; - } - } else { - spdlog::warn(utils::MessageWithLink( - "Couldn't authenticate user '{}' using the auth module because the user doesn't exist.", username, - "https://memgr.ph/auth")); - return std::nullopt; - } - } else { - UpdatePassword(*user, password); - } - if (FLAGS_auth_module_manage_roles) { - if (!rolename.empty()) { - auto role = GetRole(rolename); - if (!role) { - if (FLAGS_auth_module_create_missing_role) { - role = AddRole(rolename); - if (!role) { - spdlog::warn( - utils::MessageWithLink("Couldn't authenticate user '{}' using the auth module because the user's " - "role '{}' already exists as a user.", - username, rolename, "https://memgr.ph/auth")); - return std::nullopt; - } - SaveRole(*role); - } else { - spdlog::warn(utils::MessageWithLink( - "Couldn't authenticate user '{}' using the auth module because the user's role '{}' doesn't exist.", - username, rolename, "https://memgr.ph/auth")); - return std::nullopt; - } - } - user->SetRole(*role); - } else { - user->ClearRole(); - } - } - SaveUser(*user); - return user; - } else { - auto user = GetUser(username); - if (!user) { - spdlog::warn(utils::MessageWithLink("Couldn't authenticate user '{}' because the user doesn't exist.", username, - "https://memgr.ph/auth")); - return std::nullopt; - } - if (!user->CheckPassword(password)) { - spdlog::warn(utils::MessageWithLink("Couldn't authenticate user '{}' because the password is not correct.", - username, "https://memgr.ph/auth")); - return std::nullopt; - } - if (user->UpgradeHash(password)) { - SaveUser(*user); - } - - return user; + return RoleWUsername{username, std::move(*role)}; } + + /* + * LOCAL AUTH STORAGE + */ + auto user = GetUser(username); + if (!user) { + spdlog::warn(utils::MessageWithLink("Couldn't authenticate user '{}' because the user doesn't exist.", username, + "https://memgr.ph/auth")); + return std::nullopt; + } + if (!user->CheckPassword(password)) { + spdlog::warn(utils::MessageWithLink("Couldn't authenticate user '{}' because the password is not correct.", + username, "https://memgr.ph/auth")); + return std::nullopt; + } + if (user->UpgradeHash(password)) { + SaveUser(*user); + } + + return user; } -std::optional Auth::GetUser(const std::string &username_orig) const { - auto username = utils::ToLowerCase(username_orig); - auto existing_user = storage_.Get(kUserPrefix + username); - if (!existing_user) return std::nullopt; - - nlohmann::json data; - try { - data = nlohmann::json::parse(*existing_user); - } catch (const nlohmann::json::parse_error &e) { - throw AuthException("Couldn't load user data!"); - } - - auto user = User::Deserialize(data); - auto link = storage_.Get(kLinkPrefix + username); - +void Auth::LinkUser(User &user) const { + auto link = storage_.Get(kLinkPrefix + user.username()); if (link) { auto role = GetRole(*link); if (role) { user.SetRole(*role); } } +} + +std::optional Auth::GetUser(const std::string &username_orig) const { + if (module_.IsUsed()) return std::nullopt; // User's are not supported when using module + auto username = utils::ToLowerCase(username_orig); + auto existing_user = storage_.Get(kUserPrefix + username); + if (!existing_user) return std::nullopt; + + auto user = User::Deserialize(ParseJson(*existing_user)); + LinkUser(user); return user; } void Auth::SaveUser(const User &user, system::Transaction *system_tx) { + DisableIfModuleUsed(); bool success = false; if (const auto *role = user.role(); role != nullptr) { success = storage_.PutMultiple( @@ -338,6 +334,10 @@ void Auth::SaveUser(const User &user, system::Transaction *system_tx) { if (!success) { throw AuthException("Couldn't save user '{}'!", user.username()); } + + // Durability updated -> new epoch + UpdateEpoch(); + // All changes to the user end up calling this function, so no need to add a delta anywhere else if (system_tx) { #ifdef MG_ENTERPRISE @@ -347,6 +347,7 @@ void Auth::SaveUser(const User &user, system::Transaction *system_tx) { } void Auth::UpdatePassword(auth::User &user, const std::optional &password) { + DisableIfModuleUsed(); // Check if null if (!password) { if (!config_.password_permit_null) { @@ -378,6 +379,7 @@ void Auth::UpdatePassword(auth::User &user, const std::optional &pa std::optional Auth::AddUser(const std::string &username, const std::optional &password, system::Transaction *system_tx) { + DisableIfModuleUsed(); if (!NameRegexMatch(username)) { throw AuthException("Invalid user name."); } @@ -392,12 +394,17 @@ std::optional Auth::AddUser(const std::string &username, const std::option } bool Auth::RemoveUser(const std::string &username_orig, system::Transaction *system_tx) { + DisableIfModuleUsed(); auto username = utils::ToLowerCase(username_orig); if (!storage_.Get(kUserPrefix + username)) return false; std::vector keys({kLinkPrefix + username, kUserPrefix + username}); if (!storage_.DeleteMultiple(keys)) { throw AuthException("Couldn't remove user '{}'!", username); } + + // Durability updated -> new epoch + UpdateEpoch(); + // Handling drop user delta if (system_tx) { #ifdef MG_ENTERPRISE @@ -412,9 +419,12 @@ std::vector Auth::AllUsers() const { for (auto it = storage_.begin(kUserPrefix); it != storage_.end(kUserPrefix); ++it) { auto username = it->first.substr(kUserPrefix.size()); if (username != utils::ToLowerCase(username)) continue; - auto user = GetUser(username); - if (user) { - ret.push_back(std::move(*user)); + try { + User user = auth::User::Deserialize(ParseJson(it->second)); // Will throw on failure + LinkUser(user); + ret.emplace_back(std::move(user)); + } catch (AuthException &) { + continue; } } return ret; @@ -425,9 +435,12 @@ std::vector Auth::AllUsernames() const { for (auto it = storage_.begin(kUserPrefix); it != storage_.end(kUserPrefix); ++it) { auto username = it->first.substr(kUserPrefix.size()); if (username != utils::ToLowerCase(username)) continue; - auto user = GetUser(username); - if (user) { - ret.push_back(username); + try { + // Check if serialized correctly + memgraph::auth::User::Deserialize(ParseJson(it->second)); // Will throw on failure + ret.emplace_back(std::move(username)); + } catch (AuthException &) { + continue; } } return ret; @@ -435,25 +448,24 @@ std::vector Auth::AllUsernames() const { bool Auth::HasUsers() const { return storage_.begin(kUserPrefix) != storage_.end(kUserPrefix); } +bool Auth::AccessControlled() const { return HasUsers() || module_.IsUsed(); } + std::optional Auth::GetRole(const std::string &rolename_orig) const { auto rolename = utils::ToLowerCase(rolename_orig); auto existing_role = storage_.Get(kRolePrefix + rolename); if (!existing_role) return std::nullopt; - nlohmann::json data; - try { - data = nlohmann::json::parse(*existing_role); - } catch (const nlohmann::json::parse_error &e) { - throw AuthException("Couldn't load role data!"); - } - - return Role::Deserialize(data); + return Role::Deserialize(ParseJson(*existing_role)); } void Auth::SaveRole(const Role &role, system::Transaction *system_tx) { if (!storage_.Put(kRolePrefix + role.rolename(), role.Serialize().dump())) { throw AuthException("Couldn't save role '{}'!", role.rolename()); } + + // Durability updated -> new epoch + UpdateEpoch(); + // All changes to the role end up calling this function, so no need to add a delta anywhere else if (system_tx) { #ifdef MG_ENTERPRISE @@ -486,6 +498,10 @@ bool Auth::RemoveRole(const std::string &rolename_orig, system::Transaction *sys if (!storage_.DeleteMultiple(keys)) { throw AuthException("Couldn't remove role '{}'!", rolename); } + + // Durability updated -> new epoch + UpdateEpoch(); + // Handling drop role delta if (system_tx) { #ifdef MG_ENTERPRISE @@ -500,11 +516,8 @@ std::vector Auth::AllRoles() const { for (auto it = storage_.begin(kRolePrefix); it != storage_.end(kRolePrefix); ++it) { auto rolename = it->first.substr(kRolePrefix.size()); if (rolename != utils::ToLowerCase(rolename)) continue; - if (auto role = GetRole(rolename)) { - ret.push_back(*role); - } else { - throw AuthException("Couldn't load role '{}'!", rolename); - } + Role role = memgraph::auth::Role::Deserialize(ParseJson(it->second)); // Will throw on failure + ret.emplace_back(std::move(role)); } return ret; } @@ -514,14 +527,19 @@ std::vector Auth::AllRolenames() const { for (auto it = storage_.begin(kRolePrefix); it != storage_.end(kRolePrefix); ++it) { auto rolename = it->first.substr(kRolePrefix.size()); if (rolename != utils::ToLowerCase(rolename)) continue; - if (auto role = GetRole(rolename)) { - ret.push_back(rolename); + try { + // Check that the data is serialized correctly + memgraph::auth::Role::Deserialize(ParseJson(it->second)); + ret.emplace_back(std::move(rolename)); + } catch (AuthException &) { + continue; } } return ret; } std::vector Auth::AllUsersForRole(const std::string &rolename_orig) const { + DisableIfModuleUsed(); const auto rolename = utils::ToLowerCase(rolename_orig); std::vector ret; for (auto it = storage_.begin(kLinkPrefix); it != storage_.end(kLinkPrefix); ++it) { @@ -540,51 +558,176 @@ std::vector Auth::AllUsersForRole(const std::string &rolename_orig) } #ifdef MG_ENTERPRISE -bool Auth::GrantDatabaseToUser(const std::string &db, const std::string &name, system::Transaction *system_tx) { - if (auto user = GetUser(name)) { - if (db == kAllDatabases) { - user->db_access().GrantAll(); - } else { - user->db_access().Add(db); +Auth::Result Auth::GrantDatabase(const std::string &db, const std::string &name, system::Transaction *system_tx) { + using enum Auth::Result; + if (module_.IsUsed()) { + if (auto role = GetRole(name)) { + GrantDatabase(db, *role, system_tx); + return SUCCESS; } - SaveUser(*user, system_tx); - return true; + return NO_ROLE; } - return false; + if (auto user = GetUser(name)) { + GrantDatabase(db, *user, system_tx); + return SUCCESS; + } + if (auto role = GetRole(name)) { + GrantDatabase(db, *role, system_tx); + return SUCCESS; + } + return NO_USER_ROLE; } -bool Auth::RevokeDatabaseFromUser(const std::string &db, const std::string &name, system::Transaction *system_tx) { - if (auto user = GetUser(name)) { - if (db == kAllDatabases) { - user->db_access().DenyAll(); - } else { - user->db_access().Remove(db); - } - SaveUser(*user, system_tx); - return true; +void Auth::GrantDatabase(const std::string &db, User &user, system::Transaction *system_tx) { + if (db == kAllDatabases) { + user.db_access().GrantAll(); + } else { + user.db_access().Grant(db); } - return false; + SaveUser(user, system_tx); +} + +void Auth::GrantDatabase(const std::string &db, Role &role, system::Transaction *system_tx) { + if (db == kAllDatabases) { + role.db_access().GrantAll(); + } else { + role.db_access().Grant(db); + } + SaveRole(role, system_tx); +} + +Auth::Result Auth::DenyDatabase(const std::string &db, const std::string &name, system::Transaction *system_tx) { + using enum Auth::Result; + if (module_.IsUsed()) { + if (auto role = GetRole(name)) { + DenyDatabase(db, *role, system_tx); + return SUCCESS; + } + return NO_ROLE; + } + if (auto user = GetUser(name)) { + DenyDatabase(db, *user, system_tx); + return SUCCESS; + } + if (auto role = GetRole(name)) { + DenyDatabase(db, *role, system_tx); + return SUCCESS; + } + return NO_USER_ROLE; +} + +void Auth::DenyDatabase(const std::string &db, User &user, system::Transaction *system_tx) { + if (db == kAllDatabases) { + user.db_access().DenyAll(); + } else { + user.db_access().Deny(db); + } + SaveUser(user, system_tx); +} + +void Auth::DenyDatabase(const std::string &db, Role &role, system::Transaction *system_tx) { + if (db == kAllDatabases) { + role.db_access().DenyAll(); + } else { + role.db_access().Deny(db); + } + SaveRole(role, system_tx); +} + +Auth::Result Auth::RevokeDatabase(const std::string &db, const std::string &name, system::Transaction *system_tx) { + using enum Auth::Result; + if (module_.IsUsed()) { + if (auto role = GetRole(name)) { + RevokeDatabase(db, *role, system_tx); + return SUCCESS; + } + return NO_ROLE; + } + if (auto user = GetUser(name)) { + RevokeDatabase(db, *user, system_tx); + return SUCCESS; + } + if (auto role = GetRole(name)) { + RevokeDatabase(db, *role, system_tx); + return SUCCESS; + } + return NO_USER_ROLE; +} + +void Auth::RevokeDatabase(const std::string &db, User &user, system::Transaction *system_tx) { + if (db == kAllDatabases) { + user.db_access().RevokeAll(); + } else { + user.db_access().Revoke(db); + } + SaveUser(user, system_tx); +} + +void Auth::RevokeDatabase(const std::string &db, Role &role, system::Transaction *system_tx) { + if (db == kAllDatabases) { + role.db_access().RevokeAll(); + } else { + role.db_access().Revoke(db); + } + SaveRole(role, system_tx); } void Auth::DeleteDatabase(const std::string &db, system::Transaction *system_tx) { for (auto it = storage_.begin(kUserPrefix); it != storage_.end(kUserPrefix); ++it) { auto username = it->first.substr(kUserPrefix.size()); - if (auto user = GetUser(username)) { - user->db_access().Delete(db); - SaveUser(*user, system_tx); + try { + User user = auth::User::Deserialize(ParseJson(it->second)); + LinkUser(user); + user.db_access().Revoke(db); + SaveUser(user, system_tx); + } catch (AuthException &) { + continue; + } + } + for (auto it = storage_.begin(kRolePrefix); it != storage_.end(kRolePrefix); ++it) { + auto rolename = it->first.substr(kRolePrefix.size()); + try { + auto role = memgraph::auth::Role::Deserialize(ParseJson(it->second)); + role.db_access().Revoke(db); + SaveRole(role, system_tx); + } catch (AuthException &) { + continue; } } } -bool Auth::SetMainDatabase(std::string_view db, const std::string &name, system::Transaction *system_tx) { - if (auto user = GetUser(name)) { - if (!user->db_access().SetDefault(db)) { - throw AuthException("Couldn't set default database '{}' for user '{}'!", db, name); +Auth::Result Auth::SetMainDatabase(std::string_view db, const std::string &name, system::Transaction *system_tx) { + using enum Auth::Result; + if (module_.IsUsed()) { + if (auto role = GetRole(name)) { + SetMainDatabase(db, *role, system_tx); + return SUCCESS; } - SaveUser(*user, system_tx); - return true; + return NO_ROLE; } - return false; + if (auto user = GetUser(name)) { + SetMainDatabase(db, *user, system_tx); + return SUCCESS; + } + if (auto role = GetRole(name)) { + SetMainDatabase(db, *role, system_tx); + return SUCCESS; + } + return NO_USER_ROLE; +} + +void Auth::SetMainDatabase(std::string_view db, User &user, system::Transaction *system_tx) { + if (!user.db_access().SetMain(db)) { + throw AuthException("Couldn't set default database '{}' for '{}'!", db, user.username()); + } + SaveUser(user, system_tx); +} + +void Auth::SetMainDatabase(std::string_view db, Role &role, system::Transaction *system_tx) { + if (!role.db_access().SetMain(db)) { + throw AuthException("Couldn't set default database '{}' for '{}'!", db, role.rolename()); + } + SaveRole(role, system_tx); } #endif diff --git a/src/auth/auth.hpp b/src/auth/auth.hpp index 4b1bcd479..f8d3d58be 100644 --- a/src/auth/auth.hpp +++ b/src/auth/auth.hpp @@ -29,6 +29,18 @@ using SynchedAuth = memgraph::utils::Synchronized + RoleWUsername(std::string_view username, Args &&...args) : Role{std::forward(args)...}, username_{username} {} + + std::string username() { return username_; } + const std::string &username() const { return username_; } + + private: + std::string username_; +}; +using UserOrRole = std::variant; + /** * This class serves as the main Authentication/Authorization storage. * It provides functions for managing Users, Roles, Permissions and FineGrainedAccessPermissions. @@ -61,6 +73,25 @@ class Auth final { std::regex password_regex{password_regex_str}; }; + struct Epoch { + Epoch() : epoch_{0} {} + Epoch(unsigned e) : epoch_{e} {} + + Epoch operator++() { return ++epoch_; } + bool operator==(const Epoch &rhs) const = default; + + private: + unsigned epoch_; + }; + + static const Epoch kStartEpoch; + + enum class Result { + SUCCESS, + NO_USER_ROLE, + NO_ROLE, + }; + explicit Auth(std::string storage_directory, Config config); /** @@ -89,7 +120,7 @@ class Auth final { * @return a user when the username and password match, nullopt otherwise * @throw AuthException if unable to authenticate for whatever reason. */ - std::optional Authenticate(const std::string &username, const std::string &password); + std::optional Authenticate(const std::string &username, const std::string &password); /** * Gets a user from the storage. @@ -101,6 +132,8 @@ class Auth final { */ std::optional GetUser(const std::string &username) const; + void LinkUser(User &user) const; + /** * Saves a user object to the storage. * @@ -163,6 +196,13 @@ class Auth final { */ bool HasUsers() const; + /** + * Returns whether the access is controlled by authentication/authorization. + * + * @return `true` if auth needs to run + */ + bool AccessControlled() const; + /** * Gets a role from the storage. * @@ -173,6 +213,37 @@ class Auth final { */ std::optional GetRole(const std::string &rolename) const; + std::optional GetUserOrRole(const std::optional &username, + const std::optional &rolename) const { + auto expect = [](bool condition, std::string &&msg) { + if (!condition) throw AuthException(std::move(msg)); + }; + // Special case if we are using a module; we must find the specified role + if (module_.IsUsed()) { + expect(username && rolename, "When using a module, a role needs to be connected to a username."); + const auto role = GetRole(*rolename); + expect(role != std::nullopt, "No role named " + *rolename); + return UserOrRole(auth::RoleWUsername{*username, *role}); + } + + // First check if we need to find a role + if (username && rolename) { + const auto role = GetRole(*rolename); + expect(role != std::nullopt, "No role named " + *rolename); + return UserOrRole(auth::RoleWUsername{*username, *role}); + } + + // We are only looking for a user + if (username) { + const auto user = GetUser(*username); + expect(user != std::nullopt, "No user named " + *username); + return *user; + } + + // No user or role + return std::nullopt; + } + /** * Saves a role object to the storage. * @@ -229,16 +300,6 @@ class Auth final { std::vector AllUsersForRole(const std::string &rolename) const; #ifdef MG_ENTERPRISE - /** - * @brief Revoke access to individual database for a user. - * - * @param db name of the database to revoke - * @param name user's username - * @return true on success - * @throw AuthException if unable to find or update the user - */ - bool RevokeDatabaseFromUser(const std::string &db, const std::string &name, system::Transaction *system_tx = nullptr); - /** * @brief Grant access to individual database for a user. * @@ -247,7 +308,33 @@ class Auth final { * @return true on success * @throw AuthException if unable to find or update the user */ - bool GrantDatabaseToUser(const std::string &db, const std::string &name, system::Transaction *system_tx = nullptr); + Result GrantDatabase(const std::string &db, const std::string &name, system::Transaction *system_tx = nullptr); + void GrantDatabase(const std::string &db, User &user, system::Transaction *system_tx = nullptr); + void GrantDatabase(const std::string &db, Role &role, system::Transaction *system_tx = nullptr); + + /** + * @brief Revoke access to individual database for a user. + * + * @param db name of the database to revoke + * @param name user's username + * @return true on success + * @throw AuthException if unable to find or update the user + */ + Result DenyDatabase(const std::string &db, const std::string &name, system::Transaction *system_tx = nullptr); + void DenyDatabase(const std::string &db, User &user, system::Transaction *system_tx = nullptr); + void DenyDatabase(const std::string &db, Role &role, system::Transaction *system_tx = nullptr); + + /** + * @brief Revoke access to individual database for a user. + * + * @param db name of the database to revoke + * @param name user's username + * @return true on success + * @throw AuthException if unable to find or update the user + */ + Result RevokeDatabase(const std::string &db, const std::string &name, system::Transaction *system_tx = nullptr); + void RevokeDatabase(const std::string &db, User &user, system::Transaction *system_tx = nullptr); + void RevokeDatabase(const std::string &db, Role &role, system::Transaction *system_tx = nullptr); /** * @brief Delete a database from all users. @@ -265,9 +352,17 @@ class Auth final { * @return true on success * @throw AuthException if unable to find or update the user */ - bool SetMainDatabase(std::string_view db, const std::string &name, system::Transaction *system_tx = nullptr); + Result SetMainDatabase(std::string_view db, const std::string &name, system::Transaction *system_tx = nullptr); + void SetMainDatabase(std::string_view db, User &user, system::Transaction *system_tx = nullptr); + void SetMainDatabase(std::string_view db, Role &role, system::Transaction *system_tx = nullptr); #endif + bool UpToDate(Epoch &e) const { + bool res = e == epoch_; + e = epoch_; + return res; + } + private: /** * @brief @@ -278,11 +373,18 @@ class Auth final { */ bool NameRegexMatch(const std::string &user_or_role) const; + void UpdateEpoch() { ++epoch_; } + + void DisableIfModuleUsed() const { + if (module_.IsUsed()) throw AuthException("Operation not permited when using an authentication module."); + } + // Even though the `kvstore::KVStore` class is guaranteed to be thread-safe, // Auth is not thread-safe because modifying users and roles might require // more than one operation on the storage. kvstore::KVStore storage_; auth::Module module_; Config config_; + Epoch epoch_{kStartEpoch}; }; } // namespace memgraph::auth diff --git a/src/auth/crypto.hpp b/src/auth/crypto.hpp index c5dfc1c05..a0458a067 100644 --- a/src/auth/crypto.hpp +++ b/src/auth/crypto.hpp @@ -8,10 +8,12 @@ #pragma once -#include +#include #include #include +#include + namespace memgraph::auth { /// Need to be stable, auth durability depends on this enum class PasswordHashAlgorithm : uint8_t { BCRYPT = 0, SHA256 = 1, SHA256_MULTIPLE = 2 }; diff --git a/src/auth/models.cpp b/src/auth/models.cpp index f75e6fe32..51b13329a 100644 --- a/src/auth/models.cpp +++ b/src/auth/models.cpp @@ -425,10 +425,11 @@ Role::Role(const std::string &rolename, const Permissions &permissions) : rolename_(utils::ToLowerCase(rolename)), permissions_(permissions) {} #ifdef MG_ENTERPRISE Role::Role(const std::string &rolename, const Permissions &permissions, - FineGrainedAccessHandler fine_grained_access_handler) + FineGrainedAccessHandler fine_grained_access_handler, Databases db_access) : rolename_(utils::ToLowerCase(rolename)), permissions_(permissions), - fine_grained_access_handler_(std::move(fine_grained_access_handler)) {} + fine_grained_access_handler_(std::move(fine_grained_access_handler)), + db_access_(std::move(db_access)) {} #endif const std::string &Role::rolename() const { return rolename_; } @@ -454,8 +455,10 @@ nlohmann::json Role::Serialize() const { #ifdef MG_ENTERPRISE if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) { data[kFineGrainedAccessHandler] = fine_grained_access_handler_.Serialize(); + data[kDatabases] = db_access_.Serialize(); } else { data[kFineGrainedAccessHandler] = {}; + data[kDatabases] = {}; } #endif return data; @@ -471,12 +474,21 @@ Role Role::Deserialize(const nlohmann::json &data) { auto permissions = Permissions::Deserialize(data[kPermissions]); #ifdef MG_ENTERPRISE if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) { + Databases db_access; + if (data[kDatabases].is_structured()) { + db_access = Databases::Deserialize(data[kDatabases]); + } else { + // Back-compatibility + spdlog::warn("Role without specified database access. Given access to the default database."); + db_access.Grant(dbms::kDefaultDB); + db_access.SetMain(dbms::kDefaultDB); + } FineGrainedAccessHandler fine_grained_access_handler; // We can have an empty fine_grained if the user was created without a valid license if (data[kFineGrainedAccessHandler].is_object()) { fine_grained_access_handler = FineGrainedAccessHandler::Deserialize(data[kFineGrainedAccessHandler]); } - return {data[kRoleName], permissions, std::move(fine_grained_access_handler)}; + return {data[kRoleName], permissions, std::move(fine_grained_access_handler), std::move(db_access)}; } #endif return {data[kRoleName], permissions}; @@ -493,7 +505,7 @@ bool operator==(const Role &first, const Role &second) { } #ifdef MG_ENTERPRISE -void Databases::Add(std::string_view db) { +void Databases::Grant(std::string_view db) { if (allow_all_) { grants_dbs_.clear(); allow_all_ = false; @@ -502,19 +514,19 @@ void Databases::Add(std::string_view db) { denies_dbs_.erase(std::string{db}); // TODO: C++23 use transparent key compare } -void Databases::Remove(const std::string &db) { +void Databases::Deny(const std::string &db) { denies_dbs_.emplace(db); grants_dbs_.erase(db); } -void Databases::Delete(const std::string &db) { +void Databases::Revoke(const std::string &db) { denies_dbs_.erase(db); if (!allow_all_) { grants_dbs_.erase(db); } // Reset if default deleted - if (default_db_ == db) { - default_db_ = ""; + if (main_db_ == db) { + main_db_ = ""; } } @@ -530,9 +542,16 @@ void Databases::DenyAll() { denies_dbs_.clear(); } -bool Databases::SetDefault(std::string_view db) { +void Databases::RevokeAll() { + allow_all_ = false; + grants_dbs_.clear(); + denies_dbs_.clear(); + main_db_ = ""; +} + +bool Databases::SetMain(std::string_view db) { if (!Contains(db)) return false; - default_db_ = db; + main_db_ = db; return true; } @@ -540,11 +559,11 @@ bool Databases::SetDefault(std::string_view db) { return !denies_dbs_.contains(db) && (allow_all_ || grants_dbs_.contains(db)); } -const std::string &Databases::GetDefault() const { - if (!Contains(default_db_)) { - throw AuthException("No access to the set default database \"{}\".", default_db_); +const std::string &Databases::GetMain() const { + if (!Contains(main_db_)) { + throw AuthException("No access to the set default database \"{}\".", main_db_); } - return default_db_; + return main_db_; } nlohmann::json Databases::Serialize() const { @@ -552,7 +571,7 @@ nlohmann::json Databases::Serialize() const { data[kGrants] = grants_dbs_; data[kDenies] = denies_dbs_; data[kAllowAll] = allow_all_; - data[kDefault] = default_db_; + data[kDefault] = main_db_; return data; } @@ -719,15 +738,16 @@ User User::Deserialize(const nlohmann::json &data) { } else { // Back-compatibility spdlog::warn("User without specified database access. Given access to the default database."); - db_access.Add(dbms::kDefaultDB); - db_access.SetDefault(dbms::kDefaultDB); + db_access.Grant(dbms::kDefaultDB); + db_access.SetMain(dbms::kDefaultDB); } FineGrainedAccessHandler fine_grained_access_handler; // We can have an empty fine_grained if the user was created without a valid license if (data[kFineGrainedAccessHandler].is_object()) { fine_grained_access_handler = FineGrainedAccessHandler::Deserialize(data[kFineGrainedAccessHandler]); } - return {data[kUsername], std::move(password_hash), permissions, std::move(fine_grained_access_handler), db_access}; + return {data[kUsername], std::move(password_hash), permissions, std::move(fine_grained_access_handler), + std::move(db_access)}; } #endif return {data[kUsername], std::move(password_hash), permissions}; diff --git a/src/auth/models.hpp b/src/auth/models.hpp index b65d172ff..9b12abee4 100644 --- a/src/auth/models.hpp +++ b/src/auth/models.hpp @@ -205,52 +205,10 @@ class FineGrainedAccessHandler final { bool operator==(const FineGrainedAccessHandler &first, const FineGrainedAccessHandler &second); #endif -class Role final { - public: - Role() = default; - - explicit Role(const std::string &rolename); - Role(const std::string &rolename, const Permissions &permissions); -#ifdef MG_ENTERPRISE - Role(const std::string &rolename, const Permissions &permissions, - FineGrainedAccessHandler fine_grained_access_handler); -#endif - Role(const Role &) = default; - Role &operator=(const Role &) = default; - Role(Role &&) noexcept = default; - Role &operator=(Role &&) noexcept = default; - ~Role() = default; - - const std::string &rolename() const; - const Permissions &permissions() const; - Permissions &permissions(); -#ifdef MG_ENTERPRISE - const FineGrainedAccessHandler &fine_grained_access_handler() const; - FineGrainedAccessHandler &fine_grained_access_handler(); - const FineGrainedAccessPermissions &GetFineGrainedAccessLabelPermissions() const; - const FineGrainedAccessPermissions &GetFineGrainedAccessEdgeTypePermissions() const; -#endif - nlohmann::json Serialize() const; - - /// @throw AuthException if unable to deserialize. - static Role Deserialize(const nlohmann::json &data); - - friend bool operator==(const Role &first, const Role &second); - - private: - std::string rolename_; - Permissions permissions_; -#ifdef MG_ENTERPRISE - FineGrainedAccessHandler fine_grained_access_handler_; -#endif -}; - -bool operator==(const Role &first, const Role &second); - #ifdef MG_ENTERPRISE class Databases final { public: - Databases() : grants_dbs_{std::string{dbms::kDefaultDB}}, allow_all_(false), default_db_(dbms::kDefaultDB) {} + Databases() : grants_dbs_{std::string{dbms::kDefaultDB}}, allow_all_(false), main_db_(dbms::kDefaultDB) {} Databases(const Databases &) = default; Databases &operator=(const Databases &) = default; @@ -263,7 +221,7 @@ class Databases final { * * @param db name of the database to grant access to */ - void Add(std::string_view db); + void Grant(std::string_view db); /** * @brief Remove database to the list of granted access. @@ -272,7 +230,7 @@ class Databases final { * * @param db name of the database to grant access to */ - void Remove(const std::string &db); + void Deny(const std::string &db); /** * @brief Called when database is dropped. Removes it from granted (if allow_all is false) and denied set. @@ -280,7 +238,7 @@ class Databases final { * * @param db name of the database to grant access to */ - void Delete(const std::string &db); + void Revoke(const std::string &db); /** * @brief Set allow_all_ to true and clears grants and denied sets. @@ -292,10 +250,15 @@ class Databases final { */ void DenyAll(); + /** + * @brief Set allow_all_ to false and clears grants and denied sets. + */ + void RevokeAll(); + /** * @brief Set the default database. */ - bool SetDefault(std::string_view db); + bool SetMain(std::string_view db); /** * @brief Checks if access is grated to the database. @@ -304,11 +267,13 @@ class Databases final { * @return true if allow_all and not denied or granted */ bool Contains(std::string_view db) const; + bool Denies(std::string_view db_name) const { return denies_dbs_.contains(db_name); } + bool Grants(std::string_view db_name) const { return allow_all_ || grants_dbs_.contains(db_name); } bool GetAllowAll() const { return allow_all_; } const std::set> &GetGrants() const { return grants_dbs_; } const std::set> &GetDenies() const { return denies_dbs_; } - const std::string &GetDefault() const; + const std::string &GetMain() const; nlohmann::json Serialize() const; /// @throw AuthException if unable to deserialize. @@ -320,15 +285,69 @@ class Databases final { : grants_dbs_(std::move(grant)), denies_dbs_(std::move(deny)), allow_all_(allow_all), - default_db_(std::move(default_db)) {} + main_db_(std::move(default_db)) {} std::set> grants_dbs_; //!< set of databases with granted access std::set> denies_dbs_; //!< set of databases with denied access bool allow_all_; //!< flag to allow access to everything (denied overrides this) - std::string default_db_; //!< user's default database + std::string main_db_; //!< user's default database }; #endif +class Role { + public: + Role() = default; + + explicit Role(const std::string &rolename); + Role(const std::string &rolename, const Permissions &permissions); +#ifdef MG_ENTERPRISE + Role(const std::string &rolename, const Permissions &permissions, + FineGrainedAccessHandler fine_grained_access_handler, Databases db_access = {}); +#endif + Role(const Role &) = default; + Role &operator=(const Role &) = default; + Role(Role &&) noexcept = default; + Role &operator=(Role &&) noexcept = default; + ~Role() = default; + + const std::string &rolename() const; + const Permissions &permissions() const; + Permissions &permissions(); + Permissions GetPermissions() const { return permissions_; } +#ifdef MG_ENTERPRISE + const FineGrainedAccessHandler &fine_grained_access_handler() const; + FineGrainedAccessHandler &fine_grained_access_handler(); + const FineGrainedAccessPermissions &GetFineGrainedAccessLabelPermissions() const; + const FineGrainedAccessPermissions &GetFineGrainedAccessEdgeTypePermissions() const; +#endif + +#ifdef MG_ENTERPRISE + Databases &db_access() { return db_access_; } + const Databases &db_access() const { return db_access_; } + + bool DeniesDB(std::string_view db_name) const { return db_access_.Denies(db_name); } + bool GrantsDB(std::string_view db_name) const { return db_access_.Grants(db_name); } + bool HasAccess(std::string_view db_name) const { return !DeniesDB(db_name) && GrantsDB(db_name); } +#endif + + nlohmann::json Serialize() const; + + /// @throw AuthException if unable to deserialize. + static Role Deserialize(const nlohmann::json &data); + + friend bool operator==(const Role &first, const Role &second); + + private: + std::string rolename_; + Permissions permissions_; +#ifdef MG_ENTERPRISE + FineGrainedAccessHandler fine_grained_access_handler_; + Databases db_access_; +#endif +}; + +bool operator==(const Role &first, const Role &second); + // TODO (mferencevic): Implement password expiry. class User final { public: @@ -388,6 +407,18 @@ class User final { #ifdef MG_ENTERPRISE Databases &db_access() { return database_access_; } const Databases &db_access() const { return database_access_; } + + bool DeniesDB(std::string_view db_name) const { + bool denies = database_access_.Denies(db_name); + if (role_) denies |= role_->DeniesDB(db_name); + return denies; + } + bool GrantsDB(std::string_view db_name) const { + bool grants = database_access_.Grants(db_name); + if (role_) grants |= role_->GrantsDB(db_name); + return grants; + } + bool HasAccess(std::string_view db_name) const { return !DeniesDB(db_name) && GrantsDB(db_name); } #endif nlohmann::json Serialize() const; @@ -403,7 +434,7 @@ class User final { Permissions permissions_; #ifdef MG_ENTERPRISE FineGrainedAccessHandler fine_grained_access_handler_; - Databases database_access_; + Databases database_access_{}; #endif std::optional role_; }; diff --git a/src/auth/module.cpp b/src/auth/module.cpp index 45b93182a..04fa7fa73 100644 --- a/src/auth/module.cpp +++ b/src/auth/module.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Licensed as a Memgraph Enterprise file under the Memgraph Enterprise // License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use @@ -403,7 +403,7 @@ nlohmann::json Module::Call(const nlohmann::json ¶ms, int timeout_millisec) return ret; } -bool Module::IsUsed() { return !module_executable_path_.empty(); } +bool Module::IsUsed() const { return !module_executable_path_.empty(); } void Module::Shutdown() { if (pid_ == -1) return; diff --git a/src/auth/module.hpp b/src/auth/module.hpp index e711708f7..712466950 100644 --- a/src/auth/module.hpp +++ b/src/auth/module.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Licensed as a Memgraph Enterprise file under the Memgraph Enterprise // License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use @@ -49,7 +49,7 @@ class Module final { /// specified executable path and can thus be used. /// /// @return boolean indicating whether the module can be used - bool IsUsed(); + bool IsUsed() const; ~Module(); diff --git a/src/auth/rpc.cpp b/src/auth/rpc.cpp index b658c9491..6f264ccdf 100644 --- a/src/auth/rpc.cpp +++ b/src/auth/rpc.cpp @@ -18,11 +18,9 @@ #include "utils/enum.hpp" namespace memgraph::slk { - // Serialize code for auth::Role -void Save(const auth::Role &self, memgraph::slk::Builder *builder) { - memgraph::slk::Save(self.Serialize().dump(), builder); -} +void Save(const auth::Role &self, Builder *builder) { memgraph::slk::Save(self.Serialize().dump(), builder); } + namespace { auth::Role LoadAuthRole(memgraph::slk::Reader *reader) { std::string tmp; diff --git a/src/communication/bolt/client.cpp b/src/communication/bolt/client.cpp index 39cd24730..29f7d237a 100644 --- a/src/communication/bolt/client.cpp +++ b/src/communication/bolt/client.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -15,6 +15,9 @@ #include "communication/bolt/v1/value.hpp" #include "utils/logging.hpp" +#include "communication/bolt/v1/fmt.hpp" +#include "io/network/fmt.hpp" + namespace { constexpr uint8_t kBoltV43Version[4] = {0x00, 0x00, 0x03, 0x04}; constexpr uint8_t kEmptyBoltVersion[4] = {0x00, 0x00, 0x00, 0x00}; diff --git a/src/communication/bolt/v1/fmt.hpp b/src/communication/bolt/v1/fmt.hpp new file mode 100644 index 000000000..0a6808643 --- /dev/null +++ b/src/communication/bolt/v1/fmt.hpp @@ -0,0 +1,27 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#if FMT_VERSION > 90000 +#include + +#include "communication/bolt/v1/value.hpp" + +template <> +class fmt::formatter : public fmt::ostream_formatter {}; + +template <> +class fmt::formatter> : public fmt::ostream_formatter {}; + +template <> +class fmt::formatter> : public fmt::ostream_formatter {}; +#endif diff --git a/src/communication/fmt.hpp b/src/communication/fmt.hpp new file mode 100644 index 000000000..ab65066b2 --- /dev/null +++ b/src/communication/fmt.hpp @@ -0,0 +1,20 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#if FMT_VERSION > 90000 +#include +#include + +template <> +class fmt::formatter : public fmt::ostream_formatter {}; +#endif diff --git a/src/communication/http/listener.hpp b/src/communication/http/listener.hpp index fac4cfaf3..aa3e7e2f5 100644 --- a/src/communication/http/listener.hpp +++ b/src/communication/http/listener.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -21,6 +21,7 @@ #include #include "communication/context.hpp" +#include "communication/fmt.hpp" #include "communication/http/session.hpp" #include "utils/spin_lock.hpp" #include "utils/synchronized.hpp" @@ -82,7 +83,7 @@ class Listener final : public std::enable_shared_from_this #include "communication/context.hpp" +#include "communication/fmt.hpp" #include "communication/init.hpp" #include "communication/v2/listener.hpp" #include "communication/v2/pool.hpp" @@ -129,7 +130,7 @@ bool Server::Start() { listener_->Start(); spdlog::info("{} server is fully armed and operational", service_name_); - spdlog::info("{} listening on {}", service_name_, endpoint_.address()); + spdlog::info("{} listening on {}", service_name_, endpoint_); context_thread_pool_.Run(); return true; diff --git a/src/communication/v2/session.hpp b/src/communication/v2/session.hpp index b54607729..0b23d9301 100644 --- a/src/communication/v2/session.hpp +++ b/src/communication/v2/session.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -47,6 +47,7 @@ #include "communication/buffer.hpp" #include "communication/context.hpp" #include "communication/exceptions.hpp" +#include "communication/fmt.hpp" #include "dbms/global.hpp" #include "utils/event_counter.hpp" #include "utils/logging.hpp" @@ -212,14 +213,11 @@ class WebsocketSession : public std::enable_shared_from_this +#include "utils/variant_helpers.hpp" namespace memgraph::communication::websocket { bool SafeAuth::Authenticate(const std::string &username, const std::string &password) const { - return auth_->Lock()->Authenticate(username, password).has_value(); + user_or_role_ = auth_->Lock()->Authenticate(username, password); + return user_or_role_.has_value(); } -bool SafeAuth::HasUserPermission(const std::string &username, const auth::Permission permission) const { - if (const auto user = auth_->ReadLock()->GetUser(username); user) { - return user->GetPermissions().Has(permission) == auth::PermissionLevel::GRANT; +bool SafeAuth::HasPermission(const auth::Permission permission) const { + auto locked_auth = auth_->ReadLock(); + // Update if cache invalidated + if (!locked_auth->UpToDate(auth_epoch_) && user_or_role_) { + bool success = true; + std::visit(utils::Overloaded{[&](auth::User &user) { + auto tmp = locked_auth->GetUser(user.username()); + if (!tmp) success = false; + user = std::move(*tmp); + }, + [&](auth::Role &role) { + auto tmp = locked_auth->GetRole(role.rolename()); + if (!tmp) success = false; + role = std::move(*tmp); + }}, + *user_or_role_); + // Missing user/role; delete from cache + if (!success) user_or_role_.reset(); } + // Check permissions + if (user_or_role_) { + return std::visit(utils::Overloaded{[&](auto &user_or_role) { + return user_or_role.GetPermissions().Has(permission) == auth::PermissionLevel::GRANT; + }}, + *user_or_role_); + } + // NOTE: websocket authenticates only if there is a user, so no need to check if access controlled return false; } -bool SafeAuth::HasAnyUsers() const { return auth_->ReadLock()->HasUsers(); } +bool SafeAuth::AccessControlled() const { return auth_->ReadLock()->AccessControlled(); } } // namespace memgraph::communication::websocket diff --git a/src/communication/websocket/auth.hpp b/src/communication/websocket/auth.hpp index 1ab865a2a..cb838382c 100644 --- a/src/communication/websocket/auth.hpp +++ b/src/communication/websocket/auth.hpp @@ -21,9 +21,9 @@ class AuthenticationInterface { public: virtual bool Authenticate(const std::string &username, const std::string &password) const = 0; - virtual bool HasUserPermission(const std::string &username, auth::Permission permission) const = 0; + virtual bool HasPermission(auth::Permission permission) const = 0; - virtual bool HasAnyUsers() const = 0; + virtual bool AccessControlled() const = 0; }; class SafeAuth : public AuthenticationInterface { @@ -32,11 +32,13 @@ class SafeAuth : public AuthenticationInterface { bool Authenticate(const std::string &username, const std::string &password) const override; - bool HasUserPermission(const std::string &username, auth::Permission permission) const override; + bool HasPermission(auth::Permission permission) const override; - bool HasAnyUsers() const override; + bool AccessControlled() const override; private: auth::SynchedAuth *auth_; + mutable std::optional user_or_role_; + mutable auth::Auth::Epoch auth_epoch_{}; }; } // namespace memgraph::communication::websocket diff --git a/src/communication/websocket/listener.cpp b/src/communication/websocket/listener.cpp index 7c8efe203..05c6199dd 100644 --- a/src/communication/websocket/listener.cpp +++ b/src/communication/websocket/listener.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -10,6 +10,7 @@ // licenses/APL.txt. #include "communication/websocket/listener.hpp" +#include "communication/fmt.hpp" namespace memgraph::communication::websocket { namespace { @@ -61,7 +62,7 @@ Listener::Listener(boost::asio::io_context &ioc, ServerContext *context, tcp::en return; } - spdlog::info("WebSocket server is listening on {}:{}", endpoint.address(), endpoint.port()); + spdlog::info("WebSocket server is listening on {}", endpoint); } void Listener::DoAccept() { diff --git a/src/communication/websocket/session.cpp b/src/communication/websocket/session.cpp index 13c788ddd..094ed8f83 100644 --- a/src/communication/websocket/session.cpp +++ b/src/communication/websocket/session.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -80,7 +80,7 @@ bool Session::Run() { return false; } - authenticated_ = !auth_.HasAnyUsers(); + authenticated_ = !auth_.AccessControlled(); connected_.store(true, std::memory_order_relaxed); // run on the strand @@ -162,7 +162,7 @@ utils::BasicResult Session::Authorize(const nlohmann::json &creds) return {"Authentication failed!"}; } #ifdef MG_ENTERPRISE - if (!auth_.HasUserPermission(creds.at("username").get(), auth::Permission::WEBSOCKET)) { + if (!auth_.HasPermission(auth::Permission::WEBSOCKET)) { return {"Authorization failed!"}; } #endif diff --git a/src/coordination/CMakeLists.txt b/src/coordination/CMakeLists.txt index 4937c0ad3..936d7a5c2 100644 --- a/src/coordination/CMakeLists.txt +++ b/src/coordination/CMakeLists.txt @@ -10,12 +10,11 @@ target_sources(mg-coordination include/coordination/coordinator_exceptions.hpp include/coordination/coordinator_slk.hpp include/coordination/coordinator_instance.hpp - include/coordination/coordinator_cluster_config.hpp include/coordination/coordinator_handlers.hpp - include/coordination/constants.hpp include/coordination/instance_status.hpp include/coordination/replication_instance.hpp include/coordination/raft_state.hpp + include/coordination/rpc_errors.hpp include/nuraft/coordinator_log_store.hpp include/nuraft/coordinator_state_machine.hpp diff --git a/src/coordination/coordinator_client.cpp b/src/coordination/coordinator_client.cpp index a30e504b7..84044b04a 100644 --- a/src/coordination/coordinator_client.cpp +++ b/src/coordination/coordinator_client.cpp @@ -17,6 +17,7 @@ #include "coordination/coordinator_config.hpp" #include "coordination/coordinator_rpc.hpp" #include "replication_coordination_glue/messages.hpp" +#include "utils/result.hpp" namespace memgraph::coordination { @@ -41,16 +42,25 @@ CoordinatorClient::CoordinatorClient(CoordinatorInstance *coord_instance, Coordi auto CoordinatorClient::InstanceName() const -> std::string { return config_.instance_name; } auto CoordinatorClient::SocketAddress() const -> std::string { return rpc_client_.Endpoint().SocketAddress(); } +auto CoordinatorClient::InstanceDownTimeoutSec() const -> std::chrono::seconds { + return config_.instance_down_timeout_sec; +} + +auto CoordinatorClient::InstanceGetUUIDFrequencySec() const -> std::chrono::seconds { + return config_.instance_get_uuid_frequency_sec; +} + void CoordinatorClient::StartFrequentCheck() { if (instance_checker_.IsRunning()) { return; } - MG_ASSERT(config_.health_check_frequency_sec > std::chrono::seconds(0), + MG_ASSERT(config_.instance_health_check_frequency_sec > std::chrono::seconds(0), "Health check frequency must be greater than 0"); instance_checker_.Run( - config_.instance_name, config_.health_check_frequency_sec, [this, instance_name = config_.instance_name] { + config_.instance_name, config_.instance_health_check_frequency_sec, + [this, instance_name = config_.instance_name] { try { spdlog::trace("Sending frequent heartbeat to machine {} on {}", instance_name, rpc_client_.Endpoint().SocketAddress()); @@ -121,5 +131,45 @@ auto CoordinatorClient::SendSwapMainUUIDRpc(const utils::UUID &uuid) const -> bo return false; } +auto CoordinatorClient::SendUnregisterReplicaRpc(std::string const &instance_name) const -> bool { + try { + auto stream{rpc_client_.Stream(instance_name)}; + if (!stream.AwaitResponse().success) { + spdlog::error("Failed to receive successful RPC response for unregistering replica!"); + return false; + } + return true; + } catch (rpc::RpcFailedException const &) { + spdlog::error("Failed to unregister replica!"); + } + return false; +} + +auto CoordinatorClient::SendGetInstanceUUIDRpc() const + -> utils::BasicResult> { + try { + auto stream{rpc_client_.Stream()}; + auto res = stream.AwaitResponse(); + return res.uuid; + } catch (const rpc::RpcFailedException &) { + spdlog::error("RPC error occured while sending GetInstance UUID RPC"); + return GetInstanceUUIDError::RPC_EXCEPTION; + } +} + +auto CoordinatorClient::SendEnableWritingOnMainRpc() const -> bool { + try { + auto stream{rpc_client_.Stream()}; + if (!stream.AwaitResponse().success) { + spdlog::error("Failed to receive successful RPC response for enabling writing on main!"); + return false; + } + return true; + } catch (rpc::RpcFailedException const &) { + spdlog::error("Failed to enable writing on main!"); + } + return false; +} + } // namespace memgraph::coordination #endif diff --git a/src/coordination/coordinator_handlers.cpp b/src/coordination/coordinator_handlers.cpp index fb0750935..f605069fe 100644 --- a/src/coordination/coordinator_handlers.cpp +++ b/src/coordination/coordinator_handlers.cpp @@ -39,6 +39,24 @@ void CoordinatorHandlers::Register(memgraph::coordination::CoordinatorServer &se spdlog::info("Received SwapMainUUIDRPC on coordinator server"); CoordinatorHandlers::SwapMainUUIDHandler(replication_handler, req_reader, res_builder); }); + + server.Register( + [&replication_handler](slk::Reader *req_reader, slk::Builder *res_builder) -> void { + spdlog::info("Received UnregisterReplicaRpc on coordinator server"); + CoordinatorHandlers::UnregisterReplicaHandler(replication_handler, req_reader, res_builder); + }); + + server.Register( + [&replication_handler](slk::Reader *req_reader, slk::Builder *res_builder) -> void { + spdlog::info("Received EnableWritingOnMainRpc on coordinator server"); + CoordinatorHandlers::EnableWritingOnMainHandler(replication_handler, req_reader, res_builder); + }); + + server.Register( + [&replication_handler](slk::Reader *req_reader, slk::Builder *res_builder) -> void { + spdlog::info("Received GetInstanceUUIDRpc on coordinator server"); + CoordinatorHandlers::GetInstanceUUIDHandler(replication_handler, req_reader, res_builder); + }); } void CoordinatorHandlers::SwapMainUUIDHandler(replication::ReplicationHandler &replication_handler, @@ -62,12 +80,6 @@ void CoordinatorHandlers::DemoteMainToReplicaHandler(replication::ReplicationHan slk::Reader *req_reader, slk::Builder *res_builder) { spdlog::info("Executing DemoteMainToReplicaHandler"); - if (!replication_handler.IsMain()) { - spdlog::error("Setting to replica must be performed on main."); - slk::Save(coordination::DemoteMainToReplicaRes{false}, res_builder); - return; - } - coordination::DemoteMainToReplicaReq req; slk::Load(&req, req_reader); @@ -77,11 +89,18 @@ void CoordinatorHandlers::DemoteMainToReplicaHandler(replication::ReplicationHan if (!replication_handler.SetReplicationRoleReplica(clients_config, std::nullopt)) { spdlog::error("Demoting main to replica failed!"); - slk::Save(coordination::PromoteReplicaToMainRes{false}, res_builder); + slk::Save(coordination::DemoteMainToReplicaRes{false}, res_builder); return; } - slk::Save(coordination::PromoteReplicaToMainRes{true}, res_builder); + slk::Save(coordination::DemoteMainToReplicaRes{true}, res_builder); +} + +void CoordinatorHandlers::GetInstanceUUIDHandler(replication::ReplicationHandler &replication_handler, + slk::Reader * /*req_reader*/, slk::Builder *res_builder) { + spdlog::info("Executing GetInstanceUUIDHandler"); + + slk::Save(coordination::GetInstanceUUIDRes{replication_handler.GetReplicaUUID()}, res_builder); } void CoordinatorHandlers::PromoteReplicaToMainHandler(replication::ReplicationHandler &replication_handler, @@ -113,7 +132,7 @@ void CoordinatorHandlers::PromoteReplicaToMainHandler(replication::ReplicationHa // registering replicas for (auto const &config : req.replication_clients_info | ranges::views::transform(converter)) { - auto instance_client = replication_handler.RegisterReplica(config, false); + auto instance_client = replication_handler.RegisterReplica(config); if (instance_client.HasError()) { using enum memgraph::replication::RegisterReplicaError; switch (instance_client.GetError()) { @@ -142,9 +161,58 @@ void CoordinatorHandlers::PromoteReplicaToMainHandler(replication::ReplicationHa } } } - spdlog::error(fmt::format("FICO : Promote replica to main was success {}", std::string(req.main_uuid_))); + spdlog::info("Promote replica to main was success {}", std::string(req.main_uuid_)); slk::Save(coordination::PromoteReplicaToMainRes{true}, res_builder); } +void CoordinatorHandlers::UnregisterReplicaHandler(replication::ReplicationHandler &replication_handler, + slk::Reader *req_reader, slk::Builder *res_builder) { + if (!replication_handler.IsMain()) { + spdlog::error("Unregistering replica must be performed on main."); + slk::Save(coordination::UnregisterReplicaRes{false}, res_builder); + return; + } + + coordination::UnregisterReplicaReq req; + slk::Load(&req, req_reader); + + auto res = replication_handler.UnregisterReplica(req.instance_name); + switch (res) { + using enum memgraph::query::UnregisterReplicaResult; + case SUCCESS: + slk::Save(coordination::UnregisterReplicaRes{true}, res_builder); + break; + case NOT_MAIN: + spdlog::error("Unregistering replica must be performed on main."); + slk::Save(coordination::UnregisterReplicaRes{false}, res_builder); + break; + case CAN_NOT_UNREGISTER: + spdlog::error("Could not unregister replica."); + slk::Save(coordination::UnregisterReplicaRes{false}, res_builder); + break; + case COULD_NOT_BE_PERSISTED: + spdlog::error("Could not persist replica unregistration."); + slk::Save(coordination::UnregisterReplicaRes{false}, res_builder); + break; + } +} + +void CoordinatorHandlers::EnableWritingOnMainHandler(replication::ReplicationHandler &replication_handler, + slk::Reader * /*req_reader*/, slk::Builder *res_builder) { + if (!replication_handler.IsMain()) { + spdlog::error("Enable writing on main must be performed on main!"); + slk::Save(coordination::EnableWritingOnMainRes{false}, res_builder); + return; + } + + if (!replication_handler.GetReplState().EnableWritingOnMain()) { + spdlog::error("Enabling writing on main failed!"); + slk::Save(coordination::EnableWritingOnMainRes{false}, res_builder); + return; + } + + slk::Save(coordination::EnableWritingOnMainRes{true}, res_builder); +} + } // namespace memgraph::dbms #endif diff --git a/src/coordination/coordinator_instance.cpp b/src/coordination/coordinator_instance.cpp index 4c3f3e646..90674cf3c 100644 --- a/src/coordination/coordinator_instance.cpp +++ b/src/coordination/coordinator_instance.cpp @@ -14,9 +14,11 @@ #include "coordination/coordinator_instance.hpp" #include "coordination/coordinator_exceptions.hpp" +#include "coordination/fmt.hpp" #include "nuraft/coordinator_state_machine.hpp" #include "nuraft/coordinator_state_manager.hpp" #include "utils/counter.hpp" +#include "utils/functional.hpp" #include #include @@ -47,9 +49,12 @@ CoordinatorInstance::CoordinatorInstance() spdlog::trace("Instance {} performing replica successful callback", repl_instance_name); auto &repl_instance = find_repl_instance(self, repl_instance_name); + // We need to get replicas UUID from time to time to ensure replica is listening to correct main + // and that it didn't go down for less time than we could notice + // We need to get id of main replica is listening to + // and swap if necessary if (!repl_instance.EnsureReplicaHasCorrectMainUUID(self->GetMainUUID())) { - spdlog::error( - fmt::format("Failed to swap uuid for replica instance {} which is alive", repl_instance.InstanceName())); + spdlog::error("Failed to swap uuid for replica instance {} which is alive", repl_instance.InstanceName()); return; } @@ -61,14 +66,6 @@ CoordinatorInstance::CoordinatorInstance() spdlog::trace("Instance {} performing replica failure callback", repl_instance_name); auto &repl_instance = find_repl_instance(self, repl_instance_name); repl_instance.OnFailPing(); - // We need to restart main uuid from instance since it was "down" at least a second - // There is slight delay, if we choose to use isAlive, instance can be down and back up in less than - // our isAlive time difference, which would lead to instance setting UUID to nullopt and stopping accepting any - // incoming RPCs from valid main - // TODO(antoniofilipovic) this needs here more complex logic - // We need to get id of main replica is listening to on successful ping - // and swap it to correct uuid if it failed - repl_instance.ResetMainUUID(); }; main_succ_cb_ = [find_repl_instance](CoordinatorInstance *self, std::string_view repl_instance_name) -> void { @@ -87,6 +84,11 @@ CoordinatorInstance::CoordinatorInstance() auto const curr_main_uuid = self->GetMainUUID(); if (curr_main_uuid == repl_instance_uuid.value()) { + if (!repl_instance.EnableWritingOnMain()) { + spdlog::error("Failed to enable writing on main instance {}", repl_instance_name); + return; + } + repl_instance.OnSuccessPing(); return; } @@ -122,17 +124,9 @@ CoordinatorInstance::CoordinatorInstance() }; } -auto CoordinatorInstance::ClusterHasAliveMain_() const -> bool { - auto const alive_main = [](ReplicationInstance const &instance) { return instance.IsMain() && instance.IsAlive(); }; - return std::ranges::any_of(repl_instances_, alive_main); -} - auto CoordinatorInstance::ShowInstances() const -> std::vector { auto const coord_instances = raft_state_.GetAllCoordinators(); - std::vector instances_status; - instances_status.reserve(repl_instances_.size() + coord_instances.size()); - auto const stringify_repl_role = [](ReplicationInstance const &instance) -> std::string { if (!instance.IsAlive()) return "unknown"; if (instance.IsMain()) return "main"; @@ -154,8 +148,7 @@ auto CoordinatorInstance::ShowInstances() const -> std::vector { // CoordinatorState to every instance, we can be smarter about this using our RPC. }; - std::ranges::transform(coord_instances, std::back_inserter(instances_status), coord_instance_to_status); - + auto instances_status = utils::fmap(coord_instance_to_status, coord_instances); { auto lock = std::shared_lock{coord_instance_lock_}; std::ranges::transform(repl_instances_, std::back_inserter(instances_status), repl_instance_to_status); @@ -194,10 +187,9 @@ auto CoordinatorInstance::TryFailover() -> void { } } - ReplicationClientsInfo repl_clients_info; - repl_clients_info.reserve(repl_instances_.size() - 1); - std::ranges::transform(repl_instances_ | ranges::views::filter(is_not_new_main), - std::back_inserter(repl_clients_info), &ReplicationInstance::ReplicationClientInfo); + auto repl_clients_info = repl_instances_ | ranges::views::filter(is_not_new_main) | + ranges::views::transform(&ReplicationInstance::ReplicationClientInfo) | + ranges::to(); if (!new_main->PromoteToMain(new_main_uuid, std::move(repl_clients_info), main_succ_cb_, main_fail_cb_)) { spdlog::warn("Failover failed since promoting replica to main failed!"); @@ -213,6 +205,10 @@ auto CoordinatorInstance::SetReplicationInstanceToMain(std::string instance_name -> SetInstanceToMainCoordinatorStatus { auto lock = std::lock_guard{coord_instance_lock_}; + if (std::ranges::any_of(repl_instances_, &ReplicationInstance::IsMain)) { + return SetInstanceToMainCoordinatorStatus::MAIN_ALREADY_EXISTS; + } + auto const is_new_main = [&instance_name](ReplicationInstance const &instance) { return instance.InstanceName() == instance_name; }; @@ -308,6 +304,35 @@ auto CoordinatorInstance::RegisterReplicationInstance(CoordinatorClientConfig co return RegisterInstanceCoordinatorStatus::SUCCESS; } +auto CoordinatorInstance::UnregisterReplicationInstance(std::string instance_name) + -> UnregisterInstanceCoordinatorStatus { + auto lock = std::lock_guard{coord_instance_lock_}; + + auto const name_matches = [&instance_name](ReplicationInstance const &instance) { + return instance.InstanceName() == instance_name; + }; + + auto inst_to_remove = std::ranges::find_if(repl_instances_, name_matches); + if (inst_to_remove == repl_instances_.end()) { + return UnregisterInstanceCoordinatorStatus::NO_INSTANCE_WITH_NAME; + } + + if (inst_to_remove->IsMain() && inst_to_remove->IsAlive()) { + return UnregisterInstanceCoordinatorStatus::IS_MAIN; + } + + inst_to_remove->StopFrequentCheck(); + auto curr_main = std::ranges::find_if(repl_instances_, &ReplicationInstance::IsMain); + MG_ASSERT(curr_main != repl_instances_.end(), "There must be a main instance when unregistering a replica"); + if (!curr_main->SendUnregisterReplicaRpc(instance_name)) { + inst_to_remove->StartFrequentCheck(); + return UnregisterInstanceCoordinatorStatus::RPC_FAILED; + } + std::erase_if(repl_instances_, name_matches); + + return UnregisterInstanceCoordinatorStatus::SUCCESS; +} + auto CoordinatorInstance::AddCoordinatorInstance(uint32_t raft_server_id, uint32_t raft_port, std::string raft_address) -> void { raft_state_.AddCoordinatorInstance(raft_server_id, raft_port, std::move(raft_address)); diff --git a/src/coordination/coordinator_rpc.cpp b/src/coordination/coordinator_rpc.cpp index 2b5752a07..4115f1979 100644 --- a/src/coordination/coordinator_rpc.cpp +++ b/src/coordination/coordinator_rpc.cpp @@ -52,6 +52,51 @@ void DemoteMainToReplicaRes::Load(DemoteMainToReplicaRes *self, memgraph::slk::R memgraph::slk::Load(self, reader); } +void UnregisterReplicaReq::Save(UnregisterReplicaReq const &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); +} + +void UnregisterReplicaReq::Load(UnregisterReplicaReq *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(self, reader); +} + +void UnregisterReplicaRes::Save(UnregisterReplicaRes const &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); +} + +void UnregisterReplicaRes::Load(UnregisterReplicaRes *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(self, reader); +} + +void EnableWritingOnMainRes::Save(EnableWritingOnMainRes const &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); +} + +void EnableWritingOnMainRes::Load(EnableWritingOnMainRes *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(self, reader); +} + +void EnableWritingOnMainReq::Save(EnableWritingOnMainReq const &self, memgraph::slk::Builder *builder) {} + +void EnableWritingOnMainReq::Load(EnableWritingOnMainReq *self, memgraph::slk::Reader *reader) {} + +// GetInstanceUUID +void GetInstanceUUIDReq::Save(const GetInstanceUUIDReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); +} + +void GetInstanceUUIDReq::Load(GetInstanceUUIDReq *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(self, reader); +} + +void GetInstanceUUIDRes::Save(const GetInstanceUUIDRes &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); +} + +void GetInstanceUUIDRes::Load(GetInstanceUUIDRes *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(self, reader); +} + } // namespace coordination constexpr utils::TypeInfo coordination::PromoteReplicaToMainReq::kType{utils::TypeId::COORD_FAILOVER_REQ, @@ -64,10 +109,31 @@ constexpr utils::TypeInfo coordination::DemoteMainToReplicaReq::kType{utils::Typ "CoordDemoteToReplicaReq", nullptr}; constexpr utils::TypeInfo coordination::DemoteMainToReplicaRes::kType{utils::TypeId::COORD_SET_REPL_MAIN_RES, + "CoordDemoteToReplicaRes", nullptr}; +constexpr utils::TypeInfo coordination::UnregisterReplicaReq::kType{utils::TypeId::COORD_UNREGISTER_REPLICA_REQ, + "UnregisterReplicaReq", nullptr}; + +constexpr utils::TypeInfo coordination::UnregisterReplicaRes::kType{utils::TypeId::COORD_UNREGISTER_REPLICA_RES, + "UnregisterReplicaRes", nullptr}; + +constexpr utils::TypeInfo coordination::EnableWritingOnMainReq::kType{utils::TypeId::COORD_ENABLE_WRITING_ON_MAIN_REQ, + "CoordEnableWritingOnMainReq", nullptr}; + +constexpr utils::TypeInfo coordination::EnableWritingOnMainRes::kType{utils::TypeId::COORD_ENABLE_WRITING_ON_MAIN_RES, + "CoordEnableWritingOnMainRes", nullptr}; + +constexpr utils::TypeInfo coordination::GetInstanceUUIDReq::kType{utils::TypeId::COORD_GET_UUID_REQ, "CoordGetUUIDReq", + nullptr}; + +constexpr utils::TypeInfo coordination::GetInstanceUUIDRes::kType{utils::TypeId::COORD_GET_UUID_RES, "CoordGetUUIDRes", + nullptr}; + namespace slk { +// PromoteReplicaToMainRpc + void Save(const memgraph::coordination::PromoteReplicaToMainRes &self, memgraph::slk::Builder *builder) { memgraph::slk::Save(self.success, builder); } @@ -86,6 +152,7 @@ void Load(memgraph::coordination::PromoteReplicaToMainReq *self, memgraph::slk:: memgraph::slk::Load(&self->replication_clients_info, reader); } +// DemoteMainToReplicaRpc void Save(const memgraph::coordination::DemoteMainToReplicaReq &self, memgraph::slk::Builder *builder) { memgraph::slk::Save(self.replication_client_info, builder); } @@ -102,6 +169,50 @@ void Load(memgraph::coordination::DemoteMainToReplicaRes *self, memgraph::slk::R memgraph::slk::Load(&self->success, reader); } +// UnregisterReplicaRpc + +void Save(memgraph::coordination::UnregisterReplicaReq const &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.instance_name, builder); +} + +void Load(memgraph::coordination::UnregisterReplicaReq *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->instance_name, reader); +} + +void Save(memgraph::coordination::UnregisterReplicaRes const &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.success, builder); +} + +void Load(memgraph::coordination::UnregisterReplicaRes *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->success, reader); +} + +void Save(memgraph::coordination::EnableWritingOnMainRes const &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.success, builder); +} + +void Load(memgraph::coordination::EnableWritingOnMainRes *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->success, reader); +} + +// GetInstanceUUIDRpc + +void Save(const memgraph::coordination::GetInstanceUUIDReq & /*self*/, memgraph::slk::Builder * /*builder*/) { + /* nothing to serialize*/ +} + +void Load(memgraph::coordination::GetInstanceUUIDReq * /*self*/, memgraph::slk::Reader * /*reader*/) { + /* nothing to serialize*/ +} + +void Save(const memgraph::coordination::GetInstanceUUIDRes &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.uuid, builder); +} + +void Load(memgraph::coordination::GetInstanceUUIDRes *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->uuid, reader); +} + } // namespace slk } // namespace memgraph diff --git a/src/coordination/coordinator_state.cpp b/src/coordination/coordinator_state.cpp index a2f6c9cee..28d6c604e 100644 --- a/src/coordination/coordinator_state.cpp +++ b/src/coordination/coordinator_state.cpp @@ -56,6 +56,20 @@ auto CoordinatorState::RegisterReplicationInstance(CoordinatorClientConfig confi data_); } +auto CoordinatorState::UnregisterReplicationInstance(std::string instance_name) -> UnregisterInstanceCoordinatorStatus { + MG_ASSERT(std::holds_alternative(data_), + "Coordinator cannot unregister instance since variant holds wrong alternative"); + + return std::visit( + memgraph::utils::Overloaded{[](const CoordinatorMainReplicaData & /*coordinator_main_replica_data*/) { + return UnregisterInstanceCoordinatorStatus::NOT_COORDINATOR; + }, + [&instance_name](CoordinatorInstance &coordinator_instance) { + return coordinator_instance.UnregisterReplicationInstance(instance_name); + }}, + data_); +} + auto CoordinatorState::SetReplicationInstanceToMain(std::string instance_name) -> SetInstanceToMainCoordinatorStatus { MG_ASSERT(std::holds_alternative(data_), "Coordinator cannot register replica since variant holds wrong alternative"); diff --git a/src/coordination/fmt.hpp b/src/coordination/fmt.hpp new file mode 100644 index 000000000..192f4e725 --- /dev/null +++ b/src/coordination/fmt.hpp @@ -0,0 +1,60 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#if FMT_VERSION > 90000 +#include +#include + +#include +#include "utils/logging.hpp" + +inline std::string ToString(const nuraft::cmd_result_code &code) { + switch (code) { + case nuraft::cmd_result_code::OK: + return "OK"; + case nuraft::cmd_result_code::FAILED: + return "FAILED"; + case nuraft::cmd_result_code::RESULT_NOT_EXIST_YET: + return "RESULT_NOT_EXIST_YET"; + case nuraft::cmd_result_code::TERM_MISMATCH: + return "TERM_MISMATCH"; + case nuraft::cmd_result_code::SERVER_IS_LEAVING: + return "SERVER_IS_LEAVING"; + case nuraft::cmd_result_code::CANNOT_REMOVE_LEADER: + return "CANNOT_REMOVE_LEADER"; + case nuraft::cmd_result_code::SERVER_NOT_FOUND: + return "SERVER_NOT_FOUND"; + case nuraft::cmd_result_code::SERVER_IS_JOINING: + return "SERVER_IS_JOINING"; + case nuraft::cmd_result_code::CONFIG_CHANGING: + return "CONFIG_CHANGING"; + case nuraft::cmd_result_code::SERVER_ALREADY_EXISTS: + return "SERVER_ALREADY_EXISTS"; + case nuraft::cmd_result_code::BAD_REQUEST: + return "BAD_REQUEST"; + case nuraft::cmd_result_code::NOT_LEADER: + return "NOT_LEADER"; + case nuraft::cmd_result_code::TIMEOUT: + return "TIMEOUT"; + case nuraft::cmd_result_code::CANCELLED: + return "CANCELLED"; + } + LOG_FATAL("ToString of a nuraft::cmd_result_code -> check missing switch case"); +} +inline std::ostream &operator<<(std::ostream &os, const nuraft::cmd_result_code &code) { + os << ToString(code); + return os; +} +template <> +class fmt::formatter : public fmt::ostream_formatter {}; +#endif diff --git a/src/coordination/include/coordination/coordinator_client.hpp b/src/coordination/include/coordination/coordinator_client.hpp index 02bae1c03..5e10af89d 100644 --- a/src/coordination/include/coordination/coordinator_client.hpp +++ b/src/coordination/include/coordination/coordinator_client.hpp @@ -11,12 +11,14 @@ #pragma once -#include "utils/uuid.hpp" #ifdef MG_ENTERPRISE #include "coordination/coordinator_config.hpp" #include "rpc/client.hpp" +#include "rpc_errors.hpp" +#include "utils/result.hpp" #include "utils/scheduler.hpp" +#include "utils/uuid.hpp" namespace memgraph::coordination { @@ -46,17 +48,28 @@ class CoordinatorClient { auto SocketAddress() const -> std::string; [[nodiscard]] auto DemoteToReplica() const -> bool; + auto SendPromoteReplicaToMainRpc(const utils::UUID &uuid, ReplicationClientsInfo replication_clients_info) const -> bool; auto SendSwapMainUUIDRpc(const utils::UUID &uuid) const -> bool; + auto SendUnregisterReplicaRpc(std::string const &instance_name) const -> bool; + + auto SendEnableWritingOnMainRpc() const -> bool; + + auto SendGetInstanceUUIDRpc() const -> memgraph::utils::BasicResult>; + auto ReplicationClientInfo() const -> ReplClientInfo; auto SetCallbacks(HealthCheckCallback succ_cb, HealthCheckCallback fail_cb) -> void; auto RpcClient() -> rpc::Client & { return rpc_client_; } + auto InstanceDownTimeoutSec() const -> std::chrono::seconds; + + auto InstanceGetUUIDFrequencySec() const -> std::chrono::seconds; + friend bool operator==(CoordinatorClient const &first, CoordinatorClient const &second) { return first.config_ == second.config_; } @@ -64,7 +77,6 @@ class CoordinatorClient { private: utils::Scheduler instance_checker_; - // TODO: (andi) Pimpl? communication::ClientContext rpc_context_; mutable rpc::Client rpc_client_; diff --git a/src/coordination/include/coordination/coordinator_config.hpp b/src/coordination/include/coordination/coordinator_config.hpp index f72b3a6ad..df7a5f94f 100644 --- a/src/coordination/include/coordination/coordinator_config.hpp +++ b/src/coordination/include/coordination/coordinator_config.hpp @@ -28,7 +28,9 @@ struct CoordinatorClientConfig { std::string instance_name; std::string ip_address; uint16_t port{}; - std::chrono::seconds health_check_frequency_sec{1}; + std::chrono::seconds instance_health_check_frequency_sec{1}; + std::chrono::seconds instance_down_timeout_sec{5}; + std::chrono::seconds instance_get_uuid_frequency_sec{10}; auto SocketAddress() const -> std::string { return ip_address + ":" + std::to_string(port); } diff --git a/src/coordination/include/coordination/coordinator_handlers.hpp b/src/coordination/include/coordination/coordinator_handlers.hpp index 4aa4656c3..b9ed4b519 100644 --- a/src/coordination/include/coordination/coordinator_handlers.hpp +++ b/src/coordination/include/coordination/coordinator_handlers.hpp @@ -33,6 +33,14 @@ class CoordinatorHandlers { slk::Builder *res_builder); static void SwapMainUUIDHandler(replication::ReplicationHandler &replication_handler, slk::Reader *req_reader, slk::Builder *res_builder); + + static void UnregisterReplicaHandler(replication::ReplicationHandler &replication_handler, slk::Reader *req_reader, + slk::Builder *res_builder); + static void EnableWritingOnMainHandler(replication::ReplicationHandler &replication_handler, slk::Reader *req_reader, + slk::Builder *res_builder); + + static void GetInstanceUUIDHandler(replication::ReplicationHandler &replication_handler, slk::Reader *req_reader, + slk::Builder *res_builder); }; } // namespace memgraph::dbms diff --git a/src/coordination/include/coordination/coordinator_instance.hpp b/src/coordination/include/coordination/coordinator_instance.hpp index bc6954b37..15b377ed9 100644 --- a/src/coordination/include/coordination/coordinator_instance.hpp +++ b/src/coordination/include/coordination/coordinator_instance.hpp @@ -30,6 +30,7 @@ class CoordinatorInstance { CoordinatorInstance(); [[nodiscard]] auto RegisterReplicationInstance(CoordinatorClientConfig config) -> RegisterInstanceCoordinatorStatus; + [[nodiscard]] auto UnregisterReplicationInstance(std::string instance_name) -> UnregisterInstanceCoordinatorStatus; [[nodiscard]] auto SetReplicationInstanceToMain(std::string instance_name) -> SetInstanceToMainCoordinatorStatus; @@ -44,8 +45,6 @@ class CoordinatorInstance { auto SetMainUUID(utils::UUID new_uuid) -> void; private: - auto ClusterHasAliveMain_() const -> bool; - HealthCheckCallback main_succ_cb_, main_fail_cb_, replica_succ_cb_, replica_fail_cb_; // NOTE: Must be std::list because we rely on pointer stability diff --git a/src/coordination/include/coordination/coordinator_rpc.hpp b/src/coordination/include/coordination/coordinator_rpc.hpp index 56cfdb403..1578b4577 100644 --- a/src/coordination/include/coordination/coordinator_rpc.hpp +++ b/src/coordination/include/coordination/coordinator_rpc.hpp @@ -82,6 +82,85 @@ struct DemoteMainToReplicaRes { using DemoteMainToReplicaRpc = rpc::RequestResponse; +struct UnregisterReplicaReq { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(UnregisterReplicaReq *self, memgraph::slk::Reader *reader); + static void Save(UnregisterReplicaReq const &self, memgraph::slk::Builder *builder); + + explicit UnregisterReplicaReq(std::string instance_name) : instance_name(std::move(instance_name)) {} + + UnregisterReplicaReq() = default; + + std::string instance_name; +}; + +struct UnregisterReplicaRes { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(UnregisterReplicaRes *self, memgraph::slk::Reader *reader); + static void Save(const UnregisterReplicaRes &self, memgraph::slk::Builder *builder); + + explicit UnregisterReplicaRes(bool success) : success(success) {} + UnregisterReplicaRes() = default; + + bool success; +}; + +using UnregisterReplicaRpc = rpc::RequestResponse; + +struct EnableWritingOnMainReq { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(EnableWritingOnMainReq *self, memgraph::slk::Reader *reader); + static void Save(EnableWritingOnMainReq const &self, memgraph::slk::Builder *builder); + + EnableWritingOnMainReq() = default; +}; + +struct EnableWritingOnMainRes { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(EnableWritingOnMainRes *self, memgraph::slk::Reader *reader); + static void Save(EnableWritingOnMainRes const &self, memgraph::slk::Builder *builder); + + explicit EnableWritingOnMainRes(bool success) : success(success) {} + EnableWritingOnMainRes() = default; + + bool success; +}; + +using EnableWritingOnMainRpc = rpc::RequestResponse; + +struct GetInstanceUUIDReq { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(GetInstanceUUIDReq *self, memgraph::slk::Reader *reader); + static void Save(const GetInstanceUUIDReq &self, memgraph::slk::Builder *builder); + + GetInstanceUUIDReq() = default; +}; + +struct GetInstanceUUIDRes { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(GetInstanceUUIDRes *self, memgraph::slk::Reader *reader); + static void Save(const GetInstanceUUIDRes &self, memgraph::slk::Builder *builder); + + explicit GetInstanceUUIDRes(std::optional uuid) : uuid(uuid) {} + GetInstanceUUIDRes() = default; + + std::optional uuid; +}; + +using GetInstanceUUIDRpc = rpc::RequestResponse; + } // namespace memgraph::coordination // SLK serialization declarations @@ -99,6 +178,19 @@ void Load(memgraph::coordination::DemoteMainToReplicaRes *self, memgraph::slk::R void Save(const memgraph::coordination::DemoteMainToReplicaReq &self, memgraph::slk::Builder *builder); void Load(memgraph::coordination::DemoteMainToReplicaReq *self, memgraph::slk::Reader *reader); +// GetInstanceUUIDRpc +void Save(const memgraph::coordination::GetInstanceUUIDReq &self, memgraph::slk::Builder *builder); +void Load(memgraph::coordination::GetInstanceUUIDReq *self, memgraph::slk::Reader *reader); +void Save(const memgraph::coordination::GetInstanceUUIDRes &self, memgraph::slk::Builder *builder); +void Load(memgraph::coordination::GetInstanceUUIDRes *self, memgraph::slk::Reader *reader); +// UnregisterReplicaRpc +void Save(memgraph::coordination::UnregisterReplicaRes const &self, memgraph::slk::Builder *builder); +void Load(memgraph::coordination::UnregisterReplicaRes *self, memgraph::slk::Reader *reader); +void Save(memgraph::coordination::UnregisterReplicaReq const &self, memgraph::slk::Builder *builder); +void Load(memgraph::coordination::UnregisterReplicaReq *self, memgraph::slk::Reader *reader); + +void Save(memgraph::coordination::EnableWritingOnMainRes const &self, memgraph::slk::Builder *builder); +void Load(memgraph::coordination::EnableWritingOnMainRes *self, memgraph::slk::Reader *reader); } // namespace memgraph::slk diff --git a/src/coordination/include/coordination/coordinator_state.hpp b/src/coordination/include/coordination/coordinator_state.hpp index 8830d1b49..256af66f9 100644 --- a/src/coordination/include/coordination/coordinator_state.hpp +++ b/src/coordination/include/coordination/coordinator_state.hpp @@ -34,6 +34,7 @@ class CoordinatorState { CoordinatorState &operator=(CoordinatorState &&) noexcept = delete; [[nodiscard]] auto RegisterReplicationInstance(CoordinatorClientConfig config) -> RegisterInstanceCoordinatorStatus; + [[nodiscard]] auto UnregisterReplicationInstance(std::string instance_name) -> UnregisterInstanceCoordinatorStatus; [[nodiscard]] auto SetReplicationInstanceToMain(std::string instance_name) -> SetInstanceToMainCoordinatorStatus; diff --git a/src/coordination/include/coordination/raft_state.hpp b/src/coordination/include/coordination/raft_state.hpp index 6b53197a0..b6ef06008 100644 --- a/src/coordination/include/coordination/raft_state.hpp +++ b/src/coordination/include/coordination/raft_state.hpp @@ -15,8 +15,6 @@ #include -#include - #include namespace memgraph::coordination { diff --git a/src/coordination/include/coordination/register_main_replica_coordinator_status.hpp b/src/coordination/include/coordination/register_main_replica_coordinator_status.hpp index 3a0df5607..3aa7e3ca1 100644 --- a/src/coordination/include/coordination/register_main_replica_coordinator_status.hpp +++ b/src/coordination/include/coordination/register_main_replica_coordinator_status.hpp @@ -28,8 +28,18 @@ enum class RegisterInstanceCoordinatorStatus : uint8_t { SUCCESS }; +enum class UnregisterInstanceCoordinatorStatus : uint8_t { + NO_INSTANCE_WITH_NAME, + IS_MAIN, + NOT_COORDINATOR, + NOT_LEADER, + RPC_FAILED, + SUCCESS, +}; + enum class SetInstanceToMainCoordinatorStatus : uint8_t { NO_INSTANCE_WITH_NAME, + MAIN_ALREADY_EXISTS, NOT_COORDINATOR, SUCCESS, COULD_NOT_PROMOTE_TO_MAIN, diff --git a/src/coordination/include/coordination/replication_instance.hpp b/src/coordination/include/coordination/replication_instance.hpp index 713a66fd8..8001d0905 100644 --- a/src/coordination/include/coordination/replication_instance.hpp +++ b/src/coordination/include/coordination/replication_instance.hpp @@ -14,11 +14,11 @@ #ifdef MG_ENTERPRISE #include "coordination/coordinator_client.hpp" -#include "coordination/coordinator_cluster_config.hpp" #include "coordination/coordinator_exceptions.hpp" #include "replication_coordination_glue/role.hpp" #include +#include "utils/result.hpp" #include "utils/uuid.hpp" namespace memgraph::coordination { @@ -38,6 +38,9 @@ class ReplicationInstance { auto OnSuccessPing() -> void; auto OnFailPing() -> bool; + auto IsReadyForUUIDPing() -> bool; + + void UpdateReplicaLastResponseUUID(); auto IsAlive() const -> bool; @@ -59,18 +62,26 @@ class ReplicationInstance { auto ReplicationClientInfo() const -> ReplClientInfo; auto EnsureReplicaHasCorrectMainUUID(utils::UUID const &curr_main_uuid) -> bool; + auto SendSwapAndUpdateUUID(const utils::UUID &new_main_uuid) -> bool; + auto SendUnregisterReplicaRpc(std::string const &instance_name) -> bool; + + + auto SendGetInstanceUUID() -> utils::BasicResult>; auto GetClient() -> CoordinatorClient &; + auto EnableWritingOnMain() -> bool; + auto SetNewMainUUID(utils::UUID const &main_uuid) -> void; auto ResetMainUUID() -> void; - auto GetMainUUID() -> const std::optional &; + auto GetMainUUID() const -> const std::optional &; private: CoordinatorClient client_; replication_coordination_glue::ReplicationRole replication_role_; std::chrono::system_clock::time_point last_response_time_{}; bool is_alive_{false}; + std::chrono::system_clock::time_point last_check_of_uuid_{}; // for replica this is main uuid of current main // for "main" main this same as in CoordinatorData diff --git a/src/coordination/include/coordination/constants.hpp b/src/coordination/include/coordination/rpc_errors.hpp similarity index 82% rename from src/coordination/include/coordination/constants.hpp rename to src/coordination/include/coordination/rpc_errors.hpp index 819b9fa05..f6bfbf3e0 100644 --- a/src/coordination/include/coordination/constants.hpp +++ b/src/coordination/include/coordination/rpc_errors.hpp @@ -9,14 +9,6 @@ // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. -#pragma once - namespace memgraph::coordination { - -#ifdef MG_EXPERIMENTAL_HIGH_AVAILABILITY -constexpr bool allow_ha = true; -#else -constexpr bool allow_ha = false; -#endif - +enum class GetInstanceUUIDError { NO_RESPONSE, RPC_EXCEPTION }; } // namespace memgraph::coordination diff --git a/src/coordination/replication_instance.cpp b/src/coordination/replication_instance.cpp index 0fb13998c..0d16db648 100644 --- a/src/coordination/replication_instance.cpp +++ b/src/coordination/replication_instance.cpp @@ -14,6 +14,7 @@ #include "coordination/replication_instance.hpp" #include "replication_coordination_glue/handler.hpp" +#include "utils/result.hpp" namespace memgraph::coordination { @@ -34,12 +35,16 @@ auto ReplicationInstance::OnSuccessPing() -> void { } auto ReplicationInstance::OnFailPing() -> bool { - is_alive_ = - std::chrono::duration_cast(std::chrono::system_clock::now() - last_response_time_).count() < - CoordinatorClusterConfig::alive_response_time_difference_sec_; + auto elapsed_time = std::chrono::system_clock::now() - last_response_time_; + is_alive_ = elapsed_time < client_.InstanceDownTimeoutSec(); return is_alive_; } +auto ReplicationInstance::IsReadyForUUIDPing() -> bool { + return std::chrono::duration_cast(std::chrono::system_clock::now() - last_check_of_uuid_) > + client_.InstanceGetUUIDFrequencySec(); +} + auto ReplicationInstance::InstanceName() const -> std::string { return client_.InstanceName(); } auto ReplicationInstance::SocketAddress() const -> std::string { return client_.SocketAddress(); } auto ReplicationInstance::IsAlive() const -> bool { return is_alive_; } @@ -86,15 +91,26 @@ auto ReplicationInstance::ReplicationClientInfo() const -> CoordinatorClientConf } auto ReplicationInstance::GetClient() -> CoordinatorClient & { return client_; } + auto ReplicationInstance::SetNewMainUUID(utils::UUID const &main_uuid) -> void { main_uuid_ = main_uuid; } auto ReplicationInstance::ResetMainUUID() -> void { main_uuid_ = std::nullopt; } -auto ReplicationInstance::GetMainUUID() -> const std::optional & { return main_uuid_; } +auto ReplicationInstance::GetMainUUID() const -> std::optional const & { return main_uuid_; } auto ReplicationInstance::EnsureReplicaHasCorrectMainUUID(utils::UUID const &curr_main_uuid) -> bool { - if (!main_uuid_ || *main_uuid_ != curr_main_uuid) { - return SendSwapAndUpdateUUID(curr_main_uuid); + if (!IsReadyForUUIDPing()) { + return true; } - return true; + auto res = SendGetInstanceUUID(); + if (res.HasError()) { + return false; + } + UpdateReplicaLastResponseUUID(); + + if (res.GetValue().has_value() && res.GetValue().value() == curr_main_uuid) { + return true; + } + + return SendSwapAndUpdateUUID(curr_main_uuid); } auto ReplicationInstance::SendSwapAndUpdateUUID(const utils::UUID &new_main_uuid) -> bool { @@ -105,5 +121,18 @@ auto ReplicationInstance::SendSwapAndUpdateUUID(const utils::UUID &new_main_uuid return true; } +auto ReplicationInstance::SendUnregisterReplicaRpc(std::string const &instance_name) -> bool { + return client_.SendUnregisterReplicaRpc(instance_name); +} + +auto ReplicationInstance::EnableWritingOnMain() -> bool { return client_.SendEnableWritingOnMainRpc(); } + +auto ReplicationInstance::SendGetInstanceUUID() + -> utils::BasicResult> { + return client_.SendGetInstanceUUIDRpc(); +} + +void ReplicationInstance::UpdateReplicaLastResponseUUID() { last_check_of_uuid_ = std::chrono::system_clock::now(); } + } // namespace memgraph::coordination #endif diff --git a/src/dbms/coordinator_handler.cpp b/src/dbms/coordinator_handler.cpp index b623e1db6..f8e14e2a0 100644 --- a/src/dbms/coordinator_handler.cpp +++ b/src/dbms/coordinator_handler.cpp @@ -25,6 +25,11 @@ auto CoordinatorHandler::RegisterReplicationInstance(memgraph::coordination::Coo return coordinator_state_.RegisterReplicationInstance(config); } +auto CoordinatorHandler::UnregisterReplicationInstance(std::string instance_name) + -> coordination::UnregisterInstanceCoordinatorStatus { + return coordinator_state_.UnregisterReplicationInstance(std::move(instance_name)); +} + auto CoordinatorHandler::SetReplicationInstanceToMain(std::string instance_name) -> coordination::SetInstanceToMainCoordinatorStatus { return coordinator_state_.SetReplicationInstanceToMain(std::move(instance_name)); diff --git a/src/dbms/coordinator_handler.hpp b/src/dbms/coordinator_handler.hpp index 03d45ee41..d06e70676 100644 --- a/src/dbms/coordinator_handler.hpp +++ b/src/dbms/coordinator_handler.hpp @@ -28,9 +28,13 @@ class CoordinatorHandler { public: explicit CoordinatorHandler(coordination::CoordinatorState &coordinator_state); + // TODO: (andi) When moving coordinator state on same instances, rename from RegisterReplicationInstance to + // RegisterInstance auto RegisterReplicationInstance(coordination::CoordinatorClientConfig config) -> coordination::RegisterInstanceCoordinatorStatus; + auto UnregisterReplicationInstance(std::string instance_name) -> coordination::UnregisterInstanceCoordinatorStatus; + auto SetReplicationInstanceToMain(std::string instance_name) -> coordination::SetInstanceToMainCoordinatorStatus; auto ShowInstances() const -> std::vector; diff --git a/src/dbms/database.hpp b/src/dbms/database.hpp index 2d7d3fe88..d144276da 100644 --- a/src/dbms/database.hpp +++ b/src/dbms/database.hpp @@ -110,9 +110,9 @@ class Database { * @param force_directory Use the configured directory, do not try to decipher the multi-db version * @return DatabaseInfo */ - DatabaseInfo GetInfo(bool force_directory, replication_coordination_glue::ReplicationRole replication_role) const { + DatabaseInfo GetInfo(replication_coordination_glue::ReplicationRole replication_role) const { DatabaseInfo info; - info.storage_info = storage_->GetInfo(force_directory, replication_role); + info.storage_info = storage_->GetInfo(replication_role); info.triggers = trigger_store_.GetTriggerInfo().size(); info.streams = streams_.GetStreamInfo().size(); return info; diff --git a/src/dbms/dbms_handler.hpp b/src/dbms/dbms_handler.hpp index 1bdb4a8fa..7b1d45335 100644 --- a/src/dbms/dbms_handler.hpp +++ b/src/dbms/dbms_handler.hpp @@ -302,7 +302,7 @@ class DbmsHandler { auto db_acc_opt = db_gk.access(); if (db_acc_opt) { auto &db_acc = *db_acc_opt; - const auto &info = db_acc->GetInfo(false, replication_role); + const auto &info = db_acc->GetInfo(replication_role); const auto &storage_info = info.storage_info; stats.num_vertex += storage_info.vertex_count; stats.num_edges += storage_info.edge_count; @@ -338,7 +338,7 @@ class DbmsHandler { auto db_acc_opt = db_gk.access(); if (db_acc_opt) { auto &db_acc = *db_acc_opt; - res.push_back(db_acc->GetInfo(false, replication_role)); + res.push_back(db_acc->GetInfo(replication_role)); } } return res; diff --git a/src/dbms/inmemory/replication_handlers.cpp b/src/dbms/inmemory/replication_handlers.cpp index b7b2146f4..3fc174d3c 100644 --- a/src/dbms/inmemory/replication_handlers.cpp +++ b/src/dbms/inmemory/replication_handlers.cpp @@ -19,6 +19,7 @@ #include "storage/v2/durability/durability.hpp" #include "storage/v2/durability/snapshot.hpp" #include "storage/v2/durability/version.hpp" +#include "storage/v2/fmt.hpp" #include "storage/v2/indices/label_index_stats.hpp" #include "storage/v2/inmemory/storage.hpp" #include "storage/v2/inmemory/unique_constraints.hpp" diff --git a/src/dbms/utils.hpp b/src/dbms/utils.hpp deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/distributed/include/distributed/lamport_clock.hpp b/src/distributed/include/distributed/lamport_clock.hpp index f3e91e47a..2bbc0b447 100644 --- a/src/distributed/include/distributed/lamport_clock.hpp +++ b/src/distributed/include/distributed/lamport_clock.hpp @@ -10,6 +10,7 @@ // licenses/APL.txt. #pragma once +#include #include #include #include diff --git a/src/flags/experimental.cpp b/src/flags/experimental.cpp index 7bd26a837..123903c96 100644 --- a/src/flags/experimental.cpp +++ b/src/flags/experimental.cpp @@ -19,13 +19,14 @@ // Bolt server flags. // NOLINTNEXTLINE (cppcoreguidelines-avoid-non-const-global-variables) DEFINE_string(experimental_enabled, "", - "Experimental features to be used, comma seperated. Options [system-replication]"); + "Experimental features to be used, comma seperated. Options [system-replication, high-availability]"); using namespace std::string_view_literals; namespace memgraph::flags { -auto const mapping = std::map{std::pair{"system-replication"sv, Experiments::SYSTEM_REPLICATION}}; +auto const mapping = std::map{std::pair{"system-replication"sv, Experiments::SYSTEM_REPLICATION}, + std::pair{"high-availability"sv, Experiments::HIGH_AVAILABILITY}}; auto ExperimentsInstance() -> Experiments & { static auto instance = Experiments{}; diff --git a/src/flags/experimental.hpp b/src/flags/experimental.hpp index ec4db2037..5a19889fe 100644 --- a/src/flags/experimental.hpp +++ b/src/flags/experimental.hpp @@ -23,6 +23,7 @@ namespace memgraph::flags { // old experiments can be reused once code cleanup has happened enum class Experiments : uint8_t { SYSTEM_REPLICATION = 1 << 0, + HIGH_AVAILABILITY = 1 << 1, }; bool AreExperimentsEnabled(Experiments experiments); diff --git a/src/flags/replication.cpp b/src/flags/replication.cpp index 29c7bfbda..e6b71b942 100644 --- a/src/flags/replication.cpp +++ b/src/flags/replication.cpp @@ -18,6 +18,12 @@ DEFINE_uint32(coordinator_server_port, 0, "Port on which coordinator servers wil DEFINE_uint32(raft_server_port, 0, "Port on which raft servers will be started."); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_uint32(raft_server_id, 0, "Unique ID of the raft server."); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +DEFINE_uint32(instance_down_timeout_sec, 5, "Time duration after which an instance is considered down."); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +DEFINE_uint32(instance_health_check_frequency_sec, 1, "The time duration between two health checks/pings."); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +DEFINE_uint32(instance_get_uuid_frequency_sec, 10, "The time duration between two instance uuid checks."); #endif // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) diff --git a/src/flags/replication.hpp b/src/flags/replication.hpp index 025079271..0a4982f12 100644 --- a/src/flags/replication.hpp +++ b/src/flags/replication.hpp @@ -20,6 +20,12 @@ DECLARE_uint32(coordinator_server_port); DECLARE_uint32(raft_server_port); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) DECLARE_uint32(raft_server_id); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +DECLARE_uint32(instance_down_timeout_sec); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +DECLARE_uint32(instance_health_check_frequency_sec); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +DECLARE_uint32(instance_get_uuid_frequency_sec); #endif // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) diff --git a/src/glue/CMakeLists.txt b/src/glue/CMakeLists.txt index da287179f..8f3aec412 100644 --- a/src/glue/CMakeLists.txt +++ b/src/glue/CMakeLists.txt @@ -6,5 +6,6 @@ target_sources(mg-glue PRIVATE auth.cpp SessionHL.cpp ServerT.cpp MonitoringServerT.cpp - run_id.cpp) + run_id.cpp + query_user.cpp) target_link_libraries(mg-glue mg-query mg-auth mg-audit mg-flags) diff --git a/src/glue/SessionHL.cpp b/src/glue/SessionHL.cpp index 07e1bf6e8..6c901516c 100644 --- a/src/glue/SessionHL.cpp +++ b/src/glue/SessionHL.cpp @@ -11,6 +11,7 @@ #include #include +#include "auth/auth.hpp" #include "gflags/gflags.h" #include "audit/log.hpp" @@ -19,17 +20,22 @@ #include "glue/SessionHL.hpp" #include "glue/auth_checker.hpp" #include "glue/communication.hpp" +#include "glue/query_user.hpp" #include "glue/run_id.hpp" #include "license/license.hpp" +#include "query/auth_checker.hpp" #include "query/discard_value_stream.hpp" #include "query/interpreter_context.hpp" +#include "query/query_user.hpp" #include "utils/event_map.hpp" #include "utils/spin_lock.hpp" +#include "utils/variant_helpers.hpp" namespace memgraph::metrics { extern const Event ActiveBoltSessions; } // namespace memgraph::metrics +namespace { auto ToQueryExtras(const memgraph::communication::bolt::Value &extra) -> memgraph::query::QueryExtras { auto const &as_map = extra.ValueMap(); @@ -97,20 +103,24 @@ std::vector TypedValueResultStreamBase::De } return decoded_values; } + TypedValueResultStreamBase::TypedValueResultStreamBase(memgraph::storage::Storage *storage) : storage_(storage) {} -namespace memgraph::glue { - #ifdef MG_ENTERPRISE -inline static void MultiDatabaseAuth(const std::optional &user, std::string_view db) { - if (user && !AuthChecker::IsUserAuthorized(*user, {}, std::string(db))) { +void MultiDatabaseAuth(memgraph::query::QueryUserOrRole *user, std::string_view db) { + if (user && !user->IsAuthorized({}, std::string(db), &memgraph::query::session_long_policy)) { throw memgraph::communication::bolt::ClientError( "You are not authorized on the database \"{}\"! Please contact your database administrator.", db); } } +#endif +} // namespace +namespace memgraph::glue { + +#ifdef MG_ENTERPRISE std::string SessionHL::GetDefaultDB() { - if (user_.has_value()) { - return user_->db_access().GetDefault(); + if (user_or_role_) { + return user_or_role_->GetDefaultDB(); } return std::string{memgraph::dbms::kDefaultDB}; } @@ -132,13 +142,18 @@ bool SessionHL::Authenticate(const std::string &username, const std::string &pas interpreter_.ResetUser(); { auto locked_auth = auth_->Lock(); - if (locked_auth->HasUsers()) { - user_ = locked_auth->Authenticate(username, password); - if (user_.has_value()) { - interpreter_.SetUser(user_->username()); + if (locked_auth->AccessControlled()) { + const auto user_or_role = locked_auth->Authenticate(username, password); + if (user_or_role.has_value()) { + user_or_role_ = AuthChecker::GenQueryUser(auth_, *user_or_role); + interpreter_.SetUser(AuthChecker::GenQueryUser(auth_, *user_or_role)); } else { res = false; } + } else { + // No access control -> give empty user + user_or_role_ = AuthChecker::GenQueryUser(auth_, std::nullopt); + interpreter_.SetUser(AuthChecker::GenQueryUser(auth_, std::nullopt)); } } #ifdef MG_ENTERPRISE @@ -195,21 +210,17 @@ std::pair, std::optional> SessionHL::Interpret( } #ifdef MG_ENTERPRISE - const std::string *username{nullptr}; - if (user_) { - username = &user_->username(); - } - if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) { auto &db = interpreter_.current_db_.db_acc_; - audit_log_->Record(endpoint_.address().to_string(), user_ ? *username : "", query, - memgraph::storage::PropertyValue(params_pv), db ? db->get()->name() : "no known database"); + const auto username = user_or_role_ ? (user_or_role_->username() ? *user_or_role_->username() : "") : ""; + audit_log_->Record(endpoint_.address().to_string(), username, query, memgraph::storage::PropertyValue(params_pv), + db ? db->get()->name() : "no known database"); } #endif try { auto result = interpreter_.Prepare(query, params_pv, ToQueryExtras(extra)); const std::string db_name = result.db ? *result.db : ""; - if (user_ && !AuthChecker::IsUserAuthorized(*user_, result.privileges, db_name)) { + if (user_or_role_ && !user_or_role_->IsAuthorized(result.privileges, db_name, &query::session_long_policy)) { interpreter_.Abort(); if (db_name.empty()) { throw memgraph::communication::bolt::ClientError( @@ -311,7 +322,7 @@ void SessionHL::Configure(const std::mapinterpreters.WithLock([this](auto &interpreters) { interpreters.insert(&interpreter_); }); } diff --git a/src/glue/SessionHL.hpp b/src/glue/SessionHL.hpp index 64dcddda5..cf0280fcc 100644 --- a/src/glue/SessionHL.hpp +++ b/src/glue/SessionHL.hpp @@ -15,6 +15,7 @@ #include "communication/v2/server.hpp" #include "communication/v2/session.hpp" #include "dbms/database.hpp" +#include "glue/query_user.hpp" #include "query/interpreter.hpp" namespace memgraph::glue { @@ -82,7 +83,7 @@ class SessionHL final : public memgraph::communication::bolt::Session user_; + std::unique_ptr user_or_role_; #ifdef MG_ENTERPRISE memgraph::audit::Log *audit_log_; bool in_explicit_db_{false}; //!< If true, the user has defined the database to use via metadata diff --git a/src/glue/auth_checker.cpp b/src/glue/auth_checker.cpp index 4db6c827e..99463d323 100644 --- a/src/glue/auth_checker.cpp +++ b/src/glue/auth_checker.cpp @@ -14,53 +14,74 @@ #include "auth/auth.hpp" #include "auth/models.hpp" #include "glue/auth.hpp" +#include "glue/query_user.hpp" #include "license/license.hpp" +#include "query/auth_checker.hpp" #include "query/constants.hpp" #include "query/frontend/ast/ast.hpp" +#include "query/query_user.hpp" +#include "utils/logging.hpp" #include "utils/synchronized.hpp" +#include "utils/variant_helpers.hpp" #ifdef MG_ENTERPRISE namespace { -bool IsUserAuthorizedLabels(const memgraph::auth::User &user, const memgraph::query::DbAccessor *dba, - const std::vector &labels, - const memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) { +bool IsAuthorizedLabels(const memgraph::auth::UserOrRole &user_or_role, const memgraph::query::DbAccessor *dba, + const std::vector &labels, + const memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) { if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) { return true; } - return std::all_of(labels.begin(), labels.end(), [dba, &user, fine_grained_privilege](const auto &label) { - return user.GetFineGrainedAccessLabelPermissions().Has( - dba->LabelToName(label), memgraph::glue::FineGrainedPrivilegeToFineGrainedPermission( - fine_grained_privilege)) == memgraph::auth::PermissionLevel::GRANT; + return std::all_of(labels.begin(), labels.end(), [dba, &user_or_role, fine_grained_privilege](const auto &label) { + return std::visit(memgraph::utils::Overloaded{[&](auto &user_or_role) { + return user_or_role.GetFineGrainedAccessLabelPermissions().Has( + dba->LabelToName(label), memgraph::glue::FineGrainedPrivilegeToFineGrainedPermission( + fine_grained_privilege)) == + memgraph::auth::PermissionLevel::GRANT; + }}, + user_or_role); }); } -bool IsUserAuthorizedGloballyLabels(const memgraph::auth::User &user, - const memgraph::auth::FineGrainedPermission fine_grained_permission) { +bool IsAuthorizedGloballyLabels(const memgraph::auth::UserOrRole &user_or_role, + const memgraph::auth::FineGrainedPermission fine_grained_permission) { if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) { return true; } - return user.GetFineGrainedAccessLabelPermissions().Has(memgraph::query::kAsterisk, fine_grained_permission) == - memgraph::auth::PermissionLevel::GRANT; + return std::visit(memgraph::utils::Overloaded{[&](auto &user_or_role) { + return user_or_role.GetFineGrainedAccessLabelPermissions().Has(memgraph::query::kAsterisk, + fine_grained_permission) == + memgraph::auth::PermissionLevel::GRANT; + }}, + user_or_role); } -bool IsUserAuthorizedGloballyEdges(const memgraph::auth::User &user, - const memgraph::auth::FineGrainedPermission fine_grained_permission) { +bool IsAuthorizedGloballyEdges(const memgraph::auth::UserOrRole &user_or_role, + const memgraph::auth::FineGrainedPermission fine_grained_permission) { if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) { return true; } - return user.GetFineGrainedAccessEdgeTypePermissions().Has(memgraph::query::kAsterisk, fine_grained_permission) == - memgraph::auth::PermissionLevel::GRANT; + return std::visit(memgraph::utils::Overloaded{[&](auto &user_or_role) { + return user_or_role.GetFineGrainedAccessEdgeTypePermissions().Has(memgraph::query::kAsterisk, + fine_grained_permission) == + memgraph::auth::PermissionLevel::GRANT; + }}, + user_or_role); } -bool IsUserAuthorizedEdgeType(const memgraph::auth::User &user, const memgraph::query::DbAccessor *dba, - const memgraph::storage::EdgeTypeId &edgeType, - const memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) { +bool IsAuthorizedEdgeType(const memgraph::auth::UserOrRole &user_or_role, const memgraph::query::DbAccessor *dba, + const memgraph::storage::EdgeTypeId &edgeType, + const memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) { if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) { return true; } - return user.GetFineGrainedAccessEdgeTypePermissions().Has( - dba->EdgeTypeToName(edgeType), memgraph::glue::FineGrainedPrivilegeToFineGrainedPermission( - fine_grained_privilege)) == memgraph::auth::PermissionLevel::GRANT; + return std::visit(memgraph::utils::Overloaded{[&](auto &user_or_role) { + return user_or_role.GetFineGrainedAccessEdgeTypePermissions().Has( + dba->EdgeTypeToName(edgeType), + memgraph::glue::FineGrainedPrivilegeToFineGrainedPermission(fine_grained_privilege)) == + memgraph::auth::PermissionLevel::GRANT; + }}, + user_or_role); } } // namespace #endif @@ -68,47 +89,54 @@ namespace memgraph::glue { AuthChecker::AuthChecker(memgraph::auth::SynchedAuth *auth) : auth_(auth) {} -bool AuthChecker::IsUserAuthorized(const std::optional &username, - const std::vector &privileges, - const std::string &db_name) const { - std::optional maybe_user; - { - auto locked_auth = auth_->ReadLock(); - if (!locked_auth->HasUsers()) { - return true; - } - if (username.has_value()) { - maybe_user = locked_auth->GetUser(*username); - } +std::shared_ptr AuthChecker::GenQueryUser(const std::optional &username, + const std::optional &rolename) const { + const auto user_or_role = auth_->ReadLock()->GetUserOrRole(username, rolename); + if (user_or_role) { + return std::make_shared(auth_, *user_or_role); } + // No user or role + return std::make_shared(auth_); +} - return maybe_user.has_value() && IsUserAuthorized(*maybe_user, privileges, db_name); +std::unique_ptr AuthChecker::GenQueryUser(auth::SynchedAuth *auth, + const std::optional &user_or_role) { + if (user_or_role) { + return std::visit( + utils::Overloaded{[&](auto &user_or_role) { return std::make_unique(auth, user_or_role); }}, + *user_or_role); + } + // No user or role + return std::make_unique(auth); } #ifdef MG_ENTERPRISE std::unique_ptr AuthChecker::GetFineGrainedAuthChecker( - const std::string &username, const memgraph::query::DbAccessor *dba) const { + std::shared_ptr user_or_role, const memgraph::query::DbAccessor *dba) const { if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) { return {}; } - try { - auto user = user_.Lock(); - if (username != user->username()) { - auto maybe_user = auth_->ReadLock()->GetUser(username); - if (!maybe_user) { - throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist .", username); - } - *user = std::move(*maybe_user); - } - return std::make_unique(*user, dba); - - } catch (const memgraph::auth::AuthException &e) { - throw memgraph::query::QueryRuntimeException(e.what()); + if (!user_or_role || !*user_or_role) { + throw query::QueryRuntimeException("No user specified for fine grained authorization!"); } -} -void AuthChecker::ClearCache() const { - user_.WithLock([](auto &user) mutable { user = {}; }); + // Convert from query user to auth user or role + try { + auto glue_user = dynamic_cast(*user_or_role); + if (glue_user.user_) { + return std::make_unique(std::move(*glue_user.user_), dba); + } + if (glue_user.role_) { + return std::make_unique( + auth::RoleWUsername{*glue_user.username(), std::move(*glue_user.role_)}, dba); + } + DMG_ASSERT(false, "Glue user has neither user not role"); + } catch (std::bad_cast &e) { + DMG_ASSERT(false, "Using a non-glue user in glue..."); + } + + // Should never get here + return {}; } #endif @@ -116,7 +144,7 @@ bool AuthChecker::IsUserAuthorized(const memgraph::auth::User &user, const std::vector &privileges, const std::string &db_name) { // NOLINT #ifdef MG_ENTERPRISE - if (!db_name.empty() && !user.db_access().Contains(db_name)) { + if (!db_name.empty() && !user.HasAccess(db_name)) { return false; } #endif @@ -127,9 +155,34 @@ bool AuthChecker::IsUserAuthorized(const memgraph::auth::User &user, }); } +bool AuthChecker::IsRoleAuthorized(const memgraph::auth::Role &role, + const std::vector &privileges, + const std::string &db_name) { // NOLINT #ifdef MG_ENTERPRISE -FineGrainedAuthChecker::FineGrainedAuthChecker(auth::User user, const memgraph::query::DbAccessor *dba) - : user_{std::move(user)}, dba_(dba){}; + if (!db_name.empty() && !role.HasAccess(db_name)) { + return false; + } +#endif + const auto role_permissions = role.permissions(); + return std::all_of(privileges.begin(), privileges.end(), [&role_permissions](const auto privilege) { + return role_permissions.Has(memgraph::glue::PrivilegeToPermission(privilege)) == + memgraph::auth::PermissionLevel::GRANT; + }); +} + +bool AuthChecker::IsUserOrRoleAuthorized(const memgraph::auth::UserOrRole &user_or_role, + const std::vector &privileges, + const std::string &db_name) { + return std::visit( + utils::Overloaded{ + [&](const auth::User &user) -> bool { return AuthChecker::IsUserAuthorized(user, privileges, db_name); }, + [&](const auth::Role &role) -> bool { return AuthChecker::IsRoleAuthorized(role, privileges, db_name); }}, + user_or_role); +} + +#ifdef MG_ENTERPRISE +FineGrainedAuthChecker::FineGrainedAuthChecker(auth::UserOrRole user_or_role, const memgraph::query::DbAccessor *dba) + : user_or_role_{std::move(user_or_role)}, dba_(dba){}; bool FineGrainedAuthChecker::Has(const memgraph::query::VertexAccessor &vertex, const memgraph::storage::View view, const memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const { @@ -147,22 +200,22 @@ bool FineGrainedAuthChecker::Has(const memgraph::query::VertexAccessor &vertex, } } - return IsUserAuthorizedLabels(user_, dba_, *maybe_labels, fine_grained_privilege); + return IsAuthorizedLabels(user_or_role_, dba_, *maybe_labels, fine_grained_privilege); } bool FineGrainedAuthChecker::Has(const memgraph::query::EdgeAccessor &edge, const memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const { - return IsUserAuthorizedEdgeType(user_, dba_, edge.EdgeType(), fine_grained_privilege); + return IsAuthorizedEdgeType(user_or_role_, dba_, edge.EdgeType(), fine_grained_privilege); } bool FineGrainedAuthChecker::Has(const std::vector &labels, const memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const { - return IsUserAuthorizedLabels(user_, dba_, labels, fine_grained_privilege); + return IsAuthorizedLabels(user_or_role_, dba_, labels, fine_grained_privilege); } bool FineGrainedAuthChecker::Has(const memgraph::storage::EdgeTypeId &edge_type, const memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const { - return IsUserAuthorizedEdgeType(user_, dba_, edge_type, fine_grained_privilege); + return IsAuthorizedEdgeType(user_or_role_, dba_, edge_type, fine_grained_privilege); } bool FineGrainedAuthChecker::HasGlobalPrivilegeOnVertices( @@ -170,7 +223,7 @@ bool FineGrainedAuthChecker::HasGlobalPrivilegeOnVertices( if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) { return true; } - return IsUserAuthorizedGloballyLabels(user_, FineGrainedPrivilegeToFineGrainedPermission(fine_grained_privilege)); + return IsAuthorizedGloballyLabels(user_or_role_, FineGrainedPrivilegeToFineGrainedPermission(fine_grained_privilege)); } bool FineGrainedAuthChecker::HasGlobalPrivilegeOnEdges( @@ -178,7 +231,7 @@ bool FineGrainedAuthChecker::HasGlobalPrivilegeOnEdges( if (!memgraph::license::global_license_checker.IsEnterpriseValidFast()) { return true; } - return IsUserAuthorizedGloballyEdges(user_, FineGrainedPrivilegeToFineGrainedPermission(fine_grained_privilege)); + return IsAuthorizedGloballyEdges(user_or_role_, FineGrainedPrivilegeToFineGrainedPermission(fine_grained_privilege)); }; #endif } // namespace memgraph::glue diff --git a/src/glue/auth_checker.hpp b/src/glue/auth_checker.hpp index 217ac0c74..ef8e993df 100644 --- a/src/glue/auth_checker.hpp +++ b/src/glue/auth_checker.hpp @@ -22,53 +22,59 @@ namespace memgraph::glue { class AuthChecker : public query::AuthChecker { public: - explicit AuthChecker(memgraph::auth::SynchedAuth *auth); + explicit AuthChecker(auth::SynchedAuth *auth); - bool IsUserAuthorized(const std::optional &username, - const std::vector &privileges, - const std::string &db_name) const override; + std::shared_ptr GenQueryUser(const std::optional &username, + const std::optional &rolename) const override; + + static std::unique_ptr GenQueryUser(auth::SynchedAuth *auth, + const std::optional &user_or_role); #ifdef MG_ENTERPRISE - std::unique_ptr GetFineGrainedAuthChecker( - const std::string &username, const memgraph::query::DbAccessor *dba) const override; - - void ClearCache() const override; - + std::unique_ptr GetFineGrainedAuthChecker(std::shared_ptr user, + const query::DbAccessor *dba) const override; #endif - [[nodiscard]] static bool IsUserAuthorized(const memgraph::auth::User &user, - const std::vector &privileges, + + [[nodiscard]] static bool IsUserAuthorized(const auth::User &user, + const std::vector &privileges, const std::string &db_name = ""); + [[nodiscard]] static bool IsRoleAuthorized(const auth::Role &role, + const std::vector &privileges, + const std::string &db_name = ""); + + [[nodiscard]] static bool IsUserOrRoleAuthorized(const auth::UserOrRole &user_or_role, + const std::vector &privileges, + const std::string &db_name = ""); + private: - memgraph::auth::SynchedAuth *auth_; - mutable memgraph::utils::Synchronized user_; // cached user + auth::SynchedAuth *auth_; + mutable utils::Synchronized user_or_role_; // cached user }; #ifdef MG_ENTERPRISE class FineGrainedAuthChecker : public query::FineGrainedAuthChecker { public: - explicit FineGrainedAuthChecker(auth::User user, const memgraph::query::DbAccessor *dba); + explicit FineGrainedAuthChecker(auth::UserOrRole user, const query::DbAccessor *dba); - bool Has(const query::VertexAccessor &vertex, memgraph::storage::View view, + bool Has(const query::VertexAccessor &vertex, storage::View view, query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const override; bool Has(const query::EdgeAccessor &edge, query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const override; - bool Has(const std::vector &labels, + bool Has(const std::vector &labels, query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const override; - bool Has(const memgraph::storage::EdgeTypeId &edge_type, + bool Has(const storage::EdgeTypeId &edge_type, query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const override; - bool HasGlobalPrivilegeOnVertices( - memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const override; + bool HasGlobalPrivilegeOnVertices(query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const override; - bool HasGlobalPrivilegeOnEdges( - memgraph::query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const override; + bool HasGlobalPrivilegeOnEdges(query::AuthQuery::FineGrainedPrivilege fine_grained_privilege) const override; private: - auth::User user_; - const memgraph::query::DbAccessor *dba_; + auth::UserOrRole user_or_role_; + const query::DbAccessor *dba_; }; #endif } // namespace memgraph::glue diff --git a/src/glue/auth_handler.cpp b/src/glue/auth_handler.cpp index 2d7260b3c..6178b152e 100644 --- a/src/glue/auth_handler.cpp +++ b/src/glue/auth_handler.cpp @@ -15,6 +15,7 @@ #include +#include "auth/auth.hpp" #include "auth/models.hpp" #include "dbms/constants.hpp" #include "glue/auth.hpp" @@ -123,6 +124,29 @@ std::vector> ShowRolePrivileges( } #ifdef MG_ENTERPRISE +std::vector> ShowDatabasePrivileges( + const std::optional &role) { + if (!memgraph::license::global_license_checker.IsEnterpriseValidFast() || !role) { + return {}; + } + + const auto &db = role->db_access(); + const auto &allows = db.GetAllowAll(); + const auto &grants = db.GetGrants(); + const auto &denies = db.GetDenies(); + + std::vector res; // First element is a list of granted databases, second of revoked ones + if (allows) { + res.emplace_back("*"); + } else { + std::vector grants_vec(grants.cbegin(), grants.cend()); + res.emplace_back(std::move(grants_vec)); + } + std::vector denies_vec(denies.cbegin(), denies.cend()); + res.emplace_back(std::move(denies_vec)); + return {res}; +} + std::vector> ShowDatabasePrivileges( const std::optional &user) { if (!memgraph::license::global_license_checker.IsEnterpriseValidFast() || !user) { @@ -130,9 +154,15 @@ std::vector> ShowDatabasePrivileges( } const auto &db = user->db_access(); - const auto &allows = db.GetAllowAll(); - const auto &grants = db.GetGrants(); - const auto &denies = db.GetDenies(); + auto allows = db.GetAllowAll(); + auto grants = db.GetGrants(); + auto denies = db.GetDenies(); + if (const auto *role = user->role()) { + const auto &role_db = role->db_access(); + allows |= role_db.GetAllowAll(); + grants.insert(role_db.GetGrants().begin(), role_db.GetGrants().end()); + denies.insert(role_db.GetDenies().begin(), role_db.GetDenies().end()); + } std::vector res; // First element is a list of granted databases, second of revoked ones if (allows) { @@ -287,7 +317,7 @@ bool AuthQueryHandler::CreateUser(const std::string &username, const std::option , system_tx); #ifdef MG_ENTERPRISE - GrantDatabaseToUser(auth::kAllDatabases, username, system_tx); + GrantDatabase(auth::kAllDatabases, username, system_tx); SetMainDatabase(dbms::kDefaultDB, username, system_tx); #endif } @@ -334,51 +364,97 @@ bool AuthQueryHandler::CreateRole(const std::string &rolename, system::Transacti } #ifdef MG_ENTERPRISE -bool AuthQueryHandler::RevokeDatabaseFromUser(const std::string &db_name, const std::string &username, - system::Transaction *system_tx) { +void AuthQueryHandler::GrantDatabase(const std::string &db_name, const std::string &user_or_role, + system::Transaction *system_tx) { try { auto locked_auth = auth_->Lock(); - auto user = locked_auth->GetUser(username); - if (!user) return false; - return locked_auth->RevokeDatabaseFromUser(db_name, username, system_tx); + const auto res = locked_auth->GrantDatabase(db_name, user_or_role, system_tx); + switch (res) { + using enum auth::Auth::Result; + case SUCCESS: + return; + case NO_USER_ROLE: + throw query::QueryRuntimeException("No user nor role '{}' found.", user_or_role); + case NO_ROLE: + throw query::QueryRuntimeException("Using auth module, no role '{}' found.", user_or_role); + break; + } } catch (const memgraph::auth::AuthException &e) { throw memgraph::query::QueryRuntimeException(e.what()); } } -bool AuthQueryHandler::GrantDatabaseToUser(const std::string &db_name, const std::string &username, - system::Transaction *system_tx) { +void AuthQueryHandler::DenyDatabase(const std::string &db_name, const std::string &user_or_role, + system::Transaction *system_tx) { try { auto locked_auth = auth_->Lock(); - auto user = locked_auth->GetUser(username); - if (!user) return false; - return locked_auth->GrantDatabaseToUser(db_name, username, system_tx); + const auto res = locked_auth->DenyDatabase(db_name, user_or_role, system_tx); + switch (res) { + using enum auth::Auth::Result; + case SUCCESS: + return; + case NO_USER_ROLE: + throw query::QueryRuntimeException("No user nor role '{}' found.", user_or_role); + case NO_ROLE: + throw query::QueryRuntimeException("Using auth module, no role '{}' found.", user_or_role); + break; + } + } catch (const memgraph::auth::AuthException &e) { + throw memgraph::query::QueryRuntimeException(e.what()); + } +} + +void AuthQueryHandler::RevokeDatabase(const std::string &db_name, const std::string &user_or_role, + system::Transaction *system_tx) { + try { + auto locked_auth = auth_->Lock(); + const auto res = locked_auth->RevokeDatabase(db_name, user_or_role, system_tx); + switch (res) { + using enum auth::Auth::Result; + case SUCCESS: + return; + case NO_USER_ROLE: + throw query::QueryRuntimeException("No user nor role '{}' found.", user_or_role); + case NO_ROLE: + throw query::QueryRuntimeException("Using auth module, no role '{}' found.", user_or_role); + break; + } } catch (const memgraph::auth::AuthException &e) { throw memgraph::query::QueryRuntimeException(e.what()); } } std::vector> AuthQueryHandler::GetDatabasePrivileges( - const std::string &username) { + const std::string &user_or_role) { try { auto locked_auth = auth_->ReadLock(); - auto user = locked_auth->GetUser(username); - if (!user) { - throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist.", username); + if (auto user = locked_auth->GetUser(user_or_role)) { + return ShowDatabasePrivileges(user); } - return ShowDatabasePrivileges(user); + if (auto role = locked_auth->GetRole(user_or_role)) { + return ShowDatabasePrivileges(role); + } + throw memgraph::query::QueryRuntimeException("Neither user nor role '{}' exist.", user_or_role); } catch (const memgraph::auth::AuthException &e) { throw memgraph::query::QueryRuntimeException(e.what()); } } -bool AuthQueryHandler::SetMainDatabase(std::string_view db_name, const std::string &username, +void AuthQueryHandler::SetMainDatabase(std::string_view db_name, const std::string &user_or_role, system::Transaction *system_tx) { try { auto locked_auth = auth_->Lock(); - auto user = locked_auth->GetUser(username); - if (!user) return false; - return locked_auth->SetMainDatabase(db_name, username, system_tx); + const auto res = locked_auth->SetMainDatabase(db_name, user_or_role, system_tx); + switch (res) { + using enum auth::Auth::Result; + case SUCCESS: + return; + case NO_USER_ROLE: + throw query::QueryRuntimeException("No user nor role '{}' found.", user_or_role); + case NO_ROLE: + throw query::QueryRuntimeException("Using auth module, no role '{}' found.", user_or_role); + break; + } } catch (const memgraph::auth::AuthException &e) { throw memgraph::query::QueryRuntimeException(e.what()); } diff --git a/src/glue/auth_handler.hpp b/src/glue/auth_handler.hpp index 52db6075f..d78daaea4 100644 --- a/src/glue/auth_handler.hpp +++ b/src/glue/auth_handler.hpp @@ -37,15 +37,19 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler { system::Transaction *system_tx) override; #ifdef MG_ENTERPRISE - bool RevokeDatabaseFromUser(const std::string &db_name, const std::string &username, - system::Transaction *system_tx) override; + void GrantDatabase(const std::string &db_name, const std::string &user_or_role, + system::Transaction *system_tx) override; - bool GrantDatabaseToUser(const std::string &db_name, const std::string &username, - system::Transaction *system_tx) override; + void DenyDatabase(const std::string &db_name, const std::string &user_or_role, + system::Transaction *system_tx) override; - std::vector> GetDatabasePrivileges(const std::string &username) override; + void RevokeDatabase(const std::string &db_name, const std::string &user_or_role, + system::Transaction *system_tx) override; - bool SetMainDatabase(std::string_view db_name, const std::string &username, system::Transaction *system_tx) override; + std::vector> GetDatabasePrivileges(const std::string &user_or_role) override; + + void SetMainDatabase(std::string_view db_name, const std::string &user_or_role, + system::Transaction *system_tx) override; void DeleteDatabase(std::string_view db_name, system::Transaction *system_tx) override; #endif diff --git a/src/glue/query_user.cpp b/src/glue/query_user.cpp new file mode 100644 index 000000000..5cd6e6750 --- /dev/null +++ b/src/glue/query_user.cpp @@ -0,0 +1,41 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "glue/query_user.hpp" + +#include "glue/auth_checker.hpp" + +namespace memgraph::glue { + +bool QueryUserOrRole::IsAuthorized(const std::vector &privileges, + const std::string &db_name, query::UserPolicy *policy) const { + auto locked_auth = auth_->Lock(); + // Check policy and update if behind (and policy permits it) + if (policy->DoUpdate() && !locked_auth->UpToDate(auth_epoch_)) { + if (user_) user_ = locked_auth->GetUser(user_->username()); + if (role_) role_ = locked_auth->GetRole(role_->rolename()); + } + + if (user_) return AuthChecker::IsUserAuthorized(*user_, privileges, db_name); + if (role_) return AuthChecker::IsRoleAuthorized(*role_, privileges, db_name); + + return !policy->DoUpdate() || !locked_auth->AccessControlled(); +} + +#ifdef MG_ENTERPRISE +std::string QueryUserOrRole::GetDefaultDB() const { + if (user_) return user_->db_access().GetMain(); + if (role_) return role_->db_access().GetMain(); + return std::string{dbms::kDefaultDB}; +} +#endif + +} // namespace memgraph::glue diff --git a/src/glue/query_user.hpp b/src/glue/query_user.hpp new file mode 100644 index 000000000..22f3598db --- /dev/null +++ b/src/glue/query_user.hpp @@ -0,0 +1,57 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include + +#include "auth/auth.hpp" +#include "query/query_user.hpp" +#include "utils/variant_helpers.hpp" + +namespace memgraph::glue { + +struct QueryUserOrRole : public query::QueryUserOrRole { + bool IsAuthorized(const std::vector &privileges, const std::string &db_name, + query::UserPolicy *policy) const override; + +#ifdef MG_ENTERPRISE + std::string GetDefaultDB() const override; +#endif + + explicit QueryUserOrRole(auth::SynchedAuth *auth) : query::QueryUserOrRole{std::nullopt, std::nullopt}, auth_{auth} {} + + QueryUserOrRole(auth::SynchedAuth *auth, auth::UserOrRole user_or_role) + : query::QueryUserOrRole{std::visit( + utils::Overloaded{[](const auto &user_or_role) { return user_or_role.username(); }}, + user_or_role), + std::visit(utils::Overloaded{[&](const auth::User &) -> std::optional { + return std::nullopt; + }, + [&](const auth::Role &role) -> std::optional { + return role.rolename(); + }}, + user_or_role)}, + auth_{auth} { + std::visit(utils::Overloaded{[&](auth::User &&user) { user_.emplace(std::move(user)); }, + [&](auth::Role &&role) { role_.emplace(std::move(role)); }}, + std::move(user_or_role)); + } + + private: + friend class AuthChecker; + auth::SynchedAuth *auth_; + mutable std::optional user_{}; + mutable std::optional role_{}; + mutable auth::Auth::Epoch auth_epoch_{auth::Auth::kStartEpoch}; +}; + +} // namespace memgraph::glue diff --git a/src/integrations/kafka/consumer.cpp b/src/integrations/kafka/consumer.cpp index 9889fe46b..c5604e85a 100644 --- a/src/integrations/kafka/consumer.cpp +++ b/src/integrations/kafka/consumer.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -22,6 +22,7 @@ #include "integrations/constants.hpp" #include "integrations/kafka/exceptions.hpp" +#include "integrations/kafka/fmt.hpp" #include "utils/exceptions.hpp" #include "utils/logging.hpp" #include "utils/on_scope_exit.hpp" diff --git a/src/integrations/kafka/fmt.hpp b/src/integrations/kafka/fmt.hpp new file mode 100644 index 000000000..f85f74b49 --- /dev/null +++ b/src/integrations/kafka/fmt.hpp @@ -0,0 +1,25 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#if FMT_VERSION > 90000 +#include + +#include + +inline std::ostream &operator<<(std::ostream &os, const RdKafka::ErrorCode &code) { + os << RdKafka::err2str(code); + return os; +} +template <> +class fmt::formatter : public fmt::ostream_formatter {}; +#endif diff --git a/src/integrations/pulsar/consumer.cpp b/src/integrations/pulsar/consumer.cpp index f004cf6dc..1cfd8159c 100644 --- a/src/integrations/pulsar/consumer.cpp +++ b/src/integrations/pulsar/consumer.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -15,12 +15,12 @@ #include #include -#include #include #include #include "integrations/constants.hpp" #include "integrations/pulsar/exceptions.hpp" +#include "integrations/pulsar/fmt.hpp" #include "utils/concepts.hpp" #include "utils/logging.hpp" #include "utils/on_scope_exit.hpp" diff --git a/src/integrations/pulsar/fmt.hpp b/src/integrations/pulsar/fmt.hpp new file mode 100644 index 000000000..7585d87c7 --- /dev/null +++ b/src/integrations/pulsar/fmt.hpp @@ -0,0 +1,21 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#if FMT_VERSION > 90000 +#include + +#include "integrations/pulsar/consumer.hpp" + +template <> +class fmt::formatter : public fmt::ostream_formatter {}; +#endif diff --git a/src/coordination/include/coordination/coordinator_cluster_config.hpp b/src/io/network/fmt.hpp similarity index 74% rename from src/coordination/include/coordination/coordinator_cluster_config.hpp rename to src/io/network/fmt.hpp index e1d91ff7d..014de5353 100644 --- a/src/coordination/include/coordination/coordinator_cluster_config.hpp +++ b/src/io/network/fmt.hpp @@ -11,12 +11,11 @@ #pragma once -#ifdef MG_ENTERPRISE -namespace memgraph::coordination { +#if FMT_VERSION > 90000 +#include -struct CoordinatorClusterConfig { - static constexpr int alive_response_time_difference_sec_{5}; -}; +#include "io/network/endpoint.hpp" -} // namespace memgraph::coordination +template <> +class fmt::formatter : public fmt::ostream_formatter {}; #endif diff --git a/src/io/network/stream_buffer.hpp b/src/io/network/stream_buffer.hpp index 5ed7fc69e..5a9f01bf7 100644 --- a/src/io/network/stream_buffer.hpp +++ b/src/io/network/stream_buffer.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -11,6 +11,7 @@ #pragma once +#include #include namespace memgraph::io::network { diff --git a/src/kvstore/kvstore.hpp b/src/kvstore/kvstore.hpp index b9675d75b..84aa27009 100644 --- a/src/kvstore/kvstore.hpp +++ b/src/kvstore/kvstore.hpp @@ -160,13 +160,14 @@ class KVStore final { * and behaves as if all of those pairs are stored in a single iterable * collection of std::pair. */ - class iterator final : public std::iterator, // value_type - long, // difference_type - const std::pair *, // pointer - const std::pair & // reference - > { + class iterator final { public: + using iterator_concept [[maybe_unused]] = std::input_iterator_tag; + using value_type = std::pair; + using difference_type = long; + using pointer = const std::pair *; + using reference = const std::pair &; + explicit iterator(const KVStore *kvstore, const std::string &prefix = "", bool at_end = false); iterator(const iterator &other) = delete; diff --git a/src/memgraph.cpp b/src/memgraph.cpp index b965b82a9..34d64f434 100644 --- a/src/memgraph.cpp +++ b/src/memgraph.cpp @@ -27,6 +27,7 @@ #include "helpers.hpp" #include "license/license_sender.hpp" #include "memory/global_memory_control.hpp" +#include "query/auth_checker.hpp" #include "query/auth_query_handler.hpp" #include "query/config.hpp" #include "query/discard_value_stream.hpp" @@ -57,8 +58,13 @@ constexpr uint64_t kMgVmMaxMapCount = 262144; void InitFromCypherlFile(memgraph::query::InterpreterContext &ctx, memgraph::dbms::DatabaseAccess &db_acc, std::string cypherl_file_path, memgraph::audit::Log *audit_log = nullptr) { memgraph::query::Interpreter interpreter(&ctx, db_acc); - std::ifstream file(cypherl_file_path); + // Temporary empty user + // TODO: Double check with buda + memgraph::query::AllowEverythingAuthChecker tmp_auth_checker; + auto tmp_user = tmp_auth_checker.GenQueryUser(std::nullopt, std::nullopt); + interpreter.SetUser(tmp_user); + std::ifstream file(cypherl_file_path); if (!file.is_open()) { spdlog::trace("Could not find init file {}", cypherl_file_path); return; @@ -356,6 +362,11 @@ int main(int argc, char **argv) { memgraph::query::InterpreterConfig interp_config{ .query = {.allow_load_csv = FLAGS_allow_load_csv}, .replication_replica_check_frequency = std::chrono::seconds(FLAGS_replication_replica_check_frequency_sec), +#ifdef MG_ENTERPRISE + .instance_down_timeout_sec = std::chrono::seconds(FLAGS_instance_down_timeout_sec), + .instance_health_check_frequency_sec = std::chrono::seconds(FLAGS_instance_health_check_frequency_sec), + .instance_get_uuid_frequency_sec = std::chrono::seconds(FLAGS_instance_get_uuid_frequency_sec), +#endif .default_kafka_bootstrap_servers = FLAGS_kafka_bootstrap_servers, .default_pulsar_service_url = FLAGS_pulsar_service_url, .stream_transaction_conflict_retries = FLAGS_stream_transaction_conflict_retries, diff --git a/src/memory/global_memory_control.cpp b/src/memory/global_memory_control.cpp index 9c3f5db32..bcf12bd2c 100644 --- a/src/memory/global_memory_control.cpp +++ b/src/memory/global_memory_control.cpp @@ -61,10 +61,12 @@ void *my_alloc(extent_hooks_t *extent_hooks, void *new_addr, size_t size, size_t // This needs to be before, to throw exception in case of too big alloc if (*commit) [[likely]] { if (GetQueriesMemoryControl().IsThreadTracked()) [[unlikely]] { - GetQueriesMemoryControl().TrackAllocOnCurrentThread(size); + bool ok = GetQueriesMemoryControl().TrackAllocOnCurrentThread(size); + if (!ok) return nullptr; } // This needs to be here so it doesn't get incremented in case the first TrackAlloc throws an exception - memgraph::utils::total_memory_tracker.Alloc(static_cast(size)); + bool ok = memgraph::utils::total_memory_tracker.Alloc(static_cast(size)); + if (!ok) return nullptr; } auto *ptr = old_hooks->alloc(extent_hooks, new_addr, size, alignment, zero, commit, arena_ind); @@ -118,10 +120,14 @@ static bool my_commit(extent_hooks_t *extent_hooks, void *addr, size_t size, siz return err; } + [[maybe_unused]] auto blocker = memgraph::utils::MemoryTracker::OutOfMemoryExceptionBlocker{}; if (GetQueriesMemoryControl().IsThreadTracked()) [[unlikely]] { - GetQueriesMemoryControl().TrackAllocOnCurrentThread(length); + bool ok = GetQueriesMemoryControl().TrackAllocOnCurrentThread(length); + DMG_ASSERT(ok); } - memgraph::utils::total_memory_tracker.Alloc(static_cast(length)); + + auto ok = memgraph::utils::total_memory_tracker.Alloc(static_cast(length)); + DMG_ASSERT(ok); return false; } diff --git a/src/memory/new_delete.cpp b/src/memory/new_delete.cpp index 2f982ec67..32ed4d4be 100644 --- a/src/memory/new_delete.cpp +++ b/src/memory/new_delete.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -28,6 +28,12 @@ void *newImpl(const std::size_t size) { return ptr; } + [[maybe_unused]] auto blocker = memgraph::utils::MemoryTracker::OutOfMemoryExceptionBlocker{}; + auto maybe_msg = memgraph::utils::MemoryErrorStatus().msg(); + if (maybe_msg) { + throw memgraph::utils::OutOfMemoryException{std::move(*maybe_msg)}; + } + throw std::bad_alloc{}; } @@ -37,11 +43,21 @@ void *newImpl(const std::size_t size, const std::align_val_t align) { return ptr; } + [[maybe_unused]] auto blocker = memgraph::utils::MemoryTracker::OutOfMemoryExceptionBlocker{}; + auto maybe_msg = memgraph::utils::MemoryErrorStatus().msg(); + if (maybe_msg) { + throw memgraph::utils::OutOfMemoryException{std::move(*maybe_msg)}; + } + throw std::bad_alloc{}; } -void *newNoExcept(const std::size_t size) noexcept { return malloc(size); } +void *newNoExcept(const std::size_t size) noexcept { + [[maybe_unused]] auto blocker = memgraph::utils::MemoryTracker::OutOfMemoryExceptionBlocker{}; + return malloc(size); +} void *newNoExcept(const std::size_t size, const std::align_val_t align) noexcept { + [[maybe_unused]] auto blocker = memgraph::utils::MemoryTracker::OutOfMemoryExceptionBlocker{}; return aligned_alloc(size, static_cast(align)); } diff --git a/src/memory/query_memory_control.cpp b/src/memory/query_memory_control.cpp index 1a3ede032..119d936ec 100644 --- a/src/memory/query_memory_control.cpp +++ b/src/memory/query_memory_control.cpp @@ -54,14 +54,14 @@ void QueriesMemoryControl::EraseThreadToTransactionId(const std::thread::id &thr } } -void QueriesMemoryControl::TrackAllocOnCurrentThread(size_t size) { +bool QueriesMemoryControl::TrackAllocOnCurrentThread(size_t size) { auto thread_id_to_transaction_id_accessor = thread_id_to_transaction_id.access(); // we might be just constructing mapping between thread id and transaction id // so we miss this allocation auto thread_id_to_transaction_id_elem = thread_id_to_transaction_id_accessor.find(std::this_thread::get_id()); if (thread_id_to_transaction_id_elem == thread_id_to_transaction_id_accessor.end()) { - return; + return true; } auto transaction_id_to_tracker_accessor = transaction_id_to_tracker.access(); @@ -71,10 +71,10 @@ void QueriesMemoryControl::TrackAllocOnCurrentThread(size_t size) { // It can happen that some allocation happens between mapping thread to // transaction id, so we miss this allocation if (transaction_id_to_tracker == transaction_id_to_tracker_accessor.end()) [[unlikely]] { - return; + return true; } auto &query_tracker = transaction_id_to_tracker->tracker; - query_tracker.TrackAlloc(size); + return query_tracker.TrackAlloc(size); } void QueriesMemoryControl::TrackFreeOnCurrentThread(size_t size) { diff --git a/src/memory/query_memory_control.hpp b/src/memory/query_memory_control.hpp index 3852027a5..b598a8c73 100644 --- a/src/memory/query_memory_control.hpp +++ b/src/memory/query_memory_control.hpp @@ -62,7 +62,7 @@ class QueriesMemoryControl { // Find tracker for current thread if exists, track // query allocation and procedure allocation if // necessary - void TrackAllocOnCurrentThread(size_t size); + bool TrackAllocOnCurrentThread(size_t size); // Find tracker for current thread if exists, track // query allocation and procedure allocation if diff --git a/src/mg_import_csv.cpp b/src/mg_import_csv.cpp index cbfb905aa..2d77c2db2 100644 --- a/src/mg_import_csv.cpp +++ b/src/mg_import_csv.cpp @@ -139,6 +139,11 @@ struct NodeId { std::string id_space; }; +#if FMT_VERSION > 90000 +template <> +class fmt::formatter : public fmt::ostream_formatter {}; +#endif + bool operator==(const NodeId &a, const NodeId &b) { return a.id == b.id && a.id_space == b.id_space; } std::ostream &operator<<(std::ostream &stream, const NodeId &node_id) { diff --git a/src/py/py.hpp b/src/py/py.hpp index 14d54d657..7b25b595e 100644 --- a/src/py/py.hpp +++ b/src/py/py.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -274,3 +274,8 @@ inline void RestoreError(ExceptionInfo exc_info) { } } // namespace memgraph::py + +#if FMT_VERSION > 90000 +template <> +class fmt::formatter : public fmt::ostream_formatter {}; +#endif diff --git a/src/query/CMakeLists.txt b/src/query/CMakeLists.txt index 3bc7c9499..d70ede482 100644 --- a/src/query/CMakeLists.txt +++ b/src/query/CMakeLists.txt @@ -40,6 +40,7 @@ set(mg_query_sources db_accessor.cpp auth_query_handler.cpp interpreter_context.cpp + query_user.cpp ) add_library(mg-query STATIC ${mg_query_sources}) diff --git a/src/query/auth_checker.hpp b/src/query/auth_checker.hpp index 1eb9d02e9..183cbd900 100644 --- a/src/query/auth_checker.hpp +++ b/src/query/auth_checker.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -16,7 +16,9 @@ #include #include +#include "dbms/constants.hpp" #include "query/frontend/ast/ast.hpp" +#include "query/query_user.hpp" #include "storage/v2/id_types.hpp" namespace memgraph::query { @@ -29,15 +31,12 @@ class AuthChecker { public: virtual ~AuthChecker() = default; - [[nodiscard]] virtual bool IsUserAuthorized(const std::optional &username, - const std::vector &privileges, - const std::string &db_name) const = 0; + virtual std::shared_ptr GenQueryUser(const std::optional &username, + const std::optional &rolename) const = 0; #ifdef MG_ENTERPRISE [[nodiscard]] virtual std::unique_ptr GetFineGrainedAuthChecker( - const std::string &username, const DbAccessor *db_accessor) const = 0; - - virtual void ClearCache() const = 0; + std::shared_ptr user, const DbAccessor *db_accessor) const = 0; #endif }; #ifdef MG_ENTERPRISE @@ -98,19 +97,29 @@ class AllowEverythingFineGrainedAuthChecker final : public FineGrainedAuthChecke class AllowEverythingAuthChecker final : public AuthChecker { public: - bool IsUserAuthorized(const std::optional & /*username*/, - const std::vector & /*privileges*/, - const std::string & /*db*/) const override { - return true; + struct User : query::QueryUserOrRole { + User() : query::QueryUserOrRole{std::nullopt, std::nullopt} {} + User(std::string name) : query::QueryUserOrRole{std::move(name), std::nullopt} {} + bool IsAuthorized(const std::vector & /*privileges*/, const std::string & /*db_name*/, + UserPolicy * /*policy*/) const override { + return true; + } +#ifdef MG_ENTERPRISE + std::string GetDefaultDB() const override { return std::string{dbms::kDefaultDB}; } +#endif + }; + + std::shared_ptr GenQueryUser(const std::optional &name, + const std::optional & /*role*/) const override { + if (name) return std::make_shared(std::move(*name)); + return std::make_shared(); } #ifdef MG_ENTERPRISE - std::unique_ptr GetFineGrainedAuthChecker(const std::string & /*username*/, + std::unique_ptr GetFineGrainedAuthChecker(std::shared_ptr /*user*/, const DbAccessor * /*dba*/) const override { return std::make_unique(); } - - void ClearCache() const override {} #endif }; diff --git a/src/query/auth_query_handler.hpp b/src/query/auth_query_handler.hpp index 0258005c3..acc90c2c5 100644 --- a/src/query/auth_query_handler.hpp +++ b/src/query/auth_query_handler.hpp @@ -46,15 +46,17 @@ class AuthQueryHandler { system::Transaction *system_tx) = 0; #ifdef MG_ENTERPRISE - /// Return true if access revoked successfully - /// @throw QueryRuntimeException if an error ocurred. - virtual bool RevokeDatabaseFromUser(const std::string &db, const std::string &username, - system::Transaction *system_tx) = 0; - /// Return true if access granted successfully /// @throw QueryRuntimeException if an error ocurred. - virtual bool GrantDatabaseToUser(const std::string &db, const std::string &username, - system::Transaction *system_tx) = 0; + virtual void GrantDatabase(const std::string &db, const std::string &username, system::Transaction *system_tx) = 0; + + /// Return true if access revoked successfully + /// @throw QueryRuntimeException if an error ocurred. + virtual void DenyDatabase(const std::string &db, const std::string &username, system::Transaction *system_tx) = 0; + + /// Return true if access revoked successfully + /// @throw QueryRuntimeException if an error ocurred. + virtual void RevokeDatabase(const std::string &db, const std::string &username, system::Transaction *system_tx) = 0; /// Returns database access rights for the user /// @throw QueryRuntimeException if an error ocurred. @@ -62,7 +64,7 @@ class AuthQueryHandler { /// Return true if main database set successfully /// @throw QueryRuntimeException if an error ocurred. - virtual bool SetMainDatabase(std::string_view db, const std::string &username, system::Transaction *system_tx) = 0; + virtual void SetMainDatabase(std::string_view db, const std::string &username, system::Transaction *system_tx) = 0; /// Delete database from all users /// @throw QueryRuntimeException if an error ocurred. diff --git a/src/query/common.hpp b/src/query/common.hpp index 054714164..36ba07791 100644 --- a/src/query/common.hpp +++ b/src/query/common.hpp @@ -19,6 +19,7 @@ #include "query/db_accessor.hpp" #include "query/exceptions.hpp" +#include "query/fmt.hpp" #include "query/frontend/ast/ast.hpp" #include "query/frontend/semantic/symbol.hpp" #include "query/typed_value.hpp" diff --git a/src/query/config.hpp b/src/query/config.hpp index 88c3dd00e..fe8c2ae4c 100644 --- a/src/query/config.hpp +++ b/src/query/config.hpp @@ -22,6 +22,10 @@ struct InterpreterConfig { // The same as \ref memgraph::replication::ReplicationClientConfig std::chrono::seconds replication_replica_check_frequency{1}; + std::chrono::seconds instance_down_timeout_sec{5}; + std::chrono::seconds instance_health_check_frequency_sec{1}; + std::chrono::seconds instance_get_uuid_frequency_sec{10}; + std::string default_kafka_bootstrap_servers; std::string default_pulsar_service_url; uint32_t stream_transaction_conflict_retries; diff --git a/src/query/db_accessor.hpp b/src/query/db_accessor.hpp index 71b997d9e..e10102ee5 100644 --- a/src/query/db_accessor.hpp +++ b/src/query/db_accessor.hpp @@ -54,6 +54,10 @@ class EdgeAccessor final { return impl_.GetProperty(key, view); } + storage::Result GetPropertySize(storage::PropertyId key, storage::View view) const { + return impl_.GetPropertySize(key, view); + } + storage::Result SetProperty(storage::PropertyId key, const storage::PropertyValue &value) { return impl_.SetProperty(key, value); } @@ -129,6 +133,10 @@ class VertexAccessor final { return impl_.GetProperty(key, view); } + storage::Result GetPropertySize(storage::PropertyId key, storage::View view) const { + return impl_.GetPropertySize(key, view); + } + storage::Result SetProperty(storage::PropertyId key, const storage::PropertyValue &value) { return impl_.SetProperty(key, value); } @@ -268,6 +276,10 @@ class SubgraphVertexAccessor final { return impl_.GetProperty(view, key); } + storage::Result GetPropertySize(storage::PropertyId key, storage::View view) const { + return impl_.GetPropertySize(key, view); + } + storage::Gid Gid() const noexcept { return impl_.Gid(); } storage::Result InDegree(storage::View view) const { return impl_.InDegree(view); } @@ -529,6 +541,10 @@ class DbAccessor final { storage::PropertyId NameToProperty(const std::string_view name) { return accessor_->NameToProperty(name); } + std::optional NameToPropertyIfExists(std::string_view name) const { + return accessor_->NameToPropertyIfExists(name); + } + storage::LabelId NameToLabel(const std::string_view name) { return accessor_->NameToLabel(name); } storage::EdgeTypeId NameToEdgeType(const std::string_view name) { return accessor_->NameToEdgeType(name); } diff --git a/src/query/fmt.hpp b/src/query/fmt.hpp new file mode 100644 index 000000000..50a915715 --- /dev/null +++ b/src/query/fmt.hpp @@ -0,0 +1,23 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#if FMT_VERSION > 90000 +#include + +#include "query/typed_value.hpp" + +template <> +class fmt::formatter : public fmt::ostream_formatter {}; +template <> +class fmt::formatter : public fmt::ostream_formatter {}; +#endif diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index b9f637df8..65f1b58ef 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -2819,6 +2819,7 @@ class AuthQuery : public memgraph::query::Query { SHOW_ROLE_FOR_USER, SHOW_USERS_FOR_ROLE, GRANT_DATABASE_TO_USER, + DENY_DATABASE_FROM_USER, REVOKE_DATABASE_FROM_USER, SHOW_DATABASE_PRIVILEGES, SET_MAIN_DATABASE, @@ -3034,8 +3035,6 @@ class ReplicationQuery : public memgraph::query::Query { enum class SyncMode { SYNC, ASYNC }; - enum class ReplicaState { READY, REPLICATING, RECOVERY, MAYBE_BEHIND, DIVERGED_FROM_MAIN }; - ReplicationQuery() = default; DEFVISITABLE(QueryVisitor); @@ -3071,7 +3070,13 @@ class CoordinatorQuery : public memgraph::query::Query { static const utils::TypeInfo kType; const utils::TypeInfo &GetTypeInfo() const override { return kType; } - enum class Action { REGISTER_INSTANCE, SET_INSTANCE_TO_MAIN, SHOW_INSTANCES, ADD_COORDINATOR_INSTANCE }; + enum class Action { + REGISTER_INSTANCE, + UNREGISTER_INSTANCE, + SET_INSTANCE_TO_MAIN, + SHOW_INSTANCES, + ADD_COORDINATOR_INSTANCE + }; enum class SyncMode { SYNC, ASYNC }; diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index 3316b9490..e8425a8ed 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -405,6 +405,14 @@ antlrcpp::Any CypherMainVisitor::visitRegisterInstanceOnCoordinator( return coordinator_query; } +antlrcpp::Any CypherMainVisitor::visitUnregisterInstanceOnCoordinator( + MemgraphCypher::UnregisterInstanceOnCoordinatorContext *ctx) { + auto *coordinator_query = storage_->Create(); + coordinator_query->action_ = CoordinatorQuery::Action::UNREGISTER_INSTANCE; + coordinator_query->instance_name_ = std::any_cast(ctx->instanceName()->symbolicName()->accept(this)); + return coordinator_query; +} + antlrcpp::Any CypherMainVisitor::visitAddCoordinatorInstance(MemgraphCypher::AddCoordinatorInstanceContext *ctx) { auto *coordinator_query = storage_->Create(); @@ -1778,22 +1786,35 @@ antlrcpp::Any CypherMainVisitor::visitShowUsersForRole(MemgraphCypher::ShowUsers /** * @return AuthQuery* */ -antlrcpp::Any CypherMainVisitor::visitGrantDatabaseToUser(MemgraphCypher::GrantDatabaseToUserContext *ctx) { +antlrcpp::Any CypherMainVisitor::visitGrantDatabaseToUserOrRole(MemgraphCypher::GrantDatabaseToUserOrRoleContext *ctx) { auto *auth = storage_->Create(); auth->action_ = AuthQuery::Action::GRANT_DATABASE_TO_USER; auth->database_ = std::any_cast(ctx->wildcardName()->accept(this)); - auth->user_ = std::any_cast(ctx->user->accept(this)); + auth->user_ = std::any_cast(ctx->userOrRole->accept(this)); return auth; } /** * @return AuthQuery* */ -antlrcpp::Any CypherMainVisitor::visitRevokeDatabaseFromUser(MemgraphCypher::RevokeDatabaseFromUserContext *ctx) { +antlrcpp::Any CypherMainVisitor::visitDenyDatabaseFromUserOrRole( + MemgraphCypher::DenyDatabaseFromUserOrRoleContext *ctx) { + auto *auth = storage_->Create(); + auth->action_ = AuthQuery::Action::DENY_DATABASE_FROM_USER; + auth->database_ = std::any_cast(ctx->wildcardName()->accept(this)); + auth->user_ = std::any_cast(ctx->userOrRole->accept(this)); + return auth; +} + +/** + * @return AuthQuery* + */ +antlrcpp::Any CypherMainVisitor::visitRevokeDatabaseFromUserOrRole( + MemgraphCypher::RevokeDatabaseFromUserOrRoleContext *ctx) { auto *auth = storage_->Create(); auth->action_ = AuthQuery::Action::REVOKE_DATABASE_FROM_USER; auth->database_ = std::any_cast(ctx->wildcardName()->accept(this)); - auth->user_ = std::any_cast(ctx->user->accept(this)); + auth->user_ = std::any_cast(ctx->userOrRole->accept(this)); return auth; } @@ -1803,7 +1824,7 @@ antlrcpp::Any CypherMainVisitor::visitRevokeDatabaseFromUser(MemgraphCypher::Rev antlrcpp::Any CypherMainVisitor::visitShowDatabasePrivileges(MemgraphCypher::ShowDatabasePrivilegesContext *ctx) { auto *auth = storage_->Create(); auth->action_ = AuthQuery::Action::SHOW_DATABASE_PRIVILEGES; - auth->user_ = std::any_cast(ctx->user->accept(this)); + auth->user_ = std::any_cast(ctx->userOrRole->accept(this)); return auth; } @@ -1814,7 +1835,7 @@ antlrcpp::Any CypherMainVisitor::visitSetMainDatabase(MemgraphCypher::SetMainDat auto *auth = storage_->Create(); auth->action_ = AuthQuery::Action::SET_MAIN_DATABASE; auth->database_ = std::any_cast(ctx->db->accept(this)); - auth->user_ = std::any_cast(ctx->user->accept(this)); + auth->user_ = std::any_cast(ctx->userOrRole->accept(this)); return auth; } diff --git a/src/query/frontend/ast/cypher_main_visitor.hpp b/src/query/frontend/ast/cypher_main_visitor.hpp index d011484f8..d627c0d08 100644 --- a/src/query/frontend/ast/cypher_main_visitor.hpp +++ b/src/query/frontend/ast/cypher_main_visitor.hpp @@ -248,6 +248,12 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor { */ antlrcpp::Any visitRegisterInstanceOnCoordinator(MemgraphCypher::RegisterInstanceOnCoordinatorContext *ctx) override; + /** + * @return CoordinatorQuery* + */ + antlrcpp::Any visitUnregisterInstanceOnCoordinator( + MemgraphCypher::UnregisterInstanceOnCoordinatorContext *ctx) override; + /** * @return CoordinatorQuery* */ @@ -604,12 +610,17 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor { /** * @return AuthQuery* */ - antlrcpp::Any visitGrantDatabaseToUser(MemgraphCypher::GrantDatabaseToUserContext *ctx) override; + antlrcpp::Any visitGrantDatabaseToUserOrRole(MemgraphCypher::GrantDatabaseToUserOrRoleContext *ctx) override; /** * @return AuthQuery* */ - antlrcpp::Any visitRevokeDatabaseFromUser(MemgraphCypher::RevokeDatabaseFromUserContext *ctx) override; + antlrcpp::Any visitDenyDatabaseFromUserOrRole(MemgraphCypher::DenyDatabaseFromUserOrRoleContext *ctx) override; + + /** + * @return AuthQuery* + */ + antlrcpp::Any visitRevokeDatabaseFromUserOrRole(MemgraphCypher::RevokeDatabaseFromUserOrRoleContext *ctx) override; /** * @return AuthQuery* diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 index 779f822d4..769f75b6c 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 @@ -177,8 +177,9 @@ authQuery : createRole | showPrivileges | showRoleForUser | showUsersForRole - | grantDatabaseToUser - | revokeDatabaseFromUser + | grantDatabaseToUserOrRole + | denyDatabaseFromUserOrRole + | revokeDatabaseFromUserOrRole | showDatabasePrivileges | setMainDatabase ; @@ -191,6 +192,7 @@ replicationQuery : setReplicationRole ; coordinatorQuery : registerInstanceOnCoordinator + | unregisterInstanceOnCoordinator | setInstanceToMain | showInstances | addCoordinatorInstance @@ -303,13 +305,15 @@ denyPrivilege : DENY ( ALL PRIVILEGES | privileges=privilegesList ) TO userOrRol revokePrivilege : REVOKE ( ALL PRIVILEGES | privileges=revokePrivilegesList ) FROM userOrRole=userOrRoleName ; -grantDatabaseToUser : GRANT DATABASE db=wildcardName TO user=symbolicName ; +grantDatabaseToUserOrRole : GRANT DATABASE db=wildcardName TO userOrRole=userOrRoleName ; -revokeDatabaseFromUser : REVOKE DATABASE db=wildcardName FROM user=symbolicName ; +denyDatabaseFromUserOrRole : DENY DATABASE db=wildcardName FROM userOrRole=userOrRoleName ; -showDatabasePrivileges : SHOW DATABASE PRIVILEGES FOR user=symbolicName ; +revokeDatabaseFromUserOrRole : REVOKE DATABASE db=wildcardName FROM userOrRole=userOrRoleName ; -setMainDatabase : SET MAIN DATABASE db=symbolicName FOR user=symbolicName ; +showDatabasePrivileges : SHOW DATABASE PRIVILEGES FOR userOrRole=userOrRoleName ; + +setMainDatabase : SET MAIN DATABASE db=symbolicName FOR userOrRole=userOrRoleName ; privilege : CREATE | DELETE @@ -393,6 +397,8 @@ registerReplica : REGISTER REPLICA instanceName ( SYNC | ASYNC ) registerInstanceOnCoordinator : REGISTER INSTANCE instanceName ON coordinatorSocketAddress ( AS ASYNC ) ? WITH replicationSocketAddress ; +unregisterInstanceOnCoordinator : UNREGISTER INSTANCE instanceName ; + setInstanceToMain : SET INSTANCE instanceName TO MAIN ; raftServerId : literal ; diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 index b2d4de661..d40c5d3dc 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 @@ -141,6 +141,7 @@ TRIGGER : T R I G G E R ; TRIGGERS : T R I G G E R S ; UNCOMMITTED : U N C O M M I T T E D ; UNLOCK : U N L O C K ; +UNREGISTER : U N R E G I S T E R ; UPDATE : U P D A T E ; USE : U S E ; USER : U S E R ; diff --git a/src/query/interpret/awesome_memgraph_functions.cpp b/src/query/interpret/awesome_memgraph_functions.cpp index ece0aec78..6be8c4837 100644 --- a/src/query/interpret/awesome_memgraph_functions.cpp +++ b/src/query/interpret/awesome_memgraph_functions.cpp @@ -442,6 +442,29 @@ TypedValue Size(const TypedValue *args, int64_t nargs, const FunctionContext &ct } } +TypedValue PropertySize(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { + FType, Or>("propertySize", args, nargs); + + auto *dba = ctx.db_accessor; + + const auto &property_name = args[1].ValueString(); + const auto maybe_property_id = dba->NameToPropertyIfExists(property_name); + + if (!maybe_property_id) { + return TypedValue(0, ctx.memory); + } + + uint64_t property_size = 0; + const auto &graph_entity = args[0]; + if (graph_entity.IsVertex()) { + property_size = graph_entity.ValueVertex().GetPropertySize(*maybe_property_id, ctx.view).GetValue(); + } else if (graph_entity.IsEdge()) { + property_size = graph_entity.ValueEdge().GetPropertySize(*maybe_property_id, ctx.view).GetValue(); + } + + return TypedValue(static_cast(property_size), ctx.memory); +} + TypedValue StartNode(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) { FType>("startNode", args, nargs); if (args[0].IsNull()) return TypedValue(ctx.memory); @@ -1325,6 +1348,7 @@ std::function timeout; - uint64_t current_timestamp_of_replica; - uint64_t current_number_of_timestamp_behind_master; - ReplicationQuery::ReplicaState state; - }; - explicit ReplQueryHandler(query::ReplicationQueryHandler &replication_query_handler) : handler_{&replication_query_handler} {} @@ -328,7 +319,7 @@ class ReplQueryHandler { .port = static_cast(*port), }; - if (!handler_->SetReplicationRoleReplica(config, std::nullopt)) { + if (!handler_->TrySetReplicationRoleReplica(config, std::nullopt)) { throw QueryRuntimeException("Couldn't set role to replica!"); } } @@ -369,7 +360,7 @@ class ReplQueryHandler { .replica_check_frequency = replica_check_frequency, .ssl = std::nullopt}; - const auto error = handler_->TryRegisterReplica(replication_config, true).HasError(); + const auto error = handler_->TryRegisterReplica(replication_config).HasError(); if (error) { throw QueryRuntimeException(fmt::format("Couldn't register replica '{}'!", name)); @@ -396,58 +387,16 @@ class ReplQueryHandler { } } - std::vector ShowReplicas(const dbms::Database &db) const { - if (handler_->IsReplica()) { - // replica can't show registered replicas (it shouldn't have any) - throw QueryRuntimeException("Replica can't show registered replicas (it shouldn't have any)!"); + std::vector ShowReplicas() const { + auto info = handler_->ShowReplicas(); + if (info.HasError()) { + switch (info.GetError()) { + case ShowReplicaError::NOT_MAIN: + throw QueryRuntimeException("Replica can't show registered replicas (it shouldn't have any)!"); + } } - // TODO: Combine results? Have a single place with clients??? - // Also authentication checks (replica + database visibility) - const auto repl_infos = db.storage()->ReplicasInfo(); - std::vector replicas; - replicas.reserve(repl_infos.size()); - - const auto from_info = [](const auto &repl_info) -> ReplicaInfo { - ReplicaInfo replica; - replica.name = repl_info.name; - replica.socket_address = repl_info.endpoint.SocketAddress(); - switch (repl_info.mode) { - case replication_coordination_glue::ReplicationMode::SYNC: - replica.sync_mode = ReplicationQuery::SyncMode::SYNC; - break; - case replication_coordination_glue::ReplicationMode::ASYNC: - replica.sync_mode = ReplicationQuery::SyncMode::ASYNC; - break; - } - - replica.current_timestamp_of_replica = repl_info.timestamp_info.current_timestamp_of_replica; - replica.current_number_of_timestamp_behind_master = - repl_info.timestamp_info.current_number_of_timestamp_behind_master; - - switch (repl_info.state) { - case storage::replication::ReplicaState::READY: - replica.state = ReplicationQuery::ReplicaState::READY; - break; - case storage::replication::ReplicaState::REPLICATING: - replica.state = ReplicationQuery::ReplicaState::REPLICATING; - break; - case storage::replication::ReplicaState::RECOVERY: - replica.state = ReplicationQuery::ReplicaState::RECOVERY; - break; - case storage::replication::ReplicaState::MAYBE_BEHIND: - replica.state = ReplicationQuery::ReplicaState::MAYBE_BEHIND; - break; - case storage::replication::ReplicaState::DIVERGED_FROM_MAIN: - replica.state = ReplicationQuery::ReplicaState::DIVERGED_FROM_MAIN; - break; - } - - return replica; - }; - - std::transform(repl_infos.begin(), repl_infos.end(), std::back_inserter(replicas), from_info); - return replicas; + return info.GetValue().entries_; } private: @@ -461,11 +410,33 @@ class CoordQueryHandler final : public query::CoordinatorQueryHandler { : coordinator_handler_(coordinator_state) {} - /// @throw QueryRuntimeException if an error ocurred. - void RegisterReplicationInstance(const std::string &coordinator_socket_address, - const std::string &replication_socket_address, - const std::chrono::seconds instance_check_frequency, - const std::string &instance_name, CoordinatorQuery::SyncMode sync_mode) override { + void UnregisterInstance(std::string const &instance_name) override { + auto status = coordinator_handler_.UnregisterReplicationInstance(instance_name); + switch (status) { + using enum memgraph::coordination::UnregisterInstanceCoordinatorStatus; + case NO_INSTANCE_WITH_NAME: + throw QueryRuntimeException("No instance with such name!"); + case IS_MAIN: + throw QueryRuntimeException( + "Alive main instance can't be unregistered! Shut it down to trigger failover and then unregister it!"); + case NOT_COORDINATOR: + throw QueryRuntimeException("UNREGISTER INSTANCE query can only be run on a coordinator!"); + case NOT_LEADER: + throw QueryRuntimeException("Couldn't unregister replica instance since coordinator is not a leader!"); + case RPC_FAILED: + throw QueryRuntimeException( + "Couldn't unregister replica instance because current main instance couldn't unregister replica!"); + case SUCCESS: + break; + } + } + + void RegisterReplicationInstance(std::string const &coordinator_socket_address, + std::string const &replication_socket_address, + std::chrono::seconds const &instance_check_frequency, + std::chrono::seconds const &instance_down_timeout, + std::chrono::seconds const &instance_get_uuid_frequency, + std::string const &instance_name, CoordinatorQuery::SyncMode sync_mode) override { const auto maybe_replication_ip_port = io::network::Endpoint::ParseSocketOrAddress(replication_socket_address, std::nullopt); if (!maybe_replication_ip_port) { @@ -490,7 +461,9 @@ class CoordQueryHandler final : public query::CoordinatorQueryHandler { coordination::CoordinatorClientConfig{.instance_name = instance_name, .ip_address = coordinator_server_ip, .port = coordinator_server_port, - .health_check_frequency_sec = instance_check_frequency, + .instance_health_check_frequency_sec = instance_check_frequency, + .instance_down_timeout_sec = instance_down_timeout, + .instance_get_uuid_frequency_sec = instance_get_uuid_frequency, .replication_client_info = repl_config, .ssl = std::nullopt}; @@ -538,6 +511,8 @@ class CoordQueryHandler final : public query::CoordinatorQueryHandler { using enum memgraph::coordination::SetInstanceToMainCoordinatorStatus; case NO_INSTANCE_WITH_NAME: throw QueryRuntimeException("No instance with such name!"); + case MAIN_ALREADY_EXISTS: + throw QueryRuntimeException("Couldn't set instance to main since there is already a main instance in cluster!"); case NOT_COORDINATOR: throw QueryRuntimeException("SET INSTANCE TO MAIN query can only be run on a coordinator!"); case COULD_NOT_PROMOTE_TO_MAIN: @@ -603,6 +578,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_ AuthQuery::Action::SHOW_USERS_FOR_ROLE, AuthQuery::Action::SHOW_ROLE_FOR_USER, AuthQuery::Action::GRANT_DATABASE_TO_USER, + AuthQuery::Action::DENY_DATABASE_FROM_USER, AuthQuery::Action::REVOKE_DATABASE_FROM_USER, AuthQuery::Action::SHOW_DATABASE_PRIVILEGES, AuthQuery::Action::SET_MAIN_DATABASE}; @@ -862,9 +838,31 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_ if (database != memgraph::auth::kAllDatabases) { db = db_handler->Get(database); // Will throw if databases doesn't exist and protect it during pull } - if (!auth->GrantDatabaseToUser(database, username, &*interpreter->system_transaction_)) { - throw QueryRuntimeException("Failed to grant database {} to user {}.", database, username); + auth->GrantDatabase(database, username, &*interpreter->system_transaction_); // Can throws query exception + } catch (memgraph::dbms::UnknownDatabaseException &e) { + throw QueryRuntimeException(e.what()); + } +#else + callback.fn = [] { +#endif + return std::vector>(); + }; + return callback; + case AuthQuery::Action::DENY_DATABASE_FROM_USER: + forbid_on_replica(); +#ifdef MG_ENTERPRISE + callback.fn = [auth, database, username, db_handler, interpreter = &interpreter] { // NOLINT + if (!interpreter->system_transaction_) { + throw QueryException("Expected to be in a system transaction"); + } + + try { + std::optional db = + std::nullopt; // Hold pointer to database to protect it until query is done + if (database != memgraph::auth::kAllDatabases) { + db = db_handler->Get(database); // Will throw if databases doesn't exist and protect it during pull } + auth->DenyDatabase(database, username, &*interpreter->system_transaction_); // Can throws query exception } catch (memgraph::dbms::UnknownDatabaseException &e) { throw QueryRuntimeException(e.what()); } @@ -888,9 +886,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_ if (database != memgraph::auth::kAllDatabases) { db = db_handler->Get(database); // Will throw if databases doesn't exist and protect it during pull } - if (!auth->RevokeDatabaseFromUser(database, username, &*interpreter->system_transaction_)) { - throw QueryRuntimeException("Failed to revoke database {} from user {}.", database, username); - } + auth->RevokeDatabase(database, username, &*interpreter->system_transaction_); // Can throws query exception } catch (memgraph::dbms::UnknownDatabaseException &e) { throw QueryRuntimeException(e.what()); } @@ -923,9 +919,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_ try { const auto db = db_handler->Get(database); // Will throw if databases doesn't exist and protect it during pull - if (!auth->SetMainDatabase(database, username, &*interpreter->system_transaction_)) { - throw QueryRuntimeException("Failed to set main database {} for user {}.", database, username); - } + auth->SetMainDatabase(database, username, &*interpreter->system_transaction_); // Can throws query exception } catch (memgraph::dbms::UnknownDatabaseException &e) { throw QueryRuntimeException(e.what()); } @@ -1046,50 +1040,98 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters & } #endif - callback.header = { - "name", "socket_address", "sync_mode", "current_timestamp_of_replica", "number_of_timestamp_behind_master", - "state"}; + bool full_info = false; +#ifdef MG_ENTERPRISE + full_info = license::global_license_checker.IsEnterpriseValidFast(); +#endif + + callback.header = {"name", "socket_address", "sync_mode", "system_info", "data_info"}; + callback.fn = [handler = ReplQueryHandler{replication_query_handler}, replica_nfields = callback.header.size(), - db_acc = current_db.db_acc_] { - const auto &replicas = handler.ShowReplicas(*db_acc->get()); + full_info] { + auto const sync_mode_to_tv = [](memgraph::replication_coordination_glue::ReplicationMode sync_mode) { + using namespace std::string_view_literals; + switch (sync_mode) { + using enum memgraph::replication_coordination_glue::ReplicationMode; + case SYNC: + return TypedValue{"sync"sv}; + case ASYNC: + return TypedValue{"async"sv}; + } + }; + + auto const replica_sys_state_to_tv = [](memgraph::replication::ReplicationClient::State state) { + using namespace std::string_view_literals; + switch (state) { + using enum memgraph::replication::ReplicationClient::State; + case BEHIND: + return TypedValue{"invalid"sv}; + case READY: + return TypedValue{"ready"sv}; + case RECOVERY: + return TypedValue{"recovery"sv}; + } + }; + + auto const sys_info_to_tv = [&](ReplicaSystemInfoState orig) { + auto info = std::map{}; + info.emplace("ts", TypedValue{static_cast(orig.ts_)}); + // TODO: behind not implemented + info.emplace("behind", TypedValue{/* static_cast(orig.behind_) */}); + info.emplace("status", replica_sys_state_to_tv(orig.state_)); + return TypedValue{std::move(info)}; + }; + + auto const replica_state_to_tv = [](memgraph::storage::replication::ReplicaState state) { + using namespace std::string_view_literals; + switch (state) { + using enum memgraph::storage::replication::ReplicaState; + case READY: + return TypedValue{"ready"sv}; + case REPLICATING: + return TypedValue{"replicating"sv}; + case RECOVERY: + return TypedValue{"recovery"sv}; + case MAYBE_BEHIND: + return TypedValue{"invalid"sv}; + case DIVERGED_FROM_MAIN: + return TypedValue{"diverged"sv}; + } + }; + + auto const info_to_tv = [&](ReplicaInfoState orig) { + auto info = std::map{}; + info.emplace("ts", TypedValue{static_cast(orig.ts_)}); + info.emplace("behind", TypedValue{static_cast(orig.behind_)}); + info.emplace("status", replica_state_to_tv(orig.state_)); + return TypedValue{std::move(info)}; + }; + + auto const data_info_to_tv = [&](std::map orig) { + auto data_info = std::map{}; + for (auto &[name, info] : orig) { + data_info.emplace(name, info_to_tv(info)); + } + return TypedValue{std::move(data_info)}; + }; + + auto replicas = handler.ShowReplicas(); auto typed_replicas = std::vector>{}; typed_replicas.reserve(replicas.size()); - for (const auto &replica : replicas) { + for (auto &replica : replicas) { std::vector typed_replica; typed_replica.reserve(replica_nfields); - typed_replica.emplace_back(replica.name); - typed_replica.emplace_back(replica.socket_address); - - switch (replica.sync_mode) { - case ReplicationQuery::SyncMode::SYNC: - typed_replica.emplace_back("sync"); - break; - case ReplicationQuery::SyncMode::ASYNC: - typed_replica.emplace_back("async"); - break; - } - - typed_replica.emplace_back(static_cast(replica.current_timestamp_of_replica)); - typed_replica.emplace_back(static_cast(replica.current_number_of_timestamp_behind_master)); - - switch (replica.state) { - case ReplicationQuery::ReplicaState::READY: - typed_replica.emplace_back("ready"); - break; - case ReplicationQuery::ReplicaState::REPLICATING: - typed_replica.emplace_back("replicating"); - break; - case ReplicationQuery::ReplicaState::RECOVERY: - typed_replica.emplace_back("recovery"); - break; - case ReplicationQuery::ReplicaState::MAYBE_BEHIND: - typed_replica.emplace_back("invalid"); - break; - case ReplicationQuery::ReplicaState::DIVERGED_FROM_MAIN: - typed_replica.emplace_back("diverged"); - break; + typed_replica.emplace_back(replica.name_); + typed_replica.emplace_back(replica.socket_address_); + typed_replica.emplace_back(sync_mode_to_tv(replica.sync_mode_)); + if (full_info) { + typed_replica.emplace_back(sys_info_to_tv(replica.system_info_)); + } else { + // Set to NULL + typed_replica.emplace_back(TypedValue{}); } + typed_replica.emplace_back(data_info_to_tv(replica.data_info_)); typed_replicas.emplace_back(std::move(typed_replica)); } @@ -1104,17 +1146,21 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters & Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Parameters ¶meters, coordination::CoordinatorState *coordinator_state, const query::InterpreterConfig &config, std::vector *notifications) { + using enum memgraph::flags::Experiments; + + if (!license::global_license_checker.IsEnterpriseValidFast()) { + throw QueryRuntimeException("High availability is only available in Memgraph Enterprise."); + } + + if (!flags::AreExperimentsEnabled(HIGH_AVAILABILITY)) { + throw QueryRuntimeException( + "High availability is experimental feature. If you want to use it, add high-availability option to the " + "--experimental-enabled flag."); + } + Callback callback; switch (coordinator_query->action_) { case CoordinatorQuery::Action::ADD_COORDINATOR_INSTANCE: { - if (!license::global_license_checker.IsEnterpriseValidFast()) { - throw QueryException("Trying to use enterprise feature without a valid license."); - } - if constexpr (!coordination::allow_ha) { - throw QueryRuntimeException( - "High availability is experimental feature. Please set MG_EXPERIMENTAL_HIGH_AVAILABILITY compile flag to " - "be able to use this functionality."); - } if (!FLAGS_raft_server_id) { throw QueryRuntimeException("Only coordinator can add coordinator instance!"); } @@ -1138,15 +1184,6 @@ Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Param return callback; } case CoordinatorQuery::Action::REGISTER_INSTANCE: { - if (!license::global_license_checker.IsEnterpriseValidFast()) { - throw QueryException("Trying to use enterprise feature without a valid license."); - } - - if constexpr (!coordination::allow_ha) { - throw QueryRuntimeException( - "High availability is experimental feature. Please set MG_EXPERIMENTAL_HIGH_AVAILABILITY compile flag to " - "be able to use this functionality."); - } if (!FLAGS_raft_server_id) { throw QueryRuntimeException("Only coordinator can register coordinator server!"); } @@ -1158,12 +1195,16 @@ Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Param auto coordinator_socket_address_tv = coordinator_query->coordinator_socket_address_->Accept(evaluator); auto replication_socket_address_tv = coordinator_query->replication_socket_address_->Accept(evaluator); callback.fn = [handler = CoordQueryHandler{*coordinator_state}, coordinator_socket_address_tv, - replication_socket_address_tv, main_check_frequency = config.replication_replica_check_frequency, + replication_socket_address_tv, + instance_health_check_frequency_sec = config.instance_health_check_frequency_sec, instance_name = coordinator_query->instance_name_, + instance_down_timeout_sec = config.instance_down_timeout_sec, + instance_get_uuid_frequency_sec = config.instance_get_uuid_frequency_sec, sync_mode = coordinator_query->sync_mode_]() mutable { handler.RegisterReplicationInstance(std::string(coordinator_socket_address_tv.ValueString()), std::string(replication_socket_address_tv.ValueString()), - main_check_frequency, instance_name, sync_mode); + instance_health_check_frequency_sec, instance_down_timeout_sec, + instance_get_uuid_frequency_sec, instance_name, sync_mode); return std::vector>(); }; @@ -1173,15 +1214,22 @@ Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Param coordinator_socket_address_tv.ValueString(), coordinator_query->instance_name_)); return callback; } + case CoordinatorQuery::Action::UNREGISTER_INSTANCE: + if (!FLAGS_raft_server_id) { + throw QueryRuntimeException("Only coordinator can register coordinator server!"); + } + callback.fn = [handler = CoordQueryHandler{*coordinator_state}, + instance_name = coordinator_query->instance_name_]() mutable { + handler.UnregisterInstance(instance_name); + return std::vector>(); + }; + notifications->emplace_back( + SeverityLevel::INFO, NotificationCode::UNREGISTER_INSTANCE, + fmt::format("Coordinator has unregistered instance {}.", coordinator_query->instance_name_)); + + return callback; + case CoordinatorQuery::Action::SET_INSTANCE_TO_MAIN: { - if (!license::global_license_checker.IsEnterpriseValidFast()) { - throw QueryException("Trying to use enterprise feature without a valid license."); - } - if constexpr (!coordination::allow_ha) { - throw QueryRuntimeException( - "High availability is experimental feature. Please set MG_EXPERIMENTAL_HIGH_AVAILABILITY compile flag to " - "be able to use this functionality."); - } if (!FLAGS_raft_server_id) { throw QueryRuntimeException("Only coordinator can register coordinator server!"); } @@ -1199,14 +1247,6 @@ Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Param return callback; } case CoordinatorQuery::Action::SHOW_INSTANCES: { - if (!license::global_license_checker.IsEnterpriseValidFast()) { - throw QueryException("Trying to use enterprise feature without a valid license."); - } - if constexpr (!coordination::allow_ha) { - throw QueryRuntimeException( - "High availability is experimental feature. Please set MG_EXPERIMENTAL_HIGH_AVAILABILITY compile flag to " - "be able to use this functionality."); - } if (!FLAGS_raft_server_id) { throw QueryRuntimeException("Only coordinator can run SHOW INSTANCES."); } @@ -1215,17 +1255,13 @@ Callback HandleCoordinatorQuery(CoordinatorQuery *coordinator_query, const Param callback.fn = [handler = CoordQueryHandler{*coordinator_state}, replica_nfields = callback.header.size()]() mutable { auto const instances = handler.ShowInstances(); - std::vector> result{}; - result.reserve(result.size()); + auto const converter = [](const auto &status) -> std::vector { + return {TypedValue{status.instance_name}, TypedValue{status.raft_socket_address}, + TypedValue{status.coord_socket_address}, TypedValue{status.is_alive}, + TypedValue{status.cluster_role}}; + }; - std::ranges::transform(instances, std::back_inserter(result), - [](const auto &status) -> std::vector { - return {TypedValue{status.instance_name}, TypedValue{status.raft_socket_address}, - TypedValue{status.coord_socket_address}, TypedValue{status.is_alive}, - TypedValue{status.cluster_role}}; - }); - - return result; + return utils::fmap(converter, instances); }; return callback; } @@ -1255,7 +1291,7 @@ std::vector EvaluateTopicNames(ExpressionVisitor &evalu Callback::CallbackFunction GetKafkaCreateCallback(StreamQuery *stream_query, ExpressionVisitor &evaluator, memgraph::dbms::DatabaseAccess db_acc, InterpreterContext *interpreter_context, - const std::optional &username) { + std::shared_ptr user_or_role) { static constexpr std::string_view kDefaultConsumerGroup = "mg_consumer"; std::string consumer_group{stream_query->consumer_group_.empty() ? kDefaultConsumerGroup : stream_query->consumer_group_}; @@ -1282,10 +1318,13 @@ Callback::CallbackFunction GetKafkaCreateCallback(StreamQuery *stream_query, Exp memgraph::metrics::IncrementCounter(memgraph::metrics::StreamsCreated); + // Make a copy of the user and pass it to the subsystem + auto owner = interpreter_context->auth_checker->GenQueryUser(user_or_role->username(), user_or_role->rolename()); + return [db_acc = std::move(db_acc), interpreter_context, stream_name = stream_query->stream_name_, topic_names = EvaluateTopicNames(evaluator, stream_query->topic_names_), consumer_group = std::move(consumer_group), common_stream_info = std::move(common_stream_info), - bootstrap_servers = std::move(bootstrap), owner = username, + bootstrap_servers = std::move(bootstrap), owner = std::move(owner), configs = get_config_map(stream_query->configs_, "Configs"), credentials = get_config_map(stream_query->credentials_, "Credentials"), default_server = interpreter_context->config.default_kafka_bootstrap_servers]() mutable { @@ -1307,7 +1346,7 @@ Callback::CallbackFunction GetKafkaCreateCallback(StreamQuery *stream_query, Exp Callback::CallbackFunction GetPulsarCreateCallback(StreamQuery *stream_query, ExpressionVisitor &evaluator, memgraph::dbms::DatabaseAccess db, InterpreterContext *interpreter_context, - const std::optional &username) { + std::shared_ptr user_or_role) { auto service_url = GetOptionalStringValue(stream_query->service_url_, evaluator); if (service_url && service_url->empty()) { throw SemanticException("Service URL must not be an empty string!"); @@ -1315,9 +1354,13 @@ Callback::CallbackFunction GetPulsarCreateCallback(StreamQuery *stream_query, Ex auto common_stream_info = GetCommonStreamInfo(stream_query, evaluator); memgraph::metrics::IncrementCounter(memgraph::metrics::StreamsCreated); + // Make a copy of the user and pass it to the subsystem + auto owner = interpreter_context->auth_checker->GenQueryUser(user_or_role->username(), user_or_role->rolename()); + return [db = std::move(db), interpreter_context, stream_name = stream_query->stream_name_, topic_names = EvaluateTopicNames(evaluator, stream_query->topic_names_), - common_stream_info = std::move(common_stream_info), service_url = std::move(service_url), owner = username, + common_stream_info = std::move(common_stream_info), service_url = std::move(service_url), + owner = std::move(owner), default_service = interpreter_context->config.default_pulsar_service_url]() mutable { std::string url = service_url ? std::move(*service_url) : std::move(default_service); db->streams()->Create( @@ -1331,7 +1374,7 @@ Callback::CallbackFunction GetPulsarCreateCallback(StreamQuery *stream_query, Ex Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters ¶meters, memgraph::dbms::DatabaseAccess &db_acc, InterpreterContext *interpreter_context, - const std::optional &username, std::vector *notifications) { + std::shared_ptr user_or_role, std::vector *notifications) { // TODO: MemoryResource for EvaluationContext, it should probably be passed as // the argument to Callback. EvaluationContext evaluation_context; @@ -1344,10 +1387,12 @@ Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters ¶mete case StreamQuery::Action::CREATE_STREAM: { switch (stream_query->type_) { case StreamQuery::Type::KAFKA: - callback.fn = GetKafkaCreateCallback(stream_query, evaluator, db_acc, interpreter_context, username); + callback.fn = + GetKafkaCreateCallback(stream_query, evaluator, db_acc, interpreter_context, std::move(user_or_role)); break; case StreamQuery::Type::PULSAR: - callback.fn = GetPulsarCreateCallback(stream_query, evaluator, db_acc, interpreter_context, username); + callback.fn = + GetPulsarCreateCallback(stream_query, evaluator, db_acc, interpreter_context, std::move(user_or_role)); break; } notifications->emplace_back(SeverityLevel::INFO, NotificationCode::CREATE_STREAM, @@ -1620,7 +1665,7 @@ struct TxTimeout { struct PullPlan { explicit PullPlan(std::shared_ptr plan, const Parameters ¶meters, bool is_profile_query, DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory, - std::optional username, std::atomic *transaction_status, + std::shared_ptr user_or_role, std::atomic *transaction_status, std::shared_ptr tx_timer, TriggerContextCollector *trigger_context_collector = nullptr, std::optional memory_limit = {}, bool use_monotonic_memory = true, @@ -1660,7 +1705,7 @@ struct PullPlan { PullPlan::PullPlan(const std::shared_ptr plan, const Parameters ¶meters, const bool is_profile_query, DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory, - std::optional username, std::atomic *transaction_status, + std::shared_ptr user_or_role, std::atomic *transaction_status, std::shared_ptr tx_timer, TriggerContextCollector *trigger_context_collector, const std::optional memory_limit, bool use_monotonic_memory, FrameChangeCollector *frame_change_collector) @@ -1676,10 +1721,9 @@ PullPlan::PullPlan(const std::shared_ptr plan, const Parameters &pa ctx_.evaluation_context.properties = NamesToProperties(plan->ast_storage().properties_, dba); ctx_.evaluation_context.labels = NamesToLabels(plan->ast_storage().labels_, dba); #ifdef MG_ENTERPRISE - if (license::global_license_checker.IsEnterpriseValidFast() && username.has_value() && dba) { - // TODO How can we avoid creating this every time? If we must create it, it would be faster with an auth::User - // instead of the username - auto auth_checker = interpreter_context->auth_checker->GetFineGrainedAuthChecker(*username, dba); + if (license::global_license_checker.IsEnterpriseValidFast() && user_or_role && *user_or_role && dba) { + // Create only if an explicit user is defined + auto auth_checker = interpreter_context->auth_checker->GetFineGrainedAuthChecker(std::move(user_or_role), dba); // if the user has global privileges to read, edit and write anything, we don't need to perform authorization // otherwise, we do assign the auth checker to check for label access control @@ -1969,7 +2013,7 @@ bool IsCallBatchedProcedureQuery(const std::vector &c PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map *summary, InterpreterContext *interpreter_context, CurrentDB ¤t_db, utils::MemoryResource *execution_memory, std::vector *notifications, - std::optional const &username, + std::shared_ptr user_or_role, std::atomic *transaction_status, std::shared_ptr tx_timer, FrameChangeCollector *frame_change_collector = nullptr) { @@ -2037,8 +2081,8 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map( - plan, parsed_query.parameters, false, dba, interpreter_context, execution_memory, username, transaction_status, - std::move(tx_timer), trigger_context_collector, memory_limit, use_monotonic_memory, + plan, parsed_query.parameters, false, dba, interpreter_context, execution_memory, std::move(user_or_role), + transaction_status, std::move(tx_timer), trigger_context_collector, memory_limit, use_monotonic_memory, frame_change_collector->IsTrackingValues() ? frame_change_collector : nullptr); return PreparedQuery{std::move(header), std::move(parsed_query.required_privileges), [pull_plan = std::move(pull_plan), output_symbols = std::move(output_symbols), summary]( @@ -2110,7 +2154,8 @@ PreparedQuery PrepareExplainQuery(ParsedQuery parsed_query, std::map *summary, std::vector *notifications, InterpreterContext *interpreter_context, CurrentDB ¤t_db, - utils::MemoryResource *execution_memory, std::optional const &username, + utils::MemoryResource *execution_memory, + std::shared_ptr user_or_role, std::atomic *transaction_status, std::shared_ptr tx_timer, FrameChangeCollector *frame_change_collector) { @@ -2188,37 +2233,37 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra rw_type_checker.InferRWType(const_cast(cypher_query_plan->plan())); - return PreparedQuery{{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"}, - std::move(parsed_query.required_privileges), - [plan = std::move(cypher_query_plan), parameters = std::move(parsed_inner_query.parameters), - summary, dba, interpreter_context, execution_memory, memory_limit, username, - // We want to execute the query we are profiling lazily, so we delay - // the construction of the corresponding context. - stats_and_total_time = std::optional{}, - pull_plan = std::shared_ptr(nullptr), transaction_status, use_monotonic_memory, - frame_change_collector, tx_timer = std::move(tx_timer)]( - AnyStream *stream, std::optional n) mutable -> std::optional { - // No output symbols are given so that nothing is streamed. - if (!stats_and_total_time) { - stats_and_total_time = - PullPlan(plan, parameters, true, dba, interpreter_context, execution_memory, username, - transaction_status, std::move(tx_timer), nullptr, memory_limit, - use_monotonic_memory, - frame_change_collector->IsTrackingValues() ? frame_change_collector : nullptr) - .Pull(stream, {}, {}, summary); - pull_plan = std::make_shared(ProfilingStatsToTable(*stats_and_total_time)); - } + return PreparedQuery{ + {"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"}, + std::move(parsed_query.required_privileges), + [plan = std::move(cypher_query_plan), parameters = std::move(parsed_inner_query.parameters), summary, dba, + interpreter_context, execution_memory, memory_limit, user_or_role = std::move(user_or_role), + // We want to execute the query we are profiling lazily, so we delay + // the construction of the corresponding context. + stats_and_total_time = std::optional{}, + pull_plan = std::shared_ptr(nullptr), transaction_status, use_monotonic_memory, + frame_change_collector, tx_timer = std::move(tx_timer)]( + AnyStream *stream, std::optional n) mutable -> std::optional { + // No output symbols are given so that nothing is streamed. + if (!stats_and_total_time) { + stats_and_total_time = + PullPlan(plan, parameters, true, dba, interpreter_context, execution_memory, std::move(user_or_role), + transaction_status, std::move(tx_timer), nullptr, memory_limit, use_monotonic_memory, + frame_change_collector->IsTrackingValues() ? frame_change_collector : nullptr) + .Pull(stream, {}, {}, summary); + pull_plan = std::make_shared(ProfilingStatsToTable(*stats_and_total_time)); + } - MG_ASSERT(stats_and_total_time, "Failed to execute the query!"); + MG_ASSERT(stats_and_total_time, "Failed to execute the query!"); - if (pull_plan->Pull(stream, n)) { - summary->insert_or_assign("profile", ProfilingStatsToJson(*stats_and_total_time).dump()); - return QueryHandlerResult::ABORT; - } + if (pull_plan->Pull(stream, n)) { + summary->insert_or_assign("profile", ProfilingStatsToJson(*stats_and_total_time).dump()); + return QueryHandlerResult::ABORT; + } - return std::nullopt; - }, - rw_type_checker.type}; + return std::nullopt; + }, + rw_type_checker.type}; } PreparedQuery PrepareDumpQuery(ParsedQuery parsed_query, CurrentDB ¤t_db) { @@ -2642,26 +2687,22 @@ PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transa auto callback = HandleAuthQuery(auth_query, interpreter_context, parsed_query.parameters, interpreter); - return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges), - [handler = std::move(callback.fn), pull_plan = std::shared_ptr(nullptr), - interpreter_context]( // NOLINT - AnyStream *stream, std::optional n) mutable -> std::optional { - if (!pull_plan) { - // Run the specific query - auto results = handler(); - pull_plan = std::make_shared(std::move(results)); -#ifdef MG_ENTERPRISE - // Invalidate auth cache after every type of AuthQuery - interpreter_context->auth_checker->ClearCache(); -#endif - } + return PreparedQuery{ + std::move(callback.header), std::move(parsed_query.required_privileges), + [handler = std::move(callback.fn), pull_plan = std::shared_ptr(nullptr)]( // NOLINT + AnyStream *stream, std::optional n) mutable -> std::optional { + if (!pull_plan) { + // Run the specific query + auto results = handler(); + pull_plan = std::make_shared(std::move(results)); + } - if (pull_plan->Pull(stream, n)) { - return QueryHandlerResult::COMMIT; - } - return std::nullopt; - }, - RWType::NONE}; + if (pull_plan->Pull(stream, n)) { + return QueryHandlerResult::COMMIT; + } + return std::nullopt; + }, + RWType::NONE}; } PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, bool in_explicit_transaction, @@ -2865,17 +2906,18 @@ TriggerEventType ToTriggerEventType(const TriggerQuery::EventType event_type) { Callback CreateTrigger(TriggerQuery *trigger_query, const std::map &user_parameters, TriggerStore *trigger_store, InterpreterContext *interpreter_context, DbAccessor *dba, - std::optional owner) { + std::shared_ptr user_or_role) { + // Make a copy of the user and pass it to the subsystem + auto owner = interpreter_context->auth_checker->GenQueryUser(user_or_role->username(), user_or_role->rolename()); return {{}, [trigger_name = std::move(trigger_query->trigger_name_), trigger_statement = std::move(trigger_query->statement_), event_type = trigger_query->event_type_, before_commit = trigger_query->before_commit_, trigger_store, interpreter_context, dba, user_parameters, owner = std::move(owner)]() mutable -> std::vector> { - trigger_store->AddTrigger(std::move(trigger_name), trigger_statement, user_parameters, - ToTriggerEventType(event_type), - before_commit ? TriggerPhase::BEFORE_COMMIT : TriggerPhase::AFTER_COMMIT, - &interpreter_context->ast_cache, dba, interpreter_context->config.query, - std::move(owner), interpreter_context->auth_checker); + trigger_store->AddTrigger( + std::move(trigger_name), trigger_statement, user_parameters, ToTriggerEventType(event_type), + before_commit ? TriggerPhase::BEFORE_COMMIT : TriggerPhase::AFTER_COMMIT, + &interpreter_context->ast_cache, dba, interpreter_context->config.query, std::move(owner)); memgraph::metrics::IncrementCounter(memgraph::metrics::TriggersCreated); return {}; }}; @@ -2917,7 +2959,7 @@ PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, bool in_explicit_tra std::vector *notifications, CurrentDB ¤t_db, InterpreterContext *interpreter_context, const std::map &user_parameters, - std::optional const &username) { + std::shared_ptr user_or_role) { if (in_explicit_transaction) { throw TriggerModificationInMulticommandTxException(); } @@ -2931,8 +2973,9 @@ PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, bool in_explicit_tra MG_ASSERT(trigger_query); std::optional trigger_notification; + auto callback = std::invoke([trigger_query, trigger_store, interpreter_context, dba, &user_parameters, - owner = username, &trigger_notification]() mutable { + owner = std::move(user_or_role), &trigger_notification]() mutable { switch (trigger_query->action_) { case TriggerQuery::Action::CREATE_TRIGGER: trigger_notification.emplace(SeverityLevel::INFO, NotificationCode::CREATE_TRIGGER, @@ -2970,7 +3013,8 @@ PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, bool in_explicit_tra PreparedQuery PrepareStreamQuery(ParsedQuery parsed_query, bool in_explicit_transaction, std::vector *notifications, CurrentDB ¤t_db, - InterpreterContext *interpreter_context, const std::optional &username) { + InterpreterContext *interpreter_context, + std::shared_ptr user_or_role) { if (in_explicit_transaction) { throw StreamQueryInMulticommandTxException(); } @@ -2980,8 +3024,8 @@ PreparedQuery PrepareStreamQuery(ParsedQuery parsed_query, bool in_explicit_tran auto *stream_query = utils::Downcast(parsed_query.query); MG_ASSERT(stream_query); - auto callback = - HandleStreamQuery(stream_query, parsed_query.parameters, db_acc, interpreter_context, username, notifications); + auto callback = HandleStreamQuery(stream_query, parsed_query.parameters, db_acc, interpreter_context, + std::move(user_or_role), notifications); return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges), [callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr{nullptr}]( @@ -3327,7 +3371,7 @@ PreparedQuery PrepareSettingQuery(ParsedQuery parsed_query, bool in_explicit_tra } template -auto ShowTransactions(const std::unordered_set &interpreters, const std::optional &username, +auto ShowTransactions(const std::unordered_set &interpreters, QueryUserOrRole *user_or_role, Func &&privilege_checker) -> std::vector> { std::vector> results; results.reserve(interpreters.size()); @@ -3347,11 +3391,21 @@ auto ShowTransactions(const std::unordered_set &interpreters, con static std::string all; return interpreter->current_db_.db_acc_ ? interpreter->current_db_.db_acc_->get()->name() : all; }; - if (transaction_id.has_value() && - (interpreter->username_ == username || privilege_checker(get_interpreter_db_name()))) { + + auto same_user = [](const auto &lv, const auto &rv) { + if (lv.get() == rv) return true; + if (lv && rv) return *lv == *rv; + return false; + }; + + if (transaction_id.has_value() && (same_user(interpreter->user_or_role_, user_or_role) || + privilege_checker(user_or_role, get_interpreter_db_name()))) { const auto &typed_queries = interpreter->GetQueries(); - results.push_back({TypedValue(interpreter->username_.value_or("")), - TypedValue(std::to_string(transaction_id.value())), TypedValue(typed_queries)}); + results.push_back( + {TypedValue(interpreter->user_or_role_ + ? (interpreter->user_or_role_->username() ? *interpreter->user_or_role_->username() : "") + : ""), + TypedValue(std::to_string(transaction_id.value())), TypedValue(typed_queries)}); // Handle user-defined metadata std::map metadata_tv; if (interpreter->metadata_) { @@ -3366,17 +3420,19 @@ auto ShowTransactions(const std::unordered_set &interpreters, con } Callback HandleTransactionQueueQuery(TransactionQueueQuery *transaction_query, - const std::optional &username, const Parameters ¶meters, + std::shared_ptr user_or_role, const Parameters ¶meters, InterpreterContext *interpreter_context) { - auto privilege_checker = [username, auth_checker = interpreter_context->auth_checker](std::string const &db_name) { - return auth_checker->IsUserAuthorized(username, {query::AuthQuery::Privilege::TRANSACTION_MANAGEMENT}, db_name); + auto privilege_checker = [](QueryUserOrRole *user_or_role, std::string const &db_name) { + return user_or_role && user_or_role->IsAuthorized({query::AuthQuery::Privilege::TRANSACTION_MANAGEMENT}, db_name, + &query::up_to_date_policy); }; Callback callback; switch (transaction_query->action_) { case TransactionQueueQuery::Action::SHOW_TRANSACTIONS: { - auto show_transactions = [username, privilege_checker = std::move(privilege_checker)](const auto &interpreters) { - return ShowTransactions(interpreters, username, privilege_checker); + auto show_transactions = [user_or_role = std::move(user_or_role), + privilege_checker = std::move(privilege_checker)](const auto &interpreters) { + return ShowTransactions(interpreters, user_or_role.get(), privilege_checker); }; callback.header = {"username", "transaction_id", "query", "metadata"}; callback.fn = [interpreter_context, show_transactions = std::move(show_transactions)] { @@ -3394,9 +3450,10 @@ Callback HandleTransactionQueueQuery(TransactionQueueQuery *transaction_query, return std::string(expression->Accept(evaluator).ValueString()); }); callback.header = {"transaction_id", "killed"}; - callback.fn = [interpreter_context, maybe_kill_transaction_ids = std::move(maybe_kill_transaction_ids), username, + callback.fn = [interpreter_context, maybe_kill_transaction_ids = std::move(maybe_kill_transaction_ids), + user_or_role = std::move(user_or_role), privilege_checker = std::move(privilege_checker)]() mutable { - return interpreter_context->TerminateTransactions(std::move(maybe_kill_transaction_ids), username, + return interpreter_context->TerminateTransactions(std::move(maybe_kill_transaction_ids), user_or_role.get(), std::move(privilege_checker)); }; break; @@ -3406,12 +3463,12 @@ Callback HandleTransactionQueueQuery(TransactionQueueQuery *transaction_query, return callback; } -PreparedQuery PrepareTransactionQueueQuery(ParsedQuery parsed_query, const std::optional &username, +PreparedQuery PrepareTransactionQueueQuery(ParsedQuery parsed_query, std::shared_ptr user_or_role, InterpreterContext *interpreter_context) { auto *transaction_queue_query = utils::Downcast(parsed_query.query); MG_ASSERT(transaction_queue_query); - auto callback = - HandleTransactionQueueQuery(transaction_queue_query, username, parsed_query.parameters, interpreter_context); + auto callback = HandleTransactionQueueQuery(transaction_queue_query, std::move(user_or_role), parsed_query.parameters, + interpreter_context); return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges), [callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr{nullptr}]( @@ -4044,7 +4101,7 @@ PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, CurrentDB &cur } PreparedQuery PrepareShowDatabasesQuery(ParsedQuery parsed_query, InterpreterContext *interpreter_context, - const std::optional &username) { + std::shared_ptr user_or_role) { #ifdef MG_ENTERPRISE if (!license::global_license_checker.IsEnterpriseValidFast()) { throw QueryException("Trying to use enterprise feature without a valid license."); @@ -4055,7 +4112,8 @@ PreparedQuery PrepareShowDatabasesQuery(ParsedQuery parsed_query, InterpreterCon Callback callback; callback.header = {"Name"}; - callback.fn = [auth, db_handler, username]() mutable -> std::vector> { + callback.fn = [auth, db_handler, + user_or_role = std::move(user_or_role)]() mutable -> std::vector> { std::vector> status; auto gen_status = [&](T all, K denied) { Sort(all); @@ -4077,12 +4135,12 @@ PreparedQuery PrepareShowDatabasesQuery(ParsedQuery parsed_query, InterpreterCon status.erase(iter, status.end()); }; - if (!username) { + if (!user_or_role || !*user_or_role) { // No user, return all gen_status(db_handler->All(), std::vector{}); } else { // User has a subset of accessible dbs; this is synched with the SessionContextHandler - const auto &db_priv = auth->GetDatabasePrivileges(*username); + const auto &db_priv = auth->GetDatabasePrivileges(user_or_role->key()); const auto &allowed = db_priv[0][0]; const auto &denied = db_priv[0][1].ValueList(); if (allowed.IsString() && allowed.ValueString() == auth::kAllDatabases) { @@ -4150,6 +4208,7 @@ void Interpreter::SetCurrentDB(std::string_view db_name, bool in_explicit_db) { Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, const std::map ¶ms, QueryExtras const &extras) { + MG_ASSERT(user_or_role_, "Trying to prepare a query without a query user."); // Handle transaction control queries. const auto upper_case_query = utils::ToUpperCase(query_string); const auto trimmed_query = utils::Trim(upper_case_query); @@ -4292,7 +4351,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, frame_change_collector_.emplace(); if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareCypherQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_, - current_db_, memory_resource, &query_execution->notifications, username_, + current_db_, memory_resource, &query_execution->notifications, user_or_role_, &transaction_status_, current_timeout_timer_, &*frame_change_collector_); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareExplainQuery(std::move(parsed_query), &query_execution->summary, @@ -4300,7 +4359,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareProfileQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, &query_execution->notifications, interpreter_context_, current_db_, - &query_execution->execution_memory_with_exception, username_, + &query_execution->execution_memory_with_exception, user_or_role_, &transaction_status_, current_timeout_timer_, &*frame_change_collector_); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareDumpQuery(std::move(parsed_query), current_db_); @@ -4344,11 +4403,11 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareTriggerQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications, - current_db_, interpreter_context_, params, username_); + current_db_, interpreter_context_, params, user_or_role_); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareStreamQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications, - current_db_, interpreter_context_, username_); + current_db_, interpreter_context_, user_or_role_); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareIsolationLevelQuery(std::move(parsed_query), in_explicit_transaction_, current_db_, this); } else if (utils::Downcast(parsed_query.query)) { @@ -4369,7 +4428,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, if (in_explicit_transaction_) { throw TransactionQueueInMulticommandTxException(); } - prepared_query = PrepareTransactionQueueQuery(std::move(parsed_query), username_, interpreter_context_); + prepared_query = PrepareTransactionQueueQuery(std::move(parsed_query), user_or_role_, interpreter_context_); } else if (utils::Downcast(parsed_query.query)) { if (in_explicit_transaction_) { throw MultiDatabaseQueryInMulticommandTxException(); @@ -4379,7 +4438,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, prepared_query = PrepareMultiDatabaseQuery(std::move(parsed_query), current_db_, interpreter_context_, on_change_, *this); } else if (utils::Downcast(parsed_query.query)) { - prepared_query = PrepareShowDatabasesQuery(std::move(parsed_query), interpreter_context_, username_); + prepared_query = PrepareShowDatabasesQuery(std::move(parsed_query), interpreter_context_, user_or_role_); } else if (utils::Downcast(parsed_query.query)) { if (in_explicit_transaction_) { throw EdgeImportModeModificationInMulticommandTxException(); @@ -4402,9 +4461,19 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, UpdateTypeCount(rw_type); - if (interpreter_context_->repl_state->IsReplica() && IsQueryWrite(rw_type)) { - query_execution = nullptr; - throw QueryException("Write query forbidden on the replica!"); + bool const write_query = IsQueryWrite(rw_type); + if (write_query) { + if (interpreter_context_->repl_state->IsReplica()) { + query_execution = nullptr; + throw QueryException("Write query forbidden on the replica!"); + } +#ifdef MG_ENTERPRISE + if (FLAGS_coordinator_server_port && !interpreter_context_->repl_state->IsMainWriteable()) { + query_execution = nullptr; + throw QueryException( + "Write query forbidden on the main! Coordinator needs to enable writing on main by sending RPC message."); + } +#endif } // Set the target db to the current db (some queries have different target from the current db) @@ -4450,6 +4519,12 @@ std::vector Interpreter::GetQueries() { void Interpreter::Abort() { bool decrement = true; + + // System tx + // TODO Implement system transaction scope and the ability to abort + system_transaction_.reset(); + + // Data tx auto expected = TransactionStatus::ACTIVE; while (!transaction_status_.compare_exchange_weak(expected, TransactionStatus::STARTED_ROLLBACK)) { if (expected == TransactionStatus::TERMINATED || expected == TransactionStatus::IDLE) { @@ -4501,8 +4576,7 @@ void RunTriggersAfterCommit(dbms::DatabaseAccess db_acc, InterpreterContext *int trigger_context.AdaptForAccessor(&db_accessor); try { trigger.Execute(&db_accessor, &execution_memory, flags::run_time::GetExecutionTimeout(), - &interpreter_context->is_shutting_down, transaction_status, trigger_context, - interpreter_context->auth_checker); + &interpreter_context->is_shutting_down, transaction_status, trigger_context); } catch (const utils::BasicException &exception) { spdlog::warn("Trigger '{}' failed with exception:\n{}", trigger.Name(), exception.what()); db_accessor.Abort(); @@ -4659,8 +4733,7 @@ void Interpreter::Commit() { AdvanceCommand(); try { trigger.Execute(&*current_db_.execution_db_accessor_, &execution_memory, flags::run_time::GetExecutionTimeout(), - &interpreter_context_->is_shutting_down, &transaction_status_, *trigger_context, - interpreter_context_->auth_checker); + &interpreter_context_->is_shutting_down, &transaction_status_, *trigger_context); } catch (const utils::BasicException &e) { throw utils::BasicException( fmt::format("Trigger '{}' caused the transaction to fail.\nException: {}", trigger.Name(), e.what())); @@ -4775,7 +4848,7 @@ void Interpreter::SetNextTransactionIsolationLevel(const storage::IsolationLevel void Interpreter::SetSessionIsolationLevel(const storage::IsolationLevel isolation_level) { interpreter_isolation_level.emplace(isolation_level); } -void Interpreter::ResetUser() { username_.reset(); } -void Interpreter::SetUser(std::string_view username) { username_ = username; } +void Interpreter::ResetUser() { user_or_role_.reset(); } +void Interpreter::SetUser(std::shared_ptr user_or_role) { user_or_role_ = std::move(user_or_role); } } // namespace memgraph::query diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index da032b8e3..01a443d6d 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -84,16 +84,6 @@ class CoordinatorQueryHandler { CoordinatorQueryHandler(CoordinatorQueryHandler &&) = default; CoordinatorQueryHandler &operator=(CoordinatorQueryHandler &&) = default; - struct Replica { - std::string name; - std::string socket_address; - ReplicationQuery::SyncMode sync_mode; - std::optional timeout; - uint64_t current_timestamp_of_replica; - uint64_t current_number_of_timestamp_behind_master; - ReplicationQuery::ReplicaState state; - }; - struct MainReplicaStatus { std::string_view name; std::string_view socket_address; @@ -105,10 +95,15 @@ class CoordinatorQueryHandler { }; /// @throw QueryRuntimeException if an error ocurred. - virtual void RegisterReplicationInstance(const std::string &coordinator_socket_address, - const std::string &replication_socket_address, - const std::chrono::seconds instance_check_frequency, - const std::string &instance_name, CoordinatorQuery::SyncMode sync_mode) = 0; + virtual void RegisterReplicationInstance(std::string const &coordinator_socket_address, + std::string const &replication_socket_address, + std::chrono::seconds const &instance_health_check_frequency, + std::chrono::seconds const &instance_down_timeout, + std::chrono::seconds const &instance_get_uuid_frequency, + std::string const &instance_name, CoordinatorQuery::SyncMode sync_mode) = 0; + + /// @throw QueryRuntimeException if an error ocurred. + virtual void UnregisterInstance(std::string const &instance_name) = 0; /// @throw QueryRuntimeException if an error ocurred. virtual void SetReplicationInstanceToMain(const std::string &instance_name) = 0; @@ -205,7 +200,7 @@ class Interpreter final { std::optional db; }; - std::optional username_; + std::shared_ptr user_or_role_{}; bool in_explicit_transaction_{false}; CurrentDB current_db_; @@ -295,7 +290,7 @@ class Interpreter final { void ResetUser(); - void SetUser(std::string_view username); + void SetUser(std::shared_ptr user); std::optional system_transaction_{}; diff --git a/src/query/interpreter_context.cpp b/src/query/interpreter_context.cpp index f7b4584ba..eb35dbf03 100644 --- a/src/query/interpreter_context.cpp +++ b/src/query/interpreter_context.cpp @@ -35,13 +35,13 @@ InterpreterContext::InterpreterContext(InterpreterConfig interpreter_config, dbm } std::vector> InterpreterContext::TerminateTransactions( - std::vector maybe_kill_transaction_ids, const std::optional &username, - std::function privilege_checker) { + std::vector maybe_kill_transaction_ids, QueryUserOrRole *user_or_role, + std::function privilege_checker) { auto not_found_midpoint = maybe_kill_transaction_ids.end(); // Multiple simultaneous TERMINATE TRANSACTIONS aren't allowed // TERMINATE and SHOW TRANSACTIONS are mutually exclusive - interpreters.WithLock([¬_found_midpoint, &maybe_kill_transaction_ids, username, + interpreters.WithLock([¬_found_midpoint, &maybe_kill_transaction_ids, user_or_role, privilege_checker = std::move(privilege_checker)](const auto &interpreters) { for (Interpreter *interpreter : interpreters) { TransactionStatus alive_status = TransactionStatus::ACTIVE; @@ -73,7 +73,15 @@ std::vector> InterpreterContext::TerminateTransactions( static std::string all; return interpreter->current_db_.db_acc_ ? interpreter->current_db_.db_acc_->get()->name() : all; }; - if (interpreter->username_ == username || privilege_checker(get_interpreter_db_name())) { + + auto same_user = [](const auto &lv, const auto &rv) { + if (lv.get() == rv) return true; + if (lv && rv) return *lv == *rv; + return false; + }; + + if (same_user(interpreter->user_or_role_, user_or_role) || + privilege_checker(user_or_role, get_interpreter_db_name())) { killed = true; // Note: this is used by the above `clean_status` (OnScopeExit) spdlog::warn("Transaction {} successfully killed", transaction_id); } else { diff --git a/src/query/interpreter_context.hpp b/src/query/interpreter_context.hpp index c5fe00d2d..559ea3342 100644 --- a/src/query/interpreter_context.hpp +++ b/src/query/interpreter_context.hpp @@ -46,6 +46,7 @@ constexpr uint64_t kInterpreterTransactionInitialId = 1ULL << 63U; class AuthQueryHandler; class AuthChecker; class Interpreter; +struct QueryUserOrRole; /** * Holds data shared between multiple `Interpreter` instances (which might be @@ -95,8 +96,8 @@ struct InterpreterContext { void Shutdown() { is_shutting_down.store(true, std::memory_order_release); } std::vector> TerminateTransactions( - std::vector maybe_kill_transaction_ids, const std::optional &username, - std::function privilege_checker); + std::vector maybe_kill_transaction_ids, QueryUserOrRole *user_or_role, + std::function privilege_checker); }; } // namespace memgraph::query diff --git a/src/query/metadata.cpp b/src/query/metadata.cpp index 59d65e077..e339aad57 100644 --- a/src/query/metadata.cpp +++ b/src/query/metadata.cpp @@ -71,6 +71,8 @@ constexpr std::string_view GetCodeString(const NotificationCode code) { return "RegisterCoordinatorServer"sv; case NotificationCode::ADD_COORDINATOR_INSTANCE: return "AddCoordinatorInstance"sv; + case NotificationCode::UNREGISTER_INSTANCE: + return "UnregisterInstance"sv; #endif case NotificationCode::REPLICA_PORT_WARNING: return "ReplicaPortWarning"sv; diff --git a/src/query/metadata.hpp b/src/query/metadata.hpp index 2f357a555..dd8c2db07 100644 --- a/src/query/metadata.hpp +++ b/src/query/metadata.hpp @@ -43,8 +43,9 @@ enum class NotificationCode : uint8_t { REPLICA_PORT_WARNING, REGISTER_REPLICA, #ifdef MG_ENTERPRISE - REGISTER_COORDINATOR_SERVER, + REGISTER_COORDINATOR_SERVER, // TODO: (andi) What is this? ADD_COORDINATOR_INSTANCE, + UNREGISTER_INSTANCE, #endif SET_REPLICA, START_STREAM, diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 8dfaac81f..75b531261 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -1549,15 +1549,15 @@ class SingleSourceShortestPathCursor : public query::plan::Cursor { // for the given (edge, vertex) pair checks if they satisfy the // "where" condition. if so, places them in the to_visit_ structure. - auto expand_pair = [this, &evaluator, &frame, &context](EdgeAccessor edge, VertexAccessor vertex) { + auto expand_pair = [this, &evaluator, &frame, &context](EdgeAccessor edge, VertexAccessor vertex) -> bool { // if we already processed the given vertex it doesn't get expanded - if (processed_.find(vertex) != processed_.end()) return; + if (processed_.find(vertex) != processed_.end()) return false; #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && !(context.auth_checker->Has(vertex, storage::View::OLD, memgraph::query::AuthQuery::FineGrainedPrivilege::READ) && context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) { - return; + return false; } #endif frame[self_.filter_lambda_.inner_edge_symbol] = edge; @@ -1576,9 +1576,9 @@ class SingleSourceShortestPathCursor : public query::plan::Cursor { TypedValue result = self_.filter_lambda_.expression->Accept(evaluator); switch (result.type()) { case TypedValue::Type::Null: - return; + return true; case TypedValue::Type::Bool: - if (!result.ValueBool()) return; + if (!result.ValueBool()) return true; break; default: throw QueryRuntimeException("Expansion condition must evaluate to boolean or null."); @@ -1586,10 +1586,11 @@ class SingleSourceShortestPathCursor : public query::plan::Cursor { } to_visit_next_.emplace_back(edge, vertex, std::move(curr_acc_path)); processed_.emplace(vertex, edge); + return true; }; - auto restore_frame_state_after_expansion = [this, &frame]() { - if (self_.filter_lambda_.accumulated_path_symbol) { + auto restore_frame_state_after_expansion = [this, &frame](bool was_expanded) { + if (was_expanded && self_.filter_lambda_.accumulated_path_symbol) { frame[self_.filter_lambda_.accumulated_path_symbol.value()].ValuePath().Shrink(); } }; @@ -1601,15 +1602,15 @@ class SingleSourceShortestPathCursor : public query::plan::Cursor { if (self_.common_.direction != EdgeAtom::Direction::IN) { auto out_edges = UnwrapEdgesResult(vertex.OutEdges(storage::View::OLD, self_.common_.edge_types)).edges; for (const auto &edge : out_edges) { - expand_pair(edge, edge.To()); - restore_frame_state_after_expansion(); + bool was_expanded = expand_pair(edge, edge.To()); + restore_frame_state_after_expansion(was_expanded); } } if (self_.common_.direction != EdgeAtom::Direction::OUT) { auto in_edges = UnwrapEdgesResult(vertex.InEdges(storage::View::OLD, self_.common_.edge_types)).edges; for (const auto &edge : in_edges) { - expand_pair(edge, edge.From()); - restore_frame_state_after_expansion(); + bool was_expanded = expand_pair(edge, edge.From()); + restore_frame_state_after_expansion(was_expanded); } } }; @@ -1800,18 +1801,8 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor { // For the given (edge, vertex, weight, depth) tuple checks if they // satisfy the "where" condition. if so, places them in the priority // queue. - auto expand_pair = [this, &evaluator, &frame, &create_state, &context]( - const EdgeAccessor &edge, const VertexAccessor &vertex, const TypedValue &total_weight, - int64_t depth) { -#ifdef MG_ENTERPRISE - if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && - !(context.auth_checker->Has(vertex, storage::View::OLD, - memgraph::query::AuthQuery::FineGrainedPrivilege::READ) && - context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) { - return; - } -#endif - + auto expand_pair = [this, &evaluator, &frame, &create_state](const EdgeAccessor &edge, const VertexAccessor &vertex, + const TypedValue &total_weight, int64_t depth) { frame[self_.weight_lambda_->inner_edge_symbol] = edge; frame[self_.weight_lambda_->inner_node_symbol] = vertex; TypedValue next_weight = CalculateNextWeight(self_.weight_lambda_, total_weight, evaluator); @@ -1854,11 +1845,19 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor { // Populates the priority queue structure with expansions // from the given vertex. skips expansions that don't satisfy // the "where" condition. - auto expand_from_vertex = [this, &expand_pair, &restore_frame_state_after_expansion]( + auto expand_from_vertex = [this, &context, &expand_pair, &restore_frame_state_after_expansion]( const VertexAccessor &vertex, const TypedValue &weight, int64_t depth) { if (self_.common_.direction != EdgeAtom::Direction::IN) { auto out_edges = UnwrapEdgesResult(vertex.OutEdges(storage::View::OLD, self_.common_.edge_types)).edges; for (const auto &edge : out_edges) { +#ifdef MG_ENTERPRISE + if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && + !(context.auth_checker->Has(edge.To(), storage::View::OLD, + memgraph::query::AuthQuery::FineGrainedPrivilege::READ) && + context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) { + continue; + } +#endif expand_pair(edge, edge.To(), weight, depth); restore_frame_state_after_expansion(); } @@ -1866,6 +1865,14 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor { if (self_.common_.direction != EdgeAtom::Direction::OUT) { auto in_edges = UnwrapEdgesResult(vertex.InEdges(storage::View::OLD, self_.common_.edge_types)).edges; for (const auto &edge : in_edges) { +#ifdef MG_ENTERPRISE + if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && + !(context.auth_checker->Has(edge.From(), storage::View::OLD, + memgraph::query::AuthQuery::FineGrainedPrivilege::READ) && + context.auth_checker->Has(edge, memgraph::query::AuthQuery::FineGrainedPrivilege::READ))) { + continue; + } +#endif expand_pair(edge, edge.From(), weight, depth); restore_frame_state_after_expansion(); } diff --git a/src/query/plan/preprocess.cpp b/src/query/plan/preprocess.cpp index cf8ad9c97..c3bfdf462 100644 --- a/src/query/plan/preprocess.cpp +++ b/src/query/plan/preprocess.cpp @@ -313,7 +313,7 @@ void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table, auto *property_lookup = storage.Create(atom->filter_lambda_.inner_edge, prop_pair.first); auto *prop_equal = storage.Create(property_lookup, prop_pair.second); // Currently, variable expand has no gains if we set PropertyFilter. - all_filters_.emplace_back(FilterInfo{FilterInfo::Type::Generic, prop_equal, collector.symbols_}); + all_filters_.emplace_back(FilterInfo::Type::Generic, prop_equal, collector.symbols_); } { collector.symbols_.clear(); @@ -328,9 +328,9 @@ void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table, auto *prop_equal = storage.Create(property_lookup, prop_pair.second); // Currently, variable expand has no gains if we set PropertyFilter. all_filters_.emplace_back( - FilterInfo{FilterInfo::Type::Generic, - storage.Create(identifier, atom->identifier_, storage.Create(prop_equal)), - collector.symbols_}); + FilterInfo::Type::Generic, + storage.Create(identifier, atom->identifier_, storage.Create(prop_equal)), + collector.symbols_); } } return; @@ -639,6 +639,12 @@ void AddMatching(const Match &match, SymbolTable &symbol_table, AstStorage &stor } } +PatternFilterVisitor::PatternFilterVisitor(SymbolTable &symbol_table, AstStorage &storage) + : symbol_table_(symbol_table), storage_(storage) {} +PatternFilterVisitor::PatternFilterVisitor(const PatternFilterVisitor &) = default; +PatternFilterVisitor::PatternFilterVisitor(PatternFilterVisitor &&) noexcept = default; +PatternFilterVisitor::~PatternFilterVisitor() = default; + void PatternFilterVisitor::Visit(Exists &op) { std::vector patterns; patterns.push_back(op.pattern_); @@ -652,6 +658,8 @@ void PatternFilterVisitor::Visit(Exists &op) { matchings_.push_back(std::move(filter_matching)); } +std::vector PatternFilterVisitor::getMatchings() { return matchings_; } + static void ParseForeach(query::Foreach &foreach, SingleQueryPart &query_part, AstStorage &storage, SymbolTable &symbol_table) { for (auto *clause : foreach.clauses_) { @@ -723,4 +731,18 @@ QueryParts CollectQueryParts(SymbolTable &symbol_table, AstStorage &storage, Cyp return QueryParts{query_parts, distinct}; } +FilterInfo::FilterInfo(Type type, Expression *expression, std::unordered_set used_symbols, + std::optional property_filter, std::optional id_filter) + : type(type), + expression(expression), + used_symbols(std::move(used_symbols)), + property_filter(std::move(property_filter)), + id_filter(std::move(id_filter)), + matchings({}) {} +FilterInfo::FilterInfo(const FilterInfo &) = default; +FilterInfo &FilterInfo::operator=(const FilterInfo &) = default; +FilterInfo::FilterInfo(FilterInfo &&) noexcept = default; +FilterInfo &FilterInfo::operator=(FilterInfo &&) noexcept = default; +FilterInfo::~FilterInfo() = default; + } // namespace memgraph::query::plan diff --git a/src/query/plan/preprocess.hpp b/src/query/plan/preprocess.hpp index 2b53fb7b0..01b10ebaf 100644 --- a/src/query/plan/preprocess.hpp +++ b/src/query/plan/preprocess.hpp @@ -19,6 +19,7 @@ #include #include "query/frontend/ast/ast.hpp" +#include "query/frontend/ast/ast_visitor.hpp" #include "query/frontend/semantic/symbol_table.hpp" namespace memgraph::query::plan { @@ -159,8 +160,12 @@ enum class PatternFilterType { EXISTS }; /// Collects matchings from filters that include patterns class PatternFilterVisitor : public ExpressionVisitor { public: - explicit PatternFilterVisitor(SymbolTable &symbol_table, AstStorage &storage) - : symbol_table_(symbol_table), storage_(storage) {} + explicit PatternFilterVisitor(SymbolTable &symbol_table, AstStorage &storage); + PatternFilterVisitor(const PatternFilterVisitor &); + PatternFilterVisitor &operator=(const PatternFilterVisitor &) = delete; + PatternFilterVisitor(PatternFilterVisitor &&) noexcept; + PatternFilterVisitor &operator=(PatternFilterVisitor &&) noexcept = delete; + ~PatternFilterVisitor() override; using ExpressionVisitor::Visit; @@ -232,7 +237,7 @@ class PatternFilterVisitor : public ExpressionVisitor { void Visit(RegexMatch &op) override{}; void Visit(PatternComprehension &op) override{}; - std::vector getMatchings() { return matchings_; } + std::vector getMatchings(); SymbolTable &symbol_table_; AstStorage &storage_; @@ -298,9 +303,23 @@ struct FilterInfo { /// elements. enum class Type { Generic, Label, Property, Id, Pattern }; - Type type; + // FilterInfo is tricky because FilterMatching is not yet defined: + // * if no declared constructor -> FilterInfo is std::__is_complete_or_unbounded + // * if any user-declared constructor -> non-aggregate type -> no designated initializers are possible + // * IMPORTANT: Matchings will always be initialized to an empty container. + explicit FilterInfo(Type type = Type::Generic, Expression *expression = nullptr, + std::unordered_set used_symbols = {}, std::optional property_filter = {}, + std::optional id_filter = {}); + // All other constructors are also defined in the cpp file because this struct is incomplete here. + FilterInfo(const FilterInfo &); + FilterInfo &operator=(const FilterInfo &); + FilterInfo(FilterInfo &&) noexcept; + FilterInfo &operator=(FilterInfo &&) noexcept; + ~FilterInfo(); + + Type type{Type::Generic}; /// The original filter expression which must be satisfied. - Expression *expression; + Expression *expression{nullptr}; /// Set of used symbols by the filter @c expression. std::unordered_set used_symbols{}; /// Labels for Type::Label filtering. @@ -310,7 +329,8 @@ struct FilterInfo { /// Information for Type::Id filtering. std::optional id_filter{}; /// Matchings for filters that include patterns - std::vector matchings{}; + /// NOTE: The vector is not defined here because FilterMatching is forward declared above. + std::vector matchings; }; /// Stores information on filters used inside the @c Matching of a @c QueryPart. @@ -329,34 +349,15 @@ class Filters final { auto empty() const { return all_filters_.empty(); } - auto erase(iterator pos) { return all_filters_.erase(pos); } - auto erase(const_iterator pos) { return all_filters_.erase(pos); } - auto erase(iterator first, iterator last) { return all_filters_.erase(first, last); } - auto erase(const_iterator first, const_iterator last) { return all_filters_.erase(first, last); } + auto erase(iterator pos) -> iterator; + auto erase(const_iterator pos) -> iterator; + auto erase(iterator first, iterator last) -> iterator; + auto erase(const_iterator first, const_iterator last) -> iterator; void SetFilters(std::vector &&all_filters) { all_filters_ = std::move(all_filters); } - auto FilteredLabels(const Symbol &symbol) const { - std::unordered_set labels; - for (const auto &filter : all_filters_) { - if (filter.type == FilterInfo::Type::Label && utils::Contains(filter.used_symbols, symbol)) { - MG_ASSERT(filter.used_symbols.size() == 1U, "Expected a single used symbol for label filter"); - labels.insert(filter.labels.begin(), filter.labels.end()); - } - } - return labels; - } - - auto FilteredProperties(const Symbol &symbol) const -> std::unordered_set { - std::unordered_set properties; - - for (const auto &filter : all_filters_) { - if (filter.type == FilterInfo::Type::Property && filter.property_filter->symbol_ == symbol) { - properties.insert(filter.property_filter->property_); - } - } - return properties; - } + auto FilteredLabels(const Symbol &symbol) const -> std::unordered_set; + auto FilteredProperties(const Symbol &symbol) const -> std::unordered_set; /// Remove a filter; may invalidate iterators. /// Removal is done by comparing only the expression, so that multiple @@ -370,26 +371,10 @@ class Filters final { std::vector *removed_filters = nullptr); /// Returns a vector of FilterInfo for properties. - auto PropertyFilters(const Symbol &symbol) const { - std::vector filters; - for (const auto &filter : all_filters_) { - if (filter.type == FilterInfo::Type::Property && filter.property_filter->symbol_ == symbol) { - filters.push_back(filter); - } - } - return filters; - } + auto PropertyFilters(const Symbol &symbol) const -> std::vector; /// Return a vector of FilterInfo for ID equality filtering. - auto IdFilters(const Symbol &symbol) const { - std::vector filters; - for (const auto &filter : all_filters_) { - if (filter.type == FilterInfo::Type::Id && filter.id_filter->symbol_ == symbol) { - filters.push_back(filter); - } - } - return filters; - } + auto IdFilters(const Symbol &symbol) const -> std::vector; /// Collects filtering information from a pattern. /// @@ -459,6 +444,57 @@ struct FilterMatching : Matching { std::optional symbol; }; +inline auto Filters::erase(Filters::iterator pos) -> iterator { return all_filters_.erase(pos); } +inline auto Filters::erase(Filters::const_iterator pos) -> iterator { return all_filters_.erase(pos); } +inline auto Filters::erase(Filters::iterator first, Filters::iterator last) -> iterator { + return all_filters_.erase(first, last); +} +inline auto Filters::erase(Filters::const_iterator first, Filters::const_iterator last) -> iterator { + return all_filters_.erase(first, last); +} + +inline auto Filters::FilteredLabels(const Symbol &symbol) const -> std::unordered_set { + std::unordered_set labels; + for (const auto &filter : all_filters_) { + if (filter.type == FilterInfo::Type::Label && utils::Contains(filter.used_symbols, symbol)) { + MG_ASSERT(filter.used_symbols.size() == 1U, "Expected a single used symbol for label filter"); + labels.insert(filter.labels.begin(), filter.labels.end()); + } + } + return labels; +} + +inline auto Filters::FilteredProperties(const Symbol &symbol) const -> std::unordered_set { + std::unordered_set properties; + + for (const auto &filter : all_filters_) { + if (filter.type == FilterInfo::Type::Property && filter.property_filter->symbol_ == symbol) { + properties.insert(filter.property_filter->property_); + } + } + return properties; +} + +inline auto Filters::PropertyFilters(const Symbol &symbol) const -> std::vector { + std::vector filters; + for (const auto &filter : all_filters_) { + if (filter.type == FilterInfo::Type::Property && filter.property_filter->symbol_ == symbol) { + filters.push_back(filter); + } + } + return filters; +} + +inline auto Filters::IdFilters(const Symbol &symbol) const -> std::vector { + std::vector filters; + for (const auto &filter : all_filters_) { + if (filter.type == FilterInfo::Type::Id && filter.id_filter->symbol_ == symbol) { + filters.push_back(filter); + } + } + return filters; +} + /// @brief Represents a read (+ write) part of a query. Parts are split on /// `WITH` clauses. /// diff --git a/src/query/procedure/fmt.hpp b/src/query/procedure/fmt.hpp new file mode 100644 index 000000000..85775da46 --- /dev/null +++ b/src/query/procedure/fmt.hpp @@ -0,0 +1,82 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#if FMT_VERSION > 90000 +#include +#include + +#include "mg_procedure.h" +#include "utils/logging.hpp" + +inline std::string ToString(const mgp_log_level &log_level) { + switch (log_level) { + case mgp_log_level::MGP_LOG_LEVEL_CRITICAL: + return "CRITICAL"; + case mgp_log_level::MGP_LOG_LEVEL_ERROR: + return "ERROR"; + case mgp_log_level::MGP_LOG_LEVEL_WARN: + return "WARN"; + case mgp_log_level::MGP_LOG_LEVEL_INFO: + return "INFO"; + case mgp_log_level::MGP_LOG_LEVEL_DEBUG: + return "DEBUG"; + case mgp_log_level::MGP_LOG_LEVEL_TRACE: + return "TRACE"; + } + LOG_FATAL("ToString of a wrong mgp_log_level -> check missing switch case"); +} +inline std::ostream &operator<<(std::ostream &os, const mgp_log_level &log_level) { + os << ToString(log_level); + return os; +} +template <> +class fmt::formatter : public fmt::ostream_formatter {}; + +inline std::string ToString(const mgp_error &error) { + switch (error) { + case mgp_error::MGP_ERROR_NO_ERROR: + return "NO ERROR"; + case mgp_error::MGP_ERROR_UNKNOWN_ERROR: + return "UNKNOWN ERROR"; + case mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE: + return "UNABLE TO ALLOCATE ERROR"; + case mgp_error::MGP_ERROR_INSUFFICIENT_BUFFER: + return "INSUFFICIENT BUFFER ERROR"; + case mgp_error::MGP_ERROR_OUT_OF_RANGE: + return "OUT OF RANGE ERROR"; + case mgp_error::MGP_ERROR_LOGIC_ERROR: + return "LOGIC ERROR"; + case mgp_error::MGP_ERROR_DELETED_OBJECT: + return "DELETED OBJECT ERROR"; + case mgp_error::MGP_ERROR_INVALID_ARGUMENT: + return "INVALID ARGUMENT ERROR"; + case mgp_error::MGP_ERROR_KEY_ALREADY_EXISTS: + return "KEY ALREADY EXISTS ERROR"; + case mgp_error::MGP_ERROR_IMMUTABLE_OBJECT: + return "IMMUTABLE OBJECT ERROR"; + case mgp_error::MGP_ERROR_VALUE_CONVERSION: + return "VALUE CONVERSION ERROR"; + case mgp_error::MGP_ERROR_SERIALIZATION_ERROR: + return "SERIALIZATION ERROR"; + case mgp_error::MGP_ERROR_AUTHORIZATION_ERROR: + return "AUTHORIZATION ERROR"; + } + LOG_FATAL("ToString of a wrong mgp_error -> check missing switch case"); +} +inline std::ostream &operator<<(std::ostream &os, const mgp_error &error) { + os << ToString(error); + return os; +} +template <> +class fmt::formatter : public fmt::ostream_formatter {}; +#endif diff --git a/src/query/procedure/mg_procedure_helpers.cpp b/src/query/procedure/mg_procedure_helpers.cpp index 6b206e7dc..a6590a287 100644 --- a/src/query/procedure/mg_procedure_helpers.cpp +++ b/src/query/procedure/mg_procedure_helpers.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -10,6 +10,7 @@ // licenses/APL.txt. #include "query/procedure/mg_procedure_helpers.hpp" +#include "query/procedure/fmt.hpp" namespace memgraph::query::procedure { MgpUniquePtr GetStringValueOrSetError(const char *string, mgp_memory *memory, mgp_result *result) { diff --git a/src/query/procedure/mg_procedure_helpers.hpp b/src/query/procedure/mg_procedure_helpers.hpp index cb8bd55db..d0032c521 100644 --- a/src/query/procedure/mg_procedure_helpers.hpp +++ b/src/query/procedure/mg_procedure_helpers.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -18,6 +18,7 @@ #include #include "mg_procedure.h" +#include "query/procedure/fmt.hpp" namespace memgraph::query::procedure { template diff --git a/src/query/procedure/mg_procedure_impl.cpp b/src/query/procedure/mg_procedure_impl.cpp index 2eea1cecb..647f3e14d 100644 --- a/src/query/procedure/mg_procedure_impl.cpp +++ b/src/query/procedure/mg_procedure_impl.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -29,6 +29,7 @@ #include "query/db_accessor.hpp" #include "query/frontend/ast/ast.hpp" #include "query/procedure/cypher_types.hpp" +#include "query/procedure/fmt.hpp" #include "query/procedure/mg_procedure_helpers.hpp" #include "query/stream/common.hpp" #include "storage/v2/property_value.hpp" @@ -187,6 +188,7 @@ template spdlog::error("Memory allocation error during mg API call: {}", bae.what()); return mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE; } catch (const memgraph::utils::OutOfMemoryException &oome) { + [[maybe_unused]] auto blocker = memgraph::utils::MemoryTracker::OutOfMemoryExceptionBlocker{}; spdlog::error("Memory limit exceeded during mg API call: {}", oome.what()); return mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE; } catch (const std::out_of_range &oore) { @@ -198,12 +200,12 @@ template } catch (const std::logic_error &lee) { spdlog::error("Logic error during mg API call: {}", lee.what()); return mgp_error::MGP_ERROR_LOGIC_ERROR; - } catch (const std::exception &e) { - spdlog::error("Unexpected error during mg API call: {}", e.what()); - return mgp_error::MGP_ERROR_UNKNOWN_ERROR; } catch (const memgraph::utils::temporal::InvalidArgumentException &e) { spdlog::error("Invalid argument was sent to an mg API call for temporal types: {}", e.what()); return mgp_error::MGP_ERROR_INVALID_ARGUMENT; + } catch (const std::exception &e) { + spdlog::error("Unexpected error during mg API call: {}", e.what()); + return mgp_error::MGP_ERROR_UNKNOWN_ERROR; } catch (...) { spdlog::error("Unexpected error during mg API call"); return mgp_error::MGP_ERROR_UNKNOWN_ERROR; diff --git a/src/query/query_user.cpp b/src/query/query_user.cpp new file mode 100644 index 000000000..005601f81 --- /dev/null +++ b/src/query/query_user.cpp @@ -0,0 +1,21 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/query_user.hpp" + +namespace memgraph::query { +// The variables below are used to define a user auth policy. +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +SessionLongPolicy session_long_policy; +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +UpToDatePolicy up_to_date_policy; + +} // namespace memgraph::query diff --git a/src/query/query_user.hpp b/src/query/query_user.hpp new file mode 100644 index 000000000..62d2e32b1 --- /dev/null +++ b/src/query/query_user.hpp @@ -0,0 +1,61 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include +#include + +#include "query/frontend/ast/ast.hpp" + +namespace memgraph::query { + +class UserPolicy { + public: + virtual bool DoUpdate() const = 0; +}; +extern struct SessionLongPolicy : UserPolicy { + public: + bool DoUpdate() const override { return false; } +} session_long_policy; +extern struct UpToDatePolicy : UserPolicy { + public: + bool DoUpdate() const override { return true; } +} up_to_date_policy; + +struct QueryUserOrRole { + QueryUserOrRole(std::optional username, std::optional rolename) + : username_{std::move(username)}, rolename_{std::move(rolename)} {} + virtual ~QueryUserOrRole() = default; + + virtual bool IsAuthorized(const std::vector &privileges, const std::string &db_name, + UserPolicy *policy) const = 0; + +#ifdef MG_ENTERPRISE + virtual std::string GetDefaultDB() const = 0; +#endif + + std::string key() const { + // NOTE: Each role has an associated username, that's why we check it with higher priority + return rolename_ ? *rolename_ : (username_ ? *username_ : ""); + } + const std::optional &username() const { return username_; } + const std::optional &rolename() const { return rolename_; } + + bool operator==(const QueryUserOrRole &other) const = default; + operator bool() const { return username_.has_value(); } + + private: + std::optional username_; + std::optional rolename_; +}; + +} // namespace memgraph::query diff --git a/src/query/replication_query_handler.hpp b/src/query/replication_query_handler.hpp index aa0611a43..9c5814eef 100644 --- a/src/query/replication_query_handler.hpp +++ b/src/query/replication_query_handler.hpp @@ -11,6 +11,8 @@ #pragma once +#include "replication/replication_client.hpp" +#include "replication_coordination_glue/mode.hpp" #include "replication_coordination_glue/role.hpp" #include "utils/result.hpp" #include "utils/uuid.hpp" @@ -31,6 +33,7 @@ enum class RegisterReplicaError : uint8_t { COULD_NOT_BE_PERSISTED, ERROR_ACCEPTING_MAIN }; + enum class UnregisterReplicaResult : uint8_t { NOT_MAIN, COULD_NOT_BE_PERSISTED, @@ -38,6 +41,47 @@ enum class UnregisterReplicaResult : uint8_t { SUCCESS, }; +enum class ShowReplicaError : uint8_t { + NOT_MAIN, +}; + +struct ReplicaSystemInfoState { + uint64_t ts_; + uint64_t behind_; + replication::ReplicationClient::State state_; +}; + +struct ReplicaInfoState { + ReplicaInfoState(uint64_t ts, uint64_t behind, storage::replication::ReplicaState state) + : ts_(ts), behind_(behind), state_(state) {} + + uint64_t ts_; + uint64_t behind_; + storage::replication::ReplicaState state_; +}; + +struct ReplicasInfo { + ReplicasInfo(std::string name, std::string socket_address, replication_coordination_glue::ReplicationMode sync_mode, + ReplicaSystemInfoState system_info, std::map data_info) + : name_(std::move(name)), + socket_address_(std::move(socket_address)), + sync_mode_(sync_mode), + system_info_(std::move(system_info)), + data_info_(std::move(data_info)) {} + + std::string name_; + std::string socket_address_; + memgraph::replication_coordination_glue::ReplicationMode sync_mode_; + ReplicaSystemInfoState system_info_; + std::map data_info_; +}; + +struct ReplicasInfos { + explicit ReplicasInfos(std::vector entries) : entries_(std::move(entries)) {} + + std::vector entries_; +}; + /// A handler type that keep in sync current ReplicationState and the MAIN/REPLICA-ness of Storage struct ReplicationQueryHandler { virtual ~ReplicationQueryHandler() = default; @@ -49,11 +93,14 @@ struct ReplicationQueryHandler { virtual bool SetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config, const std::optional &main_uuid) = 0; + virtual bool TrySetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config, + const std::optional &main_uuid) = 0; + // as MAIN, define and connect to REPLICAs - virtual auto TryRegisterReplica(const memgraph::replication::ReplicationClientConfig &config, bool send_swap_uuid) + virtual auto TryRegisterReplica(const memgraph::replication::ReplicationClientConfig &config) -> utils::BasicResult = 0; - virtual auto RegisterReplica(const memgraph::replication::ReplicationClientConfig &config, bool send_swap_uuid) + virtual auto RegisterReplica(const memgraph::replication::ReplicationClientConfig &config) -> utils::BasicResult = 0; // as MAIN, remove a REPLICA connection @@ -63,6 +110,8 @@ struct ReplicationQueryHandler { virtual auto GetRole() const -> memgraph::replication_coordination_glue::ReplicationRole = 0; virtual bool IsMain() const = 0; virtual bool IsReplica() const = 0; + + virtual auto ShowReplicas() const -> utils::BasicResult = 0; }; } // namespace memgraph::query diff --git a/src/query/stream/streams.cpp b/src/query/stream/streams.cpp index 101ca592c..b8984b94b 100644 --- a/src/query/stream/streams.cpp +++ b/src/query/stream/streams.cpp @@ -29,6 +29,7 @@ #include "query/procedure/mg_procedure_helpers.hpp" #include "query/procedure/mg_procedure_impl.hpp" #include "query/procedure/module.hpp" +#include "query/query_user.hpp" #include "query/stream/sources.hpp" #include "query/typed_value.hpp" #include "utils/event_counter.hpp" @@ -131,6 +132,7 @@ StreamStatus CreateStatus(std::string stream_name, std::string transfor const std::string kStreamName{"name"}; const std::string kIsRunningKey{"is_running"}; const std::string kOwner{"owner"}; +const std::string kOwnerRole{"owner_role"}; const std::string kType{"type"}; } // namespace @@ -142,6 +144,11 @@ void to_json(nlohmann::json &data, StreamStatus &&status) { if (status.owner.has_value()) { data[kOwner] = std::move(*status.owner); + if (status.owner_role.has_value()) { + data[kOwnerRole] = std::move(*status.owner_role); + } else { + data[kOwnerRole] = nullptr; + } } else { data[kOwner] = nullptr; } @@ -156,6 +163,11 @@ void from_json(const nlohmann::json &data, StreamStatus &status) { if (const auto &owner = data.at(kOwner); !owner.is_null()) { status.owner = owner.get(); + if (const auto &owner_role = data.at(kOwnerRole); !owner_role.is_null()) { + owner_role.get_to(status.owner_role); + } else { + status.owner_role = {}; + } } else { status.owner = {}; } @@ -449,7 +461,7 @@ void Streams::RegisterPulsarProcedures() { template void Streams::Create(const std::string &stream_name, typename TStream::StreamInfo info, - std::optional owner, TDbAccess db_acc, InterpreterContext *ic) { + std::shared_ptr owner, TDbAccess db_acc, InterpreterContext *ic) { auto locked_streams = streams_.Lock(); auto it = CreateConsumer(*locked_streams, stream_name, std::move(info), std::move(owner), std::move(db_acc), ic); @@ -469,31 +481,39 @@ void Streams::Create(const std::string &stream_name, typename TStream::StreamInf template void Streams::Create(const std::string &stream_name, KafkaStream::StreamInfo info, - std::optional owner, + std::shared_ptr owner, dbms::DatabaseAccess db, InterpreterContext *ic); template void Streams::Create(const std::string &stream_name, PulsarStream::StreamInfo info, - std::optional owner, + std::shared_ptr owner, dbms::DatabaseAccess db, InterpreterContext *ic); template Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std::string &stream_name, typename TStream::StreamInfo stream_info, - std::optional owner, TDbAccess db_acc, + std::shared_ptr owner, TDbAccess db_acc, InterpreterContext *interpreter_context) { if (map.contains(stream_name)) { throw StreamsException{"Stream already exists with name '{}'", stream_name}; } + auto ownername = owner->username(); + auto rolename = owner->rolename(); + auto *memory_resource = utils::NewDeleteResource(); auto consumer_function = [interpreter_context, memory_resource, stream_name, - transformation_name = stream_info.common_info.transformation_name, owner = owner, + transformation_name = stream_info.common_info.transformation_name, owner = std::move(owner), interpreter = std::make_shared(interpreter_context, std::move(db_acc)), result = mgp_result{nullptr, memory_resource}, total_retries = interpreter_context->config.stream_transaction_conflict_retries, retry_interval = interpreter_context->config.stream_transaction_retry_interval]( const std::vector &messages) mutable { + // Set interpreter's user to the stream owner + // NOTE: We generate an empty user to avoid generating interpreter's fine grained access control and rely only on + // the global auth_checker used in the stream itself + // TODO: Fix auth inconsistency + interpreter->SetUser(interpreter_context->auth_checker->GenQueryUser(std::nullopt, std::nullopt)); #ifdef MG_ENTERPRISE interpreter->OnChangeCB([](auto) { return false; }); // Disable database change #endif @@ -523,12 +543,11 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std spdlog::trace("Processing row in stream '{}'", stream_name); auto [query_value, params_value] = ExtractTransformationResult(row.values, transformation_name, stream_name); storage::PropertyValue params_prop{params_value}; - std::string query{query_value.ValueString()}; spdlog::trace("Executing query '{}' in stream '{}'", query, stream_name); auto prepare_result = interpreter->Prepare(query, params_prop.IsNull() ? empty_parameters : params_prop.ValueMap(), {}); - if (!interpreter_context->auth_checker->IsUserAuthorized(owner, prepare_result.privileges, "")) { + if (!owner->IsAuthorized(prepare_result.privileges, "", &up_to_date_policy)) { throw StreamsException{ "Couldn't execute query '{}' for stream '{}' because the owner is not authorized to execute the " "query!", @@ -553,7 +572,8 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std }; auto insert_result = map.try_emplace( - stream_name, StreamData{std::move(stream_info.common_info.transformation_name), std::move(owner), + stream_name, StreamData{std::move(stream_info.common_info.transformation_name), std::move(ownername), + std::move(rolename), std::make_unique>( stream_name, std::move(stream_info), std::move(consumer_function))}); MG_ASSERT(insert_result.second, "Unexpected error during storing consumer '{}'", stream_name); @@ -575,6 +595,7 @@ void Streams::RestoreStreams(TDbAccess db, InterpreterContext *ic) { const auto create_consumer = [&, &stream_name = stream_name](StreamStatus status, auto &&stream_json_data) { try { + // TODO: Migration stream_json_data.get_to(status); } catch (const nlohmann::json::type_error &exception) { spdlog::warn(get_failed_message("invalid type conversion", exception.what())); @@ -586,8 +607,8 @@ void Streams::RestoreStreams(TDbAccess db, InterpreterContext *ic) { MG_ASSERT(status.name == stream_name, "Expected stream name is '{}', but got '{}'", status.name, stream_name); try { - auto it = CreateConsumer(*locked_streams_map, stream_name, std::move(status.info), std::move(status.owner), - db, ic); + auto owner = ic->auth_checker->GenQueryUser(status.owner, status.owner_role); + auto it = CreateConsumer(*locked_streams_map, stream_name, std::move(status.info), std::move(owner), db, ic); if (status.is_running) { std::visit( [&](const auto &stream_data) { @@ -745,7 +766,7 @@ std::vector> Streams::GetStreamInfo() const { auto info = locked_stream_source->Info(stream_data.transformation_name); result.emplace_back(StreamStatus<>{stream_name, StreamType(*locked_stream_source), locked_stream_source->IsRunning(), std::move(info.common_info), - stream_data.owner}); + stream_data.owner, stream_data.owner_role}); }, stream_data); } diff --git a/src/query/stream/streams.hpp b/src/query/stream/streams.hpp index bad1f8c98..e1660bdb4 100644 --- a/src/query/stream/streams.hpp +++ b/src/query/stream/streams.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -67,6 +67,7 @@ struct StreamStatus { bool is_running; StreamInfoType info; std::optional owner; + std::optional owner_role; }; using TransformationResult = std::vector>; @@ -100,7 +101,7 @@ class Streams final { /// /// @throws StreamsException if the stream with the same name exists or if the creation of Kafka consumer fails template - void Create(const std::string &stream_name, typename TStream::StreamInfo info, std::optional owner, + void Create(const std::string &stream_name, typename TStream::StreamInfo info, std::shared_ptr owner, TDbAccess db, InterpreterContext *interpreter_context); /// Deletes an existing stream and all the data that was persisted. @@ -182,6 +183,7 @@ class Streams final { struct StreamData { std::string transformation_name; std::optional owner; + std::optional owner_role; std::unique_ptr> stream_source; }; @@ -191,7 +193,7 @@ class Streams final { template StreamsMap::iterator CreateConsumer(StreamsMap &map, const std::string &stream_name, - typename TStream::StreamInfo stream_info, std::optional owner, + typename TStream::StreamInfo stream_info, std::shared_ptr owner, TDbAccess db, InterpreterContext *interpreter_context); template diff --git a/src/query/trigger.cpp b/src/query/trigger.cpp index 7998714c1..437389128 100644 --- a/src/query/trigger.cpp +++ b/src/query/trigger.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -11,14 +11,13 @@ #include "query/trigger.hpp" -#include - #include "query/config.hpp" #include "query/context.hpp" #include "query/cypher_query_interpreter.hpp" #include "query/db_accessor.hpp" #include "query/frontend/ast/ast.hpp" #include "query/interpret/frame.hpp" +#include "query/query_user.hpp" #include "query/serialization/property_value.hpp" #include "query/typed_value.hpp" #include "storage/v2/property_value.hpp" @@ -154,20 +153,19 @@ Trigger::Trigger(std::string name, const std::string &query, const std::map &user_parameters, const TriggerEventType event_type, utils::SkipList *query_cache, DbAccessor *db_accessor, const InterpreterConfig::Query &query_config, - std::optional owner, const query::AuthChecker *auth_checker) + std::shared_ptr owner) : name_{std::move(name)}, parsed_statements_{ParseQuery(query, user_parameters, query_cache, query_config)}, event_type_{event_type}, owner_{std::move(owner)} { // We check immediately if the query is valid by trying to create a plan. - GetPlan(db_accessor, auth_checker); + GetPlan(db_accessor); } Trigger::TriggerPlan::TriggerPlan(std::unique_ptr logical_plan, std::vector identifiers) : cached_plan(std::move(logical_plan)), identifiers(std::move(identifiers)) {} -std::shared_ptr Trigger::GetPlan(DbAccessor *db_accessor, - const query::AuthChecker *auth_checker) const { +std::shared_ptr Trigger::GetPlan(DbAccessor *db_accessor) const { std::lock_guard plan_guard{plan_lock_}; if (!parsed_statements_.is_cacheable || !trigger_plan_) { auto identifiers = GetPredefinedIdentifiers(event_type_); @@ -187,7 +185,7 @@ std::shared_ptr Trigger::GetPlan(DbAccessor *db_accessor, trigger_plan_ = std::make_shared(std::move(logical_plan), std::move(identifiers)); } - if (!auth_checker->IsUserAuthorized(owner_, parsed_statements_.required_privileges, "")) { + if (!owner_->IsAuthorized(parsed_statements_.required_privileges, "", &up_to_date_policy)) { throw utils::BasicException("The owner of trigger '{}' is not authorized to execute the query!", name_); } return trigger_plan_; @@ -195,14 +193,13 @@ std::shared_ptr Trigger::GetPlan(DbAccessor *db_accessor, void Trigger::Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution_memory, const double max_execution_time_sec, std::atomic *is_shutting_down, - std::atomic *transaction_status, const TriggerContext &context, - const AuthChecker *auth_checker) const { + std::atomic *transaction_status, const TriggerContext &context) const { if (!context.ShouldEventTrigger(event_type_)) { return; } spdlog::debug("Executing trigger '{}'", name_); - auto trigger_plan = GetPlan(dba, auth_checker); + auto trigger_plan = GetPlan(dba); MG_ASSERT(trigger_plan, "Invalid trigger plan received"); auto &[plan, identifiers] = *trigger_plan; @@ -308,6 +305,7 @@ void TriggerStore::RestoreTriggers(utils::SkipList *query_cache } const auto user_parameters = serialization::DeserializePropertyValueMap(json_trigger_data["user_parameters"]); + // TODO: Migration const auto owner_json = json_trigger_data["owner"]; std::optional owner{}; if (owner_json.is_string()) { @@ -317,10 +315,21 @@ void TriggerStore::RestoreTriggers(utils::SkipList *query_cache continue; } + const auto owner_role_json = json_trigger_data["owner_role"]; + std::optional role{}; + if (owner_role_json.is_string()) { + owner.emplace(owner_role_json.get()); + } else if (!owner_role_json.is_null()) { + spdlog::warn(invalid_state_message); + continue; + } + + auto user = auth_checker->GenQueryUser(owner, role); + std::optional trigger; try { trigger.emplace(trigger_name, statement, user_parameters, event_type, query_cache, db_accessor, query_config, - std::move(owner), auth_checker); + std::move(user)); } catch (const utils::BasicException &e) { spdlog::warn("Failed to create trigger '{}' because: {}", trigger_name, e.what()); continue; @@ -338,8 +347,7 @@ void TriggerStore::AddTrigger(std::string name, const std::string &query, const std::map &user_parameters, TriggerEventType event_type, TriggerPhase phase, utils::SkipList *query_cache, DbAccessor *db_accessor, - const InterpreterConfig::Query &query_config, std::optional owner, - const query::AuthChecker *auth_checker) { + const InterpreterConfig::Query &query_config, std::shared_ptr owner) { std::unique_lock store_guard{store_lock_}; if (storage_.Get(name)) { throw utils::BasicException("Trigger with the same name already exists."); @@ -348,7 +356,7 @@ void TriggerStore::AddTrigger(std::string name, const std::string &query, std::optional trigger; try { trigger.emplace(std::move(name), query, user_parameters, event_type, query_cache, db_accessor, query_config, - std::move(owner), auth_checker); + std::move(owner)); } catch (const utils::BasicException &e) { const auto identifiers = GetPredefinedIdentifiers(event_type); std::stringstream identifier_names_stream; @@ -370,10 +378,23 @@ void TriggerStore::AddTrigger(std::string name, const std::string &query, data["phase"] = phase; data["version"] = kVersion; - if (const auto &owner_from_trigger = trigger->Owner(); owner_from_trigger.has_value()) { - data["owner"] = *owner_from_trigger; + if (const auto &owner_from_trigger = trigger->Owner(); owner_from_trigger && *owner_from_trigger) { + const auto &maybe_username = owner_from_trigger->username(); + if (maybe_username) { + data["owner"] = *maybe_username; + // Roles need to be associated with a username + const auto &maybe_rolename = owner_from_trigger->rolename(); + if (maybe_rolename) { + data["owner_role"] = *maybe_rolename; + } else { + data["owner_role"] = nullptr; + } + } else { + data["owner"] = nullptr; + } } else { data["owner"] = nullptr; + data["owner_role"] = nullptr; } storage_.Put(trigger->Name(), data.dump()); store_guard.unlock(); @@ -417,7 +438,9 @@ std::vector TriggerStore::GetTriggerInfo() const { const auto add_info = [&](const utils::SkipList &trigger_list, const TriggerPhase phase) { for (const auto &trigger : trigger_list.access()) { - info.push_back({trigger.Name(), trigger.OriginalStatement(), trigger.EventType(), phase, trigger.Owner()}); + std::optional owner_str{}; + if (const auto &owner = trigger.Owner(); owner && *owner) owner_str = owner->username(); + info.push_back({trigger.Name(), trigger.OriginalStatement(), trigger.EventType(), phase, std::move(owner_str)}); } }; diff --git a/src/query/trigger.hpp b/src/query/trigger.hpp index a6e19032e..91c74579e 100644 --- a/src/query/trigger.hpp +++ b/src/query/trigger.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -37,12 +37,11 @@ struct Trigger { explicit Trigger(std::string name, const std::string &query, const std::map &user_parameters, TriggerEventType event_type, utils::SkipList *query_cache, DbAccessor *db_accessor, - const InterpreterConfig::Query &query_config, std::optional owner, - const query::AuthChecker *auth_checker); + const InterpreterConfig::Query &query_config, std::shared_ptr owner); void Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution_memory, double max_execution_time_sec, std::atomic *is_shutting_down, std::atomic *transaction_status, - const TriggerContext &context, const AuthChecker *auth_checker) const; + const TriggerContext &context) const; bool operator==(const Trigger &other) const { return name_ == other.name_; } // NOLINTNEXTLINE (modernize-use-nullptr) @@ -65,7 +64,7 @@ struct Trigger { PlanWrapper cached_plan; std::vector identifiers; }; - std::shared_ptr GetPlan(DbAccessor *db_accessor, const query::AuthChecker *auth_checker) const; + std::shared_ptr GetPlan(DbAccessor *db_accessor) const; std::string name_; ParsedQuery parsed_statements_; @@ -74,7 +73,7 @@ struct Trigger { mutable utils::SpinLock plan_lock_; mutable std::shared_ptr trigger_plan_; - std::optional owner_; + std::shared_ptr owner_; }; enum class TriggerPhase : uint8_t { BEFORE_COMMIT, AFTER_COMMIT }; @@ -88,8 +87,7 @@ struct TriggerStore { void AddTrigger(std::string name, const std::string &query, const std::map &user_parameters, TriggerEventType event_type, TriggerPhase phase, utils::SkipList *query_cache, DbAccessor *db_accessor, - const InterpreterConfig::Query &query_config, std::optional owner, - const query::AuthChecker *auth_checker); + const InterpreterConfig::Query &query_config, std::shared_ptr owner); void DropTrigger(const std::string &name); diff --git a/src/query/typed_value.cpp b/src/query/typed_value.cpp index 4cb79508e..86d25f01b 100644 --- a/src/query/typed_value.cpp +++ b/src/query/typed_value.cpp @@ -19,6 +19,7 @@ #include #include +#include "query/fmt.hpp" #include "storage/v2/temporal.hpp" #include "utils/exceptions.hpp" #include "utils/fnv.hpp" @@ -326,13 +327,11 @@ TypedValue::operator storage::PropertyValue() const { throw TypedValueException("TypedValue is of type '{}', not '{}'", type_, Type::type_enum); \ return field; \ } \ - \ const type_param &TypedValue::Value##type_enum() const { \ if (type_ != Type::type_enum) [[unlikely]] \ throw TypedValueException("TypedValue is of type '{}', not '{}'", type_, Type::type_enum); \ return field; \ } \ - \ bool TypedValue::Is##type_enum() const { return type_ == Type::type_enum; } DEFINE_VALUE_AND_TYPE_GETTERS(bool, Bool, bool_v) @@ -783,10 +782,13 @@ TypedValue operator<(const TypedValue &a, const TypedValue &b) { return false; } }; - if (!is_legal(a.type()) || !is_legal(b.type())) + if (!is_legal(a.type()) || !is_legal(b.type())) { throw TypedValueException("Invalid 'less' operand types({} + {})", a.type(), b.type()); + } - if (a.IsNull() || b.IsNull()) return TypedValue(a.GetMemoryResource()); + if (a.IsNull() || b.IsNull()) { + return TypedValue(a.GetMemoryResource()); + } if (a.IsString() || b.IsString()) { if (a.type() != b.type()) { @@ -956,8 +958,9 @@ inline void EnsureArithmeticallyOk(const TypedValue &a, const TypedValue &b, boo // checked here because they are handled before this check is performed in // arithmetic op implementations. - if (!is_legal(a) || !is_legal(b)) + if (!is_legal(a) || !is_legal(b)) { throw TypedValueException("Invalid {} operand types {}, {}", op_name, a.type(), b.type()); + } } namespace { @@ -1107,8 +1110,9 @@ TypedValue operator%(const TypedValue &a, const TypedValue &b) { } inline void EnsureLogicallyOk(const TypedValue &a, const TypedValue &b, const std::string &op_name) { - if (!((a.IsBool() || a.IsNull()) && (b.IsBool() || b.IsNull()))) + if (!((a.IsBool() || a.IsNull()) && (b.IsBool() || b.IsNull()))) { throw TypedValueException("Invalid {} operand types({} && {})", op_name, a.type(), b.type()); + } } TypedValue operator&&(const TypedValue &a, const TypedValue &b) { diff --git a/src/replication/include/replication/messages.hpp b/src/replication/include/replication/messages.hpp deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/replication/include/replication/replication_client.hpp b/src/replication/include/replication/replication_client.hpp index 92c321ac6..7ad980e8f 100644 --- a/src/replication/include/replication/replication_client.hpp +++ b/src/replication/include/replication/replication_client.hpp @@ -14,7 +14,9 @@ #include "replication/config.hpp" #include "replication_coordination_glue/messages.hpp" #include "rpc/client.hpp" +#include "utils/rw_lock.hpp" #include "utils/scheduler.hpp" +#include "utils/spin_lock.hpp" #include "utils/synchronized.hpp" #include "utils/thread_pool.hpp" @@ -114,8 +116,9 @@ struct ReplicationClient { enum class State { BEHIND, READY, + RECOVERY, }; - utils::Synchronized state_{State::BEHIND}; + utils::Synchronized state_{State::BEHIND}; replication_coordination_glue::ReplicationMode mode_{replication_coordination_glue::ReplicationMode::SYNC}; // This thread pool is used for background tasks so we don't diff --git a/src/replication/include/replication/state.hpp b/src/replication/include/replication/state.hpp index 18f9efd4e..fb47185fc 100644 --- a/src/replication/include/replication/state.hpp +++ b/src/replication/include/replication/state.hpp @@ -39,7 +39,8 @@ enum class RegisterReplicaError : uint8_t { NAME_EXISTS, ENDPOINT_EXISTS, COULD_ struct RoleMainData { RoleMainData() = default; - explicit RoleMainData(ReplicationEpoch e, std::optional uuid = std::nullopt) : epoch_(std::move(e)) { + explicit RoleMainData(ReplicationEpoch e, bool writing_enabled, std::optional uuid = std::nullopt) + : epoch_(std::move(e)), writing_enabled_(writing_enabled) { if (uuid) { uuid_ = *uuid; } @@ -54,6 +55,7 @@ struct RoleMainData { ReplicationEpoch epoch_; std::list registered_replicas_{}; // TODO: data race issues utils::UUID uuid_; + bool writing_enabled_{false}; }; struct RoleReplicaData { @@ -90,6 +92,21 @@ struct ReplicationState { bool IsMain() const { return GetRole() == replication_coordination_glue::ReplicationRole::MAIN; } bool IsReplica() const { return GetRole() == replication_coordination_glue::ReplicationRole::REPLICA; } + auto IsMainWriteable() const -> bool { + if (auto const *main = std::get_if(&replication_data_)) { + return main->writing_enabled_; + } + return false; + } + + auto EnableWritingOnMain() -> bool { + if (auto *main = std::get_if(&replication_data_)) { + main->writing_enabled_ = true; + return true; + } + return false; + } + bool HasDurability() const { return nullptr != durability_; } bool TryPersistRoleMain(std::string new_epoch, utils::UUID main_uuid); diff --git a/src/replication/messages.cpp b/src/replication/messages.cpp deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/replication/replication_client.cpp b/src/replication/replication_client.cpp index ed46ea471..262d698bf 100644 --- a/src/replication/replication_client.cpp +++ b/src/replication/replication_client.cpp @@ -10,6 +10,7 @@ // licenses/APL.txt. #include "replication/replication_client.hpp" +#include "io/network/fmt.hpp" namespace memgraph::replication { @@ -30,7 +31,7 @@ ReplicationClient::ReplicationClient(const memgraph::replication::ReplicationCli ReplicationClient::~ReplicationClient() { try { auto const &endpoint = rpc_client_.Endpoint(); - spdlog::trace("Closing replication client on {}:{}", endpoint.address, endpoint.port); + spdlog::trace("Closing replication client on {}", endpoint); } catch (...) { // Logging can throw. Not a big deal, just ignore. } diff --git a/src/replication/state.cpp b/src/replication/state.cpp index 6b1d128ec..1155fdb51 100644 --- a/src/replication/state.cpp +++ b/src/replication/state.cpp @@ -62,8 +62,9 @@ ReplicationState::ReplicationState(std::optional durabili } #endif if (std::holds_alternative(replication_data)) { - spdlog::trace("Recovered main's uuid for replica {}", - std::string(std::get(replication_data).uuid_.value())); + auto &replica_uuid = std::get(replication_data).uuid_; + std::string uuid = replica_uuid.has_value() ? std::string(replica_uuid.value()) : ""; + spdlog::trace("Recovered main's uuid for replica {}", uuid); } else { spdlog::trace("Recovered uuid for main {}", std::string(std::get(replication_data).uuid_)); } @@ -144,8 +145,8 @@ auto ReplicationState::FetchReplicationData() -> FetchReplicationResult_t { return std::visit( utils::Overloaded{ [&](durability::MainRole &&r) -> FetchReplicationResult_t { - auto res = - RoleMainData{std::move(r.epoch), r.main_uuid.has_value() ? r.main_uuid.value() : utils::UUID{}}; + auto res = RoleMainData{std::move(r.epoch), false, + r.main_uuid.has_value() ? r.main_uuid.value() : utils::UUID{}}; auto b = durability_->begin(durability::kReplicationReplicaPrefix); auto e = durability_->end(durability::kReplicationReplicaPrefix); for (; b != e; ++b) { @@ -253,7 +254,7 @@ bool ReplicationState::SetReplicationRoleMain(const utils::UUID &main_uuid) { return false; } - replication_data_ = RoleMainData{ReplicationEpoch{new_epoch}, main_uuid}; + replication_data_ = RoleMainData{ReplicationEpoch{new_epoch}, true, main_uuid}; return true; } diff --git a/src/replication_handler/include/replication_handler/replication_handler.hpp b/src/replication_handler/include/replication_handler/replication_handler.hpp index 7882fa3c0..b110e6015 100644 --- a/src/replication_handler/include/replication_handler/replication_handler.hpp +++ b/src/replication_handler/include/replication_handler/replication_handler.hpp @@ -14,6 +14,7 @@ #include "dbms/dbms_handler.hpp" #include "flags/experimental.hpp" #include "replication/include/replication/state.hpp" +#include "replication_handler/system_replication.hpp" #include "replication_handler/system_rpc.hpp" #include "utils/result.hpp" @@ -38,10 +39,12 @@ void SystemRestore(replication::ReplicationClient &client, system::System &syste const utils::UUID &main_uuid, auth::SynchedAuth &auth) { // Check if system is up to date if (client.state_.WithLock( - [](auto &state) { return state == memgraph::replication::ReplicationClient::State::READY; })) + [](auto &state) { return state != memgraph::replication::ReplicationClient::State::BEHIND; })) return; // Try to recover... + client.state_.WithLock( + [](auto &state) { return state != memgraph::replication::ReplicationClient::State::RECOVERY; }); { using enum memgraph::flags::Experiments; bool full_system_replication = @@ -113,15 +116,19 @@ struct ReplicationHandler : public memgraph::query::ReplicationQueryHandler { // as REPLICA, become MAIN bool SetReplicationRoleMain() override; - // as MAIN, become REPLICA + // as MAIN, become REPLICA, can be called on MAIN and REPLICA bool SetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config, const std::optional &main_uuid) override; + // as MAIN, become REPLICA, can be called only on MAIN + bool TrySetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config, + const std::optional &main_uuid) override; + // as MAIN, define and connect to REPLICAs - auto TryRegisterReplica(const memgraph::replication::ReplicationClientConfig &config, bool send_swap_uuid) + auto TryRegisterReplica(const memgraph::replication::ReplicationClientConfig &config) -> memgraph::utils::BasicResult override; - auto RegisterReplica(const memgraph::replication::ReplicationClientConfig &config, bool send_swap_uuid) + auto RegisterReplica(const memgraph::replication::ReplicationClientConfig &config) -> memgraph::utils::BasicResult override; // as MAIN, remove a REPLICA connection @@ -134,15 +141,19 @@ struct ReplicationHandler : public memgraph::query::ReplicationQueryHandler { bool IsMain() const override; bool IsReplica() const override; + auto ShowReplicas() const + -> utils::BasicResult override; + auto GetReplState() const -> const memgraph::replication::ReplicationState &; auto GetReplState() -> memgraph::replication::ReplicationState &; + auto GetReplicaUUID() -> std::optional; + private: - template - auto RegisterReplica_(const memgraph::replication::ReplicationClientConfig &config, bool send_swap_uuid) + template + auto RegisterReplica_(const memgraph::replication::ReplicationClientConfig &config) -> memgraph::utils::BasicResult { MG_ASSERT(repl_state_.IsMain(), "Only main instance can register a replica!"); - auto maybe_client = repl_state_.RegisterReplica(config); if (maybe_client.HasError()) { switch (maybe_client.GetError()) { @@ -159,7 +170,6 @@ struct ReplicationHandler : public memgraph::query::ReplicationQueryHandler { break; } } - using enum memgraph::flags::Experiments; bool system_replication_enabled = flags::AreExperimentsEnabled(SYSTEM_REPLICATION); if (!system_replication_enabled && dbms_handler_.Count() > 1) { @@ -167,25 +177,21 @@ struct ReplicationHandler : public memgraph::query::ReplicationQueryHandler { } const auto main_uuid = std::get(dbms_handler_.ReplicationState().ReplicationData()).uuid_; - - if (send_swap_uuid) { + if constexpr (SendSwapUUID) { if (!memgraph::replication_coordination_glue::SendSwapMainUUIDRpc(maybe_client.GetValue()->rpc_client_, main_uuid)) { return memgraph::query::RegisterReplicaError::ERROR_ACCEPTING_MAIN; } } - #ifdef MG_ENTERPRISE // Update system before enabling individual storage <-> replica clients SystemRestore(*maybe_client.GetValue(), system_, dbms_handler_, main_uuid, auth_); #endif - const auto dbms_error = HandleRegisterReplicaStatus(maybe_client); if (dbms_error.has_value()) { return *dbms_error; } auto &instance_client_ptr = maybe_client.GetValue(); - bool all_clients_good = true; // Add database specific clients (NOTE Currently all databases are connected to each replica) dbms_handler_.ForEach([&](dbms::DatabaseAccess db_acc) { @@ -195,7 +201,6 @@ struct ReplicationHandler : public memgraph::query::ReplicationQueryHandler { } // TODO: ATM only IN_MEMORY_TRANSACTIONAL, fix other modes if (storage->storage_mode_ != storage::StorageMode::IN_MEMORY_TRANSACTIONAL) return; - all_clients_good &= storage->repl_storage_state_.replication_clients_.WithLock( [storage, &instance_client_ptr, db_acc = std::move(db_acc), main_uuid](auto &storage_clients) mutable { // NOLINT @@ -203,9 +208,9 @@ struct ReplicationHandler : public memgraph::query::ReplicationQueryHandler { client->Start(storage, std::move(db_acc)); bool const success = std::invoke([state = client->State()]() { if (state == storage::replication::ReplicaState::DIVERGED_FROM_MAIN) { - return AllowReplicaToDivergeFromMain; + return false; } - return state != storage::replication::ReplicaState::MAYBE_BEHIND; + return true; }); if (success) { @@ -214,14 +219,12 @@ struct ReplicationHandler : public memgraph::query::ReplicationQueryHandler { return success; }); }); - // NOTE Currently if any databases fails, we revert back if (!all_clients_good) { spdlog::error("Failed to register all databases on the REPLICA \"{}\"", config.name); UnregisterReplica(config.name); return memgraph::query::RegisterReplicaError::CONNECTION_FAILED; } - // No client error, start instance level client #ifdef MG_ENTERPRISE StartReplicaClient(*instance_client_ptr, system_, dbms_handler_, main_uuid, auth_); @@ -231,6 +234,57 @@ struct ReplicationHandler : public memgraph::query::ReplicationQueryHandler { return {}; } + template + bool SetReplicationRoleReplica_(const memgraph::replication::ReplicationServerConfig &config, + const std::optional &main_uuid) { + if (repl_state_.IsReplica()) { + if (!AllowIdempotency) { + return false; + } + // We don't want to restart the server if we're already a REPLICA with correct config + auto &replica_data = std::get(repl_state_.ReplicationData()); + if (replica_data.config == config) { + return true; + } + repl_state_.SetReplicationRoleReplica(config, main_uuid); +#ifdef MG_ENTERPRISE + return StartRpcServer(dbms_handler_, replica_data, auth_, system_); +#else + return StartRpcServer(dbms_handler_, replica_data); +#endif + } + + // TODO StorageState needs to be synched. Could have a dangling reference if someone adds a database as we are + // deleting the replica. + // Remove database specific clients + dbms_handler_.ForEach([&](memgraph::dbms::DatabaseAccess db_acc) { + auto *storage = db_acc->storage(); + storage->repl_storage_state_.replication_clients_.WithLock([](auto &clients) { clients.clear(); }); + }); + // Remove instance level clients + std::get(repl_state_.ReplicationData()).registered_replicas_.clear(); + + // Creates the server + repl_state_.SetReplicationRoleReplica(config, main_uuid); + + // Start + const auto success = + std::visit(memgraph::utils::Overloaded{[](memgraph::replication::RoleMainData &) { + // ASSERT + return false; + }, + [this](memgraph::replication::RoleReplicaData &data) { +#ifdef MG_ENTERPRISE + return StartRpcServer(dbms_handler_, data, auth_, system_); +#else + return StartRpcServer(dbms_handler_, data); +#endif + }}, + repl_state_.ReplicationData()); + // TODO Handle error (restore to main?) + return success; + } + memgraph::replication::ReplicationState &repl_state_; memgraph::dbms::DbmsHandler &dbms_handler_; diff --git a/src/replication_handler/replication_handler.cpp b/src/replication_handler/replication_handler.cpp index 8d07d5af5..5f807779d 100644 --- a/src/replication_handler/replication_handler.cpp +++ b/src/replication_handler/replication_handler.cpp @@ -10,26 +10,28 @@ // licenses/APL.txt. #include "replication_handler/replication_handler.hpp" +#include "dbms/constants.hpp" #include "dbms/dbms_handler.hpp" +#include "replication/replication_client.hpp" #include "replication_handler/system_replication.hpp" namespace memgraph::replication { namespace { #ifdef MG_ENTERPRISE -void RecoverReplication(memgraph::replication::ReplicationState &repl_state, memgraph::system::System &system, - memgraph::dbms::DbmsHandler &dbms_handler, memgraph::auth::SynchedAuth &auth) { +void RecoverReplication(replication::ReplicationState &repl_state, system::System &system, + dbms::DbmsHandler &dbms_handler, auth::SynchedAuth &auth) { /* * REPLICATION RECOVERY AND STARTUP */ // Startup replication state (if recovered at startup) - auto replica = [&dbms_handler, &auth, &system](memgraph::replication::RoleReplicaData &data) { - return memgraph::replication::StartRpcServer(dbms_handler, data, auth, system); + auto replica = [&dbms_handler, &auth, &system](replication::RoleReplicaData &data) { + return replication::StartRpcServer(dbms_handler, data, auth, system); }; // Replication recovery and frequent check start - auto main = [&system, &dbms_handler, &auth](memgraph::replication::RoleMainData &mainData) { + auto main = [&system, &dbms_handler, &auth](replication::RoleMainData &mainData) { for (auto &client : mainData.registered_replicas_) { if (client.try_set_uuid && replication_coordination_glue::SendSwapMainUUIDRpc(client.rpc_client_, mainData.uuid_)) { @@ -38,7 +40,7 @@ void RecoverReplication(memgraph::replication::ReplicationState &repl_state, mem SystemRestore(client, system, dbms_handler, mainData.uuid_, auth); } // DBMS here - dbms_handler.ForEach([&mainData](memgraph::dbms::DatabaseAccess db_acc) { + dbms_handler.ForEach([&mainData](dbms::DatabaseAccess db_acc) { dbms::DbmsHandler::RecoverStorageReplication(std::move(db_acc), mainData); }); @@ -48,7 +50,7 @@ void RecoverReplication(memgraph::replication::ReplicationState &repl_state, mem // Warning if (dbms_handler.default_config().durability.snapshot_wal_mode == - memgraph::storage::Config::Durability::SnapshotWalMode::DISABLED) { + storage::Config::Durability::SnapshotWalMode::DISABLED) { spdlog::warn( "The instance has the MAIN replication role, but durability logs and snapshots are disabled. Please " "consider " @@ -59,19 +61,18 @@ void RecoverReplication(memgraph::replication::ReplicationState &repl_state, mem return true; }; - auto result = std::visit(memgraph::utils::Overloaded{replica, main}, repl_state.ReplicationData()); + auto result = std::visit(utils::Overloaded{replica, main}, repl_state.ReplicationData()); MG_ASSERT(result, "Replica recovery failure!"); } #else -void RecoverReplication(memgraph::replication::ReplicationState &repl_state, - memgraph::dbms::DbmsHandler &dbms_handler) { +void RecoverReplication(replication::ReplicationState &repl_state, dbms::DbmsHandler &dbms_handler) { // Startup replication state (if recovered at startup) - auto replica = [&dbms_handler](memgraph::replication::RoleReplicaData &data) { - return memgraph::replication::StartRpcServer(dbms_handler, data); + auto replica = [&dbms_handler](replication::RoleReplicaData &data) { + return replication::StartRpcServer(dbms_handler, data); }; // Replication recovery and frequent check start - auto main = [&dbms_handler](memgraph::replication::RoleMainData &mainData) { + auto main = [&dbms_handler](replication::RoleMainData &mainData) { dbms::DbmsHandler::RecoverStorageReplication(dbms_handler.Get(), mainData); for (auto &client : mainData.registered_replicas_) { @@ -79,12 +80,12 @@ void RecoverReplication(memgraph::replication::ReplicationState &repl_state, replication_coordination_glue::SendSwapMainUUIDRpc(client.rpc_client_, mainData.uuid_)) { client.try_set_uuid = false; } - memgraph::replication::StartReplicaClient(client, dbms_handler, mainData.uuid_); + replication::StartReplicaClient(client, dbms_handler, mainData.uuid_); } // Warning if (dbms_handler.default_config().durability.snapshot_wal_mode == - memgraph::storage::Config::Durability::SnapshotWalMode::DISABLED) { + storage::Config::Durability::SnapshotWalMode::DISABLED) { spdlog::warn( "The instance has the MAIN replication role, but durability logs and snapshots are disabled. Please " "consider " @@ -95,7 +96,7 @@ void RecoverReplication(memgraph::replication::ReplicationState &repl_state, return true; }; - auto result = std::visit(memgraph::utils::Overloaded{replica, main}, repl_state.ReplicationData()); + auto result = std::visit(utils::Overloaded{replica, main}, repl_state.ReplicationData()); MG_ASSERT(result, "Replica recovery failure!"); } #endif @@ -103,7 +104,8 @@ void RecoverReplication(memgraph::replication::ReplicationState &repl_state, inline std::optional HandleRegisterReplicaStatus( utils::BasicResult &instance_client) { - if (instance_client.HasError()) switch (instance_client.GetError()) { + if (instance_client.HasError()) { + switch (instance_client.GetError()) { case replication::RegisterReplicaError::NOT_MAIN: MG_ASSERT(false, "Only main instance can register a replica!"); return {}; @@ -116,6 +118,7 @@ inline std::optional HandleRegisterReplicaStatus( case replication::RegisterReplicaError::SUCCESS: break; } + } return {}; } @@ -131,20 +134,19 @@ void StartReplicaClient(replication::ReplicationClient &client, dbms::DbmsHandle spdlog::trace("Replication client started at: {}:{}", endpoint.address, endpoint.port); client.StartFrequentCheck([&, license = license::global_license_checker.IsEnterpriseValidFast(), main_uuid]( bool reconnect, replication::ReplicationClient &client) mutable { - if (client.try_set_uuid && - memgraph::replication_coordination_glue::SendSwapMainUUIDRpc(client.rpc_client_, main_uuid)) { + if (client.try_set_uuid && replication_coordination_glue::SendSwapMainUUIDRpc(client.rpc_client_, main_uuid)) { client.try_set_uuid = false; } // Working connection // Check if system needs restoration if (reconnect) { - client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; }); + client.state_.WithLock([](auto &state) { state = replication::ReplicationClient::State::BEHIND; }); } // Check if license has changed const auto new_license = license::global_license_checker.IsEnterpriseValidFast(); if (new_license != license) { license = new_license; - client.state_.WithLock([](auto &state) { state = memgraph::replication::ReplicationClient::State::BEHIND; }); + client.state_.WithLock([](auto &state) { state = replication::ReplicationClient::State::BEHIND; }); } #ifdef MG_ENTERPRISE SystemRestore(client, system, dbms_handler, main_uuid, auth); @@ -152,10 +154,10 @@ void StartReplicaClient(replication::ReplicationClient &client, dbms::DbmsHandle // Check if any database has been left behind dbms_handler.ForEach([&name = client.name_, reconnect](dbms::DatabaseAccess db_acc) { // Specific database <-> replica client - db_acc->storage()->repl_storage_state_.WithClient(name, [&](storage::ReplicationStorageClient *client) { - if (reconnect || client->State() == storage::replication::ReplicaState::MAYBE_BEHIND) { + db_acc->storage()->repl_storage_state_.WithClient(name, [&](storage::ReplicationStorageClient &client) { + if (reconnect || client.State() == storage::replication::ReplicaState::MAYBE_BEHIND) { // Database <-> replica might be behind, check and recover - client->TryCheckReplicaStateAsync(db_acc->storage(), db_acc); + client.TryCheckReplicaStateAsync(db_acc->storage(), db_acc); } }); }); @@ -163,9 +165,8 @@ void StartReplicaClient(replication::ReplicationClient &client, dbms::DbmsHandle } #ifdef MG_ENTERPRISE -ReplicationHandler::ReplicationHandler(memgraph::replication::ReplicationState &repl_state, - memgraph::dbms::DbmsHandler &dbms_handler, memgraph::system::System &system, - memgraph::auth::SynchedAuth &auth) +ReplicationHandler::ReplicationHandler(replication::ReplicationState &repl_state, dbms::DbmsHandler &dbms_handler, + system::System &system, auth::SynchedAuth &auth) : repl_state_{repl_state}, dbms_handler_{dbms_handler}, system_{system}, auth_{auth} { RecoverReplication(repl_state_, system_, dbms_handler_, auth_); } @@ -177,56 +178,27 @@ ReplicationHandler::ReplicationHandler(replication::ReplicationState &repl_state #endif bool ReplicationHandler::SetReplicationRoleMain() { - auto const main_handler = [](memgraph::replication::RoleMainData &) { + auto const main_handler = [](replication::RoleMainData &) { // If we are already MAIN, we don't want to change anything return false; }; - auto const replica_handler = [this](memgraph::replication::RoleReplicaData const &) { + auto const replica_handler = [this](replication::RoleReplicaData const &) { return DoReplicaToMainPromotion(utils::UUID{}); }; // TODO: under lock - return std::visit(memgraph::utils::Overloaded{main_handler, replica_handler}, repl_state_.ReplicationData()); + return std::visit(utils::Overloaded{main_handler, replica_handler}, repl_state_.ReplicationData()); } -bool ReplicationHandler::SetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config, +bool ReplicationHandler::SetReplicationRoleReplica(const replication::ReplicationServerConfig &config, const std::optional &main_uuid) { - // We don't want to restart the server if we're already a REPLICA - if (repl_state_.IsReplica()) { - spdlog::trace("Instance has already has replica role."); - return false; - } + return SetReplicationRoleReplica_(config, main_uuid); +} - // TODO StorageState needs to be synched. Could have a dangling reference if someone adds a database as we are - // deleting the replica. - // Remove database specific clients - dbms_handler_.ForEach([&](memgraph::dbms::DatabaseAccess db_acc) { - auto *storage = db_acc->storage(); - storage->repl_storage_state_.replication_clients_.WithLock([](auto &clients) { clients.clear(); }); - }); - // Remove instance level clients - std::get(repl_state_.ReplicationData()).registered_replicas_.clear(); - - // Creates the server - repl_state_.SetReplicationRoleReplica(config, main_uuid); - - // Start - const auto success = - std::visit(memgraph::utils::Overloaded{[](memgraph::replication::RoleMainData &) { - // ASSERT - return false; - }, - [this](memgraph::replication::RoleReplicaData &data) { -#ifdef MG_ENTERPRISE - return StartRpcServer(dbms_handler_, data, auth_, system_); -#else - return StartRpcServer(dbms_handler_, data); -#endif - }}, - repl_state_.ReplicationData()); - // TODO Handle error (restore to main?) - return success; +bool ReplicationHandler::TrySetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config, + const std::optional &main_uuid) { + return SetReplicationRoleReplica_(config, main_uuid); } bool ReplicationHandler::DoReplicaToMainPromotion(const utils::UUID &main_uuid) { @@ -255,30 +227,26 @@ bool ReplicationHandler::DoReplicaToMainPromotion(const utils::UUID &main_uuid) }; // as MAIN, define and connect to REPLICAs -auto ReplicationHandler::TryRegisterReplica(const memgraph::replication::ReplicationClientConfig &config, - bool send_swap_uuid) +auto ReplicationHandler::TryRegisterReplica(const memgraph::replication::ReplicationClientConfig &config) -> memgraph::utils::BasicResult { - return RegisterReplica_(config, send_swap_uuid); + return RegisterReplica_(config); } -auto ReplicationHandler::RegisterReplica(const memgraph::replication::ReplicationClientConfig &config, - bool send_swap_uuid) +auto ReplicationHandler::RegisterReplica(const memgraph::replication::ReplicationClientConfig &config) -> memgraph::utils::BasicResult { - return RegisterReplica_(config, send_swap_uuid); + return RegisterReplica_(config); } -auto ReplicationHandler::UnregisterReplica(std::string_view name) -> memgraph::query::UnregisterReplicaResult { - auto const replica_handler = - [](memgraph::replication::RoleReplicaData const &) -> memgraph::query::UnregisterReplicaResult { - return memgraph::query::UnregisterReplicaResult::NOT_MAIN; +auto ReplicationHandler::UnregisterReplica(std::string_view name) -> query::UnregisterReplicaResult { + auto const replica_handler = [](replication::RoleReplicaData const &) -> query::UnregisterReplicaResult { + return query::UnregisterReplicaResult::NOT_MAIN; }; - auto const main_handler = - [this, name](memgraph::replication::RoleMainData &mainData) -> memgraph::query::UnregisterReplicaResult { + auto const main_handler = [this, name](replication::RoleMainData &mainData) -> query::UnregisterReplicaResult { if (!repl_state_.TryPersistUnregisterReplica(name)) { - return memgraph::query::UnregisterReplicaResult::COULD_NOT_BE_PERSISTED; + return query::UnregisterReplicaResult::COULD_NOT_BE_PERSISTED; } // Remove database specific clients - dbms_handler_.ForEach([name](memgraph::dbms::DatabaseAccess db_acc) { + dbms_handler_.ForEach([name](dbms::DatabaseAccess db_acc) { db_acc->storage()->repl_storage_state_.replication_clients_.WithLock([&name](auto &clients) { std::erase_if(clients, [name](const auto &client) { return client->Name() == name; }); }); @@ -286,23 +254,75 @@ auto ReplicationHandler::UnregisterReplica(std::string_view name) -> memgraph::q // Remove instance level clients auto const n_unregistered = std::erase_if(mainData.registered_replicas_, [name](auto const &client) { return client.name_ == name; }); - return n_unregistered != 0 ? memgraph::query::UnregisterReplicaResult::SUCCESS - : memgraph::query::UnregisterReplicaResult::CAN_NOT_UNREGISTER; + return n_unregistered != 0 ? query::UnregisterReplicaResult::SUCCESS + : query::UnregisterReplicaResult::CAN_NOT_UNREGISTER; }; - return std::visit(memgraph::utils::Overloaded{main_handler, replica_handler}, repl_state_.ReplicationData()); + return std::visit(utils::Overloaded{main_handler, replica_handler}, repl_state_.ReplicationData()); } -auto ReplicationHandler::GetRole() const -> memgraph::replication_coordination_glue::ReplicationRole { +auto ReplicationHandler::GetRole() const -> replication_coordination_glue::ReplicationRole { return repl_state_.GetRole(); } +auto ReplicationHandler::GetReplicaUUID() -> std::optional { + MG_ASSERT(repl_state_.IsReplica()); + return std::get(repl_state_.ReplicationData()).uuid_; +} + auto ReplicationHandler::GetReplState() const -> const memgraph::replication::ReplicationState & { return repl_state_; } -auto ReplicationHandler::GetReplState() -> memgraph::replication::ReplicationState & { return repl_state_; } +auto ReplicationHandler::GetReplState() -> replication::ReplicationState & { return repl_state_; } bool ReplicationHandler::IsMain() const { return repl_state_.IsMain(); } bool ReplicationHandler::IsReplica() const { return repl_state_.IsReplica(); } +auto ReplicationHandler::ShowReplicas() const -> utils::BasicResult { + using res_t = utils::BasicResult; + auto main = [this](RoleMainData const &main) -> res_t { + auto entries = std::vector{}; + entries.reserve(main.registered_replicas_.size()); + + const bool full_info = license::global_license_checker.IsEnterpriseValidFast(); + + for (auto const &replica : main.registered_replicas_) { + // STEP 1: data_info + auto data_info = std::map{}; + this->dbms_handler_.ForEach([&](dbms::DatabaseAccess db_acc) { + auto *storage = db_acc->storage(); + // ATM we only support IN_MEMORY_TRANSACTIONAL + if (storage->storage_mode_ != storage::StorageMode::IN_MEMORY_TRANSACTIONAL) return; + if (!full_info && storage->name() == dbms::kDefaultDB) return; + auto ok = + storage->repl_storage_state_.WithClient(replica.name_, [&](storage::ReplicationStorageClient &client) { + auto ts_info = client.GetTimestampInfo(storage); + auto state = client.State(); + + data_info.emplace(storage->name(), + query::ReplicaInfoState{ts_info.current_timestamp_of_replica, + ts_info.current_number_of_timestamp_behind_main, state}); + }); + DMG_ASSERT(ok); + }); + +// STEP 2: system_info +#ifdef MG_ENTERPRISE + // Already locked on system transaction via the interpreter + const auto ts = system_.LastCommittedSystemTimestamp(); + // NOTE: no system behind at the moment + query::ReplicaSystemInfoState system_info{ts, 0 /* behind ts not implemented */, *replica.state_.ReadLock()}; +#else + query::ReplicaSystemInfoState system_info{}; +#endif + // STEP 3: add entry + entries.emplace_back(replica.name_, replica.rpc_client_.Endpoint().SocketAddress(), replica.mode_, system_info, + std::move(data_info)); + } + return query::ReplicasInfos{std::move(entries)}; + }; + auto replica = [](RoleReplicaData const &) -> res_t { return query::ShowReplicaError::NOT_MAIN; }; + + return std::visit(utils::Overloaded{main, replica}, repl_state_.ReplicationData()); +} } // namespace memgraph::replication diff --git a/src/rpc/client.hpp b/src/rpc/client.hpp index 3a2fefd57..d14746313 100644 --- a/src/rpc/client.hpp +++ b/src/rpc/client.hpp @@ -27,6 +27,8 @@ #include "utils/on_scope_exit.hpp" #include "utils/typeinfo.hpp" +#include "io/network/fmt.hpp" + namespace memgraph::rpc { /// Client is thread safe, but it is recommended to use thread_local clients. diff --git a/src/storage/v2/disk/storage.cpp b/src/storage/v2/disk/storage.cpp index adc0e92f4..f9cd2ac13 100644 --- a/src/storage/v2/disk/storage.cpp +++ b/src/storage/v2/disk/storage.cpp @@ -825,7 +825,7 @@ uint64_t DiskStorage::GetDiskSpaceUsage() const { durability_disk_storage_size; } -StorageInfo DiskStorage::GetBaseInfo(bool /* unused */) { +StorageInfo DiskStorage::GetBaseInfo() { StorageInfo info{}; info.vertex_count = vertex_count_; info.edge_count = edge_count_.load(std::memory_order_acquire); @@ -838,9 +838,8 @@ StorageInfo DiskStorage::GetBaseInfo(bool /* unused */) { return info; } -StorageInfo DiskStorage::GetInfo(bool force_dir, - memgraph::replication_coordination_glue::ReplicationRole replication_role) { - StorageInfo info = GetBaseInfo(force_dir); +StorageInfo DiskStorage::GetInfo(memgraph::replication_coordination_glue::ReplicationRole replication_role) { + StorageInfo info = GetBaseInfo(); { auto access = Access(replication_role); const auto &lbl = access->ListAllIndices(); @@ -1278,7 +1277,7 @@ bool DiskStorage::DeleteEdgeFromConnectivityIndex(Transaction *transaction, cons /// std::map /// Here we also do flushing of too many things, we don't need to serialize edges in read-only txn, check that... [[nodiscard]] utils::BasicResult DiskStorage::FlushModifiedEdges( - Transaction *transaction, const auto &edge_acc) { + Transaction *transaction, const auto &edges_acc) { for (const auto &modified_edge : transaction->modified_edges_) { const std::string edge_gid = modified_edge.first.ToString(); const Delta::Action root_action = modified_edge.second.delta_action; @@ -1304,8 +1303,8 @@ bool DiskStorage::DeleteEdgeFromConnectivityIndex(Transaction *transaction, cons return StorageManipulationError{SerializationError{}}; } - const auto &edge = edge_acc.find(modified_edge.first); - MG_ASSERT(edge != edge_acc.end(), + const auto &edge = edges_acc.find(modified_edge.first); + MG_ASSERT(edge != edges_acc.end(), "Database in invalid state, commit not possible! Please restart your DB and start the import again."); /// TODO: (andi) I think this is not wrong but it would be better to use AtomicWrites across column families. @@ -1693,9 +1692,8 @@ utils::BasicResult DiskStorage::DiskAccessor::Co transaction_.commit_timestamp->store(*commit_timestamp_, std::memory_order_release); if (edge_import_mode_active) { - if (auto res = - disk_storage->FlushModifiedEdges(&transaction_, disk_storage->edge_import_mode_cache_->AccessToEdges()); - res.HasError()) { + auto edges_acc = disk_storage->edge_import_mode_cache_->AccessToEdges(); + if (auto res = disk_storage->FlushModifiedEdges(&transaction_, edges_acc); res.HasError()) { Abort(); return res; } @@ -1717,7 +1715,8 @@ utils::BasicResult DiskStorage::DiskAccessor::Co return del_vertices_res.GetError(); } - if (auto modified_edges_res = disk_storage->FlushModifiedEdges(&transaction_, transaction_.edges_->access()); + auto tx_edges_acc = transaction_.edges_->access(); + if (auto modified_edges_res = disk_storage->FlushModifiedEdges(&transaction_, tx_edges_acc); modified_edges_res.HasError()) { Abort(); return modified_edges_res.GetError(); diff --git a/src/storage/v2/disk/storage.hpp b/src/storage/v2/disk/storage.hpp index cc3b24c2f..ea2e6714e 100644 --- a/src/storage/v2/disk/storage.hpp +++ b/src/storage/v2/disk/storage.hpp @@ -197,7 +197,7 @@ class DiskStorage final : public Storage { [[nodiscard]] utils::BasicResult FlushDeletedVertices(Transaction *transaction); [[nodiscard]] utils::BasicResult FlushDeletedEdges(Transaction *transaction); [[nodiscard]] utils::BasicResult FlushModifiedEdges(Transaction *transaction, - const auto &edge_acc); + const auto &edges_acc); [[nodiscard]] utils::BasicResult ClearDanglingVertices(Transaction *transaction); /// Writing methods @@ -309,9 +309,8 @@ class DiskStorage final : public Storage { std::vector> SerializeVerticesForLabelPropertyIndex(LabelId label, PropertyId property); - StorageInfo GetBaseInfo(bool force_directory) override; - StorageInfo GetInfo(bool force_directory, - memgraph::replication_coordination_glue::ReplicationRole replication_role) override; + StorageInfo GetBaseInfo() override; + StorageInfo GetInfo(memgraph::replication_coordination_glue::ReplicationRole replication_role) override; void FreeMemory(std::unique_lock /*lock*/) override {} diff --git a/src/storage/v2/durability/snapshot.cpp b/src/storage/v2/durability/snapshot.cpp index 0d434fadf..eee099870 100644 --- a/src/storage/v2/durability/snapshot.cpp +++ b/src/storage/v2/durability/snapshot.cpp @@ -22,6 +22,7 @@ #include "storage/v2/edge.hpp" #include "storage/v2/edge_accessor.hpp" #include "storage/v2/edge_ref.hpp" +#include "storage/v2/fmt.hpp" #include "storage/v2/id_types.hpp" #include "storage/v2/indices/label_index_stats.hpp" #include "storage/v2/indices/label_property_index_stats.hpp" diff --git a/src/storage/v2/edge_accessor.cpp b/src/storage/v2/edge_accessor.cpp index 03522ba16..62a9f4bcd 100644 --- a/src/storage/v2/edge_accessor.cpp +++ b/src/storage/v2/edge_accessor.cpp @@ -17,6 +17,7 @@ #include "storage/v2/delta.hpp" #include "storage/v2/mvcc.hpp" +#include "storage/v2/property_store.hpp" #include "storage/v2/property_value.hpp" #include "storage/v2/result.hpp" #include "storage/v2/storage.hpp" @@ -264,6 +265,27 @@ Result EdgeAccessor::GetProperty(PropertyId property, View view) return *std::move(value); } +Result EdgeAccessor::GetPropertySize(PropertyId property, View view) const { + if (!storage_->config_.salient.items.properties_on_edges) return 0; + + auto guard = std::shared_lock{edge_.ptr->lock}; + Delta *delta = edge_.ptr->delta; + if (!delta) { + return edge_.ptr->properties.PropertySize(property); + } + + auto property_result = this->GetProperty(property, view); + + if (property_result.HasError()) { + return property_result.GetError(); + } + + auto property_store = storage::PropertyStore(); + property_store.SetProperty(property, *property_result); + + return property_store.PropertySize(property); +}; + Result> EdgeAccessor::Properties(View view) const { if (!storage_->config_.salient.items.properties_on_edges) return std::map{}; bool exists = true; diff --git a/src/storage/v2/edge_accessor.hpp b/src/storage/v2/edge_accessor.hpp index 83a3e549d..6b76ddbe8 100644 --- a/src/storage/v2/edge_accessor.hpp +++ b/src/storage/v2/edge_accessor.hpp @@ -82,6 +82,9 @@ class EdgeAccessor final { /// @throw std::bad_alloc Result GetProperty(PropertyId property, View view) const; + /// Returns the size of the encoded edge property in bytes. + Result GetPropertySize(PropertyId property, View view) const; + /// @throw std::bad_alloc Result> Properties(View view) const; diff --git a/src/storage/v2/fmt.hpp b/src/storage/v2/fmt.hpp new file mode 100644 index 000000000..e200d7299 --- /dev/null +++ b/src/storage/v2/fmt.hpp @@ -0,0 +1,23 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#if FMT_VERSION > 90000 +#include + +#include "storage/v2/property_value.hpp" + +template <> +class fmt::formatter : public fmt::ostream_formatter {}; +template <> +class fmt::formatter : public fmt::ostream_formatter {}; +#endif diff --git a/src/storage/v2/inmemory/storage.cpp b/src/storage/v2/inmemory/storage.cpp index ac4ba45a9..8134be070 100644 --- a/src/storage/v2/inmemory/storage.cpp +++ b/src/storage/v2/inmemory/storage.cpp @@ -11,6 +11,7 @@ #include "storage/v2/inmemory/storage.hpp" #include +#include #include #include #include "dbms/constants.hpp" @@ -1758,7 +1759,7 @@ void InMemoryStorage::CollectGarbage(std::unique_lock main_ template void InMemoryStorage::CollectGarbage(std::unique_lock); template void InMemoryStorage::CollectGarbage(std::unique_lock); -StorageInfo InMemoryStorage::GetBaseInfo(bool force_directory) { +StorageInfo InMemoryStorage::GetBaseInfo() { StorageInfo info{}; info.vertex_count = vertices_.size(); info.edge_count = edge_count_.load(std::memory_order_acquire); @@ -1769,27 +1770,23 @@ StorageInfo InMemoryStorage::GetBaseInfo(bool force_directory) { info.memory_res = utils::GetMemoryRES(); // Special case for the default database auto update_path = [&](const std::filesystem::path &dir) { - if (!force_directory && std::filesystem::is_directory(dir) && dir.has_filename()) { - const auto end = dir.end(); - auto it = end; - --it; - if (it != end) { - --it; - if (it != end && *it != "databases") { - // Default DB points to the root (for back-compatibility); update to the "database" dir - return dir / dbms::kMultiTenantDir / dbms::kDefaultDB; - } +#ifdef MG_ENTERPRISE + if (config_.salient.name == dbms::kDefaultDB) { + // Default DB points to the root (for back-compatibility); update to the "database" dir + std::filesystem::path new_dir = dir / "databases" / dbms::kDefaultDB; + if (std::filesystem::exists(new_dir) && std::filesystem::is_directory(new_dir)) { + return new_dir; } } +#endif return dir; }; info.disk_usage = utils::GetDirDiskUsage(update_path(config_.durability.storage_directory)); return info; } -StorageInfo InMemoryStorage::GetInfo(bool force_directory, - memgraph::replication_coordination_glue::ReplicationRole replication_role) { - StorageInfo info = GetBaseInfo(force_directory); +StorageInfo InMemoryStorage::GetInfo(memgraph::replication_coordination_glue::ReplicationRole replication_role) { + StorageInfo info = GetBaseInfo(); { auto access = Access(replication_role); // TODO: override isolation level? const auto &lbl = access->ListAllIndices(); diff --git a/src/storage/v2/inmemory/storage.hpp b/src/storage/v2/inmemory/storage.hpp index 8d0a5e0c9..15d9b4e61 100644 --- a/src/storage/v2/inmemory/storage.hpp +++ b/src/storage/v2/inmemory/storage.hpp @@ -372,9 +372,8 @@ class InMemoryStorage final : public Storage { bool InitializeWalFile(memgraph::replication::ReplicationEpoch &epoch); void FinalizeWalFile(); - StorageInfo GetBaseInfo(bool force_directory) override; - StorageInfo GetInfo(bool force_directory, - memgraph::replication_coordination_glue::ReplicationRole replication_role) override; + StorageInfo GetBaseInfo() override; + StorageInfo GetInfo(memgraph::replication_coordination_glue::ReplicationRole replication_role) override; /// Return true in all cases excepted if any sync replicas have not sent confirmation. [[nodiscard]] bool AppendToWal(const Transaction &transaction, uint64_t final_commit_timestamp, diff --git a/src/storage/v2/name_id_mapper.hpp b/src/storage/v2/name_id_mapper.hpp index bb91e3647..2c5aee352 100644 --- a/src/storage/v2/name_id_mapper.hpp +++ b/src/storage/v2/name_id_mapper.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -83,6 +83,18 @@ class NameIdMapper { return id; } + /// This method unlike NameToId does not insert the new property id if not found + /// but just returns either std::nullopt or the value of the property id if it + /// finds it. + virtual std::optional NameToIdIfExists(const std::string_view name) { + auto name_to_id_acc = name_to_id_.access(); + auto found = name_to_id_acc.find(name); + if (found == name_to_id_acc.end()) { + return std::nullopt; + } + return found->id; + } + // NOTE: Currently this function returns a `const std::string &` instead of a // `std::string` to avoid making unnecessary copies of the string. // Usually, this wouldn't be correct because the accessor to the diff --git a/src/storage/v2/property_store.cpp b/src/storage/v2/property_store.cpp index 427998fbe..e6e4dbbaf 100644 --- a/src/storage/v2/property_store.cpp +++ b/src/storage/v2/property_store.cpp @@ -93,6 +93,19 @@ enum class Size : uint8_t { INT64 = 0x03, }; +uint64_t SizeToByteSize(Size size) { + switch (size) { + case Size::INT8: + return 1; + case Size::INT16: + return 2; + case Size::INT32: + return 4; + case Size::INT64: + return 8; + } +} + // All of these values must have the lowest 4 bits set to zero because they are // used to store two `Size` values as described in the comment above. enum class Type : uint8_t { @@ -486,6 +499,27 @@ std::optional DecodeTemporalData(Reader &reader) { return TemporalData{static_cast(*type_value), *microseconds_value}; } +std::optional DecodeTemporalDataSize(Reader &reader) { + uint64_t temporal_data_size = 0; + + auto metadata = reader.ReadMetadata(); + if (!metadata || metadata->type != Type::TEMPORAL_DATA) return std::nullopt; + + temporal_data_size += 1; + + auto type_value = reader.ReadUint(metadata->id_size); + if (!type_value) return std::nullopt; + + temporal_data_size += SizeToByteSize(metadata->id_size); + + auto microseconds_value = reader.ReadInt(metadata->payload_size); + if (!microseconds_value) return std::nullopt; + + temporal_data_size += SizeToByteSize(metadata->payload_size); + + return temporal_data_size; +} + } // namespace // Function used to decode a PropertyValue from a byte stream. @@ -572,6 +606,92 @@ std::optional DecodeTemporalData(Reader &reader) { } } +[[nodiscard]] bool DecodePropertyValueSize(Reader *reader, Type type, Size payload_size, uint64_t &property_size) { + switch (type) { + case Type::EMPTY: { + return false; + } + case Type::NONE: + case Type::BOOL: { + return true; + } + case Type::INT: { + reader->ReadInt(payload_size); + property_size += SizeToByteSize(payload_size); + return true; + } + case Type::DOUBLE: { + reader->ReadDouble(payload_size); + property_size += SizeToByteSize(payload_size); + return true; + } + case Type::STRING: { + auto size = reader->ReadUint(payload_size); + if (!size) return false; + property_size += SizeToByteSize(payload_size); + + std::string str_v(*size, '\0'); + if (!reader->SkipBytes(*size)) return false; + property_size += *size; + + return true; + } + case Type::LIST: { + auto size = reader->ReadUint(payload_size); + if (!size) return false; + + uint64_t list_property_size = SizeToByteSize(payload_size); + + for (uint64_t i = 0; i < *size; ++i) { + auto metadata = reader->ReadMetadata(); + if (!metadata) return false; + + list_property_size += 1; + if (!DecodePropertyValueSize(reader, metadata->type, metadata->payload_size, list_property_size)) return false; + } + + property_size += list_property_size; + return true; + } + case Type::MAP: { + auto size = reader->ReadUint(payload_size); + if (!size) return false; + + uint64_t map_property_size = SizeToByteSize(payload_size); + + for (uint64_t i = 0; i < *size; ++i) { + auto metadata = reader->ReadMetadata(); + if (!metadata) return false; + + map_property_size += 1; + + auto key_size = reader->ReadUint(metadata->id_size); + if (!key_size) return false; + + map_property_size += SizeToByteSize(metadata->id_size); + + std::string key(*key_size, '\0'); + if (!reader->ReadBytes(key.data(), *key_size)) return false; + + map_property_size += *key_size; + + if (!DecodePropertyValueSize(reader, metadata->type, metadata->payload_size, map_property_size)) return false; + } + + property_size += map_property_size; + return true; + } + + case Type::TEMPORAL_DATA: { + const auto maybe_temporal_data_size = DecodeTemporalDataSize(*reader); + if (!maybe_temporal_data_size) return false; + + property_size += *maybe_temporal_data_size; + return true; + } + } +} + // Function used to skip a PropertyValue from a byte stream. // // @sa ComparePropertyValue @@ -788,6 +908,27 @@ enum class ExpectedPropertyStatus { : ExpectedPropertyStatus::GREATER; } +[[nodiscard]] ExpectedPropertyStatus DecodeExpectedPropertySize(Reader *reader, PropertyId expected_property, + uint64_t &size) { + auto metadata = reader->ReadMetadata(); + if (!metadata) return ExpectedPropertyStatus::MISSING_DATA; + + auto property_id = reader->ReadUint(metadata->id_size); + if (!property_id) return ExpectedPropertyStatus::MISSING_DATA; + + if (*property_id == expected_property.AsUint()) { + // Add one byte for reading metadata + add the number of bytes for the property key + size += (1 + SizeToByteSize(metadata->id_size)); + if (!DecodePropertyValueSize(reader, metadata->type, metadata->payload_size, size)) + return ExpectedPropertyStatus::MISSING_DATA; + return ExpectedPropertyStatus::EQUAL; + } + // Don't load the value if this isn't the expected property. + if (!SkipPropertyValue(reader, metadata->type, metadata->payload_size)) return ExpectedPropertyStatus::MISSING_DATA; + return (*property_id < expected_property.AsUint()) ? ExpectedPropertyStatus::SMALLER + : ExpectedPropertyStatus::GREATER; +} + // Function used to check a property exists (PropertyId) from a byte stream. // It will skip the encoded PropertyValue. // @@ -875,6 +1016,13 @@ enum class ExpectedPropertyStatus { } } +[[nodiscard]] ExpectedPropertyStatus FindSpecificPropertySize(Reader *reader, PropertyId property, uint64_t &size) { + ExpectedPropertyStatus ret = ExpectedPropertyStatus::SMALLER; + while ((ret = DecodeExpectedPropertySize(reader, property, size)) == ExpectedPropertyStatus::SMALLER) { + } + return ret; +} + // Function used to find if property is set. It relies on the fact that the properties // are sorted (by ID) in the buffer. // @@ -983,6 +1131,31 @@ std::pair GetSizeData(const uint8_t *buffer) { return {size, data}; } +struct BufferInfo { + uint64_t size; + uint8_t *data{nullptr}; + bool in_local_buffer; +}; + +template +BufferInfo GetBufferInfo(const uint8_t (&buffer)[N]) { + uint64_t size = 0; + const uint8_t *data = nullptr; + bool in_local_buffer = false; + std::tie(size, data) = GetSizeData(buffer); + if (size % 8 != 0) { + // We are storing the data in the local buffer. + size = sizeof(buffer) - 1; + data = &buffer[1]; + in_local_buffer = true; + } + + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + auto *non_const_data = const_cast(data); + + return {size, non_const_data, in_local_buffer}; +} + void SetSizeData(uint8_t *buffer, uint64_t size, uint8_t *data) { memcpy(buffer, &size, sizeof(uint64_t)); memcpy(buffer + sizeof(uint64_t), &data, sizeof(uint8_t *)); @@ -1023,30 +1196,27 @@ PropertyStore::~PropertyStore() { } PropertyValue PropertyStore::GetProperty(PropertyId property) const { - uint64_t size; - const uint8_t *data; - std::tie(size, data) = GetSizeData(buffer_); - if (size % 8 != 0) { - // We are storing the data in the local buffer. - size = sizeof(buffer_) - 1; - data = &buffer_[1]; - } - Reader reader(data, size); + BufferInfo buffer_info = GetBufferInfo(buffer_); + Reader reader(buffer_info.data, buffer_info.size); + PropertyValue value; if (FindSpecificProperty(&reader, property, value) != ExpectedPropertyStatus::EQUAL) return {}; return value; } +uint64_t PropertyStore::PropertySize(PropertyId property) const { + auto data_size_localbuffer = GetBufferInfo(buffer_); + Reader reader(data_size_localbuffer.data, data_size_localbuffer.size); + + uint64_t property_size = 0; + if (FindSpecificPropertySize(&reader, property, property_size) != ExpectedPropertyStatus::EQUAL) return 0; + return property_size; +} + bool PropertyStore::HasProperty(PropertyId property) const { - uint64_t size; - const uint8_t *data; - std::tie(size, data) = GetSizeData(buffer_); - if (size % 8 != 0) { - // We are storing the data in the local buffer. - size = sizeof(buffer_) - 1; - data = &buffer_[1]; - } - Reader reader(data, size); + BufferInfo buffer_info = GetBufferInfo(buffer_); + Reader reader(buffer_info.data, buffer_info.size); + return ExistsSpecificProperty(&reader, property) == ExpectedPropertyStatus::EQUAL; } @@ -1081,32 +1251,20 @@ std::optional> PropertyStore::ExtractPropertyValues( } bool PropertyStore::IsPropertyEqual(PropertyId property, const PropertyValue &value) const { - uint64_t size; - const uint8_t *data; - std::tie(size, data) = GetSizeData(buffer_); - if (size % 8 != 0) { - // We are storing the data in the local buffer. - size = sizeof(buffer_) - 1; - data = &buffer_[1]; - } - Reader reader(data, size); + BufferInfo buffer_info = GetBufferInfo(buffer_); + Reader reader(buffer_info.data, buffer_info.size); + auto info = FindSpecificPropertyAndBufferInfo(&reader, property); if (info.property_size == 0) return value.IsNull(); - Reader prop_reader(data + info.property_begin, info.property_size); + Reader prop_reader(buffer_info.data + info.property_begin, info.property_size); if (!CompareExpectedProperty(&prop_reader, property, value)) return false; return prop_reader.GetPosition() == info.property_size; } std::map PropertyStore::Properties() const { - uint64_t size; - const uint8_t *data; - std::tie(size, data) = GetSizeData(buffer_); - if (size % 8 != 0) { - // We are storing the data in the local buffer. - size = sizeof(buffer_) - 1; - data = &buffer_[1]; - } - Reader reader(data, size); + BufferInfo buffer_info = GetBufferInfo(buffer_); + Reader reader(buffer_info.data, buffer_info.size); + std::map props; while (true) { PropertyValue value; @@ -1340,33 +1498,20 @@ bool PropertyStore::InitProperties(std::vector(data[i]); + BufferInfo buffer_info = GetBufferInfo(buffer_); + + std::string arr(buffer_info.size, ' '); + for (uint i = 0; i < buffer_info.size; ++i) { + arr[i] = static_cast(buffer_info.data[i]); } return arr; } diff --git a/src/storage/v2/property_store.hpp b/src/storage/v2/property_store.hpp index c217cbd81..eee83f5df 100644 --- a/src/storage/v2/property_store.hpp +++ b/src/storage/v2/property_store.hpp @@ -45,6 +45,11 @@ class PropertyStore { /// @throw std::bad_alloc PropertyValue GetProperty(PropertyId property) const; + /// Returns the size of the encoded property in bytes. + /// Returns 0 if the property does not exist. + /// The time complexity of this function is O(n). + uint64_t PropertySize(PropertyId property) const; + /// Checks whether the property `property` exists in the store. The time /// complexity of this function is O(n). bool HasProperty(PropertyId property) const; diff --git a/src/storage/v2/replication/global.hpp b/src/storage/v2/replication/global.hpp index ebcec1206..9c479edc9 100644 --- a/src/storage/v2/replication/global.hpp +++ b/src/storage/v2/replication/global.hpp @@ -26,7 +26,7 @@ namespace memgraph::storage { struct TimestampInfo { uint64_t current_timestamp_of_replica; - uint64_t current_number_of_timestamp_behind_master; + uint64_t current_number_of_timestamp_behind_main; }; struct ReplicaInfo { diff --git a/src/storage/v2/replication/replication_client.cpp b/src/storage/v2/replication/replication_client.cpp index 16247de57..16429d11f 100644 --- a/src/storage/v2/replication/replication_client.cpp +++ b/src/storage/v2/replication/replication_client.cpp @@ -9,6 +9,8 @@ // by the Apache License, Version 2.0, included in the file // licenses/APL.txt. +#include + #include "replication/replication_client.hpp" #include "storage/v2/inmemory/storage.hpp" #include "storage/v2/storage.hpp" @@ -17,7 +19,7 @@ #include "utils/uuid.hpp" #include "utils/variant_helpers.hpp" -#include +#include "io/network/fmt.hpp" namespace { template @@ -93,7 +95,7 @@ void ReplicationStorageClient::UpdateReplicaState(Storage *storage, DatabaseAcce TimestampInfo ReplicationStorageClient::GetTimestampInfo(Storage const *storage) { TimestampInfo info; info.current_timestamp_of_replica = 0; - info.current_number_of_timestamp_behind_master = 0; + info.current_number_of_timestamp_behind_main = 0; try { auto stream{client_.rpc_client_.Stream(main_uuid_, storage->uuid())}; @@ -102,9 +104,9 @@ TimestampInfo ReplicationStorageClient::GetTimestampInfo(Storage const *storage) auto main_time_stamp = storage->repl_storage_state_.last_commit_timestamp_.load(); info.current_timestamp_of_replica = response.current_commit_timestamp; - info.current_number_of_timestamp_behind_master = response.current_commit_timestamp - main_time_stamp; + info.current_number_of_timestamp_behind_main = response.current_commit_timestamp - main_time_stamp; - if (!is_success || info.current_number_of_timestamp_behind_master != 0) { + if (!is_success || info.current_number_of_timestamp_behind_main != 0) { replica_state_.WithLock([](auto &val) { val = replication::ReplicaState::MAYBE_BEHIND; }); LogRpcFailure(); } diff --git a/src/storage/v2/replication/replication_storage_state.hpp b/src/storage/v2/replication/replication_storage_state.hpp index adbf87aa9..91cec563c 100644 --- a/src/storage/v2/replication/replication_storage_state.hpp +++ b/src/storage/v2/replication/replication_storage_state.hpp @@ -63,7 +63,7 @@ struct ReplicationStorageState { return replication_clients_.WithLock([replica_name, cb = std::forward(callback)](auto &clients) { for (const auto &client : clients) { if (client->Name() == replica_name) { - cb(client.get()); + cb(*client); return true; } } diff --git a/src/storage/v2/replication/serialization.cpp b/src/storage/v2/replication/serialization.cpp index 6651b8999..d0ba2e8ac 100644 --- a/src/storage/v2/replication/serialization.cpp +++ b/src/storage/v2/replication/serialization.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source diff --git a/src/storage/v2/storage.hpp b/src/storage/v2/storage.hpp index 79e99456b..d2c42c33a 100644 --- a/src/storage/v2/storage.hpp +++ b/src/storage/v2/storage.hpp @@ -250,6 +250,10 @@ class Storage { PropertyId NameToProperty(std::string_view name) { return storage_->NameToProperty(name); } + std::optional NameToPropertyIfExists(std::string_view name) const { + return storage_->NameToPropertyIfExists(name); + } + EdgeTypeId NameToEdgeType(std::string_view name) { return storage_->NameToEdgeType(name); } StorageMode GetCreationStorageMode() const noexcept; @@ -320,6 +324,14 @@ class Storage { return PropertyId::FromUint(name_id_mapper_->NameToId(name)); } + std::optional NameToPropertyIfExists(std::string_view name) const { + const auto id = name_id_mapper_->NameToIdIfExists(name); + if (!id) { + return std::nullopt; + } + return PropertyId::FromUint(*id); + } + EdgeTypeId NameToEdgeType(const std::string_view name) const { return EdgeTypeId::FromUint(name_id_mapper_->NameToId(name)); } @@ -349,18 +361,9 @@ class Storage { utils::BasicResult SetIsolationLevel(IsolationLevel isolation_level); IsolationLevel GetIsolationLevel() const noexcept; - virtual StorageInfo GetBaseInfo(bool force_directory) = 0; - StorageInfo GetBaseInfo() { -#if MG_ENTERPRISE - const bool force_dir = false; -#else - const bool force_dir = true; //!< Use the configured directory (multi-tenancy reroutes to another dir) -#endif - return GetBaseInfo(force_dir); - } + virtual StorageInfo GetBaseInfo() = 0; - virtual StorageInfo GetInfo(bool force_directory, - memgraph::replication_coordination_glue::ReplicationRole replication_role) = 0; + virtual StorageInfo GetInfo(memgraph::replication_coordination_glue::ReplicationRole replication_role) = 0; virtual Transaction CreateTransaction(IsolationLevel isolation_level, StorageMode storage_mode, memgraph::replication_coordination_glue::ReplicationRole replication_role) = 0; diff --git a/src/storage/v2/vertex_accessor.cpp b/src/storage/v2/vertex_accessor.cpp index ff5062444..ef0a6ab3e 100644 --- a/src/storage/v2/vertex_accessor.cpp +++ b/src/storage/v2/vertex_accessor.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -438,6 +438,26 @@ Result VertexAccessor::GetProperty(PropertyId property, View view return std::move(value); } +Result VertexAccessor::GetPropertySize(PropertyId property, View view) const { + { + auto guard = std::shared_lock{vertex_->lock}; + Delta *delta = vertex_->delta; + if (!delta) { + return vertex_->properties.PropertySize(property); + } + } + + auto property_result = this->GetProperty(property, view); + if (property_result.HasError()) { + return property_result.GetError(); + } + + auto property_store = storage::PropertyStore(); + property_store.SetProperty(property, *property_result); + + return property_store.PropertySize(property); +}; + Result> VertexAccessor::Properties(View view) const { bool exists = true; bool deleted = false; diff --git a/src/storage/v2/vertex_accessor.hpp b/src/storage/v2/vertex_accessor.hpp index 0e5972d14..18fad3dcc 100644 --- a/src/storage/v2/vertex_accessor.hpp +++ b/src/storage/v2/vertex_accessor.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -80,6 +80,9 @@ class VertexAccessor final { /// @throw std::bad_alloc Result GetProperty(PropertyId property, View view) const; + /// Returns the size of the encoded vertex property in bytes. + Result GetPropertySize(PropertyId property, View view) const; + /// @throw std::bad_alloc Result> Properties(View view) const; diff --git a/src/system/include/system/system.hpp b/src/system/include/system/system.hpp index eb15a553f..c549a7e0e 100644 --- a/src/system/include/system/system.hpp +++ b/src/system/include/system/system.hpp @@ -34,7 +34,7 @@ struct System { if (!system_unique.try_lock_for(try_time)) { return std::nullopt; } - return Transaction{state_, std::move(system_unique), timestamp_++}; + return Transaction{state_, std::move(system_unique), ++timestamp_}; } // TODO: this and LastCommittedSystemTimestamp maybe not needed @@ -46,7 +46,7 @@ struct System { private: State state_; std::timed_mutex mtx_{}; - std::uint64_t timestamp_{}; + std::uint64_t timestamp_{0}; }; } // namespace memgraph::system diff --git a/src/utils/async_timer.cpp b/src/utils/async_timer.cpp index dd5789172..b72be4d45 100644 --- a/src/utils/async_timer.cpp +++ b/src/utils/async_timer.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -56,12 +56,11 @@ void EraseFlag(uint64_t flag_id) { expiration_flags.access().remove(flag_id); } std::weak_ptr> GetFlag(uint64_t flag_id) { const auto flag_accessor = expiration_flags.access(); - const auto it = flag_accessor.find(flag_id); - if (it == flag_accessor.end()) { + const auto iter = flag_accessor.find(flag_id); + if (iter == flag_accessor.end()) { return {}; } - - return it->flag; + return iter->flag; } void MarkDone(const uint64_t flag_id) { diff --git a/src/utils/exceptions.hpp b/src/utils/exceptions.hpp index fa49f770e..929be1d58 100644 --- a/src/utils/exceptions.hpp +++ b/src/utils/exceptions.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -42,12 +42,26 @@ namespace memgraph::utils { class BasicException : public std::exception { public: /** - * @brief Constructor (C++ STL strings). + * @brief Constructor (C++ STL strings_view). * * @param message The error message. */ explicit BasicException(std::string_view message) noexcept : msg_(message) {} + /** + * @brief Constructor (string literal). + * + * @param message The error message. + */ + explicit BasicException(const char *message) noexcept : msg_(message) {} + + /** + * @brief Constructor (C++ STL strings). + * + * @param message The error message. + */ + explicit BasicException(std::string message) noexcept : msg_(std::move(message)) {} + /** * @brief Constructor with format string (C++ STL strings). * diff --git a/src/utils/functional.hpp b/src/utils/functional.hpp new file mode 100644 index 000000000..e0714de2a --- /dev/null +++ b/src/utils/functional.hpp @@ -0,0 +1,26 @@ +// Copyright 2024 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include +#include + +#include + +namespace memgraph::utils { + +template ::type> +auto fmap(F &&f, std::vector const &v) -> std::vector { + return v | ranges::views::transform(std::forward(f)) | ranges::to>(); +} + +} // namespace memgraph::utils diff --git a/src/utils/logging.hpp b/src/utils/logging.hpp index 02389beab..adc5db51a 100644 --- a/src/utils/logging.hpp +++ b/src/utils/logging.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -23,6 +23,11 @@ #include #include +// NOTE: fmt 9+ introduced fmt/std.h, it's important because of, e.g., std::path formatting. toolchain-v4 has fmt 8, +// the guard is here because of fmt 8 compatibility. +#if FMT_VERSION > 90000 +#include +#endif #include #include #include diff --git a/src/utils/memory_tracker.cpp b/src/utils/memory_tracker.cpp index 7dfd88416..3a2ec4dec 100644 --- a/src/utils/memory_tracker.cpp +++ b/src/utils/memory_tracker.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -104,7 +104,7 @@ void MemoryTracker::SetMaximumHardLimit(const int64_t limit) { maximum_hard_limit_ = limit; } -void MemoryTracker::Alloc(const int64_t size) { +bool MemoryTracker::Alloc(int64_t const size) { MG_ASSERT(size >= 0, "Negative size passed to the MemoryTracker."); const int64_t will_be = size + amount_.fetch_add(size, std::memory_order_relaxed); @@ -116,12 +116,13 @@ void MemoryTracker::Alloc(const int64_t size) { amount_.fetch_sub(size, std::memory_order_relaxed); - throw OutOfMemoryException( - fmt::format("Memory limit exceeded! Attempting to allocate a chunk of {} which would put the current " - "use to {}, while the maximum allowed size for allocation is set to {}.", - GetReadableSize(size), GetReadableSize(will_be), GetReadableSize(current_hard_limit))); + // register our error data, we will pick this up on the other side of jemalloc + MemoryErrorStatus().set({size, will_be, current_hard_limit}); + + return false; } UpdatePeak(will_be); + return true; } void MemoryTracker::DoCheck() { @@ -139,4 +140,23 @@ void MemoryTracker::DoCheck() { void MemoryTracker::Free(const int64_t size) { amount_.fetch_sub(size, std::memory_order_relaxed); } +// DEVNOTE: important that this is allocated at thread construction time +// otherwise subtle bug where jemalloc will try to lock an non-recursive mutex +// that it already owns +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +thread_local MemoryTrackerStatus status; +auto MemoryErrorStatus() -> MemoryTrackerStatus & { return status; } + +auto MemoryTrackerStatus::msg() -> std::optional { + if (!data_) return std::nullopt; + + auto [size, will_be, hard_limit] = *data_; + data_.reset(); + return fmt::format( + "Memory limit exceeded! Attempting to allocate a chunk of {} which would put the current " + "use to {}, while the maximum allowed size for allocation is set to {}.", + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + GetReadableSize(size), GetReadableSize(will_be), GetReadableSize(hard_limit)); +} + } // namespace memgraph::utils diff --git a/src/utils/memory_tracker.hpp b/src/utils/memory_tracker.hpp index a6d7221ff..5d82c6f2a 100644 --- a/src/utils/memory_tracker.hpp +++ b/src/utils/memory_tracker.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -12,15 +12,35 @@ #pragma once #include +#include +#include #include #include "utils/exceptions.hpp" namespace memgraph::utils { +struct MemoryTrackerStatus { + struct data { + int64_t size; + int64_t will_be; + int64_t hard_limit; + }; + + // DEVNOTE: Do not call from within allocator, will cause another allocation + auto msg() -> std::optional; + + void set(data d) { data_ = d; } + + private: + std::optional data_; +}; + +auto MemoryErrorStatus() -> MemoryTrackerStatus &; + class OutOfMemoryException : public utils::BasicException { public: - explicit OutOfMemoryException(const std::string &msg) : utils::BasicException(msg) {} + explicit OutOfMemoryException(std::string msg) : utils::BasicException(std::move(msg)) {} SPECIALIZE_GET_EXCEPTION_NAME(OutOfMemoryException) }; @@ -47,7 +67,7 @@ class MemoryTracker final { MemoryTracker &operator=(MemoryTracker &&) = delete; - void Alloc(int64_t size); + bool Alloc(int64_t size); void Free(int64_t size); void DoCheck(); diff --git a/src/utils/message.hpp b/src/utils/message.hpp index c301b3878..009bea032 100644 --- a/src/utils/message.hpp +++ b/src/utils/message.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -17,8 +17,13 @@ namespace memgraph::utils { template std::string MessageWithLink(fmt::format_string fmt, Args &&...args) { +#if FMT_VERSION > 90000 + return fmt::format(fmt::runtime(fmt::format(fmt::runtime("{} For more details, visit {{}}."), fmt.get())), + std::forward(args)...); +#else return fmt::format(fmt::runtime(fmt::format(fmt::runtime("{} For more details, visit {{}}."), fmt)), std::forward(args)...); +#endif } } // namespace memgraph::utils diff --git a/src/utils/query_memory_tracker.cpp b/src/utils/query_memory_tracker.cpp index a9b61cdf3..46cb6d871 100644 --- a/src/utils/query_memory_tracker.cpp +++ b/src/utils/query_memory_tracker.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -17,18 +17,19 @@ namespace memgraph::utils { -void QueryMemoryTracker::TrackAlloc(size_t size) { +bool QueryMemoryTracker::TrackAlloc(size_t size) { if (query_tracker_.has_value()) [[likely]] { - query_tracker_->Alloc(static_cast(size)); + bool ok = query_tracker_->Alloc(static_cast(size)); + if (!ok) return false; } auto *proc_tracker = GetActiveProc(); if (proc_tracker == nullptr) { - return; + return true; } - proc_tracker->Alloc(static_cast(size)); + return proc_tracker->Alloc(static_cast(size)); } void QueryMemoryTracker::TrackFree(size_t size) { if (query_tracker_.has_value()) [[likely]] { diff --git a/src/utils/query_memory_tracker.hpp b/src/utils/query_memory_tracker.hpp index 87975adf8..acfdca07f 100644 --- a/src/utils/query_memory_tracker.hpp +++ b/src/utils/query_memory_tracker.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -44,7 +44,7 @@ class QueryMemoryTracker { ~QueryMemoryTracker() = default; // Track allocation on query and procedure if active - void TrackAlloc(size_t); + bool TrackAlloc(size_t size); // Track Free on query and procedure if active void TrackFree(size_t); diff --git a/src/utils/stat.hpp b/src/utils/stat.hpp index 4c2eec6d6..de806f853 100644 --- a/src/utils/stat.hpp +++ b/src/utils/stat.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -18,6 +18,7 @@ #include #include "utils/file.hpp" +#include "utils/logging.hpp" #include "utils/string.hpp" namespace memgraph::utils { @@ -39,19 +40,19 @@ inline uint64_t GetDirDiskUsage(const std::filesystem::path &path) { return 0; } uint64_t size = 0; - for (const auto &p : std::filesystem::directory_iterator(path)) { - if (IgnoreSymlink && std::filesystem::is_symlink(p)) continue; - if (std::filesystem::is_directory(p)) { - size += GetDirDiskUsage(p); - } else if (std::filesystem::is_regular_file(p)) { - if (!utils::HasReadAccess(p)) { + for (const auto &dir_entry : std::filesystem::directory_iterator(path)) { + if (IgnoreSymlink && std::filesystem::is_symlink(dir_entry)) continue; + if (std::filesystem::is_directory(dir_entry)) { + size += GetDirDiskUsage(dir_entry); + } else if (std::filesystem::is_regular_file(dir_entry)) { + if (!utils::HasReadAccess(dir_entry)) { spdlog::warn( "Skipping file path on collecting directory disk usage '{}' because it is not readable, check file " "ownership and read permissions!", - p); + dir_entry.path()); continue; } - size += std::filesystem::file_size(p); + size += std::filesystem::file_size(dir_entry); } } diff --git a/src/utils/typeinfo.hpp b/src/utils/typeinfo.hpp index f2f1e4407..8ee7cdc33 100644 --- a/src/utils/typeinfo.hpp +++ b/src/utils/typeinfo.hpp @@ -107,6 +107,13 @@ enum class TypeId : uint64_t { COORD_SET_REPL_MAIN_RES, COORD_SWAP_UUID_REQ, COORD_SWAP_UUID_RES, + COORD_UNREGISTER_REPLICA_REQ, + COORD_UNREGISTER_REPLICA_RES, + COORD_ENABLE_WRITING_ON_MAIN_REQ, + COORD_ENABLE_WRITING_ON_MAIN_RES, + + COORD_GET_UUID_REQ, + COORD_GET_UUID_RES, // AST AST_LABELIX = 3000, diff --git a/tests/benchmark/expansion.cpp b/tests/benchmark/expansion.cpp index 0c4579476..d47ca1aca 100644 --- a/tests/benchmark/expansion.cpp +++ b/tests/benchmark/expansion.cpp @@ -27,6 +27,7 @@ std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "e class ExpansionBenchFixture : public benchmark::Fixture { protected: std::optional system; + std::optional auth_checker; std::optional interpreter_context; std::optional interpreter; std::optional> db_gk; @@ -43,6 +44,7 @@ class ExpansionBenchFixture : public benchmark::Fixture { auto &db_acc = *db_acc_opt; system.emplace(); + auth_checker.emplace(); interpreter_context.emplace(memgraph::query::InterpreterConfig{}, nullptr, &repl_state.value(), *system #ifdef MG_ENTERPRISE , @@ -73,13 +75,15 @@ class ExpansionBenchFixture : public benchmark::Fixture { } interpreter.emplace(&*interpreter_context, std::move(db_acc)); + interpreter->SetUser(auth_checker->GenQueryUser(std::nullopt, std::nullopt)); } void TearDown(const benchmark::State &) override { interpreter = std::nullopt; interpreter_context = std::nullopt; - system.reset(); db_gk.reset(); + auth_checker.reset(); + system.reset(); std::filesystem::remove_all(data_directory); } }; diff --git a/tests/drivers/run.sh b/tests/drivers/run.sh index d82b81ea9..ec99e410d 100755 --- a/tests/drivers/run.sh +++ b/tests/drivers/run.sh @@ -1,7 +1,7 @@ #!/bin/bash -# Old v1 tests -run_v1.sh +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +cd "$DIR" # New tests pushd () { command pushd "$@" > /dev/null; } @@ -15,8 +15,9 @@ function wait_for_server { sleep 1 } -DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" -cd "$DIR" +# Old v1 tests +tests_v1="$DIR/run_v1.sh" +$tests_v1 # Create a temporary directory. tmpdir=/tmp/memgraph_drivers diff --git a/tests/e2e/CMakeLists.txt b/tests/e2e/CMakeLists.txt index 9a4406d2b..1876074ee 100644 --- a/tests/e2e/CMakeLists.txt +++ b/tests/e2e/CMakeLists.txt @@ -40,7 +40,7 @@ endfunction() add_subdirectory(fine_grained_access) add_subdirectory(server) add_subdirectory(replication) -#add_subdirectory(memory) +add_subdirectory(memory) add_subdirectory(triggers) add_subdirectory(isolation_levels) add_subdirectory(streams) @@ -56,7 +56,6 @@ add_subdirectory(python_query_modules_reloading) add_subdirectory(analyze_graph) add_subdirectory(transaction_queue) add_subdirectory(mock_api) -#add_subdirectory(graphql) add_subdirectory(disk_storage) add_subdirectory(load_csv) add_subdirectory(init_file_flags) @@ -76,10 +75,8 @@ add_subdirectory(queries) add_subdirectory(query_modules_storage_modes) add_subdirectory(garbage_collection) add_subdirectory(query_planning) - -if (MG_EXPERIMENTAL_HIGH_AVAILABILITY) - add_subdirectory(high_availability_experimental) -endif () +add_subdirectory(awesome_functions) +add_subdirectory(high_availability) add_subdirectory(replication_experimental) diff --git a/tests/e2e/awesome_functions/CMakeLists.txt b/tests/e2e/awesome_functions/CMakeLists.txt new file mode 100644 index 000000000..9d6e0143b --- /dev/null +++ b/tests/e2e/awesome_functions/CMakeLists.txt @@ -0,0 +1,6 @@ +function(copy_awesome_functions_e2e_python_files FILE_NAME) + copy_e2e_python_files(awesome_functions ${FILE_NAME}) +endfunction() + +copy_awesome_functions_e2e_python_files(common.py) +copy_awesome_functions_e2e_python_files(awesome_functions.py) diff --git a/tests/e2e/awesome_functions/awesome_functions.py b/tests/e2e/awesome_functions/awesome_functions.py new file mode 100644 index 000000000..9761708ed --- /dev/null +++ b/tests/e2e/awesome_functions/awesome_functions.py @@ -0,0 +1,269 @@ +# Copyright 2023 Memgraph Ltd. +# +# Use of this software is governed by the Business Source License +# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +# License, and you may not use this file except in compliance with the Business Source License. +# +# As of the Change Date specified in that file, in accordance with +# the Business Source License, use of this software will be governed +# by the Apache License, Version 2.0, included in the file +# licenses/APL.txt. + +import sys + +import pytest +from common import get_bytes, memgraph + + +def test_property_size_on_null_prop(memgraph): + memgraph.execute( + """ + CREATE (n:Node) + SET n.null_prop = null; + """ + ) + + null_bytes = get_bytes(memgraph, "null_prop") + + # No property stored, no bytes allocated + assert null_bytes == 0 + + +def test_property_size_on_bool_prop(memgraph): + memgraph.execute( + """ + CREATE (n:Node) + SET n.bool_prop = True; + """ + ) + + bool_bytes = get_bytes(memgraph, "bool_prop") + + # 1 byte metadata, 1 byte prop id, but value is encoded in the metadata + assert bool_bytes == 2 + + +def test_property_size_on_one_byte_int_prop(memgraph): + memgraph.execute( + """ + CREATE (n:Node) + SET n.S_int_prop = 4; + """ + ) + + s_int_bytes = get_bytes(memgraph, "S_int_prop") + + # 1 byte metadata, 1 byte prop id + payload size 1 byte to store the int + assert s_int_bytes == 3 + + +def test_property_size_on_two_byte_int_prop(memgraph): + memgraph.execute( + """ + CREATE (n:Node) + SET n.M_int_prop = 500; + """ + ) + + m_int_bytes = get_bytes(memgraph, "M_int_prop") + + # 1 byte metadata, 1 byte prop id + payload size 2 bytes to store the int + assert m_int_bytes == 4 + + +def test_property_size_on_four_byte_int_prop(memgraph): + memgraph.execute( + """ + CREATE (n:Node) + SET n.L_int_prop = 1000000000; + """ + ) + + l_int_bytes = get_bytes(memgraph, "L_int_prop") + + # 1 byte metadata, 1 byte prop id + payload size 4 bytes to store the int + assert l_int_bytes == 6 + + +def test_property_size_on_eight_byte_int_prop(memgraph): + memgraph.execute( + """ + CREATE (n:Node) + SET n.XL_int_prop = 1000000000000; + """ + ) + + xl_int_bytes = get_bytes(memgraph, "XL_int_prop") + + # 1 byte metadata, 1 byte prop id + payload size 8 bytes to store the int + assert xl_int_bytes == 10 + + +def test_property_size_on_float_prop(memgraph): + memgraph.execute( + """ + CREATE (n:Node) + SET n.float_prop = 4.0; + """ + ) + + float_bytes = get_bytes(memgraph, "float_prop") + + # 1 byte metadata, 1 byte prop id + payload size 8 bytes to store the float + assert float_bytes == 10 + + +def test_property_size_on_string_prop(memgraph): + memgraph.execute( + """ + CREATE (n:Node) + SET n.str_prop = 'str_value'; + """ + ) + + str_bytes = get_bytes(memgraph, "str_prop") + + # 1 byte metadata + # 1 byte prop id + # - the payload size contains the amount of bytes stored for the size in the next sequence + # X bytes for the length of the string (1, 2, 4 or 8 bytes) -> "str_value" has 1 byte for the length of 9 + # Y bytes for the string content -> 9 bytes for "str_value" + assert str_bytes == 12 + + +def test_property_size_on_list_prop(memgraph): + memgraph.execute( + """ + CREATE (n:Node) + SET n.list_prop = [1, 2, 3]; + """ + ) + + list_bytes = get_bytes(memgraph, "list_prop") + + # 1 byte metadata + # 1 byte prop id + # - the payload size contains the amount of bytes stored for the size of the list + # X bytes for the size of the list (1, 2, 4 or 8 bytes) + # for each list element: + # - 1 byte for the metadata + # - the amount of bytes for the payload of the type (a small int is 1 additional byte) + # in this case 1 + 1 + 3 * (1 + 1) + assert list_bytes == 9 + + +def test_property_size_on_map_prop(memgraph): + memgraph.execute( + """ + CREATE (n:Node) + SET n.map_prop = {key1: 'value', key2: 4}; + """ + ) + + map_bytes = get_bytes(memgraph, "map_prop") + + # 1 byte metadata + # 1 byte prop id + # - the payload size contains the amount of bytes stored for the size of the map + # X bytes for the size of the map (1, 2, 4 or 8 bytes - in this case 1) + # for every map element: + # - 1 byte for metadata + # - 1, 2, 4 or 8 bytes for the key length (read from the metadata payload) -> this case 1 + # - Y bytes for the key content -> this case 4 + # - Z amount of bytes for the type + # - for 'value' -> 1 byte for size and 5 for length + # - for 4 -> 1 byte for content read from payload + # total: 1 + 1 + (1 + 1 + 4 + (1 + 5)) + (1 + 1 + 4 + (1)) + assert map_bytes == 22 + + +def test_property_size_on_date_prop(memgraph): + memgraph.execute( + """ + CREATE (n:Node) + SET n.date_prop = date('2023-01-01'); + """ + ) + + date_bytes = get_bytes(memgraph, "date_prop") + + # 1 byte metadata (to see that it's temporal data) + # 1 byte prop id + # 1 byte metadata + # - type is again the same + # - id field contains the length of the specific temporal type (1, 2, 4 or 8 bytes) -> probably always 1 + # - payload field contains the length of the microseconds (1, 2, 4, or 8 bytes) -> probably always 8 + assert date_bytes == 12 + + +def test_property_size_on_local_time_prop(memgraph): + memgraph.execute( + """ + CREATE (n:Node) + SET n.localtime_prop = localtime('23:00:00'); + """ + ) + + local_time_bytes = get_bytes(memgraph, "localtime_prop") + + # 1 byte metadata (to see that it's temporal data) + # 1 byte prop id + # 1 byte metadata + # - type is again the same + # - id field contains the length of the specific temporal type (1, 2, 4 or 8 bytes) -> probably always 1 + # - payload field contains the length of the microseconds (1, 2, 4, or 8 bytes) -> probably always 8 + assert local_time_bytes == 12 + + +def test_property_size_on_local_date_time_prop(memgraph): + memgraph.execute( + """ + CREATE (n:Node) + SET n.localdatetime_prop = localdatetime('2022-01-01T23:00:00'); + """ + ) + + local_date_time_bytes = get_bytes(memgraph, "localdatetime_prop") + + # 1 byte metadata (to see that it's temporal data) + # 1 byte prop id + # 1 byte metadata + # - type is again the same + # - id field contains the length of the specific temporal type (1, 2, 4 or 8 bytes) -> probably always 1 + # - payload field contains the length of the microseconds (1, 2, 4, or 8 bytes) -> probably always 8 + assert local_date_time_bytes == 12 + + +def test_property_size_on_duration_prop(memgraph): + memgraph.execute( + """ + CREATE (n:Node) + SET n.duration_prop = duration('P5DT2M2.33S'); + """ + ) + + duration_bytes = get_bytes(memgraph, "duration_prop") + + # 1 byte metadata (to see that it's temporal data) + # 1 byte prop id + # 1 byte metadata + # - type is again the same + # - id field contains the length of the specific temporal type (1, 2, 4 or 8 bytes) -> probably always 1 + # - payload field contains the length of the microseconds (1, 2, 4, or 8 bytes) -> probably always 8 + assert duration_bytes == 12 + + +def test_property_size_on_nonexistent_prop(memgraph): + memgraph.execute( + """ + CREATE (n:Node); + """ + ) + + nonexistent_bytes = get_bytes(memgraph, "nonexistent_prop") + + assert nonexistent_bytes == 0 + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__, "-rA"])) diff --git a/tests/e2e/awesome_functions/common.py b/tests/e2e/awesome_functions/common.py new file mode 100644 index 000000000..14f272c23 --- /dev/null +++ b/tests/e2e/awesome_functions/common.py @@ -0,0 +1,29 @@ +# Copyright 2023 Memgraph Ltd. +# +# Use of this software is governed by the Business Source License +# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +# License, and you may not use this file except in compliance with the Business Source License. +# +# As of the Change Date specified in that file, in accordance with +# the Business Source License, use of this software will be governed +# by the Apache License, Version 2.0, included in the file +# licenses/APL.txt. + +import pytest +from gqlalchemy import Memgraph + + +@pytest.fixture +def memgraph(**kwargs) -> Memgraph: + memgraph = Memgraph() + + yield memgraph + + memgraph.drop_indexes() + memgraph.ensure_constraints([]) + memgraph.drop_database() + + +def get_bytes(memgraph, prop_name): + res = list(memgraph.execute_and_fetch(f"MATCH (n) RETURN propertySize(n, '{prop_name}') AS size")) + return res[0]["size"] diff --git a/tests/e2e/awesome_functions/workloads.yaml b/tests/e2e/awesome_functions/workloads.yaml new file mode 100644 index 000000000..37e5e8813 --- /dev/null +++ b/tests/e2e/awesome_functions/workloads.yaml @@ -0,0 +1,14 @@ +awesome_functions_cluster: &awesome_functions_cluster + cluster: + main: + args: ["--bolt-port", "7687", "--log-level=TRACE"] + log_file: "awesome_functions.log" + setup_queries: [] + validation_queries: [] + + +workloads: + - name: "Awesome Functions" + binary: "tests/e2e/pytest_runner.sh" + args: ["awesome_functions/awesome_functions.py"] + <<: *awesome_functions_cluster diff --git a/tests/e2e/configuration/default_config.py b/tests/e2e/configuration/default_config.py index 3c58e2c49..65a850f0b 100644 --- a/tests/e2e/configuration/default_config.py +++ b/tests/e2e/configuration/default_config.py @@ -14,14 +14,7 @@ # If you wish to modify these, update the startup_config_dict and workloads.yaml ! startup_config_dict = { - "auth_module_create_missing_role": ("true", "true", "Set to false to disable creation of missing roles."), - "auth_module_create_missing_user": ("true", "true", "Set to false to disable creation of missing users."), "auth_module_executable": ("", "", "Absolute path to the auth module executable that should be used."), - "auth_module_manage_roles": ( - "true", - "true", - "Set to false to disable management of roles through the auth module.", - ), "auth_module_timeout_ms": ( "10000", "10000", @@ -69,6 +62,9 @@ startup_config_dict = { "coordinator_server_port": ("0", "0", "Port on which coordinator servers will be started."), "raft_server_port": ("0", "0", "Port on which raft servers will be started."), "raft_server_id": ("0", "0", "Unique ID of the raft server."), + "instance_down_timeout_sec": ("5", "5", "Time duration after which an instance is considered down."), + "instance_health_check_frequency_sec": ("1", "1", "The time duration between two health checks/pings."), + "instance_get_uuid_frequency_sec": ("10", "10", "The time duration between two instance uuid checks."), "data_directory": ("mg_data", "mg_data", "Path to directory in which to save all permanent data."), "data_recovery_on_startup": ( "false", @@ -225,6 +221,6 @@ startup_config_dict = { "experimental_enabled": ( "", "", - "Experimental features to be used, comma seperated. Options [system-replication]", + "Experimental features to be used, comma seperated. Options [system-replication, high-availability]", ), } diff --git a/tests/e2e/fine_grained_access/create_delete_filtering_tests.py b/tests/e2e/fine_grained_access/create_delete_filtering_tests.py index b087fa6c5..bc8f107da 100644 --- a/tests/e2e/fine_grained_access/create_delete_filtering_tests.py +++ b/tests/e2e/fine_grained_access/create_delete_filtering_tests.py @@ -19,10 +19,10 @@ from mgclient import DatabaseError @pytest.mark.parametrize("switch", [False, True]) def test_create_node_all_labels_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) results = common.execute_and_fetch_all(user_connection.cursor(), "CREATE (n:label1) RETURN n;") @@ -33,10 +33,10 @@ def test_create_node_all_labels_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_create_node_all_labels_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -47,10 +47,10 @@ def test_create_node_all_labels_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_create_node_specific_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label1 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) results = common.execute_and_fetch_all(user_connection.cursor(), "CREATE (n:label1) RETURN n;") @@ -61,10 +61,10 @@ def test_create_node_specific_label_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_create_node_specific_label_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :label1 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -75,10 +75,10 @@ def test_create_node_specific_label_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_delete_node_all_labels_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n:test_delete) DELETE n;") @@ -91,10 +91,10 @@ def test_delete_node_all_labels_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_delete_node_all_labels_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -105,10 +105,10 @@ def test_delete_node_all_labels_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_delete_node_specific_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :test_delete TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n:test_delete) DELETE n;") @@ -123,10 +123,10 @@ def test_delete_node_specific_label_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_delete_node_specific_label_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :test_delete TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -137,11 +137,11 @@ def test_delete_node_specific_label_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_create_edge_all_labels_all_edge_types_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -156,11 +156,11 @@ def test_create_edge_all_labels_all_edge_types_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_create_edge_all_labels_all_edge_types_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -174,11 +174,11 @@ def test_create_edge_all_labels_all_edge_types_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_create_edge_all_labels_denied_all_edge_types_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -192,11 +192,11 @@ def test_create_edge_all_labels_denied_all_edge_types_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_create_edge_all_labels_granted_all_edge_types_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -210,7 +210,6 @@ def test_create_edge_all_labels_granted_all_edge_types_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_create_edge_all_labels_granted_specific_edge_types_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") @@ -218,6 +217,7 @@ def test_create_edge_all_labels_granted_specific_edge_types_denied(switch): admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES :edge_type TO user;", ) + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -231,7 +231,6 @@ def test_create_edge_all_labels_granted_specific_edge_types_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_create_edge_first_node_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label1 TO user;") @@ -240,6 +239,7 @@ def test_create_edge_first_node_label_granted(switch): admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES :edge_type TO user;", ) + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -253,7 +253,6 @@ def test_create_edge_first_node_label_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_create_edge_second_node_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label2 TO user;") @@ -262,6 +261,7 @@ def test_create_edge_second_node_label_granted(switch): admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES :edge_type TO user;", ) + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -275,11 +275,11 @@ def test_create_edge_second_node_label_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_delete_edge_all_labels_denied_all_edge_types_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -293,11 +293,11 @@ def test_delete_edge_all_labels_denied_all_edge_types_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_delete_edge_all_labels_granted_all_edge_types_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -311,7 +311,6 @@ def test_delete_edge_all_labels_granted_all_edge_types_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_delete_edge_all_labels_granted_specific_edge_types_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") @@ -319,6 +318,7 @@ def test_delete_edge_all_labels_granted_specific_edge_types_denied(switch): admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES :edge_type_delete TO user;", ) + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -332,7 +332,6 @@ def test_delete_edge_all_labels_granted_specific_edge_types_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_delete_edge_first_node_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :test_delete_1 TO user;") @@ -341,6 +340,7 @@ def test_delete_edge_first_node_label_granted(switch): admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES :edge_type_delete TO user;", ) + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -354,7 +354,6 @@ def test_delete_edge_first_node_label_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_delete_edge_second_node_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :test_delete_2 TO user;") @@ -363,6 +362,7 @@ def test_delete_edge_second_node_label_granted(switch): admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES :edge_type_delete TO user;", ) + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -376,13 +376,13 @@ def test_delete_edge_second_node_label_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_delete_node_with_edge_label_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all( admin_connection.cursor(), "GRANT UPDATE ON LABELS :test_delete_1 TO user;", ) + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -393,13 +393,13 @@ def test_delete_node_with_edge_label_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_delete_node_with_edge_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all( admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :test_delete_1 TO user;", ) + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -415,10 +415,10 @@ def test_delete_node_with_edge_label_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_merge_node_all_labels_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) results = common.execute_and_fetch_all(user_connection.cursor(), "MERGE (n:label1) RETURN n;") @@ -429,10 +429,10 @@ def test_merge_node_all_labels_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_merge_node_all_labels_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -443,10 +443,10 @@ def test_merge_node_all_labels_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_merge_node_specific_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label1 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) results = common.execute_and_fetch_all(user_connection.cursor(), "MERGE (n:label1) RETURN n;") @@ -457,10 +457,10 @@ def test_merge_node_specific_label_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_merge_node_specific_label_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :label1 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -471,11 +471,11 @@ def test_merge_node_specific_label_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_merge_edge_all_labels_all_edge_types_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) results = common.execute_and_fetch_all( @@ -489,11 +489,11 @@ def test_merge_edge_all_labels_all_edge_types_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_merge_edge_all_labels_all_edge_types_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -507,11 +507,11 @@ def test_merge_edge_all_labels_all_edge_types_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_merge_edge_all_labels_denied_all_edge_types_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -525,11 +525,11 @@ def test_merge_edge_all_labels_denied_all_edge_types_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_merge_edge_all_labels_granted_all_edge_types_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -543,7 +543,6 @@ def test_merge_edge_all_labels_granted_all_edge_types_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_merge_edge_all_labels_granted_specific_edge_types_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") @@ -551,6 +550,7 @@ def test_merge_edge_all_labels_granted_specific_edge_types_denied(switch): admin_connection.cursor(), "GRANT UPDATE ON EDGE_TYPES :edge_type TO user;", ) + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -564,7 +564,6 @@ def test_merge_edge_all_labels_granted_specific_edge_types_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_merge_edge_first_node_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label1 TO user;") @@ -573,6 +572,7 @@ def test_merge_edge_first_node_label_granted(switch): admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES :edge_type TO user;", ) + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -586,7 +586,6 @@ def test_merge_edge_first_node_label_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_merge_edge_second_node_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :label2 TO user;") @@ -595,6 +594,7 @@ def test_merge_edge_second_node_label_granted(switch): admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES :edge_type TO user;", ) + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -608,10 +608,10 @@ def test_merge_edge_second_node_label_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_set_label_when_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :update_label_2 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -621,12 +621,12 @@ def test_set_label_when_label_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_set_label_when_label_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :update_label_2 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :test_delete TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -637,11 +637,11 @@ def test_set_label_when_label_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_remove_label_when_label_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS :test_delete TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -651,12 +651,12 @@ def test_remove_label_when_label_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_remove_label_when_label_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT UPDATE ON LABELS :update_label_2 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :test_delete TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -667,12 +667,12 @@ def test_remove_label_when_label_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_merge_nodes_pass_when_having_create_delete(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.reset_and_prepare(admin_connection.cursor()) common.create_multi_db(admin_connection.cursor(), switch) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT CREATE_DELETE ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) diff --git a/tests/e2e/fine_grained_access/edge_type_filtering_tests.py b/tests/e2e/fine_grained_access/edge_type_filtering_tests.py index f2071c54b..4bae2b2f4 100644 --- a/tests/e2e/fine_grained_access/edge_type_filtering_tests.py +++ b/tests/e2e/fine_grained_access/edge_type_filtering_tests.py @@ -7,9 +7,9 @@ import pytest @pytest.mark.parametrize("switch", [False, True]) def test_all_edge_types_all_labels_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -21,9 +21,9 @@ def test_all_edge_types_all_labels_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_deny_all_edge_types_and_all_labels(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -35,9 +35,9 @@ def test_deny_all_edge_types_and_all_labels(switch): @pytest.mark.parametrize("switch", [False, True]) def test_revoke_all_edge_types_and_all_labels(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -49,10 +49,10 @@ def test_revoke_all_edge_types_and_all_labels(switch): @pytest.mark.parametrize("switch", [False, True]) def test_deny_edge_type(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :label1, :label2, :label3 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edgeType2 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edgeType1 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -64,10 +64,10 @@ def test_deny_edge_type(switch): @pytest.mark.parametrize("switch", [False, True]) def test_denied_node_label(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :label1,:label3 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edgeType1, :edgeType2 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label2 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -79,10 +79,10 @@ def test_denied_node_label(switch): @pytest.mark.parametrize("switch", [False, True]) def test_denied_one_of_node_label(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS :label1,:label2 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edgeType1, :edgeType2 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label3 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -94,8 +94,8 @@ def test_denied_one_of_node_label(switch): @pytest.mark.parametrize("switch", [False, True]) def test_revoke_all_labels(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;") @@ -106,8 +106,8 @@ def test_revoke_all_labels(switch): @pytest.mark.parametrize("switch", [False, True]) def test_revoke_all_edge_types(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) results = common.execute_and_fetch_all(user_connection.cursor(), "MATCH (n)-[r]->(m) RETURN n,r,m;") diff --git a/tests/e2e/fine_grained_access/path_filtering_tests.py b/tests/e2e/fine_grained_access/path_filtering_tests.py index c5b873972..e8c395b2e 100644 --- a/tests/e2e/fine_grained_access/path_filtering_tests.py +++ b/tests/e2e/fine_grained_access/path_filtering_tests.py @@ -7,11 +7,11 @@ import pytest @pytest.mark.parametrize("switch", [False, True]) def test_weighted_shortest_path_all_edge_types_all_labels_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -54,11 +54,11 @@ def test_weighted_shortest_path_all_edge_types_all_labels_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_weighted_shortest_path_all_edge_types_all_labels_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -72,7 +72,6 @@ def test_weighted_shortest_path_all_edge_types_all_labels_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_weighted_shortest_path_denied_start(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -80,6 +79,7 @@ def test_weighted_shortest_path_denied_start(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -94,7 +94,6 @@ def test_weighted_shortest_path_denied_start(switch): @pytest.mark.parametrize("switch", [False, True]) def test_weighted_shortest_path_denied_destination(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -102,6 +101,7 @@ def test_weighted_shortest_path_denied_destination(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -116,7 +116,6 @@ def test_weighted_shortest_path_denied_destination(switch): @pytest.mark.parametrize("switch", [False, True]) def test_weighted_shortest_path_denied_label_1(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -124,6 +123,7 @@ def test_weighted_shortest_path_denied_label_1(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -162,7 +162,6 @@ def test_weighted_shortest_path_denied_label_1(switch): @pytest.mark.parametrize("switch", [False, True]) def test_weighted_shortest_path_denied_edge_type_3(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") @@ -170,6 +169,7 @@ def test_weighted_shortest_path_denied_edge_type_3(switch): admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edge_type_1, :edge_type_2, :edge_type_4 TO user;" ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -213,11 +213,11 @@ def test_weighted_shortest_path_denied_edge_type_3(switch): @pytest.mark.parametrize("switch", [False, True]) def test_dfs_all_edge_types_all_labels_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -235,11 +235,11 @@ def test_dfs_all_edge_types_all_labels_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_dfs_all_edge_types_all_labels_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -251,7 +251,6 @@ def test_dfs_all_edge_types_all_labels_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_dfs_denied_start(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -259,6 +258,7 @@ def test_dfs_denied_start(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -272,7 +272,6 @@ def test_dfs_denied_start(switch): @pytest.mark.parametrize("switch", [False, True]) def test_dfs_denied_destination(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -280,6 +279,7 @@ def test_dfs_denied_destination(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -293,7 +293,6 @@ def test_dfs_denied_destination(switch): @pytest.mark.parametrize("switch", [False, True]) def test_dfs_denied_label_1(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -301,6 +300,7 @@ def test_dfs_denied_label_1(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -318,7 +318,6 @@ def test_dfs_denied_label_1(switch): @pytest.mark.parametrize("switch", [False, True]) def test_dfs_denied_edge_type_3(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") @@ -327,6 +326,7 @@ def test_dfs_denied_edge_type_3(switch): admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edge_type_1, :edge_type_2, :edge_type_4 TO user;" ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -344,11 +344,11 @@ def test_dfs_denied_edge_type_3(switch): @pytest.mark.parametrize("switch", [False, True]) def test_bfs_sts_all_edge_types_all_labels_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -366,11 +366,11 @@ def test_bfs_sts_all_edge_types_all_labels_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_bfs_sts_all_edge_types_all_labels_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -384,7 +384,6 @@ def test_bfs_sts_all_edge_types_all_labels_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_bfs_sts_denied_start(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -392,6 +391,7 @@ def test_bfs_sts_denied_start(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -405,7 +405,6 @@ def test_bfs_sts_denied_start(switch): @pytest.mark.parametrize("switch", [False, True]) def test_bfs_sts_denied_destination(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -413,6 +412,7 @@ def test_bfs_sts_denied_destination(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -426,7 +426,6 @@ def test_bfs_sts_denied_destination(switch): @pytest.mark.parametrize("switch", [False, True]) def test_bfs_sts_denied_label_1(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -434,6 +433,7 @@ def test_bfs_sts_denied_label_1(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -450,7 +450,6 @@ def test_bfs_sts_denied_label_1(switch): @pytest.mark.parametrize("switch", [False, True]) def test_bfs_sts_denied_edge_type_3(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") @@ -458,6 +457,7 @@ def test_bfs_sts_denied_edge_type_3(switch): admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edge_type_1, :edge_type_2, :edge_type_4 TO user;" ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -474,11 +474,11 @@ def test_bfs_sts_denied_edge_type_3(switch): @pytest.mark.parametrize("switch", [False, True]) def test_bfs_single_source_all_edge_types_all_labels_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -496,11 +496,11 @@ def test_bfs_single_source_all_edge_types_all_labels_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_bfs_single_source_all_edge_types_all_labels_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -512,7 +512,6 @@ def test_bfs_single_source_all_edge_types_all_labels_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_bfs_single_source_denied_start(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -520,6 +519,7 @@ def test_bfs_single_source_denied_start(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -533,7 +533,6 @@ def test_bfs_single_source_denied_start(switch): @pytest.mark.parametrize("switch", [False, True]) def test_bfs_single_source_denied_destination(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -541,6 +540,7 @@ def test_bfs_single_source_denied_destination(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -554,7 +554,6 @@ def test_bfs_single_source_denied_destination(switch): @pytest.mark.parametrize("switch", [False, True]) def test_bfs_single_source_denied_label_1(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -562,6 +561,7 @@ def test_bfs_single_source_denied_label_1(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -579,7 +579,6 @@ def test_bfs_single_source_denied_label_1(switch): @pytest.mark.parametrize("switch", [False, True]) def test_bfs_single_source_denied_edge_type_3(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") @@ -587,6 +586,7 @@ def test_bfs_single_source_denied_edge_type_3(switch): admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edge_type_1, :edge_type_2, :edge_type_4 TO user;" ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -604,11 +604,11 @@ def test_bfs_single_source_denied_edge_type_3(switch): @pytest.mark.parametrize("switch", [False, True]) def test_all_shortest_paths_when_all_edge_types_all_labels_granted(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -651,11 +651,11 @@ def test_all_shortest_paths_when_all_edge_types_all_labels_granted(switch): @pytest.mark.parametrize("switch", [False, True]) def test_all_shortest_paths_when_all_edge_types_all_labels_denied(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -669,7 +669,6 @@ def test_all_shortest_paths_when_all_edge_types_all_labels_denied(switch): @pytest.mark.parametrize("switch", [False, True]) def test_all_shortest_paths_when_denied_start(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -677,6 +676,7 @@ def test_all_shortest_paths_when_denied_start(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label0 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -691,7 +691,6 @@ def test_all_shortest_paths_when_denied_start(switch): @pytest.mark.parametrize("switch", [False, True]) def test_all_shortest_paths_when_denied_destination(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -699,6 +698,7 @@ def test_all_shortest_paths_when_denied_destination(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label4 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -713,7 +713,6 @@ def test_all_shortest_paths_when_denied_destination(switch): @pytest.mark.parametrize("switch", [False, True]) def test_all_shortest_paths_when_denied_label_1(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all( @@ -721,6 +720,7 @@ def test_all_shortest_paths_when_denied_label_1(switch): ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON LABELS :label1 TO user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON EDGE_TYPES * TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) @@ -759,7 +759,6 @@ def test_all_shortest_paths_when_denied_label_1(switch): @pytest.mark.parametrize("switch", [False, True]) def test_all_shortest_paths_when_denied_edge_type_3(switch): admin_connection = common.connect(username="admin", password="test") - user_connection = common.connect(username="user", password="test") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE LABELS * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "REVOKE EDGE_TYPES * FROM user;") common.execute_and_fetch_all(admin_connection.cursor(), "GRANT READ ON LABELS * TO user;") @@ -767,6 +766,7 @@ def test_all_shortest_paths_when_denied_edge_type_3(switch): admin_connection.cursor(), "GRANT READ ON EDGE_TYPES :edge_type_1, :edge_type_2, :edge_type_4 TO user;" ) common.execute_and_fetch_all(admin_connection.cursor(), "GRANT NOTHING ON EDGE_TYPES :edge_type_3 TO user;") + user_connection = common.connect(username="user", password="test") if switch: common.switch_db(user_connection.cursor()) diff --git a/tests/e2e/fine_grained_access/workloads.yaml b/tests/e2e/fine_grained_access/workloads.yaml index ad1dd43b2..6128b4b7d 100644 --- a/tests/e2e/fine_grained_access/workloads.yaml +++ b/tests/e2e/fine_grained_access/workloads.yaml @@ -84,11 +84,12 @@ show_databases_w_user_setup_queries: &show_databases_w_user_setup_queries - "GRANT DATABASE db1 TO user;" - "GRANT ALL PRIVILEGES TO user2;" - "GRANT DATABASE db2 TO user2;" + - "GRANT DATABASE memgraph TO user2;" - "REVOKE DATABASE memgraph FROM user2;" - "SET MAIN DATABASE db2 FOR user2" - "GRANT ALL PRIVILEGES TO user3;" - "GRANT DATABASE * TO user3;" - - "REVOKE DATABASE memgraph FROM user3;" + - "DENY DATABASE memgraph FROM user3;" - "SET MAIN DATABASE db1 FOR user3" create_delete_filtering_in_memory_cluster: &create_delete_filtering_in_memory_cluster diff --git a/tests/e2e/high_availability/CMakeLists.txt b/tests/e2e/high_availability/CMakeLists.txt new file mode 100644 index 000000000..47a1781aa --- /dev/null +++ b/tests/e2e/high_availability/CMakeLists.txt @@ -0,0 +1,15 @@ +find_package(gflags REQUIRED) + +copy_e2e_python_files(high_availability coordinator.py) +copy_e2e_python_files(high_availability single_coordinator.py) +copy_e2e_python_files(high_availability coord_cluster_registration.py) +copy_e2e_python_files(high_availability distributed_coords.py) +copy_e2e_python_files(high_availability disable_writing_on_main_after_restart.py) +copy_e2e_python_files(high_availability manual_setting_replicas.py) +copy_e2e_python_files(high_availability not_replicate_from_old_main.py) +copy_e2e_python_files(high_availability common.py) +copy_e2e_python_files(high_availability workloads.yaml) + +copy_e2e_python_files_from_parent_folder(high_availability ".." memgraph.py) +copy_e2e_python_files_from_parent_folder(high_availability ".." interactive_mg_runner.py) +copy_e2e_python_files_from_parent_folder(high_availability ".." mg_utils.py) diff --git a/tests/e2e/high_availability_experimental/common.py b/tests/e2e/high_availability/common.py similarity index 77% rename from tests/e2e/high_availability_experimental/common.py rename to tests/e2e/high_availability/common.py index adfabd87a..2157b29ca 100644 --- a/tests/e2e/high_availability_experimental/common.py +++ b/tests/e2e/high_availability/common.py @@ -30,3 +30,14 @@ def safe_execute(function, *args): function(*args) except: pass + + +# NOTE: Repeated execution because it can fail if Raft server is not up +def add_coordinator(cursor, query): + for _ in range(10): + try: + execute_and_fetch_all(cursor, query) + return True + except Exception: + pass + return False diff --git a/tests/e2e/high_availability_experimental/coord_cluster_registration.py b/tests/e2e/high_availability/coord_cluster_registration.py similarity index 61% rename from tests/e2e/high_availability_experimental/coord_cluster_registration.py rename to tests/e2e/high_availability/coord_cluster_registration.py index 5feb0bb11..68a387281 100644 --- a/tests/e2e/high_availability_experimental/coord_cluster_registration.py +++ b/tests/e2e/high_availability/coord_cluster_registration.py @@ -16,7 +16,7 @@ import tempfile import interactive_mg_runner import pytest -from common import connect, execute_and_fetch_all, safe_execute +from common import add_coordinator, connect, execute_and_fetch_all, safe_execute from mg_utils import mg_sleep_and_assert interactive_mg_runner.SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) @@ -31,6 +31,7 @@ TEMP_DIR = tempfile.TemporaryDirectory().name MEMGRAPH_INSTANCES_DESCRIPTION = { "instance_1": { "args": [ + "--experimental-enabled=high-availability", "--bolt-port", "7687", "--log-level", @@ -44,6 +45,7 @@ MEMGRAPH_INSTANCES_DESCRIPTION = { }, "instance_2": { "args": [ + "--experimental-enabled=high-availability", "--bolt-port", "7688", "--log-level", @@ -57,6 +59,7 @@ MEMGRAPH_INSTANCES_DESCRIPTION = { }, "instance_3": { "args": [ + "--experimental-enabled=high-availability", "--bolt-port", "7689", "--log-level", @@ -70,6 +73,7 @@ MEMGRAPH_INSTANCES_DESCRIPTION = { }, "coordinator_1": { "args": [ + "--experimental-enabled=high-availability", "--bolt-port", "7690", "--log-level=TRACE", @@ -81,6 +85,7 @@ MEMGRAPH_INSTANCES_DESCRIPTION = { }, "coordinator_2": { "args": [ + "--experimental-enabled=high-availability", "--bolt-port", "7691", "--log-level=TRACE", @@ -92,6 +97,7 @@ MEMGRAPH_INSTANCES_DESCRIPTION = { }, "coordinator_3": { "args": [ + "--experimental-enabled=high-availability", "--bolt-port", "7692", "--log-level=TRACE", @@ -104,17 +110,6 @@ MEMGRAPH_INSTANCES_DESCRIPTION = { } -# NOTE: Repeated execution because it can fail if Raft server is not up -def add_coordinator(cursor, query): - for _ in range(10): - try: - execute_and_fetch_all(cursor, query) - return True - except Exception: - pass - return False - - def test_register_repl_instances_then_coordinators(): safe_execute(shutil.rmtree, TEMP_DIR) interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION) @@ -280,5 +275,172 @@ def test_coordinators_communication_with_restarts(): mg_sleep_and_assert(expected_cluster_not_shared, check_coordinator2) +# TODO: (andi) Test when dealing with distributed coordinators that you can register on one coordinator and unregister from any other coordinator +@pytest.mark.parametrize( + "kill_instance", + [True, False], +) +def test_unregister_replicas(kill_instance): + safe_execute(shutil.rmtree, TEMP_DIR) + interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION) + + coordinator3_cursor = connect(host="localhost", port=7692).cursor() + execute_and_fetch_all( + coordinator3_cursor, "REGISTER INSTANCE instance_1 ON '127.0.0.1:10011' WITH '127.0.0.1:10001'" + ) + execute_and_fetch_all( + coordinator3_cursor, "REGISTER INSTANCE instance_2 ON '127.0.0.1:10012' WITH '127.0.0.1:10002'" + ) + execute_and_fetch_all( + coordinator3_cursor, "REGISTER INSTANCE instance_3 ON '127.0.0.1:10013' WITH '127.0.0.1:10003'" + ) + execute_and_fetch_all(coordinator3_cursor, "SET INSTANCE instance_3 TO MAIN") + + def check_coordinator3(): + return sorted(list(execute_and_fetch_all(coordinator3_cursor, "SHOW INSTANCES"))) + + main_cursor = connect(host="localhost", port=7689).cursor() + + def check_main(): + return sorted(list(execute_and_fetch_all(main_cursor, "SHOW REPLICAS"))) + + expected_cluster = [ + ("coordinator_3", "127.0.0.1:10113", "", True, "coordinator"), + ("instance_1", "", "127.0.0.1:10011", True, "replica"), + ("instance_2", "", "127.0.0.1:10012", True, "replica"), + ("instance_3", "", "127.0.0.1:10013", True, "main"), + ] + + expected_replicas = [ + ( + "instance_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "instance_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + + mg_sleep_and_assert(expected_cluster, check_coordinator3) + mg_sleep_and_assert(expected_replicas, check_main) + + if kill_instance: + interactive_mg_runner.kill(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_1") + execute_and_fetch_all(coordinator3_cursor, "UNREGISTER INSTANCE instance_1") + + expected_cluster = [ + ("coordinator_3", "127.0.0.1:10113", "", True, "coordinator"), + ("instance_2", "", "127.0.0.1:10012", True, "replica"), + ("instance_3", "", "127.0.0.1:10013", True, "main"), + ] + + expected_replicas = [ + ( + "instance_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + + mg_sleep_and_assert(expected_cluster, check_coordinator3) + mg_sleep_and_assert(expected_replicas, check_main) + + if kill_instance: + interactive_mg_runner.kill(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_2") + execute_and_fetch_all(coordinator3_cursor, "UNREGISTER INSTANCE instance_2") + + expected_cluster = [ + ("coordinator_3", "127.0.0.1:10113", "", True, "coordinator"), + ("instance_3", "", "127.0.0.1:10013", True, "main"), + ] + expected_replicas = [] + + mg_sleep_and_assert(expected_cluster, check_coordinator3) + mg_sleep_and_assert(expected_replicas, check_main) + + +def test_unregister_main(): + safe_execute(shutil.rmtree, TEMP_DIR) + interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION) + + coordinator3_cursor = connect(host="localhost", port=7692).cursor() + execute_and_fetch_all( + coordinator3_cursor, "REGISTER INSTANCE instance_1 ON '127.0.0.1:10011' WITH '127.0.0.1:10001'" + ) + execute_and_fetch_all( + coordinator3_cursor, "REGISTER INSTANCE instance_2 ON '127.0.0.1:10012' WITH '127.0.0.1:10002'" + ) + execute_and_fetch_all( + coordinator3_cursor, "REGISTER INSTANCE instance_3 ON '127.0.0.1:10013' WITH '127.0.0.1:10003'" + ) + execute_and_fetch_all(coordinator3_cursor, "SET INSTANCE instance_3 TO MAIN") + + def check_coordinator3(): + return sorted(list(execute_and_fetch_all(coordinator3_cursor, "SHOW INSTANCES"))) + + expected_cluster = [ + ("coordinator_3", "127.0.0.1:10113", "", True, "coordinator"), + ("instance_1", "", "127.0.0.1:10011", True, "replica"), + ("instance_2", "", "127.0.0.1:10012", True, "replica"), + ("instance_3", "", "127.0.0.1:10013", True, "main"), + ] + + mg_sleep_and_assert(expected_cluster, check_coordinator3) + + try: + execute_and_fetch_all(coordinator3_cursor, "UNREGISTER INSTANCE instance_3") + except Exception as e: + assert ( + str(e) + == "Alive main instance can't be unregistered! Shut it down to trigger failover and then unregister it!" + ) + + interactive_mg_runner.kill(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_3") + + expected_cluster = [ + ("coordinator_3", "127.0.0.1:10113", "", True, "coordinator"), + ("instance_1", "", "127.0.0.1:10011", True, "main"), + ("instance_2", "", "127.0.0.1:10012", True, "replica"), + ("instance_3", "", "127.0.0.1:10013", False, "unknown"), + ] + + mg_sleep_and_assert(expected_cluster, check_coordinator3) + + execute_and_fetch_all(coordinator3_cursor, "UNREGISTER INSTANCE instance_3") + + expected_cluster = [ + ("coordinator_3", "127.0.0.1:10113", "", True, "coordinator"), + ("instance_1", "", "127.0.0.1:10011", True, "main"), + ("instance_2", "", "127.0.0.1:10012", True, "replica"), + ] + + expected_replicas = [ + ( + "instance_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + + main_cursor = connect(host="localhost", port=7687).cursor() + + def check_main(): + return sorted(list(execute_and_fetch_all(main_cursor, "SHOW REPLICAS"))) + + mg_sleep_and_assert(expected_cluster, check_coordinator3) + mg_sleep_and_assert(expected_replicas, check_main) + + if __name__ == "__main__": sys.exit(pytest.main([__file__, "-rA"])) diff --git a/tests/e2e/high_availability_experimental/coordinator.py b/tests/e2e/high_availability/coordinator.py similarity index 100% rename from tests/e2e/high_availability_experimental/coordinator.py rename to tests/e2e/high_availability/coordinator.py diff --git a/tests/e2e/high_availability/disable_writing_on_main_after_restart.py b/tests/e2e/high_availability/disable_writing_on_main_after_restart.py new file mode 100644 index 000000000..53d570a6d --- /dev/null +++ b/tests/e2e/high_availability/disable_writing_on_main_after_restart.py @@ -0,0 +1,187 @@ +# Copyright 2022 Memgraph Ltd. +# +# Use of this software is governed by the Business Source License +# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +# License, and you may not use this file except in compliance with the Business Source License. +# +# As of the Change Date specified in that file, in accordance with +# the Business Source License, use of this software will be governed +# by the Apache License, Version 2.0, included in the file +# licenses/APL.txt. + +import os +import shutil +import sys +import tempfile + +import interactive_mg_runner +import pytest +from common import add_coordinator, connect, execute_and_fetch_all, safe_execute +from mg_utils import mg_sleep_and_assert + +interactive_mg_runner.SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) +interactive_mg_runner.PROJECT_DIR = os.path.normpath( + os.path.join(interactive_mg_runner.SCRIPT_DIR, "..", "..", "..", "..") +) +interactive_mg_runner.BUILD_DIR = os.path.normpath(os.path.join(interactive_mg_runner.PROJECT_DIR, "build")) +interactive_mg_runner.MEMGRAPH_BINARY = os.path.normpath(os.path.join(interactive_mg_runner.BUILD_DIR, "memgraph")) + +TEMP_DIR = tempfile.TemporaryDirectory().name + +MEMGRAPH_INSTANCES_DESCRIPTION = { + "instance_1": { + "args": [ + "--experimental-enabled=high-availability", + "--bolt-port", + "7687", + "--log-level", + "TRACE", + "--coordinator-server-port", + "10011", + "--also-log-to-stderr", + "--instance-health-check-frequency-sec", + "1", + "--instance-down-timeout-sec", + "5", + ], + "log_file": "instance_1.log", + "data_directory": f"{TEMP_DIR}/instance_1", + "setup_queries": [], + }, + "instance_2": { + "args": [ + "--experimental-enabled=high-availability", + "--bolt-port", + "7688", + "--log-level", + "TRACE", + "--coordinator-server-port", + "10012", + "--also-log-to-stderr", + "--instance-health-check-frequency-sec", + "1", + "--instance-down-timeout-sec", + "5", + ], + "log_file": "instance_2.log", + "data_directory": f"{TEMP_DIR}/instance_2", + "setup_queries": [], + }, + "instance_3": { + "args": [ + "--experimental-enabled=high-availability", + "--bolt-port", + "7689", + "--log-level", + "TRACE", + "--coordinator-server-port", + "10013", + "--also-log-to-stderr", + "--instance-health-check-frequency-sec", + "5", + "--instance-down-timeout-sec", + "10", + ], + "log_file": "instance_3.log", + "data_directory": f"{TEMP_DIR}/instance_3", + "setup_queries": [], + }, + "coordinator_1": { + "args": [ + "--experimental-enabled=high-availability", + "--bolt-port", + "7690", + "--log-level=TRACE", + "--raft-server-id=1", + "--raft-server-port=10111", + ], + "log_file": "coordinator1.log", + "setup_queries": [], + }, + "coordinator_2": { + "args": [ + "--experimental-enabled=high-availability", + "--bolt-port", + "7691", + "--log-level=TRACE", + "--raft-server-id=2", + "--raft-server-port=10112", + ], + "log_file": "coordinator2.log", + "setup_queries": [], + }, + "coordinator_3": { + "args": [ + "--experimental-enabled=high-availability", + "--bolt-port", + "7692", + "--log-level=TRACE", + "--raft-server-id=3", + "--raft-server-port=10113", + "--also-log-to-stderr", + ], + "log_file": "coordinator3.log", + "setup_queries": [], + }, +} + + +def test_writing_disabled_on_main_restart(): + safe_execute(shutil.rmtree, TEMP_DIR) + interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION) + + coordinator3_cursor = connect(host="localhost", port=7692).cursor() + + execute_and_fetch_all( + coordinator3_cursor, "REGISTER INSTANCE instance_3 ON '127.0.0.1:10013' WITH '127.0.0.1:10003'" + ) + execute_and_fetch_all(coordinator3_cursor, "SET INSTANCE instance_3 TO MAIN") + assert add_coordinator(coordinator3_cursor, "ADD COORDINATOR 1 ON '127.0.0.1:10111'") + assert add_coordinator(coordinator3_cursor, "ADD COORDINATOR 2 ON '127.0.0.1:10112'") + + def check_coordinator3(): + return sorted(list(execute_and_fetch_all(coordinator3_cursor, "SHOW INSTANCES"))) + + expected_cluster_coord3 = [ + ("coordinator_1", "127.0.0.1:10111", "", True, "coordinator"), + ("coordinator_2", "127.0.0.1:10112", "", True, "coordinator"), + ("coordinator_3", "127.0.0.1:10113", "", True, "coordinator"), + ("instance_3", "", "127.0.0.1:10013", True, "main"), + ] + mg_sleep_and_assert(expected_cluster_coord3, check_coordinator3) + + interactive_mg_runner.kill(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_3") + + expected_cluster_coord3 = [ + ("coordinator_1", "127.0.0.1:10111", "", True, "coordinator"), + ("coordinator_2", "127.0.0.1:10112", "", True, "coordinator"), + ("coordinator_3", "127.0.0.1:10113", "", True, "coordinator"), + ("instance_3", "", "127.0.0.1:10013", False, "unknown"), + ] + + mg_sleep_and_assert(expected_cluster_coord3, check_coordinator3) + + interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_3") + + try: + instance3_cursor = connect(host="localhost", port=7689).cursor() + execute_and_fetch_all(instance3_cursor, "CREATE (n:Node {name: 'node'})") + except Exception as e: + assert ( + str(e) + == "Write query forbidden on the main! Coordinator needs to enable writing on main by sending RPC message." + ) + + expected_cluster_coord3 = [ + ("coordinator_1", "127.0.0.1:10111", "", True, "coordinator"), + ("coordinator_2", "127.0.0.1:10112", "", True, "coordinator"), + ("coordinator_3", "127.0.0.1:10113", "", True, "coordinator"), + ("instance_3", "", "127.0.0.1:10013", True, "main"), + ] + + mg_sleep_and_assert(expected_cluster_coord3, check_coordinator3) + execute_and_fetch_all(instance3_cursor, "CREATE (n:Node {name: 'node'})") + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__, "-rA"])) diff --git a/tests/e2e/high_availability_experimental/distributed_coords.py b/tests/e2e/high_availability/distributed_coords.py similarity index 71% rename from tests/e2e/high_availability_experimental/distributed_coords.py rename to tests/e2e/high_availability/distributed_coords.py index 052cb6dba..07b6eefe0 100644 --- a/tests/e2e/high_availability_experimental/distributed_coords.py +++ b/tests/e2e/high_availability/distributed_coords.py @@ -17,7 +17,7 @@ import tempfile import interactive_mg_runner import pytest from common import connect, execute_and_fetch_all, safe_execute -from mg_utils import mg_sleep_and_assert +from mg_utils import mg_sleep_and_assert, mg_sleep_and_assert_collection interactive_mg_runner.SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) interactive_mg_runner.PROJECT_DIR = os.path.normpath( @@ -31,6 +31,7 @@ TEMP_DIR = tempfile.TemporaryDirectory().name MEMGRAPH_INSTANCES_DESCRIPTION = { "instance_1": { "args": [ + "--experimental-enabled=high-availability", "--bolt-port", "7687", "--log-level", @@ -44,6 +45,7 @@ MEMGRAPH_INSTANCES_DESCRIPTION = { }, "instance_2": { "args": [ + "--experimental-enabled=high-availability", "--bolt-port", "7688", "--log-level", @@ -57,6 +59,7 @@ MEMGRAPH_INSTANCES_DESCRIPTION = { }, "instance_3": { "args": [ + "--experimental-enabled=high-availability", "--bolt-port", "7689", "--log-level", @@ -70,6 +73,7 @@ MEMGRAPH_INSTANCES_DESCRIPTION = { }, "coordinator_1": { "args": [ + "--experimental-enabled=high-availability", "--bolt-port", "7690", "--log-level=TRACE", @@ -81,6 +85,7 @@ MEMGRAPH_INSTANCES_DESCRIPTION = { }, "coordinator_2": { "args": [ + "--experimental-enabled=high-availability", "--bolt-port", "7691", "--log-level=TRACE", @@ -92,6 +97,7 @@ MEMGRAPH_INSTANCES_DESCRIPTION = { }, "coordinator_3": { "args": [ + "--experimental-enabled=high-availability", "--bolt-port", "7692", "--log-level=TRACE", @@ -117,11 +123,23 @@ def test_distributed_automatic_failover(): main_cursor = connect(host="localhost", port=7689).cursor() expected_data_on_main = [ - ("instance_1", "127.0.0.1:10001", "sync", 0, 0, "ready"), - ("instance_2", "127.0.0.1:10002", "sync", 0, 0, "ready"), + ( + "instance_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "instance_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), ] actual_data_on_main = sorted(list(execute_and_fetch_all(main_cursor, "SHOW REPLICAS;"))) - assert actual_data_on_main == expected_data_on_main + assert actual_data_on_main == sorted(expected_data_on_main) interactive_mg_runner.kill(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_3") @@ -146,18 +164,42 @@ def test_distributed_automatic_failover(): return sorted(list(execute_and_fetch_all(new_main_cursor, "SHOW REPLICAS;"))) expected_data_on_new_main = [ - ("instance_2", "127.0.0.1:10002", "sync", 0, 0, "ready"), - ("instance_3", "127.0.0.1:10003", "sync", 0, 0, "invalid"), + ( + "instance_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "instance_3", + "127.0.0.1:10003", + "sync", + {"ts": 0, "behind": None, "status": "invalid"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "invalid"}}, + ), ] - mg_sleep_and_assert(expected_data_on_new_main, retrieve_data_show_replicas) + mg_sleep_and_assert_collection(expected_data_on_new_main, retrieve_data_show_replicas) interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_3") expected_data_on_new_main_old_alive = [ - ("instance_2", "127.0.0.1:10002", "sync", 0, 0, "ready"), - ("instance_3", "127.0.0.1:10003", "sync", 0, 0, "ready"), + ( + "instance_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "instance_3", + "127.0.0.1:10003", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), ] - mg_sleep_and_assert(expected_data_on_new_main_old_alive, retrieve_data_show_replicas) + mg_sleep_and_assert_collection(expected_data_on_new_main_old_alive, retrieve_data_show_replicas) if __name__ == "__main__": diff --git a/tests/e2e/high_availability_experimental/manual_setting_replicas.py b/tests/e2e/high_availability/manual_setting_replicas.py similarity index 84% rename from tests/e2e/high_availability_experimental/manual_setting_replicas.py rename to tests/e2e/high_availability/manual_setting_replicas.py index f2d48ffd7..b0b0965bc 100644 --- a/tests/e2e/high_availability_experimental/manual_setting_replicas.py +++ b/tests/e2e/high_availability/manual_setting_replicas.py @@ -14,8 +14,7 @@ import sys import interactive_mg_runner import pytest -from common import execute_and_fetch_all -from mg_utils import mg_sleep_and_assert +from common import connect, execute_and_fetch_all interactive_mg_runner.SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) interactive_mg_runner.PROJECT_DIR = os.path.normpath( @@ -26,20 +25,28 @@ interactive_mg_runner.MEMGRAPH_BINARY = os.path.normpath(os.path.join(interactiv MEMGRAPH_INSTANCES_DESCRIPTION = { "instance_3": { - "args": ["--bolt-port", "7687", "--log-level", "TRACE", "--coordinator-server-port", "10013"], + "args": [ + "--experimental-enabled=high-availability", + "--bolt-port", + "7687", + "--log-level", + "TRACE", + "--coordinator-server-port", + "10013", + ], "log_file": "main.log", "setup_queries": [], }, } -def test_no_manual_setup_on_main(connection): +def test_no_manual_setup_on_main(): # Goal of this test is to check that all manual registration actions are disabled on instances with coordiantor server port # 1 interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION) - any_main = connection(7687, "instance_3").cursor() + any_main = connect(host="localhost", port=7687).cursor() with pytest.raises(Exception) as e: execute_and_fetch_all(any_main, "REGISTER REPLICA replica_1 SYNC TO '127.0.0.1:10001';") assert str(e.value) == "Can't register replica manually on instance with coordinator server port." diff --git a/tests/e2e/high_availability/not_replicate_from_old_main.py b/tests/e2e/high_availability/not_replicate_from_old_main.py new file mode 100644 index 000000000..c2cc93cb1 --- /dev/null +++ b/tests/e2e/high_availability/not_replicate_from_old_main.py @@ -0,0 +1,287 @@ +# Copyright 2024 Memgraph Ltd. +# +# Use of this software is governed by the Business Source License +# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +# License, and you may not use this file except in compliance with the Business Source License. +# +# As of the Change Date specified in that file, in accordance with +# the Business Source License, use of this software will be governed +# by the Apache License, Version 2.0, included in the file +# licenses/APL.txt. + +import os +import shutil +import sys +import tempfile + +import interactive_mg_runner +import pytest +from common import connect, execute_and_fetch_all +from mg_utils import mg_sleep_and_assert + +interactive_mg_runner.SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) +interactive_mg_runner.PROJECT_DIR = os.path.normpath( + os.path.join(interactive_mg_runner.SCRIPT_DIR, "..", "..", "..", "..") +) +interactive_mg_runner.BUILD_DIR = os.path.normpath(os.path.join(interactive_mg_runner.PROJECT_DIR, "build")) +interactive_mg_runner.MEMGRAPH_BINARY = os.path.normpath(os.path.join(interactive_mg_runner.BUILD_DIR, "memgraph")) + +MEMGRAPH_FIRST_CLUSTER_DESCRIPTION = { + "shared_replica": { + "args": ["--experimental-enabled=high-availability", "--bolt-port", "7688", "--log-level", "TRACE"], + "log_file": "replica2.log", + "setup_queries": ["SET REPLICATION ROLE TO REPLICA WITH PORT 10001;"], + }, + "main1": { + "args": ["--experimental-enabled=high-availability", "--bolt-port", "7687", "--log-level", "TRACE"], + "log_file": "main.log", + "setup_queries": ["REGISTER REPLICA shared_replica SYNC TO '127.0.0.1:10001' ;"], + }, +} + + +MEMGRAPH_SECOND_CLUSTER_DESCRIPTION = { + "replica": { + "args": ["--experimental-enabled=high-availability", "--bolt-port", "7689", "--log-level", "TRACE"], + "log_file": "replica.log", + "setup_queries": ["SET REPLICATION ROLE TO REPLICA WITH PORT 10002;"], + }, + "main_2": { + "args": ["--experimental-enabled=high-availability", "--bolt-port", "7690", "--log-level", "TRACE"], + "log_file": "main_2.log", + "setup_queries": [ + "REGISTER REPLICA shared_replica SYNC TO '127.0.0.1:10001' ;", + "REGISTER REPLICA replica SYNC TO '127.0.0.1:10002' ; ", + ], + }, +} + + +def test_replication_works_on_failover(): + # Goal of this test is to check that after changing `shared_replica` + # to be part of new cluster, `main` (old cluster) can't write any more to it + + # 1 + interactive_mg_runner.start_all_keep_others(MEMGRAPH_FIRST_CLUSTER_DESCRIPTION) + + # 2 + main_cursor = connect(host="localhost", port=7687).cursor() + expected_data_on_main = [ + ( + "shared_replica", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + actual_data_on_main = sorted(list(execute_and_fetch_all(main_cursor, "SHOW REPLICAS;"))) + assert actual_data_on_main == expected_data_on_main + + # 3 + interactive_mg_runner.start_all_keep_others(MEMGRAPH_SECOND_CLUSTER_DESCRIPTION) + + # 4 + new_main_cursor = connect(host="localhost", port=7690).cursor() + + def retrieve_data_show_replicas(): + return sorted(list(execute_and_fetch_all(new_main_cursor, "SHOW REPLICAS;"))) + + expected_data_on_new_main = [ + ( + "replica", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "shared_replica", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + mg_sleep_and_assert(expected_data_on_new_main, retrieve_data_show_replicas) + + # 5 + shared_replica_cursor = connect(host="localhost", port=7688).cursor() + + with pytest.raises(Exception) as e: + execute_and_fetch_all(main_cursor, "CREATE ();") + assert "At least one SYNC replica has not confirmed committing last transaction." in str(e.value) + + res = execute_and_fetch_all(main_cursor, "MATCH (n) RETURN count(n) as count;")[0][0] + assert res == 1, "Vertex should be created" + + res = execute_and_fetch_all(shared_replica_cursor, "MATCH (n) RETURN count(n) as count;")[0][0] + assert res == 0, "Vertex shouldn't be replicated" + + # 7 + execute_and_fetch_all(new_main_cursor, "CREATE ();") + + res = execute_and_fetch_all(new_main_cursor, "MATCH (n) RETURN count(n) as count;")[0][0] + assert res == 1, "Vertex should be created" + + res = execute_and_fetch_all(shared_replica_cursor, "MATCH (n) RETURN count(n) as count;")[0][0] + assert res == 1, "Vertex should be replicated" + + interactive_mg_runner.stop_all() + + +def test_not_replicate_old_main_register_new_cluster(): + # Goal of this test is to check that although replica is registered in one cluster + # it can be re-registered to new cluster + # This flow checks if Registering replica is idempotent and that old main cannot talk to replica + # 1. We start all replicas and main in one cluster + # 2. Main from first cluster can see all replicas + # 3. We start all replicas and main in second cluster, by reusing one replica from first cluster + # 4. New main should see replica. Registration should pass (idempotent registration) + # 5. Old main should not talk to new replica + # 6. New main should talk to replica + + TEMP_DIR = tempfile.TemporaryDirectory().name + MEMGRAPH_FISRT_COORD_CLUSTER_DESCRIPTION = { + "shared_instance": { + "args": [ + "--experimental-enabled=high-availability", + "--bolt-port", + "7688", + "--log-level", + "TRACE", + "--coordinator-server-port", + "10011", + ], + "log_file": "instance_1.log", + "data_directory": f"{TEMP_DIR}/shared_instance", + "setup_queries": [], + }, + "instance_2": { + "args": [ + "--experimental-enabled=high-availability", + "--bolt-port", + "7689", + "--log-level", + "TRACE", + "--coordinator-server-port", + "10012", + ], + "log_file": "instance_2.log", + "data_directory": f"{TEMP_DIR}/instance_2", + "setup_queries": [], + }, + "coordinator_1": { + "args": [ + "--experimental-enabled=high-availability", + "--bolt-port", + "7690", + "--log-level=TRACE", + "--raft-server-id=1", + "--raft-server-port=10111", + ], + "log_file": "coordinator.log", + "setup_queries": [ + "REGISTER INSTANCE shared_instance ON '127.0.0.1:10011' WITH '127.0.0.1:10001';", + "REGISTER INSTANCE instance_2 ON '127.0.0.1:10012' WITH '127.0.0.1:10002';", + "SET INSTANCE instance_2 TO MAIN", + ], + }, + } + + # 1 + interactive_mg_runner.start_all_keep_others(MEMGRAPH_FISRT_COORD_CLUSTER_DESCRIPTION) + + # 2 + + first_cluster_coord_cursor = connect(host="localhost", port=7690).cursor() + + def show_repl_cluster(): + return sorted(list(execute_and_fetch_all(first_cluster_coord_cursor, "SHOW INSTANCES;"))) + + expected_data_up_first_cluster = [ + ("coordinator_1", "127.0.0.1:10111", "", True, "coordinator"), + ("instance_2", "", "127.0.0.1:10012", True, "main"), + ("shared_instance", "", "127.0.0.1:10011", True, "replica"), + ] + + mg_sleep_and_assert(expected_data_up_first_cluster, show_repl_cluster) + + # 3 + + MEMGRAPH_SECOND_COORD_CLUSTER_DESCRIPTION = { + "instance_3": { + "args": [ + "--experimental-enabled=high-availability", + "--bolt-port", + "7687", + "--log-level", + "TRACE", + "--coordinator-server-port", + "10013", + ], + "log_file": "instance_3.log", + "data_directory": f"{TEMP_DIR}/instance_3", + "setup_queries": [], + }, + "coordinator_2": { + "args": [ + "--experimental-enabled=high-availability", + "--bolt-port", + "7691", + "--log-level=TRACE", + "--raft-server-id=1", + "--raft-server-port=10112", + ], + "log_file": "coordinator.log", + "setup_queries": [], + }, + } + + interactive_mg_runner.start_all_keep_others(MEMGRAPH_SECOND_COORD_CLUSTER_DESCRIPTION) + second_cluster_coord_cursor = connect(host="localhost", port=7691).cursor() + execute_and_fetch_all( + second_cluster_coord_cursor, "REGISTER INSTANCE shared_instance ON '127.0.0.1:10011' WITH '127.0.0.1:10001';" + ) + execute_and_fetch_all( + second_cluster_coord_cursor, "REGISTER INSTANCE instance_3 ON '127.0.0.1:10013' WITH '127.0.0.1:10003';" + ) + execute_and_fetch_all(second_cluster_coord_cursor, "SET INSTANCE instance_3 TO MAIN") + + # 4 + + def show_repl_cluster(): + return sorted(list(execute_and_fetch_all(second_cluster_coord_cursor, "SHOW INSTANCES;"))) + + expected_data_up_second_cluster = [ + ("coordinator_1", "127.0.0.1:10112", "", True, "coordinator"), + ("instance_3", "", "127.0.0.1:10013", True, "main"), + ("shared_instance", "", "127.0.0.1:10011", True, "replica"), + ] + + mg_sleep_and_assert(expected_data_up_second_cluster, show_repl_cluster) + + # 5 + main_1_cursor = connect(host="localhost", port=7689).cursor() + with pytest.raises(Exception) as e: + execute_and_fetch_all(main_1_cursor, "CREATE ();") + assert "At least one SYNC replica has not confirmed committing last transaction." in str(e.value) + + shared_replica_cursor = connect(host="localhost", port=7688).cursor() + res = execute_and_fetch_all(shared_replica_cursor, "MATCH (n) RETURN count(n);")[0][0] + assert res == 0, "Old main should not replicate to 'shared' replica" + + # 6 + main_2_cursor = connect(host="localhost", port=7687).cursor() + + execute_and_fetch_all(main_2_cursor, "CREATE ();") + + shared_replica_cursor = connect(host="localhost", port=7688).cursor() + res = execute_and_fetch_all(shared_replica_cursor, "MATCH (n) RETURN count(n);")[0][0] + assert res == 1, "New main should replicate to 'shared' replica" + + interactive_mg_runner.stop_all() + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__, "-rA"])) diff --git a/tests/e2e/high_availability_experimental/single_coordinator.py b/tests/e2e/high_availability/single_coordinator.py similarity index 63% rename from tests/e2e/high_availability_experimental/single_coordinator.py rename to tests/e2e/high_availability/single_coordinator.py index d490a36ba..ecf063092 100644 --- a/tests/e2e/high_availability_experimental/single_coordinator.py +++ b/tests/e2e/high_availability/single_coordinator.py @@ -16,7 +16,7 @@ import tempfile import interactive_mg_runner import pytest from common import connect, execute_and_fetch_all, safe_execute -from mg_utils import mg_sleep_and_assert +from mg_utils import mg_sleep_and_assert, mg_sleep_and_assert_collection interactive_mg_runner.SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) interactive_mg_runner.PROJECT_DIR = os.path.normpath( @@ -30,6 +30,7 @@ TEMP_DIR = tempfile.TemporaryDirectory().name MEMGRAPH_INSTANCES_DESCRIPTION = { "instance_1": { "args": [ + "--experimental-enabled=high-availability", "--bolt-port", "7688", "--log-level", @@ -43,6 +44,7 @@ MEMGRAPH_INSTANCES_DESCRIPTION = { }, "instance_2": { "args": [ + "--experimental-enabled=high-availability", "--bolt-port", "7689", "--log-level", @@ -56,6 +58,7 @@ MEMGRAPH_INSTANCES_DESCRIPTION = { }, "instance_3": { "args": [ + "--experimental-enabled=high-availability", "--bolt-port", "7687", "--log-level", @@ -68,7 +71,14 @@ MEMGRAPH_INSTANCES_DESCRIPTION = { "setup_queries": [], }, "coordinator": { - "args": ["--bolt-port", "7690", "--log-level=TRACE", "--raft-server-id=1", "--raft-server-port=10111"], + "args": [ + "--experimental-enabled=high-availability", + "--bolt-port", + "7690", + "--log-level=TRACE", + "--raft-server-id=1", + "--raft-server-port=10111", + ], "log_file": "coordinator.log", "setup_queries": [ "REGISTER INSTANCE instance_1 ON '127.0.0.1:10011' WITH '127.0.0.1:10001';", @@ -96,8 +106,20 @@ def test_replication_works_on_failover(): # 2 main_cursor = connect(host="localhost", port=7687).cursor() expected_data_on_main = [ - ("instance_1", "127.0.0.1:10001", "sync", 0, 0, "ready"), - ("instance_2", "127.0.0.1:10002", "sync", 0, 0, "ready"), + ( + "instance_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "instance_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), ] actual_data_on_main = sorted(list(execute_and_fetch_all(main_cursor, "SHOW REPLICAS;"))) assert actual_data_on_main == expected_data_on_main @@ -125,17 +147,41 @@ def test_replication_works_on_failover(): return sorted(list(execute_and_fetch_all(new_main_cursor, "SHOW REPLICAS;"))) expected_data_on_new_main = [ - ("instance_2", "127.0.0.1:10002", "sync", 0, 0, "ready"), - ("instance_3", "127.0.0.1:10003", "sync", 0, 0, "invalid"), + ( + "instance_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "instance_3", + "127.0.0.1:10003", + "sync", + {"ts": 0, "behind": None, "status": "invalid"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "invalid"}}, + ), ] - mg_sleep_and_assert(expected_data_on_new_main, retrieve_data_show_replicas) + mg_sleep_and_assert_collection(expected_data_on_new_main, retrieve_data_show_replicas) interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_3") expected_data_on_new_main = [ - ("instance_2", "127.0.0.1:10002", "sync", 0, 0, "ready"), - ("instance_3", "127.0.0.1:10003", "sync", 0, 0, "ready"), + ( + "instance_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "instance_3", + "127.0.0.1:10003", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), ] - mg_sleep_and_assert(expected_data_on_new_main, retrieve_data_show_replicas) + mg_sleep_and_assert_collection(expected_data_on_new_main, retrieve_data_show_replicas) # 5 execute_and_fetch_all(new_main_cursor, "CREATE ();") @@ -147,6 +193,150 @@ def test_replication_works_on_failover(): interactive_mg_runner.stop_all(MEMGRAPH_INSTANCES_DESCRIPTION) +def test_replication_works_on_replica_instance_restart(): + # Goal of this test is to check the replication works after replica goes down and restarts + # 1. We start all replicas, main and coordinator manually: we want to be able to kill them ourselves without relying on external tooling to kill processes. + # 2. We check that main has correct state + # 3. We kill replica + # 4. We check that main cannot replicate to replica + # 5. We bring replica back up + # 6. We check that replica gets data + safe_execute(shutil.rmtree, TEMP_DIR) + + # 1 + interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION) + + # 2 + main_cursor = connect(host="localhost", port=7687).cursor() + expected_data_on_main = [ + ( + "instance_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "instance_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + actual_data_on_main = sorted(list(execute_and_fetch_all(main_cursor, "SHOW REPLICAS;"))) + assert actual_data_on_main == expected_data_on_main + + # 3 + coord_cursor = connect(host="localhost", port=7690).cursor() + + interactive_mg_runner.kill(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_2") + + def retrieve_data_show_repl_cluster(): + return sorted(list(execute_and_fetch_all(coord_cursor, "SHOW INSTANCES;"))) + + expected_data_on_coord = [ + ("coordinator_1", "127.0.0.1:10111", "", True, "coordinator"), + ("instance_1", "", "127.0.0.1:10011", True, "replica"), + ("instance_2", "", "127.0.0.1:10012", False, "unknown"), + ("instance_3", "", "127.0.0.1:10013", True, "main"), + ] + mg_sleep_and_assert_collection(expected_data_on_coord, retrieve_data_show_repl_cluster) + + def retrieve_data_show_replicas(): + return sorted(list(execute_and_fetch_all(main_cursor, "SHOW REPLICAS;"))) + + expected_data_on_main = [ + ( + "instance_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "instance_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "invalid"}}, + ), + ] + mg_sleep_and_assert_collection(expected_data_on_main, retrieve_data_show_replicas) + + # 4 + instance_1_cursor = connect(host="localhost", port=7688).cursor() + with pytest.raises(Exception) as e: + execute_and_fetch_all(main_cursor, "CREATE ();") + assert "At least one SYNC replica has not confirmed committing last transaction." in str(e.value) + + res_instance_1 = execute_and_fetch_all(instance_1_cursor, "MATCH (n) RETURN count(n)")[0][0] + assert res_instance_1 == 1 + + def retrieve_data_show_replicas(): + return sorted(list(execute_and_fetch_all(main_cursor, "SHOW REPLICAS;"))) + + expected_data_on_main = [ + ( + "instance_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 2, "behind": 0, "status": "ready"}}, + ), + ( + "instance_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "invalid"}}, + ), + ] + mg_sleep_and_assert_collection(expected_data_on_main, retrieve_data_show_replicas) + + # 5. + + interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_2") + + def retrieve_data_show_repl_cluster(): + return sorted(list(execute_and_fetch_all(coord_cursor, "SHOW INSTANCES;"))) + + expected_data_on_coord = [ + ("coordinator_1", "127.0.0.1:10111", "", True, "coordinator"), + ("instance_1", "", "127.0.0.1:10011", True, "replica"), + ("instance_2", "", "127.0.0.1:10012", True, "replica"), + ("instance_3", "", "127.0.0.1:10013", True, "main"), + ] + mg_sleep_and_assert(expected_data_on_coord, retrieve_data_show_repl_cluster) + + def retrieve_data_show_replicas(): + return sorted(list(execute_and_fetch_all(main_cursor, "SHOW REPLICAS;"))) + + expected_data_on_main = [ + ( + "instance_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 2, "behind": 0, "status": "ready"}}, + ), + ( + "instance_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 2, "behind": 0, "status": "ready"}}, + ), + ] + mg_sleep_and_assert_collection(expected_data_on_main, retrieve_data_show_replicas) + + # 6. + instance_2_cursor = connect(port=7689, host="localhost").cursor() + execute_and_fetch_all(main_cursor, "CREATE ();") + res_instance_2 = execute_and_fetch_all(instance_2_cursor, "MATCH (n) RETURN count(n)")[0][0] + assert res_instance_2 == 2 + + def test_show_instances(): safe_execute(shutil.rmtree, TEMP_DIR) interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION) @@ -207,11 +397,23 @@ def test_simple_automatic_failover(): main_cursor = connect(host="localhost", port=7687).cursor() expected_data_on_main = [ - ("instance_1", "127.0.0.1:10001", "sync", 0, 0, "ready"), - ("instance_2", "127.0.0.1:10002", "sync", 0, 0, "ready"), + ( + "instance_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "instance_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), ] actual_data_on_main = sorted(list(execute_and_fetch_all(main_cursor, "SHOW REPLICAS;"))) - assert actual_data_on_main == expected_data_on_main + assert actual_data_on_main == sorted(expected_data_on_main) interactive_mg_runner.kill(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_3") @@ -234,18 +436,42 @@ def test_simple_automatic_failover(): return sorted(list(execute_and_fetch_all(new_main_cursor, "SHOW REPLICAS;"))) expected_data_on_new_main = [ - ("instance_2", "127.0.0.1:10002", "sync", 0, 0, "ready"), - ("instance_3", "127.0.0.1:10003", "sync", 0, 0, "invalid"), + ( + "instance_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "instance_3", + "127.0.0.1:10003", + "sync", + {"ts": 0, "behind": None, "status": "invalid"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "invalid"}}, + ), ] - mg_sleep_and_assert(expected_data_on_new_main, retrieve_data_show_replicas) + mg_sleep_and_assert_collection(expected_data_on_new_main, retrieve_data_show_replicas) interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION, "instance_3") expected_data_on_new_main_old_alive = [ - ("instance_2", "127.0.0.1:10002", "sync", 0, 0, "ready"), - ("instance_3", "127.0.0.1:10003", "sync", 0, 0, "ready"), + ( + "instance_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "instance_3", + "127.0.0.1:10003", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), ] - mg_sleep_and_assert(expected_data_on_new_main_old_alive, retrieve_data_show_replicas) + mg_sleep_and_assert_collection(expected_data_on_new_main_old_alive, retrieve_data_show_replicas) def test_registering_replica_fails_name_exists(): @@ -416,5 +642,20 @@ def test_automatic_failover_main_back_as_main(): mg_sleep_and_assert([("main",)], retrieve_data_show_repl_role_instance3) +def test_disable_multiple_mains(): + safe_execute(shutil.rmtree, TEMP_DIR) + interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION) + + coord_cursor = connect(host="localhost", port=7690).cursor() + + try: + execute_and_fetch_all( + coord_cursor, + "SET INSTANCE instance_1 TO MAIN;", + ) + except Exception as e: + assert str(e) == "Couldn't set instance to main since there is already a main instance in cluster!" + + if __name__ == "__main__": sys.exit(pytest.main([__file__, "-rA"])) diff --git a/tests/e2e/high_availability_experimental/workloads.yaml b/tests/e2e/high_availability/workloads.yaml similarity index 53% rename from tests/e2e/high_availability_experimental/workloads.yaml rename to tests/e2e/high_availability/workloads.yaml index e624d35c0..75f17b2f7 100644 --- a/tests/e2e/high_availability_experimental/workloads.yaml +++ b/tests/e2e/high_availability/workloads.yaml @@ -1,19 +1,19 @@ ha_cluster: &ha_cluster cluster: replica_1: - args: ["--bolt-port", "7688", "--log-level=TRACE", "--coordinator-server-port=10011"] + args: ["--experimental-enabled=high-availability", "--bolt-port", "7688", "--log-level=TRACE", "--coordinator-server-port=10011"] log_file: "replication-e2e-replica1.log" setup_queries: [] replica_2: - args: ["--bolt-port", "7689", "--log-level=TRACE", "--coordinator-server-port=10012"] + args: ["--experimental-enabled=high-availability", "--bolt-port", "7689", "--log-level=TRACE", "--coordinator-server-port=10012"] log_file: "replication-e2e-replica2.log" setup_queries: [] main: - args: ["--bolt-port", "7687", "--log-level=TRACE", "--coordinator-server-port=10013"] + args: ["--experimental-enabled=high-availability", "--bolt-port", "7687", "--log-level=TRACE", "--coordinator-server-port=10013"] log_file: "replication-e2e-main.log" setup_queries: [] coordinator: - args: ["--bolt-port", "7690", "--log-level=TRACE", "--raft-server-id=1", "--raft-server-port=10111"] + args: ["--experimental-enabled=high-availability", "--bolt-port", "7690", "--log-level=TRACE", "--raft-server-id=1", "--raft-server-port=10111"] log_file: "replication-e2e-coordinator.log" setup_queries: [ "REGISTER INSTANCE instance_1 ON '127.0.0.1:10011' WITH '127.0.0.1:10001';", @@ -25,25 +25,29 @@ ha_cluster: &ha_cluster workloads: - name: "Coordinator" binary: "tests/e2e/pytest_runner.sh" - args: ["high_availability_experimental/coordinator.py"] + args: ["high_availability/coordinator.py"] <<: *ha_cluster - name: "Single coordinator" binary: "tests/e2e/pytest_runner.sh" - args: ["high_availability_experimental/single_coordinator.py"] + args: ["high_availability/single_coordinator.py"] - name: "Disabled manual setting of replication cluster" binary: "tests/e2e/pytest_runner.sh" - args: ["high_availability_experimental/manual_setting_replicas.py"] + args: ["high_availability/manual_setting_replicas.py"] - name: "Coordinator cluster registration" binary: "tests/e2e/pytest_runner.sh" - args: ["high_availability_experimental/coord_cluster_registration.py"] + args: ["high_availability/coord_cluster_registration.py"] - name: "Not replicate from old main" binary: "tests/e2e/pytest_runner.sh" - args: ["high_availability_experimental/not_replicate_from_old_main.py"] + args: ["high_availability/not_replicate_from_old_main.py"] + + - name: "Disable writing on main after restart" + binary: "tests/e2e/pytest_runner.sh" + args: ["high_availability/disable_writing_on_main_after_restart.py"] - name: "Distributed coordinators" binary: "tests/e2e/pytest_runner.sh" - args: ["high_availability_experimental/distributed_coords.py"] + args: ["high_availability/distributed_coords.py"] diff --git a/tests/e2e/high_availability_experimental/CMakeLists.txt b/tests/e2e/high_availability_experimental/CMakeLists.txt deleted file mode 100644 index d97080585..000000000 --- a/tests/e2e/high_availability_experimental/CMakeLists.txt +++ /dev/null @@ -1,14 +0,0 @@ -find_package(gflags REQUIRED) - -copy_e2e_python_files(ha_experimental coordinator.py) -copy_e2e_python_files(ha_experimental single_coordinator.py) -copy_e2e_python_files(ha_experimental coord_cluster_registration.py) -copy_e2e_python_files(ha_experimental distributed_coords.py) -copy_e2e_python_files(ha_experimental manual_setting_replicas.py) -copy_e2e_python_files(ha_experimental not_replicate_from_old_main.py) -copy_e2e_python_files(ha_experimental common.py) -copy_e2e_python_files(ha_experimental workloads.yaml) - -copy_e2e_python_files_from_parent_folder(ha_experimental ".." memgraph.py) -copy_e2e_python_files_from_parent_folder(ha_experimental ".." interactive_mg_runner.py) -copy_e2e_python_files_from_parent_folder(ha_experimental ".." mg_utils.py) diff --git a/tests/e2e/high_availability_experimental/not_replicate_from_old_main.py b/tests/e2e/high_availability_experimental/not_replicate_from_old_main.py deleted file mode 100644 index d6f6f7da4..000000000 --- a/tests/e2e/high_availability_experimental/not_replicate_from_old_main.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright 2024 Memgraph Ltd. -# -# Use of this software is governed by the Business Source License -# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source -# License, and you may not use this file except in compliance with the Business Source License. -# -# As of the Change Date specified in that file, in accordance with -# the Business Source License, use of this software will be governed -# by the Apache License, Version 2.0, included in the file -# licenses/APL.txt. - -import os -import sys - -import interactive_mg_runner -import pytest -from common import execute_and_fetch_all -from mg_utils import mg_sleep_and_assert - -interactive_mg_runner.SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) -interactive_mg_runner.PROJECT_DIR = os.path.normpath( - os.path.join(interactive_mg_runner.SCRIPT_DIR, "..", "..", "..", "..") -) -interactive_mg_runner.BUILD_DIR = os.path.normpath(os.path.join(interactive_mg_runner.PROJECT_DIR, "build")) -interactive_mg_runner.MEMGRAPH_BINARY = os.path.normpath(os.path.join(interactive_mg_runner.BUILD_DIR, "memgraph")) - -MEMGRAPH_FIRST_CLUSTER_DESCRIPTION = { - "shared_replica": { - "args": ["--bolt-port", "7688", "--log-level", "TRACE"], - "log_file": "replica2.log", - "setup_queries": ["SET REPLICATION ROLE TO REPLICA WITH PORT 10001;"], - }, - "main1": { - "args": ["--bolt-port", "7687", "--log-level", "TRACE"], - "log_file": "main.log", - "setup_queries": ["REGISTER REPLICA shared_replica SYNC TO '127.0.0.1:10001' ;"], - }, -} - - -MEMGRAPH_INSTANCES_DESCRIPTION = { - "replica": { - "args": ["--bolt-port", "7689", "--log-level", "TRACE"], - "log_file": "replica.log", - "setup_queries": ["SET REPLICATION ROLE TO REPLICA WITH PORT 10002;"], - }, - "main_2": { - "args": ["--bolt-port", "7690", "--log-level", "TRACE"], - "log_file": "main_2.log", - "setup_queries": [ - "REGISTER REPLICA shared_replica SYNC TO '127.0.0.1:10001' ;", - "REGISTER REPLICA replica SYNC TO '127.0.0.1:10002' ; ", - ], - }, -} - - -def test_replication_works_on_failover(connection): - # Goal of this test is to check that after changing `shared_replica` - # to be part of new cluster, `main` (old cluster) can't write any more to it - - # 1 - interactive_mg_runner.start_all_keep_others(MEMGRAPH_FIRST_CLUSTER_DESCRIPTION) - - # 2 - main_cursor = connection(7687, "main1").cursor() - expected_data_on_main = [ - ("shared_replica", "127.0.0.1:10001", "sync", 0, 0, "ready"), - ] - actual_data_on_main = sorted(list(execute_and_fetch_all(main_cursor, "SHOW REPLICAS;"))) - assert actual_data_on_main == expected_data_on_main - - # 3 - interactive_mg_runner.start_all_keep_others(MEMGRAPH_INSTANCES_DESCRIPTION) - - # 4 - new_main_cursor = connection(7690, "main_2").cursor() - - def retrieve_data_show_replicas(): - return sorted(list(execute_and_fetch_all(new_main_cursor, "SHOW REPLICAS;"))) - - expected_data_on_new_main = [ - ("replica", "127.0.0.1:10002", "sync", 0, 0, "ready"), - ("shared_replica", "127.0.0.1:10001", "sync", 0, 0, "ready"), - ] - mg_sleep_and_assert(expected_data_on_new_main, retrieve_data_show_replicas) - - # 5 - shared_replica_cursor = connection(7688, "shared_replica").cursor() - - with pytest.raises(Exception) as e: - execute_and_fetch_all(main_cursor, "CREATE ();") - assert ( - str(e.value) - == "Replication Exception: At least one SYNC replica has not confirmed committing last transaction. Check the status of the replicas using 'SHOW REPLICAS' query." - ) - - res = execute_and_fetch_all(main_cursor, "MATCH (n) RETURN count(n) as count;")[0][0] - assert res == 1, "Vertex should be created" - - res = execute_and_fetch_all(shared_replica_cursor, "MATCH (n) RETURN count(n) as count;")[0][0] - assert res == 0, "Vertex shouldn't be replicated" - - # 7 - execute_and_fetch_all(new_main_cursor, "CREATE ();") - - res = execute_and_fetch_all(new_main_cursor, "MATCH (n) RETURN count(n) as count;")[0][0] - assert res == 1, "Vertex should be created" - - res = execute_and_fetch_all(shared_replica_cursor, "MATCH (n) RETURN count(n) as count;")[0][0] - assert res == 1, "Vertex should be replicated" - - interactive_mg_runner.stop_all() - - -if __name__ == "__main__": - sys.exit(pytest.main([__file__, "-rA"])) diff --git a/tests/e2e/interactive_mg_runner.py b/tests/e2e/interactive_mg_runner.py index 06908747e..efa4dc3d5 100755 --- a/tests/e2e/interactive_mg_runner.py +++ b/tests/e2e/interactive_mg_runner.py @@ -160,6 +160,12 @@ def kill(context, name, keep_directories=True): MEMGRAPH_INSTANCES.pop(name) +def kill_all(context, keep_directories=True): + for key in MEMGRAPH_INSTANCES.keys(): + MEMGRAPH_INSTANCES[key].kill(keep_directories) + MEMGRAPH_INSTANCES.clear() + + def cleanup_directories_on_exit(value=True): CLEANUP_DIRECTORIES_ON_EXIT = value diff --git a/tests/e2e/lba_procedures/read_permission_queries.py b/tests/e2e/lba_procedures/read_permission_queries.py index 4348e6bda..4c02910da 100644 --- a/tests/e2e/lba_procedures/read_permission_queries.py +++ b/tests/e2e/lba_procedures/read_permission_queries.py @@ -107,18 +107,21 @@ def execute_read_node_assertion( operation_case: List[str], queries: List[str], create_index: bool, expected_size: int, switch: bool ) -> None: admin_cursor = get_admin_cursor() - user_cursor = get_user_cursor() if switch: create_multi_db(admin_cursor) switch_db(admin_cursor) - switch_db(user_cursor) reset_permissions(admin_cursor, create_index) for operation in operation_case: execute_and_fetch_all(admin_cursor, operation) + # Connect after possible auth changes + user_cursor = get_user_cursor() + if switch: + switch_db(user_cursor) + for mq in queries: results = execute_and_fetch_all(user_cursor, mq) assert len(results) == expected_size diff --git a/tests/e2e/memory/CMakeLists.txt b/tests/e2e/memory/CMakeLists.txt index 256107724..97fd8f9dc 100644 --- a/tests/e2e/memory/CMakeLists.txt +++ b/tests/e2e/memory/CMakeLists.txt @@ -22,9 +22,6 @@ target_link_libraries(memgraph__e2e__memory__limit_accumulation gflags mgclient add_executable(memgraph__e2e__memory__limit_edge_create memory_limit_edge_create.cpp) target_link_libraries(memgraph__e2e__memory__limit_edge_create gflags mgclient mg-utils mg-io) -add_executable(memgraph__e2e__memory_limit_global_multi_thread_proc_create memory_limit_global_multi_thread_proc_create.cpp) -target_link_libraries(memgraph__e2e__memory_limit_global_multi_thread_proc_create gflags mgclient mg-utils mg-io) - add_executable(memgraph__e2e__memory_limit_global_thread_alloc_proc memory_limit_global_thread_alloc_proc.cpp) target_link_libraries(memgraph__e2e__memory_limit_global_thread_alloc_proc gflags mgclient mg-utils mg-io) diff --git a/tests/e2e/memory/memory_control.cpp b/tests/e2e/memory/memory_control.cpp index 0d969220b..d4c4e431c 100644 --- a/tests/e2e/memory/memory_control.cpp +++ b/tests/e2e/memory/memory_control.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -44,7 +44,7 @@ int main(int argc, char **argv) { client->DiscardAll(); } - const auto *create_query = "UNWIND range(1, 50) as u CREATE (n {string: \"Some longer string\"}) RETURN n;"; + const auto *create_query = "UNWIND range(1, 100) as u CREATE (n {string: \"Some longer string\"}) RETURN n;"; memgraph::utils::Timer timer; while (true) { diff --git a/tests/e2e/memory/memory_limit_global_multi_thread_proc_create.cpp b/tests/e2e/memory/memory_limit_global_multi_thread_proc_create.cpp deleted file mode 100644 index e44c91ea7..000000000 --- a/tests/e2e/memory/memory_limit_global_multi_thread_proc_create.cpp +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2023 Memgraph Ltd. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source -// License, and you may not use this file except in compliance with the Business Source License. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -#include -#include -#include -#include -#include -#include - -#include "utils/logging.hpp" -#include "utils/timer.hpp" - -DEFINE_uint64(bolt_port, 7687, "Bolt port"); -DEFINE_uint64(timeout, 120, "Timeout seconds"); -DEFINE_bool(multi_db, false, "Run test in multi db environment"); - -int main(int argc, char **argv) { - google::SetUsageMessage("Memgraph E2E Global Memory Limit In Multi-Thread Create For Local Allocators"); - gflags::ParseCommandLineFlags(&argc, &argv, true); - memgraph::logging::RedirectToStderr(); - - mg::Client::Init(); - - auto client = - mg::Client::Connect({.host = "127.0.0.1", .port = static_cast(FLAGS_bolt_port), .use_ssl = false}); - if (!client) { - LOG_FATAL("Failed to connect!"); - } - - if (FLAGS_multi_db) { - client->Execute("CREATE DATABASE clean;"); - client->DiscardAll(); - client->Execute("USE DATABASE clean;"); - client->DiscardAll(); - client->Execute("MATCH (n) DETACH DELETE n;"); - client->DiscardAll(); - } - - bool error{false}; - try { - client->Execute( - "CALL libglobal_memory_limit_multi_thread_create_proc.multi_create() PROCEDURE MEMORY UNLIMITED YIELD " - "allocated_all RETURN allocated_all " - "QUERY MEMORY LIMIT 50MB;"); - auto result_rows = client->FetchAll(); - if (result_rows) { - auto row = *result_rows->begin(); - error = !row[0].ValueBool(); - } - - } catch (const std::exception &e) { - error = true; - } - - MG_ASSERT(error, "Error should have happend"); - - return 0; -} diff --git a/tests/e2e/memory/procedures/CMakeLists.txt b/tests/e2e/memory/procedures/CMakeLists.txt index df7acee31..8f9d625f3 100644 --- a/tests/e2e/memory/procedures/CMakeLists.txt +++ b/tests/e2e/memory/procedures/CMakeLists.txt @@ -6,7 +6,7 @@ target_include_directories(global_memory_limit_proc PRIVATE ${CMAKE_SOURCE_DIR}/ add_library(query_memory_limit_proc_multi_thread SHARED query_memory_limit_proc_multi_thread.cpp) target_include_directories(query_memory_limit_proc_multi_thread PRIVATE ${CMAKE_SOURCE_DIR}/include) -target_link_libraries(query_memory_limit_proc_multi_thread mg-utils) +target_link_libraries(query_memory_limit_proc_multi_thread mg-utils ) add_library(query_memory_limit_proc SHARED query_memory_limit_proc.cpp) target_include_directories(query_memory_limit_proc PRIVATE ${CMAKE_SOURCE_DIR}/include) @@ -16,10 +16,6 @@ add_library(global_memory_limit_thread_proc SHARED global_memory_limit_thread_pr target_include_directories(global_memory_limit_thread_proc PRIVATE ${CMAKE_SOURCE_DIR}/include) target_link_libraries(global_memory_limit_thread_proc mg-utils) -add_library(global_memory_limit_multi_thread_create_proc SHARED global_memory_limit_multi_thread_create_proc.cpp) -target_include_directories(global_memory_limit_multi_thread_create_proc PRIVATE ${CMAKE_SOURCE_DIR}/include) -target_link_libraries(global_memory_limit_multi_thread_create_proc mg-utils) - add_library(proc_memory_limit SHARED proc_memory_limit.cpp) target_include_directories(proc_memory_limit PRIVATE ${CMAKE_SOURCE_DIR}/include) target_link_libraries(proc_memory_limit mg-utils) diff --git a/tests/e2e/memory/procedures/global_memory_limit_multi_thread_create_proc.cpp b/tests/e2e/memory/procedures/global_memory_limit_multi_thread_create_proc.cpp deleted file mode 100644 index 2ccaac631..000000000 --- a/tests/e2e/memory/procedures/global_memory_limit_multi_thread_create_proc.cpp +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2023 Memgraph Ltd. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source -// License, and you may not use this file except in compliance with the Business Source License. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "mg_procedure.h" -#include "mgp.hpp" -#include "utils/on_scope_exit.hpp" - -// change communication between threads with feature and promise -std::atomic created_vertices{0}; -constexpr int num_vertices_per_thread{100'000}; -constexpr int num_threads{2}; - -void CallCreate(mgp_graph *graph, mgp_memory *memory) { - [[maybe_unused]] const enum mgp_error tracking_error = mgp_track_current_thread_allocations(graph); - for (int i = 0; i < num_vertices_per_thread; i++) { - struct mgp_vertex *vertex{nullptr}; - auto enum_error = mgp_graph_create_vertex(graph, memory, &vertex); - if (enum_error != mgp_error::MGP_ERROR_NO_ERROR) { - break; - } - created_vertices.fetch_add(1, std::memory_order_acq_rel); - } - [[maybe_unused]] const enum mgp_error untracking_error = mgp_untrack_current_thread_allocations(graph); -} - -void AllocFunc(mgp_graph *graph, mgp_memory *memory) { - try { - CallCreate(graph, memory); - } catch (const std::exception &e) { - return; - } -} - -void MultiCreate(mgp_list *args, mgp_graph *memgraph_graph, mgp_result *result, mgp_memory *memory) { - mgp::MemoryDispatcherGuard guard{memory}; - const auto arguments = mgp::List(args); - const auto record_factory = mgp::RecordFactory(result); - try { - std::vector threads; - - for (int i = 0; i < 2; i++) { - threads.emplace_back(AllocFunc, memgraph_graph, memory); - } - - for (int i = 0; i < num_threads; i++) { - threads[i].join(); - } - if (created_vertices.load(std::memory_order_acquire) != num_vertices_per_thread * num_threads) { - record_factory.SetErrorMessage("Unable to allocate"); - return; - } - - auto new_record = record_factory.NewRecord(); - new_record.Insert("allocated_all", - created_vertices.load(std::memory_order_acquire) == num_vertices_per_thread * num_threads); - } catch (std::exception &e) { - record_factory.SetErrorMessage(e.what()); - } -} - -extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory) { - try { - mgp::MemoryDispatcherGuard guard{memory}; - - AddProcedure(MultiCreate, std::string("multi_create").c_str(), mgp::ProcedureType::Write, {}, - {mgp::Return(std::string("allocated_all").c_str(), mgp::Type::Bool)}, module, memory); - - } catch (const std::exception &e) { - return 1; - } - - return 0; -} - -extern "C" int mgp_shutdown_module() { return 0; } diff --git a/tests/e2e/memory/procedures/query_memory_limit_proc_multi_thread.cpp b/tests/e2e/memory/procedures/query_memory_limit_proc_multi_thread.cpp index 0a1d8f125..4a1a96845 100644 --- a/tests/e2e/memory/procedures/query_memory_limit_proc_multi_thread.cpp +++ b/tests/e2e/memory/procedures/query_memory_limit_proc_multi_thread.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -13,73 +13,122 @@ #include #include #include +#include +#include +#include +#include +#include #include -#include #include #include -#include #include #include "mg_procedure.h" #include "mgp.hpp" +#include "utils/memory_tracker.hpp" #include "utils/on_scope_exit.hpp" -enum mgp_error Alloc(void *ptr) { - const size_t mb_size_268 = 1 << 28; +using safe_ptr = std::unique_ptr; +enum class AllocFuncRes { NoIssues, UnableToAlloc, Unexpected }; +using result_t = std::pair>; - return mgp_global_alloc(mb_size_268, (void **)(&ptr)); -} +constexpr auto N_THREADS = 2; +static_assert(N_THREADS > 0 && (N_THREADS & (N_THREADS - 1)) == 0); -// change communication between threads with feature and promise -std::atomic num_allocations{0}; -std::vector ptrs_; +constexpr auto mb_size_512 = 1 << 29; +constexpr auto mb_size_16 = 1 << 24; -void AllocFunc(mgp_graph *graph) { +static_assert(mb_size_512 % N_THREADS == 0); +static_assert(mb_size_16 % N_THREADS == 0); +static_assert(mb_size_512 % mb_size_16 == 0); + +void AllocFunc(std::latch &start_latch, std::promise promise, mgp_graph *graph) { [[maybe_unused]] const enum mgp_error tracking_error = mgp_track_current_thread_allocations(graph); - void *ptr = nullptr; + auto on_exit = memgraph::utils::OnScopeExit{[&]() { + [[maybe_unused]] const enum mgp_error untracking_error = mgp_untrack_current_thread_allocations(graph); + }}; + + std::list ptrs; + + // Ensure test would concurrently run these allocations, wait until both are ready + start_latch.arrive_and_wait(); - ptrs_.emplace_back(ptr); try { - enum mgp_error alloc_err { mgp_error::MGP_ERROR_NO_ERROR }; - alloc_err = Alloc(ptr); - if (alloc_err != mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) { - num_allocations.fetch_add(1, std::memory_order_relaxed); - } - if (alloc_err != mgp_error::MGP_ERROR_NO_ERROR) { - assert(false); + constexpr auto allocation_limit = mb_size_512 / N_THREADS; + // many allocation to increase chance of seeing any concurent issues + for (auto total = 0; total < allocation_limit; total += mb_size_16) { + void *ptr = nullptr; + auto alloc_err = mgp_global_alloc(mb_size_16, &ptr); + if (alloc_err != mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE && ptr != nullptr) { + ptrs.emplace_back(ptr); + } else if (alloc_err == mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) { + // this is expected, the test checks N threads allocating to a limit of 512MB + promise.set_value({AllocFuncRes::UnableToAlloc, std::move(ptrs)}); + return; + } else { + promise.set_value({AllocFuncRes::Unexpected, std::move(ptrs)}); + return; + } } } catch (const std::exception &e) { - assert(false); + promise.set_value({AllocFuncRes::Unexpected, std::move(ptrs)}); + return; } - - [[maybe_unused]] const enum mgp_error untracking_error = mgp_untrack_current_thread_allocations(graph); + promise.set_value({AllocFuncRes::NoIssues, std::move(ptrs)}); + return; } void DualThread(mgp_list *args, mgp_graph *memgraph_graph, mgp_result *result, mgp_memory *memory) { mgp::MemoryDispatcherGuard guard{memory}; const auto arguments = mgp::List(args); const auto record_factory = mgp::RecordFactory(result); - num_allocations.store(0, std::memory_order_relaxed); + + // 1 byte allocation to + auto ptr = std::invoke([&] { + void *ptr; + [[maybe_unused]] auto alloc_err = mgp_global_alloc(1, &ptr); + return safe_ptr{ptr}; + }); + try { - std::vector threads; - - for (int i = 0; i < 2; i++) { - threads.emplace_back(AllocFunc, memgraph_graph); - } - - for (int i = 0; i < 2; i++) { - threads[i].join(); - } - for (void *ptr : ptrs_) { - if (ptr != nullptr) { - mgp_global_free(ptr); + auto futures = std::vector>{}; + futures.reserve(N_THREADS); + std::latch start_latch{N_THREADS}; + { + auto threads = std::vector{}; + threads.reserve(N_THREADS); + for (int i = 0; i < N_THREADS; i++) { + auto promise = std::promise{}; + futures.emplace_back(promise.get_future()); + threads.emplace_back([&, promise = std::move(promise)]() mutable { + AllocFunc(start_latch, std::move(promise), memgraph_graph); + }); } + } // ~jthread will join + + int alloc_errors = 0; + int unexpected_errors = 0; + for (auto &x : futures) { + auto [res, ptrs] = x.get(); + alloc_errors += (res == AllocFuncRes::UnableToAlloc); + unexpected_errors += (res == AllocFuncRes::Unexpected); + // regardless of outcome, we want this thread to do the deallocation + ptrs.clear(); + } + + if (unexpected_errors != 0) { + record_factory.SetErrorMessage("Unanticipated error happened"); + return; + } + + if (alloc_errors < 1) { + record_factory.SetErrorMessage("Didn't hit the QUERY MEMORY LIMIT we expected"); + return; } auto new_record = record_factory.NewRecord(); - - new_record.Insert("allocated_all", num_allocations.load(std::memory_order_relaxed) == 2); + new_record.Insert("test_passed", true); } catch (std::exception &e) { record_factory.SetErrorMessage(e.what()); } @@ -90,7 +139,7 @@ extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *mem mgp::memory = memory; AddProcedure(DualThread, std::string("dual_thread").c_str(), mgp::ProcedureType::Read, {}, - {mgp::Return(std::string("allocated_all").c_str(), mgp::Type::Bool)}, module, memory); + {mgp::Return(std::string("test_passed").c_str(), mgp::Type::Bool)}, module, memory); } catch (const std::exception &e) { return 1; diff --git a/tests/e2e/memory/query_memory_limit_proc_multi_thread.cpp b/tests/e2e/memory/query_memory_limit_proc_multi_thread.cpp index 5acac5404..7d1d06b33 100644 --- a/tests/e2e/memory/query_memory_limit_proc_multi_thread.cpp +++ b/tests/e2e/memory/query_memory_limit_proc_multi_thread.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -46,21 +46,21 @@ int main(int argc, char **argv) { } MG_ASSERT( - client->Execute("CALL libquery_memory_limit_proc_multi_thread.dual_thread() YIELD allocated_all RETURN " - "allocated_all QUERY MEMORY LIMIT 500MB")); + client->Execute("CALL libquery_memory_limit_proc_multi_thread.dual_thread() YIELD test_passed RETURN " + "test_passed QUERY MEMORY LIMIT 500MB")); bool error{false}; try { auto result_rows = client->FetchAll(); if (result_rows) { auto row = *result_rows->begin(); - error = !row[0].ValueBool(); + MG_ASSERT(row[0].ValueBool(), "Execpected the procedure to pass"); + } else { + MG_ASSERT(false, "Expected at least one row"); } } catch (const std::exception &e) { - error = true; + MG_ASSERT(error, "This error should not have happend {}", e.what()); } - MG_ASSERT(error, "Error should have happend"); - return 0; } diff --git a/tests/e2e/memory/workloads.yaml b/tests/e2e/memory/workloads.yaml index 21924c880..bf29e484c 100644 --- a/tests/e2e/memory/workloads.yaml +++ b/tests/e2e/memory/workloads.yaml @@ -6,7 +6,22 @@ args: &args - "--storage-gc-cycle-sec=180" - "--log-level=TRACE" -in_memory_cluster: &in_memory_cluster +args_150_MiB_limit: &args_150_MiB_limit + - "--bolt-port" + - *bolt_port + - "--memory-limit=150" + - "--storage-gc-cycle-sec=180" + - "--log-level=TRACE" + +in_memory_150_MiB_limit_cluster: &in_memory_150_MiB_limit_cluster + cluster: + main: + args: *args_150_MiB_limit + log_file: "memory-e2e.log" + setup_queries: [] + validation_queries: [] + +in_memory_1024_MiB_limit_cluster: &in_memory_1024_MiB_limit_cluster cluster: main: args: *args @@ -61,6 +76,30 @@ disk_450_MiB_limit_cluster: &disk_450_MiB_limit_cluster setup_queries: [] validation_queries: [] +args_300_MiB_limit: &args_300_MiB_limit + - "--bolt-port" + - *bolt_port + - "--memory-limit=300" + - "--storage-gc-cycle-sec=180" + - "--log-level=INFO" + +in_memory_300_MiB_limit_cluster: &in_memory_300_MiB_limit_cluster + cluster: + main: + args: *args_300_MiB_limit + log_file: "memory-e2e.log" + setup_queries: [] + validation_queries: [] + + +disk_300_MiB_limit_cluster: &disk_300_MiB_limit_cluster + cluster: + main: + args: *args_300_MiB_limit + log_file: "memory-e2e.log" + setup_queries: [] + validation_queries: [] + args_global_limit_1024_MiB: &args_global_limit_1024_MiB - "--bolt-port" @@ -80,36 +119,36 @@ workloads: - name: "Memory control" binary: "tests/e2e/memory/memgraph__e2e__memory__control" args: ["--bolt-port", *bolt_port, "--timeout", "180"] - <<: *in_memory_cluster + <<: *in_memory_150_MiB_limit_cluster - name: "Memory control multi database" binary: "tests/e2e/memory/memgraph__e2e__memory__control" args: ["--bolt-port", *bolt_port, "--timeout", "180", "--multi-db", "true"] - <<: *in_memory_cluster + <<: *in_memory_150_MiB_limit_cluster - name: "Memory limit for modules upon loading" binary: "tests/e2e/memory/memgraph__e2e__memory__limit_global_alloc" args: ["--bolt-port", *bolt_port, "--timeout", "180"] proc: "tests/e2e/memory/procedures/" - <<: *in_memory_cluster + <<: *in_memory_1024_MiB_limit_cluster - name: "Memory limit for modules upon loading multi database" binary: "tests/e2e/memory/memgraph__e2e__memory__limit_global_alloc" args: ["--bolt-port", *bolt_port, "--timeout", "180", "--multi-db", "true"] proc: "tests/e2e/memory/procedures/" - <<: *in_memory_cluster + <<: *in_memory_1024_MiB_limit_cluster - name: "Memory limit for modules inside a procedure" binary: "tests/e2e/memory/memgraph__e2e__memory__limit_global_alloc_proc" args: ["--bolt-port", *bolt_port, "--timeout", "180"] proc: "tests/e2e/memory/procedures/" - <<: *in_memory_cluster + <<: *in_memory_1024_MiB_limit_cluster - name: "Memory limit for modules inside a procedure multi database" binary: "tests/e2e/memory/memgraph__e2e__memory__limit_global_alloc_proc" args: ["--bolt-port", *bolt_port, "--timeout", "180", "--multi-db", "true"] proc: "tests/e2e/memory/procedures/" - <<: *in_memory_cluster + <<: *in_memory_1024_MiB_limit_cluster - name: "Memory limit for modules upon loading for on-disk storage" binary: "tests/e2e/memory/memgraph__e2e__memory__limit_global_alloc" @@ -143,12 +182,12 @@ workloads: - name: "Memory control for detach delete" binary: "tests/e2e/memory/memgraph__e2e__memory__limit_delete" args: ["--bolt-port", *bolt_port] - <<: *in_memory_450_MiB_limit_cluster + <<: *in_memory_300_MiB_limit_cluster - name: "Memory control for detach delete on disk storage" binary: "tests/e2e/memory/memgraph__e2e__memory__limit_delete" args: ["--bolt-port", *bolt_port] - <<: *disk_450_MiB_limit_cluster + <<: *disk_300_MiB_limit_cluster - name: "Memory control for accumulation" binary: "tests/e2e/memory/memgraph__e2e__memory__limit_accumulation" @@ -170,17 +209,11 @@ workloads: args: ["--bolt-port", *bolt_port] <<: *disk_450_MiB_limit_cluster - - name: "Memory control for create from multi thread proc create" - binary: "tests/e2e/memory/memgraph__e2e__memory_limit_global_multi_thread_proc_create" - proc: "tests/e2e/memory/procedures/" - args: ["--bolt-port", *bolt_port] - <<: *in_memory_cluster - - name: "Memory control for memory limit global thread alloc" binary: "tests/e2e/memory/memgraph__e2e__memory_limit_global_thread_alloc_proc" proc: "tests/e2e/memory/procedures/" args: ["--bolt-port", *bolt_port] - <<: *in_memory_cluster + <<: *in_memory_1024_MiB_limit_cluster - name: "Procedure memory control for single procedure" binary: "tests/e2e/memory/memgraph__e2e__procedure_memory_limit" diff --git a/tests/e2e/mg_utils.py b/tests/e2e/mg_utils.py index 74cc8dc3a..3a475bf3c 100644 --- a/tests/e2e/mg_utils.py +++ b/tests/e2e/mg_utils.py @@ -15,3 +15,21 @@ def mg_sleep_and_assert(expected_value, function_to_retrieve_data, max_duration= result = function_to_retrieve_data() return result + + +def mg_sleep_and_assert_collection( + expected_value, function_to_retrieve_data, max_duration=20, time_between_attempt=0.2 +): + result = function_to_retrieve_data() + start_time = time.time() + while len(result) != len(expected_value) or any((x not in result for x in expected_value)): + duration = time.time() - start_time + if duration > max_duration: + assert ( + False + ), f" mg_sleep_and_assert has tried for too long and did not get the expected result! Last result was: {result}" + + time.sleep(time_between_attempt) + result = function_to_retrieve_data() + + return result diff --git a/tests/e2e/replication/constraints.cpp b/tests/e2e/replication/constraints.cpp index 6f7e2991a..de090007f 100644 --- a/tests/e2e/replication/constraints.cpp +++ b/tests/e2e/replication/constraints.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -18,6 +18,7 @@ #include #include "common.hpp" +#include "io/network/fmt.hpp" #include "utils/logging.hpp" #include "utils/thread.hpp" #include "utils/timer.hpp" diff --git a/tests/e2e/replication/edge_delete.py b/tests/e2e/replication/edge_delete.py index 0e25faee1..32a261a70 100755 --- a/tests/e2e/replication/edge_delete.py +++ b/tests/e2e/replication/edge_delete.py @@ -14,7 +14,7 @@ import time import pytest from common import execute_and_fetch_all -from mg_utils import mg_sleep_and_assert +from mg_utils import mg_sleep_and_assert_collection # BUGFIX: for issue https://github.com/memgraph/memgraph/issues/1515 @@ -28,28 +28,52 @@ def test_replication_handles_delete_when_multiple_edges_of_same_type(connection) conn = connection(7687, "main") conn.autocommit = True cursor = conn.cursor() - actual_data = set(execute_and_fetch_all(cursor, "SHOW REPLICAS;")) + actual_data = execute_and_fetch_all(cursor, "SHOW REPLICAS;") - expected_data = { - ("replica_1", "127.0.0.1:10001", "sync", 0, 0, "ready"), - ("replica_2", "127.0.0.1:10002", "async", 0, 0, "ready"), - } - assert actual_data == expected_data + expected_data = [ + ( + "replica_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_2", + "127.0.0.1:10002", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + assert all([x in actual_data for x in expected_data]) # 1/ execute_and_fetch_all(cursor, "CREATE (a)-[r:X]->(b) CREATE (a)-[:X]->(b) DELETE r;") # 2/ - expected_data = { - ("replica_1", "127.0.0.1:10001", "sync", 2, 0, "ready"), - ("replica_2", "127.0.0.1:10002", "async", 2, 0, "ready"), - } + expected_data = [ + ( + "replica_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 2, "behind": 0, "status": "ready"}}, + ), + ( + "replica_2", + "127.0.0.1:10002", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 2, "behind": 0, "status": "ready"}}, + ), + ] def retrieve_data(): - return set(execute_and_fetch_all(cursor, "SHOW REPLICAS;")) + return execute_and_fetch_all(cursor, "SHOW REPLICAS;") - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) - assert actual_data == expected_data + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) + assert all([x in actual_data for x in expected_data]) if __name__ == "__main__": diff --git a/tests/e2e/replication/indices.cpp b/tests/e2e/replication/indices.cpp index c0eee23a7..d4b5397f1 100644 --- a/tests/e2e/replication/indices.cpp +++ b/tests/e2e/replication/indices.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -18,6 +18,7 @@ #include #include "common.hpp" +#include "io/network/fmt.hpp" #include "utils/logging.hpp" #include "utils/thread.hpp" #include "utils/timer.hpp" diff --git a/tests/e2e/replication/read_write_benchmark.cpp b/tests/e2e/replication/read_write_benchmark.cpp index 243aab2a8..b7719faf0 100644 --- a/tests/e2e/replication/read_write_benchmark.cpp +++ b/tests/e2e/replication/read_write_benchmark.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -19,6 +19,7 @@ #include #include "common.hpp" +#include "io/network/fmt.hpp" #include "utils/logging.hpp" #include "utils/thread.hpp" #include "utils/timer.hpp" diff --git a/tests/e2e/replication/show.py b/tests/e2e/replication/show.py index 4a32f300d..315cb7142 100755 --- a/tests/e2e/replication/show.py +++ b/tests/e2e/replication/show.py @@ -10,12 +10,11 @@ # licenses/APL.txt. import sys - -import pytest import time +import pytest from common import execute_and_fetch_all -from mg_utils import mg_sleep_and_assert +from mg_utils import mg_sleep_and_assert_collection @pytest.mark.parametrize( @@ -31,25 +30,42 @@ def test_show_replication_role(port, role, connection): def test_show_replicas(connection): cursor = connection(7687, "main").cursor() - actual_data = set(execute_and_fetch_all(cursor, "SHOW REPLICAS;")) + actual_data = execute_and_fetch_all(cursor, "SHOW REPLICAS;") expected_column_names = { "name", "socket_address", "sync_mode", - "current_timestamp_of_replica", - "number_of_timestamp_behind_master", - "state", + "system_info", + "data_info", } actual_column_names = {x.name for x in cursor.description} assert actual_column_names == expected_column_names - expected_data = { - ("replica_1", "127.0.0.1:10001", "sync", 0, 0, "ready"), - ("replica_2", "127.0.0.1:10002", "sync", 0, 0, "ready"), - ("replica_3", "127.0.0.1:10003", "async", 0, 0, "ready"), - } - assert actual_data == expected_data + expected_data = [ + ( + "replica_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_3", + "127.0.0.1:10003", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + assert all([x in actual_data for x in expected_data]) def test_show_replicas_while_inserting_data(connection): @@ -62,49 +78,108 @@ def test_show_replicas_while_inserting_data(connection): # 0/ cursor = connection(7687, "main").cursor() - actual_data = set(execute_and_fetch_all(cursor, "SHOW REPLICAS;")) + actual_data = execute_and_fetch_all(cursor, "SHOW REPLICAS;") expected_column_names = { "name", "socket_address", "sync_mode", - "current_timestamp_of_replica", - "number_of_timestamp_behind_master", - "state", + "system_info", + "data_info", } actual_column_names = {x.name for x in cursor.description} assert actual_column_names == expected_column_names - expected_data = { - ("replica_1", "127.0.0.1:10001", "sync", 0, 0, "ready"), - ("replica_2", "127.0.0.1:10002", "sync", 0, 0, "ready"), - ("replica_3", "127.0.0.1:10003", "async", 0, 0, "ready"), - } - assert actual_data == expected_data + expected_data = [ + ( + "replica_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_3", + "127.0.0.1:10003", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + assert all([x in actual_data for x in expected_data]) # 1/ execute_and_fetch_all(cursor, "CREATE (n1:Number {name: 'forty_two', value:42});") # 2/ - expected_data = { - ("replica_1", "127.0.0.1:10001", "sync", 4, 0, "ready"), - ("replica_2", "127.0.0.1:10002", "sync", 4, 0, "ready"), - ("replica_3", "127.0.0.1:10003", "async", 4, 0, "ready"), - } + expected_data = [ + ( + "replica_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 4, "behind": 0, "status": "ready"}}, + ), + ( + "replica_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 4, "behind": 0, "status": "ready"}}, + ), + ( + "replica_3", + "127.0.0.1:10003", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 4, "behind": 0, "status": "ready"}}, + ), + ] def retrieve_data(): - return set(execute_and_fetch_all(cursor, "SHOW REPLICAS;")) + return execute_and_fetch_all(cursor, "SHOW REPLICAS;") - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) - assert actual_data == expected_data + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) + assert all([x in actual_data for x in expected_data]) # 3/ res = execute_and_fetch_all(cursor, "MATCH (node) return node;") assert len(res) == 1 # 4/ - actual_data = set(execute_and_fetch_all(cursor, "SHOW REPLICAS;")) - assert actual_data == expected_data + expected_data = [ + ( + "replica_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 4, "behind": 0, "status": "ready"}}, + ), + ( + "replica_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 4, "behind": 0, "status": "ready"}}, + ), + ( + "replica_3", + "127.0.0.1:10003", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 4, "behind": 0, "status": "ready"}}, + ), + ] + + actual_data = execute_and_fetch_all(cursor, "SHOW REPLICAS;") + assert all([x in actual_data for x in expected_data]) if __name__ == "__main__": diff --git a/tests/e2e/replication/show_while_creating_invalid_state.py b/tests/e2e/replication/show_while_creating_invalid_state.py index abd5b5f48..be7cd2b54 100644 --- a/tests/e2e/replication/show_while_creating_invalid_state.py +++ b/tests/e2e/replication/show_while_creating_invalid_state.py @@ -18,7 +18,7 @@ import interactive_mg_runner import mgclient import pytest from common import execute_and_fetch_all -from mg_utils import mg_sleep_and_assert +from mg_utils import mg_sleep_and_assert, mg_sleep_and_assert_collection interactive_mg_runner.SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) interactive_mg_runner.PROJECT_DIR = os.path.normpath( @@ -74,36 +74,77 @@ def test_show_replicas(connection): cursor = connection(7687, "main").cursor() # 1/ - actual_data = set(execute_and_fetch_all(cursor, "SHOW REPLICAS;")) + actual_data = execute_and_fetch_all(cursor, "SHOW REPLICAS;") EXPECTED_COLUMN_NAMES = { "name", "socket_address", "sync_mode", - "current_timestamp_of_replica", - "number_of_timestamp_behind_master", - "state", + "system_info", + "data_info", } actual_column_names = {x.name for x in cursor.description} assert actual_column_names == EXPECTED_COLUMN_NAMES - expected_data = { - ("replica_1", "127.0.0.1:10001", "sync", 0, 0, "ready"), - ("replica_2", "127.0.0.1:10002", "sync", 0, 0, "ready"), - ("replica_3", "127.0.0.1:10003", "async", 0, 0, "ready"), - ("replica_4", "127.0.0.1:10004", "async", 0, 0, "ready"), - } - assert actual_data == expected_data + expected_data = [ + ( + "replica_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_3", + "127.0.0.1:10003", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_4", + "127.0.0.1:10004", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + assert all([x in actual_data for x in expected_data]) # 2/ execute_and_fetch_all(cursor, "DROP REPLICA replica_2") - actual_data = set(execute_and_fetch_all(cursor, "SHOW REPLICAS;")) - expected_data = { - ("replica_1", "127.0.0.1:10001", "sync", 0, 0, "ready"), - ("replica_3", "127.0.0.1:10003", "async", 0, 0, "ready"), - ("replica_4", "127.0.0.1:10004", "async", 0, 0, "ready"), - } - assert actual_data == expected_data + actual_data = execute_and_fetch_all(cursor, "SHOW REPLICAS;") + expected_data = [ + ( + "replica_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_3", + "127.0.0.1:10003", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_4", + "127.0.0.1:10004", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + assert all([x in actual_data for x in expected_data]) # 3/ interactive_mg_runner.kill(MEMGRAPH_INSTANCES_DESCRIPTION, "replica_1") @@ -112,15 +153,33 @@ def test_show_replicas(connection): # We leave some time for the main to realise the replicas are down. def retrieve_data(): - return set(execute_and_fetch_all(cursor, "SHOW REPLICAS;")) + return execute_and_fetch_all(cursor, "SHOW REPLICAS;") - expected_data = { - ("replica_1", "127.0.0.1:10001", "sync", 0, 0, "invalid"), - ("replica_3", "127.0.0.1:10003", "async", 0, 0, "invalid"), - ("replica_4", "127.0.0.1:10004", "async", 0, 0, "invalid"), - } - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) - assert actual_data == expected_data + expected_data = [ + ( + "replica_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "invalid"}}, + ), + ( + "replica_3", + "127.0.0.1:10003", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "invalid"}}, + ), + ( + "replica_4", + "127.0.0.1:10004", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "invalid"}}, + ), + ] + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) + assert all([x in actual_data for x in expected_data]) def test_drop_replicas(connection): @@ -140,7 +199,7 @@ def test_drop_replicas(connection): # 12/ Drop all and check status def retrieve_data(): - return set(execute_and_fetch_all(cursor, "SHOW REPLICAS;")) + return execute_and_fetch_all(cursor, "SHOW REPLICAS;") # 0/ interactive_mg_runner.start_all(MEMGRAPH_INSTANCES_DESCRIPTION) @@ -148,89 +207,208 @@ def test_drop_replicas(connection): cursor = connection(7687, "main").cursor() # 1/ - actual_data = set(execute_and_fetch_all(cursor, "SHOW REPLICAS;")) + actual_data = execute_and_fetch_all(cursor, "SHOW REPLICAS;") EXPECTED_COLUMN_NAMES = { "name", "socket_address", "sync_mode", - "current_timestamp_of_replica", - "number_of_timestamp_behind_master", - "state", + "system_info", + "data_info", } actual_column_names = {x.name for x in cursor.description} assert actual_column_names == EXPECTED_COLUMN_NAMES - expected_data = { - ("replica_1", "127.0.0.1:10001", "sync", 0, 0, "ready"), - ("replica_2", "127.0.0.1:10002", "sync", 0, 0, "ready"), - ("replica_3", "127.0.0.1:10003", "async", 0, 0, "ready"), - ("replica_4", "127.0.0.1:10004", "async", 0, 0, "ready"), - } - mg_sleep_and_assert(expected_data, retrieve_data) + expected_data = [ + ( + "replica_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_3", + "127.0.0.1:10003", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_4", + "127.0.0.1:10004", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + mg_sleep_and_assert_collection(expected_data, retrieve_data) # 2/ interactive_mg_runner.kill(MEMGRAPH_INSTANCES_DESCRIPTION, "replica_3") - expected_data = { - ("replica_1", "127.0.0.1:10001", "sync", 0, 0, "ready"), - ("replica_2", "127.0.0.1:10002", "sync", 0, 0, "ready"), - ("replica_3", "127.0.0.1:10003", "async", 0, 0, "invalid"), - ("replica_4", "127.0.0.1:10004", "async", 0, 0, "ready"), - } - mg_sleep_and_assert(expected_data, retrieve_data) + expected_data = [ + ( + "replica_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_3", + "127.0.0.1:10003", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "invalid"}}, + ), + ( + "replica_4", + "127.0.0.1:10004", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + mg_sleep_and_assert_collection(expected_data, retrieve_data) # 3/ execute_and_fetch_all(cursor, "DROP REPLICA replica_3") - expected_data = { - ("replica_1", "127.0.0.1:10001", "sync", 0, 0, "ready"), - ("replica_2", "127.0.0.1:10002", "sync", 0, 0, "ready"), - ("replica_4", "127.0.0.1:10004", "async", 0, 0, "ready"), - } - mg_sleep_and_assert(expected_data, retrieve_data) + expected_data = [ + ( + "replica_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_4", + "127.0.0.1:10004", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + mg_sleep_and_assert_collection(expected_data, retrieve_data) # 4/ interactive_mg_runner.stop(MEMGRAPH_INSTANCES_DESCRIPTION, "replica_4") - expected_data = { - ("replica_1", "127.0.0.1:10001", "sync", 0, 0, "ready"), - ("replica_2", "127.0.0.1:10002", "sync", 0, 0, "ready"), - ("replica_4", "127.0.0.1:10004", "async", 0, 0, "invalid"), - } - mg_sleep_and_assert(expected_data, retrieve_data) + expected_data = [ + ( + "replica_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_4", + "127.0.0.1:10004", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "invalid"}}, + ), + ] + mg_sleep_and_assert_collection(expected_data, retrieve_data) # 5/ execute_and_fetch_all(cursor, "DROP REPLICA replica_4") - expected_data = { - ("replica_1", "127.0.0.1:10001", "sync", 0, 0, "ready"), - ("replica_2", "127.0.0.1:10002", "sync", 0, 0, "ready"), - } - mg_sleep_and_assert(expected_data, retrieve_data) + expected_data = [ + ( + "replica_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + mg_sleep_and_assert_collection(expected_data, retrieve_data) # 6/ interactive_mg_runner.kill(MEMGRAPH_INSTANCES_DESCRIPTION, "replica_1") - expected_data = { - ("replica_1", "127.0.0.1:10001", "sync", 0, 0, "invalid"), - ("replica_2", "127.0.0.1:10002", "sync", 0, 0, "ready"), - } - mg_sleep_and_assert(expected_data, retrieve_data) + expected_data = [ + ( + "replica_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "invalid"}}, + ), + ( + "replica_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + mg_sleep_and_assert_collection(expected_data, retrieve_data) # 7/ execute_and_fetch_all(cursor, "DROP REPLICA replica_1") - expected_data = { - ("replica_2", "127.0.0.1:10002", "sync", 0, 0, "ready"), - } - mg_sleep_and_assert(expected_data, retrieve_data) + expected_data = [ + ( + "replica_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + mg_sleep_and_assert_collection(expected_data, retrieve_data) # 8/ interactive_mg_runner.stop(MEMGRAPH_INSTANCES_DESCRIPTION, "replica_2") - expected_data = { - ("replica_2", "127.0.0.1:10002", "sync", 0, 0, "invalid"), - } - mg_sleep_and_assert(expected_data, retrieve_data) + expected_data = [ + ( + "replica_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "invalid"}}, + ), + ] + mg_sleep_and_assert_collection(expected_data, retrieve_data) # 9/ execute_and_fetch_all(cursor, "DROP REPLICA replica_2") expected_data = set() - mg_sleep_and_assert(expected_data, retrieve_data) + mg_sleep_and_assert_collection(expected_data, retrieve_data) # 10/ interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION, "replica_1") @@ -243,13 +421,37 @@ def test_drop_replicas(connection): execute_and_fetch_all(cursor, "REGISTER REPLICA replica_4 ASYNC TO '127.0.0.1:10004';") # 11/ - expected_data = { - ("replica_1", "127.0.0.1:10001", "sync", 0, 0, "ready"), - ("replica_2", "127.0.0.1:10002", "sync", 0, 0, "ready"), - ("replica_3", "127.0.0.1:10003", "async", 0, 0, "ready"), - ("replica_4", "127.0.0.1:10004", "async", 0, 0, "ready"), - } - mg_sleep_and_assert(expected_data, retrieve_data) + expected_data = [ + ( + "replica_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_3", + "127.0.0.1:10003", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_4", + "127.0.0.1:10004", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + mg_sleep_and_assert_collection(expected_data, retrieve_data) # 12/ execute_and_fetch_all(cursor, "DROP REPLICA replica_1") @@ -257,7 +459,7 @@ def test_drop_replicas(connection): execute_and_fetch_all(cursor, "DROP REPLICA replica_3") execute_and_fetch_all(cursor, "DROP REPLICA replica_4") expected_data = set() - mg_sleep_and_assert(expected_data, retrieve_data) + mg_sleep_and_assert_collection(expected_data, retrieve_data) @pytest.mark.parametrize( @@ -379,15 +581,39 @@ def test_basic_recovery(recover_data_on_startup, connection): execute_and_fetch_all(cursor, "REGISTER REPLICA replica_4 ASYNC TO '127.0.0.1:10004';") # 1/ - expected_data = { - ("replica_1", "127.0.0.1:10001", "sync", 0, 0, "ready"), - ("replica_2", "127.0.0.1:10002", "sync", 0, 0, "ready"), - ("replica_3", "127.0.0.1:10003", "async", 0, 0, "ready"), - ("replica_4", "127.0.0.1:10004", "async", 0, 0, "ready"), - } - actual_data = set(execute_and_fetch_all(cursor, "SHOW REPLICAS;")) + expected_data = [ + ( + "replica_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_3", + "127.0.0.1:10003", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_4", + "127.0.0.1:10004", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + actual_data = execute_and_fetch_all(cursor, "SHOW REPLICAS;") - assert actual_data == expected_data + assert all([x in actual_data for x in expected_data]) def check_roles(): assert "main" == interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICATION ROLE;")[0][0] @@ -409,10 +635,10 @@ def test_basic_recovery(recover_data_on_startup, connection): # 4/ def retrieve_data(): - return set(execute_and_fetch_all(cursor, "SHOW REPLICAS;")) + return execute_and_fetch_all(cursor, "SHOW REPLICAS;") - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) - assert actual_data == expected_data + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) + assert all([x in actual_data for x in expected_data]) # 5/ execute_and_fetch_all(cursor, "DROP REPLICA replica_2;") @@ -431,13 +657,31 @@ def test_basic_recovery(recover_data_on_startup, connection): for index in (1, 3, 4): assert res_from_main == interactive_mg_runner.MEMGRAPH_INSTANCES[f"replica_{index}"].query(QUERY_TO_CHECK) - expected_data = { - ("replica_1", "127.0.0.1:10001", "sync", 2, 0, "ready"), - ("replica_3", "127.0.0.1:10003", "async", 2, 0, "ready"), - ("replica_4", "127.0.0.1:10004", "async", 2, 0, "ready"), - } - actual_data = set(execute_and_fetch_all(cursor, "SHOW REPLICAS;")) - assert actual_data == expected_data + expected_data = [ + ( + "replica_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 2, "behind": 0, "status": "ready"}}, + ), + ( + "replica_3", + "127.0.0.1:10003", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 2, "behind": 0, "status": "ready"}}, + ), + ( + "replica_4", + "127.0.0.1:10004", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 2, "behind": 0, "status": "ready"}}, + ), + ] + actual_data = execute_and_fetch_all(cursor, "SHOW REPLICAS;") + assert all([x in actual_data for x in expected_data]) # Replica_2 was dropped, we check it does not have the data from main. assert len(interactive_mg_runner.MEMGRAPH_INSTANCES["replica_2"].query(QUERY_TO_CHECK)) == 0 @@ -454,59 +698,155 @@ def test_basic_recovery(recover_data_on_startup, connection): execute_and_fetch_all(cursor, "REGISTER REPLICA replica_2 SYNC TO '127.0.0.1:10002';") interactive_mg_runner.start(CONFIGURATION, "replica_3") - expected_data = { - ("replica_1", "127.0.0.1:10001", "sync", 6, 0, "ready"), - ("replica_2", "127.0.0.1:10002", "sync", 6, 0, "ready"), - ("replica_3", "127.0.0.1:10003", "async", 6, 0, "ready"), - ("replica_4", "127.0.0.1:10004", "async", 6, 0, "ready"), - } + expected_data = [ + ( + "replica_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 6, "behind": 0, "status": "ready"}}, + ), + ( + "replica_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 6, "behind": 0, "status": "ready"}}, + ), + ( + "replica_3", + "127.0.0.1:10003", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 6, "behind": 0, "status": "ready"}}, + ), + ( + "replica_4", + "127.0.0.1:10004", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 6, "behind": 0, "status": "ready"}}, + ), + ] - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) - assert actual_data == expected_data + assert all([x in actual_data for x in expected_data]) for index in (1, 2, 3, 4): assert interactive_mg_runner.MEMGRAPH_INSTANCES[f"replica_{index}"].query(QUERY_TO_CHECK) == res_from_main # 11/ interactive_mg_runner.kill(CONFIGURATION, "replica_1") - expected_data = { - ("replica_1", "127.0.0.1:10001", "sync", 0, 0, "invalid"), - ("replica_2", "127.0.0.1:10002", "sync", 6, 0, "ready"), - ("replica_3", "127.0.0.1:10003", "async", 6, 0, "ready"), - ("replica_4", "127.0.0.1:10004", "async", 6, 0, "ready"), - } + expected_data = [ + ( + "replica_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "invalid"}}, + ), + ( + "replica_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 6, "behind": 0, "status": "ready"}}, + ), + ( + "replica_3", + "127.0.0.1:10003", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 6, "behind": 0, "status": "ready"}}, + ), + ( + "replica_4", + "127.0.0.1:10004", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 6, "behind": 0, "status": "ready"}}, + ), + ] - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) - assert actual_data == expected_data + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) + assert all([x in actual_data for x in expected_data]) # 12/ with pytest.raises(mgclient.DatabaseError): interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query( "CREATE (p1:Number {name:'Magic_again_again', value:44})" ) - expected_data = { - ("replica_1", "127.0.0.1:10001", "sync", 0, 0, "invalid"), - ("replica_2", "127.0.0.1:10002", "sync", 9, 0, "ready"), - ("replica_3", "127.0.0.1:10003", "async", 9, 0, "ready"), - ("replica_4", "127.0.0.1:10004", "async", 9, 0, "ready"), - } + expected_data = [ + ( + "replica_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "invalid"}}, + ), + ( + "replica_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 9, "behind": 0, "status": "ready"}}, + ), + ( + "replica_3", + "127.0.0.1:10003", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 9, "behind": 0, "status": "ready"}}, + ), + ( + "replica_4", + "127.0.0.1:10004", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 9, "behind": 0, "status": "ready"}}, + ), + ] - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) - assert actual_data == expected_data + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) + assert all([x in actual_data for x in expected_data]) # 13/ interactive_mg_runner.start(CONFIGURATION, "replica_1") # 14/ - expected_data = { - ("replica_1", "127.0.0.1:10001", "sync", 9, 0, "ready"), - ("replica_2", "127.0.0.1:10002", "sync", 9, 0, "ready"), - ("replica_3", "127.0.0.1:10003", "async", 9, 0, "ready"), - ("replica_4", "127.0.0.1:10004", "async", 9, 0, "ready"), - } - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) + expected_data = [ + ( + "replica_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 9, "behind": 0, "status": "ready"}}, + ), + ( + "replica_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 9, "behind": 0, "status": "ready"}}, + ), + ( + "replica_3", + "127.0.0.1:10003", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 9, "behind": 0, "status": "ready"}}, + ), + ( + "replica_4", + "127.0.0.1:10004", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 9, "behind": 0, "status": "ready"}}, + ), + ] + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) print("actual=", actual_data) - assert actual_data == expected_data + assert all([x in actual_data for x in expected_data]) res_from_main = execute_and_fetch_all(cursor, QUERY_TO_CHECK) assert len(res_from_main) == 3 @@ -519,14 +859,38 @@ def test_basic_recovery(recover_data_on_startup, connection): ) # 16/ - expected_data = { - ("replica_1", "127.0.0.1:10001", "sync", 12, 0, "ready"), - ("replica_2", "127.0.0.1:10002", "sync", 12, 0, "ready"), - ("replica_3", "127.0.0.1:10003", "async", 12, 0, "ready"), - ("replica_4", "127.0.0.1:10004", "async", 12, 0, "ready"), - } - actual_data = set(execute_and_fetch_all(cursor, "SHOW REPLICAS;")) - assert actual_data == expected_data + expected_data = [ + ( + "replica_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 12, "behind": 0, "status": "ready"}}, + ), + ( + "replica_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 12, "behind": 0, "status": "ready"}}, + ), + ( + "replica_3", + "127.0.0.1:10003", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 12, "behind": 0, "status": "ready"}}, + ), + ( + "replica_4", + "127.0.0.1:10004", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 12, "behind": 0, "status": "ready"}}, + ), + ] + actual_data = execute_and_fetch_all(cursor, "SHOW REPLICAS;") + assert all([x in actual_data for x in expected_data]) res_from_main = execute_and_fetch_all(cursor, QUERY_TO_CHECK) assert len(res_from_main) == 4 @@ -590,12 +954,18 @@ def test_replication_role_recovery(connection): execute_and_fetch_all(cursor, "REGISTER REPLICA replica SYNC TO '127.0.0.1:10001';") # 1/ - expected_data = { - ("replica", "127.0.0.1:10001", "sync", 0, 0, "ready"), - } - actual_data = set(execute_and_fetch_all(cursor, "SHOW REPLICAS;")) + expected_data = [ + ( + "replica", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + actual_data = execute_and_fetch_all(cursor, "SHOW REPLICAS;") - assert actual_data == expected_data + assert all([x in actual_data for x in expected_data]) def check_roles(): assert "main" == interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICATION ROLE;")[0][0] @@ -612,33 +982,45 @@ def test_replication_role_recovery(connection): check_roles() def retrieve_data(): - return set(execute_and_fetch_all(cursor, "SHOW REPLICAS;")) + return execute_and_fetch_all(cursor, "SHOW REPLICAS;") - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) - assert actual_data == expected_data + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) + assert all([x in actual_data for x in expected_data]) # 4/ interactive_mg_runner.kill(CONFIGURATION, "replica") # 5/ - expected_data = { - ("replica", "127.0.0.1:10001", "sync", 0, 0, "invalid"), - } - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) + expected_data = [ + ( + "replica", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "invalid"}}, + ), + ] + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) - assert actual_data == expected_data + assert all([x in actual_data for x in expected_data]) # 6/ interactive_mg_runner.start(CONFIGURATION, "replica") check_roles() # 7/ - expected_data = { - ("replica", "127.0.0.1:10001", "sync", 0, 0, "ready"), - } + expected_data = [ + ( + "replica", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) - assert actual_data == expected_data + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) + assert all([x in actual_data for x in expected_data]) # 8/ interactive_mg_runner.kill(CONFIGURATION, "replica") @@ -651,11 +1033,17 @@ def test_replication_role_recovery(connection): interactive_mg_runner.start(CONFIGURATION, "replica") check_roles() - expected_data = { - ("replica", "127.0.0.1:10001", "sync", 2, 0, "ready"), - } - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) - assert actual_data == expected_data + expected_data = [ + ( + "replica", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 2, "behind": 0, "status": "ready"}}, + ), + ] + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) + assert all([x in actual_data for x in expected_data]) QUERY_TO_CHECK = "MATCH (node) return node;" res_from_main = execute_and_fetch_all(cursor, QUERY_TO_CHECK) @@ -733,13 +1121,25 @@ def test_basic_recovery_when_replica_is_kill_when_main_is_down(): interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("REGISTER REPLICA replica_2 SYNC TO '127.0.0.1:10002';") # 1/ - expected_data = { - ("replica_1", "127.0.0.1:10001", "sync", 0, 0, "ready"), - ("replica_2", "127.0.0.1:10002", "sync", 0, 0, "ready"), - } - actual_data = set(interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICAS;")) + expected_data = [ + ( + "replica_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + actual_data = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICAS;") - assert actual_data == expected_data + assert all([x in actual_data for x in expected_data]) def check_roles(): assert "main" == interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICATION ROLE;")[0][0] @@ -759,12 +1159,24 @@ def test_basic_recovery_when_replica_is_kill_when_main_is_down(): interactive_mg_runner.start(CONFIGURATION, "main") # 4/ - expected_data = { - ("replica_1", "127.0.0.1:10001", "sync", 0, 0, "ready"), - ("replica_2", "127.0.0.1:10002", "sync", 0, 0, "invalid"), - } - actual_data = set(interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICAS;")) - assert actual_data == expected_data + expected_data = [ + ( + "replica_1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "invalid"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "invalid"}}, + ), + ] + actual_data = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICAS;") + assert all([x in actual_data for x in expected_data]) def test_async_replication_when_main_is_killed(): @@ -795,7 +1207,7 @@ def test_async_replication_when_main_is_killed(): "data_directory": f"{data_directory_main.name}", }, } - + interactive_mg_runner.kill_all(CONFIGURATION) interactive_mg_runner.start_all(CONFIGURATION) # 1/ @@ -812,12 +1224,12 @@ def test_async_replication_when_main_is_killed(): def retrieve_data(): replicas = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICAS;") return [ - (replica_name, ip, mode, status) - for replica_name, ip, mode, timestamp, timestamp_behind_main, status in replicas + (replica_name, ip, mode, info["memgraph"]["status"]) + for replica_name, ip, mode, sys_info, info in replicas ] - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) - assert actual_data == expected_data + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) + assert all([x in actual_data for x in expected_data]) for index in range(5, 50): interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query(f"CREATE (p:Number {{name:{index}}})") @@ -878,7 +1290,7 @@ def test_sync_replication_when_main_is_killed(): "data_directory": f"{data_directory_main.name}", }, } - + interactive_mg_runner.kill_all(CONFIGURATION) interactive_mg_runner.start_all(CONFIGURATION) # 1/ @@ -941,12 +1353,24 @@ def test_attempt_to_write_data_on_main_when_async_replica_is_down(): interactive_mg_runner.start_all(CONFIGURATION) # 1/ - expected_data = { - ("async_replica1", "127.0.0.1:10001", "async", 0, 0, "ready"), - ("async_replica2", "127.0.0.1:10002", "async", 0, 0, "ready"), - } - actual_data = set(interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICAS;")) - assert actual_data == expected_data + expected_data = [ + ( + "async_replica1", + "127.0.0.1:10001", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "async_replica2", + "127.0.0.1:10002", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + actual_data = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICAS;") + assert all([x in actual_data for x in expected_data]) # 2/ interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("CREATE (p:Number {name:1});") @@ -972,12 +1396,12 @@ def test_attempt_to_write_data_on_main_when_async_replica_is_down(): def retrieve_data(): replicas = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICAS;") return [ - (replica_name, mode, timestamp_behind_main, status) - for replica_name, ip, mode, timestamp, timestamp_behind_main, status in replicas + (replica_name, mode, info["memgraph"]["behind"], info["memgraph"]["status"]) + for replica_name, ip, mode, sys_info, info in replicas ] - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) - assert actual_data == expected_data + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) + assert all([x in actual_data for x in expected_data]) # 6/ res_from_main = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query(QUERY_TO_CHECK) @@ -1038,12 +1462,24 @@ def test_attempt_to_write_data_on_main_when_sync_replica_is_down(connection): execute_and_fetch_all(main_cursor, "REGISTER REPLICA sync_replica2 SYNC TO '127.0.0.1:10002';") # 1/ - expected_data = { - ("sync_replica1", "127.0.0.1:10001", "sync", 0, 0, "ready"), - ("sync_replica2", "127.0.0.1:10002", "sync", 0, 0, "ready"), - } - actual_data = set(interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICAS;")) - assert actual_data == expected_data + expected_data = [ + ( + "sync_replica1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "sync_replica2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + actual_data = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICAS;") + assert all([x in actual_data for x in expected_data]) # 2/ interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("CREATE (p:Number {name:1});") @@ -1064,12 +1500,12 @@ def test_attempt_to_write_data_on_main_when_sync_replica_is_down(connection): def retrieve_data(): replicas = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICAS;") return [ - (replica_name, mode, timestamp_behind_main, status) - for replica_name, ip, mode, timestamp, timestamp_behind_main, status in replicas + (replica_name, mode, info["memgraph"]["behind"], info["memgraph"]["status"]) + for replica_name, ip, mode, sys_info, info in replicas ] - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) - assert actual_data == expected_data + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) + assert all([x in actual_data for x in expected_data]) # 4/ with pytest.raises(mgclient.DatabaseError): @@ -1080,13 +1516,25 @@ def test_attempt_to_write_data_on_main_when_sync_replica_is_down(connection): assert res_from_main == interactive_mg_runner.MEMGRAPH_INSTANCES["sync_replica2"].query(QUERY_TO_CHECK) # 5/ - expected_data = { - ("sync_replica1", "127.0.0.1:10001", "sync", 0, 0, "invalid"), - ("sync_replica2", "127.0.0.1:10002", "sync", 5, 0, "ready"), - } + expected_data = [ + ( + "sync_replica1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "invalid"}}, + ), + ( + "sync_replica2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 5, "behind": 0, "status": "ready"}}, + ), + ] res_from_main = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query(QUERY_TO_CHECK) - actual_data = set(interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICAS;")) - assert actual_data == expected_data + actual_data = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICAS;") + assert all([x in actual_data for x in expected_data]) # 6/ interactive_mg_runner.start(CONFIGURATION, "sync_replica1") @@ -1094,8 +1542,8 @@ def test_attempt_to_write_data_on_main_when_sync_replica_is_down(connection): ("sync_replica1", "sync", 0, "ready"), ("sync_replica2", "sync", 0, "ready"), ] - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) - assert actual_data == expected_data + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) + assert all([x in actual_data for x in expected_data]) res_from_main = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query(QUERY_TO_CHECK) assert len(res_from_main) == 2 assert res_from_main == interactive_mg_runner.MEMGRAPH_INSTANCES["sync_replica1"].query(QUERY_TO_CHECK) @@ -1137,12 +1585,24 @@ def test_attempt_to_create_indexes_on_main_when_async_replica_is_down(): interactive_mg_runner.start_all(CONFIGURATION) # 1/ - expected_data = { - ("async_replica1", "127.0.0.1:10001", "async", 0, 0, "ready"), - ("async_replica2", "127.0.0.1:10002", "async", 0, 0, "ready"), - } - actual_data = set(interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICAS;")) - assert actual_data == expected_data + expected_data = [ + ( + "async_replica1", + "127.0.0.1:10001", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "async_replica2", + "127.0.0.1:10002", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + actual_data = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICAS;") + assert all([x in actual_data for x in expected_data]) # 2/ interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("CREATE INDEX ON :Number(value);") @@ -1168,12 +1628,12 @@ def test_attempt_to_create_indexes_on_main_when_async_replica_is_down(): def retrieve_data(): replicas = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICAS;") return [ - (replica_name, mode, timestamp_behind_main, status) - for replica_name, ip, mode, timestamp, timestamp_behind_main, status in replicas + (replica_name, mode, info["memgraph"]["behind"], info["memgraph"]["status"]) + for replica_name, ip, mode, sys_info, info in replicas ] - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) - assert actual_data == expected_data + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) + assert all([x in actual_data for x in expected_data]) # 6/ res_from_main = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query(QUERY_TO_CHECK) @@ -1234,12 +1694,24 @@ def test_attempt_to_create_indexes_on_main_when_sync_replica_is_down(connection) execute_and_fetch_all(cursor, "REGISTER REPLICA sync_replica2 SYNC TO '127.0.0.1:10002';") # 1/ - expected_data = { - ("sync_replica1", "127.0.0.1:10001", "sync", 0, 0, "ready"), - ("sync_replica2", "127.0.0.1:10002", "sync", 0, 0, "ready"), - } - actual_data = set(interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICAS;")) - assert actual_data == expected_data + expected_data = [ + ( + "sync_replica1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "sync_replica2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + actual_data = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICAS;") + assert all([x in actual_data for x in expected_data]) # 2/ interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("CREATE INDEX ON :Number(value);") @@ -1260,12 +1732,12 @@ def test_attempt_to_create_indexes_on_main_when_sync_replica_is_down(connection) def retrieve_data(): replicas = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICAS;") return [ - (replica_name, mode, timestamp_behind_main, status) - for replica_name, ip, mode, timestamp, timestamp_behind_main, status in replicas + (replica_name, mode, info["memgraph"]["behind"], info["memgraph"]["status"]) + for replica_name, ip, mode, sys_info, info in replicas ] - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) - assert actual_data == expected_data + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) + assert all([x in actual_data for x in expected_data]) # 4/ with pytest.raises(mgclient.DatabaseError): @@ -1276,14 +1748,26 @@ def test_attempt_to_create_indexes_on_main_when_sync_replica_is_down(connection) assert res_from_main == interactive_mg_runner.MEMGRAPH_INSTANCES["sync_replica2"].query(QUERY_TO_CHECK) # 5/ - expected_data = { - ("sync_replica1", "127.0.0.1:10001", "sync", 0, 0, "invalid"), - ("sync_replica2", "127.0.0.1:10002", "sync", 6, 0, "ready"), - } + expected_data = [ + ( + "sync_replica1", + "127.0.0.1:10001", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 0, "behind": 0, "status": "invalid"}}, + ), + ( + "sync_replica2", + "127.0.0.1:10002", + "sync", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 6, "behind": 0, "status": "ready"}}, + ), + ] res_from_main = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query(QUERY_TO_CHECK) assert res_from_main == interactive_mg_runner.MEMGRAPH_INSTANCES["sync_replica2"].query(QUERY_TO_CHECK) - actual_data = set(interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICAS;")) - assert actual_data == expected_data + actual_data = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICAS;") + assert all([x in actual_data for x in expected_data]) # 6/ interactive_mg_runner.start(CONFIGURATION, "sync_replica1") @@ -1291,8 +1775,8 @@ def test_attempt_to_create_indexes_on_main_when_sync_replica_is_down(connection) ("sync_replica1", "sync", 0, "ready"), ("sync_replica2", "sync", 0, "ready"), ] - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) - assert actual_data == expected_data + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) + assert all([x in actual_data for x in expected_data]) res_from_main = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query(QUERY_TO_CHECK) assert len(res_from_main) == 2 assert res_from_main == interactive_mg_runner.MEMGRAPH_INSTANCES["sync_replica1"].query(QUERY_TO_CHECK) @@ -1388,12 +1872,12 @@ def test_trigger_on_create_before_commit_with_offline_sync_replica(connection): def retrieve_data(): replicas = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICAS;") return [ - (replica_name, mode, timestamp_behind_main, status) - for replica_name, ip, mode, timestamp, timestamp_behind_main, status in replicas + (replica_name, mode, info["memgraph"]["behind"], info["memgraph"]["status"]) + for replica_name, ip, mode, sys_info, info in replicas ] - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) - assert actual_data == expected_data + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) + assert all([x in actual_data for x in expected_data]) # 6/ with pytest.raises(mgclient.DatabaseError): @@ -1410,8 +1894,8 @@ def test_trigger_on_create_before_commit_with_offline_sync_replica(connection): ("sync_replica1", "sync", 0, "ready"), ("sync_replica2", "sync", 0, "ready"), ] - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) - assert actual_data == expected_data + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) + assert all([x in actual_data for x in expected_data]) res_from_main = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query(QUERY_TO_CHECK) assert len(res_from_main) == 2 assert res_from_main == interactive_mg_runner.MEMGRAPH_INSTANCES["sync_replica1"].query(QUERY_TO_CHECK) @@ -1511,12 +1995,12 @@ def test_trigger_on_update_before_commit_with_offline_sync_replica(connection): def retrieve_data(): replicas = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICAS;") return [ - (replica_name, mode, timestamp_behind_main, status) - for replica_name, ip, mode, timestamp, timestamp_behind_main, status in replicas + (replica_name, mode, info["memgraph"]["behind"], info["memgraph"]["status"]) + for replica_name, ip, mode, sys_info, info in replicas ] - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) - assert actual_data == expected_data + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) + assert all([x in actual_data for x in expected_data]) # 7/ with pytest.raises(mgclient.DatabaseError): @@ -1533,8 +2017,8 @@ def test_trigger_on_update_before_commit_with_offline_sync_replica(connection): ("sync_replica1", "sync", 0, "ready"), ("sync_replica2", "sync", 0, "ready"), ] - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) - assert actual_data == expected_data + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) + assert all([x in actual_data for x in expected_data]) res_from_main = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query(QUERY_TO_CHECK) assert len(res_from_main) == 2 assert res_from_main == interactive_mg_runner.MEMGRAPH_INSTANCES["sync_replica1"].query(QUERY_TO_CHECK) @@ -1637,12 +2121,12 @@ def test_trigger_on_delete_before_commit_with_offline_sync_replica(connection): def retrieve_data(): replicas = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICAS;") return [ - (replica_name, mode, timestamp_behind_main, status) - for replica_name, ip, mode, timestamp, timestamp_behind_main, status in replicas + (replica_name, mode, info["memgraph"]["behind"], info["memgraph"]["status"]) + for replica_name, ip, mode, sys_info, info in replicas ] - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) - assert actual_data == expected_data + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) + assert all([x in actual_data for x in expected_data]) # 7/ with pytest.raises(mgclient.DatabaseError): @@ -1660,8 +2144,8 @@ def test_trigger_on_delete_before_commit_with_offline_sync_replica(connection): ("sync_replica1", "sync", 0, "ready"), ("sync_replica2", "sync", 0, "ready"), ] - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) - assert actual_data == expected_data + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) + assert all([x in actual_data for x in expected_data]) res_from_main = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query(QUERY_TO_CHECK) assert len(res_from_main) == 1 assert res_from_main[0][0].properties["name"] == "Node_created_by_trigger" @@ -1762,12 +2246,12 @@ def test_trigger_on_create_before_and_after_commit_with_offline_sync_replica(con def retrieve_data(): replicas = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICAS;") return [ - (replica_name, mode, timestamp_behind_main, status) - for replica_name, ip, mode, timestamp, timestamp_behind_main, status in replicas + (replica_name, mode, info["memgraph"]["behind"], info["memgraph"]["status"]) + for replica_name, ip, mode, sys_info, info in replicas ] - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) - assert actual_data == expected_data + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) + assert all([x in actual_data for x in expected_data]) # 6/ with pytest.raises(mgclient.DatabaseError): @@ -1784,8 +2268,8 @@ def test_trigger_on_create_before_and_after_commit_with_offline_sync_replica(con ("sync_replica1", "sync", 0, "ready"), ("sync_replica2", "sync", 0, "ready"), ] - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) - assert actual_data == expected_data + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) + assert all([x in actual_data for x in expected_data]) res_from_main = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query(QUERY_TO_CHECK) assert len(res_from_main) == 3 assert res_from_main == interactive_mg_runner.MEMGRAPH_INSTANCES["sync_replica1"].query(QUERY_TO_CHECK) @@ -1885,12 +2369,12 @@ def test_triggers_on_create_before_commit_with_offline_sync_replica(connection): def retrieve_data(): replicas = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query("SHOW REPLICAS;") return [ - (replica_name, mode, timestamp_behind_main, status) - for replica_name, ip, mode, timestamp, timestamp_behind_main, status in replicas + (replica_name, mode, info["memgraph"]["behind"], info["memgraph"]["status"]) + for replica_name, ip, mode, sys_info, info in replicas ] - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) - assert actual_data == expected_data + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) + assert all([x in actual_data for x in expected_data]) # 6/ with pytest.raises(mgclient.DatabaseError): @@ -1912,8 +2396,8 @@ def test_triggers_on_create_before_commit_with_offline_sync_replica(connection): ("sync_replica1", "sync", 0, "ready"), ("sync_replica2", "sync", 0, "ready"), ] - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) - assert actual_data == expected_data + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) + assert all([x in actual_data for x in expected_data]) res_from_main = interactive_mg_runner.MEMGRAPH_INSTANCES["main"].query(QUERY_TO_CHECK) assert len(res_from_main) == 3 assert res_from_main == interactive_mg_runner.MEMGRAPH_INSTANCES["sync_replica1"].query(QUERY_TO_CHECK) @@ -1976,10 +2460,16 @@ def test_replication_not_messed_up_by_ShowIndexInfo(connection): return replicas expected_data = [ - ("replica_1", "127.0.0.1:10001", "async", 2, 0, "ready"), + ( + "replica_1", + "127.0.0.1:10001", + "async", + {"ts": 0, "behind": None, "status": "ready"}, + {"memgraph": {"ts": 2, "behind": 0, "status": "ready"}}, + ), ] - actual_data = mg_sleep_and_assert(expected_data, retrieve_data) - assert actual_data == expected_data + actual_data = mg_sleep_and_assert_collection(expected_data, retrieve_data) + assert all([x in actual_data for x in expected_data]) # 3/ cursor = connection(7688, "replica_1").cursor() @@ -1990,5 +2480,4 @@ def test_replication_not_messed_up_by_ShowIndexInfo(connection): if __name__ == "__main__": - sys.exit(pytest.main([__file__, "-k", "test_basic_recovery"])) sys.exit(pytest.main([__file__, "-rA"])) diff --git a/tests/e2e/replication_experimental/auth.py b/tests/e2e/replication_experimental/auth.py index 44738ccbd..60da2f513 100644 --- a/tests/e2e/replication_experimental/auth.py +++ b/tests/e2e/replication_experimental/auth.py @@ -121,6 +121,7 @@ def only_main_queries(cursor): n_exceptions += try_and_count(cursor, f"REVOKE EDGE_TYPES :e FROM user_name") n_exceptions += try_and_count(cursor, f"GRANT DATABASE memgraph TO user_name;") n_exceptions += try_and_count(cursor, f"SET MAIN DATABASE memgraph FOR user_name") + n_exceptions += try_and_count(cursor, f"DENY DATABASE memgraph FROM user_name;") n_exceptions += try_and_count(cursor, f"REVOKE DATABASE memgraph FROM user_name;") return n_exceptions @@ -198,8 +199,8 @@ def test_auth_queries_on_replica(connection): # 1/ assert only_main_queries(cursor_main) == 0 - assert only_main_queries(cursor_replica_1) == 17 - assert only_main_queries(cursor_replica_2) == 17 + assert only_main_queries(cursor_replica_1) == 18 + assert only_main_queries(cursor_replica_2) == 18 assert main_and_repl_queries(cursor_main) == 0 assert main_and_repl_queries(cursor_replica_1) == 0 assert main_and_repl_queries(cursor_replica_2) == 0 @@ -383,6 +384,7 @@ def test_manual_roles_recovery(connection): "--log-level=TRACE", "--data_directory", TEMP_DIR + "/replica1", + "--also-log-to-stderr", ], "log_file": "replica1.log", "setup_queries": [ @@ -818,13 +820,15 @@ def test_auth_replication(connection): {("LABEL :l3", "UPDATE", "LABEL PERMISSION GRANTED TO ROLE")}, ) - # GRANT/REVOKE DATABASE + # GRANT/DENY DATABASE execute_and_fetch_all(cursor_main, "CREATE DATABASE auth_test") execute_and_fetch_all(cursor_main, "CREATE DATABASE auth_test2") execute_and_fetch_all(cursor_main, "GRANT DATABASE auth_test TO user4") check(partial(show_database_privileges_func, user="user4"), [(["auth_test", "memgraph"], [])]) - execute_and_fetch_all(cursor_main, "REVOKE DATABASE auth_test2 FROM user4") + execute_and_fetch_all(cursor_main, "DENY DATABASE auth_test2 FROM user4") check(partial(show_database_privileges_func, user="user4"), [(["auth_test", "memgraph"], ["auth_test2"])]) + execute_and_fetch_all(cursor_main, "REVOKE DATABASE memgraph FROM user4") + check(partial(show_database_privileges_func, user="user4"), [(["auth_test"], ["auth_test2"])]) # SET MAIN DATABASE execute_and_fetch_all(cursor_main, "GRANT ALL PRIVILEGES TO user4") diff --git a/tests/e2e/replication_experimental/multitenancy.py b/tests/e2e/replication_experimental/multitenancy.py index cad5f24ca..8715a7261 100644 --- a/tests/e2e/replication_experimental/multitenancy.py +++ b/tests/e2e/replication_experimental/multitenancy.py @@ -22,7 +22,7 @@ import interactive_mg_runner import mgclient import pytest from common import execute_and_fetch_all -from mg_utils import mg_sleep_and_assert +from mg_utils import mg_sleep_and_assert, mg_sleep_and_assert_collection interactive_mg_runner.SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) interactive_mg_runner.PROJECT_DIR = os.path.normpath( @@ -35,6 +35,10 @@ BOLT_PORTS = {"main": 7687, "replica_1": 7688, "replica_2": 7689} REPLICATION_PORTS = {"replica_1": 10001, "replica_2": 10002} +def set_eq(actual, expected): + return len(actual) == len(expected) and all([x in actual for x in expected]) + + def create_memgraph_instances_with_role_recovery(data_directory: Any) -> Dict[str, Any]: return { "replica_1": { @@ -174,10 +178,9 @@ def setup_main(main_cursor): execute_and_fetch_all(main_cursor, "CREATE (:Node{on:'B'});") -def show_replicas_func(cursor, db_name): +def show_replicas_func(cursor): def func(): - execute_and_fetch_all(cursor, f"USE DATABASE {db_name};") - return set(execute_and_fetch_all(cursor, "SHOW REPLICAS;")) + return execute_and_fetch_all(cursor, "SHOW REPLICAS;") return func @@ -271,17 +274,31 @@ def test_manual_databases_create_multitenancy_replication(connection): execute_and_fetch_all(cursor, "CREATE ()-[:EDGE]->();") # 2/ - expected_data = { - ("replica_1", f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", "sync", 1, 0, "ready"), - ("replica_2", f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", "async", 1, 0, "ready"), - } - mg_sleep_and_assert(expected_data, show_replicas_func(cursor, "A")) - - expected_data = { - ("replica_1", f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", "sync", 1, 0, "ready"), - ("replica_2", f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", "async", 1, 0, "ready"), - } - mg_sleep_and_assert(expected_data, show_replicas_func(cursor, "B")) + expected_data = [ + ( + "replica_1", + f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", + "sync", + {"ts": 2, "behind": None, "status": "ready"}, + { + "A": {"ts": 1, "behind": 0, "status": "ready"}, + "B": {"ts": 1, "behind": 0, "status": "ready"}, + "memgraph": {"ts": 0, "behind": 0, "status": "ready"}, + }, + ), + ( + "replica_2", + f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", + "async", + {"ts": 2, "behind": None, "status": "ready"}, + { + "A": {"ts": 1, "behind": 0, "status": "ready"}, + "B": {"ts": 1, "behind": 0, "status": "ready"}, + "memgraph": {"ts": 0, "behind": 0, "status": "ready"}, + }, + ), + ] + mg_sleep_and_assert_collection(expected_data, show_replicas_func(cursor)) cursor_replica = connection(BOLT_PORTS["replica_1"], "replica").cursor() assert get_number_of_nodes_func(cursor_replica, "A")() == 1 @@ -523,11 +540,23 @@ def test_manual_databases_create_multitenancy_replication_main_behind(connection execute_and_fetch_all(main_cursor, "CREATE DATABASE A;") # 2/ - expected_data = { - ("replica_1", f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", "sync", 0, 0, "ready"), - ("replica_2", f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", "async", 0, 0, "ready"), - } - mg_sleep_and_assert(expected_data, show_replicas_func(main_cursor, "A")) + expected_data = [ + ( + "replica_1", + f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", + "sync", + {"ts": 3, "behind": None, "status": "ready"}, + {"A": {"ts": 0, "behind": 0, "status": "ready"}, "memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_2", + f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", + "async", + {"ts": 3, "behind": None, "status": "ready"}, + {"A": {"ts": 0, "behind": 0, "status": "ready"}, "memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + mg_sleep_and_assert_collection(expected_data, show_replicas_func(main_cursor)) databases_on_main = show_databases_func(main_cursor)() @@ -567,17 +596,31 @@ def test_automatic_databases_create_multitenancy_replication(connection): execute_and_fetch_all(main_cursor, "CREATE (:Node)-[:EDGE]->(:Node)") # 3/ - expected_data = { - ("replica_1", f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", "sync", 7, 0, "ready"), - ("replica_2", f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", "async", 7, 0, "ready"), - } - mg_sleep_and_assert(expected_data, show_replicas_func(main_cursor, "A")) - - expected_data = { - ("replica_1", f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", "sync", 0, 0, "ready"), - ("replica_2", f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", "async", 0, 0, "ready"), - } - mg_sleep_and_assert(expected_data, show_replicas_func(main_cursor, "B")) + expected_data = [ + ( + "replica_1", + f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", + "sync", + {"ts": 4, "behind": None, "status": "ready"}, + { + "A": {"ts": 7, "behind": 0, "status": "ready"}, + "B": {"ts": 0, "behind": 0, "status": "ready"}, + "memgraph": {"ts": 0, "behind": 0, "status": "ready"}, + }, + ), + ( + "replica_2", + f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", + "async", + {"ts": 4, "behind": None, "status": "ready"}, + { + "A": {"ts": 7, "behind": 0, "status": "ready"}, + "B": {"ts": 0, "behind": 0, "status": "ready"}, + "memgraph": {"ts": 0, "behind": 0, "status": "ready"}, + }, + ), + ] + mg_sleep_and_assert_collection(expected_data, show_replicas_func(main_cursor)) cursor_replica = connection(BOLT_PORTS["replica_1"], "replica").cursor() assert get_number_of_nodes_func(cursor_replica, "A")() == 7 @@ -640,19 +683,26 @@ def test_automatic_databases_multitenancy_replication_predefined(connection): execute_and_fetch_all(cursor, "CREATE ()-[:EDGE]->();") # 2/ - expected_data = { - ("replica_1", f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", "sync", 1, 0, "ready"), - } - mg_sleep_and_assert(expected_data, show_replicas_func(cursor, "A")) - - expected_data = { - ("replica_1", f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", "sync", 1, 0, "ready"), - } - mg_sleep_and_assert(expected_data, show_replicas_func(cursor, "B")) + expected_data = [ + ( + "replica_1", + f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", + "sync", + {"ts": 2, "behind": None, "status": "ready"}, + { + "A": {"ts": 1, "behind": 0, "status": "ready"}, + "B": {"ts": 1, "behind": 0, "status": "ready"}, + "memgraph": {"ts": 0, "behind": 0, "status": "ready"}, + }, + ), + ] + mg_sleep_and_assert_collection(expected_data, show_replicas_func(cursor)) cursor_replica = connection(BOLT_PORTS["replica_1"], "replica").cursor() assert get_number_of_nodes_func(cursor_replica, "A")() == 1 assert get_number_of_edges_func(cursor_replica, "A")() == 0 + assert get_number_of_nodes_func(cursor_replica, "B")() == 2 + assert get_number_of_edges_func(cursor_replica, "B")() == 1 def test_automatic_databases_create_multitenancy_replication_dirty_main(connection): @@ -698,10 +748,16 @@ def test_automatic_databases_create_multitenancy_replication_dirty_main(connecti cursor = connection(BOLT_PORTS["main"], "main").cursor() # 1/ - expected_data = { - ("replica_1", f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", "sync", 1, 0, "ready"), - } - mg_sleep_and_assert(expected_data, show_replicas_func(cursor, "A")) + expected_data = [ + ( + "replica_1", + f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", + "sync", + {"ts": 1, "behind": None, "status": "ready"}, + {"A": {"ts": 1, "behind": 0, "status": "ready"}, "memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + mg_sleep_and_assert_collection(expected_data, show_replicas_func(cursor)) cursor_replica = connection(BOLT_PORTS["replica_1"], "replica").cursor() execute_and_fetch_all(cursor_replica, "USE DATABASE A;") @@ -740,31 +796,85 @@ def test_multitenancy_replication_restart_replica_w_fc(connection, replica_name) time.sleep(3) # In order for the frequent check to run # Check that the FC did invalidate expected_data = { - "replica_1": { - ("replica_1", f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", "sync", 0, 0, "invalid"), - ("replica_2", f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", "async", 7, 0, "ready"), - }, - "replica_2": { - ("replica_1", f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", "sync", 7, 0, "ready"), - ("replica_2", f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", "async", 0, 0, "invalid"), - }, + "replica_1": [ + ( + "replica_1", + f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", + "sync", + {"ts": 4, "behind": None, "status": "ready"}, + { + "A": {"ts": 0, "behind": 0, "status": "invalid"}, + "B": {"ts": 0, "behind": 0, "status": "invalid"}, + "memgraph": {"ts": 0, "behind": 0, "status": "invalid"}, + }, + ), + ( + "replica_2", + f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", + "async", + {"ts": 4, "behind": None, "status": "ready"}, + { + "A": {"ts": 7, "behind": 0, "status": "ready"}, + "B": {"ts": 3, "behind": 0, "status": "ready"}, + "memgraph": {"ts": 0, "behind": 0, "status": "ready"}, + }, + ), + ], + "replica_2": [ + ( + "replica_1", + f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", + "sync", + {"ts": 4, "behind": None, "status": "ready"}, + { + "A": {"ts": 7, "behind": 0, "status": "ready"}, + "B": {"ts": 3, "behind": 0, "status": "ready"}, + "memgraph": {"ts": 0, "behind": 0, "status": "ready"}, + }, + ), + ( + "replica_2", + f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", + "async", + {"ts": 4, "behind": None, "status": "ready"}, + { + "A": {"ts": 0, "behind": 0, "status": "invalid"}, + "B": {"ts": 0, "behind": 0, "status": "invalid"}, + "memgraph": {"ts": 0, "behind": 0, "status": "invalid"}, + }, + ), + ], } - assert expected_data[replica_name] == show_replicas_func(main_cursor, "A")() + assert set_eq(expected_data[replica_name], show_replicas_func(main_cursor)()) # Restart interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION, replica_name) # 4/ - expected_data = { - ("replica_1", f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", "sync", 7, 0, "ready"), - ("replica_2", f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", "async", 7, 0, "ready"), - } - mg_sleep_and_assert(expected_data, show_replicas_func(main_cursor, "A")) - - expected_data = { - ("replica_1", f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", "sync", 3, 0, "ready"), - ("replica_2", f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", "async", 3, 0, "ready"), - } - mg_sleep_and_assert(expected_data, show_replicas_func(main_cursor, "B")) + expected_data = [ + ( + "replica_1", + f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", + "sync", + {"ts": 4, "behind": None, "status": "ready"}, + { + "A": {"ts": 7, "behind": 0, "status": "ready"}, + "B": {"ts": 3, "behind": 0, "status": "ready"}, + "memgraph": {"ts": 0, "behind": 0, "status": "ready"}, + }, + ), + ( + "replica_2", + f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", + "async", + {"ts": 4, "behind": None, "status": "ready"}, + { + "A": {"ts": 7, "behind": 0, "status": "ready"}, + "B": {"ts": 3, "behind": 0, "status": "ready"}, + "memgraph": {"ts": 0, "behind": 0, "status": "ready"}, + }, + ), + ] + mg_sleep_and_assert_collection(expected_data, show_replicas_func(main_cursor)) cursor_replica = connection(BOLT_PORTS[replica_name], "replica").cursor() @@ -805,19 +915,33 @@ def test_multitenancy_replication_restart_replica_wo_fc(connection, replica_name interactive_mg_runner.start(MEMGRAPH_INSTANCES_DESCRIPTION, replica_name) # 4/ - expected_data = { - ("replica_1", f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", "sync", 7, 0, "ready"), - ("replica_2", f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", "async", 7, 0, "ready"), - } - mg_sleep_and_assert(expected_data, show_replicas_func(main_cursor, "A")) + expected_data = [ + ( + "replica_1", + f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", + "sync", + {"ts": 4, "behind": None, "status": "ready"}, + { + "A": {"ts": 7, "behind": 0, "status": "ready"}, + "B": {"ts": 3, "behind": 0, "status": "ready"}, + "memgraph": {"ts": 0, "behind": 0, "status": "ready"}, + }, + ), + ( + "replica_2", + f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", + "async", + {"ts": 4, "behind": None, "status": "ready"}, + { + "A": {"ts": 7, "behind": 0, "status": "ready"}, + "B": {"ts": 3, "behind": 0, "status": "ready"}, + "memgraph": {"ts": 0, "behind": 0, "status": "ready"}, + }, + ), + ] + mg_sleep_and_assert_collection(expected_data, show_replicas_func(main_cursor)) - expected_data = { - ("replica_1", f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", "sync", 3, 0, "ready"), - ("replica_2", f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", "async", 3, 0, "ready"), - } - mg_sleep_and_assert(expected_data, show_replicas_func(main_cursor, "B")) - - cursor_replica = connection(BOLT_PORTS[replica_name], "replica").cursor() + cursor_replica = connection(BOLT_PORTS[replica_name], replica_name).cursor() assert get_number_of_nodes_func(cursor_replica, "A")() == 7 assert get_number_of_edges_func(cursor_replica, "A")() == 3 assert get_number_of_nodes_func(cursor_replica, "B")() == 2 @@ -899,17 +1023,28 @@ def test_multitenancy_replication_drop_replica(connection, replica_name): ) # 4/ - expected_data = { - ("replica_1", f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", "sync", 7, 0, "ready"), - ("replica_2", f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", "async", 7, 0, "ready"), - } - mg_sleep_and_assert(expected_data, show_replicas_func(main_cursor, "A")) - - expected_data = { - ("replica_1", f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", "sync", 3, 0, "ready"), - ("replica_2", f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", "async", 3, 0, "ready"), - } - mg_sleep_and_assert(expected_data, show_replicas_func(main_cursor, "B")) + expected_data = [ + ( + "replica_1", + f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", + "sync", + { + "A": {"ts": 7, "behind": 0, "status": "ready"}, + "B": {"ts": 3, "behind": 0, "status": "ready"}, + "memgraph": {"ts": 0, "behind": 0, "status": "ready"}, + }, + ), + ( + "replica_2", + f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", + "async", + { + "A": {"ts": 7, "behind": 0, "status": "ready"}, + "B": {"ts": 3, "behind": 0, "status": "ready"}, + "memgraph": {"ts": 0, "behind": 0, "status": "ready"}, + }, + ), + ] cursor_replica = connection(BOLT_PORTS[replica_name], "replica").cursor() assert get_number_of_nodes_func(cursor_replica, "A")() == 7 @@ -993,17 +1128,31 @@ def test_automatic_databases_drop_multitenancy_replication(connection): execute_and_fetch_all(main_cursor, "CREATE (:Node{on:'A'});") # 3/ - expected_data = { - ("replica_1", f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", "sync", 1, 0, "ready"), - ("replica_2", f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", "async", 1, 0, "ready"), - } - mg_sleep_and_assert(expected_data, show_replicas_func(main_cursor, "A")) - - expected_data = { - ("replica_1", f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", "sync", 0, 0, "ready"), - ("replica_2", f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", "async", 0, 0, "ready"), - } - mg_sleep_and_assert(expected_data, show_replicas_func(main_cursor, "B")) + expected_data = [ + ( + "replica_1", + f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", + "sync", + {"ts": 4, "behind": None, "status": "ready"}, + { + "A": {"ts": 1, "behind": 0, "status": "ready"}, + "B": {"ts": 0, "behind": 0, "status": "ready"}, + "memgraph": {"ts": 0, "behind": 0, "status": "ready"}, + }, + ), + ( + "replica_2", + f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", + "async", + {"ts": 4, "behind": None, "status": "ready"}, + { + "A": {"ts": 1, "behind": 0, "status": "ready"}, + "B": {"ts": 0, "behind": 0, "status": "ready"}, + "memgraph": {"ts": 0, "behind": 0, "status": "ready"}, + }, + ), + ] + mg_sleep_and_assert(expected_data, show_replicas_func(main_cursor)) # 4/ execute_and_fetch_all(main_cursor, "USE DATABASE memgraph;") @@ -1088,11 +1237,23 @@ def test_multitenancy_drop_while_replica_using(connection): execute_and_fetch_all(main_cursor, "CREATE (:Node{on:'A'});") # 3/ - expected_data = { - ("replica_1", f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", "sync", 1, 0, "ready"), - ("replica_2", f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", "async", 1, 0, "ready"), - } - mg_sleep_and_assert(expected_data, show_replicas_func(main_cursor, "A")) + expected_data = [ + ( + "replica_1", + f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", + "sync", + {"ts": 3, "behind": None, "status": "ready"}, + {"A": {"ts": 1, "behind": 0, "status": "ready"}, "memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_2", + f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", + "async", + {"ts": 3, "behind": None, "status": "ready"}, + {"A": {"ts": 1, "behind": 0, "status": "ready"}, "memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + mg_sleep_and_assert_collection(expected_data, show_replicas_func(main_cursor)) # 4/ replica1_cursor = connection(BOLT_PORTS["replica_1"], "replica").cursor() @@ -1110,11 +1271,23 @@ def test_multitenancy_drop_while_replica_using(connection): execute_and_fetch_all(main_cursor, "CREATE DATABASE B;") execute_and_fetch_all(main_cursor, "USE DATABASE B;") - expected_data = { - ("replica_1", f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", "sync", 0, 0, "ready"), - ("replica_2", f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", "async", 0, 0, "ready"), - } - mg_sleep_and_assert(expected_data, show_replicas_func(main_cursor, "B")) + expected_data = [ + ( + "replica_1", + f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", + "sync", + {"ts": 8, "behind": None, "status": "ready"}, + {"B": {"ts": 0, "behind": 0, "status": "ready"}, "memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_2", + f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", + "async", + {"ts": 8, "behind": None, "status": "ready"}, + {"B": {"ts": 0, "behind": 0, "status": "ready"}, "memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + mg_sleep_and_assert(expected_data, show_replicas_func(main_cursor)) # 6/ assert execute_and_fetch_all(replica1_cursor, "MATCH(n) RETURN count(*);")[0][0] == 1 @@ -1163,11 +1336,23 @@ def test_multitenancy_drop_and_recreate_while_replica_using(connection): execute_and_fetch_all(main_cursor, "CREATE (:Node{on:'A'});") # 3/ - expected_data = { - ("replica_1", f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", "sync", 1, 0, "ready"), - ("replica_2", f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", "async", 1, 0, "ready"), - } - mg_sleep_and_assert(expected_data, show_replicas_func(main_cursor, "A")) + expected_data = [ + ( + "replica_1", + f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", + "sync", + {"ts": 3, "behind": None, "status": "ready"}, + {"A": {"ts": 1, "behind": 0, "status": "ready"}, "memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_2", + f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", + "async", + {"ts": 3, "behind": None, "status": "ready"}, + {"A": {"ts": 1, "behind": 0, "status": "ready"}, "memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + mg_sleep_and_assert_collection(expected_data, show_replicas_func(main_cursor)) # 4/ replica1_cursor = connection(BOLT_PORTS["replica_1"], "replica").cursor() @@ -1184,11 +1369,23 @@ def test_multitenancy_drop_and_recreate_while_replica_using(connection): execute_and_fetch_all(main_cursor, "CREATE DATABASE A;") execute_and_fetch_all(main_cursor, "USE DATABASE A;") - expected_data = { - ("replica_1", f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", "sync", 0, 0, "ready"), - ("replica_2", f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", "async", 0, 0, "ready"), - } - mg_sleep_and_assert(expected_data, show_replicas_func(main_cursor, "A")) + expected_data = [ + ( + "replica_1", + f"127.0.0.1:{REPLICATION_PORTS['replica_1']}", + "sync", + {"ts": 8, "behind": None, "status": "ready"}, + {"A": {"ts": 0, "behind": 0, "status": "ready"}, "memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ( + "replica_2", + f"127.0.0.1:{REPLICATION_PORTS['replica_2']}", + "async", + {"ts": 8, "behind": None, "status": "ready"}, + {"A": {"ts": 0, "behind": 0, "status": "ready"}, "memgraph": {"ts": 0, "behind": 0, "status": "ready"}}, + ), + ] + mg_sleep_and_assert(expected_data, show_replicas_func(main_cursor)) # 6/ assert execute_and_fetch_all(replica1_cursor, "MATCH(n) RETURN count(*);")[0][0] == 1 diff --git a/tests/e2e/transaction_queue/test_transaction_queue.py b/tests/e2e/transaction_queue/test_transaction_queue.py index 221243c50..a563b0aff 100644 --- a/tests/e2e/transaction_queue/test_transaction_queue.py +++ b/tests/e2e/transaction_queue/test_transaction_queue.py @@ -70,21 +70,26 @@ def test_multitenant_transactions(): # TODO Add SHOW TRANSACTIONS ON * that should return all transactions -def test_admin_has_one_transaction(): +def test_admin_has_one_transaction(request): """Creates admin and tests that he sees only one transaction.""" # a_cursor is used for creating admin user, simulates main thread superadmin_cursor = connect().cursor() execute_and_fetch_all(superadmin_cursor, "CREATE USER admin") execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin") execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin") + + def on_exit(): + execute_and_fetch_all(superadmin_cursor, "DROP USER admin") + + request.addfinalizer(on_exit) + admin_cursor = connect(username="admin", password="").cursor() process = multiprocessing.Process(target=show_transactions_test, args=(admin_cursor, 1)) process.start() process.join() - execute_and_fetch_all(superadmin_cursor, "DROP USER admin") -def test_user_can_see_its_transaction(): +def test_user_can_see_its_transaction(request): """Tests that user without privileges can see its own transaction""" superadmin_cursor = connect().cursor() execute_and_fetch_all(superadmin_cursor, "CREATE USER admin") @@ -92,20 +97,31 @@ def test_user_can_see_its_transaction(): execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin") execute_and_fetch_all(superadmin_cursor, "CREATE USER user") execute_and_fetch_all(superadmin_cursor, "REVOKE ALL PRIVILEGES FROM user") + + def on_exit(): + execute_and_fetch_all(superadmin_cursor, "DROP USER admin") + execute_and_fetch_all(superadmin_cursor, "DROP USER user") + + request.addfinalizer(on_exit) + user_cursor = connect(username="user", password="").cursor() process = multiprocessing.Process(target=show_transactions_test, args=(user_cursor, 1)) process.start() process.join() admin_cursor = connect(username="admin", password="").cursor() - execute_and_fetch_all(admin_cursor, "DROP USER user") - execute_and_fetch_all(admin_cursor, "DROP USER admin") -def test_explicit_transaction_output(): +def test_explicit_transaction_output(request): superadmin_cursor = connect().cursor() execute_and_fetch_all(superadmin_cursor, "CREATE USER admin") execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin") execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin") + + def on_exit(): + execute_and_fetch_all(superadmin_cursor, "DROP USER admin") + + request.addfinalizer(on_exit) + admin_connection = connect(username="admin", password="") admin_cursor = admin_connection.cursor() # Admin starts running explicit transaction @@ -123,10 +139,9 @@ def test_explicit_transaction_output(): assert show_results[1 - executing_index][2] == ["CREATE (n:Person {id_: 1})", "CREATE (n:Person {id_: 2})"] execute_and_fetch_all(superadmin_cursor, "ROLLBACK") - execute_and_fetch_all(superadmin_cursor, "DROP USER admin") -def test_superadmin_cannot_see_admin_can_see_admin(): +def test_superadmin_cannot_see_admin_can_see_admin(request): """Tests that superadmin cannot see the transaction created by admin but two admins can see and kill each other's transactions.""" superadmin_cursor = connect().cursor() execute_and_fetch_all(superadmin_cursor, "CREATE USER admin1") @@ -135,6 +150,13 @@ def test_superadmin_cannot_see_admin_can_see_admin(): execute_and_fetch_all(superadmin_cursor, "CREATE USER admin2") execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin2") execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin2") + + def on_exit(): + execute_and_fetch_all(superadmin_cursor, "DROP USER admin1") + execute_and_fetch_all(superadmin_cursor, "DROP USER admin2") + + request.addfinalizer(on_exit) + # Admin starts running infinite query admin_connection_1 = connect(username="admin1", password="") admin_cursor_1 = admin_connection_1.cursor() @@ -160,19 +182,23 @@ def test_superadmin_cannot_see_admin_can_see_admin(): # Kill transaction long_transaction_id = show_results[1 - executing_index][1] execute_and_fetch_all(admin_cursor_2, f"TERMINATE TRANSACTIONS '{long_transaction_id}'") - execute_and_fetch_all(superadmin_cursor, "DROP USER admin1") - execute_and_fetch_all(superadmin_cursor, "DROP USER admin2") admin_connection_1.close() admin_connection_2.close() -def test_admin_sees_superadmin(): +def test_admin_sees_superadmin(request): """Tests that admin created by superadmin can see the superadmin's transaction.""" superadmin_connection = connect() superadmin_cursor = superadmin_connection.cursor() execute_and_fetch_all(superadmin_cursor, "CREATE USER admin") execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin") execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin") + + def on_exit(): + execute_and_fetch_all(admin_cursor, "DROP USER admin") + + request.addfinalizer(on_exit) + # Admin starts running infinite query process = multiprocessing.Process( target=process_function, args=(superadmin_cursor, ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"]) @@ -194,17 +220,23 @@ def test_admin_sees_superadmin(): # Kill transaction long_transaction_id = show_results[1 - executing_index][1] execute_and_fetch_all(admin_cursor, f"TERMINATE TRANSACTIONS '{long_transaction_id}'") - execute_and_fetch_all(admin_cursor, "DROP USER admin") superadmin_connection.close() -def test_admin_can_see_user_transaction(): +def test_admin_can_see_user_transaction(request): """Tests that admin can see user's transaction and kill it.""" superadmin_cursor = connect().cursor() execute_and_fetch_all(superadmin_cursor, "CREATE USER admin") execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin") execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin") execute_and_fetch_all(superadmin_cursor, "CREATE USER user") + + def on_exit(): + execute_and_fetch_all(superadmin_cursor, "DROP USER admin") + execute_and_fetch_all(superadmin_cursor, "DROP USER user") + + request.addfinalizer(on_exit) + # Admin starts running infinite query admin_connection = connect(username="admin", password="") admin_cursor = admin_connection.cursor() @@ -229,13 +261,11 @@ def test_admin_can_see_user_transaction(): # Kill transaction long_transaction_id = show_results[1 - executing_index][1] execute_and_fetch_all(admin_cursor, f"TERMINATE TRANSACTIONS '{long_transaction_id}'") - execute_and_fetch_all(superadmin_cursor, "DROP USER admin") - execute_and_fetch_all(superadmin_cursor, "DROP USER user") admin_connection.close() user_connection.close() -def test_user_cannot_see_admin_transaction(): +def test_user_cannot_see_admin_transaction(request): """User cannot see admin's transaction but other admin can and he can kill it.""" # Superadmin creates two admins and one user superadmin_cursor = connect().cursor() @@ -246,6 +276,14 @@ def test_user_cannot_see_admin_transaction(): execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin2") execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin2") execute_and_fetch_all(superadmin_cursor, "CREATE USER user") + + def on_exit(): + execute_and_fetch_all(superadmin_cursor, "DROP USER admin1") + execute_and_fetch_all(superadmin_cursor, "DROP USER admin2") + execute_and_fetch_all(superadmin_cursor, "DROP USER user") + + request.addfinalizer(on_exit) + admin_connection_1 = connect(username="admin1", password="") admin_cursor_1 = admin_connection_1.cursor() admin_connection_2 = connect(username="admin2", password="") @@ -274,9 +312,6 @@ def test_user_cannot_see_admin_transaction(): # Kill transaction long_transaction_id = show_results[1 - executing_index][1] execute_and_fetch_all(admin_cursor_2, f"TERMINATE TRANSACTIONS '{long_transaction_id}'") - execute_and_fetch_all(superadmin_cursor, "DROP USER admin1") - execute_and_fetch_all(superadmin_cursor, "DROP USER admin2") - execute_and_fetch_all(superadmin_cursor, "DROP USER user") admin_connection_1.close() admin_connection_2.close() user_connection.close() @@ -300,12 +335,18 @@ def test_killing_multiple_non_existing_transactions(): assert results[i][1] == False # not killed -def test_admin_killing_multiple_non_existing_transactions(): +def test_admin_killing_multiple_non_existing_transactions(request): # Starting, superadmin admin superadmin_cursor = connect().cursor() execute_and_fetch_all(superadmin_cursor, "CREATE USER admin") execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin") execute_and_fetch_all(superadmin_cursor, "GRANT DATABASE * TO admin") + + def on_exit(): + execute_and_fetch_all(admin_cursor, "DROP USER admin") + + request.addfinalizer(on_exit) + # Connect with admin admin_cursor = connect(username="admin", password="").cursor() transactions_id = ["'1'", "'2'", "'3'"] @@ -314,7 +355,6 @@ def test_admin_killing_multiple_non_existing_transactions(): for i in range(len(results)): assert results[i][0] == eval(transactions_id[i]) # transaction id assert results[i][1] == False # not killed - execute_and_fetch_all(admin_cursor, "DROP USER admin") def test_user_killing_some_transactions(): diff --git a/tests/e2e/triggers/privilige_check.cpp b/tests/e2e/triggers/privilige_check.cpp index f2cad40d4..2d7ac0f1c 100644 --- a/tests/e2e/triggers/privilige_check.cpp +++ b/tests/e2e/triggers/privilige_check.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -13,7 +13,6 @@ #include #include -#include #include #include "common.hpp" #include "utils/logging.hpp" diff --git a/tests/gql_behave/tests/memgraph_V1/features/match.feature b/tests/gql_behave/tests/memgraph_V1/features/match.feature index eaf8d3f44..47da2fadf 100644 --- a/tests/gql_behave/tests/memgraph_V1/features/match.feature +++ b/tests/gql_behave/tests/memgraph_V1/features/match.feature @@ -761,6 +761,7 @@ Feature: Match Then the result should be: | path | | <(:label1 {id: 1})-[:type2 {id: 10}]->(:label3 {id: 3})> | + | <(:label1 {id: 1})-[:same {id: 30}]->(:label1 {id: 1})-[:type2 {id: 10}]->(:label3 {id: 3})> | Scenario: Test DFS variable expand using IN edges with filter by edge type1 Given graph "graph_edges" @@ -771,6 +772,7 @@ Feature: Match Then the result should be: | path | | <(:label3 {id: 3})<-[:type2 {id: 10}]-(:label1 {id: 1})> | + | <(:label3 {id: 3})<-[:type2 {id: 10}]-(:label1 {id: 1})-[:same {id: 30}]->(:label1 {id: 1})> | Scenario: Test DFS variable expand with filter by edge type2 Given graph "graph_edges" @@ -781,6 +783,7 @@ Feature: Match Then the result should be: | path | | <(:label1 {id: 1})-[:type1 {id: 1}]->(:label2 {id: 2})-[:type1 {id: 2}]->(:label3 {id: 3})> | + | <(:label1 {id: 1})-[:same {id: 30}]->(:label1 {id: 1})-[:type1 {id: 1}]->(:label2 {id: 2})-[:type1 {id: 2}]->(:label3 {id: 3})> | Scenario: Test DFS variable expand using IN edges with filter by edge type2 Given graph "graph_edges" @@ -791,6 +794,7 @@ Feature: Match Then the result should be: | path | | <(:label3 {id: 3})<-[:type1 {id: 2}]-(:label2 {id: 2})<-[:type1 {id: 1}]-(:label1 {id: 1})> | + | <(:label3 {id: 3})<-[:type1 {id: 2}]-(:label2 {id: 2})<-[:type1 {id: 1}]-(:label1 {id: 1})-[:same {id: 30}]->(:label1 {id: 1})> | Scenario: Using path indentifier from CREATE in MERGE Given an empty graph diff --git a/tests/gql_behave/tests/memgraph_V1/features/memgraph_bfs.feature b/tests/gql_behave/tests/memgraph_V1/features/memgraph_bfs.feature index 23edc69cd..01855e548 100644 --- a/tests/gql_behave/tests/memgraph_V1/features/memgraph_bfs.feature +++ b/tests/gql_behave/tests/memgraph_V1/features/memgraph_bfs.feature @@ -249,3 +249,15 @@ Feature: Bfs Then the result should be: | path | | <(:label1 {id: 1})-[:type1 {id: 1}]->(:label2 {id: 2})-[:type1 {id: 2}]->(:label3 {id: 3})> | + + Scenario: Test BFS variable expand with already processed vertex and loop with filter by path + Given graph "graph_edges" + When executing query: + """ + MATCH path=(:label1)-[*BFS 1..1 (e, n, p | True)]-() RETURN path; + """ + Then the result should be: + | path | + | < (:label1 {id: 1})-[:type3 {id: 20}]->(:label5 {id: 5}) > | + | < (:label1 {id: 1})-[:type2 {id: 10}]->(:label3 {id: 3}) > | + | < (:label1 {id: 1})-[:type1 {id: 1}]->(:label2 {id: 2}) > | diff --git a/tests/gql_behave/tests/memgraph_V1/features/memgraph_wshortest.feature b/tests/gql_behave/tests/memgraph_V1/features/memgraph_wshortest.feature index afd484696..a160e471a 100644 --- a/tests/gql_behave/tests/memgraph_V1/features/memgraph_wshortest.feature +++ b/tests/gql_behave/tests/memgraph_V1/features/memgraph_wshortest.feature @@ -269,3 +269,15 @@ Feature: Weighted Shortest Path Then the result should be: | path | total_weight | | <(:station {arrival: 08:00:00.000000000, departure: 08:15:00.000000000, name: 'A'})-[:ride {duration: PT1H5M, id: 1}]->(:station {arrival: 09:20:00.000000000, departure: 09:30:00.000000000, name: 'B'})-[:ride {duration: PT30M, id: 2}]->(:station {arrival: 10:00:00.000000000, departure: 10:20:00.000000000, name: 'C'})> | PT2H20M | + + Scenario: Test wShortest variable expand with already processed vertex and loop with filter by path + Given graph "graph_edges" + When executing query: + """ + MATCH path=(:label1)-[*WSHORTEST ..1 (r, n | r.id) total_weight (e, n, p | True)]-() RETURN path; + """ + Then the result should be: + | path | + | < (:label1 {id: 1})-[:type3 {id: 20}]->(:label5 {id: 5}) > | + | < (:label1 {id: 1})-[:type2 {id: 10}]->(:label3 {id: 3}) > | + | < (:label1 {id: 1})-[:type1 {id: 1}]->(:label2 {id: 2}) > | diff --git a/tests/gql_behave/tests/memgraph_V1/graphs/graph_edges.cypher b/tests/gql_behave/tests/memgraph_V1/graphs/graph_edges.cypher index 3657b855b..1aa081cc1 100644 --- a/tests/gql_behave/tests/memgraph_V1/graphs/graph_edges.cypher +++ b/tests/gql_behave/tests/memgraph_V1/graphs/graph_edges.cypher @@ -1,3 +1,4 @@ CREATE (:label1 {id: 1})-[:type1 {id:1}]->(:label2 {id: 2})-[:type1 {id: 2}]->(:label3 {id: 3})-[:type1 {id: 3}]->(:label4 {id: 4}); MATCH (n :label1), (m :label3) CREATE (n)-[:type2 {id: 10}]->(m); MATCH (n :label1) CREATE (n)-[:type3 {id: 20}]->(:label5 { id: 5 }); +MATCH (n :label1) CREATE (n)-[:same {id: 30}]->(n); diff --git a/tests/integration/auth/runner.py b/tests/integration/auth/runner.py index a74b19a4e..6dcd42f38 100755 --- a/tests/integration/auth/runner.py +++ b/tests/integration/auth/runner.py @@ -193,12 +193,12 @@ def execute_test(memgraph_binary, tester_binary, checker_binary): "GRANT DATABASE db2 TO user", "CREATE USER useR2 IDENTIFIED BY 'user'", "GRANT DATABASE db2 TO user2", - "REVOKE DATABASE memgraph FROM user2", + "DENY DATABASE memgraph FROM user2", "SET MAIN DATABASE db2 FOR user2", "CREATE USER user3 IDENTIFIED BY 'user'", "GRANT ALL PRIVILEGES TO user3", "GRANT DATABASE * TO user3", - "REVOKE DATABASE memgraph FROM user3", + "DENY DATABASE memgraph FROM user3", ] ) diff --git a/tests/integration/ldap/runner.py b/tests/integration/ldap/runner.py index 8fc3af913..9e1a20f71 100755 --- a/tests/integration/ldap/runner.py +++ b/tests/integration/ldap/runner.py @@ -139,7 +139,7 @@ class Memgraph: def initialize_test(memgraph, tester_binary, **kwargs): memgraph.start(module_executable="") - execute_tester(tester_binary, ["CREATE USER root", "GRANT ALL PRIVILEGES TO root"]) + execute_tester(tester_binary, ["CREATE ROLE root_role", "GRANT ALL PRIVILEGES TO root_role"]) check_login = kwargs.pop("check_login", True) memgraph.restart(**kwargs) if check_login: @@ -149,20 +149,24 @@ def initialize_test(memgraph, tester_binary, **kwargs): # Tests -def test_basic(memgraph, tester_binary): +def test_module_ux(memgraph, tester_binary): initialize_test(memgraph, tester_binary) - execute_tester(tester_binary, [], "alice") - execute_tester(tester_binary, ["GRANT MATCH TO alice"], "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") + execute_tester(tester_binary, ["CREATE USER user1"], "root", query_should_fail=True) + execute_tester(tester_binary, ["CREATE ROLE role1"], "root", query_should_fail=False) + execute_tester(tester_binary, ["DROP USER user1"], "root", query_should_fail=True) + execute_tester(tester_binary, ["DROP ROLE role1"], "root", query_should_fail=False) + execute_tester(tester_binary, ["SET ROLE FOR user1 TO role1"], "root", query_should_fail=True) + execute_tester(tester_binary, ["CLEAR ROLE FOR user1"], "root", query_should_fail=True) memgraph.stop() -def test_only_existing_users(memgraph, tester_binary): - initialize_test(memgraph, tester_binary, create_missing_user=False) +def test_user_auth(memgraph, tester_binary): + initialize_test(memgraph, tester_binary) execute_tester(tester_binary, [], "alice", auth_should_fail=True) - execute_tester(tester_binary, ["CREATE USER alice"], "root") + execute_tester(tester_binary, ["CREATE ROLE moderator"], "root") execute_tester(tester_binary, [], "alice") - execute_tester(tester_binary, ["GRANT MATCH TO alice"], "root") + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) + execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root") execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") memgraph.stop() @@ -170,77 +174,50 @@ def test_only_existing_users(memgraph, tester_binary): def test_role_mapping(memgraph, tester_binary): initialize_test(memgraph, tester_binary) - execute_tester(tester_binary, [], "alice") + execute_tester(tester_binary, [], "alice", auth_should_fail=True) + execute_tester(tester_binary, [], "bob", auth_should_fail=True) + execute_tester(tester_binary, [], "carol", auth_should_fail=True) + execute_tester(tester_binary, ["CREATE ROLE moderator"], "root") + execute_tester(tester_binary, ["CREATE ROLE admin"], "root") + execute_tester(tester_binary, [], "alice", auth_should_fail=False) + execute_tester(tester_binary, [], "bob", auth_should_fail=True) + execute_tester(tester_binary, [], "carol", auth_should_fail=False) + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "carol", query_should_fail=True) execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=False) + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "carol", query_should_fail=True) + execute_tester(tester_binary, ["GRANT MATCH TO admin"], "root") + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=False) + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "carol", query_should_fail=False) - execute_tester(tester_binary, [], "bob") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "bob", query_should_fail=True) - - execute_tester(tester_binary, [], "carol") + execute_tester(tester_binary, ["CREATE (n) RETURN n"], "alice", query_should_fail=True) execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol", query_should_fail=True) execute_tester(tester_binary, ["GRANT CREATE TO admin"], "root") - execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol") - execute_tester(tester_binary, ["CREATE (n) RETURN n"], "dave") + execute_tester(tester_binary, ["CREATE (n) RETURN n"], "alice", query_should_fail=True) + execute_tester(tester_binary, ["CREATE (n) RETURN n"], "carol", query_should_fail=False) memgraph.stop() +def test_instance_restart(memgraph, tester_binary): + initialize_test(memgraph, tester_binary) + execute_tester(tester_binary, ["CREATE ROLE moderator"], "root") + execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root") + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") + memgraph.restart() + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") + memgraph.stop() + + def test_role_removal(memgraph, tester_binary): initialize_test(memgraph, tester_binary) - execute_tester(tester_binary, [], "alice") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) + execute_tester(tester_binary, ["CREATE ROLE moderator"], "root") execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root") execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") - memgraph.restart(manage_roles=False) - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") - execute_tester(tester_binary, ["CLEAR ROLE FOR alice"], "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) - memgraph.stop() - - -def test_only_existing_roles(memgraph, tester_binary): - initialize_test(memgraph, tester_binary, create_missing_role=False) - execute_tester(tester_binary, [], "bob") + execute_tester(tester_binary, ["DROP ROLE moderator"], "root") execute_tester(tester_binary, [], "alice", auth_should_fail=True) - execute_tester(tester_binary, ["CREATE ROLE moderator"], "root") - execute_tester(tester_binary, [], "alice") - memgraph.stop() - - -def test_role_is_user(memgraph, tester_binary): - initialize_test(memgraph, tester_binary) - execute_tester(tester_binary, [], "admin") - execute_tester(tester_binary, [], "carol", auth_should_fail=True) - memgraph.stop() - - -def test_user_is_role(memgraph, tester_binary): - initialize_test(memgraph, tester_binary) - execute_tester(tester_binary, [], "carol") - execute_tester(tester_binary, [], "admin", auth_should_fail=True) - memgraph.stop() - - -def test_user_permissions_persistancy(memgraph, tester_binary): - initialize_test(memgraph, tester_binary) - execute_tester(tester_binary, ["CREATE USER alice", "GRANT MATCH TO alice"], "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") - memgraph.stop() - - -def test_role_permissions_persistancy(memgraph, tester_binary): - initialize_test(memgraph, tester_binary) - execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") - memgraph.stop() - - -def test_only_authentication(memgraph, tester_binary): - initialize_test(memgraph, tester_binary, manage_roles=False) - execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) memgraph.stop() @@ -258,36 +235,36 @@ def test_wrong_suffix(memgraph, tester_binary): def test_suffix_with_spaces(memgraph, tester_binary): initialize_test(memgraph, tester_binary, suffix=", ou= people, dc = memgraph, dc = com") - execute_tester(tester_binary, ["CREATE USER alice", "GRANT MATCH TO alice"], "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") - memgraph.stop() - - -def test_role_mapping_wrong_root_dn(memgraph, tester_binary): - initialize_test(memgraph, tester_binary, root_dn="ou=invalid,dc=memgraph,dc=com") execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) - memgraph.restart() execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") memgraph.stop() -def test_role_mapping_wrong_root_objectclass(memgraph, tester_binary): - initialize_test(memgraph, tester_binary, root_objectclass="person") - execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) - memgraph.restart() - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") - memgraph.stop() +# def test_role_mapping_wrong_root_dn(memgraph, tester_binary): +# initialize_test(memgraph, tester_binary, root_dn="ou=invalid,dc=memgraph,dc=com") +# execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root") +# execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) +# memgraph.restart() +# execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") +# memgraph.stop() -def test_role_mapping_wrong_user_attribute(memgraph, tester_binary): - initialize_test(memgraph, tester_binary, user_attribute="cn") - execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) - memgraph.restart() - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") - memgraph.stop() +# def test_role_mapping_wrong_root_objectclass(memgraph, tester_binary): +# initialize_test(memgraph, tester_binary, root_objectclass="person") +# execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root") +# execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) +# memgraph.restart() +# execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") +# memgraph.stop() + + +# def test_role_mapping_wrong_user_attribute(memgraph, tester_binary): +# initialize_test(memgraph, tester_binary, user_attribute="cn") +# execute_tester(tester_binary, ["CREATE ROLE moderator", "GRANT MATCH TO moderator"], "root") +# execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice", query_should_fail=True) +# memgraph.restart() +# execute_tester(tester_binary, ["MATCH (n) RETURN n"], "alice") +# memgraph.stop() def test_wrong_password(memgraph, tester_binary): @@ -297,31 +274,9 @@ def test_wrong_password(memgraph, tester_binary): memgraph.stop() -def test_password_persistancy(memgraph, tester_binary): - initialize_test(memgraph, tester_binary, check_login=False) - memgraph.restart(module_executable="") - execute_tester(tester_binary, ["SHOW USERS"], "root", password="sudo") - execute_tester(tester_binary, ["SHOW USERS"], "root", password="root") - memgraph.restart() - execute_tester(tester_binary, [], "root", password="sudo", auth_should_fail=True) - execute_tester(tester_binary, ["SHOW USERS"], "root", password="root") - memgraph.restart(module_executable="") - execute_tester(tester_binary, [], "root", password="sudo", auth_should_fail=True) - execute_tester(tester_binary, ["SHOW USERS"], "root", password="root") - memgraph.stop() - - def test_user_multiple_roles(memgraph, tester_binary): - initialize_test(memgraph, tester_binary, check_login=False) - memgraph.restart() - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", query_should_fail=True) - execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", query_should_fail=True) - memgraph.restart(manage_roles=False) - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", query_should_fail=True) - execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", query_should_fail=True) - memgraph.restart(manage_roles=False, root_dn="") - execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", query_should_fail=True) - execute_tester(tester_binary, ["GRANT MATCH TO moderator"], "root", query_should_fail=True) + initialize_test(memgraph, tester_binary) + execute_tester(tester_binary, ["MATCH (n) RETURN n"], "eve", auth_should_fail=True) memgraph.stop() diff --git a/tests/integration/ldap/schema.ldif b/tests/integration/ldap/schema.ldif index f47ca5e8f..730c04415 100644 --- a/tests/integration/ldap/schema.ldif +++ b/tests/integration/ldap/schema.ldif @@ -84,6 +84,13 @@ objectclass: organizationalUnit objectclass: top ou: roles +# Role root +dn: cn=root_role,ou=roles,dc=memgraph,dc=com +cn: root_role +member: cn=root,ou=people,dc=memgraph,dc=com +objectclass: groupOfNames +objectclass: top + # Role moderator dn: cn=moderator,ou=roles,dc=memgraph,dc=com cn: moderator diff --git a/tests/integration/ldap/tester.cpp b/tests/integration/ldap/tester.cpp index 8f79938c7..fc8acfd82 100644 --- a/tests/integration/ldap/tester.cpp +++ b/tests/integration/ldap/tester.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -48,6 +48,7 @@ int main(int argc, char **argv) { } if (FLAGS_auth_should_fail) { MG_ASSERT(!what.empty(), "The authentication should have failed!"); + return 0; // Auth failed, nothing left to do } else { MG_ASSERT(what.empty(), "The authentication should have succeeded, but " diff --git a/tests/macro_benchmark/clients/pokec_client.cpp b/tests/macro_benchmark/clients/pokec_client.cpp index ba6f96941..40854707e 100644 --- a/tests/macro_benchmark/clients/pokec_client.cpp +++ b/tests/macro_benchmark/clients/pokec_client.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -24,11 +24,13 @@ #include #include +#include "communication/bolt/v1/value.hpp" #include "io/network/utils.hpp" +#include "long_running_common.hpp" #include "utils/algorithm.hpp" #include "utils/timer.hpp" -#include "long_running_common.hpp" +#include "communication/bolt/v1/fmt.hpp" using memgraph::communication::bolt::Edge; using memgraph::communication::bolt::Value; diff --git a/tests/manual/interactive_planning.cpp b/tests/manual/interactive_planning.cpp index f550b9724..3f64c4f37 100644 --- a/tests/manual/interactive_planning.cpp +++ b/tests/manual/interactive_planning.cpp @@ -27,6 +27,7 @@ #include "query/plan/planner.hpp" #include "query/plan/pretty_print.hpp" #include "query/typed_value.hpp" +#include "storage/v2/fmt.hpp" #include "storage/v2/property_value.hpp" #include "utils/string.hpp" diff --git a/tests/manual/query_hash.cpp b/tests/manual/query_hash.cpp index 8688da351..fb16f6db5 100644 --- a/tests/manual/query_hash.cpp +++ b/tests/manual/query_hash.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -15,6 +15,7 @@ #include #include "query/frontend/stripped.hpp" +#include "storage/v2/fmt.hpp" DEFINE_string(q, "CREATE (n) RETURN n", "Query"); diff --git a/tests/manual/single_query.cpp b/tests/manual/single_query.cpp index f2ce9c572..32e093a7a 100644 --- a/tests/manual/single_query.cpp +++ b/tests/manual/single_query.cpp @@ -50,6 +50,8 @@ int main(int argc, char *argv[]) { memgraph::query::Interpreter interpreter{&interpreter_context, db_acc}; ResultStreamFaker stream(db_acc->storage()); + memgraph::query::AllowEverythingAuthChecker auth_checker; + interpreter.SetUser(auth_checker.GenQueryUser(std::nullopt, std::nullopt)); auto [header, _1, qid, _2] = interpreter.Prepare(argv[1], {}, {}); stream.Header(header); auto summary = interpreter.PullAll(&stream); diff --git a/tests/unit/auth.cpp b/tests/unit/auth.cpp index bc2947a12..8bac5a05b 100644 --- a/tests/unit/auth.cpp +++ b/tests/unit/auth.cpp @@ -280,6 +280,8 @@ TEST_F(AuthWithStorage, RoleManipulations) { } { + const auto all = auth->AllUsernames(); + for (const auto &user : all) std::cout << user << std::endl; auto users = auth->AllUsers(); std::sort(users.begin(), users.end(), [](const User &a, const User &b) { return a.username() < b.username(); }); ASSERT_EQ(users.size(), 2); @@ -774,14 +776,16 @@ TEST_F(AuthWithStorage, CaseInsensitivity) { // Authenticate { - auto user = auth->Authenticate("alice", "alice"); - ASSERT_TRUE(user); - ASSERT_EQ(user->username(), "alice"); + auto user_or_role = auth->Authenticate("alice", "alice"); + ASSERT_TRUE(user_or_role); + const auto &user = std::get(*user_or_role); + ASSERT_EQ(user.username(), "alice"); } { - auto user = auth->Authenticate("alICe", "alice"); - ASSERT_TRUE(user); - ASSERT_EQ(user->username(), "alice"); + auto user_or_role = auth->Authenticate("alICe", "alice"); + ASSERT_TRUE(user_or_role); + const auto &user = std::get(*user_or_role); + ASSERT_EQ(user.username(), "alice"); } // GetUser @@ -809,6 +813,8 @@ TEST_F(AuthWithStorage, CaseInsensitivity) { // AllUsers { + const auto all = auth->AllUsernames(); + for (const auto &user : all) std::cout << user << std::endl; auto users = auth->AllUsers(); ASSERT_EQ(users.size(), 2); std::sort(users.begin(), users.end(), [](const auto &a, const auto &b) { return a.username() < b.username(); }); diff --git a/tests/unit/auth_checker.cpp b/tests/unit/auth_checker.cpp index f4c499cd7..50bec4cbc 100644 --- a/tests/unit/auth_checker.cpp +++ b/tests/unit/auth_checker.cpp @@ -12,11 +12,14 @@ #include #include +#include "auth/exceptions.hpp" #include "auth/models.hpp" #include "disk_test_utils.hpp" #include "glue/auth_checker.hpp" #include "license/license.hpp" +#include "query/frontend/ast/ast.hpp" +#include "query/query_user.hpp" #include "query_plan_common.hpp" #include "storage/v2/config.hpp" #include "storage/v2/disk/storage.hpp" @@ -225,4 +228,123 @@ TYPED_TEST(FineGrainedAuthCheckerFixture, GrantAndDenySpecificEdgeTypes) { ASSERT_FALSE(auth_checker.Has(this->r3, memgraph::query::AuthQuery::FineGrainedPrivilege::READ)); ASSERT_FALSE(auth_checker.Has(this->r4, memgraph::query::AuthQuery::FineGrainedPrivilege::READ)); } + +TEST(AuthChecker, Generate) { + std::filesystem::path auth_dir{std::filesystem::temp_directory_path() / "MG_auth_checker"}; + memgraph::utils::OnScopeExit clean([&]() { + if (std::filesystem::exists(auth_dir)) { + std::filesystem::remove_all(auth_dir); + } + }); + memgraph::auth::SynchedAuth auth(auth_dir, memgraph::auth::Auth::Config{/* default config */}); + memgraph::glue::AuthChecker auth_checker(&auth); + + auto empty_user = auth_checker.GenQueryUser(std::nullopt, std::nullopt); + ASSERT_THROW(auth_checker.GenQueryUser("does_not_exist", std::nullopt), memgraph::auth::AuthException); + + EXPECT_FALSE(empty_user && *empty_user); + // Still empty auth, so the above should have su permissions + using enum memgraph::query::AuthQuery::Privilege; + EXPECT_TRUE(empty_user->IsAuthorized({AUTH, REMOVE, REPLICATION}, "", &memgraph::query::session_long_policy)); + EXPECT_TRUE(empty_user->IsAuthorized({FREE_MEMORY, WEBSOCKET, MULTI_DATABASE_EDIT}, "memgraph", + &memgraph::query::session_long_policy)); + EXPECT_TRUE( + empty_user->IsAuthorized({TRIGGER, DURABILITY, STORAGE_MODE}, "some_db", &memgraph::query::session_long_policy)); + + // Add user + auth->AddUser("new_user"); + + // ~Empty user should now fail~ + // NOTE: Cache invalidation has been disabled, so this will pass; change if it is ever turned on + EXPECT_TRUE(empty_user->IsAuthorized({AUTH, REMOVE, REPLICATION}, "", &memgraph::query::session_long_policy)); + EXPECT_TRUE(empty_user->IsAuthorized({FREE_MEMORY, WEBSOCKET, MULTI_DATABASE_EDIT}, "memgraph", + &memgraph::query::session_long_policy)); + EXPECT_TRUE( + empty_user->IsAuthorized({TRIGGER, DURABILITY, STORAGE_MODE}, "some_db", &memgraph::query::session_long_policy)); + + // Add role and new user + auto new_role = *auth->AddRole("new_role"); + auto new_user2 = *auth->AddUser("new_user2"); + auto role = auth_checker.GenQueryUser("anyuser", "new_role"); + auto user2 = auth_checker.GenQueryUser("new_user2", std::nullopt); + + // Should be permission-less by default + EXPECT_FALSE(role->IsAuthorized({AUTH}, "memgraph", &memgraph::query::session_long_policy)); + EXPECT_FALSE(role->IsAuthorized({FREE_MEMORY}, "memgraph", &memgraph::query::session_long_policy)); + EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::session_long_policy)); + EXPECT_FALSE(user2->IsAuthorized({AUTH}, "memgraph", &memgraph::query::session_long_policy)); + EXPECT_FALSE(user2->IsAuthorized({FREE_MEMORY}, "memgraph", &memgraph::query::session_long_policy)); + EXPECT_FALSE(user2->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::session_long_policy)); + + // Update permissions and recheck + new_user2.permissions().Grant(memgraph::auth::Permission::AUTH); + new_role.permissions().Grant(memgraph::auth::Permission::TRIGGER); + auth->SaveUser(new_user2); + auth->SaveRole(new_role); + role = auth_checker.GenQueryUser("no check", "new_role"); + user2 = auth_checker.GenQueryUser("new_user2", std::nullopt); + EXPECT_FALSE(role->IsAuthorized({AUTH}, "memgraph", &memgraph::query::session_long_policy)); + EXPECT_FALSE(role->IsAuthorized({FREE_MEMORY}, "memgraph", &memgraph::query::session_long_policy)); + EXPECT_TRUE(role->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::session_long_policy)); + EXPECT_TRUE(user2->IsAuthorized({AUTH}, "memgraph", &memgraph::query::session_long_policy)); + EXPECT_FALSE(user2->IsAuthorized({FREE_MEMORY}, "memgraph", &memgraph::query::session_long_policy)); + EXPECT_FALSE(user2->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::session_long_policy)); + + // Connect role and recheck + new_user2.SetRole(new_role); + auth->SaveUser(new_user2); + user2 = auth_checker.GenQueryUser("new_user2", std::nullopt); + EXPECT_TRUE(user2->IsAuthorized({AUTH}, "memgraph", &memgraph::query::session_long_policy)); + EXPECT_FALSE(user2->IsAuthorized({FREE_MEMORY}, "memgraph", &memgraph::query::session_long_policy)); + EXPECT_TRUE(user2->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::session_long_policy)); + + // Add database and recheck + EXPECT_FALSE(user2->IsAuthorized({AUTH}, "non_default", &memgraph::query::session_long_policy)); + EXPECT_FALSE(user2->IsAuthorized({AUTH}, "another", &memgraph::query::session_long_policy)); + EXPECT_TRUE(user2->IsAuthorized({AUTH}, "memgraph", &memgraph::query::session_long_policy)); + EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "non_default", &memgraph::query::session_long_policy)); + EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "another", &memgraph::query::session_long_policy)); + EXPECT_TRUE(role->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::session_long_policy)); + new_user2.db_access().Grant("another"); + new_role.db_access().Grant("non_default"); + auth->SaveUser(new_user2); + auth->SaveRole(new_role); + // Session policy test + // Session long policy + EXPECT_FALSE(user2->IsAuthorized({AUTH}, "non_default", &memgraph::query::session_long_policy)); + EXPECT_FALSE(user2->IsAuthorized({AUTH}, "another", &memgraph::query::session_long_policy)); + EXPECT_TRUE(user2->IsAuthorized({AUTH}, "memgraph", &memgraph::query::session_long_policy)); + EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "non_default", &memgraph::query::session_long_policy)); + EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "another", &memgraph::query::session_long_policy)); + EXPECT_TRUE(role->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::session_long_policy)); + // Up to date policy + EXPECT_TRUE(user2->IsAuthorized({AUTH}, "non_default", &memgraph::query::up_to_date_policy)); + EXPECT_TRUE(user2->IsAuthorized({AUTH}, "another", &memgraph::query::up_to_date_policy)); + EXPECT_TRUE(user2->IsAuthorized({AUTH}, "memgraph", &memgraph::query::up_to_date_policy)); + EXPECT_TRUE(role->IsAuthorized({TRIGGER}, "non_default", &memgraph::query::up_to_date_policy)); + EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "another", &memgraph::query::up_to_date_policy)); + EXPECT_TRUE(role->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::up_to_date_policy)); + + new_user2.db_access().Deny("memgraph"); + new_role.db_access().Deny("non_default"); + auth->SaveUser(new_user2); + auth->SaveRole(new_role); + EXPECT_FALSE(user2->IsAuthorized({AUTH}, "non_default", &memgraph::query::up_to_date_policy)); + EXPECT_TRUE(user2->IsAuthorized({AUTH}, "another", &memgraph::query::up_to_date_policy)); + EXPECT_FALSE(user2->IsAuthorized({AUTH}, "memgraph", &memgraph::query::up_to_date_policy)); + EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "non_default", &memgraph::query::up_to_date_policy)); + EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "another", &memgraph::query::up_to_date_policy)); + EXPECT_TRUE(role->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::up_to_date_policy)); + + new_user2.db_access().Revoke("memgraph"); + new_role.db_access().Revoke("non_default"); + auth->SaveUser(new_user2); + auth->SaveRole(new_role); + EXPECT_FALSE(user2->IsAuthorized({AUTH}, "non_default", &memgraph::query::up_to_date_policy)); + EXPECT_TRUE(user2->IsAuthorized({AUTH}, "another", &memgraph::query::up_to_date_policy)); + EXPECT_TRUE(user2->IsAuthorized({AUTH}, "memgraph", &memgraph::query::up_to_date_policy)); + EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "non_default", &memgraph::query::up_to_date_policy)); + EXPECT_FALSE(role->IsAuthorized({TRIGGER}, "another", &memgraph::query::up_to_date_policy)); + EXPECT_TRUE(role->IsAuthorized({TRIGGER}, "memgraph", &memgraph::query::up_to_date_policy)); +} #endif diff --git a/tests/unit/database_get_info.cpp b/tests/unit/database_get_info.cpp index a8a275a61..be6885b37 100644 --- a/tests/unit/database_get_info.cpp +++ b/tests/unit/database_get_info.cpp @@ -15,6 +15,7 @@ #include #include "dbms/database.hpp" +#include "dbms/dbms_handler.hpp" #include "disk_test_utils.hpp" #include "query/interpret/awesome_memgraph_functions.hpp" #include "query/interpreter_context.hpp" @@ -30,15 +31,31 @@ using namespace memgraph::storage; constexpr auto testSuite = "database_v2_get_info"; const std::filesystem::path storage_directory{std::filesystem::temp_directory_path() / testSuite}; -template +struct TestConfig {}; +struct DefaultConfig : TestConfig {}; +struct TenantConfig : TestConfig {}; + +template class InfoTest : public testing::Test { + using StorageType = typename TestType::first_type; + using ConfigType = typename TestType::second_type; + protected: void SetUp() { - repl_state.emplace(memgraph::storage::ReplicationStateRootPath(config)); - db_gk.emplace(config, *repl_state); - auto db_acc_opt = db_gk->access(); - MG_ASSERT(db_acc_opt, "Failed to access db"); - auto &db_acc = *db_acc_opt; + repl_state_.emplace(ReplicationStateRootPath(config)); +#ifdef MG_ENTERPRISE + dbms_handler_.emplace(config, *repl_state_, auth_, false); + auto db_acc = dbms_handler_->Get(); // Default db + if (std::is_same_v) { + constexpr std::string_view db_name = "test_db"; + MG_ASSERT(dbms_handler_->New(std::string{db_name}).HasValue(), "Failed to create database."); + db_acc = dbms_handler_->Get(db_name); + } +#else + dbms_handler_.emplace(config, *repl_state_); + auto db_acc = dbms_handler_->Get(); +#endif + MG_ASSERT(db_acc, "Failed to access db"); MG_ASSERT(db_acc->GetStorageMode() == (std::is_same_v ? memgraph::storage::StorageMode::ON_DISK_TRANSACTIONAL : memgraph::storage::StorageMode::IN_MEMORY_TRANSACTIONAL), @@ -48,8 +65,8 @@ class InfoTest : public testing::Test { void TearDown() { db_acc_.reset(); - db_gk.reset(); - repl_state.reset(); + dbms_handler_.reset(); + repl_state_.reset(); if (std::is_same::value) { disk_test_utils::RemoveRocksDbDirs(testSuite); } @@ -59,9 +76,9 @@ class InfoTest : public testing::Test { StorageMode mode{std::is_same_v ? StorageMode::ON_DISK_TRANSACTIONAL : StorageMode::IN_MEMORY_TRANSACTIONAL}; - std::optional repl_state; - std::optional db_acc_; - std::optional> db_gk; +#ifdef MG_ENTERPRISE + memgraph::auth::SynchedAuth auth_{storage_directory, memgraph::auth::Auth::Config {}}; +#endif memgraph::storage::Config config{ [&]() { memgraph::storage::Config config{}; @@ -69,18 +86,27 @@ class InfoTest : public testing::Test { config.durability.snapshot_wal_mode = memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL; if constexpr (std::is_same_v) { - config.disk = disk_test_utils::GenerateOnDiskConfig(testSuite).disk; config.force_on_disk = true; } return config; }() // iile }; + std::optional repl_state_; + std::optional dbms_handler_; + std::optional db_acc_; }; -using StorageTypes = ::testing::Types; +using TestTypes = ::testing::Types, + std::pair -TYPED_TEST_CASE(InfoTest, StorageTypes); -// TYPED_TEST_CASE(IndexTest, InMemoryStorageType); +#ifdef MG_ENTERPRISE + , + std::pair, + std::pair +#endif + >; + +TYPED_TEST_CASE(InfoTest, TestTypes); // NOLINTNEXTLINE(hicpp-special-member-functions) TYPED_TEST(InfoTest, InfoCheck) { @@ -166,13 +192,13 @@ TYPED_TEST(InfoTest, InfoCheck) { } const auto &info = db_acc->GetInfo( - true, memgraph::replication_coordination_glue::ReplicationRole::MAIN); // force to use configured directory + memgraph::replication_coordination_glue::ReplicationRole::MAIN); // force to use configured directory ASSERT_EQ(info.storage_info.vertex_count, 5); ASSERT_EQ(info.storage_info.edge_count, 2); ASSERT_EQ(info.storage_info.average_degree, 0.8); - ASSERT_GT(info.storage_info.memory_res, 10'000'000); // 200MB < > 10MB - ASSERT_LT(info.storage_info.memory_res, 200'000'000); + ASSERT_GT(info.storage_info.memory_res, 10'000'000); // 250MB < > 10MB + ASSERT_LT(info.storage_info.memory_res, 250'000'000); ASSERT_GT(info.storage_info.disk_usage, 100); // 1MB < > 100B ASSERT_LT(info.storage_info.disk_usage, 1000'000); ASSERT_EQ(info.storage_info.label_indices, 1); diff --git a/tests/unit/integrations_kafka_consumer.cpp b/tests/unit/integrations_kafka_consumer.cpp index 3d5feb80b..2265aa310 100644 --- a/tests/unit/integrations_kafka_consumer.cpp +++ b/tests/unit/integrations_kafka_consumer.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -113,7 +113,6 @@ struct ConsumerTest : public ::testing::Test { void SeedTopicWithInt(const std::string &topic_name, int value) { std::array int_as_char{}; std::memcpy(int_as_char.data(), &value, int_as_char.size()); - cluster.SeedTopic(topic_name, int_as_char); } diff --git a/tests/unit/interpreter_faker.hpp b/tests/unit/interpreter_faker.hpp index 3b6075911..c1e3b4b06 100644 --- a/tests/unit/interpreter_faker.hpp +++ b/tests/unit/interpreter_faker.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -18,6 +18,7 @@ struct InterpreterFaker { : interpreter_context(interpreter_context), interpreter(interpreter_context, db) { interpreter_context->auth_checker = &auth_checker; interpreter_context->interpreters.WithLock([this](auto &interpreters) { interpreters.insert(&interpreter); }); + interpreter.SetUser(auth_checker.GenQueryUser(std::nullopt, std::nullopt)); } auto Prepare(const std::string &query, const std::map ¶ms = {}) { diff --git a/tests/unit/kafka_mock.cpp b/tests/unit/kafka_mock.cpp index 7cf788479..0ea9bcac4 100644 --- a/tests/unit/kafka_mock.cpp +++ b/tests/unit/kafka_mock.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -78,10 +78,6 @@ void KafkaClusterMock::CreateTopic(const std::string &topic_name) { } } -void KafkaClusterMock::SeedTopic(const std::string &topic_name, std::string_view message) { - SeedTopic(topic_name, std::span{message.data(), message.size()}); -} - void KafkaClusterMock::SeedTopic(const std::string &topic_name, std::span message) { char errstr[256] = {'\0'}; std::string bootstraps_servers = Bootstraps(); diff --git a/tests/unit/kafka_mock.hpp b/tests/unit/kafka_mock.hpp index fce563fda..5905d281f 100644 --- a/tests/unit/kafka_mock.hpp +++ b/tests/unit/kafka_mock.hpp @@ -1,4 +1,4 @@ -// Copyright 2021 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -41,7 +41,6 @@ class KafkaClusterMock { std::string Bootstraps() const; void CreateTopic(const std::string &topic_name); void SeedTopic(const std::string &topic_name, std::span message); - void SeedTopic(const std::string &topic_name, std::string_view message); private: RdKafkaUniquePtr rk_{nullptr}; diff --git a/tests/unit/monitoring.cpp b/tests/unit/monitoring.cpp index e04e091e5..26dc6ad47 100644 --- a/tests/unit/monitoring.cpp +++ b/tests/unit/monitoring.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -44,11 +44,9 @@ struct MockAuth : public memgraph::communication::websocket::AuthenticationInter return authentication; } - bool HasUserPermission(const std::string & /*username*/, memgraph::auth::Permission /*permission*/) const override { - return authorization; - } + bool HasPermission(memgraph::auth::Permission /*permission*/) const override { return authorization; } - bool HasAnyUsers() const override { return has_any_users; } + bool AccessControlled() const override { return has_any_users; } bool authentication{true}; bool authorization{true}; diff --git a/tests/unit/query_dump.cpp b/tests/unit/query_dump.cpp index 5ecf598b2..a2ca2864d 100644 --- a/tests/unit/query_dump.cpp +++ b/tests/unit/query_dump.cpp @@ -21,6 +21,8 @@ #include "communication/result_stream_faker.hpp" #include "dbms/database.hpp" #include "disk_test_utils.hpp" +#include "glue/auth_checker.hpp" +#include "query/auth_checker.hpp" #include "query/config.hpp" #include "query/dump.hpp" #include "query/interpreter.hpp" @@ -216,6 +218,8 @@ DatabaseState GetState(memgraph::storage::Storage *db) { auto Execute(memgraph::query::InterpreterContext *context, memgraph::dbms::DatabaseAccess db, const std::string &query) { memgraph::query::Interpreter interpreter(context, db); + memgraph::query::AllowEverythingAuthChecker auth_checker; + interpreter.SetUser(auth_checker.GenQueryUser(std::nullopt, std::nullopt)); ResultStreamFaker stream(db->storage()); auto [header, _1, qid, _2] = interpreter.Prepare(query, {}, {}); @@ -915,7 +919,10 @@ TYPED_TEST(DumpTest, ExecuteDumpDatabase) { class StatefulInterpreter { public: explicit StatefulInterpreter(memgraph::query::InterpreterContext *context, memgraph::dbms::DatabaseAccess db) - : context_(context), interpreter_(context_, db) {} + : context_(context), interpreter_(context_, db) { + memgraph::query::AllowEverythingAuthChecker auth_checker; + interpreter_.SetUser(auth_checker.GenQueryUser(std::nullopt, std::nullopt)); + } auto Execute(const std::string &query) { ResultStreamFaker stream(interpreter_.current_db_.db_acc_->get()->storage()); @@ -1138,7 +1145,7 @@ TYPED_TEST(DumpTest, DumpDatabaseWithTriggers) { memgraph::query::DbAccessor dba(acc.get()); const std::map props; trigger_store->AddTrigger(trigger_name, trigger_statement, props, trigger_event_type, trigger_phase, &ast_cache, - &dba, query_config, std::nullopt, &auth_checker); + &dba, query_config, auth_checker.GenQueryUser(std::nullopt, std::nullopt)); } { ResultStreamFaker stream(this->db->storage()); diff --git a/tests/unit/query_plan_edge_cases.cpp b/tests/unit/query_plan_edge_cases.cpp index ac04cabdd..262ebd4e1 100644 --- a/tests/unit/query_plan_edge_cases.cpp +++ b/tests/unit/query_plan_edge_cases.cpp @@ -22,6 +22,7 @@ #include "gtest/gtest.h" #include "communication/result_stream_faker.hpp" +#include "query/auth_checker.hpp" #include "query/interpreter.hpp" #include "query/interpreter_context.hpp" #include "query/stream/streams.hpp" @@ -36,6 +37,7 @@ class QueryExecution : public testing::Test { const std::string testSuite = "query_plan_edge_cases"; std::optional db_acc_; std::optional interpreter_context_; + std::optional auth_checker_; std::optional interpreter_; std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_tests_unit_query_plan_edge_cases"}; @@ -73,11 +75,14 @@ class QueryExecution : public testing::Test { nullptr #endif ); + auth_checker_.emplace(); interpreter_.emplace(&*interpreter_context_, *db_acc_); + interpreter_->SetUser(auth_checker_->GenQueryUser(std::nullopt, std::nullopt)); } void TearDown() override { interpreter_ = std::nullopt; + auth_checker_.reset(); interpreter_context_ = std::nullopt; system_state.reset(); db_acc_.reset(); diff --git a/tests/unit/query_plan_operator_to_string.cpp b/tests/unit/query_plan_operator_to_string.cpp index 694552cf0..9696050f2 100644 --- a/tests/unit/query_plan_operator_to_string.cpp +++ b/tests/unit/query_plan_operator_to_string.cpp @@ -214,23 +214,22 @@ TYPED_TEST(OperatorToStringTest, Filter) { auto node_ident = IDENT("person"); auto property = this->dba.NameToProperty("name"); auto property_ix = this->storage.GetPropertyIx("name"); - - FilterInfo generic_filter_info = {.type = FilterInfo::Type::Generic, .used_symbols = {node}}; + auto generic_filter_info = FilterInfo{FilterInfo::Type::Generic, nullptr, {node}}; auto id_filter = IdFilter(this->symbol_table, node, LITERAL(42)); - FilterInfo id_filter_info = {.type = FilterInfo::Type::Id, .id_filter = id_filter}; + auto id_filter_info = FilterInfo{FilterInfo::Type::Id, nullptr, {}, {}, id_filter}; std::vector labels{this->storage.GetLabelIx("Customer"), this->storage.GetLabelIx("Visitor")}; auto labels_test = LABELS_TEST(node_ident, labels); - FilterInfo label_filter_info = {.type = FilterInfo::Type::Label, .expression = labels_test}; + auto label_filter_info = FilterInfo{FilterInfo::Type::Label, labels_test}; auto labels_test_2 = LABELS_TEST(PROPERTY_LOOKUP(this->dba, "person", property), labels); - FilterInfo label_filter_2_info = {.type = FilterInfo::Type::Label, .expression = labels_test_2}; + auto label_filter_2_info = FilterInfo{FilterInfo::Type::Label, labels_test_2}; auto property_filter = PropertyFilter(node, property_ix, PropertyFilter::Type::EQUAL); - FilterInfo property_filter_info = {.type = FilterInfo::Type::Property, .property_filter = property_filter}; + auto property_filter_info = FilterInfo{FilterInfo::Type::Property, nullptr, {}, property_filter}; - FilterInfo pattern_filter_info = {.type = FilterInfo::Type::Pattern}; + auto pattern_filter_info = FilterInfo{FilterInfo::Type::Pattern}; Filters filters; filters.SetFilters({generic_filter_info, id_filter_info, label_filter_info, label_filter_2_info, property_filter_info, diff --git a/tests/unit/query_streams.cpp b/tests/unit/query_streams.cpp index cde3d937a..5b246468f 100644 --- a/tests/unit/query_streams.cpp +++ b/tests/unit/query_streams.cpp @@ -20,9 +20,11 @@ #include "integrations/constants.hpp" #include "integrations/kafka/exceptions.hpp" #include "kafka_mock.hpp" +#include "query/auth_checker.hpp" #include "query/config.hpp" #include "query/interpreter.hpp" #include "query/interpreter_context.hpp" +#include "query/query_user.hpp" #include "query/stream/streams.hpp" #include "storage/v2/config.hpp" #include "storage/v2/disk/storage.hpp" @@ -35,11 +37,23 @@ using StreamStatus = memgraph::query::stream::StreamStatus &privileges, const std::string &db_name, + memgraph::query::UserPolicy *policy) const { + return true; + } +#ifdef MG_ENTERPRISE + std::string GetDefaultDB() const { return "memgraph"; } +#endif +}; + struct StreamCheckData { std::string name; StreamInfo info; bool is_running; - std::optional owner; + std::shared_ptr owner; }; std::string GetDefaultStreamName() { @@ -105,13 +119,16 @@ class StreamsTestFixture : public ::testing::Test { }() // iile }; memgraph::system::System system_state; - memgraph::query::InterpreterContext interpreter_context_{memgraph::query::InterpreterConfig{}, nullptr, &repl_state, - system_state + memgraph::query::AllowEverythingAuthChecker auth_checker; + memgraph::query::InterpreterContext interpreter_context_{memgraph::query::InterpreterConfig{}, + nullptr, + &repl_state, + system_state, #ifdef MG_ENTERPRISE - , - nullptr + nullptr, #endif - }; + nullptr, + &auth_checker}; std::filesystem::path streams_data_directory_{data_directory_ / "separate-dir-for-test"}; std::optional proxyStreams_; @@ -173,7 +190,7 @@ class StreamsTestFixture : public ::testing::Test { } StreamCheckData CreateDefaultStreamCheckData() { - return {GetDefaultStreamName(), CreateDefaultStreamInfo(), false, std::nullopt}; + return {GetDefaultStreamName(), CreateDefaultStreamInfo(), false, std::make_unique()}; } void Clear() { @@ -215,11 +232,11 @@ TYPED_TEST(StreamsTestFixture, CreateAlreadyExisting) { auto stream_info = this->CreateDefaultStreamInfo(); auto stream_name = GetDefaultStreamName(); this->proxyStreams_->streams_->template Create( - stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_); + stream_name, stream_info, std::make_unique(), this->db_, &this->interpreter_context_); try { this->proxyStreams_->streams_->template Create( - stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_); + stream_name, stream_info, std::make_unique(), this->db_, &this->interpreter_context_); FAIL() << "Creating already existing stream should throw\n"; } catch (memgraph::query::stream::StreamsException &exception) { EXPECT_EQ(exception.what(), fmt::format("Stream already exists with name '{}'", stream_name)); @@ -231,7 +248,7 @@ TYPED_TEST(StreamsTestFixture, DropNotExistingStream) { const auto stream_name = GetDefaultStreamName(); const std::string not_existing_stream_name{"ThisDoesn'tExists"}; this->proxyStreams_->streams_->template Create( - stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_); + stream_name, stream_info, std::make_unique(), this->db_, &this->interpreter_context_); try { this->proxyStreams_->streams_->Drop(not_existing_stream_name); @@ -262,7 +279,7 @@ TYPED_TEST(StreamsTestFixture, RestoreStreams) { if (i > 0) { stream_info.common_info.batch_interval = std::chrono::milliseconds((i + 1) * 10); stream_info.common_info.batch_size = 1000 + i; - stream_check_data.owner = std::string{"owner"} + iteration_postfix; + stream_check_data.owner = std::make_unique(); // These are just random numbers to make the CONFIGS and CREDENTIALS map vary between consumers: // - 0 means no config, no credential @@ -280,7 +297,7 @@ TYPED_TEST(StreamsTestFixture, RestoreStreams) { this->mock_cluster_.CreateTopic(stream_info.topics[0]); } - stream_check_datas[3].owner = {}; + stream_check_datas[3].owner = std::make_unique(); const auto check_restore_logic = [&stream_check_datas, this]() { // Reset the Streams object to trigger reloading @@ -336,7 +353,7 @@ TYPED_TEST(StreamsTestFixture, CheckWithTimeout) { const auto stream_info = this->CreateDefaultStreamInfo(); const auto stream_name = GetDefaultStreamName(); this->proxyStreams_->streams_->template Create( - stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_); + stream_name, stream_info, std::make_unique(), this->db_, &this->interpreter_context_); std::chrono::milliseconds timeout{3000}; @@ -360,9 +377,10 @@ TYPED_TEST(StreamsTestFixture, CheckInvalidConfig) { EXPECT_TRUE(message.find(kInvalidConfigName) != std::string::npos) << message; EXPECT_TRUE(message.find(kConfigValue) != std::string::npos) << message; }; - EXPECT_THROW_WITH_MSG(this->proxyStreams_->streams_->template Create( - stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_), - memgraph::integrations::kafka::SettingCustomConfigFailed, checker); + EXPECT_THROW_WITH_MSG( + this->proxyStreams_->streams_->template Create( + stream_name, stream_info, std::make_unique(), this->db_, &this->interpreter_context_), + memgraph::integrations::kafka::SettingCustomConfigFailed, checker); } TYPED_TEST(StreamsTestFixture, CheckInvalidCredentials) { @@ -376,7 +394,8 @@ TYPED_TEST(StreamsTestFixture, CheckInvalidCredentials) { EXPECT_TRUE(message.find(memgraph::integrations::kReducted) != std::string::npos) << message; EXPECT_TRUE(message.find(kCredentialValue) == std::string::npos) << message; }; - EXPECT_THROW_WITH_MSG(this->proxyStreams_->streams_->template Create( - stream_name, stream_info, std::nullopt, this->db_, &this->interpreter_context_), - memgraph::integrations::kafka::SettingCustomConfigFailed, checker); + EXPECT_THROW_WITH_MSG( + this->proxyStreams_->streams_->template Create( + stream_name, stream_info, std::make_unique(), this->db_, &this->interpreter_context_), + memgraph::integrations::kafka::SettingCustomConfigFailed, checker); } diff --git a/tests/unit/query_trigger.cpp b/tests/unit/query_trigger.cpp index 1b2ca5e9c..06aa1dbd9 100644 --- a/tests/unit/query_trigger.cpp +++ b/tests/unit/query_trigger.cpp @@ -21,6 +21,7 @@ #include "query/db_accessor.hpp" #include "query/frontend/ast/ast.hpp" #include "query/interpreter.hpp" +#include "query/query_user.hpp" #include "query/trigger.hpp" #include "query/typed_value.hpp" #include "storage/v2/config.hpp" @@ -42,16 +43,27 @@ const std::unordered_set kAllEventTypes{ class MockAuthChecker : public memgraph::query::AuthChecker { public: - MOCK_CONST_METHOD3(IsUserAuthorized, - bool(const std::optional &username, - const std::vector &privileges, const std::string &db)); + MOCK_CONST_METHOD2(GenQueryUser, + std::shared_ptr(const std::optional &username, + const std::optional &rolename)); #ifdef MG_ENTERPRISE - MOCK_CONST_METHOD2(GetFineGrainedAuthChecker, - std::unique_ptr( - const std::string &username, const memgraph::query::DbAccessor *db_accessor)); + MOCK_CONST_METHOD2(GetFineGrainedAuthChecker, std::unique_ptr( + std::shared_ptr user, + const memgraph::query::DbAccessor *db_accessor)); MOCK_CONST_METHOD0(ClearCache, void()); #endif }; + +class MockQueryUser : public memgraph::query::QueryUserOrRole { + public: + MockQueryUser(std::optional name) : memgraph::query::QueryUserOrRole(std::move(name), std::nullopt) {} + MOCK_CONST_METHOD3(IsAuthorized, bool(const std::vector &privileges, + const std::string &db_name, memgraph::query::UserPolicy *policy)); + +#ifdef MG_ENTERPRISE + MOCK_CONST_METHOD0(GetDefaultDB, std::string()); +#endif +}; } // namespace const std::string testSuite = "query_trigger"; @@ -966,12 +978,12 @@ TYPED_TEST(TriggerStoreTest, Restore) { trigger_name_before, trigger_statement, std::map{{"parameter", memgraph::storage::PropertyValue{1}}}, event_type, memgraph::query::TriggerPhase::BEFORE_COMMIT, &this->ast_cache, &*this->dba, - memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker); + memgraph::query::InterpreterConfig::Query{}, this->auth_checker.GenQueryUser(std::nullopt, std::nullopt)); store->AddTrigger( trigger_name_after, trigger_statement, std::map{{"parameter", memgraph::storage::PropertyValue{"value"}}}, event_type, memgraph::query::TriggerPhase::AFTER_COMMIT, &this->ast_cache, &*this->dba, - memgraph::query::InterpreterConfig::Query{}, {owner}, &this->auth_checker); + memgraph::query::InterpreterConfig::Query{}, this->auth_checker.GenQueryUser(owner, std::nullopt)); const auto check_triggers = [&] { ASSERT_EQ(store->GetTriggerInfo().size(), 2); @@ -981,9 +993,9 @@ TYPED_TEST(TriggerStoreTest, Restore) { ASSERT_EQ(trigger.OriginalStatement(), trigger_statement); ASSERT_EQ(trigger.EventType(), event_type); if (owner != nullptr) { - ASSERT_EQ(*trigger.Owner(), *owner); + ASSERT_EQ(trigger.Owner()->username(), *owner); } else { - ASSERT_FALSE(trigger.Owner().has_value()); + ASSERT_FALSE(trigger.Owner()->username()); } }; @@ -1022,32 +1034,38 @@ TYPED_TEST(TriggerStoreTest, AddTrigger) { // Invalid query in statements ASSERT_THROW(store.AddTrigger("trigger", "RETUR 1", {}, memgraph::query::TriggerEventType::VERTEX_CREATE, memgraph::query::TriggerPhase::BEFORE_COMMIT, &this->ast_cache, &*this->dba, - memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker), + memgraph::query::InterpreterConfig::Query{}, + this->auth_checker.GenQueryUser(std::nullopt, std::nullopt)), memgraph::utils::BasicException); ASSERT_THROW(store.AddTrigger("trigger", "RETURN createdEdges", {}, memgraph::query::TriggerEventType::VERTEX_CREATE, memgraph::query::TriggerPhase::BEFORE_COMMIT, &this->ast_cache, &*this->dba, - memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker), + memgraph::query::InterpreterConfig::Query{}, + this->auth_checker.GenQueryUser(std::nullopt, std::nullopt)), memgraph::utils::BasicException); ASSERT_THROW(store.AddTrigger("trigger", "RETURN $parameter", {}, memgraph::query::TriggerEventType::VERTEX_CREATE, memgraph::query::TriggerPhase::BEFORE_COMMIT, &this->ast_cache, &*this->dba, - memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker), + memgraph::query::InterpreterConfig::Query{}, + this->auth_checker.GenQueryUser(std::nullopt, std::nullopt)), memgraph::utils::BasicException); ASSERT_NO_THROW(store.AddTrigger( "trigger", "RETURN $parameter", std::map{{"parameter", memgraph::storage::PropertyValue{1}}}, memgraph::query::TriggerEventType::VERTEX_CREATE, memgraph::query::TriggerPhase::BEFORE_COMMIT, &this->ast_cache, - &*this->dba, memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker)); + &*this->dba, memgraph::query::InterpreterConfig::Query{}, + this->auth_checker.GenQueryUser(std::nullopt, std::nullopt))); // Inserting with the same name ASSERT_THROW(store.AddTrigger("trigger", "RETURN 1", {}, memgraph::query::TriggerEventType::VERTEX_CREATE, memgraph::query::TriggerPhase::BEFORE_COMMIT, &this->ast_cache, &*this->dba, - memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker), + memgraph::query::InterpreterConfig::Query{}, + this->auth_checker.GenQueryUser(std::nullopt, std::nullopt)), memgraph::utils::BasicException); ASSERT_THROW(store.AddTrigger("trigger", "RETURN 1", {}, memgraph::query::TriggerEventType::VERTEX_CREATE, memgraph::query::TriggerPhase::AFTER_COMMIT, &this->ast_cache, &*this->dba, - memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker), + memgraph::query::InterpreterConfig::Query{}, + this->auth_checker.GenQueryUser(std::nullopt, std::nullopt)), memgraph::utils::BasicException); ASSERT_EQ(store.GetTriggerInfo().size(), 1); @@ -1063,7 +1081,8 @@ TYPED_TEST(TriggerStoreTest, DropTrigger) { const auto *trigger_name = "trigger"; store.AddTrigger(trigger_name, "RETURN 1", {}, memgraph::query::TriggerEventType::VERTEX_CREATE, memgraph::query::TriggerPhase::BEFORE_COMMIT, &this->ast_cache, &*this->dba, - memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker); + memgraph::query::InterpreterConfig::Query{}, + this->auth_checker.GenQueryUser(std::nullopt, std::nullopt)); ASSERT_THROW(store.DropTrigger("Unknown"), memgraph::utils::BasicException); ASSERT_NO_THROW(store.DropTrigger(trigger_name)); @@ -1076,7 +1095,8 @@ TYPED_TEST(TriggerStoreTest, TriggerInfo) { std::vector expected_info; store.AddTrigger("trigger", "RETURN 1", {}, memgraph::query::TriggerEventType::VERTEX_CREATE, memgraph::query::TriggerPhase::BEFORE_COMMIT, &this->ast_cache, &*this->dba, - memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker); + memgraph::query::InterpreterConfig::Query{}, + this->auth_checker.GenQueryUser(std::nullopt, std::nullopt)); expected_info.push_back({"trigger", "RETURN 1", memgraph::query::TriggerEventType::VERTEX_CREATE, @@ -1099,7 +1119,8 @@ TYPED_TEST(TriggerStoreTest, TriggerInfo) { store.AddTrigger("edge_update_trigger", "RETURN 1", {}, memgraph::query::TriggerEventType::EDGE_UPDATE, memgraph::query::TriggerPhase::AFTER_COMMIT, &this->ast_cache, &*this->dba, - memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker); + memgraph::query::InterpreterConfig::Query{}, + this->auth_checker.GenQueryUser(std::nullopt, std::nullopt)); expected_info.push_back({"edge_update_trigger", "RETURN 1", memgraph::query::TriggerEventType::EDGE_UPDATE, @@ -1216,7 +1237,8 @@ TYPED_TEST(TriggerStoreTest, AnyTriggerAllKeywords) { SCOPED_TRACE(keyword); EXPECT_NO_THROW(store.AddTrigger(trigger_name, fmt::format("RETURN {}", keyword), {}, event_type, memgraph::query::TriggerPhase::BEFORE_COMMIT, &this->ast_cache, &*this->dba, - memgraph::query::InterpreterConfig::Query{}, std::nullopt, &this->auth_checker)); + memgraph::query::InterpreterConfig::Query{}, + this->auth_checker.GenQueryUser(std::nullopt, std::nullopt))); store.DropTrigger(trigger_name); } } @@ -1228,45 +1250,50 @@ TYPED_TEST(TriggerStoreTest, AuthCheckerUsage) { using ::testing::ElementsAre; using ::testing::Return; std::optional store{this->testing_directory}; - const std::optional owner{"testing_owner"}; MockAuthChecker mock_checker; + const std::optional owner{"mock_user"}; + MockQueryUser mock_user(owner); + std::shared_ptr mock_user_ptr( + &mock_user, [](memgraph::query::QueryUserOrRole *) { /* do nothing */ }); + MockQueryUser mock_userless(std::nullopt); + std::shared_ptr mock_userless_ptr( + &mock_userless, [](memgraph::query::QueryUserOrRole *) { /* do nothing */ }); ::testing::InSequence s; - EXPECT_CALL(mock_checker, IsUserAuthorized(std::optional{}, ElementsAre(Privilege::CREATE), "")) - .Times(1) + // TODO Userless + EXPECT_CALL(mock_user, IsAuthorized(ElementsAre(Privilege::CREATE), "", &memgraph::query::up_to_date_policy)) .WillOnce(Return(true)); - EXPECT_CALL(mock_checker, IsUserAuthorized(owner, ElementsAre(Privilege::CREATE), "")) - .Times(1) - .WillOnce(Return(true)); - ASSERT_NO_THROW(store->AddTrigger("successfull_trigger_1", "CREATE (n:VERTEX) RETURN n", {}, memgraph::query::TriggerEventType::EDGE_UPDATE, memgraph::query::TriggerPhase::AFTER_COMMIT, &this->ast_cache, &*this->dba, - memgraph::query::InterpreterConfig::Query{}, std::nullopt, &mock_checker)); + memgraph::query::InterpreterConfig::Query{}, mock_user_ptr)); + EXPECT_CALL(mock_userless, IsAuthorized(ElementsAre(Privilege::CREATE), "", &memgraph::query::up_to_date_policy)) + .WillOnce(Return(true)); ASSERT_NO_THROW(store->AddTrigger("successfull_trigger_2", "CREATE (n:VERTEX) RETURN n", {}, memgraph::query::TriggerEventType::EDGE_UPDATE, memgraph::query::TriggerPhase::AFTER_COMMIT, &this->ast_cache, &*this->dba, - memgraph::query::InterpreterConfig::Query{}, owner, &mock_checker)); + memgraph::query::InterpreterConfig::Query{}, mock_userless_ptr)); - EXPECT_CALL(mock_checker, IsUserAuthorized(std::optional{}, ElementsAre(Privilege::MATCH), "")) - .Times(1) + EXPECT_CALL(mock_user, IsAuthorized(ElementsAre(Privilege::MATCH), "", &memgraph::query::up_to_date_policy)) .WillOnce(Return(false)); + ASSERT_THROW( + store->AddTrigger("unprivileged_trigger", "MATCH (n:VERTEX) RETURN n", {}, + memgraph::query::TriggerEventType::EDGE_UPDATE, memgraph::query::TriggerPhase::AFTER_COMMIT, + &this->ast_cache, &*this->dba, memgraph::query::InterpreterConfig::Query{}, mock_user_ptr); + , memgraph::utils::BasicException); - ASSERT_THROW(store->AddTrigger("unprivileged_trigger", "MATCH (n:VERTEX) RETURN n", {}, - memgraph::query::TriggerEventType::EDGE_UPDATE, - memgraph::query::TriggerPhase::AFTER_COMMIT, &this->ast_cache, &*this->dba, - memgraph::query::InterpreterConfig::Query{}, std::nullopt, &mock_checker); - , memgraph::utils::BasicException); - + // Restore store.emplace(this->testing_directory); - EXPECT_CALL(mock_checker, IsUserAuthorized(std::optional{}, ElementsAre(Privilege::CREATE), "")) - .Times(1) - .WillOnce(Return(false)); - EXPECT_CALL(mock_checker, IsUserAuthorized(owner, ElementsAre(Privilege::CREATE), "")) - .Times(1) + + std::optional nopt{}; + EXPECT_CALL(mock_checker, GenQueryUser(owner, nopt)).WillOnce(Return(mock_user_ptr)); + EXPECT_CALL(mock_user, IsAuthorized(ElementsAre(Privilege::CREATE), "", &memgraph::query::up_to_date_policy)) .WillOnce(Return(true)); + EXPECT_CALL(mock_checker, GenQueryUser(nopt, nopt)).WillOnce(Return(mock_userless_ptr)); + EXPECT_CALL(mock_userless, IsAuthorized(ElementsAre(Privilege::CREATE), "", &memgraph::query::up_to_date_policy)) + .WillOnce(Return(false)); ASSERT_NO_THROW(store->RestoreTriggers(&this->ast_cache, &*this->dba, memgraph::query::InterpreterConfig::Query{}, &mock_checker)); diff --git a/tests/unit/storage_v2_get_info.cpp b/tests/unit/storage_v2_get_info.cpp index c0f7e2dbc..71dbc1a8d 100644 --- a/tests/unit/storage_v2_get_info.cpp +++ b/tests/unit/storage_v2_get_info.cpp @@ -13,12 +13,11 @@ #include #include -#include "disk_test_utils.hpp" +#include "dbms/constants.hpp" #include "storage/v2/disk/storage.hpp" #include "storage/v2/inmemory/storage.hpp" #include "storage/v2/isolation_level.hpp" #include "storage/v2/storage.hpp" -#include "storage/v2/storage_error.hpp" // NOLINTNEXTLINE(google-build-using-namespace) using namespace memgraph::storage; @@ -31,6 +30,7 @@ class InfoTest : public testing::Test { protected: void SetUp() override { std::filesystem::remove_all(storage_directory); + config_.salient.name = memgraph::dbms::kDefaultDB; memgraph::storage::UpdatePaths(config_, storage_directory); config_.durability.snapshot_wal_mode = memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL; @@ -135,7 +135,7 @@ TYPED_TEST(InfoTest, InfoCheck) { ASSERT_FALSE(unique_acc->Commit().HasError()); } - StorageInfo info = this->storage->GetInfo(true, ReplicationRole::MAIN); // force to use configured directory + StorageInfo info = this->storage->GetInfo(ReplicationRole::MAIN); ASSERT_EQ(info.vertex_count, 5); ASSERT_EQ(info.edge_count, 2); diff --git a/tests/unit/storage_v2_replication.cpp b/tests/unit/storage_v2_replication.cpp index c5e1ad543..4ae2101cb 100644 --- a/tests/unit/storage_v2_replication.cpp +++ b/tests/unit/storage_v2_replication.cpp @@ -142,21 +142,19 @@ TEST_F(ReplicationTest, BasicSynchronousReplicationTest) { MinMemgraph replica(repl_conf); auto replica_store_handler = replica.repl_handler; - replica_store_handler.SetReplicationRoleReplica( + replica_store_handler.TrySetReplicationRoleReplica( ReplicationServerConfig{ .ip_address = local_host, .port = ports[0], }, std::nullopt); - const auto ® = main.repl_handler.TryRegisterReplica( - ReplicationClientConfig{ - .name = "REPLICA", - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = ports[0], - }, - true); + const auto ® = main.repl_handler.TryRegisterReplica(ReplicationClientConfig{ + .name = "REPLICA", + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = ports[0], + }); ASSERT_FALSE(reg.HasError()) << (int)reg.GetError(); // vertex create @@ -439,13 +437,13 @@ TEST_F(ReplicationTest, MultipleSynchronousReplicationTest) { MinMemgraph replica1(repl_conf); MinMemgraph replica2(repl2_conf); - replica1.repl_handler.SetReplicationRoleReplica( + replica1.repl_handler.TrySetReplicationRoleReplica( ReplicationServerConfig{ .ip_address = local_host, .port = ports[0], }, std::nullopt); - replica2.repl_handler.SetReplicationRoleReplica( + replica2.repl_handler.TrySetReplicationRoleReplica( ReplicationServerConfig{ .ip_address = local_host, .port = ports[1], @@ -453,24 +451,20 @@ TEST_F(ReplicationTest, MultipleSynchronousReplicationTest) { std::nullopt); ASSERT_FALSE(main.repl_handler - .TryRegisterReplica( - ReplicationClientConfig{ - .name = replicas[0], - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = ports[0], - }, - true) + .TryRegisterReplica(ReplicationClientConfig{ + .name = replicas[0], + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = ports[0], + }) .HasError()); ASSERT_FALSE(main.repl_handler - .TryRegisterReplica( - ReplicationClientConfig{ - .name = replicas[1], - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = ports[1], - }, - true) + .TryRegisterReplica(ReplicationClientConfig{ + .name = replicas[1], + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = ports[1], + }) .HasError()); const auto *vertex_label = "label"; @@ -597,21 +591,19 @@ TEST_F(ReplicationTest, RecoveryProcess) { MinMemgraph replica(repl_conf); auto replica_store_handler = replica.repl_handler; - replica_store_handler.SetReplicationRoleReplica( + replica_store_handler.TrySetReplicationRoleReplica( ReplicationServerConfig{ .ip_address = local_host, .port = ports[0], }, std::nullopt); ASSERT_FALSE(main.repl_handler - .TryRegisterReplica( - ReplicationClientConfig{ - .name = replicas[0], - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = ports[0], - }, - true) + .TryRegisterReplica(ReplicationClientConfig{ + .name = replicas[0], + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = ports[0], + }) .HasError()); ASSERT_EQ(main.db.storage()->GetReplicaState(replicas[0]), ReplicaState::RECOVERY); @@ -676,7 +668,7 @@ TEST_F(ReplicationTest, BasicAsynchronousReplicationTest) { MinMemgraph replica_async(repl_conf); auto replica_store_handler = replica_async.repl_handler; - replica_store_handler.SetReplicationRoleReplica( + replica_store_handler.TrySetReplicationRoleReplica( ReplicationServerConfig{ .ip_address = local_host, .port = ports[1], @@ -684,14 +676,12 @@ TEST_F(ReplicationTest, BasicAsynchronousReplicationTest) { std::nullopt); ASSERT_FALSE(main.repl_handler - .TryRegisterReplica( - ReplicationClientConfig{ - .name = "REPLICA_ASYNC", - .mode = ReplicationMode::ASYNC, - .ip_address = local_host, - .port = ports[1], - }, - true) + .TryRegisterReplica(ReplicationClientConfig{ + .name = "REPLICA_ASYNC", + .mode = ReplicationMode::ASYNC, + .ip_address = local_host, + .port = ports[1], + }) .HasError()); static constexpr size_t vertices_create_num = 10; @@ -726,7 +716,7 @@ TEST_F(ReplicationTest, EpochTest) { MinMemgraph main(main_conf); MinMemgraph replica1(repl_conf); - replica1.repl_handler.SetReplicationRoleReplica( + replica1.repl_handler.TrySetReplicationRoleReplica( ReplicationServerConfig{ .ip_address = local_host, .port = ports[0], @@ -734,7 +724,7 @@ TEST_F(ReplicationTest, EpochTest) { std::nullopt); MinMemgraph replica2(repl2_conf); - replica2.repl_handler.SetReplicationRoleReplica( + replica2.repl_handler.TrySetReplicationRoleReplica( ReplicationServerConfig{ .ip_address = local_host, .port = 10001, @@ -742,25 +732,21 @@ TEST_F(ReplicationTest, EpochTest) { std::nullopt); ASSERT_FALSE(main.repl_handler - .TryRegisterReplica( - ReplicationClientConfig{ - .name = replicas[0], - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = ports[0], - }, - true) + .TryRegisterReplica(ReplicationClientConfig{ + .name = replicas[0], + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = ports[0], + }) .HasError()); ASSERT_FALSE(main.repl_handler - .TryRegisterReplica( - ReplicationClientConfig{ - .name = replicas[1], - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = 10001, - }, - true) + .TryRegisterReplica(ReplicationClientConfig{ + .name = replicas[1], + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = 10001, + }) .HasError()); std::optional vertex_gid; @@ -789,15 +775,12 @@ TEST_F(ReplicationTest, EpochTest) { ASSERT_TRUE(replica1.repl_handler.SetReplicationRoleMain()); ASSERT_FALSE(replica1.repl_handler - .TryRegisterReplica( - ReplicationClientConfig{ - .name = replicas[1], - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = 10001, - }, - true) - + .TryRegisterReplica(ReplicationClientConfig{ + .name = replicas[1], + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = 10001, + }) .HasError()); { @@ -819,22 +802,19 @@ TEST_F(ReplicationTest, EpochTest) { ASSERT_FALSE(acc->Commit().HasError()); } - replica1.repl_handler.SetReplicationRoleReplica( + replica1.repl_handler.TrySetReplicationRoleReplica( ReplicationServerConfig{ .ip_address = local_host, .port = ports[0], }, std::nullopt); ASSERT_TRUE(main.repl_handler - .TryRegisterReplica( - ReplicationClientConfig{ - .name = replicas[0], - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = ports[0], - }, - true) - + .TryRegisterReplica(ReplicationClientConfig{ + .name = replicas[0], + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = ports[0], + }) .HasError()); { @@ -858,7 +838,7 @@ TEST_F(ReplicationTest, ReplicationInformation) { MinMemgraph replica1(repl_conf); uint16_t replica1_port = 10001; - replica1.repl_handler.SetReplicationRoleReplica( + replica1.repl_handler.TrySetReplicationRoleReplica( ReplicationServerConfig{ .ip_address = local_host, .port = replica1_port, @@ -867,7 +847,7 @@ TEST_F(ReplicationTest, ReplicationInformation) { uint16_t replica2_port = 10002; MinMemgraph replica2(repl2_conf); - replica2.repl_handler.SetReplicationRoleReplica( + replica2.repl_handler.TrySetReplicationRoleReplica( ReplicationServerConfig{ .ip_address = local_host, .port = replica2_port, @@ -875,27 +855,21 @@ TEST_F(ReplicationTest, ReplicationInformation) { std::nullopt); ASSERT_FALSE(main.repl_handler - .TryRegisterReplica( - ReplicationClientConfig{ - .name = replicas[0], - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = replica1_port, - }, - true) - + .TryRegisterReplica(ReplicationClientConfig{ + .name = replicas[0], + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = replica1_port, + }) .HasError()); ASSERT_FALSE(main.repl_handler - .TryRegisterReplica( - ReplicationClientConfig{ - .name = replicas[1], - .mode = ReplicationMode::ASYNC, - .ip_address = local_host, - .port = replica2_port, - }, - true) - + .TryRegisterReplica(ReplicationClientConfig{ + .name = replicas[1], + .mode = ReplicationMode::ASYNC, + .ip_address = local_host, + .port = replica2_port, + }) .HasError()); ASSERT_TRUE(main.repl_state.IsMain()); @@ -923,7 +897,7 @@ TEST_F(ReplicationTest, ReplicationReplicaWithExistingName) { MinMemgraph replica1(repl_conf); uint16_t replica1_port = 10001; - replica1.repl_handler.SetReplicationRoleReplica( + replica1.repl_handler.TrySetReplicationRoleReplica( ReplicationServerConfig{ .ip_address = local_host, .port = replica1_port, @@ -932,32 +906,28 @@ TEST_F(ReplicationTest, ReplicationReplicaWithExistingName) { uint16_t replica2_port = 10002; MinMemgraph replica2(repl2_conf); - replica2.repl_handler.SetReplicationRoleReplica( + replica2.repl_handler.TrySetReplicationRoleReplica( ReplicationServerConfig{ .ip_address = local_host, .port = replica2_port, }, std::nullopt); ASSERT_FALSE(main.repl_handler - .TryRegisterReplica( - ReplicationClientConfig{ - .name = replicas[0], - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = replica1_port, - }, - true) + .TryRegisterReplica(ReplicationClientConfig{ + .name = replicas[0], + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = replica1_port, + }) .HasError()); ASSERT_TRUE(main.repl_handler - .TryRegisterReplica( - ReplicationClientConfig{ - .name = replicas[0], - .mode = ReplicationMode::ASYNC, - .ip_address = local_host, - .port = replica2_port, - }, - true) + .TryRegisterReplica(ReplicationClientConfig{ + .name = replicas[0], + .mode = ReplicationMode::ASYNC, + .ip_address = local_host, + .port = replica2_port, + }) .GetError() == RegisterReplicaError::NAME_EXISTS); } @@ -966,7 +936,7 @@ TEST_F(ReplicationTest, ReplicationReplicaWithExistingEndPoint) { MinMemgraph main(main_conf); MinMemgraph replica1(repl_conf); - replica1.repl_handler.SetReplicationRoleReplica( + replica1.repl_handler.TrySetReplicationRoleReplica( ReplicationServerConfig{ .ip_address = local_host, .port = common_port, @@ -974,7 +944,7 @@ TEST_F(ReplicationTest, ReplicationReplicaWithExistingEndPoint) { std::nullopt); MinMemgraph replica2(repl2_conf); - replica2.repl_handler.SetReplicationRoleReplica( + replica2.repl_handler.TrySetReplicationRoleReplica( ReplicationServerConfig{ .ip_address = local_host, .port = common_port, @@ -982,25 +952,21 @@ TEST_F(ReplicationTest, ReplicationReplicaWithExistingEndPoint) { std::nullopt); ASSERT_FALSE(main.repl_handler - .TryRegisterReplica( - ReplicationClientConfig{ - .name = replicas[0], - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = common_port, - }, - true) + .TryRegisterReplica(ReplicationClientConfig{ + .name = replicas[0], + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = common_port, + }) .HasError()); ASSERT_TRUE(main.repl_handler - .TryRegisterReplica( - ReplicationClientConfig{ - .name = replicas[1], - .mode = ReplicationMode::ASYNC, - .ip_address = local_host, - .port = common_port, - }, - true) + .TryRegisterReplica(ReplicationClientConfig{ + .name = replicas[1], + .mode = ReplicationMode::ASYNC, + .ip_address = local_host, + .port = common_port, + }) .GetError() == RegisterReplicaError::ENDPOINT_EXISTS); } @@ -1023,7 +989,7 @@ TEST_F(ReplicationTest, RestoringReplicationAtStartupAfterDroppingReplica) { std::optional main(main_config); MinMemgraph replica1(replica1_config); - replica1.repl_handler.SetReplicationRoleReplica( + replica1.repl_handler.TrySetReplicationRoleReplica( ReplicationServerConfig{ .ip_address = local_host, .port = ports[0], @@ -1031,30 +997,26 @@ TEST_F(ReplicationTest, RestoringReplicationAtStartupAfterDroppingReplica) { std::nullopt); MinMemgraph replica2(replica2_config); - replica2.repl_handler.SetReplicationRoleReplica( + replica2.repl_handler.TrySetReplicationRoleReplica( ReplicationServerConfig{ .ip_address = local_host, .port = ports[1], }, std::nullopt); - auto res = main->repl_handler.TryRegisterReplica( - ReplicationClientConfig{ - .name = replicas[0], - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = ports[0], - }, - true); + auto res = main->repl_handler.TryRegisterReplica(ReplicationClientConfig{ + .name = replicas[0], + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = ports[0], + }); ASSERT_FALSE(res.HasError()) << (int)res.GetError(); - res = main->repl_handler.TryRegisterReplica( - ReplicationClientConfig{ - .name = replicas[1], - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = ports[1], - }, - true); + res = main->repl_handler.TryRegisterReplica(ReplicationClientConfig{ + .name = replicas[1], + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = ports[1], + }); ASSERT_FALSE(res.HasError()) << (int)res.GetError(); auto replica_infos = main->db.storage()->ReplicasInfo(); @@ -1088,7 +1050,7 @@ TEST_F(ReplicationTest, RestoringReplicationAtStartup) { std::optional main(main_config); MinMemgraph replica1(repl_conf); - replica1.repl_handler.SetReplicationRoleReplica( + replica1.repl_handler.TrySetReplicationRoleReplica( ReplicationServerConfig{ .ip_address = local_host, .port = ports[0], @@ -1097,29 +1059,25 @@ TEST_F(ReplicationTest, RestoringReplicationAtStartup) { MinMemgraph replica2(repl2_conf); - replica2.repl_handler.SetReplicationRoleReplica( + replica2.repl_handler.TrySetReplicationRoleReplica( ReplicationServerConfig{ .ip_address = local_host, .port = ports[1], }, std::nullopt); - auto res = main->repl_handler.TryRegisterReplica( - ReplicationClientConfig{ - .name = replicas[0], - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = ports[0], - }, - true); + auto res = main->repl_handler.TryRegisterReplica(ReplicationClientConfig{ + .name = replicas[0], + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = ports[0], + }); ASSERT_FALSE(res.HasError()); - res = main->repl_handler.TryRegisterReplica( - ReplicationClientConfig{ - .name = replicas[1], - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = ports[1], - }, - true); + res = main->repl_handler.TryRegisterReplica(ReplicationClientConfig{ + .name = replicas[1], + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = ports[1], + }); ASSERT_FALSE(res.HasError()); auto replica_infos = main->db.storage()->ReplicasInfo(); @@ -1157,13 +1115,11 @@ TEST_F(ReplicationTest, AddingInvalidReplica) { MinMemgraph main(main_conf); ASSERT_TRUE(main.repl_handler - .TryRegisterReplica( - ReplicationClientConfig{ - .name = "REPLICA", - .mode = ReplicationMode::SYNC, - .ip_address = local_host, - .port = ports[0], - }, - true) + .TryRegisterReplica(ReplicationClientConfig{ + .name = "REPLICA", + .mode = ReplicationMode::SYNC, + .ip_address = local_host, + .port = ports[0], + }) .GetError() == RegisterReplicaError::ERROR_ACCEPTING_MAIN); } diff --git a/tests/unit/utils_memory_tracker.cpp b/tests/unit/utils_memory_tracker.cpp index cfc9e32b1..5f92b493b 100644 --- a/tests/unit/utils_memory_tracker.cpp +++ b/tests/unit/utils_memory_tracker.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -36,13 +36,13 @@ TEST(MemoryTrackerTest, ExceptionEnabler) { can_continue = true; }}; - ASSERT_NO_THROW(memory_tracker.Alloc(hard_limit + 1)); + ASSERT_TRUE(memory_tracker.Alloc(hard_limit + 1)); }}; std::thread t2{[&] { memgraph::utils::MemoryTracker::OutOfMemoryExceptionEnabler exception_enabler; enabler_created = true; - ASSERT_THROW(memory_tracker.Alloc(hard_limit + 1), memgraph::utils::OutOfMemoryException); + ASSERT_FALSE(memory_tracker.Alloc(hard_limit + 1)); // hold the enabler until the first thread finishes while (!can_continue) @@ -63,8 +63,8 @@ TEST(MemoryTrackerTest, ExceptionBlocker) { { memgraph::utils::MemoryTracker::OutOfMemoryExceptionBlocker exception_blocker; - ASSERT_NO_THROW(memory_tracker.Alloc(hard_limit + 1)); + ASSERT_TRUE(memory_tracker.Alloc(hard_limit + 1)); ASSERT_EQ(memory_tracker.Amount(), hard_limit + 1); } - ASSERT_THROW(memory_tracker.Alloc(hard_limit + 1), memgraph::utils::OutOfMemoryException); + ASSERT_FALSE(memory_tracker.Alloc(hard_limit + 1)); }