Merge branch 'master' into add-bug-tracking-workflow

This commit is contained in:
Marko Budiselić 2023-03-28 15:14:26 +02:00 committed by GitHub
commit 6636cbb104
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
93 changed files with 12452 additions and 1297 deletions

View File

@ -177,6 +177,23 @@ jobs:
name: fedora-36
path: build/output/fedora-36/memgraph*.rpm
amzn-2:
runs-on: [self-hosted, DockerMgBuild, X64]
timeout-minutes: 60
steps:
- name: "Set up repository"
uses: actions/checkout@v3
with:
fetch-depth: 0 # Required because of release/get_version.py
- name: "Build package"
run: |
./release/package/run.sh package amzn-2
- name: "Upload package"
uses: actions/upload-artifact@v3
with:
name: amzn-2
path: build/output/amzn-2/memgraph*.rpm
debian-11-arm:
runs-on: [self-hosted, DockerMgBuild, ARM64, strange]
timeout-minutes: 60

7
.gitignore vendored
View File

@ -34,9 +34,6 @@ TAGS
*.fas
*.fasl
# LCP generated C++ files
*.lcp.cpp
src/database/distributed/serialization.hpp
src/database/single_node_ha/serialization.hpp
src/distributed/bfs_rpc_messages.hpp
@ -50,15 +47,11 @@ src/distributed/pull_produce_rpc_messages.hpp
src/distributed/storage_gc_rpc_messages.hpp
src/distributed/token_sharing_rpc_messages.hpp
src/distributed/updates_rpc_messages.hpp
src/query/frontend/ast/ast.hpp
src/query/distributed/frontend/ast/ast_serialization.hpp
src/durability/distributed/state_delta.hpp
src/durability/single_node/state_delta.hpp
src/durability/single_node_ha/state_delta.hpp
src/query/frontend/semantic/symbol.hpp
src/query/distributed/frontend/semantic/symbol_serialization.hpp
src/query/distributed/plan/ops.hpp
src/query/plan/operator.hpp
src/raft/log_entry.hpp
src/raft/raft_rpc_messages.hpp
src/raft/snapshot_metadata.hpp

View File

@ -1,16 +1,16 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
rev: v4.4.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 22.8.0
rev: 23.1.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 5.10.1
rev: 5.12.0
hooks:
- id: isort
name: isort (python)

View File

@ -231,7 +231,15 @@ endif()
message(STATUS "CMake build type: ${CMAKE_BUILD_TYPE}")
# -----------------------------------------------------------------------------
set(MG_ARCH "x86_64" CACHE STRING "Host architecture to build Memgraph on. Supported values are x86_64 (default), ARM64.")
if (NOT MG_ARCH)
set(MG_ARCH_DESCR "Host architecture to build Memgraph on. Supported values are x86_64, ARM64.")
if (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "aarch64")
set(MG_ARCH "ARM64" CACHE STRING ${MG_ARCH_DESCR})
else()
set(MG_ARCH "x86_64" CACHE STRING ${MG_ARCH_DESCR})
endif()
endif()
message(STATUS "MG_ARCH: ${MG_ARCH}")
# setup external dependencies -------------------------------------------------

156
environment/os/amzn-2.sh Executable file
View File

@ -0,0 +1,156 @@
#!/bin/bash
set -Eeuo pipefail
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
source "$DIR/../util.sh"
check_operating_system "amzn-2"
check_architecture "x86_64"
TOOLCHAIN_BUILD_DEPS=(
gcc gcc-c++ make # generic build tools
wget # used for archive download
gnupg2 # used for archive signature verification
tar gzip bzip2 xz unzip # used for archive unpacking
zlib-devel # zlib library used for all builds
expat-devel xz-devel python3-devel texinfo
curl libcurl-devel # for cmake
readline-devel # for cmake and llvm
libffi-devel libxml2-devel # for llvm
libedit-devel pcre-devel automake bison # for swig
file
openssl-devel
gmp-devel
gperf
diffutils
patch
libipt libipt-devel # intel
perl # for openssl
)
TOOLCHAIN_RUN_DEPS=(
make # generic build tools
tar gzip bzip2 xz # used for archive unpacking
zlib # zlib library used for all builds
expat xz-libs python3 # for gdb
readline # for cmake and llvm
libffi libxml2 # for llvm
openssl-devel
)
MEMGRAPH_BUILD_DEPS=(
git # source code control
make # build system
wget # for downloading libs
libuuid-devel java-11-openjdk # required by antlr
readline-devel # for memgraph console
python3-devel # for query modules
openssl-devel
libseccomp-devel
python3 python3-pip nmap-ncat # for tests
#
# IMPORTANT: python3-yaml does NOT exist on CentOS
# Install it using `pip3 install PyYAML`
#
PyYAML # Package name here does not correspond to the yum package!
libcurl-devel # mg-requests
rpm-build rpmlint # for RPM package building
doxygen graphviz # source documentation generators
which nodejs golang zip unzip java-11-openjdk-devel # for driver tests
autoconf # for jemalloc code generation
libtool # for protobuf code generation
)
list() {
echo "$1"
}
check() {
local missing=""
# On Fedora yum/dnf and python10 use newer glibc which is not compatible
# with ours, so we need to momentarely disable env
local OLD_LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-""}
LD_LIBRARY_PATH=""
for pkg in $1; do
if [ "$pkg" == "PyYAML" ]; then
if ! python3 -c "import yaml" >/dev/null 2>/dev/null; then
missing="$pkg $missing"
fi
continue
fi
if ! yum list installed "$pkg" >/dev/null 2>/dev/null; then
missing="$pkg $missing"
fi
done
if [ "$missing" != "" ]; then
echo "MISSING PACKAGES: $missing"
exit 1
fi
LD_LIBRARY_PATH=${OLD_LD_LIBRARY_PATH}
}
install() {
cd "$DIR"
if [ "$EUID" -ne 0 ]; then
echo "Please run as root."
exit 1
fi
# If GitHub Actions runner is installed, append LANG to the environment.
# Python related tests don't work without the LANG export.
if [ -d "/home/gh/actions-runner" ]; then
echo "LANG=en_US.utf8" >> /home/gh/actions-runner/.env
else
echo "NOTE: export LANG=en_US.utf8"
fi
yum update -y
for pkg in $1; do
if [ "$pkg" == libipt ]; then
if ! yum list installed libipt >/dev/null 2>/dev/null; then
yum install -y http://repo.okay.com.mx/centos/8/x86_64/release/libipt-1.6.1-8.el8.x86_64.rpm
fi
continue
fi
if [ "$pkg" == libipt-devel ]; then
if ! yum list installed libipt-devel >/dev/null 2>/dev/null; then
yum install -y http://repo.okay.com.mx/centos/8/x86_64/release/libipt-devel-1.6.1-8.el8.x86_64.rpm
fi
continue
fi
if [ "$pkg" == nodejs ]; then
curl -sL https://rpm.nodesource.com/setup_16.x | bash -
if ! yum list installed nodejs >/dev/null 2>/dev/null; then
yum install -y nodejs
fi
continue
fi
if [ "$pkg" == PyYAML ]; then
if [ -z ${SUDO_USER+x} ]; then # Running as root (e.g. Docker).
pip3 install --user PyYAML
else # Running using sudo.
sudo -H -u "$SUDO_USER" bash -c "pip3 install --user PyYAML"
fi
continue
fi
if [ "$pkg" == nodejs ]; then
curl -sL https://rpm.nodesource.com/setup_16.x | bash -
if ! yum list installed nodejs >/dev/null 2>/dev/null; then
yum install -y nodejs
fi
continue
fi
if [ "$pkg" == java-11-openjdk ]; then
amazon-linux-extras install -y java-openjdk11
continue
fi
if [ "$pkg" == java-11-openjdk-devel ]; then
amazon-linux-extras install -y java-openjdk11
yum install -y java-11-openjdk-devel
continue
fi
yum install -y "$pkg"
done
}
deps=$2"[*]"
"$1" "${!deps}"

View File

@ -415,6 +415,34 @@ if [ ! -f $PREFIX/bin/gdb ]; then
--with-intel-pt \
--enable-tui \
--with-python=python3
elif [[ "${DISTRO}" == "amzn-2" ]]; then
# Remove readline, gdb does not compile
env \
CC=gcc \
CXX=g++ \
CFLAGS="-g -O2 -fstack-protector-strong -Wformat -Werror=format-security" \
CXXFLAGS="-g -O2 -fstack-protector-strong -Wformat -Werror=format-security" \
CPPFLAGS="-Wdate-time -D_FORTIFY_SOURCE=2 -fPIC" \
LDFLAGS="-Wl,-z,relro" \
PYTHON="" \
../configure \
--build=x86_64-linux-gnu \
--host=x86_64-linux-gnu \
--prefix=$PREFIX \
--disable-maintainer-mode \
--disable-dependency-tracking \
--disable-silent-rules \
--disable-gdbtk \
--disable-shared \
--without-guile \
--with-system-gdbinit=$PREFIX/etc/gdb/gdbinit \
--with-expat \
--with-system-zlib \
--with-lzma \
--with-babeltrace \
--with-intel-pt \
--enable-tui \
--with-python=python3
else
# https://buildd.debian.org/status/fetch.php?pkg=gdb&arch=amd64&ver=8.2.1-2&stamp=1550831554&raw=0
env \
@ -1143,119 +1171,121 @@ if [ ! -f $PREFIX/include/libaio.h ]; then
popd
fi
log_tool_name "folly $FBLIBS_VERSION"
if [ ! -d $PREFIX/include/folly ]; then
if [ -d folly-$FBLIBS_VERSION ]; then
rm -rf folly-$FBLIBS_VERSION
if [[ "${DISTRO}" != "amzn-2" ]]; then
log_tool_name "folly $FBLIBS_VERSION"
if [ ! -d $PREFIX/include/folly ]; then
if [ -d folly-$FBLIBS_VERSION ]; then
rm -rf folly-$FBLIBS_VERSION
fi
mkdir folly-$FBLIBS_VERSION
tar -xzf ../archives/folly-$FBLIBS_VERSION.tar.gz -C folly-$FBLIBS_VERSION
pushd folly-$FBLIBS_VERSION
patch -p1 < ../../folly.patch
# build is used by facebook builder
mkdir _build
pushd _build
cmake .. $COMMON_CMAKE_FLAGS \
-DBOOST_LINK_STATIC=ON \
-DBUILD_TESTS=OFF \
-DGFLAGS_NOTHREADS=OFF \
-DCXX_STD="c++20"
make -j$CPUS install
popd && popd
fi
mkdir folly-$FBLIBS_VERSION
tar -xzf ../archives/folly-$FBLIBS_VERSION.tar.gz -C folly-$FBLIBS_VERSION
pushd folly-$FBLIBS_VERSION
patch -p1 < ../../folly.patch
# build is used by facebook builder
mkdir _build
pushd _build
cmake .. $COMMON_CMAKE_FLAGS \
-DBOOST_LINK_STATIC=ON \
-DBUILD_TESTS=OFF \
-DGFLAGS_NOTHREADS=OFF \
-DCXX_STD="c++20"
make -j$CPUS install
popd && popd
fi
log_tool_name "fizz $FBLIBS_VERSION"
if [ ! -d $PREFIX/include/fizz ]; then
if [ -d fizz-$FBLIBS_VERSION ]; then
rm -rf fizz-$FBLIBS_VERSION
log_tool_name "fizz $FBLIBS_VERSION"
if [ ! -d $PREFIX/include/fizz ]; then
if [ -d fizz-$FBLIBS_VERSION ]; then
rm -rf fizz-$FBLIBS_VERSION
fi
mkdir fizz-$FBLIBS_VERSION
tar -xzf ../archives/fizz-$FBLIBS_VERSION.tar.gz -C fizz-$FBLIBS_VERSION
pushd fizz-$FBLIBS_VERSION
# build is used by facebook builder
mkdir _build
pushd _build
cmake ../fizz $COMMON_CMAKE_FLAGS \
-DBUILD_TESTS=OFF \
-DBUILD_EXAMPLES=OFF \
-DGFLAGS_NOTHREADS=OFF
make -j$CPUS install
popd && popd
fi
mkdir fizz-$FBLIBS_VERSION
tar -xzf ../archives/fizz-$FBLIBS_VERSION.tar.gz -C fizz-$FBLIBS_VERSION
pushd fizz-$FBLIBS_VERSION
# build is used by facebook builder
mkdir _build
pushd _build
cmake ../fizz $COMMON_CMAKE_FLAGS \
-DBUILD_TESTS=OFF \
-DBUILD_EXAMPLES=OFF \
-DGFLAGS_NOTHREADS=OFF
make -j$CPUS install
popd && popd
fi
log_tool_name "wangle FBLIBS_VERSION"
if [ ! -d $PREFIX/include/wangle ]; then
if [ -d wangle-$FBLIBS_VERSION ]; then
rm -rf wangle-$FBLIBS_VERSION
log_tool_name "wangle FBLIBS_VERSION"
if [ ! -d $PREFIX/include/wangle ]; then
if [ -d wangle-$FBLIBS_VERSION ]; then
rm -rf wangle-$FBLIBS_VERSION
fi
mkdir wangle-$FBLIBS_VERSION
tar -xzf ../archives/wangle-$FBLIBS_VERSION.tar.gz -C wangle-$FBLIBS_VERSION
pushd wangle-$FBLIBS_VERSION
# build is used by facebook builder
mkdir _build
pushd _build
cmake ../wangle $COMMON_CMAKE_FLAGS \
-DBUILD_TESTS=OFF \
-DBUILD_EXAMPLES=OFF \
-DGFLAGS_NOTHREADS=OFF
make -j$CPUS install
popd && popd
fi
mkdir wangle-$FBLIBS_VERSION
tar -xzf ../archives/wangle-$FBLIBS_VERSION.tar.gz -C wangle-$FBLIBS_VERSION
pushd wangle-$FBLIBS_VERSION
# build is used by facebook builder
mkdir _build
pushd _build
cmake ../wangle $COMMON_CMAKE_FLAGS \
-DBUILD_TESTS=OFF \
-DBUILD_EXAMPLES=OFF \
-DGFLAGS_NOTHREADS=OFF
make -j$CPUS install
popd && popd
fi
log_tool_name "proxygen $FBLIBS_VERSION"
if [ ! -d $PREFIX/include/proxygen ]; then
if [ -d proxygen-$FBLIBS_VERSION ]; then
rm -rf proxygen-$FBLIBS_VERSION
log_tool_name "proxygen $FBLIBS_VERSION"
if [ ! -d $PREFIX/include/proxygen ]; then
if [ -d proxygen-$FBLIBS_VERSION ]; then
rm -rf proxygen-$FBLIBS_VERSION
fi
mkdir proxygen-$FBLIBS_VERSION
tar -xzf ../archives/proxygen-$FBLIBS_VERSION.tar.gz -C proxygen-$FBLIBS_VERSION
pushd proxygen-$FBLIBS_VERSION
patch -p1 < ../../proxygen.patch
# build is used by facebook builder
mkdir _build
pushd _build
cmake .. $COMMON_CMAKE_FLAGS \
-DBUILD_TESTS=OFF \
-DBUILD_SAMPLES=OFF \
-DGFLAGS_NOTHREADS=OFF \
-DBUILD_QUIC=OFF
make -j$CPUS install
popd && popd
fi
mkdir proxygen-$FBLIBS_VERSION
tar -xzf ../archives/proxygen-$FBLIBS_VERSION.tar.gz -C proxygen-$FBLIBS_VERSION
pushd proxygen-$FBLIBS_VERSION
patch -p1 < ../../proxygen.patch
# build is used by facebook builder
mkdir _build
pushd _build
cmake .. $COMMON_CMAKE_FLAGS \
-DBUILD_TESTS=OFF \
-DBUILD_SAMPLES=OFF \
-DGFLAGS_NOTHREADS=OFF \
-DBUILD_QUIC=OFF
make -j$CPUS install
popd && popd
fi
log_tool_name "flex $FBLIBS_VERSION"
if [ ! -f $PREFIX/include/FlexLexer.h ]; then
if [ -d flex-$FLEX_VERSION ]; then
rm -rf flex-$FLEX_VERSION
log_tool_name "flex $FBLIBS_VERSION"
if [ ! -f $PREFIX/include/FlexLexer.h ]; then
if [ -d flex-$FLEX_VERSION ]; then
rm -rf flex-$FLEX_VERSION
fi
tar -xzf ../archives/flex-$FLEX_VERSION.tar.gz
pushd flex-$FLEX_VERSION
./configure $COMMON_CONFIGURE_FLAGS
make -j$CPUS install
popd
fi
tar -xzf ../archives/flex-$FLEX_VERSION.tar.gz
pushd flex-$FLEX_VERSION
./configure $COMMON_CONFIGURE_FLAGS
make -j$CPUS install
popd
fi
log_tool_name "fbthrift $FBLIBS_VERSION"
if [ ! -d $PREFIX/include/thrift ]; then
if [ -d fbthrift-$FBLIBS_VERSION ]; then
rm -rf fbthrift-$FBLIBS_VERSION
log_tool_name "fbthrift $FBLIBS_VERSION"
if [ ! -d $PREFIX/include/thrift ]; then
if [ -d fbthrift-$FBLIBS_VERSION ]; then
rm -rf fbthrift-$FBLIBS_VERSION
fi
git clone --depth 1 --branch v$FBLIBS_VERSION https://github.com/facebook/fbthrift.git fbthrift-$FBLIBS_VERSION
pushd fbthrift-$FBLIBS_VERSION
# build is used by facebook builder
mkdir _build
pushd _build
if [ "$TOOLCHAIN_STDCXX" = "libstdc++" ]; then
CMAKE_CXX_FLAGS="-fsized-deallocation"
else
CMAKE_CXX_FLAGS="-fsized-deallocation -stdlib=libc++"
fi
cmake .. $COMMON_CMAKE_FLAGS \
-Denable_tests=OFF \
-DGFLAGS_NOTHREADS=OFF \
-DCMAKE_CXX_FLAGS="$CMAKE_CXX_FLAGS"
make -j$CPUS install
popd
fi
git clone --depth 1 --branch v$FBLIBS_VERSION https://github.com/facebook/fbthrift.git fbthrift-$FBLIBS_VERSION
pushd fbthrift-$FBLIBS_VERSION
# build is used by facebook builder
mkdir _build
pushd _build
if [ "$TOOLCHAIN_STDCXX" = "libstdc++" ]; then
CMAKE_CXX_FLAGS="-fsized-deallocation"
else
CMAKE_CXX_FLAGS="-fsized-deallocation -stdlib=libc++"
fi
cmake .. $COMMON_CMAKE_FLAGS \
-Denable_tests=OFF \
-DGFLAGS_NOTHREADS=OFF \
-DCMAKE_CXX_FLAGS="$CMAKE_CXX_FLAGS"
make -j$CPUS install
popd
fi
popd

View File

