Merge branch 'master' into T-fix-orderby-spike

This commit is contained in:
Ivan Milinović 2023-12-19 22:46:42 +01:00 committed by GitHub
commit 6e4f0dc683
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
336 changed files with 8716 additions and 2566 deletions

View File

@ -1,4 +1,4 @@
name: Package All name: Package memgraph
# TODO(gitbuda): Cleanup docker container if GHA job was canceled. # TODO(gitbuda): Cleanup docker container if GHA job was canceled.
@ -6,18 +6,56 @@ on:
workflow_dispatch: workflow_dispatch:
inputs: inputs:
memgraph_version: memgraph_version:
description: "Memgraph version to upload as. If empty upload is skipped. Format: 'X.Y.Z'" description: "Memgraph version to upload as. Leave this field empty if you don't want to upload binaries to S3. Format: 'X.Y.Z'"
required: false required: false
build_type: build_type:
type: choice type: choice
description: "Memgraph Build type. Default value is Release." description: "Memgraph Build type. Default value is Release"
default: 'Release' default: 'Release'
options: options:
- Release - Release
- RelWithDebInfo - RelWithDebInfo
target_os:
type: choice
description: "Target OS for which memgraph will be packaged. Select 'all' if you want to package for every listed OS. Default is Ubuntu 22.04"
default: 'ubuntu-22_04'
options:
- all
- amzn-2
- centos-7
- centos-9
- debian-10
- debian-11
- debian-11-arm
- debian-11-platform
- docker
- fedora-36
- ubuntu-18_04
- ubuntu-20_04
- ubuntu-22_04
- ubuntu-22_04-arm
jobs: jobs:
amzn-2:
if: ${{ github.event.inputs.target_os == 'amzn-2' || github.event.inputs.target_os == 'all' }}
runs-on: [self-hosted, DockerMgBuild, X64]
timeout-minutes: 60
steps:
- name: "Set up repository"
uses: actions/checkout@v3
with:
fetch-depth: 0 # Required because of release/get_version.py
- name: "Build package"
run: |
./release/package/run.sh package amzn-2 ${{ github.event.inputs.build_type }}
- name: "Upload package"
uses: actions/upload-artifact@v3
with:
name: amzn-2
path: build/output/amzn-2/memgraph*.rpm
centos-7: centos-7:
if: ${{ github.event.inputs.target_os == 'centos-7' || github.event.inputs.target_os == 'all' }}
runs-on: [self-hosted, DockerMgBuild, X64] runs-on: [self-hosted, DockerMgBuild, X64]
timeout-minutes: 60 timeout-minutes: 60
steps: steps:
@ -35,6 +73,7 @@ jobs:
path: build/output/centos-7/memgraph*.rpm path: build/output/centos-7/memgraph*.rpm
centos-9: centos-9:
if: ${{ github.event.inputs.target_os == 'centos-9' || github.event.inputs.target_os == 'all' }}
runs-on: [self-hosted, DockerMgBuild, X64] runs-on: [self-hosted, DockerMgBuild, X64]
timeout-minutes: 60 timeout-minutes: 60
steps: steps:
@ -52,6 +91,7 @@ jobs:
path: build/output/centos-9/memgraph*.rpm path: build/output/centos-9/memgraph*.rpm
debian-10: debian-10:
if: ${{ github.event.inputs.target_os == 'debian-10' || github.event.inputs.target_os == 'all' }}
runs-on: [self-hosted, DockerMgBuild, X64] runs-on: [self-hosted, DockerMgBuild, X64]
timeout-minutes: 60 timeout-minutes: 60
steps: steps:
@ -69,6 +109,7 @@ jobs:
path: build/output/debian-10/memgraph*.deb path: build/output/debian-10/memgraph*.deb
debian-11: debian-11:
if: ${{ github.event.inputs.target_os == 'debian-11' || github.event.inputs.target_os == 'all' }}
runs-on: [self-hosted, DockerMgBuild, X64] runs-on: [self-hosted, DockerMgBuild, X64]
timeout-minutes: 60 timeout-minutes: 60
steps: steps:
@ -85,7 +126,44 @@ jobs:
name: debian-11 name: debian-11
path: build/output/debian-11/memgraph*.deb path: build/output/debian-11/memgraph*.deb
debian-11-arm:
if: ${{ github.event.inputs.target_os == 'debian-11-arm' || github.event.inputs.target_os == 'all' }}
runs-on: [self-hosted, DockerMgBuild, ARM64, strange]
timeout-minutes: 120
steps:
- name: "Set up repository"
uses: actions/checkout@v3
with:
fetch-depth: 0 # Required because of release/get_version.py
- name: "Build package"
run: |
./release/package/run.sh package debian-11-arm ${{ github.event.inputs.build_type }}
- name: "Upload package"
uses: actions/upload-artifact@v3
with:
name: debian-11-aarch64
path: build/output/debian-11-arm/memgraph*.deb
debian-11-platform:
if: ${{ github.event.inputs.target_os == 'debian-11-platform' || github.event.inputs.target_os == 'all' }}
runs-on: [self-hosted, DockerMgBuild, X64]
timeout-minutes: 60
steps:
- name: "Set up repository"
uses: actions/checkout@v3
with:
fetch-depth: 0 # Required because of release/get_version.py
- name: "Build package"
run: |
./release/package/run.sh package debian-11 ${{ github.event.inputs.build_type }} --for-platform
- name: "Upload package"
uses: actions/upload-artifact@v3
with:
name: debian-11-platform
path: build/output/debian-11/memgraph*.deb
docker: docker:
if: ${{ github.event.inputs.target_os == 'docker' || github.event.inputs.target_os == 'all' }}
runs-on: [self-hosted, DockerMgBuild, X64] runs-on: [self-hosted, DockerMgBuild, X64]
timeout-minutes: 60 timeout-minutes: 60
steps: steps:
@ -104,75 +182,8 @@ jobs:
name: docker name: docker
path: build/output/docker/memgraph*.tar.gz path: build/output/docker/memgraph*.tar.gz
ubuntu-1804:
runs-on: [self-hosted, DockerMgBuild, X64]
timeout-minutes: 60
steps:
- name: "Set up repository"
uses: actions/checkout@v3
with:
fetch-depth: 0 # Required because of release/get_version.py
- name: "Build package"
run: |
./release/package/run.sh package ubuntu-18.04 ${{ github.event.inputs.build_type }}
- name: "Upload package"
uses: actions/upload-artifact@v3
with:
name: ubuntu-18.04
path: build/output/ubuntu-18.04/memgraph*.deb
ubuntu-2004:
runs-on: [self-hosted, DockerMgBuild, X64]
timeout-minutes: 60
steps:
- name: "Set up repository"
uses: actions/checkout@v3
with:
fetch-depth: 0 # Required because of release/get_version.py
- name: "Build package"
run: |
./release/package/run.sh package ubuntu-20.04 ${{ github.event.inputs.build_type }}
- name: "Upload package"
uses: actions/upload-artifact@v3
with:
name: ubuntu-20.04
path: build/output/ubuntu-20.04/memgraph*.deb
ubuntu-2204:
runs-on: [self-hosted, DockerMgBuild, X64]
timeout-minutes: 60
steps:
- name: "Set up repository"
uses: actions/checkout@v3
with:
fetch-depth: 0 # Required because of release/get_version.py
- name: "Build package"
run: |
./release/package/run.sh package ubuntu-22.04 ${{ github.event.inputs.build_type }}
- name: "Upload package"
uses: actions/upload-artifact@v3
with:
name: ubuntu-22.04
path: build/output/ubuntu-22.04/memgraph*.deb
debian-11-platform:
runs-on: [self-hosted, DockerMgBuild, X64]
timeout-minutes: 60
steps:
- name: "Set up repository"
uses: actions/checkout@v3
with:
fetch-depth: 0 # Required because of release/get_version.py
- name: "Build package"
run: |
./release/package/run.sh package debian-11 ${{ github.event.inputs.build_type }} --for-platform
- name: "Upload package"
uses: actions/upload-artifact@v3
with:
name: debian-11-platform
path: build/output/debian-11/memgraph*.deb
fedora-36: fedora-36:
if: ${{ github.event.inputs.target_os == 'fedora-36' || github.event.inputs.target_os == 'all' }}
runs-on: [self-hosted, DockerMgBuild, X64] runs-on: [self-hosted, DockerMgBuild, X64]
timeout-minutes: 60 timeout-minutes: 60
steps: steps:
@ -189,7 +200,8 @@ jobs:
name: fedora-36 name: fedora-36
path: build/output/fedora-36/memgraph*.rpm path: build/output/fedora-36/memgraph*.rpm
amzn-2: ubuntu-18_04:
if: ${{ github.event.inputs.target_os == 'ubuntu-18_04' || github.event.inputs.target_os == 'all' }}
runs-on: [self-hosted, DockerMgBuild, X64] runs-on: [self-hosted, DockerMgBuild, X64]
timeout-minutes: 60 timeout-minutes: 60
steps: steps:
@ -199,16 +211,17 @@ jobs:
fetch-depth: 0 # Required because of release/get_version.py fetch-depth: 0 # Required because of release/get_version.py
- name: "Build package" - name: "Build package"
run: | run: |
./release/package/run.sh package amzn-2 ${{ github.event.inputs.build_type }} ./release/package/run.sh package ubuntu-18.04 ${{ github.event.inputs.build_type }}
- name: "Upload package" - name: "Upload package"
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v3
with: with:
name: amzn-2 name: ubuntu-18.04
path: build/output/amzn-2/memgraph*.rpm path: build/output/ubuntu-18.04/memgraph*.deb
debian-11-arm: ubuntu-20_04:
runs-on: [self-hosted, DockerMgBuild, ARM64, strange] if: ${{ github.event.inputs.target_os == 'ubuntu-20_04' || github.event.inputs.target_os == 'all' }}
timeout-minutes: 120 runs-on: [self-hosted, DockerMgBuild, X64]
timeout-minutes: 60
steps: steps:
- name: "Set up repository" - name: "Set up repository"
uses: actions/checkout@v3 uses: actions/checkout@v3
@ -216,14 +229,33 @@ jobs:
fetch-depth: 0 # Required because of release/get_version.py fetch-depth: 0 # Required because of release/get_version.py
- name: "Build package" - name: "Build package"
run: | run: |
./release/package/run.sh package debian-11-arm ${{ github.event.inputs.build_type }} ./release/package/run.sh package ubuntu-20.04 ${{ github.event.inputs.build_type }}
- name: "Upload package" - name: "Upload package"
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v3
with: with:
name: debian-11-aarch64 name: ubuntu-20.04
path: build/output/debian-11-arm/memgraph*.deb path: build/output/ubuntu-20.04/memgraph*.deb
ubuntu-2204-arm: ubuntu-22_04:
if: ${{ github.event.inputs.target_os == 'ubuntu-22_04' || github.event.inputs.target_os == 'all' }}
runs-on: [self-hosted, DockerMgBuild, X64]
timeout-minutes: 60
steps:
- name: "Set up repository"
uses: actions/checkout@v3
with:
fetch-depth: 0 # Required because of release/get_version.py
- name: "Build package"
run: |
./release/package/run.sh package ubuntu-22.04 ${{ github.event.inputs.build_type }}
- name: "Upload package"
uses: actions/upload-artifact@v3
with:
name: ubuntu-22.04
path: build/output/ubuntu-22.04/memgraph*.deb
ubuntu-22_04-arm:
if: ${{ github.event.inputs.target_os == 'ubuntu-22_04-arm' || github.event.inputs.target_os == 'all' }}
runs-on: [self-hosted, DockerMgBuild, ARM64, strange] runs-on: [self-hosted, DockerMgBuild, ARM64, strange]
timeout-minutes: 120 timeout-minutes: 120
steps: steps:
@ -243,7 +275,7 @@ jobs:
upload-to-s3: upload-to-s3:
# only run upload if we specified version. Allows for runs without upload # only run upload if we specified version. Allows for runs without upload
if: "${{ github.event.inputs.memgraph_version != '' }}" if: "${{ github.event.inputs.memgraph_version != '' }}"
needs: [centos-7, centos-9, debian-10, debian-11, docker, ubuntu-1804, ubuntu-2004, ubuntu-2204, debian-11-platform, fedora-36, amzn-2, debian-11-arm, ubuntu-2204-arm] needs: [amzn-2, centos-7, centos-9, debian-10, debian-11, debian-11-arm, debian-11-platform, docker, fedora-36, ubuntu-18_04, ubuntu-20_04, ubuntu-22_04, ubuntu-22_04-arm]
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Download artifacts - name: Download artifacts

View File

@ -178,7 +178,7 @@ jobs:
release_build: release_build:
name: "Release build" name: "Release build"
runs-on: [self-hosted, Linux, X64, Debian10] runs-on: [self-hosted, Linux, X64, Debian10, BigMemory]
env: env:
THREADS: 24 THREADS: 24
MEMGRAPH_ENTERPRISE_LICENSE: ${{ secrets.MEMGRAPH_ENTERPRISE_LICENSE }} MEMGRAPH_ENTERPRISE_LICENSE: ${{ secrets.MEMGRAPH_ENTERPRISE_LICENSE }}

22
.sonarcloud.properties Normal file
View File

@ -0,0 +1,22 @@
# Path to sources
sonar.sources = .
# sonar.exclusions=
sonar.inclusions=src,include,query_modules
# Path to tests
sonar.tests = tests/
# sonar.test.exclusions=
# sonar.test.inclusions=
# Source encoding
# sonar.sourceEncoding=
# Exclusions for copy-paste detection
# sonar.cpd.exclusions=
# Python version (for python projects only)
# sonar.python.version=
# C++ standard version (for C++ projects only)
# If not specified, it defaults to the latest supported standard
# sonar.cfamily.reportingCppStandardOverride=c++98|c++11|c++14|c++17|c++20

View File

@ -292,7 +292,7 @@ if (MG_ENTERPRISE)
add_definitions(-DMG_ENTERPRISE) add_definitions(-DMG_ENTERPRISE)
endif() endif()
set(ENABLE_JEMALLOC ON) option(ENABLE_JEMALLOC "Use jemalloc" ON)
if (ASAN) if (ASAN)
message(WARNING "Disabling jemalloc as it doesn't work well with ASAN") message(WARNING "Disabling jemalloc as it doesn't work well with ASAN")

View File

@ -111,6 +111,14 @@ modifications:
value: "false" value: "false"
override: true override: true
- name: "storage_parallel_schema_recovery"
value: "false"
override: true
- name: "storage_enable_schema_metadata"
value: "false"
override: true
- name: "query_callable_mappings_path" - name: "query_callable_mappings_path"
value: "/etc/memgraph/apoc_compatibility_mappings.json" value: "/etc/memgraph/apoc_compatibility_mappings.json"
override: true override: true

View File

@ -234,7 +234,57 @@ inline mgp_type *type_duration() { return MgInvoke<mgp_type *>(mgp_type_duration
inline mgp_type *type_nullable(mgp_type *type) { return MgInvoke<mgp_type *>(mgp_type_nullable, type); } inline mgp_type *type_nullable(mgp_type *type) { return MgInvoke<mgp_type *>(mgp_type_nullable, type); }
inline bool create_label_index(mgp_graph *graph, const char *label) {
return MgInvoke<int>(mgp_create_label_index, graph, label);
}
inline bool drop_label_index(mgp_graph *graph, const char *label) {
return MgInvoke<int>(mgp_drop_label_index, graph, label);
}
inline mgp_list *list_all_label_indices(mgp_graph *graph, mgp_memory *memory) {
return MgInvoke<mgp_list *>(mgp_list_all_label_indices, graph, memory);
}
inline bool create_label_property_index(mgp_graph *graph, const char *label, const char *property) {
return MgInvoke<int>(mgp_create_label_property_index, graph, label, property);
}
inline bool drop_label_property_index(mgp_graph *graph, const char *label, const char *property) {
return MgInvoke<int>(mgp_drop_label_property_index, graph, label, property);
}
inline mgp_list *list_all_label_property_indices(mgp_graph *graph, mgp_memory *memory) {
return MgInvoke<mgp_list *>(mgp_list_all_label_property_indices, graph, memory);
}
inline bool create_existence_constraint(mgp_graph *graph, const char *label, const char *property) {
return MgInvoke<int>(mgp_create_existence_constraint, graph, label, property);
}
inline bool drop_existence_constraint(mgp_graph *graph, const char *label, const char *property) {
return MgInvoke<int>(mgp_drop_existence_constraint, graph, label, property);
}
inline mgp_list *list_all_existence_constraints(mgp_graph *graph, mgp_memory *memory) {
return MgInvoke<mgp_list *>(mgp_list_all_existence_constraints, graph, memory);
}
inline bool create_unique_constraint(mgp_graph *memgraph_graph, const char *label, mgp_value *properties) {
return MgInvoke<int>(mgp_create_unique_constraint, memgraph_graph, label, properties);
}
inline bool drop_unique_constraint(mgp_graph *memgraph_graph, const char *label, mgp_value *properties) {
return MgInvoke<int>(mgp_drop_unique_constraint, memgraph_graph, label, properties);
}
inline mgp_list *list_all_unique_constraints(mgp_graph *graph, mgp_memory *memory) {
return MgInvoke<mgp_list *>(mgp_list_all_unique_constraints, graph, memory);
}
// mgp_graph // mgp_graph
inline bool graph_is_transactional(mgp_graph *graph) { return MgInvoke<int>(mgp_graph_is_transactional, graph); }
inline bool graph_is_mutable(mgp_graph *graph) { return MgInvoke<int>(mgp_graph_is_mutable, graph); } inline bool graph_is_mutable(mgp_graph *graph) { return MgInvoke<int>(mgp_graph_is_mutable, graph); }
@ -328,6 +378,8 @@ inline mgp_list *list_copy(mgp_list *list, mgp_memory *memory) {
inline void list_destroy(mgp_list *list) { mgp_list_destroy(list); } inline void list_destroy(mgp_list *list) { mgp_list_destroy(list); }
inline bool list_contains_deleted(mgp_list *list) { return MgInvoke<int>(mgp_list_contains_deleted, list); }
inline void list_append(mgp_list *list, mgp_value *val) { MgInvokeVoid(mgp_list_append, list, val); } inline void list_append(mgp_list *list, mgp_value *val) { MgInvokeVoid(mgp_list_append, list, val); }
inline void list_append_extend(mgp_list *list, mgp_value *val) { MgInvokeVoid(mgp_list_append_extend, list, val); } inline void list_append_extend(mgp_list *list, mgp_value *val) { MgInvokeVoid(mgp_list_append_extend, list, val); }
@ -346,6 +398,8 @@ inline mgp_map *map_copy(mgp_map *map, mgp_memory *memory) { return MgInvoke<mgp
inline void map_destroy(mgp_map *map) { mgp_map_destroy(map); } inline void map_destroy(mgp_map *map) { mgp_map_destroy(map); }
inline bool map_contains_deleted(mgp_map *map) { return MgInvoke<int>(mgp_map_contains_deleted, map); }
inline void map_insert(mgp_map *map, const char *key, mgp_value *value) { inline void map_insert(mgp_map *map, const char *key, mgp_value *value) {
MgInvokeVoid(mgp_map_insert, map, key, value); MgInvokeVoid(mgp_map_insert, map, key, value);
} }
@ -394,6 +448,8 @@ inline mgp_vertex *vertex_copy(mgp_vertex *v, mgp_memory *memory) {
inline void vertex_destroy(mgp_vertex *v) { mgp_vertex_destroy(v); } inline void vertex_destroy(mgp_vertex *v) { mgp_vertex_destroy(v); }
inline bool vertex_is_deleted(mgp_vertex *v) { return MgInvoke<int>(mgp_vertex_is_deleted, v); }
inline bool vertex_equal(mgp_vertex *v1, mgp_vertex *v2) { return MgInvoke<int>(mgp_vertex_equal, v1, v2); } inline bool vertex_equal(mgp_vertex *v1, mgp_vertex *v2) { return MgInvoke<int>(mgp_vertex_equal, v1, v2); }
inline size_t vertex_labels_count(mgp_vertex *v) { return MgInvoke<size_t>(mgp_vertex_labels_count, v); } inline size_t vertex_labels_count(mgp_vertex *v) { return MgInvoke<size_t>(mgp_vertex_labels_count, v); }
@ -446,6 +502,8 @@ inline mgp_edge *edge_copy(mgp_edge *e, mgp_memory *memory) { return MgInvoke<mg
inline void edge_destroy(mgp_edge *e) { mgp_edge_destroy(e); } inline void edge_destroy(mgp_edge *e) { mgp_edge_destroy(e); }
inline bool edge_is_deleted(mgp_edge *e) { return MgInvoke<int>(mgp_edge_is_deleted, e); }
inline bool edge_equal(mgp_edge *e1, mgp_edge *e2) { return MgInvoke<int>(mgp_edge_equal, e1, e2); } inline bool edge_equal(mgp_edge *e1, mgp_edge *e2) { return MgInvoke<int>(mgp_edge_equal, e1, e2); }
inline mgp_edge_type edge_get_type(mgp_edge *e) { return MgInvoke<mgp_edge_type>(mgp_edge_get_type, e); } inline mgp_edge_type edge_get_type(mgp_edge *e) { return MgInvoke<mgp_edge_type>(mgp_edge_get_type, e); }
@ -482,6 +540,8 @@ inline mgp_path *path_copy(mgp_path *path, mgp_memory *memory) {
inline void path_destroy(mgp_path *path) { mgp_path_destroy(path); } inline void path_destroy(mgp_path *path) { mgp_path_destroy(path); }
inline bool path_contains_deleted(mgp_path *path) { return MgInvoke<int>(mgp_path_contains_deleted, path); }
inline void path_expand(mgp_path *path, mgp_edge *edge) { MgInvokeVoid(mgp_path_expand, path, edge); } inline void path_expand(mgp_path *path, mgp_edge *edge) { MgInvokeVoid(mgp_path_expand, path, edge); }
inline void path_pop(mgp_path *path) { MgInvokeVoid(mgp_path_pop, path); } inline void path_pop(mgp_path *path) { MgInvokeVoid(mgp_path_pop, path); }

View File

@ -429,6 +429,9 @@ enum mgp_error mgp_list_copy(struct mgp_list *list, struct mgp_memory *memory, s
/// Free the memory used by the given mgp_list and contained elements. /// Free the memory used by the given mgp_list and contained elements.
void mgp_list_destroy(struct mgp_list *list); void mgp_list_destroy(struct mgp_list *list);
/// Return whether the given mgp_list contains any deleted values.
enum mgp_error mgp_list_contains_deleted(struct mgp_list *list, int *result);
/// Append a copy of mgp_value to mgp_list if capacity allows. /// Append a copy of mgp_value to mgp_list if capacity allows.
/// The list copies the given value and therefore does not take ownership of the /// The list copies the given value and therefore does not take ownership of the
/// original value. You still need to call mgp_value_destroy to free the /// original value. You still need to call mgp_value_destroy to free the
@ -469,6 +472,9 @@ enum mgp_error mgp_map_copy(struct mgp_map *map, struct mgp_memory *memory, stru
/// Free the memory used by the given mgp_map and contained items. /// Free the memory used by the given mgp_map and contained items.
void mgp_map_destroy(struct mgp_map *map); void mgp_map_destroy(struct mgp_map *map);
/// Return whether the given mgp_map contains any deleted values.
enum mgp_error mgp_map_contains_deleted(struct mgp_map *map, int *result);
/// Insert a new mapping from a NULL terminated character string to a value. /// Insert a new mapping from a NULL terminated character string to a value.
/// If a mapping with the same key already exists, it is *not* replaced. /// If a mapping with the same key already exists, it is *not* replaced.
/// In case of insertion, both the string and the value are copied into the map. /// In case of insertion, both the string and the value are copied into the map.
@ -552,6 +558,9 @@ enum mgp_error mgp_path_copy(struct mgp_path *path, struct mgp_memory *memory, s
/// Free the memory used by the given mgp_path and contained vertices and edges. /// Free the memory used by the given mgp_path and contained vertices and edges.
void mgp_path_destroy(struct mgp_path *path); void mgp_path_destroy(struct mgp_path *path);
/// Return whether the given mgp_path contains any deleted values.
enum mgp_error mgp_path_contains_deleted(struct mgp_path *path, int *result);
/// Append an edge continuing from the last vertex on the path. /// Append an edge continuing from the last vertex on the path.
/// The edge is copied into the path. Therefore, the path does not take /// The edge is copied into the path. Therefore, the path does not take
/// ownership of the original edge, so you still need to free the edge memory /// ownership of the original edge, so you still need to free the edge memory
@ -725,6 +734,9 @@ enum mgp_error mgp_vertex_copy(struct mgp_vertex *v, struct mgp_memory *memory,
/// Free the memory used by a mgp_vertex. /// Free the memory used by a mgp_vertex.
void mgp_vertex_destroy(struct mgp_vertex *v); void mgp_vertex_destroy(struct mgp_vertex *v);
/// Return whether the given mgp_vertex is deleted.
enum mgp_error mgp_vertex_is_deleted(struct mgp_vertex *v, int *result);
/// Result is non-zero if given vertices are equal, otherwise 0. /// Result is non-zero if given vertices are equal, otherwise 0.
enum mgp_error mgp_vertex_equal(struct mgp_vertex *v1, struct mgp_vertex *v2, int *result); enum mgp_error mgp_vertex_equal(struct mgp_vertex *v1, struct mgp_vertex *v2, int *result);
@ -819,6 +831,9 @@ enum mgp_error mgp_edge_copy(struct mgp_edge *e, struct mgp_memory *memory, stru
/// Free the memory used by a mgp_edge. /// Free the memory used by a mgp_edge.
void mgp_edge_destroy(struct mgp_edge *e); void mgp_edge_destroy(struct mgp_edge *e);
/// Return whether the given mgp_edge is deleted.
enum mgp_error mgp_edge_is_deleted(struct mgp_edge *e, int *result);
/// Result is non-zero if given edges are equal, otherwise 0. /// Result is non-zero if given edges are equal, otherwise 0.
enum mgp_error mgp_edge_equal(struct mgp_edge *e1, struct mgp_edge *e2, int *result); enum mgp_error mgp_edge_equal(struct mgp_edge *e1, struct mgp_edge *e2, int *result);
@ -876,12 +891,77 @@ enum mgp_error mgp_edge_iter_properties(struct mgp_edge *e, struct mgp_memory *m
enum mgp_error mgp_graph_get_vertex_by_id(struct mgp_graph *g, struct mgp_vertex_id id, struct mgp_memory *memory, enum mgp_error mgp_graph_get_vertex_by_id(struct mgp_graph *g, struct mgp_vertex_id id, struct mgp_memory *memory,
struct mgp_vertex **result); struct mgp_vertex **result);
/// Creates label index for given label.
/// mgp_error::MGP_ERROR_NO_ERROR is always returned.
/// if label index already exists, result will be 0, otherwise 1.
enum mgp_error mgp_create_label_index(struct mgp_graph *graph, const char *label, int *result);
/// Drop label index.
enum mgp_error mgp_drop_label_index(struct mgp_graph *graph, const char *label, int *result);
/// List all label indices.
enum mgp_error mgp_list_all_label_indices(struct mgp_graph *graph, struct mgp_memory *memory, struct mgp_list **result);
/// Creates label-property index for given label and propery.
/// mgp_error::MGP_ERROR_NO_ERROR is always returned.
/// if label property index already exists, result will be 0, otherwise 1.
enum mgp_error mgp_create_label_property_index(struct mgp_graph *graph, const char *label, const char *property,
int *result);
/// Drops label-property index for given label and propery.
/// mgp_error::MGP_ERROR_NO_ERROR is always returned.
/// if dropping label property index failed, result will be 0, otherwise 1.
enum mgp_error mgp_drop_label_property_index(struct mgp_graph *graph, const char *label, const char *property,
int *result);
/// List all label+property indices.
enum mgp_error mgp_list_all_label_property_indices(struct mgp_graph *graph, struct mgp_memory *memory,
struct mgp_list **result);
/// Creates existence constraint for given label and property.
/// mgp_error::MGP_ERROR_NO_ERROR is always returned.
/// if creating existence constraint failed, result will be 0, otherwise 1.
enum mgp_error mgp_create_existence_constraint(struct mgp_graph *graph, const char *label, const char *property,
int *result);
/// Drops existence constraint for given label and property.
/// mgp_error::MGP_ERROR_NO_ERROR is always returned.
/// if dropping existence constraint failed, result will be 0, otherwise 1.
enum mgp_error mgp_drop_existence_constraint(struct mgp_graph *graph, const char *label, const char *property,
int *result);
/// List all existence constraints.
enum mgp_error mgp_list_all_existence_constraints(struct mgp_graph *graph, struct mgp_memory *memory,
struct mgp_list **result);
/// Creates unique constraint for given label and properties.
/// mgp_error::MGP_ERROR_NO_ERROR is always returned.
/// if creating unique constraint failed, result will be 0, otherwise 1.
enum mgp_error mgp_create_unique_constraint(struct mgp_graph *graph, const char *label, struct mgp_value *properties,
int *result);
/// Drops unique constraint for given label and properties.
/// mgp_error::MGP_ERROR_NO_ERROR is always returned.
/// if dropping unique constraint failed, result will be 0, otherwise 1.
enum mgp_error mgp_drop_unique_constraint(struct mgp_graph *graph, const char *label, struct mgp_value *properties,
int *result);
/// List all unique constraints
enum mgp_error mgp_list_all_unique_constraints(struct mgp_graph *graph, struct mgp_memory *memory,
struct mgp_list **result);
/// Result is non-zero if the graph can be modified. /// Result is non-zero if the graph can be modified.
/// If a graph is immutable, then vertices cannot be created or deleted, and all of the returned vertices will be /// If a graph is immutable, then vertices cannot be created or deleted, and all of the returned vertices will be
/// immutable also. The same applies for edges. /// immutable also. The same applies for edges.
/// Current implementation always returns without errors. /// Current implementation always returns without errors.
enum mgp_error mgp_graph_is_mutable(struct mgp_graph *graph, int *result); enum mgp_error mgp_graph_is_mutable(struct mgp_graph *graph, int *result);
/// Result is non-zero if the graph is in transactional storage mode.
/// If a graph is not in transactional mode (i.e. analytical mode), then vertices and edges can be missing
/// because changes from other transactions are visible.
/// Current implementation always returns without errors.
enum mgp_error mgp_graph_is_transactional(struct mgp_graph *graph, int *result);
/// Add a new vertex to the graph. /// Add a new vertex to the graph.
/// Resulting vertex must be freed using mgp_vertex_destroy. /// Resulting vertex must be freed using mgp_vertex_destroy.
/// Return mgp_error::MGP_ERROR_IMMUTABLE_OBJECT if `graph` is immutable. /// Return mgp_error::MGP_ERROR_IMMUTABLE_OBJECT if `graph` is immutable.

File diff suppressed because it is too large Load Diff

View File

@ -259,7 +259,7 @@ repo_clone_try_double "${primary_urls[absl]}" "${secondary_urls[absl]}" "absl" "
# jemalloc ea6b3e973b477b8061e0076bb257dbd7f3faa756 # jemalloc ea6b3e973b477b8061e0076bb257dbd7f3faa756
JEMALLOC_COMMIT_VERSION="5.2.1" JEMALLOC_COMMIT_VERSION="5.2.1"
repo_clone_try_double "${secondary_urls[jemalloc]}" "${secondary_urls[jemalloc]}" "jemalloc" "$JEMALLOC_COMMIT_VERSION" repo_clone_try_double "${primary_urls[jemalloc]}" "${secondary_urls[jemalloc]}" "jemalloc" "$JEMALLOC_COMMIT_VERSION"
# this is hack for cmake in libs to set path, and for FindJemalloc to use Jemalloc_INCLUDE_DIR # this is hack for cmake in libs to set path, and for FindJemalloc to use Jemalloc_INCLUDE_DIR
pushd jemalloc pushd jemalloc

View File

@ -36,7 +36,7 @@ ADDITIONAL USE GRANT: You may use the Licensed Work in accordance with the
3. using the Licensed Work to create a work or solution 3. using the Licensed Work to create a work or solution
which competes (or might reasonably be expected to which competes (or might reasonably be expected to
compete) with the Licensed Work. compete) with the Licensed Work.
CHANGE DATE: 2027-30-10 CHANGE DATE: 2027-08-12
CHANGE LICENSE: Apache License, Version 2.0 CHANGE LICENSE: Apache License, Version 2.0
For information about alternative licensing arrangements, please visit: https://memgraph.com/legal. For information about alternative licensing arrangements, please visit: https://memgraph.com/legal.

View File

@ -13,6 +13,7 @@ string(TOLOWER ${CMAKE_BUILD_TYPE} lower_build_type)
add_library(example_c SHARED example.c) add_library(example_c SHARED example.c)
target_include_directories(example_c PRIVATE ${CMAKE_SOURCE_DIR}/include) target_include_directories(example_c PRIVATE ${CMAKE_SOURCE_DIR}/include)
target_compile_options(example_c PRIVATE -Wall) target_compile_options(example_c PRIVATE -Wall)
target_link_libraries(example_c PRIVATE -static-libgcc -static-libstdc++)
# Strip C example in release build. # Strip C example in release build.
if (lower_build_type STREQUAL "release") if (lower_build_type STREQUAL "release")
add_custom_command(TARGET example_c POST_BUILD add_custom_command(TARGET example_c POST_BUILD
@ -28,6 +29,7 @@ install(FILES example.c DESTINATION lib/memgraph/query_modules/src)
add_library(example_cpp SHARED example.cpp) add_library(example_cpp SHARED example.cpp)
target_include_directories(example_cpp PRIVATE ${CMAKE_SOURCE_DIR}/include) target_include_directories(example_cpp PRIVATE ${CMAKE_SOURCE_DIR}/include)
target_compile_options(example_cpp PRIVATE -Wall) target_compile_options(example_cpp PRIVATE -Wall)
target_link_libraries(example_cpp PRIVATE -static-libgcc -static-libstdc++)
# Strip C++ example in release build. # Strip C++ example in release build.
if (lower_build_type STREQUAL "release") if (lower_build_type STREQUAL "release")
add_custom_command(TARGET example_cpp POST_BUILD add_custom_command(TARGET example_cpp POST_BUILD
@ -43,6 +45,7 @@ install(FILES example.cpp DESTINATION lib/memgraph/query_modules/src)
add_library(schema SHARED schema.cpp) add_library(schema SHARED schema.cpp)
target_include_directories(schema PRIVATE ${CMAKE_SOURCE_DIR}/include) target_include_directories(schema PRIVATE ${CMAKE_SOURCE_DIR}/include)
target_compile_options(schema PRIVATE -Wall) target_compile_options(schema PRIVATE -Wall)
target_link_libraries(schema PRIVATE -static-libgcc -static-libstdc++)
# Strip C++ example in release build. # Strip C++ example in release build.
if (lower_build_type STREQUAL "release") if (lower_build_type STREQUAL "release")
add_custom_command(TARGET schema POST_BUILD add_custom_command(TARGET schema POST_BUILD

View File

@ -10,18 +10,33 @@
// licenses/APL.txt. // licenses/APL.txt.
#include <mgp.hpp> #include <mgp.hpp>
#include "utils/string.hpp"
#include <optional>
namespace Schema { namespace Schema {
/*NodeTypeProperties and RelTypeProperties constants*/ constexpr std::string_view kStatusKept = "Kept";
constexpr std::string_view kStatusCreated = "Created";
constexpr std::string_view kStatusDropped = "Dropped";
constexpr std::string_view kReturnNodeType = "nodeType"; constexpr std::string_view kReturnNodeType = "nodeType";
constexpr std::string_view kProcedureNodeType = "node_type_properties"; constexpr std::string_view kProcedureNodeType = "node_type_properties";
constexpr std::string_view kProcedureRelType = "rel_type_properties"; constexpr std::string_view kProcedureRelType = "rel_type_properties";
constexpr std::string_view kProcedureAssert = "assert";
constexpr std::string_view kReturnLabels = "nodeLabels"; constexpr std::string_view kReturnLabels = "nodeLabels";
constexpr std::string_view kReturnRelType = "relType"; constexpr std::string_view kReturnRelType = "relType";
constexpr std::string_view kReturnPropertyName = "propertyName"; constexpr std::string_view kReturnPropertyName = "propertyName";
constexpr std::string_view kReturnPropertyType = "propertyTypes"; constexpr std::string_view kReturnPropertyType = "propertyTypes";
constexpr std::string_view kReturnMandatory = "mandatory"; constexpr std::string_view kReturnMandatory = "mandatory";
constexpr std::string_view kReturnLabel = "label";
constexpr std::string_view kReturnKey = "key";
constexpr std::string_view kReturnKeys = "keys";
constexpr std::string_view kReturnUnique = "unique";
constexpr std::string_view kReturnAction = "action";
constexpr std::string_view kParameterIndices = "indices";
constexpr std::string_view kParameterUniqueConstraints = "unique_constraints";
constexpr std::string_view kParameterExistenceConstraints = "existence_constraints";
constexpr std::string_view kParameterDropExisting = "drop_existing";
std::string TypeOf(const mgp::Type &type); std::string TypeOf(const mgp::Type &type);
@ -35,6 +50,7 @@ void ProcessPropertiesRel(mgp::Record &record, const std::string_view &type, con
void NodeTypeProperties(mgp_list *args, mgp_graph *memgraph_graph, mgp_result *result, mgp_memory *memory); void NodeTypeProperties(mgp_list *args, mgp_graph *memgraph_graph, mgp_result *result, mgp_memory *memory);
void RelTypeProperties(mgp_list *args, mgp_graph *memgraph_graph, mgp_result *result, mgp_memory *memory); void RelTypeProperties(mgp_list *args, mgp_graph *memgraph_graph, mgp_result *result, mgp_memory *memory);
void Assert(mgp_list *args, mgp_graph *memgraph_graph, mgp_result *result, mgp_memory *memory);
} // namespace Schema } // namespace Schema
/*we have << operator for type in Cpp API, but in it we return somewhat different strings than I would like in this /*we have << operator for type in Cpp API, but in it we return somewhat different strings than I would like in this
@ -92,31 +108,83 @@ void Schema::ProcessPropertiesRel(mgp::Record &record, const std::string_view &t
record.Insert(std::string(kReturnMandatory).c_str(), mandatory); record.Insert(std::string(kReturnMandatory).c_str(), mandatory);
} }
void Schema::NodeTypeProperties(mgp_list *args, mgp_graph *memgraph_graph, mgp_result *result, mgp_memory *memory) { struct Property {
std::string name;
mgp::Value value;
Property(const std::string &name, mgp::Value &&value) : name(name), value(std::move(value)) {}
};
struct LabelsHash {
std::size_t operator()(const std::set<std::string> &set) const {
std::size_t seed = set.size();
for (const auto &i : set) {
seed ^= std::hash<std::string>{}(i) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
return seed;
}
};
struct LabelsComparator {
bool operator()(const std::set<std::string> &lhs, const std::set<std::string> &rhs) const { return lhs == rhs; }
};
struct PropertyComparator {
bool operator()(const Property &lhs, const Property &rhs) const { return lhs.name < rhs.name; }
};
struct PropertyInfo {
std::set<Property, PropertyComparator> properties;
bool mandatory;
};
void Schema::NodeTypeProperties(mgp_list * /*args*/, mgp_graph *memgraph_graph, mgp_result *result,
mgp_memory *memory) {
mgp::MemoryDispatcherGuard guard{memory}; mgp::MemoryDispatcherGuard guard{memory};
;
const auto record_factory = mgp::RecordFactory(result); const auto record_factory = mgp::RecordFactory(result);
try { try {
const mgp::Graph graph = mgp::Graph(memgraph_graph); std::unordered_map<std::set<std::string>, PropertyInfo, LabelsHash, LabelsComparator> node_types_properties;
for (auto node : graph.Nodes()) {
std::string type = ""; for (auto node : mgp::Graph(memgraph_graph).Nodes()) {
mgp::List labels = mgp::List(); std::set<std::string> labels_set = {};
for (auto label : node.Labels()) { for (auto label : node.Labels()) {
labels.AppendExtend(mgp::Value(label)); labels_set.emplace(label);
type += ":`" + std::string(label) + "`";
} }
if (node.Properties().size() == 0) { if (node_types_properties.find(labels_set) == node_types_properties.end()) {
auto record = record_factory.NewRecord(); node_types_properties[labels_set] = PropertyInfo{std::set<Property, PropertyComparator>(), true};
ProcessPropertiesNode<std::string>(record, type, labels, "", "", false); }
if (node.Properties().empty()) {
node_types_properties[labels_set].mandatory = false; // if there is node with no property, it is not mandatory
continue; continue;
} }
auto &property_info = node_types_properties.at(labels_set);
for (auto &[key, prop] : node.Properties()) { for (auto &[key, prop] : node.Properties()) {
auto property_type = mgp::List(); property_info.properties.emplace(key, std::move(prop));
if (property_info.mandatory) {
property_info.mandatory =
property_info.properties.size() == 1; // if there is only one property, it is mandatory
}
}
}
for (auto &[labels, property_info] : node_types_properties) {
std::string label_type;
mgp::List labels_list = mgp::List();
for (auto const &label : labels) {
label_type += ":`" + std::string(label) + "`";
labels_list.AppendExtend(mgp::Value(label));
}
for (auto const &prop : property_info.properties) {
auto record = record_factory.NewRecord(); auto record = record_factory.NewRecord();
property_type.AppendExtend(mgp::Value(TypeOf(prop.Type()))); ProcessPropertiesNode(record, label_type, labels_list, prop.name, TypeOf(prop.value.Type()),
ProcessPropertiesNode<mgp::List>(record, type, labels, key, property_type, true); property_info.mandatory);
}
if (property_info.properties.empty()) {
auto record = record_factory.NewRecord();
ProcessPropertiesNode<std::string>(record, label_type, labels_list, "", "", false);
} }
} }
@ -126,26 +194,43 @@ void Schema::NodeTypeProperties(mgp_list *args, mgp_graph *memgraph_graph, mgp_r
} }
} }
void Schema::RelTypeProperties(mgp_list *args, mgp_graph *memgraph_graph, mgp_result *result, mgp_memory *memory) { void Schema::RelTypeProperties(mgp_list * /*args*/, mgp_graph *memgraph_graph, mgp_result *result, mgp_memory *memory) {
mgp::MemoryDispatcherGuard guard{memory}; mgp::MemoryDispatcherGuard guard{memory};
;
std::unordered_map<std::string, PropertyInfo> rel_types_properties;
const auto record_factory = mgp::RecordFactory(result); const auto record_factory = mgp::RecordFactory(result);
try { try {
const mgp::Graph graph = mgp::Graph(memgraph_graph); const mgp::Graph graph = mgp::Graph(memgraph_graph);
for (auto rel : graph.Relationships()) { for (auto rel : graph.Relationships()) {
std::string type = ":`" + std::string(rel.Type()) + "`"; std::string rel_type = std::string(rel.Type());
if (rel.Properties().size() == 0) { if (rel_types_properties.find(rel_type) == rel_types_properties.end()) {
auto record = record_factory.NewRecord(); rel_types_properties[rel_type] = PropertyInfo{std::set<Property, PropertyComparator>(), true};
ProcessPropertiesRel<std::string>(record, type, "", "", false); }
if (rel.Properties().empty()) {
rel_types_properties[rel_type].mandatory = false; // if there is rel with no property, it is not mandatory
continue; continue;
} }
auto &property_info = rel_types_properties.at(rel_type);
for (auto &[key, prop] : rel.Properties()) { for (auto &[key, prop] : rel.Properties()) {
auto property_type = mgp::List(); property_info.properties.emplace(key, std::move(prop));
if (property_info.mandatory) {
property_info.mandatory =
property_info.properties.size() == 1; // if there is only one property, it is mandatory
}
}
}
for (auto &[type, property_info] : rel_types_properties) {
std::string type_str = ":`" + std::string(type) + "`";
for (auto const &prop : property_info.properties) {
auto record = record_factory.NewRecord(); auto record = record_factory.NewRecord();
property_type.AppendExtend(mgp::Value(TypeOf(prop.Type()))); ProcessPropertiesRel(record, type_str, prop.name, TypeOf(prop.value.Type()), property_info.mandatory);
ProcessPropertiesRel<mgp::List>(record, type, key, property_type, true); }
if (property_info.properties.empty()) {
auto record = record_factory.NewRecord();
ProcessPropertiesRel<std::string>(record, type_str, "", "", false);
} }
} }
@ -155,29 +240,435 @@ void Schema::RelTypeProperties(mgp_list *args, mgp_graph *memgraph_graph, mgp_re
} }
} }
void InsertRecordForLabelIndex(const auto &record_factory, const std::string_view label,
const std::string_view status) {
auto record = record_factory.NewRecord();
record.Insert(std::string(Schema::kReturnLabel).c_str(), label);
record.Insert(std::string(Schema::kReturnKey).c_str(), "");
record.Insert(std::string(Schema::kReturnKeys).c_str(), mgp::List());
record.Insert(std::string(Schema::kReturnUnique).c_str(), false);
record.Insert(std::string(Schema::kReturnAction).c_str(), status);
}
void InsertRecordForUniqueConstraint(const auto &record_factory, const std::string_view label,
const mgp::List &properties, const std::string_view status) {
auto record = record_factory.NewRecord();
record.Insert(std::string(Schema::kReturnLabel).c_str(), label);
record.Insert(std::string(Schema::kReturnKey).c_str(), properties.ToString());
record.Insert(std::string(Schema::kReturnKeys).c_str(), properties);
record.Insert(std::string(Schema::kReturnUnique).c_str(), true);
record.Insert(std::string(Schema::kReturnAction).c_str(), status);
}
void InsertRecordForLabelPropertyIndexAndExistenceConstraint(const auto &record_factory, const std::string_view label,
const std::string_view property,
const std::string_view status) {
auto record = record_factory.NewRecord();
record.Insert(std::string(Schema::kReturnLabel).c_str(), label);
record.Insert(std::string(Schema::kReturnKey).c_str(), property);
record.Insert(std::string(Schema::kReturnKeys).c_str(), mgp::List({mgp::Value(property)}));
record.Insert(std::string(Schema::kReturnUnique).c_str(), false);
record.Insert(std::string(Schema::kReturnAction).c_str(), status);
}
void ProcessCreatingLabelIndex(const std::string_view label, const std::set<std::string_view> &existing_label_indices,
mgp_graph *memgraph_graph, const auto &record_factory) {
if (existing_label_indices.contains(label)) {
InsertRecordForLabelIndex(record_factory, label, Schema::kStatusKept);
} else if (mgp::CreateLabelIndex(memgraph_graph, label)) {
InsertRecordForLabelIndex(record_factory, label, Schema::kStatusCreated);
}
}
template <typename TFunc>
void ProcessCreatingLabelPropertyIndexAndExistenceConstraint(const std::string_view label,
const std::string_view property,
const std::set<std::string_view> &existing_collection,
const TFunc &func_creation, mgp_graph *memgraph_graph,
const auto &record_factory) {
const auto label_property_search_key = std::string(label) + ":" + std::string(property);
if (existing_collection.contains(label_property_search_key)) {
InsertRecordForLabelPropertyIndexAndExistenceConstraint(record_factory, label, property, Schema::kStatusKept);
} else if (func_creation(memgraph_graph, label, property)) {
InsertRecordForLabelPropertyIndexAndExistenceConstraint(record_factory, label, property, Schema::kStatusCreated);
}
}
/// We collect properties for which index was created.
using AssertedIndices = std::set<std::string, std::less<>>;
AssertedIndices CreateIndicesForLabel(const std::string_view label, const mgp::Value &properties_val,
mgp_graph *memgraph_graph, const auto &record_factory,
const std::set<std::string_view> &existing_label_indices,
const std::set<std::string_view> &existing_label_property_indices) {
AssertedIndices asserted_indices;
if (!properties_val.IsList()) {
return {};
}
if (const auto properties = properties_val.ValueList();
properties.Empty() && mgp::CreateLabelIndex(memgraph_graph, label)) {
InsertRecordForLabelIndex(record_factory, label, Schema::kStatusCreated);
asserted_indices.emplace("");
} else {
std::for_each(properties.begin(), properties.end(),
[&label, &existing_label_indices, &existing_label_property_indices, &memgraph_graph, &record_factory,
&asserted_indices](const mgp::Value &property) {
if (!property.IsString()) {
return;
}
const auto property_str = property.ValueString();
if (property_str.empty()) {
ProcessCreatingLabelIndex(label, existing_label_indices, memgraph_graph, record_factory);
asserted_indices.emplace("");
} else {
ProcessCreatingLabelPropertyIndexAndExistenceConstraint(
label, property_str, existing_label_property_indices, mgp::CreateLabelPropertyIndex,
memgraph_graph, record_factory);
asserted_indices.emplace(property_str);
}
});
}
return asserted_indices;
}
void ProcessIndices(const mgp::Map &indices_map, mgp_graph *memgraph_graph, const auto &record_factory,
bool drop_existing) {
auto mgp_existing_label_indices = mgp::ListAllLabelIndices(memgraph_graph);
auto mgp_existing_label_property_indices = mgp::ListAllLabelPropertyIndices(memgraph_graph);
std::set<std::string_view> existing_label_indices;
std::transform(mgp_existing_label_indices.begin(), mgp_existing_label_indices.end(),
std::inserter(existing_label_indices, existing_label_indices.begin()),
[](const mgp::Value &index) { return index.ValueString(); });
std::set<std::string_view> existing_label_property_indices;
std::transform(mgp_existing_label_property_indices.begin(), mgp_existing_label_property_indices.end(),
std::inserter(existing_label_property_indices, existing_label_property_indices.begin()),
[](const mgp::Value &index) { return index.ValueString(); });
std::set<std::string> asserted_label_indices;
std::set<std::string> asserted_label_property_indices;
auto merge_label_property = [](const std::string &label, const std::string &property) {
return label + ":" + property;
};
for (const auto &index : indices_map) {
const std::string_view label = index.key;
const mgp::Value &properties_val = index.value;
AssertedIndices asserted_indices_new = CreateIndicesForLabel(
label, properties_val, memgraph_graph, record_factory, existing_label_indices, existing_label_property_indices);
if (!drop_existing) {
continue;
}
std::ranges::for_each(asserted_indices_new, [&asserted_label_indices, &asserted_label_property_indices, label,
&merge_label_property](const std::string &property) {
if (property.empty()) {
asserted_label_indices.emplace(label);
} else {
asserted_label_property_indices.emplace(merge_label_property(std::string(label), property));
}
});
}
if (!drop_existing) {
return;
}
std::set<std::string_view> label_indices_to_drop;
std::ranges::set_difference(existing_label_indices, asserted_label_indices,
std::inserter(label_indices_to_drop, label_indices_to_drop.begin()));
std::ranges::for_each(label_indices_to_drop, [memgraph_graph, &record_factory](const std::string_view label) {
if (mgp::DropLabelIndex(memgraph_graph, label)) {
InsertRecordForLabelIndex(record_factory, label, Schema::kStatusDropped);
}
});
std::set<std::string_view> label_property_indices_to_drop;
std::ranges::set_difference(existing_label_property_indices, asserted_label_property_indices,
std::inserter(label_property_indices_to_drop, label_property_indices_to_drop.begin()));
auto decouple_label_property = [](const std::string_view label_property) {
const auto label_size = label_property.find(':');
const auto label = std::string(label_property.substr(0, label_size));
const auto property = std::string(label_property.substr(label_size + 1));
return std::make_pair(label, property);
};
std::ranges::for_each(label_property_indices_to_drop, [memgraph_graph, &record_factory, decouple_label_property](
const std::string_view label_property) {
const auto [label, property] = decouple_label_property(label_property);
if (mgp::DropLabelPropertyIndex(memgraph_graph, label, property)) {
InsertRecordForLabelPropertyIndexAndExistenceConstraint(record_factory, label, property, Schema::kStatusDropped);
}
});
}
using ExistenceConstraintsStorage = std::set<std::string_view>;
ExistenceConstraintsStorage CreateExistenceConstraintsForLabel(
const std::string_view label, const mgp::Value &properties_val, mgp_graph *memgraph_graph,
const auto &record_factory, const std::set<std::string_view> &existing_existence_constraints) {
ExistenceConstraintsStorage asserted_existence_constraints;
if (!properties_val.IsList()) {
return asserted_existence_constraints;
}
auto validate_property = [](const mgp::Value &property) -> bool {
return property.IsString() && !property.ValueString().empty();
};
const auto &properties = properties_val.ValueList();
std::for_each(properties.begin(), properties.end(),
[&label, &existing_existence_constraints, &asserted_existence_constraints, &memgraph_graph,
&record_factory, &validate_property](const mgp::Value &property) {
if (!validate_property(property)) {
return;
}
const std::string_view property_str = property.ValueString();
asserted_existence_constraints.emplace(property_str);
ProcessCreatingLabelPropertyIndexAndExistenceConstraint(
label, property_str, existing_existence_constraints, mgp::CreateExistenceConstraint,
memgraph_graph, record_factory);
});
return asserted_existence_constraints;
}
void ProcessExistenceConstraints(const mgp::Map &existence_constraints_map, mgp_graph *memgraph_graph,
const auto &record_factory, bool drop_existing) {
auto mgp_existing_existence_constraints = mgp::ListAllExistenceConstraints(memgraph_graph);
std::set<std::string_view> existing_existence_constraints;
std::transform(mgp_existing_existence_constraints.begin(), mgp_existing_existence_constraints.end(),
std::inserter(existing_existence_constraints, existing_existence_constraints.begin()),
[](const mgp::Value &constraint) { return constraint.ValueString(); });
auto merge_label_property = [](const std::string_view label, const std::string_view property) {
auto str = std::string(label) + ":";
str += property;
return str;
};
ExistenceConstraintsStorage asserted_existence_constraints;
for (const auto &existing_constraint : existence_constraints_map) {
const std::string_view label = existing_constraint.key;
const mgp::Value &properties_val = existing_constraint.value;
auto asserted_existence_constraints_new = CreateExistenceConstraintsForLabel(
label, properties_val, memgraph_graph, record_factory, existing_existence_constraints);
if (!drop_existing) {
continue;
}
std::ranges::for_each(asserted_existence_constraints_new, [&asserted_existence_constraints, &merge_label_property,
label](const std::string_view property) {
asserted_existence_constraints.emplace(merge_label_property(label, property));
});
}
if (!drop_existing) {
return;
}
std::set<std::string_view> existence_constraints_to_drop;
std::ranges::set_difference(existing_existence_constraints, asserted_existence_constraints,
std::inserter(existence_constraints_to_drop, existence_constraints_to_drop.begin()));
auto decouple_label_property = [](const std::string_view label_property) {
const auto label_size = label_property.find(':');
const auto label = std::string(label_property.substr(0, label_size));
const auto property = std::string(label_property.substr(label_size + 1));
return std::make_pair(label, property);
};
std::ranges::for_each(existence_constraints_to_drop, [&](const std::string_view label_property) {
const auto [label, property] = decouple_label_property(label_property);
if (mgp::DropExistenceConstraint(memgraph_graph, label, property)) {
InsertRecordForLabelPropertyIndexAndExistenceConstraint(record_factory, label, property, Schema::kStatusDropped);
}
});
}
using AssertedUniqueConstraintsStorage = std::set<std::set<std::string_view>>;
AssertedUniqueConstraintsStorage CreateUniqueConstraintsForLabel(
const std::string_view label, const mgp::Value &unique_props_nested,
const std::map<std::string_view, AssertedUniqueConstraintsStorage> &existing_unique_constraints,
mgp_graph *memgraph_graph, const auto &record_factory) {
AssertedUniqueConstraintsStorage asserted_unique_constraints;
if (!unique_props_nested.IsList()) {
return asserted_unique_constraints;
}
auto validate_unique_constraint_props = [](const mgp::Value &properties) -> bool {
if (!properties.IsList()) {
return false;
}
const auto &properties_list = properties.ValueList();
if (properties_list.Empty()) {
return false;
}
return std::all_of(properties_list.begin(), properties_list.end(), [](const mgp::Value &property) {
return property.IsString() && !property.ValueString().empty();
});
};
auto unique_constraint_exists =
[](const std::string_view label, const std::set<std::string_view> &properties,
const std::map<std::string_view, AssertedUniqueConstraintsStorage> &existing_unique_constraints) -> bool {
auto iter = existing_unique_constraints.find(label);
if (iter == existing_unique_constraints.end()) {
return false;
}
return iter->second.find(properties) != iter->second.end();
};
for (const auto unique_props_nested_list = unique_props_nested.ValueList();
const auto &properties : unique_props_nested_list) {
if (!validate_unique_constraint_props(properties)) {
continue;
}
const auto properties_list = properties.ValueList();
std::set<std::string_view> properties_coll;
std::transform(properties_list.begin(), properties_list.end(),
std::inserter(properties_coll, properties_coll.begin()),
[](const mgp::Value &property) { return property.ValueString(); });
if (unique_constraint_exists(label, properties_coll, existing_unique_constraints)) {
InsertRecordForUniqueConstraint(record_factory, label, properties_list, Schema::kStatusKept);
} else if (mgp::CreateUniqueConstraint(memgraph_graph, label, properties.ptr())) {
InsertRecordForUniqueConstraint(record_factory, label, properties_list, Schema::kStatusCreated);
}
asserted_unique_constraints.emplace(std::move(properties_coll));
}
return asserted_unique_constraints;
}
void ProcessUniqueConstraints(const mgp::Map &unique_constraints_map, mgp_graph *memgraph_graph,
const auto &record_factory, bool drop_existing) {
auto mgp_existing_unique_constraints = mgp::ListAllUniqueConstraints(memgraph_graph);
// label-unique_constraints pair
std::map<std::string_view, AssertedUniqueConstraintsStorage> existing_unique_constraints;
for (const auto &constraint : mgp_existing_unique_constraints) {
auto constraint_list = constraint.ValueList();
std::set<std::string_view> properties;
for (int i = 1; i < constraint_list.Size(); i++) {
properties.emplace(constraint_list[i].ValueString());
}
const std::string_view label = constraint_list[0].ValueString();
auto [it, inserted] = existing_unique_constraints.try_emplace(label, AssertedUniqueConstraintsStorage{properties});
if (!inserted) {
it->second.emplace(std::move(properties));
}
}
std::map<std::string_view, AssertedUniqueConstraintsStorage> asserted_unique_constraints;
for (const auto &[label, unique_props_nested] : unique_constraints_map) {
auto asserted_unique_constraints_new = CreateUniqueConstraintsForLabel(
label, unique_props_nested, existing_unique_constraints, memgraph_graph, record_factory);
if (drop_existing) {
asserted_unique_constraints.emplace(label, std::move(asserted_unique_constraints_new));
}
}
if (!drop_existing) {
return;
}
std::vector<std::pair<std::string_view, std::set<std::string_view>>> unique_constraints_to_drop;
// Check for each label for we found existing constraint in the DB whether it was asserted.
// If no unique constraint was found with label, we can drop all unique constraints for this label. (if branch)
// If some unique constraint was found with label, we can drop only those unique constraints that were not asserted.
// (else branch.)
std::ranges::for_each(existing_unique_constraints, [&asserted_unique_constraints, &unique_constraints_to_drop](
const auto &existing_label_unique_constraints) {
const auto &label = existing_label_unique_constraints.first;
const auto &existing_unique_constraints_for_label = existing_label_unique_constraints.second;
const auto &asserted_unique_constraints_for_label = asserted_unique_constraints.find(label);
if (asserted_unique_constraints_for_label == asserted_unique_constraints.end()) {
std::ranges::for_each(
std::make_move_iterator(existing_unique_constraints_for_label.begin()),
std::make_move_iterator(existing_unique_constraints_for_label.end()),
[&unique_constraints_to_drop, &label](std::set<std::string_view> existing_unique_constraint_for_label) {
unique_constraints_to_drop.emplace_back(label, std::move(existing_unique_constraint_for_label));
});
} else {
const auto &asserted_unique_constraints_for_label_coll = asserted_unique_constraints_for_label->second;
std::ranges::for_each(
std::make_move_iterator(existing_unique_constraints_for_label.begin()),
std::make_move_iterator(existing_unique_constraints_for_label.end()),
[&unique_constraints_to_drop, &label, &asserted_unique_constraints_for_label_coll](
std::set<std::string_view> existing_unique_constraint_for_label) {
if (!asserted_unique_constraints_for_label_coll.contains(existing_unique_constraint_for_label)) {
unique_constraints_to_drop.emplace_back(label, std::move(existing_unique_constraint_for_label));
}
});
}
});
std::ranges::for_each(
unique_constraints_to_drop, [memgraph_graph, &record_factory](const auto &label_unique_constraint) {
const auto &[label, unique_constraint] = label_unique_constraint;
auto unique_constraint_list = mgp::List();
std::ranges::for_each(unique_constraint, [&unique_constraint_list](const std::string_view &property) {
unique_constraint_list.AppendExtend(mgp::Value(property));
});
if (mgp::DropUniqueConstraint(memgraph_graph, label, mgp::Value(unique_constraint_list).ptr())) {
InsertRecordForUniqueConstraint(record_factory, label, unique_constraint_list, Schema::kStatusDropped);
}
});
}
void Schema::Assert(mgp_list *args, mgp_graph *memgraph_graph, mgp_result *result, mgp_memory *memory) {
mgp::MemoryDispatcherGuard guard{memory};
const auto record_factory = mgp::RecordFactory(result);
auto arguments = mgp::List(args);
auto indices_map = arguments[0].ValueMap();
auto unique_constraints_map = arguments[1].ValueMap();
auto existence_constraints_map = arguments[2].ValueMap();
auto drop_existing = arguments[3].ValueBool();
ProcessIndices(indices_map, memgraph_graph, record_factory, drop_existing);
ProcessExistenceConstraints(existence_constraints_map, memgraph_graph, record_factory, drop_existing);
ProcessUniqueConstraints(unique_constraints_map, memgraph_graph, record_factory, drop_existing);
}
extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory) { extern "C" int mgp_init_module(struct mgp_module *module, struct mgp_memory *memory) {
try { try {
mgp::MemoryDispatcherGuard guard{memory}; mgp::MemoryDispatcherGuard guard{memory};
;
AddProcedure(Schema::NodeTypeProperties, std::string(Schema::kProcedureNodeType).c_str(), mgp::ProcedureType::Read, AddProcedure(Schema::NodeTypeProperties, Schema::kProcedureNodeType, mgp::ProcedureType::Read, {},
{}, {mgp::Return(Schema::kReturnNodeType, mgp::Type::String),
{mgp::Return(std::string(Schema::kReturnNodeType).c_str(), mgp::Type::String), mgp::Return(Schema::kReturnLabels, {mgp::Type::List, mgp::Type::String}),
mgp::Return(std::string(Schema::kReturnLabels).c_str(), {mgp::Type::List, mgp::Type::String}), mgp::Return(Schema::kReturnPropertyName, mgp::Type::String),
mgp::Return(std::string(Schema::kReturnPropertyName).c_str(), mgp::Type::String), mgp::Return(Schema::kReturnPropertyType, mgp::Type::Any),
mgp::Return(std::string(Schema::kReturnPropertyType).c_str(), mgp::Type::Any), mgp::Return(Schema::kReturnMandatory, mgp::Type::Bool)},
mgp::Return(std::string(Schema::kReturnMandatory).c_str(), mgp::Type::Bool)},
module, memory); module, memory);
AddProcedure(Schema::RelTypeProperties, std::string(Schema::kProcedureRelType).c_str(), mgp::ProcedureType::Read, AddProcedure(Schema::RelTypeProperties, Schema::kProcedureRelType, mgp::ProcedureType::Read, {},
{}, {mgp::Return(Schema::kReturnRelType, mgp::Type::String),
{mgp::Return(std::string(Schema::kReturnRelType).c_str(), mgp::Type::String), mgp::Return(Schema::kReturnPropertyName, mgp::Type::String),
mgp::Return(std::string(Schema::kReturnPropertyName).c_str(), mgp::Type::String), mgp::Return(Schema::kReturnPropertyType, mgp::Type::Any),
mgp::Return(std::string(Schema::kReturnPropertyType).c_str(), mgp::Type::Any), mgp::Return(Schema::kReturnMandatory, mgp::Type::Bool)},
mgp::Return(std::string(Schema::kReturnMandatory).c_str(), mgp::Type::Bool)},
module, memory); module, memory);
AddProcedure(
Schema::Assert, Schema::kProcedureAssert, mgp::ProcedureType::Read,
{
mgp::Parameter(Schema::kParameterIndices, {mgp::Type::Map, mgp::Type::Any}),
mgp::Parameter(Schema::kParameterUniqueConstraints, {mgp::Type::Map, mgp::Type::Any}),
mgp::Parameter(Schema::kParameterExistenceConstraints, {mgp::Type::Map, mgp::Type::Any},
mgp::Value(mgp::Map{})),
mgp::Parameter(Schema::kParameterDropExisting, mgp::Type::Bool, mgp::Value(true)),
},
{mgp::Return(Schema::kReturnLabel, mgp::Type::String), mgp::Return(Schema::kReturnKey, mgp::Type::String),
mgp::Return(Schema::kReturnKeys, {mgp::Type::List, mgp::Type::String}),
mgp::Return(Schema::kReturnUnique, mgp::Type::Bool), mgp::Return(Schema::kReturnAction, mgp::Type::String)},
module, memory);
} catch (const std::exception &e) { } catch (const std::exception &e) {
std::cerr << "Error while initializing query module: " << e.what() << std::endl;
return 1; return 1;
} }

View File

@ -104,7 +104,9 @@ def retry(retry_limit, timeout=100):
except Exception: except Exception:
time.sleep(timeout) time.sleep(timeout)
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper return wrapper
return inner_func return inner_func
@ -200,19 +202,19 @@ if args.version:
try: try:
current_branch = get_output("git", "rev-parse", "--abbrev-ref", "HEAD") current_branch = get_output("git", "rev-parse", "--abbrev-ref", "HEAD")
if current_branch != "master": if current_branch != "master":
branches = get_output("git", "branch") branches = get_output("git", "branch", "-r", "--list", "origin/master")
if "master" in branches: if "origin/master" in branches:
# If master is present locally, the fetch is allowed to fail # If master is present locally, the fetch is allowed to fail
# because this script will still be able to compare against the # because this script will still be able to compare against the
# master branch. # master branch.
try: try:
get_output("git", "fetch", "origin", "master:master") get_output("git", "fetch", "origin", "master")
except Exception: except Exception:
pass pass
else: else:
# If master is not present locally, the fetch command has to # If master is not present locally, the fetch command has to
# succeed because something else will fail otherwise. # succeed because something else will fail otherwise.
get_output("git", "fetch", "origin", "master:master") get_output("git", "fetch", "origin", "master")
except Exception: except Exception:
print("Fatal error while ensuring local master branch.") print("Fatal error while ensuring local master branch.")
sys.exit(1) sys.exit(1)
@ -232,7 +234,7 @@ for branch in branches:
match = branch_regex.match(branch) match = branch_regex.match(branch)
if match is not None: if match is not None:
version = tuple(map(int, match.group(1).split("."))) version = tuple(map(int, match.group(1).split(".")))
master_branch_merge = get_output("git", "merge-base", "master", branch) master_branch_merge = get_output("git", "merge-base", "origin/master", branch)
versions.append((version, branch, master_branch_merge)) versions.append((version, branch, master_branch_merge))
versions.sort(reverse=True) versions.sort(reverse=True)
@ -243,7 +245,7 @@ current_version = None
for version in versions: for version in versions:
version_tuple, branch, master_branch_merge = version version_tuple, branch, master_branch_merge = version
current_branch_merge = get_output("git", "merge-base", current_hash, branch) current_branch_merge = get_output("git", "merge-base", current_hash, branch)
master_current_merge = get_output("git", "merge-base", current_hash, "master") master_current_merge = get_output("git", "merge-base", current_hash, "origin/master")
# The first check checks whether this commit is a child of `master` and # The first check checks whether this commit is a child of `master` and
# the version branch was created before us. # the version branch was created before us.
# The second check checks whether this commit is a child of the version # The second check checks whether this commit is a child of the version

View File

@ -13,6 +13,7 @@
#include <fmt/format.h> #include <fmt/format.h>
#include <json/json.hpp> #include <json/json.hpp>
#include <utility>
#include "storage/v2/temporal.hpp" #include "storage/v2/temporal.hpp"
#include "utils/logging.hpp" #include "utils/logging.hpp"
@ -87,8 +88,8 @@ inline nlohmann::json PropertyValueToJson(const storage::PropertyValue &pv) {
return ret; return ret;
} }
Log::Log(const std::filesystem::path &storage_directory, int32_t buffer_size, int32_t buffer_flush_interval_millis) Log::Log(std::filesystem::path storage_directory, int32_t buffer_size, int32_t buffer_flush_interval_millis)
: storage_directory_(storage_directory), : storage_directory_(std::move(storage_directory)),
buffer_size_(buffer_size), buffer_size_(buffer_size),
buffer_flush_interval_millis_(buffer_flush_interval_millis), buffer_flush_interval_millis_(buffer_flush_interval_millis),
started_(false) {} started_(false) {}

View File

@ -36,7 +36,7 @@ class Log {
}; };
public: public:
Log(const std::filesystem::path &storage_directory, int32_t buffer_size, int32_t buffer_flush_interval_millis); Log(std::filesystem::path storage_directory, int32_t buffer_size, int32_t buffer_flush_interval_millis);
~Log(); ~Log();

View File

@ -10,6 +10,7 @@
#include <cstdint> #include <cstdint>
#include <regex> #include <regex>
#include <utility>
#include <gflags/gflags.h> #include <gflags/gflags.h>
@ -560,20 +561,20 @@ Databases Databases::Deserialize(const nlohmann::json &data) {
} }
#endif #endif
User::User() {} User::User() = default;
User::User(const std::string &username) : username_(utils::ToLowerCase(username)) {} User::User(const std::string &username) : username_(utils::ToLowerCase(username)) {}
User::User(const std::string &username, const std::string &password_hash, const Permissions &permissions) User::User(const std::string &username, std::string password_hash, const Permissions &permissions)
: username_(utils::ToLowerCase(username)), password_hash_(password_hash), permissions_(permissions) {} : username_(utils::ToLowerCase(username)), password_hash_(std::move(password_hash)), permissions_(permissions) {}
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
User::User(const std::string &username, const std::string &password_hash, const Permissions &permissions, User::User(const std::string &username, std::string password_hash, const Permissions &permissions,
FineGrainedAccessHandler fine_grained_access_handler, Databases db_access) FineGrainedAccessHandler fine_grained_access_handler, Databases db_access)
: username_(utils::ToLowerCase(username)), : username_(utils::ToLowerCase(username)),
password_hash_(password_hash), password_hash_(std::move(password_hash)),
permissions_(permissions), permissions_(permissions),
fine_grained_access_handler_(std::move(fine_grained_access_handler)), fine_grained_access_handler_(std::move(fine_grained_access_handler)),
database_access_(db_access) {} database_access_(std::move(db_access)) {}
#endif #endif
bool User::CheckPassword(const std::string &password) { bool User::CheckPassword(const std::string &password) {

View File

@ -14,6 +14,7 @@
#include <unordered_map> #include <unordered_map>
#include <json/json.hpp> #include <json/json.hpp>
#include <utility>
#include "dbms/constants.hpp" #include "dbms/constants.hpp"
#include "utils/logging.hpp" #include "utils/logging.hpp"
@ -301,8 +302,8 @@ class Databases final {
bool Contains(const std::string &db) const; bool Contains(const std::string &db) const;
bool GetAllowAll() const { return allow_all_; } bool GetAllowAll() const { return allow_all_; }
const std::set<std::string> &GetGrants() const { return grants_dbs_; } const std::set<std::string, std::less<>> &GetGrants() const { return grants_dbs_; }
const std::set<std::string> &GetDenies() const { return denies_dbs_; } const std::set<std::string, std::less<>> &GetDenies() const { return denies_dbs_; }
const std::string &GetDefault() const; const std::string &GetDefault() const;
nlohmann::json Serialize() const; nlohmann::json Serialize() const;
@ -310,14 +311,17 @@ class Databases final {
static Databases Deserialize(const nlohmann::json &data); static Databases Deserialize(const nlohmann::json &data);
private: private:
Databases(bool allow_all, std::set<std::string> grant, std::set<std::string> deny, Databases(bool allow_all, std::set<std::string, std::less<>> grant, std::set<std::string, std::less<>> deny,
const std::string &default_db = dbms::kDefaultDB) std::string default_db = dbms::kDefaultDB)
: grants_dbs_(grant), denies_dbs_(deny), allow_all_(allow_all), default_db_(default_db) {} : grants_dbs_(std::move(grant)),
denies_dbs_(std::move(deny)),
allow_all_(allow_all),
default_db_(std::move(default_db)) {}
std::set<std::string> grants_dbs_; //!< set of databases with granted access std::set<std::string, std::less<>> grants_dbs_; //!< set of databases with granted access
std::set<std::string> denies_dbs_; //!< set of databases with denied access std::set<std::string, std::less<>> denies_dbs_; //!< set of databases with denied access
bool allow_all_; //!< flag to allow access to everything (denied overrides this) bool allow_all_; //!< flag to allow access to everything (denied overrides this)
std::string default_db_; //!< user's default database std::string default_db_; //!< user's default database
}; };
#endif #endif
@ -327,9 +331,9 @@ class User final {
User(); User();
explicit User(const std::string &username); explicit User(const std::string &username);
User(const std::string &username, const std::string &password_hash, const Permissions &permissions); User(const std::string &username, std::string password_hash, const Permissions &permissions);
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
User(const std::string &username, const std::string &password_hash, const Permissions &permissions, User(const std::string &username, std::string password_hash, const Permissions &permissions,
FineGrainedAccessHandler fine_grained_access_handler, Databases db_access = {}); FineGrainedAccessHandler fine_grained_access_handler, Databases db_access = {});
#endif #endif
User(const User &) = default; User(const User &) = default;

View File

@ -14,6 +14,7 @@
#include <map> #include <map>
#include <optional> #include <optional>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "communication/bolt/v1/codes.hpp" #include "communication/bolt/v1/codes.hpp"
@ -34,8 +35,8 @@ class FailureResponseException : public utils::BasicException {
explicit FailureResponseException(const std::string &message) : utils::BasicException{message} {} explicit FailureResponseException(const std::string &message) : utils::BasicException{message} {}
FailureResponseException(const std::string &code, const std::string &message) FailureResponseException(std::string code, const std::string &message)
: utils::BasicException{message}, code_{code} {} : utils::BasicException{message}, code_{std::move(code)} {}
const std::string &code() const { return code_; } const std::string &code() const { return code_; }
SPECIALIZE_GET_EXCEPTION_NAME(FailureResponseException) SPECIALIZE_GET_EXCEPTION_NAME(FailureResponseException)

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -51,7 +51,7 @@ enum class ChunkState : uint8_t {
template <typename TBuffer> template <typename TBuffer>
class ChunkedDecoderBuffer { class ChunkedDecoderBuffer {
public: public:
ChunkedDecoderBuffer(TBuffer &buffer) : buffer_(buffer) { data_.reserve(kChunkMaxDataSize); } explicit ChunkedDecoderBuffer(TBuffer &buffer) : buffer_(buffer) { data_.reserve(kChunkMaxDataSize); }
/** /**
* Reads data from the internal buffer. * Reads data from the internal buffer.

View File

@ -401,11 +401,11 @@ class Decoder {
} }
auto &labels = dv.ValueList(); auto &labels = dv.ValueList();
vertex.labels.reserve(labels.size()); vertex.labels.reserve(labels.size());
for (size_t i = 0; i < labels.size(); ++i) { for (auto &label : labels) {
if (labels[i].type() != Value::Type::String) { if (label.type() != Value::Type::String) {
return false; return false;
} }
vertex.labels.emplace_back(std::move(labels[i].ValueString())); vertex.labels.emplace_back(std::move(label.ValueString()));
} }
// read properties // read properties

View File

@ -111,12 +111,12 @@ class BaseEncoder {
void WriteList(const std::vector<Value> &value) { void WriteList(const std::vector<Value> &value) {
WriteTypeSize(value.size(), MarkerList); WriteTypeSize(value.size(), MarkerList);
for (auto &x : value) WriteValue(x); for (const auto &x : value) WriteValue(x);
} }
void WriteMap(const std::map<std::string, Value> &value) { void WriteMap(const std::map<std::string, Value> &value) {
WriteTypeSize(value.size(), MarkerMap); WriteTypeSize(value.size(), MarkerMap);
for (auto &x : value) { for (const auto &x : value) {
WriteString(x.first); WriteString(x.first);
WriteValue(x.second); WriteValue(x.second);
} }
@ -205,11 +205,11 @@ class BaseEncoder {
WriteRAW(utils::UnderlyingCast(Marker::TinyStruct) + 3); WriteRAW(utils::UnderlyingCast(Marker::TinyStruct) + 3);
WriteRAW(utils::UnderlyingCast(Signature::Path)); WriteRAW(utils::UnderlyingCast(Signature::Path));
WriteTypeSize(path.vertices.size(), MarkerList); WriteTypeSize(path.vertices.size(), MarkerList);
for (auto &v : path.vertices) WriteVertex(v); for (const auto &v : path.vertices) WriteVertex(v);
WriteTypeSize(path.edges.size(), MarkerList); WriteTypeSize(path.edges.size(), MarkerList);
for (auto &e : path.edges) WriteEdge(e); for (const auto &e : path.edges) WriteEdge(e);
WriteTypeSize(path.indices.size(), MarkerList); WriteTypeSize(path.indices.size(), MarkerList);
for (auto &i : path.indices) WriteInt(i); for (const auto &i : path.indices) WriteInt(i);
} }
void WriteDate(const utils::Date &date) { void WriteDate(const utils::Date &date) {

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -48,7 +48,7 @@ namespace memgraph::communication::bolt {
template <class TOutputStream> template <class TOutputStream>
class ChunkedEncoderBuffer { class ChunkedEncoderBuffer {
public: public:
ChunkedEncoderBuffer(TOutputStream &output_stream) : output_stream_(output_stream) {} explicit ChunkedEncoderBuffer(TOutputStream &output_stream) : output_stream_(output_stream) {}
/** /**
* Writes n values into the buffer. If n is bigger than whole chunk size * Writes n values into the buffer. If n is bigger than whole chunk size

View File

@ -39,7 +39,7 @@ class ClientEncoder : private BaseEncoder<Buffer> {
using BaseEncoder<Buffer>::buffer_; using BaseEncoder<Buffer>::buffer_;
public: public:
ClientEncoder(Buffer &buffer) : BaseEncoder<Buffer>(buffer) {} explicit ClientEncoder(Buffer &buffer) : BaseEncoder<Buffer>(buffer) {}
using BaseEncoder<Buffer>::UpdateVersion; using BaseEncoder<Buffer>::UpdateVersion;

View File

@ -32,7 +32,7 @@ class Encoder : private BaseEncoder<Buffer> {
using BaseEncoder<Buffer>::buffer_; using BaseEncoder<Buffer>::buffer_;
public: public:
Encoder(Buffer &buffer) : BaseEncoder<Buffer>(buffer) {} explicit Encoder(Buffer &buffer) : BaseEncoder<Buffer>(buffer) {}
using BaseEncoder<Buffer>::UpdateVersion; using BaseEncoder<Buffer>::UpdateVersion;

View File

@ -93,7 +93,7 @@ State HandlePullDiscard(TSession &session, std::optional<int> n, std::optional<i
return State::Close; return State::Close;
} }
if (summary.count("has_more") && summary.at("has_more").ValueBool()) { if (summary.contains("has_more") && summary.at("has_more").ValueBool()) {
return State::Result; return State::Result;
} }
@ -148,13 +148,13 @@ State HandlePullDiscardV4(TSession &session, const State state, const Marker mar
spdlog::trace("Couldn't read extra field!"); spdlog::trace("Couldn't read extra field!");
} }
const auto &extra_map = extra.ValueMap(); const auto &extra_map = extra.ValueMap();
if (extra_map.count("n")) { if (extra_map.contains("n")) {
if (const auto n_value = extra_map.at("n").ValueInt(); n_value != kPullAll) { if (const auto n_value = extra_map.at("n").ValueInt(); n_value != kPullAll) {
n = n_value; n = n_value;
} }
} }
if (extra_map.count("qid")) { if (extra_map.contains("qid")) {
if (const auto qid_value = extra_map.at("qid").ValueInt(); qid_value != kPullLast) { if (const auto qid_value = extra_map.at("qid").ValueInt(); qid_value != kPullLast) {
qid = qid_value; qid = qid_value;
} }
@ -367,14 +367,16 @@ State HandleReset(TSession &session, const Marker marker) {
return State::Close; return State::Close;
} }
if (!session.encoder_.MessageSuccess()) { try {
spdlog::trace("Couldn't send success message!"); session.Abort();
return State::Close; if (!session.encoder_.MessageSuccess({})) {
spdlog::trace("Couldn't send success message!");
return State::Close;
}
return State::Idle;
} catch (const std::exception &e) {
return HandleFailure(session, e);
} }
session.Abort();
return State::Idle;
} }
template <typename TSession> template <typename TSession>
@ -397,19 +399,17 @@ State HandleBegin(TSession &session, const State state, const Marker marker) {
DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state"); DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state");
if (!session.encoder_.MessageSuccess({})) {
spdlog::trace("Couldn't send success message!");
return State::Close;
}
try { try {
session.Configure(extra.ValueMap()); session.Configure(extra.ValueMap());
session.BeginTransaction(extra.ValueMap()); session.BeginTransaction(extra.ValueMap());
if (!session.encoder_.MessageSuccess({})) {
spdlog::trace("Couldn't send success message!");
return State::Close;
}
return State::Idle;
} catch (const std::exception &e) { } catch (const std::exception &e) {
return HandleFailure(session, e); return HandleFailure(session, e);
} }
return State::Idle;
} }
template <typename TSession> template <typename TSession>
@ -427,11 +427,11 @@ State HandleCommit(TSession &session, const State state, const Marker marker) {
DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state"); DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state");
try { try {
session.CommitTransaction();
if (!session.encoder_.MessageSuccess({})) { if (!session.encoder_.MessageSuccess({})) {
spdlog::trace("Couldn't send success message!"); spdlog::trace("Couldn't send success message!");
return State::Close; return State::Close;
} }
session.CommitTransaction();
return State::Idle; return State::Idle;
} catch (const std::exception &e) { } catch (const std::exception &e) {
return HandleFailure(session, e); return HandleFailure(session, e);
@ -453,11 +453,11 @@ State HandleRollback(TSession &session, const State state, const Marker marker)
DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state"); DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state");
try { try {
session.RollbackTransaction();
if (!session.encoder_.MessageSuccess({})) { if (!session.encoder_.MessageSuccess({})) {
spdlog::trace("Couldn't send success message!"); spdlog::trace("Couldn't send success message!");
return State::Close; return State::Close;
} }
session.RollbackTransaction();
return State::Idle; return State::Idle;
} catch (const std::exception &e) { } catch (const std::exception &e) {
return HandleFailure(session, e); return HandleFailure(session, e);

View File

@ -42,11 +42,11 @@ std::optional<State> AuthenticateUser(TSession &session, Value &metadata) {
std::string username; std::string username;
std::string password; std::string password;
if (data["scheme"].ValueString() == "basic") { if (data["scheme"].ValueString() == "basic") {
if (!data.count("principal")) { // Special case principal = "" if (!data.contains("principal")) { // Special case principal = ""
spdlog::warn("The client didn't supply the principal field! Trying with \"\"..."); spdlog::warn("The client didn't supply the principal field! Trying with \"\"...");
data["principal"] = ""; data["principal"] = "";
} }
if (!data.count("credentials")) { // Special case credentials = "" if (!data.contains("credentials")) { // Special case credentials = ""
spdlog::warn("The client didn't supply the credentials field! Trying with \"\"..."); spdlog::warn("The client didn't supply the credentials field! Trying with \"\"...");
data["credentials"] = ""; data["credentials"] = "";
} }
@ -118,7 +118,7 @@ std::optional<Value> GetMetadataV4(TSession &session, const Marker marker) {
} }
auto &data = metadata.ValueMap(); auto &data = metadata.ValueMap();
if (!data.count("user_agent")) { if (!data.contains("user_agent")) {
spdlog::warn("The client didn't supply the user agent!"); spdlog::warn("The client didn't supply the user agent!");
return std::nullopt; return std::nullopt;
} }
@ -142,7 +142,7 @@ std::optional<Value> GetInitDataV5(TSession &session, const Marker marker) {
} }
const auto &data = metadata.ValueMap(); const auto &data = metadata.ValueMap();
if (!data.count("user_agent")) { if (!data.contains("user_agent")) {
spdlog::warn("The client didn't supply the user agent!"); spdlog::warn("The client didn't supply the user agent!");
return std::nullopt; return std::nullopt;
} }

View File

@ -91,7 +91,7 @@ struct UnboundedEdge {
* The decoder writes data into this structure. * The decoder writes data into this structure.
*/ */
struct Path { struct Path {
Path() {} Path() = default;
Path(const std::vector<Vertex> &vertices, const std::vector<Edge> &edges) { Path(const std::vector<Vertex> &vertices, const std::vector<Edge> &edges) {
// Helper function. Looks for the given element in the collection. If found, // Helper function. Looks for the given element in the collection. If found,

View File

@ -132,7 +132,7 @@ class Client final {
*/ */
class ClientInputStream final { class ClientInputStream final {
public: public:
ClientInputStream(Client &client); explicit ClientInputStream(Client &client);
ClientInputStream(const ClientInputStream &) = delete; ClientInputStream(const ClientInputStream &) = delete;
ClientInputStream(ClientInputStream &&) = delete; ClientInputStream(ClientInputStream &&) = delete;
@ -156,7 +156,7 @@ class ClientInputStream final {
*/ */
class ClientOutputStream final { class ClientOutputStream final {
public: public:
ClientOutputStream(Client &client); explicit ClientOutputStream(Client &client);
ClientOutputStream(const ClientOutputStream &) = delete; ClientOutputStream(const ClientOutputStream &) = delete;
ClientOutputStream(ClientOutputStream &&) = delete; ClientOutputStream(ClientOutputStream &&) = delete;

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -34,7 +34,7 @@ ClientContext::ClientContext(bool use_ssl) : use_ssl_(use_ssl), ctx_(nullptr) {
} }
ClientContext::ClientContext(const std::string &key_file, const std::string &cert_file) : ClientContext(true) { ClientContext::ClientContext(const std::string &key_file, const std::string &cert_file) : ClientContext(true) {
if (key_file != "" && cert_file != "") { if (!key_file.empty() && !cert_file.empty()) {
MG_ASSERT(SSL_CTX_use_certificate_file(ctx_, cert_file.c_str(), SSL_FILETYPE_PEM) == 1, MG_ASSERT(SSL_CTX_use_certificate_file(ctx_, cert_file.c_str(), SSL_FILETYPE_PEM) == 1,
"Couldn't load client certificate from file: {}", cert_file); "Couldn't load client certificate from file: {}", cert_file);
MG_ASSERT(SSL_CTX_use_PrivateKey_file(ctx_, key_file.c_str(), SSL_FILETYPE_PEM) == 1, MG_ASSERT(SSL_CTX_use_PrivateKey_file(ctx_, key_file.c_str(), SSL_FILETYPE_PEM) == 1,
@ -124,7 +124,7 @@ ServerContext &ServerContext::operator=(ServerContext &&other) noexcept {
return *this; return *this;
} }
ServerContext::~ServerContext() {} ServerContext::~ServerContext() = default;
SSL_CTX *ServerContext::context() { SSL_CTX *ServerContext::context() {
MG_ASSERT(ctx_); MG_ASSERT(ctx_);

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -15,7 +15,7 @@
namespace memgraph::communication { namespace memgraph::communication {
const std::string SslGetLastError() { std::string SslGetLastError() {
char buff[2048]; char buff[2048];
auto err = ERR_get_error(); auto err = ERR_get_error();
ERR_error_string_n(err, buff, sizeof(buff)); ERR_error_string_n(err, buff, sizeof(buff));

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -18,6 +18,6 @@ namespace memgraph::communication {
/** /**
* This function reads and returns a string describing the last OpenSSL error. * This function reads and returns a string describing the last OpenSSL error.
*/ */
const std::string SslGetLastError(); std::string SslGetLastError();
} // namespace memgraph::communication } // namespace memgraph::communication

View File

@ -38,7 +38,7 @@ class Listener final : public std::enable_shared_from_this<Listener<TRequestHand
Listener(Listener &&) = delete; Listener(Listener &&) = delete;
Listener &operator=(const Listener &) = delete; Listener &operator=(const Listener &) = delete;
Listener &operator=(Listener &&) = delete; Listener &operator=(Listener &&) = delete;
~Listener() {} ~Listener() = default;
template <typename... Args> template <typename... Args>
static std::shared_ptr<Listener> Create(Args &&...args) { static std::shared_ptr<Listener> Create(Args &&...args) {

View File

@ -17,6 +17,7 @@
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <thread> #include <thread>
#include <utility>
#include <gflags/gflags.h> #include <gflags/gflags.h>
@ -51,13 +52,13 @@ class Listener final {
using SessionHandler = Session<TSession, TSessionContext>; using SessionHandler = Session<TSession, TSessionContext>;
public: public:
Listener(TSessionContext *data, ServerContext *context, int inactivity_timeout_sec, const std::string &service_name, Listener(TSessionContext *data, ServerContext *context, int inactivity_timeout_sec, std::string service_name,
size_t workers_count) size_t workers_count)
: data_(data), : data_(data),
alive_(false), alive_(false),
context_(context), context_(context),
inactivity_timeout_sec_(inactivity_timeout_sec), inactivity_timeout_sec_(inactivity_timeout_sec),
service_name_(service_name), service_name_(std::move(service_name)),
workers_count_(workers_count) {} workers_count_(workers_count) {}
~Listener() { ~Listener() {

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -80,7 +80,7 @@ class ResultStreamFaker {
std::transform(header.begin(), header.end(), column_widths.begin(), [](const auto &s) { return s.size(); }); std::transform(header.begin(), header.end(), column_widths.begin(), [](const auto &s) { return s.size(); });
// convert all the results into strings, and track max column width // convert all the results into strings, and track max column width
auto &results_data = results.GetResults(); const auto &results_data = results.GetResults();
std::vector<std::vector<std::string>> result_strings(results_data.size(), std::vector<std::vector<std::string>> result_strings(results_data.size(),
std::vector<std::string>(column_widths.size())); std::vector<std::string>(column_widths.size()));
for (int row_ind = 0; row_ind < static_cast<int>(results_data.size()); ++row_ind) { for (int row_ind = 0; row_ind < static_cast<int>(results_data.size()); ++row_ind) {

View File

@ -17,6 +17,7 @@
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <thread> #include <thread>
#include <utility>
#include <openssl/bio.h> #include <openssl/bio.h>
#include <openssl/err.h> #include <openssl/err.h>
@ -51,7 +52,8 @@ using InputStream = Buffer::ReadEnd;
*/ */
class OutputStream final { class OutputStream final {
public: public:
OutputStream(std::function<bool(const uint8_t *, size_t, bool)> write_function) : write_function_(write_function) {} explicit OutputStream(std::function<bool(const uint8_t *, size_t, bool)> write_function)
: write_function_(std::move(write_function)) {}
OutputStream(const OutputStream &) = delete; OutputStream(const OutputStream &) = delete;
OutputStream(OutputStream &&) = delete; OutputStream(OutputStream &&) = delete;

View File

@ -47,7 +47,7 @@ class Listener final : public std::enable_shared_from_this<Listener<TSession, TS
Listener(Listener &&) = delete; Listener(Listener &&) = delete;
Listener &operator=(const Listener &) = delete; Listener &operator=(const Listener &) = delete;
Listener &operator=(Listener &&) = delete; Listener &operator=(Listener &&) = delete;
~Listener() {} ~Listener() = default;
template <typename... Args> template <typename... Args>
static std::shared_ptr<Listener> Create(Args &&...args) { static std::shared_ptr<Listener> Create(Args &&...args) {

View File

@ -76,7 +76,7 @@ using tcp = boost::asio::ip::tcp;
class OutputStream final { class OutputStream final {
public: public:
explicit OutputStream(std::function<bool(const uint8_t *, size_t, bool)> write_function) explicit OutputStream(std::function<bool(const uint8_t *, size_t, bool)> write_function)
: write_function_(write_function) {} : write_function_(std::move(write_function)) {}
OutputStream(const OutputStream &) = delete; OutputStream(const OutputStream &) = delete;
OutputStream(OutputStream &&) = delete; OutputStream(OutputStream &&) = delete;

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -16,6 +16,7 @@
#include <spdlog/sinks/base_sink.h> #include <spdlog/sinks/base_sink.h>
#include <boost/asio/io_context.hpp> #include <boost/asio/io_context.hpp>
#include <boost/asio/ip/tcp.hpp> #include <boost/asio/ip/tcp.hpp>
#include <utility>
#include "communication/websocket/listener.hpp" #include "communication/websocket/listener.hpp"
#include "io/network/endpoint.hpp" #include "io/network/endpoint.hpp"
@ -45,7 +46,7 @@ class Server final {
class LoggingSink : public spdlog::sinks::base_sink<std::mutex> { class LoggingSink : public spdlog::sinks::base_sink<std::mutex> {
public: public:
explicit LoggingSink(std::weak_ptr<Listener> listener) : listener_(listener) {} explicit LoggingSink(std::weak_ptr<Listener> listener) : listener_(std::move(listener)) {}
private: private:
void sink_it_(const spdlog::details::log_msg &msg) override; void sink_it_(const spdlog::details::log_msg &msg) override;

View File

@ -1,3 +1,3 @@
add_library(mg-dbms STATIC database.cpp replication_handler.cpp inmemory/replication_handlers.cpp) add_library(mg-dbms STATIC dbms_handler.cpp database.cpp replication_handler.cpp replication_client.cpp inmemory/replication_handlers.cpp)
target_link_libraries(mg-dbms mg-utils mg-storage-v2 mg-query) target_link_libraries(mg-dbms mg-utils mg-storage-v2 mg-query)

View File

@ -15,4 +15,10 @@ namespace memgraph::dbms {
constexpr static const char *kDefaultDB = "memgraph"; //!< Name of the default database constexpr static const char *kDefaultDB = "memgraph"; //!< Name of the default database
#ifdef MG_EXPERIMENTAL_REPLICATION_MULTITENANCY
constexpr bool allow_mt_repl = true;
#else
constexpr bool allow_mt_repl = false;
#endif
} // namespace memgraph::dbms } // namespace memgraph::dbms

View File

@ -21,7 +21,7 @@ template struct memgraph::utils::Gatekeeper<memgraph::dbms::Database>;
namespace memgraph::dbms { namespace memgraph::dbms {
Database::Database(storage::Config config, const replication::ReplicationState &repl_state) Database::Database(storage::Config config, replication::ReplicationState &repl_state)
: trigger_store_(config.durability.storage_directory / "triggers"), : trigger_store_(config.durability.storage_directory / "triggers"),
streams_{config.durability.storage_directory / "streams"}, streams_{config.durability.storage_directory / "streams"},
plan_cache_{FLAGS_query_plan_cache_max_size}, plan_cache_{FLAGS_query_plan_cache_max_size},

View File

@ -48,7 +48,7 @@ class Database {
* *
* @param config storage configuration * @param config storage configuration
*/ */
explicit Database(storage::Config config, const replication::ReplicationState &repl_state); explicit Database(storage::Config config, replication::ReplicationState &repl_state);
/** /**
* @brief Returns the raw storage pointer. * @brief Returns the raw storage pointer.
@ -95,7 +95,7 @@ class Database {
* *
* @return storage::StorageMode * @return storage::StorageMode
*/ */
storage::StorageMode GetStorageMode() const { return storage_->GetStorageMode(); } storage::StorageMode GetStorageMode() const noexcept { return storage_->GetStorageMode(); }
/** /**
* @brief Get the storage info * @brief Get the storage info

View File

@ -51,8 +51,7 @@ class DatabaseHandler : public Handler<Database> {
* @param config Storage configuration * @param config Storage configuration
* @return HandlerT::NewResult * @return HandlerT::NewResult
*/ */
HandlerT::NewResult New(std::string_view name, storage::Config config, HandlerT::NewResult New(std::string_view name, storage::Config config, replication::ReplicationState &repl_state) {
const replication::ReplicationState &repl_state) {
// Control that no one is using the same data directory // Control that no one is using the same data directory
if (std::any_of(begin(), end(), [&](auto &elem) { if (std::any_of(begin(), end(), [&](auto &elem) {
auto db_acc = elem.second.access(); auto db_acc = elem.second.access();

75
src/dbms/dbms_handler.cpp Normal file
View File

@ -0,0 +1,75 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "dbms/dbms_handler.hpp"
namespace memgraph::dbms {
#ifdef MG_ENTERPRISE
DbmsHandler::DbmsHandler(
storage::Config config,
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth,
bool recovery_on_startup, bool delete_on_drop)
: default_config_{std::move(config)},
delete_on_drop_(delete_on_drop),
repl_state_{ReplicationStateRootPath(default_config_)} {
// TODO: Decouple storage config from dbms config
// TODO: Save individual db configs inside the kvstore and restore from there
storage::UpdatePaths(default_config_, default_config_.durability.storage_directory / "databases");
const auto &db_dir = default_config_.durability.storage_directory;
const auto durability_dir = db_dir / ".durability";
utils::EnsureDirOrDie(db_dir);
utils::EnsureDirOrDie(durability_dir);
durability_ = std::make_unique<kvstore::KVStore>(durability_dir);
// Generate the default database
MG_ASSERT(!NewDefault_().HasError(), "Failed while creating the default DB.");
// Recover previous databases
if (recovery_on_startup) {
for (const auto &[name, _] : *durability_) {
if (name == kDefaultDB) continue; // Already set
spdlog::info("Restoring database {}.", name);
MG_ASSERT(!New_(name).HasError(), "Failed while creating database {}.", name);
spdlog::info("Database {} restored.", name);
}
} else { // Clear databases from the durability list and auth
auto locked_auth = auth->Lock();
for (const auto &[name, _] : *durability_) {
if (name == kDefaultDB) continue;
locked_auth->DeleteDatabase(name);
durability_->Delete(name);
}
}
// Startup replication state (if recovered at startup)
auto replica = [this](replication::RoleReplicaData const &data) {
// Register handlers
InMemoryReplicationHandlers::Register(this, *data.server);
if (!data.server->Start()) {
spdlog::error("Unable to start the replication server.");
return false;
}
return true;
};
// Replication frequent check start
auto main = [this](replication::RoleMainData &data) {
for (auto &client : data.registered_replicas_) {
StartReplicaClient(*this, client);
}
return true;
};
// Startup proccess for main/replica
MG_ASSERT(std::visit(memgraph::utils::Overloaded{replica, main}, repl_state_.ReplicationData()),
"Replica recovery failure!");
}
#endif
} // namespace memgraph::dbms

View File

@ -26,9 +26,11 @@
#include "auth/auth.hpp" #include "auth/auth.hpp"
#include "constants.hpp" #include "constants.hpp"
#include "dbms/database.hpp" #include "dbms/database.hpp"
#include "dbms/inmemory/replication_handlers.hpp"
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
#include "dbms/database_handler.hpp" #include "dbms/database_handler.hpp"
#endif #endif
#include "dbms/replication_client.hpp"
#include "global.hpp" #include "global.hpp"
#include "query/config.hpp" #include "query/config.hpp"
#include "query/interpreter_context.hpp" #include "query/interpreter_context.hpp"
@ -102,52 +104,22 @@ class DbmsHandler {
* @param recovery_on_startup restore databases (and its content) and authentication data * @param recovery_on_startup restore databases (and its content) and authentication data
* @param delete_on_drop when dropping delete any associated directories on disk * @param delete_on_drop when dropping delete any associated directories on disk
*/ */
DbmsHandler(storage::Config config, const replication::ReplicationState &repl_state, auto *auth, DbmsHandler(storage::Config config,
bool recovery_on_startup, bool delete_on_drop) memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth,
: lock_{utils::RWLock::Priority::READ}, bool recovery_on_startup, bool delete_on_drop); // TODO If more arguments are added use a config strut
default_config_{std::move(config)},
repl_state_(repl_state),
delete_on_drop_(delete_on_drop) {
// TODO: Decouple storage config from dbms config
// TODO: Save individual db configs inside the kvstore and restore from there
storage::UpdatePaths(default_config_, default_config_.durability.storage_directory / "databases");
const auto &db_dir = default_config_.durability.storage_directory;
const auto durability_dir = db_dir / ".durability";
utils::EnsureDirOrDie(db_dir);
utils::EnsureDirOrDie(durability_dir);
durability_ = std::make_unique<kvstore::KVStore>(durability_dir);
// Generate the default database
MG_ASSERT(!NewDefault_().HasError(), "Failed while creating the default DB.");
// Recover previous databases
if (recovery_on_startup) {
for (const auto &[name, _] : *durability_) {
if (name == kDefaultDB) continue; // Already set
spdlog::info("Restoring database {}.", name);
MG_ASSERT(!New_(name).HasError(), "Failed while creating database {}.", name);
spdlog::info("Database {} restored.", name);
}
} else { // Clear databases from the durability list and auth
auto locked_auth = auth->Lock();
for (const auto &[name, _] : *durability_) {
if (name == kDefaultDB) continue;
locked_auth->DeleteDatabase(name);
durability_->Delete(name);
}
}
}
#else #else
/** /**
* @brief Initialize the handler. A single database is supported in community edition. * @brief Initialize the handler. A single database is supported in community edition.
* *
* @param configs storage configuration * @param configs storage configuration
*/ */
DbmsHandler(storage::Config config, const replication::ReplicationState &repl_state) DbmsHandler(storage::Config config)
: db_gatekeeper_{[&] { : repl_state_{ReplicationStateRootPath(config)},
db_gatekeeper_{[&] {
config.name = kDefaultDB; config.name = kDefaultDB;
return std::move(config); return std::move(config);
}(), }(),
repl_state} {} repl_state_} {}
#endif #endif
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
@ -248,6 +220,12 @@ class DbmsHandler {
#endif #endif
} }
replication::ReplicationState &ReplicationState() { return repl_state_; }
replication::ReplicationState const &ReplicationState() const { return repl_state_; }
bool IsMain() const { return repl_state_.IsMain(); }
bool IsReplica() const { return repl_state_.IsReplica(); }
/** /**
* @brief Return the statistics all databases. * @brief Return the statistics all databases.
* *
@ -536,14 +514,15 @@ class DbmsHandler {
throw UnknownDatabaseException("Tried to retrieve an unknown database \"{}\".", name); throw UnknownDatabaseException("Tried to retrieve an unknown database \"{}\".", name);
} }
mutable LockT lock_; //!< protective lock mutable LockT lock_{utils::RWLock::Priority::READ}; //!< protective lock
storage::Config default_config_; //!< Storage configuration used when creating new databases storage::Config default_config_; //!< Storage configuration used when creating new databases
const replication::ReplicationState &repl_state_; //!< Global replication state DatabaseHandler db_handler_; //!< multi-tenancy storage handler
DatabaseHandler db_handler_; //!< multi-tenancy storage handler std::unique_ptr<kvstore::KVStore> durability_; //!< list of active dbs (pointer so we can postpone its creation)
std::unique_ptr<kvstore::KVStore> durability_; //!< list of active dbs (pointer so we can postpone its creation) bool delete_on_drop_; //!< Flag defining if dropping storage also deletes its directory
bool delete_on_drop_; //!< Flag defining if dropping storage also deletes its directory std::set<std::string> defunct_dbs_; //!< Databases that are in an unknown state due to various failures
std::set<std::string> defunct_dbs_; //!< Databases that are in an unknown state due to various failures #endif
#else replication::ReplicationState repl_state_; //!< Global replication state
#ifndef MG_ENTERPRISE
mutable utils::Gatekeeper<Database> db_gatekeeper_; //!< Single databases gatekeeper mutable utils::Gatekeeper<Database> db_gatekeeper_; //!< Single databases gatekeeper
#endif #endif
}; };

View File

@ -38,7 +38,7 @@ class Handler {
* @brief Empty Handler constructor. * @brief Empty Handler constructor.
* *
*/ */
Handler() {} Handler() = default;
/** /**
* @brief Generate a new context and corresponding configuration. * @brief Generate a new context and corresponding configuration.

View File

@ -10,6 +10,7 @@
// licenses/APL.txt. // licenses/APL.txt.
#include "dbms/inmemory/replication_handlers.hpp" #include "dbms/inmemory/replication_handlers.hpp"
#include <optional>
#include "dbms/constants.hpp" #include "dbms/constants.hpp"
#include "dbms/dbms_handler.hpp" #include "dbms/dbms_handler.hpp"
#include "replication/replication_server.hpp" #include "replication/replication_server.hpp"
@ -187,9 +188,9 @@ void InMemoryReplicationHandlers::SnapshotHandler(dbms::DbmsHandler *dbms_handle
storage::replication::Decoder decoder(req_reader); storage::replication::Decoder decoder(req_reader);
auto *storage = static_cast<storage::InMemoryStorage *>(db_acc->get()->storage()); auto *storage = static_cast<storage::InMemoryStorage *>(db_acc->get()->storage());
utils::EnsureDirOrDie(storage->snapshot_directory_); utils::EnsureDirOrDie(storage->recovery_.snapshot_directory_);
const auto maybe_snapshot_path = decoder.ReadFile(storage->snapshot_directory_); const auto maybe_snapshot_path = decoder.ReadFile(storage->recovery_.snapshot_directory_);
MG_ASSERT(maybe_snapshot_path, "Failed to load snapshot!"); MG_ASSERT(maybe_snapshot_path, "Failed to load snapshot!");
spdlog::info("Received snapshot saved to {}", *maybe_snapshot_path); spdlog::info("Received snapshot saved to {}", *maybe_snapshot_path);
@ -219,8 +220,12 @@ void InMemoryReplicationHandlers::SnapshotHandler(dbms::DbmsHandler *dbms_handle
storage->timestamp_ = std::max(storage->timestamp_, recovery_info.next_timestamp); storage->timestamp_ = std::max(storage->timestamp_, recovery_info.next_timestamp);
spdlog::trace("Recovering indices and constraints from snapshot."); spdlog::trace("Recovering indices and constraints from snapshot.");
storage::durability::RecoverIndicesAndConstraints(recovered_snapshot.indices_constraints, &storage->indices_, memgraph::storage::durability::RecoverIndicesAndStats(recovered_snapshot.indices_constraints.indices,
&storage->constraints_, &storage->vertices_); &storage->indices_, &storage->vertices_,
storage->name_id_mapper_.get());
memgraph::storage::durability::RecoverConstraints(recovered_snapshot.indices_constraints.constraints,
&storage->constraints_, &storage->vertices_,
storage->name_id_mapper_.get());
} catch (const storage::durability::RecoveryFailure &e) { } catch (const storage::durability::RecoveryFailure &e) {
LOG_FATAL("Couldn't load the snapshot because of: {}", e.what()); LOG_FATAL("Couldn't load the snapshot because of: {}", e.what());
} }
@ -232,7 +237,7 @@ void InMemoryReplicationHandlers::SnapshotHandler(dbms::DbmsHandler *dbms_handle
spdlog::trace("Deleting old snapshot files due to snapshot recovery."); spdlog::trace("Deleting old snapshot files due to snapshot recovery.");
// Delete other durability files // Delete other durability files
auto snapshot_files = storage::durability::GetSnapshotFiles(storage->snapshot_directory_, storage->uuid_); auto snapshot_files = storage::durability::GetSnapshotFiles(storage->recovery_.snapshot_directory_, storage->uuid_);
for (const auto &[path, uuid, _] : snapshot_files) { for (const auto &[path, uuid, _] : snapshot_files) {
if (path != *maybe_snapshot_path) { if (path != *maybe_snapshot_path) {
spdlog::trace("Deleting snapshot file {}", path); spdlog::trace("Deleting snapshot file {}", path);
@ -241,7 +246,7 @@ void InMemoryReplicationHandlers::SnapshotHandler(dbms::DbmsHandler *dbms_handle
} }
spdlog::trace("Deleting old WAL files due to snapshot recovery."); spdlog::trace("Deleting old WAL files due to snapshot recovery.");
auto wal_files = storage::durability::GetWalFiles(storage->wal_directory_, storage->uuid_); auto wal_files = storage::durability::GetWalFiles(storage->recovery_.wal_directory_, storage->uuid_);
if (wal_files) { if (wal_files) {
for (const auto &wal_file : *wal_files) { for (const auto &wal_file : *wal_files) {
spdlog::trace("Deleting WAL file {}", wal_file.path); spdlog::trace("Deleting WAL file {}", wal_file.path);
@ -266,7 +271,7 @@ void InMemoryReplicationHandlers::WalFilesHandler(dbms::DbmsHandler *dbms_handle
storage::replication::Decoder decoder(req_reader); storage::replication::Decoder decoder(req_reader);
auto *storage = static_cast<storage::InMemoryStorage *>(db_acc->get()->storage()); auto *storage = static_cast<storage::InMemoryStorage *>(db_acc->get()->storage());
utils::EnsureDirOrDie(storage->wal_directory_); utils::EnsureDirOrDie(storage->recovery_.wal_directory_);
for (auto i = 0; i < wal_file_number; ++i) { for (auto i = 0; i < wal_file_number; ++i) {
LoadWal(storage, &decoder); LoadWal(storage, &decoder);
@ -288,7 +293,7 @@ void InMemoryReplicationHandlers::CurrentWalHandler(dbms::DbmsHandler *dbms_hand
storage::replication::Decoder decoder(req_reader); storage::replication::Decoder decoder(req_reader);
auto *storage = static_cast<storage::InMemoryStorage *>(db_acc->get()->storage()); auto *storage = static_cast<storage::InMemoryStorage *>(db_acc->get()->storage());
utils::EnsureDirOrDie(storage->wal_directory_); utils::EnsureDirOrDie(storage->recovery_.wal_directory_);
LoadWal(storage, &decoder); LoadWal(storage, &decoder);
@ -369,8 +374,9 @@ uint64_t InMemoryReplicationHandlers::ReadAndApplyDelta(storage::InMemoryStorage
constexpr bool kSharedAccess = false; constexpr bool kSharedAccess = false;
std::optional<std::pair<uint64_t, storage::InMemoryStorage::ReplicationAccessor>> commit_timestamp_and_accessor; std::optional<std::pair<uint64_t, storage::InMemoryStorage::ReplicationAccessor>> commit_timestamp_and_accessor;
auto get_transaction = [storage, &commit_timestamp_and_accessor](uint64_t commit_timestamp, auto const get_transaction = [storage, &commit_timestamp_and_accessor](
bool unique = kSharedAccess) { uint64_t commit_timestamp,
bool unique = kSharedAccess) -> storage::InMemoryStorage::ReplicationAccessor * {
if (!commit_timestamp_and_accessor) { if (!commit_timestamp_and_accessor) {
std::unique_ptr<storage::Storage::Accessor> acc = nullptr; std::unique_ptr<storage::Storage::Accessor> acc = nullptr;
if (unique) { if (unique) {
@ -414,9 +420,11 @@ uint64_t InMemoryReplicationHandlers::ReadAndApplyDelta(storage::InMemoryStorage
spdlog::trace(" Delete vertex {}", delta.vertex_create_delete.gid.AsUint()); spdlog::trace(" Delete vertex {}", delta.vertex_create_delete.gid.AsUint());
auto *transaction = get_transaction(timestamp); auto *transaction = get_transaction(timestamp);
auto vertex = transaction->FindVertex(delta.vertex_create_delete.gid, View::NEW); auto vertex = transaction->FindVertex(delta.vertex_create_delete.gid, View::NEW);
if (!vertex) throw utils::BasicException("Invalid transaction!"); if (!vertex)
throw utils::BasicException("Invalid transaction! Please raise an issue, {}:{}", __FILE__, __LINE__);
auto ret = transaction->DeleteVertex(&*vertex); auto ret = transaction->DeleteVertex(&*vertex);
if (ret.HasError() || !ret.GetValue()) throw utils::BasicException("Invalid transaction!"); if (ret.HasError() || !ret.GetValue())
throw utils::BasicException("Invalid transaction! Please raise an issue, {}:{}", __FILE__, __LINE__);
break; break;
} }
case WalDeltaData::Type::VERTEX_ADD_LABEL: { case WalDeltaData::Type::VERTEX_ADD_LABEL: {
@ -424,9 +432,11 @@ uint64_t InMemoryReplicationHandlers::ReadAndApplyDelta(storage::InMemoryStorage
delta.vertex_add_remove_label.label); delta.vertex_add_remove_label.label);
auto *transaction = get_transaction(timestamp); auto *transaction = get_transaction(timestamp);
auto vertex = transaction->FindVertex(delta.vertex_add_remove_label.gid, View::NEW); auto vertex = transaction->FindVertex(delta.vertex_add_remove_label.gid, View::NEW);
if (!vertex) throw utils::BasicException("Invalid transaction!"); if (!vertex)
throw utils::BasicException("Invalid transaction! Please raise an issue, {}:{}", __FILE__, __LINE__);
auto ret = vertex->AddLabel(transaction->NameToLabel(delta.vertex_add_remove_label.label)); auto ret = vertex->AddLabel(transaction->NameToLabel(delta.vertex_add_remove_label.label));
if (ret.HasError() || !ret.GetValue()) throw utils::BasicException("Invalid transaction!"); if (ret.HasError() || !ret.GetValue())
throw utils::BasicException("Invalid transaction! Please raise an issue, {}:{}", __FILE__, __LINE__);
break; break;
} }
case WalDeltaData::Type::VERTEX_REMOVE_LABEL: { case WalDeltaData::Type::VERTEX_REMOVE_LABEL: {
@ -434,9 +444,11 @@ uint64_t InMemoryReplicationHandlers::ReadAndApplyDelta(storage::InMemoryStorage
delta.vertex_add_remove_label.label); delta.vertex_add_remove_label.label);
auto *transaction = get_transaction(timestamp); auto *transaction = get_transaction(timestamp);
auto vertex = transaction->FindVertex(delta.vertex_add_remove_label.gid, View::NEW); auto vertex = transaction->FindVertex(delta.vertex_add_remove_label.gid, View::NEW);
if (!vertex) throw utils::BasicException("Invalid transaction!"); if (!vertex)
throw utils::BasicException("Invalid transaction! Please raise an issue, {}:{}", __FILE__, __LINE__);
auto ret = vertex->RemoveLabel(transaction->NameToLabel(delta.vertex_add_remove_label.label)); auto ret = vertex->RemoveLabel(transaction->NameToLabel(delta.vertex_add_remove_label.label));
if (ret.HasError() || !ret.GetValue()) throw utils::BasicException("Invalid transaction!"); if (ret.HasError() || !ret.GetValue())
throw utils::BasicException("Invalid transaction! Please raise an issue, {}:{}", __FILE__, __LINE__);
break; break;
} }
case WalDeltaData::Type::VERTEX_SET_PROPERTY: { case WalDeltaData::Type::VERTEX_SET_PROPERTY: {
@ -444,10 +456,12 @@ uint64_t InMemoryReplicationHandlers::ReadAndApplyDelta(storage::InMemoryStorage
delta.vertex_edge_set_property.property, delta.vertex_edge_set_property.value); delta.vertex_edge_set_property.property, delta.vertex_edge_set_property.value);
auto *transaction = get_transaction(timestamp); auto *transaction = get_transaction(timestamp);
auto vertex = transaction->FindVertex(delta.vertex_edge_set_property.gid, View::NEW); auto vertex = transaction->FindVertex(delta.vertex_edge_set_property.gid, View::NEW);
if (!vertex) throw utils::BasicException("Invalid transaction!"); if (!vertex)
throw utils::BasicException("Invalid transaction! Please raise an issue, {}:{}", __FILE__, __LINE__);
auto ret = vertex->SetProperty(transaction->NameToProperty(delta.vertex_edge_set_property.property), auto ret = vertex->SetProperty(transaction->NameToProperty(delta.vertex_edge_set_property.property),
delta.vertex_edge_set_property.value); delta.vertex_edge_set_property.value);
if (ret.HasError()) throw utils::BasicException("Invalid transaction!"); if (ret.HasError())
throw utils::BasicException("Invalid transaction! Please raise an issue, {}:{}", __FILE__, __LINE__);
break; break;
} }
case WalDeltaData::Type::EDGE_CREATE: { case WalDeltaData::Type::EDGE_CREATE: {
@ -456,13 +470,16 @@ uint64_t InMemoryReplicationHandlers::ReadAndApplyDelta(storage::InMemoryStorage
delta.edge_create_delete.from_vertex.AsUint(), delta.edge_create_delete.to_vertex.AsUint()); delta.edge_create_delete.from_vertex.AsUint(), delta.edge_create_delete.to_vertex.AsUint());
auto *transaction = get_transaction(timestamp); auto *transaction = get_transaction(timestamp);
auto from_vertex = transaction->FindVertex(delta.edge_create_delete.from_vertex, View::NEW); auto from_vertex = transaction->FindVertex(delta.edge_create_delete.from_vertex, View::NEW);
if (!from_vertex) throw utils::BasicException("Invalid transaction!"); if (!from_vertex)
throw utils::BasicException("Invalid transaction! Please raise an issue, {}:{}", __FILE__, __LINE__);
auto to_vertex = transaction->FindVertex(delta.edge_create_delete.to_vertex, View::NEW); auto to_vertex = transaction->FindVertex(delta.edge_create_delete.to_vertex, View::NEW);
if (!to_vertex) throw utils::BasicException("Invalid transaction!"); if (!to_vertex)
throw utils::BasicException("Invalid transaction! Please raise an issue, {}:{}", __FILE__, __LINE__);
auto edge = transaction->CreateEdgeEx(&*from_vertex, &*to_vertex, auto edge = transaction->CreateEdgeEx(&*from_vertex, &*to_vertex,
transaction->NameToEdgeType(delta.edge_create_delete.edge_type), transaction->NameToEdgeType(delta.edge_create_delete.edge_type),
delta.edge_create_delete.gid); delta.edge_create_delete.gid);
if (edge.HasError()) throw utils::BasicException("Invalid transaction!"); if (edge.HasError())
throw utils::BasicException("Invalid transaction! Please raise an issue, {}:{}", __FILE__, __LINE__);
break; break;
} }
case WalDeltaData::Type::EDGE_DELETE: { case WalDeltaData::Type::EDGE_DELETE: {
@ -471,16 +488,17 @@ uint64_t InMemoryReplicationHandlers::ReadAndApplyDelta(storage::InMemoryStorage
delta.edge_create_delete.from_vertex.AsUint(), delta.edge_create_delete.to_vertex.AsUint()); delta.edge_create_delete.from_vertex.AsUint(), delta.edge_create_delete.to_vertex.AsUint());
auto *transaction = get_transaction(timestamp); auto *transaction = get_transaction(timestamp);
auto from_vertex = transaction->FindVertex(delta.edge_create_delete.from_vertex, View::NEW); auto from_vertex = transaction->FindVertex(delta.edge_create_delete.from_vertex, View::NEW);
if (!from_vertex) throw utils::BasicException("Invalid transaction!"); if (!from_vertex)
throw utils::BasicException("Invalid transaction! Please raise an issue, {}:{}", __FILE__, __LINE__);
auto to_vertex = transaction->FindVertex(delta.edge_create_delete.to_vertex, View::NEW); auto to_vertex = transaction->FindVertex(delta.edge_create_delete.to_vertex, View::NEW);
if (!to_vertex) throw utils::BasicException("Invalid transaction!"); if (!to_vertex)
auto edges = from_vertex->OutEdges(View::NEW, {transaction->NameToEdgeType(delta.edge_create_delete.edge_type)}, throw utils::BasicException("Invalid transaction! Please raise an issue, {}:{}", __FILE__, __LINE__);
&*to_vertex); auto edgeType = transaction->NameToEdgeType(delta.edge_create_delete.edge_type);
if (edges.HasError()) throw utils::BasicException("Invalid transaction!"); auto edge =
if (edges->edges.size() != 1) throw utils::BasicException("Invalid transaction!"); transaction->FindEdge(delta.edge_create_delete.gid, View::NEW, edgeType, &*from_vertex, &*to_vertex);
auto &edge = (*edges).edges[0]; if (!edge) throw utils::BasicException("Invalid transaction! Please raise an issue, {}:{}", __FILE__, __LINE__);
auto ret = transaction->DeleteEdge(&edge); if (auto ret = transaction->DeleteEdge(&*edge); ret.HasError())
if (ret.HasError()) throw utils::BasicException("Invalid transaction!"); throw utils::BasicException("Invalid transaction! Please raise an issue, {}:{}", __FILE__, __LINE__);
break; break;
} }
case WalDeltaData::Type::EDGE_SET_PROPERTY: { case WalDeltaData::Type::EDGE_SET_PROPERTY: {
@ -497,7 +515,8 @@ uint64_t InMemoryReplicationHandlers::ReadAndApplyDelta(storage::InMemoryStorage
// yields an accessor that is only valid for managing the edge's // yields an accessor that is only valid for managing the edge's
// properties. // properties.
auto edge = edge_acc.find(delta.vertex_edge_set_property.gid); auto edge = edge_acc.find(delta.vertex_edge_set_property.gid);
if (edge == edge_acc.end()) throw utils::BasicException("Invalid transaction!"); if (edge == edge_acc.end())
throw utils::BasicException("Invalid transaction! Please raise an issue, {}:{}", __FILE__, __LINE__);
// The edge visibility check must be done here manually because we // The edge visibility check must be done here manually because we
// don't allow direct access to the edges through the public API. // don't allow direct access to the edges through the public API.
{ {
@ -529,7 +548,8 @@ uint64_t InMemoryReplicationHandlers::ReadAndApplyDelta(storage::InMemoryStorage
} }
} }
}); });
if (!is_visible) throw utils::BasicException("Invalid transaction!"); if (!is_visible)
throw utils::BasicException("Invalid transaction! Please raise an issue, {}:{}", __FILE__, __LINE__);
} }
EdgeRef edge_ref(&*edge); EdgeRef edge_ref(&*edge);
// Here we create an edge accessor that we will use to get the // Here we create an edge accessor that we will use to get the
@ -542,7 +562,8 @@ uint64_t InMemoryReplicationHandlers::ReadAndApplyDelta(storage::InMemoryStorage
auto ret = ea.SetProperty(transaction->NameToProperty(delta.vertex_edge_set_property.property), auto ret = ea.SetProperty(transaction->NameToProperty(delta.vertex_edge_set_property.property),
delta.vertex_edge_set_property.value); delta.vertex_edge_set_property.value);
if (ret.HasError()) throw utils::BasicException("Invalid transaction!"); if (ret.HasError())
throw utils::BasicException("Invalid transaction! Please raise an issue, {}:{}", __FILE__, __LINE__);
break; break;
} }
@ -552,7 +573,8 @@ uint64_t InMemoryReplicationHandlers::ReadAndApplyDelta(storage::InMemoryStorage
throw utils::BasicException("Invalid commit data!"); throw utils::BasicException("Invalid commit data!");
auto ret = auto ret =
commit_timestamp_and_accessor->second.Commit(commit_timestamp_and_accessor->first, false /* not main */); commit_timestamp_and_accessor->second.Commit(commit_timestamp_and_accessor->first, false /* not main */);
if (ret.HasError()) throw utils::BasicException("Invalid transaction!"); if (ret.HasError())
throw utils::BasicException("Invalid transaction! Please raise an issue, {}:{}", __FILE__, __LINE__);
commit_timestamp_and_accessor = std::nullopt; commit_timestamp_and_accessor = std::nullopt;
break; break;
} }
@ -562,14 +584,14 @@ uint64_t InMemoryReplicationHandlers::ReadAndApplyDelta(storage::InMemoryStorage
// Need to send the timestamp // Need to send the timestamp
auto *transaction = get_transaction(timestamp, kUniqueAccess); auto *transaction = get_transaction(timestamp, kUniqueAccess);
if (transaction->CreateIndex(storage->NameToLabel(delta.operation_label.label)).HasError()) if (transaction->CreateIndex(storage->NameToLabel(delta.operation_label.label)).HasError())
throw utils::BasicException("Invalid transaction!"); throw utils::BasicException("Invalid transaction! Please raise an issue, {}:{}", __FILE__, __LINE__);
break; break;
} }
case WalDeltaData::Type::LABEL_INDEX_DROP: { case WalDeltaData::Type::LABEL_INDEX_DROP: {
spdlog::trace(" Drop label index on :{}", delta.operation_label.label); spdlog::trace(" Drop label index on :{}", delta.operation_label.label);
auto *transaction = get_transaction(timestamp, kUniqueAccess); auto *transaction = get_transaction(timestamp, kUniqueAccess);
if (transaction->DropIndex(storage->NameToLabel(delta.operation_label.label)).HasError()) if (transaction->DropIndex(storage->NameToLabel(delta.operation_label.label)).HasError())
throw utils::BasicException("Invalid transaction!"); throw utils::BasicException("Invalid transaction! Please raise an issue, {}:{}", __FILE__, __LINE__);
break; break;
} }
case WalDeltaData::Type::LABEL_INDEX_STATS_SET: { case WalDeltaData::Type::LABEL_INDEX_STATS_SET: {
@ -600,7 +622,7 @@ uint64_t InMemoryReplicationHandlers::ReadAndApplyDelta(storage::InMemoryStorage
->CreateIndex(storage->NameToLabel(delta.operation_label_property.label), ->CreateIndex(storage->NameToLabel(delta.operation_label_property.label),
storage->NameToProperty(delta.operation_label_property.property)) storage->NameToProperty(delta.operation_label_property.property))
.HasError()) .HasError())
throw utils::BasicException("Invalid transaction!"); throw utils::BasicException("Invalid transaction! Please raise an issue, {}:{}", __FILE__, __LINE__);
break; break;
} }
case WalDeltaData::Type::LABEL_PROPERTY_INDEX_DROP: { case WalDeltaData::Type::LABEL_PROPERTY_INDEX_DROP: {
@ -611,7 +633,7 @@ uint64_t InMemoryReplicationHandlers::ReadAndApplyDelta(storage::InMemoryStorage
->DropIndex(storage->NameToLabel(delta.operation_label_property.label), ->DropIndex(storage->NameToLabel(delta.operation_label_property.label),
storage->NameToProperty(delta.operation_label_property.property)) storage->NameToProperty(delta.operation_label_property.property))
.HasError()) .HasError())
throw utils::BasicException("Invalid transaction!"); throw utils::BasicException("Invalid transaction! Please raise an issue, {}:{}", __FILE__, __LINE__);
break; break;
} }
case WalDeltaData::Type::LABEL_PROPERTY_INDEX_STATS_SET: { case WalDeltaData::Type::LABEL_PROPERTY_INDEX_STATS_SET: {
@ -643,7 +665,8 @@ uint64_t InMemoryReplicationHandlers::ReadAndApplyDelta(storage::InMemoryStorage
auto ret = auto ret =
transaction->CreateExistenceConstraint(storage->NameToLabel(delta.operation_label_property.label), transaction->CreateExistenceConstraint(storage->NameToLabel(delta.operation_label_property.label),
storage->NameToProperty(delta.operation_label_property.property)); storage->NameToProperty(delta.operation_label_property.property));
if (ret.HasError()) throw utils::BasicException("Invalid transaction!"); if (ret.HasError())
throw utils::BasicException("Invalid transaction! Please raise an issue, {}:{}", __FILE__, __LINE__);
break; break;
} }
case WalDeltaData::Type::EXISTENCE_CONSTRAINT_DROP: { case WalDeltaData::Type::EXISTENCE_CONSTRAINT_DROP: {
@ -654,7 +677,7 @@ uint64_t InMemoryReplicationHandlers::ReadAndApplyDelta(storage::InMemoryStorage
->DropExistenceConstraint(storage->NameToLabel(delta.operation_label_property.label), ->DropExistenceConstraint(storage->NameToLabel(delta.operation_label_property.label),
storage->NameToProperty(delta.operation_label_property.property)) storage->NameToProperty(delta.operation_label_property.property))
.HasError()) .HasError())
throw utils::BasicException("Invalid transaction!"); throw utils::BasicException("Invalid transaction! Please raise an issue, {}:{}", __FILE__, __LINE__);
break; break;
} }
case WalDeltaData::Type::UNIQUE_CONSTRAINT_CREATE: { case WalDeltaData::Type::UNIQUE_CONSTRAINT_CREATE: {
@ -669,7 +692,7 @@ uint64_t InMemoryReplicationHandlers::ReadAndApplyDelta(storage::InMemoryStorage
auto ret = transaction->CreateUniqueConstraint(storage->NameToLabel(delta.operation_label_properties.label), auto ret = transaction->CreateUniqueConstraint(storage->NameToLabel(delta.operation_label_properties.label),
properties); properties);
if (!ret.HasValue() || ret.GetValue() != UniqueConstraints::CreationStatus::SUCCESS) if (!ret.HasValue() || ret.GetValue() != UniqueConstraints::CreationStatus::SUCCESS)
throw utils::BasicException("Invalid transaction!"); throw utils::BasicException("Invalid transaction! Please raise an issue, {}:{}", __FILE__, __LINE__);
break; break;
} }
case WalDeltaData::Type::UNIQUE_CONSTRAINT_DROP: { case WalDeltaData::Type::UNIQUE_CONSTRAINT_DROP: {
@ -684,7 +707,7 @@ uint64_t InMemoryReplicationHandlers::ReadAndApplyDelta(storage::InMemoryStorage
auto ret = auto ret =
transaction->DropUniqueConstraint(storage->NameToLabel(delta.operation_label_properties.label), properties); transaction->DropUniqueConstraint(storage->NameToLabel(delta.operation_label_properties.label), properties);
if (ret != UniqueConstraints::DeletionStatus::SUCCESS) { if (ret != UniqueConstraints::DeletionStatus::SUCCESS) {
throw utils::BasicException("Invalid transaction!"); throw utils::BasicException("Invalid transaction! Please raise an issue, {}:{}", __FILE__, __LINE__);
} }
break; break;
} }

View File

@ -22,14 +22,8 @@
namespace memgraph::dbms { namespace memgraph::dbms {
#ifdef MG_EXPERIMENTAL_REPLICATION_MULTITENANCY inline std::unique_ptr<storage::Storage> CreateInMemoryStorage(storage::Config config,
constexpr bool allow_mt_repl = true; ::memgraph::replication::ReplicationState &repl_state) {
#else
constexpr bool allow_mt_repl = false;
#endif
inline std::unique_ptr<storage::Storage> CreateInMemoryStorage(
storage::Config config, const ::memgraph::replication::ReplicationState &repl_state) {
const auto wal_mode = config.durability.snapshot_wal_mode; const auto wal_mode = config.durability.snapshot_wal_mode;
const auto name = config.name; const auto name = config.name;
auto storage = std::make_unique<storage::InMemoryStorage>(std::move(config)); auto storage = std::make_unique<storage::InMemoryStorage>(std::move(config));

View File

@ -0,0 +1,34 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "dbms/replication_client.hpp"
namespace memgraph::dbms {
void StartReplicaClient(DbmsHandler &dbms_handler, replication::ReplicationClient &client) {
// No client error, start instance level client
auto const &endpoint = client.rpc_client_.Endpoint();
spdlog::trace("Replication client started at: {}:{}", endpoint.address, endpoint.port);
client.StartFrequentCheck([&dbms_handler](std::string_view name) {
// Working connection, check if any database has been left behind
dbms_handler.ForEach([name](dbms::Database *db) {
// Specific database <-> replica client
db->storage()->repl_storage_state_.WithClient(name, [&](storage::ReplicationStorageClient *client) {
if (client->State() == storage::replication::ReplicaState::MAYBE_BEHIND) {
// Database <-> replica might be behind, check and recover
client->TryCheckReplicaStateAsync(db->storage());
}
});
});
});
}
} // namespace memgraph::dbms

View File

@ -0,0 +1,21 @@
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "dbms/dbms_handler.hpp"
#include "replication/replication_client.hpp"
namespace memgraph::dbms {
void StartReplicaClient(DbmsHandler &dbms_handler, replication::ReplicationClient &client);
} // namespace memgraph::dbms

View File

@ -15,6 +15,7 @@
#include "dbms/dbms_handler.hpp" #include "dbms/dbms_handler.hpp"
#include "dbms/inmemory/replication_handlers.hpp" #include "dbms/inmemory/replication_handlers.hpp"
#include "dbms/inmemory/storage_helper.hpp" #include "dbms/inmemory/storage_helper.hpp"
#include "dbms/replication_client.hpp"
#include "replication/state.hpp" #include "replication/state.hpp"
using memgraph::replication::ReplicationClientConfig; using memgraph::replication::ReplicationClientConfig;
@ -41,6 +42,8 @@ std::string RegisterReplicaErrorToString(RegisterReplicaError error) {
} }
} // namespace } // namespace
ReplicationHandler::ReplicationHandler(DbmsHandler &dbms_handler) : dbms_handler_(dbms_handler) {}
bool ReplicationHandler::SetReplicationRoleMain() { bool ReplicationHandler::SetReplicationRoleMain() {
auto const main_handler = [](RoleMainData const &) { auto const main_handler = [](RoleMainData const &) {
// If we are already MAIN, we don't want to change anything // If we are already MAIN, we don't want to change anything
@ -56,42 +59,49 @@ bool ReplicationHandler::SetReplicationRoleMain() {
// STEP 2) Change to MAIN // STEP 2) Change to MAIN
// TODO: restore replication servers if false? // TODO: restore replication servers if false?
if (!repl_state_.SetReplicationRoleMain()) { if (!dbms_handler_.ReplicationState().SetReplicationRoleMain()) {
// TODO: Handle recovery on failure??? // TODO: Handle recovery on failure???
return false; return false;
} }
// STEP 3) We are now MAIN, update storage local epoch // STEP 3) We are now MAIN, update storage local epoch
const auto &epoch =
std::get<RoleMainData>(std::as_const(dbms_handler_.ReplicationState()).ReplicationData()).epoch_;
dbms_handler_.ForEach([&](Database *db) { dbms_handler_.ForEach([&](Database *db) {
auto *storage = db->storage(); auto *storage = db->storage();
storage->repl_storage_state_.epoch_ = std::get<RoleMainData>(std::as_const(repl_state_).ReplicationData()).epoch_; storage->repl_storage_state_.epoch_ = epoch;
}); });
return true; return true;
}; };
// TODO: under lock // TODO: under lock
return std::visit(utils::Overloaded{main_handler, replica_handler}, repl_state_.ReplicationData()); return std::visit(utils::Overloaded{main_handler, replica_handler},
dbms_handler_.ReplicationState().ReplicationData());
} }
bool ReplicationHandler::SetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config) { bool ReplicationHandler::SetReplicationRoleReplica(const memgraph::replication::ReplicationServerConfig &config) {
// We don't want to restart the server if we're already a REPLICA // We don't want to restart the server if we're already a REPLICA
if (repl_state_.IsReplica()) { if (dbms_handler_.ReplicationState().IsReplica()) {
return false; return false;
} }
// Remove registered replicas // TODO StorageState needs to be synched. Could have a dangling reference if someone adds a database as we are
// deleting the replica.
// Remove database specific clients
dbms_handler_.ForEach([&](Database *db) { dbms_handler_.ForEach([&](Database *db) {
auto *storage = db->storage(); auto *storage = db->storage();
storage->repl_storage_state_.replication_clients_.WithLock([](auto &clients) { clients.clear(); }); storage->repl_storage_state_.replication_clients_.WithLock([](auto &clients) { clients.clear(); });
}); });
// Remove instance level clients
std::get<RoleMainData>(dbms_handler_.ReplicationState().ReplicationData()).registered_replicas_.clear();
// Creates the server // Creates the server
repl_state_.SetReplicationRoleReplica(config); dbms_handler_.ReplicationState().SetReplicationRoleReplica(config);
// Start // Start
const auto success = const auto success =
std::visit(utils::Overloaded{[](auto) { std::visit(utils::Overloaded{[](RoleMainData const &) {
// ASSERT // ASSERT
return false; return false;
}, },
@ -104,36 +114,37 @@ bool ReplicationHandler::SetReplicationRoleReplica(const memgraph::replication::
} }
return true; return true;
}}, }},
repl_state_.ReplicationData()); dbms_handler_.ReplicationState().ReplicationData());
// TODO Handle error (restore to main?) // TODO Handle error (restore to main?)
return success; return success;
} }
auto ReplicationHandler::RegisterReplica(const memgraph::replication::ReplicationClientConfig &config) auto ReplicationHandler::RegisterReplica(const memgraph::replication::ReplicationClientConfig &config)
-> memgraph::utils::BasicResult<RegisterReplicaError> { -> memgraph::utils::BasicResult<RegisterReplicaError> {
MG_ASSERT(repl_state_.IsMain(), "Only main instance can register a replica!"); MG_ASSERT(dbms_handler_.ReplicationState().IsMain(), "Only main instance can register a replica!");
auto res = repl_state_.RegisterReplica(config); auto instance_client = dbms_handler_.ReplicationState().RegisterReplica(config);
switch (res) { if (instance_client.HasError()) switch (instance_client.GetError()) {
case memgraph::replication::RegisterReplicaError::NOT_MAIN: case memgraph::replication::RegisterReplicaError::NOT_MAIN:
MG_ASSERT(false, "Only main instance can register a replica!"); MG_ASSERT(false, "Only main instance can register a replica!");
return {}; return {};
case memgraph::replication::RegisterReplicaError::NAME_EXISTS: case memgraph::replication::RegisterReplicaError::NAME_EXISTS:
return memgraph::dbms::RegisterReplicaError::NAME_EXISTS; return memgraph::dbms::RegisterReplicaError::NAME_EXISTS;
case memgraph::replication::RegisterReplicaError::END_POINT_EXISTS: case memgraph::replication::RegisterReplicaError::END_POINT_EXISTS:
return memgraph::dbms::RegisterReplicaError::END_POINT_EXISTS; return memgraph::dbms::RegisterReplicaError::END_POINT_EXISTS;
case memgraph::replication::RegisterReplicaError::COULD_NOT_BE_PERSISTED: case memgraph::replication::RegisterReplicaError::COULD_NOT_BE_PERSISTED:
return memgraph::dbms::RegisterReplicaError::COULD_NOT_BE_PERSISTED; return memgraph::dbms::RegisterReplicaError::COULD_NOT_BE_PERSISTED;
case memgraph::replication::RegisterReplicaError::SUCCESS: case memgraph::replication::RegisterReplicaError::SUCCESS:
break; break;
} }
bool all_clients_good = true;
if (!allow_mt_repl && dbms_handler_.All().size() > 1) { if (!allow_mt_repl && dbms_handler_.All().size() > 1) {
spdlog::warn("Multi-tenant replication is currently not supported!"); spdlog::warn("Multi-tenant replication is currently not supported!");
} }
bool all_clients_good = true;
// Add database specific clients (NOTE Currently all databases are connected to each replica)
dbms_handler_.ForEach([&](Database *db) { dbms_handler_.ForEach([&](Database *db) {
auto *storage = db->storage(); auto *storage = db->storage();
if (!allow_mt_repl && storage->id() != kDefaultDB) { if (!allow_mt_repl && storage->id() != kDefaultDB) {
@ -143,18 +154,29 @@ auto ReplicationHandler::RegisterReplica(const memgraph::replication::Replicatio
if (storage->storage_mode_ != storage::StorageMode::IN_MEMORY_TRANSACTIONAL) return; if (storage->storage_mode_ != storage::StorageMode::IN_MEMORY_TRANSACTIONAL) return;
all_clients_good &= all_clients_good &=
storage->repl_storage_state_.replication_clients_.WithLock([storage, &config](auto &clients) -> bool { storage->repl_storage_state_.replication_clients_.WithLock([storage, &instance_client](auto &storage_clients) {
auto client = storage->CreateReplicationClient(config, &storage->repl_storage_state_.epoch_); auto client = std::make_unique<storage::ReplicationStorageClient>(*instance_client.GetValue());
client->Start(); client->Start(storage);
// After start the storage <-> replica state should be READY or RECOVERING (if correctly started)
if (client->State() == storage::replication::ReplicaState::INVALID) { // MAYBE_BEHIND isn't a statement of the current state, this is the default value
// Failed to start due to branching of MAIN and REPLICA
if (client->State() == storage::replication::ReplicaState::MAYBE_BEHIND) {
return false; return false;
} }
clients.push_back(std::move(client)); storage_clients.push_back(std::move(client));
return true; return true;
}); });
}); });
if (!all_clients_good) return RegisterReplicaError::CONNECTION_FAILED; // TODO: this happen to 1 or many...what to do
// NOTE Currently if any databases fails, we revert back
if (!all_clients_good) {
spdlog::error("Failed to register all databases to the REPLICA \"{}\"", config.name);
UnregisterReplica(config.name);
return RegisterReplicaError::CONNECTION_FAILED;
}
// No client error, start instance level client
StartReplicaClient(dbms_handler_, *instance_client.GetValue());
return {}; return {};
} }
@ -163,60 +185,66 @@ auto ReplicationHandler::UnregisterReplica(std::string_view name) -> UnregisterR
return UnregisterReplicaResult::NOT_MAIN; return UnregisterReplicaResult::NOT_MAIN;
}; };
auto const main_handler = [this, name](RoleMainData &mainData) -> UnregisterReplicaResult { auto const main_handler = [this, name](RoleMainData &mainData) -> UnregisterReplicaResult {
if (!repl_state_.TryPersistUnregisterReplica(name)) { if (!dbms_handler_.ReplicationState().TryPersistUnregisterReplica(name)) {
return UnregisterReplicaResult::COULD_NOT_BE_PERSISTED; return UnregisterReplicaResult::COULD_NOT_BE_PERSISTED;
} }
auto const n_unregistered = // Remove database specific clients
std::erase_if(mainData.registered_replicas_, dbms_handler_.ForEach([name](Database *db) {
[&](ReplicationClientConfig const &registered_config) { return registered_config.name == name; }); db->storage()->repl_storage_state_.replication_clients_.WithLock([&name](auto &clients) {
std::erase_if(clients, [name](const auto &client) { return client->Name() == name; });
dbms_handler_.ForEach([&](Database *db) { });
db->storage()->repl_storage_state_.replication_clients_.WithLock(
[&](auto &clients) { std::erase_if(clients, [&](const auto &client) { return client->Name() == name; }); });
}); });
// Remove instance level clients
auto const n_unregistered =
std::erase_if(mainData.registered_replicas_, [name](auto const &client) { return client.name_ == name; });
return n_unregistered != 0 ? UnregisterReplicaResult::SUCCESS : UnregisterReplicaResult::CAN_NOT_UNREGISTER; return n_unregistered != 0 ? UnregisterReplicaResult::SUCCESS : UnregisterReplicaResult::CAN_NOT_UNREGISTER;
}; };
return std::visit(utils::Overloaded{main_handler, replica_handler}, repl_state_.ReplicationData()); return std::visit(utils::Overloaded{main_handler, replica_handler},
dbms_handler_.ReplicationState().ReplicationData());
} }
auto ReplicationHandler::GetRole() const -> memgraph::replication::ReplicationRole { return repl_state_.GetRole(); } auto ReplicationHandler::GetRole() const -> memgraph::replication::ReplicationRole {
return dbms_handler_.ReplicationState().GetRole();
}
bool ReplicationHandler::IsMain() const { return repl_state_.IsMain(); } bool ReplicationHandler::IsMain() const { return dbms_handler_.ReplicationState().IsMain(); }
bool ReplicationHandler::IsReplica() const { return repl_state_.IsReplica(); } bool ReplicationHandler::IsReplica() const { return dbms_handler_.ReplicationState().IsReplica(); }
void RestoreReplication(const replication::ReplicationState &repl_state, storage::Storage &storage) { // Per storage
// NOTE Storage will connect to all replicas. Future work might change this
void RestoreReplication(replication::ReplicationState &repl_state, storage::Storage &storage) {
spdlog::info("Restoring replication role."); spdlog::info("Restoring replication role.");
/// MAIN /// MAIN
auto const recover_main = [&storage](RoleMainData const &mainData) { auto const recover_main = [&storage](RoleMainData &mainData) {
for (const auto &config : mainData.registered_replicas_) { // Each individual client has already been restored and started. Here we just go through each database and start its
spdlog::info("Replica {} restoration started for {}.", config.name, storage.id()); // client
for (auto &instance_client : mainData.registered_replicas_) {
spdlog::info("Replica {} restoration started for {}.", instance_client.name_, storage.id());
auto register_replica = [&storage](const memgraph::replication::ReplicationClientConfig &config) const auto &ret = storage.repl_storage_state_.replication_clients_.WithLock(
-> memgraph::utils::BasicResult<RegisterReplicaError> { [&](auto &storage_clients) -> utils::BasicResult<RegisterReplicaError> {
return storage.repl_storage_state_.replication_clients_.WithLock( auto client = std::make_unique<storage::ReplicationStorageClient>(instance_client);
[&storage, &config](auto &clients) -> utils::BasicResult<RegisterReplicaError> { client->Start(&storage);
auto client = storage.CreateReplicationClient(config, &storage.repl_storage_state_.epoch_); // After start the storage <-> replica state should be READY or RECOVERING (if correctly started)
client->Start(); // MAYBE_BEHIND isn't a statement of the current state, this is the default value
// Failed to start due to branching of MAIN and REPLICA
if (client->State() == storage::replication::ReplicaState::MAYBE_BEHIND) {
spdlog::warn("Connection failed when registering replica {}. Replica will still be registered.",
instance_client.name_);
}
storage_clients.push_back(std::move(client));
return {};
});
if (client->State() == storage::replication::ReplicaState::INVALID) {
spdlog::warn("Connection failed when registering replica {}. Replica will still be registered.",
client->Name());
}
clients.push_back(std::move(client));
return {};
});
};
auto ret = register_replica(config);
if (ret.HasError()) { if (ret.HasError()) {
MG_ASSERT(RegisterReplicaError::CONNECTION_FAILED != ret.GetError()); MG_ASSERT(RegisterReplicaError::CONNECTION_FAILED != ret.GetError());
LOG_FATAL("Failure when restoring replica {}: {}.", config.name, RegisterReplicaErrorToString(ret.GetError())); LOG_FATAL("Failure when restoring replica {}: {}.", instance_client.name_,
RegisterReplicaErrorToString(ret.GetError()));
} }
spdlog::info("Replica {} restored for {}.", config.name, storage.id()); spdlog::info("Replica {} restored for {}.", instance_client.name_, storage.id());
} }
spdlog::info("Replication role restored to MAIN."); spdlog::info("Replication role restored to MAIN.");
}; };
@ -229,6 +257,6 @@ void RestoreReplication(const replication::ReplicationState &repl_state, storage
recover_main, recover_main,
recover_replica, recover_replica,
}, },
std::as_const(repl_state).ReplicationData()); repl_state.ReplicationData());
} }
} // namespace memgraph::dbms } // namespace memgraph::dbms

View File

@ -36,8 +36,7 @@ enum class UnregisterReplicaResult : uint8_t {
/// A handler type that keep in sync current ReplicationState and the MAIN/REPLICA-ness of Storage /// A handler type that keep in sync current ReplicationState and the MAIN/REPLICA-ness of Storage
/// TODO: extend to do multiple storages /// TODO: extend to do multiple storages
struct ReplicationHandler { struct ReplicationHandler {
ReplicationHandler(memgraph::replication::ReplicationState &replState, DbmsHandler &dbms_handler) explicit ReplicationHandler(DbmsHandler &dbms_handler);
: repl_state_(replState), dbms_handler_(dbms_handler) {}
// as REPLICA, become MAIN // as REPLICA, become MAIN
bool SetReplicationRoleMain(); bool SetReplicationRoleMain();
@ -58,12 +57,11 @@ struct ReplicationHandler {
bool IsReplica() const; bool IsReplica() const;
private: private:
memgraph::replication::ReplicationState &repl_state_;
DbmsHandler &dbms_handler_; DbmsHandler &dbms_handler_;
}; };
/// A handler type that keep in sync current ReplicationState and the MAIN/REPLICA-ness of Storage /// A handler type that keep in sync current ReplicationState and the MAIN/REPLICA-ness of Storage
/// TODO: extend to do multiple storages /// TODO: extend to do multiple storages
void RestoreReplication(const replication::ReplicationState &repl_state, storage::Storage &storage); void RestoreReplication(replication::ReplicationState &repl_state, storage::Storage &storage);
} // namespace memgraph::dbms } // namespace memgraph::dbms

View File

@ -65,7 +65,7 @@ DEFINE_bool(allow_load_csv, true, "Controls whether LOAD CSV clause is allowed i
// Storage flags. // Storage flags.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_VALIDATED_uint64(storage_gc_cycle_sec, 30, "Storage garbage collector interval (in seconds).", DEFINE_VALIDATED_uint64(storage_gc_cycle_sec, 30, "Storage garbage collector interval (in seconds).",
FLAG_IN_RANGE(1, 24 * 3600)); FLAG_IN_RANGE(1, 24UL * 3600));
// NOTE: The `storage_properties_on_edges` flag must be the same here and in // NOTE: The `storage_properties_on_edges` flag must be the same here and in
// `mg_import_csv`. If you change it, make sure to change it there as well. // `mg_import_csv`. If you change it, make sure to change it there as well.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
@ -104,9 +104,19 @@ DEFINE_bool(storage_snapshot_on_exit, false, "Controls whether the storage creat
DEFINE_uint64(storage_items_per_batch, memgraph::storage::Config::Durability().items_per_batch, DEFINE_uint64(storage_items_per_batch, memgraph::storage::Config::Durability().items_per_batch,
"The number of edges and vertices stored in a batch in a snapshot file."); "The number of edges and vertices stored in a batch in a snapshot file.");
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,misc-unused-parameters)
DEFINE_VALIDATED_bool(
storage_parallel_index_recovery, false,
"Controls whether the index creation can be done in a multithreaded fashion.", {
spdlog::warn(
"storage_parallel_index_recovery flag is deprecated. Check storage_mode_parallel_schema_recovery for more "
"details.");
return true;
});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_bool(storage_parallel_index_recovery, false, DEFINE_bool(storage_parallel_schema_recovery, false,
"Controls whether the index creation can be done in a multithreaded fashion."); "Controls whether the indices and constraints creation can be done in a multithreaded fashion.");
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_uint64(storage_recovery_thread_count, DEFINE_uint64(storage_recovery_thread_count,
@ -114,6 +124,10 @@ DEFINE_uint64(storage_recovery_thread_count,
memgraph::storage::Config::Durability().recovery_thread_count), memgraph::storage::Config::Durability().recovery_thread_count),
"The number of threads used to recover persisted data from disk."); "The number of threads used to recover persisted data from disk.");
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_bool(storage_enable_schema_metadata, false,
"Controls whether metadata should be collected about the resident labels and edge types.");
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_bool(storage_delete_on_drop, true, DEFINE_bool(storage_delete_on_drop, true,

View File

@ -73,10 +73,15 @@ DECLARE_uint64(storage_wal_file_flush_every_n_tx);
DECLARE_bool(storage_snapshot_on_exit); DECLARE_bool(storage_snapshot_on_exit);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DECLARE_uint64(storage_items_per_batch); DECLARE_uint64(storage_items_per_batch);
// storage_parallel_index_recovery deprecated; use storage_parallel_schema_recovery instead
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DECLARE_bool(storage_parallel_index_recovery); DECLARE_bool(storage_parallel_index_recovery);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DECLARE_bool(storage_parallel_schema_recovery);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DECLARE_uint64(storage_recovery_thread_count); DECLARE_uint64(storage_recovery_thread_count);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DECLARE_bool(storage_enable_schema_metadata);
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DECLARE_bool(storage_delete_on_drop); DECLARE_bool(storage_delete_on_drop);

View File

@ -10,6 +10,7 @@
// licenses/APL.txt. // licenses/APL.txt.
#include <optional> #include <optional>
#include <utility>
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "audit/log.hpp" #include "audit/log.hpp"
@ -233,11 +234,55 @@ std::pair<std::vector<std::string>, std::optional<int>> SessionHL::Interpret(
throw memgraph::communication::bolt::ClientError(e.what()); throw memgraph::communication::bolt::ClientError(e.what());
} }
} }
void SessionHL::RollbackTransaction() { interpreter_.RollbackTransaction(); }
void SessionHL::CommitTransaction() { interpreter_.CommitTransaction(); } void SessionHL::RollbackTransaction() {
void SessionHL::BeginTransaction(const std::map<std::string, memgraph::communication::bolt::Value> &extra) { try {
interpreter_.BeginTransaction(ToQueryExtras(extra)); interpreter_.RollbackTransaction();
} catch (const memgraph::query::QueryException &e) {
// Count the number of specific exceptions thrown
metrics::IncrementCounter(GetExceptionName(e));
// Wrap QueryException into ClientError, because we want to allow the
// client to fix their query.
throw memgraph::communication::bolt::ClientError(e.what());
} catch (const memgraph::query::ReplicationException &e) {
// Count the number of specific exceptions thrown
metrics::IncrementCounter(GetExceptionName(e));
throw memgraph::communication::bolt::ClientError(e.what());
}
} }
void SessionHL::CommitTransaction() {
try {
interpreter_.CommitTransaction();
} catch (const memgraph::query::QueryException &e) {
// Count the number of specific exceptions thrown
metrics::IncrementCounter(GetExceptionName(e));
// Wrap QueryException into ClientError, because we want to allow the
// client to fix their query.
throw memgraph::communication::bolt::ClientError(e.what());
} catch (const memgraph::query::ReplicationException &e) {
// Count the number of specific exceptions thrown
metrics::IncrementCounter(GetExceptionName(e));
throw memgraph::communication::bolt::ClientError(e.what());
}
}
void SessionHL::BeginTransaction(const std::map<std::string, memgraph::communication::bolt::Value> &extra) {
try {
interpreter_.BeginTransaction(ToQueryExtras(extra));
} catch (const memgraph::query::QueryException &e) {
// Count the number of specific exceptions thrown
metrics::IncrementCounter(GetExceptionName(e));
// Wrap QueryException into ClientError, because we want to allow the
// client to fix their query.
throw memgraph::communication::bolt::ClientError(e.what());
} catch (const memgraph::query::ReplicationException &e) {
// Count the number of specific exceptions thrown
metrics::IncrementCounter(GetExceptionName(e));
throw memgraph::communication::bolt::ClientError(e.what());
}
}
void SessionHL::Configure(const std::map<std::string, memgraph::communication::bolt::Value> &run_time_info) { void SessionHL::Configure(const std::map<std::string, memgraph::communication::bolt::Value> &run_time_info) {
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
std::string db; std::string db;
@ -272,7 +317,7 @@ void SessionHL::Configure(const std::map<std::string, memgraph::communication::b
#endif #endif
} }
SessionHL::SessionHL(memgraph::query::InterpreterContext *interpreter_context, SessionHL::SessionHL(memgraph::query::InterpreterContext *interpreter_context,
const memgraph::communication::v2::ServerEndpoint &endpoint, memgraph::communication::v2::ServerEndpoint endpoint,
memgraph::communication::v2::InputStream *input_stream, memgraph::communication::v2::InputStream *input_stream,
memgraph::communication::v2::OutputStream *output_stream, memgraph::communication::v2::OutputStream *output_stream,
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth
@ -289,7 +334,7 @@ SessionHL::SessionHL(memgraph::query::InterpreterContext *interpreter_context,
audit_log_(audit_log), audit_log_(audit_log),
#endif #endif
auth_(auth), auth_(auth),
endpoint_(endpoint), endpoint_(std::move(endpoint)),
implicit_db_(dbms::kDefaultDB) { implicit_db_(dbms::kDefaultDB) {
// Metrics update // Metrics update
memgraph::metrics::IncrementCounter(memgraph::metrics::ActiveBoltSessions); memgraph::metrics::IncrementCounter(memgraph::metrics::ActiveBoltSessions);

View File

@ -23,7 +23,7 @@ class SessionHL final : public memgraph::communication::bolt::Session<memgraph::
memgraph::communication::v2::OutputStream> { memgraph::communication::v2::OutputStream> {
public: public:
SessionHL(memgraph::query::InterpreterContext *interpreter_context, SessionHL(memgraph::query::InterpreterContext *interpreter_context,
const memgraph::communication::v2::ServerEndpoint &endpoint, memgraph::communication::v2::ServerEndpoint endpoint,
memgraph::communication::v2::InputStream *input_stream, memgraph::communication::v2::InputStream *input_stream,
memgraph::communication::v2::OutputStream *output_stream, memgraph::communication::v2::OutputStream *output_stream,
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth

View File

@ -127,6 +127,8 @@ storage::Result<Value> ToBoltValue(const query::TypedValue &value, const storage
return Value(value.ValueLocalDateTime()); return Value(value.ValueLocalDateTime());
case query::TypedValue::Type::Duration: case query::TypedValue::Type::Duration:
return Value(value.ValueDuration()); return Value(value.ValueDuration());
case query::TypedValue::Type::Function:
throw communication::bolt::ValueException("Unsupported conversion from TypedValue::Function to Value");
case query::TypedValue::Type::Graph: case query::TypedValue::Type::Graph:
auto maybe_graph = ToBoltGraph(value.ValueGraph(), db, view); auto maybe_graph = ToBoltGraph(value.ValueGraph(), db, view);
if (maybe_graph.HasError()) return maybe_graph.GetError(); if (maybe_graph.HasError()) return maybe_graph.GetError();
@ -202,7 +204,7 @@ storage::Result<std::map<std::string, Value>> ToBoltGraph(const query::Graph &gr
for (const auto &v : graph.vertices()) { for (const auto &v : graph.vertices()) {
auto maybe_vertex = ToBoltVertex(v, db, view); auto maybe_vertex = ToBoltVertex(v, db, view);
if (maybe_vertex.HasError()) return maybe_vertex.GetError(); if (maybe_vertex.HasError()) return maybe_vertex.GetError();
vertices.emplace_back(Value(std::move(*maybe_vertex))); vertices.emplace_back(std::move(*maybe_vertex));
} }
map.emplace("nodes", Value(vertices)); map.emplace("nodes", Value(vertices));
@ -211,7 +213,7 @@ storage::Result<std::map<std::string, Value>> ToBoltGraph(const query::Graph &gr
for (const auto &e : graph.edges()) { for (const auto &e : graph.edges()) {
auto maybe_edge = ToBoltEdge(e, db, view); auto maybe_edge = ToBoltEdge(e, db, view);
if (maybe_edge.HasError()) return maybe_edge.GetError(); if (maybe_edge.HasError()) return maybe_edge.GetError();
edges.emplace_back(Value(std::move(*maybe_edge))); edges.emplace_back(std::move(*maybe_edge));
} }
map.emplace("edges", Value(edges)); map.emplace("edges", Value(edges));

View File

@ -31,7 +31,7 @@ inline void LoadConfig(const std::string &product_name) {
std::vector<fs::path> configs = {fs::path("/etc/memgraph/memgraph.conf")}; std::vector<fs::path> configs = {fs::path("/etc/memgraph/memgraph.conf")};
if (getenv("HOME") != nullptr) configs.emplace_back(fs::path(getenv("HOME")) / fs::path(".memgraph/config")); if (getenv("HOME") != nullptr) configs.emplace_back(fs::path(getenv("HOME")) / fs::path(".memgraph/config"));
{ {
auto memgraph_config = getenv("MEMGRAPH_CONFIG"); auto *memgraph_config = getenv("MEMGRAPH_CONFIG");
if (memgraph_config != nullptr) { if (memgraph_config != nullptr) {
auto path = fs::path(memgraph_config); auto path = fs::path(memgraph_config);
MG_ASSERT(fs::exists(path), "MEMGRAPH_CONFIG environment variable set to nonexisting path: {}", MG_ASSERT(fs::exists(path), "MEMGRAPH_CONFIG environment variable set to nonexisting path: {}",

View File

@ -23,7 +23,6 @@
#include <utils/event_counter.hpp> #include <utils/event_counter.hpp>
#include <utils/event_gauge.hpp> #include <utils/event_gauge.hpp>
#include "storage/v2/storage.hpp" #include "storage/v2/storage.hpp"
#include "utils/event_gauge.hpp"
#include "utils/event_histogram.hpp" #include "utils/event_histogram.hpp"
namespace memgraph::http { namespace memgraph::http {

View File

@ -171,10 +171,10 @@ class Consumer final : public RdKafka::EventCb {
class ConsumerRebalanceCb : public RdKafka::RebalanceCb { class ConsumerRebalanceCb : public RdKafka::RebalanceCb {
public: public:
ConsumerRebalanceCb(std::string consumer_name); explicit ConsumerRebalanceCb(std::string consumer_name);
void rebalance_cb(RdKafka::KafkaConsumer *consumer, RdKafka::ErrorCode err, void rebalance_cb(RdKafka::KafkaConsumer *consumer, RdKafka::ErrorCode err,
std::vector<RdKafka::TopicPartition *> &partitions) override final; std::vector<RdKafka::TopicPartition *> &partitions) final;
void set_offset(int64_t offset); void set_offset(int64_t offset);

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -67,7 +67,7 @@ utils::BasicResult<std::string, std::vector<Message>> GetBatch(TConsumer &consum
return std::move(batch); return std::move(batch);
case pulsar_client::Result::ResultOk: case pulsar_client::Result::ResultOk:
if (message.getMessageId() != last_message_id) { if (message.getMessageId() != last_message_id) {
batch.emplace_back(Message{std::move(message)}); batch.emplace_back(std::move(message));
} }
break; break;
default: default:

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -10,6 +10,7 @@
// licenses/APL.txt. // licenses/APL.txt.
#pragma once #pragma once
#include <atomic> #include <atomic>
#include <optional> #include <optional>
#include <span> #include <span>

View File

@ -48,8 +48,8 @@ struct Endpoint {
uint16_t port{0}; uint16_t port{0};
IpFamily family{IpFamily::NONE}; IpFamily family{IpFamily::NONE};
static std::optional<std::pair<std::string, uint16_t>> ParseSocketOrAddress( static std::optional<std::pair<std::string, uint16_t>> ParseSocketOrAddress(const std::string &address,
const std::string &address, const std::optional<uint16_t> default_port); std::optional<uint16_t> default_port);
/** /**
* Tries to parse the given string as either a socket address or ip address. * Tries to parse the given string as either a socket address or ip address.
@ -61,8 +61,8 @@ struct Endpoint {
* it into an ip address and a port number; even if a default port is given, * it into an ip address and a port number; even if a default port is given,
* it won't be used, as we expect that it is given in the address string. * it won't be used, as we expect that it is given in the address string.
*/ */
static std::optional<std::pair<std::string, uint16_t>> ParseSocketOrIpAddress( static std::optional<std::pair<std::string, uint16_t>> ParseSocketOrIpAddress(const std::string &address,
const std::string &address, const std::optional<uint16_t> default_port); std::optional<uint16_t> default_port);
/** /**
* Tries to parse given string as either socket address or hostname. * Tries to parse given string as either socket address or hostname.
@ -72,7 +72,7 @@ struct Endpoint {
* After we parse hostname and port we try to resolve the hostname into an ip_address. * After we parse hostname and port we try to resolve the hostname into an ip_address.
*/ */
static std::optional<std::pair<std::string, uint16_t>> ParseHostname(const std::string &address, static std::optional<std::pair<std::string, uint16_t>> ParseHostname(const std::string &address,
const std::optional<uint16_t> default_port); std::optional<uint16_t> default_port);
static IpFamily GetIpFamily(const std::string &address); static IpFamily GetIpFamily(const std::string &address);

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -32,7 +32,7 @@ class Epoll {
public: public:
using Event = struct epoll_event; using Event = struct epoll_event;
Epoll(bool set_cloexec = false) : epoll_fd_(epoll_create1(set_cloexec ? EPOLL_CLOEXEC : 0)) { explicit Epoll(bool set_cloexec = false) : epoll_fd_(epoll_create1(set_cloexec ? EPOLL_CLOEXEC : 0)) {
// epoll_create1 returns an error if there is a logical error in our code // epoll_create1 returns an error if there is a logical error in our code
// (for example invalid flags) or if there is irrecoverable error. In both // (for example invalid flags) or if there is irrecoverable error. In both
// cases it is best to terminate. // cases it is best to terminate.

View File

@ -14,6 +14,7 @@
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <optional> #include <optional>
#include <utility>
#include "io/network/endpoint.hpp" #include "io/network/endpoint.hpp"
@ -201,7 +202,7 @@ class Socket {
bool WaitForReadyWrite(); bool WaitForReadyWrite();
private: private:
Socket(int fd, const Endpoint &endpoint) : socket_(fd), endpoint_(endpoint) {} Socket(int fd, Endpoint endpoint) : socket_(fd), endpoint_(std::move(endpoint)) {}
int socket_ = -1; int socket_ = -1;
Endpoint endpoint_; Endpoint endpoint_;

View File

@ -128,7 +128,7 @@ KVStore::iterator::iterator(const KVStore *kvstore, const std::string &prefix, b
KVStore::iterator::iterator(KVStore::iterator &&other) { pimpl_ = std::move(other.pimpl_); } KVStore::iterator::iterator(KVStore::iterator &&other) { pimpl_ = std::move(other.pimpl_); }
KVStore::iterator::~iterator() {} KVStore::iterator::~iterator() = default;
KVStore::iterator &KVStore::iterator::operator=(KVStore::iterator &&other) { KVStore::iterator &KVStore::iterator::operator=(KVStore::iterator &&other) {
pimpl_ = std::move(other.pimpl_); pimpl_ = std::move(other.pimpl_);

View File

@ -65,10 +65,13 @@ void InitFromCypherlFile(memgraph::query::InterpreterContext &ctx, memgraph::dbm
std::string line; std::string line;
while (std::getline(file, line)) { while (std::getline(file, line)) {
if (!line.empty()) { if (!line.empty()) {
auto results = interpreter.Prepare(line, {}, {}); try {
memgraph::query::DiscardValueResultStream stream; auto results = interpreter.Prepare(line, {}, {});
interpreter.Pull(&stream, {}, results.qid); memgraph::query::DiscardValueResultStream stream;
interpreter.Pull(&stream, {}, results.qid);
} catch (const memgraph::query::UserAlreadyExistsException &e) {
spdlog::warn("{} The rest of the init-file will be run.", e.what());
}
if (audit_log) { if (audit_log) {
audit_log->Record("", "", line, {}, memgraph::dbms::kDefaultDB); audit_log->Record("", "", line, {}, memgraph::dbms::kDefaultDB);
} }
@ -291,7 +294,8 @@ int main(int argc, char **argv) {
memgraph::storage::Config db_config{ memgraph::storage::Config db_config{
.gc = {.type = memgraph::storage::Config::Gc::Type::PERIODIC, .gc = {.type = memgraph::storage::Config::Gc::Type::PERIODIC,
.interval = std::chrono::seconds(FLAGS_storage_gc_cycle_sec)}, .interval = std::chrono::seconds(FLAGS_storage_gc_cycle_sec)},
.items = {.properties_on_edges = FLAGS_storage_properties_on_edges}, .items = {.properties_on_edges = FLAGS_storage_properties_on_edges,
.enable_schema_metadata = FLAGS_storage_enable_schema_metadata},
.durability = {.storage_directory = FLAGS_data_directory, .durability = {.storage_directory = FLAGS_data_directory,
.recover_on_startup = FLAGS_storage_recover_on_startup || FLAGS_data_recovery_on_startup, .recover_on_startup = FLAGS_storage_recover_on_startup || FLAGS_data_recovery_on_startup,
.snapshot_retention_count = FLAGS_storage_snapshot_retention_count, .snapshot_retention_count = FLAGS_storage_snapshot_retention_count,
@ -301,7 +305,9 @@ int main(int argc, char **argv) {
.restore_replication_state_on_startup = FLAGS_replication_restore_state_on_startup, .restore_replication_state_on_startup = FLAGS_replication_restore_state_on_startup,
.items_per_batch = FLAGS_storage_items_per_batch, .items_per_batch = FLAGS_storage_items_per_batch,
.recovery_thread_count = FLAGS_storage_recovery_thread_count, .recovery_thread_count = FLAGS_storage_recovery_thread_count,
.allow_parallel_index_creation = FLAGS_storage_parallel_index_recovery}, // deprecated
.allow_parallel_index_creation = FLAGS_storage_parallel_index_recovery,
.allow_parallel_schema_creation = FLAGS_storage_parallel_schema_recovery},
.transaction = {.isolation_level = memgraph::flags::ParseIsolationLevel()}, .transaction = {.isolation_level = memgraph::flags::ParseIsolationLevel()},
.disk = {.main_storage_directory = FLAGS_data_directory + "/rocksdb_main_storage", .disk = {.main_storage_directory = FLAGS_data_directory + "/rocksdb_main_storage",
.label_index_directory = FLAGS_data_directory + "/rocksdb_label_index", .label_index_directory = FLAGS_data_directory + "/rocksdb_label_index",
@ -368,34 +374,17 @@ int main(int argc, char **argv) {
std::unique_ptr<memgraph::query::AuthChecker> auth_checker; std::unique_ptr<memgraph::query::AuthChecker> auth_checker;
auth_glue(&auth_, auth_handler, auth_checker); auth_glue(&auth_, auth_handler, auth_checker);
memgraph::replication::ReplicationState repl_state(ReplicationStateRootPath(db_config)); memgraph::dbms::DbmsHandler dbms_handler(db_config
memgraph::dbms::DbmsHandler dbms_handler(db_config, repl_state
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
, ,
&auth_, FLAGS_data_recovery_on_startup, FLAGS_storage_delete_on_drop &auth_, FLAGS_data_recovery_on_startup, FLAGS_storage_delete_on_drop
#endif #endif
); );
auto db_acc = dbms_handler.Get(); auto db_acc = dbms_handler.Get();
memgraph::query::InterpreterContext interpreter_context_(interp_config, &dbms_handler, &repl_state,
auth_handler.get(), auth_checker.get());
MG_ASSERT(db_acc, "Failed to access the main database");
// TODO: Move it somewhere better memgraph::query::InterpreterContext interpreter_context_(
// Startup replication state (if recovered at startup) interp_config, &dbms_handler, &dbms_handler.ReplicationState(), auth_handler.get(), auth_checker.get());
MG_ASSERT(std::visit(memgraph::utils::Overloaded{[](memgraph::replication::RoleMainData const &) { return true; }, MG_ASSERT(db_acc, "Failed to access the main database");
[&](memgraph::replication::RoleReplicaData const &data) {
// Register handlers
memgraph::dbms::InMemoryReplicationHandlers::Register(
&dbms_handler, *data.server);
if (!data.server->Start()) {
spdlog::error("Unable to start the replication server.");
return false;
}
return true;
}},
repl_state.ReplicationData()),
"Replica recovery failure!");
memgraph::query::procedure::gModuleRegistry.SetModulesDirectory(memgraph::flags::ParseQueryModulesDirectory(), memgraph::query::procedure::gModuleRegistry.SetModulesDirectory(memgraph::flags::ParseQueryModulesDirectory(),
FLAGS_data_directory); FLAGS_data_directory);

View File

@ -3,14 +3,14 @@ set(memory_src_files
global_memory_control.cpp global_memory_control.cpp
query_memory_control.cpp) query_memory_control.cpp)
find_package(jemalloc REQUIRED)
add_library(mg-memory STATIC ${memory_src_files}) add_library(mg-memory STATIC ${memory_src_files})
target_link_libraries(mg-memory mg-utils fmt) target_link_libraries(mg-memory mg-utils fmt)
message(STATUS "ENABLE_JEMALLOC: ${ENABLE_JEMALLOC}")
if (ENABLE_JEMALLOC) if (ENABLE_JEMALLOC)
find_package(jemalloc REQUIRED)
target_link_libraries(mg-memory Jemalloc::Jemalloc ${CMAKE_DL_LIBS}) target_link_libraries(mg-memory Jemalloc::Jemalloc ${CMAKE_DL_LIBS})
target_compile_definitions(mg-memory PRIVATE USE_JEMALLOC=1) target_compile_definitions(mg-memory PRIVATE USE_JEMALLOC=1)
else()
target_compile_definitions(mg-memory PRIVATE USE_JEMALLOC=0)
endif() endif()

View File

@ -119,7 +119,7 @@ static bool my_commit(extent_hooks_t *extent_hooks, void *addr, size_t size, siz
memgraph::utils::total_memory_tracker.Alloc(static_cast<int64_t>(length)); memgraph::utils::total_memory_tracker.Alloc(static_cast<int64_t>(length));
if (GetQueriesMemoryControl().IsThreadTracked()) [[unlikely]] { if (GetQueriesMemoryControl().IsThreadTracked()) [[unlikely]] {
GetQueriesMemoryControl().TrackFreeOnCurrentThread(size); GetQueriesMemoryControl().TrackAllocOnCurrentThread(size);
} }
return false; return false;

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -62,6 +62,7 @@ bool TypedValueCompare(const TypedValue &a, const TypedValue &b) {
case TypedValue::Type::Edge: case TypedValue::Type::Edge:
case TypedValue::Type::Path: case TypedValue::Type::Path:
case TypedValue::Type::Graph: case TypedValue::Type::Graph:
case TypedValue::Type::Function:
throw QueryRuntimeException("Comparison is not defined for values of type {}.", a.type()); throw QueryRuntimeException("Comparison is not defined for values of type {}.", a.type());
case TypedValue::Type::Null: case TypedValue::Type::Null:
LOG_FATAL("Invalid type"); LOG_FATAL("Invalid type");

View File

@ -41,7 +41,7 @@ bool TypedValueCompare(const TypedValue &a, const TypedValue &b);
/// the define how respective elements compare. /// the define how respective elements compare.
class TypedValueVectorCompare final { class TypedValueVectorCompare final {
public: public:
TypedValueVectorCompare() {} TypedValueVectorCompare() = default;
explicit TypedValueVectorCompare(const std::vector<Ordering> &ordering) : ordering_(ordering) {} explicit TypedValueVectorCompare(const std::vector<Ordering> &ordering) : ordering_(ordering) {}
template <class TAllocator> template <class TAllocator>
@ -147,8 +147,8 @@ concept AccessorWithUpdateProperties = requires(T accessor,
/// ///
/// @throw QueryRuntimeException if value cannot be set as a property value /// @throw QueryRuntimeException if value cannot be set as a property value
template <AccessorWithUpdateProperties T> template <AccessorWithUpdateProperties T>
auto UpdatePropertiesChecked(T *record, std::map<storage::PropertyId, storage::PropertyValue> &properties) -> auto UpdatePropertiesChecked(T *record, std::map<storage::PropertyId, storage::PropertyValue> &properties)
typename std::remove_reference<decltype(record->UpdateProperties(properties).GetValue())>::type { -> std::remove_reference_t<decltype(record->UpdateProperties(properties).GetValue())> {
try { try {
auto maybe_values = record->UpdateProperties(properties); auto maybe_values = record->UpdateProperties(properties);
if (maybe_values.HasError()) { if (maybe_values.HasError()) {

View File

@ -11,6 +11,8 @@
#pragma once #pragma once
#include <utility>
#include "query/config.hpp" #include "query/config.hpp"
#include "query/frontend/semantic/required_privileges.hpp" #include "query/frontend/semantic/required_privileges.hpp"
#include "query/frontend/semantic/symbol_generator.hpp" #include "query/frontend/semantic/symbol_generator.hpp"
@ -98,8 +100,8 @@ ParsedQuery ParseQuery(const std::string &query_string, const std::map<std::stri
class SingleNodeLogicalPlan final : public LogicalPlan { class SingleNodeLogicalPlan final : public LogicalPlan {
public: public:
SingleNodeLogicalPlan(std::unique_ptr<plan::LogicalOperator> root, double cost, AstStorage storage, SingleNodeLogicalPlan(std::unique_ptr<plan::LogicalOperator> root, double cost, AstStorage storage,
const SymbolTable &symbol_table) SymbolTable symbol_table)
: root_(std::move(root)), cost_(cost), storage_(std::move(storage)), symbol_table_(symbol_table) {} : root_(std::move(root)), cost_(cost), storage_(std::move(storage)), symbol_table_(std::move(symbol_table)) {}
const plan::LogicalOperator &GetRoot() const override { return *root_; } const plan::LogicalOperator &GetRoot() const override { return *root_; }
double GetCost() const override { return cost_; } double GetCost() const override { return cost_; }

View File

@ -15,6 +15,7 @@
#include <cppitertools/filter.hpp> #include <cppitertools/filter.hpp>
#include <cppitertools/imap.hpp> #include <cppitertools/imap.hpp>
#include "storage/v2/storage_mode.hpp"
#include "utils/pmr/unordered_set.hpp" #include "utils/pmr/unordered_set.hpp"
namespace memgraph::query { namespace memgraph::query {
@ -139,6 +140,10 @@ std::optional<VertexAccessor> SubgraphDbAccessor::FindVertex(storage::Gid gid, s
query::Graph *SubgraphDbAccessor::getGraph() { return graph_; } query::Graph *SubgraphDbAccessor::getGraph() { return graph_; }
storage::StorageMode SubgraphDbAccessor::GetStorageMode() const noexcept { return db_accessor_.GetStorageMode(); }
DbAccessor *SubgraphDbAccessor::GetAccessor() { return &db_accessor_; }
VertexAccessor SubgraphVertexAccessor::GetVertexAccessor() const { return impl_; } VertexAccessor SubgraphVertexAccessor::GetVertexAccessor() const { return impl_; }
storage::Result<EdgeVertexAccessorResult> SubgraphVertexAccessor::OutEdges(storage::View view) const { storage::Result<EdgeVertexAccessorResult> SubgraphVertexAccessor::OutEdges(storage::View view) const {

View File

@ -40,9 +40,10 @@ class EdgeAccessor final {
public: public:
storage::EdgeAccessor impl_; storage::EdgeAccessor impl_;
public:
explicit EdgeAccessor(storage::EdgeAccessor impl) : impl_(std::move(impl)) {} explicit EdgeAccessor(storage::EdgeAccessor impl) : impl_(std::move(impl)) {}
bool IsDeleted() const { return impl_.IsDeleted(); }
bool IsVisible(storage::View view) const { return impl_.IsVisible(view); } bool IsVisible(storage::View view) const { return impl_.IsVisible(view); }
storage::EdgeTypeId EdgeType() const { return impl_.EdgeType(); } storage::EdgeTypeId EdgeType() const { return impl_.EdgeType(); }
@ -108,7 +109,6 @@ class VertexAccessor final {
static EdgeAccessor MakeEdgeAccessor(const storage::EdgeAccessor impl) { return EdgeAccessor(impl); } static EdgeAccessor MakeEdgeAccessor(const storage::EdgeAccessor impl) { return EdgeAccessor(impl); }
public:
explicit VertexAccessor(storage::VertexAccessor impl) : impl_(impl) {} explicit VertexAccessor(storage::VertexAccessor impl) : impl_(impl) {}
bool IsVisible(storage::View view) const { return impl_.IsVisible(view); } bool IsVisible(storage::View view) const { return impl_.IsVisible(view); }
@ -545,7 +545,7 @@ class DbAccessor final {
void Abort() { accessor_->Abort(); } void Abort() { accessor_->Abort(); }
storage::StorageMode GetStorageMode() const { return accessor_->GetCreationStorageMode(); } storage::StorageMode GetStorageMode() const noexcept { return accessor_->GetCreationStorageMode(); }
bool LabelIndexExists(storage::LabelId label) const { return accessor_->LabelIndexExists(label); } bool LabelIndexExists(storage::LabelId label) const { return accessor_->LabelIndexExists(label); }
@ -597,6 +597,13 @@ class DbAccessor final {
return accessor_->ApproximateVertexCount(label, property, lower, upper); return accessor_->ApproximateVertexCount(label, property, lower, upper);
} }
std::vector<storage::LabelId> ListAllPossiblyPresentVertexLabels() const {
return accessor_->ListAllPossiblyPresentVertexLabels();
}
std::vector<storage::EdgeTypeId> ListAllPossiblyPresentEdgeTypes() const {
return accessor_->ListAllPossiblyPresentEdgeTypes();
}
storage::IndicesInfo ListAllIndices() const { return accessor_->ListAllIndices(); } storage::IndicesInfo ListAllIndices() const { return accessor_->ListAllIndices(); }
storage::ConstraintsInfo ListAllConstraints() const { return accessor_->ListAllConstraints(); } storage::ConstraintsInfo ListAllConstraints() const { return accessor_->ListAllConstraints(); }
@ -694,6 +701,10 @@ class SubgraphDbAccessor final {
std::optional<VertexAccessor> FindVertex(storage::Gid gid, storage::View view); std::optional<VertexAccessor> FindVertex(storage::Gid gid, storage::View view);
Graph *getGraph(); Graph *getGraph();
storage::StorageMode GetStorageMode() const noexcept;
DbAccessor *GetAccessor();
}; };
} // namespace memgraph::query } // namespace memgraph::query

View File

@ -159,7 +159,7 @@ void DumpProperties(std::ostream *os, query::DbAccessor *dba,
*os << "{"; *os << "{";
if (property_id) { if (property_id) {
*os << kInternalPropertyId << ": " << *property_id; *os << kInternalPropertyId << ": " << *property_id;
if (store.size() > 0) *os << ", "; if (!store.empty()) *os << ", ";
} }
utils::PrintIterable(*os, store, ", ", [&dba](auto &os, const auto &kv) { utils::PrintIterable(*os, store, ", ", [&dba](auto &os, const auto &kv) {
os << EscapeName(dba->PropertyToName(kv.first)) << ": "; os << EscapeName(dba->PropertyToName(kv.first)) << ": ";
@ -228,7 +228,7 @@ void DumpEdge(std::ostream *os, query::DbAccessor *dba, const query::EdgeAccesso
throw query::QueryRuntimeException("Unexpected error when getting properties."); throw query::QueryRuntimeException("Unexpected error when getting properties.");
} }
} }
if (maybe_props->size() > 0) { if (!maybe_props->empty()) {
*os << " "; *os << " ";
DumpProperties(os, dba, *maybe_props); DumpProperties(os, dba, *maybe_props);
} }

View File

@ -126,6 +126,12 @@ class InfoInMulticommandTxException : public QueryException {
SPECIALIZE_GET_EXCEPTION_NAME(InfoInMulticommandTxException) SPECIALIZE_GET_EXCEPTION_NAME(InfoInMulticommandTxException)
}; };
class UserAlreadyExistsException : public QueryException {
public:
using QueryException::QueryException;
SPECIALIZE_GET_EXCEPTION_NAME(UserAlreadyExistsException)
};
/** /**
* An exception for an illegal operation that can not be detected * An exception for an illegal operation that can not be detected
* before the query starts executing over data. * before the query starts executing over data.

View File

@ -8,65 +8,49 @@
// the Business Source License, use of this software will be governed // the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file // by the Apache License, Version 2.0, included in the file
// licenses/APL.txt. // licenses/APL.txt.
#include <tuple>
#include <utility>
#include "query/typed_value.hpp" #include "query/typed_value.hpp"
#include "utils/fnv.hpp"
#include "utils/memory.hpp" #include "utils/memory.hpp"
#include "utils/pmr/unordered_map.hpp" #include "utils/pmr/unordered_map.hpp"
#include "utils/pmr/vector.hpp" #include "utils/pmr/vector.hpp"
namespace memgraph::query { namespace memgraph::query {
// Key is hash output, value is vector of unique elements // Key is hash output, value is vector of unique elements
using CachedType = utils::pmr::unordered_map<size_t, std::vector<TypedValue>>; using CachedType = utils::pmr::unordered_map<size_t, utils::pmr::vector<TypedValue>>;
struct CachedValue { struct CachedValue {
using allocator_type = utils::Allocator<CachedValue>;
// Cached value, this can be probably templateized // Cached value, this can be probably templateized
CachedType cache_; CachedType cache_;
explicit CachedValue(utils::MemoryResource *mem) : cache_(mem) {} explicit CachedValue(utils::MemoryResource *mem) : cache_{mem} {};
CachedValue(const CachedValue &other, utils::MemoryResource *mem) : cache_(other.cache_, mem) {}
CachedValue(CachedValue &&other, utils::MemoryResource *mem) : cache_(std::move(other.cache_), mem){};
CachedValue(CachedType &&cache, memgraph::utils::MemoryResource *memory) : cache_(std::move(cache), memory) {} CachedValue(CachedValue &&other) noexcept : CachedValue(std::move(other), other.GetMemoryResource()) {}
CachedValue(const CachedValue &other, memgraph::utils::MemoryResource *memory) : cache_(other.cache_, memory) {} CachedValue(const CachedValue &other)
: CachedValue(other, std::allocator_traits<allocator_type>::select_on_container_copy_construction(
other.GetMemoryResource())
.GetMemoryResource()) {}
CachedValue(CachedValue &&other, memgraph::utils::MemoryResource *memory) : cache_(std::move(other.cache_), memory) {} utils::MemoryResource *GetMemoryResource() const { return cache_.get_allocator().GetMemoryResource(); }
CachedValue(CachedValue &&other) noexcept = delete;
/// Copy construction without memgraph::utils::MemoryResource is not allowed.
CachedValue(const CachedValue &) = delete;
CachedValue &operator=(const CachedValue &) = delete; CachedValue &operator=(const CachedValue &) = delete;
CachedValue &operator=(CachedValue &&) = delete; CachedValue &operator=(CachedValue &&) = delete;
~CachedValue() = default; ~CachedValue() = default;
memgraph::utils::MemoryResource *GetMemoryResource() const noexcept {
return cache_.get_allocator().GetMemoryResource();
}
// Func to check if cache_ contains value
bool CacheValue(TypedValue &&maybe_list) {
if (!maybe_list.IsList()) {
return false;
}
auto &list = maybe_list.ValueList();
TypedValue::Hash hash{};
for (auto &element : list) {
const auto key = hash(element);
auto &vector_values = cache_[key];
if (!IsValueInVec(vector_values, element)) {
vector_values.emplace_back(std::move(element));
}
}
return true;
}
bool CacheValue(const TypedValue &maybe_list) { bool CacheValue(const TypedValue &maybe_list) {
if (!maybe_list.IsList()) { if (!maybe_list.IsList()) {
return false; return false;
} }
auto &list = maybe_list.ValueList(); const auto &list = maybe_list.ValueList();
TypedValue::Hash hash{}; TypedValue::Hash hash{};
for (auto &element : list) { for (const auto &element : list) {
const auto key = hash(element); const auto key = hash(element);
auto &vector_values = cache_[key]; auto &vector_values = cache_[key];
if (!IsValueInVec(vector_values, element)) { if (!IsValueInVec(vector_values, element)) {
@ -87,7 +71,7 @@ struct CachedValue {
} }
private: private:
static bool IsValueInVec(const std::vector<TypedValue> &vec_values, const TypedValue &value) { static bool IsValueInVec(const utils::pmr::vector<TypedValue> &vec_values, const TypedValue &value) {
return std::any_of(vec_values.begin(), vec_values.end(), [&value](auto &vec_value) { return std::any_of(vec_values.begin(), vec_values.end(), [&value](auto &vec_value) {
const auto is_value_equal = vec_value == value; const auto is_value_equal = vec_value == value;
if (is_value_equal.IsNull()) return false; if (is_value_equal.IsNull()) return false;
@ -99,35 +83,70 @@ struct CachedValue {
// Class tracks keys for which user can cache values which help with faster search or faster retrieval // Class tracks keys for which user can cache values which help with faster search or faster retrieval
// in the future. Used for IN LIST operator. // in the future. Used for IN LIST operator.
class FrameChangeCollector { class FrameChangeCollector {
/** Allocator type so that STL containers are aware that we need one */
using allocator_type = utils::Allocator<FrameChangeCollector>;
public: public:
explicit FrameChangeCollector() : tracked_values_(&memory_resource_){}; explicit FrameChangeCollector(utils::MemoryResource *mem = utils::NewDeleteResource()) : tracked_values_{mem} {}
FrameChangeCollector(FrameChangeCollector &&other, utils::MemoryResource *mem)
: tracked_values_(std::move(other.tracked_values_), mem) {}
FrameChangeCollector(const FrameChangeCollector &other, utils::MemoryResource *mem)
: tracked_values_(other.tracked_values_, mem) {}
FrameChangeCollector(const FrameChangeCollector &other)
: FrameChangeCollector(other, std::allocator_traits<allocator_type>::select_on_container_copy_construction(
other.GetMemoryResource())
.GetMemoryResource()){};
FrameChangeCollector(FrameChangeCollector &&other) noexcept
: FrameChangeCollector(std::move(other), other.GetMemoryResource()) {}
/** Copy assign other, utils::MemoryResource of `this` is used */
FrameChangeCollector &operator=(const FrameChangeCollector &) = default;
/** Move assign other, utils::MemoryResource of `this` is used. */
FrameChangeCollector &operator=(FrameChangeCollector &&) noexcept = default;
utils::MemoryResource *GetMemoryResource() const { return tracked_values_.get_allocator().GetMemoryResource(); }
CachedValue &AddTrackingKey(const std::string &key) { CachedValue &AddTrackingKey(const std::string &key) {
const auto &[it, _] = tracked_values_.emplace(key, tracked_values_.get_allocator().GetMemoryResource()); const auto &[it, _] = tracked_values_.emplace(
std::piecewise_construct, std::forward_as_tuple(utils::pmr::string(key, utils::NewDeleteResource())),
std::forward_as_tuple());
return it->second; return it->second;
} }
bool IsKeyTracked(const std::string &key) const { return tracked_values_.contains(key); } bool IsKeyTracked(const std::string &key) const {
return tracked_values_.contains(utils::pmr::string(key, utils::NewDeleteResource()));
}
bool IsKeyValueCached(const std::string &key) const { bool IsKeyValueCached(const std::string &key) const {
return IsKeyTracked(key) && !tracked_values_.at(key).cache_.empty(); return IsKeyTracked(key) && !tracked_values_.at(utils::pmr::string(key, utils::NewDeleteResource())).cache_.empty();
} }
bool ResetTrackingValue(const std::string &key) { bool ResetTrackingValue(const std::string &key) {
if (!tracked_values_.contains(key)) { if (!tracked_values_.contains(utils::pmr::string(key, utils::NewDeleteResource()))) {
return false; return false;
} }
tracked_values_.erase(key); tracked_values_.erase(utils::pmr::string(key, utils::NewDeleteResource()));
AddTrackingKey(key); AddTrackingKey(key);
return true; return true;
} }
CachedValue &GetCachedValue(const std::string &key) { return tracked_values_.at(key); } CachedValue &GetCachedValue(const std::string &key) {
return tracked_values_.at(utils::pmr::string(key, utils::NewDeleteResource()));
}
bool IsTrackingValues() const { return !tracked_values_.empty(); } bool IsTrackingValues() const { return !tracked_values_.empty(); }
~FrameChangeCollector() = default;
private: private:
utils::MonotonicBufferResource memory_resource_{0}; struct PmrStringHash {
memgraph::utils::pmr::unordered_map<std::string, CachedValue> tracked_values_; size_t operator()(const utils::pmr::string &key) const { return utils::Fnv(key); }
};
utils::pmr::unordered_map<utils::pmr::string, CachedValue, PmrStringHash> tracked_values_;
}; };
} // namespace memgraph::query } // namespace memgraph::query

View File

@ -23,9 +23,7 @@
#include "storage/v2/property_value.hpp" #include "storage/v2/property_value.hpp"
#include "utils/typeinfo.hpp" #include "utils/typeinfo.hpp"
namespace memgraph { namespace memgraph::query {
namespace query {
struct LabelIx { struct LabelIx {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
@ -62,8 +60,8 @@ inline bool operator!=(const PropertyIx &a, const PropertyIx &b) { return !(a ==
inline bool operator==(const EdgeTypeIx &a, const EdgeTypeIx &b) { return a.ix == b.ix && a.name == b.name; } inline bool operator==(const EdgeTypeIx &a, const EdgeTypeIx &b) { return a.ix == b.ix && a.name == b.name; }
inline bool operator!=(const EdgeTypeIx &a, const EdgeTypeIx &b) { return !(a == b); } inline bool operator!=(const EdgeTypeIx &a, const EdgeTypeIx &b) { return !(a == b); }
} // namespace query } // namespace memgraph::query
} // namespace memgraph
namespace std { namespace std {
template <> template <>
@ -83,9 +81,7 @@ struct hash<memgraph::query::EdgeTypeIx> {
} // namespace std } // namespace std
namespace memgraph { namespace memgraph::query {
namespace query {
class Tree; class Tree;
@ -1822,6 +1818,10 @@ class EdgeAtom : public memgraph::query::PatternAtom {
memgraph::query::Identifier *inner_edge{nullptr}; memgraph::query::Identifier *inner_edge{nullptr};
/// Argument identifier for the destination node of the edge. /// Argument identifier for the destination node of the edge.
memgraph::query::Identifier *inner_node{nullptr}; memgraph::query::Identifier *inner_node{nullptr};
/// Argument identifier for the currently-accumulated path.
memgraph::query::Identifier *accumulated_path{nullptr};
/// Argument identifier for the weight of the currently-accumulated path.
memgraph::query::Identifier *accumulated_weight{nullptr};
/// Evaluates the result of the lambda. /// Evaluates the result of the lambda.
memgraph::query::Expression *expression{nullptr}; memgraph::query::Expression *expression{nullptr};
@ -1829,6 +1829,8 @@ class EdgeAtom : public memgraph::query::PatternAtom {
Lambda object; Lambda object;
object.inner_edge = inner_edge ? inner_edge->Clone(storage) : nullptr; object.inner_edge = inner_edge ? inner_edge->Clone(storage) : nullptr;
object.inner_node = inner_node ? inner_node->Clone(storage) : nullptr; object.inner_node = inner_node ? inner_node->Clone(storage) : nullptr;
object.accumulated_path = accumulated_path ? accumulated_path->Clone(storage) : nullptr;
object.accumulated_weight = accumulated_weight ? accumulated_weight->Clone(storage) : nullptr;
object.expression = expression ? expression->Clone(storage) : nullptr; object.expression = expression ? expression->Clone(storage) : nullptr;
return object; return object;
} }
@ -2932,7 +2934,7 @@ class DatabaseInfoQuery : public memgraph::query::Query {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
enum class InfoType { INDEX, CONSTRAINT }; enum class InfoType { INDEX, CONSTRAINT, EDGE_TYPES, NODE_LABELS };
DEFVISITABLE(QueryVisitor<void>); DEFVISITABLE(QueryVisitor<void>);
@ -3029,7 +3031,7 @@ class ReplicationQuery : public memgraph::query::Query {
enum class SyncMode { SYNC, ASYNC }; enum class SyncMode { SYNC, ASYNC };
enum class ReplicaState { READY, REPLICATING, RECOVERY, INVALID }; enum class ReplicaState { READY, REPLICATING, RECOVERY, MAYBE_BEHIND };
ReplicationQuery() = default; ReplicationQuery() = default;
@ -3577,5 +3579,4 @@ class ShowDatabasesQuery : public memgraph::query::Query {
} }
}; };
} // namespace query } // namespace memgraph::query
} // namespace memgraph

View File

@ -124,6 +124,14 @@ antlrcpp::Any CypherMainVisitor::visitDatabaseInfoQuery(MemgraphCypher::Database
info_query->info_type_ = DatabaseInfoQuery::InfoType::CONSTRAINT; info_query->info_type_ = DatabaseInfoQuery::InfoType::CONSTRAINT;
return info_query; return info_query;
} }
if (ctx->edgetypeInfo()) {
info_query->info_type_ = DatabaseInfoQuery::InfoType::EDGE_TYPES;
return info_query;
}
if (ctx->nodelabelInfo()) {
info_query->info_type_ = DatabaseInfoQuery::InfoType::NODE_LABELS;
return info_query;
}
// Should never get here // Should never get here
throw utils::NotYetImplemented("Database info query: '{}'", ctx->getText()); throw utils::NotYetImplemented("Database info query: '{}'", ctx->getText());
} }
@ -1216,10 +1224,6 @@ antlrcpp::Any CypherMainVisitor::visitCallProcedure(MemgraphCypher::CallProcedur
call_proc->memory_limit_ = memory_limit_info->first; call_proc->memory_limit_ = memory_limit_info->first;
call_proc->memory_scale_ = memory_limit_info->second; call_proc->memory_scale_ = memory_limit_info->second;
} }
} else {
// Default to 100 MB
call_proc->memory_limit_ = storage_->Create<PrimitiveLiteral>(TypedValue(100));
call_proc->memory_scale_ = 1024U * 1024U;
} }
const auto &maybe_found = const auto &maybe_found =
@ -1240,11 +1244,13 @@ antlrcpp::Any CypherMainVisitor::visitCallProcedure(MemgraphCypher::CallProcedur
throw SemanticException("There is no procedure named '{}'.", call_proc->procedure_name_); throw SemanticException("There is no procedure named '{}'.", call_proc->procedure_name_);
} }
} }
call_proc->is_write_ = maybe_found->second->info.is_write; if (maybe_found) {
call_proc->is_write_ = maybe_found->second->info.is_write;
}
auto *yield_ctx = ctx->yieldProcedureResults(); auto *yield_ctx = ctx->yieldProcedureResults();
if (!yield_ctx) { if (!yield_ctx) {
if (!maybe_found->second->results.empty() && !call_proc->void_procedure_) { if ((maybe_found && !maybe_found->second->results.empty()) && !call_proc->void_procedure_) {
throw SemanticException( throw SemanticException(
"CALL without YIELD may only be used on procedures which do not " "CALL without YIELD may only be used on procedures which do not "
"return any result fields."); "return any result fields.");
@ -1270,28 +1276,59 @@ antlrcpp::Any CypherMainVisitor::visitCallProcedure(MemgraphCypher::CallProcedur
call_proc->result_identifiers_.push_back(storage_->Create<Identifier>(result_alias)); call_proc->result_identifiers_.push_back(storage_->Create<Identifier>(result_alias));
} }
} else { } else {
const auto &maybe_found = call_proc->is_write_ = maybe_found->second->info.is_write;
procedure::FindProcedure(procedure::gModuleRegistry, call_proc->procedure_name_, utils::NewDeleteResource());
if (!maybe_found) { auto *yield_ctx = ctx->yieldProcedureResults();
throw SemanticException("There is no procedure named '{}'.", call_proc->procedure_name_); if (!yield_ctx) {
if (!maybe_found->second->results.empty() && !call_proc->void_procedure_) {
throw SemanticException(
"CALL without YIELD may only be used on procedures which do not "
"return any result fields.");
}
// When we return, we will release the lock on modules. This means that
// someone may reload the procedure and change the result signature. But to
// keep the implementation simple, we ignore the case as the rest of the
// code doesn't really care whether we yield or not, so it should not break.
return call_proc;
} }
const auto &[module, proc] = *maybe_found; if (yield_ctx->getTokens(MemgraphCypher::ASTERISK).empty()) {
call_proc->result_fields_.reserve(proc->results.size()); call_proc->result_fields_.reserve(yield_ctx->procedureResult().size());
call_proc->result_identifiers_.reserve(proc->results.size()); call_proc->result_identifiers_.reserve(yield_ctx->procedureResult().size());
for (const auto &[result_name, desc] : proc->results) { for (auto *result : yield_ctx->procedureResult()) {
bool is_deprecated = desc.second; MG_ASSERT(result->variable().size() == 1 || result->variable().size() == 2);
if (is_deprecated) continue; call_proc->result_fields_.push_back(std::any_cast<std::string>(result->variable()[0]->accept(this)));
call_proc->result_fields_.emplace_back(result_name); std::string result_alias;
call_proc->result_identifiers_.push_back(storage_->Create<Identifier>(std::string(result_name))); if (result->variable().size() == 2) {
result_alias = std::any_cast<std::string>(result->variable()[1]->accept(this));
} else {
result_alias = std::any_cast<std::string>(result->variable()[0]->accept(this));
}
call_proc->result_identifiers_.push_back(storage_->Create<Identifier>(result_alias));
}
} else {
const auto &maybe_found =
procedure::FindProcedure(procedure::gModuleRegistry, call_proc->procedure_name_, utils::NewDeleteResource());
if (!maybe_found) {
throw SemanticException("There is no procedure named '{}'.", call_proc->procedure_name_);
}
const auto &[module, proc] = *maybe_found;
call_proc->result_fields_.reserve(proc->results.size());
call_proc->result_identifiers_.reserve(proc->results.size());
for (const auto &[result_name, desc] : proc->results) {
bool is_deprecated = desc.second;
if (is_deprecated) continue;
call_proc->result_fields_.emplace_back(result_name);
call_proc->result_identifiers_.push_back(storage_->Create<Identifier>(std::string(result_name)));
}
// When we leave the scope, we will release the lock on modules. This means
// that someone may reload the procedure and change its result signature. We
// are fine with this, because if new result fields were added then we yield
// the subset of those and that will appear to a user as if they used the
// procedure before reload. Any subsequent `CALL ... YIELD *` will fetch the
// new fields as well. In case the result signature has had some result
// fields removed, then the query execution will report an error that we are
// yielding missing fields. The user can then just retry the query.
} }
// When we leave the scope, we will release the lock on modules. This means
// that someone may reload the procedure and change its result signature. We
// are fine with this, because if new result fields were added then we yield
// the subset of those and that will appear to a user as if they used the
// procedure before reload. Any subsequent `CALL ... YIELD *` will fetch the
// new fields as well. In case the result signature has had some result
// fields removed, then the query execution will report an error that we are
// yielding missing fields. The user can then just retry the query.
} }
return call_proc; return call_proc;
@ -1980,6 +2017,15 @@ antlrcpp::Any CypherMainVisitor::visitRelationshipPattern(MemgraphCypher::Relati
edge_lambda.inner_edge = storage_->Create<Identifier>(traversed_edge_variable); edge_lambda.inner_edge = storage_->Create<Identifier>(traversed_edge_variable);
auto traversed_node_variable = std::any_cast<std::string>(lambda->traversed_node->accept(this)); auto traversed_node_variable = std::any_cast<std::string>(lambda->traversed_node->accept(this));
edge_lambda.inner_node = storage_->Create<Identifier>(traversed_node_variable); edge_lambda.inner_node = storage_->Create<Identifier>(traversed_node_variable);
if (lambda->accumulated_path) {
auto accumulated_path_variable = std::any_cast<std::string>(lambda->accumulated_path->accept(this));
edge_lambda.accumulated_path = storage_->Create<Identifier>(accumulated_path_variable);
if (lambda->accumulated_weight) {
auto accumulated_weight_variable = std::any_cast<std::string>(lambda->accumulated_weight->accept(this));
edge_lambda.accumulated_weight = storage_->Create<Identifier>(accumulated_weight_variable);
}
}
edge_lambda.expression = std::any_cast<Expression *>(lambda->expression()->accept(this)); edge_lambda.expression = std::any_cast<Expression *>(lambda->expression()->accept(this));
return edge_lambda; return edge_lambda;
}; };
@ -2004,6 +2050,15 @@ antlrcpp::Any CypherMainVisitor::visitRelationshipPattern(MemgraphCypher::Relati
// In variable expansion inner variables are mandatory. // In variable expansion inner variables are mandatory.
anonymous_identifiers.push_back(&edge->filter_lambda_.inner_edge); anonymous_identifiers.push_back(&edge->filter_lambda_.inner_edge);
anonymous_identifiers.push_back(&edge->filter_lambda_.inner_node); anonymous_identifiers.push_back(&edge->filter_lambda_.inner_node);
// TODO: In what use case do we need accumulated path and weight here?
if (edge->filter_lambda_.accumulated_path) {
anonymous_identifiers.push_back(&edge->filter_lambda_.accumulated_path);
if (edge->filter_lambda_.accumulated_weight) {
anonymous_identifiers.push_back(&edge->filter_lambda_.accumulated_weight);
}
}
break; break;
case 1: case 1:
if (edge->type_ == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH || if (edge->type_ == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH ||
@ -2015,9 +2070,21 @@ antlrcpp::Any CypherMainVisitor::visitRelationshipPattern(MemgraphCypher::Relati
// Add mandatory inner variables for filter lambda. // Add mandatory inner variables for filter lambda.
anonymous_identifiers.push_back(&edge->filter_lambda_.inner_edge); anonymous_identifiers.push_back(&edge->filter_lambda_.inner_edge);
anonymous_identifiers.push_back(&edge->filter_lambda_.inner_node); anonymous_identifiers.push_back(&edge->filter_lambda_.inner_node);
if (edge->filter_lambda_.accumulated_path) {
anonymous_identifiers.push_back(&edge->filter_lambda_.accumulated_path);
if (edge->filter_lambda_.accumulated_weight) {
anonymous_identifiers.push_back(&edge->filter_lambda_.accumulated_weight);
}
}
} else { } else {
// Other variable expands only have the filter lambda. // Other variable expands only have the filter lambda.
edge->filter_lambda_ = visit_lambda(relationshipLambdas[0]); edge->filter_lambda_ = visit_lambda(relationshipLambdas[0]);
if (edge->filter_lambda_.accumulated_weight) {
throw SemanticException(
"Accumulated weight in filter lambda can be used only with "
"shortest paths expansion.");
}
} }
break; break;
case 2: case 2:

View File

@ -105,7 +105,7 @@ void PrintObject(std::ostream *out, const std::map<K, V> &map);
template <typename T> template <typename T>
void PrintObject(std::ostream *out, const T &arg) { void PrintObject(std::ostream *out, const T &arg) {
static_assert(!std::is_convertible<T, Expression *>::value, static_assert(!std::is_convertible_v<T, Expression *>,
"This overload shouldn't be called with pointers convertible " "This overload shouldn't be called with pointers convertible "
"to Expression *. This means your other PrintObject overloads aren't " "to Expression *. This means your other PrintObject overloads aren't "
"being called for certain AST nodes when they should (or perhaps such " "being called for certain AST nodes when they should (or perhaps such "

View File

@ -47,9 +47,13 @@ indexInfo : INDEX INFO ;
constraintInfo : CONSTRAINT INFO ; constraintInfo : CONSTRAINT INFO ;
edgetypeInfo : EDGE_TYPES INFO ;
nodelabelInfo : NODE_LABELS INFO ;
buildInfo : BUILD INFO ; buildInfo : BUILD INFO ;
databaseInfoQuery : SHOW ( indexInfo | constraintInfo ) ; databaseInfoQuery : SHOW ( indexInfo | constraintInfo | edgetypeInfo | nodelabelInfo ) ;
systemInfoQuery : SHOW ( storageInfo | buildInfo ) ; systemInfoQuery : SHOW ( storageInfo | buildInfo ) ;
@ -175,7 +179,7 @@ relationshipDetail : '[' ( name=variable )? ( relationshipTypes )? ( variableExp
| '[' ( name=variable )? ( relationshipTypes )? ( variableExpansion )? relationshipLambda ( total_weight=variable )? (relationshipLambda )? ']' | '[' ( name=variable )? ( relationshipTypes )? ( variableExpansion )? relationshipLambda ( total_weight=variable )? (relationshipLambda )? ']'
| '[' ( name=variable )? ( relationshipTypes )? ( variableExpansion )? (properties )* ( relationshipLambda total_weight=variable )? (relationshipLambda )? ']'; | '[' ( name=variable )? ( relationshipTypes )? ( variableExpansion )? (properties )* ( relationshipLambda total_weight=variable )? (relationshipLambda )? ']';
relationshipLambda: '(' traversed_edge=variable ',' traversed_node=variable '|' expression ')'; relationshipLambda: '(' traversed_edge=variable ',' traversed_node=variable ( ',' accumulated_path=variable )? ( ',' accumulated_weight=variable )? '|' expression ')';
variableExpansion : '*' (BFS | WSHORTEST | ALLSHORTEST)? ( expression )? ( '..' ( expression )? )? ; variableExpansion : '*' (BFS | WSHORTEST | ALLSHORTEST)? ( expression )? ( '..' ( expression )? )? ;

View File

@ -61,6 +61,7 @@ memgraphCypherKeyword : cypherKeyword
| GRANT | GRANT
| HEADER | HEADER
| IDENTIFIED | IDENTIFIED
| NODE_LABELS
| NULLIF | NULLIF
| IMPORT | IMPORT
| INACTIVE | INACTIVE

View File

@ -89,6 +89,7 @@ MULTI_DATABASE_EDIT : M U L T I UNDERSCORE D A T A B A S E UNDERSCORE E D I
MULTI_DATABASE_USE : M U L T I UNDERSCORE D A T A B A S E UNDERSCORE U S E ; MULTI_DATABASE_USE : M U L T I UNDERSCORE D A T A B A S E UNDERSCORE U S E ;
NEXT : N E X T ; NEXT : N E X T ;
NO : N O ; NO : N O ;
NODE_LABELS : N O D E UNDERSCORE L A B E L S ;
NOTHING : N O T H I N G ; NOTHING : N O T H I N G ;
ON_DISK_TRANSACTIONAL : O N UNDERSCORE D I S K UNDERSCORE T R A N S A C T I O N A L ; ON_DISK_TRANSACTIONAL : O N UNDERSCORE D I S K UNDERSCORE T R A N S A C T I O N A L ;
NULLIF : N U L L I F ; NULLIF : N U L L I F ;

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -31,7 +31,7 @@ class Parser {
* @param query incoming query that has to be compiled into query plan * @param query incoming query that has to be compiled into query plan
* the first step is to generate AST * the first step is to generate AST
*/ */
Parser(const std::string query) : query_(std::move(query)) { explicit Parser(const std::string query) : query_(std::move(query)) {
parser_.removeErrorListeners(); parser_.removeErrorListeners();
parser_.addErrorListener(&error_listener_); parser_.addErrorListener(&error_listener_);
tree_ = parser_.cypher(); tree_ = parser_.cypher();

View File

@ -38,6 +38,9 @@ class PrivilegeExtractor : public QueryVisitor<void>, public HierarchicalTreeVis
void Visit(DatabaseInfoQuery &info_query) override { void Visit(DatabaseInfoQuery &info_query) override {
switch (info_query.info_type_) { switch (info_query.info_type_) {
case DatabaseInfoQuery::InfoType::INDEX: case DatabaseInfoQuery::InfoType::INDEX:
// TODO: Reconsider priviliges, this 4 should have the same.
case DatabaseInfoQuery::InfoType::EDGE_TYPES:
case DatabaseInfoQuery::InfoType::NODE_LABELS:
// TODO: This should be INDEX | STATS, but we don't have support for // TODO: This should be INDEX | STATS, but we don't have support for
// *or* with privileges. // *or* with privileges.
AddPrivilege(AuthQuery::Privilege::INDEX); AddPrivilege(AuthQuery::Privilege::INDEX);

View File

@ -12,12 +12,11 @@
#pragma once #pragma once
#include <string> #include <string>
#include <utility>
#include "utils/typeinfo.hpp" #include "utils/typeinfo.hpp"
namespace memgraph { namespace memgraph::query {
namespace query {
class Symbol { class Symbol {
public: public:
@ -34,9 +33,13 @@ class Symbol {
return enum_string[static_cast<int>(type)]; return enum_string[static_cast<int>(type)];
} }
Symbol() {} Symbol() = default;
Symbol(const std::string &name, int position, bool user_declared, Type type = Type::ANY, int token_position = -1) Symbol(std::string name, int position, bool user_declared, Type type = Type::ANY, int token_position = -1)
: name_(name), position_(position), user_declared_(user_declared), type_(type), token_position_(token_position) {} : name_(std::move(name)),
position_(position),
user_declared_(user_declared),
type_(type),
token_position_(token_position) {}
bool operator==(const Symbol &other) const { bool operator==(const Symbol &other) const {
return position_ == other.position_ && name_ == other.name_ && type_ == other.type_; return position_ == other.position_ && name_ == other.name_ && type_ == other.type_;
@ -57,8 +60,8 @@ class Symbol {
int64_t token_position_{-1}; int64_t token_position_{-1};
}; };
} // namespace query } // namespace memgraph::query
} // namespace memgraph
namespace std { namespace std {
template <> template <>

View File

@ -658,8 +658,16 @@ bool SymbolGenerator::PreVisit(EdgeAtom &edge_atom) {
scope.in_edge_range = false; scope.in_edge_range = false;
scope.in_pattern = false; scope.in_pattern = false;
if (edge_atom.filter_lambda_.expression) { if (edge_atom.filter_lambda_.expression) {
VisitWithIdentifiers(edge_atom.filter_lambda_.expression, std::vector<Identifier *> filter_lambda_identifiers{edge_atom.filter_lambda_.inner_edge,
{edge_atom.filter_lambda_.inner_edge, edge_atom.filter_lambda_.inner_node}); edge_atom.filter_lambda_.inner_node};
if (edge_atom.filter_lambda_.accumulated_path) {
filter_lambda_identifiers.emplace_back(edge_atom.filter_lambda_.accumulated_path);
if (edge_atom.filter_lambda_.accumulated_weight) {
filter_lambda_identifiers.emplace_back(edge_atom.filter_lambda_.accumulated_weight);
}
}
VisitWithIdentifiers(edge_atom.filter_lambda_.expression, filter_lambda_identifiers);
} else { } else {
// Create inner symbols, but don't bind them in scope, since they are to // Create inner symbols, but don't bind them in scope, since they are to
// be used in the missing filter expression. // be used in the missing filter expression.
@ -668,6 +676,17 @@ bool SymbolGenerator::PreVisit(EdgeAtom &edge_atom) {
auto *inner_node = edge_atom.filter_lambda_.inner_node; auto *inner_node = edge_atom.filter_lambda_.inner_node;
inner_node->MapTo( inner_node->MapTo(
symbol_table_->CreateSymbol(inner_node->name_, inner_node->user_declared_, Symbol::Type::VERTEX)); symbol_table_->CreateSymbol(inner_node->name_, inner_node->user_declared_, Symbol::Type::VERTEX));
if (edge_atom.filter_lambda_.accumulated_path) {
auto *accumulated_path = edge_atom.filter_lambda_.accumulated_path;
accumulated_path->MapTo(
symbol_table_->CreateSymbol(accumulated_path->name_, accumulated_path->user_declared_, Symbol::Type::PATH));
if (edge_atom.filter_lambda_.accumulated_weight) {
auto *accumulated_weight = edge_atom.filter_lambda_.accumulated_weight;
accumulated_weight->MapTo(symbol_table_->CreateSymbol(
accumulated_weight->name_, accumulated_weight->user_declared_, Symbol::Type::NUMBER));
}
}
} }
if (edge_atom.weight_lambda_.expression) { if (edge_atom.weight_lambda_.expression) {
VisitWithIdentifiers(edge_atom.weight_lambda_.expression, VisitWithIdentifiers(edge_atom.weight_lambda_.expression,

View File

@ -183,7 +183,7 @@ class SymbolGenerator : public HierarchicalTreeVisitor {
/// If property lookup for one symbol is visited more times, it is better to fetch all properties /// If property lookup for one symbol is visited more times, it is better to fetch all properties
class PropertyLookupEvaluationModeVisitor : public ExpressionVisitor<void> { class PropertyLookupEvaluationModeVisitor : public ExpressionVisitor<void> {
public: public:
explicit PropertyLookupEvaluationModeVisitor() {} explicit PropertyLookupEvaluationModeVisitor() = default;
using ExpressionVisitor<void>::Visit; using ExpressionVisitor<void>::Visit;

View File

@ -22,7 +22,7 @@ namespace memgraph::query {
class SymbolTable final { class SymbolTable final {
public: public:
SymbolTable() {} SymbolTable() = default;
const Symbol &CreateSymbol(const std::string &name, bool user_declared, Symbol::Type type = Symbol::Type::ANY, const Symbol &CreateSymbol(const std::string &name, bool user_declared, Symbol::Type type = Symbol::Type::ANY,
int32_t token_position = -1) { int32_t token_position = -1) {
MG_ASSERT(table_.size() <= std::numeric_limits<int32_t>::max(), MG_ASSERT(table_.size() <= std::numeric_limits<int32_t>::max(),

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -16,6 +16,7 @@
#include <iostream> #include <iostream>
#include <span> #include <span>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "query/exceptions.hpp" #include "query/exceptions.hpp"
@ -32,7 +33,7 @@ namespace memgraph::query::frontend {
using namespace lexer_constants; using namespace lexer_constants;
StrippedQuery::StrippedQuery(const std::string &query) : original_(query) { StrippedQuery::StrippedQuery(std::string query) : original_(std::move(query)) {
enum class Token { enum class Token {
UNMATCHED, UNMATCHED,
KEYWORD, // Including true, false and null. KEYWORD, // Including true, false and null.
@ -255,29 +256,29 @@ std::string GetFirstUtf8Symbol(const char *_s) {
// According to // According to
// https://stackoverflow.com/questions/16260033/reinterpret-cast-between-char-and-stduint8-t-safe // https://stackoverflow.com/questions/16260033/reinterpret-cast-between-char-and-stduint8-t-safe
// this checks if casting from const char * to uint8_t is undefined behaviour. // this checks if casting from const char * to uint8_t is undefined behaviour.
static_assert(std::is_same<std::uint8_t, unsigned char>::value, static_assert(std::is_same_v<std::uint8_t, unsigned char>,
"This library requires std::uint8_t to be implemented as " "This library requires std::uint8_t to be implemented as "
"unsigned char."); "unsigned char.");
const uint8_t *s = reinterpret_cast<const uint8_t *>(_s); const uint8_t *s = reinterpret_cast<const uint8_t *>(_s);
if ((*s >> 7) == 0x00) return std::string(_s, _s + 1); if ((*s >> 7) == 0x00) return std::string(_s, _s + 1);
if ((*s >> 5) == 0x06) { if ((*s >> 5) == 0x06) {
auto *s1 = s + 1; const auto *s1 = s + 1;
if ((*s1 >> 6) != 0x02) throw LexingException("Invalid character."); if ((*s1 >> 6) != 0x02) throw LexingException("Invalid character.");
return std::string(_s, _s + 2); return std::string(_s, _s + 2);
} }
if ((*s >> 4) == 0x0e) { if ((*s >> 4) == 0x0e) {
auto *s1 = s + 1; const auto *s1 = s + 1;
if ((*s1 >> 6) != 0x02) throw LexingException("Invalid character."); if ((*s1 >> 6) != 0x02) throw LexingException("Invalid character.");
auto *s2 = s + 2; const auto *s2 = s + 2;
if ((*s2 >> 6) != 0x02) throw LexingException("Invalid character."); if ((*s2 >> 6) != 0x02) throw LexingException("Invalid character.");
return std::string(_s, _s + 3); return std::string(_s, _s + 3);
} }
if ((*s >> 3) == 0x1e) { if ((*s >> 3) == 0x1e) {
auto *s1 = s + 1; const auto *s1 = s + 1;
if ((*s1 >> 6) != 0x02) throw LexingException("Invalid character."); if ((*s1 >> 6) != 0x02) throw LexingException("Invalid character.");
auto *s2 = s + 2; const auto *s2 = s + 2;
if ((*s2 >> 6) != 0x02) throw LexingException("Invalid character."); if ((*s2 >> 6) != 0x02) throw LexingException("Invalid character.");
auto *s3 = s + 3; const auto *s3 = s + 3;
if ((*s3 >> 6) != 0x02) throw LexingException("Invalid character."); if ((*s3 >> 6) != 0x02) throw LexingException("Invalid character.");
return std::string(_s, _s + 4); return std::string(_s, _s + 4);
} }
@ -286,29 +287,29 @@ std::string GetFirstUtf8Symbol(const char *_s) {
// Return codepoint of first utf8 symbol and its encoded length. // Return codepoint of first utf8 symbol and its encoded length.
std::pair<int, int> GetFirstUtf8SymbolCodepoint(const char *_s) { std::pair<int, int> GetFirstUtf8SymbolCodepoint(const char *_s) {
static_assert(std::is_same<std::uint8_t, unsigned char>::value, static_assert(std::is_same_v<std::uint8_t, unsigned char>,
"This library requires std::uint8_t to be implemented as " "This library requires std::uint8_t to be implemented as "
"unsigned char."); "unsigned char.");
const uint8_t *s = reinterpret_cast<const uint8_t *>(_s); const uint8_t *s = reinterpret_cast<const uint8_t *>(_s);
if ((*s >> 7) == 0x00) return {*s & 0x7f, 1}; if ((*s >> 7) == 0x00) return {*s & 0x7f, 1};
if ((*s >> 5) == 0x06) { if ((*s >> 5) == 0x06) {
auto *s1 = s + 1; const auto *s1 = s + 1;
if ((*s1 >> 6) != 0x02) throw LexingException("Invalid character."); if ((*s1 >> 6) != 0x02) throw LexingException("Invalid character.");
return {((*s & 0x1f) << 6) | (*s1 & 0x3f), 2}; return {((*s & 0x1f) << 6) | (*s1 & 0x3f), 2};
} }
if ((*s >> 4) == 0x0e) { if ((*s >> 4) == 0x0e) {
auto *s1 = s + 1; const auto *s1 = s + 1;
if ((*s1 >> 6) != 0x02) throw LexingException("Invalid character."); if ((*s1 >> 6) != 0x02) throw LexingException("Invalid character.");
auto *s2 = s + 2; const auto *s2 = s + 2;
if ((*s2 >> 6) != 0x02) throw LexingException("Invalid character."); if ((*s2 >> 6) != 0x02) throw LexingException("Invalid character.");
return {((*s & 0x0f) << 12) | ((*s1 & 0x3f) << 6) | (*s2 & 0x3f), 3}; return {((*s & 0x0f) << 12) | ((*s1 & 0x3f) << 6) | (*s2 & 0x3f), 3};
} }
if ((*s >> 3) == 0x1e) { if ((*s >> 3) == 0x1e) {
auto *s1 = s + 1; const auto *s1 = s + 1;
if ((*s1 >> 6) != 0x02) throw LexingException("Invalid character."); if ((*s1 >> 6) != 0x02) throw LexingException("Invalid character.");
auto *s2 = s + 2; const auto *s2 = s + 2;
if ((*s2 >> 6) != 0x02) throw LexingException("Invalid character."); if ((*s2 >> 6) != 0x02) throw LexingException("Invalid character.");
auto *s3 = s + 3; const auto *s3 = s + 3;
if ((*s3 >> 6) != 0x02) throw LexingException("Invalid character."); if ((*s3 >> 6) != 0x02) throw LexingException("Invalid character.");
return {((*s & 0x07) << 18) | ((*s1 & 0x3f) << 12) | ((*s2 & 0x3f) << 6) | (*s3 & 0x3f), 4}; return {((*s & 0x07) << 18) | ((*s1 & 0x3f) << 12) | ((*s2 & 0x3f) << 6) | (*s3 & 0x3f), 4};
} }
@ -336,7 +337,7 @@ int StrippedQuery::MatchSpecial(int start) const { return kSpecialTokens.Match(o
int StrippedQuery::MatchString(int start) const { int StrippedQuery::MatchString(int start) const {
if (original_[start] != '"' && original_[start] != '\'') return 0; if (original_[start] != '"' && original_[start] != '\'') return 0;
char start_char = original_[start]; char start_char = original_[start];
for (auto *p = original_.data() + start + 1; *p; ++p) { for (const auto *p = original_.data() + start + 1; *p; ++p) {
if (*p == start_char) return p - (original_.data() + start) + 1; if (*p == start_char) return p - (original_.data() + start) + 1;
if (*p == '\\') { if (*p == '\\') {
++p; ++p;
@ -346,7 +347,7 @@ int StrippedQuery::MatchString(int start) const {
continue; continue;
} else if (*p == 'U' || *p == 'u') { } else if (*p == 'U' || *p == 'u') {
int cnt = 0; int cnt = 0;
auto *r = p + 1; const auto *r = p + 1;
while (isxdigit(*r) && cnt < 8) { while (isxdigit(*r) && cnt < 8) {
++cnt; ++cnt;
++r; ++r;

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd. // Copyright 2023 Memgraph Ltd.
// //
// Use of this software is governed by the Business Source License // Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -40,7 +40,7 @@ class StrippedQuery {
* *
* @param query Input query. * @param query Input query.
*/ */
explicit StrippedQuery(const std::string &query); explicit StrippedQuery(std::string query);
/** /**
* Copy constructor is deleted because we don't want to make unnecessary * Copy constructor is deleted because we don't want to make unnecessary

View File

@ -17,8 +17,7 @@
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
namespace memgraph::query { namespace memgraph::query::lexer_constants {
namespace lexer_constants {
namespace trie { namespace trie {
@ -33,7 +32,7 @@ inline int Noop(int x) { return x; }
class Trie { class Trie {
public: public:
Trie() {} Trie() = default;
Trie(std::initializer_list<std::string> l) { Trie(std::initializer_list<std::string> l) {
for (const auto &s : l) { for (const auto &s : l) {
Insert(s); Insert(s);
@ -2934,5 +2933,4 @@ const trie::Trie kSpecialTokens = {";",
"\xEF\xB9\x98", // u8"\ufe58" "\xEF\xB9\x98", // u8"\ufe58"
"\xEF\xB9\xA3", // u8"\ufe63" "\xEF\xB9\xA3", // u8"\ufe63"
"\xEF\xBC\x8D"}; // u8"\uff0d" "\xEF\xBC\x8D"}; // u8"\uff0d"
} // namespace lexer_constants } // namespace memgraph::query::lexer_constants
} // namespace memgraph::query

View File

@ -593,6 +593,7 @@ TypedValue ValueType(const TypedValue *args, int64_t nargs, const FunctionContex
case TypedValue::Type::Duration: case TypedValue::Type::Duration:
return TypedValue("DURATION", ctx.memory); return TypedValue("DURATION", ctx.memory);
case TypedValue::Type::Graph: case TypedValue::Type::Graph:
case TypedValue::Type::Function:
throw QueryRuntimeException("Cannot fetch graph as it is not standardized openCypher type name"); throw QueryRuntimeException("Cannot fetch graph as it is not standardized openCypher type name");
} }
} }

View File

@ -18,6 +18,7 @@
#include <map> #include <map>
#include <optional> #include <optional>
#include <regex> #include <regex>
#include <stdexcept>
#include <string> #include <string>
#include <vector> #include <vector>
@ -28,8 +29,10 @@
#include "query/frontend/ast/ast.hpp" #include "query/frontend/ast/ast.hpp"
#include "query/frontend/semantic/symbol_table.hpp" #include "query/frontend/semantic/symbol_table.hpp"
#include "query/interpret/frame.hpp" #include "query/interpret/frame.hpp"
#include "query/procedure/mg_procedure_impl.hpp"
#include "query/typed_value.hpp" #include "query/typed_value.hpp"
#include "spdlog/spdlog.h" #include "spdlog/spdlog.h"
#include "storage/v2/storage_mode.hpp"
#include "utils/exceptions.hpp" #include "utils/exceptions.hpp"
#include "utils/frame_change_id.hpp" #include "utils/frame_change_id.hpp"
#include "utils/logging.hpp" #include "utils/logging.hpp"
@ -187,6 +190,8 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
utils::MemoryResource *GetMemoryResource() const { return ctx_->memory; } utils::MemoryResource *GetMemoryResource() const { return ctx_->memory; }
void ResetPropertyLookupCache() { property_lookup_cache_.clear(); }
TypedValue Visit(NamedExpression &named_expression) override { TypedValue Visit(NamedExpression &named_expression) override {
const auto &symbol = symbol_table_->at(named_expression); const auto &symbol = symbol_table_->at(named_expression);
auto value = named_expression.expression_->Accept(*this); auto value = named_expression.expression_->Accept(*this);
@ -315,8 +320,8 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
return std::move(*preoperational_checks); return std::move(*preoperational_checks);
} }
auto &cached_value = frame_change_collector_->GetCachedValue(*cached_id); auto &cached_value = frame_change_collector_->GetCachedValue(*cached_id);
cached_value.CacheValue(std::move(list)); // Don't move here because we don't want to remove the element from the frame
spdlog::trace("Value cached {}", *cached_id); cached_value.CacheValue(list);
} }
const auto &cached_value = frame_change_collector_->GetCachedValue(*cached_id); const auto &cached_value = frame_change_collector_->GetCachedValue(*cached_id);
@ -338,7 +343,6 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
} }
const auto &list_value = list.ValueList(); const auto &list_value = list.ValueList();
spdlog::trace("Not using cache on IN LIST operator");
auto has_null = false; auto has_null = false;
for (const auto &element : list_value) { for (const auto &element : list_value) {
auto result = literal == element; auto result = literal == element;
@ -826,8 +830,8 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
throw QueryRuntimeException("'coalesce' requires at least one argument."); throw QueryRuntimeException("'coalesce' requires at least one argument.");
} }
for (int64_t i = 0; i < exprs.size(); ++i) { for (auto &expr : exprs) {
TypedValue val(exprs[i]->Accept(*this), ctx_->memory); TypedValue val(expr->Accept(*this), ctx_->memory);
if (!val.IsNull()) { if (!val.IsNull()) {
return val; return val;
} }
@ -838,6 +842,8 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
TypedValue Visit(Function &function) override { TypedValue Visit(Function &function) override {
FunctionContext function_ctx{dba_, ctx_->memory, ctx_->timestamp, &ctx_->counters, view_}; FunctionContext function_ctx{dba_, ctx_->memory, ctx_->timestamp, &ctx_->counters, view_};
bool is_transactional = storage::IsTransactional(dba_->GetStorageMode());
TypedValue res(ctx_->memory);
// Stack allocate evaluated arguments when there's a small number of them. // Stack allocate evaluated arguments when there's a small number of them.
if (function.arguments_.size() <= 8) { if (function.arguments_.size() <= 8) {
TypedValue arguments[8] = {TypedValue(ctx_->memory), TypedValue(ctx_->memory), TypedValue(ctx_->memory), TypedValue arguments[8] = {TypedValue(ctx_->memory), TypedValue(ctx_->memory), TypedValue(ctx_->memory),
@ -846,19 +852,20 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
for (size_t i = 0; i < function.arguments_.size(); ++i) { for (size_t i = 0; i < function.arguments_.size(); ++i) {
arguments[i] = function.arguments_[i]->Accept(*this); arguments[i] = function.arguments_[i]->Accept(*this);
} }
auto res = function.function_(arguments, function.arguments_.size(), function_ctx); res = function.function_(arguments, function.arguments_.size(), function_ctx);
MG_ASSERT(res.GetMemoryResource() == ctx_->memory);
return res;
} else { } else {
TypedValue::TVector arguments(ctx_->memory); TypedValue::TVector arguments(ctx_->memory);
arguments.reserve(function.arguments_.size()); arguments.reserve(function.arguments_.size());
for (const auto &argument : function.arguments_) { for (const auto &argument : function.arguments_) {
arguments.emplace_back(argument->Accept(*this)); arguments.emplace_back(argument->Accept(*this));
} }
auto res = function.function_(arguments.data(), arguments.size(), function_ctx); res = function.function_(arguments.data(), arguments.size(), function_ctx);
MG_ASSERT(res.GetMemoryResource() == ctx_->memory);
return res;
} }
MG_ASSERT(res.GetMemoryResource() == ctx_->memory);
if (!is_transactional && res.ContainsDeleted()) [[unlikely]] {
return TypedValue(ctx_->memory);
}
return res;
} }
TypedValue Visit(Reduce &reduce) override { TypedValue Visit(Reduce &reduce) override {
@ -904,7 +911,17 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
return TypedValue(std::move(result), ctx_->memory); return TypedValue(std::move(result), ctx_->memory);
} }
TypedValue Visit(Exists &exists) override { return TypedValue{frame_->at(symbol_table_->at(exists)), ctx_->memory}; } TypedValue Visit(Exists &exists) override {
TypedValue &frame_exists_value = frame_->at(symbol_table_->at(exists));
if (!frame_exists_value.IsFunction()) [[unlikely]] {
throw QueryRuntimeException(
"Unexpected behavior: Exists expected a function, got {}. Please report the problem on GitHub issues",
frame_exists_value.type());
}
TypedValue result{ctx_->memory};
frame_exists_value.ValueFunction()(&result);
return result;
}
TypedValue Visit(All &all) override { TypedValue Visit(All &all) override {
auto list_value = all.list_expression_->Accept(*this); auto list_value = all.list_expression_->Accept(*this);

View File

@ -147,6 +147,8 @@ void memgraph::query::CurrentDB::CleanupDBTransaction(bool abort) {
namespace memgraph::query { namespace memgraph::query {
constexpr std::string_view kSchemaAssert = "SCHEMA.ASSERT";
template <typename> template <typename>
constexpr auto kAlwaysFalse = false; constexpr auto kAlwaysFalse = false;
@ -282,8 +284,7 @@ inline auto convertToReplicationMode(const ReplicationQuery::SyncMode &sync_mode
class ReplQueryHandler final : public query::ReplicationQueryHandler { class ReplQueryHandler final : public query::ReplicationQueryHandler {
public: public:
explicit ReplQueryHandler(dbms::DbmsHandler *dbms_handler, memgraph::replication::ReplicationState *repl_state) explicit ReplQueryHandler(dbms::DbmsHandler *dbms_handler) : dbms_handler_(dbms_handler), handler_{*dbms_handler} {}
: dbms_handler_(dbms_handler), handler_{*repl_state, *dbms_handler} {}
/// @throw QueryRuntimeException if an error ocurred. /// @throw QueryRuntimeException if an error ocurred.
void SetReplicationRole(ReplicationQuery::ReplicationRole replication_role, std::optional<int64_t> port) override { void SetReplicationRole(ReplicationQuery::ReplicationRole replication_role, std::optional<int64_t> port) override {
@ -412,8 +413,8 @@ class ReplQueryHandler final : public query::ReplicationQueryHandler {
case storage::replication::ReplicaState::RECOVERY: case storage::replication::ReplicaState::RECOVERY:
replica.state = ReplicationQuery::ReplicaState::RECOVERY; replica.state = ReplicationQuery::ReplicaState::RECOVERY;
break; break;
case storage::replication::ReplicaState::INVALID: case storage::replication::ReplicaState::MAYBE_BEHIND:
replica.state = ReplicationQuery::ReplicaState::INVALID; replica.state = ReplicationQuery::ReplicaState::MAYBE_BEHIND;
break; break;
} }
@ -487,7 +488,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_
MG_ASSERT(password.IsString() || password.IsNull()); MG_ASSERT(password.IsString() || password.IsNull());
if (!auth->CreateUser(username, password.IsString() ? std::make_optional(std::string(password.ValueString())) if (!auth->CreateUser(username, password.IsString() ? std::make_optional(std::string(password.ValueString()))
: std::nullopt)) { : std::nullopt)) {
throw QueryRuntimeException("User '{}' already exists.", username); throw UserAlreadyExistsException("User '{}' already exists.", username);
} }
// If the license is not valid we create users with admin access // If the license is not valid we create users with admin access
@ -721,8 +722,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_
Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &parameters, Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &parameters,
dbms::DbmsHandler *dbms_handler, const query::InterpreterConfig &config, dbms::DbmsHandler *dbms_handler, const query::InterpreterConfig &config,
std::vector<Notification> *notifications, std::vector<Notification> *notifications) {
memgraph::replication::ReplicationState *repl_state) {
// TODO: MemoryResource for EvaluationContext, it should probably be passed as // TODO: MemoryResource for EvaluationContext, it should probably be passed as
// the argument to Callback. // the argument to Callback.
EvaluationContext evaluation_context; EvaluationContext evaluation_context;
@ -742,8 +742,7 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &
notifications->emplace_back(SeverityLevel::WARNING, NotificationCode::REPLICA_PORT_WARNING, notifications->emplace_back(SeverityLevel::WARNING, NotificationCode::REPLICA_PORT_WARNING,
"Be careful the replication port must be different from the memgraph port!"); "Be careful the replication port must be different from the memgraph port!");
} }
callback.fn = [handler = ReplQueryHandler{dbms_handler, repl_state}, role = repl_query->role_, callback.fn = [handler = ReplQueryHandler{dbms_handler}, role = repl_query->role_, maybe_port]() mutable {
maybe_port]() mutable {
handler.SetReplicationRole(role, maybe_port); handler.SetReplicationRole(role, maybe_port);
return std::vector<std::vector<TypedValue>>(); return std::vector<std::vector<TypedValue>>();
}; };
@ -755,7 +754,7 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &
} }
case ReplicationQuery::Action::SHOW_REPLICATION_ROLE: { case ReplicationQuery::Action::SHOW_REPLICATION_ROLE: {
callback.header = {"replication role"}; callback.header = {"replication role"};
callback.fn = [handler = ReplQueryHandler{dbms_handler, repl_state}] { callback.fn = [handler = ReplQueryHandler{dbms_handler}] {
auto mode = handler.ShowReplicationRole(); auto mode = handler.ShowReplicationRole();
switch (mode) { switch (mode) {
case ReplicationQuery::ReplicationRole::MAIN: { case ReplicationQuery::ReplicationRole::MAIN: {
@ -774,7 +773,7 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &
auto socket_address = repl_query->socket_address_->Accept(evaluator); auto socket_address = repl_query->socket_address_->Accept(evaluator);
const auto replica_check_frequency = config.replication_replica_check_frequency; const auto replica_check_frequency = config.replication_replica_check_frequency;
callback.fn = [handler = ReplQueryHandler{dbms_handler, repl_state}, name, socket_address, sync_mode, callback.fn = [handler = ReplQueryHandler{dbms_handler}, name, socket_address, sync_mode,
replica_check_frequency]() mutable { replica_check_frequency]() mutable {
handler.RegisterReplica(name, std::string(socket_address.ValueString()), sync_mode, replica_check_frequency); handler.RegisterReplica(name, std::string(socket_address.ValueString()), sync_mode, replica_check_frequency);
return std::vector<std::vector<TypedValue>>(); return std::vector<std::vector<TypedValue>>();
@ -785,7 +784,7 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &
} }
case ReplicationQuery::Action::DROP_REPLICA: { case ReplicationQuery::Action::DROP_REPLICA: {
const auto &name = repl_query->replica_name_; const auto &name = repl_query->replica_name_;
callback.fn = [handler = ReplQueryHandler{dbms_handler, repl_state}, name]() mutable { callback.fn = [handler = ReplQueryHandler{dbms_handler}, name]() mutable {
handler.DropReplica(name); handler.DropReplica(name);
return std::vector<std::vector<TypedValue>>(); return std::vector<std::vector<TypedValue>>();
}; };
@ -797,7 +796,7 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &
callback.header = { callback.header = {
"name", "socket_address", "sync_mode", "current_timestamp_of_replica", "number_of_timestamp_behind_master", "name", "socket_address", "sync_mode", "current_timestamp_of_replica", "number_of_timestamp_behind_master",
"state"}; "state"};
callback.fn = [handler = ReplQueryHandler{dbms_handler, repl_state}, replica_nfields = callback.header.size()] { callback.fn = [handler = ReplQueryHandler{dbms_handler}, replica_nfields = callback.header.size()] {
const auto &replicas = handler.ShowReplicas(); const auto &replicas = handler.ShowReplicas();
auto typed_replicas = std::vector<std::vector<TypedValue>>{}; auto typed_replicas = std::vector<std::vector<TypedValue>>{};
typed_replicas.reserve(replicas.size()); typed_replicas.reserve(replicas.size());
@ -805,34 +804,33 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &
std::vector<TypedValue> typed_replica; std::vector<TypedValue> typed_replica;
typed_replica.reserve(replica_nfields); typed_replica.reserve(replica_nfields);
typed_replica.emplace_back(TypedValue(replica.name)); typed_replica.emplace_back(replica.name);
typed_replica.emplace_back(TypedValue(replica.socket_address)); typed_replica.emplace_back(replica.socket_address);
switch (replica.sync_mode) { switch (replica.sync_mode) {
case ReplicationQuery::SyncMode::SYNC: case ReplicationQuery::SyncMode::SYNC:
typed_replica.emplace_back(TypedValue("sync")); typed_replica.emplace_back("sync");
break; break;
case ReplicationQuery::SyncMode::ASYNC: case ReplicationQuery::SyncMode::ASYNC:
typed_replica.emplace_back(TypedValue("async")); typed_replica.emplace_back("async");
break; break;
} }
typed_replica.emplace_back(TypedValue(static_cast<int64_t>(replica.current_timestamp_of_replica))); typed_replica.emplace_back(static_cast<int64_t>(replica.current_timestamp_of_replica));
typed_replica.emplace_back( typed_replica.emplace_back(static_cast<int64_t>(replica.current_number_of_timestamp_behind_master));
TypedValue(static_cast<int64_t>(replica.current_number_of_timestamp_behind_master)));
switch (replica.state) { switch (replica.state) {
case ReplicationQuery::ReplicaState::READY: case ReplicationQuery::ReplicaState::READY:
typed_replica.emplace_back(TypedValue("ready")); typed_replica.emplace_back("ready");
break; break;
case ReplicationQuery::ReplicaState::REPLICATING: case ReplicationQuery::ReplicaState::REPLICATING:
typed_replica.emplace_back(TypedValue("replicating")); typed_replica.emplace_back("replicating");
break; break;
case ReplicationQuery::ReplicaState::RECOVERY: case ReplicationQuery::ReplicaState::RECOVERY:
typed_replica.emplace_back(TypedValue("recovery")); typed_replica.emplace_back("recovery");
break; break;
case ReplicationQuery::ReplicaState::INVALID: case ReplicationQuery::ReplicaState::MAYBE_BEHIND:
typed_replica.emplace_back(TypedValue("invalid")); typed_replica.emplace_back("invalid");
break; break;
} }
@ -1545,7 +1543,6 @@ inline static void TryCaching(const AstStorage &ast_storage, FrameChangeCollecto
continue; continue;
} }
frame_change_collector->AddTrackingKey(*cached_id); frame_change_collector->AddTrackingKey(*cached_id);
spdlog::trace("Tracking {} operator, by id: {}", InListOperator::kType.name, *cached_id);
} }
} }
@ -1971,11 +1968,11 @@ std::vector<std::vector<TypedValue>> AnalyzeGraphQueryHandler::AnalyzeGraphCreat
result.reserve(kComputeStatisticsNumResults); result.reserve(kComputeStatisticsNumResults);
result.emplace_back(execution_db_accessor->LabelToName(stat_entry.first)); result.emplace_back(execution_db_accessor->LabelToName(stat_entry.first));
result.emplace_back(TypedValue()); result.emplace_back();
result.emplace_back(static_cast<int64_t>(stat_entry.second.count)); result.emplace_back(static_cast<int64_t>(stat_entry.second.count));
result.emplace_back(TypedValue()); result.emplace_back();
result.emplace_back(TypedValue()); result.emplace_back();
result.emplace_back(TypedValue()); result.emplace_back();
result.emplace_back(stat_entry.second.avg_degree); result.emplace_back(stat_entry.second.avg_degree);
results.push_back(std::move(result)); results.push_back(std::move(result));
}); });
@ -2273,15 +2270,14 @@ PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transa
PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, bool in_explicit_transaction, PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
std::vector<Notification> *notifications, dbms::DbmsHandler &dbms_handler, std::vector<Notification> *notifications, dbms::DbmsHandler &dbms_handler,
const InterpreterConfig &config, const InterpreterConfig &config) {
memgraph::replication::ReplicationState *repl_state) {
if (in_explicit_transaction) { if (in_explicit_transaction) {
throw ReplicationModificationInMulticommandTxException(); throw ReplicationModificationInMulticommandTxException();
} }
auto *replication_query = utils::Downcast<ReplicationQuery>(parsed_query.query); auto *replication_query = utils::Downcast<ReplicationQuery>(parsed_query.query);
auto callback = HandleReplicationQuery(replication_query, parsed_query.parameters, &dbms_handler, config, auto callback =
notifications, repl_state); HandleReplicationQuery(replication_query, parsed_query.parameters, &dbms_handler, config, notifications);
return PreparedQuery{callback.header, std::move(parsed_query.required_privileges), return PreparedQuery{callback.header, std::move(parsed_query.required_privileges),
[callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}]( [callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}](
@ -2892,7 +2888,7 @@ auto ShowTransactions(const std::unordered_set<Interpreter *> &interpreters, con
metadata_tv.emplace(md.first, TypedValue(md.second)); metadata_tv.emplace(md.first, TypedValue(md.second));
} }
} }
results.back().push_back(TypedValue(metadata_tv)); results.back().emplace_back(metadata_tv);
} }
} }
return results; return results;
@ -3056,6 +3052,46 @@ PreparedQuery PrepareDatabaseInfoQuery(ParsedQuery parsed_query, bool in_explici
}; };
break; break;
} }
case DatabaseInfoQuery::InfoType::EDGE_TYPES: {
header = {"edge types"};
handler = [storage = current_db.db_acc_->get()->storage(), dba] {
if (!storage->config_.items.enable_schema_metadata) {
throw QueryRuntimeException(
"The metadata collection for edge-types is disabled. To enable it, restart your instance and set the "
"storage-enable-schema-metadata flag to True.");
}
auto edge_types = dba->ListAllPossiblyPresentEdgeTypes();
std::vector<std::vector<TypedValue>> results;
results.reserve(edge_types.size());
for (auto &edge_type : edge_types) {
results.push_back({TypedValue(storage->EdgeTypeToName(edge_type))});
}
return std::pair{results, QueryHandlerResult::COMMIT};
};
break;
}
case DatabaseInfoQuery::InfoType::NODE_LABELS: {
header = {"node labels"};
handler = [storage = current_db.db_acc_->get()->storage(), dba] {
if (!storage->config_.items.enable_schema_metadata) {
throw QueryRuntimeException(
"The metadata collection for node-labels is disabled. To enable it, restart your instance and set the "
"storage-enable-schema-metadata flag to True.");
}
auto node_labels = dba->ListAllPossiblyPresentVertexLabels();
std::vector<std::vector<TypedValue>> results;
results.reserve(node_labels.size());
for (auto &node_label : node_labels) {
results.push_back({TypedValue(storage->LabelToName(node_label))});
}
return std::pair{results, QueryHandlerResult::COMMIT};
};
break;
}
} }
return PreparedQuery{std::move(header), std::move(parsed_query.required_privileges), return PreparedQuery{std::move(header), std::move(parsed_query.required_privileges),
@ -3359,8 +3395,7 @@ PreparedQuery PrepareConstraintQuery(ParsedQuery parsed_query, bool in_explicit_
PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, CurrentDB &current_db, PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, CurrentDB &current_db,
InterpreterContext *interpreter_context, InterpreterContext *interpreter_context,
std::optional<std::function<void(std::string_view)>> on_change_cb, std::optional<std::function<void(std::string_view)>> on_change_cb) {
memgraph::replication::ReplicationState *repl_state) {
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
if (!license::global_license_checker.IsEnterpriseValidFast()) { if (!license::global_license_checker.IsEnterpriseValidFast()) {
throw QueryException("Trying to use enterprise feature without a valid license."); throw QueryException("Trying to use enterprise feature without a valid license.");
@ -3371,9 +3406,11 @@ PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, CurrentDB &cur
auto *query = utils::Downcast<MultiDatabaseQuery>(parsed_query.query); auto *query = utils::Downcast<MultiDatabaseQuery>(parsed_query.query);
auto *db_handler = interpreter_context->dbms_handler; auto *db_handler = interpreter_context->dbms_handler;
const bool is_replica = interpreter_context->repl_state->IsReplica();
switch (query->action_) { switch (query->action_) {
case MultiDatabaseQuery::Action::CREATE: case MultiDatabaseQuery::Action::CREATE:
if (repl_state->IsReplica()) { if (is_replica) {
throw QueryException("Query forbidden on the replica!"); throw QueryException("Query forbidden on the replica!");
} }
return PreparedQuery{ return PreparedQuery{
@ -3418,12 +3455,12 @@ PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, CurrentDB &cur
if (current_db.in_explicit_db_) { if (current_db.in_explicit_db_) {
throw QueryException("Database switching is prohibited if session explicitly defines the used database"); throw QueryException("Database switching is prohibited if session explicitly defines the used database");
} }
if (!dbms::allow_mt_repl && repl_state->IsReplica()) { if (!dbms::allow_mt_repl && is_replica) {
throw QueryException("Query forbidden on the replica!"); throw QueryException("Query forbidden on the replica!");
} }
return PreparedQuery{{"STATUS"}, return PreparedQuery{{"STATUS"},
std::move(parsed_query.required_privileges), std::move(parsed_query.required_privileges),
[db_name = query->db_name_, db_handler, &current_db, on_change_cb]( [db_name = query->db_name_, db_handler, &current_db, on_change = std::move(on_change_cb)](
AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> { AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> {
std::vector<std::vector<TypedValue>> status; std::vector<std::vector<TypedValue>> status;
std::string res; std::string res;
@ -3433,7 +3470,7 @@ PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, CurrentDB &cur
res = "Already using " + db_name; res = "Already using " + db_name;
} else { } else {
auto tmp = db_handler->Get(db_name); auto tmp = db_handler->Get(db_name);
if (on_change_cb) (*on_change_cb)(db_name); // Will trow if cb fails if (on_change) (*on_change)(db_name); // Will trow if cb fails
current_db.SetCurrentDB(std::move(tmp), false); current_db.SetCurrentDB(std::move(tmp), false);
res = "Using " + db_name; res = "Using " + db_name;
} }
@ -3452,7 +3489,7 @@ PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, CurrentDB &cur
query->db_name_}; query->db_name_};
case MultiDatabaseQuery::Action::DROP: case MultiDatabaseQuery::Action::DROP:
if (repl_state->IsReplica()) { if (is_replica) {
throw QueryException("Query forbidden on the replica!"); throw QueryException("Query forbidden on the replica!");
} }
return PreparedQuery{ return PreparedQuery{
@ -3727,7 +3764,8 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
// TODO: ATM only a single database, will change when we have multiple database transactions // TODO: ATM only a single database, will change when we have multiple database transactions
bool could_commit = utils::Downcast<CypherQuery>(parsed_query.query) != nullptr; bool could_commit = utils::Downcast<CypherQuery>(parsed_query.query) != nullptr;
bool unique = utils::Downcast<IndexQuery>(parsed_query.query) != nullptr || bool unique = utils::Downcast<IndexQuery>(parsed_query.query) != nullptr ||
utils::Downcast<ConstraintQuery>(parsed_query.query) != nullptr; utils::Downcast<ConstraintQuery>(parsed_query.query) != nullptr ||
upper_case_query.find(kSchemaAssert) != std::string::npos;
SetupDatabaseTransaction(could_commit, unique); SetupDatabaseTransaction(could_commit, unique);
} }
@ -3775,9 +3813,9 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
&query_execution->notifications, current_db_); &query_execution->notifications, current_db_);
} else if (utils::Downcast<ReplicationQuery>(parsed_query.query)) { } else if (utils::Downcast<ReplicationQuery>(parsed_query.query)) {
/// TODO: make replication DB agnostic /// TODO: make replication DB agnostic
prepared_query = PrepareReplicationQuery(std::move(parsed_query), in_explicit_transaction_, prepared_query =
&query_execution->notifications, *interpreter_context_->dbms_handler, PrepareReplicationQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications,
interpreter_context_->config, interpreter_context_->repl_state); *interpreter_context_->dbms_handler, interpreter_context_->config);
} else if (utils::Downcast<LockPathQuery>(parsed_query.query)) { } else if (utils::Downcast<LockPathQuery>(parsed_query.query)) {
prepared_query = PrepareLockPathQuery(std::move(parsed_query), in_explicit_transaction_, current_db_); prepared_query = PrepareLockPathQuery(std::move(parsed_query), in_explicit_transaction_, current_db_);
} else if (utils::Downcast<FreeMemoryQuery>(parsed_query.query)) { } else if (utils::Downcast<FreeMemoryQuery>(parsed_query.query)) {
@ -3817,8 +3855,8 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
throw MultiDatabaseQueryInMulticommandTxException(); throw MultiDatabaseQueryInMulticommandTxException();
} }
/// SYSTEM (Replication) + INTERPRETER /// SYSTEM (Replication) + INTERPRETER
prepared_query = PrepareMultiDatabaseQuery(std::move(parsed_query), current_db_, interpreter_context_, on_change_, prepared_query =
interpreter_context_->repl_state); PrepareMultiDatabaseQuery(std::move(parsed_query), current_db_, interpreter_context_, on_change_);
} else if (utils::Downcast<ShowDatabasesQuery>(parsed_query.query)) { } else if (utils::Downcast<ShowDatabasesQuery>(parsed_query.query)) {
/// SYSTEM PURE ("SHOW DATABASES") /// SYSTEM PURE ("SHOW DATABASES")
/// INTERPRETER (TODO: "SHOW DATABASE") /// INTERPRETER (TODO: "SHOW DATABASE")

View File

@ -174,7 +174,7 @@ struct CurrentDB {
class Interpreter final { class Interpreter final {
public: public:
Interpreter(InterpreterContext *interpreter_context); explicit Interpreter(InterpreterContext *interpreter_context);
Interpreter(InterpreterContext *interpreter_context, memgraph::dbms::DatabaseAccess db); Interpreter(InterpreterContext *interpreter_context, memgraph::dbms::DatabaseAccess db);
Interpreter(const Interpreter &) = delete; Interpreter(const Interpreter &) = delete;
Interpreter &operator=(const Interpreter &) = delete; Interpreter &operator=(const Interpreter &) = delete;

View File

@ -57,6 +57,7 @@
#include "utils/likely.hpp" #include "utils/likely.hpp"
#include "utils/logging.hpp" #include "utils/logging.hpp"
#include "utils/memory.hpp" #include "utils/memory.hpp"
#include "utils/memory_tracker.hpp"
#include "utils/message.hpp" #include "utils/message.hpp"
#include "utils/on_scope_exit.hpp" #include "utils/on_scope_exit.hpp"
#include "utils/pmr/deque.hpp" #include "utils/pmr/deque.hpp"
@ -206,8 +207,8 @@ void Once::OnceCursor::Shutdown() {}
void Once::OnceCursor::Reset() { did_pull_ = false; } void Once::OnceCursor::Reset() { did_pull_ = false; }
CreateNode::CreateNode(const std::shared_ptr<LogicalOperator> &input, const NodeCreationInfo &node_info) CreateNode::CreateNode(const std::shared_ptr<LogicalOperator> &input, NodeCreationInfo node_info)
: input_(input ? input : std::make_shared<Once>()), node_info_(node_info) {} : input_(input ? input : std::make_shared<Once>()), node_info_(std::move(node_info)) {}
// Creates a vertex on this GraphDb. Returns a reference to vertex placed on the // Creates a vertex on this GraphDb. Returns a reference to vertex placed on the
// frame. // frame.
@ -297,12 +298,12 @@ void CreateNode::CreateNodeCursor::Shutdown() { input_cursor_->Shutdown(); }
void CreateNode::CreateNodeCursor::Reset() { input_cursor_->Reset(); } void CreateNode::CreateNodeCursor::Reset() { input_cursor_->Reset(); }
CreateExpand::CreateExpand(const NodeCreationInfo &node_info, const EdgeCreationInfo &edge_info, CreateExpand::CreateExpand(NodeCreationInfo node_info, EdgeCreationInfo edge_info,
const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, bool existing_node) const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, bool existing_node)
: node_info_(node_info), : node_info_(std::move(node_info)),
edge_info_(edge_info), edge_info_(std::move(edge_info)),
input_(input ? input : std::make_shared<Once>()), input_(input ? input : std::make_shared<Once>()),
input_symbol_(input_symbol), input_symbol_(std::move(input_symbol)),
existing_node_(existing_node) {} existing_node_(existing_node) {}
ACCEPT_WITH_INPUT(CreateExpand) ACCEPT_WITH_INPUT(CreateExpand)
@ -446,7 +447,7 @@ class ScanAllCursor : public Cursor {
explicit ScanAllCursor(const ScanAll &self, Symbol output_symbol, UniqueCursorPtr input_cursor, storage::View view, explicit ScanAllCursor(const ScanAll &self, Symbol output_symbol, UniqueCursorPtr input_cursor, storage::View view,
TVerticesFun get_vertices, const char *op_name) TVerticesFun get_vertices, const char *op_name)
: self_(self), : self_(self),
output_symbol_(output_symbol), output_symbol_(std::move(output_symbol)),
input_cursor_(std::move(input_cursor)), input_cursor_(std::move(input_cursor)),
view_(view), view_(view),
get_vertices_(std::move(get_vertices)), get_vertices_(std::move(get_vertices)),
@ -517,7 +518,7 @@ class ScanAllCursor : public Cursor {
}; };
ScanAll::ScanAll(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, storage::View view) ScanAll::ScanAll(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, storage::View view)
: input_(input ? input : std::make_shared<Once>()), output_symbol_(output_symbol), view_(view) {} : input_(input ? input : std::make_shared<Once>()), output_symbol_(std::move(output_symbol)), view_(view) {}
ACCEPT_WITH_INPUT(ScanAll) ACCEPT_WITH_INPUT(ScanAll)
@ -560,13 +561,13 @@ UniqueCursorPtr ScanAllByLabel::MakeCursor(utils::MemoryResource *mem) const {
ScanAllByLabelPropertyRange::ScanAllByLabelPropertyRange(const std::shared_ptr<LogicalOperator> &input, ScanAllByLabelPropertyRange::ScanAllByLabelPropertyRange(const std::shared_ptr<LogicalOperator> &input,
Symbol output_symbol, storage::LabelId label, Symbol output_symbol, storage::LabelId label,
storage::PropertyId property, const std::string &property_name, storage::PropertyId property, std::string property_name,
std::optional<Bound> lower_bound, std::optional<Bound> lower_bound,
std::optional<Bound> upper_bound, storage::View view) std::optional<Bound> upper_bound, storage::View view)
: ScanAll(input, output_symbol, view), : ScanAll(input, output_symbol, view),
label_(label), label_(label),
property_(property), property_(property),
property_name_(property_name), property_name_(std::move(property_name)),
lower_bound_(lower_bound), lower_bound_(lower_bound),
upper_bound_(upper_bound) { upper_bound_(upper_bound) {
MG_ASSERT(lower_bound_ || upper_bound_, "Only one bound can be left out"); MG_ASSERT(lower_bound_ || upper_bound_, "Only one bound can be left out");
@ -622,12 +623,12 @@ UniqueCursorPtr ScanAllByLabelPropertyRange::MakeCursor(utils::MemoryResource *m
ScanAllByLabelPropertyValue::ScanAllByLabelPropertyValue(const std::shared_ptr<LogicalOperator> &input, ScanAllByLabelPropertyValue::ScanAllByLabelPropertyValue(const std::shared_ptr<LogicalOperator> &input,
Symbol output_symbol, storage::LabelId label, Symbol output_symbol, storage::LabelId label,
storage::PropertyId property, const std::string &property_name, storage::PropertyId property, std::string property_name,
Expression *expression, storage::View view) Expression *expression, storage::View view)
: ScanAll(input, output_symbol, view), : ScanAll(input, output_symbol, view),
label_(label), label_(label),
property_(property), property_(property),
property_name_(property_name), property_name_(std::move(property_name)),
expression_(expression) { expression_(expression) {
DMG_ASSERT(expression, "Expression is not optional."); DMG_ASSERT(expression, "Expression is not optional.");
} }
@ -654,8 +655,11 @@ UniqueCursorPtr ScanAllByLabelPropertyValue::MakeCursor(utils::MemoryResource *m
ScanAllByLabelProperty::ScanAllByLabelProperty(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, ScanAllByLabelProperty::ScanAllByLabelProperty(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol,
storage::LabelId label, storage::PropertyId property, storage::LabelId label, storage::PropertyId property,
const std::string &property_name, storage::View view) std::string property_name, storage::View view)
: ScanAll(input, output_symbol, view), label_(label), property_(property), property_name_(property_name) {} : ScanAll(input, output_symbol, view),
label_(label),
property_(property),
property_name_(std::move(property_name)) {}
ACCEPT_WITH_INPUT(ScanAllByLabelProperty) ACCEPT_WITH_INPUT(ScanAllByLabelProperty)
@ -727,7 +731,7 @@ Expand::Expand(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbo
Symbol edge_symbol, EdgeAtom::Direction direction, const std::vector<storage::EdgeTypeId> &edge_types, Symbol edge_symbol, EdgeAtom::Direction direction, const std::vector<storage::EdgeTypeId> &edge_types,
bool existing_node, storage::View view) bool existing_node, storage::View view)
: input_(input ? input : std::make_shared<Once>()), : input_(input ? input : std::make_shared<Once>()),
input_symbol_(input_symbol), input_symbol_(std::move(input_symbol)),
common_{node_symbol, edge_symbol, direction, edge_types, existing_node}, common_{node_symbol, edge_symbol, direction, edge_types, existing_node},
view_(view) {} view_(view) {}
@ -961,15 +965,15 @@ ExpandVariable::ExpandVariable(const std::shared_ptr<LogicalOperator> &input, Sy
ExpansionLambda filter_lambda, std::optional<ExpansionLambda> weight_lambda, ExpansionLambda filter_lambda, std::optional<ExpansionLambda> weight_lambda,
std::optional<Symbol> total_weight) std::optional<Symbol> total_weight)
: input_(input ? input : std::make_shared<Once>()), : input_(input ? input : std::make_shared<Once>()),
input_symbol_(input_symbol), input_symbol_(std::move(input_symbol)),
common_{node_symbol, edge_symbol, direction, edge_types, existing_node}, common_{node_symbol, edge_symbol, direction, edge_types, existing_node},
type_(type), type_(type),
is_reverse_(is_reverse), is_reverse_(is_reverse),
lower_bound_(lower_bound), lower_bound_(lower_bound),
upper_bound_(upper_bound), upper_bound_(upper_bound),
filter_lambda_(filter_lambda), filter_lambda_(std::move(filter_lambda)),
weight_lambda_(weight_lambda), weight_lambda_(std::move(weight_lambda)),
total_weight_(total_weight) { total_weight_(std::move(total_weight)) {
DMG_ASSERT(type_ == EdgeAtom::Type::DEPTH_FIRST || type_ == EdgeAtom::Type::BREADTH_FIRST || DMG_ASSERT(type_ == EdgeAtom::Type::DEPTH_FIRST || type_ == EdgeAtom::Type::BREADTH_FIRST ||
type_ == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH || type_ == EdgeAtom::Type::ALL_SHORTEST_PATHS, type_ == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH || type_ == EdgeAtom::Type::ALL_SHORTEST_PATHS,
"ExpandVariable can only be used with breadth first, depth first, " "ExpandVariable can only be used with breadth first, depth first, "
@ -1134,6 +1138,11 @@ class ExpandVariableCursor : public Cursor {
edges_it_.emplace_back(edges_.back().begin()); edges_it_.emplace_back(edges_.back().begin());
} }
if (self_.filter_lambda_.accumulated_path_symbol) {
// Add initial vertex of path to the accumulated path
frame[self_.filter_lambda_.accumulated_path_symbol.value()] = Path(vertex);
}
// reset the frame value to an empty edge list // reset the frame value to an empty edge list
auto *pull_memory = context.evaluation_context.memory; auto *pull_memory = context.evaluation_context.memory;
frame[self_.common_.edge_symbol] = TypedValue::TVector(pull_memory); frame[self_.common_.edge_symbol] = TypedValue::TVector(pull_memory);
@ -1230,6 +1239,13 @@ class ExpandVariableCursor : public Cursor {
// Skip expanding out of filtered expansion. // Skip expanding out of filtered expansion.
frame[self_.filter_lambda_.inner_edge_symbol] = current_edge.first; frame[self_.filter_lambda_.inner_edge_symbol] = current_edge.first;
frame[self_.filter_lambda_.inner_node_symbol] = current_vertex; frame[self_.filter_lambda_.inner_node_symbol] = current_vertex;
if (self_.filter_lambda_.accumulated_path_symbol) {
MG_ASSERT(frame[self_.filter_lambda_.accumulated_path_symbol.value()].IsPath(),
"Accumulated path must be path");
Path &accumulated_path = frame[self_.filter_lambda_.accumulated_path_symbol.value()].ValuePath();
accumulated_path.Expand(current_edge.first);
accumulated_path.Expand(current_vertex);
}
if (self_.filter_lambda_.expression && !EvaluateFilter(evaluator, self_.filter_lambda_.expression)) continue; if (self_.filter_lambda_.expression && !EvaluateFilter(evaluator, self_.filter_lambda_.expression)) continue;
// we are doing depth-first search, so place the current // we are doing depth-first search, so place the current
@ -1542,6 +1558,13 @@ class SingleSourceShortestPathCursor : public query::plan::Cursor {
#endif #endif
frame[self_.filter_lambda_.inner_edge_symbol] = edge; frame[self_.filter_lambda_.inner_edge_symbol] = edge;
frame[self_.filter_lambda_.inner_node_symbol] = vertex; frame[self_.filter_lambda_.inner_node_symbol] = vertex;
if (self_.filter_lambda_.accumulated_path_symbol) {
MG_ASSERT(frame[self_.filter_lambda_.accumulated_path_symbol.value()].IsPath(),
"Accumulated path must have Path type");
Path &accumulated_path = frame[self_.filter_lambda_.accumulated_path_symbol.value()].ValuePath();
accumulated_path.Expand(edge);
accumulated_path.Expand(vertex);
}
if (self_.filter_lambda_.expression) { if (self_.filter_lambda_.expression) {
TypedValue result = self_.filter_lambda_.expression->Accept(evaluator); TypedValue result = self_.filter_lambda_.expression->Accept(evaluator);
@ -1603,6 +1626,11 @@ class SingleSourceShortestPathCursor : public query::plan::Cursor {
const auto &vertex = vertex_value.ValueVertex(); const auto &vertex = vertex_value.ValueVertex();
processed_.emplace(vertex, std::nullopt); processed_.emplace(vertex, std::nullopt);
if (self_.filter_lambda_.accumulated_path_symbol) {
// Add initial vertex of path to the accumulated path
frame[self_.filter_lambda_.accumulated_path_symbol.value()] = Path(vertex);
}
expand_from_vertex(vertex); expand_from_vertex(vertex);
// go back to loop start and see if we expanded anything // go back to loop start and see if we expanded anything
@ -1673,6 +1701,10 @@ class SingleSourceShortestPathCursor : public query::plan::Cursor {
namespace { namespace {
void CheckWeightType(TypedValue current_weight, utils::MemoryResource *memory) { void CheckWeightType(TypedValue current_weight, utils::MemoryResource *memory) {
if (current_weight.IsNull()) {
return;
}
if (!current_weight.IsNumeric() && !current_weight.IsDuration()) { if (!current_weight.IsNumeric() && !current_weight.IsDuration()) {
throw QueryRuntimeException("Calculated weight must be numeric or a Duration, got {}.", current_weight.type()); throw QueryRuntimeException("Calculated weight must be numeric or a Duration, got {}.", current_weight.type());
} }
@ -1690,6 +1722,34 @@ void CheckWeightType(TypedValue current_weight, utils::MemoryResource *memory) {
} }
} }
void ValidateWeightTypes(const TypedValue &lhs, const TypedValue &rhs) {
if ((lhs.IsNumeric() && rhs.IsNumeric()) || (lhs.IsDuration() && rhs.IsDuration())) {
return;
}
throw QueryRuntimeException(utils::MessageWithLink(
"All weights should be of the same type, either numeric or a Duration. Please update the weight "
"expression or the filter expression.",
"https://memgr.ph/wsp"));
}
TypedValue CalculateNextWeight(const std::optional<memgraph::query::plan::ExpansionLambda> &weight_lambda,
const TypedValue &total_weight, ExpressionEvaluator evaluator) {
if (!weight_lambda) {
return {};
}
auto *memory = evaluator.GetMemoryResource();
TypedValue current_weight = weight_lambda->expression->Accept(evaluator);
CheckWeightType(current_weight, memory);
if (total_weight.IsNull()) {
return current_weight;
}
ValidateWeightTypes(current_weight, total_weight);
return TypedValue(current_weight, memory) + total_weight;
}
} // namespace } // namespace
class ExpandWeightedShortestPathCursor : public query::plan::Cursor { class ExpandWeightedShortestPathCursor : public query::plan::Cursor {
@ -1718,7 +1778,6 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor {
auto expand_pair = [this, &evaluator, &frame, &create_state, &context]( auto expand_pair = [this, &evaluator, &frame, &create_state, &context](
const EdgeAccessor &edge, const VertexAccessor &vertex, const TypedValue &total_weight, const EdgeAccessor &edge, const VertexAccessor &vertex, const TypedValue &total_weight,
int64_t depth) { int64_t depth) {
auto *memory = evaluator.GetMemoryResource();
#ifdef MG_ENTERPRISE #ifdef MG_ENTERPRISE
if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker &&
!(context.auth_checker->Has(vertex, storage::View::OLD, !(context.auth_checker->Has(vertex, storage::View::OLD,
@ -1727,37 +1786,36 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor {
return; return;
} }
#endif #endif
frame[self_.weight_lambda_->inner_edge_symbol] = edge;
frame[self_.weight_lambda_->inner_node_symbol] = vertex;
TypedValue next_weight = CalculateNextWeight(self_.weight_lambda_, total_weight, evaluator);
if (self_.filter_lambda_.expression) { if (self_.filter_lambda_.expression) {
frame[self_.filter_lambda_.inner_edge_symbol] = edge; frame[self_.filter_lambda_.inner_edge_symbol] = edge;
frame[self_.filter_lambda_.inner_node_symbol] = vertex; frame[self_.filter_lambda_.inner_node_symbol] = vertex;
if (self_.filter_lambda_.accumulated_path_symbol) {
MG_ASSERT(frame[self_.filter_lambda_.accumulated_path_symbol.value()].IsPath(),
"Accumulated path must be path");
Path &accumulated_path = frame[self_.filter_lambda_.accumulated_path_symbol.value()].ValuePath();
accumulated_path.Expand(edge);
accumulated_path.Expand(vertex);
if (self_.filter_lambda_.accumulated_weight_symbol) {
frame[self_.filter_lambda_.accumulated_weight_symbol.value()] = next_weight;
}
}
if (!EvaluateFilter(evaluator, self_.filter_lambda_.expression)) return; if (!EvaluateFilter(evaluator, self_.filter_lambda_.expression)) return;
} }
frame[self_.weight_lambda_->inner_edge_symbol] = edge;
frame[self_.weight_lambda_->inner_node_symbol] = vertex;
TypedValue current_weight = self_.weight_lambda_->expression->Accept(evaluator);
CheckWeightType(current_weight, memory);
auto next_state = create_state(vertex, depth); auto next_state = create_state(vertex, depth);
TypedValue next_weight = std::invoke([&] {
if (total_weight.IsNull()) {
return current_weight;
}
ValidateWeightTypes(current_weight, total_weight);
return TypedValue(current_weight, memory) + total_weight;
});
auto found_it = total_cost_.find(next_state); auto found_it = total_cost_.find(next_state);
if (found_it != total_cost_.end() && (found_it->second.IsNull() || (found_it->second <= next_weight).ValueBool())) if (found_it != total_cost_.end() && (found_it->second.IsNull() || (found_it->second <= next_weight).ValueBool()))
return; return;
pq_.push({next_weight, depth + 1, vertex, edge}); pq_.emplace(next_weight, depth + 1, vertex, edge);
}; };
// Populates the priority queue structure with expansions // Populates the priority queue structure with expansions
@ -1792,6 +1850,10 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor {
// Skip expansion for such nodes. // Skip expansion for such nodes.
if (node.IsNull()) continue; if (node.IsNull()) continue;
} }
if (self_.filter_lambda_.accumulated_path_symbol) {
// Add initial vertex of path to the accumulated path
frame[self_.filter_lambda_.accumulated_path_symbol.value()] = Path(vertex);
}
if (self_.upper_bound_) { if (self_.upper_bound_) {
upper_bound_ = EvaluateInt(&evaluator, self_.upper_bound_, "Max depth in weighted shortest path expansion"); upper_bound_ = EvaluateInt(&evaluator, self_.upper_bound_, "Max depth in weighted shortest path expansion");
upper_bound_set_ = true; upper_bound_set_ = true;
@ -1804,12 +1866,17 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor {
"Maximum depth in weighted shortest path expansion must be at " "Maximum depth in weighted shortest path expansion must be at "
"least 1."); "least 1.");
frame[self_.weight_lambda_->inner_edge_symbol] = TypedValue();
frame[self_.weight_lambda_->inner_node_symbol] = vertex;
TypedValue current_weight =
CalculateNextWeight(self_.weight_lambda_, /* total_weight */ TypedValue(), evaluator);
// Clear existing data structures. // Clear existing data structures.
previous_.clear(); previous_.clear();
total_cost_.clear(); total_cost_.clear();
yielded_vertices_.clear(); yielded_vertices_.clear();
pq_.push({TypedValue(), 0, vertex, std::nullopt}); pq_.emplace(current_weight, 0, vertex, std::nullopt);
// We are adding the starting vertex to the set of yielded vertices // We are adding the starting vertex to the set of yielded vertices
// because we don't want to yield paths that end with the starting // because we don't want to yield paths that end with the starting
// vertex. // vertex.
@ -1909,15 +1976,6 @@ class ExpandWeightedShortestPathCursor : public query::plan::Cursor {
// Keeps track of vertices for which we yielded a path already. // Keeps track of vertices for which we yielded a path already.
utils::pmr::unordered_set<VertexAccessor> yielded_vertices_; utils::pmr::unordered_set<VertexAccessor> yielded_vertices_;
static void ValidateWeightTypes(const TypedValue &lhs, const TypedValue &rhs) {
if (!((lhs.IsNumeric() && lhs.IsNumeric()) || (rhs.IsDuration() && rhs.IsDuration()))) {
throw QueryRuntimeException(utils::MessageWithLink(
"All weights should be of the same type, either numeric or a Duration. Please update the weight "
"expression or the filter expression.",
"https://memgr.ph/wsp"));
}
}
// Priority queue comparator. Keep lowest weight on top of the queue. // Priority queue comparator. Keep lowest weight on top of the queue.
class PriorityQueueComparator { class PriorityQueueComparator {
public: public:
@ -1975,36 +2033,32 @@ class ExpandAllShortestPathsCursor : public query::plan::Cursor {
// queue. // queue.
auto expand_vertex = [this, &evaluator, &frame](const EdgeAccessor &edge, const EdgeAtom::Direction direction, auto expand_vertex = [this, &evaluator, &frame](const EdgeAccessor &edge, const EdgeAtom::Direction direction,
const TypedValue &total_weight, int64_t depth) { const TypedValue &total_weight, int64_t depth) {
auto *memory = evaluator.GetMemoryResource();
auto const &next_vertex = direction == EdgeAtom::Direction::IN ? edge.From() : edge.To(); auto const &next_vertex = direction == EdgeAtom::Direction::IN ? edge.From() : edge.To();
// Evaluate current weight
frame[self_.weight_lambda_->inner_edge_symbol] = edge;
frame[self_.weight_lambda_->inner_node_symbol] = next_vertex;
TypedValue next_weight = CalculateNextWeight(self_.weight_lambda_, total_weight, evaluator);
// If filter expression exists, evaluate filter // If filter expression exists, evaluate filter
if (self_.filter_lambda_.expression) { if (self_.filter_lambda_.expression) {
frame[self_.filter_lambda_.inner_edge_symbol] = edge; frame[self_.filter_lambda_.inner_edge_symbol] = edge;
frame[self_.filter_lambda_.inner_node_symbol] = next_vertex; frame[self_.filter_lambda_.inner_node_symbol] = next_vertex;
if (self_.filter_lambda_.accumulated_path_symbol) {
MG_ASSERT(frame[self_.filter_lambda_.accumulated_path_symbol.value()].IsPath(),
"Accumulated path must be path");
Path &accumulated_path = frame[self_.filter_lambda_.accumulated_path_symbol.value()].ValuePath();
accumulated_path.Expand(edge);
accumulated_path.Expand(next_vertex);
if (self_.filter_lambda_.accumulated_weight_symbol) {
frame[self_.filter_lambda_.accumulated_weight_symbol.value()] = next_weight;
}
}
if (!EvaluateFilter(evaluator, self_.filter_lambda_.expression)) return; if (!EvaluateFilter(evaluator, self_.filter_lambda_.expression)) return;
} }
// Evaluate current weight
frame[self_.weight_lambda_->inner_edge_symbol] = edge;
frame[self_.weight_lambda_->inner_node_symbol] = next_vertex;
TypedValue current_weight = self_.weight_lambda_->expression->Accept(evaluator);
CheckWeightType(current_weight, memory);
TypedValue next_weight = std::invoke([&] {
if (total_weight.IsNull()) {
return current_weight;
}
ValidateWeightTypes(current_weight, total_weight);
return TypedValue(current_weight, memory) + total_weight;
});
auto found_it = visited_cost_.find(next_vertex); auto found_it = visited_cost_.find(next_vertex);
// Check if the vertex has already been processed. // Check if the vertex has already been processed.
if (found_it != visited_cost_.end()) { if (found_it != visited_cost_.end()) {
@ -2022,7 +2076,7 @@ class ExpandAllShortestPathsCursor : public query::plan::Cursor {
} }
DirectedEdge directed_edge = {edge, direction, next_weight}; DirectedEdge directed_edge = {edge, direction, next_weight};
pq_.push({next_weight, depth + 1, next_vertex, directed_edge}); pq_.emplace(next_weight, depth + 1, next_vertex, directed_edge);
}; };
// Populates the priority queue structure with expansions // Populates the priority queue structure with expansions
@ -2196,7 +2250,17 @@ class ExpandAllShortestPathsCursor : public query::plan::Cursor {
traversal_stack_.clear(); traversal_stack_.clear();
total_cost_.clear(); total_cost_.clear();
expand_from_vertex(*start_vertex, TypedValue(), 0); if (self_.filter_lambda_.accumulated_path_symbol) {
// Add initial vertex of path to the accumulated path
frame[self_.filter_lambda_.accumulated_path_symbol.value()] = Path(*start_vertex);
}
frame[self_.weight_lambda_->inner_edge_symbol] = TypedValue();
frame[self_.weight_lambda_->inner_node_symbol] = *start_vertex;
TypedValue current_weight =
CalculateNextWeight(self_.weight_lambda_, /* total_weight */ TypedValue(), evaluator);
expand_from_vertex(*start_vertex, current_weight, 0);
visited_cost_.emplace(*start_vertex, 0); visited_cost_.emplace(*start_vertex, 0);
frame[self_.common_.edge_symbol] = TypedValue::TVector(memory); frame[self_.common_.edge_symbol] = TypedValue::TVector(memory);
} }
@ -2248,15 +2312,6 @@ class ExpandAllShortestPathsCursor : public query::plan::Cursor {
// Stack indicating the traversal level. // Stack indicating the traversal level.
utils::pmr::list<utils::pmr::list<DirectedEdge>> traversal_stack_; utils::pmr::list<utils::pmr::list<DirectedEdge>> traversal_stack_;
static void ValidateWeightTypes(const TypedValue &lhs, const TypedValue &rhs) {
if (!((lhs.IsNumeric() && lhs.IsNumeric()) || (rhs.IsDuration() && rhs.IsDuration()))) {
throw QueryRuntimeException(utils::MessageWithLink(
"All weights should be of the same type, either numeric or a Duration. Please update the weight "
"expression or the filter expression.",
"https://memgr.ph/wsp"));
}
}
// Priority queue comparator. Keep lowest weight on top of the queue. // Priority queue comparator. Keep lowest weight on top of the queue.
class PriorityQueueComparator { class PriorityQueueComparator {
public: public:
@ -2313,8 +2368,8 @@ UniqueCursorPtr ExpandVariable::MakeCursor(utils::MemoryResource *mem) const {
class ConstructNamedPathCursor : public Cursor { class ConstructNamedPathCursor : public Cursor {
public: public:
ConstructNamedPathCursor(const ConstructNamedPath &self, utils::MemoryResource *mem) ConstructNamedPathCursor(ConstructNamedPath self, utils::MemoryResource *mem)
: self_(self), input_cursor_(self_.input()->MakeCursor(mem)) {} : self_(std::move(self)), input_cursor_(self_.input()->MakeCursor(mem)) {}
bool Pull(Frame &frame, ExecutionContext &context) override { bool Pull(Frame &frame, ExecutionContext &context) override {
OOMExceptionEnabler oom_exception; OOMExceptionEnabler oom_exception;
@ -2412,11 +2467,11 @@ Filter::Filter(const std::shared_ptr<LogicalOperator> &input,
Filter::Filter(const std::shared_ptr<LogicalOperator> &input, Filter::Filter(const std::shared_ptr<LogicalOperator> &input,
const std::vector<std::shared_ptr<LogicalOperator>> &pattern_filters, Expression *expression, const std::vector<std::shared_ptr<LogicalOperator>> &pattern_filters, Expression *expression,
const Filters &all_filters) Filters all_filters)
: input_(input ? input : std::make_shared<Once>()), : input_(input ? input : std::make_shared<Once>()),
pattern_filters_(pattern_filters), pattern_filters_(pattern_filters),
expression_(expression), expression_(expression),
all_filters_(all_filters) {} all_filters_(std::move(all_filters)) {}
bool Filter::Accept(HierarchicalLogicalOperatorVisitor &visitor) { bool Filter::Accept(HierarchicalLogicalOperatorVisitor &visitor) {
if (visitor.PreVisit(*this)) { if (visitor.PreVisit(*this)) {
@ -2477,7 +2532,7 @@ void Filter::FilterCursor::Shutdown() { input_cursor_->Shutdown(); }
void Filter::FilterCursor::Reset() { input_cursor_->Reset(); } void Filter::FilterCursor::Reset() { input_cursor_->Reset(); }
EvaluatePatternFilter::EvaluatePatternFilter(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol) EvaluatePatternFilter::EvaluatePatternFilter(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol)
: input_(input), output_symbol_(output_symbol) {} : input_(input), output_symbol_(std::move(output_symbol)) {}
ACCEPT_WITH_INPUT(EvaluatePatternFilter); ACCEPT_WITH_INPUT(EvaluatePatternFilter);
@ -2496,13 +2551,16 @@ std::vector<Symbol> EvaluatePatternFilter::ModifiedSymbols(const SymbolTable &ta
} }
bool EvaluatePatternFilter::EvaluatePatternFilterCursor::Pull(Frame &frame, ExecutionContext &context) { bool EvaluatePatternFilter::EvaluatePatternFilterCursor::Pull(Frame &frame, ExecutionContext &context) {
OOMExceptionEnabler oom_exception;
SCOPED_PROFILE_OP("EvaluatePatternFilter"); SCOPED_PROFILE_OP("EvaluatePatternFilter");
std::function<void(TypedValue *)> function = [&frame, self = this->self_, input_cursor = this->input_cursor_.get(),
&context](TypedValue *return_value) {
OOMExceptionEnabler oom_exception;
input_cursor->Reset();
input_cursor_->Reset(); *return_value = TypedValue(input_cursor->Pull(frame, context), context.evaluation_context.memory);
};
frame[self_.output_symbol_] = TypedValue(input_cursor_->Pull(frame, context), context.evaluation_context.memory);
frame[self_.output_symbol_] = TypedValue(std::move(function));
return true; return true;
} }
@ -2800,7 +2858,7 @@ void SetProperty::SetPropertyCursor::Shutdown() { input_cursor_->Shutdown(); }
void SetProperty::SetPropertyCursor::Reset() { input_cursor_->Reset(); } void SetProperty::SetPropertyCursor::Reset() { input_cursor_->Reset(); }
SetProperties::SetProperties(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, Expression *rhs, Op op) SetProperties::SetProperties(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, Expression *rhs, Op op)
: input_(input), input_symbol_(input_symbol), rhs_(rhs), op_(op) {} : input_(input), input_symbol_(std::move(input_symbol)), rhs_(rhs), op_(op) {}
ACCEPT_WITH_INPUT(SetProperties) ACCEPT_WITH_INPUT(SetProperties)
@ -2999,7 +3057,7 @@ void SetProperties::SetPropertiesCursor::Reset() { input_cursor_->Reset(); }
SetLabels::SetLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, SetLabels::SetLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol,
const std::vector<storage::LabelId> &labels) const std::vector<storage::LabelId> &labels)
: input_(input), input_symbol_(input_symbol), labels_(labels) {} : input_(input), input_symbol_(std::move(input_symbol)), labels_(labels) {}
ACCEPT_WITH_INPUT(SetLabels) ACCEPT_WITH_INPUT(SetLabels)
@ -3159,7 +3217,7 @@ void RemoveProperty::RemovePropertyCursor::Reset() { input_cursor_->Reset(); }
RemoveLabels::RemoveLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, RemoveLabels::RemoveLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol,
const std::vector<storage::LabelId> &labels) const std::vector<storage::LabelId> &labels)
: input_(input), input_symbol_(input_symbol), labels_(labels) {} : input_(input), input_symbol_(std::move(input_symbol)), labels_(labels) {}
ACCEPT_WITH_INPUT(RemoveLabels) ACCEPT_WITH_INPUT(RemoveLabels)
@ -3233,7 +3291,7 @@ void RemoveLabels::RemoveLabelsCursor::Reset() { input_cursor_->Reset(); }
EdgeUniquenessFilter::EdgeUniquenessFilter(const std::shared_ptr<LogicalOperator> &input, Symbol expand_symbol, EdgeUniquenessFilter::EdgeUniquenessFilter(const std::shared_ptr<LogicalOperator> &input, Symbol expand_symbol,
const std::vector<Symbol> &previous_symbols) const std::vector<Symbol> &previous_symbols)
: input_(input), expand_symbol_(expand_symbol), previous_symbols_(previous_symbols) {} : input_(input), expand_symbol_(std::move(expand_symbol)), previous_symbols_(previous_symbols) {}
ACCEPT_WITH_INPUT(EdgeUniquenessFilter) ACCEPT_WITH_INPUT(EdgeUniquenessFilter)
@ -3463,7 +3521,7 @@ class AggregateCursor : public Cursor {
SCOPED_PROFILE_OP_BY_REF(self_); SCOPED_PROFILE_OP_BY_REF(self_);
if (!pulled_all_input_) { if (!pulled_all_input_) {
ProcessAll(&frame, &context); if (!ProcessAll(&frame, &context) && !self_.group_by_.empty()) return false;
pulled_all_input_ = true; pulled_all_input_ = true;
aggregation_it_ = aggregation_.begin(); aggregation_it_ = aggregation_.begin();
@ -3487,7 +3545,6 @@ class AggregateCursor : public Cursor {
return true; return true;
} }
} }
if (aggregation_it_ == aggregation_.end()) return false; if (aggregation_it_ == aggregation_.end()) return false;
// place aggregation values on the frame // place aggregation values on the frame
@ -3567,12 +3624,16 @@ class AggregateCursor : public Cursor {
* cache cardinality depends on number of * cache cardinality depends on number of
* aggregation results, and not on the number of inputs. * aggregation results, and not on the number of inputs.
*/ */
void ProcessAll(Frame *frame, ExecutionContext *context) { bool ProcessAll(Frame *frame, ExecutionContext *context) {
ExpressionEvaluator evaluator(frame, context->symbol_table, context->evaluation_context, context->db_accessor, ExpressionEvaluator evaluator(frame, context->symbol_table, context->evaluation_context, context->db_accessor,
storage::View::NEW); storage::View::NEW);
bool pulled = false;
while (input_cursor_->Pull(*frame, *context)) { while (input_cursor_->Pull(*frame, *context)) {
ProcessOne(*frame, &evaluator); ProcessOne(*frame, &evaluator);
pulled = true;
} }
if (!pulled) return false;
// post processing // post processing
for (size_t pos = 0; pos < self_.aggregations_.size(); ++pos) { for (size_t pos = 0; pos < self_.aggregations_.size(); ++pos) {
@ -3606,6 +3667,7 @@ class AggregateCursor : public Cursor {
break; break;
} }
} }
return true;
} }
/** /**
@ -3614,6 +3676,7 @@ class AggregateCursor : public Cursor {
void ProcessOne(const Frame &frame, ExpressionEvaluator *evaluator) { void ProcessOne(const Frame &frame, ExpressionEvaluator *evaluator) {
// Preallocated group_by, since most of the time the aggregation key won't be unique // Preallocated group_by, since most of the time the aggregation key won't be unique
reused_group_by_.clear(); reused_group_by_.clear();
evaluator->ResetPropertyLookupCache();
for (Expression *expression : self_.group_by_) { for (Expression *expression : self_.group_by_) {
reused_group_by_.emplace_back(expression->Accept(*evaluator)); reused_group_by_.emplace_back(expression->Accept(*evaluator));
@ -4199,7 +4262,7 @@ void Optional::OptionalCursor::Reset() {
Unwind::Unwind(const std::shared_ptr<LogicalOperator> &input, Expression *input_expression, Symbol output_symbol) Unwind::Unwind(const std::shared_ptr<LogicalOperator> &input, Expression *input_expression, Symbol output_symbol)
: input_(input ? input : std::make_shared<Once>()), : input_(input ? input : std::make_shared<Once>()),
input_expression_(input_expression), input_expression_(input_expression),
output_symbol_(output_symbol) {} output_symbol_(std::move(output_symbol)) {}
ACCEPT_WITH_INPUT(Unwind) ACCEPT_WITH_INPUT(Unwind)
@ -4530,7 +4593,7 @@ WITHOUT_SINGLE_INPUT(OutputTable);
class OutputTableCursor : public Cursor { class OutputTableCursor : public Cursor {
public: public:
OutputTableCursor(const OutputTable &self) : self_(self) {} explicit OutputTableCursor(const OutputTable &self) : self_(self) {}
bool Pull(Frame &frame, ExecutionContext &context) override { bool Pull(Frame &frame, ExecutionContext &context) override {
OOMExceptionEnabler oom_exception; OOMExceptionEnabler oom_exception;
@ -4621,10 +4684,10 @@ CallProcedure::CallProcedure(std::shared_ptr<LogicalOperator> input, std::string
std::vector<std::string> fields, std::vector<Symbol> symbols, Expression *memory_limit, std::vector<std::string> fields, std::vector<Symbol> symbols, Expression *memory_limit,
size_t memory_scale, bool is_write, int64_t procedure_id, bool void_procedure) size_t memory_scale, bool is_write, int64_t procedure_id, bool void_procedure)
: input_(input ? input : std::make_shared<Once>()), : input_(input ? input : std::make_shared<Once>()),
procedure_name_(name), procedure_name_(std::move(name)),
arguments_(args), arguments_(std::move(args)),
result_fields_(fields), result_fields_(std::move(fields)),
result_symbols_(symbols), result_symbols_(std::move(symbols)),
memory_limit_(memory_limit), memory_limit_(memory_limit),
memory_scale_(memory_scale), memory_scale_(memory_scale),
is_write_(is_write), is_write_(is_write),
@ -4772,6 +4835,12 @@ class CallProcedureCursor : public Cursor {
AbortCheck(context); AbortCheck(context);
auto skip_rows_with_deleted_values = [this]() {
while (result_row_it_ != result_->rows.end() && result_row_it_->has_deleted_values) {
++result_row_it_;
}
};
// We need to fetch new procedure results after pulling from input. // We need to fetch new procedure results after pulling from input.
// TODO: Look into openCypher's distinction between procedures returning an // TODO: Look into openCypher's distinction between procedures returning an
// empty result set vs procedures which return `void`. We currently don't // empty result set vs procedures which return `void`. We currently don't
@ -4781,7 +4850,7 @@ class CallProcedureCursor : public Cursor {
// It might be a good idea to resolve the procedure name once, at the // It might be a good idea to resolve the procedure name once, at the
// start. Unfortunately, this could deadlock if we tried to invoke a // start. Unfortunately, this could deadlock if we tried to invoke a
// procedure from a module (read lock) and reload a module (write lock) // procedure from a module (read lock) and reload a module (write lock)
// inside the same execution thread. Also, our RWLock is setup so that // inside the same execution thread. Also, our RWLock is set up so that
// it's not possible for a single thread to request multiple read locks. // it's not possible for a single thread to request multiple read locks.
// Builtin module registration in query/procedure/module.cpp depends on // Builtin module registration in query/procedure/module.cpp depends on
// this locking scheme. // this locking scheme.
@ -4829,6 +4898,7 @@ class CallProcedureCursor : public Cursor {
graph_view); graph_view);
result_->signature = &proc->results; result_->signature = &proc->results;
result_->is_transactional = storage::IsTransactional(context.db_accessor->GetStorageMode());
// Use special memory as invoking procedure is complex // Use special memory as invoking procedure is complex
// TODO: This will probably need to be changed when we add support for // TODO: This will probably need to be changed when we add support for
@ -4849,9 +4919,13 @@ class CallProcedureCursor : public Cursor {
result_signature_size_ = result_->signature->size(); result_signature_size_ = result_->signature->size();
result_->signature = nullptr; result_->signature = nullptr;
if (result_->error_msg) { if (result_->error_msg) {
memgraph::utils::MemoryTracker::OutOfMemoryExceptionBlocker blocker;
throw QueryRuntimeException("{}: {}", self_->procedure_name_, *result_->error_msg); throw QueryRuntimeException("{}: {}", self_->procedure_name_, *result_->error_msg);
} }
result_row_it_ = result_->rows.begin(); result_row_it_ = result_->rows.begin();
if (!result_->is_transactional) {
skip_rows_with_deleted_values();
}
stream_exhausted = result_row_it_ == result_->rows.end(); stream_exhausted = result_row_it_ == result_->rows.end();
} }
@ -4881,6 +4955,9 @@ class CallProcedureCursor : public Cursor {
} }
} }
++result_row_it_; ++result_row_it_;
if (!result_->is_transactional) {
skip_rows_with_deleted_values();
}
return true; return true;
} }
@ -4977,7 +5054,7 @@ LoadCsv::LoadCsv(std::shared_ptr<LogicalOperator> input, Expression *file, bool
delimiter_(delimiter), delimiter_(delimiter),
quote_(quote), quote_(quote),
nullif_(nullif), nullif_(nullif),
row_var_(row_var) { row_var_(std::move(row_var)) {
MG_ASSERT(file_, "Something went wrong - '{}' member file_ shouldn't be a nullptr", __func__); MG_ASSERT(file_, "Something went wrong - '{}' member file_ shouldn't be a nullptr", __func__);
} }
@ -5191,7 +5268,7 @@ Foreach::Foreach(std::shared_ptr<LogicalOperator> input, std::shared_ptr<Logical
: input_(input ? std::move(input) : std::make_shared<Once>()), : input_(input ? std::move(input) : std::make_shared<Once>()),
update_clauses_(std::move(updates)), update_clauses_(std::move(updates)),
expression_(expr), expression_(expr),
loop_variable_symbol_(loop_variable_symbol) {} loop_variable_symbol_(std::move(loop_variable_symbol)) {}
UniqueCursorPtr Foreach::MakeCursor(utils::MemoryResource *mem) const { UniqueCursorPtr Foreach::MakeCursor(utils::MemoryResource *mem) const {
memgraph::metrics::IncrementCounter(memgraph::metrics::ForeachOperator); memgraph::metrics::IncrementCounter(memgraph::metrics::ForeachOperator);
@ -5404,7 +5481,7 @@ class HashJoinCursor : public Cursor {
// Check if the join value from the pulled frame is shared with any left frames // Check if the join value from the pulled frame is shared with any left frames
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
storage::View::OLD); storage::View::OLD);
auto right_value = self_.hash_join_condition_->expression1_->Accept(evaluator); auto right_value = self_.hash_join_condition_->expression2_->Accept(evaluator);
if (hashtable_.contains(right_value)) { if (hashtable_.contains(right_value)) {
// If so, finish pulling for now and proceed to joining the pulled frame // If so, finish pulling for now and proceed to joining the pulled frame
right_op_frame_.assign(frame.elems().begin(), frame.elems().end()); right_op_frame_.assign(frame.elems().begin(), frame.elems().end());
@ -5452,7 +5529,7 @@ class HashJoinCursor : public Cursor {
while (left_op_cursor_->Pull(frame, context)) { while (left_op_cursor_->Pull(frame, context)) {
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
storage::View::OLD); storage::View::OLD);
auto left_value = self_.hash_join_condition_->expression2_->Accept(evaluator); auto left_value = self_.hash_join_condition_->expression1_->Accept(evaluator);
if (left_value.type() != TypedValue::Type::Null) { if (left_value.type() != TypedValue::Type::Null) {
hashtable_[left_value].emplace_back(frame.elems().begin(), frame.elems().end()); hashtable_[left_value].emplace_back(frame.elems().begin(), frame.elems().end());
} }

View File

@ -32,9 +32,7 @@
#include "utils/synchronized.hpp" #include "utils/synchronized.hpp"
#include "utils/visitor.hpp" #include "utils/visitor.hpp"
namespace memgraph { namespace memgraph::query {
namespace query {
struct ExecutionContext; struct ExecutionContext;
class ExpressionEvaluator; class ExpressionEvaluator;
@ -68,7 +66,7 @@ class Cursor {
/// Perform cleanup which may throw an exception /// Perform cleanup which may throw an exception
virtual void Shutdown() = 0; virtual void Shutdown() = 0;
virtual ~Cursor() {} virtual ~Cursor() = default;
}; };
/// unique_ptr to Cursor managed with a custom deleter. /// unique_ptr to Cursor managed with a custom deleter.
@ -172,7 +170,7 @@ class LogicalOperator : public utils::Visitable<HierarchicalLogicalOperatorVisit
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
virtual const utils::TypeInfo &GetTypeInfo() const { return kType; } virtual const utils::TypeInfo &GetTypeInfo() const { return kType; }
virtual ~LogicalOperator() {} ~LogicalOperator() override = default;
/** Construct a @c Cursor which is used to run this operator. /** Construct a @c Cursor which is used to run this operator.
* *
@ -274,7 +272,7 @@ class Once : public memgraph::query::plan::LogicalOperator {
private: private:
class OnceCursor : public Cursor { class OnceCursor : public Cursor {
public: public:
OnceCursor() {} OnceCursor() = default;
bool Pull(Frame &, ExecutionContext &) override; bool Pull(Frame &, ExecutionContext &) override;
void Shutdown() override; void Shutdown() override;
void Reset() override; void Reset() override;
@ -340,7 +338,7 @@ class CreateNode : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
CreateNode() {} CreateNode() = default;
/** /**
* @param input Optional. If @c nullptr, then a single node will be * @param input Optional. If @c nullptr, then a single node will be
@ -349,7 +347,7 @@ class CreateNode : public memgraph::query::plan::LogicalOperator {
* successful pull from the given input. * successful pull from the given input.
* @param node_info @c NodeCreationInfo * @param node_info @c NodeCreationInfo
*/ */
CreateNode(const std::shared_ptr<LogicalOperator> &input, const NodeCreationInfo &node_info); CreateNode(const std::shared_ptr<LogicalOperator> &input, NodeCreationInfo node_info);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
@ -445,7 +443,7 @@ class CreateExpand : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
CreateExpand() {} CreateExpand() = default;
/** @brief Construct @c CreateExpand. /** @brief Construct @c CreateExpand.
* *
@ -459,8 +457,8 @@ class CreateExpand : public memgraph::query::plan::LogicalOperator {
* @param existing_node @c bool indicating whether the @c node_atom refers to * @param existing_node @c bool indicating whether the @c node_atom refers to
* an existing node. If @c false, the operator will also create the node. * an existing node. If @c false, the operator will also create the node.
*/ */
CreateExpand(const NodeCreationInfo &node_info, const EdgeCreationInfo &edge_info, CreateExpand(NodeCreationInfo node_info, EdgeCreationInfo edge_info, const std::shared_ptr<LogicalOperator> &input,
const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, bool existing_node); Symbol input_symbol, bool existing_node);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
@ -529,7 +527,7 @@ class ScanAll : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
ScanAll() {} ScanAll() = default;
ScanAll(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, storage::View view = storage::View::OLD); ScanAll(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, storage::View view = storage::View::OLD);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override;
@ -571,7 +569,7 @@ class ScanAllByLabel : public memgraph::query::plan::ScanAll {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
ScanAllByLabel() {} ScanAllByLabel() = default;
ScanAllByLabel(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, storage::LabelId label, ScanAllByLabel(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, storage::LabelId label,
storage::View view = storage::View::OLD); storage::View view = storage::View::OLD);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
@ -606,7 +604,7 @@ class ScanAllByLabelPropertyRange : public memgraph::query::plan::ScanAll {
/** Bound with expression which when evaluated produces the bound value. */ /** Bound with expression which when evaluated produces the bound value. */
using Bound = utils::Bound<Expression *>; using Bound = utils::Bound<Expression *>;
ScanAllByLabelPropertyRange() {} ScanAllByLabelPropertyRange() = default;
/** /**
* Constructs the operator for given label and property value in range * Constructs the operator for given label and property value in range
* (inclusive). * (inclusive).
@ -622,7 +620,7 @@ class ScanAllByLabelPropertyRange : public memgraph::query::plan::ScanAll {
* @param view storage::View used when obtaining vertices. * @param view storage::View used when obtaining vertices.
*/ */
ScanAllByLabelPropertyRange(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, ScanAllByLabelPropertyRange(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol,
storage::LabelId label, storage::PropertyId property, const std::string &property_name, storage::LabelId label, storage::PropertyId property, std::string property_name,
std::optional<Bound> lower_bound, std::optional<Bound> upper_bound, std::optional<Bound> lower_bound, std::optional<Bound> upper_bound,
storage::View view = storage::View::OLD); storage::View view = storage::View::OLD);
@ -675,7 +673,7 @@ class ScanAllByLabelPropertyValue : public memgraph::query::plan::ScanAll {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
ScanAllByLabelPropertyValue() {} ScanAllByLabelPropertyValue() = default;
/** /**
* Constructs the operator for given label and property value. * Constructs the operator for given label and property value.
* *
@ -687,7 +685,7 @@ class ScanAllByLabelPropertyValue : public memgraph::query::plan::ScanAll {
* @param view storage::View used when obtaining vertices. * @param view storage::View used when obtaining vertices.
*/ */
ScanAllByLabelPropertyValue(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, ScanAllByLabelPropertyValue(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol,
storage::LabelId label, storage::PropertyId property, const std::string &property_name, storage::LabelId label, storage::PropertyId property, std::string property_name,
Expression *expression, storage::View view = storage::View::OLD); Expression *expression, storage::View view = storage::View::OLD);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
@ -727,9 +725,9 @@ class ScanAllByLabelProperty : public memgraph::query::plan::ScanAll {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
ScanAllByLabelProperty() {} ScanAllByLabelProperty() = default;
ScanAllByLabelProperty(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, storage::LabelId label, ScanAllByLabelProperty(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, storage::LabelId label,
storage::PropertyId property, const std::string &property_name, storage::PropertyId property, std::string property_name,
storage::View view = storage::View::OLD); storage::View view = storage::View::OLD);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override;
@ -763,7 +761,7 @@ class ScanAllById : public memgraph::query::plan::ScanAll {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
ScanAllById() {} ScanAllById() = default;
ScanAllById(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, Expression *expression, ScanAllById(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, Expression *expression,
storage::View view = storage::View::OLD); storage::View view = storage::View::OLD);
@ -842,7 +840,7 @@ class Expand : public memgraph::query::plan::LogicalOperator {
EdgeAtom::Direction direction, const std::vector<storage::EdgeTypeId> &edge_types, bool existing_node, EdgeAtom::Direction direction, const std::vector<storage::EdgeTypeId> &edge_types, bool existing_node,
storage::View view); storage::View view);
Expand() {} Expand() = default;
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override;
@ -919,12 +917,18 @@ struct ExpansionLambda {
Symbol inner_node_symbol; Symbol inner_node_symbol;
/// Expression used in lambda during expansion. /// Expression used in lambda during expansion.
Expression *expression; Expression *expression;
/// Currently expanded accumulated path symbol.
std::optional<Symbol> accumulated_path_symbol;
/// Currently expanded accumulated weight symbol.
std::optional<Symbol> accumulated_weight_symbol;
ExpansionLambda Clone(AstStorage *storage) const { ExpansionLambda Clone(AstStorage *storage) const {
ExpansionLambda object; ExpansionLambda object;
object.inner_edge_symbol = inner_edge_symbol; object.inner_edge_symbol = inner_edge_symbol;
object.inner_node_symbol = inner_node_symbol; object.inner_node_symbol = inner_node_symbol;
object.expression = expression ? expression->Clone(storage) : nullptr; object.expression = expression ? expression->Clone(storage) : nullptr;
object.accumulated_path_symbol = accumulated_path_symbol;
object.accumulated_weight_symbol = accumulated_weight_symbol;
return object; return object;
} }
}; };
@ -950,7 +954,7 @@ class ExpandVariable : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
ExpandVariable() {} ExpandVariable() = default;
/** /**
* Creates a variable-length expansion. Most params are forwarded * Creates a variable-length expansion. Most params are forwarded
@ -1073,10 +1077,10 @@ class ConstructNamedPath : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
ConstructNamedPath() {} ConstructNamedPath() = default;
ConstructNamedPath(const std::shared_ptr<LogicalOperator> &input, Symbol path_symbol, ConstructNamedPath(const std::shared_ptr<LogicalOperator> &input, Symbol path_symbol,
const std::vector<Symbol> &path_elements) const std::vector<Symbol> &path_elements)
: input_(input), path_symbol_(path_symbol), path_elements_(path_elements) {} : input_(input), path_symbol_(std::move(path_symbol)), path_elements_(path_elements) {}
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
@ -1108,13 +1112,13 @@ class Filter : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
Filter() {} Filter() = default;
Filter(const std::shared_ptr<LogicalOperator> &input, Filter(const std::shared_ptr<LogicalOperator> &input,
const std::vector<std::shared_ptr<LogicalOperator>> &pattern_filters, Expression *expression); const std::vector<std::shared_ptr<LogicalOperator>> &pattern_filters, Expression *expression);
Filter(const std::shared_ptr<LogicalOperator> &input, Filter(const std::shared_ptr<LogicalOperator> &input,
const std::vector<std::shared_ptr<LogicalOperator>> &pattern_filters, Expression *expression, const std::vector<std::shared_ptr<LogicalOperator>> &pattern_filters, Expression *expression,
const Filters &all_filters); Filters all_filters);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
@ -1126,12 +1130,12 @@ class Filter : public memgraph::query::plan::LogicalOperator {
std::shared_ptr<memgraph::query::plan::LogicalOperator> input_; std::shared_ptr<memgraph::query::plan::LogicalOperator> input_;
std::vector<std::shared_ptr<memgraph::query::plan::LogicalOperator>> pattern_filters_; std::vector<std::shared_ptr<memgraph::query::plan::LogicalOperator>> pattern_filters_;
Expression *expression_; Expression *expression_;
const memgraph::query::plan::Filters all_filters_; memgraph::query::plan::Filters all_filters_;
static std::string SingleFilterName(const query::plan::FilterInfo &single_filter) { static std::string SingleFilterName(const query::plan::FilterInfo &single_filter) {
using Type = query::plan::FilterInfo::Type; using Type = query::plan::FilterInfo::Type;
if (single_filter.type == Type::Generic) { if (single_filter.type == Type::Generic) {
std::set<std::string> symbol_names; std::set<std::string, std::less<>> symbol_names;
for (const auto &symbol : single_filter.used_symbols) { for (const auto &symbol : single_filter.used_symbols) {
symbol_names.insert(symbol.name()); symbol_names.insert(symbol.name());
} }
@ -1144,7 +1148,7 @@ class Filter : public memgraph::query::plan::LogicalOperator {
LOG_FATAL("Label filters not using LabelsTest are not supported for query inspection!"); LOG_FATAL("Label filters not using LabelsTest are not supported for query inspection!");
} }
auto filter_expression = static_cast<LabelsTest *>(single_filter.expression); auto filter_expression = static_cast<LabelsTest *>(single_filter.expression);
std::set<std::string> label_names; std::set<std::string, std::less<>> label_names;
for (const auto &label : filter_expression->labels_) { for (const auto &label : filter_expression->labels_) {
label_names.insert(label.name); label_names.insert(label.name);
} }
@ -1167,7 +1171,7 @@ class Filter : public memgraph::query::plan::LogicalOperator {
} }
std::string ToString() const override { std::string ToString() const override {
std::set<std::string> filter_names; std::set<std::string, std::less<>> filter_names;
for (const auto &filter : all_filters_) { for (const auto &filter : all_filters_) {
filter_names.insert(Filter::SingleFilterName(filter)); filter_names.insert(Filter::SingleFilterName(filter));
} }
@ -1214,7 +1218,7 @@ class Produce : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
Produce() {} Produce() = default;
Produce(const std::shared_ptr<LogicalOperator> &input, const std::vector<NamedExpression *> &named_expressions); Produce(const std::shared_ptr<LogicalOperator> &input, const std::vector<NamedExpression *> &named_expressions);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
@ -1271,7 +1275,7 @@ class Delete : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
Delete() {} Delete() = default;
Delete(const std::shared_ptr<LogicalOperator> &input_, const std::vector<Expression *> &expressions, bool detach_); Delete(const std::shared_ptr<LogicalOperator> &input_, const std::vector<Expression *> &expressions, bool detach_);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
@ -1326,7 +1330,7 @@ class SetProperty : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
SetProperty() {} SetProperty() = default;
SetProperty(const std::shared_ptr<LogicalOperator> &input, storage::PropertyId property, PropertyLookup *lhs, SetProperty(const std::shared_ptr<LogicalOperator> &input, storage::PropertyId property, PropertyLookup *lhs,
Expression *rhs); Expression *rhs);
@ -1385,7 +1389,7 @@ class SetProperties : public memgraph::query::plan::LogicalOperator {
/// that the old properties are discarded and replaced with new ones. /// that the old properties are discarded and replaced with new ones.
enum class Op { UPDATE, REPLACE }; enum class Op { UPDATE, REPLACE };
SetProperties() {} SetProperties() = default;
SetProperties(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, Expression *rhs, Op op); SetProperties(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, Expression *rhs, Op op);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
@ -1433,7 +1437,7 @@ class SetLabels : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
SetLabels() {} SetLabels() = default;
SetLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, SetLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol,
const std::vector<storage::LabelId> &labels); const std::vector<storage::LabelId> &labels);
@ -1477,7 +1481,7 @@ class RemoveProperty : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
RemoveProperty() {} RemoveProperty() = default;
RemoveProperty(const std::shared_ptr<LogicalOperator> &input, storage::PropertyId property, PropertyLookup *lhs); RemoveProperty(const std::shared_ptr<LogicalOperator> &input, storage::PropertyId property, PropertyLookup *lhs);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
@ -1522,7 +1526,7 @@ class RemoveLabels : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
RemoveLabels() {} RemoveLabels() = default;
RemoveLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, RemoveLabels(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol,
const std::vector<storage::LabelId> &labels); const std::vector<storage::LabelId> &labels);
@ -1578,7 +1582,7 @@ class EdgeUniquenessFilter : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
EdgeUniquenessFilter() {} EdgeUniquenessFilter() = default;
EdgeUniquenessFilter(const std::shared_ptr<LogicalOperator> &input, Symbol expand_symbol, EdgeUniquenessFilter(const std::shared_ptr<LogicalOperator> &input, Symbol expand_symbol,
const std::vector<Symbol> &previous_symbols); const std::vector<Symbol> &previous_symbols);
@ -1636,7 +1640,7 @@ class EmptyResult : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
EmptyResult() {} EmptyResult() = default;
EmptyResult(const std::shared_ptr<LogicalOperator> &input); EmptyResult(const std::shared_ptr<LogicalOperator> &input);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
@ -1688,7 +1692,7 @@ class Accumulate : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
Accumulate() {} Accumulate() = default;
Accumulate(const std::shared_ptr<LogicalOperator> &input, const std::vector<Symbol> &symbols, Accumulate(const std::shared_ptr<LogicalOperator> &input, const std::vector<Symbol> &symbols,
bool advance_command = false); bool advance_command = false);
@ -1758,6 +1762,7 @@ class Aggregate : public memgraph::query::plan::LogicalOperator {
Aggregate() = default; Aggregate() = default;
Aggregate(const std::shared_ptr<LogicalOperator> &input, const std::vector<Element> &aggregations, Aggregate(const std::shared_ptr<LogicalOperator> &input, const std::vector<Element> &aggregations,
const std::vector<Expression *> &group_by, const std::vector<Symbol> &remember); const std::vector<Expression *> &group_by, const std::vector<Symbol> &remember);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override;
std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override; std::vector<Symbol> ModifiedSymbols(const SymbolTable &) const override;
@ -1810,7 +1815,7 @@ class Skip : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
Skip() {} Skip() = default;
Skip(const std::shared_ptr<LogicalOperator> &input, Expression *expression); Skip(const std::shared_ptr<LogicalOperator> &input, Expression *expression);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
@ -1856,7 +1861,7 @@ class EvaluatePatternFilter : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
EvaluatePatternFilter() {} EvaluatePatternFilter() = default;
EvaluatePatternFilter(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol); EvaluatePatternFilter(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
@ -1910,7 +1915,7 @@ class Limit : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
Limit() {} Limit() = default;
Limit(const std::shared_ptr<LogicalOperator> &input, Expression *expression); Limit(const std::shared_ptr<LogicalOperator> &input, Expression *expression);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
@ -1965,7 +1970,7 @@ class OrderBy : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
OrderBy() {} OrderBy() = default;
OrderBy(const std::shared_ptr<LogicalOperator> &input, const std::vector<SortItem> &order_by, OrderBy(const std::shared_ptr<LogicalOperator> &input, const std::vector<SortItem> &order_by,
const std::vector<Symbol> &output_symbols); const std::vector<Symbol> &output_symbols);
@ -2017,7 +2022,7 @@ class Merge : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
Merge() {} Merge() = default;
Merge(const std::shared_ptr<LogicalOperator> &input, const std::shared_ptr<LogicalOperator> &merge_match, Merge(const std::shared_ptr<LogicalOperator> &input, const std::shared_ptr<LogicalOperator> &merge_match,
const std::shared_ptr<LogicalOperator> &merge_create); const std::shared_ptr<LogicalOperator> &merge_create);
@ -2077,7 +2082,7 @@ class Optional : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
Optional() {} Optional() = default;
Optional(const std::shared_ptr<LogicalOperator> &input, const std::shared_ptr<LogicalOperator> &optional, Optional(const std::shared_ptr<LogicalOperator> &input, const std::shared_ptr<LogicalOperator> &optional,
const std::vector<Symbol> &optional_symbols); const std::vector<Symbol> &optional_symbols);
@ -2131,7 +2136,7 @@ class Unwind : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
Unwind() {} Unwind() = default;
Unwind(const std::shared_ptr<LogicalOperator> &input, Expression *input_expression_, Symbol output_symbol); Unwind(const std::shared_ptr<LogicalOperator> &input, Expression *input_expression_, Symbol output_symbol);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
@ -2166,7 +2171,7 @@ class Distinct : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
Distinct() {} Distinct() = default;
Distinct(const std::shared_ptr<LogicalOperator> &input, const std::vector<Symbol> &value_symbols); Distinct(const std::shared_ptr<LogicalOperator> &input, const std::vector<Symbol> &value_symbols);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
@ -2199,7 +2204,7 @@ class Union : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
Union() {} Union() = default;
Union(const std::shared_ptr<LogicalOperator> &left_op, const std::shared_ptr<LogicalOperator> &right_op, Union(const std::shared_ptr<LogicalOperator> &left_op, const std::shared_ptr<LogicalOperator> &right_op,
const std::vector<Symbol> &union_symbols, const std::vector<Symbol> &left_symbols, const std::vector<Symbol> &union_symbols, const std::vector<Symbol> &left_symbols,
@ -2255,7 +2260,7 @@ class Cartesian : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
Cartesian() {} Cartesian() = default;
/** Construct the operator with left input branch and right input branch. */ /** Construct the operator with left input branch and right input branch. */
Cartesian(const std::shared_ptr<LogicalOperator> &left_op, const std::vector<Symbol> &left_symbols, Cartesian(const std::shared_ptr<LogicalOperator> &left_op, const std::vector<Symbol> &left_symbols,
const std::shared_ptr<LogicalOperator> &right_op, const std::vector<Symbol> &right_symbols) const std::shared_ptr<LogicalOperator> &right_op, const std::vector<Symbol> &right_symbols)
@ -2290,7 +2295,7 @@ class OutputTable : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
OutputTable() {} OutputTable() = default;
OutputTable(std::vector<Symbol> output_symbols, OutputTable(std::vector<Symbol> output_symbols,
std::function<std::vector<std::vector<TypedValue>>(Frame *, ExecutionContext *)> callback); std::function<std::vector<std::vector<TypedValue>>(Frame *, ExecutionContext *)> callback);
OutputTable(std::vector<Symbol> output_symbols, std::vector<std::vector<TypedValue>> rows); OutputTable(std::vector<Symbol> output_symbols, std::vector<std::vector<TypedValue>> rows);
@ -2326,7 +2331,7 @@ class OutputTableStream : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
OutputTableStream() {} OutputTableStream() = default;
OutputTableStream(std::vector<Symbol> output_symbols, OutputTableStream(std::vector<Symbol> output_symbols,
std::function<std::optional<std::vector<TypedValue>>(Frame *, ExecutionContext *)> callback); std::function<std::optional<std::vector<TypedValue>>(Frame *, ExecutionContext *)> callback);
@ -2497,7 +2502,7 @@ class Apply : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
Apply() {} Apply() = default;
Apply(const std::shared_ptr<LogicalOperator> input, const std::shared_ptr<LogicalOperator> subquery, Apply(const std::shared_ptr<LogicalOperator> input, const std::shared_ptr<LogicalOperator> subquery,
bool subquery_has_return); bool subquery_has_return);
@ -2544,7 +2549,7 @@ class IndexedJoin : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
IndexedJoin() {} IndexedJoin() = default;
IndexedJoin(std::shared_ptr<LogicalOperator> main_branch, std::shared_ptr<LogicalOperator> sub_branch); IndexedJoin(std::shared_ptr<LogicalOperator> main_branch, std::shared_ptr<LogicalOperator> sub_branch);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
@ -2587,7 +2592,7 @@ class HashJoin : public memgraph::query::plan::LogicalOperator {
static const utils::TypeInfo kType; static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; } const utils::TypeInfo &GetTypeInfo() const override { return kType; }
HashJoin() {} HashJoin() = default;
/** Construct the operator with left input branch and right input branch. */ /** Construct the operator with left input branch and right input branch. */
HashJoin(const std::shared_ptr<LogicalOperator> &left_op, const std::vector<Symbol> &left_symbols, HashJoin(const std::shared_ptr<LogicalOperator> &left_op, const std::vector<Symbol> &left_symbols,
const std::shared_ptr<LogicalOperator> &right_op, const std::vector<Symbol> &right_symbols, const std::shared_ptr<LogicalOperator> &right_op, const std::vector<Symbol> &right_symbols,
@ -2630,5 +2635,4 @@ class HashJoin : public memgraph::query::plan::LogicalOperator {
}; };
} // namespace plan } // namespace plan
} // namespace query } // namespace memgraph::query
} // namespace memgraph

View File

@ -17,6 +17,8 @@
#pragma once #pragma once
#include <utility>
#include "query/plan/cost_estimator.hpp" #include "query/plan/cost_estimator.hpp"
#include "query/plan/operator.hpp" #include "query/plan/operator.hpp"
#include "query/plan/preprocess.hpp" #include "query/plan/preprocess.hpp"
@ -42,11 +44,11 @@ class PostProcessor final {
using ProcessedPlan = std::unique_ptr<LogicalOperator>; using ProcessedPlan = std::unique_ptr<LogicalOperator>;
explicit PostProcessor(const Parameters &parameters) : parameters_(parameters) {} explicit PostProcessor(Parameters parameters) : parameters_(std::move(parameters)) {}
template <class TDbAccessor> template <class TDbAccessor>
PostProcessor(const Parameters &parameters, std::vector<IndexHint> index_hints, TDbAccessor *db) PostProcessor(Parameters parameters, std::vector<IndexHint> index_hints, TDbAccessor *db)
: parameters_(parameters), index_hints_(IndexHints(index_hints, db)) {} : parameters_(std::move(parameters)), index_hints_(IndexHints(index_hints, db)) {}
template <class TPlanningContext> template <class TPlanningContext>
std::unique_ptr<LogicalOperator> Rewrite(std::unique_ptr<LogicalOperator> plan, TPlanningContext *context) { std::unique_ptr<LogicalOperator> Rewrite(std::unique_ptr<LogicalOperator> plan, TPlanningContext *context) {

View File

@ -14,6 +14,7 @@
#include <stack> #include <stack>
#include <type_traits> #include <type_traits>
#include <unordered_map> #include <unordered_map>
#include <utility>
#include <variant> #include <variant>
#include "query/exceptions.hpp" #include "query/exceptions.hpp"
@ -73,6 +74,13 @@ std::vector<Expansion> NormalizePatterns(const SymbolTable &symbol_table, const
// Remove symbols which are bound by lambda arguments. // Remove symbols which are bound by lambda arguments.
collector.symbols_.erase(symbol_table.at(*edge->filter_lambda_.inner_edge)); collector.symbols_.erase(symbol_table.at(*edge->filter_lambda_.inner_edge));
collector.symbols_.erase(symbol_table.at(*edge->filter_lambda_.inner_node)); collector.symbols_.erase(symbol_table.at(*edge->filter_lambda_.inner_node));
if (edge->filter_lambda_.accumulated_path) {
collector.symbols_.erase(symbol_table.at(*edge->filter_lambda_.accumulated_path));
if (edge->filter_lambda_.accumulated_weight) {
collector.symbols_.erase(symbol_table.at(*edge->filter_lambda_.accumulated_weight));
}
}
if (edge->type_ == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH || if (edge->type_ == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH ||
edge->type_ == EdgeAtom::Type::ALL_SHORTEST_PATHS) { edge->type_ == EdgeAtom::Type::ALL_SHORTEST_PATHS) {
collector.symbols_.erase(symbol_table.at(*edge->weight_lambda_.inner_edge)); collector.symbols_.erase(symbol_table.at(*edge->weight_lambda_.inner_edge));
@ -199,7 +207,7 @@ auto SplitExpressionOnAnd(Expression *expression) {
PropertyFilter::PropertyFilter(const SymbolTable &symbol_table, const Symbol &symbol, PropertyIx property, PropertyFilter::PropertyFilter(const SymbolTable &symbol_table, const Symbol &symbol, PropertyIx property,
Expression *value, Type type) Expression *value, Type type)
: symbol_(symbol), property_(property), type_(type), value_(value) { : symbol_(symbol), property_(std::move(property)), type_(type), value_(value) {
MG_ASSERT(type != Type::RANGE); MG_ASSERT(type != Type::RANGE);
UsedSymbolsCollector collector(symbol_table); UsedSymbolsCollector collector(symbol_table);
value->Accept(collector); value->Accept(collector);
@ -209,7 +217,11 @@ PropertyFilter::PropertyFilter(const SymbolTable &symbol_table, const Symbol &sy
PropertyFilter::PropertyFilter(const SymbolTable &symbol_table, const Symbol &symbol, PropertyIx property, PropertyFilter::PropertyFilter(const SymbolTable &symbol_table, const Symbol &symbol, PropertyIx property,
const std::optional<PropertyFilter::Bound> &lower_bound, const std::optional<PropertyFilter::Bound> &lower_bound,
const std::optional<PropertyFilter::Bound> &upper_bound) const std::optional<PropertyFilter::Bound> &upper_bound)
: symbol_(symbol), property_(property), type_(Type::RANGE), lower_bound_(lower_bound), upper_bound_(upper_bound) { : symbol_(symbol),
property_(std::move(property)),
type_(Type::RANGE),
lower_bound_(lower_bound),
upper_bound_(upper_bound) {
UsedSymbolsCollector collector(symbol_table); UsedSymbolsCollector collector(symbol_table);
if (lower_bound) { if (lower_bound) {
lower_bound->value()->Accept(collector); lower_bound->value()->Accept(collector);
@ -220,8 +232,8 @@ PropertyFilter::PropertyFilter(const SymbolTable &symbol_table, const Symbol &sy
is_symbol_in_value_ = utils::Contains(collector.symbols_, symbol); is_symbol_in_value_ = utils::Contains(collector.symbols_, symbol);
} }
PropertyFilter::PropertyFilter(const Symbol &symbol, PropertyIx property, Type type) PropertyFilter::PropertyFilter(Symbol symbol, PropertyIx property, Type type)
: symbol_(symbol), property_(property), type_(type) { : symbol_(std::move(symbol)), property_(std::move(property)), type_(type) {
// As this constructor is used for property filters where // As this constructor is used for property filters where
// we don't have to evaluate the filter expression, we set // we don't have to evaluate the filter expression, we set
// the is_symbol_in_value_ to false, although the filter // the is_symbol_in_value_ to false, although the filter
@ -290,6 +302,13 @@ void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table,
prop_pair.second->Accept(collector); prop_pair.second->Accept(collector);
collector.symbols_.emplace(symbol_table.at(*atom->filter_lambda_.inner_node)); collector.symbols_.emplace(symbol_table.at(*atom->filter_lambda_.inner_node));
collector.symbols_.emplace(symbol_table.at(*atom->filter_lambda_.inner_edge)); collector.symbols_.emplace(symbol_table.at(*atom->filter_lambda_.inner_edge));
if (atom->filter_lambda_.accumulated_path) {
collector.symbols_.emplace(symbol_table.at(*atom->filter_lambda_.accumulated_path));
if (atom->filter_lambda_.accumulated_weight) {
collector.symbols_.emplace(symbol_table.at(*atom->filter_lambda_.accumulated_weight));
}
}
// First handle the inline property filter. // First handle the inline property filter.
auto *property_lookup = storage.Create<PropertyLookup>(atom->filter_lambda_.inner_edge, prop_pair.first); auto *property_lookup = storage.Create<PropertyLookup>(atom->filter_lambda_.inner_edge, prop_pair.first);
auto *prop_equal = storage.Create<EqualOperator>(property_lookup, prop_pair.second); auto *prop_equal = storage.Create<EqualOperator>(property_lookup, prop_pair.second);

Some files were not shown because too many files have changed in this diff Show More