diff --git a/.github/workflows/diff.yaml b/.github/workflows/diff.yaml index 4d9d31c1b..d6eb61de1 100644 --- a/.github/workflows/diff.yaml +++ b/.github/workflows/diff.yaml @@ -14,6 +14,8 @@ jobs: runs-on: [self-hosted, Linux, X64, Diff] env: THREADS: 24 + MEMGRAPH_ENTERPRISE_LICENSE: ${{ secrets.MEMGRAPH_ENTERPRISE_LICENSE }} + MEMGRAPH_ORGANIZATION_NAME: ${{ secrets.MEMGRAPH_ORGANIZATION_NAME }} steps: - name: Set up repository @@ -76,6 +78,8 @@ jobs: runs-on: [self-hosted, Linux, X64, Diff] env: THREADS: 24 + MEMGRAPH_ENTERPRISE_LICENSE: ${{ secrets.MEMGRAPH_ENTERPRISE_LICENSE }} + MEMGRAPH_ORGANIZATION_NAME: ${{ secrets.MEMGRAPH_ORGANIZATION_NAME }} steps: - name: Set up repository @@ -140,6 +144,8 @@ jobs: runs-on: [self-hosted, Linux, X64, Diff] env: THREADS: 24 + MEMGRAPH_ENTERPRISE_LICENSE: ${{ secrets.MEMGRAPH_ENTERPRISE_LICENSE }} + MEMGRAPH_ORGANIZATION_NAME: ${{ secrets.MEMGRAPH_ORGANIZATION_NAME }} steps: - name: Set up repository @@ -214,6 +220,8 @@ jobs: runs-on: [self-hosted, Linux, X64, Diff] env: THREADS: 24 + MEMGRAPH_ENTERPRISE_LICENSE: ${{ secrets.MEMGRAPH_ENTERPRISE_LICENSE }} + MEMGRAPH_ORGANIZATION_NAME: ${{ secrets.MEMGRAPH_ORGANIZATION_NAME }} steps: - name: Set up repository @@ -310,6 +318,8 @@ jobs: #continue-on-error: true env: THREADS: 24 + MEMGRAPH_ENTERPRISE_LICENSE: ${{ secrets.MEMGRAPH_ENTERPRISE_LICENSE }} + MEMGRAPH_ORGANIZATION_NAME: ${{ secrets.MEMGRAPH_ORGANIZATION_NAME }} steps: - name: Set up repository @@ -349,6 +359,8 @@ jobs: runs-on: [self-hosted, Linux, X64, Diff, Gen7] env: THREADS: 24 + MEMGRAPH_ENTERPRISE_LICENSE: ${{ secrets.MEMGRAPH_ENTERPRISE_LICENSE }} + MEMGRAPH_ORGANIZATION_NAME: ${{ secrets.MEMGRAPH_ORGANIZATION_NAME }} steps: - name: Set up repository diff --git a/.github/workflows/package_all.yaml b/.github/workflows/package_all.yaml index 656f1016a..bc2a64095 100644 --- a/.github/workflows/package_all.yaml +++ b/.github/workflows/package_all.yaml @@ -15,7 +15,7 @@ jobs: fetch-depth: 0 # Required because of release/get_version.py - name: "Build package" run: | - ./release/package/run.sh package community centos-7 + ./release/package/run.sh package centos-7 - name: "Upload package" uses: actions/upload-artifact@v2 with: @@ -32,7 +32,7 @@ jobs: fetch-depth: 0 # Required because of release/get_version.py - name: "Build package" run: | - ./release/package/run.sh package community centos-8 + ./release/package/run.sh package centos-8 - name: "Upload package" uses: actions/upload-artifact@v2 with: @@ -49,7 +49,7 @@ jobs: fetch-depth: 0 # Required because of release/get_version.py - name: "Build package" run: | - ./release/package/run.sh package community debian-9 + ./release/package/run.sh package debian-9 - name: "Upload package" uses: actions/upload-artifact@v2 with: @@ -66,7 +66,7 @@ jobs: fetch-depth: 0 # Required because of release/get_version.py - name: "Build package" run: | - ./release/package/run.sh package community debian-10 + ./release/package/run.sh package debian-10 - name: "Upload package" uses: actions/upload-artifact@v2 with: @@ -84,7 +84,7 @@ jobs: - name: "Build package" run: | cd release/package - ./run.sh package community debian-10 --for-docker + ./run.sh package debian-10 --for-docker ./run.sh docker - name: "Upload package" uses: actions/upload-artifact@v2 @@ -102,7 +102,7 @@ jobs: fetch-depth: 0 # Required because of release/get_version.py - name: "Build package" run: | - ./release/package/run.sh package community ubuntu-18.04 + ./release/package/run.sh package ubuntu-18.04 - name: "Upload package" uses: actions/upload-artifact@v2 with: @@ -119,7 +119,7 @@ jobs: fetch-depth: 0 # Required because of release/get_version.py - name: "Build package" run: | - ./release/package/run.sh package community ubuntu-20.04 + ./release/package/run.sh package ubuntu-20.04 - name: "Upload package" uses: actions/upload-artifact@v2 with: @@ -136,7 +136,7 @@ jobs: fetch-depth: 0 # Required because of release/get_version.py - name: "Build package" run: | - ./release/package/run.sh package enterprise centos-7 + ./release/package/run.sh package centos-7 - name: "Upload package" uses: actions/upload-artifact@v2 with: @@ -153,7 +153,7 @@ jobs: fetch-depth: 0 # Required because of release/get_version.py - name: "Build package" run: | - ./release/package/run.sh package enterprise centos-8 + ./release/package/run.sh package centos-8 - name: "Upload package" uses: actions/upload-artifact@v2 with: @@ -170,7 +170,7 @@ jobs: fetch-depth: 0 # Required because of release/get_version.py - name: "Build package" run: | - ./release/package/run.sh package enterprise debian-9 + ./release/package/run.sh package debian-9 - name: "Upload package" uses: actions/upload-artifact@v2 with: @@ -187,7 +187,7 @@ jobs: fetch-depth: 0 # Required because of release/get_version.py - name: "Build package" run: | - ./release/package/run.sh package enterprise debian-10 + ./release/package/run.sh package debian-10 - name: "Upload package" uses: actions/upload-artifact@v2 with: @@ -205,7 +205,7 @@ jobs: - name: "Build package" run: | cd release/package - ./run.sh package enterprise debian-10 --for-docker + ./run.sh package debian-10 --for-docker ./run.sh docker - name: "Upload package" uses: actions/upload-artifact@v2 @@ -223,7 +223,7 @@ jobs: fetch-depth: 0 # Required because of release/get_version.py - name: "Build package" run: | - ./release/package/run.sh package enterprise ubuntu-18.04 + ./release/package/run.sh package ubuntu-18.04 - name: "Upload package" uses: actions/upload-artifact@v2 with: @@ -240,7 +240,7 @@ jobs: fetch-depth: 0 # Required because of release/get_version.py - name: "Build package" run: | - ./release/package/run.sh package enterprise ubuntu-20.04 + ./release/package/run.sh package ubuntu-20.04 - name: "Upload package" uses: actions/upload-artifact@v2 with: diff --git a/CMakeLists.txt b/CMakeLists.txt index ad7873a91..0e751a756 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -61,9 +61,9 @@ set(MEMGRAPH_OVERRIDE_VERSION_SUFFIX "") # Variables used to generate the versions. if (MG_ENTERPRISE) - set(get_version_enterprise "--enterprise") + set(get_version_offering "") else() - set(get_version_enterprise "") + set(get_version_offering "--open-source") endif() set(get_version_script "${CMAKE_SOURCE_DIR}/release/get_version.py") @@ -71,7 +71,7 @@ set(get_version_script "${CMAKE_SOURCE_DIR}/release/get_version.py") execute_process( OUTPUT_VARIABLE MEMGRAPH_VERSION RESULT_VARIABLE MEMGRAPH_VERSION_RESULT - COMMAND "${get_version_script}" ${get_version_enterprise} + COMMAND "${get_version_script}" ${get_version_offering} "${MEMGRAPH_OVERRIDE_VERSION}" "${MEMGRAPH_OVERRIDE_VERSION_SUFFIX}" "--memgraph-root-dir" @@ -87,7 +87,7 @@ endif() execute_process( OUTPUT_VARIABLE MEMGRAPH_VERSION_DEB RESULT_VARIABLE MEMGRAPH_VERSION_DEB_RESULT - COMMAND "${get_version_script}" ${get_version_enterprise} + COMMAND "${get_version_script}" ${get_version_offering} --variant deb "${MEMGRAPH_OVERRIDE_VERSION}" "${MEMGRAPH_OVERRIDE_VERSION_SUFFIX}" @@ -104,7 +104,7 @@ endif() execute_process( OUTPUT_VARIABLE MEMGRAPH_VERSION_RPM RESULT_VARIABLE MEMGRAPH_VERSION_RPM_RESULT - COMMAND "${get_version_script}" ${get_version_enterprise} + COMMAND "${get_version_script}" ${get_version_offering} --variant rpm "${MEMGRAPH_OVERRIDE_VERSION}" "${MEMGRAPH_OVERRIDE_VERSION_SUFFIX}" diff --git a/config/flags.yaml b/config/flags.yaml index 23f39effd..c3e5e754c 100644 --- a/config/flags.yaml +++ b/config/flags.yaml @@ -101,3 +101,5 @@ undocumented: - "help" - "help_xml" - "version" + - "organization_name" + - "license_key" diff --git a/release/CMakeLists.txt b/release/CMakeLists.txt index e29c6a47c..f2b49e6b1 100644 --- a/release/CMakeLists.txt +++ b/release/CMakeLists.txt @@ -1,11 +1,6 @@ # Install the license file. -if (MG_ENTERPRISE) - install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/LICENSE_ENTERPRISE.md - DESTINATION share/doc/memgraph RENAME copyright) -else() - install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/LICENSE_COMMUNITY.md - DESTINATION share/doc/memgraph RENAME copyright) -endif() +install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/LICENSE.md + DESTINATION share/doc/memgraph RENAME copyright) install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/third-party-licenses DESTINATION share/doc/memgraph) @@ -30,23 +25,13 @@ set(CPACK_DEBIAN_PACKAGE_SECTION non-free/database) set(CPACK_DEBIAN_PACKAGE_HOMEPAGE https://memgraph.com) set(CPACK_DEBIAN_PACKAGE_VERSION "${MEMGRAPH_VERSION_DEB}") set(CPACK_DEBIAN_FILE_NAME "memgraph_${MEMGRAPH_VERSION_DEB}_amd64.deb") -if (MG_ENTERPRISE) - set(CPACK_DEBIAN_PACKAGE_CONTROL_EXTRA - "${CMAKE_CURRENT_SOURCE_DIR}/debian/enterprise/conffiles;" - "${CMAKE_CURRENT_SOURCE_DIR}/debian/enterprise/copyright;" - "${CMAKE_CURRENT_SOURCE_DIR}/debian/preinst;" - "${CMAKE_CURRENT_SOURCE_DIR}/debian/prerm;" - "${CMAKE_CURRENT_SOURCE_DIR}/debian/postrm;" - "${CMAKE_CURRENT_SOURCE_DIR}/debian/postinst;") -else() - set(CPACK_DEBIAN_PACKAGE_CONTROL_EXTRA - "${CMAKE_CURRENT_SOURCE_DIR}/debian/community/conffiles;" - "${CMAKE_CURRENT_SOURCE_DIR}/debian/community/copyright;" - "${CMAKE_CURRENT_SOURCE_DIR}/debian/preinst;" - "${CMAKE_CURRENT_SOURCE_DIR}/debian/prerm;" - "${CMAKE_CURRENT_SOURCE_DIR}/debian/postrm;" - "${CMAKE_CURRENT_SOURCE_DIR}/debian/postinst;") -endif() +set(CPACK_DEBIAN_PACKAGE_CONTROL_EXTRA + "${CMAKE_CURRENT_SOURCE_DIR}/debian/conffiles;" + "${CMAKE_CURRENT_SOURCE_DIR}/debian/copyright;" + "${CMAKE_CURRENT_SOURCE_DIR}/debian/preinst;" + "${CMAKE_CURRENT_SOURCE_DIR}/debian/prerm;" + "${CMAKE_CURRENT_SOURCE_DIR}/debian/postrm;" + "${CMAKE_CURRENT_SOURCE_DIR}/debian/postinst;") set(CPACK_DEBIAN_PACKAGE_SHLIBDEPS ON) # Description formatting is important, summary must be followed with a newline and 1 space. set(CPACK_DEBIAN_PACKAGE_DESCRIPTION "${CPACK_PACKAGE_DESCRIPTION_SUMMARY} @@ -65,13 +50,8 @@ set(CPACK_RPM_EXCLUDE_FROM_AUTO_FILELIST_ADDITION /var /var/lib /var/log /etc/logrotate.d /lib /lib/systemd /lib/systemd/system /lib/systemd/system/memgraph.service) set(CPACK_RPM_PACKAGE_REQUIRES_PRE "shadow-utils") -if (MG_ENTERPRISE) - set(CPACK_RPM_USER_BINARY_SPECFILE "${CMAKE_CURRENT_SOURCE_DIR}/rpm/enterprise/memgraph.spec.in") - set(CPACK_RPM_PACKAGE_LICENSE "Memgraph Enterprise Trial License") -else() - set(CPACK_RPM_USER_BINARY_SPECFILE "${CMAKE_CURRENT_SOURCE_DIR}/rpm/community/memgraph.spec.in") - set(CPACK_RPM_PACKAGE_LICENSE "Memgraph Community License") -endif() +set(CPACK_RPM_USER_BINARY_SPECFILE "${CMAKE_CURRENT_SOURCE_DIR}/rpm/memgraph.spec.in") +set(CPACK_RPM_PACKAGE_LICENSE "Memgraph License") # Description formatting is important, no line must be greater than 80 characters. set(CPACK_RPM_PACKAGE_DESCRIPTION "Contains Memgraph, the graph database. It aims to deliver developers the speed, simplicity and scale required to build diff --git a/release/LICENSE_ENTERPRISE.md b/release/LICENSE.md similarity index 99% rename from release/LICENSE_ENTERPRISE.md rename to release/LICENSE.md index b4c00bf7d..a95703965 100644 --- a/release/LICENSE_ENTERPRISE.md +++ b/release/LICENSE.md @@ -1,4 +1,4 @@ -# Memgraph Enterprise - Software Subscription Agreement +# Memgraph - Software Subscription Agreement Memgraph Limited is registered in England under registration 10195084 and has its registered office at Suite 4, Ironstone House, Ironstone Way, Brixworth, diff --git a/release/LICENSE_COMMUNITY.md b/release/LICENSE_COMMUNITY.md deleted file mode 100644 index 34a75c4d5..000000000 --- a/release/LICENSE_COMMUNITY.md +++ /dev/null @@ -1,98 +0,0 @@ -# Memgraph Community User License Agreement - -This License Agreement governs your use of the Memgraph Community Release (the -"Software") and documentation ("Documentation"). - -BY DOWNLOADING AND/OR ACCESSING THIS SOFTWARE, YOU ("LICENSEE") AGREE TO THESE -TERMS. - -1. License Grant - -The Software and Documentation are provided to Licensee at no charge and are -licensed, not sold to Licensee. No ownership of any part of the Software and -Documentation is hereby transferred to Licensee. Subject to (i) the terms and -conditions of this License Agreement, and (ii) any additional license -restrictions and parameters contained on Licensor’s quotation, website, or -order form, Licensor hereby grants Licensee a personal, non-assignable, -non-transferable and non-exclusive license to install, access and use the -Software (in object code form only) and Documentation for Licensee’s internal -business purposes (including for use in a production environment) only. All -rights relating to the Software and Documentation that are not expressly -licensed in this License Agreement, whether now existing or which may hereafter -come into existence are reserved for Licensor. Licensee shall not remove, -obscure, or alter any proprietary rights notices (including without limitation -copyright and trademark notices), which may be affixed to or contained within -the Software or Documentation. - -Licensor may terminate this License Agreement with immediate effect upon -written notice to the Licensee. Upon termination Licensee shall delete all -electronic copies of all or any part of the Software and/or the Documentation -resident in its systems or elsewhere. - -2. Restrictions - -Licensee will not, directly or indirectly, (a) copy the Software or -Documentation in any manner or for any purpose; (b) install, access or use any -component of the Software or Documentation for any purpose not expressly -granted in Section 1 above; (c) resell, distribute, publicly display or -publicly perform the Software or Documentation or any component thereof, by -transfer, lease, loan or any other means, or make it available for use by -others in any time-sharing, service bureau or similar arrangement; (d) -disassemble, decrypt, extract, reverse engineer or reverse compile the -Software, or otherwise attempt to discover the source code, confidential -algorithms or techniques incorporated in the Software; (e) export the Software -or Documentation in violation of any applicable laws or regulations; (f) -modify, translate, adapt, or create derivative works from the Software or -Documentation; (g) circumvent, disable or otherwise interfere with -security-related features of the Software or Documentation; (h) use the -Software or Documentation for any illegal purpose, in any manner that is -inconsistent with the terms of this License Agreement, or to engage in illegal -activity; (i) remove or alter any trademark, logo, copyright or other -proprietary notices, legends, symbols or labels on, or embedded in, the -Software or Documentation; or (j) provide access to the Software or -Documentation to third parties. - -3. Warranty Disclaimer - -THE SOFTWARE AND DOCUMENTATION ARE PROVIDED "AS IS" AND LICENSOR MAKES NO -WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON -INFRINGEMENT OF THIRD PARTIES’ INTELLECTUAL PROPERTY RIGHTS OR OTHER -PROPRIETARY RIGHTS. NEITHER THIS LICENSE AGREEMENT NOR ANY DOCUMENTATION -FURNISHED UNDER IT IS INTENDED TO EXPRESS OR IMPLY ANY WARRANTY THAT THE -OPERATION OF THE SOFTWARE WILL BE UNINTERRUPTED, TIMELY, OR ERROR-FREE. - -4. Limitation of Liability - -Licensor shall not in any circumstances be liable, whether in tort (including -for negligence or breach of statutory duty howsoever arising), contract, -misrepresentation (whether innocent or negligent) or otherwise for: loss of -profits, loss of business, depletion of goodwill or similar losses, loss of -anticipated savings, loss of goods, loss or corruption of data or computer -downtime, or any special, indirect, consequential or pure economic loss, costs, -damages, charges or expenses. - -Licensor's total aggregate liability in contract, tort (including without -limitation negligence or breach of statutory duty howsoever arising), -misrepresentation (whether innocent or negligent), restitution or otherwise, -arising in connection with the performance or contemplated performance of this -License Agreement shall in all circumstances be limited to GBP10.00 (ten pounds -sterling). - -Nothing in this License Agreement shall limit Licensor’s liability in the case -of death or personal injury caused by negligence, fraud, or fraudulent -misrepresentation, or where it otherwise cannot be limited by law. - -5. Technical Data - -Licensor may collect and use technical information (such as usage patterns) -gathered when the Licensee downloads and uses the Software. This is generally -statistical data which does not identify an identified or identifiable -individual. It may also include Licensee’s IP address which is personal data -and is processed in accordance with our Privacy Policy. We only use this -technical information to improve our products. - -6. Law and Jurisdiction - -This License Agreement is governed by the laws of England and is subject to the -non-exclusive jurisdiction of the courts of England. diff --git a/release/debian/community/conffiles b/release/debian/community/conffiles deleted file mode 100644 index e20b1a1c4..000000000 --- a/release/debian/community/conffiles +++ /dev/null @@ -1,2 +0,0 @@ -/etc/memgraph/memgraph.conf -/etc/logrotate.d/memgraph diff --git a/release/debian/community/copyright b/release/debian/community/copyright deleted file mode 120000 index f6386b1af..000000000 --- a/release/debian/community/copyright +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE_COMMUNITY.md \ No newline at end of file diff --git a/release/debian/enterprise/conffiles b/release/debian/conffiles similarity index 100% rename from release/debian/enterprise/conffiles rename to release/debian/conffiles diff --git a/release/debian/copyright b/release/debian/copyright new file mode 120000 index 000000000..7eabdb1c2 --- /dev/null +++ b/release/debian/copyright @@ -0,0 +1 @@ +../LICENSE.md \ No newline at end of file diff --git a/release/debian/enterprise/copyright b/release/debian/enterprise/copyright deleted file mode 120000 index 86de13b66..000000000 --- a/release/debian/enterprise/copyright +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE_ENTERPRISE.md \ No newline at end of file diff --git a/release/debian/preinst b/release/debian/preinst index 189fb5fdb..02ca52991 100644 --- a/release/debian/preinst +++ b/release/debian/preinst @@ -7,7 +7,7 @@ set -e # Manage (remove) /etc/logrotate.d/memgraph_audit file because the whole # logrotate config is moved to /etc/logrotate.d/memgraph since v1.2.0. -# Note: Only used to manage Memgraph Enterprise config but packaged into the +# Note: Only used to manage Memgraph config but it was packaged into the # Memgraph Community as well. if dpkg-maintscript-helper supports rm_conffile 2>/dev/null; then # 1.1.999 is chosen because it's high enough version number. It's highly diff --git a/release/docker/memgraph_enterprise.dockerfile b/release/docker/memgraph.dockerfile similarity index 100% rename from release/docker/memgraph_enterprise.dockerfile rename to release/docker/memgraph.dockerfile diff --git a/release/docker/memgraph_community.dockerfile b/release/docker/memgraph_community.dockerfile deleted file mode 100644 index 1d1a9d2ba..000000000 --- a/release/docker/memgraph_community.dockerfile +++ /dev/null @@ -1,30 +0,0 @@ -FROM debian:buster -# NOTE: If you change the base distro update release/package as well. - -ARG deb_release - -RUN apt-get update && apt-get install -y \ - openssl libcurl4 libssl1.1 python3 libpython3.7 python3-pip \ - --no-install-recommends \ - && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* - -RUN pip3 install networkx==2.4 numpy==1.19.2 scipy==1.5.2 - -COPY ${deb_release} / - -# Install memgraph package -RUN dpkg -i ${deb_release} - -# Memgraph listens for Bolt Protocol on this port by default. -EXPOSE 7687 -# Snapshots and logging volumes -VOLUME /var/log/memgraph -VOLUME /var/lib/memgraph -# Configuration volume -VOLUME /etc/memgraph - -USER memgraph -WORKDIR /usr/lib/memgraph - -ENTRYPOINT ["/usr/lib/memgraph/memgraph"] -CMD [""] diff --git a/release/get_version.py b/release/get_version.py index dec645e31..73a2567a1 100755 --- a/release/get_version.py +++ b/release/get_version.py @@ -49,13 +49,13 @@ import os # <VERSION>+<DISTANCE>~<SHORTHASH>-<OFFERING>[-<SUFFIX>] # Examples: # Release version: -# 0.50.1-community -# 0.50.1-enterprise -# 0.50.1-enterprise-veryimportantcustomer +# 0.50.1-open-source +# 0.50.1 +# 0.50.1-veryimportantcustomer # Development version (master, 12 commits after release/0.50): -# 0.50.0+12~7e1eef94-community -# 0.50.0+12~7e1eef94-enterprise -# 0.50.0+12~7e1eef94-enterprise-veryimportantcustomer +# 0.50.0+12~7e1eef94-open-source +# 0.50.0+12~7e1eef94 +# 0.50.0+12~7e1eef94-veryimportantcustomer # # The DEB package version is determined using the following two templates: # Release version: @@ -64,13 +64,13 @@ import os # <VERSION>+<DISTANCE>~<SHORTHASH>-<OFFERING>[-<SUFFIX>]-1 # Examples: # Release version: -# 0.50.1-community-1 -# 0.50.1-enterprise-1 -# 0.50.1-enterprise-veryimportantcustomer-1 +# 0.50.1-open-source-1 +# 0.50.1-1 +# 0.50.1-veryimportantcustomer-1 # Development version (master, 12 commits after release/0.50): -# 0.50.0+12~7e1eef94-community-1 -# 0.50.0+12~7e1eef94-enterprise-1 -# 0.50.0+12~7e1eef94-enterprise-veryimportantcustomer-1 +# 0.50.0+12~7e1eef94-open-source-1 +# 0.50.0+12~7e1eef94-1 +# 0.50.0+12~7e1eef94-veryimportantcustomer-1 # For more documentation about the DEB package naming conventions see: # https://www.debian.org/doc/debian-policy/ch-controlfields.html#version # @@ -81,13 +81,13 @@ import os # <VERSION>_0.<DISTANCE>.<SHORTHASH>.<OFFERING>[.<SUFFIX>] # Examples: # Release version: -# 0.50.1_1.community -# 0.50.1_1.enterprise -# 0.50.1_1.enterprise.veryimportantcustomer +# 0.50.1_1.open-source +# 0.50.1_1 +# 0.50.1_1.veryimportantcustomer # Development version: -# 0.50.0_0.12.7e1eef94.community -# 0.50.0_0.12.7e1eef94.enterprise -# 0.50.0_0.12.7e1eef94.enterprise.veryimportantcustomer +# 0.50.0_0.12.7e1eef94.open-source +# 0.50.0_0.12.7e1eef94 +# 0.50.0_0.12.7e1eef94.veryimportantcustomer # For more documentation about the RPM package naming conventions see: # https://docs.fedoraproject.org/en-US/packaging-guidelines/Versioning/ # https://fedoraproject.org/wiki/Package_Versioning_Examples @@ -107,20 +107,20 @@ def format_version(variant, version, offering, distance=None, shorthash=None, # This is a release version. if variant == "deb": # <VERSION>-<OFFERING>[-<SUFFIX>]-1 - ret = "{}-{}".format(version, offering) + ret = "{}{}".format(version, f"-{offering}" if offering else "") if suffix: ret += "-" + suffix ret += "-1" return ret elif variant == "rpm": # <VERSION>_1.<OFFERING>[.<SUFFIX>] - ret = "{}_1.{}".format(version, offering) + ret = "{}_1{}".format(version, f".{offering}" if offering else "") if suffix: ret += "." + suffix return ret else: # <VERSION>-<OFFERING>[-<SUFFIX>] - ret = "{}-{}".format(version, offering) + ret = "{}{}".format(version, f"-{offering}" if offering else "") if suffix: ret += "-" + suffix return ret @@ -128,21 +128,21 @@ def format_version(variant, version, offering, distance=None, shorthash=None, # This is a development version. if variant == "deb": # <VERSION>+<DISTANCE>~<SHORTHASH>-<OFFERING>[-<SUFFIX>]-1 - ret = "{}+{}~{}-{}".format(version, distance, shorthash, offering) + ret = "{}+{}~{}{}".format(version, distance, shorthash, f"-{offering}" if offering else "") if suffix: ret += "-" + suffix ret += "-1" return ret elif variant == "rpm": # <VERSION>_0.<DISTANCE>.<SHORTHASH>.<OFFERING>[.<SUFFIX>] - ret = "{}_0.{}.{}.{}".format( - version, distance, shorthash, offering) + ret = "{}_0.{}.{}{}".format( + version, distance, shorthash, f".{offering}" if offering else "") if suffix: ret += "." + suffix return ret else: # <VERSION>+<DISTANCE>~<SHORTHASH>-<OFFERING>[-<SUFFIX>] - ret = "{}+{}~{}-{}".format(version, distance, shorthash, offering) + ret = "{}+{}~{}{}".format(version, distance, shorthash, f"-{offering}" if offering else "") if suffix: ret += "-" + suffix return ret @@ -152,8 +152,8 @@ def format_version(variant, version, offering, distance=None, shorthash=None, parser = argparse.ArgumentParser( description="Get the current version of Memgraph.") parser.add_argument( - "--enterprise", action="store_true", - help="set the current offering to enterprise (default 'community')") + "--open-source", action="store_true", + help="set the current offering to 'open-source'") parser.add_argument( "version", help="manual version override, if supplied the version isn't " "determined using git") @@ -173,7 +173,7 @@ if not os.path.isdir(args.memgraph_root_dir): os.chdir(args.memgraph_root_dir) -offering = "enterprise" if args.enterprise else "community" +offering = "open-source" if args.open_source else None # Check whether the version was manually supplied. if args.version: diff --git a/release/logrotate_enterprise.conf b/release/logrotate.conf similarity index 86% rename from release/logrotate_enterprise.conf rename to release/logrotate.conf index bfcc4cd7f..b731d97b9 100644 --- a/release/logrotate_enterprise.conf +++ b/release/logrotate.conf @@ -1,7 +1,7 @@ # logrotate configuration for Memgraph Enterprise # see "man logrotate" for details -/var/lib/memgraph/durability/audit/audit.log { +/var/lib/memgraph/audit/audit.log { # rotate log files daily daily # keep one year worth of audit logs diff --git a/release/logrotate_community.conf b/release/logrotate_community.conf deleted file mode 100644 index ceafb0fba..000000000 --- a/release/logrotate_community.conf +++ /dev/null @@ -1,2 +0,0 @@ -# logrotate configuration for Memgraph Community -# see "man logrotate" for details diff --git a/release/package/run.sh b/release/package/run.sh index fa980c7a2..046925d5c 100755 --- a/release/package/run.sh +++ b/release/package/run.sh @@ -3,7 +3,6 @@ set -Eeuo pipefail SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" -SUPPORTED_OFFERING=(community enterprise) SUPPORTED_OS=(centos-7 centos-8 debian-9 debian-10 ubuntu-18.04 ubuntu-20.04) PROJECT_ROOT="$SCRIPT_DIR/../.." TOOLCHAIN_VERSION="toolchain-v3" @@ -11,23 +10,14 @@ ACTIVATE_TOOLCHAIN="source /opt/${TOOLCHAIN_VERSION}/activate" HOST_OUTPUT_DIR="$PROJECT_ROOT/build/output" print_help () { - echo "$0 init|package {offering} {os} [--for-docker]|docker|test" + echo "$0 init|package {os} [--for-docker]|docker|test" echo "" - echo " offerings: ${SUPPORTED_OFFERING[*]}" echo " OSs: ${SUPPORTED_OS[*]}" exit 1 } make_package () { - offering="$1" - offering_flag=" -DMG_ENTERPRISE=OFF " - if [[ "$offering" == "enterprise" ]]; then - offering_flag=" -DMG_ENTERPRISE=ON " - fi - if [[ "$offering" == "community" ]]; then - offering_flag=" -DMG_ENTERPRISE=OFF " - fi - os="$2" + os="$1" package_command="" if [[ "$os" =~ ^"centos".* ]]; then package_command=" cpack -G RPM --config ../CPackConfig.cmake && rpmlint memgraph*.rpm " @@ -45,7 +35,7 @@ make_package () { fi fi build_container="mgbuild_$os" - echo "Building Memgraph $offering for $os on $build_container..." + echo "Building Memgraph for $os on $build_container..." echo "Copying project files..." # If master is not the current branch, fetch it, because the get_version @@ -74,7 +64,7 @@ make_package () { echo "Building targeted package..." docker exec "$build_container" bash -c "cd /memgraph && ./init" docker exec "$build_container" bash -c "cd $container_build_dir && rm -rf ./*" - docker exec "$build_container" bash -c "cd $container_build_dir && $ACTIVATE_TOOLCHAIN && cmake -DCMAKE_BUILD_TYPE=release $offering_flag $docker_flag .." + docker exec "$build_container" bash -c "cd $container_build_dir && $ACTIVATE_TOOLCHAIN && cmake -DCMAKE_BUILD_TYPE=release $docker_flag .." # ' is used instead of " because we need to run make within the allowed # container resources. # shellcheck disable=SC2016 @@ -121,14 +111,6 @@ case "$1" in if [[ "$#" -lt 2 ]]; then print_help fi - offering="$1" - shift 1 - is_offering_ok=false - for supported_offering in "${SUPPORTED_OFFERING[@]}"; do - if [[ "$supported_offering" == "${offering}" ]]; then - is_offering_ok=true - fi - done os="$1" shift 1 is_os_ok=false @@ -137,8 +119,8 @@ case "$1" in is_os_ok=true fi done - if [[ "$is_offering_ok" == true ]] && [[ "$is_os_ok" == true ]]; then - make_package "$offering" "$os" "$@" + if [[ "$is_os_ok" == true ]]; then + make_package "$os" "$@" else print_help fi diff --git a/release/rpm/community/memgraph.spec.in b/release/rpm/community/memgraph.spec.in deleted file mode 100644 index 5b23f794d..000000000 --- a/release/rpm/community/memgraph.spec.in +++ /dev/null @@ -1,138 +0,0 @@ -# -*- rpm-spec -*- -BuildRoot: %_topdir/@CPACK_PACKAGE_FILE_NAME@@CPACK_RPM_PACKAGE_COMPONENT_PART_PATH@ -Summary: @CPACK_RPM_PACKAGE_SUMMARY@ -Name: @CPACK_RPM_PACKAGE_NAME@ -Version: @CPACK_RPM_PACKAGE_VERSION@ -Release: @CPACK_RPM_PACKAGE_RELEASE@ -License: @CPACK_RPM_PACKAGE_LICENSE@ -# Group field is deprecated -# Group: @CPACK_RPM_PACKAGE_GROUP@ -Vendor: @CPACK_RPM_PACKAGE_VENDOR@ -BuildRequires: systemd - -@TMP_RPM_URL@ -@TMP_RPM_REQUIRES@ -@TMP_RPM_REQUIRES_PRE@ -@TMP_RPM_REQUIRES_POST@ -@TMP_RPM_REQUIRES_PREUN@ -@TMP_RPM_REQUIRES_POSTUN@ -@TMP_RPM_PROVIDES@ -@TMP_RPM_OBSOLETES@ -@TMP_RPM_CONFLICTS@ -@TMP_RPM_AUTOPROV@ -@TMP_RPM_AUTOREQ@ -@TMP_RPM_AUTOREQPROV@ -@TMP_RPM_BUILDARCH@ -@TMP_RPM_PREFIXES@ - -@TMP_RPM_DEBUGINFO@ - -# This is needed to prevent Python compilation errors when building the RPM -# package -# https://github.com/scylladb/scylla/issues/2235 -%if 0%{?rhel} < 8 -%global __os_install_post \ - /usr/lib/rpm/redhat/brp-compress \ - %{!?__debug_package:\ - /usr/lib/rpm/redhat/brp-strip %{__strip} \ - /usr/lib/rpm/redhat/brp-strip-comment-note %{__strip} %{__objdump} \ - } \ - /usr/lib/rpm/redhat/brp-strip-static-archive %{__strip} \ - %{!?__jar_repack:/usr/lib/rpm/redhat/brp-java-repack-jars} \ -%{nil} -%else -%global __os_install_post \ - /usr/lib/rpm/brp-compress \ - %{!?__debug_package:\ - /usr/lib/rpm/brp-strip %{__strip} \ - /usr/lib/rpm/brp-strip-comment-note %{__strip} %{__objdump} \ - } \ - /usr/lib/rpm/brp-strip-static-archive %{__strip} \ -%{nil} -%endif - -%define _rpmdir %_topdir/RPMS -%define _srcrpmdir %_topdir/SRPMS -@FILE_NAME_DEFINE@ -%define _unpackaged_files_terminate_build 0 -@TMP_RPM_SPEC_INSTALL_POST@ -@CPACK_RPM_SPEC_MORE_DEFINE@ -@CPACK_RPM_COMPRESSION_TYPE_TMP@ - -%description -@CPACK_RPM_PACKAGE_DESCRIPTION@ - -# This is a shortcutted spec file generated by CMake RPM generator -# we skip _install step because CPack does that for us. -# We do only save CPack installed tree in _prepr -# and then restore it in build. -%prep -# Put the systemd unit where it is expected on this system -mkdir -p $RPM_BUILD_ROOT/%{_unitdir} -mv $RPM_BUILD_ROOT/lib/systemd/system/memgraph.service $RPM_BUILD_ROOT/%{_unitdir} -rm -rf $RPM_BUILD_ROOT/lib -# Fix the incorrect directory permissions set by cpack (this is fixed in CMake 3.11) -find $RPM_BUILD_ROOT -type d | xargs chmod 755 -# After setting up custom prep, continue with CMake's default -mv $RPM_BUILD_ROOT %_topdir/tmpBBroot - -%install -if [ -e $RPM_BUILD_ROOT ]; -then - rm -rf $RPM_BUILD_ROOT -fi -mv %_topdir/tmpBBroot $RPM_BUILD_ROOT - -@TMP_RPM_DEBUGINFO_INSTALL@ - -%clean - -%post -# memgraph user and group must be set in preinst -chown memgraph:memgraph /var/lib/memgraph || exit 1 -chmod 750 /var/lib/memgraph || exit 1 -chown memgraph:adm /var/log/memgraph || exit 1 -chmod 750 /var/log/memgraph || exit 1 - -# Generate SSL certificates -if [ ! -d /etc/memgraph/ssl ]; then - mkdir /etc/memgraph/ssl || exit 1 - openssl req -x509 -newkey rsa:4096 -days 3650 -nodes \ - -keyout /etc/memgraph/ssl/key.pem -out /etc/memgraph/ssl/cert.pem \ - -subj "/C=GB/ST=London/L=London/O=Memgraph Ltd./CN=Memgraph DB" || exit 1 - chown memgraph:memgraph /etc/memgraph/ssl/* || exit 1 - chmod 400 /etc/memgraph/ssl/* || exit 1 -fi -@RPM_SYMLINK_POSTINSTALL@ -@CPACK_RPM_SPEC_POSTINSTALL@ - -%postun -@CPACK_RPM_SPEC_POSTUNINSTALL@ - -%pre -# Add the 'memgraph' user and group -getent group memgraph >/dev/null || groupadd -r memgraph || exit 1 -getent passwd memgraph >/dev/null || \ - useradd -r -g memgraph -d /var/lib/memgraph -s /bin/bash memgraph || exit 1 -echo "Don't forget to switch to the 'memgraph' user to use Memgraph" || exit 1 -@CPACK_RPM_SPEC_PREINSTALL@ - -%preun -@CPACK_RPM_SPEC_PREUNINSTALL@ - -%files -%defattr(@TMP_DEFAULT_FILE_PERMISSIONS@,@TMP_DEFAULT_USER@,@TMP_DEFAULT_GROUP@,@TMP_DEFAULT_DIR_PERMISSIONS@) -@CPACK_RPM_INSTALL_FILES@ -# Since we moved the memgraph.service file, declare it explicitly here. -# NOTE: memgraph.service must not be marked as configuration file. -%{_unitdir}/memgraph.service - -# Override CPACK_RPM_ABSOLUTE_INSTALL_FILES with our %config(noreplace), cpack -# uses plain %config. -%config(noreplace) "/etc/memgraph/memgraph.conf" -%config(noreplace) "/etc/logrotate.d/memgraph" - -@CPACK_RPM_USER_INSTALL_FILES@ - -%changelog -@CPACK_RPM_SPEC_CHANGELOG@ diff --git a/release/rpm/enterprise/memgraph.spec.in b/release/rpm/memgraph.spec.in similarity index 100% rename from release/rpm/enterprise/memgraph.spec.in rename to release/rpm/memgraph.spec.in diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e71985387..00fc2932b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -14,9 +14,10 @@ add_subdirectory(integrations) add_subdirectory(query) add_subdirectory(slk) add_subdirectory(rpc) +add_subdirectory(auth) + if (MG_ENTERPRISE) add_subdirectory(audit) - add_subdirectory(auth) endif() string(TOLOWER ${CMAKE_BUILD_TYPE} lower_build_type) @@ -32,18 +33,14 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}) set(mg_single_node_v2_sources glue/communication.cpp memgraph.cpp + glue/auth.cpp ) -if (MG_ENTERPRISE) - set(mg_single_node_v2_sources - ${mg_single_node_v2_sources} - glue/auth.cpp) -endif() set(mg_single_node_v2_libs stdc++fs Threads::Threads - telemetry_lib mg-query mg-communication mg-memory mg-utils) + telemetry_lib mg-query mg-communication mg-memory mg-utils mg-auth) if (MG_ENTERPRISE) # These are enterprise subsystems - set(mg_single_node_v2_libs ${mg_single_node_v2_libs} mg-auth mg-audit) + set(mg_single_node_v2_libs ${mg_single_node_v2_libs} mg-audit) endif() # memgraph main executable @@ -109,13 +106,9 @@ install(FILES ${CMAKE_SOURCE_DIR}/include/mg_procedure.h install(FILES ${CMAKE_BINARY_DIR}/config/memgraph.conf DESTINATION /etc/memgraph RENAME memgraph.conf) # Install logrotate configuration (must use absolute path). -if (MG_ENTERPRISE) -install(FILES ${CMAKE_SOURCE_DIR}/release/logrotate_enterprise.conf +install(FILES ${CMAKE_SOURCE_DIR}/release/logrotate.conf DESTINATION /etc/logrotate.d RENAME memgraph) -else() -install(FILES ${CMAKE_SOURCE_DIR}/release/logrotate_community.conf - DESTINATION /etc/logrotate.d RENAME memgraph) -endif() + # Create empty directories for default location of lib and log. install(CODE "file(MAKE_DIRECTORY \$ENV{DESTDIR}/var/log/memgraph \$ENV{DESTDIR}/var/lib/memgraph)") diff --git a/src/auth/auth.cpp b/src/auth/auth.cpp index d9bda1c60..abf9fc26c 100644 --- a/src/auth/auth.cpp +++ b/src/auth/auth.cpp @@ -9,7 +9,9 @@ #include "auth/exceptions.hpp" #include "utils/flag_validation.hpp" +#include "utils/license.hpp" #include "utils/logging.hpp" +#include "utils/settings.hpp" #include "utils/string.hpp" DEFINE_VALIDATED_string(auth_module_executable, "", "Absolute path to the auth module executable that should be used.", @@ -58,6 +60,13 @@ Auth::Auth(const std::string &storage_directory) : storage_(storage_directory), std::optional<User> Auth::Authenticate(const std::string &username, const std::string &password) { if (module_.IsUsed()) { + const auto license_check_result = utils::license::global_license_checker.IsValidLicense(utils::global_settings); + if (license_check_result.HasError()) { + spdlog::warn( + utils::license::LicenseCheckErrorToString(license_check_result.GetError(), "authentication modules")); + return std::nullopt; + } + nlohmann::json params = nlohmann::json::object(); params["username"] = username; params["password"] = password; diff --git a/src/auth/auth.hpp b/src/auth/auth.hpp index 0174d39b2..706b35457 100644 --- a/src/auth/auth.hpp +++ b/src/auth/auth.hpp @@ -8,6 +8,7 @@ #include "auth/models.hpp" #include "auth/module.hpp" #include "kvstore/kvstore.hpp" +#include "utils/settings.hpp" namespace auth { @@ -20,7 +21,7 @@ namespace auth { */ class Auth final { public: - Auth(const std::string &storage_directory); + explicit Auth(const std::string &storage_directory); /** * Authenticates a user using his username and password. diff --git a/src/auth/models.cpp b/src/auth/models.cpp index 6a9ed1040..04d7cc605 100644 --- a/src/auth/models.cpp +++ b/src/auth/models.cpp @@ -7,11 +7,16 @@ #include "auth/crypto.hpp" #include "auth/exceptions.hpp" #include "utils/cast.hpp" +#include "utils/license.hpp" +#include "utils/settings.hpp" #include "utils/string.hpp" +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) DEFINE_bool(auth_password_permit_null, true, "Set to false to disable null passwords."); -DEFINE_string(auth_password_strength_regex, ".+", +constexpr std::string_view default_password_regex = ".+"; +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +DEFINE_string(auth_password_strength_regex, default_password_regex.data(), "The regular expression that should be used to match the entire " "entered password to ensure its strength."); @@ -80,7 +85,8 @@ PermissionLevel Permissions::Has(Permission permission) const { // Check for the deny first because it has greater priority than a grant. if (denies_ & utils::UnderlyingCast(permission)) { return PermissionLevel::DENY; - } else if (grants_ & utils::UnderlyingCast(permission)) { + } + if (grants_ & utils::UnderlyingCast(permission)) { return PermissionLevel::GRANT; } return PermissionLevel::NEUTRAL; @@ -190,26 +196,38 @@ User::User(const std::string &username, const std::string &password_hash, const : username_(utils::ToLowerCase(username)), password_hash_(password_hash), permissions_(permissions) {} bool User::CheckPassword(const std::string &password) { - if (password_hash_ == "") return true; + if (password_hash_.empty()) return true; return VerifyPassword(password, password_hash_); } void User::UpdatePassword(const std::optional<std::string> &password) { - if (password) { - std::regex re(FLAGS_auth_password_strength_regex); - if (!std::regex_match(*password, re)) { - throw AuthException( - "The user password doesn't conform to the required strength! Regex: " - "{}", - FLAGS_auth_password_strength_regex); - } - password_hash_ = EncryptPassword(*password); - } else { + if (!password) { if (!FLAGS_auth_password_permit_null) { throw AuthException("Null passwords aren't permitted!"); } password_hash_ = ""; + return; } + + if (FLAGS_auth_password_strength_regex != default_password_regex) { + if (const auto license_check_result = utils::license::global_license_checker.IsValidLicense(utils::global_settings); + license_check_result.HasError()) { + throw AuthException( + "Custom password regex is a Memgraph Enterprise feature. Please set the config " + "(\"--auth-password-strength-regex\") to its default value (\"{}\") or remove the flag.\n{}", + default_password_regex, + utils::license::LicenseCheckErrorToString(license_check_result.GetError(), "password regex")); + } + } + std::regex re(FLAGS_auth_password_strength_regex); + if (!std::regex_match(*password, re)) { + throw AuthException( + "The user password doesn't conform to the required strength! Regex: " + "\"{}\"", + FLAGS_auth_password_strength_regex); + } + + password_hash_ = EncryptPassword(*password); } void User::SetRole(const Role &role) { role_.emplace(role); } diff --git a/src/memgraph.cpp b/src/memgraph.cpp index 352b9d28b..4c3ad3f26 100644 --- a/src/memgraph.cpp +++ b/src/memgraph.cpp @@ -26,6 +26,7 @@ #include "query/auth_checker.hpp" #include "query/discard_value_stream.hpp" #include "query/exceptions.hpp" +#include "query/frontend/ast/ast.hpp" #include "query/interpreter.hpp" #include "query/plan/operator.hpp" #include "query/procedure/module.hpp" @@ -38,10 +39,12 @@ #include "utils/event_counter.hpp" #include "utils/file.hpp" #include "utils/flag_validation.hpp" +#include "utils/license.hpp" #include "utils/logging.hpp" #include "utils/memory_tracker.hpp" #include "utils/readable_size.hpp" #include "utils/rw_lock.hpp" +#include "utils/settings.hpp" #include "utils/signals.hpp" #include "utils/string.hpp" #include "utils/synchronized.hpp" @@ -67,10 +70,11 @@ #include "communication/session.hpp" #include "glue/communication.hpp" -#ifdef MG_ENTERPRISE -#include "audit/log.hpp" #include "auth/auth.hpp" #include "glue/auth.hpp" + +#ifdef MG_ENTERPRISE +#include "audit/log.hpp" #endif namespace { @@ -351,12 +355,18 @@ void ConfigureLogging() { } } // namespace +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +DEFINE_string(license_key, "", "License key for Memgraph Enterprise."); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +DEFINE_string(organization_name, "", "Organization name."); + /// Encapsulates Dbms and Interpreter that are passed through the network server /// and worker to the session. -#ifdef MG_ENTERPRISE struct SessionData { // Explicit constructor here to ensure that pointers to all objects are // supplied. +#if MG_ENTERPRISE + SessionData(storage::Storage *db, query::InterpreterContext *interpreter_context, utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth, audit::Log *audit_log) : db(db), interpreter_context(interpreter_context), auth(auth), audit_log(audit_log) {} @@ -364,26 +374,62 @@ struct SessionData { query::InterpreterContext *interpreter_context; utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth; audit::Log *audit_log; + +#else + + SessionData(storage::Storage *db, query::InterpreterContext *interpreter_context, + utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth) + : db(db), interpreter_context(interpreter_context), auth(auth) {} + storage::Storage *db; + query::InterpreterContext *interpreter_context; + utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth; + +#endif }; -DEFINE_string(auth_user_or_role_name_regex, "[a-zA-Z0-9_.+-@]+", +constexpr std::string_view default_user_role_regex = "[a-zA-Z0-9_.+-@]+"; +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +DEFINE_string(auth_user_or_role_name_regex, default_user_role_regex.data(), "Set to the regular expression that each user or role name must fulfill."); class AuthQueryHandler final : public query::AuthQueryHandler { utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth_; + std::string name_regex_string_; std::regex name_regex_; public: - AuthQueryHandler(utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth, const std::regex &name_regex) - : auth_(auth), name_regex_(name_regex) {} + AuthQueryHandler(utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth, std::string name_regex_string) + : auth_(auth), name_regex_string_(std::move(name_regex_string)), name_regex_(name_regex_string_) {} bool CreateUser(const std::string &username, const std::optional<std::string> &password) override { + if (name_regex_string_ != default_user_role_regex) { + if (const auto license_check_result = + utils::license::global_license_checker.IsValidLicense(utils::global_settings); + license_check_result.HasError()) { + throw auth::AuthException( + "Custom user/role regex is a Memgraph Enterprise feature. Please set the config " + "(\"--auth-user-or-role-name-regex\") to its default value (\"{}\") or remove the flag.\n{}", + default_user_role_regex, + utils::license::LicenseCheckErrorToString(license_check_result.GetError(), "user/role regex")); + } + } if (!std::regex_match(username, name_regex_)) { throw query::QueryRuntimeException("Invalid user name."); } try { - auto locked_auth = auth_->Lock(); - return locked_auth->AddUser(username, password).has_value(); + const auto [first_user, user_added] = std::invoke([&, this] { + auto locked_auth = auth_->Lock(); + const auto first_user = !locked_auth->HasUsers(); + const auto user_added = locked_auth->AddUser(username, password).has_value(); + return std::make_pair(first_user, user_added); + }); + + if (first_user) { + spdlog::info("{} is first created user. Granting all privileges.", username); + GrantPrivilege(username, query::kPrivilegesAll); + } + + return user_added; } catch (const auth::AuthException &e) { throw query::QueryRuntimeException(e.what()); } @@ -663,12 +709,12 @@ class AuthQueryHandler final : public query::AuthQueryHandler { throw query::QueryRuntimeException("Invalid user or role name."); } try { - auto locked_auth = auth_->Lock(); std::vector<auth::Permission> permissions; permissions.reserve(privileges.size()); for (const auto &privilege : privileges) { permissions.push_back(glue::PrivilegeToPermission(privilege)); } + auto locked_auth = auth_->Lock(); auto user = locked_auth->GetUser(user_or_role); auto role = locked_auth->GetRole(user_or_role); if (!user && !role) { @@ -722,64 +768,6 @@ class AuthChecker final : public query::AuthChecker { utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth_; }; -#else - -struct SessionData { - // Explicit constructor here to ensure that pointers to all objects are - // supplied. - SessionData(storage::Storage *db, query::InterpreterContext *interpreter_context) - : db(db), interpreter_context(interpreter_context) {} - storage::Storage *db; - query::InterpreterContext *interpreter_context; -}; - -class NoAuthInCommunity : public query::QueryRuntimeException { - public: - NoAuthInCommunity() - : query::QueryRuntimeException::QueryRuntimeException("Auth is not supported in Memgraph Community!") {} -}; - -class AuthQueryHandler final : public query::AuthQueryHandler { - public: - bool CreateUser(const std::string &, const std::optional<std::string> &) override { throw NoAuthInCommunity(); } - - bool DropUser(const std::string &) override { throw NoAuthInCommunity(); } - - void SetPassword(const std::string &, const std::optional<std::string> &) override { throw NoAuthInCommunity(); } - - bool CreateRole(const std::string &) override { throw NoAuthInCommunity(); } - - bool DropRole(const std::string &) override { throw NoAuthInCommunity(); } - - std::vector<query::TypedValue> GetUsernames() override { throw NoAuthInCommunity(); } - - std::vector<query::TypedValue> GetRolenames() override { throw NoAuthInCommunity(); } - - std::optional<std::string> GetRolenameForUser(const std::string &) override { throw NoAuthInCommunity(); } - - std::vector<query::TypedValue> GetUsernamesForRole(const std::string &) override { throw NoAuthInCommunity(); } - - void SetRole(const std::string &, const std::string &) override { throw NoAuthInCommunity(); } - - void ClearRole(const std::string &) override { throw NoAuthInCommunity(); } - - std::vector<std::vector<query::TypedValue>> GetPrivileges(const std::string &) override { throw NoAuthInCommunity(); } - - void GrantPrivilege(const std::string &, const std::vector<query::AuthQuery::Privilege> &) override { - throw NoAuthInCommunity(); - } - - void DenyPrivilege(const std::string &, const std::vector<query::AuthQuery::Privilege> &) override { - throw NoAuthInCommunity(); - } - - void RevokePrivilege(const std::string &, const std::vector<query::AuthQuery::Privilege> &) override { - throw NoAuthInCommunity(); - } -}; - -#endif - class BoltSession final : public communication::bolt::Session<communication::InputStream, communication::OutputStream> { public: BoltSession(SessionData *data, const io::network::Endpoint &endpoint, communication::InputStream *input_stream, @@ -788,8 +776,8 @@ class BoltSession final : public communication::bolt::Session<communication::Inp output_stream), db_(data->db), interpreter_(data->interpreter_context), -#ifdef MG_ENTERPRISE auth_(data->auth), +#if MG_ENTERPRISE audit_log_(data->audit_log), #endif endpoint_(endpoint) { @@ -808,22 +796,22 @@ class BoltSession final : public communication::bolt::Session<communication::Inp std::map<std::string, storage::PropertyValue> params_pv; for (const auto &kv : params) params_pv.emplace(kv.first, glue::ToPropertyValue(kv.second)); const std::string *username{nullptr}; -#ifdef MG_ENTERPRISE if (user_) { username = &user_->username(); } - audit_log_->Record(endpoint_.address, user_ ? *username : "", query, storage::PropertyValue(params_pv)); +#ifdef MG_ENTERPRISE + if (utils::license::global_license_checker.IsValidLicenseFast()) { + audit_log_->Record(endpoint_.address, user_ ? *username : "", query, storage::PropertyValue(params_pv)); + } #endif try { auto result = interpreter_.Prepare(query, params_pv, username); -#ifdef MG_ENTERPRISE if (user_ && !AuthChecker::IsUserAuthorized(*user_, result.privileges)) { interpreter_.Abort(); throw communication::bolt::ClientError( "You are not authorized to execute this query! Please contact " "your database administrator."); } -#endif return {result.headers, result.qid}; } catch (const query::QueryException &e) { @@ -847,16 +835,12 @@ class BoltSession final : public communication::bolt::Session<communication::Inp void Abort() override { interpreter_.Abort(); } bool Authenticate(const std::string &username, const std::string &password) override { -#ifdef MG_ENTERPRISE auto locked_auth = auth_->Lock(); if (!locked_auth->HasUsers()) { return true; } user_ = locked_auth->Authenticate(username, password); return user_.has_value(); -#else - return true; -#endif } std::optional<std::string> GetServerNameForInit() override { @@ -930,9 +914,9 @@ class BoltSession final : public communication::bolt::Session<communication::Inp // NOTE: Needed only for ToBoltValue conversions const storage::Storage *db_; query::Interpreter interpreter_; -#ifdef MG_ENTERPRISE utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> *auth_; std::optional<auth::User> user_; +#ifdef MG_ENTERPRISE audit::Log *audit_log_; #endif io::network::Endpoint endpoint_; @@ -1041,7 +1025,25 @@ int main(int argc, char **argv) { auto data_directory = std::filesystem::path(FLAGS_data_directory); -#ifdef MG_ENTERPRISE + const auto memory_limit = GetMemoryLimit(); + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + spdlog::info("Memory limit in config is set to {}", utils::GetReadableSize(memory_limit)); + utils::total_memory_tracker.SetMaximumHardLimit(memory_limit); + utils::total_memory_tracker.SetHardLimit(memory_limit); + + utils::global_settings.Initialize(data_directory / "settings"); + utils::OnScopeExit settings_finalizer([&] { utils::global_settings.Finalize(); }); + + // register all runtime settings + utils::license::RegisterLicenseSettings(utils::license::global_license_checker, utils::global_settings); + + utils::license::global_license_checker.CheckEnvLicense(); + if (!FLAGS_organization_name.empty() && !FLAGS_license_key.empty()) { + utils::license::global_license_checker.SetLicenseInfoOverride(FLAGS_license_key, FLAGS_organization_name); + } + + utils::license::global_license_checker.StartBackgroundLicenseChecker(utils::global_settings); + // All enterprise features should be constructed before the main database // storage. This will cause them to be destructed *after* the main database // storage. That way any errors that happen during enterprise features @@ -1056,6 +1058,7 @@ int main(int argc, char **argv) { // Auth utils::Synchronized<auth::Auth, utils::WritePrioritizedRWLock> auth{data_directory / "auth"}; +#ifdef MG_ENTERPRISE // Audit log audit::Log audit_log{data_directory / "audit", FLAGS_audit_buffer_size, FLAGS_audit_buffer_flush_interval_ms}; // Start the log if enabled. @@ -1070,10 +1073,6 @@ int main(int argc, char **argv) { // End enterprise features initialization #endif - const auto memory_limit = GetMemoryLimit(); - spdlog::info("Memory limit set to {}", utils::GetReadableSize(memory_limit)); - utils::total_memory_tracker.SetHardLimit(memory_limit); - // Main storage and execution engines initialization storage::Config db_config{ .gc = {.type = storage::Config::Gc::Type::PERIODIC, .interval = std::chrono::seconds(FLAGS_storage_gc_cycle_sec)}, @@ -1102,6 +1101,7 @@ int main(int argc, char **argv) { db_config.durability.snapshot_interval = std::chrono::seconds(FLAGS_storage_snapshot_interval_sec); } storage::Storage db(db_config); + query::InterpreterContext interpreter_context{ &db, {.query = {.allow_load_csv = FLAGS_allow_load_csv}, .execution_timeout_sec = FLAGS_query_execution_timeout_sec}, @@ -1110,7 +1110,7 @@ int main(int argc, char **argv) { #ifdef MG_ENTERPRISE SessionData session_data{&db, &interpreter_context, &auth, &audit_log}; #else - SessionData session_data{&db, &interpreter_context}; + SessionData session_data{&db, &interpreter_context, &auth}; #endif query::procedure::gModuleRegistry.SetModulesDirectory(query_modules_directories); @@ -1119,13 +1119,8 @@ int main(int argc, char **argv) { // As the Stream transformations are using modules, they have to be restored after the query modules are loaded. interpreter_context.streams.RestoreStreams(); -#ifdef MG_ENTERPRISE - AuthQueryHandler auth_handler(&auth, std::regex(FLAGS_auth_user_or_role_name_regex)); + AuthQueryHandler auth_handler(&auth, FLAGS_auth_user_or_role_name_regex); AuthChecker auth_checker{&auth}; -#else - AuthQueryHandler auth_handler; - query::AllowEverythingAuthChecker auth_checker{}; -#endif interpreter_context.auth = &auth_handler; interpreter_context.auth_checker = &auth_checker; diff --git a/src/query/exceptions.hpp b/src/query/exceptions.hpp index f0dad6419..30c04157b 100644 --- a/src/query/exceptions.hpp +++ b/src/query/exceptions.hpp @@ -188,4 +188,10 @@ class CreateSnapshotInMulticommandTxException final : public QueryException { CreateSnapshotInMulticommandTxException() : QueryException("Snapshot cannot be created in multicommand transactions.") {} }; + +class SettingConfigInMulticommandTxException final : public QueryException { + public: + SettingConfigInMulticommandTxException() + : QueryException("Settings cannot be changed or fetched in multicommand transactions.") {} +}; } // namespace query diff --git a/src/query/frontend/ast/ast.lcp b/src/query/frontend/ast/ast.lcp index 9bca16e37..7fff2acb4 100644 --- a/src/query/frontend/ast/ast.lcp +++ b/src/query/frontend/ast/ast.lcp @@ -2521,4 +2521,25 @@ cpp<# (:serialize (:slk)) (:clone)) +(lcp:define-class setting-query (query) + ((action "Action" :scope :public) + (setting_name "Expression *" :initval "nullptr" :scope :public) + (setting_value "Expression *" :initval "nullptr" :scope :public)) + + (:public + (lcp:define-enum action + (show-setting show-all-settings set-setting) + (:serialize)) + #>cpp + SettingQuery() = default; + + DEFVISITABLE(QueryVisitor<void>); + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + (lcp:pop-namespace) ;; namespace query diff --git a/src/query/frontend/ast/ast_visitor.hpp b/src/query/frontend/ast/ast_visitor.hpp index 679f36bfb..e6a78692a 100644 --- a/src/query/frontend/ast/ast_visitor.hpp +++ b/src/query/frontend/ast/ast_visitor.hpp @@ -80,6 +80,7 @@ class TriggerQuery; class IsolationLevelQuery; class CreateSnapshotQuery; class StreamQuery; +class SettingQuery; using TreeCompositeVisitor = ::utils::CompositeVisitor< SingleQuery, CypherUnion, NamedExpression, OrOperator, XorOperator, AndOperator, NotOperator, AdditionOperator, @@ -114,6 +115,6 @@ template <class TResult> class QueryVisitor : public ::utils::Visitor<TResult, CypherQuery, ExplainQuery, ProfileQuery, IndexQuery, AuthQuery, InfoQuery, ConstraintQuery, DumpQuery, ReplicationQuery, LockPathQuery, FreeMemoryQuery, - TriggerQuery, IsolationLevelQuery, CreateSnapshotQuery, StreamQuery> {}; + TriggerQuery, IsolationLevelQuery, CreateSnapshotQuery, StreamQuery, SettingQuery> {}; } // namespace query diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index 572dba5bc..e2b2cb00b 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -575,6 +575,53 @@ antlrcpp::Any CypherMainVisitor::visitCheckStream(MemgraphCypher::CheckStreamCon return stream_query; } +antlrcpp::Any CypherMainVisitor::visitSettingQuery(MemgraphCypher::SettingQueryContext *ctx) { + MG_ASSERT(ctx->children.size() == 1, "SettingQuery should have exactly one child!"); + auto *setting_query = ctx->children[0]->accept(this).as<SettingQuery *>(); + query_ = setting_query; + return setting_query; +} + +antlrcpp::Any CypherMainVisitor::visitSetSetting(MemgraphCypher::SetSettingContext *ctx) { + auto *setting_query = storage_->Create<SettingQuery>(); + setting_query->action_ = SettingQuery::Action::SET_SETTING; + + if (!ctx->settingName()->literal()->StringLiteral()) { + throw SemanticException("Setting name should be a string literal"); + } + + if (!ctx->settingValue()->literal()->StringLiteral()) { + throw SemanticException("Setting value should be a string literal"); + } + + setting_query->setting_name_ = ctx->settingName()->accept(this); + MG_ASSERT(setting_query->setting_name_); + + setting_query->setting_value_ = ctx->settingValue()->accept(this); + MG_ASSERT(setting_query->setting_value_); + return setting_query; +} + +antlrcpp::Any CypherMainVisitor::visitShowSetting(MemgraphCypher::ShowSettingContext *ctx) { + auto *setting_query = storage_->Create<SettingQuery>(); + setting_query->action_ = SettingQuery::Action::SHOW_SETTING; + + if (!ctx->settingName()->literal()->StringLiteral()) { + throw SemanticException("Setting name should be a string literal"); + } + + setting_query->setting_name_ = ctx->settingName()->accept(this); + MG_ASSERT(setting_query->setting_name_); + + return setting_query; +} + +antlrcpp::Any CypherMainVisitor::visitShowSettings(MemgraphCypher::ShowSettingsContext * /*ctx*/) { + auto *setting_query = storage_->Create<SettingQuery>(); + setting_query->action_ = SettingQuery::Action::SHOW_ALL_SETTINGS; + return setting_query; +} + antlrcpp::Any CypherMainVisitor::visitCypherUnion(MemgraphCypher::CypherUnionContext *ctx) { bool distinct = !ctx->ALL(); auto *cypher_union = storage_->Create<CypherUnion>(distinct); diff --git a/src/query/frontend/ast/cypher_main_visitor.hpp b/src/query/frontend/ast/cypher_main_visitor.hpp index f33a48bd4..0ce48d318 100644 --- a/src/query/frontend/ast/cypher_main_visitor.hpp +++ b/src/query/frontend/ast/cypher_main_visitor.hpp @@ -293,6 +293,26 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor { */ antlrcpp::Any visitCheckStream(MemgraphCypher::CheckStreamContext *ctx) override; + /** + * @return SettingQuery* + */ + antlrcpp::Any visitSettingQuery(MemgraphCypher::SettingQueryContext *ctx) override; + + /** + * @return SetSetting* + */ + antlrcpp::Any visitSetSetting(MemgraphCypher::SetSettingContext *ctx) override; + + /** + * @return ShowSetting* + */ + antlrcpp::Any visitShowSetting(MemgraphCypher::ShowSettingContext *ctx) override; + + /** + * @return ShowSettings* + */ + antlrcpp::Any visitShowSettings(MemgraphCypher::ShowSettingsContext *ctx) override; + /** * @return CypherUnion* */ diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 index 6f27d3853..3aa1b167e 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 @@ -58,6 +58,8 @@ memgraphCypherKeyword : cypherKeyword | ROLES | QUOTE | SESSION + | SETTING + | SETTINGS | SNAPSHOT | START | STATS @@ -98,6 +100,7 @@ query : cypherQuery | isolationLevelQuery | createSnapshotQuery | streamQuery + | settingQuery ; authQuery : createRole @@ -152,6 +155,11 @@ streamQuery : checkStream | showStreams ; +settingQuery : setSetting + | showSetting + | showSettings + ; + loadCsv : LOAD CSV FROM csvFile ( WITH | NO ) HEADER ( IGNORE BAD ) ? ( DELIMITER delimiter ) ? @@ -295,3 +303,13 @@ stopAllStreams : STOP ALL STREAMS ; showStreams : SHOW STREAMS ; checkStream : CHECK STREAM streamName ( BATCH_LIMIT batchLimit=literal ) ? ( TIMEOUT timeout=literal ) ? ; + +settingName : literal ; + +settingValue : literal ; + +setSetting : SET DATABASE SETTING settingName TO settingValue ; + +showSetting : SHOW DATABASE SETTING settingName ; + +showSettings : SHOW DATABASE SETTINGS ; diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 index c2d0e7a0c..578aa2b8e 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 @@ -69,6 +69,8 @@ ROLE : R O L E ; ROLES : R O L E S ; QUOTE : Q U O T E ; SESSION : S E S S I O N ; +SETTING : S E T T I N G ; +SETTINGS : S E T T I N G S ; SNAPSHOT : S N A P S H O T ; START : S T A R T ; STATS : S T A T S ; diff --git a/src/query/frontend/semantic/required_privileges.cpp b/src/query/frontend/semantic/required_privileges.cpp index 5131e3041..6f9a0baef 100644 --- a/src/query/frontend/semantic/required_privileges.cpp +++ b/src/query/frontend/semantic/required_privileges.cpp @@ -63,6 +63,8 @@ class PrivilegeExtractor : public QueryVisitor<void>, public HierarchicalTreeVis void Visit(CreateSnapshotQuery &create_snapshot_query) override { AddPrivilege(AuthQuery::Privilege::DURABILITY); } + void Visit(SettingQuery & /*setting_query*/) override { AddPrivilege(AuthQuery::Privilege::CONFIG); } + bool PreVisit(Create & /*unused*/) override { AddPrivilege(AuthQuery::Privilege::CREATE); return false; diff --git a/src/query/frontend/stripped_lexer_constants.hpp b/src/query/frontend/stripped_lexer_constants.hpp index 73bd83346..0a7fa7975 100644 --- a/src/query/frontend/stripped_lexer_constants.hpp +++ b/src/query/frontend/stripped_lexer_constants.hpp @@ -131,7 +131,8 @@ const trie::Trie kKeywords = {"union", "all", "batch_size", "consumer_group", "start", "stream", "streams", "transform", - "topics", "check"}; + "topics", "check", + "setting", "settings"}; // Unicode codepoints that are allowed at the start of the unescaped name. const std::bitset<kBitsetSize> kUnescapedNameAllowedStarts( diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index 2a64609f8..85d4b7133 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -14,6 +14,7 @@ #include "query/dump.hpp" #include "query/exceptions.hpp" #include "query/frontend/ast/ast.hpp" +#include "query/frontend/ast/ast_visitor.hpp" #include "query/frontend/ast/cypher_main_visitor.hpp" #include "query/frontend/opencypher/parser.hpp" #include "query/frontend/semantic/required_privileges.hpp" @@ -22,6 +23,7 @@ #include "query/plan/planner.hpp" #include "query/plan/profile.hpp" #include "query/plan/vertex_count_cache.hpp" +#include "query/streams.hpp" #include "query/trigger.hpp" #include "query/typed_value.hpp" #include "storage/v2/property_value.hpp" @@ -30,11 +32,13 @@ #include "utils/event_counter.hpp" #include "utils/exceptions.hpp" #include "utils/flag_validation.hpp" +#include "utils/license.hpp" #include "utils/likely.hpp" #include "utils/logging.hpp" #include "utils/memory.hpp" #include "utils/memory_tracker.hpp" #include "utils/readable_size.hpp" +#include "utils/settings.hpp" #include "utils/string.hpp" #include "utils/tsc.hpp" @@ -233,14 +237,34 @@ Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, const Pa Callback callback; + const auto license_check_result = utils::license::global_license_checker.IsValidLicense(utils::global_settings); + + static const std::unordered_set enterprise_only_methods{ + AuthQuery::Action::CREATE_ROLE, AuthQuery::Action::DROP_ROLE, AuthQuery::Action::SET_ROLE, + AuthQuery::Action::CLEAR_ROLE, AuthQuery::Action::GRANT_PRIVILEGE, AuthQuery::Action::DENY_PRIVILEGE, + AuthQuery::Action::REVOKE_PRIVILEGE, AuthQuery::Action::SHOW_PRIVILEGES, AuthQuery::Action::SHOW_USERS_FOR_ROLE, + AuthQuery::Action::SHOW_ROLE_FOR_USER}; + + if (license_check_result.HasError() && enterprise_only_methods.contains(auth_query->action_)) { + throw utils::BasicException( + utils::license::LicenseCheckErrorToString(license_check_result.GetError(), "advanced authentication features")); + } + switch (auth_query->action_) { case AuthQuery::Action::CREATE_USER: - callback.fn = [auth, username, password] { + callback.fn = [auth, username, password, valid_enterprise_license = !license_check_result.HasError()] { MG_ASSERT(password.IsString() || password.IsNull()); if (!auth->CreateUser(username, password.IsString() ? std::make_optional(std::string(password.ValueString())) : std::nullopt)) { throw QueryRuntimeException("User '{}' already exists.", username); } + + // If the license is not valid we create users with admin access + if (!valid_enterprise_license) { + spdlog::warn("Granting all the privileges to {}.", username); + auth->GrantPrivilege(username, kPrivilegesAll); + } + return std::vector<std::vector<TypedValue>>(); }; return callback; @@ -606,6 +630,87 @@ Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters ¶mete } } +Callback HandleSettingQuery(SettingQuery *setting_query, const Parameters ¶meters, DbAccessor *db_accessor) { + Frame frame(0); + SymbolTable symbol_table; + EvaluationContext evaluation_context; + // TODO: MemoryResource for EvaluationContext, it should probably be passed as + // the argument to Callback. + evaluation_context.timestamp = + std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()) + .count(); + evaluation_context.parameters = parameters; + ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, db_accessor, storage::View::OLD); + + Callback callback; + switch (setting_query->action_) { + case SettingQuery::Action::SET_SETTING: { + const auto setting_name = EvaluateOptionalExpression(setting_query->setting_name_, &evaluator); + if (!setting_name.IsString()) { + throw utils::BasicException("Setting name should be a string literal"); + } + + const auto setting_value = EvaluateOptionalExpression(setting_query->setting_value_, &evaluator); + if (!setting_value.IsString()) { + throw utils::BasicException("Setting value should be a string literal"); + } + + callback.fn = [setting_name = std::string{setting_name.ValueString()}, + setting_value = std::string{setting_value.ValueString()}]() mutable { + if (!utils::global_settings.SetValue(setting_name, setting_value)) { + throw utils::BasicException("Unknown setting name '{}'", setting_name); + } + return std::vector<std::vector<TypedValue>>{}; + }; + return callback; + } + case SettingQuery::Action::SHOW_SETTING: { + const auto setting_name = EvaluateOptionalExpression(setting_query->setting_name_, &evaluator); + if (!setting_name.IsString()) { + throw utils::BasicException("Setting name should be a string literal"); + } + + callback.header = {"setting_value"}; + callback.fn = [setting_name = std::string{setting_name.ValueString()}] { + auto maybe_value = utils::global_settings.GetValue(setting_name); + if (!maybe_value) { + throw utils::BasicException("Unknown setting name '{}'", setting_name); + } + std::vector<std::vector<TypedValue>> results; + results.reserve(1); + + std::vector<TypedValue> setting_value; + setting_value.reserve(1); + + setting_value.emplace_back(*maybe_value); + results.push_back(std::move(setting_value)); + return results; + }; + return callback; + } + case SettingQuery::Action::SHOW_ALL_SETTINGS: { + callback.header = {"setting_name", "setting_value"}; + callback.fn = [] { + auto all_settings = utils::global_settings.AllSettings(); + std::vector<std::vector<TypedValue>> results; + results.reserve(all_settings.size()); + + for (const auto &[k, v] : all_settings) { + std::vector<TypedValue> setting_info; + setting_info.reserve(2); + + setting_info.emplace_back(k); + setting_info.emplace_back(v); + results.push_back(std::move(setting_info)); + } + + return results; + }; + return callback; + } + } +} + // Struct for lazy pulling from a vector struct PullPlanVector { explicit PullPlanVector(std::vector<std::vector<TypedValue>> values) : values_(std::move(values)) {} @@ -1429,6 +1534,32 @@ PreparedQuery PrepareCreateSnapshotQuery(ParsedQuery parsed_query, bool in_expli RWType::NONE}; } +PreparedQuery PrepareSettingQuery(ParsedQuery parsed_query, const bool in_explicit_transaction, DbAccessor *dba) { + if (in_explicit_transaction) { + throw SettingConfigInMulticommandTxException{}; + } + + auto *setting_query = utils::Downcast<SettingQuery>(parsed_query.query); + MG_ASSERT(setting_query); + auto callback = HandleSettingQuery(setting_query, parsed_query.parameters, dba); + + return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges), + [callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}]( + AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> { + if (UNLIKELY(!pull_plan)) { + pull_plan = std::make_shared<PullPlanVector>(callback_fn()); + } + + if (pull_plan->Pull(stream, n)) { + return QueryHandlerResult::COMMIT; + } + return std::nullopt; + }, + RWType::NONE}; + // False positive report for the std::make_shared above + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) +} + PreparedQuery PrepareInfoQuery(ParsedQuery parsed_query, bool in_explicit_transaction, std::map<std::string, TypedValue> *summary, InterpreterContext *interpreter_context, storage::Storage *db, utils::MemoryResource *execution_memory) { @@ -1738,29 +1869,29 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, trigger_context_collector_ ? &*trigger_context_collector_ : nullptr); } else if (utils::Downcast<ExplainQuery>(parsed_query.query)) { prepared_query = PrepareExplainQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_, - &*execution_db_accessor_, &query_execution->execution_memory); + &*execution_db_accessor_, &query_execution->execution_memory_with_exception); } else if (utils::Downcast<ProfileQuery>(parsed_query.query)) { - prepared_query = - PrepareProfileQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, - interpreter_context_, &*execution_db_accessor_, &query_execution->execution_memory); + prepared_query = PrepareProfileQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, + interpreter_context_, &*execution_db_accessor_, + &query_execution->execution_memory_with_exception); } else if (utils::Downcast<DumpQuery>(parsed_query.query)) { prepared_query = PrepareDumpQuery(std::move(parsed_query), &query_execution->summary, &*execution_db_accessor_, &query_execution->execution_memory); } else if (utils::Downcast<IndexQuery>(parsed_query.query)) { prepared_query = PrepareIndexQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, - interpreter_context_, &query_execution->execution_memory); + interpreter_context_, &query_execution->execution_memory_with_exception); } else if (utils::Downcast<AuthQuery>(parsed_query.query)) { - prepared_query = - PrepareAuthQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, - interpreter_context_, &*execution_db_accessor_, &query_execution->execution_memory); + prepared_query = PrepareAuthQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, + interpreter_context_, &*execution_db_accessor_, + &query_execution->execution_memory_with_exception); } else if (utils::Downcast<InfoQuery>(parsed_query.query)) { - prepared_query = - PrepareInfoQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, - interpreter_context_, interpreter_context_->db, &query_execution->execution_memory); + prepared_query = PrepareInfoQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, + interpreter_context_, interpreter_context_->db, + &query_execution->execution_memory_with_exception); } else if (utils::Downcast<ConstraintQuery>(parsed_query.query)) { prepared_query = PrepareConstraintQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, - interpreter_context_, &query_execution->execution_memory); + interpreter_context_, &query_execution->execution_memory_with_exception); } else if (utils::Downcast<ReplicationQuery>(parsed_query.query)) { prepared_query = PrepareReplicationQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_, &*execution_db_accessor_); @@ -1781,6 +1912,8 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, } else if (utils::Downcast<CreateSnapshotQuery>(parsed_query.query)) { prepared_query = PrepareCreateSnapshotQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_); + } else if (utils::Downcast<SettingQuery>(parsed_query.query)) { + prepared_query = PrepareSettingQuery(std::move(parsed_query), in_explicit_transaction_, &*execution_db_accessor_); } else { LOG_FATAL("Should not get here -- unknown query type!"); } diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index ec07b74f9..21e4aefaf 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -22,6 +22,7 @@ #include "utils/event_counter.hpp" #include "utils/logging.hpp" #include "utils/memory.hpp" +#include "utils/settings.hpp" #include "utils/skip_list.hpp" #include "utils/spin_lock.hpp" #include "utils/thread_pool.hpp" @@ -269,8 +270,8 @@ class Interpreter final { private: struct QueryExecution { std::optional<PreparedQuery> prepared_query; - utils::MonotonicBufferResource execution_monotonic_memory{kExecutionMemoryBlockSize}; - utils::ResourceWithOutOfMemoryException execution_memory{&execution_monotonic_memory}; + utils::MonotonicBufferResource execution_memory{kExecutionMemoryBlockSize}; + utils::ResourceWithOutOfMemoryException execution_memory_with_exception{&execution_memory}; std::map<std::string, TypedValue> summary; @@ -285,7 +286,7 @@ class Interpreter final { // destroy the prepared query which is using that instance // of execution memory. prepared_query.reset(); - execution_monotonic_memory.Release(); + execution_memory.Release(); } }; diff --git a/src/utils/CMakeLists.txt b/src/utils/CMakeLists.txt index 07c24c0e7..e55681774 100644 --- a/src/utils/CMakeLists.txt +++ b/src/utils/CMakeLists.txt @@ -1,12 +1,15 @@ set(utils_src_files async_timer.cpp + base64.cpp event_counter.cpp csv_parsing.cpp file.cpp file_locker.cpp + license.cpp memory.cpp memory_tracker.cpp readable_size.cpp + settings.cpp signals.cpp sysinfo/memory.cpp thread.cpp @@ -14,4 +17,4 @@ set(utils_src_files uuid.cpp) add_library(mg-utils STATIC ${utils_src_files}) -target_link_libraries(mg-utils stdc++fs Threads::Threads spdlog fmt gflags uuid rt) +target_link_libraries(mg-utils mg-kvstore mg-slk stdc++fs Threads::Threads spdlog fmt gflags uuid rt) diff --git a/src/utils/base64.cpp b/src/utils/base64.cpp new file mode 100644 index 000000000..0af4249b7 --- /dev/null +++ b/src/utils/base64.cpp @@ -0,0 +1,260 @@ +/* + base64.cpp and base64.h + base64 encoding and decoding with C++. + More information at + https://renenyffenegger.ch/notes/development/Base64/Encoding-and-decoding-base-64-with-cpp + Version: 2.rc.08 (release candidate) + Copyright (C) 2004-2017, 2020, 2021 René Nyffenegger + This source code is provided 'as-is', without any express or implied + warranty. In no event will the author be held liable for any damages + arising from the use of this software. + Permission is granted to anyone to use this software for any purpose, + including commercial applications, and to alter it and redistribute it + freely, subject to the following restrictions: + 1. The origin of this source code must not be misrepresented; you must not + claim that you wrote the original source code. If you use this source code + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. + 2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original source code. + 3. This notice may not be removed or altered from any source distribution. + René Nyffenegger rene.nyffenegger@adp-gmbh.ch +*/ + +#include "base64.hpp" + +#include <algorithm> +#include <array> +#include <stdexcept> + +namespace utils { +namespace { +// +// Depending on the url parameter in base64_chars, one of +// two sets of base64 characters needs to be chosen. +// They differ in their last two characters. +// +constexpr std::array base64_chars = { + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789" + "+/", + + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789" + "-_"}; + +unsigned int pos_of_char(const unsigned char chr) { + // + // Return the position of chr within base64_encode() + // + + if (chr >= 'A' && chr <= 'Z') return chr - 'A'; + if (chr >= 'a' && chr <= 'z') return chr - 'a' + ('Z' - 'A') + 1; + if (chr >= '0' && chr <= '9') return chr - '0' + ('Z' - 'A') + ('z' - 'a') + 2; + if (chr == '+' || chr == '-') + return 62; // Be liberal with input and accept both url ('-') and non-url ('+') base 64 characters ( + if (chr == '/' || chr == '_') return 63; // Ditto for '/' and '_' + + // + // 2020-10-23: Throw std::exception rather than const char* + //(Pablo Martin-Gomez, https://github.com/Bouska) + // + throw std::runtime_error("Input is not valid base64-encoded data."); +} + +std::string insert_linebreaks(std::string str, size_t distance) { + // + // Provided by https://github.com/JomaCorpFX, adapted by me. + // + if (!str.length()) { + return ""; + } + + size_t pos = distance; + + while (pos < str.size()) { + str.insert(pos, "\n"); + pos += distance + 1; + } + + return str; +} + +template <typename String, unsigned int line_length> +std::string encode_with_line_breaks(String s) { + return insert_linebreaks(base64_encode(s, false), line_length); +} + +template <typename String> +std::string encode_pem(String s) { + return encode_with_line_breaks<String, 64>(s); +} + +template <typename String> +std::string encode_mime(String s) { + return encode_with_line_breaks<String, 76>(s); +} + +template <typename String> +std::string encode(String s, bool url) { + return base64_encode(reinterpret_cast<const unsigned char *>(s.data()), s.length(), url); +} +} // namespace + +std::string base64_encode(unsigned char const *bytes_to_encode, size_t in_len, bool url) { + size_t len_encoded = (in_len + 2) / 3 * 4; + + unsigned char trailing_char = url ? '.' : '='; + + // + // Choose set of base64 characters. They differ + // for the last two positions, depending on the url + // parameter. + // A bool (as is the parameter url) is guaranteed + // to evaluate to either 0 or 1 in C++ therefore, + // the correct character set is chosen by subscripting + // base64_chars with url. + // + const char *base64_chars_ = base64_chars[url]; + + std::string ret; + ret.reserve(len_encoded); + + unsigned int pos = 0; + + while (pos < in_len) { + // NOLINTNEXTLINE(hicpp-signed-bitwise) + ret.push_back(base64_chars_[(bytes_to_encode[pos + 0] & 0xfc) >> 2]); + + if (pos + 1 < in_len) { + // NOLINTNEXTLINE(hicpp-signed-bitwise) + ret.push_back(base64_chars_[((bytes_to_encode[pos + 0] & 0x03) << 4) + ((bytes_to_encode[pos + 1] & 0xf0) >> 4)]); + + if (pos + 2 < in_len) { + ret.push_back( + // NOLINTNEXTLINE(hicpp-signed-bitwise) + base64_chars_[((bytes_to_encode[pos + 1] & 0x0f) << 2) + ((bytes_to_encode[pos + 2] & 0xc0) >> 6)]); + // NOLINTNEXTLINE(hicpp-signed-bitwise) + ret.push_back(base64_chars_[bytes_to_encode[pos + 2] & 0x3f]); + } else { + // NOLINTNEXTLINE(hicpp-signed-bitwise) + ret.push_back(base64_chars_[(bytes_to_encode[pos + 1] & 0x0f) << 2]); + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + ret.push_back(trailing_char); + } + } else { + // NOLINTNEXTLINE(hicpp-signed-bitwise) + ret.push_back(base64_chars_[(bytes_to_encode[pos + 0] & 0x03) << 4]); + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + ret.push_back(trailing_char); + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + ret.push_back(trailing_char); + } + + pos += 3; + } + + return ret; +} + +template <typename String> +static std::string decode(String encoded_string, bool remove_linebreaks) { + // + // decode(…) is templated so that it can be used with String = const std::string& + // or std::string_view (requires at least C++17) + // + + if (encoded_string.empty()) return std::string(); + + if (remove_linebreaks) { + std::string copy(encoded_string); + + copy.erase(std::remove(copy.begin(), copy.end(), '\n'), copy.end()); + + return base64_decode(copy, false); + } + + size_t length_of_string = encoded_string.length(); + size_t pos = 0; + + // + // The approximate length (bytes) of the decoded string might be one or + // two bytes smaller, depending on the amount of trailing equal signs + // in the encoded string. This approximation is needed to reserve + // enough space in the string to be returned. + // + size_t approx_length_of_decoded_string = length_of_string / 4 * 3; + std::string ret; + ret.reserve(approx_length_of_decoded_string); + + while (pos < length_of_string) { + // + // Iterate over encoded input string in chunks. The size of all + // chunks except the last one is 4 bytes. + // + // The last chunk might be padded with equal signs or dots + // in order to make it 4 bytes in size as well, but this + // is not required as per RFC 2045. + // + // All chunks except the last one produce three output bytes. + // + // The last chunk produces at least one and up to three bytes. + // + + size_t pos_of_char_1 = pos_of_char(encoded_string[pos + 1]); + + // + // Emit the first output byte that is produced in each chunk: + // + // NOLINTNEXTLINE(hicpp-signed-bitwise) + ret.push_back(static_cast<std::string::value_type>(((pos_of_char(encoded_string[pos + 0])) << 2) + + ((pos_of_char_1 & 0x30) >> 4))); // NOLINT(hicpp-signed-bitwise) + + if ((pos + 2 < + length_of_string) && // Check for data that is not padded with equal signs (which is allowed by RFC 2045) + encoded_string[pos + 2] != '=' && + encoded_string[pos + 2] != '.' // accept URL-safe base 64 strings, too, so check for '.' also. + ) { + // + // Emit a chunk's second byte (which might not be produced in the last chunk). + // + unsigned int pos_of_char_2 = pos_of_char(encoded_string[pos + 2]); + ret.push_back( + // NOLINTNEXTLINE(hicpp-signed-bitwise) + static_cast<std::string::value_type>(((pos_of_char_1 & 0x0f) << 4) + ((pos_of_char_2 & 0x3c) >> 2))); + + if ((pos + 3 < length_of_string) && encoded_string[pos + 3] != '=' && encoded_string[pos + 3] != '.') { + // + // Emit a chunk's third byte (which might not be produced in the last chunk). + // + ret.push_back( + // NOLINTNEXTLINE(hicpp-signed-bitwise) + static_cast<std::string::value_type>(((pos_of_char_2 & 0x03) << 6) + pos_of_char(encoded_string[pos + 3]))); + } + } + + pos += 4; + } + + return ret; +} + +std::string base64_decode(std::string const &s, bool remove_linebreaks) { return decode(s, remove_linebreaks); } + +std::string base64_encode(std::string const &s, bool url) { return encode(s, url); } + +std::string base64_encode_pem(std::string const &s) { return encode_pem(s); } + +std::string base64_encode_mime(std::string const &s) { return encode_mime(s); } + +std::string base64_encode(std::string_view s, bool url) { return encode(s, url); } + +std::string base64_encode_pem(std::string_view s) { return encode_pem(s); } + +std::string base64_encode_mime(std::string_view s) { return encode_mime(s); } + +std::string base64_decode(std::string_view s, bool remove_linebreaks) { return decode(s, remove_linebreaks); } + +} // namespace utils diff --git a/src/utils/base64.hpp b/src/utils/base64.hpp new file mode 100644 index 000000000..3af42a10f --- /dev/null +++ b/src/utils/base64.hpp @@ -0,0 +1,35 @@ +// +// base64 encoding and decoding with C++. +// Version: 2.rc.08 (release candidate) +// + +#ifndef BASE64_H_C0CE2A47_D10E_42C9_A27C_C883944E704A +#define BASE64_H_C0CE2A47_D10E_42C9_A27C_C883944E704A + +#include <string> +#include <string_view> + +namespace utils { +std::string base64_encode(std::string const &s, bool url = false); +std::string base64_encode_pem(std::string const &s); +std::string base64_encode_mime(std::string const &s); + +std::string base64_decode(std::string const &s, bool remove_linebreaks = false); +std::string base64_encode(unsigned char const *, size_t len, bool url = false); + +#if __cplusplus >= 201703L +// +// Interface with std::string_view rather than const std::string& +// Requires C++17 +// Provided by Yannic Bonenberger (https://github.com/Yannic) +// +std::string base64_encode(std::string_view s, bool url = false); +std::string base64_encode_pem(std::string_view s); +std::string base64_encode_mime(std::string_view s); + +std::string base64_decode(std::string_view s, bool remove_linebreaks = false); +#endif // __cplusplus >= 201703L + +} // namespace utils + +#endif /* BASE64_H_C0CE2A47_D10E_42C9_A27C_C883944E704A */ diff --git a/src/utils/license.cpp b/src/utils/license.cpp new file mode 100644 index 000000000..0b4ca3598 --- /dev/null +++ b/src/utils/license.cpp @@ -0,0 +1,284 @@ +#include "utils/license.hpp" + +#include <atomic> +#include <charconv> +#include <chrono> +#include <functional> +#include <optional> +#include <unordered_map> + +#include "slk/serialization.hpp" +#include "utils/base64.hpp" +#include "utils/exceptions.hpp" +#include "utils/logging.hpp" +#include "utils/memory_tracker.hpp" +#include "utils/settings.hpp" +#include "utils/spin_lock.hpp" +#include "utils/synchronized.hpp" + +namespace utils::license { + +namespace { +constexpr std::string_view license_key_prefix = "mglk-"; + +std::optional<License> GetLicense(const std::string &license_key) { + if (license_key.empty()) { + return std::nullopt; + } + + static utils::Synchronized<std::pair<std::string, License>, utils::SpinLock> cached_license; + { + auto cache_locked = cached_license.Lock(); + const auto &[cached_key, license] = *cache_locked; + if (cached_key == license_key) { + return license; + } + } + auto license = Decode(license_key); + if (license) { + auto cache_locked = cached_license.Lock(); + *cache_locked = std::make_pair(license_key, *license); + } + return license; +} + +LicenseCheckResult IsValidLicenseInternal(const License &license, const std::string &organization_name) { + if (license.organization_name != organization_name) { + return LicenseCheckError::INVALID_ORGANIZATION_NAME; + } + + const auto now = + std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now().time_since_epoch()).count(); + + if (license.valid_until != 0 && now > license.valid_until) { + return LicenseCheckError::EXPIRED_LICENSE; + } + + return {}; +} +} // namespace + +void RegisterLicenseSettings(LicenseChecker &license_checker, utils::Settings &settings) { + settings.RegisterSetting(std::string{kEnterpriseLicenseSettingKey}, "", + [&] { license_checker.RevalidateLicense(settings); }); + settings.RegisterSetting(std::string{kOrganizationNameSettingKey}, "", + [&] { license_checker.RevalidateLicense(settings); }); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +LicenseChecker global_license_checker; + +LicenseChecker::~LicenseChecker() { scheduler_.Stop(); } + +std::pair<std::string, std::string> LicenseChecker::GetLicenseInfo(const utils::Settings &settings) const { + if (license_info_override_) { + spdlog::warn("Ignoring license info stored in the settings because a different source was specified."); + return *license_info_override_; + } + + auto license_key = settings.GetValue(std::string{kEnterpriseLicenseSettingKey}); + MG_ASSERT(license_key, "License key is missing from the settings"); + + auto organization_name = settings.GetValue(std::string{kOrganizationNameSettingKey}); + MG_ASSERT(organization_name, "Organization name is missing from the settings"); + return std::make_pair(std::move(*license_key), std::move(*organization_name)); +} + +void LicenseChecker::RevalidateLicense(const utils::Settings &settings) { + const auto license_info = GetLicenseInfo(settings); + RevalidateLicense(license_info.first, license_info.second); +} + +void LicenseChecker::RevalidateLicense(const std::string &license_key, const std::string &organization_name) { + static utils::Synchronized<std::optional<int64_t>, utils::SpinLock> previous_memory_limit; + const auto set_memory_limit = [](const auto memory_limit) { + auto locked_previous_memory_limit_ptr = previous_memory_limit.Lock(); + auto &locked_previous_memory_limit = *locked_previous_memory_limit_ptr; + if (!locked_previous_memory_limit || *locked_previous_memory_limit != memory_limit) { + utils::total_memory_tracker.SetHardLimit(memory_limit); + locked_previous_memory_limit = memory_limit; + } + }; + + if (enterprise_enabled_) [[unlikely]] { + is_valid_.store(true, std::memory_order_relaxed); + set_memory_limit(0); + return; + } + + struct PreviousLicenseInfo { + PreviousLicenseInfo(std::string license_key, std::string organization_name) + : license_key(std::move(license_key)), organization_name(std::move(organization_name)) {} + + std::string license_key; + std::string organization_name; + bool is_valid{false}; + }; + + static utils::Synchronized<std::optional<PreviousLicenseInfo>, utils::SpinLock> previous_license_info; + + auto locked_previous_license_info_ptr = previous_license_info.Lock(); + auto &locked_previous_license_info = *locked_previous_license_info_ptr; + const bool same_license_info = locked_previous_license_info && + locked_previous_license_info->license_key == license_key && + locked_previous_license_info->organization_name == organization_name; + // If we already know it's invalid skip the check + if (same_license_info && !locked_previous_license_info->is_valid) { + return; + } + + locked_previous_license_info.emplace(license_key, organization_name); + + const auto maybe_license = GetLicense(locked_previous_license_info->license_key); + if (!maybe_license) { + spdlog::warn(LicenseCheckErrorToString(LicenseCheckError::INVALID_LICENSE_KEY_STRING, "Enterprise features")); + is_valid_.store(false, std::memory_order_relaxed); + locked_previous_license_info->is_valid = false; + set_memory_limit(0); + return; + } + + const auto license_check_result = + IsValidLicenseInternal(*maybe_license, locked_previous_license_info->organization_name); + + if (license_check_result.HasError()) { + spdlog::warn(LicenseCheckErrorToString(license_check_result.GetError(), "Enterprise features")); + is_valid_.store(false, std::memory_order_relaxed); + locked_previous_license_info->is_valid = false; + set_memory_limit(0); + return; + } + + if (!same_license_info) { + spdlog::info("All Enterprise features are active."); + is_valid_.store(true, std::memory_order_relaxed); + locked_previous_license_info->is_valid = true; + set_memory_limit(maybe_license->memory_limit); + } +} + +void LicenseChecker::EnableTesting() { + enterprise_enabled_ = true; + is_valid_.store(true, std::memory_order_relaxed); + spdlog::info("All Enterprise features are activated for testing."); +} + +void LicenseChecker::CheckEnvLicense() { + // NOLINTNEXTLINE(concurrency-mt-unsafe) + const char *license_key = std::getenv("MEMGRAPH_ENTERPRISE_LICENSE"); + if (!license_key) { + return; + } + + // NOLINTNEXTLINE(concurrency-mt-unsafe) + const char *organization_name = std::getenv("MEMGRAPH_ORGANIZATION_NAME"); + if (!organization_name) { + return; + } + + spdlog::warn("Using license info from environment variables"); + license_info_override_.emplace(license_key, organization_name); + RevalidateLicense(license_key, organization_name); +} + +void LicenseChecker::SetLicenseInfoOverride(std::string license_key, std::string organization_name) { + spdlog::warn("Using license info overrides"); + license_info_override_.emplace(std::move(license_key), std::move(organization_name)); + RevalidateLicense(license_info_override_->first, license_info_override_->second); +} + +std::string LicenseCheckErrorToString(LicenseCheckError error, const std::string_view feature) { + switch (error) { + case LicenseCheckError::INVALID_LICENSE_KEY_STRING: + return fmt::format( + "Invalid license key string. To use {} please set it to a valid string using " + "the following query:\n" + "SET DATABASE SETTING \"enterprise.license\" TO \"your-license-key\"", + feature); + case LicenseCheckError::INVALID_ORGANIZATION_NAME: + return fmt::format( + "The organization name contained in the license key is not the same as the one defined in the settings. To " + "use {} please set the organization name to a valid string using the following query:\n" + "SET DATABASE SETTING \"organization.name\" TO \"your-organization-name\"", + feature); + case LicenseCheckError::EXPIRED_LICENSE: + return fmt::format( + "Your license key has expired. To use {} please renew your license and set the updated license key using the " + "following query:\n" + "SET DATABASE SETTING \"enterprise.license\" TO \"your-license-key\"", + feature); + } +} + +LicenseCheckResult LicenseChecker::IsValidLicense(const utils::Settings &settings) const { + if (enterprise_enabled_) [[unlikely]] { + return {}; + } + + const auto license_info = GetLicenseInfo(settings); + + const auto maybe_license = GetLicense(license_info.first); + if (!maybe_license) { + return LicenseCheckError::INVALID_LICENSE_KEY_STRING; + } + + return IsValidLicenseInternal(*maybe_license, license_info.second); +} + +void LicenseChecker::StartBackgroundLicenseChecker(const utils::Settings &settings) { + RevalidateLicense(settings); + scheduler_.Run("licensechecker", std::chrono::minutes{5}, [&, this] { RevalidateLicense(settings); }); +} + +bool LicenseChecker::IsValidLicenseFast() const { return is_valid_.load(std::memory_order_relaxed); } + +std::string Encode(const License &license) { + std::vector<uint8_t> buffer; + slk::Builder builder([&buffer](const uint8_t *data, size_t size, bool /*have_more*/) { + for (size_t i = 0; i < size; ++i) { + buffer.push_back(data[i]); + } + }); + + slk::Save(license.organization_name, &builder); + slk::Save(license.valid_until, &builder); + slk::Save(license.memory_limit, &builder); + builder.Finalize(); + + return std::string{license_key_prefix} + base64_encode(buffer.data(), buffer.size()); +} + +std::optional<License> Decode(std::string_view license_key) { + if (!license_key.starts_with(license_key_prefix)) { + return std::nullopt; + } + + license_key.remove_prefix(license_key_prefix.size()); + + const auto decoded = std::invoke([license_key]() -> std::optional<std::string> { + try { + return base64_decode(license_key); + } catch (const std::runtime_error & /*exception*/) { + return std::nullopt; + } + }); + + if (!decoded) { + return std::nullopt; + } + + try { + slk::Reader reader(std::bit_cast<uint8_t *>(decoded->c_str()), decoded->size()); + std::string organization_name; + slk::Load(&organization_name, &reader); + int64_t valid_until{0}; + slk::Load(&valid_until, &reader); + int64_t memory_limit{0}; + slk::Load(&memory_limit, &reader); + return License{.organization_name = organization_name, .valid_until = valid_until, .memory_limit = memory_limit}; + } catch (const slk::SlkReaderException &e) { + return std::nullopt; + } +} + +} // namespace utils::license diff --git a/src/utils/license.hpp b/src/utils/license.hpp new file mode 100644 index 000000000..6b087efa9 --- /dev/null +++ b/src/utils/license.hpp @@ -0,0 +1,66 @@ +#pragma once + +#include <cstdint> +#include <string> + +#include "utils/result.hpp" +#include "utils/scheduler.hpp" +#include "utils/settings.hpp" + +namespace utils::license { + +struct License { + std::string organization_name; + int64_t valid_until; + int64_t memory_limit; + + bool operator==(const License &) const = default; +}; + +constexpr std::string_view kEnterpriseLicenseSettingKey = "enterprise.license"; +constexpr std::string_view kOrganizationNameSettingKey = "organization.name"; + +enum class LicenseCheckError : uint8_t { INVALID_LICENSE_KEY_STRING, INVALID_ORGANIZATION_NAME, EXPIRED_LICENSE }; + +std::string LicenseCheckErrorToString(LicenseCheckError error, std::string_view feature); + +using LicenseCheckResult = utils::BasicResult<LicenseCheckError, void>; + +struct LicenseChecker { + public: + explicit LicenseChecker() = default; + ~LicenseChecker(); + + LicenseChecker(const LicenseChecker &) = delete; + LicenseChecker operator=(const LicenseChecker &) = delete; + LicenseChecker(LicenseChecker &&) = delete; + LicenseChecker operator=(LicenseChecker &&) = delete; + + void CheckEnvLicense(); + void SetLicenseInfoOverride(std::string license_key, std::string organization_name); + void EnableTesting(); + LicenseCheckResult IsValidLicense(const utils::Settings &settings) const; + bool IsValidLicenseFast() const; + void StartBackgroundLicenseChecker(const utils::Settings &settings); + + private: + std::pair<std::string, std::string> GetLicenseInfo(const utils::Settings &settings) const; + void RevalidateLicense(const utils::Settings &settings); + void RevalidateLicense(const std::string &license_key, const std::string &organization_name); + + std::optional<std::pair<std::string, std::string>> license_info_override_; + bool enterprise_enabled_{false}; + std::atomic<bool> is_valid_{false}; + utils::Scheduler scheduler_; + + friend void RegisterLicenseSettings(LicenseChecker &license_checker, utils::Settings &settings); +}; + +void RegisterLicenseSettings(LicenseChecker &license_checker, utils::Settings &settings); + +std::optional<License> Decode(std::string_view license_key); +std::string Encode(const License &license); + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +extern LicenseChecker global_license_checker; +} // namespace utils::license diff --git a/src/utils/memory_tracker.cpp b/src/utils/memory_tracker.cpp index d1ccceade..e507ac4f1 100644 --- a/src/utils/memory_tracker.cpp +++ b/src/utils/memory_tracker.cpp @@ -52,7 +52,25 @@ void MemoryTracker::UpdatePeak(const int64_t will_be) { } } -void MemoryTracker::SetHardLimit(const int64_t limit) { hard_limit_.store(limit, std::memory_order_relaxed); } +void MemoryTracker::SetHardLimit(const int64_t limit) { + const int64_t next_limit = std::invoke([this, limit] { + if (maximum_hard_limit_ == 0) { + return limit; + } + return limit == 0 ? maximum_hard_limit_ : std::min(maximum_hard_limit_, limit); + }); + + if (next_limit <= 0) { + spdlog::warn("Invalid memory limit."); + return; + } + + const auto previous_limit = hard_limit_.exchange(next_limit, std::memory_order_relaxed); + if (previous_limit != next_limit) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + spdlog::info("Memory limit set to {}", utils::GetReadableSize(next_limit)); + } +} void MemoryTracker::TryRaiseHardLimit(const int64_t limit) { int64_t old_limit = hard_limit_.load(std::memory_order_relaxed); @@ -60,6 +78,14 @@ void MemoryTracker::TryRaiseHardLimit(const int64_t limit) { ; } +void MemoryTracker::SetMaximumHardLimit(const int64_t limit) { + if (maximum_hard_limit_ < 0) { + spdlog::warn("Invalid maximum hard limit."); + return; + } + maximum_hard_limit_ = limit; +} + void MemoryTracker::Alloc(const int64_t size) { MG_ASSERT(size >= 0, "Negative size passed to the MemoryTracker."); diff --git a/src/utils/memory_tracker.hpp b/src/utils/memory_tracker.hpp index 28019664d..536c2177f 100644 --- a/src/utils/memory_tracker.hpp +++ b/src/utils/memory_tracker.hpp @@ -16,6 +16,8 @@ class MemoryTracker final { std::atomic<int64_t> amount_{0}; std::atomic<int64_t> peak_{0}; std::atomic<int64_t> hard_limit_{0}; + // Maximum possible value of a hard limit. If it's set to 0, no upper bound on the hard limit is set. + int64_t maximum_hard_limit_{0}; void UpdatePeak(int64_t will_be); @@ -43,6 +45,7 @@ class MemoryTracker final { void SetHardLimit(int64_t limit); void TryRaiseHardLimit(int64_t limit); + void SetMaximumHardLimit(int64_t limit); // By creating an object of this class, every allocation in its scope that goes over // the set hard limit produces an OutOfMemoryException. diff --git a/src/utils/settings.cpp b/src/utils/settings.cpp new file mode 100644 index 000000000..5fb21ba45 --- /dev/null +++ b/src/utils/settings.cpp @@ -0,0 +1,79 @@ +#include <fmt/format.h> + +#include "utils/logging.hpp" +#include "utils/settings.hpp" + +namespace utils { +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +Settings global_settings; + +void Settings::Initialize(std::filesystem::path storage_path) { + std::lock_guard settings_guard{settings_lock_}; + storage_.emplace(std::move(storage_path)); +} + +void Settings::Finalize() { + std::lock_guard settings_guard{settings_lock_}; + storage_.reset(); + on_change_callbacks_.clear(); +} + +void Settings::RegisterSetting(std::string name, const std::string &default_value, OnChangeCallback callback) { + std::lock_guard settings_guard{settings_lock_}; + MG_ASSERT(storage_); + + if (const auto maybe_value = storage_->Get(name); maybe_value) { + SPDLOG_INFO("The setting with name {} already exists!", name); + } else { + MG_ASSERT(storage_->Put(name, default_value), "Failed to register a setting"); + } + + const auto [it, inserted] = on_change_callbacks_.emplace(std::move(name), callback); + MG_ASSERT(inserted, "Settings storage is out of sync"); +} + +std::optional<std::string> Settings::GetValue(const std::string &setting_name) const { + std::shared_lock settings_guard{settings_lock_}; + MG_ASSERT(storage_); + auto maybe_value = storage_->Get(setting_name); + return maybe_value; +} + +bool Settings::SetValue(const std::string &setting_name, const std::string &new_value) { + const auto settings_change_callback = std::invoke([&, this]() -> std::optional<OnChangeCallback> { + std::lock_guard settings_guard{settings_lock_}; + MG_ASSERT(storage_); + + if (const auto maybe_value = storage_->Get(setting_name); !maybe_value) { + return std::nullopt; + } + + MG_ASSERT(storage_->Put(setting_name, new_value), "Failed to modify the setting"); + + const auto it = on_change_callbacks_.find(setting_name); + MG_ASSERT(it != on_change_callbacks_.end(), "Settings storage is out of sync"); + return it->second; + }); + + if (!settings_change_callback) { + return false; + } + + (*settings_change_callback)(); + return true; +} + +std::vector<std::pair<std::string, std::string>> Settings::AllSettings() const { + std::shared_lock settings_guard{settings_lock_}; + + MG_ASSERT(storage_); + + std::vector<std::pair<std::string, std::string>> settings; + settings.reserve(storage_->Size()); + for (const auto &[k, v] : *storage_) { + settings.emplace_back(k, v); + } + + return settings; +} +} // namespace utils diff --git a/src/utils/settings.hpp b/src/utils/settings.hpp new file mode 100644 index 000000000..02afd772c --- /dev/null +++ b/src/utils/settings.hpp @@ -0,0 +1,32 @@ +#pragma once + +#include <functional> +#include <optional> +#include <unordered_map> + +#include "kvstore/kvstore.hpp" +#include "utils/rw_lock.hpp" +#include "utils/synchronized.hpp" + +namespace utils { +struct Settings { + using OnChangeCallback = std::function<void()>; + + void Initialize(std::filesystem::path storage_path); + // RocksDB depends on statically allocated objects so we need to delete it before the static destruction kicks in + void Finalize(); + + void RegisterSetting(std::string name, const std::string &default_value, OnChangeCallback callback); + std::optional<std::string> GetValue(const std::string &setting_name) const; + bool SetValue(const std::string &setting_name, const std::string &new_value); + std::vector<std::pair<std::string, std::string>> AllSettings() const; + + private: + mutable utils::RWLock settings_lock_{RWLock::Priority::WRITE}; + std::unordered_map<std::string, OnChangeCallback> on_change_callbacks_; + std::optional<kvstore::KVStore> storage_; +}; + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +extern Settings global_settings; +} // namespace utils diff --git a/tests/e2e/streams/streams_owner_tests.py b/tests/e2e/streams/streams_owner_tests.py index f48f7439c..1281b8fce 100644 --- a/tests/e2e/streams/streams_owner_tests.py +++ b/tests/e2e/streams/streams_owner_tests.py @@ -4,7 +4,6 @@ import time import mgclient import common - def get_cursor_with_user(username): connection = common.connect(username=username, password="") return connection.cursor() diff --git a/tests/integration/audit/runner.py b/tests/integration/audit/runner.py index 3c39f1be2..d457acd5c 100755 --- a/tests/integration/audit/runner.py +++ b/tests/integration/audit/runner.py @@ -55,10 +55,14 @@ def wait_for_server(port, delay=0.1): def execute_test(memgraph_binary, tester_binary): storage_directory = tempfile.TemporaryDirectory() - memgraph_args = [memgraph_binary, - "--storage-properties-on-edges", - "--data-directory", storage_directory.name, - "--audit-enabled"] + memgraph_args = [ + memgraph_binary, + "--storage-properties-on-edges", + "--data-directory", + storage_directory.name, + "--audit-enabled", + "--log-file=memgraph.log", + "--log-level=TRACE"] # Start the memgraph binary memgraph = subprocess.Popen(list(map(str, memgraph_args))) @@ -73,17 +77,21 @@ def execute_test(memgraph_binary, tester_binary): memgraph.terminate() assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" + def execute_queries(queries): + for query, params in queries: + print(query, params) + args = [tester_binary, "--query", query, + "--params-json", json.dumps(params)] + subprocess.run(args).check_returncode() + # Execute all queries print("\033[1;36m~~ Starting query execution ~~\033[0m") - for query, params in QUERIES: - print(query, params) - args = [tester_binary, "--query", query, - "--params-json", json.dumps(params)] - subprocess.run(args).check_returncode() + execute_queries(QUERIES) print("\033[1;36m~~ Finished query execution ~~\033[0m\n") # Shutdown the memgraph binary memgraph.terminate() + assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!" # Verify the written log @@ -99,6 +107,7 @@ def execute_test(memgraph_binary, tester_binary): params = json.loads(params) queries.append((query, params)) print(query, params) + assert queries == QUERIES, "Logged queries don't match " \ "executed queries!" print("\033[1;36m~~ Finished log verification ~~\033[0m\n") diff --git a/tests/integration/auth/runner.py b/tests/integration/auth/runner.py index 1b2d9e2d4..cb0a8cde4 100755 --- a/tests/integration/auth/runner.py +++ b/tests/integration/auth/runner.py @@ -218,7 +218,7 @@ def execute_test(memgraph_binary, tester_binary, checker_binary): execute_admin_queries([ "CREATE USER ADmin IDENTIFIED BY 'admin'", "GRANT ALL PRIVILEGES TO admIN", - "CREATE USER usEr IDENTIFIED BY 'user'" + "CREATE USER usEr IDENTIFIED BY 'user'", ]) # Find all existing permissions diff --git a/tests/integration/ldap/runner.py b/tests/integration/ldap/runner.py index f5f8b4c5b..305bcf57e 100755 --- a/tests/integration/ldap/runner.py +++ b/tests/integration/ldap/runner.py @@ -100,7 +100,7 @@ class Memgraph: kwargs.pop("module_executable", self._auth_module)] for key, value in kwargs.items(): ldap_key = "--auth-module-" + key.replace("_", "-") - if type(value) == bool: + if isinstance(value, bool): args.append(ldap_key + "=" + str(value).lower()) else: args.append(ldap_key) @@ -124,6 +124,7 @@ class Memgraph: def initialize_test(memgraph, tester_binary, **kwargs): memgraph.start(module_executable="") + execute_tester(tester_binary, ["CREATE USER root", "GRANT ALL PRIVILEGES TO root"]) check_login = kwargs.pop("check_login", True) diff --git a/tests/manual/single_query.cpp b/tests/manual/single_query.cpp index a181db737..ec7dd594d 100644 --- a/tests/manual/single_query.cpp +++ b/tests/manual/single_query.cpp @@ -3,6 +3,7 @@ #include "query/interpreter.hpp" #include "storage/v2/isolation_level.hpp" #include "storage/v2/storage.hpp" +#include "utils/license.hpp" #include "utils/on_scope_exit.hpp" int main(int argc, char *argv[]) { @@ -17,6 +18,8 @@ int main(int argc, char *argv[]) { storage::Storage db; auto data_directory = std::filesystem::temp_directory_path() / "single_query_test"; utils::OnScopeExit([&data_directory] { std::filesystem::remove_all(data_directory); }); + + utils::license::global_license_checker.EnableTesting(); query::InterpreterContext interpreter_context{&db, query::InterpreterConfig{}, data_directory, "non existing bootstrap servers"}; query::Interpreter interpreter{&interpreter_context}; diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 84158c18a..610b43ee2 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -260,6 +260,12 @@ target_link_libraries(${test_prefix}utils_csv_parsing mg-utils fmt) add_unit_test(utils_async_timer.cpp) target_link_libraries(${test_prefix}utils_async_timer mg-utils) +add_unit_test(utils_license.cpp) +target_link_libraries(${test_prefix}utils_license mg-utils) + +add_unit_test(utils_settings.cpp) +target_link_libraries(${test_prefix}utils_settings mg-utils) + # Test mg-storage-v2 add_unit_test(commit_log_v2.cpp) diff --git a/tests/unit/auth.cpp b/tests/unit/auth.cpp index 201248bf4..ce92ee04d 100644 --- a/tests/unit/auth.cpp +++ b/tests/unit/auth.cpp @@ -8,6 +8,7 @@ #include "auth/crypto.hpp" #include "utils/cast.hpp" #include "utils/file.hpp" +#include "utils/license.hpp" using namespace auth; namespace fs = std::filesystem; @@ -21,13 +22,15 @@ class AuthWithStorage : public ::testing::Test { utils::EnsureDir(test_folder_); FLAGS_auth_password_permit_null = true; FLAGS_auth_password_strength_regex = ".+"; + + utils::license::global_license_checker.EnableTesting(); } virtual void TearDown() { fs::remove_all(test_folder_); } - fs::path test_folder_{fs::temp_directory_path() / ("unit_auth_test_" + std::to_string(static_cast<int>(getpid())))}; + fs::path test_folder_{fs::temp_directory_path() / "MG_tests_unit_auth"}; - Auth auth{test_folder_}; + Auth auth{test_folder_ / ("unit_auth_test_" + std::to_string(static_cast<int>(getpid())))}; }; TEST_F(AuthWithStorage, AddRole) { diff --git a/tests/unit/cypher_main_visitor.cpp b/tests/unit/cypher_main_visitor.cpp index a92c0daea..06282eb8d 100644 --- a/tests/unit/cypher_main_visitor.cpp +++ b/tests/unit/cypher_main_visitor.cpp @@ -3432,4 +3432,34 @@ TEST_P(CypherMainVisitorTest, CheckStream) { StreamQuery::Action::CHECK_STREAM, "checkedStream", TypedValue(30), TypedValue(444)); } +TEST_P(CypherMainVisitorTest, SettingQuery) { + auto &ast_generator = *GetParam(); + + TestInvalidQuery("SHOW DB SETTINGS", ast_generator); + TestInvalidQuery("SHOW SETTINGS", ast_generator); + TestInvalidQuery("SHOW DATABASE SETTING", ast_generator); + TestInvalidQuery("SHOW DB SETTING 'setting'", ast_generator); + TestInvalidQuery("SHOW SETTING 'setting'", ast_generator); + TestInvalidQuery<SemanticException>("SHOW DATABASE SETTING 1", ast_generator); + TestInvalidQuery("SET SETTING 'setting' TO 'value'", ast_generator); + TestInvalidQuery("SET DB SETTING 'setting' TO 'value'", ast_generator); + TestInvalidQuery<SemanticException>("SET DATABASE SETTING 1 TO 'value'", ast_generator); + TestInvalidQuery<SemanticException>("SET DATABASE SETTING 'setting' TO 2", ast_generator); + + const auto validate_setting_query = [&](const auto &query, const auto action, + const std::optional<TypedValue> &expected_setting_name, + const std::optional<TypedValue> &expected_setting_value) { + auto *parsed_query = dynamic_cast<SettingQuery *>(ast_generator.ParseQuery(query)); + EXPECT_EQ(parsed_query->action_, action) << query; + EXPECT_NO_FATAL_FAILURE(CheckOptionalExpression(ast_generator, parsed_query->setting_name_, expected_setting_name)); + EXPECT_NO_FATAL_FAILURE( + CheckOptionalExpression(ast_generator, parsed_query->setting_value_, expected_setting_value)); + }; + + validate_setting_query("SHOW DATABASE SETTINGS", SettingQuery::Action::SHOW_ALL_SETTINGS, std::nullopt, std::nullopt); + validate_setting_query("SHOW DATABASE SETTING 'setting'", SettingQuery::Action::SHOW_SETTING, TypedValue{"setting"}, + std::nullopt); + validate_setting_query("SET DATABASE SETTING 'setting' TO 'value'", SettingQuery::Action::SET_SETTING, + TypedValue{"setting"}, TypedValue{"value"}); +} } // namespace diff --git a/tests/unit/main.cpp b/tests/unit/main.cpp index cb05a4e3b..22e690290 100644 --- a/tests/unit/main.cpp +++ b/tests/unit/main.cpp @@ -4,6 +4,6 @@ int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); logging::RedirectToStderr(); - spdlog::set_level(spdlog::level::warn); + spdlog::set_level(spdlog::level::trace); return RUN_ALL_TESTS(); } diff --git a/tests/unit/query_required_privileges.cpp b/tests/unit/query_required_privileges.cpp index 8ee4a3122..dfc0c4cce 100644 --- a/tests/unit/query_required_privileges.cpp +++ b/tests/unit/query_required_privileges.cpp @@ -170,3 +170,8 @@ TEST_F(TestPrivilegeExtractor, StreamQuery) { auto *query = storage.Create<StreamQuery>(); EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::STREAM)); } + +TEST_F(TestPrivilegeExtractor, SettingQuery) { + auto *query = storage.Create<SettingQuery>(); + EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::CONFIG)); +} diff --git a/tests/unit/utils_license.cpp b/tests/unit/utils_license.cpp new file mode 100644 index 000000000..3ee0b109b --- /dev/null +++ b/tests/unit/utils_license.cpp @@ -0,0 +1,127 @@ +#include <gtest/gtest.h> + +#include "utils/license.hpp" +#include "utils/settings.hpp" + +class LicenseTest : public ::testing::Test { + public: + void SetUp() override { + settings.emplace(); + settings->Initialize(settings_directory); + + license_checker.emplace(); + utils::license::RegisterLicenseSettings(*license_checker, *settings); + + license_checker->StartBackgroundLicenseChecker(*settings); + } + + void TearDown() override { std::filesystem::remove_all(test_directory); } + + protected: + const std::filesystem::path test_directory{"MG_tests_unit_utils_license"}; + const std::filesystem::path settings_directory{test_directory / "settings"}; + + void CheckLicenseValidity(const bool expected_valid) { + ASSERT_EQ(!license_checker->IsValidLicense(*settings).HasError(), expected_valid); + ASSERT_EQ(license_checker->IsValidLicenseFast(), expected_valid); + } + + std::optional<utils::Settings> settings; + std::optional<utils::license::LicenseChecker> license_checker; +}; + +TEST_F(LicenseTest, EncodeDecode) { + const std::array licenses = { + utils::license::License{"Organization", 1, 2}, + utils::license::License{"", -1, 0}, + utils::license::License{"Some very long name for the organization Ltd", -999, -9999}, + }; + + for (const auto &license : licenses) { + const auto result = utils::license::Encode(license); + auto maybe_license = utils::license::Decode(result); + ASSERT_TRUE(maybe_license); + ASSERT_EQ(*maybe_license, license); + } +} + +TEST_F(LicenseTest, TestingFlag) { + CheckLicenseValidity(false); + + license_checker->EnableTesting(); + CheckLicenseValidity(true); + + SCOPED_TRACE("EnableTesting shouldn't be affected by settings change"); + settings->SetValue("enterprise.license", ""); + CheckLicenseValidity(true); +} + +TEST_F(LicenseTest, LicenseOrganizationName) { + const std::string organization_name{"Memgraph"}; + utils::license::License license{.organization_name = organization_name, .valid_until = 0, .memory_limit = 0}; + + settings->SetValue("enterprise.license", utils::license::Encode(license)); + settings->SetValue("organization.name", organization_name); + CheckLicenseValidity(true); + + settings->SetValue("organization.name", fmt::format("{}modified", organization_name)); + CheckLicenseValidity(false); + + settings->SetValue("organization.name", organization_name); + CheckLicenseValidity(true); +} + +TEST_F(LicenseTest, Expiration) { + const std::string organization_name{"Memgraph"}; + + { + const auto now = + std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now().time_since_epoch()); + const auto delta = std::chrono::seconds(1); + const auto valid_until = now + delta; + utils::license::License license{ + .organization_name = organization_name, .valid_until = valid_until.count(), .memory_limit = 0}; + + settings->SetValue("enterprise.license", utils::license::Encode(license)); + settings->SetValue("organization.name", organization_name); + CheckLicenseValidity(true); + + std::this_thread::sleep_for(delta + std::chrono::seconds(1)); + ASSERT_TRUE(license_checker->IsValidLicense(*settings).HasError()); + // We can't check fast checker because it has unknown refresh rate + } + { + SCOPED_TRACE("License with valid_until = 0 is always valid"); + utils::license::License license{.organization_name = organization_name, .valid_until = 0, .memory_limit = 0}; + settings->SetValue("enterprise.license", utils::license::Encode(license)); + settings->SetValue("organization.name", organization_name); + CheckLicenseValidity(true); + } +} + +TEST_F(LicenseTest, LicenseInfoOverride) { + CheckLicenseValidity(false); + + const std::string organization_name{"Memgraph"}; + utils::license::License license{.organization_name = organization_name, .valid_until = 0, .memory_limit = 0}; + const std::string license_key = utils::license::Encode(license); + + { + SCOPED_TRACE("Checker should use overrides instead of info from the settings"); + license_checker->SetLicenseInfoOverride(license_key, organization_name); + CheckLicenseValidity(true); + } + { + SCOPED_TRACE("License info override shouldn't be affected by settings change"); + settings->SetValue("enterprise.license", "INVALID"); + CheckLicenseValidity(true); + } + { + SCOPED_TRACE("Override with invalid key"); + license_checker->SetLicenseInfoOverride("INVALID", organization_name); + CheckLicenseValidity(false); + settings->SetValue("enterprise.license", license_key); + settings->SetValue("organization.name", organization_name); + CheckLicenseValidity(false); + } +} diff --git a/tests/unit/utils_settings.cpp b/tests/unit/utils_settings.cpp new file mode 100644 index 000000000..cd27fd749 --- /dev/null +++ b/tests/unit/utils_settings.cpp @@ -0,0 +1,158 @@ +#include <filesystem> + +#include <gmock/gmock-generated-matchers.h> +#include <gtest/gtest.h> + +#include "utils/settings.hpp" + +class SettingsTest : public ::testing::Test { + public: + void TearDown() override { std::filesystem::remove_all(test_directory); } + + protected: + const std::filesystem::path test_directory{"MG_tests_unit_utils_settings"}; + const std::filesystem::path settings_directory{test_directory / "settings"}; + static void DummyCallback() {} +}; + +namespace { +void CheckSettingValue(const utils::Settings &settings, const std::string &setting_name, + const std::string &expected_value) { + auto maybe_value = settings.GetValue(setting_name); + ASSERT_TRUE(maybe_value) << "Failed to access registered setting"; + ASSERT_EQ(maybe_value, expected_value); +} +} // namespace + +TEST_F(SettingsTest, RegisterSetting) { + const std::string setting_name{"name"}; + const std::string default_value{"value"}; + + { + utils::Settings settings; + + settings.Initialize(settings_directory); + settings.RegisterSetting(setting_name, default_value, DummyCallback); + CheckSettingValue(settings, setting_name, default_value); + } + { + utils::Settings settings; + settings.Initialize(settings_directory); + // registering the same object shouldn't change its value + settings.RegisterSetting(setting_name, fmt::format("{}-modified", default_value), DummyCallback); + CheckSettingValue(settings, setting_name, default_value); + } +} + +TEST_F(SettingsTest, RegisterSettingCallback) { + const std::string setting_name{"name"}; + const std::string default_value{"value"}; + + utils::Settings settings; + settings.Initialize(settings_directory); + + size_t callback_counter{0}; + const auto callback = [&]() { ++callback_counter; }; + + size_t setting_change_counter{0}; + const auto assert_equal_counters = [&] { ASSERT_EQ(callback_counter, setting_change_counter); }; + + settings.RegisterSetting(setting_name, default_value, callback); + assert_equal_counters(); + + ASSERT_TRUE(settings.SetValue(setting_name, default_value)); + ++setting_change_counter; + assert_equal_counters(); + + ASSERT_TRUE(settings.SetValue(setting_name, fmt::format("{}-modified", default_value))); + ++setting_change_counter; + assert_equal_counters(); +} + +TEST_F(SettingsTest, GetSetRegisteredSetting) { + const std::string setting_name{"name"}; + const std::string setting_value{"value"}; + const std::string default_value{"default"}; + + utils::Settings settings; + settings.Initialize(settings_directory); + settings.RegisterSetting(setting_name, default_value, DummyCallback); + + CheckSettingValue(settings, setting_name, default_value); + ASSERT_TRUE(settings.SetValue(setting_name, setting_value)) << "Failed to modify registered setting"; + CheckSettingValue(settings, setting_name, setting_value); +} + +TEST_F(SettingsTest, GetSetUnregisteredSetting) { + utils::Settings settings; + settings.Initialize(settings_directory); + ASSERT_FALSE(settings.GetValue("Somesetting")) << "Accessed unregistered setting"; + ASSERT_FALSE(settings.SetValue("Somesetting", "Somevalue")) << "Modified unregistered setting"; +} + +TEST_F(SettingsTest, Initialization) { + utils::Settings settings; + settings.Initialize(settings_directory); + ASSERT_NO_FATAL_FAILURE(settings.GetValue("setting")); + ASSERT_NO_FATAL_FAILURE(settings.SetValue("setting", "value")); + ASSERT_NO_FATAL_FAILURE(settings.AllSettings()); +} + +namespace { +std::vector<std::pair<std::string, std::string>> GenerateSettings(const size_t amount) { + std::vector<std::pair<std::string, std::string>> result; + result.reserve(amount); + + for (size_t i = 0; i < amount; ++i) { + result.emplace_back(fmt::format("setting{}", i), fmt::format("value{}", i)); + } + + return result; +} +} // namespace + +TEST_F(SettingsTest, AllSettings) { + const auto generated_settings = GenerateSettings(100); + + utils::Settings settings; + settings.Initialize(settings_directory); + for (const auto &[setting_name, setting_value] : generated_settings) { + settings.RegisterSetting(setting_name, setting_value, DummyCallback); + } + ASSERT_THAT(settings.AllSettings(), testing::UnorderedElementsAreArray(generated_settings)); +} + +TEST_F(SettingsTest, Persistance) { + auto generated_settings = GenerateSettings(100); + + utils::Settings settings; + settings.Initialize(settings_directory); + + for (const auto &[setting_name, setting_value] : generated_settings) { + settings.RegisterSetting(setting_name, setting_value, DummyCallback); + } + + ASSERT_THAT(settings.AllSettings(), testing::UnorderedElementsAreArray(generated_settings)); + + // reinitialize to other directory and then back to the first + settings.Initialize(test_directory / "other_settings"); + ASSERT_TRUE(settings.AllSettings().empty()); + + settings.Initialize(settings_directory); + ASSERT_THAT(settings.AllSettings(), testing::UnorderedElementsAreArray(generated_settings)); + + for (size_t i = 0; i < generated_settings.size(); ++i) { + auto &[setting_name, setting_value] = generated_settings[i]; + setting_value = fmt::format("new_value{}", i); + settings.SetValue(setting_name, setting_value); + } + + ASSERT_THAT(settings.AllSettings(), testing::UnorderedElementsAreArray(generated_settings)); + + // reinitialize to other directory and then back to the first + settings.Initialize(test_directory / "other_settings"); + ASSERT_TRUE(settings.AllSettings().empty()); + + settings.Initialize(settings_directory); + ASSERT_THAT(settings.AllSettings(), testing::UnorderedElementsAreArray(generated_settings)); +}