Integrate bolt server (#572)

* Use query-v2 in the main executable
* Set up machine manager in memgraph
* Add `ShardRequestManager` to `Interpreter`
* Make vertex creation work
* Make scan all work
* Add edge type map in shard request manager
* Send schema over request
* Empty out DbAccessor
* Store shard mapping at creation
* Remove failing CI steps

Cooltura is the best place in Zagreb!

Co-authored-by: János Benjamin Antal <benjamin.antal@memgraph.io>
This commit is contained in:
Jure Bajic 2022-10-11 16:31:46 +02:00 committed by GitHub
parent 6fd64d31f2
commit 23171e76b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
59 changed files with 711 additions and 12567 deletions

View File

@ -164,52 +164,14 @@ jobs:
cmake ..
make -j$THREADS
- name: Run leftover CTest tests
- name: Run simulation tests
run: |
# Activate toolchain.
source /opt/toolchain-v4/activate
# Run leftover CTest tests (all except unit and benchmark tests).
# Run simulation tests.
cd build
ctest -E "(memgraph__unit|memgraph__benchmark|memgraph__simulation)" --output-on-failure
- name: Run drivers tests
run: |
./tests/drivers/run.sh
- name: Run integration tests
run: |
cd tests/integration
for name in *; do
if [ ! -d $name ]; then continue; fi
pushd $name >/dev/null
echo "Running: $name"
if [ -x prepare.sh ]; then
./prepare.sh
fi
if [ -x runner.py ]; then
./runner.py
elif [ -x runner.sh ]; then
./runner.sh
fi
echo
popd >/dev/null
done
- name: Run cppcheck and clang-format
run: |
# Activate toolchain.
source /opt/toolchain-v4/activate
# Run cppcheck and clang-format.
cd tools/github
./cppcheck_and_clang_format diff
- name: Save cppcheck and clang-format errors
uses: actions/upload-artifact@v2
with:
name: "Code coverage"
path: tools/github/cppcheck_and_clang_format.txt
ctest -R memgraph__simulation --output-on-failure -j$THREADS
release_build:
name: "Release build"
@ -240,19 +202,6 @@ jobs:
cmake -DCMAKE_BUILD_TYPE=release ..
make -j$THREADS
- name: Run GQL Behave tests
run: |
cd tests/gql_behave
./continuous_integration
- name: Save quality assurance status
uses: actions/upload-artifact@v2
with:
name: "GQL Behave Status"
path: |
tests/gql_behave/gql_behave_status.csv
tests/gql_behave/gql_behave_status.html
- name: Run unit tests
run: |
# Activate toolchain.
@ -267,178 +216,6 @@ jobs:
# Activate toolchain.
source /opt/toolchain-v4/activate
# Run unit tests.
# Run simulation tests.
cd build
ctest -R memgraph__simulation --output-on-failure -j$THREADS
- name: Run e2e tests
run: |
# TODO(gitbuda): Setup mgclient and pymgclient properly.
cd tests
./setup.sh
source ve3/bin/activate
cd e2e
LD_LIBRARY_PATH=$LD_LIBRARY_PATH:../../libs/mgclient/lib python runner.py --workloads-root-directory .
- name: Run stress test (plain)
run: |
cd tests/stress
./continuous_integration
- name: Run stress test (SSL)
run: |
cd tests/stress
./continuous_integration --use-ssl
- name: Run durability test
run: |
cd tests/stress
source ve3/bin/activate
python3 durability --num-steps 5
- name: Create enterprise DEB package
run: |
# Activate toolchain.
source /opt/toolchain-v4/activate
cd build
# create mgconsole
# we use the -B to force the build
make -j$THREADS -B mgconsole
# Create enterprise DEB package.
mkdir output && cd output
cpack -G DEB --config ../CPackConfig.cmake
- name: Save enterprise DEB package
uses: actions/upload-artifact@v2
with:
name: "Enterprise DEB package"
path: build/output/memgraph*.deb
- name: Save test data
uses: actions/upload-artifact@v2
if: always()
with:
name: "Test data"
path: |
# multiple paths could be defined
build/logs
release_jepsen_test:
name: "Release Jepsen Test"
runs-on: [self-hosted, Linux, X64, Debian10, JepsenControl]
#continue-on-error: true
env:
THREADS: 24
MEMGRAPH_ENTERPRISE_LICENSE: ${{ secrets.MEMGRAPH_ENTERPRISE_LICENSE }}
MEMGRAPH_ORGANIZATION_NAME: ${{ secrets.MEMGRAPH_ORGANIZATION_NAME }}
steps:
- name: Set up repository
uses: actions/checkout@v2
with:
# Number of commits to fetch. `0` indicates all history for all
# branches and tags. (default: 1)
fetch-depth: 0
- name: Build release binaries
run: |
# Activate toolchain.
source /opt/toolchain-v4/activate
# Initialize dependencies.
./init
# Build only memgraph release binarie.
cd build
cmake -DCMAKE_BUILD_TYPE=release ..
make -j$THREADS memgraph
- name: Run Jepsen tests
run: |
cd tests/jepsen
./run.sh test --binary ../../build/memgraph --run-args "test-all --node-configs resources/node-config.edn" --ignore-run-stdout-logs --ignore-run-stderr-logs
- name: Save Jepsen report
uses: actions/upload-artifact@v2
if: ${{ always() }}
with:
name: "Jepsen Report"
path: tests/jepsen/Jepsen.tar.gz
release_benchmarks:
name: "Release benchmarks"
runs-on: [self-hosted, Linux, X64, Diff, Gen7]
env:
THREADS: 24
MEMGRAPH_ENTERPRISE_LICENSE: ${{ secrets.MEMGRAPH_ENTERPRISE_LICENSE }}
MEMGRAPH_ORGANIZATION_NAME: ${{ secrets.MEMGRAPH_ORGANIZATION_NAME }}
steps:
- name: Set up repository
uses: actions/checkout@v2
with:
# Number of commits to fetch. `0` indicates all history for all
# branches and tags. (default: 1)
fetch-depth: 0
- name: Build release binaries
run: |
# Activate toolchain.
source /opt/toolchain-v4/activate
# Initialize dependencies.
./init
# Build only memgraph release binaries.
cd build
cmake -DCMAKE_BUILD_TYPE=release ..
make -j$THREADS
- name: Run macro benchmarks
run: |
cd tests/macro_benchmark
./harness QuerySuite MemgraphRunner \
--groups aggregation 1000_create unwind_create dense_expand match \
--no-strict
- name: Get branch name (merge)
if: github.event_name != 'pull_request'
shell: bash
run: echo "BRANCH_NAME=$(echo ${GITHUB_REF#refs/heads/} | tr / -)" >> $GITHUB_ENV
- name: Get branch name (pull request)
if: github.event_name == 'pull_request'
shell: bash
run: echo "BRANCH_NAME=$(echo ${GITHUB_HEAD_REF} | tr / -)" >> $GITHUB_ENV
- name: Upload macro benchmark results
run: |
cd tools/bench-graph-client
virtualenv -p python3 ve3
source ve3/bin/activate
pip install -r requirements.txt
./main.py --benchmark-name "macro_benchmark" \
--benchmark-results-path "../../tests/macro_benchmark/.harness_summary" \
--github-run-id "${{ github.run_id }}" \
--github-run-number "${{ github.run_number }}" \
--head-branch-name "${{ env.BRANCH_NAME }}"
- name: Run mgbench
run: |
cd tests/mgbench
./benchmark.py --num-workers-for-benchmark 12 --export-results benchmark_result.json pokec/medium/*/*
- name: Upload mgbench results
run: |
cd tools/bench-graph-client
virtualenv -p python3 ve3
source ve3/bin/activate
pip install -r requirements.txt
./main.py --benchmark-name "mgbench" \
--benchmark-results-path "../../tests/mgbench/benchmark_result.json" \
--github-run-id "${{ github.run_id }}" \
--github-run-number "${{ github.run_number }}" \
--head-branch-name "${{ env.BRANCH_NAME }}"

View File

@ -83,13 +83,9 @@ modifications:
value: "true"
override: true
- name: "query_modules_directory"
value: "/usr/lib/memgraph/query_modules"
override: true
- name: "auth_module_executable"
value: "/usr/lib/memgraph/auth_module/example.py"
override: false
# - name: "query_modules_directory"
# value: "/usr/lib/memgraph/query_modules"
# override: true
- name: "memory_limit"
value: "0"

View File

@ -37,13 +37,13 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR})
# Memgraph Single Node v2 Executable
# ----------------------------------------------------------------------------
set(mg_single_node_v2_sources
glue/communication.cpp
memgraph.cpp
glue/auth.cpp
glue/v2/communication.cpp
memgraph.cpp
glue/v2/auth.cpp
)
set(mg_single_node_v2_libs stdc++fs Threads::Threads
telemetry_lib mg-query mg-communication mg-memory mg-utils mg-auth mg-license mg-settings)
telemetry_lib mg-query-v2 mg-communication mg-memory mg-utils mg-auth mg-license mg-settings mg-io mg-coordinator)
if (MG_ENTERPRISE)
# These are enterprise subsystems
set(mg_single_node_v2_libs ${mg_single_node_v2_libs} mg-audit)
@ -126,19 +126,3 @@ install(CODE "file(MAKE_DIRECTORY \$ENV{DESTDIR}/var/log/memgraph
# ----------------------------------------------------------------------------
# Memgraph CSV Import Tool Executable
# ----------------------------------------------------------------------------
add_executable(mg_import_csv mg_import_csv.cpp)
target_link_libraries(mg_import_csv mg-storage-v2)
# Strip the executable in release build.
if (lower_build_type STREQUAL "release")
add_custom_command(TARGET mg_import_csv POST_BUILD
COMMAND strip -s mg_import_csv
COMMENT "Stripping symbols and sections from mg_import_csv")
endif()
install(TARGETS mg_import_csv RUNTIME DESTINATION bin)
# ----------------------------------------------------------------------------
# Memgraph CSV Import Tool Executable
# ----------------------------------------------------------------------------

View File

@ -30,6 +30,7 @@
#include "communication/context.hpp"
#include "communication/v2/pool.hpp"
#include "communication/v2/session.hpp"
#include "utils/message.hpp"
#include "utils/spin_lock.hpp"
#include "utils/synchronized.hpp"
@ -58,10 +59,10 @@ class Listener final : public std::enable_shared_from_this<Listener<TSession, TS
bool IsRunning() const noexcept { return alive_.load(std::memory_order_relaxed); }
private:
Listener(boost::asio::io_context &io_context, TSessionData *data, ServerContext *server_context,
Listener(boost::asio::io_context &io_context, TSessionData &data, ServerContext *server_context,
tcp::endpoint &endpoint, const std::string_view service_name, const uint64_t inactivity_timeout_sec)
: io_context_(io_context),
data_(data),
data_(&data),
server_context_(server_context),
acceptor_(io_context_),
endpoint_{endpoint},
@ -110,7 +111,7 @@ class Listener final : public std::enable_shared_from_this<Listener<TSession, TS
return OnError(ec, "accept");
}
auto session = SessionHandler::Create(std::move(socket), data_, *server_context_, endpoint_, inactivity_timeout_,
auto session = SessionHandler::Create(std::move(socket), *data_, *server_context_, endpoint_, inactivity_timeout_,
service_name_);
session->Start();
DoAccept();

View File

@ -72,7 +72,7 @@ class Server final {
* Constructs and binds server to endpoint, operates on session data and
* invokes workers_count workers
*/
Server(ServerEndpoint &endpoint, TSessionData *session_data, ServerContext *server_context,
Server(ServerEndpoint &endpoint, TSessionData &session_data, ServerContext *server_context,
const int inactivity_timeout_sec, const std::string_view service_name,
size_t workers_count = std::thread::hardware_concurrency())
: endpoint_{endpoint},

View File

@ -41,6 +41,7 @@
#include <boost/beast/websocket/rfc6455.hpp>
#include <boost/system/detail/error_code.hpp>
#include "communication/buffer.hpp"
#include "communication/context.hpp"
#include "communication/exceptions.hpp"
#include "utils/logging.hpp"
@ -139,7 +140,7 @@ class WebsocketSession : public std::enable_shared_from_this<WebsocketSession<TS
private:
// Take ownership of the socket
explicit WebsocketSession(tcp::socket &&socket, TSessionData *data, tcp::endpoint endpoint,
explicit WebsocketSession(tcp::socket &&socket, TSessionData &data, tcp::endpoint endpoint,
std::string_view service_name)
: ws_(std::move(socket)),
strand_{boost::asio::make_strand(ws_.get_executor())},
@ -311,13 +312,13 @@ class Session final : public std::enable_shared_from_this<Session<TSession, TSes
}
private:
explicit Session(tcp::socket &&socket, TSessionData *data, ServerContext &server_context, tcp::endpoint endpoint,
explicit Session(tcp::socket &&socket, TSessionData &data, ServerContext &server_context, tcp::endpoint endpoint,
const std::chrono::seconds inactivity_timeout_sec, std::string_view service_name)
: socket_(CreateSocket(std::move(socket), server_context)),
strand_{boost::asio::make_strand(GetExecutor())},
output_stream_([this](const uint8_t *data, size_t len, bool have_more) { return Write(data, len, have_more); }),
session_(data, endpoint, input_buffer_.read_end(), &output_stream_),
data_{data},
data_{&data},
endpoint_{endpoint},
remote_endpoint_{GetRemoteEndpoint()},
service_name_{service_name},
@ -373,7 +374,7 @@ class Session final : public std::enable_shared_from_this<Session<TSession, TSes
spdlog::info("Switching {} to websocket connection", remote_endpoint_);
if (std::holds_alternative<TCPSocket>(socket_)) {
auto sock = std::get<TCPSocket>(std::move(socket_));
WebsocketSession<TSession, TSessionData>::Create(std::move(sock), data_, endpoint_, service_name_)
WebsocketSession<TSession, TSessionData>::Create(std::move(sock), *data_, endpoint_, service_name_)
->DoAccept(parser.release());
execution_active_ = false;
return;
@ -465,7 +466,7 @@ class Session final : public std::enable_shared_from_this<Session<TSession, TSes
if (timeout_timer_.expiry() <= boost::asio::steady_timer::clock_type::now()) {
// The deadline has passed. Stop the session. The other actors will
// terminate as soon as possible.
spdlog::info("Shutting down session after {} of inactivity", timeout_seconds_);
spdlog::info("Shutting down session after {} of inactivity", timeout_seconds_.count());
DoShutdown();
} else {
// Put the actor back to sleep.

View File

@ -75,10 +75,11 @@ std::optional<LabelId> ShardMap::InitializeNewLabel(std::string label_name, std:
};
LabelSpace label_space{
.schema = std::move(schema),
.schema = schema,
.shards = shards,
.replication_factor = replication_factor,
};
schemas[label_id] = std::move(schema);
label_spaces.emplace(label_id, label_space);

View File

@ -11,6 +11,7 @@
#pragma once
#include <algorithm>
#include <limits>
#include <map>
#include <set>
@ -27,6 +28,7 @@
#include "storage/v3/property_value.hpp"
#include "storage/v3/schemas.hpp"
#include "storage/v3/temporal.hpp"
#include "utils/exceptions.hpp"
namespace memgraph::coordinator {
@ -34,6 +36,7 @@ constexpr int64_t kNotExistingId{0};
using memgraph::io::Address;
using memgraph::storage::v3::Config;
using memgraph::storage::v3::EdgeTypeId;
using memgraph::storage::v3::LabelId;
using memgraph::storage::v3::PropertyId;
using memgraph::storage::v3::PropertyValue;
@ -58,7 +61,9 @@ using Shard = std::vector<AddressAndStatus>;
using Shards = std::map<PrimaryKey, Shard>;
using LabelName = std::string;
using PropertyName = std::string;
using EdgeTypeName = std::string;
using PropertyMap = std::map<PropertyName, PropertyId>;
using EdgeTypeIdMap = std::map<EdgeTypeName, EdgeTypeId>;
struct ShardToInitialize {
boost::uuids::uuid uuid;
@ -80,7 +85,9 @@ struct LabelSpace {
struct ShardMap {
Hlc shard_map_version;
uint64_t max_property_id{kNotExistingId};
uint64_t max_edge_type_id{kNotExistingId};
std::map<PropertyName, PropertyId> properties;
std::map<EdgeTypeName, EdgeTypeId> edge_types;
uint64_t max_label_id{kNotExistingId};
std::map<LabelName, LabelId> labels;
std::map<LabelId, LabelSpace> label_spaces;
@ -127,7 +134,7 @@ struct ShardMap {
.label_id = label_id,
.min_key = low_key,
.max_key = std::nullopt,
.schema = label_space.schema,
.schema = schemas[label_id],
.config = Config{},
});
}
@ -140,13 +147,12 @@ struct ShardMap {
// TODO(tyler) use deterministic UUID so that coordinators don't diverge here
address.unique_id = boost::uuids::uuid{boost::uuids::random_generator()()},
ret.push_back(ShardToInitialize{
.uuid = address.unique_id,
.label_id = label_id,
.min_key = low_key,
.max_key = std::nullopt,
.config = Config{},
});
ret.push_back(ShardToInitialize{.uuid = address.unique_id,
.label_id = label_id,
.min_key = low_key,
.max_key = std::nullopt,
.schema = schemas[label_id],
.config = Config{}});
AddressAndStatus aas = {
.address = address,
@ -196,6 +202,49 @@ struct ShardMap {
LabelId GetLabelId(const std::string &label) const { return labels.at(label); }
std::string GetLabelName(const LabelId label) const {
if (const auto it =
std::ranges::find_if(labels, [label](const auto &name_id_pair) { return name_id_pair.second == label; });
it != labels.end()) {
return it->first;
}
throw utils::BasicException("GetLabelName fails on the given label id!");
}
std::optional<PropertyId> GetPropertyId(const std::string &property_name) const {
if (properties.contains(property_name)) {
return properties.at(property_name);
}
return std::nullopt;
}
std::string GetPropertyName(const PropertyId property) const {
if (const auto it = std::ranges::find_if(
properties, [property](const auto &name_id_pair) { return name_id_pair.second == property; });
it != properties.end()) {
return it->first;
}
throw utils::BasicException("PropertyId not found!");
}
std::optional<EdgeTypeId> GetEdgeTypeId(const std::string &edge_type) const {
if (edge_types.contains(edge_type)) {
return edge_types.at(edge_type);
}
return std::nullopt;
}
std::string GetEdgeTypeName(const EdgeTypeId property) const {
if (const auto it = std::ranges::find_if(
edge_types, [property](const auto &name_id_pair) { return name_id_pair.second == property; });
it != edge_types.end()) {
return it->first;
}
throw utils::BasicException("EdgeTypeId not found!");
}
Shards GetShardsForRange(const LabelName &label_name, const PrimaryKey &start_key, const PrimaryKey &end_key) const {
MG_ASSERT(start_key <= end_key);
MG_ASSERT(labels.contains(label_name));
@ -268,12 +317,29 @@ struct ShardMap {
return ret;
}
std::optional<PropertyId> GetPropertyId(const std::string &property_name) const {
if (properties.contains(property_name)) {
return properties.at(property_name);
EdgeTypeIdMap AllocateEdgeTypeIds(const std::vector<EdgeTypeName> &new_edge_types) {
EdgeTypeIdMap ret;
bool mutated = false;
for (const auto &edge_type_name : new_edge_types) {
if (edge_types.contains(edge_type_name)) {
auto edge_type_id = edge_types.at(edge_type_name);
ret.emplace(edge_type_name, edge_type_id);
} else {
mutated = true;
const EdgeTypeId edge_type_id = EdgeTypeId::FromUint(++max_edge_type_id);
ret.emplace(edge_type_name, edge_type_id);
edge_types.emplace(edge_type_name, edge_type_id);
}
}
return std::nullopt;
if (mutated) {
IncrementShardMapVersion();
}
return ret;
}
};

View File

@ -15,9 +15,16 @@
#include <string>
#include <vector>
#include "coordinator/shard_map.hpp"
#include "query/v2/accessors.hpp"
#include "query/v2/requests.hpp"
#include "storage/v3/edge_accessor.hpp"
#include "storage/v3/id_types.hpp"
#include "storage/v3/result.hpp"
#include "storage/v3/shard.hpp"
#include "storage/v3/vertex_accessor.hpp"
#include "storage/v3/view.hpp"
#include "utils/exceptions.hpp"
#include "utils/temporal.hpp"
using memgraph::communication::bolt::Value;
@ -63,17 +70,51 @@ query::v2::TypedValue ToTypedValue(const Value &value) {
}
}
storage::v3::Result<communication::bolt::Vertex> ToBoltVertex(const query::v2::VertexAccessor &vertex,
const storage::v3::Shard &db, storage::v3::View view) {
return ToBoltVertex(vertex.impl_, db, view);
storage::v3::Result<communication::bolt::Vertex> ToBoltVertex(const query::v2::accessors::VertexAccessor &vertex,
const coordinator::ShardMap &shard_map,
storage::v3::View /*view*/) {
auto id = communication::bolt::Id::FromUint(0);
auto labels = vertex.Labels();
std::vector<std::string> new_labels;
new_labels.reserve(labels.size());
for (const auto &label : labels) {
new_labels.push_back(shard_map.GetLabelName(label.id));
}
auto properties = vertex.Properties();
std::map<std::string, Value> new_properties;
for (const auto &[prop, property_value] : properties) {
new_properties[shard_map.GetPropertyName(prop)] = ToBoltValue(property_value);
}
return communication::bolt::Vertex{id, new_labels, new_properties};
}
storage::v3::Result<communication::bolt::Edge> ToBoltEdge(const query::v2::EdgeAccessor &edge,
const storage::v3::Shard &db, storage::v3::View view) {
return ToBoltEdge(edge.impl_, db, view);
storage::v3::Result<communication::bolt::Edge> ToBoltEdge(const query::v2::accessors::EdgeAccessor &edge,
const coordinator::ShardMap &shard_map,
storage::v3::View /*view*/) {
// TODO(jbajic) Fix bolt communication
auto id = communication::bolt::Id::FromUint(0);
auto from = communication::bolt::Id::FromUint(0);
auto to = communication::bolt::Id::FromUint(0);
const auto &type = shard_map.GetEdgeTypeName(edge.EdgeType());
auto properties = edge.Properties();
std::map<std::string, Value> new_properties;
for (const auto &[prop, property_value] : properties) {
new_properties[shard_map.GetPropertyName(prop)] = ToBoltValue(property_value);
}
return communication::bolt::Edge{id, from, to, type, new_properties};
}
storage::v3::Result<Value> ToBoltValue(const query::v2::TypedValue &value, const storage::v3::Shard &db,
storage::v3::Result<communication::bolt::Path> ToBoltPath(const query::v2::accessors::Path & /*edge*/,
const coordinator::ShardMap & /*shard_map*/,
storage::v3::View /*view*/) {
// TODO(jbajic) Fix bolt communication
return {storage::v3::Error::DELETED_OBJECT};
}
storage::v3::Result<Value> ToBoltValue(const query::v2::TypedValue &value, const coordinator::ShardMap &shard_map,
storage::v3::View view) {
switch (value.type()) {
case query::v2::TypedValue::Type::Null:
@ -90,7 +131,7 @@ storage::v3::Result<Value> ToBoltValue(const query::v2::TypedValue &value, const
std::vector<Value> values;
values.reserve(value.ValueList().size());
for (const auto &v : value.ValueList()) {
auto maybe_value = ToBoltValue(v, db, view);
auto maybe_value = ToBoltValue(v, shard_map, view);
if (maybe_value.HasError()) return maybe_value.GetError();
values.emplace_back(std::move(*maybe_value));
}
@ -99,24 +140,24 @@ storage::v3::Result<Value> ToBoltValue(const query::v2::TypedValue &value, const
case query::v2::TypedValue::Type::Map: {
std::map<std::string, Value> map;
for (const auto &kv : value.ValueMap()) {
auto maybe_value = ToBoltValue(kv.second, db, view);
auto maybe_value = ToBoltValue(kv.second, shard_map, view);
if (maybe_value.HasError()) return maybe_value.GetError();
map.emplace(kv.first, std::move(*maybe_value));
}
return Value(std::move(map));
}
case query::v2::TypedValue::Type::Vertex: {
auto maybe_vertex = ToBoltVertex(value.ValueVertex(), db, view);
auto maybe_vertex = ToBoltVertex(value.ValueVertex(), shard_map, view);
if (maybe_vertex.HasError()) return maybe_vertex.GetError();
return Value(std::move(*maybe_vertex));
}
case query::v2::TypedValue::Type::Edge: {
auto maybe_edge = ToBoltEdge(value.ValueEdge(), db, view);
auto maybe_edge = ToBoltEdge(value.ValueEdge(), shard_map, view);
if (maybe_edge.HasError()) return maybe_edge.GetError();
return Value(std::move(*maybe_edge));
}
case query::v2::TypedValue::Type::Path: {
auto maybe_path = ToBoltPath(value.ValuePath(), db, view);
auto maybe_path = ToBoltPath(value.ValuePath(), shard_map, view);
if (maybe_path.HasError()) return maybe_path.GetError();
return Value(std::move(*maybe_path));
}
@ -131,59 +172,48 @@ storage::v3::Result<Value> ToBoltValue(const query::v2::TypedValue &value, const
}
}
storage::v3::Result<communication::bolt::Vertex> ToBoltVertex(const storage::v3::VertexAccessor &vertex,
const storage::v3::Shard &db, storage::v3::View view) {
// TODO(jbajic) Fix bolt communication
auto id = communication::bolt::Id::FromUint(0);
auto maybe_labels = vertex.Labels(view);
if (maybe_labels.HasError()) return maybe_labels.GetError();
std::vector<std::string> labels;
labels.reserve(maybe_labels->size());
for (const auto &label : *maybe_labels) {
labels.push_back(db.LabelToName(label));
Value ToBoltValue(msgs::Value value) {
switch (value.type) {
case msgs::Value::Type::Null:
return {};
case msgs::Value::Type::Bool:
return {value.bool_v};
case msgs::Value::Type::Int64:
return {value.int_v};
case msgs::Value::Type::Double:
return {value.double_v};
case msgs::Value::Type::String:
return {std::string(value.string_v)};
case msgs::Value::Type::List: {
std::vector<Value> values;
values.reserve(value.list_v.size());
for (const auto &v : value.list_v) {
auto maybe_value = ToBoltValue(v);
values.emplace_back(std::move(maybe_value));
}
return Value{std::move(values)};
}
case msgs::Value::Type::Map: {
std::map<std::string, Value> map;
for (const auto &kv : value.map_v) {
auto maybe_value = ToBoltValue(kv.second);
map.emplace(kv.first, std::move(maybe_value));
}
return Value{std::move(map)};
}
case msgs::Value::Type::Vertex:
case msgs::Value::Type::Edge:
case msgs::Value::Type::Path: {
throw utils::BasicException("Path, Vertex and Edge not supported!");
}
// TODO Value to Date types not supported
}
auto maybe_properties = vertex.Properties(view);
if (maybe_properties.HasError()) return maybe_properties.GetError();
std::map<std::string, Value> properties;
for (const auto &prop : *maybe_properties) {
properties[db.PropertyToName(prop.first)] = ToBoltValue(prop.second);
}
return communication::bolt::Vertex{id, labels, properties};
}
storage::v3::Result<communication::bolt::Edge> ToBoltEdge(const storage::v3::EdgeAccessor &edge,
const storage::v3::Shard &db, storage::v3::View view) {
// TODO(jbajic) Fix bolt communication
auto id = communication::bolt::Id::FromUint(0);
auto from = communication::bolt::Id::FromUint(0);
auto to = communication::bolt::Id::FromUint(0);
const auto &type = db.EdgeTypeToName(edge.EdgeType());
auto maybe_properties = edge.Properties(view);
if (maybe_properties.HasError()) return maybe_properties.GetError();
std::map<std::string, Value> properties;
for (const auto &prop : *maybe_properties) {
properties[db.PropertyToName(prop.first)] = ToBoltValue(prop.second);
}
return communication::bolt::Edge{id, from, to, type, properties};
}
storage::v3::Result<communication::bolt::Path> ToBoltPath(const query::v2::Path &path, const storage::v3::Shard &db,
storage::v3::View view) {
std::vector<communication::bolt::Vertex> vertices;
vertices.reserve(path.vertices().size());
for (const auto &v : path.vertices()) {
auto maybe_vertex = ToBoltVertex(v, db, view);
if (maybe_vertex.HasError()) return maybe_vertex.GetError();
vertices.emplace_back(std::move(*maybe_vertex));
}
std::vector<communication::bolt::Edge> edges;
edges.reserve(path.edges().size());
for (const auto &e : path.edges()) {
auto maybe_edge = ToBoltEdge(e, db, view);
if (maybe_edge.HasError()) return maybe_edge.GetError();
edges.emplace_back(std::move(*maybe_edge));
}
return communication::bolt::Path(vertices, edges);
storage::v3::Result<communication::bolt::Path> ToBoltPath(const query::v2::accessors::Path & /*path*/,
const storage::v3::Shard & /*db*/,
storage::v3::View /*view*/) {
return communication::bolt::Path();
}
storage::v3::PropertyValue ToPropertyValue(const Value &value) {

View File

@ -13,9 +13,11 @@
#pragma once
#include "communication/bolt/v1/value.hpp"
#include "coordinator/shard_map.hpp"
#include "query/v2/bindings/typed_value.hpp"
#include "storage/v3/property_value.hpp"
#include "storage/v3/result.hpp"
#include "storage/v3/shard.hpp"
#include "storage/v3/view.hpp"
namespace memgraph::storage::v3 {
@ -28,36 +30,40 @@ namespace memgraph::glue::v2 {
/// @param storage::v3::VertexAccessor for converting to
/// communication::bolt::Vertex.
/// @param storage::v3::Shard for getting label and property names.
/// @param coordinator::ShardMap shard_map getting label and property names.
/// @param storage::v3::View for deciding which vertex attributes are visible.
///
/// @throw std::bad_alloc
storage::v3::Result<communication::bolt::Vertex> ToBoltVertex(const storage::v3::VertexAccessor &vertex,
const storage::v3::Shard &db, storage::v3::View view);
const coordinator::ShardMap &shard_map,
storage::v3::View view);
/// @param storage::v3::EdgeAccessor for converting to communication::bolt::Edge.
/// @param storage::v3::Shard for getting edge type and property names.
/// @param coordinator::ShardMap shard_map getting edge type and property names.
/// @param storage::v3::View for deciding which edge attributes are visible.
///
/// @throw std::bad_alloc
storage::v3::Result<communication::bolt::Edge> ToBoltEdge(const storage::v3::EdgeAccessor &edge,
const storage::v3::Shard &db, storage::v3::View view);
const coordinator::ShardMap &shard_map,
storage::v3::View view);
/// @param query::v2::Path for converting to communication::bolt::Path.
/// @param storage::v3::Shard for ToBoltVertex and ToBoltEdge.
/// @param coordinator::ShardMap shard_map ToBoltVertex and ToBoltEdge.
/// @param storage::v3::View for ToBoltVertex and ToBoltEdge.
///
/// @throw std::bad_alloc
storage::v3::Result<communication::bolt::Path> ToBoltPath(const query::v2::Path &path, const storage::v3::Shard &db,
storage::v3::Result<communication::bolt::Path> ToBoltPath(const query::v2::accessors::Path &path,
const coordinator::ShardMap &shard_map,
storage::v3::View view);
/// @param query::v2::TypedValue for converting to communication::bolt::Value.
/// @param storage::v3::Shard for ToBoltVertex and ToBoltEdge.
/// @param coordinator::ShardMap shard_map ToBoltVertex and ToBoltEdge.
/// @param storage::v3::View for ToBoltVertex and ToBoltEdge.
///
/// @throw std::bad_alloc
storage::v3::Result<communication::bolt::Value> ToBoltValue(const query::v2::TypedValue &value,
const storage::v3::Shard &db, storage::v3::View view);
const coordinator::ShardMap &shard_map,
storage::v3::View view);
query::v2::TypedValue ToTypedValue(const communication::bolt::Value &value);
@ -65,4 +71,9 @@ communication::bolt::Value ToBoltValue(const storage::v3::PropertyValue &value);
storage::v3::PropertyValue ToPropertyValue(const communication::bolt::Value &value);
communication::bolt::Value ToBoltValue(msgs::Value value);
communication::bolt::Value ToBoltValue(msgs::Value value, const coordinator::ShardMap &shard_map,
storage::v3::View view);
} // namespace memgraph::glue::v2

View File

@ -192,7 +192,7 @@ class Future {
template <typename T>
class Promise {
std::shared_ptr<details::Shared<T>> shared_;
bool filled_or_moved_ = false;
bool filled_or_moved_{false};
public:
explicit Promise(std::shared_ptr<details::Shared<T>> shared) : shared_(shared) {}
@ -212,6 +212,7 @@ class Promise {
Promise(const Promise &) = delete;
Promise &operator=(const Promise &) = delete;
// NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Branch)
~Promise() { MG_ASSERT(filled_or_moved_, "Promise destroyed before its associated Future was filled!"); }
// Fill the expected item into the Future.

View File

@ -18,6 +18,7 @@
#include <io/time.hpp>
#include <machine_manager/machine_config.hpp>
#include <storage/v3/shard_manager.hpp>
#include "coordinator/shard_map.hpp"
namespace memgraph::machine_manager {
@ -69,11 +70,11 @@ class MachineManager {
public:
// TODO initialize ShardManager with "real" coordinator addresses instead of io.GetAddress
// which is only true for single-machine config.
MachineManager(io::Io<IoImpl> io, MachineConfig config, Coordinator coordinator)
MachineManager(io::Io<IoImpl> io, MachineConfig config, Coordinator coordinator, coordinator::ShardMap &shard_map)
: io_(io),
config_(config),
coordinator_{std::move(io.ForkLocal()), {}, std::move(coordinator)},
shard_manager_(ShardManager{io.ForkLocal(), coordinator_.GetAddress()}) {}
shard_manager_{io.ForkLocal(), coordinator_.GetAddress(), shard_map} {}
Address CoordinatorAddress() { return coordinator_.GetAddress(); }

View File

@ -35,20 +35,28 @@
#include "communication/bolt/v1/constants.hpp"
#include "communication/websocket/auth.hpp"
#include "communication/websocket/server.hpp"
#include "coordinator/shard_map.hpp"
#include "helpers.hpp"
#include "io/address.hpp"
#include "io/local_transport/local_system.hpp"
#include "io/local_transport/local_transport.hpp"
#include "io/simulator/simulator_transport.hpp"
#include "machine_manager/machine_config.hpp"
#include "machine_manager/machine_manager.hpp"
#include "py/py.hpp"
#include "query/auth_checker.hpp"
#include "query/discard_value_stream.hpp"
#include "query/exceptions.hpp"
#include "query/frontend/ast/ast.hpp"
#include "query/interpreter.hpp"
#include "query/plan/operator.hpp"
#include "query/procedure/module.hpp"
#include "query/procedure/py_module.hpp"
#include "query/v2/discard_value_stream.hpp"
#include "query/v2/exceptions.hpp"
#include "query/v2/frontend/ast/ast.hpp"
#include "query/v2/interpreter.hpp"
#include "query/v2/plan/operator.hpp"
#include "requests/requests.hpp"
#include "storage/v2/isolation_level.hpp"
#include "storage/v2/storage.hpp"
#include "storage/v2/view.hpp"
#include "storage/v3/id_types.hpp"
#include "storage/v3/isolation_level.hpp"
#include "storage/v3/key_store.hpp"
#include "storage/v3/property_value.hpp"
#include "storage/v3/schemas.hpp"
#include "storage/v3/shard.hpp"
#include "storage/v3/view.hpp"
#include "telemetry/telemetry.hpp"
#include "utils/event_counter.hpp"
#include "utils/file.hpp"
@ -62,7 +70,6 @@
#include "utils/settings.hpp"
#include "utils/signals.hpp"
#include "utils/string.hpp"
#include "utils/synchronized.hpp"
#include "utils/sysinfo/memory.hpp"
#include "utils/terminate_handler.hpp"
#include "version.hpp"
@ -83,10 +90,10 @@
#include "communication/init.hpp"
#include "communication/v2/server.hpp"
#include "communication/v2/session.hpp"
#include "glue/communication.hpp"
#include "glue/v2/communication.hpp"
#include "auth/auth.hpp"
#include "glue/auth.hpp"
#include "glue/v2/auth.hpp"
#ifdef MG_ENTERPRISE
#include "audit/log.hpp"
@ -197,16 +204,6 @@ DEFINE_bool(storage_wal_enabled, false,
DEFINE_VALIDATED_uint64(storage_snapshot_retention_count, 3, "The number of snapshots that should always be kept.",
FLAG_IN_RANGE(1, 1000000));
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_VALIDATED_uint64(storage_wal_file_size_kib, memgraph::storage::Config::Durability().wal_file_size_kibibytes,
"Minimum file size of each WAL file.",
FLAG_IN_RANGE(1, static_cast<unsigned long>(1000) * 1024));
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_VALIDATED_uint64(storage_wal_file_flush_every_n_tx,
memgraph::storage::Config::Durability().wal_file_flush_every_n_tx,
"Issue a 'fsync' call after this amount of transactions are written to the "
"WAL file. Set to 1 for fully synchronous operation.",
FLAG_IN_RANGE(1, 1000000));
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_bool(storage_snapshot_on_exit, false, "Controls whether the storage creates another snapshot on exit.");
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
@ -271,9 +268,9 @@ DEFINE_uint64(
namespace {
using namespace std::literals;
inline constexpr std::array isolation_level_mappings{
std::pair{"SNAPSHOT_ISOLATION"sv, memgraph::storage::IsolationLevel::SNAPSHOT_ISOLATION},
std::pair{"READ_COMMITTED"sv, memgraph::storage::IsolationLevel::READ_COMMITTED},
std::pair{"READ_UNCOMMITTED"sv, memgraph::storage::IsolationLevel::READ_UNCOMMITTED}};
std::pair{"SNAPSHOT_ISOLATION"sv, memgraph::storage::v3::IsolationLevel::SNAPSHOT_ISOLATION},
std::pair{"READ_COMMITTED"sv, memgraph::storage::v3::IsolationLevel::READ_COMMITTED},
std::pair{"READ_UNCOMMITTED"sv, memgraph::storage::v3::IsolationLevel::READ_UNCOMMITTED}};
const std::string isolation_level_help_string =
fmt::format("Default isolation level used for the transactions. Allowed values: {}",
@ -302,13 +299,6 @@ DEFINE_VALIDATED_string(isolation_level, "SNAPSHOT_ISOLATION", isolation_level_h
});
namespace {
memgraph::storage::IsolationLevel ParseIsolationLevel() {
const auto isolation_level =
StringToEnum<memgraph::storage::IsolationLevel>(FLAGS_isolation_level, isolation_level_mappings);
MG_ASSERT(isolation_level, "Invalid isolation level");
return *isolation_level;
}
int64_t GetMemoryLimit() {
if (FLAGS_memory_limit == 0) {
auto maybe_total_memory = memgraph::utils::sysinfo::TotalMemory();
@ -329,30 +319,6 @@ int64_t GetMemoryLimit() {
}
} // namespace
namespace {
std::vector<std::filesystem::path> query_modules_directories;
} // namespace
DEFINE_VALIDATED_string(query_modules_directory, "",
"Directory where modules with custom query procedures are stored. "
"NOTE: Multiple comma-separated directories can be defined.",
{
query_modules_directories.clear();
if (value.empty()) return true;
const auto directories = memgraph::utils::Split(value, ",");
for (const auto &dir : directories) {
if (!memgraph::utils::DirExists(dir)) {
std::cout << "Expected --" << flagname << " to point to directories." << std::endl;
std::cout << dir << " is not a directory." << std::endl;
return false;
}
}
query_modules_directories.reserve(directories.size());
std::transform(directories.begin(), directories.end(),
std::back_inserter(query_modules_directories),
[](const auto &dir) { return dir; });
return true;
});
// Logging flags
DEFINE_bool(also_log_to_stderr, false, "Log messages go to stderr in addition to logfiles");
DEFINE_string(log_file, "", "Path to where the log should be stored.");
@ -424,13 +390,6 @@ void InitializeLogger() {
CreateLoggerFromSink(sinks, ParseLogLevel());
}
void AddLoggerSink(spdlog::sink_ptr new_sink) {
auto default_logger = spdlog::default_logger();
auto sinks = default_logger->sinks();
sinks.push_back(new_sink);
CreateLoggerFromSink(sinks, default_logger->level());
}
} // namespace
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
@ -443,431 +402,26 @@ DEFINE_string(organization_name, "", "Organization name.");
struct SessionData {
// Explicit constructor here to ensure that pointers to all objects are
// supplied.
#if MG_ENTERPRISE
SessionData(memgraph::storage::Storage *db, memgraph::query::InterpreterContext *interpreter_context,
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth,
memgraph::audit::Log *audit_log)
: db(db), interpreter_context(interpreter_context), auth(auth), audit_log(audit_log) {}
memgraph::storage::Storage *db;
memgraph::query::InterpreterContext *interpreter_context;
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth;
memgraph::audit::Log *audit_log;
#else
SessionData(memgraph::storage::Storage *db, memgraph::query::InterpreterContext *interpreter_context,
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth)
: db(db), interpreter_context(interpreter_context), auth(auth) {}
memgraph::storage::Storage *db;
memgraph::query::InterpreterContext *interpreter_context;
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth;
#endif
SessionData(memgraph::coordinator::ShardMap &shard_map, memgraph::query::v2::InterpreterContext *interpreter_context)
: shard_map(&shard_map), interpreter_context(interpreter_context) {}
memgraph::coordinator::ShardMap *shard_map;
memgraph::query::v2::InterpreterContext *interpreter_context;
};
inline constexpr std::string_view default_user_role_regex = "[a-zA-Z0-9_.+-@]+";
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_string(auth_user_or_role_name_regex, default_user_role_regex.data(),
"Set to the regular expression that each user or role name must fulfill.");
class AuthQueryHandler final : public memgraph::query::AuthQueryHandler {
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth_;
std::string name_regex_string_;
std::regex name_regex_;
public:
AuthQueryHandler(memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth,
std::string name_regex_string)
: auth_(auth), name_regex_string_(std::move(name_regex_string)), name_regex_(name_regex_string_) {}
bool CreateUser(const std::string &username, const std::optional<std::string> &password) override {
if (name_regex_string_ != default_user_role_regex) {
if (const auto license_check_result =
memgraph::utils::license::global_license_checker.IsValidLicense(memgraph::utils::global_settings);
license_check_result.HasError()) {
throw memgraph::auth::AuthException(
"Custom user/role regex is a Memgraph Enterprise feature. Please set the config "
"(\"--auth-user-or-role-name-regex\") to its default value (\"{}\") or remove the flag.\n{}",
default_user_role_regex,
memgraph::utils::license::LicenseCheckErrorToString(license_check_result.GetError(), "user/role regex"));
}
}
if (!std::regex_match(username, name_regex_)) {
throw memgraph::query::QueryRuntimeException("Invalid user name.");
}
try {
const auto [first_user, user_added] = std::invoke([&, this] {
auto locked_auth = auth_->Lock();
const auto first_user = !locked_auth->HasUsers();
const auto user_added = locked_auth->AddUser(username, password).has_value();
return std::make_pair(first_user, user_added);
});
if (first_user) {
spdlog::info("{} is first created user. Granting all privileges.", username);
GrantPrivilege(username, memgraph::query::kPrivilegesAll);
}
return user_added;
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
bool DropUser(const std::string &username) override {
if (!std::regex_match(username, name_regex_)) {
throw memgraph::query::QueryRuntimeException("Invalid user name.");
}
try {
auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username);
if (!user) return false;
return locked_auth->RemoveUser(username);
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
void SetPassword(const std::string &username, const std::optional<std::string> &password) override {
if (!std::regex_match(username, name_regex_)) {
throw memgraph::query::QueryRuntimeException("Invalid user name.");
}
try {
auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username);
if (!user) {
throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist.", username);
}
user->UpdatePassword(password);
locked_auth->SaveUser(*user);
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
bool CreateRole(const std::string &rolename) override {
if (!std::regex_match(rolename, name_regex_)) {
throw memgraph::query::QueryRuntimeException("Invalid role name.");
}
try {
auto locked_auth = auth_->Lock();
return locked_auth->AddRole(rolename).has_value();
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
bool DropRole(const std::string &rolename) override {
if (!std::regex_match(rolename, name_regex_)) {
throw memgraph::query::QueryRuntimeException("Invalid role name.");
}
try {
auto locked_auth = auth_->Lock();
auto role = locked_auth->GetRole(rolename);
if (!role) return false;
return locked_auth->RemoveRole(rolename);
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
std::vector<memgraph::query::TypedValue> GetUsernames() override {
try {
auto locked_auth = auth_->ReadLock();
std::vector<memgraph::query::TypedValue> usernames;
const auto &users = locked_auth->AllUsers();
usernames.reserve(users.size());
for (const auto &user : users) {
usernames.emplace_back(user.username());
}
return usernames;
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
std::vector<memgraph::query::TypedValue> GetRolenames() override {
try {
auto locked_auth = auth_->ReadLock();
std::vector<memgraph::query::TypedValue> rolenames;
const auto &roles = locked_auth->AllRoles();
rolenames.reserve(roles.size());
for (const auto &role : roles) {
rolenames.emplace_back(role.rolename());
}
return rolenames;
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
std::optional<std::string> GetRolenameForUser(const std::string &username) override {
if (!std::regex_match(username, name_regex_)) {
throw memgraph::query::QueryRuntimeException("Invalid user name.");
}
try {
auto locked_auth = auth_->ReadLock();
auto user = locked_auth->GetUser(username);
if (!user) {
throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist .", username);
}
if (const auto *role = user->role(); role != nullptr) {
return role->rolename();
}
return std::nullopt;
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
std::vector<memgraph::query::TypedValue> GetUsernamesForRole(const std::string &rolename) override {
if (!std::regex_match(rolename, name_regex_)) {
throw memgraph::query::QueryRuntimeException("Invalid role name.");
}
try {
auto locked_auth = auth_->ReadLock();
auto role = locked_auth->GetRole(rolename);
if (!role) {
throw memgraph::query::QueryRuntimeException("Role '{}' doesn't exist.", rolename);
}
std::vector<memgraph::query::TypedValue> usernames;
const auto &users = locked_auth->AllUsersForRole(rolename);
usernames.reserve(users.size());
for (const auto &user : users) {
usernames.emplace_back(user.username());
}
return usernames;
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
void SetRole(const std::string &username, const std::string &rolename) override {
if (!std::regex_match(username, name_regex_)) {
throw memgraph::query::QueryRuntimeException("Invalid user name.");
}
if (!std::regex_match(rolename, name_regex_)) {
throw memgraph::query::QueryRuntimeException("Invalid role name.");
}
try {
auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username);
if (!user) {
throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist .", username);
}
auto role = locked_auth->GetRole(rolename);
if (!role) {
throw memgraph::query::QueryRuntimeException("Role '{}' doesn't exist .", rolename);
}
if (const auto *current_role = user->role(); current_role != nullptr) {
throw memgraph::query::QueryRuntimeException("User '{}' is already a member of role '{}'.", username,
current_role->rolename());
}
user->SetRole(*role);
locked_auth->SaveUser(*user);
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
void ClearRole(const std::string &username) override {
if (!std::regex_match(username, name_regex_)) {
throw memgraph::query::QueryRuntimeException("Invalid user name.");
}
try {
auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username);
if (!user) {
throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist .", username);
}
user->ClearRole();
locked_auth->SaveUser(*user);
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
std::vector<std::vector<memgraph::query::TypedValue>> GetPrivileges(const std::string &user_or_role) override {
if (!std::regex_match(user_or_role, name_regex_)) {
throw memgraph::query::QueryRuntimeException("Invalid user or role name.");
}
try {
auto locked_auth = auth_->ReadLock();
std::vector<std::vector<memgraph::query::TypedValue>> grants;
auto user = locked_auth->GetUser(user_or_role);
auto role = locked_auth->GetRole(user_or_role);
if (!user && !role) {
throw memgraph::query::QueryRuntimeException("User or role '{}' doesn't exist.", user_or_role);
}
if (user) {
const auto &permissions = user->GetPermissions();
for (const auto &privilege : memgraph::query::kPrivilegesAll) {
auto permission = memgraph::glue::PrivilegeToPermission(privilege);
auto effective = permissions.Has(permission);
if (permissions.Has(permission) != memgraph::auth::PermissionLevel::NEUTRAL) {
std::vector<std::string> description;
auto user_level = user->permissions().Has(permission);
if (user_level == memgraph::auth::PermissionLevel::GRANT) {
description.emplace_back("GRANTED TO USER");
} else if (user_level == memgraph::auth::PermissionLevel::DENY) {
description.emplace_back("DENIED TO USER");
}
if (const auto *role = user->role(); role != nullptr) {
auto role_level = role->permissions().Has(permission);
if (role_level == memgraph::auth::PermissionLevel::GRANT) {
description.emplace_back("GRANTED TO ROLE");
} else if (role_level == memgraph::auth::PermissionLevel::DENY) {
description.emplace_back("DENIED TO ROLE");
}
}
grants.push_back({memgraph::query::TypedValue(memgraph::auth::PermissionToString(permission)),
memgraph::query::TypedValue(memgraph::auth::PermissionLevelToString(effective)),
memgraph::query::TypedValue(memgraph::utils::Join(description, ", "))});
}
}
} else {
const auto &permissions = role->permissions();
for (const auto &privilege : memgraph::query::kPrivilegesAll) {
auto permission = memgraph::glue::PrivilegeToPermission(privilege);
auto effective = permissions.Has(permission);
if (effective != memgraph::auth::PermissionLevel::NEUTRAL) {
std::string description;
if (effective == memgraph::auth::PermissionLevel::GRANT) {
description = "GRANTED TO ROLE";
} else if (effective == memgraph::auth::PermissionLevel::DENY) {
description = "DENIED TO ROLE";
}
grants.push_back({memgraph::query::TypedValue(memgraph::auth::PermissionToString(permission)),
memgraph::query::TypedValue(memgraph::auth::PermissionLevelToString(effective)),
memgraph::query::TypedValue(description)});
}
}
}
return grants;
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
void GrantPrivilege(const std::string &user_or_role,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges) override {
EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) {
// TODO (mferencevic): should we first check that the
// privilege is granted/denied/revoked before
// unconditionally granting/denying/revoking it?
permissions->Grant(permission);
});
}
void DenyPrivilege(const std::string &user_or_role,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges) override {
EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) {
// TODO (mferencevic): should we first check that the
// privilege is granted/denied/revoked before
// unconditionally granting/denying/revoking it?
permissions->Deny(permission);
});
}
void RevokePrivilege(const std::string &user_or_role,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges) override {
EditPermissions(user_or_role, privileges, [](auto *permissions, const auto &permission) {
// TODO (mferencevic): should we first check that the
// privilege is granted/denied/revoked before
// unconditionally granting/denying/revoking it?
permissions->Revoke(permission);
});
}
private:
template <class TEditFun>
void EditPermissions(const std::string &user_or_role,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges, const TEditFun &edit_fun) {
if (!std::regex_match(user_or_role, name_regex_)) {
throw memgraph::query::QueryRuntimeException("Invalid user or role name.");
}
try {
std::vector<memgraph::auth::Permission> permissions;
permissions.reserve(privileges.size());
for (const auto &privilege : privileges) {
permissions.push_back(memgraph::glue::PrivilegeToPermission(privilege));
}
auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(user_or_role);
auto role = locked_auth->GetRole(user_or_role);
if (!user && !role) {
throw memgraph::query::QueryRuntimeException("User or role '{}' doesn't exist.", user_or_role);
}
if (user) {
for (const auto &permission : permissions) {
edit_fun(&user->permissions(), permission);
}
locked_auth->SaveUser(*user);
} else {
for (const auto &permission : permissions) {
edit_fun(&role->permissions(), permission);
}
locked_auth->SaveRole(*role);
}
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
};
class AuthChecker final : public memgraph::query::AuthChecker {
public:
explicit AuthChecker(
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth)
: auth_{auth} {}
static bool IsUserAuthorized(const memgraph::auth::User &user,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges) {
const auto user_permissions = user.GetPermissions();
return std::all_of(privileges.begin(), privileges.end(), [&user_permissions](const auto privilege) {
return user_permissions.Has(memgraph::glue::PrivilegeToPermission(privilege)) ==
memgraph::auth::PermissionLevel::GRANT;
});
}
bool IsUserAuthorized(const std::optional<std::string> &username,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges) const final {
std::optional<memgraph::auth::User> maybe_user;
{
auto locked_auth = auth_->ReadLock();
if (!locked_auth->HasUsers()) {
return true;
}
if (username.has_value()) {
maybe_user = locked_auth->GetUser(*username);
}
}
return maybe_user.has_value() && IsUserAuthorized(*maybe_user, privileges);
}
private:
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth_;
};
class BoltSession final : public memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream,
memgraph::communication::v2::OutputStream> {
public:
BoltSession(SessionData *data, const memgraph::communication::v2::ServerEndpoint &endpoint,
BoltSession(SessionData &data, const memgraph::communication::v2::ServerEndpoint &endpoint,
memgraph::communication::v2::InputStream *input_stream,
memgraph::communication::v2::OutputStream *output_stream)
: memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream,
memgraph::communication::v2::OutputStream>(input_stream, output_stream),
db_(data->db),
interpreter_(data->interpreter_context),
auth_(data->auth),
#if MG_ENTERPRISE
audit_log_(data->audit_log),
#endif
endpoint_(endpoint) {
}
shard_map_(data.shard_map),
interpreter_(data.interpreter_context),
endpoint_(endpoint) {}
using memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream,
memgraph::communication::v2::OutputStream>::TEncoder;
@ -880,29 +434,14 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
std::pair<std::vector<std::string>, std::optional<int>> Interpret(
const std::string &query, const std::map<std::string, memgraph::communication::bolt::Value> &params) override {
std::map<std::string, memgraph::storage::PropertyValue> params_pv;
for (const auto &kv : params) params_pv.emplace(kv.first, memgraph::glue::ToPropertyValue(kv.second));
std::map<std::string, memgraph::storage::v3::PropertyValue> params_pv;
for (const auto &kv : params) params_pv.emplace(kv.first, memgraph::glue::v2::ToPropertyValue(kv.second));
const std::string *username{nullptr};
if (user_) {
username = &user_->username();
}
#ifdef MG_ENTERPRISE
if (memgraph::utils::license::global_license_checker.IsValidLicenseFast()) {
audit_log_->Record(endpoint_.address().to_string(), user_ ? *username : "", query,
memgraph::storage::PropertyValue(params_pv));
}
#endif
try {
auto result = interpreter_.Prepare(query, params_pv, username);
if (user_ && !AuthChecker::IsUserAuthorized(*user_, result.privileges)) {
interpreter_.Abort();
throw memgraph::communication::bolt::ClientError(
"You are not authorized to execute this query! Please contact "
"your database administrator.");
}
return {result.headers, result.qid};
} catch (const memgraph::query::QueryException &e) {
} catch (const memgraph::query::v2::QueryException &e) {
// Wrap QueryException into ClientError, because we want to allow the
// client to fix their query.
throw memgraph::communication::bolt::ClientError(e.what());
@ -911,26 +450,19 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
std::map<std::string, memgraph::communication::bolt::Value> Pull(TEncoder *encoder, std::optional<int> n,
std::optional<int> qid) override {
TypedValueResultStream stream(encoder, db_);
TypedValueResultStream stream(encoder, *shard_map_);
return PullResults(stream, n, qid);
}
std::map<std::string, memgraph::communication::bolt::Value> Discard(std::optional<int> n,
std::optional<int> qid) override {
memgraph::query::DiscardValueResultStream stream;
memgraph::query::v2::DiscardValueResultStream stream;
return PullResults(stream, n, qid);
}
void Abort() override { interpreter_.Abort(); }
bool Authenticate(const std::string &username, const std::string &password) override {
auto locked_auth = auth_->Lock();
if (!locked_auth->HasUsers()) {
return true;
}
user_ = locked_auth->Authenticate(username, password);
return user_.has_value();
}
bool Authenticate(const std::string & /*username*/, const std::string & /*password*/) override { return true; }
std::optional<std::string> GetServerNameForInit() override {
if (FLAGS_bolt_server_name_for_init.empty()) return std::nullopt;
@ -945,21 +477,21 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
const auto &summary = interpreter_.Pull(&stream, n, qid);
std::map<std::string, memgraph::communication::bolt::Value> decoded_summary;
for (const auto &kv : summary) {
auto maybe_value = memgraph::glue::ToBoltValue(kv.second, *db_, memgraph::storage::View::NEW);
auto maybe_value = memgraph::glue::v2::ToBoltValue(kv.second, *shard_map_, memgraph::storage::v3::View::NEW);
if (maybe_value.HasError()) {
switch (maybe_value.GetError()) {
case memgraph::storage::Error::DELETED_OBJECT:
case memgraph::storage::Error::SERIALIZATION_ERROR:
case memgraph::storage::Error::VERTEX_HAS_EDGES:
case memgraph::storage::Error::PROPERTIES_DISABLED:
case memgraph::storage::Error::NONEXISTENT_OBJECT:
case memgraph::storage::v3::Error::DELETED_OBJECT:
case memgraph::storage::v3::Error::SERIALIZATION_ERROR:
case memgraph::storage::v3::Error::VERTEX_HAS_EDGES:
case memgraph::storage::v3::Error::PROPERTIES_DISABLED:
case memgraph::storage::v3::Error::NONEXISTENT_OBJECT:
throw memgraph::communication::bolt::ClientError("Unexpected storage error when streaming summary.");
}
}
decoded_summary.emplace(kv.first, std::move(*maybe_value));
}
return decoded_summary;
} catch (const memgraph::query::QueryException &e) {
} catch (const memgraph::query::v2::QueryException &e) {
// Wrap QueryException into ClientError, because we want to allow the
// client to fix their query.
throw memgraph::communication::bolt::ClientError(e.what());
@ -970,22 +502,23 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
/// before forwarding the calls to original TEncoder.
class TypedValueResultStream {
public:
TypedValueResultStream(TEncoder *encoder, const memgraph::storage::Storage *db) : encoder_(encoder), db_(db) {}
TypedValueResultStream(TEncoder *encoder, const memgraph::coordinator::ShardMap &shard_map)
: encoder_(encoder), shard_map_(&shard_map) {}
void Result(const std::vector<memgraph::query::TypedValue> &values) {
void Result(const std::vector<memgraph::query::v2::TypedValue> &values) {
std::vector<memgraph::communication::bolt::Value> decoded_values;
decoded_values.reserve(values.size());
for (const auto &v : values) {
auto maybe_value = memgraph::glue::ToBoltValue(v, *db_, memgraph::storage::View::NEW);
auto maybe_value = memgraph::glue::v2::ToBoltValue(v, *shard_map_, memgraph::storage::v3::View::NEW);
if (maybe_value.HasError()) {
switch (maybe_value.GetError()) {
case memgraph::storage::Error::DELETED_OBJECT:
case memgraph::storage::v3::Error::DELETED_OBJECT:
throw memgraph::communication::bolt::ClientError("Returning a deleted object as a result.");
case memgraph::storage::Error::NONEXISTENT_OBJECT:
case memgraph::storage::v3::Error::NONEXISTENT_OBJECT:
throw memgraph::communication::bolt::ClientError("Returning a nonexistent object as a result.");
case memgraph::storage::Error::VERTEX_HAS_EDGES:
case memgraph::storage::Error::SERIALIZATION_ERROR:
case memgraph::storage::Error::PROPERTIES_DISABLED:
case memgraph::storage::v3::Error::VERTEX_HAS_EDGES:
case memgraph::storage::v3::Error::SERIALIZATION_ERROR:
case memgraph::storage::v3::Error::PROPERTIES_DISABLED:
throw memgraph::communication::bolt::ClientError("Unexpected storage error when streaming results.");
}
}
@ -997,17 +530,12 @@ class BoltSession final : public memgraph::communication::bolt::Session<memgraph
private:
TEncoder *encoder_;
// NOTE: Needed only for ToBoltValue conversions
const memgraph::storage::Storage *db_;
const memgraph::coordinator::ShardMap *shard_map_;
};
// NOTE: Needed only for ToBoltValue conversions
const memgraph::storage::Storage *db_;
memgraph::query::Interpreter interpreter_;
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth_;
std::optional<memgraph::auth::User> user_;
#ifdef MG_ENTERPRISE
memgraph::audit::Log *audit_log_;
#endif
const memgraph::coordinator::ShardMap *shard_map_;
memgraph::query::v2::Interpreter interpreter_;
memgraph::communication::v2::ServerEndpoint endpoint_;
};
@ -1059,80 +587,12 @@ int main(int argc, char **argv) {
// Unhandled exception handler init.
std::set_terminate(&memgraph::utils::TerminateHandler);
// Initialize Python
auto *program_name = Py_DecodeLocale(argv[0], nullptr);
MG_ASSERT(program_name);
// Set program name, so Python can find its way to runtime libraries relative
// to executable.
Py_SetProgramName(program_name);
PyImport_AppendInittab("_mgp", &memgraph::query::procedure::PyInitMgpModule);
Py_InitializeEx(0 /* = initsigs */);
PyEval_InitThreads();
Py_BEGIN_ALLOW_THREADS;
// Add our Python modules to sys.path
try {
auto exe_path = memgraph::utils::GetExecutablePath();
auto py_support_dir = exe_path.parent_path() / "python_support";
if (std::filesystem::is_directory(py_support_dir)) {
auto gil = memgraph::py::EnsureGIL();
auto maybe_exc = memgraph::py::AppendToSysPath(py_support_dir.c_str());
if (maybe_exc) {
spdlog::error(memgraph::utils::MessageWithLink("Unable to load support for embedded Python: {}.", *maybe_exc,
"https://memgr.ph/python"));
} else {
// Change how we load dynamic libraries on Python by using RTLD_NOW and
// RTLD_DEEPBIND flags. This solves an issue with using the wrong version of
// libstd.
auto gil = memgraph::py::EnsureGIL();
// NOLINTNEXTLINE(hicpp-signed-bitwise)
auto *flag = PyLong_FromLong(RTLD_NOW | RTLD_DEEPBIND);
auto *setdl = PySys_GetObject("setdlopenflags");
MG_ASSERT(setdl);
auto *arg = PyTuple_New(1);
MG_ASSERT(arg);
MG_ASSERT(PyTuple_SetItem(arg, 0, flag) == 0);
PyObject_CallObject(setdl, arg);
Py_DECREF(flag);
Py_DECREF(setdl);
Py_DECREF(arg);
}
} else {
spdlog::error(
memgraph::utils::MessageWithLink("Unable to load support for embedded Python: missing directory {}.",
py_support_dir, "https://memgr.ph/python"));
}
} catch (const std::filesystem::filesystem_error &e) {
spdlog::error(memgraph::utils::MessageWithLink("Unable to load support for embedded Python: {}.", e.what(),
"https://memgr.ph/python"));
}
// Initialize the communication library.
memgraph::communication::SSLInit sslInit;
// Initialize the requests library.
memgraph::requests::Init();
// Start memory warning logger.
memgraph::utils::Scheduler mem_log_scheduler;
if (FLAGS_memory_warning_threshold > 0) {
auto free_ram = memgraph::utils::sysinfo::AvailableMemory();
if (free_ram) {
mem_log_scheduler.Run("Memory warning", std::chrono::seconds(3), [] {
auto free_ram = memgraph::utils::sysinfo::AvailableMemory();
if (free_ram && *free_ram / 1024 < FLAGS_memory_warning_threshold)
spdlog::warn(memgraph::utils::MessageWithLink("Running out of available RAM, only {} MB left.",
*free_ram / 1024, "https://memgr.ph/ram"));
});
} else {
// Kernel version for the `MemAvailable` value is from: man procfs
spdlog::warn(
"You have an older kernel version (<3.14) or the /proc "
"filesystem isn't available so remaining memory warnings "
"won't be available.");
}
}
std::cout << "You are running Memgraph v" << gflags::VersionString() << std::endl;
std::cout << "To get started with Memgraph, visit https://memgr.ph/start" << std::endl;
@ -1165,66 +625,35 @@ int main(int argc, char **argv) {
// Example: When the main storage is destructed it makes a snapshot. When
// audit logging is destructed it syncs all pending data to disk and that can
// fail. That is why it must be destructed *after* the main database storage
// to minimise the impact of their failure on the main storage.
// to minimize the impact of their failure on the main storage.
// Begin enterprise features initialization
memgraph::io::local_transport::LocalSystem ls;
auto unique_local_addr_query = memgraph::coordinator::Address::UniqueLocalAddress();
auto io = ls.Register(unique_local_addr_query);
// Auth
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> auth{data_directory /
"auth"};
memgraph::machine_manager::MachineConfig config{
.coordinator_addresses = std::vector<memgraph::io::Address>{unique_local_addr_query},
.is_storage = true,
.is_coordinator = true,
.listen_ip = unique_local_addr_query.last_known_ip,
.listen_port = unique_local_addr_query.last_known_port,
};
#ifdef MG_ENTERPRISE
// Audit log
memgraph::audit::Log audit_log{data_directory / "audit", FLAGS_audit_buffer_size,
FLAGS_audit_buffer_flush_interval_ms};
// Start the log if enabled.
if (FLAGS_audit_enabled) {
audit_log.Start();
}
// Setup SIGUSR2 to be used for reopening audit log files, when e.g. logrotate
// rotates our audit logs.
MG_ASSERT(memgraph::utils::SignalHandler::RegisterHandler(memgraph::utils::Signal::User2,
[&audit_log]() { audit_log.ReopenLog(); }),
"Unable to register SIGUSR2 handler!");
memgraph::coordinator::ShardMap sm;
auto prop_map = sm.AllocatePropertyIds(std::vector<std::string>{"property"});
auto edge_type_map = sm.AllocateEdgeTypeIds(std::vector<std::string>{"edge_type"});
std::vector<memgraph::storage::v3::SchemaProperty> schema{
{prop_map.at("property"), memgraph::common::SchemaType::INT}};
sm.InitializeNewLabel("label", schema, 1, sm.shard_map_version);
// End enterprise features initialization
#endif
memgraph::coordinator::Coordinator coordinator{sm};
// Main storage and execution engines initialization
memgraph::storage::Config db_config{
.gc = {.type = memgraph::storage::Config::Gc::Type::PERIODIC,
.interval = std::chrono::seconds(FLAGS_storage_gc_cycle_sec)},
.items = {.properties_on_edges = FLAGS_storage_properties_on_edges},
.durability = {.storage_directory = FLAGS_data_directory,
.recover_on_startup = FLAGS_storage_recover_on_startup,
.snapshot_retention_count = FLAGS_storage_snapshot_retention_count,
.wal_file_size_kibibytes = FLAGS_storage_wal_file_size_kib,
.wal_file_flush_every_n_tx = FLAGS_storage_wal_file_flush_every_n_tx,
.snapshot_on_exit = FLAGS_storage_snapshot_on_exit,
.restore_replicas_on_startup = FLAGS_storage_restore_replicas_on_startup},
.transaction = {.isolation_level = ParseIsolationLevel()}};
if (FLAGS_storage_snapshot_interval_sec == 0) {
if (FLAGS_storage_wal_enabled) {
LOG_FATAL(
"In order to use write-ahead-logging you must enable "
"periodic snapshots by setting the snapshot interval to a "
"value larger than 0!");
db_config.durability.snapshot_wal_mode = memgraph::storage::Config::Durability::SnapshotWalMode::DISABLED;
}
} else {
if (FLAGS_storage_wal_enabled) {
db_config.durability.snapshot_wal_mode =
memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT_WITH_WAL;
} else {
db_config.durability.snapshot_wal_mode =
memgraph::storage::Config::Durability::SnapshotWalMode::PERIODIC_SNAPSHOT;
}
db_config.durability.snapshot_interval = std::chrono::seconds(FLAGS_storage_snapshot_interval_sec);
}
memgraph::storage::Storage db(db_config);
memgraph::machine_manager::MachineManager<memgraph::io::local_transport::LocalTransport> mm{io, config, coordinator,
sm};
std::jthread mm_thread([&mm] { mm.Run(); });
memgraph::query::InterpreterContext interpreter_context{
&db,
memgraph::query::v2::InterpreterContext interpreter_context{
(memgraph::storage::v3::Shard *)(nullptr),
{.query = {.allow_load_csv = FLAGS_allow_load_csv},
.execution_timeout_sec = FLAGS_query_execution_timeout_sec,
.replication_replica_check_frequency = std::chrono::seconds(FLAGS_replication_replica_check_frequency_sec),
@ -1232,32 +661,14 @@ int main(int argc, char **argv) {
.default_pulsar_service_url = FLAGS_pulsar_service_url,
.stream_transaction_conflict_retries = FLAGS_stream_transaction_conflict_retries,
.stream_transaction_retry_interval = std::chrono::milliseconds(FLAGS_stream_transaction_retry_interval)},
FLAGS_data_directory};
#ifdef MG_ENTERPRISE
SessionData session_data{&db, &interpreter_context, &auth, &audit_log};
#else
SessionData session_data{&db, &interpreter_context, &auth};
#endif
FLAGS_data_directory,
std::move(io),
mm.CoordinatorAddress()};
memgraph::query::procedure::gModuleRegistry.SetModulesDirectory(query_modules_directories, FLAGS_data_directory);
memgraph::query::procedure::gModuleRegistry.UnloadAndLoadModulesFromDirectories();
SessionData session_data{sm, &interpreter_context};
AuthQueryHandler auth_handler(&auth, FLAGS_auth_user_or_role_name_regex);
AuthChecker auth_checker{&auth};
interpreter_context.auth = &auth_handler;
interpreter_context.auth_checker = &auth_checker;
{
// Triggers can execute query procedures, so we need to reload the modules first and then
// the triggers
auto storage_accessor = interpreter_context.db->Access();
auto dba = memgraph::query::DbAccessor{&storage_accessor};
interpreter_context.trigger_store.RestoreTriggers(
&interpreter_context.ast_cache, &dba, interpreter_context.config.query, interpreter_context.auth_checker);
}
// As the Stream transformations are using modules, they have to be restored after the query modules are loaded.
interpreter_context.streams.RestoreStreams();
interpreter_context.auth = nullptr;
interpreter_context.auth_checker = nullptr;
ServerContext context;
std::string service_name = "Bolt";
@ -1272,61 +683,27 @@ int main(int argc, char **argv) {
auto server_endpoint = memgraph::communication::v2::ServerEndpoint{
boost::asio::ip::address::from_string(FLAGS_bolt_address), static_cast<uint16_t>(FLAGS_bolt_port)};
ServerT server(server_endpoint, &session_data, &context, FLAGS_bolt_session_inactivity_timeout, service_name,
ServerT server(server_endpoint, session_data, &context, FLAGS_bolt_session_inactivity_timeout, service_name,
FLAGS_bolt_num_workers);
// Setup telemetry
std::optional<memgraph::telemetry::Telemetry> telemetry;
if (FLAGS_telemetry_enabled) {
telemetry.emplace("https://telemetry.memgraph.com/88b5e7e8-746a-11e8-9f85-538a9e9690cc/",
data_directory / "telemetry", std::chrono::minutes(10));
telemetry->AddCollector("storage", [&db]() -> nlohmann::json {
auto info = db.GetInfo();
return {{"vertices", info.vertex_count}, {"edges", info.edge_count}};
});
telemetry->AddCollector("event_counters", []() -> nlohmann::json {
nlohmann::json ret;
for (size_t i = 0; i < EventCounter::End(); ++i) {
ret[EventCounter::GetName(i)] = EventCounter::global_counters[i].load(std::memory_order_relaxed);
}
return ret;
});
telemetry->AddCollector("query_module_counters", []() -> nlohmann::json {
return memgraph::query::plan::CallProcedure::GetAndResetCounters();
});
}
memgraph::communication::websocket::SafeAuth websocket_auth{&auth};
memgraph::communication::websocket::Server websocket_server{
{FLAGS_monitoring_address, static_cast<uint16_t>(FLAGS_monitoring_port)}, &context, websocket_auth};
AddLoggerSink(websocket_server.GetLoggingSink());
// Handler for regular termination signals
auto shutdown = [&websocket_server, &server, &interpreter_context] {
auto shutdown = [&server, &interpreter_context] {
// Server needs to be shutdown first and then the database. This prevents
// a race condition when a transaction is accepted during server shutdown.
server.Shutdown();
// After the server is notified to stop accepting and processing
// connections we tell the execution engine to stop processing all pending
// queries.
memgraph::query::Shutdown(&interpreter_context);
websocket_server.Shutdown();
memgraph::query::v2::Shutdown(&interpreter_context);
};
InitSignalHandlers(shutdown);
MG_ASSERT(server.Start(), "Couldn't start the Bolt server!");
websocket_server.Start();
server.AwaitShutdown();
websocket_server.AwaitShutdown();
memgraph::query::procedure::gModuleRegistry.UnloadAllModules();
Py_END_ALLOW_THREADS;
// Shutdown Python
Py_Finalize();
PyMem_RawFree(program_name);
memgraph::utils::total_memory_tracker.LogPeakMemoryUsage();
return 0;

View File

@ -9,7 +9,6 @@ set(mg_query_v2_sources
${lcp_query_v2_cpp_files}
common.cpp
cypher_query_interpreter.cpp
dump.cpp
frontend/semantic/required_privileges.cpp
frontend/stripped.cpp
interpret/awesome_memgraph_functions.cpp
@ -23,16 +22,7 @@ set(mg_query_v2_sources
plan/rewrite/index_lookup.cpp
plan/rule_based_planner.cpp
plan/variable_start_planner.cpp
# procedure/mg_procedure_impl.cpp
# procedure/mg_procedure_helpers.cpp
# procedure/module.cpp
# procedure/py_module.cpp
serialization/property_value.cpp
# stream/streams.cpp
# stream/sources.cpp
# stream/common.cpp
# trigger.cpp
# trigger_context.cpp
bindings/typed_value.cpp
accessors.cpp)

View File

@ -11,12 +11,13 @@
#include "query/v2/accessors.hpp"
#include "query/v2/requests.hpp"
#include "storage/v3/id_types.hpp"
namespace memgraph::query::v2::accessors {
EdgeAccessor::EdgeAccessor(Edge edge, std::vector<std::pair<PropertyId, Value>> props)
: edge(std::move(edge)), properties(std::move(props)) {}
uint64_t EdgeAccessor::EdgeType() const { return edge.type.id; }
EdgeTypeId EdgeAccessor::EdgeType() const { return EdgeTypeId::FromUint(edge.type.id); }
std::vector<std::pair<PropertyId, Value>> EdgeAccessor::Properties() const {
return properties;

View File

@ -30,6 +30,7 @@ using Edge = memgraph::msgs::Edge;
using Vertex = memgraph::msgs::Vertex;
using Label = memgraph::msgs::Label;
using PropertyId = memgraph::msgs::PropertyId;
using EdgeTypeId = memgraph::msgs::EdgeTypeId;
class VertexAccessor;
@ -37,7 +38,7 @@ class EdgeAccessor final {
public:
EdgeAccessor(Edge edge, std::vector<std::pair<PropertyId, Value>> props);
uint64_t EdgeType() const;
EdgeTypeId EdgeType() const;
std::vector<std::pair<PropertyId, Value>> Properties() const;

View File

@ -100,41 +100,6 @@ template <typename TRecordAccessor>
concept RecordAccessor =
AccessorWithSetProperty<TRecordAccessor> || AccessorWithSetPropertyAndValidate<TRecordAccessor>;
inline void HandleSchemaViolation(const storage::v3::SchemaViolation &schema_violation, const DbAccessor &dba) {
switch (schema_violation.status) {
case storage::v3::SchemaViolation::ValidationStatus::VERTEX_HAS_NO_PRIMARY_PROPERTY: {
throw SchemaViolationException(
fmt::format("Primary key {} not defined on label :{}",
storage::v3::SchemaTypeToString(schema_violation.violated_schema_property->type),
dba.LabelToName(schema_violation.label)));
}
case storage::v3::SchemaViolation::ValidationStatus::NO_SCHEMA_DEFINED_FOR_LABEL: {
throw SchemaViolationException(
fmt::format("Label :{} is not a primary label", dba.LabelToName(schema_violation.label)));
}
case storage::v3::SchemaViolation::ValidationStatus::VERTEX_PROPERTY_WRONG_TYPE: {
throw SchemaViolationException(
fmt::format("Wrong type of property {} in schema :{}, should be of type {}",
*schema_violation.violated_property_value, dba.LabelToName(schema_violation.label),
storage::v3::SchemaTypeToString(schema_violation.violated_schema_property->type)));
}
case storage::v3::SchemaViolation::ValidationStatus::VERTEX_UPDATE_PRIMARY_KEY: {
throw SchemaViolationException(fmt::format("Updating of primary key {} on schema :{} not supported",
*schema_violation.violated_property_value,
dba.LabelToName(schema_violation.label)));
}
case storage::v3::SchemaViolation::ValidationStatus::VERTEX_UPDATE_PRIMARY_LABEL: {
throw SchemaViolationException(fmt::format(
"Adding primary label as secondary or removing primary label:", *schema_violation.violated_property_value,
dba.LabelToName(schema_violation.label)));
}
case storage::v3::SchemaViolation::ValidationStatus::VERTEX_SECONDARY_LABEL_IS_PRIMARY: {
throw SchemaViolationException(fmt::format("Cannot create vertex where primary label is secondary:{}",
dba.LabelToName(schema_violation.label)));
}
}
}
inline void HandleErrorOnPropertyUpdate(const storage::v3::Error error) {
switch (error) {
case storage::v3::Error::SERIALIZATION_ERROR:
@ -149,35 +114,5 @@ inline void HandleErrorOnPropertyUpdate(const storage::v3::Error error) {
}
}
/// Set a property `value` mapped with given `key` on a `record`.
///
/// @throw QueryRuntimeException if value cannot be set as a property value
template <RecordAccessor T>
storage::v3::PropertyValue PropsSetChecked(T *record, const DbAccessor &dba, const storage::v3::PropertyId &key,
const TypedValue &value) {
try {
if constexpr (std::is_same_v<T, VertexAccessor>) {
const auto maybe_old_value = record->SetPropertyAndValidate(key, storage::v3::TypedToPropertyValue(value));
if (maybe_old_value.HasError()) {
std::visit(utils::Overloaded{[](const storage::v3::Error error) { HandleErrorOnPropertyUpdate(error); },
[&dba](const storage::v3::SchemaViolation &schema_violation) {
HandleSchemaViolation(schema_violation, dba);
}},
maybe_old_value.GetError());
}
return std::move(*maybe_old_value);
} else {
// No validation on edge properties
const auto maybe_old_value = record->SetProperty(key, storage::v3::TypedToPropertyValue(value));
if (maybe_old_value.HasError()) {
HandleErrorOnPropertyUpdate(maybe_old_value.GetError());
}
return std::move(*maybe_old_value);
}
} catch (const expr::TypedValueException &) {
throw QueryRuntimeException("'{}' cannot be used as a property value.", value.type());
}
}
int64_t QueryTimestamp();
} // namespace memgraph::query::v2

View File

@ -13,6 +13,7 @@
#include <type_traits>
#include "io/local_transport/local_transport.hpp"
#include "query/v2/bindings/symbol_table.hpp"
#include "query/v2/common.hpp"
#include "query/v2/metadata.hpp"
@ -42,21 +43,28 @@ struct EvaluationContext {
mutable std::unordered_map<std::string, int64_t> counters;
};
inline std::vector<storage::v3::PropertyId> NamesToProperties(const std::vector<std::string> &property_names,
DbAccessor *dba) {
inline std::vector<storage::v3::PropertyId> NamesToProperties(
const std::vector<std::string> &property_names, msgs::ShardRequestManagerInterface *shard_request_manager) {
std::vector<storage::v3::PropertyId> properties;
// TODO Fix by using reference
properties.reserve(property_names.size());
for (const auto &name : property_names) {
properties.push_back(dba->NameToProperty(name));
if (shard_request_manager != nullptr) {
for (const auto &name : property_names) {
properties.push_back(shard_request_manager->NameToProperty(name));
}
}
return properties;
}
inline std::vector<storage::v3::LabelId> NamesToLabels(const std::vector<std::string> &label_names, DbAccessor *dba) {
inline std::vector<storage::v3::LabelId> NamesToLabels(const std::vector<std::string> &label_names,
msgs::ShardRequestManagerInterface *shard_request_manager) {
std::vector<storage::v3::LabelId> labels;
labels.reserve(label_names.size());
for (const auto &name : label_names) {
labels.push_back(dba->NameToLabel(name));
// TODO Fix by using reference
if (shard_request_manager != nullptr) {
for (const auto &name : label_names) {
labels.push_back(shard_request_manager->LabelNameToLabelId(name));
}
}
return labels;
}
@ -73,7 +81,7 @@ struct ExecutionContext {
ExecutionStats execution_stats;
// TriggerContextCollector *trigger_context_collector{nullptr};
utils::AsyncTimer timer;
std::unique_ptr<msgs::ShardRequestManagerInterface> shard_request_manager{nullptr};
msgs::ShardRequestManagerInterface *shard_request_manager{nullptr};
};
static_assert(std::is_move_assignable_v<ExecutionContext>, "ExecutionContext must be move assignable!");

View File

@ -11,6 +11,7 @@
#include "query/v2/cypher_query_interpreter.hpp"
#include "query/v2/bindings/symbol_generator.hpp"
#include "query/v2/shard_request_manager.hpp"
// NOLINTNEXTLINE (cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_HIDDEN_bool(query_cost_planner, true, "Use the cost-estimating query planner.");
@ -117,9 +118,9 @@ ParsedQuery ParseQuery(const std::string &query_string, const std::map<std::stri
}
std::unique_ptr<LogicalPlan> MakeLogicalPlan(AstStorage ast_storage, CypherQuery *query, const Parameters &parameters,
DbAccessor *db_accessor,
msgs::ShardRequestManagerInterface *shard_manager,
const std::vector<Identifier *> &predefined_identifiers) {
auto vertex_counts = plan::MakeVertexCountCache(db_accessor);
auto vertex_counts = plan::MakeVertexCountCache(shard_manager);
auto symbol_table = expr::MakeSymbolTable(query, predefined_identifiers);
auto planning_context = plan::MakePlanningContext(&ast_storage, &symbol_table, query, &vertex_counts);
auto [root, cost] = plan::MakeLogicalPlan(&planning_context, parameters, FLAGS_query_cost_planner);
@ -129,7 +130,7 @@ std::unique_ptr<LogicalPlan> MakeLogicalPlan(AstStorage ast_storage, CypherQuery
std::shared_ptr<CachedPlan> CypherQueryToPlan(uint64_t hash, AstStorage ast_storage, CypherQuery *query,
const Parameters &parameters, utils::SkipList<PlanCacheEntry> *plan_cache,
DbAccessor *db_accessor,
msgs::ShardRequestManagerInterface *shard_manager,
const std::vector<Identifier *> &predefined_identifiers) {
std::optional<utils::SkipList<PlanCacheEntry>::Accessor> plan_cache_access;
if (plan_cache) {
@ -145,7 +146,7 @@ std::shared_ptr<CachedPlan> CypherQueryToPlan(uint64_t hash, AstStorage ast_stor
}
auto plan = std::make_shared<CachedPlan>(
MakeLogicalPlan(std::move(ast_storage), query, parameters, db_accessor, predefined_identifiers));
MakeLogicalPlan(std::move(ast_storage), query, parameters, shard_manager, predefined_identifiers));
if (plan_cache_access) {
plan_cache_access->insert({hash, plan});
}

View File

@ -132,7 +132,7 @@ class SingleNodeLogicalPlan final : public LogicalPlan {
};
std::unique_ptr<LogicalPlan> MakeLogicalPlan(AstStorage ast_storage, CypherQuery *query, const Parameters &parameters,
DbAccessor *db_accessor,
msgs::ShardRequestManagerInterface *shard_manager,
const std::vector<Identifier *> &predefined_identifiers);
/**
@ -145,7 +145,7 @@ std::unique_ptr<LogicalPlan> MakeLogicalPlan(AstStorage ast_storage, CypherQuery
*/
std::shared_ptr<CachedPlan> CypherQueryToPlan(uint64_t hash, AstStorage ast_storage, CypherQuery *query,
const Parameters &parameters, utils::SkipList<PlanCacheEntry> *plan_cache,
DbAccessor *db_accessor,
msgs::ShardRequestManagerInterface *shard_manager,
const std::vector<Identifier *> &predefined_identifiers = {});
} // namespace memgraph::query::v2

View File

@ -219,195 +219,7 @@ inline VertexAccessor EdgeAccessor::From() const { return *static_cast<VertexAcc
inline bool EdgeAccessor::IsCycle() const { return To() == From(); }
class DbAccessor final {
storage::v3::Shard::Accessor *accessor_;
class VerticesIterable final {
storage::v3::VerticesIterable iterable_;
public:
class Iterator final {
storage::v3::VerticesIterable::Iterator it_;
public:
explicit Iterator(storage::v3::VerticesIterable::Iterator it) : it_(it) {}
VertexAccessor operator*() const { return VertexAccessor(*it_); }
Iterator &operator++() {
++it_;
return *this;
}
bool operator==(const Iterator &other) const { return it_ == other.it_; }
bool operator!=(const Iterator &other) const { return !(other == *this); }
};
explicit VerticesIterable(storage::v3::VerticesIterable iterable) : iterable_(std::move(iterable)) {}
Iterator begin() { return Iterator(iterable_.begin()); }
Iterator end() { return Iterator(iterable_.end()); }
};
public:
explicit DbAccessor(storage::v3::Shard::Accessor *accessor) : accessor_(accessor) {}
// TODO(jbajic) Fix Remove Gid
// NOLINTNEXTLINE(readability-convert-member-functions-to-static)
std::optional<VertexAccessor> FindVertex(uint64_t /*unused*/) { return std::nullopt; }
std::optional<VertexAccessor> FindVertex(storage::v3::PrimaryKey &primary_key, storage::v3::View view) {
auto maybe_vertex = accessor_->FindVertex(primary_key, view);
if (maybe_vertex) return VertexAccessor(*maybe_vertex);
return std::nullopt;
}
VerticesIterable Vertices(storage::v3::View view) { return VerticesIterable(accessor_->Vertices(view)); }
VerticesIterable Vertices(storage::v3::View view, storage::v3::LabelId label) {
return VerticesIterable(accessor_->Vertices(label, view));
}
VerticesIterable Vertices(storage::v3::View view, storage::v3::LabelId label, storage::v3::PropertyId property) {
return VerticesIterable(accessor_->Vertices(label, property, view));
}
VerticesIterable Vertices(storage::v3::View view, storage::v3::LabelId label, storage::v3::PropertyId property,
const storage::v3::PropertyValue &value) {
return VerticesIterable(accessor_->Vertices(label, property, value, view));
}
VerticesIterable Vertices(storage::v3::View view, storage::v3::LabelId label, storage::v3::PropertyId property,
const std::optional<utils::Bound<storage::v3::PropertyValue>> &lower,
const std::optional<utils::Bound<storage::v3::PropertyValue>> &upper) {
return VerticesIterable(accessor_->Vertices(label, property, lower, upper, view));
}
storage::v3::ResultSchema<VertexAccessor> InsertVertexAndValidate(
const storage::v3::LabelId primary_label, const std::vector<storage::v3::LabelId> &labels,
const std::vector<std::pair<storage::v3::PropertyId, storage::v3::PropertyValue>> &properties) {
auto maybe_vertex_acc = accessor_->CreateVertexAndValidate(primary_label, labels, properties);
if (maybe_vertex_acc.HasError()) {
return {std::move(maybe_vertex_acc.GetError())};
}
return VertexAccessor{maybe_vertex_acc.GetValue()};
}
storage::v3::Result<EdgeAccessor> InsertEdge(VertexAccessor *from, VertexAccessor *to,
const storage::v3::EdgeTypeId &edge_type) {
static constexpr auto kDummyGid = storage::v3::Gid::FromUint(0);
auto maybe_edge = accessor_->CreateEdge(from->impl_.Id(storage::v3::View::NEW).GetValue(),
to->impl_.Id(storage::v3::View::NEW).GetValue(), edge_type, kDummyGid);
if (maybe_edge.HasError()) return storage::v3::Result<EdgeAccessor>(maybe_edge.GetError());
return EdgeAccessor(*maybe_edge);
}
storage::v3::Result<std::optional<EdgeAccessor>> RemoveEdge(EdgeAccessor *edge) {
auto res = accessor_->DeleteEdge(edge->impl_.FromVertex(), edge->impl_.ToVertex(), edge->impl_.Gid());
if (res.HasError()) {
return res.GetError();
}
const auto &value = res.GetValue();
if (!value) {
return std::optional<EdgeAccessor>{};
}
return std::make_optional<EdgeAccessor>(*value);
}
storage::v3::Result<std::optional<std::pair<VertexAccessor, std::vector<EdgeAccessor>>>> DetachRemoveVertex(
VertexAccessor *vertex_accessor) {
using ReturnType = std::pair<VertexAccessor, std::vector<EdgeAccessor>>;
auto res = accessor_->DetachDeleteVertex(&vertex_accessor->impl_);
if (res.HasError()) {
return res.GetError();
}
const auto &value = res.GetValue();
if (!value) {
return std::optional<ReturnType>{};
}
const auto &[vertex, edges] = *value;
std::vector<EdgeAccessor> deleted_edges;
deleted_edges.reserve(edges.size());
std::transform(edges.begin(), edges.end(), std::back_inserter(deleted_edges),
[](const auto &deleted_edge) { return EdgeAccessor{deleted_edge}; });
return std::make_optional<ReturnType>(vertex, std::move(deleted_edges));
}
storage::v3::Result<std::optional<VertexAccessor>> RemoveVertex(VertexAccessor *vertex_accessor) {
auto res = accessor_->DeleteVertex(&vertex_accessor->impl_);
if (res.HasError()) {
return res.GetError();
}
const auto &value = res.GetValue();
if (!value) {
return std::optional<VertexAccessor>{};
}
return {std::make_optional<VertexAccessor>(*value)};
}
// TODO(jbajic) Query engine should have a map of labels, properties and edge
// types
// NOLINTNEXTLINE(readability-convert-member-functions-to-static)
storage::v3::PropertyId NameToProperty(const std::string_view name) { return accessor_->NameToProperty(name); }
// NOLINTNEXTLINE(readability-convert-member-functions-to-static)
storage::v3::LabelId NameToLabel(const std::string_view name) { return accessor_->NameToLabel(name); }
// NOLINTNEXTLINE(readability-convert-member-functions-to-static)
storage::v3::EdgeTypeId NameToEdgeType(const std::string_view name) { return accessor_->NameToEdgeType(name); }
const std::string &PropertyToName(storage::v3::PropertyId prop) const { return accessor_->PropertyToName(prop); }
const std::string &LabelToName(storage::v3::LabelId label) const { return accessor_->LabelToName(label); }
const std::string &EdgeTypeToName(storage::v3::EdgeTypeId type) const { return accessor_->EdgeTypeToName(type); }
void AdvanceCommand() { accessor_->AdvanceCommand(); }
void Commit() { return accessor_->Commit(coordinator::Hlc{}); }
void Abort() { accessor_->Abort(); }
bool LabelIndexExists(storage::v3::LabelId label) const { return accessor_->LabelIndexExists(label); }
bool LabelPropertyIndexExists(storage::v3::LabelId label, storage::v3::PropertyId prop) const {
return accessor_->LabelPropertyIndexExists(label, prop);
}
int64_t VerticesCount() const { return accessor_->ApproximateVertexCount(); }
int64_t VerticesCount(storage::v3::LabelId label) const { return accessor_->ApproximateVertexCount(label); }
int64_t VerticesCount(storage::v3::LabelId label, storage::v3::PropertyId property) const {
return accessor_->ApproximateVertexCount(label, property);
}
int64_t VerticesCount(storage::v3::LabelId label, storage::v3::PropertyId property,
const storage::v3::PropertyValue &value) const {
return accessor_->ApproximateVertexCount(label, property, value);
}
int64_t VerticesCount(storage::v3::LabelId label, storage::v3::PropertyId property,
const std::optional<utils::Bound<storage::v3::PropertyValue>> &lower,
const std::optional<utils::Bound<storage::v3::PropertyValue>> &upper) const {
return accessor_->ApproximateVertexCount(label, property, lower, upper);
}
storage::v3::IndicesInfo ListAllIndices() const { return accessor_->ListAllIndices(); }
const storage::v3::SchemaValidator &GetSchemaValidator() const { return accessor_->GetSchemaValidator(); }
storage::v3::SchemasInfo ListAllSchemas() const { return accessor_->ListAllSchemas(); }
};
} // namespace memgraph::query::v2

View File

@ -1,483 +0,0 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/v2/dump.hpp"
#include <iomanip>
#include <limits>
#include <map>
#include <optional>
#include <ostream>
#include <utility>
#include <vector>
#include <fmt/format.h>
#include "query/v2/bindings/typed_value.hpp"
#include "query/v2/db_accessor.hpp"
#include "query/v2/exceptions.hpp"
#include "query/v2/stream.hpp"
#include "storage/v3/property_value.hpp"
#include "storage/v3/storage.hpp"
#include "utils/algorithm.hpp"
#include "utils/logging.hpp"
#include "utils/string.hpp"
#include "utils/temporal.hpp"
namespace memgraph::query::v2 {
namespace {
// Property that is used to make a difference among vertices. It is added to
// property set of vertices to match edges and removed after the entire graph
// is built.
const char *kInternalPropertyId = "__mg_id__";
// Label that is attached to each vertex and is used for easier creation of
// index on internal property id.
const char *kInternalVertexLabel = "__mg_vertex__";
/// A helper function that escapes label, edge type and property names.
std::string EscapeName(const std::string_view value) {
std::string out;
out.reserve(value.size() + 2);
out.append(1, '`');
for (auto c : value) {
if (c == '`') {
out.append("``");
} else {
out.append(1, c);
}
}
out.append(1, '`');
return out;
}
void DumpPreciseDouble(std::ostream *os, double value) {
// A temporary stream is used to keep precision of the original output
// stream unchanged.
std::ostringstream temp_oss;
temp_oss << std::setprecision(std::numeric_limits<double>::max_digits10) << value;
*os << temp_oss.str();
}
namespace {
void DumpDate(std::ostream &os, const storage::v3::TemporalData &value) {
utils::Date date(value.microseconds);
os << "DATE(\"" << date << "\")";
}
void DumpLocalTime(std::ostream &os, const storage::v3::TemporalData &value) {
utils::LocalTime lt(value.microseconds);
os << "LOCALTIME(\"" << lt << "\")";
}
void DumpLocalDateTime(std::ostream &os, const storage::v3::TemporalData &value) {
utils::LocalDateTime ldt(value.microseconds);
os << "LOCALDATETIME(\"" << ldt << "\")";
}
void DumpDuration(std::ostream &os, const storage::v3::TemporalData &value) {
utils::Duration dur(value.microseconds);
os << "DURATION(\"" << dur << "\")";
}
void DumpTemporalData(std::ostream &os, const storage::v3::TemporalData &value) {
switch (value.type) {
case storage::v3::TemporalType::Date: {
DumpDate(os, value);
return;
}
case storage::v3::TemporalType::LocalTime: {
DumpLocalTime(os, value);
return;
}
case storage::v3::TemporalType::LocalDateTime: {
DumpLocalDateTime(os, value);
return;
}
case storage::v3::TemporalType::Duration: {
DumpDuration(os, value);
return;
}
}
}
} // namespace
void DumpPropertyValue(std::ostream *os, const storage::v3::PropertyValue &value) {
switch (value.type()) {
case storage::v3::PropertyValue::Type::Null:
*os << "Null";
return;
case storage::v3::PropertyValue::Type::Bool:
*os << (value.ValueBool() ? "true" : "false");
return;
case storage::v3::PropertyValue::Type::String:
*os << utils::Escape(value.ValueString());
return;
case storage::v3::PropertyValue::Type::Int:
*os << value.ValueInt();
return;
case storage::v3::PropertyValue::Type::Double:
DumpPreciseDouble(os, value.ValueDouble());
return;
case storage::v3::PropertyValue::Type::List: {
*os << "[";
const auto &list = value.ValueList();
utils::PrintIterable(*os, list, ", ", [](auto &os, const auto &item) { DumpPropertyValue(&os, item); });
*os << "]";
return;
}
case storage::v3::PropertyValue::Type::Map: {
*os << "{";
const auto &map = value.ValueMap();
utils::PrintIterable(*os, map, ", ", [](auto &os, const auto &kv) {
os << EscapeName(kv.first) << ": ";
DumpPropertyValue(&os, kv.second);
});
*os << "}";
return;
}
case storage::v3::PropertyValue::Type::TemporalData: {
DumpTemporalData(*os, value.ValueTemporalData());
return;
}
}
}
void DumpProperties(std::ostream *os, query::v2::DbAccessor *dba,
const std::map<storage::v3::PropertyId, storage::v3::PropertyValue> &store,
std::optional<int64_t> property_id = std::nullopt) {
*os << "{";
if (property_id) {
*os << kInternalPropertyId << ": " << *property_id;
if (store.size() > 0) *os << ", ";
}
utils::PrintIterable(*os, store, ", ", [&dba](auto &os, const auto &kv) {
os << EscapeName(dba->PropertyToName(kv.first)) << ": ";
DumpPropertyValue(&os, kv.second);
});
*os << "}";
}
void DumpVertex(std::ostream *os, query::v2::DbAccessor *dba, const query::v2::VertexAccessor &vertex) {
*os << "CREATE (";
*os << ":" << kInternalVertexLabel;
auto maybe_labels = vertex.Labels(storage::v3::View::OLD);
if (maybe_labels.HasError()) {
switch (maybe_labels.GetError()) {
case storage::v3::Error::DELETED_OBJECT:
throw query::v2::QueryRuntimeException("Trying to get labels from a deleted node.");
case storage::v3::Error::NONEXISTENT_OBJECT:
throw query::v2::QueryRuntimeException("Trying to get labels from a node that doesn't exist.");
case storage::v3::Error::SERIALIZATION_ERROR:
case storage::v3::Error::VERTEX_HAS_EDGES:
case storage::v3::Error::PROPERTIES_DISABLED:
throw query::v2::QueryRuntimeException("Unexpected error when getting labels.");
}
}
for (const auto &label : *maybe_labels) {
*os << ":" << EscapeName(dba->LabelToName(label));
}
*os << " ";
auto maybe_props = vertex.Properties(storage::v3::View::OLD);
if (maybe_props.HasError()) {
switch (maybe_props.GetError()) {
case storage::v3::Error::DELETED_OBJECT:
throw query::v2::QueryRuntimeException("Trying to get properties from a deleted object.");
case storage::v3::Error::NONEXISTENT_OBJECT:
throw query::v2::QueryRuntimeException("Trying to get properties from a node that doesn't exist.");
case storage::v3::Error::SERIALIZATION_ERROR:
case storage::v3::Error::VERTEX_HAS_EDGES:
case storage::v3::Error::PROPERTIES_DISABLED:
throw query::v2::QueryRuntimeException("Unexpected error when getting properties.");
}
}
DumpProperties(os, dba, *maybe_props, vertex.CypherId());
*os << ");";
}
void DumpEdge(std::ostream *os, query::v2::DbAccessor *dba, const query::v2::EdgeAccessor &edge) {
*os << "MATCH ";
*os << "(u:" << kInternalVertexLabel << "), ";
*os << "(v:" << kInternalVertexLabel << ")";
*os << " WHERE ";
*os << "u." << kInternalPropertyId << " = " << edge.From().CypherId();
*os << " AND ";
*os << "v." << kInternalPropertyId << " = " << edge.To().CypherId() << " ";
*os << "CREATE (u)-[";
*os << ":" << EscapeName(dba->EdgeTypeToName(edge.EdgeType()));
auto maybe_props = edge.Properties(storage::v3::View::OLD);
if (maybe_props.HasError()) {
switch (maybe_props.GetError()) {
case storage::v3::Error::DELETED_OBJECT:
throw query::v2::QueryRuntimeException("Trying to get properties from a deleted object.");
case storage::v3::Error::NONEXISTENT_OBJECT:
throw query::v2::QueryRuntimeException("Trying to get properties from an edge that doesn't exist.");
case storage::v3::Error::SERIALIZATION_ERROR:
case storage::v3::Error::VERTEX_HAS_EDGES:
case storage::v3::Error::PROPERTIES_DISABLED:
throw query::v2::QueryRuntimeException("Unexpected error when getting properties.");
}
}
if (maybe_props->size() > 0) {
*os << " ";
DumpProperties(os, dba, *maybe_props);
}
*os << "]->(v);";
}
void DumpLabelIndex(std::ostream *os, query::v2::DbAccessor *dba, const storage::v3::LabelId label) {
*os << "CREATE INDEX ON :" << EscapeName(dba->LabelToName(label)) << ";";
}
void DumpLabelPropertyIndex(std::ostream *os, query::v2::DbAccessor *dba, storage::v3::LabelId label,
storage::v3::PropertyId property) {
*os << "CREATE INDEX ON :" << EscapeName(dba->LabelToName(label)) << "(" << EscapeName(dba->PropertyToName(property))
<< ");";
}
void DumpExistenceConstraint(std::ostream *os, query::v2::DbAccessor *dba, storage::v3::LabelId label,
storage::v3::PropertyId property) {
*os << "CREATE CONSTRAINT ON (u:" << EscapeName(dba->LabelToName(label)) << ") ASSERT EXISTS (u."
<< EscapeName(dba->PropertyToName(property)) << ");";
}
void DumpUniqueConstraint(std::ostream *os, query::v2::DbAccessor *dba, storage::v3::LabelId label,
const std::set<storage::v3::PropertyId> &properties) {
*os << "CREATE CONSTRAINT ON (u:" << EscapeName(dba->LabelToName(label)) << ") ASSERT ";
utils::PrintIterable(*os, properties, ", ", [&dba](auto &stream, const auto &property) {
stream << "u." << EscapeName(dba->PropertyToName(property));
});
*os << " IS UNIQUE;";
}
} // namespace
PullPlanDump::PullPlanDump(DbAccessor *dba)
: dba_(dba),
vertices_iterable_(dba->Vertices(storage::v3::View::OLD)),
pull_chunks_{// Dump all label indices
CreateLabelIndicesPullChunk(),
// Dump all label property indices
CreateLabelPropertyIndicesPullChunk(),
// Create internal index for faster edge creation
CreateInternalIndexPullChunk(),
// Dump all vertices
CreateVertexPullChunk(),
// Dump all edges
CreateEdgePullChunk(),
// Drop the internal index
CreateDropInternalIndexPullChunk(),
// Internal index cleanup
CreateInternalIndexCleanupPullChunk()} {}
bool PullPlanDump::Pull(AnyStream *stream, std::optional<int> n) {
// Iterate all functions that stream some results.
// Each function should return number of results it streamed after it
// finishes. If the function did not finish streaming all the results,
// std::nullopt should be returned because n results have already been sent.
while (current_chunk_index_ < pull_chunks_.size() && (!n || *n > 0)) {
const auto maybe_streamed_count = pull_chunks_[current_chunk_index_](stream, n);
if (!maybe_streamed_count) {
// n wasn't large enough to stream all the results from the current chunk
break;
}
if (n) {
// chunk finished streaming its results
// subtract number of results streamed in current pull
// so we know how many results we need to stream from future
// chunks.
*n -= *maybe_streamed_count;
}
++current_chunk_index_;
}
return current_chunk_index_ == pull_chunks_.size();
}
PullPlanDump::PullChunk PullPlanDump::CreateLabelIndicesPullChunk() {
// Dump all label indices
return [this, global_index = 0U](AnyStream *stream, std::optional<int> n) mutable -> std::optional<size_t> {
// Delay the construction of indices vectors
if (!indices_info_) {
indices_info_.emplace(dba_->ListAllIndices());
}
const auto &label = indices_info_->label;
size_t local_counter = 0;
while (global_index < label.size() && (!n || local_counter < *n)) {
std::ostringstream os;
DumpLabelIndex(&os, dba_, label[global_index]);
stream->Result({TypedValue(os.str())});
++global_index;
++local_counter;
}
if (global_index == label.size()) {
return local_counter;
}
return std::nullopt;
};
}
PullPlanDump::PullChunk PullPlanDump::CreateLabelPropertyIndicesPullChunk() {
return [this, global_index = 0U](AnyStream *stream, std::optional<int> n) mutable -> std::optional<size_t> {
// Delay the construction of indices vectors
if (!indices_info_) {
indices_info_.emplace(dba_->ListAllIndices());
}
const auto &label_property = indices_info_->label_property;
size_t local_counter = 0;
while (global_index < label_property.size() && (!n || local_counter < *n)) {
std::ostringstream os;
const auto &label_property_index = label_property[global_index];
DumpLabelPropertyIndex(&os, dba_, label_property_index.first, label_property_index.second);
stream->Result({TypedValue(os.str())});
++global_index;
++local_counter;
}
if (global_index == label_property.size()) {
return local_counter;
}
return std::nullopt;
};
}
PullPlanDump::PullChunk PullPlanDump::CreateInternalIndexPullChunk() {
return [this](AnyStream *stream, std::optional<int>) mutable -> std::optional<size_t> {
if (vertices_iterable_.begin() != vertices_iterable_.end()) {
std::ostringstream os;
os << "CREATE INDEX ON :" << kInternalVertexLabel << "(" << kInternalPropertyId << ");";
stream->Result({TypedValue(os.str())});
internal_index_created_ = true;
return 1;
}
return 0;
};
}
PullPlanDump::PullChunk PullPlanDump::CreateVertexPullChunk() {
return [this, maybe_current_iter = std::optional<VertexAccessorIterableIterator>{}](
AnyStream *stream, std::optional<int> n) mutable -> std::optional<size_t> {
// Delay the call of begin() function
// If multiple begins are called before an iteration,
// one iteration will make the rest of iterators be in undefined
// states.
if (!maybe_current_iter) {
maybe_current_iter.emplace(vertices_iterable_.begin());
}
auto &current_iter{*maybe_current_iter};
size_t local_counter = 0;
while (current_iter != vertices_iterable_.end() && (!n || local_counter < *n)) {
std::ostringstream os;
DumpVertex(&os, dba_, *current_iter);
stream->Result({TypedValue(os.str())});
++local_counter;
++current_iter;
}
if (current_iter == vertices_iterable_.end()) {
return local_counter;
}
return std::nullopt;
};
}
PullPlanDump::PullChunk PullPlanDump::CreateEdgePullChunk() {
return [this, maybe_current_vertex_iter = std::optional<VertexAccessorIterableIterator>{},
// we need to save the iterable which contains list of accessor so
// our saved iterator is valid in the next run
maybe_edge_iterable = std::shared_ptr<EdgeAccessorIterable>{nullptr},
maybe_current_edge_iter = std::optional<EdgeAccessorIterableIterator>{}](
AnyStream *stream, std::optional<int> n) mutable -> std::optional<size_t> {
// Delay the call of begin() function
// If multiple begins are called before an iteration,
// one iteration will make the rest of iterators be in undefined
// states.
if (!maybe_current_vertex_iter) {
maybe_current_vertex_iter.emplace(vertices_iterable_.begin());
}
auto &current_vertex_iter{*maybe_current_vertex_iter};
size_t local_counter = 0U;
for (; current_vertex_iter != vertices_iterable_.end() && (!n || local_counter < *n); ++current_vertex_iter) {
const auto &vertex = *current_vertex_iter;
// If we have a saved iterable from a previous pull
// we need to use the same iterable
if (!maybe_edge_iterable) {
maybe_edge_iterable = std::make_shared<EdgeAccessorIterable>(vertex.OutEdges(storage::v3::View::OLD));
}
auto &maybe_edges = *maybe_edge_iterable;
MG_ASSERT(maybe_edges.HasValue(), "Invalid database state!");
auto current_edge_iter = maybe_current_edge_iter ? *maybe_current_edge_iter : maybe_edges->begin();
for (; current_edge_iter != maybe_edges->end() && (!n || local_counter < *n); ++current_edge_iter) {
std::ostringstream os;
DumpEdge(&os, dba_, *current_edge_iter);
stream->Result({TypedValue(os.str())});
++local_counter;
}
if (current_edge_iter != maybe_edges->end()) {
maybe_current_edge_iter.emplace(current_edge_iter);
return std::nullopt;
}
maybe_current_edge_iter = std::nullopt;
maybe_edge_iterable = nullptr;
}
if (current_vertex_iter == vertices_iterable_.end()) {
return local_counter;
}
return std::nullopt;
};
}
PullPlanDump::PullChunk PullPlanDump::CreateDropInternalIndexPullChunk() {
return [this](AnyStream *stream, std::optional<int>) {
if (internal_index_created_) {
std::ostringstream os;
os << "DROP INDEX ON :" << kInternalVertexLabel << "(" << kInternalPropertyId << ");";
stream->Result({TypedValue(os.str())});
return 1;
}
return 0;
};
}
PullPlanDump::PullChunk PullPlanDump::CreateInternalIndexCleanupPullChunk() {
return [this](AnyStream *stream, std::optional<int>) {
if (internal_index_created_) {
std::ostringstream os;
os << "MATCH (u) REMOVE u:" << kInternalVertexLabel << ", u." << kInternalPropertyId << ";";
stream->Result({TypedValue(os.str())});
return 1;
}
return 0;
};
}
void DumpDatabaseToCypherQueries(query::v2::DbAccessor *dba, AnyStream *stream) { PullPlanDump(dba).Pull(stream, {}); }
} // namespace memgraph::query::v2

View File

@ -1,63 +0,0 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <ostream>
#include "query/v2/db_accessor.hpp"
#include "query/v2/stream.hpp"
#include "storage/v3/storage.hpp"
namespace memgraph::query::v2 {
void DumpDatabaseToCypherQueries(query::v2::DbAccessor *dba, AnyStream *stream);
struct PullPlanDump {
explicit PullPlanDump(query::v2::DbAccessor *dba);
/// Pull the dump results lazily
/// @return true if all results were returned, false otherwise
bool Pull(AnyStream *stream, std::optional<int> n);
private:
query::v2::DbAccessor *dba_ = nullptr;
std::optional<storage::v3::IndicesInfo> indices_info_ = std::nullopt;
using VertexAccessorIterable = decltype(std::declval<query::v2::DbAccessor>().Vertices(storage::v3::View::OLD));
using VertexAccessorIterableIterator = decltype(std::declval<VertexAccessorIterable>().begin());
using EdgeAccessorIterable = decltype(std::declval<VertexAccessor>().OutEdges(storage::v3::View::OLD));
using EdgeAccessorIterableIterator = decltype(std::declval<EdgeAccessorIterable>().GetValue().begin());
VertexAccessorIterable vertices_iterable_;
bool internal_index_created_ = false;
size_t current_chunk_index_ = 0;
using PullChunk = std::function<std::optional<size_t>(AnyStream *stream, std::optional<int> n)>;
// We define every part of the dump query in a self contained function.
// Each functions is responsible of keeping track of its execution status.
// If a function did finish its execution, it should return number of results
// it streamed so we know how many rows should be pulled from the next
// function, otherwise std::nullopt is returned.
std::vector<PullChunk> pull_chunks_;
PullChunk CreateLabelIndicesPullChunk();
PullChunk CreateLabelPropertyIndicesPullChunk();
PullChunk CreateInternalIndexPullChunk();
PullChunk CreateVertexPullChunk();
PullChunk CreateEdgePullChunk();
PullChunk CreateDropInternalIndexPullChunk();
PullChunk CreateInternalIndexCleanupPullChunk();
};
} // namespace memgraph::query::v2

View File

@ -18,6 +18,7 @@
#include <vector>
#include "query/v2/bindings/ast_visitor.hpp"
#include "common/types.hpp"
#include "query/v2/bindings/symbol.hpp"
#include "query/v2/interpret/awesome_memgraph_functions.hpp"
#include "query/v2/bindings/typed_value.hpp"

View File

@ -24,9 +24,6 @@
#include "query/v2/conversions.hpp"
#include "query/v2/db_accessor.hpp"
#include "query/v2/exceptions.hpp"
#include "query/v2/procedure/cypher_types.hpp"
//#include "query/v2/procedure/mg_procedure_impl.hpp"
//#include "query/v2/procedure/module.hpp"
#include "storage/v3/conversions.hpp"
#include "utils/string.hpp"
#include "utils/temporal.hpp"
@ -415,48 +412,6 @@ TypedValue StartNode(const TypedValue *args, int64_t nargs, const FunctionContex
return TypedValue(args[0].ValueEdge().From(), ctx.memory);
}
namespace {
size_t UnwrapDegreeResult(storage::v3::Result<size_t> maybe_degree) {
if (maybe_degree.HasError()) {
switch (maybe_degree.GetError()) {
case storage::v3::Error::DELETED_OBJECT:
throw QueryRuntimeException("Trying to get degree of a deleted node.");
case storage::v3::Error::NONEXISTENT_OBJECT:
throw query::v2::QueryRuntimeException("Trying to get degree of a node that doesn't exist.");
case storage::v3::Error::SERIALIZATION_ERROR:
case storage::v3::Error::VERTEX_HAS_EDGES:
case storage::v3::Error::PROPERTIES_DISABLED:
throw QueryRuntimeException("Unexpected error when getting node degree.");
}
}
return *maybe_degree;
}
} // namespace
TypedValue Degree(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Vertex>>("degree", args, nargs);
if (args[0].IsNull()) return TypedValue(ctx.memory);
const auto &vertex = args[0].ValueVertex();
// TODO(kostasrim) Fix dummy values
return TypedValue(int64_t(0), ctx.memory);
}
TypedValue InDegree(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Vertex>>("inDegree", args, nargs);
if (args[0].IsNull()) return TypedValue(ctx.memory);
const auto &vertex = args[0].ValueVertex();
return TypedValue(int64_t(0), ctx.memory);
}
TypedValue OutDegree(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Vertex>>("outDegree", args, nargs);
if (args[0].IsNull()) return TypedValue(ctx.memory);
const auto &vertex = args[0].ValueVertex();
return TypedValue(int64_t(0), ctx.memory);
}
TypedValue ToBoolean(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Bool, Integer, String>>("toBoolean", args, nargs);
const auto &value = args[0];
@ -518,9 +473,8 @@ TypedValue ToInteger(const TypedValue *args, int64_t nargs, const FunctionContex
TypedValue Type(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Edge>>("type", args, nargs);
auto *dba = ctx.db_accessor;
if (args[0].IsNull()) return TypedValue(ctx.memory);
return TypedValue(static_cast<int64_t>(args[0].ValueEdge().EdgeType()), ctx.memory);
return TypedValue(args[0].ValueEdge().EdgeType().AsInt(), ctx.memory);
}
TypedValue ValueType(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
@ -559,30 +513,6 @@ TypedValue ValueType(const TypedValue *args, int64_t nargs, const FunctionContex
}
}
// TODO: How is Keys different from Properties function?
TypedValue Keys(const TypedValue *args, int64_t nargs, const FunctionContext & /*ctx*/) {
FType<Or<Null, Vertex, Edge>>("keys", args, nargs);
return {};
}
TypedValue Labels(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Vertex>>("labels", args, nargs);
if (args[0].IsNull()) return TypedValue(ctx.memory);
return {};
}
TypedValue Nodes(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Path>>("nodes", args, nargs);
if (args[0].IsNull()) return TypedValue(ctx.memory);
return {};
}
TypedValue Relationships(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Path>>("relationships", args, nargs);
if (args[0].IsNull()) return TypedValue(ctx.memory);
return {};
}
TypedValue Range(const TypedValue *args, int64_t nargs, const FunctionContext &ctx) {
FType<Or<Null, Integer>, Or<Null, Integer>, Optional<Or<Null, NonZeroInteger>>>("range", args, nargs);
for (int64_t i = 0; i < nargs; ++i)
@ -1098,9 +1028,6 @@ TypedValue Duration(const TypedValue *args, int64_t nargs, const FunctionContext
std::function<TypedValue(const TypedValue *, int64_t, const FunctionContext &ctx)> NameToFunction(
const std::string &function_name) {
// Scalar functions
if (function_name == "DEGREE") return Degree;
if (function_name == "INDEGREE") return InDegree;
if (function_name == "OUTDEGREE") return OutDegree;
if (function_name == "ENDNODE") return EndNode;
if (function_name == "HEAD") return Head;
if (function_name == kId) return Id;
@ -1116,11 +1043,7 @@ std::function<TypedValue(const TypedValue *, int64_t, const FunctionContext &ctx
if (function_name == "VALUETYPE") return ValueType;
// List functions
if (function_name == "KEYS") return Keys;
if (function_name == "LABELS") return Labels;
if (function_name == "NODES") return Nodes;
if (function_name == "RANGE") return Range;
if (function_name == "RELATIONSHIPS") return Relationships;
if (function_name == "TAIL") return Tail;
if (function_name == "UNIFORMSAMPLE") return UniformSample;

View File

@ -19,9 +19,13 @@
#include <cstdint>
#include <functional>
#include <limits>
#include <memory>
#include <optional>
#include "coordinator/coordinator_client.hpp"
#include "expr/ast/ast_visitor.hpp"
#include "io/local_transport/local_system.hpp"
#include "io/local_transport/local_transport.hpp"
#include "memory/memory_control.hpp"
#include "parser/opencypher/parser.hpp"
#include "query/v2/bindings/eval.hpp"
@ -33,7 +37,6 @@
#include "query/v2/context.hpp"
#include "query/v2/cypher_query_interpreter.hpp"
#include "query/v2/db_accessor.hpp"
#include "query/v2/dump.hpp"
#include "query/v2/exceptions.hpp"
#include "query/v2/frontend/ast/ast.hpp"
#include "query/v2/frontend/semantic/required_privileges.hpp"
@ -41,6 +44,7 @@
#include "query/v2/plan/planner.hpp"
#include "query/v2/plan/profile.hpp"
#include "query/v2/plan/vertex_count_cache.hpp"
#include "query/v2/shard_request_manager.hpp"
#include "storage/v3/property_value.hpp"
#include "storage/v3/shard.hpp"
#include "storage/v3/storage.hpp"
@ -114,17 +118,6 @@ std::optional<TResult> GetOptionalValue(query::v2::Expression *expression, Expre
return {};
};
std::optional<std::string> GetOptionalStringValue(query::v2::Expression *expression, ExpressionEvaluator &evaluator) {
if (expression != nullptr) {
auto value = expression->Accept(evaluator);
MG_ASSERT(value.IsNull() || value.IsString());
if (value.IsString()) {
return {std::string(value.ValueString().begin(), value.ValueString().end())};
}
}
return {};
};
class ReplQueryHandler final : public query::v2::ReplicationQueryHandler {
public:
explicit ReplQueryHandler(storage::v3::Shard * /*db*/) {}
@ -455,21 +448,6 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &
}
}
std::optional<std::string> StringPointerToOptional(const std::string *str) {
return str == nullptr ? std::nullopt : std::make_optional(*str);
}
std::vector<std::string> EvaluateTopicNames(ExpressionEvaluator &evaluator,
std::variant<Expression *, std::vector<std::string>> topic_variant) {
return std::visit(utils::Overloaded{[&](Expression *expression) {
auto topic_names = expression->Accept(evaluator);
MG_ASSERT(topic_names.IsString());
return utils::Split(topic_names.ValueString(), ",");
},
[&](std::vector<std::string> topic_names) { return topic_names; }},
std::move(topic_variant));
}
Callback HandleSettingQuery(SettingQuery *setting_query, const Parameters &parameters, DbAccessor *db_accessor) {
expr::Frame<TypedValue> frame(0);
SymbolTable symbol_table;
@ -671,6 +649,7 @@ struct PullPlanVector {
struct PullPlan {
explicit PullPlan(std::shared_ptr<CachedPlan> plan, const Parameters &parameters, bool is_profile_query,
DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory,
msgs::ShardRequestManagerInterface *shard_request_manager = nullptr,
// TriggerContextCollector *trigger_context_collector = nullptr,
std::optional<size_t> memory_limit = {});
std::optional<plan::ProfilingStatsWithTotalTime> Pull(AnyStream *stream, std::optional<int> n,
@ -700,7 +679,7 @@ struct PullPlan {
PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters &parameters, const bool is_profile_query,
DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory,
const std::optional<size_t> memory_limit)
msgs::ShardRequestManagerInterface *shard_request_manager, const std::optional<size_t> memory_limit)
// TriggerContextCollector *trigger_context_collector, const std::optional<size_t> memory_limit)
: plan_(plan),
cursor_(plan->plan().MakeCursor(execution_memory)),
@ -710,14 +689,15 @@ PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters &par
ctx_.symbol_table = plan->symbol_table();
ctx_.evaluation_context.timestamp = QueryTimestamp();
ctx_.evaluation_context.parameters = parameters;
ctx_.evaluation_context.properties = NamesToProperties(plan->ast_storage().properties_, dba);
ctx_.evaluation_context.labels = NamesToLabels(plan->ast_storage().labels_, dba);
ctx_.evaluation_context.properties = NamesToProperties(plan->ast_storage().properties_, shard_request_manager);
ctx_.evaluation_context.labels = NamesToLabels(plan->ast_storage().labels_, shard_request_manager);
if (interpreter_context->config.execution_timeout_sec > 0) {
ctx_.timer = utils::AsyncTimer{interpreter_context->config.execution_timeout_sec};
}
ctx_.is_shutting_down = &interpreter_context->is_shutting_down;
ctx_.is_profile_query = is_profile_query;
// ctx_.trigger_context_collector = trigger_context_collector;
ctx_.shard_request_manager = shard_request_manager;
}
std::optional<plan::ProfilingStatsWithTotalTime> PullPlan::Pull(AnyStream *stream, std::optional<int> n,
@ -813,13 +793,18 @@ using RWType = plan::ReadWriteTypeChecker::RWType;
} // namespace
InterpreterContext::InterpreterContext(storage::v3::Shard *db, const InterpreterConfig config,
const std::filesystem::path &data_directory)
// : db(db), trigger_store(data_directory / "triggers"), config(config), streams{this, data_directory /
// "streams"} {}
: db(db), config(config) {}
const std::filesystem::path & /*data_directory*/,
io::Io<io::local_transport::LocalTransport> io,
coordinator::Address coordinator_addr)
: db(db), config(config), io{std::move(io)}, coordinator_address{coordinator_addr} {}
Interpreter::Interpreter(InterpreterContext *interpreter_context) : interpreter_context_(interpreter_context) {
MG_ASSERT(interpreter_context_, "Interpreter context must not be NULL");
auto query_io = interpreter_context_->io.ForkLocal();
shard_request_manager_ = std::make_unique<msgs::ShardRequestManager<io::local_transport::LocalTransport>>(
coordinator::CoordinatorClient<io::local_transport::LocalTransport>(
query_io, interpreter_context_->coordinator_address, std::vector{interpreter_context_->coordinator_address}),
std::move(query_io));
}
PreparedQuery Interpreter::PrepareTransactionQuery(std::string_view query_upper) {
@ -832,14 +817,6 @@ PreparedQuery Interpreter::PrepareTransactionQuery(std::string_view query_upper)
}
in_explicit_transaction_ = true;
expect_rollback_ = false;
db_accessor_ = std::make_unique<storage::v3::Shard::Accessor>(
interpreter_context_->db->Access(coordinator::Hlc{}, GetIsolationLevelOverride()));
execution_db_accessor_.emplace(db_accessor_.get());
// if (interpreter_context_->trigger_store.HasTriggers()) {
// trigger_context_collector_.emplace(interpreter_context_->trigger_store.GetEventTypes());
// }
};
} else if (query_upper == "COMMIT") {
handler = [this] {
@ -886,7 +863,8 @@ PreparedQuery Interpreter::PrepareTransactionQuery(std::string_view query_upper)
PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string, TypedValue> *summary,
InterpreterContext *interpreter_context, DbAccessor *dba,
utils::MemoryResource *execution_memory, std::vector<Notification> *notifications) {
utils::MemoryResource *execution_memory, std::vector<Notification> *notifications,
msgs::ShardRequestManagerInterface *shard_request_manager) {
// TriggerContextCollector *trigger_context_collector = nullptr) {
auto *cypher_query = utils::Downcast<CypherQuery>(parsed_query.query);
@ -910,10 +888,10 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string,
"convert the parsed row values to the appropriate type. This can be done using the built-in "
"conversion functions such as ToInteger, ToFloat, ToBoolean etc.");
}
auto plan = CypherQueryToPlan(parsed_query.stripped_query.hash(), std::move(parsed_query.ast_storage), cypher_query,
parsed_query.parameters,
parsed_query.is_cacheable ? &interpreter_context->plan_cache : nullptr, dba);
shard_request_manager->StartTransaction();
auto plan = CypherQueryToPlan(
parsed_query.stripped_query.hash(), std::move(parsed_query.ast_storage), cypher_query, parsed_query.parameters,
parsed_query.is_cacheable ? &interpreter_context->plan_cache : nullptr, shard_request_manager);
summary->insert_or_assign("cost_estimate", plan->cost());
auto rw_type_checker = plan::ReadWriteTypeChecker();
@ -932,7 +910,7 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string,
utils::FindOr(parsed_query.stripped_query.named_expressions(), symbol.token_position(), symbol.name()).first);
}
auto pull_plan = std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context,
execution_memory, memory_limit);
execution_memory, shard_request_manager, memory_limit);
// execution_memory, trigger_context_collector, memory_limit);
return PreparedQuery{std::move(header), std::move(parsed_query.required_privileges),
[pull_plan = std::move(pull_plan), output_symbols = std::move(output_symbols), summary](
@ -946,7 +924,8 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string,
}
PreparedQuery PrepareExplainQuery(ParsedQuery parsed_query, std::map<std::string, TypedValue> *summary,
InterpreterContext *interpreter_context, DbAccessor *dba,
InterpreterContext *interpreter_context,
msgs::ShardRequestManagerInterface *shard_request_manager,
utils::MemoryResource *execution_memory) {
const std::string kExplainQueryStart = "explain ";
MG_ASSERT(utils::StartsWith(utils::ToLowerCase(parsed_query.stripped_query.query()), kExplainQueryStart),
@ -965,19 +944,20 @@ PreparedQuery PrepareExplainQuery(ParsedQuery parsed_query, std::map<std::string
auto *cypher_query = utils::Downcast<CypherQuery>(parsed_inner_query.query);
MG_ASSERT(cypher_query, "Cypher grammar should not allow other queries in EXPLAIN");
auto cypher_query_plan = CypherQueryToPlan(
parsed_inner_query.stripped_query.hash(), std::move(parsed_inner_query.ast_storage), cypher_query,
parsed_inner_query.parameters, parsed_inner_query.is_cacheable ? &interpreter_context->plan_cache : nullptr, dba);
auto cypher_query_plan =
CypherQueryToPlan(parsed_inner_query.stripped_query.hash(), std::move(parsed_inner_query.ast_storage),
cypher_query, parsed_inner_query.parameters,
parsed_inner_query.is_cacheable ? &interpreter_context->plan_cache : nullptr, nullptr);
std::stringstream printed_plan;
plan::PrettyPrint(*dba, &cypher_query_plan->plan(), &printed_plan);
plan::PrettyPrint(*shard_request_manager, &cypher_query_plan->plan(), &printed_plan);
std::vector<std::vector<TypedValue>> printed_plan_rows;
for (const auto &row : utils::Split(utils::RTrim(printed_plan.str()), "\n")) {
printed_plan_rows.push_back(std::vector<TypedValue>{TypedValue(row)});
}
summary->insert_or_assign("explain", plan::PlanToJson(*dba, &cypher_query_plan->plan()).dump());
summary->insert_or_assign("explain", plan::PlanToJson(*shard_request_manager, &cypher_query_plan->plan()).dump());
return PreparedQuery{{"QUERY PLAN"},
std::move(parsed_query.required_privileges),
@ -993,7 +973,8 @@ PreparedQuery PrepareExplainQuery(ParsedQuery parsed_query, std::map<std::string
PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
std::map<std::string, TypedValue> *summary, InterpreterContext *interpreter_context,
DbAccessor *dba, utils::MemoryResource *execution_memory) {
DbAccessor *dba, utils::MemoryResource *execution_memory,
msgs::ShardRequestManagerInterface *shard_request_manager = nullptr) {
const std::string kProfileQueryStart = "profile ";
MG_ASSERT(utils::StartsWith(utils::ToLowerCase(parsed_query.stripped_query.query()), kProfileQueryStart),
@ -1042,14 +1023,15 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra
auto cypher_query_plan = CypherQueryToPlan(
parsed_inner_query.stripped_query.hash(), std::move(parsed_inner_query.ast_storage), cypher_query,
parsed_inner_query.parameters, parsed_inner_query.is_cacheable ? &interpreter_context->plan_cache : nullptr, dba);
parsed_inner_query.parameters, parsed_inner_query.is_cacheable ? &interpreter_context->plan_cache : nullptr,
shard_request_manager);
auto rw_type_checker = plan::ReadWriteTypeChecker();
rw_type_checker.InferRWType(const_cast<plan::LogicalOperator &>(cypher_query_plan->plan()));
return PreparedQuery{{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"},
std::move(parsed_query.required_privileges),
[plan = std::move(cypher_query_plan), parameters = std::move(parsed_inner_query.parameters),
summary, dba, interpreter_context, execution_memory, memory_limit,
summary, dba, interpreter_context, execution_memory, memory_limit, shard_request_manager,
// We want to execute the query we are profiling lazily, so we delay
// the construction of the corresponding context.
stats_and_total_time = std::optional<plan::ProfilingStatsWithTotalTime>{},
@ -1058,7 +1040,7 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra
// No output symbols are given so that nothing is streamed.
if (!stats_and_total_time) {
stats_and_total_time = PullPlan(plan, parameters, true, dba, interpreter_context,
execution_memory, memory_limit)
execution_memory, shard_request_manager, memory_limit)
.Pull(stream, {}, {}, summary);
pull_plan = std::make_shared<PullPlanVector>(ProfilingStatsToTable(*stats_and_total_time));
}
@ -1077,16 +1059,7 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra
PreparedQuery PrepareDumpQuery(ParsedQuery parsed_query, std::map<std::string, TypedValue> *summary, DbAccessor *dba,
utils::MemoryResource *execution_memory) {
return PreparedQuery{{"QUERY"},
std::move(parsed_query.required_privileges),
[pull_plan = std::make_shared<PullPlanDump>(dba)](
AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> {
if (pull_plan->Pull(stream, n)) {
return QueryHandlerResult::COMMIT;
}
return std::nullopt;
},
RWType::R};
throw QueryRuntimeException("Dump query is not supported!");
}
PreparedQuery PrepareIndexQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
@ -1526,30 +1499,20 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
ParseQuery(query_string, params, &interpreter_context_->ast_cache, interpreter_context_->config.query);
query_execution->summary["parsing_time"] = parsing_timer.Elapsed().count();
// Some queries require an active transaction in order to be prepared.
if (!in_explicit_transaction_ &&
(utils::Downcast<CypherQuery>(parsed_query.query) || utils::Downcast<ExplainQuery>(parsed_query.query) ||
utils::Downcast<ProfileQuery>(parsed_query.query) || utils::Downcast<DumpQuery>(parsed_query.query) ||
utils::Downcast<TriggerQuery>(parsed_query.query))) {
db_accessor_ = std::make_unique<storage::v3::Shard::Accessor>(
interpreter_context_->db->Access(coordinator::Hlc{}, GetIsolationLevelOverride()));
execution_db_accessor_.emplace(db_accessor_.get());
}
utils::Timer planning_timer;
PreparedQuery prepared_query;
if (utils::Downcast<CypherQuery>(parsed_query.query)) {
prepared_query = PrepareCypherQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_,
&*execution_db_accessor_, &query_execution->execution_memory,
&query_execution->notifications);
&query_execution->notifications, shard_request_manager_.get());
} else if (utils::Downcast<ExplainQuery>(parsed_query.query)) {
prepared_query = PrepareExplainQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_,
&*execution_db_accessor_, &query_execution->execution_memory_with_exception);
&*shard_request_manager_, &query_execution->execution_memory_with_exception);
} else if (utils::Downcast<ProfileQuery>(parsed_query.query)) {
prepared_query = PrepareProfileQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary,
interpreter_context_, &*execution_db_accessor_,
&query_execution->execution_memory_with_exception);
prepared_query = PrepareProfileQuery(
std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, interpreter_context_,
&*execution_db_accessor_, &query_execution->execution_memory_with_exception, shard_request_manager_.get());
} else if (utils::Downcast<DumpQuery>(parsed_query.query)) {
prepared_query = PrepareDumpQuery(std::move(parsed_query), &query_execution->summary, &*execution_db_accessor_,
&query_execution->execution_memory);
@ -1628,6 +1591,7 @@ void Interpreter::Commit() {
// For now, we will not check if there are some unfinished queries.
// We should document clearly that all results should be pulled to complete
// a query.
shard_request_manager_->Commit();
if (!db_accessor_) return;
const auto reset_necessary_members = [this]() {

View File

@ -13,6 +13,10 @@
#include <gflags/gflags.h>
#include "coordinator/coordinator.hpp"
#include "coordinator/coordinator_client.hpp"
#include "io/local_transport/local_transport.hpp"
#include "io/transport.hpp"
#include "query/v2/auth_checker.hpp"
#include "query/v2/bindings/cypher_main_visitor.hpp"
#include "query/v2/bindings/typed_value.hpp"
@ -165,7 +169,8 @@ struct PreparedQuery {
*/
struct InterpreterContext {
explicit InterpreterContext(storage::v3::Shard *db, InterpreterConfig config,
const std::filesystem::path &data_directory);
const std::filesystem::path &data_directory,
io::Io<io::local_transport::LocalTransport> io, coordinator::Address coordinator_addr);
storage::v3::Shard *db;
@ -180,6 +185,11 @@ struct InterpreterContext {
const InterpreterConfig config;
// TODO (antaljanosbenjamin) Figure out an abstraction for io::Io to make it possible to construct an interpreter
// context with a simulator transport without templatizing it.
io::Io<io::local_transport::LocalTransport> io;
coordinator::Address coordinator_address;
storage::v3::LabelId NameToLabelId(std::string_view label_name) {
return storage::v3::LabelId::FromUint(query_id_mapper.NameToId(label_name));
}
@ -327,6 +337,7 @@ class Interpreter final {
// move this unique_ptr into a shrared_ptr.
std::unique_ptr<storage::v3::Shard::Accessor> db_accessor_;
std::optional<DbAccessor> execution_db_accessor_;
std::unique_ptr<msgs::ShardRequestManagerInterface> shard_request_manager_;
bool in_explicit_transaction_{false};
bool expect_rollback_{false};

View File

@ -14,6 +14,8 @@
#include <functional>
#include <utility>
#include "query/db_accessor.hpp"
#include "query/v2/accessors.hpp"
#include "query/v2/db_accessor.hpp"
#include "utils/logging.hpp"
#include "utils/memory.hpp"
@ -28,6 +30,8 @@ namespace memgraph::query::v2 {
*/
class Path {
public:
using VertexAccessor = accessors::VertexAccessor;
using EdgeAccessor = accessors::EdgeAccessor;
/** Allocator type so that STL containers are aware that we need one */
using allocator_type = utils::Allocator<char>;

View File

@ -27,6 +27,7 @@
#include <cppitertools/imap.hpp>
#include "expr/exceptions.hpp"
#include "query/exceptions.hpp"
#include "query/v2/accessors.hpp"
#include "query/v2/bindings/eval.hpp"
#include "query/v2/bindings/symbol_table.hpp"
@ -36,7 +37,6 @@
#include "query/v2/frontend/ast/ast.hpp"
#include "query/v2/path.hpp"
#include "query/v2/plan/scoped_profile.hpp"
#include "query/v2/procedure/cypher_types.hpp"
#include "query/v2/requests.hpp"
#include "query/v2/shard_request_manager.hpp"
#include "storage/v3/conversions.hpp"
@ -156,6 +156,71 @@ uint64_t ComputeProfilingKey(const T *obj) {
#define SCOPED_PROFILE_OP(name) ScopedProfile profile{ComputeProfilingKey(this), name, &context};
class DistributedCreateNodeCursor : public Cursor {
public:
using InputOperator = std::shared_ptr<memgraph::query::v2::plan::LogicalOperator>;
DistributedCreateNodeCursor(const InputOperator &op, utils::MemoryResource *mem,
std::vector<const NodeCreationInfo *> nodes_info)
: input_cursor_(op->MakeCursor(mem)), nodes_info_(std::move(nodes_info)) {}
bool Pull(Frame &frame, ExecutionContext &context) override {
SCOPED_PROFILE_OP("CreateNode");
if (input_cursor_->Pull(frame, context)) {
auto &shard_manager = context.shard_request_manager;
shard_manager->Request(state_, NodeCreationInfoToRequest(context, frame));
return true;
}
return false;
}
void Shutdown() override { input_cursor_->Shutdown(); }
void Reset() override { state_ = {}; }
std::vector<msgs::NewVertex> NodeCreationInfoToRequest(ExecutionContext &context, Frame &frame) const {
std::vector<msgs::NewVertex> requests;
for (const auto &node_info : nodes_info_) {
msgs::NewVertex rqst;
std::map<msgs::PropertyId, msgs::Value> properties;
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, nullptr,
storage::v3::View::NEW);
if (const auto *node_info_properties = std::get_if<PropertiesMapList>(&node_info->properties)) {
for (const auto &[key, value_expression] : *node_info_properties) {
TypedValue val = value_expression->Accept(evaluator);
properties[key] = TypedValueToValue(val);
if (context.shard_request_manager->IsPrimaryKey(key)) {
rqst.primary_key.push_back(storage::v3::TypedValueToValue(val));
}
}
} else {
auto property_map = evaluator.Visit(*std::get<ParameterLookup *>(node_info->properties)).ValueMap();
for (const auto &[key, value] : property_map) {
auto key_str = std::string(key);
auto property_id = context.shard_request_manager->NameToProperty(key_str);
properties[property_id] = TypedValueToValue(value);
if (context.shard_request_manager->IsPrimaryKey(property_id)) {
rqst.primary_key.push_back(storage::v3::TypedValueToValue(value));
}
}
}
if (node_info->labels.empty()) {
throw QueryRuntimeException("Primary label must be defined!");
}
// TODO(kostasrim) Copy non primary labels as well
rqst.label_ids.push_back(msgs::Label{node_info->labels[0]});
requests.push_back(std::move(rqst));
}
return requests;
}
private:
const UniqueCursorPtr input_cursor_;
std::vector<const NodeCreationInfo *> nodes_info_;
msgs::ExecutionState<msgs::CreateVerticesRequest> state_;
};
bool Once::OnceCursor::Pull(Frame &, ExecutionContext &context) {
SCOPED_PROFILE_OP("Once");
@ -186,7 +251,7 @@ ACCEPT_WITH_INPUT(CreateNode)
UniqueCursorPtr CreateNode::MakeCursor(utils::MemoryResource *mem) const {
EventCounter::IncrementCounter(EventCounter::CreateNodeOperator);
return MakeUniqueCursorPtr<CreateNodeCursor>(mem, *this, mem);
return MakeUniqueCursorPtr<DistributedCreateNodeCursor>(mem, input_, mem, std::vector{&this->node_info_});
}
std::vector<Symbol> CreateNode::ModifiedSymbols(const SymbolTable &table) const {
@ -272,15 +337,12 @@ ScanAll::ScanAll(const std::shared_ptr<LogicalOperator> &input, Symbol output_sy
ACCEPT_WITH_INPUT(ScanAll)
class DistributedScanAllCursor;
UniqueCursorPtr ScanAll::MakeCursor(utils::MemoryResource *mem) const {
EventCounter::IncrementCounter(EventCounter::ScanAllOperator);
auto vertices = [this](Frame & /*unused*/, ExecutionContext &context) {
auto *db = context.db_accessor;
return std::make_optional(db->Vertices(view_));
};
return MakeUniqueCursorPtr<ScanAllCursor<decltype(vertices)>>(mem, output_symbol_, input_->MakeCursor(mem),
std::move(vertices), "ScanAll");
return MakeUniqueCursorPtr<DistributedScanAllCursor>(mem, output_symbol_, input_->MakeCursor(mem), "ScanAll");
}
std::vector<Symbol> ScanAll::ModifiedSymbols(const SymbolTable &table) const {
@ -295,15 +357,10 @@ ScanAllByLabel::ScanAllByLabel(const std::shared_ptr<LogicalOperator> &input, Sy
ACCEPT_WITH_INPUT(ScanAllByLabel)
UniqueCursorPtr ScanAllByLabel::MakeCursor(utils::MemoryResource *mem) const {
UniqueCursorPtr ScanAllByLabel::MakeCursor(utils::MemoryResource * /*mem*/) const {
EventCounter::IncrementCounter(EventCounter::ScanAllByLabelOperator);
auto vertices = [this](Frame & /*unused*/, ExecutionContext &context) {
auto *db = context.db_accessor;
return std::make_optional(db->Vertices(view_, label_));
};
return MakeUniqueCursorPtr<ScanAllCursor<decltype(vertices)>>(mem, output_symbol_, input_->MakeCursor(mem),
std::move(vertices), "ScanAllByLabel");
throw QueryRuntimeException("ScanAllByLabel is not supported");
}
// TODO(buda): Implement ScanAllByLabelProperty operator to iterate over
@ -326,50 +383,10 @@ ScanAllByLabelPropertyRange::ScanAllByLabelPropertyRange(const std::shared_ptr<L
ACCEPT_WITH_INPUT(ScanAllByLabelPropertyRange)
UniqueCursorPtr ScanAllByLabelPropertyRange::MakeCursor(utils::MemoryResource *mem) const {
UniqueCursorPtr ScanAllByLabelPropertyRange::MakeCursor(utils::MemoryResource * /*mem*/) const {
EventCounter::IncrementCounter(EventCounter::ScanAllByLabelPropertyRangeOperator);
auto vertices = [this](Frame &frame, ExecutionContext &context)
-> std::optional<decltype(context.db_accessor->Vertices(view_, label_, property_, std::nullopt, std::nullopt))> {
auto *db = context.db_accessor;
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, view_);
auto convert = [&evaluator](const auto &bound) -> std::optional<utils::Bound<storage::v3::PropertyValue>> {
if (!bound) return std::nullopt;
const auto &value = bound->value()->Accept(evaluator);
try {
const auto &property_value = storage::v3::TypedToPropertyValue(value);
switch (property_value.type()) {
case storage::v3::PropertyValue::Type::Bool:
case storage::v3::PropertyValue::Type::List:
case storage::v3::PropertyValue::Type::Map:
// Prevent indexed lookup with something that would fail if we did
// the original filter with `operator<`. Note, for some reason,
// Cypher does not support comparing boolean values.
throw QueryRuntimeException("Invalid type {} for '<'.", value.type());
case storage::v3::PropertyValue::Type::Null:
case storage::v3::PropertyValue::Type::Int:
case storage::v3::PropertyValue::Type::Double:
case storage::v3::PropertyValue::Type::String:
case storage::v3::PropertyValue::Type::TemporalData:
// These are all fine, there's also Point, Date and Time data types
// which were added to Cypher, but we don't have support for those
// yet.
return std::make_optional(utils::Bound<storage::v3::PropertyValue>(property_value, bound->type()));
}
} catch (const expr::TypedValueException &) {
throw QueryRuntimeException("'{}' cannot be used as a property value.", value.type());
}
};
auto maybe_lower = convert(lower_bound_);
auto maybe_upper = convert(upper_bound_);
// If any bound is null, then the comparison would result in nulls. This
// is treated as not satisfying the filter, so return no vertices.
if (maybe_lower && maybe_lower->value().IsNull()) return std::nullopt;
if (maybe_upper && maybe_upper->value().IsNull()) return std::nullopt;
return std::make_optional(db->Vertices(view_, label_, property_, maybe_lower, maybe_upper));
};
return MakeUniqueCursorPtr<ScanAllCursor<decltype(vertices)>>(mem, output_symbol_, input_->MakeCursor(mem),
std::move(vertices), "ScanAllByLabelPropertyRange");
throw QueryRuntimeException("ScanAllByLabelPropertyRange is not supported");
}
ScanAllByLabelPropertyValue::ScanAllByLabelPropertyValue(const std::shared_ptr<LogicalOperator> &input,
@ -387,20 +404,10 @@ ScanAllByLabelPropertyValue::ScanAllByLabelPropertyValue(const std::shared_ptr<L
ACCEPT_WITH_INPUT(ScanAllByLabelPropertyValue)
UniqueCursorPtr ScanAllByLabelPropertyValue::MakeCursor(utils::MemoryResource *mem) const {
UniqueCursorPtr ScanAllByLabelPropertyValue::MakeCursor(utils::MemoryResource * /*mem*/) const {
EventCounter::IncrementCounter(EventCounter::ScanAllByLabelPropertyValueOperator);
auto vertices =
[this](Frame &frame, ExecutionContext &context) -> std::optional<decltype(context.db_accessor->Vertices(
view_, label_, property_, storage::v3::PropertyValue()))> {
auto *db = context.db_accessor;
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, view_);
auto value = expression_->Accept(evaluator);
if (value.IsNull()) return std::nullopt;
return std::make_optional(db->Vertices(view_, label_, property_, storage::v3::TypedToPropertyValue(value)));
};
return MakeUniqueCursorPtr<ScanAllCursor<decltype(vertices)>>(mem, output_symbol_, input_->MakeCursor(mem),
std::move(vertices), "ScanAllByLabelPropertyValue");
throw QueryRuntimeException("ScanAllByLabelPropertyValue is not supported");
}
ScanAllByLabelProperty::ScanAllByLabelProperty(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol,
@ -412,13 +419,7 @@ ACCEPT_WITH_INPUT(ScanAllByLabelProperty)
UniqueCursorPtr ScanAllByLabelProperty::MakeCursor(utils::MemoryResource *mem) const {
EventCounter::IncrementCounter(EventCounter::ScanAllByLabelPropertyOperator);
auto vertices = [this](Frame & /*frame*/, ExecutionContext &context) {
auto *db = context.db_accessor;
return std::make_optional(db->Vertices(view_, label_, property_));
};
return MakeUniqueCursorPtr<ScanAllCursor<decltype(vertices)>>(mem, output_symbol_, input_->MakeCursor(mem),
std::move(vertices), "ScanAllByLabelProperty");
throw QueryRuntimeException("ScanAllByLabelProperty is not supported");
}
ScanAllById::ScanAllById(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, Expression *expression,
@ -431,16 +432,8 @@ ACCEPT_WITH_INPUT(ScanAllById)
UniqueCursorPtr ScanAllById::MakeCursor(utils::MemoryResource *mem) const {
EventCounter::IncrementCounter(EventCounter::ScanAllByIdOperator);
// TODO Reimplement when we have reliable conversion between hash value and pk
auto vertices = [this](Frame &frame, ExecutionContext &context) -> std::optional<std::vector<VertexAccessor>> {
// auto *db = context.db_accessor;
// ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
// view_); auto value = expression_->Accept(evaluator); if (!value.IsNumeric()) return std::nullopt; int64_t id =
// value.IsInt() ? value.ValueInt() : value.ValueDouble(); if (value.IsDouble() && id != value.ValueDouble()) return
// std::nullopt; auto maybe_vertex = db->FindVertex(storage::v3::Gid::FromInt(id), view_); auto maybe_vertex =
// nullptr; if (!maybe_vertex) return std::nullopt;
auto vertices = [](Frame & /*frame*/, ExecutionContext & /*context*/) -> std::optional<std::vector<VertexAccessor>> {
return std::nullopt;
// return std::vector<VertexAccessor>{*maybe_vertex};
};
return MakeUniqueCursorPtr<ScanAllCursor<decltype(vertices)>>(mem, output_symbol_, input_->MakeCursor(mem),
std::move(vertices), "ScanAllById");
@ -540,265 +533,10 @@ std::vector<Symbol> ExpandVariable::ModifiedSymbols(const SymbolTable &table) co
return symbols;
}
namespace {
/**
* Helper function that returns an iterable over
* <EdgeAtom::Direction, EdgeAccessor> pairs
* for the given params.
*
* @param vertex - The vertex to expand from.
* @param direction - Expansion direction. All directions (IN, OUT, BOTH)
* are supported.
* @param memory - Used to allocate the result.
* @return See above.
*/
auto ExpandFromVertex(const VertexAccessor &vertex, EdgeAtom::Direction direction,
const std::vector<storage::v3::EdgeTypeId> &edge_types, utils::MemoryResource *memory) {
// wraps an EdgeAccessor into a pair <accessor, direction>
auto wrapper = [](EdgeAtom::Direction direction, auto &&edges) {
return iter::imap([direction](const auto &edge) { return std::make_pair(edge, direction); },
std::forward<decltype(edges)>(edges));
};
storage::v3::View view = storage::v3::View::OLD;
utils::pmr::vector<decltype(wrapper(direction, *vertex.InEdges(view, edge_types)))> chain_elements(memory);
if (direction != EdgeAtom::Direction::OUT) {
auto edges = UnwrapEdgesResult(vertex.InEdges(view, edge_types));
if (edges.begin() != edges.end()) {
chain_elements.emplace_back(wrapper(EdgeAtom::Direction::IN, std::move(edges)));
}
}
if (direction != EdgeAtom::Direction::IN) {
auto edges = UnwrapEdgesResult(vertex.OutEdges(view, edge_types));
if (edges.begin() != edges.end()) {
chain_elements.emplace_back(wrapper(EdgeAtom::Direction::OUT, std::move(edges)));
}
}
// TODO: Investigate whether itertools perform heap allocation?
return iter::chain.from_iterable(std::move(chain_elements));
}
} // namespace
class ExpandVariableCursor : public Cursor {
public:
ExpandVariableCursor(const ExpandVariable &self, utils::MemoryResource *mem)
: self_(self), input_cursor_(self.input_->MakeCursor(mem)), edges_(mem), edges_it_(mem) {}
bool Pull(Frame & /*frame*/, ExecutionContext & /*context*/) override { return false; }
void Shutdown() override { input_cursor_->Shutdown(); }
void Reset() override {
input_cursor_->Reset();
edges_.clear();
edges_it_.clear();
}
private:
const ExpandVariable &self_;
const UniqueCursorPtr input_cursor_;
// bounds. in the cursor they are not optional but set to
// default values if missing in the ExpandVariable operator
// initialize to arbitrary values, they should only be used
// after a successful pull from the input
int64_t upper_bound_{-1};
int64_t lower_bound_{-1};
// a stack of edge iterables corresponding to the level/depth of
// the expansion currently being Pulled
using ExpandEdges = decltype(ExpandFromVertex(std::declval<VertexAccessor>(), EdgeAtom::Direction::IN,
self_.common_.edge_types, utils::NewDeleteResource()));
utils::pmr::vector<ExpandEdges> edges_;
// an iterator indicating the position in the corresponding edges_ element
utils::pmr::vector<decltype(edges_.begin()->begin())> edges_it_;
/**
* Performs a single expansion for the current state of this
* VariableExpansionCursor.
*
* @return True if the expansion was a success and this Cursor's
* consumer can consume it. False if the expansion failed. In that
* case no more expansions are available from the current input
* vertex and another Pull from the input cursor should be performed.
*/
// NOLINTNEXTLINE(readability-convert-member-functions-to-static)
bool Expand(Frame & /*frame*/, ExecutionContext & /*context*/) { return false; }
};
class STShortestPathCursor : public query::v2::plan::Cursor {
public:
STShortestPathCursor(const ExpandVariable &self, utils::MemoryResource *mem)
: self_(self), input_cursor_(self_.input()->MakeCursor(mem)) {
MG_ASSERT(self_.common_.existing_node,
"s-t shortest path algorithm should only "
"be used when `existing_node` flag is "
"set!");
}
bool Pull(Frame & /*frame*/, ExecutionContext & /*context*/) override { return false; }
void Shutdown() override { input_cursor_->Shutdown(); }
void Reset() override { input_cursor_->Reset(); }
private:
const ExpandVariable &self_;
UniqueCursorPtr input_cursor_;
using VertexEdgeMapT = utils::pmr::unordered_map<VertexAccessor, std::optional<EdgeAccessor>>;
};
class SingleSourceShortestPathCursor : public query::v2::plan::Cursor {
public:
SingleSourceShortestPathCursor(const ExpandVariable &self, utils::MemoryResource *mem)
: self_(self),
input_cursor_(self_.input()->MakeCursor(mem)),
processed_(mem),
to_visit_current_(mem),
to_visit_next_(mem) {
MG_ASSERT(!self_.common_.existing_node,
"Single source shortest path algorithm "
"should not be used when `existing_node` "
"flag is set, s-t shortest path algorithm "
"should be used instead!");
}
bool Pull(Frame & /*frame*/, ExecutionContext & /*context*/) override { return true; }
void Shutdown() override { input_cursor_->Shutdown(); }
void Reset() override {
input_cursor_->Reset();
processed_.clear();
to_visit_next_.clear();
to_visit_current_.clear();
}
private:
const ExpandVariable &self_;
const UniqueCursorPtr input_cursor_;
// Depth bounds. Calculated on each pull from the input, the initial value
// is irrelevant.
int64_t lower_bound_{-1};
int64_t upper_bound_{-1};
// maps vertices to the edge they got expanded from. it is an optional
// edge because the root does not get expanded from anything.
// contains visited vertices as well as those scheduled to be visited.
utils::pmr::unordered_map<VertexAccessor, std::optional<EdgeAccessor>> processed_;
// edge/vertex pairs we have yet to visit, for current and next depth
utils::pmr::vector<std::pair<EdgeAccessor, VertexAccessor>> to_visit_current_;
utils::pmr::vector<std::pair<EdgeAccessor, VertexAccessor>> to_visit_next_;
};
class ExpandWeightedShortestPathCursor : public query::v2::plan::Cursor {
public:
ExpandWeightedShortestPathCursor(const ExpandVariable &self, utils::MemoryResource *mem)
: self_(self),
input_cursor_(self_.input_->MakeCursor(mem)),
total_cost_(mem),
previous_(mem),
yielded_vertices_(mem),
pq_(mem) {}
bool Pull(Frame & /*frame*/, ExecutionContext & /*context*/) override { return false; }
void Shutdown() override { input_cursor_->Shutdown(); }
void Reset() override {
input_cursor_->Reset();
previous_.clear();
total_cost_.clear();
yielded_vertices_.clear();
ClearQueue();
}
private:
const ExpandVariable &self_;
const UniqueCursorPtr input_cursor_;
// Upper bound on the path length.
int64_t upper_bound_{-1};
bool upper_bound_set_{false};
struct WspStateHash {
size_t operator()(const std::pair<VertexAccessor, int64_t> &key) const {
return utils::HashCombine<VertexAccessor, int64_t>{}(key.first, key.second);
}
};
// Maps vertices to weights they got in expansion.
utils::pmr::unordered_map<std::pair<VertexAccessor, int64_t>, TypedValue, WspStateHash> total_cost_;
// Maps vertices to edges used to reach them.
utils::pmr::unordered_map<std::pair<VertexAccessor, int64_t>, std::optional<EdgeAccessor>, WspStateHash> previous_;
// Keeps track of vertices for which we yielded a path already.
utils::pmr::unordered_set<VertexAccessor> yielded_vertices_;
static void ValidateWeightTypes(const TypedValue &lhs, const TypedValue &rhs) {
if (!((lhs.IsNumeric() && lhs.IsNumeric()) || (rhs.IsDuration() && rhs.IsDuration()))) {
throw QueryRuntimeException(utils::MessageWithLink(
"All weights should be of the same type, either numeric or a Duration. Please update the weight "
"expression or the filter expression.",
"https://memgr.ph/wsp"));
}
}
// Priority queue comparator. Keep lowest weight on top of the queue.
class PriorityQueueComparator {
public:
bool operator()(const std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>> &lhs,
const std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>> &rhs) {
const auto &lhs_weight = std::get<0>(lhs);
const auto &rhs_weight = std::get<0>(rhs);
// Null defines minimum value for all types
if (lhs_weight.IsNull()) {
return false;
}
if (rhs_weight.IsNull()) {
return true;
}
ValidateWeightTypes(lhs_weight, rhs_weight);
return (lhs_weight > rhs_weight).ValueBool();
}
};
std::priority_queue<std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>>,
utils::pmr::vector<std::tuple<TypedValue, int64_t, VertexAccessor, std::optional<EdgeAccessor>>>,
PriorityQueueComparator>
pq_;
void ClearQueue() {
while (!pq_.empty()) pq_.pop();
}
};
UniqueCursorPtr ExpandVariable::MakeCursor(utils::MemoryResource *mem) const {
EventCounter::IncrementCounter(EventCounter::ExpandVariableOperator);
switch (type_) {
case EdgeAtom::Type::BREADTH_FIRST:
if (common_.existing_node) {
return MakeUniqueCursorPtr<STShortestPathCursor>(mem, *this, mem);
} else {
return MakeUniqueCursorPtr<SingleSourceShortestPathCursor>(mem, *this, mem);
}
case EdgeAtom::Type::DEPTH_FIRST:
return MakeUniqueCursorPtr<ExpandVariableCursor>(mem, *this, mem);
case EdgeAtom::Type::WEIGHTED_SHORTEST_PATH:
return MakeUniqueCursorPtr<ExpandWeightedShortestPathCursor>(mem, *this, mem);
case EdgeAtom::Type::SINGLE:
LOG_FATAL("ExpandVariable should not be planned for a single expansion!");
}
throw QueryRuntimeException("ExpandVariable is not supported");
}
class ConstructNamedPathCursor : public Cursor {
@ -1039,43 +777,6 @@ SetLabels::SetLabelsCursor::SetLabelsCursor(const SetLabels &self, utils::Memory
bool SetLabels::SetLabelsCursor::Pull(Frame &frame, ExecutionContext &context) {
SCOPED_PROFILE_OP("SetLabels");
return false;
// if (!input_cursor_->Pull(frame, context)) return false;
//
// TypedValue &vertex_value = frame[self_.input_symbol_];
// // Skip setting labels on Null (can occur in optional match).
// if (vertex_value.IsNull()) return true;
// ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex);
//
// auto &dba = *context.db_accessor;
// auto &vertex = vertex_value.ValueVertex();
// for (const auto label : self_.labels_) {
// auto maybe_value = vertex.AddLabelAndValidate(label);
// if (maybe_value.HasError()) {
// std::visit(utils::Overloaded{[](const storage::v3::Error error) {
// switch (error) {
// case storage::v3::Error::SERIALIZATION_ERROR:
// throw TransactionSerializationException();
// case storage::v3::Error::DELETED_OBJECT:
// throw QueryRuntimeException("Trying to set a label on a deleted node.");
// case storage::v3::Error::VERTEX_HAS_EDGES:
// case storage::v3::Error::PROPERTIES_DISABLED:
// case storage::v3::Error::NONEXISTENT_OBJECT:
// throw QueryRuntimeException("Unexpected error when setting a label.");
// }
// },
// [&dba](const storage::v3::SchemaViolation schema_violation) {
// HandleSchemaViolation(schema_violation, dba);
// }},
// maybe_value.GetError());
// }
//
// context.execution_stats[ExecutionStats::Key::CREATED_LABELS]++;
// if (context.trigger_context_collector && *maybe_value) {
// context.trigger_context_collector->RegisterSetVertexLabel(vertex, label);
// }
// }
//
// return true;
}
void SetLabels::SetLabelsCursor::Shutdown() { input_cursor_->Shutdown(); }
@ -1104,36 +805,6 @@ RemoveProperty::RemovePropertyCursor::RemovePropertyCursor(const RemoveProperty
bool RemoveProperty::RemovePropertyCursor::Pull(Frame &frame, ExecutionContext &context) {
SCOPED_PROFILE_OP("RemoveProperty");
return false;
// if (!input_cursor_->Pull(frame, context)) return false;
//
// // Remove, just like Delete needs to see the latest changes.
// ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
// storage::v3::View::NEW);
// TypedValue lhs = self_.lhs_->expression_->Accept(evaluator);
//
// auto remove_prop = [property = self_.property_, &context](auto *record) {
// auto old_value = PropsSetChecked(record, *context.db_accessor, property, TypedValue{});
//
// if (context.trigger_context_collector) {
// context.trigger_context_collector->RegisterRemovedObjectProperty(
// *record, property, storage::v3::PropertyToTypedValue<TypedValue>(std::move(old_value)));
// }
// };
//
// switch (lhs.type()) {
// case TypedValue::Type::Vertex:
// remove_prop(&lhs.ValueVertex());
// break;
// case TypedValue::Type::Edge:
// remove_prop(&lhs.ValueEdge());
// break;
// case TypedValue::Type::Null:
// // Skip removing properties on Null (can occur in optional match).
// break;
// default:
// throw QueryRuntimeException("Properties can only be removed from vertices and edges.");
// }
// return true;
}
void RemoveProperty::RemovePropertyCursor::Shutdown() { input_cursor_->Shutdown(); }
@ -1162,44 +833,6 @@ RemoveLabels::RemoveLabelsCursor::RemoveLabelsCursor(const RemoveLabels &self, u
bool RemoveLabels::RemoveLabelsCursor::Pull(Frame &frame, ExecutionContext &context) {
SCOPED_PROFILE_OP("RemoveLabels");
return false;
//
// if (!input_cursor_->Pull(frame, context)) return false;
//
// TypedValue &vertex_value = frame[self_.input_symbol_];
// // Skip removing labels on Null (can occur in optional match).
// if (vertex_value.IsNull()) return true;
// ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex);
// auto &vertex = vertex_value.ValueVertex();
// for (auto label : self_.labels_) {
// auto maybe_value = vertex.RemoveLabelAndValidate(label);
// if (maybe_value.HasError()) {
// std::visit(
// utils::Overloaded{[](const storage::v3::Error error) {
// switch (error) {
// case storage::v3::Error::SERIALIZATION_ERROR:
// throw TransactionSerializationException();
// case storage::v3::Error::DELETED_OBJECT:
// throw QueryRuntimeException("Trying to remove labels from a deleted node.");
// case storage::v3::Error::VERTEX_HAS_EDGES:
// case storage::v3::Error::PROPERTIES_DISABLED:
// case storage::v3::Error::NONEXISTENT_OBJECT:
// throw QueryRuntimeException("Unexpected error when removing labels from a
// node.");
// }
// },
// [&context](const storage::v3::SchemaViolation &schema_violation) {
// HandleSchemaViolation(schema_violation, *context.db_accessor);
// }},
// maybe_value.GetError());
// }
//
// context.execution_stats[ExecutionStats::Key::DELETED_LABELS] += 1;
// if (context.trigger_context_collector && *maybe_value) {
// context.trigger_context_collector->RegisterRemovedVertexLabel(vertex, label);
// }
// }
//
// return true;
}
void RemoveLabels::RemoveLabelsCursor::Shutdown() { input_cursor_->Shutdown(); }
@ -1278,57 +911,9 @@ ACCEPT_WITH_INPUT(Accumulate)
std::vector<Symbol> Accumulate::ModifiedSymbols(const SymbolTable &) const { return symbols_; }
class AccumulateCursor : public Cursor {
public:
AccumulateCursor(const Accumulate &self, utils::MemoryResource *mem)
: self_(self), input_cursor_(self.input_->MakeCursor(mem)), cache_(mem) {}
bool Pull(Frame &frame, ExecutionContext &context) override {
SCOPED_PROFILE_OP("Accumulate");
auto &dba = *context.db_accessor;
// cache all the input
if (!pulled_all_input_) {
while (input_cursor_->Pull(frame, context)) {
utils::pmr::vector<TypedValue> row(cache_.get_allocator().GetMemoryResource());
row.reserve(self_.symbols_.size());
for (const Symbol &symbol : self_.symbols_) row.emplace_back(frame[symbol]);
cache_.emplace_back(std::move(row));
}
pulled_all_input_ = true;
cache_it_ = cache_.begin();
if (self_.advance_command_) dba.AdvanceCommand();
}
if (MustAbort(context)) throw HintedAbortError();
if (cache_it_ == cache_.end()) return false;
auto row_it = (cache_it_++)->begin();
for (const Symbol &symbol : self_.symbols_) frame[symbol] = *row_it++;
return true;
}
void Shutdown() override { input_cursor_->Shutdown(); }
void Reset() override {
input_cursor_->Reset();
cache_.clear();
cache_it_ = cache_.begin();
pulled_all_input_ = false;
}
private:
const Accumulate &self_;
const UniqueCursorPtr input_cursor_;
utils::pmr::vector<utils::pmr::vector<TypedValue>> cache_;
decltype(cache_.begin()) cache_it_ = cache_.begin();
bool pulled_all_input_{false};
};
UniqueCursorPtr Accumulate::MakeCursor(utils::MemoryResource *mem) const {
EventCounter::IncrementCounter(EventCounter::AccumulateOperator);
return MakeUniqueCursorPtr<AccumulateCursor>(mem, *this, mem);
throw QueryRuntimeException("Accumulate is not supported");
}
Aggregate::Aggregate(const std::shared_ptr<LogicalOperator> &input, const std::vector<Aggregate::Element> &aggregations,
@ -2448,40 +2033,8 @@ std::unordered_map<std::string, int64_t> CallProcedure::GetAndResetCounters() {
return ret;
}
class CallProcedureCursor : public Cursor {
const CallProcedure *self_;
UniqueCursorPtr input_cursor_;
mgp_result result_;
decltype(result_.rows.end()) result_row_it_{result_.rows.end()};
size_t result_signature_size_{0};
public:
CallProcedureCursor(const CallProcedure *self, utils::MemoryResource *mem)
: self_(self),
input_cursor_(self_->input_->MakeCursor(mem)),
// result_ needs to live throughout multiple Pull evaluations, until all
// rows are produced. Therefore, we use the memory dedicated for the
// whole execution.
result_(nullptr, mem) {
MG_ASSERT(self_->result_fields_.size() == self_->result_symbols_.size(), "Incorrectly constructed CallProcedure");
}
bool Pull(Frame & /*frame*/, ExecutionContext & /*context*/) override { return false; }
void Reset() override {
result_.rows.clear();
result_.error_msg.reset();
input_cursor_->Reset();
}
void Shutdown() override {}
};
UniqueCursorPtr CallProcedure::MakeCursor(utils::MemoryResource *mem) const {
EventCounter::IncrementCounter(EventCounter::CallProcedureOperator);
CallProcedure::IncrementCounter(procedure_name_);
return MakeUniqueCursorPtr<CallProcedureCursor>(mem, this, mem);
throw QueryRuntimeException("Procedure call is not supported!");
}
LoadCsv::LoadCsv(std::shared_ptr<LogicalOperator> input, Expression *file, bool with_header, bool ignore_bad,
@ -2711,6 +2264,8 @@ class DistributedScanAllCursor : public Cursor {
using VertexAccessor = accessors::VertexAccessor;
bool MakeRequest(msgs::ShardRequestManagerInterface &shard_manager) {
// TODO(antaljanosbenjamin) Use real label
request_state_.label = "label";
current_batch = shard_manager.Request(request_state_);
current_vertex_it = current_batch.begin();
return !current_batch.empty();
@ -2759,69 +2314,4 @@ class DistributedScanAllCursor : public Cursor {
decltype(std::vector<VertexAccessor>().begin()) current_vertex_it;
msgs::ExecutionState<msgs::ScanVerticesRequest> request_state_;
};
class DistributedCreateNodeCursor : public Cursor {
public:
using InputOperator = std::shared_ptr<memgraph::query::v2::plan::LogicalOperator>;
DistributedCreateNodeCursor(const InputOperator &op, utils::MemoryResource *mem,
std::vector<NodeCreationInfo> nodes_info)
: input_cursor_(op->MakeCursor(mem)), nodes_info_(std::move(nodes_info)) {}
bool Pull(Frame &frame, ExecutionContext &context) override {
SCOPED_PROFILE_OP("CreateNode");
if (input_cursor_->Pull(frame, context)) {
auto &shard_manager = context.shard_request_manager;
shard_manager->Request(state_, NodeCreationInfoToRequest(context, frame));
return true;
}
return false;
}
void Shutdown() override { input_cursor_->Shutdown(); }
void Reset() override { state_ = {}; }
std::vector<msgs::NewVertex> NodeCreationInfoToRequest(ExecutionContext &context, Frame &frame) const {
std::vector<msgs::NewVertex> requests;
for (const auto &node_info : nodes_info_) {
msgs::NewVertex rqst;
std::map<msgs::PropertyId, msgs::Value> properties;
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, nullptr,
storage::v3::View::NEW);
if (const auto *node_info_properties = std::get_if<PropertiesMapList>(&node_info.properties)) {
for (const auto &[key, value_expression] : *node_info_properties) {
TypedValue val = value_expression->Accept(evaluator);
properties[key] = TypedValueToValue(val);
if (context.shard_request_manager->IsPrimaryKey(key)) {
rqst.primary_key.push_back(storage::v3::TypedValueToValue(val));
}
}
} else {
auto property_map = evaluator.Visit(*std::get<ParameterLookup *>(node_info.properties)).ValueMap();
for (const auto &[key, value] : property_map) {
auto key_str = std::string(key);
auto property_id = context.shard_request_manager->NameToProperty(key_str);
properties[property_id] = TypedValueToValue(value);
if (context.shard_request_manager->IsPrimaryKey(property_id)) {
rqst.primary_key.push_back(storage::v3::TypedValueToValue(value));
}
}
}
if (node_info.labels.empty()) {
throw QueryRuntimeException("Primary label must be defined!");
}
// TODO(kostasrim) Copy non primary labels as well
rqst.label_ids.push_back(msgs::Label{node_info.labels[0]});
requests.push_back(std::move(rqst));
}
return requests;
}
private:
const UniqueCursorPtr input_cursor_;
std::vector<NodeCreationInfo> nodes_info_;
msgs::ExecutionState<msgs::CreateVerticesRequest> state_;
};
} // namespace memgraph::query::v2::plan

View File

@ -14,11 +14,13 @@
#include "query/v2/bindings/pretty_print.hpp"
#include "query/v2/db_accessor.hpp"
#include "query/v2/shard_request_manager.hpp"
#include "utils/string.hpp"
namespace memgraph::query::v2::plan {
PlanPrinter::PlanPrinter(const DbAccessor *dba, std::ostream *out) : dba_(dba), out_(out) {}
PlanPrinter::PlanPrinter(const msgs::ShardRequestManagerInterface *request_manager, std::ostream *out)
: request_manager_(request_manager), out_(out) {}
#define PRE_VISIT(TOp) \
bool PlanPrinter::PreVisit(TOp &) { \
@ -32,7 +34,7 @@ bool PlanPrinter::PreVisit(CreateExpand &op) {
WithPrintLn([&](auto &out) {
out << "* CreateExpand (" << op.input_symbol_.name() << ")"
<< (op.edge_info_.direction == query::v2::EdgeAtom::Direction::IN ? "<-" : "-") << "["
<< op.edge_info_.symbol.name() << ":" << dba_->EdgeTypeToName(op.edge_info_.edge_type) << "]"
<< op.edge_info_.symbol.name() << ":" << request_manager_->EdgeTypeToName(op.edge_info_.edge_type) << "]"
<< (op.edge_info_.direction == query::v2::EdgeAtom::Direction::OUT ? "->" : "-") << "("
<< op.node_info_.symbol.name() << ")";
});
@ -52,7 +54,7 @@ bool PlanPrinter::PreVisit(query::v2::plan::ScanAll &op) {
bool PlanPrinter::PreVisit(query::v2::plan::ScanAllByLabel &op) {
WithPrintLn([&](auto &out) {
out << "* ScanAllByLabel"
<< " (" << op.output_symbol_.name() << " :" << dba_->LabelToName(op.label_) << ")";
<< " (" << op.output_symbol_.name() << " :" << request_manager_->LabelToName(op.label_) << ")";
});
return true;
}
@ -60,8 +62,8 @@ bool PlanPrinter::PreVisit(query::v2::plan::ScanAllByLabel &op) {
bool PlanPrinter::PreVisit(query::v2::plan::ScanAllByLabelPropertyValue &op) {
WithPrintLn([&](auto &out) {
out << "* ScanAllByLabelPropertyValue"
<< " (" << op.output_symbol_.name() << " :" << dba_->LabelToName(op.label_) << " {"
<< dba_->PropertyToName(op.property_) << "})";
<< " (" << op.output_symbol_.name() << " :" << request_manager_->LabelToName(op.label_) << " {"
<< request_manager_->PropertyToName(op.property_) << "})";
});
return true;
}
@ -69,8 +71,8 @@ bool PlanPrinter::PreVisit(query::v2::plan::ScanAllByLabelPropertyValue &op) {
bool PlanPrinter::PreVisit(query::v2::plan::ScanAllByLabelPropertyRange &op) {
WithPrintLn([&](auto &out) {
out << "* ScanAllByLabelPropertyRange"
<< " (" << op.output_symbol_.name() << " :" << dba_->LabelToName(op.label_) << " {"
<< dba_->PropertyToName(op.property_) << "})";
<< " (" << op.output_symbol_.name() << " :" << request_manager_->LabelToName(op.label_) << " {"
<< request_manager_->PropertyToName(op.property_) << "})";
});
return true;
}
@ -78,8 +80,8 @@ bool PlanPrinter::PreVisit(query::v2::plan::ScanAllByLabelPropertyRange &op) {
bool PlanPrinter::PreVisit(query::v2::plan::ScanAllByLabelProperty &op) {
WithPrintLn([&](auto &out) {
out << "* ScanAllByLabelProperty"
<< " (" << op.output_symbol_.name() << " :" << dba_->LabelToName(op.label_) << " {"
<< dba_->PropertyToName(op.property_) << "})";
<< " (" << op.output_symbol_.name() << " :" << request_manager_->LabelToName(op.label_) << " {"
<< request_manager_->PropertyToName(op.property_) << "})";
});
return true;
}
@ -98,7 +100,7 @@ bool PlanPrinter::PreVisit(query::v2::plan::Expand &op) {
<< (op.common_.direction == query::v2::EdgeAtom::Direction::IN ? "<-" : "-") << "["
<< op.common_.edge_symbol.name();
utils::PrintIterable(*out_, op.common_.edge_types, "|", [this](auto &stream, const auto &edge_type) {
stream << ":" << dba_->EdgeTypeToName(edge_type);
stream << ":" << request_manager_->EdgeTypeToName(edge_type);
});
*out_ << "]" << (op.common_.direction == query::v2::EdgeAtom::Direction::OUT ? "->" : "-") << "("
<< op.common_.node_symbol.name() << ")";
@ -127,7 +129,7 @@ bool PlanPrinter::PreVisit(query::v2::plan::ExpandVariable &op) {
<< (op.common_.direction == query::v2::EdgeAtom::Direction::IN ? "<-" : "-") << "["
<< op.common_.edge_symbol.name();
utils::PrintIterable(*out_, op.common_.edge_types, "|", [this](auto &stream, const auto &edge_type) {
stream << ":" << dba_->EdgeTypeToName(edge_type);
stream << ":" << request_manager_->EdgeTypeToName(edge_type);
});
*out_ << "]" << (op.common_.direction == query::v2::EdgeAtom::Direction::OUT ? "->" : "-") << "("
<< op.common_.node_symbol.name() << ")";
@ -261,14 +263,15 @@ void PlanPrinter::Branch(query::v2::plan::LogicalOperator &op, const std::string
--depth_;
}
void PrettyPrint(const DbAccessor &dba, const LogicalOperator *plan_root, std::ostream *out) {
PlanPrinter printer(&dba, out);
void PrettyPrint(const msgs::ShardRequestManagerInterface &request_manager, const LogicalOperator *plan_root,
std::ostream *out) {
PlanPrinter printer(&request_manager, out);
// FIXME(mtomic): We should make visitors that take const arguments.
const_cast<LogicalOperator *>(plan_root)->Accept(printer);
}
nlohmann::json PlanToJson(const DbAccessor &dba, const LogicalOperator *plan_root) {
impl::PlanToJsonVisitor visitor(&dba);
nlohmann::json PlanToJson(const msgs::ShardRequestManagerInterface &request_manager, const LogicalOperator *plan_root) {
impl::PlanToJsonVisitor visitor(&request_manager);
// FIXME(mtomic): We should make visitors that take const arguments.
const_cast<LogicalOperator *>(plan_root)->Accept(visitor);
return visitor.output();
@ -346,11 +349,17 @@ json ToJson(const utils::Bound<Expression *> &bound) {
json ToJson(const Symbol &symbol) { return symbol.name(); }
json ToJson(storage::v3::EdgeTypeId edge_type, const DbAccessor &dba) { return dba.EdgeTypeToName(edge_type); }
json ToJson(storage::v3::EdgeTypeId edge_type, const msgs::ShardRequestManagerInterface &request_manager) {
return request_manager.EdgeTypeToName(edge_type);
}
json ToJson(storage::v3::LabelId label, const DbAccessor &dba) { return dba.LabelToName(label); }
json ToJson(storage::v3::LabelId label, const msgs::ShardRequestManagerInterface &request_manager) {
return request_manager.LabelToName(label);
}
json ToJson(storage::v3::PropertyId property, const DbAccessor &dba) { return dba.PropertyToName(property); }
json ToJson(storage::v3::PropertyId property, const msgs::ShardRequestManagerInterface &request_manager) {
return request_manager.PropertyToName(property);
}
json ToJson(NamedExpression *nexpr) {
json json;
@ -359,29 +368,30 @@ json ToJson(NamedExpression *nexpr) {
return json;
}
json ToJson(const std::vector<std::pair<storage::v3::PropertyId, Expression *>> &properties, const DbAccessor &dba) {
json ToJson(const std::vector<std::pair<storage::v3::PropertyId, Expression *>> &properties,
const msgs::ShardRequestManagerInterface &request_manager) {
json json;
for (const auto &prop_pair : properties) {
json.emplace(ToJson(prop_pair.first, dba), ToJson(prop_pair.second));
json.emplace(ToJson(prop_pair.first, request_manager), ToJson(prop_pair.second));
}
return json;
}
json ToJson(const NodeCreationInfo &node_info, const DbAccessor &dba) {
json ToJson(const NodeCreationInfo &node_info, const msgs::ShardRequestManagerInterface &request_manager) {
json self;
self["symbol"] = ToJson(node_info.symbol);
self["labels"] = ToJson(node_info.labels, dba);
self["labels"] = ToJson(node_info.labels, request_manager);
const auto *props = std::get_if<PropertiesMapList>(&node_info.properties);
self["properties"] = ToJson(props ? *props : PropertiesMapList{}, dba);
self["properties"] = ToJson(props ? *props : PropertiesMapList{}, request_manager);
return self;
}
json ToJson(const EdgeCreationInfo &edge_info, const DbAccessor &dba) {
json ToJson(const EdgeCreationInfo &edge_info, const msgs::ShardRequestManagerInterface &request_manager) {
json self;
self["symbol"] = ToJson(edge_info.symbol);
const auto *props = std::get_if<PropertiesMapList>(&edge_info.properties);
self["properties"] = ToJson(props ? *props : PropertiesMapList{}, dba);
self["edge_type"] = ToJson(edge_info.edge_type, dba);
self["properties"] = ToJson(props ? *props : PropertiesMapList{}, request_manager);
self["edge_type"] = ToJson(edge_info.edge_type, request_manager);
self["direction"] = ToString(edge_info.direction);
return self;
}
@ -423,7 +433,7 @@ bool PlanToJsonVisitor::PreVisit(ScanAll &op) {
bool PlanToJsonVisitor::PreVisit(ScanAllByLabel &op) {
json self;
self["name"] = "ScanAllByLabel";
self["label"] = ToJson(op.label_, *dba_);
self["label"] = ToJson(op.label_, *request_manager_);
self["output_symbol"] = ToJson(op.output_symbol_);
op.input_->Accept(*this);
@ -436,8 +446,8 @@ bool PlanToJsonVisitor::PreVisit(ScanAllByLabel &op) {
bool PlanToJsonVisitor::PreVisit(ScanAllByLabelPropertyRange &op) {
json self;
self["name"] = "ScanAllByLabelPropertyRange";
self["label"] = ToJson(op.label_, *dba_);
self["property"] = ToJson(op.property_, *dba_);
self["label"] = ToJson(op.label_, *request_manager_);
self["property"] = ToJson(op.property_, *request_manager_);
self["lower_bound"] = op.lower_bound_ ? ToJson(*op.lower_bound_) : json();
self["upper_bound"] = op.upper_bound_ ? ToJson(*op.upper_bound_) : json();
self["output_symbol"] = ToJson(op.output_symbol_);
@ -452,8 +462,8 @@ bool PlanToJsonVisitor::PreVisit(ScanAllByLabelPropertyRange &op) {
bool PlanToJsonVisitor::PreVisit(ScanAllByLabelPropertyValue &op) {
json self;
self["name"] = "ScanAllByLabelPropertyValue";
self["label"] = ToJson(op.label_, *dba_);
self["property"] = ToJson(op.property_, *dba_);
self["label"] = ToJson(op.label_, *request_manager_);
self["property"] = ToJson(op.property_, *request_manager_);
self["expression"] = ToJson(op.expression_);
self["output_symbol"] = ToJson(op.output_symbol_);
@ -467,8 +477,8 @@ bool PlanToJsonVisitor::PreVisit(ScanAllByLabelPropertyValue &op) {
bool PlanToJsonVisitor::PreVisit(ScanAllByLabelProperty &op) {
json self;
self["name"] = "ScanAllByLabelProperty";
self["label"] = ToJson(op.label_, *dba_);
self["property"] = ToJson(op.property_, *dba_);
self["label"] = ToJson(op.label_, *request_manager_);
self["property"] = ToJson(op.property_, *request_manager_);
self["output_symbol"] = ToJson(op.output_symbol_);
op.input_->Accept(*this);
@ -491,7 +501,7 @@ bool PlanToJsonVisitor::PreVisit(ScanAllById &op) {
bool PlanToJsonVisitor::PreVisit(CreateNode &op) {
json self;
self["name"] = "CreateNode";
self["node_info"] = ToJson(op.node_info_, *dba_);
self["node_info"] = ToJson(op.node_info_, *request_manager_);
op.input_->Accept(*this);
self["input"] = PopOutput();
@ -504,8 +514,8 @@ bool PlanToJsonVisitor::PreVisit(CreateExpand &op) {
json self;
self["name"] = "CreateExpand";
self["input_symbol"] = ToJson(op.input_symbol_);
self["node_info"] = ToJson(op.node_info_, *dba_);
self["edge_info"] = ToJson(op.edge_info_, *dba_);
self["node_info"] = ToJson(op.node_info_, *request_manager_);
self["edge_info"] = ToJson(op.edge_info_, *request_manager_);
self["existing_node"] = op.existing_node_;
op.input_->Accept(*this);
@ -521,7 +531,7 @@ bool PlanToJsonVisitor::PreVisit(Expand &op) {
self["input_symbol"] = ToJson(op.input_symbol_);
self["node_symbol"] = ToJson(op.common_.node_symbol);
self["edge_symbol"] = ToJson(op.common_.edge_symbol);
self["edge_types"] = ToJson(op.common_.edge_types, *dba_);
self["edge_types"] = ToJson(op.common_.edge_types, *request_manager_);
self["direction"] = ToString(op.common_.direction);
self["existing_node"] = op.common_.existing_node;
@ -538,7 +548,7 @@ bool PlanToJsonVisitor::PreVisit(ExpandVariable &op) {
self["input_symbol"] = ToJson(op.input_symbol_);
self["node_symbol"] = ToJson(op.common_.node_symbol);
self["edge_symbol"] = ToJson(op.common_.edge_symbol);
self["edge_types"] = ToJson(op.common_.edge_types, *dba_);
self["edge_types"] = ToJson(op.common_.edge_types, *request_manager_);
self["direction"] = ToString(op.common_.direction);
self["type"] = ToString(op.type_);
self["is_reverse"] = op.is_reverse_;
@ -613,7 +623,7 @@ bool PlanToJsonVisitor::PreVisit(Delete &op) {
bool PlanToJsonVisitor::PreVisit(SetProperty &op) {
json self;
self["name"] = "SetProperty";
self["property"] = ToJson(op.property_, *dba_);
self["property"] = ToJson(op.property_, *request_manager_);
self["lhs"] = ToJson(op.lhs_);
self["rhs"] = ToJson(op.rhs_);
@ -650,7 +660,7 @@ bool PlanToJsonVisitor::PreVisit(SetLabels &op) {
json self;
self["name"] = "SetLabels";
self["input_symbol"] = ToJson(op.input_symbol_);
self["labels"] = ToJson(op.labels_, *dba_);
self["labels"] = ToJson(op.labels_, *request_manager_);
op.input_->Accept(*this);
self["input"] = PopOutput();
@ -662,7 +672,7 @@ bool PlanToJsonVisitor::PreVisit(SetLabels &op) {
bool PlanToJsonVisitor::PreVisit(RemoveProperty &op) {
json self;
self["name"] = "RemoveProperty";
self["property"] = ToJson(op.property_, *dba_);
self["property"] = ToJson(op.property_, *request_manager_);
self["lhs"] = ToJson(op.lhs_);
op.input_->Accept(*this);
@ -676,7 +686,7 @@ bool PlanToJsonVisitor::PreVisit(RemoveLabels &op) {
json self;
self["name"] = "RemoveLabels";
self["input_symbol"] = ToJson(op.input_symbol_);
self["labels"] = ToJson(op.labels_, *dba_);
self["labels"] = ToJson(op.labels_, *request_manager_);
op.input_->Accept(*this);
self["input"] = PopOutput();

View File

@ -18,28 +18,29 @@
#include "query/v2/frontend/ast/ast.hpp"
#include "query/v2/plan/operator.hpp"
#include "query/v2/shard_request_manager.hpp"
namespace memgraph::query::v2 {
class DbAccessor;
namespace plan {
class LogicalOperator;
/// Pretty print a `LogicalOperator` plan to a `std::ostream`.
/// DbAccessor is needed for resolving label and property names.
/// ShardRequestManager is needed for resolving label and property names.
/// Note that `plan_root` isn't modified, but we can't take it as a const
/// because we don't have support for visiting a const LogicalOperator.
void PrettyPrint(const DbAccessor &dba, const LogicalOperator *plan_root, std::ostream *out);
void PrettyPrint(const msgs::ShardRequestManagerInterface &request_manager, const LogicalOperator *plan_root,
std::ostream *out);
/// Overload of `PrettyPrint` which defaults the `std::ostream` to `std::cout`.
inline void PrettyPrint(const DbAccessor &dba, const LogicalOperator *plan_root) {
PrettyPrint(dba, plan_root, &std::cout);
inline void PrettyPrint(const msgs::ShardRequestManagerInterface &request_manager, const LogicalOperator *plan_root) {
PrettyPrint(request_manager, plan_root, &std::cout);
}
/// Convert a `LogicalOperator` plan to a JSON representation.
/// DbAccessor is needed for resolving label and property names.
nlohmann::json PlanToJson(const DbAccessor &dba, const LogicalOperator *plan_root);
nlohmann::json PlanToJson(const msgs::ShardRequestManagerInterface &request_manager, const LogicalOperator *plan_root);
class PlanPrinter : public virtual HierarchicalLogicalOperatorVisitor {
public:
@ -47,7 +48,7 @@ class PlanPrinter : public virtual HierarchicalLogicalOperatorVisitor {
using HierarchicalLogicalOperatorVisitor::PreVisit;
using HierarchicalLogicalOperatorVisitor::Visit;
PlanPrinter(const DbAccessor *dba, std::ostream *out);
PlanPrinter(const msgs::ShardRequestManagerInterface *request_manager, std::ostream *out);
bool DefaultPreVisit() override;
@ -114,7 +115,7 @@ class PlanPrinter : public virtual HierarchicalLogicalOperatorVisitor {
void Branch(LogicalOperator &op, const std::string &branch_name = "");
int64_t depth_{0};
const DbAccessor *dba_{nullptr};
const msgs::ShardRequestManagerInterface *request_manager_{nullptr};
std::ostream *out_{nullptr};
};
@ -132,20 +133,20 @@ nlohmann::json ToJson(const utils::Bound<Expression *> &bound);
nlohmann::json ToJson(const Symbol &symbol);
nlohmann::json ToJson(storage::v3::EdgeTypeId edge_type, const DbAccessor &dba);
nlohmann::json ToJson(storage::v3::EdgeTypeId edge_type, const msgs::ShardRequestManagerInterface &request_manager);
nlohmann::json ToJson(storage::v3::LabelId label, const DbAccessor &dba);
nlohmann::json ToJson(storage::v3::LabelId label, const msgs::ShardRequestManagerInterface &request_manager);
nlohmann::json ToJson(storage::v3::PropertyId property, const DbAccessor &dba);
nlohmann::json ToJson(storage::v3::PropertyId property, const msgs::ShardRequestManagerInterface &request_manager);
nlohmann::json ToJson(NamedExpression *nexpr);
nlohmann::json ToJson(const std::vector<std::pair<storage::v3::PropertyId, Expression *>> &properties,
const DbAccessor &dba);
const msgs::ShardRequestManagerInterface &request_manager);
nlohmann::json ToJson(const NodeCreationInfo &node_info, const DbAccessor &dba);
nlohmann::json ToJson(const NodeCreationInfo &node_info, const msgs::ShardRequestManagerInterface &request_manager);
nlohmann::json ToJson(const EdgeCreationInfo &edge_info, const DbAccessor &dba);
nlohmann::json ToJson(const EdgeCreationInfo &edge_info, const msgs::ShardRequestManagerInterface &request_manager);
nlohmann::json ToJson(const Aggregate::Element &elem);
@ -160,7 +161,8 @@ nlohmann::json ToJson(const std::vector<T> &items, Args &&...args) {
class PlanToJsonVisitor : public virtual HierarchicalLogicalOperatorVisitor {
public:
explicit PlanToJsonVisitor(const DbAccessor *dba) : dba_(dba) {}
explicit PlanToJsonVisitor(const msgs::ShardRequestManagerInterface *request_manager)
: request_manager_(request_manager) {}
using HierarchicalLogicalOperatorVisitor::PostVisit;
using HierarchicalLogicalOperatorVisitor::PreVisit;
@ -216,7 +218,7 @@ class PlanToJsonVisitor : public virtual HierarchicalLogicalOperatorVisitor {
protected:
nlohmann::json output_;
const DbAccessor *dba_;
const msgs::ShardRequestManagerInterface *request_manager_;
nlohmann::json PopOutput() {
nlohmann::json tmp;

View File

@ -15,6 +15,7 @@
#include <optional>
#include "query/v2/bindings/typed_value.hpp"
#include "query/v2/shard_request_manager.hpp"
#include "storage/v3/conversions.hpp"
#include "storage/v3/id_types.hpp"
#include "storage/v3/property_value.hpp"
@ -28,110 +29,37 @@ namespace memgraph::query::v2::plan {
template <class TDbAccessor>
class VertexCountCache {
public:
VertexCountCache(TDbAccessor *db) : db_(db) {}
explicit VertexCountCache(TDbAccessor *shard_request_manager) : shard_request_manager_{shard_request_manager} {}
auto NameToLabel(const std::string &name) { return db_->NameToLabel(name); }
auto NameToProperty(const std::string &name) { return db_->NameToProperty(name); }
auto NameToEdgeType(const std::string &name) { return db_->NameToEdgeType(name); }
int64_t VerticesCount() {
if (!vertices_count_) vertices_count_ = db_->VerticesCount();
return *vertices_count_;
auto NameToLabel(const std::string &name) { return shard_request_manager_->LabelNameToLabelId(name); }
auto NameToProperty(const std::string &name) { return shard_request_manager_->NameToProperty(name); }
auto NameToEdgeType(const std::string & /*name*/) {
MG_ASSERT(false, "NameToEdgeType");
return storage::v3::EdgeTypeId::FromInt(0);
}
int64_t VerticesCount(storage::v3::LabelId label) {
if (label_vertex_count_.find(label) == label_vertex_count_.end())
label_vertex_count_[label] = db_->VerticesCount(label);
return label_vertex_count_.at(label);
int64_t VerticesCount() { return 1; }
int64_t VerticesCount(storage::v3::LabelId /*label*/) { return 1; }
int64_t VerticesCount(storage::v3::LabelId /*label*/, storage::v3::PropertyId /*property*/) { return 1; }
int64_t VerticesCount(storage::v3::LabelId /*label*/, storage::v3::PropertyId /*property*/,
const storage::v3::PropertyValue & /*value*/) {
return 1;
}
int64_t VerticesCount(storage::v3::LabelId label, storage::v3::PropertyId property) {
auto key = std::make_pair(label, property);
if (label_property_vertex_count_.find(key) == label_property_vertex_count_.end())
label_property_vertex_count_[key] = db_->VerticesCount(label, property);
return label_property_vertex_count_.at(key);
int64_t VerticesCount(storage::v3::LabelId /*label*/, storage::v3::PropertyId /*property*/,
const std::optional<utils::Bound<storage::v3::PropertyValue>> & /*lower*/,
const std::optional<utils::Bound<storage::v3::PropertyValue>> & /*upper*/) {
return 1;
}
int64_t VerticesCount(storage::v3::LabelId label, storage::v3::PropertyId property,
const storage::v3::PropertyValue &value) {
auto label_prop = std::make_pair(label, property);
auto &value_vertex_count = property_value_vertex_count_[label_prop];
// TODO: Why do we even need TypedValue in this whole file?
auto tv_value(storage::v3::PropertyToTypedValue<TypedValue>(value));
if (value_vertex_count.find(tv_value) == value_vertex_count.end())
value_vertex_count[tv_value] = db_->VerticesCount(label, property, value);
return value_vertex_count.at(tv_value);
}
bool LabelIndexExists(storage::v3::LabelId /*label*/) { return false; }
int64_t VerticesCount(storage::v3::LabelId label, storage::v3::PropertyId property,
const std::optional<utils::Bound<storage::v3::PropertyValue>> &lower,
const std::optional<utils::Bound<storage::v3::PropertyValue>> &upper) {
auto label_prop = std::make_pair(label, property);
auto &bounds_vertex_count = property_bounds_vertex_count_[label_prop];
BoundsKey bounds = std::make_pair(lower, upper);
if (bounds_vertex_count.find(bounds) == bounds_vertex_count.end())
bounds_vertex_count[bounds] = db_->VerticesCount(label, property, lower, upper);
return bounds_vertex_count.at(bounds);
}
bool LabelPropertyIndexExists(storage::v3::LabelId /*label*/, storage::v3::PropertyId /*property*/) { return false; }
bool LabelIndexExists(storage::v3::LabelId label) { return db_->LabelIndexExists(label); }
bool LabelPropertyIndexExists(storage::v3::LabelId label, storage::v3::PropertyId property) {
return db_->LabelPropertyIndexExists(label, property);
}
private:
typedef std::pair<storage::v3::LabelId, storage::v3::PropertyId> LabelPropertyKey;
struct LabelPropertyHash {
size_t operator()(const LabelPropertyKey &key) const {
return utils::HashCombine<storage::v3::LabelId, storage::v3::PropertyId>{}(key.first, key.second);
}
};
typedef std::pair<std::optional<utils::Bound<storage::v3::PropertyValue>>,
std::optional<utils::Bound<storage::v3::PropertyValue>>>
BoundsKey;
struct BoundsHash {
size_t operator()(const BoundsKey &key) const {
const auto &maybe_lower = key.first;
const auto &maybe_upper = key.second;
query::v2::TypedValue lower;
query::v2::TypedValue upper;
if (maybe_lower) lower = storage::v3::PropertyToTypedValue<TypedValue>(maybe_lower->value());
if (maybe_upper) upper = storage::v3::PropertyToTypedValue<TypedValue>(maybe_upper->value());
query::v2::TypedValue::Hash hash;
return utils::HashCombine<size_t, size_t>{}(hash(lower), hash(upper));
}
};
struct BoundsEqual {
bool operator()(const BoundsKey &a, const BoundsKey &b) const {
auto bound_equal = [](const auto &maybe_bound_a, const auto &maybe_bound_b) {
if (maybe_bound_a && maybe_bound_b && maybe_bound_a->type() != maybe_bound_b->type()) return false;
query::v2::TypedValue bound_a;
query::v2::TypedValue bound_b;
if (maybe_bound_a) bound_a = storage::v3::PropertyToTypedValue<TypedValue>(maybe_bound_a->value());
if (maybe_bound_b) bound_b = storage::v3::PropertyToTypedValue<TypedValue>(maybe_bound_b->value());
return query::v2::TypedValue::BoolEqual{}(bound_a, bound_b);
};
return bound_equal(a.first, b.first) && bound_equal(a.second, b.second);
}
};
TDbAccessor *db_;
std::optional<int64_t> vertices_count_;
std::unordered_map<storage::v3::LabelId, int64_t> label_vertex_count_;
std::unordered_map<LabelPropertyKey, int64_t, LabelPropertyHash> label_property_vertex_count_;
std::unordered_map<
LabelPropertyKey,
std::unordered_map<query::v2::TypedValue, int64_t, query::v2::TypedValue::Hash, query::v2::TypedValue::BoolEqual>,
LabelPropertyHash>
property_value_vertex_count_;
std::unordered_map<LabelPropertyKey, std::unordered_map<BoundsKey, int64_t, BoundsHash, BoundsEqual>,
LabelPropertyHash>
property_bounds_vertex_count_;
msgs::ShardRequestManagerInterface *shard_request_manager_;
};
template <class TDbAccessor>

View File

@ -1,20 +0,0 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <functional>
#include <memory>
namespace memgraph::query::v2::procedure {
class CypherType;
using CypherTypePtr = std::unique_ptr<CypherType, std::function<void(CypherType *)>>;
} // namespace memgraph::query::v2::procedure

View File

@ -1,293 +0,0 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
/// @file
#pragma once
#include "mg_procedure.h"
#include <functional>
#include <memory>
#include <string_view>
#include "query/v2/bindings/typed_value.hpp"
#include "query/v2/procedure/cypher_type_ptr.hpp"
#include "query/v2/procedure/mg_procedure_impl.hpp"
#include "utils/memory.hpp"
#include "utils/pmr/string.hpp"
namespace memgraph::query::v2::procedure {
class ListType;
class NullableType;
/// Interface for all supported types in openCypher type system.
class CypherType {
public:
CypherType() = default;
virtual ~CypherType() = default;
CypherType(const CypherType &) = delete;
CypherType(CypherType &&) = delete;
CypherType &operator=(const CypherType &) = delete;
CypherType &operator=(CypherType &&) = delete;
/// Get name of the type as it should be presented to the user.
virtual std::string_view GetPresentableName() const = 0;
/// Return true if given mgp_value is of the type as described by `this`.
virtual bool SatisfiesType(const mgp_value &) const = 0;
/// Return true if given TypedValue is of the type as described by `this`.
virtual bool SatisfiesType(const query::v2::TypedValue &) const = 0;
// The following methods are a simple replacement for RTTI because we have
// some special cases we need to handle.
virtual const ListType *AsListType() const { return nullptr; }
virtual const NullableType *AsNullableType() const { return nullptr; }
};
// Simple Types
class AnyType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "ANY"; }
bool SatisfiesType(const mgp_value &value) const override { return value.type != MGP_VALUE_TYPE_NULL; }
bool SatisfiesType(const query::v2::TypedValue &value) const override { return !value.IsNull(); }
};
class BoolType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "BOOLEAN"; }
bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_BOOL; }
bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsBool(); }
};
class StringType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "STRING"; }
bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_STRING; }
bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsString(); }
};
class IntType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "INTEGER"; }
bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_INT; }
bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsInt(); }
};
class FloatType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "FLOAT"; }
bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_DOUBLE; }
bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsDouble(); }
};
class NumberType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "NUMBER"; }
bool SatisfiesType(const mgp_value &value) const override {
return value.type == MGP_VALUE_TYPE_INT || value.type == MGP_VALUE_TYPE_DOUBLE;
}
bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsInt() || value.IsDouble(); }
};
class NodeType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "NODE"; }
bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_VERTEX; }
bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsVertex(); }
};
class RelationshipType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "RELATIONSHIP"; }
bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_EDGE; }
bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsEdge(); }
};
class PathType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "PATH"; }
bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_PATH; }
bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsPath(); }
};
// You'd think that MapType would be a composite type like ListType, but nope.
// Why? No-one really knows. It's defined like that in "CIP2015-09-16 Public
// Type System and Type Annotations"
// Additionally, MapType also covers NodeType and RelationshipType because
// values of that type have property *maps*.
class MapType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "MAP"; }
bool SatisfiesType(const mgp_value &value) const override {
return value.type == MGP_VALUE_TYPE_MAP || value.type == MGP_VALUE_TYPE_VERTEX || value.type == MGP_VALUE_TYPE_EDGE;
}
bool SatisfiesType(const query::v2::TypedValue &value) const override {
return value.IsMap() || value.IsVertex() || value.IsEdge();
}
};
// Temporal Types
class DateType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "DATE"; }
bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_DATE; }
bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsDate(); }
};
class LocalTimeType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "LOCAL_TIME"; }
bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_LOCAL_TIME; }
bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsLocalTime(); }
};
class LocalDateTimeType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "LOCAL_DATE_TIME"; }
bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_LOCAL_DATE_TIME; }
bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsLocalDateTime(); }
};
class DurationType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "DURATION"; }
bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_DURATION; }
bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsDuration(); }
};
// Composite Types
class ListType : public CypherType {
public:
CypherTypePtr element_type_;
utils::pmr::string presentable_name_;
/// @throw std::bad_alloc
/// @throw std::length_error
explicit ListType(CypherTypePtr element_type, utils::MemoryResource *memory)
: element_type_(std::move(element_type)), presentable_name_("LIST OF ", memory) {
presentable_name_.append(element_type_->GetPresentableName());
}
std::string_view GetPresentableName() const override { return presentable_name_; }
bool SatisfiesType(const mgp_value &value) const override {
if (value.type != MGP_VALUE_TYPE_LIST) {
return false;
}
auto *list = value.list_v;
const auto list_size = list->elems.size();
for (size_t i = 0; i < list_size; ++i) {
if (!element_type_->SatisfiesType(list->elems[i])) {
return false;
};
}
return true;
}
bool SatisfiesType(const query::v2::TypedValue &value) const override {
if (!value.IsList()) return false;
for (const auto &elem : value.ValueList()) {
if (!element_type_->SatisfiesType(elem)) return false;
}
return true;
}
const ListType *AsListType() const override { return this; }
};
class NullableType : public CypherType {
CypherTypePtr type_;
utils::pmr::string presentable_name_;
// Constructor is private, because we use a factory method Create to prevent
// nesting NullableType on top of each other.
// @throw std::bad_alloc
// @throw std::length_error
explicit NullableType(CypherTypePtr type, utils::MemoryResource *memory)
: type_(std::move(type)), presentable_name_(memory) {
const auto *list_type = type_->AsListType();
// ListType is specially formatted
if (list_type) {
presentable_name_.assign("LIST? OF ").append(list_type->element_type_->GetPresentableName());
} else {
presentable_name_.assign(type_->GetPresentableName()).append("?");
}
}
public:
/// Create a NullableType of some CypherType.
/// If passed in `type` is already a NullableType, it is returned intact.
/// Otherwise, `type` is wrapped in a new instance of NullableType.
/// @throw std::bad_alloc
/// @throw std::length_error
static CypherTypePtr Create(CypherTypePtr type, utils::MemoryResource *memory) {
if (type->AsNullableType()) return type;
utils::Allocator<NullableType> alloc(memory);
auto *nullable = alloc.allocate(1);
try {
new (nullable) NullableType(std::move(type), memory);
} catch (...) {
alloc.deallocate(nullable, 1);
throw;
}
return CypherTypePtr(nullable, [alloc](CypherType *base_ptr) mutable {
alloc.delete_object(static_cast<NullableType *>(base_ptr));
});
}
std::string_view GetPresentableName() const override { return presentable_name_; }
bool SatisfiesType(const mgp_value &value) const override {
return value.type == MGP_VALUE_TYPE_NULL || type_->SatisfiesType(value);
}
bool SatisfiesType(const query::v2::TypedValue &value) const override {
return value.IsNull() || type_->SatisfiesType(value);
}
const NullableType *AsNullableType() const override { return this; }
};
} // namespace memgraph::query::v2::procedure

View File

@ -1,36 +0,0 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/v2/procedure/mg_procedure_helpers.hpp"
namespace memgraph::query::v2::procedure {
MgpUniquePtr<mgp_value> GetStringValueOrSetError(const char *string, mgp_memory *memory, mgp_result *result) {
procedure::MgpUniquePtr<mgp_value> value{nullptr, mgp_value_destroy};
const auto success =
TryOrSetError([&] { return procedure::CreateMgpObject(value, mgp_value_make_string, string, memory); }, result);
if (!success) {
value.reset();
}
return value;
}
bool InsertResultOrSetError(mgp_result *result, mgp_result_record *record, const char *result_name, mgp_value *value) {
if (const auto err = mgp_result_record_insert(record, result_name, value); err != mgp_error::MGP_ERROR_NO_ERROR) {
const auto error_msg = fmt::format("Unable to set the result for {}, error = {}", result_name, err);
static_cast<void>(mgp_result_set_error_msg(result, error_msg.c_str()));
return false;
}
return true;
}
} // namespace memgraph::query::v2::procedure

View File

@ -1,69 +0,0 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <memory>
#include <type_traits>
#include <utility>
#include <fmt/format.h>
#include "mg_procedure.h"
namespace memgraph::query::v2::procedure {
template <typename TResult, typename TFunc, typename... TArgs>
TResult Call(TFunc func, TArgs... args) {
static_assert(std::is_trivially_copyable_v<TFunc>);
static_assert((std::is_trivially_copyable_v<std::remove_reference_t<TArgs>> && ...));
TResult result{};
MG_ASSERT(func(args..., &result) == mgp_error::MGP_ERROR_NO_ERROR);
return result;
}
template <typename TFunc, typename... TArgs>
bool CallBool(TFunc func, TArgs... args) {
return Call<int>(func, args...) != 0;
}
template <typename TObj>
using MgpRawObjectDeleter = void (*)(TObj *);
template <typename TObj>
using MgpUniquePtr = std::unique_ptr<TObj, MgpRawObjectDeleter<TObj>>;
template <typename TObj, typename TFunc, typename... TArgs>
mgp_error CreateMgpObject(MgpUniquePtr<TObj> &obj, TFunc func, TArgs &&...args) {
TObj *raw_obj{nullptr};
const auto err = func(std::forward<TArgs>(args)..., &raw_obj);
obj.reset(raw_obj);
return err;
}
template <typename Fun>
[[nodiscard]] bool TryOrSetError(Fun &&func, mgp_result *result) {
if (const auto err = func(); err == mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) {
static_cast<void>(mgp_result_set_error_msg(result, "Not enough memory!"));
return false;
} else if (err != mgp_error::MGP_ERROR_NO_ERROR) {
const auto error_msg = fmt::format("Unexpected error ({})!", err);
static_cast<void>(mgp_result_set_error_msg(result, error_msg.c_str()));
return false;
}
return true;
}
[[nodiscard]] MgpUniquePtr<mgp_value> GetStringValueOrSetError(const char *string, mgp_memory *memory,
mgp_result *result);
[[nodiscard]] bool InsertResultOrSetError(mgp_result *result, mgp_result_record *record, const char *result_name,
mgp_value *value);
} // namespace memgraph::query::v2::procedure

File diff suppressed because it is too large Load Diff

View File

@ -1,927 +0,0 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
/// @file
/// Contains private (implementation) declarations and definitions for
/// mg_procedure.h
#pragma once
#include "mg_procedure.h"
#include <optional>
#include <ostream>
#include "integrations/kafka/consumer.hpp"
#include "integrations/pulsar/consumer.hpp"
#include "query/v2/bindings/typed_value.hpp"
#include "query/v2/context.hpp"
#include "query/v2/db_accessor.hpp"
#include "query/v2/frontend/ast/ast.hpp"
#include "query/v2/procedure/cypher_type_ptr.hpp"
#include "storage/v3/view.hpp"
#include "utils/memory.hpp"
#include "utils/pmr/map.hpp"
#include "utils/pmr/string.hpp"
#include "utils/pmr/vector.hpp"
#include "utils/temporal.hpp"
/// Wraps memory resource used in custom procedures.
///
/// This should have been `using mgp_memory = memgraph::utils::MemoryResource`, but that's
/// not valid C++ because we have a forward declare `struct mgp_memory` in
/// mg_procedure.h
/// TODO: Make this extendable in C API, so that custom procedure writer can add
/// their own memory management wrappers.
struct mgp_memory {
memgraph::utils::MemoryResource *impl;
};
/// Immutable container of various values that appear in openCypher.
struct mgp_value {
/// Allocator type so that STL containers are aware that we need one.
using allocator_type = memgraph::utils::Allocator<mgp_value>;
// Construct MGP_VALUE_TYPE_NULL.
explicit mgp_value(memgraph::utils::MemoryResource *) noexcept;
mgp_value(bool, memgraph::utils::MemoryResource *) noexcept;
mgp_value(int64_t, memgraph::utils::MemoryResource *) noexcept;
mgp_value(double, memgraph::utils::MemoryResource *) noexcept;
/// @throw std::bad_alloc
mgp_value(const char *, memgraph::utils::MemoryResource *);
/// Take ownership of the mgp_list, MemoryResource must match.
mgp_value(mgp_list *, memgraph::utils::MemoryResource *) noexcept;
/// Take ownership of the mgp_map, MemoryResource must match.
mgp_value(mgp_map *, memgraph::utils::MemoryResource *) noexcept;
/// Take ownership of the mgp_vertex, MemoryResource must match.
mgp_value(mgp_vertex *, memgraph::utils::MemoryResource *) noexcept;
/// Take ownership of the mgp_edge, MemoryResource must match.
mgp_value(mgp_edge *, memgraph::utils::MemoryResource *) noexcept;
/// Take ownership of the mgp_path, MemoryResource must match.
mgp_value(mgp_path *, memgraph::utils::MemoryResource *) noexcept;
mgp_value(mgp_date *, memgraph::utils::MemoryResource *) noexcept;
mgp_value(mgp_local_time *, memgraph::utils::MemoryResource *) noexcept;
mgp_value(mgp_local_date_time *, memgraph::utils::MemoryResource *) noexcept;
mgp_value(mgp_duration *, memgraph::utils::MemoryResource *) noexcept;
/// Construct by copying memgraph::query::v2::TypedValue using memgraph::utils::MemoryResource.
/// mgp_graph is needed to construct mgp_vertex and mgp_edge.
/// @throw std::bad_alloc
mgp_value(const memgraph::query::v2::TypedValue &, mgp_graph *, memgraph::utils::MemoryResource *);
/// Construct by copying memgraph::storage::v3::PropertyValue using memgraph::utils::MemoryResource.
/// @throw std::bad_alloc
mgp_value(const memgraph::storage::v3::PropertyValue &, memgraph::utils::MemoryResource *);
/// Copy construction without memgraph::utils::MemoryResource is not allowed.
mgp_value(const mgp_value &) = delete;
/// Copy construct using given memgraph::utils::MemoryResource.
/// @throw std::bad_alloc
mgp_value(const mgp_value &, memgraph::utils::MemoryResource *);
/// Move construct using given memgraph::utils::MemoryResource.
/// @throw std::bad_alloc if MemoryResource is different, so we cannot move.
mgp_value(mgp_value &&, memgraph::utils::MemoryResource *);
/// Move construct, memgraph::utils::MemoryResource is inherited.
mgp_value(mgp_value &&other) noexcept : mgp_value(other, other.memory) {}
/// Copy-assignment is not allowed to preserve immutability.
mgp_value &operator=(const mgp_value &) = delete;
/// Move-assignment is not allowed to preserve immutability.
mgp_value &operator=(mgp_value &&) = delete;
~mgp_value() noexcept;
memgraph::utils::MemoryResource *GetMemoryResource() const noexcept { return memory; }
mgp_value_type type;
memgraph::utils::MemoryResource *memory;
union {
bool bool_v;
int64_t int_v;
double double_v;
memgraph::utils::pmr::string string_v;
// We use pointers so that taking ownership via C API is easier. Besides,
// mgp_map cannot use incomplete mgp_value type, because that would be
// undefined behaviour.
mgp_list *list_v;
mgp_map *map_v;
mgp_vertex *vertex_v;
mgp_edge *edge_v;
mgp_path *path_v;
mgp_date *date_v;
mgp_local_time *local_time_v;
mgp_local_date_time *local_date_time_v;
mgp_duration *duration_v;
};
};
inline memgraph::utils::DateParameters MapDateParameters(const mgp_date_parameters *parameters) {
return {.year = parameters->year, .month = parameters->month, .day = parameters->day};
}
struct mgp_date {
/// Allocator type so that STL containers are aware that we need one.
/// We don't actually need this, but it simplifies the C API, because we store
/// the allocator which was used to allocate `this`.
using allocator_type = memgraph::utils::Allocator<mgp_date>;
// Hopefully memgraph::utils::Date copy constructor remains noexcept, so that we can
// have everything noexcept here.
static_assert(std::is_nothrow_copy_constructible_v<memgraph::utils::Date>);
mgp_date(const memgraph::utils::Date &date, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), date(date) {}
mgp_date(const std::string_view string, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), date(memgraph::utils::ParseDateParameters(string).first) {}
mgp_date(const mgp_date_parameters *parameters, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), date(MapDateParameters(parameters)) {}
mgp_date(const int64_t microseconds, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), date(microseconds) {}
mgp_date(const mgp_date &other, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), date(other.date) {}
mgp_date(mgp_date &&other, memgraph::utils::MemoryResource *memory) noexcept : memory(memory), date(other.date) {}
mgp_date(mgp_date &&other) noexcept : memory(other.memory), date(other.date) {}
/// Copy construction without memgraph::utils::MemoryResource is not allowed.
mgp_date(const mgp_date &) = delete;
mgp_date &operator=(const mgp_date &) = delete;
mgp_date &operator=(mgp_date &&) = delete;
~mgp_date() = default;
memgraph::utils::MemoryResource *GetMemoryResource() const noexcept { return memory; }
memgraph::utils::MemoryResource *memory;
memgraph::utils::Date date;
};
inline memgraph::utils::LocalTimeParameters MapLocalTimeParameters(const mgp_local_time_parameters *parameters) {
return {.hour = parameters->hour,
.minute = parameters->minute,
.second = parameters->second,
.millisecond = parameters->millisecond,
.microsecond = parameters->microsecond};
}
struct mgp_local_time {
/// Allocator type so that STL containers are aware that we need one.
/// We don't actually need this, but it simplifies the C API, because we store
/// the allocator which was used to allocate `this`.
using allocator_type = memgraph::utils::Allocator<mgp_local_time>;
// Hopefully memgraph::utils::LocalTime copy constructor remains noexcept, so that we can
// have everything noexcept here.
static_assert(std::is_nothrow_copy_constructible_v<memgraph::utils::LocalTime>);
mgp_local_time(const std::string_view string, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), local_time(memgraph::utils::ParseLocalTimeParameters(string).first) {}
mgp_local_time(const mgp_local_time_parameters *parameters, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), local_time(MapLocalTimeParameters(parameters)) {}
mgp_local_time(const memgraph::utils::LocalTime &local_time, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), local_time(local_time) {}
mgp_local_time(const int64_t microseconds, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), local_time(microseconds) {}
mgp_local_time(const mgp_local_time &other, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), local_time(other.local_time) {}
mgp_local_time(mgp_local_time &&other, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), local_time(other.local_time) {}
mgp_local_time(mgp_local_time &&other) noexcept : memory(other.memory), local_time(other.local_time) {}
/// Copy construction without memgraph::utils::MemoryResource is not allowed.
mgp_local_time(const mgp_local_time &) = delete;
mgp_local_time &operator=(const mgp_local_time &) = delete;
mgp_local_time &operator=(mgp_local_time &&) = delete;
~mgp_local_time() = default;
memgraph::utils::MemoryResource *GetMemoryResource() const noexcept { return memory; }
memgraph::utils::MemoryResource *memory;
memgraph::utils::LocalTime local_time;
};
inline memgraph::utils::LocalDateTime CreateLocalDateTimeFromString(const std::string_view string) {
const auto &[date_parameters, local_time_parameters] = memgraph::utils::ParseLocalDateTimeParameters(string);
return memgraph::utils::LocalDateTime{date_parameters, local_time_parameters};
}
struct mgp_local_date_time {
/// Allocator type so that STL containers are aware that we need one.
/// We don't actually need this, but it simplifies the C API, because we store
/// the allocator which was used to allocate `this`.
using allocator_type = memgraph::utils::Allocator<mgp_local_date_time>;
// Hopefully memgraph::utils::LocalDateTime copy constructor remains noexcept, so that we can
// have everything noexcept here.
static_assert(std::is_nothrow_copy_constructible_v<memgraph::utils::LocalDateTime>);
mgp_local_date_time(const memgraph::utils::LocalDateTime &local_date_time,
memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), local_date_time(local_date_time) {}
mgp_local_date_time(const std::string_view string, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), local_date_time(CreateLocalDateTimeFromString(string)) {}
mgp_local_date_time(const mgp_local_date_time_parameters *parameters,
memgraph::utils::MemoryResource *memory) noexcept
: memory(memory),
local_date_time(MapDateParameters(parameters->date_parameters),
MapLocalTimeParameters(parameters->local_time_parameters)) {}
mgp_local_date_time(const int64_t microseconds, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), local_date_time(microseconds) {}
mgp_local_date_time(const mgp_local_date_time &other, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), local_date_time(other.local_date_time) {}
mgp_local_date_time(mgp_local_date_time &&other, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), local_date_time(other.local_date_time) {}
mgp_local_date_time(mgp_local_date_time &&other) noexcept
: memory(other.memory), local_date_time(other.local_date_time) {}
/// Copy construction without memgraph::utils::MemoryResource is not allowed.
mgp_local_date_time(const mgp_local_date_time &) = delete;
mgp_local_date_time &operator=(const mgp_local_date_time &) = delete;
mgp_local_date_time &operator=(mgp_local_date_time &&) = delete;
~mgp_local_date_time() = default;
memgraph::utils::MemoryResource *GetMemoryResource() const noexcept { return memory; }
memgraph::utils::MemoryResource *memory;
memgraph::utils::LocalDateTime local_date_time;
};
inline memgraph::utils::DurationParameters MapDurationParameters(const mgp_duration_parameters *parameters) {
return {.day = parameters->day,
.hour = parameters->hour,
.minute = parameters->minute,
.second = parameters->second,
.millisecond = parameters->millisecond,
.microsecond = parameters->microsecond};
}
struct mgp_duration {
/// Allocator type so that STL containers are aware that we need one.
/// We don't actually need this, but it simplifies the C API, because we store
/// the allocator which was used to allocate `this`.
using allocator_type = memgraph::utils::Allocator<mgp_duration>;
// Hopefully memgraph::utils::Duration copy constructor remains noexcept, so that we can
// have everything noexcept here.
static_assert(std::is_nothrow_copy_constructible_v<memgraph::utils::Duration>);
mgp_duration(const std::string_view string, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), duration(memgraph::utils::ParseDurationParameters(string)) {}
mgp_duration(const mgp_duration_parameters *parameters, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), duration(MapDurationParameters(parameters)) {}
mgp_duration(const memgraph::utils::Duration &duration, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), duration(duration) {}
mgp_duration(const int64_t microseconds, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), duration(microseconds) {}
mgp_duration(const mgp_duration &other, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), duration(other.duration) {}
mgp_duration(mgp_duration &&other, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), duration(other.duration) {}
mgp_duration(mgp_duration &&other) noexcept : memory(other.memory), duration(other.duration) {}
/// Copy construction without memgraph::utils::MemoryResource is not allowed.
mgp_duration(const mgp_duration &) = delete;
mgp_duration &operator=(const mgp_duration &) = delete;
mgp_duration &operator=(mgp_duration &&) = delete;
~mgp_duration() = default;
memgraph::utils::MemoryResource *GetMemoryResource() const noexcept { return memory; }
memgraph::utils::MemoryResource *memory;
memgraph::utils::Duration duration;
};
struct mgp_list {
/// Allocator type so that STL containers are aware that we need one.
using allocator_type = memgraph::utils::Allocator<mgp_list>;
explicit mgp_list(memgraph::utils::MemoryResource *memory) : elems(memory) {}
mgp_list(memgraph::utils::pmr::vector<mgp_value> &&elems, memgraph::utils::MemoryResource *memory)
: elems(std::move(elems), memory) {}
mgp_list(const mgp_list &other, memgraph::utils::MemoryResource *memory) : elems(other.elems, memory) {}
mgp_list(mgp_list &&other, memgraph::utils::MemoryResource *memory) : elems(std::move(other.elems), memory) {}
mgp_list(mgp_list &&other) noexcept : elems(std::move(other.elems)) {}
/// Copy construction without memgraph::utils::MemoryResource is not allowed.
mgp_list(const mgp_list &) = delete;
mgp_list &operator=(const mgp_list &) = delete;
mgp_list &operator=(mgp_list &&) = delete;
~mgp_list() = default;
memgraph::utils::MemoryResource *GetMemoryResource() const noexcept {
return elems.get_allocator().GetMemoryResource();
}
// C++17 vector can work with incomplete type.
memgraph::utils::pmr::vector<mgp_value> elems;
};
struct mgp_map {
/// Allocator type so that STL containers are aware that we need one.
using allocator_type = memgraph::utils::Allocator<mgp_map>;
explicit mgp_map(memgraph::utils::MemoryResource *memory) : items(memory) {}
mgp_map(memgraph::utils::pmr::map<memgraph::utils::pmr::string, mgp_value> &&items,
memgraph::utils::MemoryResource *memory)
: items(std::move(items), memory) {}
mgp_map(const mgp_map &other, memgraph::utils::MemoryResource *memory) : items(other.items, memory) {}
mgp_map(mgp_map &&other, memgraph::utils::MemoryResource *memory) : items(std::move(other.items), memory) {}
mgp_map(mgp_map &&other) noexcept : items(std::move(other.items)) {}
/// Copy construction without memgraph::utils::MemoryResource is not allowed.
mgp_map(const mgp_map &) = delete;
mgp_map &operator=(const mgp_map &) = delete;
mgp_map &operator=(mgp_map &&) = delete;
~mgp_map() = default;
memgraph::utils::MemoryResource *GetMemoryResource() const noexcept {
return items.get_allocator().GetMemoryResource();
}
// Unfortunately using incomplete type with map is undefined, so mgp_map
// needs to be defined after mgp_value.
memgraph::utils::pmr::map<memgraph::utils::pmr::string, mgp_value> items;
};
struct mgp_map_item {
const char *key;
mgp_value *value;
};
struct mgp_map_items_iterator {
using allocator_type = memgraph::utils::Allocator<mgp_map_items_iterator>;
mgp_map_items_iterator(mgp_map *map, memgraph::utils::MemoryResource *memory)
: memory(memory), map(map), current_it(map->items.begin()) {
if (current_it != map->items.end()) {
current.key = current_it->first.c_str();
current.value = &current_it->second;
}
}
mgp_map_items_iterator(const mgp_map_items_iterator &) = delete;
mgp_map_items_iterator(mgp_map_items_iterator &&) = delete;
mgp_map_items_iterator &operator=(const mgp_map_items_iterator &) = delete;
mgp_map_items_iterator &operator=(mgp_map_items_iterator &&) = delete;
~mgp_map_items_iterator() = default;
memgraph::utils::MemoryResource *GetMemoryResource() const { return memory; }
memgraph::utils::MemoryResource *memory;
mgp_map *map;
decltype(map->items.begin()) current_it;
mgp_map_item current;
};
struct mgp_vertex {
/// Allocator type so that STL containers are aware that we need one.
/// We don't actually need this, but it simplifies the C API, because we store
/// the allocator which was used to allocate `this`.
using allocator_type = memgraph::utils::Allocator<mgp_vertex>;
// Hopefully VertexAccessor copy constructor remains noexcept, so that we can
// have everything noexcept here.
static_assert(std::is_nothrow_copy_constructible_v<memgraph::query::v2::VertexAccessor>);
mgp_vertex(memgraph::query::v2::VertexAccessor v, mgp_graph *graph, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), impl(v), graph(graph) {}
mgp_vertex(const mgp_vertex &other, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), impl(other.impl), graph(other.graph) {}
mgp_vertex(mgp_vertex &&other, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), impl(other.impl), graph(other.graph) {}
mgp_vertex(mgp_vertex &&other) noexcept : memory(other.memory), impl(other.impl), graph(other.graph) {}
/// Copy construction without memgraph::utils::MemoryResource is not allowed.
mgp_vertex(const mgp_vertex &) = delete;
mgp_vertex &operator=(const mgp_vertex &) = delete;
mgp_vertex &operator=(mgp_vertex &&) = delete;
bool operator==(const mgp_vertex &other) const noexcept { return this->impl == other.impl; }
bool operator!=(const mgp_vertex &other) const noexcept { return !(*this == other); };
~mgp_vertex() = default;
memgraph::utils::MemoryResource *GetMemoryResource() const noexcept { return memory; }
memgraph::utils::MemoryResource *memory;
memgraph::query::v2::VertexAccessor impl;
mgp_graph *graph;
};
struct mgp_edge {
/// Allocator type so that STL containers are aware that we need one.
/// We don't actually need this, but it simplifies the C API, because we store
/// the allocator which was used to allocate `this`.
using allocator_type = memgraph::utils::Allocator<mgp_edge>;
// TODO(antaljanosbenjamin): Handle this static assert failure when we will support procedures again
// Hopefully EdgeAccessor copy constructor remains noexcept, so that we can
// have everything noexcept here.
// static_assert(std::is_nothrow_copy_constructible_v<memgraph::query::v2::EdgeAccessor>);
static mgp_edge *Copy(const mgp_edge &edge, mgp_memory &memory);
mgp_edge(const memgraph::query::v2::EdgeAccessor &impl, mgp_graph *graph,
memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), impl(impl), from(impl.From(), graph, memory), to(impl.To(), graph, memory) {}
mgp_edge(const mgp_edge &other, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), impl(other.impl), from(other.from, memory), to(other.to, memory) {}
mgp_edge(mgp_edge &&other, memgraph::utils::MemoryResource *memory) noexcept
: memory(other.memory), impl(other.impl), from(std::move(other.from), memory), to(std::move(other.to), memory) {}
mgp_edge(mgp_edge &&other) noexcept
: memory(other.memory), impl(other.impl), from(std::move(other.from)), to(std::move(other.to)) {}
/// Copy construction without memgraph::utils::MemoryResource is not allowed.
mgp_edge(const mgp_edge &) = delete;
mgp_edge &operator=(const mgp_edge &) = delete;
mgp_edge &operator=(mgp_edge &&) = delete;
~mgp_edge() = default;
bool operator==(const mgp_edge &other) const noexcept { return this->impl == other.impl; }
bool operator!=(const mgp_edge &other) const noexcept { return !(*this == other); };
memgraph::utils::MemoryResource *GetMemoryResource() const noexcept { return memory; }
memgraph::utils::MemoryResource *memory;
memgraph::query::v2::EdgeAccessor impl;
mgp_vertex from;
mgp_vertex to;
};
struct mgp_path {
/// Allocator type so that STL containers are aware that we need one.
using allocator_type = memgraph::utils::Allocator<mgp_path>;
explicit mgp_path(memgraph::utils::MemoryResource *memory) : vertices(memory), edges(memory) {}
mgp_path(const mgp_path &other, memgraph::utils::MemoryResource *memory)
: vertices(other.vertices, memory), edges(other.edges, memory) {}
mgp_path(mgp_path &&other, memgraph::utils::MemoryResource *memory)
: vertices(std::move(other.vertices), memory), edges(std::move(other.edges), memory) {}
mgp_path(mgp_path &&other) noexcept : vertices(std::move(other.vertices)), edges(std::move(other.edges)) {}
/// Copy construction without memgraph::utils::MemoryResource is not allowed.
mgp_path(const mgp_path &) = delete;
mgp_path &operator=(const mgp_path &) = delete;
mgp_path &operator=(mgp_path &&) = delete;
~mgp_path() = default;
memgraph::utils::MemoryResource *GetMemoryResource() const noexcept {
return vertices.get_allocator().GetMemoryResource();
}
memgraph::utils::pmr::vector<mgp_vertex> vertices;
memgraph::utils::pmr::vector<mgp_edge> edges;
};
struct mgp_result_record {
/// Result record signature as defined for mgp_proc.
const memgraph::utils::pmr::map<memgraph::utils::pmr::string,
std::pair<const memgraph::query::v2::procedure::CypherType *, bool>> *signature;
memgraph::utils::pmr::map<memgraph::utils::pmr::string, memgraph::query::v2::TypedValue> values;
};
struct mgp_result {
explicit mgp_result(
const memgraph::utils::pmr::map<memgraph::utils::pmr::string,
std::pair<const memgraph::query::v2::procedure::CypherType *, bool>> *signature,
memgraph::utils::MemoryResource *mem)
: signature(signature), rows(mem) {}
/// Result record signature as defined for mgp_proc.
const memgraph::utils::pmr::map<memgraph::utils::pmr::string,
std::pair<const memgraph::query::v2::procedure::CypherType *, bool>> *signature;
memgraph::utils::pmr::vector<mgp_result_record> rows;
std::optional<memgraph::utils::pmr::string> error_msg;
};
struct mgp_func_result {
mgp_func_result() {}
/// Return Magic function result. If user forgets it, the error is raised
std::optional<memgraph::query::v2::TypedValue> value;
/// Return Magic function result with potential error
std::optional<memgraph::utils::pmr::string> error_msg;
};
struct mgp_graph {
memgraph::query::v2::DbAccessor *impl;
memgraph::storage::v3::View view;
// TODO: Merge `mgp_graph` and `mgp_memory` into a single `mgp_context`. The
// `ctx` field is out of place here.
memgraph::query::v2::ExecutionContext *ctx;
static mgp_graph WritableGraph(memgraph::query::v2::DbAccessor &acc, memgraph::storage::v3::View view,
memgraph::query::v2::ExecutionContext &ctx) {
return mgp_graph{&acc, view, &ctx};
}
static mgp_graph NonWritableGraph(memgraph::query::v2::DbAccessor &acc, memgraph::storage::v3::View view) {
return mgp_graph{&acc, view, nullptr};
}
};
// Prevents user to use ExecutionContext in writable callables
struct mgp_func_context {
memgraph::query::v2::DbAccessor *impl;
memgraph::storage::v3::View view;
};
struct mgp_properties_iterator {
using allocator_type = memgraph::utils::Allocator<mgp_properties_iterator>;
// Define members at the start because we use decltype a lot here, so members
// need to be visible in method definitions.
memgraph::utils::MemoryResource *memory;
mgp_graph *graph;
std::remove_reference_t<decltype(*std::declval<memgraph::query::v2::VertexAccessor>().Properties(graph->view))> pvs;
decltype(pvs.begin()) current_it;
std::optional<std::pair<memgraph::utils::pmr::string, mgp_value>> current;
mgp_property property{nullptr, nullptr};
// Construct with no properties.
explicit mgp_properties_iterator(mgp_graph *graph, memgraph::utils::MemoryResource *memory)
: memory(memory), graph(graph), current_it(pvs.begin()) {}
// May throw who the #$@! knows what because PropertyValueStore doesn't
// document what it throws, and it may surely throw some piece of !@#$
// exception because it's built on top of STL and other libraries.
mgp_properties_iterator(mgp_graph *graph, decltype(pvs) pvs, memgraph::utils::MemoryResource *memory)
: memory(memory), graph(graph), pvs(std::move(pvs)), current_it(this->pvs.begin()) {
if (current_it != this->pvs.end()) {
current.emplace(memgraph::utils::pmr::string(graph->impl->PropertyToName(current_it->first), memory),
mgp_value(current_it->second, memory));
property.name = current->first.c_str();
property.value = &current->second;
}
}
mgp_properties_iterator(const mgp_properties_iterator &) = delete;
mgp_properties_iterator(mgp_properties_iterator &&) = delete;
mgp_properties_iterator &operator=(const mgp_properties_iterator &) = delete;
mgp_properties_iterator &operator=(mgp_properties_iterator &&) = delete;
~mgp_properties_iterator() = default;
memgraph::utils::MemoryResource *GetMemoryResource() const { return memory; }
};
struct mgp_edges_iterator {
using allocator_type = memgraph::utils::Allocator<mgp_edges_iterator>;
// Hopefully mgp_vertex copy constructor remains noexcept, so that we can
// have everything noexcept here.
static_assert(std::is_nothrow_constructible_v<mgp_vertex, const mgp_vertex &, memgraph::utils::MemoryResource *>);
mgp_edges_iterator(const mgp_vertex &v, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), source_vertex(v, memory) {}
mgp_edges_iterator(mgp_edges_iterator &&other) noexcept
: memory(other.memory),
source_vertex(std::move(other.source_vertex)),
in(std::move(other.in)),
in_it(std::move(other.in_it)),
out(std::move(other.out)),
out_it(std::move(other.out_it)),
current_e(std::move(other.current_e)) {}
mgp_edges_iterator(const mgp_edges_iterator &) = delete;
mgp_edges_iterator &operator=(const mgp_edges_iterator &) = delete;
mgp_edges_iterator &operator=(mgp_edges_iterator &&) = delete;
~mgp_edges_iterator() = default;
memgraph::utils::MemoryResource *GetMemoryResource() const { return memory; }
memgraph::utils::MemoryResource *memory;
mgp_vertex source_vertex;
std::optional<std::remove_reference_t<decltype(*source_vertex.impl.InEdges(source_vertex.graph->view))>> in;
std::optional<decltype(in->begin())> in_it;
std::optional<std::remove_reference_t<decltype(*source_vertex.impl.OutEdges(source_vertex.graph->view))>> out;
std::optional<decltype(out->begin())> out_it;
std::optional<mgp_edge> current_e;
};
struct mgp_vertices_iterator {
using allocator_type = memgraph::utils::Allocator<mgp_vertices_iterator>;
/// @throw anything VerticesIterable may throw
mgp_vertices_iterator(mgp_graph *graph, memgraph::utils::MemoryResource *memory)
: memory(memory), graph(graph), vertices(graph->impl->Vertices(graph->view)), current_it(vertices.begin()) {
if (current_it != vertices.end()) {
current_v.emplace(*current_it, graph, memory);
}
}
memgraph::utils::MemoryResource *GetMemoryResource() const { return memory; }
memgraph::utils::MemoryResource *memory;
mgp_graph *graph;
decltype(graph->impl->Vertices(graph->view)) vertices;
decltype(vertices.begin()) current_it;
std::optional<mgp_vertex> current_v;
};
struct mgp_type {
memgraph::query::v2::procedure::CypherTypePtr impl;
};
struct ProcedureInfo {
bool is_write = false;
std::optional<memgraph::query::v2::AuthQuery::Privilege> required_privilege = std::nullopt;
};
struct mgp_proc {
using allocator_type = memgraph::utils::Allocator<mgp_proc>;
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_proc(const char *name, mgp_proc_cb cb, memgraph::utils::MemoryResource *memory, const ProcedureInfo &info = {})
: name(name, memory), cb(cb), args(memory), opt_args(memory), results(memory), info(info) {}
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_proc(const char *name, std::function<void(mgp_list *, mgp_graph *, mgp_result *, mgp_memory *)> cb,
memgraph::utils::MemoryResource *memory, const ProcedureInfo &info = {})
: name(name, memory), cb(cb), args(memory), opt_args(memory), results(memory), info(info) {}
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_proc(const std::string_view name, std::function<void(mgp_list *, mgp_graph *, mgp_result *, mgp_memory *)> cb,
memgraph::utils::MemoryResource *memory, const ProcedureInfo &info = {})
: name(name, memory), cb(cb), args(memory), opt_args(memory), results(memory), info(info) {}
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_proc(const mgp_proc &other, memgraph::utils::MemoryResource *memory)
: name(other.name, memory),
cb(other.cb),
args(other.args, memory),
opt_args(other.opt_args, memory),
results(other.results, memory),
info(other.info) {}
mgp_proc(mgp_proc &&other, memgraph::utils::MemoryResource *memory)
: name(std::move(other.name), memory),
cb(std::move(other.cb)),
args(std::move(other.args), memory),
opt_args(std::move(other.opt_args), memory),
results(std::move(other.results), memory),
info(other.info) {}
mgp_proc(const mgp_proc &other) = default;
mgp_proc(mgp_proc &&other) = default;
mgp_proc &operator=(const mgp_proc &) = delete;
mgp_proc &operator=(mgp_proc &&) = delete;
~mgp_proc() = default;
/// Name of the procedure.
memgraph::utils::pmr::string name;
/// Entry-point for the procedure.
std::function<void(mgp_list *, mgp_graph *, mgp_result *, mgp_memory *)> cb;
/// Required, positional arguments as a (name, type) pair.
memgraph::utils::pmr::vector<
std::pair<memgraph::utils::pmr::string, const memgraph::query::v2::procedure::CypherType *>>
args;
/// Optional positional arguments as a (name, type, default_value) tuple.
memgraph::utils::pmr::vector<
std::tuple<memgraph::utils::pmr::string, const memgraph::query::v2::procedure::CypherType *,
memgraph::query::v2::TypedValue>>
opt_args;
/// Fields this procedure returns, as a (name -> (type, is_deprecated)) map.
memgraph::utils::pmr::map<memgraph::utils::pmr::string,
std::pair<const memgraph::query::v2::procedure::CypherType *, bool>>
results;
ProcedureInfo info;
};
struct mgp_trans {
using allocator_type = memgraph::utils::Allocator<mgp_trans>;
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_trans(const char *name, mgp_trans_cb cb, memgraph::utils::MemoryResource *memory)
: name(name, memory), cb(cb), results(memory) {}
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_trans(const char *name, std::function<void(mgp_messages *, mgp_graph *, mgp_result *, mgp_memory *)> cb,
memgraph::utils::MemoryResource *memory)
: name(name, memory), cb(cb), results(memory) {}
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_trans(const mgp_trans &other, memgraph::utils::MemoryResource *memory)
: name(other.name, memory), cb(other.cb), results(other.results) {}
mgp_trans(mgp_trans &&other, memgraph::utils::MemoryResource *memory)
: name(std::move(other.name), memory), cb(std::move(other.cb)), results(std::move(other.results)) {}
mgp_trans(const mgp_trans &other) = default;
mgp_trans(mgp_trans &&other) = default;
mgp_trans &operator=(const mgp_trans &) = delete;
mgp_trans &operator=(mgp_trans &&) = delete;
~mgp_trans() = default;
/// Name of the transformation.
memgraph::utils::pmr::string name;
/// Entry-point for the transformation.
std::function<void(mgp_messages *, mgp_graph *, mgp_result *, mgp_memory *)> cb;
/// Fields this transformation returns.
memgraph::utils::pmr::map<memgraph::utils::pmr::string,
std::pair<const memgraph::query::v2::procedure::CypherType *, bool>>
results;
};
struct mgp_func {
using allocator_type = memgraph::utils::Allocator<mgp_func>;
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_func(const char *name, mgp_func_cb cb, memgraph::utils::MemoryResource *memory)
: name(name, memory), cb(cb), args(memory), opt_args(memory) {}
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_func(const char *name, std::function<void(mgp_list *, mgp_func_context *, mgp_func_result *, mgp_memory *)> cb,
memgraph::utils::MemoryResource *memory)
: name(name, memory), cb(cb), args(memory), opt_args(memory) {}
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_func(const mgp_func &other, memgraph::utils::MemoryResource *memory)
: name(other.name, memory), cb(other.cb), args(other.args, memory), opt_args(other.opt_args, memory) {}
mgp_func(mgp_func &&other, memgraph::utils::MemoryResource *memory)
: name(std::move(other.name), memory),
cb(std::move(other.cb)),
args(std::move(other.args), memory),
opt_args(std::move(other.opt_args), memory) {}
mgp_func(const mgp_func &other) = default;
mgp_func(mgp_func &&other) = default;
mgp_func &operator=(const mgp_func &) = delete;
mgp_func &operator=(mgp_func &&) = delete;
~mgp_func() = default;
/// Name of the function.
memgraph::utils::pmr::string name;
/// Entry-point for the function.
std::function<void(mgp_list *, mgp_func_context *, mgp_func_result *, mgp_memory *)> cb;
/// Required, positional arguments as a (name, type) pair.
memgraph::utils::pmr::vector<
std::pair<memgraph::utils::pmr::string, const memgraph::query::v2::procedure::CypherType *>>
args;
/// Optional positional arguments as a (name, type, default_value) tuple.
memgraph::utils::pmr::vector<
std::tuple<memgraph::utils::pmr::string, const memgraph::query::v2::procedure::CypherType *,
memgraph::query::v2::TypedValue>>
opt_args;
};
mgp_error MgpTransAddFixedResult(mgp_trans *trans) noexcept;
struct mgp_module {
using allocator_type = memgraph::utils::Allocator<mgp_module>;
explicit mgp_module(memgraph::utils::MemoryResource *memory)
: procedures(memory), transformations(memory), functions(memory) {}
mgp_module(const mgp_module &other, memgraph::utils::MemoryResource *memory)
: procedures(other.procedures, memory),
transformations(other.transformations, memory),
functions(other.functions, memory) {}
mgp_module(mgp_module &&other, memgraph::utils::MemoryResource *memory)
: procedures(std::move(other.procedures), memory),
transformations(std::move(other.transformations), memory),
functions(std::move(other.functions), memory) {}
mgp_module(const mgp_module &) = default;
mgp_module(mgp_module &&) = default;
mgp_module &operator=(const mgp_module &) = delete;
mgp_module &operator=(mgp_module &&) = delete;
~mgp_module() = default;
memgraph::utils::pmr::map<memgraph::utils::pmr::string, mgp_proc> procedures;
memgraph::utils::pmr::map<memgraph::utils::pmr::string, mgp_trans> transformations;
memgraph::utils::pmr::map<memgraph::utils::pmr::string, mgp_func> functions;
};
namespace memgraph::query::v2::procedure {
/// @throw std::bad_alloc
/// @throw std::length_error
/// @throw anything std::ostream::operator<< may throw.
void PrintProcSignature(const mgp_proc &, std::ostream *);
/// @throw std::bad_alloc
/// @throw std::length_error
/// @throw anything std::ostream::operator<< may throw.
void PrintFuncSignature(const mgp_func &, std::ostream &);
bool IsValidIdentifierName(const char *name);
} // namespace memgraph::query::v2::procedure
struct mgp_message {
explicit mgp_message(const memgraph::integrations::kafka::Message &message) : msg{&message} {}
explicit mgp_message(const memgraph::integrations::pulsar::Message &message) : msg{message} {}
using KafkaMessage = const memgraph::integrations::kafka::Message *;
using PulsarMessage = memgraph::integrations::pulsar::Message;
std::variant<KafkaMessage, PulsarMessage> msg;
};
struct mgp_messages {
using allocator_type = memgraph::utils::Allocator<mgp_messages>;
using storage_type = memgraph::utils::pmr::vector<mgp_message>;
explicit mgp_messages(storage_type &&storage) : messages(std::move(storage)) {}
mgp_messages(const mgp_messages &) = delete;
mgp_messages &operator=(const mgp_messages &) = delete;
mgp_messages(mgp_messages &&) = delete;
mgp_messages &operator=(mgp_messages &&) = delete;
~mgp_messages() = default;
storage_type messages;
};
memgraph::query::v2::TypedValue ToTypedValue(const mgp_value &val, memgraph::utils::MemoryResource *memory);

File diff suppressed because it is too large Load Diff

View File

@ -1,246 +0,0 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
/// @file
/// API for loading and registering modules providing custom oC procedures
#pragma once
#include <dlfcn.h>
#include <filesystem>
#include <functional>
#include <optional>
#include <shared_mutex>
#include <string>
#include <string_view>
#include <unordered_map>
#include "query/v2/procedure/cypher_types.hpp"
#include "query/v2/procedure/mg_procedure_impl.hpp"
#include "utils/memory.hpp"
#include "utils/rw_lock.hpp"
class CypherMainVisitorTest;
namespace memgraph::query::v2::procedure {
class Module {
public:
Module() {}
virtual ~Module();
Module(const Module &) = delete;
Module(Module &&) = delete;
Module &operator=(const Module &) = delete;
Module &operator=(Module &&) = delete;
/// Invokes the (optional) shutdown function and closes the module.
virtual bool Close() = 0;
/// Returns registered procedures of this module
virtual const std::map<std::string, mgp_proc, std::less<>> *Procedures() const = 0;
/// Returns registered transformations of this module
virtual const std::map<std::string, mgp_trans, std::less<>> *Transformations() const = 0;
// /// Returns registered functions of this module
virtual const std::map<std::string, mgp_func, std::less<>> *Functions() const = 0;
virtual std::optional<std::filesystem::path> Path() const = 0;
};
/// Proxy for a registered Module, acquires a read lock from ModuleRegistry.
class ModulePtr final {
const Module *module_{nullptr};
std::shared_lock<utils::RWLock> lock_;
public:
ModulePtr() = default;
ModulePtr(std::nullptr_t) {}
ModulePtr(const Module *module, std::shared_lock<utils::RWLock> lock) : module_(module), lock_(std::move(lock)) {}
explicit operator bool() const { return static_cast<bool>(module_); }
const Module &operator*() const { return *module_; }
const Module *operator->() const { return module_; }
};
/// Thread-safe registration of modules from libraries, uses utils::RWLock.
class ModuleRegistry final {
friend CypherMainVisitorTest;
std::map<std::string, std::unique_ptr<Module>, std::less<>> modules_;
mutable utils::RWLock lock_{utils::RWLock::Priority::WRITE};
std::unique_ptr<utils::MemoryResource> shared_{std::make_unique<utils::ResourceWithOutOfMemoryException>()};
bool RegisterModule(std::string_view name, std::unique_ptr<Module> module);
void DoUnloadAllModules();
/// Loads the module if it's in the modules_dir directory
/// @return Whether the module was loaded
bool LoadModuleIfFound(const std::filesystem::path &modules_dir, std::string_view name);
void LoadModulesFromDirectory(const std::filesystem::path &modules_dir);
public:
ModuleRegistry();
/// Set the modules directories that will be used when (re)loading modules.
void SetModulesDirectory(std::vector<std::filesystem::path> modules_dir, const std::filesystem::path &data_directory);
const std::vector<std::filesystem::path> &GetModulesDirectory() const;
/// Atomically load or reload a module with a particular name from the given
/// directory.
///
/// Takes a write lock. If the module exists it is reloaded. Otherwise, the
/// module is loaded from the file whose filename, without the extension,
/// matches the module's name. If multiple such files exist, only one is
/// chosen, in an unspecified manner. If loading of the chosen file fails, no
/// other files are tried.
///
/// Return true if the module was loaded or reloaded successfully, false
/// otherwise.
bool LoadOrReloadModuleFromName(std::string_view name);
/// Atomically unload all modules and then load all possible modules from the
/// set directories.
///
/// Takes a write lock.
void UnloadAndLoadModulesFromDirectories();
/// Find a module with given name or return nullptr.
/// Takes a read lock.
ModulePtr GetModuleNamed(std::string_view name) const;
/// Remove all loaded (non-builtin) modules.
/// Takes a write lock.
void UnloadAllModules();
/// Returns the shared memory allocator used by modules
utils::MemoryResource &GetSharedMemoryResource() noexcept;
bool RegisterMgProcedure(std::string_view name, mgp_proc proc);
const std::filesystem::path &InternalModuleDir() const noexcept;
private:
class SharedLibraryHandle {
public:
SharedLibraryHandle(const std::string &shared_library, int mode) : handle_{dlopen(shared_library.c_str(), mode)} {}
SharedLibraryHandle(const SharedLibraryHandle &) = delete;
SharedLibraryHandle(SharedLibraryHandle &&) = delete;
SharedLibraryHandle operator=(const SharedLibraryHandle &) = delete;
SharedLibraryHandle operator=(SharedLibraryHandle &&) = delete;
~SharedLibraryHandle() {
if (handle_) {
dlclose(handle_);
}
}
private:
void *handle_;
};
#if __has_feature(address_sanitizer)
// This is why we need RTLD_NODELETE and we must not use RTLD_DEEPBIND with
// ASAN: https://github.com/google/sanitizers/issues/89
SharedLibraryHandle libstd_handle{"libstdc++.so.6", RTLD_NOW | RTLD_LOCAL | RTLD_NODELETE};
#else
// The reason behind opening share library during runtime is to avoid issues
// with loading symbols from stdlib. We have encounter issues with locale
// that cause std::cout not being printed and issues when python libraries
// would call stdlib (e.g. pytorch).
// The way that those issues were solved was
// by using RTLD_DEEPBIND. RTLD_DEEPBIND ensures that the lookup for the
// mentioned library will be first performed in the already existing binded
// libraries and then the global namespace.
// RTLD_DEEPBIND => https://linux.die.net/man/3/dlopen
SharedLibraryHandle libstd_handle{"libstdc++.so.6", RTLD_NOW | RTLD_LOCAL | RTLD_DEEPBIND};
#endif
std::vector<std::filesystem::path> modules_dirs_;
std::filesystem::path internal_module_dir_;
};
/// Single, global module registry.
extern ModuleRegistry gModuleRegistry;
/// Return the ModulePtr and `mgp_proc *` of the found procedure after resolving
/// `fully_qualified_procedure_name`. `memory` is used for temporary allocations
/// inside this function. ModulePtr must be kept alive to make sure it won't be
/// unloaded.
std::optional<std::pair<procedure::ModulePtr, const mgp_proc *>> FindProcedure(
const ModuleRegistry &module_registry, std::string_view fully_qualified_procedure_name,
utils::MemoryResource *memory);
/// Return the ModulePtr and `mgp_trans *` of the found transformation after resolving
/// `fully_qualified_transformation_name`. `memory` is used for temporary allocations
/// inside this function. ModulePtr must be kept alive to make sure it won't be
/// unloaded.
std::optional<std::pair<procedure::ModulePtr, const mgp_trans *>> FindTransformation(
const ModuleRegistry &module_registry, std::string_view fully_qualified_transformation_name,
utils::MemoryResource *memory);
/// Return the ModulePtr and `mgp_func *` of the found function after resolving
/// `fully_qualified_function_name` if found. If there is no such function
/// std::nullopt is returned. `memory` is used for temporary allocations
/// inside this function. ModulePtr must be kept alive to make sure it won't be unloaded.
std::optional<std::pair<procedure::ModulePtr, const mgp_func *>> FindFunction(
const ModuleRegistry &module_registry, std::string_view fully_qualified_function_name,
utils::MemoryResource *memory);
template <typename T>
concept IsCallable = utils::SameAsAnyOf<T, mgp_proc, mgp_func>;
template <IsCallable TCall>
void ConstructArguments(const std::vector<TypedValue> &args, const TCall &callable,
const std::string_view fully_qualified_name, mgp_list &args_list, mgp_graph &graph) {
const auto n_args = args.size();
const auto c_args_sz = callable.args.size();
const auto c_opt_args_sz = callable.opt_args.size();
if (n_args < c_args_sz || (n_args - c_args_sz > c_opt_args_sz)) {
if (callable.args.empty() && callable.opt_args.empty()) {
throw QueryRuntimeException("'{}' requires no arguments.", fully_qualified_name);
}
if (callable.opt_args.empty()) {
throw QueryRuntimeException("'{}' requires exactly {} {}.", fully_qualified_name, c_args_sz,
c_args_sz == 1U ? "argument" : "arguments");
}
throw QueryRuntimeException("'{}' requires between {} and {} arguments.", fully_qualified_name, c_args_sz,
c_args_sz + c_opt_args_sz);
}
args_list.elems.reserve(n_args);
auto is_not_optional_arg = [c_args_sz](int i) { return c_args_sz > i; };
for (size_t i = 0; i < n_args; ++i) {
auto arg = args[i];
std::string_view name;
const query::v2::procedure::CypherType *type;
if (is_not_optional_arg(i)) {
name = callable.args[i].first;
type = callable.args[i].second;
} else {
name = std::get<0>(callable.opt_args[i - c_args_sz]);
type = std::get<1>(callable.opt_args[i - c_args_sz]);
}
if (!type->SatisfiesType(arg)) {
throw QueryRuntimeException("'{}' argument named '{}' at position {} must be of type {}.", fully_qualified_name,
name, i, type->GetPresentableName());
}
args_list.elems.emplace_back(std::move(arg), &graph);
}
// Fill missing optional arguments with their default values.
const size_t passed_in_opt_args = n_args - c_args_sz;
for (size_t i = passed_in_opt_args; i < c_opt_args_sz; ++i) {
args_list.elems.emplace_back(std::get<2>(callable.opt_args[i]), &graph);
}
}
} // namespace memgraph::query::v2::procedure

File diff suppressed because it is too large Load Diff

View File

@ -1,82 +0,0 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
/// @file
/// Functions and types for loading Query Modules written in Python.
#pragma once
#include "py/py.hpp"
struct mgp_graph;
struct mgp_memory;
struct mgp_module;
struct mgp_value;
namespace memgraph::query::v2::procedure {
struct PyGraph;
/// Convert an `mgp_value` into a Python object, referencing the given `PyGraph`
/// instance and using the same allocator as the graph.
///
/// Values of type `MGP_VALUE_TYPE_VERTEX`, `MGP_VALUE_TYPE_EDGE` and
/// `MGP_VALUE_TYPE_PATH` are returned as `mgp.Vertex`, `mgp.Edge` and
/// `mgp.Path` respectively, and *not* their internal `_mgp`
/// representations. Other value types are converted to equivalent builtin
/// Python objects.
///
/// Return a non-null `py::Object` instance on success. Otherwise, return a null
/// `py::Object` instance and set the appropriate Python exception.
py::Object MgpValueToPyObject(const mgp_value &value, PyGraph *py_graph);
py::Object MgpValueToPyObject(const mgp_value &value, PyObject *py_graph);
/// Convert a Python object into `mgp_value`, constructing it using the given
/// `mgp_memory` allocator.
///
/// If the user-facing 'mgp' module can be imported, this function will handle
/// conversion of 'mgp.Vertex', 'mgp.Edge' and 'mgp.Path' values.
///
/// @throw std::bad_alloc
/// @throw std::overflow_error if attempting to convert a Python integer which
/// too large to fit into int64_t.
/// @throw std::invalid_argument if the given Python object cannot be converted
/// to an mgp_value (e.g. a dictionary whose keys aren't strings or an object
/// of unsupported type).
mgp_value *PyObjectToMgpValue(PyObject *, mgp_memory *);
/// Create the _mgp module for use in embedded Python.
///
/// The function is to be used before Py_Initialize via the following code.
///
/// PyImport_AppendInittab("_mgp", &query::v2::procedure::PyInitMgpModule);
PyObject *PyInitMgpModule();
/// Create an instance of _mgp.Graph class.
PyObject *MakePyGraph(mgp_graph *, mgp_memory *);
/// Import a module with given name in the context of mgp_module.
///
/// This function can only be called when '_mgp' module has been initialized in
/// Python.
///
/// Return nullptr and set appropriate Python exception on failure.
py::Object ImportPyModule(const char *, mgp_module *);
/// Reload already loaded Python module in the context of mgp_module.
///
/// This function can only be called when '_mgp' module has been initialized in
/// Python.
///
/// Return nullptr and set appropriate Python exception on failure.
py::Object ReloadPyModule(PyObject *, mgp_module *);
} // namespace memgraph::query::v2::procedure

View File

@ -47,6 +47,7 @@ inline bool operator==(const VertexId &lhs, const VertexId &rhs) {
using Gid = size_t;
using PropertyId = memgraph::storage::v3::PropertyId;
using EdgeTypeId = memgraph::storage::v3::EdgeTypeId;
struct EdgeType {
uint64_t id;

View File

@ -105,12 +105,19 @@ class ShardRequestManagerInterface {
virtual ~ShardRequestManagerInterface() = default;
virtual void StartTransaction() = 0;
virtual void Commit() = 0;
virtual std::vector<VertexAccessor> Request(ExecutionState<ScanVerticesRequest> &state) = 0;
virtual std::vector<CreateVerticesResponse> Request(ExecutionState<CreateVerticesRequest> &state,
std::vector<NewVertex> new_vertices) = 0;
virtual std::vector<ExpandOneResponse> Request(ExecutionState<ExpandOneRequest> &state) = 0;
virtual memgraph::storage::v3::PropertyId NameToProperty(const std::string &name) const = 0;
virtual memgraph::storage::v3::LabelId LabelNameToLabelId(const std::string &name) const = 0;
// TODO(antaljanosbenjamin): unify the GetXXXId and NameToId functions to have consistent naming, return type and
// implementation
virtual storage::v3::EdgeTypeId NameToEdgeType(const std::string &name) const = 0;
virtual storage::v3::PropertyId NameToProperty(const std::string &name) const = 0;
virtual storage::v3::LabelId LabelNameToLabelId(const std::string &name) const = 0;
virtual const std::string &PropertyToName(memgraph::storage::v3::PropertyId prop) const = 0;
virtual const std::string &LabelToName(memgraph::storage::v3::LabelId label) const = 0;
virtual const std::string &EdgeTypeToName(memgraph::storage::v3::EdgeTypeId type) const = 0;
virtual bool IsPrimaryKey(PropertyId name) const = 0;
};
@ -155,7 +162,48 @@ class ShardRequestManager : public ShardRequestManagerInterface {
}
}
memgraph::storage::v3::PropertyId NameToProperty(const std::string &name) const override {
void Commit() override {
memgraph::coordinator::HlcRequest req{.last_shard_map_version = shards_map_.GetHlc()};
CoordinatorWriteRequests write_req = req;
auto write_res = coord_cli_.SendWriteRequest(write_req);
if (write_res.HasError()) {
throw std::runtime_error("HLC request for commit failed");
}
auto coordinator_write_response = write_res.GetValue();
auto hlc_response = std::get<memgraph::coordinator::HlcResponse>(coordinator_write_response);
if (hlc_response.fresher_shard_map) {
shards_map_ = hlc_response.fresher_shard_map.value();
}
auto commit_timestamp = hlc_response.new_hlc;
msgs::CommitRequest commit_req{.transaction_id = transaction_id_, .commit_timestamp = commit_timestamp};
for (const auto &[label, space] : shards_map_.label_spaces) {
for (const auto &[key, shard] : space.shards) {
auto &storage_client = GetStorageClientForShard(shard, label);
// TODO(kostasrim) Currently requests return the result directly. Adjust this when the API works MgFuture
// instead.
auto commit_response = storage_client.SendWriteRequest(commit_req);
// RETRY on timeouts?
// Sometimes this produces a timeout. Temporary solution is to use a while(true) as was done in shard_map test
if (commit_response.HasError()) {
throw std::runtime_error("Commit request timed out");
}
WriteResponses write_response_variant = commit_response.GetValue();
auto &response = std::get<CommitResponse>(write_response_variant);
if (!response.success) {
throw std::runtime_error("Commit request did not succeed");
}
}
}
}
storage::v3::EdgeTypeId NameToEdgeType(const std::string & /*name*/) const override {
return memgraph::storage::v3::EdgeTypeId::FromUint(0);
}
storage::v3::PropertyId NameToProperty(const std::string &name) const override {
return *shards_map_.GetPropertyId(name);
}
@ -163,6 +211,19 @@ class ShardRequestManager : public ShardRequestManagerInterface {
return shards_map_.GetLabelId(name);
}
const std::string &PropertyToName(memgraph::storage::v3::PropertyId /*prop*/) const override {
static std::string str{"dummy__prop"};
return str;
}
const std::string &LabelToName(memgraph::storage::v3::LabelId /*label*/) const override {
static std::string str{"dummy__label"};
return str;
}
const std::string &EdgeTypeToName(memgraph::storage::v3::EdgeTypeId /*type*/) const override {
static std::string str{"dummy__edgetype"};
return str;
}
bool IsPrimaryKey(const PropertyId name) const override {
return std::find_if(shards_map_.properties.begin(), shards_map_.properties.end(),
[name](auto &pr) { return pr.second == name; }) != shards_map_.properties.end();
@ -222,7 +283,6 @@ class ShardRequestManager : public ShardRequestManagerInterface {
auto primary_key = state.requests[id].new_vertices[0].primary_key;
auto &storage_client = GetStorageClientForShard(*shard_it, labels[0].id);
WriteRequests req = state.requests[id];
auto ladaksd = std::get<CreateVerticesRequest>(req);
auto write_response_result = storage_client.SendWriteRequest(req);
// RETRY on timeouts?
// Sometimes this produces a timeout. Temporary solution is to use a while(true) as was done in shard_map test

View File

@ -1,45 +0,0 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/v2/stream/common.hpp"
#include <json/json.hpp>
namespace memgraph::query::v2::stream {
namespace {
const std::string kBatchIntervalKey{"batch_interval"};
const std::string kBatchSizeKey{"batch_size"};
const std::string kTransformationName{"transformation_name"};
} // namespace
void to_json(nlohmann::json &data, CommonStreamInfo &&common_info) {
data[kBatchIntervalKey] = common_info.batch_interval.count();
data[kBatchSizeKey] = common_info.batch_size;
data[kTransformationName] = common_info.transformation_name;
}
void from_json(const nlohmann::json &data, CommonStreamInfo &common_info) {
if (const auto batch_interval = data.at(kBatchIntervalKey); !batch_interval.is_null()) {
using BatchInterval = decltype(common_info.batch_interval);
common_info.batch_interval = BatchInterval{batch_interval.get<typename BatchInterval::rep>()};
} else {
common_info.batch_interval = kDefaultBatchInterval;
}
if (const auto batch_size = data.at(kBatchSizeKey); !batch_size.is_null()) {
common_info.batch_size = batch_size.get<decltype(common_info.batch_size)>();
} else {
common_info.batch_size = kDefaultBatchSize;
}
data.at(kTransformationName).get_to(common_info.transformation_name);
}
} // namespace memgraph::query::v2::stream

View File

@ -1,87 +0,0 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <chrono>
#include <cstdint>
#include <functional>
#include <optional>
#include <string>
#include <json/json.hpp>
#include "query/v2/procedure/mg_procedure_impl.hpp"
namespace memgraph::query::v2::stream {
inline constexpr std::chrono::milliseconds kDefaultBatchInterval{100};
inline constexpr int64_t kDefaultBatchSize{1000};
template <typename TMessage>
using ConsumerFunction = std::function<void(const std::vector<TMessage> &)>;
struct CommonStreamInfo {
std::chrono::milliseconds batch_interval;
int64_t batch_size;
std::string transformation_name;
};
template <typename T>
concept ConvertableToJson = requires(T value, nlohmann::json data) {
{ to_json(data, std::move(value)) } -> std::same_as<void>;
{ from_json(data, value) } -> std::same_as<void>;
};
template <typename T>
concept ConvertableToMgpMessage = requires(T value) {
mgp_message{value};
};
template <typename TStream>
concept Stream = requires(TStream stream) {
typename TStream::StreamInfo;
typename TStream::Message;
TStream{std::string{""}, typename TStream::StreamInfo{}, ConsumerFunction<typename TStream::Message>{}};
{ stream.Start() } -> std::same_as<void>;
{ stream.StartWithLimit(uint64_t{}, std::optional<std::chrono::milliseconds>{}) } -> std::same_as<void>;
{ stream.Stop() } -> std::same_as<void>;
{ stream.IsRunning() } -> std::same_as<bool>;
{
stream.Check(std::optional<std::chrono::milliseconds>{}, std::optional<uint64_t>{},
ConsumerFunction<typename TStream::Message>{})
} -> std::same_as<void>;
requires std::same_as<std::decay_t<decltype(std::declval<typename TStream::StreamInfo>().common_info)>,
CommonStreamInfo>;
requires ConvertableToMgpMessage<typename TStream::Message>;
requires ConvertableToJson<typename TStream::StreamInfo>;
};
enum class StreamSourceType : uint8_t { KAFKA, PULSAR };
constexpr std::string_view StreamSourceTypeToString(StreamSourceType type) {
switch (type) {
case StreamSourceType::KAFKA:
return "kafka";
case StreamSourceType::PULSAR:
return "pulsar";
}
}
template <Stream T>
StreamSourceType StreamType(const T & /*stream*/);
const std::string kCommonInfoKey = "common_info";
void to_json(nlohmann::json &data, CommonStreamInfo &&info);
void from_json(const nlohmann::json &data, CommonStreamInfo &common_info);
} // namespace memgraph::query::v2::stream

View File

@ -1,137 +0,0 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/v2/stream/sources.hpp"
#include <json/json.hpp>
#include "integrations/constants.hpp"
namespace memgraph::query::v2::stream {
KafkaStream::KafkaStream(std::string stream_name, StreamInfo stream_info,
ConsumerFunction<integrations::kafka::Message> consumer_function) {
integrations::kafka::ConsumerInfo consumer_info{
.consumer_name = std::move(stream_name),
.topics = std::move(stream_info.topics),
.consumer_group = std::move(stream_info.consumer_group),
.bootstrap_servers = std::move(stream_info.bootstrap_servers),
.batch_interval = stream_info.common_info.batch_interval,
.batch_size = stream_info.common_info.batch_size,
.public_configs = std::move(stream_info.configs),
.private_configs = std::move(stream_info.credentials),
};
consumer_.emplace(std::move(consumer_info), std::move(consumer_function));
};
KafkaStream::StreamInfo KafkaStream::Info(std::string transformation_name) const {
const auto &info = consumer_->Info();
return {{.batch_interval = info.batch_interval,
.batch_size = info.batch_size,
.transformation_name = std::move(transformation_name)},
.topics = info.topics,
.consumer_group = info.consumer_group,
.bootstrap_servers = info.bootstrap_servers,
.configs = info.public_configs,
.credentials = info.private_configs};
}
void KafkaStream::Start() { consumer_->Start(); }
void KafkaStream::StartWithLimit(uint64_t batch_limit, std::optional<std::chrono::milliseconds> timeout) const {
consumer_->StartWithLimit(batch_limit, timeout);
}
void KafkaStream::Stop() { consumer_->Stop(); }
bool KafkaStream::IsRunning() const { return consumer_->IsRunning(); }
void KafkaStream::Check(std::optional<std::chrono::milliseconds> timeout, std::optional<uint64_t> batch_limit,
const ConsumerFunction<integrations::kafka::Message> &consumer_function) const {
consumer_->Check(timeout, batch_limit, consumer_function);
}
utils::BasicResult<std::string> KafkaStream::SetStreamOffset(const int64_t offset) {
return consumer_->SetConsumerOffsets(offset);
}
namespace {
const std::string kTopicsKey{"topics"};
const std::string kConsumerGroupKey{"consumer_group"};
const std::string kBoostrapServers{"bootstrap_servers"};
const std::string kConfigs{"configs"};
const std::string kCredentials{"credentials"};
const std::unordered_map<std::string, std::string> kDefaultConfigsMap;
} // namespace
void to_json(nlohmann::json &data, KafkaStream::StreamInfo &&info) {
data[kCommonInfoKey] = std::move(info.common_info);
data[kTopicsKey] = std::move(info.topics);
data[kConsumerGroupKey] = info.consumer_group;
data[kBoostrapServers] = std::move(info.bootstrap_servers);
data[kConfigs] = std::move(info.configs);
data[kCredentials] = std::move(info.credentials);
}
void from_json(const nlohmann::json &data, KafkaStream::StreamInfo &info) {
data.at(kCommonInfoKey).get_to(info.common_info);
data.at(kTopicsKey).get_to(info.topics);
data.at(kConsumerGroupKey).get_to(info.consumer_group);
data.at(kBoostrapServers).get_to(info.bootstrap_servers);
// These values might not be present in the persisted JSON object
info.configs = data.value(kConfigs, kDefaultConfigsMap);
info.credentials = data.value(kCredentials, kDefaultConfigsMap);
}
PulsarStream::PulsarStream(std::string stream_name, StreamInfo stream_info,
ConsumerFunction<integrations::pulsar::Message> consumer_function) {
integrations::pulsar::ConsumerInfo consumer_info{.batch_size = stream_info.common_info.batch_size,
.batch_interval = stream_info.common_info.batch_interval,
.topics = std::move(stream_info.topics),
.consumer_name = std::move(stream_name),
.service_url = std::move(stream_info.service_url)};
consumer_.emplace(std::move(consumer_info), std::move(consumer_function));
};
PulsarStream::StreamInfo PulsarStream::Info(std::string transformation_name) const {
const auto &info = consumer_->Info();
return {{.batch_interval = info.batch_interval,
.batch_size = info.batch_size,
.transformation_name = std::move(transformation_name)},
.topics = info.topics,
.service_url = info.service_url};
}
void PulsarStream::Start() { consumer_->Start(); }
void PulsarStream::StartWithLimit(uint64_t batch_limit, std::optional<std::chrono::milliseconds> timeout) const {
consumer_->StartWithLimit(batch_limit, timeout);
}
void PulsarStream::Stop() { consumer_->Stop(); }
bool PulsarStream::IsRunning() const { return consumer_->IsRunning(); }
void PulsarStream::Check(std::optional<std::chrono::milliseconds> timeout, std::optional<uint64_t> batch_limit,
const ConsumerFunction<Message> &consumer_function) const {
consumer_->Check(timeout, batch_limit, consumer_function);
}
namespace {
const std::string kServiceUrl{"service_url"};
} // namespace
void to_json(nlohmann::json &data, PulsarStream::StreamInfo &&info) {
data[kCommonInfoKey] = std::move(info.common_info);
data[kTopicsKey] = std::move(info.topics);
data[kServiceUrl] = std::move(info.service_url);
}
void from_json(const nlohmann::json &data, PulsarStream::StreamInfo &info) {
data.at(kCommonInfoKey).get_to(info.common_info);
data.at(kTopicsKey).get_to(info.topics);
data.at(kServiceUrl).get_to(info.service_url);
}
} // namespace memgraph::query::v2::stream

View File

@ -1,95 +0,0 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "query/v2/stream/common.hpp"
#include "integrations/kafka/consumer.hpp"
#include "integrations/pulsar/consumer.hpp"
namespace memgraph::query::v2::stream {
struct KafkaStream {
struct StreamInfo {
CommonStreamInfo common_info;
std::vector<std::string> topics;
std::string consumer_group;
std::string bootstrap_servers;
std::unordered_map<std::string, std::string> configs;
std::unordered_map<std::string, std::string> credentials;
};
using Message = integrations::kafka::Message;
KafkaStream(std::string stream_name, StreamInfo stream_info,
ConsumerFunction<integrations::kafka::Message> consumer_function);
StreamInfo Info(std::string transformation_name) const;
void Start();
void StartWithLimit(uint64_t batch_limit, std::optional<std::chrono::milliseconds> timeout) const;
void Stop();
bool IsRunning() const;
void Check(std::optional<std::chrono::milliseconds> timeout, std::optional<uint64_t> batch_limit,
const ConsumerFunction<Message> &consumer_function) const;
utils::BasicResult<std::string> SetStreamOffset(int64_t offset);
private:
using Consumer = integrations::kafka::Consumer;
std::optional<Consumer> consumer_;
};
void to_json(nlohmann::json &data, KafkaStream::StreamInfo &&info);
void from_json(const nlohmann::json &data, KafkaStream::StreamInfo &info);
template <>
inline StreamSourceType StreamType(const KafkaStream & /*stream*/) {
return StreamSourceType::KAFKA;
}
struct PulsarStream {
struct StreamInfo {
CommonStreamInfo common_info;
std::vector<std::string> topics;
std::string service_url;
};
using Message = integrations::pulsar::Message;
PulsarStream(std::string stream_name, StreamInfo stream_info, ConsumerFunction<Message> consumer_function);
StreamInfo Info(std::string transformation_name) const;
void Start();
void StartWithLimit(uint64_t batch_limit, std::optional<std::chrono::milliseconds> timeout) const;
void Stop();
bool IsRunning() const;
void Check(std::optional<std::chrono::milliseconds> timeout, std::optional<uint64_t> batch_limit,
const ConsumerFunction<Message> &consumer_function) const;
private:
using Consumer = integrations::pulsar::Consumer;
std::optional<Consumer> consumer_;
};
void to_json(nlohmann::json &data, PulsarStream::StreamInfo &&info);
void from_json(const nlohmann::json &data, PulsarStream::StreamInfo &info);
template <>
inline StreamSourceType StreamType(const PulsarStream & /*stream*/) {
return StreamSourceType::PULSAR;
}
} // namespace memgraph::query::v2::stream

View File

@ -1,773 +0,0 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/v2/stream/streams.hpp"
#include <shared_mutex>
#include <string_view>
#include <utility>
#include <spdlog/spdlog.h>
#include <json/json.hpp>
#include "integrations/constants.hpp"
#include "mg_procedure.h"
#include "query/v2/bindings/typed_value.hpp"
#include "query/v2/db_accessor.hpp"
#include "query/v2/discard_value_stream.hpp"
#include "query/v2/exceptions.hpp"
#include "query/v2/interpreter.hpp"
#include "query/v2/procedure/mg_procedure_helpers.hpp"
#include "query/v2/procedure/mg_procedure_impl.hpp"
#include "query/v2/procedure/module.hpp"
#include "query/v2/stream/sources.hpp"
#include "storage/v3/conversions.hpp"
#include "utils/event_counter.hpp"
#include "utils/logging.hpp"
#include "utils/memory.hpp"
#include "utils/on_scope_exit.hpp"
#include "utils/pmr/string.hpp"
#include "utils/variant_helpers.hpp"
namespace EventCounter {
extern const Event MessagesConsumed;
} // namespace EventCounter
namespace memgraph::query::v2::stream {
namespace {
inline constexpr auto kExpectedTransformationResultSize = 2;
inline constexpr auto kCheckStreamResultSize = 2;
const utils::pmr::string query_param_name{"query", utils::NewDeleteResource()};
const utils::pmr::string params_param_name{"parameters", utils::NewDeleteResource()};
const std::map<std::string, storage::v3::PropertyValue> empty_parameters{};
auto GetStream(auto &map, const std::string &stream_name) {
if (auto it = map.find(stream_name); it != map.end()) {
return it;
}
throw StreamsException("Couldn't find stream '{}'", stream_name);
}
std::pair<TypedValue /*query*/, TypedValue /*parameters*/> ExtractTransformationResult(
const utils::pmr::map<utils::pmr::string, TypedValue> &values, const std::string_view transformation_name,
const std::string_view stream_name) {
if (values.size() != kExpectedTransformationResultSize) {
throw StreamsException(
"Transformation '{}' in stream '{}' did not yield all fields (query, parameters) as required.",
transformation_name, stream_name);
}
auto get_value = [&](const utils::pmr::string &field_name) mutable -> const TypedValue & {
auto it = values.find(field_name);
if (it == values.end()) {
throw StreamsException{"Transformation '{}' in stream '{}' did not yield a record with '{}' field.",
transformation_name, stream_name, field_name};
};
return it->second;
};
const auto &query_value = get_value(query_param_name);
MG_ASSERT(query_value.IsString());
const auto &params_value = get_value(params_param_name);
MG_ASSERT(params_value.IsNull() || params_value.IsMap());
return {query_value, params_value};
}
template <typename TMessage>
void CallCustomTransformation(const std::string &transformation_name, const std::vector<TMessage> &messages,
mgp_result &result, storage::v3::Shard::Accessor &storage_accessor,
utils::MemoryResource &memory_resource, const std::string &stream_name) {
DbAccessor db_accessor{&storage_accessor};
{
auto maybe_transformation =
procedure::FindTransformation(procedure::gModuleRegistry, transformation_name, utils::NewDeleteResource());
if (!maybe_transformation) {
throw StreamsException("Couldn't find transformation {} for stream '{}'", transformation_name, stream_name);
};
const auto &trans = *maybe_transformation->second;
mgp_messages mgp_messages{mgp_messages::storage_type{&memory_resource}};
std::transform(messages.begin(), messages.end(), std::back_inserter(mgp_messages.messages),
[](const TMessage &message) { return mgp_message{message}; });
mgp_graph graph{&db_accessor, storage::v3::View::OLD, nullptr};
mgp_memory memory{&memory_resource};
result.rows.clear();
result.error_msg.reset();
result.signature = &trans.results;
MG_ASSERT(result.signature->size() == kExpectedTransformationResultSize);
MG_ASSERT(result.signature->contains(query_param_name));
MG_ASSERT(result.signature->contains(params_param_name));
spdlog::trace("Calling transformation in stream '{}'", stream_name);
trans.cb(&mgp_messages, &graph, &result, &memory);
}
if (result.error_msg.has_value()) {
throw StreamsException(result.error_msg->c_str());
}
}
template <Stream TStream>
StreamStatus<TStream> CreateStatus(std::string stream_name, std::string transformation_name,
std::optional<std::string> owner, const TStream &stream) {
return {.name = std::move(stream_name),
.type = StreamType(stream),
.is_running = stream.IsRunning(),
.info = stream.Info(std::move(transformation_name)),
.owner = std::move(owner)};
}
// nlohmann::json doesn't support string_view access yet
const std::string kStreamName{"name"};
const std::string kIsRunningKey{"is_running"};
const std::string kOwner{"owner"};
const std::string kType{"type"};
} // namespace
template <Stream TStream>
void to_json(nlohmann::json &data, StreamStatus<TStream> &&status) {
data[kStreamName] = std::move(status.name);
data[kType] = status.type;
data[kIsRunningKey] = status.is_running;
if (status.owner.has_value()) {
data[kOwner] = std::move(*status.owner);
} else {
data[kOwner] = nullptr;
}
to_json(data, std::move(status.info));
}
template <Stream TStream>
void from_json(const nlohmann::json &data, StreamStatus<TStream> &status) {
data.at(kStreamName).get_to(status.name);
data.at(kIsRunningKey).get_to(status.is_running);
if (const auto &owner = data.at(kOwner); !owner.is_null()) {
status.owner = owner.get<typename decltype(status.owner)::value_type>();
} else {
status.owner = {};
}
from_json(data, status.info);
}
Streams::Streams(InterpreterContext *interpreter_context, std::filesystem::path directory)
: interpreter_context_(interpreter_context), storage_(std::move(directory)) {
RegisterProcedures();
}
void Streams::RegisterProcedures() {
RegisterKafkaProcedures();
RegisterPulsarProcedures();
}
void Streams::RegisterKafkaProcedures() {
{
static constexpr std::string_view proc_name = "kafka_set_stream_offset";
auto set_stream_offset = [this](mgp_list *args, mgp_graph * /*graph*/, mgp_result *result,
mgp_memory * /*memory*/) {
auto *arg_stream_name = procedure::Call<mgp_value *>(mgp_list_at, args, 0);
const auto *stream_name = procedure::Call<const char *>(mgp_value_get_string, arg_stream_name);
auto *arg_offset = procedure::Call<mgp_value *>(mgp_list_at, args, 1);
const auto offset = procedure::Call<int64_t>(mgp_value_get_int, arg_offset);
auto lock_ptr = streams_.Lock();
auto it = GetStream(*lock_ptr, std::string(stream_name));
std::visit(utils::Overloaded{[&](StreamData<KafkaStream> &kafka_stream) {
auto stream_source_ptr = kafka_stream.stream_source->Lock();
const auto error = stream_source_ptr->SetStreamOffset(offset);
if (error.HasError()) {
MG_ASSERT(mgp_result_set_error_msg(result, error.GetError().c_str()) ==
mgp_error::MGP_ERROR_NO_ERROR,
"Unable to set procedure error message of procedure: {}", proc_name);
}
},
[](auto && /*other*/) {
throw QueryRuntimeException("'{}' can be only used for Kafka stream sources",
proc_name);
}},
it->second);
};
mgp_proc proc(proc_name, set_stream_offset, utils::NewDeleteResource());
MG_ASSERT(mgp_proc_add_arg(&proc, "stream_name", procedure::Call<mgp_type *>(mgp_type_string)) ==
mgp_error::MGP_ERROR_NO_ERROR);
MG_ASSERT(mgp_proc_add_arg(&proc, "offset", procedure::Call<mgp_type *>(mgp_type_int)) ==
mgp_error::MGP_ERROR_NO_ERROR);
procedure::gModuleRegistry.RegisterMgProcedure(proc_name, std::move(proc));
}
{
static constexpr std::string_view proc_name = "kafka_stream_info";
static constexpr std::string_view consumer_group_result_name = "consumer_group";
static constexpr std::string_view topics_result_name = "topics";
static constexpr std::string_view bootstrap_servers_result_name = "bootstrap_servers";
static constexpr std::string_view configs_result_name = "configs";
static constexpr std::string_view credentials_result_name = "credentials";
auto get_stream_info = [this](mgp_list *args, mgp_graph * /*graph*/, mgp_result *result, mgp_memory *memory) {
auto *arg_stream_name = procedure::Call<mgp_value *>(mgp_list_at, args, 0);
const auto *stream_name = procedure::Call<const char *>(mgp_value_get_string, arg_stream_name);
auto lock_ptr = streams_.Lock();
auto it = GetStream(*lock_ptr, std::string(stream_name));
std::visit(
utils::Overloaded{
[&](StreamData<KafkaStream> &kafka_stream) {
auto stream_source_ptr = kafka_stream.stream_source->Lock();
const auto info = stream_source_ptr->Info(kafka_stream.transformation_name);
mgp_result_record *record{nullptr};
if (!procedure::TryOrSetError([&] { return mgp_result_new_record(result, &record); }, result)) {
return;
}
const auto consumer_group_value =
procedure::GetStringValueOrSetError(info.consumer_group.c_str(), memory, result);
if (!consumer_group_value) {
return;
}
procedure::MgpUniquePtr<mgp_list> topic_names{nullptr, mgp_list_destroy};
if (!procedure::TryOrSetError(
[&] {
return procedure::CreateMgpObject(topic_names, mgp_list_make_empty, info.topics.size(),
memory);
},
result)) {
return;
}
for (const auto &topic : info.topics) {
auto topic_value = procedure::GetStringValueOrSetError(topic.c_str(), memory, result);
if (!topic_value) {
return;
}
topic_names->elems.push_back(std::move(*topic_value));
}
procedure::MgpUniquePtr<mgp_value> topics_value{nullptr, mgp_value_destroy};
if (!procedure::TryOrSetError(
[&] {
return procedure::CreateMgpObject(topics_value, mgp_value_make_list, topic_names.get());
},
result)) {
return;
}
static_cast<void>(topic_names.release());
const auto bootstrap_servers_value =
procedure::GetStringValueOrSetError(info.bootstrap_servers.c_str(), memory, result);
if (!bootstrap_servers_value) {
return;
}
const auto convert_config_map =
[result, memory](const std::unordered_map<std::string, std::string> &configs_to_convert)
-> procedure::MgpUniquePtr<mgp_value> {
procedure::MgpUniquePtr<mgp_value> configs_value{nullptr, mgp_value_destroy};
procedure::MgpUniquePtr<mgp_map> configs{nullptr, mgp_map_destroy};
if (!procedure::TryOrSetError(
[&] { return procedure::CreateMgpObject(configs, mgp_map_make_empty, memory); }, result)) {
return configs_value;
}
for (const auto &[key, value] : configs_to_convert) {
auto value_value = procedure::GetStringValueOrSetError(value.c_str(), memory, result);
if (!value_value) {
return configs_value;
}
configs->items.emplace(key, std::move(*value_value));
}
if (!procedure::TryOrSetError(
[&] { return procedure::CreateMgpObject(configs_value, mgp_value_make_map, configs.get()); },
result)) {
return configs_value;
}
static_cast<void>(configs.release());
return configs_value;
};
const auto configs_value = convert_config_map(info.configs);
if (configs_value == nullptr) {
return;
}
using CredentialsType = decltype(KafkaStream::StreamInfo::credentials);
CredentialsType reducted_credentials;
std::transform(info.credentials.begin(), info.credentials.end(),
std::inserter(reducted_credentials, reducted_credentials.end()),
[](const auto &pair) -> CredentialsType::value_type {
return {pair.first, integrations::kReducted};
});
const auto credentials_value = convert_config_map(reducted_credentials);
if (credentials_value == nullptr) {
return;
}
if (!procedure::InsertResultOrSetError(result, record, consumer_group_result_name.data(),
consumer_group_value.get())) {
return;
}
if (!procedure::InsertResultOrSetError(result, record, topics_result_name.data(), topics_value.get())) {
return;
}
if (!procedure::InsertResultOrSetError(result, record, bootstrap_servers_result_name.data(),
bootstrap_servers_value.get())) {
return;
}
if (!procedure::InsertResultOrSetError(result, record, configs_result_name.data(),
configs_value.get())) {
return;
}
if (!procedure::InsertResultOrSetError(result, record, credentials_result_name.data(),
credentials_value.get())) {
return;
}
},
[](auto && /*other*/) {
throw QueryRuntimeException("'{}' can be only used for Kafka stream sources", proc_name);
}},
it->second);
};
mgp_proc proc(proc_name, get_stream_info, utils::NewDeleteResource());
MG_ASSERT(mgp_proc_add_arg(&proc, "stream_name", procedure::Call<mgp_type *>(mgp_type_string)) ==
mgp_error::MGP_ERROR_NO_ERROR);
MG_ASSERT(mgp_proc_add_result(&proc, consumer_group_result_name.data(),
procedure::Call<mgp_type *>(mgp_type_string)) == mgp_error::MGP_ERROR_NO_ERROR);
MG_ASSERT(
mgp_proc_add_result(&proc, topics_result_name.data(),
procedure::Call<mgp_type *>(mgp_type_list, procedure::Call<mgp_type *>(mgp_type_string))) ==
mgp_error::MGP_ERROR_NO_ERROR);
MG_ASSERT(mgp_proc_add_result(&proc, bootstrap_servers_result_name.data(),
procedure::Call<mgp_type *>(mgp_type_string)) == mgp_error::MGP_ERROR_NO_ERROR);
MG_ASSERT(mgp_proc_add_result(&proc, configs_result_name.data(), procedure::Call<mgp_type *>(mgp_type_map)) ==
mgp_error::MGP_ERROR_NO_ERROR);
MG_ASSERT(mgp_proc_add_result(&proc, credentials_result_name.data(), procedure::Call<mgp_type *>(mgp_type_map)) ==
mgp_error::MGP_ERROR_NO_ERROR);
procedure::gModuleRegistry.RegisterMgProcedure(proc_name, std::move(proc));
}
}
void Streams::RegisterPulsarProcedures() {
{
static constexpr std::string_view proc_name = "pulsar_stream_info";
static constexpr std::string_view service_url_result_name = "service_url";
static constexpr std::string_view topics_result_name = "topics";
auto get_stream_info = [this](mgp_list *args, mgp_graph * /*graph*/, mgp_result *result, mgp_memory *memory) {
auto *arg_stream_name = procedure::Call<mgp_value *>(mgp_list_at, args, 0);
const auto *stream_name = procedure::Call<const char *>(mgp_value_get_string, arg_stream_name);
auto lock_ptr = streams_.Lock();
auto it = GetStream(*lock_ptr, std::string(stream_name));
std::visit(
utils::Overloaded{
[&](StreamData<PulsarStream> &pulsar_stream) {
auto stream_source_ptr = pulsar_stream.stream_source->Lock();
const auto info = stream_source_ptr->Info(pulsar_stream.transformation_name);
mgp_result_record *record{nullptr};
if (!procedure::TryOrSetError([&] { return mgp_result_new_record(result, &record); }, result)) {
return;
}
auto service_url_value = procedure::GetStringValueOrSetError(info.service_url.c_str(), memory, result);
if (!service_url_value) {
return;
}
procedure::MgpUniquePtr<mgp_list> topic_names{nullptr, mgp_list_destroy};
if (!procedure::TryOrSetError(
[&] {
return procedure::CreateMgpObject(topic_names, mgp_list_make_empty, info.topics.size(),
memory);
},
result)) {
return;
}
for (const auto &topic : info.topics) {
auto topic_value = procedure::GetStringValueOrSetError(topic.c_str(), memory, result);
if (!topic_value) {
return;
}
topic_names->elems.push_back(std::move(*topic_value));
}
procedure::MgpUniquePtr<mgp_value> topics_value{nullptr, mgp_value_destroy};
if (!procedure::TryOrSetError(
[&] {
return procedure::CreateMgpObject(topics_value, mgp_value_make_list, topic_names.release());
},
result)) {
return;
}
if (!procedure::InsertResultOrSetError(result, record, topics_result_name.data(), topics_value.get())) {
return;
}
if (!procedure::InsertResultOrSetError(result, record, service_url_result_name.data(),
service_url_value.get())) {
return;
}
},
[](auto && /*other*/) {
throw QueryRuntimeException("'{}' can be only used for Pulsar stream sources", proc_name);
}},
it->second);
};
mgp_proc proc(proc_name, get_stream_info, utils::NewDeleteResource());
MG_ASSERT(mgp_proc_add_arg(&proc, "stream_name", procedure::Call<mgp_type *>(mgp_type_string)) ==
mgp_error::MGP_ERROR_NO_ERROR);
MG_ASSERT(mgp_proc_add_result(&proc, service_url_result_name.data(),
procedure::Call<mgp_type *>(mgp_type_string)) == mgp_error::MGP_ERROR_NO_ERROR);
MG_ASSERT(
mgp_proc_add_result(&proc, topics_result_name.data(),
procedure::Call<mgp_type *>(mgp_type_list, procedure::Call<mgp_type *>(mgp_type_string))) ==
mgp_error::MGP_ERROR_NO_ERROR);
procedure::gModuleRegistry.RegisterMgProcedure(proc_name, std::move(proc));
}
}
template <Stream TStream>
void Streams::Create(const std::string &stream_name, typename TStream::StreamInfo info,
std::optional<std::string> owner) {
auto locked_streams = streams_.Lock();
auto it = CreateConsumer<TStream>(*locked_streams, stream_name, std::move(info), std::move(owner));
try {
std::visit(
[&](const auto &stream_data) {
const auto stream_source_ptr = stream_data.stream_source->ReadLock();
Persist(CreateStatus(stream_name, stream_data.transformation_name, stream_data.owner, *stream_source_ptr));
},
it->second);
} catch (...) {
locked_streams->erase(it);
throw;
}
}
template void Streams::Create<KafkaStream>(const std::string &stream_name, KafkaStream::StreamInfo info,
std::optional<std::string> owner);
template void Streams::Create<PulsarStream>(const std::string &stream_name, PulsarStream::StreamInfo info,
std::optional<std::string> owner);
template <Stream TStream>
Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std::string &stream_name,
typename TStream::StreamInfo stream_info,
std::optional<std::string> owner) {
if (map.contains(stream_name)) {
throw StreamsException{"Stream already exists with name '{}'", stream_name};
}
auto *memory_resource = utils::NewDeleteResource();
auto consumer_function = [interpreter_context = interpreter_context_, memory_resource, stream_name,
transformation_name = stream_info.common_info.transformation_name, owner = owner,
interpreter = std::make_shared<Interpreter>(interpreter_context_),
result = mgp_result{nullptr, memory_resource},
total_retries = interpreter_context_->config.stream_transaction_conflict_retries,
retry_interval = interpreter_context_->config.stream_transaction_retry_interval](
const std::vector<typename TStream::Message> &messages) mutable {
auto accessor = interpreter_context->db->Access(coordinator::Hlc{});
EventCounter::IncrementCounter(EventCounter::MessagesConsumed, messages.size());
CallCustomTransformation(transformation_name, messages, result, accessor, *memory_resource, stream_name);
DiscardValueResultStream stream;
spdlog::trace("Start transaction in stream '{}'", stream_name);
utils::OnScopeExit cleanup{[&interpreter, &result]() {
result.rows.clear();
interpreter->Abort();
}};
const static std::map<std::string, storage::v3::PropertyValue> empty_parameters{};
uint32_t i = 0;
while (true) {
try {
interpreter->BeginTransaction();
for (auto &row : result.rows) {
spdlog::trace("Processing row in stream '{}'", stream_name);
auto [query_value, params_value] = ExtractTransformationResult(row.values, transformation_name, stream_name);
storage::v3::PropertyValue params_prop = storage::v3::TypedToPropertyValue(params_value);
std::string query{query_value.ValueString()};
spdlog::trace("Executing query '{}' in stream '{}'", query, stream_name);
auto prepare_result =
interpreter->Prepare(query, params_prop.IsNull() ? empty_parameters : params_prop.ValueMap(), nullptr);
if (!interpreter_context->auth_checker->IsUserAuthorized(owner, prepare_result.privileges)) {
throw StreamsException{
"Couldn't execute query '{}' for stream '{}' because the owner is not authorized to execute the "
"query!",
query, stream_name};
}
interpreter->PullAll(&stream);
}
spdlog::trace("Commit transaction in stream '{}'", stream_name);
interpreter->CommitTransaction();
result.rows.clear();
break;
} catch (const query::v2::TransactionSerializationException &e) {
interpreter->Abort();
if (i == total_retries) {
throw;
}
++i;
std::this_thread::sleep_for(retry_interval);
}
}
};
auto insert_result = map.try_emplace(
stream_name, StreamData<TStream>{std::move(stream_info.common_info.transformation_name), std::move(owner),
std::make_unique<SynchronizedStreamSource<TStream>>(
stream_name, std::move(stream_info), std::move(consumer_function))});
MG_ASSERT(insert_result.second, "Unexpected error during storing consumer '{}'", stream_name);
return insert_result.first;
}
void Streams::RestoreStreams() {
spdlog::info("Loading streams...");
auto locked_streams_map = streams_.Lock();
MG_ASSERT(locked_streams_map->empty(), "Cannot restore streams when some streams already exist!");
for (const auto &[stream_name, stream_data] : storage_) {
const auto get_failed_message = [&stream_name = stream_name](const std::string_view message,
const std::string_view nested_message) {
return fmt::format("Failed to load stream '{}', because: {} caused by {}", stream_name, message, nested_message);
};
const auto create_consumer = [&, &stream_name = stream_name, this]<typename T>(StreamStatus<T> status,
auto &&stream_json_data) {
try {
stream_json_data.get_to(status);
} catch (const nlohmann::json::type_error &exception) {
spdlog::warn(get_failed_message("invalid type conversion", exception.what()));
return;
} catch (const nlohmann::json::out_of_range &exception) {
spdlog::warn(get_failed_message("non existing field", exception.what()));
return;
}
MG_ASSERT(status.name == stream_name, "Expected stream name is '{}', but got '{}'", status.name, stream_name);
try {
auto it = CreateConsumer<T>(*locked_streams_map, stream_name, std::move(status.info), std::move(status.owner));
if (status.is_running) {
std::visit(
[&](const auto &stream_data) {
auto stream_source_ptr = stream_data.stream_source->Lock();
stream_source_ptr->Start();
},
it->second);
}
spdlog::info("Stream '{}' is loaded", stream_name);
} catch (const utils::BasicException &exception) {
spdlog::warn(get_failed_message("unexpected error", exception.what()));
}
};
auto stream_json_data = nlohmann::json::parse(stream_data);
if (const auto it = stream_json_data.find(kType); it != stream_json_data.end()) {
const auto stream_type = static_cast<StreamSourceType>(*it);
switch (stream_type) {
case StreamSourceType::KAFKA:
create_consumer(StreamStatus<KafkaStream>{}, std::move(stream_json_data));
break;
case StreamSourceType::PULSAR:
create_consumer(StreamStatus<PulsarStream>{}, std::move(stream_json_data));
break;
}
} else {
spdlog::warn(
"Unable to load stream '{}', because it does not contain the type of the stream. Most probably the stream "
"was saved before Memgraph 2.1. Please recreate the stream manually to make it work. For more information "
"please check https://memgraph.com/docs/memgraph/changelog#v210---nov-22-2021 .",
stream_json_data.value(kStreamName, "<invalid format>"));
}
}
}
void Streams::Drop(const std::string &stream_name) {
auto locked_streams = streams_.Lock();
auto it = GetStream(*locked_streams, stream_name);
// streams_ is write locked, which means there is no access to it outside of this function, thus only the Test
// function can be executing with the consumer, nothing else.
// By acquiring the write lock here for the consumer, we make sure there is
// no running Test function for this consumer, therefore it can be erased.
std::visit([&](const auto &stream_data) { stream_data.stream_source->Lock(); }, it->second);
locked_streams->erase(it);
if (!storage_.Delete(stream_name)) {
throw StreamsException("Couldn't delete stream '{}' from persistent store!", stream_name);
}
// TODO(antaljanosbenjamin) Release the transformation
}
void Streams::Start(const std::string &stream_name) {
auto locked_streams = streams_.Lock();
auto it = GetStream(*locked_streams, stream_name);
std::visit(
[&, this](const auto &stream_data) {
auto stream_source_ptr = stream_data.stream_source->Lock();
stream_source_ptr->Start();
Persist(CreateStatus(stream_name, stream_data.transformation_name, stream_data.owner, *stream_source_ptr));
},
it->second);
}
void Streams::StartWithLimit(const std::string &stream_name, uint64_t batch_limit,
std::optional<std::chrono::milliseconds> timeout) const {
std::optional locked_streams{streams_.ReadLock()};
auto it = GetStream(**locked_streams, stream_name);
std::visit(
[&](const auto &stream_data) {
const auto locked_stream_source = stream_data.stream_source->ReadLock();
locked_streams.reset();
locked_stream_source->StartWithLimit(batch_limit, timeout);
},
it->second);
}
void Streams::Stop(const std::string &stream_name) {
auto locked_streams = streams_.Lock();
auto it = GetStream(*locked_streams, stream_name);
std::visit(
[&, this](const auto &stream_data) {
auto stream_source_ptr = stream_data.stream_source->Lock();
stream_source_ptr->Stop();
Persist(CreateStatus(stream_name, stream_data.transformation_name, stream_data.owner, *stream_source_ptr));
},
it->second);
}
void Streams::StartAll() {
for (auto locked_streams = streams_.Lock(); auto &[stream_name, stream_data] : *locked_streams) {
std::visit(
[&stream_name = stream_name, this](const auto &stream_data) {
auto locked_stream_source = stream_data.stream_source->Lock();
if (!locked_stream_source->IsRunning()) {
locked_stream_source->Start();
Persist(
CreateStatus(stream_name, stream_data.transformation_name, stream_data.owner, *locked_stream_source));
}
},
stream_data);
}
}
void Streams::StopAll() {
for (auto locked_streams = streams_.Lock(); auto &[stream_name, stream_data] : *locked_streams) {
std::visit(
[&stream_name = stream_name, this](const auto &stream_data) {
auto locked_stream_source = stream_data.stream_source->Lock();
if (locked_stream_source->IsRunning()) {
locked_stream_source->Stop();
Persist(
CreateStatus(stream_name, stream_data.transformation_name, stream_data.owner, *locked_stream_source));
}
},
stream_data);
}
}
std::vector<StreamStatus<>> Streams::GetStreamInfo() const {
std::vector<StreamStatus<>> result;
{
for (auto locked_streams = streams_.ReadLock(); const auto &[stream_name, stream_data] : *locked_streams) {
std::visit(
[&, &stream_name = stream_name](const auto &stream_data) {
auto locked_stream_source = stream_data.stream_source->ReadLock();
auto info = locked_stream_source->Info(stream_data.transformation_name);
result.emplace_back(StreamStatus<>{stream_name, StreamType(*locked_stream_source),
locked_stream_source->IsRunning(), std::move(info.common_info),
stream_data.owner});
},
stream_data);
}
}
return result;
}
TransformationResult Streams::Check(const std::string &stream_name, std::optional<std::chrono::milliseconds> timeout,
std::optional<uint64_t> batch_limit) const {
std::optional locked_streams{streams_.ReadLock()};
auto it = GetStream(**locked_streams, stream_name);
return std::visit(
[&](const auto &stream_data) {
// This depends on the fact that Drop will first acquire a write lock to the consumer, and erase it only after
// that
const auto locked_stream_source = stream_data.stream_source->ReadLock();
const auto transformation_name = stream_data.transformation_name;
locked_streams.reset();
auto *memory_resource = utils::NewDeleteResource();
mgp_result result{nullptr, memory_resource};
TransformationResult test_result;
auto consumer_function = [interpreter_context = interpreter_context_, memory_resource, &stream_name,
&transformation_name = transformation_name, &result,
&test_result]<typename T>(const std::vector<T> &messages) mutable {
auto accessor = interpreter_context->db->Access(coordinator::Hlc{});
CallCustomTransformation(transformation_name, messages, result, accessor, *memory_resource, stream_name);
auto result_row = std::vector<TypedValue>();
result_row.reserve(kCheckStreamResultSize);
auto queries_and_parameters = std::vector<TypedValue>(result.rows.size());
std::transform(
result.rows.cbegin(), result.rows.cend(), queries_and_parameters.begin(), [&](const auto &row) {
auto [query, parameters] = ExtractTransformationResult(row.values, transformation_name, stream_name);
return std::map<std::string, TypedValue>{{"query", std::move(query)},
{"parameters", std::move(parameters)}};
});
result_row.emplace_back(std::move(queries_and_parameters));
auto messages_list = std::vector<TypedValue>(messages.size());
std::transform(messages.cbegin(), messages.cend(), messages_list.begin(), [](const auto &message) {
return std::string_view(message.Payload().data(), message.Payload().size());
});
result_row.emplace_back(std::move(messages_list));
test_result.emplace_back(std::move(result_row));
};
locked_stream_source->Check(timeout, batch_limit, consumer_function);
return test_result;
},
it->second);
}
} // namespace memgraph::query::v2::stream

View File

@ -1,206 +0,0 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <concepts>
#include <functional>
#include <map>
#include <optional>
#include <type_traits>
#include <unordered_map>
#include <json/json.hpp>
#include "integrations/kafka/consumer.hpp"
#include "kvstore/kvstore.hpp"
#include "query/v2/bindings/typed_value.hpp"
#include "query/v2/stream/common.hpp"
#include "query/v2/stream/sources.hpp"
#include "storage/v3/property_value.hpp"
#include "utils/event_counter.hpp"
#include "utils/exceptions.hpp"
#include "utils/rw_lock.hpp"
#include "utils/synchronized.hpp"
class StreamsTest;
namespace memgraph::query::v2 {
struct InterpreterContext;
namespace stream {
class StreamsException : public utils::BasicException {
public:
using BasicException::BasicException;
};
template <typename T>
struct StreamInfo;
template <>
struct StreamInfo<void> {
using Type = CommonStreamInfo;
};
template <Stream TStream>
struct StreamInfo<TStream> {
using Type = typename TStream::StreamInfo;
};
template <typename T>
using StreamInfoType = typename StreamInfo<T>::Type;
template <typename T = void>
struct StreamStatus {
std::string name;
StreamSourceType type;
bool is_running;
StreamInfoType<T> info;
std::optional<std::string> owner;
};
using TransformationResult = std::vector<std::vector<TypedValue>>;
/// Manages Kafka consumers.
///
/// This class is responsible for all query supported actions to happen.
class Streams final {
friend StreamsTest;
public:
/// Initializes the streams.
///
/// @param interpreter_context context to use to run the result of transformations
/// @param directory a directory path to store the persisted streams metadata
Streams(InterpreterContext *interpreter_context, std::filesystem::path directory);
/// Restores the streams from the persisted metadata.
/// The restoration is done in a best effort manner, therefore no exception is thrown on failure, but the error is
/// logged. If a stream was running previously, then after restoration it will be started.
/// This function should only be called when there are no existing streams.
void RestoreStreams();
/// Creates a new import stream.
/// The create implies connecting to the server to get metadata necessary to initialize the stream. This
/// method assures there is no other stream with the same name.
///
/// @param stream_name the name of the stream which can be used to uniquely identify the stream
/// @param stream_info the necessary informations needed to create the Kafka consumer and transform the messages
///
/// @throws StreamsException if the stream with the same name exists or if the creation of Kafka consumer fails
template <Stream TStream>
void Create(const std::string &stream_name, typename TStream::StreamInfo info, std::optional<std::string> owner);
/// Deletes an existing stream and all the data that was persisted.
///
/// @param stream_name name of the stream that needs to be deleted.
///
/// @throws StreamsException if the stream doesn't exist or if the persisted metadata can't be deleted.
void Drop(const std::string &stream_name);
/// Start consuming from a stream.
///
/// @param stream_name name of the stream that needs to be started
///
/// @throws StreamsException if the stream doesn't exist or if the metadata cannot be persisted
/// @throws ConsumerRunningException if the consumer is already running
void Start(const std::string &stream_name);
/// Start consuming from a stream.
///
/// @param stream_name name of the stream that needs to be started
/// @param batch_limit number of batches we want to consume before stopping
/// @param timeout the maximum duration during which the command should run.
///
/// @throws StreamsException if the stream doesn't exist
/// @throws ConsumerRunningException if the consumer is already running
void StartWithLimit(const std::string &stream_name, uint64_t batch_limit,
std::optional<std::chrono::milliseconds> timeout) const;
/// Stop consuming from a stream.
///
/// @param stream_name name of the stream that needs to be stopped
///
/// @throws StreamsException if the stream doesn't exist or if the metadata cannot be persisted
/// @throws ConsumerStoppedException if the consumer is already stopped
void Stop(const std::string &stream_name);
/// Start consuming from all streams that are stopped.
///
/// @throws StreamsException if the metadata cannot be persisted
void StartAll();
/// Stop consuming from all streams that are running.
///
/// @throws StreamsException if the metadata cannot be persisted
void StopAll();
/// Return current status for all streams.
/// It might happend that the is_running field is out of date if the one of the streams stops during the invocation of
/// this function because of an error.
std::vector<StreamStatus<>> GetStreamInfo() const;
/// Do a dry-run consume from a stream.
///
/// @param stream_name name of the stream we want to test
/// @param batch_limit number of batches we want to test before stopping
/// @param timeout the maximum duration during which the command should run.
///
/// @returns A vector of vectors of TypedValue. Each subvector contains two elements, the query string and the
/// nullable parameters map.
///
/// @throws StreamsException if the stream doesn't exist
/// @throws ConsumerRunningException if the consumer is alredy running
/// @throws ConsumerCheckFailedException if the transformation function throws any std::exception during processing
TransformationResult Check(const std::string &stream_name,
std::optional<std::chrono::milliseconds> timeout = std::nullopt,
std::optional<uint64_t> batch_limit = std::nullopt) const;
private:
template <Stream TStream>
using SynchronizedStreamSource = utils::Synchronized<TStream, utils::WritePrioritizedRWLock>;
template <Stream TStream>
struct StreamData {
std::string transformation_name;
std::optional<std::string> owner;
std::unique_ptr<SynchronizedStreamSource<TStream>> stream_source;
};
using StreamDataVariant = std::variant<StreamData<KafkaStream>, StreamData<PulsarStream>>;
using StreamsMap = std::unordered_map<std::string, StreamDataVariant>;
using SynchronizedStreamsMap = utils::Synchronized<StreamsMap, utils::WritePrioritizedRWLock>;
template <Stream TStream>
StreamsMap::iterator CreateConsumer(StreamsMap &map, const std::string &stream_name,
typename TStream::StreamInfo stream_info, std::optional<std::string> owner);
template <Stream TStream>
void Persist(StreamStatus<TStream> &&status) {
const std::string stream_name = status.name;
if (!storage_.Put(stream_name, nlohmann::json(std::move(status)).dump())) {
throw StreamsException{"Couldn't persist steam data for stream '{}'", stream_name};
}
}
void RegisterProcedures();
void RegisterKafkaProcedures();
void RegisterPulsarProcedures();
InterpreterContext *interpreter_context_;
kvstore::KVStore storage_;
SynchronizedStreamsMap streams_;
};
} // namespace stream
} // namespace memgraph::query::v2

View File

@ -183,7 +183,7 @@ std::shared_ptr<Trigger::TriggerPlan> Trigger::GetPlan(DbAccessor *db_accessor,
[](auto &identifier) { return &identifier.first; });
auto logical_plan = MakeLogicalPlan(std::move(ast_storage), utils::Downcast<CypherQuery>(parsed_statements_.query),
parsed_statements_.parameters, db_accessor, predefined_identifiers);
parsed_statements_.parameters, nullptr, predefined_identifiers);
trigger_plan_ = std::make_shared<TriggerPlan>(std::move(logical_plan), std::move(identifiers));
}
@ -210,8 +210,8 @@ void Trigger::Execute(DbAccessor *dba, utils::MonotonicBufferResource *execution
ctx.symbol_table = plan.symbol_table();
ctx.evaluation_context.timestamp = QueryTimestamp();
ctx.evaluation_context.parameters = parsed_statements_.parameters;
ctx.evaluation_context.properties = NamesToProperties(plan.ast_storage().properties_, dba);
ctx.evaluation_context.labels = NamesToLabels(plan.ast_storage().labels_, dba);
ctx.evaluation_context.properties = NamesToProperties(plan.ast_storage().properties_, nullptr);
ctx.evaluation_context.labels = NamesToLabels(plan.ast_storage().labels_, nullptr);
ctx.timer = utils::AsyncTimer(max_execution_time_sec);
ctx.is_shutting_down = is_shutting_down;
ctx.is_profile_query = false;

View File

@ -77,7 +77,7 @@ struct SetObjectProperty {
std::map<std::string, TypedValue> ToMap(DbAccessor *dba) const {
return {{ObjectString<TAccessor>(), TypedValue{object}},
{"key", TypedValue{dba->PropertyToName(key)}},
{"key", TypedValue{1}}, // TODO Fix trigger
{"old", old_value},
{"new", new_value}};
}
@ -97,7 +97,7 @@ struct RemovedObjectProperty {
std::map<std::string, TypedValue> ToMap(DbAccessor *dba) const {
return {{ObjectString<TAccessor>(), TypedValue{object}},
{"key", TypedValue{dba->PropertyToName(key)}},
{"key", TypedValue{1}}, // TODO Fix triggers
{"old", old_value}};
}

View File

@ -11,6 +11,7 @@
#pragma once
#include "storage/v3/result.hpp"
#include "storage/v3/shard.hpp"
namespace memgraph::storage::v3 {
@ -88,7 +89,7 @@ class DbAccessor final {
if (maybe_vertex_acc.HasError()) {
return {std::move(maybe_vertex_acc.GetError())};
}
return VertexAccessor{maybe_vertex_acc.GetValue()};
return maybe_vertex_acc.GetValue();
}
storage::v3::Result<EdgeAccessor> InsertEdge(VertexAccessor *from, VertexAccessor *to,

View File

@ -340,6 +340,7 @@ Shard::~Shard() {}
Shard::Accessor::Accessor(Shard &shard, Transaction &transaction)
: shard_(&shard), transaction_(&transaction), config_(shard_->config_.items) {}
// TODO(jbajic) Remove with next PR
ResultSchema<VertexAccessor> Shard::Accessor::CreateVertexAndValidate(
LabelId primary_label, const std::vector<LabelId> &labels,
const std::vector<std::pair<PropertyId, PropertyValue>> &properties) {
@ -387,16 +388,23 @@ ResultSchema<VertexAccessor> Shard::Accessor::CreateVertexAndValidate(
}
ResultSchema<VertexAccessor> Shard::Accessor::CreateVertexAndValidate(
LabelId primary_label, const std::vector<LabelId> &labels, const std::vector<PropertyValue> &primary_properties,
const std::vector<LabelId> &labels, const std::vector<PropertyValue> &primary_properties,
const std::vector<std::pair<PropertyId, PropertyValue>> &properties) {
if (primary_label != shard_->primary_label_) {
throw utils::BasicException("Cannot add vertex to shard which does not hold the given primary label!");
OOMExceptionEnabler oom_exception;
const auto schema = shard_->GetSchema(shard_->primary_label_)->second;
std::vector<std::pair<PropertyId, PropertyValue>> primary_properties_ordered;
// TODO(jbajic) Maybe react immediately and send Violation
MG_ASSERT("PrimaryKey is invalid size");
for (auto i{0}; i < schema.size(); ++i) {
primary_properties_ordered.emplace_back(schema[i].property_id, primary_properties[i]);
}
auto maybe_schema_violation = GetSchemaValidator().ValidateVertexCreate(primary_label, labels, properties);
auto maybe_schema_violation =
GetSchemaValidator().ValidateVertexCreate(shard_->primary_label_, labels, primary_properties_ordered);
if (maybe_schema_violation) {
return {std::move(*maybe_schema_violation)};
}
OOMExceptionEnabler oom_exception;
auto acc = shard_->vertices_.access();
auto *delta = CreateDeleteObjectDelta(transaction_);
auto [it, inserted] = acc.insert({Vertex{delta, primary_properties}});
@ -408,7 +416,7 @@ ResultSchema<VertexAccessor> Shard::Accessor::CreateVertexAndValidate(
// TODO(jbajic) Improve, maybe delay index update
for (const auto &[property_id, property_value] : properties) {
if (!shard_->schemas_.IsPropertyKey(primary_label, property_id)) {
if (!shard_->schemas_.IsPropertyKey(shard_->primary_label_, property_id)) {
if (const auto err = vertex_acc.SetProperty(property_id, property_value); err.HasError()) {
return {err.GetError()};
}
@ -697,6 +705,8 @@ const std::string &Shard::Accessor::EdgeTypeToName(EdgeTypeId edge_type) const {
return shard_->EdgeTypeToName(edge_type);
}
LabelId Shard::PrimaryLabel() const { return primary_label_; }
void Shard::Accessor::AdvanceCommand() { ++transaction_->command_id; }
void Shard::Accessor::Commit(coordinator::Hlc commit_timestamp) {

View File

@ -208,14 +208,13 @@ class Shard final {
// TODO(gvolfing) this is just a workaround for stitching remove this later.
LabelId GetPrimaryLabel() const noexcept { return shard_->primary_label_; }
/// @throw std::bad_alloc
ResultSchema<VertexAccessor> CreateVertexAndValidate(
LabelId primary_label, const std::vector<LabelId> &labels,
const std::vector<std::pair<PropertyId, PropertyValue>> &properties);
/// @throw std::bad_alloc
ResultSchema<VertexAccessor> CreateVertexAndValidate(
LabelId primary_label, const std::vector<LabelId> &labels, const std::vector<PropertyValue> &primary_properties,
const std::vector<LabelId> &labels, const std::vector<PropertyValue> &primary_properties,
const std::vector<std::pair<PropertyId, PropertyValue>> &properties);
std::optional<VertexAccessor> FindVertex(std::vector<PropertyValue> primary_key, View view);
@ -341,6 +340,8 @@ class Shard final {
const std::string &EdgeTypeToName(EdgeTypeId edge_type) const;
LabelId PrimaryLabel() const;
/// @throw std::bad_alloc
bool CreateIndex(LabelId label, std::optional<uint64_t> desired_commit_timestamp = {});

View File

@ -26,6 +26,7 @@
#include <query/v2/requests.hpp>
#include <storage/v3/shard.hpp>
#include <storage/v3/shard_rsm.hpp>
#include "coordinator/shard_map.hpp"
#include "storage/v3/config.hpp"
namespace memgraph::storage::v3 {
@ -76,7 +77,8 @@ static_assert(kMinimumCronInterval < kMaximumCronInterval,
template <typename IoImpl>
class ShardManager {
public:
ShardManager(io::Io<IoImpl> io, Address coordinator_leader) : io_(io), coordinator_leader_(coordinator_leader) {}
ShardManager(io::Io<IoImpl> io, Address coordinator_leader, coordinator::ShardMap shard_map)
: io_(io), coordinator_leader_(coordinator_leader), shard_map_{std::move(shard_map)} {}
/// Periodic protocol maintenance. Returns the time that Cron should be called again
/// in the future.
@ -135,6 +137,7 @@ class ShardManager {
std::priority_queue<std::pair<Time, uuid>, std::vector<std::pair<Time, uuid>>, std::greater<>> cron_schedule_;
Time next_cron_;
Address coordinator_leader_;
coordinator::ShardMap shard_map_;
std::optional<ResponseFuture<WriteResponse<CoordinatorWriteResponses>>> heartbeat_res_;
// TODO(tyler) over time remove items from initialized_but_not_confirmed_rsm_
@ -212,6 +215,17 @@ class ShardManager {
std::unique_ptr<Shard> shard =
std::make_unique<Shard>(to_init.label_id, to_init.min_key, to_init.max_key, to_init.schema, to_init.config);
// TODO(jbajic) Should be sync with coordinator and not passed
std::unordered_map<uint64_t, std::string> id_to_name;
const auto map_type_ids = [&id_to_name](const auto &name_to_id_type) {
for (const auto &[name, id] : name_to_id_type) {
id_to_name.insert({id.AsUint(), name});
}
};
map_type_ids(shard_map_.edge_types);
map_type_ids(shard_map_.labels);
map_type_ids(shard_map_.properties);
shard->StoreMapping(std::move(id_to_name));
ShardRsm rsm_state{std::move(shard)};

View File

@ -15,6 +15,9 @@
#include "parser/opencypher/parser.hpp"
#include "query/v2/requests.hpp"
#include "storage/v3/key_store.hpp"
#include "storage/v3/property_value.hpp"
#include "storage/v3/schemas.hpp"
#include "storage/v3/shard_rsm.hpp"
#include "storage/v3/value_conversions.hpp"
#include "storage/v3/vertex_accessor.hpp"
@ -78,18 +81,29 @@ std::optional<std::map<PropertyId, Value>> CollectSpecificPropertiesFromAccessor
}
std::optional<std::map<PropertyId, Value>> CollectAllPropertiesFromAccessor(
const memgraph::storage::v3::VertexAccessor &acc, memgraph::storage::v3::View view) {
const memgraph::storage::v3::VertexAccessor &acc, memgraph::storage::v3::View view,
const memgraph::storage::v3::Schemas::Schema *schema) {
std::map<PropertyId, Value> ret;
auto iter = acc.Properties(view);
if (iter.HasError()) {
auto props = acc.Properties(view);
if (props.HasError()) {
spdlog::debug("Encountered an error while trying to get vertex properties.");
return std::nullopt;
}
for (const auto &[prop_key, prop_val] : iter.GetValue()) {
for (const auto &[prop_key, prop_val] : props.GetValue()) {
ret.emplace(prop_key, ToValue(prop_val));
}
auto maybe_pk = acc.PrimaryKey(view);
if (maybe_pk.HasError()) {
spdlog::debug("Encountered an error while trying to get vertex primary key.");
}
const auto pk = maybe_pk.GetValue();
MG_ASSERT(schema->second.size() == pk.size(), "PrimaryKey size does not match schema!");
for (size_t i{0}; i < schema->second.size(); ++i) {
ret.emplace(schema->second[i].property_id, ToValue(pk[i]));
}
return ret;
}
@ -406,10 +420,6 @@ namespace memgraph::storage::v3 {
msgs::WriteResponses ShardRsm::ApplyWrite(msgs::CreateVerticesRequest &&req) {
auto acc = shard_->Access(req.transaction_id);
// Workaround untill we have access to CreateVertexAndValidate()
// with the new signature that does not require the primary label.
const auto prim_label = acc.GetPrimaryLabel();
bool action_successful = true;
for (auto &new_vertex : req.new_vertices) {
@ -428,10 +438,12 @@ msgs::WriteResponses ShardRsm::ApplyWrite(msgs::CreateVerticesRequest &&req) {
for (const auto &label_id : new_vertex.label_ids) {
converted_label_ids.emplace_back(label_id.id);
}
auto result_schema =
acc.CreateVertexAndValidate(prim_label, converted_label_ids,
ConvertPropertyVector(std::move(new_vertex.primary_key)), converted_property_map);
// TODO(jbajic) sending primary key as vector breaks validation on storage side
// cannot map id -> value
PrimaryKey transformed_pk;
std::transform(new_vertex.primary_key.begin(), new_vertex.primary_key.end(), std::back_inserter(transformed_pk),
[](const auto &val) { return ToPropertyValue(val); });
auto result_schema = acc.CreateVertexAndValidate(converted_label_ids, transformed_pk, converted_property_map);
if (result_schema.HasError()) {
auto &error = result_schema.GetError();
@ -695,17 +707,18 @@ msgs::ReadResponses ShardRsm::HandleRead(msgs::ScanVerticesRequest &&req) {
for (auto it = vertex_iterable.begin(); it != vertex_iterable.end(); ++it) {
const auto &vertex = *it;
if (start_ids == vertex.PrimaryKey(View(req.storage_view)).GetValue()) {
if (start_ids <= vertex.PrimaryKey(View(req.storage_view)).GetValue()) {
did_reach_starting_point = true;
}
if (did_reach_starting_point) {
std::optional<std::map<PropertyId, Value>> found_props;
const auto *schema = shard_->GetSchema(shard_->PrimaryLabel());
if (req.props_to_return) {
found_props = CollectSpecificPropertiesFromAccessor(vertex, req.props_to_return.value(), view);
} else {
found_props = CollectAllPropertiesFromAccessor(vertex, view);
found_props = CollectAllPropertiesFromAccessor(vertex, view, schema);
}
// TODO(gvolfing) -VERIFY-
@ -720,7 +733,7 @@ msgs::ReadResponses ShardRsm::HandleRead(msgs::ScanVerticesRequest &&req) {
.props = FromMap(found_props.value())});
++sample_counter;
if (sample_counter == req.batch_limit) {
if (req.batch_limit && sample_counter == req.batch_limit) {
// Reached the maximum specified batch size.
// Get the next element before exiting.
const auto &next_vertex = *(++it);

View File

@ -151,7 +151,7 @@ MachineManager<LocalTransport> MkMm(LocalSystem &local_system, std::vector<Addre
Coordinator coordinator{shard_map};
return MachineManager{io, config, coordinator};
return MachineManager{io, config, coordinator, shard_map};
}
void RunMachine(MachineManager<LocalTransport> mm) { mm.Run(); }

View File

@ -26,7 +26,6 @@
#include "query/v2/plan/operator.hpp"
#include "query/v2/plan/operator.hpp"
#include "query_v2_query_plan_common.hpp"
class Dummy : public testing::Test {
protected: