From c4167bafdd9b529517edcafe95604a71bf20aaa0 Mon Sep 17 00:00:00 2001 From: Jure Bajic <jure.bajic@memgraph.com> Date: Tue, 14 Mar 2023 19:24:55 +0100 Subject: [PATCH 1/6] Add support for Amazon Linux 2 and stop generating C++ using Lisp/LCP (#814) --- .github/workflows/package_all.yaml | 17 + .gitignore | 7 - environment/os/amzn-2.sh | 156 ++ environment/toolchain/v4.sh | 234 +- init | 49 +- release/CMakeLists.txt | 36 +- release/package/amzn-2/Dockerfile | 14 + release/package/docker-compose.yml | 4 + release/package/run.sh | 11 +- src/CMakeLists.txt | 1 - src/query/CMakeLists.txt | 13 +- src/query/frontend/ast/ast.cpp | 263 ++ src/query/frontend/ast/ast.hpp | 3246 ++++++++++++++++++++++++ src/query/frontend/semantic/symbol.cpp | 18 + src/query/frontend/semantic/symbol.hpp | 75 + src/query/plan/operator.hpp | 2296 +++++++++++++++++ src/query/plan/operator_type_info.cpp | 148 ++ src/rpc/client.hpp | 7 +- src/rpc/protocol.cpp | 5 +- src/rpc/server.hpp | 6 +- src/slk/serialization.hpp | 16 +- src/storage/v2/CMakeLists.txt | 9 +- src/storage/v2/replication/.gitignore | 2 - src/storage/v2/replication/rpc.cpp | 263 ++ src/storage/v2/replication/rpc.hpp | 278 ++ src/utils/typeinfo.hpp | 170 +- tests/benchmark/rpc.cpp | 6 +- tests/unit/CMakeLists.txt | 9 - tests/unit/rpc_messages.hpp | 8 +- 29 files changed, 7155 insertions(+), 212 deletions(-) create mode 100755 environment/os/amzn-2.sh create mode 100644 release/package/amzn-2/Dockerfile create mode 100644 src/query/frontend/ast/ast.cpp create mode 100644 src/query/frontend/ast/ast.hpp create mode 100644 src/query/frontend/semantic/symbol.cpp create mode 100644 src/query/frontend/semantic/symbol.hpp create mode 100644 src/query/plan/operator.hpp create mode 100644 src/query/plan/operator_type_info.cpp delete mode 100644 src/storage/v2/replication/.gitignore create mode 100644 src/storage/v2/replication/rpc.cpp create mode 100644 src/storage/v2/replication/rpc.hpp diff --git a/.github/workflows/package_all.yaml b/.github/workflows/package_all.yaml index b0ef18fde..7d225bac7 100644 --- a/.github/workflows/package_all.yaml +++ b/.github/workflows/package_all.yaml @@ -177,6 +177,23 @@ jobs: name: fedora-36 path: build/output/fedora-36/memgraph*.rpm + amzn-2: + runs-on: [self-hosted, DockerMgBuild, X64] + timeout-minutes: 60 + steps: + - name: "Set up repository" + uses: actions/checkout@v3 + with: + fetch-depth: 0 # Required because of release/get_version.py + - name: "Build package" + run: | + ./release/package/run.sh package amzn-2 + - name: "Upload package" + uses: actions/upload-artifact@v3 + with: + name: amzn-2 + path: build/output/amzn-2/memgraph*.rpm + debian-11-arm: runs-on: [self-hosted, DockerMgBuild, ARM64, strange] timeout-minutes: 60 diff --git a/.gitignore b/.gitignore index e1a4187b0..83b9a5dc7 100644 --- a/.gitignore +++ b/.gitignore @@ -34,9 +34,6 @@ TAGS *.fas *.fasl -# LCP generated C++ files -*.lcp.cpp - src/database/distributed/serialization.hpp src/database/single_node_ha/serialization.hpp src/distributed/bfs_rpc_messages.hpp @@ -50,15 +47,11 @@ src/distributed/pull_produce_rpc_messages.hpp src/distributed/storage_gc_rpc_messages.hpp src/distributed/token_sharing_rpc_messages.hpp src/distributed/updates_rpc_messages.hpp -src/query/frontend/ast/ast.hpp -src/query/distributed/frontend/ast/ast_serialization.hpp src/durability/distributed/state_delta.hpp src/durability/single_node/state_delta.hpp src/durability/single_node_ha/state_delta.hpp -src/query/frontend/semantic/symbol.hpp src/query/distributed/frontend/semantic/symbol_serialization.hpp src/query/distributed/plan/ops.hpp -src/query/plan/operator.hpp src/raft/log_entry.hpp src/raft/raft_rpc_messages.hpp src/raft/snapshot_metadata.hpp diff --git a/environment/os/amzn-2.sh b/environment/os/amzn-2.sh new file mode 100755 index 000000000..6df12312c --- /dev/null +++ b/environment/os/amzn-2.sh @@ -0,0 +1,156 @@ +#!/bin/bash + +set -Eeuo pipefail + +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +source "$DIR/../util.sh" + +check_operating_system "amzn-2" +check_architecture "x86_64" + +TOOLCHAIN_BUILD_DEPS=( + gcc gcc-c++ make # generic build tools + wget # used for archive download + gnupg2 # used for archive signature verification + tar gzip bzip2 xz unzip # used for archive unpacking + zlib-devel # zlib library used for all builds + expat-devel xz-devel python3-devel texinfo + curl libcurl-devel # for cmake + readline-devel # for cmake and llvm + libffi-devel libxml2-devel # for llvm + libedit-devel pcre-devel automake bison # for swig + file + openssl-devel + gmp-devel + gperf + diffutils + patch + libipt libipt-devel # intel + perl # for openssl +) + +TOOLCHAIN_RUN_DEPS=( + make # generic build tools + tar gzip bzip2 xz # used for archive unpacking + zlib # zlib library used for all builds + expat xz-libs python3 # for gdb + readline # for cmake and llvm + libffi libxml2 # for llvm + openssl-devel +) + +MEMGRAPH_BUILD_DEPS=( + git # source code control + make # build system + wget # for downloading libs + libuuid-devel java-11-openjdk # required by antlr + readline-devel # for memgraph console + python3-devel # for query modules + openssl-devel + libseccomp-devel + python3 python3-pip nmap-ncat # for tests + # + # IMPORTANT: python3-yaml does NOT exist on CentOS + # Install it using `pip3 install PyYAML` + # + PyYAML # Package name here does not correspond to the yum package! + libcurl-devel # mg-requests + rpm-build rpmlint # for RPM package building + doxygen graphviz # source documentation generators + which nodejs golang zip unzip java-11-openjdk-devel # for driver tests + autoconf # for jemalloc code generation + libtool # for protobuf code generation +) + +list() { + echo "$1" +} + +check() { + local missing="" + # On Fedora yum/dnf and python10 use newer glibc which is not compatible + # with ours, so we need to momentarely disable env + local OLD_LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-""} + LD_LIBRARY_PATH="" + for pkg in $1; do + if [ "$pkg" == "PyYAML" ]; then + if ! python3 -c "import yaml" >/dev/null 2>/dev/null; then + missing="$pkg $missing" + fi + continue + fi + if ! yum list installed "$pkg" >/dev/null 2>/dev/null; then + missing="$pkg $missing" + fi + done + if [ "$missing" != "" ]; then + echo "MISSING PACKAGES: $missing" + exit 1 + fi + LD_LIBRARY_PATH=${OLD_LD_LIBRARY_PATH} +} + +install() { + cd "$DIR" + if [ "$EUID" -ne 0 ]; then + echo "Please run as root." + exit 1 + fi + # If GitHub Actions runner is installed, append LANG to the environment. + # Python related tests don't work without the LANG export. + if [ -d "/home/gh/actions-runner" ]; then + echo "LANG=en_US.utf8" >> /home/gh/actions-runner/.env + else + echo "NOTE: export LANG=en_US.utf8" + fi + yum update -y + for pkg in $1; do + if [ "$pkg" == libipt ]; then + if ! yum list installed libipt >/dev/null 2>/dev/null; then + yum install -y http://repo.okay.com.mx/centos/8/x86_64/release/libipt-1.6.1-8.el8.x86_64.rpm + fi + continue + fi + if [ "$pkg" == libipt-devel ]; then + if ! yum list installed libipt-devel >/dev/null 2>/dev/null; then + yum install -y http://repo.okay.com.mx/centos/8/x86_64/release/libipt-devel-1.6.1-8.el8.x86_64.rpm + fi + continue + fi + if [ "$pkg" == nodejs ]; then + curl -sL https://rpm.nodesource.com/setup_16.x | bash - + if ! yum list installed nodejs >/dev/null 2>/dev/null; then + yum install -y nodejs + fi + continue + fi + if [ "$pkg" == PyYAML ]; then + if [ -z ${SUDO_USER+x} ]; then # Running as root (e.g. Docker). + pip3 install --user PyYAML + else # Running using sudo. + sudo -H -u "$SUDO_USER" bash -c "pip3 install --user PyYAML" + fi + continue + fi + if [ "$pkg" == nodejs ]; then + curl -sL https://rpm.nodesource.com/setup_16.x | bash - + if ! yum list installed nodejs >/dev/null 2>/dev/null; then + yum install -y nodejs + fi + continue + fi + if [ "$pkg" == java-11-openjdk ]; then + amazon-linux-extras install -y java-openjdk11 + continue + fi + if [ "$pkg" == java-11-openjdk-devel ]; then + amazon-linux-extras install -y java-openjdk11 + yum install -y java-11-openjdk-devel + continue + fi + yum install -y "$pkg" + done +} + +deps=$2"[*]" +"$1" "${!deps}" diff --git a/environment/toolchain/v4.sh b/environment/toolchain/v4.sh index e3fddaa1f..2ac92441f 100755 --- a/environment/toolchain/v4.sh +++ b/environment/toolchain/v4.sh @@ -415,6 +415,34 @@ if [ ! -f $PREFIX/bin/gdb ]; then --with-intel-pt \ --enable-tui \ --with-python=python3 + elif [[ "${DISTRO}" == "amzn-2" ]]; then + # Remove readline, gdb does not compile + env \ + CC=gcc \ + CXX=g++ \ + CFLAGS="-g -O2 -fstack-protector-strong -Wformat -Werror=format-security" \ + CXXFLAGS="-g -O2 -fstack-protector-strong -Wformat -Werror=format-security" \ + CPPFLAGS="-Wdate-time -D_FORTIFY_SOURCE=2 -fPIC" \ + LDFLAGS="-Wl,-z,relro" \ + PYTHON="" \ + ../configure \ + --build=x86_64-linux-gnu \ + --host=x86_64-linux-gnu \ + --prefix=$PREFIX \ + --disable-maintainer-mode \ + --disable-dependency-tracking \ + --disable-silent-rules \ + --disable-gdbtk \ + --disable-shared \ + --without-guile \ + --with-system-gdbinit=$PREFIX/etc/gdb/gdbinit \ + --with-expat \ + --with-system-zlib \ + --with-lzma \ + --with-babeltrace \ + --with-intel-pt \ + --enable-tui \ + --with-python=python3 else # https://buildd.debian.org/status/fetch.php?pkg=gdb&arch=amd64&ver=8.2.1-2&stamp=1550831554&raw=0 env \ @@ -1143,119 +1171,121 @@ if [ ! -f $PREFIX/include/libaio.h ]; then popd fi -log_tool_name "folly $FBLIBS_VERSION" -if [ ! -d $PREFIX/include/folly ]; then - if [ -d folly-$FBLIBS_VERSION ]; then - rm -rf folly-$FBLIBS_VERSION +if [[ "${DISTRO}" != "amzn-2" ]]; then + log_tool_name "folly $FBLIBS_VERSION" + if [ ! -d $PREFIX/include/folly ]; then + if [ -d folly-$FBLIBS_VERSION ]; then + rm -rf folly-$FBLIBS_VERSION + fi + mkdir folly-$FBLIBS_VERSION + tar -xzf ../archives/folly-$FBLIBS_VERSION.tar.gz -C folly-$FBLIBS_VERSION + pushd folly-$FBLIBS_VERSION + patch -p1 < ../../folly.patch + # build is used by facebook builder + mkdir _build + pushd _build + cmake .. $COMMON_CMAKE_FLAGS \ + -DBOOST_LINK_STATIC=ON \ + -DBUILD_TESTS=OFF \ + -DGFLAGS_NOTHREADS=OFF \ + -DCXX_STD="c++20" + make -j$CPUS install + popd && popd fi - mkdir folly-$FBLIBS_VERSION - tar -xzf ../archives/folly-$FBLIBS_VERSION.tar.gz -C folly-$FBLIBS_VERSION - pushd folly-$FBLIBS_VERSION - patch -p1 < ../../folly.patch - # build is used by facebook builder - mkdir _build - pushd _build - cmake .. $COMMON_CMAKE_FLAGS \ - -DBOOST_LINK_STATIC=ON \ - -DBUILD_TESTS=OFF \ - -DGFLAGS_NOTHREADS=OFF \ - -DCXX_STD="c++20" - make -j$CPUS install - popd && popd -fi -log_tool_name "fizz $FBLIBS_VERSION" -if [ ! -d $PREFIX/include/fizz ]; then - if [ -d fizz-$FBLIBS_VERSION ]; then - rm -rf fizz-$FBLIBS_VERSION + log_tool_name "fizz $FBLIBS_VERSION" + if [ ! -d $PREFIX/include/fizz ]; then + if [ -d fizz-$FBLIBS_VERSION ]; then + rm -rf fizz-$FBLIBS_VERSION + fi + mkdir fizz-$FBLIBS_VERSION + tar -xzf ../archives/fizz-$FBLIBS_VERSION.tar.gz -C fizz-$FBLIBS_VERSION + pushd fizz-$FBLIBS_VERSION + # build is used by facebook builder + mkdir _build + pushd _build + cmake ../fizz $COMMON_CMAKE_FLAGS \ + -DBUILD_TESTS=OFF \ + -DBUILD_EXAMPLES=OFF \ + -DGFLAGS_NOTHREADS=OFF + make -j$CPUS install + popd && popd fi - mkdir fizz-$FBLIBS_VERSION - tar -xzf ../archives/fizz-$FBLIBS_VERSION.tar.gz -C fizz-$FBLIBS_VERSION - pushd fizz-$FBLIBS_VERSION - # build is used by facebook builder - mkdir _build - pushd _build - cmake ../fizz $COMMON_CMAKE_FLAGS \ - -DBUILD_TESTS=OFF \ - -DBUILD_EXAMPLES=OFF \ - -DGFLAGS_NOTHREADS=OFF - make -j$CPUS install - popd && popd -fi -log_tool_name "wangle FBLIBS_VERSION" -if [ ! -d $PREFIX/include/wangle ]; then - if [ -d wangle-$FBLIBS_VERSION ]; then - rm -rf wangle-$FBLIBS_VERSION + log_tool_name "wangle FBLIBS_VERSION" + if [ ! -d $PREFIX/include/wangle ]; then + if [ -d wangle-$FBLIBS_VERSION ]; then + rm -rf wangle-$FBLIBS_VERSION + fi + mkdir wangle-$FBLIBS_VERSION + tar -xzf ../archives/wangle-$FBLIBS_VERSION.tar.gz -C wangle-$FBLIBS_VERSION + pushd wangle-$FBLIBS_VERSION + # build is used by facebook builder + mkdir _build + pushd _build + cmake ../wangle $COMMON_CMAKE_FLAGS \ + -DBUILD_TESTS=OFF \ + -DBUILD_EXAMPLES=OFF \ + -DGFLAGS_NOTHREADS=OFF + make -j$CPUS install + popd && popd fi - mkdir wangle-$FBLIBS_VERSION - tar -xzf ../archives/wangle-$FBLIBS_VERSION.tar.gz -C wangle-$FBLIBS_VERSION - pushd wangle-$FBLIBS_VERSION - # build is used by facebook builder - mkdir _build - pushd _build - cmake ../wangle $COMMON_CMAKE_FLAGS \ - -DBUILD_TESTS=OFF \ - -DBUILD_EXAMPLES=OFF \ - -DGFLAGS_NOTHREADS=OFF - make -j$CPUS install - popd && popd -fi -log_tool_name "proxygen $FBLIBS_VERSION" -if [ ! -d $PREFIX/include/proxygen ]; then - if [ -d proxygen-$FBLIBS_VERSION ]; then - rm -rf proxygen-$FBLIBS_VERSION + log_tool_name "proxygen $FBLIBS_VERSION" + if [ ! -d $PREFIX/include/proxygen ]; then + if [ -d proxygen-$FBLIBS_VERSION ]; then + rm -rf proxygen-$FBLIBS_VERSION + fi + mkdir proxygen-$FBLIBS_VERSION + tar -xzf ../archives/proxygen-$FBLIBS_VERSION.tar.gz -C proxygen-$FBLIBS_VERSION + pushd proxygen-$FBLIBS_VERSION + patch -p1 < ../../proxygen.patch + # build is used by facebook builder + mkdir _build + pushd _build + cmake .. $COMMON_CMAKE_FLAGS \ + -DBUILD_TESTS=OFF \ + -DBUILD_SAMPLES=OFF \ + -DGFLAGS_NOTHREADS=OFF \ + -DBUILD_QUIC=OFF + make -j$CPUS install + popd && popd fi - mkdir proxygen-$FBLIBS_VERSION - tar -xzf ../archives/proxygen-$FBLIBS_VERSION.tar.gz -C proxygen-$FBLIBS_VERSION - pushd proxygen-$FBLIBS_VERSION - patch -p1 < ../../proxygen.patch - # build is used by facebook builder - mkdir _build - pushd _build - cmake .. $COMMON_CMAKE_FLAGS \ - -DBUILD_TESTS=OFF \ - -DBUILD_SAMPLES=OFF \ - -DGFLAGS_NOTHREADS=OFF \ - -DBUILD_QUIC=OFF - make -j$CPUS install - popd && popd -fi -log_tool_name "flex $FBLIBS_VERSION" -if [ ! -f $PREFIX/include/FlexLexer.h ]; then - if [ -d flex-$FLEX_VERSION ]; then - rm -rf flex-$FLEX_VERSION + log_tool_name "flex $FBLIBS_VERSION" + if [ ! -f $PREFIX/include/FlexLexer.h ]; then + if [ -d flex-$FLEX_VERSION ]; then + rm -rf flex-$FLEX_VERSION + fi + tar -xzf ../archives/flex-$FLEX_VERSION.tar.gz + pushd flex-$FLEX_VERSION + ./configure $COMMON_CONFIGURE_FLAGS + make -j$CPUS install + popd fi - tar -xzf ../archives/flex-$FLEX_VERSION.tar.gz - pushd flex-$FLEX_VERSION - ./configure $COMMON_CONFIGURE_FLAGS - make -j$CPUS install - popd -fi -log_tool_name "fbthrift $FBLIBS_VERSION" -if [ ! -d $PREFIX/include/thrift ]; then - if [ -d fbthrift-$FBLIBS_VERSION ]; then - rm -rf fbthrift-$FBLIBS_VERSION + log_tool_name "fbthrift $FBLIBS_VERSION" + if [ ! -d $PREFIX/include/thrift ]; then + if [ -d fbthrift-$FBLIBS_VERSION ]; then + rm -rf fbthrift-$FBLIBS_VERSION + fi + git clone --depth 1 --branch v$FBLIBS_VERSION https://github.com/facebook/fbthrift.git fbthrift-$FBLIBS_VERSION + pushd fbthrift-$FBLIBS_VERSION + # build is used by facebook builder + mkdir _build + pushd _build + if [ "$TOOLCHAIN_STDCXX" = "libstdc++" ]; then + CMAKE_CXX_FLAGS="-fsized-deallocation" + else + CMAKE_CXX_FLAGS="-fsized-deallocation -stdlib=libc++" + fi + cmake .. $COMMON_CMAKE_FLAGS \ + -Denable_tests=OFF \ + -DGFLAGS_NOTHREADS=OFF \ + -DCMAKE_CXX_FLAGS="$CMAKE_CXX_FLAGS" + make -j$CPUS install + popd fi - git clone --depth 1 --branch v$FBLIBS_VERSION https://github.com/facebook/fbthrift.git fbthrift-$FBLIBS_VERSION - pushd fbthrift-$FBLIBS_VERSION - # build is used by facebook builder - mkdir _build - pushd _build - if [ "$TOOLCHAIN_STDCXX" = "libstdc++" ]; then - CMAKE_CXX_FLAGS="-fsized-deallocation" - else - CMAKE_CXX_FLAGS="-fsized-deallocation -stdlib=libc++" - fi - cmake .. $COMMON_CMAKE_FLAGS \ - -Denable_tests=OFF \ - -DGFLAGS_NOTHREADS=OFF \ - -DCMAKE_CXX_FLAGS="$CMAKE_CXX_FLAGS" - make -j$CPUS install - popd fi popd diff --git a/init b/init index a4754fb3e..95a2438eb 100755 --- a/init +++ b/init @@ -14,7 +14,6 @@ function print_help () { echo "Optional arguments:" echo -e " -h\tdisplay this help and exit" echo -e " --without-libs-setup\tskip the step for setting up libs" - echo -e " --wsl-quicklisp-proxy \"host:port\"\tquicklist HTTP proxy (this flag + HTTP proxy are required on WSL)" } function setup_virtualenv () { @@ -35,7 +34,6 @@ function setup_virtualenv () { popd > /dev/null } -wsl_quicklisp_proxy="" setup_libs=true if [[ $# -eq 1 && "$1" == "-h" ]]; then print_help @@ -43,16 +41,6 @@ if [[ $# -eq 1 && "$1" == "-h" ]]; then else while(($#)); do case "$1" in - --wsl-quicklisp-proxy) - shift - if [[ $# -eq 0 ]]; then - echo "Missing proxy URL" - print_help - exit 1 - fi - wsl_quicklisp_proxy=":proxy \"http://$1/\"" - shift - ;; --without-libs-setup) shift setup_libs=false @@ -79,41 +67,16 @@ echo "All packages are in-place..." # create a default build directory mkdir -p ./build -# quicklisp package manager for Common Lisp -quicklisp_install_dir="$HOME/quicklisp" -if [[ -v QUICKLISP_HOME ]]; then - quicklisp_install_dir="${QUICKLISP_HOME}" -fi - -if [[ ! -f "${quicklisp_install_dir}/setup.lisp" ]]; then - wget -nv https://beta.quicklisp.org/quicklisp.lisp -O quicklisp.lisp || exit 1 - echo \ - " - (load \"${DIR}/quicklisp.lisp\") - (quicklisp-quickstart:install $wsl_quicklisp_proxy :path \"${quicklisp_install_dir}\") - " | sbcl --script || exit 1 - rm -rf quicklisp.lisp || exit 1 -fi -ln -Tfs "$DIR/src/lisp" "${quicklisp_install_dir}/local-projects/lcp" -# Install LCP dependencies -# TODO: We should at some point cache or have a mirror of packages we use. -# TODO: move the installation of LCP's dependencies into ./setup.sh -echo \ - " - (load \"${quicklisp_install_dir}/setup.lisp\") - (ql:quickload '(:lcp :lcp/test) :silent t) - " | sbcl --script - if [[ "$setup_libs" == "true" ]]; then - # Setup libs (download). - cd libs - ./cleanup.sh - ./setup.sh - cd .. + # Setup libs (download). + cd libs + ./cleanup.sh + ./setup.sh + cd .. fi # Fix for centos 7 during release -if [ "${DISTRO}" = "centos-7" ] || [ "${DISTRO}" = "debian-11" ]; then +if [ "${DISTRO}" = "centos-7" ] || [ "${DISTRO}" = "debian-11" ] || [ "${DISTRO}" = "amzn-2" ]; then python3 -m pip uninstall -y virtualenv python3 -m pip install virtualenv fi diff --git a/release/CMakeLists.txt b/release/CMakeLists.txt index 489aea989..4bdb4bf69 100644 --- a/release/CMakeLists.txt +++ b/release/CMakeLists.txt @@ -1,6 +1,10 @@ # Install systemd service (must use absolute path). install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/memgraph.service - DESTINATION /lib/systemd/system) + DESTINATION /lib/systemd/system) + +# Set parameters to recognize the host distro +cmake_host_system_information(RESULT DISTRO QUERY DISTRIB_NAME) +cmake_host_system_information(RESULT DISTRO_VERSION QUERY DISTRIB_VERSION) # ---- Setup CPack -------- @@ -12,10 +16,11 @@ set(CPACK_PACKAGE_DESCRIPTION_SUMMARY # Setting arhitecture extension for deb packages set(MG_ARCH_EXTENSION_DEB "all") -if (${MG_ARCH} STREQUAL "x86_64") + +if(${MG_ARCH} STREQUAL "x86_64") set(MG_ARCH_EXTENSION_DEB "amd64") -elseif (${MG_ARCH} STREQUAL "ARM64") - set(MG_ARCH_EXTENSION_DEB "arm64") +elseif(${MG_ARCH} STREQUAL "ARM64") + set(MG_ARCH_EXTENSION_DEB "arm64") endif() # DEB specific @@ -34,21 +39,24 @@ set(CPACK_DEBIAN_PACKAGE_CONTROL_EXTRA "${CMAKE_CURRENT_SOURCE_DIR}/debian/postrm;" "${CMAKE_CURRENT_SOURCE_DIR}/debian/postinst;") set(CPACK_DEBIAN_PACKAGE_SHLIBDEPS ON) + # Description formatting is important, summary must be followed with a newline and 1 space. set(CPACK_DEBIAN_PACKAGE_DESCRIPTION "${CPACK_PACKAGE_DESCRIPTION_SUMMARY} Contains Memgraph, the graph database. It aims to deliver developers the speed, simplicity and scale required to build the next generation of applications driver by real-time connected data.") + # Add `openssl` package to dependencies list. Used to generate SSL certificates. # We also depend on `python3` because we embed it in Memgraph. set(CPACK_DEBIAN_PACKAGE_DEPENDS "openssl (>= 1.1.0), python3 (>= 3.5.0), libstdc++6") # Setting arhitecture extension for rpm packages set(MG_ARCH_EXTENSION_RPM "noarch") -if (${MG_ARCH} STREQUAL "x86_64") + +if(${MG_ARCH} STREQUAL "x86_64") set(MG_ARCH_EXTENSION_RPM "x86_64") -elseif (${MG_ARCH} STREQUAL "ARM64") - set(MG_ARCH_EXTENSION_RPM "aarch64") +elseif(${MG_ARCH} STREQUAL "ARM64") + set(MG_ARCH_EXTENSION_RPM "aarch64") endif() # RPM specific @@ -56,18 +64,26 @@ set(CPACK_RPM_PACKAGE_URL https://memgraph.com) set(CPACK_RPM_PACKAGE_VERSION "${MEMGRAPH_VERSION_RPM}") set(CPACK_RPM_FILE_NAME "memgraph-${MEMGRAPH_VERSION_RPM}-1.${MG_ARCH_EXTENSION_RPM}.rpm") set(CPACK_RPM_EXCLUDE_FROM_AUTO_FILELIST_ADDITION - /var /var/lib /var/log /etc/logrotate.d - /lib /lib/systemd /lib/systemd/system /lib/systemd/system/memgraph.service) + /var /var/lib /var/log /etc/logrotate.d + /lib /lib/systemd /lib/systemd/system /lib/systemd/system/memgraph.service) set(CPACK_RPM_PACKAGE_REQUIRES_PRE "shadow-utils") set(CPACK_RPM_USER_BINARY_SPECFILE "${CMAKE_CURRENT_SOURCE_DIR}/rpm/memgraph.spec.in") set(CPACK_RPM_PACKAGE_LICENSE "Memgraph License") + # Description formatting is important, no line must be greater than 80 characters. set(CPACK_RPM_PACKAGE_DESCRIPTION "Contains Memgraph, the graph database. It aims to deliver developers the speed, simplicity and scale required to build the next generation of applications driver by real-time connected data.") + # Add `openssl` package to dependencies list. Used to generate SSL certificates. # We also depend on `python3` because we embed it in Memgraph. -set(CPACK_RPM_PACKAGE_REQUIRES "openssl >= 1.0.0, curl >= 7.29.0, python3 >= 3.5.0, libstdc++ >= 6, logrotate") +set(CPACK_RPM_PACKAGE_REQUIRES "openssl >= 1.0.0, curl >= 7.29.0, python3 >= 3.5.0, libstdc++ >= 3.4.29, logrotate") + +# If amzn-2 +if(DISTRO STREQUAL "Amazon Linux" AND DISTRO_VERSION STREQUAL "2") + # It causes issues with glibcxx 2.4 + set(CPACK_RPM_PACKAGE_AUTOREQ " no") +endif() # All variables must be set before including. include(CPack) diff --git a/release/package/amzn-2/Dockerfile b/release/package/amzn-2/Dockerfile new file mode 100644 index 000000000..3bcc8ad72 --- /dev/null +++ b/release/package/amzn-2/Dockerfile @@ -0,0 +1,14 @@ +FROM amazonlinux:2 + +ARG TOOLCHAIN_VERSION + +RUN yum -y update \ + && yum install -y wget git tar +# Do NOT be smart here and clean the cache because the container is used in the +# stateful context. + +RUN wget -q https://s3-eu-west-1.amazonaws.com/deps.memgraph.io/${TOOLCHAIN_VERSION}/${TOOLCHAIN_VERSION}-binaries-amzn-2-x86_64.tar.gz \ + -O ${TOOLCHAIN_VERSION}-binaries-amzn-2-x86_64.tar.gz \ + && tar xzvf ${TOOLCHAIN_VERSION}-binaries-amzn-2-x86_64.tar.gz -C /opt + +ENTRYPOINT ["sleep", "infinity"] diff --git a/release/package/docker-compose.yml b/release/package/docker-compose.yml index 1f285bf84..4da0526ba 100644 --- a/release/package/docker-compose.yml +++ b/release/package/docker-compose.yml @@ -32,3 +32,7 @@ services: build: context: fedora-36 container_name: "mgbuild_fedora-36" + mgbuild_amzn-2: + build: + context: amzn-2 + container_name: "mgbuild_amzn-2" diff --git a/release/package/run.sh b/release/package/run.sh index 7d068eaa9..cdf95466a 100755 --- a/release/package/run.sh +++ b/release/package/run.sh @@ -3,7 +3,14 @@ set -Eeuo pipefail SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" -SUPPORTED_OS=(centos-7 centos-9 debian-10 debian-11 ubuntu-18.04 ubuntu-20.04 ubuntu-22.04 debian-11-arm fedora-36 ubuntu-22.04-arm) +SUPPORTED_OS=( + centos-7 centos-9 + debian-10 debian-11 debian-11-arm + ubuntu-18.04 ubuntu-20.04 ubuntu-22.04 ubuntu-22.04-arm + fedora-36 + amzn-2 +) + PROJECT_ROOT="$SCRIPT_DIR/../.." TOOLCHAIN_VERSION="toolchain-v4" ACTIVATE_TOOLCHAIN="source /opt/${TOOLCHAIN_VERSION}/activate" @@ -23,7 +30,7 @@ make_package () { echo "Building Memgraph for $os on $build_container..." package_command="" - if [[ "$os" =~ ^"centos".* ]] || [[ "$os" =~ ^"fedora".* ]]; then + if [[ "$os" =~ ^"centos".* ]] || [[ "$os" =~ ^"fedora".* ]] || [[ "$os" =~ ^"amzn".* ]]; then docker exec "$build_container" bash -c "yum -y update" package_command=" cpack -G RPM --config ../CPackConfig.cmake && rpmlint --file='../../release/rpm/rpmlintrc' memgraph*.rpm " fi diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d6f58c642..ad1cfd711 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,7 +1,6 @@ # CMake configuration for the main memgraph library and executable # add memgraph sub libraries, ordered by dependency -add_subdirectory(lisp) add_subdirectory(utils) add_subdirectory(requests) add_subdirectory(io) diff --git a/src/query/CMakeLists.txt b/src/query/CMakeLists.txt index 77a12a430..e7e2bd66e 100644 --- a/src/query/CMakeLists.txt +++ b/src/query/CMakeLists.txt @@ -1,13 +1,7 @@ -define_add_lcp(add_lcp_query lcp_query_cpp_files generated_lcp_query_files) - -add_lcp_query(frontend/ast/ast.lcp) -add_lcp_query(frontend/semantic/symbol.lcp) -add_lcp_query(plan/operator.lcp) - -add_custom_target(generate_lcp_query DEPENDS ${generated_lcp_query_files}) - set(mg_query_sources - ${lcp_query_cpp_files} + frontend/ast/ast.cpp + frontend/semantic/symbol.cpp + plan/operator_type_info.cpp common.cpp cypher_query_interpreter.cpp dump.cpp @@ -46,7 +40,6 @@ set(mg_query_sources find_package(Boost REQUIRED) add_library(mg-query STATIC ${mg_query_sources}) -add_dependencies(mg-query generate_lcp_query) target_include_directories(mg-query PUBLIC ${CMAKE_SOURCE_DIR}/include) target_link_libraries(mg-query dl cppitertools Boost::headers) target_link_libraries(mg-query mg-integrations-pulsar mg-integrations-kafka mg-storage-v2 mg-license mg-utils mg-kvstore mg-memory) diff --git a/src/query/frontend/ast/ast.cpp b/src/query/frontend/ast/ast.cpp new file mode 100644 index 000000000..6dfd4f85c --- /dev/null +++ b/src/query/frontend/ast/ast.cpp @@ -0,0 +1,263 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/frontend/ast/ast.hpp" +#include "utils/typeinfo.hpp" + +namespace memgraph { + +constexpr utils::TypeInfo query::LabelIx::kType{utils::TypeId::AST_LABELIX, "LabelIx", nullptr}; + +constexpr utils::TypeInfo query::PropertyIx::kType{utils::TypeId::AST_PROPERTYIX, "PropertyIx", nullptr}; + +constexpr utils::TypeInfo query::EdgeTypeIx::kType{utils::TypeId::AST_EDGETYPEIX, "EdgeTypeIx", nullptr}; + +constexpr utils::TypeInfo query::Tree::kType{utils::TypeId::AST_TREE, "Tree", nullptr}; + +constexpr utils::TypeInfo query::Expression::kType{utils::TypeId::AST_EXPRESSION, "Expression", &query::Tree::kType}; + +constexpr utils::TypeInfo query::Where::kType{utils::TypeId::AST_WHERE, "Where", &query::Tree::kType}; + +constexpr utils::TypeInfo query::BinaryOperator::kType{utils::TypeId::AST_BINARY_OPERATOR, "BinaryOperator", + &query::Expression::kType}; + +constexpr utils::TypeInfo query::UnaryOperator::kType{utils::TypeId::AST_UNARY_OPERATOR, "UnaryOperator", + &query::Expression::kType}; + +constexpr utils::TypeInfo query::OrOperator::kType{utils::TypeId::AST_OR_OPERATOR, "OrOperator", + &query::BinaryOperator::kType}; + +constexpr utils::TypeInfo query::XorOperator::kType{utils::TypeId::AST_XOR_OPERATOR, "XorOperator", + &query::BinaryOperator::kType}; + +constexpr utils::TypeInfo query::AndOperator::kType{utils::TypeId::AST_AND_OPERATOR, "AndOperator", + &query::BinaryOperator::kType}; + +constexpr utils::TypeInfo query::AdditionOperator::kType{utils::TypeId::AST_ADDITION_OPERATOR, "AdditionOperator", + &query::BinaryOperator::kType}; + +constexpr utils::TypeInfo query::SubtractionOperator::kType{utils::TypeId::AST_SUBTRACTION_OPERATOR, + "SubtractionOperator", &query::BinaryOperator::kType}; + +constexpr utils::TypeInfo query::MultiplicationOperator::kType{utils::TypeId::AST_MULTIPLICATION_OPERATOR, + "MultiplicationOperator", &query::BinaryOperator::kType}; + +constexpr utils::TypeInfo query::DivisionOperator::kType{utils::TypeId::AST_DIVISION_OPERATOR, "DivisionOperator", + &query::BinaryOperator::kType}; + +constexpr utils::TypeInfo query::ModOperator::kType{utils::TypeId::AST_MOD_OPERATOR, "ModOperator", + &query::BinaryOperator::kType}; + +constexpr utils::TypeInfo query::NotEqualOperator::kType{utils::TypeId::AST_NOT_EQUAL_OPERATOR, "NotEqualOperator", + &query::BinaryOperator::kType}; + +constexpr utils::TypeInfo query::EqualOperator::kType{utils::TypeId::AST_EQUAL_OPERATOR, "EqualOperator", + &query::BinaryOperator::kType}; + +constexpr utils::TypeInfo query::LessOperator::kType{utils::TypeId::AST_LESS_OPERATOR, "LessOperator", + &query::BinaryOperator::kType}; + +constexpr utils::TypeInfo query::GreaterOperator::kType{utils::TypeId::AST_GREATER_OPERATOR, "GreaterOperator", + &query::BinaryOperator::kType}; + +constexpr utils::TypeInfo query::LessEqualOperator::kType{utils::TypeId::AST_LESS_EQUAL_OPERATOR, "LessEqualOperator", + &query::BinaryOperator::kType}; + +constexpr utils::TypeInfo query::GreaterEqualOperator::kType{utils::TypeId::AST_GREATER_EQUAL_OPERATOR, + "GreaterEqualOperator", &query::BinaryOperator::kType}; + +constexpr utils::TypeInfo query::InListOperator::kType{utils::TypeId::AST_IN_LIST_OPERATOR, "InListOperator", + &query::BinaryOperator::kType}; + +constexpr utils::TypeInfo query::SubscriptOperator::kType{utils::TypeId::AST_SUBSCRIPT_OPERATOR, "SubscriptOperator", + &query::BinaryOperator::kType}; + +constexpr utils::TypeInfo query::NotOperator::kType{utils::TypeId::AST_NOT_OPERATOR, "NotOperator", + &query::UnaryOperator::kType}; + +constexpr utils::TypeInfo query::UnaryPlusOperator::kType{utils::TypeId::AST_UNARY_PLUS_OPERATOR, "UnaryPlusOperator", + &query::UnaryOperator::kType}; + +constexpr utils::TypeInfo query::UnaryMinusOperator::kType{utils::TypeId::AST_UNARY_MINUS_OPERATOR, + "UnaryMinusOperator", &query::UnaryOperator::kType}; + +constexpr utils::TypeInfo query::IsNullOperator::kType{utils::TypeId::AST_IS_NULL_OPERATOR, "IsNullOperator", + &query::UnaryOperator::kType}; + +constexpr utils::TypeInfo query::Aggregation::kType{utils::TypeId::AST_AGGREGATION, "Aggregation", + &query::BinaryOperator::kType}; + +constexpr utils::TypeInfo query::ListSlicingOperator::kType{utils::TypeId::AST_LIST_SLICING_OPERATOR, + "ListSlicingOperator", &query::Expression::kType}; + +constexpr utils::TypeInfo query::IfOperator::kType{utils::TypeId::AST_IF_OPERATOR, "IfOperator", + &query::Expression::kType}; + +constexpr utils::TypeInfo query::BaseLiteral::kType{utils::TypeId::AST_BASE_LITERAL, "BaseLiteral", + &query::Expression::kType}; + +constexpr utils::TypeInfo query::PrimitiveLiteral::kType{utils::TypeId::AST_PRIMITIVE_LITERAL, "PrimitiveLiteral", + &query::BaseLiteral::kType}; + +constexpr utils::TypeInfo query::ListLiteral::kType{utils::TypeId::AST_LIST_LITERAL, "ListLiteral", + &query::BaseLiteral::kType}; + +constexpr utils::TypeInfo query::MapLiteral::kType{utils::TypeId::AST_MAP_LITERAL, "MapLiteral", + &query::BaseLiteral::kType}; + +constexpr utils::TypeInfo query::Identifier::kType{utils::TypeId::AST_IDENTIFIER, "Identifier", + &query::Expression::kType}; + +constexpr utils::TypeInfo query::PropertyLookup::kType{utils::TypeId::AST_PROPERTY_LOOKUP, "PropertyLookup", + &query::Expression::kType}; + +constexpr utils::TypeInfo query::LabelsTest::kType{utils::TypeId::AST_LABELS_TEST, "LabelsTest", + &query::Expression::kType}; + +constexpr utils::TypeInfo query::Function::kType{utils::TypeId::AST_FUNCTION, "Function", &query::Expression::kType}; + +constexpr utils::TypeInfo query::Reduce::kType{utils::TypeId::AST_REDUCE, "Reduce", &query::Expression::kType}; + +constexpr utils::TypeInfo query::Coalesce::kType{utils::TypeId::AST_COALESCE, "Coalesce", &query::Expression::kType}; + +constexpr utils::TypeInfo query::Extract::kType{utils::TypeId::AST_EXTRACT, "Extract", &query::Expression::kType}; + +constexpr utils::TypeInfo query::All::kType{utils::TypeId::AST_ALL, "All", &query::Expression::kType}; + +constexpr utils::TypeInfo query::Single::kType{utils::TypeId::AST_SINGLE, "Single", &query::Expression::kType}; + +constexpr utils::TypeInfo query::Any::kType{utils::TypeId::AST_ANY, "Any", &query::Expression::kType}; + +constexpr utils::TypeInfo query::None::kType{utils::TypeId::AST_NONE, "None", &query::Expression::kType}; + +constexpr utils::TypeInfo query::ParameterLookup::kType{utils::TypeId::AST_PARAMETER_LOOKUP, "ParameterLookup", + &query::Expression::kType}; + +constexpr utils::TypeInfo query::RegexMatch::kType{utils::TypeId::AST_REGEX_MATCH, "RegexMatch", + &query::Expression::kType}; + +constexpr utils::TypeInfo query::NamedExpression::kType{utils::TypeId::AST_NAMED_EXPRESSION, "NamedExpression", + &query::Tree::kType}; + +constexpr utils::TypeInfo query::PatternAtom::kType{utils::TypeId::AST_PATTERN_ATOM, "PatternAtom", + &query::Tree::kType}; + +constexpr utils::TypeInfo query::NodeAtom::kType{utils::TypeId::AST_NODE_ATOM, "NodeAtom", &query::PatternAtom::kType}; + +constexpr utils::TypeInfo query::EdgeAtom::Lambda::kType{utils::TypeId::AST_EDGE_ATOM_LAMBDA, "Lambda", nullptr}; + +constexpr utils::TypeInfo query::EdgeAtom::kType{utils::TypeId::AST_EDGE_ATOM, "EdgeAtom", &query::PatternAtom::kType}; + +constexpr utils::TypeInfo query::Pattern::kType{utils::TypeId::AST_PATTERN, "Pattern", &query::Tree::kType}; + +constexpr utils::TypeInfo query::Clause::kType{utils::TypeId::AST_CLAUSE, "Clause", &query::Tree::kType}; + +constexpr utils::TypeInfo query::SingleQuery::kType{utils::TypeId::AST_SINGLE_QUERY, "SingleQuery", + &query::Tree::kType}; + +constexpr utils::TypeInfo query::CypherUnion::kType{utils::TypeId::AST_CYPHER_UNION, "CypherUnion", + &query::Tree::kType}; + +constexpr utils::TypeInfo query::Query::kType{utils::TypeId::AST_QUERY, "Query", &query::Tree::kType}; + +constexpr utils::TypeInfo query::CypherQuery::kType{utils::TypeId::AST_CYPHER_QUERY, "CypherQuery", + &query::Query::kType}; + +constexpr utils::TypeInfo query::ExplainQuery::kType{utils::TypeId::AST_EXPLAIN_QUERY, "ExplainQuery", + &query::Query::kType}; + +constexpr utils::TypeInfo query::ProfileQuery::kType{utils::TypeId::AST_PROFILE_QUERY, "ProfileQuery", + &query::Query::kType}; + +constexpr utils::TypeInfo query::IndexQuery::kType{utils::TypeId::AST_INDEX_QUERY, "IndexQuery", &query::Query::kType}; + +constexpr utils::TypeInfo query::Create::kType{utils::TypeId::AST_CREATE, "Create", &query::Clause::kType}; + +constexpr utils::TypeInfo query::CallProcedure::kType{utils::TypeId::AST_CALL_PROCEDURE, "CallProcedure", + &query::Clause::kType}; + +constexpr utils::TypeInfo query::Match::kType{utils::TypeId::AST_MATCH, "Match", &query::Clause::kType}; + +constexpr utils::TypeInfo query::SortItem::kType{utils::TypeId::AST_SORT_ITEM, "SortItem", nullptr}; + +constexpr utils::TypeInfo query::ReturnBody::kType{utils::TypeId::AST_RETURN_BODY, "ReturnBody", nullptr}; + +constexpr utils::TypeInfo query::Return::kType{utils::TypeId::AST_RETURN, "Return", &query::Clause::kType}; + +constexpr utils::TypeInfo query::With::kType{utils::TypeId::AST_WITH, "With", &query::Clause::kType}; + +constexpr utils::TypeInfo query::Delete::kType{utils::TypeId::AST_DELETE, "Delete", &query::Clause::kType}; + +constexpr utils::TypeInfo query::SetProperty::kType{utils::TypeId::AST_SET_PROPERTY, "SetProperty", + &query::Clause::kType}; + +constexpr utils::TypeInfo query::SetProperties::kType{utils::TypeId::AST_SET_PROPERTIES, "SetProperties", + &query::Clause::kType}; + +constexpr utils::TypeInfo query::SetLabels::kType{utils::TypeId::AST_SET_LABELS, "SetLabels", &query::Clause::kType}; + +constexpr utils::TypeInfo query::RemoveProperty::kType{utils::TypeId::AST_REMOVE_PROPERTY, "RemoveProperty", + &query::Clause::kType}; + +constexpr utils::TypeInfo query::RemoveLabels::kType{utils::TypeId::AST_REMOVE_LABELS, "RemoveLabels", + &query::Clause::kType}; + +constexpr utils::TypeInfo query::Merge::kType{utils::TypeId::AST_MERGE, "Merge", &query::Clause::kType}; + +constexpr utils::TypeInfo query::Unwind::kType{utils::TypeId::AST_UNWIND, "Unwind", &query::Clause::kType}; + +constexpr utils::TypeInfo query::AuthQuery::kType{utils::TypeId::AST_AUTH_QUERY, "AuthQuery", &query::Query::kType}; + +constexpr utils::TypeInfo query::InfoQuery::kType{utils::TypeId::AST_INFO_QUERY, "InfoQuery", &query::Query::kType}; + +constexpr utils::TypeInfo query::Constraint::kType{utils::TypeId::AST_CONSTRAINT, "Constraint", nullptr}; + +constexpr utils::TypeInfo query::ConstraintQuery::kType{utils::TypeId::AST_CONSTRAINT_QUERY, "ConstraintQuery", + &query::Query::kType}; + +constexpr utils::TypeInfo query::DumpQuery::kType{utils::TypeId::AST_DUMP_QUERY, "DumpQuery", &query::Query::kType}; + +constexpr utils::TypeInfo query::ReplicationQuery::kType{utils::TypeId::AST_REPLICATION_QUERY, "ReplicationQuery", + &query::Query::kType}; + +constexpr utils::TypeInfo query::LockPathQuery::kType{utils::TypeId::AST_LOCK_PATH_QUERY, "LockPathQuery", + &query::Query::kType}; + +constexpr utils::TypeInfo query::LoadCsv::kType{utils::TypeId::AST_LOAD_CSV, "LoadCsv", &query::Clause::kType}; + +constexpr utils::TypeInfo query::FreeMemoryQuery::kType{utils::TypeId::AST_FREE_MEMORY_QUERY, "FreeMemoryQuery", + &query::Query::kType}; + +constexpr utils::TypeInfo query::TriggerQuery::kType{utils::TypeId::AST_TRIGGER_QUERY, "TriggerQuery", + &query::Query::kType}; + +constexpr utils::TypeInfo query::IsolationLevelQuery::kType{utils::TypeId::AST_ISOLATION_LEVEL_QUERY, + "IsolationLevelQuery", &query::Query::kType}; + +constexpr utils::TypeInfo query::CreateSnapshotQuery::kType{utils::TypeId::AST_CREATE_SNAPSHOT_QUERY, + "CreateSnapshotQuery", &query::Query::kType}; + +constexpr utils::TypeInfo query::StreamQuery::kType{utils::TypeId::AST_STREAM_QUERY, "StreamQuery", + &query::Query::kType}; + +constexpr utils::TypeInfo query::SettingQuery::kType{utils::TypeId::AST_SETTING_QUERY, "SettingQuery", + &query::Query::kType}; + +constexpr utils::TypeInfo query::VersionQuery::kType{utils::TypeId::AST_VERSION_QUERY, "VersionQuery", + &query::Query::kType}; + +constexpr utils::TypeInfo query::Foreach::kType{utils::TypeId::AST_FOREACH, "Foreach", &query::Clause::kType}; + +constexpr utils::TypeInfo query::ShowConfigQuery::kType{utils::TypeId::AST_SHOW_CONFIG_QUERY, "ShowConfigQuery", + &query::Query::kType}; + +constexpr utils::TypeInfo query::Exists::kType{utils::TypeId::AST_EXISTS, "Exists", &query::Expression::kType}; +} // namespace memgraph diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp new file mode 100644 index 000000000..e27717a2c --- /dev/null +++ b/src/query/frontend/ast/ast.hpp @@ -0,0 +1,3246 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <memory> +#include <unordered_map> +#include <variant> +#include <vector> + +#include "query/frontend/ast/ast_visitor.hpp" +#include "query/frontend/semantic/symbol.hpp" +#include "query/interpret/awesome_memgraph_functions.hpp" +#include "query/typed_value.hpp" +#include "storage/v2/property_value.hpp" +#include "utils/typeinfo.hpp" + +namespace memgraph { + +namespace query { + +struct LabelIx { + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const { return kType; } + + std::string name; + int64_t ix; +}; + +struct PropertyIx { + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const { return kType; } + + std::string name; + int64_t ix; +}; + +struct EdgeTypeIx { + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const { return kType; } + + std::string name; + int64_t ix; +}; + +inline bool operator==(const LabelIx &a, const LabelIx &b) { return a.ix == b.ix && a.name == b.name; } + +inline bool operator!=(const LabelIx &a, const LabelIx &b) { return !(a == b); } + +inline bool operator==(const PropertyIx &a, const PropertyIx &b) { return a.ix == b.ix && a.name == b.name; } + +inline bool operator!=(const PropertyIx &a, const PropertyIx &b) { return !(a == b); } + +inline bool operator==(const EdgeTypeIx &a, const EdgeTypeIx &b) { return a.ix == b.ix && a.name == b.name; } + +inline bool operator!=(const EdgeTypeIx &a, const EdgeTypeIx &b) { return !(a == b); } +} // namespace query +} // namespace memgraph +namespace std { + +template <> +struct hash<memgraph::query::LabelIx> { + size_t operator()(const memgraph::query::LabelIx &label) const { return label.ix; } +}; + +template <> +struct hash<memgraph::query::PropertyIx> { + size_t operator()(const memgraph::query::PropertyIx &prop) const { return prop.ix; } +}; + +template <> +struct hash<memgraph::query::EdgeTypeIx> { + size_t operator()(const memgraph::query::EdgeTypeIx &edge_type) const { return edge_type.ix; } +}; + +} // namespace std + +namespace memgraph { + +namespace query { + +class Tree; + +// It would be better to call this AstTree, but we already have a class Tree, +// which could be renamed to Node or AstTreeNode, but we also have a class +// called NodeAtom... +class AstStorage { + public: + AstStorage() = default; + AstStorage(const AstStorage &) = delete; + AstStorage &operator=(const AstStorage &) = delete; + AstStorage(AstStorage &&) = default; + AstStorage &operator=(AstStorage &&) = default; + + template <typename T, typename... Args> + T *Create(Args &&...args) { + T *ptr = new T(std::forward<Args>(args)...); + std::unique_ptr<T> tmp(ptr); + storage_.emplace_back(std::move(tmp)); + return ptr; + } + + LabelIx GetLabelIx(const std::string &name) { return LabelIx{name, FindOrAddName(name, &labels_)}; } + + PropertyIx GetPropertyIx(const std::string &name) { return PropertyIx{name, FindOrAddName(name, &properties_)}; } + + EdgeTypeIx GetEdgeTypeIx(const std::string &name) { return EdgeTypeIx{name, FindOrAddName(name, &edge_types_)}; } + + std::vector<std::string> labels_; + std::vector<std::string> edge_types_; + std::vector<std::string> properties_; + + // Public only for serialization access + std::vector<std::unique_ptr<Tree>> storage_; + + private: + int64_t FindOrAddName(const std::string &name, std::vector<std::string> *names) { + for (int64_t i = 0; i < names->size(); ++i) { + if ((*names)[i] == name) { + return i; + } + } + names->push_back(name); + return names->size() - 1; + } +}; + +class Tree { + public: + static const utils::TypeInfo kType; + virtual const utils::TypeInfo &GetTypeInfo() const { return kType; } + + Tree() = default; + virtual ~Tree() {} + + virtual Tree *Clone(AstStorage *storage) const = 0; + + private: + friend class AstStorage; +}; + +class Expression : public memgraph::query::Tree, + public utils::Visitable<HierarchicalTreeVisitor>, + public utils::Visitable<ExpressionVisitor<TypedValue>>, + public utils::Visitable<ExpressionVisitor<TypedValue *>>, + public utils::Visitable<ExpressionVisitor<void>> { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + using utils::Visitable<HierarchicalTreeVisitor>::Accept; + using utils::Visitable<ExpressionVisitor<TypedValue>>::Accept; + using utils::Visitable<ExpressionVisitor<TypedValue *>>::Accept; + using utils::Visitable<ExpressionVisitor<void>>::Accept; + + Expression() = default; + + Expression *Clone(AstStorage *storage) const override = 0; + + private: + friend class AstStorage; +}; + +class Where : public memgraph::query::Tree, public utils::Visitable<HierarchicalTreeVisitor> { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + using utils::Visitable<HierarchicalTreeVisitor>::Accept; + + Where() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + memgraph::query::Expression *expression_{nullptr}; + + Where *Clone(AstStorage *storage) const override { + Where *object = storage->Create<Where>(); + object->expression_ = expression_ ? expression_->Clone(storage) : nullptr; + return object; + } + + protected: + explicit Where(Expression *expression) : expression_(expression) {} + + private: + friend class AstStorage; +}; + +class BinaryOperator : public memgraph::query::Expression { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + BinaryOperator() = default; + + memgraph::query::Expression *expression1_{nullptr}; + memgraph::query::Expression *expression2_{nullptr}; + + BinaryOperator *Clone(AstStorage *storage) const override = 0; + + protected: + BinaryOperator(Expression *expression1, Expression *expression2) + : expression1_(expression1), expression2_(expression2) {} + + private: + friend class AstStorage; +}; + +class UnaryOperator : public memgraph::query::Expression { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + UnaryOperator() = default; + + memgraph::query::Expression *expression_{nullptr}; + + UnaryOperator *Clone(AstStorage *storage) const override = 0; + + protected: + explicit UnaryOperator(Expression *expression) : expression_(expression) {} + + private: + friend class AstStorage; +}; + +class OrOperator : public memgraph::query::BinaryOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression1_->Accept(visitor) && expression2_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + OrOperator *Clone(AstStorage *storage) const override { + OrOperator *object = storage->Create<OrOperator>(); + object->expression1_ = expression1_ ? expression1_->Clone(storage) : nullptr; + object->expression2_ = expression2_ ? expression2_->Clone(storage) : nullptr; + return object; + } + + protected: + using BinaryOperator::BinaryOperator; + + private: + friend class AstStorage; +}; + +class XorOperator : public memgraph::query::BinaryOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression1_->Accept(visitor) && expression2_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + XorOperator *Clone(AstStorage *storage) const override { + XorOperator *object = storage->Create<XorOperator>(); + object->expression1_ = expression1_ ? expression1_->Clone(storage) : nullptr; + object->expression2_ = expression2_ ? expression2_->Clone(storage) : nullptr; + return object; + } + + protected: + using BinaryOperator::BinaryOperator; + + private: + friend class AstStorage; +}; + +class AndOperator : public memgraph::query::BinaryOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression1_->Accept(visitor) && expression2_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + AndOperator *Clone(AstStorage *storage) const override { + AndOperator *object = storage->Create<AndOperator>(); + object->expression1_ = expression1_ ? expression1_->Clone(storage) : nullptr; + object->expression2_ = expression2_ ? expression2_->Clone(storage) : nullptr; + return object; + } + + protected: + using BinaryOperator::BinaryOperator; + + private: + friend class AstStorage; +}; + +class AdditionOperator : public memgraph::query::BinaryOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression1_->Accept(visitor) && expression2_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + AdditionOperator *Clone(AstStorage *storage) const override { + AdditionOperator *object = storage->Create<AdditionOperator>(); + object->expression1_ = expression1_ ? expression1_->Clone(storage) : nullptr; + object->expression2_ = expression2_ ? expression2_->Clone(storage) : nullptr; + return object; + } + + protected: + using BinaryOperator::BinaryOperator; + + private: + friend class AstStorage; +}; + +class SubtractionOperator : public memgraph::query::BinaryOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression1_->Accept(visitor) && expression2_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + SubtractionOperator *Clone(AstStorage *storage) const override { + SubtractionOperator *object = storage->Create<SubtractionOperator>(); + object->expression1_ = expression1_ ? expression1_->Clone(storage) : nullptr; + object->expression2_ = expression2_ ? expression2_->Clone(storage) : nullptr; + return object; + } + + protected: + using BinaryOperator::BinaryOperator; + + private: + friend class AstStorage; +}; + +class MultiplicationOperator : public memgraph::query::BinaryOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression1_->Accept(visitor) && expression2_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + MultiplicationOperator *Clone(AstStorage *storage) const override { + MultiplicationOperator *object = storage->Create<MultiplicationOperator>(); + object->expression1_ = expression1_ ? expression1_->Clone(storage) : nullptr; + object->expression2_ = expression2_ ? expression2_->Clone(storage) : nullptr; + return object; + } + + protected: + using BinaryOperator::BinaryOperator; + + private: + friend class AstStorage; +}; + +class DivisionOperator : public memgraph::query::BinaryOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression1_->Accept(visitor) && expression2_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + DivisionOperator *Clone(AstStorage *storage) const override { + DivisionOperator *object = storage->Create<DivisionOperator>(); + object->expression1_ = expression1_ ? expression1_->Clone(storage) : nullptr; + object->expression2_ = expression2_ ? expression2_->Clone(storage) : nullptr; + return object; + } + + protected: + using BinaryOperator::BinaryOperator; + + private: + friend class AstStorage; +}; + +class ModOperator : public memgraph::query::BinaryOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression1_->Accept(visitor) && expression2_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + ModOperator *Clone(AstStorage *storage) const override { + ModOperator *object = storage->Create<ModOperator>(); + object->expression1_ = expression1_ ? expression1_->Clone(storage) : nullptr; + object->expression2_ = expression2_ ? expression2_->Clone(storage) : nullptr; + return object; + } + + protected: + using BinaryOperator::BinaryOperator; + + private: + friend class AstStorage; +}; + +class NotEqualOperator : public memgraph::query::BinaryOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression1_->Accept(visitor) && expression2_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + NotEqualOperator *Clone(AstStorage *storage) const override { + NotEqualOperator *object = storage->Create<NotEqualOperator>(); + object->expression1_ = expression1_ ? expression1_->Clone(storage) : nullptr; + object->expression2_ = expression2_ ? expression2_->Clone(storage) : nullptr; + return object; + } + + protected: + using BinaryOperator::BinaryOperator; + + private: + friend class AstStorage; +}; + +class EqualOperator : public memgraph::query::BinaryOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression1_->Accept(visitor) && expression2_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + EqualOperator *Clone(AstStorage *storage) const override { + EqualOperator *object = storage->Create<EqualOperator>(); + object->expression1_ = expression1_ ? expression1_->Clone(storage) : nullptr; + object->expression2_ = expression2_ ? expression2_->Clone(storage) : nullptr; + return object; + } + + protected: + using BinaryOperator::BinaryOperator; + + private: + friend class AstStorage; +}; + +class LessOperator : public memgraph::query::BinaryOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression1_->Accept(visitor) && expression2_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + LessOperator *Clone(AstStorage *storage) const override { + LessOperator *object = storage->Create<LessOperator>(); + object->expression1_ = expression1_ ? expression1_->Clone(storage) : nullptr; + object->expression2_ = expression2_ ? expression2_->Clone(storage) : nullptr; + return object; + } + + protected: + using BinaryOperator::BinaryOperator; + + private: + friend class AstStorage; +}; + +class GreaterOperator : public memgraph::query::BinaryOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression1_->Accept(visitor) && expression2_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + GreaterOperator *Clone(AstStorage *storage) const override { + GreaterOperator *object = storage->Create<GreaterOperator>(); + object->expression1_ = expression1_ ? expression1_->Clone(storage) : nullptr; + object->expression2_ = expression2_ ? expression2_->Clone(storage) : nullptr; + return object; + } + + protected: + using BinaryOperator::BinaryOperator; + + private: + friend class AstStorage; +}; + +class LessEqualOperator : public memgraph::query::BinaryOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression1_->Accept(visitor) && expression2_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + LessEqualOperator *Clone(AstStorage *storage) const override { + LessEqualOperator *object = storage->Create<LessEqualOperator>(); + object->expression1_ = expression1_ ? expression1_->Clone(storage) : nullptr; + object->expression2_ = expression2_ ? expression2_->Clone(storage) : nullptr; + return object; + } + + protected: + using BinaryOperator::BinaryOperator; + + private: + friend class AstStorage; +}; + +class GreaterEqualOperator : public memgraph::query::BinaryOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression1_->Accept(visitor) && expression2_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + GreaterEqualOperator *Clone(AstStorage *storage) const override { + GreaterEqualOperator *object = storage->Create<GreaterEqualOperator>(); + object->expression1_ = expression1_ ? expression1_->Clone(storage) : nullptr; + object->expression2_ = expression2_ ? expression2_->Clone(storage) : nullptr; + return object; + } + + protected: + using BinaryOperator::BinaryOperator; + + private: + friend class AstStorage; +}; + +class InListOperator : public memgraph::query::BinaryOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression1_->Accept(visitor) && expression2_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + InListOperator *Clone(AstStorage *storage) const override { + InListOperator *object = storage->Create<InListOperator>(); + object->expression1_ = expression1_ ? expression1_->Clone(storage) : nullptr; + object->expression2_ = expression2_ ? expression2_->Clone(storage) : nullptr; + return object; + } + + protected: + using BinaryOperator::BinaryOperator; + + private: + friend class AstStorage; +}; + +class SubscriptOperator : public memgraph::query::BinaryOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression1_->Accept(visitor) && expression2_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + SubscriptOperator *Clone(AstStorage *storage) const override { + SubscriptOperator *object = storage->Create<SubscriptOperator>(); + object->expression1_ = expression1_ ? expression1_->Clone(storage) : nullptr; + object->expression2_ = expression2_ ? expression2_->Clone(storage) : nullptr; + return object; + } + + protected: + using BinaryOperator::BinaryOperator; + + private: + friend class AstStorage; +}; + +class NotOperator : public memgraph::query::UnaryOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + NotOperator *Clone(AstStorage *storage) const override { + NotOperator *object = storage->Create<NotOperator>(); + object->expression_ = expression_ ? expression_->Clone(storage) : nullptr; + return object; + } + + protected: + using UnaryOperator::UnaryOperator; + + private: + friend class AstStorage; +}; + +class UnaryPlusOperator : public memgraph::query::UnaryOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + UnaryPlusOperator *Clone(AstStorage *storage) const override { + UnaryPlusOperator *object = storage->Create<UnaryPlusOperator>(); + object->expression_ = expression_ ? expression_->Clone(storage) : nullptr; + return object; + } + + protected: + using UnaryOperator::UnaryOperator; + + private: + friend class AstStorage; +}; + +class UnaryMinusOperator : public memgraph::query::UnaryOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + UnaryMinusOperator *Clone(AstStorage *storage) const override { + UnaryMinusOperator *object = storage->Create<UnaryMinusOperator>(); + object->expression_ = expression_ ? expression_->Clone(storage) : nullptr; + return object; + } + + protected: + using UnaryOperator::UnaryOperator; + + private: + friend class AstStorage; +}; + +class IsNullOperator : public memgraph::query::UnaryOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + IsNullOperator *Clone(AstStorage *storage) const override { + IsNullOperator *object = storage->Create<IsNullOperator>(); + object->expression_ = expression_ ? expression_->Clone(storage) : nullptr; + return object; + } + + protected: + using UnaryOperator::UnaryOperator; + + private: + friend class AstStorage; +}; + +class Aggregation : public memgraph::query::BinaryOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + enum class Op { COUNT, MIN, MAX, SUM, AVG, COLLECT_LIST, COLLECT_MAP, PROJECT }; + + Aggregation() = default; + + static const constexpr char *const kCount = "COUNT"; + static const constexpr char *const kMin = "MIN"; + static const constexpr char *const kMax = "MAX"; + static const constexpr char *const kSum = "SUM"; + static const constexpr char *const kAvg = "AVG"; + static const constexpr char *const kCollect = "COLLECT"; + static const constexpr char *const kProject = "PROJECT"; + + static std::string OpToString(Op op) { + const char *op_strings[] = {kCount, kMin, kMax, kSum, kAvg, kCollect, kCollect, kProject}; + return op_strings[static_cast<int>(op)]; + } + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + if (expression1_) expression1_->Accept(visitor); + if (expression2_) expression2_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + Aggregation *MapTo(const Symbol &symbol) { + symbol_pos_ = symbol.position(); + return this; + } + + memgraph::query::Aggregation::Op op_; + /// Symbol table position of the symbol this Aggregation is mapped to. + int32_t symbol_pos_{-1}; + bool distinct_{false}; + + Aggregation *Clone(AstStorage *storage) const override { + Aggregation *object = storage->Create<Aggregation>(); + object->expression1_ = expression1_ ? expression1_->Clone(storage) : nullptr; + object->expression2_ = expression2_ ? expression2_->Clone(storage) : nullptr; + object->op_ = op_; + object->symbol_pos_ = symbol_pos_; + object->distinct_ = distinct_; + return object; + } + + protected: + // Use only for serialization. + explicit Aggregation(Op op) : op_(op) {} + + /// Aggregation's first expression is the value being aggregated. The second + /// expression is the key used only in COLLECT_MAP. + Aggregation(Expression *expression1, Expression *expression2, Op op, bool distinct) + : BinaryOperator(expression1, expression2), op_(op), distinct_(distinct) { + // COUNT without expression denotes COUNT(*) in cypher. + DMG_ASSERT(expression1 || op == Aggregation::Op::COUNT, "All aggregations, except COUNT require expression"); + DMG_ASSERT((expression2 == nullptr) ^ (op == Aggregation::Op::COLLECT_MAP), + "The second expression is obligatory in COLLECT_MAP and " + "invalid otherwise"); + } + + private: + friend class AstStorage; +}; + +class ListSlicingOperator : public memgraph::query::Expression { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + ListSlicingOperator() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + bool cont = list_->Accept(visitor); + if (cont && lower_bound_) { + cont = lower_bound_->Accept(visitor); + } + if (cont && upper_bound_) { + upper_bound_->Accept(visitor); + } + } + return visitor.PostVisit(*this); + } + + memgraph::query::Expression *list_{nullptr}; + memgraph::query::Expression *lower_bound_{nullptr}; + memgraph::query::Expression *upper_bound_{nullptr}; + + ListSlicingOperator *Clone(AstStorage *storage) const override { + ListSlicingOperator *object = storage->Create<ListSlicingOperator>(); + object->list_ = list_ ? list_->Clone(storage) : nullptr; + object->lower_bound_ = lower_bound_ ? lower_bound_->Clone(storage) : nullptr; + object->upper_bound_ = upper_bound_ ? upper_bound_->Clone(storage) : nullptr; + return object; + } + + protected: + ListSlicingOperator(Expression *list, Expression *lower_bound, Expression *upper_bound) + : list_(list), lower_bound_(lower_bound), upper_bound_(upper_bound) {} + + private: + friend class AstStorage; +}; + +class IfOperator : public memgraph::query::Expression { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + IfOperator() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + condition_->Accept(visitor) && then_expression_->Accept(visitor) && else_expression_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + /// None of the expressions should be nullptr. If there is no else_expression, you should make it null + /// PrimitiveLiteral. + memgraph::query::Expression *condition_; + memgraph::query::Expression *then_expression_; + memgraph::query::Expression *else_expression_; + + IfOperator *Clone(AstStorage *storage) const override { + IfOperator *object = storage->Create<IfOperator>(); + object->condition_ = condition_ ? condition_->Clone(storage) : nullptr; + object->then_expression_ = then_expression_ ? then_expression_->Clone(storage) : nullptr; + object->else_expression_ = else_expression_ ? else_expression_->Clone(storage) : nullptr; + return object; + } + + protected: + IfOperator(Expression *condition, Expression *then_expression, Expression *else_expression) + : condition_(condition), then_expression_(then_expression), else_expression_(else_expression) {} + + private: + friend class AstStorage; +}; + +class BaseLiteral : public memgraph::query::Expression { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + BaseLiteral() = default; + + BaseLiteral *Clone(AstStorage *storage) const override = 0; + + private: + friend class AstStorage; +}; + +class PrimitiveLiteral : public memgraph::query::BaseLiteral { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + PrimitiveLiteral() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + DEFVISITABLE(HierarchicalTreeVisitor); + + storage::PropertyValue value_; + /// This field contains token position of literal used to create PrimitiveLiteral object. If PrimitiveLiteral object + /// is not created from query, leave its value at -1. + int32_t token_position_{-1}; + + PrimitiveLiteral *Clone(AstStorage *storage) const override { + PrimitiveLiteral *object = storage->Create<PrimitiveLiteral>(); + object->value_ = value_; + object->token_position_ = token_position_; + return object; + } + + protected: + template <typename T> + explicit PrimitiveLiteral(T value) : value_(value) {} + template <typename T> + PrimitiveLiteral(T value, int token_position) : value_(value), token_position_(token_position) {} + + private: + friend class AstStorage; +}; + +class ListLiteral : public memgraph::query::BaseLiteral { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + ListLiteral() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + for (auto expr_ptr : elements_) + if (!expr_ptr->Accept(visitor)) break; + } + return visitor.PostVisit(*this); + } + + std::vector<memgraph::query::Expression *> elements_; + + ListLiteral *Clone(AstStorage *storage) const override { + ListLiteral *object = storage->Create<ListLiteral>(); + object->elements_.resize(elements_.size()); + for (auto i0 = 0; i0 < elements_.size(); ++i0) { + object->elements_[i0] = elements_[i0] ? elements_[i0]->Clone(storage) : nullptr; + } + return object; + } + + protected: + explicit ListLiteral(const std::vector<Expression *> &elements) : elements_(elements) {} + + private: + friend class AstStorage; +}; + +class MapLiteral : public memgraph::query::BaseLiteral { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + MapLiteral() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + for (auto pair : elements_) + if (!pair.second->Accept(visitor)) break; + } + return visitor.PostVisit(*this); + } + + std::unordered_map<memgraph::query::PropertyIx, memgraph::query::Expression *> elements_; + + MapLiteral *Clone(AstStorage *storage) const override { + MapLiteral *object = storage->Create<MapLiteral>(); + for (const auto &entry : elements_) { + PropertyIx key = storage->GetPropertyIx(entry.first.name); + object->elements_[key] = entry.second->Clone(storage); + } + return object; + } + + protected: + explicit MapLiteral(const std::unordered_map<PropertyIx, Expression *> &elements) : elements_(elements) {} + + private: + friend class AstStorage; +}; + +class Identifier : public memgraph::query::Expression { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Identifier() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + DEFVISITABLE(HierarchicalTreeVisitor); + + Identifier *MapTo(const Symbol &symbol) { + symbol_pos_ = symbol.position(); + return this; + } + + explicit Identifier(const std::string &name) : name_(name) {} + Identifier(const std::string &name, bool user_declared) : name_(name), user_declared_(user_declared) {} + + std::string name_; + bool user_declared_{true}; + /// Symbol table position of the symbol this Identifier is mapped to. + int32_t symbol_pos_{-1}; + + Identifier *Clone(AstStorage *storage) const override { + Identifier *object = storage->Create<Identifier>(); + object->name_ = name_; + object->user_declared_ = user_declared_; + object->symbol_pos_ = symbol_pos_; + return object; + } + + private: + friend class AstStorage; +}; + +class PropertyLookup : public memgraph::query::Expression { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + PropertyLookup() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + memgraph::query::Expression *expression_{nullptr}; + memgraph::query::PropertyIx property_; + + PropertyLookup *Clone(AstStorage *storage) const override { + PropertyLookup *object = storage->Create<PropertyLookup>(); + object->expression_ = expression_ ? expression_->Clone(storage) : nullptr; + object->property_ = storage->GetPropertyIx(property_.name); + return object; + } + + protected: + PropertyLookup(Expression *expression, PropertyIx property) : expression_(expression), property_(property) {} + + private: + friend class AstStorage; +}; + +class LabelsTest : public memgraph::query::Expression { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + LabelsTest() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + memgraph::query::Expression *expression_{nullptr}; + std::vector<memgraph::query::LabelIx> labels_; + + LabelsTest *Clone(AstStorage *storage) const override { + LabelsTest *object = storage->Create<LabelsTest>(); + object->expression_ = expression_ ? expression_->Clone(storage) : nullptr; + object->labels_.resize(labels_.size()); + for (auto i = 0; i < object->labels_.size(); ++i) { + object->labels_[i] = storage->GetLabelIx(labels_[i].name); + } + return object; + } + + protected: + LabelsTest(Expression *expression, const std::vector<LabelIx> &labels) : expression_(expression), labels_(labels) {} + + private: + friend class AstStorage; +}; + +class Function : public memgraph::query::Expression { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Function() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + for (auto *argument : arguments_) { + if (!argument->Accept(visitor)) break; + } + } + return visitor.PostVisit(*this); + } + + std::vector<memgraph::query::Expression *> arguments_; + std::string function_name_; + std::function<TypedValue(const TypedValue *, int64_t, const FunctionContext &)> function_; + + Function *Clone(AstStorage *storage) const override { + Function *object = storage->Create<Function>(); + object->arguments_.resize(arguments_.size()); + for (auto i1 = 0; i1 < arguments_.size(); ++i1) { + object->arguments_[i1] = arguments_[i1] ? arguments_[i1]->Clone(storage) : nullptr; + } + object->function_name_ = function_name_; + object->function_ = function_; + return object; + } + + protected: + Function(const std::string &function_name, const std::vector<Expression *> &arguments) + : arguments_(arguments), function_name_(function_name), function_(NameToFunction(function_name_)) { + if (!function_) { + throw SemanticException("Function '{}' doesn't exist.", function_name); + } + } + + private: + friend class AstStorage; +}; + +class Reduce : public memgraph::query::Expression { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Reduce() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + accumulator_->Accept(visitor) && initializer_->Accept(visitor) && identifier_->Accept(visitor) && + list_->Accept(visitor) && expression_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + /// Identifier for the accumulating variable + memgraph::query::Identifier *accumulator_{nullptr}; + /// Expression which produces the initial accumulator value. + memgraph::query::Expression *initializer_{nullptr}; + /// Identifier for the list element. + memgraph::query::Identifier *identifier_{nullptr}; + /// Expression which produces a list to be reduced. + memgraph::query::Expression *list_{nullptr}; + /// Expression which does the reduction, i.e. produces the new accumulator value. + memgraph::query::Expression *expression_{nullptr}; + + Reduce *Clone(AstStorage *storage) const override { + Reduce *object = storage->Create<Reduce>(); + object->accumulator_ = accumulator_ ? accumulator_->Clone(storage) : nullptr; + object->initializer_ = initializer_ ? initializer_->Clone(storage) : nullptr; + object->identifier_ = identifier_ ? identifier_->Clone(storage) : nullptr; + object->list_ = list_ ? list_->Clone(storage) : nullptr; + object->expression_ = expression_ ? expression_->Clone(storage) : nullptr; + return object; + } + + protected: + Reduce(Identifier *accumulator, Expression *initializer, Identifier *identifier, Expression *list, + Expression *expression) + : accumulator_(accumulator), + initializer_(initializer), + identifier_(identifier), + list_(list), + expression_(expression) {} + + private: + friend class AstStorage; +}; + +class Coalesce : public memgraph::query::Expression { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Coalesce() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + for (auto *expr : expressions_) { + if (!expr->Accept(visitor)) break; + } + } + return visitor.PostVisit(*this); + } + + /// A list of expressions to evaluate. None of the expressions should be nullptr. + std::vector<memgraph::query::Expression *> expressions_; + + Coalesce *Clone(AstStorage *storage) const override { + Coalesce *object = storage->Create<Coalesce>(); + object->expressions_.resize(expressions_.size()); + for (auto i2 = 0; i2 < expressions_.size(); ++i2) { + object->expressions_[i2] = expressions_[i2] ? expressions_[i2]->Clone(storage) : nullptr; + } + return object; + } + + private: + explicit Coalesce(const std::vector<Expression *> &expressions) : expressions_(expressions) {} + + friend class AstStorage; +}; + +class Extract : public memgraph::query::Expression { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Extract() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + identifier_->Accept(visitor) && list_->Accept(visitor) && expression_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + /// Identifier for the list element. + memgraph::query::Identifier *identifier_{nullptr}; + /// Expression which produces a list which will be extracted. + memgraph::query::Expression *list_{nullptr}; + /// Expression which produces the new value for list element. + memgraph::query::Expression *expression_{nullptr}; + + Extract *Clone(AstStorage *storage) const override { + Extract *object = storage->Create<Extract>(); + object->identifier_ = identifier_ ? identifier_->Clone(storage) : nullptr; + object->list_ = list_ ? list_->Clone(storage) : nullptr; + object->expression_ = expression_ ? expression_->Clone(storage) : nullptr; + return object; + } + + protected: + Extract(Identifier *identifier, Expression *list, Expression *expression) + : identifier_(identifier), list_(list), expression_(expression) {} + + private: + friend class AstStorage; +}; + +class All : public memgraph::query::Expression { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + All() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + identifier_->Accept(visitor) && list_expression_->Accept(visitor) && where_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + memgraph::query::Identifier *identifier_{nullptr}; + memgraph::query::Expression *list_expression_{nullptr}; + memgraph::query::Where *where_{nullptr}; + + All *Clone(AstStorage *storage) const override { + All *object = storage->Create<All>(); + object->identifier_ = identifier_ ? identifier_->Clone(storage) : nullptr; + object->list_expression_ = list_expression_ ? list_expression_->Clone(storage) : nullptr; + object->where_ = where_ ? where_->Clone(storage) : nullptr; + return object; + } + + protected: + All(Identifier *identifier, Expression *list_expression, Where *where) + : identifier_(identifier), list_expression_(list_expression), where_(where) {} + + private: + friend class AstStorage; +}; + +class Single : public memgraph::query::Expression { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Single() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + identifier_->Accept(visitor) && list_expression_->Accept(visitor) && where_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + memgraph::query::Identifier *identifier_{nullptr}; + memgraph::query::Expression *list_expression_{nullptr}; + memgraph::query::Where *where_{nullptr}; + + Single *Clone(AstStorage *storage) const override { + Single *object = storage->Create<Single>(); + object->identifier_ = identifier_ ? identifier_->Clone(storage) : nullptr; + object->list_expression_ = list_expression_ ? list_expression_->Clone(storage) : nullptr; + object->where_ = where_ ? where_->Clone(storage) : nullptr; + return object; + } + + protected: + Single(Identifier *identifier, Expression *list_expression, Where *where) + : identifier_(identifier), list_expression_(list_expression), where_(where) {} + + private: + friend class AstStorage; +}; + +class Any : public memgraph::query::Expression { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Any() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + identifier_->Accept(visitor) && list_expression_->Accept(visitor) && where_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + memgraph::query::Identifier *identifier_{nullptr}; + memgraph::query::Expression *list_expression_{nullptr}; + memgraph::query::Where *where_{nullptr}; + + Any *Clone(AstStorage *storage) const override { + Any *object = storage->Create<Any>(); + object->identifier_ = identifier_ ? identifier_->Clone(storage) : nullptr; + object->list_expression_ = list_expression_ ? list_expression_->Clone(storage) : nullptr; + object->where_ = where_ ? where_->Clone(storage) : nullptr; + return object; + } + + protected: + Any(Identifier *identifier, Expression *list_expression, Where *where) + : identifier_(identifier), list_expression_(list_expression), where_(where) {} + + private: + friend class AstStorage; +}; + +class None : public memgraph::query::Expression { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + None() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + identifier_->Accept(visitor) && list_expression_->Accept(visitor) && where_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + memgraph::query::Identifier *identifier_{nullptr}; + memgraph::query::Expression *list_expression_{nullptr}; + memgraph::query::Where *where_{nullptr}; + + None *Clone(AstStorage *storage) const override { + None *object = storage->Create<None>(); + object->identifier_ = identifier_ ? identifier_->Clone(storage) : nullptr; + object->list_expression_ = list_expression_ ? list_expression_->Clone(storage) : nullptr; + object->where_ = where_ ? where_->Clone(storage) : nullptr; + return object; + } + + protected: + None(Identifier *identifier, Expression *list_expression, Where *where) + : identifier_(identifier), list_expression_(list_expression), where_(where) {} + + private: + friend class AstStorage; +}; + +class ParameterLookup : public memgraph::query::Expression { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + ParameterLookup() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + DEFVISITABLE(HierarchicalTreeVisitor); + + /// This field contains token position of *literal* used to create ParameterLookup object. If ParameterLookup object + /// is not created from a literal leave this value at -1. + int32_t token_position_{-1}; + + ParameterLookup *Clone(AstStorage *storage) const override { + ParameterLookup *object = storage->Create<ParameterLookup>(); + object->token_position_ = token_position_; + return object; + } + + protected: + explicit ParameterLookup(int token_position) : token_position_(token_position) {} + + private: + friend class AstStorage; +}; + +class RegexMatch : public memgraph::query::Expression { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + RegexMatch() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + string_expr_->Accept(visitor) && regex_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + memgraph::query::Expression *string_expr_; + memgraph::query::Expression *regex_; + + RegexMatch *Clone(AstStorage *storage) const override { + RegexMatch *object = storage->Create<RegexMatch>(); + object->string_expr_ = string_expr_ ? string_expr_->Clone(storage) : nullptr; + object->regex_ = regex_ ? regex_->Clone(storage) : nullptr; + return object; + } + + private: + friend class AstStorage; + RegexMatch(Expression *string_expr, Expression *regex) : string_expr_(string_expr), regex_(regex) {} +}; + +class NamedExpression : public memgraph::query::Tree, + public utils::Visitable<HierarchicalTreeVisitor>, + public utils::Visitable<ExpressionVisitor<TypedValue>>, + public utils::Visitable<ExpressionVisitor<TypedValue *>>, + public utils::Visitable<ExpressionVisitor<void>> { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + using utils::Visitable<ExpressionVisitor<TypedValue>>::Accept; + using utils::Visitable<ExpressionVisitor<void>>::Accept; + using utils::Visitable<HierarchicalTreeVisitor>::Accept; + + NamedExpression() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + expression_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + NamedExpression *MapTo(const Symbol &symbol) { + symbol_pos_ = symbol.position(); + return this; + } + + std::string name_; + memgraph::query::Expression *expression_{nullptr}; + /// This field contains token position of first token in named expression used to create name_. If NamedExpression + /// object is not created from query or it is aliased leave this value at -1. + int32_t token_position_{-1}; + /// Symbol table position of the symbol this NamedExpression is mapped to. + int32_t symbol_pos_{-1}; + + NamedExpression *Clone(AstStorage *storage) const override { + NamedExpression *object = storage->Create<NamedExpression>(); + object->name_ = name_; + object->expression_ = expression_ ? expression_->Clone(storage) : nullptr; + object->token_position_ = token_position_; + object->symbol_pos_ = symbol_pos_; + return object; + } + + protected: + explicit NamedExpression(const std::string &name) : name_(name) {} + NamedExpression(const std::string &name, Expression *expression) : name_(name), expression_(expression) {} + NamedExpression(const std::string &name, Expression *expression, int token_position) + : name_(name), expression_(expression), token_position_(token_position) {} + + private: + friend class AstStorage; +}; + +class PatternAtom : public memgraph::query::Tree, public utils::Visitable<HierarchicalTreeVisitor> { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + using utils::Visitable<HierarchicalTreeVisitor>::Accept; + + PatternAtom() = default; + + memgraph::query::Identifier *identifier_{nullptr}; + + PatternAtom *Clone(AstStorage *storage) const override = 0; + + protected: + explicit PatternAtom(Identifier *identifier) : identifier_(identifier) {} + + private: + friend class AstStorage; +}; + +class NodeAtom : public memgraph::query::PatternAtom { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + if (auto *properties = std::get_if<std::unordered_map<PropertyIx, Expression *>>(&properties_)) { + bool cont = identifier_->Accept(visitor); + for (auto &property : *properties) { + if (cont) { + cont = property.second->Accept(visitor); + } + } + } else { + std::get<ParameterLookup *>(properties_)->Accept(visitor); + } + } + return visitor.PostVisit(*this); + } + + std::vector<memgraph::query::LabelIx> labels_; + std::variant<std::unordered_map<memgraph::query::PropertyIx, memgraph::query::Expression *>, + memgraph::query::ParameterLookup *> + properties_; + + NodeAtom *Clone(AstStorage *storage) const override { + NodeAtom *object = storage->Create<NodeAtom>(); + object->identifier_ = identifier_ ? identifier_->Clone(storage) : nullptr; + object->labels_.resize(labels_.size()); + for (auto i = 0; i < object->labels_.size(); ++i) { + object->labels_[i] = storage->GetLabelIx(labels_[i].name); + } + if (const auto *properties = std::get_if<std::unordered_map<PropertyIx, Expression *>>(&properties_)) { + auto &new_obj_properties = std::get<std::unordered_map<PropertyIx, Expression *>>(object->properties_); + for (const auto &[property, value_expression] : *properties) { + PropertyIx key = storage->GetPropertyIx(property.name); + new_obj_properties[key] = value_expression->Clone(storage); + } + } else { + object->properties_ = std::get<ParameterLookup *>(properties_)->Clone(storage); + } + return object; + } + + protected: + using PatternAtom::PatternAtom; + + private: + friend class AstStorage; +}; + +class EdgeAtom : public memgraph::query::PatternAtom { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + enum class Type { SINGLE, DEPTH_FIRST, BREADTH_FIRST, WEIGHTED_SHORTEST_PATH, ALL_SHORTEST_PATHS }; + + enum class Direction { IN, OUT, BOTH }; + + /// Lambda for use in filtering or weight calculation during variable expand. + struct Lambda { + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const { return kType; } + + /// Argument identifier for the edge currently being traversed. + memgraph::query::Identifier *inner_edge{nullptr}; + /// Argument identifier for the destination node of the edge. + memgraph::query::Identifier *inner_node{nullptr}; + /// Evaluates the result of the lambda. + memgraph::query::Expression *expression{nullptr}; + + Lambda Clone(AstStorage *storage) const { + Lambda object; + object.inner_edge = inner_edge ? inner_edge->Clone(storage) : nullptr; + object.inner_node = inner_node ? inner_node->Clone(storage) : nullptr; + object.expression = expression ? expression->Clone(storage) : nullptr; + return object; + } + }; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + bool cont = identifier_->Accept(visitor); + if (auto *properties = std::get_if<std::unordered_map<query::PropertyIx, query::Expression *>>(&properties_)) { + for (auto &property : *properties) { + if (cont) { + cont = property.second->Accept(visitor); + } + } + } else { + std::get<ParameterLookup *>(properties_)->Accept(visitor); + } + if (cont && lower_bound_) { + cont = lower_bound_->Accept(visitor); + } + if (cont && upper_bound_) { + cont = upper_bound_->Accept(visitor); + } + if (cont && total_weight_) { + total_weight_->Accept(visitor); + } + } + return visitor.PostVisit(*this); + } + + bool IsVariable() const { + switch (type_) { + case Type::DEPTH_FIRST: + case Type::BREADTH_FIRST: + case Type::WEIGHTED_SHORTEST_PATH: + case Type::ALL_SHORTEST_PATHS: + return true; + case Type::SINGLE: + return false; + } + } + + memgraph::query::EdgeAtom::Type type_{Type::SINGLE}; + memgraph::query::EdgeAtom::Direction direction_{Direction::BOTH}; + std::vector<memgraph::query::EdgeTypeIx> edge_types_; + std::variant<std::unordered_map<memgraph::query::PropertyIx, memgraph::query::Expression *>, + memgraph::query::ParameterLookup *> + properties_; + /// Evaluates to lower bound in variable length expands. + memgraph::query::Expression *lower_bound_{nullptr}; + /// Evaluated to upper bound in variable length expands. + memgraph::query::Expression *upper_bound_{nullptr}; + /// Filter lambda for variable length expands. Can have an empty expression, but identifiers must be valid, because an + /// optimization pass may inline other expressions into this lambda. + memgraph::query::EdgeAtom::Lambda filter_lambda_; + /// Used in weighted shortest path. It must have valid expressions and identifiers. In all other expand types, it is + /// empty. + memgraph::query::EdgeAtom::Lambda weight_lambda_; + /// Variable where the total weight for weighted shortest path will be stored. + memgraph::query::Identifier *total_weight_{nullptr}; + + EdgeAtom *Clone(AstStorage *storage) const override { + EdgeAtom *object = storage->Create<EdgeAtom>(); + object->identifier_ = identifier_ ? identifier_->Clone(storage) : nullptr; + object->type_ = type_; + object->direction_ = direction_; + object->edge_types_.resize(edge_types_.size()); + for (auto i = 0; i < object->edge_types_.size(); ++i) { + object->edge_types_[i] = storage->GetEdgeTypeIx(edge_types_[i].name); + } + if (const auto *properties = std::get_if<std::unordered_map<PropertyIx, Expression *>>(&properties_)) { + auto &new_obj_properties = std::get<std::unordered_map<PropertyIx, Expression *>>(object->properties_); + for (const auto &[property, value_expression] : *properties) { + PropertyIx key = storage->GetPropertyIx(property.name); + new_obj_properties[key] = value_expression->Clone(storage); + } + } else { + object->properties_ = std::get<ParameterLookup *>(properties_)->Clone(storage); + } + object->lower_bound_ = lower_bound_ ? lower_bound_->Clone(storage) : nullptr; + object->upper_bound_ = upper_bound_ ? upper_bound_->Clone(storage) : nullptr; + object->filter_lambda_ = filter_lambda_.Clone(storage); + object->weight_lambda_ = weight_lambda_.Clone(storage); + object->total_weight_ = total_weight_ ? total_weight_->Clone(storage) : nullptr; + return object; + } + + protected: + using PatternAtom::PatternAtom; + EdgeAtom(Identifier *identifier, Type type, Direction direction) + : PatternAtom(identifier), type_(type), direction_(direction) {} + + // Creates an edge atom for a SINGLE expansion with the given . + EdgeAtom(Identifier *identifier, Type type, Direction direction, const std::vector<EdgeTypeIx> &edge_types) + : PatternAtom(identifier), type_(type), direction_(direction), edge_types_(edge_types) {} + + private: + friend class AstStorage; +}; + +class Pattern : public memgraph::query::Tree, public utils::Visitable<HierarchicalTreeVisitor> { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + using utils::Visitable<HierarchicalTreeVisitor>::Accept; + + Pattern() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + bool cont = identifier_->Accept(visitor); + for (auto &part : atoms_) { + if (cont) { + cont = part->Accept(visitor); + } + } + } + return visitor.PostVisit(*this); + } + + memgraph::query::Identifier *identifier_{nullptr}; + std::vector<memgraph::query::PatternAtom *> atoms_; + + Pattern *Clone(AstStorage *storage) const override { + Pattern *object = storage->Create<Pattern>(); + object->identifier_ = identifier_ ? identifier_->Clone(storage) : nullptr; + object->atoms_.resize(atoms_.size()); + for (auto i3 = 0; i3 < atoms_.size(); ++i3) { + object->atoms_[i3] = atoms_[i3] ? atoms_[i3]->Clone(storage) : nullptr; + } + return object; + } + + private: + friend class AstStorage; +}; + +class Clause : public memgraph::query::Tree, public utils::Visitable<HierarchicalTreeVisitor> { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + using utils::Visitable<HierarchicalTreeVisitor>::Accept; + + Clause() = default; + + Clause *Clone(AstStorage *storage) const override = 0; + + private: + friend class AstStorage; +}; + +class SingleQuery : public memgraph::query::Tree, public utils::Visitable<HierarchicalTreeVisitor> { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + using utils::Visitable<HierarchicalTreeVisitor>::Accept; + + SingleQuery() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + for (auto &clause : clauses_) { + if (!clause->Accept(visitor)) break; + } + } + return visitor.PostVisit(*this); + } + + std::vector<memgraph::query::Clause *> clauses_; + + SingleQuery *Clone(AstStorage *storage) const override { + SingleQuery *object = storage->Create<SingleQuery>(); + object->clauses_.resize(clauses_.size()); + for (auto i4 = 0; i4 < clauses_.size(); ++i4) { + object->clauses_[i4] = clauses_[i4] ? clauses_[i4]->Clone(storage) : nullptr; + } + return object; + } + + private: + friend class AstStorage; +}; + +class CypherUnion : public memgraph::query::Tree, public utils::Visitable<HierarchicalTreeVisitor> { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + using utils::Visitable<HierarchicalTreeVisitor>::Accept; + + CypherUnion() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + single_query_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + memgraph::query::SingleQuery *single_query_{nullptr}; + bool distinct_{false}; + /// Holds symbols that are created during symbol generation phase. These symbols are used when UNION/UNION ALL + /// combines single query results. + std::vector<Symbol> union_symbols_; + + CypherUnion *Clone(AstStorage *storage) const override { + CypherUnion *object = storage->Create<CypherUnion>(); + object->single_query_ = single_query_ ? single_query_->Clone(storage) : nullptr; + object->distinct_ = distinct_; + object->union_symbols_ = union_symbols_; + return object; + } + + protected: + explicit CypherUnion(bool distinct) : distinct_(distinct) {} + CypherUnion(bool distinct, SingleQuery *single_query, std::vector<Symbol> union_symbols) + : single_query_(single_query), distinct_(distinct), union_symbols_(union_symbols) {} + + private: + friend class AstStorage; +}; + +class Query : public memgraph::query::Tree, public utils::Visitable<QueryVisitor<void>> { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + using utils::Visitable<QueryVisitor<void>>::Accept; + + Query() = default; + + Query *Clone(AstStorage *storage) const override = 0; + + private: + friend class AstStorage; +}; + +class CypherQuery : public memgraph::query::Query { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + CypherQuery() = default; + + DEFVISITABLE(QueryVisitor<void>); + + /// First and potentially only query. + memgraph::query::SingleQuery *single_query_{nullptr}; + /// Contains remaining queries that should form and union with `single_query_`. + std::vector<memgraph::query::CypherUnion *> cypher_unions_; + memgraph::query::Expression *memory_limit_{nullptr}; + size_t memory_scale_{1024U}; + + CypherQuery *Clone(AstStorage *storage) const override { + CypherQuery *object = storage->Create<CypherQuery>(); + object->single_query_ = single_query_ ? single_query_->Clone(storage) : nullptr; + object->cypher_unions_.resize(cypher_unions_.size()); + for (auto i5 = 0; i5 < cypher_unions_.size(); ++i5) { + object->cypher_unions_[i5] = cypher_unions_[i5] ? cypher_unions_[i5]->Clone(storage) : nullptr; + } + object->memory_limit_ = memory_limit_ ? memory_limit_->Clone(storage) : nullptr; + object->memory_scale_ = memory_scale_; + return object; + } + + private: + friend class AstStorage; +}; + +class ExplainQuery : public memgraph::query::Query { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + ExplainQuery() = default; + + DEFVISITABLE(QueryVisitor<void>); + + /// The CypherQuery to explain. + memgraph::query::CypherQuery *cypher_query_{nullptr}; + + ExplainQuery *Clone(AstStorage *storage) const override { + ExplainQuery *object = storage->Create<ExplainQuery>(); + object->cypher_query_ = cypher_query_ ? cypher_query_->Clone(storage) : nullptr; + return object; + } + + private: + friend class AstStorage; +}; + +class ProfileQuery : public memgraph::query::Query { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + ProfileQuery() = default; + + DEFVISITABLE(QueryVisitor<void>); + + /// The CypherQuery to profile. + memgraph::query::CypherQuery *cypher_query_{nullptr}; + + ProfileQuery *Clone(AstStorage *storage) const override { + ProfileQuery *object = storage->Create<ProfileQuery>(); + object->cypher_query_ = cypher_query_ ? cypher_query_->Clone(storage) : nullptr; + return object; + } + + private: + friend class AstStorage; +}; + +class IndexQuery : public memgraph::query::Query { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + enum class Action { CREATE, DROP }; + + IndexQuery() = default; + + DEFVISITABLE(QueryVisitor<void>); + + memgraph::query::IndexQuery::Action action_; + memgraph::query::LabelIx label_; + std::vector<memgraph::query::PropertyIx> properties_; + + IndexQuery *Clone(AstStorage *storage) const override { + IndexQuery *object = storage->Create<IndexQuery>(); + object->action_ = action_; + object->label_ = storage->GetLabelIx(label_.name); + object->properties_.resize(properties_.size()); + for (auto i = 0; i < object->properties_.size(); ++i) { + object->properties_[i] = storage->GetPropertyIx(properties_[i].name); + } + return object; + } + + protected: + IndexQuery(Action action, LabelIx label, std::vector<PropertyIx> properties) + : action_(action), label_(label), properties_(properties) {} + + private: + friend class AstStorage; +}; + +class Create : public memgraph::query::Clause { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Create() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + for (auto &pattern : patterns_) { + if (!pattern->Accept(visitor)) break; + } + } + return visitor.PostVisit(*this); + } + + std::vector<memgraph::query::Pattern *> patterns_; + + Create *Clone(AstStorage *storage) const override { + Create *object = storage->Create<Create>(); + object->patterns_.resize(patterns_.size()); + for (auto i6 = 0; i6 < patterns_.size(); ++i6) { + object->patterns_[i6] = patterns_[i6] ? patterns_[i6]->Clone(storage) : nullptr; + } + return object; + } + + protected: + explicit Create(std::vector<Pattern *> patterns) : patterns_(patterns) {} + + private: + friend class AstStorage; +}; + +class CallProcedure : public memgraph::query::Clause { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + CallProcedure() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + bool cont = true; + for (auto &arg : arguments_) { + if (!arg->Accept(visitor)) { + cont = false; + break; + } + } + if (cont) { + for (auto &ident : result_identifiers_) { + if (!ident->Accept(visitor)) { + cont = false; + break; + } + } + } + } + return visitor.PostVisit(*this); + } + + std::string procedure_name_; + std::vector<memgraph::query::Expression *> arguments_; + std::vector<std::string> result_fields_; + std::vector<memgraph::query::Identifier *> result_identifiers_; + memgraph::query::Expression *memory_limit_{nullptr}; + size_t memory_scale_{1024U}; + bool is_write_; + + CallProcedure *Clone(AstStorage *storage) const override { + CallProcedure *object = storage->Create<CallProcedure>(); + object->procedure_name_ = procedure_name_; + object->arguments_.resize(arguments_.size()); + for (auto i7 = 0; i7 < arguments_.size(); ++i7) { + object->arguments_[i7] = arguments_[i7] ? arguments_[i7]->Clone(storage) : nullptr; + } + object->result_fields_ = result_fields_; + object->result_identifiers_.resize(result_identifiers_.size()); + for (auto i8 = 0; i8 < result_identifiers_.size(); ++i8) { + object->result_identifiers_[i8] = result_identifiers_[i8] ? result_identifiers_[i8]->Clone(storage) : nullptr; + } + object->memory_limit_ = memory_limit_ ? memory_limit_->Clone(storage) : nullptr; + object->memory_scale_ = memory_scale_; + object->is_write_ = is_write_; + return object; + } + + private: + friend class AstStorage; +}; + +class Match : public memgraph::query::Clause { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Match() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + bool cont = true; + for (auto &pattern : patterns_) { + if (!pattern->Accept(visitor)) { + cont = false; + break; + } + } + if (cont && where_) { + where_->Accept(visitor); + } + } + return visitor.PostVisit(*this); + } + + std::vector<memgraph::query::Pattern *> patterns_; + memgraph::query::Where *where_{nullptr}; + bool optional_{false}; + + Match *Clone(AstStorage *storage) const override { + Match *object = storage->Create<Match>(); + object->patterns_.resize(patterns_.size()); + for (auto i9 = 0; i9 < patterns_.size(); ++i9) { + object->patterns_[i9] = patterns_[i9] ? patterns_[i9]->Clone(storage) : nullptr; + } + object->where_ = where_ ? where_->Clone(storage) : nullptr; + object->optional_ = optional_; + return object; + } + + protected: + explicit Match(bool optional) : optional_(optional) {} + Match(bool optional, Where *where, std::vector<Pattern *> patterns) + : patterns_(patterns), where_(where), optional_(optional) {} + + private: + friend class AstStorage; +}; + +/// Defines the order for sorting values (ascending or descending). +enum class Ordering { ASC, DESC }; + +struct SortItem { + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const { return kType; } + + memgraph::query::Ordering ordering; + memgraph::query::Expression *expression; + + SortItem Clone(AstStorage *storage) const { + SortItem object; + object.ordering = ordering; + object.expression = expression ? expression->Clone(storage) : nullptr; + return object; + } +}; + +/// Contents common to @c Return and @c With clauses. +struct ReturnBody { + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const { return kType; } + + /// True if distinct results should be produced. + bool distinct{false}; + /// True if asterisk was found in the return body. + bool all_identifiers{false}; + /// Expressions which are used to produce results. + std::vector<memgraph::query::NamedExpression *> named_expressions; + /// Expressions used for ordering the results. + std::vector<memgraph::query::SortItem> order_by; + /// Optional expression on how many results to skip. + memgraph::query::Expression *skip{nullptr}; + /// Optional expression on how many results to produce. + memgraph::query::Expression *limit{nullptr}; + + ReturnBody Clone(AstStorage *storage) const { + ReturnBody object; + object.distinct = distinct; + object.all_identifiers = all_identifiers; + object.named_expressions.resize(named_expressions.size()); + for (auto i10 = 0; i10 < named_expressions.size(); ++i10) { + object.named_expressions[i10] = named_expressions[i10] ? named_expressions[i10]->Clone(storage) : nullptr; + } + object.order_by.resize(order_by.size()); + for (auto i11 = 0; i11 < order_by.size(); ++i11) { + object.order_by[i11] = order_by[i11].Clone(storage); + } + object.skip = skip ? skip->Clone(storage) : nullptr; + object.limit = limit ? limit->Clone(storage) : nullptr; + return object; + } +}; + +class Return : public memgraph::query::Clause { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Return() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + bool cont = true; + for (auto &expr : body_.named_expressions) { + if (!expr->Accept(visitor)) { + cont = false; + break; + } + } + if (cont) { + for (auto &order_by : body_.order_by) { + if (!order_by.expression->Accept(visitor)) { + cont = false; + break; + } + } + } + if (cont && body_.skip) cont = body_.skip->Accept(visitor); + if (cont && body_.limit) cont = body_.limit->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + memgraph::query::ReturnBody body_; + + Return *Clone(AstStorage *storage) const override { + Return *object = storage->Create<Return>(); + object->body_ = body_.Clone(storage); + return object; + } + + protected: + explicit Return(ReturnBody &body) : body_(body) {} + + private: + friend class AstStorage; +}; + +class With : public memgraph::query::Clause { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + With() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + bool cont = true; + for (auto &expr : body_.named_expressions) { + if (!expr->Accept(visitor)) { + cont = false; + break; + } + } + if (cont) { + for (auto &order_by : body_.order_by) { + if (!order_by.expression->Accept(visitor)) { + cont = false; + break; + } + } + } + if (cont && where_) cont = where_->Accept(visitor); + if (cont && body_.skip) cont = body_.skip->Accept(visitor); + if (cont && body_.limit) cont = body_.limit->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + memgraph::query::ReturnBody body_; + memgraph::query::Where *where_{nullptr}; + + With *Clone(AstStorage *storage) const override { + With *object = storage->Create<With>(); + object->body_ = body_.Clone(storage); + object->where_ = where_ ? where_->Clone(storage) : nullptr; + return object; + } + + protected: + With(ReturnBody &body, Where *where) : body_(body), where_(where) {} + + private: + friend class AstStorage; +}; + +class Delete : public memgraph::query::Clause { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Delete() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + for (auto &expr : expressions_) { + if (!expr->Accept(visitor)) break; + } + } + return visitor.PostVisit(*this); + } + + std::vector<memgraph::query::Expression *> expressions_; + bool detach_{false}; + + Delete *Clone(AstStorage *storage) const override { + Delete *object = storage->Create<Delete>(); + object->expressions_.resize(expressions_.size()); + for (auto i12 = 0; i12 < expressions_.size(); ++i12) { + object->expressions_[i12] = expressions_[i12] ? expressions_[i12]->Clone(storage) : nullptr; + } + object->detach_ = detach_; + return object; + } + + protected: + Delete(bool detach, std::vector<Expression *> expressions) : expressions_(expressions), detach_(detach) {} + + private: + friend class AstStorage; +}; + +class SetProperty : public memgraph::query::Clause { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + SetProperty() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + property_lookup_->Accept(visitor) && expression_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + memgraph::query::PropertyLookup *property_lookup_{nullptr}; + memgraph::query::Expression *expression_{nullptr}; + + SetProperty *Clone(AstStorage *storage) const override { + SetProperty *object = storage->Create<SetProperty>(); + object->property_lookup_ = property_lookup_ ? property_lookup_->Clone(storage) : nullptr; + object->expression_ = expression_ ? expression_->Clone(storage) : nullptr; + return object; + } + + protected: + SetProperty(PropertyLookup *property_lookup, Expression *expression) + : property_lookup_(property_lookup), expression_(expression) {} + + private: + friend class AstStorage; +}; + +class SetProperties : public memgraph::query::Clause { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + SetProperties() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + identifier_->Accept(visitor) && expression_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + memgraph::query::Identifier *identifier_{nullptr}; + memgraph::query::Expression *expression_{nullptr}; + bool update_{false}; + + SetProperties *Clone(AstStorage *storage) const override { + SetProperties *object = storage->Create<SetProperties>(); + object->identifier_ = identifier_ ? identifier_->Clone(storage) : nullptr; + object->expression_ = expression_ ? expression_->Clone(storage) : nullptr; + object->update_ = update_; + return object; + } + + protected: + SetProperties(Identifier *identifier, Expression *expression, bool update = false) + : identifier_(identifier), expression_(expression), update_(update) {} + + private: + friend class AstStorage; +}; + +class SetLabels : public memgraph::query::Clause { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + SetLabels() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + identifier_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + memgraph::query::Identifier *identifier_{nullptr}; + std::vector<memgraph::query::LabelIx> labels_; + + SetLabels *Clone(AstStorage *storage) const override { + SetLabels *object = storage->Create<SetLabels>(); + object->identifier_ = identifier_ ? identifier_->Clone(storage) : nullptr; + object->labels_.resize(labels_.size()); + for (auto i = 0; i < object->labels_.size(); ++i) { + object->labels_[i] = storage->GetLabelIx(labels_[i].name); + } + return object; + } + + protected: + SetLabels(Identifier *identifier, const std::vector<LabelIx> &labels) : identifier_(identifier), labels_(labels) {} + + private: + friend class AstStorage; +}; + +class RemoveProperty : public memgraph::query::Clause { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + RemoveProperty() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + property_lookup_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + memgraph::query::PropertyLookup *property_lookup_{nullptr}; + + RemoveProperty *Clone(AstStorage *storage) const override { + RemoveProperty *object = storage->Create<RemoveProperty>(); + object->property_lookup_ = property_lookup_ ? property_lookup_->Clone(storage) : nullptr; + return object; + } + + protected: + explicit RemoveProperty(PropertyLookup *property_lookup) : property_lookup_(property_lookup) {} + + private: + friend class AstStorage; +}; + +class RemoveLabels : public memgraph::query::Clause { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + RemoveLabels() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + identifier_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + memgraph::query::Identifier *identifier_{nullptr}; + std::vector<memgraph::query::LabelIx> labels_; + + RemoveLabels *Clone(AstStorage *storage) const override { + RemoveLabels *object = storage->Create<RemoveLabels>(); + object->identifier_ = identifier_ ? identifier_->Clone(storage) : nullptr; + object->labels_.resize(labels_.size()); + for (auto i = 0; i < object->labels_.size(); ++i) { + object->labels_[i] = storage->GetLabelIx(labels_[i].name); + } + return object; + } + + protected: + RemoveLabels(Identifier *identifier, const std::vector<LabelIx> &labels) : identifier_(identifier), labels_(labels) {} + + private: + friend class AstStorage; +}; + +class Merge : public memgraph::query::Clause { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Merge() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + bool cont = pattern_->Accept(visitor); + if (cont) { + for (auto &set : on_match_) { + if (!set->Accept(visitor)) { + cont = false; + break; + } + } + } + if (cont) { + for (auto &set : on_create_) { + if (!set->Accept(visitor)) { + cont = false; + break; + } + } + } + } + return visitor.PostVisit(*this); + } + + memgraph::query::Pattern *pattern_{nullptr}; + std::vector<memgraph::query::Clause *> on_match_; + std::vector<memgraph::query::Clause *> on_create_; + + Merge *Clone(AstStorage *storage) const override { + Merge *object = storage->Create<Merge>(); + object->pattern_ = pattern_ ? pattern_->Clone(storage) : nullptr; + object->on_match_.resize(on_match_.size()); + for (auto i13 = 0; i13 < on_match_.size(); ++i13) { + object->on_match_[i13] = on_match_[i13] ? on_match_[i13]->Clone(storage) : nullptr; + } + object->on_create_.resize(on_create_.size()); + for (auto i14 = 0; i14 < on_create_.size(); ++i14) { + object->on_create_[i14] = on_create_[i14] ? on_create_[i14]->Clone(storage) : nullptr; + } + return object; + } + + protected: + Merge(Pattern *pattern, std::vector<Clause *> on_match, std::vector<Clause *> on_create) + : pattern_(pattern), on_match_(on_match), on_create_(on_create) {} + + private: + friend class AstStorage; +}; + +class Unwind : public memgraph::query::Clause { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Unwind() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + named_expression_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + memgraph::query::NamedExpression *named_expression_{nullptr}; + + Unwind *Clone(AstStorage *storage) const override { + Unwind *object = storage->Create<Unwind>(); + object->named_expression_ = named_expression_ ? named_expression_->Clone(storage) : nullptr; + return object; + } + + protected: + explicit Unwind(NamedExpression *named_expression) : named_expression_(named_expression) { + DMG_ASSERT(named_expression, "Unwind cannot take nullptr for named_expression"); + } + + private: + friend class AstStorage; +}; + +class AuthQuery : public memgraph::query::Query { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + enum class Action { + CREATE_ROLE, + DROP_ROLE, + SHOW_ROLES, + CREATE_USER, + SET_PASSWORD, + DROP_USER, + SHOW_USERS, + SET_ROLE, + CLEAR_ROLE, + GRANT_PRIVILEGE, + DENY_PRIVILEGE, + REVOKE_PRIVILEGE, + SHOW_PRIVILEGES, + SHOW_ROLE_FOR_USER, + SHOW_USERS_FOR_ROLE + }; + + enum class Privilege { + CREATE, + DELETE, + MATCH, + MERGE, + SET, + REMOVE, + INDEX, + STATS, + AUTH, + CONSTRAINT, + DUMP, + REPLICATION, + DURABILITY, + READ_FILE, + FREE_MEMORY, + TRIGGER, + CONFIG, + STREAM, + MODULE_READ, + MODULE_WRITE, + WEBSOCKET + }; + + enum class FineGrainedPrivilege { NOTHING, READ, UPDATE, CREATE_DELETE }; + + AuthQuery() = default; + + DEFVISITABLE(QueryVisitor<void>); + + memgraph::query::AuthQuery::Action action_; + std::string user_; + std::string role_; + std::string user_or_role_; + memgraph::query::Expression *password_{nullptr}; + std::vector<memgraph::query::AuthQuery::Privilege> privileges_; + std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>> + label_privileges_; + std::vector<std::unordered_map<memgraph::query::AuthQuery::FineGrainedPrivilege, std::vector<std::string>>> + edge_type_privileges_; + + AuthQuery *Clone(AstStorage *storage) const override { + AuthQuery *object = storage->Create<AuthQuery>(); + object->action_ = action_; + object->user_ = user_; + object->role_ = role_; + object->user_or_role_ = user_or_role_; + object->password_ = password_ ? password_->Clone(storage) : nullptr; + object->privileges_ = privileges_; + object->label_privileges_ = label_privileges_; + object->edge_type_privileges_ = edge_type_privileges_; + return object; + } + + protected: + AuthQuery(Action action, std::string user, std::string role, std::string user_or_role, Expression *password, + std::vector<Privilege> privileges, + std::vector<std::unordered_map<FineGrainedPrivilege, std::vector<std::string>>> label_privileges, + std::vector<std::unordered_map<FineGrainedPrivilege, std::vector<std::string>>> edge_type_privileges) + : action_(action), + user_(user), + role_(role), + user_or_role_(user_or_role), + password_(password), + privileges_(privileges), + label_privileges_(label_privileges), + edge_type_privileges_(edge_type_privileges) {} + + private: + friend class AstStorage; +}; + +/// Constant that holds all available privileges. +const std::vector<AuthQuery::Privilege> kPrivilegesAll = { + AuthQuery::Privilege::CREATE, AuthQuery::Privilege::DELETE, AuthQuery::Privilege::MATCH, + AuthQuery::Privilege::MERGE, AuthQuery::Privilege::SET, AuthQuery::Privilege::REMOVE, + AuthQuery::Privilege::INDEX, AuthQuery::Privilege::STATS, AuthQuery::Privilege::AUTH, + AuthQuery::Privilege::CONSTRAINT, AuthQuery::Privilege::DUMP, AuthQuery::Privilege::REPLICATION, + AuthQuery::Privilege::READ_FILE, AuthQuery::Privilege::DURABILITY, AuthQuery::Privilege::FREE_MEMORY, + AuthQuery::Privilege::TRIGGER, AuthQuery::Privilege::CONFIG, AuthQuery::Privilege::STREAM, + AuthQuery::Privilege::MODULE_READ, AuthQuery::Privilege::MODULE_WRITE, AuthQuery::Privilege::WEBSOCKET}; + +class InfoQuery : public memgraph::query::Query { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + enum class InfoType { STORAGE, INDEX, CONSTRAINT }; + + DEFVISITABLE(QueryVisitor<void>); + + memgraph::query::InfoQuery::InfoType info_type_; + + InfoQuery *Clone(AstStorage *storage) const override { + InfoQuery *object = storage->Create<InfoQuery>(); + object->info_type_ = info_type_; + return object; + } +}; + +struct Constraint { + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const { return kType; } + + enum class Type { EXISTS, UNIQUE, NODE_KEY }; + + memgraph::query::Constraint::Type type; + memgraph::query::LabelIx label; + std::vector<memgraph::query::PropertyIx> properties; + + Constraint Clone(AstStorage *storage) const { + Constraint object; + object.type = type; + object.label = storage->GetLabelIx(label.name); + object.properties.resize(properties.size()); + for (auto i = 0; i < object.properties.size(); ++i) { + object.properties[i] = storage->GetPropertyIx(properties[i].name); + } + return object; + } +}; + +class ConstraintQuery : public memgraph::query::Query { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + enum class ActionType { CREATE, DROP }; + + DEFVISITABLE(QueryVisitor<void>); + + memgraph::query::ConstraintQuery::ActionType action_type_; + memgraph::query::Constraint constraint_; + + ConstraintQuery *Clone(AstStorage *storage) const override { + ConstraintQuery *object = storage->Create<ConstraintQuery>(); + object->action_type_ = action_type_; + object->constraint_ = constraint_.Clone(storage); + return object; + } +}; + +class DumpQuery : public memgraph::query::Query { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + DEFVISITABLE(QueryVisitor<void>); + + DumpQuery *Clone(AstStorage *storage) const override { + DumpQuery *object = storage->Create<DumpQuery>(); + return object; + } +}; + +class ReplicationQuery : public memgraph::query::Query { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + enum class Action { SET_REPLICATION_ROLE, SHOW_REPLICATION_ROLE, REGISTER_REPLICA, DROP_REPLICA, SHOW_REPLICAS }; + + enum class ReplicationRole { MAIN, REPLICA }; + + enum class SyncMode { SYNC, ASYNC }; + + enum class ReplicaState { READY, REPLICATING, RECOVERY, INVALID }; + + ReplicationQuery() = default; + + DEFVISITABLE(QueryVisitor<void>); + + memgraph::query::ReplicationQuery::Action action_; + memgraph::query::ReplicationQuery::ReplicationRole role_; + std::string replica_name_; + memgraph::query::Expression *socket_address_{nullptr}; + memgraph::query::Expression *port_{nullptr}; + memgraph::query::ReplicationQuery::SyncMode sync_mode_; + + ReplicationQuery *Clone(AstStorage *storage) const override { + ReplicationQuery *object = storage->Create<ReplicationQuery>(); + object->action_ = action_; + object->role_ = role_; + object->replica_name_ = replica_name_; + object->socket_address_ = socket_address_ ? socket_address_->Clone(storage) : nullptr; + object->port_ = port_ ? port_->Clone(storage) : nullptr; + object->sync_mode_ = sync_mode_; + return object; + } + + private: + friend class AstStorage; +}; + +class LockPathQuery : public memgraph::query::Query { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + enum class Action { LOCK_PATH, UNLOCK_PATH }; + + LockPathQuery() = default; + + DEFVISITABLE(QueryVisitor<void>); + + memgraph::query::LockPathQuery::Action action_; + + LockPathQuery *Clone(AstStorage *storage) const override { + LockPathQuery *object = storage->Create<LockPathQuery>(); + object->action_ = action_; + return object; + } + + private: + friend class AstStorage; +}; + +class LoadCsv : public memgraph::query::Clause { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + LoadCsv() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + row_var_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + memgraph::query::Expression *file_; + bool with_header_; + bool ignore_bad_; + memgraph::query::Expression *delimiter_{nullptr}; + memgraph::query::Expression *quote_{nullptr}; + memgraph::query::Identifier *row_var_{nullptr}; + + LoadCsv *Clone(AstStorage *storage) const override { + LoadCsv *object = storage->Create<LoadCsv>(); + object->file_ = file_ ? file_->Clone(storage) : nullptr; + object->with_header_ = with_header_; + object->ignore_bad_ = ignore_bad_; + object->delimiter_ = delimiter_ ? delimiter_->Clone(storage) : nullptr; + object->quote_ = quote_ ? quote_->Clone(storage) : nullptr; + object->row_var_ = row_var_ ? row_var_->Clone(storage) : nullptr; + return object; + } + + protected: + explicit LoadCsv(Expression *file, bool with_header, bool ignore_bad, Expression *delimiter, Expression *quote, + Identifier *row_var) + : file_(file), + with_header_(with_header), + ignore_bad_(ignore_bad), + delimiter_(delimiter), + quote_(quote), + row_var_(row_var) { + DMG_ASSERT(row_var, "LoadCsv cannot take nullptr for identifier"); + } + + private: + friend class AstStorage; +}; + +class FreeMemoryQuery : public memgraph::query::Query { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + DEFVISITABLE(QueryVisitor<void>); + + FreeMemoryQuery *Clone(AstStorage *storage) const override { + FreeMemoryQuery *object = storage->Create<FreeMemoryQuery>(); + return object; + } +}; + +class TriggerQuery : public memgraph::query::Query { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + enum class Action { CREATE_TRIGGER, DROP_TRIGGER, SHOW_TRIGGERS }; + + enum class EventType { + ANY, + VERTEX_CREATE, + EDGE_CREATE, + CREATE, + VERTEX_DELETE, + EDGE_DELETE, + DELETE, + VERTEX_UPDATE, + EDGE_UPDATE, + UPDATE + }; + + TriggerQuery() = default; + + DEFVISITABLE(QueryVisitor<void>); + + memgraph::query::TriggerQuery::Action action_; + memgraph::query::TriggerQuery::EventType event_type_; + std::string trigger_name_; + bool before_commit_; + std::string statement_; + + TriggerQuery *Clone(AstStorage *storage) const override { + TriggerQuery *object = storage->Create<TriggerQuery>(); + object->action_ = action_; + object->event_type_ = event_type_; + object->trigger_name_ = trigger_name_; + object->before_commit_ = before_commit_; + object->statement_ = statement_; + return object; + } + + private: + friend class AstStorage; +}; + +class IsolationLevelQuery : public memgraph::query::Query { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + enum class IsolationLevel { SNAPSHOT_ISOLATION, READ_COMMITTED, READ_UNCOMMITTED }; + + enum class IsolationLevelScope { NEXT, SESSION, GLOBAL }; + + IsolationLevelQuery() = default; + + DEFVISITABLE(QueryVisitor<void>); + + memgraph::query::IsolationLevelQuery::IsolationLevel isolation_level_; + memgraph::query::IsolationLevelQuery::IsolationLevelScope isolation_level_scope_; + + IsolationLevelQuery *Clone(AstStorage *storage) const override { + IsolationLevelQuery *object = storage->Create<IsolationLevelQuery>(); + object->isolation_level_ = isolation_level_; + object->isolation_level_scope_ = isolation_level_scope_; + return object; + } + + private: + friend class AstStorage; +}; + +class CreateSnapshotQuery : public memgraph::query::Query { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + DEFVISITABLE(QueryVisitor<void>); + + CreateSnapshotQuery *Clone(AstStorage *storage) const override { + CreateSnapshotQuery *object = storage->Create<CreateSnapshotQuery>(); + return object; + } +}; + +class StreamQuery : public memgraph::query::Query { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + enum class Action { + CREATE_STREAM, + DROP_STREAM, + START_STREAM, + STOP_STREAM, + START_ALL_STREAMS, + STOP_ALL_STREAMS, + SHOW_STREAMS, + CHECK_STREAM + }; + + enum class Type { KAFKA, PULSAR }; + + StreamQuery() = default; + + DEFVISITABLE(QueryVisitor<void>); + + memgraph::query::StreamQuery::Action action_; + memgraph::query::StreamQuery::Type type_; + std::string stream_name_; + memgraph::query::Expression *batch_limit_{nullptr}; + memgraph::query::Expression *timeout_{nullptr}; + std::string transform_name_; + memgraph::query::Expression *batch_interval_{nullptr}; + memgraph::query::Expression *batch_size_{nullptr}; + std::variant<memgraph::query::Expression *, std::vector<std::string>> topic_names_{nullptr}; + std::string consumer_group_; + memgraph::query::Expression *bootstrap_servers_{nullptr}; + memgraph::query::Expression *service_url_{nullptr}; + std::unordered_map<memgraph::query::Expression *, memgraph::query::Expression *> configs_; + std::unordered_map<memgraph::query::Expression *, memgraph::query::Expression *> credentials_; + + StreamQuery *Clone(AstStorage *storage) const override { + StreamQuery *object = storage->Create<StreamQuery>(); + object->action_ = action_; + object->type_ = type_; + object->stream_name_ = stream_name_; + object->batch_limit_ = batch_limit_ ? batch_limit_->Clone(storage) : nullptr; + object->timeout_ = timeout_ ? timeout_->Clone(storage) : nullptr; + object->transform_name_ = transform_name_; + object->batch_interval_ = batch_interval_ ? batch_interval_->Clone(storage) : nullptr; + object->batch_size_ = batch_size_ ? batch_size_->Clone(storage) : nullptr; + if (auto *topic_expression = std::get_if<Expression *>(&topic_names_)) { + if (*topic_expression == nullptr) { + object->topic_names_ = nullptr; + } else { + object->topic_names_ = (*topic_expression)->Clone(storage); + } + } else { + object->topic_names_ = std::get<std::vector<std::string>>(topic_names_); + } + object->consumer_group_ = consumer_group_; + object->bootstrap_servers_ = bootstrap_servers_ ? bootstrap_servers_->Clone(storage) : nullptr; + object->service_url_ = service_url_ ? service_url_->Clone(storage) : nullptr; + for (const auto &[key, value] : configs_) { + object->configs_[key->Clone(storage)] = value->Clone(storage); + } + for (const auto &[key, value] : credentials_) { + object->credentials_[key->Clone(storage)] = value->Clone(storage); + } + return object; + } + + private: + friend class AstStorage; +}; + +class SettingQuery : public memgraph::query::Query { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + enum class Action { SHOW_SETTING, SHOW_ALL_SETTINGS, SET_SETTING }; + + SettingQuery() = default; + + DEFVISITABLE(QueryVisitor<void>); + + memgraph::query::SettingQuery::Action action_; + memgraph::query::Expression *setting_name_{nullptr}; + memgraph::query::Expression *setting_value_{nullptr}; + + SettingQuery *Clone(AstStorage *storage) const override { + SettingQuery *object = storage->Create<SettingQuery>(); + object->action_ = action_; + object->setting_name_ = setting_name_ ? setting_name_->Clone(storage) : nullptr; + object->setting_value_ = setting_value_ ? setting_value_->Clone(storage) : nullptr; + return object; + } + + private: + friend class AstStorage; +}; + +class VersionQuery : public memgraph::query::Query { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + DEFVISITABLE(QueryVisitor<void>); + + VersionQuery *Clone(AstStorage *storage) const override { + VersionQuery *object = storage->Create<VersionQuery>(); + return object; + } +}; + +class Foreach : public memgraph::query::Clause { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Foreach() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + named_expression_->Accept(visitor); + for (auto &clause : clauses_) { + clause->Accept(visitor); + } + } + return visitor.PostVisit(*this); + } + + memgraph::query::NamedExpression *named_expression_{nullptr}; + std::vector<memgraph::query::Clause *> clauses_; + + Foreach *Clone(AstStorage *storage) const override { + Foreach *object = storage->Create<Foreach>(); + object->named_expression_ = named_expression_ ? named_expression_->Clone(storage) : nullptr; + object->clauses_.resize(clauses_.size()); + for (auto i15 = 0; i15 < clauses_.size(); ++i15) { + object->clauses_[i15] = clauses_[i15] ? clauses_[i15]->Clone(storage) : nullptr; + } + return object; + } + + protected: + Foreach(NamedExpression *expression, std::vector<Clause *> clauses) + : named_expression_(expression), clauses_(clauses) {} + + private: + friend class AstStorage; +}; + +class ShowConfigQuery : public memgraph::query::Query { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + DEFVISITABLE(QueryVisitor<void>); + + ShowConfigQuery *Clone(AstStorage *storage) const override { + ShowConfigQuery *object = storage->Create<ShowConfigQuery>(); + return object; + } +}; + +class Exists : public memgraph::query::Expression { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Exists() = default; + + DEFVISITABLE(ExpressionVisitor<TypedValue>); + DEFVISITABLE(ExpressionVisitor<TypedValue *>); + DEFVISITABLE(ExpressionVisitor<void>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + pattern_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + Exists *MapTo(const Symbol &symbol) { + symbol_pos_ = symbol.position(); + return this; + } + + memgraph::query::Pattern *pattern_{nullptr}; + /// Symbol table position of the symbol this Aggregation is mapped to. + int32_t symbol_pos_{-1}; + + Exists *Clone(AstStorage *storage) const override { + Exists *object = storage->Create<Exists>(); + object->pattern_ = pattern_ ? pattern_->Clone(storage) : nullptr; + object->symbol_pos_ = symbol_pos_; + return object; + } + + protected: + Exists(Pattern *pattern) : pattern_(pattern) {} + + private: + friend class AstStorage; +}; + +} // namespace query +} // namespace memgraph diff --git a/src/query/frontend/semantic/symbol.cpp b/src/query/frontend/semantic/symbol.cpp new file mode 100644 index 000000000..57d19b25d --- /dev/null +++ b/src/query/frontend/semantic/symbol.cpp @@ -0,0 +1,18 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query/frontend/semantic/symbol.hpp" +#include "utils/typeinfo.hpp" + +namespace memgraph { + +constexpr utils::TypeInfo query::Symbol::kType{utils::TypeId::SYMBOL, "Symbol", nullptr}; +} // namespace memgraph diff --git a/src/query/frontend/semantic/symbol.hpp b/src/query/frontend/semantic/symbol.hpp new file mode 100644 index 000000000..5381cb48d --- /dev/null +++ b/src/query/frontend/semantic/symbol.hpp @@ -0,0 +1,75 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <string> + +#include "utils/typeinfo.hpp" + +namespace memgraph { + +namespace query { + +class Symbol { + public: + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + enum class Type { ANY, VERTEX, EDGE, PATH, NUMBER, EDGE_LIST }; + + // TODO: Generate enum to string conversion from LCP. Note, that this is + // displayed to the end user, so we may want to have a pretty name of each + // value. + static std::string TypeToString(Type type) { + const char *enum_string[] = {"Any", "Vertex", "Edge", "Path", "Number", "EdgeList"}; + return enum_string[static_cast<int>(type)]; + } + + Symbol() {} + Symbol(const std::string &name, int position, bool user_declared, Type type = Type::ANY, int token_position = -1) + : name_(name), position_(position), user_declared_(user_declared), type_(type), token_position_(token_position) {} + + bool operator==(const Symbol &other) const { + return position_ == other.position_ && name_ == other.name_ && type_ == other.type_; + } + bool operator!=(const Symbol &other) const { return !operator==(other); } + + // TODO: Remove these since members are public + const auto &name() const { return name_; } + int position() const { return position_; } + Type type() const { return type_; } + bool user_declared() const { return user_declared_; } + int token_position() const { return token_position_; } + + std::string name_; + int64_t position_; + bool user_declared_{true}; + memgraph::query::Symbol::Type type_{Type::ANY}; + int64_t token_position_{-1}; +}; + +} // namespace query +} // namespace memgraph +namespace std { + +template <> +struct hash<memgraph::query::Symbol> { + size_t operator()(const memgraph::query::Symbol &symbol) const { + size_t prime = 265443599u; + size_t hash = std::hash<int>{}(symbol.position()); + hash ^= prime * std::hash<std::string>{}(symbol.name()); + hash ^= prime * std::hash<int>{}(static_cast<int>(symbol.type())); + return hash; + } +}; + +} // namespace std diff --git a/src/query/plan/operator.hpp b/src/query/plan/operator.hpp new file mode 100644 index 000000000..08ced64aa --- /dev/null +++ b/src/query/plan/operator.hpp @@ -0,0 +1,2296 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <memory> +#include <optional> +#include <unordered_map> +#include <unordered_set> +#include <utility> +#include <variant> +#include <vector> + +#include "query/common.hpp" +#include "query/frontend/ast/ast.hpp" +#include "query/frontend/semantic/symbol.hpp" +#include "query/typed_value.hpp" +#include "storage/v2/id_types.hpp" +#include "utils/bound.hpp" +#include "utils/fnv.hpp" +#include "utils/logging.hpp" +#include "utils/memory.hpp" +#include "utils/visitor.hpp" + +namespace memgraph { + +namespace query { + +struct ExecutionContext; +class ExpressionEvaluator; +class Frame; +class SymbolTable; + +namespace plan { + +/// Base class for iteration cursors of @c LogicalOperator classes. +/// +/// Each @c LogicalOperator must produce a concrete @c Cursor, which provides +/// the iteration mechanism. +class Cursor { + public: + /// Run an iteration of a @c LogicalOperator. + /// + /// Since operators may be chained, the iteration may pull results from + /// multiple operators. + /// + /// @param Frame May be read from or written to while performing the + /// iteration. + /// @param ExecutionContext Used to get the position of symbols in frame and + /// other information. + /// + /// @throws QueryRuntimeException if something went wrong with execution + virtual bool Pull(Frame &, ExecutionContext &) = 0; + + /// Resets the Cursor to its initial state. + virtual void Reset() = 0; + + /// Perform cleanup which may throw an exception + virtual void Shutdown() = 0; + + virtual ~Cursor() {} +}; + +/// unique_ptr to Cursor managed with a custom deleter. +/// This allows us to use utils::MemoryResource for allocation. +using UniqueCursorPtr = std::unique_ptr<Cursor, std::function<void(Cursor *)>>; + +template <class TCursor, class... TArgs> +std::unique_ptr<Cursor, std::function<void(Cursor *)>> MakeUniqueCursorPtr(utils::Allocator<TCursor> allocator, + TArgs &&...args) { + auto *ptr = allocator.allocate(1); + try { + auto *cursor = new (ptr) TCursor(std::forward<TArgs>(args)...); + return std::unique_ptr<Cursor, std::function<void(Cursor *)>>(cursor, [allocator](Cursor *base_ptr) mutable { + auto *p = static_cast<TCursor *>(base_ptr); + p->~TCursor(); + allocator.deallocate(p, 1); + }); + } catch (...) { + allocator.deallocate(ptr, 1); + throw; + } +} + +class Once; +class CreateNode; +class CreateExpand; +class ScanAll; +class ScanAllByLabel; +class ScanAllByLabelPropertyRange; +class ScanAllByLabelPropertyValue; +class ScanAllByLabelProperty; +class ScanAllById; +class Expand; +class ExpandVariable; +class ConstructNamedPath; +class Filter; +class Produce; +class Delete; +class SetProperty; +class SetProperties; +class SetLabels; +class RemoveProperty; +class RemoveLabels; +class EdgeUniquenessFilter; +class Accumulate; +class Aggregate; +class Skip; +class Limit; +class OrderBy; +class Merge; +class Optional; +class Unwind; +class Distinct; +class Union; +class Cartesian; +class CallProcedure; +class LoadCsv; +class Foreach; +class EmptyResult; +class EvaluatePatternFilter; + +using LogicalOperatorCompositeVisitor = + utils::CompositeVisitor<Once, CreateNode, CreateExpand, ScanAll, ScanAllByLabel, ScanAllByLabelPropertyRange, + ScanAllByLabelPropertyValue, ScanAllByLabelProperty, ScanAllById, Expand, ExpandVariable, + ConstructNamedPath, Filter, Produce, Delete, SetProperty, SetProperties, SetLabels, + RemoveProperty, RemoveLabels, EdgeUniquenessFilter, Accumulate, Aggregate, Skip, Limit, + OrderBy, Merge, Optional, Unwind, Distinct, Union, Cartesian, CallProcedure, LoadCsv, + Foreach, EmptyResult, EvaluatePatternFilter>; + +using LogicalOperatorLeafVisitor = utils::LeafVisitor<Once>; + +/** + * @brief Base class for hierarchical visitors of @c LogicalOperator class + * hierarchy. + */ +class HierarchicalLogicalOperatorVisitor : public LogicalOperatorCompositeVisitor, public LogicalOperatorLeafVisitor { + public: + using LogicalOperatorCompositeVisitor::PostVisit; + using LogicalOperatorCompositeVisitor::PreVisit; + using LogicalOperatorLeafVisitor::Visit; + using typename LogicalOperatorLeafVisitor::ReturnType; +}; + +/// Base class for logical operators. +/// +/// Each operator describes an operation, which is to be performed on the +/// database. Operators are iterated over using a @c Cursor. Various operators +/// can serve as inputs to others and thus a sequence of operations is formed. +class LogicalOperator : public utils::Visitable<HierarchicalLogicalOperatorVisitor> { + public: + static const utils::TypeInfo kType; + virtual const utils::TypeInfo &GetTypeInfo() const { return kType; } + + virtual ~LogicalOperator() {} + + /** Construct a @c Cursor which is used to run this operator. + * + * @param utils::MemoryResource Memory resource used for allocations during + * the lifetime of the returned Cursor. + */ + virtual UniqueCursorPtr MakeCursor(utils::MemoryResource *) const = 0; + + /** Return @c Symbol vector where the query results will be stored. + * + * Currently, output symbols are generated in @c Produce @c Union and + * @c CallProcedure operators. @c Skip, @c Limit, @c OrderBy and @c Distinct + * propagate the symbols from @c Produce (if it exists as input operator). + * + * @param SymbolTable used to find symbols for expressions. + * @return std::vector<Symbol> used for results. + */ + virtual std::vector<Symbol> OutputSymbols(const SymbolTable &) const { return std::vector<Symbol>(); } + + /** + * Symbol vector whose values are modified by this operator sub-tree. + * + * This is different than @c OutputSymbols, because it returns all of the + * modified symbols, including those that may not be returned as the + * result of the query. Note that the modified symbols will not contain + * those that should not be read after the operator is processed. + * + * For example, `MATCH (n)-[e]-(m) RETURN n AS l` will generate `ScanAll (n) > + * Expand (e, m) > Produce (l)`. The modified symbols on Produce sub-tree will + * be `l`, the same as output symbols, because it isn't valid to read `n`, `e` + * nor `m` after Produce. On the other hand, modified symbols from Expand + * contain `e` and `m`, as well as `n`, while output symbols are empty. + * Modified symbols from ScanAll contain only `n`, while output symbols are + * also empty. + */ + virtual std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const = 0; + + /** + * Returns true if the operator takes only one input operator. + * NOTE: When this method returns true, you may use `input` and `set_input` + * methods. + */ + virtual bool HasSingleInput() const = 0; + + /** + * Returns the input operator if it has any. + * NOTE: This should only be called if `HasSingleInput() == true`. + */ + virtual std::shared_ptr<LogicalOperator> input() const = 0; + /** + * Set a different input on this operator. + * NOTE: This should only be called if `HasSingleInput() == true`. + */ + virtual void set_input(std::shared_ptr<LogicalOperator>) = 0; + + struct SaveHelper { + std::vector<LogicalOperator *> saved_ops; + }; + + struct LoadHelper { + AstStorage ast_storage; + std::vector<std::pair<uint64_t, std::shared_ptr<LogicalOperator>>> loaded_ops; + }; + + struct SlkLoadHelper { + AstStorage ast_storage; + std::vector<std::shared_ptr<LogicalOperator>> loaded_ops; + }; + + virtual std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const = 0; +}; + +/// A logical operator whose Cursor returns true on the first Pull +/// and false on every following Pull. +class Once : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Once(std::vector<Symbol> symbols = {}) : symbols_{std::move(symbols)} {} + DEFVISITABLE(HierarchicalLogicalOperatorVisitor); + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override { return symbols_; } + + bool HasSingleInput() const override; + std::shared_ptr<LogicalOperator> input() const override; + void set_input(std::shared_ptr<LogicalOperator>) override; + + std::vector<Symbol> symbols_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<Once>(); + object->symbols_ = symbols_; + return object; + } + + private: + class OnceCursor : public Cursor { + public: + OnceCursor() {} + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + bool did_pull_{false}; + }; +}; + +using PropertiesMapList = std::vector<std::pair<storage::PropertyId, Expression *>>; + +struct NodeCreationInfo { + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const { return kType; } + + NodeCreationInfo() = default; + + NodeCreationInfo(Symbol symbol, std::vector<storage::LabelId> labels, + std::variant<PropertiesMapList, ParameterLookup *> properties) + : symbol{std::move(symbol)}, labels{std::move(labels)}, properties{std::move(properties)} {}; + + NodeCreationInfo(Symbol symbol, std::vector<storage::LabelId> labels, PropertiesMapList properties) + : symbol{std::move(symbol)}, labels{std::move(labels)}, properties{std::move(properties)} {}; + + NodeCreationInfo(Symbol symbol, std::vector<storage::LabelId> labels, ParameterLookup *properties) + : symbol{std::move(symbol)}, labels{std::move(labels)}, properties{properties} {}; + + Symbol symbol; + std::vector<storage::LabelId> labels; + std::variant<PropertiesMapList, ParameterLookup *> properties; + + NodeCreationInfo Clone(AstStorage *storage) const { + NodeCreationInfo object; + object.symbol = symbol; + object.labels = labels; + if (const auto *props = std::get_if<PropertiesMapList>(&properties)) { + auto &destination_props = std::get<PropertiesMapList>(object.properties); + destination_props.resize(props->size()); + for (auto i0 = 0; i0 < props->size(); ++i0) { + { + storage::PropertyId first1 = (*props)[i0].first; + Expression *second2; + second2 = (*props)[i0].second ? (*props)[i0].second->Clone(storage) : nullptr; + destination_props[i0] = std::make_pair(std::move(first1), std::move(second2)); + } + } + } else { + object.properties = std::get<ParameterLookup *>(properties)->Clone(storage); + } + return object; + } +}; + +/// Operator for creating a node. +/// +/// This op is used both for creating a single node (`CREATE` statement without +/// a preceding `MATCH`), or multiple nodes (`MATCH ... CREATE` or +/// `CREATE (), () ...`). +/// +/// @sa CreateExpand +class CreateNode : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + CreateNode() {} + + /** + * @param input Optional. If @c nullptr, then a single node will be + * created (a single successful @c Cursor::Pull from this op's @c Cursor). + * If a valid input, then a node will be created for each + * successful pull from the given input. + * @param node_info @c NodeCreationInfo + */ + CreateNode(const std::shared_ptr<LogicalOperator> &input, const NodeCreationInfo &node_info); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + memgraph::query::plan::NodeCreationInfo node_info_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<CreateNode>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->node_info_ = node_info_.Clone(storage); + return object; + } + + private: + class CreateNodeCursor : public Cursor { + public: + CreateNodeCursor(const CreateNode &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const CreateNode &self_; + const UniqueCursorPtr input_cursor_; + }; +}; + +struct EdgeCreationInfo { + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const { return kType; } + + EdgeCreationInfo() = default; + + EdgeCreationInfo(Symbol symbol, std::variant<PropertiesMapList, ParameterLookup *> properties, + storage::EdgeTypeId edge_type, EdgeAtom::Direction direction) + : symbol{std::move(symbol)}, properties{std::move(properties)}, edge_type{edge_type}, direction{direction} {}; + + EdgeCreationInfo(Symbol symbol, PropertiesMapList properties, storage::EdgeTypeId edge_type, + EdgeAtom::Direction direction) + : symbol{std::move(symbol)}, properties{std::move(properties)}, edge_type{edge_type}, direction{direction} {}; + + EdgeCreationInfo(Symbol symbol, ParameterLookup *properties, storage::EdgeTypeId edge_type, + EdgeAtom::Direction direction) + : symbol{std::move(symbol)}, properties{properties}, edge_type{edge_type}, direction{direction} {}; + + Symbol symbol; + std::variant<PropertiesMapList, ParameterLookup *> properties; + storage::EdgeTypeId edge_type; + EdgeAtom::Direction direction{EdgeAtom::Direction::BOTH}; + + EdgeCreationInfo Clone(AstStorage *storage) const { + EdgeCreationInfo object; + object.symbol = symbol; + if (const auto *props = std::get_if<PropertiesMapList>(&properties)) { + auto &destination_props = std::get<PropertiesMapList>(object.properties); + destination_props.resize(props->size()); + for (auto i0 = 0; i0 < props->size(); ++i0) { + { + storage::PropertyId first1 = (*props)[i0].first; + Expression *second2; + second2 = (*props)[i0].second ? (*props)[i0].second->Clone(storage) : nullptr; + destination_props[i0] = std::make_pair(std::move(first1), std::move(second2)); + } + } + } else { + object.properties = std::get<ParameterLookup *>(properties)->Clone(storage); + } + object.edge_type = edge_type; + object.direction = direction; + return object; + } +}; + +/// Operator for creating edges and destination nodes. +/// +/// This operator extends already created nodes with an edge. If the other node +/// on the edge does not exist, it will be created. For example, in `MATCH (n) +/// CREATE (n) -[r:r]-> (n)` query, this operator will create just the edge `r`. +/// In `MATCH (n) CREATE (n) -[r:r]-> (m)` query, the operator will create both +/// the edge `r` and the node `m`. In case of `CREATE (n) -[r:r]-> (m)` the +/// first node `n` is created by @c CreateNode operator, while @c CreateExpand +/// will create the edge `r` and `m`. Similarly, multiple @c CreateExpand are +/// chained in cases when longer paths need creating. +/// +/// @sa CreateNode +class CreateExpand : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + CreateExpand() {} + + /** @brief Construct @c CreateExpand. + * + * @param node_info @c NodeCreationInfo at the end of the edge. + * Used to create a node, unless it refers to an existing one. + * @param edge_info @c EdgeCreationInfo for the edge to be created. + * @param input Optional. Previous @c LogicalOperator which will be pulled. + * For each successful @c Cursor::Pull, this operator will create an + * expansion. + * @param input_symbol @c Symbol for the node at the start of the edge. + * @param existing_node @c bool indicating whether the @c node_atom refers to + * an existing node. If @c false, the operator will also create the node. + */ + CreateExpand(const NodeCreationInfo &node_info, const EdgeCreationInfo &edge_info, + const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, bool existing_node); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + memgraph::query::plan::NodeCreationInfo node_info_; + memgraph::query::plan::EdgeCreationInfo edge_info_; + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + Symbol input_symbol_; + /// if the given node atom refers to an existing node (either matched or created) + bool existing_node_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<CreateExpand>(); + object->node_info_ = node_info_.Clone(storage); + object->edge_info_ = edge_info_.Clone(storage); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->input_symbol_ = input_symbol_; + object->existing_node_ = existing_node_; + return object; + } + + private: + class CreateExpandCursor : public Cursor { + public: + CreateExpandCursor(const CreateExpand &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const CreateExpand &self_; + const UniqueCursorPtr input_cursor_; + + // Get the existing node (if existing_node_ == true), or create a new node + VertexAccessor &OtherVertex(Frame &frame, ExecutionContext &context); + }; +}; + +/// Operator which iterates over all the nodes currently in the database. +/// When given an input (optional), does a cartesian product. +/// +/// It accepts an optional input. If provided then this op scans all the nodes +/// currently in the database for each successful Pull from it's input, thereby +/// producing a cartesian product of input Pulls and database elements. +/// +/// ScanAll can either iterate over the previous graph state (state before +/// the current transacton+command) or over current state. This is controlled +/// with a constructor argument. +/// +/// @sa ScanAllByLabel +/// @sa ScanAllByLabelPropertyRange +/// @sa ScanAllByLabelPropertyValue +class ScanAll : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + ScanAll() {} + ScanAll(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, storage::View view = storage::View::OLD); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + Symbol output_symbol_; + /// Controls which graph state is used to produce vertices. + /// + /// If @c storage::View::OLD, @c ScanAll will produce vertices visible in the + /// previous graph state, before modifications done by current transaction & + /// command. With @c storage::View::NEW, all vertices will be produced the current + /// transaction sees along with their modifications. + storage::View view_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<ScanAll>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->output_symbol_ = output_symbol_; + object->view_ = view_; + return object; + } +}; + +/// Behaves like @c ScanAll, but this operator produces only vertices with +/// given label. +/// +/// @sa ScanAll +/// @sa ScanAllByLabelPropertyRange +/// @sa ScanAllByLabelPropertyValue +class ScanAllByLabel : public memgraph::query::plan::ScanAll { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + ScanAllByLabel() {} + ScanAllByLabel(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, storage::LabelId label, + storage::View view = storage::View::OLD); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + + storage::LabelId label_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<ScanAllByLabel>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->output_symbol_ = output_symbol_; + object->view_ = view_; + object->label_ = label_; + return object; + } +}; + +/// Behaves like @c ScanAll, but produces only vertices with given label and +/// property value which is inside a range (inclusive or exlusive). +/// +/// @sa ScanAll +/// @sa ScanAllByLabel +/// @sa ScanAllByLabelPropertyValue +class ScanAllByLabelPropertyRange : public memgraph::query::plan::ScanAll { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + /** Bound with expression which when evaluated produces the bound value. */ + using Bound = utils::Bound<Expression *>; + ScanAllByLabelPropertyRange() {} + /** + * Constructs the operator for given label and property value in range + * (inclusive). + * + * Range bounds are optional, but only one bound can be left out. + * + * @param input Preceding operator which will serve as the input. + * @param output_symbol Symbol where the vertices will be stored. + * @param label Label which the vertex must have. + * @param property Property from which the value will be looked up from. + * @param lower_bound Optional lower @c Bound. + * @param upper_bound Optional upper @c Bound. + * @param view storage::View used when obtaining vertices. + */ + ScanAllByLabelPropertyRange(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, + storage::LabelId label, storage::PropertyId property, const std::string &property_name, + std::optional<Bound> lower_bound, std::optional<Bound> upper_bound, + storage::View view = storage::View::OLD); + + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + + storage::LabelId label_; + storage::PropertyId property_; + std::string property_name_; + std::optional<Bound> lower_bound_; + std::optional<Bound> upper_bound_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<ScanAllByLabelPropertyRange>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->output_symbol_ = output_symbol_; + object->view_ = view_; + object->label_ = label_; + object->property_ = property_; + object->property_name_ = property_name_; + if (lower_bound_) { + object->lower_bound_.emplace( + utils::Bound<Expression *>(lower_bound_->value()->Clone(storage), lower_bound_->type())); + } else { + object->lower_bound_ = std::nullopt; + } + if (upper_bound_) { + object->upper_bound_.emplace( + utils::Bound<Expression *>(upper_bound_->value()->Clone(storage), upper_bound_->type())); + } else { + object->upper_bound_ = std::nullopt; + } + return object; + } +}; + +/// Behaves like @c ScanAll, but produces only vertices with given label and +/// property value. +/// +/// @sa ScanAll +/// @sa ScanAllByLabel +/// @sa ScanAllByLabelPropertyRange +class ScanAllByLabelPropertyValue : public memgraph::query::plan::ScanAll { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + ScanAllByLabelPropertyValue() {} + /** + * Constructs the operator for given label and property value. + * + * @param input Preceding operator which will serve as the input. + * @param output_symbol Symbol where the vertices will be stored. + * @param label Label which the vertex must have. + * @param property Property from which the value will be looked up from. + * @param expression Expression producing the value of the vertex property. + * @param view storage::View used when obtaining vertices. + */ + ScanAllByLabelPropertyValue(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, + storage::LabelId label, storage::PropertyId property, const std::string &property_name, + Expression *expression, storage::View view = storage::View::OLD); + + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + + storage::LabelId label_; + storage::PropertyId property_; + std::string property_name_; + Expression *expression_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<ScanAllByLabelPropertyValue>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->output_symbol_ = output_symbol_; + object->view_ = view_; + object->label_ = label_; + object->property_ = property_; + object->property_name_ = property_name_; + object->expression_ = expression_ ? expression_->Clone(storage) : nullptr; + return object; + } +}; + +/// Behaves like @c ScanAll, but this operator produces only vertices with +/// given label and property. +/// +/// @sa ScanAll +/// @sa ScanAllByLabelPropertyRange +/// @sa ScanAllByLabelPropertyValue +class ScanAllByLabelProperty : public memgraph::query::plan::ScanAll { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + ScanAllByLabelProperty() {} + ScanAllByLabelProperty(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, storage::LabelId label, + storage::PropertyId property, const std::string &property_name, + storage::View view = storage::View::OLD); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + + storage::LabelId label_; + storage::PropertyId property_; + std::string property_name_; + Expression *expression_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<ScanAllByLabelProperty>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->output_symbol_ = output_symbol_; + object->view_ = view_; + object->label_ = label_; + object->property_ = property_; + object->property_name_ = property_name_; + object->expression_ = expression_ ? expression_->Clone(storage) : nullptr; + return object; + } +}; + +/// ScanAll producing a single node with ID equal to evaluated expression +class ScanAllById : public memgraph::query::plan::ScanAll { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + ScanAllById() {} + ScanAllById(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, Expression *expression, + storage::View view = storage::View::OLD); + + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + + Expression *expression_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<ScanAllById>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->output_symbol_ = output_symbol_; + object->view_ = view_; + object->expression_ = expression_ ? expression_->Clone(storage) : nullptr; + return object; + } +}; + +struct ExpandCommon { + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const { return kType; } + + /// Symbol pointing to the node to be expanded. + /// This is where the new node will be stored. + Symbol node_symbol; + /// Symbol for the edges to be expanded. + /// This is where a TypedValue containing a list of expanded edges will be stored. + Symbol edge_symbol; + /// EdgeAtom::Direction determining the direction of edge + /// expansion. The direction is relative to the starting vertex for each expansion. + EdgeAtom::Direction direction; + /// storage::EdgeTypeId specifying which edges we want + /// to expand. If empty, all edges are valid. If not empty, only edges with one of + /// the given types are valid. + std::vector<storage::EdgeTypeId> edge_types; + /// If the given node atom refer to a symbol + /// that has already been expanded and should be just validated in the frame. + bool existing_node; +}; + +/// Expansion operator. For a node existing in the frame it +/// expands one edge and one node and places them on the frame. +/// +/// This class does not handle node/edge filtering based on +/// properties, labels and edge types. However, it does handle +/// filtering on existing node / edge. +/// +/// Filtering on existing means that for a pattern that references +/// an already declared node or edge (for example in +/// MATCH (a) MATCH (a)--(b)), +/// only expansions that match defined equalities are successfully +/// pulled. +class Expand : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + /** + * Creates an expansion. All parameters except input and input_symbol are + * forwarded to @c ExpandCommon and are documented there. + * + * @param input Optional logical operator that preceeds this one. + * @param input_symbol Symbol that points to a VertexAccessor in the frame + * that expansion should emanate from. + */ + Expand(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, Symbol node_symbol, Symbol edge_symbol, + EdgeAtom::Direction direction, const std::vector<storage::EdgeTypeId> &edge_types, bool existing_node, + storage::View view); + + Expand() {} + + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + class ExpandCursor : public Cursor { + public: + ExpandCursor(const Expand &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + using InEdgeT = std::remove_reference_t<decltype(*std::declval<VertexAccessor>().InEdges(storage::View::OLD))>; + using InEdgeIteratorT = decltype(std::declval<InEdgeT>().begin()); + using OutEdgeT = std::remove_reference_t<decltype(*std::declval<VertexAccessor>().OutEdges(storage::View::OLD))>; + using OutEdgeIteratorT = decltype(std::declval<OutEdgeT>().begin()); + + const Expand &self_; + const UniqueCursorPtr input_cursor_; + + // The iterable over edges and the current edge iterator are referenced via + // optional because they can not be initialized in the constructor of + // this class. They are initialized once for each pull from the input. + std::optional<InEdgeT> in_edges_; + std::optional<InEdgeIteratorT> in_edges_it_; + std::optional<OutEdgeT> out_edges_; + std::optional<OutEdgeIteratorT> out_edges_it_; + + bool InitEdges(Frame &, ExecutionContext &); + }; + + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + Symbol input_symbol_; + memgraph::query::plan::ExpandCommon common_; + /// State from which the input node should get expanded. + storage::View view_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<Expand>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->input_symbol_ = input_symbol_; + object->common_ = common_; + object->view_ = view_; + return object; + } +}; + +struct ExpansionLambda { + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const { return kType; } + + /// Currently expanded edge symbol. + Symbol inner_edge_symbol; + /// Currently expanded node symbol. + Symbol inner_node_symbol; + /// Expression used in lambda during expansion. + Expression *expression; + + ExpansionLambda Clone(AstStorage *storage) const { + ExpansionLambda object; + object.inner_edge_symbol = inner_edge_symbol; + object.inner_node_symbol = inner_node_symbol; + object.expression = expression ? expression->Clone(storage) : nullptr; + return object; + } +}; + +/// Variable-length expansion operator. For a node existing in +/// the frame it expands a variable number of edges and places them +/// (in a list-type TypedValue), as well as the final destination node, +/// on the frame. +/// +/// This class does not handle node/edge filtering based on +/// properties, labels and edge types. However, it does handle +/// filtering on existing node / edge. Additionally it handles's +/// edge-uniquess (cyphermorphism) because it's not feasable to do +/// later. +/// +/// Filtering on existing means that for a pattern that references +/// an already declared node or edge (for example in +/// MATCH (a) MATCH (a)--(b)), +/// only expansions that match defined equalities are succesfully +/// pulled. +class ExpandVariable : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + ExpandVariable() {} + + /** + * Creates a variable-length expansion. Most params are forwarded + * to the @c ExpandCommon constructor, and are documented there. + * + * Expansion length bounds are both inclusive (as in Neo's Cypher + * implementation). + * + * @param input Optional logical operator that preceeds this one. + * @param input_symbol Symbol that points to a VertexAccessor in the frame + * that expansion should emanate from. + * @param type - Either Type::DEPTH_FIRST (default variable-length expansion), + * or Type::BREADTH_FIRST. + * @param is_reverse Set to `true` if the edges written on frame should expand + * from `node_symbol` to `input_symbol`. Opposed to the usual expanding + * from `input_symbol` to `node_symbol`. + * @param lower_bound An optional indicator of the minimum number of edges + * that get expanded (inclusive). + * @param upper_bound An optional indicator of the maximum number of edges + * that get expanded (inclusive). + * @param inner_edge_symbol Like `inner_node_symbol` + * @param inner_node_symbol For each expansion the node expanded into is + * assigned to this symbol so it can be evaulated by the 'where' + * expression. + * @param filter_ The filter that must be satisfied for an expansion to + * succeed. Can use inner(node/edge) symbols. If nullptr, it is ignored. + */ + ExpandVariable(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, Symbol node_symbol, + Symbol edge_symbol, EdgeAtom::Type type, EdgeAtom::Direction direction, + const std::vector<storage::EdgeTypeId> &edge_types, bool is_reverse, Expression *lower_bound, + Expression *upper_bound, bool existing_node, ExpansionLambda filter_lambda, + std::optional<ExpansionLambda> weight_lambda, std::optional<Symbol> total_weight); + + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + Symbol input_symbol_; + memgraph::query::plan::ExpandCommon common_; + EdgeAtom::Type type_; + /// True if the path should be written as expanding from node_symbol to input_symbol. + bool is_reverse_; + /// Optional lower bound of the variable length expansion, defaults are (1, inf) + Expression *lower_bound_; + /// Optional upper bound of the variable length expansion, defaults are (1, inf) + Expression *upper_bound_; + memgraph::query::plan::ExpansionLambda filter_lambda_; + std::optional<memgraph::query::plan::ExpansionLambda> weight_lambda_; + std::optional<Symbol> total_weight_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<ExpandVariable>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->input_symbol_ = input_symbol_; + object->common_ = common_; + object->type_ = type_; + object->is_reverse_ = is_reverse_; + object->lower_bound_ = lower_bound_ ? lower_bound_->Clone(storage) : nullptr; + object->upper_bound_ = upper_bound_ ? upper_bound_->Clone(storage) : nullptr; + object->filter_lambda_ = filter_lambda_.Clone(storage); + if (weight_lambda_) { + memgraph::query::plan::ExpansionLambda value0; + value0 = (*weight_lambda_).Clone(storage); + object->weight_lambda_.emplace(std::move(value0)); + } else { + object->weight_lambda_ = std::nullopt; + } + object->total_weight_ = total_weight_; + return object; + } + + private: + // the Cursors are not declared in the header because + // it's edges_ and edges_it_ are decltyped using a helper function + // that should be inaccessible (private class function won't compile) + friend class ExpandVariableCursor; + friend class ExpandWeightedShortestPathCursor; + friend class ExpandAllShortestPathCursor; +}; + +/// Constructs a named path from its elements and places it on the frame. +class ConstructNamedPath : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + ConstructNamedPath() {} + ConstructNamedPath(const std::shared_ptr<LogicalOperator> &input, Symbol path_symbol, + const std::vector<Symbol> &path_elements) + : input_(input), path_symbol_(path_symbol), path_elements_(path_elements) {} + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + Symbol path_symbol_; + std::vector<Symbol> path_elements_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<ConstructNamedPath>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->path_symbol_ = path_symbol_; + object->path_elements_ = path_elements_; + return object; + } +}; + +/// Filter whose Pull returns true only when the given expression +/// evaluates into true. +/// +/// The given expression is assumed to return either NULL (treated as false) or +/// a boolean value. +class Filter : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Filter() {} + + Filter(const std::shared_ptr<LogicalOperator> &input_, + const std::vector<std::shared_ptr<LogicalOperator>> &pattern_filters_, Expression *expression_); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + std::vector<std::shared_ptr<memgraph::query::plan::LogicalOperator>> pattern_filters_; + Expression *expression_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<Filter>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->pattern_filters_.resize(pattern_filters_.size()); + for (auto i1 = 0; i1 < pattern_filters_.size(); ++i1) { + object->pattern_filters_[i1] = pattern_filters_[i1] ? pattern_filters_[i1]->Clone(storage) : nullptr; + } + object->expression_ = expression_ ? expression_->Clone(storage) : nullptr; + return object; + } + + private: + class FilterCursor : public Cursor { + public: + FilterCursor(const Filter &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const Filter &self_; + const UniqueCursorPtr input_cursor_; + const std::vector<UniqueCursorPtr> pattern_filter_cursors_; + }; +}; + +/// A logical operator that places an arbitrary number +/// of named expressions on the frame (the logical operator +/// for the RETURN clause). +/// +/// Supports optional input. When the input is provided, +/// it is Pulled from and the Produce succeeds once for +/// every input Pull (typically a MATCH/RETURN query). +/// When the input is not provided (typically a standalone +/// RETURN clause) the Produce's pull succeeds exactly once. +class Produce : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Produce() {} + + Produce(const std::shared_ptr<LogicalOperator> &input, const std::vector<NamedExpression *> &named_expressions); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> OutputSymbols(const SymbolTable &) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + std::vector<NamedExpression *> named_expressions_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<Produce>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->named_expressions_.resize(named_expressions_.size()); + for (auto i1 = 0; i1 < named_expressions_.size(); ++i1) { + object->named_expressions_[i1] = named_expressions_[i1] ? named_expressions_[i1]->Clone(storage) : nullptr; + } + return object; + } + + private: + class ProduceCursor : public Cursor { + public: + ProduceCursor(const Produce &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const Produce &self_; + const UniqueCursorPtr input_cursor_; + }; +}; + +/// Operator for deleting vertices and edges. +/// +/// Has a flag for using DETACH DELETE when deleting vertices. +class Delete : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Delete() {} + + Delete(const std::shared_ptr<LogicalOperator> &input_, const std::vector<Expression *> &expressions, bool detach_); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + std::vector<Expression *> expressions_; + /// Whether the vertex should be detached before deletion. If not detached, + /// and has connections, an error is raised when deleting edges. + bool detach_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<Delete>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->expressions_.resize(expressions_.size()); + for (auto i2 = 0; i2 < expressions_.size(); ++i2) { + object->expressions_[i2] = expressions_[i2] ? expressions_[i2]->Clone(storage) : nullptr; + } + object->detach_ = detach_; + return object; + } + + private: + class DeleteCursor : public Cursor { + public: + DeleteCursor(const Delete &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const Delete &self_; + const UniqueCursorPtr input_cursor_; + }; +}; + +/// Logical operator for setting a single property on a single vertex or edge. +/// +/// The property value is an expression that must evaluate to some type that +/// can be stored (a TypedValue that can be converted to PropertyValue). +class SetProperty : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + SetProperty() {} + + SetProperty(const std::shared_ptr<LogicalOperator> &input, storage::PropertyId property, PropertyLookup *lhs, + Expression *rhs); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + storage::PropertyId property_; + PropertyLookup *lhs_; + Expression *rhs_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<SetProperty>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->property_ = property_; + object->lhs_ = lhs_ ? lhs_->Clone(storage) : nullptr; + object->rhs_ = rhs_ ? rhs_->Clone(storage) : nullptr; + return object; + } + + private: + class SetPropertyCursor : public Cursor { + public: + SetPropertyCursor(const SetProperty &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const SetProperty &self_; + const UniqueCursorPtr input_cursor_; + }; +}; + +/// Logical operator for setting the whole property set on a vertex or an edge. +/// +/// The value being set is an expression that must evaluate to a vertex, edge or +/// map (literal or parameter). +/// +/// Supports setting (replacing the whole properties set with another) and +/// updating. +class SetProperties : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + /// Defines how setting the properties works. + /// + /// @c UPDATE means that the current property set is augmented with additional + /// ones (existing props of the same name are replaced), while @c REPLACE means + /// that the old properties are discarded and replaced with new ones. + enum class Op { UPDATE, REPLACE }; + + SetProperties() {} + + SetProperties(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, Expression *rhs, Op op); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + Symbol input_symbol_; + Expression *rhs_; + memgraph::query::plan::SetProperties::Op op_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<SetProperties>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->input_symbol_ = input_symbol_; + object->rhs_ = rhs_ ? rhs_->Clone(storage) : nullptr; + object->op_ = op_; + return object; + } + + private: + class SetPropertiesCursor : public Cursor { + public: + SetPropertiesCursor(const SetProperties &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const SetProperties &self_; + const UniqueCursorPtr input_cursor_; + }; +}; + +/// Logical operator for setting an arbitrary number of labels on a Vertex. +/// +/// It does NOT remove labels that are already set on that Vertex. +class SetLabels : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + SetLabels() {} + + SetLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, + const std::vector<storage::LabelId> &labels); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + Symbol input_symbol_; + std::vector<storage::LabelId> labels_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<SetLabels>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->input_symbol_ = input_symbol_; + object->labels_ = labels_; + return object; + } + + private: + class SetLabelsCursor : public Cursor { + public: + SetLabelsCursor(const SetLabels &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const SetLabels &self_; + const UniqueCursorPtr input_cursor_; + }; +}; + +/// Logical operator for removing a property from an edge or a vertex. +class RemoveProperty : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + RemoveProperty() {} + + RemoveProperty(const std::shared_ptr<LogicalOperator> &input, storage::PropertyId property, PropertyLookup *lhs); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + storage::PropertyId property_; + PropertyLookup *lhs_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<RemoveProperty>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->property_ = property_; + object->lhs_ = lhs_ ? lhs_->Clone(storage) : nullptr; + return object; + } + + private: + class RemovePropertyCursor : public Cursor { + public: + RemovePropertyCursor(const RemoveProperty &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const RemoveProperty &self_; + const UniqueCursorPtr input_cursor_; + }; +}; + +/// Logical operator for removing an arbitrary number of labels on a Vertex. +/// +/// If a label does not exist on a Vertex, nothing happens. +class RemoveLabels : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + RemoveLabels() {} + + RemoveLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, + const std::vector<storage::LabelId> &labels); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + Symbol input_symbol_; + std::vector<storage::LabelId> labels_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<RemoveLabels>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->input_symbol_ = input_symbol_; + object->labels_ = labels_; + return object; + } + + private: + class RemoveLabelsCursor : public Cursor { + public: + RemoveLabelsCursor(const RemoveLabels &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const RemoveLabels &self_; + const UniqueCursorPtr input_cursor_; + }; +}; + +/// Filter whose Pull returns true only when the given expand_symbol frame +/// value (the latest expansion) is not equal to any of the previous_symbols frame +/// values. +/// +/// Used for implementing Cyphermorphism. +/// Isomorphism is vertex-uniqueness. It means that two different vertices in a +/// pattern can not map to the same data vertex. +/// Cyphermorphism is edge-uniqueness (the above explanation applies). By default +/// Neo4j uses Cyphermorphism (that's where the name stems from, it is not a valid +/// graph-theory term). +/// +/// Supports variable-length-edges (uniqueness comparisons between edges and an +/// edge lists). +class EdgeUniquenessFilter : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + EdgeUniquenessFilter() {} + + EdgeUniquenessFilter(const std::shared_ptr<LogicalOperator> &input, Symbol expand_symbol, + const std::vector<Symbol> &previous_symbols); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + Symbol expand_symbol_; + std::vector<Symbol> previous_symbols_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<EdgeUniquenessFilter>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->expand_symbol_ = expand_symbol_; + object->previous_symbols_ = previous_symbols_; + return object; + } + + private: + class EdgeUniquenessFilterCursor : public Cursor { + public: + EdgeUniquenessFilterCursor(const EdgeUniquenessFilter &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const EdgeUniquenessFilter &self_; + const UniqueCursorPtr input_cursor_; + }; +}; + +/// Pulls everything from the input and discards it. +/// +/// On the first Pull from this operator's Cursor the input Cursor will be Pulled +/// until it is empty. The results won't be accumulated in the temporary cache. +/// +/// This technique is used for ensuring that the cursor has been exhausted after +/// a WriteHandleClause. A typical use case is a `MATCH--SET` query with RETURN statement +/// missing. +/// @param input Input @c LogicalOperator. +class EmptyResult : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + EmptyResult() {} + + EmptyResult(const std::shared_ptr<LogicalOperator> &input); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> OutputSymbols(const SymbolTable &) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<EmptyResult>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + return object; + } +}; + +/// Pulls everything from the input before passing it through. +/// Optionally advances the command after accumulation and before emitting. +/// +/// On the first Pull from this operator's Cursor the input Cursor will be Pulled +/// until it is empty. The results will be accumulated in the temporary cache. Once +/// the input Cursor is empty, this operator's Cursor will start returning cached +/// stuff from its Pull. +/// +/// This technique is used for ensuring all the operations from the +/// previous logical operator have been performed before exposing data +/// to the next. A typical use case is a `MATCH--SET--RETURN` +/// query in which every SET iteration must be performed before +/// RETURN starts iterating (see Memgraph Wiki for detailed reasoning). +/// +/// IMPORTANT: This operator does not cache all the results but only those +/// elements from the frame whose symbols (frame positions) it was given. +/// All other frame positions will contain undefined junk after this +/// operator has executed, and should not be used. +/// +/// This operator can also advance the command after the accumulation and +/// before emitting. If the command gets advanced, every value that +/// has been cached will be reconstructed before Pull returns. +/// +/// @param input Input @c LogicalOperator. +/// @param symbols A vector of Symbols that need to be accumulated +/// and exposed to the next op. +class Accumulate : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Accumulate() {} + + Accumulate(const std::shared_ptr<LogicalOperator> &input, const std::vector<Symbol> &symbols, + bool advance_command = false); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + std::vector<Symbol> symbols_; + bool advance_command_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<Accumulate>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->symbols_ = symbols_; + object->advance_command_ = advance_command_; + return object; + } +}; + +/// Performs an arbitrary number of aggregations of data +/// from the given input grouped by the given criteria. +/// +/// Aggregations are defined by triples that define +/// (input data expression, type of aggregation, output symbol). +/// Input data is grouped based on the given set of named +/// expressions. Grouping is done on unique values. +/// +/// IMPORTANT: +/// Operators taking their input from an aggregation are only +/// allowed to use frame values that are either aggregation +/// outputs or group-by named-expressions. All other frame +/// elements are in an undefined state after aggregation. +class Aggregate : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + /// An aggregation element, contains: + /// (input data expression, key expression - only used in COLLECT_MAP, type of + /// aggregation, output symbol). + struct Element { + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const { return kType; } + + Expression *value; + Expression *key; + Aggregation::Op op; + Symbol output_sym; + bool distinct{false}; + + Element Clone(AstStorage *storage) const { + Element object; + object.value = value ? value->Clone(storage) : nullptr; + object.key = key ? key->Clone(storage) : nullptr; + object.op = op; + object.output_sym = output_sym; + object.distinct = distinct; + return object; + } + }; + + Aggregate() = default; + Aggregate(const std::shared_ptr<LogicalOperator> &input, const std::vector<Element> &aggregations, + const std::vector<Expression *> &group_by, const std::vector<Symbol> &remember); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + std::vector<memgraph::query::plan::Aggregate::Element> aggregations_; + std::vector<Expression *> group_by_; + std::vector<Symbol> remember_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<Aggregate>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->aggregations_.resize(aggregations_.size()); + for (auto i3 = 0; i3 < aggregations_.size(); ++i3) { + object->aggregations_[i3] = aggregations_[i3].Clone(storage); + } + object->group_by_.resize(group_by_.size()); + for (auto i4 = 0; i4 < group_by_.size(); ++i4) { + object->group_by_[i4] = group_by_[i4] ? group_by_[i4]->Clone(storage) : nullptr; + } + object->remember_ = remember_; + return object; + } +}; + +/// Skips a number of Pulls from the input op. +/// +/// The given expression determines how many Pulls from the input +/// should be skipped (ignored). +/// All other successful Pulls from the +/// input are simply passed through. +/// +/// The given expression is evaluated after the first Pull from +/// the input, and only once. Neo does not allow this expression +/// to contain identifiers, and neither does Memgraph, but this +/// operator's implementation does not expect this. +class Skip : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Skip() {} + + Skip(const std::shared_ptr<LogicalOperator> &input, Expression *expression); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> OutputSymbols(const SymbolTable &) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + Expression *expression_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<Skip>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->expression_ = expression_ ? expression_->Clone(storage) : nullptr; + return object; + } + + private: + class SkipCursor : public Cursor { + public: + SkipCursor(const Skip &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const Skip &self_; + const UniqueCursorPtr input_cursor_; + // init to_skip_ to -1, indicating + // that it's still unknown (input has not been Pulled yet) + int64_t to_skip_{-1}; + int64_t skipped_{0}; + }; +}; + +/// Applies the pattern filter by putting the value of the input cursor to the frame. +class EvaluatePatternFilter : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + EvaluatePatternFilter() {} + + EvaluatePatternFilter(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + Symbol output_symbol_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<EvaluatePatternFilter>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->output_symbol_ = output_symbol_; + return object; + } + + private: + class EvaluatePatternFilterCursor : public Cursor { + public: + EvaluatePatternFilterCursor(const EvaluatePatternFilter &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const EvaluatePatternFilter &self_; + UniqueCursorPtr input_cursor_; + }; +}; + +/// Limits the number of Pulls from the input op. +/// +/// The given expression determines how many +/// input Pulls should be passed through. The input is not +/// Pulled once this limit is reached. Note that this has +/// implications: the out-of-bounds input Pulls are never +/// evaluated. +/// +/// The limit expression must NOT use anything from the +/// Frame. It is evaluated before the first Pull from the +/// input. This is consistent with Neo (they don't allow +/// identifiers in limit expressions), and it's necessary +/// when limit evaluates to 0 (because 0 Pulls from the +/// input should be performed). +class Limit : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Limit() {} + + Limit(const std::shared_ptr<LogicalOperator> &input, Expression *expression); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> OutputSymbols(const SymbolTable &) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + Expression *expression_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<Limit>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->expression_ = expression_ ? expression_->Clone(storage) : nullptr; + return object; + } + + private: + class LimitCursor : public Cursor { + public: + LimitCursor(const Limit &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const Limit &self_; + UniqueCursorPtr input_cursor_; + // init limit_ to -1, indicating + // that it's still unknown (Cursor has not been Pulled yet) + int64_t limit_{-1}; + int64_t pulled_{0}; + }; +}; + +/// Logical operator for ordering (sorting) results. +/// +/// Sorts the input rows based on an arbitrary number of +/// Expressions. Ascending or descending ordering can be chosen +/// for each independently (not providing enough orderings +/// results in a runtime error). +/// +/// For each row an arbitrary number of Frame elements can be +/// remembered. Only these elements (defined by their Symbols) +/// are valid for usage after the OrderBy operator. +class OrderBy : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + OrderBy() {} + + OrderBy(const std::shared_ptr<LogicalOperator> &input, const std::vector<SortItem> &order_by, + const std::vector<Symbol> &output_symbols); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> OutputSymbols(const SymbolTable &) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + TypedValueVectorCompare compare_; + std::vector<Expression *> order_by_; + std::vector<Symbol> output_symbols_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<OrderBy>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->compare_ = compare_; + object->order_by_.resize(order_by_.size()); + for (auto i5 = 0; i5 < order_by_.size(); ++i5) { + object->order_by_[i5] = order_by_[i5] ? order_by_[i5]->Clone(storage) : nullptr; + } + object->output_symbols_ = output_symbols_; + return object; + } +}; + +/// Merge operator. For every sucessful Pull from the +/// input operator a Pull from the merge_match is attempted. All +/// successfull Pulls from the merge_match are passed on as output. +/// If merge_match Pull does not yield any elements, a single Pull +/// from the merge_create op is performed. +/// +/// The input logical op is optional. If false (nullptr) +/// it will be replaced by a Once op. +/// +/// For an argumentation of this implementation see the wiki +/// documentation. +class Merge : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Merge() {} + + Merge(const std::shared_ptr<LogicalOperator> &input, const std::shared_ptr<LogicalOperator> &merge_match, + const std::shared_ptr<LogicalOperator> &merge_create); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + // TODO: Consider whether we want to treat Merge as having single input. It + // makes sense that we do, because other branches are executed depending on + // the input. + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + std::shared_ptr<memgraph::query::plan::LogicalOperator> merge_match_; + std::shared_ptr<memgraph::query::plan::LogicalOperator> merge_create_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<Merge>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->merge_match_ = merge_match_ ? merge_match_->Clone(storage) : nullptr; + object->merge_create_ = merge_create_ ? merge_create_->Clone(storage) : nullptr; + return object; + } + + private: + class MergeCursor : public Cursor { + public: + MergeCursor(const Merge &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const UniqueCursorPtr input_cursor_; + const UniqueCursorPtr merge_match_cursor_; + const UniqueCursorPtr merge_create_cursor_; + + // indicates if the next Pull from this cursor + // should perform a pull from input_cursor_ + // this is true when: + // - first Pulling from this cursor + // - previous Pull from this cursor exhausted the merge_match_cursor + bool pull_input_{true}; + }; +}; + +/// Optional operator. Used for optional match. For every +/// successful Pull from the input branch a Pull from the optional +/// branch is attempted (and Pulled from till exhausted). If zero +/// Pulls succeed from the optional branch, the Optional operator +/// sets the optional symbols to TypedValue::Null on the Frame +/// and returns true, once. +class Optional : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Optional() {} + + Optional(const std::shared_ptr<LogicalOperator> &input, const std::shared_ptr<LogicalOperator> &optional, + const std::vector<Symbol> &optional_symbols); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + std::shared_ptr<memgraph::query::plan::LogicalOperator> optional_; + std::vector<Symbol> optional_symbols_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<Optional>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->optional_ = optional_ ? optional_->Clone(storage) : nullptr; + object->optional_symbols_ = optional_symbols_; + return object; + } + + private: + class OptionalCursor : public Cursor { + public: + OptionalCursor(const Optional &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const Optional &self_; + const UniqueCursorPtr input_cursor_; + const UniqueCursorPtr optional_cursor_; + // indicates if the next Pull from this cursor should + // perform a Pull from the input_cursor_ + // this is true when: + // - first pulling from this Cursor + // - previous Pull from this cursor exhausted the optional_cursor_ + bool pull_input_{true}; + }; +}; + +/// Takes a list TypedValue as it's input and yields each +/// element as it's output. +/// +/// Input is optional (unwind can be the first clause in a query). +class Unwind : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Unwind() {} + + Unwind(const std::shared_ptr<LogicalOperator> &input, Expression *input_expression_, Symbol output_symbol); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + Expression *input_expression_; + Symbol output_symbol_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<Unwind>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->input_expression_ = input_expression_ ? input_expression_->Clone(storage) : nullptr; + object->output_symbol_ = output_symbol_; + return object; + } +}; + +/// Ensures that only distinct rows are yielded. +/// This implementation accepts a vector of Symbols +/// which define a row. Only those Symbols are valid +/// for use in operators following Distinct. +/// +/// This implementation maintains input ordering. +class Distinct : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Distinct() {} + + Distinct(const std::shared_ptr<LogicalOperator> &input, const std::vector<Symbol> &value_symbols); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> OutputSymbols(const SymbolTable &) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + std::vector<Symbol> value_symbols_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<Distinct>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->value_symbols_ = value_symbols_; + return object; + } +}; + +/// A logical operator that applies UNION operator on inputs and places the +/// result on the frame. +/// +/// This operator takes two inputs, a vector of symbols for the result, and vectors +/// of symbols used by each of the inputs. +class Union : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Union() {} + + Union(const std::shared_ptr<LogicalOperator> &left_op, const std::shared_ptr<LogicalOperator> &right_op, + const std::vector<Symbol> &union_symbols, const std::vector<Symbol> &left_symbols, + const std::vector<Symbol> &right_symbols); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> OutputSymbols(const SymbolTable &) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override; + std::shared_ptr<LogicalOperator> input() const override; + void set_input(std::shared_ptr<LogicalOperator>) override; + + std::shared_ptr<memgraph::query::plan::LogicalOperator> left_op_; + std::shared_ptr<memgraph::query::plan::LogicalOperator> right_op_; + std::vector<Symbol> union_symbols_; + std::vector<Symbol> left_symbols_; + std::vector<Symbol> right_symbols_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<Union>(); + object->left_op_ = left_op_ ? left_op_->Clone(storage) : nullptr; + object->right_op_ = right_op_ ? right_op_->Clone(storage) : nullptr; + object->union_symbols_ = union_symbols_; + object->left_symbols_ = left_symbols_; + object->right_symbols_ = right_symbols_; + return object; + } + + private: + class UnionCursor : public Cursor { + public: + UnionCursor(const Union &, utils::MemoryResource *); + bool Pull(Frame &, ExecutionContext &) override; + void Shutdown() override; + void Reset() override; + + private: + const Union &self_; + const UniqueCursorPtr left_cursor_, right_cursor_; + }; +}; + +/// Operator for producing a Cartesian product from 2 input branches +class Cartesian : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Cartesian() {} + /** Construct the operator with left input branch and right input branch. */ + Cartesian(const std::shared_ptr<LogicalOperator> &left_op, const std::vector<Symbol> &left_symbols, + const std::shared_ptr<LogicalOperator> &right_op, const std::vector<Symbol> &right_symbols) + : left_op_(left_op), left_symbols_(left_symbols), right_op_(right_op), right_symbols_(right_symbols) {} + + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override; + std::shared_ptr<LogicalOperator> input() const override; + void set_input(std::shared_ptr<LogicalOperator>) override; + + std::shared_ptr<memgraph::query::plan::LogicalOperator> left_op_; + std::vector<Symbol> left_symbols_; + std::shared_ptr<memgraph::query::plan::LogicalOperator> right_op_; + std::vector<Symbol> right_symbols_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<Cartesian>(); + object->left_op_ = left_op_ ? left_op_->Clone(storage) : nullptr; + object->left_symbols_ = left_symbols_; + object->right_op_ = right_op_ ? right_op_->Clone(storage) : nullptr; + object->right_symbols_ = right_symbols_; + return object; + } +}; + +/// An operator that outputs a table, producing a single row on each pull +class OutputTable : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + OutputTable() {} + OutputTable(std::vector<Symbol> output_symbols, + std::function<std::vector<std::vector<TypedValue>>(Frame *, ExecutionContext *)> callback); + OutputTable(std::vector<Symbol> output_symbols, std::vector<std::vector<TypedValue>> rows); + + bool Accept(HierarchicalLogicalOperatorVisitor &) override { + LOG_FATAL("OutputTable operator should not be visited!"); + } + + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> OutputSymbols(const SymbolTable &) const override { return output_symbols_; } + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override { return output_symbols_; } + + bool HasSingleInput() const override; + std::shared_ptr<LogicalOperator> input() const override; + void set_input(std::shared_ptr<LogicalOperator> input) override; + + std::vector<Symbol> output_symbols_; + std::function<std::vector<std::vector<TypedValue>>(Frame *, ExecutionContext *)> callback_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<OutputTable>(); + object->output_symbols_ = output_symbols_; + object->callback_ = callback_; + return object; + } +}; + +/// An operator that outputs a table, producing a single row on each pull. +/// This class is different from @c OutputTable in that its callback doesn't fetch all rows +/// at once. Instead, each call of the callback should return a single row of the table. +class OutputTableStream : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + OutputTableStream() {} + OutputTableStream(std::vector<Symbol> output_symbols, + std::function<std::optional<std::vector<TypedValue>>(Frame *, ExecutionContext *)> callback); + + bool Accept(HierarchicalLogicalOperatorVisitor &) override { + LOG_FATAL("OutputTableStream operator should not be visited!"); + } + + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> OutputSymbols(const SymbolTable &) const override { return output_symbols_; } + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override { return output_symbols_; } + + bool HasSingleInput() const override; + std::shared_ptr<LogicalOperator> input() const override; + void set_input(std::shared_ptr<LogicalOperator> input) override; + + std::vector<Symbol> output_symbols_; + std::function<std::optional<std::vector<TypedValue>>(Frame *, ExecutionContext *)> callback_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<OutputTableStream>(); + object->output_symbols_ = output_symbols_; + object->callback_ = callback_; + return object; + } +}; + +class CallProcedure : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + CallProcedure() = default; + CallProcedure(std::shared_ptr<LogicalOperator> input, std::string name, std::vector<Expression *> arguments, + std::vector<std::string> fields, std::vector<Symbol> symbols, Expression *memory_limit, + size_t memory_scale, bool is_write); + + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> OutputSymbols(const SymbolTable &) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + static void IncrementCounter(const std::string &procedure_name); + static std::unordered_map<std::string, int64_t> GetAndResetCounters(); + + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + std::string procedure_name_; + std::vector<Expression *> arguments_; + std::vector<std::string> result_fields_; + std::vector<Symbol> result_symbols_; + Expression *memory_limit_{nullptr}; + size_t memory_scale_{1024U}; + bool is_write_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<CallProcedure>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->procedure_name_ = procedure_name_; + object->arguments_.resize(arguments_.size()); + for (auto i6 = 0; i6 < arguments_.size(); ++i6) { + object->arguments_[i6] = arguments_[i6] ? arguments_[i6]->Clone(storage) : nullptr; + } + object->result_fields_ = result_fields_; + object->result_symbols_ = result_symbols_; + object->memory_limit_ = memory_limit_ ? memory_limit_->Clone(storage) : nullptr; + object->memory_scale_ = memory_scale_; + object->is_write_ = is_write_; + return object; + } + + private: + inline static utils::Synchronized<std::unordered_map<std::string, int64_t>, utils::SpinLock> procedure_counters_; +}; + +class LoadCsv : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + LoadCsv() = default; + LoadCsv(std::shared_ptr<LogicalOperator> input, Expression *file, bool with_header, bool ignore_bad, + Expression *delimiter, Expression *quote, Symbol row_var); + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> OutputSymbols(const SymbolTable &) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = input; } + + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + Expression *file_; + bool with_header_; + bool ignore_bad_; + Expression *delimiter_{nullptr}; + Expression *quote_{nullptr}; + Symbol row_var_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<LoadCsv>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->file_ = file_ ? file_->Clone(storage) : nullptr; + object->with_header_ = with_header_; + object->ignore_bad_ = ignore_bad_; + object->delimiter_ = delimiter_ ? delimiter_->Clone(storage) : nullptr; + object->quote_ = quote_ ? quote_->Clone(storage) : nullptr; + object->row_var_ = row_var_; + return object; + } +}; + +/// Iterates over a collection of elements and applies one or more update +/// clauses. +/// +class Foreach : public memgraph::query::plan::LogicalOperator { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + Foreach() = default; + Foreach(std::shared_ptr<LogicalOperator> input, std::shared_ptr<LogicalOperator> updates, Expression *named_expr, + Symbol loop_variable_symbol); + + bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; + UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; + std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; + bool HasSingleInput() const override { return true; } + std::shared_ptr<LogicalOperator> input() const override { return input_; } + void set_input(std::shared_ptr<LogicalOperator> input) override { input_ = std::move(input); } + + std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; + std::shared_ptr<memgraph::query::plan::LogicalOperator> update_clauses_; + Expression *expression_; + Symbol loop_variable_symbol_; + + std::unique_ptr<LogicalOperator> Clone(AstStorage *storage) const override { + auto object = std::make_unique<Foreach>(); + object->input_ = input_ ? input_->Clone(storage) : nullptr; + object->update_clauses_ = update_clauses_ ? update_clauses_->Clone(storage) : nullptr; + object->expression_ = expression_ ? expression_->Clone(storage) : nullptr; + object->loop_variable_symbol_ = loop_variable_symbol_; + return object; + } +}; + +} // namespace plan +} // namespace query +} // namespace memgraph diff --git a/src/query/plan/operator_type_info.cpp b/src/query/plan/operator_type_info.cpp new file mode 100644 index 000000000..c20dd1e77 --- /dev/null +++ b/src/query/plan/operator_type_info.cpp @@ -0,0 +1,148 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include <cstdint> + +#include "query/plan/operator.hpp" + +namespace memgraph { + +constexpr utils::TypeInfo query::plan::LogicalOperator::kType{utils::TypeId::LOGICAL_OPERATOR, "LogicalOperator", + nullptr}; + +constexpr utils::TypeInfo query::plan::Once::kType{utils::TypeId::ONCE, "Once", &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::NodeCreationInfo::kType{utils::TypeId::NODE_CREATION_INFO, "NodeCreationInfo", + nullptr}; + +constexpr utils::TypeInfo query::plan::CreateNode::kType{utils::TypeId::CREATE_NODE, "CreateNode", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::EdgeCreationInfo::kType{utils::TypeId::EDGE_CREATION_INFO, "EdgeCreationInfo", + nullptr}; + +constexpr utils::TypeInfo query::plan::CreateExpand::kType{utils::TypeId::CREATE_EXPAND, "CreateExpand", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::ScanAll::kType{utils::TypeId::SCAN_ALL, "ScanAll", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::ScanAllByLabel::kType{utils::TypeId::SCAN_ALL_BY_LABEL, "ScanAllByLabel", + &query::plan::ScanAll::kType}; + +constexpr utils::TypeInfo query::plan::ScanAllByLabelPropertyRange::kType{ + utils::TypeId::SCAN_ALL_BY_LABEL_PROPERTY_RANGE, "ScanAllByLabelPropertyRange", &query::plan::ScanAll::kType}; + +constexpr utils::TypeInfo query::plan::ScanAllByLabelPropertyValue::kType{ + utils::TypeId::SCAN_ALL_BY_LABEL_PROPERTY_VALUE, "ScanAllByLabelPropertyValue", &query::plan::ScanAll::kType}; + +constexpr utils::TypeInfo query::plan::ScanAllByLabelProperty::kType{ + utils::TypeId::SCAN_ALL_BY_LABEL_PROPERTY, "ScanAllByLabelProperty", &query::plan::ScanAll::kType}; + +constexpr utils::TypeInfo query::plan::ScanAllById::kType{utils::TypeId::SCAN_ALL_BY_ID, "ScanAllById", + &query::plan::ScanAll::kType}; + +constexpr utils::TypeInfo query::plan::ExpandCommon::kType{utils::TypeId::EXPAND_COMMON, "ExpandCommon", nullptr}; + +constexpr utils::TypeInfo query::plan::Expand::kType{utils::TypeId::EXPAND, "Expand", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::ExpansionLambda::kType{utils::TypeId::EXPANSION_LAMBDA, "ExpansionLambda", + nullptr}; + +constexpr utils::TypeInfo query::plan::ExpandVariable::kType{utils::TypeId::EXPAND_VARIABLE, "ExpandVariable", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::ConstructNamedPath::kType{ + utils::TypeId::CONSTRUCT_NAMED_PATH, "ConstructNamedPath", &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::Filter::kType{utils::TypeId::FILTER, "Filter", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::Produce::kType{utils::TypeId::PRODUCE, "Produce", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::Delete::kType{utils::TypeId::DELETE, "Delete", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::SetProperty::kType{utils::TypeId::SET_PROPERTY, "SetProperty", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::SetProperties::kType{utils::TypeId::SET_PROPERTIES, "SetProperties", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::SetLabels::kType{utils::TypeId::SET_LABELS, "SetLabels", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::RemoveProperty::kType{utils::TypeId::REMOVE_PROPERTY, "RemoveProperty", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::RemoveLabels::kType{utils::TypeId::REMOVE_LABELS, "RemoveLabels", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::EdgeUniquenessFilter::kType{ + utils::TypeId::EDGE_UNIQUENESS_FILTER, "EdgeUniquenessFilter", &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::EmptyResult::kType{utils::TypeId::EMPTY_RESULT, "EmptyResult", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::Accumulate::kType{utils::TypeId::ACCUMULATE, "Accumulate", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::Aggregate::Element::kType{utils::TypeId::AGGREGATE_ELEMENT, "Element", nullptr}; + +constexpr utils::TypeInfo query::plan::Aggregate::kType{utils::TypeId::AGGREGATE, "Aggregate", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::Skip::kType{utils::TypeId::SKIP, "Skip", &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::EvaluatePatternFilter::kType{ + utils::TypeId::EVALUATE_PATTERN_FILTER, "EvaluatePatternFilter", &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::Limit::kType{utils::TypeId::LIMIT, "Limit", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::OrderBy::kType{utils::TypeId::ORDERBY, "OrderBy", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::Merge::kType{utils::TypeId::MERGE, "Merge", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::Optional::kType{utils::TypeId::OPTIONAL, "Optional", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::Unwind::kType{utils::TypeId::UNWIND, "Unwind", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::Distinct::kType{utils::TypeId::DISTINCT, "Distinct", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::Union::kType{utils::TypeId::UNION, "Union", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::Cartesian::kType{utils::TypeId::CARTESIAN, "Cartesian", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::OutputTable::kType{utils::TypeId::OUTPUT_TABLE, "OutputTable", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::OutputTableStream::kType{utils::TypeId::OUTPUT_TABLE_STREAM, "OutputTableStream", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::CallProcedure::kType{utils::TypeId::CALL_PROCEDURE, "CallProcedure", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::LoadCsv::kType{utils::TypeId::LOAD_CSV, "LoadCsv", + &query::plan::LogicalOperator::kType}; + +constexpr utils::TypeInfo query::plan::Foreach::kType{utils::TypeId::FOREACH, "Foreach", + &query::plan::LogicalOperator::kType}; +} // namespace memgraph diff --git a/src/rpc/client.hpp b/src/rpc/client.hpp index 5c9d2ad41..4756cd6a5 100644 --- a/src/rpc/client.hpp +++ b/src/rpc/client.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -23,6 +23,7 @@ #include "slk/streams.hpp" #include "utils/logging.hpp" #include "utils/on_scope_exit.hpp" +#include "utils/typeinfo.hpp" namespace memgraph::rpc { @@ -84,11 +85,11 @@ class Client { slk::Reader res_reader(self_->client_->GetData(), response_data_size); utils::OnScopeExit res_cleanup([&, response_data_size] { self_->client_->ShiftData(response_data_size); }); - uint64_t res_id = 0; + utils::TypeId res_id{utils::TypeId::UNKNOWN}; slk::Load(&res_id, &res_reader); // Check the response ID. - if (res_id != res_type.id) { + if (res_id != res_type.id && res_id != utils::TypeId::UNKNOWN) { spdlog::error("Message response was of unexpected type"); self_->client_ = std::nullopt; throw RpcFailedException(self_->endpoint_); diff --git a/src/rpc/protocol.cpp b/src/rpc/protocol.cpp index 63ac58760..ac74d754b 100644 --- a/src/rpc/protocol.cpp +++ b/src/rpc/protocol.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -16,6 +16,7 @@ #include "slk/serialization.hpp" #include "slk/streams.hpp" #include "utils/on_scope_exit.hpp" +#include "utils/typeinfo.hpp" namespace memgraph::rpc { @@ -41,7 +42,7 @@ void Session::Execute() { [&](const uint8_t *data, size_t size, bool have_more) { output_stream_->Write(data, size, have_more); }); // Load the request ID. - uint64_t req_id = 0; + utils::TypeId req_id{utils::TypeId::UNKNOWN}; slk::Load(&req_id, &req_reader); // Access to `callbacks_` and `extended_callbacks_` is done here without diff --git a/src/rpc/server.hpp b/src/rpc/server.hpp index af0c60dde..877420381 100644 --- a/src/rpc/server.hpp +++ b/src/rpc/server.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -86,8 +86,8 @@ class Server { }; std::mutex lock_; - std::map<uint64_t, RpcCallback> callbacks_; - std::map<uint64_t, RpcExtendedCallback> extended_callbacks_; + std::map<utils::TypeId, RpcCallback> callbacks_; + std::map<utils::TypeId, RpcExtendedCallback> extended_callbacks_; communication::Server<Session, Server> server_; }; diff --git a/src/slk/serialization.hpp b/src/slk/serialization.hpp index f7961368f..2b3dab796 100644 --- a/src/slk/serialization.hpp +++ b/src/slk/serialization.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -29,8 +29,10 @@ #include "slk/streams.hpp" #include "utils/cast.hpp" +#include "utils/concepts.hpp" #include "utils/endian.hpp" #include "utils/exceptions.hpp" +#include "utils/typeinfo.hpp" // The namespace name stands for SaveLoadKit. It should be not mistaken for the // Mercedes car model line. @@ -308,6 +310,10 @@ inline void Save(const std::optional<T> &obj, Builder *builder) { } } +inline void Save(const utils::TypeId &obj, Builder *builder) { + Save(static_cast<std::underlying_type_t<utils::TypeId>>(obj), builder); +} + template <typename T> inline void Load(std::optional<T> *obj, Reader *reader) { bool exists = false; @@ -471,4 +477,12 @@ inline void Load(std::optional<T> *obj, Reader *reader, std::function<void(T *, *obj = std::nullopt; } } + +inline void Load(utils::TypeId *obj, Reader *reader) { + using enum_type = std::underlying_type_t<utils::TypeId>; + enum_type obj_encoded; + slk::Load(&obj_encoded, reader); + *obj = utils::TypeId(utils::MemcpyCast<enum_type>(obj_encoded)); +} + } // namespace memgraph::slk diff --git a/src/storage/v2/CMakeLists.txt b/src/storage/v2/CMakeLists.txt index d46563183..505ae45ec 100644 --- a/src/storage/v2/CMakeLists.txt +++ b/src/storage/v2/CMakeLists.txt @@ -12,12 +12,6 @@ set(storage_v2_src_files vertex_accessor.cpp storage.cpp) -##### Replication ##### -define_add_lcp(add_lcp_storage lcp_storage_cpp_files generated_lcp_storage_files) - -add_lcp_storage(replication/rpc.lcp SLK_SERIALIZE) - -add_custom_target(generate_lcp_storage DEPENDS ${generated_lcp_storage_files}) set(storage_v2_src_files ${storage_v2_src_files} @@ -26,7 +20,7 @@ set(storage_v2_src_files replication/serialization.cpp replication/slk.cpp replication/replication_persistence_helper.cpp - ${lcp_storage_cpp_files}) + replication/rpc.cpp) ####################### find_package(gflags REQUIRED) @@ -35,5 +29,4 @@ find_package(Threads REQUIRED) add_library(mg-storage-v2 STATIC ${storage_v2_src_files}) target_link_libraries(mg-storage-v2 Threads::Threads mg-utils gflags) -add_dependencies(mg-storage-v2 generate_lcp_storage) target_link_libraries(mg-storage-v2 mg-rpc mg-slk) diff --git a/src/storage/v2/replication/.gitignore b/src/storage/v2/replication/.gitignore deleted file mode 100644 index 8fb0c720c..000000000 --- a/src/storage/v2/replication/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -# autogenerated files -rpc.hpp diff --git a/src/storage/v2/replication/rpc.cpp b/src/storage/v2/replication/rpc.cpp new file mode 100644 index 000000000..783e3b4f5 --- /dev/null +++ b/src/storage/v2/replication/rpc.cpp @@ -0,0 +1,263 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "storage/v2/replication/rpc.hpp" +#include "utils/typeinfo.hpp" + +namespace memgraph { + +namespace storage { + +namespace replication { + +void AppendDeltasReq::Save(const AppendDeltasReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); +} +void AppendDeltasReq::Load(AppendDeltasReq *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); } +void AppendDeltasRes::Save(const AppendDeltasRes &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); +} +void AppendDeltasRes::Load(AppendDeltasRes *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); } +void HeartbeatReq::Save(const HeartbeatReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); +} +void HeartbeatReq::Load(HeartbeatReq *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); } +void HeartbeatRes::Save(const HeartbeatRes &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); +} +void HeartbeatRes::Load(HeartbeatRes *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); } +void FrequentHeartbeatReq::Save(const FrequentHeartbeatReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); +} +void FrequentHeartbeatReq::Load(FrequentHeartbeatReq *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(self, reader); +} +void FrequentHeartbeatRes::Save(const FrequentHeartbeatRes &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); +} +void FrequentHeartbeatRes::Load(FrequentHeartbeatRes *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(self, reader); +} +void SnapshotReq::Save(const SnapshotReq &self, memgraph::slk::Builder *builder) { memgraph::slk::Save(self, builder); } +void SnapshotReq::Load(SnapshotReq *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); } +void SnapshotRes::Save(const SnapshotRes &self, memgraph::slk::Builder *builder) { memgraph::slk::Save(self, builder); } +void SnapshotRes::Load(SnapshotRes *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); } +void WalFilesReq::Save(const WalFilesReq &self, memgraph::slk::Builder *builder) { memgraph::slk::Save(self, builder); } +void WalFilesReq::Load(WalFilesReq *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); } +void WalFilesRes::Save(const WalFilesRes &self, memgraph::slk::Builder *builder) { memgraph::slk::Save(self, builder); } +void WalFilesRes::Load(WalFilesRes *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); } +void CurrentWalReq::Save(const CurrentWalReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); +} +void CurrentWalReq::Load(CurrentWalReq *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); } +void CurrentWalRes::Save(const CurrentWalRes &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); +} +void CurrentWalRes::Load(CurrentWalRes *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); } +void TimestampReq::Save(const TimestampReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); +} +void TimestampReq::Load(TimestampReq *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); } +void TimestampRes::Save(const TimestampRes &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self, builder); +} +void TimestampRes::Load(TimestampRes *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); } + +} // namespace replication +} // namespace storage + +constexpr utils::TypeInfo storage::replication::AppendDeltasReq::kType{utils::TypeId::REP_APPEND_DELTAS_REQ, + "AppendDeltasReq", nullptr}; + +constexpr utils::TypeInfo storage::replication::AppendDeltasRes::kType{utils::TypeId::REP_APPEND_DELTAS_RES, + "AppendDeltasRes", nullptr}; + +constexpr utils::TypeInfo storage::replication::HeartbeatReq::kType{utils::TypeId::REP_HEARTBEAT_REQ, "HeartbeatReq", + nullptr}; + +constexpr utils::TypeInfo storage::replication::HeartbeatRes::kType{utils::TypeId::REP_HEARTBEAT_RES, "HeartbeatRes", + nullptr}; + +constexpr utils::TypeInfo storage::replication::FrequentHeartbeatReq::kType{utils::TypeId::REP_FREQUENT_HEARTBEAT_REQ, + "FrequentHeartbeatReq", nullptr}; + +constexpr utils::TypeInfo storage::replication::FrequentHeartbeatRes::kType{utils::TypeId::REP_FREQUENT_HEARTBEAT_RES, + "FrequentHeartbeatRes", nullptr}; + +constexpr utils::TypeInfo storage::replication::SnapshotReq::kType{utils::TypeId::REP_SNAPSHOT_REQ, "SnapshotReq", + nullptr}; + +constexpr utils::TypeInfo storage::replication::SnapshotRes::kType{utils::TypeId::REP_SNAPSHOT_RES, "SnapshotRes", + nullptr}; + +constexpr utils::TypeInfo storage::replication::WalFilesReq::kType{utils::TypeId::REP_WALFILES_REQ, "WalFilesReq", + nullptr}; + +constexpr utils::TypeInfo storage::replication::WalFilesRes::kType{utils::TypeId::REP_WALFILES_RES, "WalFilesRes", + nullptr}; + +constexpr utils::TypeInfo storage::replication::CurrentWalReq::kType{utils::TypeId::REP_CURRENT_WAL_REQ, + "CurrentWalReq", nullptr}; + +constexpr utils::TypeInfo storage::replication::CurrentWalRes::kType{utils::TypeId::REP_CURRENT_WAL_RES, + "CurrentWalRes", nullptr}; + +constexpr utils::TypeInfo storage::replication::TimestampReq::kType{utils::TypeId::REP_TIMESTAMP_REQ, "TimestampReq", + nullptr}; + +constexpr utils::TypeInfo storage::replication::TimestampRes::kType{utils::TypeId::REP_TIMESTAMP_RES, "TimestampRes", + nullptr}; + +// Autogenerated SLK serialization code +namespace slk { +// Serialize code for TimestampRes + +void Save(const memgraph::storage::replication::TimestampRes &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.success, builder); + memgraph::slk::Save(self.current_commit_timestamp, builder); +} + +void Load(memgraph::storage::replication::TimestampRes *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->success, reader); + memgraph::slk::Load(&self->current_commit_timestamp, reader); +} + +// Serialize code for TimestampReq + +void Save(const memgraph::storage::replication::TimestampReq &self, memgraph::slk::Builder *builder) {} + +void Load(memgraph::storage::replication::TimestampReq *self, memgraph::slk::Reader *reader) {} + +// Serialize code for CurrentWalRes + +void Save(const memgraph::storage::replication::CurrentWalRes &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.success, builder); + memgraph::slk::Save(self.current_commit_timestamp, builder); +} + +void Load(memgraph::storage::replication::CurrentWalRes *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->success, reader); + memgraph::slk::Load(&self->current_commit_timestamp, reader); +} + +// Serialize code for CurrentWalReq + +void Save(const memgraph::storage::replication::CurrentWalReq &self, memgraph::slk::Builder *builder) {} + +void Load(memgraph::storage::replication::CurrentWalReq *self, memgraph::slk::Reader *reader) {} + +// Serialize code for WalFilesRes + +void Save(const memgraph::storage::replication::WalFilesRes &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.success, builder); + memgraph::slk::Save(self.current_commit_timestamp, builder); +} + +void Load(memgraph::storage::replication::WalFilesRes *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->success, reader); + memgraph::slk::Load(&self->current_commit_timestamp, reader); +} + +// Serialize code for WalFilesReq + +void Save(const memgraph::storage::replication::WalFilesReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.file_number, builder); +} + +void Load(memgraph::storage::replication::WalFilesReq *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->file_number, reader); +} + +// Serialize code for SnapshotRes + +void Save(const memgraph::storage::replication::SnapshotRes &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.success, builder); + memgraph::slk::Save(self.current_commit_timestamp, builder); +} + +void Load(memgraph::storage::replication::SnapshotRes *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->success, reader); + memgraph::slk::Load(&self->current_commit_timestamp, reader); +} + +// Serialize code for SnapshotReq + +void Save(const memgraph::storage::replication::SnapshotReq &self, memgraph::slk::Builder *builder) {} + +void Load(memgraph::storage::replication::SnapshotReq *self, memgraph::slk::Reader *reader) {} + +// Serialize code for FrequentHeartbeatRes + +void Save(const memgraph::storage::replication::FrequentHeartbeatRes &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.success, builder); +} + +void Load(memgraph::storage::replication::FrequentHeartbeatRes *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->success, reader); +} + +// Serialize code for FrequentHeartbeatReq + +void Save(const memgraph::storage::replication::FrequentHeartbeatReq &self, memgraph::slk::Builder *builder) {} + +void Load(memgraph::storage::replication::FrequentHeartbeatReq *self, memgraph::slk::Reader *reader) {} + +// Serialize code for HeartbeatRes + +void Save(const memgraph::storage::replication::HeartbeatRes &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.success, builder); + memgraph::slk::Save(self.current_commit_timestamp, builder); + memgraph::slk::Save(self.epoch_id, builder); +} + +void Load(memgraph::storage::replication::HeartbeatRes *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->success, reader); + memgraph::slk::Load(&self->current_commit_timestamp, reader); + memgraph::slk::Load(&self->epoch_id, reader); +} + +// Serialize code for HeartbeatReq + +void Save(const memgraph::storage::replication::HeartbeatReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.main_commit_timestamp, builder); + memgraph::slk::Save(self.epoch_id, builder); +} + +void Load(memgraph::storage::replication::HeartbeatReq *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->main_commit_timestamp, reader); + memgraph::slk::Load(&self->epoch_id, reader); +} + +// Serialize code for AppendDeltasRes + +void Save(const memgraph::storage::replication::AppendDeltasRes &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.success, builder); + memgraph::slk::Save(self.current_commit_timestamp, builder); +} + +void Load(memgraph::storage::replication::AppendDeltasRes *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->success, reader); + memgraph::slk::Load(&self->current_commit_timestamp, reader); +} + +// Serialize code for AppendDeltasReq + +void Save(const memgraph::storage::replication::AppendDeltasReq &self, memgraph::slk::Builder *builder) { + memgraph::slk::Save(self.previous_commit_timestamp, builder); + memgraph::slk::Save(self.seq_num, builder); +} + +void Load(memgraph::storage::replication::AppendDeltasReq *self, memgraph::slk::Reader *reader) { + memgraph::slk::Load(&self->previous_commit_timestamp, reader); + memgraph::slk::Load(&self->seq_num, reader); +} +} // namespace slk +} // namespace memgraph diff --git a/src/storage/v2/replication/rpc.hpp b/src/storage/v2/replication/rpc.hpp new file mode 100644 index 000000000..f466b1880 --- /dev/null +++ b/src/storage/v2/replication/rpc.hpp @@ -0,0 +1,278 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include <cstdint> +#include <cstring> +#include <string> + +#include "rpc/messages.hpp" +#include "slk/serialization.hpp" +#include "slk/streams.hpp" + +namespace memgraph { + +namespace storage { + +namespace replication { + +struct AppendDeltasReq { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(AppendDeltasReq *self, memgraph::slk::Reader *reader); + static void Save(const AppendDeltasReq &self, memgraph::slk::Builder *builder); + AppendDeltasReq() {} + AppendDeltasReq(uint64_t previous_commit_timestamp, uint64_t seq_num) + : previous_commit_timestamp(previous_commit_timestamp), seq_num(seq_num) {} + + uint64_t previous_commit_timestamp; + uint64_t seq_num; +}; + +struct AppendDeltasRes { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(AppendDeltasRes *self, memgraph::slk::Reader *reader); + static void Save(const AppendDeltasRes &self, memgraph::slk::Builder *builder); + AppendDeltasRes() {} + AppendDeltasRes(bool success, uint64_t current_commit_timestamp) + : success(success), current_commit_timestamp(current_commit_timestamp) {} + + bool success; + uint64_t current_commit_timestamp; +}; + +using AppendDeltasRpc = rpc::RequestResponse<AppendDeltasReq, AppendDeltasRes>; + +struct HeartbeatReq { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(HeartbeatReq *self, memgraph::slk::Reader *reader); + static void Save(const HeartbeatReq &self, memgraph::slk::Builder *builder); + HeartbeatReq() {} + HeartbeatReq(uint64_t main_commit_timestamp, std::string epoch_id) + : main_commit_timestamp(main_commit_timestamp), epoch_id(epoch_id) {} + + uint64_t main_commit_timestamp; + std::string epoch_id; +}; + +struct HeartbeatRes { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(HeartbeatRes *self, memgraph::slk::Reader *reader); + static void Save(const HeartbeatRes &self, memgraph::slk::Builder *builder); + HeartbeatRes() {} + HeartbeatRes(bool success, uint64_t current_commit_timestamp, std::string epoch_id) + : success(success), current_commit_timestamp(current_commit_timestamp), epoch_id(epoch_id) {} + + bool success; + uint64_t current_commit_timestamp; + std::string epoch_id; +}; + +using HeartbeatRpc = rpc::RequestResponse<HeartbeatReq, HeartbeatRes>; + +struct FrequentHeartbeatReq { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(FrequentHeartbeatReq *self, memgraph::slk::Reader *reader); + static void Save(const FrequentHeartbeatReq &self, memgraph::slk::Builder *builder); + FrequentHeartbeatReq() {} +}; + +struct FrequentHeartbeatRes { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(FrequentHeartbeatRes *self, memgraph::slk::Reader *reader); + static void Save(const FrequentHeartbeatRes &self, memgraph::slk::Builder *builder); + FrequentHeartbeatRes() {} + explicit FrequentHeartbeatRes(bool success) : success(success) {} + + bool success; +}; + +using FrequentHeartbeatRpc = rpc::RequestResponse<FrequentHeartbeatReq, FrequentHeartbeatRes>; + +struct SnapshotReq { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(SnapshotReq *self, memgraph::slk::Reader *reader); + static void Save(const SnapshotReq &self, memgraph::slk::Builder *builder); + SnapshotReq() {} +}; + +struct SnapshotRes { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(SnapshotRes *self, memgraph::slk::Reader *reader); + static void Save(const SnapshotRes &self, memgraph::slk::Builder *builder); + SnapshotRes() {} + SnapshotRes(bool success, uint64_t current_commit_timestamp) + : success(success), current_commit_timestamp(current_commit_timestamp) {} + + bool success; + uint64_t current_commit_timestamp; +}; + +using SnapshotRpc = rpc::RequestResponse<SnapshotReq, SnapshotRes>; + +struct WalFilesReq { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(WalFilesReq *self, memgraph::slk::Reader *reader); + static void Save(const WalFilesReq &self, memgraph::slk::Builder *builder); + WalFilesReq() {} + explicit WalFilesReq(uint64_t file_number) : file_number(file_number) {} + + uint64_t file_number; +}; + +struct WalFilesRes { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(WalFilesRes *self, memgraph::slk::Reader *reader); + static void Save(const WalFilesRes &self, memgraph::slk::Builder *builder); + WalFilesRes() {} + WalFilesRes(bool success, uint64_t current_commit_timestamp) + : success(success), current_commit_timestamp(current_commit_timestamp) {} + + bool success; + uint64_t current_commit_timestamp; +}; + +using WalFilesRpc = rpc::RequestResponse<WalFilesReq, WalFilesRes>; + +struct CurrentWalReq { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(CurrentWalReq *self, memgraph::slk::Reader *reader); + static void Save(const CurrentWalReq &self, memgraph::slk::Builder *builder); + CurrentWalReq() {} +}; + +struct CurrentWalRes { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(CurrentWalRes *self, memgraph::slk::Reader *reader); + static void Save(const CurrentWalRes &self, memgraph::slk::Builder *builder); + CurrentWalRes() {} + CurrentWalRes(bool success, uint64_t current_commit_timestamp) + : success(success), current_commit_timestamp(current_commit_timestamp) {} + + bool success; + uint64_t current_commit_timestamp; +}; + +using CurrentWalRpc = rpc::RequestResponse<CurrentWalReq, CurrentWalRes>; + +struct TimestampReq { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(TimestampReq *self, memgraph::slk::Reader *reader); + static void Save(const TimestampReq &self, memgraph::slk::Builder *builder); + TimestampReq() {} +}; + +struct TimestampRes { + static const utils::TypeInfo kType; + static const utils::TypeInfo &GetTypeInfo() { return kType; } + + static void Load(TimestampRes *self, memgraph::slk::Reader *reader); + static void Save(const TimestampRes &self, memgraph::slk::Builder *builder); + TimestampRes() {} + TimestampRes(bool success, uint64_t current_commit_timestamp) + : success(success), current_commit_timestamp(current_commit_timestamp) {} + + bool success; + uint64_t current_commit_timestamp; +}; + +using TimestampRpc = rpc::RequestResponse<TimestampReq, TimestampRes>; +} // namespace replication +} // namespace storage +} // namespace memgraph + +// SLK serialization declarations +#include "slk/serialization.hpp" +namespace memgraph::slk { + +void Save(const memgraph::storage::replication::TimestampRes &self, memgraph::slk::Builder *builder); + +void Load(memgraph::storage::replication::TimestampRes *self, memgraph::slk::Reader *reader); + +void Save(const memgraph::storage::replication::TimestampReq &self, memgraph::slk::Builder *builder); + +void Load(memgraph::storage::replication::TimestampReq *self, memgraph::slk::Reader *reader); + +void Save(const memgraph::storage::replication::CurrentWalRes &self, memgraph::slk::Builder *builder); + +void Load(memgraph::storage::replication::CurrentWalRes *self, memgraph::slk::Reader *reader); + +void Save(const memgraph::storage::replication::CurrentWalReq &self, memgraph::slk::Builder *builder); + +void Load(memgraph::storage::replication::CurrentWalReq *self, memgraph::slk::Reader *reader); + +void Save(const memgraph::storage::replication::WalFilesRes &self, memgraph::slk::Builder *builder); + +void Load(memgraph::storage::replication::WalFilesRes *self, memgraph::slk::Reader *reader); + +void Save(const memgraph::storage::replication::WalFilesReq &self, memgraph::slk::Builder *builder); + +void Load(memgraph::storage::replication::WalFilesReq *self, memgraph::slk::Reader *reader); + +void Save(const memgraph::storage::replication::SnapshotRes &self, memgraph::slk::Builder *builder); + +void Load(memgraph::storage::replication::SnapshotRes *self, memgraph::slk::Reader *reader); + +void Save(const memgraph::storage::replication::SnapshotReq &self, memgraph::slk::Builder *builder); + +void Load(memgraph::storage::replication::SnapshotReq *self, memgraph::slk::Reader *reader); + +void Save(const memgraph::storage::replication::FrequentHeartbeatRes &self, memgraph::slk::Builder *builder); + +void Load(memgraph::storage::replication::FrequentHeartbeatRes *self, memgraph::slk::Reader *reader); + +void Save(const memgraph::storage::replication::FrequentHeartbeatReq &self, memgraph::slk::Builder *builder); + +void Load(memgraph::storage::replication::FrequentHeartbeatReq *self, memgraph::slk::Reader *reader); + +void Save(const memgraph::storage::replication::HeartbeatRes &self, memgraph::slk::Builder *builder); + +void Load(memgraph::storage::replication::HeartbeatRes *self, memgraph::slk::Reader *reader); + +void Save(const memgraph::storage::replication::HeartbeatReq &self, memgraph::slk::Builder *builder); + +void Load(memgraph::storage::replication::HeartbeatReq *self, memgraph::slk::Reader *reader); + +void Save(const memgraph::storage::replication::AppendDeltasRes &self, memgraph::slk::Builder *builder); + +void Load(memgraph::storage::replication::AppendDeltasRes *self, memgraph::slk::Reader *reader); + +void Save(const memgraph::storage::replication::AppendDeltasReq &self, memgraph::slk::Builder *builder); + +void Load(memgraph::storage::replication::AppendDeltasReq *self, memgraph::slk::Reader *reader); + +} // namespace memgraph::slk diff --git a/src/utils/typeinfo.hpp b/src/utils/typeinfo.hpp index 7d8085bce..ca0cbe39d 100644 --- a/src/utils/typeinfo.hpp +++ b/src/utils/typeinfo.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -16,6 +16,172 @@ namespace memgraph::utils { +enum class TypeId : uint64_t { + // Operators + UNKNOWN, + LOGICAL_OPERATOR, + ONCE, + NODE_CREATION_INFO, + CREATE_NODE, + EDGE_CREATION_INFO, + CREATE_EXPAND, + SCAN_ALL, + SCAN_ALL_BY_LABEL, + SCAN_ALL_BY_LABEL_PROPERTY_RANGE, + SCAN_ALL_BY_LABEL_PROPERTY_VALUE, + SCAN_ALL_BY_LABEL_PROPERTY, + SCAN_ALL_BY_ID, + EXPAND_COMMON, + EXPAND, + EXPANSION_LAMBDA, + EXPAND_VARIABLE, + CONSTRUCT_NAMED_PATH, + FILTER, + PRODUCE, + DELETE, + SET_PROPERTY, + SET_PROPERTIES, + SET_LABELS, + REMOVE_PROPERTY, + REMOVE_LABELS, + EDGE_UNIQUENESS_FILTER, + EMPTY_RESULT, + ACCUMULATE, + AGGREGATE, + AGGREGATE_ELEMENT, + SKIP, + EVALUATE_PATTERN_FILTER, + LIMIT, + ORDERBY, + MERGE, + OPTIONAL, + UNWIND, + DISTINCT, + UNION, + CARTESIAN, + OUTPUT_TABLE, + OUTPUT_TABLE_STREAM, + CALL_PROCEDURE, + LOAD_CSV, + FOREACH, + + // Replication + REP_APPEND_DELTAS_REQ, + REP_APPEND_DELTAS_RES, + REP_HEARTBEAT_REQ, + REP_HEARTBEAT_RES, + REP_FREQUENT_HEARTBEAT_REQ, + REP_FREQUENT_HEARTBEAT_RES, + REP_SNAPSHOT_REQ, + REP_SNAPSHOT_RES, + REP_WALFILES_REQ, + REP_WALFILES_RES, + REP_CURRENT_WAL_REQ, + REP_CURRENT_WAL_RES, + REP_TIMESTAMP_REQ, + REP_TIMESTAMP_RES, + + // AST + AST_LABELIX, + AST_PROPERTYIX, + AST_EDGETYPEIX, + AST_TREE, + AST_EXPRESSION, + AST_WHERE, + AST_BINARY_OPERATOR, + AST_UNARY_OPERATOR, + AST_OR_OPERATOR, + AST_XOR_OPERATOR, + AST_AND_OPERATOR, + AST_ADDITION_OPERATOR, + AST_SUBTRACTION_OPERATOR, + AST_MULTIPLICATION_OPERATOR, + AST_DIVISION_OPERATOR, + AST_MOD_OPERATOR, + AST_NOT_EQUAL_OPERATOR, + AST_EQUAL_OPERATOR, + AST_LESS_OPERATOR, + AST_GREATER_OPERATOR, + AST_LESS_EQUAL_OPERATOR, + AST_GREATER_EQUAL_OPERATOR, + AST_IN_LIST_OPERATOR, + AST_SUBSCRIPT_OPERATOR, + AST_NOT_OPERATOR, + AST_UNARY_PLUS_OPERATOR, + AST_UNARY_MINUS_OPERATOR, + AST_IS_NULL_OPERATOR, + AST_AGGREGATION, + AST_LIST_SLICING_OPERATOR, + AST_IF_OPERATOR, + AST_BASE_LITERAL, + AST_PRIMITIVE_LITERAL, + AST_LIST_LITERAL, + AST_MAP_LITERAL, + AST_IDENTIFIER, + AST_PROPERTY_LOOKUP, + AST_LABELS_TEST, + AST_FUNCTION, + AST_REDUCE, + AST_COALESCE, + AST_EXTRACT, + AST_ALL, + AST_SINGLE, + AST_ANY, + AST_NONE, + AST_PARAMETER_LOOKUP, + AST_REGEX_MATCH, + AST_NAMED_EXPRESSION, + AST_PATTERN_ATOM, + AST_NODE_ATOM, + AST_EDGE_ATOM_LAMBDA, + AST_EDGE_ATOM, + AST_PATTERN, + AST_CLAUSE, + AST_SINGLE_QUERY, + AST_CYPHER_UNION, + AST_QUERY, + AST_CYPHER_QUERY, + AST_EXPLAIN_QUERY, + AST_PROFILE_QUERY, + AST_INDEX_QUERY, + AST_CREATE, + AST_CALL_PROCEDURE, + AST_MATCH, + AST_SORT_ITEM, + AST_RETURN_BODY, + AST_RETURN, + AST_WITH, + AST_DELETE, + AST_SET_PROPERTY, + AST_SET_PROPERTIES, + AST_SET_LABELS, + AST_REMOVE_PROPERTY, + AST_REMOVE_LABELS, + AST_MERGE, + AST_UNWIND, + AST_AUTH_QUERY, + AST_INFO_QUERY, + AST_CONSTRAINT, + AST_CONSTRAINT_QUERY, + AST_DUMP_QUERY, + AST_REPLICATION_QUERY, + AST_LOCK_PATH_QUERY, + AST_LOAD_CSV, + AST_FREE_MEMORY_QUERY, + AST_TRIGGER_QUERY, + AST_ISOLATION_LEVEL_QUERY, + AST_CREATE_SNAPSHOT_QUERY, + AST_STREAM_QUERY, + AST_SETTING_QUERY, + AST_VERSION_QUERY, + AST_FOREACH, + AST_SHOW_CONFIG_QUERY, + AST_EXISTS, + + // Symbol + SYMBOL, +}; + /// Type information on a C++ type. /// /// You should embed this structure as a static constant member `kType` and make @@ -24,7 +190,7 @@ namespace memgraph::utils { /// runtime type. struct TypeInfo { /// Unique ID for the type. - uint64_t id; + TypeId id; /// Pretty name of the type. const char *name; /// `TypeInfo *` for superclass of this type. diff --git a/tests/benchmark/rpc.cpp b/tests/benchmark/rpc.cpp index fc60b76b1..3dace0531 100644 --- a/tests/benchmark/rpc.cpp +++ b/tests/benchmark/rpc.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -25,7 +25,7 @@ struct EchoMessage { static const memgraph::utils::TypeInfo kType; EchoMessage() {} // Needed for serialization. - EchoMessage(const std::string &data) : data(data) {} + explicit EchoMessage(const std::string &data) : data(data) {} static void Load(EchoMessage *obj, memgraph::slk::Reader *reader); static void Save(const EchoMessage &obj, memgraph::slk::Builder *builder); @@ -41,7 +41,7 @@ void Load(EchoMessage *echo, Reader *reader) { Load(&echo->data, reader); } void EchoMessage::Load(EchoMessage *obj, memgraph::slk::Reader *reader) { memgraph::slk::Load(obj, reader); } void EchoMessage::Save(const EchoMessage &obj, memgraph::slk::Builder *builder) { memgraph::slk::Save(obj, builder); } -const memgraph::utils::TypeInfo EchoMessage::kType{2, "EchoMessage"}; +const memgraph::utils::TypeInfo EchoMessage::kType{memgraph::utils::TypeId::UNKNOWN, "EchoMessage"}; using Echo = memgraph::rpc::RequestResponse<EchoMessage, EchoMessage>; diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index fe3cea930..dfcbc9697 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -360,15 +360,6 @@ if(MG_ENTERPRISE) target_link_libraries(${test_prefix}rpc mg-rpc) endif() -# Test LCP -add_custom_command( - OUTPUT test_lcp - DEPENDS ${lcp_src_files} lcp test_lcp.lisp - COMMAND sbcl --script ${CMAKE_CURRENT_SOURCE_DIR}/test_lcp.lisp) -add_custom_target(test_lcp ALL DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/test_lcp) -add_test(test_lcp ${CMAKE_CURRENT_BINARY_DIR}/test_lcp) -add_dependencies(memgraph__unit test_lcp) - # Test websocket find_package(Boost REQUIRED) diff --git a/tests/unit/rpc_messages.hpp b/tests/unit/rpc_messages.hpp index aa8b89fe3..6058c37cf 100644 --- a/tests/unit/rpc_messages.hpp +++ b/tests/unit/rpc_messages.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -28,7 +28,7 @@ struct SumReq { int y; }; -const memgraph::utils::TypeInfo SumReq::kType{0, "SumReq"}; +const memgraph::utils::TypeInfo SumReq::kType{memgraph::utils::TypeId::UNKNOWN, "SumReq"}; struct SumRes { static const memgraph::utils::TypeInfo kType; @@ -42,7 +42,7 @@ struct SumRes { int sum; }; -const memgraph::utils::TypeInfo SumRes::kType{1, "SumRes"}; +const memgraph::utils::TypeInfo SumRes::kType{memgraph::utils::TypeId::UNKNOWN, "SumRes"}; namespace memgraph::slk { void Save(const SumReq &sum, Builder *builder); @@ -66,7 +66,7 @@ struct EchoMessage { std::string data; }; -const memgraph::utils::TypeInfo EchoMessage::kType{2, "EchoMessage"}; +const memgraph::utils::TypeInfo EchoMessage::kType{memgraph::utils::TypeId::UNKNOWN, "EchoMessage"}; namespace memgraph::slk { void Save(const EchoMessage &echo, Builder *builder); From 6349fc950133be5b2186315419d077c041369a8f Mon Sep 17 00:00:00 2001 From: Ante Javor <javor.ante@gmail.com> Date: Sat, 18 Mar 2023 20:18:58 +0100 Subject: [PATCH 2/6] Add time-depended execution to the mgbench client (#805) --- tests/mgbench/client.cpp | 128 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 121 insertions(+), 7 deletions(-) diff --git a/tests/mgbench/client.cpp b/tests/mgbench/client.cpp index c7e002df6..b18567fda 100644 --- a/tests/mgbench/client.cpp +++ b/tests/mgbench/client.cpp @@ -58,6 +58,10 @@ DEFINE_bool(validation, false, "Set to true to run client in validation mode." "Validation mode works for singe query and returns results for validation" "with metadata"); +DEFINE_int64(time_dependent_execution, 0, + "Time-dependent executions execute the queries for a specified number of seconds." + "If all queries are executed, and there is still time, queries are rerun again." + "If the time runs out, the client is done with the job and returning results."); std::pair<std::map<std::string, memgraph::communication::bolt::Value>, uint64_t> ExecuteNTimesTillSuccess( memgraph::communication::bolt::Client *client, const std::string &query, @@ -220,7 +224,113 @@ nlohmann::json LatencyStatistics(std::vector<std::vector<double>> &worker_query_ return statistics; } -void Execute( +void ExecuteTimeDependentWorkload( + const std::vector<std::pair<std::string, std::map<std::string, memgraph::communication::bolt::Value>>> &queries, + std::ostream *stream) { + std::vector<std::thread> threads; + threads.reserve(FLAGS_num_workers); + + std::vector<uint64_t> worker_retries(FLAGS_num_workers, 0); + std::vector<Metadata> worker_metadata(FLAGS_num_workers, Metadata()); + std::vector<double> worker_duration(FLAGS_num_workers, 0.0); + std::vector<std::vector<double>> worker_query_durations(FLAGS_num_workers); + + // Start workers and execute queries. + auto size = queries.size(); + std::atomic<bool> run(false); + std::atomic<uint64_t> ready(0); + std::atomic<uint64_t> position(0); + std::atomic<bool> start_workload_timer(false); + + std::chrono::time_point<std::chrono::steady_clock> workload_start; + std::chrono::duration<double> time_limit = std::chrono::seconds(FLAGS_time_dependent_execution); + for (int worker = 0; worker < FLAGS_num_workers; ++worker) { + threads.push_back(std::thread([&, worker]() { + memgraph::io::network::Endpoint endpoint(FLAGS_address, FLAGS_port); + memgraph::communication::ClientContext context(FLAGS_use_ssl); + memgraph::communication::bolt::Client client(context); + client.Connect(endpoint, FLAGS_username, FLAGS_password); + + ready.fetch_add(1, std::memory_order_acq_rel); + while (!run.load(std::memory_order_acq_rel)) + ; + auto &retries = worker_retries[worker]; + auto &metadata = worker_metadata[worker]; + auto &duration = worker_duration[worker]; + auto &query_duration = worker_query_durations[worker]; + + // After all threads have been initialised, start the workload timer + if (!start_workload_timer.load()) { + workload_start = std::chrono::steady_clock::now(); + start_workload_timer.store(true); + } + + memgraph::utils::Timer worker_timer; + while (std::chrono::duration_cast<std::chrono::duration<double>>(std::chrono::steady_clock::now() - + workload_start) < time_limit) { + auto pos = position.fetch_add(1, std::memory_order_acq_rel); + if (pos >= size) { + /// Get back to inital position + position.store(0, std::memory_order_acq_rel); + pos = position.fetch_add(1, std::memory_order_acq_rel); + } + const auto &query = queries[pos]; + memgraph::utils::Timer query_timer; + auto ret = ExecuteNTimesTillSuccess(&client, query.first, query.second, FLAGS_max_retries); + query_duration.emplace_back(query_timer.Elapsed().count()); + retries += ret.second; + metadata.Append(ret.first); + duration = worker_timer.Elapsed().count(); + } + client.Close(); + })); + } + + // Synchronize workers and collect runtime. + while (ready.load(std::memory_order_acq_rel) < FLAGS_num_workers) + ; + run.store(true); + for (int i = 0; i < FLAGS_num_workers; ++i) { + threads[i].join(); + } + + // Create and output summary. + Metadata final_metadata; + uint64_t final_retries = 0; + double final_duration = 0.0; + for (int i = 0; i < FLAGS_num_workers; ++i) { + final_metadata += worker_metadata[i]; + final_retries += worker_retries[i]; + final_duration += worker_duration[i]; + } + + int total_iterations = 0; + std::for_each(worker_query_durations.begin(), worker_query_durations.end(), + [&](const std::vector<double> &v) { total_iterations += v.size(); }); + + final_duration /= FLAGS_num_workers; + double execution_delta = time_limit.count() / final_duration; + // This is adjusted throughput based on how much longer did workload execution time took. + double throughput = (total_iterations / final_duration) * execution_delta; + double raw_throughput = total_iterations / final_duration; + + nlohmann::json summary = nlohmann::json::object(); + summary["count"] = queries.size(); + summary["duration"] = final_duration; + summary["time_limit"] = FLAGS_time_dependent_execution; + summary["queries_executed"] = total_iterations; + + summary["throughput"] = throughput; + summary["raw_throughput"] = raw_throughput; + summary["latency_stats"] = LatencyStatistics(worker_query_durations); + summary["retries"] = final_retries; + summary["metadata"] = final_metadata.Export(); + summary["num_workers"] = FLAGS_num_workers; + + (*stream) << summary.dump() << std::endl; +} + +void ExecuteWorkload( const std::vector<std::pair<std::string, std::map<std::string, memgraph::communication::bolt::Value>>> &queries, std::ostream *stream) { std::vector<std::thread> threads; @@ -259,7 +369,7 @@ void Execute( const auto &query = queries[pos]; memgraph::utils::Timer query_timer; auto ret = ExecuteNTimesTillSuccess(&client, query.first, query.second, FLAGS_max_retries); - query_duration.push_back(query_timer.Elapsed().count()); + query_duration.emplace_back(query_timer.Elapsed().count()); retries += ret.second; metadata.Append(ret.first); } @@ -272,6 +382,7 @@ void Execute( while (ready.load(std::memory_order_acq_rel) < FLAGS_num_workers) ; run.store(true, std::memory_order_acq_rel); + for (int i = 0; i < FLAGS_num_workers; ++i) { threads[i].join(); } @@ -363,6 +474,7 @@ int main(int argc, char **argv) { spdlog::info("Input: {}", FLAGS_input); spdlog::info("Output: {}", FLAGS_output); spdlog::info("Validation: {}", FLAGS_validation); + spdlog::info("Time dependend execution: {}", FLAGS_time_dependent_execution); memgraph::communication::SSLInit sslInit; @@ -390,7 +502,7 @@ int main(int argc, char **argv) { while (std::getline(*istream, query)) { auto trimmed = memgraph::utils::Trim(query); if (trimmed == "" || trimmed == ";") { - Execute(queries, ostream); + ExecuteWorkload(queries, ostream); queries.clear(); continue; } @@ -406,7 +518,7 @@ int main(int argc, char **argv) { "array!"); MG_ASSERT(data.is_array() && data.size() == 2, "Each item of the loaded JSON queries must be an array!"); if (data.size() == 0) { - Execute(queries, ostream); + ExecuteWorkload(queries, ostream); queries.clear(); continue; } @@ -424,10 +536,12 @@ int main(int argc, char **argv) { } } - if (!FLAGS_validation) { - Execute(queries, ostream); - } else { + if (FLAGS_validation) { ExecuteValidation(queries, ostream); + } else if (FLAGS_time_dependent_execution > 0) { + ExecuteTimeDependentWorkload(queries, ostream); + } else { + ExecuteWorkload(queries, ostream); } return 0; From cb813c307050b019f5dca2676a4c730530f8930c Mon Sep 17 00:00:00 2001 From: Ante Javor <javor.ante@gmail.com> Date: Tue, 21 Mar 2023 21:44:11 +0100 Subject: [PATCH 3/6] Add bigger LDBC dataset to mgbench (#747) --- tests/mgbench/README.md | 2 +- tests/mgbench/benchmark.py | 1069 ++++++++--------- tests/mgbench/benchmark_context.py | 57 + tests/mgbench/client.cpp | 3 +- tests/mgbench/compare_results.py | 11 +- tests/mgbench/cypher/__init__.py | 0 tests/mgbench/cypher/ldbc_to_cypher.py | 500 ++++++++ tests/mgbench/graph_bench.py | 208 ++-- tests/mgbench/helpers.py | 198 ++- tests/mgbench/log.py | 32 +- tests/mgbench/runners.py | 297 +++-- tests/mgbench/validation.py | 244 ++++ tests/mgbench/workloads/__init__.py | 4 + tests/mgbench/workloads/base.py | 197 +++ tests/mgbench/workloads/demo.py | 28 + tests/mgbench/workloads/importers/__init__.py | 0 .../workloads/importers/importer_ldbc_bi.py | 213 ++++ .../importers/importer_ldbc_interactive.py | 163 +++ .../workloads/importers/importer_pokec.py | 41 + tests/mgbench/workloads/ldbc_bi.py | 708 +++++++++++ tests/mgbench/workloads/ldbc_interactive.py | 684 +++++++++++ .../{datasets.py => workloads/pokec.py} | 151 +-- 22 files changed, 3907 insertions(+), 903 deletions(-) create mode 100644 tests/mgbench/benchmark_context.py create mode 100644 tests/mgbench/cypher/__init__.py create mode 100644 tests/mgbench/cypher/ldbc_to_cypher.py create mode 100644 tests/mgbench/validation.py create mode 100644 tests/mgbench/workloads/__init__.py create mode 100644 tests/mgbench/workloads/base.py create mode 100644 tests/mgbench/workloads/demo.py create mode 100644 tests/mgbench/workloads/importers/__init__.py create mode 100644 tests/mgbench/workloads/importers/importer_ldbc_bi.py create mode 100644 tests/mgbench/workloads/importers/importer_ldbc_interactive.py create mode 100644 tests/mgbench/workloads/importers/importer_pokec.py create mode 100644 tests/mgbench/workloads/ldbc_bi.py create mode 100644 tests/mgbench/workloads/ldbc_interactive.py rename tests/mgbench/{datasets.py => workloads/pokec.py} (72%) diff --git a/tests/mgbench/README.md b/tests/mgbench/README.md index 12c114186..6c51e4850 100644 --- a/tests/mgbench/README.md +++ b/tests/mgbench/README.md @@ -247,7 +247,7 @@ Index queries for each supported vendor can be downloaded from “https://s3.eu- |Q19|pattern_short| analytical | MATCH (n:User {id: $id})-[e]->(m) RETURN m LIMIT 1| |Q20|single_edge_write| write | MATCH (n:User {id: $from}), (m:User {id: $to}) WITH n, m CREATE (n)-[e:Temp]->(m) RETURN e| |Q21|single_vertex_write| write |CREATE (n:UserTemp {id : $id}) RETURN n| -|Q22|single_vertex_property_update| update | MATCH (n:User {id: $id})-[e]->(m) RETURN m LIMIT 1| +|Q22|single_vertex_property_update| update | MATCH (n:User {id: $id}) SET n.property = -1| |Q23|single_vertex_read| read | MATCH (n:User {id : $id}) RETURN n| ## :computer: Platform diff --git a/tests/mgbench/benchmark.py b/tests/mgbench/benchmark.py index f08d6e3fa..376396650 100755 --- a/tests/mgbench/benchmark.py +++ b/tests/mgbench/benchmark.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2022 Memgraph Ltd. +# Copyright 2023 Memgraph Ltd. # # Use of this software is governed by the Business Source License # included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -12,166 +12,165 @@ # licenses/APL.txt. import argparse -import collections -import copy -import fnmatch -import inspect import json -import math import multiprocessing +import platform import random -import statistics -import sys -import datasets import helpers import log import runners +from benchmark_context import BenchmarkContext +from workloads import * WITH_FINE_GRAINED_AUTHORIZATION = "with_fine_grained_authorization" WITHOUT_FINE_GRAINED_AUTHORIZATION = "without_fine_grained_authorization" - -# Parse options. -parser = argparse.ArgumentParser( - description="Memgraph benchmark executor.", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, -) -parser.add_argument( - "benchmarks", - nargs="*", - default="", - help="descriptions of benchmarks that should be run; " - "multiple descriptions can be specified to run multiple " - "benchmarks; the description is specified as " - "dataset/variant/group/query; Unix shell-style wildcards " - "can be used in the descriptions; variant, group and query " - "are optional and they can be left out; the default " - "variant is '' which selects the default dataset variant; " - "the default group is '*' which selects all groups; the" - "default query is '*' which selects all queries", -) -parser.add_argument( - "--vendor-binary", - help="Vendor binary used for benchmarking, by defuault it is memgraph", - default=helpers.get_binary_path("memgraph"), -) - -parser.add_argument( - "--vendor-name", - default="memgraph", - help="Input vendor binary name (memgraph, neo4j)", -) -parser.add_argument( - "--client-binary", - default=helpers.get_binary_path("tests/mgbench/client"), - help="Client binary used for benchmarking", -) -parser.add_argument( - "--num-workers-for-import", - type=int, - default=multiprocessing.cpu_count() // 2, - help="number of workers used to import the dataset", -) -parser.add_argument( - "--num-workers-for-benchmark", - type=int, - default=1, - help="number of workers used to execute the benchmark", -) -parser.add_argument( - "--single-threaded-runtime-sec", - type=int, - default=10, - help="single threaded duration of each query", -) -parser.add_argument( - "--no-load-query-counts", - action="store_true", - help="disable loading of cached query counts", -) -parser.add_argument( - "--no-save-query-counts", - action="store_true", - help="disable storing of cached query counts", -) -parser.add_argument( - "--export-results", - default="", - help="file path into which results should be exported", -) -parser.add_argument( - "--temporary-directory", - default="/tmp", - help="directory path where temporary data should " "be stored", -) -parser.add_argument("--no-properties-on-edges", action="store_true", help="disable properties on edges") - -parser.add_argument("--bolt-port", default=7687, help="memgraph bolt port") - -parser.add_argument( - "--no-authorization", - action="store_false", - default=True, - help="Run each query with authorization", -) - -parser.add_argument( - "--warmup-run", - action="store_true", - default=False, - help="Run warmup before benchmarks", -) - -parser.add_argument( - "--mixed-workload", - nargs="*", - type=int, - default=[], - help="""Define combination that defines the mixed workload. - Mixed workload can be run as a single configuration for all groups of queries, - Pass the positional arguments as values of what percentage of - write/read/update/analytical queries you want to have in your workload. - Example: --mixed-workload 1000 20 70 10 0 will execute 1000 queries, 20% write, - 70% read, 10% update and 0% analytical. - - Mixed workload can also be run on each query under some defined load. - By passing one more positional argument, you are defining what percentage of that query - will be in mixed workload, and this is executed for each query. The rest of the queries will be - selected from the appropriate groups - Running --mixed-workload 1000 30 0 0 0 70, will execute each query 700 times or 70%, - with the presence of 300 write queries from write type or 30%""", -) - -parser.add_argument("--tail-latency", type=int, default=0, help="Number of queries for the tail latency statistics") - -parser.add_argument( - "--performance-tracking", - action="store_true", - default=False, - help="Flag for runners performance tracking, this logs RES through time and vendor specific performance tracking.", -) - -args = parser.parse_args() +QUERY_COUNT_LOWER_BOUND = 30 -class Workload: - def __init__(self, config): - config_len = len(config) - if config_len == 0: - self.name = "Isolated" - self.config = config - elif config_len >= 5: - if sum(config[1:]) != 100: - raise Exception( - "Please make sure that passed arguments % sum to 100% percent!, passed: ", - config, - ) - if config_len == 5: - self.name = "Realistic" - self.config = config - else: - self.name = "Mixed" - self.config = config +def parse_args(): + + parser = argparse.ArgumentParser( + description="Memgraph benchmark executor.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "benchmarks", + nargs="*", + default=None, + help="descriptions of benchmarks that should be run; " + "multiple descriptions can be specified to run multiple " + "benchmarks; the description is specified as " + "dataset/variant/group/query; Unix shell-style wildcards " + "can be used in the descriptions; variant, group and query " + "are optional and they can be left out; the default " + "variant is '' which selects the default dataset variant; " + "the default group is '*' which selects all groups; the" + "default query is '*' which selects all queries", + ) + parser.add_argument( + "--vendor-binary", + help="Vendor binary used for benchmarking, by default it is memgraph", + default=helpers.get_binary_path("memgraph"), + ) + + parser.add_argument( + "--vendor-name", + default="memgraph", + choices=["memgraph", "neo4j"], + help="Input vendor binary name (memgraph, neo4j)", + ) + parser.add_argument( + "--client-binary", + default=helpers.get_binary_path("tests/mgbench/client"), + help="Client binary used for benchmarking", + ) + parser.add_argument( + "--num-workers-for-import", + type=int, + default=multiprocessing.cpu_count() // 2, + help="number of workers used to import the dataset", + ) + parser.add_argument( + "--num-workers-for-benchmark", + type=int, + default=1, + help="number of workers used to execute the benchmark", + ) + parser.add_argument( + "--single-threaded-runtime-sec", + type=int, + default=10, + help="single threaded duration of each query", + ) + parser.add_argument( + "--no-load-query-counts", + action="store_true", + default=False, + help="disable loading of cached query counts", + ) + parser.add_argument( + "--no-save-query-counts", + action="store_true", + default=False, + help="disable storing of cached query counts", + ) + + parser.add_argument( + "--export-results", + default=None, + help="file path into which results should be exported", + ) + parser.add_argument( + "--temporary-directory", + default="/tmp", + help="directory path where temporary data should be stored", + ) + + parser.add_argument( + "--no-authorization", + action="store_false", + default=True, + help="Run each query with authorization", + ) + + parser.add_argument( + "--warm-up", + default="cold", + choices=["cold", "hot", "vulcanic"], + help="Run different warmups before benchmarks sample starts", + ) + + parser.add_argument( + "--workload-realistic", + nargs="*", + type=int, + default=None, + help="""Define combination that defines the realistic workload. + Realistic workload can be run as a single configuration for all groups of queries, + Pass the positional arguments as values of what percentage of + write/read/update/analytical queries you want to have in your workload. + Example: --workload-realistic 1000 20 70 10 0 will execute 1000 queries, 20% write, + 70% read, 10% update and 0% analytical.""", + ) + + parser.add_argument( + "--workload-mixed", + nargs="*", + type=int, + default=None, + help="""Mixed workload can be run on each query under some defined load. + By passing one more positional argument, you are defining what percentage of that query + will be in mixed workload, and this is executed for each query. The rest of the queries will be + selected from the appropriate groups + Running --mixed-workload 1000 30 0 0 0 70, will execute each query 700 times or 70%, + with the presence of 300 write queries from write type or 30%""", + ) + + parser.add_argument( + "--time-depended-execution", + type=int, + default=0, + help="Execute defined number of queries (based on single-threaded-runtime-sec) for a defined duration in of wall-clock time", + ) + + parser.add_argument( + "--performance-tracking", + action="store_true", + default=False, + help="Flag for runners performance tracking, this logs RES through time and vendor specific performance tracking.", + ) + + parser.add_argument("--customer-workloads", default=None, help="Path to customers workloads") + + parser.add_argument( + "--vendor-specific", + nargs="*", + default=[], + help="Vendor specific arguments that can be applied to each vendor, format: [key=value, key=value ...]", + ) + + return parser.parse_args() def get_queries(gen, count): @@ -184,117 +183,41 @@ def get_queries(gen, count): return ret -def match_patterns(dataset, variant, group, query, is_default_variant, patterns): - for pattern in patterns: - verdict = [fnmatch.fnmatchcase(dataset, pattern[0])] - if pattern[1] != "": - verdict.append(fnmatch.fnmatchcase(variant, pattern[1])) - else: - verdict.append(is_default_variant) - verdict.append(fnmatch.fnmatchcase(group, pattern[2])) - verdict.append(fnmatch.fnmatchcase(query, pattern[3])) - if all(verdict): - return True - return False - - -def filter_benchmarks(generators, patterns): - patterns = copy.deepcopy(patterns) - for i in range(len(patterns)): - pattern = patterns[i].split("/") - if len(pattern) > 5 or len(pattern) == 0: - raise Exception("Invalid benchmark description '" + pattern + "'!") - pattern.extend(["", "*", "*"][len(pattern) - 1 :]) - patterns[i] = pattern - filtered = [] - for dataset in sorted(generators.keys()): - generator, queries = generators[dataset] - for variant in generator.VARIANTS: - is_default_variant = variant == generator.DEFAULT_VARIANT - current = collections.defaultdict(list) - for group in queries: - for query_name, query_func in queries[group]: - if match_patterns( - dataset, - variant, - group, - query_name, - is_default_variant, - patterns, - ): - current[group].append((query_name, query_func)) - if len(current) == 0: - continue - - # Ignore benchgraph "basic" queries in standard CI/CD run - for pattern in patterns: - res = pattern.count("*") - key = "basic" - if res >= 2 and key in current.keys(): - current.pop(key) - - filtered.append((generator(variant, args.vendor_name), dict(current))) - return filtered - - -def warmup(client): - print("Executing warm-up queries") - client.execute( - queries=[ - ("CREATE ();", {}), - ("CREATE ()-[:TempEdge]->();", {}), - ("MATCH (n) RETURN n LIMIT 1;", {}), - ], - num_workers=1, - ) - - -def tail_latency(vendor, client, func): - iteration = args.tail_latency - if iteration >= 10: - vendor.start_benchmark("tail_latency") - if args.warmup_run: - warmup(client) - latency = [] - - query_list = get_queries(func, iteration) - for i in range(0, iteration): - ret = client.execute(queries=[query_list[i]], num_workers=1) - latency.append(ret[0]["duration"]) - latency.sort() - query_stats = { - "iterations": iteration, - "min": latency[0], - "max": latency[iteration - 1], - "mean": statistics.mean(latency), - "p99": latency[math.floor(iteration * 0.99) - 1], - "p95": latency[math.floor(iteration * 0.95) - 1], - "p90": latency[math.floor(iteration * 0.90) - 1], - "p75": latency[math.floor(iteration * 0.75) - 1], - "p50": latency[math.floor(iteration * 0.50) - 1], - } - print("Query statistics for tail latency: ") - print(query_stats) - vendor.stop("tail_latency") +def warmup(condition: str, client: runners.BaseRunner, queries: list = None): + log.log("Database condition {} ".format(condition)) + if condition == "hot": + log.log("Execute warm-up to match condition {} ".format(condition)) + client.execute( + queries=[ + ("CREATE ();", {}), + ("CREATE ()-[:TempEdge]->();", {}), + ("MATCH (n) RETURN n LIMIT 1;", {}), + ], + num_workers=1, + ) + elif condition == "vulcanic": + log.log("Execute warm-up to match condition {} ".format(condition)) + client.execute(queries=queries) else: - query_stats = {} - return query_stats + log.log("No warm-up on condition {} ".format(condition)) -def mixed_workload(vendor, client, dataset, group, queries, workload): +def mixed_workload( + vendor: runners.BaseRunner, client: runners.BaseClient, dataset, group, queries, benchmark_context: BenchmarkContext +): - num_of_queries = workload.config[0] - percentage_distribution = workload.config[1:] + num_of_queries = benchmark_context.mode_config[0] + percentage_distribution = benchmark_context.mode_config[1:] if sum(percentage_distribution) != 100: raise Exception( "Please make sure that passed arguments % sum to 100% percent!, passed: ", percentage_distribution, ) - s = [str(i) for i in workload.config] + s = [str(i) for i in benchmark_context.mode_config] config_distribution = "_".join(s) - print("Generating mixed workload.") + log.log("Generating mixed workload...") percentages_by_type = { "write": percentage_distribution[0], @@ -324,7 +247,7 @@ def mixed_workload(vendor, client, dataset, group, queries, workload): random.seed(config_distribution) # Executing mixed workload for each test - if workload.name == "Mixed": + if benchmark_context.mode == "Mixed": for query, funcname in queries[group]: full_workload = [] @@ -352,17 +275,16 @@ def mixed_workload(vendor, client, dataset, group, queries, workload): full_workload.append(base_query()) else: funcname = random.choices(queries_by_type[t], k=1)[0] - aditional_query = getattr(dataset, funcname) - full_workload.append(aditional_query()) + additional_query = getattr(dataset, funcname) + full_workload.append(additional_query()) vendor.start_benchmark( dataset.NAME + dataset.get_variant() + "_" + "mixed" + "_" + query + "_" + config_distribution ) - if args.warmup_run: - warmup(client) + warmup(benchmark_context.warm_up, client=client) ret = client.execute( queries=full_workload, - num_workers=args.num_workers_for_benchmark, + num_workers=benchmark_context.num_workers_for_benchmark, )[0] usage_workload = vendor.stop( dataset.NAME + dataset.get_variant() + "_" + "mixed" + "_" + query + "_" + config_distribution @@ -386,20 +308,19 @@ def mixed_workload(vendor, client, dataset, group, queries, workload): function_type = random.choices(population=options, weights=percentage_distribution, k=num_of_queries) for t in function_type: - # Get the apropropriate functions with same probabilty + # Get the appropriate functions with same probability funcname = random.choices(queries_by_type[t], k=1)[0] - aditional_query = getattr(dataset, funcname) - full_workload.append(aditional_query()) + additional_query = getattr(dataset, funcname) + full_workload.append(additional_query()) - vendor.start_benchmark(dataset.NAME + dataset.get_variant() + "_" + workload.name + "_" + config_distribution) - if args.warmup_run: - warmup(client) + vendor.start_benchmark(dataset.NAME + dataset.get_variant() + "_" + "realistic" + "_" + config_distribution) + warmup(benchmark_context.warm_up, client=client) ret = client.execute( queries=full_workload, - num_workers=args.num_workers_for_benchmark, + num_workers=benchmark_context.num_workers_for_benchmark, )[0] usage_workload = vendor.stop( - dataset.NAME + dataset.get_variant() + "_" + workload.name + "_" + config_distribution + dataset.NAME + dataset.get_variant() + "_" + "realistic" + "_" + config_distribution ) mixed_workload = { "count": ret["count"], @@ -421,29 +342,34 @@ def mixed_workload(vendor, client, dataset, group, queries, workload): print(mixed_workload) -def get_query_cache_count(vendor, client, func, config_key): - cached_count = config.get_value(*config_key) +def get_query_cache_count( + vendor: runners.BaseRunner, + client: runners.BaseClient, + queries: list, + config_key: list, + benchmark_context: BenchmarkContext, +): + cached_count = config.get_value(*config_key) if cached_count is None: - print( - "Determining the number of queries necessary for", - args.single_threaded_runtime_sec, - "seconds of single-threaded runtime...", + log.info( + "Determining the number of queries necessary for {} seconds of single-threaded runtime...".format( + benchmark_context.single_threaded_runtime_sec + ) ) # First run to prime the query caches. vendor.start_benchmark("cache") - if args.warmup_run: - warmup(client) - client.execute(queries=get_queries(func, 1), num_workers=1) + client.execute(queries=queries, num_workers=1) # Get a sense of the runtime. count = 1 while True: ret = client.execute(queries=get_queries(func, count), num_workers=1) duration = ret[0]["duration"] - should_execute = int(args.single_threaded_runtime_sec / (duration / count)) - print( - "executed_queries={}, total_duration={}, " - "query_duration={}, estimated_count={}".format(count, duration, duration / count, should_execute) + should_execute = int(benchmark_context.single_threaded_runtime_sec / (duration / count)) + log.log( + "executed_queries={}, total_duration={}, query_duration={}, estimated_count={}".format( + count, duration, duration / count, should_execute + ) ) # We don't have to execute the next iteration when # `should_execute` becomes the same order of magnitude as @@ -455,343 +381,336 @@ def get_query_cache_count(vendor, client, func, config_key): count = count * 10 vendor.stop("cache") - # Lower bound for count - if count < 20: - count = 20 + QUERY_COUNT_LOWER_BOUND = 30 + if count < QUERY_COUNT_LOWER_BOUND: + count = QUERY_COUNT_LOWER_BOUND config.set_value( *config_key, value={ "count": count, - "duration": args.single_threaded_runtime_sec, + "duration": benchmark_context.single_threaded_runtime_sec, }, ) else: - print( - "Using cached query count of", - cached_count["count"], - "queries for", - cached_count["duration"], - "seconds of single-threaded runtime.", + log.log( + "Using cached query count of {} queries for {} seconds of single-threaded runtime.".format( + cached_count["count"], cached_count["duration"] + ), ) - count = int(cached_count["count"] * args.single_threaded_runtime_sec / cached_count["duration"]) + count = int(cached_count["count"] * benchmark_context.single_threaded_runtime_sec / cached_count["duration"]) return count -# Testing pre commit. +if __name__ == "__main__": -# Detect available datasets. -generators = {} -for key in dir(datasets): - if key.startswith("_"): - continue - dataset = getattr(datasets, key) - if not inspect.isclass(dataset) or dataset == datasets.Dataset or not issubclass(dataset, datasets.Dataset): - continue - queries = collections.defaultdict(list) - for funcname in dir(dataset): - if not funcname.startswith("benchmark__"): - continue - group, query = funcname.split("__")[1:] - queries[group].append((query, funcname)) - generators[dataset.NAME] = (dataset, dict(queries)) - if dataset.PROPERTIES_ON_EDGES and args.no_properties_on_edges: - raise Exception( - 'The "{}" dataset requires properties on edges, ' "but you have disabled them!".format(dataset.NAME) - ) + args = parse_args() + vendor_specific_args = helpers.parse_kwargs(args.vendor_specific) -# List datasets if there is no specified dataset. -if len(args.benchmarks) == 0: - log.init("Available queries") - for name in sorted(generators.keys()): - print("Dataset:", name) - dataset, queries = generators[name] - print( - " Variants:", - ", ".join(dataset.VARIANTS), - "(default: " + dataset.DEFAULT_VARIANT + ")", - ) - for group in sorted(queries.keys()): - print(" Group:", group) - for query_name, query_func in queries[group]: - print(" Query:", query_name) - sys.exit(0) + assert args.benchmarks != None, helpers.list_available_workloads() + assert args.vendor_name == "memgraph" or args.vendor_name == "neo4j", "Unsupported vendors" + assert args.vendor_binary != None, "Pass database binary for runner" + assert args.client_binary != None, "Pass client binary for benchmark client " + assert args.num_workers_for_import > 0 + assert args.num_workers_for_benchmark > 0 + assert args.export_results != None, "Pass where will results be saved" + assert ( + args.single_threaded_runtime_sec >= 10 + ), "Low runtime value, consider extending time for more accurate results" + assert ( + args.workload_realistic == None or args.workload_mixed == None + ), "Cannot run both realistic and mixed workload, only one mode run at the time" -# Create cache, config and results objects. -cache = helpers.Cache() -if not args.no_load_query_counts: - config = cache.load_config() -else: - config = helpers.RecursiveDict() -results = helpers.RecursiveDict() + benchmark_context = BenchmarkContext( + benchmark_target_workload=args.benchmarks, + vendor_binary=args.vendor_binary, + vendor_name=args.vendor_name, + client_binary=args.client_binary, + num_workers_for_import=args.num_workers_for_import, + num_workers_for_benchmark=args.num_workers_for_benchmark, + single_threaded_runtime_sec=args.single_threaded_runtime_sec, + no_load_query_counts=args.no_load_query_counts, + export_results=args.export_results, + temporary_directory=args.temporary_directory, + workload_mixed=args.workload_mixed, + workload_realistic=args.workload_realistic, + time_dependent_execution=args.time_depended_execution, + warm_up=args.warm_up, + performance_tracking=args.performance_tracking, + no_authorization=args.no_authorization, + customer_workloads=args.customer_workloads, + vendor_args=vendor_specific_args, + ) -# Filter out the generators. -benchmarks = filter_benchmarks(generators, args.benchmarks) -# Run all specified benchmarks. -for dataset, queries in benchmarks: + log.init("Executing benchmark with following arguments: ") + for key, value in benchmark_context.__dict__.items(): + log.log(str(key) + " : " + str(value)) - workload = Workload(args.mixed_workload) + log.log("Creating cache folder for: dataset, configurations, indexes, results etc. ") + # Create cache, config and results objects. + cache = helpers.Cache() + log.init("Folder in use: " + cache.get_default_cache_directory()) + if not benchmark_context.no_load_query_counts: + log.log("Using previous cached query count data from cache directory.") + config = cache.load_config() + else: + config = helpers.RecursiveDict() + results = helpers.RecursiveDict() + + log.init("Creating vendor runner for DB: " + benchmark_context.vendor_name) + vendor_runner = runners.BaseRunner.create( + benchmark_context=benchmark_context, + ) + log.log("Class in use: " + str(vendor_runner)) run_config = { - "vendor": args.vendor_name, - "condition": "hot" if args.warmup_run else "cold", - "workload": workload.name, - "workload_config": workload.config, + "vendor": benchmark_context.vendor_name, + "condition": benchmark_context.warm_up, + "benchmark_mode": benchmark_context.mode, + "benchmark_mode_config": benchmark_context.mode_config, + "platform": platform.platform(), } results.set_value("__run_configuration__", value=run_config) - log.init("Preparing", dataset.NAME + "/" + dataset.get_variant(), "dataset") - dataset.prepare(cache.cache_directory("datasets", dataset.NAME, dataset.get_variant())) + available_workloads = helpers.get_available_workloads(benchmark_context.customer_workloads) - # TODO: Create some abstract class for vendors, that will hold this data - if args.vendor_name == "neo4j": - vendor = runners.Neo4j( - args.vendor_binary, - args.temporary_directory, - args.bolt_port, - args.performance_tracking, - ) - else: - vendor = runners.Memgraph( - args.vendor_binary, - args.temporary_directory, - not args.no_properties_on_edges, - args.bolt_port, - args.performance_tracking, - ) + log.init("Currently available workloads: ") + log.log(helpers.list_available_workloads(benchmark_context.customer_workloads)) - client = runners.Client(args.client_binary, args.temporary_directory, args.bolt_port) + # Filter out the workloads based on the pattern + target_workloads = helpers.filter_workloads( + available_workloads=available_workloads, benchmark_context=benchmark_context + ) - ret = None - usage = None - if args.vendor_name == "neo4j": - vendor.start_preparation("preparation") - print("Executing database cleanup and index setup...") - ret = client.execute(file_path=dataset.get_index(), num_workers=args.num_workers_for_import) - usage = vendor.stop("preparation") - dump_dir = cache.cache_directory("datasets", dataset.NAME, dataset.get_variant()) - dump_file, exists = dump_dir.get_file("neo4j.dump") - if exists: - vendor.load_db_from_dump(path=dump_dir.get_path()) + # Run all target workloads. + for workload, queries in target_workloads: + log.info("Started running following workload: " + str(workload.NAME)) + + log.info("Cleaning the database from any previous data") + vendor_runner.clean_db() + + client = vendor_runner.fetch_client() + log.log("Get appropriate client for vendor " + str(client)) + + ret = None + usage = None + + log.init("Preparing workload: " + workload.NAME + "/" + workload.get_variant()) + workload.prepare(cache.cache_directory("datasets", workload.NAME, workload.get_variant())) + generated_queries = workload.dataset_generator() + if generated_queries: + vendor_runner.start_preparation("import") + log.info("Using workload as dataset generator...") + if workload.get_index(): + log.info("Using index from specified file: {}".format(workload.get_index())) + client.execute(file_path=workload.get_index(), num_workers=benchmark_context.num_workers_for_import) + else: + log.warning("Make sure proper indexes/constraints are created in generated queries!") + ret = client.execute(queries=generated_queries, num_workers=benchmark_context.num_workers_for_import) + usage = vendor_runner.stop("import") else: - vendor.start_preparation("import") - print("Importing dataset...") - ret = client.execute(file_path=dataset.get_file(), num_workers=args.num_workers_for_import) - usage = vendor.stop("import") + log.info("Using workload dataset information for import...") + imported = workload.custom_import() + if not imported: + log.log("Basic import execution") + vendor_runner.start_preparation("import") + log.log("Executing database index setup...") + client.execute(file_path=workload.get_index(), num_workers=benchmark_context.num_workers_for_import) + log.log("Importing dataset...") + ret = client.execute( + file_path=workload.get_file(), num_workers=benchmark_context.num_workers_for_import + ) + usage = vendor_runner.stop("import") + else: + log.info("Custom import executed...") - vendor.dump_db(path=dump_dir.get_path()) - else: - vendor.start_preparation("import") - print("Executing database cleanup and index setup...") - ret = client.execute(file_path=dataset.get_index(), num_workers=args.num_workers_for_import) - print("Importing dataset...") - ret = client.execute(file_path=dataset.get_file(), num_workers=args.num_workers_for_import) - usage = vendor.stop("import") # Save import results. - import_key = [dataset.NAME, dataset.get_variant(), "__import__"] - if ret != None and usage != None: - # Display import statistics. - print() - for row in ret: - print( - "Executed", - row["count"], - "queries in", - row["duration"], - "seconds using", - row["num_workers"], - "workers with a total throughput of", - row["throughput"], - "queries/second.", + import_key = [workload.NAME, workload.get_variant(), "__import__"] + if ret != None and usage != None: + # Display import statistics. + for row in ret: + log.success( + "Executed {} queries in {} seconds using {} workers with a total throughput of {} + Q/S.".format( + row["count"], row["duration"], row["num_workers"], row["throughput"] + ) + ) + + log.success( + "The database used {} seconds of CPU time and peaked at {} MiB of RAM".format( + usage["cpu"], usage["memory"] / 1024 / 1024 + ), ) - print() - print( - "The database used", - usage["cpu"], - "seconds of CPU time and peaked at", - usage["memory"] / 1024 / 1024, - "MiB of RAM.", - ) - results.set_value(*import_key, value={"client": ret, "database": usage}) - else: - results.set_value(*import_key, value={"client": "dump_load", "database": "dump_load"}) - - # Run all benchmarks in all available groups. - for group in sorted(queries.keys()): - - # Running queries in mixed workload - if workload.name == "Mixed" or workload.name == "Realistic": - mixed_workload(vendor, client, dataset, group, queries, workload) + results.set_value(*import_key, value={"client": ret, "database": usage}) else: - for query, funcname in queries[group]: - log.info( - "Running query:", - "{}/{}/{}/{}".format(group, query, funcname, WITHOUT_FINE_GRAINED_AUTHORIZATION), - ) - func = getattr(dataset, funcname) + results.set_value(*import_key, value={"client": "custom_load", "database": "custom_load"}) - query_statistics = tail_latency(vendor, client, func) - - # Query count for each vendor - config_key = [ - dataset.NAME, - dataset.get_variant(), - args.vendor_name, - group, - query, - ] - count = get_query_cache_count(vendor, client, func, config_key) - - # Benchmark run. - print("Sample query:", get_queries(func, 1)[0][0]) - print( - "Executing benchmark with", - count, - "queries that should " "yield a single-threaded runtime of", - args.single_threaded_runtime_sec, - "seconds.", - ) - print( - "Queries are executed using", - args.num_workers_for_benchmark, - "concurrent clients.", - ) - vendor.start_benchmark(dataset.NAME + dataset.get_variant() + "_" + workload.name + "_" + query) - if args.warmup_run: - warmup(client) - ret = client.execute( - queries=get_queries(func, count), - num_workers=args.num_workers_for_benchmark, - )[0] - usage = vendor.stop(dataset.NAME + dataset.get_variant() + "_" + workload.name + "_" + query) - ret["database"] = usage - ret["query_statistics"] = query_statistics - - # Output summary. - print() - print("Executed", ret["count"], "queries in", ret["duration"], "seconds.") - print("Queries have been retried", ret["retries"], "times.") - print("Database used {:.3f} seconds of CPU time.".format(usage["cpu"])) - print("Database peaked at {:.3f} MiB of memory.".format(usage["memory"] / 1024.0 / 1024.0)) - print("{:<31} {:>20} {:>20} {:>20}".format("Metadata:", "min", "avg", "max")) - metadata = ret["metadata"] - for key in sorted(metadata.keys()): - print( - "{name:>30}: {minimum:>20.06f} {average:>20.06f} " - "{maximum:>20.06f}".format(name=key, **metadata[key]) + # Run all benchmarks in all available groups. + for group in sorted(queries.keys()): + log.init("Running benchmark in " + benchmark_context.mode) + if benchmark_context.mode == "Mixed": + mixed_workload(vendor_runner, client, workload, group, queries, benchmark_context) + elif benchmark_context.mode == "Realistic": + mixed_workload(vendor_runner, client, workload, group, queries, benchmark_context) + else: + for query, funcname in queries[group]: + log.init( + "Running query:" + + "{}/{}/{}/{}".format(group, query, funcname, WITHOUT_FINE_GRAINED_AUTHORIZATION), ) - log.success("Throughput: {:02f} QPS".format(ret["throughput"])) + func = getattr(workload, funcname) - # Save results. - results_key = [ - dataset.NAME, - dataset.get_variant(), - group, - query, - WITHOUT_FINE_GRAINED_AUTHORIZATION, - ] - results.set_value(*results_key, value=ret) - - ## If there is need for authorization testing. - if args.no_authorization: - print("Running query with authorization") - vendor.start_benchmark("authorization") - client.execute( - queries=[ - ("CREATE USER user IDENTIFIED BY 'test';", {}), - ("GRANT ALL PRIVILEGES TO user;", {}), - ("GRANT CREATE_DELETE ON EDGE_TYPES * TO user;", {}), - ("GRANT CREATE_DELETE ON LABELS * TO user;", {}), - ] - ) - client = runners.Client( - args.client_binary, - args.temporary_directory, - args.bolt_port, - username="user", - password="test", - ) - vendor.stop("authorization") - - for query, funcname in queries[group]: - - log.info( - "Running query:", - "{}/{}/{}/{}".format(group, query, funcname, WITH_FINE_GRAINED_AUTHORIZATION), - ) - func = getattr(dataset, funcname) - - query_statistics = tail_latency(vendor, client, func) - - config_key = [ - dataset.NAME, - dataset.get_variant(), - args.vendor_name, - group, - query, - ] - count = get_query_cache_count(vendor, client, func, config_key) - - vendor.start_benchmark("authorization") - if args.warmup_run: - warmup(client) - ret = client.execute( - queries=get_queries(func, count), - num_workers=args.num_workers_for_benchmark, - )[0] - usage = vendor.stop("authorization") - ret["database"] = usage - ret["query_statistics"] = query_statistics - - # Output summary. - print() - print( - "Executed", - ret["count"], - "queries in", - ret["duration"], - "seconds.", - ) - print("Queries have been retried", ret["retries"], "times.") - print("Database used {:.3f} seconds of CPU time.".format(usage["cpu"])) - print("Database peaked at {:.3f} MiB of memory.".format(usage["memory"] / 1024.0 / 1024.0)) - print("{:<31} {:>20} {:>20} {:>20}".format("Metadata:", "min", "avg", "max")) - metadata = ret["metadata"] - for key in sorted(metadata.keys()): - print( - "{name:>30}: {minimum:>20.06f} {average:>20.06f} " - "{maximum:>20.06f}".format(name=key, **metadata[key]) + # Query count + config_key = [ + workload.NAME, + workload.get_variant(), + group, + query, + ] + count = get_query_cache_count( + vendor_runner, client, get_queries(func, 1), config_key, benchmark_context ) - log.success("Throughput: {:02f} QPS".format(ret["throughput"])) - # Save results. - results_key = [ - dataset.NAME, - dataset.get_variant(), - group, - query, - WITH_FINE_GRAINED_AUTHORIZATION, - ] - results.set_value(*results_key, value=ret) - # Clean up database from any roles and users job - vendor.start_benchmark("authorizations") - ret = client.execute( - queries=[ - ("REVOKE LABELS * FROM user;", {}), - ("REVOKE EDGE_TYPES * FROM user;", {}), - ("DROP USER user;", {}), - ] - ) - vendor.stop("authorization") + # Benchmark run. + log.info("Sample query:{}".format(get_queries(func, 1)[0][0])) + log.log( + "Executing benchmark with {} queries that should yield a single-threaded runtime of {} seconds.".format( + count, benchmark_context.single_threaded_runtime_sec + ) + ) + log.log( + "Queries are executed using {} concurrent clients".format( + benchmark_context.num_workers_for_benchmark + ) + ) + vendor_runner.start_benchmark( + workload.NAME + workload.get_variant() + "_" + "_" + benchmark_context.mode + "_" + query + ) + warmup(condition=benchmark_context.warm_up, client=client, queries=get_queries(func, count)) + if benchmark_context.time_dependent_execution != 0: + ret = client.execute( + queries=get_queries(func, count), + num_workers=benchmark_context.num_workers_for_benchmark, + time_dependent_execution=benchmark_context.time_depended_execution, + )[0] + else: + ret = client.execute( + queries=get_queries(func, count), + num_workers=benchmark_context.num_workers_for_benchmark, + )[0] -# Save configuration. -if not args.no_save_query_counts: - cache.save_config(config) + usage = vendor_runner.stop( + workload.NAME + workload.get_variant() + "_" + benchmark_context.mode + "_" + query + ) + ret["database"] = usage + # Output summary. -# Export results. -if args.export_results: - with open(args.export_results, "w") as f: - json.dump(results.get_data(), f) + log.log("Executed {} queries in {} seconds.".format(ret["count"], ret["duration"])) + log.log("Queries have been retried {} times".format(ret["retries"])) + log.log("Database used {:.3f} seconds of CPU time.".format(usage["cpu"])) + log.log("Database peaked at {:.3f} MiB of memory.".format(usage["memory"] / 1024.0 / 1024.0)) + log.log("{:<31} {:>20} {:>20} {:>20}".format("Metadata:", "min", "avg", "max")) + metadata = ret["metadata"] + for key in sorted(metadata.keys()): + log.log( + "{name:>30}: {minimum:>20.06f} {average:>20.06f} " + "{maximum:>20.06f}".format(name=key, **metadata[key]) + ) + log.success("Throughput: {:02f} QPS".format(ret["throughput"])) + + # Save results. + results_key = [ + workload.NAME, + workload.get_variant(), + group, + query, + WITHOUT_FINE_GRAINED_AUTHORIZATION, + ] + results.set_value(*results_key, value=ret) + + # If there is need for authorization testing. + if benchmark_context.no_authorization: + log.info("Running queries with authorization...") + vendor_runner.start_benchmark("authorization") + client.execute( + queries=[ + ("CREATE USER user IDENTIFIED BY 'test';", {}), + ("GRANT ALL PRIVILEGES TO user;", {}), + ("GRANT CREATE_DELETE ON EDGE_TYPES * TO user;", {}), + ("GRANT CREATE_DELETE ON LABELS * TO user;", {}), + ] + ) + + client.set_credentials(username="user", password="test") + vendor_runner.stop("authorization") + + for query, funcname in queries[group]: + + log.info( + "Running query:", + "{}/{}/{}/{}".format(group, query, funcname, WITH_FINE_GRAINED_AUTHORIZATION), + ) + func = getattr(workload, funcname) + + config_key = [ + workload.NAME, + workload.get_variant(), + group, + query, + ] + count = get_query_cache_count( + vendor_runner, client, get_queries(func, 1), config_key, benchmark_context + ) + + vendor_runner.start_benchmark("authorization") + warmup(condition=benchmark_context.warm_up, client=client, queries=get_queries(func, count)) + ret = client.execute( + queries=get_queries(func, count), + num_workers=benchmark_context.num_workers_for_benchmark, + )[0] + usage = vendor_runner.stop("authorization") + ret["database"] = usage + # Output summary. + log.log("Executed {} queries in {} seconds.".format(ret["count"], ret["duration"])) + log.log("Queries have been retried {} times".format(ret["retries"])) + log.log("Database used {:.3f} seconds of CPU time.".format(usage["cpu"])) + log.log("Database peaked at {:.3f} MiB of memory.".format(usage["memory"] / 1024.0 / 1024.0)) + log.log("{:<31} {:>20} {:>20} {:>20}".format("Metadata:", "min", "avg", "max")) + metadata = ret["metadata"] + for key in sorted(metadata.keys()): + log.log( + "{name:>30}: {minimum:>20.06f} {average:>20.06f} " + "{maximum:>20.06f}".format(name=key, **metadata[key]) + ) + log.success("Throughput: {:02f} QPS".format(ret["throughput"])) + # Save results. + results_key = [ + workload.NAME, + workload.get_variant(), + group, + query, + WITH_FINE_GRAINED_AUTHORIZATION, + ] + results.set_value(*results_key, value=ret) + + # Clean up database from any roles and users job + vendor_runner.start_benchmark("authorizations") + ret = client.execute( + queries=[ + ("REVOKE LABELS * FROM user;", {}), + ("REVOKE EDGE_TYPES * FROM user;", {}), + ("DROP USER user;", {}), + ] + ) + vendor_runner.stop("authorization") + + # Save configuration. + if not benchmark_context.no_save_query_counts: + cache.save_config(config) + + # Export results. + if benchmark_context.export_results: + with open(benchmark_context.export_results, "w") as f: + json.dump(results.get_data(), f) diff --git a/tests/mgbench/benchmark_context.py b/tests/mgbench/benchmark_context.py new file mode 100644 index 000000000..a01f253ce --- /dev/null +++ b/tests/mgbench/benchmark_context.py @@ -0,0 +1,57 @@ +# Describes all the information of single benchmark.py run. +class BenchmarkContext: + """ + Class for holding information on what type of benchmark is being executed + """ + + def __init__( + self, + benchmark_target_workload: str = None, # Workload that needs to be executed (dataset/variant/group/query) + vendor_binary: str = None, # Benchmark vendor binary + vendor_name: str = None, + client_binary: str = None, + num_workers_for_import: int = None, + num_workers_for_benchmark: int = None, + single_threaded_runtime_sec: int = 0, + no_load_query_counts: bool = False, + no_save_query_counts: bool = False, + export_results: str = None, + temporary_directory: str = None, + workload_mixed: str = None, # Default mode is isolated, mixed None + workload_realistic: str = None, # Default mode is isolated, realistic None + time_dependent_execution: int = 0, + warm_up: str = None, + performance_tracking: bool = False, + no_authorization: bool = True, + customer_workloads: str = None, + vendor_args: dict = {}, + ) -> None: + + self.benchmark_target_workload = benchmark_target_workload + self.vendor_binary = vendor_binary + self.vendor_name = vendor_name + self.client_binary = client_binary + self.num_workers_for_import = num_workers_for_import + self.num_workers_for_benchmark = num_workers_for_benchmark + self.single_threaded_runtime_sec = single_threaded_runtime_sec + self.no_load_query_counts = no_load_query_counts + self.no_save_query_counts = no_save_query_counts + self.export_results = export_results + self.temporary_directory = temporary_directory + + if workload_mixed != None: + self.mode = "Mixed" + self.mode_config = workload_mixed + elif workload_realistic != None: + self.mode = "Realistic" + self.mode_config = workload_realistic + else: + self.mode = "Isolated" + self.mode_config = "Isolated run does not have a config." + + self.time_dependent_execution = time_dependent_execution + self.performance_tracking = performance_tracking + self.warm_up = warm_up + self.no_authorization = no_authorization + self.customer_workloads = customer_workloads + self.vendor_args = vendor_args diff --git a/tests/mgbench/client.cpp b/tests/mgbench/client.cpp index b18567fda..12495a50c 100644 --- a/tests/mgbench/client.cpp +++ b/tests/mgbench/client.cpp @@ -289,6 +289,7 @@ void ExecuteTimeDependentWorkload( // Synchronize workers and collect runtime. while (ready.load(std::memory_order_acq_rel) < FLAGS_num_workers) ; + run.store(true); for (int i = 0; i < FLAGS_num_workers; ++i) { threads[i].join(); @@ -310,6 +311,7 @@ void ExecuteTimeDependentWorkload( final_duration /= FLAGS_num_workers; double execution_delta = time_limit.count() / final_duration; + // This is adjusted throughput based on how much longer did workload execution time took. double throughput = (total_iterations / final_duration) * execution_delta; double raw_throughput = total_iterations / final_duration; @@ -319,7 +321,6 @@ void ExecuteTimeDependentWorkload( summary["duration"] = final_duration; summary["time_limit"] = FLAGS_time_dependent_execution; summary["queries_executed"] = total_iterations; - summary["throughput"] = throughput; summary["raw_throughput"] = raw_throughput; summary["latency_stats"] = LatencyStatistics(worker_query_durations); diff --git a/tests/mgbench/compare_results.py b/tests/mgbench/compare_results.py index a13b84c7e..65c77fc3f 100755 --- a/tests/mgbench/compare_results.py +++ b/tests/mgbench/compare_results.py @@ -77,10 +77,10 @@ def compare_results(results_from, results_to, fields, ignored, different_vendors recursive_get(summary_from, "database", key, value=None), summary_to["database"][key], ) - elif summary_to.get("query_statistics") != None and key in summary_to["query_statistics"]: + elif summary_to.get("latency_stats") != None and key in summary_to["latency_stats"]: row[key] = compute_diff( - recursive_get(summary_from, "query_statistics", key, value=None), - summary_to["query_statistics"][key], + recursive_get(summary_from, "latency_stats", key, value=None), + summary_to["latency_stats"][key], ) elif not different_vendors: row[key] = compute_diff( @@ -160,7 +160,10 @@ if __name__ == "__main__": help="Comparing different vendors, there is no need for metadata, duration, count check.", ) parser.add_argument( - "--difference-threshold", type=float, help="Difference threshold for memory and throughput, 0.02 = 2% " + "--difference-threshold", + type=float, + default=0.02, + help="Difference threshold for memory and throughput, 0.02 = 2% ", ) args = parser.parse_args() diff --git a/tests/mgbench/cypher/__init__.py b/tests/mgbench/cypher/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/mgbench/cypher/ldbc_to_cypher.py b/tests/mgbench/cypher/ldbc_to_cypher.py new file mode 100644 index 000000000..0bd06b471 --- /dev/null +++ b/tests/mgbench/cypher/ldbc_to_cypher.py @@ -0,0 +1,500 @@ +import argparse +import csv +import sys +from collections import defaultdict +from pathlib import Path + +import helpers + +# Most recent list of LDBC datasets available at: https://github.com/ldbc/data-sets-surf-repository +INTERACTIVE_LINK = { + "sf0.1": "https://repository.surfsara.nl/datasets/cwi/snb/files/social_network-csv_basic/social_network-csv_basic-sf0.1.tar.zst", + "sf0.3": "https://repository.surfsara.nl/datasets/cwi/snb/files/social_network-csv_basic/social_network-csv_basic-sf0.3.tar.zst", + "sf1": "https://repository.surfsara.nl/datasets/cwi/snb/files/social_network-csv_basic/social_network-csv_basic-sf1.tar.zst", + "sf3": "https://repository.surfsara.nl/datasets/cwi/snb/files/social_network-csv_basic/social_network-csv_basic-sf3.tar.zst", + "sf10": "https://repository.surfsara.nl/datasets/cwi/snb/files/social_network-csv_basic/social_network-csv_basic-sf10.tar.zst", +} + + +BI_LINK = { + "sf1": "https://pub-383410a98aef4cb686f0c7601eddd25f.r2.dev/bi-pre-audit/bi-sf1-composite-projected-fk.tar.zst", + "sf3": "https://pub-383410a98aef4cb686f0c7601eddd25f.r2.dev/bi-pre-audit/bi-sf3-composite-projected-fk.tar.zst", + "sf10": "https://pub-383410a98aef4cb686f0c7601eddd25f.r2.dev/bi-pre-audit/bi-sf10-composite-projected-fk.tar.zst", +} + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + prog="LDBC CSV to CYPHERL converter", + description="""Converts all LDBC CSV files to CYPHERL transactions, for faster Memgraph load""", + ) + parser.add_argument( + "--size", + required=True, + choices=["0.1", "0.3", "1", "3", "10"], + help="Interactive: (0.1 , 0.3, 1, 3, 10) BI: (1, 3, 10)", + ) + parser.add_argument("--type", required=True, choices=["interactive", "bi"], help="interactive or bi") + + args = parser.parse_args() + output_directory = Path().absolute() / ".cache" / "LDBC_generated" + output_directory.mkdir(exist_ok=True) + + if args.type == "interactive": + + NODES_INTERACTIVE = [ + {"filename": "Place", "label": "Place"}, + {"filename": "Organisation", "label": "Organisation"}, + {"filename": "TagClass", "label": "TagClass"}, + {"filename": "Tag", "label": "Tag"}, + {"filename": "Comment", "label": "Message:Comment"}, + {"filename": "Forum", "label": "Forum"}, + {"filename": "Person", "label": "Person"}, + {"filename": "Post", "label": "Message:Post"}, + ] + + EDGES_INTERACTIVE = [ + { + "filename": "Place_isPartOf_Place", + "source_label": "Place", + "type": "IS_PART_OF", + "target_label": "Place", + }, + { + "filename": "TagClass_isSubclassOf_TagClass", + "source_label": "TagClass", + "type": "IS_SUBCLASS_OF", + "target_label": "TagClass", + }, + { + "filename": "Organisation_isLocatedIn_Place", + "source_label": "Organisation", + "type": "IS_LOCATED_IN", + "target_label": "Place", + }, + {"filename": "Tag_hasType_TagClass", "source_label": "Tag", "type": "HAS_TYPE", "target_label": "TagClass"}, + { + "filename": "Comment_hasCreator_Person", + "source_label": "Comment", + "type": "HAS_CREATOR", + "target_label": "Person", + }, + { + "filename": "Comment_isLocatedIn_Place", + "source_label": "Comment", + "type": "IS_LOCATED_IN", + "target_label": "Place", + }, + { + "filename": "Comment_replyOf_Comment", + "source_label": "Comment", + "type": "REPLY_OF", + "target_label": "Comment", + }, + {"filename": "Comment_replyOf_Post", "source_label": "Comment", "type": "REPLY_OF", "target_label": "Post"}, + { + "filename": "Forum_containerOf_Post", + "source_label": "Forum", + "type": "CONTAINER_OF", + "target_label": "Post", + }, + { + "filename": "Forum_hasMember_Person", + "source_label": "Forum", + "type": "HAS_MEMBER", + "target_label": "Person", + }, + { + "filename": "Forum_hasModerator_Person", + "source_label": "Forum", + "type": "HAS_MODERATOR", + "target_label": "Person", + }, + {"filename": "Forum_hasTag_Tag", "source_label": "Forum", "type": "HAS_TAG", "target_label": "Tag"}, + { + "filename": "Person_hasInterest_Tag", + "source_label": "Person", + "type": "HAS_INTEREST", + "target_label": "Tag", + }, + { + "filename": "Person_isLocatedIn_Place", + "source_label": "Person", + "type": "IS_LOCATED_IN", + "target_label": "Place", + }, + {"filename": "Person_knows_Person", "source_label": "Person", "type": "KNOWS", "target_label": "Person"}, + {"filename": "Person_likes_Comment", "source_label": "Person", "type": "LIKES", "target_label": "Comment"}, + {"filename": "Person_likes_Post", "source_label": "Person", "type": "LIKES", "target_label": "Post"}, + { + "filename": "Post_hasCreator_Person", + "source_label": "Post", + "type": "HAS_CREATOR", + "target_label": "Person", + }, + {"filename": "Comment_hasTag_Tag", "source_label": "Comment", "type": "HAS_TAG", "target_label": "Tag"}, + {"filename": "Post_hasTag_Tag", "source_label": "Post", "type": "HAS_TAG", "target_label": "Tag"}, + { + "filename": "Post_isLocatedIn_Place", + "source_label": "Post", + "type": "IS_LOCATED_IN", + "target_label": "Place", + }, + { + "filename": "Person_studyAt_Organisation", + "source_label": "Person", + "type": "STUDY_AT", + "target_label": "Organisation", + }, + { + "filename": "Person_workAt_Organisation", + "source_label": "Person", + "type": "WORK_AT", + "target_label": "Organisation", + }, + ] + + file_size = "sf{}".format(args.size) + out_file = "ldbc_interactive_{}.cypher".format(file_size) + output = output_directory / out_file + if output.exists(): + output.unlink() + + files_present = None + for file in output_directory.glob("**/*.tar.zst"): + if "basic-" + file_size in file.name: + files_present = file.with_suffix("").with_suffix("") + break + + if not files_present: + try: + print("Downloading the file... " + INTERACTIVE_LINK[file_size]) + downloaded_file = helpers.download_file(INTERACTIVE_LINK[file_size], output_directory.absolute()) + print("Unpacking the file..." + downloaded_file) + files_present = helpers.unpack_tar_zst(Path(downloaded_file)) + except: + print("Issue with downloading and unpacking the file, check if links are working properly.") + raise + + input_files = {} + for file in files_present.glob("**/*.csv"): + name = file.name.replace("_0_0.csv", "").lower() + input_files[name] = file + + for node_file in NODES_INTERACTIVE: + key = node_file["filename"].lower() + default_label = node_file["label"] + query = None + if key in input_files.keys(): + with input_files[key].open("r") as input_f, output.open("a") as output_f: + reader = csv.DictReader(input_f, delimiter="|") + + for row in reader: + if "type" in row.keys(): + label = default_label + ":" + row.pop("type").capitalize() + else: + label = default_label + + query = "CREATE (:{} {{id:{}, ".format(label, row.pop("id")) + # Format properties to fit Memgraph + for k, v in row.items(): + if k == "creationDate": + row[k] = 'localDateTime("{}")'.format(v[0:-5]) + elif k == "birthday": + row[k] = 'date("{}")'.format(v) + elif k == "length": + row[k] = "toInteger({})".format(v) + else: + row[k] = '"{}"'.format(v) + + prop_string = ", ".join("{} : {}".format(k, v) for k, v in row.items()) + query = query + prop_string + "});" + output_f.write(query + "\n") + print("Converted file: " + input_files[key].name + " to " + output.name) + else: + print("Didn't process node file: " + key) + raise Exception("Didn't find the file that was needed!") + + for edge_file in EDGES_INTERACTIVE: + key = edge_file["filename"].lower() + source_label = edge_file["source_label"] + edge_type = edge_file["type"] + target_label = edge_file["target_label"] + if key in input_files.keys(): + query = None + with input_files[key].open("r") as input_f, output.open("a") as output_f: + sufixl = ".id" + sufixr = ".id" + # Handle identical label/key in CSV header + if source_label == target_label: + sufixl = "l" + sufixr = "r" + # Move a place from header + header = next(input_f).strip().split("|") + reader = csv.DictReader( + input_f, delimiter="|", fieldnames=([source_label + sufixl, target_label + sufixr] + header[2:]) + ) + + for row in reader: + query = "MATCH (n1:{} {{id:{}}}), (n2:{} {{id:{}}}) ".format( + source_label, row.pop(source_label + sufixl), target_label, row.pop(target_label + sufixr) + ) + for k, v in row.items(): + if "date" in k.lower(): + # Take time zone out + row[k] = 'localDateTime("{}")'.format(v[0:-5]) + elif "workfrom" in k.lower() or "classyear" in k.lower(): + row[k] = 'toInteger("{}")'.format(v) + else: + row[k] = '"{}"'.format(v) + + edge_part = "CREATE (n1)-[:{}{{".format(edge_type) + prop_string = ", ".join("{} : {}".format(k, v) for k, v in row.items()) + + query = query + edge_part + prop_string + "}]->(n2);" + output_f.write(query + "\n") + print("Converted file: " + input_files[key].name + " to " + output.name) + else: + print("Didn't process Edge file: " + key) + raise Exception("Didn't find the file that was needed!") + + elif args.type == "bi": + + NODES_BI = [ + {"filename": "Place", "label": "Place"}, + {"filename": "Organisation", "label": "Organisation"}, + {"filename": "TagClass", "label": "TagClass"}, + {"filename": "Tag", "label": "Tag"}, + {"filename": "Comment", "label": "Message:Comment"}, + {"filename": "Forum", "label": "Forum"}, + {"filename": "Person", "label": "Person"}, + {"filename": "Post", "label": "Message:Post"}, + ] + + EDGES_BI = [ + { + "filename": "Place_isPartOf_Place", + "source_label": "Place", + "type": "IS_PART_OF", + "target_label": "Place", + }, + { + "filename": "TagClass_isSubclassOf_TagClass", + "source_label": "TagClass", + "type": "IS_SUBCLASS_OF", + "target_label": "TagClass", + }, + { + "filename": "Organisation_isLocatedIn_Place", + "source_label": "Organisation", + "type": "IS_LOCATED_IN", + "target_label": "Place", + }, + {"filename": "Tag_hasType_TagClass", "source_label": "Tag", "type": "HAS_TYPE", "target_label": "TagClass"}, + { + "filename": "Comment_hasCreator_Person", + "source_label": "Comment", + "type": "HAS_CREATOR", + "target_label": "Person", + }, + # Change place to Country + { + "filename": "Comment_isLocatedIn_Country", + "source_label": "Comment", + "type": "IS_LOCATED_IN", + "target_label": "Country", + }, + { + "filename": "Comment_replyOf_Comment", + "source_label": "Comment", + "type": "REPLY_OF", + "target_label": "Comment", + }, + {"filename": "Comment_replyOf_Post", "source_label": "Comment", "type": "REPLY_OF", "target_label": "Post"}, + { + "filename": "Forum_containerOf_Post", + "source_label": "Forum", + "type": "CONTAINER_OF", + "target_label": "Post", + }, + { + "filename": "Forum_hasMember_Person", + "source_label": "Forum", + "type": "HAS_MEMBER", + "target_label": "Person", + }, + { + "filename": "Forum_hasModerator_Person", + "source_label": "Forum", + "type": "HAS_MODERATOR", + "target_label": "Person", + }, + {"filename": "Forum_hasTag_Tag", "source_label": "Forum", "type": "HAS_TAG", "target_label": "Tag"}, + { + "filename": "Person_hasInterest_Tag", + "source_label": "Person", + "type": "HAS_INTEREST", + "target_label": "Tag", + }, + # Changed place to City + { + "filename": "Person_isLocatedIn_City", + "source_label": "Person", + "type": "IS_LOCATED_IN", + "target_label": "City", + }, + {"filename": "Person_knows_Person", "source_label": "Person", "type": "KNOWS", "target_label": "Person"}, + {"filename": "Person_likes_Comment", "source_label": "Person", "type": "LIKES", "target_label": "Comment"}, + {"filename": "Person_likes_Post", "source_label": "Person", "type": "LIKES", "target_label": "Post"}, + { + "filename": "Post_hasCreator_Person", + "source_label": "Post", + "type": "HAS_CREATOR", + "target_label": "Person", + }, + {"filename": "Comment_hasTag_Tag", "source_label": "Comment", "type": "HAS_TAG", "target_label": "Tag"}, + {"filename": "Post_hasTag_Tag", "source_label": "Post", "type": "HAS_TAG", "target_label": "Tag"}, + # Change place to Country + { + "filename": "Post_isLocatedIn_Country", + "source_label": "Post", + "type": "IS_LOCATED_IN", + "target_label": "Country", + }, + # Changed organisation to University + { + "filename": "Person_studyAt_University", + "source_label": "Person", + "type": "STUDY_AT", + "target_label": "University", + }, + # Changed organisation to Company + { + "filename": "Person_workAt_Company", + "source_label": "Person", + "type": "WORK_AT", + "target_label": "Company", + }, + ] + + file_size = "sf{}".format(args.size) + out_file = "ldbc_bi_{}.cypher".format(file_size) + output = output_directory / out_file + if output.exists(): + output.unlink() + + files_present = None + for file in output_directory.glob("**/*.tar.zst"): + if "bi-" + file_size in file.name: + files_present = file.with_suffix("").with_suffix("") + break + + if not files_present: + try: + print("Downloading the file... " + BI_LINK[file_size]) + downloaded_file = helpers.download_file(BI_LINK[file_size], output_directory.absolute()) + print("Unpacking the file..." + downloaded_file) + files_present = helpers.unpack_tar_zst(Path(downloaded_file)) + except: + print("Issue with downloading and unpacking the file, check if links are working properly.") + raise + + for file in files_present.glob("**/*.csv.gz"): + if "initial_snapshot" in file.parts: + helpers.unpack_gz(file) + + input_files = defaultdict(list) + for file in files_present.glob("**/*.csv"): + key = file.parents[0].name + input_files[file.parents[0].name].append(file) + + for node_file in NODES_BI: + key = node_file["filename"] + default_label = node_file["label"] + query = None + if key in input_files.keys(): + for part_file in input_files[key]: + with part_file.open("r") as input_f, output.open("a") as output_f: + reader = csv.DictReader(input_f, delimiter="|") + + for row in reader: + if "type" in row.keys(): + label = default_label + ":" + row.pop("type") + else: + label = default_label + + query = "CREATE (:{} {{id:{}, ".format(label, row.pop("id")) + # Format properties to fit Memgraph + for k, v in row.items(): + if k == "creationDate": + row[k] = 'localDateTime("{}")'.format(v[0:-6]) + elif k == "birthday": + row[k] = 'date("{}")'.format(v) + elif k == "length": + row[k] = "toInteger({})".format(v) + else: + row[k] = '"{}"'.format(v) + + prop_string = ", ".join("{} : {}".format(k, v) for k, v in row.items()) + query = query + prop_string + "});" + output_f.write(query + "\n") + print("Key: " + key + " Converted file: " + part_file.name + " to " + output.name) + else: + print("Didn't process node file: " + key) + + for edge_file in EDGES_BI: + key = edge_file["filename"] + source_label = edge_file["source_label"] + edge_type = edge_file["type"] + target_label = edge_file["target_label"] + if key in input_files.keys(): + for part_file in input_files[key]: + query = None + with part_file.open("r") as input_f, output.open("a") as output_f: + sufixl = "Id" + sufixr = "Id" + # Handle identical label/key in CSV header + if source_label == target_label: + sufixl = "l" + sufixr = "r" + # Move a place from header + header = next(input_f).strip().split("|") + if len(header) >= 3: + reader = csv.DictReader( + input_f, + delimiter="|", + fieldnames=(["date", source_label + sufixl, target_label + sufixr] + header[3:]), + ) + else: + reader = csv.DictReader( + input_f, + delimiter="|", + fieldnames=([source_label + sufixl, target_label + sufixr] + header[2:]), + ) + + for row in reader: + query = "MATCH (n1:{} {{id:{}}}), (n2:{} {{id:{}}}) ".format( + source_label, + row.pop(source_label + sufixl), + target_label, + row.pop(target_label + sufixr), + ) + for k, v in row.items(): + if "date" in k.lower(): + # Take time zone out + row[k] = 'localDateTime("{}")'.format(v[0:-6]) + elif k == "classYear" or k == "workFrom": + row[k] = 'toInteger("{}")'.format(v) + else: + row[k] = '"{}"'.format(v) + + edge_part = "CREATE (n1)-[:{}{{".format(edge_type) + prop_string = ", ".join("{} : {}".format(k, v) for k, v in row.items()) + + query = query + edge_part + prop_string + "}]->(n2);" + output_f.write(query + "\n") + print("Key: " + key + " Converted file: " + part_file.name + " to " + output.name) + else: + print("Didn't process Edge file: " + key) + raise Exception("Didn't find the file that was needed!") diff --git a/tests/mgbench/graph_bench.py b/tests/mgbench/graph_bench.py index d1a633081..e0a4fac85 100644 --- a/tests/mgbench/graph_bench.py +++ b/tests/mgbench/graph_bench.py @@ -16,14 +16,20 @@ def parse_arguments(): help="Forward name and paths to vendors binary" "Example: --vendor memgraph /path/to/binary --vendor neo4j /path/to/binary", ) + parser.add_argument( - "--dataset-size", - default="small", - choices=["small", "medium", "large"], - help="Pick a dataset size (small, medium, large)", + "--dataset-name", + default="", + help="Dataset name you wish to execute", ) - parser.add_argument("--dataset-group", default="basic", help="Select a group of queries") + parser.add_argument( + "--dataset-size", + default="", + help="Pick a dataset variant you wish to execute", + ) + + parser.add_argument("--dataset-group", default="", help="Select a group of queries") parser.add_argument( "--realistic", @@ -53,88 +59,110 @@ def parse_arguments(): return args -def run_full_benchmarks(vendor, binary, dataset_size, dataset_group, realistic, mixed): +def run_full_benchmarks(vendor, binary, dataset, dataset_size, dataset_group, realistic, mixed): configurations = [ - # Basic full group test cold + # Basic isolated test cold [ "--export-results", - vendor + "_" + dataset_size + "_cold_isolated.json", + vendor + "_" + dataset + "_" + dataset_size + "_cold_isolated.json", ], - # Basic full group test hot + # Basic isolated test hot [ "--export-results", - vendor + "_" + dataset_size + "_hot_isolated.json", - "--warmup-run", + vendor + "_" + dataset + "_" + dataset_size + "_hot_isolated.json", + "--warm-up", + "hot", + ], + # Basic isolated test vulcanic + [ + "--export-results", + vendor + "_" + dataset + "_" + dataset_size + "_vulcanic_isolated.json", + "--warm-up", + "vulcanic", ], ] - # Configurations for full workload - for count, write, read, update, analytical in realistic: - cold = [ - "--export-results", - vendor - + "_" - + dataset_size - + "_cold_realistic_{}_{}_{}_{}_{}.json".format(count, write, read, update, analytical), - "--mixed-workload", - count, - write, - read, - update, - analytical, - ] + if realistic: + # Configurations for full workload + for count, write, read, update, analytical in realistic: + cold = [ + "--export-results", + vendor + + "_" + + dataset + + "_" + + dataset_size + + "_cold_realistic_{}_{}_{}_{}_{}.json".format(count, write, read, update, analytical), + "--workload-realistic", + count, + write, + read, + update, + analytical, + ] - hot = [ - "--export-results", - vendor - + "_" - + dataset_size - + "_hot_realistic_{}_{}_{}_{}_{}.json".format(count, write, read, update, analytical), - "--warmup-run", - "--mixed-workload", - count, - write, - read, - update, - analytical, - ] - configurations.append(cold) - configurations.append(hot) + hot = [ + "--export-results", + vendor + + "_" + + dataset + + "_" + + dataset_size + + "_hot_realistic_{}_{}_{}_{}_{}.json".format(count, write, read, update, analytical), + "--warm-up", + "hot", + "--workload-realistic", + count, + write, + read, + update, + analytical, + ] - # Configurations for workload per query - for count, write, read, update, analytical, query in mixed: - cold = [ - "--export-results", - vendor - + "_" - + dataset_size - + "_cold_mixed_{}_{}_{}_{}_{}_{}.json".format(count, write, read, update, analytical, query), - "--mixed-workload", - count, - write, - read, - update, - analytical, - query, - ] - hot = [ - "--export-results", - vendor - + "_" - + dataset_size - + "_hot_mixed_{}_{}_{}_{}_{}_{}.json".format(count, write, read, update, analytical, query), - "--warmup-run", - "--mixed-workload", - count, - write, - read, - update, - analytical, - query, - ] - configurations.append(cold) - configurations.append(hot) + configurations.append(cold) + configurations.append(hot) + + if mixed: + # Configurations for workload per query + for count, write, read, update, analytical, query in mixed: + cold = [ + "--export-results", + vendor + + "_" + + dataset + + "_" + + dataset_size + + "_cold_mixed_{}_{}_{}_{}_{}_{}.json".format(count, write, read, update, analytical, query), + "--workload-mixed", + count, + write, + read, + update, + analytical, + query, + ] + hot = [ + "--export-results", + vendor + + "_" + + dataset + + "_" + + dataset_size + + "_hot_mixed_{}_{}_{}_{}_{}_{}.json".format(count, write, read, update, analytical, query), + "--warm-up", + "hot", + "--workload-mixed", + count, + write, + read, + update, + analytical, + query, + ] + + configurations.append(cold) + configurations.append(hot) default_args = [ "python3", @@ -146,9 +174,7 @@ def run_full_benchmarks(vendor, binary, dataset_size, dataset_group, realistic, "--num-workers-for-benchmark", "12", "--no-authorization", - "pokec/" + dataset_size + "/" + dataset_group + "/*", - "--tail-latency", - "100", + dataset + "/" + dataset_size + "/" + dataset_group + "/*", ] for config in configurations: @@ -157,11 +183,11 @@ def run_full_benchmarks(vendor, binary, dataset_size, dataset_group, realistic, subprocess.run(args=full_config, check=True) -def collect_all_results(vendor_name, dataset_size, dataset_group): +def collect_all_results(vendor_name, dataset, dataset_size, dataset_group): working_directory = Path().absolute() print(working_directory) - results = sorted(working_directory.glob(vendor_name + "_" + dataset_size + "_*.json")) - summary = {"pokec": {dataset_size: {dataset_group: {}}}} + results = sorted(working_directory.glob(vendor_name + "_" + dataset + "_" + dataset_size + "_*.json")) + summary = {dataset: {dataset_size: {dataset_group: {}}}} for file in results: if "summary" in file.name: @@ -169,19 +195,22 @@ def collect_all_results(vendor_name, dataset_size, dataset_group): f = file.open() data = json.loads(f.read()) if data["__run_configuration__"]["condition"] == "hot": - for key, value in data["pokec"][dataset_size][dataset_group].items(): + for key, value in data[dataset][dataset_size][dataset_group].items(): key_condition = key + "_hot" - summary["pokec"][dataset_size][dataset_group][key_condition] = value + summary[dataset][dataset_size][dataset_group][key_condition] = value elif data["__run_configuration__"]["condition"] == "cold": - for key, value in data["pokec"][dataset_size][dataset_group].items(): + for key, value in data[dataset][dataset_size][dataset_group].items(): key_condition = key + "_cold" - summary["pokec"][dataset_size][dataset_group][key_condition] = value - + summary[dataset][dataset_size][dataset_group][key_condition] = value + elif data["__run_configuration__"]["condition"] == "vulcanic": + for key, value in data[dataset][dataset_size][dataset_group].items(): + key_condition = key + "_vulcanic" + summary[dataset][dataset_size][dataset_group][key_condition] = value print(summary) json_object = json.dumps(summary, indent=4) print(json_object) - with open(vendor_name + "_" + dataset_size + "_summary.json", "w") as f: + with open(vendor_name + "_" + dataset + "_" + dataset_size + "_summary.json", "w") as f: json.dump(summary, f) @@ -194,16 +223,17 @@ if __name__ == "__main__": vendor_names = {"memgraph", "neo4j"} for vendor_name, vendor_binary in args.vendor: path = Path(vendor_binary) - if vendor_name.lower() in vendor_names and (path.is_file() or path.is_dir()): + if vendor_name.lower() in vendor_names and path.is_file(): run_full_benchmarks( vendor_name, vendor_binary, + args.dataset_name, args.dataset_size, args.dataset_group, realistic, mixed, ) - collect_all_results(vendor_name, args.dataset_size, args.dataset_group) + collect_all_results(vendor_name, args.dataset_name, args.dataset_size, args.dataset_group) else: raise Exception( "Check that vendor: {} is supported and you are passing right path: {} to binary.".format( diff --git a/tests/mgbench/helpers.py b/tests/mgbench/helpers.py index 7488b1443..d90cbe9a3 100644 --- a/tests/mgbench/helpers.py +++ b/tests/mgbench/helpers.py @@ -1,4 +1,4 @@ -# Copyright 2021 Memgraph Ltd. +# Copyright 2023 Memgraph Ltd. # # Use of this software is governed by the Business Source License # included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -9,11 +9,21 @@ # by the Apache License, Version 2.0, included in the file # licenses/APL.txt. +import collections import copy +import fnmatch +import importlib +import inspect import json import os import subprocess +import sys +from pathlib import Path +import workloads +from benchmark_context import BenchmarkContext +from workloads import * +from workloads import base SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) @@ -28,22 +38,70 @@ def get_binary_path(path, base=""): def download_file(url, path): - ret = subprocess.run(["wget", "-nv", "--content-disposition", url], - stderr=subprocess.PIPE, cwd=path, check=True) + ret = subprocess.run(["wget", "-nv", "--content-disposition", url], stderr=subprocess.PIPE, cwd=path, check=True) data = ret.stderr.decode("utf-8") tmp = data.split("->")[1] - name = tmp[tmp.index('"') + 1:tmp.rindex('"')] + name = tmp[tmp.index('"') + 1 : tmp.rindex('"')] return os.path.join(path, name) -def unpack_and_move_file(input_path, output_path): +def unpack_gz_and_move_file(input_path, output_path): if input_path.endswith(".gz"): - subprocess.run(["gunzip", input_path], - stdout=subprocess.DEVNULL, check=True) + subprocess.run(["gunzip", input_path], stdout=subprocess.DEVNULL, check=True) input_path = input_path[:-3] os.rename(input_path, output_path) +def unpack_gz(input_path: Path): + if input_path.suffix == ".gz": + subprocess.run(["gzip", "-d", input_path], capture_output=True, check=True) + input_path = input_path.with_suffix("") + return input_path + + +def unpack_zip(input_path: Path): + if input_path.suffix == ".zip": + subprocess.run(["unzip", input_path], capture_output=True, check=True, cwd=input_path.parent) + input_path = input_path.with_suffix("") + return input_path + + +def unpack_tar_zst(input_path: Path): + if input_path.suffix == ".zst": + subprocess.run( + ["tar", "--use-compress-program=unzstd", "-xvf", input_path], + cwd=input_path.parent, + capture_output=True, + check=True, + ) + input_path = input_path.with_suffix("").with_suffix("") + return input_path + + +def unpack_tar_gz(input_path: Path): + if input_path.suffix == ".gz": + subprocess.run( + ["tar", "-xvf", input_path], + cwd=input_path.parent, + capture_output=True, + check=True, + ) + input_path = input_path.with_suffix("").with_suffix("") + return input_path + + +def unpack_tar_zst_and_move(input_path: Path, output_path: Path): + if input_path.suffix == ".zst": + subprocess.run( + ["tar", "--use-compress-program=unzstd", "-xvf", input_path], + cwd=input_path.parent, + capture_output=True, + check=True, + ) + input_path = input_path.with_suffix("").with_suffix("") + return input_path.rename(output_path) + + def ensure_directory(path): if not os.path.exists(path): os.makedirs(path) @@ -51,6 +109,129 @@ def ensure_directory(path): raise Exception("The path '{}' should be a directory!".format(path)) +def get_available_workloads(customer_workloads: str = None) -> dict: + generators = {} + for module in map(workloads.__dict__.get, workloads.__all__): + for key in dir(module): + if key.startswith("_"): + continue + base_class = getattr(module, key) + if not inspect.isclass(base_class) or not issubclass(base_class, base.Workload): + continue + queries = collections.defaultdict(list) + for funcname in dir(base_class): + if not funcname.startswith("benchmark__"): + continue + group, query = funcname.split("__")[1:] + queries[group].append((query, funcname)) + generators[base_class.NAME] = (base_class, dict(queries)) + + if customer_workloads: + head_tail = os.path.split(customer_workloads) + path_without_dataset_name = head_tail[0] + dataset_name = head_tail[1].split(".")[0] + sys.path.append(path_without_dataset_name) + dataset_to_use = importlib.import_module(dataset_name) + + for key in dir(dataset_to_use): + if key.startswith("_"): + continue + base_class = getattr(dataset_to_use, key) + if not inspect.isclass(base_class) or not issubclass(base_class, base.Workload): + continue + queries = collections.defaultdict(list) + for funcname in dir(base_class): + if not funcname.startswith("benchmark__"): + continue + group, query = funcname.split("__")[1:] + queries[group].append((query, funcname)) + generators[base_class.NAME] = (base_class, dict(queries)) + + return generators + + +def list_available_workloads(customer_workloads: str = None): + generators = get_available_workloads(customer_workloads) + for name in sorted(generators.keys()): + print("Dataset:", name) + dataset, queries = generators[name] + print( + " Variants:", + ", ".join(dataset.VARIANTS), + "(default: " + dataset.DEFAULT_VARIANT + ")", + ) + for group in sorted(queries.keys()): + print(" Group:", group) + for query_name, query_func in queries[group]: + print(" Query:", query_name) + + +def match_patterns(workload, variant, group, query, is_default_variant, patterns): + for pattern in patterns: + verdict = [fnmatch.fnmatchcase(workload, pattern[0])] + if pattern[1] != "": + verdict.append(fnmatch.fnmatchcase(variant, pattern[1])) + else: + verdict.append(is_default_variant) + verdict.append(fnmatch.fnmatchcase(group, pattern[2])) + verdict.append(fnmatch.fnmatchcase(query, pattern[3])) + if all(verdict): + return True + return False + + +def filter_workloads(available_workloads: dict, benchmark_context: BenchmarkContext) -> list: + patterns = benchmark_context.benchmark_target_workload + for i in range(len(patterns)): + pattern = patterns[i].split("/") + if len(pattern) > 5 or len(pattern) == 0: + raise Exception("Invalid benchmark description '" + pattern + "'!") + pattern.extend(["", "*", "*"][len(pattern) - 1 :]) + patterns[i] = pattern + filtered = [] + for workload in sorted(available_workloads.keys()): + generator, queries = available_workloads[workload] + for variant in generator.VARIANTS: + is_default_variant = variant == generator.DEFAULT_VARIANT + current = collections.defaultdict(list) + for group in queries: + for query_name, query_func in queries[group]: + if match_patterns( + workload, + variant, + group, + query_name, + is_default_variant, + patterns, + ): + current[group].append((query_name, query_func)) + if len(current) == 0: + continue + + # Ignore benchgraph "basic" queries in standard CI/CD run + for pattern in patterns: + res = pattern.count("*") + key = "basic" + if res >= 2 and key in current.keys(): + current.pop(key) + + filtered.append((generator(variant=variant, benchmark_context=benchmark_context), dict(current))) + return filtered + + +def parse_kwargs(items): + """ + Parse a series of key-value pairs and return a dictionary + """ + d = {} + + if items: + for item in items: + key, value = item.split("=") + d[key] = value + return d + + class Directory: def __init__(self, path): self._path = path @@ -103,6 +284,9 @@ class Cache: ensure_directory(path) return Directory(path) + def get_default_cache_directory(self): + return self._directory + def load_config(self): if not os.path.isfile(self._config): return RecursiveDict() diff --git a/tests/mgbench/log.py b/tests/mgbench/log.py index 126c3b082..01a6771b3 100644 --- a/tests/mgbench/log.py +++ b/tests/mgbench/log.py @@ -9,6 +9,8 @@ # by the Apache License, Version 2.0, included in the file # licenses/APL.txt. +import logging + COLOR_GRAY = 0 COLOR_RED = 1 COLOR_GREEN = 2 @@ -16,27 +18,45 @@ COLOR_YELLOW = 3 COLOR_BLUE = 4 COLOR_VIOLET = 5 COLOR_CYAN = 6 +COLOR_WHITE = 7 -def log(color, *args): +logger = logging.Logger("mgbench_logger") +file_handler = logging.FileHandler("mgbench_logs.log") +file_format = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +file_handler.setFormatter(file_format) +logger.addHandler(file_handler) + + +def _log(color, *args): print("\033[1;3{}m~~".format(color), *args, "~~\033[0m") +def log(msg): + print(msg) + logger.info(msg=msg) + + def init(*args): - log(COLOR_BLUE, *args) + _log(COLOR_BLUE, *args) + logger.info(*args) def info(*args): - log(COLOR_CYAN, *args) + _log(COLOR_WHITE, *args) + logger.info(*args) def success(*args): - log(COLOR_GREEN, *args) + _log(COLOR_GREEN, *args) + logger.info(*args) def warning(*args): - log(COLOR_YELLOW, *args) + _log(COLOR_YELLOW, *args) + logger.warning(*args) def error(*args): - log(COLOR_RED, *args) + _log(COLOR_RED, *args) + logger.critical(*args) diff --git a/tests/mgbench/runners.py b/tests/mgbench/runners.py index 3d3aa966e..923854dec 100644 --- a/tests/mgbench/runners.py +++ b/tests/mgbench/runners.py @@ -1,4 +1,4 @@ -# Copyright 2022 Memgraph Ltd. +# Copyright 2023 Memgraph Ltd. # # Use of this software is governed by the Business Source License # included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -17,10 +17,13 @@ import subprocess import tempfile import threading import time +from abc import ABC, abstractmethod from pathlib import Path +from benchmark_context import BenchmarkContext -def wait_for_server(port, delay=0.1): + +def _wait_for_server(port, delay=0.1): cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)] while subprocess.call(cmd) != 0: time.sleep(0.01) @@ -62,50 +65,165 @@ def _get_current_usage(pid): return rss / 1024 -class Memgraph: - def __init__(self, memgraph_binary, temporary_dir, properties_on_edges, bolt_port, performance_tracking): - self._memgraph_binary = memgraph_binary - self._directory = tempfile.TemporaryDirectory(dir=temporary_dir) - self._properties_on_edges = properties_on_edges +class BaseClient(ABC): + @abstractmethod + def __init__(self, benchmark_context: BenchmarkContext): + self.benchmark_context = benchmark_context + + @abstractmethod + def execute(self): + pass + + +class BoltClient(BaseClient): + def __init__(self, benchmark_context: BenchmarkContext): + self._client_binary = benchmark_context.client_binary + self._directory = tempfile.TemporaryDirectory(dir=benchmark_context.temporary_directory) + self._username = "" + self._password = "" + self._bolt_port = ( + benchmark_context.vendor_args["bolt-port"] if "bolt-port" in benchmark_context.vendor_args.keys() else 7687 + ) + + def _get_args(self, **kwargs): + return _convert_args_to_flags(self._client_binary, **kwargs) + + def set_credentials(self, username: str, password: str): + self._username = username + self._password = password + + def execute( + self, + queries=None, + file_path=None, + num_workers=1, + max_retries: int = 50, + validation: bool = False, + time_dependent_execution: int = 0, + ): + if (queries is None and file_path is None) or (queries is not None and file_path is not None): + raise ValueError("Either queries or input_path must be specified!") + + queries_json = False + if queries is not None: + queries_json = True + file_path = os.path.join(self._directory.name, "queries.json") + with open(file_path, "w") as f: + for query in queries: + json.dump(query, f) + f.write("\n") + args = self._get_args( + input=file_path, + num_workers=num_workers, + max_retries=max_retries, + queries_json=queries_json, + username=self._username, + password=self._password, + port=self._bolt_port, + validation=validation, + time_dependent_execution=time_dependent_execution, + ) + + ret = None + try: + ret = subprocess.run(args, capture_output=True) + finally: + error = ret.stderr.decode("utf-8").strip().split("\n") + data = ret.stdout.decode("utf-8").strip().split("\n") + if error and error[0] != "": + print("Reported errros from client") + print(error) + data = [x for x in data if not x.startswith("[")] + return list(map(json.loads, data)) + + +class BaseRunner(ABC): + subclasses = {} + + def __init_subclass__(cls, **kwargs) -> None: + super().__init_subclass__(**kwargs) + cls.subclasses[cls.__name__.lower()] = cls + return + + @classmethod + def create(cls, benchmark_context: BenchmarkContext): + if benchmark_context.vendor_name not in cls.subclasses: + raise ValueError("Missing runner with name: {}".format(benchmark_context.vendor_name)) + + return cls.subclasses[benchmark_context.vendor_name]( + benchmark_context=benchmark_context, + ) + + @abstractmethod + def __init__(self, benchmark_context: BenchmarkContext): + self.benchmark_context = benchmark_context + + @abstractmethod + def start_benchmark(self): + pass + + @abstractmethod + def start_preparation(self): + pass + + @abstractmethod + def stop(self): + pass + + @abstractmethod + def clean_db(self): + pass + + @abstractmethod + def fetch_client(self) -> BaseClient: + pass + + +class Memgraph(BaseRunner): + def __init__(self, benchmark_context: BenchmarkContext): + super().__init__(benchmark_context=benchmark_context) + self._memgraph_binary = benchmark_context.vendor_binary + self._performance_tracking = benchmark_context.performance_tracking + self._directory = tempfile.TemporaryDirectory(dir=benchmark_context.temporary_directory) + self._vendor_args = benchmark_context.vendor_args + self._properties_on_edges = ( + self._vendor_args["no-properties-on-edges"] + if "no-properties-on-edges" in self._vendor_args.keys() + else False + ) + self._bolt_port = self._vendor_args["bolt-port"] if "bolt-port" in self._vendor_args.keys() else 7687 self._proc_mg = None - self._bolt_port = bolt_port - self.performance_tracking = performance_tracking self._stop_event = threading.Event() self._rss = [] - atexit.register(self._cleanup) # Determine Memgraph version - ret = subprocess.run([memgraph_binary, "--version"], stdout=subprocess.PIPE, check=True) + ret = subprocess.run([self._memgraph_binary, "--version"], stdout=subprocess.PIPE, check=True) version = re.search(r"[0-9]+\.[0-9]+\.[0-9]+", ret.stdout.decode("utf-8")).group(0) self._memgraph_version = tuple(map(int, version.split("."))) + atexit.register(self._cleanup) + def __del__(self): self._cleanup() atexit.unregister(self._cleanup) - def _get_args(self, **kwargs): + def _set_args(self, **kwargs): data_directory = os.path.join(self._directory.name, "memgraph") kwargs["bolt_port"] = self._bolt_port - if self._memgraph_version >= (0, 50, 0): - kwargs["data_directory"] = data_directory - else: - kwargs["durability_directory"] = data_directory - if self._memgraph_version >= (0, 50, 0): - kwargs["storage_properties_on_edges"] = self._properties_on_edges - else: - assert self._properties_on_edges, "Older versions of Memgraph can't disable properties on edges!" + kwargs["data_directory"] = data_directory + kwargs["storage_properties_on_edges"] = self._properties_on_edges return _convert_args_to_flags(self._memgraph_binary, **kwargs) def _start(self, **kwargs): if self._proc_mg is not None: raise Exception("The database process is already running!") - args = self._get_args(**kwargs) + args = self._set_args(**kwargs) self._proc_mg = subprocess.Popen(args, stdout=subprocess.DEVNULL) time.sleep(0.2) if self._proc_mg.poll() is not None: self._proc_mg = None raise Exception("The database process died prematurely!") - wait_for_server(self._bolt_port) + _wait_for_server(self._bolt_port) ret = self._proc_mg.poll() assert ret is None, "The database process died prematurely " "({})!".format(ret) @@ -119,7 +237,7 @@ class Memgraph: return ret, usage def start_preparation(self, workload): - if self.performance_tracking: + if self._performance_tracking: p = threading.Thread(target=self.res_background_tracking, args=(self._rss, self._stop_event)) self._stop_event.clear() self._rss.clear() @@ -127,13 +245,26 @@ class Memgraph: self._start(storage_snapshot_on_exit=True) def start_benchmark(self, workload): - if self.performance_tracking: + if self._performance_tracking: p = threading.Thread(target=self.res_background_tracking, args=(self._rss, self._stop_event)) self._stop_event.clear() self._rss.clear() p.start() self._start(storage_recover_on_startup=True) + def clean_db(self): + if self._proc_mg is not None: + raise Exception("The database process is already running, cannot clear data it!") + else: + out = subprocess.run( + args="rm -Rf memgraph/snapshots/*", + cwd=self._directory.name, + capture_output=True, + shell=True, + ) + print(out.stderr.decode("utf-8")) + print(out.stdout.decode("utf-8")) + def res_background_tracking(self, res, stop_event): print("Started rss tracking.") while not stop_event.is_set(): @@ -154,35 +285,46 @@ class Memgraph: f.close() def stop(self, workload): - if self.performance_tracking: + if self._performance_tracking: self._stop_event.set() self.dump_rss(workload) ret, usage = self._cleanup() assert ret == 0, "The database process exited with a non-zero " "status ({})!".format(ret) return usage + def fetch_client(self) -> BoltClient: + return BoltClient(benchmark_context=self.benchmark_context) -class Neo4j: - def __init__(self, neo4j_path, temporary_dir, bolt_port, performance_tracking): - self._neo4j_path = Path(neo4j_path) - self._neo4j_binary = Path(neo4j_path) / "bin" / "neo4j" - self._neo4j_config = Path(neo4j_path) / "conf" / "neo4j.conf" - self._neo4j_pid = Path(neo4j_path) / "run" / "neo4j.pid" - self._neo4j_admin = Path(neo4j_path) / "bin" / "neo4j-admin" - self.performance_tracking = performance_tracking + +class Neo4j(BaseRunner): + def __init__(self, benchmark_context: BenchmarkContext): + super().__init__(benchmark_context=benchmark_context) + self._neo4j_binary = Path(benchmark_context.vendor_binary) + self._neo4j_path = Path(benchmark_context.vendor_binary).parents[1] + self._neo4j_config = self._neo4j_path / "conf" / "neo4j.conf" + self._neo4j_pid = self._neo4j_path / "run" / "neo4j.pid" + self._neo4j_admin = self._neo4j_path / "bin" / "neo4j-admin" + self._performance_tracking = benchmark_context.performance_tracking + self._vendor_args = benchmark_context.vendor_args self._stop_event = threading.Event() self._rss = [] if not self._neo4j_binary.is_file(): raise Exception("Wrong path to binary!") - self._directory = tempfile.TemporaryDirectory(dir=temporary_dir) - self._bolt_port = bolt_port + + tempfile.TemporaryDirectory(dir=benchmark_context.temporary_directory) + self._bolt_port = ( + self.benchmark_context.vendor_args["bolt-port"] + if "bolt-port" in self.benchmark_context.vendor_args.keys() + else 7687 + ) atexit.register(self._cleanup) configs = [] memory_flag = "server.jvm.additional=-XX:NativeMemoryTracking=detail" auth_flag = "dbms.security.auth_enabled=false" - - if self.performance_tracking: + bolt_flag = "server.bolt.listen_address=:7687" + http_flag = "server.http.listen_address=:7474" + if self._performance_tracking: configs.append(memory_flag) else: lines = [] @@ -201,6 +343,8 @@ class Neo4j: file.close() configs.append(auth_flag) + configs.append(bolt_flag) + configs.append(http_flag) print("Check neo4j config flags:") for conf in configs: with self._neo4j_config.open("r+") as file: @@ -234,7 +378,7 @@ class Neo4j: else: raise Exception("The database process died prematurely!") print("Run server check:") - wait_for_server(self._bolt_port) + _wait_for_server(self._bolt_port) def _cleanup(self): if self._neo4j_pid.exists(): @@ -248,7 +392,7 @@ class Neo4j: return 0 def start_preparation(self, workload): - if self.performance_tracking: + if self._performance_tracking: p = threading.Thread(target=self.res_background_tracking, args=(self._rss, self._stop_event)) self._stop_event.clear() self._rss.clear() @@ -257,11 +401,11 @@ class Neo4j: # Start DB self._start() - if self.performance_tracking: + if self._performance_tracking: self.get_memory_usage("start_" + workload) def start_benchmark(self, workload): - if self.performance_tracking: + if self._performance_tracking: p = threading.Thread(target=self.res_background_tracking, args=(self._rss, self._stop_event)) self._stop_event.clear() self._rss.clear() @@ -269,7 +413,7 @@ class Neo4j: # Start DB self._start() - if self.performance_tracking: + if self._performance_tracking: self.get_memory_usage("start_" + workload) def dump_db(self, path): @@ -290,6 +434,20 @@ class Neo4j: check=True, ) + def clean_db(self): + print("Cleaning the database") + if self._neo4j_pid.exists(): + raise Exception("Cannot clean DB because it is running.") + else: + out = subprocess.run( + args="rm -Rf data/databases/* data/transactions/*", + cwd=self._neo4j_path, + capture_output=True, + shell=True, + ) + print(out.stderr.decode("utf-8")) + print(out.stdout.decode("utf-8")) + def load_db_from_dump(self, path): print("Loading the neo4j database from dump...") if self._neo4j_pid.exists(): @@ -300,7 +458,8 @@ class Neo4j: self._neo4j_admin, "database", "load", - "--from-path=" + path, + "--from-path", + path, "--overwrite-destination=true", "neo4j", ], @@ -325,7 +484,7 @@ class Neo4j: return True def stop(self, workload): - if self.performance_tracking: + if self._performance_tracking: self._stop_event.set() self.get_memory_usage("stop_" + workload) self.dump_rss(workload) @@ -360,51 +519,5 @@ class Neo4j: f.write(memory_usage.stdout) f.close() - -class Client: - def __init__( - self, client_binary: str, temporary_directory: str, bolt_port: int, username: str = "", password: str = "" - ): - self._client_binary = client_binary - self._directory = tempfile.TemporaryDirectory(dir=temporary_directory) - self._username = username - self._password = password - self._bolt_port = bolt_port - - def _get_args(self, **kwargs): - return _convert_args_to_flags(self._client_binary, **kwargs) - - def execute(self, queries=None, file_path=None, num_workers=1): - if (queries is None and file_path is None) or (queries is not None and file_path is not None): - raise ValueError("Either queries or input_path must be specified!") - - # TODO: check `file_path.endswith(".json")` to support advanced - # input queries - - queries_json = False - if queries is not None: - queries_json = True - file_path = os.path.join(self._directory.name, "queries.json") - with open(file_path, "w") as f: - for query in queries: - json.dump(query, f) - f.write("\n") - - args = self._get_args( - input=file_path, - num_workers=num_workers, - queries_json=queries_json, - username=self._username, - password=self._password, - port=self._bolt_port, - ) - - ret = subprocess.run(args, capture_output=True, check=True) - error = ret.stderr.decode("utf-8").strip().split("\n") - if error and error[0] != "": - print("Reported errros from client") - print(error) - - data = ret.stdout.decode("utf-8").strip().split("\n") - data = [x for x in data if not x.startswith("[")] - return list(map(json.loads, data)) + def fetch_client(self) -> BoltClient: + return BoltClient(benchmark_context=self.benchmark_context) diff --git a/tests/mgbench/validation.py b/tests/mgbench/validation.py new file mode 100644 index 000000000..d32d1d61b --- /dev/null +++ b/tests/mgbench/validation.py @@ -0,0 +1,244 @@ +import argparse +import copy +import multiprocessing +import random + +import helpers +import runners +import workloads +from benchmark_context import BenchmarkContext +from workloads import base + + +def pars_args(): + + parser = argparse.ArgumentParser( + prog="Validator for individual query checking", + description="""Validates that query is running, and validates output between different vendors""", + ) + parser.add_argument( + "benchmarks", + nargs="*", + default="", + help="descriptions of benchmarks that should be run; " + "multiple descriptions can be specified to run multiple " + "benchmarks; the description is specified as " + "dataset/variant/group/query; Unix shell-style wildcards " + "can be used in the descriptions; variant, group and query " + "are optional and they can be left out; the default " + "variant is '' which selects the default dataset variant; " + "the default group is '*' which selects all groups; the" + "default query is '*' which selects all queries", + ) + + parser.add_argument( + "--vendor-binary-1", + help="Vendor binary used for benchmarking, by default it is memgraph", + default=helpers.get_binary_path("memgraph"), + ) + + parser.add_argument( + "--vendor-name-1", + default="memgraph", + choices=["memgraph", "neo4j"], + help="Input vendor binary name (memgraph, neo4j)", + ) + + parser.add_argument( + "--vendor-binary-2", + help="Vendor binary used for benchmarking, by default it is memgraph", + default=helpers.get_binary_path("memgraph"), + ) + + parser.add_argument( + "--vendor-name-2", + default="memgraph", + choices=["memgraph", "neo4j"], + help="Input vendor binary name (memgraph, neo4j)", + ) + + parser.add_argument( + "--client-binary", + default=helpers.get_binary_path("tests/mgbench/client"), + help="Client binary used for benchmarking", + ) + + parser.add_argument( + "--temporary-directory", + default="/tmp", + help="directory path where temporary data should " "be stored", + ) + + parser.add_argument( + "--num-workers-for-import", + type=int, + default=multiprocessing.cpu_count() // 2, + help="number of workers used to import the dataset", + ) + + return parser.parse_args() + + +def get_queries(gen, count): + # Make the generator deterministic. + random.seed(gen.__name__) + # Generate queries. + ret = [] + for i in range(count): + ret.append(gen()) + return ret + + +if __name__ == "__main__": + + args = pars_args() + + benchmark_context_db_1 = BenchmarkContext( + vendor_name=args.vendor_name_1, + vendor_binary=args.vendor_binary_1, + benchmark_target_workload=copy.copy(args.benchmarks), + client_binary=args.client_binary, + num_workers_for_import=args.num_workers_for_import, + temporary_directory=args.temporary_directory, + ) + + available_workloads = helpers.get_available_workloads() + + print(helpers.list_available_workloads()) + + vendor_runner = runners.BaseRunner.create( + benchmark_context=benchmark_context_db_1, + ) + + cache = helpers.Cache() + client = vendor_runner.fetch_client() + + workloads = helpers.filter_workloads( + available_workloads=available_workloads, benchmark_context=benchmark_context_db_1 + ) + + results_db_1 = {} + + for workload, queries in workloads: + + vendor_runner.clean_db() + + generated_queries = workload.dataset_generator() + if generated_queries: + vendor_runner.start_preparation("import") + client.execute(queries=generated_queries, num_workers=benchmark_context_db_1.num_workers_for_import) + vendor_runner.stop("import") + else: + workload.prepare(cache.cache_directory("datasets", workload.NAME, workload.get_variant())) + imported = workload.custom_import() + if not imported: + vendor_runner.start_preparation("import") + print("Executing database cleanup and index setup...") + client.execute( + file_path=workload.get_index(), num_workers=benchmark_context_db_1.num_workers_for_import + ) + print("Importing dataset...") + ret = client.execute( + file_path=workload.get_file(), num_workers=benchmark_context_db_1.num_workers_for_import + ) + usage = vendor_runner.stop("import") + + for group in sorted(queries.keys()): + for query, funcname in queries[group]: + print("Running query:{}/{}/{}".format(group, query, funcname)) + func = getattr(workload, funcname) + count = 1 + vendor_runner.start_benchmark("validation") + try: + ret = client.execute(queries=get_queries(func, count), num_workers=1, validation=True)[0] + results_db_1[funcname] = ret["results"].items() + except Exception as e: + print("Issue running the query" + funcname) + print(e) + results_db_1[funcname] = "Query not executed properly" + finally: + usage = vendor_runner.stop("validation") + print("Database used {:.3f} seconds of CPU time.".format(usage["cpu"])) + print("Database peaked at {:.3f} MiB of memory.".format(usage["memory"] / 1024.0 / 1024.0)) + + benchmark_context_db_2 = BenchmarkContext( + vendor_name=args.vendor_name_2, + vendor_binary=args.vendor_binary_2, + benchmark_target_workload=copy.copy(args.benchmarks), + client_binary=args.client_binary, + num_workers_for_import=args.num_workers_for_import, + temporary_directory=args.temporary_directory, + ) + + vendor_runner = runners.BaseRunner.create( + benchmark_context=benchmark_context_db_2, + ) + available_workloads = helpers.get_available_workloads() + + workloads = helpers.filter_workloads(available_workloads, benchmark_context=benchmark_context_db_2) + + client = vendor_runner.fetch_client() + + results_db_2 = {} + + for workload, queries in workloads: + + vendor_runner.clean_db() + + generated_queries = workload.dataset_generator() + if generated_queries: + vendor_runner.start_preparation("import") + client.execute(queries=generated_queries, num_workers=benchmark_context_db_2.num_workers_for_import) + vendor_runner.stop("import") + else: + workload.prepare(cache.cache_directory("datasets", workload.NAME, workload.get_variant())) + imported = workload.custom_import() + if not imported: + vendor_runner.start_preparation("import") + print("Executing database cleanup and index setup...") + client.execute( + file_path=workload.get_index(), num_workers=benchmark_context_db_2.num_workers_for_import + ) + print("Importing dataset...") + ret = client.execute( + file_path=workload.get_file(), num_workers=benchmark_context_db_2.num_workers_for_import + ) + usage = vendor_runner.stop("import") + + for group in sorted(queries.keys()): + for query, funcname in queries[group]: + print("Running query:{}/{}/{}".format(group, query, funcname)) + func = getattr(workload, funcname) + count = 1 + vendor_runner.start_benchmark("validation") + try: + ret = client.execute(queries=get_queries(func, count), num_workers=1, validation=True)[0] + results_db_2[funcname] = ret["results"].items() + except Exception as e: + print("Issue running the query" + funcname) + print(e) + results_db_2[funcname] = "Query not executed properly" + finally: + usage = vendor_runner.stop("validation") + print("Database used {:.3f} seconds of CPU time.".format(usage["cpu"])) + print("Database peaked at {:.3f} MiB of memory.".format(usage["memory"] / 1024.0 / 1024.0)) + + validation = {} + for key in results_db_1.keys(): + if type(results_db_1[key]) is str: + validation[key] = "Query not executed properly." + else: + db_1_values = set() + for index, value in results_db_1[key]: + db_1_values.add(value) + neo4j_values = set() + for index, value in results_db_2[key]: + neo4j_values.add(value) + + if db_1_values == neo4j_values: + validation[key] = "Identical results" + else: + validation[key] = "Different results, check manually." + + for key, value in validation.items(): + print(key + " " + value) diff --git a/tests/mgbench/workloads/__init__.py b/tests/mgbench/workloads/__init__.py new file mode 100644 index 000000000..ed172b041 --- /dev/null +++ b/tests/mgbench/workloads/__init__.py @@ -0,0 +1,4 @@ +from pathlib import Path + +modules = Path(__file__).resolve().parent.glob("*.py") +__all__ = [f.name[:-3] for f in modules if f.is_file() and not f.name == "__init__.py"] diff --git a/tests/mgbench/workloads/base.py b/tests/mgbench/workloads/base.py new file mode 100644 index 000000000..d6125ab16 --- /dev/null +++ b/tests/mgbench/workloads/base.py @@ -0,0 +1,197 @@ +# Copyright 2022 Memgraph Ltd. +# +# Use of this software is governed by the Business Source License +# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +# License, and you may not use this file except in compliance with the Business Source License. +# +# As of the Change Date specified in that file, in accordance with +# the Business Source License, use of this software will be governed +# by the Apache License, Version 2.0, included in the file +# licenses/APL.txt. + +from abc import ABC, abstractclassmethod +from pathlib import Path + +import helpers +from benchmark_context import BenchmarkContext + + +# Base dataset class used as a template to create each individual dataset. All +# common logic is handled here. +class Workload(ABC): + + # Name of the workload/dataset. + NAME = "" + # List of all variants of the workload/dataset that exist. + VARIANTS = ["default"] + # One of the available variants that should be used as the default variant. + DEFAULT_VARIANT = "default" + + # List of local files that should be used to import the dataset. + LOCAL_FILE = None + + # URLs of remote dataset files that should be used to import the dataset, compressed in gz format. + URL_FILE = None + + # Index files + LOCAL_INDEX_FILE = None + URL_INDEX_FILE = None + + # Number of vertices/edges for each variant. + SIZES = { + "default": {"vertices": 0, "edges": 0}, + } + + # Indicates whether the dataset has properties on edges. + PROPERTIES_ON_EDGES = False + + def __init_subclass__(cls) -> None: + name_prerequisite = "NAME" in cls.__dict__ + generator_prerequisite = "dataset_generator" in cls.__dict__ + custom_import_prerequisite = "custom_import" in cls.__dict__ + basic_import_prerequisite = ("LOCAL_FILE" in cls.__dict__ or "URL_FILE" in cls.__dict__) and ( + "LOCAL_INDEX_FILE" in cls.__dict__ or "URL_INDEX_FILE" in cls.__dict__ + ) + + if not name_prerequisite: + raise ValueError( + """Can't define a workload class {} without NAME property: + NAME = "dataset name" + Name property defines the workload you want to execute, for example: "demo/*/*/*" + """.format( + cls.__name__ + ) + ) + + # Check workload is in generator or dataset mode during interpretation (not both), not runtime + if generator_prerequisite and (custom_import_prerequisite or basic_import_prerequisite): + raise ValueError( + """ + The workload class {} cannot have defined dataset import and generate dataset at + the same time. + """.format( + cls.__name__ + ) + ) + + if not generator_prerequisite and (not custom_import_prerequisite and not basic_import_prerequisite): + raise ValueError( + """ + The workload class {} need to have defined dataset import or dataset generator + """.format( + cls.__name__ + ) + ) + + return super().__init_subclass__() + + def __init__(self, variant: str = None, benchmark_context: BenchmarkContext = None): + """ + Accepts a `variant` variable that indicates which variant + of the dataset should be executed + """ + self.benchmark_context = benchmark_context + self._variant = variant + self._vendor = benchmark_context.vendor_name + self._file = None + self._file_index = None + + if self.NAME == "": + raise ValueError("Give your workload a name, by setting self.NAME") + + if variant is None: + variant = self.DEFAULT_VARIANT + if variant not in self.VARIANTS: + raise ValueError("Invalid test variant!") + if (self.LOCAL_FILE and variant not in self.LOCAL_FILE) and (self.URL_FILE and variant not in self.URL_FILE): + raise ValueError("The variant doesn't have a defined URL or LOCAL file path!") + if variant not in self.SIZES: + raise ValueError("The variant doesn't have a defined dataset " "size!") + + if (self.LOCAL_INDEX_FILE and self._vendor not in self.LOCAL_INDEX_FILE) and ( + self.URL_INDEX_FILE and self._vendor not in self.URL_INDEX_FILE + ): + raise ValueError("Vendor does not have INDEX for dataset!") + + if self.LOCAL_FILE is not None: + self._local_file = self.LOCAL_FILE.get(variant, None) + else: + self._local_file = None + + if self.URL_FILE is not None: + self._url_file = self.URL_FILE.get(variant, None) + else: + self._url_file = None + + if self.LOCAL_INDEX_FILE is not None: + self._local_index = self.LOCAL_INDEX_FILE.get(self._vendor, None) + else: + self._local_index = None + + if self.URL_INDEX_FILE is not None: + self._url_index = self.URL_INDEX_FILE.get(self._vendor, None) + else: + self._url_index = None + + self._size = self.SIZES[variant] + if "vertices" in self._size or "edges" in self._size: + self._num_vertices = self._size["vertices"] + self._num_edges = self._size["edges"] + + def prepare(self, directory): + if self._local_file is not None: + print("Using local dataset file:", self._local_file) + self._file = self._local_file + elif self._url_file is not None: + cached_input, exists = directory.get_file("dataset.cypher") + if not exists: + print("Downloading dataset file:", self._url_file) + downloaded_file = helpers.download_file(self._url_file, directory.get_path()) + print("Unpacking and caching file:", downloaded_file) + helpers.unpack_gz_and_move_file(downloaded_file, cached_input) + print("Using cached dataset file:", cached_input) + self._file = cached_input + + if self._local_index is not None: + print("Using local index file:", self._local_index) + self._file_index = self._local_index + elif self._url_index is not None: + cached_index, exists = directory.get_file(self._vendor + ".cypher") + if not exists: + print("Downloading index file:", self._url_index) + downloaded_file = helpers.download_file(self._url_index, directory.get_path()) + print("Unpacking and caching file:", downloaded_file) + helpers.unpack_gz_and_move_file(downloaded_file, cached_index) + print("Using cached index file:", cached_index) + self._file_index = cached_index + + def get_variant(self): + """Returns the current variant of the dataset.""" + return self._variant + + def get_index(self): + """Get index file, defined by vendor""" + return self._file_index + + def get_file(self): + """ + Returns path to the file that contains dataset creation queries. + """ + return self._file + + def get_size(self): + """Returns number of vertices/edges for the current variant.""" + return self._size + + def custom_import(self) -> bool: + print("Workload does not have a custom import") + return False + + def dataset_generator(self) -> list: + print("Workload is not auto generated") + return [] + + # All tests should be query generator functions that output all of the + # queries that should be executed by the runner. The functions should be + # named `benchmark__GROUPNAME__TESTNAME` and should not accept any + # arguments. diff --git a/tests/mgbench/workloads/demo.py b/tests/mgbench/workloads/demo.py new file mode 100644 index 000000000..2b758d5ef --- /dev/null +++ b/tests/mgbench/workloads/demo.py @@ -0,0 +1,28 @@ +import random + +from workloads.base import Workload + + +class Demo(Workload): + + NAME = "demo" + + def dataset_generator(self): + + queries = [("MATCH (n) DETACH DELETE n;", {})] + for i in range(0, 100): + queries.append(("CREATE (:NodeA{{ id:{}}});".format(i), {})) + queries.append(("CREATE (:NodeB{{ id:{}}});".format(i), {})) + + for i in range(0, 100): + a = random.randint(0, 99) + b = random.randint(0, 99) + queries.append(("MATCH(a:NodeA{{ id: {}}}),(b:NodeB{{id: {}}}) CREATE (a)-[:EDGE]->(b)".format(a, b), {})) + + return queries + + def benchmark__test__sample_query1(self): + return ("MATCH (n) RETURN n", {}) + + def benchmark__test__sample_query2(self): + return ("MATCH (n) RETURN n", {}) diff --git a/tests/mgbench/workloads/importers/__init__.py b/tests/mgbench/workloads/importers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/mgbench/workloads/importers/importer_ldbc_bi.py b/tests/mgbench/workloads/importers/importer_ldbc_bi.py new file mode 100644 index 000000000..6c84ba75d --- /dev/null +++ b/tests/mgbench/workloads/importers/importer_ldbc_bi.py @@ -0,0 +1,213 @@ +import csv +import subprocess +from collections import defaultdict +from pathlib import Path + +import helpers +from benchmark_context import BenchmarkContext +from runners import BaseRunner + +HEADERS_URL = "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/ldbc/benchmark/bi/headers.tar.gz" + + +class ImporterLDBCBI: + def __init__( + self, benchmark_context: BenchmarkContext, dataset_name: str, variant: str, index_file: str, csv_dict: dict + ) -> None: + self._benchmark_context = benchmark_context + self._dataset_name = dataset_name + self._variant = variant + self._index_file = index_file + self._csv_dict = csv_dict + + def execute_import(self): + + vendor_runner = BaseRunner.create( + benchmark_context=self._benchmark_context, + ) + client = vendor_runner.fetch_client() + + if self._benchmark_context.vendor_name == "neo4j": + data_dir = Path() / ".cache" / "datasets" / self._dataset_name / self._variant / "data_neo4j" + data_dir.mkdir(parents=True, exist_ok=True) + dir_name = self._csv_dict[self._variant].split("/")[-1:][0].removesuffix(".tar.zst") + if (data_dir / dir_name).exists(): + print("Files downloaded") + data_dir = data_dir / dir_name + else: + print("Downloading files") + downloaded_file = helpers.download_file(self._csv_dict[self._variant], data_dir.absolute()) + print("Unpacking the file..." + downloaded_file) + data_dir = helpers.unpack_tar_zst(Path(downloaded_file)) + + headers_dir = Path() / ".cache" / "datasets" / self._dataset_name / self._variant / "headers_neo4j" + headers_dir.mkdir(parents=True, exist_ok=True) + headers = HEADERS_URL.split("/")[-1:][0].removesuffix(".tar.gz") + if (headers_dir / headers).exists(): + print("Header files downloaded.") + else: + print("Downloading files") + downloaded_file = helpers.download_file(HEADERS_URL, headers_dir.absolute()) + print("Unpacking the file..." + downloaded_file) + headers_dir = helpers.unpack_tar_gz(Path(downloaded_file)) + + input_headers = {} + for header_file in headers_dir.glob("**/*.csv"): + key = "/".join(header_file.parts[-2:])[0:-4] + input_headers[key] = header_file.as_posix() + + for data_file in data_dir.glob("**/*.gz"): + if "initial_snapshot" in data_file.parts: + data_file = helpers.unpack_gz(data_file) + output = data_file.parent / (data_file.stem + "_neo" + ".csv") + if not output.exists(): + with data_file.open("r") as input_f, output.open("a") as output_f: + reader = csv.reader(input_f, delimiter="|") + header = next(reader) + writer = csv.writer(output_f, delimiter="|") + for line in reader: + writer.writerow(line) + else: + print("Files converted") + + input_files = defaultdict(list) + for neo_file in data_dir.glob("**/*_neo.csv"): + key = "/".join(neo_file.parts[-3:-1]) + input_files[key].append(neo_file.as_posix()) + + vendor_runner.clean_db() + subprocess.run( + args=[ + vendor_runner._neo4j_admin, + "database", + "import", + "full", + "--id-type=INTEGER", + "--ignore-empty-strings=true", + "--bad-tolerance=0", + "--nodes=Place=" + input_headers["static/Place"] + "," + ",".join(input_files["static/Place"]), + "--nodes=Organisation=" + + input_headers["static/Organisation"] + + "," + + ",".join(input_files["static/Organisation"]), + "--nodes=TagClass=" + + input_headers["static/TagClass"] + + "," + + ",".join(input_files["static/TagClass"]), + "--nodes=Tag=" + input_headers["static/Tag"] + "," + ",".join(input_files["static/Tag"]), + "--nodes=Forum=" + input_headers["dynamic/Forum"] + "," + ",".join(input_files["dynamic/Forum"]), + "--nodes=Person=" + input_headers["dynamic/Person"] + "," + ",".join(input_files["dynamic/Person"]), + "--nodes=Message:Comment=" + + input_headers["dynamic/Comment"] + + "," + + ",".join(input_files["dynamic/Comment"]), + "--nodes=Message:Post=" + + input_headers["dynamic/Post"] + + "," + + ",".join(input_files["dynamic/Post"]), + "--relationships=IS_PART_OF=" + + input_headers["static/Place_isPartOf_Place"] + + "," + + ",".join(input_files["static/Place_isPartOf_Place"]), + "--relationships=IS_SUBCLASS_OF=" + + input_headers["static/TagClass_isSubclassOf_TagClass"] + + "," + + ",".join(input_files["static/TagClass_isSubclassOf_TagClass"]), + "--relationships=IS_LOCATED_IN=" + + input_headers["static/Organisation_isLocatedIn_Place"] + + "," + + ",".join(input_files["static/Organisation_isLocatedIn_Place"]), + "--relationships=HAS_TYPE=" + + input_headers["static/Tag_hasType_TagClass"] + + "," + + ",".join(input_files["static/Tag_hasType_TagClass"]), + "--relationships=HAS_CREATOR=" + + input_headers["dynamic/Comment_hasCreator_Person"] + + "," + + ",".join(input_files["dynamic/Comment_hasCreator_Person"]), + "--relationships=IS_LOCATED_IN=" + + input_headers["dynamic/Comment_isLocatedIn_Country"] + + "," + + ",".join(input_files["dynamic/Comment_isLocatedIn_Country"]), + "--relationships=REPLY_OF=" + + input_headers["dynamic/Comment_replyOf_Comment"] + + "," + + ",".join(input_files["dynamic/Comment_replyOf_Comment"]), + "--relationships=REPLY_OF=" + + input_headers["dynamic/Comment_replyOf_Post"] + + "," + + ",".join(input_files["dynamic/Comment_replyOf_Post"]), + "--relationships=CONTAINER_OF=" + + input_headers["dynamic/Forum_containerOf_Post"] + + "," + + ",".join(input_files["dynamic/Forum_containerOf_Post"]), + "--relationships=HAS_MEMBER=" + + input_headers["dynamic/Forum_hasMember_Person"] + + "," + + ",".join(input_files["dynamic/Forum_hasMember_Person"]), + "--relationships=HAS_MODERATOR=" + + input_headers["dynamic/Forum_hasModerator_Person"] + + "," + + ",".join(input_files["dynamic/Forum_hasModerator_Person"]), + "--relationships=HAS_TAG=" + + input_headers["dynamic/Forum_hasTag_Tag"] + + "," + + ",".join(input_files["dynamic/Forum_hasTag_Tag"]), + "--relationships=HAS_INTEREST=" + + input_headers["dynamic/Person_hasInterest_Tag"] + + "," + + ",".join(input_files["dynamic/Person_hasInterest_Tag"]), + "--relationships=IS_LOCATED_IN=" + + input_headers["dynamic/Person_isLocatedIn_City"] + + "," + + ",".join(input_files["dynamic/Person_isLocatedIn_City"]), + "--relationships=KNOWS=" + + input_headers["dynamic/Person_knows_Person"] + + "," + + ",".join(input_files["dynamic/Person_knows_Person"]), + "--relationships=LIKES=" + + input_headers["dynamic/Person_likes_Comment"] + + "," + + ",".join(input_files["dynamic/Person_likes_Comment"]), + "--relationships=LIKES=" + + input_headers["dynamic/Person_likes_Post"] + + "," + + ",".join(input_files["dynamic/Person_likes_Post"]), + "--relationships=HAS_CREATOR=" + + input_headers["dynamic/Post_hasCreator_Person"] + + "," + + ",".join(input_files["dynamic/Post_hasCreator_Person"]), + "--relationships=HAS_TAG=" + + input_headers["dynamic/Comment_hasTag_Tag"] + + "," + + ",".join(input_files["dynamic/Comment_hasTag_Tag"]), + "--relationships=HAS_TAG=" + + input_headers["dynamic/Post_hasTag_Tag"] + + "," + + ",".join(input_files["dynamic/Post_hasTag_Tag"]), + "--relationships=IS_LOCATED_IN=" + + input_headers["dynamic/Post_isLocatedIn_Country"] + + "," + + ",".join(input_files["dynamic/Post_isLocatedIn_Country"]), + "--relationships=STUDY_AT=" + + input_headers["dynamic/Person_studyAt_University"] + + "," + + ",".join(input_files["dynamic/Person_studyAt_University"]), + "--relationships=WORK_AT=" + + input_headers["dynamic/Person_workAt_Company"] + + "," + + ",".join(input_files["dynamic/Person_workAt_Company"]), + "--delimiter", + "|", + "neo4j", + ], + check=True, + ) + + vendor_runner.start_preparation("Index preparation") + print("Executing database index setup") + client.execute(file_path=self._index_file, num_workers=1) + vendor_runner.stop("Stop index preparation") + return True + else: + return False diff --git a/tests/mgbench/workloads/importers/importer_ldbc_interactive.py b/tests/mgbench/workloads/importers/importer_ldbc_interactive.py new file mode 100644 index 000000000..3c78405b7 --- /dev/null +++ b/tests/mgbench/workloads/importers/importer_ldbc_interactive.py @@ -0,0 +1,163 @@ +import csv +import subprocess +from pathlib import Path + +import helpers +from benchmark_context import BenchmarkContext +from runners import BaseRunner + +# Removed speaks/email from person header +HEADERS_INTERACTIVE = { + "static/organisation": "id:ID(Organisation)|:LABEL|name:STRING|url:STRING", + "static/place": "id:ID(Place)|name:STRING|url:STRING|:LABEL", + "static/tagclass": "id:ID(TagClass)|name:STRING|url:STRING", + "static/tag": "id:ID(Tag)|name:STRING|url:STRING", + "static/tagclass_isSubclassOf_tagclass": ":START_ID(TagClass)|:END_ID(TagClass)", + "static/tag_hasType_tagclass": ":START_ID(Tag)|:END_ID(TagClass)", + "static/organisation_isLocatedIn_place": ":START_ID(Organisation)|:END_ID(Place)", + "static/place_isPartOf_place": ":START_ID(Place)|:END_ID(Place)", + "dynamic/comment": "id:ID(Comment)|creationDate:LOCALDATETIME|locationIP:STRING|browserUsed:STRING|content:STRING|length:INT", + "dynamic/forum": "id:ID(Forum)|title:STRING|creationDate:LOCALDATETIME", + "dynamic/person": "id:ID(Person)|firstName:STRING|lastName:STRING|gender:STRING|birthday:LOCALDATETIME|creationDate:LOCALDATETIME|locationIP:STRING|browserUsed:STRING", + "dynamic/post": "id:ID(Post)|imageFile:STRING|creationDate:LOCALDATETIME|locationIP:STRING|browserUsed:STRING|language:STRING|content:STRING|length:INT", + "dynamic/comment_hasCreator_person": ":START_ID(Comment)|:END_ID(Person)", + "dynamic/comment_isLocatedIn_place": ":START_ID(Comment)|:END_ID(Place)", + "dynamic/comment_replyOf_comment": ":START_ID(Comment)|:END_ID(Comment)", + "dynamic/comment_replyOf_post": ":START_ID(Comment)|:END_ID(Post)", + "dynamic/forum_containerOf_post": ":START_ID(Forum)|:END_ID(Post)", + "dynamic/forum_hasMember_person": ":START_ID(Forum)|:END_ID(Person)|joinDate:LOCALDATETIME", + "dynamic/forum_hasModerator_person": ":START_ID(Forum)|:END_ID(Person)", + "dynamic/forum_hasTag_tag": ":START_ID(Forum)|:END_ID(Tag)", + "dynamic/person_hasInterest_tag": ":START_ID(Person)|:END_ID(Tag)", + "dynamic/person_isLocatedIn_place": ":START_ID(Person)|:END_ID(Place)", + "dynamic/person_knows_person": ":START_ID(Person)|:END_ID(Person)|creationDate:LOCALDATETIME", + "dynamic/person_likes_comment": ":START_ID(Person)|:END_ID(Comment)|creationDate:LOCALDATETIME", + "dynamic/person_likes_post": ":START_ID(Person)|:END_ID(Post)|creationDate:LOCALDATETIME", + "dynamic/person_studyAt_organisation": ":START_ID(Person)|:END_ID(Organisation)|classYear:INT", + "dynamic/person_workAt_organisation": ":START_ID(Person)|:END_ID(Organisation)|workFrom:INT", + "dynamic/post_hasCreator_person": ":START_ID(Post)|:END_ID(Person)", + "dynamic/comment_hasTag_tag": ":START_ID(Comment)|:END_ID(Tag)", + "dynamic/post_hasTag_tag": ":START_ID(Post)|:END_ID(Tag)", + "dynamic/post_isLocatedIn_place": ":START_ID(Post)|:END_ID(Place)", +} + + +class ImporterLDBCInteractive: + def __init__( + self, benchmark_context: BenchmarkContext, dataset_name: str, variant: str, index_file: str, csv_dict: dict + ) -> None: + self._benchmark_context = benchmark_context + self._dataset_name = dataset_name + self._variant = variant + self._index_file = index_file + self._csv_dict = csv_dict + + def execute_import(self): + + vendor_runner = BaseRunner.create( + benchmark_context=self._benchmark_context, + ) + client = vendor_runner.fetch_client() + + if self._benchmark_context.vendor_name == "neo4j": + print("Runnning Neo4j import") + dump_dir = Path() / ".cache" / "datasets" / self._dataset_name / self._variant / "dump" + dump_dir.mkdir(parents=True, exist_ok=True) + dir_name = self._csv_dict[self._variant].split("/")[-1:][0].removesuffix(".tar.zst") + if (dump_dir / dir_name).exists(): + print("Files downloaded") + dump_dir = dump_dir / dir_name + else: + print("Downloading files") + downloaded_file = helpers.download_file(self._csv_dict[self._variant], dump_dir.absolute()) + print("Unpacking the file..." + downloaded_file) + dump_dir = helpers.unpack_tar_zst(Path(downloaded_file)) + + input_files = {} + for file in dump_dir.glob("*/*0.csv"): + parts = file.parts[-2:] + key = parts[0] + "/" + parts[1][:-8] + input_files[key] = file + + output_files = {} + for key, file in input_files.items(): + output = file.parent / (file.stem + "_neo" + ".csv") + if not output.exists(): + with file.open("r") as input_f, output.open("a") as output_f: + reader = csv.reader(input_f, delimiter="|") + header = next(reader) + + writer = csv.writer(output_f, delimiter="|") + if key in HEADERS_INTERACTIVE.keys(): + updated_header = HEADERS_INTERACTIVE[key].split("|") + writer.writerow(updated_header) + for line in reader: + if "creationDate" in header: + pos = header.index("creationDate") + line[pos] = line[pos][0:-5] + elif "joinDate" in header: + pos = header.index("joinDate") + line[pos] = line[pos][0:-5] + + if "organisation_0_0.csv" == file.name: + writer.writerow([line[0], line[1].capitalize(), line[2], line[3]]) + elif "place_0_0.csv" == file.name: + writer.writerow([line[0], line[1], line[2], line[3].capitalize()]) + else: + writer.writerow(line) + + output_files[key] = output.as_posix() + vendor_runner.clean_db() + subprocess.run( + args=[ + vendor_runner._neo4j_admin, + "database", + "import", + "full", + "--id-type=INTEGER", + "--nodes=Place=" + output_files["static/place"], + "--nodes=Organisation=" + output_files["static/organisation"], + "--nodes=TagClass=" + output_files["static/tagclass"], + "--nodes=Tag=" + output_files["static/tag"], + "--nodes=Comment:Message=" + output_files["dynamic/comment"], + "--nodes=Forum=" + output_files["dynamic/forum"], + "--nodes=Person=" + output_files["dynamic/person"], + "--nodes=Post:Message=" + output_files["dynamic/post"], + "--relationships=IS_PART_OF=" + output_files["static/place_isPartOf_place"], + "--relationships=IS_SUBCLASS_OF=" + output_files["static/tagclass_isSubclassOf_tagclass"], + "--relationships=IS_LOCATED_IN=" + output_files["static/organisation_isLocatedIn_place"], + "--relationships=HAS_TYPE=" + output_files["static/tag_hasType_tagclass"], + "--relationships=HAS_CREATOR=" + output_files["dynamic/comment_hasCreator_person"], + "--relationships=IS_LOCATED_IN=" + output_files["dynamic/comment_isLocatedIn_place"], + "--relationships=REPLY_OF=" + output_files["dynamic/comment_replyOf_comment"], + "--relationships=REPLY_OF=" + output_files["dynamic/comment_replyOf_post"], + "--relationships=CONTAINER_OF=" + output_files["dynamic/forum_containerOf_post"], + "--relationships=HAS_MEMBER=" + output_files["dynamic/forum_hasMember_person"], + "--relationships=HAS_MODERATOR=" + output_files["dynamic/forum_hasModerator_person"], + "--relationships=HAS_TAG=" + output_files["dynamic/forum_hasTag_tag"], + "--relationships=HAS_INTEREST=" + output_files["dynamic/person_hasInterest_tag"], + "--relationships=IS_LOCATED_IN=" + output_files["dynamic/person_isLocatedIn_place"], + "--relationships=KNOWS=" + output_files["dynamic/person_knows_person"], + "--relationships=LIKES=" + output_files["dynamic/person_likes_comment"], + "--relationships=LIKES=" + output_files["dynamic/person_likes_post"], + "--relationships=HAS_CREATOR=" + output_files["dynamic/post_hasCreator_person"], + "--relationships=HAS_TAG=" + output_files["dynamic/comment_hasTag_tag"], + "--relationships=HAS_TAG=" + output_files["dynamic/post_hasTag_tag"], + "--relationships=IS_LOCATED_IN=" + output_files["dynamic/post_isLocatedIn_place"], + "--relationships=STUDY_AT=" + output_files["dynamic/person_studyAt_organisation"], + "--relationships=WORK_AT=" + output_files["dynamic/person_workAt_organisation"], + "--delimiter", + "|", + "neo4j", + ], + check=True, + ) + + vendor_runner.start_preparation("Index preparation") + print("Executing database index setup") + client.execute(file_path=self._index_file, num_workers=1) + vendor_runner.stop("Stop index preparation") + + return True + else: + return False diff --git a/tests/mgbench/workloads/importers/importer_pokec.py b/tests/mgbench/workloads/importers/importer_pokec.py new file mode 100644 index 000000000..ee6621369 --- /dev/null +++ b/tests/mgbench/workloads/importers/importer_pokec.py @@ -0,0 +1,41 @@ +from pathlib import Path + +from benchmark_context import BenchmarkContext +from runners import BaseRunner + + +class ImporterPokec: + def __init__( + self, benchmark_context: BenchmarkContext, dataset_name: str, variant: str, index_file: str, dataset_file: str + ) -> None: + self._benchmark_context = benchmark_context + self._dataset_name = dataset_name + self._variant = variant + self._index_file = index_file + self._dataset_file = dataset_file + + def execute_import(self): + if self._benchmark_context.vendor_name == "neo4j": + + vendor_runner = BaseRunner.create( + benchmark_context=self._benchmark_context, + ) + client = vendor_runner.fetch_client() + vendor_runner.clean_db() + vendor_runner.start_preparation("preparation") + print("Executing database cleanup and index setup...") + client.execute(file_path=self._index_file, num_workers=1) + vendor_runner.stop("preparation") + neo4j_dump = Path() / ".cache" / "datasets" / self._dataset_name / self._variant / "neo4j.dump" + if neo4j_dump.exists(): + vendor_runner.load_db_from_dump(path=neo4j_dump.parent) + else: + vendor_runner.start_preparation("import") + print("Importing dataset...") + client.execute(file_path=self._dataset_file, num_workers=self._benchmark_context.num_workers_for_import) + vendor_runner.stop("import") + vendor_runner.dump_db(path=neo4j_dump.parent) + + return True + else: + return False diff --git a/tests/mgbench/workloads/ldbc_bi.py b/tests/mgbench/workloads/ldbc_bi.py new file mode 100644 index 000000000..e1d28577b --- /dev/null +++ b/tests/mgbench/workloads/ldbc_bi.py @@ -0,0 +1,708 @@ +import inspect +import random +from pathlib import Path + +import helpers +from benchmark_context import BenchmarkContext +from workloads.base import Workload +from workloads.importers.importer_ldbc_bi import ImporterLDBCBI + + +class LDBC_BI(Workload): + NAME = "ldbc_bi" + VARIANTS = ["sf1", "sf3", "sf10"] + DEFAULT_VARIANT = "sf1" + + URL_FILE = { + "sf1": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/ldbc/benchmark/bi/ldbc_bi_sf1.cypher.gz", + "sf3": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/ldbc/benchmark/bi/ldbc_bi_sf3.cypher.gz", + "sf10": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/ldbc/benchmark/bi/ldbc_bi_sf10.cypher.gz", + } + + URL_CSV = { + "sf1": "https://pub-383410a98aef4cb686f0c7601eddd25f.r2.dev/bi-pre-audit/bi-sf1-composite-projected-fk.tar.zst", + "sf3": "https://pub-383410a98aef4cb686f0c7601eddd25f.r2.dev/bi-pre-audit/bi-sf3-composite-projected-fk.tar.zst", + "sf10": "https://pub-383410a98aef4cb686f0c7601eddd25f.r2.dev/bi-pre-audit/bi-sf10-composite-projected-fk.tar.zst", + } + + SIZES = { + "sf1": {"vertices": 2997352, "edges": 17196776}, + "sf3": {"vertices": 1, "edges": 1}, + "sf10": {"vertices": 1, "edges": 1}, + } + + LOCAL_INDEX_FILES = None + + URL_INDEX_FILE = { + "memgraph": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/ldbc/benchmark/bi/memgraph_bi_index.cypher", + "neo4j": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/ldbc/benchmark/bi/neo4j_bi_index.cypher", + } + + QUERY_PARAMETERS = { + "sf1": "https://pub-383410a98aef4cb686f0c7601eddd25f.r2.dev/bi-pre-audit/parameters-2022-10-01.zip", + "sf3": "https://pub-383410a98aef4cb686f0c7601eddd25f.r2.dev/bi-pre-audit/parameters-2022-10-01.zip", + "sf10": "https://pub-383410a98aef4cb686f0c7601eddd25f.r2.dev/bi-pre-audit/parameters-2022-10-01.zip", + } + + def custom_import(self) -> bool: + importer = ImporterLDBCBI( + benchmark_context=self.benchmark_context, + dataset_name=self.NAME, + variant=self._variant, + index_file=self._file_index, + csv_dict=self.URL_CSV, + ) + return importer.execute_import() + + def _prepare_parameters_directory(self): + parameters = Path() / ".cache" / "datasets" / self.NAME / self._variant / "parameters" + parameters.mkdir(parents=True, exist_ok=True) + if parameters.exists() and any(parameters.iterdir()): + print("Files downloaded.") + else: + print("Downloading files") + downloaded_file = helpers.download_file(self.QUERY_PARAMETERS[self._variant], parameters.parent.absolute()) + print("Unpacking the file..." + downloaded_file) + parameters = helpers.unpack_zip(Path(downloaded_file)) + return parameters / ("parameters-" + self._variant) + + def _get_query_parameters(self) -> dict: + func_name = inspect.stack()[1].function + parameters = {} + for file in self._parameters_dir.glob("bi-*.csv"): + file_name_query_id = file.name.split("-")[1][0:-4] + func_name_id = func_name.split("_")[-1] + if file_name_query_id == func_name_id or file_name_query_id == func_name_id + "a": + with file.open("r") as input: + lines = input.readlines() + header = lines[0].strip("\n").split("|") + position = random.randint(1, len(lines) - 1) + data = lines[position].strip("\n").split("|") + for i in range(len(header)): + key, value_type = header[i].split(":") + if value_type == "DATETIME": + # Drop time zone + converted = data[i][0:-6] + parameters[key] = converted + elif value_type == "DATE": + converted = data[i] + "T00:00:00" + parameters[key] = converted + elif value_type == "INT": + parameters[key] = int(data[i]) + elif value_type == "STRING[]": + elements = data[i].split(";") + parameters[key] = elements + else: + parameters[key] = data[i] + break + + return parameters + + def __init__(self, variant=None, benchmark_context: BenchmarkContext = None): + super().__init__(variant, benchmark_context=benchmark_context) + self._parameters_dir = self._prepare_parameters_directory() + + def benchmark__bi__query_1_analytical(self): + + memgraph = ( + """ + MATCH (message:Message) + WHERE message.creationDate < localDateTime($datetime) + WITH count(message) AS totalMessageCountInt + WITH toFloat(totalMessageCountInt) AS totalMessageCount + MATCH (message:Message) + WHERE message.creationDate < localDateTime($datetime) + AND message.content IS NOT NULL + WITH + totalMessageCount, + message, + message.creationDate.year AS year + WITH + totalMessageCount, + year, + message:Comment AS isComment, + CASE + WHEN message.length < 40 THEN 0 + WHEN message.length < 80 THEN 1 + WHEN message.length < 160 THEN 2 + ELSE 3 + END AS lengthCategory, + count(message) AS messageCount, + sum(message.length) / toFloat(count(message)) AS averageMessageLength, + sum(message.length) AS sumMessageLength + RETURN + year, + isComment, + lengthCategory, + messageCount, + averageMessageLength, + sumMessageLength, + messageCount / totalMessageCount AS percentageOfMessages + ORDER BY + year DESC, + isComment ASC, + lengthCategory ASC + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + + neo4j = ( + """ + MATCH (message:Message) + WHERE message.creationDate < DateTime($datetime) + WITH count(message) AS totalMessageCountInt + WITH toFloat(totalMessageCountInt) AS totalMessageCount + MATCH (message:Message) + WHERE message.creationDate < DateTime($datetime) + AND message.content IS NOT NULL + WITH + totalMessageCount, + message, + message.creationDate.year AS year + WITH + totalMessageCount, + year, + message:Comment AS isComment, + CASE + WHEN message.length < 40 THEN 0 + WHEN message.length < 80 THEN 1 + WHEN message.length < 160 THEN 2 + ELSE 3 + END AS lengthCategory, + count(message) AS messageCount, + sum(message.length) / toFloat(count(message)) AS averageMessageLength, + sum(message.length) AS sumMessageLength + RETURN + year, + isComment, + lengthCategory, + messageCount, + averageMessageLength, + sumMessageLength, + messageCount / totalMessageCount AS percentageOfMessages + ORDER BY + year DESC, + isComment ASC, + lengthCategory ASC + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + if self._vendor == "memgraph": + return memgraph + else: + return neo4j + + def benchmark__bi__query_2_analytical(self): + + memgraph = ( + """ + MATCH (tag:Tag)-[:HAS_TYPE]->(:TagClass {name: $tagClass}) + OPTIONAL MATCH (message1:Message)-[:HAS_TAG]->(tag) + WHERE localDateTime($date) <= message1.creationDate + AND message1.creationDate < localDateTime($date) + duration({day: 100}) + WITH tag, count(message1) AS countWindow1 + OPTIONAL MATCH (message2:Message)-[:HAS_TAG]->(tag) + WHERE localDateTime($date) + duration({day: 100}) <= message2.creationDate + AND message2.creationDate < localDateTime($date) + duration({day: 200}) + WITH + tag, + countWindow1, + count(message2) AS countWindow2 + RETURN + tag.name, + countWindow1, + countWindow2, + abs(countWindow1 - countWindow2) AS diff + ORDER BY + diff DESC, + tag.name ASC + LIMIT 100 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + + neo4j = ( + """ + MATCH (tag:Tag)-[:HAS_TYPE]->(:TagClass {name: $tagClass}) + OPTIONAL MATCH (message1:Message)-[:HAS_TAG]->(tag) + WHERE DateTime($date) <= message1.creationDate + AND message1.creationDate < DateTime($date) + duration({days: 100}) + WITH tag, count(message1) AS countWindow1 + OPTIONAL MATCH (message2:Message)-[:HAS_TAG]->(tag) + WHERE DateTime($date) + duration({days: 100}) <= message2.creationDate + AND message2.creationDate < DateTime($date) + duration({days: 200}) + WITH + tag, + countWindow1, + count(message2) AS countWindow2 + RETURN + tag.name, + countWindow1, + countWindow2, + abs(countWindow1 - countWindow2) AS diff + ORDER BY + diff DESC, + tag.name ASC + LIMIT 100 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + if self._vendor == "memgraph": + return memgraph + else: + return neo4j + + def benchmark__bi__query_3_analytical(self): + return ( + """ + MATCH + (:Country {name: $country})<-[:IS_PART_OF]-(:City)<-[:IS_LOCATED_IN]- + (person:Person)<-[:HAS_MODERATOR]-(forum:Forum)-[:CONTAINER_OF]-> + (post:Post)<-[:REPLY_OF*0..]-(message:Message)-[:HAS_TAG]->(:Tag)-[:HAS_TYPE]->(:TagClass {name: $tagClass}) + RETURN + forum.id as id, + forum.title, + person.id, + count(DISTINCT message) AS messageCount + ORDER BY + messageCount DESC, + id ASC + LIMIT 20 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + + def benchmark__bi__query_5_analytical(self): + return ( + """ + MATCH (tag:Tag {name: $tag})<-[:HAS_TAG]-(message:Message)-[:HAS_CREATOR]->(person:Person) + OPTIONAL MATCH (message)<-[likes:LIKES]-(:Person) + WITH person, message, count(likes) AS likeCount + OPTIONAL MATCH (message)<-[:REPLY_OF]-(reply:Comment) + WITH person, message, likeCount, count(reply) AS replyCount + WITH person, count(message) AS messageCount, sum(likeCount) AS likeCount, sum(replyCount) AS replyCount + RETURN + person.id, + replyCount, + likeCount, + messageCount, + 1*messageCount + 2*replyCount + 10*likeCount AS score + ORDER BY + score DESC, + person.id ASC + LIMIT 100 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + + def benchmark__bi__query_6_analytical(self): + return ( + """ + MATCH (tag:Tag {name: $tag})<-[:HAS_TAG]-(message1:Message)-[:HAS_CREATOR]->(person1:Person) + OPTIONAL MATCH (message1)<-[:LIKES]-(person2:Person) + OPTIONAL MATCH (person2)<-[:HAS_CREATOR]-(message2:Message)<-[like:LIKES]-(person3:Person) + RETURN + person1.id as id, + count(DISTINCT like) AS authorityScore + ORDER BY + authorityScore DESC, + id ASC + LIMIT 100 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + + def benchmark__bi__query_7_analytical(self): + + memgraph = ( + """ + MATCH + (tag:Tag {name: $tag})<-[:HAS_TAG]-(message:Message), + (message)<-[:REPLY_OF]-(comment:Comment)-[:HAS_TAG]->(relatedTag:Tag) + OPTIONAL MATCH (comment)-[:HAS_TAG]->(tag) + WHERE tag IS NOT NULL + RETURN + relatedTag, + count(DISTINCT comment) AS count + ORDER BY + relatedTag.name ASC, + count DESC + LIMIT 100 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + + neo4j = ( + """ + MATCH + (tag:Tag {name: $tag})<-[:HAS_TAG]-(message:Message), + (message)<-[:REPLY_OF]-(comment:Comment)-[:HAS_TAG]->(relatedTag:Tag) + WHERE NOT (comment)-[:HAS_TAG]->(tag) + RETURN + relatedTag.name, + count(DISTINCT comment) AS count + ORDER BY + relatedTag.name ASC, + count DESC + LIMIT 100 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + if self._vendor == "memgraph": + return memgraph + else: + return neo4j + + def benchmark__bi__query_9_analytical(self): + memgraph = ( + """ + MATCH (person:Person)<-[:HAS_CREATOR]-(post:Post)<-[:REPLY_OF*0..]-(reply:Message) + WHERE post.creationDate >= localDateTime($startDate) + AND post.creationDate <= localDateTime($endDate) + AND reply.creationDate >= localDateTime($startDate) + AND reply.creationDate <= localDateTime($endDate) + RETURN + person.id as id, + person.firstName, + person.lastName, + count(DISTINCT post) AS threadCount, + count(DISTINCT reply) AS messageCount + ORDER BY + messageCount DESC, + id ASC + LIMIT 100 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + neo4j = ( + """ + MATCH (person:Person)<-[:HAS_CREATOR]-(post:Post)<-[:REPLY_OF*0..]-(reply:Message) + WHERE post.creationDate >= DateTime($startDate) + AND post.creationDate <= DateTime($endDate) + AND reply.creationDate >= DateTime($startDate) + AND reply.creationDate <= DateTime($endDate) + RETURN + person.id as id, + person.firstName, + person.lastName, + count(DISTINCT post) AS threadCount, + count(DISTINCT reply) AS messageCount + ORDER BY + messageCount DESC, + id ASC + LIMIT 100 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + if self._vendor == "memgraph": + return memgraph + else: + return neo4j + + def benchmark__bi__query_11_analytical(self): + return ( + """ + MATCH (a:Person)-[:IS_LOCATED_IN]->(:City)-[:IS_PART_OF]->(country:Country {name: $country}), + (a)-[k1:KNOWS]-(b:Person) + WHERE a.id < b.id + AND localDateTime($startDate) <= k1.creationDate AND k1.creationDate <= localDateTime($endDate) + WITH DISTINCT country, a, b + MATCH (b)-[:IS_LOCATED_IN]->(:City)-[:IS_PART_OF]->(country) + WITH DISTINCT country, a, b + MATCH (b)-[k2:KNOWS]-(c:Person), + (c)-[:IS_LOCATED_IN]->(:City)-[:IS_PART_OF]->(country) + WHERE b.id < c.id + AND localDateTime($startDate) <= k2.creationDate AND k2.creationDate <= localDateTime($endDate) + WITH DISTINCT a, b, c + MATCH (c)-[k3:KNOWS]-(a) + WHERE localDateTime($startDate) <= k3.creationDate AND k3.creationDate <= localDateTime($endDate) + WITH DISTINCT a, b, c + RETURN count(*) AS count + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + + def benchmark__bi__query_12_analytical(self): + return ( + """ + MATCH (person:Person) + OPTIONAL MATCH (person)<-[:HAS_CREATOR]-(message:Message)-[:REPLY_OF*0..]->(post:Post) + WHERE message.content IS NOT NULL + AND message.length < $lengthThreshold + AND message.creationDate > localDateTime($startDate) + AND post.language IN $languages + WITH + person, + count(message) AS messageCount + RETURN + messageCount, + count(person) AS personCount + ORDER BY + personCount DESC, + messageCount DESC + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + + def benchmark__bi__query_13_analytical(self): + memgraph = ( + """ + MATCH (country:Country {name: $country})<-[:IS_PART_OF]-(:City)<-[:IS_LOCATED_IN]-(zombie:Person) + WHERE zombie.creationDate < localDateTime($endDate) + WITH country, zombie + OPTIONAL MATCH (zombie)<-[:HAS_CREATOR]-(message:Message) + WHERE message.creationDate < localDateTime($endDate) + WITH + country, + zombie, + count(message) AS messageCount + WITH + country, + zombie, + 12 * (localDateTime($endDate).year - zombie.creationDate.year ) + + (localDateTime($endDate).month - zombie.creationDate.month) + + 1 AS months, + messageCount + WHERE messageCount / months < 1 + WITH + country, + collect(zombie) AS zombies + UNWIND zombies AS zombie + OPTIONAL MATCH + (zombie)<-[:HAS_CREATOR]-(message:Message)<-[:LIKES]-(likerZombie:Person) + WHERE likerZombie IN zombies + WITH + zombie, + count(likerZombie) AS zombieLikeCount + OPTIONAL MATCH + (zombie)<-[:HAS_CREATOR]-(message:Message)<-[:LIKES]-(likerPerson:Person) + WHERE likerPerson.creationDate < localDateTime($endDate) + WITH + zombie, + zombieLikeCount, + count(likerPerson) AS totalLikeCount + RETURN + zombie.id, + zombieLikeCount, + totalLikeCount, + CASE totalLikeCount + WHEN 0 THEN 0.0 + ELSE zombieLikeCount / toFloat(totalLikeCount) + END AS zombieScore + ORDER BY + zombieScore DESC, + zombie.id ASC + LIMIT 100 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + + neo4j = ( + """ + MATCH (country:Country {name: $country})<-[:IS_PART_OF]-(:City)<-[:IS_LOCATED_IN]-(zombie:Person) + WHERE zombie.creationDate < DateTime($endDate) + WITH country, zombie + OPTIONAL MATCH (zombie)<-[:HAS_CREATOR]-(message:Message) + WHERE message.creationDate < DateTime($endDate) + WITH + country, + zombie, + count(message) AS messageCount + WITH + country, + zombie, + 12 * (DateTime($endDate).year - zombie.creationDate.year ) + + (DateTime($endDate).month - zombie.creationDate.month) + + 1 AS months, + messageCount + WHERE messageCount / months < 1 + WITH + country, + collect(zombie) AS zombies + UNWIND zombies AS zombie + OPTIONAL MATCH + (zombie)<-[:HAS_CREATOR]-(message:Message)<-[:LIKES]-(likerZombie:Person) + WHERE likerZombie IN zombies + WITH + zombie, + count(likerZombie) AS zombieLikeCount + OPTIONAL MATCH + (zombie)<-[:HAS_CREATOR]-(message:Message)<-[:LIKES]-(likerPerson:Person) + WHERE likerPerson.creationDate < DateTime($endDate) + WITH + zombie, + zombieLikeCount, + count(likerPerson) AS totalLikeCount + RETURN + zombie.id, + zombieLikeCount, + totalLikeCount, + CASE totalLikeCount + WHEN 0 THEN 0.0 + ELSE zombieLikeCount / toFloat(totalLikeCount) + END AS zombieScore + ORDER BY + zombieScore DESC, + zombie.id ASC + LIMIT 100 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + + if self._vendor == "memgraph": + return memgraph + else: + return neo4j + + def benchmark__bi__query_14_analytical(self): + return ( + """ + MATCH + (country1:Country {name: $country1})<-[:IS_PART_OF]-(city1:City)<-[:IS_LOCATED_IN]-(person1:Person), + (country2:Country {name: $country2})<-[:IS_PART_OF]-(city2:City)<-[:IS_LOCATED_IN]-(person2:Person), + (person1)-[:KNOWS]-(person2) + WITH person1, person2, city1, 0 AS score + OPTIONAL MATCH (person1)<-[:HAS_CREATOR]-(c:Comment)-[:REPLY_OF]->(:Message)-[:HAS_CREATOR]->(person2) + WITH DISTINCT person1, person2, city1, score + (CASE c WHEN null THEN 0 ELSE 4 END) AS score + OPTIONAL MATCH (person1)<-[:HAS_CREATOR]-(m:Message)<-[:REPLY_OF]-(:Comment)-[:HAS_CREATOR]->(person2) + WITH DISTINCT person1, person2, city1, score + (CASE m WHEN null THEN 0 ELSE 1 END) AS score + OPTIONAL MATCH (person1)-[:LIKES]->(m:Message)-[:HAS_CREATOR]->(person2) + WITH DISTINCT person1, person2, city1, score + (CASE m WHEN null THEN 0 ELSE 10 END) AS score + OPTIONAL MATCH (person1)<-[:HAS_CREATOR]-(m:Message)<-[:LIKES]-(person2) + WITH DISTINCT person1, person2, city1, score + (CASE m WHEN null THEN 0 ELSE 1 END) AS score + ORDER BY + city1.name ASC, + score DESC, + person1.id ASC, + person2.id ASC + WITH city1, collect({score: score, person1Id: person1.id, person2Id: person2.id})[0] AS top + RETURN + top.person1Id, + top.person2Id, + city1.name, + top.score + ORDER BY + top.score DESC, + top.person1Id ASC, + top.person2Id ASC + LIMIT 100 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + + def benchmark__bi__query_17_analytical(self): + + memgraph = ( + """ + MATCH + (tag:Tag {name: $tag}), + (person1:Person)<-[:HAS_CREATOR]-(message1:Message)-[:REPLY_OF*0..]->(post1:Post)<-[:CONTAINER_OF]-(forum1:Forum), + (message1)-[:HAS_TAG]->(tag), + (forum1)<-[:HAS_MEMBER]->(person2:Person)<-[:HAS_CREATOR]-(comment:Comment)-[:HAS_TAG]->(tag), + (forum1)<-[:HAS_MEMBER]->(person3:Person)<-[:HAS_CREATOR]-(message2:Message), + (comment)-[:REPLY_OF]->(message2)-[:REPLY_OF*0..]->(post2:Post)<-[:CONTAINER_OF]-(forum2:Forum) + MATCH (comment)-[:HAS_TAG]->(tag) + MATCH (message2)-[:HAS_TAG]->(tag) + OPTIONAL MATCH (forum2)-[:HAS_MEMBER]->(person1) + WHERE forum1 <> forum2 AND message2.creationDate > message1.creationDate + duration({hours: $delta}) AND person1 IS NULL + RETURN person1, count(DISTINCT message2) AS messageCount + ORDER BY messageCount DESC, person1.id ASC + LIMIT 10 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + + neo4j = ( + """ + MATCH + (tag:Tag {name: $tag}), + (person1:Person)<-[:HAS_CREATOR]-(message1:Message)-[:REPLY_OF*0..]->(post1:Post)<-[:CONTAINER_OF]-(forum1:Forum), + (message1)-[:HAS_TAG]->(tag), + (forum1)<-[:HAS_MEMBER]->(person2:Person)<-[:HAS_CREATOR]-(comment:Comment)-[:HAS_TAG]->(tag), + (forum1)<-[:HAS_MEMBER]->(person3:Person)<-[:HAS_CREATOR]-(message2:Message), + (comment)-[:REPLY_OF]->(message2)-[:REPLY_OF*0..]->(post2:Post)<-[:CONTAINER_OF]-(forum2:Forum) + MATCH (comment)-[:HAS_TAG]->(tag) + MATCH (message2)-[:HAS_TAG]->(tag) + WHERE forum1 <> forum2 + AND message2.creationDate > message1.creationDate + duration({hours: $delta}) + AND NOT (forum2)-[:HAS_MEMBER]->(person1) + RETURN person1, count(DISTINCT message2) AS messageCount + ORDER BY messageCount DESC, person1.id ASC + LIMIT 10 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + if self._vendor == "memgraph": + return memgraph + else: + return neo4j + + def benchmark__bi__query_18_analytical(self): + + memgraph = ( + """ + MATCH (tag:Tag {name: $tag})<-[:HAS_INTEREST]-(person1:Person)-[:KNOWS]-(mutualFriend:Person)-[:KNOWS]-(person2:Person)-[:HAS_INTEREST]->(tag) + OPTIONAL MATCH (person1)-[:KNOWS]-(person2) + WHERE person1 <> person2 + RETURN person1.id AS person1Id, person2.id AS person2Id, count(DISTINCT mutualFriend) AS mutualFriendCount + ORDER BY mutualFriendCount DESC, person1Id ASC, person2Id ASC + LIMIT 20 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + + neo4j = ( + """ + MATCH (tag:Tag {name: $tag})<-[:HAS_INTEREST]-(person1:Person)-[:KNOWS]-(mutualFriend:Person)-[:KNOWS]-(person2:Person)-[:HAS_INTEREST]->(tag) + WHERE person1 <> person2 + AND NOT (person1)-[:KNOWS]-(person2) + RETURN person1.id AS person1Id, person2.id AS person2Id, count(DISTINCT mutualFriend) AS mutualFriendCount + ORDER BY mutualFriendCount DESC, person1Id ASC, person2Id ASC + LIMIT 20 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + if self._vendor == "memgraph": + return memgraph + else: + return neo4j diff --git a/tests/mgbench/workloads/ldbc_interactive.py b/tests/mgbench/workloads/ldbc_interactive.py new file mode 100644 index 000000000..576025949 --- /dev/null +++ b/tests/mgbench/workloads/ldbc_interactive.py @@ -0,0 +1,684 @@ +import inspect +import random +from datetime import datetime +from pathlib import Path + +import helpers +from benchmark_context import BenchmarkContext +from workloads.base import Workload +from workloads.importers.importer_ldbc_interactive import * + + +class LDBC_Interactive(Workload): + + NAME = "ldbc_interactive" + VARIANTS = ["sf0.1", "sf1", "sf3", "sf10"] + DEFAULT_VARIANT = "sf1" + + URL_FILE = { + "sf0.1": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/ldbc/benchmark/interactive/ldbc_interactive_sf0.1.cypher.gz", + "sf1": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/ldbc/benchmark/interactive/ldbc_interactive_sf1.cypher.gz", + "sf3": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/ldbc/benchmark/interactive/ldbc_interactive_sf3.cypher.gz", + "sf10": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/ldbc/benchmark/interactive/ldbc_interactive_sf10.cypher.gz", + } + URL_CSV = { + "sf0.1": "https://repository.surfsara.nl/datasets/cwi/snb/files/social_network-csv_basic/social_network-csv_basic-sf0.1.tar.zst", + "sf1": "https://repository.surfsara.nl/datasets/cwi/snb/files/social_network-csv_basic/social_network-csv_basic-sf1.tar.zst", + "sf3": "https://repository.surfsara.nl/datasets/cwi/snb/files/social_network-csv_basic/social_network-csv_basic-sf3.tar.zst", + "sf10": "https://repository.surfsara.nl/datasets/cwi/snb/files/social_network-csv_basic/social_network-csv_basic-sf10.tar.zst", + } + + SIZES = { + "sf0.1": {"vertices": 327588, "edges": 1477965}, + "sf1": {"vertices": 3181724, "edges": 17256038}, + "sf3": {"vertices": 1, "edges": 1}, + "sf10": {"vertices": 1, "edges": 1}, + } + + URL_INDEX_FILE = { + "memgraph": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/ldbc/benchmark/interactive/memgraph_interactive_index.cypher", + "neo4j": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/ldbc/benchmark/interactive/neo4j_interactive_index.cypher", + } + + PROPERTIES_ON_EDGES = True + + QUERY_PARAMETERS = { + "sf0.1": "https://repository.surfsara.nl/datasets/cwi/snb/files/substitution_parameters/substitution_parameters-sf0.1.tar.zst", + "sf1": "https://repository.surfsara.nl/datasets/cwi/snb/files/substitution_parameters/substitution_parameters-sf0.1.tar.zst", + "sf3": "https://repository.surfsara.nl/datasets/cwi/snb/files/substitution_parameters/substitution_parameters-sf0.1.tar.zst", + } + + def custom_import(self) -> bool: + importer = ImporterLDBCInteractive( + benchmark_context=self.benchmark_context, + dataset_name=self.NAME, + variant=self._variant, + index_file=self._file_index, + csv_dict=self.URL_CSV, + ) + return importer.execute_import() + + def _prepare_parameters_directory(self): + parameters = Path() / ".cache" / "datasets" / self.NAME / self._variant / "parameters" + parameters.mkdir(parents=True, exist_ok=True) + dir_name = self.QUERY_PARAMETERS[self._variant].split("/")[-1:][0].removesuffix(".tar.zst") + if (parameters / dir_name).exists(): + print("Files downloaded:") + parameters = parameters / dir_name + else: + print("Downloading files") + downloaded_file = helpers.download_file(self.QUERY_PARAMETERS[self._variant], parameters.absolute()) + print("Unpacking the file..." + downloaded_file) + parameters = helpers.unpack_tar_zst(Path(downloaded_file)) + return parameters + + def _get_query_parameters(self) -> dict: + func_name = inspect.stack()[1].function + parameters = {} + for file in self._parameters_dir.glob("interactive_*.txt"): + if file.name.split("_")[1] == func_name.split("_")[-2]: + with file.open("r") as input: + lines = input.readlines() + position = random.randint(1, len(lines) - 1) + header = lines[0].strip("\n").split("|") + data = lines[position].strip("\n").split("|") + for i in range(len(header)): + if "Date" in header[i]: + time = int(data[i]) / 1000 + converted = datetime.utcfromtimestamp(time).strftime("%Y-%m-%dT%H:%M:%S") + parameters[header[i]] = converted + elif data[i].isdigit(): + parameters[header[i]] = int(data[i]) + else: + parameters[header[i]] = data[i] + + return parameters + + def __init__(self, variant: str = None, benchmark_context: BenchmarkContext = None): + super().__init__(variant, benchmark_context=benchmark_context) + self._parameters_dir = self._prepare_parameters_directory() + self.benchmark_context = benchmark_context + + def benchmark__interactive__complex_query_1_analytical(self): + memgraph = ( + """ + MATCH (p:Person {id: $personId}), (friend:Person {firstName: $firstName}) + WHERE NOT p=friend + WITH p, friend + MATCH path =((p)-[:KNOWS *BFS 1..3]-(friend)) + WITH min(size(path)) AS distance, friend + ORDER BY + distance ASC, + friend.lastName ASC, + toInteger(friend.id) ASC + LIMIT 20 + + MATCH (friend)-[:IS_LOCATED_IN]->(friendCity:City) + OPTIONAL MATCH (friend)-[studyAt:STUDY_AT]->(uni:University)-[:IS_LOCATED_IN]->(uniCity:City) + WITH friend, collect( + CASE uni.name + WHEN null THEN null + ELSE [uni.name, studyAt.classYear, uniCity.name] + END ) AS unis, friendCity, distance + + OPTIONAL MATCH (friend)-[workAt:WORK_AT]->(company:Company)-[:IS_LOCATED_IN]->(companyCountry:Country) + WITH friend, collect( + CASE company.name + WHEN null THEN null + ELSE [company.name, workAt.workFrom, companyCountry.name] + END ) AS companies, unis, friendCity, distance + + RETURN + friend.id AS friendId, + friend.lastName AS friendLastName, + distance AS distanceFromPerson, + friend.birthday AS friendBirthday, + friend.gender AS friendGender, + friend.browserUsed AS friendBrowserUsed, + friend.locationIP AS friendLocationIp, + friend.email AS friendEmails, + friend.speaks AS friendLanguages, + friendCity.name AS friendCityName, + unis AS friendUniversities, + companies AS friendCompanies + ORDER BY + distanceFromPerson ASC, + friendLastName ASC, + toInteger(friendId) ASC + LIMIT 20 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + neo4j = ( + """ + MATCH (p:Person {id: $personId}), (friend:Person {firstName: $firstName}) + WHERE NOT p=friend + WITH p, friend + MATCH path = shortestPath((p)-[:KNOWS*1..3]-(friend)) + WITH min(length(path)) AS distance, friend + ORDER BY + distance ASC, + friend.lastName ASC, + toInteger(friend.id) ASC + LIMIT 20 + + MATCH (friend)-[:IS_LOCATED_IN]->(friendCity:City) + OPTIONAL MATCH (friend)-[studyAt:STUDY_AT]->(uni:University)-[:IS_LOCATED_IN]->(uniCity:City) + WITH friend, collect( + CASE uni.name + WHEN null THEN null + ELSE [uni.name, studyAt.classYear, uniCity.name] + END ) AS unis, friendCity, distance + + OPTIONAL MATCH (friend)-[workAt:WORK_AT]->(company:Company)-[:IS_LOCATED_IN]->(companyCountry:Country) + WITH friend, collect( + CASE company.name + WHEN null THEN null + ELSE [company.name, workAt.workFrom, companyCountry.name] + END ) AS companies, unis, friendCity, distance + + RETURN + friend.id AS friendId, + friend.lastName AS friendLastName, + distance AS distanceFromPerson, + friend.birthday AS friendBirthday, + friend.gender AS friendGender, + friend.browserUsed AS friendBrowserUsed, + friend.locationIP AS friendLocationIp, + friend.email AS friendEmails, + friend.speaks AS friendLanguages, + friendCity.name AS friendCityName, + unis AS friendUniversities, + companies AS friendCompanies + ORDER BY + distanceFromPerson ASC, + friendLastName ASC, + toInteger(friendId) ASC + LIMIT 20 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + if self._vendor == "memgraph": + return memgraph + else: + return neo4j + + def benchmark__interactive__complex_query_2_analytical(self): + return ( + """ + MATCH (:Person {id: $personId })-[:KNOWS]-(friend:Person)<-[:HAS_CREATOR]-(message:Message) + WHERE message.creationDate <= localDateTime($maxDate) + RETURN + friend.id AS personId, + friend.firstName AS personFirstName, + friend.lastName AS personLastName, + message.id AS postOrCommentId, + coalesce(message.content,message.imageFile) AS postOrCommentContent, + message.creationDate AS postOrCommentCreationDate + ORDER BY + postOrCommentCreationDate DESC, + toInteger(postOrCommentId) ASC + LIMIT 20 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + + def benchmark__interactive__complex_query_3_analytical(self): + + memgraph = ( + """ + MATCH (countryX:Country {name: $countryXName }), + (countryY:Country {name: $countryYName }), + (person:Person {id: $personId }) + WITH person, countryX, countryY + LIMIT 1 + MATCH (city:City)-[:IS_PART_OF]->(country:Country) + WHERE country IN [countryX, countryY] + WITH person, countryX, countryY, collect(city) AS cities + MATCH (person)-[:KNOWS*1..2]-(friend)-[:IS_LOCATED_IN]->(city) + WHERE NOT person=friend AND NOT city IN cities + WITH DISTINCT friend, countryX, countryY + MATCH (friend)<-[:HAS_CREATOR]-(message), + (message)-[:IS_LOCATED_IN]->(country) + WHERE localDateTime($startDate) + duration({day:$durationDays}) > message.creationDate >= localDateTime($startDate) AND + country IN [countryX, countryY] + WITH friend, + CASE WHEN country=countryX THEN 1 ELSE 0 END AS messageX, + CASE WHEN country=countryY THEN 1 ELSE 0 END AS messageY + WITH friend, sum(messageX) AS xCount, sum(messageY) AS yCount + WHERE xCount>0 AND yCount>0 + RETURN friend.id AS friendId, + friend.firstName AS friendFirstName, + friend.lastName AS friendLastName, + xCount, + yCount, + xCount + yCount AS xyCount + ORDER BY xyCount DESC, friendId ASC + LIMIT 20 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + neo4j = ( + """ + MATCH (countryX:Country {name: $countryXName }), + (countryY:Country {name: $countryYName }), + (person:Person {id: $personId }) + WITH person, countryX, countryY + LIMIT 1 + MATCH (city:City)-[:IS_PART_OF]->(country:Country) + WHERE country IN [countryX, countryY] + WITH person, countryX, countryY, collect(city) AS cities + MATCH (person)-[:KNOWS*1..2]-(friend)-[:IS_LOCATED_IN]->(city) + WHERE NOT person=friend AND NOT city IN cities + WITH DISTINCT friend, countryX, countryY + MATCH (friend)<-[:HAS_CREATOR]-(message), + (message)-[:IS_LOCATED_IN]->(country) + WHERE localDateTime($startDate) + duration({days:$durationDays}) > message.creationDate >= localDateTime($startDate) AND + country IN [countryX, countryY] + WITH friend, + CASE WHEN country=countryX THEN 1 ELSE 0 END AS messageX, + CASE WHEN country=countryY THEN 1 ELSE 0 END AS messageY + WITH friend, sum(messageX) AS xCount, sum(messageY) AS yCount + WHERE xCount>0 AND yCount>0 + RETURN friend.id AS friendId, + friend.firstName AS friendFirstName, + friend.lastName AS friendLastName, + xCount, + yCount, + xCount + yCount AS xyCount + ORDER BY xyCount DESC, friendId ASC + LIMIT 20 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + + if self._vendor == "memgraph": + return memgraph + else: + return neo4j + + def benchmark__interactive__complex_query_4_analytical(self): + memgraph = ( + """ + MATCH (person:Person {id: $personId })-[:KNOWS]-(friend:Person), + (friend)<-[:HAS_CREATOR]-(post:Post)-[:HAS_TAG]->(tag) + WITH DISTINCT tag, post + WITH tag, + CASE + WHEN localDateTime($startDate) + duration({day:$durationDays}) > post.creationDate >= localDateTime($startDate) THEN 1 + ELSE 0 + END AS valid, + CASE + WHEN localDateTime($startDate) > post.creationDate THEN 1 + ELSE 0 + END AS inValid + WITH tag, sum(valid) AS postCount, sum(inValid) AS inValidPostCount + WHERE postCount>0 AND inValidPostCount=0 + RETURN tag.name AS tagName, postCount + ORDER BY postCount DESC, tagName ASC + LIMIT 10 + + """, + self._get_query_parameters(), + ) + + neo4j = ( + """ + MATCH (person:Person {id: $personId })-[:KNOWS]-(friend:Person), + (friend)<-[:HAS_CREATOR]-(post:Post)-[:HAS_TAG]->(tag) + WITH DISTINCT tag, post + WITH tag, + CASE + WHEN localDateTime($startDate) + duration({days:$durationDays}) > post.creationDate >= localDateTime($startDate) THEN 1 + ELSE 0 + END AS valid, + CASE + WHEN localDateTime($startDate) > post.creationDate THEN 1 + ELSE 0 + END AS inValid + WITH tag, sum(valid) AS postCount, sum(inValid) AS inValidPostCount + WHERE postCount>0 AND inValidPostCount=0 + RETURN tag.name AS tagName, postCount + ORDER BY postCount DESC, tagName ASC + LIMIT 10 + + """, + self._get_query_parameters(), + ) + + if self._vendor == "memgraph": + return memgraph + else: + return neo4j + + def benchmark__interactive__complex_query_5_analytical(self): + return ( + """ + MATCH (person:Person { id: $personId })-[:KNOWS*1..2]-(friend) + WHERE + NOT person=friend + WITH DISTINCT friend + MATCH (friend)<-[membership:HAS_MEMBER]-(forum) + WHERE + membership.joinDate > localDateTime($minDate) + WITH + forum, + collect(friend) AS friends + OPTIONAL MATCH (friend)<-[:HAS_CREATOR]-(post)<-[:CONTAINER_OF]-(forum) + WHERE + friend IN friends + WITH + forum, + count(post) AS postCount + RETURN + forum.title AS forumName, + postCount + ORDER BY + postCount DESC, + forum.id ASC + LIMIT 20 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + + def benchmark__interactive__complex_query_6_analytical(self): + return ( + """ + MATCH (knownTag:Tag { name: $tagName }) + WITH knownTag.id as knownTagId + + MATCH (person:Person { id: $personId })-[:KNOWS*1..2]-(friend) + WHERE NOT person=friend + WITH + knownTagId, + collect(distinct friend) as friends + UNWIND friends as f + MATCH (f)<-[:HAS_CREATOR]-(post:Post), + (post)-[:HAS_TAG]->(t:Tag{id: knownTagId}), + (post)-[:HAS_TAG]->(tag:Tag) + WHERE NOT t = tag + WITH + tag.name as tagName, + count(post) as postCount + RETURN + tagName, + postCount + ORDER BY + postCount DESC, + tagName ASC + LIMIT 10 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + + def benchmark__interactive__complex_query_7_analytical(self): + memgraph = ( + """ + MATCH (person:Person {id: $personId})<-[:HAS_CREATOR]-(message:Message)<-[like:LIKES]-(liker:Person) + WITH liker, message, like.creationDate AS likeTime, person + ORDER BY likeTime DESC, toInteger(message.id) ASC + WITH liker, head(collect({msg: message, likeTime: likeTime})) AS latestLike, person + OPTIONAL MATCH (liker)-[:KNOWS]-(person) + WITH liker, latestLike, person, + CASE WHEN person IS null THEN TRUE ELSE FALSE END AS isNew + RETURN + liker.id AS personId, + liker.firstName AS personFirstName, + liker.lastName AS personLastName, + latestLike.likeTime AS likeCreationDate, + latestLike.msg.id AS commentOrPostId, + coalesce(latestLike.msg.content, latestLike.msg.imageFile) AS commentOrPostContent, + (latestLike.likeTime - latestLike.msg.creationDate).minute AS minutesLatency + ORDER BY + likeCreationDate DESC, + toInteger(personId) ASC + LIMIT 20 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + neo4j = ( + """ + MATCH (person:Person {id: $personId})<-[:HAS_CREATOR]-(message:Message)<-[like:LIKES]-(liker:Person) + WITH liker, message, like.creationDate AS likeTime, person + ORDER BY likeTime DESC, toInteger(message.id) ASC + WITH liker, head(collect({msg: message, likeTime: likeTime})) AS latestLike, person + RETURN + liker.id AS personId, + liker.firstName AS personFirstName, + liker.lastName AS personLastName, + latestLike.likeTime AS likeCreationDate, + latestLike.msg.id AS commentOrPostId, + coalesce(latestLike.msg.content, latestLike.msg.imageFile) AS commentOrPostContent, + duration.between(latestLike.likeTime, latestLike.msg.creationDate).minutes AS minutesLatency, + not((liker)-[:KNOWS]-(person)) AS isNew + ORDER BY + likeCreationDate DESC, + toInteger(personId) ASC + LIMIT 20 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + if self._vendor == "memgraph": + return memgraph + else: + return neo4j + + def benchmark__interactive__complex_query_8_analytical(self): + return ( + """ + MATCH (start:Person {id: $personId})<-[:HAS_CREATOR]-(:Message)<-[:REPLY_OF]-(comment:Comment)-[:HAS_CREATOR]->(person:Person) + RETURN + person.id AS personId, + person.firstName AS personFirstName, + person.lastName AS personLastName, + comment.creationDate AS commentCreationDate, + comment.id AS commentId, + comment.content AS commentContent + ORDER BY + commentCreationDate DESC, + commentId ASC + LIMIT 20 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + + def benchmark__interactive__complex_query_9_analytical(self): + return ( + """ + MATCH (root:Person {id: $personId })-[:KNOWS*1..2]-(friend:Person) + WHERE NOT friend = root + WITH collect(distinct friend) as friends + UNWIND friends as friend + MATCH (friend)<-[:HAS_CREATOR]-(message:Message) + WHERE message.creationDate < localDateTime($maxDate) + RETURN + friend.id AS personId, + friend.firstName AS personFirstName, + friend.lastName AS personLastName, + message.id AS commentOrPostId, + coalesce(message.content,message.imageFile) AS commentOrPostContent, + message.creationDate AS commentOrPostCreationDate + ORDER BY + commentOrPostCreationDate DESC, + message.id ASC + LIMIT 20 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + + def benchmark__interactive__complex_query_10_analytical(self): + memgraph = ( + """ + MATCH (person:Person {id: $personId})-[:KNOWS*2..2]-(friend), + (friend)-[:IS_LOCATED_IN]->(city:City) + WHERE NOT friend=person AND + NOT (friend)-[:KNOWS]-(person) + WITH person, city, friend, datetime({epochMillis: friend.birthday}) as birthday + WHERE (birthday.month=$month AND birthday.day>=21) OR + (birthday.month=($month%12)+1 AND birthday.day<22) + WITH DISTINCT friend, city, person + OPTIONAL MATCH (friend)<-[:HAS_CREATOR]-(post:Post) + WITH friend, city, collect(post) AS posts, person + WITH friend, + city, + size(posts) AS postCount, + size([p IN posts WHERE (p)-[:HAS_TAG]->()<-[:HAS_INTEREST]-(person)]) AS commonPostCount + RETURN friend.id AS personId, + friend.firstName AS personFirstName, + friend.lastName AS personLastName, + commonPostCount - (postCount - commonPostCount) AS commonInterestScore, + friend.gender AS personGender, + city.name AS personCityName + ORDER BY commonInterestScore DESC, personId ASC + LIMIT 10 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + + neo4j = ( + """ + MATCH (person:Person {id: $personId})-[:KNOWS*2..2]-(friend), + (friend)-[:IS_LOCATED_IN]->(city:City) + WHERE NOT friend=person AND + NOT (friend)-[:KNOWS]-(person) + WITH person, city, friend, datetime({epochMillis: friend.birthday}) as birthday + WHERE (birthday.month=$month AND birthday.day>=21) OR + (birthday.month=($month%12)+1 AND birthday.day<22) + WITH DISTINCT friend, city, person + OPTIONAL MATCH (friend)<-[:HAS_CREATOR]-(post:Post) + WITH friend, city, collect(post) AS posts, person + WITH friend, + city, + size(posts) AS postCount, + size([p IN posts WHERE (p)-[:HAS_TAG]->()<-[:HAS_INTEREST]-(person)]) AS commonPostCount + RETURN friend.id AS personId, + friend.firstName AS personFirstName, + friend.lastName AS personLastName, + commonPostCount - (postCount - commonPostCount) AS commonInterestScore, + friend.gender AS personGender, + city.name AS personCityName + ORDER BY commonInterestScore DESC, personId ASC + LIMIT 10 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + + if self._vendor == "memgraph": + return memgraph + else: + return neo4j + + def benchmark__interactive__complex_query_11_analytical(self): + return ( + """ + MATCH (person:Person {id: $personId })-[:KNOWS*1..2]-(friend:Person) + WHERE not(person=friend) + WITH DISTINCT friend + MATCH (friend)-[workAt:WORK_AT]->(company:Company)-[:IS_LOCATED_IN]->(:Country {name: $countryName }) + WHERE workAt.workFrom < $workFromYear + RETURN + friend.id AS personId, + friend.firstName AS personFirstName, + friend.lastName AS personLastName, + company.name AS organizationName, + workAt.workFrom AS organizationWorkFromYear + ORDER BY + organizationWorkFromYear ASC, + toInteger(personId) ASC, + organizationName DESC + LIMIT 10 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + + def benchmark__interactive__complex_query_12_analytical(self): + return ( + """ + MATCH (tag:Tag)-[:HAS_TYPE|IS_SUBCLASS_OF*0..]->(baseTagClass:TagClass) + WHERE tag.name = $tagClassName OR baseTagClass.name = $tagClassName + WITH collect(tag.id) as tags + MATCH (:Person {id: $personId })-[:KNOWS]-(friend:Person)<-[:HAS_CREATOR]-(comment:Comment)-[:REPLY_OF]->(:Post)-[:HAS_TAG]->(tag:Tag) + WHERE tag.id in tags + RETURN + friend.id AS personId, + friend.firstName AS personFirstName, + friend.lastName AS personLastName, + collect(DISTINCT tag.name) AS tagNames, + count(DISTINCT comment) AS replyCount + ORDER BY + replyCount DESC, + toInteger(personId) ASC + LIMIT 20 + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + + def benchmark__interactive__complex_query_13_analytical(self): + memgraph = ( + """ + MATCH + (person1:Person {id: $person1Id}), + (person2:Person {id: $person2Id}), + path = (person1)-[:KNOWS *BFS]-(person2) + RETURN + CASE path IS NULL + WHEN true THEN -1 + ELSE size(path) + END AS shortestPathLength + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + + neo4j = ( + """ + MATCH + (person1:Person {id: $person1Id}), + (person2:Person {id: $person2Id}), + path = shortestPath((person1)-[:KNOWS*]-(person2)) + RETURN + CASE path IS NULL + WHEN true THEN -1 + ELSE length(path) + END AS shortestPathLength + """.replace( + "\n", "" + ), + self._get_query_parameters(), + ) + + if self._vendor == "memgraph": + return memgraph + else: + return neo4j diff --git a/tests/mgbench/datasets.py b/tests/mgbench/workloads/pokec.py similarity index 72% rename from tests/mgbench/datasets.py rename to tests/mgbench/workloads/pokec.py index 455d86af5..afecf0b6e 100644 --- a/tests/mgbench/datasets.py +++ b/tests/mgbench/workloads/pokec.py @@ -1,134 +1,17 @@ -# Copyright 2022 Memgraph Ltd. -# -# Use of this software is governed by the Business Source License -# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source -# License, and you may not use this file except in compliance with the Business Source License. -# -# As of the Change Date specified in that file, in accordance with -# the Business Source License, use of this software will be governed -# by the Apache License, Version 2.0, included in the file -# licenses/APL.txt. - import random -import helpers +from benchmark_context import BenchmarkContext +from workloads.base import Workload +from workloads.importers.importer_pokec import ImporterPokec -# Base dataset class used as a template to create each individual dataset. All -# common logic is handled here. -class Dataset: - # Name of the dataset. - NAME = "Base dataset" - # List of all variants of the dataset that exist. - VARIANTS = ["default"] - # One of the available variants that should be used as the default variant. - DEFAULT_VARIANT = "default" - # List of query files that should be used to import the dataset. - FILES = { - "default": "/foo/bar", - } - INDEX = None - INDEX_FILES = {"default": ""} - # List of query file URLs that should be used to import the dataset. - URLS = None - # Number of vertices/edges for each variant. - SIZES = { - "default": {"vertices": 0, "edges": 0}, - } - # Indicates whether the dataset has properties on edges. - PROPERTIES_ON_EDGES = False - - def __init__(self, variant=None, vendor=None): - """ - Accepts a `variant` variable that indicates which variant - of the dataset should be executed. - """ - if variant is None: - variant = self.DEFAULT_VARIANT - if variant not in self.VARIANTS: - raise ValueError("Invalid test variant!") - if (self.FILES and variant not in self.FILES) and (self.URLS and variant not in self.URLS): - raise ValueError("The variant doesn't have a defined URL or " "file path!") - if variant not in self.SIZES: - raise ValueError("The variant doesn't have a defined dataset " "size!") - if vendor not in self.INDEX_FILES: - raise ValueError("Vendor does not have INDEX for dataset!") - self._variant = variant - self._vendor = vendor - if self.FILES is not None: - self._file = self.FILES.get(variant, None) - else: - self._file = None - if self.URLS is not None: - self._url = self.URLS.get(variant, None) - else: - self._url = None - - if self.INDEX_FILES is not None: - self._index = self.INDEX_FILES.get(vendor, None) - else: - self._index = None - - self._size = self.SIZES[variant] - if "vertices" not in self._size or "edges" not in self._size: - raise ValueError("The size defined for this variant doesn't " "have the number of vertices and/or edges!") - self._num_vertices = self._size["vertices"] - self._num_edges = self._size["edges"] - - def prepare(self, directory): - if self._file is not None: - print("Using dataset file:", self._file) - else: - # TODO: add support for JSON datasets - cached_input, exists = directory.get_file("dataset.cypher") - if not exists: - print("Downloading dataset file:", self._url) - downloaded_file = helpers.download_file(self._url, directory.get_path()) - print("Unpacking and caching file:", downloaded_file) - helpers.unpack_and_move_file(downloaded_file, cached_input) - print("Using cached dataset file:", cached_input) - self._file = cached_input - - cached_index, exists = directory.get_file(self._vendor + ".cypher") - if not exists: - print("Downloading index file:", self._index) - downloaded_file = helpers.download_file(self._index, directory.get_path()) - print("Unpacking and caching file:", downloaded_file) - helpers.unpack_and_move_file(downloaded_file, cached_index) - print("Using cached index file:", cached_index) - self._index = cached_index - - def get_variant(self): - """Returns the current variant of the dataset.""" - return self._variant - - def get_index(self): - """Get index file, defined by vendor""" - return self._index - - def get_file(self): - """ - Returns path to the file that contains dataset creation queries. - """ - return self._file - - def get_size(self): - """Returns number of vertices/edges for the current variant.""" - return self._size - - # All tests should be query generator functions that output all of the - # queries that should be executed by the runner. The functions should be - # named `benchmark__GROUPNAME__TESTNAME` and should not accept any - # arguments. - - -class Pokec(Dataset): +class Pokec(Workload): NAME = "pokec" VARIANTS = ["small", "medium", "large"] DEFAULT_VARIANT = "small" - FILES = None + FILE = None - URLS = { + URL_FILE = { "small": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/pokec/benchmark/pokec_small_import.cypher", "medium": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/pokec/benchmark/pokec_medium_import.cypher", "large": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/pokec/benchmark/pokec_large.setup.cypher.gz", @@ -138,16 +21,28 @@ class Pokec(Dataset): "medium": {"vertices": 100000, "edges": 1768515}, "large": {"vertices": 1632803, "edges": 30622564}, } - INDEX = None - INDEX_FILES = { + + URL_INDEX_FILE = { "memgraph": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/pokec/benchmark/memgraph.cypher", "neo4j": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/pokec/benchmark/neo4j.cypher", } PROPERTIES_ON_EDGES = False - # Helpers used to generate the queries + def __init__(self, variant: str = None, benchmark_context: BenchmarkContext = None): + super().__init__(variant, benchmark_context=benchmark_context) + def custom_import(self) -> bool: + importer = ImporterPokec( + benchmark_context=self.benchmark_context, + dataset_name=self.NAME, + index_file=self._file_index, + dataset_file=self._file, + variant=self._variant, + ) + return importer.execute_import() + + # Helpers used to generate the queries def _get_random_vertex(self): # All vertices in the Pokec dataset have an ID in the range # [1, _num_vertices]. @@ -343,7 +238,7 @@ class Pokec(Dataset): return ("MATCH (n:User {id: $id}) RETURN n", {"id": self._get_random_vertex()}) def benchmark__match__vertex_on_property(self): - return ("MATCH (n {id: $id}) RETURN n", {"id": self._get_random_vertex()}) + return ("MATCH (n:User {id: $id}) RETURN n", {"id": self._get_random_vertex()}) def benchmark__update__vertex_on_property(self): return ( @@ -364,7 +259,7 @@ class Pokec(Dataset): def benchmark__basic__single_vertex_property_update_update(self): return ( - "MATCH (n {id: $id}) SET n.property = -1", + "MATCH (n:User {id: $id}) SET n.property = -1", {"id": self._get_random_vertex()}, ) From 8b0dca9eabbe8e19ec1f6052503f2b38996efbf2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marko=20Budiseli=C4=87?= <marko.budiselic@memgraph.com> Date: Sun, 26 Mar 2023 17:34:51 +0200 Subject: [PATCH 4/6] Upgrade pre-commit hook to use isort 5.12 (#840) --- .pre-commit-config.yaml | 6 +++--- init | 21 ++++++++++++--------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 79bf689df..c75df7b98 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,16 +1,16 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.3.0 + rev: v4.4.0 hooks: - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 22.8.0 + rev: 23.1.0 hooks: - id: black - repo: https://github.com/pycqa/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort name: isort (python) diff --git a/init b/init index 95a2438eb..3a89013fb 100755 --- a/init +++ b/init @@ -106,15 +106,18 @@ for hook in $(find $DIR/.githooks -type f -printf "%f\n"); do echo "Added $hook hook" done; -# Install precommit hook -python3 -m pip install pre-commit -python3 -m pre_commit install - -# Install py format tools -echo "Install black formatter" -python3 -m pip install black==22.8.* -echo "Install isort" -python3 -m pip install isort==5.10.* +# Install precommit hook except on old operating systems because we don't +# develop on them -> pre-commit hook not required -> we can use latest +# packages. +if [ "${DISTRO}" != "centos-7" ] && [ "$DISTRO" != "debian-10" ] && [ "${DISTRO}" != "ubuntu-18.04" ]; then + python3 -m pip install pre-commit + python3 -m pre_commit install + # Install py format tools for usage during the development. + echo "Install black formatter" + python3 -m pip install black==23.1.* + echo "Install isort" + python3 -m pip install isort==5.12.* +fi # Link `include/mgp.py` with `release/mgp/mgp.py` ln -v -f include/mgp.py release/mgp/mgp.py From a9dc344b4965029fc76421a3677fa46e335616c1 Mon Sep 17 00:00:00 2001 From: Aidar Samerkhanov <darych90@gmail.com> Date: Mon, 27 Mar 2023 13:26:10 +0400 Subject: [PATCH 5/6] Add automatic CPU architecture detection to CMake (#838) --- CMakeLists.txt | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 916389424..592930a40 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -231,7 +231,15 @@ endif() message(STATUS "CMake build type: ${CMAKE_BUILD_TYPE}") # ----------------------------------------------------------------------------- -set(MG_ARCH "x86_64" CACHE STRING "Host architecture to build Memgraph on. Supported values are x86_64 (default), ARM64.") +if (NOT MG_ARCH) + set(MG_ARCH_DESCR "Host architecture to build Memgraph on. Supported values are x86_64, ARM64.") + if (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "aarch64") + set(MG_ARCH "ARM64" CACHE STRING ${MG_ARCH_DESCR}) + else() + set(MG_ARCH "x86_64" CACHE STRING ${MG_ARCH_DESCR}) + endif() +endif() +message(STATUS "MG_ARCH: ${MG_ARCH}") # setup external dependencies ------------------------------------------------- From 029be10f1de2adaf1948040c5c7ada9b2722d7e3 Mon Sep 17 00:00:00 2001 From: Andi <andi8647@gmail.com> Date: Mon, 27 Mar 2023 15:46:00 +0200 Subject: [PATCH 6/6] Add queries to show or terminate active transactions (#790) --- include/mgp.hpp | 21 ++ src/auth/models.cpp | 22 +- src/auth/models.hpp | 5 +- src/glue/auth.cpp | 4 +- src/glue/auth_checker.cpp | 3 +- src/glue/auth_checker.hpp | 4 +- src/memgraph.cpp | 9 +- src/query/auth_checker.hpp | 2 +- src/query/context.hpp | 16 +- src/query/exceptions.hpp | 14 +- src/query/frontend/ast/ast.cpp | 4 + src/query/frontend/ast/ast.hpp | 43 ++- src/query/frontend/ast/ast.lcp | 24 +- src/query/frontend/ast/ast_visitor.hpp | 10 +- .../frontend/ast/cypher_main_visitor.cpp | 32 ++ .../frontend/ast/cypher_main_visitor.hpp | 20 ++ .../opencypher/grammar/MemgraphCypher.g4 | 16 + .../opencypher/grammar/MemgraphCypherLexer.g4 | 5 +- .../frontend/semantic/required_privileges.cpp | 4 +- src/query/interpreter.cpp | 303 ++++++++++++++-- src/query/interpreter.hpp | 50 ++- src/query/stream/streams.cpp | 7 +- src/query/trigger.cpp | 6 +- src/query/trigger.hpp | 8 +- src/storage/v2/constraints.cpp | 5 +- src/storage/v2/mvcc.hpp | 7 +- src/storage/v2/storage.cpp | 17 +- src/storage/v2/storage.hpp | 5 +- src/storage/v2/transaction.hpp | 20 +- src/utils/typeinfo.hpp | 2 +- tests/e2e/CMakeLists.txt | 1 + tests/e2e/lba_procedures/show_privileges.py | 4 +- tests/e2e/transaction_queue/CMakeLists.txt | 8 + tests/e2e/transaction_queue/common.py | 26 ++ .../procedures/CMakeLists.txt | 1 + .../procedures/infinite_query.py | 27 ++ .../test_transaction_queue.py | 338 ++++++++++++++++++ tests/e2e/transaction_queue/workloads.yaml | 14 + tests/unit/CMakeLists.txt | 6 + tests/unit/interpreter.cpp | 69 +--- tests/unit/interpreter_faker.hpp | 49 +++ tests/unit/query_trigger.cpp | 3 +- tests/unit/transaction_queue.cpp | 75 ++++ tests/unit/transaction_queue_multiple.cpp | 118 ++++++ 44 files changed, 1255 insertions(+), 172 deletions(-) create mode 100644 tests/e2e/transaction_queue/CMakeLists.txt create mode 100644 tests/e2e/transaction_queue/common.py create mode 100644 tests/e2e/transaction_queue/procedures/CMakeLists.txt create mode 100644 tests/e2e/transaction_queue/procedures/infinite_query.py create mode 100644 tests/e2e/transaction_queue/test_transaction_queue.py create mode 100644 tests/e2e/transaction_queue/workloads.yaml create mode 100644 tests/unit/interpreter_faker.hpp create mode 100644 tests/unit/transaction_queue.cpp create mode 100644 tests/unit/transaction_queue_multiple.cpp diff --git a/include/mgp.hpp b/include/mgp.hpp index 4d39a7c86..a40d76fd3 100644 --- a/include/mgp.hpp +++ b/include/mgp.hpp @@ -55,6 +55,15 @@ class NotEnoughMemoryException : public std::exception { const char *what() const throw() { return "Not enough memory!"; } }; +class MustAbortException : public std::exception { + public: + explicit MustAbortException(const std::string &message) : message_(message) {} + const char *what() const noexcept override { return message_.c_str(); } + + private: + std::string message_; +}; + // Forward declarations class Nodes; using GraphNodes = Nodes; @@ -141,6 +150,10 @@ class Graph { /// @brief Deletes a relationship from the graph. void DeleteRelationship(const Relationship &relationship); + bool MustAbort() const; + + void CheckMustAbort() const; + private: mgp_graph *graph_; }; @@ -1572,6 +1585,14 @@ inline Id::Id(int64_t id) : id_(id) {} inline Graph::Graph(mgp_graph *graph) : graph_(graph) {} +inline bool Graph::MustAbort() const { return must_abort(graph_); } + +inline void Graph::CheckMustAbort() const { + if (MustAbort()) { + throw MustAbortException("Query was asked to abort."); + } +} + inline int64_t Graph::Order() const { int64_t i = 0; for (const auto _ : Nodes()) { diff --git a/src/auth/models.cpp b/src/auth/models.cpp index 78276bbae..18574d369 100644 --- a/src/auth/models.cpp +++ b/src/auth/models.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Licensed as a Memgraph Enterprise file under the Memgraph Enterprise // License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use @@ -34,13 +34,17 @@ namespace memgraph::auth { namespace { // Constant list of all available permissions. -const std::vector<Permission> kPermissionsAll = { - Permission::MATCH, Permission::CREATE, Permission::MERGE, Permission::DELETE, - Permission::SET, Permission::REMOVE, Permission::INDEX, Permission::STATS, - Permission::CONSTRAINT, Permission::DUMP, Permission::AUTH, Permission::REPLICATION, - Permission::DURABILITY, Permission::READ_FILE, Permission::FREE_MEMORY, Permission::TRIGGER, - Permission::CONFIG, Permission::STREAM, Permission::MODULE_READ, Permission::MODULE_WRITE, - Permission::WEBSOCKET}; +const std::vector<Permission> kPermissionsAll = {Permission::MATCH, Permission::CREATE, + Permission::MERGE, Permission::DELETE, + Permission::SET, Permission::REMOVE, + Permission::INDEX, Permission::STATS, + Permission::CONSTRAINT, Permission::DUMP, + Permission::AUTH, Permission::REPLICATION, + Permission::DURABILITY, Permission::READ_FILE, + Permission::FREE_MEMORY, Permission::TRIGGER, + Permission::CONFIG, Permission::STREAM, + Permission::MODULE_READ, Permission::MODULE_WRITE, + Permission::WEBSOCKET, Permission::TRANSACTION_MANAGEMENT}; } // namespace std::string PermissionToString(Permission permission) { @@ -87,6 +91,8 @@ std::string PermissionToString(Permission permission) { return "MODULE_WRITE"; case Permission::WEBSOCKET: return "WEBSOCKET"; + case Permission::TRANSACTION_MANAGEMENT: + return "TRANSACTION_MANAGEMENT"; } } diff --git a/src/auth/models.hpp b/src/auth/models.hpp index 726586bdb..b902b6960 100644 --- a/src/auth/models.hpp +++ b/src/auth/models.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Licensed as a Memgraph Enterprise file under the Memgraph Enterprise // License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use @@ -40,7 +40,8 @@ enum class Permission : uint64_t { STREAM = 1U << 17U, MODULE_READ = 1U << 18U, MODULE_WRITE = 1U << 19U, - WEBSOCKET = 1U << 20U + WEBSOCKET = 1U << 20U, + TRANSACTION_MANAGEMENT = 1U << 21U }; // clang-format on diff --git a/src/glue/auth.cpp b/src/glue/auth.cpp index 639f722c4..0811ff2e1 100644 --- a/src/glue/auth.cpp +++ b/src/glue/auth.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -58,6 +58,8 @@ auth::Permission PrivilegeToPermission(query::AuthQuery::Privilege privilege) { return auth::Permission::MODULE_WRITE; case query::AuthQuery::Privilege::WEBSOCKET: return auth::Permission::WEBSOCKET; + case query::AuthQuery::Privilege::TRANSACTION_MANAGEMENT: + return auth::Permission::TRANSACTION_MANAGEMENT; } } diff --git a/src/glue/auth_checker.cpp b/src/glue/auth_checker.cpp index 1c58daed4..debdc0f5b 100644 --- a/src/glue/auth_checker.cpp +++ b/src/glue/auth_checker.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -84,6 +84,7 @@ bool AuthChecker::IsUserAuthorized(const std::optional<std::string> &username, return maybe_user.has_value() && IsUserAuthorized(*maybe_user, privileges); } + #ifdef MG_ENTERPRISE std::unique_ptr<memgraph::query::FineGrainedAuthChecker> AuthChecker::GetFineGrainedAuthChecker( const std::string &username, const memgraph::query::DbAccessor *dba) const { diff --git a/src/glue/auth_checker.hpp b/src/glue/auth_checker.hpp index f5e9bc526..22f6515c3 100644 --- a/src/glue/auth_checker.hpp +++ b/src/glue/auth_checker.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -26,9 +26,11 @@ class AuthChecker : public query::AuthChecker { bool IsUserAuthorized(const std::optional<std::string> &username, const std::vector<query::AuthQuery::Privilege> &privileges) const override; + #ifdef MG_ENTERPRISE std::unique_ptr<memgraph::query::FineGrainedAuthChecker> GetFineGrainedAuthChecker( const std::string &username, const memgraph::query::DbAccessor *dba) const override; + #endif [[nodiscard]] static bool IsUserAuthorized(const memgraph::auth::User &user, const std::vector<memgraph::query::AuthQuery::Privilege> &privileges); diff --git a/src/memgraph.cpp b/src/memgraph.cpp index 91ec8ea99..6c0c28e89 100644 --- a/src/memgraph.cpp +++ b/src/memgraph.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -522,6 +522,7 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph : memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream, memgraph::communication::v2::OutputStream>(input_stream, output_stream), db_(data->db), + interpreter_context_(data->interpreter_context), interpreter_(data->interpreter_context), auth_(data->auth), #if MG_ENTERPRISE @@ -529,6 +530,11 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph #endif endpoint_(endpoint), run_id_(data->run_id) { + interpreter_context_->interpreters.WithLock([this](auto &interpreters) { interpreters.insert(&interpreter_); }); + } + + ~BoltSession() override { + interpreter_context_->interpreters.WithLock([this](auto &interpreters) { interpreters.erase(&interpreter_); }); } using memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream, @@ -674,6 +680,7 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph // NOTE: Needed only for ToBoltValue conversions const memgraph::storage::Storage *db_; + memgraph::query::InterpreterContext *interpreter_context_; memgraph::query::Interpreter interpreter_; memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth_; std::optional<memgraph::auth::User> user_; diff --git a/src/query/auth_checker.hpp b/src/query/auth_checker.hpp index 00abebc4a..4f6cb1419 100644 --- a/src/query/auth_checker.hpp +++ b/src/query/auth_checker.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source diff --git a/src/query/context.hpp b/src/query/context.hpp index 18f8eb27b..3b89ede13 100644 --- a/src/query/context.hpp +++ b/src/query/context.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -24,6 +24,15 @@ namespace memgraph::query { +enum class TransactionStatus { + IDLE, + ACTIVE, + VERIFYING, + TERMINATED, + STARTED_COMMITTING, + STARTED_ROLLBACK, +}; + struct EvaluationContext { /// Memory for allocations during evaluation of a *single* Pull call. /// @@ -66,6 +75,7 @@ struct ExecutionContext { SymbolTable symbol_table; EvaluationContext evaluation_context; std::atomic<bool> *is_shutting_down{nullptr}; + std::atomic<TransactionStatus> *transaction_status{nullptr}; bool is_profile_query{false}; std::chrono::duration<double> profile_execution_time; plan::ProfilingStats stats; @@ -82,7 +92,9 @@ static_assert(std::is_move_assignable_v<ExecutionContext>, "ExecutionContext mus static_assert(std::is_move_constructible_v<ExecutionContext>, "ExecutionContext must be move constructible!"); inline bool MustAbort(const ExecutionContext &context) noexcept { - return (context.is_shutting_down != nullptr && context.is_shutting_down->load(std::memory_order_acquire)) || + return (context.transaction_status != nullptr && + context.transaction_status->load(std::memory_order_acquire) == TransactionStatus::TERMINATED) || + (context.is_shutting_down != nullptr && context.is_shutting_down->load(std::memory_order_acquire)) || context.timer.IsExpired(); } diff --git a/src/query/exceptions.hpp b/src/query/exceptions.hpp index ce2904882..e48b3e525 100644 --- a/src/query/exceptions.hpp +++ b/src/query/exceptions.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -120,9 +120,8 @@ class HintedAbortError : public utils::BasicException { using utils::BasicException::BasicException; HintedAbortError() : utils::BasicException( - "Transaction was asked to abort, most likely because it was " - "executing longer than time specified by " - "--query-execution-timeout-sec flag.") {} + "Transaction was asked to abort either because it was executing longer than time specified or another user " + "asked it to abort.") {} }; class ExplicitTransactionUsageException : public QueryRuntimeException { @@ -237,4 +236,11 @@ class ReplicationException : public utils::BasicException { : utils::BasicException("Replication Exception: {} Check the status of the replicas using 'SHOW REPLICA' query.", message) {} }; + +class TransactionQueueInMulticommandTxException : public QueryException { + public: + TransactionQueueInMulticommandTxException() + : QueryException("Transaction queue queries not allowed in multicommand transactions.") {} +}; + } // namespace memgraph::query diff --git a/src/query/frontend/ast/ast.cpp b/src/query/frontend/ast/ast.cpp index 6dfd4f85c..6f7b0bddd 100644 --- a/src/query/frontend/ast/ast.cpp +++ b/src/query/frontend/ast/ast.cpp @@ -10,6 +10,7 @@ // licenses/APL.txt. #include "query/frontend/ast/ast.hpp" +#include "query/frontend/ast/ast_visitor.hpp" #include "utils/typeinfo.hpp" namespace memgraph { @@ -259,5 +260,8 @@ constexpr utils::TypeInfo query::Foreach::kType{utils::TypeId::AST_FOREACH, "For constexpr utils::TypeInfo query::ShowConfigQuery::kType{utils::TypeId::AST_SHOW_CONFIG_QUERY, "ShowConfigQuery", &query::Query::kType}; +constexpr utils::TypeInfo query::TransactionQueueQuery::kType{utils::TypeId::AST_TRANSACTION_QUEUE_QUERY, + "TransactionQueueQuery", &query::Query::kType}; + constexpr utils::TypeInfo query::Exists::kType{utils::TypeId::AST_EXISTS, "Exists", &query::Expression::kType}; } // namespace memgraph diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index e27717a2c..a2b7bd9f4 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -2699,7 +2699,8 @@ class AuthQuery : public memgraph::query::Query { STREAM, MODULE_READ, MODULE_WRITE, - WEBSOCKET + WEBSOCKET, + TRANSACTION_MANAGEMENT }; enum class FineGrainedPrivilege { NOTHING, READ, UPDATE, CREATE_DELETE }; @@ -2752,13 +2753,17 @@ class AuthQuery : public memgraph::query::Query { /// Constant that holds all available privileges. const std::vector<AuthQuery::Privilege> kPrivilegesAll = { - AuthQuery::Privilege::CREATE, AuthQuery::Privilege::DELETE, AuthQuery::Privilege::MATCH, - AuthQuery::Privilege::MERGE, AuthQuery::Privilege::SET, AuthQuery::Privilege::REMOVE, - AuthQuery::Privilege::INDEX, AuthQuery::Privilege::STATS, AuthQuery::Privilege::AUTH, - AuthQuery::Privilege::CONSTRAINT, AuthQuery::Privilege::DUMP, AuthQuery::Privilege::REPLICATION, - AuthQuery::Privilege::READ_FILE, AuthQuery::Privilege::DURABILITY, AuthQuery::Privilege::FREE_MEMORY, - AuthQuery::Privilege::TRIGGER, AuthQuery::Privilege::CONFIG, AuthQuery::Privilege::STREAM, - AuthQuery::Privilege::MODULE_READ, AuthQuery::Privilege::MODULE_WRITE, AuthQuery::Privilege::WEBSOCKET}; + AuthQuery::Privilege::CREATE, AuthQuery::Privilege::DELETE, + AuthQuery::Privilege::MATCH, AuthQuery::Privilege::MERGE, + AuthQuery::Privilege::SET, AuthQuery::Privilege::REMOVE, + AuthQuery::Privilege::INDEX, AuthQuery::Privilege::STATS, + AuthQuery::Privilege::AUTH, AuthQuery::Privilege::CONSTRAINT, + AuthQuery::Privilege::DUMP, AuthQuery::Privilege::REPLICATION, + AuthQuery::Privilege::READ_FILE, AuthQuery::Privilege::DURABILITY, + AuthQuery::Privilege::FREE_MEMORY, AuthQuery::Privilege::TRIGGER, + AuthQuery::Privilege::CONFIG, AuthQuery::Privilege::STREAM, + AuthQuery::Privilege::MODULE_READ, AuthQuery::Privilege::MODULE_WRITE, + AuthQuery::Privilege::WEBSOCKET, AuthQuery::Privilege::TRANSACTION_MANAGEMENT}; class InfoQuery : public memgraph::query::Query { public: @@ -3203,6 +3208,28 @@ class ShowConfigQuery : public memgraph::query::Query { } }; +class TransactionQueueQuery : public memgraph::query::Query { + public: + static const utils::TypeInfo kType; + const utils::TypeInfo &GetTypeInfo() const override { return kType; } + + enum class Action { SHOW_TRANSACTIONS, TERMINATE_TRANSACTIONS }; + + TransactionQueueQuery() = default; + + DEFVISITABLE(QueryVisitor<void>); + + memgraph::query::TransactionQueueQuery::Action action_; + std::vector<Expression *> transaction_id_list_; + + TransactionQueueQuery *Clone(AstStorage *storage) const override { + auto *object = storage->Create<TransactionQueueQuery>(); + object->action_ = action_; + object->transaction_id_list_ = transaction_id_list_; + return object; + } +}; + class Exists : public memgraph::query::Expression { public: static const utils::TypeInfo kType; diff --git a/src/query/frontend/ast/ast.lcp b/src/query/frontend/ast/ast.lcp index 24b460c4f..9bc82e31c 100644 --- a/src/query/frontend/ast/ast.lcp +++ b/src/query/frontend/ast/ast.lcp @@ -2284,7 +2284,7 @@ cpp<# (lcp:define-enum privilege (create delete match merge set remove index stats auth constraint dump replication durability read_file free_memory trigger config stream module_read module_write - websocket) + websocket transaction_management) (:serialize)) (lcp:define-enum fine-grained-privilege (nothing read update create_delete) @@ -2333,7 +2333,7 @@ const std::vector<AuthQuery::Privilege> kPrivilegesAll = { AuthQuery::Privilege::FREE_MEMORY, AuthQuery::Privilege::TRIGGER, AuthQuery::Privilege::CONFIG, AuthQuery::Privilege::STREAM, AuthQuery::Privilege::MODULE_READ, AuthQuery::Privilege::MODULE_WRITE, - AuthQuery::Privilege::WEBSOCKET}; + AuthQuery::Privilege::WEBSOCKET, AuthQuery::Privilege::TRANSACTION_MANAGEMENT}; cpp<# (lcp:define-class info-query (query) @@ -2661,6 +2661,26 @@ cpp<# (:serialize (:slk)) (:clone)) +(lcp:define-class transaction-queue-query (query) + ((action "Action" :scope :public) + (transaction_id_list "std::vector<Expression*>" :scope :public)) + + (:public + (lcp:define-enum action + (show-transactions terminate-transactions) + (:serialize)) + #>cpp + TransactionQueueQuery() = default; + + DEFVISITABLE(QueryVisitor<void>); + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + (lcp:define-class version-query (query) () (:public #>cpp diff --git a/src/query/frontend/ast/ast_visitor.hpp b/src/query/frontend/ast/ast_visitor.hpp index e860cad1d..81ef7f5c6 100644 --- a/src/query/frontend/ast/ast_visitor.hpp +++ b/src/query/frontend/ast/ast_visitor.hpp @@ -95,6 +95,7 @@ class SettingQuery; class VersionQuery; class Foreach; class ShowConfigQuery; +class TransactionQueueQuery; class Exists; using TreeCompositeVisitor = utils::CompositeVisitor< @@ -127,9 +128,10 @@ class ExpressionVisitor None, ParameterLookup, Identifier, PrimitiveLiteral, RegexMatch, Exists> {}; template <class TResult> -class QueryVisitor : public utils::Visitor<TResult, CypherQuery, ExplainQuery, ProfileQuery, IndexQuery, AuthQuery, - InfoQuery, ConstraintQuery, DumpQuery, ReplicationQuery, LockPathQuery, - FreeMemoryQuery, TriggerQuery, IsolationLevelQuery, CreateSnapshotQuery, - StreamQuery, SettingQuery, VersionQuery, ShowConfigQuery> {}; +class QueryVisitor + : public utils::Visitor<TResult, CypherQuery, ExplainQuery, ProfileQuery, IndexQuery, AuthQuery, InfoQuery, + ConstraintQuery, DumpQuery, ReplicationQuery, LockPathQuery, FreeMemoryQuery, TriggerQuery, + IsolationLevelQuery, CreateSnapshotQuery, StreamQuery, SettingQuery, TransactionQueueQuery, + VersionQuery, ShowConfigQuery> {}; } // namespace memgraph::query diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index f1c74e554..8425f8fe2 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -11,8 +11,10 @@ #include "query/frontend/ast/cypher_main_visitor.hpp" #include <support/Any.h> +#include <tree/ParseTreeVisitor.h> #include <algorithm> +#include <any> #include <climits> #include <codecvt> #include <cstring> @@ -631,6 +633,7 @@ void GetTopicNames(auto &destination, MemgraphCypher::TopicNamesContext *topic_n destination = std::any_cast<Expression *>(topic_names_ctx->accept(&visitor)); } } + } // namespace antlrcpp::Any CypherMainVisitor::visitKafkaCreateStreamConfig(MemgraphCypher::KafkaCreateStreamConfigContext *ctx) { @@ -883,6 +886,34 @@ antlrcpp::Any CypherMainVisitor::visitShowSettings(MemgraphCypher::ShowSettingsC return setting_query; } +antlrcpp::Any CypherMainVisitor::visitTransactionQueueQuery(MemgraphCypher::TransactionQueueQueryContext *ctx) { + MG_ASSERT(ctx->children.size() == 1, "TransactionQueueQuery should have exactly one child!"); + auto *transaction_queue_query = std::any_cast<TransactionQueueQuery *>(ctx->children[0]->accept(this)); + query_ = transaction_queue_query; + return transaction_queue_query; +} + +antlrcpp::Any CypherMainVisitor::visitShowTransactions(MemgraphCypher::ShowTransactionsContext * /*ctx*/) { + auto *transaction_shower = storage_->Create<TransactionQueueQuery>(); + transaction_shower->action_ = TransactionQueueQuery::Action::SHOW_TRANSACTIONS; + return transaction_shower; +} + +antlrcpp::Any CypherMainVisitor::visitTerminateTransactions(MemgraphCypher::TerminateTransactionsContext *ctx) { + auto *terminator = storage_->Create<TransactionQueueQuery>(); + terminator->action_ = TransactionQueueQuery::Action::TERMINATE_TRANSACTIONS; + terminator->transaction_id_list_ = std::any_cast<std::vector<Expression *>>(ctx->transactionIdList()->accept(this)); + return terminator; +} + +antlrcpp::Any CypherMainVisitor::visitTransactionIdList(MemgraphCypher::TransactionIdListContext *ctx) { + std::vector<Expression *> transaction_ids; + for (auto *transaction_id : ctx->transactionId()) { + transaction_ids.push_back(std::any_cast<Expression *>(transaction_id->accept(this))); + } + return transaction_ids; +} + antlrcpp::Any CypherMainVisitor::visitVersionQuery(MemgraphCypher::VersionQueryContext * /*ctx*/) { auto *version_query = storage_->Create<VersionQuery>(); query_ = version_query; @@ -1451,6 +1482,7 @@ antlrcpp::Any CypherMainVisitor::visitPrivilege(MemgraphCypher::PrivilegeContext if (ctx->MODULE_READ()) return AuthQuery::Privilege::MODULE_READ; if (ctx->MODULE_WRITE()) return AuthQuery::Privilege::MODULE_WRITE; if (ctx->WEBSOCKET()) return AuthQuery::Privilege::WEBSOCKET; + if (ctx->TRANSACTION_MANAGEMENT()) return AuthQuery::Privilege::TRANSACTION_MANAGEMENT; LOG_FATAL("Should not get here - unknown privilege!"); } diff --git a/src/query/frontend/ast/cypher_main_visitor.hpp b/src/query/frontend/ast/cypher_main_visitor.hpp index e5c5f55f4..aa37b383f 100644 --- a/src/query/frontend/ast/cypher_main_visitor.hpp +++ b/src/query/frontend/ast/cypher_main_visitor.hpp @@ -358,6 +358,26 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor { */ antlrcpp::Any visitShowSettings(MemgraphCypher::ShowSettingsContext *ctx) override; + /** + * @return TransactionQueueQuery* + */ + antlrcpp::Any visitTransactionQueueQuery(MemgraphCypher::TransactionQueueQueryContext *ctx) override; + + /** + * @return ShowTransactions* + */ + antlrcpp::Any visitShowTransactions(MemgraphCypher::ShowTransactionsContext *ctx) override; + + /** + * @return TerminateTransactions* + */ + antlrcpp::Any visitTerminateTransactions(MemgraphCypher::TerminateTransactionsContext *ctx) override; + + /** + * @return TransactionIdList* + */ + antlrcpp::Any visitTransactionIdList(MemgraphCypher::TransactionIdListContext *ctx) override; + /** * @return VersionQuery* */ diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 index 22b11ab61..c189168c8 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 @@ -102,6 +102,8 @@ memgraphCypherKeyword : cypherKeyword | USER | USERS | VERSION + | TERMINATE + | TRANSACTIONS ; symbolicName : UnescapedSymbolicName @@ -127,6 +129,7 @@ query : cypherQuery | settingQuery | versionQuery | showConfigQuery + | transactionQueueQuery ; authQuery : createRole @@ -197,6 +200,14 @@ settingQuery : setSetting | showSettings ; +transactionQueueQuery : showTransactions + | terminateTransactions + ; + +showTransactions : SHOW TRANSACTIONS ; + +terminateTransactions : TERMINATE TRANSACTIONS transactionIdList; + loadCsv : LOAD CSV FROM csvFile ( WITH | NO ) HEADER ( IGNORE BAD ) ? ( DELIMITER delimiter ) ? @@ -259,6 +270,7 @@ privilege : CREATE | MODULE_READ | MODULE_WRITE | WEBSOCKET + | TRANSACTION_MANAGEMENT ; granularPrivilege : NOTHING | READ | UPDATE | CREATE_DELETE ; @@ -402,3 +414,7 @@ showSettings : SHOW DATABASE SETTINGS ; showConfigQuery : SHOW CONFIG ; versionQuery : SHOW VERSION ; + +transactionIdList : transactionId ( ',' transactionId )* ; + +transactionId : literal ; diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 index fbe6f7725..86911785a 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 @@ -53,6 +53,7 @@ DIRECTORY : D I R E C T O R Y ; DROP : D R O P ; DUMP : D U M P ; DURABILITY : D U R A B I L I T Y ; +EDGE_TYPES : E D G E UNDERSCORE T Y P E S ; EXECUTE : E X E C U T E ; FOR : F O R ; FOREACH : F O R E A C H; @@ -103,10 +104,13 @@ STOP : S T O P ; STREAM : S T R E A M ; STREAMS : S T R E A M S ; SYNC : S Y N C ; +TERMINATE : T E R M I N A T E ; TIMEOUT : T I M E O U T ; TO : T O ; TOPICS : T O P I C S; TRANSACTION : T R A N S A C T I O N ; +TRANSACTION_MANAGEMENT : T R A N S A C T I O N UNDERSCORE M A N A G E M E N T ; +TRANSACTIONS : T R A N S A C T I O N S ; TRANSFORM : T R A N S F O R M ; TRIGGER : T R I G G E R ; TRIGGERS : T R I G G E R S ; @@ -117,4 +121,3 @@ USER : U S E R ; USERS : U S E R S ; VERSION : V E R S I O N ; WEBSOCKET : W E B S O C K E T ; -EDGE_TYPES : E D G E UNDERSCORE T Y P E S ; diff --git a/src/query/frontend/semantic/required_privileges.cpp b/src/query/frontend/semantic/required_privileges.cpp index 2dfada82d..ffb5b703c 100644 --- a/src/query/frontend/semantic/required_privileges.cpp +++ b/src/query/frontend/semantic/required_privileges.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -80,6 +80,8 @@ class PrivilegeExtractor : public QueryVisitor<void>, public HierarchicalTreeVis void Visit(SettingQuery & /*setting_query*/) override { AddPrivilege(AuthQuery::Privilege::CONFIG); } + void Visit(TransactionQueueQuery & /*transaction_queue_query*/) override {} + void Visit(VersionQuery & /*version_query*/) override { AddPrivilege(AuthQuery::Privilege::STATS); } bool PreVisit(Create & /*unused*/) override { diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index 37672cca7..180cc707d 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -18,9 +18,12 @@ #include <cstddef> #include <cstdint> #include <functional> +#include <iterator> #include <limits> #include <optional> +#include <thread> #include <unordered_map> +#include <utility> #include <variant> #include "auth/models.hpp" @@ -59,6 +62,7 @@ #include "utils/logging.hpp" #include "utils/memory.hpp" #include "utils/memory_tracker.hpp" +#include "utils/on_scope_exit.hpp" #include "utils/readable_size.hpp" #include "utils/settings.hpp" #include "utils/string.hpp" @@ -975,7 +979,8 @@ struct PullPlanVector { struct PullPlan { explicit PullPlan(std::shared_ptr<CachedPlan> plan, const Parameters ¶meters, bool is_profile_query, DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory, - std::optional<std::string> username, TriggerContextCollector *trigger_context_collector = nullptr, + std::optional<std::string> username, std::atomic<TransactionStatus> *transaction_status, + TriggerContextCollector *trigger_context_collector = nullptr, std::optional<size_t> memory_limit = {}); std::optional<plan::ProfilingStatsWithTotalTime> Pull(AnyStream *stream, std::optional<int> n, const std::vector<Symbol> &output_symbols, @@ -1004,8 +1009,8 @@ struct PullPlan { PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters ¶meters, const bool is_profile_query, DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory, - std::optional<std::string> username, TriggerContextCollector *trigger_context_collector, - const std::optional<size_t> memory_limit) + std::optional<std::string> username, std::atomic<TransactionStatus> *transaction_status, + TriggerContextCollector *trigger_context_collector, const std::optional<size_t> memory_limit) : plan_(plan), cursor_(plan->plan().MakeCursor(execution_memory)), frame_(plan->symbol_table().max_position(), execution_memory), @@ -1025,6 +1030,7 @@ PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters &par ctx_.timer = utils::AsyncTimer{interpreter_context->config.execution_timeout_sec}; } ctx_.is_shutting_down = &interpreter_context->is_shutting_down; + ctx_.transaction_status = transaction_status; ctx_.is_profile_query = is_profile_query; ctx_.trigger_context_collector = trigger_context_collector; } @@ -1137,12 +1143,14 @@ PreparedQuery Interpreter::PrepareTransactionQuery(std::string_view query_upper) if (in_explicit_transaction_) { throw ExplicitTransactionUsageException("Nested transactions are not supported."); } + in_explicit_transaction_ = true; expect_rollback_ = false; db_accessor_ = std::make_unique<storage::Storage::Accessor>(interpreter_context_->db->Access(GetIsolationLevelOverride())); execution_db_accessor_.emplace(db_accessor_.get()); + transaction_status_.store(TransactionStatus::ACTIVE, std::memory_order_release); if (interpreter_context_->trigger_store.HasTriggers()) { trigger_context_collector_.emplace(interpreter_context_->trigger_store.GetEventTypes()); @@ -1194,7 +1202,7 @@ PreparedQuery Interpreter::PrepareTransactionQuery(std::string_view query_upper) PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string, TypedValue> *summary, InterpreterContext *interpreter_context, DbAccessor *dba, utils::MemoryResource *execution_memory, std::vector<Notification> *notifications, - const std::string *username, + const std::string *username, std::atomic<TransactionStatus> *transaction_status, TriggerContextCollector *trigger_context_collector = nullptr) { auto *cypher_query = utils::Downcast<CypherQuery>(parsed_query.query); @@ -1239,9 +1247,9 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string, header.push_back( utils::FindOr(parsed_query.stripped_query.named_expressions(), symbol.token_position(), symbol.name()).first); } - auto pull_plan = - std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context, execution_memory, - StringPointerToOptional(username), trigger_context_collector, memory_limit); + auto pull_plan = std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context, + execution_memory, StringPointerToOptional(username), transaction_status, + trigger_context_collector, memory_limit); return PreparedQuery{std::move(header), std::move(parsed_query.required_privileges), [pull_plan = std::move(pull_plan), output_symbols = std::move(output_symbols), summary]( AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> { @@ -1301,8 +1309,8 @@ PreparedQuery PrepareExplainQuery(ParsedQuery parsed_query, std::map<std::string PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_transaction, std::map<std::string, TypedValue> *summary, InterpreterContext *interpreter_context, - DbAccessor *dba, utils::MemoryResource *execution_memory, - const std::string *username) { + DbAccessor *dba, utils::MemoryResource *execution_memory, const std::string *username, + std::atomic<TransactionStatus> *transaction_status) { const std::string kProfileQueryStart = "profile "; MG_ASSERT(utils::StartsWith(utils::ToLowerCase(parsed_query.stripped_query.query()), kProfileQueryStart), @@ -1363,13 +1371,14 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra // We want to execute the query we are profiling lazily, so we delay // the construction of the corresponding context. stats_and_total_time = std::optional<plan::ProfilingStatsWithTotalTime>{}, - pull_plan = std::shared_ptr<PullPlanVector>(nullptr)]( + pull_plan = std::shared_ptr<PullPlanVector>(nullptr), transaction_status]( AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> { // No output symbols are given so that nothing is streamed. if (!stats_and_total_time) { - stats_and_total_time = PullPlan(plan, parameters, true, dba, interpreter_context, - execution_memory, optional_username, nullptr, memory_limit) - .Pull(stream, {}, {}, summary); + stats_and_total_time = + PullPlan(plan, parameters, true, dba, interpreter_context, execution_memory, + optional_username, transaction_status, nullptr, memory_limit) + .Pull(stream, {}, {}, summary); pull_plan = std::make_shared<PullPlanVector>(ProfilingStatsToTable(*stats_and_total_time)); } @@ -1524,7 +1533,8 @@ PreparedQuery PrepareIndexQuery(ParsedQuery parsed_query, bool in_explicit_trans PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transaction, std::map<std::string, TypedValue> *summary, InterpreterContext *interpreter_context, - DbAccessor *dba, utils::MemoryResource *execution_memory, const std::string *username) { + DbAccessor *dba, utils::MemoryResource *execution_memory, const std::string *username, + std::atomic<TransactionStatus> *transaction_status) { if (in_explicit_transaction) { throw UserModificationInMulticommandTxException(); } @@ -1545,7 +1555,7 @@ PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transa 0.0, AstStorage{}, symbol_table)); auto pull_plan = std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context, - execution_memory, StringPointerToOptional(username)); + execution_memory, StringPointerToOptional(username), transaction_status); return PreparedQuery{ callback.header, std::move(parsed_query.required_privileges), [pull_plan = std::move(pull_plan), callback = std::move(callback), output_symbols = std::move(output_symbols), @@ -1558,7 +1568,7 @@ PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transa RWType::NONE}; } -PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, const bool in_explicit_transaction, +PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, bool in_explicit_transaction, std::vector<Notification> *notifications, InterpreterContext *interpreter_context, DbAccessor *dba) { if (in_explicit_transaction) { @@ -1586,7 +1596,7 @@ PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, const bool in_ex // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) } -PreparedQuery PrepareLockPathQuery(ParsedQuery parsed_query, const bool in_explicit_transaction, +PreparedQuery PrepareLockPathQuery(ParsedQuery parsed_query, bool in_explicit_transaction, InterpreterContext *interpreter_context, DbAccessor *dba) { if (in_explicit_transaction) { throw LockPathModificationInMulticommandTxException(); @@ -1615,7 +1625,7 @@ PreparedQuery PrepareLockPathQuery(ParsedQuery parsed_query, const bool in_expli RWType::NONE}; } -PreparedQuery PrepareFreeMemoryQuery(ParsedQuery parsed_query, const bool in_explicit_transaction, +PreparedQuery PrepareFreeMemoryQuery(ParsedQuery parsed_query, bool in_explicit_transaction, InterpreterContext *interpreter_context) { if (in_explicit_transaction) { throw FreeMemoryModificationInMulticommandTxException(); @@ -1632,7 +1642,7 @@ PreparedQuery PrepareFreeMemoryQuery(ParsedQuery parsed_query, const bool in_exp RWType::NONE}; } -PreparedQuery PrepareShowConfigQuery(ParsedQuery parsed_query, const bool in_explicit_transaction) { +PreparedQuery PrepareShowConfigQuery(ParsedQuery parsed_query, bool in_explicit_transaction) { if (in_explicit_transaction) { throw ShowConfigModificationInMulticommandTxException(); } @@ -1736,7 +1746,7 @@ Callback ShowTriggers(InterpreterContext *interpreter_context) { }}; } -PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, const bool in_explicit_transaction, +PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, bool in_explicit_transaction, std::vector<Notification> *notifications, InterpreterContext *interpreter_context, DbAccessor *dba, const std::map<std::string, storage::PropertyValue> &user_parameters, const std::string *username) { @@ -1786,7 +1796,7 @@ PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, const bool in_explic // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) } -PreparedQuery PrepareStreamQuery(ParsedQuery parsed_query, const bool in_explicit_transaction, +PreparedQuery PrepareStreamQuery(ParsedQuery parsed_query, bool in_explicit_transaction, std::vector<Notification> *notifications, InterpreterContext *interpreter_context, DbAccessor *dba, const std::map<std::string, storage::PropertyValue> & /*user_parameters*/, @@ -1828,7 +1838,7 @@ constexpr auto ToStorageIsolationLevel(const IsolationLevelQuery::IsolationLevel } } -PreparedQuery PrepareIsolationLevelQuery(ParsedQuery parsed_query, const bool in_explicit_transaction, +PreparedQuery PrepareIsolationLevelQuery(ParsedQuery parsed_query, bool in_explicit_transaction, InterpreterContext *interpreter_context, Interpreter *interpreter) { if (in_explicit_transaction) { throw IsolationLevelModificationInMulticommandTxException(); @@ -1883,7 +1893,7 @@ PreparedQuery PrepareCreateSnapshotQuery(ParsedQuery parsed_query, bool in_expli RWType::NONE}; } -PreparedQuery PrepareSettingQuery(ParsedQuery parsed_query, const bool in_explicit_transaction, DbAccessor *dba) { +PreparedQuery PrepareSettingQuery(ParsedQuery parsed_query, bool in_explicit_transaction, DbAccessor *dba) { if (in_explicit_transaction) { throw SettingConfigInMulticommandTxException{}; } @@ -1909,7 +1919,155 @@ PreparedQuery PrepareSettingQuery(ParsedQuery parsed_query, const bool in_explic // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) } -PreparedQuery PrepareVersionQuery(ParsedQuery parsed_query, const bool in_explicit_transaction) { +std::vector<std::vector<TypedValue>> TransactionQueueQueryHandler::ShowTransactions( + const std::unordered_set<Interpreter *> &interpreters, const std::optional<std::string> &username, + bool hasTransactionManagementPrivilege) { + std::vector<std::vector<TypedValue>> results; + results.reserve(interpreters.size()); + for (Interpreter *interpreter : interpreters) { + TransactionStatus alive_status = TransactionStatus::ACTIVE; + // if it is just checking status, commit and abort should wait for the end of the check + // ignore interpreters that already started committing or rollback + if (!interpreter->transaction_status_.compare_exchange_strong(alive_status, TransactionStatus::VERIFYING)) { + continue; + } + utils::OnScopeExit clean_status([interpreter]() { + interpreter->transaction_status_.store(TransactionStatus::ACTIVE, std::memory_order_release); + }); + std::optional<uint64_t> transaction_id = interpreter->GetTransactionId(); + if (transaction_id.has_value() && (interpreter->username_ == username || hasTransactionManagementPrivilege)) { + const auto &typed_queries = interpreter->GetQueries(); + results.push_back({TypedValue(interpreter->username_.value_or("")), + TypedValue(std::to_string(transaction_id.value())), TypedValue(typed_queries)}); + } + } + return results; +} + +std::vector<std::vector<TypedValue>> TransactionQueueQueryHandler::KillTransactions( + InterpreterContext *interpreter_context, const std::vector<std::string> &maybe_kill_transaction_ids, + const std::optional<std::string> &username, bool hasTransactionManagementPrivilege) { + std::vector<std::vector<TypedValue>> results; + for (const std::string &transaction_id : maybe_kill_transaction_ids) { + bool killed = false; + bool transaction_found = false; + // Multiple simultaneous TERMINATE TRANSACTIONS aren't allowed + // TERMINATE and SHOW TRANSACTIONS are mutually exclusive + interpreter_context->interpreters.WithLock([&transaction_id, &killed, &transaction_found, username, + hasTransactionManagementPrivilege](const auto &interpreters) { + for (Interpreter *interpreter : interpreters) { + TransactionStatus alive_status = TransactionStatus::ACTIVE; + // if it is just checking kill, commit and abort should wait for the end of the check + // The only way to start checking if the transaction will get killed is if the transaction_status is + // active + if (!interpreter->transaction_status_.compare_exchange_strong(alive_status, TransactionStatus::VERIFYING)) { + continue; + } + utils::OnScopeExit clean_status([interpreter, &killed]() { + if (killed) { + interpreter->transaction_status_.store(TransactionStatus::TERMINATED, std::memory_order_release); + } else { + interpreter->transaction_status_.store(TransactionStatus::ACTIVE, std::memory_order_release); + } + }); + + std::optional<uint64_t> intr_trans = interpreter->GetTransactionId(); + if (intr_trans.has_value() && std::to_string(intr_trans.value()) == transaction_id) { + transaction_found = true; + if (interpreter->username_ == username || hasTransactionManagementPrivilege) { + killed = true; + spdlog::warn("Transaction {} successfully killed", transaction_id); + } else { + spdlog::warn("Not enough rights to kill the transaction"); + } + break; + } + } + }); + if (!transaction_found) { + spdlog::warn("Transaction {} not found", transaction_id); + } + results.push_back({TypedValue(transaction_id), TypedValue(killed)}); + } + return results; +} + +Callback HandleTransactionQueueQuery(TransactionQueueQuery *transaction_query, + const std::optional<std::string> &username, const Parameters ¶meters, + InterpreterContext *interpreter_context, DbAccessor *db_accessor) { + Frame frame(0); + SymbolTable symbol_table; + EvaluationContext evaluation_context; + evaluation_context.timestamp = QueryTimestamp(); + evaluation_context.parameters = parameters; + ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, db_accessor, storage::View::OLD); + + bool hasTransactionManagementPrivilege = interpreter_context->auth_checker->IsUserAuthorized( + username, {query::AuthQuery::Privilege::TRANSACTION_MANAGEMENT}); + + Callback callback; + switch (transaction_query->action_) { + case TransactionQueueQuery::Action::SHOW_TRANSACTIONS: { + callback.header = {"username", "transaction_id", "query"}; + callback.fn = [handler = TransactionQueueQueryHandler(), interpreter_context, username, + hasTransactionManagementPrivilege]() mutable { + std::vector<std::vector<TypedValue>> results; + // Multiple simultaneous SHOW TRANSACTIONS aren't allowed + interpreter_context->interpreters.WithLock( + [&results, handler, username, hasTransactionManagementPrivilege](const auto &interpreters) { + results = handler.ShowTransactions(interpreters, username, hasTransactionManagementPrivilege); + }); + return results; + }; + break; + } + case TransactionQueueQuery::Action::TERMINATE_TRANSACTIONS: { + std::vector<std::string> maybe_kill_transaction_ids; + std::transform(transaction_query->transaction_id_list_.begin(), transaction_query->transaction_id_list_.end(), + std::back_inserter(maybe_kill_transaction_ids), [&evaluator](Expression *expression) { + return std::string(expression->Accept(evaluator).ValueString()); + }); + callback.header = {"transaction_id", "killed"}; + callback.fn = [handler = TransactionQueueQueryHandler(), interpreter_context, maybe_kill_transaction_ids, + username, hasTransactionManagementPrivilege]() mutable { + return handler.KillTransactions(interpreter_context, maybe_kill_transaction_ids, username, + hasTransactionManagementPrivilege); + }; + break; + } + } + + return callback; +} + +PreparedQuery PrepareTransactionQueueQuery(ParsedQuery parsed_query, const std::optional<std::string> &username, + bool in_explicit_transaction, InterpreterContext *interpreter_context, + DbAccessor *dba) { + if (in_explicit_transaction) { + throw TransactionQueueInMulticommandTxException(); + } + + auto *transaction_queue_query = utils::Downcast<TransactionQueueQuery>(parsed_query.query); + MG_ASSERT(transaction_queue_query); + auto callback = + HandleTransactionQueueQuery(transaction_queue_query, username, parsed_query.parameters, interpreter_context, dba); + + return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges), + [callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}]( + AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> { + if (UNLIKELY(!pull_plan)) { + pull_plan = std::make_shared<PullPlanVector>(callback_fn()); + } + + if (pull_plan->Pull(stream, n)) { + return QueryHandlerResult::COMMIT; + } + return std::nullopt; + }, + RWType::NONE}; +} + +PreparedQuery PrepareVersionQuery(ParsedQuery parsed_query, bool in_explicit_transaction) { if (in_explicit_transaction) { throw VersionInfoInMulticommandTxException(); } @@ -2263,6 +2421,13 @@ PreparedQuery PrepareConstraintQuery(ParsedQuery parsed_query, bool in_explicit_ RWType::NONE}; } +std::optional<uint64_t> Interpreter::GetTransactionId() const { + if (db_accessor_) { + return db_accessor_->GetTransactionId(); + } + return {}; +} + void Interpreter::BeginTransaction() { const auto prepared_query = PrepareTransactionQuery("BEGIN"); prepared_query.query_handler(nullptr, {}); @@ -2272,12 +2437,14 @@ void Interpreter::CommitTransaction() { const auto prepared_query = PrepareTransactionQuery("COMMIT"); prepared_query.query_handler(nullptr, {}); query_executions_.clear(); + transaction_queries_->clear(); } void Interpreter::RollbackTransaction() { const auto prepared_query = PrepareTransactionQuery("ROLLBACK"); prepared_query.query_handler(nullptr, {}); query_executions_.clear(); + transaction_queries_->clear(); } Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, @@ -2285,10 +2452,17 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, const std::string *username) { if (!in_explicit_transaction_) { query_executions_.clear(); + transaction_queries_->clear(); } + // This will be done in the handle transaction query. Our handler can save username and then send it to the kill and + // show transactions. + std::optional<std::string> user = StringPointerToOptional(username); + username_ = user; + query_executions_.emplace_back(std::make_unique<QueryExecution>()); auto &query_execution = query_executions_.back(); + std::optional<int> qid = in_explicit_transaction_ ? static_cast<int>(query_executions_.size() - 1) : std::optional<int>{}; @@ -2302,6 +2476,9 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, return {query_execution->prepared_query->header, query_execution->prepared_query->privileges, qid}; } + // Don't save BEGIN, COMMIT or ROLLBACK + transaction_queries_->push_back(query_string); + // All queries other than transaction control queries advance the command in // an explicit transaction block. if (in_explicit_transaction_) { @@ -2327,10 +2504,12 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, if (!in_explicit_transaction_ && (utils::Downcast<CypherQuery>(parsed_query.query) || utils::Downcast<ExplainQuery>(parsed_query.query) || utils::Downcast<ProfileQuery>(parsed_query.query) || utils::Downcast<DumpQuery>(parsed_query.query) || - utils::Downcast<TriggerQuery>(parsed_query.query))) { + utils::Downcast<TriggerQuery>(parsed_query.query) || + utils::Downcast<TransactionQueueQuery>(parsed_query.query))) { db_accessor_ = std::make_unique<storage::Storage::Accessor>(interpreter_context_->db->Access(GetIsolationLevelOverride())); execution_db_accessor_.emplace(db_accessor_.get()); + transaction_status_.store(TransactionStatus::ACTIVE, std::memory_order_release); if (utils::Downcast<CypherQuery>(parsed_query.query) && interpreter_context_->trigger_store.HasTriggers()) { trigger_context_collector_.emplace(interpreter_context_->trigger_store.GetEventTypes()); @@ -2343,15 +2522,15 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, if (utils::Downcast<CypherQuery>(parsed_query.query)) { prepared_query = PrepareCypherQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_, &*execution_db_accessor_, &query_execution->execution_memory, - &query_execution->notifications, username, + &query_execution->notifications, username, &transaction_status_, trigger_context_collector_ ? &*trigger_context_collector_ : nullptr); } else if (utils::Downcast<ExplainQuery>(parsed_query.query)) { prepared_query = PrepareExplainQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_, &*execution_db_accessor_, &query_execution->execution_memory_with_exception); } else if (utils::Downcast<ProfileQuery>(parsed_query.query)) { - prepared_query = PrepareProfileQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, - interpreter_context_, &*execution_db_accessor_, - &query_execution->execution_memory_with_exception, username); + prepared_query = PrepareProfileQuery( + std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, interpreter_context_, + &*execution_db_accessor_, &query_execution->execution_memory_with_exception, username, &transaction_status_); } else if (utils::Downcast<DumpQuery>(parsed_query.query)) { prepared_query = PrepareDumpQuery(std::move(parsed_query), &query_execution->summary, &*execution_db_accessor_, &query_execution->execution_memory); @@ -2359,9 +2538,9 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, prepared_query = PrepareIndexQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications, interpreter_context_); } else if (utils::Downcast<AuthQuery>(parsed_query.query)) { - prepared_query = PrepareAuthQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, - interpreter_context_, &*execution_db_accessor_, - &query_execution->execution_memory_with_exception, username); + prepared_query = PrepareAuthQuery( + std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, interpreter_context_, + &*execution_db_accessor_, &query_execution->execution_memory_with_exception, username, &transaction_status_); } else if (utils::Downcast<InfoQuery>(parsed_query.query)) { prepared_query = PrepareInfoQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, interpreter_context_, interpreter_context_->db, @@ -2398,6 +2577,9 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, prepared_query = PrepareSettingQuery(std::move(parsed_query), in_explicit_transaction_, &*execution_db_accessor_); } else if (utils::Downcast<VersionQuery>(parsed_query.query)) { prepared_query = PrepareVersionQuery(std::move(parsed_query), in_explicit_transaction_); + } else if (utils::Downcast<TransactionQueueQuery>(parsed_query.query)) { + prepared_query = PrepareTransactionQueueQuery(std::move(parsed_query), username_, in_explicit_transaction_, + interpreter_context_, &*execution_db_accessor_); } else { LOG_FATAL("Should not get here -- unknown query type!"); } @@ -2425,7 +2607,29 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, } } +std::vector<TypedValue> Interpreter::GetQueries() { + auto typed_queries = std::vector<TypedValue>(); + transaction_queries_.WithLock([&typed_queries](const auto &transaction_queries) { + std::for_each(transaction_queries.begin(), transaction_queries.end(), + [&typed_queries](const auto &query) { typed_queries.emplace_back(query); }); + }); + return typed_queries; +} + void Interpreter::Abort() { + auto expected = TransactionStatus::ACTIVE; + while (!transaction_status_.compare_exchange_weak(expected, TransactionStatus::STARTED_ROLLBACK)) { + if (expected == TransactionStatus::TERMINATED || expected == TransactionStatus::IDLE) { + transaction_status_.store(TransactionStatus::STARTED_ROLLBACK); + break; + } + expected = TransactionStatus::ACTIVE; + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + + utils::OnScopeExit clean_status( + [this]() { transaction_status_.store(TransactionStatus::IDLE, std::memory_order_release); }); + expect_rollback_ = false; in_explicit_transaction_ = false; if (!db_accessor_) return; @@ -2437,7 +2641,7 @@ void Interpreter::Abort() { namespace { void RunTriggersIndividually(const utils::SkipList<Trigger> &triggers, InterpreterContext *interpreter_context, - TriggerContext trigger_context) { + TriggerContext trigger_context, std::atomic<TransactionStatus> *transaction_status) { // Run the triggers for (const auto &trigger : triggers.access()) { utils::MonotonicBufferResource execution_memory{kExecutionMemoryBlockSize}; @@ -2449,7 +2653,8 @@ void RunTriggersIndividually(const utils::SkipList<Trigger> &triggers, Interpret trigger_context.AdaptForAccessor(&db_accessor); try { trigger.Execute(&db_accessor, &execution_memory, interpreter_context->config.execution_timeout_sec, - &interpreter_context->is_shutting_down, trigger_context, interpreter_context->auth_checker); + &interpreter_context->is_shutting_down, transaction_status, trigger_context, + interpreter_context->auth_checker); } catch (const utils::BasicException &exception) { spdlog::warn("Trigger '{}' failed with exception:\n{}", trigger.Name(), exception.what()); db_accessor.Abort(); @@ -2504,6 +2709,25 @@ void Interpreter::Commit() { // a query. if (!db_accessor_) return; + /* + At this point we must check that the transaction is alive to start committing. The only other possible state is + verifying and in that case we must check if the transaction was terminated and if yes abort committing. Exception + should suffice. + */ + auto expected = TransactionStatus::ACTIVE; + while (!transaction_status_.compare_exchange_weak(expected, TransactionStatus::STARTED_COMMITTING)) { + if (expected == TransactionStatus::TERMINATED) { + throw memgraph::utils::BasicException( + "Aborting transaction commit because the transaction was requested to stop from other session. "); + } + expected = TransactionStatus::ACTIVE; + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + + // Clean transaction status if something went wrong + utils::OnScopeExit clean_status( + [this]() { transaction_status_.store(TransactionStatus::IDLE, std::memory_order_release); }); + std::optional<TriggerContext> trigger_context = std::nullopt; if (trigger_context_collector_) { trigger_context.emplace(std::move(*trigger_context_collector_).TransformToTriggerContext()); @@ -2517,7 +2741,8 @@ void Interpreter::Commit() { AdvanceCommand(); try { trigger.Execute(&*execution_db_accessor_, &execution_memory, interpreter_context_->config.execution_timeout_sec, - &interpreter_context_->is_shutting_down, *trigger_context, interpreter_context_->auth_checker); + &interpreter_context_->is_shutting_down, &transaction_status_, *trigger_context, + interpreter_context_->auth_checker); } catch (const utils::BasicException &e) { throw utils::BasicException( fmt::format("Trigger '{}' caused the transaction to fail.\nException: {}", trigger.Name(), e.what())); @@ -2579,10 +2804,10 @@ void Interpreter::Commit() { // This means the ordered execution of after commit triggers are not guaranteed. if (trigger_context && interpreter_context_->trigger_store.AfterCommitTriggers().size() > 0) { interpreter_context_->after_commit_trigger_pool.AddTask( - [trigger_context = std::move(*trigger_context), interpreter_context = this->interpreter_context_, + [this, trigger_context = std::move(*trigger_context), user_transaction = std::shared_ptr(std::move(db_accessor_))]() mutable { - RunTriggersIndividually(interpreter_context->trigger_store.AfterCommitTriggers(), interpreter_context, - std::move(trigger_context)); + RunTriggersIndividually(this->interpreter_context_->trigger_store.AfterCommitTriggers(), + this->interpreter_context_, std::move(trigger_context), &this->transaction_status_); user_transaction->FinalizeTransaction(); SPDLOG_DEBUG("Finished executing after commit triggers"); // NOLINT(bugprone-lambda-function-name) }); diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index c74304773..132cdc790 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -11,6 +11,8 @@ #pragma once +#include <unordered_set> + #include <gflags/gflags.h> #include "query/auth_checker.hpp" @@ -37,6 +39,7 @@ #include "utils/settings.hpp" #include "utils/skip_list.hpp" #include "utils/spin_lock.hpp" +#include "utils/synchronized.hpp" #include "utils/thread_pool.hpp" #include "utils/timer.hpp" #include "utils/tsc.hpp" @@ -179,12 +182,12 @@ struct PreparedQuery { plan::ReadWriteTypeChecker::RWType rw_type; }; +class Interpreter; + /** * Holds data shared between multiple `Interpreter` instances (which might be * running concurrently). * - * Users should initialize the context but should not modify it after it has - * been passed to an `Interpreter` instance. */ struct InterpreterContext { explicit InterpreterContext(storage::Storage *db, InterpreterConfig config, @@ -214,6 +217,7 @@ struct InterpreterContext { const InterpreterConfig config; query::stream::Streams streams; + utils::Synchronized<std::unordered_set<Interpreter *>, utils::SpinLock> interpreters; }; /// Function that is used to tell all active interpreters that they should stop @@ -235,6 +239,10 @@ class Interpreter final { std::optional<int> qid; }; + std::optional<std::string> username_; + bool in_explicit_transaction_{false}; + bool expect_rollback_{false}; + /** * Prepare a query for execution. * @@ -290,6 +298,11 @@ class Interpreter final { void BeginTransaction(); + /* + Returns transaction id or empty if the db_accessor is not initialized. + */ + std::optional<uint64_t> GetTransactionId() const; + void CommitTransaction(); void RollbackTransaction(); @@ -297,11 +310,15 @@ class Interpreter final { void SetNextTransactionIsolationLevel(storage::IsolationLevel isolation_level); void SetSessionIsolationLevel(storage::IsolationLevel isolation_level); + std::vector<TypedValue> GetQueries(); + /** * Abort the current multicommand transaction. */ void Abort(); + std::atomic<TransactionStatus> transaction_status_{TransactionStatus::IDLE}; + private: struct QueryExecution { std::optional<PreparedQuery> prepared_query; @@ -338,6 +355,8 @@ class Interpreter final { // and deletion of a single query execution, i.e. when a query finishes, // we reset the corresponding unique_ptr. std::vector<std::unique_ptr<QueryExecution>> query_executions_; + // all queries that are run as part of the current transaction + utils::Synchronized<std::vector<std::string>, utils::SpinLock> transaction_queries_; InterpreterContext *interpreter_context_; @@ -347,8 +366,6 @@ class Interpreter final { std::unique_ptr<storage::Storage::Accessor> db_accessor_; std::optional<DbAccessor> execution_db_accessor_; std::optional<TriggerContextCollector> trigger_context_collector_; - bool in_explicit_transaction_{false}; - bool expect_rollback_{false}; std::optional<storage::IsolationLevel> interpreter_isolation_level; std::optional<storage::IsolationLevel> next_transaction_isolation_level; @@ -365,12 +382,32 @@ class Interpreter final { } }; +class TransactionQueueQueryHandler { + public: + TransactionQueueQueryHandler() = default; + virtual ~TransactionQueueQueryHandler() = default; + + TransactionQueueQueryHandler(const TransactionQueueQueryHandler &) = default; + TransactionQueueQueryHandler &operator=(const TransactionQueueQueryHandler &) = default; + + TransactionQueueQueryHandler(TransactionQueueQueryHandler &&) = default; + TransactionQueueQueryHandler &operator=(TransactionQueueQueryHandler &&) = default; + + static std::vector<std::vector<TypedValue>> ShowTransactions(const std::unordered_set<Interpreter *> &interpreters, + const std::optional<std::string> &username, + bool hasTransactionManagementPrivilege); + + static std::vector<std::vector<TypedValue>> KillTransactions( + InterpreterContext *interpreter_context, const std::vector<std::string> &maybe_kill_transaction_ids, + const std::optional<std::string> &username, bool hasTransactionManagementPrivilege); +}; + template <typename TStream> std::map<std::string, TypedValue> Interpreter::Pull(TStream *result_stream, std::optional<int> n, std::optional<int> qid) { MG_ASSERT(in_explicit_transaction_ || !qid, "qid can be only used in explicit transaction!"); - const int qid_value = qid ? *qid : static_cast<int>(query_executions_.size() - 1); + const int qid_value = qid ? *qid : static_cast<int>(query_executions_.size() - 1); if (qid_value < 0 || qid_value >= query_executions_.size()) { throw InvalidArgumentsException("qid", "Query with specified ID does not exist!"); } @@ -430,6 +467,7 @@ std::map<std::string, TypedValue> Interpreter::Pull(TStream *result_stream, std: // methods as we will delete summary contained in them which we need // after our query finished executing. query_executions_.clear(); + transaction_queries_->clear(); } else { // We can only clear this execution as some of the queries // in the transaction can be in unfinished state diff --git a/src/query/stream/streams.cpp b/src/query/stream/streams.cpp index 5f7564b0d..e5b4241ae 100644 --- a/src/query/stream/streams.cpp +++ b/src/query/stream/streams.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -490,6 +490,11 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std retry_interval = interpreter_context_->config.stream_transaction_retry_interval]( const std::vector<typename TStream::Message> &messages) mutable { auto accessor = interpreter_context->db->Access(); + // register new interpreter into interpreter_context_ + interpreter_context->interpreters->insert(interpreter.get()); + utils::OnScopeExit interpreter_cleanup{ + [interpreter_context, interpreter]() { interpreter_context->interpreters->erase(interpreter.get()); }}; + EventCounter::IncrementCounter(EventCounter::MessagesConsumed, messages.size()); CallCustomTransformation(transformation_name, messages, result, accessor, *memory_resource, stream_name); diff --git a/src/query/trigger.cpp b/src/query/trigger.cpp index ee8e84e3f..1c1d14a64 100644 --- a/src/query/trigger.cpp +++ b/src/query/trigger.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -195,7 +195,8 @@ std::shared_ptr<Trigger::TriggerPlan> Trigger::GetPlan(DbAccessor *db_accessor, void Trigger::Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution_memory, const double max_execution_time_sec, std::atomic<bool> *is_shutting_down, - const TriggerContext &context, const AuthChecker *auth_checker) const { + std::atomic<TransactionStatus> *transaction_status, const TriggerContext &context, + const AuthChecker *auth_checker) const { if (!context.ShouldEventTrigger(event_type_)) { return; } @@ -214,6 +215,7 @@ void Trigger::Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution ctx.evaluation_context.labels = NamesToLabels(plan.ast_storage().labels_, dba); ctx.timer = utils::AsyncTimer(max_execution_time_sec); ctx.is_shutting_down = is_shutting_down; + ctx.transaction_status = transaction_status; ctx.is_profile_query = false; // Set up temporary memory for a single Pull. Initial memory comes from the diff --git a/src/query/trigger.hpp b/src/query/trigger.hpp index a1b6b7012..499bf634c 100644 --- a/src/query/trigger.hpp +++ b/src/query/trigger.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -31,6 +31,8 @@ #include "utils/spin_lock.hpp" namespace memgraph::query { + +enum class TransactionStatus; struct Trigger { explicit Trigger(std::string name, const std::string &query, const std::map<std::string, storage::PropertyValue> &user_parameters, TriggerEventType event_type, @@ -39,8 +41,8 @@ struct Trigger { const query::AuthChecker *auth_checker); void Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution_memory, double max_execution_time_sec, - std::atomic<bool> *is_shutting_down, const TriggerContext &context, - const AuthChecker *auth_checker) const; + std::atomic<bool> *is_shutting_down, std::atomic<TransactionStatus> *transaction_status, + const TriggerContext &context, const AuthChecker *auth_checker) const; bool operator==(const Trigger &other) const { return name_ == other.name_; } // NOLINTNEXTLINE (modernize-use-nullptr) diff --git a/src/storage/v2/constraints.cpp b/src/storage/v2/constraints.cpp index fab6ee4c4..12243f173 100644 --- a/src/storage/v2/constraints.cpp +++ b/src/storage/v2/constraints.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -12,6 +12,7 @@ #include "storage/v2/constraints.hpp" #include <algorithm> +#include <atomic> #include <cstring> #include <map> @@ -71,7 +72,7 @@ bool LastCommittedVersionHasLabelProperty(const Vertex &vertex, LabelId label, c while (delta != nullptr) { auto ts = delta->timestamp->load(std::memory_order_acquire); - if (ts < commit_timestamp || ts == transaction.transaction_id) { + if (ts < commit_timestamp || ts == transaction.transaction_id.load(std::memory_order_acquire)) { break; } diff --git a/src/storage/v2/mvcc.hpp b/src/storage/v2/mvcc.hpp index 52154fe4a..4c0e55461 100644 --- a/src/storage/v2/mvcc.hpp +++ b/src/storage/v2/mvcc.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -11,6 +11,7 @@ #pragma once +#include <atomic> #include "storage/v2/property_value.hpp" #include "storage/v2/transaction.hpp" #include "storage/v2/view.hpp" @@ -30,7 +31,7 @@ inline void ApplyDeltasForRead(Transaction *transaction, const Delta *delta, Vie // This allows the transaction to see its changes even though it's committed. const auto commit_timestamp = transaction->commit_timestamp ? transaction->commit_timestamp->load(std::memory_order_acquire) - : transaction->transaction_id; + : transaction->transaction_id.load(std::memory_order_acquire); while (delta != nullptr) { auto ts = delta->timestamp->load(std::memory_order_acquire); auto cid = delta->command_id; @@ -80,7 +81,7 @@ inline bool PrepareForWrite(Transaction *transaction, TObj *object) { if (object->delta == nullptr) return true; auto ts = object->delta->timestamp->load(std::memory_order_acquire); - if (ts == transaction->transaction_id || ts < transaction->start_timestamp) { + if (ts == transaction->transaction_id.load(std::memory_order_acquire) || ts < transaction->start_timestamp) { return true; } diff --git a/src/storage/v2/storage.cpp b/src/storage/v2/storage.cpp index b51607b9e..5d418b805 100644 --- a/src/storage/v2/storage.cpp +++ b/src/storage/v2/storage.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -985,8 +985,8 @@ void Storage::Accessor::Abort() { auto vertex = prev.vertex; std::lock_guard<utils::SpinLock> guard(vertex->lock); Delta *current = vertex->delta; - while (current != nullptr && - current->timestamp->load(std::memory_order_acquire) == transaction_.transaction_id) { + while (current != nullptr && current->timestamp->load(std::memory_order_acquire) == + transaction_.transaction_id.load(std::memory_order_acquire)) { switch (current->action) { case Delta::Action::REMOVE_LABEL: { auto it = std::find(vertex->labels.begin(), vertex->labels.end(), current->label); @@ -1072,8 +1072,8 @@ void Storage::Accessor::Abort() { auto edge = prev.edge; std::lock_guard<utils::SpinLock> guard(edge->lock); Delta *current = edge->delta; - while (current != nullptr && - current->timestamp->load(std::memory_order_acquire) == transaction_.transaction_id) { + while (current != nullptr && current->timestamp->load(std::memory_order_acquire) == + transaction_.transaction_id.load(std::memory_order_acquire)) { switch (current->action) { case Delta::Action::SET_PROPERTY: { edge->properties.SetProperty(current->property.key, current->property.value); @@ -1144,6 +1144,13 @@ void Storage::Accessor::FinalizeTransaction() { } } +std::optional<uint64_t> Storage::Accessor::GetTransactionId() const { + if (is_transaction_active_) { + return transaction_.transaction_id.load(std::memory_order_acquire); + } + return {}; +} + const std::string &Storage::LabelToName(LabelId label) const { return name_id_mapper_.IdToName(label.AsUint()); } const std::string &Storage::PropertyToName(PropertyId property) const { diff --git a/src/storage/v2/storage.hpp b/src/storage/v2/storage.hpp index 0a6bf5c68..407b0c090 100644 --- a/src/storage/v2/storage.hpp +++ b/src/storage/v2/storage.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -12,6 +12,7 @@ #pragma once #include <atomic> +#include <cstdint> #include <filesystem> #include <optional> #include <shared_mutex> @@ -324,6 +325,8 @@ class Storage final { void FinalizeTransaction(); + std::optional<uint64_t> GetTransactionId() const; + private: /// @throw std::bad_alloc VertexAccessor CreateVertex(storage::Gid gid); diff --git a/src/storage/v2/transaction.hpp b/src/storage/v2/transaction.hpp index d7a770598..348c3e605 100644 --- a/src/storage/v2/transaction.hpp +++ b/src/storage/v2/transaction.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -39,7 +39,7 @@ struct Transaction { isolation_level(isolation_level) {} Transaction(Transaction &&other) noexcept - : transaction_id(other.transaction_id), + : transaction_id(other.transaction_id.load(std::memory_order_acquire)), start_timestamp(other.start_timestamp), commit_timestamp(std::move(other.commit_timestamp)), command_id(other.command_id), @@ -56,10 +56,10 @@ struct Transaction { /// @throw std::bad_alloc if failed to create the `commit_timestamp` void EnsureCommitTimestampExists() { if (commit_timestamp != nullptr) return; - commit_timestamp = std::make_unique<std::atomic<uint64_t>>(transaction_id); + commit_timestamp = std::make_unique<std::atomic<uint64_t>>(transaction_id.load(std::memory_order_relaxed)); } - uint64_t transaction_id; + std::atomic<uint64_t> transaction_id; uint64_t start_timestamp; // The `Transaction` object is stack allocated, but the `commit_timestamp` // must be heap allocated because `Delta`s have a pointer to it, and that @@ -73,12 +73,16 @@ struct Transaction { }; inline bool operator==(const Transaction &first, const Transaction &second) { - return first.transaction_id == second.transaction_id; + return first.transaction_id.load(std::memory_order_acquire) == second.transaction_id.load(std::memory_order_acquire); } inline bool operator<(const Transaction &first, const Transaction &second) { - return first.transaction_id < second.transaction_id; + return first.transaction_id.load(std::memory_order_acquire) < second.transaction_id.load(std::memory_order_acquire); +} +inline bool operator==(const Transaction &first, const uint64_t &second) { + return first.transaction_id.load(std::memory_order_acquire) == second; +} +inline bool operator<(const Transaction &first, const uint64_t &second) { + return first.transaction_id.load(std::memory_order_acquire) < second; } -inline bool operator==(const Transaction &first, const uint64_t &second) { return first.transaction_id == second; } -inline bool operator<(const Transaction &first, const uint64_t &second) { return first.transaction_id < second; } } // namespace memgraph::storage diff --git a/src/utils/typeinfo.hpp b/src/utils/typeinfo.hpp index ca0cbe39d..bad53f9c6 100644 --- a/src/utils/typeinfo.hpp +++ b/src/utils/typeinfo.hpp @@ -176,8 +176,8 @@ enum class TypeId : uint64_t { AST_VERSION_QUERY, AST_FOREACH, AST_SHOW_CONFIG_QUERY, + AST_TRANSACTION_QUEUE_QUERY, AST_EXISTS, - // Symbol SYMBOL, }; diff --git a/tests/e2e/CMakeLists.txt b/tests/e2e/CMakeLists.txt index 1d2eee643..bb7b6839e 100644 --- a/tests/e2e/CMakeLists.txt +++ b/tests/e2e/CMakeLists.txt @@ -44,6 +44,7 @@ add_subdirectory(module_file_manager) add_subdirectory(monitoring_server) add_subdirectory(lba_procedures) add_subdirectory(python_query_modules_reloading) +add_subdirectory(transaction_queue) add_subdirectory(mock_api) copy_e2e_python_files(pytest_runner pytest_runner.sh "") diff --git a/tests/e2e/lba_procedures/show_privileges.py b/tests/e2e/lba_procedures/show_privileges.py index 6e3c85572..29f834896 100644 --- a/tests/e2e/lba_procedures/show_privileges.py +++ b/tests/e2e/lba_procedures/show_privileges.py @@ -10,6 +10,7 @@ # licenses/APL.txt. import sys + import pytest from common import connect, execute_and_fetch_all @@ -35,6 +36,7 @@ BASIC_PRIVILEGES = [ "MODULE_READ", "WEBSOCKET", "MODULE_WRITE", + "TRANSACTION_MANAGEMENT", ] @@ -58,7 +60,7 @@ def test_lba_procedures_show_privileges_first_user(): cursor = connect(username="Josip", password="").cursor() result = execute_and_fetch_all(cursor, "SHOW PRIVILEGES FOR Josip;") - assert len(result) == 30 + assert len(result) == 31 fine_privilege_results = [res for res in result if res[0] not in BASIC_PRIVILEGES] diff --git a/tests/e2e/transaction_queue/CMakeLists.txt b/tests/e2e/transaction_queue/CMakeLists.txt new file mode 100644 index 000000000..574c46bfd --- /dev/null +++ b/tests/e2e/transaction_queue/CMakeLists.txt @@ -0,0 +1,8 @@ +function(copy_query_modules_reloading_procedures_e2e_python_files FILE_NAME) + copy_e2e_python_files(transaction_queue ${FILE_NAME}) +endfunction() + +copy_query_modules_reloading_procedures_e2e_python_files(common.py) +copy_query_modules_reloading_procedures_e2e_python_files(test_transaction_queue.py) + +add_subdirectory(procedures) diff --git a/tests/e2e/transaction_queue/common.py b/tests/e2e/transaction_queue/common.py new file mode 100644 index 000000000..fdab67a72 --- /dev/null +++ b/tests/e2e/transaction_queue/common.py @@ -0,0 +1,26 @@ +# Copyright 2022 Memgraph Ltd. +# +# Use of this software is governed by the Business Source License +# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +# License, and you may not use this file except in compliance with the Business Source License. +# +# As of the Change Date specified in that file, in accordance with +# the Business Source License, use of this software will be governed +# by the Apache License, Version 2.0, included in the file +# licenses/APL.txt. + +import typing + +import mgclient +import pytest + + +def execute_and_fetch_all(cursor: mgclient.Cursor, query: str, params: dict = {}) -> typing.List[tuple]: + cursor.execute(query, params) + return cursor.fetchall() + + +def connect(**kwargs) -> mgclient.Connection: + connection = mgclient.connect(host="localhost", port=7687, **kwargs) + connection.autocommit = True + return connection diff --git a/tests/e2e/transaction_queue/procedures/CMakeLists.txt b/tests/e2e/transaction_queue/procedures/CMakeLists.txt new file mode 100644 index 000000000..0a22ee796 --- /dev/null +++ b/tests/e2e/transaction_queue/procedures/CMakeLists.txt @@ -0,0 +1 @@ +copy_e2e_python_files(transaction_queue infinite_query.py) diff --git a/tests/e2e/transaction_queue/procedures/infinite_query.py b/tests/e2e/transaction_queue/procedures/infinite_query.py new file mode 100644 index 000000000..33d5f4d40 --- /dev/null +++ b/tests/e2e/transaction_queue/procedures/infinite_query.py @@ -0,0 +1,27 @@ +# Copyright 2021 Memgraph Ltd. +# +# Use of this software is governed by the Business Source License +# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +# License, and you may not use this file except in compliance with the Business Source License. +# +# As of the Change Date specified in that file, in accordance with +# the Business Source License, use of this software will be governed +# by the Apache License, Version 2.0, included in the file +# licenses/APL.txt. + +import threading +import time + +import mgp + + +@mgp.read_proc +def long_query(ctx: mgp.ProcCtx) -> mgp.Record(my_id=int): + id = 1 + try: + while True: + if ctx.check_must_abort(): + break + id += 1 + except mgp.AbortError: + return mgp.Record(my_id=id) diff --git a/tests/e2e/transaction_queue/test_transaction_queue.py b/tests/e2e/transaction_queue/test_transaction_queue.py new file mode 100644 index 000000000..ba9eac19a --- /dev/null +++ b/tests/e2e/transaction_queue/test_transaction_queue.py @@ -0,0 +1,338 @@ +# Copyright 2022 Memgraph Ltd. +# +# Use of this software is governed by the Business Source License +# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +# License, and you may not use this file except in compliance with the Business Source License. +# +# As of the Change Date specified in that file, in accordance with +# the Business Source License, use of this software will be governed +# by the Apache License, Version 2.0, included in the file +# licenses/APL.txt. + + +import multiprocessing +import sys +import threading +import time +from typing import List + +import mgclient +import pytest +from common import connect, execute_and_fetch_all + +# Utility functions +# ------------------------- + + +def get_non_show_transaction_id(results): + """Returns transaction id of the first transaction that is not SHOW TRANSACTIONS;""" + for res in results: + if res[2] != ["SHOW TRANSACTIONS"]: + return res[1] + + +def show_transactions_test(cursor, expected_num_results: int): + results = execute_and_fetch_all(cursor, "SHOW TRANSACTIONS") + assert len(results) == expected_num_results + return results + + +def process_function(cursor, queries: List[str]): + try: + for query in queries: + cursor.execute(query, {}) + except mgclient.DatabaseError: + pass + + +# Tests +# ------------------------- + + +def test_self_transaction(): + """Tests that simple show transactions work when no other is running.""" + cursor = connect().cursor() + results = execute_and_fetch_all(cursor, "SHOW TRANSACTIONS") + assert len(results) == 1 + + +def test_admin_has_one_transaction(): + """Creates admin and tests that he sees only one transaction.""" + # a_cursor is used for creating admin user, simulates main thread + superadmin_cursor = connect().cursor() + execute_and_fetch_all(superadmin_cursor, "CREATE USER admin") + execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin") + admin_cursor = connect(username="admin", password="").cursor() + process = multiprocessing.Process(target=show_transactions_test, args=(admin_cursor, 1)) + process.start() + process.join() + execute_and_fetch_all(superadmin_cursor, "DROP USER admin") + + +def test_user_can_see_its_transaction(): + """Tests that user without privileges can see its own transaction""" + superadmin_cursor = connect().cursor() + execute_and_fetch_all(superadmin_cursor, "CREATE USER admin") + execute_and_fetch_all(superadmin_cursor, "GRANT ALL PRIVILEGES TO admin") + execute_and_fetch_all(superadmin_cursor, "CREATE USER user") + execute_and_fetch_all(superadmin_cursor, "REVOKE ALL PRIVILEGES FROM user") + user_cursor = connect(username="user", password="").cursor() + process = multiprocessing.Process(target=show_transactions_test, args=(user_cursor, 1)) + process.start() + process.join() + admin_cursor = connect(username="admin", password="").cursor() + execute_and_fetch_all(admin_cursor, "DROP USER user") + execute_and_fetch_all(admin_cursor, "DROP USER admin") + + +def test_explicit_transaction_output(): + superadmin_cursor = connect().cursor() + execute_and_fetch_all(superadmin_cursor, "CREATE USER admin") + execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin") + admin_connection = connect(username="admin", password="") + admin_cursor = admin_connection.cursor() + # Admin starts running explicit transaction + process = multiprocessing.Process( + target=process_function, + args=(superadmin_cursor, ["BEGIN", "CREATE (n:Person {id_: 1})", "CREATE (n:Person {id_: 2})"]), + ) + process.start() + time.sleep(0.5) + show_results = show_transactions_test(admin_cursor, 2) + if show_results[0][2] == ["SHOW TRANSACTIONS"]: + executing_index = 0 + else: + executing_index = 1 + assert show_results[1 - executing_index][2] == ["CREATE (n:Person {id_: 1})", "CREATE (n:Person {id_: 2})"] + + execute_and_fetch_all(superadmin_cursor, "ROLLBACK") + execute_and_fetch_all(superadmin_cursor, "DROP USER admin") + + +def test_superadmin_cannot_see_admin_can_see_admin(): + """Tests that superadmin cannot see the transaction created by admin but two admins can see and kill each other's transactions.""" + superadmin_cursor = connect().cursor() + execute_and_fetch_all(superadmin_cursor, "CREATE USER admin1") + execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin1") + execute_and_fetch_all(superadmin_cursor, "CREATE USER admin2") + execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin2") + # Admin starts running infinite query + admin_connection_1 = connect(username="admin1", password="") + admin_cursor_1 = admin_connection_1.cursor() + admin_connection_2 = connect(username="admin2", password="") + admin_cursor_2 = admin_connection_2.cursor() + process = multiprocessing.Process( + target=process_function, args=(admin_cursor_1, ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"]) + ) + process.start() + time.sleep(0.5) + # Superadmin shouldn't see the execution of the admin + show_transactions_test(superadmin_cursor, 1) + show_results = show_transactions_test(admin_cursor_2, 2) + # Don't rely on the order of intepreters in Memgraph + if show_results[0][2] == ["SHOW TRANSACTIONS"]: + executing_index = 0 + else: + executing_index = 1 + assert show_results[executing_index][0] == "admin2" + assert show_results[executing_index][2] == ["SHOW TRANSACTIONS"] + assert show_results[1 - executing_index][0] == "admin1" + assert show_results[1 - executing_index][2] == ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"] + # Kill transaction + long_transaction_id = show_results[1 - executing_index][1] + execute_and_fetch_all(admin_cursor_2, f"TERMINATE TRANSACTIONS '{long_transaction_id}'") + execute_and_fetch_all(superadmin_cursor, "DROP USER admin1") + execute_and_fetch_all(superadmin_cursor, "DROP USER admin2") + admin_connection_1.close() + admin_connection_2.close() + + +def test_admin_sees_superadmin(): + """Tests that admin created by superadmin can see the superadmin's transaction.""" + superadmin_connection = connect() + superadmin_cursor = superadmin_connection.cursor() + execute_and_fetch_all(superadmin_cursor, "CREATE USER admin") + execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin") + # Admin starts running infinite query + process = multiprocessing.Process( + target=process_function, args=(superadmin_cursor, ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"]) + ) + process.start() + time.sleep(0.5) + admin_cursor = connect(username="admin", password="").cursor() + show_results = show_transactions_test(admin_cursor, 2) + # show_results_2 = show_transactions_test(admin_cursor, 2) + # Don't rely on the order of intepreters in Memgraph + if show_results[0][2] == ["SHOW TRANSACTIONS"]: + executing_index = 0 + else: + executing_index = 1 + assert show_results[executing_index][0] == "admin" + assert show_results[executing_index][2] == ["SHOW TRANSACTIONS"] + assert show_results[1 - executing_index][0] == "" + assert show_results[1 - executing_index][2] == ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"] + # Kill transaction + long_transaction_id = show_results[1 - executing_index][1] + execute_and_fetch_all(admin_cursor, f"TERMINATE TRANSACTIONS '{long_transaction_id}'") + execute_and_fetch_all(admin_cursor, "DROP USER admin") + superadmin_connection.close() + + +def test_admin_can_see_user_transaction(): + """Tests that admin can see user's transaction and kill it.""" + superadmin_cursor = connect().cursor() + execute_and_fetch_all(superadmin_cursor, "CREATE USER admin") + execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin") + execute_and_fetch_all(superadmin_cursor, "CREATE USER user") + # Admin starts running infinite query + admin_connection = connect(username="admin", password="") + admin_cursor = admin_connection.cursor() + user_connection = connect(username="user", password="") + user_cursor = user_connection.cursor() + process = multiprocessing.Process( + target=process_function, args=(user_cursor, ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"]) + ) + process.start() + time.sleep(0.5) + # Admin should see the user's transaction. + show_results = show_transactions_test(admin_cursor, 2) + # Don't rely on the order of intepreters in Memgraph + if show_results[0][2] == ["SHOW TRANSACTIONS"]: + executing_index = 0 + else: + executing_index = 1 + assert show_results[executing_index][0] == "admin" + assert show_results[executing_index][2] == ["SHOW TRANSACTIONS"] + assert show_results[1 - executing_index][0] == "user" + assert show_results[1 - executing_index][2] == ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"] + # Kill transaction + long_transaction_id = show_results[1 - executing_index][1] + execute_and_fetch_all(admin_cursor, f"TERMINATE TRANSACTIONS '{long_transaction_id}'") + execute_and_fetch_all(superadmin_cursor, "DROP USER admin") + execute_and_fetch_all(superadmin_cursor, "DROP USER user") + admin_connection.close() + user_connection.close() + + +def test_user_cannot_see_admin_transaction(): + """User cannot see admin's transaction but other admin can and he can kill it.""" + # Superadmin creates two admins and one user + superadmin_cursor = connect().cursor() + execute_and_fetch_all(superadmin_cursor, "CREATE USER admin1") + execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin1") + execute_and_fetch_all(superadmin_cursor, "CREATE USER admin2") + execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin2") + execute_and_fetch_all(superadmin_cursor, "CREATE USER user") + admin_connection_1 = connect(username="admin1", password="") + admin_cursor_1 = admin_connection_1.cursor() + admin_connection_2 = connect(username="admin2", password="") + admin_cursor_2 = admin_connection_2.cursor() + user_connection = connect(username="user", password="") + user_cursor = user_connection.cursor() + # Admin1 starts running long running query + process = multiprocessing.Process( + target=process_function, args=(admin_cursor_1, ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"]) + ) + process.start() + time.sleep(0.5) + # User should not see the admin's transaction. + show_transactions_test(user_cursor, 1) + # Second admin should see other admin's transactions + show_results = show_transactions_test(admin_cursor_2, 2) + # Don't rely on the order of intepreters in Memgraph + if show_results[0][2] == ["SHOW TRANSACTIONS"]: + executing_index = 0 + else: + executing_index = 1 + assert show_results[executing_index][0] == "admin2" + assert show_results[executing_index][2] == ["SHOW TRANSACTIONS"] + assert show_results[1 - executing_index][0] == "admin1" + assert show_results[1 - executing_index][2] == ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"] + # Kill transaction + long_transaction_id = show_results[1 - executing_index][1] + execute_and_fetch_all(admin_cursor_2, f"TERMINATE TRANSACTIONS '{long_transaction_id}'") + execute_and_fetch_all(superadmin_cursor, "DROP USER admin1") + execute_and_fetch_all(superadmin_cursor, "DROP USER admin2") + execute_and_fetch_all(superadmin_cursor, "DROP USER user") + admin_connection_1.close() + admin_connection_2.close() + user_connection.close() + + +def test_killing_non_existing_transaction(): + cursor = connect().cursor() + results = execute_and_fetch_all(cursor, "TERMINATE TRANSACTIONS '1'") + assert len(results) == 1 + assert results[0][0] == "1" # transaction id + assert results[0][1] == False # not killed + + +def test_killing_multiple_non_existing_transactions(): + cursor = connect().cursor() + transactions_id = ["'1'", "'2'", "'3'"] + results = execute_and_fetch_all(cursor, f"TERMINATE TRANSACTIONS {','.join(transactions_id)}") + assert len(results) == 3 + for i in range(len(results)): + assert results[i][0] == eval(transactions_id[i]) # transaction id + assert results[i][1] == False # not killed + + +def test_admin_killing_multiple_non_existing_transactions(): + # Starting, superadmin admin + superadmin_cursor = connect().cursor() + execute_and_fetch_all(superadmin_cursor, "CREATE USER admin") + execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin") + # Connect with admin + admin_cursor = connect(username="admin", password="").cursor() + transactions_id = ["'1'", "'2'", "'3'"] + results = execute_and_fetch_all(admin_cursor, f"TERMINATE TRANSACTIONS {','.join(transactions_id)}") + assert len(results) == 3 + for i in range(len(results)): + assert results[i][0] == eval(transactions_id[i]) # transaction id + assert results[i][1] == False # not killed + execute_and_fetch_all(admin_cursor, "DROP USER admin") + + +def test_user_killing_some_transactions(): + """Tests what happens when user can kill only some of the transactions given.""" + superadmin_cursor = connect().cursor() + execute_and_fetch_all(superadmin_cursor, "CREATE USER admin") + execute_and_fetch_all(superadmin_cursor, "GRANT ALL PRIVILEGES TO admin") + execute_and_fetch_all(superadmin_cursor, "CREATE USER user1") + execute_and_fetch_all(superadmin_cursor, "REVOKE ALL PRIVILEGES FROM user1") + + # Connect with user in two different sessions + admin_cursor = connect(username="admin", password="").cursor() + execute_and_fetch_all(admin_cursor, "CREATE USER user2") + execute_and_fetch_all(admin_cursor, "GRANT ALL PRIVILEGES TO user2") + user_connection_1 = connect(username="user1", password="") + user_cursor_1 = user_connection_1.cursor() + user_connection_2 = connect(username="user2", password="") + user_cursor_2 = user_connection_2.cursor() + process_1 = multiprocessing.Process( + target=process_function, args=(user_cursor_1, ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"]) + ) + process_2 = multiprocessing.Process(target=process_function, args=(user_cursor_2, ["BEGIN", "MATCH (n) RETURN n"])) + process_1.start() + process_2.start() + # Create another user1 connections + user_connection_1_copy = connect(username="user1", password="") + user_cursor_1_copy = user_connection_1_copy.cursor() + show_user_1_results = show_transactions_test(user_cursor_1_copy, 2) + if show_user_1_results[0][2] == ["SHOW TRANSACTIONS"]: + execution_index = 0 + else: + execution_index = 1 + assert show_user_1_results[1 - execution_index][2] == ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"] + # Connect with admin + time.sleep(0.5) + show_admin_results = show_transactions_test(admin_cursor, 3) + for show_admin_res in show_admin_results: + if show_admin_res[2] != "[SHOW TRANSACTIONS]": + execute_and_fetch_all(admin_cursor, f"TERMINATE TRANSACTIONS '{show_admin_res[1]}'") + user_connection_1.close() + user_connection_2.close() + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__, "-rA"])) diff --git a/tests/e2e/transaction_queue/workloads.yaml b/tests/e2e/transaction_queue/workloads.yaml new file mode 100644 index 000000000..b5f15facf --- /dev/null +++ b/tests/e2e/transaction_queue/workloads.yaml @@ -0,0 +1,14 @@ +test_transaction_queue: &test_transaction_queue + cluster: + main: + args: ["--bolt-port", "7687", "--log-level=TRACE", "--also-log-to-stderr"] + log_file: "transaction_queue.log" + setup_queries: [] + validation_queries: [] + +workloads: + - name: "test-transaction-queue" # should be the same as the python file + binary: "tests/e2e/pytest_runner.sh" + proc: "tests/e2e/transaction_queue/procedures/" + args: ["transaction_queue/test_transaction_queue.py"] + <<: *test_transaction_queue diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index dfcbc9697..9fcb275ad 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -130,6 +130,12 @@ target_link_libraries(${test_prefix}query_serialization_property_value mg-query) add_unit_test(query_streams.cpp) target_link_libraries(${test_prefix}query_streams mg-query kafka-mock) +add_unit_test(transaction_queue.cpp) +target_link_libraries(${test_prefix}transaction_queue mg-communication mg-query mg-glue) + +add_unit_test(transaction_queue_multiple.cpp) +target_link_libraries(${test_prefix}transaction_queue_multiple mg-communication mg-query mg-glue) + # Test query functions add_unit_test(query_function_mgp_module.cpp) target_link_libraries(${test_prefix}query_function_mgp_module mg-query) diff --git a/tests/unit/interpreter.cpp b/tests/unit/interpreter.cpp index 4097a8cab..d7ad623d6 100644 --- a/tests/unit/interpreter.cpp +++ b/tests/unit/interpreter.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -18,6 +18,7 @@ #include "glue/communication.hpp" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "interpreter_faker.hpp" #include "query/auth_checker.hpp" #include "query/config.hpp" #include "query/exceptions.hpp" @@ -40,57 +41,18 @@ auto ToEdgeList(const memgraph::communication::bolt::Value &v) { return list; }; -struct InterpreterFaker { - InterpreterFaker(memgraph::storage::Storage *db, const memgraph::query::InterpreterConfig config, - const std::filesystem::path &data_directory) - : interpreter_context(db, config, data_directory), interpreter(&interpreter_context) { - interpreter_context.auth_checker = &auth_checker; - } - - auto Prepare(const std::string &query, const std::map<std::string, memgraph::storage::PropertyValue> ¶ms = {}) { - ResultStreamFaker stream(interpreter_context.db); - - const auto [header, _, qid] = interpreter.Prepare(query, params, nullptr); - stream.Header(header); - return std::make_pair(std::move(stream), qid); - } - - void Pull(ResultStreamFaker *stream, std::optional<int> n = {}, std::optional<int> qid = {}) { - const auto summary = interpreter.Pull(stream, n, qid); - stream->Summary(summary); - } - - /** - * Execute the given query and commit the transaction. - * - * Return the query stream. - */ - auto Interpret(const std::string &query, const std::map<std::string, memgraph::storage::PropertyValue> ¶ms = {}) { - auto prepare_result = Prepare(query, params); - - auto &stream = prepare_result.first; - auto summary = interpreter.Pull(&stream, {}, prepare_result.second); - stream.Summary(summary); - - return std::move(stream); - } - - memgraph::query::AllowEverythingAuthChecker auth_checker; - memgraph::query::InterpreterContext interpreter_context; - memgraph::query::Interpreter interpreter; -}; - } // namespace // TODO: This is not a unit test, but tests/integration dir is chaotic at the // moment. After tests refactoring is done, move/rename this. class InterpreterTest : public ::testing::Test { - protected: + public: memgraph::storage::Storage db_; std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_tests_unit_interpreter"}; + memgraph::query::InterpreterContext interpreter_context{&db_, {}, data_directory}; - InterpreterFaker default_interpreter{&db_, {}, data_directory}; + InterpreterFaker default_interpreter{&interpreter_context}; auto Prepare(const std::string &query, const std::map<std::string, memgraph::storage::PropertyValue> ¶ms = {}) { return default_interpreter.Prepare(query, params); @@ -638,8 +600,6 @@ TEST_F(InterpreterTest, UniqueConstraintTest) { } TEST_F(InterpreterTest, ExplainQuery) { - const auto &interpreter_context = default_interpreter.interpreter_context; - EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); auto stream = Interpret("EXPLAIN MATCH (n) RETURN *;"); @@ -663,8 +623,6 @@ TEST_F(InterpreterTest, ExplainQuery) { } TEST_F(InterpreterTest, ExplainQueryMultiplePulls) { - const auto &interpreter_context = default_interpreter.interpreter_context; - EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); auto [stream, qid] = Prepare("EXPLAIN MATCH (n) RETURN *;"); @@ -698,8 +656,6 @@ TEST_F(InterpreterTest, ExplainQueryMultiplePulls) { } TEST_F(InterpreterTest, ExplainQueryInMulticommandTransaction) { - const auto &interpreter_context = default_interpreter.interpreter_context; - EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); Interpret("BEGIN"); @@ -725,8 +681,6 @@ TEST_F(InterpreterTest, ExplainQueryInMulticommandTransaction) { } TEST_F(InterpreterTest, ExplainQueryWithParams) { - const auto &interpreter_context = default_interpreter.interpreter_context; - EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); auto stream = @@ -751,8 +705,6 @@ TEST_F(InterpreterTest, ExplainQueryWithParams) { } TEST_F(InterpreterTest, ProfileQuery) { - const auto &interpreter_context = default_interpreter.interpreter_context; - EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); auto stream = Interpret("PROFILE MATCH (n) RETURN *;"); @@ -776,8 +728,6 @@ TEST_F(InterpreterTest, ProfileQuery) { } TEST_F(InterpreterTest, ProfileQueryMultiplePulls) { - const auto &interpreter_context = default_interpreter.interpreter_context; - EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); auto [stream, qid] = Prepare("PROFILE MATCH (n) RETURN *;"); @@ -820,8 +770,6 @@ TEST_F(InterpreterTest, ProfileQueryInMulticommandTransaction) { } TEST_F(InterpreterTest, ProfileQueryWithParams) { - const auto &interpreter_context = default_interpreter.interpreter_context; - EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); auto stream = @@ -846,8 +794,6 @@ TEST_F(InterpreterTest, ProfileQueryWithParams) { } TEST_F(InterpreterTest, ProfileQueryWithLiterals) { - const auto &interpreter_context = default_interpreter.interpreter_context; - EXPECT_EQ(interpreter_context.plan_cache.size(), 0U); EXPECT_EQ(interpreter_context.ast_cache.size(), 0U); auto stream = Interpret("PROFILE UNWIND range(1, 1000) AS x CREATE (:Node {id: x});", {}); @@ -1087,7 +1033,6 @@ TEST_F(InterpreterTest, LoadCsvClause) { } TEST_F(InterpreterTest, CacheableQueries) { - const auto &interpreter_context = default_interpreter.interpreter_context; // This should be cached { SCOPED_TRACE("Cacheable query"); @@ -1120,7 +1065,9 @@ TEST_F(InterpreterTest, AllowLoadCsvConfig) { "CREATE TRIGGER trigger ON CREATE BEFORE COMMIT EXECUTE LOAD CSV FROM 'file.csv' WITH HEADER AS row RETURN " "row"}; - InterpreterFaker interpreter_faker{&db_, {.query = {.allow_load_csv = allow_load_csv}}, directory_manager.Path()}; + memgraph::query::InterpreterContext csv_interpreter_context{ + &db_, {.query = {.allow_load_csv = allow_load_csv}}, directory_manager.Path()}; + InterpreterFaker interpreter_faker{&csv_interpreter_context}; for (const auto &query : queries) { if (allow_load_csv) { SCOPED_TRACE(fmt::format("'{}' should not throw because LOAD CSV is allowed", query)); diff --git a/tests/unit/interpreter_faker.hpp b/tests/unit/interpreter_faker.hpp new file mode 100644 index 000000000..5b14a543e --- /dev/null +++ b/tests/unit/interpreter_faker.hpp @@ -0,0 +1,49 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "communication/result_stream_faker.hpp" +#include "query/interpreter.hpp" + +struct InterpreterFaker { + InterpreterFaker(memgraph::query::InterpreterContext *interpreter_context) + : interpreter_context(interpreter_context), interpreter(interpreter_context) { + interpreter_context->auth_checker = &auth_checker; + interpreter_context->interpreters.WithLock([this](auto &interpreters) { interpreters.insert(&interpreter); }); + } + + auto Prepare(const std::string &query, const std::map<std::string, memgraph::storage::PropertyValue> ¶ms = {}) { + ResultStreamFaker stream(interpreter_context->db); + const auto [header, _, qid] = interpreter.Prepare(query, params, nullptr); + stream.Header(header); + return std::make_pair(std::move(stream), qid); + } + + void Pull(ResultStreamFaker *stream, std::optional<int> n = {}, std::optional<int> qid = {}) { + const auto summary = interpreter.Pull(stream, n, qid); + stream->Summary(summary); + } + + /** + * Execute the given query and commit the transaction. + * + * Return the query stream. + */ + auto Interpret(const std::string &query, const std::map<std::string, memgraph::storage::PropertyValue> ¶ms = {}) { + auto prepare_result = Prepare(query, params); + auto &stream = prepare_result.first; + auto summary = interpreter.Pull(&stream, {}, prepare_result.second); + stream.Summary(summary); + return std::move(stream); + } + memgraph::query::AllowEverythingAuthChecker auth_checker; + memgraph::query::InterpreterContext *interpreter_context; + memgraph::query::Interpreter interpreter; +}; diff --git a/tests/unit/query_trigger.cpp b/tests/unit/query_trigger.cpp index 2c0559f50..8a20d1294 100644 --- a/tests/unit/query_trigger.cpp +++ b/tests/unit/query_trigger.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -39,7 +39,6 @@ class MockAuthChecker : public memgraph::query::AuthChecker { public: MOCK_CONST_METHOD2(IsUserAuthorized, bool(const std::optional<std::string> &username, const std::vector<memgraph::query::AuthQuery::Privilege> &privileges)); - #ifdef MG_ENTERPRISE MOCK_CONST_METHOD2(GetFineGrainedAuthChecker, std::unique_ptr<memgraph::query::FineGrainedAuthChecker>( diff --git a/tests/unit/transaction_queue.cpp b/tests/unit/transaction_queue.cpp new file mode 100644 index 000000000..812f9cb2b --- /dev/null +++ b/tests/unit/transaction_queue.cpp @@ -0,0 +1,75 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include <chrono> +#include <stop_token> +#include <string> +#include <thread> + +#include <gtest/gtest.h> +#include "gmock/gmock.h" + +#include "interpreter_faker.hpp" + +/* +Tests rely on the fact that interpreters are sequentially added to runninng_interpreters to get transaction_id of its +corresponding interpreter/. +*/ +class TransactionQueueSimpleTest : public ::testing::Test { + protected: + memgraph::storage::Storage db_; + std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_tests_unit_transaction_queue_intr"}; + memgraph::query::InterpreterContext interpreter_context{&db_, {}, data_directory}; + InterpreterFaker running_interpreter{&interpreter_context}, main_interpreter{&interpreter_context}; +}; + +TEST_F(TransactionQueueSimpleTest, TwoInterpretersInterleaving) { + bool started = false; + std::jthread running_thread = std::jthread( + [this, &started](std::stop_token st, int thread_index) { + running_interpreter.Interpret("BEGIN"); + started = true; + }, + 0); + + { + while (!started) { + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + } + main_interpreter.Interpret("CREATE (:Person {prop: 1})"); + auto show_stream = main_interpreter.Interpret("SHOW TRANSACTIONS"); + ASSERT_EQ(show_stream.GetResults().size(), 2U); + // superadmin executing the transaction + EXPECT_EQ(show_stream.GetResults()[0][0].ValueString(), ""); + ASSERT_TRUE(show_stream.GetResults()[0][1].IsString()); + EXPECT_EQ(show_stream.GetResults()[0][2].ValueList().at(0).ValueString(), "SHOW TRANSACTIONS"); + // Also anonymous user executing + EXPECT_EQ(show_stream.GetResults()[1][0].ValueString(), ""); + ASSERT_TRUE(show_stream.GetResults()[1][1].IsString()); + // Kill the other transaction + std::string run_trans_id = show_stream.GetResults()[1][1].ValueString(); + std::string esc_run_trans_id = "'" + run_trans_id + "'"; + auto terminate_stream = main_interpreter.Interpret("TERMINATE TRANSACTIONS " + esc_run_trans_id); + // check result of killing + ASSERT_EQ(terminate_stream.GetResults().size(), 1U); + EXPECT_EQ(terminate_stream.GetResults()[0][0].ValueString(), run_trans_id); + ASSERT_TRUE(terminate_stream.GetResults()[0][1].ValueBool()); // that the transaction is actually killed + // check the number of transactions now + auto show_stream_after_killing = main_interpreter.Interpret("SHOW TRANSACTIONS"); + ASSERT_EQ(show_stream_after_killing.GetResults().size(), 1U); + // test the state of the database + auto results_stream = main_interpreter.Interpret("MATCH (n) RETURN n"); + ASSERT_EQ(results_stream.GetResults().size(), 1U); // from the main interpreter + main_interpreter.Interpret("MATCH (n) DETACH DELETE n"); + // finish thread + running_thread.request_stop(); + } +} diff --git a/tests/unit/transaction_queue_multiple.cpp b/tests/unit/transaction_queue_multiple.cpp new file mode 100644 index 000000000..64cd10ad8 --- /dev/null +++ b/tests/unit/transaction_queue_multiple.cpp @@ -0,0 +1,118 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include <chrono> +#include <random> +#include <stop_token> +#include <string> +#include <thread> + +#include <gtest/gtest.h> +#include "gmock/gmock.h" +#include "spdlog/spdlog.h" + +#include "interpreter_faker.hpp" +#include "query/exceptions.hpp" + +constexpr int NUM_INTERPRETERS = 4, INSERTIONS = 4000; + +/* +Tests rely on the fact that interpreters are sequentially added to running_interpreters to get transaction_id of its +corresponding interpreter. +*/ +class TransactionQueueMultipleTest : public ::testing::Test { + protected: + memgraph::storage::Storage db_; + std::filesystem::path data_directory{std::filesystem::temp_directory_path() / + "MG_tests_unit_transaction_queue_multiple_intr"}; + memgraph::query::InterpreterContext interpreter_context{&db_, {}, data_directory}; + InterpreterFaker main_interpreter{&interpreter_context}; + std::vector<InterpreterFaker *> running_interpreters; + + TransactionQueueMultipleTest() { + for (int i = 0; i < NUM_INTERPRETERS; ++i) { + InterpreterFaker *faker = new InterpreterFaker(&interpreter_context); + running_interpreters.push_back(faker); + } + } + + ~TransactionQueueMultipleTest() override { + for (int i = 0; i < NUM_INTERPRETERS; ++i) { + delete running_interpreters[i]; + } + } +}; + +// Tests whether admin can see transaction of superadmin +TEST_F(TransactionQueueMultipleTest, TerminateTransaction) { + std::vector<bool> started(NUM_INTERPRETERS, false); + auto thread_func = [this, &started](int thread_index) { + try { + running_interpreters[thread_index]->Interpret("BEGIN"); + started[thread_index] = true; + // add try-catch block + for (int j = 0; j < INSERTIONS; ++j) { + running_interpreters[thread_index]->Interpret("CREATE (:Person {prop: " + std::to_string(thread_index) + "})"); + } + } catch (memgraph::query::HintedAbortError &e) { + } + }; + + { + std::vector<std::jthread> running_threads; + running_threads.reserve(NUM_INTERPRETERS); + for (int i = 0; i < NUM_INTERPRETERS; ++i) { + running_threads.emplace_back(thread_func, i); + } + + while (!std::all_of(started.begin(), started.end(), [](const bool v) { return v; })) { + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + } + + auto show_stream = main_interpreter.Interpret("SHOW TRANSACTIONS"); + ASSERT_EQ(show_stream.GetResults().size(), NUM_INTERPRETERS + 1); + // Choose random transaction to kill + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<int> distr(0, NUM_INTERPRETERS - 1); + int index_to_terminate = distr(gen); + // Kill random transaction + std::string run_trans_id = + std::to_string(running_interpreters[index_to_terminate]->interpreter.GetTransactionId().value()); + std::string esc_run_trans_id = "'" + run_trans_id + "'"; + auto terminate_stream = main_interpreter.Interpret("TERMINATE TRANSACTIONS " + esc_run_trans_id); + // check result of killing + ASSERT_EQ(terminate_stream.GetResults().size(), 1U); + EXPECT_EQ(terminate_stream.GetResults()[0][0].ValueString(), run_trans_id); + ASSERT_TRUE(terminate_stream.GetResults()[0][1].ValueBool()); // that the transaction is actually killed + // test here show transactions + auto show_stream_after_kill = main_interpreter.Interpret("SHOW TRANSACTIONS"); + ASSERT_EQ(show_stream_after_kill.GetResults().size(), NUM_INTERPRETERS); + // wait to finish for threads + for (int i = 0; i < NUM_INTERPRETERS; ++i) { + running_threads[i].join(); + } + // test the state of the database + for (int i = 0; i < NUM_INTERPRETERS; ++i) { + if (i != index_to_terminate) { + running_interpreters[i]->Interpret("COMMIT"); + } + std::string fetch_query = "MATCH (n:Person) WHERE n.prop=" + std::to_string(i) + " RETURN n"; + auto results_stream = main_interpreter.Interpret(fetch_query); + if (i == index_to_terminate) { + ASSERT_EQ(results_stream.GetResults().size(), 0); + } else { + ASSERT_EQ(results_stream.GetResults().size(), INSERTIONS); + } + } + main_interpreter.Interpret("MATCH (n) DETACH DELETE n"); + } +}