@ -55,6 +55,15 @@ class NotEnoughMemoryException : public std::exception {
const char *what() const throw() { return "Not enough memory!"; }
};
class MustAbortException : public std::exception {
public:
explicit MustAbortException(const std::string &message) : message_(message) {}
const char *what() const noexcept override { return message_.c_str(); }
private:
std::string message_;
};
// Forward declarations
class Nodes;
using GraphNodes = Nodes;
@ -141,6 +150,10 @@ class Graph {
/// @brief Deletes a relationship from the graph.
void DeleteRelationship(const Relationship &relationship);
bool MustAbort() const;
void CheckMustAbort() const;
private:
mgp_graph *graph_;
};
@ -1572,6 +1585,14 @@ inline Id::Id(int64_t id) : id_(id) {}
inline Graph::Graph(mgp_graph *graph) : graph_(graph) {}
inline bool Graph::MustAbort() const { return must_abort(graph_); }
inline void Graph::CheckMustAbort() const {
if (MustAbort()) {
throw MustAbortException("Query was asked to abort.");
}
}
inline int64_t Graph::Order() const {
int64_t i = 0;
for (const auto _ : Nodes()) {

70
init
View File

@ -14,7 +14,6 @@ function print_help () {
echo "Optional arguments:"
echo -e " -h\tdisplay this help and exit"
echo -e " --without-libs-setup\tskip the step for setting up libs"
echo -e " --wsl-quicklisp-proxy \"host:port\"\tquicklist HTTP proxy (this flag + HTTP proxy are required on WSL)"
}
function setup_virtualenv () {
@ -35,7 +34,6 @@ function setup_virtualenv () {
popd > /dev/null
}
wsl_quicklisp_proxy=""
setup_libs=true
if [[ $# -eq 1 && "$1" == "-h" ]]; then
print_help
@ -43,16 +41,6 @@ if [[ $# -eq 1 && "$1" == "-h" ]]; then
else
while(($#)); do
case "$1" in
--wsl-quicklisp-proxy)
shift
if [[ $# -eq 0 ]]; then
echo "Missing proxy URL"
print_help
exit 1
fi
wsl_quicklisp_proxy=":proxy \"http://$1/\""
shift
;;
--without-libs-setup)
shift
setup_libs=false
@ -79,41 +67,16 @@ echo "All packages are in-place..."
# create a default build directory
mkdir -p ./build
# quicklisp package manager for Common Lisp
quicklisp_install_dir="$HOME/quicklisp"
if [[ -v QUICKLISP_HOME ]]; then
quicklisp_install_dir="${QUICKLISP_HOME}"
fi
if [[ ! -f "${quicklisp_install_dir}/setup.lisp" ]]; then
wget -nv https://beta.quicklisp.org/quicklisp.lisp -O quicklisp.lisp || exit 1
echo \
"
(load \"${DIR}/quicklisp.lisp\")
(quicklisp-quickstart:install $wsl_quicklisp_proxy :path \"${quicklisp_install_dir}\")
" | sbcl --script || exit 1
rm -rf quicklisp.lisp || exit 1
fi
ln -Tfs "$DIR/src/lisp" "${quicklisp_install_dir}/local-projects/lcp"
# Install LCP dependencies
# TODO: We should at some point cache or have a mirror of packages we use.
# TODO: move the installation of LCP's dependencies into ./setup.sh
echo \
"
(load \"${quicklisp_install_dir}/setup.lisp\")
(ql:quickload '(:lcp :lcp/test) :silent t)
" | sbcl --script
if [[ "$setup_libs" == "true" ]]; then
# Setup libs (download).
cd libs
./cleanup.sh
./setup.sh
cd ..
# Setup libs (download).
cd libs
./cleanup.sh
./setup.sh
cd ..
fi
# Fix for centos 7 during release
if [ "${DISTRO}" = "centos-7" ] || [ "${DISTRO}" = "debian-11" ]; then
if [ "${DISTRO}" = "centos-7" ] || [ "${DISTRO}" = "debian-11" ] || [ "${DISTRO}" = "amzn-2" ]; then
python3 -m pip uninstall -y virtualenv
python3 -m pip install virtualenv
fi
@ -143,15 +106,18 @@ for hook in $(find $DIR/.githooks -type f -printf "%f\n"); do
echo "Added $hook hook"
done;
# Install precommit hook
python3 -m pip install pre-commit
python3 -m pre_commit install
# Install py format tools
echo "Install black formatter"
python3 -m pip install black==22.8.*
echo "Install isort"
python3 -m pip install isort==5.10.*
# Install precommit hook except on old operating systems because we don't
# develop on them -> pre-commit hook not required -> we can use latest
# packages.
if [ "${DISTRO}" != "centos-7" ] && [ "$DISTRO" != "debian-10" ] && [ "${DISTRO}" != "ubuntu-18.04" ]; then
python3 -m pip install pre-commit
python3 -m pre_commit install
# Install py format tools for usage during the development.
echo "Install black formatter"
python3 -m pip install black==23.1.*
echo "Install isort"
python3 -m pip install isort==5.12.*
fi
# Link `include/mgp.py` with `release/mgp/mgp.py`
ln -v -f include/mgp.py release/mgp/mgp.py

View File

@ -1,6 +1,10 @@
# Install systemd service (must use absolute path).
install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/memgraph.service
DESTINATION /lib/systemd/system)
DESTINATION /lib/systemd/system)
# Set parameters to recognize the host distro
cmake_host_system_information(RESULT DISTRO QUERY DISTRIB_NAME)
cmake_host_system_information(RESULT DISTRO_VERSION QUERY DISTRIB_VERSION)
# ---- Setup CPack --------
@ -12,10 +16,11 @@ set(CPACK_PACKAGE_DESCRIPTION_SUMMARY
# Setting arhitecture extension for deb packages
set(MG_ARCH_EXTENSION_DEB "all")
if (${MG_ARCH} STREQUAL "x86_64")
if(${MG_ARCH} STREQUAL "x86_64")
set(MG_ARCH_EXTENSION_DEB "amd64")
elseif (${MG_ARCH} STREQUAL "ARM64")
set(MG_ARCH_EXTENSION_DEB "arm64")
elseif(${MG_ARCH} STREQUAL "ARM64")
set(MG_ARCH_EXTENSION_DEB "arm64")
endif()
# DEB specific
@ -34,21 +39,24 @@ set(CPACK_DEBIAN_PACKAGE_CONTROL_EXTRA
"${CMAKE_CURRENT_SOURCE_DIR}/debian/postrm;"
"${CMAKE_CURRENT_SOURCE_DIR}/debian/postinst;")
set(CPACK_DEBIAN_PACKAGE_SHLIBDEPS ON)
# Description formatting is important, summary must be followed with a newline and 1 space.
set(CPACK_DEBIAN_PACKAGE_DESCRIPTION "${CPACK_PACKAGE_DESCRIPTION_SUMMARY}
Contains Memgraph, the graph database. It aims to deliver developers the
speed, simplicity and scale required to build the next generation of
applications driver by real-time connected data.")
# Add `openssl` package to dependencies list. Used to generate SSL certificates.
# We also depend on `python3` because we embed it in Memgraph.
set(CPACK_DEBIAN_PACKAGE_DEPENDS "openssl (>= 1.1.0), python3 (>= 3.5.0), libstdc++6")
# Setting arhitecture extension for rpm packages
set(MG_ARCH_EXTENSION_RPM "noarch")
if (${MG_ARCH} STREQUAL "x86_64")
if(${MG_ARCH} STREQUAL "x86_64")
set(MG_ARCH_EXTENSION_RPM "x86_64")
elseif (${MG_ARCH} STREQUAL "ARM64")
set(MG_ARCH_EXTENSION_RPM "aarch64")
elseif(${MG_ARCH} STREQUAL "ARM64")
set(MG_ARCH_EXTENSION_RPM "aarch64")
endif()
# RPM specific
@ -56,18 +64,26 @@ set(CPACK_RPM_PACKAGE_URL https://memgraph.com)
set(CPACK_RPM_PACKAGE_VERSION "${MEMGRAPH_VERSION_RPM}")
set(CPACK_RPM_FILE_NAME "memgraph-${MEMGRAPH_VERSION_RPM}-1.${MG_ARCH_EXTENSION_RPM}.rpm")
set(CPACK_RPM_EXCLUDE_FROM_AUTO_FILELIST_ADDITION
/var /var/lib /var/log /etc/logrotate.d
/lib /lib/systemd /lib/systemd/system /lib/systemd/system/memgraph.service)
/var /var/lib /var/log /etc/logrotate.d
/lib /lib/systemd /lib/systemd/system /lib/systemd/system/memgraph.service)
set(CPACK_RPM_PACKAGE_REQUIRES_PRE "shadow-utils")
set(CPACK_RPM_USER_BINARY_SPECFILE "${CMAKE_CURRENT_SOURCE_DIR}/rpm/memgraph.spec.in")
set(CPACK_RPM_PACKAGE_LICENSE "Memgraph License")
# Description formatting is important, no line must be greater than 80 characters.
set(CPACK_RPM_PACKAGE_DESCRIPTION "Contains Memgraph, the graph database.
It aims to deliver developers the speed, simplicity and scale required to build
the next generation of applications driver by real-time connected data.")
# Add `openssl` package to dependencies list. Used to generate SSL certificates.
# We also depend on `python3` because we embed it in Memgraph.
set(CPACK_RPM_PACKAGE_REQUIRES "openssl >= 1.0.0, curl >= 7.29.0, python3 >= 3.5.0, libstdc++ >= 6, logrotate")
set(CPACK_RPM_PACKAGE_REQUIRES "openssl >= 1.0.0, curl >= 7.29.0, python3 >= 3.5.0, libstdc++ >= 3.4.29, logrotate")
# If amzn-2
if(DISTRO STREQUAL "Amazon Linux" AND DISTRO_VERSION STREQUAL "2")
# It causes issues with glibcxx 2.4
set(CPACK_RPM_PACKAGE_AUTOREQ " no")
endif()
# All variables must be set before including.
include(CPack)

View File

@ -0,0 +1,14 @@
FROM amazonlinux:2
ARG TOOLCHAIN_VERSION
RUN yum -y update \
&& yum install -y wget git tar
# Do NOT be smart here and clean the cache because the container is used in the
# stateful context.
RUN wget -q https://s3-eu-west-1.amazonaws.com/deps.memgraph.io/${TOOLCHAIN_VERSION}/${TOOLCHAIN_VERSION}-binaries-amzn-2-x86_64.tar.gz \
-O ${TOOLCHAIN_VERSION}-binaries-amzn-2-x86_64.tar.gz \
&& tar xzvf ${TOOLCHAIN_VERSION}-binaries-amzn-2-x86_64.tar.gz -C /opt
ENTRYPOINT ["sleep", "infinity"]

View File

@ -32,3 +32,7 @@ services:
build:
context: fedora-36
container_name: "mgbuild_fedora-36"
mgbuild_amzn-2:
build:
context: amzn-2
container_name: "mgbuild_amzn-2"

View File

@ -3,7 +3,14 @@
set -Eeuo pipefail
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
SUPPORTED_OS=(centos-7 centos-9 debian-10 debian-11 ubuntu-18.04 ubuntu-20.04 ubuntu-22.04 debian-11-arm fedora-36 ubuntu-22.04-arm)
SUPPORTED_OS=(
centos-7 centos-9
debian-10 debian-11 debian-11-arm
ubuntu-18.04 ubuntu-20.04 ubuntu-22.04 ubuntu-22.04-arm
fedora-36
amzn-2
)
PROJECT_ROOT="$SCRIPT_DIR/../.."
TOOLCHAIN_VERSION="toolchain-v4"
ACTIVATE_TOOLCHAIN="source /opt/${TOOLCHAIN_VERSION}/activate"
@ -23,7 +30,7 @@ make_package () {
echo "Building Memgraph for $os on $build_container..."
package_command=""
if [[ "$os" =~ ^"centos".* ]] || [[ "$os" =~ ^"fedora".* ]]; then
if [[ "$os" =~ ^"centos".* ]] || [[ "$os" =~ ^"fedora".* ]] || [[ "$os" =~ ^"amzn".* ]]; then
docker exec "$build_container" bash -c "yum -y update"
package_command=" cpack -G RPM --config ../CPackConfig.cmake && rpmlint --file='../../release/rpm/rpmlintrc' memgraph*.rpm "
fi

View File

@ -1,7 +1,6 @@
# CMake configuration for the main memgraph library and executable
# add memgraph sub libraries, ordered by dependency
add_subdirectory(lisp)
add_subdirectory(utils)
add_subdirectory(requests)
add_subdirectory(io)

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Licensed as a Memgraph Enterprise file under the Memgraph Enterprise
// License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use
@ -34,13 +34,17 @@ namespace memgraph::auth {
namespace {
// Constant list of all available permissions.
const std::vector<Permission> kPermissionsAll = {
Permission::MATCH, Permission::CREATE, Permission::MERGE, Permission::DELETE,
Permission::SET, Permission::REMOVE, Permission::INDEX, Permission::STATS,
Permission::CONSTRAINT, Permission::DUMP, Permission::AUTH, Permission::REPLICATION,
Permission::DURABILITY, Permission::READ_FILE, Permission::FREE_MEMORY, Permission::TRIGGER,
Permission::CONFIG, Permission::STREAM, Permission::MODULE_READ, Permission::MODULE_WRITE,
Permission::WEBSOCKET};
const std::vector<Permission> kPermissionsAll = {Permission::MATCH, Permission::CREATE,
Permission::MERGE, Permission::DELETE,
Permission::SET, Permission::REMOVE,
Permission::INDEX, Permission::STATS,
Permission::CONSTRAINT, Permission::DUMP,
Permission::AUTH, Permission::REPLICATION,
Permission::DURABILITY, Permission::READ_FILE,
Permission::FREE_MEMORY, Permission::TRIGGER,
Permission::CONFIG, Permission::STREAM,
Permission::MODULE_READ, Permission::MODULE_WRITE,
Permission::WEBSOCKET, Permission::TRANSACTION_MANAGEMENT};
} // namespace
std::string PermissionToString(Permission permission) {
@ -87,6 +91,8 @@ std::string PermissionToString(Permission permission) {
return "MODULE_WRITE";
case Permission::WEBSOCKET:
return "WEBSOCKET";
case Permission::TRANSACTION_MANAGEMENT:
return "TRANSACTION_MANAGEMENT";
}
}

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Licensed as a Memgraph Enterprise file under the Memgraph Enterprise
// License (the "License"); by using this file, you agree to be bound by the terms of the License, and you may not use
@ -40,7 +40,8 @@ enum class Permission : uint64_t {
STREAM = 1U << 17U,
MODULE_READ = 1U << 18U,
MODULE_WRITE = 1U << 19U,
WEBSOCKET = 1U << 20U
WEBSOCKET = 1U << 20U,
TRANSACTION_MANAGEMENT = 1U << 21U
};
// clang-format on

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -58,6 +58,8 @@ auth::Permission PrivilegeToPermission(query::AuthQuery::Privilege privilege) {
return auth::Permission::MODULE_WRITE;
case query::AuthQuery::Privilege::WEBSOCKET:
return auth::Permission::WEBSOCKET;
case query::AuthQuery::Privilege::TRANSACTION_MANAGEMENT:
return auth::Permission::TRANSACTION_MANAGEMENT;
}
}

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -84,6 +84,7 @@ bool AuthChecker::IsUserAuthorized(const std::optional<std::string> &username,
return maybe_user.has_value() && IsUserAuthorized(*maybe_user, privileges);
}
#ifdef MG_ENTERPRISE
std::unique_ptr<memgraph::query::FineGrainedAuthChecker> AuthChecker::GetFineGrainedAuthChecker(
const std::string &username, const memgraph::query::DbAccessor *dba) const {

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -26,9 +26,11 @@ class AuthChecker : public query::AuthChecker {
bool IsUserAuthorized(const std::optional<std::string> &username,
const std::vector<query::AuthQuery::Privilege> &privileges) const override;
#ifdef MG_ENTERPRISE
std::unique_ptr<memgraph::query::FineGrainedAuthChecker> GetFineGrainedAuthChecker(
const std::string &username, const memgraph::query::DbAccessor *dba) const override;
#endif
[[nodiscard]] static bool IsUserAuthorized(const memgraph::auth::User &user,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges);

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -522,6 +522,7 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
: memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream,
memgraph::communication::v2::OutputStream>(input_stream, output_stream),
db_(data->db),
interpreter_context_(data->interpreter_context),
interpreter_(data->interpreter_context),
auth_(data->auth),
#if MG_ENTERPRISE
@ -529,6 +530,11 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
#endif
endpoint_(endpoint),
run_id_(data->run_id) {
interpreter_context_->interpreters.WithLock([this](auto &interpreters) { interpreters.insert(&interpreter_); });
}
~BoltSession() override {
interpreter_context_->interpreters.WithLock([this](auto &interpreters) { interpreters.erase(&interpreter_); });
}
using memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream,
@ -674,6 +680,7 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
// NOTE: Needed only for ToBoltValue conversions
const memgraph::storage::Storage *db_;
memgraph::query::InterpreterContext *interpreter_context_;
memgraph::query::Interpreter interpreter_;
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth_;
std::optional<memgraph::auth::User> user_;

View File

@ -1,13 +1,7 @@
define_add_lcp(add_lcp_query lcp_query_cpp_files generated_lcp_query_files)
add_lcp_query(frontend/ast/ast.lcp)
add_lcp_query(frontend/semantic/symbol.lcp)
add_lcp_query(plan/operator.lcp)
add_custom_target(generate_lcp_query DEPENDS ${generated_lcp_query_files})
set(mg_query_sources
${lcp_query_cpp_files}
frontend/ast/ast.cpp
frontend/semantic/symbol.cpp
plan/operator_type_info.cpp
common.cpp
cypher_query_interpreter.cpp
dump.cpp
@ -46,7 +40,6 @@ set(mg_query_sources
find_package(Boost REQUIRED)
add_library(mg-query STATIC ${mg_query_sources})
add_dependencies(mg-query generate_lcp_query)
target_include_directories(mg-query PUBLIC ${CMAKE_SOURCE_DIR}/include)
target_link_libraries(mg-query dl cppitertools Boost::headers)
target_link_libraries(mg-query mg-integrations-pulsar mg-integrations-kafka mg-storage-v2 mg-license mg-utils mg-kvstore mg-memory)

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -24,6 +24,15 @@
namespace memgraph::query {
enum class TransactionStatus {
IDLE,
ACTIVE,
VERIFYING,
TERMINATED,
STARTED_COMMITTING,
STARTED_ROLLBACK,
};
struct EvaluationContext {
/// Memory for allocations during evaluation of a *single* Pull call.
///
@ -66,6 +75,7 @@ struct ExecutionContext {
SymbolTable symbol_table;
EvaluationContext evaluation_context;
std::atomic<bool> *is_shutting_down{nullptr};
std::atomic<TransactionStatus> *transaction_status{nullptr};
bool is_profile_query{false};
std::chrono::duration<double> profile_execution_time;
plan::ProfilingStats stats;
@ -82,7 +92,9 @@ static_assert(std::is_move_assignable_v<ExecutionContext>, "ExecutionContext mus
static_assert(std::is_move_constructible_v<ExecutionContext>, "ExecutionContext must be move constructible!");
inline bool MustAbort(const ExecutionContext &context) noexcept {
return (context.is_shutting_down != nullptr && context.is_shutting_down->load(std::memory_order_acquire)) ||
return (context.transaction_status != nullptr &&
context.transaction_status->load(std::memory_order_acquire) == TransactionStatus::TERMINATED) ||
(context.is_shutting_down != nullptr && context.is_shutting_down->load(std::memory_order_acquire)) ||
context.timer.IsExpired();
}

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -120,9 +120,8 @@ class HintedAbortError : public utils::BasicException {
using utils::BasicException::BasicException;
HintedAbortError()
: utils::BasicException(
"Transaction was asked to abort, most likely because it was "
"executing longer than time specified by "
"--query-execution-timeout-sec flag.") {}
"Transaction was asked to abort either because it was executing longer than time specified or another user "
"asked it to abort.") {}
};
class ExplicitTransactionUsageException : public QueryRuntimeException {
@ -237,4 +236,11 @@ class ReplicationException : public utils::BasicException {
: utils::BasicException("Replication Exception: {} Check the status of the replicas using 'SHOW REPLICA' query.",
message) {}
};
class TransactionQueueInMulticommandTxException : public QueryException {
public:
TransactionQueueInMulticommandTxException()
: QueryException("Transaction queue queries not allowed in multicommand transactions.") {}
};
} // namespace memgraph::query

View File

@ -0,0 +1,267 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/frontend/ast/ast.hpp"
#include "query/frontend/ast/ast_visitor.hpp"
#include "utils/typeinfo.hpp"
namespace memgraph {
constexpr utils::TypeInfo query::LabelIx::kType{utils::TypeId::AST_LABELIX, "LabelIx", nullptr};
constexpr utils::TypeInfo query::PropertyIx::kType{utils::TypeId::AST_PROPERTYIX, "PropertyIx", nullptr};
constexpr utils::TypeInfo query::EdgeTypeIx::kType{utils::TypeId::AST_EDGETYPEIX, "EdgeTypeIx", nullptr};
constexpr utils::TypeInfo query::Tree::kType{utils::TypeId::AST_TREE, "Tree", nullptr};
constexpr utils::TypeInfo query::Expression::kType{utils::TypeId::AST_EXPRESSION, "Expression", &query::Tree::kType};
constexpr utils::TypeInfo query::Where::kType{utils::TypeId::AST_WHERE, "Where", &query::Tree::kType};
constexpr utils::TypeInfo query::BinaryOperator::kType{utils::TypeId::AST_BINARY_OPERATOR, "BinaryOperator",
&query::Expression::kType};
constexpr utils::TypeInfo query::UnaryOperator::kType{utils::TypeId::AST_UNARY_OPERATOR, "UnaryOperator",
&query::Expression::kType};
constexpr utils::TypeInfo query::OrOperator::kType{utils::TypeId::AST_OR_OPERATOR, "OrOperator",
&query::BinaryOperator::kType};
constexpr utils::TypeInfo query::XorOperator::kType{utils::TypeId::AST_XOR_OPERATOR, "XorOperator",
&query::BinaryOperator::kType};
constexpr utils::TypeInfo query::AndOperator::kType{utils::TypeId::AST_AND_OPERATOR, "AndOperator",
&query::BinaryOperator::kType};
constexpr utils::TypeInfo query::AdditionOperator::kType{utils::TypeId::AST_ADDITION_OPERATOR, "AdditionOperator",
&query::BinaryOperator::kType};
constexpr utils::TypeInfo query::SubtractionOperator::kType{utils::TypeId::AST_SUBTRACTION_OPERATOR,
"SubtractionOperator", &query::BinaryOperator::kType};
constexpr utils::TypeInfo query::MultiplicationOperator::kType{utils::TypeId::AST_MULTIPLICATION_OPERATOR,
"MultiplicationOperator", &query::BinaryOperator::kType};
constexpr utils::TypeInfo query::DivisionOperator::kType{utils::TypeId::AST_DIVISION_OPERATOR, "DivisionOperator",
&query::BinaryOperator::kType};
constexpr utils::TypeInfo query::ModOperator::kType{utils::TypeId::AST_MOD_OPERATOR, "ModOperator",
&query::BinaryOperator::kType};
constexpr utils::TypeInfo query::NotEqualOperator::kType{utils::TypeId::AST_NOT_EQUAL_OPERATOR, "NotEqualOperator",
&query::BinaryOperator::kType};
constexpr utils::TypeInfo query::EqualOperator::kType{utils::TypeId::AST_EQUAL_OPERATOR, "EqualOperator",
&query::BinaryOperator::kType};
constexpr utils::TypeInfo query::LessOperator::kType{utils::TypeId::AST_LESS_OPERATOR, "LessOperator",
&query::BinaryOperator::kType};
constexpr utils::TypeInfo query::GreaterOperator::kType{utils::TypeId::AST_GREATER_OPERATOR, "GreaterOperator",
&query::BinaryOperator::kType};
constexpr utils::TypeInfo query::LessEqualOperator::kType{utils::TypeId::AST_LESS_EQUAL_OPERATOR, "LessEqualOperator",
&query::BinaryOperator::kType};
constexpr utils::TypeInfo query::GreaterEqualOperator::kType{utils::TypeId::AST_GREATER_EQUAL_OPERATOR,
"GreaterEqualOperator", &query::BinaryOperator::kType};
constexpr utils::TypeInfo query::InListOperator::kType{utils::TypeId::AST_IN_LIST_OPERATOR, "InListOperator",
&query::BinaryOperator::kType};
constexpr utils::TypeInfo query::SubscriptOperator::kType{utils::TypeId::AST_SUBSCRIPT_OPERATOR, "SubscriptOperator",
&query::BinaryOperator::kType};
constexpr utils::TypeInfo query::NotOperator::kType{utils::TypeId::AST_NOT_OPERATOR, "NotOperator",
&query::UnaryOperator::kType};
constexpr utils::TypeInfo query::UnaryPlusOperator::kType{utils::TypeId::AST_UNARY_PLUS_OPERATOR, "UnaryPlusOperator",
&query::UnaryOperator::kType};
constexpr utils::TypeInfo query::UnaryMinusOperator::kType{utils::TypeId::AST_UNARY_MINUS_OPERATOR,
"UnaryMinusOperator", &query::UnaryOperator::kType};
constexpr utils::TypeInfo query::IsNullOperator::kType{utils::TypeId::AST_IS_NULL_OPERATOR, "IsNullOperator",
&query::UnaryOperator::kType};
constexpr utils::TypeInfo query::Aggregation::kType{utils::TypeId::AST_AGGREGATION, "Aggregation",
&query::BinaryOperator::kType};
constexpr utils::TypeInfo query::ListSlicingOperator::kType{utils::TypeId::AST_LIST_SLICING_OPERATOR,
"ListSlicingOperator", &query::Expression::kType};
constexpr utils::TypeInfo query::IfOperator::kType{utils::TypeId::AST_IF_OPERATOR, "IfOperator",
&query::Expression::kType};
constexpr utils::TypeInfo query::BaseLiteral::kType{utils::TypeId::AST_BASE_LITERAL, "BaseLiteral",
&query::Expression::kType};
constexpr utils::TypeInfo query::PrimitiveLiteral::kType{utils::TypeId::AST_PRIMITIVE_LITERAL, "PrimitiveLiteral",
&query::BaseLiteral::kType};
constexpr utils::TypeInfo query::ListLiteral::kType{utils::TypeId::AST_LIST_LITERAL, "ListLiteral",
&query::BaseLiteral::kType};
constexpr utils::TypeInfo query::MapLiteral::kType{utils::TypeId::AST_MAP_LITERAL, "MapLiteral",
&query::BaseLiteral::kType};
constexpr utils::TypeInfo query::Identifier::kType{utils::TypeId::AST_IDENTIFIER, "Identifier",
&query::Expression::kType};
constexpr utils::TypeInfo query::PropertyLookup::kType{utils::TypeId::AST_PROPERTY_LOOKUP, "PropertyLookup",
&query::Expression::kType};
constexpr utils::TypeInfo query::LabelsTest::kType{utils::TypeId::AST_LABELS_TEST, "LabelsTest",
&query::Expression::kType};
constexpr utils::TypeInfo query::Function::kType{utils::TypeId::AST_FUNCTION, "Function", &query::Expression::kType};
constexpr utils::TypeInfo query::Reduce::kType{utils::TypeId::AST_REDUCE, "Reduce", &query::Expression::kType};
constexpr utils::TypeInfo query::Coalesce::kType{utils::TypeId::AST_COALESCE, "Coalesce", &query::Expression::kType};
constexpr utils::TypeInfo query::Extract::kType{utils::TypeId::AST_EXTRACT, "Extract", &query::Expression::kType};
constexpr utils::TypeInfo query::All::kType{utils::TypeId::AST_ALL, "All", &query::Expression::kType};
constexpr utils::TypeInfo query::Single::kType{utils::TypeId::AST_SINGLE, "Single", &query::Expression::kType};
constexpr utils::TypeInfo query::Any::kType{utils::TypeId::AST_ANY, "Any", &query::Expression::kType};
constexpr utils::TypeInfo query::None::kType{utils::TypeId::AST_NONE, "None", &query::Expression::kType};
constexpr utils::TypeInfo query::ParameterLookup::kType{utils::TypeId::AST_PARAMETER_LOOKUP, "ParameterLookup",
&query::Expression::kType};
constexpr utils::TypeInfo query::RegexMatch::kType{utils::TypeId::AST_REGEX_MATCH, "RegexMatch",
&query::Expression::kType};
constexpr utils::TypeInfo query::NamedExpression::kType{utils::TypeId::AST_NAMED_EXPRESSION, "NamedExpression",
&query::Tree::kType};
constexpr utils::TypeInfo query::PatternAtom::kType{utils::TypeId::AST_PATTERN_ATOM, "PatternAtom",
&query::Tree::kType};
constexpr utils::TypeInfo query::NodeAtom::kType{utils::TypeId::AST_NODE_ATOM, "NodeAtom", &query::PatternAtom::kType};
constexpr utils::TypeInfo query::EdgeAtom::Lambda::kType{utils::TypeId::AST_EDGE_ATOM_LAMBDA, "Lambda", nullptr};
constexpr utils::TypeInfo query::EdgeAtom::kType{utils::TypeId::AST_EDGE_ATOM, "EdgeAtom", &query::PatternAtom::kType};
constexpr utils::TypeInfo query::Pattern::kType{utils::TypeId::AST_PATTERN, "Pattern", &query::Tree::kType};
constexpr utils::TypeInfo query::Clause::kType{utils::TypeId::AST_CLAUSE, "Clause", &query::Tree::kType};
constexpr utils::TypeInfo query::SingleQuery::kType{utils::TypeId::AST_SINGLE_QUERY, "SingleQuery",
&query::Tree::kType};
constexpr utils::TypeInfo query::CypherUnion::kType{utils::TypeId::AST_CYPHER_UNION, "CypherUnion",
&query::Tree::kType};
constexpr utils::TypeInfo query::Query::kType{utils::TypeId::AST_QUERY, "Query", &query::Tree::kType};
constexpr utils::TypeInfo query::CypherQuery::kType{utils::TypeId::AST_CYPHER_QUERY, "CypherQuery",
&query::Query::kType};
constexpr utils::TypeInfo query::ExplainQuery::kType{utils::TypeId::AST_EXPLAIN_QUERY, "ExplainQuery",
&query::Query::kType};
constexpr utils::TypeInfo query::ProfileQuery::kType{utils::TypeId::AST_PROFILE_QUERY, "ProfileQuery",
&query::Query::kType};
constexpr utils::TypeInfo query::IndexQuery::kType{utils::TypeId::AST_INDEX_QUERY, "IndexQuery", &query::Query::kType};
constexpr utils::TypeInfo query::Create::kType{utils::TypeId::AST_CREATE, "Create", &query::Clause::kType};
constexpr utils::TypeInfo query::CallProcedure::kType{utils::TypeId::AST_CALL_PROCEDURE, "CallProcedure",
&query::Clause::kType};
constexpr utils::TypeInfo query::Match::kType{utils::TypeId::AST_MATCH, "Match", &query::Clause::kType};
constexpr utils::TypeInfo query::SortItem::kType{utils::TypeId::AST_SORT_ITEM, "SortItem", nullptr};
constexpr utils::TypeInfo query::ReturnBody::kType{utils::TypeId::AST_RETURN_BODY, "ReturnBody", nullptr};
constexpr utils::TypeInfo query::Return::kType{utils::TypeId::AST_RETURN, "Return", &query::Clause::kType};
constexpr utils::TypeInfo query::With::kType{utils::TypeId::AST_WITH, "With", &query::Clause::kType};
constexpr utils::TypeInfo query::Delete::kType{utils::TypeId::AST_DELETE, "Delete", &query::Clause::kType};
constexpr utils::TypeInfo query::SetProperty::kType{utils::TypeId::AST_SET_PROPERTY, "SetProperty",
&query::Clause::kType};
constexpr utils::TypeInfo query::SetProperties::kType{utils::TypeId::AST_SET_PROPERTIES, "SetProperties",
&query::Clause::kType};
constexpr utils::TypeInfo query::SetLabels::kType{utils::TypeId::AST_SET_LABELS, "SetLabels", &query::Clause::kType};
constexpr utils::TypeInfo query::RemoveProperty::kType{utils::TypeId::AST_REMOVE_PROPERTY, "RemoveProperty",
&query::Clause::kType};
constexpr utils::TypeInfo query::RemoveLabels::kType{utils::TypeId::AST_REMOVE_LABELS, "RemoveLabels",
&query::Clause::kType};
constexpr utils::TypeInfo query::Merge::kType{utils::TypeId::AST_MERGE, "Merge", &query::Clause::kType};
constexpr utils::TypeInfo query::Unwind::kType{utils::TypeId::AST_UNWIND, "Unwind", &query::Clause::kType};
constexpr utils::TypeInfo query::AuthQuery::kType{utils::TypeId::AST_AUTH_QUERY, "AuthQuery", &query::Query::kType};
constexpr utils::TypeInfo query::InfoQuery::kType{utils::TypeId::AST_INFO_QUERY, "InfoQuery", &query::Query::kType};
constexpr utils::TypeInfo query::Constraint::kType{utils::TypeId::AST_CONSTRAINT, "Constraint", nullptr};
constexpr utils::TypeInfo query::ConstraintQuery::kType{utils::TypeId::AST_CONSTRAINT_QUERY, "ConstraintQuery",
&query::Query::kType};
constexpr utils::TypeInfo query::DumpQuery::kType{utils::TypeId::AST_DUMP_QUERY, "DumpQuery", &query::Query::kType};
constexpr utils::TypeInfo query::ReplicationQuery::kType{utils::TypeId::AST_REPLICATION_QUERY, "ReplicationQuery",
&query::Query::kType};
constexpr utils::TypeInfo query::LockPathQuery::kType{utils::TypeId::AST_LOCK_PATH_QUERY, "LockPathQuery",
&query::Query::kType};
constexpr utils::TypeInfo query::LoadCsv::kType{utils::TypeId::AST_LOAD_CSV, "LoadCsv", &query::Clause::kType};
constexpr utils::TypeInfo query::FreeMemoryQuery::kType{utils::TypeId::AST_FREE_MEMORY_QUERY, "FreeMemoryQuery",
&query::Query::kType};
constexpr utils::TypeInfo query::TriggerQuery::kType{utils::TypeId::AST_TRIGGER_QUERY, "TriggerQuery",
&query::Query::kType};
constexpr utils::TypeInfo query::IsolationLevelQuery::kType{utils::TypeId::AST_ISOLATION_LEVEL_QUERY,
"IsolationLevelQuery", &query::Query::kType};
constexpr utils::TypeInfo query::CreateSnapshotQuery::kType{utils::TypeId::AST_CREATE_SNAPSHOT_QUERY,
"CreateSnapshotQuery", &query::Query::kType};
constexpr utils::TypeInfo query::StreamQuery::kType{utils::TypeId::AST_STREAM_QUERY, "StreamQuery",
&query::Query::kType};
constexpr utils::TypeInfo query::SettingQuery::kType{utils::TypeId::AST_SETTING_QUERY, "SettingQuery",
&query::Query::kType};
constexpr utils::TypeInfo query::VersionQuery::kType{utils::TypeId::AST_VERSION_QUERY, "VersionQuery",
&query::Query::kType};
constexpr utils::TypeInfo query::Foreach::kType{utils::TypeId::AST_FOREACH, "Foreach", &query::Clause::kType};
constexpr utils::TypeInfo query::ShowConfigQuery::kType{utils::TypeId::AST_SHOW_CONFIG_QUERY, "ShowConfigQuery",
&query::Query::kType};
constexpr utils::TypeInfo query::TransactionQueueQuery::kType{utils::TypeId::AST_TRANSACTION_QUEUE_QUERY,
"TransactionQueueQuery", &query::Query::kType};
constexpr utils::TypeInfo query::Exists::kType{utils::TypeId::AST_EXISTS, "Exists", &query::Expression::kType};
} // namespace memgraph

File diff suppressed because it is too large Load Diff

View File

@ -2284,7 +2284,7 @@ cpp<#
(lcp:define-enum privilege
(create delete match merge set remove index stats auth constraint
dump replication durability read_file free_memory trigger config stream module_read module_write
websocket)
websocket transaction_management)
(:serialize))
(lcp:define-enum fine-grained-privilege
(nothing read update create_delete)
@ -2333,7 +2333,7 @@ const std::vector<AuthQuery::Privilege> kPrivilegesAll = {
AuthQuery::Privilege::FREE_MEMORY, AuthQuery::Privilege::TRIGGER,
AuthQuery::Privilege::CONFIG, AuthQuery::Privilege::STREAM,
AuthQuery::Privilege::MODULE_READ, AuthQuery::Privilege::MODULE_WRITE,
AuthQuery::Privilege::WEBSOCKET};
AuthQuery::Privilege::WEBSOCKET, AuthQuery::Privilege::TRANSACTION_MANAGEMENT};
cpp<#
(lcp:define-class info-query (query)
@ -2661,6 +2661,26 @@ cpp<#
(:serialize (:slk))
(:clone))
(lcp:define-class transaction-queue-query (query)
((action "Action" :scope :public)
(transaction_id_list "std::vector<Expression*>" :scope :public))
(:public
(lcp:define-enum action
(show-transactions terminate-transactions)
(:serialize))
#>cpp
TransactionQueueQuery() = default;
DEFVISITABLE(QueryVisitor<void>);
cpp<#)
(:private
#>cpp
friend class AstStorage;
cpp<#)
(:serialize (:slk))
(:clone))
(lcp:define-class version-query (query) ()
(:public
#>cpp

View File

@ -95,6 +95,7 @@ class SettingQuery;
class VersionQuery;
class Foreach;
class ShowConfigQuery;
class TransactionQueueQuery;
class Exists;
using TreeCompositeVisitor = utils::CompositeVisitor<
@ -127,9 +128,10 @@ class ExpressionVisitor
None, ParameterLookup, Identifier, PrimitiveLiteral, RegexMatch, Exists> {};
template <class TResult>
class QueryVisitor : public utils::Visitor<TResult, CypherQuery, ExplainQuery, ProfileQuery, IndexQuery, AuthQuery,
InfoQuery, ConstraintQuery, DumpQuery, ReplicationQuery, LockPathQuery,
FreeMemoryQuery, TriggerQuery, IsolationLevelQuery, CreateSnapshotQuery,
StreamQuery, SettingQuery, VersionQuery, ShowConfigQuery> {};
class QueryVisitor
: public utils::Visitor<TResult, CypherQuery, ExplainQuery, ProfileQuery, IndexQuery, AuthQuery, InfoQuery,
ConstraintQuery, DumpQuery, ReplicationQuery, LockPathQuery, FreeMemoryQuery, TriggerQuery,
IsolationLevelQuery, CreateSnapshotQuery, StreamQuery, SettingQuery, TransactionQueueQuery,
VersionQuery, ShowConfigQuery> {};
} // namespace memgraph::query

View File

@ -11,8 +11,10 @@
#include "query/frontend/ast/cypher_main_visitor.hpp"
#include <support/Any.h>
#include <tree/ParseTreeVisitor.h>
#include <algorithm>
#include <any>
#include <climits>
#include <codecvt>
#include <cstring>
@ -631,6 +633,7 @@ void GetTopicNames(auto &destination, MemgraphCypher::TopicNamesContext *topic_n
destination = std::any_cast<Expression *>(topic_names_ctx->accept(&visitor));
}
}
} // namespace
antlrcpp::Any CypherMainVisitor::visitKafkaCreateStreamConfig(MemgraphCypher::KafkaCreateStreamConfigContext *ctx) {
@ -883,6 +886,34 @@ antlrcpp::Any CypherMainVisitor::visitShowSettings(MemgraphCypher::ShowSettingsC
return setting_query;
}
antlrcpp::Any CypherMainVisitor::visitTransactionQueueQuery(MemgraphCypher::TransactionQueueQueryContext *ctx) {
MG_ASSERT(ctx->children.size() == 1, "TransactionQueueQuery should have exactly one child!");
auto *transaction_queue_query = std::any_cast<TransactionQueueQuery *>(ctx->children[0]->accept(this));
query_ = transaction_queue_query;
return transaction_queue_query;
}
antlrcpp::Any CypherMainVisitor::visitShowTransactions(MemgraphCypher::ShowTransactionsContext * /*ctx*/) {
auto *transaction_shower = storage_->Create<TransactionQueueQuery>();
transaction_shower->action_ = TransactionQueueQuery::Action::SHOW_TRANSACTIONS;
return transaction_shower;
}
antlrcpp::Any CypherMainVisitor::visitTerminateTransactions(MemgraphCypher::TerminateTransactionsContext *ctx) {
auto *terminator = storage_->Create<TransactionQueueQuery>();
terminator->action_ = TransactionQueueQuery::Action::TERMINATE_TRANSACTIONS;
terminator->transaction_id_list_ = std::any_cast<std::vector<Expression *>>(ctx->transactionIdList()->accept(this));
return terminator;
}
antlrcpp::Any CypherMainVisitor::visitTransactionIdList(MemgraphCypher::TransactionIdListContext *ctx) {
std::vector<Expression *> transaction_ids;
for (auto *transaction_id : ctx->transactionId()) {
transaction_ids.push_back(std::any_cast<Expression *>(transaction_id->accept(this)));
}
return transaction_ids;
}
antlrcpp::Any CypherMainVisitor::visitVersionQuery(MemgraphCypher::VersionQueryContext * /*ctx*/) {
auto *version_query = storage_->Create<VersionQuery>();
query_ = version_query;
@ -1451,6 +1482,7 @@ antlrcpp::Any CypherMainVisitor::visitPrivilege(MemgraphCypher::PrivilegeContext
if (ctx->MODULE_READ()) return AuthQuery::Privilege::MODULE_READ;
if (ctx->MODULE_WRITE()) return AuthQuery::Privilege::MODULE_WRITE;
if (ctx->WEBSOCKET()) return AuthQuery::Privilege::WEBSOCKET;
if (ctx->TRANSACTION_MANAGEMENT()) return AuthQuery::Privilege::TRANSACTION_MANAGEMENT;
LOG_FATAL("Should not get here - unknown privilege!");
}

View File

@ -358,6 +358,26 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
*/
antlrcpp::Any visitShowSettings(MemgraphCypher::ShowSettingsContext *ctx) override;
/**
* @return TransactionQueueQuery*
*/
antlrcpp::Any visitTransactionQueueQuery(MemgraphCypher::TransactionQueueQueryContext *ctx) override;
/**
* @return ShowTransactions*
*/
antlrcpp::Any visitShowTransactions(MemgraphCypher::ShowTransactionsContext *ctx) override;
/**
* @return TerminateTransactions*
*/
antlrcpp::Any visitTerminateTransactions(MemgraphCypher::TerminateTransactionsContext *ctx) override;
/**
* @return TransactionIdList*
*/
antlrcpp::Any visitTransactionIdList(MemgraphCypher::TransactionIdListContext *ctx) override;
/**
* @return VersionQuery*
*/

View File

@ -102,6 +102,8 @@ memgraphCypherKeyword : cypherKeyword
| USER
| USERS
| VERSION
| TERMINATE
| TRANSACTIONS
;
symbolicName : UnescapedSymbolicName
@ -127,6 +129,7 @@ query : cypherQuery
| settingQuery
| versionQuery
| showConfigQuery
| transactionQueueQuery
;
authQuery : createRole
@ -197,6 +200,14 @@ settingQuery : setSetting
| showSettings
;
transactionQueueQuery : showTransactions
| terminateTransactions
;
showTransactions : SHOW TRANSACTIONS ;
terminateTransactions : TERMINATE TRANSACTIONS transactionIdList;
loadCsv : LOAD CSV FROM csvFile ( WITH | NO ) HEADER
( IGNORE BAD ) ?
( DELIMITER delimiter ) ?
@ -259,6 +270,7 @@ privilege : CREATE
| MODULE_READ
| MODULE_WRITE
| WEBSOCKET
| TRANSACTION_MANAGEMENT
;
granularPrivilege : NOTHING | READ | UPDATE | CREATE_DELETE ;
@ -402,3 +414,7 @@ showSettings : SHOW DATABASE SETTINGS ;
showConfigQuery : SHOW CONFIG ;
versionQuery : SHOW VERSION ;
transactionIdList : transactionId ( ',' transactionId )* ;
transactionId : literal ;

View File

@ -53,6 +53,7 @@ DIRECTORY : D I R E C T O R Y ;
DROP : D R O P ;
DUMP : D U M P ;
DURABILITY : D U R A B I L I T Y ;
EDGE_TYPES : E D G E UNDERSCORE T Y P E S ;
EXECUTE : E X E C U T E ;
FOR : F O R ;
FOREACH : F O R E A C H;
@ -103,10 +104,13 @@ STOP : S T O P ;
STREAM : S T R E A M ;
STREAMS : S T R E A M S ;
SYNC : S Y N C ;
TERMINATE : T E R M I N A T E ;
TIMEOUT : T I M E O U T ;
TO : T O ;
TOPICS : T O P I C S;
TRANSACTION : T R A N S A C T I O N ;
TRANSACTION_MANAGEMENT : T R A N S A C T I O N UNDERSCORE M A N A G E M E N T ;
TRANSACTIONS : T R A N S A C T I O N S ;
TRANSFORM : T R A N S F O R M ;
TRIGGER : T R I G G E R ;
TRIGGERS : T R I G G E R S ;
@ -117,4 +121,3 @@ USER : U S E R ;
USERS : U S E R S ;
VERSION : V E R S I O N ;
WEBSOCKET : W E B S O C K E T ;
EDGE_TYPES : E D G E UNDERSCORE T Y P E S ;

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -80,6 +80,8 @@ class PrivilegeExtractor : public QueryVisitor<void>, public HierarchicalTreeVis
void Visit(SettingQuery & /*setting_query*/) override { AddPrivilege(AuthQuery::Privilege::CONFIG); }
void Visit(TransactionQueueQuery & /*transaction_queue_query*/) override {}
void Visit(VersionQuery & /*version_query*/) override { AddPrivilege(AuthQuery::Privilege::STATS); }
bool PreVisit(Create & /*unused*/) override {

View File

@ -0,0 +1,18 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/frontend/semantic/symbol.hpp"
#include "utils/typeinfo.hpp"
namespace memgraph {
constexpr utils::TypeInfo query::Symbol::kType{utils::TypeId::SYMBOL, "Symbol", nullptr};
} // namespace memgraph

View File

@ -0,0 +1,75 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <string>
#include "utils/typeinfo.hpp"
namespace memgraph {
namespace query {
class Symbol {
public:
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
enum class Type { ANY, VERTEX, EDGE, PATH, NUMBER, EDGE_LIST };
// TODO: Generate enum to string conversion from LCP. Note, that this is
// displayed to the end user, so we may want to have a pretty name of each
// value.
static std::string TypeToString(Type type) {
const char *enum_string[] = {"Any", "Vertex", "Edge", "Path", "Number", "EdgeList"};
return enum_string[static_cast<int>(type)];
}
Symbol() {}
Symbol(const std::string &name, int position, bool user_declared, Type type = Type::ANY, int token_position = -1)
: name_(name), position_(position), user_declared_(user_declared), type_(type), token_position_(token_position) {}
bool operator==(const Symbol &other) const {
return position_ == other.position_ && name_ == other.name_ && type_ == other.type_;
}
bool operator!=(const Symbol &other) const { return !operator==(other); }
// TODO: Remove these since members are public
const auto &name() const { return name_; }
int position() const { return position_; }
Type type() const { return type_; }
bool user_declared() const { return user_declared_; }
int token_position() const { return token_position_; }
std::string name_;
int64_t position_;
bool user_declared_{true};
memgraph::query::Symbol::Type type_{Type::ANY};
int64_t token_position_{-1};
};
} // namespace query
} // namespace memgraph
namespace std {
template <>
struct hash<memgraph::query::Symbol> {
size_t operator()(const memgraph::query::Symbol &symbol) const {
size_t prime = 265443599u;
size_t hash = std::hash<int>{}(symbol.position());
hash ^= prime * std::hash<std::string>{}(symbol.name());
hash ^= prime * std::hash<int>{}(static_cast<int>(symbol.type()));
return hash;
}
};
} // namespace std

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -18,9 +18,12 @@
#include <cstddef>
#include <cstdint>
#include <functional>
#include <iterator>
#include <limits>
#include <optional>
#include <thread>
#include <unordered_map>
#include <utility>
#include <variant>
#include "auth/models.hpp"
@ -59,6 +62,7 @@
#include "utils/logging.hpp"
#include "utils/memory.hpp"
#include "utils/memory_tracker.hpp"
#include "utils/on_scope_exit.hpp"
#include "utils/readable_size.hpp"
#include "utils/settings.hpp"
#include "utils/string.hpp"
@ -975,7 +979,8 @@ struct PullPlanVector {
struct PullPlan {
explicit PullPlan(std::shared_ptr<CachedPlan> plan, const Parameters &parameters, bool is_profile_query,
DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory,
std::optional<std::string> username, TriggerContextCollector *trigger_context_collector = nullptr,
std::optional<std::string> username, std::atomic<TransactionStatus> *transaction_status,
TriggerContextCollector *trigger_context_collector = nullptr,
std::optional<size_t> memory_limit = {});
std::optional<plan::ProfilingStatsWithTotalTime> Pull(AnyStream *stream, std::optional<int> n,
const std::vector<Symbol> &output_symbols,
@ -1004,8 +1009,8 @@ struct PullPlan {
PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters &parameters, const bool is_profile_query,
DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory,
std::optional<std::string> username, TriggerContextCollector *trigger_context_collector,
const std::optional<size_t> memory_limit)
std::optional<std::string> username, std::atomic<TransactionStatus> *transaction_status,
TriggerContextCollector *trigger_context_collector, const std::optional<size_t> memory_limit)
: plan_(plan),
cursor_(plan->plan().MakeCursor(execution_memory)),
frame_(plan->symbol_table().max_position(), execution_memory),
@ -1025,6 +1030,7 @@ PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters &par
ctx_.timer = utils::AsyncTimer{interpreter_context->config.execution_timeout_sec};
}
ctx_.is_shutting_down = &interpreter_context->is_shutting_down;
ctx_.transaction_status = transaction_status;
ctx_.is_profile_query = is_profile_query;
ctx_.trigger_context_collector = trigger_context_collector;
}
@ -1137,12 +1143,14 @@ PreparedQuery Interpreter::PrepareTransactionQuery(std::string_view query_upper)
if (in_explicit_transaction_) {
throw ExplicitTransactionUsageException("Nested transactions are not supported.");
}
in_explicit_transaction_ = true;
expect_rollback_ = false;
db_accessor_ =
std::make_unique<storage::Storage::Accessor>(interpreter_context_->db->Access(GetIsolationLevelOverride()));
execution_db_accessor_.emplace(db_accessor_.get());
transaction_status_.store(TransactionStatus::ACTIVE, std::memory_order_release);
if (interpreter_context_->trigger_store.HasTriggers()) {
trigger_context_collector_.emplace(interpreter_context_->trigger_store.GetEventTypes());
@ -1194,7 +1202,7 @@ PreparedQuery Interpreter::PrepareTransactionQuery(std::string_view query_upper)
PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string, TypedValue> *summary,
InterpreterContext *interpreter_context, DbAccessor *dba,
utils::MemoryResource *execution_memory, std::vector<Notification> *notifications,
const std::string *username,
const std::string *username, std::atomic<TransactionStatus> *transaction_status,
TriggerContextCollector *trigger_context_collector = nullptr) {
auto *cypher_query = utils::Downcast<CypherQuery>(parsed_query.query);
@ -1239,9 +1247,9 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string,
header.push_back(
utils::FindOr(parsed_query.stripped_query.named_expressions(), symbol.token_position(), symbol.name()).first);
}
auto pull_plan =
std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context, execution_memory,
StringPointerToOptional(username), trigger_context_collector, memory_limit);
auto pull_plan = std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context,
execution_memory, StringPointerToOptional(username), transaction_status,
trigger_context_collector, memory_limit);
return PreparedQuery{std::move(header), std::move(parsed_query.required_privileges),
[pull_plan = std::move(pull_plan), output_symbols = std::move(output_symbols), summary](
AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> {
@ -1301,8 +1309,8 @@ PreparedQuery PrepareExplainQuery(ParsedQuery parsed_query, std::map<std::string
PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
std::map<std::string, TypedValue> *summary, InterpreterContext *interpreter_context,
DbAccessor *dba, utils::MemoryResource *execution_memory,
const std::string *username) {
DbAccessor *dba, utils::MemoryResource *execution_memory, const std::string *username,
std::atomic<TransactionStatus> *transaction_status) {
const std::string kProfileQueryStart = "profile ";
MG_ASSERT(utils::StartsWith(utils::ToLowerCase(parsed_query.stripped_query.query()), kProfileQueryStart),
@ -1363,13 +1371,14 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra
// We want to execute the query we are profiling lazily, so we delay
// the construction of the corresponding context.
stats_and_total_time = std::optional<plan::ProfilingStatsWithTotalTime>{},
pull_plan = std::shared_ptr<PullPlanVector>(nullptr)](
pull_plan = std::shared_ptr<PullPlanVector>(nullptr), transaction_status](
AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> {
// No output symbols are given so that nothing is streamed.
if (!stats_and_total_time) {
stats_and_total_time = PullPlan(plan, parameters, true, dba, interpreter_context,
execution_memory, optional_username, nullptr, memory_limit)
.Pull(stream, {}, {}, summary);
stats_and_total_time =
PullPlan(plan, parameters, true, dba, interpreter_context, execution_memory,
optional_username, transaction_status, nullptr, memory_limit)
.Pull(stream, {}, {}, summary);
pull_plan = std::make_shared<PullPlanVector>(ProfilingStatsToTable(*stats_and_total_time));
}
@ -1524,7 +1533,8 @@ PreparedQuery PrepareIndexQuery(ParsedQuery parsed_query, bool in_explicit_trans
PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
std::map<std::string, TypedValue> *summary, InterpreterContext *interpreter_context,
DbAccessor *dba, utils::MemoryResource *execution_memory, const std::string *username) {
DbAccessor *dba, utils::MemoryResource *execution_memory, const std::string *username,
std::atomic<TransactionStatus> *transaction_status) {
if (in_explicit_transaction) {
throw UserModificationInMulticommandTxException();
}
@ -1545,7 +1555,7 @@ PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transa
0.0, AstStorage{}, symbol_table));
auto pull_plan = std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context,
execution_memory, StringPointerToOptional(username));
execution_memory, StringPointerToOptional(username), transaction_status);
return PreparedQuery{
callback.header, std::move(parsed_query.required_privileges),
[pull_plan = std::move(pull_plan), callback = std::move(callback), output_symbols = std::move(output_symbols),
@ -1558,7 +1568,7 @@ PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transa
RWType::NONE};
}
PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, const bool in_explicit_transaction,
PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
std::vector<Notification> *notifications, InterpreterContext *interpreter_context,
DbAccessor *dba) {
if (in_explicit_transaction) {
@ -1586,7 +1596,7 @@ PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, const bool in_ex
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
}
PreparedQuery PrepareLockPathQuery(ParsedQuery parsed_query, const bool in_explicit_transaction,
PreparedQuery PrepareLockPathQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
InterpreterContext *interpreter_context, DbAccessor *dba) {
if (in_explicit_transaction) {
throw LockPathModificationInMulticommandTxException();
@ -1615,7 +1625,7 @@ PreparedQuery PrepareLockPathQuery(ParsedQuery parsed_query, const bool in_expli
RWType::NONE};
}
PreparedQuery PrepareFreeMemoryQuery(ParsedQuery parsed_query, const bool in_explicit_transaction,
PreparedQuery PrepareFreeMemoryQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
InterpreterContext *interpreter_context) {
if (in_explicit_transaction) {
throw FreeMemoryModificationInMulticommandTxException();
@ -1632,7 +1642,7 @@ PreparedQuery PrepareFreeMemoryQuery(ParsedQuery parsed_query, const bool in_exp
RWType::NONE};
}
PreparedQuery PrepareShowConfigQuery(ParsedQuery parsed_query, const bool in_explicit_transaction) {
PreparedQuery PrepareShowConfigQuery(ParsedQuery parsed_query, bool in_explicit_transaction) {
if (in_explicit_transaction) {
throw ShowConfigModificationInMulticommandTxException();
}
@ -1736,7 +1746,7 @@ Callback ShowTriggers(InterpreterContext *interpreter_context) {
}};
}
PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, const bool in_explicit_transaction,
PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
std::vector<Notification> *notifications, InterpreterContext *interpreter_context,
DbAccessor *dba, const std::map<std::string, storage::PropertyValue> &user_parameters,
const std::string *username) {
@ -1786,7 +1796,7 @@ PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, const bool in_explic
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
}
PreparedQuery PrepareStreamQuery(ParsedQuery parsed_query, const bool in_explicit_transaction,
PreparedQuery PrepareStreamQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
std::vector<Notification> *notifications, InterpreterContext *interpreter_context,
DbAccessor *dba,
const std::map<std::string, storage::PropertyValue> & /*user_parameters*/,
@ -1828,7 +1838,7 @@ constexpr auto ToStorageIsolationLevel(const IsolationLevelQuery::IsolationLevel
}
}
PreparedQuery PrepareIsolationLevelQuery(ParsedQuery parsed_query, const bool in_explicit_transaction,
PreparedQuery PrepareIsolationLevelQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
InterpreterContext *interpreter_context, Interpreter *interpreter) {
if (in_explicit_transaction) {
throw IsolationLevelModificationInMulticommandTxException();
@ -1883,7 +1893,7 @@ PreparedQuery PrepareCreateSnapshotQuery(ParsedQuery parsed_query, bool in_expli
RWType::NONE};
}
PreparedQuery PrepareSettingQuery(ParsedQuery parsed_query, const bool in_explicit_transaction, DbAccessor *dba) {
PreparedQuery PrepareSettingQuery(ParsedQuery parsed_query, bool in_explicit_transaction, DbAccessor *dba) {
if (in_explicit_transaction) {
throw SettingConfigInMulticommandTxException{};
}
@ -1909,7 +1919,155 @@ PreparedQuery PrepareSettingQuery(ParsedQuery parsed_query, const bool in_explic
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
}
PreparedQuery PrepareVersionQuery(ParsedQuery parsed_query, const bool in_explicit_transaction) {
std::vector<std::vector<TypedValue>> TransactionQueueQueryHandler::ShowTransactions(
const std::unordered_set<Interpreter *> &interpreters, const std::optional<std::string> &username,
bool hasTransactionManagementPrivilege) {
std::vector<std::vector<TypedValue>> results;
results.reserve(interpreters.size());
for (Interpreter *interpreter : interpreters) {
TransactionStatus alive_status = TransactionStatus::ACTIVE;
// if it is just checking status, commit and abort should wait for the end of the check
// ignore interpreters that already started committing or rollback
if (!interpreter->transaction_status_.compare_exchange_strong(alive_status, TransactionStatus::VERIFYING)) {
continue;
}
utils::OnScopeExit clean_status([interpreter]() {
interpreter->transaction_status_.store(TransactionStatus::ACTIVE, std::memory_order_release);
});
std::optional<uint64_t> transaction_id = interpreter->GetTransactionId();
if (transaction_id.has_value() && (interpreter->username_ == username || hasTransactionManagementPrivilege)) {
const auto &typed_queries = interpreter->GetQueries();
results.push_back({TypedValue(interpreter->username_.value_or("")),
TypedValue(std::to_string(transaction_id.value())), TypedValue(typed_queries)});
}
}
return results;
}
std::vector<std::vector<TypedValue>> TransactionQueueQueryHandler::KillTransactions(
InterpreterContext *interpreter_context, const std::vector<std::string> &maybe_kill_transaction_ids,
const std::optional<std::string> &username, bool hasTransactionManagementPrivilege) {
std::vector<std::vector<TypedValue>> results;
for (const std::string &transaction_id : maybe_kill_transaction_ids) {
bool killed = false;
bool transaction_found = false;
// Multiple simultaneous TERMINATE TRANSACTIONS aren't allowed
// TERMINATE and SHOW TRANSACTIONS are mutually exclusive
interpreter_context->interpreters.WithLock([&transaction_id, &killed, &transaction_found, username,
hasTransactionManagementPrivilege](const auto &interpreters) {
for (Interpreter *interpreter : interpreters) {
TransactionStatus alive_status = TransactionStatus::ACTIVE;
// if it is just checking kill, commit and abort should wait for the end of the check
// The only way to start checking if the transaction will get killed is if the transaction_status is
// active
if (!interpreter->transaction_status_.compare_exchange_strong(alive_status, TransactionStatus::VERIFYING)) {
continue;
}
utils::OnScopeExit clean_status([interpreter, &killed]() {
if (killed) {
interpreter->transaction_status_.store(TransactionStatus::TERMINATED, std::memory_order_release);
} else {
interpreter->transaction_status_.store(TransactionStatus::ACTIVE, std::memory_order_release);
}
});
std::optional<uint64_t> intr_trans = interpreter->GetTransactionId();
if (intr_trans.has_value() && std::to_string(intr_trans.value()) == transaction_id) {
transaction_found = true;
if (interpreter->username_ == username || hasTransactionManagementPrivilege) {
killed = true;
spdlog::warn("Transaction {} successfully killed", transaction_id);
} else {
spdlog::warn("Not enough rights to kill the transaction");
}
break;
}
}
});
if (!transaction_found) {
spdlog::warn("Transaction {} not found", transaction_id);
}
results.push_back({TypedValue(transaction_id), TypedValue(killed)});
}
return results;
}
Callback HandleTransactionQueueQuery(TransactionQueueQuery *transaction_query,
const std::optional<std::string> &username, const Parameters &parameters,
InterpreterContext *interpreter_context, DbAccessor *db_accessor) {
Frame frame(0);
SymbolTable symbol_table;
EvaluationContext evaluation_context;
evaluation_context.timestamp = QueryTimestamp();
evaluation_context.parameters = parameters;
ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, db_accessor, storage::View::OLD);
bool hasTransactionManagementPrivilege = interpreter_context->auth_checker->IsUserAuthorized(
username, {query::AuthQuery::Privilege::TRANSACTION_MANAGEMENT});
Callback callback;
switch (transaction_query->action_) {
case TransactionQueueQuery::Action::SHOW_TRANSACTIONS: {
callback.header = {"username", "transaction_id", "query"};
callback.fn = [handler = TransactionQueueQueryHandler(), interpreter_context, username,
hasTransactionManagementPrivilege]() mutable {
std::vector<std::vector<TypedValue>> results;
// Multiple simultaneous SHOW TRANSACTIONS aren't allowed
interpreter_context->interpreters.WithLock(
[&results, handler, username, hasTransactionManagementPrivilege](const auto &interpreters) {
results = handler.ShowTransactions(interpreters, username, hasTransactionManagementPrivilege);
});
return results;
};
break;
}
case TransactionQueueQuery::Action::TERMINATE_TRANSACTIONS: {
std::vector<std::string> maybe_kill_transaction_ids;
std::transform(transaction_query->transaction_id_list_.begin(), transaction_query->transaction_id_list_.end(),
std::back_inserter(maybe_kill_transaction_ids), [&evaluator](Expression *expression) {
return std::string(expression->Accept(evaluator).ValueString());
});
callback.header = {"transaction_id", "killed"};
callback.fn = [handler = TransactionQueueQueryHandler(), interpreter_context, maybe_kill_transaction_ids,
username, hasTransactionManagementPrivilege]() mutable {
return handler.KillTransactions(interpreter_context, maybe_kill_transaction_ids, username,
hasTransactionManagementPrivilege);
};
break;
}
}
return callback;
}
PreparedQuery PrepareTransactionQueueQuery(ParsedQuery parsed_query, const std::optional<std::string> &username,
bool in_explicit_transaction, InterpreterContext *interpreter_context,
DbAccessor *dba) {
if (in_explicit_transaction) {
throw TransactionQueueInMulticommandTxException();
}
auto *transaction_queue_query = utils::Downcast<TransactionQueueQuery>(parsed_query.query);
MG_ASSERT(transaction_queue_query);
auto callback =
HandleTransactionQueueQuery(transaction_queue_query, username, parsed_query.parameters, interpreter_context, dba);
return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges),
[callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}](
AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> {
if (UNLIKELY(!pull_plan)) {
pull_plan = std::make_shared<PullPlanVector>(callback_fn());
}
if (pull_plan->Pull(stream, n)) {
return QueryHandlerResult::COMMIT;
}
return std::nullopt;
},
RWType::NONE};
}
PreparedQuery PrepareVersionQuery(ParsedQuery parsed_query, bool in_explicit_transaction) {
if (in_explicit_transaction) {
throw VersionInfoInMulticommandTxException();
}
@ -2263,6 +2421,13 @@ PreparedQuery PrepareConstraintQuery(ParsedQuery parsed_query, bool in_explicit_
RWType::NONE};
}
std::optional<uint64_t> Interpreter::GetTransactionId() const {
if (db_accessor_) {
return db_accessor_->GetTransactionId();
}
return {};
}
void Interpreter::BeginTransaction() {
const auto prepared_query = PrepareTransactionQuery("BEGIN");
prepared_query.query_handler(nullptr, {});
@ -2272,12 +2437,14 @@ void Interpreter::CommitTransaction() {
const auto prepared_query = PrepareTransactionQuery("COMMIT");
prepared_query.query_handler(nullptr, {});
query_executions_.clear();
transaction_queries_->clear();
}
void Interpreter::RollbackTransaction() {
const auto prepared_query = PrepareTransactionQuery("ROLLBACK");
prepared_query.query_handler(nullptr, {});
query_executions_.clear();
transaction_queries_->clear();
}
Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
@ -2285,10 +2452,17 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
const std::string *username) {
if (!in_explicit_transaction_) {
query_executions_.clear();
transaction_queries_->clear();
}
// This will be done in the handle transaction query. Our handler can save username and then send it to the kill and
// show transactions.
std::optional<std::string> user = StringPointerToOptional(username);
username_ = user;
query_executions_.emplace_back(std::make_unique<QueryExecution>());
auto &query_execution = query_executions_.back();
std::optional<int> qid =
in_explicit_transaction_ ? static_cast<int>(query_executions_.size() - 1) : std::optional<int>{};
@ -2302,6 +2476,9 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
return {query_execution->prepared_query->header, query_execution->prepared_query->privileges, qid};
}
// Don't save BEGIN, COMMIT or ROLLBACK
transaction_queries_->push_back(query_string);
// All queries other than transaction control queries advance the command in
// an explicit transaction block.
if (in_explicit_transaction_) {
@ -2327,10 +2504,12 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
if (!in_explicit_transaction_ &&
(utils::Downcast<CypherQuery>(parsed_query.query) || utils::Downcast<ExplainQuery>(parsed_query.query) ||
utils::Downcast<ProfileQuery>(parsed_query.query) || utils::Downcast<DumpQuery>(parsed_query.query) ||
utils::Downcast<TriggerQuery>(parsed_query.query))) {
utils::Downcast<TriggerQuery>(parsed_query.query) ||
utils::Downcast<TransactionQueueQuery>(parsed_query.query))) {
db_accessor_ =
std::make_unique<storage::Storage::Accessor>(interpreter_context_->db->Access(GetIsolationLevelOverride()));
execution_db_accessor_.emplace(db_accessor_.get());
transaction_status_.store(TransactionStatus::ACTIVE, std::memory_order_release);
if (utils::Downcast<CypherQuery>(parsed_query.query) && interpreter_context_->trigger_store.HasTriggers()) {
trigger_context_collector_.emplace(interpreter_context_->trigger_store.GetEventTypes());
@ -2343,15 +2522,15 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
if (utils::Downcast<CypherQuery>(parsed_query.query)) {
prepared_query = PrepareCypherQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_,
&*execution_db_accessor_, &query_execution->execution_memory,
&query_execution->notifications, username,
&query_execution->notifications, username, &transaction_status_,
trigger_context_collector_ ? &*trigger_context_collector_ : nullptr);
} else if (utils::Downcast<ExplainQuery>(parsed_query.query)) {
prepared_query = PrepareExplainQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_,
&*execution_db_accessor_, &query_execution->execution_memory_with_exception);
} else if (utils::Downcast<ProfileQuery>(parsed_query.query)) {
prepared_query = PrepareProfileQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary,
interpreter_context_, &*execution_db_accessor_,
&query_execution->execution_memory_with_exception, username);
prepared_query = PrepareProfileQuery(
std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, interpreter_context_,
&*execution_db_accessor_, &query_execution->execution_memory_with_exception, username, &transaction_status_);
} else if (utils::Downcast<DumpQuery>(parsed_query.query)) {
prepared_query = PrepareDumpQuery(std::move(parsed_query), &query_execution->summary, &*execution_db_accessor_,
&query_execution->execution_memory);
@ -2359,9 +2538,9 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
prepared_query = PrepareIndexQuery(std::move(parsed_query), in_explicit_transaction_,
&query_execution->notifications, interpreter_context_);
} else if (utils::Downcast<AuthQuery>(parsed_query.query)) {
prepared_query = PrepareAuthQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary,
interpreter_context_, &*execution_db_accessor_,
&query_execution->execution_memory_with_exception, username);
prepared_query = PrepareAuthQuery(
std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, interpreter_context_,
&*execution_db_accessor_, &query_execution->execution_memory_with_exception, username, &transaction_status_);
} else if (utils::Downcast<InfoQuery>(parsed_query.query)) {
prepared_query = PrepareInfoQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary,
interpreter_context_, interpreter_context_->db,
@ -2398,6 +2577,9 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
prepared_query = PrepareSettingQuery(std::move(parsed_query), in_explicit_transaction_, &*execution_db_accessor_);
} else if (utils::Downcast<VersionQuery>(parsed_query.query)) {
prepared_query = PrepareVersionQuery(std::move(parsed_query), in_explicit_transaction_);
} else if (utils::Downcast<TransactionQueueQuery>(parsed_query.query)) {
prepared_query = PrepareTransactionQueueQuery(std::move(parsed_query), username_, in_explicit_transaction_,
interpreter_context_, &*execution_db_accessor_);
} else {
LOG_FATAL("Should not get here -- unknown query type!");
}
@ -2425,7 +2607,29 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
}
}
std::vector<TypedValue> Interpreter::GetQueries() {
auto typed_queries = std::vector<TypedValue>();
transaction_queries_.WithLock([&typed_queries](const auto &transaction_queries) {
std::for_each(transaction_queries.begin(), transaction_queries.end(),
[&typed_queries](const auto &query) { typed_queries.emplace_back(query); });
});
return typed_queries;
}
void Interpreter::Abort() {
auto expected = TransactionStatus::ACTIVE;
while (!transaction_status_.compare_exchange_weak(expected, TransactionStatus::STARTED_ROLLBACK)) {
if (expected == TransactionStatus::TERMINATED || expected == TransactionStatus::IDLE) {
transaction_status_.store(TransactionStatus::STARTED_ROLLBACK);
break;
}
expected = TransactionStatus::ACTIVE;
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
utils::OnScopeExit clean_status(
[this]() { transaction_status_.store(TransactionStatus::IDLE, std::memory_order_release); });
expect_rollback_ = false;
in_explicit_transaction_ = false;
if (!db_accessor_) return;
@ -2437,7 +2641,7 @@ void Interpreter::Abort() {
namespace {
void RunTriggersIndividually(const utils::SkipList<Trigger> &triggers, InterpreterContext *interpreter_context,
TriggerContext trigger_context) {
TriggerContext trigger_context, std::atomic<TransactionStatus> *transaction_status) {
// Run the triggers
for (const auto &trigger : triggers.access()) {
utils::MonotonicBufferResource execution_memory{kExecutionMemoryBlockSize};
@ -2449,7 +2653,8 @@ void RunTriggersIndividually(const utils::SkipList<Trigger> &triggers, Interpret
trigger_context.AdaptForAccessor(&db_accessor);
try {
trigger.Execute(&db_accessor, &execution_memory, interpreter_context->config.execution_timeout_sec,
&interpreter_context->is_shutting_down, trigger_context, interpreter_context->auth_checker);
&interpreter_context->is_shutting_down, transaction_status, trigger_context,
interpreter_context->auth_checker);
} catch (const utils::BasicException &exception) {
spdlog::warn("Trigger '{}' failed with exception:\n{}", trigger.Name(), exception.what());
db_accessor.Abort();
@ -2504,6 +2709,25 @@ void Interpreter::Commit() {
// a query.
if (!db_accessor_) return;
/*
At this point we must check that the transaction is alive to start committing. The only other possible state is
verifying and in that case we must check if the transaction was terminated and if yes abort committing. Exception
should suffice.
*/
auto expected = TransactionStatus::ACTIVE;
while (!transaction_status_.compare_exchange_weak(expected, TransactionStatus::STARTED_COMMITTING)) {
if (expected == TransactionStatus::TERMINATED) {
throw memgraph::utils::BasicException(
"Aborting transaction commit because the transaction was requested to stop from other session. ");
}
expected = TransactionStatus::ACTIVE;
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
// Clean transaction status if something went wrong
utils::OnScopeExit clean_status(
[this]() { transaction_status_.store(TransactionStatus::IDLE, std::memory_order_release); });
std::optional<TriggerContext> trigger_context = std::nullopt;
if (trigger_context_collector_) {
trigger_context.emplace(std::move(*trigger_context_collector_).TransformToTriggerContext());
@ -2517,7 +2741,8 @@ void Interpreter::Commit() {
AdvanceCommand();
try {
trigger.Execute(&*execution_db_accessor_, &execution_memory, interpreter_context_->config.execution_timeout_sec,
&interpreter_context_->is_shutting_down, *trigger_context, interpreter_context_->auth_checker);
&interpreter_context_->is_shutting_down, &transaction_status_, *trigger_context,
interpreter_context_->auth_checker);
} catch (const utils::BasicException &e) {
throw utils::BasicException(
fmt::format("Trigger '{}' caused the transaction to fail.\nException: {}", trigger.Name(), e.what()));
@ -2579,10 +2804,10 @@ void Interpreter::Commit() {
// This means the ordered execution of after commit triggers are not guaranteed.
if (trigger_context && interpreter_context_->trigger_store.AfterCommitTriggers().size() > 0) {
interpreter_context_->after_commit_trigger_pool.AddTask(
[trigger_context = std::move(*trigger_context), interpreter_context = this->interpreter_context_,
[this, trigger_context = std::move(*trigger_context),
user_transaction = std::shared_ptr(std::move(db_accessor_))]() mutable {
RunTriggersIndividually(interpreter_context->trigger_store.AfterCommitTriggers(), interpreter_context,
std::move(trigger_context));
RunTriggersIndividually(this->interpreter_context_->trigger_store.AfterCommitTriggers(),
this->interpreter_context_, std::move(trigger_context), &this->transaction_status_);
user_transaction->FinalizeTransaction();
SPDLOG_DEBUG("Finished executing after commit triggers"); // NOLINT(bugprone-lambda-function-name)
});

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -11,6 +11,8 @@
#pragma once
#include <unordered_set>
#include <gflags/gflags.h>
#include "query/auth_checker.hpp"
@ -37,6 +39,7 @@
#include "utils/settings.hpp"
#include "utils/skip_list.hpp"
#include "utils/spin_lock.hpp"
#include "utils/synchronized.hpp"
#include "utils/thread_pool.hpp"
#include "utils/timer.hpp"
#include "utils/tsc.hpp"
@ -179,12 +182,12 @@ struct PreparedQuery {
plan::ReadWriteTypeChecker::RWType rw_type;
};
class Interpreter;
/**
* Holds data shared between multiple `Interpreter` instances (which might be
* running concurrently).
*
* Users should initialize the context but should not modify it after it has
* been passed to an `Interpreter` instance.
*/
struct InterpreterContext {
explicit InterpreterContext(storage::Storage *db, InterpreterConfig config,
@ -214,6 +217,7 @@ struct InterpreterContext {
const InterpreterConfig config;
query::stream::Streams streams;
utils::Synchronized<std::unordered_set<Interpreter *>, utils::SpinLock> interpreters;
};
/// Function that is used to tell all active interpreters that they should stop
@ -235,6 +239,10 @@ class Interpreter final {
std::optional<int> qid;
};
std::optional<std::string> username_;
bool in_explicit_transaction_{false};
bool expect_rollback_{false};
/**
* Prepare a query for execution.
*
@ -290,6 +298,11 @@ class Interpreter final {
void BeginTransaction();
/*
Returns transaction id or empty if the db_accessor is not initialized.
*/
std::optional<uint64_t> GetTransactionId() const;
void CommitTransaction();
void RollbackTransaction();
@ -297,11 +310,15 @@ class Interpreter final {
void SetNextTransactionIsolationLevel(storage::IsolationLevel isolation_level);
void SetSessionIsolationLevel(storage::IsolationLevel isolation_level);
std::vector<TypedValue> GetQueries();
/**
* Abort the current multicommand transaction.
*/
void Abort();
std::atomic<TransactionStatus> transaction_status_{TransactionStatus::IDLE};
private:
struct QueryExecution {
std::optional<PreparedQuery> prepared_query;
@ -338,6 +355,8 @@ class Interpreter final {
// and deletion of a single query execution, i.e. when a query finishes,
// we reset the corresponding unique_ptr.
std::vector<std::unique_ptr<QueryExecution>> query_executions_;
// all queries that are run as part of the current transaction
utils::Synchronized<std::vector<std::string>, utils::SpinLock> transaction_queries_;
InterpreterContext *interpreter_context_;
@ -347,8 +366,6 @@ class Interpreter final {
std::unique_ptr<storage::Storage::Accessor> db_accessor_;
std::optional<DbAccessor> execution_db_accessor_;
std::optional<TriggerContextCollector> trigger_context_collector_;
bool in_explicit_transaction_{false};
bool expect_rollback_{false};
std::optional<storage::IsolationLevel> interpreter_isolation_level;
std::optional<storage::IsolationLevel> next_transaction_isolation_level;
@ -365,12 +382,32 @@ class Interpreter final {
}
};
class TransactionQueueQueryHandler {
public:
TransactionQueueQueryHandler() = default;
virtual ~TransactionQueueQueryHandler() = default;
TransactionQueueQueryHandler(const TransactionQueueQueryHandler &) = default;
TransactionQueueQueryHandler &operator=(const TransactionQueueQueryHandler &) = default;
TransactionQueueQueryHandler(TransactionQueueQueryHandler &&) = default;
TransactionQueueQueryHandler &operator=(TransactionQueueQueryHandler &&) = default;
static std::vector<std::vector<TypedValue>> ShowTransactions(const std::unordered_set<Interpreter *> &interpreters,
const std::optional<std::string> &username,
bool hasTransactionManagementPrivilege);
static std::vector<std::vector<TypedValue>> KillTransactions(
InterpreterContext *interpreter_context, const std::vector<std::string> &maybe_kill_transaction_ids,
const std::optional<std::string> &username, bool hasTransactionManagementPrivilege);
};
template <typename TStream>
std::map<std::string, TypedValue> Interpreter::Pull(TStream *result_stream, std::optional<int> n,
std::optional<int> qid) {
MG_ASSERT(in_explicit_transaction_ || !qid, "qid can be only used in explicit transaction!");
const int qid_value = qid ? *qid : static_cast<int>(query_executions_.size() - 1);
const int qid_value = qid ? *qid : static_cast<int>(query_executions_.size() - 1);
if (qid_value < 0 || qid_value >= query_executions_.size()) {
throw InvalidArgumentsException("qid", "Query with specified ID does not exist!");
}
@ -430,6 +467,7 @@ std::map<std::string, TypedValue> Interpreter::Pull(TStream *result_stream, std:
// methods as we will delete summary contained in them which we need
// after our query finished executing.
query_executions_.clear();
transaction_queries_->clear();
} else {
// We can only clear this execution as some of the queries
// in the transaction can be in unfinished state

2296
src/query/plan/operator.hpp Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,148 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <cstdint>
#include "query/plan/operator.hpp"
namespace memgraph {
constexpr utils::TypeInfo query::plan::LogicalOperator::kType{utils::TypeId::LOGICAL_OPERATOR, "LogicalOperator",
nullptr};
constexpr utils::TypeInfo query::plan::Once::kType{utils::TypeId::ONCE, "Once", &query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::NodeCreationInfo::kType{utils::TypeId::NODE_CREATION_INFO, "NodeCreationInfo",
nullptr};
constexpr utils::TypeInfo query::plan::CreateNode::kType{utils::TypeId::CREATE_NODE, "CreateNode",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::EdgeCreationInfo::kType{utils::TypeId::EDGE_CREATION_INFO, "EdgeCreationInfo",
nullptr};
constexpr utils::TypeInfo query::plan::CreateExpand::kType{utils::TypeId::CREATE_EXPAND, "CreateExpand",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::ScanAll::kType{utils::TypeId::SCAN_ALL, "ScanAll",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::ScanAllByLabel::kType{utils::TypeId::SCAN_ALL_BY_LABEL, "ScanAllByLabel",
&query::plan::ScanAll::kType};
constexpr utils::TypeInfo query::plan::ScanAllByLabelPropertyRange::kType{
utils::TypeId::SCAN_ALL_BY_LABEL_PROPERTY_RANGE, "ScanAllByLabelPropertyRange", &query::plan::ScanAll::kType};
constexpr utils::TypeInfo query::plan::ScanAllByLabelPropertyValue::kType{
utils::TypeId::SCAN_ALL_BY_LABEL_PROPERTY_VALUE, "ScanAllByLabelPropertyValue", &query::plan::ScanAll::kType};
constexpr utils::TypeInfo query::plan::ScanAllByLabelProperty::kType{
utils::TypeId::SCAN_ALL_BY_LABEL_PROPERTY, "ScanAllByLabelProperty", &query::plan::ScanAll::kType};
constexpr utils::TypeInfo query::plan::ScanAllById::kType{utils::TypeId::SCAN_ALL_BY_ID, "ScanAllById",
&query::plan::ScanAll::kType};
constexpr utils::TypeInfo query::plan::ExpandCommon::kType{utils::TypeId::EXPAND_COMMON, "ExpandCommon", nullptr};
constexpr utils::TypeInfo query::plan::Expand::kType{utils::TypeId::EXPAND, "Expand",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::ExpansionLambda::kType{utils::TypeId::EXPANSION_LAMBDA, "ExpansionLambda",
nullptr};
constexpr utils::TypeInfo query::plan::ExpandVariable::kType{utils::TypeId::EXPAND_VARIABLE, "ExpandVariable",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::ConstructNamedPath::kType{
utils::TypeId::CONSTRUCT_NAMED_PATH, "ConstructNamedPath", &query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::Filter::kType{utils::TypeId::FILTER, "Filter",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::Produce::kType{utils::TypeId::PRODUCE, "Produce",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::Delete::kType{utils::TypeId::DELETE, "Delete",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::SetProperty::kType{utils::TypeId::SET_PROPERTY, "SetProperty",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::SetProperties::kType{utils::TypeId::SET_PROPERTIES, "SetProperties",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::SetLabels::kType{utils::TypeId::SET_LABELS, "SetLabels",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::RemoveProperty::kType{utils::TypeId::REMOVE_PROPERTY, "RemoveProperty",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::RemoveLabels::kType{utils::TypeId::REMOVE_LABELS, "RemoveLabels",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::EdgeUniquenessFilter::kType{
utils::TypeId::EDGE_UNIQUENESS_FILTER, "EdgeUniquenessFilter", &query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::EmptyResult::kType{utils::TypeId::EMPTY_RESULT, "EmptyResult",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::Accumulate::kType{utils::TypeId::ACCUMULATE, "Accumulate",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::Aggregate::Element::kType{utils::TypeId::AGGREGATE_ELEMENT, "Element", nullptr};
constexpr utils::TypeInfo query::plan::Aggregate::kType{utils::TypeId::AGGREGATE, "Aggregate",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::Skip::kType{utils::TypeId::SKIP, "Skip", &query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::EvaluatePatternFilter::kType{
utils::TypeId::EVALUATE_PATTERN_FILTER, "EvaluatePatternFilter", &query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::Limit::kType{utils::TypeId::LIMIT, "Limit",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::OrderBy::kType{utils::TypeId::ORDERBY, "OrderBy",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::Merge::kType{utils::TypeId::MERGE, "Merge",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::Optional::kType{utils::TypeId::OPTIONAL, "Optional",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::Unwind::kType{utils::TypeId::UNWIND, "Unwind",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::Distinct::kType{utils::TypeId::DISTINCT, "Distinct",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::Union::kType{utils::TypeId::UNION, "Union",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::Cartesian::kType{utils::TypeId::CARTESIAN, "Cartesian",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::OutputTable::kType{utils::TypeId::OUTPUT_TABLE, "OutputTable",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::OutputTableStream::kType{utils::TypeId::OUTPUT_TABLE_STREAM, "OutputTableStream",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::CallProcedure::kType{utils::TypeId::CALL_PROCEDURE, "CallProcedure",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::LoadCsv::kType{utils::TypeId::LOAD_CSV, "LoadCsv",
&query::plan::LogicalOperator::kType};
constexpr utils::TypeInfo query::plan::Foreach::kType{utils::TypeId::FOREACH, "Foreach",
&query::plan::LogicalOperator::kType};
} // namespace memgraph

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -490,6 +490,11 @@ Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std
retry_interval = interpreter_context_->config.stream_transaction_retry_interval](
const std::vector<typename TStream::Message> &messages) mutable {
auto accessor = interpreter_context->db->Access();
// register new interpreter into interpreter_context_
interpreter_context->interpreters->insert(interpreter.get());
utils::OnScopeExit interpreter_cleanup{
[interpreter_context, interpreter]() { interpreter_context->interpreters->erase(interpreter.get()); }};
EventCounter::IncrementCounter(EventCounter::MessagesConsumed, messages.size());
CallCustomTransformation(transformation_name, messages, result, accessor, *memory_resource, stream_name);

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -195,7 +195,8 @@ std::shared_ptr<Trigger::TriggerPlan> Trigger::GetPlan(DbAccessor *db_accessor,
void Trigger::Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution_memory,
const double max_execution_time_sec, std::atomic<bool> *is_shutting_down,
const TriggerContext &context, const AuthChecker *auth_checker) const {
std::atomic<TransactionStatus> *transaction_status, const TriggerContext &context,
const AuthChecker *auth_checker) const {
if (!context.ShouldEventTrigger(event_type_)) {
return;
}
@ -214,6 +215,7 @@ void Trigger::Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution
ctx.evaluation_context.labels = NamesToLabels(plan.ast_storage().labels_, dba);
ctx.timer = utils::AsyncTimer(max_execution_time_sec);
ctx.is_shutting_down = is_shutting_down;
ctx.transaction_status = transaction_status;
ctx.is_profile_query = false;
// Set up temporary memory for a single Pull. Initial memory comes from the

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -31,6 +31,8 @@
#include "utils/spin_lock.hpp"
namespace memgraph::query {
enum class TransactionStatus;
struct Trigger {
explicit Trigger(std::string name, const std::string &query,
const std::map<std::string, storage::PropertyValue> &user_parameters, TriggerEventType event_type,
@ -39,8 +41,8 @@ struct Trigger {
const query::AuthChecker *auth_checker);
void Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution_memory, double max_execution_time_sec,
std::atomic<bool> *is_shutting_down, const TriggerContext &context,
const AuthChecker *auth_checker) const;
std::atomic<bool> *is_shutting_down, std::atomic<TransactionStatus> *transaction_status,
const TriggerContext &context, const AuthChecker *auth_checker) const;
bool operator==(const Trigger &other) const { return name_ == other.name_; }
// NOLINTNEXTLINE (modernize-use-nullptr)

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -23,6 +23,7 @@
#include "slk/streams.hpp"
#include "utils/logging.hpp"
#include "utils/on_scope_exit.hpp"
#include "utils/typeinfo.hpp"
namespace memgraph::rpc {
@ -84,11 +85,11 @@ class Client {
slk::Reader res_reader(self_->client_->GetData(), response_data_size);
utils::OnScopeExit res_cleanup([&, response_data_size] { self_->client_->ShiftData(response_data_size); });
uint64_t res_id = 0;
utils::TypeId res_id{utils::TypeId::UNKNOWN};
slk::Load(&res_id, &res_reader);
// Check the response ID.
if (res_id != res_type.id) {
if (res_id != res_type.id && res_id != utils::TypeId::UNKNOWN) {
spdlog::error("Message response was of unexpected type");
self_->client_ = std::nullopt;
throw RpcFailedException(self_->endpoint_);

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -16,6 +16,7 @@
#include "slk/serialization.hpp"
#include "slk/streams.hpp"
#include "utils/on_scope_exit.hpp"
#include "utils/typeinfo.hpp"
namespace memgraph::rpc {
@ -41,7 +42,7 @@ void Session::Execute() {
[&](const uint8_t *data, size_t size, bool have_more) { output_stream_->Write(data, size, have_more); });
// Load the request ID.
uint64_t req_id = 0;
utils::TypeId req_id{utils::TypeId::UNKNOWN};
slk::Load(&req_id, &req_reader);
// Access to `callbacks_` and `extended_callbacks_` is done here without

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -86,8 +86,8 @@ class Server {
};
std::mutex lock_;
std::map<uint64_t, RpcCallback> callbacks_;
std::map<uint64_t, RpcExtendedCallback> extended_callbacks_;
std::map<utils::TypeId, RpcCallback> callbacks_;
std::map<utils::TypeId, RpcExtendedCallback> extended_callbacks_;
communication::Server<Session, Server> server_;
};

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -29,8 +29,10 @@
#include "slk/streams.hpp"
#include "utils/cast.hpp"
#include "utils/concepts.hpp"
#include "utils/endian.hpp"
#include "utils/exceptions.hpp"
#include "utils/typeinfo.hpp"
// The namespace name stands for SaveLoadKit. It should be not mistaken for the
// Mercedes car model line.
@ -308,6 +310,10 @@ inline void Save(const std::optional<T> &obj, Builder *builder) {
}
}
inline void Save(const utils::TypeId &obj, Builder *builder) {
Save(static_cast<std::underlying_type_t<utils::TypeId>>(obj), builder);
}
template <typename T>
inline void Load(std::optional<T> *obj, Reader *reader) {
bool exists = false;
@ -471,4 +477,12 @@ inline void Load(std::optional<T> *obj, Reader *reader, std::function<void(T *,
*obj = std::nullopt;
}
}
inline void Load(utils::TypeId *obj, Reader *reader) {
using enum_type = std::underlying_type_t<utils::TypeId>;
enum_type obj_encoded;
slk::Load(&obj_encoded, reader);
*obj = utils::TypeId(utils::MemcpyCast<enum_type>(obj_encoded));
}
} // namespace memgraph::slk

View File

@ -12,12 +12,6 @@ set(storage_v2_src_files
vertex_accessor.cpp
storage.cpp)
##### Replication #####
define_add_lcp(add_lcp_storage lcp_storage_cpp_files generated_lcp_storage_files)
add_lcp_storage(replication/rpc.lcp SLK_SERIALIZE)
add_custom_target(generate_lcp_storage DEPENDS ${generated_lcp_storage_files})
set(storage_v2_src_files
${storage_v2_src_files}
@ -26,7 +20,7 @@ set(storage_v2_src_files
replication/serialization.cpp
replication/slk.cpp
replication/replication_persistence_helper.cpp
${lcp_storage_cpp_files})
replication/rpc.cpp)
#######################
find_package(gflags REQUIRED)
@ -35,5 +29,4 @@ find_package(Threads REQUIRED)
add_library(mg-storage-v2 STATIC ${storage_v2_src_files})
target_link_libraries(mg-storage-v2 Threads::Threads mg-utils gflags)
add_dependencies(mg-storage-v2 generate_lcp_storage)
target_link_libraries(mg-storage-v2 mg-rpc mg-slk)

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -12,6 +12,7 @@
#include "storage/v2/constraints.hpp"
#include <algorithm>
#include <atomic>
#include <cstring>
#include <map>
@ -71,7 +72,7 @@ bool LastCommittedVersionHasLabelProperty(const Vertex &vertex, LabelId label, c
while (delta != nullptr) {
auto ts = delta->timestamp->load(std::memory_order_acquire);
if (ts < commit_timestamp || ts == transaction.transaction_id) {
if (ts < commit_timestamp || ts == transaction.transaction_id.load(std::memory_order_acquire)) {
break;
}

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -11,6 +11,7 @@
#pragma once
#include <atomic>
#include "storage/v2/property_value.hpp"
#include "storage/v2/transaction.hpp"
#include "storage/v2/view.hpp"
@ -30,7 +31,7 @@ inline void ApplyDeltasForRead(Transaction *transaction, const Delta *delta, Vie
// This allows the transaction to see its changes even though it's committed.
const auto commit_timestamp = transaction->commit_timestamp
? transaction->commit_timestamp->load(std::memory_order_acquire)
: transaction->transaction_id;
: transaction->transaction_id.load(std::memory_order_acquire);
while (delta != nullptr) {
auto ts = delta->timestamp->load(std::memory_order_acquire);
auto cid = delta->command_id;
@ -80,7 +81,7 @@ inline bool PrepareForWrite(Transaction *transaction, TObj *object) {
if (object->delta == nullptr) return true;
auto ts = object->delta->timestamp->load(std::memory_order_acquire);
if (ts == transaction->transaction_id || ts < transaction->start_timestamp) {
if (ts == transaction->transaction_id.load(std::memory_order_acquire) || ts < transaction->start_timestamp) {
return true;
}

View File

@ -1,2 +0,0 @@
# autogenerated files
rpc.hpp

View File

@ -0,0 +1,263 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "storage/v2/replication/rpc.hpp"
#include "utils/typeinfo.hpp"
namespace memgraph {
namespace storage {
namespace replication {
void AppendDeltasReq::Save(const AppendDeltasReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void AppendDeltasReq::Load(AppendDeltasReq *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); }
void AppendDeltasRes::Save(const AppendDeltasRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void AppendDeltasRes::Load(AppendDeltasRes *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); }
void HeartbeatReq::Save(const HeartbeatReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void HeartbeatReq::Load(HeartbeatReq *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); }
void HeartbeatRes::Save(const HeartbeatRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void HeartbeatRes::Load(HeartbeatRes *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); }
void FrequentHeartbeatReq::Save(const FrequentHeartbeatReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void FrequentHeartbeatReq::Load(FrequentHeartbeatReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
void FrequentHeartbeatRes::Save(const FrequentHeartbeatRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void FrequentHeartbeatRes::Load(FrequentHeartbeatRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(self, reader);
}
void SnapshotReq::Save(const SnapshotReq &self, memgraph::slk::Builder *builder) { memgraph::slk::Save(self, builder); }
void SnapshotReq::Load(SnapshotReq *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); }
void SnapshotRes::Save(const SnapshotRes &self, memgraph::slk::Builder *builder) { memgraph::slk::Save(self, builder); }
void SnapshotRes::Load(SnapshotRes *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); }
void WalFilesReq::Save(const WalFilesReq &self, memgraph::slk::Builder *builder) { memgraph::slk::Save(self, builder); }
void WalFilesReq::Load(WalFilesReq *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); }
void WalFilesRes::Save(const WalFilesRes &self, memgraph::slk::Builder *builder) { memgraph::slk::Save(self, builder); }
void WalFilesRes::Load(WalFilesRes *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); }
void CurrentWalReq::Save(const CurrentWalReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void CurrentWalReq::Load(CurrentWalReq *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); }
void CurrentWalRes::Save(const CurrentWalRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void CurrentWalRes::Load(CurrentWalRes *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); }
void TimestampReq::Save(const TimestampReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void TimestampReq::Load(TimestampReq *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); }
void TimestampRes::Save(const TimestampRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self, builder);
}
void TimestampRes::Load(TimestampRes *self, memgraph::slk::Reader *reader) { memgraph::slk::Load(self, reader); }
} // namespace replication
} // namespace storage
constexpr utils::TypeInfo storage::replication::AppendDeltasReq::kType{utils::TypeId::REP_APPEND_DELTAS_REQ,
"AppendDeltasReq", nullptr};
constexpr utils::TypeInfo storage::replication::AppendDeltasRes::kType{utils::TypeId::REP_APPEND_DELTAS_RES,
"AppendDeltasRes", nullptr};
constexpr utils::TypeInfo storage::replication::HeartbeatReq::kType{utils::TypeId::REP_HEARTBEAT_REQ, "HeartbeatReq",
nullptr};
constexpr utils::TypeInfo storage::replication::HeartbeatRes::kType{utils::TypeId::REP_HEARTBEAT_RES, "HeartbeatRes",
nullptr};
constexpr utils::TypeInfo storage::replication::FrequentHeartbeatReq::kType{utils::TypeId::REP_FREQUENT_HEARTBEAT_REQ,
"FrequentHeartbeatReq", nullptr};
constexpr utils::TypeInfo storage::replication::FrequentHeartbeatRes::kType{utils::TypeId::REP_FREQUENT_HEARTBEAT_RES,
"FrequentHeartbeatRes", nullptr};
constexpr utils::TypeInfo storage::replication::SnapshotReq::kType{utils::TypeId::REP_SNAPSHOT_REQ, "SnapshotReq",
nullptr};
constexpr utils::TypeInfo storage::replication::SnapshotRes::kType{utils::TypeId::REP_SNAPSHOT_RES, "SnapshotRes",
nullptr};
constexpr utils::TypeInfo storage::replication::WalFilesReq::kType{utils::TypeId::REP_WALFILES_REQ, "WalFilesReq",
nullptr};
constexpr utils::TypeInfo storage::replication::WalFilesRes::kType{utils::TypeId::REP_WALFILES_RES, "WalFilesRes",
nullptr};
constexpr utils::TypeInfo storage::replication::CurrentWalReq::kType{utils::TypeId::REP_CURRENT_WAL_REQ,
"CurrentWalReq", nullptr};
constexpr utils::TypeInfo storage::replication::CurrentWalRes::kType{utils::TypeId::REP_CURRENT_WAL_RES,
"CurrentWalRes", nullptr};
constexpr utils::TypeInfo storage::replication::TimestampReq::kType{utils::TypeId::REP_TIMESTAMP_REQ, "TimestampReq",
nullptr};
constexpr utils::TypeInfo storage::replication::TimestampRes::kType{utils::TypeId::REP_TIMESTAMP_RES, "TimestampRes",
nullptr};
// Autogenerated SLK serialization code
namespace slk {
// Serialize code for TimestampRes
void Save(const memgraph::storage::replication::TimestampRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.success, builder);
memgraph::slk::Save(self.current_commit_timestamp, builder);
}
void Load(memgraph::storage::replication::TimestampRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->success, reader);
memgraph::slk::Load(&self->current_commit_timestamp, reader);
}
// Serialize code for TimestampReq
void Save(const memgraph::storage::replication::TimestampReq &self, memgraph::slk::Builder *builder) {}
void Load(memgraph::storage::replication::TimestampReq *self, memgraph::slk::Reader *reader) {}
// Serialize code for CurrentWalRes
void Save(const memgraph::storage::replication::CurrentWalRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.success, builder);
memgraph::slk::Save(self.current_commit_timestamp, builder);
}
void Load(memgraph::storage::replication::CurrentWalRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->success, reader);
memgraph::slk::Load(&self->current_commit_timestamp, reader);
}
// Serialize code for CurrentWalReq
void Save(const memgraph::storage::replication::CurrentWalReq &self, memgraph::slk::Builder *builder) {}
void Load(memgraph::storage::replication::CurrentWalReq *self, memgraph::slk::Reader *reader) {}
// Serialize code for WalFilesRes
void Save(const memgraph::storage::replication::WalFilesRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.success, builder);
memgraph::slk::Save(self.current_commit_timestamp, builder);
}
void Load(memgraph::storage::replication::WalFilesRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->success, reader);
memgraph::slk::Load(&self->current_commit_timestamp, reader);
}
// Serialize code for WalFilesReq
void Save(const memgraph::storage::replication::WalFilesReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.file_number, builder);
}
void Load(memgraph::storage::replication::WalFilesReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->file_number, reader);
}
// Serialize code for SnapshotRes
void Save(const memgraph::storage::replication::SnapshotRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.success, builder);
memgraph::slk::Save(self.current_commit_timestamp, builder);
}
void Load(memgraph::storage::replication::SnapshotRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->success, reader);
memgraph::slk::Load(&self->current_commit_timestamp, reader);
}
// Serialize code for SnapshotReq
void Save(const memgraph::storage::replication::SnapshotReq &self, memgraph::slk::Builder *builder) {}
void Load(memgraph::storage::replication::SnapshotReq *self, memgraph::slk::Reader *reader) {}
// Serialize code for FrequentHeartbeatRes
void Save(const memgraph::storage::replication::FrequentHeartbeatRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.success, builder);
}
void Load(memgraph::storage::replication::FrequentHeartbeatRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->success, reader);
}
// Serialize code for FrequentHeartbeatReq
void Save(const memgraph::storage::replication::FrequentHeartbeatReq &self, memgraph::slk::Builder *builder) {}
void Load(memgraph::storage::replication::FrequentHeartbeatReq *self, memgraph::slk::Reader *reader) {}
// Serialize code for HeartbeatRes
void Save(const memgraph::storage::replication::HeartbeatRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.success, builder);
memgraph::slk::Save(self.current_commit_timestamp, builder);
memgraph::slk::Save(self.epoch_id, builder);
}
void Load(memgraph::storage::replication::HeartbeatRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->success, reader);
memgraph::slk::Load(&self->current_commit_timestamp, reader);
memgraph::slk::Load(&self->epoch_id, reader);
}
// Serialize code for HeartbeatReq
void Save(const memgraph::storage::replication::HeartbeatReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.main_commit_timestamp, builder);
memgraph::slk::Save(self.epoch_id, builder);
}
void Load(memgraph::storage::replication::HeartbeatReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->main_commit_timestamp, reader);
memgraph::slk::Load(&self->epoch_id, reader);
}
// Serialize code for AppendDeltasRes
void Save(const memgraph::storage::replication::AppendDeltasRes &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.success, builder);
memgraph::slk::Save(self.current_commit_timestamp, builder);
}
void Load(memgraph::storage::replication::AppendDeltasRes *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->success, reader);
memgraph::slk::Load(&self->current_commit_timestamp, reader);
}
// Serialize code for AppendDeltasReq
void Save(const memgraph::storage::replication::AppendDeltasReq &self, memgraph::slk::Builder *builder) {
memgraph::slk::Save(self.previous_commit_timestamp, builder);
memgraph::slk::Save(self.seq_num, builder);
}
void Load(memgraph::storage::replication::AppendDeltasReq *self, memgraph::slk::Reader *reader) {
memgraph::slk::Load(&self->previous_commit_timestamp, reader);
memgraph::slk::Load(&self->seq_num, reader);
}
} // namespace slk
} // namespace memgraph

View File

@ -0,0 +1,278 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <cstdint>
#include <cstring>
#include <string>
#include "rpc/messages.hpp"
#include "slk/serialization.hpp"
#include "slk/streams.hpp"
namespace memgraph {
namespace storage {
namespace replication {
struct AppendDeltasReq {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(AppendDeltasReq *self, memgraph::slk::Reader *reader);
static void Save(const AppendDeltasReq &self, memgraph::slk::Builder *builder);
AppendDeltasReq() {}
AppendDeltasReq(uint64_t previous_commit_timestamp, uint64_t seq_num)
: previous_commit_timestamp(previous_commit_timestamp), seq_num(seq_num) {}
uint64_t previous_commit_timestamp;
uint64_t seq_num;
};
struct AppendDeltasRes {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(AppendDeltasRes *self, memgraph::slk::Reader *reader);
static void Save(const AppendDeltasRes &self, memgraph::slk::Builder *builder);
AppendDeltasRes() {}
AppendDeltasRes(bool success, uint64_t current_commit_timestamp)
: success(success), current_commit_timestamp(current_commit_timestamp) {}
bool success;
uint64_t current_commit_timestamp;
};
using AppendDeltasRpc = rpc::RequestResponse<AppendDeltasReq, AppendDeltasRes>;
struct HeartbeatReq {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(HeartbeatReq *self, memgraph::slk::Reader *reader);
static void Save(const HeartbeatReq &self, memgraph::slk::Builder *builder);
HeartbeatReq() {}
HeartbeatReq(uint64_t main_commit_timestamp, std::string epoch_id)
: main_commit_timestamp(main_commit_timestamp), epoch_id(epoch_id) {}
uint64_t main_commit_timestamp;
std::string epoch_id;
};
struct HeartbeatRes {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(HeartbeatRes *self, memgraph::slk::Reader *reader);
static void Save(const HeartbeatRes &self, memgraph::slk::Builder *builder);
HeartbeatRes() {}
HeartbeatRes(bool success, uint64_t current_commit_timestamp, std::string epoch_id)
: success(success), current_commit_timestamp(current_commit_timestamp), epoch_id(epoch_id) {}
bool success;
uint64_t current_commit_timestamp;
std::string epoch_id;
};
using HeartbeatRpc = rpc::RequestResponse<HeartbeatReq, HeartbeatRes>;
struct FrequentHeartbeatReq {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(FrequentHeartbeatReq *self, memgraph::slk::Reader *reader);
static void Save(const FrequentHeartbeatReq &self, memgraph::slk::Builder *builder);
FrequentHeartbeatReq() {}
};
struct FrequentHeartbeatRes {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(FrequentHeartbeatRes *self, memgraph::slk::Reader *reader);
static void Save(const FrequentHeartbeatRes &self, memgraph::slk::Builder *builder);
FrequentHeartbeatRes() {}
explicit FrequentHeartbeatRes(bool success) : success(success) {}
bool success;
};
using FrequentHeartbeatRpc = rpc::RequestResponse<FrequentHeartbeatReq, FrequentHeartbeatRes>;
struct SnapshotReq {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(SnapshotReq *self, memgraph::slk::Reader *reader);
static void Save(const SnapshotReq &self, memgraph::slk::Builder *builder);
SnapshotReq() {}
};
struct SnapshotRes {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(SnapshotRes *self, memgraph::slk::Reader *reader);
static void Save(const SnapshotRes &self, memgraph::slk::Builder *builder);
SnapshotRes() {}
SnapshotRes(bool success, uint64_t current_commit_timestamp)
: success(success), current_commit_timestamp(current_commit_timestamp) {}
bool success;
uint64_t current_commit_timestamp;
};
using SnapshotRpc = rpc::RequestResponse<SnapshotReq, SnapshotRes>;
struct WalFilesReq {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(WalFilesReq *self, memgraph::slk::Reader *reader);
static void Save(const WalFilesReq &self, memgraph::slk::Builder *builder);
WalFilesReq() {}
explicit WalFilesReq(uint64_t file_number) : file_number(file_number) {}
uint64_t file_number;
};
struct WalFilesRes {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(WalFilesRes *self, memgraph::slk::Reader *reader);
static void Save(const WalFilesRes &self, memgraph::slk::Builder *builder);
WalFilesRes() {}
WalFilesRes(bool success, uint64_t current_commit_timestamp)
: success(success), current_commit_timestamp(current_commit_timestamp) {}
bool success;
uint64_t current_commit_timestamp;
};
using WalFilesRpc = rpc::RequestResponse<WalFilesReq, WalFilesRes>;
struct CurrentWalReq {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(CurrentWalReq *self, memgraph::slk::Reader *reader);
static void Save(const CurrentWalReq &self, memgraph::slk::Builder *builder);
CurrentWalReq() {}
};
struct CurrentWalRes {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(CurrentWalRes *self, memgraph::slk::Reader *reader);
static void Save(const CurrentWalRes &self, memgraph::slk::Builder *builder);
CurrentWalRes() {}
CurrentWalRes(bool success, uint64_t current_commit_timestamp)
: success(success), current_commit_timestamp(current_commit_timestamp) {}
bool success;
uint64_t current_commit_timestamp;
};
using CurrentWalRpc = rpc::RequestResponse<CurrentWalReq, CurrentWalRes>;
struct TimestampReq {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(TimestampReq *self, memgraph::slk::Reader *reader);
static void Save(const TimestampReq &self, memgraph::slk::Builder *builder);
TimestampReq() {}
};
struct TimestampRes {
static const utils::TypeInfo kType;
static const utils::TypeInfo &GetTypeInfo() { return kType; }
static void Load(TimestampRes *self, memgraph::slk::Reader *reader);
static void Save(const TimestampRes &self, memgraph::slk::Builder *builder);
TimestampRes() {}
TimestampRes(bool success, uint64_t current_commit_timestamp)
: success(success), current_commit_timestamp(current_commit_timestamp) {}
bool success;
uint64_t current_commit_timestamp;
};
using TimestampRpc = rpc::RequestResponse<TimestampReq, TimestampRes>;
} // namespace replication
} // namespace storage
} // namespace memgraph
// SLK serialization declarations
#include "slk/serialization.hpp"
namespace memgraph::slk {
void Save(const memgraph::storage::replication::TimestampRes &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::TimestampRes *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::TimestampReq &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::TimestampReq *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::CurrentWalRes &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::CurrentWalRes *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::CurrentWalReq &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::CurrentWalReq *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::WalFilesRes &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::WalFilesRes *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::WalFilesReq &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::WalFilesReq *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::SnapshotRes &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::SnapshotRes *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::SnapshotReq &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::SnapshotReq *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::FrequentHeartbeatRes &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::FrequentHeartbeatRes *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::FrequentHeartbeatReq &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::FrequentHeartbeatReq *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::HeartbeatRes &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::HeartbeatRes *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::HeartbeatReq &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::HeartbeatReq *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::AppendDeltasRes &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::AppendDeltasRes *self, memgraph::slk::Reader *reader);
void Save(const memgraph::storage::replication::AppendDeltasReq &self, memgraph::slk::Builder *builder);
void Load(memgraph::storage::replication::AppendDeltasReq *self, memgraph::slk::Reader *reader);
} // namespace memgraph::slk

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -985,8 +985,8 @@ void Storage::Accessor::Abort() {
auto vertex = prev.vertex;
std::lock_guard<utils::SpinLock> guard(vertex->lock);
Delta *current = vertex->delta;
while (current != nullptr &&
current->timestamp->load(std::memory_order_acquire) == transaction_.transaction_id) {
while (current != nullptr && current->timestamp->load(std::memory_order_acquire) ==
transaction_.transaction_id.load(std::memory_order_acquire)) {
switch (current->action) {
case Delta::Action::REMOVE_LABEL: {
auto it = std::find(vertex->labels.begin(), vertex->labels.end(), current->label);
@ -1072,8 +1072,8 @@ void Storage::Accessor::Abort() {
auto edge = prev.edge;
std::lock_guard<utils::SpinLock> guard(edge->lock);
Delta *current = edge->delta;
while (current != nullptr &&
current->timestamp->load(std::memory_order_acquire) == transaction_.transaction_id) {
while (current != nullptr && current->timestamp->load(std::memory_order_acquire) ==
transaction_.transaction_id.load(std::memory_order_acquire)) {
switch (current->action) {
case Delta::Action::SET_PROPERTY: {
edge->properties.SetProperty(current->property.key, current->property.value);
@ -1144,6 +1144,13 @@ void Storage::Accessor::FinalizeTransaction() {
}
}
std::optional<uint64_t> Storage::Accessor::GetTransactionId() const {
if (is_transaction_active_) {
return transaction_.transaction_id.load(std::memory_order_acquire);
}
return {};
}
const std::string &Storage::LabelToName(LabelId label) const { return name_id_mapper_.IdToName(label.AsUint()); }
const std::string &Storage::PropertyToName(PropertyId property) const {

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -12,6 +12,7 @@
#pragma once
#include <atomic>
#include <cstdint>
#include <filesystem>
#include <optional>
#include <shared_mutex>
@ -324,6 +325,8 @@ class Storage final {
void FinalizeTransaction();
std::optional<uint64_t> GetTransactionId() const;
private:
/// @throw std::bad_alloc
VertexAccessor CreateVertex(storage::Gid gid);

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -39,7 +39,7 @@ struct Transaction {
isolation_level(isolation_level) {}
Transaction(Transaction &&other) noexcept
: transaction_id(other.transaction_id),
: transaction_id(other.transaction_id.load(std::memory_order_acquire)),
start_timestamp(other.start_timestamp),
commit_timestamp(std::move(other.commit_timestamp)),
command_id(other.command_id),
@ -56,10 +56,10 @@ struct Transaction {
/// @throw std::bad_alloc if failed to create the `commit_timestamp`
void EnsureCommitTimestampExists() {
if (commit_timestamp != nullptr) return;
commit_timestamp = std::make_unique<std::atomic<uint64_t>>(transaction_id);
commit_timestamp = std::make_unique<std::atomic<uint64_t>>(transaction_id.load(std::memory_order_relaxed));
}
uint64_t transaction_id;
std::atomic<uint64_t> transaction_id;
uint64_t start_timestamp;
// The `Transaction` object is stack allocated, but the `commit_timestamp`
// must be heap allocated because `Delta`s have a pointer to it, and that
@ -73,12 +73,16 @@ struct Transaction {
};
inline bool operator==(const Transaction &first, const Transaction &second) {
return first.transaction_id == second.transaction_id;
return first.transaction_id.load(std::memory_order_acquire) == second.transaction_id.load(std::memory_order_acquire);
}
inline bool operator<(const Transaction &first, const Transaction &second) {
return first.transaction_id < second.transaction_id;
return first.transaction_id.load(std::memory_order_acquire) < second.transaction_id.load(std::memory_order_acquire);
}
inline bool operator==(const Transaction &first, const uint64_t &second) {
return first.transaction_id.load(std::memory_order_acquire) == second;
}
inline bool operator<(const Transaction &first, const uint64_t &second) {
return first.transaction_id.load(std::memory_order_acquire) < second;
}
inline bool operator==(const Transaction &first, const uint64_t &second) { return first.transaction_id == second; }
inline bool operator<(const Transaction &first, const uint64_t &second) { return first.transaction_id < second; }
} // namespace memgraph::storage

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -16,6 +16,172 @@
namespace memgraph::utils {
enum class TypeId : uint64_t {
// Operators
UNKNOWN,
LOGICAL_OPERATOR,
ONCE,
NODE_CREATION_INFO,
CREATE_NODE,
EDGE_CREATION_INFO,
CREATE_EXPAND,
SCAN_ALL,
SCAN_ALL_BY_LABEL,
SCAN_ALL_BY_LABEL_PROPERTY_RANGE,
SCAN_ALL_BY_LABEL_PROPERTY_VALUE,
SCAN_ALL_BY_LABEL_PROPERTY,
SCAN_ALL_BY_ID,
EXPAND_COMMON,
EXPAND,
EXPANSION_LAMBDA,
EXPAND_VARIABLE,
CONSTRUCT_NAMED_PATH,
FILTER,
PRODUCE,
DELETE,
SET_PROPERTY,
SET_PROPERTIES,
SET_LABELS,
REMOVE_PROPERTY,
REMOVE_LABELS,
EDGE_UNIQUENESS_FILTER,
EMPTY_RESULT,
ACCUMULATE,
AGGREGATE,
AGGREGATE_ELEMENT,
SKIP,
EVALUATE_PATTERN_FILTER,
LIMIT,
ORDERBY,
MERGE,
OPTIONAL,
UNWIND,
DISTINCT,
UNION,
CARTESIAN,
OUTPUT_TABLE,
OUTPUT_TABLE_STREAM,
CALL_PROCEDURE,
LOAD_CSV,
FOREACH,
// Replication
REP_APPEND_DELTAS_REQ,
REP_APPEND_DELTAS_RES,
REP_HEARTBEAT_REQ,
REP_HEARTBEAT_RES,
REP_FREQUENT_HEARTBEAT_REQ,
REP_FREQUENT_HEARTBEAT_RES,
REP_SNAPSHOT_REQ,
REP_SNAPSHOT_RES,
REP_WALFILES_REQ,
REP_WALFILES_RES,
REP_CURRENT_WAL_REQ,
REP_CURRENT_WAL_RES,
REP_TIMESTAMP_REQ,
REP_TIMESTAMP_RES,
// AST
AST_LABELIX,
AST_PROPERTYIX,
AST_EDGETYPEIX,
AST_TREE,
AST_EXPRESSION,
AST_WHERE,
AST_BINARY_OPERATOR,
AST_UNARY_OPERATOR,
AST_OR_OPERATOR,
AST_XOR_OPERATOR,
AST_AND_OPERATOR,
AST_ADDITION_OPERATOR,
AST_SUBTRACTION_OPERATOR,
AST_MULTIPLICATION_OPERATOR,
AST_DIVISION_OPERATOR,
AST_MOD_OPERATOR,
AST_NOT_EQUAL_OPERATOR,
AST_EQUAL_OPERATOR,
AST_LESS_OPERATOR,
AST_GREATER_OPERATOR,
AST_LESS_EQUAL_OPERATOR,
AST_GREATER_EQUAL_OPERATOR,
AST_IN_LIST_OPERATOR,
AST_SUBSCRIPT_OPERATOR,
AST_NOT_OPERATOR,
AST_UNARY_PLUS_OPERATOR,
AST_UNARY_MINUS_OPERATOR,
AST_IS_NULL_OPERATOR,
AST_AGGREGATION,
AST_LIST_SLICING_OPERATOR,
AST_IF_OPERATOR,
AST_BASE_LITERAL,
AST_PRIMITIVE_LITERAL,
AST_LIST_LITERAL,
AST_MAP_LITERAL,
AST_IDENTIFIER,
AST_PROPERTY_LOOKUP,
AST_LABELS_TEST,
AST_FUNCTION,
AST_REDUCE,
AST_COALESCE,
AST_EXTRACT,
AST_ALL,
AST_SINGLE,
AST_ANY,
AST_NONE,
AST_PARAMETER_LOOKUP,
AST_REGEX_MATCH,
AST_NAMED_EXPRESSION,
AST_PATTERN_ATOM,
AST_NODE_ATOM,
AST_EDGE_ATOM_LAMBDA,
AST_EDGE_ATOM,
AST_PATTERN,
AST_CLAUSE,
AST_SINGLE_QUERY,
AST_CYPHER_UNION,
AST_QUERY,
AST_CYPHER_QUERY,
AST_EXPLAIN_QUERY,
AST_PROFILE_QUERY,
AST_INDEX_QUERY,
AST_CREATE,
AST_CALL_PROCEDURE,
AST_MATCH,
AST_SORT_ITEM,
AST_RETURN_BODY,
AST_RETURN,
AST_WITH,
AST_DELETE,
AST_SET_PROPERTY,
AST_SET_PROPERTIES,
AST_SET_LABELS,
AST_REMOVE_PROPERTY,
AST_REMOVE_LABELS,
AST_MERGE,
AST_UNWIND,
AST_AUTH_QUERY,
AST_INFO_QUERY,
AST_CONSTRAINT,
AST_CONSTRAINT_QUERY,
AST_DUMP_QUERY,
AST_REPLICATION_QUERY,
AST_LOCK_PATH_QUERY,
AST_LOAD_CSV,
AST_FREE_MEMORY_QUERY,
AST_TRIGGER_QUERY,
AST_ISOLATION_LEVEL_QUERY,
AST_CREATE_SNAPSHOT_QUERY,
AST_STREAM_QUERY,
AST_SETTING_QUERY,
AST_VERSION_QUERY,
AST_FOREACH,
AST_SHOW_CONFIG_QUERY,
AST_TRANSACTION_QUEUE_QUERY,
AST_EXISTS,
// Symbol
SYMBOL,
};
/// Type information on a C++ type.
///
/// You should embed this structure as a static constant member `kType` and make
@ -24,7 +190,7 @@ namespace memgraph::utils {
/// runtime type.
struct TypeInfo {
/// Unique ID for the type.
uint64_t id;
TypeId id;
/// Pretty name of the type.
const char *name;
/// `TypeInfo *` for superclass of this type.

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -25,7 +25,7 @@ struct EchoMessage {
static const memgraph::utils::TypeInfo kType;
EchoMessage() {} // Needed for serialization.
EchoMessage(const std::string &data) : data(data) {}
explicit EchoMessage(const std::string &data) : data(data) {}
static void Load(EchoMessage *obj, memgraph::slk::Reader *reader);
static void Save(const EchoMessage &obj, memgraph::slk::Builder *builder);
@ -41,7 +41,7 @@ void Load(EchoMessage *echo, Reader *reader) { Load(&echo->data, reader); }
void EchoMessage::Load(EchoMessage *obj, memgraph::slk::Reader *reader) { memgraph::slk::Load(obj, reader); }
void EchoMessage::Save(const EchoMessage &obj, memgraph::slk::Builder *builder) { memgraph::slk::Save(obj, builder); }
const memgraph::utils::TypeInfo EchoMessage::kType{2, "EchoMessage"};
const memgraph::utils::TypeInfo EchoMessage::kType{memgraph::utils::TypeId::UNKNOWN, "EchoMessage"};
using Echo = memgraph::rpc::RequestResponse<EchoMessage, EchoMessage>;

View File

@ -44,6 +44,7 @@ add_subdirectory(module_file_manager)
add_subdirectory(monitoring_server)
add_subdirectory(lba_procedures)
add_subdirectory(python_query_modules_reloading)
add_subdirectory(transaction_queue)
add_subdirectory(mock_api)
copy_e2e_python_files(pytest_runner pytest_runner.sh "")

View File

@ -10,6 +10,7 @@
# licenses/APL.txt.
import sys
import pytest
from common import connect, execute_and_fetch_all
@ -35,6 +36,7 @@ BASIC_PRIVILEGES = [
"MODULE_READ",
"WEBSOCKET",
"MODULE_WRITE",
"TRANSACTION_MANAGEMENT",
]
@ -58,7 +60,7 @@ def test_lba_procedures_show_privileges_first_user():
cursor = connect(username="Josip", password="").cursor()
result = execute_and_fetch_all(cursor, "SHOW PRIVILEGES FOR Josip;")
assert len(result) == 30
assert len(result) == 31
fine_privilege_results = [res for res in result if res[0] not in BASIC_PRIVILEGES]

View File

@ -0,0 +1,8 @@
function(copy_query_modules_reloading_procedures_e2e_python_files FILE_NAME)
copy_e2e_python_files(transaction_queue ${FILE_NAME})
endfunction()
copy_query_modules_reloading_procedures_e2e_python_files(common.py)
copy_query_modules_reloading_procedures_e2e_python_files(test_transaction_queue.py)
add_subdirectory(procedures)

View File

@ -0,0 +1,26 @@
# Copyright 2022 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import typing
import mgclient
import pytest
def execute_and_fetch_all(cursor: mgclient.Cursor, query: str, params: dict = {}) -> typing.List[tuple]:
cursor.execute(query, params)
return cursor.fetchall()
def connect(**kwargs) -> mgclient.Connection:
connection = mgclient.connect(host="localhost", port=7687, **kwargs)
connection.autocommit = True
return connection

View File

@ -0,0 +1 @@
copy_e2e_python_files(transaction_queue infinite_query.py)

View File

@ -0,0 +1,27 @@
# Copyright 2021 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import threading
import time
import mgp
@mgp.read_proc
def long_query(ctx: mgp.ProcCtx) -> mgp.Record(my_id=int):
id = 1
try:
while True:
if ctx.check_must_abort():
break
id += 1
except mgp.AbortError:
return mgp.Record(my_id=id)

View File

@ -0,0 +1,338 @@
# Copyright 2022 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import multiprocessing
import sys
import threading
import time
from typing import List
import mgclient
import pytest
from common import connect, execute_and_fetch_all
# Utility functions
# -------------------------
def get_non_show_transaction_id(results):
"""Returns transaction id of the first transaction that is not SHOW TRANSACTIONS;"""
for res in results:
if res[2] != ["SHOW TRANSACTIONS"]:
return res[1]
def show_transactions_test(cursor, expected_num_results: int):
results = execute_and_fetch_all(cursor, "SHOW TRANSACTIONS")
assert len(results) == expected_num_results
return results
def process_function(cursor, queries: List[str]):
try:
for query in queries:
cursor.execute(query, {})
except mgclient.DatabaseError:
pass
# Tests
# -------------------------
def test_self_transaction():
"""Tests that simple show transactions work when no other is running."""
cursor = connect().cursor()
results = execute_and_fetch_all(cursor, "SHOW TRANSACTIONS")
assert len(results) == 1
def test_admin_has_one_transaction():
"""Creates admin and tests that he sees only one transaction."""
# a_cursor is used for creating admin user, simulates main thread
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin")
admin_cursor = connect(username="admin", password="").cursor()
process = multiprocessing.Process(target=show_transactions_test, args=(admin_cursor, 1))
process.start()
process.join()
execute_and_fetch_all(superadmin_cursor, "DROP USER admin")
def test_user_can_see_its_transaction():
"""Tests that user without privileges can see its own transaction"""
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
execute_and_fetch_all(superadmin_cursor, "GRANT ALL PRIVILEGES TO admin")
execute_and_fetch_all(superadmin_cursor, "CREATE USER user")
execute_and_fetch_all(superadmin_cursor, "REVOKE ALL PRIVILEGES FROM user")
user_cursor = connect(username="user", password="").cursor()
process = multiprocessing.Process(target=show_transactions_test, args=(user_cursor, 1))
process.start()
process.join()
admin_cursor = connect(username="admin", password="").cursor()
execute_and_fetch_all(admin_cursor, "DROP USER user")
execute_and_fetch_all(admin_cursor, "DROP USER admin")
def test_explicit_transaction_output():
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin")
admin_connection = connect(username="admin", password="")
admin_cursor = admin_connection.cursor()
# Admin starts running explicit transaction
process = multiprocessing.Process(
target=process_function,
args=(superadmin_cursor, ["BEGIN", "CREATE (n:Person {id_: 1})", "CREATE (n:Person {id_: 2})"]),
)
process.start()
time.sleep(0.5)
show_results = show_transactions_test(admin_cursor, 2)
if show_results[0][2] == ["SHOW TRANSACTIONS"]:
executing_index = 0
else:
executing_index = 1
assert show_results[1 - executing_index][2] == ["CREATE (n:Person {id_: 1})", "CREATE (n:Person {id_: 2})"]
execute_and_fetch_all(superadmin_cursor, "ROLLBACK")
execute_and_fetch_all(superadmin_cursor, "DROP USER admin")
def test_superadmin_cannot_see_admin_can_see_admin():
"""Tests that superadmin cannot see the transaction created by admin but two admins can see and kill each other's transactions."""
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin1")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin1")
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin2")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin2")
# Admin starts running infinite query
admin_connection_1 = connect(username="admin1", password="")
admin_cursor_1 = admin_connection_1.cursor()
admin_connection_2 = connect(username="admin2", password="")
admin_cursor_2 = admin_connection_2.cursor()
process = multiprocessing.Process(
target=process_function, args=(admin_cursor_1, ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"])
)
process.start()
time.sleep(0.5)
# Superadmin shouldn't see the execution of the admin
show_transactions_test(superadmin_cursor, 1)
show_results = show_transactions_test(admin_cursor_2, 2)
# Don't rely on the order of intepreters in Memgraph
if show_results[0][2] == ["SHOW TRANSACTIONS"]:
executing_index = 0
else:
executing_index = 1
assert show_results[executing_index][0] == "admin2"
assert show_results[executing_index][2] == ["SHOW TRANSACTIONS"]
assert show_results[1 - executing_index][0] == "admin1"
assert show_results[1 - executing_index][2] == ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"]
# Kill transaction
long_transaction_id = show_results[1 - executing_index][1]
execute_and_fetch_all(admin_cursor_2, f"TERMINATE TRANSACTIONS '{long_transaction_id}'")
execute_and_fetch_all(superadmin_cursor, "DROP USER admin1")
execute_and_fetch_all(superadmin_cursor, "DROP USER admin2")
admin_connection_1.close()
admin_connection_2.close()
def test_admin_sees_superadmin():
"""Tests that admin created by superadmin can see the superadmin's transaction."""
superadmin_connection = connect()
superadmin_cursor = superadmin_connection.cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin")
# Admin starts running infinite query
process = multiprocessing.Process(
target=process_function, args=(superadmin_cursor, ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"])
)
process.start()
time.sleep(0.5)
admin_cursor = connect(username="admin", password="").cursor()
show_results = show_transactions_test(admin_cursor, 2)
# show_results_2 = show_transactions_test(admin_cursor, 2)
# Don't rely on the order of intepreters in Memgraph
if show_results[0][2] == ["SHOW TRANSACTIONS"]:
executing_index = 0
else:
executing_index = 1
assert show_results[executing_index][0] == "admin"
assert show_results[executing_index][2] == ["SHOW TRANSACTIONS"]
assert show_results[1 - executing_index][0] == ""
assert show_results[1 - executing_index][2] == ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"]
# Kill transaction
long_transaction_id = show_results[1 - executing_index][1]
execute_and_fetch_all(admin_cursor, f"TERMINATE TRANSACTIONS '{long_transaction_id}'")
execute_and_fetch_all(admin_cursor, "DROP USER admin")
superadmin_connection.close()
def test_admin_can_see_user_transaction():
"""Tests that admin can see user's transaction and kill it."""
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin")
execute_and_fetch_all(superadmin_cursor, "CREATE USER user")
# Admin starts running infinite query
admin_connection = connect(username="admin", password="")
admin_cursor = admin_connection.cursor()
user_connection = connect(username="user", password="")
user_cursor = user_connection.cursor()
process = multiprocessing.Process(
target=process_function, args=(user_cursor, ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"])
)
process.start()
time.sleep(0.5)
# Admin should see the user's transaction.
show_results = show_transactions_test(admin_cursor, 2)
# Don't rely on the order of intepreters in Memgraph
if show_results[0][2] == ["SHOW TRANSACTIONS"]:
executing_index = 0
else:
executing_index = 1
assert show_results[executing_index][0] == "admin"
assert show_results[executing_index][2] == ["SHOW TRANSACTIONS"]
assert show_results[1 - executing_index][0] == "user"
assert show_results[1 - executing_index][2] == ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"]
# Kill transaction
long_transaction_id = show_results[1 - executing_index][1]
execute_and_fetch_all(admin_cursor, f"TERMINATE TRANSACTIONS '{long_transaction_id}'")
execute_and_fetch_all(superadmin_cursor, "DROP USER admin")
execute_and_fetch_all(superadmin_cursor, "DROP USER user")
admin_connection.close()
user_connection.close()
def test_user_cannot_see_admin_transaction():
"""User cannot see admin's transaction but other admin can and he can kill it."""
# Superadmin creates two admins and one user
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin1")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin1")
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin2")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin2")
execute_and_fetch_all(superadmin_cursor, "CREATE USER user")
admin_connection_1 = connect(username="admin1", password="")
admin_cursor_1 = admin_connection_1.cursor()
admin_connection_2 = connect(username="admin2", password="")
admin_cursor_2 = admin_connection_2.cursor()
user_connection = connect(username="user", password="")
user_cursor = user_connection.cursor()
# Admin1 starts running long running query
process = multiprocessing.Process(
target=process_function, args=(admin_cursor_1, ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"])
)
process.start()
time.sleep(0.5)
# User should not see the admin's transaction.
show_transactions_test(user_cursor, 1)
# Second admin should see other admin's transactions
show_results = show_transactions_test(admin_cursor_2, 2)
# Don't rely on the order of intepreters in Memgraph
if show_results[0][2] == ["SHOW TRANSACTIONS"]:
executing_index = 0
else:
executing_index = 1
assert show_results[executing_index][0] == "admin2"
assert show_results[executing_index][2] == ["SHOW TRANSACTIONS"]
assert show_results[1 - executing_index][0] == "admin1"
assert show_results[1 - executing_index][2] == ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"]
# Kill transaction
long_transaction_id = show_results[1 - executing_index][1]
execute_and_fetch_all(admin_cursor_2, f"TERMINATE TRANSACTIONS '{long_transaction_id}'")
execute_and_fetch_all(superadmin_cursor, "DROP USER admin1")
execute_and_fetch_all(superadmin_cursor, "DROP USER admin2")
execute_and_fetch_all(superadmin_cursor, "DROP USER user")
admin_connection_1.close()
admin_connection_2.close()
user_connection.close()
def test_killing_non_existing_transaction():
cursor = connect().cursor()
results = execute_and_fetch_all(cursor, "TERMINATE TRANSACTIONS '1'")
assert len(results) == 1
assert results[0][0] == "1" # transaction id
assert results[0][1] == False # not killed
def test_killing_multiple_non_existing_transactions():
cursor = connect().cursor()
transactions_id = ["'1'", "'2'", "'3'"]
results = execute_and_fetch_all(cursor, f"TERMINATE TRANSACTIONS {','.join(transactions_id)}")
assert len(results) == 3
for i in range(len(results)):
assert results[i][0] == eval(transactions_id[i]) # transaction id
assert results[i][1] == False # not killed
def test_admin_killing_multiple_non_existing_transactions():
# Starting, superadmin admin
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
execute_and_fetch_all(superadmin_cursor, "GRANT TRANSACTION_MANAGEMENT TO admin")
# Connect with admin
admin_cursor = connect(username="admin", password="").cursor()
transactions_id = ["'1'", "'2'", "'3'"]
results = execute_and_fetch_all(admin_cursor, f"TERMINATE TRANSACTIONS {','.join(transactions_id)}")
assert len(results) == 3
for i in range(len(results)):
assert results[i][0] == eval(transactions_id[i]) # transaction id
assert results[i][1] == False # not killed
execute_and_fetch_all(admin_cursor, "DROP USER admin")
def test_user_killing_some_transactions():
"""Tests what happens when user can kill only some of the transactions given."""
superadmin_cursor = connect().cursor()
execute_and_fetch_all(superadmin_cursor, "CREATE USER admin")
execute_and_fetch_all(superadmin_cursor, "GRANT ALL PRIVILEGES TO admin")
execute_and_fetch_all(superadmin_cursor, "CREATE USER user1")
execute_and_fetch_all(superadmin_cursor, "REVOKE ALL PRIVILEGES FROM user1")
# Connect with user in two different sessions
admin_cursor = connect(username="admin", password="").cursor()
execute_and_fetch_all(admin_cursor, "CREATE USER user2")
execute_and_fetch_all(admin_cursor, "GRANT ALL PRIVILEGES TO user2")
user_connection_1 = connect(username="user1", password="")
user_cursor_1 = user_connection_1.cursor()
user_connection_2 = connect(username="user2", password="")
user_cursor_2 = user_connection_2.cursor()
process_1 = multiprocessing.Process(
target=process_function, args=(user_cursor_1, ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"])
)
process_2 = multiprocessing.Process(target=process_function, args=(user_cursor_2, ["BEGIN", "MATCH (n) RETURN n"]))
process_1.start()
process_2.start()
# Create another user1 connections
user_connection_1_copy = connect(username="user1", password="")
user_cursor_1_copy = user_connection_1_copy.cursor()
show_user_1_results = show_transactions_test(user_cursor_1_copy, 2)
if show_user_1_results[0][2] == ["SHOW TRANSACTIONS"]:
execution_index = 0
else:
execution_index = 1
assert show_user_1_results[1 - execution_index][2] == ["CALL infinite_query.long_query() YIELD my_id RETURN my_id"]
# Connect with admin
time.sleep(0.5)
show_admin_results = show_transactions_test(admin_cursor, 3)
for show_admin_res in show_admin_results:
if show_admin_res[2] != "[SHOW TRANSACTIONS]":
execute_and_fetch_all(admin_cursor, f"TERMINATE TRANSACTIONS '{show_admin_res[1]}'")
user_connection_1.close()
user_connection_2.close()
if __name__ == "__main__":
sys.exit(pytest.main([__file__, "-rA"]))

View File

@ -0,0 +1,14 @@
test_transaction_queue: &test_transaction_queue
cluster:
main:
args: ["--bolt-port", "7687", "--log-level=TRACE", "--also-log-to-stderr"]
log_file: "transaction_queue.log"
setup_queries: []
validation_queries: []
workloads:
- name: "test-transaction-queue" # should be the same as the python file
binary: "tests/e2e/pytest_runner.sh"
proc: "tests/e2e/transaction_queue/procedures/"
args: ["transaction_queue/test_transaction_queue.py"]
<<: *test_transaction_queue

View File

@ -247,7 +247,7 @@ Index queries for each supported vendor can be downloaded from “https://s3.eu-
|Q19|pattern_short| analytical | MATCH (n:User {id: $id})-[e]->(m) RETURN m LIMIT 1|
|Q20|single_edge_write| write | MATCH (n:User {id: $from}), (m:User {id: $to}) WITH n, m CREATE (n)-[e:Temp]->(m) RETURN e|
|Q21|single_vertex_write| write |CREATE (n:UserTemp {id : $id}) RETURN n|
|Q22|single_vertex_property_update| update | MATCH (n:User {id: $id})-[e]->(m) RETURN m LIMIT 1|
|Q22|single_vertex_property_update| update | MATCH (n:User {id: $id}) SET n.property = -1|
|Q23|single_vertex_read| read | MATCH (n:User {id : $id}) RETURN n|
## :computer: Platform

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,57 @@
# Describes all the information of single benchmark.py run.
class BenchmarkContext:
"""
Class for holding information on what type of benchmark is being executed
"""
def __init__(
self,
benchmark_target_workload: str = None, # Workload that needs to be executed (dataset/variant/group/query)
vendor_binary: str = None, # Benchmark vendor binary
vendor_name: str = None,
client_binary: str = None,
num_workers_for_import: int = None,
num_workers_for_benchmark: int = None,
single_threaded_runtime_sec: int = 0,
no_load_query_counts: bool = False,
no_save_query_counts: bool = False,
export_results: str = None,
temporary_directory: str = None,
workload_mixed: str = None, # Default mode is isolated, mixed None
workload_realistic: str = None, # Default mode is isolated, realistic None
time_dependent_execution: int = 0,
warm_up: str = None,
performance_tracking: bool = False,
no_authorization: bool = True,
customer_workloads: str = None,
vendor_args: dict = {},
) -> None:
self.benchmark_target_workload = benchmark_target_workload
self.vendor_binary = vendor_binary
self.vendor_name = vendor_name
self.client_binary = client_binary
self.num_workers_for_import = num_workers_for_import
self.num_workers_for_benchmark = num_workers_for_benchmark
self.single_threaded_runtime_sec = single_threaded_runtime_sec
self.no_load_query_counts = no_load_query_counts
self.no_save_query_counts = no_save_query_counts
self.export_results = export_results
self.temporary_directory = temporary_directory
if workload_mixed != None:
self.mode = "Mixed"
self.mode_config = workload_mixed
elif workload_realistic != None:
self.mode = "Realistic"
self.mode_config = workload_realistic
else:
self.mode = "Isolated"
self.mode_config = "Isolated run does not have a config."
self.time_dependent_execution = time_dependent_execution
self.performance_tracking = performance_tracking
self.warm_up = warm_up
self.no_authorization = no_authorization
self.customer_workloads = customer_workloads
self.vendor_args = vendor_args

View File

@ -58,6 +58,10 @@ DEFINE_bool(validation, false,
"Set to true to run client in validation mode."
"Validation mode works for singe query and returns results for validation"
"with metadata");
DEFINE_int64(time_dependent_execution, 0,
"Time-dependent executions execute the queries for a specified number of seconds."
"If all queries are executed, and there is still time, queries are rerun again."
"If the time runs out, the client is done with the job and returning results.");
std::pair<std::map<std::string, memgraph::communication::bolt::Value>, uint64_t> ExecuteNTimesTillSuccess(
memgraph::communication::bolt::Client *client, const std::string &query,
@ -220,7 +224,114 @@ nlohmann::json LatencyStatistics(std::vector<std::vector<double>> &worker_query_
return statistics;
}
void Execute(
void ExecuteTimeDependentWorkload(
const std::vector<std::pair<std::string, std::map<std::string, memgraph::communication::bolt::Value>>> &queries,
std::ostream *stream) {
std::vector<std::thread> threads;
threads.reserve(FLAGS_num_workers);
std::vector<uint64_t> worker_retries(FLAGS_num_workers, 0);
std::vector<Metadata> worker_metadata(FLAGS_num_workers, Metadata());
std::vector<double> worker_duration(FLAGS_num_workers, 0.0);
std::vector<std::vector<double>> worker_query_durations(FLAGS_num_workers);
// Start workers and execute queries.
auto size = queries.size();
std::atomic<bool> run(false);
std::atomic<uint64_t> ready(0);
std::atomic<uint64_t> position(0);
std::atomic<bool> start_workload_timer(false);
std::chrono::time_point<std::chrono::steady_clock> workload_start;
std::chrono::duration<double> time_limit = std::chrono::seconds(FLAGS_time_dependent_execution);
for (int worker = 0; worker < FLAGS_num_workers; ++worker) {
threads.push_back(std::thread([&, worker]() {
memgraph::io::network::Endpoint endpoint(FLAGS_address, FLAGS_port);
memgraph::communication::ClientContext context(FLAGS_use_ssl);
memgraph::communication::bolt::Client client(context);
client.Connect(endpoint, FLAGS_username, FLAGS_password);
ready.fetch_add(1, std::memory_order_acq_rel);
while (!run.load(std::memory_order_acq_rel))
;
auto &retries = worker_retries[worker];
auto &metadata = worker_metadata[worker];
auto &duration = worker_duration[worker];
auto &query_duration = worker_query_durations[worker];
// After all threads have been initialised, start the workload timer
if (!start_workload_timer.load()) {
workload_start = std::chrono::steady_clock::now();
start_workload_timer.store(true);
}
memgraph::utils::Timer worker_timer;
while (std::chrono::duration_cast<std::chrono::duration<double>>(std::chrono::steady_clock::now() -
workload_start) < time_limit) {
auto pos = position.fetch_add(1, std::memory_order_acq_rel);
if (pos >= size) {
/// Get back to inital position
position.store(0, std::memory_order_acq_rel);
pos = position.fetch_add(1, std::memory_order_acq_rel);
}
const auto &query = queries[pos];
memgraph::utils::Timer query_timer;
auto ret = ExecuteNTimesTillSuccess(&client, query.first, query.second, FLAGS_max_retries);
query_duration.emplace_back(query_timer.Elapsed().count());
retries += ret.second;
metadata.Append(ret.first);
duration = worker_timer.Elapsed().count();
}
client.Close();
}));
}
// Synchronize workers and collect runtime.
while (ready.load(std::memory_order_acq_rel) < FLAGS_num_workers)
;
run.store(true);
for (int i = 0; i < FLAGS_num_workers; ++i) {
threads[i].join();
}
// Create and output summary.
Metadata final_metadata;
uint64_t final_retries = 0;
double final_duration = 0.0;
for (int i = 0; i < FLAGS_num_workers; ++i) {
final_metadata += worker_metadata[i];
final_retries += worker_retries[i];
final_duration += worker_duration[i];
}
int total_iterations = 0;
std::for_each(worker_query_durations.begin(), worker_query_durations.end(),
[&](const std::vector<double> &v) { total_iterations += v.size(); });
final_duration /= FLAGS_num_workers;
double execution_delta = time_limit.count() / final_duration;
// This is adjusted throughput based on how much longer did workload execution time took.
double throughput = (total_iterations / final_duration) * execution_delta;
double raw_throughput = total_iterations / final_duration;
nlohmann::json summary = nlohmann::json::object();
summary["count"] = queries.size();
summary["duration"] = final_duration;
summary["time_limit"] = FLAGS_time_dependent_execution;
summary["queries_executed"] = total_iterations;
summary["throughput"] = throughput;
summary["raw_throughput"] = raw_throughput;
summary["latency_stats"] = LatencyStatistics(worker_query_durations);
summary["retries"] = final_retries;
summary["metadata"] = final_metadata.Export();
summary["num_workers"] = FLAGS_num_workers;
(*stream) << summary.dump() << std::endl;
}
void ExecuteWorkload(
const std::vector<std::pair<std::string, std::map<std::string, memgraph::communication::bolt::Value>>> &queries,
std::ostream *stream) {
std::vector<std::thread> threads;
@ -259,7 +370,7 @@ void Execute(
const auto &query = queries[pos];
memgraph::utils::Timer query_timer;
auto ret = ExecuteNTimesTillSuccess(&client, query.first, query.second, FLAGS_max_retries);
query_duration.push_back(query_timer.Elapsed().count());
query_duration.emplace_back(query_timer.Elapsed().count());
retries += ret.second;
metadata.Append(ret.first);
}
@ -272,6 +383,7 @@ void Execute(
while (ready.load(std::memory_order_acq_rel) < FLAGS_num_workers)
;
run.store(true, std::memory_order_acq_rel);
for (int i = 0; i < FLAGS_num_workers; ++i) {
threads[i].join();
}
@ -363,6 +475,7 @@ int main(int argc, char **argv) {
spdlog::info("Input: {}", FLAGS_input);
spdlog::info("Output: {}", FLAGS_output);
spdlog::info("Validation: {}", FLAGS_validation);
spdlog::info("Time dependend execution: {}", FLAGS_time_dependent_execution);
memgraph::communication::SSLInit sslInit;
@ -390,7 +503,7 @@ int main(int argc, char **argv) {
while (std::getline(*istream, query)) {
auto trimmed = memgraph::utils::Trim(query);
if (trimmed == "" || trimmed == ";") {
Execute(queries, ostream);
ExecuteWorkload(queries, ostream);
queries.clear();
continue;
}
@ -406,7 +519,7 @@ int main(int argc, char **argv) {
"array!");
MG_ASSERT(data.is_array() && data.size() == 2, "Each item of the loaded JSON queries must be an array!");
if (data.size() == 0) {
Execute(queries, ostream);
ExecuteWorkload(queries, ostream);
queries.clear();
continue;
}
@ -424,10 +537,12 @@ int main(int argc, char **argv) {
}
}
if (!FLAGS_validation) {
Execute(queries, ostream);
} else {
if (FLAGS_validation) {
ExecuteValidation(queries, ostream);
} else if (FLAGS_time_dependent_execution > 0) {
ExecuteTimeDependentWorkload(queries, ostream);
} else {
ExecuteWorkload(queries, ostream);
}
return 0;

View File

@ -77,10 +77,10 @@ def compare_results(results_from, results_to, fields, ignored, different_vendors
recursive_get(summary_from, "database", key, value=None),
summary_to["database"][key],
)
elif summary_to.get("query_statistics") != None and key in summary_to["query_statistics"]:
elif summary_to.get("latency_stats") != None and key in summary_to["latency_stats"]:
row[key] = compute_diff(
recursive_get(summary_from, "query_statistics", key, value=None),
summary_to["query_statistics"][key],
recursive_get(summary_from, "latency_stats", key, value=None),
summary_to["latency_stats"][key],
)
elif not different_vendors:
row[key] = compute_diff(
@ -160,7 +160,10 @@ if __name__ == "__main__":
help="Comparing different vendors, there is no need for metadata, duration, count check.",
)
parser.add_argument(
"--difference-threshold", type=float, help="Difference threshold for memory and throughput, 0.02 = 2% "
"--difference-threshold",
type=float,
default=0.02,
help="Difference threshold for memory and throughput, 0.02 = 2% ",
)
args = parser.parse_args()

View File

View File

@ -0,0 +1,500 @@
import argparse
import csv
import sys
from collections import defaultdict
from pathlib import Path
import helpers
# Most recent list of LDBC datasets available at: https://github.com/ldbc/data-sets-surf-repository
INTERACTIVE_LINK = {
"sf0.1": "https://repository.surfsara.nl/datasets/cwi/snb/files/social_network-csv_basic/social_network-csv_basic-sf0.1.tar.zst",
"sf0.3": "https://repository.surfsara.nl/datasets/cwi/snb/files/social_network-csv_basic/social_network-csv_basic-sf0.3.tar.zst",
"sf1": "https://repository.surfsara.nl/datasets/cwi/snb/files/social_network-csv_basic/social_network-csv_basic-sf1.tar.zst",
"sf3": "https://repository.surfsara.nl/datasets/cwi/snb/files/social_network-csv_basic/social_network-csv_basic-sf3.tar.zst",
"sf10": "https://repository.surfsara.nl/datasets/cwi/snb/files/social_network-csv_basic/social_network-csv_basic-sf10.tar.zst",
}
BI_LINK = {
"sf1": "https://pub-383410a98aef4cb686f0c7601eddd25f.r2.dev/bi-pre-audit/bi-sf1-composite-projected-fk.tar.zst",
"sf3": "https://pub-383410a98aef4cb686f0c7601eddd25f.r2.dev/bi-pre-audit/bi-sf3-composite-projected-fk.tar.zst",
"sf10": "https://pub-383410a98aef4cb686f0c7601eddd25f.r2.dev/bi-pre-audit/bi-sf10-composite-projected-fk.tar.zst",
}
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="LDBC CSV to CYPHERL converter",
description="""Converts all LDBC CSV files to CYPHERL transactions, for faster Memgraph load""",
)
parser.add_argument(
"--size",
required=True,
choices=["0.1", "0.3", "1", "3", "10"],
help="Interactive: (0.1 , 0.3, 1, 3, 10) BI: (1, 3, 10)",
)
parser.add_argument("--type", required=True, choices=["interactive", "bi"], help="interactive or bi")
args = parser.parse_args()
output_directory = Path().absolute() / ".cache" / "LDBC_generated"
output_directory.mkdir(exist_ok=True)
if args.type == "interactive":
NODES_INTERACTIVE = [
{"filename": "Place", "label": "Place"},
{"filename": "Organisation", "label": "Organisation"},
{"filename": "TagClass", "label": "TagClass"},
{"filename": "Tag", "label": "Tag"},
{"filename": "Comment", "label": "Message:Comment"},
{"filename": "Forum", "label": "Forum"},
{"filename": "Person", "label": "Person"},
{"filename": "Post", "label": "Message:Post"},
]
EDGES_INTERACTIVE = [
{
"filename": "Place_isPartOf_Place",
"source_label": "Place",
"type": "IS_PART_OF",
"target_label": "Place",
},
{
"filename": "TagClass_isSubclassOf_TagClass",
"source_label": "TagClass",
"type": "IS_SUBCLASS_OF",
"target_label": "TagClass",
},
{
"filename": "Organisation_isLocatedIn_Place",
"source_label": "Organisation",
"type": "IS_LOCATED_IN",
"target_label": "Place",
},
{"filename": "Tag_hasType_TagClass", "source_label": "Tag", "type": "HAS_TYPE", "target_label": "TagClass"},
{
"filename": "Comment_hasCreator_Person",
"source_label": "Comment",
"type": "HAS_CREATOR",
"target_label": "Person",
},
{
"filename": "Comment_isLocatedIn_Place",
"source_label": "Comment",
"type": "IS_LOCATED_IN",
"target_label": "Place",
},
{
"filename": "Comment_replyOf_Comment",
"source_label": "Comment",
"type": "REPLY_OF",
"target_label": "Comment",
},
{"filename": "Comment_replyOf_Post", "source_label": "Comment", "type": "REPLY_OF", "target_label": "Post"},
{
"filename": "Forum_containerOf_Post",
"source_label": "Forum",
"type": "CONTAINER_OF",
"target_label": "Post",
},
{
"filename": "Forum_hasMember_Person",
"source_label": "Forum",
"type": "HAS_MEMBER",
"target_label": "Person",
},
{
"filename": "Forum_hasModerator_Person",
"source_label": "Forum",
"type": "HAS_MODERATOR",
"target_label": "Person",
},
{"filename": "Forum_hasTag_Tag", "source_label": "Forum", "type": "HAS_TAG", "target_label": "Tag"},
{
"filename": "Person_hasInterest_Tag",
"source_label": "Person",
"type": "HAS_INTEREST",
"target_label": "Tag",
},
{
"filename": "Person_isLocatedIn_Place",
"source_label": "Person",
"type": "IS_LOCATED_IN",
"target_label": "Place",
},
{"filename": "Person_knows_Person", "source_label": "Person", "type": "KNOWS", "target_label": "Person"},
{"filename": "Person_likes_Comment", "source_label": "Person", "type": "LIKES", "target_label": "Comment"},
{"filename": "Person_likes_Post", "source_label": "Person", "type": "LIKES", "target_label": "Post"},
{
"filename": "Post_hasCreator_Person",
"source_label": "Post",
"type": "HAS_CREATOR",
"target_label": "Person",
},
{"filename": "Comment_hasTag_Tag", "source_label": "Comment", "type": "HAS_TAG", "target_label": "Tag"},
{"filename": "Post_hasTag_Tag", "source_label": "Post", "type": "HAS_TAG", "target_label": "Tag"},
{
"filename": "Post_isLocatedIn_Place",
"source_label": "Post",
"type": "IS_LOCATED_IN",
"target_label": "Place",
},
{
"filename": "Person_studyAt_Organisation",
"source_label": "Person",
"type": "STUDY_AT",
"target_label": "Organisation",
},
{
"filename": "Person_workAt_Organisation",
"source_label": "Person",
"type": "WORK_AT",
"target_label": "Organisation",
},
]
file_size = "sf{}".format(args.size)
out_file = "ldbc_interactive_{}.cypher".format(file_size)
output = output_directory / out_file
if output.exists():
output.unlink()
files_present = None
for file in output_directory.glob("**/*.tar.zst"):
if "basic-" + file_size in file.name:
files_present = file.with_suffix("").with_suffix("")
break
if not files_present:
try:
print("Downloading the file... " + INTERACTIVE_LINK[file_size])
downloaded_file = helpers.download_file(INTERACTIVE_LINK[file_size], output_directory.absolute())
print("Unpacking the file..." + downloaded_file)
files_present = helpers.unpack_tar_zst(Path(downloaded_file))
except:
print("Issue with downloading and unpacking the file, check if links are working properly.")
raise
input_files = {}
for file in files_present.glob("**/*.csv"):
name = file.name.replace("_0_0.csv", "").lower()
input_files[name] = file
for node_file in NODES_INTERACTIVE:
key = node_file["filename"].lower()
default_label = node_file["label"]
query = None
if key in input_files.keys():
with input_files[key].open("r") as input_f, output.open("a") as output_f:
reader = csv.DictReader(input_f, delimiter="|")
for row in reader:
if "type" in row.keys():
label = default_label + ":" + row.pop("type").capitalize()
else:
label = default_label
query = "CREATE (:{} {{id:{}, ".format(label, row.pop("id"))
# Format properties to fit Memgraph
for k, v in row.items():
if k == "creationDate":
row[k] = 'localDateTime("{}")'.format(v[0:-5])
elif k == "birthday":
row[k] = 'date("{}")'.format(v)
elif k == "length":
row[k] = "toInteger({})".format(v)
else:
row[k] = '"{}"'.format(v)
prop_string = ", ".join("{} : {}".format(k, v) for k, v in row.items())
query = query + prop_string + "});"
output_f.write(query + "\n")
print("Converted file: " + input_files[key].name + " to " + output.name)
else:
print("Didn't process node file: " + key)
raise Exception("Didn't find the file that was needed!")
for edge_file in EDGES_INTERACTIVE:
key = edge_file["filename"].lower()
source_label = edge_file["source_label"]
edge_type = edge_file["type"]
target_label = edge_file["target_label"]
if key in input_files.keys():
query = None
with input_files[key].open("r") as input_f, output.open("a") as output_f:
sufixl = ".id"
sufixr = ".id"
# Handle identical label/key in CSV header
if source_label == target_label:
sufixl = "l"
sufixr = "r"
# Move a place from header
header = next(input_f).strip().split("|")
reader = csv.DictReader(
input_f, delimiter="|", fieldnames=([source_label + sufixl, target_label + sufixr] + header[2:])
)
for row in reader:
query = "MATCH (n1:{} {{id:{}}}), (n2:{} {{id:{}}}) ".format(
source_label, row.pop(source_label + sufixl), target_label, row.pop(target_label + sufixr)
)
for k, v in row.items():
if "date" in k.lower():
# Take time zone out
row[k] = 'localDateTime("{}")'.format(v[0:-5])
elif "workfrom" in k.lower() or "classyear" in k.lower():
row[k] = 'toInteger("{}")'.format(v)
else:
row[k] = '"{}"'.format(v)
edge_part = "CREATE (n1)-[:{}{{".format(edge_type)
prop_string = ", ".join("{} : {}".format(k, v) for k, v in row.items())
query = query + edge_part + prop_string + "}]->(n2);"
output_f.write(query + "\n")
print("Converted file: " + input_files[key].name + " to " + output.name)
else:
print("Didn't process Edge file: " + key)
raise Exception("Didn't find the file that was needed!")
elif args.type == "bi":
NODES_BI = [
{"filename": "Place", "label": "Place"},
{"filename": "Organisation", "label": "Organisation"},
{"filename": "TagClass", "label": "TagClass"},
{"filename": "Tag", "label": "Tag"},
{"filename": "Comment", "label": "Message:Comment"},
{"filename": "Forum", "label": "Forum"},
{"filename": "Person", "label": "Person"},
{"filename": "Post", "label": "Message:Post"},
]
EDGES_BI = [
{
"filename": "Place_isPartOf_Place",
"source_label": "Place",
"type": "IS_PART_OF",
"target_label": "Place",
},
{
"filename": "TagClass_isSubclassOf_TagClass",
"source_label": "TagClass",
"type": "IS_SUBCLASS_OF",
"target_label": "TagClass",
},
{
"filename": "Organisation_isLocatedIn_Place",
"source_label": "Organisation",
"type": "IS_LOCATED_IN",
"target_label": "Place",
},
{"filename": "Tag_hasType_TagClass", "source_label": "Tag", "type": "HAS_TYPE", "target_label": "TagClass"},
{
"filename": "Comment_hasCreator_Person",
"source_label": "Comment",
"type": "HAS_CREATOR",
"target_label": "Person",
},
# Change place to Country
{
"filename": "Comment_isLocatedIn_Country",
"source_label": "Comment",
"type": "IS_LOCATED_IN",
"target_label": "Country",
},
{
"filename": "Comment_replyOf_Comment",
"source_label": "Comment",
"type": "REPLY_OF",
"target_label": "Comment",
},
{"filename": "Comment_replyOf_Post", "source_label": "Comment", "type": "REPLY_OF", "target_label": "Post"},
{
"filename": "Forum_containerOf_Post",
"source_label": "Forum",
"type": "CONTAINER_OF",
"target_label": "Post",
},
{
"filename": "Forum_hasMember_Person",
"source_label": "Forum",
"type": "HAS_MEMBER",
"target_label": "Person",
},
{
"filename": "Forum_hasModerator_Person",
"source_label": "Forum",
"type": "HAS_MODERATOR",
"target_label": "Person",
},
{"filename": "Forum_hasTag_Tag", "source_label": "Forum", "type": "HAS_TAG", "target_label": "Tag"},
{
"filename": "Person_hasInterest_Tag",
"source_label": "Person",
"type": "HAS_INTEREST",
"target_label": "Tag",
},
# Changed place to City
{
"filename": "Person_isLocatedIn_City",
"source_label": "Person",
"type": "IS_LOCATED_IN",
"target_label": "City",
},
{"filename": "Person_knows_Person", "source_label": "Person", "type": "KNOWS", "target_label": "Person"},
{"filename": "Person_likes_Comment", "source_label": "Person", "type": "LIKES", "target_label": "Comment"},
{"filename": "Person_likes_Post", "source_label": "Person", "type": "LIKES", "target_label": "Post"},
{
"filename": "Post_hasCreator_Person",
"source_label": "Post",
"type": "HAS_CREATOR",
"target_label": "Person",
},
{"filename": "Comment_hasTag_Tag", "source_label": "Comment", "type": "HAS_TAG", "target_label": "Tag"},
{"filename": "Post_hasTag_Tag", "source_label": "Post", "type": "HAS_TAG", "target_label": "Tag"},
# Change place to Country
{
"filename": "Post_isLocatedIn_Country",
"source_label": "Post",
"type": "IS_LOCATED_IN",
"target_label": "Country",
},
# Changed organisation to University
{
"filename": "Person_studyAt_University",
"source_label": "Person",
"type": "STUDY_AT",
"target_label": "University",
},
# Changed organisation to Company
{
"filename": "Person_workAt_Company",
"source_label": "Person",
"type": "WORK_AT",
"target_label": "Company",
},
]
file_size = "sf{}".format(args.size)
out_file = "ldbc_bi_{}.cypher".format(file_size)
output = output_directory / out_file
if output.exists():
output.unlink()
files_present = None
for file in output_directory.glob("**/*.tar.zst"):
if "bi-" + file_size in file.name:
files_present = file.with_suffix("").with_suffix("")
break
if not files_present:
try:
print("Downloading the file... " + BI_LINK[file_size])
downloaded_file = helpers.download_file(BI_LINK[file_size], output_directory.absolute())
print("Unpacking the file..." + downloaded_file)
files_present = helpers.unpack_tar_zst(Path(downloaded_file))
except:
print("Issue with downloading and unpacking the file, check if links are working properly.")
raise
for file in files_present.glob("**/*.csv.gz"):
if "initial_snapshot" in file.parts:
helpers.unpack_gz(file)
input_files = defaultdict(list)
for file in files_present.glob("**/*.csv"):
key = file.parents[0].name
input_files[file.parents[0].name].append(file)
for node_file in NODES_BI:
key = node_file["filename"]
default_label = node_file["label"]
query = None
if key in input_files.keys():
for part_file in input_files[key]:
with part_file.open("r") as input_f, output.open("a") as output_f:
reader = csv.DictReader(input_f, delimiter="|")
for row in reader:
if "type" in row.keys():
label = default_label + ":" + row.pop("type")
else:
label = default_label
query = "CREATE (:{} {{id:{}, ".format(label, row.pop("id"))
# Format properties to fit Memgraph
for k, v in row.items():
if k == "creationDate":
row[k] = 'localDateTime("{}")'.format(v[0:-6])
elif k == "birthday":
row[k] = 'date("{}")'.format(v)
elif k == "length":
row[k] = "toInteger({})".format(v)
else:
row[k] = '"{}"'.format(v)
prop_string = ", ".join("{} : {}".format(k, v) for k, v in row.items())
query = query + prop_string + "});"
output_f.write(query + "\n")
print("Key: " + key + " Converted file: " + part_file.name + " to " + output.name)
else:
print("Didn't process node file: " + key)
for edge_file in EDGES_BI:
key = edge_file["filename"]
source_label = edge_file["source_label"]
edge_type = edge_file["type"]
target_label = edge_file["target_label"]
if key in input_files.keys():
for part_file in input_files[key]:
query = None
with part_file.open("r") as input_f, output.open("a") as output_f:
sufixl = "Id"
sufixr = "Id"
# Handle identical label/key in CSV header
if source_label == target_label:
sufixl = "l"
sufixr = "r"
# Move a place from header
header = next(input_f).strip().split("|")
if len(header) >= 3:
reader = csv.DictReader(
input_f,
delimiter="|",
fieldnames=(["date", source_label + sufixl, target_label + sufixr] + header[3:]),
)
else:
reader = csv.DictReader(
input_f,
delimiter="|",
fieldnames=([source_label + sufixl, target_label + sufixr] + header[2:]),
)
for row in reader:
query = "MATCH (n1:{} {{id:{}}}), (n2:{} {{id:{}}}) ".format(
source_label,
row.pop(source_label + sufixl),
target_label,
row.pop(target_label + sufixr),
)
for k, v in row.items():
if "date" in k.lower():
# Take time zone out
row[k] = 'localDateTime("{}")'.format(v[0:-6])
elif k == "classYear" or k == "workFrom":
row[k] = 'toInteger("{}")'.format(v)
else:
row[k] = '"{}"'.format(v)
edge_part = "CREATE (n1)-[:{}{{".format(edge_type)
prop_string = ", ".join("{} : {}".format(k, v) for k, v in row.items())
query = query + edge_part + prop_string + "}]->(n2);"
output_f.write(query + "\n")
print("Key: " + key + " Converted file: " + part_file.name + " to " + output.name)
else:
print("Didn't process Edge file: " + key)
raise Exception("Didn't find the file that was needed!")

View File

@ -16,14 +16,20 @@ def parse_arguments():
help="Forward name and paths to vendors binary"
"Example: --vendor memgraph /path/to/binary --vendor neo4j /path/to/binary",
)
parser.add_argument(
"--dataset-size",
default="small",
choices=["small", "medium", "large"],
help="Pick a dataset size (small, medium, large)",
"--dataset-name",
default="",
help="Dataset name you wish to execute",
)
parser.add_argument("--dataset-group", default="basic", help="Select a group of queries")
parser.add_argument(
"--dataset-size",
default="",
help="Pick a dataset variant you wish to execute",
)
parser.add_argument("--dataset-group", default="", help="Select a group of queries")
parser.add_argument(
"--realistic",
@ -53,88 +59,110 @@ def parse_arguments():
return args
def run_full_benchmarks(vendor, binary, dataset_size, dataset_group, realistic, mixed):
def run_full_benchmarks(vendor, binary, dataset, dataset_size, dataset_group, realistic, mixed):
configurations = [
# Basic full group test cold
# Basic isolated test cold
[
"--export-results",
vendor + "_" + dataset_size + "_cold_isolated.json",
vendor + "_" + dataset + "_" + dataset_size + "_cold_isolated.json",
],
# Basic full group test hot
# Basic isolated test hot
[
"--export-results",
vendor + "_" + dataset_size + "_hot_isolated.json",
"--warmup-run",
vendor + "_" + dataset + "_" + dataset_size + "_hot_isolated.json",
"--warm-up",
"hot",
],
# Basic isolated test vulcanic
[
"--export-results",
vendor + "_" + dataset + "_" + dataset_size + "_vulcanic_isolated.json",
"--warm-up",
"vulcanic",
],
]
# Configurations for full workload
for count, write, read, update, analytical in realistic:
cold = [
"--export-results",
vendor
+ "_"
+ dataset_size
+ "_cold_realistic_{}_{}_{}_{}_{}.json".format(count, write, read, update, analytical),
"--mixed-workload",
count,
write,
read,
update,
analytical,
]
if realistic:
# Configurations for full workload
for count, write, read, update, analytical in realistic:
cold = [
"--export-results",
vendor
+ "_"
+ dataset
+ "_"
+ dataset_size
+ "_cold_realistic_{}_{}_{}_{}_{}.json".format(count, write, read, update, analytical),
"--workload-realistic",
count,
write,
read,
update,
analytical,
]
hot = [
"--export-results",
vendor
+ "_"
+ dataset_size
+ "_hot_realistic_{}_{}_{}_{}_{}.json".format(count, write, read, update, analytical),
"--warmup-run",
"--mixed-workload",
count,
write,
read,
update,
analytical,
]
configurations.append(cold)
configurations.append(hot)
hot = [
"--export-results",
vendor
+ "_"
+ dataset
+ "_"
+ dataset_size
+ "_hot_realistic_{}_{}_{}_{}_{}.json".format(count, write, read, update, analytical),
"--warm-up",
"hot",
"--workload-realistic",
count,
write,
read,
update,
analytical,
]
# Configurations for workload per query
for count, write, read, update, analytical, query in mixed:
cold = [
"--export-results",
vendor
+ "_"
+ dataset_size
+ "_cold_mixed_{}_{}_{}_{}_{}_{}.json".format(count, write, read, update, analytical, query),
"--mixed-workload",
count,
write,
read,
update,
analytical,
query,
]
hot = [
"--export-results",
vendor
+ "_"
+ dataset_size
+ "_hot_mixed_{}_{}_{}_{}_{}_{}.json".format(count, write, read, update, analytical, query),
"--warmup-run",
"--mixed-workload",
count,
write,
read,
update,
analytical,
query,
]
configurations.append(cold)
configurations.append(hot)
configurations.append(cold)
configurations.append(hot)
if mixed:
# Configurations for workload per query
for count, write, read, update, analytical, query in mixed:
cold = [
"--export-results",
vendor
+ "_"
+ dataset
+ "_"
+ dataset_size
+ "_cold_mixed_{}_{}_{}_{}_{}_{}.json".format(count, write, read, update, analytical, query),
"--workload-mixed",
count,
write,
read,
update,
analytical,
query,
]
hot = [
"--export-results",
vendor
+ "_"
+ dataset
+ "_"
+ dataset_size
+ "_hot_mixed_{}_{}_{}_{}_{}_{}.json".format(count, write, read, update, analytical, query),
"--warm-up",
"hot",
"--workload-mixed",
count,
write,
read,
update,
analytical,
query,
]
configurations.append(cold)
configurations.append(hot)
default_args = [
"python3",
@ -146,9 +174,7 @@ def run_full_benchmarks(vendor, binary, dataset_size, dataset_group, realistic,
"--num-workers-for-benchmark",
"12",
"--no-authorization",
"pokec/" + dataset_size + "/" + dataset_group + "/*",
"--tail-latency",
"100",
dataset + "/" + dataset_size + "/" + dataset_group + "/*",
]
for config in configurations:
@ -157,11 +183,11 @@ def run_full_benchmarks(vendor, binary, dataset_size, dataset_group, realistic,
subprocess.run(args=full_config, check=True)
def collect_all_results(vendor_name, dataset_size, dataset_group):
def collect_all_results(vendor_name, dataset, dataset_size, dataset_group):
working_directory = Path().absolute()
print(working_directory)
results = sorted(working_directory.glob(vendor_name + "_" + dataset_size + "_*.json"))
summary = {"pokec": {dataset_size: {dataset_group: {}}}}
results = sorted(working_directory.glob(vendor_name + "_" + dataset + "_" + dataset_size + "_*.json"))
summary = {dataset: {dataset_size: {dataset_group: {}}}}
for file in results:
if "summary" in file.name:
@ -169,19 +195,22 @@ def collect_all_results(vendor_name, dataset_size, dataset_group):
f = file.open()
data = json.loads(f.read())
if data["__run_configuration__"]["condition"] == "hot":
for key, value in data["pokec"][dataset_size][dataset_group].items():
for key, value in data[dataset][dataset_size][dataset_group].items():
key_condition = key + "_hot"
summary["pokec"][dataset_size][dataset_group][key_condition] = value
summary[dataset][dataset_size][dataset_group][key_condition] = value
elif data["__run_configuration__"]["condition"] == "cold":
for key, value in data["pokec"][dataset_size][dataset_group].items():
for key, value in data[dataset][dataset_size][dataset_group].items():
key_condition = key + "_cold"
summary["pokec"][dataset_size][dataset_group][key_condition] = value
summary[dataset][dataset_size][dataset_group][key_condition] = value
elif data["__run_configuration__"]["condition"] == "vulcanic":
for key, value in data[dataset][dataset_size][dataset_group].items():
key_condition = key + "_vulcanic"
summary[dataset][dataset_size][dataset_group][key_condition] = value
print(summary)
json_object = json.dumps(summary, indent=4)
print(json_object)
with open(vendor_name + "_" + dataset_size + "_summary.json", "w") as f:
with open(vendor_name + "_" + dataset + "_" + dataset_size + "_summary.json", "w") as f:
json.dump(summary, f)
@ -194,16 +223,17 @@ if __name__ == "__main__":
vendor_names = {"memgraph", "neo4j"}
for vendor_name, vendor_binary in args.vendor:
path = Path(vendor_binary)
if vendor_name.lower() in vendor_names and (path.is_file() or path.is_dir()):
if vendor_name.lower() in vendor_names and path.is_file():
run_full_benchmarks(
vendor_name,
vendor_binary,
args.dataset_name,
args.dataset_size,
args.dataset_group,
realistic,
mixed,
)
collect_all_results(vendor_name, args.dataset_size, args.dataset_group)
collect_all_results(vendor_name, args.dataset_name, args.dataset_size, args.dataset_group)
else:
raise Exception(
"Check that vendor: {} is supported and you are passing right path: {} to binary.".format(

View File

@ -1,4 +1,4 @@
# Copyright 2021 Memgraph Ltd.
# Copyright 2023 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -9,11 +9,21 @@
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import collections
import copy
import fnmatch
import importlib
import inspect
import json
import os
import subprocess
import sys
from pathlib import Path
import workloads
from benchmark_context import BenchmarkContext
from workloads import *
from workloads import base
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
@ -28,22 +38,70 @@ def get_binary_path(path, base=""):
def download_file(url, path):
ret = subprocess.run(["wget", "-nv", "--content-disposition", url],
stderr=subprocess.PIPE, cwd=path, check=True)
ret = subprocess.run(["wget", "-nv", "--content-disposition", url], stderr=subprocess.PIPE, cwd=path, check=True)
data = ret.stderr.decode("utf-8")
tmp = data.split("->")[1]
name = tmp[tmp.index('"') + 1:tmp.rindex('"')]
name = tmp[tmp.index('"') + 1 : tmp.rindex('"')]
return os.path.join(path, name)
def unpack_and_move_file(input_path, output_path):
def unpack_gz_and_move_file(input_path, output_path):
if input_path.endswith(".gz"):
subprocess.run(["gunzip", input_path],
stdout=subprocess.DEVNULL, check=True)
subprocess.run(["gunzip", input_path], stdout=subprocess.DEVNULL, check=True)
input_path = input_path[:-3]
os.rename(input_path, output_path)
def unpack_gz(input_path: Path):
if input_path.suffix == ".gz":
subprocess.run(["gzip", "-d", input_path], capture_output=True, check=True)
input_path = input_path.with_suffix("")
return input_path
def unpack_zip(input_path: Path):
if input_path.suffix == ".zip":
subprocess.run(["unzip", input_path], capture_output=True, check=True, cwd=input_path.parent)
input_path = input_path.with_suffix("")
return input_path
def unpack_tar_zst(input_path: Path):
if input_path.suffix == ".zst":
subprocess.run(
["tar", "--use-compress-program=unzstd", "-xvf", input_path],
cwd=input_path.parent,
capture_output=True,
check=True,
)
input_path = input_path.with_suffix("").with_suffix("")
return input_path
def unpack_tar_gz(input_path: Path):
if input_path.suffix == ".gz":
subprocess.run(
["tar", "-xvf", input_path],
cwd=input_path.parent,
capture_output=True,
check=True,
)
input_path = input_path.with_suffix("").with_suffix("")
return input_path
def unpack_tar_zst_and_move(input_path: Path, output_path: Path):
if input_path.suffix == ".zst":
subprocess.run(
["tar", "--use-compress-program=unzstd", "-xvf", input_path],
cwd=input_path.parent,
capture_output=True,
check=True,
)
input_path = input_path.with_suffix("").with_suffix("")
return input_path.rename(output_path)
def ensure_directory(path):
if not os.path.exists(path):
os.makedirs(path)
@ -51,6 +109,129 @@ def ensure_directory(path):
raise Exception("The path '{}' should be a directory!".format(path))
def get_available_workloads(customer_workloads: str = None) -> dict:
generators = {}
for module in map(workloads.__dict__.get, workloads.__all__):
for key in dir(module):
if key.startswith("_"):
continue
base_class = getattr(module, key)
if not inspect.isclass(base_class) or not issubclass(base_class, base.Workload):
continue
queries = collections.defaultdict(list)
for funcname in dir(base_class):
if not funcname.startswith("benchmark__"):
continue
group, query = funcname.split("__")[1:]
queries[group].append((query, funcname))
generators[base_class.NAME] = (base_class, dict(queries))
if customer_workloads:
head_tail = os.path.split(customer_workloads)
path_without_dataset_name = head_tail[0]
dataset_name = head_tail[1].split(".")[0]
sys.path.append(path_without_dataset_name)
dataset_to_use = importlib.import_module(dataset_name)
for key in dir(dataset_to_use):
if key.startswith("_"):
continue
base_class = getattr(dataset_to_use, key)
if not inspect.isclass(base_class) or not issubclass(base_class, base.Workload):
continue
queries = collections.defaultdict(list)
for funcname in dir(base_class):
if not funcname.startswith("benchmark__"):
continue
group, query = funcname.split("__")[1:]
queries[group].append((query, funcname))
generators[base_class.NAME] = (base_class, dict(queries))
return generators
def list_available_workloads(customer_workloads: str = None):
generators = get_available_workloads(customer_workloads)
for name in sorted(generators.keys()):
print("Dataset:", name)
dataset, queries = generators[name]
print(
" Variants:",
", ".join(dataset.VARIANTS),
"(default: " + dataset.DEFAULT_VARIANT + ")",
)
for group in sorted(queries.keys()):
print(" Group:", group)
for query_name, query_func in queries[group]:
print(" Query:", query_name)
def match_patterns(workload, variant, group, query, is_default_variant, patterns):
for pattern in patterns:
verdict = [fnmatch.fnmatchcase(workload, pattern[0])]
if pattern[1] != "":
verdict.append(fnmatch.fnmatchcase(variant, pattern[1]))
else:
verdict.append(is_default_variant)
verdict.append(fnmatch.fnmatchcase(group, pattern[2]))
verdict.append(fnmatch.fnmatchcase(query, pattern[3]))
if all(verdict):
return True
return False
def filter_workloads(available_workloads: dict, benchmark_context: BenchmarkContext) -> list:
patterns = benchmark_context.benchmark_target_workload
for i in range(len(patterns)):
pattern = patterns[i].split("/")
if len(pattern) > 5 or len(pattern) == 0:
raise Exception("Invalid benchmark description '" + pattern + "'!")
pattern.extend(["", "*", "*"][len(pattern) - 1 :])
patterns[i] = pattern
filtered = []
for workload in sorted(available_workloads.keys()):
generator, queries = available_workloads[workload]
for variant in generator.VARIANTS:
is_default_variant = variant == generator.DEFAULT_VARIANT
current = collections.defaultdict(list)
for group in queries:
for query_name, query_func in queries[group]:
if match_patterns(
workload,
variant,
group,
query_name,
is_default_variant,
patterns,
):
current[group].append((query_name, query_func))
if len(current) == 0:
continue
# Ignore benchgraph "basic" queries in standard CI/CD run
for pattern in patterns:
res = pattern.count("*")
key = "basic"
if res >= 2 and key in current.keys():
current.pop(key)
filtered.append((generator(variant=variant, benchmark_context=benchmark_context), dict(current)))
return filtered
def parse_kwargs(items):
"""
Parse a series of key-value pairs and return a dictionary
"""
d = {}
if items:
for item in items:
key, value = item.split("=")
d[key] = value
return d
class Directory:
def __init__(self, path):
self._path = path
@ -103,6 +284,9 @@ class Cache:
ensure_directory(path)
return Directory(path)
def get_default_cache_directory(self):
return self._directory
def load_config(self):
if not os.path.isfile(self._config):
return RecursiveDict()

View File

@ -9,6 +9,8 @@
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import logging
COLOR_GRAY = 0
COLOR_RED = 1
COLOR_GREEN = 2
@ -16,27 +18,45 @@ COLOR_YELLOW = 3
COLOR_BLUE = 4
COLOR_VIOLET = 5
COLOR_CYAN = 6
COLOR_WHITE = 7
def log(color, *args):
logger = logging.Logger("mgbench_logger")
file_handler = logging.FileHandler("mgbench_logs.log")
file_format = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
file_handler.setFormatter(file_format)
logger.addHandler(file_handler)
def _log(color, *args):
print("\033[1;3{}m~~".format(color), *args, "~~\033[0m")
def log(msg):
print(msg)
logger.info(msg=msg)
def init(*args):
log(COLOR_BLUE, *args)
_log(COLOR_BLUE, *args)
logger.info(*args)
def info(*args):
log(COLOR_CYAN, *args)
_log(COLOR_WHITE, *args)
logger.info(*args)
def success(*args):
log(COLOR_GREEN, *args)
_log(COLOR_GREEN, *args)
logger.info(*args)
def warning(*args):
log(COLOR_YELLOW, *args)
_log(COLOR_YELLOW, *args)
logger.warning(*args)
def error(*args):
log(COLOR_RED, *args)
_log(COLOR_RED, *args)
logger.critical(*args)

View File

@ -1,4 +1,4 @@
# Copyright 2022 Memgraph Ltd.
# Copyright 2023 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -17,10 +17,13 @@ import subprocess
import tempfile
import threading
import time
from abc import ABC, abstractmethod
from pathlib import Path
from benchmark_context import BenchmarkContext
def wait_for_server(port, delay=0.1):
def _wait_for_server(port, delay=0.1):
cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)]
while subprocess.call(cmd) != 0:
time.sleep(0.01)
@ -62,50 +65,165 @@ def _get_current_usage(pid):
return rss / 1024
class Memgraph:
def __init__(self, memgraph_binary, temporary_dir, properties_on_edges, bolt_port, performance_tracking):
self._memgraph_binary = memgraph_binary
self._directory = tempfile.TemporaryDirectory(dir=temporary_dir)
self._properties_on_edges = properties_on_edges
class BaseClient(ABC):
@abstractmethod
def __init__(self, benchmark_context: BenchmarkContext):
self.benchmark_context = benchmark_context
@abstractmethod
def execute(self):
pass
class BoltClient(BaseClient):
def __init__(self, benchmark_context: BenchmarkContext):
self._client_binary = benchmark_context.client_binary
self._directory = tempfile.TemporaryDirectory(dir=benchmark_context.temporary_directory)
self._username = ""
self._password = ""
self._bolt_port = (
benchmark_context.vendor_args["bolt-port"] if "bolt-port" in benchmark_context.vendor_args.keys() else 7687
)
def _get_args(self, **kwargs):
return _convert_args_to_flags(self._client_binary, **kwargs)
def set_credentials(self, username: str, password: str):
self._username = username
self._password = password
def execute(
self,
queries=None,
file_path=None,
num_workers=1,
max_retries: int = 50,
validation: bool = False,
time_dependent_execution: int = 0,
):
if (queries is None and file_path is None) or (queries is not None and file_path is not None):
raise ValueError("Either queries or input_path must be specified!")
queries_json = False
if queries is not None:
queries_json = True
file_path = os.path.join(self._directory.name, "queries.json")
with open(file_path, "w") as f:
for query in queries:
json.dump(query, f)
f.write("\n")
args = self._get_args(
input=file_path,
num_workers=num_workers,
max_retries=max_retries,
queries_json=queries_json,
username=self._username,
password=self._password,
port=self._bolt_port,
validation=validation,
time_dependent_execution=time_dependent_execution,
)
ret = None
try:
ret = subprocess.run(args, capture_output=True)
finally:
error = ret.stderr.decode("utf-8").strip().split("\n")
data = ret.stdout.decode("utf-8").strip().split("\n")
if error and error[0] != "":
print("Reported errros from client")
print(error)
data = [x for x in data if not x.startswith("[")]
return list(map(json.loads, data))
class BaseRunner(ABC):
subclasses = {}
def __init_subclass__(cls, **kwargs) -> None:
super().__init_subclass__(**kwargs)
cls.subclasses[cls.__name__.lower()] = cls
return
@classmethod
def create(cls, benchmark_context: BenchmarkContext):
if benchmark_context.vendor_name not in cls.subclasses:
raise ValueError("Missing runner with name: {}".format(benchmark_context.vendor_name))
return cls.subclasses[benchmark_context.vendor_name](
benchmark_context=benchmark_context,
)
@abstractmethod
def __init__(self, benchmark_context: BenchmarkContext):
self.benchmark_context = benchmark_context
@abstractmethod
def start_benchmark(self):
pass
@abstractmethod
def start_preparation(self):
pass
@abstractmethod
def stop(self):
pass
@abstractmethod
def clean_db(self):
pass
@abstractmethod
def fetch_client(self) -> BaseClient:
pass
class Memgraph(BaseRunner):
def __init__(self, benchmark_context: BenchmarkContext):
super().__init__(benchmark_context=benchmark_context)
self._memgraph_binary = benchmark_context.vendor_binary
self._performance_tracking = benchmark_context.performance_tracking
self._directory = tempfile.TemporaryDirectory(dir=benchmark_context.temporary_directory)
self._vendor_args = benchmark_context.vendor_args
self._properties_on_edges = (
self._vendor_args["no-properties-on-edges"]
if "no-properties-on-edges" in self._vendor_args.keys()
else False
)
self._bolt_port = self._vendor_args["bolt-port"] if "bolt-port" in self._vendor_args.keys() else 7687
self._proc_mg = None
self._bolt_port = bolt_port
self.performance_tracking = performance_tracking
self._stop_event = threading.Event()
self._rss = []
atexit.register(self._cleanup)
# Determine Memgraph version
ret = subprocess.run([memgraph_binary, "--version"], stdout=subprocess.PIPE, check=True)
ret = subprocess.run([self._memgraph_binary, "--version"], stdout=subprocess.PIPE, check=True)
version = re.search(r"[0-9]+\.[0-9]+\.[0-9]+", ret.stdout.decode("utf-8")).group(0)
self._memgraph_version = tuple(map(int, version.split(".")))
atexit.register(self._cleanup)
def __del__(self):
self._cleanup()
atexit.unregister(self._cleanup)
def _get_args(self, **kwargs):
def _set_args(self, **kwargs):
data_directory = os.path.join(self._directory.name, "memgraph")
kwargs["bolt_port"] = self._bolt_port
if self._memgraph_version >= (0, 50, 0):
kwargs["data_directory"] = data_directory
else:
kwargs["durability_directory"] = data_directory
if self._memgraph_version >= (0, 50, 0):
kwargs["storage_properties_on_edges"] = self._properties_on_edges
else:
assert self._properties_on_edges, "Older versions of Memgraph can't disable properties on edges!"
kwargs["data_directory"] = data_directory
kwargs["storage_properties_on_edges"] = self._properties_on_edges
return _convert_args_to_flags(self._memgraph_binary, **kwargs)
def _start(self, **kwargs):
if self._proc_mg is not None:
raise Exception("The database process is already running!")
args = self._get_args(**kwargs)
args = self._set_args(**kwargs)
self._proc_mg = subprocess.Popen(args, stdout=subprocess.DEVNULL)
time.sleep(0.2)
if self._proc_mg.poll() is not None:
self._proc_mg = None
raise Exception("The database process died prematurely!")
wait_for_server(self._bolt_port)
_wait_for_server(self._bolt_port)
ret = self._proc_mg.poll()
assert ret is None, "The database process died prematurely " "({})!".format(ret)
@ -119,7 +237,7 @@ class Memgraph:
return ret, usage
def start_preparation(self, workload):
if self.performance_tracking:
if self._performance_tracking:
p = threading.Thread(target=self.res_background_tracking, args=(self._rss, self._stop_event))
self._stop_event.clear()
self._rss.clear()
@ -127,13 +245,26 @@ class Memgraph:
self._start(storage_snapshot_on_exit=True)
def start_benchmark(self, workload):
if self.performance_tracking:
if self._performance_tracking:
p = threading.Thread(target=self.res_background_tracking, args=(self._rss, self._stop_event))
self._stop_event.clear()
self._rss.clear()
p.start()
self._start(storage_recover_on_startup=True)
def clean_db(self):
if self._proc_mg is not None:
raise Exception("The database process is already running, cannot clear data it!")
else:
out = subprocess.run(
args="rm -Rf memgraph/snapshots/*",
cwd=self._directory.name,
capture_output=True,
shell=True,
)
print(out.stderr.decode("utf-8"))
print(out.stdout.decode("utf-8"))
def res_background_tracking(self, res, stop_event):
print("Started rss tracking.")
while not stop_event.is_set():
@ -154,35 +285,46 @@ class Memgraph:
f.close()
def stop(self, workload):
if self.performance_tracking:
if self._performance_tracking:
self._stop_event.set()
self.dump_rss(workload)
ret, usage = self._cleanup()
assert ret == 0, "The database process exited with a non-zero " "status ({})!".format(ret)
return usage
def fetch_client(self) -> BoltClient:
return BoltClient(benchmark_context=self.benchmark_context)
class Neo4j:
def __init__(self, neo4j_path, temporary_dir, bolt_port, performance_tracking):
self._neo4j_path = Path(neo4j_path)
self._neo4j_binary = Path(neo4j_path) / "bin" / "neo4j"
self._neo4j_config = Path(neo4j_path) / "conf" / "neo4j.conf"
self._neo4j_pid = Path(neo4j_path) / "run" / "neo4j.pid"
self._neo4j_admin = Path(neo4j_path) / "bin" / "neo4j-admin"
self.performance_tracking = performance_tracking
class Neo4j(BaseRunner):
def __init__(self, benchmark_context: BenchmarkContext):
super().__init__(benchmark_context=benchmark_context)
self._neo4j_binary = Path(benchmark_context.vendor_binary)
self._neo4j_path = Path(benchmark_context.vendor_binary).parents[1]
self._neo4j_config = self._neo4j_path / "conf" / "neo4j.conf"
self._neo4j_pid = self._neo4j_path / "run" / "neo4j.pid"
self._neo4j_admin = self._neo4j_path / "bin" / "neo4j-admin"
self._performance_tracking = benchmark_context.performance_tracking
self._vendor_args = benchmark_context.vendor_args
self._stop_event = threading.Event()
self._rss = []
if not self._neo4j_binary.is_file():
raise Exception("Wrong path to binary!")
self._directory = tempfile.TemporaryDirectory(dir=temporary_dir)
self._bolt_port = bolt_port
tempfile.TemporaryDirectory(dir=benchmark_context.temporary_directory)
self._bolt_port = (
self.benchmark_context.vendor_args["bolt-port"]
if "bolt-port" in self.benchmark_context.vendor_args.keys()
else 7687
)
atexit.register(self._cleanup)
configs = []
memory_flag = "server.jvm.additional=-XX:NativeMemoryTracking=detail"
auth_flag = "dbms.security.auth_enabled=false"
if self.performance_tracking:
bolt_flag = "server.bolt.listen_address=:7687"
http_flag = "server.http.listen_address=:7474"
if self._performance_tracking:
configs.append(memory_flag)
else:
lines = []
@ -201,6 +343,8 @@ class Neo4j:
file.close()
configs.append(auth_flag)
configs.append(bolt_flag)
configs.append(http_flag)
print("Check neo4j config flags:")
for conf in configs:
with self._neo4j_config.open("r+") as file:
@ -234,7 +378,7 @@ class Neo4j:
else:
raise Exception("The database process died prematurely!")
print("Run server check:")
wait_for_server(self._bolt_port)
_wait_for_server(self._bolt_port)
def _cleanup(self):
if self._neo4j_pid.exists():
@ -248,7 +392,7 @@ class Neo4j:
return 0
def start_preparation(self, workload):
if self.performance_tracking:
if self._performance_tracking:
p = threading.Thread(target=self.res_background_tracking, args=(self._rss, self._stop_event))
self._stop_event.clear()
self._rss.clear()
@ -257,11 +401,11 @@ class Neo4j:
# Start DB
self._start()
if self.performance_tracking:
if self._performance_tracking:
self.get_memory_usage("start_" + workload)
def start_benchmark(self, workload):
if self.performance_tracking:
if self._performance_tracking:
p = threading.Thread(target=self.res_background_tracking, args=(self._rss, self._stop_event))
self._stop_event.clear()
self._rss.clear()
@ -269,7 +413,7 @@ class Neo4j:
# Start DB
self._start()
if self.performance_tracking:
if self._performance_tracking:
self.get_memory_usage("start_" + workload)
def dump_db(self, path):
@ -290,6 +434,20 @@ class Neo4j:
check=True,
)
def clean_db(self):
print("Cleaning the database")
if self._neo4j_pid.exists():
raise Exception("Cannot clean DB because it is running.")
else:
out = subprocess.run(
args="rm -Rf data/databases/* data/transactions/*",
cwd=self._neo4j_path,
capture_output=True,
shell=True,
)
print(out.stderr.decode("utf-8"))
print(out.stdout.decode("utf-8"))
def load_db_from_dump(self, path):
print("Loading the neo4j database from dump...")
if self._neo4j_pid.exists():
@ -300,7 +458,8 @@ class Neo4j:
self._neo4j_admin,
"database",
"load",
"--from-path=" + path,
"--from-path",
path,
"--overwrite-destination=true",
"neo4j",
],
@ -325,7 +484,7 @@ class Neo4j:
return True
def stop(self, workload):
if self.performance_tracking:
if self._performance_tracking:
self._stop_event.set()
self.get_memory_usage("stop_" + workload)
self.dump_rss(workload)
@ -360,51 +519,5 @@ class Neo4j:
f.write(memory_usage.stdout)
f.close()
class Client:
def __init__(
self, client_binary: str, temporary_directory: str, bolt_port: int, username: str = "", password: str = ""
):
self._client_binary = client_binary
self._directory = tempfile.TemporaryDirectory(dir=temporary_directory)
self._username = username
self._password = password
self._bolt_port = bolt_port
def _get_args(self, **kwargs):
return _convert_args_to_flags(self._client_binary, **kwargs)
def execute(self, queries=None, file_path=None, num_workers=1):
if (queries is None and file_path is None) or (queries is not None and file_path is not None):
raise ValueError("Either queries or input_path must be specified!")
# TODO: check `file_path.endswith(".json")` to support advanced
# input queries
queries_json = False
if queries is not None:
queries_json = True
file_path = os.path.join(self._directory.name, "queries.json")
with open(file_path, "w") as f:
for query in queries:
json.dump(query, f)
f.write("\n")
args = self._get_args(
input=file_path,
num_workers=num_workers,
queries_json=queries_json,
username=self._username,
password=self._password,
port=self._bolt_port,
)
ret = subprocess.run(args, capture_output=True, check=True)
error = ret.stderr.decode("utf-8").strip().split("\n")
if error and error[0] != "":
print("Reported errros from client")
print(error)
data = ret.stdout.decode("utf-8").strip().split("\n")
data = [x for x in data if not x.startswith("[")]
return list(map(json.loads, data))
def fetch_client(self) -> BoltClient:
return BoltClient(benchmark_context=self.benchmark_context)

244
tests/mgbench/validation.py Normal file
View File

@ -0,0 +1,244 @@
import argparse
import copy
import multiprocessing
import random
import helpers
import runners
import workloads
from benchmark_context import BenchmarkContext
from workloads import base
def pars_args():
parser = argparse.ArgumentParser(
prog="Validator for individual query checking",
description="""Validates that query is running, and validates output between different vendors""",
)
parser.add_argument(
"benchmarks",
nargs="*",
default="",
help="descriptions of benchmarks that should be run; "
"multiple descriptions can be specified to run multiple "
"benchmarks; the description is specified as "
"dataset/variant/group/query; Unix shell-style wildcards "
"can be used in the descriptions; variant, group and query "
"are optional and they can be left out; the default "
"variant is '' which selects the default dataset variant; "
"the default group is '*' which selects all groups; the"
"default query is '*' which selects all queries",
)
parser.add_argument(
"--vendor-binary-1",
help="Vendor binary used for benchmarking, by default it is memgraph",
default=helpers.get_binary_path("memgraph"),
)
parser.add_argument(
"--vendor-name-1",
default="memgraph",
choices=["memgraph", "neo4j"],
help="Input vendor binary name (memgraph, neo4j)",
)
parser.add_argument(
"--vendor-binary-2",
help="Vendor binary used for benchmarking, by default it is memgraph",
default=helpers.get_binary_path("memgraph"),
)
parser.add_argument(
"--vendor-name-2",
default="memgraph",
choices=["memgraph", "neo4j"],
help="Input vendor binary name (memgraph, neo4j)",
)
parser.add_argument(
"--client-binary",
default=helpers.get_binary_path("tests/mgbench/client"),
help="Client binary used for benchmarking",
)
parser.add_argument(
"--temporary-directory",
default="/tmp",
help="directory path where temporary data should " "be stored",
)
parser.add_argument(
"--num-workers-for-import",
type=int,
default=multiprocessing.cpu_count() // 2,
help="number of workers used to import the dataset",
)
return parser.parse_args()
def get_queries(gen, count):
# Make the generator deterministic.
random.seed(gen.__name__)
# Generate queries.
ret = []
for i in range(count):
ret.append(gen())
return ret
if __name__ == "__main__":
args = pars_args()
benchmark_context_db_1 = BenchmarkContext(
vendor_name=args.vendor_name_1,
vendor_binary=args.vendor_binary_1,
benchmark_target_workload=copy.copy(args.benchmarks),
client_binary=args.client_binary,
num_workers_for_import=args.num_workers_for_import,
temporary_directory=args.temporary_directory,
)
available_workloads = helpers.get_available_workloads()
print(helpers.list_available_workloads())
vendor_runner = runners.BaseRunner.create(
benchmark_context=benchmark_context_db_1,
)
cache = helpers.Cache()
client = vendor_runner.fetch_client()
workloads = helpers.filter_workloads(
available_workloads=available_workloads, benchmark_context=benchmark_context_db_1
)
results_db_1 = {}
for workload, queries in workloads:
vendor_runner.clean_db()
generated_queries = workload.dataset_generator()
if generated_queries:
vendor_runner.start_preparation("import")
client.execute(queries=generated_queries, num_workers=benchmark_context_db_1.num_workers_for_import)
vendor_runner.stop("import")
else:
workload.prepare(cache.cache_directory("datasets", workload.NAME, workload.get_variant()))
imported = workload.custom_import()
if not imported:
vendor_runner.start_preparation("import")
print("Executing database cleanup and index setup...")
client.execute(
file_path=workload.get_index(), num_workers=benchmark_context_db_1.num_workers_for_import
)
print("Importing dataset...")
ret = client.execute(
file_path=workload.get_file(), num_workers=benchmark_context_db_1.num_workers_for_import
)
usage = vendor_runner.stop("import")
for group in sorted(queries.keys()):
for query, funcname in queries[group]:
print("Running query:{}/{}/{}".format(group, query, funcname))
func = getattr(workload, funcname)
count = 1
vendor_runner.start_benchmark("validation")
try:
ret = client.execute(queries=get_queries(func, count), num_workers=1, validation=True)[0]
results_db_1[funcname] = ret["results"].items()
except Exception as e:
print("Issue running the query" + funcname)
print(e)
results_db_1[funcname] = "Query not executed properly"
finally:
usage = vendor_runner.stop("validation")
print("Database used {:.3f} seconds of CPU time.".format(usage["cpu"]))
print("Database peaked at {:.3f} MiB of memory.".format(usage["memory"] / 1024.0 / 1024.0))
benchmark_context_db_2 = BenchmarkContext(
vendor_name=args.vendor_name_2,
vendor_binary=args.vendor_binary_2,
benchmark_target_workload=copy.copy(args.benchmarks),
client_binary=args.client_binary,
num_workers_for_import=args.num_workers_for_import,
temporary_directory=args.temporary_directory,
)
vendor_runner = runners.BaseRunner.create(
benchmark_context=benchmark_context_db_2,
)
available_workloads = helpers.get_available_workloads()
workloads = helpers.filter_workloads(available_workloads, benchmark_context=benchmark_context_db_2)
client = vendor_runner.fetch_client()
results_db_2 = {}
for workload, queries in workloads:
vendor_runner.clean_db()
generated_queries = workload.dataset_generator()
if generated_queries:
vendor_runner.start_preparation("import")
client.execute(queries=generated_queries, num_workers=benchmark_context_db_2.num_workers_for_import)
vendor_runner.stop("import")
else:
workload.prepare(cache.cache_directory("datasets", workload.NAME, workload.get_variant()))
imported = workload.custom_import()
if not imported:
vendor_runner.start_preparation("import")
print("Executing database cleanup and index setup...")
client.execute(
file_path=workload.get_index(), num_workers=benchmark_context_db_2.num_workers_for_import
)
print("Importing dataset...")
ret = client.execute(
file_path=workload.get_file(), num_workers=benchmark_context_db_2.num_workers_for_import
)
usage = vendor_runner.stop("import")
for group in sorted(queries.keys()):
for query, funcname in queries[group]:
print("Running query:{}/{}/{}".format(group, query, funcname))
func = getattr(workload, funcname)
count = 1
vendor_runner.start_benchmark("validation")
try:
ret = client.execute(queries=get_queries(func, count), num_workers=1, validation=True)[0]
results_db_2[funcname] = ret["results"].items()
except Exception as e:
print("Issue running the query" + funcname)
print(e)
results_db_2[funcname] = "Query not executed properly"
finally:
usage = vendor_runner.stop("validation")
print("Database used {:.3f} seconds of CPU time.".format(usage["cpu"]))
print("Database peaked at {:.3f} MiB of memory.".format(usage["memory"] / 1024.0 / 1024.0))
validation = {}
for key in results_db_1.keys():
if type(results_db_1[key]) is str:
validation[key] = "Query not executed properly."
else:
db_1_values = set()
for index, value in results_db_1[key]:
db_1_values.add(value)
neo4j_values = set()
for index, value in results_db_2[key]:
neo4j_values.add(value)
if db_1_values == neo4j_values:
validation[key] = "Identical results"
else:
validation[key] = "Different results, check manually."
for key, value in validation.items():
print(key + " " + value)

View File

@ -0,0 +1,4 @@
from pathlib import Path
modules = Path(__file__).resolve().parent.glob("*.py")
__all__ = [f.name[:-3] for f in modules if f.is_file() and not f.name == "__init__.py"]

View File

@ -0,0 +1,197 @@
# Copyright 2022 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
from abc import ABC, abstractclassmethod
from pathlib import Path
import helpers
from benchmark_context import BenchmarkContext
# Base dataset class used as a template to create each individual dataset. All
# common logic is handled here.
class Workload(ABC):
# Name of the workload/dataset.
NAME = ""
# List of all variants of the workload/dataset that exist.
VARIANTS = ["default"]
# One of the available variants that should be used as the default variant.
DEFAULT_VARIANT = "default"
# List of local files that should be used to import the dataset.
LOCAL_FILE = None
# URLs of remote dataset files that should be used to import the dataset, compressed in gz format.
URL_FILE = None
# Index files
LOCAL_INDEX_FILE = None
URL_INDEX_FILE = None
# Number of vertices/edges for each variant.
SIZES = {
"default": {"vertices": 0, "edges": 0},
}
# Indicates whether the dataset has properties on edges.
PROPERTIES_ON_EDGES = False
def __init_subclass__(cls) -> None:
name_prerequisite = "NAME" in cls.__dict__
generator_prerequisite = "dataset_generator" in cls.__dict__
custom_import_prerequisite = "custom_import" in cls.__dict__
basic_import_prerequisite = ("LOCAL_FILE" in cls.__dict__ or "URL_FILE" in cls.__dict__) and (
"LOCAL_INDEX_FILE" in cls.__dict__ or "URL_INDEX_FILE" in cls.__dict__
)
if not name_prerequisite:
raise ValueError(
"""Can't define a workload class {} without NAME property:
NAME = "dataset name"
Name property defines the workload you want to execute, for example: "demo/*/*/*"
""".format(
cls.__name__
)
)
# Check workload is in generator or dataset mode during interpretation (not both), not runtime
if generator_prerequisite and (custom_import_prerequisite or basic_import_prerequisite):
raise ValueError(
"""
The workload class {} cannot have defined dataset import and generate dataset at
the same time.
""".format(
cls.__name__
)
)
if not generator_prerequisite and (not custom_import_prerequisite and not basic_import_prerequisite):
raise ValueError(
"""
The workload class {} need to have defined dataset import or dataset generator
""".format(
cls.__name__
)
)
return super().__init_subclass__()
def __init__(self, variant: str = None, benchmark_context: BenchmarkContext = None):
"""
Accepts a `variant` variable that indicates which variant
of the dataset should be executed
"""
self.benchmark_context = benchmark_context
self._variant = variant
self._vendor = benchmark_context.vendor_name
self._file = None
self._file_index = None
if self.NAME == "":
raise ValueError("Give your workload a name, by setting self.NAME")
if variant is None:
variant = self.DEFAULT_VARIANT
if variant not in self.VARIANTS:
raise ValueError("Invalid test variant!")
if (self.LOCAL_FILE and variant not in self.LOCAL_FILE) and (self.URL_FILE and variant not in self.URL_FILE):
raise ValueError("The variant doesn't have a defined URL or LOCAL file path!")
if variant not in self.SIZES:
raise ValueError("The variant doesn't have a defined dataset " "size!")
if (self.LOCAL_INDEX_FILE and self._vendor not in self.LOCAL_INDEX_FILE) and (
self.URL_INDEX_FILE and self._vendor not in self.URL_INDEX_FILE
):
raise ValueError("Vendor does not have INDEX for dataset!")
if self.LOCAL_FILE is not None:
self._local_file = self.LOCAL_FILE.get(variant, None)
else:
self._local_file = None
if self.URL_FILE is not None:
self._url_file = self.URL_FILE.get(variant, None)
else:
self._url_file = None
if self.LOCAL_INDEX_FILE is not None:
self._local_index = self.LOCAL_INDEX_FILE.get(self._vendor, None)
else:
self._local_index = None
if self.URL_INDEX_FILE is not None:
self._url_index = self.URL_INDEX_FILE.get(self._vendor, None)
else:
self._url_index = None
self._size = self.SIZES[variant]
if "vertices" in self._size or "edges" in self._size:
self._num_vertices = self._size["vertices"]
self._num_edges = self._size["edges"]
def prepare(self, directory):
if self._local_file is not None:
print("Using local dataset file:", self._local_file)
self._file = self._local_file
elif self._url_file is not None:
cached_input, exists = directory.get_file("dataset.cypher")
if not exists:
print("Downloading dataset file:", self._url_file)
downloaded_file = helpers.download_file(self._url_file, directory.get_path())
print("Unpacking and caching file:", downloaded_file)
helpers.unpack_gz_and_move_file(downloaded_file, cached_input)
print("Using cached dataset file:", cached_input)
self._file = cached_input
if self._local_index is not None:
print("Using local index file:", self._local_index)
self._file_index = self._local_index
elif self._url_index is not None:
cached_index, exists = directory.get_file(self._vendor + ".cypher")
if not exists:
print("Downloading index file:", self._url_index)
downloaded_file = helpers.download_file(self._url_index, directory.get_path())
print("Unpacking and caching file:", downloaded_file)
helpers.unpack_gz_and_move_file(downloaded_file, cached_index)
print("Using cached index file:", cached_index)
self._file_index = cached_index
def get_variant(self):
"""Returns the current variant of the dataset."""
return self._variant
def get_index(self):
"""Get index file, defined by vendor"""
return self._file_index
def get_file(self):
"""
Returns path to the file that contains dataset creation queries.
"""
return self._file
def get_size(self):
"""Returns number of vertices/edges for the current variant."""
return self._size
def custom_import(self) -> bool:
print("Workload does not have a custom import")
return False
def dataset_generator(self) -> list:
print("Workload is not auto generated")
return []
# All tests should be query generator functions that output all of the
# queries that should be executed by the runner. The functions should be
# named `benchmark__GROUPNAME__TESTNAME` and should not accept any
# arguments.

View File

@ -0,0 +1,28 @@
import random
from workloads.base import Workload
class Demo(Workload):
NAME = "demo"
def dataset_generator(self):
queries = [("MATCH (n) DETACH DELETE n;", {})]
for i in range(0, 100):
queries.append(("CREATE (:NodeA{{ id:{}}});".format(i), {}))
queries.append(("CREATE (:NodeB{{ id:{}}});".format(i), {}))
for i in range(0, 100):
a = random.randint(0, 99)
b = random.randint(0, 99)
queries.append(("MATCH(a:NodeA{{ id: {}}}),(b:NodeB{{id: {}}}) CREATE (a)-[:EDGE]->(b)".format(a, b), {}))
return queries
def benchmark__test__sample_query1(self):
return ("MATCH (n) RETURN n", {})
def benchmark__test__sample_query2(self):
return ("MATCH (n) RETURN n", {})

View File

@ -0,0 +1,213 @@
import csv
import subprocess
from collections import defaultdict
from pathlib import Path
import helpers
from benchmark_context import BenchmarkContext
from runners import BaseRunner
HEADERS_URL = "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/ldbc/benchmark/bi/headers.tar.gz"
class ImporterLDBCBI:
def __init__(
self, benchmark_context: BenchmarkContext, dataset_name: str, variant: str, index_file: str, csv_dict: dict
) -> None:
self._benchmark_context = benchmark_context
self._dataset_name = dataset_name
self._variant = variant
self._index_file = index_file
self._csv_dict = csv_dict
def execute_import(self):
vendor_runner = BaseRunner.create(
benchmark_context=self._benchmark_context,
)
client = vendor_runner.fetch_client()
if self._benchmark_context.vendor_name == "neo4j":
data_dir = Path() / ".cache" / "datasets" / self._dataset_name / self._variant / "data_neo4j"
data_dir.mkdir(parents=True, exist_ok=True)
dir_name = self._csv_dict[self._variant].split("/")[-1:][0].removesuffix(".tar.zst")
if (data_dir / dir_name).exists():
print("Files downloaded")
data_dir = data_dir / dir_name
else:
print("Downloading files")
downloaded_file = helpers.download_file(self._csv_dict[self._variant], data_dir.absolute())
print("Unpacking the file..." + downloaded_file)
data_dir = helpers.unpack_tar_zst(Path(downloaded_file))
headers_dir = Path() / ".cache" / "datasets" / self._dataset_name / self._variant / "headers_neo4j"
headers_dir.mkdir(parents=True, exist_ok=True)
headers = HEADERS_URL.split("/")[-1:][0].removesuffix(".tar.gz")
if (headers_dir / headers).exists():
print("Header files downloaded.")
else:
print("Downloading files")
downloaded_file = helpers.download_file(HEADERS_URL, headers_dir.absolute())
print("Unpacking the file..." + downloaded_file)
headers_dir = helpers.unpack_tar_gz(Path(downloaded_file))
input_headers = {}
for header_file in headers_dir.glob("**/*.csv"):
key = "/".join(header_file.parts[-2:])[0:-4]
input_headers[key] = header_file.as_posix()
for data_file in data_dir.glob("**/*.gz"):
if "initial_snapshot" in data_file.parts:
data_file = helpers.unpack_gz(data_file)
output = data_file.parent / (data_file.stem + "_neo" + ".csv")
if not output.exists():
with data_file.open("r") as input_f, output.open("a") as output_f:
reader = csv.reader(input_f, delimiter="|")
header = next(reader)
writer = csv.writer(output_f, delimiter="|")
for line in reader:
writer.writerow(line)
else:
print("Files converted")
input_files = defaultdict(list)
for neo_file in data_dir.glob("**/*_neo.csv"):
key = "/".join(neo_file.parts[-3:-1])
input_files[key].append(neo_file.as_posix())
vendor_runner.clean_db()
subprocess.run(
args=[
vendor_runner._neo4j_admin,
"database",
"import",
"full",
"--id-type=INTEGER",
"--ignore-empty-strings=true",
"--bad-tolerance=0",
"--nodes=Place=" + input_headers["static/Place"] + "," + ",".join(input_files["static/Place"]),
"--nodes=Organisation="
+ input_headers["static/Organisation"]
+ ","
+ ",".join(input_files["static/Organisation"]),
"--nodes=TagClass="
+ input_headers["static/TagClass"]
+ ","
+ ",".join(input_files["static/TagClass"]),
"--nodes=Tag=" + input_headers["static/Tag"] + "," + ",".join(input_files["static/Tag"]),
"--nodes=Forum=" + input_headers["dynamic/Forum"] + "," + ",".join(input_files["dynamic/Forum"]),
"--nodes=Person=" + input_headers["dynamic/Person"] + "," + ",".join(input_files["dynamic/Person"]),
"--nodes=Message:Comment="
+ input_headers["dynamic/Comment"]
+ ","
+ ",".join(input_files["dynamic/Comment"]),
"--nodes=Message:Post="
+ input_headers["dynamic/Post"]
+ ","
+ ",".join(input_files["dynamic/Post"]),
"--relationships=IS_PART_OF="
+ input_headers["static/Place_isPartOf_Place"]
+ ","
+ ",".join(input_files["static/Place_isPartOf_Place"]),
"--relationships=IS_SUBCLASS_OF="
+ input_headers["static/TagClass_isSubclassOf_TagClass"]
+ ","
+ ",".join(input_files["static/TagClass_isSubclassOf_TagClass"]),
"--relationships=IS_LOCATED_IN="
+ input_headers["static/Organisation_isLocatedIn_Place"]
+ ","
+ ",".join(input_files["static/Organisation_isLocatedIn_Place"]),
"--relationships=HAS_TYPE="
+ input_headers["static/Tag_hasType_TagClass"]
+ ","
+ ",".join(input_files["static/Tag_hasType_TagClass"]),
"--relationships=HAS_CREATOR="
+ input_headers["dynamic/Comment_hasCreator_Person"]
+ ","
+ ",".join(input_files["dynamic/Comment_hasCreator_Person"]),
"--relationships=IS_LOCATED_IN="
+ input_headers["dynamic/Comment_isLocatedIn_Country"]
+ ","
+ ",".join(input_files["dynamic/Comment_isLocatedIn_Country"]),
"--relationships=REPLY_OF="
+ input_headers["dynamic/Comment_replyOf_Comment"]
+ ","
+ ",".join(input_files["dynamic/Comment_replyOf_Comment"]),
"--relationships=REPLY_OF="
+ input_headers["dynamic/Comment_replyOf_Post"]
+ ","
+ ",".join(input_files["dynamic/Comment_replyOf_Post"]),
"--relationships=CONTAINER_OF="
+ input_headers["dynamic/Forum_containerOf_Post"]
+ ","
+ ",".join(input_files["dynamic/Forum_containerOf_Post"]),
"--relationships=HAS_MEMBER="
+ input_headers["dynamic/Forum_hasMember_Person"]
+ ","
+ ",".join(input_files["dynamic/Forum_hasMember_Person"]),
"--relationships=HAS_MODERATOR="
+ input_headers["dynamic/Forum_hasModerator_Person"]
+ ","
+ ",".join(input_files["dynamic/Forum_hasModerator_Person"]),
"--relationships=HAS_TAG="
+ input_headers["dynamic/Forum_hasTag_Tag"]
+ ","
+ ",".join(input_files["dynamic/Forum_hasTag_Tag"]),
"--relationships=HAS_INTEREST="
+ input_headers["dynamic/Person_hasInterest_Tag"]
+ ","
+ ",".join(input_files["dynamic/Person_hasInterest_Tag"]),
"--relationships=IS_LOCATED_IN="
+ input_headers["dynamic/Person_isLocatedIn_City"]
+ ","
+ ",".join(input_files["dynamic/Person_isLocatedIn_City"]),
"--relationships=KNOWS="
+ input_headers["dynamic/Person_knows_Person"]
+ ","
+ ",".join(input_files["dynamic/Person_knows_Person"]),
"--relationships=LIKES="
+ input_headers["dynamic/Person_likes_Comment"]
+ ","
+ ",".join(input_files["dynamic/Person_likes_Comment"]),
"--relationships=LIKES="
+ input_headers["dynamic/Person_likes_Post"]
+ ","
+ ",".join(input_files["dynamic/Person_likes_Post"]),
"--relationships=HAS_CREATOR="
+ input_headers["dynamic/Post_hasCreator_Person"]
+ ","
+ ",".join(input_files["dynamic/Post_hasCreator_Person"]),
"--relationships=HAS_TAG="
+ input_headers["dynamic/Comment_hasTag_Tag"]
+ ","
+ ",".join(input_files["dynamic/Comment_hasTag_Tag"]),
"--relationships=HAS_TAG="
+ input_headers["dynamic/Post_hasTag_Tag"]
+ ","
+ ",".join(input_files["dynamic/Post_hasTag_Tag"]),
"--relationships=IS_LOCATED_IN="
+ input_headers["dynamic/Post_isLocatedIn_Country"]
+ ","
+ ",".join(input_files["dynamic/Post_isLocatedIn_Country"]),
"--relationships=STUDY_AT="
+ input_headers["dynamic/Person_studyAt_University"]
+ ","
+ ",".join(input_files["dynamic/Person_studyAt_University"]),
"--relationships=WORK_AT="
+ input_headers["dynamic/Person_workAt_Company"]
+ ","
+ ",".join(input_files["dynamic/Person_workAt_Company"]),
"--delimiter",
"|",
"neo4j",
],
check=True,
)
vendor_runner.start_preparation("Index preparation")
print("Executing database index setup")
client.execute(file_path=self._index_file, num_workers=1)
vendor_runner.stop("Stop index preparation")
return True
else:
return False

View File

@ -0,0 +1,163 @@
import csv
import subprocess
from pathlib import Path
import helpers
from benchmark_context import BenchmarkContext
from runners import BaseRunner
# Removed speaks/email from person header
HEADERS_INTERACTIVE = {
"static/organisation": "id:ID(Organisation)|:LABEL|name:STRING|url:STRING",
"static/place": "id:ID(Place)|name:STRING|url:STRING|:LABEL",
"static/tagclass": "id:ID(TagClass)|name:STRING|url:STRING",
"static/tag": "id:ID(Tag)|name:STRING|url:STRING",
"static/tagclass_isSubclassOf_tagclass": ":START_ID(TagClass)|:END_ID(TagClass)",
"static/tag_hasType_tagclass": ":START_ID(Tag)|:END_ID(TagClass)",
"static/organisation_isLocatedIn_place": ":START_ID(Organisation)|:END_ID(Place)",
"static/place_isPartOf_place": ":START_ID(Place)|:END_ID(Place)",
"dynamic/comment": "id:ID(Comment)|creationDate:LOCALDATETIME|locationIP:STRING|browserUsed:STRING|content:STRING|length:INT",
"dynamic/forum": "id:ID(Forum)|title:STRING|creationDate:LOCALDATETIME",
"dynamic/person": "id:ID(Person)|firstName:STRING|lastName:STRING|gender:STRING|birthday:LOCALDATETIME|creationDate:LOCALDATETIME|locationIP:STRING|browserUsed:STRING",
"dynamic/post": "id:ID(Post)|imageFile:STRING|creationDate:LOCALDATETIME|locationIP:STRING|browserUsed:STRING|language:STRING|content:STRING|length:INT",
"dynamic/comment_hasCreator_person": ":START_ID(Comment)|:END_ID(Person)",
"dynamic/comment_isLocatedIn_place": ":START_ID(Comment)|:END_ID(Place)",
"dynamic/comment_replyOf_comment": ":START_ID(Comment)|:END_ID(Comment)",
"dynamic/comment_replyOf_post": ":START_ID(Comment)|:END_ID(Post)",
"dynamic/forum_containerOf_post": ":START_ID(Forum)|:END_ID(Post)",
"dynamic/forum_hasMember_person": ":START_ID(Forum)|:END_ID(Person)|joinDate:LOCALDATETIME",
"dynamic/forum_hasModerator_person": ":START_ID(Forum)|:END_ID(Person)",
"dynamic/forum_hasTag_tag": ":START_ID(Forum)|:END_ID(Tag)",
"dynamic/person_hasInterest_tag": ":START_ID(Person)|:END_ID(Tag)",
"dynamic/person_isLocatedIn_place": ":START_ID(Person)|:END_ID(Place)",
"dynamic/person_knows_person": ":START_ID(Person)|:END_ID(Person)|creationDate:LOCALDATETIME",
"dynamic/person_likes_comment": ":START_ID(Person)|:END_ID(Comment)|creationDate:LOCALDATETIME",
"dynamic/person_likes_post": ":START_ID(Person)|:END_ID(Post)|creationDate:LOCALDATETIME",
"dynamic/person_studyAt_organisation": ":START_ID(Person)|:END_ID(Organisation)|classYear:INT",
"dynamic/person_workAt_organisation": ":START_ID(Person)|:END_ID(Organisation)|workFrom:INT",
"dynamic/post_hasCreator_person": ":START_ID(Post)|:END_ID(Person)",
"dynamic/comment_hasTag_tag": ":START_ID(Comment)|:END_ID(Tag)",
"dynamic/post_hasTag_tag": ":START_ID(Post)|:END_ID(Tag)",
"dynamic/post_isLocatedIn_place": ":START_ID(Post)|:END_ID(Place)",
}
class ImporterLDBCInteractive:
def __init__(
self, benchmark_context: BenchmarkContext, dataset_name: str, variant: str, index_file: str, csv_dict: dict
) -> None:
self._benchmark_context = benchmark_context
self._dataset_name = dataset_name
self._variant = variant
self._index_file = index_file
self._csv_dict = csv_dict
def execute_import(self):
vendor_runner = BaseRunner.create(
benchmark_context=self._benchmark_context,
)
client = vendor_runner.fetch_client()
if self._benchmark_context.vendor_name == "neo4j":
print("Runnning Neo4j import")
dump_dir = Path() / ".cache" / "datasets" / self._dataset_name / self._variant / "dump"
dump_dir.mkdir(parents=True, exist_ok=True)
dir_name = self._csv_dict[self._variant].split("/")[-1:][0].removesuffix(".tar.zst")
if (dump_dir / dir_name).exists():
print("Files downloaded")
dump_dir = dump_dir / dir_name
else:
print("Downloading files")
downloaded_file = helpers.download_file(self._csv_dict[self._variant], dump_dir.absolute())
print("Unpacking the file..." + downloaded_file)
dump_dir = helpers.unpack_tar_zst(Path(downloaded_file))
input_files = {}
for file in dump_dir.glob("*/*0.csv"):
parts = file.parts[-2:]
key = parts[0] + "/" + parts[1][:-8]
input_files[key] = file
output_files = {}
for key, file in input_files.items():
output = file.parent / (file.stem + "_neo" + ".csv")
if not output.exists():
with file.open("r") as input_f, output.open("a") as output_f:
reader = csv.reader(input_f, delimiter="|")
header = next(reader)
writer = csv.writer(output_f, delimiter="|")
if key in HEADERS_INTERACTIVE.keys():
updated_header = HEADERS_INTERACTIVE[key].split("|")
writer.writerow(updated_header)
for line in reader:
if "creationDate" in header:
pos = header.index("creationDate")
line[pos] = line[pos][0:-5]
elif "joinDate" in header:
pos = header.index("joinDate")
line[pos] = line[pos][0:-5]
if "organisation_0_0.csv" == file.name:
writer.writerow([line[0], line[1].capitalize(), line[2], line[3]])
elif "place_0_0.csv" == file.name:
writer.writerow([line[0], line[1], line[2], line[3].capitalize()])
else:
writer.writerow(line)
output_files[key] = output.as_posix()
vendor_runner.clean_db()
subprocess.run(
args=[
vendor_runner._neo4j_admin,
"database",
"import",
"full",
"--id-type=INTEGER",
"--nodes=Place=" + output_files["static/place"],
"--nodes=Organisation=" + output_files["static/organisation"],
"--nodes=TagClass=" + output_files["static/tagclass"],
"--nodes=Tag=" + output_files["static/tag"],
"--nodes=Comment:Message=" + output_files["dynamic/comment"],
"--nodes=Forum=" + output_files["dynamic/forum"],
"--nodes=Person=" + output_files["dynamic/person"],
"--nodes=Post:Message=" + output_files["dynamic/post"],
"--relationships=IS_PART_OF=" + output_files["static/place_isPartOf_place"],
"--relationships=IS_SUBCLASS_OF=" + output_files["static/tagclass_isSubclassOf_tagclass"],
"--relationships=IS_LOCATED_IN=" + output_files["static/organisation_isLocatedIn_place"],
"--relationships=HAS_TYPE=" + output_files["static/tag_hasType_tagclass"],
"--relationships=HAS_CREATOR=" + output_files["dynamic/comment_hasCreator_person"],
"--relationships=IS_LOCATED_IN=" + output_files["dynamic/comment_isLocatedIn_place"],
"--relationships=REPLY_OF=" + output_files["dynamic/comment_replyOf_comment"],
"--relationships=REPLY_OF=" + output_files["dynamic/comment_replyOf_post"],
"--relationships=CONTAINER_OF=" + output_files["dynamic/forum_containerOf_post"],
"--relationships=HAS_MEMBER=" + output_files["dynamic/forum_hasMember_person"],
"--relationships=HAS_MODERATOR=" + output_files["dynamic/forum_hasModerator_person"],
"--relationships=HAS_TAG=" + output_files["dynamic/forum_hasTag_tag"],
"--relationships=HAS_INTEREST=" + output_files["dynamic/person_hasInterest_tag"],
"--relationships=IS_LOCATED_IN=" + output_files["dynamic/person_isLocatedIn_place"],
"--relationships=KNOWS=" + output_files["dynamic/person_knows_person"],
"--relationships=LIKES=" + output_files["dynamic/person_likes_comment"],
"--relationships=LIKES=" + output_files["dynamic/person_likes_post"],
"--relationships=HAS_CREATOR=" + output_files["dynamic/post_hasCreator_person"],
"--relationships=HAS_TAG=" + output_files["dynamic/comment_hasTag_tag"],
"--relationships=HAS_TAG=" + output_files["dynamic/post_hasTag_tag"],
"--relationships=IS_LOCATED_IN=" + output_files["dynamic/post_isLocatedIn_place"],
"--relationships=STUDY_AT=" + output_files["dynamic/person_studyAt_organisation"],
"--relationships=WORK_AT=" + output_files["dynamic/person_workAt_organisation"],
"--delimiter",
"|",
"neo4j",
],
check=True,
)
vendor_runner.start_preparation("Index preparation")
print("Executing database index setup")
client.execute(file_path=self._index_file, num_workers=1)
vendor_runner.stop("Stop index preparation")
return True
else:
return False

View File

@ -0,0 +1,41 @@
from pathlib import Path
from benchmark_context import BenchmarkContext
from runners import BaseRunner
class ImporterPokec:
def __init__(
self, benchmark_context: BenchmarkContext, dataset_name: str, variant: str, index_file: str, dataset_file: str
) -> None:
self._benchmark_context = benchmark_context
self._dataset_name = dataset_name
self._variant = variant
self._index_file = index_file
self._dataset_file = dataset_file
def execute_import(self):
if self._benchmark_context.vendor_name == "neo4j":
vendor_runner = BaseRunner.create(
benchmark_context=self._benchmark_context,
)
client = vendor_runner.fetch_client()
vendor_runner.clean_db()
vendor_runner.start_preparation("preparation")
print("Executing database cleanup and index setup...")
client.execute(file_path=self._index_file, num_workers=1)
vendor_runner.stop("preparation")
neo4j_dump = Path() / ".cache" / "datasets" / self._dataset_name / self._variant / "neo4j.dump"
if neo4j_dump.exists():
vendor_runner.load_db_from_dump(path=neo4j_dump.parent)
else:
vendor_runner.start_preparation("import")
print("Importing dataset...")
client.execute(file_path=self._dataset_file, num_workers=self._benchmark_context.num_workers_for_import)
vendor_runner.stop("import")
vendor_runner.dump_db(path=neo4j_dump.parent)
return True
else:
return False

View File

@ -0,0 +1,708 @@
import inspect
import random
from pathlib import Path
import helpers
from benchmark_context import BenchmarkContext
from workloads.base import Workload
from workloads.importers.importer_ldbc_bi import ImporterLDBCBI
class LDBC_BI(Workload):
NAME = "ldbc_bi"
VARIANTS = ["sf1", "sf3", "sf10"]
DEFAULT_VARIANT = "sf1"
URL_FILE = {
"sf1": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/ldbc/benchmark/bi/ldbc_bi_sf1.cypher.gz",
"sf3": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/ldbc/benchmark/bi/ldbc_bi_sf3.cypher.gz",
"sf10": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/ldbc/benchmark/bi/ldbc_bi_sf10.cypher.gz",
}
URL_CSV = {
"sf1": "https://pub-383410a98aef4cb686f0c7601eddd25f.r2.dev/bi-pre-audit/bi-sf1-composite-projected-fk.tar.zst",
"sf3": "https://pub-383410a98aef4cb686f0c7601eddd25f.r2.dev/bi-pre-audit/bi-sf3-composite-projected-fk.tar.zst",
"sf10": "https://pub-383410a98aef4cb686f0c7601eddd25f.r2.dev/bi-pre-audit/bi-sf10-composite-projected-fk.tar.zst",
}
SIZES = {
"sf1": {"vertices": 2997352, "edges": 17196776},
"sf3": {"vertices": 1, "edges": 1},
"sf10": {"vertices": 1, "edges": 1},
}
LOCAL_INDEX_FILES = None
URL_INDEX_FILE = {
"memgraph": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/ldbc/benchmark/bi/memgraph_bi_index.cypher",
"neo4j": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/ldbc/benchmark/bi/neo4j_bi_index.cypher",
}
QUERY_PARAMETERS = {
"sf1": "https://pub-383410a98aef4cb686f0c7601eddd25f.r2.dev/bi-pre-audit/parameters-2022-10-01.zip",
"sf3": "https://pub-383410a98aef4cb686f0c7601eddd25f.r2.dev/bi-pre-audit/parameters-2022-10-01.zip",
"sf10": "https://pub-383410a98aef4cb686f0c7601eddd25f.r2.dev/bi-pre-audit/parameters-2022-10-01.zip",
}
def custom_import(self) -> bool:
importer = ImporterLDBCBI(
benchmark_context=self.benchmark_context,
dataset_name=self.NAME,
variant=self._variant,
index_file=self._file_index,
csv_dict=self.URL_CSV,
)
return importer.execute_import()
def _prepare_parameters_directory(self):
parameters = Path() / ".cache" / "datasets" / self.NAME / self._variant / "parameters"
parameters.mkdir(parents=True, exist_ok=True)
if parameters.exists() and any(parameters.iterdir()):
print("Files downloaded.")
else:
print("Downloading files")
downloaded_file = helpers.download_file(self.QUERY_PARAMETERS[self._variant], parameters.parent.absolute())
print("Unpacking the file..." + downloaded_file)
parameters = helpers.unpack_zip(Path(downloaded_file))
return parameters / ("parameters-" + self._variant)
def _get_query_parameters(self) -> dict:
func_name = inspect.stack()[1].function
parameters = {}
for file in self._parameters_dir.glob("bi-*.csv"):
file_name_query_id = file.name.split("-")[1][0:-4]
func_name_id = func_name.split("_")[-1]
if file_name_query_id == func_name_id or file_name_query_id == func_name_id + "a":
with file.open("r") as input:
lines = input.readlines()
header = lines[0].strip("\n").split("|")
position = random.randint(1, len(lines) - 1)
data = lines[position].strip("\n").split("|")
for i in range(len(header)):
key, value_type = header[i].split(":")
if value_type == "DATETIME":
# Drop time zone
converted = data[i][0:-6]
parameters[key] = converted
elif value_type == "DATE":
converted = data[i] + "T00:00:00"
parameters[key] = converted
elif value_type == "INT":
parameters[key] = int(data[i])
elif value_type == "STRING[]":
elements = data[i].split(";")
parameters[key] = elements
else:
parameters[key] = data[i]
break
return parameters
def __init__(self, variant=None, benchmark_context: BenchmarkContext = None):
super().__init__(variant, benchmark_context=benchmark_context)
self._parameters_dir = self._prepare_parameters_directory()
def benchmark__bi__query_1_analytical(self):
memgraph = (
"""
MATCH (message:Message)
WHERE message.creationDate < localDateTime($datetime)
WITH count(message) AS totalMessageCountInt
WITH toFloat(totalMessageCountInt) AS totalMessageCount
MATCH (message:Message)
WHERE message.creationDate < localDateTime($datetime)
AND message.content IS NOT NULL
WITH
totalMessageCount,
message,
message.creationDate.year AS year
WITH
totalMessageCount,
year,
message:Comment AS isComment,
CASE
WHEN message.length < 40 THEN 0
WHEN message.length < 80 THEN 1
WHEN message.length < 160 THEN 2
ELSE 3
END AS lengthCategory,
count(message) AS messageCount,
sum(message.length) / toFloat(count(message)) AS averageMessageLength,
sum(message.length) AS sumMessageLength
RETURN
year,
isComment,
lengthCategory,
messageCount,
averageMessageLength,
sumMessageLength,
messageCount / totalMessageCount AS percentageOfMessages
ORDER BY
year DESC,
isComment ASC,
lengthCategory ASC
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
neo4j = (
"""
MATCH (message:Message)
WHERE message.creationDate < DateTime($datetime)
WITH count(message) AS totalMessageCountInt
WITH toFloat(totalMessageCountInt) AS totalMessageCount
MATCH (message:Message)
WHERE message.creationDate < DateTime($datetime)
AND message.content IS NOT NULL
WITH
totalMessageCount,
message,
message.creationDate.year AS year
WITH
totalMessageCount,
year,
message:Comment AS isComment,
CASE
WHEN message.length < 40 THEN 0
WHEN message.length < 80 THEN 1
WHEN message.length < 160 THEN 2
ELSE 3
END AS lengthCategory,
count(message) AS messageCount,
sum(message.length) / toFloat(count(message)) AS averageMessageLength,
sum(message.length) AS sumMessageLength
RETURN
year,
isComment,
lengthCategory,
messageCount,
averageMessageLength,
sumMessageLength,
messageCount / totalMessageCount AS percentageOfMessages
ORDER BY
year DESC,
isComment ASC,
lengthCategory ASC
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
if self._vendor == "memgraph":
return memgraph
else:
return neo4j
def benchmark__bi__query_2_analytical(self):
memgraph = (
"""
MATCH (tag:Tag)-[:HAS_TYPE]->(:TagClass {name: $tagClass})
OPTIONAL MATCH (message1:Message)-[:HAS_TAG]->(tag)
WHERE localDateTime($date) <= message1.creationDate
AND message1.creationDate < localDateTime($date) + duration({day: 100})
WITH tag, count(message1) AS countWindow1
OPTIONAL MATCH (message2:Message)-[:HAS_TAG]->(tag)
WHERE localDateTime($date) + duration({day: 100}) <= message2.creationDate
AND message2.creationDate < localDateTime($date) + duration({day: 200})
WITH
tag,
countWindow1,
count(message2) AS countWindow2
RETURN
tag.name,
countWindow1,
countWindow2,
abs(countWindow1 - countWindow2) AS diff
ORDER BY
diff DESC,
tag.name ASC
LIMIT 100
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
neo4j = (
"""
MATCH (tag:Tag)-[:HAS_TYPE]->(:TagClass {name: $tagClass})
OPTIONAL MATCH (message1:Message)-[:HAS_TAG]->(tag)
WHERE DateTime($date) <= message1.creationDate
AND message1.creationDate < DateTime($date) + duration({days: 100})
WITH tag, count(message1) AS countWindow1
OPTIONAL MATCH (message2:Message)-[:HAS_TAG]->(tag)
WHERE DateTime($date) + duration({days: 100}) <= message2.creationDate
AND message2.creationDate < DateTime($date) + duration({days: 200})
WITH
tag,
countWindow1,
count(message2) AS countWindow2
RETURN
tag.name,
countWindow1,
countWindow2,
abs(countWindow1 - countWindow2) AS diff
ORDER BY
diff DESC,
tag.name ASC
LIMIT 100
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
if self._vendor == "memgraph":
return memgraph
else:
return neo4j
def benchmark__bi__query_3_analytical(self):
return (
"""
MATCH
(:Country {name: $country})<-[:IS_PART_OF]-(:City)<-[:IS_LOCATED_IN]-
(person:Person)<-[:HAS_MODERATOR]-(forum:Forum)-[:CONTAINER_OF]->
(post:Post)<-[:REPLY_OF*0..]-(message:Message)-[:HAS_TAG]->(:Tag)-[:HAS_TYPE]->(:TagClass {name: $tagClass})
RETURN
forum.id as id,
forum.title,
person.id,
count(DISTINCT message) AS messageCount
ORDER BY
messageCount DESC,
id ASC
LIMIT 20
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
def benchmark__bi__query_5_analytical(self):
return (
"""
MATCH (tag:Tag {name: $tag})<-[:HAS_TAG]-(message:Message)-[:HAS_CREATOR]->(person:Person)
OPTIONAL MATCH (message)<-[likes:LIKES]-(:Person)
WITH person, message, count(likes) AS likeCount
OPTIONAL MATCH (message)<-[:REPLY_OF]-(reply:Comment)
WITH person, message, likeCount, count(reply) AS replyCount
WITH person, count(message) AS messageCount, sum(likeCount) AS likeCount, sum(replyCount) AS replyCount
RETURN
person.id,
replyCount,
likeCount,
messageCount,
1*messageCount + 2*replyCount + 10*likeCount AS score
ORDER BY
score DESC,
person.id ASC
LIMIT 100
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
def benchmark__bi__query_6_analytical(self):
return (
"""
MATCH (tag:Tag {name: $tag})<-[:HAS_TAG]-(message1:Message)-[:HAS_CREATOR]->(person1:Person)
OPTIONAL MATCH (message1)<-[:LIKES]-(person2:Person)
OPTIONAL MATCH (person2)<-[:HAS_CREATOR]-(message2:Message)<-[like:LIKES]-(person3:Person)
RETURN
person1.id as id,
count(DISTINCT like) AS authorityScore
ORDER BY
authorityScore DESC,
id ASC
LIMIT 100
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
def benchmark__bi__query_7_analytical(self):
memgraph = (
"""
MATCH
(tag:Tag {name: $tag})<-[:HAS_TAG]-(message:Message),
(message)<-[:REPLY_OF]-(comment:Comment)-[:HAS_TAG]->(relatedTag:Tag)
OPTIONAL MATCH (comment)-[:HAS_TAG]->(tag)
WHERE tag IS NOT NULL
RETURN
relatedTag,
count(DISTINCT comment) AS count
ORDER BY
relatedTag.name ASC,
count DESC
LIMIT 100
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
neo4j = (
"""
MATCH
(tag:Tag {name: $tag})<-[:HAS_TAG]-(message:Message),
(message)<-[:REPLY_OF]-(comment:Comment)-[:HAS_TAG]->(relatedTag:Tag)
WHERE NOT (comment)-[:HAS_TAG]->(tag)
RETURN
relatedTag.name,
count(DISTINCT comment) AS count
ORDER BY
relatedTag.name ASC,
count DESC
LIMIT 100
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
if self._vendor == "memgraph":
return memgraph
else:
return neo4j
def benchmark__bi__query_9_analytical(self):
memgraph = (
"""
MATCH (person:Person)<-[:HAS_CREATOR]-(post:Post)<-[:REPLY_OF*0..]-(reply:Message)
WHERE post.creationDate >= localDateTime($startDate)
AND post.creationDate <= localDateTime($endDate)
AND reply.creationDate >= localDateTime($startDate)
AND reply.creationDate <= localDateTime($endDate)
RETURN
person.id as id,
person.firstName,
person.lastName,
count(DISTINCT post) AS threadCount,
count(DISTINCT reply) AS messageCount
ORDER BY
messageCount DESC,
id ASC
LIMIT 100
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
neo4j = (
"""
MATCH (person:Person)<-[:HAS_CREATOR]-(post:Post)<-[:REPLY_OF*0..]-(reply:Message)
WHERE post.creationDate >= DateTime($startDate)
AND post.creationDate <= DateTime($endDate)
AND reply.creationDate >= DateTime($startDate)
AND reply.creationDate <= DateTime($endDate)
RETURN
person.id as id,
person.firstName,
person.lastName,
count(DISTINCT post) AS threadCount,
count(DISTINCT reply) AS messageCount
ORDER BY
messageCount DESC,
id ASC
LIMIT 100
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
if self._vendor == "memgraph":
return memgraph
else:
return neo4j
def benchmark__bi__query_11_analytical(self):
return (
"""
MATCH (a:Person)-[:IS_LOCATED_IN]->(:City)-[:IS_PART_OF]->(country:Country {name: $country}),
(a)-[k1:KNOWS]-(b:Person)
WHERE a.id < b.id
AND localDateTime($startDate) <= k1.creationDate AND k1.creationDate <= localDateTime($endDate)
WITH DISTINCT country, a, b
MATCH (b)-[:IS_LOCATED_IN]->(:City)-[:IS_PART_OF]->(country)
WITH DISTINCT country, a, b
MATCH (b)-[k2:KNOWS]-(c:Person),
(c)-[:IS_LOCATED_IN]->(:City)-[:IS_PART_OF]->(country)
WHERE b.id < c.id
AND localDateTime($startDate) <= k2.creationDate AND k2.creationDate <= localDateTime($endDate)
WITH DISTINCT a, b, c
MATCH (c)-[k3:KNOWS]-(a)
WHERE localDateTime($startDate) <= k3.creationDate AND k3.creationDate <= localDateTime($endDate)
WITH DISTINCT a, b, c
RETURN count(*) AS count
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
def benchmark__bi__query_12_analytical(self):
return (
"""
MATCH (person:Person)
OPTIONAL MATCH (person)<-[:HAS_CREATOR]-(message:Message)-[:REPLY_OF*0..]->(post:Post)
WHERE message.content IS NOT NULL
AND message.length < $lengthThreshold
AND message.creationDate > localDateTime($startDate)
AND post.language IN $languages
WITH
person,
count(message) AS messageCount
RETURN
messageCount,
count(person) AS personCount
ORDER BY
personCount DESC,
messageCount DESC
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
def benchmark__bi__query_13_analytical(self):
memgraph = (
"""
MATCH (country:Country {name: $country})<-[:IS_PART_OF]-(:City)<-[:IS_LOCATED_IN]-(zombie:Person)
WHERE zombie.creationDate < localDateTime($endDate)
WITH country, zombie
OPTIONAL MATCH (zombie)<-[:HAS_CREATOR]-(message:Message)
WHERE message.creationDate < localDateTime($endDate)
WITH
country,
zombie,
count(message) AS messageCount
WITH
country,
zombie,
12 * (localDateTime($endDate).year - zombie.creationDate.year )
+ (localDateTime($endDate).month - zombie.creationDate.month)
+ 1 AS months,
messageCount
WHERE messageCount / months < 1
WITH
country,
collect(zombie) AS zombies
UNWIND zombies AS zombie
OPTIONAL MATCH
(zombie)<-[:HAS_CREATOR]-(message:Message)<-[:LIKES]-(likerZombie:Person)
WHERE likerZombie IN zombies
WITH
zombie,
count(likerZombie) AS zombieLikeCount
OPTIONAL MATCH
(zombie)<-[:HAS_CREATOR]-(message:Message)<-[:LIKES]-(likerPerson:Person)
WHERE likerPerson.creationDate < localDateTime($endDate)
WITH
zombie,
zombieLikeCount,
count(likerPerson) AS totalLikeCount
RETURN
zombie.id,
zombieLikeCount,
totalLikeCount,
CASE totalLikeCount
WHEN 0 THEN 0.0
ELSE zombieLikeCount / toFloat(totalLikeCount)
END AS zombieScore
ORDER BY
zombieScore DESC,
zombie.id ASC
LIMIT 100
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
neo4j = (
"""
MATCH (country:Country {name: $country})<-[:IS_PART_OF]-(:City)<-[:IS_LOCATED_IN]-(zombie:Person)
WHERE zombie.creationDate < DateTime($endDate)
WITH country, zombie
OPTIONAL MATCH (zombie)<-[:HAS_CREATOR]-(message:Message)
WHERE message.creationDate < DateTime($endDate)
WITH
country,
zombie,
count(message) AS messageCount
WITH
country,
zombie,
12 * (DateTime($endDate).year - zombie.creationDate.year )
+ (DateTime($endDate).month - zombie.creationDate.month)
+ 1 AS months,
messageCount
WHERE messageCount / months < 1
WITH
country,
collect(zombie) AS zombies
UNWIND zombies AS zombie
OPTIONAL MATCH
(zombie)<-[:HAS_CREATOR]-(message:Message)<-[:LIKES]-(likerZombie:Person)
WHERE likerZombie IN zombies
WITH
zombie,
count(likerZombie) AS zombieLikeCount
OPTIONAL MATCH
(zombie)<-[:HAS_CREATOR]-(message:Message)<-[:LIKES]-(likerPerson:Person)
WHERE likerPerson.creationDate < DateTime($endDate)
WITH
zombie,
zombieLikeCount,
count(likerPerson) AS totalLikeCount
RETURN
zombie.id,
zombieLikeCount,
totalLikeCount,
CASE totalLikeCount
WHEN 0 THEN 0.0
ELSE zombieLikeCount / toFloat(totalLikeCount)
END AS zombieScore
ORDER BY
zombieScore DESC,
zombie.id ASC
LIMIT 100
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
if self._vendor == "memgraph":
return memgraph
else:
return neo4j
def benchmark__bi__query_14_analytical(self):
return (
"""
MATCH
(country1:Country {name: $country1})<-[:IS_PART_OF]-(city1:City)<-[:IS_LOCATED_IN]-(person1:Person),
(country2:Country {name: $country2})<-[:IS_PART_OF]-(city2:City)<-[:IS_LOCATED_IN]-(person2:Person),
(person1)-[:KNOWS]-(person2)
WITH person1, person2, city1, 0 AS score
OPTIONAL MATCH (person1)<-[:HAS_CREATOR]-(c:Comment)-[:REPLY_OF]->(:Message)-[:HAS_CREATOR]->(person2)
WITH DISTINCT person1, person2, city1, score + (CASE c WHEN null THEN 0 ELSE 4 END) AS score
OPTIONAL MATCH (person1)<-[:HAS_CREATOR]-(m:Message)<-[:REPLY_OF]-(:Comment)-[:HAS_CREATOR]->(person2)
WITH DISTINCT person1, person2, city1, score + (CASE m WHEN null THEN 0 ELSE 1 END) AS score
OPTIONAL MATCH (person1)-[:LIKES]->(m:Message)-[:HAS_CREATOR]->(person2)
WITH DISTINCT person1, person2, city1, score + (CASE m WHEN null THEN 0 ELSE 10 END) AS score
OPTIONAL MATCH (person1)<-[:HAS_CREATOR]-(m:Message)<-[:LIKES]-(person2)
WITH DISTINCT person1, person2, city1, score + (CASE m WHEN null THEN 0 ELSE 1 END) AS score
ORDER BY
city1.name ASC,
score DESC,
person1.id ASC,
person2.id ASC
WITH city1, collect({score: score, person1Id: person1.id, person2Id: person2.id})[0] AS top
RETURN
top.person1Id,
top.person2Id,
city1.name,
top.score
ORDER BY
top.score DESC,
top.person1Id ASC,
top.person2Id ASC
LIMIT 100
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
def benchmark__bi__query_17_analytical(self):
memgraph = (
"""
MATCH
(tag:Tag {name: $tag}),
(person1:Person)<-[:HAS_CREATOR]-(message1:Message)-[:REPLY_OF*0..]->(post1:Post)<-[:CONTAINER_OF]-(forum1:Forum),
(message1)-[:HAS_TAG]->(tag),
(forum1)<-[:HAS_MEMBER]->(person2:Person)<-[:HAS_CREATOR]-(comment:Comment)-[:HAS_TAG]->(tag),
(forum1)<-[:HAS_MEMBER]->(person3:Person)<-[:HAS_CREATOR]-(message2:Message),
(comment)-[:REPLY_OF]->(message2)-[:REPLY_OF*0..]->(post2:Post)<-[:CONTAINER_OF]-(forum2:Forum)
MATCH (comment)-[:HAS_TAG]->(tag)
MATCH (message2)-[:HAS_TAG]->(tag)
OPTIONAL MATCH (forum2)-[:HAS_MEMBER]->(person1)
WHERE forum1 <> forum2 AND message2.creationDate > message1.creationDate + duration({hours: $delta}) AND person1 IS NULL
RETURN person1, count(DISTINCT message2) AS messageCount
ORDER BY messageCount DESC, person1.id ASC
LIMIT 10
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
neo4j = (
"""
MATCH
(tag:Tag {name: $tag}),
(person1:Person)<-[:HAS_CREATOR]-(message1:Message)-[:REPLY_OF*0..]->(post1:Post)<-[:CONTAINER_OF]-(forum1:Forum),
(message1)-[:HAS_TAG]->(tag),
(forum1)<-[:HAS_MEMBER]->(person2:Person)<-[:HAS_CREATOR]-(comment:Comment)-[:HAS_TAG]->(tag),
(forum1)<-[:HAS_MEMBER]->(person3:Person)<-[:HAS_CREATOR]-(message2:Message),
(comment)-[:REPLY_OF]->(message2)-[:REPLY_OF*0..]->(post2:Post)<-[:CONTAINER_OF]-(forum2:Forum)
MATCH (comment)-[:HAS_TAG]->(tag)
MATCH (message2)-[:HAS_TAG]->(tag)
WHERE forum1 <> forum2
AND message2.creationDate > message1.creationDate + duration({hours: $delta})
AND NOT (forum2)-[:HAS_MEMBER]->(person1)
RETURN person1, count(DISTINCT message2) AS messageCount
ORDER BY messageCount DESC, person1.id ASC
LIMIT 10
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
if self._vendor == "memgraph":
return memgraph
else:
return neo4j
def benchmark__bi__query_18_analytical(self):
memgraph = (
"""
MATCH (tag:Tag {name: $tag})<-[:HAS_INTEREST]-(person1:Person)-[:KNOWS]-(mutualFriend:Person)-[:KNOWS]-(person2:Person)-[:HAS_INTEREST]->(tag)
OPTIONAL MATCH (person1)-[:KNOWS]-(person2)
WHERE person1 <> person2
RETURN person1.id AS person1Id, person2.id AS person2Id, count(DISTINCT mutualFriend) AS mutualFriendCount
ORDER BY mutualFriendCount DESC, person1Id ASC, person2Id ASC
LIMIT 20
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
neo4j = (
"""
MATCH (tag:Tag {name: $tag})<-[:HAS_INTEREST]-(person1:Person)-[:KNOWS]-(mutualFriend:Person)-[:KNOWS]-(person2:Person)-[:HAS_INTEREST]->(tag)
WHERE person1 <> person2
AND NOT (person1)-[:KNOWS]-(person2)
RETURN person1.id AS person1Id, person2.id AS person2Id, count(DISTINCT mutualFriend) AS mutualFriendCount
ORDER BY mutualFriendCount DESC, person1Id ASC, person2Id ASC
LIMIT 20
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
if self._vendor == "memgraph":
return memgraph
else:
return neo4j

View File

@ -0,0 +1,684 @@
import inspect
import random
from datetime import datetime
from pathlib import Path
import helpers
from benchmark_context import BenchmarkContext
from workloads.base import Workload
from workloads.importers.importer_ldbc_interactive import *
class LDBC_Interactive(Workload):
NAME = "ldbc_interactive"
VARIANTS = ["sf0.1", "sf1", "sf3", "sf10"]
DEFAULT_VARIANT = "sf1"
URL_FILE = {
"sf0.1": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/ldbc/benchmark/interactive/ldbc_interactive_sf0.1.cypher.gz",
"sf1": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/ldbc/benchmark/interactive/ldbc_interactive_sf1.cypher.gz",
"sf3": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/ldbc/benchmark/interactive/ldbc_interactive_sf3.cypher.gz",
"sf10": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/ldbc/benchmark/interactive/ldbc_interactive_sf10.cypher.gz",
}
URL_CSV = {
"sf0.1": "https://repository.surfsara.nl/datasets/cwi/snb/files/social_network-csv_basic/social_network-csv_basic-sf0.1.tar.zst",
"sf1": "https://repository.surfsara.nl/datasets/cwi/snb/files/social_network-csv_basic/social_network-csv_basic-sf1.tar.zst",
"sf3": "https://repository.surfsara.nl/datasets/cwi/snb/files/social_network-csv_basic/social_network-csv_basic-sf3.tar.zst",
"sf10": "https://repository.surfsara.nl/datasets/cwi/snb/files/social_network-csv_basic/social_network-csv_basic-sf10.tar.zst",
}
SIZES = {
"sf0.1": {"vertices": 327588, "edges": 1477965},
"sf1": {"vertices": 3181724, "edges": 17256038},
"sf3": {"vertices": 1, "edges": 1},
"sf10": {"vertices": 1, "edges": 1},
}
URL_INDEX_FILE = {
"memgraph": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/ldbc/benchmark/interactive/memgraph_interactive_index.cypher",
"neo4j": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/ldbc/benchmark/interactive/neo4j_interactive_index.cypher",
}
PROPERTIES_ON_EDGES = True
QUERY_PARAMETERS = {
"sf0.1": "https://repository.surfsara.nl/datasets/cwi/snb/files/substitution_parameters/substitution_parameters-sf0.1.tar.zst",
"sf1": "https://repository.surfsara.nl/datasets/cwi/snb/files/substitution_parameters/substitution_parameters-sf0.1.tar.zst",
"sf3": "https://repository.surfsara.nl/datasets/cwi/snb/files/substitution_parameters/substitution_parameters-sf0.1.tar.zst",
}
def custom_import(self) -> bool:
importer = ImporterLDBCInteractive(
benchmark_context=self.benchmark_context,
dataset_name=self.NAME,
variant=self._variant,
index_file=self._file_index,
csv_dict=self.URL_CSV,
)
return importer.execute_import()
def _prepare_parameters_directory(self):
parameters = Path() / ".cache" / "datasets" / self.NAME / self._variant / "parameters"
parameters.mkdir(parents=True, exist_ok=True)
dir_name = self.QUERY_PARAMETERS[self._variant].split("/")[-1:][0].removesuffix(".tar.zst")
if (parameters / dir_name).exists():
print("Files downloaded:")
parameters = parameters / dir_name
else:
print("Downloading files")
downloaded_file = helpers.download_file(self.QUERY_PARAMETERS[self._variant], parameters.absolute())
print("Unpacking the file..." + downloaded_file)
parameters = helpers.unpack_tar_zst(Path(downloaded_file))
return parameters
def _get_query_parameters(self) -> dict:
func_name = inspect.stack()[1].function
parameters = {}
for file in self._parameters_dir.glob("interactive_*.txt"):
if file.name.split("_")[1] == func_name.split("_")[-2]:
with file.open("r") as input:
lines = input.readlines()
position = random.randint(1, len(lines) - 1)
header = lines[0].strip("\n").split("|")
data = lines[position].strip("\n").split("|")
for i in range(len(header)):
if "Date" in header[i]:
time = int(data[i]) / 1000
converted = datetime.utcfromtimestamp(time).strftime("%Y-%m-%dT%H:%M:%S")
parameters[header[i]] = converted
elif data[i].isdigit():
parameters[header[i]] = int(data[i])
else:
parameters[header[i]] = data[i]
return parameters
def __init__(self, variant: str = None, benchmark_context: BenchmarkContext = None):
super().__init__(variant, benchmark_context=benchmark_context)
self._parameters_dir = self._prepare_parameters_directory()
self.benchmark_context = benchmark_context
def benchmark__interactive__complex_query_1_analytical(self):
memgraph = (
"""
MATCH (p:Person {id: $personId}), (friend:Person {firstName: $firstName})
WHERE NOT p=friend
WITH p, friend
MATCH path =((p)-[:KNOWS *BFS 1..3]-(friend))
WITH min(size(path)) AS distance, friend
ORDER BY
distance ASC,
friend.lastName ASC,
toInteger(friend.id) ASC
LIMIT 20
MATCH (friend)-[:IS_LOCATED_IN]->(friendCity:City)
OPTIONAL MATCH (friend)-[studyAt:STUDY_AT]->(uni:University)-[:IS_LOCATED_IN]->(uniCity:City)
WITH friend, collect(
CASE uni.name
WHEN null THEN null
ELSE [uni.name, studyAt.classYear, uniCity.name]
END ) AS unis, friendCity, distance
OPTIONAL MATCH (friend)-[workAt:WORK_AT]->(company:Company)-[:IS_LOCATED_IN]->(companyCountry:Country)
WITH friend, collect(
CASE company.name
WHEN null THEN null
ELSE [company.name, workAt.workFrom, companyCountry.name]
END ) AS companies, unis, friendCity, distance
RETURN
friend.id AS friendId,
friend.lastName AS friendLastName,
distance AS distanceFromPerson,
friend.birthday AS friendBirthday,
friend.gender AS friendGender,
friend.browserUsed AS friendBrowserUsed,
friend.locationIP AS friendLocationIp,
friend.email AS friendEmails,
friend.speaks AS friendLanguages,
friendCity.name AS friendCityName,
unis AS friendUniversities,
companies AS friendCompanies
ORDER BY
distanceFromPerson ASC,
friendLastName ASC,
toInteger(friendId) ASC
LIMIT 20
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
neo4j = (
"""
MATCH (p:Person {id: $personId}), (friend:Person {firstName: $firstName})
WHERE NOT p=friend
WITH p, friend
MATCH path = shortestPath((p)-[:KNOWS*1..3]-(friend))
WITH min(length(path)) AS distance, friend
ORDER BY
distance ASC,
friend.lastName ASC,
toInteger(friend.id) ASC
LIMIT 20
MATCH (friend)-[:IS_LOCATED_IN]->(friendCity:City)
OPTIONAL MATCH (friend)-[studyAt:STUDY_AT]->(uni:University)-[:IS_LOCATED_IN]->(uniCity:City)
WITH friend, collect(
CASE uni.name
WHEN null THEN null
ELSE [uni.name, studyAt.classYear, uniCity.name]
END ) AS unis, friendCity, distance
OPTIONAL MATCH (friend)-[workAt:WORK_AT]->(company:Company)-[:IS_LOCATED_IN]->(companyCountry:Country)
WITH friend, collect(
CASE company.name
WHEN null THEN null
ELSE [company.name, workAt.workFrom, companyCountry.name]
END ) AS companies, unis, friendCity, distance
RETURN
friend.id AS friendId,
friend.lastName AS friendLastName,
distance AS distanceFromPerson,
friend.birthday AS friendBirthday,
friend.gender AS friendGender,
friend.browserUsed AS friendBrowserUsed,
friend.locationIP AS friendLocationIp,
friend.email AS friendEmails,
friend.speaks AS friendLanguages,
friendCity.name AS friendCityName,
unis AS friendUniversities,
companies AS friendCompanies
ORDER BY
distanceFromPerson ASC,
friendLastName ASC,
toInteger(friendId) ASC
LIMIT 20
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
if self._vendor == "memgraph":
return memgraph
else:
return neo4j
def benchmark__interactive__complex_query_2_analytical(self):
return (
"""
MATCH (:Person {id: $personId })-[:KNOWS]-(friend:Person)<-[:HAS_CREATOR]-(message:Message)
WHERE message.creationDate <= localDateTime($maxDate)
RETURN
friend.id AS personId,
friend.firstName AS personFirstName,
friend.lastName AS personLastName,
message.id AS postOrCommentId,
coalesce(message.content,message.imageFile) AS postOrCommentContent,
message.creationDate AS postOrCommentCreationDate
ORDER BY
postOrCommentCreationDate DESC,
toInteger(postOrCommentId) ASC
LIMIT 20
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
def benchmark__interactive__complex_query_3_analytical(self):
memgraph = (
"""
MATCH (countryX:Country {name: $countryXName }),
(countryY:Country {name: $countryYName }),
(person:Person {id: $personId })
WITH person, countryX, countryY
LIMIT 1
MATCH (city:City)-[:IS_PART_OF]->(country:Country)
WHERE country IN [countryX, countryY]
WITH person, countryX, countryY, collect(city) AS cities
MATCH (person)-[:KNOWS*1..2]-(friend)-[:IS_LOCATED_IN]->(city)
WHERE NOT person=friend AND NOT city IN cities
WITH DISTINCT friend, countryX, countryY
MATCH (friend)<-[:HAS_CREATOR]-(message),
(message)-[:IS_LOCATED_IN]->(country)
WHERE localDateTime($startDate) + duration({day:$durationDays}) > message.creationDate >= localDateTime($startDate) AND
country IN [countryX, countryY]
WITH friend,
CASE WHEN country=countryX THEN 1 ELSE 0 END AS messageX,
CASE WHEN country=countryY THEN 1 ELSE 0 END AS messageY
WITH friend, sum(messageX) AS xCount, sum(messageY) AS yCount
WHERE xCount>0 AND yCount>0
RETURN friend.id AS friendId,
friend.firstName AS friendFirstName,
friend.lastName AS friendLastName,
xCount,
yCount,
xCount + yCount AS xyCount
ORDER BY xyCount DESC, friendId ASC
LIMIT 20
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
neo4j = (
"""
MATCH (countryX:Country {name: $countryXName }),
(countryY:Country {name: $countryYName }),
(person:Person {id: $personId })
WITH person, countryX, countryY
LIMIT 1
MATCH (city:City)-[:IS_PART_OF]->(country:Country)
WHERE country IN [countryX, countryY]
WITH person, countryX, countryY, collect(city) AS cities
MATCH (person)-[:KNOWS*1..2]-(friend)-[:IS_LOCATED_IN]->(city)
WHERE NOT person=friend AND NOT city IN cities
WITH DISTINCT friend, countryX, countryY
MATCH (friend)<-[:HAS_CREATOR]-(message),
(message)-[:IS_LOCATED_IN]->(country)
WHERE localDateTime($startDate) + duration({days:$durationDays}) > message.creationDate >= localDateTime($startDate) AND
country IN [countryX, countryY]
WITH friend,
CASE WHEN country=countryX THEN 1 ELSE 0 END AS messageX,
CASE WHEN country=countryY THEN 1 ELSE 0 END AS messageY
WITH friend, sum(messageX) AS xCount, sum(messageY) AS yCount
WHERE xCount>0 AND yCount>0
RETURN friend.id AS friendId,
friend.firstName AS friendFirstName,
friend.lastName AS friendLastName,
xCount,
yCount,
xCount + yCount AS xyCount
ORDER BY xyCount DESC, friendId ASC
LIMIT 20
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
if self._vendor == "memgraph":
return memgraph
else:
return neo4j
def benchmark__interactive__complex_query_4_analytical(self):
memgraph = (
"""
MATCH (person:Person {id: $personId })-[:KNOWS]-(friend:Person),
(friend)<-[:HAS_CREATOR]-(post:Post)-[:HAS_TAG]->(tag)
WITH DISTINCT tag, post
WITH tag,
CASE
WHEN localDateTime($startDate) + duration({day:$durationDays}) > post.creationDate >= localDateTime($startDate) THEN 1
ELSE 0
END AS valid,
CASE
WHEN localDateTime($startDate) > post.creationDate THEN 1
ELSE 0
END AS inValid
WITH tag, sum(valid) AS postCount, sum(inValid) AS inValidPostCount
WHERE postCount>0 AND inValidPostCount=0
RETURN tag.name AS tagName, postCount
ORDER BY postCount DESC, tagName ASC
LIMIT 10
""",
self._get_query_parameters(),
)
neo4j = (
"""
MATCH (person:Person {id: $personId })-[:KNOWS]-(friend:Person),
(friend)<-[:HAS_CREATOR]-(post:Post)-[:HAS_TAG]->(tag)
WITH DISTINCT tag, post
WITH tag,
CASE
WHEN localDateTime($startDate) + duration({days:$durationDays}) > post.creationDate >= localDateTime($startDate) THEN 1
ELSE 0
END AS valid,
CASE
WHEN localDateTime($startDate) > post.creationDate THEN 1
ELSE 0
END AS inValid
WITH tag, sum(valid) AS postCount, sum(inValid) AS inValidPostCount
WHERE postCount>0 AND inValidPostCount=0
RETURN tag.name AS tagName, postCount
ORDER BY postCount DESC, tagName ASC
LIMIT 10
""",
self._get_query_parameters(),
)
if self._vendor == "memgraph":
return memgraph
else:
return neo4j
def benchmark__interactive__complex_query_5_analytical(self):
return (
"""
MATCH (person:Person { id: $personId })-[:KNOWS*1..2]-(friend)
WHERE
NOT person=friend
WITH DISTINCT friend
MATCH (friend)<-[membership:HAS_MEMBER]-(forum)
WHERE
membership.joinDate > localDateTime($minDate)
WITH
forum,
collect(friend) AS friends
OPTIONAL MATCH (friend)<-[:HAS_CREATOR]-(post)<-[:CONTAINER_OF]-(forum)
WHERE
friend IN friends
WITH
forum,
count(post) AS postCount
RETURN
forum.title AS forumName,
postCount
ORDER BY
postCount DESC,
forum.id ASC
LIMIT 20
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
def benchmark__interactive__complex_query_6_analytical(self):
return (
"""
MATCH (knownTag:Tag { name: $tagName })
WITH knownTag.id as knownTagId
MATCH (person:Person { id: $personId })-[:KNOWS*1..2]-(friend)
WHERE NOT person=friend
WITH
knownTagId,
collect(distinct friend) as friends
UNWIND friends as f
MATCH (f)<-[:HAS_CREATOR]-(post:Post),
(post)-[:HAS_TAG]->(t:Tag{id: knownTagId}),
(post)-[:HAS_TAG]->(tag:Tag)
WHERE NOT t = tag
WITH
tag.name as tagName,
count(post) as postCount
RETURN
tagName,
postCount
ORDER BY
postCount DESC,
tagName ASC
LIMIT 10
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
def benchmark__interactive__complex_query_7_analytical(self):
memgraph = (
"""
MATCH (person:Person {id: $personId})<-[:HAS_CREATOR]-(message:Message)<-[like:LIKES]-(liker:Person)
WITH liker, message, like.creationDate AS likeTime, person
ORDER BY likeTime DESC, toInteger(message.id) ASC
WITH liker, head(collect({msg: message, likeTime: likeTime})) AS latestLike, person
OPTIONAL MATCH (liker)-[:KNOWS]-(person)
WITH liker, latestLike, person,
CASE WHEN person IS null THEN TRUE ELSE FALSE END AS isNew
RETURN
liker.id AS personId,
liker.firstName AS personFirstName,
liker.lastName AS personLastName,
latestLike.likeTime AS likeCreationDate,
latestLike.msg.id AS commentOrPostId,
coalesce(latestLike.msg.content, latestLike.msg.imageFile) AS commentOrPostContent,
(latestLike.likeTime - latestLike.msg.creationDate).minute AS minutesLatency
ORDER BY
likeCreationDate DESC,
toInteger(personId) ASC
LIMIT 20
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
neo4j = (
"""
MATCH (person:Person {id: $personId})<-[:HAS_CREATOR]-(message:Message)<-[like:LIKES]-(liker:Person)
WITH liker, message, like.creationDate AS likeTime, person
ORDER BY likeTime DESC, toInteger(message.id) ASC
WITH liker, head(collect({msg: message, likeTime: likeTime})) AS latestLike, person
RETURN
liker.id AS personId,
liker.firstName AS personFirstName,
liker.lastName AS personLastName,
latestLike.likeTime AS likeCreationDate,
latestLike.msg.id AS commentOrPostId,
coalesce(latestLike.msg.content, latestLike.msg.imageFile) AS commentOrPostContent,
duration.between(latestLike.likeTime, latestLike.msg.creationDate).minutes AS minutesLatency,
not((liker)-[:KNOWS]-(person)) AS isNew
ORDER BY
likeCreationDate DESC,
toInteger(personId) ASC
LIMIT 20
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
if self._vendor == "memgraph":
return memgraph
else:
return neo4j
def benchmark__interactive__complex_query_8_analytical(self):
return (
"""
MATCH (start:Person {id: $personId})<-[:HAS_CREATOR]-(:Message)<-[:REPLY_OF]-(comment:Comment)-[:HAS_CREATOR]->(person:Person)
RETURN
person.id AS personId,
person.firstName AS personFirstName,
person.lastName AS personLastName,
comment.creationDate AS commentCreationDate,
comment.id AS commentId,
comment.content AS commentContent
ORDER BY
commentCreationDate DESC,
commentId ASC
LIMIT 20
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
def benchmark__interactive__complex_query_9_analytical(self):
return (
"""
MATCH (root:Person {id: $personId })-[:KNOWS*1..2]-(friend:Person)
WHERE NOT friend = root
WITH collect(distinct friend) as friends
UNWIND friends as friend
MATCH (friend)<-[:HAS_CREATOR]-(message:Message)
WHERE message.creationDate < localDateTime($maxDate)
RETURN
friend.id AS personId,
friend.firstName AS personFirstName,
friend.lastName AS personLastName,
message.id AS commentOrPostId,
coalesce(message.content,message.imageFile) AS commentOrPostContent,
message.creationDate AS commentOrPostCreationDate
ORDER BY
commentOrPostCreationDate DESC,
message.id ASC
LIMIT 20
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
def benchmark__interactive__complex_query_10_analytical(self):
memgraph = (
"""
MATCH (person:Person {id: $personId})-[:KNOWS*2..2]-(friend),
(friend)-[:IS_LOCATED_IN]->(city:City)
WHERE NOT friend=person AND
NOT (friend)-[:KNOWS]-(person)
WITH person, city, friend, datetime({epochMillis: friend.birthday}) as birthday
WHERE (birthday.month=$month AND birthday.day>=21) OR
(birthday.month=($month%12)+1 AND birthday.day<22)
WITH DISTINCT friend, city, person
OPTIONAL MATCH (friend)<-[:HAS_CREATOR]-(post:Post)
WITH friend, city, collect(post) AS posts, person
WITH friend,
city,
size(posts) AS postCount,
size([p IN posts WHERE (p)-[:HAS_TAG]->()<-[:HAS_INTEREST]-(person)]) AS commonPostCount
RETURN friend.id AS personId,
friend.firstName AS personFirstName,
friend.lastName AS personLastName,
commonPostCount - (postCount - commonPostCount) AS commonInterestScore,
friend.gender AS personGender,
city.name AS personCityName
ORDER BY commonInterestScore DESC, personId ASC
LIMIT 10
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
neo4j = (
"""
MATCH (person:Person {id: $personId})-[:KNOWS*2..2]-(friend),
(friend)-[:IS_LOCATED_IN]->(city:City)
WHERE NOT friend=person AND
NOT (friend)-[:KNOWS]-(person)
WITH person, city, friend, datetime({epochMillis: friend.birthday}) as birthday
WHERE (birthday.month=$month AND birthday.day>=21) OR
(birthday.month=($month%12)+1 AND birthday.day<22)
WITH DISTINCT friend, city, person
OPTIONAL MATCH (friend)<-[:HAS_CREATOR]-(post:Post)
WITH friend, city, collect(post) AS posts, person
WITH friend,
city,
size(posts) AS postCount,
size([p IN posts WHERE (p)-[:HAS_TAG]->()<-[:HAS_INTEREST]-(person)]) AS commonPostCount
RETURN friend.id AS personId,
friend.firstName AS personFirstName,
friend.lastName AS personLastName,
commonPostCount - (postCount - commonPostCount) AS commonInterestScore,
friend.gender AS personGender,
city.name AS personCityName
ORDER BY commonInterestScore DESC, personId ASC
LIMIT 10
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
if self._vendor == "memgraph":
return memgraph
else:
return neo4j
def benchmark__interactive__complex_query_11_analytical(self):
return (
"""
MATCH (person:Person {id: $personId })-[:KNOWS*1..2]-(friend:Person)
WHERE not(person=friend)
WITH DISTINCT friend
MATCH (friend)-[workAt:WORK_AT]->(company:Company)-[:IS_LOCATED_IN]->(:Country {name: $countryName })
WHERE workAt.workFrom < $workFromYear
RETURN
friend.id AS personId,
friend.firstName AS personFirstName,
friend.lastName AS personLastName,
company.name AS organizationName,
workAt.workFrom AS organizationWorkFromYear
ORDER BY
organizationWorkFromYear ASC,
toInteger(personId) ASC,
organizationName DESC
LIMIT 10
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
def benchmark__interactive__complex_query_12_analytical(self):
return (
"""
MATCH (tag:Tag)-[:HAS_TYPE|IS_SUBCLASS_OF*0..]->(baseTagClass:TagClass)
WHERE tag.name = $tagClassName OR baseTagClass.name = $tagClassName
WITH collect(tag.id) as tags
MATCH (:Person {id: $personId })-[:KNOWS]-(friend:Person)<-[:HAS_CREATOR]-(comment:Comment)-[:REPLY_OF]->(:Post)-[:HAS_TAG]->(tag:Tag)
WHERE tag.id in tags
RETURN
friend.id AS personId,
friend.firstName AS personFirstName,
friend.lastName AS personLastName,
collect(DISTINCT tag.name) AS tagNames,
count(DISTINCT comment) AS replyCount
ORDER BY
replyCount DESC,
toInteger(personId) ASC
LIMIT 20
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
def benchmark__interactive__complex_query_13_analytical(self):
memgraph = (
"""
MATCH
(person1:Person {id: $person1Id}),
(person2:Person {id: $person2Id}),
path = (person1)-[:KNOWS *BFS]-(person2)
RETURN
CASE path IS NULL
WHEN true THEN -1
ELSE size(path)
END AS shortestPathLength
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
neo4j = (
"""
MATCH
(person1:Person {id: $person1Id}),
(person2:Person {id: $person2Id}),
path = shortestPath((person1)-[:KNOWS*]-(person2))
RETURN
CASE path IS NULL
WHEN true THEN -1
ELSE length(path)
END AS shortestPathLength
""".replace(
"\n", ""
),
self._get_query_parameters(),
)
if self._vendor == "memgraph":
return memgraph
else:
return neo4j

View File

@ -1,134 +1,17 @@
# Copyright 2022 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import random
import helpers
from benchmark_context import BenchmarkContext
from workloads.base import Workload
from workloads.importers.importer_pokec import ImporterPokec
# Base dataset class used as a template to create each individual dataset. All
# common logic is handled here.
class Dataset:
# Name of the dataset.
NAME = "Base dataset"
# List of all variants of the dataset that exist.
VARIANTS = ["default"]
# One of the available variants that should be used as the default variant.
DEFAULT_VARIANT = "default"
# List of query files that should be used to import the dataset.
FILES = {
"default": "/foo/bar",
}
INDEX = None
INDEX_FILES = {"default": ""}
# List of query file URLs that should be used to import the dataset.
URLS = None
# Number of vertices/edges for each variant.
SIZES = {
"default": {"vertices": 0, "edges": 0},
}
# Indicates whether the dataset has properties on edges.
PROPERTIES_ON_EDGES = False
def __init__(self, variant=None, vendor=None):
"""
Accepts a `variant` variable that indicates which variant
of the dataset should be executed.
"""
if variant is None:
variant = self.DEFAULT_VARIANT
if variant not in self.VARIANTS:
raise ValueError("Invalid test variant!")
if (self.FILES and variant not in self.FILES) and (self.URLS and variant not in self.URLS):
raise ValueError("The variant doesn't have a defined URL or " "file path!")
if variant not in self.SIZES:
raise ValueError("The variant doesn't have a defined dataset " "size!")
if vendor not in self.INDEX_FILES:
raise ValueError("Vendor does not have INDEX for dataset!")
self._variant = variant
self._vendor = vendor
if self.FILES is not None:
self._file = self.FILES.get(variant, None)
else:
self._file = None
if self.URLS is not None:
self._url = self.URLS.get(variant, None)
else:
self._url = None
if self.INDEX_FILES is not None:
self._index = self.INDEX_FILES.get(vendor, None)
else:
self._index = None
self._size = self.SIZES[variant]
if "vertices" not in self._size or "edges" not in self._size:
raise ValueError("The size defined for this variant doesn't " "have the number of vertices and/or edges!")
self._num_vertices = self._size["vertices"]
self._num_edges = self._size["edges"]
def prepare(self, directory):
if self._file is not None:
print("Using dataset file:", self._file)
else:
# TODO: add support for JSON datasets
cached_input, exists = directory.get_file("dataset.cypher")
if not exists:
print("Downloading dataset file:", self._url)
downloaded_file = helpers.download_file(self._url, directory.get_path())
print("Unpacking and caching file:", downloaded_file)
helpers.unpack_and_move_file(downloaded_file, cached_input)
print("Using cached dataset file:", cached_input)
self._file = cached_input
cached_index, exists = directory.get_file(self._vendor + ".cypher")
if not exists:
print("Downloading index file:", self._index)
downloaded_file = helpers.download_file(self._index, directory.get_path())
print("Unpacking and caching file:", downloaded_file)
helpers.unpack_and_move_file(downloaded_file, cached_index)
print("Using cached index file:", cached_index)
self._index = cached_index
def get_variant(self):
"""Returns the current variant of the dataset."""
return self._variant
def get_index(self):
"""Get index file, defined by vendor"""
return self._index
def get_file(self):
"""
Returns path to the file that contains dataset creation queries.
"""
return self._file
def get_size(self):
"""Returns number of vertices/edges for the current variant."""
return self._size
# All tests should be query generator functions that output all of the
# queries that should be executed by the runner. The functions should be
# named `benchmark__GROUPNAME__TESTNAME` and should not accept any
# arguments.
class Pokec(Dataset):
class Pokec(Workload):
NAME = "pokec"
VARIANTS = ["small", "medium", "large"]
DEFAULT_VARIANT = "small"
FILES = None
FILE = None
URLS = {
URL_FILE = {
"small": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/pokec/benchmark/pokec_small_import.cypher",
"medium": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/pokec/benchmark/pokec_medium_import.cypher",
"large": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/pokec/benchmark/pokec_large.setup.cypher.gz",
@ -138,16 +21,28 @@ class Pokec(Dataset):
"medium": {"vertices": 100000, "edges": 1768515},
"large": {"vertices": 1632803, "edges": 30622564},
}
INDEX = None
INDEX_FILES = {
URL_INDEX_FILE = {
"memgraph": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/pokec/benchmark/memgraph.cypher",
"neo4j": "https://s3.eu-west-1.amazonaws.com/deps.memgraph.io/dataset/pokec/benchmark/neo4j.cypher",
}
PROPERTIES_ON_EDGES = False
# Helpers used to generate the queries
def __init__(self, variant: str = None, benchmark_context: BenchmarkContext = None):
super().__init__(variant, benchmark_context=benchmark_context)
def custom_import(self) -> bool:
importer = ImporterPokec(
benchmark_context=self.benchmark_context,
dataset_name=self.NAME,
index_file=self._file_index,
dataset_file=self._file,
variant=self._variant,
)
return importer.execute_import()
# Helpers used to generate the queries
def _get_random_vertex(self):
# All vertices in the Pokec dataset have an ID in the range
# [1, _num_vertices].
@ -343,7 +238,7 @@ class Pokec(Dataset):
return ("MATCH (n:User {id: $id}) RETURN n", {"id": self._get_random_vertex()})
def benchmark__match__vertex_on_property(self):
return ("MATCH (n {id: $id}) RETURN n", {"id": self._get_random_vertex()})
return ("MATCH (n:User {id: $id}) RETURN n", {"id": self._get_random_vertex()})
def benchmark__update__vertex_on_property(self):
return (
@ -364,7 +259,7 @@ class Pokec(Dataset):
def benchmark__basic__single_vertex_property_update_update(self):
return (
"MATCH (n {id: $id}) SET n.property = -1",
"MATCH (n:User {id: $id}) SET n.property = -1",
{"id": self._get_random_vertex()},
)

View File

@ -130,6 +130,12 @@ target_link_libraries(${test_prefix}query_serialization_property_value mg-query)
add_unit_test(query_streams.cpp)
target_link_libraries(${test_prefix}query_streams mg-query kafka-mock)
add_unit_test(transaction_queue.cpp)
target_link_libraries(${test_prefix}transaction_queue mg-communication mg-query mg-glue)
add_unit_test(transaction_queue_multiple.cpp)
target_link_libraries(${test_prefix}transaction_queue_multiple mg-communication mg-query mg-glue)
# Test query functions
add_unit_test(query_function_mgp_module.cpp)
target_link_libraries(${test_prefix}query_function_mgp_module mg-query)
@ -360,15 +366,6 @@ if(MG_ENTERPRISE)
target_link_libraries(${test_prefix}rpc mg-rpc)
endif()
# Test LCP
add_custom_command(
OUTPUT test_lcp
DEPENDS ${lcp_src_files} lcp test_lcp.lisp
COMMAND sbcl --script ${CMAKE_CURRENT_SOURCE_DIR}/test_lcp.lisp)
add_custom_target(test_lcp ALL DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/test_lcp)
add_test(test_lcp ${CMAKE_CURRENT_BINARY_DIR}/test_lcp)
add_dependencies(memgraph__unit test_lcp)
# Test websocket
find_package(Boost REQUIRED)

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -18,6 +18,7 @@
#include "glue/communication.hpp"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "interpreter_faker.hpp"
#include "query/auth_checker.hpp"
#include "query/config.hpp"
#include "query/exceptions.hpp"
@ -40,57 +41,18 @@ auto ToEdgeList(const memgraph::communication::bolt::Value &v) {
return list;
};
struct InterpreterFaker {
InterpreterFaker(memgraph::storage::Storage *db, const memgraph::query::InterpreterConfig config,
const std::filesystem::path &data_directory)
: interpreter_context(db, config, data_directory), interpreter(&interpreter_context) {
interpreter_context.auth_checker = &auth_checker;
}
auto Prepare(const std::string &query, const std::map<std::string, memgraph::storage::PropertyValue> &params = {}) {
ResultStreamFaker stream(interpreter_context.db);
const auto [header, _, qid] = interpreter.Prepare(query, params, nullptr);
stream.Header(header);
return std::make_pair(std::move(stream), qid);
}
void Pull(ResultStreamFaker *stream, std::optional<int> n = {}, std::optional<int> qid = {}) {
const auto summary = interpreter.Pull(stream, n, qid);
stream->Summary(summary);
}
/**
* Execute the given query and commit the transaction.
*
* Return the query stream.
*/
auto Interpret(const std::string &query, const std::map<std::string, memgraph::storage::PropertyValue> &params = {}) {
auto prepare_result = Prepare(query, params);
auto &stream = prepare_result.first;
auto summary = interpreter.Pull(&stream, {}, prepare_result.second);
stream.Summary(summary);
return std::move(stream);
}
memgraph::query::AllowEverythingAuthChecker auth_checker;
memgraph::query::InterpreterContext interpreter_context;
memgraph::query::Interpreter interpreter;
};
} // namespace
// TODO: This is not a unit test, but tests/integration dir is chaotic at the
// moment. After tests refactoring is done, move/rename this.
class InterpreterTest : public ::testing::Test {
protected:
public:
memgraph::storage::Storage db_;
std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_tests_unit_interpreter"};
memgraph::query::InterpreterContext interpreter_context{&db_, {}, data_directory};
InterpreterFaker default_interpreter{&db_, {}, data_directory};
InterpreterFaker default_interpreter{&interpreter_context};
auto Prepare(const std::string &query, const std::map<std::string, memgraph::storage::PropertyValue> &params = {}) {
return default_interpreter.Prepare(query, params);
@ -638,8 +600,6 @@ TEST_F(InterpreterTest, UniqueConstraintTest) {
}
TEST_F(InterpreterTest, ExplainQuery) {
const auto &interpreter_context = default_interpreter.interpreter_context;
EXPECT_EQ(interpreter_context.plan_cache.size(), 0U);
EXPECT_EQ(interpreter_context.ast_cache.size(), 0U);
auto stream = Interpret("EXPLAIN MATCH (n) RETURN *;");
@ -663,8 +623,6 @@ TEST_F(InterpreterTest, ExplainQuery) {
}
TEST_F(InterpreterTest, ExplainQueryMultiplePulls) {
const auto &interpreter_context = default_interpreter.interpreter_context;
EXPECT_EQ(interpreter_context.plan_cache.size(), 0U);
EXPECT_EQ(interpreter_context.ast_cache.size(), 0U);
auto [stream, qid] = Prepare("EXPLAIN MATCH (n) RETURN *;");
@ -698,8 +656,6 @@ TEST_F(InterpreterTest, ExplainQueryMultiplePulls) {
}
TEST_F(InterpreterTest, ExplainQueryInMulticommandTransaction) {
const auto &interpreter_context = default_interpreter.interpreter_context;
EXPECT_EQ(interpreter_context.plan_cache.size(), 0U);
EXPECT_EQ(interpreter_context.ast_cache.size(), 0U);
Interpret("BEGIN");
@ -725,8 +681,6 @@ TEST_F(InterpreterTest, ExplainQueryInMulticommandTransaction) {
}
TEST_F(InterpreterTest, ExplainQueryWithParams) {
const auto &interpreter_context = default_interpreter.interpreter_context;
EXPECT_EQ(interpreter_context.plan_cache.size(), 0U);
EXPECT_EQ(interpreter_context.ast_cache.size(), 0U);
auto stream =
@ -751,8 +705,6 @@ TEST_F(InterpreterTest, ExplainQueryWithParams) {
}
TEST_F(InterpreterTest, ProfileQuery) {
const auto &interpreter_context = default_interpreter.interpreter_context;
EXPECT_EQ(interpreter_context.plan_cache.size(), 0U);
EXPECT_EQ(interpreter_context.ast_cache.size(), 0U);
auto stream = Interpret("PROFILE MATCH (n) RETURN *;");
@ -776,8 +728,6 @@ TEST_F(InterpreterTest, ProfileQuery) {
}
TEST_F(InterpreterTest, ProfileQueryMultiplePulls) {
const auto &interpreter_context = default_interpreter.interpreter_context;
EXPECT_EQ(interpreter_context.plan_cache.size(), 0U);
EXPECT_EQ(interpreter_context.ast_cache.size(), 0U);
auto [stream, qid] = Prepare("PROFILE MATCH (n) RETURN *;");
@ -820,8 +770,6 @@ TEST_F(InterpreterTest, ProfileQueryInMulticommandTransaction) {
}
TEST_F(InterpreterTest, ProfileQueryWithParams) {
const auto &interpreter_context = default_interpreter.interpreter_context;
EXPECT_EQ(interpreter_context.plan_cache.size(), 0U);
EXPECT_EQ(interpreter_context.ast_cache.size(), 0U);
auto stream =
@ -846,8 +794,6 @@ TEST_F(InterpreterTest, ProfileQueryWithParams) {
}
TEST_F(InterpreterTest, ProfileQueryWithLiterals) {
const auto &interpreter_context = default_interpreter.interpreter_context;
EXPECT_EQ(interpreter_context.plan_cache.size(), 0U);
EXPECT_EQ(interpreter_context.ast_cache.size(), 0U);
auto stream = Interpret("PROFILE UNWIND range(1, 1000) AS x CREATE (:Node {id: x});", {});
@ -1087,7 +1033,6 @@ TEST_F(InterpreterTest, LoadCsvClause) {
}
TEST_F(InterpreterTest, CacheableQueries) {
const auto &interpreter_context = default_interpreter.interpreter_context;
// This should be cached
{
SCOPED_TRACE("Cacheable query");
@ -1120,7 +1065,9 @@ TEST_F(InterpreterTest, AllowLoadCsvConfig) {
"CREATE TRIGGER trigger ON CREATE BEFORE COMMIT EXECUTE LOAD CSV FROM 'file.csv' WITH HEADER AS row RETURN "
"row"};
InterpreterFaker interpreter_faker{&db_, {.query = {.allow_load_csv = allow_load_csv}}, directory_manager.Path()};
memgraph::query::InterpreterContext csv_interpreter_context{
&db_, {.query = {.allow_load_csv = allow_load_csv}}, directory_manager.Path()};
InterpreterFaker interpreter_faker{&csv_interpreter_context};
for (const auto &query : queries) {
if (allow_load_csv) {
SCOPED_TRACE(fmt::format("'{}' should not throw because LOAD CSV is allowed", query));

View File

@ -0,0 +1,49 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "communication/result_stream_faker.hpp"
#include "query/interpreter.hpp"
struct InterpreterFaker {
InterpreterFaker(memgraph::query::InterpreterContext *interpreter_context)
: interpreter_context(interpreter_context), interpreter(interpreter_context) {
interpreter_context->auth_checker = &auth_checker;
interpreter_context->interpreters.WithLock([this](auto &interpreters) { interpreters.insert(&interpreter); });
}
auto Prepare(const std::string &query, const std::map<std::string, memgraph::storage::PropertyValue> &params = {}) {
ResultStreamFaker stream(interpreter_context->db);
const auto [header, _, qid] = interpreter.Prepare(query, params, nullptr);
stream.Header(header);
return std::make_pair(std::move(stream), qid);
}
void Pull(ResultStreamFaker *stream, std::optional<int> n = {}, std::optional<int> qid = {}) {
const auto summary = interpreter.Pull(stream, n, qid);
stream->Summary(summary);
}
/**
* Execute the given query and commit the transaction.
*
* Return the query stream.
*/
auto Interpret(const std::string &query, const std::map<std::string, memgraph::storage::PropertyValue> &params = {}) {
auto prepare_result = Prepare(query, params);
auto &stream = prepare_result.first;
auto summary = interpreter.Pull(&stream, {}, prepare_result.second);
stream.Summary(summary);
return std::move(stream);
}
memgraph::query::AllowEverythingAuthChecker auth_checker;
memgraph::query::InterpreterContext *interpreter_context;
memgraph::query::Interpreter interpreter;
};

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -39,7 +39,6 @@ class MockAuthChecker : public memgraph::query::AuthChecker {
public:
MOCK_CONST_METHOD2(IsUserAuthorized, bool(const std::optional<std::string> &username,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges));
#ifdef MG_ENTERPRISE
MOCK_CONST_METHOD2(GetFineGrainedAuthChecker,
std::unique_ptr<memgraph::query::FineGrainedAuthChecker>(

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -28,7 +28,7 @@ struct SumReq {
int y;
};
const memgraph::utils::TypeInfo SumReq::kType{0, "SumReq"};
const memgraph::utils::TypeInfo SumReq::kType{memgraph::utils::TypeId::UNKNOWN, "SumReq"};
struct SumRes {
static const memgraph::utils::TypeInfo kType;
@ -42,7 +42,7 @@ struct SumRes {
int sum;
};
const memgraph::utils::TypeInfo SumRes::kType{1, "SumRes"};
const memgraph::utils::TypeInfo SumRes::kType{memgraph::utils::TypeId::UNKNOWN, "SumRes"};
namespace memgraph::slk {
void Save(const SumReq &sum, Builder *builder);
@ -66,7 +66,7 @@ struct EchoMessage {
std::string data;
};
const memgraph::utils::TypeInfo EchoMessage::kType{2, "EchoMessage"};
const memgraph::utils::TypeInfo EchoMessage::kType{memgraph::utils::TypeId::UNKNOWN, "EchoMessage"};
namespace memgraph::slk {
void Save(const EchoMessage &echo, Builder *builder);

View File

@ -0,0 +1,75 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <chrono>
#include <stop_token>
#include <string>
#include <thread>
#include <gtest/gtest.h>
#include "gmock/gmock.h"
#include "interpreter_faker.hpp"
/*
Tests rely on the fact that interpreters are sequentially added to runninng_interpreters to get transaction_id of its
corresponding interpreter/.
*/
class TransactionQueueSimpleTest : public ::testing::Test {
protected:
memgraph::storage::Storage db_;
std::filesystem::path data_directory{std::filesystem::temp_directory_path() / "MG_tests_unit_transaction_queue_intr"};
memgraph::query::InterpreterContext interpreter_context{&db_, {}, data_directory};
InterpreterFaker running_interpreter{&interpreter_context}, main_interpreter{&interpreter_context};
};
TEST_F(TransactionQueueSimpleTest, TwoInterpretersInterleaving) {
bool started = false;
std::jthread running_thread = std::jthread(
[this, &started](std::stop_token st, int thread_index) {
running_interpreter.Interpret("BEGIN");
started = true;
},
0);
{
while (!started) {
std::this_thread::sleep_for(std::chrono::milliseconds(20));
}
main_interpreter.Interpret("CREATE (:Person {prop: 1})");
auto show_stream = main_interpreter.Interpret("SHOW TRANSACTIONS");
ASSERT_EQ(show_stream.GetResults().size(), 2U);
// superadmin executing the transaction
EXPECT_EQ(show_stream.GetResults()[0][0].ValueString(), "");
ASSERT_TRUE(show_stream.GetResults()[0][1].IsString());
EXPECT_EQ(show_stream.GetResults()[0][2].ValueList().at(0).ValueString(), "SHOW TRANSACTIONS");
// Also anonymous user executing
EXPECT_EQ(show_stream.GetResults()[1][0].ValueString(), "");
ASSERT_TRUE(show_stream.GetResults()[1][1].IsString());
// Kill the other transaction
std::string run_trans_id = show_stream.GetResults()[1][1].ValueString();
std::string esc_run_trans_id = "'" + run_trans_id + "'";
auto terminate_stream = main_interpreter.Interpret("TERMINATE TRANSACTIONS " + esc_run_trans_id);
// check result of killing
ASSERT_EQ(terminate_stream.GetResults().size(), 1U);
EXPECT_EQ(terminate_stream.GetResults()[0][0].ValueString(), run_trans_id);
ASSERT_TRUE(terminate_stream.GetResults()[0][1].ValueBool()); // that the transaction is actually killed
// check the number of transactions now
auto show_stream_after_killing = main_interpreter.Interpret("SHOW TRANSACTIONS");
ASSERT_EQ(show_stream_after_killing.GetResults().size(), 1U);
// test the state of the database
auto results_stream = main_interpreter.Interpret("MATCH (n) RETURN n");
ASSERT_EQ(results_stream.GetResults().size(), 1U); // from the main interpreter
main_interpreter.Interpret("MATCH (n) DETACH DELETE n");
// finish thread
running_thread.request_stop();
}
}

View File

@ -0,0 +1,118 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <chrono>
#include <random>
#include <stop_token>
#include <string>
#include <thread>
#include <gtest/gtest.h>
#include "gmock/gmock.h"
#include "spdlog/spdlog.h"
#include "interpreter_faker.hpp"
#include "query/exceptions.hpp"
constexpr int NUM_INTERPRETERS = 4, INSERTIONS = 4000;
/*
Tests rely on the fact that interpreters are sequentially added to running_interpreters to get transaction_id of its
corresponding interpreter.
*/
class TransactionQueueMultipleTest : public ::testing::Test {
protected:
memgraph::storage::Storage db_;
std::filesystem::path data_directory{std::filesystem::temp_directory_path() /
"MG_tests_unit_transaction_queue_multiple_intr"};
memgraph::query::InterpreterContext interpreter_context{&db_, {}, data_directory};
InterpreterFaker main_interpreter{&interpreter_context};
std::vector<InterpreterFaker *> running_interpreters;
TransactionQueueMultipleTest() {
for (int i = 0; i < NUM_INTERPRETERS; ++i) {
InterpreterFaker *faker = new InterpreterFaker(&interpreter_context);
running_interpreters.push_back(faker);
}
}
~TransactionQueueMultipleTest() override {
for (int i = 0; i < NUM_INTERPRETERS; ++i) {
delete running_interpreters[i];
}
}
};
// Tests whether admin can see transaction of superadmin
TEST_F(TransactionQueueMultipleTest, TerminateTransaction) {
std::vector<bool> started(NUM_INTERPRETERS, false);
auto thread_func = [this, &started](int thread_index) {
try {
running_interpreters[thread_index]->Interpret("BEGIN");
started[thread_index] = true;
// add try-catch block
for (int j = 0; j < INSERTIONS; ++j) {
running_interpreters[thread_index]->Interpret("CREATE (:Person {prop: " + std::to_string(thread_index) + "})");
}
} catch (memgraph::query::HintedAbortError &e) {
}
};
{
std::vector<std::jthread> running_threads;
running_threads.reserve(NUM_INTERPRETERS);
for (int i = 0; i < NUM_INTERPRETERS; ++i) {
running_threads.emplace_back(thread_func, i);
}
while (!std::all_of(started.begin(), started.end(), [](const bool v) { return v; })) {
std::this_thread::sleep_for(std::chrono::milliseconds(20));
}
auto show_stream = main_interpreter.Interpret("SHOW TRANSACTIONS");
ASSERT_EQ(show_stream.GetResults().size(), NUM_INTERPRETERS + 1);
// Choose random transaction to kill
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<int> distr(0, NUM_INTERPRETERS - 1);
int index_to_terminate = distr(gen);
// Kill random transaction
std::string run_trans_id =
std::to_string(running_interpreters[index_to_terminate]->interpreter.GetTransactionId().value());
std::string esc_run_trans_id = "'" + run_trans_id + "'";
auto terminate_stream = main_interpreter.Interpret("TERMINATE TRANSACTIONS " + esc_run_trans_id);
// check result of killing
ASSERT_EQ(terminate_stream.GetResults().size(), 1U);
EXPECT_EQ(terminate_stream.GetResults()[0][0].ValueString(), run_trans_id);
ASSERT_TRUE(terminate_stream.GetResults()[0][1].ValueBool()); // that the transaction is actually killed
// test here show transactions
auto show_stream_after_kill = main_interpreter.Interpret("SHOW TRANSACTIONS");
ASSERT_EQ(show_stream_after_kill.GetResults().size(), NUM_INTERPRETERS);
// wait to finish for threads
for (int i = 0; i < NUM_INTERPRETERS; ++i) {
running_threads[i].join();
}
// test the state of the database
for (int i = 0; i < NUM_INTERPRETERS; ++i) {
if (i != index_to_terminate) {
running_interpreters[i]->Interpret("COMMIT");
}
std::string fetch_query = "MATCH (n:Person) WHERE n.prop=" + std::to_string(i) + " RETURN n";
auto results_stream = main_interpreter.Interpret(fetch_query);
if (i == index_to_terminate) {
ASSERT_EQ(results_stream.GetResults().size(), 0);
} else {
ASSERT_EQ(results_stream.GetResults().size(), INSERTIONS);
}
}
main_interpreter.Interpret("MATCH (n) DETACH DELETE n");
}
}