Merge remote-tracking branch 'origin/project-pineapples' into E118-MG-lexicographically-ordered-storage

This commit is contained in:
János Benjamin Antal 2022-08-02 08:19:43 +02:00
commit cc5ee6a496
119 changed files with 45898 additions and 421 deletions

View File

@ -70,6 +70,11 @@ jobs:
# branches and tags. (default: 1)
fetch-depth: 0
# This is also needed if we want do to comparison against other branches
# See https://github.community/t/checkout-code-fails-when-it-runs-lerna-run-test-since-master/17920
- name: Fetch all history for all tags and branches
run: git fetch
- name: Build combined ASAN, UBSAN and coverage binaries
run: |
# Activate toolchain.
@ -110,12 +115,22 @@ jobs:
name: "Code coverage"
path: tools/github/generated/code_coverage.tar.gz
- name: Set base branch
if: ${{ github.event_name == 'pull_request' }}
run: |
echo "BASE_BRANCH=origin/${{ github.base_ref }}" >> $GITHUB_ENV
- name: Set base branch # if we manually dispatch or push to master
if: ${{ github.event_name != 'pull_request' }}
run: |
echo "BASE_BRANCH=origin/master" >> $GITHUB_ENV
- name: Run clang-tidy
run: |
source /opt/toolchain-v4/activate
# Restrict clang-tidy results only to the modified parts
git diff -U0 master... -- src | ./tools/github/clang-tidy/clang-tidy-diff.py -p 1 -j $THREADS -path build | tee ./build/clang_tidy_output.txt
git diff -U0 ${{ env.BASE_BRANCH }}... -- src | ./tools/github/clang-tidy/clang-tidy-diff.py -p 1 -j $THREADS -path build | tee ./build/clang_tidy_output.txt
# Fail if any warning is reported
! cat ./build/clang_tidy_output.txt | ./tools/github/clang-tidy/grep_error_lines.sh > /dev/null

7
.gitignore vendored
View File

@ -23,6 +23,7 @@ cmake-build-*
cmake/DownloadProject/
dist/
src/query/frontend/opencypher/generated/
src/query/v2/frontend/opencypher/generated/
tags
ve/
ve3/
@ -50,15 +51,21 @@ src/distributed/pull_produce_rpc_messages.hpp
src/distributed/storage_gc_rpc_messages.hpp
src/distributed/token_sharing_rpc_messages.hpp
src/distributed/updates_rpc_messages.hpp
src/query/v2/frontend/ast/ast.hpp
src/query/frontend/ast/ast.hpp
src/query/distributed/frontend/ast/ast_serialization.hpp
src/query/v2/distributed/frontend/ast/ast_serialization.hpp
src/durability/distributed/state_delta.hpp
src/durability/single_node/state_delta.hpp
src/durability/single_node_ha/state_delta.hpp
src/query/frontend/semantic/symbol.hpp
src/query/v2/frontend/semantic/symbol.hpp
src/query/distributed/frontend/semantic/symbol_serialization.hpp
src/query/v2/distributed/frontend/semantic/symbol_serialization.hpp
src/query/distributed/plan/ops.hpp
src/query/v2/distributed/plan/ops.hpp
src/query/plan/operator.hpp
src/query/v2/plan/operator.hpp
src/raft/log_entry.hpp
src/raft/raft_rpc_messages.hpp
src/raft/snapshot_metadata.hpp

1
libs/.gitignore vendored
View File

@ -5,3 +5,4 @@
!CMakeLists.txt
!__main.cpp
!pulsar.patch
!antlr4.10.1.patch

View File

@ -106,6 +106,7 @@ import_external_library(antlr4 STATIC
-DWITH_LIBCXX=OFF # because of debian bug
-DCMAKE_SKIP_INSTALL_ALL_DEPENDENCY=true
-DCMAKE_CXX_STANDARD=20
-DANTLR_BUILD_CPP_TESTS=OFF
BUILD_COMMAND $(MAKE) antlr4_static
INSTALL_COMMAND $(MAKE) install)

13
libs/antlr4.10.1.patch Normal file
View File

@ -0,0 +1,13 @@
diff --git a/runtime/Cpp/runtime/CMakeLists.txt b/runtime/Cpp/runtime/CMakeLists.txt
index baf46cac9..2e7756de8 100644
--- a/runtime/Cpp/runtime/CMakeLists.txt
+++ b/runtime/Cpp/runtime/CMakeLists.txt
@@ -134,7 +134,7 @@ set_target_properties(antlr4_static
ARCHIVE_OUTPUT_DIRECTORY ${LIB_OUTPUT_DIR}
COMPILE_FLAGS "${disabled_compile_warnings} ${extra_static_compile_flags}")
-install(TARGETS antlr4_shared
+install(TARGETS antlr4_shared OPTIONAL
EXPORT antlr4-targets
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}

View File

@ -1,43 +0,0 @@
diff --git a/runtime/Cpp/runtime/CMakeLists.txt b/runtime/Cpp/runtime/CMakeLists.txt
index a8503bb..11362cf 100644
--- a/runtime/Cpp/runtime/CMakeLists.txt
+++ b/runtime/Cpp/runtime/CMakeLists.txt
@@ -5,8 +5,8 @@ set(THIRDPARTY_DIR ${CMAKE_BINARY_DIR}/runtime/thirdparty)
set(UTFCPP_DIR ${THIRDPARTY_DIR}/utfcpp)
ExternalProject_Add(
utfcpp
- GIT_REPOSITORY "git://github.com/nemtrif/utfcpp"
- GIT_TAG "v3.1.1"
+ GIT_REPOSITORY "https://github.com/nemtrif/utfcpp"
+ GIT_TAG "v3.2.1"
SOURCE_DIR ${UTFCPP_DIR}
UPDATE_DISCONNECTED 1
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${UTFCPP_DIR}/install -Dgtest_force_shared_crt=ON
@@ -118,7 +118,7 @@ set_target_properties(antlr4_static
ARCHIVE_OUTPUT_DIRECTORY ${LIB_OUTPUT_DIR}
COMPILE_FLAGS "${disabled_compile_warnings} ${extra_static_compile_flags}")
-install(TARGETS antlr4_shared
+install(TARGETS antlr4_shared OPTIONAL
DESTINATION lib
EXPORT antlr4-targets)
install(TARGETS antlr4_static
diff --git a/runtime/Cpp/runtime/src/support/Any.h b/runtime/Cpp/runtime/src/support/Any.h
index 468db98..65a473b 100644
--- a/runtime/Cpp/runtime/src/support/Any.h
+++ b/runtime/Cpp/runtime/src/support/Any.h
@@ -122,12 +122,12 @@ private:
}
private:
- template<int N = 0, typename std::enable_if<N == N && std::is_nothrow_copy_constructible<T>::value, int>::type = 0>
+ template<int N = 0, typename std::enable_if<N == N && std::is_copy_constructible<T>::value, int>::type = 0>
Base* clone() const {
return new Derived<T>(value);
}
- template<int N = 0, typename std::enable_if<N == N && !std::is_nothrow_copy_constructible<T>::value, int>::type = 0>
+ template<int N = 0, typename std::enable_if<N == N && !std::is_copy_constructible<T>::value, int>::type = 0>
Base* clone() const {
return nullptr;
}

View File

@ -105,7 +105,7 @@ repo_clone_try_double () {
# Download from primary_urls might fail because the cache is not installed.
declare -A primary_urls=(
["antlr4-code"]="http://$local_cache_host/git/antlr4.git"
["antlr4-generator"]="http://$local_cache_host/file/antlr-4.9.2-complete.jar"
["antlr4-generator"]="http://$local_cache_host/file/antlr-4.10.1-complete.jar"
["cppitertools"]="http://$local_cache_host/git/cppitertools.git"
["rapidcheck"]="http://$local_cache_host/git/rapidcheck.git"
["gbenchmark"]="http://$local_cache_host/git/benchmark.git"
@ -130,7 +130,7 @@ declare -A primary_urls=(
# should fail.
declare -A secondary_urls=(
["antlr4-code"]="https://github.com/antlr/antlr4.git"
["antlr4-generator"]="http://www.antlr.org/download/antlr-4.9.2-complete.jar"
["antlr4-generator"]="https://www.antlr.org/download/antlr-4.10.1-complete.jar"
["cppitertools"]="https://github.com/ryanhaining/cppitertools.git"
["rapidcheck"]="https://github.com/emil-e/rapidcheck.git"
["gbenchmark"]="https://github.com/google/benchmark.git"
@ -152,10 +152,10 @@ declare -A secondary_urls=(
# antlr
file_get_try_double "${primary_urls[antlr4-generator]}" "${secondary_urls[antlr4-generator]}"
antlr4_tag="4.9.2" # v4.9.2
antlr4_tag="4.10.1" # v4.10.1
repo_clone_try_double "${primary_urls[antlr4-code]}" "${secondary_urls[antlr4-code]}" "antlr4" "$antlr4_tag" true
pushd antlr4
git apply ../antlr4.patch
git apply ../antlr4.10.1.patch
popd
# cppitertools v2.0 2019-12-23
@ -199,7 +199,7 @@ git apply ../rocksdb.patch
popd
# mgclient
mgclient_tag="96e95c6845463cbe88948392be58d26da0d5ffd3" # (2022-02-08)
mgclient_tag="v1.4.0" # (2022-06-14)
repo_clone_try_double "${primary_urls[mgclient]}" "${secondary_urls[mgclient]}" "mgclient" "$mgclient_tag"
sed -i 's/\${CMAKE_INSTALL_LIBDIR}/lib/' mgclient/src/CMakeLists.txt

View File

@ -13,6 +13,7 @@ add_subdirectory(storage/v2)
add_subdirectory(storage/v3)
add_subdirectory(integrations)
add_subdirectory(query)
add_subdirectory(query/v2)
add_subdirectory(slk)
add_subdirectory(rpc)
add_subdirectory(auth)

View File

@ -1252,9 +1252,8 @@ int main(int argc, char **argv) {
// the triggers
auto storage_accessor = interpreter_context.db->Access();
auto dba = memgraph::query::DbAccessor{&storage_accessor};
interpreter_context.trigger_store.RestoreTriggers(&interpreter_context.ast_cache, &dba,
&interpreter_context.antlr_lock, interpreter_context.config.query,
interpreter_context.auth_checker);
interpreter_context.trigger_store.RestoreTriggers(
&interpreter_context.ast_cache, &dba, interpreter_context.config.query, interpreter_context.auth_checker);
}
// As the Stream transformations are using modules, they have to be restored after the query modules are loaded.

View File

@ -82,7 +82,7 @@ add_custom_command(
OUTPUT ${antlr_opencypher_generated_src} ${antlr_opencypher_generated_include}
COMMAND ${CMAKE_COMMAND} -E make_directory ${opencypher_generated}
COMMAND
java -jar ${CMAKE_SOURCE_DIR}/libs/antlr-4.9.2-complete.jar
java -jar ${CMAKE_SOURCE_DIR}/libs/antlr-4.10.1-complete.jar
-Dlanguage=Cpp -visitor -package antlropencypher
-o ${opencypher_generated}
${opencypher_lexer_grammar} ${opencypher_parser_grammar}

View File

@ -21,8 +21,7 @@ namespace memgraph::query {
CachedPlan::CachedPlan(std::unique_ptr<LogicalPlan> plan) : plan_(std::move(plan)) {}
ParsedQuery ParseQuery(const std::string &query_string, const std::map<std::string, storage::PropertyValue> &params,
utils::SkipList<QueryCacheEntry> *cache, utils::SpinLock *antlr_lock,
const InterpreterConfig::Query &query_config) {
utils::SkipList<QueryCacheEntry> *cache, const InterpreterConfig::Query &query_config) {
// Strip the query for caching purposes. The process of stripping a query
// "normalizes" it by replacing any literals with new parameters. This
// results in just the *structure* of the query being taken into account for
@ -63,20 +62,16 @@ ParsedQuery ParseQuery(const std::string &query_string, const std::map<std::stri
};
if (it == accessor.end()) {
{
std::unique_lock<utils::SpinLock> guard(*antlr_lock);
try {
parser = std::make_unique<frontend::opencypher::Parser>(stripped_query.query());
} catch (const SyntaxException &e) {
// There is a syntax exception in the stripped query. Re-run the parser
// on the original query to get an appropriate error messsage.
parser = std::make_unique<frontend::opencypher::Parser>(query_string);
try {
parser = std::make_unique<frontend::opencypher::Parser>(stripped_query.query());
} catch (const SyntaxException &e) {
// There is a syntax exception in the stripped query. Re-run the parser
// on the original query to get an appropriate error messsage.
parser = std::make_unique<frontend::opencypher::Parser>(query_string);
// If an exception was not thrown here, the stripper messed something
// up.
LOG_FATAL("The stripped query can't be parsed, but the original can.");
}
// If an exception was not thrown here, the stripper messed something
// up.
LOG_FATAL("The stripped query can't be parsed, but the original can.");
}
// Convert the ANTLR4 parse tree into an AST.

View File

@ -111,8 +111,7 @@ struct ParsedQuery {
};
ParsedQuery ParseQuery(const std::string &query_string, const std::map<std::string, storage::PropertyValue> &params,
utils::SkipList<QueryCacheEntry> *cache, utils::SpinLock *antlr_lock,
const InterpreterConfig::Query &query_config);
utils::SkipList<QueryCacheEntry> *cache, const InterpreterConfig::Query &query_config);
class SingleNodeLogicalPlan final : public LogicalPlan {
public:

File diff suppressed because it is too large Load Diff

View File

@ -115,7 +115,7 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
auto operators = ExtractOperators(all_children, allowed_operators);
for (auto *expression : _expressions) {
expressions.push_back(expression->accept(this));
expressions.push_back(std::any_cast<Expression *>(expression->accept(this)));
}
Expression *first_operand = expressions[0];
@ -131,7 +131,7 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
DMG_ASSERT(_expression, "can't happen");
auto operators = ExtractOperators(all_children, allowed_operators);
Expression *expression = _expression->accept(this);
Expression *expression = std::any_cast<Expression *>(_expression->accept(this));
for (int i = (int)operators.size() - 1; i >= 0; --i) {
expression = CreateUnaryOperatorByToken(operators[i], expression);
}

View File

@ -1181,7 +1181,7 @@ PreparedQuery PrepareExplainQuery(ParsedQuery parsed_query, std::map<std::string
// full query string) when given just the inner query to execute.
ParsedQuery parsed_inner_query =
ParseQuery(parsed_query.query_string.substr(kExplainQueryStart.size()), parsed_query.user_parameters,
&interpreter_context->ast_cache, &interpreter_context->antlr_lock, interpreter_context->config.query);
&interpreter_context->ast_cache, interpreter_context->config.query);
auto *cypher_query = utils::Downcast<CypherQuery>(parsed_inner_query.query);
MG_ASSERT(cypher_query, "Cypher grammar should not allow other queries in EXPLAIN");
@ -1248,7 +1248,7 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra
// full query string) when given just the inner query to execute.
ParsedQuery parsed_inner_query =
ParseQuery(parsed_query.query_string.substr(kProfileQueryStart.size()), parsed_query.user_parameters,
&interpreter_context->ast_cache, &interpreter_context->antlr_lock, interpreter_context->config.query);
&interpreter_context->ast_cache, interpreter_context->config.query);
auto *cypher_query = utils::Downcast<CypherQuery>(parsed_inner_query.query);
MG_ASSERT(cypher_query, "Cypher grammar should not allow other queries in PROFILE");
@ -1566,8 +1566,7 @@ Callback CreateTrigger(TriggerQuery *trigger_query,
interpreter_context->trigger_store.AddTrigger(
std::move(trigger_name), trigger_statement, user_parameters, ToTriggerEventType(event_type),
before_commit ? TriggerPhase::BEFORE_COMMIT : TriggerPhase::AFTER_COMMIT, &interpreter_context->ast_cache,
dba, &interpreter_context->antlr_lock, interpreter_context->config.query, std::move(owner),
interpreter_context->auth_checker);
dba, interpreter_context->config.query, std::move(owner), interpreter_context->auth_checker);
return {};
}};
}
@ -2123,8 +2122,8 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
query_execution->summary["cost_estimate"] = 0.0;
utils::Timer parsing_timer;
ParsedQuery parsed_query = ParseQuery(query_string, params, &interpreter_context_->ast_cache,
&interpreter_context_->antlr_lock, interpreter_context_->config.query);
ParsedQuery parsed_query =
ParseQuery(query_string, params, &interpreter_context_->ast_cache, interpreter_context_->config.query);
query_execution->summary["parsing_time"] = parsing_timer.Elapsed().count();
// Some queries require an active transaction in order to be prepared.

View File

@ -173,13 +173,6 @@ struct InterpreterContext {
storage::Storage *db;
// ANTLR has singleton instance that is shared between threads. It is
// protected by locks inside of ANTLR. Unfortunately, they are not protected
// in a very good way. Once we have ANTLR version without race conditions we
// can remove this lock. This will probably never happen since ANTLR
// developers introduce more bugs in each version. Fortunately, we have
// cache so this lock probably won't impact performance much...
utils::SpinLock antlr_lock;
std::optional<double> tsc_frequency{utils::GetTSCFrequency()};
std::atomic<bool> is_shutting_down{false};

View File

@ -2600,13 +2600,13 @@ namespace {
* when there are */
TypedValue DefaultAggregationOpValue(const Aggregate::Element &element, utils::MemoryResource *memory) {
switch (element.op) {
case Aggregation::Op::COUNT:
return TypedValue(0, memory);
case Aggregation::Op::SUM:
case Aggregation::Op::MIN:
case Aggregation::Op::MAX:
case Aggregation::Op::AVG:
return TypedValue(memory);
case Aggregation::Op::COUNT:
case Aggregation::Op::SUM:
return TypedValue(0, memory);
case Aggregation::Op::COLLECT_LIST:
return TypedValue(TypedValue::TVector(memory));
case Aggregation::Op::COLLECT_MAP:
@ -2628,9 +2628,7 @@ class AggregateCursor : public Cursor {
pulled_all_input_ = true;
aggregation_it_ = aggregation_.begin();
// in case there is no input and no group_bys we need to return true
// just this once
if (aggregation_.empty() && self_.group_by_.empty()) {
if (aggregation_.empty()) {
auto *pull_memory = context.evaluation_context.memory;
// place default aggregation values on the frame
for (const auto &elem : self_.aggregations_)

View File

@ -153,10 +153,10 @@ std::vector<std::pair<Identifier, TriggerIdentifierTag>> GetPredefinedIdentifier
Trigger::Trigger(std::string name, const std::string &query,
const std::map<std::string, storage::PropertyValue> &user_parameters,
const TriggerEventType event_type, utils::SkipList<QueryCacheEntry> *query_cache,
DbAccessor *db_accessor, utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config,
DbAccessor *db_accessor, const InterpreterConfig::Query &query_config,
std::optional<std::string> owner, const query::AuthChecker *auth_checker)
: name_{std::move(name)},
parsed_statements_{ParseQuery(query, user_parameters, query_cache, antlr_lock, query_config)},
parsed_statements_{ParseQuery(query, user_parameters, query_cache, query_config)},
event_type_{event_type},
owner_{std::move(owner)} {
// We check immediately if the query is valid by trying to create a plan.
@ -257,7 +257,7 @@ inline constexpr uint64_t kVersion{2};
TriggerStore::TriggerStore(std::filesystem::path directory) : storage_{std::move(directory)} {}
void TriggerStore::RestoreTriggers(utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor,
utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config,
const InterpreterConfig::Query &query_config,
const query::AuthChecker *auth_checker) {
MG_ASSERT(before_commit_triggers_.size() == 0 && after_commit_triggers_.size() == 0,
"Cannot restore trigger when some triggers already exist!");
@ -317,8 +317,8 @@ void TriggerStore::RestoreTriggers(utils::SkipList<QueryCacheEntry> *query_cache
std::optional<Trigger> trigger;
try {
trigger.emplace(trigger_name, statement, user_parameters, event_type, query_cache, db_accessor, antlr_lock,
query_config, std::move(owner), auth_checker);
trigger.emplace(trigger_name, statement, user_parameters, event_type, query_cache, db_accessor, query_config,
std::move(owner), auth_checker);
} catch (const utils::BasicException &e) {
spdlog::warn("Failed to create trigger '{}' because: {}", trigger_name, e.what());
continue;
@ -336,8 +336,8 @@ void TriggerStore::AddTrigger(std::string name, const std::string &query,
const std::map<std::string, storage::PropertyValue> &user_parameters,
TriggerEventType event_type, TriggerPhase phase,
utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor,
utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config,
std::optional<std::string> owner, const query::AuthChecker *auth_checker) {
const InterpreterConfig::Query &query_config, std::optional<std::string> owner,
const query::AuthChecker *auth_checker) {
std::unique_lock store_guard{store_lock_};
if (storage_.Get(name)) {
throw utils::BasicException("Trigger with the same name already exists.");
@ -345,8 +345,8 @@ void TriggerStore::AddTrigger(std::string name, const std::string &query,
std::optional<Trigger> trigger;
try {
trigger.emplace(std::move(name), query, user_parameters, event_type, query_cache, db_accessor, antlr_lock,
query_config, std::move(owner), auth_checker);
trigger.emplace(std::move(name), query, user_parameters, event_type, query_cache, db_accessor, query_config,
std::move(owner), auth_checker);
} catch (const utils::BasicException &e) {
const auto identifiers = GetPredefinedIdentifiers(event_type);
std::stringstream identifier_names_stream;

View File

@ -34,7 +34,7 @@ namespace memgraph::query {
struct Trigger {
explicit Trigger(std::string name, const std::string &query,
const std::map<std::string, storage::PropertyValue> &user_parameters, TriggerEventType event_type,
utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor, utils::SpinLock *antlr_lock,
utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor,
const InterpreterConfig::Query &query_config, std::optional<std::string> owner,
const query::AuthChecker *auth_checker);
@ -81,14 +81,13 @@ struct TriggerStore {
explicit TriggerStore(std::filesystem::path directory);
void RestoreTriggers(utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor,
utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config,
const query::AuthChecker *auth_checker);
const InterpreterConfig::Query &query_config, const query::AuthChecker *auth_checker);
void AddTrigger(std::string name, const std::string &query,
const std::map<std::string, storage::PropertyValue> &user_parameters, TriggerEventType event_type,
TriggerPhase phase, utils::SkipList<QueryCacheEntry> *query_cache, DbAccessor *db_accessor,
utils::SpinLock *antlr_lock, const InterpreterConfig::Query &query_config,
std::optional<std::string> owner, const query::AuthChecker *auth_checker);
const InterpreterConfig::Query &query_config, std::optional<std::string> owner,
const query::AuthChecker *auth_checker);
void DropTrigger(const std::string &name);

104
src/query/v2/CMakeLists.txt Normal file
View File

@ -0,0 +1,104 @@
define_add_lcp(add_lcp_query lcp_query_v2_cpp_files generated_lcp_query_v2_files)
add_lcp_query(frontend/ast/ast.lcp)
add_lcp_query(frontend/semantic/symbol.lcp)
add_lcp_query(plan/operator.lcp)
add_custom_target(generate_lcp_query_v2 DEPENDS ${generated_lcp_query_v2_files})
set(mg_query_v2_sources
${lcp_query_v2_cpp_files}
common.cpp
cypher_query_interpreter.cpp
dump.cpp
frontend/ast/cypher_main_visitor.cpp
frontend/ast/pretty_print.cpp
frontend/parsing.cpp
frontend/semantic/required_privileges.cpp
frontend/semantic/symbol_generator.cpp
frontend/stripped.cpp
interpret/awesome_memgraph_functions.cpp
interpret/eval.cpp
interpreter.cpp
metadata.cpp
plan/operator.cpp
plan/preprocess.cpp
plan/pretty_print.cpp
plan/profile.cpp
plan/read_write_type_checker.cpp
plan/rewrite/index_lookup.cpp
plan/rule_based_planner.cpp
plan/variable_start_planner.cpp
procedure/mg_procedure_impl.cpp
procedure/mg_procedure_helpers.cpp
procedure/module.cpp
procedure/py_module.cpp
serialization/property_value.cpp
stream/streams.cpp
stream/sources.cpp
stream/common.cpp
trigger.cpp
trigger_context.cpp
typed_value.cpp)
find_package(Boost REQUIRED)
add_library(mg-query-v2 STATIC ${mg_query_v2_sources})
add_dependencies(mg-query-v2 generate_lcp_query_v2)
target_include_directories(mg-query-v2 PUBLIC ${CMAKE_SOURCE_DIR}/include)
target_link_libraries(mg-query-v2 dl cppitertools Boost::headers)
target_link_libraries(mg-query-v2 mg-integrations-pulsar mg-integrations-kafka mg-storage-v3 mg-license mg-utils mg-kvstore mg-memory)
if(NOT "${MG_PYTHON_PATH}" STREQUAL "")
set(Python3_ROOT_DIR "${MG_PYTHON_PATH}")
endif()
if("${MG_PYTHON_VERSION}" STREQUAL "")
find_package(Python3 3.5 REQUIRED COMPONENTS Development)
else()
find_package(Python3 "${MG_PYTHON_VERSION}" EXACT REQUIRED COMPONENTS Development)
endif()
target_link_libraries(mg-query-v2 Python3::Python)
# Generate Antlr openCypher parser
set(opencypher_frontend ${CMAKE_CURRENT_SOURCE_DIR}/frontend/opencypher)
set(opencypher_generated ${opencypher_frontend}/generated)
set(opencypher_lexer_grammar ${opencypher_frontend}/grammar/MemgraphCypherLexer.g4)
set(opencypher_parser_grammar ${opencypher_frontend}/grammar/MemgraphCypher.g4)
set(antlr_opencypher_generated_src
${opencypher_generated}/MemgraphCypherLexer.cpp
${opencypher_generated}/MemgraphCypher.cpp
${opencypher_generated}/MemgraphCypherBaseVisitor.cpp
${opencypher_generated}/MemgraphCypherVisitor.cpp
)
set(antlr_opencypher_generated_include
${opencypher_generated}/MemgraphCypherLexer.h
${opencypher_generated}/MemgraphCypher.h
${opencypher_generated}/MemgraphCypherBaseVisitor.h
${opencypher_generated}/MemgraphCypherVisitor.h
)
add_custom_command(
OUTPUT ${antlr_opencypher_generated_src} ${antlr_opencypher_generated_include}
COMMAND ${CMAKE_COMMAND} -E make_directory ${opencypher_generated}
COMMAND
java -jar ${CMAKE_SOURCE_DIR}/libs/antlr-4.10.1-complete.jar
-Dlanguage=Cpp -visitor -package antlropencypher
-o ${opencypher_generated}
${opencypher_lexer_grammar} ${opencypher_parser_grammar}
WORKING_DIRECTORY "${CMAKE_BINARY_DIR}"
DEPENDS
${opencypher_lexer_grammar} ${opencypher_parser_grammar}
${opencypher_frontend}/grammar/CypherLexer.g4
${opencypher_frontend}/grammar/Cypher.g4)
add_custom_target(generate_opencypher_parser_v2
DEPENDS ${antlr_opencypher_generated_src} ${antlr_opencypher_generated_include})
add_library(antlr_opencypher_parser_lib_v2 STATIC ${antlr_opencypher_generated_src})
add_dependencies(antlr_opencypher_parser_lib_v2 generate_opencypher_parser_v2)
target_link_libraries(antlr_opencypher_parser_lib_v2 antlr4)
target_link_libraries(mg-query-v2 antlr_opencypher_parser_lib_v2)

View File

@ -0,0 +1,29 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "query/v2/frontend/ast/ast.hpp"
namespace memgraph::query::v2 {
class AuthChecker {
public:
virtual bool IsUserAuthorized(const std::optional<std::string> &username,
const std::vector<query::v2::AuthQuery::Privilege> &privileges) const = 0;
};
class AllowEverythingAuthChecker final : public query::v2::AuthChecker {
bool IsUserAuthorized(const std::optional<std::string> & /*username*/,
const std::vector<query::v2::AuthQuery::Privilege> & /*privileges*/) const override {
return true;
}
};
} // namespace memgraph::query::v2

76
src/query/v2/common.cpp Normal file
View File

@ -0,0 +1,76 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/v2/common.hpp"
namespace memgraph::query::v2 {
namespace impl {
bool TypedValueCompare(const TypedValue &a, const TypedValue &b) {
// in ordering null comes after everything else
// at the same time Null is not less that null
// first deal with Null < Whatever case
if (a.IsNull()) return false;
// now deal with NotNull < Null case
if (b.IsNull()) return true;
// comparisons are from this point legal only between values of
// the same type, or int+float combinations
if ((a.type() != b.type() && !(a.IsNumeric() && b.IsNumeric())))
throw QueryRuntimeException("Can't compare value of type {} to value of type {}.", a.type(), b.type());
switch (a.type()) {
case TypedValue::Type::Bool:
return !a.ValueBool() && b.ValueBool();
case TypedValue::Type::Int:
if (b.type() == TypedValue::Type::Double)
return a.ValueInt() < b.ValueDouble();
else
return a.ValueInt() < b.ValueInt();
case TypedValue::Type::Double:
if (b.type() == TypedValue::Type::Int)
return a.ValueDouble() < b.ValueInt();
else
return a.ValueDouble() < b.ValueDouble();
case TypedValue::Type::String:
// NOLINTNEXTLINE(modernize-use-nullptr)
return a.ValueString() < b.ValueString();
case TypedValue::Type::Date:
// NOLINTNEXTLINE(modernize-use-nullptr)
return a.ValueDate() < b.ValueDate();
case TypedValue::Type::LocalTime:
// NOLINTNEXTLINE(modernize-use-nullptr)
return a.ValueLocalTime() < b.ValueLocalTime();
case TypedValue::Type::LocalDateTime:
// NOLINTNEXTLINE(modernize-use-nullptr)
return a.ValueLocalDateTime() < b.ValueLocalDateTime();
case TypedValue::Type::Duration:
// NOLINTNEXTLINE(modernize-use-nullptr)
return a.ValueDuration() < b.ValueDuration();
case TypedValue::Type::List:
case TypedValue::Type::Map:
case TypedValue::Type::Vertex:
case TypedValue::Type::Edge:
case TypedValue::Type::Path:
throw QueryRuntimeException("Comparison is not defined for values of type {}.", a.type());
case TypedValue::Type::Null:
LOG_FATAL("Invalid type");
}
}
} // namespace impl
int64_t QueryTimestamp() {
return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch())
.count();
}
} // namespace memgraph::query::v2

111
src/query/v2/common.hpp Normal file
View File

@ -0,0 +1,111 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
/// @file
#pragma once
#include <concepts>
#include <cstdint>
#include <string>
#include <string_view>
#include "query/v2/db_accessor.hpp"
#include "query/v2/exceptions.hpp"
#include "query/v2/frontend/ast/ast.hpp"
#include "query/v2/frontend/semantic/symbol.hpp"
#include "query/v2/typed_value.hpp"
#include "storage/v3/id_types.hpp"
#include "storage/v3/property_value.hpp"
#include "storage/v3/view.hpp"
#include "utils/logging.hpp"
namespace memgraph::query::v2 {
namespace impl {
bool TypedValueCompare(const TypedValue &a, const TypedValue &b);
} // namespace impl
/// Custom Comparator type for comparing vectors of TypedValues.
///
/// Does lexicographical ordering of elements based on the above
/// defined TypedValueCompare, and also accepts a vector of Orderings
/// the define how respective elements compare.
class TypedValueVectorCompare final {
public:
TypedValueVectorCompare() {}
explicit TypedValueVectorCompare(const std::vector<Ordering> &ordering) : ordering_(ordering) {}
template <class TAllocator>
bool operator()(const std::vector<TypedValue, TAllocator> &c1, const std::vector<TypedValue, TAllocator> &c2) const {
// ordering is invalid if there are more elements in the collections
// then there are in the ordering_ vector
MG_ASSERT(c1.size() <= ordering_.size() && c2.size() <= ordering_.size(),
"Collections contain more elements then there are orderings");
auto c1_it = c1.begin();
auto c2_it = c2.begin();
auto ordering_it = ordering_.begin();
for (; c1_it != c1.end() && c2_it != c2.end(); c1_it++, c2_it++, ordering_it++) {
if (impl::TypedValueCompare(*c1_it, *c2_it)) return *ordering_it == Ordering::ASC;
if (impl::TypedValueCompare(*c2_it, *c1_it)) return *ordering_it == Ordering::DESC;
}
// at least one collection is exhausted
// c1 is less then c2 iff c1 reached the end but c2 didn't
return (c1_it == c1.end()) && (c2_it != c2.end());
}
// TODO: Remove this, member is public
const auto &ordering() const { return ordering_; }
std::vector<Ordering> ordering_;
};
/// Raise QueryRuntimeException if the value for symbol isn't of expected type.
inline void ExpectType(const Symbol &symbol, const TypedValue &value, TypedValue::Type expected) {
if (value.type() != expected)
throw QueryRuntimeException("Expected a {} for '{}', but got {}.", expected, symbol.name(), value.type());
}
template <typename T>
concept AccessorWithSetProperty = requires(T accessor, const storage::v3::PropertyId key,
const storage::v3::PropertyValue new_value) {
{ accessor.SetProperty(key, new_value) } -> std::same_as<storage::v3::Result<storage::v3::PropertyValue>>;
};
/// Set a property `value` mapped with given `key` on a `record`.
///
/// @throw QueryRuntimeException if value cannot be set as a property value
template <AccessorWithSetProperty T>
storage::v3::PropertyValue PropsSetChecked(T *record, const storage::v3::PropertyId &key, const TypedValue &value) {
try {
auto maybe_old_value = record->SetProperty(key, storage::v3::PropertyValue(value));
if (maybe_old_value.HasError()) {
switch (maybe_old_value.GetError()) {
case storage::v3::Error::SERIALIZATION_ERROR:
throw TransactionSerializationException();
case storage::v3::Error::DELETED_OBJECT:
throw QueryRuntimeException("Trying to set properties on a deleted object.");
case storage::v3::Error::PROPERTIES_DISABLED:
throw QueryRuntimeException("Can't set property because properties on edges are disabled.");
case storage::v3::Error::VERTEX_HAS_EDGES:
case storage::v3::Error::NONEXISTENT_OBJECT:
throw QueryRuntimeException("Unexpected error when setting a property.");
}
}
return std::move(*maybe_old_value);
} catch (const TypedValueException &) {
throw QueryRuntimeException("'{}' cannot be used as a property value.", value.type());
}
}
int64_t QueryTimestamp();
} // namespace memgraph::query::v2

32
src/query/v2/config.hpp Normal file
View File

@ -0,0 +1,32 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <chrono>
#include <string>
namespace memgraph::query::v2 {
struct InterpreterConfig {
struct Query {
bool allow_load_csv{true};
} query;
// The default execution timeout is 10 minutes.
double execution_timeout_sec{600.0};
// The same as \ref memgraph::storage::v3::replication::ReplicationClientConfig
std::chrono::seconds replication_replica_check_frequency{1};
std::string default_kafka_bootstrap_servers;
std::string default_pulsar_service_url;
uint32_t stream_transaction_conflict_retries;
std::chrono::milliseconds stream_transaction_retry_interval;
};
} // namespace memgraph::query::v2

View File

@ -0,0 +1,19 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <cstdint>
#include <string>
namespace memgraph::query::v2 {
inline constexpr uint16_t kDefaultReplicationPort = 10000;
inline constexpr auto *kDefaultReplicationServerIp = "0.0.0.0";
} // namespace memgraph::query::v2

89
src/query/v2/context.hpp Normal file
View File

@ -0,0 +1,89 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <type_traits>
#include "query/v2/common.hpp"
#include "query/v2/frontend/semantic/symbol_table.hpp"
#include "query/v2/metadata.hpp"
#include "query/v2/parameters.hpp"
#include "query/v2/plan/profile.hpp"
#include "query/v2/trigger.hpp"
#include "utils/async_timer.hpp"
namespace memgraph::query::v2 {
struct EvaluationContext {
/// Memory for allocations during evaluation of a *single* Pull call.
///
/// Although the assigned memory may live longer than the duration of a Pull
/// (e.g. memory is the same as the whole execution memory), you have to treat
/// it as if the lifetime is only valid during the Pull.
utils::MemoryResource *memory{utils::NewDeleteResource()};
int64_t timestamp{-1};
Parameters parameters;
/// All properties indexable via PropertyIx
std::vector<storage::v3::PropertyId> properties;
/// All labels indexable via LabelIx
std::vector<storage::v3::LabelId> labels;
/// All counters generated by `counter` function, mutable because the function
/// modifies the values
mutable std::unordered_map<std::string, int64_t> counters;
};
inline std::vector<storage::v3::PropertyId> NamesToProperties(const std::vector<std::string> &property_names,
DbAccessor *dba) {
std::vector<storage::v3::PropertyId> properties;
properties.reserve(property_names.size());
for (const auto &name : property_names) {
properties.push_back(dba->NameToProperty(name));
}
return properties;
}
inline std::vector<storage::v3::LabelId> NamesToLabels(const std::vector<std::string> &label_names, DbAccessor *dba) {
std::vector<storage::v3::LabelId> labels;
labels.reserve(label_names.size());
for (const auto &name : label_names) {
labels.push_back(dba->NameToLabel(name));
}
return labels;
}
struct ExecutionContext {
DbAccessor *db_accessor{nullptr};
SymbolTable symbol_table;
EvaluationContext evaluation_context;
std::atomic<bool> *is_shutting_down{nullptr};
bool is_profile_query{false};
std::chrono::duration<double> profile_execution_time;
plan::ProfilingStats stats;
plan::ProfilingStats *stats_root{nullptr};
ExecutionStats execution_stats;
TriggerContextCollector *trigger_context_collector{nullptr};
utils::AsyncTimer timer;
};
static_assert(std::is_move_assignable_v<ExecutionContext>, "ExecutionContext must be move assignable!");
static_assert(std::is_move_constructible_v<ExecutionContext>, "ExecutionContext must be move constructible!");
inline bool MustAbort(const ExecutionContext &context) noexcept {
return (context.is_shutting_down != nullptr && context.is_shutting_down->load(std::memory_order_acquire)) ||
context.timer.IsExpired();
}
inline plan::ProfilingStatsWithTotalTime GetStatsWithTotalTime(const ExecutionContext &context) {
return plan::ProfilingStatsWithTotalTime{context.stats, context.profile_execution_time};
}
} // namespace memgraph::query::v2

View File

@ -0,0 +1,153 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/v2/cypher_query_interpreter.hpp"
// NOLINTNEXTLINE (cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_HIDDEN_bool(query_cost_planner, true, "Use the cost-estimating query planner.");
// NOLINTNEXTLINE (cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_VALIDATED_int32(query_plan_cache_ttl, 60, "Time to live for cached query plans, in seconds.",
FLAG_IN_RANGE(0, std::numeric_limits<int32_t>::max()));
namespace memgraph::query::v2 {
CachedPlan::CachedPlan(std::unique_ptr<LogicalPlan> plan) : plan_(std::move(plan)) {}
ParsedQuery ParseQuery(const std::string &query_string, const std::map<std::string, storage::v3::PropertyValue> &params,
utils::SkipList<QueryCacheEntry> *cache, const InterpreterConfig::Query &query_config) {
// Strip the query for caching purposes. The process of stripping a query
// "normalizes" it by replacing any literals with new parameters. This
// results in just the *structure* of the query being taken into account for
// caching.
frontend::StrippedQuery stripped_query{query_string};
// Copy over the parameters that were introduced during stripping.
Parameters parameters{stripped_query.literals()};
// Check that all user-specified parameters are provided.
for (const auto &param_pair : stripped_query.parameters()) {
auto it = params.find(param_pair.second);
if (it == params.end()) {
throw query::v2::UnprovidedParameterError("Parameter ${} not provided.", param_pair.second);
}
parameters.Add(param_pair.first, it->second);
}
// Cache the query's AST if it isn't already.
auto hash = stripped_query.hash();
auto accessor = cache->access();
auto it = accessor.find(hash);
std::unique_ptr<frontend::opencypher::Parser> parser;
// Return a copy of both the AST storage and the query.
CachedQuery result;
bool is_cacheable = true;
auto get_information_from_cache = [&](const auto &cached_query) {
result.ast_storage.properties_ = cached_query.ast_storage.properties_;
result.ast_storage.labels_ = cached_query.ast_storage.labels_;
result.ast_storage.edge_types_ = cached_query.ast_storage.edge_types_;
result.query = cached_query.query->Clone(&result.ast_storage);
result.required_privileges = cached_query.required_privileges;
};
if (it == accessor.end()) {
try {
parser = std::make_unique<frontend::opencypher::Parser>(stripped_query.query());
} catch (const SyntaxException &e) {
// There is a syntax exception in the stripped query. Re-run the parser
// on the original query to get an appropriate error messsage.
parser = std::make_unique<frontend::opencypher::Parser>(query_string);
// If an exception was not thrown here, the stripper messed something
// up.
LOG_FATAL("The stripped query can't be parsed, but the original can.");
}
// Convert the ANTLR4 parse tree into an AST.
AstStorage ast_storage;
frontend::ParsingContext context{true};
frontend::CypherMainVisitor visitor(context, &ast_storage);
visitor.visit(parser->tree());
if (visitor.GetQueryInfo().has_load_csv && !query_config.allow_load_csv) {
throw utils::BasicException("Load CSV not allowed on this instance because it was disabled by a config.");
}
if (visitor.GetQueryInfo().is_cacheable) {
CachedQuery cached_query{std::move(ast_storage), visitor.query(),
query::v2::GetRequiredPrivileges(visitor.query())};
it = accessor.insert({hash, std::move(cached_query)}).first;
get_information_from_cache(it->second);
} else {
result.ast_storage.properties_ = ast_storage.properties_;
result.ast_storage.labels_ = ast_storage.labels_;
result.ast_storage.edge_types_ = ast_storage.edge_types_;
result.query = visitor.query()->Clone(&result.ast_storage);
result.required_privileges = query::v2::GetRequiredPrivileges(visitor.query());
is_cacheable = false;
}
} else {
get_information_from_cache(it->second);
}
return ParsedQuery{query_string,
params,
std::move(parameters),
std::move(stripped_query),
std::move(result.ast_storage),
result.query,
std::move(result.required_privileges),
is_cacheable};
}
std::unique_ptr<LogicalPlan> MakeLogicalPlan(AstStorage ast_storage, CypherQuery *query, const Parameters &parameters,
DbAccessor *db_accessor,
const std::vector<Identifier *> &predefined_identifiers) {
auto vertex_counts = plan::MakeVertexCountCache(db_accessor);
auto symbol_table = MakeSymbolTable(query, predefined_identifiers);
auto planning_context = plan::MakePlanningContext(&ast_storage, &symbol_table, query, &vertex_counts);
auto [root, cost] = plan::MakeLogicalPlan(&planning_context, parameters, FLAGS_query_cost_planner);
return std::make_unique<SingleNodeLogicalPlan>(std::move(root), cost, std::move(ast_storage),
std::move(symbol_table));
}
std::shared_ptr<CachedPlan> CypherQueryToPlan(uint64_t hash, AstStorage ast_storage, CypherQuery *query,
const Parameters &parameters, utils::SkipList<PlanCacheEntry> *plan_cache,
DbAccessor *db_accessor,
const std::vector<Identifier *> &predefined_identifiers) {
std::optional<utils::SkipList<PlanCacheEntry>::Accessor> plan_cache_access;
if (plan_cache) {
plan_cache_access.emplace(plan_cache->access());
auto it = plan_cache_access->find(hash);
if (it != plan_cache_access->end()) {
if (it->second->IsExpired()) {
plan_cache_access->remove(hash);
} else {
return it->second;
}
}
}
auto plan = std::make_shared<CachedPlan>(
MakeLogicalPlan(std::move(ast_storage), query, parameters, db_accessor, predefined_identifiers));
if (plan_cache_access) {
plan_cache_access->insert({hash, plan});
}
return plan;
}
} // namespace memgraph::query::v2

View File

@ -0,0 +1,151 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "query/v2/config.hpp"
#include "query/v2/frontend/ast/cypher_main_visitor.hpp"
#include "query/v2/frontend/opencypher/parser.hpp"
#include "query/v2/frontend/semantic/required_privileges.hpp"
#include "query/v2/frontend/semantic/symbol_generator.hpp"
#include "query/v2/frontend/stripped.hpp"
#include "query/v2/plan/planner.hpp"
#include "utils/flag_validation.hpp"
#include "utils/timer.hpp"
// NOLINTNEXTLINE (cppcoreguidelines-avoid-non-const-global-variables)
DECLARE_bool(query_cost_planner);
// NOLINTNEXTLINE (cppcoreguidelines-avoid-non-const-global-variables)
DECLARE_int32(query_plan_cache_ttl);
namespace memgraph::query::v2 {
// TODO: Maybe this should move to query/plan/planner.
/// Interface for accessing the root operator of a logical plan.
class LogicalPlan {
public:
explicit LogicalPlan() = default;
virtual ~LogicalPlan() = default;
LogicalPlan(const LogicalPlan &) = default;
LogicalPlan &operator=(const LogicalPlan &) = default;
LogicalPlan(LogicalPlan &&) = default;
LogicalPlan &operator=(LogicalPlan &&) = default;
virtual const plan::LogicalOperator &GetRoot() const = 0;
virtual double GetCost() const = 0;
virtual const SymbolTable &GetSymbolTable() const = 0;
virtual const AstStorage &GetAstStorage() const = 0;
};
class CachedPlan {
public:
explicit CachedPlan(std::unique_ptr<LogicalPlan> plan);
const auto &plan() const { return plan_->GetRoot(); }
double cost() const { return plan_->GetCost(); }
const auto &symbol_table() const { return plan_->GetSymbolTable(); }
const auto &ast_storage() const { return plan_->GetAstStorage(); }
bool IsExpired() const {
// NOLINTNEXTLINE (modernize-use-nullptr)
return cache_timer_.Elapsed() > std::chrono::seconds(FLAGS_query_plan_cache_ttl);
};
private:
std::unique_ptr<LogicalPlan> plan_;
utils::Timer cache_timer_;
};
struct CachedQuery {
AstStorage ast_storage;
Query *query;
std::vector<AuthQuery::Privilege> required_privileges;
};
struct QueryCacheEntry {
bool operator==(const QueryCacheEntry &other) const { return first == other.first; }
bool operator<(const QueryCacheEntry &other) const { return first < other.first; }
bool operator==(const uint64_t &other) const { return first == other; }
bool operator<(const uint64_t &other) const { return first < other; }
uint64_t first;
// TODO: Maybe store the query string here and use it as a key with the hash
// so that we eliminate the risk of hash collisions.
CachedQuery second;
};
struct PlanCacheEntry {
bool operator==(const PlanCacheEntry &other) const { return first == other.first; }
bool operator<(const PlanCacheEntry &other) const { return first < other.first; }
bool operator==(const uint64_t &other) const { return first == other; }
bool operator<(const uint64_t &other) const { return first < other; }
uint64_t first;
// TODO: Maybe store the query string here and use it as a key with the hash
// so that we eliminate the risk of hash collisions.
std::shared_ptr<CachedPlan> second;
};
/**
* A container for data related to the parsing of a query.
*/
struct ParsedQuery {
std::string query_string;
std::map<std::string, storage::v3::PropertyValue> user_parameters;
Parameters parameters;
frontend::StrippedQuery stripped_query;
AstStorage ast_storage;
Query *query;
std::vector<AuthQuery::Privilege> required_privileges;
bool is_cacheable{true};
};
ParsedQuery ParseQuery(const std::string &query_string, const std::map<std::string, storage::v3::PropertyValue> &params,
utils::SkipList<QueryCacheEntry> *cache, const InterpreterConfig::Query &query_config);
class SingleNodeLogicalPlan final : public LogicalPlan {
public:
SingleNodeLogicalPlan(std::unique_ptr<plan::LogicalOperator> root, double cost, AstStorage storage,
const SymbolTable &symbol_table)
: root_(std::move(root)), cost_(cost), storage_(std::move(storage)), symbol_table_(symbol_table) {}
const plan::LogicalOperator &GetRoot() const override { return *root_; }
double GetCost() const override { return cost_; }
const SymbolTable &GetSymbolTable() const override { return symbol_table_; }
const AstStorage &GetAstStorage() const override { return storage_; }
private:
std::unique_ptr<plan::LogicalOperator> root_;
double cost_;
AstStorage storage_;
SymbolTable symbol_table_;
};
std::unique_ptr<LogicalPlan> MakeLogicalPlan(AstStorage ast_storage, CypherQuery *query, const Parameters &parameters,
DbAccessor *db_accessor,
const std::vector<Identifier *> &predefined_identifiers);
/**
* Return the parsed *Cypher* query's AST cached logical plan, or create and
* cache a fresh one if it doesn't yet exist.
* @param predefined_identifiers optional identifiers you want to inject into a query.
* If an identifier is not defined in a scope, we check the predefined identifiers.
* If an identifier is contained there, we inject it at that place and remove it,
* because a predefined identifier can be used only in one scope.
*/
std::shared_ptr<CachedPlan> CypherQueryToPlan(uint64_t hash, AstStorage ast_storage, CypherQuery *query,
const Parameters &parameters, utils::SkipList<PlanCacheEntry> *plan_cache,
DbAccessor *db_accessor,
const std::vector<Identifier *> &predefined_identifiers = {});
} // namespace memgraph::query::v2

View File

@ -0,0 +1,384 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <optional>
#include <cppitertools/filter.hpp>
#include <cppitertools/imap.hpp>
#include "query/v2/exceptions.hpp"
#include "storage/v3/id_types.hpp"
#include "storage/v3/property_value.hpp"
#include "storage/v3/result.hpp"
///////////////////////////////////////////////////////////
// Our communication layer and query engine don't mix
// very well on Centos because OpenSSL version avaialable
// on Centos 7 include libkrb5 which has brilliant macros
// called TRUE and FALSE. For more detailed explanation go
// to memgraph.cpp.
//
// Because of the replication storage now uses some form of
// communication so we have some unwanted macros.
// This cannot be avoided by simple include orderings so we
// simply undefine those macros as we're sure that libkrb5
// won't and can't be used anywhere in the query engine.
#include "storage/v3/storage.hpp"
#undef FALSE
#undef TRUE
///////////////////////////////////////////////////////////
#include "storage/v3/view.hpp"
#include "utils/bound.hpp"
#include "utils/exceptions.hpp"
namespace memgraph::query::v2 {
class VertexAccessor;
class EdgeAccessor final {
public:
storage::v3::EdgeAccessor impl_;
public:
explicit EdgeAccessor(storage::v3::EdgeAccessor impl) : impl_(std::move(impl)) {}
bool IsVisible(storage::v3::View view) const { return impl_.IsVisible(view); }
storage::v3::EdgeTypeId EdgeType() const { return impl_.EdgeType(); }
auto Properties(storage::v3::View view) const { return impl_.Properties(view); }
storage::v3::Result<storage::v3::PropertyValue> GetProperty(storage::v3::View view,
storage::v3::PropertyId key) const {
return impl_.GetProperty(key, view);
}
storage::v3::Result<storage::v3::PropertyValue> SetProperty(storage::v3::PropertyId key,
const storage::v3::PropertyValue &value) {
return impl_.SetProperty(key, value);
}
storage::v3::Result<storage::v3::PropertyValue> RemoveProperty(storage::v3::PropertyId key) {
return SetProperty(key, storage::v3::PropertyValue());
}
storage::v3::Result<std::map<storage::v3::PropertyId, storage::v3::PropertyValue>> ClearProperties() {
return impl_.ClearProperties();
}
VertexAccessor To() const;
VertexAccessor From() const;
bool IsCycle() const;
int64_t CypherId() const { return impl_.Gid().AsInt(); }
storage::v3::Gid Gid() const noexcept { return impl_.Gid(); }
bool operator==(const EdgeAccessor &e) const noexcept { return impl_ == e.impl_; }
bool operator!=(const EdgeAccessor &e) const noexcept { return !(*this == e); }
};
class VertexAccessor final {
public:
storage::v3::VertexAccessor impl_;
static EdgeAccessor MakeEdgeAccessor(const storage::v3::EdgeAccessor impl) { return EdgeAccessor(impl); }
public:
explicit VertexAccessor(storage::v3::VertexAccessor impl) : impl_(impl) {}
bool IsVisible(storage::v3::View view) const { return impl_.IsVisible(view); }
auto Labels(storage::v3::View view) const { return impl_.Labels(view); }
storage::v3::Result<bool> AddLabel(storage::v3::LabelId label) { return impl_.AddLabel(label); }
storage::v3::Result<bool> RemoveLabel(storage::v3::LabelId label) { return impl_.RemoveLabel(label); }
storage::v3::Result<bool> HasLabel(storage::v3::View view, storage::v3::LabelId label) const {
return impl_.HasLabel(label, view);
}
auto Properties(storage::v3::View view) const { return impl_.Properties(view); }
storage::v3::Result<storage::v3::PropertyValue> GetProperty(storage::v3::View view,
storage::v3::PropertyId key) const {
return impl_.GetProperty(key, view);
}
storage::v3::Result<storage::v3::PropertyValue> SetProperty(storage::v3::PropertyId key,
const storage::v3::PropertyValue &value) {
return impl_.SetProperty(key, value);
}
storage::v3::Result<storage::v3::PropertyValue> RemoveProperty(storage::v3::PropertyId key) {
return SetProperty(key, storage::v3::PropertyValue());
}
storage::v3::Result<std::map<storage::v3::PropertyId, storage::v3::PropertyValue>> ClearProperties() {
return impl_.ClearProperties();
}
auto InEdges(storage::v3::View view, const std::vector<storage::v3::EdgeTypeId> &edge_types) const
-> storage::v3::Result<decltype(iter::imap(MakeEdgeAccessor, *impl_.InEdges(view)))> {
auto maybe_edges = impl_.InEdges(view, edge_types);
if (maybe_edges.HasError()) return maybe_edges.GetError();
return iter::imap(MakeEdgeAccessor, std::move(*maybe_edges));
}
auto InEdges(storage::v3::View view) const { return InEdges(view, {}); }
auto InEdges(storage::v3::View view, const std::vector<storage::v3::EdgeTypeId> &edge_types,
const VertexAccessor &dest) const
-> storage::v3::Result<decltype(iter::imap(MakeEdgeAccessor, *impl_.InEdges(view)))> {
auto maybe_edges = impl_.InEdges(view, edge_types, &dest.impl_);
if (maybe_edges.HasError()) return maybe_edges.GetError();
return iter::imap(MakeEdgeAccessor, std::move(*maybe_edges));
}
auto OutEdges(storage::v3::View view, const std::vector<storage::v3::EdgeTypeId> &edge_types) const
-> storage::v3::Result<decltype(iter::imap(MakeEdgeAccessor, *impl_.OutEdges(view)))> {
auto maybe_edges = impl_.OutEdges(view, edge_types);
if (maybe_edges.HasError()) return maybe_edges.GetError();
return iter::imap(MakeEdgeAccessor, std::move(*maybe_edges));
}
auto OutEdges(storage::v3::View view) const { return OutEdges(view, {}); }
auto OutEdges(storage::v3::View view, const std::vector<storage::v3::EdgeTypeId> &edge_types,
const VertexAccessor &dest) const
-> storage::v3::Result<decltype(iter::imap(MakeEdgeAccessor, *impl_.OutEdges(view)))> {
auto maybe_edges = impl_.OutEdges(view, edge_types, &dest.impl_);
if (maybe_edges.HasError()) return maybe_edges.GetError();
return iter::imap(MakeEdgeAccessor, std::move(*maybe_edges));
}
storage::v3::Result<size_t> InDegree(storage::v3::View view) const { return impl_.InDegree(view); }
storage::v3::Result<size_t> OutDegree(storage::v3::View view) const { return impl_.OutDegree(view); }
int64_t CypherId() const { return impl_.Gid().AsInt(); }
storage::v3::Gid Gid() const noexcept { return impl_.Gid(); }
bool operator==(const VertexAccessor &v) const noexcept {
static_assert(noexcept(impl_ == v.impl_));
return impl_ == v.impl_;
}
bool operator!=(const VertexAccessor &v) const noexcept { return !(*this == v); }
};
inline VertexAccessor EdgeAccessor::To() const { return VertexAccessor(impl_.ToVertex()); }
inline VertexAccessor EdgeAccessor::From() const { return VertexAccessor(impl_.FromVertex()); }
inline bool EdgeAccessor::IsCycle() const { return To() == From(); }
class DbAccessor final {
storage::v3::Storage::Accessor *accessor_;
class VerticesIterable final {
storage::v3::VerticesIterable iterable_;
public:
class Iterator final {
storage::v3::VerticesIterable::Iterator it_;
public:
explicit Iterator(storage::v3::VerticesIterable::Iterator it) : it_(it) {}
VertexAccessor operator*() const { return VertexAccessor(*it_); }
Iterator &operator++() {
++it_;
return *this;
}
bool operator==(const Iterator &other) const { return it_ == other.it_; }
bool operator!=(const Iterator &other) const { return !(other == *this); }
};
explicit VerticesIterable(storage::v3::VerticesIterable iterable) : iterable_(std::move(iterable)) {}
Iterator begin() { return Iterator(iterable_.begin()); }
Iterator end() { return Iterator(iterable_.end()); }
};
public:
explicit DbAccessor(storage::v3::Storage::Accessor *accessor) : accessor_(accessor) {}
std::optional<VertexAccessor> FindVertex(storage::v3::Gid gid, storage::v3::View view) {
auto maybe_vertex = accessor_->FindVertex(gid, view);
if (maybe_vertex) return VertexAccessor(*maybe_vertex);
return std::nullopt;
}
void FinalizeTransaction() { accessor_->FinalizeTransaction(); }
VerticesIterable Vertices(storage::v3::View view) { return VerticesIterable(accessor_->Vertices(view)); }
VerticesIterable Vertices(storage::v3::View view, storage::v3::LabelId label) {
return VerticesIterable(accessor_->Vertices(label, view));
}
VerticesIterable Vertices(storage::v3::View view, storage::v3::LabelId label, storage::v3::PropertyId property) {
return VerticesIterable(accessor_->Vertices(label, property, view));
}
VerticesIterable Vertices(storage::v3::View view, storage::v3::LabelId label, storage::v3::PropertyId property,
const storage::v3::PropertyValue &value) {
return VerticesIterable(accessor_->Vertices(label, property, value, view));
}
VerticesIterable Vertices(storage::v3::View view, storage::v3::LabelId label, storage::v3::PropertyId property,
const std::optional<utils::Bound<storage::v3::PropertyValue>> &lower,
const std::optional<utils::Bound<storage::v3::PropertyValue>> &upper) {
return VerticesIterable(accessor_->Vertices(label, property, lower, upper, view));
}
VertexAccessor InsertVertex() { return VertexAccessor(accessor_->CreateVertex()); }
storage::v3::Result<EdgeAccessor> InsertEdge(VertexAccessor *from, VertexAccessor *to,
const storage::v3::EdgeTypeId &edge_type) {
auto maybe_edge = accessor_->CreateEdge(&from->impl_, &to->impl_, edge_type);
if (maybe_edge.HasError()) return storage::v3::Result<EdgeAccessor>(maybe_edge.GetError());
return EdgeAccessor(*maybe_edge);
}
storage::v3::Result<std::optional<EdgeAccessor>> RemoveEdge(EdgeAccessor *edge) {
auto res = accessor_->DeleteEdge(&edge->impl_);
if (res.HasError()) {
return res.GetError();
}
const auto &value = res.GetValue();
if (!value) {
return std::optional<EdgeAccessor>{};
}
return std::make_optional<EdgeAccessor>(*value);
}
storage::v3::Result<std::optional<std::pair<VertexAccessor, std::vector<EdgeAccessor>>>> DetachRemoveVertex(
VertexAccessor *vertex_accessor) {
using ReturnType = std::pair<VertexAccessor, std::vector<EdgeAccessor>>;
auto res = accessor_->DetachDeleteVertex(&vertex_accessor->impl_);
if (res.HasError()) {
return res.GetError();
}
const auto &value = res.GetValue();
if (!value) {
return std::optional<ReturnType>{};
}
const auto &[vertex, edges] = *value;
std::vector<EdgeAccessor> deleted_edges;
deleted_edges.reserve(edges.size());
std::transform(edges.begin(), edges.end(), std::back_inserter(deleted_edges),
[](const auto &deleted_edge) { return EdgeAccessor{deleted_edge}; });
return std::make_optional<ReturnType>(vertex, std::move(deleted_edges));
}
storage::v3::Result<std::optional<VertexAccessor>> RemoveVertex(VertexAccessor *vertex_accessor) {
auto res = accessor_->DeleteVertex(&vertex_accessor->impl_);
if (res.HasError()) {
return res.GetError();
}
const auto &value = res.GetValue();
if (!value) {
return std::optional<VertexAccessor>{};
}
return std::make_optional<VertexAccessor>(*value);
}
storage::v3::PropertyId NameToProperty(const std::string_view name) { return accessor_->NameToProperty(name); }
storage::v3::LabelId NameToLabel(const std::string_view name) { return accessor_->NameToLabel(name); }
storage::v3::EdgeTypeId NameToEdgeType(const std::string_view name) { return accessor_->NameToEdgeType(name); }
const std::string &PropertyToName(storage::v3::PropertyId prop) const { return accessor_->PropertyToName(prop); }
const std::string &LabelToName(storage::v3::LabelId label) const { return accessor_->LabelToName(label); }
const std::string &EdgeTypeToName(storage::v3::EdgeTypeId type) const { return accessor_->EdgeTypeToName(type); }
void AdvanceCommand() { accessor_->AdvanceCommand(); }
utils::BasicResult<storage::v3::ConstraintViolation, void> Commit() { return accessor_->Commit(); }
void Abort() { accessor_->Abort(); }
bool LabelIndexExists(storage::v3::LabelId label) const { return accessor_->LabelIndexExists(label); }
bool LabelPropertyIndexExists(storage::v3::LabelId label, storage::v3::PropertyId prop) const {
return accessor_->LabelPropertyIndexExists(label, prop);
}
int64_t VerticesCount() const { return accessor_->ApproximateVertexCount(); }
int64_t VerticesCount(storage::v3::LabelId label) const { return accessor_->ApproximateVertexCount(label); }
int64_t VerticesCount(storage::v3::LabelId label, storage::v3::PropertyId property) const {
return accessor_->ApproximateVertexCount(label, property);
}
int64_t VerticesCount(storage::v3::LabelId label, storage::v3::PropertyId property,
const storage::v3::PropertyValue &value) const {
return accessor_->ApproximateVertexCount(label, property, value);
}
int64_t VerticesCount(storage::v3::LabelId label, storage::v3::PropertyId property,
const std::optional<utils::Bound<storage::v3::PropertyValue>> &lower,
const std::optional<utils::Bound<storage::v3::PropertyValue>> &upper) const {
return accessor_->ApproximateVertexCount(label, property, lower, upper);
}
storage::v3::IndicesInfo ListAllIndices() const { return accessor_->ListAllIndices(); }
storage::v3::ConstraintsInfo ListAllConstraints() const { return accessor_->ListAllConstraints(); }
};
} // namespace memgraph::query::v2
namespace std {
template <>
struct hash<memgraph::query::v2::VertexAccessor> {
size_t operator()(const memgraph::query::v2::VertexAccessor &v) const {
return std::hash<decltype(v.impl_)>{}(v.impl_);
}
};
template <>
struct hash<memgraph::query::v2::EdgeAccessor> {
size_t operator()(const memgraph::query::v2::EdgeAccessor &e) const {
return std::hash<decltype(e.impl_)>{}(e.impl_);
}
};
} // namespace std

View File

@ -0,0 +1,24 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <vector>
#include "query/v2/typed_value.hpp"
namespace memgraph::query::v2 {
struct DiscardValueResultStream final {
void Result(const std::vector<query::v2::TypedValue> & /*values*/) {
// do nothing
}
};
} // namespace memgraph::query::v2

541
src/query/v2/dump.cpp Normal file
View File

@ -0,0 +1,541 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/v2/dump.hpp"
#include <iomanip>
#include <limits>
#include <map>
#include <optional>
#include <ostream>
#include <utility>
#include <vector>
#include <fmt/format.h>
#include "query/v2/db_accessor.hpp"
#include "query/v2/exceptions.hpp"
#include "query/v2/stream.hpp"
#include "query/v2/typed_value.hpp"
#include "storage/v3/property_value.hpp"
#include "storage/v3/storage.hpp"
#include "utils/algorithm.hpp"
#include "utils/logging.hpp"
#include "utils/string.hpp"
#include "utils/temporal.hpp"
namespace memgraph::query::v2 {
namespace {
// Property that is used to make a difference among vertices. It is added to
// property set of vertices to match edges and removed after the entire graph
// is built.
const char *kInternalPropertyId = "__mg_id__";
// Label that is attached to each vertex and is used for easier creation of
// index on internal property id.
const char *kInternalVertexLabel = "__mg_vertex__";
/// A helper function that escapes label, edge type and property names.
std::string EscapeName(const std::string_view value) {
std::string out;
out.reserve(value.size() + 2);
out.append(1, '`');
for (auto c : value) {
if (c == '`') {
out.append("``");
} else {
out.append(1, c);
}
}
out.append(1, '`');
return out;
}
void DumpPreciseDouble(std::ostream *os, double value) {
// A temporary stream is used to keep precision of the original output
// stream unchanged.
std::ostringstream temp_oss;
temp_oss << std::setprecision(std::numeric_limits<double>::max_digits10) << value;
*os << temp_oss.str();
}
namespace {
void DumpDate(std::ostream &os, const storage::v3::TemporalData &value) {
utils::Date date(value.microseconds);
os << "DATE(\"" << date << "\")";
}
void DumpLocalTime(std::ostream &os, const storage::v3::TemporalData &value) {
utils::LocalTime lt(value.microseconds);
os << "LOCALTIME(\"" << lt << "\")";
}
void DumpLocalDateTime(std::ostream &os, const storage::v3::TemporalData &value) {
utils::LocalDateTime ldt(value.microseconds);
os << "LOCALDATETIME(\"" << ldt << "\")";
}
void DumpDuration(std::ostream &os, const storage::v3::TemporalData &value) {
utils::Duration dur(value.microseconds);
os << "DURATION(\"" << dur << "\")";
}
void DumpTemporalData(std::ostream &os, const storage::v3::TemporalData &value) {
switch (value.type) {
case storage::v3::TemporalType::Date: {
DumpDate(os, value);
return;
}
case storage::v3::TemporalType::LocalTime: {
DumpLocalTime(os, value);
return;
}
case storage::v3::TemporalType::LocalDateTime: {
DumpLocalDateTime(os, value);
return;
}
case storage::v3::TemporalType::Duration: {
DumpDuration(os, value);
return;
}
}
}
} // namespace
void DumpPropertyValue(std::ostream *os, const storage::v3::PropertyValue &value) {
switch (value.type()) {
case storage::v3::PropertyValue::Type::Null:
*os << "Null";
return;
case storage::v3::PropertyValue::Type::Bool:
*os << (value.ValueBool() ? "true" : "false");
return;
case storage::v3::PropertyValue::Type::String:
*os << utils::Escape(value.ValueString());
return;
case storage::v3::PropertyValue::Type::Int:
*os << value.ValueInt();
return;
case storage::v3::PropertyValue::Type::Double:
DumpPreciseDouble(os, value.ValueDouble());
return;
case storage::v3::PropertyValue::Type::List: {
*os << "[";
const auto &list = value.ValueList();
utils::PrintIterable(*os, list, ", ", [](auto &os, const auto &item) { DumpPropertyValue(&os, item); });
*os << "]";
return;
}
case storage::v3::PropertyValue::Type::Map: {
*os << "{";
const auto &map = value.ValueMap();
utils::PrintIterable(*os, map, ", ", [](auto &os, const auto &kv) {
os << EscapeName(kv.first) << ": ";
DumpPropertyValue(&os, kv.second);
});
*os << "}";
return;
}
case storage::v3::PropertyValue::Type::TemporalData: {
DumpTemporalData(*os, value.ValueTemporalData());
return;
}
}
}
void DumpProperties(std::ostream *os, query::v2::DbAccessor *dba,
const std::map<storage::v3::PropertyId, storage::v3::PropertyValue> &store,
std::optional<int64_t> property_id = std::nullopt) {
*os << "{";
if (property_id) {
*os << kInternalPropertyId << ": " << *property_id;
if (store.size() > 0) *os << ", ";
}
utils::PrintIterable(*os, store, ", ", [&dba](auto &os, const auto &kv) {
os << EscapeName(dba->PropertyToName(kv.first)) << ": ";
DumpPropertyValue(&os, kv.second);
});
*os << "}";
}
void DumpVertex(std::ostream *os, query::v2::DbAccessor *dba, const query::v2::VertexAccessor &vertex) {
*os << "CREATE (";
*os << ":" << kInternalVertexLabel;
auto maybe_labels = vertex.Labels(storage::v3::View::OLD);
if (maybe_labels.HasError()) {
switch (maybe_labels.GetError()) {
case storage::v3::Error::DELETED_OBJECT:
throw query::v2::QueryRuntimeException("Trying to get labels from a deleted node.");
case storage::v3::Error::NONEXISTENT_OBJECT:
throw query::v2::QueryRuntimeException("Trying to get labels from a node that doesn't exist.");
case storage::v3::Error::SERIALIZATION_ERROR:
case storage::v3::Error::VERTEX_HAS_EDGES:
case storage::v3::Error::PROPERTIES_DISABLED:
throw query::v2::QueryRuntimeException("Unexpected error when getting labels.");
}
}
for (const auto &label : *maybe_labels) {
*os << ":" << EscapeName(dba->LabelToName(label));
}
*os << " ";
auto maybe_props = vertex.Properties(storage::v3::View::OLD);
if (maybe_props.HasError()) {
switch (maybe_props.GetError()) {
case storage::v3::Error::DELETED_OBJECT:
throw query::v2::QueryRuntimeException("Trying to get properties from a deleted object.");
case storage::v3::Error::NONEXISTENT_OBJECT:
throw query::v2::QueryRuntimeException("Trying to get properties from a node that doesn't exist.");
case storage::v3::Error::SERIALIZATION_ERROR:
case storage::v3::Error::VERTEX_HAS_EDGES:
case storage::v3::Error::PROPERTIES_DISABLED:
throw query::v2::QueryRuntimeException("Unexpected error when getting properties.");
}
}
DumpProperties(os, dba, *maybe_props, vertex.CypherId());
*os << ");";
}
void DumpEdge(std::ostream *os, query::v2::DbAccessor *dba, const query::v2::EdgeAccessor &edge) {
*os << "MATCH ";
*os << "(u:" << kInternalVertexLabel << "), ";
*os << "(v:" << kInternalVertexLabel << ")";
*os << " WHERE ";
*os << "u." << kInternalPropertyId << " = " << edge.From().CypherId();
*os << " AND ";
*os << "v." << kInternalPropertyId << " = " << edge.To().CypherId() << " ";
*os << "CREATE (u)-[";
*os << ":" << EscapeName(dba->EdgeTypeToName(edge.EdgeType()));
auto maybe_props = edge.Properties(storage::v3::View::OLD);
if (maybe_props.HasError()) {
switch (maybe_props.GetError()) {
case storage::v3::Error::DELETED_OBJECT:
throw query::v2::QueryRuntimeException("Trying to get properties from a deleted object.");
case storage::v3::Error::NONEXISTENT_OBJECT:
throw query::v2::QueryRuntimeException("Trying to get properties from an edge that doesn't exist.");
case storage::v3::Error::SERIALIZATION_ERROR:
case storage::v3::Error::VERTEX_HAS_EDGES:
case storage::v3::Error::PROPERTIES_DISABLED:
throw query::v2::QueryRuntimeException("Unexpected error when getting properties.");
}
}
if (maybe_props->size() > 0) {
*os << " ";
DumpProperties(os, dba, *maybe_props);
}
*os << "]->(v);";
}
void DumpLabelIndex(std::ostream *os, query::v2::DbAccessor *dba, const storage::v3::LabelId label) {
*os << "CREATE INDEX ON :" << EscapeName(dba->LabelToName(label)) << ";";
}
void DumpLabelPropertyIndex(std::ostream *os, query::v2::DbAccessor *dba, storage::v3::LabelId label,
storage::v3::PropertyId property) {
*os << "CREATE INDEX ON :" << EscapeName(dba->LabelToName(label)) << "(" << EscapeName(dba->PropertyToName(property))
<< ");";
}
void DumpExistenceConstraint(std::ostream *os, query::v2::DbAccessor *dba, storage::v3::LabelId label,
storage::v3::PropertyId property) {
*os << "CREATE CONSTRAINT ON (u:" << EscapeName(dba->LabelToName(label)) << ") ASSERT EXISTS (u."
<< EscapeName(dba->PropertyToName(property)) << ");";
}
void DumpUniqueConstraint(std::ostream *os, query::v2::DbAccessor *dba, storage::v3::LabelId label,
const std::set<storage::v3::PropertyId> &properties) {
*os << "CREATE CONSTRAINT ON (u:" << EscapeName(dba->LabelToName(label)) << ") ASSERT ";
utils::PrintIterable(*os, properties, ", ", [&dba](auto &stream, const auto &property) {
stream << "u." << EscapeName(dba->PropertyToName(property));
});
*os << " IS UNIQUE;";
}
} // namespace
PullPlanDump::PullPlanDump(DbAccessor *dba)
: dba_(dba),
vertices_iterable_(dba->Vertices(storage::v3::View::OLD)),
pull_chunks_{// Dump all label indices
CreateLabelIndicesPullChunk(),
// Dump all label property indices
CreateLabelPropertyIndicesPullChunk(),
// Dump all existence constraints
CreateExistenceConstraintsPullChunk(),
// Dump all unique constraints
CreateUniqueConstraintsPullChunk(),
// Create internal index for faster edge creation
CreateInternalIndexPullChunk(),
// Dump all vertices
CreateVertexPullChunk(),
// Dump all edges
CreateEdgePullChunk(),
// Drop the internal index
CreateDropInternalIndexPullChunk(),
// Internal index cleanup
CreateInternalIndexCleanupPullChunk()} {}
bool PullPlanDump::Pull(AnyStream *stream, std::optional<int> n) {
// Iterate all functions that stream some results.
// Each function should return number of results it streamed after it
// finishes. If the function did not finish streaming all the results,
// std::nullopt should be returned because n results have already been sent.
while (current_chunk_index_ < pull_chunks_.size() && (!n || *n > 0)) {
const auto maybe_streamed_count = pull_chunks_[current_chunk_index_](stream, n);
if (!maybe_streamed_count) {
// n wasn't large enough to stream all the results from the current chunk
break;
}
if (n) {
// chunk finished streaming its results
// subtract number of results streamed in current pull
// so we know how many results we need to stream from future
// chunks.
*n -= *maybe_streamed_count;
}
++current_chunk_index_;
}
return current_chunk_index_ == pull_chunks_.size();
}
PullPlanDump::PullChunk PullPlanDump::CreateLabelIndicesPullChunk() {
// Dump all label indices
return [this, global_index = 0U](AnyStream *stream, std::optional<int> n) mutable -> std::optional<size_t> {
// Delay the construction of indices vectors
if (!indices_info_) {
indices_info_.emplace(dba_->ListAllIndices());
}
const auto &label = indices_info_->label;
size_t local_counter = 0;
while (global_index < label.size() && (!n || local_counter < *n)) {
std::ostringstream os;
DumpLabelIndex(&os, dba_, label[global_index]);
stream->Result({TypedValue(os.str())});
++global_index;
++local_counter;
}
if (global_index == label.size()) {
return local_counter;
}
return std::nullopt;
};
}
PullPlanDump::PullChunk PullPlanDump::CreateLabelPropertyIndicesPullChunk() {
return [this, global_index = 0U](AnyStream *stream, std::optional<int> n) mutable -> std::optional<size_t> {
// Delay the construction of indices vectors
if (!indices_info_) {
indices_info_.emplace(dba_->ListAllIndices());
}
const auto &label_property = indices_info_->label_property;
size_t local_counter = 0;
while (global_index < label_property.size() && (!n || local_counter < *n)) {
std::ostringstream os;
const auto &label_property_index = label_property[global_index];
DumpLabelPropertyIndex(&os, dba_, label_property_index.first, label_property_index.second);
stream->Result({TypedValue(os.str())});
++global_index;
++local_counter;
}
if (global_index == label_property.size()) {
return local_counter;
}
return std::nullopt;
};
}
PullPlanDump::PullChunk PullPlanDump::CreateExistenceConstraintsPullChunk() {
return [this, global_index = 0U](AnyStream *stream, std::optional<int> n) mutable -> std::optional<size_t> {
// Delay the construction of constraint vectors
if (!constraints_info_) {
constraints_info_.emplace(dba_->ListAllConstraints());
}
const auto &existence = constraints_info_->existence;
size_t local_counter = 0;
while (global_index < existence.size() && (!n || local_counter < *n)) {
const auto &constraint = existence[global_index];
std::ostringstream os;
DumpExistenceConstraint(&os, dba_, constraint.first, constraint.second);
stream->Result({TypedValue(os.str())});
++global_index;
++local_counter;
}
if (global_index == existence.size()) {
return local_counter;
}
return std::nullopt;
};
}
PullPlanDump::PullChunk PullPlanDump::CreateUniqueConstraintsPullChunk() {
return [this, global_index = 0U](AnyStream *stream, std::optional<int> n) mutable -> std::optional<size_t> {
// Delay the construction of constraint vectors
if (!constraints_info_) {
constraints_info_.emplace(dba_->ListAllConstraints());
}
const auto &unique = constraints_info_->unique;
size_t local_counter = 0;
while (global_index < unique.size() && (!n || local_counter < *n)) {
const auto &constraint = unique[global_index];
std::ostringstream os;
DumpUniqueConstraint(&os, dba_, constraint.first, constraint.second);
stream->Result({TypedValue(os.str())});
++global_index;
++local_counter;
}
if (global_index == unique.size()) {
return local_counter;
}
return std::nullopt;
};
}
PullPlanDump::PullChunk PullPlanDump::CreateInternalIndexPullChunk() {
return [this](AnyStream *stream, std::optional<int>) mutable -> std::optional<size_t> {
if (vertices_iterable_.begin() != vertices_iterable_.end()) {
std::ostringstream os;
os << "CREATE INDEX ON :" << kInternalVertexLabel << "(" << kInternalPropertyId << ");";
stream->Result({TypedValue(os.str())});
internal_index_created_ = true;
return 1;
}
return 0;
};
}
PullPlanDump::PullChunk PullPlanDump::CreateVertexPullChunk() {
return [this, maybe_current_iter = std::optional<VertexAccessorIterableIterator>{}](
AnyStream *stream, std::optional<int> n) mutable -> std::optional<size_t> {
// Delay the call of begin() function
// If multiple begins are called before an iteration,
// one iteration will make the rest of iterators be in undefined
// states.
if (!maybe_current_iter) {
maybe_current_iter.emplace(vertices_iterable_.begin());
}
auto &current_iter{*maybe_current_iter};
size_t local_counter = 0;
while (current_iter != vertices_iterable_.end() && (!n || local_counter < *n)) {
std::ostringstream os;
DumpVertex(&os, dba_, *current_iter);
stream->Result({TypedValue(os.str())});
++local_counter;
++current_iter;
}
if (current_iter == vertices_iterable_.end()) {
return local_counter;
}
return std::nullopt;
};
}
PullPlanDump::PullChunk PullPlanDump::CreateEdgePullChunk() {
return [this, maybe_current_vertex_iter = std::optional<VertexAccessorIterableIterator>{},
// we need to save the iterable which contains list of accessor so
// our saved iterator is valid in the next run
maybe_edge_iterable = std::shared_ptr<EdgeAccessorIterable>{nullptr},
maybe_current_edge_iter = std::optional<EdgeAccessorIterableIterator>{}](
AnyStream *stream, std::optional<int> n) mutable -> std::optional<size_t> {
// Delay the call of begin() function
// If multiple begins are called before an iteration,
// one iteration will make the rest of iterators be in undefined
// states.
if (!maybe_current_vertex_iter) {
maybe_current_vertex_iter.emplace(vertices_iterable_.begin());
}
auto &current_vertex_iter{*maybe_current_vertex_iter};
size_t local_counter = 0U;
for (; current_vertex_iter != vertices_iterable_.end() && (!n || local_counter < *n); ++current_vertex_iter) {
const auto &vertex = *current_vertex_iter;
// If we have a saved iterable from a previous pull
// we need to use the same iterable
if (!maybe_edge_iterable) {
maybe_edge_iterable = std::make_shared<EdgeAccessorIterable>(vertex.OutEdges(storage::v3::View::OLD));
}
auto &maybe_edges = *maybe_edge_iterable;
MG_ASSERT(maybe_edges.HasValue(), "Invalid database state!");
auto current_edge_iter = maybe_current_edge_iter ? *maybe_current_edge_iter : maybe_edges->begin();
for (; current_edge_iter != maybe_edges->end() && (!n || local_counter < *n); ++current_edge_iter) {
std::ostringstream os;
DumpEdge(&os, dba_, *current_edge_iter);
stream->Result({TypedValue(os.str())});
++local_counter;
}
if (current_edge_iter != maybe_edges->end()) {
maybe_current_edge_iter.emplace(current_edge_iter);
return std::nullopt;
}
maybe_current_edge_iter = std::nullopt;
maybe_edge_iterable = nullptr;
}
if (current_vertex_iter == vertices_iterable_.end()) {
return local_counter;
}
return std::nullopt;
};
}
PullPlanDump::PullChunk PullPlanDump::CreateDropInternalIndexPullChunk() {
return [this](AnyStream *stream, std::optional<int>) {
if (internal_index_created_) {
std::ostringstream os;
os << "DROP INDEX ON :" << kInternalVertexLabel << "(" << kInternalPropertyId << ");";
stream->Result({TypedValue(os.str())});
return 1;
}
return 0;
};
}
PullPlanDump::PullChunk PullPlanDump::CreateInternalIndexCleanupPullChunk() {
return [this](AnyStream *stream, std::optional<int>) {
if (internal_index_created_) {
std::ostringstream os;
os << "MATCH (u) REMOVE u:" << kInternalVertexLabel << ", u." << kInternalPropertyId << ";";
stream->Result({TypedValue(os.str())});
return 1;
}
return 0;
};
}
void DumpDatabaseToCypherQueries(query::v2::DbAccessor *dba, AnyStream *stream) { PullPlanDump(dba).Pull(stream, {}); }
} // namespace memgraph::query::v2

66
src/query/v2/dump.hpp Normal file
View File

@ -0,0 +1,66 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <ostream>
#include "query/v2/db_accessor.hpp"
#include "query/v2/stream.hpp"
#include "storage/v3/storage.hpp"
namespace memgraph::query::v2 {
void DumpDatabaseToCypherQueries(query::v2::DbAccessor *dba, AnyStream *stream);
struct PullPlanDump {
explicit PullPlanDump(query::v2::DbAccessor *dba);
/// Pull the dump results lazily
/// @return true if all results were returned, false otherwise
bool Pull(AnyStream *stream, std::optional<int> n);
private:
query::v2::DbAccessor *dba_ = nullptr;
std::optional<storage::v3::IndicesInfo> indices_info_ = std::nullopt;
std::optional<storage::v3::ConstraintsInfo> constraints_info_ = std::nullopt;
using VertexAccessorIterable = decltype(std::declval<query::v2::DbAccessor>().Vertices(storage::v3::View::OLD));
using VertexAccessorIterableIterator = decltype(std::declval<VertexAccessorIterable>().begin());
using EdgeAccessorIterable = decltype(std::declval<VertexAccessor>().OutEdges(storage::v3::View::OLD));
using EdgeAccessorIterableIterator = decltype(std::declval<EdgeAccessorIterable>().GetValue().begin());
VertexAccessorIterable vertices_iterable_;
bool internal_index_created_ = false;
size_t current_chunk_index_ = 0;
using PullChunk = std::function<std::optional<size_t>(AnyStream *stream, std::optional<int> n)>;
// We define every part of the dump query in a self contained function.
// Each functions is responsible of keeping track of its execution status.
// If a function did finish its execution, it should return number of results
// it streamed so we know how many rows should be pulled from the next
// function, otherwise std::nullopt is returned.
std::vector<PullChunk> pull_chunks_;
PullChunk CreateLabelIndicesPullChunk();
PullChunk CreateLabelPropertyIndicesPullChunk();
PullChunk CreateExistenceConstraintsPullChunk();
PullChunk CreateUniqueConstraintsPullChunk();
PullChunk CreateInternalIndexPullChunk();
PullChunk CreateVertexPullChunk();
PullChunk CreateEdgePullChunk();
PullChunk CreateDropInternalIndexPullChunk();
PullChunk CreateInternalIndexCleanupPullChunk();
};
} // namespace memgraph::query::v2

227
src/query/v2/exceptions.hpp Normal file
View File

@ -0,0 +1,227 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "utils/exceptions.hpp"
#include <fmt/format.h>
namespace memgraph::query::v2 {
/**
* @brief Base class of all query language related exceptions. All exceptions
* derived from this one will be interpreted as ClientError-s, i. e. if client
* executes same query again without making modifications to the database data,
* query will fail again.
*/
class QueryException : public utils::BasicException {
using utils::BasicException::BasicException;
};
class LexingException : public QueryException {
public:
using QueryException::QueryException;
LexingException() : QueryException("") {}
};
class SyntaxException : public QueryException {
public:
using QueryException::QueryException;
SyntaxException() : QueryException("") {}
};
// TODO: Figure out what information to put in exception.
// Error reporting is tricky since we get stripped query and position of error
// in original query is not same as position of error in stripped query. Most
// correct approach would be to do semantic analysis with original query even
// for already hashed queries, but that has obvious performance issues. Other
// approach would be to report some of the semantic errors in runtime of the
// query and only report line numbers of semantic errors (not position in the
// line) if multiple line strings are not allowed by grammar. We could also
// print whole line that contains error instead of specifying line number.
class SemanticException : public QueryException {
public:
using QueryException::QueryException;
SemanticException() : QueryException("") {}
};
class UnboundVariableError : public SemanticException {
public:
explicit UnboundVariableError(const std::string &name) : SemanticException("Unbound variable: " + name + ".") {}
};
class RedeclareVariableError : public SemanticException {
public:
explicit RedeclareVariableError(const std::string &name) : SemanticException("Redeclaring variable: " + name + ".") {}
};
class TypeMismatchError : public SemanticException {
public:
TypeMismatchError(const std::string &name, const std::string &datum, const std::string &expected)
: SemanticException(fmt::format("Type mismatch: {} already defined as {}, expected {}.", name, datum, expected)) {
}
};
class UnprovidedParameterError : public QueryException {
public:
using QueryException::QueryException;
};
class ProfileInMulticommandTxException : public QueryException {
public:
using QueryException::QueryException;
ProfileInMulticommandTxException() : QueryException("PROFILE not allowed in multicommand transactions.") {}
};
class IndexInMulticommandTxException : public QueryException {
public:
using QueryException::QueryException;
IndexInMulticommandTxException() : QueryException("Index manipulation not allowed in multicommand transactions.") {}
};
class ConstraintInMulticommandTxException : public QueryException {
public:
using QueryException::QueryException;
ConstraintInMulticommandTxException()
: QueryException(
"Constraint manipulation not allowed in multicommand "
"transactions.") {}
};
class InfoInMulticommandTxException : public QueryException {
public:
using QueryException::QueryException;
InfoInMulticommandTxException() : QueryException("Info reporting not allowed in multicommand transactions.") {}
};
/**
* An exception for an illegal operation that can not be detected
* before the query starts executing over data.
*/
class QueryRuntimeException : public QueryException {
public:
using QueryException::QueryException;
};
// This one is inherited from BasicException and will be treated as
// TransientError, i. e. client will be encouraged to retry execution because it
// could succeed if executed again.
class HintedAbortError : public utils::BasicException {
public:
using utils::BasicException::BasicException;
HintedAbortError()
: utils::BasicException(
"Transaction was asked to abort, most likely because it was "
"executing longer than time specified by "
"--query-execution-timeout-sec flag.") {}
};
class ExplicitTransactionUsageException : public QueryRuntimeException {
public:
using QueryRuntimeException::QueryRuntimeException;
};
/**
* An exception for serialization error
*/
class TransactionSerializationException : public QueryException {
public:
using QueryException::QueryException;
TransactionSerializationException()
: QueryException(
"Cannot resolve conflicting transactions. You can retry this transaction when the conflicting transaction "
"is finished") {}
};
class ReconstructionException : public QueryException {
public:
ReconstructionException()
: QueryException(
"Record invalid after WITH clause. Most likely deleted by a "
"preceeding DELETE.") {}
};
class RemoveAttachedVertexException : public QueryRuntimeException {
public:
RemoveAttachedVertexException()
: QueryRuntimeException(
"Failed to remove node because of it's existing "
"connections. Consider using DETACH DELETE.") {}
};
class UserModificationInMulticommandTxException : public QueryException {
public:
UserModificationInMulticommandTxException()
: QueryException("Authentication clause not allowed in multicommand transactions.") {}
};
class InvalidArgumentsException : public QueryException {
public:
InvalidArgumentsException(const std::string &argument_name, const std::string &message)
: QueryException(fmt::format("Invalid arguments sent: {} - {}", argument_name, message)) {}
};
class ReplicationModificationInMulticommandTxException : public QueryException {
public:
ReplicationModificationInMulticommandTxException()
: QueryException("Replication clause not allowed in multicommand transactions.") {}
};
class LockPathModificationInMulticommandTxException : public QueryException {
public:
LockPathModificationInMulticommandTxException()
: QueryException("Lock path query not allowed in multicommand transactions.") {}
};
class FreeMemoryModificationInMulticommandTxException : public QueryException {
public:
FreeMemoryModificationInMulticommandTxException()
: QueryException("Free memory query not allowed in multicommand transactions.") {}
};
class TriggerModificationInMulticommandTxException : public QueryException {
public:
TriggerModificationInMulticommandTxException()
: QueryException("Trigger queries not allowed in multicommand transactions.") {}
};
class StreamQueryInMulticommandTxException : public QueryException {
public:
StreamQueryInMulticommandTxException()
: QueryException("Stream queries are not allowed in multicommand transactions.") {}
};
class IsolationLevelModificationInMulticommandTxException : public QueryException {
public:
IsolationLevelModificationInMulticommandTxException()
: QueryException("Isolation level cannot be modified in multicommand transactions.") {}
};
class CreateSnapshotInMulticommandTxException final : public QueryException {
public:
CreateSnapshotInMulticommandTxException()
: QueryException("Snapshot cannot be created in multicommand transactions.") {}
};
class SettingConfigInMulticommandTxException final : public QueryException {
public:
SettingConfigInMulticommandTxException()
: QueryException("Settings cannot be changed or fetched in multicommand transactions.") {}
};
class VersionInfoInMulticommandTxException : public QueryException {
public:
VersionInfoInMulticommandTxException()
: QueryException("Version info query not allowed in multicommand transactions.") {}
};
} // namespace memgraph::query::v2

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,133 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "utils/visitor.hpp"
namespace memgraph::query::v2 {
// Forward declares for Tree visitors.
class CypherQuery;
class SingleQuery;
class CypherUnion;
class NamedExpression;
class Identifier;
class PropertyLookup;
class LabelsTest;
class Aggregation;
class Function;
class Reduce;
class Coalesce;
class Extract;
class All;
class Single;
class Any;
class None;
class ParameterLookup;
class CallProcedure;
class Create;
class Match;
class Return;
class With;
class Pattern;
class NodeAtom;
class EdgeAtom;
class PrimitiveLiteral;
class ListLiteral;
class MapLiteral;
class OrOperator;
class XorOperator;
class AndOperator;
class NotOperator;
class AdditionOperator;
class SubtractionOperator;
class MultiplicationOperator;
class DivisionOperator;
class ModOperator;
class UnaryPlusOperator;
class UnaryMinusOperator;
class IsNullOperator;
class NotEqualOperator;
class EqualOperator;
class LessOperator;
class GreaterOperator;
class LessEqualOperator;
class GreaterEqualOperator;
class InListOperator;
class SubscriptOperator;
class ListSlicingOperator;
class IfOperator;
class Delete;
class Where;
class SetProperty;
class SetProperties;
class SetLabels;
class RemoveProperty;
class RemoveLabels;
class Merge;
class Unwind;
class AuthQuery;
class ExplainQuery;
class ProfileQuery;
class IndexQuery;
class InfoQuery;
class ConstraintQuery;
class RegexMatch;
class DumpQuery;
class ReplicationQuery;
class LockPathQuery;
class LoadCsv;
class FreeMemoryQuery;
class TriggerQuery;
class IsolationLevelQuery;
class CreateSnapshotQuery;
class StreamQuery;
class SettingQuery;
class VersionQuery;
class Foreach;
using TreeCompositeVisitor = utils::CompositeVisitor<
SingleQuery, CypherUnion, NamedExpression, OrOperator, XorOperator, AndOperator, NotOperator, AdditionOperator,
SubtractionOperator, MultiplicationOperator, DivisionOperator, ModOperator, NotEqualOperator, EqualOperator,
LessOperator, GreaterOperator, LessEqualOperator, GreaterEqualOperator, InListOperator, SubscriptOperator,
ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, ListLiteral, MapLiteral,
PropertyLookup, LabelsTest, Aggregation, Function, Reduce, Coalesce, Extract, All, Single, Any, None, CallProcedure,
Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom, Delete, Where, SetProperty, SetProperties, SetLabels,
RemoveProperty, RemoveLabels, Merge, Unwind, RegexMatch, LoadCsv, Foreach>;
using TreeLeafVisitor = utils::LeafVisitor<Identifier, PrimitiveLiteral, ParameterLookup>;
class HierarchicalTreeVisitor : public TreeCompositeVisitor, public TreeLeafVisitor {
public:
using TreeCompositeVisitor::PostVisit;
using TreeCompositeVisitor::PreVisit;
using TreeLeafVisitor::Visit;
using typename TreeLeafVisitor::ReturnType;
};
template <class TResult>
class ExpressionVisitor
: public utils::Visitor<
TResult, NamedExpression, OrOperator, XorOperator, AndOperator, NotOperator, AdditionOperator,
SubtractionOperator, MultiplicationOperator, DivisionOperator, ModOperator, NotEqualOperator, EqualOperator,
LessOperator, GreaterOperator, LessEqualOperator, GreaterEqualOperator, InListOperator, SubscriptOperator,
ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, ListLiteral,
MapLiteral, PropertyLookup, LabelsTest, Aggregation, Function, Reduce, Coalesce, Extract, All, Single, Any,
None, ParameterLookup, Identifier, PrimitiveLiteral, RegexMatch> {};
template <class TResult>
class QueryVisitor
: public utils::Visitor<TResult, CypherQuery, ExplainQuery, ProfileQuery, IndexQuery, AuthQuery, InfoQuery,
ConstraintQuery, DumpQuery, ReplicationQuery, LockPathQuery, FreeMemoryQuery, TriggerQuery,
IsolationLevelQuery, CreateSnapshotQuery, StreamQuery, SettingQuery, VersionQuery> {};
} // namespace memgraph::query::v2

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,886 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <string>
#include <unordered_set>
#include <utility>
#include <antlr4-runtime.h>
#include "query/v2/frontend/ast/ast.hpp"
#include "query/v2/frontend/opencypher/generated/MemgraphCypherBaseVisitor.h"
#include "utils/exceptions.hpp"
#include "utils/logging.hpp"
namespace memgraph::query::v2::frontend {
using antlropencypher::MemgraphCypher;
struct ParsingContext {
bool is_query_cached = false;
};
class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
public:
explicit CypherMainVisitor(ParsingContext context, AstStorage *storage) : context_(context), storage_(storage) {}
private:
Expression *CreateBinaryOperatorByToken(size_t token, Expression *e1, Expression *e2) {
switch (token) {
case MemgraphCypher::OR:
return storage_->Create<OrOperator>(e1, e2);
case MemgraphCypher::XOR:
return storage_->Create<XorOperator>(e1, e2);
case MemgraphCypher::AND:
return storage_->Create<AndOperator>(e1, e2);
case MemgraphCypher::PLUS:
return storage_->Create<AdditionOperator>(e1, e2);
case MemgraphCypher::MINUS:
return storage_->Create<SubtractionOperator>(e1, e2);
case MemgraphCypher::ASTERISK:
return storage_->Create<MultiplicationOperator>(e1, e2);
case MemgraphCypher::SLASH:
return storage_->Create<DivisionOperator>(e1, e2);
case MemgraphCypher::PERCENT:
return storage_->Create<ModOperator>(e1, e2);
case MemgraphCypher::EQ:
return storage_->Create<EqualOperator>(e1, e2);
case MemgraphCypher::NEQ1:
case MemgraphCypher::NEQ2:
return storage_->Create<NotEqualOperator>(e1, e2);
case MemgraphCypher::LT:
return storage_->Create<LessOperator>(e1, e2);
case MemgraphCypher::GT:
return storage_->Create<GreaterOperator>(e1, e2);
case MemgraphCypher::LTE:
return storage_->Create<LessEqualOperator>(e1, e2);
case MemgraphCypher::GTE:
return storage_->Create<GreaterEqualOperator>(e1, e2);
default:
throw utils::NotYetImplemented("binary operator");
}
}
Expression *CreateUnaryOperatorByToken(size_t token, Expression *e) {
switch (token) {
case MemgraphCypher::NOT:
return storage_->Create<NotOperator>(e);
case MemgraphCypher::PLUS:
return storage_->Create<UnaryPlusOperator>(e);
case MemgraphCypher::MINUS:
return storage_->Create<UnaryMinusOperator>(e);
default:
throw utils::NotYetImplemented("unary operator");
}
}
auto ExtractOperators(std::vector<antlr4::tree::ParseTree *> &all_children,
const std::vector<size_t> &allowed_operators) {
std::vector<size_t> operators;
for (auto *child : all_children) {
antlr4::tree::TerminalNode *operator_node = nullptr;
if ((operator_node = dynamic_cast<antlr4::tree::TerminalNode *>(child))) {
if (std::find(allowed_operators.begin(), allowed_operators.end(), operator_node->getSymbol()->getType()) !=
allowed_operators.end()) {
operators.push_back(operator_node->getSymbol()->getType());
}
}
}
return operators;
}
/**
* Convert opencypher's n-ary production to ast binary operators.
*
* @param _expressions Subexpressions of child for which we construct ast
* operators, for example expression6 if we want to create ast nodes for
* expression7.
*/
template <typename TExpression>
Expression *LeftAssociativeOperatorExpression(std::vector<TExpression *> _expressions,
std::vector<antlr4::tree::ParseTree *> all_children,
const std::vector<size_t> &allowed_operators) {
DMG_ASSERT(_expressions.size(), "can't happen");
std::vector<Expression *> expressions;
auto operators = ExtractOperators(all_children, allowed_operators);
for (auto *expression : _expressions) {
expressions.push_back(std::any_cast<Expression *>(expression->accept(this)));
}
Expression *first_operand = expressions[0];
for (int i = 1; i < (int)expressions.size(); ++i) {
first_operand = CreateBinaryOperatorByToken(operators[i - 1], first_operand, expressions[i]);
}
return first_operand;
}
template <typename TExpression>
Expression *PrefixUnaryOperator(TExpression *_expression, std::vector<antlr4::tree::ParseTree *> all_children,
const std::vector<size_t> &allowed_operators) {
DMG_ASSERT(_expression, "can't happen");
auto operators = ExtractOperators(all_children, allowed_operators);
Expression *expression = std::any_cast<Expression *>(_expression->accept(this));
for (int i = (int)operators.size() - 1; i >= 0; --i) {
expression = CreateUnaryOperatorByToken(operators[i], expression);
}
return expression;
}
/**
* @return CypherQuery*
*/
antlrcpp::Any visitCypherQuery(MemgraphCypher::CypherQueryContext *ctx) override;
/**
* @return IndexQuery*
*/
antlrcpp::Any visitIndexQuery(MemgraphCypher::IndexQueryContext *ctx) override;
/**
* @return ExplainQuery*
*/
antlrcpp::Any visitExplainQuery(MemgraphCypher::ExplainQueryContext *ctx) override;
/**
* @return ProfileQuery*
*/
antlrcpp::Any visitProfileQuery(MemgraphCypher::ProfileQueryContext *ctx) override;
/**
* @return InfoQuery*
*/
antlrcpp::Any visitInfoQuery(MemgraphCypher::InfoQueryContext *ctx) override;
/**
* @return Constraint
*/
antlrcpp::Any visitConstraint(MemgraphCypher::ConstraintContext *ctx) override;
/**
* @return ConstraintQuery*
*/
antlrcpp::Any visitConstraintQuery(MemgraphCypher::ConstraintQueryContext *ctx) override;
/**
* @return AuthQuery*
*/
antlrcpp::Any visitAuthQuery(MemgraphCypher::AuthQueryContext *ctx) override;
/**
* @return DumpQuery*
*/
antlrcpp::Any visitDumpQuery(MemgraphCypher::DumpQueryContext *ctx) override;
/**
* @return ReplicationQuery*
*/
antlrcpp::Any visitReplicationQuery(MemgraphCypher::ReplicationQueryContext *ctx) override;
/**
* @return ReplicationQuery*
*/
antlrcpp::Any visitSetReplicationRole(MemgraphCypher::SetReplicationRoleContext *ctx) override;
/**
* @return ReplicationQuery*
*/
antlrcpp::Any visitShowReplicationRole(MemgraphCypher::ShowReplicationRoleContext *ctx) override;
/**
* @return ReplicationQuery*
*/
antlrcpp::Any visitRegisterReplica(MemgraphCypher::RegisterReplicaContext *ctx) override;
/**
* @return ReplicationQuery*
*/
antlrcpp::Any visitDropReplica(MemgraphCypher::DropReplicaContext *ctx) override;
/**
* @return ReplicationQuery*
*/
antlrcpp::Any visitShowReplicas(MemgraphCypher::ShowReplicasContext *ctx) override;
/**
* @return LockPathQuery*
*/
antlrcpp::Any visitLockPathQuery(MemgraphCypher::LockPathQueryContext *ctx) override;
/**
* @return LoadCsvQuery*
*/
antlrcpp::Any visitLoadCsv(MemgraphCypher::LoadCsvContext *ctx) override;
/**
* @return FreeMemoryQuery*
*/
antlrcpp::Any visitFreeMemoryQuery(MemgraphCypher::FreeMemoryQueryContext *ctx) override;
/**
* @return TriggerQuery*
*/
antlrcpp::Any visitTriggerQuery(MemgraphCypher::TriggerQueryContext *ctx) override;
/**
* @return CreateTrigger*
*/
antlrcpp::Any visitCreateTrigger(MemgraphCypher::CreateTriggerContext *ctx) override;
/**
* @return DropTrigger*
*/
antlrcpp::Any visitDropTrigger(MemgraphCypher::DropTriggerContext *ctx) override;
/**
* @return ShowTriggers*
*/
antlrcpp::Any visitShowTriggers(MemgraphCypher::ShowTriggersContext *ctx) override;
/**
* @return IsolationLevelQuery*
*/
antlrcpp::Any visitIsolationLevelQuery(MemgraphCypher::IsolationLevelQueryContext *ctx) override;
/**
* @return CreateSnapshotQuery*
*/
antlrcpp::Any visitCreateSnapshotQuery(MemgraphCypher::CreateSnapshotQueryContext *ctx) override;
/**
* @return StreamQuery*
*/
antlrcpp::Any visitStreamQuery(MemgraphCypher::StreamQueryContext *ctx) override;
/**
* @return StreamQuery*
*/
antlrcpp::Any visitCreateStream(MemgraphCypher::CreateStreamContext *ctx) override;
/**
* @return StreamQuery*
*/
antlrcpp::Any visitConfigKeyValuePair(MemgraphCypher::ConfigKeyValuePairContext *ctx) override;
/**
* @return StreamQuery*
*/
antlrcpp::Any visitConfigMap(MemgraphCypher::ConfigMapContext *ctx) override;
/**
* @return StreamQuery*
*/
antlrcpp::Any visitKafkaCreateStream(MemgraphCypher::KafkaCreateStreamContext *ctx) override;
/**
* @return StreamQuery*
*/
antlrcpp::Any visitKafkaCreateStreamConfig(MemgraphCypher::KafkaCreateStreamConfigContext *ctx) override;
/**
* @return StreamQuery*
*/
antlrcpp::Any visitPulsarCreateStreamConfig(MemgraphCypher::PulsarCreateStreamConfigContext *ctx) override;
/**
* @return StreamQuery*
*/
antlrcpp::Any visitPulsarCreateStream(MemgraphCypher::PulsarCreateStreamContext *ctx) override;
/**
* @return StreamQuery*
*/
antlrcpp::Any visitCommonCreateStreamConfig(MemgraphCypher::CommonCreateStreamConfigContext *ctx) override;
/**
* @return StreamQuery*
*/
antlrcpp::Any visitDropStream(MemgraphCypher::DropStreamContext *ctx) override;
/**
* @return StreamQuery*
*/
antlrcpp::Any visitStartStream(MemgraphCypher::StartStreamContext *ctx) override;
/**
* @return StreamQuery*
*/
antlrcpp::Any visitStartAllStreams(MemgraphCypher::StartAllStreamsContext *ctx) override;
/**
* @return StreamQuery*
*/
antlrcpp::Any visitStopStream(MemgraphCypher::StopStreamContext *ctx) override;
/**
* @return StreamQuery*
*/
antlrcpp::Any visitStopAllStreams(MemgraphCypher::StopAllStreamsContext *ctx) override;
/**
* @return StreamQuery*
*/
antlrcpp::Any visitShowStreams(MemgraphCypher::ShowStreamsContext *ctx) override;
/**
* @return StreamQuery*
*/
antlrcpp::Any visitCheckStream(MemgraphCypher::CheckStreamContext *ctx) override;
/**
* @return SettingQuery*
*/
antlrcpp::Any visitSettingQuery(MemgraphCypher::SettingQueryContext *ctx) override;
/**
* @return SetSetting*
*/
antlrcpp::Any visitSetSetting(MemgraphCypher::SetSettingContext *ctx) override;
/**
* @return ShowSetting*
*/
antlrcpp::Any visitShowSetting(MemgraphCypher::ShowSettingContext *ctx) override;
/**
* @return ShowSettings*
*/
antlrcpp::Any visitShowSettings(MemgraphCypher::ShowSettingsContext *ctx) override;
/**
* @return VersionQuery*
*/
antlrcpp::Any visitVersionQuery(MemgraphCypher::VersionQueryContext *ctx) override;
/**
* @return CypherUnion*
*/
antlrcpp::Any visitCypherUnion(MemgraphCypher::CypherUnionContext *ctx) override;
/**
* @return SingleQuery*
*/
antlrcpp::Any visitSingleQuery(MemgraphCypher::SingleQueryContext *ctx) override;
/**
* @return Clause* or vector<Clause*>!!!
*/
antlrcpp::Any visitClause(MemgraphCypher::ClauseContext *ctx) override;
/**
* @return Match*
*/
antlrcpp::Any visitCypherMatch(MemgraphCypher::CypherMatchContext *ctx) override;
/**
* @return Create*
*/
antlrcpp::Any visitCreate(MemgraphCypher::CreateContext *ctx) override;
/**
* @return CallProcedure*
*/
antlrcpp::Any visitCallProcedure(MemgraphCypher::CallProcedureContext *ctx) override;
/**
* @return std::string
*/
antlrcpp::Any visitUserOrRoleName(MemgraphCypher::UserOrRoleNameContext *ctx) override;
/**
* @return AuthQuery*
*/
antlrcpp::Any visitCreateRole(MemgraphCypher::CreateRoleContext *ctx) override;
/**
* @return AuthQuery*
*/
antlrcpp::Any visitDropRole(MemgraphCypher::DropRoleContext *ctx) override;
/**
* @return AuthQuery*
*/
antlrcpp::Any visitShowRoles(MemgraphCypher::ShowRolesContext *ctx) override;
/**
* @return IndexQuery*
*/
antlrcpp::Any visitCreateIndex(MemgraphCypher::CreateIndexContext *ctx) override;
/**
* @return DropIndex*
*/
antlrcpp::Any visitDropIndex(MemgraphCypher::DropIndexContext *ctx) override;
/**
* @return AuthQuery*
*/
antlrcpp::Any visitCreateUser(MemgraphCypher::CreateUserContext *ctx) override;
/**
* @return AuthQuery*
*/
antlrcpp::Any visitSetPassword(MemgraphCypher::SetPasswordContext *ctx) override;
/**
* @return AuthQuery*
*/
antlrcpp::Any visitDropUser(MemgraphCypher::DropUserContext *ctx) override;
/**
* @return AuthQuery*
*/
antlrcpp::Any visitShowUsers(MemgraphCypher::ShowUsersContext *ctx) override;
/**
* @return AuthQuery*
*/
antlrcpp::Any visitSetRole(MemgraphCypher::SetRoleContext *ctx) override;
/**
* @return AuthQuery*
*/
antlrcpp::Any visitClearRole(MemgraphCypher::ClearRoleContext *ctx) override;
/**
* @return AuthQuery*
*/
antlrcpp::Any visitGrantPrivilege(MemgraphCypher::GrantPrivilegeContext *ctx) override;
/**
* @return AuthQuery*
*/
antlrcpp::Any visitDenyPrivilege(MemgraphCypher::DenyPrivilegeContext *ctx) override;
/**
* @return AuthQuery*
*/
antlrcpp::Any visitRevokePrivilege(MemgraphCypher::RevokePrivilegeContext *ctx) override;
/**
* @return AuthQuery::Privilege
*/
antlrcpp::Any visitPrivilege(MemgraphCypher::PrivilegeContext *ctx) override;
/**
* @return AuthQuery*
*/
antlrcpp::Any visitShowPrivileges(MemgraphCypher::ShowPrivilegesContext *ctx) override;
/**
* @return AuthQuery*
*/
antlrcpp::Any visitShowRoleForUser(MemgraphCypher::ShowRoleForUserContext *ctx) override;
/**
* @return AuthQuery*
*/
antlrcpp::Any visitShowUsersForRole(MemgraphCypher::ShowUsersForRoleContext *ctx) override;
/**
* @return Return*
*/
antlrcpp::Any visitCypherReturn(MemgraphCypher::CypherReturnContext *ctx) override;
/**
* @return Return*
*/
antlrcpp::Any visitReturnBody(MemgraphCypher::ReturnBodyContext *ctx) override;
/**
* @return pair<bool, vector<NamedExpression*>> first member is true if
* asterisk was found in return
* expressions.
*/
antlrcpp::Any visitReturnItems(MemgraphCypher::ReturnItemsContext *ctx) override;
/**
* @return vector<NamedExpression*>
*/
antlrcpp::Any visitReturnItem(MemgraphCypher::ReturnItemContext *ctx) override;
/**
* @return vector<SortItem>
*/
antlrcpp::Any visitOrder(MemgraphCypher::OrderContext *ctx) override;
/**
* @return SortItem
*/
antlrcpp::Any visitSortItem(MemgraphCypher::SortItemContext *ctx) override;
/**
* @return NodeAtom*
*/
antlrcpp::Any visitNodePattern(MemgraphCypher::NodePatternContext *ctx) override;
/**
* @return vector<LabelIx>
*/
antlrcpp::Any visitNodeLabels(MemgraphCypher::NodeLabelsContext *ctx) override;
/**
* @return unordered_map<PropertyIx, Expression*>
*/
antlrcpp::Any visitProperties(MemgraphCypher::PropertiesContext *ctx) override;
/**
* @return map<std::string, Expression*>
*/
antlrcpp::Any visitMapLiteral(MemgraphCypher::MapLiteralContext *ctx) override;
/**
* @return vector<Expression*>
*/
antlrcpp::Any visitListLiteral(MemgraphCypher::ListLiteralContext *ctx) override;
/**
* @return PropertyIx
*/
antlrcpp::Any visitPropertyKeyName(MemgraphCypher::PropertyKeyNameContext *ctx) override;
/**
* @return string
*/
antlrcpp::Any visitSymbolicName(MemgraphCypher::SymbolicNameContext *ctx) override;
/**
* @return vector<Pattern*>
*/
antlrcpp::Any visitPattern(MemgraphCypher::PatternContext *ctx) override;
/**
* @return Pattern*
*/
antlrcpp::Any visitPatternPart(MemgraphCypher::PatternPartContext *ctx) override;
/**
* @return Pattern*
*/
antlrcpp::Any visitPatternElement(MemgraphCypher::PatternElementContext *ctx) override;
/**
* @return vector<pair<EdgeAtom*, NodeAtom*>>
*/
antlrcpp::Any visitPatternElementChain(MemgraphCypher::PatternElementChainContext *ctx) override;
/**
*@return EdgeAtom*
*/
antlrcpp::Any visitRelationshipPattern(MemgraphCypher::RelationshipPatternContext *ctx) override;
/**
* This should never be called. Everything is done directly in
* visitRelationshipPattern.
*/
antlrcpp::Any visitRelationshipDetail(MemgraphCypher::RelationshipDetailContext *ctx) override;
/**
* This should never be called. Everything is done directly in
* visitRelationshipPattern.
*/
antlrcpp::Any visitRelationshipLambda(MemgraphCypher::RelationshipLambdaContext *ctx) override;
/**
* @return vector<EdgeTypeIx>
*/
antlrcpp::Any visitRelationshipTypes(MemgraphCypher::RelationshipTypesContext *ctx) override;
/**
* @return std::tuple<EdgeAtom::Type, int64_t, int64_t>.
*/
antlrcpp::Any visitVariableExpansion(MemgraphCypher::VariableExpansionContext *ctx) override;
/**
* Top level expression, does nothing.
*
* @return Expression*
*/
antlrcpp::Any visitExpression(MemgraphCypher::ExpressionContext *ctx) override;
/**
* OR.
*
* @return Expression*
*/
antlrcpp::Any visitExpression12(MemgraphCypher::Expression12Context *ctx) override;
/**
* XOR.
*
* @return Expression*
*/
antlrcpp::Any visitExpression11(MemgraphCypher::Expression11Context *ctx) override;
/**
* AND.
*
* @return Expression*
*/
antlrcpp::Any visitExpression10(MemgraphCypher::Expression10Context *ctx) override;
/**
* NOT.
*
* @return Expression*
*/
antlrcpp::Any visitExpression9(MemgraphCypher::Expression9Context *ctx) override;
/**
* Comparisons.
*
* @return Expression*
*/
antlrcpp::Any visitExpression8(MemgraphCypher::Expression8Context *ctx) override;
/**
* Never call this. Everything related to generating code for comparison
* operators should be done in visitExpression8.
*/
antlrcpp::Any visitPartialComparisonExpression(MemgraphCypher::PartialComparisonExpressionContext *ctx) override;
/**
* Addition and subtraction.
*
* @return Expression*
*/
antlrcpp::Any visitExpression7(MemgraphCypher::Expression7Context *ctx) override;
/**
* Multiplication, division, modding.
*
* @return Expression*
*/
antlrcpp::Any visitExpression6(MemgraphCypher::Expression6Context *ctx) override;
/**
* Power.
*
* @return Expression*
*/
antlrcpp::Any visitExpression5(MemgraphCypher::Expression5Context *ctx) override;
/**
* Unary minus and plus.
*
* @return Expression*
*/
antlrcpp::Any visitExpression4(MemgraphCypher::Expression4Context *ctx) override;
/**
* IS NULL, IS NOT NULL, STARTS WITH, END WITH, =~, ...
*
* @return Expression*
*/
antlrcpp::Any visitExpression3a(MemgraphCypher::Expression3aContext *ctx) override;
/**
* Does nothing, everything is done in visitExpression3a.
*
* @return Expression*
*/
antlrcpp::Any visitStringAndNullOperators(MemgraphCypher::StringAndNullOperatorsContext *ctx) override;
/**
* List indexing and slicing.
*
* @return Expression*
*/
antlrcpp::Any visitExpression3b(MemgraphCypher::Expression3bContext *ctx) override;
/**
* Does nothing, everything is done in visitExpression3b.
*/
antlrcpp::Any visitListIndexingOrSlicing(MemgraphCypher::ListIndexingOrSlicingContext *ctx) override;
/**
* Node labels test.
*
* @return Expression*
*/
antlrcpp::Any visitExpression2a(MemgraphCypher::Expression2aContext *ctx) override;
/**
* Property lookup.
*
* @return Expression*
*/
antlrcpp::Any visitExpression2b(MemgraphCypher::Expression2bContext *ctx) override;
/**
* Literals, params, list comprehension...
*
* @return Expression*
*/
antlrcpp::Any visitAtom(MemgraphCypher::AtomContext *ctx) override;
/**
* @return ParameterLookup*
*/
antlrcpp::Any visitParameter(MemgraphCypher::ParameterContext *ctx) override;
/**
* @return Expression*
*/
antlrcpp::Any visitParenthesizedExpression(MemgraphCypher::ParenthesizedExpressionContext *ctx) override;
/**
* @return Expression*
*/
antlrcpp::Any visitFunctionInvocation(MemgraphCypher::FunctionInvocationContext *ctx) override;
/**
* @return string - uppercased
*/
antlrcpp::Any visitFunctionName(MemgraphCypher::FunctionNameContext *ctx) override;
/**
* @return Expression*
*/
antlrcpp::Any visitLiteral(MemgraphCypher::LiteralContext *ctx) override;
/**
* Convert escaped string from a query to unescaped utf8 string.
*
* @return string
*/
antlrcpp::Any visitStringLiteral(const std::string &escaped);
/**
* @return bool
*/
antlrcpp::Any visitBooleanLiteral(MemgraphCypher::BooleanLiteralContext *ctx) override;
/**
* @return TypedValue with either double or int
*/
antlrcpp::Any visitNumberLiteral(MemgraphCypher::NumberLiteralContext *ctx) override;
/**
* @return int64_t
*/
antlrcpp::Any visitIntegerLiteral(MemgraphCypher::IntegerLiteralContext *ctx) override;
/**
* @return double
*/
antlrcpp::Any visitDoubleLiteral(MemgraphCypher::DoubleLiteralContext *ctx) override;
/**
* @return Delete*
*/
antlrcpp::Any visitCypherDelete(MemgraphCypher::CypherDeleteContext *ctx) override;
/**
* @return Where*
*/
antlrcpp::Any visitWhere(MemgraphCypher::WhereContext *ctx) override;
/**
* return vector<Clause*>
*/
antlrcpp::Any visitSet(MemgraphCypher::SetContext *ctx) override;
/**
* @return Clause*
*/
antlrcpp::Any visitSetItem(MemgraphCypher::SetItemContext *ctx) override;
/**
* return vector<Clause*>
*/
antlrcpp::Any visitRemove(MemgraphCypher::RemoveContext *ctx) override;
/**
* @return Clause*
*/
antlrcpp::Any visitRemoveItem(MemgraphCypher::RemoveItemContext *ctx) override;
/**
* @return PropertyLookup*
*/
antlrcpp::Any visitPropertyExpression(MemgraphCypher::PropertyExpressionContext *ctx) override;
/**
* @return IfOperator*
*/
antlrcpp::Any visitCaseExpression(MemgraphCypher::CaseExpressionContext *ctx) override;
/**
* Never call this. Ast generation for this production is done in
* @c visitCaseExpression.
*/
antlrcpp::Any visitCaseAlternatives(MemgraphCypher::CaseAlternativesContext *ctx) override;
/**
* @return With*
*/
antlrcpp::Any visitWith(MemgraphCypher::WithContext *ctx) override;
/**
* @return Merge*
*/
antlrcpp::Any visitMerge(MemgraphCypher::MergeContext *ctx) override;
/**
* @return Unwind*
*/
antlrcpp::Any visitUnwind(MemgraphCypher::UnwindContext *ctx) override;
/**
* Never call this. Ast generation for these expressions should be done by
* explicitly visiting the members of @c FilterExpressionContext.
*/
antlrcpp::Any visitFilterExpression(MemgraphCypher::FilterExpressionContext *) override;
/**
* @return Foreach*
*/
antlrcpp::Any visitForeach(MemgraphCypher::ForeachContext *ctx) override;
public:
Query *query() { return query_; }
const static std::string kAnonPrefix;
struct QueryInfo {
bool is_cacheable{true};
bool has_load_csv{false};
};
const auto &GetQueryInfo() const { return query_info_; }
private:
LabelIx AddLabel(const std::string &name);
PropertyIx AddProperty(const std::string &name);
EdgeTypeIx AddEdgeType(const std::string &name);
ParsingContext context_;
AstStorage *storage_;
std::unordered_map<uint8_t, std::variant<Expression *, std::string, std::vector<std::string>,
std::unordered_map<Expression *, Expression *>>>
memory_;
// Set of identifiers from queries.
std::unordered_set<std::string> users_identifiers;
// Identifiers that user didn't name.
std::vector<Identifier **> anonymous_identifiers;
Query *query_ = nullptr;
// All return items which are not variables must be aliased in with.
// We use this variable in visitReturnItem to check if we are in with or
// return.
bool in_with_ = false;
QueryInfo query_info_;
};
} // namespace memgraph::query::v2::frontend

View File

@ -0,0 +1,311 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/v2/frontend/ast/pretty_print.hpp"
#include <type_traits>
#include "query/v2/frontend/ast/ast.hpp"
#include "utils/algorithm.hpp"
#include "utils/string.hpp"
namespace memgraph::query::v2 {
namespace {
class ExpressionPrettyPrinter : public ExpressionVisitor<void> {
public:
explicit ExpressionPrettyPrinter(std::ostream *out);
// Unary operators
void Visit(NotOperator &op) override;
void Visit(UnaryPlusOperator &op) override;
void Visit(UnaryMinusOperator &op) override;
void Visit(IsNullOperator &op) override;
// Binary operators
void Visit(OrOperator &op) override;
void Visit(XorOperator &op) override;
void Visit(AndOperator &op) override;
void Visit(AdditionOperator &op) override;
void Visit(SubtractionOperator &op) override;
void Visit(MultiplicationOperator &op) override;
void Visit(DivisionOperator &op) override;
void Visit(ModOperator &op) override;
void Visit(NotEqualOperator &op) override;
void Visit(EqualOperator &op) override;
void Visit(LessOperator &op) override;
void Visit(GreaterOperator &op) override;
void Visit(LessEqualOperator &op) override;
void Visit(GreaterEqualOperator &op) override;
void Visit(InListOperator &op) override;
void Visit(SubscriptOperator &op) override;
// Other
void Visit(ListSlicingOperator &op) override;
void Visit(IfOperator &op) override;
void Visit(ListLiteral &op) override;
void Visit(MapLiteral &op) override;
void Visit(LabelsTest &op) override;
void Visit(Aggregation &op) override;
void Visit(Function &op) override;
void Visit(Reduce &op) override;
void Visit(Coalesce &op) override;
void Visit(Extract &op) override;
void Visit(All &op) override;
void Visit(Single &op) override;
void Visit(Any &op) override;
void Visit(None &op) override;
void Visit(Identifier &op) override;
void Visit(PrimitiveLiteral &op) override;
void Visit(PropertyLookup &op) override;
void Visit(ParameterLookup &op) override;
void Visit(NamedExpression &op) override;
void Visit(RegexMatch &op) override;
private:
std::ostream *out_;
};
// Declare all of the different `PrintObject` overloads upfront since they're
// mutually recursive. Without this, overload resolution depends on the ordering
// of the overloads within the source, which is quite fragile.
template <typename T>
void PrintObject(std::ostream *out, const T &arg);
void PrintObject(std::ostream *out, const std::string &str);
void PrintObject(std::ostream *out, Aggregation::Op op);
void PrintObject(std::ostream *out, Expression *expr);
void PrintObject(std::ostream *out, Identifier *expr);
void PrintObject(std::ostream *out, const storage::v3::PropertyValue &value);
template <typename T>
void PrintObject(std::ostream *out, const std::vector<T> &vec);
template <typename K, typename V>
void PrintObject(std::ostream *out, const std::map<K, V> &map);
template <typename T>
void PrintObject(std::ostream *out, const T &arg) {
static_assert(!std::is_convertible<T, Expression *>::value,
"This overload shouldn't be called with pointers convertible "
"to Expression *. This means your other PrintObject overloads aren't "
"being called for certain AST nodes when they should (or perhaps such "
"overloads don't exist yet).");
*out << arg;
}
void PrintObject(std::ostream *out, const std::string &str) { *out << utils::Escape(str); }
void PrintObject(std::ostream *out, Aggregation::Op op) { *out << Aggregation::OpToString(op); }
void PrintObject(std::ostream *out, Expression *expr) {
if (expr) {
ExpressionPrettyPrinter printer{out};
expr->Accept(printer);
} else {
*out << "<null>";
}
}
void PrintObject(std::ostream *out, Identifier *expr) { PrintObject(out, static_cast<Expression *>(expr)); }
void PrintObject(std::ostream *out, const storage::v3::PropertyValue &value) {
switch (value.type()) {
case storage::v3::PropertyValue::Type::Null:
*out << "null";
break;
case storage::v3::PropertyValue::Type::String:
PrintObject(out, value.ValueString());
break;
case storage::v3::PropertyValue::Type::Bool:
*out << (value.ValueBool() ? "true" : "false");
break;
case storage::v3::PropertyValue::Type::Int:
PrintObject(out, value.ValueInt());
break;
case storage::v3::PropertyValue::Type::Double:
PrintObject(out, value.ValueDouble());
break;
case storage::v3::PropertyValue::Type::List:
PrintObject(out, value.ValueList());
break;
case storage::v3::PropertyValue::Type::Map:
PrintObject(out, value.ValueMap());
break;
case storage::v3::PropertyValue::Type::TemporalData:
PrintObject(out, value.ValueTemporalData());
break;
}
}
template <typename T>
void PrintObject(std::ostream *out, const std::vector<T> &vec) {
*out << "[";
utils::PrintIterable(*out, vec, ", ", [](auto &stream, const auto &item) { PrintObject(&stream, item); });
*out << "]";
}
template <typename K, typename V>
void PrintObject(std::ostream *out, const std::map<K, V> &map) {
*out << "{";
utils::PrintIterable(*out, map, ", ", [](auto &stream, const auto &item) {
PrintObject(&stream, item.first);
stream << ": ";
PrintObject(&stream, item.second);
});
*out << "}";
}
template <typename T>
void PrintOperatorArgs(std::ostream *out, const T &arg) {
*out << " ";
PrintObject(out, arg);
*out << ")";
}
template <typename T, typename... Ts>
void PrintOperatorArgs(std::ostream *out, const T &arg, const Ts &...args) {
*out << " ";
PrintObject(out, arg);
PrintOperatorArgs(out, args...);
}
template <typename... Ts>
void PrintOperator(std::ostream *out, const std::string &name, const Ts &...args) {
*out << "(" << name;
PrintOperatorArgs(out, args...);
}
ExpressionPrettyPrinter::ExpressionPrettyPrinter(std::ostream *out) : out_(out) {}
#define UNARY_OPERATOR_VISIT(OP_NODE, OP_STR) \
void ExpressionPrettyPrinter::Visit(OP_NODE &op) { PrintOperator(out_, OP_STR, op.expression_); }
UNARY_OPERATOR_VISIT(NotOperator, "Not");
UNARY_OPERATOR_VISIT(UnaryPlusOperator, "+");
UNARY_OPERATOR_VISIT(UnaryMinusOperator, "-");
UNARY_OPERATOR_VISIT(IsNullOperator, "IsNull");
#undef UNARY_OPERATOR_VISIT
#define BINARY_OPERATOR_VISIT(OP_NODE, OP_STR) \
void ExpressionPrettyPrinter::Visit(OP_NODE &op) { PrintOperator(out_, OP_STR, op.expression1_, op.expression2_); }
BINARY_OPERATOR_VISIT(OrOperator, "Or");
BINARY_OPERATOR_VISIT(XorOperator, "Xor");
BINARY_OPERATOR_VISIT(AndOperator, "And");
BINARY_OPERATOR_VISIT(AdditionOperator, "+");
BINARY_OPERATOR_VISIT(SubtractionOperator, "-");
BINARY_OPERATOR_VISIT(MultiplicationOperator, "*");
BINARY_OPERATOR_VISIT(DivisionOperator, "/");
BINARY_OPERATOR_VISIT(ModOperator, "%");
BINARY_OPERATOR_VISIT(NotEqualOperator, "!=");
BINARY_OPERATOR_VISIT(EqualOperator, "==");
BINARY_OPERATOR_VISIT(LessOperator, "<");
BINARY_OPERATOR_VISIT(GreaterOperator, ">");
BINARY_OPERATOR_VISIT(LessEqualOperator, "<=");
BINARY_OPERATOR_VISIT(GreaterEqualOperator, ">=");
BINARY_OPERATOR_VISIT(InListOperator, "In");
BINARY_OPERATOR_VISIT(SubscriptOperator, "Subscript");
#undef BINARY_OPERATOR_VISIT
void ExpressionPrettyPrinter::Visit(ListSlicingOperator &op) {
PrintOperator(out_, "ListSlicing", op.list_, op.lower_bound_, op.upper_bound_);
}
void ExpressionPrettyPrinter::Visit(IfOperator &op) {
PrintOperator(out_, "If", op.condition_, op.then_expression_, op.else_expression_);
}
void ExpressionPrettyPrinter::Visit(ListLiteral &op) { PrintOperator(out_, "ListLiteral", op.elements_); }
void ExpressionPrettyPrinter::Visit(MapLiteral &op) {
std::map<std::string, Expression *> map;
for (const auto &kv : op.elements_) {
map[kv.first.name] = kv.second;
}
PrintObject(out_, map);
}
void ExpressionPrettyPrinter::Visit(LabelsTest &op) { PrintOperator(out_, "LabelsTest", op.expression_); }
void ExpressionPrettyPrinter::Visit(Aggregation &op) { PrintOperator(out_, "Aggregation", op.op_); }
void ExpressionPrettyPrinter::Visit(Function &op) { PrintOperator(out_, "Function", op.function_name_, op.arguments_); }
void ExpressionPrettyPrinter::Visit(Reduce &op) {
PrintOperator(out_, "Reduce", op.accumulator_, op.initializer_, op.identifier_, op.list_, op.expression_);
}
void ExpressionPrettyPrinter::Visit(Coalesce &op) { PrintOperator(out_, "Coalesce", op.expressions_); }
void ExpressionPrettyPrinter::Visit(Extract &op) {
PrintOperator(out_, "Extract", op.identifier_, op.list_, op.expression_);
}
void ExpressionPrettyPrinter::Visit(All &op) {
PrintOperator(out_, "All", op.identifier_, op.list_expression_, op.where_->expression_);
}
void ExpressionPrettyPrinter::Visit(Single &op) {
PrintOperator(out_, "Single", op.identifier_, op.list_expression_, op.where_->expression_);
}
void ExpressionPrettyPrinter::Visit(Any &op) {
PrintOperator(out_, "Any", op.identifier_, op.list_expression_, op.where_->expression_);
}
void ExpressionPrettyPrinter::Visit(None &op) {
PrintOperator(out_, "None", op.identifier_, op.list_expression_, op.where_->expression_);
}
void ExpressionPrettyPrinter::Visit(Identifier &op) { PrintOperator(out_, "Identifier", op.name_); }
void ExpressionPrettyPrinter::Visit(PrimitiveLiteral &op) { PrintObject(out_, op.value_); }
void ExpressionPrettyPrinter::Visit(PropertyLookup &op) {
PrintOperator(out_, "PropertyLookup", op.expression_, op.property_.name);
}
void ExpressionPrettyPrinter::Visit(ParameterLookup &op) { PrintOperator(out_, "ParameterLookup", op.token_position_); }
void ExpressionPrettyPrinter::Visit(NamedExpression &op) {
PrintOperator(out_, "NamedExpression", op.name_, op.expression_);
}
void ExpressionPrettyPrinter::Visit(RegexMatch &op) { PrintOperator(out_, "=~", op.string_expr_, op.regex_); }
} // namespace
void PrintExpression(Expression *expr, std::ostream *out) {
ExpressionPrettyPrinter printer{out};
expr->Accept(printer);
}
void PrintExpression(NamedExpression *expr, std::ostream *out) {
ExpressionPrettyPrinter printer{out};
expr->Accept(printer);
}
} // namespace memgraph::query::v2

View File

@ -0,0 +1,23 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <iostream>
#include "query/v2/frontend/ast/ast.hpp"
namespace memgraph::query::v2 {
void PrintExpression(Expression *expr, std::ostream *out);
void PrintExpression(NamedExpression *expr, std::ostream *out);
} // namespace memgraph::query::v2

View File

@ -0,0 +1,391 @@
/*
* Copyright (c) 2015-2016 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
parser grammar Cypher;
options { tokenVocab=CypherLexer; }
cypher : statement ';'? EOF ;
statement : query ;
query : cypherQuery
| indexQuery
| explainQuery
| profileQuery
| infoQuery
| constraintQuery
;
constraintQuery : ( CREATE | DROP ) CONSTRAINT ON constraint ;
constraint : '(' nodeName=variable ':' labelName ')' ASSERT EXISTS '(' constraintPropertyList ')'
| '(' nodeName=variable ':' labelName ')' ASSERT constraintPropertyList IS UNIQUE
| '(' nodeName=variable ':' labelName ')' ASSERT '(' constraintPropertyList ')' IS NODE KEY
;
constraintPropertyList : variable propertyLookup ( ',' variable propertyLookup )* ;
storageInfo : STORAGE INFO ;
indexInfo : INDEX INFO ;
constraintInfo : CONSTRAINT INFO ;
infoQuery : SHOW ( storageInfo | indexInfo | constraintInfo ) ;
explainQuery : EXPLAIN cypherQuery ;
profileQuery : PROFILE cypherQuery ;
cypherQuery : singleQuery ( cypherUnion )* ( queryMemoryLimit )? ;
indexQuery : createIndex | dropIndex;
singleQuery : clause ( clause )* ;
cypherUnion : ( UNION ALL singleQuery )
| ( UNION singleQuery )
;
clause : cypherMatch
| unwind
| merge
| create
| set
| cypherDelete
| remove
| with
| cypherReturn
| callProcedure
;
cypherMatch : OPTIONAL? MATCH pattern where? ;
unwind : UNWIND expression AS variable ;
merge : MERGE patternPart ( mergeAction )* ;
mergeAction : ( ON MATCH set )
| ( ON CREATE set )
;
create : CREATE pattern ;
set : SET setItem ( ',' setItem )* ;
setItem : ( propertyExpression '=' expression )
| ( variable '=' expression )
| ( variable '+=' expression )
| ( variable nodeLabels )
;
cypherDelete : DETACH? DELETE expression ( ',' expression )* ;
remove : REMOVE removeItem ( ',' removeItem )* ;
removeItem : ( variable nodeLabels )
| propertyExpression
;
with : WITH ( DISTINCT )? returnBody ( where )? ;
cypherReturn : RETURN ( DISTINCT )? returnBody ;
callProcedure : CALL procedureName '(' ( expression ( ',' expression )* )? ')' ( procedureMemoryLimit )? ( yieldProcedureResults )? ;
procedureName : symbolicName ( '.' symbolicName )* ;
yieldProcedureResults : YIELD ( '*' | ( procedureResult ( ',' procedureResult )* ) ) ;
memoryLimit : MEMORY ( UNLIMITED | LIMIT literal ( MB | KB ) ) ;
queryMemoryLimit : QUERY memoryLimit ;
procedureMemoryLimit : PROCEDURE memoryLimit ;
procedureResult : ( variable AS variable ) | variable ;
returnBody : returnItems ( order )? ( skip )? ( limit )? ;
returnItems : ( '*' ( ',' returnItem )* )
| ( returnItem ( ',' returnItem )* )
;
returnItem : ( expression AS variable )
| expression
;
order : ORDER BY sortItem ( ',' sortItem )* ;
skip : L_SKIP expression ;
limit : LIMIT expression ;
sortItem : expression ( ASCENDING | ASC | DESCENDING | DESC )? ;
where : WHERE expression ;
pattern : patternPart ( ',' patternPart )* ;
patternPart : ( variable '=' anonymousPatternPart )
| anonymousPatternPart
;
anonymousPatternPart : patternElement ;
patternElement : ( nodePattern ( patternElementChain )* )
| ( '(' patternElement ')' )
;
nodePattern : '(' ( variable )? ( nodeLabels )? ( properties )? ')' ;
patternElementChain : relationshipPattern nodePattern ;
relationshipPattern : ( leftArrowHead dash ( relationshipDetail )? dash rightArrowHead )
| ( leftArrowHead dash ( relationshipDetail )? dash )
| ( dash ( relationshipDetail )? dash rightArrowHead )
| ( dash ( relationshipDetail )? dash )
;
leftArrowHead : '<' | LeftArrowHeadPart ;
rightArrowHead : '>' | RightArrowHeadPart ;
dash : '-' | DashPart ;
relationshipDetail : '[' ( name=variable )? ( relationshipTypes )? ( variableExpansion )? properties ']'
| '[' ( name=variable )? ( relationshipTypes )? ( variableExpansion )? relationshipLambda ( total_weight=variable )? (relationshipLambda )? ']'
| '[' ( name=variable )? ( relationshipTypes )? ( variableExpansion )? (properties )* ( relationshipLambda total_weight=variable )? (relationshipLambda )? ']';
relationshipLambda: '(' traversed_edge=variable ',' traversed_node=variable '|' expression ')';
variableExpansion : '*' (BFS | WSHORTEST)? ( expression )? ( '..' ( expression )? )? ;
properties : mapLiteral
| parameter
;
relationshipTypes : ':' relTypeName ( '|' ':'? relTypeName )* ;
nodeLabels : nodeLabel ( nodeLabel )* ;
nodeLabel : ':' labelName ;
labelName : symbolicName ;
relTypeName : symbolicName ;
expression : expression12 ;
expression12 : expression11 ( OR expression11 )* ;
expression11 : expression10 ( XOR expression10 )* ;
expression10 : expression9 ( AND expression9 )* ;
expression9 : ( NOT )* expression8 ;
expression8 : expression7 ( partialComparisonExpression )* ;
expression7 : expression6 ( ( '+' expression6 ) | ( '-' expression6 ) )* ;
expression6 : expression5 ( ( '*' expression5 ) | ( '/' expression5 ) | ( '%' expression5 ) )* ;
expression5 : expression4 ( '^' expression4 )* ;
expression4 : ( ( '+' | '-' ) )* expression3a ;
expression3a : expression3b ( stringAndNullOperators )* ;
stringAndNullOperators : ( ( ( ( '=~' ) | ( IN ) | ( STARTS WITH ) | ( ENDS WITH ) | ( CONTAINS ) ) expression3b) | ( IS CYPHERNULL ) | ( IS NOT CYPHERNULL ) ) ;
expression3b : expression2a ( listIndexingOrSlicing )* ;
listIndexingOrSlicing : ( '[' expression ']' )
| ( '[' lower_bound=expression? '..' upper_bound=expression? ']' )
;
expression2a : expression2b ( nodeLabels )? ;
expression2b : atom ( propertyLookup )* ;
atom : literal
| parameter
| caseExpression
| ( COUNT '(' '*' ')' )
| listComprehension
| patternComprehension
| ( FILTER '(' filterExpression ')' )
| ( EXTRACT '(' extractExpression ')' )
| ( REDUCE '(' reduceExpression ')' )
| ( COALESCE '(' expression ( ',' expression )* ')' )
| ( ALL '(' filterExpression ')' )
| ( ANY '(' filterExpression ')' )
| ( NONE '(' filterExpression ')' )
| ( SINGLE '(' filterExpression ')' )
| relationshipsPattern
| parenthesizedExpression
| functionInvocation
| variable
;
literal : numberLiteral
| StringLiteral
| booleanLiteral
| CYPHERNULL
| mapLiteral
| listLiteral
;
booleanLiteral : TRUE
| FALSE
;
listLiteral : '[' ( expression ( ',' expression )* )? ']' ;
partialComparisonExpression : ( '=' expression7 )
| ( '<>' expression7 )
| ( '!=' expression7 )
| ( '<' expression7 )
| ( '>' expression7 )
| ( '<=' expression7 )
| ( '>=' expression7 )
;
parenthesizedExpression : '(' expression ')' ;
relationshipsPattern : nodePattern ( patternElementChain )+ ;
filterExpression : idInColl ( where )? ;
reduceExpression : accumulator=variable '=' initial=expression ',' idInColl '|' expression ;
extractExpression : idInColl '|' expression ;
idInColl : variable IN expression ;
functionInvocation : functionName '(' ( DISTINCT )? ( expression ( ',' expression )* )? ')' ;
functionName : symbolicName ( '.' symbolicName )* ;
listComprehension : '[' filterExpression ( '|' expression )? ']' ;
patternComprehension : '[' ( variable '=' )? relationshipsPattern ( WHERE expression )? '|' expression ']' ;
propertyLookup : '.' ( propertyKeyName ) ;
caseExpression : ( ( CASE ( caseAlternatives )+ ) | ( CASE test=expression ( caseAlternatives )+ ) ) ( ELSE else_expression=expression )? END ;
caseAlternatives : WHEN when_expression=expression THEN then_expression=expression ;
variable : symbolicName ;
numberLiteral : doubleLiteral
| integerLiteral
;
mapLiteral : '{' ( propertyKeyName ':' expression ( ',' propertyKeyName ':' expression )* )? '}' ;
parameter : '$' ( symbolicName | DecimalLiteral ) ;
propertyExpression : atom ( propertyLookup )+ ;
propertyKeyName : symbolicName ;
integerLiteral : DecimalLiteral
| OctalLiteral
| HexadecimalLiteral
;
createIndex : CREATE INDEX ON ':' labelName ( '(' propertyKeyName ')' )? ;
dropIndex : DROP INDEX ON ':' labelName ( '(' propertyKeyName ')' )? ;
doubleLiteral : FloatingLiteral ;
cypherKeyword : ALL
| AND
| ANY
| AS
| ASC
| ASCENDING
| ASSERT
| BFS
| BY
| CALL
| CASE
| CONSTRAINT
| CONTAINS
| COUNT
| CREATE
| CYPHERNULL
| DELETE
| DESC
| DESCENDING
| DETACH
| DISTINCT
| ELSE
| END
| ENDS
| EXISTS
| EXPLAIN
| EXTRACT
| FALSE
| FILTER
| IN
| INDEX
| INFO
| IS
| KEY
| LIMIT
| L_SKIP
| MATCH
| MERGE
| NODE
| NONE
| NOT
| ON
| OPTIONAL
| OR
| ORDER
| PROCEDURE
| PROFILE
| QUERY
| REDUCE
| REMOVE
| RETURN
| SET
| SHOW
| SINGLE
| STARTS
| STORAGE
| THEN
| TRUE
| UNION
| UNIQUE
| UNWIND
| WHEN
| WHERE
| WITH
| WSHORTEST
| XOR
| YIELD
;
symbolicName : UnescapedSymbolicName
| EscapedSymbolicName
| cypherKeyword
;

View File

@ -0,0 +1,208 @@
/*
* When changing this grammar make sure to update constants in
* src/query/frontend/stripped_lexer_constants.hpp (kKeywords, kSpecialTokens
* and bitsets) if needed.
*/
lexer grammar CypherLexer ;
import UnicodeCategories ;
/* Skip whitespace and comments. */
Skipped : ( Whitespace | Comment ) -> skip ;
fragment Whitespace : '\u0020'
| [\u0009-\u000D]
| [\u001C-\u001F]
| '\u1680' | '\u180E'
| [\u2000-\u200A]
| '\u2028' | '\u2029'
| '\u205F'
| '\u3000'
| '\u00A0'
| '\u202F'
;
fragment Comment : '/*' .*? '*/'
| '//' ~[\r\n]*
;
/* Special symbols. */
LPAREN : '(' ;
RPAREN : ')' ;
LBRACK : '[' ;
RBRACK : ']' ;
LBRACE : '{' ;
RBRACE : '}' ;
COMMA : ',' ;
DOT : '.' ;
DOTS : '..' ;
COLON : ':' ;
SEMICOLON : ';' ;
DOLLAR : '$' ;
PIPE : '|' ;
EQ : '=' ;
LT : '<' ;
GT : '>' ;
LTE : '<=' ;
GTE : '>=' ;
NEQ1 : '<>' ;
NEQ2 : '!=' ;
SIM : '=~' ;
PLUS : '+' ;
MINUS : '-' ;
ASTERISK : '*' ;
SLASH : '/' ;
PERCENT : '%' ;
CARET : '^' ;
PLUS_EQ : '+=' ;
/* Some random unicode characters that can be used to draw arrows. */
LeftArrowHeadPart : '⟨' | '〈' | '﹤' | '' ;
RightArrowHeadPart : '⟩' | '〉' | '﹥' | '' ;
DashPart : '­' | '' | '' | '' | '' | '—' | '―'
| '' | '' | '﹣' | ''
;
/* Cypher reserved words. */
ALL : A L L ;
AND : A N D ;
ANY : A N Y ;
AS : A S ;
ASC : A S C ;
ASCENDING : A S C E N D I N G ;
ASSERT : A S S E R T ;
BFS : B F S ;
BY : B Y ;
CALL : C A L L ;
CASE : C A S E ;
COALESCE : C O A L E S C E ;
CONSTRAINT : C O N S T R A I N T ;
CONTAINS : C O N T A I N S ;
COUNT : C O U N T ;
CREATE : C R E A T E ;
CYPHERNULL : N U L L ;
DELETE : D E L E T E ;
DESC : D E S C ;
DESCENDING : D E S C E N D I N G ;
DETACH : D E T A C H ;
DISTINCT : D I S T I N C T ;
DROP : D R O P ;
ELSE : E L S E ;
END : E N D ;
ENDS : E N D S ;
EXISTS : E X I S T S ;
EXPLAIN : E X P L A I N ;
EXTRACT : E X T R A C T ;
FALSE : F A L S E ;
FILTER : F I L T E R ;
IN : I N ;
INDEX : I N D E X ;
INFO : I N F O ;
IS : I S ;
KB : K B ;
KEY : K E Y ;
LIMIT : L I M I T ;
L_SKIP : S K I P ;
MATCH : M A T C H ;
MB : M B ;
MEMORY : M E M O R Y ;
MERGE : M E R G E ;
NODE : N O D E ;
NONE : N O N E ;
NOT : N O T ;
ON : O N ;
OPTIONAL : O P T I O N A L ;
OR : O R ;
ORDER : O R D E R ;
PROCEDURE : P R O C E D U R E ;
PROFILE : P R O F I L E ;
QUERY : Q U E R Y ;
REDUCE : R E D U C E ;
REMOVE : R E M O V E ;
RETURN : R E T U R N ;
SET : S E T ;
SHOW : S H O W ;
SINGLE : S I N G L E ;
STARTS : S T A R T S ;
STORAGE : S T O R A G E ;
THEN : T H E N ;
TRUE : T R U E ;
UNION : U N I O N ;
UNIQUE : U N I Q U E ;
UNLIMITED : U N L I M I T E D ;
UNWIND : U N W I N D ;
WHEN : W H E N ;
WHERE : W H E R E ;
WITH : W I T H ;
WSHORTEST : W S H O R T E S T ;
XOR : X O R ;
YIELD : Y I E L D ;
/* Double and single quoted string literals. */
StringLiteral : '"' ( ~[\\"] | EscapeSequence )* '"'
| '\'' ( ~[\\'] | EscapeSequence )* '\''
;
fragment EscapeSequence : '\\' ( B | F | N | R | T | '\\' | '\'' | '"' )
| '\\u' HexDigit HexDigit HexDigit HexDigit
| '\\U' HexDigit HexDigit HexDigit HexDigit
HexDigit HexDigit HexDigit HexDigit
;
/* Number literals. */
DecimalLiteral : '0' | NonZeroDigit ( DecDigit )* ;
OctalLiteral : '0' ( OctDigit )+ ;
HexadecimalLiteral : '0x' ( HexDigit )+ ;
FloatingLiteral : DecDigit* '.' DecDigit+ ( E '-'? DecDigit+ )?
| DecDigit+ ( '.' DecDigit* )? ( E '-'? DecDigit+ )
| DecDigit+ ( E '-'? DecDigit+ )
;
fragment NonZeroDigit : [1-9] ;
fragment DecDigit : [0-9] ;
fragment OctDigit : [0-7] ;
fragment HexDigit : [0-9] | [a-f] | [A-F] ;
/* Symbolic names. */
UnescapedSymbolicName : IdentifierStart ( IdentifierPart )* ;
EscapedSymbolicName : ( '`' ~[`]* '`' )+ ;
/**
* Based on the unicode identifier and pattern syntax
* (http://www.unicode.org/reports/tr31/)
* and extended with a few characters.
*/
IdentifierStart : ID_Start | Pc ;
IdentifierPart : ID_Continue | Sc ;
/* Hack for case-insensitive reserved words */
fragment A : 'A' | 'a' ;
fragment B : 'B' | 'b' ;
fragment C : 'C' | 'c' ;
fragment D : 'D' | 'd' ;
fragment E : 'E' | 'e' ;
fragment F : 'F' | 'f' ;
fragment G : 'G' | 'g' ;
fragment H : 'H' | 'h' ;
fragment I : 'I' | 'i' ;
fragment J : 'J' | 'j' ;
fragment K : 'K' | 'k' ;
fragment L : 'L' | 'l' ;
fragment M : 'M' | 'm' ;
fragment N : 'N' | 'n' ;
fragment O : 'O' | 'o' ;
fragment P : 'P' | 'p' ;
fragment Q : 'Q' | 'q' ;
fragment R : 'R' | 'r' ;
fragment S : 'S' | 's' ;
fragment T : 'T' | 't' ;
fragment U : 'U' | 'u' ;
fragment V : 'V' | 'v' ;
fragment W : 'W' | 'w' ;
fragment X : 'X' | 'x' ;
fragment Y : 'Y' | 'y' ;
fragment Z : 'Z' | 'z' ;

View File

@ -0,0 +1,376 @@
/*
* Copyright 2021 Memgraph Ltd.
*
* Use of this software is governed by the Business Source License
* included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
* License, and you may not use this file except in compliance with the Business Source License.
*
* As of the Change Date specified in that file, in accordance with
* the Business Source License, use of this software will be governed
* by the Apache License, Version 2.0, included in the file
* licenses/APL.txt.
*/
/* Memgraph specific part of Cypher grammar with enterprise features. */
parser grammar MemgraphCypher ;
options { tokenVocab=MemgraphCypherLexer; }
import Cypher ;
memgraphCypherKeyword : cypherKeyword
| AFTER
| ALTER
| ASYNC
| AUTH
| BAD
| BATCH_INTERVAL
| BATCH_LIMIT
| BATCH_SIZE
| BEFORE
| BOOTSTRAP_SERVERS
| CHECK
| CLEAR
| COMMIT
| COMMITTED
| CONFIG
| CONFIGS
| CONSUMER_GROUP
| CREDENTIALS
| CSV
| DATA
| DELIMITER
| DATABASE
| DENY
| DROP
| DUMP
| EXECUTE
| FOR
| FOREACH
| FREE
| FROM
| GLOBAL
| GRANT
| HEADER
| IDENTIFIED
| ISOLATION
| KAFKA
| LEVEL
| LOAD
| LOCK
| MAIN
| MODE
| NEXT
| NO
| PASSWORD
| PULSAR
| PORT
| PRIVILEGES
| READ
| REGISTER
| REPLICA
| REPLICAS
| REPLICATION
| REVOKE
| ROLE
| ROLES
| QUOTE
| SESSION
| SETTING
| SETTINGS
| SNAPSHOT
| START
| STATS
| STREAM
| STREAMS
| SYNC
| TIMEOUT
| TO
| TOPICS
| TRANSACTION
| TRANSFORM
| TRIGGER
| TRIGGERS
| UNCOMMITTED
| UNLOCK
| UPDATE
| USER
| USERS
| VERSION
;
symbolicName : UnescapedSymbolicName
| EscapedSymbolicName
| memgraphCypherKeyword
;
query : cypherQuery
| indexQuery
| explainQuery
| profileQuery
| infoQuery
| constraintQuery
| authQuery
| dumpQuery
| replicationQuery
| lockPathQuery
| freeMemoryQuery
| triggerQuery
| isolationLevelQuery
| createSnapshotQuery
| streamQuery
| settingQuery
| versionQuery
;
authQuery : createRole
| dropRole
| showRoles
| createUser
| setPassword
| dropUser
| showUsers
| setRole
| clearRole
| grantPrivilege
| denyPrivilege
| revokePrivilege
| showPrivileges
| showRoleForUser
| showUsersForRole
;
replicationQuery : setReplicationRole
| showReplicationRole
| registerReplica
| dropReplica
| showReplicas
;
triggerQuery : createTrigger
| dropTrigger
| showTriggers
;
clause : cypherMatch
| unwind
| merge
| create
| set
| cypherDelete
| remove
| with
| cypherReturn
| callProcedure
| loadCsv
| foreach
;
updateClause : set
| remove
| create
| merge
| cypherDelete
| foreach
;
foreach : FOREACH '(' variable IN expression '|' updateClause+ ')' ;
streamQuery : checkStream
| createStream
| dropStream
| startStream
| startAllStreams
| stopStream
| stopAllStreams
| showStreams
;
settingQuery : setSetting
| showSetting
| showSettings
;
loadCsv : LOAD CSV FROM csvFile ( WITH | NO ) HEADER
( IGNORE BAD ) ?
( DELIMITER delimiter ) ?
( QUOTE quote ) ?
AS rowVar ;
csvFile : literal ;
delimiter : literal ;
quote : literal ;
rowVar : variable ;
userOrRoleName : symbolicName ;
createRole : CREATE ROLE role=userOrRoleName ;
dropRole : DROP ROLE role=userOrRoleName ;
showRoles : SHOW ROLES ;
createUser : CREATE USER user=userOrRoleName
( IDENTIFIED BY password=literal )? ;
setPassword : SET PASSWORD FOR user=userOrRoleName TO password=literal;
dropUser : DROP USER user=userOrRoleName ;
showUsers : SHOW USERS ;
setRole : SET ROLE FOR user=userOrRoleName TO role=userOrRoleName;
clearRole : CLEAR ROLE FOR user=userOrRoleName ;
grantPrivilege : GRANT ( ALL PRIVILEGES | privileges=privilegeList ) TO userOrRole=userOrRoleName ;
denyPrivilege : DENY ( ALL PRIVILEGES | privileges=privilegeList ) TO userOrRole=userOrRoleName ;
revokePrivilege : REVOKE ( ALL PRIVILEGES | privileges=privilegeList ) FROM userOrRole=userOrRoleName ;
privilege : CREATE
| DELETE
| MATCH
| MERGE
| SET
| REMOVE
| INDEX
| STATS
| AUTH
| CONSTRAINT
| DUMP
| REPLICATION
| READ_FILE
| FREE_MEMORY
| TRIGGER
| CONFIG
| DURABILITY
| STREAM
| MODULE_READ
| MODULE_WRITE
| WEBSOCKET
;
privilegeList : privilege ( ',' privilege )* ;
showPrivileges : SHOW PRIVILEGES FOR userOrRole=userOrRoleName ;
showRoleForUser : SHOW ROLE FOR user=userOrRoleName ;
showUsersForRole : SHOW USERS FOR role=userOrRoleName ;
dumpQuery: DUMP DATABASE ;
setReplicationRole : SET REPLICATION ROLE TO ( MAIN | REPLICA )
( WITH PORT port=literal ) ? ;
showReplicationRole : SHOW REPLICATION ROLE ;
replicaName : symbolicName ;
socketAddress : literal ;
registerReplica : REGISTER REPLICA replicaName ( SYNC | ASYNC )
( WITH TIMEOUT timeout=literal ) ?
TO socketAddress ;
dropReplica : DROP REPLICA replicaName ;
showReplicas : SHOW REPLICAS ;
lockPathQuery : ( LOCK | UNLOCK ) DATA DIRECTORY ;
freeMemoryQuery : FREE MEMORY ;
triggerName : symbolicName ;
triggerStatement : .*? ;
emptyVertex : '(' ')' ;
emptyEdge : dash dash rightArrowHead ;
createTrigger : CREATE TRIGGER triggerName ( ON ( emptyVertex | emptyEdge ) ? ( CREATE | UPDATE | DELETE ) ) ?
( AFTER | BEFORE ) COMMIT EXECUTE triggerStatement ;
dropTrigger : DROP TRIGGER triggerName ;
showTriggers : SHOW TRIGGERS ;
isolationLevel : SNAPSHOT ISOLATION | READ COMMITTED | READ UNCOMMITTED ;
isolationLevelScope : GLOBAL | SESSION | NEXT ;
isolationLevelQuery : SET isolationLevelScope TRANSACTION ISOLATION LEVEL isolationLevel ;
createSnapshotQuery : CREATE SNAPSHOT ;
streamName : symbolicName ;
symbolicNameWithMinus : symbolicName ( MINUS symbolicName )* ;
symbolicNameWithDotsAndMinus: symbolicNameWithMinus ( DOT symbolicNameWithMinus )* ;
symbolicTopicNames : symbolicNameWithDotsAndMinus ( COMMA symbolicNameWithDotsAndMinus )* ;
topicNames : symbolicTopicNames | literal ;
commonCreateStreamConfig : TRANSFORM transformationName=procedureName
| BATCH_INTERVAL batchInterval=literal
| BATCH_SIZE batchSize=literal
;
createStream : kafkaCreateStream | pulsarCreateStream ;
configKeyValuePair : literal ':' literal ;
configMap : '{' ( configKeyValuePair ( ',' configKeyValuePair )* )? '}' ;
kafkaCreateStreamConfig : TOPICS topicNames
| CONSUMER_GROUP consumerGroup=symbolicNameWithDotsAndMinus
| BOOTSTRAP_SERVERS bootstrapServers=literal
| CONFIGS configsMap=configMap
| CREDENTIALS credentialsMap=configMap
| commonCreateStreamConfig
;
kafkaCreateStream : CREATE KAFKA STREAM streamName ( kafkaCreateStreamConfig ) * ;
pulsarCreateStreamConfig : TOPICS topicNames
| SERVICE_URL serviceUrl=literal
| commonCreateStreamConfig
;
pulsarCreateStream : CREATE PULSAR STREAM streamName ( pulsarCreateStreamConfig ) * ;
dropStream : DROP STREAM streamName ;
startStream : START STREAM streamName ( BATCH_LIMIT batchLimit=literal ) ? ( TIMEOUT timeout=literal ) ? ;
startAllStreams : START ALL STREAMS ;
stopStream : STOP STREAM streamName ;
stopAllStreams : STOP ALL STREAMS ;
showStreams : SHOW STREAMS ;
checkStream : CHECK STREAM streamName ( BATCH_LIMIT batchLimit=literal ) ? ( TIMEOUT timeout=literal ) ? ;
settingName : literal ;
settingValue : literal ;
setSetting : SET DATABASE SETTING settingName TO settingValue ;
showSetting : SHOW DATABASE SETTING settingName ;
showSettings : SHOW DATABASE SETTINGS ;
versionQuery : SHOW VERSION ;

View File

@ -0,0 +1,116 @@
/*
* Copyright 2021 Memgraph Ltd.
*
* Use of this software is governed by the Business Source License
* included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
* License, and you may not use this file except in compliance with the Business Source License.
*
* As of the Change Date specified in that file, in accordance with
* the Business Source License, use of this software will be governed
* by the Apache License, Version 2.0, included in the file
* licenses/APL.txt.
*/
/* Memgraph specific Cypher reserved words used for enterprise features. */
/*
* When changing this grammar make sure to update constants in
* src/query/frontend/stripped_lexer_constants.hpp (kKeywords, kSpecialTokens
* and bitsets) if needed.
*/
lexer grammar MemgraphCypherLexer ;
import CypherLexer ;
UNDERSCORE : '_' ;
AFTER : A F T E R ;
ALTER : A L T E R ;
ASYNC : A S Y N C ;
AUTH : A U T H ;
BAD : B A D ;
BATCH_INTERVAL : B A T C H UNDERSCORE I N T E R V A L ;
BATCH_LIMIT : B A T C H UNDERSCORE L I M I T ;
BATCH_SIZE : B A T C H UNDERSCORE S I Z E ;
BEFORE : B E F O R E ;
BOOTSTRAP_SERVERS : B O O T S T R A P UNDERSCORE S E R V E R S ;
CHECK : C H E C K ;
CLEAR : C L E A R ;
COMMIT : C O M M I T ;
COMMITTED : C O M M I T T E D ;
CONFIG : C O N F I G ;
CONFIGS : C O N F I G S;
CONSUMER_GROUP : C O N S U M E R UNDERSCORE G R O U P ;
CREDENTIALS : C R E D E N T I A L S ;
CSV : C S V ;
DATA : D A T A ;
DELIMITER : D E L I M I T E R ;
DATABASE : D A T A B A S E ;
DENY : D E N Y ;
DIRECTORY : D I R E C T O R Y ;
DROP : D R O P ;
DUMP : D U M P ;
DURABILITY : D U R A B I L I T Y ;
EXECUTE : E X E C U T E ;
FOR : F O R ;
FOREACH : F O R E A C H;
FREE : F R E E ;
FREE_MEMORY : F R E E UNDERSCORE M E M O R Y ;
FROM : F R O M ;
GLOBAL : G L O B A L ;
GRANT : G R A N T ;
GRANTS : G R A N T S ;
HEADER : H E A D E R ;
IDENTIFIED : I D E N T I F I E D ;
IGNORE : I G N O R E ;
ISOLATION : I S O L A T I O N ;
KAFKA : K A F K A ;
LEVEL : L E V E L ;
LOAD : L O A D ;
LOCK : L O C K ;
MAIN : M A I N ;
MODE : M O D E ;
MODULE_READ : M O D U L E UNDERSCORE R E A D ;
MODULE_WRITE : M O D U L E UNDERSCORE W R I T E ;
NEXT : N E X T ;
NO : N O ;
PASSWORD : P A S S W O R D ;
PORT : P O R T ;
PRIVILEGES : P R I V I L E G E S ;
PULSAR : P U L S A R ;
READ : R E A D ;
READ_FILE : R E A D UNDERSCORE F I L E ;
REGISTER : R E G I S T E R ;
REPLICA : R E P L I C A ;
REPLICAS : R E P L I C A S ;
REPLICATION : R E P L I C A T I O N ;
REVOKE : R E V O K E ;
ROLE : R O L E ;
ROLES : R O L E S ;
QUOTE : Q U O T E ;
SERVICE_URL : S E R V I C E UNDERSCORE U R L ;
SESSION : S E S S I O N ;
SETTING : S E T T I N G ;
SETTINGS : S E T T I N G S ;
SNAPSHOT : S N A P S H O T ;
START : S T A R T ;
STATS : S T A T S ;
STOP : S T O P ;
STREAM : S T R E A M ;
STREAMS : S T R E A M S ;
SYNC : S Y N C ;
TIMEOUT : T I M E O U T ;
TO : T O ;
TOPICS : T O P I C S;
TRANSACTION : T R A N S A C T I O N ;
TRANSFORM : T R A N S F O R M ;
TRIGGER : T R I G G E R ;
TRIGGERS : T R I G G E R S ;
UNCOMMITTED : U N C O M M I T T E D ;
UNLOCK : U N L O C K ;
UPDATE : U P D A T E ;
USER : U S E R ;
USERS : U S E R S ;
VERSION : V E R S I O N ;
WEBSOCKET : W E B S O C K E T ;

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,68 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <string>
#include "antlr4-runtime.h"
#include "query/v2/exceptions.hpp"
#include "query/v2/frontend/opencypher/generated/MemgraphCypher.h"
#include "query/v2/frontend/opencypher/generated/MemgraphCypherLexer.h"
namespace memgraph::query::v2::frontend::opencypher {
/**
* Generates openCypher AST
* This thing must me a class since parser.cypher() returns pointer and there is
* no way for us to get ownership over the object.
*/
class Parser {
public:
/**
* @param query incoming query that has to be compiled into query plan
* the first step is to generate AST
*/
Parser(const std::string query) : query_(std::move(query)) {
parser_.removeErrorListeners();
parser_.addErrorListener(&error_listener_);
tree_ = parser_.cypher();
if (parser_.getNumberOfSyntaxErrors()) {
throw query::v2::SyntaxException(error_listener_.error_);
}
}
auto tree() { return tree_; }
private:
class FirstMessageErrorListener : public antlr4::BaseErrorListener {
void syntaxError(antlr4::Recognizer *, antlr4::Token *, size_t line, size_t position, const std::string &message,
std::exception_ptr) override {
if (error_.empty()) {
error_ = "line " + std::to_string(line) + ":" + std::to_string(position + 1) + " " + message;
}
}
public:
std::string error_;
};
FirstMessageErrorListener error_listener_;
std::string query_;
antlr4::ANTLRInputStream input_{query_};
antlropencypher::MemgraphCypherLexer lexer_{&input_};
antlr4::CommonTokenStream tokens_{&lexer_};
// generate ast
antlropencypher::MemgraphCypher parser_{&tokens_};
antlr4::tree::ParseTree *tree_ = nullptr;
};
} // namespace memgraph::query::v2::frontend::opencypher

View File

@ -0,0 +1,184 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/v2/frontend/parsing.hpp"
#include <cctype>
#include <codecvt>
#include <locale>
#include <stdexcept>
#include "query/v2/exceptions.hpp"
#include "utils/logging.hpp"
#include "utils/string.hpp"
namespace memgraph::query::v2::frontend {
int64_t ParseIntegerLiteral(const std::string &s) {
try {
// Not really correct since long long can have a bigger range than int64_t.
return static_cast<int64_t>(std::stoll(s, 0, 0));
} catch (const std::out_of_range &) {
throw SemanticException("Integer literal exceeds 64 bits.");
}
}
std::string ParseStringLiteral(const std::string &s) {
// These functions is declared as lambda since its semantics is highly
// specific for this conxtext and shouldn't be used elsewhere.
auto EncodeEscapedUnicodeCodepointUtf32 = [](const std::string &s, int &i) {
const int kLongUnicodeLength = 8;
int j = i + 1;
while (j < static_cast<int>(s.size()) - 1 && j < i + kLongUnicodeLength + 1 && isxdigit(s[j])) {
++j;
}
if (j - i == kLongUnicodeLength + 1) {
char32_t t = stoi(s.substr(i + 1, kLongUnicodeLength), 0, 16);
i += kLongUnicodeLength;
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
return converter.to_bytes(t);
}
throw SyntaxException(
"Expected 8 hex digits as unicode codepoint started with \\U. "
"Use \\u for 4 hex digits format.");
};
auto EncodeEscapedUnicodeCodepointUtf16 = [](const std::string &s, int &i) {
const int kShortUnicodeLength = 4;
int j = i + 1;
while (j < static_cast<int>(s.size()) - 1 && j < i + kShortUnicodeLength + 1 && isxdigit(s[j])) {
++j;
}
if (j - i >= kShortUnicodeLength + 1) {
char16_t t = stoi(s.substr(i + 1, kShortUnicodeLength), 0, 16);
if (t >= 0xD800 && t <= 0xDBFF) {
// t is high surrogate pair. Expect one more utf16 codepoint.
j = i + kShortUnicodeLength + 1;
if (j >= static_cast<int>(s.size()) - 1 || s[j] != '\\') {
throw SemanticException("Invalid UTF codepoint.");
}
++j;
if (j >= static_cast<int>(s.size()) - 1 || (s[j] != 'u' && s[j] != 'U')) {
throw SemanticException("Invalid UTF codepoint.");
}
++j;
int k = j;
while (k < static_cast<int>(s.size()) - 1 && k < j + kShortUnicodeLength && isxdigit(s[k])) {
++k;
}
if (k != j + kShortUnicodeLength) {
throw SemanticException("Invalid UTF codepoint.");
}
char16_t surrogates[3] = {t, static_cast<char16_t>(stoi(s.substr(j, kShortUnicodeLength), 0, 16)), 0};
i += kShortUnicodeLength + 2 + kShortUnicodeLength;
std::wstring_convert<std::codecvt_utf8_utf16<char16_t>, char16_t> converter;
return converter.to_bytes(surrogates);
} else {
i += kShortUnicodeLength;
std::wstring_convert<std::codecvt_utf8_utf16<char16_t>, char16_t> converter;
return converter.to_bytes(t);
}
}
throw SyntaxException(
"Expected 4 hex digits as unicode codepoint started with \\u. "
"Use \\U for 8 hex digits format.");
};
std::string unescaped;
bool escape = false;
// First and last char is quote, we don't need to look at them.
for (int i = 1; i < static_cast<int>(s.size()) - 1; ++i) {
if (escape) {
switch (s[i]) {
case '\\':
unescaped += '\\';
break;
case '\'':
unescaped += '\'';
break;
case '"':
unescaped += '"';
break;
case 'B':
case 'b':
unescaped += '\b';
break;
case 'F':
case 'f':
unescaped += '\f';
break;
case 'N':
case 'n':
unescaped += '\n';
break;
case 'R':
case 'r':
unescaped += '\r';
break;
case 'T':
case 't':
unescaped += '\t';
break;
case 'U':
try {
unescaped += EncodeEscapedUnicodeCodepointUtf32(s, i);
} catch (const std::range_error &) {
throw SemanticException("Invalid UTF codepoint.");
}
break;
case 'u':
try {
unescaped += EncodeEscapedUnicodeCodepointUtf16(s, i);
} catch (const std::range_error &) {
throw SemanticException("Invalid UTF codepoint.");
}
break;
default:
// This should never happen, except grammar changes and we don't
// notice change in this production.
DLOG_FATAL("can't happen");
throw std::exception();
}
escape = false;
} else if (s[i] == '\\') {
escape = true;
} else {
unescaped += s[i];
}
}
return unescaped;
}
double ParseDoubleLiteral(const std::string &s) {
try {
return utils::ParseDouble(s);
} catch (const utils::BasicException &) {
throw SemanticException("Couldn't parse string to double.");
}
}
std::string ParseParameter(const std::string &s) {
DMG_ASSERT(s[0] == '$', "Invalid string passed as parameter name");
if (s[1] != '`') return s.substr(1);
// If parameter name is escaped symbolic name then symbolic name should be
// unescaped and leading and trailing backquote should be removed.
DMG_ASSERT(s.size() > 3U && s.back() == '`', "Invalid string passed as parameter name");
std::string out;
for (int i = 2; i < static_cast<int>(s.size()) - 1; ++i) {
if (s[i] == '`') {
++i;
}
out.push_back(s[i]);
}
return out;
}
} // namespace memgraph::query::v2::frontend

View File

@ -0,0 +1,27 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
/// @file
#pragma once
#include <cstdint>
#include <string>
namespace memgraph::query::v2::frontend {
// These are the functions for parsing literals and parameter names from
// opencypher query.
int64_t ParseIntegerLiteral(const std::string &s);
std::string ParseStringLiteral(const std::string &s);
double ParseDoubleLiteral(const std::string &s);
std::string ParseParameter(const std::string &s);
} // namespace memgraph::query::v2::frontend

View File

@ -0,0 +1,152 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/v2/frontend/ast/ast.hpp"
#include "query/v2/frontend/ast/ast_visitor.hpp"
#include "query/v2/procedure/module.hpp"
#include "utils/memory.hpp"
namespace memgraph::query::v2 {
class PrivilegeExtractor : public QueryVisitor<void>, public HierarchicalTreeVisitor {
public:
using HierarchicalTreeVisitor::PostVisit;
using HierarchicalTreeVisitor::PreVisit;
using HierarchicalTreeVisitor::Visit;
using QueryVisitor<void>::Visit;
std::vector<AuthQuery::Privilege> privileges() { return privileges_; }
void Visit(IndexQuery &) override { AddPrivilege(AuthQuery::Privilege::INDEX); }
void Visit(AuthQuery &) override { AddPrivilege(AuthQuery::Privilege::AUTH); }
void Visit(ExplainQuery &query) override { query.cypher_query_->Accept(*this); }
void Visit(ProfileQuery &query) override { query.cypher_query_->Accept(*this); }
void Visit(InfoQuery &info_query) override {
switch (info_query.info_type_) {
case InfoQuery::InfoType::INDEX:
// TODO: This should be INDEX | STATS, but we don't have support for
// *or* with privileges.
AddPrivilege(AuthQuery::Privilege::INDEX);
break;
case InfoQuery::InfoType::STORAGE:
AddPrivilege(AuthQuery::Privilege::STATS);
break;
case InfoQuery::InfoType::CONSTRAINT:
// TODO: This should be CONSTRAINT | STATS, but we don't have support
// for *or* with privileges.
AddPrivilege(AuthQuery::Privilege::CONSTRAINT);
break;
}
}
void Visit(ConstraintQuery &constraint_query) override { AddPrivilege(AuthQuery::Privilege::CONSTRAINT); }
void Visit(CypherQuery &query) override {
query.single_query_->Accept(*this);
for (auto *cypher_union : query.cypher_unions_) {
cypher_union->Accept(*this);
}
}
void Visit(DumpQuery &dump_query) override { AddPrivilege(AuthQuery::Privilege::DUMP); }
void Visit(LockPathQuery &lock_path_query) override { AddPrivilege(AuthQuery::Privilege::DURABILITY); }
void Visit(FreeMemoryQuery &free_memory_query) override { AddPrivilege(AuthQuery::Privilege::FREE_MEMORY); }
void Visit(TriggerQuery &trigger_query) override { AddPrivilege(AuthQuery::Privilege::TRIGGER); }
void Visit(StreamQuery &stream_query) override { AddPrivilege(AuthQuery::Privilege::STREAM); }
void Visit(ReplicationQuery &replication_query) override { AddPrivilege(AuthQuery::Privilege::REPLICATION); }
void Visit(IsolationLevelQuery &isolation_level_query) override { AddPrivilege(AuthQuery::Privilege::CONFIG); }
void Visit(CreateSnapshotQuery &create_snapshot_query) override { AddPrivilege(AuthQuery::Privilege::DURABILITY); }
void Visit(SettingQuery & /*setting_query*/) override { AddPrivilege(AuthQuery::Privilege::CONFIG); }
void Visit(VersionQuery & /*version_query*/) override { AddPrivilege(AuthQuery::Privilege::STATS); }
bool PreVisit(Create & /*unused*/) override {
AddPrivilege(AuthQuery::Privilege::CREATE);
return false;
}
bool PreVisit(CallProcedure &procedure) override {
const auto maybe_proc =
procedure::FindProcedure(procedure::gModuleRegistry, procedure.procedure_name_, utils::NewDeleteResource());
if (maybe_proc && maybe_proc->second->info.required_privilege) {
AddPrivilege(*maybe_proc->second->info.required_privilege);
}
return false;
}
bool PreVisit(Delete & /*unused*/) override {
AddPrivilege(AuthQuery::Privilege::DELETE);
return false;
}
bool PreVisit(Match & /*unused*/) override {
AddPrivilege(AuthQuery::Privilege::MATCH);
return false;
}
bool PreVisit(Merge & /*unused*/) override {
AddPrivilege(AuthQuery::Privilege::MERGE);
return false;
}
bool PreVisit(SetProperty & /*unused*/) override {
AddPrivilege(AuthQuery::Privilege::SET);
return false;
}
bool PreVisit(SetProperties & /*unused*/) override {
AddPrivilege(AuthQuery::Privilege::SET);
return false;
}
bool PreVisit(SetLabels & /*unused*/) override {
AddPrivilege(AuthQuery::Privilege::SET);
return false;
}
bool PreVisit(RemoveProperty & /*unused*/) override {
AddPrivilege(AuthQuery::Privilege::REMOVE);
return false;
}
bool PreVisit(RemoveLabels & /*unused*/) override {
AddPrivilege(AuthQuery::Privilege::REMOVE);
return false;
}
bool PreVisit(LoadCsv & /*unused*/) override {
AddPrivilege(AuthQuery::Privilege::READ_FILE);
return false;
}
bool Visit(Identifier & /*unused*/) override { return true; }
bool Visit(PrimitiveLiteral & /*unused*/) override { return true; }
bool Visit(ParameterLookup & /*unused*/) override { return true; }
private:
void AddPrivilege(AuthQuery::Privilege privilege) {
if (!utils::Contains(privileges_, privilege)) {
privileges_.push_back(privilege);
}
}
std::vector<AuthQuery::Privilege> privileges_;
};
std::vector<AuthQuery::Privilege> GetRequiredPrivileges(Query *query) {
PrivilegeExtractor extractor;
query->Accept(extractor);
return extractor.privileges();
}
} // namespace memgraph::query::v2

View File

@ -0,0 +1,18 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "query/v2/frontend/ast/ast.hpp"
namespace memgraph::query::v2 {
std::vector<AuthQuery::Privilege> GetRequiredPrivileges(Query *query);
} // namespace memgraph::query::v2

View File

@ -0,0 +1,89 @@
;; Copyright 2022 Memgraph Ltd.
;;
;; Use of this software is governed by the Business Source License
;; included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
;; License, and you may not use this file except in compliance with the Business Source License.
;;
;; As of the Change Date specified in that file, in accordance with
;; the Business Source License, use of this software will be governed
;; by the Apache License, Version 2.0, included in the file
;; licenses/APL.txt.
#>cpp
#pragma once
#include <string>
#include "utils/typeinfo.hpp"
cpp<#
(lcp:namespace memgraph)
(lcp:namespace query)
(lcp:namespace v2)
(lcp:define-class symbol ()
((name "std::string" :scope :public)
(position :int64_t :scope :public)
(user-declared :bool :initval "true" :scope :public)
(type "Type" :initval "Type::ANY" :scope :public)
(token-position :int64_t :initval "-1" :scope :public))
(:public
;; This is similar to TypedValue::Type, but this has `Any` type.
;; TODO: Make a better Type structure which can store a generic List.
(lcp:define-enum type (any vertex edge path number edge-list)
(:serialize))
#>cpp
// TODO: Generate enum to string conversion from LCP. Note, that this is
// displayed to the end user, so we may want to have a pretty name of each
// value.
static std::string TypeToString(Type type) {
const char *enum_string[] = {"Any", "Vertex", "Edge",
"Path", "Number", "EdgeList"};
return enum_string[static_cast<int>(type)];
}
Symbol() {}
Symbol(const std::string &name, int position, bool user_declared,
Type type = Type::ANY, int token_position = -1)
: name_(name),
position_(position),
user_declared_(user_declared),
type_(type),
token_position_(token_position) {}
bool operator==(const Symbol &other) const {
return position_ == other.position_ && name_ == other.name_ &&
type_ == other.type_;
}
bool operator!=(const Symbol &other) const { return !operator==(other); }
// TODO: Remove these since members are public
const auto &name() const { return name_; }
int position() const { return position_; }
Type type() const { return type_; }
bool user_declared() const { return user_declared_; }
int token_position() const { return token_position_; }
cpp<#)
(:serialize (:slk)))
(lcp:pop-namespace) ;; v2
(lcp:pop-namespace) ;; query
(lcp:pop-namespace) ;; memgraph
#>cpp
namespace std {
template <>
struct hash<memgraph::query::v2::Symbol> {
size_t operator()(const memgraph::query::v2::Symbol &symbol) const {
size_t prime = 265443599u;
size_t hash = std::hash<int>{}(symbol.position());
hash ^= prime * std::hash<std::string>{}(symbol.name());
hash ^= prime * std::hash<int>{}(static_cast<int>(symbol.type()));
return hash;
}
};
} // namespace std
cpp<#

View File

@ -0,0 +1,625 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
// Copyright 2017 Memgraph
//
// Created by Teon Banek on 24-03-2017
#include "query/v2/frontend/semantic/symbol_generator.hpp"
#include <algorithm>
#include <optional>
#include <ranges>
#include <unordered_set>
#include <variant>
#include "query/v2/frontend/ast/ast.hpp"
#include "query/v2/frontend/ast/ast_visitor.hpp"
#include "utils/algorithm.hpp"
#include "utils/logging.hpp"
namespace memgraph::query::v2 {
namespace {
std::unordered_map<std::string, Identifier *> GeneratePredefinedIdentifierMap(
const std::vector<Identifier *> &predefined_identifiers) {
std::unordered_map<std::string, Identifier *> identifier_map;
for (const auto &identifier : predefined_identifiers) {
identifier_map.emplace(identifier->name_, identifier);
}
return identifier_map;
}
} // namespace
SymbolGenerator::SymbolGenerator(SymbolTable *symbol_table, const std::vector<Identifier *> &predefined_identifiers)
: symbol_table_(symbol_table),
predefined_identifiers_{GeneratePredefinedIdentifierMap(predefined_identifiers)},
scopes_(1, Scope()) {}
std::optional<Symbol> SymbolGenerator::FindSymbolInScope(const std::string &name, const Scope &scope,
Symbol::Type type) {
if (auto it = scope.symbols.find(name); it != scope.symbols.end()) {
const auto &symbol = it->second;
// Unless we have `ANY` type, check that types match.
if (type != Symbol::Type::ANY && symbol.type() != Symbol::Type::ANY && type != symbol.type()) {
throw TypeMismatchError(name, Symbol::TypeToString(symbol.type()), Symbol::TypeToString(type));
}
return symbol;
}
return std::nullopt;
}
auto SymbolGenerator::CreateSymbol(const std::string &name, bool user_declared, Symbol::Type type, int token_position) {
auto symbol = symbol_table_->CreateSymbol(name, user_declared, type, token_position);
scopes_.back().symbols[name] = symbol;
return symbol;
}
auto SymbolGenerator::GetOrCreateSymbolLocalScope(const std::string &name, bool user_declared, Symbol::Type type) {
auto &scope = scopes_.back();
if (auto maybe_symbol = FindSymbolInScope(name, scope, type); maybe_symbol) {
return *maybe_symbol;
}
return CreateSymbol(name, user_declared, type);
}
auto SymbolGenerator::GetOrCreateSymbol(const std::string &name, bool user_declared, Symbol::Type type) {
// NOLINTNEXTLINE
for (auto scope = scopes_.rbegin(); scope != scopes_.rend(); ++scope) {
if (auto maybe_symbol = FindSymbolInScope(name, *scope, type); maybe_symbol) {
return *maybe_symbol;
}
}
return CreateSymbol(name, user_declared, type);
}
void SymbolGenerator::VisitReturnBody(ReturnBody &body, Where *where) {
auto &scope = scopes_.back();
for (auto &expr : body.named_expressions) {
expr->Accept(*this);
}
std::vector<Symbol> user_symbols;
if (body.all_identifiers) {
// Carry over user symbols because '*' appeared.
for (const auto &sym_pair : scope.symbols) {
if (!sym_pair.second.user_declared()) {
continue;
}
user_symbols.emplace_back(sym_pair.second);
}
if (user_symbols.empty()) {
throw SemanticException("There are no variables in scope to use for '*'.");
}
}
// WITH/RETURN clause removes declarations of all the previous variables and
// declares only those established through named expressions. New declarations
// must not be visible inside named expressions themselves.
bool removed_old_names = false;
if ((!where && body.order_by.empty()) || scope.has_aggregation) {
// WHERE and ORDER BY need to see both the old and new symbols, unless we
// have an aggregation. Therefore, we can clear the symbols immediately if
// there is neither ORDER BY nor WHERE, or we have an aggregation.
scope.symbols.clear();
removed_old_names = true;
}
// Create symbols for named expressions.
std::unordered_set<std::string> new_names;
for (const auto &user_sym : user_symbols) {
new_names.insert(user_sym.name());
scope.symbols[user_sym.name()] = user_sym;
}
for (auto &named_expr : body.named_expressions) {
const auto &name = named_expr->name_;
if (!new_names.insert(name).second) {
throw SemanticException("Multiple results with the same name '{}' are not allowed.", name);
}
// An improvement would be to infer the type of the expression, so that the
// new symbol would have a more specific type.
named_expr->MapTo(CreateSymbol(name, true, Symbol::Type::ANY, named_expr->token_position_));
}
scope.in_order_by = true;
for (const auto &order_pair : body.order_by) {
order_pair.expression->Accept(*this);
}
scope.in_order_by = false;
if (body.skip) {
scope.in_skip = true;
body.skip->Accept(*this);
scope.in_skip = false;
}
if (body.limit) {
scope.in_limit = true;
body.limit->Accept(*this);
scope.in_limit = false;
}
if (where) where->Accept(*this);
if (!removed_old_names) {
// We have an ORDER BY or WHERE, but no aggregation, which means we didn't
// clear the old symbols, so do it now. We cannot just call clear, because
// we've added new symbols.
for (auto sym_it = scope.symbols.begin(); sym_it != scope.symbols.end();) {
if (new_names.find(sym_it->first) == new_names.end()) {
sym_it = scope.symbols.erase(sym_it);
} else {
sym_it++;
}
}
}
scopes_.back().has_aggregation = false;
}
// Query
bool SymbolGenerator::PreVisit(SingleQuery &) {
prev_return_names_ = curr_return_names_;
curr_return_names_.clear();
return true;
}
// Union
bool SymbolGenerator::PreVisit(CypherUnion &) {
scopes_.back() = Scope();
return true;
}
bool SymbolGenerator::PostVisit(CypherUnion &cypher_union) {
if (prev_return_names_ != curr_return_names_) {
throw SemanticException("All subqueries in an UNION must have the same column names.");
}
// create new symbols for the result of the union
for (const auto &name : curr_return_names_) {
auto symbol = CreateSymbol(name, false);
cypher_union.union_symbols_.push_back(symbol);
}
return true;
}
// Clauses
bool SymbolGenerator::PreVisit(Create &) {
scopes_.back().in_create = true;
return true;
}
bool SymbolGenerator::PostVisit(Create &) {
scopes_.back().in_create = false;
return true;
}
bool SymbolGenerator::PreVisit(CallProcedure &call_proc) {
for (auto *expr : call_proc.arguments_) {
expr->Accept(*this);
}
return false;
}
bool SymbolGenerator::PostVisit(CallProcedure &call_proc) {
for (auto *ident : call_proc.result_identifiers_) {
if (HasSymbolLocalScope(ident->name_)) {
throw RedeclareVariableError(ident->name_);
}
ident->MapTo(CreateSymbol(ident->name_, true));
}
return true;
}
bool SymbolGenerator::PreVisit(LoadCsv &load_csv) { return false; }
bool SymbolGenerator::PostVisit(LoadCsv &load_csv) {
if (HasSymbolLocalScope(load_csv.row_var_->name_)) {
throw RedeclareVariableError(load_csv.row_var_->name_);
}
load_csv.row_var_->MapTo(CreateSymbol(load_csv.row_var_->name_, true));
return true;
}
bool SymbolGenerator::PreVisit(Return &ret) {
auto &scope = scopes_.back();
scope.in_return = true;
VisitReturnBody(ret.body_);
scope.in_return = false;
return false; // We handled the traversal ourselves.
}
bool SymbolGenerator::PostVisit(Return &) {
for (const auto &name_symbol : scopes_.back().symbols) curr_return_names_.insert(name_symbol.first);
return true;
}
bool SymbolGenerator::PreVisit(With &with) {
auto &scope = scopes_.back();
scope.in_with = true;
VisitReturnBody(with.body_, with.where_);
scope.in_with = false;
return false; // We handled the traversal ourselves.
}
bool SymbolGenerator::PreVisit(Where &) {
scopes_.back().in_where = true;
return true;
}
bool SymbolGenerator::PostVisit(Where &) {
scopes_.back().in_where = false;
return true;
}
bool SymbolGenerator::PreVisit(Merge &) {
scopes_.back().in_merge = true;
return true;
}
bool SymbolGenerator::PostVisit(Merge &) {
scopes_.back().in_merge = false;
return true;
}
bool SymbolGenerator::PostVisit(Unwind &unwind) {
const auto &name = unwind.named_expression_->name_;
if (HasSymbolLocalScope(name)) {
throw RedeclareVariableError(name);
}
unwind.named_expression_->MapTo(CreateSymbol(name, true));
return true;
}
bool SymbolGenerator::PreVisit(Match &) {
scopes_.back().in_match = true;
return true;
}
bool SymbolGenerator::PostVisit(Match &) {
auto &scope = scopes_.back();
scope.in_match = false;
// Check variables in property maps after visiting Match, so that they can
// reference symbols out of bind order.
for (auto &ident : scope.identifiers_in_match) {
if (!HasSymbolLocalScope(ident->name_) && !ConsumePredefinedIdentifier(ident->name_))
throw UnboundVariableError(ident->name_);
ident->MapTo(scope.symbols[ident->name_]);
}
scope.identifiers_in_match.clear();
return true;
}
bool SymbolGenerator::PreVisit(Foreach &for_each) {
const auto &name = for_each.named_expression_->name_;
scopes_.emplace_back(Scope());
scopes_.back().in_foreach = true;
for_each.named_expression_->MapTo(
CreateSymbol(name, true, Symbol::Type::ANY, for_each.named_expression_->token_position_));
return true;
}
bool SymbolGenerator::PostVisit([[maybe_unused]] Foreach &for_each) {
scopes_.pop_back();
return true;
}
// Expressions
SymbolGenerator::ReturnType SymbolGenerator::Visit(Identifier &ident) {
auto &scope = scopes_.back();
if (scope.in_skip || scope.in_limit) {
throw SemanticException("Variables are not allowed in {}.", scope.in_skip ? "SKIP" : "LIMIT");
}
Symbol symbol;
if (scope.in_pattern && !(scope.in_node_atom || scope.visiting_edge)) {
// If we are in the pattern, and outside of a node or an edge, the
// identifier is the pattern name.
symbol = GetOrCreateSymbolLocalScope(ident.name_, ident.user_declared_, Symbol::Type::PATH);
} else if (scope.in_pattern && scope.in_pattern_atom_identifier) {
// Patterns used to create nodes and edges cannot redeclare already
// established bindings. Declaration only happens in single node
// patterns and in edge patterns. OpenCypher example,
// `MATCH (n) CREATE (n)` should throw an error that `n` is already
// declared. While `MATCH (n) CREATE (n) -[:R]-> (n)` is allowed,
// since `n` now references the bound node instead of declaring it.
if ((scope.in_create_node || scope.in_create_edge) && HasSymbolLocalScope(ident.name_)) {
throw RedeclareVariableError(ident.name_);
}
auto type = Symbol::Type::VERTEX;
if (scope.visiting_edge) {
// Edge referencing is not allowed (like in Neo4j):
// `MATCH (n) - [r] -> (n) - [r] -> (n) RETURN r` is not allowed.
if (HasSymbolLocalScope(ident.name_)) {
throw RedeclareVariableError(ident.name_);
}
type = scope.visiting_edge->IsVariable() ? Symbol::Type::EDGE_LIST : Symbol::Type::EDGE;
}
symbol = GetOrCreateSymbolLocalScope(ident.name_, ident.user_declared_, type);
} else if (scope.in_pattern && !scope.in_pattern_atom_identifier && scope.in_match) {
if (scope.in_edge_range && scope.visiting_edge->identifier_->name_ == ident.name_) {
// Prevent variable path bounds to reference the identifier which is bound
// by the variable path itself.
throw UnboundVariableError(ident.name_);
}
// Variables in property maps or bounds of variable length path during MATCH
// can reference symbols bound later in the same MATCH. We collect them
// here, so that they can be checked after visiting Match.
scope.identifiers_in_match.emplace_back(&ident);
} else {
// Everything else references a bound symbol.
if (!HasSymbol(ident.name_) && !ConsumePredefinedIdentifier(ident.name_)) throw UnboundVariableError(ident.name_);
symbol = GetOrCreateSymbol(ident.name_, ident.user_declared_, Symbol::Type::ANY);
}
ident.MapTo(symbol);
return true;
}
bool SymbolGenerator::PreVisit(Aggregation &aggr) {
auto &scope = scopes_.back();
// Check if the aggregation can be used in this context. This check should
// probably move to a separate phase, which checks if the query is well
// formed.
if ((!scope.in_return && !scope.in_with) || scope.in_order_by || scope.in_skip || scope.in_limit || scope.in_where) {
throw SemanticException("Aggregation functions are only allowed in WITH and RETURN.");
}
if (scope.in_aggregation) {
throw SemanticException(
"Using aggregation functions inside aggregation functions is not "
"allowed.");
}
if (scope.num_if_operators) {
// Neo allows aggregations here and produces very interesting behaviors.
// To simplify implementation at this moment we decided to completely
// disallow aggregations inside of the CASE.
// However, in some cases aggregation makes perfect sense, for example:
// CASE count(n) WHEN 10 THEN "YES" ELSE "NO" END.
// TODO: Rethink of allowing aggregations in some parts of the CASE
// construct.
throw SemanticException("Using aggregation functions inside of CASE is not allowed.");
}
// Create a virtual symbol for aggregation result.
// Currently, we only have aggregation operators which return numbers.
auto aggr_name = Aggregation::OpToString(aggr.op_) + std::to_string(aggr.symbol_pos_);
aggr.MapTo(CreateSymbol(aggr_name, false, Symbol::Type::NUMBER));
scope.in_aggregation = true;
scope.has_aggregation = true;
return true;
}
bool SymbolGenerator::PostVisit(Aggregation &) {
scopes_.back().in_aggregation = false;
return true;
}
bool SymbolGenerator::PreVisit(IfOperator &) {
++scopes_.back().num_if_operators;
return true;
}
bool SymbolGenerator::PostVisit(IfOperator &) {
--scopes_.back().num_if_operators;
return true;
}
bool SymbolGenerator::PreVisit(All &all) {
all.list_expression_->Accept(*this);
VisitWithIdentifiers(all.where_->expression_, {all.identifier_});
return false;
}
bool SymbolGenerator::PreVisit(Single &single) {
single.list_expression_->Accept(*this);
VisitWithIdentifiers(single.where_->expression_, {single.identifier_});
return false;
}
bool SymbolGenerator::PreVisit(Any &any) {
any.list_expression_->Accept(*this);
VisitWithIdentifiers(any.where_->expression_, {any.identifier_});
return false;
}
bool SymbolGenerator::PreVisit(None &none) {
none.list_expression_->Accept(*this);
VisitWithIdentifiers(none.where_->expression_, {none.identifier_});
return false;
}
bool SymbolGenerator::PreVisit(Reduce &reduce) {
reduce.initializer_->Accept(*this);
reduce.list_->Accept(*this);
VisitWithIdentifiers(reduce.expression_, {reduce.accumulator_, reduce.identifier_});
return false;
}
bool SymbolGenerator::PreVisit(Extract &extract) {
extract.list_->Accept(*this);
VisitWithIdentifiers(extract.expression_, {extract.identifier_});
return false;
}
// Pattern and its subparts.
bool SymbolGenerator::PreVisit(Pattern &pattern) {
auto &scope = scopes_.back();
scope.in_pattern = true;
if ((scope.in_create || scope.in_merge) && pattern.atoms_.size() == 1U) {
MG_ASSERT(utils::IsSubtype(*pattern.atoms_[0], NodeAtom::kType), "Expected a single NodeAtom in Pattern");
scope.in_create_node = true;
}
return true;
}
bool SymbolGenerator::PostVisit(Pattern &) {
auto &scope = scopes_.back();
scope.in_pattern = false;
scope.in_create_node = false;
return true;
}
bool SymbolGenerator::PreVisit(NodeAtom &node_atom) {
auto &scope = scopes_.back();
auto check_node_semantic = [&node_atom, &scope, this](const bool props_or_labels) {
const auto &node_name = node_atom.identifier_->name_;
if ((scope.in_create || scope.in_merge) && props_or_labels && HasSymbolLocalScope(node_name)) {
throw SemanticException("Cannot create node '" + node_name +
"' with labels or properties, because it is already declared.");
}
scope.in_pattern_atom_identifier = true;
node_atom.identifier_->Accept(*this);
scope.in_pattern_atom_identifier = false;
};
scope.in_node_atom = true;
if (auto *properties = std::get_if<std::unordered_map<PropertyIx, Expression *>>(&node_atom.properties_)) {
bool props_or_labels = !properties->empty() || !node_atom.labels_.empty();
check_node_semantic(props_or_labels);
for (auto kv : *properties) {
kv.second->Accept(*this);
}
return false;
}
auto &properties_parameter = std::get<ParameterLookup *>(node_atom.properties_);
bool props_or_labels = !properties_parameter || !node_atom.labels_.empty();
check_node_semantic(props_or_labels);
properties_parameter->Accept(*this);
return false;
}
bool SymbolGenerator::PostVisit(NodeAtom &) {
scopes_.back().in_node_atom = false;
return true;
}
bool SymbolGenerator::PreVisit(EdgeAtom &edge_atom) {
auto &scope = scopes_.back();
scope.visiting_edge = &edge_atom;
if (scope.in_create || scope.in_merge) {
scope.in_create_edge = true;
if (edge_atom.edge_types_.size() != 1U) {
throw SemanticException(
"A single relationship type must be specified "
"when creating an edge.");
}
if (scope.in_create && // Merge allows bidirectionality
edge_atom.direction_ == EdgeAtom::Direction::BOTH) {
throw SemanticException(
"Bidirectional relationship are not supported "
"when creating an edge");
}
if (edge_atom.IsVariable()) {
throw SemanticException(
"Variable length relationships are not supported when creating an "
"edge.");
}
}
if (auto *properties = std::get_if<std::unordered_map<PropertyIx, Expression *>>(&edge_atom.properties_)) {
for (auto kv : *properties) {
kv.second->Accept(*this);
}
} else {
std::get<ParameterLookup *>(edge_atom.properties_)->Accept(*this);
}
if (edge_atom.IsVariable()) {
scope.in_edge_range = true;
if (edge_atom.lower_bound_) {
edge_atom.lower_bound_->Accept(*this);
}
if (edge_atom.upper_bound_) {
edge_atom.upper_bound_->Accept(*this);
}
scope.in_edge_range = false;
scope.in_pattern = false;
if (edge_atom.filter_lambda_.expression) {
VisitWithIdentifiers(edge_atom.filter_lambda_.expression,
{edge_atom.filter_lambda_.inner_edge, edge_atom.filter_lambda_.inner_node});
} else {
// Create inner symbols, but don't bind them in scope, since they are to
// be used in the missing filter expression.
auto *inner_edge = edge_atom.filter_lambda_.inner_edge;
inner_edge->MapTo(symbol_table_->CreateSymbol(inner_edge->name_, inner_edge->user_declared_, Symbol::Type::EDGE));
auto *inner_node = edge_atom.filter_lambda_.inner_node;
inner_node->MapTo(
symbol_table_->CreateSymbol(inner_node->name_, inner_node->user_declared_, Symbol::Type::VERTEX));
}
if (edge_atom.weight_lambda_.expression) {
VisitWithIdentifiers(edge_atom.weight_lambda_.expression,
{edge_atom.weight_lambda_.inner_edge, edge_atom.weight_lambda_.inner_node});
}
scope.in_pattern = true;
}
scope.in_pattern_atom_identifier = true;
edge_atom.identifier_->Accept(*this);
scope.in_pattern_atom_identifier = false;
if (edge_atom.total_weight_) {
if (HasSymbolLocalScope(edge_atom.total_weight_->name_)) {
throw RedeclareVariableError(edge_atom.total_weight_->name_);
}
edge_atom.total_weight_->MapTo(GetOrCreateSymbolLocalScope(
edge_atom.total_weight_->name_, edge_atom.total_weight_->user_declared_, Symbol::Type::NUMBER));
}
return false;
}
bool SymbolGenerator::PostVisit(EdgeAtom &) {
auto &scope = scopes_.back();
scope.visiting_edge = nullptr;
scope.in_create_edge = false;
return true;
}
void SymbolGenerator::VisitWithIdentifiers(Expression *expr, const std::vector<Identifier *> &identifiers) {
auto &scope = scopes_.back();
std::vector<std::pair<std::optional<Symbol>, Identifier *>> prev_symbols;
// Collect previous symbols if they exist.
for (const auto &identifier : identifiers) {
std::optional<Symbol> prev_symbol;
auto prev_symbol_it = scope.symbols.find(identifier->name_);
if (prev_symbol_it != scope.symbols.end()) {
prev_symbol = prev_symbol_it->second;
}
identifier->MapTo(CreateSymbol(identifier->name_, identifier->user_declared_));
prev_symbols.emplace_back(prev_symbol, identifier);
}
// Visit the expression with the new symbols bound.
expr->Accept(*this);
// Restore back to previous symbols.
for (const auto &prev : prev_symbols) {
const auto &prev_symbol = prev.first;
const auto &identifier = prev.second;
if (prev_symbol) {
scope.symbols[identifier->name_] = *prev_symbol;
} else {
scope.symbols.erase(identifier->name_);
}
}
}
bool SymbolGenerator::HasSymbol(const std::string &name) const {
return std::ranges::any_of(scopes_, [&name](const auto &scope) { return scope.symbols.contains(name); });
}
bool SymbolGenerator::HasSymbolLocalScope(const std::string &name) const {
return scopes_.back().symbols.contains(name);
}
bool SymbolGenerator::ConsumePredefinedIdentifier(const std::string &name) {
auto it = predefined_identifiers_.find(name);
if (it == predefined_identifiers_.end()) {
return false;
}
// we can only use the predefined identifier in a single scope so we remove it after creating
// a symbol for it
auto &identifier = it->second;
MG_ASSERT(!identifier->user_declared_, "Predefined symbols cannot be user declared!");
identifier->MapTo(CreateSymbol(identifier->name_, identifier->user_declared_));
predefined_identifiers_.erase(it);
return true;
}
} // namespace memgraph::query::v2

View File

@ -0,0 +1,176 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
// Copyright 2017 Memgraph
//
// Created by Teon Banek on 11-03-2017
#pragma once
#include <optional>
#include <vector>
#include "query/v2/exceptions.hpp"
#include "query/v2/frontend/ast/ast.hpp"
#include "query/v2/frontend/semantic/symbol_table.hpp"
namespace memgraph::query::v2 {
/// Visits the AST and generates symbols for variables.
///
/// During the process of symbol generation, simple semantic checks are
/// performed. Such as, redeclaring a variable or conflicting expectations of
/// variable types.
class SymbolGenerator : public HierarchicalTreeVisitor {
public:
explicit SymbolGenerator(SymbolTable *symbol_table, const std::vector<Identifier *> &predefined_identifiers);
using HierarchicalTreeVisitor::PostVisit;
using HierarchicalTreeVisitor::PreVisit;
using HierarchicalTreeVisitor::Visit;
using typename HierarchicalTreeVisitor::ReturnType;
// Query
bool PreVisit(SingleQuery &) override;
// Union
bool PreVisit(CypherUnion &) override;
bool PostVisit(CypherUnion &) override;
// Clauses
bool PreVisit(Create &) override;
bool PostVisit(Create &) override;
bool PreVisit(CallProcedure &) override;
bool PostVisit(CallProcedure &) override;
bool PreVisit(LoadCsv &) override;
bool PostVisit(LoadCsv &) override;
bool PreVisit(Return &) override;
bool PostVisit(Return &) override;
bool PreVisit(With &) override;
bool PreVisit(Where &) override;
bool PostVisit(Where &) override;
bool PreVisit(Merge &) override;
bool PostVisit(Merge &) override;
bool PostVisit(Unwind &) override;
bool PreVisit(Match &) override;
bool PostVisit(Match &) override;
bool PreVisit(Foreach &) override;
bool PostVisit(Foreach &) override;
// Expressions
ReturnType Visit(Identifier &) override;
ReturnType Visit(PrimitiveLiteral &) override { return true; }
ReturnType Visit(ParameterLookup &) override { return true; }
bool PreVisit(Aggregation &) override;
bool PostVisit(Aggregation &) override;
bool PreVisit(IfOperator &) override;
bool PostVisit(IfOperator &) override;
bool PreVisit(All &) override;
bool PreVisit(Single &) override;
bool PreVisit(Any &) override;
bool PreVisit(None &) override;
bool PreVisit(Reduce &) override;
bool PreVisit(Extract &) override;
// Pattern and its subparts.
bool PreVisit(Pattern &) override;
bool PostVisit(Pattern &) override;
bool PreVisit(NodeAtom &) override;
bool PostVisit(NodeAtom &) override;
bool PreVisit(EdgeAtom &) override;
bool PostVisit(EdgeAtom &) override;
private:
// Scope stores the state of where we are when visiting the AST and a map of
// names to symbols.
struct Scope {
bool in_pattern{false};
bool in_merge{false};
bool in_create{false};
// in_create_node is true if we are creating or merging *only* a node.
// Therefore, it is *not* equivalent to (in_create || in_merge) &&
// in_node_atom.
bool in_create_node{false};
// True if creating an edge;
// shortcut for (in_create || in_merge) && visiting_edge.
bool in_create_edge{false};
bool in_node_atom{false};
EdgeAtom *visiting_edge{nullptr};
bool in_aggregation{false};
bool in_return{false};
bool in_with{false};
bool in_skip{false};
bool in_limit{false};
bool in_order_by{false};
bool in_where{false};
bool in_match{false};
bool in_foreach{false};
// True when visiting a pattern atom (node or edge) identifier, which can be
// reused or created in the pattern itself.
bool in_pattern_atom_identifier{false};
// True when visiting range bounds of a variable path.
bool in_edge_range{false};
// True if the return/with contains an aggregation in any named expression.
bool has_aggregation{false};
// Map from variable names to symbols.
std::map<std::string, Symbol> symbols;
// Identifiers found in property maps of patterns or as variable length path
// bounds in a single Match clause. They need to be checked after visiting
// Match. Identifiers created by naming vertices, edges and paths are *not*
// stored in here.
std::vector<Identifier *> identifiers_in_match;
// Number of nested IfOperators.
int num_if_operators{0};
};
static std::optional<Symbol> FindSymbolInScope(const std::string &name, const Scope &scope, Symbol::Type type);
bool HasSymbol(const std::string &name) const;
bool HasSymbolLocalScope(const std::string &name) const;
// @return true if it added a predefined identifier with that name
bool ConsumePredefinedIdentifier(const std::string &name);
// Returns a freshly generated symbol. Previous mapping of the same name to a
// different symbol is replaced with the new one.
auto CreateSymbol(const std::string &name, bool user_declared, Symbol::Type type = Symbol::Type::ANY,
int token_position = -1);
auto GetOrCreateSymbol(const std::string &name, bool user_declared, Symbol::Type type = Symbol::Type::ANY);
// Returns the symbol by name. If the mapping already exists, checks if the
// types match. Otherwise, returns a new symbol.
auto GetOrCreateSymbolLocalScope(const std::string &name, bool user_declared, Symbol::Type type = Symbol::Type::ANY);
void VisitReturnBody(ReturnBody &body, Where *where = nullptr);
void VisitWithIdentifiers(Expression *, const std::vector<Identifier *> &);
SymbolTable *symbol_table_;
// Identifiers which are injected from outside the query. Each identifier
// is mapped by its name.
std::unordered_map<std::string, Identifier *> predefined_identifiers_;
std::vector<Scope> scopes_;
std::unordered_set<std::string> prev_return_names_;
std::unordered_set<std::string> curr_return_names_;
};
inline SymbolTable MakeSymbolTable(CypherQuery *query, const std::vector<Identifier *> &predefined_identifiers = {}) {
SymbolTable symbol_table;
SymbolGenerator symbol_generator(&symbol_table, predefined_identifiers);
query->single_query_->Accept(symbol_generator);
for (auto *cypher_union : query->cypher_unions_) {
cypher_union->Accept(symbol_generator);
}
return symbol_table;
}
} // namespace memgraph::query::v2

View File

@ -0,0 +1,64 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <map>
#include <string>
#include "query/v2/frontend/ast/ast.hpp"
#include "query/v2/frontend/semantic/symbol.hpp"
#include "utils/logging.hpp"
namespace memgraph::query::v2 {
class SymbolTable final {
public:
SymbolTable() {}
const Symbol &CreateSymbol(const std::string &name, bool user_declared, Symbol::Type type = Symbol::Type::ANY,
int32_t token_position = -1) {
MG_ASSERT(table_.size() <= std::numeric_limits<int32_t>::max(),
"SymbolTable size doesn't fit into 32-bit integer!");
auto got = table_.emplace(position_, Symbol(name, position_, user_declared, type, token_position));
MG_ASSERT(got.second, "Duplicate symbol ID!");
position_++;
return got.first->second;
}
// TODO(buda): This is the same logic as in the cypher_main_visitor. During
// parsing phase symbol table doesn't exist. Figure out a better solution.
const Symbol &CreateAnonymousSymbol(Symbol::Type type = Symbol::Type::ANY) {
int id = 1;
while (true) {
static const std::string &kAnonPrefix = "anon";
std::string name_candidate = kAnonPrefix + std::to_string(id++);
if (std::find_if(std::begin(table_), std::end(table_), [&name_candidate](const auto &item) -> bool {
return item.second.name_ == name_candidate;
}) == std::end(table_)) {
return CreateSymbol(name_candidate, false, type);
}
}
}
const Symbol &at(const Identifier &ident) const { return table_.at(ident.symbol_pos_); }
const Symbol &at(const NamedExpression &nexpr) const { return table_.at(nexpr.symbol_pos_); }
const Symbol &at(const Aggregation &aggr) const { return table_.at(aggr.symbol_pos_); }
// TODO: Remove these since members are public
int32_t max_position() const { return static_cast<int32_t>(table_.size()); }
const auto &table() const { return table_; }
int32_t position_{0};
std::map<int32_t, Symbol> table_;
};
} // namespace memgraph::query::v2

View File

@ -0,0 +1,535 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/v2/frontend/stripped.hpp"
#include <cctype>
#include <cstdint>
#include <iostream>
#include <span>
#include <string>
#include <vector>
#include "query/v2/exceptions.hpp"
#include "query/v2/frontend/opencypher/generated/MemgraphCypher.h"
#include "query/v2/frontend/opencypher/generated/MemgraphCypherBaseVisitor.h"
#include "query/v2/frontend/opencypher/generated/MemgraphCypherLexer.h"
#include "query/v2/frontend/parsing.hpp"
#include "query/v2/frontend/stripped_lexer_constants.hpp"
#include "utils/fnv.hpp"
#include "utils/logging.hpp"
#include "utils/string.hpp"
namespace memgraph::query::v2::frontend {
using namespace lexer_constants;
StrippedQuery::StrippedQuery(const std::string &query) : original_(query) {
enum class Token {
UNMATCHED,
KEYWORD, // Including true, false and null.
SPECIAL, // +, .., +=, (, { and so on.
STRING,
INT, // Decimal, octal and hexadecimal.
REAL,
PARAMETER,
ESCAPED_NAME,
UNESCAPED_NAME,
SPACE
};
std::vector<std::pair<Token, std::string>> tokens;
std::string unstripped_chunk;
for (int i = 0; i < static_cast<int>(original_.size());) {
Token token = Token::UNMATCHED;
int len = 0;
auto update = [&](int new_len, Token new_token) {
if (new_len > len) {
len = new_len;
token = new_token;
}
};
update(MatchKeyword(i), Token::KEYWORD);
update(MatchSpecial(i), Token::SPECIAL);
update(MatchString(i), Token::STRING);
update(MatchDecimalInt(i), Token::INT);
update(MatchOctalInt(i), Token::INT);
update(MatchHexadecimalInt(i), Token::INT);
update(MatchReal(i), Token::REAL);
update(MatchParameter(i), Token::PARAMETER);
update(MatchEscapedName(i), Token::ESCAPED_NAME);
update(MatchUnescapedName(i), Token::UNESCAPED_NAME);
update(MatchWhitespaceAndComments(i), Token::SPACE);
if (token == Token::UNMATCHED) throw LexingException("Invalid query.");
tokens.emplace_back(token, original_.substr(i, len));
i += len;
// If we notice execute, we possibly create a trigger which has defined statements.
// The statements will be parsed separately later on so we skip it for now.
if (utils::IEquals(tokens.back().second, "execute")) {
// check if it's CREATE TRIGGER query
std::span token_span{tokens};
// query could start with spaces and/or comments
if (token_span.front().first == Token::SPACE) {
token_span = token_span.subspan(1);
}
// we need to check that first and third elements are correct keywords
// CREATE<SPACE>TRIGGER<SPACE>trigger-name...EXECUTE
// trigger-name (5th element) can also be "execute" so we verify that the size is larger than 5
if (token_span.size() > 5 && utils::IEquals(token_span[0].second, "create") &&
utils::IEquals(token_span[2].second, "trigger")) {
unstripped_chunk = original_.substr(i);
break;
}
}
}
std::vector<std::string> token_strings;
// A helper function that stores literal and its token position in a
// literals_. In stripped query text literal is replaced with a new_value.
// new_value can be any value that is lexed as a literal.
auto replace_stripped = [this, &token_strings](int position, const auto &value, const std::string &new_value) {
literals_.Add(position, storage::v3::PropertyValue(value));
token_strings.push_back(new_value);
};
// Copy original tokens because we need to use original case in named
// expressions and keywords in tokens will be lowercased in the next loop.
auto original_tokens = tokens;
// For every token in original query remember token index in stripped query.
std::vector<int> position_mapping(tokens.size(), -1);
// Convert tokens to strings, perform filtering, store literals and nonaliased
// named expressions in return.
for (int i = 0; i < static_cast<int>(tokens.size()); ++i) {
auto &token = tokens[i];
// We need to shift token index for every parameter since antlr's parser
// thinks of parameter as two tokens.
int token_index = token_strings.size() + parameters_.size();
switch (token.first) {
case Token::UNMATCHED:
LOG_FATAL("Shouldn't happen");
case Token::KEYWORD: {
// We don't strip NULL, since it can appear in special expressions
// like IS NULL and IS NOT NULL, but we strip true and false keywords.
if (utils::IEquals(token.second, "true")) {
replace_stripped(token_index, true, kStrippedBooleanToken);
} else if (utils::IEquals(token.second, "false")) {
replace_stripped(token_index, false, kStrippedBooleanToken);
} else {
token_strings.push_back(token.second);
}
} break;
case Token::SPACE:
break;
case Token::STRING:
replace_stripped(token_index, ParseStringLiteral(token.second), kStrippedStringToken);
break;
case Token::INT:
replace_stripped(token_index, ParseIntegerLiteral(token.second), kStrippedIntToken);
break;
case Token::REAL:
replace_stripped(token_index, ParseDoubleLiteral(token.second), kStrippedDoubleToken);
break;
case Token::SPECIAL:
case Token::ESCAPED_NAME:
case Token::UNESCAPED_NAME:
token_strings.push_back(token.second);
break;
case Token::PARAMETER:
parameters_[token_index] = ParseParameter(token.second);
token_strings.push_back(token.second);
break;
}
if (token.first != Token::SPACE) {
position_mapping[i] = token_index;
}
}
if (!unstripped_chunk.empty()) {
token_strings.push_back(std::move(unstripped_chunk));
}
query_ = utils::Join(token_strings, " ");
hash_ = utils::Fnv(query_);
auto it = tokens.begin();
while (it != tokens.end()) {
// Store nonaliased named expressions in returns in named_exprs_.
it = std::find_if(it, tokens.end(),
[](const std::pair<Token, std::string> &a) { return utils::IEquals(a.second, "return"); });
// There is no RETURN so there is nothing to do here.
if (it == tokens.end()) return;
// Skip RETURN;
++it;
// Now we need to parse cypherReturn production from opencypher grammar.
// Skip leading whitespaces and DISTINCT statemant if there is one.
while (it != tokens.end() && it->first == Token::SPACE) {
++it;
}
if (it != tokens.end() && utils::IEquals(it->second, "distinct")) {
++it;
}
// If the query is invalid, either antlr parser or cypher_main_visitor will
// report an error.
// TODO: we shouldn't rely on the fact that those checks will be done
// after this step. We should do them here.
while (it < tokens.end()) {
// Disregard leading whitespace
while (it != tokens.end() && it->first == Token::SPACE) {
++it;
}
// There is only whitespace, nothing to do...
if (it == tokens.end()) break;
bool has_as = false;
auto last_non_space = it;
auto jt = it;
// We should track number of opened braces and parantheses so that we can
// recognize if comma is a named expression separator or part of the
// list literal / function call.
int num_open_braces = 0;
int num_open_parantheses = 0;
int num_open_brackets = 0;
for (;
jt != tokens.end() && (jt->second != "," || num_open_braces || num_open_parantheses || num_open_brackets) &&
!utils::IEquals(jt->second, "order") && !utils::IEquals(jt->second, "skip") &&
!utils::IEquals(jt->second, "limit") && !utils::IEquals(jt->second, "union") &&
!utils::IEquals(jt->second, "query") && jt->second != ";";
++jt) {
if (jt->second == "(") {
++num_open_parantheses;
} else if (jt->second == ")") {
--num_open_parantheses;
} else if (jt->second == "[") {
++num_open_braces;
} else if (jt->second == "]") {
--num_open_braces;
} else if (jt->second == "{") {
++num_open_brackets;
} else if (jt->second == "}") {
--num_open_brackets;
}
has_as |= utils::IEquals(jt->second, "as");
if (jt->first != Token::SPACE) {
last_non_space = jt;
}
}
if (!has_as) {
// Named expression is not aliased. Save string disregarding leading and
// trailing whitespaces.
std::string s;
auto begin_token = it - tokens.begin() + original_tokens.begin();
auto end_token = last_non_space - tokens.begin() + original_tokens.begin() + 1;
for (auto kt = begin_token; kt != end_token; ++kt) {
s += kt->second;
}
named_exprs_[position_mapping[it - tokens.begin()]] = s;
}
if (jt != tokens.end() && jt->second == ",") {
// There are more named expressions.
it = jt + 1;
} else {
// We're done with this return statement
break;
}
}
}
}
std::string GetFirstUtf8Symbol(const char *_s) {
// According to
// https://stackoverflow.com/questions/16260033/reinterpret-cast-between-char-and-stduint8-t-safe
// this checks if casting from const char * to uint8_t is undefined behaviour.
static_assert(std::is_same<std::uint8_t, unsigned char>::value,
"This library requires std::uint8_t to be implemented as "
"unsigned char.");
const uint8_t *s = reinterpret_cast<const uint8_t *>(_s);
if ((*s >> 7) == 0x00) return std::string(_s, _s + 1);
if ((*s >> 5) == 0x06) {
auto *s1 = s + 1;
if ((*s1 >> 6) != 0x02) throw LexingException("Invalid character.");
return std::string(_s, _s + 2);
}
if ((*s >> 4) == 0x0e) {
auto *s1 = s + 1;
if ((*s1 >> 6) != 0x02) throw LexingException("Invalid character.");
auto *s2 = s + 2;
if ((*s2 >> 6) != 0x02) throw LexingException("Invalid character.");
return std::string(_s, _s + 3);
}
if ((*s >> 3) == 0x1e) {
auto *s1 = s + 1;
if ((*s1 >> 6) != 0x02) throw LexingException("Invalid character.");
auto *s2 = s + 2;
if ((*s2 >> 6) != 0x02) throw LexingException("Invalid character.");
auto *s3 = s + 3;
if ((*s3 >> 6) != 0x02) throw LexingException("Invalid character.");
return std::string(_s, _s + 4);
}
throw LexingException("Invalid character.");
}
// Return codepoint of first utf8 symbol and its encoded length.
std::pair<int, int> GetFirstUtf8SymbolCodepoint(const char *_s) {
static_assert(std::is_same<std::uint8_t, unsigned char>::value,
"This library requires std::uint8_t to be implemented as "
"unsigned char.");
const uint8_t *s = reinterpret_cast<const uint8_t *>(_s);
if ((*s >> 7) == 0x00) return {*s & 0x7f, 1};
if ((*s >> 5) == 0x06) {
auto *s1 = s + 1;
if ((*s1 >> 6) != 0x02) throw LexingException("Invalid character.");
return {((*s & 0x1f) << 6) | (*s1 & 0x3f), 2};
}
if ((*s >> 4) == 0x0e) {
auto *s1 = s + 1;
if ((*s1 >> 6) != 0x02) throw LexingException("Invalid character.");
auto *s2 = s + 2;
if ((*s2 >> 6) != 0x02) throw LexingException("Invalid character.");
return {((*s & 0x0f) << 12) | ((*s1 & 0x3f) << 6) | (*s2 & 0x3f), 3};
}
if ((*s >> 3) == 0x1e) {
auto *s1 = s + 1;
if ((*s1 >> 6) != 0x02) throw LexingException("Invalid character.");
auto *s2 = s + 2;
if ((*s2 >> 6) != 0x02) throw LexingException("Invalid character.");
auto *s3 = s + 3;
if ((*s3 >> 6) != 0x02) throw LexingException("Invalid character.");
return {((*s & 0x07) << 18) | ((*s1 & 0x3f) << 12) | ((*s2 & 0x3f) << 6) | (*s3 & 0x3f), 4};
}
throw LexingException("Invalid character.");
}
// From here until end of file there are functions that calculate matches for
// every possible token. Functions are more or less compatible with Cypher.g4
// grammar. Unfortunately, they contain a lof of special cases and shouldn't
// be changed without good reasons.
//
// Here be dragons, do not touch!
// ____ __
// { --.\ | .)%%%)%%
// '-._\\ | (\___ %)%%(%%(%%%
// `\\|{/ ^ _)-%(%%%%)%%;%%%
// .'^^^^^^^ /` %%)%%%%)%%%'
// //\ ) , / '%%%%(%%'
// , _.'/ `\<-- \<
// `^^^` ^^ ^^
int StrippedQuery::MatchKeyword(int start) const { return kKeywords.Match<tolower>(original_.c_str() + start); }
int StrippedQuery::MatchSpecial(int start) const { return kSpecialTokens.Match(original_.c_str() + start); }
int StrippedQuery::MatchString(int start) const {
if (original_[start] != '"' && original_[start] != '\'') return 0;
char start_char = original_[start];
for (auto *p = original_.data() + start + 1; *p; ++p) {
if (*p == start_char) return p - (original_.data() + start) + 1;
if (*p == '\\') {
++p;
if (*p == '\\' || *p == '\'' || *p == '"' || *p == 'B' || *p == 'b' || *p == 'F' || *p == 'f' || *p == 'N' ||
*p == 'n' || *p == 'R' || *p == 'r' || *p == 'T' || *p == 't') {
// Allowed escaped characters.
continue;
} else if (*p == 'U' || *p == 'u') {
int cnt = 0;
auto *r = p + 1;
while (isxdigit(*r) && cnt < 8) {
++cnt;
++r;
}
if (!*r) return 0;
if (cnt < 4) return 0;
if (cnt >= 4 && cnt < 8) {
p += 4;
}
if (cnt >= 8) {
p += 8;
}
} else {
return 0;
}
}
}
return 0;
}
int StrippedQuery::MatchDecimalInt(int start) const {
if (original_[start] == '0') return 1;
int i = start;
while (i < static_cast<int>(original_.size()) && isdigit(original_[i])) {
++i;
}
return i - start;
}
int StrippedQuery::MatchOctalInt(int start) const {
if (original_[start] != '0') return 0;
int i = start + 1;
while (i < static_cast<int>(original_.size()) && '0' <= original_[i] && original_[i] <= '7') {
++i;
}
if (i == start + 1) return 0;
return i - start;
}
int StrippedQuery::MatchHexadecimalInt(int start) const {
if (original_[start] != '0') return 0;
if (start + 1 >= static_cast<int>(original_.size())) return 0;
if (original_[start + 1] != 'x') return 0;
int i = start + 2;
while (i < static_cast<int>(original_.size()) && isxdigit(original_[i])) {
++i;
}
if (i == start + 2) return 0;
return i - start;
}
int StrippedQuery::MatchReal(int start) const {
enum class State { START, BEFORE_DOT, DOT, AFTER_DOT, E, E_MINUS, AFTER_E };
State state = State::START;
auto i = start;
while (i < static_cast<int>(original_.size())) {
if (original_[i] == '.') {
if (state != State::BEFORE_DOT && state != State::START) break;
state = State::DOT;
} else if ('0' <= original_[i] && original_[i] <= '9') {
if (state == State::START) {
state = State::BEFORE_DOT;
} else if (state == State::DOT) {
state = State::AFTER_DOT;
} else if (state == State::E || state == State::E_MINUS) {
state = State::AFTER_E;
}
} else if (original_[i] == 'e' || original_[i] == 'E') {
if (state != State::BEFORE_DOT && state != State::AFTER_DOT) break;
state = State::E;
} else if (original_[i] == '-') {
if (state != State::E) break;
state = State::E_MINUS;
} else {
break;
}
++i;
}
if (state == State::DOT) --i;
if (state == State::E) --i;
if (state == State::E_MINUS) i -= 2;
return i - start;
}
int StrippedQuery::MatchParameter(int start) const {
int len = original_.size();
if (start + 1 == len) return 0;
if (original_[start] != '$') return 0;
int max_len = 0;
max_len = std::max(max_len, MatchUnescapedName(start + 1));
max_len = std::max(max_len, MatchEscapedName(start + 1));
max_len = std::max(max_len, MatchKeyword(start + 1));
max_len = std::max(max_len, MatchDecimalInt(start + 1));
if (max_len == 0) return 0;
return 1 + max_len;
}
int StrippedQuery::MatchEscapedName(int start) const {
int len = original_.size();
int i = start;
while (i < len) {
if (original_[i] != '`') break;
int j = i + 1;
while (j < len && original_[j] != '`') {
++j;
}
if (j == len) break;
i = j + 1;
}
return i - start;
}
int StrippedQuery::MatchUnescapedName(int start) const {
auto i = start;
auto got = GetFirstUtf8SymbolCodepoint(original_.data() + i);
if (got.first >= lexer_constants::kBitsetSize || !kUnescapedNameAllowedStarts[got.first]) {
return 0;
}
i += got.second;
while (i < static_cast<int>(original_.size())) {
got = GetFirstUtf8SymbolCodepoint(original_.data() + i);
if (got.first >= lexer_constants::kBitsetSize || !kUnescapedNameAllowedParts[got.first]) {
break;
}
i += got.second;
}
return i - start;
}
int StrippedQuery::MatchWhitespaceAndComments(int start) const {
enum class State { OUT, IN_LINE_COMMENT, IN_BLOCK_COMMENT };
State state = State::OUT;
int i = start;
int len = original_.size();
// We need to remember at which position comment started because if we fail
// to match comment finish we have a match until comment start position.
int comment_position = -1;
while (i < len) {
if (state == State::OUT) {
auto got = GetFirstUtf8SymbolCodepoint(original_.data() + i);
if (got.first < lexer_constants::kBitsetSize && kSpaceParts[got.first]) {
i += got.second;
} else if (i + 1 < len && original_[i] == '/' && original_[i + 1] == '*') {
comment_position = i;
state = State::IN_BLOCK_COMMENT;
i += 2;
} else if (i + 1 < len && original_[i] == '/' && original_[i + 1] == '/') {
comment_position = i;
if (i + 2 < len) {
// Special case for an empty line comment starting right at the end of
// the query.
state = State::IN_LINE_COMMENT;
}
i += 2;
} else {
break;
}
} else if (state == State::IN_LINE_COMMENT) {
if (original_[i] == '\n') {
state = State::OUT;
++i;
} else if (i + 1 < len && original_[i] == '\r' && original_[i + 1] == '\n') {
state = State::OUT;
i += 2;
} else if (original_[i] == '\r') {
break;
} else if (i + 1 == len) {
state = State::OUT;
++i;
} else {
++i;
}
} else if (state == State::IN_BLOCK_COMMENT) {
if (i + 1 < len && original_[i] == '*' && original_[i + 1] == '/') {
i += 2;
state = State::OUT;
} else {
++i;
}
}
}
if (state != State::OUT) return comment_position - start;
return i - start;
}
} // namespace memgraph::query::v2::frontend

View File

@ -0,0 +1,103 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <string>
#include <unordered_map>
#include "query/v2/parameters.hpp"
#include "utils/fnv.hpp"
namespace memgraph::query::v2::frontend {
// Strings used to replace original tokens. Different types are replaced with
// different token.
const std::string kStrippedIntToken = "0";
const std::string kStrippedDoubleToken = "0.0";
const std::string kStrippedStringToken = "\"a\"";
const std::string kStrippedBooleanToken = "true";
/**
* StrippedQuery contains:
* * stripped query
* * literals stripped from query
* * hash of stripped query
*/
class StrippedQuery {
public:
/**
* Strips the input query and stores stripped query, stripped arguments and
* stripped query hash.
*
* @param query Input query.
*/
explicit StrippedQuery(const std::string &query);
/**
* Copy constructor is deleted because we don't want to make unnecessary
* copies of this object (copying of string and vector could be expensive)
*/
StrippedQuery(const StrippedQuery &other) = delete;
StrippedQuery &operator=(const StrippedQuery &other) = delete;
/**
* Move is allowed operation because it is not expensive and we can
* move the object after it was created.
*/
StrippedQuery(StrippedQuery &&other) = default;
StrippedQuery &operator=(StrippedQuery &&other) = default;
const std::string &query() const { return query_; }
const auto &original_query() const { return original_; }
const auto &literals() const { return literals_; }
const auto &named_expressions() const { return named_exprs_; }
const auto &parameters() const { return parameters_; }
uint64_t hash() const { return hash_; }
private:
// Return len of matched keyword if something is matched, otherwise 0.
int MatchKeyword(int start) const;
int MatchString(int start) const;
int MatchSpecial(int start) const;
int MatchDecimalInt(int start) const;
int MatchOctalInt(int start) const;
int MatchHexadecimalInt(int start) const;
int MatchReal(int start) const;
int MatchParameter(int start) const;
int MatchEscapedName(int start) const;
int MatchUnescapedName(int start) const;
int MatchWhitespaceAndComments(int start) const;
// Original query.
std::string original_;
// Stripped query.
std::string query_;
// Token positions of stripped out literals mapped to their values.
// TODO: Parameters class really doesn't provide anything interesting. This
// could be changed to std::unordered_map, but first we need to rewrite (or
// get rid of) hardcoded queries which expect Parameters.
Parameters literals_;
// Token positions of query parameters mapped to their names.
std::unordered_map<int, std::string> parameters_;
// Token positions of nonaliased named expressions in return statement mapped
// to their original (unstripped) string.
std::unordered_map<int, std::string> named_exprs_;
// Hash based on the stripped query.
uint64_t hash_;
};
} // namespace memgraph::query::v2::frontend

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,50 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <functional>
#include <string>
#include <unordered_map>
#include "storage/v3/view.hpp"
#include "utils/memory.hpp"
namespace memgraph::query::v2 {
class DbAccessor;
class TypedValue;
namespace {
const char kStartsWith[] = "STARTSWITH";
const char kEndsWith[] = "ENDSWITH";
const char kContains[] = "CONTAINS";
const char kId[] = "ID";
} // namespace
struct FunctionContext {
DbAccessor *db_accessor;
utils::MemoryResource *memory;
int64_t timestamp;
std::unordered_map<std::string, int64_t> *counters;
storage::v3::View view;
};
/// Return the function implementation with the given name.
///
/// Note, returned function signature uses C-style access to an array to allow
/// having an array stored anywhere the caller likes, as long as it is
/// contiguous in memory. Since most functions don't take many arguments, it's
/// convenient to have them stored in the calling stack frame.
std::function<TypedValue(const TypedValue *arguments, int64_t num_arguments, const FunctionContext &context)>
NameToFunction(const std::string &function_name);
} // namespace memgraph::query::v2

View File

@ -0,0 +1,35 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/v2/interpret/eval.hpp"
namespace memgraph::query::v2 {
int64_t EvaluateInt(ExpressionEvaluator *evaluator, Expression *expr, const std::string &what) {
TypedValue value = expr->Accept(*evaluator);
try {
return value.ValueInt();
} catch (TypedValueException &e) {
throw QueryRuntimeException(what + " must be an int");
}
}
std::optional<size_t> EvaluateMemoryLimit(ExpressionEvaluator *eval, Expression *memory_limit, size_t memory_scale) {
if (!memory_limit) return std::nullopt;
auto limit_value = memory_limit->Accept(*eval);
if (!limit_value.IsInt() || limit_value.ValueInt() <= 0)
throw QueryRuntimeException("Memory limit must be a non-negative integer.");
size_t limit = limit_value.ValueInt();
if (std::numeric_limits<size_t>::max() / memory_scale < limit) throw QueryRuntimeException("Memory limit overflow.");
return limit * memory_scale;
}
} // namespace memgraph::query::v2

View File

@ -0,0 +1,764 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
/// @file
#pragma once
#include <algorithm>
#include <limits>
#include <map>
#include <optional>
#include <regex>
#include <vector>
#include "query/v2/common.hpp"
#include "query/v2/context.hpp"
#include "query/v2/db_accessor.hpp"
#include "query/v2/exceptions.hpp"
#include "query/v2/frontend/ast/ast.hpp"
#include "query/v2/frontend/semantic/symbol_table.hpp"
#include "query/v2/interpret/frame.hpp"
#include "query/v2/typed_value.hpp"
#include "utils/exceptions.hpp"
namespace memgraph::query::v2 {
class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
public:
ExpressionEvaluator(Frame *frame, const SymbolTable &symbol_table, const EvaluationContext &ctx, DbAccessor *dba,
storage::v3::View view)
: frame_(frame), symbol_table_(&symbol_table), ctx_(&ctx), dba_(dba), view_(view) {}
using ExpressionVisitor<TypedValue>::Visit;
utils::MemoryResource *GetMemoryResource() const { return ctx_->memory; }
TypedValue Visit(NamedExpression &named_expression) override {
const auto &symbol = symbol_table_->at(named_expression);
auto value = named_expression.expression_->Accept(*this);
frame_->at(symbol) = value;
return value;
}
TypedValue Visit(Identifier &ident) override {
return TypedValue(frame_->at(symbol_table_->at(ident)), ctx_->memory);
}
#define BINARY_OPERATOR_VISITOR(OP_NODE, CPP_OP, CYPHER_OP) \
TypedValue Visit(OP_NODE &op) override { \
auto val1 = op.expression1_->Accept(*this); \
auto val2 = op.expression2_->Accept(*this); \
try { \
return val1 CPP_OP val2; \
} catch (const TypedValueException &) { \
throw QueryRuntimeException("Invalid types: {} and {} for '{}'.", val1.type(), val2.type(), #CYPHER_OP); \
} \
}
#define UNARY_OPERATOR_VISITOR(OP_NODE, CPP_OP, CYPHER_OP) \
TypedValue Visit(OP_NODE &op) override { \
auto val = op.expression_->Accept(*this); \
try { \
return CPP_OP val; \
} catch (const TypedValueException &) { \
throw QueryRuntimeException("Invalid type {} for '{}'.", val.type(), #CYPHER_OP); \
} \
}
BINARY_OPERATOR_VISITOR(OrOperator, ||, OR);
BINARY_OPERATOR_VISITOR(XorOperator, ^, XOR);
BINARY_OPERATOR_VISITOR(AdditionOperator, +, +);
BINARY_OPERATOR_VISITOR(SubtractionOperator, -, -);
BINARY_OPERATOR_VISITOR(MultiplicationOperator, *, *);
BINARY_OPERATOR_VISITOR(DivisionOperator, /, /);
BINARY_OPERATOR_VISITOR(ModOperator, %, %);
BINARY_OPERATOR_VISITOR(NotEqualOperator, !=, <>);
BINARY_OPERATOR_VISITOR(EqualOperator, ==, =);
BINARY_OPERATOR_VISITOR(LessOperator, <, <);
BINARY_OPERATOR_VISITOR(GreaterOperator, >, >);
BINARY_OPERATOR_VISITOR(LessEqualOperator, <=, <=);
BINARY_OPERATOR_VISITOR(GreaterEqualOperator, >=, >=);
UNARY_OPERATOR_VISITOR(NotOperator, !, NOT);
UNARY_OPERATOR_VISITOR(UnaryPlusOperator, +, +);
UNARY_OPERATOR_VISITOR(UnaryMinusOperator, -, -);
#undef BINARY_OPERATOR_VISITOR
#undef UNARY_OPERATOR_VISITOR
TypedValue Visit(AndOperator &op) override {
auto value1 = op.expression1_->Accept(*this);
if (value1.IsBool() && !value1.ValueBool()) {
// If first expression is false, don't evaluate the second one.
return value1;
}
auto value2 = op.expression2_->Accept(*this);
try {
return value1 && value2;
} catch (const TypedValueException &) {
throw QueryRuntimeException("Invalid types: {} and {} for AND.", value1.type(), value2.type());
}
}
TypedValue Visit(IfOperator &if_operator) override {
auto condition = if_operator.condition_->Accept(*this);
if (condition.IsNull()) {
return if_operator.then_expression_->Accept(*this);
}
if (condition.type() != TypedValue::Type::Bool) {
// At the moment IfOperator is used only in CASE construct.
throw QueryRuntimeException("CASE expected boolean expression, got {}.", condition.type());
}
if (condition.ValueBool()) {
return if_operator.then_expression_->Accept(*this);
}
return if_operator.else_expression_->Accept(*this);
}
TypedValue Visit(InListOperator &in_list) override {
auto literal = in_list.expression1_->Accept(*this);
auto _list = in_list.expression2_->Accept(*this);
if (_list.IsNull()) {
return TypedValue(ctx_->memory);
}
// Exceptions have higher priority than returning nulls when list expression
// is not null.
if (_list.type() != TypedValue::Type::List) {
throw QueryRuntimeException("IN expected a list, got {}.", _list.type());
}
const auto &list = _list.ValueList();
// If literal is NULL there is no need to try to compare it with every
// element in the list since result of every comparison will be NULL. There
// is one special case that we must test explicitly: if list is empty then
// result is false since no comparison will be performed.
if (list.empty()) return TypedValue(false, ctx_->memory);
if (literal.IsNull()) return TypedValue(ctx_->memory);
auto has_null = false;
for (const auto &element : list) {
auto result = literal == element;
if (result.IsNull()) {
has_null = true;
} else if (result.ValueBool()) {
return TypedValue(true, ctx_->memory);
}
}
if (has_null) {
return TypedValue(ctx_->memory);
}
return TypedValue(false, ctx_->memory);
}
TypedValue Visit(SubscriptOperator &list_indexing) override {
auto lhs = list_indexing.expression1_->Accept(*this);
auto index = list_indexing.expression2_->Accept(*this);
if (!lhs.IsList() && !lhs.IsMap() && !lhs.IsVertex() && !lhs.IsEdge() && !lhs.IsNull())
throw QueryRuntimeException(
"Expected a list, a map, a node or an edge to index with '[]', got "
"{}.",
lhs.type());
if (lhs.IsNull() || index.IsNull()) return TypedValue(ctx_->memory);
if (lhs.IsList()) {
if (!index.IsInt()) throw QueryRuntimeException("Expected an integer as a list index, got {}.", index.type());
auto index_int = index.ValueInt();
// NOTE: Take non-const reference to list, so that we can move out the
// indexed element as the result.
auto &list = lhs.ValueList();
if (index_int < 0) {
index_int += static_cast<int64_t>(list.size());
}
if (index_int >= static_cast<int64_t>(list.size()) || index_int < 0) return TypedValue(ctx_->memory);
// NOTE: Explicit move is needed, so that we return the move constructed
// value and preserve the correct MemoryResource.
return std::move(list[index_int]);
}
if (lhs.IsMap()) {
if (!index.IsString()) throw QueryRuntimeException("Expected a string as a map index, got {}.", index.type());
// NOTE: Take non-const reference to map, so that we can move out the
// looked-up element as the result.
auto &map = lhs.ValueMap();
auto found = map.find(index.ValueString());
if (found == map.end()) return TypedValue(ctx_->memory);
// NOTE: Explicit move is needed, so that we return the move constructed
// value and preserve the correct MemoryResource.
return std::move(found->second);
}
if (lhs.IsVertex()) {
if (!index.IsString()) throw QueryRuntimeException("Expected a string as a property name, got {}.", index.type());
return TypedValue(GetProperty(lhs.ValueVertex(), index.ValueString()), ctx_->memory);
}
if (lhs.IsEdge()) {
if (!index.IsString()) throw QueryRuntimeException("Expected a string as a property name, got {}.", index.type());
return TypedValue(GetProperty(lhs.ValueEdge(), index.ValueString()), ctx_->memory);
}
// lhs is Null
return TypedValue(ctx_->memory);
}
TypedValue Visit(ListSlicingOperator &op) override {
// If some type is null we can't return null, because throwing exception
// on illegal type has higher priority.
auto is_null = false;
auto get_bound = [&](Expression *bound_expr, int64_t default_value) {
if (bound_expr) {
auto bound = bound_expr->Accept(*this);
if (bound.type() == TypedValue::Type::Null) {
is_null = true;
} else if (bound.type() != TypedValue::Type::Int) {
throw QueryRuntimeException("Expected an integer for a bound in list slicing, got {}.", bound.type());
}
return bound;
}
return TypedValue(default_value, ctx_->memory);
};
auto _upper_bound = get_bound(op.upper_bound_, std::numeric_limits<int64_t>::max());
auto _lower_bound = get_bound(op.lower_bound_, 0);
auto _list = op.list_->Accept(*this);
if (_list.type() == TypedValue::Type::Null) {
is_null = true;
} else if (_list.type() != TypedValue::Type::List) {
throw QueryRuntimeException("Expected a list to slice, got {}.", _list.type());
}
if (is_null) {
return TypedValue(ctx_->memory);
}
const auto &list = _list.ValueList();
auto normalise_bound = [&](int64_t bound) {
if (bound < 0) {
bound = static_cast<int64_t>(list.size()) + bound;
}
return std::max(static_cast<int64_t>(0), std::min(bound, static_cast<int64_t>(list.size())));
};
auto lower_bound = normalise_bound(_lower_bound.ValueInt());
auto upper_bound = normalise_bound(_upper_bound.ValueInt());
if (upper_bound <= lower_bound) {
return TypedValue(TypedValue::TVector(ctx_->memory), ctx_->memory);
}
return TypedValue(TypedValue::TVector(list.begin() + lower_bound, list.begin() + upper_bound, ctx_->memory));
}
TypedValue Visit(IsNullOperator &is_null) override {
auto value = is_null.expression_->Accept(*this);
return TypedValue(value.IsNull(), ctx_->memory);
}
TypedValue Visit(PropertyLookup &property_lookup) override {
auto expression_result = property_lookup.expression_->Accept(*this);
auto maybe_date = [this](const auto &date, const auto &prop_name) -> std::optional<TypedValue> {
if (prop_name == "year") {
return TypedValue(date.year, ctx_->memory);
}
if (prop_name == "month") {
return TypedValue(date.month, ctx_->memory);
}
if (prop_name == "day") {
return TypedValue(date.day, ctx_->memory);
}
return std::nullopt;
};
auto maybe_local_time = [this](const auto &lt, const auto &prop_name) -> std::optional<TypedValue> {
if (prop_name == "hour") {
return TypedValue(lt.hour, ctx_->memory);
}
if (prop_name == "minute") {
return TypedValue(lt.minute, ctx_->memory);
}
if (prop_name == "second") {
return TypedValue(lt.second, ctx_->memory);
}
if (prop_name == "millisecond") {
return TypedValue(lt.millisecond, ctx_->memory);
}
if (prop_name == "microsecond") {
return TypedValue(lt.microsecond, ctx_->memory);
}
return std::nullopt;
};
auto maybe_duration = [this](const auto &dur, const auto &prop_name) -> std::optional<TypedValue> {
if (prop_name == "day") {
return TypedValue(dur.Days(), ctx_->memory);
}
if (prop_name == "hour") {
return TypedValue(dur.SubDaysAsHours(), ctx_->memory);
}
if (prop_name == "minute") {
return TypedValue(dur.SubDaysAsMinutes(), ctx_->memory);
}
if (prop_name == "second") {
return TypedValue(dur.SubDaysAsSeconds(), ctx_->memory);
}
if (prop_name == "millisecond") {
return TypedValue(dur.SubDaysAsMilliseconds(), ctx_->memory);
}
if (prop_name == "microsecond") {
return TypedValue(dur.SubDaysAsMicroseconds(), ctx_->memory);
}
if (prop_name == "nanosecond") {
return TypedValue(dur.SubDaysAsNanoseconds(), ctx_->memory);
}
return std::nullopt;
};
switch (expression_result.type()) {
case TypedValue::Type::Null:
return TypedValue(ctx_->memory);
case TypedValue::Type::Vertex:
return TypedValue(GetProperty(expression_result.ValueVertex(), property_lookup.property_), ctx_->memory);
case TypedValue::Type::Edge:
return TypedValue(GetProperty(expression_result.ValueEdge(), property_lookup.property_), ctx_->memory);
case TypedValue::Type::Map: {
// NOTE: Take non-const reference to map, so that we can move out the
// looked-up element as the result.
auto &map = expression_result.ValueMap();
auto found = map.find(property_lookup.property_.name.c_str());
if (found == map.end()) return TypedValue(ctx_->memory);
// NOTE: Explicit move is needed, so that we return the move constructed
// value and preserve the correct MemoryResource.
return std::move(found->second);
}
case TypedValue::Type::Duration: {
const auto &prop_name = property_lookup.property_.name;
const auto &dur = expression_result.ValueDuration();
if (auto dur_field = maybe_duration(dur, prop_name); dur_field) {
return std::move(*dur_field);
}
throw QueryRuntimeException("Invalid property name {} for Duration", prop_name);
}
case TypedValue::Type::Date: {
const auto &prop_name = property_lookup.property_.name;
const auto &date = expression_result.ValueDate();
if (auto date_field = maybe_date(date, prop_name); date_field) {
return std::move(*date_field);
}
throw QueryRuntimeException("Invalid property name {} for Date", prop_name);
}
case TypedValue::Type::LocalTime: {
const auto &prop_name = property_lookup.property_.name;
const auto &lt = expression_result.ValueLocalTime();
if (auto lt_field = maybe_local_time(lt, prop_name); lt_field) {
return std::move(*lt_field);
}
throw QueryRuntimeException("Invalid property name {} for LocalTime", prop_name);
}
case TypedValue::Type::LocalDateTime: {
const auto &prop_name = property_lookup.property_.name;
const auto &ldt = expression_result.ValueLocalDateTime();
if (auto date_field = maybe_date(ldt.date, prop_name); date_field) {
return std::move(*date_field);
}
if (auto lt_field = maybe_local_time(ldt.local_time, prop_name); lt_field) {
return std::move(*lt_field);
}
throw QueryRuntimeException("Invalid property name {} for LocalDateTime", prop_name);
}
default:
throw QueryRuntimeException("Only nodes, edges, maps and temporal types have properties to be looked-up.");
}
}
TypedValue Visit(LabelsTest &labels_test) override {
auto expression_result = labels_test.expression_->Accept(*this);
switch (expression_result.type()) {
case TypedValue::Type::Null:
return TypedValue(ctx_->memory);
case TypedValue::Type::Vertex: {
const auto &vertex = expression_result.ValueVertex();
for (const auto &label : labels_test.labels_) {
auto has_label = vertex.HasLabel(view_, GetLabel(label));
if (has_label.HasError() && has_label.GetError() == storage::v3::Error::NONEXISTENT_OBJECT) {
// This is a very nasty and temporary hack in order to make MERGE
// work. The old storage had the following logic when returning an
// `OLD` view: `return old ? old : new`. That means that if the
// `OLD` view didn't exist, it returned the NEW view. With this hack
// we simulate that behavior.
// TODO (mferencevic, teon.banek): Remove once MERGE is
// reimplemented.
has_label = vertex.HasLabel(storage::v3::View::NEW, GetLabel(label));
}
if (has_label.HasError()) {
switch (has_label.GetError()) {
case storage::v3::Error::DELETED_OBJECT:
throw QueryRuntimeException("Trying to access labels on a deleted node.");
case storage::v3::Error::NONEXISTENT_OBJECT:
throw query::v2::QueryRuntimeException("Trying to access labels from a node that doesn't exist.");
case storage::v3::Error::SERIALIZATION_ERROR:
case storage::v3::Error::VERTEX_HAS_EDGES:
case storage::v3::Error::PROPERTIES_DISABLED:
throw QueryRuntimeException("Unexpected error when accessing labels.");
}
}
if (!*has_label) {
return TypedValue(false, ctx_->memory);
}
}
return TypedValue(true, ctx_->memory);
}
default:
throw QueryRuntimeException("Only nodes have labels.");
}
}
TypedValue Visit(PrimitiveLiteral &literal) override {
// TODO: no need to evaluate constants, we can write it to frame in one
// of the previous phases.
return TypedValue(literal.value_, ctx_->memory);
}
TypedValue Visit(ListLiteral &literal) override {
TypedValue::TVector result(ctx_->memory);
result.reserve(literal.elements_.size());
for (const auto &expression : literal.elements_) result.emplace_back(expression->Accept(*this));
return TypedValue(result, ctx_->memory);
}
TypedValue Visit(MapLiteral &literal) override {
TypedValue::TMap result(ctx_->memory);
for (const auto &pair : literal.elements_) result.emplace(pair.first.name, pair.second->Accept(*this));
return TypedValue(result, ctx_->memory);
}
TypedValue Visit(Aggregation &aggregation) override {
return TypedValue(frame_->at(symbol_table_->at(aggregation)), ctx_->memory);
}
TypedValue Visit(Coalesce &coalesce) override {
auto &exprs = coalesce.expressions_;
if (exprs.size() == 0) {
throw QueryRuntimeException("'coalesce' requires at least one argument.");
}
for (int64_t i = 0; i < exprs.size(); ++i) {
TypedValue val(exprs[i]->Accept(*this), ctx_->memory);
if (!val.IsNull()) {
return val;
}
}
return TypedValue(ctx_->memory);
}
TypedValue Visit(Function &function) override {
FunctionContext function_ctx{dba_, ctx_->memory, ctx_->timestamp, &ctx_->counters, view_};
// Stack allocate evaluated arguments when there's a small number of them.
if (function.arguments_.size() <= 8) {
TypedValue arguments[8] = {TypedValue(ctx_->memory), TypedValue(ctx_->memory), TypedValue(ctx_->memory),
TypedValue(ctx_->memory), TypedValue(ctx_->memory), TypedValue(ctx_->memory),
TypedValue(ctx_->memory), TypedValue(ctx_->memory)};
for (size_t i = 0; i < function.arguments_.size(); ++i) {
arguments[i] = function.arguments_[i]->Accept(*this);
}
auto res = function.function_(arguments, function.arguments_.size(), function_ctx);
MG_ASSERT(res.GetMemoryResource() == ctx_->memory);
return res;
} else {
TypedValue::TVector arguments(ctx_->memory);
arguments.reserve(function.arguments_.size());
for (const auto &argument : function.arguments_) {
arguments.emplace_back(argument->Accept(*this));
}
auto res = function.function_(arguments.data(), arguments.size(), function_ctx);
MG_ASSERT(res.GetMemoryResource() == ctx_->memory);
return res;
}
}
TypedValue Visit(Reduce &reduce) override {
auto list_value = reduce.list_->Accept(*this);
if (list_value.IsNull()) {
return TypedValue(ctx_->memory);
}
if (list_value.type() != TypedValue::Type::List) {
throw QueryRuntimeException("REDUCE expected a list, got {}.", list_value.type());
}
const auto &list = list_value.ValueList();
const auto &element_symbol = symbol_table_->at(*reduce.identifier_);
const auto &accumulator_symbol = symbol_table_->at(*reduce.accumulator_);
auto accumulator = reduce.initializer_->Accept(*this);
for (const auto &element : list) {
frame_->at(accumulator_symbol) = accumulator;
frame_->at(element_symbol) = element;
accumulator = reduce.expression_->Accept(*this);
}
return accumulator;
}
TypedValue Visit(Extract &extract) override {
auto list_value = extract.list_->Accept(*this);
if (list_value.IsNull()) {
return TypedValue(ctx_->memory);
}
if (list_value.type() != TypedValue::Type::List) {
throw QueryRuntimeException("EXTRACT expected a list, got {}.", list_value.type());
}
const auto &list = list_value.ValueList();
const auto &element_symbol = symbol_table_->at(*extract.identifier_);
TypedValue::TVector result(ctx_->memory);
result.reserve(list.size());
for (const auto &element : list) {
if (element.IsNull()) {
result.emplace_back();
} else {
frame_->at(element_symbol) = element;
result.emplace_back(extract.expression_->Accept(*this));
}
}
return TypedValue(result, ctx_->memory);
}
TypedValue Visit(All &all) override {
auto list_value = all.list_expression_->Accept(*this);
if (list_value.IsNull()) {
return TypedValue(ctx_->memory);
}
if (list_value.type() != TypedValue::Type::List) {
throw QueryRuntimeException("ALL expected a list, got {}.", list_value.type());
}
const auto &list = list_value.ValueList();
const auto &symbol = symbol_table_->at(*all.identifier_);
bool has_null_elements = false;
bool has_value = false;
for (const auto &element : list) {
frame_->at(symbol) = element;
auto result = all.where_->expression_->Accept(*this);
if (!result.IsNull() && result.type() != TypedValue::Type::Bool) {
throw QueryRuntimeException("Predicate of ALL must evaluate to boolean, got {}.", result.type());
}
if (!result.IsNull()) {
has_value = true;
if (!result.ValueBool()) {
return TypedValue(false, ctx_->memory);
}
} else {
has_null_elements = true;
}
}
if (!has_value) {
return TypedValue(ctx_->memory);
}
if (has_null_elements) {
return TypedValue(false, ctx_->memory);
} else {
return TypedValue(true, ctx_->memory);
}
}
TypedValue Visit(Single &single) override {
auto list_value = single.list_expression_->Accept(*this);
if (list_value.IsNull()) {
return TypedValue(ctx_->memory);
}
if (list_value.type() != TypedValue::Type::List) {
throw QueryRuntimeException("SINGLE expected a list, got {}.", list_value.type());
}
const auto &list = list_value.ValueList();
const auto &symbol = symbol_table_->at(*single.identifier_);
bool has_value = false;
bool predicate_satisfied = false;
for (const auto &element : list) {
frame_->at(symbol) = element;
auto result = single.where_->expression_->Accept(*this);
if (!result.IsNull() && result.type() != TypedValue::Type::Bool) {
throw QueryRuntimeException("Predicate of SINGLE must evaluate to boolean, got {}.", result.type());
}
if (result.type() == TypedValue::Type::Bool) {
has_value = true;
}
if (result.IsNull() || !result.ValueBool()) {
continue;
}
// Return false if more than one element satisfies the predicate.
if (predicate_satisfied) {
return TypedValue(false, ctx_->memory);
} else {
predicate_satisfied = true;
}
}
if (!has_value) {
return TypedValue(ctx_->memory);
} else {
return TypedValue(predicate_satisfied, ctx_->memory);
}
}
TypedValue Visit(Any &any) override {
auto list_value = any.list_expression_->Accept(*this);
if (list_value.IsNull()) {
return TypedValue(ctx_->memory);
}
if (list_value.type() != TypedValue::Type::List) {
throw QueryRuntimeException("ANY expected a list, got {}.", list_value.type());
}
const auto &list = list_value.ValueList();
const auto &symbol = symbol_table_->at(*any.identifier_);
bool has_value = false;
for (const auto &element : list) {
frame_->at(symbol) = element;
auto result = any.where_->expression_->Accept(*this);
if (!result.IsNull() && result.type() != TypedValue::Type::Bool) {
throw QueryRuntimeException("Predicate of ANY must evaluate to boolean, got {}.", result.type());
}
if (!result.IsNull()) {
has_value = true;
if (result.ValueBool()) {
return TypedValue(true, ctx_->memory);
}
}
}
// Return Null if all elements are Null
if (!has_value) {
return TypedValue(ctx_->memory);
} else {
return TypedValue(false, ctx_->memory);
}
}
TypedValue Visit(None &none) override {
auto list_value = none.list_expression_->Accept(*this);
if (list_value.IsNull()) {
return TypedValue(ctx_->memory);
}
if (list_value.type() != TypedValue::Type::List) {
throw QueryRuntimeException("NONE expected a list, got {}.", list_value.type());
}
const auto &list = list_value.ValueList();
const auto &symbol = symbol_table_->at(*none.identifier_);
bool has_value = false;
for (const auto &element : list) {
frame_->at(symbol) = element;
auto result = none.where_->expression_->Accept(*this);
if (!result.IsNull() && result.type() != TypedValue::Type::Bool) {
throw QueryRuntimeException("Predicate of NONE must evaluate to boolean, got {}.", result.type());
}
if (!result.IsNull()) {
has_value = true;
if (result.ValueBool()) {
return TypedValue(false, ctx_->memory);
}
}
}
// Return Null if all elements are Null
if (!has_value) {
return TypedValue(ctx_->memory);
} else {
return TypedValue(true, ctx_->memory);
}
}
TypedValue Visit(ParameterLookup &param_lookup) override {
return TypedValue(ctx_->parameters.AtTokenPosition(param_lookup.token_position_), ctx_->memory);
}
TypedValue Visit(RegexMatch &regex_match) override {
auto target_string_value = regex_match.string_expr_->Accept(*this);
auto regex_value = regex_match.regex_->Accept(*this);
if (target_string_value.IsNull() || regex_value.IsNull()) {
return TypedValue(ctx_->memory);
}
if (regex_value.type() != TypedValue::Type::String) {
throw QueryRuntimeException("Regular expression must evaluate to a string, got {}.", regex_value.type());
}
if (target_string_value.type() != TypedValue::Type::String) {
// Instead of error, we return Null which makes it compatible in case we
// use indexed lookup which filters out any non-string properties.
// Assuming a property lookup is the target_string_value.
return TypedValue(ctx_->memory);
}
const auto &target_string = target_string_value.ValueString();
try {
std::regex regex(regex_value.ValueString());
return TypedValue(std::regex_match(target_string, regex), ctx_->memory);
} catch (const std::regex_error &e) {
throw QueryRuntimeException("Regex error in '{}': {}", regex_value.ValueString(), e.what());
}
}
private:
template <class TRecordAccessor>
storage::v3::PropertyValue GetProperty(const TRecordAccessor &record_accessor, PropertyIx prop) {
auto maybe_prop = record_accessor.GetProperty(view_, ctx_->properties[prop.ix]);
if (maybe_prop.HasError() && maybe_prop.GetError() == storage::v3::Error::NONEXISTENT_OBJECT) {
// This is a very nasty and temporary hack in order to make MERGE work.
// The old storage had the following logic when returning an `OLD` view:
// `return old ? old : new`. That means that if the `OLD` view didn't
// exist, it returned the NEW view. With this hack we simulate that
// behavior.
// TODO (mferencevic, teon.banek): Remove once MERGE is reimplemented.
maybe_prop = record_accessor.GetProperty(storage::v3::View::NEW, ctx_->properties[prop.ix]);
}
if (maybe_prop.HasError()) {
switch (maybe_prop.GetError()) {
case storage::v3::Error::DELETED_OBJECT:
throw QueryRuntimeException("Trying to get a property from a deleted object.");
case storage::v3::Error::NONEXISTENT_OBJECT:
throw query::v2::QueryRuntimeException("Trying to get a property from an object that doesn't exist.");
case storage::v3::Error::SERIALIZATION_ERROR:
case storage::v3::Error::VERTEX_HAS_EDGES:
case storage::v3::Error::PROPERTIES_DISABLED:
throw QueryRuntimeException("Unexpected error when getting a property.");
}
}
return *maybe_prop;
}
template <class TRecordAccessor>
storage::v3::PropertyValue GetProperty(const TRecordAccessor &record_accessor, const std::string_view name) {
auto maybe_prop = record_accessor.GetProperty(view_, dba_->NameToProperty(name));
if (maybe_prop.HasError() && maybe_prop.GetError() == storage::v3::Error::NONEXISTENT_OBJECT) {
// This is a very nasty and temporary hack in order to make MERGE work.
// The old storage had the following logic when returning an `OLD` view:
// `return old ? old : new`. That means that if the `OLD` view didn't
// exist, it returned the NEW view. With this hack we simulate that
// behavior.
// TODO (mferencevic, teon.banek): Remove once MERGE is reimplemented.
maybe_prop = record_accessor.GetProperty(view_, dba_->NameToProperty(name));
}
if (maybe_prop.HasError()) {
switch (maybe_prop.GetError()) {
case storage::v3::Error::DELETED_OBJECT:
throw QueryRuntimeException("Trying to get a property from a deleted object.");
case storage::v3::Error::NONEXISTENT_OBJECT:
throw query::v2::QueryRuntimeException("Trying to get a property from an object that doesn't exist.");
case storage::v3::Error::SERIALIZATION_ERROR:
case storage::v3::Error::VERTEX_HAS_EDGES:
case storage::v3::Error::PROPERTIES_DISABLED:
throw QueryRuntimeException("Unexpected error when getting a property.");
}
}
return *maybe_prop;
}
storage::v3::LabelId GetLabel(LabelIx label) { return ctx_->labels[label.ix]; }
Frame *frame_;
const SymbolTable *symbol_table_;
const EvaluationContext *ctx_;
DbAccessor *dba_;
// which switching approach should be used when evaluating
storage::v3::View view_;
};
/// A helper function for evaluating an expression that's an int.
///
/// @param what - Name of what's getting evaluated. Used for user feedback (via
/// exception) when the evaluated value is not an int.
/// @throw QueryRuntimeException if expression doesn't evaluate to an int.
int64_t EvaluateInt(ExpressionEvaluator *evaluator, Expression *expr, const std::string &what);
std::optional<size_t> EvaluateMemoryLimit(ExpressionEvaluator *eval, Expression *memory_limit, size_t memory_scale);
} // namespace memgraph::query::v2

View File

@ -0,0 +1,45 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <vector>
#include "query/v2/frontend/semantic/symbol_table.hpp"
#include "query/v2/typed_value.hpp"
#include "utils/logging.hpp"
#include "utils/memory.hpp"
#include "utils/pmr/vector.hpp"
namespace memgraph::query::v2 {
class Frame {
public:
/// Create a Frame of given size backed by a utils::NewDeleteResource()
explicit Frame(int64_t size) : elems_(size, utils::NewDeleteResource()) { MG_ASSERT(size >= 0); }
Frame(int64_t size, utils::MemoryResource *memory) : elems_(size, memory) { MG_ASSERT(size >= 0); }
TypedValue &operator[](const Symbol &symbol) { return elems_[symbol.position()]; }
const TypedValue &operator[](const Symbol &symbol) const { return elems_[symbol.position()]; }
TypedValue &at(const Symbol &symbol) { return elems_.at(symbol.position()); }
const TypedValue &at(const Symbol &symbol) const { return elems_.at(symbol.position()); }
auto &elems() { return elems_; }
utils::MemoryResource *GetMemoryResource() const { return elems_.get_allocator().GetMemoryResource(); }
private:
utils::pmr::vector<TypedValue> elems_;
};
} // namespace memgraph::query::v2

2411
src/query/v2/interpreter.cpp Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,429 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <gflags/gflags.h>
#include "query/v2/auth_checker.hpp"
#include "query/v2/config.hpp"
#include "query/v2/context.hpp"
#include "query/v2/cypher_query_interpreter.hpp"
#include "query/v2/db_accessor.hpp"
#include "query/v2/exceptions.hpp"
#include "query/v2/frontend/ast/ast.hpp"
#include "query/v2/frontend/ast/cypher_main_visitor.hpp"
#include "query/v2/frontend/stripped.hpp"
#include "query/v2/interpret/frame.hpp"
#include "query/v2/metadata.hpp"
#include "query/v2/plan/operator.hpp"
#include "query/v2/plan/read_write_type_checker.hpp"
#include "query/v2/stream.hpp"
#include "query/v2/stream/streams.hpp"
#include "query/v2/trigger.hpp"
#include "query/v2/typed_value.hpp"
#include "storage/v3/isolation_level.hpp"
#include "utils/event_counter.hpp"
#include "utils/logging.hpp"
#include "utils/memory.hpp"
#include "utils/settings.hpp"
#include "utils/skip_list.hpp"
#include "utils/spin_lock.hpp"
#include "utils/thread_pool.hpp"
#include "utils/timer.hpp"
#include "utils/tsc.hpp"
namespace EventCounter {
extern const Event FailedQuery;
} // namespace EventCounter
namespace memgraph::query::v2 {
inline constexpr size_t kExecutionMemoryBlockSize = 1UL * 1024UL * 1024UL;
class AuthQueryHandler {
public:
AuthQueryHandler() = default;
virtual ~AuthQueryHandler() = default;
AuthQueryHandler(const AuthQueryHandler &) = delete;
AuthQueryHandler(AuthQueryHandler &&) = delete;
AuthQueryHandler &operator=(const AuthQueryHandler &) = delete;
AuthQueryHandler &operator=(AuthQueryHandler &&) = delete;
/// Return false if the user already exists.
/// @throw QueryRuntimeException if an error ocurred.
virtual bool CreateUser(const std::string &username, const std::optional<std::string> &password) = 0;
/// Return false if the user does not exist.
/// @throw QueryRuntimeException if an error ocurred.
virtual bool DropUser(const std::string &username) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void SetPassword(const std::string &username, const std::optional<std::string> &password) = 0;
/// Return false if the role already exists.
/// @throw QueryRuntimeException if an error ocurred.
virtual bool CreateRole(const std::string &rolename) = 0;
/// Return false if the role does not exist.
/// @throw QueryRuntimeException if an error ocurred.
virtual bool DropRole(const std::string &rolename) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual std::vector<TypedValue> GetUsernames() = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual std::vector<TypedValue> GetRolenames() = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual std::optional<std::string> GetRolenameForUser(const std::string &username) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual std::vector<TypedValue> GetUsernamesForRole(const std::string &rolename) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void SetRole(const std::string &username, const std::string &rolename) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void ClearRole(const std::string &username) = 0;
virtual std::vector<std::vector<TypedValue>> GetPrivileges(const std::string &user_or_role) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void GrantPrivilege(const std::string &user_or_role, const std::vector<AuthQuery::Privilege> &privileges) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void DenyPrivilege(const std::string &user_or_role, const std::vector<AuthQuery::Privilege> &privileges) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void RevokePrivilege(const std::string &user_or_role,
const std::vector<AuthQuery::Privilege> &privileges) = 0;
};
enum class QueryHandlerResult { COMMIT, ABORT, NOTHING };
class ReplicationQueryHandler {
public:
ReplicationQueryHandler() = default;
virtual ~ReplicationQueryHandler() = default;
ReplicationQueryHandler(const ReplicationQueryHandler &) = default;
ReplicationQueryHandler &operator=(const ReplicationQueryHandler &) = default;
ReplicationQueryHandler(ReplicationQueryHandler &&) = default;
ReplicationQueryHandler &operator=(ReplicationQueryHandler &&) = default;
struct Replica {
std::string name;
std::string socket_address;
ReplicationQuery::SyncMode sync_mode;
std::optional<double> timeout;
ReplicationQuery::ReplicaState state;
};
/// @throw QueryRuntimeException if an error ocurred.
virtual void SetReplicationRole(ReplicationQuery::ReplicationRole replication_role, std::optional<int64_t> port) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual ReplicationQuery::ReplicationRole ShowReplicationRole() const = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void RegisterReplica(const std::string &name, const std::string &socket_address,
const ReplicationQuery::SyncMode sync_mode, const std::optional<double> timeout,
const std::chrono::seconds replica_check_frequency) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual void DropReplica(const std::string &replica_name) = 0;
/// @throw QueryRuntimeException if an error ocurred.
virtual std::vector<Replica> ShowReplicas() const = 0;
};
/**
* A container for data related to the preparation of a query.
*/
struct PreparedQuery {
std::vector<std::string> header;
std::vector<AuthQuery::Privilege> privileges;
std::function<std::optional<QueryHandlerResult>(AnyStream *stream, std::optional<int> n)> query_handler;
plan::ReadWriteTypeChecker::RWType rw_type;
};
/**
* Holds data shared between multiple `Interpreter` instances (which might be
* running concurrently).
*
* Users should initialize the context but should not modify it after it has
* been passed to an `Interpreter` instance.
*/
struct InterpreterContext {
explicit InterpreterContext(storage::v3::Storage *db, InterpreterConfig config,
const std::filesystem::path &data_directory);
storage::v3::Storage *db;
std::optional<double> tsc_frequency{utils::GetTSCFrequency()};
std::atomic<bool> is_shutting_down{false};
AuthQueryHandler *auth{nullptr};
AuthChecker *auth_checker{nullptr};
utils::SkipList<QueryCacheEntry> ast_cache;
utils::SkipList<PlanCacheEntry> plan_cache;
TriggerStore trigger_store;
utils::ThreadPool after_commit_trigger_pool{1};
const InterpreterConfig config;
query::v2::stream::Streams streams;
};
/// Function that is used to tell all active interpreters that they should stop
/// their ongoing execution.
inline void Shutdown(InterpreterContext *context) { context->is_shutting_down.store(true, std::memory_order_release); }
class Interpreter final {
public:
explicit Interpreter(InterpreterContext *interpreter_context);
Interpreter(const Interpreter &) = delete;
Interpreter &operator=(const Interpreter &) = delete;
Interpreter(Interpreter &&) = delete;
Interpreter &operator=(Interpreter &&) = delete;
~Interpreter() { Abort(); }
struct PrepareResult {
std::vector<std::string> headers;
std::vector<query::v2::AuthQuery::Privilege> privileges;
std::optional<int> qid;
};
/**
* Prepare a query for execution.
*
* Preparing a query means to preprocess the query and save it for
* future calls of `Pull`.
*
* @throw query::v2::QueryException
*/
PrepareResult Prepare(const std::string &query, const std::map<std::string, storage::v3::PropertyValue> &params,
const std::string *username);
/**
* Execute the last prepared query and stream *all* of the results into the
* given stream.
*
* It is not possible to prepare a query once and execute it multiple times,
* i.e. `Prepare` has to be called before *every* call to `PullAll`.
*
* TStream should be a type implementing the `Stream` concept, i.e. it should
* contain the member function `void Result(const std::vector<TypedValue> &)`.
* The provided vector argument is valid only for the duration of the call to
* `Result`. The stream should make an explicit copy if it wants to use it
* further.
*
* @throw utils::BasicException
* @throw query::v2::QueryException
*/
template <typename TStream>
std::map<std::string, TypedValue> PullAll(TStream *result_stream) {
return Pull(result_stream);
}
/**
* Execute a prepared query and stream result into the given stream.
*
* TStream should be a type implementing the `Stream` concept, i.e. it should
* contain the member function `void Result(const std::vector<TypedValue> &)`.
* The provided vector argument is valid only for the duration of the call to
* `Result`. The stream should make an explicit copy if it wants to use it
* further.
*
* @param n If set, amount of rows to be pulled from result,
* otherwise all the rows are pulled.
* @param qid If set, id of the query from which the result should be pulled,
* otherwise the last query should be used.
*
* @throw utils::BasicException
* @throw query::v2::QueryException
*/
template <typename TStream>
std::map<std::string, TypedValue> Pull(TStream *result_stream, std::optional<int> n = {},
std::optional<int> qid = {});
void BeginTransaction();
void CommitTransaction();
void RollbackTransaction();
void SetNextTransactionIsolationLevel(storage::v3::IsolationLevel isolation_level);
void SetSessionIsolationLevel(storage::v3::IsolationLevel isolation_level);
/**
* Abort the current multicommand transaction.
*/
void Abort();
private:
struct QueryExecution {
std::optional<PreparedQuery> prepared_query;
utils::MonotonicBufferResource execution_memory{kExecutionMemoryBlockSize};
utils::ResourceWithOutOfMemoryException execution_memory_with_exception{&execution_memory};
std::map<std::string, TypedValue> summary;
std::vector<Notification> notifications;
explicit QueryExecution() = default;
QueryExecution(const QueryExecution &) = delete;
QueryExecution(QueryExecution &&) = default;
QueryExecution &operator=(const QueryExecution &) = delete;
QueryExecution &operator=(QueryExecution &&) = default;
~QueryExecution() {
// We should always release the execution memory AFTER we
// destroy the prepared query which is using that instance
// of execution memory.
prepared_query.reset();
execution_memory.Release();
}
};
// Interpreter supports multiple prepared queries at the same time.
// The client can reference a specific query for pull using an arbitrary qid
// which is in our case the index of the query in the vector.
// To simplify the handling of the qid we avoid modifying the vector if it
// affects the position of the currently running queries in any way.
// For example, we cannot delete the prepared query from the vector because
// every prepared query after the deleted one will be moved by one place
// making their qid not equal to the their index inside the vector.
// To avoid this, we use unique_ptr with which we manualy control construction
// and deletion of a single query execution, i.e. when a query finishes,
// we reset the corresponding unique_ptr.
std::vector<std::unique_ptr<QueryExecution>> query_executions_;
InterpreterContext *interpreter_context_;
// This cannot be std::optional because we need to move this accessor later on into a lambda capture
// which is assigned to std::function. std::function requires every object to be copyable, so we
// move this unique_ptr into a shrared_ptr.
std::unique_ptr<storage::v3::Storage::Accessor> db_accessor_;
std::optional<DbAccessor> execution_db_accessor_;
std::optional<TriggerContextCollector> trigger_context_collector_;
bool in_explicit_transaction_{false};
bool expect_rollback_{false};
std::optional<storage::v3::IsolationLevel> interpreter_isolation_level;
std::optional<storage::v3::IsolationLevel> next_transaction_isolation_level;
PreparedQuery PrepareTransactionQuery(std::string_view query_upper);
void Commit();
void AdvanceCommand();
void AbortCommand(std::unique_ptr<QueryExecution> *query_execution);
std::optional<storage::v3::IsolationLevel> GetIsolationLevelOverride();
size_t ActiveQueryExecutions() {
return std::count_if(query_executions_.begin(), query_executions_.end(),
[](const auto &execution) { return execution && execution->prepared_query; });
}
};
template <typename TStream>
std::map<std::string, TypedValue> Interpreter::Pull(TStream *result_stream, std::optional<int> n,
std::optional<int> qid) {
MG_ASSERT(in_explicit_transaction_ || !qid, "qid can be only used in explicit transaction!");
const int qid_value = qid ? *qid : static_cast<int>(query_executions_.size() - 1);
if (qid_value < 0 || qid_value >= query_executions_.size()) {
throw InvalidArgumentsException("qid", "Query with specified ID does not exist!");
}
if (n && n < 0) {
throw InvalidArgumentsException("n", "Cannot fetch negative number of results!");
}
auto &query_execution = query_executions_[qid_value];
MG_ASSERT(query_execution && query_execution->prepared_query, "Query already finished executing!");
// Each prepared query has its own summary so we need to somehow preserve
// it after it finishes executing because it gets destroyed alongside
// the prepared query and its execution memory.
std::optional<std::map<std::string, TypedValue>> maybe_summary;
try {
// Wrap the (statically polymorphic) stream type into a common type which
// the handler knows.
AnyStream stream{result_stream, &query_execution->execution_memory};
const auto maybe_res = query_execution->prepared_query->query_handler(&stream, n);
// Stream is using execution memory of the query_execution which
// can be deleted after its execution so the stream should be cleared
// first.
stream.~AnyStream();
// If the query finished executing, we have received a value which tells
// us what to do after.
if (maybe_res) {
// Save its summary
maybe_summary.emplace(std::move(query_execution->summary));
if (!query_execution->notifications.empty()) {
std::vector<TypedValue> notifications;
notifications.reserve(query_execution->notifications.size());
for (const auto &notification : query_execution->notifications) {
notifications.emplace_back(notification.ConvertToMap());
}
maybe_summary->insert_or_assign("notifications", std::move(notifications));
}
if (!in_explicit_transaction_) {
switch (*maybe_res) {
case QueryHandlerResult::COMMIT:
Commit();
break;
case QueryHandlerResult::ABORT:
Abort();
break;
case QueryHandlerResult::NOTHING:
// The only cases in which we have nothing to do are those where
// we're either in an explicit transaction or the query is such that
// a transaction wasn't started on a call to `Prepare()`.
MG_ASSERT(in_explicit_transaction_ || !db_accessor_);
break;
}
// As the transaction is done we can clear all the executions
// NOTE: we cannot clear query_execution inside the Abort and Commit
// methods as we will delete summary contained in them which we need
// after our query finished executing.
query_executions_.clear();
} else {
// We can only clear this execution as some of the queries
// in the transaction can be in unfinished state
query_execution.reset(nullptr);
}
}
} catch (const ExplicitTransactionUsageException &) {
query_execution.reset(nullptr);
throw;
} catch (const utils::BasicException &) {
EventCounter::IncrementCounter(EventCounter::FailedQuery);
AbortCommand(&query_execution);
throw;
}
if (maybe_summary) {
// return the execution summary
maybe_summary->insert_or_assign("has_more", false);
return std::move(*maybe_summary);
}
// don't return the execution summary as it's not finished
return {{"has_more", TypedValue(true)}};
}
} // namespace memgraph::query::v2

117
src/query/v2/metadata.cpp Normal file
View File

@ -0,0 +1,117 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/v2/metadata.hpp"
#include <algorithm>
#include <compare>
#include <string>
#include <string_view>
namespace memgraph::query::v2 {
namespace {
using namespace std::literals;
constexpr std::string_view GetSeverityLevelString(const SeverityLevel level) {
switch (level) {
case SeverityLevel::INFO:
return "INFO"sv;
case SeverityLevel::WARNING:
return "WARNING"sv;
}
}
constexpr std::string_view GetCodeString(const NotificationCode code) {
switch (code) {
case NotificationCode::CREATE_CONSTRAINT:
return "CreateConstraint"sv;
case NotificationCode::CREATE_INDEX:
return "CreateIndex"sv;
case NotificationCode::CREATE_STREAM:
return "CreateStream"sv;
case NotificationCode::CHECK_STREAM:
return "CheckStream"sv;
case NotificationCode::CREATE_TRIGGER:
return "CreateTrigger"sv;
case NotificationCode::DROP_CONSTRAINT:
return "DropConstraint"sv;
case NotificationCode::DROP_REPLICA:
return "DropReplica"sv;
case NotificationCode::DROP_INDEX:
return "DropIndex"sv;
case NotificationCode::DROP_STREAM:
return "DropStream"sv;
case NotificationCode::DROP_TRIGGER:
return "DropTrigger"sv;
case NotificationCode::EXISTANT_CONSTRAINT:
return "ConstraintAlreadyExists"sv;
case NotificationCode::EXISTANT_INDEX:
return "IndexAlreadyExists"sv;
case NotificationCode::LOAD_CSV_TIP:
return "LoadCSVTip"sv;
case NotificationCode::NONEXISTANT_INDEX:
return "IndexDoesNotExist"sv;
case NotificationCode::NONEXISTANT_CONSTRAINT:
return "ConstraintDoesNotExist"sv;
case NotificationCode::REGISTER_REPLICA:
return "RegisterReplica"sv;
case NotificationCode::REPLICA_PORT_WARNING:
return "ReplicaPortWarning"sv;
case NotificationCode::SET_REPLICA:
return "SetReplica"sv;
case NotificationCode::START_STREAM:
return "StartStream"sv;
case NotificationCode::START_ALL_STREAMS:
return "StartAllStreams"sv;
case NotificationCode::STOP_STREAM:
return "StopStream"sv;
case NotificationCode::STOP_ALL_STREAMS:
return "StopAllStreams"sv;
}
}
} // namespace
Notification::Notification(SeverityLevel level) : level{level} {};
Notification::Notification(SeverityLevel level, NotificationCode code, std::string title, std::string description)
: level{level}, code{code}, title(std::move(title)), description(std::move(description)){};
Notification::Notification(SeverityLevel level, NotificationCode code, std::string title)
: level{level}, code{code}, title(std::move(title)){};
std::map<std::string, TypedValue> Notification::ConvertToMap() const {
return std::map<std::string, TypedValue>{{"severity", TypedValue(GetSeverityLevelString(level))},
{"code", TypedValue(GetCodeString(code))},
{"title", TypedValue(title)},
{"description", TypedValue(description)}};
}
std::string ExecutionStatsKeyToString(const ExecutionStats::Key key) {
switch (key) {
case ExecutionStats::Key::CREATED_NODES:
return std::string("nodes-created");
case ExecutionStats::Key::DELETED_NODES:
return std::string("nodes-deleted");
case ExecutionStats::Key::CREATED_EDGES:
return std::string("relationships-created");
case ExecutionStats::Key::DELETED_EDGES:
return std::string("relationships-deleted");
case ExecutionStats::Key::CREATED_LABELS:
return std::string("labels-added");
case ExecutionStats::Key::DELETED_LABELS:
return std::string("labels-removed");
case ExecutionStats::Key::UPDATED_PROPERTIES:
return std::string("properties-set");
}
}
} // namespace memgraph::query::v2

90
src/query/v2/metadata.hpp Normal file
View File

@ -0,0 +1,90 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <cstdint>
#include <map>
#include <string>
#include <string_view>
#include <type_traits>
#include "query/v2/typed_value.hpp"
namespace memgraph::query::v2 {
enum class SeverityLevel : uint8_t { INFO, WARNING };
enum class NotificationCode : uint8_t {
CREATE_CONSTRAINT,
CREATE_INDEX,
CHECK_STREAM,
CREATE_STREAM,
CREATE_TRIGGER,
DROP_CONSTRAINT,
DROP_INDEX,
DROP_REPLICA,
DROP_STREAM,
DROP_TRIGGER,
EXISTANT_INDEX,
EXISTANT_CONSTRAINT,
LOAD_CSV_TIP,
NONEXISTANT_INDEX,
NONEXISTANT_CONSTRAINT,
REPLICA_PORT_WARNING,
REGISTER_REPLICA,
SET_REPLICA,
START_STREAM,
START_ALL_STREAMS,
STOP_STREAM,
STOP_ALL_STREAMS,
};
struct Notification {
SeverityLevel level;
NotificationCode code;
std::string title;
std::string description;
explicit Notification(SeverityLevel level);
Notification(SeverityLevel level, NotificationCode code, std::string title, std::string description);
Notification(SeverityLevel level, NotificationCode code, std::string title);
std::map<std::string, TypedValue> ConvertToMap() const;
};
struct ExecutionStats {
public:
// All the stats have specific key to be compatible with neo4j
enum class Key : uint8_t {
CREATED_NODES,
DELETED_NODES,
CREATED_EDGES,
DELETED_EDGES,
CREATED_LABELS,
DELETED_LABELS,
UPDATED_PROPERTIES,
};
int64_t &operator[](Key key) { return counters[static_cast<size_t>(key)]; }
private:
static constexpr auto kExecutionStatsCountersSize = std::underlying_type_t<Key>(Key::UPDATED_PROPERTIES) + 1;
public:
std::array<int64_t, kExecutionStatsCountersSize> counters{0};
};
std::string ExecutionStatsKeyToString(ExecutionStats::Key key);
} // namespace memgraph::query::v2

View File

@ -0,0 +1,71 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "storage/v3/property_value.hpp"
#include "utils/logging.hpp"
/**
* Encapsulates user provided parameters (and stripped literals)
* and provides ways of obtaining them by position.
*/
namespace memgraph::query::v2 {
struct Parameters {
public:
/**
* Adds a value to the stripped arguments under a token position.
*
* @param position Token position in query of value.
* @param value
*/
void Add(int position, const storage::v3::PropertyValue &value) { storage_.emplace_back(position, value); }
/**
* Returns the value found for the given token position.
*
* @param position Token position in query of value.
* @return Value for the given token position.
*/
const storage::v3::PropertyValue &AtTokenPosition(int position) const {
auto found = std::find_if(storage_.begin(), storage_.end(), [&](const auto &a) { return a.first == position; });
MG_ASSERT(found != storage_.end(), "Token position must be present in container");
return found->second;
}
/**
* Returns the position-th stripped value. Asserts that this
* container has at least (position + 1) elements.
*
* @param position Which stripped param is sought.
* @return Token position and value for sought param.
*/
const std::pair<int, storage::v3::PropertyValue> &At(int position) const {
MG_ASSERT(position < static_cast<int>(storage_.size()), "Invalid position");
return storage_[position];
}
/** Returns the number of arguments in this container */
auto size() const { return storage_.size(); }
auto begin() const { return storage_.begin(); }
auto end() const { return storage_.end(); }
private:
std::vector<std::pair<int, storage::v3::PropertyValue>> storage_;
};
} // namespace memgraph::query::v2

146
src/query/v2/path.hpp Normal file
View File

@ -0,0 +1,146 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <functional>
#include <utility>
#include "query/v2/db_accessor.hpp"
#include "utils/logging.hpp"
#include "utils/memory.hpp"
#include "utils/pmr/vector.hpp"
namespace memgraph::query::v2 {
/**
* A data structure that holds a graph path. A path consists of at least one
* vertex, followed by zero or more edge + vertex extensions (thus having one
* vertex more then edges).
*/
class Path {
public:
/** Allocator type so that STL containers are aware that we need one */
using allocator_type = utils::Allocator<char>;
/**
* Create the path starting with the given vertex.
* Allocations are done using the given MemoryResource.
*/
explicit Path(const VertexAccessor &vertex, utils::MemoryResource *memory = utils::NewDeleteResource())
: vertices_(memory), edges_(memory) {
Expand(vertex);
}
/**
* Create the path starting with the given vertex and containing all other
* elements.
* Allocations are done using the default utils::NewDeleteResource().
*/
template <typename... TOthers>
explicit Path(const VertexAccessor &vertex, const TOthers &...others)
: vertices_(utils::NewDeleteResource()), edges_(utils::NewDeleteResource()) {
Expand(vertex);
Expand(others...);
}
/**
* Create the path starting with the given vertex and containing all other
* elements.
* Allocations are done using the given MemoryResource.
*/
template <typename... TOthers>
Path(std::allocator_arg_t, utils::MemoryResource *memory, const VertexAccessor &vertex, const TOthers &...others)
: vertices_(memory), edges_(memory) {
Expand(vertex);
Expand(others...);
}
/**
* Construct a copy of other.
* utils::MemoryResource is obtained by calling
* std::allocator_traits<>::
* select_on_container_copy_construction(other.GetMemoryResource()).
* Since we use utils::Allocator, which does not propagate, this means that we
* will default to utils::NewDeleteResource().
*/
Path(const Path &other)
: Path(other,
std::allocator_traits<allocator_type>::select_on_container_copy_construction(other.GetMemoryResource())
.GetMemoryResource()) {}
/** Construct a copy using the given utils::MemoryResource */
Path(const Path &other, utils::MemoryResource *memory)
: vertices_(other.vertices_, memory), edges_(other.edges_, memory) {}
/**
* Construct with the value of other.
* utils::MemoryResource is obtained from other. After the move, other will be
* empty.
*/
Path(Path &&other) noexcept : Path(std::move(other), other.GetMemoryResource()) {}
/**
* Construct with the value of other, but use the given utils::MemoryResource.
* After the move, other may not be empty if `*memory !=
* *other.GetMemoryResource()`, because an element-wise move will be
* performed.
*/
Path(Path &&other, utils::MemoryResource *memory)
: vertices_(std::move(other.vertices_), memory), edges_(std::move(other.edges_), memory) {}
/** Copy assign other, utils::MemoryResource of `this` is used */
Path &operator=(const Path &) = default;
/** Move assign other, utils::MemoryResource of `this` is used. */
Path &operator=(Path &&) = default;
~Path() = default;
/** Expands the path with the given vertex. */
void Expand(const VertexAccessor &vertex) {
DMG_ASSERT(vertices_.size() == edges_.size(), "Illegal path construction order");
vertices_.emplace_back(vertex);
}
/** Expands the path with the given edge. */
void Expand(const EdgeAccessor &edge) {
DMG_ASSERT(vertices_.size() - 1 == edges_.size(), "Illegal path construction order");
edges_.emplace_back(edge);
}
/** Expands the path with the given elements. */
template <typename TFirst, typename... TOthers>
void Expand(const TFirst &first, const TOthers &...others) {
Expand(first);
Expand(others...);
}
/** Returns the number of expansions (edges) in this path. */
auto size() const { return edges_.size(); }
auto &vertices() { return vertices_; }
auto &edges() { return edges_; }
const auto &vertices() const { return vertices_; }
const auto &edges() const { return edges_; }
utils::MemoryResource *GetMemoryResource() const { return vertices_.get_allocator().GetMemoryResource(); }
bool operator==(const Path &other) const { return vertices_ == other.vertices_ && edges_ == other.edges_; }
private:
// Contains all the vertices in the path.
utils::pmr::vector<VertexAccessor> vertices_;
// Contains all the edges in the path (one less then there are vertices).
utils::pmr::vector<EdgeAccessor> edges_;
};
} // namespace memgraph::query::v2

View File

@ -0,0 +1,267 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "query/v2/frontend/ast/ast.hpp"
#include "query/v2/parameters.hpp"
#include "query/v2/plan/operator.hpp"
#include "query/v2/typed_value.hpp"
namespace memgraph::query::v2::plan {
/**
* Query plan execution time cost estimator, for comparing and choosing optimal
* execution plans.
*
* In Cypher the write part of the query always executes in the same
* cardinality. It is not allowed to execute a write operation before all the
* expansion for that query part (WITH splits a query into parts) have executed.
* For that reason cost estimation comes down to cardinality estimation for the
* read parts of the query, and their expansion. We want to compare different
* plans and try to figure out which has the optimal organization of scans,
* expansions and filters.
*
* Note that expansions and filtering can also happen during Merge, which is a
* write operation. We let that get evaluated like all other cardinality
* influencing ops. Also, Merge cardinality modification should be contained (it
* can never reduce it's input cardinality), but since Merge always happens
* after the read part, and can't be reoredered, we can ignore that.
*
* Limiting and accumulating (Aggregate, OrderBy, Accumulate) operations are
* cardinality modifiers that always execute at the end of the query part. Their
* cardinality influence is irrelevant because they execute the same
* for all plans for a single query part, and query part reordering is not
* allowed.
*
* This kind of cost estimation can only be used for comparing logical plans.
* It's aim is to estimate cost(A) to be less then cost(B) in every case where
* actual query execution for plan A is less then that of plan B. It can NOT be
* used to estimate how MUCH execution between A and B will differ.
*/
template <class TDbAccessor>
class CostEstimator : public HierarchicalLogicalOperatorVisitor {
public:
struct CostParam {
static constexpr double kScanAll{1.0};
static constexpr double kScanAllByLabel{1.1};
static constexpr double MakeScanAllByLabelPropertyValue{1.1};
static constexpr double MakeScanAllByLabelPropertyRange{1.1};
static constexpr double MakeScanAllByLabelProperty{1.1};
static constexpr double kExpand{2.0};
static constexpr double kExpandVariable{3.0};
static constexpr double kFilter{1.5};
static constexpr double kEdgeUniquenessFilter{1.5};
static constexpr double kUnwind{1.3};
static constexpr double kForeach{1.0};
};
struct CardParam {
static constexpr double kExpand{3.0};
static constexpr double kExpandVariable{9.0};
static constexpr double kFilter{0.25};
static constexpr double kEdgeUniquenessFilter{0.95};
};
struct MiscParam {
static constexpr double kUnwindNoLiteral{10.0};
static constexpr double kForeachNoLiteral{10.0};
};
using HierarchicalLogicalOperatorVisitor::PostVisit;
using HierarchicalLogicalOperatorVisitor::PreVisit;
CostEstimator(TDbAccessor *db_accessor, const Parameters &parameters)
: db_accessor_(db_accessor), parameters(parameters) {}
bool PostVisit(ScanAll &) override {
cardinality_ *= db_accessor_->VerticesCount();
// ScanAll performs some work for every element that is produced
IncrementCost(CostParam::kScanAll);
return true;
}
bool PostVisit(ScanAllByLabel &scan_all_by_label) override {
cardinality_ *= db_accessor_->VerticesCount(scan_all_by_label.label_);
// ScanAll performs some work for every element that is produced
IncrementCost(CostParam::kScanAllByLabel);
return true;
}
bool PostVisit(ScanAllByLabelPropertyValue &logical_op) override {
// This cardinality estimation depends on the property value (expression).
// If it's a constant, we can evaluate cardinality exactly, otherwise
// we estimate
auto property_value = ConstPropertyValue(logical_op.expression_);
double factor = 1.0;
if (property_value)
// get the exact influence based on ScanAll(label, property, value)
factor = db_accessor_->VerticesCount(logical_op.label_, logical_op.property_, property_value.value());
else
// estimate the influence as ScanAll(label, property) * filtering
factor = db_accessor_->VerticesCount(logical_op.label_, logical_op.property_) * CardParam::kFilter;
cardinality_ *= factor;
// ScanAll performs some work for every element that is produced
IncrementCost(CostParam::MakeScanAllByLabelPropertyValue);
return true;
}
bool PostVisit(ScanAllByLabelPropertyRange &logical_op) override {
// this cardinality estimation depends on Bound expressions.
// if they are literals we can evaluate cardinality properly
auto lower = BoundToPropertyValue(logical_op.lower_bound_);
auto upper = BoundToPropertyValue(logical_op.upper_bound_);
int64_t factor = 1;
if (upper || lower)
// if we have either Bound<PropertyValue>, use the value index
factor = db_accessor_->VerticesCount(logical_op.label_, logical_op.property_, lower, upper);
else
// no values, but we still have the label
factor = db_accessor_->VerticesCount(logical_op.label_, logical_op.property_);
// if we failed to take either bound from the op into account, then apply
// the filtering constant to the factor
if ((logical_op.upper_bound_ && !upper) || (logical_op.lower_bound_ && !lower)) factor *= CardParam::kFilter;
cardinality_ *= factor;
// ScanAll performs some work for every element that is produced
IncrementCost(CostParam::MakeScanAllByLabelPropertyRange);
return true;
}
bool PostVisit(ScanAllByLabelProperty &logical_op) override {
const auto factor = db_accessor_->VerticesCount(logical_op.label_, logical_op.property_);
cardinality_ *= factor;
IncrementCost(CostParam::MakeScanAllByLabelProperty);
return true;
}
// TODO: Cost estimate ScanAllById?
// For the given op first increments the cardinality and then cost.
#define POST_VISIT_CARD_FIRST(NAME) \
bool PostVisit(NAME &) override { \
cardinality_ *= CardParam::k##NAME; \
IncrementCost(CostParam::k##NAME); \
return true; \
}
POST_VISIT_CARD_FIRST(Expand);
POST_VISIT_CARD_FIRST(ExpandVariable);
#undef POST_VISIT_CARD_FIRST
// For the given op first increments the cost and then cardinality.
#define POST_VISIT_COST_FIRST(LOGICAL_OP, PARAM_NAME) \
bool PostVisit(LOGICAL_OP &) override { \
IncrementCost(CostParam::PARAM_NAME); \
cardinality_ *= CardParam::PARAM_NAME; \
return true; \
}
POST_VISIT_COST_FIRST(Filter, kFilter)
POST_VISIT_COST_FIRST(EdgeUniquenessFilter, kEdgeUniquenessFilter);
#undef POST_VISIT_COST_FIRST
bool PostVisit(Unwind &unwind) override {
// Unwind cost depends more on the number of lists that get unwound
// much less on the number of outputs
// for that reason first increment cost, then modify cardinality
IncrementCost(CostParam::kUnwind);
// try to determine how many values will be yielded by Unwind
// if the Unwind expression is a list literal, we can deduce cardinality
// exactly, otherwise we approximate
double unwind_value;
if (auto *literal = utils::Downcast<query::v2::ListLiteral>(unwind.input_expression_))
unwind_value = literal->elements_.size();
else
unwind_value = MiscParam::kUnwindNoLiteral;
cardinality_ *= unwind_value;
return true;
}
bool PostVisit(Foreach &foreach) override {
// Foreach cost depends both on the number elements in the list that get unwound
// as well as the total clauses that get called for each unwounded element.
// First estimate cardinality and then increment the cost.
double foreach_elements{0};
if (auto *literal = utils::Downcast<query::v2::ListLiteral>(foreach.expression_)) {
foreach_elements = literal->elements_.size();
} else {
foreach_elements = MiscParam::kForeachNoLiteral;
}
cardinality_ *= foreach_elements;
IncrementCost(CostParam::kForeach);
return true;
}
bool Visit(Once &) override { return true; }
auto cost() const { return cost_; }
auto cardinality() const { return cardinality_; }
private:
// cost estimation that gets accumulated as the visitor
// tours the logical plan
double cost_{0};
// cardinality estimation (how many times an operator gets executed)
// cardinality is a double to make it easier to work with
double cardinality_{1};
// accessor used for cardinality estimates in ScanAll and ScanAllByLabel
TDbAccessor *db_accessor_;
const Parameters &parameters;
void IncrementCost(double param) { cost_ += param * cardinality_; }
// converts an optional ScanAll range bound into a property value
// if the bound is present and is a constant expression convertible to
// a property value. otherwise returns nullopt
std::optional<utils::Bound<storage::v3::PropertyValue>> BoundToPropertyValue(
std::optional<ScanAllByLabelPropertyRange::Bound> bound) {
if (bound) {
auto property_value = ConstPropertyValue(bound->value());
if (property_value) return utils::Bound<storage::v3::PropertyValue>(*property_value, bound->type());
}
return std::nullopt;
}
// If the expression is a constant property value, it is returned. Otherwise,
// return nullopt.
std::optional<storage::v3::PropertyValue> ConstPropertyValue(const Expression *expression) {
if (auto *literal = utils::Downcast<const PrimitiveLiteral>(expression)) {
return literal->value_;
} else if (auto *param_lookup = utils::Downcast<const ParameterLookup>(expression)) {
return parameters.AtTokenPosition(param_lookup->token_position_);
}
return std::nullopt;
}
};
/** Returns the estimated cost of the given plan. */
template <class TDbAccessor>
double EstimatePlanCost(TDbAccessor *db, const Parameters &parameters, LogicalOperator &plan) {
CostEstimator<TDbAccessor> estimator(db, parameters);
plan.Accept(estimator);
return estimator.cost();
}
} // namespace memgraph::query::v2::plan

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,158 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
/// @file
/// This file is an entry point for invoking various planners via the following
/// API:
/// * `MakeLogicalPlanForSingleQuery`
/// * `MakeLogicalPlan`
#pragma once
#include "query/v2/plan/cost_estimator.hpp"
#include "query/v2/plan/operator.hpp"
#include "query/v2/plan/preprocess.hpp"
#include "query/v2/plan/pretty_print.hpp"
#include "query/v2/plan/rewrite/index_lookup.hpp"
#include "query/v2/plan/rule_based_planner.hpp"
#include "query/v2/plan/variable_start_planner.hpp"
#include "query/v2/plan/vertex_count_cache.hpp"
namespace memgraph::query::v2 {
class AstStorage;
class SymbolTable;
namespace plan {
class PostProcessor final {
Parameters parameters_;
public:
using ProcessedPlan = std::unique_ptr<LogicalOperator>;
explicit PostProcessor(const Parameters &parameters) : parameters_(parameters) {}
template <class TPlanningContext>
std::unique_ptr<LogicalOperator> Rewrite(std::unique_ptr<LogicalOperator> plan, TPlanningContext *context) {
return RewriteWithIndexLookup(std::move(plan), context->symbol_table, context->ast_storage, context->db);
}
template <class TVertexCounts>
double EstimatePlanCost(const std::unique_ptr<LogicalOperator> &plan, TVertexCounts *vertex_counts) {
return query::v2::plan::EstimatePlanCost(vertex_counts, parameters_, *plan);
}
template <class TPlanningContext>
std::unique_ptr<LogicalOperator> MergeWithCombinator(std::unique_ptr<LogicalOperator> curr_op,
std::unique_ptr<LogicalOperator> last_op, const Tree &combinator,
TPlanningContext *context) {
if (const auto *union_ = utils::Downcast<const CypherUnion>(&combinator)) {
return std::unique_ptr<LogicalOperator>(
impl::GenUnion(*union_, std::move(last_op), std::move(curr_op), *context->symbol_table));
}
throw utils::NotYetImplemented("query combinator");
}
template <class TPlanningContext>
std::unique_ptr<LogicalOperator> MakeDistinct(std::unique_ptr<LogicalOperator> last_op, TPlanningContext *context) {
auto output_symbols = last_op->OutputSymbols(*context->symbol_table);
return std::make_unique<Distinct>(std::move(last_op), output_symbols);
}
};
/// @brief Generates the LogicalOperator tree for a single query and returns the
/// resulting plan.
///
/// @tparam TPlanner Type of the planner used for generation.
/// @tparam TDbAccessor Type of the database accessor used for generation.
/// @param vector of @c SingleQueryPart from the single query
/// @param context PlanningContext used for generating plans.
/// @return @c PlanResult which depends on the @c TPlanner used.
///
/// @sa PlanningContext
/// @sa RuleBasedPlanner
/// @sa VariableStartPlanner
template <template <class> class TPlanner, class TDbAccessor>
auto MakeLogicalPlanForSingleQuery(std::vector<SingleQueryPart> single_query_parts,
PlanningContext<TDbAccessor> *context) {
context->bound_symbols.clear();
return TPlanner<PlanningContext<TDbAccessor>>(context).Plan(single_query_parts);
}
/// Generates the LogicalOperator tree and returns the resulting plan.
///
/// @tparam TPlanningContext Type of the context used.
/// @tparam TPlanPostProcess Type of the plan post processor used.
///
/// @param context PlanningContext used for generating plans.
/// @param post_process performs plan rewrites and cost estimation.
/// @param use_variable_planner boolean flag to choose which planner to use.
///
/// @return pair consisting of the final `TPlanPostProcess::ProcessedPlan` and
/// the estimated cost of that plan as a `double`.
template <class TPlanningContext, class TPlanPostProcess>
auto MakeLogicalPlan(TPlanningContext *context, TPlanPostProcess *post_process, bool use_variable_planner) {
auto query_parts = CollectQueryParts(*context->symbol_table, *context->ast_storage, context->query);
auto &vertex_counts = *context->db;
double total_cost = 0;
using ProcessedPlan = typename TPlanPostProcess::ProcessedPlan;
ProcessedPlan last_plan;
for (const auto &query_part : query_parts.query_parts) {
std::optional<ProcessedPlan> curr_plan;
double min_cost = std::numeric_limits<double>::max();
if (use_variable_planner) {
auto plans = MakeLogicalPlanForSingleQuery<VariableStartPlanner>(query_part.single_query_parts, context);
for (auto plan : plans) {
// Plans are generated lazily and the current plan will disappear, so
// it's ok to move it.
auto rewritten_plan = post_process->Rewrite(std::move(plan), context);
double cost = post_process->EstimatePlanCost(rewritten_plan, &vertex_counts);
if (!curr_plan || cost < min_cost) {
curr_plan.emplace(std::move(rewritten_plan));
min_cost = cost;
}
}
} else {
auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>(query_part.single_query_parts, context);
auto rewritten_plan = post_process->Rewrite(std::move(plan), context);
min_cost = post_process->EstimatePlanCost(rewritten_plan, &vertex_counts);
curr_plan.emplace(std::move(rewritten_plan));
}
total_cost += min_cost;
if (query_part.query_combinator) {
last_plan = post_process->MergeWithCombinator(std::move(*curr_plan), std::move(last_plan),
*query_part.query_combinator, context);
} else {
last_plan = std::move(*curr_plan);
}
}
if (query_parts.distinct) {
last_plan = post_process->MakeDistinct(std::move(last_plan), context);
}
return std::make_pair(std::move(last_plan), total_cost);
}
template <class TPlanningContext>
auto MakeLogicalPlan(TPlanningContext *context, const Parameters &parameters, bool use_variable_planner) {
PostProcessor post_processor(parameters);
return MakeLogicalPlan(context, &post_processor, use_variable_planner);
}
} // namespace plan
} // namespace memgraph::query::v2

View File

@ -0,0 +1,599 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <algorithm>
#include <functional>
#include <stack>
#include <type_traits>
#include <unordered_map>
#include <variant>
#include "query/v2/exceptions.hpp"
#include "query/v2/frontend/ast/ast.hpp"
#include "query/v2/frontend/ast/ast_visitor.hpp"
#include "query/v2/plan/preprocess.hpp"
#include "utils/typeinfo.hpp"
namespace memgraph::query::v2::plan {
namespace {
void ForEachPattern(Pattern &pattern, std::function<void(NodeAtom *)> base,
std::function<void(NodeAtom *, EdgeAtom *, NodeAtom *)> collect) {
DMG_ASSERT(!pattern.atoms_.empty(), "Missing atoms in pattern");
auto atoms_it = pattern.atoms_.begin();
auto current_node = utils::Downcast<NodeAtom>(*atoms_it++);
DMG_ASSERT(current_node, "First pattern atom is not a node");
base(current_node);
// Remaining atoms need to follow sequentially as (EdgeAtom, NodeAtom)*
while (atoms_it != pattern.atoms_.end()) {
auto edge = utils::Downcast<EdgeAtom>(*atoms_it++);
DMG_ASSERT(edge, "Expected an edge atom in pattern.");
DMG_ASSERT(atoms_it != pattern.atoms_.end(), "Edge atom should not end the pattern.");
auto prev_node = current_node;
current_node = utils::Downcast<NodeAtom>(*atoms_it++);
DMG_ASSERT(current_node, "Expected a node atom in pattern.");
collect(prev_node, edge, current_node);
}
}
// Converts multiple Patterns to Expansions. Each Pattern can contain an
// arbitrarily long chain of nodes and edges. The conversion to an Expansion is
// done by splitting a pattern into triplets (node1, edge, node2). The triplets
// conserve the semantics of the pattern. For example, in a pattern:
// (m) -[e]- (n) -[f]- (o) the same can be achieved with:
// (m) -[e]- (n), (n) -[f]- (o).
// This representation makes it easier to permute from which node or edge we
// want to start expanding.
std::vector<Expansion> NormalizePatterns(const SymbolTable &symbol_table, const std::vector<Pattern *> &patterns) {
std::vector<Expansion> expansions;
auto ignore_node = [&](auto *) {};
auto collect_expansion = [&](auto *prev_node, auto *edge, auto *current_node) {
UsedSymbolsCollector collector(symbol_table);
if (edge->IsVariable()) {
if (edge->lower_bound_) edge->lower_bound_->Accept(collector);
if (edge->upper_bound_) edge->upper_bound_->Accept(collector);
if (edge->filter_lambda_.expression) edge->filter_lambda_.expression->Accept(collector);
// Remove symbols which are bound by lambda arguments.
collector.symbols_.erase(symbol_table.at(*edge->filter_lambda_.inner_edge));
collector.symbols_.erase(symbol_table.at(*edge->filter_lambda_.inner_node));
if (edge->type_ == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH) {
collector.symbols_.erase(symbol_table.at(*edge->weight_lambda_.inner_edge));
collector.symbols_.erase(symbol_table.at(*edge->weight_lambda_.inner_node));
}
}
expansions.emplace_back(Expansion{prev_node, edge, edge->direction_, false, collector.symbols_, current_node});
};
for (const auto &pattern : patterns) {
if (pattern->atoms_.size() == 1U) {
auto *node = utils::Downcast<NodeAtom>(pattern->atoms_[0]);
DMG_ASSERT(node, "First pattern atom is not a node");
expansions.emplace_back(Expansion{node});
} else {
ForEachPattern(*pattern, ignore_node, collect_expansion);
}
}
return expansions;
}
// Fills the given Matching, by converting the Match patterns to normalized
// representation as Expansions. Filters used in the Match are also collected,
// as well as edge symbols which determine Cyphermorphism. Collecting filters
// will lift them out of a pattern and generate new expressions (just like they
// were in a Where clause).
void AddMatching(const std::vector<Pattern *> &patterns, Where *where, SymbolTable &symbol_table, AstStorage &storage,
Matching &matching) {
auto expansions = NormalizePatterns(symbol_table, patterns);
std::unordered_set<Symbol> edge_symbols;
for (const auto &expansion : expansions) {
// Matching may already have some expansions, so offset our index.
const size_t expansion_ix = matching.expansions.size();
// Map node1 symbol to expansion
const auto &node1_sym = symbol_table.at(*expansion.node1->identifier_);
matching.node_symbol_to_expansions[node1_sym].insert(expansion_ix);
// Add node1 to all symbols.
matching.expansion_symbols.insert(node1_sym);
if (expansion.edge) {
const auto &edge_sym = symbol_table.at(*expansion.edge->identifier_);
// Fill edge symbols for Cyphermorphism.
edge_symbols.insert(edge_sym);
// Map node2 symbol to expansion
const auto &node2_sym = symbol_table.at(*expansion.node2->identifier_);
matching.node_symbol_to_expansions[node2_sym].insert(expansion_ix);
// Add edge and node2 to all symbols
matching.expansion_symbols.insert(edge_sym);
matching.expansion_symbols.insert(node2_sym);
}
matching.expansions.push_back(expansion);
}
if (!edge_symbols.empty()) {
matching.edge_symbols.emplace_back(edge_symbols);
}
for (auto *pattern : patterns) {
matching.filters.CollectPatternFilters(*pattern, symbol_table, storage);
if (pattern->identifier_->user_declared_) {
std::vector<Symbol> path_elements;
for (auto *pattern_atom : pattern->atoms_)
path_elements.emplace_back(symbol_table.at(*pattern_atom->identifier_));
matching.named_paths.emplace(symbol_table.at(*pattern->identifier_), std::move(path_elements));
}
}
if (where) {
matching.filters.CollectWhereFilter(*where, symbol_table);
}
}
void AddMatching(const Match &match, SymbolTable &symbol_table, AstStorage &storage, Matching &matching) {
return AddMatching(match.patterns_, match.where_, symbol_table, storage, matching);
}
auto SplitExpressionOnAnd(Expression *expression) {
// TODO: Think about converting all filtering expression into CNF to improve
// the granularity of filters which can be stand alone.
std::vector<Expression *> expressions;
std::stack<Expression *> pending_expressions;
pending_expressions.push(expression);
while (!pending_expressions.empty()) {
auto *current_expression = pending_expressions.top();
pending_expressions.pop();
if (auto *and_op = utils::Downcast<AndOperator>(current_expression)) {
pending_expressions.push(and_op->expression1_);
pending_expressions.push(and_op->expression2_);
} else {
expressions.push_back(current_expression);
}
}
return expressions;
}
} // namespace
PropertyFilter::PropertyFilter(const SymbolTable &symbol_table, const Symbol &symbol, PropertyIx property,
Expression *value, Type type)
: symbol_(symbol), property_(property), type_(type), value_(value) {
MG_ASSERT(type != Type::RANGE);
UsedSymbolsCollector collector(symbol_table);
value->Accept(collector);
is_symbol_in_value_ = utils::Contains(collector.symbols_, symbol);
}
PropertyFilter::PropertyFilter(const SymbolTable &symbol_table, const Symbol &symbol, PropertyIx property,
const std::optional<PropertyFilter::Bound> &lower_bound,
const std::optional<PropertyFilter::Bound> &upper_bound)
: symbol_(symbol), property_(property), type_(Type::RANGE), lower_bound_(lower_bound), upper_bound_(upper_bound) {
UsedSymbolsCollector collector(symbol_table);
if (lower_bound) {
lower_bound->value()->Accept(collector);
}
if (upper_bound) {
upper_bound->value()->Accept(collector);
}
is_symbol_in_value_ = utils::Contains(collector.symbols_, symbol);
}
PropertyFilter::PropertyFilter(const Symbol &symbol, PropertyIx property, Type type)
: symbol_(symbol), property_(property), type_(type) {
// As this constructor is used for property filters where
// we don't have to evaluate the filter expression, we set
// the is_symbol_in_value_ to false, although the filter
// expression may actually contain the symbol whose property
// we may be looking up.
}
IdFilter::IdFilter(const SymbolTable &symbol_table, const Symbol &symbol, Expression *value)
: symbol_(symbol), value_(value) {
MG_ASSERT(value);
UsedSymbolsCollector collector(symbol_table);
value->Accept(collector);
is_symbol_in_value_ = utils::Contains(collector.symbols_, symbol);
}
void Filters::EraseFilter(const FilterInfo &filter) {
// TODO: Ideally, we want to determine the equality of both expression trees,
// instead of a simple pointer compare.
all_filters_.erase(std::remove_if(all_filters_.begin(), all_filters_.end(),
[&filter](const auto &f) { return f.expression == filter.expression; }),
all_filters_.end());
}
void Filters::EraseLabelFilter(const Symbol &symbol, LabelIx label, std::vector<Expression *> *removed_filters) {
for (auto filter_it = all_filters_.begin(); filter_it != all_filters_.end();) {
if (filter_it->type != FilterInfo::Type::Label) {
++filter_it;
continue;
}
if (!utils::Contains(filter_it->used_symbols, symbol)) {
++filter_it;
continue;
}
auto label_it = std::find(filter_it->labels.begin(), filter_it->labels.end(), label);
if (label_it == filter_it->labels.end()) {
++filter_it;
continue;
}
filter_it->labels.erase(label_it);
DMG_ASSERT(!utils::Contains(filter_it->labels, label), "Didn't expect duplicated labels");
if (filter_it->labels.empty()) {
// If there are no labels to filter, then erase the whole FilterInfo.
if (removed_filters) {
removed_filters->push_back(filter_it->expression);
}
filter_it = all_filters_.erase(filter_it);
} else {
++filter_it;
}
}
}
void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table, AstStorage &storage) {
UsedSymbolsCollector collector(symbol_table);
auto add_properties_variable = [&](EdgeAtom *atom) {
const auto &symbol = symbol_table.at(*atom->identifier_);
if (auto *properties = std::get_if<std::unordered_map<PropertyIx, Expression *>>(&atom->properties_)) {
for (auto &prop_pair : *properties) {
// We need to store two property-lookup filters in all_filters. One is
// used for inlining property filters into variable expansion, and
// utilizes the inner_edge symbol. The other is used for post-expansion
// filtering and does not use the inner_edge symbol, but the edge symbol
// (a list of edges).
{
collector.symbols_.clear();
prop_pair.second->Accept(collector);
collector.symbols_.emplace(symbol_table.at(*atom->filter_lambda_.inner_node));
collector.symbols_.emplace(symbol_table.at(*atom->filter_lambda_.inner_edge));
// First handle the inline property filter.
auto *property_lookup = storage.Create<PropertyLookup>(atom->filter_lambda_.inner_edge, prop_pair.first);
auto *prop_equal = storage.Create<EqualOperator>(property_lookup, prop_pair.second);
// Currently, variable expand has no gains if we set PropertyFilter.
all_filters_.emplace_back(FilterInfo{FilterInfo::Type::Generic, prop_equal, collector.symbols_});
}
{
collector.symbols_.clear();
prop_pair.second->Accept(collector);
collector.symbols_.insert(symbol); // PropertyLookup uses the symbol.
// Now handle the post-expansion filter.
// Create a new identifier and a symbol which will be filled in All.
auto *identifier = storage.Create<Identifier>(atom->identifier_->name_, atom->identifier_->user_declared_)
->MapTo(symbol_table.CreateSymbol(atom->identifier_->name_, false));
// Create an equality expression and store it in all_filters_.
auto *property_lookup = storage.Create<PropertyLookup>(identifier, prop_pair.first);
auto *prop_equal = storage.Create<EqualOperator>(property_lookup, prop_pair.second);
// Currently, variable expand has no gains if we set PropertyFilter.
all_filters_.emplace_back(
FilterInfo{FilterInfo::Type::Generic,
storage.Create<All>(identifier, atom->identifier_, storage.Create<Where>(prop_equal)),
collector.symbols_});
}
}
return;
}
throw SemanticException("Property map matching not supported in MATCH/MERGE clause!");
};
auto add_properties = [&](auto *atom) {
const auto &symbol = symbol_table.at(*atom->identifier_);
if (auto *properties = std::get_if<std::unordered_map<PropertyIx, Expression *>>(&atom->properties_)) {
for (auto &prop_pair : *properties) {
// Create an equality expression and store it in all_filters_.
auto *property_lookup = storage.Create<PropertyLookup>(atom->identifier_, prop_pair.first);
auto *prop_equal = storage.Create<EqualOperator>(property_lookup, prop_pair.second);
collector.symbols_.clear();
prop_equal->Accept(collector);
FilterInfo filter_info{FilterInfo::Type::Property, prop_equal, collector.symbols_};
// Store a PropertyFilter on the value of the property.
filter_info.property_filter.emplace(symbol_table, symbol, prop_pair.first, prop_pair.second,
PropertyFilter::Type::EQUAL);
all_filters_.emplace_back(filter_info);
}
return;
}
throw SemanticException("Property map matching not supported in MATCH/MERGE clause!");
};
auto add_node_filter = [&](NodeAtom *node) {
const auto &node_symbol = symbol_table.at(*node->identifier_);
if (!node->labels_.empty()) {
// Create a LabelsTest and store it.
auto *labels_test = storage.Create<LabelsTest>(node->identifier_, node->labels_);
auto label_filter = FilterInfo{FilterInfo::Type::Label, labels_test, std::unordered_set<Symbol>{node_symbol}};
label_filter.labels = node->labels_;
all_filters_.emplace_back(label_filter);
}
add_properties(node);
};
auto add_expand_filter = [&](NodeAtom *, EdgeAtom *edge, NodeAtom *node) {
if (edge->IsVariable())
add_properties_variable(edge);
else
add_properties(edge);
add_node_filter(node);
};
ForEachPattern(pattern, add_node_filter, add_expand_filter);
}
// Adds the where filter expression to `all_filters_` and collects additional
// information for potential property and label indexing.
void Filters::CollectWhereFilter(Where &where, const SymbolTable &symbol_table) {
CollectFilterExpression(where.expression_, symbol_table);
}
// Adds the expression to `all_filters_` and collects additional
// information for potential property and label indexing.
void Filters::CollectFilterExpression(Expression *expr, const SymbolTable &symbol_table) {
auto filters = SplitExpressionOnAnd(expr);
for (const auto &filter : filters) {
AnalyzeAndStoreFilter(filter, symbol_table);
}
}
// Analyzes the filter expression by collecting information on filtering labels
// and properties to be used with indexing.
void Filters::AnalyzeAndStoreFilter(Expression *expr, const SymbolTable &symbol_table) {
using Bound = PropertyFilter::Bound;
UsedSymbolsCollector collector(symbol_table);
expr->Accept(collector);
auto make_filter = [&collector, &expr](FilterInfo::Type type) { return FilterInfo{type, expr, collector.symbols_}; };
auto get_property_lookup = [](auto *maybe_lookup, auto *&prop_lookup, auto *&ident) -> bool {
return (prop_lookup = utils::Downcast<PropertyLookup>(maybe_lookup)) &&
(ident = utils::Downcast<Identifier>(prop_lookup->expression_));
};
// Checks if maybe_lookup is a property lookup, stores it as a
// PropertyFilter and returns true. If it isn't, returns false.
auto add_prop_equal = [&](auto *maybe_lookup, auto *val_expr) -> bool {
PropertyLookup *prop_lookup = nullptr;
Identifier *ident = nullptr;
if (get_property_lookup(maybe_lookup, prop_lookup, ident)) {
auto filter = make_filter(FilterInfo::Type::Property);
filter.property_filter = PropertyFilter(symbol_table, symbol_table.at(*ident), prop_lookup->property_, val_expr,
PropertyFilter::Type::EQUAL);
all_filters_.emplace_back(filter);
return true;
}
return false;
};
// Like add_prop_equal, but for adding regex match property filter.
auto add_prop_regex_match = [&](auto *maybe_lookup, auto *val_expr) -> bool {
PropertyLookup *prop_lookup = nullptr;
Identifier *ident = nullptr;
if (get_property_lookup(maybe_lookup, prop_lookup, ident)) {
auto filter = make_filter(FilterInfo::Type::Property);
filter.property_filter = PropertyFilter(symbol_table, symbol_table.at(*ident), prop_lookup->property_, val_expr,
PropertyFilter::Type::REGEX_MATCH);
all_filters_.emplace_back(filter);
return true;
}
return false;
};
// Checks if either the expr1 and expr2 are property lookups, adds them as
// PropertyFilter and returns true. Otherwise, returns false.
auto add_prop_greater = [&](auto *expr1, auto *expr2, auto bound_type) -> bool {
PropertyLookup *prop_lookup = nullptr;
Identifier *ident = nullptr;
bool is_prop_filter = false;
if (get_property_lookup(expr1, prop_lookup, ident)) {
// n.prop > value
auto filter = make_filter(FilterInfo::Type::Property);
filter.property_filter.emplace(symbol_table, symbol_table.at(*ident), prop_lookup->property_,
Bound(expr2, bound_type), std::nullopt);
all_filters_.emplace_back(filter);
is_prop_filter = true;
}
if (get_property_lookup(expr2, prop_lookup, ident)) {
// value > n.prop
auto filter = make_filter(FilterInfo::Type::Property);
filter.property_filter.emplace(symbol_table, symbol_table.at(*ident), prop_lookup->property_, std::nullopt,
Bound(expr1, bound_type));
all_filters_.emplace_back(filter);
is_prop_filter = true;
}
return is_prop_filter;
};
// Check if maybe_id_fun is ID invocation on an indentifier and add it as
// IdFilter.
auto add_id_equal = [&](auto *maybe_id_fun, auto *val_expr) -> bool {
auto *id_fun = utils::Downcast<Function>(maybe_id_fun);
if (!id_fun) return false;
if (id_fun->function_name_ != kId) return false;
if (id_fun->arguments_.size() != 1U) return false;
auto *ident = utils::Downcast<Identifier>(id_fun->arguments_.front());
if (!ident) return false;
auto filter = make_filter(FilterInfo::Type::Id);
filter.id_filter.emplace(symbol_table, symbol_table.at(*ident), val_expr);
all_filters_.emplace_back(filter);
return true;
};
// Checks if maybe_lookup is a property lookup, stores it as a
// PropertyFilter and returns true. If it isn't, returns false.
auto add_prop_in_list = [&](auto *maybe_lookup, auto *val_expr) -> bool {
if (!utils::Downcast<ListLiteral>(val_expr)) return false;
PropertyLookup *prop_lookup = nullptr;
Identifier *ident = nullptr;
if (get_property_lookup(maybe_lookup, prop_lookup, ident)) {
auto filter = make_filter(FilterInfo::Type::Property);
filter.property_filter = PropertyFilter(symbol_table, symbol_table.at(*ident), prop_lookup->property_, val_expr,
PropertyFilter::Type::IN);
all_filters_.emplace_back(filter);
return true;
}
return false;
};
// Checks whether maybe_prop_not_null_check is the null check on a property,
// ("prop IS NOT NULL"), stores it as a PropertyFilter if it is, and returns
// true. If it isn't returns false.
auto add_prop_is_not_null_check = [&](auto *maybe_is_not_null_check) -> bool {
// Strip away the outer NOT operator, and figure out
// whether the inner expression is of the form "prop IS NULL"
if (!maybe_is_not_null_check) {
return false;
}
auto *maybe_is_null_check = utils::Downcast<IsNullOperator>(maybe_is_not_null_check->expression_);
if (!maybe_is_null_check) {
return false;
}
PropertyLookup *prop_lookup = nullptr;
Identifier *ident = nullptr;
if (!get_property_lookup(maybe_is_null_check->expression_, prop_lookup, ident)) {
return false;
}
auto filter = make_filter(FilterInfo::Type::Property);
filter.property_filter =
PropertyFilter(symbol_table.at(*ident), prop_lookup->property_, PropertyFilter::Type::IS_NOT_NULL);
all_filters_.emplace_back(filter);
return true;
};
// We are only interested to see the insides of And, because Or prevents
// indexing since any labels and properties found there may be optional.
DMG_ASSERT(!utils::IsSubtype(*expr, AndOperator::kType), "Expected AndOperators have been split.");
if (auto *labels_test = utils::Downcast<LabelsTest>(expr)) {
// Since LabelsTest may contain any expression, we can only use the
// simplest test on an identifier.
if (utils::Downcast<Identifier>(labels_test->expression_)) {
auto filter = make_filter(FilterInfo::Type::Label);
filter.labels = labels_test->labels_;
all_filters_.emplace_back(filter);
} else {
all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic));
}
} else if (auto *eq = utils::Downcast<EqualOperator>(expr)) {
// Try to get property equality test from the top expressions.
// Unfortunately, we cannot go deeper inside Equal, because chained equals
// need not correspond to And. For example, `(n.prop = value) = false)`:
// EQ
// / \
// EQ false -- top expressions
// / \
// n.prop value
// Here the `prop` may be different than `value` resulting in `false`. This
// would compare with the top level `false`, producing `true`. Therefore, it
// is incorrect to pick up `n.prop = value` for scanning by property index.
bool is_prop_filter = add_prop_equal(eq->expression1_, eq->expression2_);
// And reversed.
is_prop_filter |= add_prop_equal(eq->expression2_, eq->expression1_);
// Try to get ID equality filter.
bool is_id_filter = add_id_equal(eq->expression1_, eq->expression2_);
is_id_filter |= add_id_equal(eq->expression2_, eq->expression1_);
if (!is_prop_filter && !is_id_filter) {
// No special filter was added, so just store a generic filter.
all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic));
}
} else if (auto *regex_match = utils::Downcast<RegexMatch>(expr)) {
if (!add_prop_regex_match(regex_match->string_expr_, regex_match->regex_)) {
all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic));
}
} else if (auto *gt = utils::Downcast<GreaterOperator>(expr)) {
if (!add_prop_greater(gt->expression1_, gt->expression2_, Bound::Type::EXCLUSIVE)) {
all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic));
}
} else if (auto *ge = utils::Downcast<GreaterEqualOperator>(expr)) {
if (!add_prop_greater(ge->expression1_, ge->expression2_, Bound::Type::INCLUSIVE)) {
all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic));
}
} else if (auto *lt = utils::Downcast<LessOperator>(expr)) {
// Like greater, but in reverse.
if (!add_prop_greater(lt->expression2_, lt->expression1_, Bound::Type::EXCLUSIVE)) {
all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic));
}
} else if (auto *le = utils::Downcast<LessEqualOperator>(expr)) {
// Like greater equal, but in reverse.
if (!add_prop_greater(le->expression2_, le->expression1_, Bound::Type::INCLUSIVE)) {
all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic));
}
} else if (auto *in = utils::Downcast<InListOperator>(expr)) {
// IN isn't equivalent to Equal because IN isn't a symmetric operator. The
// IN filter is captured here only if the property lookup occurs on the
// left side of the operator. In that case, it's valid to do the IN list
// optimization during the index lookup rewrite phase.
if (!add_prop_in_list(in->expression1_, in->expression2_)) {
all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic));
}
} else if (auto *is_not_null = utils::Downcast<NotOperator>(expr)) {
if (!add_prop_is_not_null_check(is_not_null)) {
all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic));
}
} else {
all_filters_.emplace_back(make_filter(FilterInfo::Type::Generic));
}
// TODO: Collect comparisons like `expr1 < n.prop < expr2` for potential
// indexing by range. Note, that the generated Ast uses AND for chained
// relation operators. Therefore, `expr1 < n.prop < expr2` will be represented
// as `expr1 < n.prop AND n.prop < expr2`.
}
static void ParseForeach(query::v2::Foreach &foreach, SingleQueryPart &query_part, AstStorage &storage,
SymbolTable &symbol_table) {
for (auto *clause : foreach.clauses_) {
if (auto *merge = utils::Downcast<query::v2::Merge>(clause)) {
query_part.merge_matching.emplace_back(Matching{});
AddMatching({merge->pattern_}, nullptr, symbol_table, storage, query_part.merge_matching.back());
} else if (auto *nested = utils::Downcast<query::v2::Foreach>(clause)) {
ParseForeach(*nested, query_part, storage, symbol_table);
}
}
}
// Converts a Query to multiple QueryParts. In the process new Ast nodes may be
// created, e.g. filter expressions.
std::vector<SingleQueryPart> CollectSingleQueryParts(SymbolTable &symbol_table, AstStorage &storage,
SingleQuery *single_query) {
std::vector<SingleQueryPart> query_parts(1);
auto *query_part = &query_parts.back();
for (auto &clause : single_query->clauses_) {
if (auto *match = utils::Downcast<Match>(clause)) {
if (match->optional_) {
query_part->optional_matching.emplace_back(Matching{});
AddMatching(*match, symbol_table, storage, query_part->optional_matching.back());
} else {
DMG_ASSERT(query_part->optional_matching.empty(), "Match clause cannot follow optional match.");
AddMatching(*match, symbol_table, storage, query_part->matching);
}
} else {
query_part->remaining_clauses.push_back(clause);
if (auto *merge = utils::Downcast<query::v2::Merge>(clause)) {
query_part->merge_matching.emplace_back(Matching{});
AddMatching({merge->pattern_}, nullptr, symbol_table, storage, query_part->merge_matching.back());
} else if (auto *foreach = utils::Downcast<query::v2::Foreach>(clause)) {
ParseForeach(*foreach, *query_part, storage, symbol_table);
} else if (utils::IsSubtype(*clause, With::kType) || utils::IsSubtype(*clause, query::v2::Unwind::kType) ||
utils::IsSubtype(*clause, query::v2::CallProcedure::kType) ||
utils::IsSubtype(*clause, query::v2::LoadCsv::kType)) {
// This query part is done, continue with a new one.
query_parts.emplace_back(SingleQueryPart{});
query_part = &query_parts.back();
} else if (utils::IsSubtype(*clause, Return::kType)) {
return query_parts;
}
}
}
return query_parts;
}
QueryParts CollectQueryParts(SymbolTable &symbol_table, AstStorage &storage, CypherQuery *query) {
std::vector<QueryPart> query_parts;
auto *single_query = query->single_query_;
MG_ASSERT(single_query, "Expected at least a single query");
query_parts.push_back(QueryPart{CollectSingleQueryParts(symbol_table, storage, single_query)});
bool distinct = false;
for (auto *cypher_union : query->cypher_unions_) {
if (cypher_union->distinct_) {
distinct = true;
}
auto *single_query = cypher_union->single_query_;
MG_ASSERT(single_query, "Expected UNION to have a query");
query_parts.push_back(QueryPart{CollectSingleQueryParts(symbol_table, storage, single_query), cypher_union});
}
return QueryParts{query_parts, distinct};
}
} // namespace memgraph::query::v2::plan

View File

@ -0,0 +1,360 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
/// @file
#pragma once
#include <optional>
#include <set>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "query/v2/frontend/ast/ast.hpp"
#include "query/v2/frontend/semantic/symbol_table.hpp"
#include "query/v2/plan/operator.hpp"
namespace memgraph::query::v2::plan {
/// Collects symbols from identifiers found in visited AST nodes.
class UsedSymbolsCollector : public HierarchicalTreeVisitor {
public:
explicit UsedSymbolsCollector(const SymbolTable &symbol_table) : symbol_table_(symbol_table) {}
using HierarchicalTreeVisitor::PostVisit;
using HierarchicalTreeVisitor::PreVisit;
using HierarchicalTreeVisitor::Visit;
bool PostVisit(All &all) override {
// Remove the symbol which is bound by all, because we are only interested
// in free (unbound) symbols.
symbols_.erase(symbol_table_.at(*all.identifier_));
return true;
}
bool PostVisit(Single &single) override {
// Remove the symbol which is bound by single, because we are only
// interested in free (unbound) symbols.
symbols_.erase(symbol_table_.at(*single.identifier_));
return true;
}
bool PostVisit(Any &any) override {
// Remove the symbol which is bound by any, because we are only interested
// in free (unbound) symbols.
symbols_.erase(symbol_table_.at(*any.identifier_));
return true;
}
bool PostVisit(None &none) override {
// Remove the symbol which is bound by none, because we are only interested
// in free (unbound) symbols.
symbols_.erase(symbol_table_.at(*none.identifier_));
return true;
}
bool PostVisit(Reduce &reduce) override {
// Remove the symbols bound by reduce, because we are only interested
// in free (unbound) symbols.
symbols_.erase(symbol_table_.at(*reduce.accumulator_));
symbols_.erase(symbol_table_.at(*reduce.identifier_));
return true;
}
bool Visit(Identifier &ident) override {
symbols_.insert(symbol_table_.at(ident));
return true;
}
bool Visit(PrimitiveLiteral &) override { return true; }
bool Visit(ParameterLookup &) override { return true; }
std::unordered_set<Symbol> symbols_;
const SymbolTable &symbol_table_;
};
/// Normalized representation of a pattern that needs to be matched.
struct Expansion {
/// The first node in the expansion, it can be a single node.
NodeAtom *node1 = nullptr;
/// Optional edge which connects the 2 nodes.
EdgeAtom *edge = nullptr;
/// Direction of the edge, it may be flipped compared to original
/// @c EdgeAtom during plan generation.
EdgeAtom::Direction direction = EdgeAtom::Direction::BOTH;
/// True if the direction and nodes were flipped.
bool is_flipped = false;
/// Set of symbols found inside the range expressions of a variable path edge.
std::unordered_set<Symbol> symbols_in_range{};
/// Optional node at the other end of an edge. If the expansion
/// contains an edge, then this node is required.
NodeAtom *node2 = nullptr;
};
/// Stores the symbols and expression used to filter a property.
class PropertyFilter {
public:
using Bound = ScanAllByLabelPropertyRange::Bound;
/// Depending on type, this PropertyFilter may be a value equality, regex
/// matched value or a range with lower and (or) upper bounds, IN list filter.
enum class Type { EQUAL, REGEX_MATCH, RANGE, IN, IS_NOT_NULL };
/// Construct with Expression being the equality or regex match check.
PropertyFilter(const SymbolTable &, const Symbol &, PropertyIx, Expression *, Type);
/// Construct the range based filter.
PropertyFilter(const SymbolTable &, const Symbol &, PropertyIx, const std::optional<Bound> &,
const std::optional<Bound> &);
/// Construct a filter without an expression that produces a value.
/// Used for the "PROP IS NOT NULL" filter, and can be used for any
/// property filter that doesn't need to use an expression to produce
/// values that should be filtered further.
PropertyFilter(const Symbol &, PropertyIx, Type);
/// Symbol whose property is looked up.
Symbol symbol_;
PropertyIx property_;
Type type_;
/// True if the same symbol is used in expressions for value or bounds.
bool is_symbol_in_value_ = false;
/// Expression which when evaluated produces the value a property must
/// equal or regex match depending on type_.
Expression *value_ = nullptr;
/// Expressions which produce lower and upper bounds for a property.
std::optional<Bound> lower_bound_{};
std::optional<Bound> upper_bound_{};
};
/// Filtering by ID, for example `MATCH (n) WHERE id(n) = 42 ...`
class IdFilter {
public:
/// Construct with Expression being the required value for ID.
IdFilter(const SymbolTable &, const Symbol &, Expression *);
/// Symbol whose id is looked up.
Symbol symbol_;
/// Expression which when evaluted produces the value an ID must satisfy.
Expression *value_;
/// True if the same symbol is used in expressions for value.
bool is_symbol_in_value_{false};
};
/// Stores additional information for a filter expression.
struct FilterInfo {
/// A FilterInfo can be a generic filter expression or a specific filtering
/// applied for labels or a property. Non generic types contain extra
/// information which can be used to produce indexed scans of graph
/// elements.
enum class Type { Generic, Label, Property, Id };
Type type;
/// The original filter expression which must be satisfied.
Expression *expression;
/// Set of used symbols by the filter @c expression.
std::unordered_set<Symbol> used_symbols;
/// Labels for Type::Label filtering.
std::vector<LabelIx> labels;
/// Property information for Type::Property filtering.
std::optional<PropertyFilter> property_filter;
/// Information for Type::Id filtering.
std::optional<IdFilter> id_filter;
};
/// Stores information on filters used inside the @c Matching of a @c QueryPart.
///
/// Info is stored as a list of FilterInfo objects corresponding to all filter
/// expressions that should be generated.
class Filters final {
public:
using iterator = std::vector<FilterInfo>::iterator;
using const_iterator = std::vector<FilterInfo>::const_iterator;
auto begin() { return all_filters_.begin(); }
auto begin() const { return all_filters_.begin(); }
auto end() { return all_filters_.end(); }
auto end() const { return all_filters_.end(); }
auto empty() const { return all_filters_.empty(); }
auto erase(iterator pos) { return all_filters_.erase(pos); }
auto erase(const_iterator pos) { return all_filters_.erase(pos); }
auto erase(iterator first, iterator last) { return all_filters_.erase(first, last); }
auto erase(const_iterator first, const_iterator last) { return all_filters_.erase(first, last); }
auto FilteredLabels(const Symbol &symbol) const {
std::unordered_set<LabelIx> labels;
for (const auto &filter : all_filters_) {
if (filter.type == FilterInfo::Type::Label && utils::Contains(filter.used_symbols, symbol)) {
MG_ASSERT(filter.used_symbols.size() == 1U, "Expected a single used symbol for label filter");
labels.insert(filter.labels.begin(), filter.labels.end());
}
}
return labels;
}
/// Remove a filter; may invalidate iterators.
/// Removal is done by comparing only the expression, so that multiple
/// FilterInfo objects using the same original expression are removed.
void EraseFilter(const FilterInfo &);
/// Remove a label filter for symbol; may invalidate iterators.
/// If removed_filters is not nullptr, fills the vector with original
/// `Expression *` which are now completely removed.
void EraseLabelFilter(const Symbol &, LabelIx, std::vector<Expression *> *removed_filters = nullptr);
/// Returns a vector of FilterInfo for properties.
auto PropertyFilters(const Symbol &symbol) const {
std::vector<FilterInfo> filters;
for (const auto &filter : all_filters_) {
if (filter.type == FilterInfo::Type::Property && filter.property_filter->symbol_ == symbol) {
filters.push_back(filter);
}
}
return filters;
}
/// Return a vector of FilterInfo for ID equality filtering.
auto IdFilters(const Symbol &symbol) const {
std::vector<FilterInfo> filters;
for (const auto &filter : all_filters_) {
if (filter.type == FilterInfo::Type::Id && filter.id_filter->symbol_ == symbol) {
filters.push_back(filter);
}
}
return filters;
}
/// Collects filtering information from a pattern.
///
/// Goes through all the atoms in a pattern and generates filter expressions
/// for found labels, properties and edge types. The generated expressions are
/// stored.
void CollectPatternFilters(Pattern &, SymbolTable &, AstStorage &);
/// Collects filtering information from a where expression.
///
/// Takes the where expression and stores it, then analyzes the expression for
/// additional information. The additional information is used to populate
/// label filters and property filters, so that indexed scanning can use it.
void CollectWhereFilter(Where &, const SymbolTable &);
/// Collects filtering information from an expression.
///
/// Takes the where expression and stores it, then analyzes the expression for
/// additional information. The additional information is used to populate
/// label filters and property filters, so that indexed scanning can use it.
void CollectFilterExpression(Expression *, const SymbolTable &);
private:
void AnalyzeAndStoreFilter(Expression *, const SymbolTable &);
std::vector<FilterInfo> all_filters_;
};
/// Normalized representation of a single or multiple Match clauses.
///
/// For example, `MATCH (a :Label) -[e1]- (b) -[e2]- (c) MATCH (n) -[e3]- (m)
/// WHERE c.prop < 42` will produce the following.
/// Expansions will store `(a) -[e1]-(b)`, `(b) -[e2]- (c)` and
/// `(n) -[e3]- (m)`.
/// Edge symbols for Cyphermorphism will only contain the set `{e1, e2}` for the
/// first `MATCH` and the set `{e3}` for the second.
/// Filters will contain 2 pairs. One for testing `:Label` on symbol `a` and the
/// other obtained from `WHERE` on symbol `c`.
struct Matching {
/// All expansions that need to be performed across @c Match clauses.
std::vector<Expansion> expansions;
/// Symbols for edges established in match, used to ensure Cyphermorphism.
///
/// There are multiple sets, because each Match clause determines a single
/// set.
std::vector<std::unordered_set<Symbol>> edge_symbols;
/// Information on used filter expressions while matching.
Filters filters;
/// Maps node symbols to expansions which bind them.
std::unordered_map<Symbol, std::set<size_t>> node_symbol_to_expansions{};
/// Maps named path symbols to a vector of Symbols that define its pattern.
std::unordered_map<Symbol, std::vector<Symbol>> named_paths{};
/// All node and edge symbols across all expansions (from all matches).
std::unordered_set<Symbol> expansion_symbols{};
};
/// @brief Represents a read (+ write) part of a query. Parts are split on
/// `WITH` clauses.
///
/// Each part ends with either:
///
/// * `RETURN` clause;
/// * `WITH` clause;
/// * `UNWIND` clause;
/// * `CALL` clause or
/// * any of the write clauses.
///
/// For a query `MATCH (n) MERGE (n) -[e]- (m) SET n.x = 42 MERGE (l)` the
/// generated SingleQueryPart will have `matching` generated for the `MATCH`.
/// `remaining_clauses` will contain `Merge`, `SetProperty` and `Merge` clauses
/// in that exact order. The pattern inside the first `MERGE` will be used to
/// generate the first `merge_matching` element, and the second `MERGE` pattern
/// will produce the second `merge_matching` element. This way, if someone
/// traverses `remaining_clauses`, the order of appearance of `Merge` clauses is
/// in the same order as their respective `merge_matching` elements.
/// An exception to the above rule is Foreach. Its update clauses will not be contained in
/// the `remaining_clauses`, but rather inside the foreach itself. The order guarantee is not
/// violated because the update clauses of the foreach are immediately processed in
/// the `RuleBasedPlanner` as if as they were pushed into the `remaining_clauses`.
struct SingleQueryPart {
/// @brief All `MATCH` clauses merged into one @c Matching.
Matching matching;
/// @brief Each `OPTIONAL MATCH` converted to @c Matching.
std::vector<Matching> optional_matching{};
/// @brief @c Matching for each `MERGE` clause.
///
/// Storing the normalized pattern of a @c Merge does not preclude storing the
/// @c Merge clause itself inside `remaining_clauses`. The reason is that we
/// need to have access to other parts of the clause, such as `SET` clauses
/// which need to be run.
///
/// Since @c Merge is contained in `remaining_clauses`, this vector contains
/// matching in the same order as @c Merge appears.
//
/// Foreach @c does not violate this gurantee. However, update clauses are not stored
/// in the `remaining_clauses` but rather in the `Foreach` itself and are guranteed
/// to be processed in the same order by the semantics of the `RuleBasedPlanner`.
std::vector<Matching> merge_matching{};
/// @brief All the remaining clauses (without @c Match).
std::vector<Clause *> remaining_clauses{};
};
/// Holds query parts of a single query together with the optional information
/// about the combinator used between this single query and the previous one.
struct QueryPart {
std::vector<SingleQueryPart> single_query_parts = {};
/// Optional AST query combinator node
Tree *query_combinator = nullptr;
};
/// Holds query parts of all single queries together with the information
/// whether or not the resulting set should contain distinct elements.
struct QueryParts {
std::vector<QueryPart> query_parts = {};
/// Distinct flag, determined by the query combinator
bool distinct = false;
};
/// @brief Convert the AST to multiple @c QueryParts.
///
/// This function will normalize patterns inside @c Match and @c Merge clauses
/// and do some other preprocessing in order to generate multiple @c QueryPart
/// structures. @c AstStorage and @c SymbolTable may be used to create new
/// AST nodes.
QueryParts CollectQueryParts(SymbolTable &, AstStorage &, CypherQuery *);
} // namespace memgraph::query::v2::plan

View File

@ -0,0 +1,910 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/v2/plan/pretty_print.hpp"
#include <variant>
#include "query/v2/db_accessor.hpp"
#include "query/v2/frontend/ast/pretty_print.hpp"
#include "utils/string.hpp"
namespace memgraph::query::v2::plan {
PlanPrinter::PlanPrinter(const DbAccessor *dba, std::ostream *out) : dba_(dba), out_(out) {}
#define PRE_VISIT(TOp) \
bool PlanPrinter::PreVisit(TOp &) { \
WithPrintLn([](auto &out) { out << "* " << #TOp; }); \
return true; \
}
PRE_VISIT(CreateNode);
bool PlanPrinter::PreVisit(CreateExpand &op) {
WithPrintLn([&](auto &out) {
out << "* CreateExpand (" << op.input_symbol_.name() << ")"
<< (op.edge_info_.direction == query::v2::EdgeAtom::Direction::IN ? "<-" : "-") << "["
<< op.edge_info_.symbol.name() << ":" << dba_->EdgeTypeToName(op.edge_info_.edge_type) << "]"
<< (op.edge_info_.direction == query::v2::EdgeAtom::Direction::OUT ? "->" : "-") << "("
<< op.node_info_.symbol.name() << ")";
});
return true;
}
PRE_VISIT(Delete);
bool PlanPrinter::PreVisit(query::v2::plan::ScanAll &op) {
WithPrintLn([&](auto &out) {
out << "* ScanAll"
<< " (" << op.output_symbol_.name() << ")";
});
return true;
}
bool PlanPrinter::PreVisit(query::v2::plan::ScanAllByLabel &op) {
WithPrintLn([&](auto &out) {
out << "* ScanAllByLabel"
<< " (" << op.output_symbol_.name() << " :" << dba_->LabelToName(op.label_) << ")";
});
return true;
}
bool PlanPrinter::PreVisit(query::v2::plan::ScanAllByLabelPropertyValue &op) {
WithPrintLn([&](auto &out) {
out << "* ScanAllByLabelPropertyValue"
<< " (" << op.output_symbol_.name() << " :" << dba_->LabelToName(op.label_) << " {"
<< dba_->PropertyToName(op.property_) << "})";
});
return true;
}
bool PlanPrinter::PreVisit(query::v2::plan::ScanAllByLabelPropertyRange &op) {
WithPrintLn([&](auto &out) {
out << "* ScanAllByLabelPropertyRange"
<< " (" << op.output_symbol_.name() << " :" << dba_->LabelToName(op.label_) << " {"
<< dba_->PropertyToName(op.property_) << "})";
});
return true;
}
bool PlanPrinter::PreVisit(query::v2::plan::ScanAllByLabelProperty &op) {
WithPrintLn([&](auto &out) {
out << "* ScanAllByLabelProperty"
<< " (" << op.output_symbol_.name() << " :" << dba_->LabelToName(op.label_) << " {"
<< dba_->PropertyToName(op.property_) << "})";
});
return true;
}
bool PlanPrinter::PreVisit(ScanAllById &op) {
WithPrintLn([&](auto &out) {
out << "* ScanAllById"
<< " (" << op.output_symbol_.name() << ")";
});
return true;
}
bool PlanPrinter::PreVisit(query::v2::plan::Expand &op) {
WithPrintLn([&](auto &out) {
*out_ << "* Expand (" << op.input_symbol_.name() << ")"
<< (op.common_.direction == query::v2::EdgeAtom::Direction::IN ? "<-" : "-") << "["
<< op.common_.edge_symbol.name();
utils::PrintIterable(*out_, op.common_.edge_types, "|", [this](auto &stream, const auto &edge_type) {
stream << ":" << dba_->EdgeTypeToName(edge_type);
});
*out_ << "]" << (op.common_.direction == query::v2::EdgeAtom::Direction::OUT ? "->" : "-") << "("
<< op.common_.node_symbol.name() << ")";
});
return true;
}
bool PlanPrinter::PreVisit(query::v2::plan::ExpandVariable &op) {
using Type = query::v2::EdgeAtom::Type;
WithPrintLn([&](auto &out) {
*out_ << "* ";
switch (op.type_) {
case Type::DEPTH_FIRST:
*out_ << "ExpandVariable";
break;
case Type::BREADTH_FIRST:
*out_ << (op.common_.existing_node ? "STShortestPath" : "BFSExpand");
break;
case Type::WEIGHTED_SHORTEST_PATH:
*out_ << "WeightedShortestPath";
break;
case Type::SINGLE:
LOG_FATAL("Unexpected ExpandVariable::type_");
}
*out_ << " (" << op.input_symbol_.name() << ")"
<< (op.common_.direction == query::v2::EdgeAtom::Direction::IN ? "<-" : "-") << "["
<< op.common_.edge_symbol.name();
utils::PrintIterable(*out_, op.common_.edge_types, "|", [this](auto &stream, const auto &edge_type) {
stream << ":" << dba_->EdgeTypeToName(edge_type);
});
*out_ << "]" << (op.common_.direction == query::v2::EdgeAtom::Direction::OUT ? "->" : "-") << "("
<< op.common_.node_symbol.name() << ")";
});
return true;
}
bool PlanPrinter::PreVisit(query::v2::plan::Produce &op) {
WithPrintLn([&](auto &out) {
out << "* Produce {";
utils::PrintIterable(out, op.named_expressions_, ", ", [](auto &out, const auto &nexpr) { out << nexpr->name_; });
out << "}";
});
return true;
}
PRE_VISIT(ConstructNamedPath);
PRE_VISIT(Filter);
PRE_VISIT(SetProperty);
PRE_VISIT(SetProperties);
PRE_VISIT(SetLabels);
PRE_VISIT(RemoveProperty);
PRE_VISIT(RemoveLabels);
PRE_VISIT(EdgeUniquenessFilter);
PRE_VISIT(Accumulate);
bool PlanPrinter::PreVisit(query::v2::plan::Aggregate &op) {
WithPrintLn([&](auto &out) {
out << "* Aggregate {";
utils::PrintIterable(out, op.aggregations_, ", ",
[](auto &out, const auto &aggr) { out << aggr.output_sym.name(); });
out << "} {";
utils::PrintIterable(out, op.remember_, ", ", [](auto &out, const auto &sym) { out << sym.name(); });
out << "}";
});
return true;
}
PRE_VISIT(Skip);
PRE_VISIT(Limit);
bool PlanPrinter::PreVisit(query::v2::plan::OrderBy &op) {
WithPrintLn([&op](auto &out) {
out << "* OrderBy {";
utils::PrintIterable(out, op.output_symbols_, ", ", [](auto &out, const auto &sym) { out << sym.name(); });
out << "}";
});
return true;
}
bool PlanPrinter::PreVisit(query::v2::plan::Merge &op) {
WithPrintLn([](auto &out) { out << "* Merge"; });
Branch(*op.merge_match_, "On Match");
Branch(*op.merge_create_, "On Create");
op.input_->Accept(*this);
return false;
}
bool PlanPrinter::PreVisit(query::v2::plan::Optional &op) {
WithPrintLn([](auto &out) { out << "* Optional"; });
Branch(*op.optional_);
op.input_->Accept(*this);
return false;
}
PRE_VISIT(Unwind);
PRE_VISIT(Distinct);
bool PlanPrinter::PreVisit(query::v2::plan::Union &op) {
WithPrintLn([&op](auto &out) {
out << "* Union {";
utils::PrintIterable(out, op.left_symbols_, ", ", [](auto &out, const auto &sym) { out << sym.name(); });
out << " : ";
utils::PrintIterable(out, op.right_symbols_, ", ", [](auto &out, const auto &sym) { out << sym.name(); });
out << "}";
});
Branch(*op.right_op_);
op.left_op_->Accept(*this);
return false;
}
bool PlanPrinter::PreVisit(query::v2::plan::CallProcedure &op) {
WithPrintLn([&op](auto &out) {
out << "* CallProcedure<" << op.procedure_name_ << "> {";
utils::PrintIterable(out, op.result_symbols_, ", ", [](auto &out, const auto &sym) { out << sym.name(); });
out << "}";
});
return true;
}
bool PlanPrinter::PreVisit(query::v2::plan::LoadCsv &op) {
WithPrintLn([&op](auto &out) { out << "* LoadCsv {" << op.row_var_.name() << "}"; });
return true;
}
bool PlanPrinter::Visit(query::v2::plan::Once & /*op*/) {
WithPrintLn([](auto &out) { out << "* Once"; });
return true;
}
bool PlanPrinter::PreVisit(query::v2::plan::Cartesian &op) {
WithPrintLn([&op](auto &out) {
out << "* Cartesian {";
utils::PrintIterable(out, op.left_symbols_, ", ", [](auto &out, const auto &sym) { out << sym.name(); });
out << " : ";
utils::PrintIterable(out, op.right_symbols_, ", ", [](auto &out, const auto &sym) { out << sym.name(); });
out << "}";
});
Branch(*op.right_op_);
op.left_op_->Accept(*this);
return false;
}
bool PlanPrinter::PreVisit(query::v2::plan::Foreach &op) {
WithPrintLn([](auto &out) { out << "* Foreach"; });
Branch(*op.update_clauses_);
op.input_->Accept(*this);
return false;
}
#undef PRE_VISIT
bool PlanPrinter::DefaultPreVisit() {
WithPrintLn([](auto &out) { out << "* Unknown operator!"; });
return true;
}
void PlanPrinter::Branch(query::v2::plan::LogicalOperator &op, const std::string &branch_name) {
WithPrintLn([&](auto &out) { out << "|\\ " << branch_name; });
++depth_;
op.Accept(*this);
--depth_;
}
void PrettyPrint(const DbAccessor &dba, const LogicalOperator *plan_root, std::ostream *out) {
PlanPrinter printer(&dba, out);
// FIXME(mtomic): We should make visitors that take const arguments.
const_cast<LogicalOperator *>(plan_root)->Accept(printer);
}
nlohmann::json PlanToJson(const DbAccessor &dba, const LogicalOperator *plan_root) {
impl::PlanToJsonVisitor visitor(&dba);
// FIXME(mtomic): We should make visitors that take const arguments.
const_cast<LogicalOperator *>(plan_root)->Accept(visitor);
return visitor.output();
}
namespace impl {
///////////////////////////////////////////////////////////////////////////////
//
// PlanToJsonVisitor implementation
//
// The JSON formatted plan is consumed (or will be) by Memgraph Lab, and
// therefore should not be changed before synchronizing with whoever is
// maintaining Memgraph Lab. Hopefully, one day integration tests will exist and
// there will be no need to be super careful.
using nlohmann::json;
//////////////////////////// HELPER FUNCTIONS /////////////////////////////////
// TODO: It would be nice to have enum->string functions auto-generated.
std::string ToString(EdgeAtom::Direction dir) {
switch (dir) {
case EdgeAtom::Direction::BOTH:
return "both";
case EdgeAtom::Direction::IN:
return "in";
case EdgeAtom::Direction::OUT:
return "out";
}
}
std::string ToString(EdgeAtom::Type type) {
switch (type) {
case EdgeAtom::Type::BREADTH_FIRST:
return "bfs";
case EdgeAtom::Type::DEPTH_FIRST:
return "dfs";
case EdgeAtom::Type::WEIGHTED_SHORTEST_PATH:
return "wsp";
case EdgeAtom::Type::SINGLE:
return "single";
}
}
std::string ToString(Ordering ord) {
switch (ord) {
case Ordering::ASC:
return "asc";
case Ordering::DESC:
return "desc";
}
}
json ToJson(Expression *expression) {
std::stringstream sstr;
PrintExpression(expression, &sstr);
return sstr.str();
}
json ToJson(const utils::Bound<Expression *> &bound) {
json json;
switch (bound.type()) {
case utils::BoundType::INCLUSIVE:
json["type"] = "inclusive";
break;
case utils::BoundType::EXCLUSIVE:
json["type"] = "exclusive";
break;
}
json["value"] = ToJson(bound.value());
return json;
}
json ToJson(const Symbol &symbol) { return symbol.name(); }
json ToJson(storage::v3::EdgeTypeId edge_type, const DbAccessor &dba) { return dba.EdgeTypeToName(edge_type); }
json ToJson(storage::v3::LabelId label, const DbAccessor &dba) { return dba.LabelToName(label); }
json ToJson(storage::v3::PropertyId property, const DbAccessor &dba) { return dba.PropertyToName(property); }
json ToJson(NamedExpression *nexpr) {
json json;
json["expression"] = ToJson(nexpr->expression_);
json["name"] = nexpr->name_;
return json;
}
json ToJson(const std::vector<std::pair<storage::v3::PropertyId, Expression *>> &properties, const DbAccessor &dba) {
json json;
for (const auto &prop_pair : properties) {
json.emplace(ToJson(prop_pair.first, dba), ToJson(prop_pair.second));
}
return json;
}
json ToJson(const NodeCreationInfo &node_info, const DbAccessor &dba) {
json self;
self["symbol"] = ToJson(node_info.symbol);
self["labels"] = ToJson(node_info.labels, dba);
const auto *props = std::get_if<PropertiesMapList>(&node_info.properties);
self["properties"] = ToJson(props ? *props : PropertiesMapList{}, dba);
return self;
}
json ToJson(const EdgeCreationInfo &edge_info, const DbAccessor &dba) {
json self;
self["symbol"] = ToJson(edge_info.symbol);
const auto *props = std::get_if<PropertiesMapList>(&edge_info.properties);
self["properties"] = ToJson(props ? *props : PropertiesMapList{}, dba);
self["edge_type"] = ToJson(edge_info.edge_type, dba);
self["direction"] = ToString(edge_info.direction);
return self;
}
json ToJson(const Aggregate::Element &elem) {
json json;
if (elem.value) {
json["value"] = ToJson(elem.value);
}
if (elem.key) {
json["key"] = ToJson(elem.key);
}
json["op"] = utils::ToLowerCase(Aggregation::OpToString(elem.op));
json["output_symbol"] = ToJson(elem.output_sym);
return json;
}
////////////////////////// END HELPER FUNCTIONS ////////////////////////////////
bool PlanToJsonVisitor::Visit(Once &) {
json self;
self["name"] = "Once";
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(ScanAll &op) {
json self;
self["name"] = "ScanAll";
self["output_symbol"] = ToJson(op.output_symbol_);
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(ScanAllByLabel &op) {
json self;
self["name"] = "ScanAllByLabel";
self["label"] = ToJson(op.label_, *dba_);
self["output_symbol"] = ToJson(op.output_symbol_);
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(ScanAllByLabelPropertyRange &op) {
json self;
self["name"] = "ScanAllByLabelPropertyRange";
self["label"] = ToJson(op.label_, *dba_);
self["property"] = ToJson(op.property_, *dba_);
self["lower_bound"] = op.lower_bound_ ? ToJson(*op.lower_bound_) : json();
self["upper_bound"] = op.upper_bound_ ? ToJson(*op.upper_bound_) : json();
self["output_symbol"] = ToJson(op.output_symbol_);
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(ScanAllByLabelPropertyValue &op) {
json self;
self["name"] = "ScanAllByLabelPropertyValue";
self["label"] = ToJson(op.label_, *dba_);
self["property"] = ToJson(op.property_, *dba_);
self["expression"] = ToJson(op.expression_);
self["output_symbol"] = ToJson(op.output_symbol_);
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(ScanAllByLabelProperty &op) {
json self;
self["name"] = "ScanAllByLabelProperty";
self["label"] = ToJson(op.label_, *dba_);
self["property"] = ToJson(op.property_, *dba_);
self["output_symbol"] = ToJson(op.output_symbol_);
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(ScanAllById &op) {
json self;
self["name"] = "ScanAllById";
self["output_symbol"] = ToJson(op.output_symbol_);
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(CreateNode &op) {
json self;
self["name"] = "CreateNode";
self["node_info"] = ToJson(op.node_info_, *dba_);
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(CreateExpand &op) {
json self;
self["name"] = "CreateExpand";
self["input_symbol"] = ToJson(op.input_symbol_);
self["node_info"] = ToJson(op.node_info_, *dba_);
self["edge_info"] = ToJson(op.edge_info_, *dba_);
self["existing_node"] = op.existing_node_;
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(Expand &op) {
json self;
self["name"] = "Expand";
self["input_symbol"] = ToJson(op.input_symbol_);
self["node_symbol"] = ToJson(op.common_.node_symbol);
self["edge_symbol"] = ToJson(op.common_.edge_symbol);
self["edge_types"] = ToJson(op.common_.edge_types, *dba_);
self["direction"] = ToString(op.common_.direction);
self["existing_node"] = op.common_.existing_node;
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(ExpandVariable &op) {
json self;
self["name"] = "ExpandVariable";
self["input_symbol"] = ToJson(op.input_symbol_);
self["node_symbol"] = ToJson(op.common_.node_symbol);
self["edge_symbol"] = ToJson(op.common_.edge_symbol);
self["edge_types"] = ToJson(op.common_.edge_types, *dba_);
self["direction"] = ToString(op.common_.direction);
self["type"] = ToString(op.type_);
self["is_reverse"] = op.is_reverse_;
self["lower_bound"] = op.lower_bound_ ? ToJson(op.lower_bound_) : json();
self["upper_bound"] = op.upper_bound_ ? ToJson(op.upper_bound_) : json();
self["existing_node"] = op.common_.existing_node;
self["filter_lambda"] = op.filter_lambda_.expression ? ToJson(op.filter_lambda_.expression) : json();
if (op.type_ == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH) {
self["weight_lambda"] = ToJson(op.weight_lambda_->expression);
self["total_weight_symbol"] = ToJson(*op.total_weight_);
}
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(ConstructNamedPath &op) {
json self;
self["name"] = "ConstructNamedPath";
self["path_symbol"] = ToJson(op.path_symbol_);
self["path_elements"] = ToJson(op.path_elements_);
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(Filter &op) {
json self;
self["name"] = "Filter";
self["expression"] = ToJson(op.expression_);
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(Produce &op) {
json self;
self["name"] = "Produce";
self["named_expressions"] = ToJson(op.named_expressions_);
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(Delete &op) {
json self;
self["name"] = "Delete";
self["expressions"] = ToJson(op.expressions_);
self["detach"] = op.detach_;
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(SetProperty &op) {
json self;
self["name"] = "SetProperty";
self["property"] = ToJson(op.property_, *dba_);
self["lhs"] = ToJson(op.lhs_);
self["rhs"] = ToJson(op.rhs_);
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(SetProperties &op) {
json self;
self["name"] = "SetProperties";
self["input_symbol"] = ToJson(op.input_symbol_);
self["rhs"] = ToJson(op.rhs_);
switch (op.op_) {
case SetProperties::Op::UPDATE:
self["op"] = "update";
break;
case SetProperties::Op::REPLACE:
self["op"] = "replace";
break;
}
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(SetLabels &op) {
json self;
self["name"] = "SetLabels";
self["input_symbol"] = ToJson(op.input_symbol_);
self["labels"] = ToJson(op.labels_, *dba_);
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(RemoveProperty &op) {
json self;
self["name"] = "RemoveProperty";
self["property"] = ToJson(op.property_, *dba_);
self["lhs"] = ToJson(op.lhs_);
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(RemoveLabels &op) {
json self;
self["name"] = "RemoveLabels";
self["input_symbol"] = ToJson(op.input_symbol_);
self["labels"] = ToJson(op.labels_, *dba_);
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(EdgeUniquenessFilter &op) {
json self;
self["name"] = "EdgeUniquenessFilter";
self["expand_symbol"] = ToJson(op.expand_symbol_);
self["previous_symbols"] = ToJson(op.previous_symbols_);
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(Accumulate &op) {
json self;
self["name"] = "Accumulate";
self["symbols"] = ToJson(op.symbols_);
self["advance_command"] = op.advance_command_;
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(Aggregate &op) {
json self;
self["name"] = "Aggregate";
self["aggregations"] = ToJson(op.aggregations_);
self["group_by"] = ToJson(op.group_by_);
self["remember"] = ToJson(op.remember_);
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(Skip &op) {
json self;
self["name"] = "Skip";
self["expression"] = ToJson(op.expression_);
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(Limit &op) {
json self;
self["name"] = "Limit";
self["expression"] = ToJson(op.expression_);
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(OrderBy &op) {
json self;
self["name"] = "OrderBy";
for (auto i = 0; i < op.order_by_.size(); ++i) {
json json;
json["ordering"] = ToString(op.compare_.ordering_[i]);
json["expression"] = ToJson(op.order_by_[i]);
self["order_by"].push_back(json);
}
self["output_symbols"] = ToJson(op.output_symbols_);
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(Merge &op) {
json self;
self["name"] = "Merge";
op.input_->Accept(*this);
self["input"] = PopOutput();
op.merge_match_->Accept(*this);
self["merge_match"] = PopOutput();
op.merge_create_->Accept(*this);
self["merge_create"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(Optional &op) {
json self;
self["name"] = "Optional";
self["optional_symbols"] = ToJson(op.optional_symbols_);
op.input_->Accept(*this);
self["input"] = PopOutput();
op.optional_->Accept(*this);
self["optional"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(Unwind &op) {
json self;
self["name"] = "Unwind";
self["output_symbol"] = ToJson(op.output_symbol_);
self["input_expression"] = ToJson(op.input_expression_);
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(query::v2::plan::CallProcedure &op) {
json self;
self["name"] = "CallProcedure";
self["procedure_name"] = op.procedure_name_;
self["arguments"] = ToJson(op.arguments_);
self["result_fields"] = op.result_fields_;
self["result_symbols"] = ToJson(op.result_symbols_);
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(query::v2::plan::LoadCsv &op) {
json self;
self["name"] = "LoadCsv";
self["file"] = ToJson(op.file_);
self["with_header"] = op.with_header_;
self["ignore_bad"] = op.ignore_bad_;
self["delimiter"] = ToJson(op.delimiter_);
self["quote"] = ToJson(op.quote_);
self["row_variable"] = ToJson(op.row_var_);
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(Distinct &op) {
json self;
self["name"] = "Distinct";
self["value_symbols"] = ToJson(op.value_symbols_);
op.input_->Accept(*this);
self["input"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(Union &op) {
json self;
self["name"] = "Union";
self["union_symbols"] = ToJson(op.union_symbols_);
self["left_symbols"] = ToJson(op.left_symbols_);
self["right_symbols"] = ToJson(op.right_symbols_);
op.left_op_->Accept(*this);
self["left_op"] = PopOutput();
op.right_op_->Accept(*this);
self["right_op"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(Cartesian &op) {
json self;
self["name"] = "Cartesian";
self["left_symbols"] = ToJson(op.left_symbols_);
self["right_symbols"] = ToJson(op.right_symbols_);
op.left_op_->Accept(*this);
self["left_op"] = PopOutput();
op.right_op_->Accept(*this);
self["right_op"] = PopOutput();
output_ = std::move(self);
return false;
}
bool PlanToJsonVisitor::PreVisit(Foreach &op) {
json self;
self["name"] = "Foreach";
self["loop_variable_symbol"] = ToJson(op.loop_variable_symbol_);
self["expression"] = ToJson(op.expression_);
op.input_->Accept(*this);
self["input"] = PopOutput();
op.update_clauses_->Accept(*this);
self["update_clauses"] = PopOutput();
output_ = std::move(self);
return false;
}
} // namespace impl
} // namespace memgraph::query::v2::plan

View File

@ -0,0 +1,230 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
/// @file
#pragma once
#include <iostream>
#include <json/json.hpp>
#include "query/v2/plan/operator.hpp"
namespace memgraph::query::v2 {
class DbAccessor;
namespace plan {
class LogicalOperator;
/// Pretty print a `LogicalOperator` plan to a `std::ostream`.
/// DbAccessor is needed for resolving label and property names.
/// Note that `plan_root` isn't modified, but we can't take it as a const
/// because we don't have support for visiting a const LogicalOperator.
void PrettyPrint(const DbAccessor &dba, const LogicalOperator *plan_root, std::ostream *out);
/// Overload of `PrettyPrint` which defaults the `std::ostream` to `std::cout`.
inline void PrettyPrint(const DbAccessor &dba, const LogicalOperator *plan_root) {
PrettyPrint(dba, plan_root, &std::cout);
}
/// Convert a `LogicalOperator` plan to a JSON representation.
/// DbAccessor is needed for resolving label and property names.
nlohmann::json PlanToJson(const DbAccessor &dba, const LogicalOperator *plan_root);
class PlanPrinter : public virtual HierarchicalLogicalOperatorVisitor {
public:
using HierarchicalLogicalOperatorVisitor::PostVisit;
using HierarchicalLogicalOperatorVisitor::PreVisit;
using HierarchicalLogicalOperatorVisitor::Visit;
PlanPrinter(const DbAccessor *dba, std::ostream *out);
bool DefaultPreVisit() override;
bool PreVisit(CreateNode &) override;
bool PreVisit(CreateExpand &) override;
bool PreVisit(Delete &) override;
bool PreVisit(SetProperty &) override;
bool PreVisit(SetProperties &) override;
bool PreVisit(SetLabels &) override;
bool PreVisit(RemoveProperty &) override;
bool PreVisit(RemoveLabels &) override;
bool PreVisit(ScanAll &) override;
bool PreVisit(ScanAllByLabel &) override;
bool PreVisit(ScanAllByLabelPropertyValue &) override;
bool PreVisit(ScanAllByLabelPropertyRange &) override;
bool PreVisit(ScanAllByLabelProperty &) override;
bool PreVisit(ScanAllById &) override;
bool PreVisit(Expand &) override;
bool PreVisit(ExpandVariable &) override;
bool PreVisit(ConstructNamedPath &) override;
bool PreVisit(Filter &) override;
bool PreVisit(EdgeUniquenessFilter &) override;
bool PreVisit(Merge &) override;
bool PreVisit(Optional &) override;
bool PreVisit(Cartesian &) override;
bool PreVisit(Produce &) override;
bool PreVisit(Accumulate &) override;
bool PreVisit(Aggregate &) override;
bool PreVisit(Skip &) override;
bool PreVisit(Limit &) override;
bool PreVisit(OrderBy &) override;
bool PreVisit(Distinct &) override;
bool PreVisit(Union &) override;
bool PreVisit(Unwind &) override;
bool PreVisit(CallProcedure &) override;
bool PreVisit(LoadCsv &) override;
bool PreVisit(Foreach &) override;
bool Visit(Once &) override;
/// Call fun with output stream. The stream is prefixed with amount of spaces
/// corresponding to the current depth_.
template <class TFun>
void WithPrintLn(TFun fun) {
*out_ << " ";
for (int64_t i = 0; i < depth_; ++i) {
*out_ << "| ";
}
fun(*out_);
*out_ << std::endl;
}
/// Forward this printer to another operator branch by incrementing the depth
/// and printing the branch name.
void Branch(LogicalOperator &op, const std::string &branch_name = "");
int64_t depth_{0};
const DbAccessor *dba_{nullptr};
std::ostream *out_{nullptr};
};
namespace impl {
std::string ToString(EdgeAtom::Direction dir);
std::string ToString(EdgeAtom::Type type);
std::string ToString(Ordering ord);
nlohmann::json ToJson(Expression *expression);
nlohmann::json ToJson(const utils::Bound<Expression *> &bound);
nlohmann::json ToJson(const Symbol &symbol);
nlohmann::json ToJson(storage::v3::EdgeTypeId edge_type, const DbAccessor &dba);
nlohmann::json ToJson(storage::v3::LabelId label, const DbAccessor &dba);
nlohmann::json ToJson(storage::v3::PropertyId property, const DbAccessor &dba);
nlohmann::json ToJson(NamedExpression *nexpr);
nlohmann::json ToJson(const std::vector<std::pair<storage::v3::PropertyId, Expression *>> &properties,
const DbAccessor &dba);
nlohmann::json ToJson(const NodeCreationInfo &node_info, const DbAccessor &dba);
nlohmann::json ToJson(const EdgeCreationInfo &edge_info, const DbAccessor &dba);
nlohmann::json ToJson(const Aggregate::Element &elem);
template <class T, class... Args>
nlohmann::json ToJson(const std::vector<T> &items, Args &&...args) {
nlohmann::json json;
for (const auto &item : items) {
json.emplace_back(ToJson(item, std::forward<Args>(args)...));
}
return json;
}
class PlanToJsonVisitor : public virtual HierarchicalLogicalOperatorVisitor {
public:
explicit PlanToJsonVisitor(const DbAccessor *dba) : dba_(dba) {}
using HierarchicalLogicalOperatorVisitor::PostVisit;
using HierarchicalLogicalOperatorVisitor::PreVisit;
using HierarchicalLogicalOperatorVisitor::Visit;
bool PreVisit(CreateNode &) override;
bool PreVisit(CreateExpand &) override;
bool PreVisit(Delete &) override;
bool PreVisit(SetProperty &) override;
bool PreVisit(SetProperties &) override;
bool PreVisit(SetLabels &) override;
bool PreVisit(RemoveProperty &) override;
bool PreVisit(RemoveLabels &) override;
bool PreVisit(Expand &) override;
bool PreVisit(ExpandVariable &) override;
bool PreVisit(ConstructNamedPath &) override;
bool PreVisit(Merge &) override;
bool PreVisit(Optional &) override;
bool PreVisit(Filter &) override;
bool PreVisit(EdgeUniquenessFilter &) override;
bool PreVisit(Cartesian &) override;
bool PreVisit(ScanAll &) override;
bool PreVisit(ScanAllByLabel &) override;
bool PreVisit(ScanAllByLabelPropertyRange &) override;
bool PreVisit(ScanAllByLabelPropertyValue &) override;
bool PreVisit(ScanAllByLabelProperty &) override;
bool PreVisit(ScanAllById &) override;
bool PreVisit(Produce &) override;
bool PreVisit(Accumulate &) override;
bool PreVisit(Aggregate &) override;
bool PreVisit(Skip &) override;
bool PreVisit(Limit &) override;
bool PreVisit(OrderBy &) override;
bool PreVisit(Distinct &) override;
bool PreVisit(Union &) override;
bool PreVisit(Unwind &) override;
bool PreVisit(Foreach &) override;
bool PreVisit(CallProcedure &) override;
bool PreVisit(LoadCsv &) override;
bool Visit(Once &) override;
nlohmann::json output() { return output_; }
protected:
nlohmann::json output_;
const DbAccessor *dba_;
nlohmann::json PopOutput() {
nlohmann::json tmp;
tmp.swap(output_);
return tmp;
}
};
} // namespace impl
} // namespace plan
} // namespace memgraph::query::v2

View File

@ -0,0 +1,166 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/v2/plan/profile.hpp"
#include <algorithm>
#include <chrono>
#include <fmt/format.h>
#include <json/json.hpp>
#include "query/v2/context.hpp"
#include "utils/likely.hpp"
namespace memgraph::query::v2::plan {
namespace {
unsigned long long IndividualCycles(const ProfilingStats &cumulative_stats) {
return cumulative_stats.num_cycles - std::accumulate(cumulative_stats.children.begin(),
cumulative_stats.children.end(), 0ULL,
[](auto acc, auto &stats) { return acc + stats.num_cycles; });
}
double RelativeTime(unsigned long long num_cycles, unsigned long long total_cycles) {
return static_cast<double>(num_cycles) / total_cycles;
}
double AbsoluteTime(unsigned long long num_cycles, unsigned long long total_cycles,
std::chrono::duration<double> total_time) {
return (RelativeTime(num_cycles, total_cycles) * static_cast<std::chrono::duration<double, std::milli>>(total_time))
.count();
}
} // namespace
//////////////////////////////////////////////////////////////////////////////
//
// ProfilingStatsToTable
namespace {
class ProfilingStatsToTableHelper {
public:
ProfilingStatsToTableHelper(unsigned long long total_cycles, std::chrono::duration<double> total_time)
: total_cycles_(total_cycles), total_time_(total_time) {}
void Output(const ProfilingStats &cumulative_stats) {
auto cycles = IndividualCycles(cumulative_stats);
rows_.emplace_back(std::vector<TypedValue>{
TypedValue(FormatOperator(cumulative_stats.name)), TypedValue(cumulative_stats.actual_hits),
TypedValue(FormatRelativeTime(cycles)), TypedValue(FormatAbsoluteTime(cycles))});
for (size_t i = 1; i < cumulative_stats.children.size(); ++i) {
Branch(cumulative_stats.children[i]);
}
if (cumulative_stats.children.size() >= 1) {
Output(cumulative_stats.children[0]);
}
}
std::vector<std::vector<TypedValue>> rows() { return rows_; }
private:
void Branch(const ProfilingStats &cumulative_stats) {
rows_.emplace_back(std::vector<TypedValue>{TypedValue("|\\"), TypedValue(""), TypedValue(""), TypedValue("")});
++depth_;
Output(cumulative_stats);
--depth_;
}
std::string Format(const char *str) {
std::ostringstream ss;
for (int64_t i = 0; i < depth_; ++i) {
ss << "| ";
}
ss << str;
return ss.str();
}
std::string Format(const std::string &str) { return Format(str.c_str()); }
std::string FormatOperator(const char *str) { return Format(std::string("* ") + str); }
std::string FormatRelativeTime(unsigned long long num_cycles) {
return fmt::format("{: 10.6f} %", RelativeTime(num_cycles, total_cycles_) * 100);
}
std::string FormatAbsoluteTime(unsigned long long num_cycles) {
return fmt::format("{: 10.6f} ms", AbsoluteTime(num_cycles, total_cycles_, total_time_));
}
int64_t depth_{0};
std::vector<std::vector<TypedValue>> rows_;
unsigned long long total_cycles_;
std::chrono::duration<double> total_time_;
};
} // namespace
std::vector<std::vector<TypedValue>> ProfilingStatsToTable(const ProfilingStatsWithTotalTime &stats) {
ProfilingStatsToTableHelper helper{stats.cumulative_stats.num_cycles, stats.total_time};
helper.Output(stats.cumulative_stats);
return helper.rows();
}
//////////////////////////////////////////////////////////////////////////////
//
// ProfilingStatsToJson
namespace {
class ProfilingStatsToJsonHelper {
private:
using json = nlohmann::json;
public:
ProfilingStatsToJsonHelper(unsigned long long total_cycles, std::chrono::duration<double> total_time)
: total_cycles_(total_cycles), total_time_(total_time) {}
void Output(const ProfilingStats &cumulative_stats) { return Output(cumulative_stats, &json_); }
json ToJson() { return json_; }
private:
void Output(const ProfilingStats &cumulative_stats, json *obj) {
auto cycles = IndividualCycles(cumulative_stats);
obj->emplace("name", cumulative_stats.name);
obj->emplace("actual_hits", cumulative_stats.actual_hits);
obj->emplace("relative_time", RelativeTime(cycles, total_cycles_));
obj->emplace("absolute_time", AbsoluteTime(cycles, total_cycles_, total_time_));
obj->emplace("children", json::array());
for (size_t i = 0; i < cumulative_stats.children.size(); ++i) {
json child;
Output(cumulative_stats.children[i], &child);
obj->at("children").emplace_back(std::move(child));
}
}
json json_;
unsigned long long total_cycles_;
std::chrono::duration<double> total_time_;
};
} // namespace
nlohmann::json ProfilingStatsToJson(const ProfilingStatsWithTotalTime &stats) {
ProfilingStatsToJsonHelper helper{stats.cumulative_stats.num_cycles, stats.total_time};
helper.Output(stats.cumulative_stats);
return helper.ToJson();
}
} // namespace memgraph::query::v2::plan

View File

@ -0,0 +1,47 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <cstdint>
#include <vector>
#include <json/json.hpp>
#include "query/v2/typed_value.hpp"
namespace memgraph::query::v2 {
namespace plan {
/**
* Stores profiling statistics for a single logical operator.
*/
struct ProfilingStats {
int64_t actual_hits{0};
unsigned long long num_cycles{0};
uint64_t key{0};
const char *name{nullptr};
// TODO: This should use the allocator for query execution
std::vector<ProfilingStats> children;
};
struct ProfilingStatsWithTotalTime {
ProfilingStats cumulative_stats{};
std::chrono::duration<double> total_time{};
};
std::vector<std::vector<TypedValue>> ProfilingStatsToTable(const ProfilingStatsWithTotalTime &stats);
nlohmann::json ProfilingStatsToJson(const ProfilingStatsWithTotalTime &stats);
} // namespace plan
} // namespace memgraph::query::v2

View File

@ -0,0 +1,128 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/v2/plan/read_write_type_checker.hpp"
#define PRE_VISIT(TOp, RWType, continue_visiting) \
bool ReadWriteTypeChecker::PreVisit(TOp &op) { \
UpdateType(RWType); \
return continue_visiting; \
}
namespace memgraph::query::v2::plan {
PRE_VISIT(CreateNode, RWType::W, true)
PRE_VISIT(CreateExpand, RWType::R, true)
PRE_VISIT(Delete, RWType::W, true)
PRE_VISIT(SetProperty, RWType::W, true)
PRE_VISIT(SetProperties, RWType::W, true)
PRE_VISIT(SetLabels, RWType::W, true)
PRE_VISIT(RemoveProperty, RWType::W, true)
PRE_VISIT(RemoveLabels, RWType::W, true)
PRE_VISIT(ScanAll, RWType::R, true)
PRE_VISIT(ScanAllByLabel, RWType::R, true)
PRE_VISIT(ScanAllByLabelPropertyRange, RWType::R, true)
PRE_VISIT(ScanAllByLabelPropertyValue, RWType::R, true)
PRE_VISIT(ScanAllByLabelProperty, RWType::R, true)
PRE_VISIT(ScanAllById, RWType::R, true)
PRE_VISIT(Expand, RWType::R, true)
PRE_VISIT(ExpandVariable, RWType::R, true)
PRE_VISIT(ConstructNamedPath, RWType::R, true)
PRE_VISIT(Filter, RWType::NONE, true)
PRE_VISIT(EdgeUniquenessFilter, RWType::NONE, true)
PRE_VISIT(Merge, RWType::RW, false)
PRE_VISIT(Optional, RWType::NONE, true)
bool ReadWriteTypeChecker::PreVisit(Cartesian &op) {
op.left_op_->Accept(*this);
op.right_op_->Accept(*this);
return false;
}
PRE_VISIT(Produce, RWType::NONE, true)
PRE_VISIT(Accumulate, RWType::NONE, true)
PRE_VISIT(Aggregate, RWType::NONE, true)
PRE_VISIT(Skip, RWType::NONE, true)
PRE_VISIT(Limit, RWType::NONE, true)
PRE_VISIT(OrderBy, RWType::NONE, true)
PRE_VISIT(Distinct, RWType::NONE, true)
bool ReadWriteTypeChecker::PreVisit(Union &op) {
op.left_op_->Accept(*this);
op.right_op_->Accept(*this);
return false;
}
PRE_VISIT(Unwind, RWType::NONE, true)
bool ReadWriteTypeChecker::PreVisit(CallProcedure &op) {
if (op.is_write_) {
UpdateType(RWType::RW);
return false;
}
UpdateType(RWType::R);
return true;
}
bool ReadWriteTypeChecker::PreVisit([[maybe_unused]] Foreach &op) {
UpdateType(RWType::RW);
return false;
}
#undef PRE_VISIT
bool ReadWriteTypeChecker::Visit(Once &op) { return false; }
void ReadWriteTypeChecker::UpdateType(RWType op_type) {
// Update type only if it's not the NONE type and the current operator's type
// is different than the one that's currently inferred.
if (type != RWType::NONE && type != op_type) {
type = RWType::RW;
}
// Stop inference because RW is the most "dominant" type, i.e. it isn't
// affected by the type of nodes in the plan appearing after the node for
// which the type is set to RW.
if (type == RWType::RW) {
return;
}
if (type == RWType::NONE && op_type != RWType::NONE) {
type = op_type;
}
}
void ReadWriteTypeChecker::InferRWType(LogicalOperator &root) { root.Accept(*this); }
std::string ReadWriteTypeChecker::TypeToString(const RWType type) {
switch (type) {
// Unfortunately, neo4j Java drivers do not allow query types that differ
// from the ones defined by neo4j. We'll keep using the NONE type internally
// but we'll convert it to "rw" to keep in line with the neo4j definition.
// Oddly enough, but not surprisingly, Python drivers don't have any problems
// with non-neo4j query types.
case RWType::NONE:
return "rw";
case RWType::R:
return "r";
case RWType::W:
return "w";
case RWType::RW:
return "rw";
}
}
} // namespace memgraph::query::v2::plan

View File

@ -0,0 +1,95 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include "query/v2/plan/operator.hpp"
namespace memgraph::query::v2::plan {
class ReadWriteTypeChecker : public virtual HierarchicalLogicalOperatorVisitor {
public:
ReadWriteTypeChecker() = default;
ReadWriteTypeChecker(const ReadWriteTypeChecker &) = delete;
ReadWriteTypeChecker(ReadWriteTypeChecker &&) = delete;
ReadWriteTypeChecker &operator=(const ReadWriteTypeChecker &) = delete;
ReadWriteTypeChecker &operator=(ReadWriteTypeChecker &&) = delete;
using HierarchicalLogicalOperatorVisitor::PostVisit;
using HierarchicalLogicalOperatorVisitor::PreVisit;
using HierarchicalLogicalOperatorVisitor::Visit;
// NONE type describes an operator whose action neither reads nor writes from
// the database (e.g. Produce or Once).
// R type describes an operator whose action involves reading from the
// database.
// W type describes an operator whose action involves writing to the
// database.
// RW type describes an operator whose action involves both reading and
// writing to the database.
enum class RWType : uint8_t { NONE, R, W, RW };
RWType type{RWType::NONE};
void InferRWType(LogicalOperator &root);
static std::string TypeToString(const RWType type);
bool PreVisit(CreateNode &) override;
bool PreVisit(CreateExpand &) override;
bool PreVisit(Delete &) override;
bool PreVisit(SetProperty &) override;
bool PreVisit(SetProperties &) override;
bool PreVisit(SetLabels &) override;
bool PreVisit(RemoveProperty &) override;
bool PreVisit(RemoveLabels &) override;
bool PreVisit(ScanAll &) override;
bool PreVisit(ScanAllByLabel &) override;
bool PreVisit(ScanAllByLabelPropertyValue &) override;
bool PreVisit(ScanAllByLabelPropertyRange &) override;
bool PreVisit(ScanAllByLabelProperty &) override;
bool PreVisit(ScanAllById &) override;
bool PreVisit(Expand &) override;
bool PreVisit(ExpandVariable &) override;
bool PreVisit(ConstructNamedPath &) override;
bool PreVisit(Filter &) override;
bool PreVisit(EdgeUniquenessFilter &) override;
bool PreVisit(Merge &) override;
bool PreVisit(Optional &) override;
bool PreVisit(Cartesian &) override;
bool PreVisit(Produce &) override;
bool PreVisit(Accumulate &) override;
bool PreVisit(Aggregate &) override;
bool PreVisit(Skip &) override;
bool PreVisit(Limit &) override;
bool PreVisit(OrderBy &) override;
bool PreVisit(Distinct &) override;
bool PreVisit(Union &) override;
bool PreVisit(Unwind &) override;
bool PreVisit(CallProcedure &) override;
bool PreVisit(Foreach &) override;
bool Visit(Once &) override;
private:
void UpdateType(RWType op_type);
};
} // namespace memgraph::query::v2::plan

View File

@ -0,0 +1,50 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/v2/plan/rewrite/index_lookup.hpp"
#include "utils/flag_validation.hpp"
DEFINE_VALIDATED_HIDDEN_int64(query_vertex_count_to_expand_existing, 10,
"Maximum count of indexed vertices which provoke "
"indexed lookup and then expand to existing, instead of "
"a regular expand. Default is 10, to turn off use -1.",
FLAG_IN_RANGE(-1, std::numeric_limits<std::int64_t>::max()));
namespace memgraph::query::v2::plan::impl {
Expression *RemoveAndExpressions(Expression *expr, const std::unordered_set<Expression *> &exprs_to_remove) {
auto *and_op = utils::Downcast<AndOperator>(expr);
if (!and_op) return expr;
if (utils::Contains(exprs_to_remove, and_op)) {
return nullptr;
}
if (utils::Contains(exprs_to_remove, and_op->expression1_)) {
and_op->expression1_ = nullptr;
}
if (utils::Contains(exprs_to_remove, and_op->expression2_)) {
and_op->expression2_ = nullptr;
}
and_op->expression1_ = RemoveAndExpressions(and_op->expression1_, exprs_to_remove);
and_op->expression2_ = RemoveAndExpressions(and_op->expression2_, exprs_to_remove);
if (!and_op->expression1_ && !and_op->expression2_) {
return nullptr;
}
if (and_op->expression1_ && !and_op->expression2_) {
return and_op->expression1_;
}
if (and_op->expression2_ && !and_op->expression1_) {
return and_op->expression2_;
}
return and_op;
}
} // namespace memgraph::query::v2::plan::impl

View File

@ -0,0 +1,668 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
/// @file
/// This file provides a plan rewriter which replaces `Filter` and `ScanAll`
/// operations with `ScanAllBy<Index>` if possible. The public entrypoint is
/// `RewriteWithIndexLookup`.
#pragma once
#include <algorithm>
#include <memory>
#include <optional>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <gflags/gflags.h>
#include "query/v2/plan/operator.hpp"
#include "query/v2/plan/preprocess.hpp"
DECLARE_int64(query_vertex_count_to_expand_existing);
namespace memgraph::query::v2::plan {
namespace impl {
// Return the new root expression after removing the given expressions from the
// given expression tree.
Expression *RemoveAndExpressions(Expression *expr, const std::unordered_set<Expression *> &exprs_to_remove);
template <class TDbAccessor>
class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor {
public:
IndexLookupRewriter(SymbolTable *symbol_table, AstStorage *ast_storage, TDbAccessor *db)
: symbol_table_(symbol_table), ast_storage_(ast_storage), db_(db) {}
using HierarchicalLogicalOperatorVisitor::PostVisit;
using HierarchicalLogicalOperatorVisitor::PreVisit;
using HierarchicalLogicalOperatorVisitor::Visit;
bool Visit(Once &) override { return true; }
bool PreVisit(Filter &op) override {
prev_ops_.push_back(&op);
filters_.CollectFilterExpression(op.expression_, *symbol_table_);
return true;
}
// Remove no longer needed Filter in PostVisit, this should be the last thing
// Filter::Accept does, so it should be safe to remove the last reference and
// free the memory.
bool PostVisit(Filter &op) override {
prev_ops_.pop_back();
op.expression_ = RemoveAndExpressions(op.expression_, filter_exprs_for_removal_);
if (!op.expression_ || utils::Contains(filter_exprs_for_removal_, op.expression_)) {
SetOnParent(op.input());
}
return true;
}
bool PreVisit(ScanAll &op) override {
prev_ops_.push_back(&op);
return true;
}
// Replace ScanAll with ScanAllBy<Index> in PostVisit, because removal of
// ScanAll may remove the last reference and thus free the memory. PostVisit
// should be the last thing ScanAll::Accept does, so it should be safe.
bool PostVisit(ScanAll &scan) override {
prev_ops_.pop_back();
auto indexed_scan = GenScanByIndex(scan);
if (indexed_scan) {
SetOnParent(std::move(indexed_scan));
}
return true;
}
bool PreVisit(Expand &op) override {
prev_ops_.push_back(&op);
return true;
}
// See if it might be better to do ScanAllBy<Index> of the destination and
// then do Expand to existing.
bool PostVisit(Expand &expand) override {
prev_ops_.pop_back();
if (expand.common_.existing_node) {
return true;
}
ScanAll dst_scan(expand.input(), expand.common_.node_symbol, expand.view_);
auto indexed_scan = GenScanByIndex(dst_scan, FLAGS_query_vertex_count_to_expand_existing);
if (indexed_scan) {
expand.set_input(std::move(indexed_scan));
expand.common_.existing_node = true;
}
return true;
}
bool PreVisit(ExpandVariable &op) override {
prev_ops_.push_back(&op);
return true;
}
// See if it might be better to do ScanAllBy<Index> of the destination and
// then do ExpandVariable to existing.
bool PostVisit(ExpandVariable &expand) override {
prev_ops_.pop_back();
if (expand.common_.existing_node) {
return true;
}
std::unique_ptr<ScanAll> indexed_scan;
ScanAll dst_scan(expand.input(), expand.common_.node_symbol, storage::v3::View::OLD);
// With expand to existing we only get real gains with BFS, because we use a
// different algorithm then, so prefer expand to existing.
if (expand.type_ == EdgeAtom::Type::BREADTH_FIRST) {
// TODO: Perhaps take average node degree into consideration, instead of
// unconditionally creating an indexed scan.
indexed_scan = GenScanByIndex(dst_scan);
} else {
indexed_scan = GenScanByIndex(dst_scan, FLAGS_query_vertex_count_to_expand_existing);
}
if (indexed_scan) {
expand.set_input(std::move(indexed_scan));
expand.common_.existing_node = true;
}
return true;
}
// The following operators may only use index lookup in filters inside of
// their own branches. So we handle them all the same.
// * Input operator is visited with the current visitor.
// * Custom operator branches are visited with a new visitor.
bool PreVisit(Merge &op) override {
prev_ops_.push_back(&op);
op.input()->Accept(*this);
RewriteBranch(&op.merge_match_);
return false;
}
bool PostVisit(Merge &) override {
prev_ops_.pop_back();
return true;
}
bool PreVisit(Optional &op) override {
prev_ops_.push_back(&op);
op.input()->Accept(*this);
RewriteBranch(&op.optional_);
return false;
}
bool PostVisit(Optional &) override {
prev_ops_.pop_back();
return true;
}
// Rewriting Cartesian assumes that the input plan will have Filter operations
// as soon as they are possible. Therefore we do not track filters above
// Cartesian because they should be irrelevant.
//
// For example, the following plan is not expected to be an input to
// IndexLookupRewriter.
//
// Filter n.prop = 16
// |
// Cartesian
// |
// |\
// | ScanAll (n)
// |
// ScanAll (m)
//
// Instead, the equivalent set of operations should be done this way:
//
// Cartesian
// |
// |\
// | Filter n.prop = 16
// | |
// | ScanAll (n)
// |
// ScanAll (m)
bool PreVisit(Cartesian &op) override {
prev_ops_.push_back(&op);
RewriteBranch(&op.left_op_);
RewriteBranch(&op.right_op_);
return false;
}
bool PostVisit(Cartesian &) override {
prev_ops_.pop_back();
return true;
}
bool PreVisit(Union &op) override {
prev_ops_.push_back(&op);
RewriteBranch(&op.left_op_);
RewriteBranch(&op.right_op_);
return false;
}
bool PostVisit(Union &) override {
prev_ops_.pop_back();
return true;
}
// The remaining operators should work by just traversing into their input.
bool PreVisit(CreateNode &op) override {
prev_ops_.push_back(&op);
return true;
}
bool PostVisit(CreateNode &) override {
prev_ops_.pop_back();
return true;
}
bool PreVisit(CreateExpand &op) override {
prev_ops_.push_back(&op);
return true;
}
bool PostVisit(CreateExpand &) override {
prev_ops_.pop_back();
return true;
}
bool PreVisit(ScanAllByLabel &op) override {
prev_ops_.push_back(&op);
return true;
}
bool PostVisit(ScanAllByLabel &) override {
prev_ops_.pop_back();
return true;
}
bool PreVisit(ScanAllByLabelPropertyRange &op) override {
prev_ops_.push_back(&op);
return true;
}
bool PostVisit(ScanAllByLabelPropertyRange &) override {
prev_ops_.pop_back();
return true;
}
bool PreVisit(ScanAllByLabelPropertyValue &op) override {
prev_ops_.push_back(&op);
return true;
}
bool PostVisit(ScanAllByLabelPropertyValue &) override {
prev_ops_.pop_back();
return true;
}
bool PreVisit(ScanAllByLabelProperty &op) override {
prev_ops_.push_back(&op);
return true;
}
bool PostVisit(ScanAllByLabelProperty &) override {
prev_ops_.pop_back();
return true;
}
bool PreVisit(ScanAllById &op) override {
prev_ops_.push_back(&op);
return true;
}
bool PostVisit(ScanAllById &) override {
prev_ops_.pop_back();
return true;
}
bool PreVisit(ConstructNamedPath &op) override {
prev_ops_.push_back(&op);
return true;
}
bool PostVisit(ConstructNamedPath &) override {
prev_ops_.pop_back();
return true;
}
bool PreVisit(Produce &op) override {
prev_ops_.push_back(&op);
return true;
}
bool PostVisit(Produce &) override {
prev_ops_.pop_back();
return true;
}
bool PreVisit(Delete &op) override {
prev_ops_.push_back(&op);
return true;
}
bool PostVisit(Delete &) override {
prev_ops_.pop_back();
return true;
}
bool PreVisit(SetProperty &op) override {
prev_ops_.push_back(&op);
return true;
}
bool PostVisit(SetProperty &) override {
prev_ops_.pop_back();
return true;
}
bool PreVisit(SetProperties &op) override {
prev_ops_.push_back(&op);
return true;
}
bool PostVisit(SetProperties &) override {
prev_ops_.pop_back();
return true;
}
bool PreVisit(SetLabels &op) override {
prev_ops_.push_back(&op);
return true;
}
bool PostVisit(SetLabels &) override {
prev_ops_.pop_back();
return true;
}
bool PreVisit(RemoveProperty &op) override {
prev_ops_.push_back(&op);
return true;
}
bool PostVisit(RemoveProperty &) override {
prev_ops_.pop_back();
return true;
}
bool PreVisit(RemoveLabels &op) override {
prev_ops_.push_back(&op);
return true;
}
bool PostVisit(RemoveLabels &) override {
prev_ops_.pop_back();
return true;
}
bool PreVisit(EdgeUniquenessFilter &op) override {
prev_ops_.push_back(&op);
return true;
}
bool PostVisit(EdgeUniquenessFilter &) override {
prev_ops_.pop_back();
return true;
}
bool PreVisit(Accumulate &op) override {
prev_ops_.push_back(&op);
return true;
}
bool PostVisit(Accumulate &) override {
prev_ops_.pop_back();
return true;
}
bool PreVisit(Aggregate &op) override {
prev_ops_.push_back(&op);
return true;
}
bool PostVisit(Aggregate &) override {
prev_ops_.pop_back();
return true;
}
bool PreVisit(Skip &op) override {
prev_ops_.push_back(&op);
return true;
}
bool PostVisit(Skip &) override {
prev_ops_.pop_back();
return true;
}
bool PreVisit(Limit &op) override {
prev_ops_.push_back(&op);
return true;
}
bool PostVisit(Limit &) override {
prev_ops_.pop_back();
return true;
}
bool PreVisit(OrderBy &op) override {
prev_ops_.push_back(&op);
return true;
}
bool PostVisit(OrderBy &) override {
prev_ops_.pop_back();
return true;
}
bool PreVisit(Unwind &op) override {
prev_ops_.push_back(&op);
return true;
}
bool PostVisit(Unwind &) override {
prev_ops_.pop_back();
return true;
}
bool PreVisit(Distinct &op) override {
prev_ops_.push_back(&op);
return true;
}
bool PostVisit(Distinct &) override {
prev_ops_.pop_back();
return true;
}
bool PreVisit(CallProcedure &op) override {
prev_ops_.push_back(&op);
return true;
}
bool PostVisit(CallProcedure &) override {
prev_ops_.pop_back();
return true;
}
bool PreVisit(Foreach &op) override {
prev_ops_.push_back(&op);
return false;
}
bool PostVisit(Foreach &) override {
prev_ops_.pop_back();
return true;
}
std::shared_ptr<LogicalOperator> new_root_;
private:
SymbolTable *symbol_table_;
AstStorage *ast_storage_;
TDbAccessor *db_;
// Collected filters, pending for examination if they can be used for advanced
// lookup operations (by index, node ID, ...).
Filters filters_;
// Expressions which no longer need a plain Filter operator.
std::unordered_set<Expression *> filter_exprs_for_removal_;
std::vector<LogicalOperator *> prev_ops_;
struct LabelPropertyIndex {
LabelIx label;
// FilterInfo with PropertyFilter.
FilterInfo filter;
int64_t vertex_count;
};
bool DefaultPreVisit() override { throw utils::NotYetImplemented("optimizing index lookup"); }
void SetOnParent(const std::shared_ptr<LogicalOperator> &input) {
MG_ASSERT(input);
if (prev_ops_.empty()) {
MG_ASSERT(!new_root_);
new_root_ = input;
return;
}
prev_ops_.back()->set_input(input);
}
void RewriteBranch(std::shared_ptr<LogicalOperator> *branch) {
IndexLookupRewriter<TDbAccessor> rewriter(symbol_table_, ast_storage_, db_);
(*branch)->Accept(rewriter);
if (rewriter.new_root_) {
*branch = rewriter.new_root_;
}
}
storage::v3::LabelId GetLabel(LabelIx label) { return db_->NameToLabel(label.name); }
storage::v3::PropertyId GetProperty(PropertyIx prop) { return db_->NameToProperty(prop.name); }
std::optional<LabelIx> FindBestLabelIndex(const std::unordered_set<LabelIx> &labels) {
MG_ASSERT(!labels.empty(), "Trying to find the best label without any labels.");
std::optional<LabelIx> best_label;
for (const auto &label : labels) {
if (!db_->LabelIndexExists(GetLabel(label))) continue;
if (!best_label) {
best_label = label;
continue;
}
if (db_->VerticesCount(GetLabel(label)) < db_->VerticesCount(GetLabel(*best_label))) best_label = label;
}
return best_label;
}
// Finds the label-property combination which has indexed the lowest amount of
// vertices. If the index cannot be found, nullopt is returned.
std::optional<LabelPropertyIndex> FindBestLabelPropertyIndex(const Symbol &symbol,
const std::unordered_set<Symbol> &bound_symbols) {
auto are_bound = [&bound_symbols](const auto &used_symbols) {
for (const auto &used_symbol : used_symbols) {
if (!utils::Contains(bound_symbols, used_symbol)) {
return false;
}
}
return true;
};
std::optional<LabelPropertyIndex> found;
for (const auto &label : filters_.FilteredLabels(symbol)) {
for (const auto &filter : filters_.PropertyFilters(symbol)) {
if (filter.property_filter->is_symbol_in_value_ || !are_bound(filter.used_symbols)) {
// Skip filter expressions which use the symbol whose property we are
// looking up or aren't bound. We cannot scan by such expressions. For
// example, in `n.a = 2 + n.b` both sides of `=` refer to `n`, so we
// cannot scan `n` by property index.
continue;
}
const auto &property = filter.property_filter->property_;
if (!db_->LabelPropertyIndexExists(GetLabel(label), GetProperty(property))) {
continue;
}
int64_t vertex_count = db_->VerticesCount(GetLabel(label), GetProperty(property));
auto is_better_type = [&found](PropertyFilter::Type type) {
// Order the types by the most preferred index lookup type.
static const PropertyFilter::Type kFilterTypeOrder[] = {
PropertyFilter::Type::EQUAL, PropertyFilter::Type::RANGE, PropertyFilter::Type::REGEX_MATCH};
auto *found_sort_ix = std::find(kFilterTypeOrder, kFilterTypeOrder + 3, found->filter.property_filter->type_);
auto *type_sort_ix = std::find(kFilterTypeOrder, kFilterTypeOrder + 3, type);
return type_sort_ix < found_sort_ix;
};
if (!found || vertex_count < found->vertex_count ||
(vertex_count == found->vertex_count && is_better_type(filter.property_filter->type_))) {
found = LabelPropertyIndex{label, filter, vertex_count};
}
}
}
return found;
}
// Creates a ScanAll by the best possible index for the `node_symbol`. Best
// index is defined as the index with least number of vertices. If the node
// does not have at least a label, no indexed lookup can be created and
// `nullptr` is returned. The operator is chained after `input`. Optional
// `max_vertex_count` controls, whether no operator should be created if the
// vertex count in the best index exceeds this number. In such a case,
// `nullptr` is returned and `input` is not chained.
std::unique_ptr<ScanAll> GenScanByIndex(const ScanAll &scan,
const std::optional<int64_t> &max_vertex_count = std::nullopt) {
const auto &input = scan.input();
const auto &node_symbol = scan.output_symbol_;
const auto &view = scan.view_;
const auto &modified_symbols = scan.ModifiedSymbols(*symbol_table_);
std::unordered_set<Symbol> bound_symbols(modified_symbols.begin(), modified_symbols.end());
auto are_bound = [&bound_symbols](const auto &used_symbols) {
for (const auto &used_symbol : used_symbols) {
if (!utils::Contains(bound_symbols, used_symbol)) {
return false;
}
}
return true;
};
// First, try to see if we can find a vertex by ID.
if (!max_vertex_count || *max_vertex_count >= 1) {
for (const auto &filter : filters_.IdFilters(node_symbol)) {
if (filter.id_filter->is_symbol_in_value_ || !are_bound(filter.used_symbols)) continue;
auto *value = filter.id_filter->value_;
filter_exprs_for_removal_.insert(filter.expression);
filters_.EraseFilter(filter);
return std::make_unique<ScanAllById>(input, node_symbol, value, view);
}
}
// Now try to see if we can use label+property index. If not, try to use
// just the label index.
const auto labels = filters_.FilteredLabels(node_symbol);
if (labels.empty()) {
// Without labels, we cannot generate any indexed ScanAll.
return nullptr;
}
auto found_index = FindBestLabelPropertyIndex(node_symbol, bound_symbols);
if (found_index &&
// Use label+property index if we satisfy max_vertex_count.
(!max_vertex_count || *max_vertex_count >= found_index->vertex_count)) {
// Copy the property filter and then erase it from filters.
const auto prop_filter = *found_index->filter.property_filter;
if (prop_filter.type_ != PropertyFilter::Type::REGEX_MATCH) {
// Remove the original expression from Filter operation only if it's not
// a regex match. In such a case we need to perform the matching even
// after we've scanned the index.
filter_exprs_for_removal_.insert(found_index->filter.expression);
}
filters_.EraseFilter(found_index->filter);
std::vector<Expression *> removed_expressions;
filters_.EraseLabelFilter(node_symbol, found_index->label, &removed_expressions);
filter_exprs_for_removal_.insert(removed_expressions.begin(), removed_expressions.end());
if (prop_filter.lower_bound_ || prop_filter.upper_bound_) {
return std::make_unique<ScanAllByLabelPropertyRange>(
input, node_symbol, GetLabel(found_index->label), GetProperty(prop_filter.property_),
prop_filter.property_.name, prop_filter.lower_bound_, prop_filter.upper_bound_, view);
} else if (prop_filter.type_ == PropertyFilter::Type::REGEX_MATCH) {
// Generate index scan using the empty string as a lower bound.
Expression *empty_string = ast_storage_->Create<PrimitiveLiteral>("");
auto lower_bound = utils::MakeBoundInclusive(empty_string);
return std::make_unique<ScanAllByLabelPropertyRange>(
input, node_symbol, GetLabel(found_index->label), GetProperty(prop_filter.property_),
prop_filter.property_.name, std::make_optional(lower_bound), std::nullopt, view);
} else if (prop_filter.type_ == PropertyFilter::Type::IN) {
// TODO(buda): ScanAllByLabelProperty + Filter should be considered
// here once the operator and the right cardinality estimation exist.
auto const &symbol = symbol_table_->CreateAnonymousSymbol();
auto *expression = ast_storage_->Create<Identifier>(symbol.name_);
expression->MapTo(symbol);
auto unwind_operator = std::make_unique<Unwind>(input, prop_filter.value_, symbol);
return std::make_unique<ScanAllByLabelPropertyValue>(
std::move(unwind_operator), node_symbol, GetLabel(found_index->label), GetProperty(prop_filter.property_),
prop_filter.property_.name, expression, view);
} else if (prop_filter.type_ == PropertyFilter::Type::IS_NOT_NULL) {
return std::make_unique<ScanAllByLabelProperty>(input, node_symbol, GetLabel(found_index->label),
GetProperty(prop_filter.property_), prop_filter.property_.name,
view);
} else {
MG_ASSERT(prop_filter.value_, "Property filter should either have bounds or a value expression.");
return std::make_unique<ScanAllByLabelPropertyValue>(input, node_symbol, GetLabel(found_index->label),
GetProperty(prop_filter.property_),
prop_filter.property_.name, prop_filter.value_, view);
}
}
auto maybe_label = FindBestLabelIndex(labels);
if (!maybe_label) return nullptr;
const auto &label = *maybe_label;
if (max_vertex_count && db_->VerticesCount(GetLabel(label)) > *max_vertex_count) {
// Don't create an indexed lookup, since we have more labeled vertices
// than the allowed count.
return nullptr;
}
std::vector<Expression *> removed_expressions;
filters_.EraseLabelFilter(node_symbol, label, &removed_expressions);
filter_exprs_for_removal_.insert(removed_expressions.begin(), removed_expressions.end());
return std::make_unique<ScanAllByLabel>(input, node_symbol, GetLabel(label), view);
}
};
} // namespace impl
template <class TDbAccessor>
std::unique_ptr<LogicalOperator> RewriteWithIndexLookup(std::unique_ptr<LogicalOperator> root_op,
SymbolTable *symbol_table, AstStorage *ast_storage,
TDbAccessor *db) {
impl::IndexLookupRewriter<TDbAccessor> rewriter(symbol_table, ast_storage, db);
root_op->Accept(rewriter);
if (rewriter.new_root_) {
// This shouldn't happen in real use case, because IndexLookupRewriter
// removes Filter operations and they cannot be the root op. In case we
// somehow missed this, raise NotYetImplemented instead of MG_ASSERT
// crashing the application.
throw utils::NotYetImplemented("optimizing index lookup");
}
return root_op;
}
} // namespace memgraph::query::v2::plan

View File

@ -0,0 +1,594 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/v2/plan/rule_based_planner.hpp"
#include <algorithm>
#include <functional>
#include <limits>
#include <stack>
#include <unordered_set>
#include "utils/algorithm.hpp"
#include "utils/exceptions.hpp"
#include "utils/logging.hpp"
namespace memgraph::query::v2::plan {
namespace {
bool HasBoundFilterSymbols(const std::unordered_set<Symbol> &bound_symbols, const FilterInfo &filter) {
for (const auto &symbol : filter.used_symbols) {
if (bound_symbols.find(symbol) == bound_symbols.end()) {
return false;
}
}
return true;
}
// Ast tree visitor which collects the context for a return body.
// The return body of WITH and RETURN clauses consists of:
//
// * named expressions (used to produce results);
// * flag whether the results need to be DISTINCT;
// * optional SKIP expression;
// * optional LIMIT expression and
// * optional ORDER BY expressions.
//
// In addition to the above, we collect information on used symbols,
// aggregations and expressions used for group by.
class ReturnBodyContext : public HierarchicalTreeVisitor {
public:
ReturnBodyContext(const ReturnBody &body, SymbolTable &symbol_table, const std::unordered_set<Symbol> &bound_symbols,
AstStorage &storage, Where *where = nullptr)
: body_(body), symbol_table_(symbol_table), bound_symbols_(bound_symbols), storage_(storage), where_(where) {
// Collect symbols from named expressions.
output_symbols_.reserve(body_.named_expressions.size());
if (body.all_identifiers) {
// Expand '*' to expressions and symbols first, so that their results come
// before regular named expressions.
ExpandUserSymbols();
}
for (auto &named_expr : body_.named_expressions) {
output_symbols_.emplace_back(symbol_table_.at(*named_expr));
named_expr->Accept(*this);
named_expressions_.emplace_back(named_expr);
}
// Collect symbols used in group by expressions.
if (!aggregations_.empty()) {
UsedSymbolsCollector collector(symbol_table_);
for (auto &group_by : group_by_) {
group_by->Accept(collector);
}
group_by_used_symbols_ = collector.symbols_;
}
if (aggregations_.empty()) {
// Visit order_by and where if we do not have aggregations. This way we
// prevent collecting group_by expressions from order_by and where, which
// would be very wrong. When we have aggregation, order_by and where can
// only use new symbols (ensured in semantic analysis), so we don't care
// about collecting used_symbols. Also, semantic analysis should
// have prevented any aggregations from appearing here.
for (const auto &order_pair : body.order_by) {
order_pair.expression->Accept(*this);
}
if (where) {
where->Accept(*this);
}
MG_ASSERT(aggregations_.empty(), "Unexpected aggregations in ORDER BY or WHERE");
}
}
using HierarchicalTreeVisitor::PostVisit;
using HierarchicalTreeVisitor::PreVisit;
using HierarchicalTreeVisitor::Visit;
bool Visit(PrimitiveLiteral &) override {
has_aggregation_.emplace_back(false);
return true;
}
private:
template <typename TLiteral, typename TIteratorToExpression>
void PostVisitCollectionLiteral(TLiteral &literal, TIteratorToExpression iterator_to_expression) {
// If there is an aggregation in the list, and there are group-bys, then we
// need to add the group-bys manually. If there are no aggregations, the
// whole list will be added as a group-by.
std::vector<Expression *> literal_group_by;
bool has_aggr = false;
auto it = has_aggregation_.end();
auto elements_it = literal.elements_.begin();
std::advance(it, -literal.elements_.size());
while (it != has_aggregation_.end()) {
if (*it) {
has_aggr = true;
} else {
literal_group_by.emplace_back(iterator_to_expression(elements_it));
}
elements_it++;
it = has_aggregation_.erase(it);
}
has_aggregation_.emplace_back(has_aggr);
if (has_aggr) {
for (auto expression_ptr : literal_group_by) group_by_.emplace_back(expression_ptr);
}
}
public:
bool PostVisit(ListLiteral &list_literal) override {
MG_ASSERT(list_literal.elements_.size() <= has_aggregation_.size(),
"Expected as many has_aggregation_ flags as there are list"
"elements.");
PostVisitCollectionLiteral(list_literal, [](auto it) { return *it; });
return true;
}
bool PostVisit(MapLiteral &map_literal) override {
MG_ASSERT(map_literal.elements_.size() <= has_aggregation_.size(),
"Expected has_aggregation_ flags as much as there are map elements.");
PostVisitCollectionLiteral(map_literal, [](auto it) { return it->second; });
return true;
}
bool PostVisit(All &all) override {
// Remove the symbol which is bound by all, because we are only interested
// in free (unbound) symbols.
used_symbols_.erase(symbol_table_.at(*all.identifier_));
MG_ASSERT(has_aggregation_.size() >= 3U, "Expected 3 has_aggregation_ flags for ALL arguments");
bool has_aggr = false;
for (int i = 0; i < 3; ++i) {
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
}
has_aggregation_.emplace_back(has_aggr);
return true;
}
bool PostVisit(Single &single) override {
// Remove the symbol which is bound by single, because we are only
// interested in free (unbound) symbols.
used_symbols_.erase(symbol_table_.at(*single.identifier_));
MG_ASSERT(has_aggregation_.size() >= 3U, "Expected 3 has_aggregation_ flags for SINGLE arguments");
bool has_aggr = false;
for (int i = 0; i < 3; ++i) {
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
}
has_aggregation_.emplace_back(has_aggr);
return true;
}
bool PostVisit(Any &any) override {
// Remove the symbol which is bound by any, because we are only interested
// in free (unbound) symbols.
used_symbols_.erase(symbol_table_.at(*any.identifier_));
MG_ASSERT(has_aggregation_.size() >= 3U, "Expected 3 has_aggregation_ flags for ANY arguments");
bool has_aggr = false;
for (int i = 0; i < 3; ++i) {
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
}
has_aggregation_.emplace_back(has_aggr);
return true;
}
bool PostVisit(None &none) override {
// Remove the symbol which is bound by none, because we are only interested
// in free (unbound) symbols.
used_symbols_.erase(symbol_table_.at(*none.identifier_));
MG_ASSERT(has_aggregation_.size() >= 3U, "Expected 3 has_aggregation_ flags for NONE arguments");
bool has_aggr = false;
for (int i = 0; i < 3; ++i) {
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
}
has_aggregation_.emplace_back(has_aggr);
return true;
}
bool PostVisit(Reduce &reduce) override {
// Remove the symbols bound by reduce, because we are only interested
// in free (unbound) symbols.
used_symbols_.erase(symbol_table_.at(*reduce.accumulator_));
used_symbols_.erase(symbol_table_.at(*reduce.identifier_));
MG_ASSERT(has_aggregation_.size() >= 5U, "Expected 5 has_aggregation_ flags for REDUCE arguments");
bool has_aggr = false;
for (int i = 0; i < 5; ++i) {
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
}
has_aggregation_.emplace_back(has_aggr);
return true;
}
bool PostVisit(Coalesce &coalesce) override {
MG_ASSERT(has_aggregation_.size() >= coalesce.expressions_.size(),
"Expected >= {} has_aggregation_ flags for COALESCE arguments", has_aggregation_.size());
bool has_aggr = false;
for (size_t i = 0; i < coalesce.expressions_.size(); ++i) {
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
}
has_aggregation_.emplace_back(has_aggr);
return true;
}
bool PostVisit(Extract &extract) override {
// Remove the symbol bound by extract, because we are only interested
// in free (unbound) symbols.
used_symbols_.erase(symbol_table_.at(*extract.identifier_));
MG_ASSERT(has_aggregation_.size() >= 3U, "Expected 3 has_aggregation_ flags for EXTRACT arguments");
bool has_aggr = false;
for (int i = 0; i < 3; ++i) {
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
}
has_aggregation_.emplace_back(has_aggr);
return true;
}
bool Visit(Identifier &ident) override {
const auto &symbol = symbol_table_.at(ident);
if (!utils::Contains(output_symbols_, symbol)) {
// Don't pick up new symbols, even though they may be used in ORDER BY or
// WHERE.
used_symbols_.insert(symbol);
}
has_aggregation_.emplace_back(false);
return true;
}
bool PreVisit(ListSlicingOperator &list_slicing) override {
list_slicing.list_->Accept(*this);
bool list_has_aggr = has_aggregation_.back();
has_aggregation_.pop_back();
bool has_aggr = list_has_aggr;
if (list_slicing.lower_bound_) {
list_slicing.lower_bound_->Accept(*this);
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
}
if (list_slicing.upper_bound_) {
list_slicing.upper_bound_->Accept(*this);
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
}
if (has_aggr && !list_has_aggr) {
// We need to group by the list expression, because it didn't have an
// aggregation inside.
group_by_.emplace_back(list_slicing.list_);
}
has_aggregation_.emplace_back(has_aggr);
return false;
}
bool PreVisit(IfOperator &if_operator) override {
if_operator.condition_->Accept(*this);
bool has_aggr = has_aggregation_.back();
has_aggregation_.pop_back();
if_operator.then_expression_->Accept(*this);
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
if_operator.else_expression_->Accept(*this);
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
has_aggregation_.emplace_back(has_aggr);
// TODO: Once we allow aggregations here, insert appropriate stuff in
// group_by.
MG_ASSERT(!has_aggr, "Currently aggregations in CASE are not allowed");
return false;
}
bool PostVisit(Function &function) override {
MG_ASSERT(function.arguments_.size() <= has_aggregation_.size(),
"Expected as many has_aggregation_ flags as there are"
"function arguments.");
bool has_aggr = false;
auto it = has_aggregation_.end();
std::advance(it, -function.arguments_.size());
while (it != has_aggregation_.end()) {
has_aggr = has_aggr || *it;
it = has_aggregation_.erase(it);
}
has_aggregation_.emplace_back(has_aggr);
return true;
}
#define VISIT_BINARY_OPERATOR(BinaryOperator) \
bool PostVisit(BinaryOperator &op) override { \
MG_ASSERT(has_aggregation_.size() >= 2U, "Expected at least 2 has_aggregation_ flags."); \
/* has_aggregation_ stack is reversed, last result is from the 2nd */ \
/* expression. */ \
bool aggr2 = has_aggregation_.back(); \
has_aggregation_.pop_back(); \
bool aggr1 = has_aggregation_.back(); \
has_aggregation_.pop_back(); \
bool has_aggr = aggr1 || aggr2; \
if (has_aggr && !(aggr1 && aggr2)) { \
/* Group by the expression which does not contain aggregation. */ \
/* Possible optimization is to ignore constant value expressions */ \
group_by_.emplace_back(aggr1 ? op.expression2_ : op.expression1_); \
} \
/* Propagate that this whole expression may contain an aggregation. */ \
has_aggregation_.emplace_back(has_aggr); \
return true; \
}
VISIT_BINARY_OPERATOR(OrOperator)
VISIT_BINARY_OPERATOR(XorOperator)
VISIT_BINARY_OPERATOR(AndOperator)
VISIT_BINARY_OPERATOR(AdditionOperator)
VISIT_BINARY_OPERATOR(SubtractionOperator)
VISIT_BINARY_OPERATOR(MultiplicationOperator)
VISIT_BINARY_OPERATOR(DivisionOperator)
VISIT_BINARY_OPERATOR(ModOperator)
VISIT_BINARY_OPERATOR(NotEqualOperator)
VISIT_BINARY_OPERATOR(EqualOperator)
VISIT_BINARY_OPERATOR(LessOperator)
VISIT_BINARY_OPERATOR(GreaterOperator)
VISIT_BINARY_OPERATOR(LessEqualOperator)
VISIT_BINARY_OPERATOR(GreaterEqualOperator)
VISIT_BINARY_OPERATOR(InListOperator)
VISIT_BINARY_OPERATOR(SubscriptOperator)
#undef VISIT_BINARY_OPERATOR
bool PostVisit(Aggregation &aggr) override {
// Aggregation contains a virtual symbol, where the result will be stored.
const auto &symbol = symbol_table_.at(aggr);
aggregations_.emplace_back(Aggregate::Element{aggr.expression1_, aggr.expression2_, aggr.op_, symbol});
// Aggregation expression1_ is optional in COUNT(*), and COLLECT_MAP uses
// two expressions, so we can have 0, 1 or 2 elements on the
// has_aggregation_stack for this Aggregation expression.
if (aggr.op_ == Aggregation::Op::COLLECT_MAP) has_aggregation_.pop_back();
if (aggr.expression1_)
has_aggregation_.back() = true;
else
has_aggregation_.emplace_back(true);
// Possible optimization is to skip remembering symbols inside aggregation.
// If and when implementing this, don't forget that Accumulate needs *all*
// the symbols, including those inside aggregation.
return true;
}
bool PostVisit(NamedExpression &named_expr) override {
MG_ASSERT(has_aggregation_.size() == 1U, "Expected to reduce has_aggregation_ to single boolean.");
if (!has_aggregation_.back()) {
group_by_.emplace_back(named_expr.expression_);
}
has_aggregation_.pop_back();
return true;
}
bool Visit(ParameterLookup &) override {
has_aggregation_.emplace_back(false);
return true;
}
bool PostVisit(RegexMatch &regex_match) override {
MG_ASSERT(has_aggregation_.size() >= 2U, "Expected 2 has_aggregation_ flags for RegexMatch arguments");
bool has_aggr = has_aggregation_.back();
has_aggregation_.pop_back();
has_aggregation_.back() |= has_aggr;
return true;
}
// Creates NamedExpression with an Identifier for each user declared symbol.
// This should be used when body.all_identifiers is true, to generate
// expressions for Produce operator.
void ExpandUserSymbols() {
MG_ASSERT(named_expressions_.empty(), "ExpandUserSymbols should be first to fill named_expressions_");
MG_ASSERT(output_symbols_.empty(), "ExpandUserSymbols should be first to fill output_symbols_");
for (const auto &symbol : bound_symbols_) {
if (!symbol.user_declared()) {
continue;
}
auto *ident = storage_.Create<Identifier>(symbol.name())->MapTo(symbol);
auto *named_expr = storage_.Create<NamedExpression>(symbol.name(), ident)->MapTo(symbol);
// Fill output expressions and symbols with expanded identifiers.
named_expressions_.emplace_back(named_expr);
output_symbols_.emplace_back(symbol);
used_symbols_.insert(symbol);
// Don't forget to group by expanded identifiers.
group_by_.emplace_back(ident);
}
// Cypher RETURN/WITH * expects to expand '*' sorted by name.
std::sort(output_symbols_.begin(), output_symbols_.end(),
[](const auto &a, const auto &b) { return a.name() < b.name(); });
std::sort(named_expressions_.begin(), named_expressions_.end(),
[](const auto &a, const auto &b) { return a->name_ < b->name_; });
}
// If true, results need to be distinct.
bool distinct() const { return body_.distinct; }
// Named expressions which are used to produce results.
const auto &named_expressions() const { return named_expressions_; }
// Pairs of (Ordering, Expression *) for sorting results.
const auto &order_by() const { return body_.order_by; }
// Optional expression which determines how many results to skip.
auto *skip() const { return body_.skip; }
// Optional expression which determines how many results to produce.
auto *limit() const { return body_.limit; }
// Optional Where clause for filtering.
const auto *where() const { return where_; }
// Set of symbols used inside the visited expressions, including the inside of
// aggregation expression. These only includes old symbols, even though new
// ones may have been used in ORDER BY or WHERE.
const auto &used_symbols() const { return used_symbols_; }
// List of aggregation elements found in expressions.
const auto &aggregations() const { return aggregations_; }
// When there is at least one aggregation element, all the non-aggregate (sub)
// expressions are used for grouping. For example, in `WITH sum(n.a) + 2 * n.b
// AS sum, n.c AS nc`, we will group by `2 * n.b` and `n.c`.
const auto &group_by() const { return group_by_; }
// Set of symbols used in group by expressions.
const auto &group_by_used_symbols() const { return group_by_used_symbols_; }
// All symbols generated by named expressions. They are collected in order of
// named_expressions.
const auto &output_symbols() const { return output_symbols_; }
private:
const ReturnBody &body_;
SymbolTable &symbol_table_;
const std::unordered_set<Symbol> &bound_symbols_;
AstStorage &storage_;
const Where *const where_ = nullptr;
std::unordered_set<Symbol> used_symbols_;
std::vector<Symbol> output_symbols_;
std::vector<Aggregate::Element> aggregations_;
std::vector<Expression *> group_by_;
std::unordered_set<Symbol> group_by_used_symbols_;
// Flag stack indicating whether an expression contains an aggregation. A
// stack is needed so that we differentiate the case where a child
// sub-expression has an aggregation, while the other child doesn't. For
// example AST, (+ (sum x) y)
// * (sum x) -- Has an aggregation.
// * y -- Doesn't, we need to group by this.
// * (+ (sum x) y) -- The whole expression has an aggregation, so we don't
// group by it.
std::list<bool> has_aggregation_;
std::vector<NamedExpression *> named_expressions_;
};
std::unique_ptr<LogicalOperator> GenReturnBody(std::unique_ptr<LogicalOperator> input_op, bool advance_command,
const ReturnBodyContext &body, bool accumulate = false) {
std::vector<Symbol> used_symbols(body.used_symbols().begin(), body.used_symbols().end());
auto last_op = std::move(input_op);
if (accumulate) {
// We only advance the command in Accumulate. This is done for WITH clause,
// when the first part updated the database. RETURN clause may only need an
// accumulation after updates, without advancing the command.
last_op = std::make_unique<Accumulate>(std::move(last_op), used_symbols, advance_command);
}
if (!body.aggregations().empty()) {
// When we have aggregation, SKIP/LIMIT should always come after it.
std::vector<Symbol> remember(body.group_by_used_symbols().begin(), body.group_by_used_symbols().end());
last_op = std::make_unique<Aggregate>(std::move(last_op), body.aggregations(), body.group_by(), remember);
}
last_op = std::make_unique<Produce>(std::move(last_op), body.named_expressions());
// Distinct in ReturnBody only makes Produce values unique, so plan after it.
if (body.distinct()) {
last_op = std::make_unique<Distinct>(std::move(last_op), body.output_symbols());
}
// Like Where, OrderBy can read from symbols established by named expressions
// in Produce, so it must come after it.
if (!body.order_by().empty()) {
last_op = std::make_unique<OrderBy>(std::move(last_op), body.order_by(), body.output_symbols());
}
// Finally, Skip and Limit must come after OrderBy.
if (body.skip()) {
last_op = std::make_unique<Skip>(std::move(last_op), body.skip());
}
// Limit is always after Skip.
if (body.limit()) {
last_op = std::make_unique<Limit>(std::move(last_op), body.limit());
}
// Where may see new symbols so it comes after we generate Produce and in
// general, comes after any OrderBy, Skip or Limit.
if (body.where()) {
last_op = std::make_unique<Filter>(std::move(last_op), body.where()->expression_);
}
return last_op;
}
} // namespace
namespace impl {
Expression *ExtractFilters(const std::unordered_set<Symbol> &bound_symbols, Filters &filters, AstStorage &storage) {
Expression *filter_expr = nullptr;
for (auto filters_it = filters.begin(); filters_it != filters.end();) {
if (HasBoundFilterSymbols(bound_symbols, *filters_it)) {
filter_expr = impl::BoolJoin<AndOperator>(storage, filter_expr, filters_it->expression);
filters_it = filters.erase(filters_it);
} else {
filters_it++;
}
}
return filter_expr;
}
std::unique_ptr<LogicalOperator> GenFilters(std::unique_ptr<LogicalOperator> last_op,
const std::unordered_set<Symbol> &bound_symbols, Filters &filters,
AstStorage &storage) {
auto *filter_expr = ExtractFilters(bound_symbols, filters, storage);
if (filter_expr) {
last_op = std::make_unique<Filter>(std::move(last_op), filter_expr);
}
return last_op;
}
std::unique_ptr<LogicalOperator> GenNamedPaths(std::unique_ptr<LogicalOperator> last_op,
std::unordered_set<Symbol> &bound_symbols,
std::unordered_map<Symbol, std::vector<Symbol>> &named_paths) {
auto all_are_bound = [&bound_symbols](const std::vector<Symbol> &syms) {
for (const auto &sym : syms)
if (bound_symbols.find(sym) == bound_symbols.end()) return false;
return true;
};
for (auto named_path_it = named_paths.begin(); named_path_it != named_paths.end();) {
if (all_are_bound(named_path_it->second)) {
last_op = std::make_unique<ConstructNamedPath>(std::move(last_op), named_path_it->first,
std::move(named_path_it->second));
bound_symbols.insert(named_path_it->first);
named_path_it = named_paths.erase(named_path_it);
} else {
++named_path_it;
}
}
return last_op;
}
std::unique_ptr<LogicalOperator> GenReturn(Return &ret, std::unique_ptr<LogicalOperator> input_op,
SymbolTable &symbol_table, bool is_write,
const std::unordered_set<Symbol> &bound_symbols, AstStorage &storage) {
// Similar to WITH clause, but we want to accumulate when the query writes to
// the database. This way we handle the case when we want to return
// expressions with the latest updated results. For example, `MATCH (n) -- ()
// SET n.prop = n.prop + 1 RETURN n.prop`. If we match same `n` multiple 'k'
// times, we want to return 'k' results where the property value is the same,
// final result of 'k' increments.
bool accumulate = is_write;
bool advance_command = false;
ReturnBodyContext body(ret.body_, symbol_table, bound_symbols, storage);
return GenReturnBody(std::move(input_op), advance_command, body, accumulate);
}
std::unique_ptr<LogicalOperator> GenWith(With &with, std::unique_ptr<LogicalOperator> input_op,
SymbolTable &symbol_table, bool is_write,
std::unordered_set<Symbol> &bound_symbols, AstStorage &storage) {
// WITH clause is Accumulate/Aggregate (advance_command) + Produce and
// optional Filter. In case of update and aggregation, we want to accumulate
// first, so that when aggregating, we get the latest results. Similar to
// RETURN clause.
bool accumulate = is_write;
// No need to advance the command if we only performed reads.
bool advance_command = is_write;
ReturnBodyContext body(with.body_, symbol_table, bound_symbols, storage, with.where_);
auto last_op = GenReturnBody(std::move(input_op), advance_command, body, accumulate);
// Reset bound symbols, so that only those in WITH are exposed.
bound_symbols.clear();
for (const auto &symbol : body.output_symbols()) {
bound_symbols.insert(symbol);
}
return last_op;
}
std::unique_ptr<LogicalOperator> GenUnion(const CypherUnion &cypher_union, std::shared_ptr<LogicalOperator> left_op,
std::shared_ptr<LogicalOperator> right_op, SymbolTable &symbol_table) {
return std::make_unique<Union>(left_op, right_op, cypher_union.union_symbols_, left_op->OutputSymbols(symbol_table),
right_op->OutputSymbols(symbol_table));
}
} // namespace impl
} // namespace memgraph::query::v2::plan

View File

@ -0,0 +1,561 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
/// @file
#pragma once
#include <optional>
#include <variant>
#include "gflags/gflags.h"
#include "query/v2/frontend/ast/ast.hpp"
#include "query/v2/frontend/ast/ast_visitor.hpp"
#include "query/v2/plan/operator.hpp"
#include "query/v2/plan/preprocess.hpp"
#include "utils/logging.hpp"
#include "utils/typeinfo.hpp"
namespace memgraph::query::v2::plan {
/// @brief Context which contains variables commonly used during planning.
template <class TDbAccessor>
struct PlanningContext {
/// @brief SymbolTable is used to determine inputs and outputs of planned
/// operators.
///
/// Newly created AST nodes may be added to reference existing symbols.
SymbolTable *symbol_table{nullptr};
/// @brief The storage is used to create new AST nodes for use in operators.
AstStorage *ast_storage{nullptr};
/// @brief Cypher query to be planned
CypherQuery *query{nullptr};
/// @brief TDbAccessor, which may be used to get some information from the
/// database to generate better plans. The accessor is required only to live
/// long enough for the plan generation to finish.
TDbAccessor *db{nullptr};
/// @brief Symbol set is used to differentiate cycles in pattern matching.
/// During planning, symbols will be added as each operator produces values
/// for them. This way, the operator can be correctly initialized whether to
/// read a symbol or write it. E.g. `MATCH (n) -[r]- (n)` would bind (and
/// write) the first `n`, but the latter `n` would only read the already
/// written information.
std::unordered_set<Symbol> bound_symbols{};
};
template <class TDbAccessor>
auto MakePlanningContext(AstStorage *ast_storage, SymbolTable *symbol_table, CypherQuery *query, TDbAccessor *db) {
return PlanningContext<TDbAccessor>{symbol_table, ast_storage, query, db};
}
// Contextual information used for generating match operators.
struct MatchContext {
const Matching &matching;
const SymbolTable &symbol_table;
// Already bound symbols, which are used to determine whether the operator
// should reference them or establish new. This is both read from and written
// to during generation.
std::unordered_set<Symbol> &bound_symbols;
// Determines whether the match should see the new graph state or not.
storage::v3::View view = storage::v3::View::OLD;
// All the newly established symbols in match.
std::vector<Symbol> new_symbols{};
};
namespace impl {
// These functions are an internal implementation of RuleBasedPlanner. To avoid
// writing the whole code inline in this header file, they are declared here and
// defined in the cpp file.
// Iterates over `Filters` joining them in one expression via
// `AndOperator` if symbols they use are bound.. All the joined filters are
// removed from `Filters`.
Expression *ExtractFilters(const std::unordered_set<Symbol> &, Filters &, AstStorage &);
std::unique_ptr<LogicalOperator> GenFilters(std::unique_ptr<LogicalOperator>, const std::unordered_set<Symbol> &,
Filters &, AstStorage &);
/// Utility function for iterating pattern atoms and accumulating a result.
///
/// Each pattern is of the form `NodeAtom (, EdgeAtom, NodeAtom)*`. Therefore,
/// the `base` function is called on the first `NodeAtom`, while the `collect`
/// is called for the whole triplet. Result of the function is passed to the
/// next call. Final result is returned.
///
/// Example usage of counting edge atoms in the pattern.
///
/// auto base = [](NodeAtom *first_node) { return 0; };
/// auto collect = [](int accum, NodeAtom *prev_node, EdgeAtom *edge,
/// NodeAtom *node) {
/// return accum + 1;
/// };
/// int edge_count = ReducePattern<int>(pattern, base, collect);
///
// TODO: It might be a good idea to move this somewhere else, for easier usage
// in other files.
template <typename T>
auto ReducePattern(Pattern &pattern, std::function<T(NodeAtom *)> base,
std::function<T(T, NodeAtom *, EdgeAtom *, NodeAtom *)> collect) {
MG_ASSERT(!pattern.atoms_.empty(), "Missing atoms in pattern");
auto atoms_it = pattern.atoms_.begin();
auto current_node = utils::Downcast<NodeAtom>(*atoms_it++);
MG_ASSERT(current_node, "First pattern atom is not a node");
auto last_res = base(current_node);
// Remaining atoms need to follow sequentially as (EdgeAtom, NodeAtom)*
while (atoms_it != pattern.atoms_.end()) {
auto edge = utils::Downcast<EdgeAtom>(*atoms_it++);
MG_ASSERT(edge, "Expected an edge atom in pattern.");
MG_ASSERT(atoms_it != pattern.atoms_.end(), "Edge atom should not end the pattern.");
auto prev_node = current_node;
current_node = utils::Downcast<NodeAtom>(*atoms_it++);
MG_ASSERT(current_node, "Expected a node atom in pattern.");
last_res = collect(std::move(last_res), prev_node, edge, current_node);
}
return last_res;
}
// For all given `named_paths` checks if all its symbols have been bound.
// If so, it creates a logical operator for named path generation, binds its
// symbol, removes that path from the collection of unhandled ones and returns
// the new op. Otherwise, returns `last_op`.
std::unique_ptr<LogicalOperator> GenNamedPaths(std::unique_ptr<LogicalOperator> last_op,
std::unordered_set<Symbol> &bound_symbols,
std::unordered_map<Symbol, std::vector<Symbol>> &named_paths);
std::unique_ptr<LogicalOperator> GenReturn(Return &ret, std::unique_ptr<LogicalOperator> input_op,
SymbolTable &symbol_table, bool is_write,
const std::unordered_set<Symbol> &bound_symbols, AstStorage &storage);
std::unique_ptr<LogicalOperator> GenWith(With &with, std::unique_ptr<LogicalOperator> input_op,
SymbolTable &symbol_table, bool is_write,
std::unordered_set<Symbol> &bound_symbols, AstStorage &storage);
std::unique_ptr<LogicalOperator> GenUnion(const CypherUnion &cypher_union, std::shared_ptr<LogicalOperator> left_op,
std::shared_ptr<LogicalOperator> right_op, SymbolTable &symbol_table);
template <class TBoolOperator>
Expression *BoolJoin(AstStorage &storage, Expression *expr1, Expression *expr2) {
if (expr1 && expr2) {
return storage.Create<TBoolOperator>(expr1, expr2);
}
return expr1 ? expr1 : expr2;
}
} // namespace impl
/// @brief Planner which uses hardcoded rules to produce operators.
///
/// @sa MakeLogicalPlan
template <class TPlanningContext>
class RuleBasedPlanner {
public:
explicit RuleBasedPlanner(TPlanningContext *context) : context_(context) {}
/// @brief The result of plan generation is the root of the generated operator
/// tree.
using PlanResult = std::unique_ptr<LogicalOperator>;
/// @brief Generates the operator tree based on explicitly set rules.
PlanResult Plan(const std::vector<SingleQueryPart> &query_parts) {
auto &context = *context_;
std::unique_ptr<LogicalOperator> input_op;
// Set to true if a query command writes to the database.
bool is_write = false;
for (const auto &query_part : query_parts) {
MatchContext match_ctx{query_part.matching, *context.symbol_table, context.bound_symbols};
input_op = PlanMatching(match_ctx, std::move(input_op));
for (const auto &matching : query_part.optional_matching) {
MatchContext opt_ctx{matching, *context.symbol_table, context.bound_symbols};
auto match_op = PlanMatching(opt_ctx, nullptr);
if (match_op) {
input_op = std::make_unique<Optional>(std::move(input_op), std::move(match_op), opt_ctx.new_symbols);
}
}
uint64_t merge_id = 0;
for (auto *clause : query_part.remaining_clauses) {
MG_ASSERT(!utils::IsSubtype(*clause, Match::kType), "Unexpected Match in remaining clauses");
if (auto *ret = utils::Downcast<Return>(clause)) {
input_op = impl::GenReturn(*ret, std::move(input_op), *context.symbol_table, is_write, context.bound_symbols,
*context.ast_storage);
} else if (auto *merge = utils::Downcast<query::v2::Merge>(clause)) {
input_op = GenMerge(*merge, std::move(input_op), query_part.merge_matching[merge_id++]);
// Treat MERGE clause as write, because we do not know if it will
// create anything.
is_write = true;
} else if (auto *with = utils::Downcast<query::v2::With>(clause)) {
input_op = impl::GenWith(*with, std::move(input_op), *context.symbol_table, is_write, context.bound_symbols,
*context.ast_storage);
// WITH clause advances the command, so reset the flag.
is_write = false;
} else if (auto op = HandleWriteClause(clause, input_op, *context.symbol_table, context.bound_symbols)) {
is_write = true;
input_op = std::move(op);
} else if (auto *unwind = utils::Downcast<query::v2::Unwind>(clause)) {
const auto &symbol = context.symbol_table->at(*unwind->named_expression_);
context.bound_symbols.insert(symbol);
input_op =
std::make_unique<plan::Unwind>(std::move(input_op), unwind->named_expression_->expression_, symbol);
} else if (auto *call_proc = utils::Downcast<query::v2::CallProcedure>(clause)) {
std::vector<Symbol> result_symbols;
result_symbols.reserve(call_proc->result_identifiers_.size());
for (const auto *ident : call_proc->result_identifiers_) {
const auto &sym = context.symbol_table->at(*ident);
context.bound_symbols.insert(sym);
result_symbols.push_back(sym);
}
// TODO: When we add support for write and eager procedures, we will
// need to plan this operator with Accumulate and pass in
// storage::v3::View::NEW.
input_op = std::make_unique<plan::CallProcedure>(
std::move(input_op), call_proc->procedure_name_, call_proc->arguments_, call_proc->result_fields_,
result_symbols, call_proc->memory_limit_, call_proc->memory_scale_, call_proc->is_write_);
} else if (auto *load_csv = utils::Downcast<query::v2::LoadCsv>(clause)) {
const auto &row_sym = context.symbol_table->at(*load_csv->row_var_);
context.bound_symbols.insert(row_sym);
input_op =
std::make_unique<plan::LoadCsv>(std::move(input_op), load_csv->file_, load_csv->with_header_,
load_csv->ignore_bad_, load_csv->delimiter_, load_csv->quote_, row_sym);
} else if (auto *foreach = utils::Downcast<query::v2::Foreach>(clause)) {
is_write = true;
input_op = HandleForeachClause(foreach, std::move(input_op), *context.symbol_table, context.bound_symbols,
query_part, merge_id);
} else {
throw utils::NotYetImplemented("clause '{}' conversion to operator(s)", clause->GetTypeInfo().name);
}
}
}
return input_op;
}
private:
TPlanningContext *context_;
storage::v3::LabelId GetLabel(LabelIx label) { return context_->db->NameToLabel(label.name); }
storage::v3::PropertyId GetProperty(PropertyIx prop) { return context_->db->NameToProperty(prop.name); }
storage::v3::EdgeTypeId GetEdgeType(EdgeTypeIx edge_type) { return context_->db->NameToEdgeType(edge_type.name); }
std::unique_ptr<LogicalOperator> GenCreate(Create &create, std::unique_ptr<LogicalOperator> input_op,
const SymbolTable &symbol_table,
std::unordered_set<Symbol> &bound_symbols) {
auto last_op = std::move(input_op);
for (auto pattern : create.patterns_) {
last_op = GenCreateForPattern(*pattern, std::move(last_op), symbol_table, bound_symbols);
}
return last_op;
}
std::unique_ptr<LogicalOperator> GenCreateForPattern(Pattern &pattern, std::unique_ptr<LogicalOperator> input_op,
const SymbolTable &symbol_table,
std::unordered_set<Symbol> &bound_symbols) {
auto node_to_creation_info = [&](const NodeAtom &node) {
const auto &node_symbol = symbol_table.at(*node.identifier_);
std::vector<storage::v3::LabelId> labels;
labels.reserve(node.labels_.size());
for (const auto &label : node.labels_) {
labels.push_back(GetLabel(label));
}
auto properties = std::invoke([&]() -> std::variant<PropertiesMapList, ParameterLookup *> {
if (const auto *node_properties =
std::get_if<std::unordered_map<PropertyIx, Expression *>>(&node.properties_)) {
PropertiesMapList vector_props;
vector_props.reserve(node_properties->size());
for (const auto &kv : *node_properties) {
vector_props.push_back({GetProperty(kv.first), kv.second});
}
return std::move(vector_props);
}
return std::get<ParameterLookup *>(node.properties_);
});
return NodeCreationInfo{node_symbol, labels, properties};
};
auto base = [&](NodeAtom *node) -> std::unique_ptr<LogicalOperator> {
const auto &node_symbol = symbol_table.at(*node->identifier_);
if (bound_symbols.insert(node_symbol).second) {
auto node_info = node_to_creation_info(*node);
return std::make_unique<CreateNode>(std::move(input_op), node_info);
}
return std::move(input_op);
};
auto collect = [&](std::unique_ptr<LogicalOperator> last_op, NodeAtom *prev_node, EdgeAtom *edge, NodeAtom *node) {
// Store the symbol from the first node as the input to CreateExpand.
const auto &input_symbol = symbol_table.at(*prev_node->identifier_);
// If the expand node was already bound, then we need to indicate this,
// so that CreateExpand only creates an edge.
bool node_existing = false;
if (!bound_symbols.insert(symbol_table.at(*node->identifier_)).second) {
node_existing = true;
}
const auto &edge_symbol = symbol_table.at(*edge->identifier_);
if (!bound_symbols.insert(edge_symbol).second) {
LOG_FATAL("Symbols used for created edges cannot be redeclared.");
}
auto node_info = node_to_creation_info(*node);
auto properties = std::invoke([&]() -> std::variant<PropertiesMapList, ParameterLookup *> {
if (const auto *edge_properties =
std::get_if<std::unordered_map<PropertyIx, Expression *>>(&edge->properties_)) {
PropertiesMapList vector_props;
vector_props.reserve(edge_properties->size());
for (const auto &kv : *edge_properties) {
vector_props.push_back({GetProperty(kv.first), kv.second});
}
return std::move(vector_props);
}
return std::get<ParameterLookup *>(edge->properties_);
});
MG_ASSERT(edge->edge_types_.size() == 1, "Creating an edge with a single type should be required by syntax");
EdgeCreationInfo edge_info{edge_symbol, properties, GetEdgeType(edge->edge_types_[0]), edge->direction_};
return std::make_unique<CreateExpand>(node_info, edge_info, std::move(last_op), input_symbol, node_existing);
};
auto last_op = impl::ReducePattern<std::unique_ptr<LogicalOperator>>(pattern, base, collect);
// If the pattern is named, append the path constructing logical operator.
if (pattern.identifier_->user_declared_) {
std::vector<Symbol> path_elements;
for (const PatternAtom *atom : pattern.atoms_) path_elements.emplace_back(symbol_table.at(*atom->identifier_));
last_op = std::make_unique<ConstructNamedPath>(std::move(last_op), symbol_table.at(*pattern.identifier_),
path_elements);
}
return last_op;
}
// Generate an operator for a clause which writes to the database. Ownership
// of input_op is transferred to the newly created operator. If the clause
// isn't handled, returns nullptr and input_op is left as is.
std::unique_ptr<LogicalOperator> HandleWriteClause(Clause *clause, std::unique_ptr<LogicalOperator> &input_op,
const SymbolTable &symbol_table,
std::unordered_set<Symbol> &bound_symbols) {
if (auto *create = utils::Downcast<Create>(clause)) {
return GenCreate(*create, std::move(input_op), symbol_table, bound_symbols);
} else if (auto *del = utils::Downcast<query::v2::Delete>(clause)) {
return std::make_unique<plan::Delete>(std::move(input_op), del->expressions_, del->detach_);
} else if (auto *set = utils::Downcast<query::v2::SetProperty>(clause)) {
return std::make_unique<plan::SetProperty>(std::move(input_op), GetProperty(set->property_lookup_->property_),
set->property_lookup_, set->expression_);
} else if (auto *set = utils::Downcast<query::v2::SetProperties>(clause)) {
auto op = set->update_ ? plan::SetProperties::Op::UPDATE : plan::SetProperties::Op::REPLACE;
const auto &input_symbol = symbol_table.at(*set->identifier_);
return std::make_unique<plan::SetProperties>(std::move(input_op), input_symbol, set->expression_, op);
} else if (auto *set = utils::Downcast<query::v2::SetLabels>(clause)) {
const auto &input_symbol = symbol_table.at(*set->identifier_);
std::vector<storage::v3::LabelId> labels;
labels.reserve(set->labels_.size());
for (const auto &label : set->labels_) {
labels.push_back(GetLabel(label));
}
return std::make_unique<plan::SetLabels>(std::move(input_op), input_symbol, labels);
} else if (auto *rem = utils::Downcast<query::v2::RemoveProperty>(clause)) {
return std::make_unique<plan::RemoveProperty>(std::move(input_op), GetProperty(rem->property_lookup_->property_),
rem->property_lookup_);
} else if (auto *rem = utils::Downcast<query::v2::RemoveLabels>(clause)) {
const auto &input_symbol = symbol_table.at(*rem->identifier_);
std::vector<storage::v3::LabelId> labels;
labels.reserve(rem->labels_.size());
for (const auto &label : rem->labels_) {
labels.push_back(GetLabel(label));
}
return std::make_unique<plan::RemoveLabels>(std::move(input_op), input_symbol, labels);
}
return nullptr;
}
std::unique_ptr<LogicalOperator> PlanMatching(MatchContext &match_context,
std::unique_ptr<LogicalOperator> input_op) {
auto &bound_symbols = match_context.bound_symbols;
auto &storage = *context_->ast_storage;
const auto &symbol_table = match_context.symbol_table;
const auto &matching = match_context.matching;
// Copy filters, because we will modify them as we generate Filters.
auto filters = matching.filters;
// Copy the named_paths for the same reason.
auto named_paths = matching.named_paths;
// Try to generate any filters even before the 1st match operator. This
// optimizes the optional match which filters only on symbols bound in
// regular match.
auto last_op = impl::GenFilters(std::move(input_op), bound_symbols, filters, storage);
for (const auto &expansion : matching.expansions) {
const auto &node1_symbol = symbol_table.at(*expansion.node1->identifier_);
if (bound_symbols.insert(node1_symbol).second) {
// We have just bound this symbol, so generate ScanAll which fills it.
last_op = std::make_unique<ScanAll>(std::move(last_op), node1_symbol, match_context.view);
match_context.new_symbols.emplace_back(node1_symbol);
last_op = impl::GenFilters(std::move(last_op), bound_symbols, filters, storage);
last_op = impl::GenNamedPaths(std::move(last_op), bound_symbols, named_paths);
last_op = impl::GenFilters(std::move(last_op), bound_symbols, filters, storage);
}
// We have an edge, so generate Expand.
if (expansion.edge) {
auto *edge = expansion.edge;
// If the expand symbols were already bound, then we need to indicate
// that they exist. The Expand will then check whether the pattern holds
// instead of writing the expansion to symbols.
const auto &node_symbol = symbol_table.at(*expansion.node2->identifier_);
auto existing_node = utils::Contains(bound_symbols, node_symbol);
const auto &edge_symbol = symbol_table.at(*edge->identifier_);
MG_ASSERT(!utils::Contains(bound_symbols, edge_symbol), "Existing edges are not supported");
std::vector<storage::v3::EdgeTypeId> edge_types;
edge_types.reserve(edge->edge_types_.size());
for (const auto &type : edge->edge_types_) {
edge_types.push_back(GetEdgeType(type));
}
if (edge->IsVariable()) {
std::optional<ExpansionLambda> weight_lambda;
std::optional<Symbol> total_weight;
if (edge->type_ == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH) {
weight_lambda.emplace(ExpansionLambda{symbol_table.at(*edge->weight_lambda_.inner_edge),
symbol_table.at(*edge->weight_lambda_.inner_node),
edge->weight_lambda_.expression});
total_weight.emplace(symbol_table.at(*edge->total_weight_));
}
ExpansionLambda filter_lambda;
filter_lambda.inner_edge_symbol = symbol_table.at(*edge->filter_lambda_.inner_edge);
filter_lambda.inner_node_symbol = symbol_table.at(*edge->filter_lambda_.inner_node);
{
// Bind the inner edge and node symbols so they're available for
// inline filtering in ExpandVariable.
bool inner_edge_bound = bound_symbols.insert(filter_lambda.inner_edge_symbol).second;
bool inner_node_bound = bound_symbols.insert(filter_lambda.inner_node_symbol).second;
MG_ASSERT(inner_edge_bound && inner_node_bound, "An inner edge and node can't be bound from before");
}
// Join regular filters with lambda filter expression, so that they
// are done inline together. Semantic analysis should guarantee that
// lambda filtering uses bound symbols.
filter_lambda.expression = impl::BoolJoin<AndOperator>(
storage, impl::ExtractFilters(bound_symbols, filters, storage), edge->filter_lambda_.expression);
// At this point it's possible we have leftover filters for inline
// filtering (they use the inner symbols. If they were not collected,
// we have to remove them manually because no other filter-extraction
// will ever bind them again.
filters.erase(std::remove_if(
filters.begin(), filters.end(),
[e = filter_lambda.inner_edge_symbol, n = filter_lambda.inner_node_symbol](FilterInfo &fi) {
return utils::Contains(fi.used_symbols, e) || utils::Contains(fi.used_symbols, n);
}),
filters.end());
// Unbind the temporarily bound inner symbols for filtering.
bound_symbols.erase(filter_lambda.inner_edge_symbol);
bound_symbols.erase(filter_lambda.inner_node_symbol);
if (total_weight) {
bound_symbols.insert(*total_weight);
}
// TODO: Pass weight lambda.
MG_ASSERT(match_context.view == storage::v3::View::OLD,
"ExpandVariable should only be planned with storage::v3::View::OLD");
last_op = std::make_unique<ExpandVariable>(std::move(last_op), node1_symbol, node_symbol, edge_symbol,
edge->type_, expansion.direction, edge_types, expansion.is_flipped,
edge->lower_bound_, edge->upper_bound_, existing_node,
filter_lambda, weight_lambda, total_weight);
} else {
last_op = std::make_unique<Expand>(std::move(last_op), node1_symbol, node_symbol, edge_symbol,
expansion.direction, edge_types, existing_node, match_context.view);
}
// Bind the expanded edge and node.
bound_symbols.insert(edge_symbol);
match_context.new_symbols.emplace_back(edge_symbol);
if (bound_symbols.insert(node_symbol).second) {
match_context.new_symbols.emplace_back(node_symbol);
}
// Ensure Cyphermorphism (different edge symbols always map to
// different edges).
for (const auto &edge_symbols : matching.edge_symbols) {
if (edge_symbols.find(edge_symbol) == edge_symbols.end()) {
continue;
}
std::vector<Symbol> other_symbols;
for (const auto &symbol : edge_symbols) {
if (symbol == edge_symbol || bound_symbols.find(symbol) == bound_symbols.end()) {
continue;
}
other_symbols.push_back(symbol);
}
if (!other_symbols.empty()) {
last_op = std::make_unique<EdgeUniquenessFilter>(std::move(last_op), edge_symbol, other_symbols);
}
}
last_op = impl::GenFilters(std::move(last_op), bound_symbols, filters, storage);
last_op = impl::GenNamedPaths(std::move(last_op), bound_symbols, named_paths);
last_op = impl::GenFilters(std::move(last_op), bound_symbols, filters, storage);
}
}
MG_ASSERT(named_paths.empty(), "Expected to generate all named paths");
// We bound all named path symbols, so just add them to new_symbols.
for (const auto &named_path : matching.named_paths) {
MG_ASSERT(utils::Contains(bound_symbols, named_path.first), "Expected generated named path to have bound symbol");
match_context.new_symbols.emplace_back(named_path.first);
}
MG_ASSERT(filters.empty(), "Expected to generate all filters");
return last_op;
}
auto GenMerge(query::v2::Merge &merge, std::unique_ptr<LogicalOperator> input_op, const Matching &matching) {
// Copy the bound symbol set, because we don't want to use the updated
// version when generating the create part.
std::unordered_set<Symbol> bound_symbols_copy(context_->bound_symbols);
MatchContext match_ctx{matching, *context_->symbol_table, bound_symbols_copy, storage::v3::View::NEW};
std::vector<Symbol> bound_symbols(context_->bound_symbols.begin(), context_->bound_symbols.end());
auto once_with_symbols = std::make_unique<Once>(bound_symbols);
auto on_match = PlanMatching(match_ctx, std::move(once_with_symbols));
once_with_symbols = std::make_unique<Once>(std::move(bound_symbols));
// Use the original bound_symbols, so we fill it with new symbols.
auto on_create = GenCreateForPattern(*merge.pattern_, std::move(once_with_symbols), *context_->symbol_table,
context_->bound_symbols);
for (auto &set : merge.on_create_) {
on_create = HandleWriteClause(set, on_create, *context_->symbol_table, context_->bound_symbols);
MG_ASSERT(on_create, "Expected SET in MERGE ... ON CREATE");
}
for (auto &set : merge.on_match_) {
on_match = HandleWriteClause(set, on_match, *context_->symbol_table, context_->bound_symbols);
MG_ASSERT(on_match, "Expected SET in MERGE ... ON MATCH");
}
return std::make_unique<plan::Merge>(std::move(input_op), std::move(on_match), std::move(on_create));
}
std::unique_ptr<LogicalOperator> HandleForeachClause(query::v2::Foreach *foreach,
std::unique_ptr<LogicalOperator> input_op,
const SymbolTable &symbol_table,
std::unordered_set<Symbol> &bound_symbols,
const SingleQueryPart &query_part, uint64_t &merge_id) {
const auto &symbol = symbol_table.at(*foreach->named_expression_);
bound_symbols.insert(symbol);
std::unique_ptr<LogicalOperator> op = std::make_unique<plan::Once>();
for (auto *clause : foreach->clauses_) {
if (auto *nested_for_each = utils::Downcast<query::v2::Foreach>(clause)) {
op = HandleForeachClause(nested_for_each, std::move(op), symbol_table, bound_symbols, query_part, merge_id);
} else if (auto *merge = utils::Downcast<query::v2::Merge>(clause)) {
op = GenMerge(*merge, std::move(op), query_part.merge_matching[merge_id++]);
} else {
op = HandleWriteClause(clause, op, symbol_table, bound_symbols);
}
}
return std::make_unique<plan::Foreach>(std::move(input_op), std::move(op), foreach->named_expression_->expression_,
symbol);
}
};
} // namespace memgraph::query::v2::plan

View File

@ -0,0 +1,79 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <cstdint>
#include "query/v2/context.hpp"
#include "query/v2/plan/profile.hpp"
#include "utils/likely.hpp"
#include "utils/tsc.hpp"
namespace memgraph::query::v2::plan {
/**
* A RAII class used for profiling logical operators. Instances of this class
* update the profiling data stored within the `ExecutionContext` object and build
* up a tree of `ProfilingStats` instances. The structure of the `ProfilingStats`
* tree depends on the `LogicalOperator`s that were executed.
*/
class ScopedProfile {
public:
ScopedProfile(uint64_t key, const char *name, query::v2::ExecutionContext *context) noexcept : context_(context) {
if (UNLIKELY(context_->is_profile_query)) {
root_ = context_->stats_root;
// Are we the root logical operator?
if (!root_) {
stats_ = &context_->stats;
stats_->key = key;
stats_->name = name;
} else {
stats_ = nullptr;
// Was this logical operator already hit on one of the previous pulls?
auto it = std::find_if(root_->children.begin(), root_->children.end(),
[key](auto &stats) { return stats.key == key; });
if (it == root_->children.end()) {
root_->children.emplace_back();
stats_ = &root_->children.back();
stats_->key = key;
stats_->name = name;
} else {
stats_ = &(*it);
}
}
context_->stats_root = stats_;
stats_->actual_hits++;
start_time_ = utils::ReadTSC();
}
}
~ScopedProfile() noexcept {
if (UNLIKELY(context_->is_profile_query)) {
stats_->num_cycles += utils::ReadTSC() - start_time_;
// Restore the old root ("pop")
context_->stats_root = root_;
}
}
private:
query::v2::ExecutionContext *context_;
ProfilingStats *root_{nullptr};
ProfilingStats *stats_{nullptr};
unsigned long long start_time_{0};
};
} // namespace memgraph::query::v2::plan

View File

@ -0,0 +1,296 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/v2/plan/variable_start_planner.hpp"
#include <limits>
#include <queue>
#include "utils/flag_validation.hpp"
#include "utils/logging.hpp"
DEFINE_VALIDATED_HIDDEN_uint64(query_max_plans, 1000U, "Maximum number of generated plans for a query.",
FLAG_IN_RANGE(1, std::numeric_limits<std::uint64_t>::max()));
namespace memgraph::query::v2::plan::impl {
namespace {
// Add applicable expansions for `node_symbol` to `next_expansions`. These
// expansions are removed from `node_symbol_to_expansions`, while
// `seen_expansions` and `expanded_symbols` are populated with new data.
void AddNextExpansions(const Symbol &node_symbol, const Matching &matching, const SymbolTable &symbol_table,
std::unordered_set<Symbol> &expanded_symbols,
std::unordered_map<Symbol, std::set<size_t>> &node_symbol_to_expansions,
std::unordered_set<size_t> &seen_expansions, std::queue<Expansion> &next_expansions) {
auto node_to_expansions_it = node_symbol_to_expansions.find(node_symbol);
if (node_to_expansions_it == node_symbol_to_expansions.end()) {
return;
}
// Returns true if the expansion is a regular expand or if it is a variable
// path expand, but with bound symbols used inside the range expression.
auto can_expand = [&](auto &expansion) {
for (const auto &range_symbol : expansion.symbols_in_range) {
// If the symbols used in range need to be bound during this whole
// expansion, we must check whether they have already been expanded and
// therefore bound. If the symbols are not found in the whole expansion,
// then the semantic analysis should guarantee that the symbols have been
// bound long before we expand.
if (matching.expansion_symbols.find(range_symbol) != matching.expansion_symbols.end() &&
expanded_symbols.find(range_symbol) == expanded_symbols.end()) {
return false;
}
}
return true;
};
auto &node_expansions = node_to_expansions_it->second;
auto node_expansions_it = node_expansions.begin();
while (node_expansions_it != node_to_expansions_it->second.end()) {
auto expansion_id = *node_expansions_it;
if (seen_expansions.find(expansion_id) != seen_expansions.end()) {
// Skip and erase seen (already expanded) expansions.
node_expansions_it = node_expansions.erase(node_expansions_it);
continue;
}
auto expansion = matching.expansions[expansion_id];
if (!can_expand(expansion)) {
// Skip but save expansions which need other symbols for later.
++node_expansions_it;
continue;
}
if (symbol_table.at(*expansion.node1->identifier_) != node_symbol) {
// We are not expanding from node1, so flip the expansion.
DMG_ASSERT(expansion.node2 && symbol_table.at(*expansion.node2->identifier_) == node_symbol,
"Expected node_symbol to be bound in node2");
if (expansion.edge->type_ != EdgeAtom::Type::BREADTH_FIRST) {
// BFS must *not* be flipped. Doing that changes the BFS results.
std::swap(expansion.node1, expansion.node2);
expansion.is_flipped = true;
if (expansion.direction != EdgeAtom::Direction::BOTH) {
expansion.direction =
expansion.direction == EdgeAtom::Direction::IN ? EdgeAtom::Direction::OUT : EdgeAtom::Direction::IN;
}
}
}
seen_expansions.insert(expansion_id);
expanded_symbols.insert(symbol_table.at(*expansion.node1->identifier_));
if (expansion.edge) {
expanded_symbols.insert(symbol_table.at(*expansion.edge->identifier_));
expanded_symbols.insert(symbol_table.at(*expansion.node2->identifier_));
}
next_expansions.emplace(std::move(expansion));
node_expansions_it = node_expansions.erase(node_expansions_it);
}
if (node_expansions.empty()) {
node_symbol_to_expansions.erase(node_to_expansions_it);
}
}
// Generates expansions emanating from the start_node by forming a chain. When
// the chain can no longer be continued, a different starting node is picked
// among remaining expansions and the process continues. This is done until all
// matching.expansions are used.
std::vector<Expansion> ExpansionsFrom(const NodeAtom *start_node, const Matching &matching,
const SymbolTable &symbol_table) {
// Make a copy of node_symbol_to_expansions, because we will modify it as
// expansions are chained.
auto node_symbol_to_expansions = matching.node_symbol_to_expansions;
std::unordered_set<size_t> seen_expansions;
std::queue<Expansion> next_expansions;
std::unordered_set<Symbol> expanded_symbols({symbol_table.at(*start_node->identifier_)});
auto add_next_expansions = [&](const auto *node) {
AddNextExpansions(symbol_table.at(*node->identifier_), matching, symbol_table, expanded_symbols,
node_symbol_to_expansions, seen_expansions, next_expansions);
};
add_next_expansions(start_node);
// Potential optimization: expansions and next_expansions could be merge into
// a single vector and an index could be used to determine from which should
// additional expansions be added.
std::vector<Expansion> expansions;
while (!next_expansions.empty()) {
auto expansion = next_expansions.front();
next_expansions.pop();
expansions.emplace_back(expansion);
add_next_expansions(expansion.node1);
if (expansion.node2) {
add_next_expansions(expansion.node2);
}
}
if (!node_symbol_to_expansions.empty()) {
// We could pick a new starting expansion, but to avoid runtime
// complexity, simply append the remaining expansions. They should have the
// correct order, since the original expansions were verified during
// semantic analysis.
for (size_t i = 0; i < matching.expansions.size(); ++i) {
if (seen_expansions.find(i) != seen_expansions.end()) {
continue;
}
expansions.emplace_back(matching.expansions[i]);
}
}
return expansions;
}
// Collect all unique nodes from expansions. Uniqueness is determined by
// symbol uniqueness.
auto ExpansionNodes(const std::vector<Expansion> &expansions, const SymbolTable &symbol_table) {
std::unordered_set<NodeAtom *, NodeSymbolHash, NodeSymbolEqual> nodes(expansions.size(), NodeSymbolHash(symbol_table),
NodeSymbolEqual(symbol_table));
for (const auto &expansion : expansions) {
// TODO: Handle labels and properties from different node atoms.
nodes.insert(expansion.node1);
if (expansion.node2) {
nodes.insert(expansion.node2);
}
}
return nodes;
}
} // namespace
VaryMatchingStart::VaryMatchingStart(Matching matching, const SymbolTable &symbol_table)
: matching_(matching), symbol_table_(symbol_table), nodes_(ExpansionNodes(matching.expansions, symbol_table)) {}
VaryMatchingStart::iterator::iterator(VaryMatchingStart *self, bool is_done)
: self_(self),
// Use the original matching as the first matching. We are only
// interested in changing the expansions part, so the remaining fields
// should stay the same. This also produces a matching for the case
// when there are no nodes.
current_matching_(self->matching_) {
if (!self_->nodes_.empty()) {
// Overwrite the original matching expansions with the new ones by
// generating it from the first start node.
start_nodes_it_ = self_->nodes_.begin();
current_matching_.expansions = ExpansionsFrom(**start_nodes_it_, self_->matching_, self_->symbol_table_);
}
DMG_ASSERT(start_nodes_it_ || self_->nodes_.empty(),
"start_nodes_it_ should only be nullopt when self_->nodes_ is empty");
if (is_done) {
start_nodes_it_ = self_->nodes_.end();
}
}
VaryMatchingStart::iterator &VaryMatchingStart::iterator::operator++() {
if (!start_nodes_it_) {
DMG_ASSERT(self_->nodes_.empty(), "start_nodes_it_ should only be nullopt when self_->nodes_ is empty");
start_nodes_it_ = self_->nodes_.end();
}
if (*start_nodes_it_ == self_->nodes_.end()) {
return *this;
}
++*start_nodes_it_;
// start_nodes_it_ can become equal to `end` and we shouldn't dereference
// iterator in that case.
if (*start_nodes_it_ == self_->nodes_.end()) {
return *this;
}
const auto &start_node = **start_nodes_it_;
current_matching_.expansions = ExpansionsFrom(start_node, self_->matching_, self_->symbol_table_);
return *this;
}
CartesianProduct<VaryMatchingStart> VaryMultiMatchingStarts(const std::vector<Matching> &matchings,
const SymbolTable &symbol_table) {
std::vector<VaryMatchingStart> variants;
variants.reserve(matchings.size());
for (const auto &matching : matchings) {
variants.emplace_back(VaryMatchingStart(matching, symbol_table));
}
return MakeCartesianProduct(std::move(variants));
}
VaryQueryPartMatching::VaryQueryPartMatching(SingleQueryPart query_part, const SymbolTable &symbol_table)
: query_part_(std::move(query_part)),
matchings_(VaryMatchingStart(query_part_.matching, symbol_table)),
optional_matchings_(VaryMultiMatchingStarts(query_part_.optional_matching, symbol_table)),
merge_matchings_(VaryMultiMatchingStarts(query_part_.merge_matching, symbol_table)) {}
VaryQueryPartMatching::iterator::iterator(const SingleQueryPart &query_part,
VaryMatchingStart::iterator matchings_begin,
VaryMatchingStart::iterator matchings_end,
CartesianProduct<VaryMatchingStart>::iterator optional_begin,
CartesianProduct<VaryMatchingStart>::iterator optional_end,
CartesianProduct<VaryMatchingStart>::iterator merge_begin,
CartesianProduct<VaryMatchingStart>::iterator merge_end)
: current_query_part_(query_part),
matchings_it_(matchings_begin),
matchings_end_(matchings_end),
optional_it_(optional_begin),
optional_begin_(optional_begin),
optional_end_(optional_end),
merge_it_(merge_begin),
merge_begin_(merge_begin),
merge_end_(merge_end) {
if (matchings_it_ != matchings_end_) {
// Fill the query part with the first variation of matchings
SetCurrentQueryPart();
}
}
VaryQueryPartMatching::iterator &VaryQueryPartMatching::iterator::operator++() {
// Produce parts for each possible combination. E.g. if we have:
// * matchings (m1) and (m2)
// * optional matchings (o1) and (o2)
// * merge matching (g1)
// We want to produce parts for:
// * (m1), (o1), (g1)
// * (m1), (o2), (g1)
// * (m2), (o1), (g1)
// * (m2), (o2), (g1)
// Create variations by changing the merge part first.
if (merge_it_ != merge_end_) ++merge_it_;
// If all merge variations are done, start them from beginning and move to the
// next optional matching variation.
if (merge_it_ == merge_end_) {
merge_it_ = merge_begin_;
if (optional_it_ != optional_end_) ++optional_it_;
}
// If all optional matching variations are done (after exhausting merge
// variations), start them from beginning and move to the next regular
// matching variation.
if (optional_it_ == optional_end_ && merge_it_ == merge_begin_) {
optional_it_ = optional_begin_;
if (matchings_it_ != matchings_end_) ++matchings_it_;
}
// We have reached the end, so return;
if (matchings_it_ == matchings_end_) return *this;
// Fill the query part with the new variation of matchings.
SetCurrentQueryPart();
return *this;
}
void VaryQueryPartMatching::iterator::SetCurrentQueryPart() {
current_query_part_.matching = *matchings_it_;
DMG_ASSERT(optional_it_ != optional_end_ || optional_begin_ == optional_end_,
"Either there are no optional matchings or we can always "
"generate a variation");
if (optional_it_ != optional_end_) {
current_query_part_.optional_matching = *optional_it_;
}
DMG_ASSERT(merge_it_ != merge_end_ || merge_begin_ == merge_end_,
"Either there are no merge matchings or we can always generate "
"a variation");
if (merge_it_ != merge_end_) {
current_query_part_.merge_matching = *merge_it_;
}
}
bool VaryQueryPartMatching::iterator::operator==(const iterator &other) const {
if (matchings_it_ == other.matchings_it_ && matchings_it_ == matchings_end_) {
// matchings_it_ is the primary iterator. If both are at the end, then other
// iterators can be at any position.
return true;
}
return matchings_it_ == other.matchings_it_ && optional_it_ == other.optional_it_ && merge_it_ == other.merge_it_;
}
} // namespace memgraph::query::v2::plan::impl

View File

@ -0,0 +1,336 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
/// @file
#pragma once
#include "cppitertools/imap.hpp"
#include "cppitertools/slice.hpp"
#include "gflags/gflags.h"
#include "query/v2/plan/rule_based_planner.hpp"
DECLARE_uint64(query_max_plans);
namespace memgraph::query::v2::plan {
/// Produces a Cartesian product among vectors between begin and end iterator.
/// For example:
///
/// std::vector<int> first_set{1,2,3};
/// std::vector<int> second_set{4,5};
/// std::vector<std::vector<int>> all_sets{first_set, second_set};
/// // prod should be {{1, 4}, {1, 5}, {2, 4}, {2, 5}, {3, 4}, {3, 5}}
/// auto product = MakeCartesianProduct(all_sets);
/// for (const auto &set : product) {
/// ...
/// }
///
/// The product is created lazily by iterating over the constructed
/// CartesianProduct instance.
template <typename TSet>
class CartesianProduct {
private:
// The original sets whose Cartesian product we are calculating.
std::vector<TSet> original_sets_;
// Iterators to the beginning and end of original_sets_.
decltype(original_sets_.begin()) begin_;
decltype(original_sets_.end()) end_;
// Type of the set element.
using TElement = typename decltype(begin_->begin())::value_type;
public:
CartesianProduct(std::vector<TSet> sets)
: original_sets_(std::move(sets)), begin_(original_sets_.begin()), end_(original_sets_.end()) {}
class iterator {
public:
typedef std::input_iterator_tag iterator_category;
typedef std::vector<TElement> value_type;
typedef long difference_type;
typedef const std::vector<TElement> &reference;
typedef const std::vector<TElement> *pointer;
explicit iterator(CartesianProduct *self, bool is_done) : self_(self), is_done_(is_done) {
if (is_done || self->begin_ == self->end_) {
is_done_ = true;
return;
}
auto begin = self->begin_;
while (begin != self->end_) {
auto set_it = begin->begin();
if (set_it == begin->end()) {
// One of the sets is empty, so there is no product.
is_done_ = true;
return;
}
// Collect the first product, by taking the first element of each set.
current_product_.emplace_back(*set_it);
// Store starting iterators to all sets.
sets_.emplace_back(begin, set_it);
begin++;
}
}
iterator &operator++() {
if (is_done_) return *this;
// Increment the leftmost set iterator.
auto sets_it = sets_.begin();
++sets_it->second;
// If the leftmost is at the end, reset it and increment the next
// leftmost.
while (sets_it->second == sets_it->first->end()) {
sets_it->second = sets_it->first->begin();
sets_it++;
if (sets_it == sets_.end()) {
// The leftmost set is the last set and it was exhausted, so we are
// done.
is_done_ = true;
return *this;
}
++sets_it->second;
}
// We can now collect another product from the modified set iterators.
DMG_ASSERT(current_product_.size() == sets_.size(),
"Expected size of current_product_ to match the size of sets_");
size_t i = 0;
// Change only the prefix of the product, remaining elements (after
// sets_it) should be the same.
auto last_unmodified = sets_it + 1;
for (auto kv_it = sets_.begin(); kv_it != last_unmodified; ++kv_it) {
current_product_[i++] = *kv_it->second;
}
return *this;
}
bool operator==(const iterator &other) const {
if (self_->begin_ != other.self_->begin_ || self_->end_ != other.self_->end_) return false;
return (is_done_ && other.is_done_) || (sets_ == other.sets_);
}
bool operator!=(const iterator &other) const { return !(*this == other); }
// Iterator interface says that dereferencing a past-the-end iterator is
// undefined, so don't bother checking if we are done.
reference operator*() const { return current_product_; }
pointer operator->() const { return &current_product_; }
private:
// Pointer instead of reference to auto generate copy constructor and
// assignment.
CartesianProduct *self_;
// Vector of (original_sets_iterator, set_iterator) pairs. The
// original_sets_iterator points to the set among all the sets, while the
// set_iterator points to an element inside the pointed to set.
std::vector<std::pair<decltype(self_->begin_), decltype(self_->begin_->begin())>> sets_;
// Currently built product from pointed to elements in all sets.
std::vector<TElement> current_product_;
// Set to true when we have generated all products.
bool is_done_ = false;
};
auto begin() { return iterator(this, false); }
auto end() { return iterator(this, true); }
private:
friend class iterator;
};
/// Convenience function for creating CartesianProduct by deducing template
/// arguments from function arguments.
template <typename TSet>
auto MakeCartesianProduct(std::vector<TSet> sets) {
return CartesianProduct<TSet>(std::move(sets));
}
namespace impl {
class NodeSymbolHash {
public:
explicit NodeSymbolHash(const SymbolTable &symbol_table) : symbol_table_(symbol_table) {}
size_t operator()(const NodeAtom *node_atom) const {
return std::hash<Symbol>{}(symbol_table_.at(*node_atom->identifier_));
}
private:
const SymbolTable &symbol_table_;
};
class NodeSymbolEqual {
public:
explicit NodeSymbolEqual(const SymbolTable &symbol_table) : symbol_table_(symbol_table) {}
bool operator()(const NodeAtom *node_atom1, const NodeAtom *node_atom2) const {
return symbol_table_.at(*node_atom1->identifier_) == symbol_table_.at(*node_atom2->identifier_);
}
private:
const SymbolTable &symbol_table_;
};
// Generates n matchings, where n is the number of nodes to match. Each Matching
// will have a different node as a starting node for expansion.
class VaryMatchingStart {
public:
VaryMatchingStart(Matching, const SymbolTable &);
class iterator {
public:
typedef std::input_iterator_tag iterator_category;
typedef Matching value_type;
typedef long difference_type;
typedef const Matching &reference;
typedef const Matching *pointer;
iterator(VaryMatchingStart *, bool);
iterator &operator++();
reference operator*() const { return current_matching_; }
pointer operator->() const { return &current_matching_; }
bool operator==(const iterator &other) const {
return self_ == other.self_ && start_nodes_it_ == other.start_nodes_it_;
}
bool operator!=(const iterator &other) const { return !(*this == other); }
private:
// Pointer instead of reference to auto generate copy constructor and
// assignment.
VaryMatchingStart *self_;
Matching current_matching_;
// Iterator over start nodes. Optional is used for differentiating the case
// when there are no start nodes vs. VaryMatchingStart::iterator itself
// being at the end. When there are no nodes, this iterator needs to produce
// a single result, which is the original matching passed in. Setting
// start_nodes_it_ to end signifies the end of our iteration.
std::optional<std::unordered_set<NodeAtom *, NodeSymbolHash, NodeSymbolEqual>::iterator> start_nodes_it_;
};
auto begin() { return iterator(this, false); }
auto end() { return iterator(this, true); }
private:
friend class iterator;
Matching matching_;
const SymbolTable &symbol_table_;
std::unordered_set<NodeAtom *, NodeSymbolHash, NodeSymbolEqual> nodes_;
};
// Similar to VaryMatchingStart, but varies the starting nodes for all given
// matchings. After all matchings produce multiple alternative starts, the
// Cartesian product of all of them is returned.
CartesianProduct<VaryMatchingStart> VaryMultiMatchingStarts(const std::vector<Matching> &, const SymbolTable &);
// Produces alternative query parts out of a single part by varying how each
// graph matching is done.
class VaryQueryPartMatching {
public:
VaryQueryPartMatching(SingleQueryPart, const SymbolTable &);
class iterator {
public:
typedef std::input_iterator_tag iterator_category;
typedef SingleQueryPart value_type;
typedef long difference_type;
typedef const SingleQueryPart &reference;
typedef const SingleQueryPart *pointer;
iterator(const SingleQueryPart &, VaryMatchingStart::iterator, VaryMatchingStart::iterator,
CartesianProduct<VaryMatchingStart>::iterator, CartesianProduct<VaryMatchingStart>::iterator,
CartesianProduct<VaryMatchingStart>::iterator, CartesianProduct<VaryMatchingStart>::iterator);
iterator &operator++();
reference operator*() const { return current_query_part_; }
pointer operator->() const { return &current_query_part_; }
bool operator==(const iterator &) const;
bool operator!=(const iterator &other) const { return !(*this == other); }
private:
void SetCurrentQueryPart();
SingleQueryPart current_query_part_;
VaryMatchingStart::iterator matchings_it_;
VaryMatchingStart::iterator matchings_end_;
CartesianProduct<VaryMatchingStart>::iterator optional_it_;
CartesianProduct<VaryMatchingStart>::iterator optional_begin_;
CartesianProduct<VaryMatchingStart>::iterator optional_end_;
CartesianProduct<VaryMatchingStart>::iterator merge_it_;
CartesianProduct<VaryMatchingStart>::iterator merge_begin_;
CartesianProduct<VaryMatchingStart>::iterator merge_end_;
};
auto begin() {
return iterator(query_part_, matchings_.begin(), matchings_.end(), optional_matchings_.begin(),
optional_matchings_.end(), merge_matchings_.begin(), merge_matchings_.end());
}
auto end() {
return iterator(query_part_, matchings_.end(), matchings_.end(), optional_matchings_.end(),
optional_matchings_.end(), merge_matchings_.end(), merge_matchings_.end());
}
private:
SingleQueryPart query_part_;
// Multiple regular matchings, each starting from different node.
VaryMatchingStart matchings_;
// Multiple optional matchings, where each combination has different starting
// nodes.
CartesianProduct<VaryMatchingStart> optional_matchings_;
// Like optional matching, but for merge matchings.
CartesianProduct<VaryMatchingStart> merge_matchings_;
};
} // namespace impl
/// @brief Planner which generates multiple plans by changing the order of graph
/// traversal.
///
/// This planner picks different starting nodes from which to start graph
/// traversal. Generating a single plan is backed by @c RuleBasedPlanner.
///
/// @sa MakeLogicalPlan
template <class TPlanningContext>
class VariableStartPlanner {
private:
TPlanningContext *context_;
// Generates different, equivalent query parts by taking different graph
// matching routes for each query part.
auto VaryQueryMatching(const std::vector<SingleQueryPart> &query_parts, const SymbolTable &symbol_table) {
std::vector<impl::VaryQueryPartMatching> alternative_query_parts;
alternative_query_parts.reserve(query_parts.size());
for (const auto &query_part : query_parts) {
alternative_query_parts.emplace_back(impl::VaryQueryPartMatching(query_part, symbol_table));
}
return iter::slice(MakeCartesianProduct(std::move(alternative_query_parts)), 0UL, FLAGS_query_max_plans);
}
public:
explicit VariableStartPlanner(TPlanningContext *context) : context_(context) {}
/// @brief Generate multiple plans by varying the order of graph traversal.
auto Plan(const std::vector<SingleQueryPart> &query_parts) {
return iter::imap(
[context = context_](const auto &alternative_query_parts) {
RuleBasedPlanner<TPlanningContext> rule_planner(context);
context->bound_symbols.clear();
return rule_planner.Plan(alternative_query_parts);
},
VaryQueryMatching(query_parts, *context_->symbol_table));
}
/// @brief The result of plan generation is an iterable of roots to multiple
/// generated operator trees.
using PlanResult = typename std::result_of<decltype (&VariableStartPlanner<TPlanningContext>::Plan)(
VariableStartPlanner<TPlanningContext>, std::vector<SingleQueryPart> &)>::type;
};
} // namespace memgraph::query::v2::plan

View File

@ -0,0 +1,141 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
/// @file
#pragma once
#include <optional>
#include "query/v2/typed_value.hpp"
#include "storage/v3/id_types.hpp"
#include "storage/v3/property_value.hpp"
#include "utils/bound.hpp"
#include "utils/fnv.hpp"
namespace memgraph::query::v2::plan {
/// A stand in class for `TDbAccessor` which provides memoized calls to
/// `VerticesCount`.
template <class TDbAccessor>
class VertexCountCache {
public:
VertexCountCache(TDbAccessor *db) : db_(db) {}
auto NameToLabel(const std::string &name) { return db_->NameToLabel(name); }
auto NameToProperty(const std::string &name) { return db_->NameToProperty(name); }
auto NameToEdgeType(const std::string &name) { return db_->NameToEdgeType(name); }
int64_t VerticesCount() {
if (!vertices_count_) vertices_count_ = db_->VerticesCount();
return *vertices_count_;
}
int64_t VerticesCount(storage::v3::LabelId label) {
if (label_vertex_count_.find(label) == label_vertex_count_.end())
label_vertex_count_[label] = db_->VerticesCount(label);
return label_vertex_count_.at(label);
}
int64_t VerticesCount(storage::v3::LabelId label, storage::v3::PropertyId property) {
auto key = std::make_pair(label, property);
if (label_property_vertex_count_.find(key) == label_property_vertex_count_.end())
label_property_vertex_count_[key] = db_->VerticesCount(label, property);
return label_property_vertex_count_.at(key);
}
int64_t VerticesCount(storage::v3::LabelId label, storage::v3::PropertyId property,
const storage::v3::PropertyValue &value) {
auto label_prop = std::make_pair(label, property);
auto &value_vertex_count = property_value_vertex_count_[label_prop];
// TODO: Why do we even need TypedValue in this whole file?
TypedValue tv_value(value);
if (value_vertex_count.find(tv_value) == value_vertex_count.end())
value_vertex_count[tv_value] = db_->VerticesCount(label, property, value);
return value_vertex_count.at(tv_value);
}
int64_t VerticesCount(storage::v3::LabelId label, storage::v3::PropertyId property,
const std::optional<utils::Bound<storage::v3::PropertyValue>> &lower,
const std::optional<utils::Bound<storage::v3::PropertyValue>> &upper) {
auto label_prop = std::make_pair(label, property);
auto &bounds_vertex_count = property_bounds_vertex_count_[label_prop];
BoundsKey bounds = std::make_pair(lower, upper);
if (bounds_vertex_count.find(bounds) == bounds_vertex_count.end())
bounds_vertex_count[bounds] = db_->VerticesCount(label, property, lower, upper);
return bounds_vertex_count.at(bounds);
}
bool LabelIndexExists(storage::v3::LabelId label) { return db_->LabelIndexExists(label); }
bool LabelPropertyIndexExists(storage::v3::LabelId label, storage::v3::PropertyId property) {
return db_->LabelPropertyIndexExists(label, property);
}
private:
typedef std::pair<storage::v3::LabelId, storage::v3::PropertyId> LabelPropertyKey;
struct LabelPropertyHash {
size_t operator()(const LabelPropertyKey &key) const {
return utils::HashCombine<storage::v3::LabelId, storage::v3::PropertyId>{}(key.first, key.second);
}
};
typedef std::pair<std::optional<utils::Bound<storage::v3::PropertyValue>>,
std::optional<utils::Bound<storage::v3::PropertyValue>>>
BoundsKey;
struct BoundsHash {
size_t operator()(const BoundsKey &key) const {
const auto &maybe_lower = key.first;
const auto &maybe_upper = key.second;
query::v2::TypedValue lower;
query::v2::TypedValue upper;
if (maybe_lower) lower = TypedValue(maybe_lower->value());
if (maybe_upper) upper = TypedValue(maybe_upper->value());
query::v2::TypedValue::Hash hash;
return utils::HashCombine<size_t, size_t>{}(hash(lower), hash(upper));
}
};
struct BoundsEqual {
bool operator()(const BoundsKey &a, const BoundsKey &b) const {
auto bound_equal = [](const auto &maybe_bound_a, const auto &maybe_bound_b) {
if (maybe_bound_a && maybe_bound_b && maybe_bound_a->type() != maybe_bound_b->type()) return false;
query::v2::TypedValue bound_a;
query::v2::TypedValue bound_b;
if (maybe_bound_a) bound_a = TypedValue(maybe_bound_a->value());
if (maybe_bound_b) bound_b = TypedValue(maybe_bound_b->value());
return query::v2::TypedValue::BoolEqual{}(bound_a, bound_b);
};
return bound_equal(a.first, b.first) && bound_equal(a.second, b.second);
}
};
TDbAccessor *db_;
std::optional<int64_t> vertices_count_;
std::unordered_map<storage::v3::LabelId, int64_t> label_vertex_count_;
std::unordered_map<LabelPropertyKey, int64_t, LabelPropertyHash> label_property_vertex_count_;
std::unordered_map<
LabelPropertyKey,
std::unordered_map<query::v2::TypedValue, int64_t, query::v2::TypedValue::Hash, query::v2::TypedValue::BoolEqual>,
LabelPropertyHash>
property_value_vertex_count_;
std::unordered_map<LabelPropertyKey, std::unordered_map<BoundsKey, int64_t, BoundsHash, BoundsEqual>,
LabelPropertyHash>
property_bounds_vertex_count_;
};
template <class TDbAccessor>
auto MakeVertexCountCache(TDbAccessor *db) {
return VertexCountCache<TDbAccessor>(db);
}
} // namespace memgraph::query::v2::plan

View File

@ -0,0 +1,20 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <functional>
#include <memory>
namespace memgraph::query::v2::procedure {
class CypherType;
using CypherTypePtr = std::unique_ptr<CypherType, std::function<void(CypherType *)>>;
} // namespace memgraph::query::v2::procedure

View File

@ -0,0 +1,293 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
/// @file
#pragma once
#include "mg_procedure.h"
#include <functional>
#include <memory>
#include <string_view>
#include "query/v2/procedure/cypher_type_ptr.hpp"
#include "query/v2/procedure/mg_procedure_impl.hpp"
#include "query/v2/typed_value.hpp"
#include "utils/memory.hpp"
#include "utils/pmr/string.hpp"
namespace memgraph::query::v2::procedure {
class ListType;
class NullableType;
/// Interface for all supported types in openCypher type system.
class CypherType {
public:
CypherType() = default;
virtual ~CypherType() = default;
CypherType(const CypherType &) = delete;
CypherType(CypherType &&) = delete;
CypherType &operator=(const CypherType &) = delete;
CypherType &operator=(CypherType &&) = delete;
/// Get name of the type as it should be presented to the user.
virtual std::string_view GetPresentableName() const = 0;
/// Return true if given mgp_value is of the type as described by `this`.
virtual bool SatisfiesType(const mgp_value &) const = 0;
/// Return true if given TypedValue is of the type as described by `this`.
virtual bool SatisfiesType(const query::v2::TypedValue &) const = 0;
// The following methods are a simple replacement for RTTI because we have
// some special cases we need to handle.
virtual const ListType *AsListType() const { return nullptr; }
virtual const NullableType *AsNullableType() const { return nullptr; }
};
// Simple Types
class AnyType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "ANY"; }
bool SatisfiesType(const mgp_value &value) const override { return value.type != MGP_VALUE_TYPE_NULL; }
bool SatisfiesType(const query::v2::TypedValue &value) const override { return !value.IsNull(); }
};
class BoolType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "BOOLEAN"; }
bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_BOOL; }
bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsBool(); }
};
class StringType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "STRING"; }
bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_STRING; }
bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsString(); }
};
class IntType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "INTEGER"; }
bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_INT; }
bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsInt(); }
};
class FloatType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "FLOAT"; }
bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_DOUBLE; }
bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsDouble(); }
};
class NumberType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "NUMBER"; }
bool SatisfiesType(const mgp_value &value) const override {
return value.type == MGP_VALUE_TYPE_INT || value.type == MGP_VALUE_TYPE_DOUBLE;
}
bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsInt() || value.IsDouble(); }
};
class NodeType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "NODE"; }
bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_VERTEX; }
bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsVertex(); }
};
class RelationshipType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "RELATIONSHIP"; }
bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_EDGE; }
bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsEdge(); }
};
class PathType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "PATH"; }
bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_PATH; }
bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsPath(); }
};
// You'd think that MapType would be a composite type like ListType, but nope.
// Why? No-one really knows. It's defined like that in "CIP2015-09-16 Public
// Type System and Type Annotations"
// Additionally, MapType also covers NodeType and RelationshipType because
// values of that type have property *maps*.
class MapType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "MAP"; }
bool SatisfiesType(const mgp_value &value) const override {
return value.type == MGP_VALUE_TYPE_MAP || value.type == MGP_VALUE_TYPE_VERTEX || value.type == MGP_VALUE_TYPE_EDGE;
}
bool SatisfiesType(const query::v2::TypedValue &value) const override {
return value.IsMap() || value.IsVertex() || value.IsEdge();
}
};
// Temporal Types
class DateType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "DATE"; }
bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_DATE; }
bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsDate(); }
};
class LocalTimeType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "LOCAL_TIME"; }
bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_LOCAL_TIME; }
bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsLocalTime(); }
};
class LocalDateTimeType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "LOCAL_DATE_TIME"; }
bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_LOCAL_DATE_TIME; }
bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsLocalDateTime(); }
};
class DurationType : public CypherType {
public:
std::string_view GetPresentableName() const override { return "DURATION"; }
bool SatisfiesType(const mgp_value &value) const override { return value.type == MGP_VALUE_TYPE_DURATION; }
bool SatisfiesType(const query::v2::TypedValue &value) const override { return value.IsDuration(); }
};
// Composite Types
class ListType : public CypherType {
public:
CypherTypePtr element_type_;
utils::pmr::string presentable_name_;
/// @throw std::bad_alloc
/// @throw std::length_error
explicit ListType(CypherTypePtr element_type, utils::MemoryResource *memory)
: element_type_(std::move(element_type)), presentable_name_("LIST OF ", memory) {
presentable_name_.append(element_type_->GetPresentableName());
}
std::string_view GetPresentableName() const override { return presentable_name_; }
bool SatisfiesType(const mgp_value &value) const override {
if (value.type != MGP_VALUE_TYPE_LIST) {
return false;
}
auto *list = value.list_v;
const auto list_size = list->elems.size();
for (size_t i = 0; i < list_size; ++i) {
if (!element_type_->SatisfiesType(list->elems[i])) {
return false;
};
}
return true;
}
bool SatisfiesType(const query::v2::TypedValue &value) const override {
if (!value.IsList()) return false;
for (const auto &elem : value.ValueList()) {
if (!element_type_->SatisfiesType(elem)) return false;
}
return true;
}
const ListType *AsListType() const override { return this; }
};
class NullableType : public CypherType {
CypherTypePtr type_;
utils::pmr::string presentable_name_;
// Constructor is private, because we use a factory method Create to prevent
// nesting NullableType on top of each other.
// @throw std::bad_alloc
// @throw std::length_error
explicit NullableType(CypherTypePtr type, utils::MemoryResource *memory)
: type_(std::move(type)), presentable_name_(memory) {
const auto *list_type = type_->AsListType();
// ListType is specially formatted
if (list_type) {
presentable_name_.assign("LIST? OF ").append(list_type->element_type_->GetPresentableName());
} else {
presentable_name_.assign(type_->GetPresentableName()).append("?");
}
}
public:
/// Create a NullableType of some CypherType.
/// If passed in `type` is already a NullableType, it is returned intact.
/// Otherwise, `type` is wrapped in a new instance of NullableType.
/// @throw std::bad_alloc
/// @throw std::length_error
static CypherTypePtr Create(CypherTypePtr type, utils::MemoryResource *memory) {
if (type->AsNullableType()) return type;
utils::Allocator<NullableType> alloc(memory);
auto *nullable = alloc.allocate(1);
try {
new (nullable) NullableType(std::move(type), memory);
} catch (...) {
alloc.deallocate(nullable, 1);
throw;
}
return CypherTypePtr(nullable, [alloc](CypherType *base_ptr) mutable {
alloc.delete_object(static_cast<NullableType *>(base_ptr));
});
}
std::string_view GetPresentableName() const override { return presentable_name_; }
bool SatisfiesType(const mgp_value &value) const override {
return value.type == MGP_VALUE_TYPE_NULL || type_->SatisfiesType(value);
}
bool SatisfiesType(const query::v2::TypedValue &value) const override {
return value.IsNull() || type_->SatisfiesType(value);
}
const NullableType *AsNullableType() const override { return this; }
};
} // namespace memgraph::query::v2::procedure

View File

@ -0,0 +1,36 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/v2/procedure/mg_procedure_helpers.hpp"
namespace memgraph::query::v2::procedure {
MgpUniquePtr<mgp_value> GetStringValueOrSetError(const char *string, mgp_memory *memory, mgp_result *result) {
procedure::MgpUniquePtr<mgp_value> value{nullptr, mgp_value_destroy};
const auto success =
TryOrSetError([&] { return procedure::CreateMgpObject(value, mgp_value_make_string, string, memory); }, result);
if (!success) {
value.reset();
}
return value;
}
bool InsertResultOrSetError(mgp_result *result, mgp_result_record *record, const char *result_name, mgp_value *value) {
if (const auto err = mgp_result_record_insert(record, result_name, value); err != mgp_error::MGP_ERROR_NO_ERROR) {
const auto error_msg = fmt::format("Unable to set the result for {}, error = {}", result_name, err);
static_cast<void>(mgp_result_set_error_msg(result, error_msg.c_str()));
return false;
}
return true;
}
} // namespace memgraph::query::v2::procedure

View File

@ -0,0 +1,69 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <memory>
#include <type_traits>
#include <utility>
#include <fmt/format.h>
#include "mg_procedure.h"
namespace memgraph::query::v2::procedure {
template <typename TResult, typename TFunc, typename... TArgs>
TResult Call(TFunc func, TArgs... args) {
static_assert(std::is_trivially_copyable_v<TFunc>);
static_assert((std::is_trivially_copyable_v<std::remove_reference_t<TArgs>> && ...));
TResult result{};
MG_ASSERT(func(args..., &result) == mgp_error::MGP_ERROR_NO_ERROR);
return result;
}
template <typename TFunc, typename... TArgs>
bool CallBool(TFunc func, TArgs... args) {
return Call<int>(func, args...) != 0;
}
template <typename TObj>
using MgpRawObjectDeleter = void (*)(TObj *);
template <typename TObj>
using MgpUniquePtr = std::unique_ptr<TObj, MgpRawObjectDeleter<TObj>>;
template <typename TObj, typename TFunc, typename... TArgs>
mgp_error CreateMgpObject(MgpUniquePtr<TObj> &obj, TFunc func, TArgs &&...args) {
TObj *raw_obj{nullptr};
const auto err = func(std::forward<TArgs>(args)..., &raw_obj);
obj.reset(raw_obj);
return err;
}
template <typename Fun>
[[nodiscard]] bool TryOrSetError(Fun &&func, mgp_result *result) {
if (const auto err = func(); err == mgp_error::MGP_ERROR_UNABLE_TO_ALLOCATE) {
static_cast<void>(mgp_result_set_error_msg(result, "Not enough memory!"));
return false;
} else if (err != mgp_error::MGP_ERROR_NO_ERROR) {
const auto error_msg = fmt::format("Unexpected error ({})!", err);
static_cast<void>(mgp_result_set_error_msg(result, error_msg.c_str()));
return false;
}
return true;
}
[[nodiscard]] MgpUniquePtr<mgp_value> GetStringValueOrSetError(const char *string, mgp_memory *memory,
mgp_result *result);
[[nodiscard]] bool InsertResultOrSetError(mgp_result *result, mgp_result_record *record, const char *result_name,
mgp_value *value);
} // namespace memgraph::query::v2::procedure

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,926 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
/// @file
/// Contains private (implementation) declarations and definitions for
/// mg_procedure.h
#pragma once
#include "mg_procedure.h"
#include <optional>
#include <ostream>
#include "integrations/kafka/consumer.hpp"
#include "integrations/pulsar/consumer.hpp"
#include "query/v2/context.hpp"
#include "query/v2/db_accessor.hpp"
#include "query/v2/frontend/ast/ast.hpp"
#include "query/v2/procedure/cypher_type_ptr.hpp"
#include "query/v2/typed_value.hpp"
#include "storage/v3/view.hpp"
#include "utils/memory.hpp"
#include "utils/pmr/map.hpp"
#include "utils/pmr/string.hpp"
#include "utils/pmr/vector.hpp"
#include "utils/temporal.hpp"
/// Wraps memory resource used in custom procedures.
///
/// This should have been `using mgp_memory = memgraph::utils::MemoryResource`, but that's
/// not valid C++ because we have a forward declare `struct mgp_memory` in
/// mg_procedure.h
/// TODO: Make this extendable in C API, so that custom procedure writer can add
/// their own memory management wrappers.
struct mgp_memory {
memgraph::utils::MemoryResource *impl;
};
/// Immutable container of various values that appear in openCypher.
struct mgp_value {
/// Allocator type so that STL containers are aware that we need one.
using allocator_type = memgraph::utils::Allocator<mgp_value>;
// Construct MGP_VALUE_TYPE_NULL.
explicit mgp_value(memgraph::utils::MemoryResource *) noexcept;
mgp_value(bool, memgraph::utils::MemoryResource *) noexcept;
mgp_value(int64_t, memgraph::utils::MemoryResource *) noexcept;
mgp_value(double, memgraph::utils::MemoryResource *) noexcept;
/// @throw std::bad_alloc
mgp_value(const char *, memgraph::utils::MemoryResource *);
/// Take ownership of the mgp_list, MemoryResource must match.
mgp_value(mgp_list *, memgraph::utils::MemoryResource *) noexcept;
/// Take ownership of the mgp_map, MemoryResource must match.
mgp_value(mgp_map *, memgraph::utils::MemoryResource *) noexcept;
/// Take ownership of the mgp_vertex, MemoryResource must match.
mgp_value(mgp_vertex *, memgraph::utils::MemoryResource *) noexcept;
/// Take ownership of the mgp_edge, MemoryResource must match.
mgp_value(mgp_edge *, memgraph::utils::MemoryResource *) noexcept;
/// Take ownership of the mgp_path, MemoryResource must match.
mgp_value(mgp_path *, memgraph::utils::MemoryResource *) noexcept;
mgp_value(mgp_date *, memgraph::utils::MemoryResource *) noexcept;
mgp_value(mgp_local_time *, memgraph::utils::MemoryResource *) noexcept;
mgp_value(mgp_local_date_time *, memgraph::utils::MemoryResource *) noexcept;
mgp_value(mgp_duration *, memgraph::utils::MemoryResource *) noexcept;
/// Construct by copying memgraph::query::v2::TypedValue using memgraph::utils::MemoryResource.
/// mgp_graph is needed to construct mgp_vertex and mgp_edge.
/// @throw std::bad_alloc
mgp_value(const memgraph::query::v2::TypedValue &, mgp_graph *, memgraph::utils::MemoryResource *);
/// Construct by copying memgraph::storage::v3::PropertyValue using memgraph::utils::MemoryResource.
/// @throw std::bad_alloc
mgp_value(const memgraph::storage::v3::PropertyValue &, memgraph::utils::MemoryResource *);
/// Copy construction without memgraph::utils::MemoryResource is not allowed.
mgp_value(const mgp_value &) = delete;
/// Copy construct using given memgraph::utils::MemoryResource.
/// @throw std::bad_alloc
mgp_value(const mgp_value &, memgraph::utils::MemoryResource *);
/// Move construct using given memgraph::utils::MemoryResource.
/// @throw std::bad_alloc if MemoryResource is different, so we cannot move.
mgp_value(mgp_value &&, memgraph::utils::MemoryResource *);
/// Move construct, memgraph::utils::MemoryResource is inherited.
mgp_value(mgp_value &&other) noexcept : mgp_value(other, other.memory) {}
/// Copy-assignment is not allowed to preserve immutability.
mgp_value &operator=(const mgp_value &) = delete;
/// Move-assignment is not allowed to preserve immutability.
mgp_value &operator=(mgp_value &&) = delete;
~mgp_value() noexcept;
memgraph::utils::MemoryResource *GetMemoryResource() const noexcept { return memory; }
mgp_value_type type;
memgraph::utils::MemoryResource *memory;
union {
bool bool_v;
int64_t int_v;
double double_v;
memgraph::utils::pmr::string string_v;
// We use pointers so that taking ownership via C API is easier. Besides,
// mgp_map cannot use incomplete mgp_value type, because that would be
// undefined behaviour.
mgp_list *list_v;
mgp_map *map_v;
mgp_vertex *vertex_v;
mgp_edge *edge_v;
mgp_path *path_v;
mgp_date *date_v;
mgp_local_time *local_time_v;
mgp_local_date_time *local_date_time_v;
mgp_duration *duration_v;
};
};
inline memgraph::utils::DateParameters MapDateParameters(const mgp_date_parameters *parameters) {
return {.year = parameters->year, .month = parameters->month, .day = parameters->day};
}
struct mgp_date {
/// Allocator type so that STL containers are aware that we need one.
/// We don't actually need this, but it simplifies the C API, because we store
/// the allocator which was used to allocate `this`.
using allocator_type = memgraph::utils::Allocator<mgp_date>;
// Hopefully memgraph::utils::Date copy constructor remains noexcept, so that we can
// have everything noexcept here.
static_assert(std::is_nothrow_copy_constructible_v<memgraph::utils::Date>);
mgp_date(const memgraph::utils::Date &date, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), date(date) {}
mgp_date(const std::string_view string, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), date(memgraph::utils::ParseDateParameters(string).first) {}
mgp_date(const mgp_date_parameters *parameters, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), date(MapDateParameters(parameters)) {}
mgp_date(const int64_t microseconds, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), date(microseconds) {}
mgp_date(const mgp_date &other, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), date(other.date) {}
mgp_date(mgp_date &&other, memgraph::utils::MemoryResource *memory) noexcept : memory(memory), date(other.date) {}
mgp_date(mgp_date &&other) noexcept : memory(other.memory), date(other.date) {}
/// Copy construction without memgraph::utils::MemoryResource is not allowed.
mgp_date(const mgp_date &) = delete;
mgp_date &operator=(const mgp_date &) = delete;
mgp_date &operator=(mgp_date &&) = delete;
~mgp_date() = default;
memgraph::utils::MemoryResource *GetMemoryResource() const noexcept { return memory; }
memgraph::utils::MemoryResource *memory;
memgraph::utils::Date date;
};
inline memgraph::utils::LocalTimeParameters MapLocalTimeParameters(const mgp_local_time_parameters *parameters) {
return {.hour = parameters->hour,
.minute = parameters->minute,
.second = parameters->second,
.millisecond = parameters->millisecond,
.microsecond = parameters->microsecond};
}
struct mgp_local_time {
/// Allocator type so that STL containers are aware that we need one.
/// We don't actually need this, but it simplifies the C API, because we store
/// the allocator which was used to allocate `this`.
using allocator_type = memgraph::utils::Allocator<mgp_local_time>;
// Hopefully memgraph::utils::LocalTime copy constructor remains noexcept, so that we can
// have everything noexcept here.
static_assert(std::is_nothrow_copy_constructible_v<memgraph::utils::LocalTime>);
mgp_local_time(const std::string_view string, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), local_time(memgraph::utils::ParseLocalTimeParameters(string).first) {}
mgp_local_time(const mgp_local_time_parameters *parameters, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), local_time(MapLocalTimeParameters(parameters)) {}
mgp_local_time(const memgraph::utils::LocalTime &local_time, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), local_time(local_time) {}
mgp_local_time(const int64_t microseconds, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), local_time(microseconds) {}
mgp_local_time(const mgp_local_time &other, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), local_time(other.local_time) {}
mgp_local_time(mgp_local_time &&other, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), local_time(other.local_time) {}
mgp_local_time(mgp_local_time &&other) noexcept : memory(other.memory), local_time(other.local_time) {}
/// Copy construction without memgraph::utils::MemoryResource is not allowed.
mgp_local_time(const mgp_local_time &) = delete;
mgp_local_time &operator=(const mgp_local_time &) = delete;
mgp_local_time &operator=(mgp_local_time &&) = delete;
~mgp_local_time() = default;
memgraph::utils::MemoryResource *GetMemoryResource() const noexcept { return memory; }
memgraph::utils::MemoryResource *memory;
memgraph::utils::LocalTime local_time;
};
inline memgraph::utils::LocalDateTime CreateLocalDateTimeFromString(const std::string_view string) {
const auto &[date_parameters, local_time_parameters] = memgraph::utils::ParseLocalDateTimeParameters(string);
return memgraph::utils::LocalDateTime{date_parameters, local_time_parameters};
}
struct mgp_local_date_time {
/// Allocator type so that STL containers are aware that we need one.
/// We don't actually need this, but it simplifies the C API, because we store
/// the allocator which was used to allocate `this`.
using allocator_type = memgraph::utils::Allocator<mgp_local_date_time>;
// Hopefully memgraph::utils::LocalDateTime copy constructor remains noexcept, so that we can
// have everything noexcept here.
static_assert(std::is_nothrow_copy_constructible_v<memgraph::utils::LocalDateTime>);
mgp_local_date_time(const memgraph::utils::LocalDateTime &local_date_time,
memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), local_date_time(local_date_time) {}
mgp_local_date_time(const std::string_view string, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), local_date_time(CreateLocalDateTimeFromString(string)) {}
mgp_local_date_time(const mgp_local_date_time_parameters *parameters,
memgraph::utils::MemoryResource *memory) noexcept
: memory(memory),
local_date_time(MapDateParameters(parameters->date_parameters),
MapLocalTimeParameters(parameters->local_time_parameters)) {}
mgp_local_date_time(const int64_t microseconds, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), local_date_time(microseconds) {}
mgp_local_date_time(const mgp_local_date_time &other, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), local_date_time(other.local_date_time) {}
mgp_local_date_time(mgp_local_date_time &&other, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), local_date_time(other.local_date_time) {}
mgp_local_date_time(mgp_local_date_time &&other) noexcept
: memory(other.memory), local_date_time(other.local_date_time) {}
/// Copy construction without memgraph::utils::MemoryResource is not allowed.
mgp_local_date_time(const mgp_local_date_time &) = delete;
mgp_local_date_time &operator=(const mgp_local_date_time &) = delete;
mgp_local_date_time &operator=(mgp_local_date_time &&) = delete;
~mgp_local_date_time() = default;
memgraph::utils::MemoryResource *GetMemoryResource() const noexcept { return memory; }
memgraph::utils::MemoryResource *memory;
memgraph::utils::LocalDateTime local_date_time;
};
inline memgraph::utils::DurationParameters MapDurationParameters(const mgp_duration_parameters *parameters) {
return {.day = parameters->day,
.hour = parameters->hour,
.minute = parameters->minute,
.second = parameters->second,
.millisecond = parameters->millisecond,
.microsecond = parameters->microsecond};
}
struct mgp_duration {
/// Allocator type so that STL containers are aware that we need one.
/// We don't actually need this, but it simplifies the C API, because we store
/// the allocator which was used to allocate `this`.
using allocator_type = memgraph::utils::Allocator<mgp_duration>;
// Hopefully memgraph::utils::Duration copy constructor remains noexcept, so that we can
// have everything noexcept here.
static_assert(std::is_nothrow_copy_constructible_v<memgraph::utils::Duration>);
mgp_duration(const std::string_view string, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), duration(memgraph::utils::ParseDurationParameters(string)) {}
mgp_duration(const mgp_duration_parameters *parameters, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), duration(MapDurationParameters(parameters)) {}
mgp_duration(const memgraph::utils::Duration &duration, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), duration(duration) {}
mgp_duration(const int64_t microseconds, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), duration(microseconds) {}
mgp_duration(const mgp_duration &other, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), duration(other.duration) {}
mgp_duration(mgp_duration &&other, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), duration(other.duration) {}
mgp_duration(mgp_duration &&other) noexcept : memory(other.memory), duration(other.duration) {}
/// Copy construction without memgraph::utils::MemoryResource is not allowed.
mgp_duration(const mgp_duration &) = delete;
mgp_duration &operator=(const mgp_duration &) = delete;
mgp_duration &operator=(mgp_duration &&) = delete;
~mgp_duration() = default;
memgraph::utils::MemoryResource *GetMemoryResource() const noexcept { return memory; }
memgraph::utils::MemoryResource *memory;
memgraph::utils::Duration duration;
};
struct mgp_list {
/// Allocator type so that STL containers are aware that we need one.
using allocator_type = memgraph::utils::Allocator<mgp_list>;
explicit mgp_list(memgraph::utils::MemoryResource *memory) : elems(memory) {}
mgp_list(memgraph::utils::pmr::vector<mgp_value> &&elems, memgraph::utils::MemoryResource *memory)
: elems(std::move(elems), memory) {}
mgp_list(const mgp_list &other, memgraph::utils::MemoryResource *memory) : elems(other.elems, memory) {}
mgp_list(mgp_list &&other, memgraph::utils::MemoryResource *memory) : elems(std::move(other.elems), memory) {}
mgp_list(mgp_list &&other) noexcept : elems(std::move(other.elems)) {}
/// Copy construction without memgraph::utils::MemoryResource is not allowed.
mgp_list(const mgp_list &) = delete;
mgp_list &operator=(const mgp_list &) = delete;
mgp_list &operator=(mgp_list &&) = delete;
~mgp_list() = default;
memgraph::utils::MemoryResource *GetMemoryResource() const noexcept {
return elems.get_allocator().GetMemoryResource();
}
// C++17 vector can work with incomplete type.
memgraph::utils::pmr::vector<mgp_value> elems;
};
struct mgp_map {
/// Allocator type so that STL containers are aware that we need one.
using allocator_type = memgraph::utils::Allocator<mgp_map>;
explicit mgp_map(memgraph::utils::MemoryResource *memory) : items(memory) {}
mgp_map(memgraph::utils::pmr::map<memgraph::utils::pmr::string, mgp_value> &&items,
memgraph::utils::MemoryResource *memory)
: items(std::move(items), memory) {}
mgp_map(const mgp_map &other, memgraph::utils::MemoryResource *memory) : items(other.items, memory) {}
mgp_map(mgp_map &&other, memgraph::utils::MemoryResource *memory) : items(std::move(other.items), memory) {}
mgp_map(mgp_map &&other) noexcept : items(std::move(other.items)) {}
/// Copy construction without memgraph::utils::MemoryResource is not allowed.
mgp_map(const mgp_map &) = delete;
mgp_map &operator=(const mgp_map &) = delete;
mgp_map &operator=(mgp_map &&) = delete;
~mgp_map() = default;
memgraph::utils::MemoryResource *GetMemoryResource() const noexcept {
return items.get_allocator().GetMemoryResource();
}
// Unfortunately using incomplete type with map is undefined, so mgp_map
// needs to be defined after mgp_value.
memgraph::utils::pmr::map<memgraph::utils::pmr::string, mgp_value> items;
};
struct mgp_map_item {
const char *key;
mgp_value *value;
};
struct mgp_map_items_iterator {
using allocator_type = memgraph::utils::Allocator<mgp_map_items_iterator>;
mgp_map_items_iterator(mgp_map *map, memgraph::utils::MemoryResource *memory)
: memory(memory), map(map), current_it(map->items.begin()) {
if (current_it != map->items.end()) {
current.key = current_it->first.c_str();
current.value = &current_it->second;
}
}
mgp_map_items_iterator(const mgp_map_items_iterator &) = delete;
mgp_map_items_iterator(mgp_map_items_iterator &&) = delete;
mgp_map_items_iterator &operator=(const mgp_map_items_iterator &) = delete;
mgp_map_items_iterator &operator=(mgp_map_items_iterator &&) = delete;
~mgp_map_items_iterator() = default;
memgraph::utils::MemoryResource *GetMemoryResource() const { return memory; }
memgraph::utils::MemoryResource *memory;
mgp_map *map;
decltype(map->items.begin()) current_it;
mgp_map_item current;
};
struct mgp_vertex {
/// Allocator type so that STL containers are aware that we need one.
/// We don't actually need this, but it simplifies the C API, because we store
/// the allocator which was used to allocate `this`.
using allocator_type = memgraph::utils::Allocator<mgp_vertex>;
// Hopefully VertexAccessor copy constructor remains noexcept, so that we can
// have everything noexcept here.
static_assert(std::is_nothrow_copy_constructible_v<memgraph::query::v2::VertexAccessor>);
mgp_vertex(memgraph::query::v2::VertexAccessor v, mgp_graph *graph, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), impl(v), graph(graph) {}
mgp_vertex(const mgp_vertex &other, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), impl(other.impl), graph(other.graph) {}
mgp_vertex(mgp_vertex &&other, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), impl(other.impl), graph(other.graph) {}
mgp_vertex(mgp_vertex &&other) noexcept : memory(other.memory), impl(other.impl), graph(other.graph) {}
/// Copy construction without memgraph::utils::MemoryResource is not allowed.
mgp_vertex(const mgp_vertex &) = delete;
mgp_vertex &operator=(const mgp_vertex &) = delete;
mgp_vertex &operator=(mgp_vertex &&) = delete;
bool operator==(const mgp_vertex &other) const noexcept { return this->impl == other.impl; }
bool operator!=(const mgp_vertex &other) const noexcept { return !(*this == other); };
~mgp_vertex() = default;
memgraph::utils::MemoryResource *GetMemoryResource() const noexcept { return memory; }
memgraph::utils::MemoryResource *memory;
memgraph::query::v2::VertexAccessor impl;
mgp_graph *graph;
};
struct mgp_edge {
/// Allocator type so that STL containers are aware that we need one.
/// We don't actually need this, but it simplifies the C API, because we store
/// the allocator which was used to allocate `this`.
using allocator_type = memgraph::utils::Allocator<mgp_edge>;
// Hopefully EdgeAccessor copy constructor remains noexcept, so that we can
// have everything noexcept here.
static_assert(std::is_nothrow_copy_constructible_v<memgraph::query::v2::EdgeAccessor>);
static mgp_edge *Copy(const mgp_edge &edge, mgp_memory &memory);
mgp_edge(const memgraph::query::v2::EdgeAccessor &impl, mgp_graph *graph,
memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), impl(impl), from(impl.From(), graph, memory), to(impl.To(), graph, memory) {}
mgp_edge(const mgp_edge &other, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), impl(other.impl), from(other.from, memory), to(other.to, memory) {}
mgp_edge(mgp_edge &&other, memgraph::utils::MemoryResource *memory) noexcept
: memory(other.memory), impl(other.impl), from(std::move(other.from), memory), to(std::move(other.to), memory) {}
mgp_edge(mgp_edge &&other) noexcept
: memory(other.memory), impl(other.impl), from(std::move(other.from)), to(std::move(other.to)) {}
/// Copy construction without memgraph::utils::MemoryResource is not allowed.
mgp_edge(const mgp_edge &) = delete;
mgp_edge &operator=(const mgp_edge &) = delete;
mgp_edge &operator=(mgp_edge &&) = delete;
~mgp_edge() = default;
bool operator==(const mgp_edge &other) const noexcept { return this->impl == other.impl; }
bool operator!=(const mgp_edge &other) const noexcept { return !(*this == other); };
memgraph::utils::MemoryResource *GetMemoryResource() const noexcept { return memory; }
memgraph::utils::MemoryResource *memory;
memgraph::query::v2::EdgeAccessor impl;
mgp_vertex from;
mgp_vertex to;
};
struct mgp_path {
/// Allocator type so that STL containers are aware that we need one.
using allocator_type = memgraph::utils::Allocator<mgp_path>;
explicit mgp_path(memgraph::utils::MemoryResource *memory) : vertices(memory), edges(memory) {}
mgp_path(const mgp_path &other, memgraph::utils::MemoryResource *memory)
: vertices(other.vertices, memory), edges(other.edges, memory) {}
mgp_path(mgp_path &&other, memgraph::utils::MemoryResource *memory)
: vertices(std::move(other.vertices), memory), edges(std::move(other.edges), memory) {}
mgp_path(mgp_path &&other) noexcept : vertices(std::move(other.vertices)), edges(std::move(other.edges)) {}
/// Copy construction without memgraph::utils::MemoryResource is not allowed.
mgp_path(const mgp_path &) = delete;
mgp_path &operator=(const mgp_path &) = delete;
mgp_path &operator=(mgp_path &&) = delete;
~mgp_path() = default;
memgraph::utils::MemoryResource *GetMemoryResource() const noexcept {
return vertices.get_allocator().GetMemoryResource();
}
memgraph::utils::pmr::vector<mgp_vertex> vertices;
memgraph::utils::pmr::vector<mgp_edge> edges;
};
struct mgp_result_record {
/// Result record signature as defined for mgp_proc.
const memgraph::utils::pmr::map<memgraph::utils::pmr::string,
std::pair<const memgraph::query::v2::procedure::CypherType *, bool>> *signature;
memgraph::utils::pmr::map<memgraph::utils::pmr::string, memgraph::query::v2::TypedValue> values;
};
struct mgp_result {
explicit mgp_result(
const memgraph::utils::pmr::map<memgraph::utils::pmr::string,
std::pair<const memgraph::query::v2::procedure::CypherType *, bool>> *signature,
memgraph::utils::MemoryResource *mem)
: signature(signature), rows(mem) {}
/// Result record signature as defined for mgp_proc.
const memgraph::utils::pmr::map<memgraph::utils::pmr::string,
std::pair<const memgraph::query::v2::procedure::CypherType *, bool>> *signature;
memgraph::utils::pmr::vector<mgp_result_record> rows;
std::optional<memgraph::utils::pmr::string> error_msg;
};
struct mgp_func_result {
mgp_func_result() {}
/// Return Magic function result. If user forgets it, the error is raised
std::optional<memgraph::query::v2::TypedValue> value;
/// Return Magic function result with potential error
std::optional<memgraph::utils::pmr::string> error_msg;
};
struct mgp_graph {
memgraph::query::v2::DbAccessor *impl;
memgraph::storage::v3::View view;
// TODO: Merge `mgp_graph` and `mgp_memory` into a single `mgp_context`. The
// `ctx` field is out of place here.
memgraph::query::v2::ExecutionContext *ctx;
static mgp_graph WritableGraph(memgraph::query::v2::DbAccessor &acc, memgraph::storage::v3::View view,
memgraph::query::v2::ExecutionContext &ctx) {
return mgp_graph{&acc, view, &ctx};
}
static mgp_graph NonWritableGraph(memgraph::query::v2::DbAccessor &acc, memgraph::storage::v3::View view) {
return mgp_graph{&acc, view, nullptr};
}
};
// Prevents user to use ExecutionContext in writable callables
struct mgp_func_context {
memgraph::query::v2::DbAccessor *impl;
memgraph::storage::v3::View view;
};
struct mgp_properties_iterator {
using allocator_type = memgraph::utils::Allocator<mgp_properties_iterator>;
// Define members at the start because we use decltype a lot here, so members
// need to be visible in method definitions.
memgraph::utils::MemoryResource *memory;
mgp_graph *graph;
std::remove_reference_t<decltype(*std::declval<memgraph::query::v2::VertexAccessor>().Properties(graph->view))> pvs;
decltype(pvs.begin()) current_it;
std::optional<std::pair<memgraph::utils::pmr::string, mgp_value>> current;
mgp_property property{nullptr, nullptr};
// Construct with no properties.
explicit mgp_properties_iterator(mgp_graph *graph, memgraph::utils::MemoryResource *memory)
: memory(memory), graph(graph), current_it(pvs.begin()) {}
// May throw who the #$@! knows what because PropertyValueStore doesn't
// document what it throws, and it may surely throw some piece of !@#$
// exception because it's built on top of STL and other libraries.
mgp_properties_iterator(mgp_graph *graph, decltype(pvs) pvs, memgraph::utils::MemoryResource *memory)
: memory(memory), graph(graph), pvs(std::move(pvs)), current_it(this->pvs.begin()) {
if (current_it != this->pvs.end()) {
current.emplace(memgraph::utils::pmr::string(graph->impl->PropertyToName(current_it->first), memory),
mgp_value(current_it->second, memory));
property.name = current->first.c_str();
property.value = &current->second;
}
}
mgp_properties_iterator(const mgp_properties_iterator &) = delete;
mgp_properties_iterator(mgp_properties_iterator &&) = delete;
mgp_properties_iterator &operator=(const mgp_properties_iterator &) = delete;
mgp_properties_iterator &operator=(mgp_properties_iterator &&) = delete;
~mgp_properties_iterator() = default;
memgraph::utils::MemoryResource *GetMemoryResource() const { return memory; }
};
struct mgp_edges_iterator {
using allocator_type = memgraph::utils::Allocator<mgp_edges_iterator>;
// Hopefully mgp_vertex copy constructor remains noexcept, so that we can
// have everything noexcept here.
static_assert(std::is_nothrow_constructible_v<mgp_vertex, const mgp_vertex &, memgraph::utils::MemoryResource *>);
mgp_edges_iterator(const mgp_vertex &v, memgraph::utils::MemoryResource *memory) noexcept
: memory(memory), source_vertex(v, memory) {}
mgp_edges_iterator(mgp_edges_iterator &&other) noexcept
: memory(other.memory),
source_vertex(std::move(other.source_vertex)),
in(std::move(other.in)),
in_it(std::move(other.in_it)),
out(std::move(other.out)),
out_it(std::move(other.out_it)),
current_e(std::move(other.current_e)) {}
mgp_edges_iterator(const mgp_edges_iterator &) = delete;
mgp_edges_iterator &operator=(const mgp_edges_iterator &) = delete;
mgp_edges_iterator &operator=(mgp_edges_iterator &&) = delete;
~mgp_edges_iterator() = default;
memgraph::utils::MemoryResource *GetMemoryResource() const { return memory; }
memgraph::utils::MemoryResource *memory;
mgp_vertex source_vertex;
std::optional<std::remove_reference_t<decltype(*source_vertex.impl.InEdges(source_vertex.graph->view))>> in;
std::optional<decltype(in->begin())> in_it;
std::optional<std::remove_reference_t<decltype(*source_vertex.impl.OutEdges(source_vertex.graph->view))>> out;
std::optional<decltype(out->begin())> out_it;
std::optional<mgp_edge> current_e;
};
struct mgp_vertices_iterator {
using allocator_type = memgraph::utils::Allocator<mgp_vertices_iterator>;
/// @throw anything VerticesIterable may throw
mgp_vertices_iterator(mgp_graph *graph, memgraph::utils::MemoryResource *memory)
: memory(memory), graph(graph), vertices(graph->impl->Vertices(graph->view)), current_it(vertices.begin()) {
if (current_it != vertices.end()) {
current_v.emplace(*current_it, graph, memory);
}
}
memgraph::utils::MemoryResource *GetMemoryResource() const { return memory; }
memgraph::utils::MemoryResource *memory;
mgp_graph *graph;
decltype(graph->impl->Vertices(graph->view)) vertices;
decltype(vertices.begin()) current_it;
std::optional<mgp_vertex> current_v;
};
struct mgp_type {
memgraph::query::v2::procedure::CypherTypePtr impl;
};
struct ProcedureInfo {
bool is_write = false;
std::optional<memgraph::query::v2::AuthQuery::Privilege> required_privilege = std::nullopt;
};
struct mgp_proc {
using allocator_type = memgraph::utils::Allocator<mgp_proc>;
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_proc(const char *name, mgp_proc_cb cb, memgraph::utils::MemoryResource *memory, const ProcedureInfo &info = {})
: name(name, memory), cb(cb), args(memory), opt_args(memory), results(memory), info(info) {}
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_proc(const char *name, std::function<void(mgp_list *, mgp_graph *, mgp_result *, mgp_memory *)> cb,
memgraph::utils::MemoryResource *memory, const ProcedureInfo &info = {})
: name(name, memory), cb(cb), args(memory), opt_args(memory), results(memory), info(info) {}
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_proc(const std::string_view name, std::function<void(mgp_list *, mgp_graph *, mgp_result *, mgp_memory *)> cb,
memgraph::utils::MemoryResource *memory, const ProcedureInfo &info = {})
: name(name, memory), cb(cb), args(memory), opt_args(memory), results(memory), info(info) {}
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_proc(const mgp_proc &other, memgraph::utils::MemoryResource *memory)
: name(other.name, memory),
cb(other.cb),
args(other.args, memory),
opt_args(other.opt_args, memory),
results(other.results, memory),
info(other.info) {}
mgp_proc(mgp_proc &&other, memgraph::utils::MemoryResource *memory)
: name(std::move(other.name), memory),
cb(std::move(other.cb)),
args(std::move(other.args), memory),
opt_args(std::move(other.opt_args), memory),
results(std::move(other.results), memory),
info(other.info) {}
mgp_proc(const mgp_proc &other) = default;
mgp_proc(mgp_proc &&other) = default;
mgp_proc &operator=(const mgp_proc &) = delete;
mgp_proc &operator=(mgp_proc &&) = delete;
~mgp_proc() = default;
/// Name of the procedure.
memgraph::utils::pmr::string name;
/// Entry-point for the procedure.
std::function<void(mgp_list *, mgp_graph *, mgp_result *, mgp_memory *)> cb;
/// Required, positional arguments as a (name, type) pair.
memgraph::utils::pmr::vector<
std::pair<memgraph::utils::pmr::string, const memgraph::query::v2::procedure::CypherType *>>
args;
/// Optional positional arguments as a (name, type, default_value) tuple.
memgraph::utils::pmr::vector<
std::tuple<memgraph::utils::pmr::string, const memgraph::query::v2::procedure::CypherType *,
memgraph::query::v2::TypedValue>>
opt_args;
/// Fields this procedure returns, as a (name -> (type, is_deprecated)) map.
memgraph::utils::pmr::map<memgraph::utils::pmr::string,
std::pair<const memgraph::query::v2::procedure::CypherType *, bool>>
results;
ProcedureInfo info;
};
struct mgp_trans {
using allocator_type = memgraph::utils::Allocator<mgp_trans>;
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_trans(const char *name, mgp_trans_cb cb, memgraph::utils::MemoryResource *memory)
: name(name, memory), cb(cb), results(memory) {}
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_trans(const char *name, std::function<void(mgp_messages *, mgp_graph *, mgp_result *, mgp_memory *)> cb,
memgraph::utils::MemoryResource *memory)
: name(name, memory), cb(cb), results(memory) {}
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_trans(const mgp_trans &other, memgraph::utils::MemoryResource *memory)
: name(other.name, memory), cb(other.cb), results(other.results) {}
mgp_trans(mgp_trans &&other, memgraph::utils::MemoryResource *memory)
: name(std::move(other.name), memory), cb(std::move(other.cb)), results(std::move(other.results)) {}
mgp_trans(const mgp_trans &other) = default;
mgp_trans(mgp_trans &&other) = default;
mgp_trans &operator=(const mgp_trans &) = delete;
mgp_trans &operator=(mgp_trans &&) = delete;
~mgp_trans() = default;
/// Name of the transformation.
memgraph::utils::pmr::string name;
/// Entry-point for the transformation.
std::function<void(mgp_messages *, mgp_graph *, mgp_result *, mgp_memory *)> cb;
/// Fields this transformation returns.
memgraph::utils::pmr::map<memgraph::utils::pmr::string,
std::pair<const memgraph::query::v2::procedure::CypherType *, bool>>
results;
};
struct mgp_func {
using allocator_type = memgraph::utils::Allocator<mgp_func>;
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_func(const char *name, mgp_func_cb cb, memgraph::utils::MemoryResource *memory)
: name(name, memory), cb(cb), args(memory), opt_args(memory) {}
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_func(const char *name, std::function<void(mgp_list *, mgp_func_context *, mgp_func_result *, mgp_memory *)> cb,
memgraph::utils::MemoryResource *memory)
: name(name, memory), cb(cb), args(memory), opt_args(memory) {}
/// @throw std::bad_alloc
/// @throw std::length_error
mgp_func(const mgp_func &other, memgraph::utils::MemoryResource *memory)
: name(other.name, memory), cb(other.cb), args(other.args, memory), opt_args(other.opt_args, memory) {}
mgp_func(mgp_func &&other, memgraph::utils::MemoryResource *memory)
: name(std::move(other.name), memory),
cb(std::move(other.cb)),
args(std::move(other.args), memory),
opt_args(std::move(other.opt_args), memory) {}
mgp_func(const mgp_func &other) = default;
mgp_func(mgp_func &&other) = default;
mgp_func &operator=(const mgp_func &) = delete;
mgp_func &operator=(mgp_func &&) = delete;
~mgp_func() = default;
/// Name of the function.
memgraph::utils::pmr::string name;
/// Entry-point for the function.
std::function<void(mgp_list *, mgp_func_context *, mgp_func_result *, mgp_memory *)> cb;
/// Required, positional arguments as a (name, type) pair.
memgraph::utils::pmr::vector<
std::pair<memgraph::utils::pmr::string, const memgraph::query::v2::procedure::CypherType *>>
args;
/// Optional positional arguments as a (name, type, default_value) tuple.
memgraph::utils::pmr::vector<
std::tuple<memgraph::utils::pmr::string, const memgraph::query::v2::procedure::CypherType *,
memgraph::query::v2::TypedValue>>
opt_args;
};
mgp_error MgpTransAddFixedResult(mgp_trans *trans) noexcept;
struct mgp_module {
using allocator_type = memgraph::utils::Allocator<mgp_module>;
explicit mgp_module(memgraph::utils::MemoryResource *memory)
: procedures(memory), transformations(memory), functions(memory) {}
mgp_module(const mgp_module &other, memgraph::utils::MemoryResource *memory)
: procedures(other.procedures, memory),
transformations(other.transformations, memory),
functions(other.functions, memory) {}
mgp_module(mgp_module &&other, memgraph::utils::MemoryResource *memory)
: procedures(std::move(other.procedures), memory),
transformations(std::move(other.transformations), memory),
functions(std::move(other.functions), memory) {}
mgp_module(const mgp_module &) = default;
mgp_module(mgp_module &&) = default;
mgp_module &operator=(const mgp_module &) = delete;
mgp_module &operator=(mgp_module &&) = delete;
~mgp_module() = default;
memgraph::utils::pmr::map<memgraph::utils::pmr::string, mgp_proc> procedures;
memgraph::utils::pmr::map<memgraph::utils::pmr::string, mgp_trans> transformations;
memgraph::utils::pmr::map<memgraph::utils::pmr::string, mgp_func> functions;
};
namespace memgraph::query::v2::procedure {
/// @throw std::bad_alloc
/// @throw std::length_error
/// @throw anything std::ostream::operator<< may throw.
void PrintProcSignature(const mgp_proc &, std::ostream *);
/// @throw std::bad_alloc
/// @throw std::length_error
/// @throw anything std::ostream::operator<< may throw.
void PrintFuncSignature(const mgp_func &, std::ostream &);
bool IsValidIdentifierName(const char *name);
} // namespace memgraph::query::v2::procedure
struct mgp_message {
explicit mgp_message(const memgraph::integrations::kafka::Message &message) : msg{&message} {}
explicit mgp_message(const memgraph::integrations::pulsar::Message &message) : msg{message} {}
using KafkaMessage = const memgraph::integrations::kafka::Message *;
using PulsarMessage = memgraph::integrations::pulsar::Message;
std::variant<KafkaMessage, PulsarMessage> msg;
};
struct mgp_messages {
using allocator_type = memgraph::utils::Allocator<mgp_messages>;
using storage_type = memgraph::utils::pmr::vector<mgp_message>;
explicit mgp_messages(storage_type &&storage) : messages(std::move(storage)) {}
mgp_messages(const mgp_messages &) = delete;
mgp_messages &operator=(const mgp_messages &) = delete;
mgp_messages(mgp_messages &&) = delete;
mgp_messages &operator=(mgp_messages &&) = delete;
~mgp_messages() = default;
storage_type messages;
};
memgraph::query::v2::TypedValue ToTypedValue(const mgp_value &val, memgraph::utils::MemoryResource *memory);

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,246 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
/// @file
/// API for loading and registering modules providing custom oC procedures
#pragma once
#include <dlfcn.h>
#include <filesystem>
#include <functional>
#include <optional>
#include <shared_mutex>
#include <string>
#include <string_view>
#include <unordered_map>
#include "query/v2/procedure/cypher_types.hpp"
#include "query/v2/procedure/mg_procedure_impl.hpp"
#include "utils/memory.hpp"
#include "utils/rw_lock.hpp"
class CypherMainVisitorTest;
namespace memgraph::query::v2::procedure {
class Module {
public:
Module() {}
virtual ~Module();
Module(const Module &) = delete;
Module(Module &&) = delete;
Module &operator=(const Module &) = delete;
Module &operator=(Module &&) = delete;
/// Invokes the (optional) shutdown function and closes the module.
virtual bool Close() = 0;
/// Returns registered procedures of this module
virtual const std::map<std::string, mgp_proc, std::less<>> *Procedures() const = 0;
/// Returns registered transformations of this module
virtual const std::map<std::string, mgp_trans, std::less<>> *Transformations() const = 0;
// /// Returns registered functions of this module
virtual const std::map<std::string, mgp_func, std::less<>> *Functions() const = 0;
virtual std::optional<std::filesystem::path> Path() const = 0;
};
/// Proxy for a registered Module, acquires a read lock from ModuleRegistry.
class ModulePtr final {
const Module *module_{nullptr};
std::shared_lock<utils::RWLock> lock_;
public:
ModulePtr() = default;
ModulePtr(std::nullptr_t) {}
ModulePtr(const Module *module, std::shared_lock<utils::RWLock> lock) : module_(module), lock_(std::move(lock)) {}
explicit operator bool() const { return static_cast<bool>(module_); }
const Module &operator*() const { return *module_; }
const Module *operator->() const { return module_; }
};
/// Thread-safe registration of modules from libraries, uses utils::RWLock.
class ModuleRegistry final {
friend CypherMainVisitorTest;
std::map<std::string, std::unique_ptr<Module>, std::less<>> modules_;
mutable utils::RWLock lock_{utils::RWLock::Priority::WRITE};
std::unique_ptr<utils::MemoryResource> shared_{std::make_unique<utils::ResourceWithOutOfMemoryException>()};
bool RegisterModule(std::string_view name, std::unique_ptr<Module> module);
void DoUnloadAllModules();
/// Loads the module if it's in the modules_dir directory
/// @return Whether the module was loaded
bool LoadModuleIfFound(const std::filesystem::path &modules_dir, std::string_view name);
void LoadModulesFromDirectory(const std::filesystem::path &modules_dir);
public:
ModuleRegistry();
/// Set the modules directories that will be used when (re)loading modules.
void SetModulesDirectory(std::vector<std::filesystem::path> modules_dir, const std::filesystem::path &data_directory);
const std::vector<std::filesystem::path> &GetModulesDirectory() const;
/// Atomically load or reload a module with a particular name from the given
/// directory.
///
/// Takes a write lock. If the module exists it is reloaded. Otherwise, the
/// module is loaded from the file whose filename, without the extension,
/// matches the module's name. If multiple such files exist, only one is
/// chosen, in an unspecified manner. If loading of the chosen file fails, no
/// other files are tried.
///
/// Return true if the module was loaded or reloaded successfully, false
/// otherwise.
bool LoadOrReloadModuleFromName(std::string_view name);
/// Atomically unload all modules and then load all possible modules from the
/// set directories.
///
/// Takes a write lock.
void UnloadAndLoadModulesFromDirectories();
/// Find a module with given name or return nullptr.
/// Takes a read lock.
ModulePtr GetModuleNamed(std::string_view name) const;
/// Remove all loaded (non-builtin) modules.
/// Takes a write lock.
void UnloadAllModules();
/// Returns the shared memory allocator used by modules
utils::MemoryResource &GetSharedMemoryResource() noexcept;
bool RegisterMgProcedure(std::string_view name, mgp_proc proc);
const std::filesystem::path &InternalModuleDir() const noexcept;
private:
class SharedLibraryHandle {
public:
SharedLibraryHandle(const std::string &shared_library, int mode) : handle_{dlopen(shared_library.c_str(), mode)} {}
SharedLibraryHandle(const SharedLibraryHandle &) = delete;
SharedLibraryHandle(SharedLibraryHandle &&) = delete;
SharedLibraryHandle operator=(const SharedLibraryHandle &) = delete;
SharedLibraryHandle operator=(SharedLibraryHandle &&) = delete;
~SharedLibraryHandle() {
if (handle_) {
dlclose(handle_);
}
}
private:
void *handle_;
};
#if __has_feature(address_sanitizer)
// This is why we need RTLD_NODELETE and we must not use RTLD_DEEPBIND with
// ASAN: https://github.com/google/sanitizers/issues/89
SharedLibraryHandle libstd_handle{"libstdc++.so.6", RTLD_NOW | RTLD_LOCAL | RTLD_NODELETE};
#else
// The reason behind opening share library during runtime is to avoid issues
// with loading symbols from stdlib. We have encounter issues with locale
// that cause std::cout not being printed and issues when python libraries
// would call stdlib (e.g. pytorch).
// The way that those issues were solved was
// by using RTLD_DEEPBIND. RTLD_DEEPBIND ensures that the lookup for the
// mentioned library will be first performed in the already existing binded
// libraries and then the global namespace.
// RTLD_DEEPBIND => https://linux.die.net/man/3/dlopen
SharedLibraryHandle libstd_handle{"libstdc++.so.6", RTLD_NOW | RTLD_LOCAL | RTLD_DEEPBIND};
#endif
std::vector<std::filesystem::path> modules_dirs_;
std::filesystem::path internal_module_dir_;
};
/// Single, global module registry.
extern ModuleRegistry gModuleRegistry;
/// Return the ModulePtr and `mgp_proc *` of the found procedure after resolving
/// `fully_qualified_procedure_name`. `memory` is used for temporary allocations
/// inside this function. ModulePtr must be kept alive to make sure it won't be
/// unloaded.
std::optional<std::pair<procedure::ModulePtr, const mgp_proc *>> FindProcedure(
const ModuleRegistry &module_registry, std::string_view fully_qualified_procedure_name,
utils::MemoryResource *memory);
/// Return the ModulePtr and `mgp_trans *` of the found transformation after resolving
/// `fully_qualified_transformation_name`. `memory` is used for temporary allocations
/// inside this function. ModulePtr must be kept alive to make sure it won't be
/// unloaded.
std::optional<std::pair<procedure::ModulePtr, const mgp_trans *>> FindTransformation(
const ModuleRegistry &module_registry, std::string_view fully_qualified_transformation_name,
utils::MemoryResource *memory);
/// Return the ModulePtr and `mgp_func *` of the found function after resolving
/// `fully_qualified_function_name` if found. If there is no such function
/// std::nullopt is returned. `memory` is used for temporary allocations
/// inside this function. ModulePtr must be kept alive to make sure it won't be unloaded.
std::optional<std::pair<procedure::ModulePtr, const mgp_func *>> FindFunction(
const ModuleRegistry &module_registry, std::string_view fully_qualified_function_name,
utils::MemoryResource *memory);
template <typename T>
concept IsCallable = utils::SameAsAnyOf<T, mgp_proc, mgp_func>;
template <IsCallable TCall>
void ConstructArguments(const std::vector<TypedValue> &args, const TCall &callable,
const std::string_view fully_qualified_name, mgp_list &args_list, mgp_graph &graph) {
const auto n_args = args.size();
const auto c_args_sz = callable.args.size();
const auto c_opt_args_sz = callable.opt_args.size();
if (n_args < c_args_sz || (n_args - c_args_sz > c_opt_args_sz)) {
if (callable.args.empty() && callable.opt_args.empty()) {
throw QueryRuntimeException("'{}' requires no arguments.", fully_qualified_name);
}
if (callable.opt_args.empty()) {
throw QueryRuntimeException("'{}' requires exactly {} {}.", fully_qualified_name, c_args_sz,
c_args_sz == 1U ? "argument" : "arguments");
}
throw QueryRuntimeException("'{}' requires between {} and {} arguments.", fully_qualified_name, c_args_sz,
c_args_sz + c_opt_args_sz);
}
args_list.elems.reserve(n_args);
auto is_not_optional_arg = [c_args_sz](int i) { return c_args_sz > i; };
for (size_t i = 0; i < n_args; ++i) {
auto arg = args[i];
std::string_view name;
const query::v2::procedure::CypherType *type;
if (is_not_optional_arg(i)) {
name = callable.args[i].first;
type = callable.args[i].second;
} else {
name = std::get<0>(callable.opt_args[i - c_args_sz]);
type = std::get<1>(callable.opt_args[i - c_args_sz]);
}
if (!type->SatisfiesType(arg)) {
throw QueryRuntimeException("'{}' argument named '{}' at position {} must be of type {}.", fully_qualified_name,
name, i, type->GetPresentableName());
}
args_list.elems.emplace_back(std::move(arg), &graph);
}
// Fill missing optional arguments with their default values.
const size_t passed_in_opt_args = n_args - c_args_sz;
for (size_t i = passed_in_opt_args; i < c_opt_args_sz; ++i) {
args_list.elems.emplace_back(std::get<2>(callable.opt_args[i]), &graph);
}
}
} // namespace memgraph::query::v2::procedure

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,82 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
/// @file
/// Functions and types for loading Query Modules written in Python.
#pragma once
#include "py/py.hpp"
struct mgp_graph;
struct mgp_memory;
struct mgp_module;
struct mgp_value;
namespace memgraph::query::v2::procedure {
struct PyGraph;
/// Convert an `mgp_value` into a Python object, referencing the given `PyGraph`
/// instance and using the same allocator as the graph.
///
/// Values of type `MGP_VALUE_TYPE_VERTEX`, `MGP_VALUE_TYPE_EDGE` and
/// `MGP_VALUE_TYPE_PATH` are returned as `mgp.Vertex`, `mgp.Edge` and
/// `mgp.Path` respectively, and *not* their internal `_mgp`
/// representations. Other value types are converted to equivalent builtin
/// Python objects.
///
/// Return a non-null `py::Object` instance on success. Otherwise, return a null
/// `py::Object` instance and set the appropriate Python exception.
py::Object MgpValueToPyObject(const mgp_value &value, PyGraph *py_graph);
py::Object MgpValueToPyObject(const mgp_value &value, PyObject *py_graph);
/// Convert a Python object into `mgp_value`, constructing it using the given
/// `mgp_memory` allocator.
///
/// If the user-facing 'mgp' module can be imported, this function will handle
/// conversion of 'mgp.Vertex', 'mgp.Edge' and 'mgp.Path' values.
///
/// @throw std::bad_alloc
/// @throw std::overflow_error if attempting to convert a Python integer which
/// too large to fit into int64_t.
/// @throw std::invalid_argument if the given Python object cannot be converted
/// to an mgp_value (e.g. a dictionary whose keys aren't strings or an object
/// of unsupported type).
mgp_value *PyObjectToMgpValue(PyObject *, mgp_memory *);
/// Create the _mgp module for use in embedded Python.
///
/// The function is to be used before Py_Initialize via the following code.
///
/// PyImport_AppendInittab("_mgp", &query::v2::procedure::PyInitMgpModule);
PyObject *PyInitMgpModule();
/// Create an instance of _mgp.Graph class.
PyObject *MakePyGraph(mgp_graph *, mgp_memory *);
/// Import a module with given name in the context of mgp_module.
///
/// This function can only be called when '_mgp' module has been initialized in
/// Python.
///
/// Return nullptr and set appropriate Python exception on failure.
py::Object ImportPyModule(const char *, mgp_module *);
/// Reload already loaded Python module in the context of mgp_module.
///
/// This function can only be called when '_mgp' module has been initialized in
/// Python.
///
/// Return nullptr and set appropriate Python exception on failure.
py::Object ReloadPyModule(PyObject *, mgp_module *);
} // namespace memgraph::query::v2::procedure

View File

@ -0,0 +1,127 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include "query/v2/serialization/property_value.hpp"
#include "storage/v3/property_value.hpp"
#include "utils/logging.hpp"
namespace memgraph::query::v2::serialization {
namespace {
enum class ObjectType : uint8_t { MAP, TEMPORAL_DATA };
} // namespace
nlohmann::json SerializePropertyValue(const storage::v3::PropertyValue &property_value) {
using Type = storage::v3::PropertyValue::Type;
switch (property_value.type()) {
case Type::Null:
return {};
case Type::Bool:
return property_value.ValueBool();
case Type::Int:
return property_value.ValueInt();
case Type::Double:
return property_value.ValueDouble();
case Type::String:
return property_value.ValueString();
case Type::List:
return SerializePropertyValueVector(property_value.ValueList());
case Type::Map:
return SerializePropertyValueMap(property_value.ValueMap());
case Type::TemporalData:
const auto temporal_data = property_value.ValueTemporalData();
auto data = nlohmann::json::object();
data.emplace("type", static_cast<uint64_t>(ObjectType::TEMPORAL_DATA));
data.emplace("value", nlohmann::json::object({{"type", static_cast<uint64_t>(temporal_data.type)},
{"microseconds", temporal_data.microseconds}}));
return data;
}
}
nlohmann::json SerializePropertyValueVector(const std::vector<storage::v3::PropertyValue> &values) {
nlohmann::json array = nlohmann::json::array();
for (const auto &value : values) {
array.push_back(SerializePropertyValue(value));
}
return array;
}
nlohmann::json SerializePropertyValueMap(const std::map<std::string, storage::v3::PropertyValue> &parameters) {
nlohmann::json data = nlohmann::json::object();
data.emplace("type", static_cast<uint64_t>(ObjectType::MAP));
data.emplace("value", nlohmann::json::object());
for (const auto &[key, value] : parameters) {
data["value"][key] = SerializePropertyValue(value);
}
return data;
};
storage::v3::PropertyValue DeserializePropertyValue(const nlohmann::json &data) {
if (data.is_null()) {
return storage::v3::PropertyValue();
}
if (data.is_boolean()) {
return storage::v3::PropertyValue(data.get<bool>());
}
if (data.is_number_integer()) {
return storage::v3::PropertyValue(data.get<int64_t>());
}
if (data.is_number_float()) {
return storage::v3::PropertyValue(data.get<double>());
}
if (data.is_string()) {
return storage::v3::PropertyValue(data.get<std::string>());
}
if (data.is_array()) {
return storage::v3::PropertyValue(DeserializePropertyValueList(data));
}
MG_ASSERT(data.is_object(), "Unknown type found in the trigger storage");
switch (data["type"].get<ObjectType>()) {
case ObjectType::MAP:
return storage::v3::PropertyValue(DeserializePropertyValueMap(data));
case ObjectType::TEMPORAL_DATA:
return storage::v3::PropertyValue(storage::v3::TemporalData{
data["value"]["type"].get<storage::v3::TemporalType>(), data["value"]["microseconds"].get<int64_t>()});
}
}
std::vector<storage::v3::PropertyValue> DeserializePropertyValueList(const nlohmann::json::array_t &data) {
std::vector<storage::v3::PropertyValue> property_values;
property_values.reserve(data.size());
for (const auto &value : data) {
property_values.emplace_back(DeserializePropertyValue(value));
}
return property_values;
}
std::map<std::string, storage::v3::PropertyValue> DeserializePropertyValueMap(const nlohmann::json::object_t &data) {
MG_ASSERT(data.at("type").get<ObjectType>() == ObjectType::MAP, "Invalid map serialization");
std::map<std::string, storage::v3::PropertyValue> property_values;
const nlohmann::json::object_t &values = data.at("value");
for (const auto &[key, value] : values) {
property_values.emplace(key, DeserializePropertyValue(value));
}
return property_values;
}
} // namespace memgraph::query::v2::serialization

View File

@ -0,0 +1,32 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <json/json.hpp>
#include "storage/v3/property_value.hpp"
namespace memgraph::query::v2::serialization {
nlohmann::json SerializePropertyValue(const storage::v3::PropertyValue &property_value);
nlohmann::json SerializePropertyValueVector(const std::vector<storage::v3::PropertyValue> &values);
nlohmann::json SerializePropertyValueMap(const std::map<std::string, storage::v3::PropertyValue> &parameters);
storage::v3::PropertyValue DeserializePropertyValue(const nlohmann::json &data);
std::vector<storage::v3::PropertyValue> DeserializePropertyValueList(const nlohmann::json::array_t &data);
std::map<std::string, storage::v3::PropertyValue> DeserializePropertyValueMap(const nlohmann::json::object_t &data);
} // namespace memgraph::query::v2::serialization

63
src/query/v2/stream.hpp Normal file
View File

@ -0,0 +1,63 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <memory>
#include <vector>
#include "query/v2/typed_value.hpp"
#include "utils/memory.hpp"
namespace memgraph::query::v2 {
/**
* `AnyStream` can wrap *any* type implementing the `Stream` concept into a
* single type.
*
* The type erasure technique is used. The original type which an `AnyStream`
* was constructed from is "erased", as `AnyStream` is not a class template and
* doesn't use the type in any way. Client code can then program just for
* `AnyStream`, rather than using static polymorphism to handle any type
* implementing the `Stream` concept.
*/
class AnyStream final {
public:
template <class TStream>
AnyStream(TStream *stream, utils::MemoryResource *memory_resource)
: content_{
utils::Allocator<GenericWrapper<TStream>>{memory_resource}.template new_object<GenericWrapper<TStream>>(
stream),
[memory_resource](Wrapper *ptr) {
utils::Allocator<GenericWrapper<TStream>>{memory_resource}
.template delete_object<GenericWrapper<TStream>>(static_cast<GenericWrapper<TStream> *>(ptr));
}} {}
void Result(const std::vector<TypedValue> &values) { content_->Result(values); }
private:
struct Wrapper {
virtual void Result(const std::vector<TypedValue> &values) = 0;
};
template <class TStream>
struct GenericWrapper final : public Wrapper {
explicit GenericWrapper(TStream *stream) : stream_{stream} {}
void Result(const std::vector<TypedValue> &values) override { stream_->Result(values); }
TStream *stream_;
};
std::unique_ptr<Wrapper, std::function<void(Wrapper *)>> content_;
};
} // namespace memgraph::query::v2

Some files were not shown because too many files have changed in this diff Show More